Merge branch 'main' into fix/openai-ws-passthrough-reasoning-effort
This commit is contained in:
commit
11fe29223d
@ -2,8 +2,11 @@ package admin
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"strconv"
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
@ -181,3 +184,108 @@ func (h *AffiliateHandler) LookupUsers(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
response.Success(c, result)
|
response.Success(c, result)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetUserOverview returns one user's affiliate overview.
|
||||||
|
// GET /api/v1/admin/affiliates/users/:user_id/overview
|
||||||
|
func (h *AffiliateHandler) GetUserOverview(c *gin.Context) {
|
||||||
|
userID, err := strconv.ParseInt(c.Param("user_id"), 10, 64)
|
||||||
|
if err != nil || userID <= 0 {
|
||||||
|
response.BadRequest(c, "Invalid user_id")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
overview, err := h.affiliateService.AdminGetUserOverview(c.Request.Context(), userID)
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
response.Success(c, overview)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListInviteRecords returns all inviter-invitee relationships.
|
||||||
|
// GET /api/v1/admin/affiliates/invites
|
||||||
|
func (h *AffiliateHandler) ListInviteRecords(c *gin.Context) {
|
||||||
|
page, pageSize := response.ParsePagination(c)
|
||||||
|
filter := parseAffiliateRecordFilter(c, page, pageSize)
|
||||||
|
items, total, err := h.affiliateService.AdminListInviteRecords(c.Request.Context(), filter)
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
response.Paginated(c, items, total, filter.Page, filter.PageSize)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListRebateRecords returns all order-level affiliate rebate records.
|
||||||
|
// GET /api/v1/admin/affiliates/rebates
|
||||||
|
func (h *AffiliateHandler) ListRebateRecords(c *gin.Context) {
|
||||||
|
page, pageSize := response.ParsePagination(c)
|
||||||
|
filter := parseAffiliateRecordFilter(c, page, pageSize)
|
||||||
|
items, total, err := h.affiliateService.AdminListRebateRecords(c.Request.Context(), filter)
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
response.Paginated(c, items, total, filter.Page, filter.PageSize)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListTransferRecords returns all affiliate quota-to-balance transfer records.
|
||||||
|
// GET /api/v1/admin/affiliates/transfers
|
||||||
|
func (h *AffiliateHandler) ListTransferRecords(c *gin.Context) {
|
||||||
|
page, pageSize := response.ParsePagination(c)
|
||||||
|
filter := parseAffiliateRecordFilter(c, page, pageSize)
|
||||||
|
items, total, err := h.affiliateService.AdminListTransferRecords(c.Request.Context(), filter)
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
response.Paginated(c, items, total, filter.Page, filter.PageSize)
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseAffiliateRecordFilter(c *gin.Context, page, pageSize int) service.AffiliateRecordFilter {
|
||||||
|
filter := service.AffiliateRecordFilter{
|
||||||
|
Search: c.Query("search"),
|
||||||
|
Page: page,
|
||||||
|
PageSize: pageSize,
|
||||||
|
SortBy: c.Query("sort_by"),
|
||||||
|
SortDesc: c.Query("sort_order") != "asc",
|
||||||
|
}
|
||||||
|
if filter.PageSize > 100 {
|
||||||
|
filter.PageSize = 100
|
||||||
|
}
|
||||||
|
userTZ := c.Query("timezone")
|
||||||
|
if t := parseAffiliateRecordStartTime(c.Query("start_at"), userTZ); t != nil {
|
||||||
|
filter.StartAt = t
|
||||||
|
}
|
||||||
|
if t := parseAffiliateRecordEndTime(c.Query("end_at"), userTZ); t != nil {
|
||||||
|
filter.EndAt = t
|
||||||
|
}
|
||||||
|
return filter
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseAffiliateRecordStartTime(raw string, userTZ string) *time.Time {
|
||||||
|
raw = strings.TrimSpace(raw)
|
||||||
|
if raw == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if parsed, err := time.Parse(time.RFC3339, raw); err == nil {
|
||||||
|
return &parsed
|
||||||
|
}
|
||||||
|
if parsed, err := timezone.ParseInUserLocation("2006-01-02", raw, userTZ); err == nil {
|
||||||
|
return &parsed
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseAffiliateRecordEndTime(raw string, userTZ string) *time.Time {
|
||||||
|
raw = strings.TrimSpace(raw)
|
||||||
|
if raw == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if parsed, err := time.Parse(time.RFC3339, raw); err == nil {
|
||||||
|
return &parsed
|
||||||
|
}
|
||||||
|
if parsed, err := timezone.ParseInUserLocation("2006-01-02", raw, userTZ); err == nil {
|
||||||
|
end := parsed.AddDate(0, 0, 1).Add(-time.Nanosecond)
|
||||||
|
return &end
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|||||||
@ -390,7 +390,7 @@ func (h *UserHandler) GetUserUsage(c *gin.Context) {
|
|||||||
// GetBalanceHistory handles getting user's balance/concurrency change history
|
// GetBalanceHistory handles getting user's balance/concurrency change history
|
||||||
// GET /api/v1/admin/users/:id/balance-history
|
// GET /api/v1/admin/users/:id/balance-history
|
||||||
// Query params:
|
// Query params:
|
||||||
// - type: filter by record type (balance, admin_balance, concurrency, admin_concurrency, subscription)
|
// - type: filter by record type (balance, affiliate_balance, admin_balance, concurrency, admin_concurrency, subscription)
|
||||||
func (h *UserHandler) GetBalanceHistory(c *gin.Context) {
|
func (h *UserHandler) GetBalanceHistory(c *gin.Context) {
|
||||||
userID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
userID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@ -434,6 +434,45 @@ func TestStreamingTextOnly(t *testing.T) {
|
|||||||
assert.Equal(t, "message_stop", events[1].Type)
|
assert.Equal(t, "message_stop", events[1].Type)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestResponsesEventToAnthropicEvents_ResponseDone(t *testing.T) {
|
||||||
|
state := NewResponsesEventToAnthropicState()
|
||||||
|
state.Model = "gpt-4o"
|
||||||
|
|
||||||
|
events := ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{
|
||||||
|
Type: "response.done",
|
||||||
|
Response: &ResponsesResponse{
|
||||||
|
Status: "completed",
|
||||||
|
Usage: &ResponsesUsage{InputTokens: 12, OutputTokens: 4},
|
||||||
|
},
|
||||||
|
}, state)
|
||||||
|
require.Len(t, events, 2)
|
||||||
|
assert.Equal(t, "message_delta", events[0].Type)
|
||||||
|
assert.Equal(t, "end_turn", events[0].Delta.StopReason)
|
||||||
|
assert.Equal(t, 12, events[0].Usage.InputTokens)
|
||||||
|
assert.Equal(t, 4, events[0].Usage.OutputTokens)
|
||||||
|
assert.Equal(t, "message_stop", events[1].Type)
|
||||||
|
assert.Nil(t, FinalizeResponsesAnthropicStream(state))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResponsesEventToAnthropicEvents_ResponseDoneIncomplete(t *testing.T) {
|
||||||
|
state := NewResponsesEventToAnthropicState()
|
||||||
|
state.Model = "gpt-4o"
|
||||||
|
|
||||||
|
events := ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{
|
||||||
|
Type: "response.done",
|
||||||
|
Response: &ResponsesResponse{
|
||||||
|
Status: "incomplete",
|
||||||
|
IncompleteDetails: &ResponsesIncompleteDetails{Reason: "max_output_tokens"},
|
||||||
|
Usage: &ResponsesUsage{InputTokens: 12, OutputTokens: 4},
|
||||||
|
},
|
||||||
|
}, state)
|
||||||
|
require.Len(t, events, 2)
|
||||||
|
assert.Equal(t, "message_delta", events[0].Type)
|
||||||
|
assert.Equal(t, "max_tokens", events[0].Delta.StopReason)
|
||||||
|
assert.Equal(t, "message_stop", events[1].Type)
|
||||||
|
assert.Nil(t, FinalizeResponsesAnthropicStream(state))
|
||||||
|
}
|
||||||
|
|
||||||
func TestStreamingCachedTokensUseAnthropicInputSemantics(t *testing.T) {
|
func TestStreamingCachedTokensUseAnthropicInputSemantics(t *testing.T) {
|
||||||
state := NewResponsesEventToAnthropicState()
|
state := NewResponsesEventToAnthropicState()
|
||||||
ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{
|
ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{
|
||||||
|
|||||||
@ -720,6 +720,49 @@ func TestResponsesEventToChatChunks_Completed(t *testing.T) {
|
|||||||
assert.Equal(t, 30, chunks[1].Usage.PromptTokensDetails.CachedTokens)
|
assert.Equal(t, 30, chunks[1].Usage.PromptTokensDetails.CachedTokens)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestResponsesEventToChatChunks_ResponseDone(t *testing.T) {
|
||||||
|
state := NewResponsesEventToChatState()
|
||||||
|
state.Model = "gpt-4o"
|
||||||
|
state.IncludeUsage = true
|
||||||
|
|
||||||
|
chunks := ResponsesEventToChatChunks(&ResponsesStreamEvent{
|
||||||
|
Type: "response.done",
|
||||||
|
Response: &ResponsesResponse{
|
||||||
|
Status: "completed",
|
||||||
|
Usage: &ResponsesUsage{InputTokens: 13, OutputTokens: 7},
|
||||||
|
},
|
||||||
|
}, state)
|
||||||
|
require.Len(t, chunks, 2)
|
||||||
|
require.NotNil(t, chunks[0].Choices[0].FinishReason)
|
||||||
|
assert.Equal(t, "stop", *chunks[0].Choices[0].FinishReason)
|
||||||
|
require.NotNil(t, chunks[1].Usage)
|
||||||
|
assert.Equal(t, 13, chunks[1].Usage.PromptTokens)
|
||||||
|
assert.Equal(t, 7, chunks[1].Usage.CompletionTokens)
|
||||||
|
assert.Nil(t, FinalizeResponsesChatStream(state))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResponsesEventToChatChunks_ResponseDoneIncomplete(t *testing.T) {
|
||||||
|
state := NewResponsesEventToChatState()
|
||||||
|
state.Model = "gpt-4o"
|
||||||
|
state.IncludeUsage = true
|
||||||
|
|
||||||
|
chunks := ResponsesEventToChatChunks(&ResponsesStreamEvent{
|
||||||
|
Type: "response.done",
|
||||||
|
Response: &ResponsesResponse{
|
||||||
|
Status: "incomplete",
|
||||||
|
IncompleteDetails: &ResponsesIncompleteDetails{Reason: "max_output_tokens"},
|
||||||
|
Usage: &ResponsesUsage{InputTokens: 13, OutputTokens: 7},
|
||||||
|
},
|
||||||
|
}, state)
|
||||||
|
require.Len(t, chunks, 2)
|
||||||
|
require.NotNil(t, chunks[0].Choices[0].FinishReason)
|
||||||
|
assert.Equal(t, "length", *chunks[0].Choices[0].FinishReason)
|
||||||
|
require.NotNil(t, chunks[1].Usage)
|
||||||
|
assert.Equal(t, 13, chunks[1].Usage.PromptTokens)
|
||||||
|
assert.Equal(t, 7, chunks[1].Usage.CompletionTokens)
|
||||||
|
assert.Nil(t, FinalizeResponsesChatStream(state))
|
||||||
|
}
|
||||||
|
|
||||||
func TestResponsesEventToChatChunks_CompletedWithToolCalls(t *testing.T) {
|
func TestResponsesEventToChatChunks_CompletedWithToolCalls(t *testing.T) {
|
||||||
state := NewResponsesEventToChatState()
|
state := NewResponsesEventToChatState()
|
||||||
state.Model = "gpt-4o"
|
state.Model = "gpt-4o"
|
||||||
|
|||||||
@ -212,7 +212,9 @@ func ResponsesEventToAnthropicEvents(
|
|||||||
return resToAnthHandleReasoningDelta(evt, state)
|
return resToAnthHandleReasoningDelta(evt, state)
|
||||||
case "response.reasoning_summary_text.done":
|
case "response.reasoning_summary_text.done":
|
||||||
return resToAnthHandleBlockDone(state)
|
return resToAnthHandleBlockDone(state)
|
||||||
case "response.completed", "response.incomplete", "response.failed":
|
// response.done 是 Realtime/WS 与项目透传路径使用的终止别名;
|
||||||
|
// 普通 Responses HTTP SSE 的公开终止事件仍以 response.completed 为主。
|
||||||
|
case "response.completed", "response.done", "response.incomplete", "response.failed":
|
||||||
return resToAnthHandleCompleted(evt, state)
|
return resToAnthHandleCompleted(evt, state)
|
||||||
default:
|
default:
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@ -160,7 +160,9 @@ func ResponsesEventToChatChunks(evt *ResponsesStreamEvent, state *ResponsesEvent
|
|||||||
return resToChatHandleReasoningDelta(evt, state)
|
return resToChatHandleReasoningDelta(evt, state)
|
||||||
case "response.reasoning_summary_text.done":
|
case "response.reasoning_summary_text.done":
|
||||||
return nil
|
return nil
|
||||||
case "response.completed", "response.incomplete", "response.failed":
|
// response.done 是 Realtime/WS 与项目透传路径使用的终止别名;
|
||||||
|
// 普通 Responses HTTP SSE 的公开终止事件仍以 response.completed 为主。
|
||||||
|
case "response.completed", "response.done", "response.incomplete", "response.failed":
|
||||||
return resToChatHandleCompleted(evt, state)
|
return resToChatHandleCompleted(evt, state)
|
||||||
default:
|
default:
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@ -314,7 +314,7 @@ type ResponsesOutputTokensDetails struct {
|
|||||||
type ResponsesStreamEvent struct {
|
type ResponsesStreamEvent struct {
|
||||||
Type string `json:"type"`
|
Type string `json:"type"`
|
||||||
|
|
||||||
// response.created / response.completed / response.failed / response.incomplete
|
// response.created / response.completed / response.done / response.failed / response.incomplete
|
||||||
Response *ResponsesResponse `json:"response,omitempty"`
|
Response *ResponsesResponse `json:"response,omitempty"`
|
||||||
|
|
||||||
// response.output_item.added / response.output_item.done
|
// response.output_item.added / response.output_item.done
|
||||||
|
|||||||
@ -22,6 +22,34 @@ const (
|
|||||||
|
|
||||||
var affiliateCodeCharset = []byte("ABCDEFGHJKLMNPQRSTUVWXYZ23456789")
|
var affiliateCodeCharset = []byte("ABCDEFGHJKLMNPQRSTUVWXYZ23456789")
|
||||||
|
|
||||||
|
const affiliateUserOverviewSQL = `
|
||||||
|
SELECT ua.user_id,
|
||||||
|
COALESCE(u.email, ''),
|
||||||
|
COALESCE(u.username, ''),
|
||||||
|
ua.aff_code,
|
||||||
|
COALESCE(ua.aff_rebate_rate_percent, 0)::double precision,
|
||||||
|
(ua.aff_rebate_rate_percent IS NOT NULL) AS has_custom_rate,
|
||||||
|
ua.aff_count,
|
||||||
|
COALESCE(rebated.rebated_invitee_count, 0),
|
||||||
|
(ua.aff_quota + COALESCE(matured.matured_frozen_quota, 0))::double precision,
|
||||||
|
ua.aff_history_quota::double precision
|
||||||
|
FROM user_affiliates ua
|
||||||
|
JOIN users u ON u.id = ua.user_id
|
||||||
|
LEFT JOIN (
|
||||||
|
SELECT user_id, COUNT(DISTINCT source_user_id)::integer AS rebated_invitee_count
|
||||||
|
FROM user_affiliate_ledger
|
||||||
|
WHERE action = 'accrue' AND source_user_id IS NOT NULL
|
||||||
|
GROUP BY user_id
|
||||||
|
) rebated ON rebated.user_id = ua.user_id
|
||||||
|
LEFT JOIN (
|
||||||
|
SELECT user_id, COALESCE(SUM(amount), 0)::double precision AS matured_frozen_quota
|
||||||
|
FROM user_affiliate_ledger
|
||||||
|
WHERE action = 'accrue' AND frozen_until IS NOT NULL AND frozen_until <= NOW()
|
||||||
|
GROUP BY user_id
|
||||||
|
) matured ON matured.user_id = ua.user_id
|
||||||
|
WHERE ua.user_id = $1
|
||||||
|
LIMIT 1`
|
||||||
|
|
||||||
type affiliateQueryExecer interface {
|
type affiliateQueryExecer interface {
|
||||||
QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error)
|
QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error)
|
||||||
ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error)
|
ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error)
|
||||||
@ -86,7 +114,7 @@ func (r *affiliateRepository) BindInviter(ctx context.Context, userID, inviterID
|
|||||||
return bound, nil
|
return bound, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *affiliateRepository) AccrueQuota(ctx context.Context, inviterID, inviteeUserID int64, amount float64, freezeHours int) (bool, error) {
|
func (r *affiliateRepository) AccrueQuota(ctx context.Context, inviterID, inviteeUserID int64, amount float64, freezeHours int, sourceOrderID *int64) (bool, error) {
|
||||||
if amount <= 0 {
|
if amount <= 0 {
|
||||||
return false, nil
|
return false, nil
|
||||||
}
|
}
|
||||||
@ -112,15 +140,15 @@ func (r *affiliateRepository) AccrueQuota(ctx context.Context, inviterID, invite
|
|||||||
|
|
||||||
if freezeHours > 0 {
|
if freezeHours > 0 {
|
||||||
if _, err = txClient.ExecContext(txCtx, `
|
if _, err = txClient.ExecContext(txCtx, `
|
||||||
INSERT INTO user_affiliate_ledger (user_id, action, amount, source_user_id, frozen_until, created_at, updated_at)
|
INSERT INTO user_affiliate_ledger (user_id, action, amount, source_user_id, source_order_id, frozen_until, created_at, updated_at)
|
||||||
VALUES ($1, 'accrue', $2, $3, NOW() + make_interval(hours => $4), NOW(), NOW())`,
|
VALUES ($1, 'accrue', $2, $3, $4, NOW() + make_interval(hours => $5), NOW(), NOW())`,
|
||||||
inviterID, amount, inviteeUserID, freezeHours); err != nil {
|
inviterID, amount, inviteeUserID, nullableInt64Arg(sourceOrderID), freezeHours); err != nil {
|
||||||
return fmt.Errorf("insert affiliate accrue ledger: %w", err)
|
return fmt.Errorf("insert affiliate accrue ledger: %w", err)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if _, err = txClient.ExecContext(txCtx, `
|
if _, err = txClient.ExecContext(txCtx, `
|
||||||
INSERT INTO user_affiliate_ledger (user_id, action, amount, source_user_id, created_at, updated_at)
|
INSERT INTO user_affiliate_ledger (user_id, action, amount, source_user_id, source_order_id, created_at, updated_at)
|
||||||
VALUES ($1, 'accrue', $2, $3, NOW(), NOW())`, inviterID, amount, inviteeUserID); err != nil {
|
VALUES ($1, 'accrue', $2, $3, $4, NOW(), NOW())`, inviterID, amount, inviteeUserID, nullableInt64Arg(sourceOrderID)); err != nil {
|
||||||
return fmt.Errorf("insert affiliate accrue ledger: %w", err)
|
return fmt.Errorf("insert affiliate accrue ledger: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -275,9 +303,32 @@ FROM cleared`, userID)
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
snapshot, err := queryAffiliateTransferSnapshot(txCtx, txClient, userID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
if _, err = txClient.ExecContext(txCtx, `
|
if _, err = txClient.ExecContext(txCtx, `
|
||||||
INSERT INTO user_affiliate_ledger (user_id, action, amount, source_user_id, created_at, updated_at)
|
INSERT INTO user_affiliate_ledger (
|
||||||
VALUES ($1, 'transfer', $2, NULL, NOW(), NOW())`, userID, transferred); err != nil {
|
user_id,
|
||||||
|
action,
|
||||||
|
amount,
|
||||||
|
source_user_id,
|
||||||
|
balance_after,
|
||||||
|
aff_quota_after,
|
||||||
|
aff_frozen_quota_after,
|
||||||
|
aff_history_quota_after,
|
||||||
|
created_at,
|
||||||
|
updated_at
|
||||||
|
)
|
||||||
|
VALUES ($1, 'transfer', $2, NULL, $3, $4, $5, $6, NOW(), NOW())`,
|
||||||
|
userID,
|
||||||
|
transferred,
|
||||||
|
snapshot.BalanceAfter,
|
||||||
|
snapshot.AvailableQuotaAfter,
|
||||||
|
snapshot.FrozenQuotaAfter,
|
||||||
|
snapshot.HistoryQuotaAfter,
|
||||||
|
); err != nil {
|
||||||
return fmt.Errorf("insert affiliate transfer ledger: %w", err)
|
return fmt.Errorf("insert affiliate transfer ledger: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -332,6 +383,349 @@ LIMIT $2`, inviterID, limit)
|
|||||||
return invitees, nil
|
return invitees, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *affiliateRepository) ListAffiliateInviteRecords(ctx context.Context, filter service.AffiliateRecordFilter) ([]service.AffiliateInviteRecord, int64, error) {
|
||||||
|
client := clientFromContext(ctx, r.client)
|
||||||
|
where, args := buildAffiliateRecordWhere(filter, "ua.created_at", []string{
|
||||||
|
"inviter.email", "inviter.username", "invitee.email", "invitee.username",
|
||||||
|
"ua.inviter_id::text", "ua.user_id::text", "inviter_aff.aff_code",
|
||||||
|
})
|
||||||
|
|
||||||
|
total, err := queryAffiliateRecordCount(ctx, client, `
|
||||||
|
SELECT COUNT(*)
|
||||||
|
FROM user_affiliates ua
|
||||||
|
JOIN users invitee ON invitee.id = ua.user_id
|
||||||
|
JOIN users inviter ON inviter.id = ua.inviter_id
|
||||||
|
JOIN user_affiliates inviter_aff ON inviter_aff.user_id = ua.inviter_id
|
||||||
|
`+where, args...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
orderBy := buildAffiliateRecordOrderBy(filter, map[string]string{
|
||||||
|
"inviter": "inviter.email",
|
||||||
|
"invitee": "invitee.email",
|
||||||
|
"aff_code": "inviter_aff.aff_code",
|
||||||
|
"total_rebate": "total_rebate",
|
||||||
|
"created_at": "ua.created_at",
|
||||||
|
}, "ua.created_at")
|
||||||
|
args = append(args, filter.PageSize, (filter.Page-1)*filter.PageSize)
|
||||||
|
rows, err := client.QueryContext(ctx, `
|
||||||
|
SELECT ua.inviter_id,
|
||||||
|
COALESCE(inviter.email, ''),
|
||||||
|
COALESCE(inviter.username, ''),
|
||||||
|
ua.user_id,
|
||||||
|
COALESCE(invitee.email, ''),
|
||||||
|
COALESCE(invitee.username, ''),
|
||||||
|
COALESCE(inviter_aff.aff_code, ''),
|
||||||
|
COALESCE(SUM(ual.amount), 0)::double precision AS total_rebate,
|
||||||
|
ua.created_at
|
||||||
|
FROM user_affiliates ua
|
||||||
|
JOIN users invitee ON invitee.id = ua.user_id
|
||||||
|
JOIN users inviter ON inviter.id = ua.inviter_id
|
||||||
|
JOIN user_affiliates inviter_aff ON inviter_aff.user_id = ua.inviter_id
|
||||||
|
LEFT JOIN user_affiliate_ledger ual
|
||||||
|
ON ual.user_id = ua.inviter_id
|
||||||
|
AND ual.source_user_id = ua.user_id
|
||||||
|
AND ual.action = 'accrue'
|
||||||
|
`+where+`
|
||||||
|
GROUP BY ua.inviter_id, inviter.email, inviter.username, ua.user_id, invitee.email, invitee.username, inviter_aff.aff_code, ua.created_at
|
||||||
|
`+orderBy+`
|
||||||
|
LIMIT $`+fmt.Sprint(len(args)-1)+` OFFSET $`+fmt.Sprint(len(args)), args...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, 0, err
|
||||||
|
}
|
||||||
|
defer func() { _ = rows.Close() }()
|
||||||
|
|
||||||
|
items := make([]service.AffiliateInviteRecord, 0)
|
||||||
|
for rows.Next() {
|
||||||
|
var item service.AffiliateInviteRecord
|
||||||
|
if err := rows.Scan(
|
||||||
|
&item.InviterID,
|
||||||
|
&item.InviterEmail,
|
||||||
|
&item.InviterUsername,
|
||||||
|
&item.InviteeID,
|
||||||
|
&item.InviteeEmail,
|
||||||
|
&item.InviteeUsername,
|
||||||
|
&item.AffCode,
|
||||||
|
&item.TotalRebate,
|
||||||
|
&item.CreatedAt,
|
||||||
|
); err != nil {
|
||||||
|
return nil, 0, err
|
||||||
|
}
|
||||||
|
items = append(items, item)
|
||||||
|
}
|
||||||
|
if err := rows.Err(); err != nil {
|
||||||
|
return nil, 0, err
|
||||||
|
}
|
||||||
|
return items, total, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *affiliateRepository) ListAffiliateRebateRecords(ctx context.Context, filter service.AffiliateRecordFilter) ([]service.AffiliateRebateRecord, int64, error) {
|
||||||
|
client := clientFromContext(ctx, r.client)
|
||||||
|
where, args := buildAffiliateRecordWhere(filter, "ual.created_at", []string{
|
||||||
|
"inviter.email", "inviter.username", "invitee.email", "invitee.username",
|
||||||
|
"po.id::text", "po.out_trade_no", "po.payment_type", "po.status",
|
||||||
|
})
|
||||||
|
baseJoin := `
|
||||||
|
FROM user_affiliate_ledger ual
|
||||||
|
JOIN payment_orders po ON po.id = ual.source_order_id
|
||||||
|
JOIN users invitee ON invitee.id = ual.source_user_id
|
||||||
|
JOIN users inviter ON inviter.id = ual.user_id
|
||||||
|
WHERE ual.action = 'accrue'
|
||||||
|
AND ual.source_order_id IS NOT NULL`
|
||||||
|
if where != "" {
|
||||||
|
where = strings.Replace(where, "WHERE ", " AND ", 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
total, err := queryAffiliateRecordCount(ctx, client, "SELECT COUNT(*) "+baseJoin+where, args...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
orderBy := buildAffiliateRecordOrderBy(filter, map[string]string{
|
||||||
|
"order": "po.id",
|
||||||
|
"inviter": "inviter.email",
|
||||||
|
"invitee": "invitee.email",
|
||||||
|
"order_amount": "po.amount",
|
||||||
|
"pay_amount": "po.pay_amount",
|
||||||
|
"rebate_amount": "ual.amount",
|
||||||
|
"payment_type": "po.payment_type",
|
||||||
|
"order_status": "po.status",
|
||||||
|
"created_at": "ual.created_at",
|
||||||
|
}, "ual.created_at")
|
||||||
|
args = append(args, filter.PageSize, (filter.Page-1)*filter.PageSize)
|
||||||
|
rows, err := client.QueryContext(ctx, `
|
||||||
|
SELECT po.id,
|
||||||
|
po.out_trade_no,
|
||||||
|
ual.user_id,
|
||||||
|
COALESCE(inviter.email, ''),
|
||||||
|
COALESCE(inviter.username, ''),
|
||||||
|
ual.source_user_id,
|
||||||
|
COALESCE(invitee.email, ''),
|
||||||
|
COALESCE(invitee.username, ''),
|
||||||
|
po.amount::double precision,
|
||||||
|
po.pay_amount::double precision,
|
||||||
|
ual.amount::double precision,
|
||||||
|
po.payment_type,
|
||||||
|
po.status,
|
||||||
|
ual.created_at
|
||||||
|
`+baseJoin+where+`
|
||||||
|
`+orderBy+`
|
||||||
|
LIMIT $`+fmt.Sprint(len(args)-1)+` OFFSET $`+fmt.Sprint(len(args)), args...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, 0, err
|
||||||
|
}
|
||||||
|
defer func() { _ = rows.Close() }()
|
||||||
|
|
||||||
|
items := make([]service.AffiliateRebateRecord, 0)
|
||||||
|
for rows.Next() {
|
||||||
|
var item service.AffiliateRebateRecord
|
||||||
|
if err := rows.Scan(
|
||||||
|
&item.OrderID,
|
||||||
|
&item.OutTradeNo,
|
||||||
|
&item.InviterID,
|
||||||
|
&item.InviterEmail,
|
||||||
|
&item.InviterUsername,
|
||||||
|
&item.InviteeID,
|
||||||
|
&item.InviteeEmail,
|
||||||
|
&item.InviteeUsername,
|
||||||
|
&item.OrderAmount,
|
||||||
|
&item.PayAmount,
|
||||||
|
&item.RebateAmount,
|
||||||
|
&item.PaymentType,
|
||||||
|
&item.OrderStatus,
|
||||||
|
&item.CreatedAt,
|
||||||
|
); err != nil {
|
||||||
|
return nil, 0, err
|
||||||
|
}
|
||||||
|
items = append(items, item)
|
||||||
|
}
|
||||||
|
if err := rows.Err(); err != nil {
|
||||||
|
return nil, 0, err
|
||||||
|
}
|
||||||
|
return items, total, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *affiliateRepository) ListAffiliateTransferRecords(ctx context.Context, filter service.AffiliateRecordFilter) ([]service.AffiliateTransferRecord, int64, error) {
|
||||||
|
client := clientFromContext(ctx, r.client)
|
||||||
|
where, args := buildAffiliateRecordWhere(filter, "ual.created_at", []string{
|
||||||
|
"u.email", "u.username", "u.id::text",
|
||||||
|
})
|
||||||
|
baseJoin := `
|
||||||
|
FROM user_affiliate_ledger ual
|
||||||
|
JOIN users u ON u.id = ual.user_id
|
||||||
|
WHERE ual.action = 'transfer'`
|
||||||
|
if where != "" {
|
||||||
|
where = strings.Replace(where, "WHERE ", " AND ", 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
total, err := queryAffiliateRecordCount(ctx, client, "SELECT COUNT(*) "+baseJoin+where, args...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
orderBy := buildAffiliateRecordOrderBy(filter, map[string]string{
|
||||||
|
"user": "u.email",
|
||||||
|
"amount": "ual.amount",
|
||||||
|
"balance_after": "ual.balance_after",
|
||||||
|
"available_quota_after": "ual.aff_quota_after",
|
||||||
|
"frozen_quota_after": "ual.aff_frozen_quota_after",
|
||||||
|
"history_quota_after": "ual.aff_history_quota_after",
|
||||||
|
"created_at": "ual.created_at",
|
||||||
|
}, "ual.created_at")
|
||||||
|
args = append(args, filter.PageSize, (filter.Page-1)*filter.PageSize)
|
||||||
|
rows, err := client.QueryContext(ctx, `
|
||||||
|
SELECT ual.id,
|
||||||
|
ual.user_id,
|
||||||
|
COALESCE(u.email, ''),
|
||||||
|
COALESCE(u.username, ''),
|
||||||
|
ual.amount::double precision,
|
||||||
|
ual.balance_after::double precision,
|
||||||
|
ual.aff_quota_after::double precision,
|
||||||
|
ual.aff_frozen_quota_after::double precision,
|
||||||
|
ual.aff_history_quota_after::double precision,
|
||||||
|
ual.created_at
|
||||||
|
`+baseJoin+where+`
|
||||||
|
`+orderBy+`
|
||||||
|
LIMIT $`+fmt.Sprint(len(args)-1)+` OFFSET $`+fmt.Sprint(len(args)), args...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, 0, err
|
||||||
|
}
|
||||||
|
defer func() { _ = rows.Close() }()
|
||||||
|
|
||||||
|
items := make([]service.AffiliateTransferRecord, 0)
|
||||||
|
for rows.Next() {
|
||||||
|
var item service.AffiliateTransferRecord
|
||||||
|
var balanceAfter sql.NullFloat64
|
||||||
|
var availableQuotaAfter sql.NullFloat64
|
||||||
|
var frozenQuotaAfter sql.NullFloat64
|
||||||
|
var historyQuotaAfter sql.NullFloat64
|
||||||
|
if err := rows.Scan(
|
||||||
|
&item.LedgerID,
|
||||||
|
&item.UserID,
|
||||||
|
&item.UserEmail,
|
||||||
|
&item.Username,
|
||||||
|
&item.Amount,
|
||||||
|
&balanceAfter,
|
||||||
|
&availableQuotaAfter,
|
||||||
|
&frozenQuotaAfter,
|
||||||
|
&historyQuotaAfter,
|
||||||
|
&item.CreatedAt,
|
||||||
|
); err != nil {
|
||||||
|
return nil, 0, err
|
||||||
|
}
|
||||||
|
item.BalanceAfter = nullableFloat64Ptr(balanceAfter)
|
||||||
|
item.AvailableQuotaAfter = nullableFloat64Ptr(availableQuotaAfter)
|
||||||
|
item.FrozenQuotaAfter = nullableFloat64Ptr(frozenQuotaAfter)
|
||||||
|
item.HistoryQuotaAfter = nullableFloat64Ptr(historyQuotaAfter)
|
||||||
|
item.SnapshotAvailable = balanceAfter.Valid &&
|
||||||
|
availableQuotaAfter.Valid &&
|
||||||
|
frozenQuotaAfter.Valid &&
|
||||||
|
historyQuotaAfter.Valid
|
||||||
|
items = append(items, item)
|
||||||
|
}
|
||||||
|
if err := rows.Err(); err != nil {
|
||||||
|
return nil, 0, err
|
||||||
|
}
|
||||||
|
return items, total, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *affiliateRepository) GetAffiliateUserOverview(ctx context.Context, userID int64) (*service.AffiliateUserOverview, error) {
|
||||||
|
if userID <= 0 {
|
||||||
|
return nil, service.ErrUserNotFound
|
||||||
|
}
|
||||||
|
client := clientFromContext(ctx, r.client)
|
||||||
|
rows, err := client.QueryContext(ctx, affiliateUserOverviewSQL, userID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer func() { _ = rows.Close() }()
|
||||||
|
|
||||||
|
if !rows.Next() {
|
||||||
|
if err := rows.Err(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return nil, service.ErrUserNotFound
|
||||||
|
}
|
||||||
|
|
||||||
|
var overview service.AffiliateUserOverview
|
||||||
|
var customRate float64
|
||||||
|
var hasCustomRate bool
|
||||||
|
if err := rows.Scan(
|
||||||
|
&overview.UserID,
|
||||||
|
&overview.Email,
|
||||||
|
&overview.Username,
|
||||||
|
&overview.AffCode,
|
||||||
|
&customRate,
|
||||||
|
&hasCustomRate,
|
||||||
|
&overview.InvitedCount,
|
||||||
|
&overview.RebatedInviteeCount,
|
||||||
|
&overview.AvailableQuota,
|
||||||
|
&overview.HistoryQuota,
|
||||||
|
); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if hasCustomRate {
|
||||||
|
overview.RebateRatePercent = customRate
|
||||||
|
overview.RebateRateCustom = true
|
||||||
|
}
|
||||||
|
return &overview, rows.Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildAffiliateRecordWhere(filter service.AffiliateRecordFilter, timeColumn string, searchColumns []string) (string, []any) {
|
||||||
|
clauses := make([]string, 0, 3)
|
||||||
|
args := make([]any, 0, 3)
|
||||||
|
if filter.StartAt != nil {
|
||||||
|
args = append(args, *filter.StartAt)
|
||||||
|
clauses = append(clauses, fmt.Sprintf("%s >= $%d", timeColumn, len(args)))
|
||||||
|
}
|
||||||
|
if filter.EndAt != nil {
|
||||||
|
args = append(args, *filter.EndAt)
|
||||||
|
clauses = append(clauses, fmt.Sprintf("%s <= $%d", timeColumn, len(args)))
|
||||||
|
}
|
||||||
|
search := strings.TrimSpace(filter.Search)
|
||||||
|
if search != "" && len(searchColumns) > 0 {
|
||||||
|
args = append(args, "%"+strings.ToLower(search)+"%")
|
||||||
|
parts := make([]string, 0, len(searchColumns))
|
||||||
|
for _, col := range searchColumns {
|
||||||
|
parts = append(parts, fmt.Sprintf("LOWER(%s) LIKE $%d", col, len(args)))
|
||||||
|
}
|
||||||
|
clauses = append(clauses, "("+strings.Join(parts, " OR ")+")")
|
||||||
|
}
|
||||||
|
if len(clauses) == 0 {
|
||||||
|
return "", args
|
||||||
|
}
|
||||||
|
return "WHERE " + strings.Join(clauses, " AND "), args
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildAffiliateRecordOrderBy(filter service.AffiliateRecordFilter, sortColumns map[string]string, fallbackColumn string) string {
|
||||||
|
column := sortColumns[filter.SortBy]
|
||||||
|
if column == "" {
|
||||||
|
column = fallbackColumn
|
||||||
|
}
|
||||||
|
direction := "DESC"
|
||||||
|
if !filter.SortDesc {
|
||||||
|
direction = "ASC"
|
||||||
|
}
|
||||||
|
return "ORDER BY " + column + " " + direction + " NULLS LAST"
|
||||||
|
}
|
||||||
|
|
||||||
|
func queryAffiliateRecordCount(ctx context.Context, client affiliateQueryExecer, query string, args ...any) (int64, error) {
|
||||||
|
rows, err := client.QueryContext(ctx, query, args...)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
defer func() { _ = rows.Close() }()
|
||||||
|
if !rows.Next() {
|
||||||
|
return 0, rows.Err()
|
||||||
|
}
|
||||||
|
var total int64
|
||||||
|
if err := rows.Scan(&total); err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
return total, rows.Err()
|
||||||
|
}
|
||||||
|
|
||||||
func (r *affiliateRepository) withTx(ctx context.Context, fn func(txCtx context.Context, txClient *dbent.Client) error) error {
|
func (r *affiliateRepository) withTx(ctx context.Context, fn func(txCtx context.Context, txClient *dbent.Client) error) error {
|
||||||
if tx := dbent.TxFromContext(ctx); tx != nil {
|
if tx := dbent.TxFromContext(ctx); tx != nil {
|
||||||
return fn(ctx, tx.Client())
|
return fn(ctx, tx.Client())
|
||||||
@ -516,6 +910,54 @@ func queryUserBalance(ctx context.Context, client affiliateQueryExecer, userID i
|
|||||||
return balance, nil
|
return balance, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type affiliateTransferSnapshot struct {
|
||||||
|
BalanceAfter float64
|
||||||
|
AvailableQuotaAfter float64
|
||||||
|
FrozenQuotaAfter float64
|
||||||
|
HistoryQuotaAfter float64
|
||||||
|
}
|
||||||
|
|
||||||
|
func queryAffiliateTransferSnapshot(ctx context.Context, client affiliateQueryExecer, userID int64) (*affiliateTransferSnapshot, error) {
|
||||||
|
rows, err := client.QueryContext(ctx, `
|
||||||
|
SELECT u.balance::double precision,
|
||||||
|
ua.aff_quota::double precision,
|
||||||
|
ua.aff_frozen_quota::double precision,
|
||||||
|
ua.aff_history_quota::double precision
|
||||||
|
FROM users u
|
||||||
|
JOIN user_affiliates ua ON ua.user_id = u.id
|
||||||
|
WHERE u.id = $1
|
||||||
|
LIMIT 1`, userID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("query affiliate transfer snapshot: %w", err)
|
||||||
|
}
|
||||||
|
defer func() { _ = rows.Close() }()
|
||||||
|
|
||||||
|
if !rows.Next() {
|
||||||
|
if err := rows.Err(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return nil, service.ErrUserNotFound
|
||||||
|
}
|
||||||
|
|
||||||
|
var snapshot affiliateTransferSnapshot
|
||||||
|
if err := rows.Scan(
|
||||||
|
&snapshot.BalanceAfter,
|
||||||
|
&snapshot.AvailableQuotaAfter,
|
||||||
|
&snapshot.FrozenQuotaAfter,
|
||||||
|
&snapshot.HistoryQuotaAfter,
|
||||||
|
); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &snapshot, rows.Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
func nullableFloat64Ptr(v sql.NullFloat64) *float64 {
|
||||||
|
if !v.Valid {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return &v.Float64
|
||||||
|
}
|
||||||
|
|
||||||
func generateAffiliateCode() (string, error) {
|
func generateAffiliateCode() (string, error) {
|
||||||
buf := make([]byte, affiliateCodeLength)
|
buf := make([]byte, affiliateCodeLength)
|
||||||
if _, err := rand.Read(buf); err != nil {
|
if _, err := rand.Read(buf); err != nil {
|
||||||
@ -674,6 +1116,13 @@ func nullableArg(v *float64) any {
|
|||||||
return *v
|
return *v
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func nullableInt64Arg(v *int64) any {
|
||||||
|
if v == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return *v
|
||||||
|
}
|
||||||
|
|
||||||
// ListUsersWithCustomSettings 列出有专属配置(自定义码或专属比例)的用户。
|
// ListUsersWithCustomSettings 列出有专属配置(自定义码或专属比例)的用户。
|
||||||
//
|
//
|
||||||
// 单一查询同时处理"无搜索"与"按邮箱/用户名模糊搜索":
|
// 单一查询同时处理"无搜索"与"按邮箱/用户名模糊搜索":
|
||||||
|
|||||||
@ -78,6 +78,26 @@ VALUES ($1, $2, $3, $3, NOW(), NOW())`, u.ID, affCode, 12.34)
|
|||||||
ledgerCount := querySingleInt(t, txCtx, client,
|
ledgerCount := querySingleInt(t, txCtx, client,
|
||||||
"SELECT COUNT(*) FROM user_affiliate_ledger WHERE user_id = $1 AND action = 'transfer'", u.ID)
|
"SELECT COUNT(*) FROM user_affiliate_ledger WHERE user_id = $1 AND action = 'transfer'", u.ID)
|
||||||
require.Equal(t, 1, ledgerCount)
|
require.Equal(t, 1, ledgerCount)
|
||||||
|
|
||||||
|
rows, err := client.QueryContext(txCtx, `
|
||||||
|
SELECT amount::double precision,
|
||||||
|
balance_after::double precision,
|
||||||
|
aff_quota_after::double precision,
|
||||||
|
aff_frozen_quota_after::double precision,
|
||||||
|
aff_history_quota_after::double precision
|
||||||
|
FROM user_affiliate_ledger
|
||||||
|
WHERE user_id = $1 AND action = 'transfer'
|
||||||
|
LIMIT 1`, u.ID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer func() { _ = rows.Close() }()
|
||||||
|
require.True(t, rows.Next(), "expected transfer ledger")
|
||||||
|
var amount, balanceAfter, quotaAfter, frozenAfter, historyAfter float64
|
||||||
|
require.NoError(t, rows.Scan(&amount, &balanceAfter, "aAfter, &frozenAfter, &historyAfter))
|
||||||
|
require.InDelta(t, 12.34, amount, 1e-9)
|
||||||
|
require.InDelta(t, 17.84, balanceAfter, 1e-9)
|
||||||
|
require.InDelta(t, 0.0, quotaAfter, 1e-9)
|
||||||
|
require.InDelta(t, 0.0, frozenAfter, 1e-9)
|
||||||
|
require.InDelta(t, 12.34, historyAfter, 1e-9)
|
||||||
}
|
}
|
||||||
|
|
||||||
// TestAffiliateRepository_AccrueQuota_ReusesOuterTransaction guards the
|
// TestAffiliateRepository_AccrueQuota_ReusesOuterTransaction guards the
|
||||||
@ -125,7 +145,7 @@ func TestAffiliateRepository_AccrueQuota_ReusesOuterTransaction(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.True(t, bound, "invitee must bind to inviter")
|
require.True(t, bound, "invitee must bind to inviter")
|
||||||
|
|
||||||
applied, err := repo.AccrueQuota(txCtx, inviter.ID, invitee.ID, 3.5, 0)
|
applied, err := repo.AccrueQuota(txCtx, inviter.ID, invitee.ID, 3.5, 0, nil)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.True(t, applied, "AccrueQuota must report applied=true")
|
require.True(t, applied, "AccrueQuota must report applied=true")
|
||||||
|
|
||||||
|
|||||||
28
backend/internal/repository/affiliate_repo_test.go
Normal file
28
backend/internal/repository/affiliate_repo_test.go
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
package repository
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestAffiliateUserOverviewSQLIncludesMaturedFrozenQuota(t *testing.T) {
|
||||||
|
query := strings.Join(strings.Fields(affiliateUserOverviewSQL), " ")
|
||||||
|
|
||||||
|
require.Contains(t, query, "ua.aff_quota + COALESCE(matured.matured_frozen_quota, 0)")
|
||||||
|
require.Contains(t, query, "frozen_until <= NOW()")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAffiliateRecordQueriesUseLedgerAuditFields(t *testing.T) {
|
||||||
|
source, err := os.ReadFile("affiliate_repo.go")
|
||||||
|
require.NoError(t, err)
|
||||||
|
content := string(source)
|
||||||
|
|
||||||
|
require.Contains(t, content, "JOIN payment_orders po ON po.id = ual.source_order_id")
|
||||||
|
require.Contains(t, content, "ual.amount::double precision")
|
||||||
|
require.Contains(t, content, "ual.balance_after::double precision")
|
||||||
|
require.NotContains(t, content, "parseAffiliateRebateAmount")
|
||||||
|
require.NotContains(t, content, `"current_balance": "u.balance"`)
|
||||||
|
}
|
||||||
@ -602,11 +602,16 @@ func registerChannelMonitorRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
|||||||
func registerAffiliateRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
func registerAffiliateRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||||
affiliates := admin.Group("/affiliates")
|
affiliates := admin.Group("/affiliates")
|
||||||
{
|
{
|
||||||
|
affiliates.GET("/invites", h.Admin.Affiliate.ListInviteRecords)
|
||||||
|
affiliates.GET("/rebates", h.Admin.Affiliate.ListRebateRecords)
|
||||||
|
affiliates.GET("/transfers", h.Admin.Affiliate.ListTransferRecords)
|
||||||
|
|
||||||
users := affiliates.Group("/users")
|
users := affiliates.Group("/users")
|
||||||
{
|
{
|
||||||
users.GET("", h.Admin.Affiliate.ListUsers)
|
users.GET("", h.Admin.Affiliate.ListUsers)
|
||||||
users.GET("/lookup", h.Admin.Affiliate.LookupUsers)
|
users.GET("/lookup", h.Admin.Affiliate.LookupUsers)
|
||||||
users.POST("/batch-rate", h.Admin.Affiliate.BatchSetRate)
|
users.POST("/batch-rate", h.Admin.Affiliate.BatchSetRate)
|
||||||
|
users.GET("/:user_id/overview", h.Admin.Affiliate.GetUserOverview)
|
||||||
users.PUT("/:user_id", h.Admin.Affiliate.UpdateUserSettings)
|
users.PUT("/:user_id", h.Admin.Affiliate.UpdateUserSettings)
|
||||||
users.DELETE("/:user_id", h.Admin.Affiliate.ClearUserSettings)
|
users.DELETE("/:user_id", h.Admin.Affiliate.ClearUserSettings)
|
||||||
}
|
}
|
||||||
|
|||||||
86
backend/internal/service/admin_balance_history_test.go
Normal file
86
backend/internal/service/admin_balance_history_test.go
Normal file
@ -0,0 +1,86 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestMergeBalanceHistoryCodesIncludesAffiliateTransfersByDefault(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
now := time.Date(2026, 5, 3, 12, 0, 0, 0, time.UTC)
|
||||||
|
older := now.Add(-2 * time.Hour)
|
||||||
|
newer := now.Add(time.Hour)
|
||||||
|
|
||||||
|
usedBy := int64(10)
|
||||||
|
redeemCodes := []RedeemCode{
|
||||||
|
{
|
||||||
|
ID: 1,
|
||||||
|
Type: RedeemTypeBalance,
|
||||||
|
Value: 8,
|
||||||
|
Status: StatusUsed,
|
||||||
|
UsedBy: &usedBy,
|
||||||
|
UsedAt: &now,
|
||||||
|
CreatedAt: now,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: 2,
|
||||||
|
Type: RedeemTypeConcurrency,
|
||||||
|
Value: 1,
|
||||||
|
Status: StatusUsed,
|
||||||
|
UsedBy: &usedBy,
|
||||||
|
UsedAt: &older,
|
||||||
|
CreatedAt: older,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
affiliateCodes := []RedeemCode{
|
||||||
|
{
|
||||||
|
ID: -20,
|
||||||
|
Type: RedeemTypeAffiliateBalance,
|
||||||
|
Value: 3.5,
|
||||||
|
Status: StatusUsed,
|
||||||
|
UsedBy: &usedBy,
|
||||||
|
UsedAt: &newer,
|
||||||
|
CreatedAt: newer,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
got := mergeBalanceHistoryCodes(redeemCodes, affiliateCodes, pagination.PaginationParams{
|
||||||
|
Page: 1,
|
||||||
|
PageSize: 2,
|
||||||
|
})
|
||||||
|
|
||||||
|
require.Len(t, got, 2)
|
||||||
|
require.Equal(t, RedeemTypeAffiliateBalance, got[0].Type)
|
||||||
|
require.Equal(t, RedeemTypeBalance, got[1].Type)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMergeBalanceHistoryCodesPaginatesAfterCombiningSources(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
base := time.Date(2026, 5, 3, 12, 0, 0, 0, time.UTC)
|
||||||
|
usedBy := int64(10)
|
||||||
|
at := func(hours int) *time.Time {
|
||||||
|
v := base.Add(time.Duration(hours) * time.Hour)
|
||||||
|
return &v
|
||||||
|
}
|
||||||
|
|
||||||
|
got := mergeBalanceHistoryCodes(
|
||||||
|
[]RedeemCode{
|
||||||
|
{ID: 1, Type: RedeemTypeBalance, UsedBy: &usedBy, UsedAt: at(4), CreatedAt: *at(4)},
|
||||||
|
{ID: 2, Type: RedeemTypeConcurrency, UsedBy: &usedBy, UsedAt: at(2), CreatedAt: *at(2)},
|
||||||
|
},
|
||||||
|
[]RedeemCode{
|
||||||
|
{ID: -3, Type: RedeemTypeAffiliateBalance, UsedBy: &usedBy, UsedAt: at(3), CreatedAt: *at(3)},
|
||||||
|
{ID: -4, Type: RedeemTypeAffiliateBalance, UsedBy: &usedBy, UsedAt: at(1), CreatedAt: *at(1)},
|
||||||
|
},
|
||||||
|
pagination.PaginationParams{Page: 2, PageSize: 2},
|
||||||
|
)
|
||||||
|
|
||||||
|
require.Len(t, got, 2)
|
||||||
|
require.Equal(t, RedeemTypeConcurrency, got[0].Type)
|
||||||
|
require.Equal(t, int64(-4), got[1].ID)
|
||||||
|
}
|
||||||
@ -2,6 +2,7 @@ package service
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"database/sql"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
@ -973,16 +974,213 @@ func (s *adminServiceImpl) GetUserUsageStats(ctx context.Context, userID int64,
|
|||||||
// GetUserBalanceHistory returns paginated balance/concurrency change records for a user.
|
// GetUserBalanceHistory returns paginated balance/concurrency change records for a user.
|
||||||
func (s *adminServiceImpl) GetUserBalanceHistory(ctx context.Context, userID int64, page, pageSize int, codeType string) ([]RedeemCode, int64, float64, error) {
|
func (s *adminServiceImpl) GetUserBalanceHistory(ctx context.Context, userID int64, page, pageSize int, codeType string) ([]RedeemCode, int64, float64, error) {
|
||||||
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
|
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
|
||||||
|
if codeType == RedeemTypeAffiliateBalance {
|
||||||
|
codes, total, err := s.listAffiliateBalanceHistory(ctx, userID, params)
|
||||||
|
if err != nil {
|
||||||
|
return nil, 0, 0, err
|
||||||
|
}
|
||||||
|
totalRecharged, err := s.redeemCodeRepo.SumPositiveBalanceByUser(ctx, userID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, 0, 0, err
|
||||||
|
}
|
||||||
|
return codes, total, totalRecharged, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if codeType == "" {
|
||||||
|
return s.getAllUserBalanceHistory(ctx, userID, params)
|
||||||
|
}
|
||||||
|
|
||||||
codes, result, err := s.redeemCodeRepo.ListByUserPaginated(ctx, userID, params, codeType)
|
codes, result, err := s.redeemCodeRepo.ListByUserPaginated(ctx, userID, params, codeType)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, 0, 0, err
|
return nil, 0, 0, err
|
||||||
}
|
}
|
||||||
|
total := result.Total
|
||||||
// Aggregate total recharged amount (only once, regardless of type filter)
|
// Aggregate total recharged amount (only once, regardless of type filter)
|
||||||
totalRecharged, err := s.redeemCodeRepo.SumPositiveBalanceByUser(ctx, userID)
|
totalRecharged, err := s.redeemCodeRepo.SumPositiveBalanceByUser(ctx, userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, 0, 0, err
|
return nil, 0, 0, err
|
||||||
}
|
}
|
||||||
return codes, result.Total, totalRecharged, nil
|
return codes, total, totalRecharged, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *adminServiceImpl) getAllUserBalanceHistory(ctx context.Context, userID int64, params pagination.PaginationParams) ([]RedeemCode, int64, float64, error) {
|
||||||
|
needed := params.Offset() + params.Limit()
|
||||||
|
if needed < params.Limit() {
|
||||||
|
needed = params.Limit()
|
||||||
|
}
|
||||||
|
|
||||||
|
redeemCodes, redeemTotal, err := s.listRedeemBalanceHistoryForMerge(ctx, userID, needed)
|
||||||
|
if err != nil {
|
||||||
|
return nil, 0, 0, err
|
||||||
|
}
|
||||||
|
affiliateCodes, affiliateTotal, err := s.listAffiliateBalanceHistoryForMerge(ctx, userID, needed)
|
||||||
|
if err != nil {
|
||||||
|
return nil, 0, 0, err
|
||||||
|
}
|
||||||
|
codes := mergeBalanceHistoryCodes(redeemCodes, affiliateCodes, params)
|
||||||
|
|
||||||
|
totalRecharged, err := s.redeemCodeRepo.SumPositiveBalanceByUser(ctx, userID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, 0, 0, err
|
||||||
|
}
|
||||||
|
return codes, redeemTotal + affiliateTotal, totalRecharged, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *adminServiceImpl) listRedeemBalanceHistoryForMerge(ctx context.Context, userID int64, needed int) ([]RedeemCode, int64, error) {
|
||||||
|
if needed <= 0 {
|
||||||
|
return nil, 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
out []RedeemCode
|
||||||
|
total int64
|
||||||
|
)
|
||||||
|
for page := 1; len(out) < needed; page++ {
|
||||||
|
params := pagination.PaginationParams{Page: page, PageSize: 1000}
|
||||||
|
codes, result, err := s.redeemCodeRepo.ListByUserPaginated(ctx, userID, params, "")
|
||||||
|
if err != nil {
|
||||||
|
return nil, 0, err
|
||||||
|
}
|
||||||
|
if result != nil {
|
||||||
|
total = result.Total
|
||||||
|
}
|
||||||
|
out = append(out, codes...)
|
||||||
|
if len(codes) < params.Limit() || int64(len(out)) >= total {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(out) > needed {
|
||||||
|
out = out[:needed]
|
||||||
|
}
|
||||||
|
return out, total, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *adminServiceImpl) listAffiliateBalanceHistoryForMerge(ctx context.Context, userID int64, needed int) ([]RedeemCode, int64, error) {
|
||||||
|
if needed <= 0 {
|
||||||
|
return nil, 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
out []RedeemCode
|
||||||
|
total int64
|
||||||
|
)
|
||||||
|
for page := 1; len(out) < needed; page++ {
|
||||||
|
params := pagination.PaginationParams{Page: page, PageSize: 1000}
|
||||||
|
codes, currentTotal, err := s.listAffiliateBalanceHistory(ctx, userID, params)
|
||||||
|
if err != nil {
|
||||||
|
return nil, 0, err
|
||||||
|
}
|
||||||
|
total = currentTotal
|
||||||
|
out = append(out, codes...)
|
||||||
|
if len(codes) < params.Limit() || int64(len(out)) >= total {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(out) > needed {
|
||||||
|
out = out[:needed]
|
||||||
|
}
|
||||||
|
return out, total, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *adminServiceImpl) listAffiliateBalanceHistory(ctx context.Context, userID int64, params pagination.PaginationParams) ([]RedeemCode, int64, error) {
|
||||||
|
if s == nil || s.entClient == nil || userID <= 0 {
|
||||||
|
return nil, 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
rows, err := s.entClient.QueryContext(ctx, `
|
||||||
|
SELECT id,
|
||||||
|
amount::double precision,
|
||||||
|
created_at
|
||||||
|
FROM user_affiliate_ledger
|
||||||
|
WHERE user_id = $1
|
||||||
|
AND action = 'transfer'
|
||||||
|
ORDER BY created_at DESC, id DESC
|
||||||
|
OFFSET $2
|
||||||
|
LIMIT $3`, userID, params.Offset(), params.Limit())
|
||||||
|
if err != nil {
|
||||||
|
return nil, 0, err
|
||||||
|
}
|
||||||
|
defer func() { _ = rows.Close() }()
|
||||||
|
|
||||||
|
codes := make([]RedeemCode, 0, params.Limit())
|
||||||
|
for rows.Next() {
|
||||||
|
var id int64
|
||||||
|
var amount float64
|
||||||
|
var createdAt time.Time
|
||||||
|
if err := rows.Scan(&id, &amount, &createdAt); err != nil {
|
||||||
|
return nil, 0, err
|
||||||
|
}
|
||||||
|
usedBy := userID
|
||||||
|
usedAt := createdAt
|
||||||
|
codes = append(codes, RedeemCode{
|
||||||
|
ID: -id,
|
||||||
|
Code: fmt.Sprintf("AFF-%d", id),
|
||||||
|
Type: RedeemTypeAffiliateBalance,
|
||||||
|
Value: amount,
|
||||||
|
Status: StatusUsed,
|
||||||
|
UsedBy: &usedBy,
|
||||||
|
UsedAt: &usedAt,
|
||||||
|
CreatedAt: createdAt,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
if err := rows.Err(); err != nil {
|
||||||
|
return nil, 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
total, err := countAffiliateBalanceHistory(ctx, s.entClient, userID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, 0, err
|
||||||
|
}
|
||||||
|
return codes, total, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func countAffiliateBalanceHistory(ctx context.Context, client *dbent.Client, userID int64) (int64, error) {
|
||||||
|
rows, err := client.QueryContext(ctx, `
|
||||||
|
SELECT COUNT(*)
|
||||||
|
FROM user_affiliate_ledger
|
||||||
|
WHERE user_id = $1
|
||||||
|
AND action = 'transfer'`, userID)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
defer func() { _ = rows.Close() }()
|
||||||
|
|
||||||
|
var total sql.NullInt64
|
||||||
|
if rows.Next() {
|
||||||
|
if err := rows.Scan(&total); err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err := rows.Err(); err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
if !total.Valid {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
return total.Int64, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func mergeBalanceHistoryCodes(redeemCodes, affiliateCodes []RedeemCode, params pagination.PaginationParams) []RedeemCode {
|
||||||
|
combined := append(append([]RedeemCode{}, redeemCodes...), affiliateCodes...)
|
||||||
|
sort.SliceStable(combined, func(i, j int) bool {
|
||||||
|
return redeemCodeHistoryTime(combined[i]).After(redeemCodeHistoryTime(combined[j]))
|
||||||
|
})
|
||||||
|
offset := params.Offset()
|
||||||
|
if offset >= len(combined) {
|
||||||
|
return []RedeemCode{}
|
||||||
|
}
|
||||||
|
end := offset + params.Limit()
|
||||||
|
if end > len(combined) {
|
||||||
|
end = len(combined)
|
||||||
|
}
|
||||||
|
return combined[offset:end]
|
||||||
|
}
|
||||||
|
|
||||||
|
func redeemCodeHistoryTime(code RedeemCode) time.Time {
|
||||||
|
if code.UsedAt != nil {
|
||||||
|
return *code.UsedAt
|
||||||
|
}
|
||||||
|
return code.CreatedAt
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *adminServiceImpl) BindUserAuthIdentity(ctx context.Context, userID int64, input AdminBindAuthIdentityInput) (*AdminBoundAuthIdentity, error) {
|
func (s *adminServiceImpl) BindUserAuthIdentity(ctx context.Context, userID int64, input AdminBindAuthIdentityInput) (*AdminBoundAuthIdentity, error) {
|
||||||
|
|||||||
@ -98,7 +98,7 @@ type AffiliateRepository interface {
|
|||||||
EnsureUserAffiliate(ctx context.Context, userID int64) (*AffiliateSummary, error)
|
EnsureUserAffiliate(ctx context.Context, userID int64) (*AffiliateSummary, error)
|
||||||
GetAffiliateByCode(ctx context.Context, code string) (*AffiliateSummary, error)
|
GetAffiliateByCode(ctx context.Context, code string) (*AffiliateSummary, error)
|
||||||
BindInviter(ctx context.Context, userID, inviterID int64) (bool, error)
|
BindInviter(ctx context.Context, userID, inviterID int64) (bool, error)
|
||||||
AccrueQuota(ctx context.Context, inviterID, inviteeUserID int64, amount float64, freezeHours int) (bool, error)
|
AccrueQuota(ctx context.Context, inviterID, inviteeUserID int64, amount float64, freezeHours int, sourceOrderID *int64) (bool, error)
|
||||||
GetAccruedRebateFromInvitee(ctx context.Context, inviterID, inviteeUserID int64) (float64, error)
|
GetAccruedRebateFromInvitee(ctx context.Context, inviterID, inviteeUserID int64) (float64, error)
|
||||||
ThawFrozenQuota(ctx context.Context, userID int64) (float64, error)
|
ThawFrozenQuota(ctx context.Context, userID int64) (float64, error)
|
||||||
TransferQuotaToBalance(ctx context.Context, userID int64) (float64, float64, error)
|
TransferQuotaToBalance(ctx context.Context, userID int64) (float64, float64, error)
|
||||||
@ -110,6 +110,10 @@ type AffiliateRepository interface {
|
|||||||
SetUserRebateRate(ctx context.Context, userID int64, ratePercent *float64) error
|
SetUserRebateRate(ctx context.Context, userID int64, ratePercent *float64) error
|
||||||
BatchSetUserRebateRate(ctx context.Context, userIDs []int64, ratePercent *float64) error
|
BatchSetUserRebateRate(ctx context.Context, userIDs []int64, ratePercent *float64) error
|
||||||
ListUsersWithCustomSettings(ctx context.Context, filter AffiliateAdminFilter) ([]AffiliateAdminEntry, int64, error)
|
ListUsersWithCustomSettings(ctx context.Context, filter AffiliateAdminFilter) ([]AffiliateAdminEntry, int64, error)
|
||||||
|
ListAffiliateInviteRecords(ctx context.Context, filter AffiliateRecordFilter) ([]AffiliateInviteRecord, int64, error)
|
||||||
|
ListAffiliateRebateRecords(ctx context.Context, filter AffiliateRecordFilter) ([]AffiliateRebateRecord, int64, error)
|
||||||
|
ListAffiliateTransferRecords(ctx context.Context, filter AffiliateRecordFilter) ([]AffiliateTransferRecord, int64, error)
|
||||||
|
GetAffiliateUserOverview(ctx context.Context, userID int64) (*AffiliateUserOverview, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// AffiliateAdminFilter 列表筛选条件
|
// AffiliateAdminFilter 列表筛选条件
|
||||||
@ -130,6 +134,76 @@ type AffiliateAdminEntry struct {
|
|||||||
AffCount int `json:"aff_count"`
|
AffCount int `json:"aff_count"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type AffiliateRecordFilter struct {
|
||||||
|
Search string
|
||||||
|
Page int
|
||||||
|
PageSize int
|
||||||
|
StartAt *time.Time
|
||||||
|
EndAt *time.Time
|
||||||
|
SortBy string
|
||||||
|
SortDesc bool
|
||||||
|
}
|
||||||
|
|
||||||
|
type AffiliateInviteRecord struct {
|
||||||
|
InviterID int64 `json:"inviter_id"`
|
||||||
|
InviterEmail string `json:"inviter_email"`
|
||||||
|
InviterUsername string `json:"inviter_username"`
|
||||||
|
InviteeID int64 `json:"invitee_id"`
|
||||||
|
InviteeEmail string `json:"invitee_email"`
|
||||||
|
InviteeUsername string `json:"invitee_username"`
|
||||||
|
AffCode string `json:"aff_code"`
|
||||||
|
TotalRebate float64 `json:"total_rebate"`
|
||||||
|
CreatedAt time.Time `json:"created_at"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type AffiliateRebateRecord struct {
|
||||||
|
OrderID int64 `json:"order_id"`
|
||||||
|
OutTradeNo string `json:"out_trade_no"`
|
||||||
|
InviterID int64 `json:"inviter_id"`
|
||||||
|
InviterEmail string `json:"inviter_email"`
|
||||||
|
InviterUsername string `json:"inviter_username"`
|
||||||
|
InviteeID int64 `json:"invitee_id"`
|
||||||
|
InviteeEmail string `json:"invitee_email"`
|
||||||
|
InviteeUsername string `json:"invitee_username"`
|
||||||
|
OrderAmount float64 `json:"order_amount"`
|
||||||
|
PayAmount float64 `json:"pay_amount"`
|
||||||
|
RebateAmount float64 `json:"rebate_amount"`
|
||||||
|
PaymentType string `json:"payment_type"`
|
||||||
|
OrderStatus string `json:"order_status"`
|
||||||
|
CreatedAt time.Time `json:"created_at"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type AffiliateTransferRecord struct {
|
||||||
|
LedgerID int64 `json:"ledger_id"`
|
||||||
|
UserID int64 `json:"user_id"`
|
||||||
|
UserEmail string `json:"user_email"`
|
||||||
|
Username string `json:"username"`
|
||||||
|
Amount float64 `json:"amount"`
|
||||||
|
BalanceAfter *float64 `json:"balance_after,omitempty"`
|
||||||
|
AvailableQuotaAfter *float64 `json:"available_quota_after,omitempty"`
|
||||||
|
FrozenQuotaAfter *float64 `json:"frozen_quota_after,omitempty"`
|
||||||
|
HistoryQuotaAfter *float64 `json:"history_quota_after,omitempty"`
|
||||||
|
SnapshotAvailable bool `json:"snapshot_available"`
|
||||||
|
CurrentBalance float64 `json:"-"`
|
||||||
|
RemainingQuota float64 `json:"-"`
|
||||||
|
FrozenQuota float64 `json:"-"`
|
||||||
|
HistoryQuota float64 `json:"-"`
|
||||||
|
CreatedAt time.Time `json:"created_at"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type AffiliateUserOverview struct {
|
||||||
|
UserID int64 `json:"user_id"`
|
||||||
|
Email string `json:"email"`
|
||||||
|
Username string `json:"username"`
|
||||||
|
AffCode string `json:"aff_code"`
|
||||||
|
RebateRatePercent float64 `json:"rebate_rate_percent"`
|
||||||
|
RebateRateCustom bool `json:"-"`
|
||||||
|
InvitedCount int `json:"invited_count"`
|
||||||
|
RebatedInviteeCount int `json:"rebated_invitee_count"`
|
||||||
|
AvailableQuota float64 `json:"available_quota"`
|
||||||
|
HistoryQuota float64 `json:"history_quota"`
|
||||||
|
}
|
||||||
|
|
||||||
type AffiliateService struct {
|
type AffiliateService struct {
|
||||||
repo AffiliateRepository
|
repo AffiliateRepository
|
||||||
settingService *SettingService
|
settingService *SettingService
|
||||||
@ -238,6 +312,10 @@ func (s *AffiliateService) BindInviterByCode(ctx context.Context, userID int64,
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *AffiliateService) AccrueInviteRebate(ctx context.Context, inviteeUserID int64, baseRechargeAmount float64) (float64, error) {
|
func (s *AffiliateService) AccrueInviteRebate(ctx context.Context, inviteeUserID int64, baseRechargeAmount float64) (float64, error) {
|
||||||
|
return s.AccrueInviteRebateForOrder(ctx, inviteeUserID, baseRechargeAmount, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *AffiliateService) AccrueInviteRebateForOrder(ctx context.Context, inviteeUserID int64, baseRechargeAmount float64, sourceOrderID *int64) (float64, error) {
|
||||||
if s == nil || s.repo == nil {
|
if s == nil || s.repo == nil {
|
||||||
return 0, nil
|
return 0, nil
|
||||||
}
|
}
|
||||||
@ -298,7 +376,7 @@ func (s *AffiliateService) AccrueInviteRebate(ctx context.Context, inviteeUserID
|
|||||||
freezeHours = s.settingService.GetAffiliateRebateFreezeHours(ctx)
|
freezeHours = s.settingService.GetAffiliateRebateFreezeHours(ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
applied, err := s.repo.AccrueQuota(ctx, *inviteeSummary.InviterID, inviteeUserID, rebate, freezeHours)
|
applied, err := s.repo.AccrueQuota(ctx, *inviteeSummary.InviterID, inviteeUserID, rebate, freezeHours, sourceOrderID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
@ -488,3 +566,59 @@ func (s *AffiliateService) AdminListCustomUsers(ctx context.Context, filter Affi
|
|||||||
}
|
}
|
||||||
return s.repo.ListUsersWithCustomSettings(ctx, filter)
|
return s.repo.ListUsersWithCustomSettings(ctx, filter)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *AffiliateService) AdminListInviteRecords(ctx context.Context, filter AffiliateRecordFilter) ([]AffiliateInviteRecord, int64, error) {
|
||||||
|
if s == nil || s.repo == nil {
|
||||||
|
return nil, 0, infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "affiliate service unavailable")
|
||||||
|
}
|
||||||
|
return s.repo.ListAffiliateInviteRecords(ctx, normalizeAffiliateRecordFilter(filter))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *AffiliateService) AdminListRebateRecords(ctx context.Context, filter AffiliateRecordFilter) ([]AffiliateRebateRecord, int64, error) {
|
||||||
|
if s == nil || s.repo == nil {
|
||||||
|
return nil, 0, infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "affiliate service unavailable")
|
||||||
|
}
|
||||||
|
return s.repo.ListAffiliateRebateRecords(ctx, normalizeAffiliateRecordFilter(filter))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *AffiliateService) AdminListTransferRecords(ctx context.Context, filter AffiliateRecordFilter) ([]AffiliateTransferRecord, int64, error) {
|
||||||
|
if s == nil || s.repo == nil {
|
||||||
|
return nil, 0, infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "affiliate service unavailable")
|
||||||
|
}
|
||||||
|
return s.repo.ListAffiliateTransferRecords(ctx, normalizeAffiliateRecordFilter(filter))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *AffiliateService) AdminGetUserOverview(ctx context.Context, userID int64) (*AffiliateUserOverview, error) {
|
||||||
|
if userID <= 0 {
|
||||||
|
return nil, infraerrors.BadRequest("INVALID_USER", "invalid user")
|
||||||
|
}
|
||||||
|
if s == nil || s.repo == nil {
|
||||||
|
return nil, infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "affiliate service unavailable")
|
||||||
|
}
|
||||||
|
overview, err := s.repo.GetAffiliateUserOverview(ctx, userID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if overview != nil {
|
||||||
|
if !overview.RebateRateCustom {
|
||||||
|
overview.RebateRatePercent = s.globalRebateRatePercent(ctx)
|
||||||
|
}
|
||||||
|
overview.RebateRatePercent = clampAffiliateRebateRate(overview.RebateRatePercent)
|
||||||
|
}
|
||||||
|
return overview, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizeAffiliateRecordFilter(filter AffiliateRecordFilter) AffiliateRecordFilter {
|
||||||
|
if filter.Page <= 0 {
|
||||||
|
filter.Page = 1
|
||||||
|
}
|
||||||
|
if filter.PageSize <= 0 {
|
||||||
|
filter.PageSize = 20
|
||||||
|
}
|
||||||
|
if filter.PageSize > 100 {
|
||||||
|
filter.PageSize = 100
|
||||||
|
}
|
||||||
|
filter.Search = strings.TrimSpace(filter.Search)
|
||||||
|
filter.SortBy = strings.TrimSpace(filter.SortBy)
|
||||||
|
return filter
|
||||||
|
}
|
||||||
|
|||||||
@ -51,10 +51,11 @@ const (
|
|||||||
|
|
||||||
// Redeem type constants
|
// Redeem type constants
|
||||||
const (
|
const (
|
||||||
RedeemTypeBalance = domain.RedeemTypeBalance
|
RedeemTypeBalance = domain.RedeemTypeBalance
|
||||||
RedeemTypeConcurrency = domain.RedeemTypeConcurrency
|
RedeemTypeConcurrency = domain.RedeemTypeConcurrency
|
||||||
RedeemTypeSubscription = domain.RedeemTypeSubscription
|
RedeemTypeSubscription = domain.RedeemTypeSubscription
|
||||||
RedeemTypeInvitation = domain.RedeemTypeInvitation
|
RedeemTypeInvitation = domain.RedeemTypeInvitation
|
||||||
|
RedeemTypeAffiliateBalance = "affiliate_balance"
|
||||||
)
|
)
|
||||||
|
|
||||||
// PromoCode status constants
|
// PromoCode status constants
|
||||||
|
|||||||
@ -8174,9 +8174,16 @@ func detachedBillingContext(ctx context.Context) (context.Context, context.Cance
|
|||||||
}
|
}
|
||||||
|
|
||||||
func detachStreamUpstreamContext(ctx context.Context, stream bool) (context.Context, context.CancelFunc) {
|
func detachStreamUpstreamContext(ctx context.Context, stream bool) (context.Context, context.CancelFunc) {
|
||||||
|
if ctx == nil {
|
||||||
|
return context.Background(), func() {}
|
||||||
|
}
|
||||||
if !stream {
|
if !stream {
|
||||||
return ctx, func() {}
|
return ctx, func() {}
|
||||||
}
|
}
|
||||||
|
return context.WithoutCancel(ctx), func() {}
|
||||||
|
}
|
||||||
|
|
||||||
|
func detachUpstreamContext(ctx context.Context) (context.Context, context.CancelFunc) {
|
||||||
if ctx == nil {
|
if ctx == nil {
|
||||||
return context.Background(), func() {}
|
return context.Background(), func() {}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -13,6 +13,8 @@ import (
|
|||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type upstreamContextTestKey string
|
||||||
|
|
||||||
func TestGatewayService_StreamingReusesScannerBufferAndStillParsesUsage(t *testing.T) {
|
func TestGatewayService_StreamingReusesScannerBufferAndStillParsesUsage(t *testing.T) {
|
||||||
gin.SetMode(gin.TestMode)
|
gin.SetMode(gin.TestMode)
|
||||||
cfg := &config.Config{
|
cfg := &config.Config{
|
||||||
@ -50,3 +52,14 @@ func TestGatewayService_StreamingReusesScannerBufferAndStillParsesUsage(t *testi
|
|||||||
require.Equal(t, 3, result.usage.InputTokens)
|
require.Equal(t, 3, result.usage.InputTokens)
|
||||||
require.Equal(t, 7, result.usage.OutputTokens)
|
require.Equal(t, 7, result.usage.OutputTokens)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestDetachUpstreamContextIgnoresClientCancel(t *testing.T) {
|
||||||
|
parent, cancel := context.WithCancel(context.WithValue(context.Background(), upstreamContextTestKey("test-key"), "test-value"))
|
||||||
|
upstreamCtx, release := detachUpstreamContext(parent)
|
||||||
|
defer release()
|
||||||
|
|
||||||
|
cancel()
|
||||||
|
|
||||||
|
require.NoError(t, upstreamCtx.Err())
|
||||||
|
require.Equal(t, "test-value", upstreamCtx.Value(upstreamContextTestKey("test-key")))
|
||||||
|
}
|
||||||
|
|||||||
@ -3,13 +3,16 @@ package service
|
|||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
|
"errors"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/apicompat"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/apicompat"
|
||||||
@ -18,6 +21,51 @@ import (
|
|||||||
"github.com/tidwall/gjson"
|
"github.com/tidwall/gjson"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type openAICompatFailingWriter struct {
|
||||||
|
gin.ResponseWriter
|
||||||
|
failAfter int
|
||||||
|
writes int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *openAICompatFailingWriter) Write(p []byte) (int, error) {
|
||||||
|
if w.writes >= w.failAfter {
|
||||||
|
return 0, errors.New("write failed: client disconnected")
|
||||||
|
}
|
||||||
|
w.writes++
|
||||||
|
return w.ResponseWriter.Write(p)
|
||||||
|
}
|
||||||
|
|
||||||
|
type openAICompatBlockingReadCloser struct {
|
||||||
|
data []byte
|
||||||
|
offset int
|
||||||
|
closed chan struct{}
|
||||||
|
closeOnce sync.Once
|
||||||
|
}
|
||||||
|
|
||||||
|
func newOpenAICompatBlockingReadCloser(data []byte) *openAICompatBlockingReadCloser {
|
||||||
|
return &openAICompatBlockingReadCloser{
|
||||||
|
data: data,
|
||||||
|
closed: make(chan struct{}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *openAICompatBlockingReadCloser) Read(p []byte) (int, error) {
|
||||||
|
if r.offset < len(r.data) {
|
||||||
|
n := copy(p, r.data[r.offset:])
|
||||||
|
r.offset += n
|
||||||
|
return n, nil
|
||||||
|
}
|
||||||
|
<-r.closed
|
||||||
|
return 0, io.EOF
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *openAICompatBlockingReadCloser) Close() error {
|
||||||
|
r.closeOnce.Do(func() {
|
||||||
|
close(r.closed)
|
||||||
|
})
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func TestNormalizeOpenAICompatRequestedModel(t *testing.T) {
|
func TestNormalizeOpenAICompatRequestedModel(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
@ -228,3 +276,242 @@ func TestForwardAsAnthropic_ForcedCodexInstructionsTemplateUsesCachedTemplateCon
|
|||||||
require.NotNil(t, result)
|
require.NotNil(t, result)
|
||||||
require.Equal(t, "cached-prefix\n\nclient-system", gjson.GetBytes(upstream.lastBody, "instructions").String())
|
require.Equal(t, "cached-prefix\n\nclient-system", gjson.GetBytes(upstream.lastBody, "instructions").String())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestForwardAsAnthropic_ClientDisconnectDrainsUpstreamUsage(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
c.Writer = &openAICompatFailingWriter{ResponseWriter: c.Writer, failAfter: 0}
|
||||||
|
body := []byte(`{"model":"gpt-5.4","max_tokens":16,"messages":[{"role":"user","content":"hello"}],"stream":true}`)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body))
|
||||||
|
c.Request.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
upstreamBody := strings.Join([]string{
|
||||||
|
`data: {"type":"response.created","response":{"id":"resp_1","model":"gpt-5.4","status":"in_progress","output":[]}}`,
|
||||||
|
"",
|
||||||
|
`data: {"type":"response.output_text.delta","delta":"ok"}`,
|
||||||
|
"",
|
||||||
|
`data: {"type":"response.completed","response":{"id":"resp_1","object":"response","model":"gpt-5.4","status":"completed","output":[{"type":"message","id":"msg_1","role":"assistant","status":"completed","content":[{"type":"output_text","text":"ok"}]}],"usage":{"input_tokens":9,"output_tokens":4,"total_tokens":13,"input_tokens_details":{"cached_tokens":3}}}}`,
|
||||||
|
"",
|
||||||
|
"data: [DONE]",
|
||||||
|
"",
|
||||||
|
}, "\n")
|
||||||
|
upstream := &httpUpstreamRecorder{resp: &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_disconnect"}},
|
||||||
|
Body: io.NopCloser(strings.NewReader(upstreamBody)),
|
||||||
|
}}
|
||||||
|
|
||||||
|
svc := &OpenAIGatewayService{httpUpstream: upstream}
|
||||||
|
account := &Account{
|
||||||
|
ID: 1,
|
||||||
|
Name: "openai-oauth",
|
||||||
|
Platform: PlatformOpenAI,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Concurrency: 1,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"access_token": "oauth-token",
|
||||||
|
"chatgpt_account_id": "chatgpt-acc",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := svc.ForwardAsAnthropic(context.Background(), c, account, body, "", "gpt-5.1")
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, result)
|
||||||
|
require.Equal(t, 9, result.Usage.InputTokens)
|
||||||
|
require.Equal(t, 4, result.Usage.OutputTokens)
|
||||||
|
require.Equal(t, 3, result.Usage.CacheReadInputTokens)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestForwardAsAnthropic_TerminalUsageWithoutUpstreamCloseReturns(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
c.Writer = &openAICompatFailingWriter{ResponseWriter: c.Writer, failAfter: 0}
|
||||||
|
body := []byte(`{"model":"gpt-5.4","max_tokens":16,"messages":[{"role":"user","content":"hello"}],"stream":true}`)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body))
|
||||||
|
c.Request.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
upstreamBody := []byte(`data: {"type":"response.completed","response":{"id":"resp_1","object":"response","model":"gpt-5.4","status":"completed","output":[{"type":"message","id":"msg_1","role":"assistant","status":"completed","content":[{"type":"output_text","text":"ok"}]}],"usage":{"input_tokens":15,"output_tokens":6,"total_tokens":21,"input_tokens_details":{"cached_tokens":5}}}}` + "\n\n")
|
||||||
|
upstreamStream := newOpenAICompatBlockingReadCloser(upstreamBody)
|
||||||
|
defer func() {
|
||||||
|
require.NoError(t, upstreamStream.Close())
|
||||||
|
}()
|
||||||
|
upstream := &httpUpstreamRecorder{resp: &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_terminal_no_close"}},
|
||||||
|
Body: upstreamStream,
|
||||||
|
}}
|
||||||
|
|
||||||
|
svc := &OpenAIGatewayService{httpUpstream: upstream}
|
||||||
|
account := &Account{
|
||||||
|
ID: 1,
|
||||||
|
Name: "openai-oauth",
|
||||||
|
Platform: PlatformOpenAI,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Concurrency: 1,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"access_token": "oauth-token",
|
||||||
|
"chatgpt_account_id": "chatgpt-acc",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
type forwardResult struct {
|
||||||
|
result *OpenAIForwardResult
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
resultCh := make(chan forwardResult, 1)
|
||||||
|
go func() {
|
||||||
|
result, err := svc.ForwardAsAnthropic(context.Background(), c, account, body, "", "gpt-5.1")
|
||||||
|
resultCh <- forwardResult{result: result, err: err}
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case got := <-resultCh:
|
||||||
|
require.NoError(t, got.err)
|
||||||
|
require.NotNil(t, got.result)
|
||||||
|
require.Equal(t, 15, got.result.Usage.InputTokens)
|
||||||
|
require.Equal(t, 6, got.result.Usage.OutputTokens)
|
||||||
|
require.Equal(t, 5, got.result.Usage.CacheReadInputTokens)
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
require.Fail(t, "ForwardAsAnthropic should return after terminal usage event even if upstream keeps the connection open")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestForwardAsAnthropic_BufferedTerminalWithoutUpstreamCloseReturns(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
body := []byte(`{"model":"gpt-5.4","max_tokens":16,"messages":[{"role":"user","content":"hello"}],"stream":false}`)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body))
|
||||||
|
c.Request.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
upstreamBody := []byte(`data: {"type":"response.completed","response":{"id":"resp_1","object":"response","model":"gpt-5.4","status":"completed","output":[{"type":"message","id":"msg_1","role":"assistant","status":"completed","content":[{"type":"output_text","text":"ok"}]}],"usage":{"input_tokens":15,"output_tokens":6,"total_tokens":21,"input_tokens_details":{"cached_tokens":5}}}}` + "\n\n")
|
||||||
|
upstreamStream := newOpenAICompatBlockingReadCloser(upstreamBody)
|
||||||
|
defer func() {
|
||||||
|
require.NoError(t, upstreamStream.Close())
|
||||||
|
}()
|
||||||
|
upstream := &httpUpstreamRecorder{resp: &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_buffered_terminal_no_close"}},
|
||||||
|
Body: upstreamStream,
|
||||||
|
}}
|
||||||
|
|
||||||
|
svc := &OpenAIGatewayService{httpUpstream: upstream}
|
||||||
|
account := &Account{
|
||||||
|
ID: 1,
|
||||||
|
Name: "openai-oauth",
|
||||||
|
Platform: PlatformOpenAI,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Concurrency: 1,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"access_token": "oauth-token",
|
||||||
|
"chatgpt_account_id": "chatgpt-acc",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
type forwardResult struct {
|
||||||
|
result *OpenAIForwardResult
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
resultCh := make(chan forwardResult, 1)
|
||||||
|
go func() {
|
||||||
|
result, err := svc.ForwardAsAnthropic(context.Background(), c, account, body, "", "gpt-5.1")
|
||||||
|
resultCh <- forwardResult{result: result, err: err}
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case got := <-resultCh:
|
||||||
|
require.NoError(t, got.err)
|
||||||
|
require.NotNil(t, got.result)
|
||||||
|
require.Equal(t, 15, got.result.Usage.InputTokens)
|
||||||
|
require.Equal(t, 6, got.result.Usage.OutputTokens)
|
||||||
|
require.Equal(t, 5, got.result.Usage.CacheReadInputTokens)
|
||||||
|
require.Contains(t, rec.Body.String(), `"stop_reason":"end_turn"`)
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
require.Fail(t, "ForwardAsAnthropic buffered response should return after terminal usage event even if upstream keeps the connection open")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestForwardAsAnthropic_DoneSentinelWithoutTerminalReturnsError(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
body := []byte(`{"model":"gpt-5.4","max_tokens":16,"messages":[{"role":"user","content":"hello"}],"stream":true}`)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body))
|
||||||
|
c.Request.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
upstreamBody := "data: [DONE]\n\n"
|
||||||
|
upstream := &httpUpstreamRecorder{resp: &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_missing_terminal"}},
|
||||||
|
Body: io.NopCloser(strings.NewReader(upstreamBody)),
|
||||||
|
}}
|
||||||
|
|
||||||
|
svc := &OpenAIGatewayService{httpUpstream: upstream}
|
||||||
|
account := &Account{
|
||||||
|
ID: 1,
|
||||||
|
Name: "openai-oauth",
|
||||||
|
Platform: PlatformOpenAI,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Concurrency: 1,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"access_token": "oauth-token",
|
||||||
|
"chatgpt_account_id": "chatgpt-acc",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := svc.ForwardAsAnthropic(context.Background(), c, account, body, "", "gpt-5.1")
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Contains(t, err.Error(), "missing terminal event")
|
||||||
|
require.NotNil(t, result)
|
||||||
|
require.Zero(t, result.Usage.InputTokens)
|
||||||
|
require.Zero(t, result.Usage.OutputTokens)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestForwardAsAnthropic_UpstreamRequestIgnoresClientCancel(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
reqCtx, cancel := context.WithCancel(context.Background())
|
||||||
|
body := []byte(`{"model":"gpt-5.4","max_tokens":16,"messages":[{"role":"user","content":"hello"}],"stream":false}`)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body)).WithContext(reqCtx)
|
||||||
|
c.Request.Header.Set("Content-Type", "application/json")
|
||||||
|
cancel()
|
||||||
|
|
||||||
|
upstreamBody := strings.Join([]string{
|
||||||
|
`data: {"type":"response.completed","response":{"id":"resp_1","object":"response","model":"gpt-5.4","status":"completed","output":[{"type":"message","id":"msg_1","role":"assistant","status":"completed","content":[{"type":"output_text","text":"ok"}]}],"usage":{"input_tokens":5,"output_tokens":2,"total_tokens":7}}}`,
|
||||||
|
"",
|
||||||
|
"data: [DONE]",
|
||||||
|
"",
|
||||||
|
}, "\n")
|
||||||
|
upstream := &httpUpstreamRecorder{resp: &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_ctx"}},
|
||||||
|
Body: io.NopCloser(strings.NewReader(upstreamBody)),
|
||||||
|
}}
|
||||||
|
|
||||||
|
svc := &OpenAIGatewayService{httpUpstream: upstream}
|
||||||
|
account := &Account{
|
||||||
|
ID: 1,
|
||||||
|
Name: "openai-oauth",
|
||||||
|
Platform: PlatformOpenAI,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Concurrency: 1,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"access_token": "oauth-token",
|
||||||
|
"chatgpt_account_id": "chatgpt-acc",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := svc.ForwardAsAnthropic(reqCtx, c, account, body, "", "gpt-5.1")
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, result)
|
||||||
|
require.NotNil(t, upstream.lastReq)
|
||||||
|
require.NoError(t, upstream.lastReq.Context().Err())
|
||||||
|
}
|
||||||
|
|||||||
@ -20,20 +20,29 @@ func (s *openAI403CounterResetStub) ResetOpenAI403Count(_ context.Context, accou
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestOpenAIGatewayServiceRecordUsage_ResetsOpenAI403CounterBeforeZeroUsageReturn(t *testing.T) {
|
func TestOpenAIGatewayServiceRecordUsage_ResetsOpenAI403CounterForZeroUsage(t *testing.T) {
|
||||||
counter := &openAI403CounterResetStub{}
|
counter := &openAI403CounterResetStub{}
|
||||||
rateLimitSvc := NewRateLimitService(nil, nil, nil, nil, nil)
|
rateLimitSvc := NewRateLimitService(nil, nil, nil, nil, nil)
|
||||||
rateLimitSvc.SetOpenAI403CounterCache(counter)
|
rateLimitSvc.SetOpenAI403CounterCache(counter)
|
||||||
|
|
||||||
svc := &OpenAIGatewayService{
|
usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
|
||||||
rateLimitService: rateLimitSvc,
|
billingRepo := &openAIRecordUsageBillingRepoStub{result: &UsageBillingApplyResult{Applied: true}}
|
||||||
}
|
userRepo := &openAIRecordUsageUserRepoStub{}
|
||||||
|
subRepo := &openAIRecordUsageSubRepoStub{}
|
||||||
|
svc := newOpenAIRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, userRepo, subRepo, nil)
|
||||||
|
svc.rateLimitService = rateLimitSvc
|
||||||
|
|
||||||
err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
|
err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
|
||||||
Result: &OpenAIForwardResult{},
|
Result: &OpenAIForwardResult{
|
||||||
|
RequestID: "resp_zero_usage_reset_403",
|
||||||
|
Model: "gpt-5.1",
|
||||||
|
},
|
||||||
|
APIKey: &APIKey{ID: 1001, Group: &Group{RateMultiplier: 1}},
|
||||||
|
User: &User{ID: 2001},
|
||||||
Account: &Account{ID: 777, Platform: PlatformOpenAI},
|
Account: &Account{ID: 777, Platform: PlatformOpenAI},
|
||||||
})
|
})
|
||||||
|
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Equal(t, []int64{777}, counter.resetCalls)
|
require.Equal(t, []int64{777}, counter.resetCalls)
|
||||||
|
require.Equal(t, 1, usageRepo.calls)
|
||||||
}
|
}
|
||||||
|
|||||||
@ -10,6 +10,7 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/apicompat"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/apicompat"
|
||||||
@ -189,7 +190,9 @@ func (s *OpenAIGatewayService) ForwardAsChatCompletions(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 6. Build upstream request
|
// 6. Build upstream request
|
||||||
upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, responsesBody, token, true, promptCacheKey, false)
|
upstreamCtx, releaseUpstreamCtx := detachUpstreamContext(ctx)
|
||||||
|
upstreamReq, err := s.buildUpstreamRequest(upstreamCtx, c, account, responsesBody, token, true, promptCacheKey, false)
|
||||||
|
releaseUpstreamCtx()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("build upstream request: %w", err)
|
return nil, fmt.Errorf("build upstream request: %w", err)
|
||||||
}
|
}
|
||||||
@ -348,59 +351,9 @@ func (s *OpenAIGatewayService) handleChatBufferedStreamingResponse(
|
|||||||
) (*OpenAIForwardResult, error) {
|
) (*OpenAIForwardResult, error) {
|
||||||
requestID := resp.Header.Get("x-request-id")
|
requestID := resp.Header.Get("x-request-id")
|
||||||
|
|
||||||
scanner := bufio.NewScanner(resp.Body)
|
finalResponse, usage, acc, err := s.readOpenAICompatBufferedTerminal(resp, "openai chat_completions buffered", requestID)
|
||||||
maxLineSize := defaultMaxLineSize
|
if err != nil {
|
||||||
if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 {
|
return nil, err
|
||||||
maxLineSize = s.cfg.Gateway.MaxLineSize
|
|
||||||
}
|
|
||||||
scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize)
|
|
||||||
|
|
||||||
var finalResponse *apicompat.ResponsesResponse
|
|
||||||
var usage OpenAIUsage
|
|
||||||
acc := apicompat.NewBufferedResponseAccumulator()
|
|
||||||
|
|
||||||
for scanner.Scan() {
|
|
||||||
line := scanner.Text()
|
|
||||||
if !strings.HasPrefix(line, "data: ") || line == "data: [DONE]" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
payload := line[6:]
|
|
||||||
|
|
||||||
var event apicompat.ResponsesStreamEvent
|
|
||||||
if err := json.Unmarshal([]byte(payload), &event); err != nil {
|
|
||||||
logger.L().Warn("openai chat_completions buffered: failed to parse event",
|
|
||||||
zap.Error(err),
|
|
||||||
zap.String("request_id", requestID),
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// Accumulate delta content for fallback when terminal output is empty.
|
|
||||||
acc.ProcessEvent(&event)
|
|
||||||
|
|
||||||
if (event.Type == "response.completed" || event.Type == "response.done" ||
|
|
||||||
event.Type == "response.incomplete" || event.Type == "response.failed") &&
|
|
||||||
event.Response != nil {
|
|
||||||
finalResponse = event.Response
|
|
||||||
if event.Response.Usage != nil {
|
|
||||||
usage = OpenAIUsage{
|
|
||||||
InputTokens: event.Response.Usage.InputTokens,
|
|
||||||
OutputTokens: event.Response.Usage.OutputTokens,
|
|
||||||
}
|
|
||||||
if event.Response.Usage.InputTokensDetails != nil {
|
|
||||||
usage.CacheReadInputTokens = event.Response.Usage.InputTokensDetails.CachedTokens
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := scanner.Err(); err != nil {
|
|
||||||
if !errors.Is(err, context.Canceled) && !errors.Is(err, context.DeadlineExceeded) {
|
|
||||||
logger.L().Warn("openai chat_completions buffered: read error",
|
|
||||||
zap.Error(err),
|
|
||||||
zap.String("request_id", requestID),
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if finalResponse == nil {
|
if finalResponse == nil {
|
||||||
@ -459,6 +412,7 @@ func (s *OpenAIGatewayService) handleChatStreamingResponse(
|
|||||||
var usage OpenAIUsage
|
var usage OpenAIUsage
|
||||||
var firstTokenMs *int
|
var firstTokenMs *int
|
||||||
firstChunk := true
|
firstChunk := true
|
||||||
|
clientDisconnected := false
|
||||||
|
|
||||||
scanner := bufio.NewScanner(resp.Body)
|
scanner := bufio.NewScanner(resp.Body)
|
||||||
maxLineSize := defaultMaxLineSize
|
maxLineSize := defaultMaxLineSize
|
||||||
@ -467,6 +421,20 @@ func (s *OpenAIGatewayService) handleChatStreamingResponse(
|
|||||||
}
|
}
|
||||||
scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize)
|
scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize)
|
||||||
|
|
||||||
|
streamInterval := time.Duration(0)
|
||||||
|
if s.cfg != nil && s.cfg.Gateway.StreamDataIntervalTimeout > 0 {
|
||||||
|
streamInterval = time.Duration(s.cfg.Gateway.StreamDataIntervalTimeout) * time.Second
|
||||||
|
}
|
||||||
|
var intervalTicker *time.Ticker
|
||||||
|
if streamInterval > 0 {
|
||||||
|
intervalTicker = time.NewTicker(streamInterval)
|
||||||
|
defer intervalTicker.Stop()
|
||||||
|
}
|
||||||
|
var intervalCh <-chan time.Time
|
||||||
|
if intervalTicker != nil {
|
||||||
|
intervalCh = intervalTicker.C
|
||||||
|
}
|
||||||
|
|
||||||
resultWithUsage := func() *OpenAIForwardResult {
|
resultWithUsage := func() *OpenAIForwardResult {
|
||||||
return &OpenAIForwardResult{
|
return &OpenAIForwardResult{
|
||||||
RequestID: requestID,
|
RequestID: requestID,
|
||||||
@ -496,54 +464,66 @@ func (s *OpenAIGatewayService) handleChatStreamingResponse(
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// Extract usage from completion events
|
// 仅按兼容转换器支持的终止事件提取 usage,避免无意扩大事件语义。
|
||||||
if (event.Type == "response.completed" || event.Type == "response.incomplete" || event.Type == "response.failed") &&
|
isTerminalEvent := isOpenAICompatResponsesTerminalEvent(event.Type)
|
||||||
event.Response != nil && event.Response.Usage != nil {
|
if isTerminalEvent && event.Response != nil && event.Response.Usage != nil {
|
||||||
usage = OpenAIUsage{
|
usage = copyOpenAIUsageFromResponsesUsage(event.Response.Usage)
|
||||||
InputTokens: event.Response.Usage.InputTokens,
|
|
||||||
OutputTokens: event.Response.Usage.OutputTokens,
|
|
||||||
}
|
|
||||||
if event.Response.Usage.InputTokensDetails != nil {
|
|
||||||
usage.CacheReadInputTokens = event.Response.Usage.InputTokensDetails.CachedTokens
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
chunks := apicompat.ResponsesEventToChatChunks(&event, state)
|
chunks := apicompat.ResponsesEventToChatChunks(&event, state)
|
||||||
for _, chunk := range chunks {
|
if !clientDisconnected {
|
||||||
sse, err := apicompat.ChatChunkToSSE(chunk)
|
for _, chunk := range chunks {
|
||||||
if err != nil {
|
sse, err := apicompat.ChatChunkToSSE(chunk)
|
||||||
logger.L().Warn("openai chat_completions stream: failed to marshal chunk",
|
if err != nil {
|
||||||
zap.Error(err),
|
logger.L().Warn("openai chat_completions stream: failed to marshal chunk",
|
||||||
zap.String("request_id", requestID),
|
zap.Error(err),
|
||||||
)
|
zap.String("request_id", requestID),
|
||||||
continue
|
)
|
||||||
}
|
continue
|
||||||
if _, err := fmt.Fprint(c.Writer, sse); err != nil {
|
}
|
||||||
logger.L().Info("openai chat_completions stream: client disconnected",
|
if _, err := fmt.Fprint(c.Writer, sse); err != nil {
|
||||||
zap.String("request_id", requestID),
|
clientDisconnected = true
|
||||||
)
|
logger.L().Info("openai chat_completions stream: client disconnected, continuing to drain upstream for billing",
|
||||||
return true
|
zap.String("request_id", requestID),
|
||||||
|
)
|
||||||
|
break
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if len(chunks) > 0 {
|
if len(chunks) > 0 && !clientDisconnected {
|
||||||
c.Writer.Flush()
|
c.Writer.Flush()
|
||||||
}
|
}
|
||||||
return false
|
return isTerminalEvent
|
||||||
}
|
}
|
||||||
|
|
||||||
finalizeStream := func() (*OpenAIForwardResult, error) {
|
finalizeStream := func() (*OpenAIForwardResult, error) {
|
||||||
if finalChunks := apicompat.FinalizeResponsesChatStream(state); len(finalChunks) > 0 {
|
if finalChunks := apicompat.FinalizeResponsesChatStream(state); len(finalChunks) > 0 && !clientDisconnected {
|
||||||
for _, chunk := range finalChunks {
|
for _, chunk := range finalChunks {
|
||||||
sse, err := apicompat.ChatChunkToSSE(chunk)
|
sse, err := apicompat.ChatChunkToSSE(chunk)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
fmt.Fprint(c.Writer, sse) //nolint:errcheck
|
if _, err := fmt.Fprint(c.Writer, sse); err != nil {
|
||||||
|
clientDisconnected = true
|
||||||
|
logger.L().Info("openai chat_completions stream: client disconnected during final flush",
|
||||||
|
zap.String("request_id", requestID),
|
||||||
|
)
|
||||||
|
break
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// Send [DONE] sentinel
|
// Send [DONE] sentinel
|
||||||
fmt.Fprint(c.Writer, "data: [DONE]\n\n") //nolint:errcheck
|
if !clientDisconnected {
|
||||||
c.Writer.Flush()
|
if _, err := fmt.Fprint(c.Writer, "data: [DONE]\n\n"); err != nil {
|
||||||
|
clientDisconnected = true
|
||||||
|
logger.L().Info("openai chat_completions stream: client disconnected during done flush",
|
||||||
|
zap.String("request_id", requestID),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !clientDisconnected {
|
||||||
|
c.Writer.Flush()
|
||||||
|
}
|
||||||
return resultWithUsage(), nil
|
return resultWithUsage(), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -555,6 +535,9 @@ func (s *OpenAIGatewayService) handleChatStreamingResponse(
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
missingTerminalErr := func() (*OpenAIForwardResult, error) {
|
||||||
|
return resultWithUsage(), fmt.Errorf("stream usage incomplete: missing terminal event")
|
||||||
|
}
|
||||||
|
|
||||||
// Determine keepalive interval
|
// Determine keepalive interval
|
||||||
keepaliveInterval := time.Duration(0)
|
keepaliveInterval := time.Duration(0)
|
||||||
@ -563,18 +546,25 @@ func (s *OpenAIGatewayService) handleChatStreamingResponse(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// No keepalive: fast synchronous path
|
// No keepalive: fast synchronous path
|
||||||
if keepaliveInterval <= 0 {
|
if streamInterval <= 0 && keepaliveInterval <= 0 {
|
||||||
for scanner.Scan() {
|
for scanner.Scan() {
|
||||||
line := scanner.Text()
|
line := scanner.Text()
|
||||||
if !strings.HasPrefix(line, "data: ") || line == "data: [DONE]" {
|
payload, ok := extractOpenAISSEDataLine(line)
|
||||||
|
if !ok {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if processDataLine(line[6:]) {
|
if strings.TrimSpace(payload) == "[DONE]" {
|
||||||
return resultWithUsage(), nil
|
return missingTerminalErr()
|
||||||
|
}
|
||||||
|
if processDataLine(payload) {
|
||||||
|
return finalizeStream()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
handleScanErr(scanner.Err())
|
if err := scanner.Err(); err != nil {
|
||||||
return finalizeStream()
|
handleScanErr(err)
|
||||||
|
return resultWithUsage(), fmt.Errorf("stream usage incomplete: %w", err)
|
||||||
|
}
|
||||||
|
return missingTerminalErr()
|
||||||
}
|
}
|
||||||
|
|
||||||
// With keepalive: goroutine + channel + select
|
// With keepalive: goroutine + channel + select
|
||||||
@ -584,6 +574,8 @@ func (s *OpenAIGatewayService) handleChatStreamingResponse(
|
|||||||
}
|
}
|
||||||
events := make(chan scanEvent, 16)
|
events := make(chan scanEvent, 16)
|
||||||
done := make(chan struct{})
|
done := make(chan struct{})
|
||||||
|
var lastReadAt int64
|
||||||
|
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
|
||||||
sendEvent := func(ev scanEvent) bool {
|
sendEvent := func(ev scanEvent) bool {
|
||||||
select {
|
select {
|
||||||
case events <- ev:
|
case events <- ev:
|
||||||
@ -595,6 +587,7 @@ func (s *OpenAIGatewayService) handleChatStreamingResponse(
|
|||||||
go func() {
|
go func() {
|
||||||
defer close(events)
|
defer close(events)
|
||||||
for scanner.Scan() {
|
for scanner.Scan() {
|
||||||
|
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
|
||||||
if !sendEvent(scanEvent{line: scanner.Text()}) {
|
if !sendEvent(scanEvent{line: scanner.Text()}) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -605,30 +598,59 @@ func (s *OpenAIGatewayService) handleChatStreamingResponse(
|
|||||||
}()
|
}()
|
||||||
defer close(done)
|
defer close(done)
|
||||||
|
|
||||||
keepaliveTicker := time.NewTicker(keepaliveInterval)
|
var keepaliveTicker *time.Ticker
|
||||||
defer keepaliveTicker.Stop()
|
if keepaliveInterval > 0 {
|
||||||
|
keepaliveTicker = time.NewTicker(keepaliveInterval)
|
||||||
|
defer keepaliveTicker.Stop()
|
||||||
|
}
|
||||||
|
var keepaliveCh <-chan time.Time
|
||||||
|
if keepaliveTicker != nil {
|
||||||
|
keepaliveCh = keepaliveTicker.C
|
||||||
|
}
|
||||||
lastDataAt := time.Now()
|
lastDataAt := time.Now()
|
||||||
|
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case ev, ok := <-events:
|
case ev, ok := <-events:
|
||||||
if !ok {
|
if !ok {
|
||||||
return finalizeStream()
|
return missingTerminalErr()
|
||||||
}
|
}
|
||||||
if ev.err != nil {
|
if ev.err != nil {
|
||||||
handleScanErr(ev.err)
|
handleScanErr(ev.err)
|
||||||
return finalizeStream()
|
return resultWithUsage(), fmt.Errorf("stream usage incomplete: %w", ev.err)
|
||||||
}
|
}
|
||||||
lastDataAt = time.Now()
|
lastDataAt = time.Now()
|
||||||
line := ev.line
|
line := ev.line
|
||||||
if !strings.HasPrefix(line, "data: ") || line == "data: [DONE]" {
|
payload, ok := extractOpenAISSEDataLine(line)
|
||||||
|
if !ok {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if processDataLine(line[6:]) {
|
if strings.TrimSpace(payload) == "[DONE]" {
|
||||||
return resultWithUsage(), nil
|
return missingTerminalErr()
|
||||||
|
}
|
||||||
|
if processDataLine(payload) {
|
||||||
|
return finalizeStream()
|
||||||
}
|
}
|
||||||
|
|
||||||
case <-keepaliveTicker.C:
|
case <-intervalCh:
|
||||||
|
lastRead := time.Unix(0, atomic.LoadInt64(&lastReadAt))
|
||||||
|
if time.Since(lastRead) < streamInterval {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if clientDisconnected {
|
||||||
|
return resultWithUsage(), fmt.Errorf("stream usage incomplete after timeout")
|
||||||
|
}
|
||||||
|
logger.L().Warn("openai chat_completions stream: data interval timeout",
|
||||||
|
zap.String("request_id", requestID),
|
||||||
|
zap.String("model", originalModel),
|
||||||
|
zap.Duration("interval", streamInterval),
|
||||||
|
)
|
||||||
|
return resultWithUsage(), fmt.Errorf("stream data interval timeout")
|
||||||
|
|
||||||
|
case <-keepaliveCh:
|
||||||
|
if clientDisconnected {
|
||||||
|
continue
|
||||||
|
}
|
||||||
if time.Since(lastDataAt) < keepaliveInterval {
|
if time.Since(lastDataAt) < keepaliveInterval {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@ -637,7 +659,8 @@ func (s *OpenAIGatewayService) handleChatStreamingResponse(
|
|||||||
logger.L().Info("openai chat_completions stream: client disconnected during keepalive",
|
logger.L().Info("openai chat_completions stream: client disconnected during keepalive",
|
||||||
zap.String("request_id", requestID),
|
zap.String("request_id", requestID),
|
||||||
)
|
)
|
||||||
return resultWithUsage(), nil
|
clientDisconnected = true
|
||||||
|
continue
|
||||||
}
|
}
|
||||||
c.Writer.Flush()
|
c.Writer.Flush()
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,13 +1,36 @@
|
|||||||
package service
|
package service
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/apicompat"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/apicompat"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"github.com/tidwall/gjson"
|
"github.com/tidwall/gjson"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type openAIChatFailingWriter struct {
|
||||||
|
gin.ResponseWriter
|
||||||
|
failAfter int
|
||||||
|
writes int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *openAIChatFailingWriter) Write(p []byte) (int, error) {
|
||||||
|
if w.writes >= w.failAfter {
|
||||||
|
return 0, errors.New("write failed: client disconnected")
|
||||||
|
}
|
||||||
|
w.writes++
|
||||||
|
return w.ResponseWriter.Write(p)
|
||||||
|
}
|
||||||
|
|
||||||
func TestNormalizeResponsesRequestServiceTier(t *testing.T) {
|
func TestNormalizeResponsesRequestServiceTier(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
@ -73,3 +96,242 @@ func TestNormalizeResponsesBodyServiceTier(t *testing.T) {
|
|||||||
require.Empty(t, tier)
|
require.Empty(t, tier)
|
||||||
require.False(t, gjson.GetBytes(body, "service_tier").Exists())
|
require.False(t, gjson.GetBytes(body, "service_tier").Exists())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestForwardAsChatCompletions_ClientDisconnectDrainsUpstreamUsage(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
c.Writer = &openAIChatFailingWriter{ResponseWriter: c.Writer, failAfter: 0}
|
||||||
|
body := []byte(`{"model":"gpt-5.4","messages":[{"role":"user","content":"hello"}],"stream":true}`)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", bytes.NewReader(body))
|
||||||
|
c.Request.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
upstreamBody := strings.Join([]string{
|
||||||
|
`data: {"type":"response.created","response":{"id":"resp_1","model":"gpt-5.4","status":"in_progress","output":[]}}`,
|
||||||
|
"",
|
||||||
|
`data: {"type":"response.output_text.delta","delta":"ok"}`,
|
||||||
|
"",
|
||||||
|
`data: {"type":"response.completed","response":{"id":"resp_1","object":"response","model":"gpt-5.4","status":"completed","output":[{"type":"message","id":"msg_1","role":"assistant","status":"completed","content":[{"type":"output_text","text":"ok"}]}],"usage":{"input_tokens":11,"output_tokens":5,"total_tokens":16,"input_tokens_details":{"cached_tokens":4}}}}`,
|
||||||
|
"",
|
||||||
|
"data: [DONE]",
|
||||||
|
"",
|
||||||
|
}, "\n")
|
||||||
|
upstream := &httpUpstreamRecorder{resp: &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_chat_disconnect"}},
|
||||||
|
Body: io.NopCloser(strings.NewReader(upstreamBody)),
|
||||||
|
}}
|
||||||
|
|
||||||
|
svc := &OpenAIGatewayService{httpUpstream: upstream}
|
||||||
|
account := &Account{
|
||||||
|
ID: 1,
|
||||||
|
Name: "openai-oauth",
|
||||||
|
Platform: PlatformOpenAI,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Concurrency: 1,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"access_token": "oauth-token",
|
||||||
|
"chatgpt_account_id": "chatgpt-acc",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := svc.ForwardAsChatCompletions(context.Background(), c, account, body, "", "gpt-5.1")
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, result)
|
||||||
|
require.Equal(t, 11, result.Usage.InputTokens)
|
||||||
|
require.Equal(t, 5, result.Usage.OutputTokens)
|
||||||
|
require.Equal(t, 4, result.Usage.CacheReadInputTokens)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestForwardAsChatCompletions_TerminalUsageWithoutUpstreamCloseReturns(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
c.Writer = &openAIChatFailingWriter{ResponseWriter: c.Writer, failAfter: 0}
|
||||||
|
body := []byte(`{"model":"gpt-5.4","messages":[{"role":"user","content":"hello"}],"stream":true}`)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", bytes.NewReader(body))
|
||||||
|
c.Request.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
upstreamBody := []byte(`data: {"type":"response.completed","response":{"id":"resp_1","object":"response","model":"gpt-5.4","status":"completed","output":[{"type":"message","id":"msg_1","role":"assistant","status":"completed","content":[{"type":"output_text","text":"ok"}]}],"usage":{"input_tokens":17,"output_tokens":8,"total_tokens":25,"input_tokens_details":{"cached_tokens":6}}}}` + "\n\n")
|
||||||
|
upstreamStream := newOpenAICompatBlockingReadCloser(upstreamBody)
|
||||||
|
defer func() {
|
||||||
|
require.NoError(t, upstreamStream.Close())
|
||||||
|
}()
|
||||||
|
upstream := &httpUpstreamRecorder{resp: &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_chat_terminal_no_close"}},
|
||||||
|
Body: upstreamStream,
|
||||||
|
}}
|
||||||
|
|
||||||
|
svc := &OpenAIGatewayService{httpUpstream: upstream}
|
||||||
|
account := &Account{
|
||||||
|
ID: 1,
|
||||||
|
Name: "openai-oauth",
|
||||||
|
Platform: PlatformOpenAI,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Concurrency: 1,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"access_token": "oauth-token",
|
||||||
|
"chatgpt_account_id": "chatgpt-acc",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
type forwardResult struct {
|
||||||
|
result *OpenAIForwardResult
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
resultCh := make(chan forwardResult, 1)
|
||||||
|
go func() {
|
||||||
|
result, err := svc.ForwardAsChatCompletions(context.Background(), c, account, body, "", "gpt-5.1")
|
||||||
|
resultCh <- forwardResult{result: result, err: err}
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case got := <-resultCh:
|
||||||
|
require.NoError(t, got.err)
|
||||||
|
require.NotNil(t, got.result)
|
||||||
|
require.Equal(t, 17, got.result.Usage.InputTokens)
|
||||||
|
require.Equal(t, 8, got.result.Usage.OutputTokens)
|
||||||
|
require.Equal(t, 6, got.result.Usage.CacheReadInputTokens)
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
require.Fail(t, "ForwardAsChatCompletions should return after terminal usage event even if upstream keeps the connection open")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestForwardAsChatCompletions_BufferedTerminalWithoutUpstreamCloseReturns(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
body := []byte(`{"model":"gpt-5.4","messages":[{"role":"user","content":"hello"}],"stream":false}`)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", bytes.NewReader(body))
|
||||||
|
c.Request.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
upstreamBody := []byte(`data: {"type":"response.completed","response":{"id":"resp_1","object":"response","model":"gpt-5.4","status":"completed","output":[{"type":"message","id":"msg_1","role":"assistant","status":"completed","content":[{"type":"output_text","text":"ok"}]}],"usage":{"input_tokens":17,"output_tokens":8,"total_tokens":25,"input_tokens_details":{"cached_tokens":6}}}}` + "\n\n")
|
||||||
|
upstreamStream := newOpenAICompatBlockingReadCloser(upstreamBody)
|
||||||
|
defer func() {
|
||||||
|
require.NoError(t, upstreamStream.Close())
|
||||||
|
}()
|
||||||
|
upstream := &httpUpstreamRecorder{resp: &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_chat_buffered_terminal_no_close"}},
|
||||||
|
Body: upstreamStream,
|
||||||
|
}}
|
||||||
|
|
||||||
|
svc := &OpenAIGatewayService{httpUpstream: upstream}
|
||||||
|
account := &Account{
|
||||||
|
ID: 1,
|
||||||
|
Name: "openai-oauth",
|
||||||
|
Platform: PlatformOpenAI,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Concurrency: 1,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"access_token": "oauth-token",
|
||||||
|
"chatgpt_account_id": "chatgpt-acc",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
type forwardResult struct {
|
||||||
|
result *OpenAIForwardResult
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
resultCh := make(chan forwardResult, 1)
|
||||||
|
go func() {
|
||||||
|
result, err := svc.ForwardAsChatCompletions(context.Background(), c, account, body, "", "gpt-5.1")
|
||||||
|
resultCh <- forwardResult{result: result, err: err}
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case got := <-resultCh:
|
||||||
|
require.NoError(t, got.err)
|
||||||
|
require.NotNil(t, got.result)
|
||||||
|
require.Equal(t, 17, got.result.Usage.InputTokens)
|
||||||
|
require.Equal(t, 8, got.result.Usage.OutputTokens)
|
||||||
|
require.Equal(t, 6, got.result.Usage.CacheReadInputTokens)
|
||||||
|
require.Contains(t, rec.Body.String(), `"finish_reason":"stop"`)
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
require.Fail(t, "ForwardAsChatCompletions buffered response should return after terminal usage event even if upstream keeps the connection open")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestForwardAsChatCompletions_DoneSentinelWithoutTerminalReturnsError(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
body := []byte(`{"model":"gpt-5.4","messages":[{"role":"user","content":"hello"}],"stream":true}`)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", bytes.NewReader(body))
|
||||||
|
c.Request.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
upstreamBody := "data: [DONE]\n\n"
|
||||||
|
upstream := &httpUpstreamRecorder{resp: &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_chat_missing_terminal"}},
|
||||||
|
Body: io.NopCloser(strings.NewReader(upstreamBody)),
|
||||||
|
}}
|
||||||
|
|
||||||
|
svc := &OpenAIGatewayService{httpUpstream: upstream}
|
||||||
|
account := &Account{
|
||||||
|
ID: 1,
|
||||||
|
Name: "openai-oauth",
|
||||||
|
Platform: PlatformOpenAI,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Concurrency: 1,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"access_token": "oauth-token",
|
||||||
|
"chatgpt_account_id": "chatgpt-acc",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := svc.ForwardAsChatCompletions(context.Background(), c, account, body, "", "gpt-5.1")
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Contains(t, err.Error(), "missing terminal event")
|
||||||
|
require.NotNil(t, result)
|
||||||
|
require.Zero(t, result.Usage.InputTokens)
|
||||||
|
require.Zero(t, result.Usage.OutputTokens)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestForwardAsChatCompletions_UpstreamRequestIgnoresClientCancel(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
reqCtx, cancel := context.WithCancel(context.Background())
|
||||||
|
body := []byte(`{"model":"gpt-5.4","messages":[{"role":"user","content":"hello"}],"stream":false}`)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", bytes.NewReader(body)).WithContext(reqCtx)
|
||||||
|
c.Request.Header.Set("Content-Type", "application/json")
|
||||||
|
cancel()
|
||||||
|
|
||||||
|
upstreamBody := strings.Join([]string{
|
||||||
|
`data: {"type":"response.completed","response":{"id":"resp_1","object":"response","model":"gpt-5.4","status":"completed","output":[{"type":"message","id":"msg_1","role":"assistant","status":"completed","content":[{"type":"output_text","text":"ok"}]}],"usage":{"input_tokens":5,"output_tokens":2,"total_tokens":7}}}`,
|
||||||
|
"",
|
||||||
|
"data: [DONE]",
|
||||||
|
"",
|
||||||
|
}, "\n")
|
||||||
|
upstream := &httpUpstreamRecorder{resp: &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_chat_ctx"}},
|
||||||
|
Body: io.NopCloser(strings.NewReader(upstreamBody)),
|
||||||
|
}}
|
||||||
|
|
||||||
|
svc := &OpenAIGatewayService{httpUpstream: upstream}
|
||||||
|
account := &Account{
|
||||||
|
ID: 1,
|
||||||
|
Name: "openai-oauth",
|
||||||
|
Platform: PlatformOpenAI,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Concurrency: 1,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"access_token": "oauth-token",
|
||||||
|
"chatgpt_account_id": "chatgpt-acc",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := svc.ForwardAsChatCompletions(reqCtx, c, account, body, "", "gpt-5.1")
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, result)
|
||||||
|
require.NotNil(t, upstream.lastReq)
|
||||||
|
require.NoError(t, upstream.lastReq.Context().Err())
|
||||||
|
}
|
||||||
|
|||||||
@ -10,6 +10,7 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/apicompat"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/apicompat"
|
||||||
@ -163,7 +164,9 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 6. Build upstream request
|
// 6. Build upstream request
|
||||||
upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, responsesBody, token, isStream, promptCacheKey, false)
|
upstreamCtx, releaseUpstreamCtx := detachUpstreamContext(ctx)
|
||||||
|
upstreamReq, err := s.buildUpstreamRequest(upstreamCtx, c, account, responsesBody, token, isStream, promptCacheKey, false)
|
||||||
|
releaseUpstreamCtx()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("build upstream request: %w", err)
|
return nil, fmt.Errorf("build upstream request: %w", err)
|
||||||
}
|
}
|
||||||
@ -296,61 +299,9 @@ func (s *OpenAIGatewayService) handleAnthropicBufferedStreamingResponse(
|
|||||||
) (*OpenAIForwardResult, error) {
|
) (*OpenAIForwardResult, error) {
|
||||||
requestID := resp.Header.Get("x-request-id")
|
requestID := resp.Header.Get("x-request-id")
|
||||||
|
|
||||||
scanner := bufio.NewScanner(resp.Body)
|
finalResponse, usage, acc, err := s.readOpenAICompatBufferedTerminal(resp, "openai messages buffered", requestID)
|
||||||
maxLineSize := defaultMaxLineSize
|
if err != nil {
|
||||||
if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 {
|
return nil, err
|
||||||
maxLineSize = s.cfg.Gateway.MaxLineSize
|
|
||||||
}
|
|
||||||
scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize)
|
|
||||||
|
|
||||||
var finalResponse *apicompat.ResponsesResponse
|
|
||||||
var usage OpenAIUsage
|
|
||||||
acc := apicompat.NewBufferedResponseAccumulator()
|
|
||||||
|
|
||||||
for scanner.Scan() {
|
|
||||||
line := scanner.Text()
|
|
||||||
|
|
||||||
if !strings.HasPrefix(line, "data: ") || line == "data: [DONE]" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
payload := line[6:]
|
|
||||||
|
|
||||||
var event apicompat.ResponsesStreamEvent
|
|
||||||
if err := json.Unmarshal([]byte(payload), &event); err != nil {
|
|
||||||
logger.L().Warn("openai messages buffered: failed to parse event",
|
|
||||||
zap.Error(err),
|
|
||||||
zap.String("request_id", requestID),
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// Accumulate delta content for fallback when terminal output is empty.
|
|
||||||
acc.ProcessEvent(&event)
|
|
||||||
|
|
||||||
// Terminal events carry the complete ResponsesResponse with output + usage.
|
|
||||||
if (event.Type == "response.completed" || event.Type == "response.done" ||
|
|
||||||
event.Type == "response.incomplete" || event.Type == "response.failed") &&
|
|
||||||
event.Response != nil {
|
|
||||||
finalResponse = event.Response
|
|
||||||
if event.Response.Usage != nil {
|
|
||||||
usage = OpenAIUsage{
|
|
||||||
InputTokens: event.Response.Usage.InputTokens,
|
|
||||||
OutputTokens: event.Response.Usage.OutputTokens,
|
|
||||||
}
|
|
||||||
if event.Response.Usage.InputTokensDetails != nil {
|
|
||||||
usage.CacheReadInputTokens = event.Response.Usage.InputTokensDetails.CachedTokens
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := scanner.Err(); err != nil {
|
|
||||||
if !errors.Is(err, context.Canceled) && !errors.Is(err, context.DeadlineExceeded) {
|
|
||||||
logger.L().Warn("openai messages buffered: read error",
|
|
||||||
zap.Error(err),
|
|
||||||
zap.String("request_id", requestID),
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if finalResponse == nil {
|
if finalResponse == nil {
|
||||||
@ -380,6 +331,153 @@ func (s *OpenAIGatewayService) handleAnthropicBufferedStreamingResponse(
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func isOpenAICompatResponsesTerminalEvent(eventType string) bool {
|
||||||
|
switch strings.TrimSpace(eventType) {
|
||||||
|
case "response.completed", "response.done", "response.incomplete", "response.failed":
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func isOpenAICompatDoneSentinelLine(line string) bool {
|
||||||
|
payload, ok := extractOpenAISSEDataLine(line)
|
||||||
|
return ok && strings.TrimSpace(payload) == "[DONE]"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *OpenAIGatewayService) readOpenAICompatBufferedTerminal(
|
||||||
|
resp *http.Response,
|
||||||
|
logPrefix string,
|
||||||
|
requestID string,
|
||||||
|
) (*apicompat.ResponsesResponse, OpenAIUsage, *apicompat.BufferedResponseAccumulator, error) {
|
||||||
|
acc := apicompat.NewBufferedResponseAccumulator()
|
||||||
|
var usage OpenAIUsage
|
||||||
|
if resp == nil || resp.Body == nil {
|
||||||
|
return nil, usage, acc, errors.New("upstream response body is nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
scanner := bufio.NewScanner(resp.Body)
|
||||||
|
maxLineSize := defaultMaxLineSize
|
||||||
|
if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 {
|
||||||
|
maxLineSize = s.cfg.Gateway.MaxLineSize
|
||||||
|
}
|
||||||
|
scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize)
|
||||||
|
|
||||||
|
streamInterval := time.Duration(0)
|
||||||
|
if s.cfg != nil && s.cfg.Gateway.StreamDataIntervalTimeout > 0 {
|
||||||
|
streamInterval = time.Duration(s.cfg.Gateway.StreamDataIntervalTimeout) * time.Second
|
||||||
|
}
|
||||||
|
var timeoutCh <-chan time.Time
|
||||||
|
var timeoutTimer *time.Timer
|
||||||
|
resetTimeout := func() {
|
||||||
|
if streamInterval <= 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if timeoutTimer == nil {
|
||||||
|
timeoutTimer = time.NewTimer(streamInterval)
|
||||||
|
timeoutCh = timeoutTimer.C
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !timeoutTimer.Stop() {
|
||||||
|
select {
|
||||||
|
case <-timeoutTimer.C:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
timeoutTimer.Reset(streamInterval)
|
||||||
|
}
|
||||||
|
stopTimeout := func() {
|
||||||
|
if timeoutTimer == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !timeoutTimer.Stop() {
|
||||||
|
select {
|
||||||
|
case <-timeoutTimer.C:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
resetTimeout()
|
||||||
|
defer stopTimeout()
|
||||||
|
|
||||||
|
type scanEvent struct {
|
||||||
|
line string
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
events := make(chan scanEvent, 16)
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
defer close(events)
|
||||||
|
for scanner.Scan() {
|
||||||
|
select {
|
||||||
|
case events <- scanEvent{line: scanner.Text()}:
|
||||||
|
case <-done:
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err := scanner.Err(); err != nil {
|
||||||
|
select {
|
||||||
|
case events <- scanEvent{err: err}:
|
||||||
|
case <-done:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
defer close(done)
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case ev, ok := <-events:
|
||||||
|
if !ok {
|
||||||
|
return nil, usage, acc, nil
|
||||||
|
}
|
||||||
|
resetTimeout()
|
||||||
|
if ev.err != nil {
|
||||||
|
if !errors.Is(ev.err, context.Canceled) && !errors.Is(ev.err, context.DeadlineExceeded) {
|
||||||
|
logger.L().Warn(logPrefix+": read error",
|
||||||
|
zap.Error(ev.err),
|
||||||
|
zap.String("request_id", requestID),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
return nil, usage, acc, ev.err
|
||||||
|
}
|
||||||
|
|
||||||
|
if isOpenAICompatDoneSentinelLine(ev.line) {
|
||||||
|
return nil, usage, acc, nil
|
||||||
|
}
|
||||||
|
payload, ok := extractOpenAISSEDataLine(ev.line)
|
||||||
|
if !ok || payload == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
var event apicompat.ResponsesStreamEvent
|
||||||
|
if err := json.Unmarshal([]byte(payload), &event); err != nil {
|
||||||
|
logger.L().Warn(logPrefix+": failed to parse event",
|
||||||
|
zap.Error(err),
|
||||||
|
zap.String("request_id", requestID),
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
acc.ProcessEvent(&event)
|
||||||
|
|
||||||
|
if isOpenAICompatResponsesTerminalEvent(event.Type) && event.Response != nil {
|
||||||
|
if event.Response.Usage != nil {
|
||||||
|
usage = copyOpenAIUsageFromResponsesUsage(event.Response.Usage)
|
||||||
|
}
|
||||||
|
return event.Response, usage, acc, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
case <-timeoutCh:
|
||||||
|
_ = resp.Body.Close()
|
||||||
|
logger.L().Warn(logPrefix+": data interval timeout",
|
||||||
|
zap.String("request_id", requestID),
|
||||||
|
zap.Duration("interval", streamInterval),
|
||||||
|
)
|
||||||
|
return nil, usage, acc, fmt.Errorf("stream data interval timeout")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// handleAnthropicStreamingResponse reads Responses SSE events from upstream,
|
// handleAnthropicStreamingResponse reads Responses SSE events from upstream,
|
||||||
// converts each to Anthropic SSE events, and writes them to the client.
|
// converts each to Anthropic SSE events, and writes them to the client.
|
||||||
// When StreamKeepaliveInterval is configured, it uses a goroutine + channel
|
// When StreamKeepaliveInterval is configured, it uses a goroutine + channel
|
||||||
@ -409,6 +507,7 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse(
|
|||||||
var usage OpenAIUsage
|
var usage OpenAIUsage
|
||||||
var firstTokenMs *int
|
var firstTokenMs *int
|
||||||
firstChunk := true
|
firstChunk := true
|
||||||
|
clientDisconnected := false
|
||||||
|
|
||||||
scanner := bufio.NewScanner(resp.Body)
|
scanner := bufio.NewScanner(resp.Body)
|
||||||
maxLineSize := defaultMaxLineSize
|
maxLineSize := defaultMaxLineSize
|
||||||
@ -417,6 +516,20 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse(
|
|||||||
}
|
}
|
||||||
scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize)
|
scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize)
|
||||||
|
|
||||||
|
streamInterval := time.Duration(0)
|
||||||
|
if s.cfg != nil && s.cfg.Gateway.StreamDataIntervalTimeout > 0 {
|
||||||
|
streamInterval = time.Duration(s.cfg.Gateway.StreamDataIntervalTimeout) * time.Second
|
||||||
|
}
|
||||||
|
var intervalTicker *time.Ticker
|
||||||
|
if streamInterval > 0 {
|
||||||
|
intervalTicker = time.NewTicker(streamInterval)
|
||||||
|
defer intervalTicker.Stop()
|
||||||
|
}
|
||||||
|
var intervalCh <-chan time.Time
|
||||||
|
if intervalTicker != nil {
|
||||||
|
intervalCh = intervalTicker.C
|
||||||
|
}
|
||||||
|
|
||||||
// resultWithUsage builds the final result snapshot.
|
// resultWithUsage builds the final result snapshot.
|
||||||
resultWithUsage := func() *OpenAIForwardResult {
|
resultWithUsage := func() *OpenAIForwardResult {
|
||||||
return &OpenAIForwardResult{
|
return &OpenAIForwardResult{
|
||||||
@ -432,7 +545,6 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// processDataLine handles a single "data: ..." SSE line from upstream.
|
// processDataLine handles a single "data: ..." SSE line from upstream.
|
||||||
// Returns (clientDisconnected bool).
|
|
||||||
processDataLine := func(payload string) bool {
|
processDataLine := func(payload string) bool {
|
||||||
if firstChunk {
|
if firstChunk {
|
||||||
firstChunk = false
|
firstChunk = false
|
||||||
@ -449,53 +561,58 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse(
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// Extract usage from completion events
|
// 仅按兼容转换器支持的终止事件提取 usage,避免无意扩大事件语义。
|
||||||
if (event.Type == "response.completed" || event.Type == "response.incomplete" || event.Type == "response.failed") &&
|
isTerminalEvent := isOpenAICompatResponsesTerminalEvent(event.Type)
|
||||||
event.Response != nil && event.Response.Usage != nil {
|
if isTerminalEvent && event.Response != nil && event.Response.Usage != nil {
|
||||||
usage = OpenAIUsage{
|
usage = copyOpenAIUsageFromResponsesUsage(event.Response.Usage)
|
||||||
InputTokens: event.Response.Usage.InputTokens,
|
|
||||||
OutputTokens: event.Response.Usage.OutputTokens,
|
|
||||||
}
|
|
||||||
if event.Response.Usage.InputTokensDetails != nil {
|
|
||||||
usage.CacheReadInputTokens = event.Response.Usage.InputTokensDetails.CachedTokens
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Convert to Anthropic events
|
// Convert to Anthropic events
|
||||||
events := apicompat.ResponsesEventToAnthropicEvents(&event, state)
|
events := apicompat.ResponsesEventToAnthropicEvents(&event, state)
|
||||||
for _, evt := range events {
|
if !clientDisconnected {
|
||||||
sse, err := apicompat.ResponsesAnthropicEventToSSE(evt)
|
for _, evt := range events {
|
||||||
if err != nil {
|
sse, err := apicompat.ResponsesAnthropicEventToSSE(evt)
|
||||||
logger.L().Warn("openai messages stream: failed to marshal event",
|
if err != nil {
|
||||||
zap.Error(err),
|
logger.L().Warn("openai messages stream: failed to marshal event",
|
||||||
zap.String("request_id", requestID),
|
zap.Error(err),
|
||||||
)
|
zap.String("request_id", requestID),
|
||||||
continue
|
)
|
||||||
}
|
continue
|
||||||
if _, err := fmt.Fprint(c.Writer, sse); err != nil {
|
}
|
||||||
logger.L().Info("openai messages stream: client disconnected",
|
if _, err := fmt.Fprint(c.Writer, sse); err != nil {
|
||||||
zap.String("request_id", requestID),
|
clientDisconnected = true
|
||||||
)
|
logger.L().Info("openai messages stream: client disconnected, continuing to drain upstream for billing",
|
||||||
return true
|
zap.String("request_id", requestID),
|
||||||
|
)
|
||||||
|
break
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if len(events) > 0 {
|
if len(events) > 0 && !clientDisconnected {
|
||||||
c.Writer.Flush()
|
c.Writer.Flush()
|
||||||
}
|
}
|
||||||
return false
|
return isTerminalEvent
|
||||||
}
|
}
|
||||||
|
|
||||||
// finalizeStream sends any remaining Anthropic events and returns the result.
|
// finalizeStream sends any remaining Anthropic events and returns the result.
|
||||||
finalizeStream := func() (*OpenAIForwardResult, error) {
|
finalizeStream := func() (*OpenAIForwardResult, error) {
|
||||||
if finalEvents := apicompat.FinalizeResponsesAnthropicStream(state); len(finalEvents) > 0 {
|
if finalEvents := apicompat.FinalizeResponsesAnthropicStream(state); len(finalEvents) > 0 && !clientDisconnected {
|
||||||
for _, evt := range finalEvents {
|
for _, evt := range finalEvents {
|
||||||
sse, err := apicompat.ResponsesAnthropicEventToSSE(evt)
|
sse, err := apicompat.ResponsesAnthropicEventToSSE(evt)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
fmt.Fprint(c.Writer, sse) //nolint:errcheck
|
if _, err := fmt.Fprint(c.Writer, sse); err != nil {
|
||||||
|
clientDisconnected = true
|
||||||
|
logger.L().Info("openai messages stream: client disconnected during final flush",
|
||||||
|
zap.String("request_id", requestID),
|
||||||
|
)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !clientDisconnected {
|
||||||
|
c.Writer.Flush()
|
||||||
}
|
}
|
||||||
c.Writer.Flush()
|
|
||||||
}
|
}
|
||||||
return resultWithUsage(), nil
|
return resultWithUsage(), nil
|
||||||
}
|
}
|
||||||
@ -509,6 +626,9 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse(
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
missingTerminalErr := func() (*OpenAIForwardResult, error) {
|
||||||
|
return resultWithUsage(), fmt.Errorf("stream usage incomplete: missing terminal event")
|
||||||
|
}
|
||||||
|
|
||||||
// ── Determine keepalive interval ──
|
// ── Determine keepalive interval ──
|
||||||
keepaliveInterval := time.Duration(0)
|
keepaliveInterval := time.Duration(0)
|
||||||
@ -517,18 +637,25 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// ── No keepalive: fast synchronous path (no goroutine overhead) ──
|
// ── No keepalive: fast synchronous path (no goroutine overhead) ──
|
||||||
if keepaliveInterval <= 0 {
|
if streamInterval <= 0 && keepaliveInterval <= 0 {
|
||||||
for scanner.Scan() {
|
for scanner.Scan() {
|
||||||
line := scanner.Text()
|
line := scanner.Text()
|
||||||
if !strings.HasPrefix(line, "data: ") || line == "data: [DONE]" {
|
if isOpenAICompatDoneSentinelLine(line) {
|
||||||
|
return missingTerminalErr()
|
||||||
|
}
|
||||||
|
payload, ok := extractOpenAISSEDataLine(line)
|
||||||
|
if !ok {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if processDataLine(line[6:]) {
|
if processDataLine(payload) {
|
||||||
return resultWithUsage(), nil
|
return finalizeStream()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
handleScanErr(scanner.Err())
|
if err := scanner.Err(); err != nil {
|
||||||
return finalizeStream()
|
handleScanErr(err)
|
||||||
|
return resultWithUsage(), fmt.Errorf("stream usage incomplete: %w", err)
|
||||||
|
}
|
||||||
|
return missingTerminalErr()
|
||||||
}
|
}
|
||||||
|
|
||||||
// ── With keepalive: goroutine + channel + select ──
|
// ── With keepalive: goroutine + channel + select ──
|
||||||
@ -538,6 +665,8 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse(
|
|||||||
}
|
}
|
||||||
events := make(chan scanEvent, 16)
|
events := make(chan scanEvent, 16)
|
||||||
done := make(chan struct{})
|
done := make(chan struct{})
|
||||||
|
var lastReadAt int64
|
||||||
|
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
|
||||||
sendEvent := func(ev scanEvent) bool {
|
sendEvent := func(ev scanEvent) bool {
|
||||||
select {
|
select {
|
||||||
case events <- ev:
|
case events <- ev:
|
||||||
@ -549,6 +678,7 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse(
|
|||||||
go func() {
|
go func() {
|
||||||
defer close(events)
|
defer close(events)
|
||||||
for scanner.Scan() {
|
for scanner.Scan() {
|
||||||
|
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
|
||||||
if !sendEvent(scanEvent{line: scanner.Text()}) {
|
if !sendEvent(scanEvent{line: scanner.Text()}) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -559,8 +689,15 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse(
|
|||||||
}()
|
}()
|
||||||
defer close(done)
|
defer close(done)
|
||||||
|
|
||||||
keepaliveTicker := time.NewTicker(keepaliveInterval)
|
var keepaliveTicker *time.Ticker
|
||||||
defer keepaliveTicker.Stop()
|
if keepaliveInterval > 0 {
|
||||||
|
keepaliveTicker = time.NewTicker(keepaliveInterval)
|
||||||
|
defer keepaliveTicker.Stop()
|
||||||
|
}
|
||||||
|
var keepaliveCh <-chan time.Time
|
||||||
|
if keepaliveTicker != nil {
|
||||||
|
keepaliveCh = keepaliveTicker.C
|
||||||
|
}
|
||||||
lastDataAt := time.Now()
|
lastDataAt := time.Now()
|
||||||
|
|
||||||
for {
|
for {
|
||||||
@ -568,22 +705,44 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse(
|
|||||||
case ev, ok := <-events:
|
case ev, ok := <-events:
|
||||||
if !ok {
|
if !ok {
|
||||||
// Upstream closed
|
// Upstream closed
|
||||||
return finalizeStream()
|
return missingTerminalErr()
|
||||||
}
|
}
|
||||||
if ev.err != nil {
|
if ev.err != nil {
|
||||||
handleScanErr(ev.err)
|
handleScanErr(ev.err)
|
||||||
return finalizeStream()
|
return resultWithUsage(), fmt.Errorf("stream usage incomplete: %w", ev.err)
|
||||||
}
|
}
|
||||||
lastDataAt = time.Now()
|
lastDataAt = time.Now()
|
||||||
line := ev.line
|
line := ev.line
|
||||||
if !strings.HasPrefix(line, "data: ") || line == "data: [DONE]" {
|
if isOpenAICompatDoneSentinelLine(line) {
|
||||||
|
return missingTerminalErr()
|
||||||
|
}
|
||||||
|
payload, ok := extractOpenAISSEDataLine(line)
|
||||||
|
if !ok {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if processDataLine(line[6:]) {
|
if processDataLine(payload) {
|
||||||
return resultWithUsage(), nil
|
return finalizeStream()
|
||||||
}
|
}
|
||||||
|
|
||||||
case <-keepaliveTicker.C:
|
case <-intervalCh:
|
||||||
|
lastRead := time.Unix(0, atomic.LoadInt64(&lastReadAt))
|
||||||
|
if time.Since(lastRead) < streamInterval {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if clientDisconnected {
|
||||||
|
return resultWithUsage(), fmt.Errorf("stream usage incomplete after timeout")
|
||||||
|
}
|
||||||
|
logger.L().Warn("openai messages stream: data interval timeout",
|
||||||
|
zap.String("request_id", requestID),
|
||||||
|
zap.String("model", originalModel),
|
||||||
|
zap.Duration("interval", streamInterval),
|
||||||
|
)
|
||||||
|
return resultWithUsage(), fmt.Errorf("stream data interval timeout")
|
||||||
|
|
||||||
|
case <-keepaliveCh:
|
||||||
|
if clientDisconnected {
|
||||||
|
continue
|
||||||
|
}
|
||||||
if time.Since(lastDataAt) < keepaliveInterval {
|
if time.Since(lastDataAt) < keepaliveInterval {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@ -593,7 +752,8 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse(
|
|||||||
logger.L().Info("openai messages stream: client disconnected during keepalive",
|
logger.L().Info("openai messages stream: client disconnected during keepalive",
|
||||||
zap.String("request_id", requestID),
|
zap.String("request_id", requestID),
|
||||||
)
|
)
|
||||||
return resultWithUsage(), nil
|
clientDisconnected = true
|
||||||
|
continue
|
||||||
}
|
}
|
||||||
c.Writer.Flush()
|
c.Writer.Flush()
|
||||||
}
|
}
|
||||||
@ -610,3 +770,17 @@ func writeAnthropicError(c *gin.Context, statusCode int, errType, message string
|
|||||||
},
|
},
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func copyOpenAIUsageFromResponsesUsage(usage *apicompat.ResponsesUsage) OpenAIUsage {
|
||||||
|
if usage == nil {
|
||||||
|
return OpenAIUsage{}
|
||||||
|
}
|
||||||
|
result := OpenAIUsage{
|
||||||
|
InputTokens: usage.InputTokens,
|
||||||
|
OutputTokens: usage.OutputTokens,
|
||||||
|
}
|
||||||
|
if usage.InputTokensDetails != nil {
|
||||||
|
result.CacheReadInputTokens = usage.InputTokensDetails.CachedTokens
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|||||||
@ -186,6 +186,56 @@ func max(a, b int) int {
|
|||||||
return b
|
return b
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestOpenAIGatewayServiceRecordUsage_ZeroUsageStillWritesUsageLog(t *testing.T) {
|
||||||
|
usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
|
||||||
|
billingRepo := &openAIRecordUsageBillingRepoStub{result: &UsageBillingApplyResult{Applied: true}}
|
||||||
|
userRepo := &openAIRecordUsageUserRepoStub{}
|
||||||
|
subRepo := &openAIRecordUsageSubRepoStub{}
|
||||||
|
quotaSvc := &openAIRecordUsageAPIKeyQuotaStub{}
|
||||||
|
svc := newOpenAIRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, userRepo, subRepo, nil)
|
||||||
|
|
||||||
|
err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
|
||||||
|
Result: &OpenAIForwardResult{
|
||||||
|
RequestID: "resp_zero_usage",
|
||||||
|
Usage: OpenAIUsage{},
|
||||||
|
Model: "gpt-5.1",
|
||||||
|
Duration: time.Second,
|
||||||
|
},
|
||||||
|
APIKey: &APIKey{ID: 1000, Quota: 100, Group: &Group{RateMultiplier: 1}},
|
||||||
|
User: &User{ID: 2000},
|
||||||
|
Account: &Account{ID: 3000, Type: AccountTypeAPIKey},
|
||||||
|
APIKeyService: quotaSvc,
|
||||||
|
})
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, 1, billingRepo.calls)
|
||||||
|
require.Equal(t, 1, usageRepo.calls)
|
||||||
|
require.Equal(t, 0, userRepo.deductCalls)
|
||||||
|
require.Equal(t, 0, subRepo.incrementCalls)
|
||||||
|
require.Equal(t, 0, quotaSvc.quotaCalls)
|
||||||
|
require.Equal(t, 0, quotaSvc.rateLimitCalls)
|
||||||
|
|
||||||
|
require.NotNil(t, usageRepo.lastLog)
|
||||||
|
require.Equal(t, "resp_zero_usage", usageRepo.lastLog.RequestID)
|
||||||
|
require.Zero(t, usageRepo.lastLog.InputTokens)
|
||||||
|
require.Zero(t, usageRepo.lastLog.OutputTokens)
|
||||||
|
require.Zero(t, usageRepo.lastLog.CacheCreationTokens)
|
||||||
|
require.Zero(t, usageRepo.lastLog.CacheReadTokens)
|
||||||
|
require.Zero(t, usageRepo.lastLog.ImageOutputTokens)
|
||||||
|
require.Zero(t, usageRepo.lastLog.ImageCount)
|
||||||
|
require.Zero(t, usageRepo.lastLog.InputCost)
|
||||||
|
require.Zero(t, usageRepo.lastLog.OutputCost)
|
||||||
|
require.Zero(t, usageRepo.lastLog.TotalCost)
|
||||||
|
require.Zero(t, usageRepo.lastLog.ActualCost)
|
||||||
|
|
||||||
|
require.NotNil(t, billingRepo.lastCmd)
|
||||||
|
require.Zero(t, billingRepo.lastCmd.BalanceCost)
|
||||||
|
require.Zero(t, billingRepo.lastCmd.SubscriptionCost)
|
||||||
|
require.Zero(t, billingRepo.lastCmd.APIKeyQuotaCost)
|
||||||
|
require.Zero(t, billingRepo.lastCmd.APIKeyRateLimitCost)
|
||||||
|
require.Zero(t, billingRepo.lastCmd.AccountQuotaCost)
|
||||||
|
}
|
||||||
|
|
||||||
func TestOpenAIGatewayServiceRecordUsage_UsesUserSpecificGroupRate(t *testing.T) {
|
func TestOpenAIGatewayServiceRecordUsage_UsesUserSpecificGroupRate(t *testing.T) {
|
||||||
groupID := int64(11)
|
groupID := int64(11)
|
||||||
groupRate := 1.4
|
groupRate := 1.4
|
||||||
|
|||||||
@ -2601,7 +2601,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
|
|||||||
httpInvalidEncryptedContentRetryTried := false
|
httpInvalidEncryptedContentRetryTried := false
|
||||||
for {
|
for {
|
||||||
// Build upstream request
|
// Build upstream request
|
||||||
upstreamCtx, releaseUpstreamCtx := detachStreamUpstreamContext(ctx, reqStream)
|
upstreamCtx, releaseUpstreamCtx := detachUpstreamContext(ctx)
|
||||||
upstreamReq, err := s.buildUpstreamRequest(upstreamCtx, c, account, body, token, reqStream, promptCacheKey, isCodexCLI)
|
upstreamReq, err := s.buildUpstreamRequest(upstreamCtx, c, account, body, token, reqStream, promptCacheKey, isCodexCLI)
|
||||||
releaseUpstreamCtx()
|
releaseUpstreamCtx()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -2852,7 +2852,7 @@ func (s *OpenAIGatewayService) forwardOpenAIPassthrough(
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
upstreamCtx, releaseUpstreamCtx := detachStreamUpstreamContext(ctx, reqStream)
|
upstreamCtx, releaseUpstreamCtx := detachUpstreamContext(ctx)
|
||||||
upstreamReq, err := s.buildUpstreamRequestOpenAIPassthrough(upstreamCtx, c, account, body, token)
|
upstreamReq, err := s.buildUpstreamRequestOpenAIPassthrough(upstreamCtx, c, account, body, token)
|
||||||
releaseUpstreamCtx()
|
releaseUpstreamCtx()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -5041,13 +5041,6 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
|
|||||||
s.rateLimitService.ResetOpenAI403Counter(ctx, input.Account.ID)
|
s.rateLimitService.ResetOpenAI403Counter(ctx, input.Account.ID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 跳过所有 token 均为零的用量记录——上游未返回 usage 时不应写入数据库
|
|
||||||
if result.Usage.InputTokens == 0 && result.Usage.OutputTokens == 0 &&
|
|
||||||
result.Usage.CacheCreationInputTokens == 0 && result.Usage.CacheReadInputTokens == 0 &&
|
|
||||||
result.Usage.ImageOutputTokens == 0 && result.ImageCount == 0 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
apiKey := input.APIKey
|
apiKey := input.APIKey
|
||||||
user := input.User
|
user := input.User
|
||||||
account := input.Account
|
account := input.Account
|
||||||
|
|||||||
@ -596,7 +596,7 @@ func (s *OpenAIGatewayService) forwardOpenAIImagesAPIKey(
|
|||||||
var usage OpenAIUsage
|
var usage OpenAIUsage
|
||||||
imageCount := parsed.N
|
imageCount := parsed.N
|
||||||
var firstTokenMs *int
|
var firstTokenMs *int
|
||||||
if parsed.Stream {
|
if parsed.Stream && isEventStreamResponse(resp.Header) {
|
||||||
streamUsage, streamCount, ttft, err := s.handleOpenAIImagesStreamingResponse(resp, c, startTime)
|
streamUsage, streamCount, ttft, err := s.handleOpenAIImagesStreamingResponse(resp, c, startTime)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@ -811,6 +811,11 @@ func (s *OpenAIGatewayService) handleOpenAIImagesStreamingResponse(
|
|||||||
usage := OpenAIUsage{}
|
usage := OpenAIUsage{}
|
||||||
imageCount := 0
|
imageCount := 0
|
||||||
var firstTokenMs *int
|
var firstTokenMs *int
|
||||||
|
var fallbackBody bytes.Buffer
|
||||||
|
fallbackBytes := int64(0)
|
||||||
|
fallbackLimit := resolveUpstreamResponseReadLimit(s.cfg)
|
||||||
|
seenSSEData := false
|
||||||
|
fallbackTooLarge := false
|
||||||
|
|
||||||
for {
|
for {
|
||||||
line, err := reader.ReadBytes('\n')
|
line, err := reader.ReadBytes('\n')
|
||||||
@ -824,11 +829,24 @@ func (s *OpenAIGatewayService) handleOpenAIImagesStreamingResponse(
|
|||||||
}
|
}
|
||||||
flusher.Flush()
|
flusher.Flush()
|
||||||
|
|
||||||
if data, ok := extractOpenAISSEDataLine(strings.TrimRight(string(line), "\r\n")); ok && data != "" && data != "[DONE]" {
|
if data, ok := extractOpenAISSEDataLine(strings.TrimRight(string(line), "\r\n")); ok {
|
||||||
dataBytes := []byte(data)
|
if data != "" && data != "[DONE]" {
|
||||||
mergeOpenAIUsage(&usage, dataBytes)
|
seenSSEData = true
|
||||||
if count := extractOpenAIImageCountFromJSONBytes(dataBytes); count > imageCount {
|
fallbackBody.Reset()
|
||||||
imageCount = count
|
fallbackBytes = 0
|
||||||
|
dataBytes := []byte(data)
|
||||||
|
mergeOpenAIUsage(&usage, dataBytes)
|
||||||
|
if count := extractOpenAIImagesBillableCountFromJSONBytes(dataBytes); count > imageCount {
|
||||||
|
imageCount = count
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if !seenSSEData && !fallbackTooLarge {
|
||||||
|
fallbackBytes += int64(len(line))
|
||||||
|
if fallbackBytes <= fallbackLimit {
|
||||||
|
_, _ = fallbackBody.Write(line)
|
||||||
|
} else {
|
||||||
|
fallbackTooLarge = true
|
||||||
|
fallbackBody.Reset()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -839,9 +857,41 @@ func (s *OpenAIGatewayService) handleOpenAIImagesStreamingResponse(
|
|||||||
return OpenAIUsage{}, 0, firstTokenMs, err
|
return OpenAIUsage{}, 0, firstTokenMs, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if !seenSSEData && fallbackBody.Len() > 0 {
|
||||||
|
body := bytes.TrimSpace(fallbackBody.Bytes())
|
||||||
|
if len(body) > 0 {
|
||||||
|
mergeOpenAIUsage(&usage, body)
|
||||||
|
if count := extractOpenAIImagesBillableCountFromJSONBytes(body); count > imageCount {
|
||||||
|
imageCount = count
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
return usage, imageCount, firstTokenMs, nil
|
return usage, imageCount, firstTokenMs, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func extractOpenAIImagesBillableCountFromJSONBytes(body []byte) int {
|
||||||
|
if count := extractOpenAIImageCountFromJSONBytes(body); count > 0 {
|
||||||
|
return count
|
||||||
|
}
|
||||||
|
if len(body) == 0 || !gjson.ValidBytes(body) {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
if count := int(gjson.GetBytes(body, "usage.images").Int()); count > 0 {
|
||||||
|
return count
|
||||||
|
}
|
||||||
|
if count := int(gjson.GetBytes(body, "tool_usage.image_gen.images").Int()); count > 0 {
|
||||||
|
return count
|
||||||
|
}
|
||||||
|
eventType := strings.TrimSpace(gjson.GetBytes(body, "type").String())
|
||||||
|
if eventType == "" || !strings.HasSuffix(eventType, ".completed") {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
if gjson.GetBytes(body, "b64_json").Exists() || gjson.GetBytes(body, "url").Exists() {
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
func mergeOpenAIUsage(dst *OpenAIUsage, body []byte) {
|
func mergeOpenAIUsage(dst *OpenAIUsage, body []byte) {
|
||||||
if dst == nil {
|
if dst == nil {
|
||||||
return
|
return
|
||||||
|
|||||||
@ -446,6 +446,109 @@ func TestOpenAIGatewayServiceForwardImages_APIKeyGenerationUsesConfiguredV1BaseU
|
|||||||
require.Equal(t, "aGVsbG8=", gjson.Get(rec.Body.String(), "data.0.b64_json").String())
|
require.Equal(t, "aGVsbG8=", gjson.Get(rec.Body.String(), "data.0.b64_json").String())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestOpenAIGatewayServiceForwardImages_APIKeyStreamJSONResponseBillsImage(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
body := []byte(`{"model":"gpt-image-2","prompt":"draw a cat","stream":true,"response_format":"b64_json"}`)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/v1/images/generations", bytes.NewReader(body))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
c.Request = req
|
||||||
|
|
||||||
|
svc := &OpenAIGatewayService{
|
||||||
|
cfg: &config.Config{},
|
||||||
|
httpUpstream: &httpUpstreamRecorder{
|
||||||
|
resp: &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Header: http.Header{
|
||||||
|
"Content-Type": []string{"application/json"},
|
||||||
|
"X-Request-Id": []string{"req_img_stream_json"},
|
||||||
|
},
|
||||||
|
Body: io.NopCloser(strings.NewReader(`{"created":1710000008,"usage":{"input_tokens":12,"output_tokens":21,"output_tokens_details":{"image_tokens":9}},"data":[{"b64_json":"aGVsbG8=","revised_prompt":"draw a cat"}]}`)),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
parsed, err := svc.ParseOpenAIImagesRequest(c, body)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
account := &Account{
|
||||||
|
ID: 7,
|
||||||
|
Name: "openai-apikey",
|
||||||
|
Platform: PlatformOpenAI,
|
||||||
|
Type: AccountTypeAPIKey,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"api_key": "test-api-key",
|
||||||
|
"base_url": "https://image-upstream.example/v1",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := svc.ForwardImages(context.Background(), c, account, body, parsed, "")
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, result)
|
||||||
|
require.True(t, result.Stream)
|
||||||
|
require.Equal(t, 1, result.ImageCount)
|
||||||
|
require.Equal(t, 12, result.Usage.InputTokens)
|
||||||
|
require.Equal(t, 21, result.Usage.OutputTokens)
|
||||||
|
require.Equal(t, 9, result.Usage.ImageOutputTokens)
|
||||||
|
require.Equal(t, http.StatusOK, rec.Code)
|
||||||
|
require.Equal(t, "aGVsbG8=", gjson.Get(rec.Body.String(), "data.0.b64_json").String())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOpenAIGatewayServiceForwardImages_APIKeyStreamRawJSONEventStreamFallbackBillsImage(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
body := []byte(`{"model":"gpt-image-2","prompt":"draw a cat","stream":true,"response_format":"b64_json"}`)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/v1/images/generations", bytes.NewReader(body))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
c.Request = req
|
||||||
|
|
||||||
|
svc := &OpenAIGatewayService{
|
||||||
|
cfg: &config.Config{},
|
||||||
|
httpUpstream: &httpUpstreamRecorder{
|
||||||
|
resp: &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Header: http.Header{
|
||||||
|
"Content-Type": []string{"text/event-stream"},
|
||||||
|
"X-Request-Id": []string{"req_img_stream_json_mislabeled"},
|
||||||
|
},
|
||||||
|
Body: io.NopCloser(strings.NewReader(`{"created":1710000009,"usage":{"input_tokens":10,"output_tokens":18,"output_tokens_details":{"image_tokens":8}},"data":[{"b64_json":"ZmluYWw="}]}`)),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
parsed, err := svc.ParseOpenAIImagesRequest(c, body)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
account := &Account{
|
||||||
|
ID: 8,
|
||||||
|
Name: "openai-apikey",
|
||||||
|
Platform: PlatformOpenAI,
|
||||||
|
Type: AccountTypeAPIKey,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"api_key": "test-api-key",
|
||||||
|
"base_url": "https://image-upstream.example/v1",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := svc.ForwardImages(context.Background(), c, account, body, parsed, "")
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, result)
|
||||||
|
require.True(t, result.Stream)
|
||||||
|
require.Equal(t, 1, result.ImageCount)
|
||||||
|
require.Equal(t, 10, result.Usage.InputTokens)
|
||||||
|
require.Equal(t, 18, result.Usage.OutputTokens)
|
||||||
|
require.Equal(t, 8, result.Usage.ImageOutputTokens)
|
||||||
|
require.Equal(t, "ZmluYWw=", gjson.Get(rec.Body.String(), "data.0.b64_json").String())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExtractOpenAIImagesBillableCountFromJSONBytes_CompletedEvent(t *testing.T) {
|
||||||
|
body := []byte(`{"type":"image_generation.completed","b64_json":"ZmluYWw=","usage":{"input_tokens":10,"output_tokens":18}}`)
|
||||||
|
|
||||||
|
require.Equal(t, 1, extractOpenAIImagesBillableCountFromJSONBytes(body))
|
||||||
|
}
|
||||||
|
|
||||||
func TestOpenAIGatewayServiceForwardImages_APIKeyEditUsesConfiguredV1BaseURL(t *testing.T) {
|
func TestOpenAIGatewayServiceForwardImages_APIKeyEditUsesConfiguredV1BaseURL(t *testing.T) {
|
||||||
gin.SetMode(gin.TestMode)
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
|||||||
@ -307,6 +307,52 @@ func TestOpenAIGatewayService_OAuthPassthrough_CompactUsesJSONAndKeepsNonStreami
|
|||||||
require.Contains(t, rec.Body.String(), `"id":"cmp_123"`)
|
require.Contains(t, rec.Body.String(), `"id":"cmp_123"`)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestOpenAIGatewayService_OAuthPassthrough_UpstreamRequestIgnoresClientCancel(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
reqCtx, cancel := context.WithCancel(context.Background())
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(nil)).WithContext(reqCtx)
|
||||||
|
c.Request.Header.Set("User-Agent", "codex_cli_rs/0.1.0")
|
||||||
|
cancel()
|
||||||
|
|
||||||
|
originalBody := []byte(`{"model":"gpt-5.2","stream":true,"store":true,"instructions":"local-test-instructions","input":[{"type":"text","text":"hi"}]}`)
|
||||||
|
upstream := &httpUpstreamRecorder{resp: &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_passthrough_ctx"}},
|
||||||
|
Body: io.NopCloser(strings.NewReader(strings.Join([]string{
|
||||||
|
`data: {"type":"response.completed","response":{"usage":{"input_tokens":2,"output_tokens":1}}}`,
|
||||||
|
"",
|
||||||
|
"data: [DONE]",
|
||||||
|
"",
|
||||||
|
}, "\n"))),
|
||||||
|
}}
|
||||||
|
|
||||||
|
svc := &OpenAIGatewayService{
|
||||||
|
cfg: &config.Config{Gateway: config.GatewayConfig{ForceCodexCLI: false}},
|
||||||
|
httpUpstream: upstream,
|
||||||
|
}
|
||||||
|
account := &Account{
|
||||||
|
ID: 123,
|
||||||
|
Name: "acc",
|
||||||
|
Platform: PlatformOpenAI,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Concurrency: 1,
|
||||||
|
Credentials: map[string]any{"access_token": "oauth-token", "chatgpt_account_id": "chatgpt-acc"},
|
||||||
|
Extra: map[string]any{"openai_passthrough": true, "openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeOff},
|
||||||
|
Status: StatusActive,
|
||||||
|
Schedulable: true,
|
||||||
|
RateMultiplier: f64p(1),
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := svc.Forward(reqCtx, c, account, originalBody)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, result)
|
||||||
|
require.NotNil(t, upstream.lastReq)
|
||||||
|
require.NoError(t, upstream.lastReq.Context().Err())
|
||||||
|
}
|
||||||
|
|
||||||
func TestOpenAIGatewayService_OAuthPassthrough_CodexMissingInstructionsRejectedBeforeUpstream(t *testing.T) {
|
func TestOpenAIGatewayService_OAuthPassthrough_CodexMissingInstructionsRejectedBeforeUpstream(t *testing.T) {
|
||||||
gin.SetMode(gin.TestMode)
|
gin.SetMode(gin.TestMode)
|
||||||
logSink, restore := captureStructuredLog(t)
|
logSink, restore := captureStructuredLog(t)
|
||||||
@ -405,6 +451,52 @@ func TestOpenAIGatewayService_OAuthPassthrough_DisabledUsesLegacyTransform(t *te
|
|||||||
require.Contains(t, string(upstream.lastBody), `"stream":true`)
|
require.Contains(t, string(upstream.lastBody), `"stream":true`)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestOpenAIGatewayService_OAuthLegacy_UpstreamRequestIgnoresClientCancel(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
reqCtx, cancel := context.WithCancel(context.Background())
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(nil)).WithContext(reqCtx)
|
||||||
|
c.Request.Header.Set("User-Agent", "codex_cli_rs/0.1.0")
|
||||||
|
cancel()
|
||||||
|
|
||||||
|
originalBody := []byte(`{"model":"gpt-5.2","stream":false,"store":true,"input":[{"type":"text","text":"hi"}]}`)
|
||||||
|
upstream := &httpUpstreamRecorder{resp: &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_legacy_ctx"}},
|
||||||
|
Body: io.NopCloser(strings.NewReader(strings.Join([]string{
|
||||||
|
`data: {"type":"response.completed","response":{"usage":{"input_tokens":1,"output_tokens":1}}}`,
|
||||||
|
"",
|
||||||
|
"data: [DONE]",
|
||||||
|
"",
|
||||||
|
}, "\n"))),
|
||||||
|
}}
|
||||||
|
|
||||||
|
svc := &OpenAIGatewayService{
|
||||||
|
cfg: &config.Config{Gateway: config.GatewayConfig{ForceCodexCLI: false}},
|
||||||
|
httpUpstream: upstream,
|
||||||
|
}
|
||||||
|
account := &Account{
|
||||||
|
ID: 123,
|
||||||
|
Name: "acc",
|
||||||
|
Platform: PlatformOpenAI,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Concurrency: 1,
|
||||||
|
Credentials: map[string]any{"access_token": "oauth-token", "chatgpt_account_id": "chatgpt-acc"},
|
||||||
|
Extra: map[string]any{"openai_passthrough": false, "openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeOff},
|
||||||
|
Status: StatusActive,
|
||||||
|
Schedulable: true,
|
||||||
|
RateMultiplier: f64p(1),
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := svc.Forward(reqCtx, c, account, originalBody)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, result)
|
||||||
|
require.NotNil(t, upstream.lastReq)
|
||||||
|
require.NoError(t, upstream.lastReq.Context().Err())
|
||||||
|
}
|
||||||
|
|
||||||
func TestOpenAIGatewayService_OAuthLegacy_CompositeCodexUAUsesCodexOriginator(t *testing.T) {
|
func TestOpenAIGatewayService_OAuthLegacy_CompositeCodexUAUsesCodexOriginator(t *testing.T) {
|
||||||
gin.SetMode(gin.TestMode)
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
|||||||
@ -394,7 +394,8 @@ func (s *PaymentService) applyAffiliateRebateForOrder(ctx context.Context, o *db
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
rebateAmount, err := s.affiliateService.AccrueInviteRebate(txCtx, o.UserID, o.Amount)
|
sourceOrderID := o.ID
|
||||||
|
rebateAmount, err := s.affiliateService.AccrueInviteRebateForOrder(txCtx, o.UserID, o.Amount, &sourceOrderID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.writeAuditLog(ctx, o.ID, "AFFILIATE_REBATE_FAILED", "system", map[string]any{
|
s.writeAuditLog(ctx, o.ID, "AFFILIATE_REBATE_FAILED", "system", map[string]any{
|
||||||
"error": err.Error(),
|
"error": err.Error(),
|
||||||
|
|||||||
85
backend/migrations/134_affiliate_ledger_audit_snapshots.sql
Normal file
85
backend/migrations/134_affiliate_ledger_audit_snapshots.sql
Normal file
@ -0,0 +1,85 @@
|
|||||||
|
-- 邀请返利流水补充订单关联和转余额快照。
|
||||||
|
-- 这些字段只用于审计展示;历史旧流水无法可靠反推的字段保持 NULL,避免把当前状态误展示为历史状态。
|
||||||
|
|
||||||
|
ALTER TABLE user_affiliate_ledger
|
||||||
|
ADD COLUMN IF NOT EXISTS source_order_id BIGINT NULL REFERENCES payment_orders(id) ON DELETE SET NULL;
|
||||||
|
|
||||||
|
ALTER TABLE user_affiliate_ledger
|
||||||
|
ADD COLUMN IF NOT EXISTS balance_after DECIMAL(20,8) NULL;
|
||||||
|
|
||||||
|
ALTER TABLE user_affiliate_ledger
|
||||||
|
ADD COLUMN IF NOT EXISTS aff_quota_after DECIMAL(20,8) NULL;
|
||||||
|
|
||||||
|
ALTER TABLE user_affiliate_ledger
|
||||||
|
ADD COLUMN IF NOT EXISTS aff_frozen_quota_after DECIMAL(20,8) NULL;
|
||||||
|
|
||||||
|
ALTER TABLE user_affiliate_ledger
|
||||||
|
ADD COLUMN IF NOT EXISTS aff_history_quota_after DECIMAL(20,8) NULL;
|
||||||
|
|
||||||
|
COMMENT ON COLUMN user_affiliate_ledger.source_order_id IS '产生该返利流水的充值订单;转余额或无法可靠回填的历史数据为 NULL';
|
||||||
|
COMMENT ON COLUMN user_affiliate_ledger.balance_after IS '邀请返利转余额后的用户余额快照;无法取得时为 NULL';
|
||||||
|
COMMENT ON COLUMN user_affiliate_ledger.aff_quota_after IS '邀请返利转余额后的可用返利额度快照;无法取得时为 NULL';
|
||||||
|
COMMENT ON COLUMN user_affiliate_ledger.aff_frozen_quota_after IS '邀请返利转余额后的冻结返利额度快照;无法取得时为 NULL';
|
||||||
|
COMMENT ON COLUMN user_affiliate_ledger.aff_history_quota_after IS '邀请返利转余额后的历史返利总额快照;无法取得时为 NULL';
|
||||||
|
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_user_affiliate_ledger_source_order_id
|
||||||
|
ON user_affiliate_ledger(source_order_id)
|
||||||
|
WHERE source_order_id IS NOT NULL;
|
||||||
|
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_user_affiliate_ledger_rebate_lookup
|
||||||
|
ON user_affiliate_ledger(action, source_order_id, user_id, source_user_id, created_at)
|
||||||
|
WHERE action = 'accrue';
|
||||||
|
|
||||||
|
-- 尽力回填 PR #2169 合并后、该迁移前已经产生的返利流水。
|
||||||
|
-- 只有在同一订单只能匹配到一条返利流水时才回填,避免把多笔同额流水错误绑定到订单。
|
||||||
|
WITH rebate_audits AS (
|
||||||
|
SELECT po.id AS order_id,
|
||||||
|
po.user_id AS invitee_user_id,
|
||||||
|
invitee_aff.inviter_id,
|
||||||
|
rebate_detail.rebate_amount,
|
||||||
|
pal.created_at AS audit_created_at
|
||||||
|
FROM payment_audit_logs pal
|
||||||
|
CROSS JOIN LATERAL (
|
||||||
|
SELECT substring(
|
||||||
|
pal.detail
|
||||||
|
FROM '"rebateAmount"[[:space:]]*:[[:space:]]*(-?[0-9]+(\.[0-9]+)?)'
|
||||||
|
)::numeric AS rebate_amount
|
||||||
|
) rebate_detail
|
||||||
|
JOIN payment_orders po ON po.id::text = pal.order_id
|
||||||
|
JOIN user_affiliates invitee_aff ON invitee_aff.user_id = po.user_id
|
||||||
|
WHERE pal.action = 'AFFILIATE_REBATE_APPLIED'
|
||||||
|
AND rebate_detail.rebate_amount IS NOT NULL
|
||||||
|
),
|
||||||
|
ranked_matches AS (
|
||||||
|
SELECT ual.id AS ledger_id,
|
||||||
|
ra.order_id,
|
||||||
|
COUNT(*) OVER (PARTITION BY ra.order_id) AS order_match_count,
|
||||||
|
COUNT(*) OVER (PARTITION BY ual.id) AS ledger_match_count,
|
||||||
|
ROW_NUMBER() OVER (
|
||||||
|
PARTITION BY ual.id
|
||||||
|
ORDER BY ABS(EXTRACT(EPOCH FROM (ual.created_at - ra.audit_created_at))), ra.order_id
|
||||||
|
) AS ledger_rank
|
||||||
|
FROM rebate_audits ra
|
||||||
|
JOIN user_affiliate_ledger ual
|
||||||
|
ON ual.action = 'accrue'
|
||||||
|
AND ual.source_order_id IS NULL
|
||||||
|
AND ual.user_id = ra.inviter_id
|
||||||
|
AND ual.source_user_id = ra.invitee_user_id
|
||||||
|
AND ABS(ual.amount - ra.rebate_amount) < 0.00000001
|
||||||
|
AND ual.created_at BETWEEN ra.audit_created_at - INTERVAL '10 minutes'
|
||||||
|
AND ra.audit_created_at + INTERVAL '10 minutes'
|
||||||
|
)
|
||||||
|
UPDATE user_affiliate_ledger ual
|
||||||
|
SET source_order_id = ranked_matches.order_id,
|
||||||
|
updated_at = NOW()
|
||||||
|
FROM ranked_matches
|
||||||
|
WHERE ual.id = ranked_matches.ledger_id
|
||||||
|
AND ranked_matches.order_match_count = 1
|
||||||
|
AND ranked_matches.ledger_match_count = 1
|
||||||
|
AND ranked_matches.ledger_rank = 1
|
||||||
|
AND NOT EXISTS (
|
||||||
|
SELECT 1
|
||||||
|
FROM user_affiliate_ledger existing
|
||||||
|
WHERE existing.source_order_id = ranked_matches.order_id
|
||||||
|
AND existing.action = 'accrue'
|
||||||
|
);
|
||||||
@ -127,3 +127,18 @@ func TestMigration124BackfillsLegacyOIDCSecurityFlagsSafely(t *testing.T) {
|
|||||||
require.Contains(t, sql, "oidc_connect_enabled")
|
require.Contains(t, sql, "oidc_connect_enabled")
|
||||||
require.Contains(t, sql, "'false'")
|
require.Contains(t, sql, "'false'")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestMigration134AddsAffiliateLedgerAuditFieldsWithoutJSONCast(t *testing.T) {
|
||||||
|
content, err := FS.ReadFile("134_affiliate_ledger_audit_snapshots.sql")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
sql := string(content)
|
||||||
|
require.Contains(t, sql, "ADD COLUMN IF NOT EXISTS source_order_id BIGINT")
|
||||||
|
require.Contains(t, sql, "ADD COLUMN IF NOT EXISTS balance_after DECIMAL(20,8)")
|
||||||
|
require.Contains(t, sql, "ADD COLUMN IF NOT EXISTS aff_quota_after DECIMAL(20,8)")
|
||||||
|
require.Contains(t, sql, "substring(")
|
||||||
|
require.Contains(t, sql, `"rebateAmount"`)
|
||||||
|
require.Contains(t, sql, "COUNT(*) OVER (PARTITION BY ra.order_id) AS order_match_count")
|
||||||
|
require.Contains(t, sql, "COUNT(*) OVER (PARTITION BY ual.id) AS ledger_match_count")
|
||||||
|
require.NotContains(t, sql, "detail::jsonb")
|
||||||
|
}
|
||||||
|
|||||||
@ -23,6 +23,72 @@ export interface ListAffiliateUsersParams {
|
|||||||
search?: string
|
search?: string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export interface ListAffiliateRecordsParams {
|
||||||
|
page?: number
|
||||||
|
page_size?: number
|
||||||
|
search?: string
|
||||||
|
start_at?: string
|
||||||
|
end_at?: string
|
||||||
|
sort_by?: string
|
||||||
|
sort_order?: 'asc' | 'desc'
|
||||||
|
timezone?: string
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface AffiliateInviteRecord {
|
||||||
|
inviter_id: number
|
||||||
|
inviter_email: string
|
||||||
|
inviter_username: string
|
||||||
|
invitee_id: number
|
||||||
|
invitee_email: string
|
||||||
|
invitee_username: string
|
||||||
|
aff_code: string
|
||||||
|
total_rebate: number
|
||||||
|
created_at: string
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface AffiliateRebateRecord {
|
||||||
|
order_id: number
|
||||||
|
out_trade_no: string
|
||||||
|
inviter_id: number
|
||||||
|
inviter_email: string
|
||||||
|
inviter_username: string
|
||||||
|
invitee_id: number
|
||||||
|
invitee_email: string
|
||||||
|
invitee_username: string
|
||||||
|
order_amount: number
|
||||||
|
pay_amount: number
|
||||||
|
rebate_amount: number
|
||||||
|
payment_type: string
|
||||||
|
order_status: string
|
||||||
|
created_at: string
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface AffiliateTransferRecord {
|
||||||
|
ledger_id: number
|
||||||
|
user_id: number
|
||||||
|
user_email: string
|
||||||
|
username: string
|
||||||
|
amount: number
|
||||||
|
balance_after?: number | null
|
||||||
|
available_quota_after?: number | null
|
||||||
|
frozen_quota_after?: number | null
|
||||||
|
history_quota_after?: number | null
|
||||||
|
snapshot_available: boolean
|
||||||
|
created_at: string
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface AffiliateUserOverview {
|
||||||
|
user_id: number
|
||||||
|
email: string
|
||||||
|
username: string
|
||||||
|
aff_code: string
|
||||||
|
rebate_rate_percent: number
|
||||||
|
invited_count: number
|
||||||
|
rebated_invitee_count: number
|
||||||
|
available_quota: number
|
||||||
|
history_quota: number
|
||||||
|
}
|
||||||
|
|
||||||
export interface UpdateAffiliateUserRequest {
|
export interface UpdateAffiliateUserRequest {
|
||||||
aff_code?: string
|
aff_code?: string
|
||||||
aff_rebate_rate_percent?: number | null
|
aff_rebate_rate_percent?: number | null
|
||||||
@ -97,12 +163,68 @@ export async function batchSetRate(
|
|||||||
return data
|
return data
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function recordParams(params: ListAffiliateRecordsParams = {}) {
|
||||||
|
return {
|
||||||
|
page: params.page ?? 1,
|
||||||
|
page_size: params.page_size ?? 20,
|
||||||
|
search: params.search ?? '',
|
||||||
|
start_at: params.start_at || undefined,
|
||||||
|
end_at: params.end_at || undefined,
|
||||||
|
sort_by: params.sort_by || undefined,
|
||||||
|
sort_order: params.sort_order || undefined,
|
||||||
|
timezone: params.timezone || undefined,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
export async function listInviteRecords(
|
||||||
|
params: ListAffiliateRecordsParams = {},
|
||||||
|
): Promise<PaginatedResponse<AffiliateInviteRecord>> {
|
||||||
|
const { data } = await apiClient.get<PaginatedResponse<AffiliateInviteRecord>>(
|
||||||
|
'/admin/affiliates/invites',
|
||||||
|
{ params: recordParams(params) },
|
||||||
|
)
|
||||||
|
return data
|
||||||
|
}
|
||||||
|
|
||||||
|
export async function listRebateRecords(
|
||||||
|
params: ListAffiliateRecordsParams = {},
|
||||||
|
): Promise<PaginatedResponse<AffiliateRebateRecord>> {
|
||||||
|
const { data } = await apiClient.get<PaginatedResponse<AffiliateRebateRecord>>(
|
||||||
|
'/admin/affiliates/rebates',
|
||||||
|
{ params: recordParams(params) },
|
||||||
|
)
|
||||||
|
return data
|
||||||
|
}
|
||||||
|
|
||||||
|
export async function listTransferRecords(
|
||||||
|
params: ListAffiliateRecordsParams = {},
|
||||||
|
): Promise<PaginatedResponse<AffiliateTransferRecord>> {
|
||||||
|
const { data } = await apiClient.get<PaginatedResponse<AffiliateTransferRecord>>(
|
||||||
|
'/admin/affiliates/transfers',
|
||||||
|
{ params: recordParams(params) },
|
||||||
|
)
|
||||||
|
return data
|
||||||
|
}
|
||||||
|
|
||||||
|
export async function getUserOverview(
|
||||||
|
userId: number,
|
||||||
|
): Promise<AffiliateUserOverview> {
|
||||||
|
const { data } = await apiClient.get<AffiliateUserOverview>(
|
||||||
|
`/admin/affiliates/users/${userId}/overview`,
|
||||||
|
)
|
||||||
|
return data
|
||||||
|
}
|
||||||
|
|
||||||
export const affiliatesAPI = {
|
export const affiliatesAPI = {
|
||||||
listUsers,
|
listUsers,
|
||||||
lookupUsers,
|
lookupUsers,
|
||||||
updateUserSettings,
|
updateUserSettings,
|
||||||
clearUserSettings,
|
clearUserSettings,
|
||||||
batchSetRate,
|
batchSetRate,
|
||||||
|
listInviteRecords,
|
||||||
|
listRebateRecords,
|
||||||
|
listTransferRecords,
|
||||||
|
getUserOverview,
|
||||||
}
|
}
|
||||||
|
|
||||||
export default affiliatesAPI
|
export default affiliatesAPI
|
||||||
|
|||||||
@ -249,7 +249,7 @@ export interface BalanceHistoryResponse extends PaginatedResponse<BalanceHistory
|
|||||||
* @param id - User ID
|
* @param id - User ID
|
||||||
* @param page - Page number
|
* @param page - Page number
|
||||||
* @param pageSize - Items per page
|
* @param pageSize - Items per page
|
||||||
* @param type - Optional type filter (balance, admin_balance, concurrency, admin_concurrency, subscription)
|
* @param type - Optional type filter (balance, affiliate_balance, admin_balance, concurrency, admin_concurrency, subscription)
|
||||||
* @returns Paginated balance history with total_recharged
|
* @returns Paginated balance history with total_recharged
|
||||||
*/
|
*/
|
||||||
export async function getUserBalanceHistory(
|
export async function getUserBalanceHistory(
|
||||||
|
|||||||
@ -196,6 +196,7 @@ const totalPages = computed(() => Math.ceil(total.value / pageSize) || 1)
|
|||||||
const typeOptions = computed(() => [
|
const typeOptions = computed(() => [
|
||||||
{ value: '', label: t('admin.users.allTypes') },
|
{ value: '', label: t('admin.users.allTypes') },
|
||||||
{ value: 'balance', label: t('admin.users.typeBalance') },
|
{ value: 'balance', label: t('admin.users.typeBalance') },
|
||||||
|
{ value: 'affiliate_balance', label: t('admin.users.typeAffiliateBalance') },
|
||||||
{ value: 'admin_balance', label: t('admin.users.typeAdminBalance') },
|
{ value: 'admin_balance', label: t('admin.users.typeAdminBalance') },
|
||||||
{ value: 'concurrency', label: t('admin.users.typeConcurrency') },
|
{ value: 'concurrency', label: t('admin.users.typeConcurrency') },
|
||||||
{ value: 'admin_concurrency', label: t('admin.users.typeAdminConcurrency') },
|
{ value: 'admin_concurrency', label: t('admin.users.typeAdminConcurrency') },
|
||||||
@ -235,7 +236,7 @@ const loadHistory = async (page: number) => {
|
|||||||
const isAdminType = (type: string) => type === 'admin_balance' || type === 'admin_concurrency'
|
const isAdminType = (type: string) => type === 'admin_balance' || type === 'admin_concurrency'
|
||||||
|
|
||||||
// Helper: check if balance type (includes admin_balance)
|
// Helper: check if balance type (includes admin_balance)
|
||||||
const isBalanceType = (type: string) => type === 'balance' || type === 'admin_balance'
|
const isBalanceType = (type: string) => type === 'balance' || type === 'admin_balance' || type === 'affiliate_balance'
|
||||||
|
|
||||||
// Helper: check if subscription type
|
// Helper: check if subscription type
|
||||||
const isSubscriptionType = (type: string) => type === 'subscription'
|
const isSubscriptionType = (type: string) => type === 'subscription'
|
||||||
@ -291,6 +292,8 @@ const getItemTitle = (item: BalanceHistoryItem) => {
|
|||||||
switch (item.type) {
|
switch (item.type) {
|
||||||
case 'balance':
|
case 'balance':
|
||||||
return t('redeem.balanceAddedRedeem')
|
return t('redeem.balanceAddedRedeem')
|
||||||
|
case 'affiliate_balance':
|
||||||
|
return t('redeem.balanceAddedAffiliate')
|
||||||
case 'admin_balance':
|
case 'admin_balance':
|
||||||
return item.value >= 0 ? t('redeem.balanceAddedAdmin') : t('redeem.balanceDeductedAdmin')
|
return item.value >= 0 ? t('redeem.balanceAddedAdmin') : t('redeem.balanceDeductedAdmin')
|
||||||
case 'concurrency':
|
case 'concurrency':
|
||||||
|
|||||||
@ -721,6 +721,19 @@ const adminNavItems = computed((): NavItem[] => {
|
|||||||
{ path: '/admin/proxies', label: t('nav.proxies'), icon: ServerIcon },
|
{ path: '/admin/proxies', label: t('nav.proxies'), icon: ServerIcon },
|
||||||
{ path: '/admin/redeem', label: t('nav.redeemCodes'), icon: TicketIcon, hideInSimpleMode: true },
|
{ path: '/admin/redeem', label: t('nav.redeemCodes'), icon: TicketIcon, hideInSimpleMode: true },
|
||||||
{ path: '/admin/promo-codes', label: t('nav.promoCodes'), icon: GiftIcon, hideInSimpleMode: true },
|
{ path: '/admin/promo-codes', label: t('nav.promoCodes'), icon: GiftIcon, hideInSimpleMode: true },
|
||||||
|
{
|
||||||
|
path: '/admin/affiliates',
|
||||||
|
label: t('nav.affiliateManagement'),
|
||||||
|
icon: UsersIcon,
|
||||||
|
hideInSimpleMode: true,
|
||||||
|
expandOnly: true,
|
||||||
|
featureFlag: flagAffiliate,
|
||||||
|
children: [
|
||||||
|
{ path: '/admin/affiliates/invites', label: t('nav.affiliateInviteRecords'), icon: UsersIcon },
|
||||||
|
{ path: '/admin/affiliates/rebates', label: t('nav.affiliateRebateRecords'), icon: OrderIcon },
|
||||||
|
{ path: '/admin/affiliates/transfers', label: t('nav.affiliateTransferRecords'), icon: CreditCardIcon },
|
||||||
|
],
|
||||||
|
},
|
||||||
{
|
{
|
||||||
path: '/admin/orders',
|
path: '/admin/orders',
|
||||||
label: t('nav.orderManagement'),
|
label: t('nav.orderManagement'),
|
||||||
|
|||||||
@ -347,6 +347,10 @@ export default {
|
|||||||
usage: 'Usage',
|
usage: 'Usage',
|
||||||
redeem: 'Redeem',
|
redeem: 'Redeem',
|
||||||
affiliate: 'Affiliate Rebates',
|
affiliate: 'Affiliate Rebates',
|
||||||
|
affiliateManagement: 'Affiliate Rebates',
|
||||||
|
affiliateInviteRecords: 'Invite Records',
|
||||||
|
affiliateRebateRecords: 'Rebate Records',
|
||||||
|
affiliateTransferRecords: 'Transfer Records',
|
||||||
profile: 'Profile',
|
profile: 'Profile',
|
||||||
users: 'Users',
|
users: 'Users',
|
||||||
groups: 'Groups',
|
groups: 'Groups',
|
||||||
@ -1046,6 +1050,7 @@ export default {
|
|||||||
recentActivity: 'Recent Activity',
|
recentActivity: 'Recent Activity',
|
||||||
historyWillAppear: 'Your redemption history will appear here',
|
historyWillAppear: 'Your redemption history will appear here',
|
||||||
balanceAddedRedeem: 'Balance Added (Redeem)',
|
balanceAddedRedeem: 'Balance Added (Redeem)',
|
||||||
|
balanceAddedAffiliate: 'Balance Added (Affiliate Transfer)',
|
||||||
balanceAddedAdmin: 'Balance Added (Admin)',
|
balanceAddedAdmin: 'Balance Added (Admin)',
|
||||||
balanceDeductedAdmin: 'Balance Deducted (Admin)',
|
balanceDeductedAdmin: 'Balance Deducted (Admin)',
|
||||||
concurrencyAddedRedeem: 'Concurrency Added (Redeem)',
|
concurrencyAddedRedeem: 'Concurrency Added (Redeem)',
|
||||||
@ -1635,6 +1640,49 @@ export default {
|
|||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
|
||||||
|
affiliates: {
|
||||||
|
invitesDescription: 'View site-wide inviter and invitee relationships',
|
||||||
|
rebatesDescription: 'View recharge orders that generated affiliate rebates',
|
||||||
|
transfersDescription: 'View affiliate quota transfers into account balance',
|
||||||
|
errors: {
|
||||||
|
loadFailed: 'Failed to load affiliate records'
|
||||||
|
},
|
||||||
|
records: {
|
||||||
|
search: 'Search',
|
||||||
|
searchPlaceholder: 'Email, username, user ID, or order number',
|
||||||
|
startAt: 'Start date',
|
||||||
|
endAt: 'End date',
|
||||||
|
inviter: 'Inviter',
|
||||||
|
invitee: 'Invitee',
|
||||||
|
user: 'User',
|
||||||
|
affCode: 'Invite Code',
|
||||||
|
order: 'Order',
|
||||||
|
totalRebate: 'Total Rebate',
|
||||||
|
orderAmount: 'Top-up Amount',
|
||||||
|
payAmount: 'Paid Amount',
|
||||||
|
rebateAmount: 'Rebate Amount',
|
||||||
|
paymentType: 'Payment Method',
|
||||||
|
orderStatus: 'Order Status',
|
||||||
|
transferAmount: 'Transfer Amount',
|
||||||
|
balanceAfter: 'Balance After',
|
||||||
|
availableQuotaAfter: 'Available After',
|
||||||
|
frozenQuotaAfter: 'Frozen After',
|
||||||
|
historyQuotaAfter: 'Historical Rebate After',
|
||||||
|
invitedAt: 'Invited At',
|
||||||
|
rebatedAt: 'Rebated At',
|
||||||
|
transferredAt: 'Transferred At'
|
||||||
|
},
|
||||||
|
overview: {
|
||||||
|
title: 'Affiliate User Overview',
|
||||||
|
affCode: 'Invite Code',
|
||||||
|
rebateRate: 'Rebate Rate',
|
||||||
|
invitedCount: 'Invited Users',
|
||||||
|
rebatedInviteeCount: 'Rebated Invitees',
|
||||||
|
availableQuota: 'Available Quota',
|
||||||
|
historyQuota: 'Historical Rebate'
|
||||||
|
}
|
||||||
|
},
|
||||||
|
|
||||||
// Users
|
// Users
|
||||||
users: {
|
users: {
|
||||||
title: 'User Management',
|
title: 'User Management',
|
||||||
@ -1787,6 +1835,7 @@ export default {
|
|||||||
noBalanceHistory: 'No records found for this user',
|
noBalanceHistory: 'No records found for this user',
|
||||||
allTypes: 'All Types',
|
allTypes: 'All Types',
|
||||||
typeBalance: 'Balance (Redeem)',
|
typeBalance: 'Balance (Redeem)',
|
||||||
|
typeAffiliateBalance: 'Balance (Affiliate Transfer)',
|
||||||
typeAdminBalance: 'Balance (Admin)',
|
typeAdminBalance: 'Balance (Admin)',
|
||||||
typeConcurrency: 'Concurrency (Redeem)',
|
typeConcurrency: 'Concurrency (Redeem)',
|
||||||
typeAdminConcurrency: 'Concurrency (Admin)',
|
typeAdminConcurrency: 'Concurrency (Admin)',
|
||||||
|
|||||||
@ -347,6 +347,10 @@ export default {
|
|||||||
usage: '使用记录',
|
usage: '使用记录',
|
||||||
redeem: '兑换',
|
redeem: '兑换',
|
||||||
affiliate: '邀请返利',
|
affiliate: '邀请返利',
|
||||||
|
affiliateManagement: '邀请返利',
|
||||||
|
affiliateInviteRecords: '邀请记录',
|
||||||
|
affiliateRebateRecords: '返利记录',
|
||||||
|
affiliateTransferRecords: '提取记录',
|
||||||
profile: '个人资料',
|
profile: '个人资料',
|
||||||
users: '用户管理',
|
users: '用户管理',
|
||||||
groups: '分组管理',
|
groups: '分组管理',
|
||||||
@ -1050,6 +1054,7 @@ export default {
|
|||||||
recentActivity: '最近活动',
|
recentActivity: '最近活动',
|
||||||
historyWillAppear: '您的兑换历史将显示在这里',
|
historyWillAppear: '您的兑换历史将显示在这里',
|
||||||
balanceAddedRedeem: '余额充值(兑换)',
|
balanceAddedRedeem: '余额充值(兑换)',
|
||||||
|
balanceAddedAffiliate: '余额充值(返利转入)',
|
||||||
balanceAddedAdmin: '余额充值(管理员)',
|
balanceAddedAdmin: '余额充值(管理员)',
|
||||||
balanceDeductedAdmin: '余额扣除(管理员)',
|
balanceDeductedAdmin: '余额扣除(管理员)',
|
||||||
concurrencyAddedRedeem: '并发增加(兑换)',
|
concurrencyAddedRedeem: '并发增加(兑换)',
|
||||||
@ -1656,6 +1661,49 @@ export default {
|
|||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
|
||||||
|
affiliates: {
|
||||||
|
invitesDescription: '查看全站邀请关系和被邀请用户累计返利',
|
||||||
|
rebatesDescription: '查看每一笔产生返利的充值订单',
|
||||||
|
transfersDescription: '查看返利额度转入账户余额的提取流水',
|
||||||
|
errors: {
|
||||||
|
loadFailed: '加载邀请返利记录失败'
|
||||||
|
},
|
||||||
|
records: {
|
||||||
|
search: '搜索',
|
||||||
|
searchPlaceholder: '邮箱、用户名、用户 ID、订单号',
|
||||||
|
startAt: '开始日期',
|
||||||
|
endAt: '结束日期',
|
||||||
|
inviter: '邀请人',
|
||||||
|
invitee: '被邀请人',
|
||||||
|
user: '用户',
|
||||||
|
affCode: '邀请码',
|
||||||
|
order: '订单',
|
||||||
|
totalRebate: '累计返利',
|
||||||
|
orderAmount: '充值金额',
|
||||||
|
payAmount: '支付金额',
|
||||||
|
rebateAmount: '返利金额',
|
||||||
|
paymentType: '支付方式',
|
||||||
|
orderStatus: '订单状态',
|
||||||
|
transferAmount: '提取金额',
|
||||||
|
balanceAfter: '提取后余额',
|
||||||
|
availableQuotaAfter: '提取后可提',
|
||||||
|
frozenQuotaAfter: '提取后冻结',
|
||||||
|
historyQuotaAfter: '提取后历史返利',
|
||||||
|
invitedAt: '邀请时间',
|
||||||
|
rebatedAt: '返利时间',
|
||||||
|
transferredAt: '提取时间'
|
||||||
|
},
|
||||||
|
overview: {
|
||||||
|
title: '用户返利概览',
|
||||||
|
affCode: '邀请码',
|
||||||
|
rebateRate: '返利比例',
|
||||||
|
invitedCount: '邀请人数',
|
||||||
|
rebatedInviteeCount: '已产生返利人数',
|
||||||
|
availableQuota: '可提余额',
|
||||||
|
historyQuota: '历史返利'
|
||||||
|
}
|
||||||
|
},
|
||||||
|
|
||||||
// Users Management
|
// Users Management
|
||||||
users: {
|
users: {
|
||||||
title: '用户管理',
|
title: '用户管理',
|
||||||
@ -1844,6 +1892,7 @@ export default {
|
|||||||
noBalanceHistory: '暂无变动记录',
|
noBalanceHistory: '暂无变动记录',
|
||||||
allTypes: '全部类型',
|
allTypes: '全部类型',
|
||||||
typeBalance: '余额(兑换码)',
|
typeBalance: '余额(兑换码)',
|
||||||
|
typeAffiliateBalance: '余额(返利转入)',
|
||||||
typeAdminBalance: '余额(管理员调整)',
|
typeAdminBalance: '余额(管理员调整)',
|
||||||
typeConcurrency: '并发(兑换码)',
|
typeConcurrency: '并发(兑换码)',
|
||||||
typeAdminConcurrency: '并发(管理员调整)',
|
typeAdminConcurrency: '并发(管理员调整)',
|
||||||
|
|||||||
@ -517,6 +517,46 @@ const routes: RouteRecordRaw[] = [
|
|||||||
descriptionKey: 'admin.usage.description'
|
descriptionKey: 'admin.usage.description'
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
path: '/admin/affiliates',
|
||||||
|
redirect: '/admin/affiliates/invites'
|
||||||
|
},
|
||||||
|
{
|
||||||
|
path: '/admin/affiliates/invites',
|
||||||
|
name: 'AdminAffiliateInvites',
|
||||||
|
component: () => import('@/views/admin/affiliates/AdminAffiliateInvitesView.vue'),
|
||||||
|
meta: {
|
||||||
|
requiresAuth: true,
|
||||||
|
requiresAdmin: true,
|
||||||
|
title: 'Affiliate Invite Records',
|
||||||
|
titleKey: 'nav.affiliateInviteRecords',
|
||||||
|
descriptionKey: 'admin.affiliates.invitesDescription'
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
path: '/admin/affiliates/rebates',
|
||||||
|
name: 'AdminAffiliateRebates',
|
||||||
|
component: () => import('@/views/admin/affiliates/AdminAffiliateRebatesView.vue'),
|
||||||
|
meta: {
|
||||||
|
requiresAuth: true,
|
||||||
|
requiresAdmin: true,
|
||||||
|
title: 'Affiliate Rebate Records',
|
||||||
|
titleKey: 'nav.affiliateRebateRecords',
|
||||||
|
descriptionKey: 'admin.affiliates.rebatesDescription'
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
path: '/admin/affiliates/transfers',
|
||||||
|
name: 'AdminAffiliateTransfers',
|
||||||
|
component: () => import('@/views/admin/affiliates/AdminAffiliateTransfersView.vue'),
|
||||||
|
meta: {
|
||||||
|
requiresAuth: true,
|
||||||
|
requiresAdmin: true,
|
||||||
|
title: 'Affiliate Transfer Records',
|
||||||
|
titleKey: 'nav.affiliateTransferRecords',
|
||||||
|
descriptionKey: 'admin.affiliates.transfersDescription'
|
||||||
|
}
|
||||||
|
},
|
||||||
|
|
||||||
|
|
||||||
// ==================== Payment Admin Routes ====================
|
// ==================== Payment Admin Routes ====================
|
||||||
|
|||||||
@ -0,0 +1,7 @@
|
|||||||
|
<template>
|
||||||
|
<AdminAffiliateRecordsTable type="invites" />
|
||||||
|
</template>
|
||||||
|
|
||||||
|
<script setup lang="ts">
|
||||||
|
import AdminAffiliateRecordsTable from './AdminAffiliateRecordsTable.vue'
|
||||||
|
</script>
|
||||||
@ -0,0 +1,7 @@
|
|||||||
|
<template>
|
||||||
|
<AdminAffiliateRecordsTable type="rebates" />
|
||||||
|
</template>
|
||||||
|
|
||||||
|
<script setup lang="ts">
|
||||||
|
import AdminAffiliateRecordsTable from './AdminAffiliateRecordsTable.vue'
|
||||||
|
</script>
|
||||||
@ -0,0 +1,407 @@
|
|||||||
|
<template>
|
||||||
|
<AppLayout>
|
||||||
|
<TablePageLayout>
|
||||||
|
<template #filters>
|
||||||
|
<div class="flex flex-wrap items-center gap-3">
|
||||||
|
<div class="relative w-full md:w-80">
|
||||||
|
<Icon name="search" size="md" class="absolute left-3 top-1/2 -translate-y-1/2 text-gray-400" />
|
||||||
|
<input v-model="filters.search" type="text" class="input pl-10" :placeholder="t('admin.affiliates.records.searchPlaceholder')" @input="debounceLoad" />
|
||||||
|
</div>
|
||||||
|
<input v-model="filters.start_at" type="date" class="input w-full sm:w-44" :title="t('admin.affiliates.records.startAt')" @change="reloadFromFirstPage" />
|
||||||
|
<input v-model="filters.end_at" type="date" class="input w-full sm:w-44" :title="t('admin.affiliates.records.endAt')" @change="reloadFromFirstPage" />
|
||||||
|
<button class="btn btn-secondary px-2 md:px-3" :disabled="loading" :title="t('common.refresh')" @click="loadRecords">
|
||||||
|
<Icon name="refresh" size="md" :class="loading ? 'animate-spin' : ''" />
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
</template>
|
||||||
|
|
||||||
|
<template #table>
|
||||||
|
<DataTable
|
||||||
|
:columns="columns"
|
||||||
|
:data="records"
|
||||||
|
:loading="loading"
|
||||||
|
:server-side-sort="true"
|
||||||
|
default-sort-key="created_at"
|
||||||
|
default-sort-order="desc"
|
||||||
|
:sort-storage-key="sortStorageKey"
|
||||||
|
@sort="handleSort"
|
||||||
|
>
|
||||||
|
<template #cell-inviter="{ row }">
|
||||||
|
<UserCell
|
||||||
|
:id="row.inviter_id"
|
||||||
|
:email="row.inviter_email"
|
||||||
|
:username="row.inviter_username"
|
||||||
|
:clickable="props.type !== 'transfers'"
|
||||||
|
@open="openUserOverview"
|
||||||
|
/>
|
||||||
|
</template>
|
||||||
|
<template #cell-invitee="{ row }">
|
||||||
|
<UserCell
|
||||||
|
:id="row.invitee_id"
|
||||||
|
:email="row.invitee_email"
|
||||||
|
:username="row.invitee_username"
|
||||||
|
:clickable="props.type !== 'transfers'"
|
||||||
|
@open="openUserOverview"
|
||||||
|
/>
|
||||||
|
</template>
|
||||||
|
<template #cell-user="{ row }">
|
||||||
|
<UserCell
|
||||||
|
:id="row.user_id"
|
||||||
|
:email="row.user_email"
|
||||||
|
:username="row.username"
|
||||||
|
:clickable="true"
|
||||||
|
@open="openUserOverview"
|
||||||
|
/>
|
||||||
|
</template>
|
||||||
|
<template #cell-aff_code="{ row }">
|
||||||
|
<span class="font-mono text-sm text-gray-700 dark:text-gray-300">{{ row.aff_code || '-' }}</span>
|
||||||
|
</template>
|
||||||
|
<template #cell-order="{ row }">
|
||||||
|
<div class="space-y-0.5">
|
||||||
|
<div class="font-mono text-sm text-gray-900 dark:text-white">#{{ row.order_id }}</div>
|
||||||
|
<div class="max-w-56 truncate text-sm text-gray-500 dark:text-dark-400">{{ row.out_trade_no }}</div>
|
||||||
|
</div>
|
||||||
|
</template>
|
||||||
|
<template #cell-payment_type="{ row }">
|
||||||
|
{{ t('payment.methods.' + row.payment_type, row.payment_type || '-') }}
|
||||||
|
</template>
|
||||||
|
<template #cell-order_status="{ row }">
|
||||||
|
<OrderStatusBadge :status="row.order_status" />
|
||||||
|
</template>
|
||||||
|
<template #cell-total_rebate="{ row }">
|
||||||
|
<AmountText :value="row.total_rebate" />
|
||||||
|
</template>
|
||||||
|
<template #cell-order_amount="{ row }">
|
||||||
|
<AmountText :value="row.order_amount" />
|
||||||
|
</template>
|
||||||
|
<template #cell-pay_amount="{ row }">
|
||||||
|
<span class="text-sm text-gray-900 dark:text-white">¥{{ formatAmount(row.pay_amount) }}</span>
|
||||||
|
</template>
|
||||||
|
<template #cell-rebate_amount="{ row }">
|
||||||
|
<AmountText :value="row.rebate_amount" strong />
|
||||||
|
</template>
|
||||||
|
<template #cell-amount="{ row }">
|
||||||
|
<AmountText :value="row.amount" strong />
|
||||||
|
</template>
|
||||||
|
<template #cell-balance_after="{ row }">
|
||||||
|
<NullableAmountText :value="row.balance_after" />
|
||||||
|
</template>
|
||||||
|
<template #cell-available_quota_after="{ row }">
|
||||||
|
<NullableAmountText :value="row.available_quota_after" />
|
||||||
|
</template>
|
||||||
|
<template #cell-frozen_quota_after="{ row }">
|
||||||
|
<NullableAmountText :value="row.frozen_quota_after" />
|
||||||
|
</template>
|
||||||
|
<template #cell-history_quota_after="{ row }">
|
||||||
|
<NullableAmountText :value="row.history_quota_after" />
|
||||||
|
</template>
|
||||||
|
<template #cell-created_at="{ row }">
|
||||||
|
<span class="text-sm text-gray-700 dark:text-gray-300">{{ formatDateTime(row.created_at) }}</span>
|
||||||
|
</template>
|
||||||
|
</DataTable>
|
||||||
|
</template>
|
||||||
|
|
||||||
|
<template #pagination>
|
||||||
|
<Pagination
|
||||||
|
v-if="pagination.total > 0"
|
||||||
|
:page="pagination.page"
|
||||||
|
:total="pagination.total"
|
||||||
|
:page-size="pagination.page_size"
|
||||||
|
@update:page="handlePageChange"
|
||||||
|
@update:pageSize="handlePageSizeChange"
|
||||||
|
/>
|
||||||
|
</template>
|
||||||
|
</TablePageLayout>
|
||||||
|
|
||||||
|
<BaseDialog
|
||||||
|
:show="overviewDialog"
|
||||||
|
:title="t('admin.affiliates.overview.title')"
|
||||||
|
width="normal"
|
||||||
|
@close="overviewDialog = false"
|
||||||
|
>
|
||||||
|
<div v-if="overviewLoading" class="flex justify-center py-8">
|
||||||
|
<div class="h-6 w-6 animate-spin rounded-full border-2 border-primary-500 border-t-transparent"></div>
|
||||||
|
</div>
|
||||||
|
<div v-else-if="selectedOverview" class="space-y-4">
|
||||||
|
<div class="rounded-lg border border-gray-100 bg-gray-50 p-4 dark:border-dark-700 dark:bg-dark-800">
|
||||||
|
<div class="font-mono text-sm text-gray-900 dark:text-white">#{{ selectedOverview.user_id }}</div>
|
||||||
|
<div class="mt-1 text-sm font-medium text-gray-900 dark:text-white">{{ selectedOverview.email || '-' }}</div>
|
||||||
|
<div class="mt-0.5 text-sm text-gray-500 dark:text-dark-400">{{ selectedOverview.username || '-' }}</div>
|
||||||
|
</div>
|
||||||
|
<div class="grid gap-3 sm:grid-cols-2">
|
||||||
|
<OverviewStat :label="t('admin.affiliates.overview.affCode')" :value="selectedOverview.aff_code || '-'" mono />
|
||||||
|
<OverviewStat :label="t('admin.affiliates.overview.rebateRate')" :value="formatPercent(selectedOverview.rebate_rate_percent)" />
|
||||||
|
<OverviewStat :label="t('admin.affiliates.overview.invitedCount')" :value="String(selectedOverview.invited_count)" />
|
||||||
|
<OverviewStat :label="t('admin.affiliates.overview.rebatedInviteeCount')" :value="String(selectedOverview.rebated_invitee_count)" />
|
||||||
|
<OverviewStat :label="t('admin.affiliates.overview.availableQuota')" :value="'$' + formatAmount(selectedOverview.available_quota)" />
|
||||||
|
<OverviewStat :label="t('admin.affiliates.overview.historyQuota')" :value="'$' + formatAmount(selectedOverview.history_quota)" />
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</BaseDialog>
|
||||||
|
</AppLayout>
|
||||||
|
</template>
|
||||||
|
|
||||||
|
<script setup lang="ts">
|
||||||
|
import { computed, defineComponent, h, onMounted, reactive, ref, type PropType } from 'vue'
|
||||||
|
import { useI18n } from 'vue-i18n'
|
||||||
|
import AppLayout from '@/components/layout/AppLayout.vue'
|
||||||
|
import TablePageLayout from '@/components/layout/TablePageLayout.vue'
|
||||||
|
import DataTable from '@/components/common/DataTable.vue'
|
||||||
|
import Pagination from '@/components/common/Pagination.vue'
|
||||||
|
import BaseDialog from '@/components/common/BaseDialog.vue'
|
||||||
|
import Icon from '@/components/icons/Icon.vue'
|
||||||
|
import OrderStatusBadge from '@/components/payment/OrderStatusBadge.vue'
|
||||||
|
import type { Column } from '@/components/common/types'
|
||||||
|
import { useAppStore } from '@/stores/app'
|
||||||
|
import { affiliatesAPI, type AffiliateInviteRecord, type AffiliateRebateRecord, type AffiliateTransferRecord, type AffiliateUserOverview, type ListAffiliateRecordsParams } from '@/api/admin/affiliates'
|
||||||
|
import type { PaginatedResponse } from '@/types'
|
||||||
|
import { extractI18nErrorMessage } from '@/utils/apiError'
|
||||||
|
import { formatDateTime as formatDisplayDateTime } from '@/utils/format'
|
||||||
|
|
||||||
|
type RecordType = 'invites' | 'rebates' | 'transfers'
|
||||||
|
type AffiliateRecord = AffiliateInviteRecord | AffiliateRebateRecord | AffiliateTransferRecord
|
||||||
|
|
||||||
|
const props = defineProps<{
|
||||||
|
type: RecordType
|
||||||
|
}>()
|
||||||
|
|
||||||
|
const { t } = useI18n()
|
||||||
|
const appStore = useAppStore()
|
||||||
|
const loading = ref(false)
|
||||||
|
const records = ref<AffiliateRecord[]>([])
|
||||||
|
const filters = reactive({ search: '', start_at: '', end_at: '' })
|
||||||
|
const pagination = reactive({ page: 1, page_size: 20, total: 0 })
|
||||||
|
const overviewDialog = ref(false)
|
||||||
|
const overviewLoading = ref(false)
|
||||||
|
const selectedOverview = ref<AffiliateUserOverview | null>(null)
|
||||||
|
let debounceTimer: ReturnType<typeof setTimeout> | null = null
|
||||||
|
|
||||||
|
const columns = computed<Column[]>(() => {
|
||||||
|
if (props.type === 'invites') {
|
||||||
|
return [
|
||||||
|
{ key: 'inviter', label: t('admin.affiliates.records.inviter'), sortable: true },
|
||||||
|
{ key: 'invitee', label: t('admin.affiliates.records.invitee'), sortable: true },
|
||||||
|
{ key: 'aff_code', label: t('admin.affiliates.records.affCode'), sortable: true },
|
||||||
|
{ key: 'total_rebate', label: t('admin.affiliates.records.totalRebate'), sortable: true },
|
||||||
|
{ key: 'created_at', label: t('admin.affiliates.records.invitedAt'), sortable: true },
|
||||||
|
]
|
||||||
|
}
|
||||||
|
if (props.type === 'rebates') {
|
||||||
|
return [
|
||||||
|
{ key: 'order', label: t('admin.affiliates.records.order'), sortable: true },
|
||||||
|
{ key: 'inviter', label: t('admin.affiliates.records.inviter'), sortable: true },
|
||||||
|
{ key: 'invitee', label: t('admin.affiliates.records.invitee'), sortable: true },
|
||||||
|
{ key: 'order_amount', label: t('admin.affiliates.records.orderAmount'), sortable: true },
|
||||||
|
{ key: 'pay_amount', label: t('admin.affiliates.records.payAmount'), sortable: true },
|
||||||
|
{ key: 'rebate_amount', label: t('admin.affiliates.records.rebateAmount') },
|
||||||
|
{ key: 'payment_type', label: t('admin.affiliates.records.paymentType'), sortable: true },
|
||||||
|
{ key: 'order_status', label: t('admin.affiliates.records.orderStatus'), sortable: true },
|
||||||
|
{ key: 'created_at', label: t('admin.affiliates.records.rebatedAt'), sortable: true },
|
||||||
|
]
|
||||||
|
}
|
||||||
|
return [
|
||||||
|
{ key: 'user', label: t('admin.affiliates.records.user'), sortable: true },
|
||||||
|
{ key: 'amount', label: t('admin.affiliates.records.transferAmount'), sortable: true },
|
||||||
|
{ key: 'balance_after', label: t('admin.affiliates.records.balanceAfter'), sortable: true },
|
||||||
|
{ key: 'available_quota_after', label: t('admin.affiliates.records.availableQuotaAfter'), sortable: true },
|
||||||
|
{ key: 'frozen_quota_after', label: t('admin.affiliates.records.frozenQuotaAfter'), sortable: true },
|
||||||
|
{ key: 'history_quota_after', label: t('admin.affiliates.records.historyQuotaAfter'), sortable: true },
|
||||||
|
{ key: 'created_at', label: t('admin.affiliates.records.transferredAt'), sortable: true },
|
||||||
|
]
|
||||||
|
})
|
||||||
|
|
||||||
|
const sortStorageKey = computed(() => `admin-affiliate-${props.type}-table-sort`)
|
||||||
|
|
||||||
|
function loadInitialSortState(): { sort_by: string; sort_order: 'asc' | 'desc' } {
|
||||||
|
const fallback = { sort_by: 'created_at', sort_order: 'desc' as 'asc' | 'desc' }
|
||||||
|
try {
|
||||||
|
const raw = localStorage.getItem(sortStorageKey.value)
|
||||||
|
if (!raw) return fallback
|
||||||
|
const parsed = JSON.parse(raw) as { key?: string; order?: string }
|
||||||
|
const key = typeof parsed.key === 'string' ? parsed.key : ''
|
||||||
|
if (!columns.value.some((column) => column.key === key && column.sortable)) return fallback
|
||||||
|
return {
|
||||||
|
sort_by: key,
|
||||||
|
sort_order: parsed.order === 'asc' ? 'asc' : 'desc',
|
||||||
|
}
|
||||||
|
} catch {
|
||||||
|
return fallback
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const sortState = reactive(loadInitialSortState())
|
||||||
|
|
||||||
|
function userTimezone(): string {
|
||||||
|
try {
|
||||||
|
return Intl.DateTimeFormat().resolvedOptions().timeZone
|
||||||
|
} catch {
|
||||||
|
return 'UTC'
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
function buildParams(): ListAffiliateRecordsParams {
|
||||||
|
return {
|
||||||
|
page: pagination.page,
|
||||||
|
page_size: pagination.page_size,
|
||||||
|
search: filters.search.trim() || undefined,
|
||||||
|
start_at: filters.start_at || undefined,
|
||||||
|
end_at: filters.end_at || undefined,
|
||||||
|
sort_by: sortState.sort_by,
|
||||||
|
sort_order: sortState.sort_order,
|
||||||
|
timezone: userTimezone(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async function fetchRecords(params: ListAffiliateRecordsParams): Promise<PaginatedResponse<AffiliateRecord>> {
|
||||||
|
if (props.type === 'invites') {
|
||||||
|
return affiliatesAPI.listInviteRecords(params)
|
||||||
|
}
|
||||||
|
if (props.type === 'rebates') {
|
||||||
|
return affiliatesAPI.listRebateRecords(params)
|
||||||
|
}
|
||||||
|
return affiliatesAPI.listTransferRecords(params)
|
||||||
|
}
|
||||||
|
|
||||||
|
async function loadRecords() {
|
||||||
|
loading.value = true
|
||||||
|
try {
|
||||||
|
const res = await fetchRecords(buildParams())
|
||||||
|
records.value = res.items || []
|
||||||
|
pagination.total = res.total || 0
|
||||||
|
} catch (error) {
|
||||||
|
appStore.showError(extractI18nErrorMessage(error, t, 'admin.affiliates.errors', t('common.error')))
|
||||||
|
} finally {
|
||||||
|
loading.value = false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
function debounceLoad() {
|
||||||
|
if (debounceTimer) clearTimeout(debounceTimer)
|
||||||
|
debounceTimer = setTimeout(() => reloadFromFirstPage(), 300)
|
||||||
|
}
|
||||||
|
|
||||||
|
function reloadFromFirstPage() {
|
||||||
|
pagination.page = 1
|
||||||
|
void loadRecords()
|
||||||
|
}
|
||||||
|
|
||||||
|
function handlePageChange(page: number) {
|
||||||
|
pagination.page = page
|
||||||
|
void loadRecords()
|
||||||
|
}
|
||||||
|
|
||||||
|
function handlePageSizeChange(size: number) {
|
||||||
|
pagination.page_size = size
|
||||||
|
pagination.page = 1
|
||||||
|
void loadRecords()
|
||||||
|
}
|
||||||
|
|
||||||
|
function handleSort(key: string, order: 'asc' | 'desc') {
|
||||||
|
sortState.sort_by = key
|
||||||
|
sortState.sort_order = order
|
||||||
|
pagination.page = 1
|
||||||
|
void loadRecords()
|
||||||
|
}
|
||||||
|
|
||||||
|
function formatAmount(value: number | null | undefined): string {
|
||||||
|
return Number(value || 0).toFixed(2)
|
||||||
|
}
|
||||||
|
|
||||||
|
function formatPercent(value: number | null | undefined): string {
|
||||||
|
const rounded = Math.round(Number(value || 0) * 100) / 100
|
||||||
|
return `${Number.isInteger(rounded) ? rounded.toString() : rounded.toString()}%`
|
||||||
|
}
|
||||||
|
|
||||||
|
function formatDateTime(value: string | null | undefined): string {
|
||||||
|
return value ? formatDisplayDateTime(value) : '-'
|
||||||
|
}
|
||||||
|
|
||||||
|
async function openUserOverview(userId: number) {
|
||||||
|
if (!userId) return
|
||||||
|
overviewDialog.value = true
|
||||||
|
overviewLoading.value = true
|
||||||
|
selectedOverview.value = null
|
||||||
|
try {
|
||||||
|
selectedOverview.value = await affiliatesAPI.getUserOverview(userId)
|
||||||
|
} catch (error) {
|
||||||
|
overviewDialog.value = false
|
||||||
|
appStore.showError(extractI18nErrorMessage(error, t, 'admin.affiliates.errors', t('common.error')))
|
||||||
|
} finally {
|
||||||
|
overviewLoading.value = false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const UserCell = defineComponent({
|
||||||
|
props: {
|
||||||
|
id: { type: Number, required: true },
|
||||||
|
email: { type: String, default: '' },
|
||||||
|
username: { type: String, default: '' },
|
||||||
|
clickable: { type: Boolean, default: false },
|
||||||
|
},
|
||||||
|
emits: ['open'],
|
||||||
|
setup(cellProps, { emit }) {
|
||||||
|
return () => h('div', { class: 'space-y-0.5' }, [
|
||||||
|
h('div', { class: 'font-mono text-sm text-gray-900 dark:text-white' }, `#${cellProps.id}`),
|
||||||
|
h(cellProps.clickable ? 'button' : 'div', {
|
||||||
|
class: cellProps.clickable
|
||||||
|
? 'max-w-56 truncate text-left text-sm font-medium text-primary-600 hover:text-primary-700 hover:underline dark:text-primary-400 dark:hover:text-primary-300'
|
||||||
|
: 'max-w-56 truncate text-sm text-gray-700 dark:text-gray-300',
|
||||||
|
type: cellProps.clickable ? 'button' : undefined,
|
||||||
|
onClick: cellProps.clickable ? () => emit('open', cellProps.id) : undefined,
|
||||||
|
}, cellProps.email || '-'),
|
||||||
|
h('div', { class: 'max-w-56 truncate text-sm text-gray-500 dark:text-dark-400' }, cellProps.username || '-'),
|
||||||
|
])
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
const AmountText = defineComponent({
|
||||||
|
props: {
|
||||||
|
value: { type: Number, default: 0 },
|
||||||
|
strong: { type: Boolean, default: false },
|
||||||
|
},
|
||||||
|
setup(amountProps) {
|
||||||
|
return () => h('span', {
|
||||||
|
class: amountProps.strong
|
||||||
|
? 'text-sm font-semibold text-emerald-600 dark:text-emerald-400'
|
||||||
|
: 'text-sm text-gray-900 dark:text-white',
|
||||||
|
}, `$${formatAmount(amountProps.value)}`)
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
const NullableAmountText = defineComponent({
|
||||||
|
props: {
|
||||||
|
value: { type: Number as PropType<number | null | undefined>, default: null },
|
||||||
|
},
|
||||||
|
setup(amountProps) {
|
||||||
|
return () => {
|
||||||
|
const value = amountProps.value
|
||||||
|
if (value === null || value === undefined) {
|
||||||
|
return h('span', { class: 'text-sm text-gray-400 dark:text-dark-500' }, '-')
|
||||||
|
}
|
||||||
|
return h(AmountText, { value })
|
||||||
|
}
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
const OverviewStat = defineComponent({
|
||||||
|
props: {
|
||||||
|
label: { type: String, required: true },
|
||||||
|
value: { type: String, required: true },
|
||||||
|
mono: { type: Boolean, default: false },
|
||||||
|
},
|
||||||
|
setup(statProps) {
|
||||||
|
return () => h('div', { class: 'rounded-lg border border-gray-100 bg-white p-3 dark:border-dark-700 dark:bg-dark-900' }, [
|
||||||
|
h('div', { class: 'text-sm text-gray-500 dark:text-dark-400' }, statProps.label),
|
||||||
|
h('div', {
|
||||||
|
class: statProps.mono
|
||||||
|
? 'mt-1 font-mono text-base font-semibold text-gray-900 dark:text-white'
|
||||||
|
: 'mt-1 text-base font-semibold text-gray-900 dark:text-white',
|
||||||
|
}, statProps.value),
|
||||||
|
])
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
onMounted(() => {
|
||||||
|
void loadRecords()
|
||||||
|
})
|
||||||
|
</script>
|
||||||
@ -0,0 +1,7 @@
|
|||||||
|
<template>
|
||||||
|
<AdminAffiliateRecordsTable type="transfers" />
|
||||||
|
</template>
|
||||||
|
|
||||||
|
<script setup lang="ts">
|
||||||
|
import AdminAffiliateRecordsTable from './AdminAffiliateRecordsTable.vue'
|
||||||
|
</script>
|
||||||
Loading…
x
Reference in New Issue
Block a user