diff --git a/backend/internal/handler/admin/affiliate_handler.go b/backend/internal/handler/admin/affiliate_handler.go index 97e649ec..d443d344 100644 --- a/backend/internal/handler/admin/affiliate_handler.go +++ b/backend/internal/handler/admin/affiliate_handler.go @@ -2,8 +2,11 @@ package admin import ( "strconv" + "strings" + "time" "github.com/Wei-Shaw/sub2api/internal/pkg/response" + "github.com/Wei-Shaw/sub2api/internal/pkg/timezone" "github.com/Wei-Shaw/sub2api/internal/service" "github.com/gin-gonic/gin" @@ -181,3 +184,108 @@ func (h *AffiliateHandler) LookupUsers(c *gin.Context) { } response.Success(c, result) } + +// GetUserOverview returns one user's affiliate overview. +// GET /api/v1/admin/affiliates/users/:user_id/overview +func (h *AffiliateHandler) GetUserOverview(c *gin.Context) { + userID, err := strconv.ParseInt(c.Param("user_id"), 10, 64) + if err != nil || userID <= 0 { + response.BadRequest(c, "Invalid user_id") + return + } + overview, err := h.affiliateService.AdminGetUserOverview(c.Request.Context(), userID) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, overview) +} + +// ListInviteRecords returns all inviter-invitee relationships. +// GET /api/v1/admin/affiliates/invites +func (h *AffiliateHandler) ListInviteRecords(c *gin.Context) { + page, pageSize := response.ParsePagination(c) + filter := parseAffiliateRecordFilter(c, page, pageSize) + items, total, err := h.affiliateService.AdminListInviteRecords(c.Request.Context(), filter) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Paginated(c, items, total, filter.Page, filter.PageSize) +} + +// ListRebateRecords returns all order-level affiliate rebate records. +// GET /api/v1/admin/affiliates/rebates +func (h *AffiliateHandler) ListRebateRecords(c *gin.Context) { + page, pageSize := response.ParsePagination(c) + filter := parseAffiliateRecordFilter(c, page, pageSize) + items, total, err := h.affiliateService.AdminListRebateRecords(c.Request.Context(), filter) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Paginated(c, items, total, filter.Page, filter.PageSize) +} + +// ListTransferRecords returns all affiliate quota-to-balance transfer records. +// GET /api/v1/admin/affiliates/transfers +func (h *AffiliateHandler) ListTransferRecords(c *gin.Context) { + page, pageSize := response.ParsePagination(c) + filter := parseAffiliateRecordFilter(c, page, pageSize) + items, total, err := h.affiliateService.AdminListTransferRecords(c.Request.Context(), filter) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Paginated(c, items, total, filter.Page, filter.PageSize) +} + +func parseAffiliateRecordFilter(c *gin.Context, page, pageSize int) service.AffiliateRecordFilter { + filter := service.AffiliateRecordFilter{ + Search: c.Query("search"), + Page: page, + PageSize: pageSize, + SortBy: c.Query("sort_by"), + SortDesc: c.Query("sort_order") != "asc", + } + if filter.PageSize > 100 { + filter.PageSize = 100 + } + userTZ := c.Query("timezone") + if t := parseAffiliateRecordStartTime(c.Query("start_at"), userTZ); t != nil { + filter.StartAt = t + } + if t := parseAffiliateRecordEndTime(c.Query("end_at"), userTZ); t != nil { + filter.EndAt = t + } + return filter +} + +func parseAffiliateRecordStartTime(raw string, userTZ string) *time.Time { + raw = strings.TrimSpace(raw) + if raw == "" { + return nil + } + if parsed, err := time.Parse(time.RFC3339, raw); err == nil { + return &parsed + } + if parsed, err := timezone.ParseInUserLocation("2006-01-02", raw, userTZ); err == nil { + return &parsed + } + return nil +} + +func parseAffiliateRecordEndTime(raw string, userTZ string) *time.Time { + raw = strings.TrimSpace(raw) + if raw == "" { + return nil + } + if parsed, err := time.Parse(time.RFC3339, raw); err == nil { + return &parsed + } + if parsed, err := timezone.ParseInUserLocation("2006-01-02", raw, userTZ); err == nil { + end := parsed.AddDate(0, 0, 1).Add(-time.Nanosecond) + return &end + } + return nil +} diff --git a/backend/internal/handler/admin/user_handler.go b/backend/internal/handler/admin/user_handler.go index 3d80107f..a297c56c 100644 --- a/backend/internal/handler/admin/user_handler.go +++ b/backend/internal/handler/admin/user_handler.go @@ -390,7 +390,7 @@ func (h *UserHandler) GetUserUsage(c *gin.Context) { // GetBalanceHistory handles getting user's balance/concurrency change history // GET /api/v1/admin/users/:id/balance-history // Query params: -// - type: filter by record type (balance, admin_balance, concurrency, admin_concurrency, subscription) +// - type: filter by record type (balance, affiliate_balance, admin_balance, concurrency, admin_concurrency, subscription) func (h *UserHandler) GetBalanceHistory(c *gin.Context) { userID, err := strconv.ParseInt(c.Param("id"), 10, 64) if err != nil { diff --git a/backend/internal/pkg/apicompat/anthropic_responses_test.go b/backend/internal/pkg/apicompat/anthropic_responses_test.go index e8b25c2b..edde85d3 100644 --- a/backend/internal/pkg/apicompat/anthropic_responses_test.go +++ b/backend/internal/pkg/apicompat/anthropic_responses_test.go @@ -434,6 +434,45 @@ func TestStreamingTextOnly(t *testing.T) { assert.Equal(t, "message_stop", events[1].Type) } +func TestResponsesEventToAnthropicEvents_ResponseDone(t *testing.T) { + state := NewResponsesEventToAnthropicState() + state.Model = "gpt-4o" + + events := ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{ + Type: "response.done", + Response: &ResponsesResponse{ + Status: "completed", + Usage: &ResponsesUsage{InputTokens: 12, OutputTokens: 4}, + }, + }, state) + require.Len(t, events, 2) + assert.Equal(t, "message_delta", events[0].Type) + assert.Equal(t, "end_turn", events[0].Delta.StopReason) + assert.Equal(t, 12, events[0].Usage.InputTokens) + assert.Equal(t, 4, events[0].Usage.OutputTokens) + assert.Equal(t, "message_stop", events[1].Type) + assert.Nil(t, FinalizeResponsesAnthropicStream(state)) +} + +func TestResponsesEventToAnthropicEvents_ResponseDoneIncomplete(t *testing.T) { + state := NewResponsesEventToAnthropicState() + state.Model = "gpt-4o" + + events := ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{ + Type: "response.done", + Response: &ResponsesResponse{ + Status: "incomplete", + IncompleteDetails: &ResponsesIncompleteDetails{Reason: "max_output_tokens"}, + Usage: &ResponsesUsage{InputTokens: 12, OutputTokens: 4}, + }, + }, state) + require.Len(t, events, 2) + assert.Equal(t, "message_delta", events[0].Type) + assert.Equal(t, "max_tokens", events[0].Delta.StopReason) + assert.Equal(t, "message_stop", events[1].Type) + assert.Nil(t, FinalizeResponsesAnthropicStream(state)) +} + func TestStreamingCachedTokensUseAnthropicInputSemantics(t *testing.T) { state := NewResponsesEventToAnthropicState() ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{ diff --git a/backend/internal/pkg/apicompat/chatcompletions_responses_test.go b/backend/internal/pkg/apicompat/chatcompletions_responses_test.go index 35d42999..bf5c23d5 100644 --- a/backend/internal/pkg/apicompat/chatcompletions_responses_test.go +++ b/backend/internal/pkg/apicompat/chatcompletions_responses_test.go @@ -720,6 +720,49 @@ func TestResponsesEventToChatChunks_Completed(t *testing.T) { assert.Equal(t, 30, chunks[1].Usage.PromptTokensDetails.CachedTokens) } +func TestResponsesEventToChatChunks_ResponseDone(t *testing.T) { + state := NewResponsesEventToChatState() + state.Model = "gpt-4o" + state.IncludeUsage = true + + chunks := ResponsesEventToChatChunks(&ResponsesStreamEvent{ + Type: "response.done", + Response: &ResponsesResponse{ + Status: "completed", + Usage: &ResponsesUsage{InputTokens: 13, OutputTokens: 7}, + }, + }, state) + require.Len(t, chunks, 2) + require.NotNil(t, chunks[0].Choices[0].FinishReason) + assert.Equal(t, "stop", *chunks[0].Choices[0].FinishReason) + require.NotNil(t, chunks[1].Usage) + assert.Equal(t, 13, chunks[1].Usage.PromptTokens) + assert.Equal(t, 7, chunks[1].Usage.CompletionTokens) + assert.Nil(t, FinalizeResponsesChatStream(state)) +} + +func TestResponsesEventToChatChunks_ResponseDoneIncomplete(t *testing.T) { + state := NewResponsesEventToChatState() + state.Model = "gpt-4o" + state.IncludeUsage = true + + chunks := ResponsesEventToChatChunks(&ResponsesStreamEvent{ + Type: "response.done", + Response: &ResponsesResponse{ + Status: "incomplete", + IncompleteDetails: &ResponsesIncompleteDetails{Reason: "max_output_tokens"}, + Usage: &ResponsesUsage{InputTokens: 13, OutputTokens: 7}, + }, + }, state) + require.Len(t, chunks, 2) + require.NotNil(t, chunks[0].Choices[0].FinishReason) + assert.Equal(t, "length", *chunks[0].Choices[0].FinishReason) + require.NotNil(t, chunks[1].Usage) + assert.Equal(t, 13, chunks[1].Usage.PromptTokens) + assert.Equal(t, 7, chunks[1].Usage.CompletionTokens) + assert.Nil(t, FinalizeResponsesChatStream(state)) +} + func TestResponsesEventToChatChunks_CompletedWithToolCalls(t *testing.T) { state := NewResponsesEventToChatState() state.Model = "gpt-4o" diff --git a/backend/internal/pkg/apicompat/responses_to_anthropic.go b/backend/internal/pkg/apicompat/responses_to_anthropic.go index 489ed238..b76f384d 100644 --- a/backend/internal/pkg/apicompat/responses_to_anthropic.go +++ b/backend/internal/pkg/apicompat/responses_to_anthropic.go @@ -212,7 +212,9 @@ func ResponsesEventToAnthropicEvents( return resToAnthHandleReasoningDelta(evt, state) case "response.reasoning_summary_text.done": return resToAnthHandleBlockDone(state) - case "response.completed", "response.incomplete", "response.failed": + // response.done 是 Realtime/WS 与项目透传路径使用的终止别名; + // 普通 Responses HTTP SSE 的公开终止事件仍以 response.completed 为主。 + case "response.completed", "response.done", "response.incomplete", "response.failed": return resToAnthHandleCompleted(evt, state) default: return nil diff --git a/backend/internal/pkg/apicompat/responses_to_chatcompletions.go b/backend/internal/pkg/apicompat/responses_to_chatcompletions.go index 61b3bf9c..2386771d 100644 --- a/backend/internal/pkg/apicompat/responses_to_chatcompletions.go +++ b/backend/internal/pkg/apicompat/responses_to_chatcompletions.go @@ -160,7 +160,9 @@ func ResponsesEventToChatChunks(evt *ResponsesStreamEvent, state *ResponsesEvent return resToChatHandleReasoningDelta(evt, state) case "response.reasoning_summary_text.done": return nil - case "response.completed", "response.incomplete", "response.failed": + // response.done 是 Realtime/WS 与项目透传路径使用的终止别名; + // 普通 Responses HTTP SSE 的公开终止事件仍以 response.completed 为主。 + case "response.completed", "response.done", "response.incomplete", "response.failed": return resToChatHandleCompleted(evt, state) default: return nil diff --git a/backend/internal/pkg/apicompat/types.go b/backend/internal/pkg/apicompat/types.go index f8c6b75f..0ff2cf49 100644 --- a/backend/internal/pkg/apicompat/types.go +++ b/backend/internal/pkg/apicompat/types.go @@ -314,7 +314,7 @@ type ResponsesOutputTokensDetails struct { type ResponsesStreamEvent struct { Type string `json:"type"` - // response.created / response.completed / response.failed / response.incomplete + // response.created / response.completed / response.done / response.failed / response.incomplete Response *ResponsesResponse `json:"response,omitempty"` // response.output_item.added / response.output_item.done diff --git a/backend/internal/repository/affiliate_repo.go b/backend/internal/repository/affiliate_repo.go index ef89e5b6..61da539e 100644 --- a/backend/internal/repository/affiliate_repo.go +++ b/backend/internal/repository/affiliate_repo.go @@ -22,6 +22,34 @@ const ( var affiliateCodeCharset = []byte("ABCDEFGHJKLMNPQRSTUVWXYZ23456789") +const affiliateUserOverviewSQL = ` +SELECT ua.user_id, + COALESCE(u.email, ''), + COALESCE(u.username, ''), + ua.aff_code, + COALESCE(ua.aff_rebate_rate_percent, 0)::double precision, + (ua.aff_rebate_rate_percent IS NOT NULL) AS has_custom_rate, + ua.aff_count, + COALESCE(rebated.rebated_invitee_count, 0), + (ua.aff_quota + COALESCE(matured.matured_frozen_quota, 0))::double precision, + ua.aff_history_quota::double precision +FROM user_affiliates ua +JOIN users u ON u.id = ua.user_id +LEFT JOIN ( + SELECT user_id, COUNT(DISTINCT source_user_id)::integer AS rebated_invitee_count + FROM user_affiliate_ledger + WHERE action = 'accrue' AND source_user_id IS NOT NULL + GROUP BY user_id +) rebated ON rebated.user_id = ua.user_id +LEFT JOIN ( + SELECT user_id, COALESCE(SUM(amount), 0)::double precision AS matured_frozen_quota + FROM user_affiliate_ledger + WHERE action = 'accrue' AND frozen_until IS NOT NULL AND frozen_until <= NOW() + GROUP BY user_id +) matured ON matured.user_id = ua.user_id +WHERE ua.user_id = $1 +LIMIT 1` + type affiliateQueryExecer interface { QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) @@ -86,7 +114,7 @@ func (r *affiliateRepository) BindInviter(ctx context.Context, userID, inviterID return bound, nil } -func (r *affiliateRepository) AccrueQuota(ctx context.Context, inviterID, inviteeUserID int64, amount float64, freezeHours int) (bool, error) { +func (r *affiliateRepository) AccrueQuota(ctx context.Context, inviterID, inviteeUserID int64, amount float64, freezeHours int, sourceOrderID *int64) (bool, error) { if amount <= 0 { return false, nil } @@ -112,15 +140,15 @@ func (r *affiliateRepository) AccrueQuota(ctx context.Context, inviterID, invite if freezeHours > 0 { if _, err = txClient.ExecContext(txCtx, ` -INSERT INTO user_affiliate_ledger (user_id, action, amount, source_user_id, frozen_until, created_at, updated_at) -VALUES ($1, 'accrue', $2, $3, NOW() + make_interval(hours => $4), NOW(), NOW())`, - inviterID, amount, inviteeUserID, freezeHours); err != nil { +INSERT INTO user_affiliate_ledger (user_id, action, amount, source_user_id, source_order_id, frozen_until, created_at, updated_at) +VALUES ($1, 'accrue', $2, $3, $4, NOW() + make_interval(hours => $5), NOW(), NOW())`, + inviterID, amount, inviteeUserID, nullableInt64Arg(sourceOrderID), freezeHours); err != nil { return fmt.Errorf("insert affiliate accrue ledger: %w", err) } } else { if _, err = txClient.ExecContext(txCtx, ` -INSERT INTO user_affiliate_ledger (user_id, action, amount, source_user_id, created_at, updated_at) -VALUES ($1, 'accrue', $2, $3, NOW(), NOW())`, inviterID, amount, inviteeUserID); err != nil { +INSERT INTO user_affiliate_ledger (user_id, action, amount, source_user_id, source_order_id, created_at, updated_at) +VALUES ($1, 'accrue', $2, $3, $4, NOW(), NOW())`, inviterID, amount, inviteeUserID, nullableInt64Arg(sourceOrderID)); err != nil { return fmt.Errorf("insert affiliate accrue ledger: %w", err) } } @@ -275,9 +303,32 @@ FROM cleared`, userID) return err } + snapshot, err := queryAffiliateTransferSnapshot(txCtx, txClient, userID) + if err != nil { + return err + } + if _, err = txClient.ExecContext(txCtx, ` -INSERT INTO user_affiliate_ledger (user_id, action, amount, source_user_id, created_at, updated_at) -VALUES ($1, 'transfer', $2, NULL, NOW(), NOW())`, userID, transferred); err != nil { +INSERT INTO user_affiliate_ledger ( + user_id, + action, + amount, + source_user_id, + balance_after, + aff_quota_after, + aff_frozen_quota_after, + aff_history_quota_after, + created_at, + updated_at +) +VALUES ($1, 'transfer', $2, NULL, $3, $4, $5, $6, NOW(), NOW())`, + userID, + transferred, + snapshot.BalanceAfter, + snapshot.AvailableQuotaAfter, + snapshot.FrozenQuotaAfter, + snapshot.HistoryQuotaAfter, + ); err != nil { return fmt.Errorf("insert affiliate transfer ledger: %w", err) } @@ -332,6 +383,349 @@ LIMIT $2`, inviterID, limit) return invitees, nil } +func (r *affiliateRepository) ListAffiliateInviteRecords(ctx context.Context, filter service.AffiliateRecordFilter) ([]service.AffiliateInviteRecord, int64, error) { + client := clientFromContext(ctx, r.client) + where, args := buildAffiliateRecordWhere(filter, "ua.created_at", []string{ + "inviter.email", "inviter.username", "invitee.email", "invitee.username", + "ua.inviter_id::text", "ua.user_id::text", "inviter_aff.aff_code", + }) + + total, err := queryAffiliateRecordCount(ctx, client, ` +SELECT COUNT(*) +FROM user_affiliates ua +JOIN users invitee ON invitee.id = ua.user_id +JOIN users inviter ON inviter.id = ua.inviter_id +JOIN user_affiliates inviter_aff ON inviter_aff.user_id = ua.inviter_id +`+where, args...) + if err != nil { + return nil, 0, err + } + + orderBy := buildAffiliateRecordOrderBy(filter, map[string]string{ + "inviter": "inviter.email", + "invitee": "invitee.email", + "aff_code": "inviter_aff.aff_code", + "total_rebate": "total_rebate", + "created_at": "ua.created_at", + }, "ua.created_at") + args = append(args, filter.PageSize, (filter.Page-1)*filter.PageSize) + rows, err := client.QueryContext(ctx, ` +SELECT ua.inviter_id, + COALESCE(inviter.email, ''), + COALESCE(inviter.username, ''), + ua.user_id, + COALESCE(invitee.email, ''), + COALESCE(invitee.username, ''), + COALESCE(inviter_aff.aff_code, ''), + COALESCE(SUM(ual.amount), 0)::double precision AS total_rebate, + ua.created_at +FROM user_affiliates ua +JOIN users invitee ON invitee.id = ua.user_id +JOIN users inviter ON inviter.id = ua.inviter_id +JOIN user_affiliates inviter_aff ON inviter_aff.user_id = ua.inviter_id +LEFT JOIN user_affiliate_ledger ual + ON ual.user_id = ua.inviter_id + AND ual.source_user_id = ua.user_id + AND ual.action = 'accrue' +`+where+` +GROUP BY ua.inviter_id, inviter.email, inviter.username, ua.user_id, invitee.email, invitee.username, inviter_aff.aff_code, ua.created_at +`+orderBy+` +LIMIT $`+fmt.Sprint(len(args)-1)+` OFFSET $`+fmt.Sprint(len(args)), args...) + if err != nil { + return nil, 0, err + } + defer func() { _ = rows.Close() }() + + items := make([]service.AffiliateInviteRecord, 0) + for rows.Next() { + var item service.AffiliateInviteRecord + if err := rows.Scan( + &item.InviterID, + &item.InviterEmail, + &item.InviterUsername, + &item.InviteeID, + &item.InviteeEmail, + &item.InviteeUsername, + &item.AffCode, + &item.TotalRebate, + &item.CreatedAt, + ); err != nil { + return nil, 0, err + } + items = append(items, item) + } + if err := rows.Err(); err != nil { + return nil, 0, err + } + return items, total, nil +} + +func (r *affiliateRepository) ListAffiliateRebateRecords(ctx context.Context, filter service.AffiliateRecordFilter) ([]service.AffiliateRebateRecord, int64, error) { + client := clientFromContext(ctx, r.client) + where, args := buildAffiliateRecordWhere(filter, "ual.created_at", []string{ + "inviter.email", "inviter.username", "invitee.email", "invitee.username", + "po.id::text", "po.out_trade_no", "po.payment_type", "po.status", + }) + baseJoin := ` +FROM user_affiliate_ledger ual +JOIN payment_orders po ON po.id = ual.source_order_id +JOIN users invitee ON invitee.id = ual.source_user_id +JOIN users inviter ON inviter.id = ual.user_id +WHERE ual.action = 'accrue' + AND ual.source_order_id IS NOT NULL` + if where != "" { + where = strings.Replace(where, "WHERE ", " AND ", 1) + } + + total, err := queryAffiliateRecordCount(ctx, client, "SELECT COUNT(*) "+baseJoin+where, args...) + if err != nil { + return nil, 0, err + } + + orderBy := buildAffiliateRecordOrderBy(filter, map[string]string{ + "order": "po.id", + "inviter": "inviter.email", + "invitee": "invitee.email", + "order_amount": "po.amount", + "pay_amount": "po.pay_amount", + "rebate_amount": "ual.amount", + "payment_type": "po.payment_type", + "order_status": "po.status", + "created_at": "ual.created_at", + }, "ual.created_at") + args = append(args, filter.PageSize, (filter.Page-1)*filter.PageSize) + rows, err := client.QueryContext(ctx, ` +SELECT po.id, + po.out_trade_no, + ual.user_id, + COALESCE(inviter.email, ''), + COALESCE(inviter.username, ''), + ual.source_user_id, + COALESCE(invitee.email, ''), + COALESCE(invitee.username, ''), + po.amount::double precision, + po.pay_amount::double precision, + ual.amount::double precision, + po.payment_type, + po.status, + ual.created_at +`+baseJoin+where+` +`+orderBy+` +LIMIT $`+fmt.Sprint(len(args)-1)+` OFFSET $`+fmt.Sprint(len(args)), args...) + if err != nil { + return nil, 0, err + } + defer func() { _ = rows.Close() }() + + items := make([]service.AffiliateRebateRecord, 0) + for rows.Next() { + var item service.AffiliateRebateRecord + if err := rows.Scan( + &item.OrderID, + &item.OutTradeNo, + &item.InviterID, + &item.InviterEmail, + &item.InviterUsername, + &item.InviteeID, + &item.InviteeEmail, + &item.InviteeUsername, + &item.OrderAmount, + &item.PayAmount, + &item.RebateAmount, + &item.PaymentType, + &item.OrderStatus, + &item.CreatedAt, + ); err != nil { + return nil, 0, err + } + items = append(items, item) + } + if err := rows.Err(); err != nil { + return nil, 0, err + } + return items, total, nil +} + +func (r *affiliateRepository) ListAffiliateTransferRecords(ctx context.Context, filter service.AffiliateRecordFilter) ([]service.AffiliateTransferRecord, int64, error) { + client := clientFromContext(ctx, r.client) + where, args := buildAffiliateRecordWhere(filter, "ual.created_at", []string{ + "u.email", "u.username", "u.id::text", + }) + baseJoin := ` +FROM user_affiliate_ledger ual +JOIN users u ON u.id = ual.user_id +WHERE ual.action = 'transfer'` + if where != "" { + where = strings.Replace(where, "WHERE ", " AND ", 1) + } + + total, err := queryAffiliateRecordCount(ctx, client, "SELECT COUNT(*) "+baseJoin+where, args...) + if err != nil { + return nil, 0, err + } + + orderBy := buildAffiliateRecordOrderBy(filter, map[string]string{ + "user": "u.email", + "amount": "ual.amount", + "balance_after": "ual.balance_after", + "available_quota_after": "ual.aff_quota_after", + "frozen_quota_after": "ual.aff_frozen_quota_after", + "history_quota_after": "ual.aff_history_quota_after", + "created_at": "ual.created_at", + }, "ual.created_at") + args = append(args, filter.PageSize, (filter.Page-1)*filter.PageSize) + rows, err := client.QueryContext(ctx, ` +SELECT ual.id, + ual.user_id, + COALESCE(u.email, ''), + COALESCE(u.username, ''), + ual.amount::double precision, + ual.balance_after::double precision, + ual.aff_quota_after::double precision, + ual.aff_frozen_quota_after::double precision, + ual.aff_history_quota_after::double precision, + ual.created_at +`+baseJoin+where+` +`+orderBy+` +LIMIT $`+fmt.Sprint(len(args)-1)+` OFFSET $`+fmt.Sprint(len(args)), args...) + if err != nil { + return nil, 0, err + } + defer func() { _ = rows.Close() }() + + items := make([]service.AffiliateTransferRecord, 0) + for rows.Next() { + var item service.AffiliateTransferRecord + var balanceAfter sql.NullFloat64 + var availableQuotaAfter sql.NullFloat64 + var frozenQuotaAfter sql.NullFloat64 + var historyQuotaAfter sql.NullFloat64 + if err := rows.Scan( + &item.LedgerID, + &item.UserID, + &item.UserEmail, + &item.Username, + &item.Amount, + &balanceAfter, + &availableQuotaAfter, + &frozenQuotaAfter, + &historyQuotaAfter, + &item.CreatedAt, + ); err != nil { + return nil, 0, err + } + item.BalanceAfter = nullableFloat64Ptr(balanceAfter) + item.AvailableQuotaAfter = nullableFloat64Ptr(availableQuotaAfter) + item.FrozenQuotaAfter = nullableFloat64Ptr(frozenQuotaAfter) + item.HistoryQuotaAfter = nullableFloat64Ptr(historyQuotaAfter) + item.SnapshotAvailable = balanceAfter.Valid && + availableQuotaAfter.Valid && + frozenQuotaAfter.Valid && + historyQuotaAfter.Valid + items = append(items, item) + } + if err := rows.Err(); err != nil { + return nil, 0, err + } + return items, total, nil +} + +func (r *affiliateRepository) GetAffiliateUserOverview(ctx context.Context, userID int64) (*service.AffiliateUserOverview, error) { + if userID <= 0 { + return nil, service.ErrUserNotFound + } + client := clientFromContext(ctx, r.client) + rows, err := client.QueryContext(ctx, affiliateUserOverviewSQL, userID) + if err != nil { + return nil, err + } + defer func() { _ = rows.Close() }() + + if !rows.Next() { + if err := rows.Err(); err != nil { + return nil, err + } + return nil, service.ErrUserNotFound + } + + var overview service.AffiliateUserOverview + var customRate float64 + var hasCustomRate bool + if err := rows.Scan( + &overview.UserID, + &overview.Email, + &overview.Username, + &overview.AffCode, + &customRate, + &hasCustomRate, + &overview.InvitedCount, + &overview.RebatedInviteeCount, + &overview.AvailableQuota, + &overview.HistoryQuota, + ); err != nil { + return nil, err + } + if hasCustomRate { + overview.RebateRatePercent = customRate + overview.RebateRateCustom = true + } + return &overview, rows.Err() +} + +func buildAffiliateRecordWhere(filter service.AffiliateRecordFilter, timeColumn string, searchColumns []string) (string, []any) { + clauses := make([]string, 0, 3) + args := make([]any, 0, 3) + if filter.StartAt != nil { + args = append(args, *filter.StartAt) + clauses = append(clauses, fmt.Sprintf("%s >= $%d", timeColumn, len(args))) + } + if filter.EndAt != nil { + args = append(args, *filter.EndAt) + clauses = append(clauses, fmt.Sprintf("%s <= $%d", timeColumn, len(args))) + } + search := strings.TrimSpace(filter.Search) + if search != "" && len(searchColumns) > 0 { + args = append(args, "%"+strings.ToLower(search)+"%") + parts := make([]string, 0, len(searchColumns)) + for _, col := range searchColumns { + parts = append(parts, fmt.Sprintf("LOWER(%s) LIKE $%d", col, len(args))) + } + clauses = append(clauses, "("+strings.Join(parts, " OR ")+")") + } + if len(clauses) == 0 { + return "", args + } + return "WHERE " + strings.Join(clauses, " AND "), args +} + +func buildAffiliateRecordOrderBy(filter service.AffiliateRecordFilter, sortColumns map[string]string, fallbackColumn string) string { + column := sortColumns[filter.SortBy] + if column == "" { + column = fallbackColumn + } + direction := "DESC" + if !filter.SortDesc { + direction = "ASC" + } + return "ORDER BY " + column + " " + direction + " NULLS LAST" +} + +func queryAffiliateRecordCount(ctx context.Context, client affiliateQueryExecer, query string, args ...any) (int64, error) { + rows, err := client.QueryContext(ctx, query, args...) + if err != nil { + return 0, err + } + defer func() { _ = rows.Close() }() + if !rows.Next() { + return 0, rows.Err() + } + var total int64 + if err := rows.Scan(&total); err != nil { + return 0, err + } + return total, rows.Err() +} + func (r *affiliateRepository) withTx(ctx context.Context, fn func(txCtx context.Context, txClient *dbent.Client) error) error { if tx := dbent.TxFromContext(ctx); tx != nil { return fn(ctx, tx.Client()) @@ -516,6 +910,54 @@ func queryUserBalance(ctx context.Context, client affiliateQueryExecer, userID i return balance, nil } +type affiliateTransferSnapshot struct { + BalanceAfter float64 + AvailableQuotaAfter float64 + FrozenQuotaAfter float64 + HistoryQuotaAfter float64 +} + +func queryAffiliateTransferSnapshot(ctx context.Context, client affiliateQueryExecer, userID int64) (*affiliateTransferSnapshot, error) { + rows, err := client.QueryContext(ctx, ` +SELECT u.balance::double precision, + ua.aff_quota::double precision, + ua.aff_frozen_quota::double precision, + ua.aff_history_quota::double precision +FROM users u +JOIN user_affiliates ua ON ua.user_id = u.id +WHERE u.id = $1 +LIMIT 1`, userID) + if err != nil { + return nil, fmt.Errorf("query affiliate transfer snapshot: %w", err) + } + defer func() { _ = rows.Close() }() + + if !rows.Next() { + if err := rows.Err(); err != nil { + return nil, err + } + return nil, service.ErrUserNotFound + } + + var snapshot affiliateTransferSnapshot + if err := rows.Scan( + &snapshot.BalanceAfter, + &snapshot.AvailableQuotaAfter, + &snapshot.FrozenQuotaAfter, + &snapshot.HistoryQuotaAfter, + ); err != nil { + return nil, err + } + return &snapshot, rows.Err() +} + +func nullableFloat64Ptr(v sql.NullFloat64) *float64 { + if !v.Valid { + return nil + } + return &v.Float64 +} + func generateAffiliateCode() (string, error) { buf := make([]byte, affiliateCodeLength) if _, err := rand.Read(buf); err != nil { @@ -674,6 +1116,13 @@ func nullableArg(v *float64) any { return *v } +func nullableInt64Arg(v *int64) any { + if v == nil { + return nil + } + return *v +} + // ListUsersWithCustomSettings 列出有专属配置(自定义码或专属比例)的用户。 // // 单一查询同时处理"无搜索"与"按邮箱/用户名模糊搜索": diff --git a/backend/internal/repository/affiliate_repo_integration_test.go b/backend/internal/repository/affiliate_repo_integration_test.go index 697a193b..b01ed528 100644 --- a/backend/internal/repository/affiliate_repo_integration_test.go +++ b/backend/internal/repository/affiliate_repo_integration_test.go @@ -78,6 +78,26 @@ VALUES ($1, $2, $3, $3, NOW(), NOW())`, u.ID, affCode, 12.34) ledgerCount := querySingleInt(t, txCtx, client, "SELECT COUNT(*) FROM user_affiliate_ledger WHERE user_id = $1 AND action = 'transfer'", u.ID) require.Equal(t, 1, ledgerCount) + + rows, err := client.QueryContext(txCtx, ` +SELECT amount::double precision, + balance_after::double precision, + aff_quota_after::double precision, + aff_frozen_quota_after::double precision, + aff_history_quota_after::double precision +FROM user_affiliate_ledger +WHERE user_id = $1 AND action = 'transfer' +LIMIT 1`, u.ID) + require.NoError(t, err) + defer func() { _ = rows.Close() }() + require.True(t, rows.Next(), "expected transfer ledger") + var amount, balanceAfter, quotaAfter, frozenAfter, historyAfter float64 + require.NoError(t, rows.Scan(&amount, &balanceAfter, "aAfter, &frozenAfter, &historyAfter)) + require.InDelta(t, 12.34, amount, 1e-9) + require.InDelta(t, 17.84, balanceAfter, 1e-9) + require.InDelta(t, 0.0, quotaAfter, 1e-9) + require.InDelta(t, 0.0, frozenAfter, 1e-9) + require.InDelta(t, 12.34, historyAfter, 1e-9) } // TestAffiliateRepository_AccrueQuota_ReusesOuterTransaction guards the @@ -125,7 +145,7 @@ func TestAffiliateRepository_AccrueQuota_ReusesOuterTransaction(t *testing.T) { require.NoError(t, err) require.True(t, bound, "invitee must bind to inviter") - applied, err := repo.AccrueQuota(txCtx, inviter.ID, invitee.ID, 3.5, 0) + applied, err := repo.AccrueQuota(txCtx, inviter.ID, invitee.ID, 3.5, 0, nil) require.NoError(t, err) require.True(t, applied, "AccrueQuota must report applied=true") diff --git a/backend/internal/repository/affiliate_repo_test.go b/backend/internal/repository/affiliate_repo_test.go new file mode 100644 index 00000000..ccb7bb3d --- /dev/null +++ b/backend/internal/repository/affiliate_repo_test.go @@ -0,0 +1,28 @@ +package repository + +import ( + "os" + "strings" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestAffiliateUserOverviewSQLIncludesMaturedFrozenQuota(t *testing.T) { + query := strings.Join(strings.Fields(affiliateUserOverviewSQL), " ") + + require.Contains(t, query, "ua.aff_quota + COALESCE(matured.matured_frozen_quota, 0)") + require.Contains(t, query, "frozen_until <= NOW()") +} + +func TestAffiliateRecordQueriesUseLedgerAuditFields(t *testing.T) { + source, err := os.ReadFile("affiliate_repo.go") + require.NoError(t, err) + content := string(source) + + require.Contains(t, content, "JOIN payment_orders po ON po.id = ual.source_order_id") + require.Contains(t, content, "ual.amount::double precision") + require.Contains(t, content, "ual.balance_after::double precision") + require.NotContains(t, content, "parseAffiliateRebateAmount") + require.NotContains(t, content, `"current_balance": "u.balance"`) +} diff --git a/backend/internal/server/routes/admin.go b/backend/internal/server/routes/admin.go index 1c786f50..fe4c4b1b 100644 --- a/backend/internal/server/routes/admin.go +++ b/backend/internal/server/routes/admin.go @@ -602,11 +602,16 @@ func registerChannelMonitorRoutes(admin *gin.RouterGroup, h *handler.Handlers) { func registerAffiliateRoutes(admin *gin.RouterGroup, h *handler.Handlers) { affiliates := admin.Group("/affiliates") { + affiliates.GET("/invites", h.Admin.Affiliate.ListInviteRecords) + affiliates.GET("/rebates", h.Admin.Affiliate.ListRebateRecords) + affiliates.GET("/transfers", h.Admin.Affiliate.ListTransferRecords) + users := affiliates.Group("/users") { users.GET("", h.Admin.Affiliate.ListUsers) users.GET("/lookup", h.Admin.Affiliate.LookupUsers) users.POST("/batch-rate", h.Admin.Affiliate.BatchSetRate) + users.GET("/:user_id/overview", h.Admin.Affiliate.GetUserOverview) users.PUT("/:user_id", h.Admin.Affiliate.UpdateUserSettings) users.DELETE("/:user_id", h.Admin.Affiliate.ClearUserSettings) } diff --git a/backend/internal/service/admin_balance_history_test.go b/backend/internal/service/admin_balance_history_test.go new file mode 100644 index 00000000..291d3f7b --- /dev/null +++ b/backend/internal/service/admin_balance_history_test.go @@ -0,0 +1,86 @@ +package service + +import ( + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" + "github.com/stretchr/testify/require" +) + +func TestMergeBalanceHistoryCodesIncludesAffiliateTransfersByDefault(t *testing.T) { + t.Parallel() + + now := time.Date(2026, 5, 3, 12, 0, 0, 0, time.UTC) + older := now.Add(-2 * time.Hour) + newer := now.Add(time.Hour) + + usedBy := int64(10) + redeemCodes := []RedeemCode{ + { + ID: 1, + Type: RedeemTypeBalance, + Value: 8, + Status: StatusUsed, + UsedBy: &usedBy, + UsedAt: &now, + CreatedAt: now, + }, + { + ID: 2, + Type: RedeemTypeConcurrency, + Value: 1, + Status: StatusUsed, + UsedBy: &usedBy, + UsedAt: &older, + CreatedAt: older, + }, + } + affiliateCodes := []RedeemCode{ + { + ID: -20, + Type: RedeemTypeAffiliateBalance, + Value: 3.5, + Status: StatusUsed, + UsedBy: &usedBy, + UsedAt: &newer, + CreatedAt: newer, + }, + } + + got := mergeBalanceHistoryCodes(redeemCodes, affiliateCodes, pagination.PaginationParams{ + Page: 1, + PageSize: 2, + }) + + require.Len(t, got, 2) + require.Equal(t, RedeemTypeAffiliateBalance, got[0].Type) + require.Equal(t, RedeemTypeBalance, got[1].Type) +} + +func TestMergeBalanceHistoryCodesPaginatesAfterCombiningSources(t *testing.T) { + t.Parallel() + + base := time.Date(2026, 5, 3, 12, 0, 0, 0, time.UTC) + usedBy := int64(10) + at := func(hours int) *time.Time { + v := base.Add(time.Duration(hours) * time.Hour) + return &v + } + + got := mergeBalanceHistoryCodes( + []RedeemCode{ + {ID: 1, Type: RedeemTypeBalance, UsedBy: &usedBy, UsedAt: at(4), CreatedAt: *at(4)}, + {ID: 2, Type: RedeemTypeConcurrency, UsedBy: &usedBy, UsedAt: at(2), CreatedAt: *at(2)}, + }, + []RedeemCode{ + {ID: -3, Type: RedeemTypeAffiliateBalance, UsedBy: &usedBy, UsedAt: at(3), CreatedAt: *at(3)}, + {ID: -4, Type: RedeemTypeAffiliateBalance, UsedBy: &usedBy, UsedAt: at(1), CreatedAt: *at(1)}, + }, + pagination.PaginationParams{Page: 2, PageSize: 2}, + ) + + require.Len(t, got, 2) + require.Equal(t, RedeemTypeConcurrency, got[0].Type) + require.Equal(t, int64(-4), got[1].ID) +} diff --git a/backend/internal/service/admin_service.go b/backend/internal/service/admin_service.go index d966c684..be4c23dc 100644 --- a/backend/internal/service/admin_service.go +++ b/backend/internal/service/admin_service.go @@ -2,6 +2,7 @@ package service import ( "context" + "database/sql" "encoding/json" "errors" "fmt" @@ -973,16 +974,213 @@ func (s *adminServiceImpl) GetUserUsageStats(ctx context.Context, userID int64, // GetUserBalanceHistory returns paginated balance/concurrency change records for a user. func (s *adminServiceImpl) GetUserBalanceHistory(ctx context.Context, userID int64, page, pageSize int, codeType string) ([]RedeemCode, int64, float64, error) { params := pagination.PaginationParams{Page: page, PageSize: pageSize} + if codeType == RedeemTypeAffiliateBalance { + codes, total, err := s.listAffiliateBalanceHistory(ctx, userID, params) + if err != nil { + return nil, 0, 0, err + } + totalRecharged, err := s.redeemCodeRepo.SumPositiveBalanceByUser(ctx, userID) + if err != nil { + return nil, 0, 0, err + } + return codes, total, totalRecharged, nil + } + + if codeType == "" { + return s.getAllUserBalanceHistory(ctx, userID, params) + } + codes, result, err := s.redeemCodeRepo.ListByUserPaginated(ctx, userID, params, codeType) if err != nil { return nil, 0, 0, err } + total := result.Total // Aggregate total recharged amount (only once, regardless of type filter) totalRecharged, err := s.redeemCodeRepo.SumPositiveBalanceByUser(ctx, userID) if err != nil { return nil, 0, 0, err } - return codes, result.Total, totalRecharged, nil + return codes, total, totalRecharged, nil +} + +func (s *adminServiceImpl) getAllUserBalanceHistory(ctx context.Context, userID int64, params pagination.PaginationParams) ([]RedeemCode, int64, float64, error) { + needed := params.Offset() + params.Limit() + if needed < params.Limit() { + needed = params.Limit() + } + + redeemCodes, redeemTotal, err := s.listRedeemBalanceHistoryForMerge(ctx, userID, needed) + if err != nil { + return nil, 0, 0, err + } + affiliateCodes, affiliateTotal, err := s.listAffiliateBalanceHistoryForMerge(ctx, userID, needed) + if err != nil { + return nil, 0, 0, err + } + codes := mergeBalanceHistoryCodes(redeemCodes, affiliateCodes, params) + + totalRecharged, err := s.redeemCodeRepo.SumPositiveBalanceByUser(ctx, userID) + if err != nil { + return nil, 0, 0, err + } + return codes, redeemTotal + affiliateTotal, totalRecharged, nil +} + +func (s *adminServiceImpl) listRedeemBalanceHistoryForMerge(ctx context.Context, userID int64, needed int) ([]RedeemCode, int64, error) { + if needed <= 0 { + return nil, 0, nil + } + + var ( + out []RedeemCode + total int64 + ) + for page := 1; len(out) < needed; page++ { + params := pagination.PaginationParams{Page: page, PageSize: 1000} + codes, result, err := s.redeemCodeRepo.ListByUserPaginated(ctx, userID, params, "") + if err != nil { + return nil, 0, err + } + if result != nil { + total = result.Total + } + out = append(out, codes...) + if len(codes) < params.Limit() || int64(len(out)) >= total { + break + } + } + if len(out) > needed { + out = out[:needed] + } + return out, total, nil +} + +func (s *adminServiceImpl) listAffiliateBalanceHistoryForMerge(ctx context.Context, userID int64, needed int) ([]RedeemCode, int64, error) { + if needed <= 0 { + return nil, 0, nil + } + + var ( + out []RedeemCode + total int64 + ) + for page := 1; len(out) < needed; page++ { + params := pagination.PaginationParams{Page: page, PageSize: 1000} + codes, currentTotal, err := s.listAffiliateBalanceHistory(ctx, userID, params) + if err != nil { + return nil, 0, err + } + total = currentTotal + out = append(out, codes...) + if len(codes) < params.Limit() || int64(len(out)) >= total { + break + } + } + if len(out) > needed { + out = out[:needed] + } + return out, total, nil +} + +func (s *adminServiceImpl) listAffiliateBalanceHistory(ctx context.Context, userID int64, params pagination.PaginationParams) ([]RedeemCode, int64, error) { + if s == nil || s.entClient == nil || userID <= 0 { + return nil, 0, nil + } + + rows, err := s.entClient.QueryContext(ctx, ` +SELECT id, + amount::double precision, + created_at +FROM user_affiliate_ledger +WHERE user_id = $1 + AND action = 'transfer' +ORDER BY created_at DESC, id DESC +OFFSET $2 +LIMIT $3`, userID, params.Offset(), params.Limit()) + if err != nil { + return nil, 0, err + } + defer func() { _ = rows.Close() }() + + codes := make([]RedeemCode, 0, params.Limit()) + for rows.Next() { + var id int64 + var amount float64 + var createdAt time.Time + if err := rows.Scan(&id, &amount, &createdAt); err != nil { + return nil, 0, err + } + usedBy := userID + usedAt := createdAt + codes = append(codes, RedeemCode{ + ID: -id, + Code: fmt.Sprintf("AFF-%d", id), + Type: RedeemTypeAffiliateBalance, + Value: amount, + Status: StatusUsed, + UsedBy: &usedBy, + UsedAt: &usedAt, + CreatedAt: createdAt, + }) + } + if err := rows.Err(); err != nil { + return nil, 0, err + } + + total, err := countAffiliateBalanceHistory(ctx, s.entClient, userID) + if err != nil { + return nil, 0, err + } + return codes, total, nil +} + +func countAffiliateBalanceHistory(ctx context.Context, client *dbent.Client, userID int64) (int64, error) { + rows, err := client.QueryContext(ctx, ` +SELECT COUNT(*) +FROM user_affiliate_ledger +WHERE user_id = $1 + AND action = 'transfer'`, userID) + if err != nil { + return 0, err + } + defer func() { _ = rows.Close() }() + + var total sql.NullInt64 + if rows.Next() { + if err := rows.Scan(&total); err != nil { + return 0, err + } + } + if err := rows.Err(); err != nil { + return 0, err + } + if !total.Valid { + return 0, nil + } + return total.Int64, nil +} + +func mergeBalanceHistoryCodes(redeemCodes, affiliateCodes []RedeemCode, params pagination.PaginationParams) []RedeemCode { + combined := append(append([]RedeemCode{}, redeemCodes...), affiliateCodes...) + sort.SliceStable(combined, func(i, j int) bool { + return redeemCodeHistoryTime(combined[i]).After(redeemCodeHistoryTime(combined[j])) + }) + offset := params.Offset() + if offset >= len(combined) { + return []RedeemCode{} + } + end := offset + params.Limit() + if end > len(combined) { + end = len(combined) + } + return combined[offset:end] +} + +func redeemCodeHistoryTime(code RedeemCode) time.Time { + if code.UsedAt != nil { + return *code.UsedAt + } + return code.CreatedAt } func (s *adminServiceImpl) BindUserAuthIdentity(ctx context.Context, userID int64, input AdminBindAuthIdentityInput) (*AdminBoundAuthIdentity, error) { diff --git a/backend/internal/service/affiliate_service.go b/backend/internal/service/affiliate_service.go index 5a4e91e7..91cca5e2 100644 --- a/backend/internal/service/affiliate_service.go +++ b/backend/internal/service/affiliate_service.go @@ -98,7 +98,7 @@ type AffiliateRepository interface { EnsureUserAffiliate(ctx context.Context, userID int64) (*AffiliateSummary, error) GetAffiliateByCode(ctx context.Context, code string) (*AffiliateSummary, error) BindInviter(ctx context.Context, userID, inviterID int64) (bool, error) - AccrueQuota(ctx context.Context, inviterID, inviteeUserID int64, amount float64, freezeHours int) (bool, error) + AccrueQuota(ctx context.Context, inviterID, inviteeUserID int64, amount float64, freezeHours int, sourceOrderID *int64) (bool, error) GetAccruedRebateFromInvitee(ctx context.Context, inviterID, inviteeUserID int64) (float64, error) ThawFrozenQuota(ctx context.Context, userID int64) (float64, error) TransferQuotaToBalance(ctx context.Context, userID int64) (float64, float64, error) @@ -110,6 +110,10 @@ type AffiliateRepository interface { SetUserRebateRate(ctx context.Context, userID int64, ratePercent *float64) error BatchSetUserRebateRate(ctx context.Context, userIDs []int64, ratePercent *float64) error ListUsersWithCustomSettings(ctx context.Context, filter AffiliateAdminFilter) ([]AffiliateAdminEntry, int64, error) + ListAffiliateInviteRecords(ctx context.Context, filter AffiliateRecordFilter) ([]AffiliateInviteRecord, int64, error) + ListAffiliateRebateRecords(ctx context.Context, filter AffiliateRecordFilter) ([]AffiliateRebateRecord, int64, error) + ListAffiliateTransferRecords(ctx context.Context, filter AffiliateRecordFilter) ([]AffiliateTransferRecord, int64, error) + GetAffiliateUserOverview(ctx context.Context, userID int64) (*AffiliateUserOverview, error) } // AffiliateAdminFilter 列表筛选条件 @@ -130,6 +134,76 @@ type AffiliateAdminEntry struct { AffCount int `json:"aff_count"` } +type AffiliateRecordFilter struct { + Search string + Page int + PageSize int + StartAt *time.Time + EndAt *time.Time + SortBy string + SortDesc bool +} + +type AffiliateInviteRecord struct { + InviterID int64 `json:"inviter_id"` + InviterEmail string `json:"inviter_email"` + InviterUsername string `json:"inviter_username"` + InviteeID int64 `json:"invitee_id"` + InviteeEmail string `json:"invitee_email"` + InviteeUsername string `json:"invitee_username"` + AffCode string `json:"aff_code"` + TotalRebate float64 `json:"total_rebate"` + CreatedAt time.Time `json:"created_at"` +} + +type AffiliateRebateRecord struct { + OrderID int64 `json:"order_id"` + OutTradeNo string `json:"out_trade_no"` + InviterID int64 `json:"inviter_id"` + InviterEmail string `json:"inviter_email"` + InviterUsername string `json:"inviter_username"` + InviteeID int64 `json:"invitee_id"` + InviteeEmail string `json:"invitee_email"` + InviteeUsername string `json:"invitee_username"` + OrderAmount float64 `json:"order_amount"` + PayAmount float64 `json:"pay_amount"` + RebateAmount float64 `json:"rebate_amount"` + PaymentType string `json:"payment_type"` + OrderStatus string `json:"order_status"` + CreatedAt time.Time `json:"created_at"` +} + +type AffiliateTransferRecord struct { + LedgerID int64 `json:"ledger_id"` + UserID int64 `json:"user_id"` + UserEmail string `json:"user_email"` + Username string `json:"username"` + Amount float64 `json:"amount"` + BalanceAfter *float64 `json:"balance_after,omitempty"` + AvailableQuotaAfter *float64 `json:"available_quota_after,omitempty"` + FrozenQuotaAfter *float64 `json:"frozen_quota_after,omitempty"` + HistoryQuotaAfter *float64 `json:"history_quota_after,omitempty"` + SnapshotAvailable bool `json:"snapshot_available"` + CurrentBalance float64 `json:"-"` + RemainingQuota float64 `json:"-"` + FrozenQuota float64 `json:"-"` + HistoryQuota float64 `json:"-"` + CreatedAt time.Time `json:"created_at"` +} + +type AffiliateUserOverview struct { + UserID int64 `json:"user_id"` + Email string `json:"email"` + Username string `json:"username"` + AffCode string `json:"aff_code"` + RebateRatePercent float64 `json:"rebate_rate_percent"` + RebateRateCustom bool `json:"-"` + InvitedCount int `json:"invited_count"` + RebatedInviteeCount int `json:"rebated_invitee_count"` + AvailableQuota float64 `json:"available_quota"` + HistoryQuota float64 `json:"history_quota"` +} + type AffiliateService struct { repo AffiliateRepository settingService *SettingService @@ -238,6 +312,10 @@ func (s *AffiliateService) BindInviterByCode(ctx context.Context, userID int64, } func (s *AffiliateService) AccrueInviteRebate(ctx context.Context, inviteeUserID int64, baseRechargeAmount float64) (float64, error) { + return s.AccrueInviteRebateForOrder(ctx, inviteeUserID, baseRechargeAmount, nil) +} + +func (s *AffiliateService) AccrueInviteRebateForOrder(ctx context.Context, inviteeUserID int64, baseRechargeAmount float64, sourceOrderID *int64) (float64, error) { if s == nil || s.repo == nil { return 0, nil } @@ -298,7 +376,7 @@ func (s *AffiliateService) AccrueInviteRebate(ctx context.Context, inviteeUserID freezeHours = s.settingService.GetAffiliateRebateFreezeHours(ctx) } - applied, err := s.repo.AccrueQuota(ctx, *inviteeSummary.InviterID, inviteeUserID, rebate, freezeHours) + applied, err := s.repo.AccrueQuota(ctx, *inviteeSummary.InviterID, inviteeUserID, rebate, freezeHours, sourceOrderID) if err != nil { return 0, err } @@ -488,3 +566,59 @@ func (s *AffiliateService) AdminListCustomUsers(ctx context.Context, filter Affi } return s.repo.ListUsersWithCustomSettings(ctx, filter) } + +func (s *AffiliateService) AdminListInviteRecords(ctx context.Context, filter AffiliateRecordFilter) ([]AffiliateInviteRecord, int64, error) { + if s == nil || s.repo == nil { + return nil, 0, infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "affiliate service unavailable") + } + return s.repo.ListAffiliateInviteRecords(ctx, normalizeAffiliateRecordFilter(filter)) +} + +func (s *AffiliateService) AdminListRebateRecords(ctx context.Context, filter AffiliateRecordFilter) ([]AffiliateRebateRecord, int64, error) { + if s == nil || s.repo == nil { + return nil, 0, infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "affiliate service unavailable") + } + return s.repo.ListAffiliateRebateRecords(ctx, normalizeAffiliateRecordFilter(filter)) +} + +func (s *AffiliateService) AdminListTransferRecords(ctx context.Context, filter AffiliateRecordFilter) ([]AffiliateTransferRecord, int64, error) { + if s == nil || s.repo == nil { + return nil, 0, infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "affiliate service unavailable") + } + return s.repo.ListAffiliateTransferRecords(ctx, normalizeAffiliateRecordFilter(filter)) +} + +func (s *AffiliateService) AdminGetUserOverview(ctx context.Context, userID int64) (*AffiliateUserOverview, error) { + if userID <= 0 { + return nil, infraerrors.BadRequest("INVALID_USER", "invalid user") + } + if s == nil || s.repo == nil { + return nil, infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "affiliate service unavailable") + } + overview, err := s.repo.GetAffiliateUserOverview(ctx, userID) + if err != nil { + return nil, err + } + if overview != nil { + if !overview.RebateRateCustom { + overview.RebateRatePercent = s.globalRebateRatePercent(ctx) + } + overview.RebateRatePercent = clampAffiliateRebateRate(overview.RebateRatePercent) + } + return overview, nil +} + +func normalizeAffiliateRecordFilter(filter AffiliateRecordFilter) AffiliateRecordFilter { + if filter.Page <= 0 { + filter.Page = 1 + } + if filter.PageSize <= 0 { + filter.PageSize = 20 + } + if filter.PageSize > 100 { + filter.PageSize = 100 + } + filter.Search = strings.TrimSpace(filter.Search) + filter.SortBy = strings.TrimSpace(filter.SortBy) + return filter +} diff --git a/backend/internal/service/domain_constants.go b/backend/internal/service/domain_constants.go index bb32540b..632ebf5f 100644 --- a/backend/internal/service/domain_constants.go +++ b/backend/internal/service/domain_constants.go @@ -51,10 +51,11 @@ const ( // Redeem type constants const ( - RedeemTypeBalance = domain.RedeemTypeBalance - RedeemTypeConcurrency = domain.RedeemTypeConcurrency - RedeemTypeSubscription = domain.RedeemTypeSubscription - RedeemTypeInvitation = domain.RedeemTypeInvitation + RedeemTypeBalance = domain.RedeemTypeBalance + RedeemTypeConcurrency = domain.RedeemTypeConcurrency + RedeemTypeSubscription = domain.RedeemTypeSubscription + RedeemTypeInvitation = domain.RedeemTypeInvitation + RedeemTypeAffiliateBalance = "affiliate_balance" ) // PromoCode status constants diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index 074013c3..67d19720 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -8174,9 +8174,16 @@ func detachedBillingContext(ctx context.Context) (context.Context, context.Cance } func detachStreamUpstreamContext(ctx context.Context, stream bool) (context.Context, context.CancelFunc) { + if ctx == nil { + return context.Background(), func() {} + } if !stream { return ctx, func() {} } + return context.WithoutCancel(ctx), func() {} +} + +func detachUpstreamContext(ctx context.Context) (context.Context, context.CancelFunc) { if ctx == nil { return context.Background(), func() {} } diff --git a/backend/internal/service/gateway_service_streaming_test.go b/backend/internal/service/gateway_service_streaming_test.go index c8803d39..39a7d3b0 100644 --- a/backend/internal/service/gateway_service_streaming_test.go +++ b/backend/internal/service/gateway_service_streaming_test.go @@ -13,6 +13,8 @@ import ( "github.com/stretchr/testify/require" ) +type upstreamContextTestKey string + func TestGatewayService_StreamingReusesScannerBufferAndStillParsesUsage(t *testing.T) { gin.SetMode(gin.TestMode) cfg := &config.Config{ @@ -50,3 +52,14 @@ func TestGatewayService_StreamingReusesScannerBufferAndStillParsesUsage(t *testi require.Equal(t, 3, result.usage.InputTokens) require.Equal(t, 7, result.usage.OutputTokens) } + +func TestDetachUpstreamContextIgnoresClientCancel(t *testing.T) { + parent, cancel := context.WithCancel(context.WithValue(context.Background(), upstreamContextTestKey("test-key"), "test-value")) + upstreamCtx, release := detachUpstreamContext(parent) + defer release() + + cancel() + + require.NoError(t, upstreamCtx.Err()) + require.Equal(t, "test-value", upstreamCtx.Value(upstreamContextTestKey("test-key"))) +} diff --git a/backend/internal/service/openai_compat_model_test.go b/backend/internal/service/openai_compat_model_test.go index 4396c15f..840784bf 100644 --- a/backend/internal/service/openai_compat_model_test.go +++ b/backend/internal/service/openai_compat_model_test.go @@ -3,13 +3,16 @@ package service import ( "bytes" "context" + "errors" "io" "net/http" "net/http/httptest" "os" "path/filepath" "strings" + "sync" "testing" + "time" "github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/pkg/apicompat" @@ -18,6 +21,51 @@ import ( "github.com/tidwall/gjson" ) +type openAICompatFailingWriter struct { + gin.ResponseWriter + failAfter int + writes int +} + +func (w *openAICompatFailingWriter) Write(p []byte) (int, error) { + if w.writes >= w.failAfter { + return 0, errors.New("write failed: client disconnected") + } + w.writes++ + return w.ResponseWriter.Write(p) +} + +type openAICompatBlockingReadCloser struct { + data []byte + offset int + closed chan struct{} + closeOnce sync.Once +} + +func newOpenAICompatBlockingReadCloser(data []byte) *openAICompatBlockingReadCloser { + return &openAICompatBlockingReadCloser{ + data: data, + closed: make(chan struct{}), + } +} + +func (r *openAICompatBlockingReadCloser) Read(p []byte) (int, error) { + if r.offset < len(r.data) { + n := copy(p, r.data[r.offset:]) + r.offset += n + return n, nil + } + <-r.closed + return 0, io.EOF +} + +func (r *openAICompatBlockingReadCloser) Close() error { + r.closeOnce.Do(func() { + close(r.closed) + }) + return nil +} + func TestNormalizeOpenAICompatRequestedModel(t *testing.T) { t.Parallel() @@ -228,3 +276,242 @@ func TestForwardAsAnthropic_ForcedCodexInstructionsTemplateUsesCachedTemplateCon require.NotNil(t, result) require.Equal(t, "cached-prefix\n\nclient-system", gjson.GetBytes(upstream.lastBody, "instructions").String()) } + +func TestForwardAsAnthropic_ClientDisconnectDrainsUpstreamUsage(t *testing.T) { + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Writer = &openAICompatFailingWriter{ResponseWriter: c.Writer, failAfter: 0} + body := []byte(`{"model":"gpt-5.4","max_tokens":16,"messages":[{"role":"user","content":"hello"}],"stream":true}`) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body)) + c.Request.Header.Set("Content-Type", "application/json") + + upstreamBody := strings.Join([]string{ + `data: {"type":"response.created","response":{"id":"resp_1","model":"gpt-5.4","status":"in_progress","output":[]}}`, + "", + `data: {"type":"response.output_text.delta","delta":"ok"}`, + "", + `data: {"type":"response.completed","response":{"id":"resp_1","object":"response","model":"gpt-5.4","status":"completed","output":[{"type":"message","id":"msg_1","role":"assistant","status":"completed","content":[{"type":"output_text","text":"ok"}]}],"usage":{"input_tokens":9,"output_tokens":4,"total_tokens":13,"input_tokens_details":{"cached_tokens":3}}}}`, + "", + "data: [DONE]", + "", + }, "\n") + upstream := &httpUpstreamRecorder{resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_disconnect"}}, + Body: io.NopCloser(strings.NewReader(upstreamBody)), + }} + + svc := &OpenAIGatewayService{httpUpstream: upstream} + account := &Account{ + ID: 1, + Name: "openai-oauth", + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Concurrency: 1, + Credentials: map[string]any{ + "access_token": "oauth-token", + "chatgpt_account_id": "chatgpt-acc", + }, + } + + result, err := svc.ForwardAsAnthropic(context.Background(), c, account, body, "", "gpt-5.1") + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, 9, result.Usage.InputTokens) + require.Equal(t, 4, result.Usage.OutputTokens) + require.Equal(t, 3, result.Usage.CacheReadInputTokens) +} + +func TestForwardAsAnthropic_TerminalUsageWithoutUpstreamCloseReturns(t *testing.T) { + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Writer = &openAICompatFailingWriter{ResponseWriter: c.Writer, failAfter: 0} + body := []byte(`{"model":"gpt-5.4","max_tokens":16,"messages":[{"role":"user","content":"hello"}],"stream":true}`) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body)) + c.Request.Header.Set("Content-Type", "application/json") + + upstreamBody := []byte(`data: {"type":"response.completed","response":{"id":"resp_1","object":"response","model":"gpt-5.4","status":"completed","output":[{"type":"message","id":"msg_1","role":"assistant","status":"completed","content":[{"type":"output_text","text":"ok"}]}],"usage":{"input_tokens":15,"output_tokens":6,"total_tokens":21,"input_tokens_details":{"cached_tokens":5}}}}` + "\n\n") + upstreamStream := newOpenAICompatBlockingReadCloser(upstreamBody) + defer func() { + require.NoError(t, upstreamStream.Close()) + }() + upstream := &httpUpstreamRecorder{resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_terminal_no_close"}}, + Body: upstreamStream, + }} + + svc := &OpenAIGatewayService{httpUpstream: upstream} + account := &Account{ + ID: 1, + Name: "openai-oauth", + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Concurrency: 1, + Credentials: map[string]any{ + "access_token": "oauth-token", + "chatgpt_account_id": "chatgpt-acc", + }, + } + + type forwardResult struct { + result *OpenAIForwardResult + err error + } + resultCh := make(chan forwardResult, 1) + go func() { + result, err := svc.ForwardAsAnthropic(context.Background(), c, account, body, "", "gpt-5.1") + resultCh <- forwardResult{result: result, err: err} + }() + + select { + case got := <-resultCh: + require.NoError(t, got.err) + require.NotNil(t, got.result) + require.Equal(t, 15, got.result.Usage.InputTokens) + require.Equal(t, 6, got.result.Usage.OutputTokens) + require.Equal(t, 5, got.result.Usage.CacheReadInputTokens) + case <-time.After(time.Second): + require.Fail(t, "ForwardAsAnthropic should return after terminal usage event even if upstream keeps the connection open") + } +} + +func TestForwardAsAnthropic_BufferedTerminalWithoutUpstreamCloseReturns(t *testing.T) { + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + body := []byte(`{"model":"gpt-5.4","max_tokens":16,"messages":[{"role":"user","content":"hello"}],"stream":false}`) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body)) + c.Request.Header.Set("Content-Type", "application/json") + + upstreamBody := []byte(`data: {"type":"response.completed","response":{"id":"resp_1","object":"response","model":"gpt-5.4","status":"completed","output":[{"type":"message","id":"msg_1","role":"assistant","status":"completed","content":[{"type":"output_text","text":"ok"}]}],"usage":{"input_tokens":15,"output_tokens":6,"total_tokens":21,"input_tokens_details":{"cached_tokens":5}}}}` + "\n\n") + upstreamStream := newOpenAICompatBlockingReadCloser(upstreamBody) + defer func() { + require.NoError(t, upstreamStream.Close()) + }() + upstream := &httpUpstreamRecorder{resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_buffered_terminal_no_close"}}, + Body: upstreamStream, + }} + + svc := &OpenAIGatewayService{httpUpstream: upstream} + account := &Account{ + ID: 1, + Name: "openai-oauth", + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Concurrency: 1, + Credentials: map[string]any{ + "access_token": "oauth-token", + "chatgpt_account_id": "chatgpt-acc", + }, + } + + type forwardResult struct { + result *OpenAIForwardResult + err error + } + resultCh := make(chan forwardResult, 1) + go func() { + result, err := svc.ForwardAsAnthropic(context.Background(), c, account, body, "", "gpt-5.1") + resultCh <- forwardResult{result: result, err: err} + }() + + select { + case got := <-resultCh: + require.NoError(t, got.err) + require.NotNil(t, got.result) + require.Equal(t, 15, got.result.Usage.InputTokens) + require.Equal(t, 6, got.result.Usage.OutputTokens) + require.Equal(t, 5, got.result.Usage.CacheReadInputTokens) + require.Contains(t, rec.Body.String(), `"stop_reason":"end_turn"`) + case <-time.After(time.Second): + require.Fail(t, "ForwardAsAnthropic buffered response should return after terminal usage event even if upstream keeps the connection open") + } +} + +func TestForwardAsAnthropic_DoneSentinelWithoutTerminalReturnsError(t *testing.T) { + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + body := []byte(`{"model":"gpt-5.4","max_tokens":16,"messages":[{"role":"user","content":"hello"}],"stream":true}`) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body)) + c.Request.Header.Set("Content-Type", "application/json") + + upstreamBody := "data: [DONE]\n\n" + upstream := &httpUpstreamRecorder{resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_missing_terminal"}}, + Body: io.NopCloser(strings.NewReader(upstreamBody)), + }} + + svc := &OpenAIGatewayService{httpUpstream: upstream} + account := &Account{ + ID: 1, + Name: "openai-oauth", + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Concurrency: 1, + Credentials: map[string]any{ + "access_token": "oauth-token", + "chatgpt_account_id": "chatgpt-acc", + }, + } + + result, err := svc.ForwardAsAnthropic(context.Background(), c, account, body, "", "gpt-5.1") + require.Error(t, err) + require.Contains(t, err.Error(), "missing terminal event") + require.NotNil(t, result) + require.Zero(t, result.Usage.InputTokens) + require.Zero(t, result.Usage.OutputTokens) +} + +func TestForwardAsAnthropic_UpstreamRequestIgnoresClientCancel(t *testing.T) { + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + reqCtx, cancel := context.WithCancel(context.Background()) + body := []byte(`{"model":"gpt-5.4","max_tokens":16,"messages":[{"role":"user","content":"hello"}],"stream":false}`) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body)).WithContext(reqCtx) + c.Request.Header.Set("Content-Type", "application/json") + cancel() + + upstreamBody := strings.Join([]string{ + `data: {"type":"response.completed","response":{"id":"resp_1","object":"response","model":"gpt-5.4","status":"completed","output":[{"type":"message","id":"msg_1","role":"assistant","status":"completed","content":[{"type":"output_text","text":"ok"}]}],"usage":{"input_tokens":5,"output_tokens":2,"total_tokens":7}}}`, + "", + "data: [DONE]", + "", + }, "\n") + upstream := &httpUpstreamRecorder{resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_ctx"}}, + Body: io.NopCloser(strings.NewReader(upstreamBody)), + }} + + svc := &OpenAIGatewayService{httpUpstream: upstream} + account := &Account{ + ID: 1, + Name: "openai-oauth", + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Concurrency: 1, + Credentials: map[string]any{ + "access_token": "oauth-token", + "chatgpt_account_id": "chatgpt-acc", + }, + } + + result, err := svc.ForwardAsAnthropic(reqCtx, c, account, body, "", "gpt-5.1") + require.NoError(t, err) + require.NotNil(t, result) + require.NotNil(t, upstream.lastReq) + require.NoError(t, upstream.lastReq.Context().Err()) +} diff --git a/backend/internal/service/openai_gateway_403_reset_test.go b/backend/internal/service/openai_gateway_403_reset_test.go index c6805464..440b94a9 100644 --- a/backend/internal/service/openai_gateway_403_reset_test.go +++ b/backend/internal/service/openai_gateway_403_reset_test.go @@ -20,20 +20,29 @@ func (s *openAI403CounterResetStub) ResetOpenAI403Count(_ context.Context, accou return nil } -func TestOpenAIGatewayServiceRecordUsage_ResetsOpenAI403CounterBeforeZeroUsageReturn(t *testing.T) { +func TestOpenAIGatewayServiceRecordUsage_ResetsOpenAI403CounterForZeroUsage(t *testing.T) { counter := &openAI403CounterResetStub{} rateLimitSvc := NewRateLimitService(nil, nil, nil, nil, nil) rateLimitSvc.SetOpenAI403CounterCache(counter) - svc := &OpenAIGatewayService{ - rateLimitService: rateLimitSvc, - } + usageRepo := &openAIRecordUsageLogRepoStub{inserted: true} + billingRepo := &openAIRecordUsageBillingRepoStub{result: &UsageBillingApplyResult{Applied: true}} + userRepo := &openAIRecordUsageUserRepoStub{} + subRepo := &openAIRecordUsageSubRepoStub{} + svc := newOpenAIRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, userRepo, subRepo, nil) + svc.rateLimitService = rateLimitSvc err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{ - Result: &OpenAIForwardResult{}, + Result: &OpenAIForwardResult{ + RequestID: "resp_zero_usage_reset_403", + Model: "gpt-5.1", + }, + APIKey: &APIKey{ID: 1001, Group: &Group{RateMultiplier: 1}}, + User: &User{ID: 2001}, Account: &Account{ID: 777, Platform: PlatformOpenAI}, }) require.NoError(t, err) require.Equal(t, []int64{777}, counter.resetCalls) + require.Equal(t, 1, usageRepo.calls) } diff --git a/backend/internal/service/openai_gateway_chat_completions.go b/backend/internal/service/openai_gateway_chat_completions.go index 5822ae4c..3456cce0 100644 --- a/backend/internal/service/openai_gateway_chat_completions.go +++ b/backend/internal/service/openai_gateway_chat_completions.go @@ -10,6 +10,7 @@ import ( "io" "net/http" "strings" + "sync/atomic" "time" "github.com/Wei-Shaw/sub2api/internal/pkg/apicompat" @@ -189,7 +190,9 @@ func (s *OpenAIGatewayService) ForwardAsChatCompletions( } // 6. Build upstream request - upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, responsesBody, token, true, promptCacheKey, false) + upstreamCtx, releaseUpstreamCtx := detachUpstreamContext(ctx) + upstreamReq, err := s.buildUpstreamRequest(upstreamCtx, c, account, responsesBody, token, true, promptCacheKey, false) + releaseUpstreamCtx() if err != nil { return nil, fmt.Errorf("build upstream request: %w", err) } @@ -348,59 +351,9 @@ func (s *OpenAIGatewayService) handleChatBufferedStreamingResponse( ) (*OpenAIForwardResult, error) { requestID := resp.Header.Get("x-request-id") - scanner := bufio.NewScanner(resp.Body) - maxLineSize := defaultMaxLineSize - if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 { - maxLineSize = s.cfg.Gateway.MaxLineSize - } - scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize) - - var finalResponse *apicompat.ResponsesResponse - var usage OpenAIUsage - acc := apicompat.NewBufferedResponseAccumulator() - - for scanner.Scan() { - line := scanner.Text() - if !strings.HasPrefix(line, "data: ") || line == "data: [DONE]" { - continue - } - payload := line[6:] - - var event apicompat.ResponsesStreamEvent - if err := json.Unmarshal([]byte(payload), &event); err != nil { - logger.L().Warn("openai chat_completions buffered: failed to parse event", - zap.Error(err), - zap.String("request_id", requestID), - ) - continue - } - - // Accumulate delta content for fallback when terminal output is empty. - acc.ProcessEvent(&event) - - if (event.Type == "response.completed" || event.Type == "response.done" || - event.Type == "response.incomplete" || event.Type == "response.failed") && - event.Response != nil { - finalResponse = event.Response - if event.Response.Usage != nil { - usage = OpenAIUsage{ - InputTokens: event.Response.Usage.InputTokens, - OutputTokens: event.Response.Usage.OutputTokens, - } - if event.Response.Usage.InputTokensDetails != nil { - usage.CacheReadInputTokens = event.Response.Usage.InputTokensDetails.CachedTokens - } - } - } - } - - if err := scanner.Err(); err != nil { - if !errors.Is(err, context.Canceled) && !errors.Is(err, context.DeadlineExceeded) { - logger.L().Warn("openai chat_completions buffered: read error", - zap.Error(err), - zap.String("request_id", requestID), - ) - } + finalResponse, usage, acc, err := s.readOpenAICompatBufferedTerminal(resp, "openai chat_completions buffered", requestID) + if err != nil { + return nil, err } if finalResponse == nil { @@ -459,6 +412,7 @@ func (s *OpenAIGatewayService) handleChatStreamingResponse( var usage OpenAIUsage var firstTokenMs *int firstChunk := true + clientDisconnected := false scanner := bufio.NewScanner(resp.Body) maxLineSize := defaultMaxLineSize @@ -467,6 +421,20 @@ func (s *OpenAIGatewayService) handleChatStreamingResponse( } scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize) + streamInterval := time.Duration(0) + if s.cfg != nil && s.cfg.Gateway.StreamDataIntervalTimeout > 0 { + streamInterval = time.Duration(s.cfg.Gateway.StreamDataIntervalTimeout) * time.Second + } + var intervalTicker *time.Ticker + if streamInterval > 0 { + intervalTicker = time.NewTicker(streamInterval) + defer intervalTicker.Stop() + } + var intervalCh <-chan time.Time + if intervalTicker != nil { + intervalCh = intervalTicker.C + } + resultWithUsage := func() *OpenAIForwardResult { return &OpenAIForwardResult{ RequestID: requestID, @@ -496,54 +464,66 @@ func (s *OpenAIGatewayService) handleChatStreamingResponse( return false } - // Extract usage from completion events - if (event.Type == "response.completed" || event.Type == "response.incomplete" || event.Type == "response.failed") && - event.Response != nil && event.Response.Usage != nil { - usage = OpenAIUsage{ - InputTokens: event.Response.Usage.InputTokens, - OutputTokens: event.Response.Usage.OutputTokens, - } - if event.Response.Usage.InputTokensDetails != nil { - usage.CacheReadInputTokens = event.Response.Usage.InputTokensDetails.CachedTokens - } + // 仅按兼容转换器支持的终止事件提取 usage,避免无意扩大事件语义。 + isTerminalEvent := isOpenAICompatResponsesTerminalEvent(event.Type) + if isTerminalEvent && event.Response != nil && event.Response.Usage != nil { + usage = copyOpenAIUsageFromResponsesUsage(event.Response.Usage) } chunks := apicompat.ResponsesEventToChatChunks(&event, state) - for _, chunk := range chunks { - sse, err := apicompat.ChatChunkToSSE(chunk) - if err != nil { - logger.L().Warn("openai chat_completions stream: failed to marshal chunk", - zap.Error(err), - zap.String("request_id", requestID), - ) - continue - } - if _, err := fmt.Fprint(c.Writer, sse); err != nil { - logger.L().Info("openai chat_completions stream: client disconnected", - zap.String("request_id", requestID), - ) - return true + if !clientDisconnected { + for _, chunk := range chunks { + sse, err := apicompat.ChatChunkToSSE(chunk) + if err != nil { + logger.L().Warn("openai chat_completions stream: failed to marshal chunk", + zap.Error(err), + zap.String("request_id", requestID), + ) + continue + } + if _, err := fmt.Fprint(c.Writer, sse); err != nil { + clientDisconnected = true + logger.L().Info("openai chat_completions stream: client disconnected, continuing to drain upstream for billing", + zap.String("request_id", requestID), + ) + break + } } } - if len(chunks) > 0 { + if len(chunks) > 0 && !clientDisconnected { c.Writer.Flush() } - return false + return isTerminalEvent } finalizeStream := func() (*OpenAIForwardResult, error) { - if finalChunks := apicompat.FinalizeResponsesChatStream(state); len(finalChunks) > 0 { + if finalChunks := apicompat.FinalizeResponsesChatStream(state); len(finalChunks) > 0 && !clientDisconnected { for _, chunk := range finalChunks { sse, err := apicompat.ChatChunkToSSE(chunk) if err != nil { continue } - fmt.Fprint(c.Writer, sse) //nolint:errcheck + if _, err := fmt.Fprint(c.Writer, sse); err != nil { + clientDisconnected = true + logger.L().Info("openai chat_completions stream: client disconnected during final flush", + zap.String("request_id", requestID), + ) + break + } } } // Send [DONE] sentinel - fmt.Fprint(c.Writer, "data: [DONE]\n\n") //nolint:errcheck - c.Writer.Flush() + if !clientDisconnected { + if _, err := fmt.Fprint(c.Writer, "data: [DONE]\n\n"); err != nil { + clientDisconnected = true + logger.L().Info("openai chat_completions stream: client disconnected during done flush", + zap.String("request_id", requestID), + ) + } + } + if !clientDisconnected { + c.Writer.Flush() + } return resultWithUsage(), nil } @@ -555,6 +535,9 @@ func (s *OpenAIGatewayService) handleChatStreamingResponse( ) } } + missingTerminalErr := func() (*OpenAIForwardResult, error) { + return resultWithUsage(), fmt.Errorf("stream usage incomplete: missing terminal event") + } // Determine keepalive interval keepaliveInterval := time.Duration(0) @@ -563,18 +546,25 @@ func (s *OpenAIGatewayService) handleChatStreamingResponse( } // No keepalive: fast synchronous path - if keepaliveInterval <= 0 { + if streamInterval <= 0 && keepaliveInterval <= 0 { for scanner.Scan() { line := scanner.Text() - if !strings.HasPrefix(line, "data: ") || line == "data: [DONE]" { + payload, ok := extractOpenAISSEDataLine(line) + if !ok { continue } - if processDataLine(line[6:]) { - return resultWithUsage(), nil + if strings.TrimSpace(payload) == "[DONE]" { + return missingTerminalErr() + } + if processDataLine(payload) { + return finalizeStream() } } - handleScanErr(scanner.Err()) - return finalizeStream() + if err := scanner.Err(); err != nil { + handleScanErr(err) + return resultWithUsage(), fmt.Errorf("stream usage incomplete: %w", err) + } + return missingTerminalErr() } // With keepalive: goroutine + channel + select @@ -584,6 +574,8 @@ func (s *OpenAIGatewayService) handleChatStreamingResponse( } events := make(chan scanEvent, 16) done := make(chan struct{}) + var lastReadAt int64 + atomic.StoreInt64(&lastReadAt, time.Now().UnixNano()) sendEvent := func(ev scanEvent) bool { select { case events <- ev: @@ -595,6 +587,7 @@ func (s *OpenAIGatewayService) handleChatStreamingResponse( go func() { defer close(events) for scanner.Scan() { + atomic.StoreInt64(&lastReadAt, time.Now().UnixNano()) if !sendEvent(scanEvent{line: scanner.Text()}) { return } @@ -605,30 +598,59 @@ func (s *OpenAIGatewayService) handleChatStreamingResponse( }() defer close(done) - keepaliveTicker := time.NewTicker(keepaliveInterval) - defer keepaliveTicker.Stop() + var keepaliveTicker *time.Ticker + if keepaliveInterval > 0 { + keepaliveTicker = time.NewTicker(keepaliveInterval) + defer keepaliveTicker.Stop() + } + var keepaliveCh <-chan time.Time + if keepaliveTicker != nil { + keepaliveCh = keepaliveTicker.C + } lastDataAt := time.Now() for { select { case ev, ok := <-events: if !ok { - return finalizeStream() + return missingTerminalErr() } if ev.err != nil { handleScanErr(ev.err) - return finalizeStream() + return resultWithUsage(), fmt.Errorf("stream usage incomplete: %w", ev.err) } lastDataAt = time.Now() line := ev.line - if !strings.HasPrefix(line, "data: ") || line == "data: [DONE]" { + payload, ok := extractOpenAISSEDataLine(line) + if !ok { continue } - if processDataLine(line[6:]) { - return resultWithUsage(), nil + if strings.TrimSpace(payload) == "[DONE]" { + return missingTerminalErr() + } + if processDataLine(payload) { + return finalizeStream() } - case <-keepaliveTicker.C: + case <-intervalCh: + lastRead := time.Unix(0, atomic.LoadInt64(&lastReadAt)) + if time.Since(lastRead) < streamInterval { + continue + } + if clientDisconnected { + return resultWithUsage(), fmt.Errorf("stream usage incomplete after timeout") + } + logger.L().Warn("openai chat_completions stream: data interval timeout", + zap.String("request_id", requestID), + zap.String("model", originalModel), + zap.Duration("interval", streamInterval), + ) + return resultWithUsage(), fmt.Errorf("stream data interval timeout") + + case <-keepaliveCh: + if clientDisconnected { + continue + } if time.Since(lastDataAt) < keepaliveInterval { continue } @@ -637,7 +659,8 @@ func (s *OpenAIGatewayService) handleChatStreamingResponse( logger.L().Info("openai chat_completions stream: client disconnected during keepalive", zap.String("request_id", requestID), ) - return resultWithUsage(), nil + clientDisconnected = true + continue } c.Writer.Flush() } diff --git a/backend/internal/service/openai_gateway_chat_completions_test.go b/backend/internal/service/openai_gateway_chat_completions_test.go index 6846e03a..c129a4df 100644 --- a/backend/internal/service/openai_gateway_chat_completions_test.go +++ b/backend/internal/service/openai_gateway_chat_completions_test.go @@ -1,13 +1,36 @@ package service import ( + "bytes" + "context" + "errors" + "io" + "net/http" + "net/http/httptest" + "strings" "testing" + "time" "github.com/Wei-Shaw/sub2api/internal/pkg/apicompat" + "github.com/gin-gonic/gin" "github.com/stretchr/testify/require" "github.com/tidwall/gjson" ) +type openAIChatFailingWriter struct { + gin.ResponseWriter + failAfter int + writes int +} + +func (w *openAIChatFailingWriter) Write(p []byte) (int, error) { + if w.writes >= w.failAfter { + return 0, errors.New("write failed: client disconnected") + } + w.writes++ + return w.ResponseWriter.Write(p) +} + func TestNormalizeResponsesRequestServiceTier(t *testing.T) { t.Parallel() @@ -73,3 +96,242 @@ func TestNormalizeResponsesBodyServiceTier(t *testing.T) { require.Empty(t, tier) require.False(t, gjson.GetBytes(body, "service_tier").Exists()) } + +func TestForwardAsChatCompletions_ClientDisconnectDrainsUpstreamUsage(t *testing.T) { + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Writer = &openAIChatFailingWriter{ResponseWriter: c.Writer, failAfter: 0} + body := []byte(`{"model":"gpt-5.4","messages":[{"role":"user","content":"hello"}],"stream":true}`) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", bytes.NewReader(body)) + c.Request.Header.Set("Content-Type", "application/json") + + upstreamBody := strings.Join([]string{ + `data: {"type":"response.created","response":{"id":"resp_1","model":"gpt-5.4","status":"in_progress","output":[]}}`, + "", + `data: {"type":"response.output_text.delta","delta":"ok"}`, + "", + `data: {"type":"response.completed","response":{"id":"resp_1","object":"response","model":"gpt-5.4","status":"completed","output":[{"type":"message","id":"msg_1","role":"assistant","status":"completed","content":[{"type":"output_text","text":"ok"}]}],"usage":{"input_tokens":11,"output_tokens":5,"total_tokens":16,"input_tokens_details":{"cached_tokens":4}}}}`, + "", + "data: [DONE]", + "", + }, "\n") + upstream := &httpUpstreamRecorder{resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_chat_disconnect"}}, + Body: io.NopCloser(strings.NewReader(upstreamBody)), + }} + + svc := &OpenAIGatewayService{httpUpstream: upstream} + account := &Account{ + ID: 1, + Name: "openai-oauth", + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Concurrency: 1, + Credentials: map[string]any{ + "access_token": "oauth-token", + "chatgpt_account_id": "chatgpt-acc", + }, + } + + result, err := svc.ForwardAsChatCompletions(context.Background(), c, account, body, "", "gpt-5.1") + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, 11, result.Usage.InputTokens) + require.Equal(t, 5, result.Usage.OutputTokens) + require.Equal(t, 4, result.Usage.CacheReadInputTokens) +} + +func TestForwardAsChatCompletions_TerminalUsageWithoutUpstreamCloseReturns(t *testing.T) { + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Writer = &openAIChatFailingWriter{ResponseWriter: c.Writer, failAfter: 0} + body := []byte(`{"model":"gpt-5.4","messages":[{"role":"user","content":"hello"}],"stream":true}`) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", bytes.NewReader(body)) + c.Request.Header.Set("Content-Type", "application/json") + + upstreamBody := []byte(`data: {"type":"response.completed","response":{"id":"resp_1","object":"response","model":"gpt-5.4","status":"completed","output":[{"type":"message","id":"msg_1","role":"assistant","status":"completed","content":[{"type":"output_text","text":"ok"}]}],"usage":{"input_tokens":17,"output_tokens":8,"total_tokens":25,"input_tokens_details":{"cached_tokens":6}}}}` + "\n\n") + upstreamStream := newOpenAICompatBlockingReadCloser(upstreamBody) + defer func() { + require.NoError(t, upstreamStream.Close()) + }() + upstream := &httpUpstreamRecorder{resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_chat_terminal_no_close"}}, + Body: upstreamStream, + }} + + svc := &OpenAIGatewayService{httpUpstream: upstream} + account := &Account{ + ID: 1, + Name: "openai-oauth", + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Concurrency: 1, + Credentials: map[string]any{ + "access_token": "oauth-token", + "chatgpt_account_id": "chatgpt-acc", + }, + } + + type forwardResult struct { + result *OpenAIForwardResult + err error + } + resultCh := make(chan forwardResult, 1) + go func() { + result, err := svc.ForwardAsChatCompletions(context.Background(), c, account, body, "", "gpt-5.1") + resultCh <- forwardResult{result: result, err: err} + }() + + select { + case got := <-resultCh: + require.NoError(t, got.err) + require.NotNil(t, got.result) + require.Equal(t, 17, got.result.Usage.InputTokens) + require.Equal(t, 8, got.result.Usage.OutputTokens) + require.Equal(t, 6, got.result.Usage.CacheReadInputTokens) + case <-time.After(time.Second): + require.Fail(t, "ForwardAsChatCompletions should return after terminal usage event even if upstream keeps the connection open") + } +} + +func TestForwardAsChatCompletions_BufferedTerminalWithoutUpstreamCloseReturns(t *testing.T) { + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + body := []byte(`{"model":"gpt-5.4","messages":[{"role":"user","content":"hello"}],"stream":false}`) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", bytes.NewReader(body)) + c.Request.Header.Set("Content-Type", "application/json") + + upstreamBody := []byte(`data: {"type":"response.completed","response":{"id":"resp_1","object":"response","model":"gpt-5.4","status":"completed","output":[{"type":"message","id":"msg_1","role":"assistant","status":"completed","content":[{"type":"output_text","text":"ok"}]}],"usage":{"input_tokens":17,"output_tokens":8,"total_tokens":25,"input_tokens_details":{"cached_tokens":6}}}}` + "\n\n") + upstreamStream := newOpenAICompatBlockingReadCloser(upstreamBody) + defer func() { + require.NoError(t, upstreamStream.Close()) + }() + upstream := &httpUpstreamRecorder{resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_chat_buffered_terminal_no_close"}}, + Body: upstreamStream, + }} + + svc := &OpenAIGatewayService{httpUpstream: upstream} + account := &Account{ + ID: 1, + Name: "openai-oauth", + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Concurrency: 1, + Credentials: map[string]any{ + "access_token": "oauth-token", + "chatgpt_account_id": "chatgpt-acc", + }, + } + + type forwardResult struct { + result *OpenAIForwardResult + err error + } + resultCh := make(chan forwardResult, 1) + go func() { + result, err := svc.ForwardAsChatCompletions(context.Background(), c, account, body, "", "gpt-5.1") + resultCh <- forwardResult{result: result, err: err} + }() + + select { + case got := <-resultCh: + require.NoError(t, got.err) + require.NotNil(t, got.result) + require.Equal(t, 17, got.result.Usage.InputTokens) + require.Equal(t, 8, got.result.Usage.OutputTokens) + require.Equal(t, 6, got.result.Usage.CacheReadInputTokens) + require.Contains(t, rec.Body.String(), `"finish_reason":"stop"`) + case <-time.After(time.Second): + require.Fail(t, "ForwardAsChatCompletions buffered response should return after terminal usage event even if upstream keeps the connection open") + } +} + +func TestForwardAsChatCompletions_DoneSentinelWithoutTerminalReturnsError(t *testing.T) { + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + body := []byte(`{"model":"gpt-5.4","messages":[{"role":"user","content":"hello"}],"stream":true}`) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", bytes.NewReader(body)) + c.Request.Header.Set("Content-Type", "application/json") + + upstreamBody := "data: [DONE]\n\n" + upstream := &httpUpstreamRecorder{resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_chat_missing_terminal"}}, + Body: io.NopCloser(strings.NewReader(upstreamBody)), + }} + + svc := &OpenAIGatewayService{httpUpstream: upstream} + account := &Account{ + ID: 1, + Name: "openai-oauth", + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Concurrency: 1, + Credentials: map[string]any{ + "access_token": "oauth-token", + "chatgpt_account_id": "chatgpt-acc", + }, + } + + result, err := svc.ForwardAsChatCompletions(context.Background(), c, account, body, "", "gpt-5.1") + require.Error(t, err) + require.Contains(t, err.Error(), "missing terminal event") + require.NotNil(t, result) + require.Zero(t, result.Usage.InputTokens) + require.Zero(t, result.Usage.OutputTokens) +} + +func TestForwardAsChatCompletions_UpstreamRequestIgnoresClientCancel(t *testing.T) { + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + reqCtx, cancel := context.WithCancel(context.Background()) + body := []byte(`{"model":"gpt-5.4","messages":[{"role":"user","content":"hello"}],"stream":false}`) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", bytes.NewReader(body)).WithContext(reqCtx) + c.Request.Header.Set("Content-Type", "application/json") + cancel() + + upstreamBody := strings.Join([]string{ + `data: {"type":"response.completed","response":{"id":"resp_1","object":"response","model":"gpt-5.4","status":"completed","output":[{"type":"message","id":"msg_1","role":"assistant","status":"completed","content":[{"type":"output_text","text":"ok"}]}],"usage":{"input_tokens":5,"output_tokens":2,"total_tokens":7}}}`, + "", + "data: [DONE]", + "", + }, "\n") + upstream := &httpUpstreamRecorder{resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_chat_ctx"}}, + Body: io.NopCloser(strings.NewReader(upstreamBody)), + }} + + svc := &OpenAIGatewayService{httpUpstream: upstream} + account := &Account{ + ID: 1, + Name: "openai-oauth", + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Concurrency: 1, + Credentials: map[string]any{ + "access_token": "oauth-token", + "chatgpt_account_id": "chatgpt-acc", + }, + } + + result, err := svc.ForwardAsChatCompletions(reqCtx, c, account, body, "", "gpt-5.1") + require.NoError(t, err) + require.NotNil(t, result) + require.NotNil(t, upstream.lastReq) + require.NoError(t, upstream.lastReq.Context().Err()) +} diff --git a/backend/internal/service/openai_gateway_messages.go b/backend/internal/service/openai_gateway_messages.go index 4e0ebb2e..5f3bf5c1 100644 --- a/backend/internal/service/openai_gateway_messages.go +++ b/backend/internal/service/openai_gateway_messages.go @@ -10,6 +10,7 @@ import ( "io" "net/http" "strings" + "sync/atomic" "time" "github.com/Wei-Shaw/sub2api/internal/pkg/apicompat" @@ -163,7 +164,9 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic( } // 6. Build upstream request - upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, responsesBody, token, isStream, promptCacheKey, false) + upstreamCtx, releaseUpstreamCtx := detachUpstreamContext(ctx) + upstreamReq, err := s.buildUpstreamRequest(upstreamCtx, c, account, responsesBody, token, isStream, promptCacheKey, false) + releaseUpstreamCtx() if err != nil { return nil, fmt.Errorf("build upstream request: %w", err) } @@ -296,61 +299,9 @@ func (s *OpenAIGatewayService) handleAnthropicBufferedStreamingResponse( ) (*OpenAIForwardResult, error) { requestID := resp.Header.Get("x-request-id") - scanner := bufio.NewScanner(resp.Body) - maxLineSize := defaultMaxLineSize - if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 { - maxLineSize = s.cfg.Gateway.MaxLineSize - } - scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize) - - var finalResponse *apicompat.ResponsesResponse - var usage OpenAIUsage - acc := apicompat.NewBufferedResponseAccumulator() - - for scanner.Scan() { - line := scanner.Text() - - if !strings.HasPrefix(line, "data: ") || line == "data: [DONE]" { - continue - } - payload := line[6:] - - var event apicompat.ResponsesStreamEvent - if err := json.Unmarshal([]byte(payload), &event); err != nil { - logger.L().Warn("openai messages buffered: failed to parse event", - zap.Error(err), - zap.String("request_id", requestID), - ) - continue - } - - // Accumulate delta content for fallback when terminal output is empty. - acc.ProcessEvent(&event) - - // Terminal events carry the complete ResponsesResponse with output + usage. - if (event.Type == "response.completed" || event.Type == "response.done" || - event.Type == "response.incomplete" || event.Type == "response.failed") && - event.Response != nil { - finalResponse = event.Response - if event.Response.Usage != nil { - usage = OpenAIUsage{ - InputTokens: event.Response.Usage.InputTokens, - OutputTokens: event.Response.Usage.OutputTokens, - } - if event.Response.Usage.InputTokensDetails != nil { - usage.CacheReadInputTokens = event.Response.Usage.InputTokensDetails.CachedTokens - } - } - } - } - - if err := scanner.Err(); err != nil { - if !errors.Is(err, context.Canceled) && !errors.Is(err, context.DeadlineExceeded) { - logger.L().Warn("openai messages buffered: read error", - zap.Error(err), - zap.String("request_id", requestID), - ) - } + finalResponse, usage, acc, err := s.readOpenAICompatBufferedTerminal(resp, "openai messages buffered", requestID) + if err != nil { + return nil, err } if finalResponse == nil { @@ -380,6 +331,153 @@ func (s *OpenAIGatewayService) handleAnthropicBufferedStreamingResponse( }, nil } +func isOpenAICompatResponsesTerminalEvent(eventType string) bool { + switch strings.TrimSpace(eventType) { + case "response.completed", "response.done", "response.incomplete", "response.failed": + return true + default: + return false + } +} + +func isOpenAICompatDoneSentinelLine(line string) bool { + payload, ok := extractOpenAISSEDataLine(line) + return ok && strings.TrimSpace(payload) == "[DONE]" +} + +func (s *OpenAIGatewayService) readOpenAICompatBufferedTerminal( + resp *http.Response, + logPrefix string, + requestID string, +) (*apicompat.ResponsesResponse, OpenAIUsage, *apicompat.BufferedResponseAccumulator, error) { + acc := apicompat.NewBufferedResponseAccumulator() + var usage OpenAIUsage + if resp == nil || resp.Body == nil { + return nil, usage, acc, errors.New("upstream response body is nil") + } + + scanner := bufio.NewScanner(resp.Body) + maxLineSize := defaultMaxLineSize + if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 { + maxLineSize = s.cfg.Gateway.MaxLineSize + } + scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize) + + streamInterval := time.Duration(0) + if s.cfg != nil && s.cfg.Gateway.StreamDataIntervalTimeout > 0 { + streamInterval = time.Duration(s.cfg.Gateway.StreamDataIntervalTimeout) * time.Second + } + var timeoutCh <-chan time.Time + var timeoutTimer *time.Timer + resetTimeout := func() { + if streamInterval <= 0 { + return + } + if timeoutTimer == nil { + timeoutTimer = time.NewTimer(streamInterval) + timeoutCh = timeoutTimer.C + return + } + if !timeoutTimer.Stop() { + select { + case <-timeoutTimer.C: + default: + } + } + timeoutTimer.Reset(streamInterval) + } + stopTimeout := func() { + if timeoutTimer == nil { + return + } + if !timeoutTimer.Stop() { + select { + case <-timeoutTimer.C: + default: + } + } + } + resetTimeout() + defer stopTimeout() + + type scanEvent struct { + line string + err error + } + events := make(chan scanEvent, 16) + done := make(chan struct{}) + go func() { + defer close(events) + for scanner.Scan() { + select { + case events <- scanEvent{line: scanner.Text()}: + case <-done: + return + } + } + if err := scanner.Err(); err != nil { + select { + case events <- scanEvent{err: err}: + case <-done: + } + } + }() + defer close(done) + + for { + select { + case ev, ok := <-events: + if !ok { + return nil, usage, acc, nil + } + resetTimeout() + if ev.err != nil { + if !errors.Is(ev.err, context.Canceled) && !errors.Is(ev.err, context.DeadlineExceeded) { + logger.L().Warn(logPrefix+": read error", + zap.Error(ev.err), + zap.String("request_id", requestID), + ) + } + return nil, usage, acc, ev.err + } + + if isOpenAICompatDoneSentinelLine(ev.line) { + return nil, usage, acc, nil + } + payload, ok := extractOpenAISSEDataLine(ev.line) + if !ok || payload == "" { + continue + } + + var event apicompat.ResponsesStreamEvent + if err := json.Unmarshal([]byte(payload), &event); err != nil { + logger.L().Warn(logPrefix+": failed to parse event", + zap.Error(err), + zap.String("request_id", requestID), + ) + continue + } + + acc.ProcessEvent(&event) + + if isOpenAICompatResponsesTerminalEvent(event.Type) && event.Response != nil { + if event.Response.Usage != nil { + usage = copyOpenAIUsageFromResponsesUsage(event.Response.Usage) + } + return event.Response, usage, acc, nil + } + + case <-timeoutCh: + _ = resp.Body.Close() + logger.L().Warn(logPrefix+": data interval timeout", + zap.String("request_id", requestID), + zap.Duration("interval", streamInterval), + ) + return nil, usage, acc, fmt.Errorf("stream data interval timeout") + } + } +} + // handleAnthropicStreamingResponse reads Responses SSE events from upstream, // converts each to Anthropic SSE events, and writes them to the client. // When StreamKeepaliveInterval is configured, it uses a goroutine + channel @@ -409,6 +507,7 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse( var usage OpenAIUsage var firstTokenMs *int firstChunk := true + clientDisconnected := false scanner := bufio.NewScanner(resp.Body) maxLineSize := defaultMaxLineSize @@ -417,6 +516,20 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse( } scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize) + streamInterval := time.Duration(0) + if s.cfg != nil && s.cfg.Gateway.StreamDataIntervalTimeout > 0 { + streamInterval = time.Duration(s.cfg.Gateway.StreamDataIntervalTimeout) * time.Second + } + var intervalTicker *time.Ticker + if streamInterval > 0 { + intervalTicker = time.NewTicker(streamInterval) + defer intervalTicker.Stop() + } + var intervalCh <-chan time.Time + if intervalTicker != nil { + intervalCh = intervalTicker.C + } + // resultWithUsage builds the final result snapshot. resultWithUsage := func() *OpenAIForwardResult { return &OpenAIForwardResult{ @@ -432,7 +545,6 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse( } // processDataLine handles a single "data: ..." SSE line from upstream. - // Returns (clientDisconnected bool). processDataLine := func(payload string) bool { if firstChunk { firstChunk = false @@ -449,53 +561,58 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse( return false } - // Extract usage from completion events - if (event.Type == "response.completed" || event.Type == "response.incomplete" || event.Type == "response.failed") && - event.Response != nil && event.Response.Usage != nil { - usage = OpenAIUsage{ - InputTokens: event.Response.Usage.InputTokens, - OutputTokens: event.Response.Usage.OutputTokens, - } - if event.Response.Usage.InputTokensDetails != nil { - usage.CacheReadInputTokens = event.Response.Usage.InputTokensDetails.CachedTokens - } + // 仅按兼容转换器支持的终止事件提取 usage,避免无意扩大事件语义。 + isTerminalEvent := isOpenAICompatResponsesTerminalEvent(event.Type) + if isTerminalEvent && event.Response != nil && event.Response.Usage != nil { + usage = copyOpenAIUsageFromResponsesUsage(event.Response.Usage) } // Convert to Anthropic events events := apicompat.ResponsesEventToAnthropicEvents(&event, state) - for _, evt := range events { - sse, err := apicompat.ResponsesAnthropicEventToSSE(evt) - if err != nil { - logger.L().Warn("openai messages stream: failed to marshal event", - zap.Error(err), - zap.String("request_id", requestID), - ) - continue - } - if _, err := fmt.Fprint(c.Writer, sse); err != nil { - logger.L().Info("openai messages stream: client disconnected", - zap.String("request_id", requestID), - ) - return true + if !clientDisconnected { + for _, evt := range events { + sse, err := apicompat.ResponsesAnthropicEventToSSE(evt) + if err != nil { + logger.L().Warn("openai messages stream: failed to marshal event", + zap.Error(err), + zap.String("request_id", requestID), + ) + continue + } + if _, err := fmt.Fprint(c.Writer, sse); err != nil { + clientDisconnected = true + logger.L().Info("openai messages stream: client disconnected, continuing to drain upstream for billing", + zap.String("request_id", requestID), + ) + break + } } } - if len(events) > 0 { + if len(events) > 0 && !clientDisconnected { c.Writer.Flush() } - return false + return isTerminalEvent } // finalizeStream sends any remaining Anthropic events and returns the result. finalizeStream := func() (*OpenAIForwardResult, error) { - if finalEvents := apicompat.FinalizeResponsesAnthropicStream(state); len(finalEvents) > 0 { + if finalEvents := apicompat.FinalizeResponsesAnthropicStream(state); len(finalEvents) > 0 && !clientDisconnected { for _, evt := range finalEvents { sse, err := apicompat.ResponsesAnthropicEventToSSE(evt) if err != nil { continue } - fmt.Fprint(c.Writer, sse) //nolint:errcheck + if _, err := fmt.Fprint(c.Writer, sse); err != nil { + clientDisconnected = true + logger.L().Info("openai messages stream: client disconnected during final flush", + zap.String("request_id", requestID), + ) + break + } + } + if !clientDisconnected { + c.Writer.Flush() } - c.Writer.Flush() } return resultWithUsage(), nil } @@ -509,6 +626,9 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse( ) } } + missingTerminalErr := func() (*OpenAIForwardResult, error) { + return resultWithUsage(), fmt.Errorf("stream usage incomplete: missing terminal event") + } // ── Determine keepalive interval ── keepaliveInterval := time.Duration(0) @@ -517,18 +637,25 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse( } // ── No keepalive: fast synchronous path (no goroutine overhead) ── - if keepaliveInterval <= 0 { + if streamInterval <= 0 && keepaliveInterval <= 0 { for scanner.Scan() { line := scanner.Text() - if !strings.HasPrefix(line, "data: ") || line == "data: [DONE]" { + if isOpenAICompatDoneSentinelLine(line) { + return missingTerminalErr() + } + payload, ok := extractOpenAISSEDataLine(line) + if !ok { continue } - if processDataLine(line[6:]) { - return resultWithUsage(), nil + if processDataLine(payload) { + return finalizeStream() } } - handleScanErr(scanner.Err()) - return finalizeStream() + if err := scanner.Err(); err != nil { + handleScanErr(err) + return resultWithUsage(), fmt.Errorf("stream usage incomplete: %w", err) + } + return missingTerminalErr() } // ── With keepalive: goroutine + channel + select ── @@ -538,6 +665,8 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse( } events := make(chan scanEvent, 16) done := make(chan struct{}) + var lastReadAt int64 + atomic.StoreInt64(&lastReadAt, time.Now().UnixNano()) sendEvent := func(ev scanEvent) bool { select { case events <- ev: @@ -549,6 +678,7 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse( go func() { defer close(events) for scanner.Scan() { + atomic.StoreInt64(&lastReadAt, time.Now().UnixNano()) if !sendEvent(scanEvent{line: scanner.Text()}) { return } @@ -559,8 +689,15 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse( }() defer close(done) - keepaliveTicker := time.NewTicker(keepaliveInterval) - defer keepaliveTicker.Stop() + var keepaliveTicker *time.Ticker + if keepaliveInterval > 0 { + keepaliveTicker = time.NewTicker(keepaliveInterval) + defer keepaliveTicker.Stop() + } + var keepaliveCh <-chan time.Time + if keepaliveTicker != nil { + keepaliveCh = keepaliveTicker.C + } lastDataAt := time.Now() for { @@ -568,22 +705,44 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse( case ev, ok := <-events: if !ok { // Upstream closed - return finalizeStream() + return missingTerminalErr() } if ev.err != nil { handleScanErr(ev.err) - return finalizeStream() + return resultWithUsage(), fmt.Errorf("stream usage incomplete: %w", ev.err) } lastDataAt = time.Now() line := ev.line - if !strings.HasPrefix(line, "data: ") || line == "data: [DONE]" { + if isOpenAICompatDoneSentinelLine(line) { + return missingTerminalErr() + } + payload, ok := extractOpenAISSEDataLine(line) + if !ok { continue } - if processDataLine(line[6:]) { - return resultWithUsage(), nil + if processDataLine(payload) { + return finalizeStream() } - case <-keepaliveTicker.C: + case <-intervalCh: + lastRead := time.Unix(0, atomic.LoadInt64(&lastReadAt)) + if time.Since(lastRead) < streamInterval { + continue + } + if clientDisconnected { + return resultWithUsage(), fmt.Errorf("stream usage incomplete after timeout") + } + logger.L().Warn("openai messages stream: data interval timeout", + zap.String("request_id", requestID), + zap.String("model", originalModel), + zap.Duration("interval", streamInterval), + ) + return resultWithUsage(), fmt.Errorf("stream data interval timeout") + + case <-keepaliveCh: + if clientDisconnected { + continue + } if time.Since(lastDataAt) < keepaliveInterval { continue } @@ -593,7 +752,8 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse( logger.L().Info("openai messages stream: client disconnected during keepalive", zap.String("request_id", requestID), ) - return resultWithUsage(), nil + clientDisconnected = true + continue } c.Writer.Flush() } @@ -610,3 +770,17 @@ func writeAnthropicError(c *gin.Context, statusCode int, errType, message string }, }) } + +func copyOpenAIUsageFromResponsesUsage(usage *apicompat.ResponsesUsage) OpenAIUsage { + if usage == nil { + return OpenAIUsage{} + } + result := OpenAIUsage{ + InputTokens: usage.InputTokens, + OutputTokens: usage.OutputTokens, + } + if usage.InputTokensDetails != nil { + result.CacheReadInputTokens = usage.InputTokensDetails.CachedTokens + } + return result +} diff --git a/backend/internal/service/openai_gateway_record_usage_test.go b/backend/internal/service/openai_gateway_record_usage_test.go index 47ff4e3b..76fbb794 100644 --- a/backend/internal/service/openai_gateway_record_usage_test.go +++ b/backend/internal/service/openai_gateway_record_usage_test.go @@ -186,6 +186,56 @@ func max(a, b int) int { return b } +func TestOpenAIGatewayServiceRecordUsage_ZeroUsageStillWritesUsageLog(t *testing.T) { + usageRepo := &openAIRecordUsageLogRepoStub{inserted: true} + billingRepo := &openAIRecordUsageBillingRepoStub{result: &UsageBillingApplyResult{Applied: true}} + userRepo := &openAIRecordUsageUserRepoStub{} + subRepo := &openAIRecordUsageSubRepoStub{} + quotaSvc := &openAIRecordUsageAPIKeyQuotaStub{} + svc := newOpenAIRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, userRepo, subRepo, nil) + + err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{ + Result: &OpenAIForwardResult{ + RequestID: "resp_zero_usage", + Usage: OpenAIUsage{}, + Model: "gpt-5.1", + Duration: time.Second, + }, + APIKey: &APIKey{ID: 1000, Quota: 100, Group: &Group{RateMultiplier: 1}}, + User: &User{ID: 2000}, + Account: &Account{ID: 3000, Type: AccountTypeAPIKey}, + APIKeyService: quotaSvc, + }) + + require.NoError(t, err) + require.Equal(t, 1, billingRepo.calls) + require.Equal(t, 1, usageRepo.calls) + require.Equal(t, 0, userRepo.deductCalls) + require.Equal(t, 0, subRepo.incrementCalls) + require.Equal(t, 0, quotaSvc.quotaCalls) + require.Equal(t, 0, quotaSvc.rateLimitCalls) + + require.NotNil(t, usageRepo.lastLog) + require.Equal(t, "resp_zero_usage", usageRepo.lastLog.RequestID) + require.Zero(t, usageRepo.lastLog.InputTokens) + require.Zero(t, usageRepo.lastLog.OutputTokens) + require.Zero(t, usageRepo.lastLog.CacheCreationTokens) + require.Zero(t, usageRepo.lastLog.CacheReadTokens) + require.Zero(t, usageRepo.lastLog.ImageOutputTokens) + require.Zero(t, usageRepo.lastLog.ImageCount) + require.Zero(t, usageRepo.lastLog.InputCost) + require.Zero(t, usageRepo.lastLog.OutputCost) + require.Zero(t, usageRepo.lastLog.TotalCost) + require.Zero(t, usageRepo.lastLog.ActualCost) + + require.NotNil(t, billingRepo.lastCmd) + require.Zero(t, billingRepo.lastCmd.BalanceCost) + require.Zero(t, billingRepo.lastCmd.SubscriptionCost) + require.Zero(t, billingRepo.lastCmd.APIKeyQuotaCost) + require.Zero(t, billingRepo.lastCmd.APIKeyRateLimitCost) + require.Zero(t, billingRepo.lastCmd.AccountQuotaCost) +} + func TestOpenAIGatewayServiceRecordUsage_UsesUserSpecificGroupRate(t *testing.T) { groupID := int64(11) groupRate := 1.4 diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go index ed69730c..b818fa4a 100644 --- a/backend/internal/service/openai_gateway_service.go +++ b/backend/internal/service/openai_gateway_service.go @@ -2601,7 +2601,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco httpInvalidEncryptedContentRetryTried := false for { // Build upstream request - upstreamCtx, releaseUpstreamCtx := detachStreamUpstreamContext(ctx, reqStream) + upstreamCtx, releaseUpstreamCtx := detachUpstreamContext(ctx) upstreamReq, err := s.buildUpstreamRequest(upstreamCtx, c, account, body, token, reqStream, promptCacheKey, isCodexCLI) releaseUpstreamCtx() if err != nil { @@ -2852,7 +2852,7 @@ func (s *OpenAIGatewayService) forwardOpenAIPassthrough( return nil, err } - upstreamCtx, releaseUpstreamCtx := detachStreamUpstreamContext(ctx, reqStream) + upstreamCtx, releaseUpstreamCtx := detachUpstreamContext(ctx) upstreamReq, err := s.buildUpstreamRequestOpenAIPassthrough(upstreamCtx, c, account, body, token) releaseUpstreamCtx() if err != nil { @@ -5041,13 +5041,6 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec s.rateLimitService.ResetOpenAI403Counter(ctx, input.Account.ID) } - // 跳过所有 token 均为零的用量记录——上游未返回 usage 时不应写入数据库 - if result.Usage.InputTokens == 0 && result.Usage.OutputTokens == 0 && - result.Usage.CacheCreationInputTokens == 0 && result.Usage.CacheReadInputTokens == 0 && - result.Usage.ImageOutputTokens == 0 && result.ImageCount == 0 { - return nil - } - apiKey := input.APIKey user := input.User account := input.Account diff --git a/backend/internal/service/openai_images.go b/backend/internal/service/openai_images.go index 4badcb1c..3da76525 100644 --- a/backend/internal/service/openai_images.go +++ b/backend/internal/service/openai_images.go @@ -596,7 +596,7 @@ func (s *OpenAIGatewayService) forwardOpenAIImagesAPIKey( var usage OpenAIUsage imageCount := parsed.N var firstTokenMs *int - if parsed.Stream { + if parsed.Stream && isEventStreamResponse(resp.Header) { streamUsage, streamCount, ttft, err := s.handleOpenAIImagesStreamingResponse(resp, c, startTime) if err != nil { return nil, err @@ -811,6 +811,11 @@ func (s *OpenAIGatewayService) handleOpenAIImagesStreamingResponse( usage := OpenAIUsage{} imageCount := 0 var firstTokenMs *int + var fallbackBody bytes.Buffer + fallbackBytes := int64(0) + fallbackLimit := resolveUpstreamResponseReadLimit(s.cfg) + seenSSEData := false + fallbackTooLarge := false for { line, err := reader.ReadBytes('\n') @@ -824,11 +829,24 @@ func (s *OpenAIGatewayService) handleOpenAIImagesStreamingResponse( } flusher.Flush() - if data, ok := extractOpenAISSEDataLine(strings.TrimRight(string(line), "\r\n")); ok && data != "" && data != "[DONE]" { - dataBytes := []byte(data) - mergeOpenAIUsage(&usage, dataBytes) - if count := extractOpenAIImageCountFromJSONBytes(dataBytes); count > imageCount { - imageCount = count + if data, ok := extractOpenAISSEDataLine(strings.TrimRight(string(line), "\r\n")); ok { + if data != "" && data != "[DONE]" { + seenSSEData = true + fallbackBody.Reset() + fallbackBytes = 0 + dataBytes := []byte(data) + mergeOpenAIUsage(&usage, dataBytes) + if count := extractOpenAIImagesBillableCountFromJSONBytes(dataBytes); count > imageCount { + imageCount = count + } + } + } else if !seenSSEData && !fallbackTooLarge { + fallbackBytes += int64(len(line)) + if fallbackBytes <= fallbackLimit { + _, _ = fallbackBody.Write(line) + } else { + fallbackTooLarge = true + fallbackBody.Reset() } } } @@ -839,9 +857,41 @@ func (s *OpenAIGatewayService) handleOpenAIImagesStreamingResponse( return OpenAIUsage{}, 0, firstTokenMs, err } } + if !seenSSEData && fallbackBody.Len() > 0 { + body := bytes.TrimSpace(fallbackBody.Bytes()) + if len(body) > 0 { + mergeOpenAIUsage(&usage, body) + if count := extractOpenAIImagesBillableCountFromJSONBytes(body); count > imageCount { + imageCount = count + } + } + } return usage, imageCount, firstTokenMs, nil } +func extractOpenAIImagesBillableCountFromJSONBytes(body []byte) int { + if count := extractOpenAIImageCountFromJSONBytes(body); count > 0 { + return count + } + if len(body) == 0 || !gjson.ValidBytes(body) { + return 0 + } + if count := int(gjson.GetBytes(body, "usage.images").Int()); count > 0 { + return count + } + if count := int(gjson.GetBytes(body, "tool_usage.image_gen.images").Int()); count > 0 { + return count + } + eventType := strings.TrimSpace(gjson.GetBytes(body, "type").String()) + if eventType == "" || !strings.HasSuffix(eventType, ".completed") { + return 0 + } + if gjson.GetBytes(body, "b64_json").Exists() || gjson.GetBytes(body, "url").Exists() { + return 1 + } + return 0 +} + func mergeOpenAIUsage(dst *OpenAIUsage, body []byte) { if dst == nil { return diff --git a/backend/internal/service/openai_images_test.go b/backend/internal/service/openai_images_test.go index 47113d4d..681e0e8e 100644 --- a/backend/internal/service/openai_images_test.go +++ b/backend/internal/service/openai_images_test.go @@ -446,6 +446,109 @@ func TestOpenAIGatewayServiceForwardImages_APIKeyGenerationUsesConfiguredV1BaseU require.Equal(t, "aGVsbG8=", gjson.Get(rec.Body.String(), "data.0.b64_json").String()) } +func TestOpenAIGatewayServiceForwardImages_APIKeyStreamJSONResponseBillsImage(t *testing.T) { + gin.SetMode(gin.TestMode) + body := []byte(`{"model":"gpt-image-2","prompt":"draw a cat","stream":true,"response_format":"b64_json"}`) + + req := httptest.NewRequest(http.MethodPost, "/v1/images/generations", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = req + + svc := &OpenAIGatewayService{ + cfg: &config.Config{}, + httpUpstream: &httpUpstreamRecorder{ + resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{ + "Content-Type": []string{"application/json"}, + "X-Request-Id": []string{"req_img_stream_json"}, + }, + Body: io.NopCloser(strings.NewReader(`{"created":1710000008,"usage":{"input_tokens":12,"output_tokens":21,"output_tokens_details":{"image_tokens":9}},"data":[{"b64_json":"aGVsbG8=","revised_prompt":"draw a cat"}]}`)), + }, + }, + } + parsed, err := svc.ParseOpenAIImagesRequest(c, body) + require.NoError(t, err) + + account := &Account{ + ID: 7, + Name: "openai-apikey", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Credentials: map[string]any{ + "api_key": "test-api-key", + "base_url": "https://image-upstream.example/v1", + }, + } + + result, err := svc.ForwardImages(context.Background(), c, account, body, parsed, "") + require.NoError(t, err) + require.NotNil(t, result) + require.True(t, result.Stream) + require.Equal(t, 1, result.ImageCount) + require.Equal(t, 12, result.Usage.InputTokens) + require.Equal(t, 21, result.Usage.OutputTokens) + require.Equal(t, 9, result.Usage.ImageOutputTokens) + require.Equal(t, http.StatusOK, rec.Code) + require.Equal(t, "aGVsbG8=", gjson.Get(rec.Body.String(), "data.0.b64_json").String()) +} + +func TestOpenAIGatewayServiceForwardImages_APIKeyStreamRawJSONEventStreamFallbackBillsImage(t *testing.T) { + gin.SetMode(gin.TestMode) + body := []byte(`{"model":"gpt-image-2","prompt":"draw a cat","stream":true,"response_format":"b64_json"}`) + + req := httptest.NewRequest(http.MethodPost, "/v1/images/generations", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = req + + svc := &OpenAIGatewayService{ + cfg: &config.Config{}, + httpUpstream: &httpUpstreamRecorder{ + resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{ + "Content-Type": []string{"text/event-stream"}, + "X-Request-Id": []string{"req_img_stream_json_mislabeled"}, + }, + Body: io.NopCloser(strings.NewReader(`{"created":1710000009,"usage":{"input_tokens":10,"output_tokens":18,"output_tokens_details":{"image_tokens":8}},"data":[{"b64_json":"ZmluYWw="}]}`)), + }, + }, + } + parsed, err := svc.ParseOpenAIImagesRequest(c, body) + require.NoError(t, err) + + account := &Account{ + ID: 8, + Name: "openai-apikey", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Credentials: map[string]any{ + "api_key": "test-api-key", + "base_url": "https://image-upstream.example/v1", + }, + } + + result, err := svc.ForwardImages(context.Background(), c, account, body, parsed, "") + require.NoError(t, err) + require.NotNil(t, result) + require.True(t, result.Stream) + require.Equal(t, 1, result.ImageCount) + require.Equal(t, 10, result.Usage.InputTokens) + require.Equal(t, 18, result.Usage.OutputTokens) + require.Equal(t, 8, result.Usage.ImageOutputTokens) + require.Equal(t, "ZmluYWw=", gjson.Get(rec.Body.String(), "data.0.b64_json").String()) +} + +func TestExtractOpenAIImagesBillableCountFromJSONBytes_CompletedEvent(t *testing.T) { + body := []byte(`{"type":"image_generation.completed","b64_json":"ZmluYWw=","usage":{"input_tokens":10,"output_tokens":18}}`) + + require.Equal(t, 1, extractOpenAIImagesBillableCountFromJSONBytes(body)) +} + func TestOpenAIGatewayServiceForwardImages_APIKeyEditUsesConfiguredV1BaseURL(t *testing.T) { gin.SetMode(gin.TestMode) diff --git a/backend/internal/service/openai_oauth_passthrough_test.go b/backend/internal/service/openai_oauth_passthrough_test.go index 049ffdd8..87a05b14 100644 --- a/backend/internal/service/openai_oauth_passthrough_test.go +++ b/backend/internal/service/openai_oauth_passthrough_test.go @@ -307,6 +307,52 @@ func TestOpenAIGatewayService_OAuthPassthrough_CompactUsesJSONAndKeepsNonStreami require.Contains(t, rec.Body.String(), `"id":"cmp_123"`) } +func TestOpenAIGatewayService_OAuthPassthrough_UpstreamRequestIgnoresClientCancel(t *testing.T) { + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + reqCtx, cancel := context.WithCancel(context.Background()) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(nil)).WithContext(reqCtx) + c.Request.Header.Set("User-Agent", "codex_cli_rs/0.1.0") + cancel() + + originalBody := []byte(`{"model":"gpt-5.2","stream":true,"store":true,"instructions":"local-test-instructions","input":[{"type":"text","text":"hi"}]}`) + upstream := &httpUpstreamRecorder{resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_passthrough_ctx"}}, + Body: io.NopCloser(strings.NewReader(strings.Join([]string{ + `data: {"type":"response.completed","response":{"usage":{"input_tokens":2,"output_tokens":1}}}`, + "", + "data: [DONE]", + "", + }, "\n"))), + }} + + svc := &OpenAIGatewayService{ + cfg: &config.Config{Gateway: config.GatewayConfig{ForceCodexCLI: false}}, + httpUpstream: upstream, + } + account := &Account{ + ID: 123, + Name: "acc", + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Concurrency: 1, + Credentials: map[string]any{"access_token": "oauth-token", "chatgpt_account_id": "chatgpt-acc"}, + Extra: map[string]any{"openai_passthrough": true, "openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeOff}, + Status: StatusActive, + Schedulable: true, + RateMultiplier: f64p(1), + } + + result, err := svc.Forward(reqCtx, c, account, originalBody) + require.NoError(t, err) + require.NotNil(t, result) + require.NotNil(t, upstream.lastReq) + require.NoError(t, upstream.lastReq.Context().Err()) +} + func TestOpenAIGatewayService_OAuthPassthrough_CodexMissingInstructionsRejectedBeforeUpstream(t *testing.T) { gin.SetMode(gin.TestMode) logSink, restore := captureStructuredLog(t) @@ -405,6 +451,52 @@ func TestOpenAIGatewayService_OAuthPassthrough_DisabledUsesLegacyTransform(t *te require.Contains(t, string(upstream.lastBody), `"stream":true`) } +func TestOpenAIGatewayService_OAuthLegacy_UpstreamRequestIgnoresClientCancel(t *testing.T) { + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + reqCtx, cancel := context.WithCancel(context.Background()) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(nil)).WithContext(reqCtx) + c.Request.Header.Set("User-Agent", "codex_cli_rs/0.1.0") + cancel() + + originalBody := []byte(`{"model":"gpt-5.2","stream":false,"store":true,"input":[{"type":"text","text":"hi"}]}`) + upstream := &httpUpstreamRecorder{resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_legacy_ctx"}}, + Body: io.NopCloser(strings.NewReader(strings.Join([]string{ + `data: {"type":"response.completed","response":{"usage":{"input_tokens":1,"output_tokens":1}}}`, + "", + "data: [DONE]", + "", + }, "\n"))), + }} + + svc := &OpenAIGatewayService{ + cfg: &config.Config{Gateway: config.GatewayConfig{ForceCodexCLI: false}}, + httpUpstream: upstream, + } + account := &Account{ + ID: 123, + Name: "acc", + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Concurrency: 1, + Credentials: map[string]any{"access_token": "oauth-token", "chatgpt_account_id": "chatgpt-acc"}, + Extra: map[string]any{"openai_passthrough": false, "openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeOff}, + Status: StatusActive, + Schedulable: true, + RateMultiplier: f64p(1), + } + + result, err := svc.Forward(reqCtx, c, account, originalBody) + require.NoError(t, err) + require.NotNil(t, result) + require.NotNil(t, upstream.lastReq) + require.NoError(t, upstream.lastReq.Context().Err()) +} + func TestOpenAIGatewayService_OAuthLegacy_CompositeCodexUAUsesCodexOriginator(t *testing.T) { gin.SetMode(gin.TestMode) diff --git a/backend/internal/service/payment_fulfillment.go b/backend/internal/service/payment_fulfillment.go index 5df69aea..4ae6d134 100644 --- a/backend/internal/service/payment_fulfillment.go +++ b/backend/internal/service/payment_fulfillment.go @@ -394,7 +394,8 @@ func (s *PaymentService) applyAffiliateRebateForOrder(ctx context.Context, o *db return nil } - rebateAmount, err := s.affiliateService.AccrueInviteRebate(txCtx, o.UserID, o.Amount) + sourceOrderID := o.ID + rebateAmount, err := s.affiliateService.AccrueInviteRebateForOrder(txCtx, o.UserID, o.Amount, &sourceOrderID) if err != nil { s.writeAuditLog(ctx, o.ID, "AFFILIATE_REBATE_FAILED", "system", map[string]any{ "error": err.Error(), diff --git a/backend/migrations/134_affiliate_ledger_audit_snapshots.sql b/backend/migrations/134_affiliate_ledger_audit_snapshots.sql new file mode 100644 index 00000000..8a87ed1f --- /dev/null +++ b/backend/migrations/134_affiliate_ledger_audit_snapshots.sql @@ -0,0 +1,85 @@ +-- 邀请返利流水补充订单关联和转余额快照。 +-- 这些字段只用于审计展示;历史旧流水无法可靠反推的字段保持 NULL,避免把当前状态误展示为历史状态。 + +ALTER TABLE user_affiliate_ledger + ADD COLUMN IF NOT EXISTS source_order_id BIGINT NULL REFERENCES payment_orders(id) ON DELETE SET NULL; + +ALTER TABLE user_affiliate_ledger + ADD COLUMN IF NOT EXISTS balance_after DECIMAL(20,8) NULL; + +ALTER TABLE user_affiliate_ledger + ADD COLUMN IF NOT EXISTS aff_quota_after DECIMAL(20,8) NULL; + +ALTER TABLE user_affiliate_ledger + ADD COLUMN IF NOT EXISTS aff_frozen_quota_after DECIMAL(20,8) NULL; + +ALTER TABLE user_affiliate_ledger + ADD COLUMN IF NOT EXISTS aff_history_quota_after DECIMAL(20,8) NULL; + +COMMENT ON COLUMN user_affiliate_ledger.source_order_id IS '产生该返利流水的充值订单;转余额或无法可靠回填的历史数据为 NULL'; +COMMENT ON COLUMN user_affiliate_ledger.balance_after IS '邀请返利转余额后的用户余额快照;无法取得时为 NULL'; +COMMENT ON COLUMN user_affiliate_ledger.aff_quota_after IS '邀请返利转余额后的可用返利额度快照;无法取得时为 NULL'; +COMMENT ON COLUMN user_affiliate_ledger.aff_frozen_quota_after IS '邀请返利转余额后的冻结返利额度快照;无法取得时为 NULL'; +COMMENT ON COLUMN user_affiliate_ledger.aff_history_quota_after IS '邀请返利转余额后的历史返利总额快照;无法取得时为 NULL'; + +CREATE INDEX IF NOT EXISTS idx_user_affiliate_ledger_source_order_id + ON user_affiliate_ledger(source_order_id) + WHERE source_order_id IS NOT NULL; + +CREATE INDEX IF NOT EXISTS idx_user_affiliate_ledger_rebate_lookup + ON user_affiliate_ledger(action, source_order_id, user_id, source_user_id, created_at) + WHERE action = 'accrue'; + +-- 尽力回填 PR #2169 合并后、该迁移前已经产生的返利流水。 +-- 只有在同一订单只能匹配到一条返利流水时才回填,避免把多笔同额流水错误绑定到订单。 +WITH rebate_audits AS ( + SELECT po.id AS order_id, + po.user_id AS invitee_user_id, + invitee_aff.inviter_id, + rebate_detail.rebate_amount, + pal.created_at AS audit_created_at + FROM payment_audit_logs pal + CROSS JOIN LATERAL ( + SELECT substring( + pal.detail + FROM '"rebateAmount"[[:space:]]*:[[:space:]]*(-?[0-9]+(\.[0-9]+)?)' + )::numeric AS rebate_amount + ) rebate_detail + JOIN payment_orders po ON po.id::text = pal.order_id + JOIN user_affiliates invitee_aff ON invitee_aff.user_id = po.user_id + WHERE pal.action = 'AFFILIATE_REBATE_APPLIED' + AND rebate_detail.rebate_amount IS NOT NULL +), +ranked_matches AS ( + SELECT ual.id AS ledger_id, + ra.order_id, + COUNT(*) OVER (PARTITION BY ra.order_id) AS order_match_count, + COUNT(*) OVER (PARTITION BY ual.id) AS ledger_match_count, + ROW_NUMBER() OVER ( + PARTITION BY ual.id + ORDER BY ABS(EXTRACT(EPOCH FROM (ual.created_at - ra.audit_created_at))), ra.order_id + ) AS ledger_rank + FROM rebate_audits ra + JOIN user_affiliate_ledger ual + ON ual.action = 'accrue' + AND ual.source_order_id IS NULL + AND ual.user_id = ra.inviter_id + AND ual.source_user_id = ra.invitee_user_id + AND ABS(ual.amount - ra.rebate_amount) < 0.00000001 + AND ual.created_at BETWEEN ra.audit_created_at - INTERVAL '10 minutes' + AND ra.audit_created_at + INTERVAL '10 minutes' +) +UPDATE user_affiliate_ledger ual +SET source_order_id = ranked_matches.order_id, + updated_at = NOW() +FROM ranked_matches +WHERE ual.id = ranked_matches.ledger_id + AND ranked_matches.order_match_count = 1 + AND ranked_matches.ledger_match_count = 1 + AND ranked_matches.ledger_rank = 1 + AND NOT EXISTS ( + SELECT 1 + FROM user_affiliate_ledger existing + WHERE existing.source_order_id = ranked_matches.order_id + AND existing.action = 'accrue' + ); diff --git a/backend/migrations/auth_identity_payment_migrations_regression_test.go b/backend/migrations/auth_identity_payment_migrations_regression_test.go index 798ae0fe..99216296 100644 --- a/backend/migrations/auth_identity_payment_migrations_regression_test.go +++ b/backend/migrations/auth_identity_payment_migrations_regression_test.go @@ -127,3 +127,18 @@ func TestMigration124BackfillsLegacyOIDCSecurityFlagsSafely(t *testing.T) { require.Contains(t, sql, "oidc_connect_enabled") require.Contains(t, sql, "'false'") } + +func TestMigration134AddsAffiliateLedgerAuditFieldsWithoutJSONCast(t *testing.T) { + content, err := FS.ReadFile("134_affiliate_ledger_audit_snapshots.sql") + require.NoError(t, err) + + sql := string(content) + require.Contains(t, sql, "ADD COLUMN IF NOT EXISTS source_order_id BIGINT") + require.Contains(t, sql, "ADD COLUMN IF NOT EXISTS balance_after DECIMAL(20,8)") + require.Contains(t, sql, "ADD COLUMN IF NOT EXISTS aff_quota_after DECIMAL(20,8)") + require.Contains(t, sql, "substring(") + require.Contains(t, sql, `"rebateAmount"`) + require.Contains(t, sql, "COUNT(*) OVER (PARTITION BY ra.order_id) AS order_match_count") + require.Contains(t, sql, "COUNT(*) OVER (PARTITION BY ual.id) AS ledger_match_count") + require.NotContains(t, sql, "detail::jsonb") +} diff --git a/frontend/src/api/admin/affiliates.ts b/frontend/src/api/admin/affiliates.ts index 22639bd2..dadb0ae9 100644 --- a/frontend/src/api/admin/affiliates.ts +++ b/frontend/src/api/admin/affiliates.ts @@ -23,6 +23,72 @@ export interface ListAffiliateUsersParams { search?: string } +export interface ListAffiliateRecordsParams { + page?: number + page_size?: number + search?: string + start_at?: string + end_at?: string + sort_by?: string + sort_order?: 'asc' | 'desc' + timezone?: string +} + +export interface AffiliateInviteRecord { + inviter_id: number + inviter_email: string + inviter_username: string + invitee_id: number + invitee_email: string + invitee_username: string + aff_code: string + total_rebate: number + created_at: string +} + +export interface AffiliateRebateRecord { + order_id: number + out_trade_no: string + inviter_id: number + inviter_email: string + inviter_username: string + invitee_id: number + invitee_email: string + invitee_username: string + order_amount: number + pay_amount: number + rebate_amount: number + payment_type: string + order_status: string + created_at: string +} + +export interface AffiliateTransferRecord { + ledger_id: number + user_id: number + user_email: string + username: string + amount: number + balance_after?: number | null + available_quota_after?: number | null + frozen_quota_after?: number | null + history_quota_after?: number | null + snapshot_available: boolean + created_at: string +} + +export interface AffiliateUserOverview { + user_id: number + email: string + username: string + aff_code: string + rebate_rate_percent: number + invited_count: number + rebated_invitee_count: number + available_quota: number + history_quota: number +} + export interface UpdateAffiliateUserRequest { aff_code?: string aff_rebate_rate_percent?: number | null @@ -97,12 +163,68 @@ export async function batchSetRate( return data } +function recordParams(params: ListAffiliateRecordsParams = {}) { + return { + page: params.page ?? 1, + page_size: params.page_size ?? 20, + search: params.search ?? '', + start_at: params.start_at || undefined, + end_at: params.end_at || undefined, + sort_by: params.sort_by || undefined, + sort_order: params.sort_order || undefined, + timezone: params.timezone || undefined, + } +} + +export async function listInviteRecords( + params: ListAffiliateRecordsParams = {}, +): Promise> { + const { data } = await apiClient.get>( + '/admin/affiliates/invites', + { params: recordParams(params) }, + ) + return data +} + +export async function listRebateRecords( + params: ListAffiliateRecordsParams = {}, +): Promise> { + const { data } = await apiClient.get>( + '/admin/affiliates/rebates', + { params: recordParams(params) }, + ) + return data +} + +export async function listTransferRecords( + params: ListAffiliateRecordsParams = {}, +): Promise> { + const { data } = await apiClient.get>( + '/admin/affiliates/transfers', + { params: recordParams(params) }, + ) + return data +} + +export async function getUserOverview( + userId: number, +): Promise { + const { data } = await apiClient.get( + `/admin/affiliates/users/${userId}/overview`, + ) + return data +} + export const affiliatesAPI = { listUsers, lookupUsers, updateUserSettings, clearUserSettings, batchSetRate, + listInviteRecords, + listRebateRecords, + listTransferRecords, + getUserOverview, } export default affiliatesAPI diff --git a/frontend/src/api/admin/users.ts b/frontend/src/api/admin/users.ts index 3c75a6c4..fabc69bc 100644 --- a/frontend/src/api/admin/users.ts +++ b/frontend/src/api/admin/users.ts @@ -249,7 +249,7 @@ export interface BalanceHistoryResponse extends PaginatedResponse Math.ceil(total.value / pageSize) || 1) const typeOptions = computed(() => [ { value: '', label: t('admin.users.allTypes') }, { value: 'balance', label: t('admin.users.typeBalance') }, + { value: 'affiliate_balance', label: t('admin.users.typeAffiliateBalance') }, { value: 'admin_balance', label: t('admin.users.typeAdminBalance') }, { value: 'concurrency', label: t('admin.users.typeConcurrency') }, { value: 'admin_concurrency', label: t('admin.users.typeAdminConcurrency') }, @@ -235,7 +236,7 @@ const loadHistory = async (page: number) => { const isAdminType = (type: string) => type === 'admin_balance' || type === 'admin_concurrency' // Helper: check if balance type (includes admin_balance) -const isBalanceType = (type: string) => type === 'balance' || type === 'admin_balance' +const isBalanceType = (type: string) => type === 'balance' || type === 'admin_balance' || type === 'affiliate_balance' // Helper: check if subscription type const isSubscriptionType = (type: string) => type === 'subscription' @@ -291,6 +292,8 @@ const getItemTitle = (item: BalanceHistoryItem) => { switch (item.type) { case 'balance': return t('redeem.balanceAddedRedeem') + case 'affiliate_balance': + return t('redeem.balanceAddedAffiliate') case 'admin_balance': return item.value >= 0 ? t('redeem.balanceAddedAdmin') : t('redeem.balanceDeductedAdmin') case 'concurrency': diff --git a/frontend/src/components/layout/AppSidebar.vue b/frontend/src/components/layout/AppSidebar.vue index d8e2794e..4488bf60 100644 --- a/frontend/src/components/layout/AppSidebar.vue +++ b/frontend/src/components/layout/AppSidebar.vue @@ -721,6 +721,19 @@ const adminNavItems = computed((): NavItem[] => { { path: '/admin/proxies', label: t('nav.proxies'), icon: ServerIcon }, { path: '/admin/redeem', label: t('nav.redeemCodes'), icon: TicketIcon, hideInSimpleMode: true }, { path: '/admin/promo-codes', label: t('nav.promoCodes'), icon: GiftIcon, hideInSimpleMode: true }, + { + path: '/admin/affiliates', + label: t('nav.affiliateManagement'), + icon: UsersIcon, + hideInSimpleMode: true, + expandOnly: true, + featureFlag: flagAffiliate, + children: [ + { path: '/admin/affiliates/invites', label: t('nav.affiliateInviteRecords'), icon: UsersIcon }, + { path: '/admin/affiliates/rebates', label: t('nav.affiliateRebateRecords'), icon: OrderIcon }, + { path: '/admin/affiliates/transfers', label: t('nav.affiliateTransferRecords'), icon: CreditCardIcon }, + ], + }, { path: '/admin/orders', label: t('nav.orderManagement'), diff --git a/frontend/src/i18n/locales/en.ts b/frontend/src/i18n/locales/en.ts index 2da121fb..195d0237 100644 --- a/frontend/src/i18n/locales/en.ts +++ b/frontend/src/i18n/locales/en.ts @@ -347,6 +347,10 @@ export default { usage: 'Usage', redeem: 'Redeem', affiliate: 'Affiliate Rebates', + affiliateManagement: 'Affiliate Rebates', + affiliateInviteRecords: 'Invite Records', + affiliateRebateRecords: 'Rebate Records', + affiliateTransferRecords: 'Transfer Records', profile: 'Profile', users: 'Users', groups: 'Groups', @@ -1046,6 +1050,7 @@ export default { recentActivity: 'Recent Activity', historyWillAppear: 'Your redemption history will appear here', balanceAddedRedeem: 'Balance Added (Redeem)', + balanceAddedAffiliate: 'Balance Added (Affiliate Transfer)', balanceAddedAdmin: 'Balance Added (Admin)', balanceDeductedAdmin: 'Balance Deducted (Admin)', concurrencyAddedRedeem: 'Concurrency Added (Redeem)', @@ -1635,6 +1640,49 @@ export default { } }, + affiliates: { + invitesDescription: 'View site-wide inviter and invitee relationships', + rebatesDescription: 'View recharge orders that generated affiliate rebates', + transfersDescription: 'View affiliate quota transfers into account balance', + errors: { + loadFailed: 'Failed to load affiliate records' + }, + records: { + search: 'Search', + searchPlaceholder: 'Email, username, user ID, or order number', + startAt: 'Start date', + endAt: 'End date', + inviter: 'Inviter', + invitee: 'Invitee', + user: 'User', + affCode: 'Invite Code', + order: 'Order', + totalRebate: 'Total Rebate', + orderAmount: 'Top-up Amount', + payAmount: 'Paid Amount', + rebateAmount: 'Rebate Amount', + paymentType: 'Payment Method', + orderStatus: 'Order Status', + transferAmount: 'Transfer Amount', + balanceAfter: 'Balance After', + availableQuotaAfter: 'Available After', + frozenQuotaAfter: 'Frozen After', + historyQuotaAfter: 'Historical Rebate After', + invitedAt: 'Invited At', + rebatedAt: 'Rebated At', + transferredAt: 'Transferred At' + }, + overview: { + title: 'Affiliate User Overview', + affCode: 'Invite Code', + rebateRate: 'Rebate Rate', + invitedCount: 'Invited Users', + rebatedInviteeCount: 'Rebated Invitees', + availableQuota: 'Available Quota', + historyQuota: 'Historical Rebate' + } + }, + // Users users: { title: 'User Management', @@ -1787,6 +1835,7 @@ export default { noBalanceHistory: 'No records found for this user', allTypes: 'All Types', typeBalance: 'Balance (Redeem)', + typeAffiliateBalance: 'Balance (Affiliate Transfer)', typeAdminBalance: 'Balance (Admin)', typeConcurrency: 'Concurrency (Redeem)', typeAdminConcurrency: 'Concurrency (Admin)', diff --git a/frontend/src/i18n/locales/zh.ts b/frontend/src/i18n/locales/zh.ts index 7d266522..0f95d652 100644 --- a/frontend/src/i18n/locales/zh.ts +++ b/frontend/src/i18n/locales/zh.ts @@ -347,6 +347,10 @@ export default { usage: '使用记录', redeem: '兑换', affiliate: '邀请返利', + affiliateManagement: '邀请返利', + affiliateInviteRecords: '邀请记录', + affiliateRebateRecords: '返利记录', + affiliateTransferRecords: '提取记录', profile: '个人资料', users: '用户管理', groups: '分组管理', @@ -1050,6 +1054,7 @@ export default { recentActivity: '最近活动', historyWillAppear: '您的兑换历史将显示在这里', balanceAddedRedeem: '余额充值(兑换)', + balanceAddedAffiliate: '余额充值(返利转入)', balanceAddedAdmin: '余额充值(管理员)', balanceDeductedAdmin: '余额扣除(管理员)', concurrencyAddedRedeem: '并发增加(兑换)', @@ -1656,6 +1661,49 @@ export default { } }, + affiliates: { + invitesDescription: '查看全站邀请关系和被邀请用户累计返利', + rebatesDescription: '查看每一笔产生返利的充值订单', + transfersDescription: '查看返利额度转入账户余额的提取流水', + errors: { + loadFailed: '加载邀请返利记录失败' + }, + records: { + search: '搜索', + searchPlaceholder: '邮箱、用户名、用户 ID、订单号', + startAt: '开始日期', + endAt: '结束日期', + inviter: '邀请人', + invitee: '被邀请人', + user: '用户', + affCode: '邀请码', + order: '订单', + totalRebate: '累计返利', + orderAmount: '充值金额', + payAmount: '支付金额', + rebateAmount: '返利金额', + paymentType: '支付方式', + orderStatus: '订单状态', + transferAmount: '提取金额', + balanceAfter: '提取后余额', + availableQuotaAfter: '提取后可提', + frozenQuotaAfter: '提取后冻结', + historyQuotaAfter: '提取后历史返利', + invitedAt: '邀请时间', + rebatedAt: '返利时间', + transferredAt: '提取时间' + }, + overview: { + title: '用户返利概览', + affCode: '邀请码', + rebateRate: '返利比例', + invitedCount: '邀请人数', + rebatedInviteeCount: '已产生返利人数', + availableQuota: '可提余额', + historyQuota: '历史返利' + } + }, + // Users Management users: { title: '用户管理', @@ -1844,6 +1892,7 @@ export default { noBalanceHistory: '暂无变动记录', allTypes: '全部类型', typeBalance: '余额(兑换码)', + typeAffiliateBalance: '余额(返利转入)', typeAdminBalance: '余额(管理员调整)', typeConcurrency: '并发(兑换码)', typeAdminConcurrency: '并发(管理员调整)', diff --git a/frontend/src/router/index.ts b/frontend/src/router/index.ts index 06f6b212..238f6a71 100644 --- a/frontend/src/router/index.ts +++ b/frontend/src/router/index.ts @@ -517,6 +517,46 @@ const routes: RouteRecordRaw[] = [ descriptionKey: 'admin.usage.description' } }, + { + path: '/admin/affiliates', + redirect: '/admin/affiliates/invites' + }, + { + path: '/admin/affiliates/invites', + name: 'AdminAffiliateInvites', + component: () => import('@/views/admin/affiliates/AdminAffiliateInvitesView.vue'), + meta: { + requiresAuth: true, + requiresAdmin: true, + title: 'Affiliate Invite Records', + titleKey: 'nav.affiliateInviteRecords', + descriptionKey: 'admin.affiliates.invitesDescription' + } + }, + { + path: '/admin/affiliates/rebates', + name: 'AdminAffiliateRebates', + component: () => import('@/views/admin/affiliates/AdminAffiliateRebatesView.vue'), + meta: { + requiresAuth: true, + requiresAdmin: true, + title: 'Affiliate Rebate Records', + titleKey: 'nav.affiliateRebateRecords', + descriptionKey: 'admin.affiliates.rebatesDescription' + } + }, + { + path: '/admin/affiliates/transfers', + name: 'AdminAffiliateTransfers', + component: () => import('@/views/admin/affiliates/AdminAffiliateTransfersView.vue'), + meta: { + requiresAuth: true, + requiresAdmin: true, + title: 'Affiliate Transfer Records', + titleKey: 'nav.affiliateTransferRecords', + descriptionKey: 'admin.affiliates.transfersDescription' + } + }, // ==================== Payment Admin Routes ==================== diff --git a/frontend/src/views/admin/affiliates/AdminAffiliateInvitesView.vue b/frontend/src/views/admin/affiliates/AdminAffiliateInvitesView.vue new file mode 100644 index 00000000..62c96ff8 --- /dev/null +++ b/frontend/src/views/admin/affiliates/AdminAffiliateInvitesView.vue @@ -0,0 +1,7 @@ + + + diff --git a/frontend/src/views/admin/affiliates/AdminAffiliateRebatesView.vue b/frontend/src/views/admin/affiliates/AdminAffiliateRebatesView.vue new file mode 100644 index 00000000..1acd7b1b --- /dev/null +++ b/frontend/src/views/admin/affiliates/AdminAffiliateRebatesView.vue @@ -0,0 +1,7 @@ + + + diff --git a/frontend/src/views/admin/affiliates/AdminAffiliateRecordsTable.vue b/frontend/src/views/admin/affiliates/AdminAffiliateRecordsTable.vue new file mode 100644 index 00000000..789df41a --- /dev/null +++ b/frontend/src/views/admin/affiliates/AdminAffiliateRecordsTable.vue @@ -0,0 +1,407 @@ + + + diff --git a/frontend/src/views/admin/affiliates/AdminAffiliateTransfersView.vue b/frontend/src/views/admin/affiliates/AdminAffiliateTransfersView.vue new file mode 100644 index 00000000..5a56f179 --- /dev/null +++ b/frontend/src/views/admin/affiliates/AdminAffiliateTransfersView.vue @@ -0,0 +1,7 @@ + + +