diff --git a/backend/cmd/server/VERSION b/backend/cmd/server/VERSION index 025c3166..2c88ab23 100644 --- a/backend/cmd/server/VERSION +++ b/backend/cmd/server/VERSION @@ -1 +1 @@ -0.1.121 +0.1.123 diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go index ff385516..6c9673e4 100644 --- a/backend/cmd/server/wire_gen.go +++ b/backend/cmd/server/wire_gen.go @@ -265,7 +265,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { opsMetricsCollector := service.ProvideOpsMetricsCollector(opsRepository, settingRepository, accountRepository, concurrencyService, db, redisClient, configConfig) opsAggregationService := service.ProvideOpsAggregationService(opsRepository, settingRepository, db, redisClient, configConfig) opsAlertEvaluatorService := service.ProvideOpsAlertEvaluatorService(opsService, opsRepository, emailService, redisClient, configConfig) - opsCleanupService := service.ProvideOpsCleanupService(opsRepository, db, redisClient, configConfig, channelMonitorService) + opsCleanupService := service.ProvideOpsCleanupService(opsRepository, db, redisClient, configConfig, channelMonitorService, settingRepository, opsService) opsScheduledReportService := service.ProvideOpsScheduledReportService(opsService, userService, emailService, redisClient, configConfig) tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, schedulerCache, configConfig, tempUnschedCache, privacyClientFactory, proxyRepository, oAuthRefreshAPI) accountExpiryService := service.ProvideAccountExpiryService(accountRepository) diff --git a/backend/ent/group.go b/backend/ent/group.go index 5d9ae2ed..a4f52c73 100644 --- a/backend/ent/group.go +++ b/backend/ent/group.go @@ -47,6 +47,12 @@ type Group struct { MonthlyLimitUsd *float64 `json:"monthly_limit_usd,omitempty"` // DefaultValidityDays holds the value of the "default_validity_days" field. DefaultValidityDays int `json:"default_validity_days,omitempty"` + // 是否允许该分组使用图片生成能力 + AllowImageGeneration bool `json:"allow_image_generation,omitempty"` + // 图片生成是否使用独立倍率;false 表示共享分组有效倍率 + ImageRateIndependent bool `json:"image_rate_independent,omitempty"` + // 图片生成独立倍率,仅 image_rate_independent=true 时生效 + ImageRateMultiplier float64 `json:"image_rate_multiplier,omitempty"` // ImagePrice1k holds the value of the "image_price_1k" field. ImagePrice1k *float64 `json:"image_price_1k,omitempty"` // ImagePrice2k holds the value of the "image_price_2k" field. @@ -189,9 +195,9 @@ func (*Group) scanValues(columns []string) ([]any, error) { switch columns[i] { case group.FieldModelRouting, group.FieldSupportedModelScopes, group.FieldMessagesDispatchModelConfig: values[i] = new([]byte) - case group.FieldIsExclusive, group.FieldClaudeCodeOnly, group.FieldModelRoutingEnabled, group.FieldMcpXMLInject, group.FieldAllowMessagesDispatch, group.FieldRequireOauthOnly, group.FieldRequirePrivacySet: + case group.FieldIsExclusive, group.FieldAllowImageGeneration, group.FieldImageRateIndependent, group.FieldClaudeCodeOnly, group.FieldModelRoutingEnabled, group.FieldMcpXMLInject, group.FieldAllowMessagesDispatch, group.FieldRequireOauthOnly, group.FieldRequirePrivacySet: values[i] = new(sql.NullBool) - case group.FieldRateMultiplier, group.FieldDailyLimitUsd, group.FieldWeeklyLimitUsd, group.FieldMonthlyLimitUsd, group.FieldImagePrice1k, group.FieldImagePrice2k, group.FieldImagePrice4k: + case group.FieldRateMultiplier, group.FieldDailyLimitUsd, group.FieldWeeklyLimitUsd, group.FieldMonthlyLimitUsd, group.FieldImageRateMultiplier, group.FieldImagePrice1k, group.FieldImagePrice2k, group.FieldImagePrice4k: values[i] = new(sql.NullFloat64) case group.FieldID, group.FieldDefaultValidityDays, group.FieldFallbackGroupID, group.FieldFallbackGroupIDOnInvalidRequest, group.FieldSortOrder, group.FieldRpmLimit: values[i] = new(sql.NullInt64) @@ -309,6 +315,24 @@ func (_m *Group) assignValues(columns []string, values []any) error { } else if value.Valid { _m.DefaultValidityDays = int(value.Int64) } + case group.FieldAllowImageGeneration: + if value, ok := values[i].(*sql.NullBool); !ok { + return fmt.Errorf("unexpected type %T for field allow_image_generation", values[i]) + } else if value.Valid { + _m.AllowImageGeneration = value.Bool + } + case group.FieldImageRateIndependent: + if value, ok := values[i].(*sql.NullBool); !ok { + return fmt.Errorf("unexpected type %T for field image_rate_independent", values[i]) + } else if value.Valid { + _m.ImageRateIndependent = value.Bool + } + case group.FieldImageRateMultiplier: + if value, ok := values[i].(*sql.NullFloat64); !ok { + return fmt.Errorf("unexpected type %T for field image_rate_multiplier", values[i]) + } else if value.Valid { + _m.ImageRateMultiplier = value.Float64 + } case group.FieldImagePrice1k: if value, ok := values[i].(*sql.NullFloat64); !ok { return fmt.Errorf("unexpected type %T for field image_price_1k", values[i]) @@ -550,6 +574,15 @@ func (_m *Group) String() string { builder.WriteString("default_validity_days=") builder.WriteString(fmt.Sprintf("%v", _m.DefaultValidityDays)) builder.WriteString(", ") + builder.WriteString("allow_image_generation=") + builder.WriteString(fmt.Sprintf("%v", _m.AllowImageGeneration)) + builder.WriteString(", ") + builder.WriteString("image_rate_independent=") + builder.WriteString(fmt.Sprintf("%v", _m.ImageRateIndependent)) + builder.WriteString(", ") + builder.WriteString("image_rate_multiplier=") + builder.WriteString(fmt.Sprintf("%v", _m.ImageRateMultiplier)) + builder.WriteString(", ") if v := _m.ImagePrice1k; v != nil { builder.WriteString("image_price_1k=") builder.WriteString(fmt.Sprintf("%v", *v)) diff --git a/backend/ent/group/group.go b/backend/ent/group/group.go index 24bd9c13..4e9ba6b6 100644 --- a/backend/ent/group/group.go +++ b/backend/ent/group/group.go @@ -44,6 +44,12 @@ const ( FieldMonthlyLimitUsd = "monthly_limit_usd" // FieldDefaultValidityDays holds the string denoting the default_validity_days field in the database. FieldDefaultValidityDays = "default_validity_days" + // FieldAllowImageGeneration holds the string denoting the allow_image_generation field in the database. + FieldAllowImageGeneration = "allow_image_generation" + // FieldImageRateIndependent holds the string denoting the image_rate_independent field in the database. + FieldImageRateIndependent = "image_rate_independent" + // FieldImageRateMultiplier holds the string denoting the image_rate_multiplier field in the database. + FieldImageRateMultiplier = "image_rate_multiplier" // FieldImagePrice1k holds the string denoting the image_price_1k field in the database. FieldImagePrice1k = "image_price_1k" // FieldImagePrice2k holds the string denoting the image_price_2k field in the database. @@ -167,6 +173,9 @@ var Columns = []string{ FieldWeeklyLimitUsd, FieldMonthlyLimitUsd, FieldDefaultValidityDays, + FieldAllowImageGeneration, + FieldImageRateIndependent, + FieldImageRateMultiplier, FieldImagePrice1k, FieldImagePrice2k, FieldImagePrice4k, @@ -239,6 +248,12 @@ var ( SubscriptionTypeValidator func(string) error // DefaultDefaultValidityDays holds the default value on creation for the "default_validity_days" field. DefaultDefaultValidityDays int + // DefaultAllowImageGeneration holds the default value on creation for the "allow_image_generation" field. + DefaultAllowImageGeneration bool + // DefaultImageRateIndependent holds the default value on creation for the "image_rate_independent" field. + DefaultImageRateIndependent bool + // DefaultImageRateMultiplier holds the default value on creation for the "image_rate_multiplier" field. + DefaultImageRateMultiplier float64 // DefaultClaudeCodeOnly holds the default value on creation for the "claude_code_only" field. DefaultClaudeCodeOnly bool // DefaultModelRoutingEnabled holds the default value on creation for the "model_routing_enabled" field. @@ -343,6 +358,21 @@ func ByDefaultValidityDays(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldDefaultValidityDays, opts...).ToFunc() } +// ByAllowImageGeneration orders the results by the allow_image_generation field. +func ByAllowImageGeneration(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldAllowImageGeneration, opts...).ToFunc() +} + +// ByImageRateIndependent orders the results by the image_rate_independent field. +func ByImageRateIndependent(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldImageRateIndependent, opts...).ToFunc() +} + +// ByImageRateMultiplier orders the results by the image_rate_multiplier field. +func ByImageRateMultiplier(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldImageRateMultiplier, opts...).ToFunc() +} + // ByImagePrice1k orders the results by the image_price_1k field. func ByImagePrice1k(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldImagePrice1k, opts...).ToFunc() diff --git a/backend/ent/group/where.go b/backend/ent/group/where.go index 2814d130..d3223a92 100644 --- a/backend/ent/group/where.go +++ b/backend/ent/group/where.go @@ -125,6 +125,21 @@ func DefaultValidityDays(v int) predicate.Group { return predicate.Group(sql.FieldEQ(FieldDefaultValidityDays, v)) } +// AllowImageGeneration applies equality check predicate on the "allow_image_generation" field. It's identical to AllowImageGenerationEQ. +func AllowImageGeneration(v bool) predicate.Group { + return predicate.Group(sql.FieldEQ(FieldAllowImageGeneration, v)) +} + +// ImageRateIndependent applies equality check predicate on the "image_rate_independent" field. It's identical to ImageRateIndependentEQ. +func ImageRateIndependent(v bool) predicate.Group { + return predicate.Group(sql.FieldEQ(FieldImageRateIndependent, v)) +} + +// ImageRateMultiplier applies equality check predicate on the "image_rate_multiplier" field. It's identical to ImageRateMultiplierEQ. +func ImageRateMultiplier(v float64) predicate.Group { + return predicate.Group(sql.FieldEQ(FieldImageRateMultiplier, v)) +} + // ImagePrice1k applies equality check predicate on the "image_price_1k" field. It's identical to ImagePrice1kEQ. func ImagePrice1k(v float64) predicate.Group { return predicate.Group(sql.FieldEQ(FieldImagePrice1k, v)) @@ -900,6 +915,66 @@ func DefaultValidityDaysLTE(v int) predicate.Group { return predicate.Group(sql.FieldLTE(FieldDefaultValidityDays, v)) } +// AllowImageGenerationEQ applies the EQ predicate on the "allow_image_generation" field. +func AllowImageGenerationEQ(v bool) predicate.Group { + return predicate.Group(sql.FieldEQ(FieldAllowImageGeneration, v)) +} + +// AllowImageGenerationNEQ applies the NEQ predicate on the "allow_image_generation" field. +func AllowImageGenerationNEQ(v bool) predicate.Group { + return predicate.Group(sql.FieldNEQ(FieldAllowImageGeneration, v)) +} + +// ImageRateIndependentEQ applies the EQ predicate on the "image_rate_independent" field. +func ImageRateIndependentEQ(v bool) predicate.Group { + return predicate.Group(sql.FieldEQ(FieldImageRateIndependent, v)) +} + +// ImageRateIndependentNEQ applies the NEQ predicate on the "image_rate_independent" field. +func ImageRateIndependentNEQ(v bool) predicate.Group { + return predicate.Group(sql.FieldNEQ(FieldImageRateIndependent, v)) +} + +// ImageRateMultiplierEQ applies the EQ predicate on the "image_rate_multiplier" field. +func ImageRateMultiplierEQ(v float64) predicate.Group { + return predicate.Group(sql.FieldEQ(FieldImageRateMultiplier, v)) +} + +// ImageRateMultiplierNEQ applies the NEQ predicate on the "image_rate_multiplier" field. +func ImageRateMultiplierNEQ(v float64) predicate.Group { + return predicate.Group(sql.FieldNEQ(FieldImageRateMultiplier, v)) +} + +// ImageRateMultiplierIn applies the In predicate on the "image_rate_multiplier" field. +func ImageRateMultiplierIn(vs ...float64) predicate.Group { + return predicate.Group(sql.FieldIn(FieldImageRateMultiplier, vs...)) +} + +// ImageRateMultiplierNotIn applies the NotIn predicate on the "image_rate_multiplier" field. +func ImageRateMultiplierNotIn(vs ...float64) predicate.Group { + return predicate.Group(sql.FieldNotIn(FieldImageRateMultiplier, vs...)) +} + +// ImageRateMultiplierGT applies the GT predicate on the "image_rate_multiplier" field. +func ImageRateMultiplierGT(v float64) predicate.Group { + return predicate.Group(sql.FieldGT(FieldImageRateMultiplier, v)) +} + +// ImageRateMultiplierGTE applies the GTE predicate on the "image_rate_multiplier" field. +func ImageRateMultiplierGTE(v float64) predicate.Group { + return predicate.Group(sql.FieldGTE(FieldImageRateMultiplier, v)) +} + +// ImageRateMultiplierLT applies the LT predicate on the "image_rate_multiplier" field. +func ImageRateMultiplierLT(v float64) predicate.Group { + return predicate.Group(sql.FieldLT(FieldImageRateMultiplier, v)) +} + +// ImageRateMultiplierLTE applies the LTE predicate on the "image_rate_multiplier" field. +func ImageRateMultiplierLTE(v float64) predicate.Group { + return predicate.Group(sql.FieldLTE(FieldImageRateMultiplier, v)) +} + // ImagePrice1kEQ applies the EQ predicate on the "image_price_1k" field. func ImagePrice1kEQ(v float64) predicate.Group { return predicate.Group(sql.FieldEQ(FieldImagePrice1k, v)) diff --git a/backend/ent/group_create.go b/backend/ent/group_create.go index 20ea0a0f..44b905bd 100644 --- a/backend/ent/group_create.go +++ b/backend/ent/group_create.go @@ -217,6 +217,48 @@ func (_c *GroupCreate) SetNillableDefaultValidityDays(v *int) *GroupCreate { return _c } +// SetAllowImageGeneration sets the "allow_image_generation" field. +func (_c *GroupCreate) SetAllowImageGeneration(v bool) *GroupCreate { + _c.mutation.SetAllowImageGeneration(v) + return _c +} + +// SetNillableAllowImageGeneration sets the "allow_image_generation" field if the given value is not nil. +func (_c *GroupCreate) SetNillableAllowImageGeneration(v *bool) *GroupCreate { + if v != nil { + _c.SetAllowImageGeneration(*v) + } + return _c +} + +// SetImageRateIndependent sets the "image_rate_independent" field. +func (_c *GroupCreate) SetImageRateIndependent(v bool) *GroupCreate { + _c.mutation.SetImageRateIndependent(v) + return _c +} + +// SetNillableImageRateIndependent sets the "image_rate_independent" field if the given value is not nil. +func (_c *GroupCreate) SetNillableImageRateIndependent(v *bool) *GroupCreate { + if v != nil { + _c.SetImageRateIndependent(*v) + } + return _c +} + +// SetImageRateMultiplier sets the "image_rate_multiplier" field. +func (_c *GroupCreate) SetImageRateMultiplier(v float64) *GroupCreate { + _c.mutation.SetImageRateMultiplier(v) + return _c +} + +// SetNillableImageRateMultiplier sets the "image_rate_multiplier" field if the given value is not nil. +func (_c *GroupCreate) SetNillableImageRateMultiplier(v *float64) *GroupCreate { + if v != nil { + _c.SetImageRateMultiplier(*v) + } + return _c +} + // SetImagePrice1k sets the "image_price_1k" field. func (_c *GroupCreate) SetImagePrice1k(v float64) *GroupCreate { _c.mutation.SetImagePrice1k(v) @@ -604,6 +646,18 @@ func (_c *GroupCreate) defaults() error { v := group.DefaultDefaultValidityDays _c.mutation.SetDefaultValidityDays(v) } + if _, ok := _c.mutation.AllowImageGeneration(); !ok { + v := group.DefaultAllowImageGeneration + _c.mutation.SetAllowImageGeneration(v) + } + if _, ok := _c.mutation.ImageRateIndependent(); !ok { + v := group.DefaultImageRateIndependent + _c.mutation.SetImageRateIndependent(v) + } + if _, ok := _c.mutation.ImageRateMultiplier(); !ok { + v := group.DefaultImageRateMultiplier + _c.mutation.SetImageRateMultiplier(v) + } if _, ok := _c.mutation.ClaudeCodeOnly(); !ok { v := group.DefaultClaudeCodeOnly _c.mutation.SetClaudeCodeOnly(v) @@ -700,6 +754,15 @@ func (_c *GroupCreate) check() error { if _, ok := _c.mutation.DefaultValidityDays(); !ok { return &ValidationError{Name: "default_validity_days", err: errors.New(`ent: missing required field "Group.default_validity_days"`)} } + if _, ok := _c.mutation.AllowImageGeneration(); !ok { + return &ValidationError{Name: "allow_image_generation", err: errors.New(`ent: missing required field "Group.allow_image_generation"`)} + } + if _, ok := _c.mutation.ImageRateIndependent(); !ok { + return &ValidationError{Name: "image_rate_independent", err: errors.New(`ent: missing required field "Group.image_rate_independent"`)} + } + if _, ok := _c.mutation.ImageRateMultiplier(); !ok { + return &ValidationError{Name: "image_rate_multiplier", err: errors.New(`ent: missing required field "Group.image_rate_multiplier"`)} + } if _, ok := _c.mutation.ClaudeCodeOnly(); !ok { return &ValidationError{Name: "claude_code_only", err: errors.New(`ent: missing required field "Group.claude_code_only"`)} } @@ -821,6 +884,18 @@ func (_c *GroupCreate) createSpec() (*Group, *sqlgraph.CreateSpec) { _spec.SetField(group.FieldDefaultValidityDays, field.TypeInt, value) _node.DefaultValidityDays = value } + if value, ok := _c.mutation.AllowImageGeneration(); ok { + _spec.SetField(group.FieldAllowImageGeneration, field.TypeBool, value) + _node.AllowImageGeneration = value + } + if value, ok := _c.mutation.ImageRateIndependent(); ok { + _spec.SetField(group.FieldImageRateIndependent, field.TypeBool, value) + _node.ImageRateIndependent = value + } + if value, ok := _c.mutation.ImageRateMultiplier(); ok { + _spec.SetField(group.FieldImageRateMultiplier, field.TypeFloat64, value) + _node.ImageRateMultiplier = value + } if value, ok := _c.mutation.ImagePrice1k(); ok { _spec.SetField(group.FieldImagePrice1k, field.TypeFloat64, value) _node.ImagePrice1k = &value @@ -1261,6 +1336,48 @@ func (u *GroupUpsert) AddDefaultValidityDays(v int) *GroupUpsert { return u } +// SetAllowImageGeneration sets the "allow_image_generation" field. +func (u *GroupUpsert) SetAllowImageGeneration(v bool) *GroupUpsert { + u.Set(group.FieldAllowImageGeneration, v) + return u +} + +// UpdateAllowImageGeneration sets the "allow_image_generation" field to the value that was provided on create. +func (u *GroupUpsert) UpdateAllowImageGeneration() *GroupUpsert { + u.SetExcluded(group.FieldAllowImageGeneration) + return u +} + +// SetImageRateIndependent sets the "image_rate_independent" field. +func (u *GroupUpsert) SetImageRateIndependent(v bool) *GroupUpsert { + u.Set(group.FieldImageRateIndependent, v) + return u +} + +// UpdateImageRateIndependent sets the "image_rate_independent" field to the value that was provided on create. +func (u *GroupUpsert) UpdateImageRateIndependent() *GroupUpsert { + u.SetExcluded(group.FieldImageRateIndependent) + return u +} + +// SetImageRateMultiplier sets the "image_rate_multiplier" field. +func (u *GroupUpsert) SetImageRateMultiplier(v float64) *GroupUpsert { + u.Set(group.FieldImageRateMultiplier, v) + return u +} + +// UpdateImageRateMultiplier sets the "image_rate_multiplier" field to the value that was provided on create. +func (u *GroupUpsert) UpdateImageRateMultiplier() *GroupUpsert { + u.SetExcluded(group.FieldImageRateMultiplier) + return u +} + +// AddImageRateMultiplier adds v to the "image_rate_multiplier" field. +func (u *GroupUpsert) AddImageRateMultiplier(v float64) *GroupUpsert { + u.Add(group.FieldImageRateMultiplier, v) + return u +} + // SetImagePrice1k sets the "image_price_1k" field. func (u *GroupUpsert) SetImagePrice1k(v float64) *GroupUpsert { u.Set(group.FieldImagePrice1k, v) @@ -1840,6 +1957,55 @@ func (u *GroupUpsertOne) UpdateDefaultValidityDays() *GroupUpsertOne { }) } +// SetAllowImageGeneration sets the "allow_image_generation" field. +func (u *GroupUpsertOne) SetAllowImageGeneration(v bool) *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.SetAllowImageGeneration(v) + }) +} + +// UpdateAllowImageGeneration sets the "allow_image_generation" field to the value that was provided on create. +func (u *GroupUpsertOne) UpdateAllowImageGeneration() *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.UpdateAllowImageGeneration() + }) +} + +// SetImageRateIndependent sets the "image_rate_independent" field. +func (u *GroupUpsertOne) SetImageRateIndependent(v bool) *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.SetImageRateIndependent(v) + }) +} + +// UpdateImageRateIndependent sets the "image_rate_independent" field to the value that was provided on create. +func (u *GroupUpsertOne) UpdateImageRateIndependent() *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.UpdateImageRateIndependent() + }) +} + +// SetImageRateMultiplier sets the "image_rate_multiplier" field. +func (u *GroupUpsertOne) SetImageRateMultiplier(v float64) *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.SetImageRateMultiplier(v) + }) +} + +// AddImageRateMultiplier adds v to the "image_rate_multiplier" field. +func (u *GroupUpsertOne) AddImageRateMultiplier(v float64) *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.AddImageRateMultiplier(v) + }) +} + +// UpdateImageRateMultiplier sets the "image_rate_multiplier" field to the value that was provided on create. +func (u *GroupUpsertOne) UpdateImageRateMultiplier() *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.UpdateImageRateMultiplier() + }) +} + // SetImagePrice1k sets the "image_price_1k" field. func (u *GroupUpsertOne) SetImagePrice1k(v float64) *GroupUpsertOne { return u.Update(func(s *GroupUpsert) { @@ -2632,6 +2798,55 @@ func (u *GroupUpsertBulk) UpdateDefaultValidityDays() *GroupUpsertBulk { }) } +// SetAllowImageGeneration sets the "allow_image_generation" field. +func (u *GroupUpsertBulk) SetAllowImageGeneration(v bool) *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.SetAllowImageGeneration(v) + }) +} + +// UpdateAllowImageGeneration sets the "allow_image_generation" field to the value that was provided on create. +func (u *GroupUpsertBulk) UpdateAllowImageGeneration() *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.UpdateAllowImageGeneration() + }) +} + +// SetImageRateIndependent sets the "image_rate_independent" field. +func (u *GroupUpsertBulk) SetImageRateIndependent(v bool) *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.SetImageRateIndependent(v) + }) +} + +// UpdateImageRateIndependent sets the "image_rate_independent" field to the value that was provided on create. +func (u *GroupUpsertBulk) UpdateImageRateIndependent() *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.UpdateImageRateIndependent() + }) +} + +// SetImageRateMultiplier sets the "image_rate_multiplier" field. +func (u *GroupUpsertBulk) SetImageRateMultiplier(v float64) *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.SetImageRateMultiplier(v) + }) +} + +// AddImageRateMultiplier adds v to the "image_rate_multiplier" field. +func (u *GroupUpsertBulk) AddImageRateMultiplier(v float64) *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.AddImageRateMultiplier(v) + }) +} + +// UpdateImageRateMultiplier sets the "image_rate_multiplier" field to the value that was provided on create. +func (u *GroupUpsertBulk) UpdateImageRateMultiplier() *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.UpdateImageRateMultiplier() + }) +} + // SetImagePrice1k sets the "image_price_1k" field. func (u *GroupUpsertBulk) SetImagePrice1k(v float64) *GroupUpsertBulk { return u.Update(func(s *GroupUpsert) { diff --git a/backend/ent/group_update.go b/backend/ent/group_update.go index cc14f897..fe55982c 100644 --- a/backend/ent/group_update.go +++ b/backend/ent/group_update.go @@ -275,6 +275,55 @@ func (_u *GroupUpdate) AddDefaultValidityDays(v int) *GroupUpdate { return _u } +// SetAllowImageGeneration sets the "allow_image_generation" field. +func (_u *GroupUpdate) SetAllowImageGeneration(v bool) *GroupUpdate { + _u.mutation.SetAllowImageGeneration(v) + return _u +} + +// SetNillableAllowImageGeneration sets the "allow_image_generation" field if the given value is not nil. +func (_u *GroupUpdate) SetNillableAllowImageGeneration(v *bool) *GroupUpdate { + if v != nil { + _u.SetAllowImageGeneration(*v) + } + return _u +} + +// SetImageRateIndependent sets the "image_rate_independent" field. +func (_u *GroupUpdate) SetImageRateIndependent(v bool) *GroupUpdate { + _u.mutation.SetImageRateIndependent(v) + return _u +} + +// SetNillableImageRateIndependent sets the "image_rate_independent" field if the given value is not nil. +func (_u *GroupUpdate) SetNillableImageRateIndependent(v *bool) *GroupUpdate { + if v != nil { + _u.SetImageRateIndependent(*v) + } + return _u +} + +// SetImageRateMultiplier sets the "image_rate_multiplier" field. +func (_u *GroupUpdate) SetImageRateMultiplier(v float64) *GroupUpdate { + _u.mutation.ResetImageRateMultiplier() + _u.mutation.SetImageRateMultiplier(v) + return _u +} + +// SetNillableImageRateMultiplier sets the "image_rate_multiplier" field if the given value is not nil. +func (_u *GroupUpdate) SetNillableImageRateMultiplier(v *float64) *GroupUpdate { + if v != nil { + _u.SetImageRateMultiplier(*v) + } + return _u +} + +// AddImageRateMultiplier adds value to the "image_rate_multiplier" field. +func (_u *GroupUpdate) AddImageRateMultiplier(v float64) *GroupUpdate { + _u.mutation.AddImageRateMultiplier(v) + return _u +} + // SetImagePrice1k sets the "image_price_1k" field. func (_u *GroupUpdate) SetImagePrice1k(v float64) *GroupUpdate { _u.mutation.ResetImagePrice1k() @@ -962,6 +1011,18 @@ func (_u *GroupUpdate) sqlSave(ctx context.Context) (_node int, err error) { if value, ok := _u.mutation.AddedDefaultValidityDays(); ok { _spec.AddField(group.FieldDefaultValidityDays, field.TypeInt, value) } + if value, ok := _u.mutation.AllowImageGeneration(); ok { + _spec.SetField(group.FieldAllowImageGeneration, field.TypeBool, value) + } + if value, ok := _u.mutation.ImageRateIndependent(); ok { + _spec.SetField(group.FieldImageRateIndependent, field.TypeBool, value) + } + if value, ok := _u.mutation.ImageRateMultiplier(); ok { + _spec.SetField(group.FieldImageRateMultiplier, field.TypeFloat64, value) + } + if value, ok := _u.mutation.AddedImageRateMultiplier(); ok { + _spec.AddField(group.FieldImageRateMultiplier, field.TypeFloat64, value) + } if value, ok := _u.mutation.ImagePrice1k(); ok { _spec.SetField(group.FieldImagePrice1k, field.TypeFloat64, value) } @@ -1610,6 +1671,55 @@ func (_u *GroupUpdateOne) AddDefaultValidityDays(v int) *GroupUpdateOne { return _u } +// SetAllowImageGeneration sets the "allow_image_generation" field. +func (_u *GroupUpdateOne) SetAllowImageGeneration(v bool) *GroupUpdateOne { + _u.mutation.SetAllowImageGeneration(v) + return _u +} + +// SetNillableAllowImageGeneration sets the "allow_image_generation" field if the given value is not nil. +func (_u *GroupUpdateOne) SetNillableAllowImageGeneration(v *bool) *GroupUpdateOne { + if v != nil { + _u.SetAllowImageGeneration(*v) + } + return _u +} + +// SetImageRateIndependent sets the "image_rate_independent" field. +func (_u *GroupUpdateOne) SetImageRateIndependent(v bool) *GroupUpdateOne { + _u.mutation.SetImageRateIndependent(v) + return _u +} + +// SetNillableImageRateIndependent sets the "image_rate_independent" field if the given value is not nil. +func (_u *GroupUpdateOne) SetNillableImageRateIndependent(v *bool) *GroupUpdateOne { + if v != nil { + _u.SetImageRateIndependent(*v) + } + return _u +} + +// SetImageRateMultiplier sets the "image_rate_multiplier" field. +func (_u *GroupUpdateOne) SetImageRateMultiplier(v float64) *GroupUpdateOne { + _u.mutation.ResetImageRateMultiplier() + _u.mutation.SetImageRateMultiplier(v) + return _u +} + +// SetNillableImageRateMultiplier sets the "image_rate_multiplier" field if the given value is not nil. +func (_u *GroupUpdateOne) SetNillableImageRateMultiplier(v *float64) *GroupUpdateOne { + if v != nil { + _u.SetImageRateMultiplier(*v) + } + return _u +} + +// AddImageRateMultiplier adds value to the "image_rate_multiplier" field. +func (_u *GroupUpdateOne) AddImageRateMultiplier(v float64) *GroupUpdateOne { + _u.mutation.AddImageRateMultiplier(v) + return _u +} + // SetImagePrice1k sets the "image_price_1k" field. func (_u *GroupUpdateOne) SetImagePrice1k(v float64) *GroupUpdateOne { _u.mutation.ResetImagePrice1k() @@ -2327,6 +2437,18 @@ func (_u *GroupUpdateOne) sqlSave(ctx context.Context) (_node *Group, err error) if value, ok := _u.mutation.AddedDefaultValidityDays(); ok { _spec.AddField(group.FieldDefaultValidityDays, field.TypeInt, value) } + if value, ok := _u.mutation.AllowImageGeneration(); ok { + _spec.SetField(group.FieldAllowImageGeneration, field.TypeBool, value) + } + if value, ok := _u.mutation.ImageRateIndependent(); ok { + _spec.SetField(group.FieldImageRateIndependent, field.TypeBool, value) + } + if value, ok := _u.mutation.ImageRateMultiplier(); ok { + _spec.SetField(group.FieldImageRateMultiplier, field.TypeFloat64, value) + } + if value, ok := _u.mutation.AddedImageRateMultiplier(); ok { + _spec.AddField(group.FieldImageRateMultiplier, field.TypeFloat64, value) + } if value, ok := _u.mutation.ImagePrice1k(); ok { _spec.SetField(group.FieldImagePrice1k, field.TypeFloat64, value) } diff --git a/backend/ent/migrate/schema.go b/backend/ent/migrate/schema.go index 178ae170..525ff092 100644 --- a/backend/ent/migrate/schema.go +++ b/backend/ent/migrate/schema.go @@ -638,6 +638,9 @@ var ( {Name: "weekly_limit_usd", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}}, {Name: "monthly_limit_usd", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}}, {Name: "default_validity_days", Type: field.TypeInt, Default: 30}, + {Name: "allow_image_generation", Type: field.TypeBool, Default: false}, + {Name: "image_rate_independent", Type: field.TypeBool, Default: false}, + {Name: "image_rate_multiplier", Type: field.TypeFloat64, Default: 1, SchemaType: map[string]string{"postgres": "decimal(10,4)"}}, {Name: "image_price_1k", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}}, {Name: "image_price_2k", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}}, {Name: "image_price_4k", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}}, @@ -690,7 +693,7 @@ var ( { Name: "group_sort_order", Unique: false, - Columns: []*schema.Column{GroupsColumns[25]}, + Columns: []*schema.Column{GroupsColumns[28]}, }, }, } diff --git a/backend/ent/mutation.go b/backend/ent/mutation.go index d616e4ae..13f6193d 100644 --- a/backend/ent/mutation.go +++ b/backend/ent/mutation.go @@ -14764,6 +14764,10 @@ type GroupMutation struct { addmonthly_limit_usd *float64 default_validity_days *int adddefault_validity_days *int + allow_image_generation *bool + image_rate_independent *bool + image_rate_multiplier *float64 + addimage_rate_multiplier *float64 image_price_1k *float64 addimage_price_1k *float64 image_price_2k *float64 @@ -15583,6 +15587,134 @@ func (m *GroupMutation) ResetDefaultValidityDays() { m.adddefault_validity_days = nil } +// SetAllowImageGeneration sets the "allow_image_generation" field. +func (m *GroupMutation) SetAllowImageGeneration(b bool) { + m.allow_image_generation = &b +} + +// AllowImageGeneration returns the value of the "allow_image_generation" field in the mutation. +func (m *GroupMutation) AllowImageGeneration() (r bool, exists bool) { + v := m.allow_image_generation + if v == nil { + return + } + return *v, true +} + +// OldAllowImageGeneration returns the old "allow_image_generation" field's value of the Group entity. +// If the Group object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *GroupMutation) OldAllowImageGeneration(ctx context.Context) (v bool, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldAllowImageGeneration is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldAllowImageGeneration requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldAllowImageGeneration: %w", err) + } + return oldValue.AllowImageGeneration, nil +} + +// ResetAllowImageGeneration resets all changes to the "allow_image_generation" field. +func (m *GroupMutation) ResetAllowImageGeneration() { + m.allow_image_generation = nil +} + +// SetImageRateIndependent sets the "image_rate_independent" field. +func (m *GroupMutation) SetImageRateIndependent(b bool) { + m.image_rate_independent = &b +} + +// ImageRateIndependent returns the value of the "image_rate_independent" field in the mutation. +func (m *GroupMutation) ImageRateIndependent() (r bool, exists bool) { + v := m.image_rate_independent + if v == nil { + return + } + return *v, true +} + +// OldImageRateIndependent returns the old "image_rate_independent" field's value of the Group entity. +// If the Group object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *GroupMutation) OldImageRateIndependent(ctx context.Context) (v bool, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldImageRateIndependent is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldImageRateIndependent requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldImageRateIndependent: %w", err) + } + return oldValue.ImageRateIndependent, nil +} + +// ResetImageRateIndependent resets all changes to the "image_rate_independent" field. +func (m *GroupMutation) ResetImageRateIndependent() { + m.image_rate_independent = nil +} + +// SetImageRateMultiplier sets the "image_rate_multiplier" field. +func (m *GroupMutation) SetImageRateMultiplier(f float64) { + m.image_rate_multiplier = &f + m.addimage_rate_multiplier = nil +} + +// ImageRateMultiplier returns the value of the "image_rate_multiplier" field in the mutation. +func (m *GroupMutation) ImageRateMultiplier() (r float64, exists bool) { + v := m.image_rate_multiplier + if v == nil { + return + } + return *v, true +} + +// OldImageRateMultiplier returns the old "image_rate_multiplier" field's value of the Group entity. +// If the Group object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *GroupMutation) OldImageRateMultiplier(ctx context.Context) (v float64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldImageRateMultiplier is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldImageRateMultiplier requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldImageRateMultiplier: %w", err) + } + return oldValue.ImageRateMultiplier, nil +} + +// AddImageRateMultiplier adds f to the "image_rate_multiplier" field. +func (m *GroupMutation) AddImageRateMultiplier(f float64) { + if m.addimage_rate_multiplier != nil { + *m.addimage_rate_multiplier += f + } else { + m.addimage_rate_multiplier = &f + } +} + +// AddedImageRateMultiplier returns the value that was added to the "image_rate_multiplier" field in this mutation. +func (m *GroupMutation) AddedImageRateMultiplier() (r float64, exists bool) { + v := m.addimage_rate_multiplier + if v == nil { + return + } + return *v, true +} + +// ResetImageRateMultiplier resets all changes to the "image_rate_multiplier" field. +func (m *GroupMutation) ResetImageRateMultiplier() { + m.image_rate_multiplier = nil + m.addimage_rate_multiplier = nil +} + // SetImagePrice1k sets the "image_price_1k" field. func (m *GroupMutation) SetImagePrice1k(f float64) { m.image_price_1k = &f @@ -16791,7 +16923,7 @@ func (m *GroupMutation) Type() string { // order to get all numeric fields that were incremented/decremented, call // AddedFields(). func (m *GroupMutation) Fields() []string { - fields := make([]string, 0, 31) + fields := make([]string, 0, 34) if m.created_at != nil { fields = append(fields, group.FieldCreatedAt) } @@ -16834,6 +16966,15 @@ func (m *GroupMutation) Fields() []string { if m.default_validity_days != nil { fields = append(fields, group.FieldDefaultValidityDays) } + if m.allow_image_generation != nil { + fields = append(fields, group.FieldAllowImageGeneration) + } + if m.image_rate_independent != nil { + fields = append(fields, group.FieldImageRateIndependent) + } + if m.image_rate_multiplier != nil { + fields = append(fields, group.FieldImageRateMultiplier) + } if m.image_price_1k != nil { fields = append(fields, group.FieldImagePrice1k) } @@ -16921,6 +17062,12 @@ func (m *GroupMutation) Field(name string) (ent.Value, bool) { return m.MonthlyLimitUsd() case group.FieldDefaultValidityDays: return m.DefaultValidityDays() + case group.FieldAllowImageGeneration: + return m.AllowImageGeneration() + case group.FieldImageRateIndependent: + return m.ImageRateIndependent() + case group.FieldImageRateMultiplier: + return m.ImageRateMultiplier() case group.FieldImagePrice1k: return m.ImagePrice1k() case group.FieldImagePrice2k: @@ -16992,6 +17139,12 @@ func (m *GroupMutation) OldField(ctx context.Context, name string) (ent.Value, e return m.OldMonthlyLimitUsd(ctx) case group.FieldDefaultValidityDays: return m.OldDefaultValidityDays(ctx) + case group.FieldAllowImageGeneration: + return m.OldAllowImageGeneration(ctx) + case group.FieldImageRateIndependent: + return m.OldImageRateIndependent(ctx) + case group.FieldImageRateMultiplier: + return m.OldImageRateMultiplier(ctx) case group.FieldImagePrice1k: return m.OldImagePrice1k(ctx) case group.FieldImagePrice2k: @@ -17133,6 +17286,27 @@ func (m *GroupMutation) SetField(name string, value ent.Value) error { } m.SetDefaultValidityDays(v) return nil + case group.FieldAllowImageGeneration: + v, ok := value.(bool) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetAllowImageGeneration(v) + return nil + case group.FieldImageRateIndependent: + v, ok := value.(bool) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetImageRateIndependent(v) + return nil + case group.FieldImageRateMultiplier: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetImageRateMultiplier(v) + return nil case group.FieldImagePrice1k: v, ok := value.(float64) if !ok { @@ -17275,6 +17449,9 @@ func (m *GroupMutation) AddedFields() []string { if m.adddefault_validity_days != nil { fields = append(fields, group.FieldDefaultValidityDays) } + if m.addimage_rate_multiplier != nil { + fields = append(fields, group.FieldImageRateMultiplier) + } if m.addimage_price_1k != nil { fields = append(fields, group.FieldImagePrice1k) } @@ -17314,6 +17491,8 @@ func (m *GroupMutation) AddedField(name string) (ent.Value, bool) { return m.AddedMonthlyLimitUsd() case group.FieldDefaultValidityDays: return m.AddedDefaultValidityDays() + case group.FieldImageRateMultiplier: + return m.AddedImageRateMultiplier() case group.FieldImagePrice1k: return m.AddedImagePrice1k() case group.FieldImagePrice2k: @@ -17372,6 +17551,13 @@ func (m *GroupMutation) AddField(name string, value ent.Value) error { } m.AddDefaultValidityDays(v) return nil + case group.FieldImageRateMultiplier: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddImageRateMultiplier(v) + return nil case group.FieldImagePrice1k: v, ok := value.(float64) if !ok { @@ -17559,6 +17745,15 @@ func (m *GroupMutation) ResetField(name string) error { case group.FieldDefaultValidityDays: m.ResetDefaultValidityDays() return nil + case group.FieldAllowImageGeneration: + m.ResetAllowImageGeneration() + return nil + case group.FieldImageRateIndependent: + m.ResetImageRateIndependent() + return nil + case group.FieldImageRateMultiplier: + m.ResetImageRateMultiplier() + return nil case group.FieldImagePrice1k: m.ResetImagePrice1k() return nil diff --git a/backend/ent/runtime/runtime.go b/backend/ent/runtime/runtime.go index 6b344a55..a282d9ba 100644 --- a/backend/ent/runtime/runtime.go +++ b/backend/ent/runtime/runtime.go @@ -803,50 +803,62 @@ func init() { groupDescDefaultValidityDays := groupFields[10].Descriptor() // group.DefaultDefaultValidityDays holds the default value on creation for the default_validity_days field. group.DefaultDefaultValidityDays = groupDescDefaultValidityDays.Default.(int) + // groupDescAllowImageGeneration is the schema descriptor for allow_image_generation field. + groupDescAllowImageGeneration := groupFields[11].Descriptor() + // group.DefaultAllowImageGeneration holds the default value on creation for the allow_image_generation field. + group.DefaultAllowImageGeneration = groupDescAllowImageGeneration.Default.(bool) + // groupDescImageRateIndependent is the schema descriptor for image_rate_independent field. + groupDescImageRateIndependent := groupFields[12].Descriptor() + // group.DefaultImageRateIndependent holds the default value on creation for the image_rate_independent field. + group.DefaultImageRateIndependent = groupDescImageRateIndependent.Default.(bool) + // groupDescImageRateMultiplier is the schema descriptor for image_rate_multiplier field. + groupDescImageRateMultiplier := groupFields[13].Descriptor() + // group.DefaultImageRateMultiplier holds the default value on creation for the image_rate_multiplier field. + group.DefaultImageRateMultiplier = groupDescImageRateMultiplier.Default.(float64) // groupDescClaudeCodeOnly is the schema descriptor for claude_code_only field. - groupDescClaudeCodeOnly := groupFields[14].Descriptor() + groupDescClaudeCodeOnly := groupFields[17].Descriptor() // group.DefaultClaudeCodeOnly holds the default value on creation for the claude_code_only field. group.DefaultClaudeCodeOnly = groupDescClaudeCodeOnly.Default.(bool) // groupDescModelRoutingEnabled is the schema descriptor for model_routing_enabled field. - groupDescModelRoutingEnabled := groupFields[18].Descriptor() + groupDescModelRoutingEnabled := groupFields[21].Descriptor() // group.DefaultModelRoutingEnabled holds the default value on creation for the model_routing_enabled field. group.DefaultModelRoutingEnabled = groupDescModelRoutingEnabled.Default.(bool) // groupDescMcpXMLInject is the schema descriptor for mcp_xml_inject field. - groupDescMcpXMLInject := groupFields[19].Descriptor() + groupDescMcpXMLInject := groupFields[22].Descriptor() // group.DefaultMcpXMLInject holds the default value on creation for the mcp_xml_inject field. group.DefaultMcpXMLInject = groupDescMcpXMLInject.Default.(bool) // groupDescSupportedModelScopes is the schema descriptor for supported_model_scopes field. - groupDescSupportedModelScopes := groupFields[20].Descriptor() + groupDescSupportedModelScopes := groupFields[23].Descriptor() // group.DefaultSupportedModelScopes holds the default value on creation for the supported_model_scopes field. group.DefaultSupportedModelScopes = groupDescSupportedModelScopes.Default.([]string) // groupDescSortOrder is the schema descriptor for sort_order field. - groupDescSortOrder := groupFields[21].Descriptor() + groupDescSortOrder := groupFields[24].Descriptor() // group.DefaultSortOrder holds the default value on creation for the sort_order field. group.DefaultSortOrder = groupDescSortOrder.Default.(int) // groupDescAllowMessagesDispatch is the schema descriptor for allow_messages_dispatch field. - groupDescAllowMessagesDispatch := groupFields[22].Descriptor() + groupDescAllowMessagesDispatch := groupFields[25].Descriptor() // group.DefaultAllowMessagesDispatch holds the default value on creation for the allow_messages_dispatch field. group.DefaultAllowMessagesDispatch = groupDescAllowMessagesDispatch.Default.(bool) // groupDescRequireOauthOnly is the schema descriptor for require_oauth_only field. - groupDescRequireOauthOnly := groupFields[23].Descriptor() + groupDescRequireOauthOnly := groupFields[26].Descriptor() // group.DefaultRequireOauthOnly holds the default value on creation for the require_oauth_only field. group.DefaultRequireOauthOnly = groupDescRequireOauthOnly.Default.(bool) // groupDescRequirePrivacySet is the schema descriptor for require_privacy_set field. - groupDescRequirePrivacySet := groupFields[24].Descriptor() + groupDescRequirePrivacySet := groupFields[27].Descriptor() // group.DefaultRequirePrivacySet holds the default value on creation for the require_privacy_set field. group.DefaultRequirePrivacySet = groupDescRequirePrivacySet.Default.(bool) // groupDescDefaultMappedModel is the schema descriptor for default_mapped_model field. - groupDescDefaultMappedModel := groupFields[25].Descriptor() + groupDescDefaultMappedModel := groupFields[28].Descriptor() // group.DefaultDefaultMappedModel holds the default value on creation for the default_mapped_model field. group.DefaultDefaultMappedModel = groupDescDefaultMappedModel.Default.(string) // group.DefaultMappedModelValidator is a validator for the "default_mapped_model" field. It is called by the builders before save. group.DefaultMappedModelValidator = groupDescDefaultMappedModel.Validators[0].(func(string) error) // groupDescMessagesDispatchModelConfig is the schema descriptor for messages_dispatch_model_config field. - groupDescMessagesDispatchModelConfig := groupFields[26].Descriptor() + groupDescMessagesDispatchModelConfig := groupFields[29].Descriptor() // group.DefaultMessagesDispatchModelConfig holds the default value on creation for the messages_dispatch_model_config field. group.DefaultMessagesDispatchModelConfig = groupDescMessagesDispatchModelConfig.Default.(domain.OpenAIMessagesDispatchModelConfig) // groupDescRpmLimit is the schema descriptor for rpm_limit field. - groupDescRpmLimit := groupFields[27].Descriptor() + groupDescRpmLimit := groupFields[30].Descriptor() // group.DefaultRpmLimit holds the default value on creation for the rpm_limit field. group.DefaultRpmLimit = groupDescRpmLimit.Default.(int) idempotencyrecordMixin := schema.IdempotencyRecord{}.Mixin() diff --git a/backend/ent/schema/group.go b/backend/ent/schema/group.go index 11f38d66..d47e8710 100644 --- a/backend/ent/schema/group.go +++ b/backend/ent/schema/group.go @@ -74,6 +74,16 @@ func (Group) Fields() []ent.Field { Default(30), // 图片生成计费配置(antigravity 和 gemini 平台使用) + field.Bool("allow_image_generation"). + Default(false). + Comment("是否允许该分组使用图片生成能力"), + field.Bool("image_rate_independent"). + Default(false). + Comment("图片生成是否使用独立倍率;false 表示共享分组有效倍率"), + field.Float("image_rate_multiplier"). + SchemaType(map[string]string{dialect.Postgres: "decimal(10,4)"}). + Default(1.0). + Comment("图片生成独立倍率,仅 image_rate_independent=true 时生效"), field.Float("image_price_1k"). Optional(). Nillable(). diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go index 112527e2..30d6db3f 100644 --- a/backend/internal/config/config.go +++ b/backend/internal/config/config.go @@ -576,6 +576,24 @@ type ConcurrencyConfig struct { PingInterval int `mapstructure:"ping_interval"` } +type ImageConcurrencyConfig struct { + // Enabled: 是否启用图片生成独立并发限制,默认关闭以保持现有行为 + Enabled bool `mapstructure:"enabled"` + // MaxConcurrentRequests: 当前进程允许同时处理的图片生成请求数,0表示不限制 + MaxConcurrentRequests int `mapstructure:"max_concurrent_requests"` + // OverflowMode: 图片并发达到上限后的处理方式:reject/wait + OverflowMode string `mapstructure:"overflow_mode"` + // WaitTimeoutSeconds: overflow_mode=wait 时等待图片并发槽位的超时时间(秒) + WaitTimeoutSeconds int `mapstructure:"wait_timeout_seconds"` + // MaxWaitingRequests: overflow_mode=wait 时当前进程允许排队等待的图片请求数 + MaxWaitingRequests int `mapstructure:"max_waiting_requests"` +} + +const ( + ImageConcurrencyOverflowModeReject = "reject" + ImageConcurrencyOverflowModeWait = "wait" +) + // GatewayConfig API网关相关配置 type GatewayConfig struct { // 等待上游响应头的超时时间(秒),0表示无超时 @@ -609,6 +627,8 @@ type GatewayConfig struct { AntigravityLSWorker GatewayAntigravityLSWorkerConfig `mapstructure:"antigravity_ls_worker"` // NodeTLSProxy: Node.js TLS 代理配置 NodeTLSProxy NodeTLSProxyConfig `mapstructure:"node_tls_proxy"` + // ImageConcurrency: 图片生成独立并发限制配置(默认关闭) + ImageConcurrency ImageConcurrencyConfig `mapstructure:"image_concurrency"` // HTTP 上游连接池配置(性能优化:支持高并发场景调优) // MaxIdleConns: 所有主机的最大空闲连接总数 @@ -640,6 +660,10 @@ type GatewayConfig struct { StreamDataIntervalTimeout int `mapstructure:"stream_data_interval_timeout"` // StreamKeepaliveInterval: 流式 keepalive 间隔(秒),0表示禁用 StreamKeepaliveInterval int `mapstructure:"stream_keepalive_interval"` + // ImageStreamDataIntervalTimeout: 图片流数据间隔超时(秒),0表示禁用 + ImageStreamDataIntervalTimeout int `mapstructure:"image_stream_data_interval_timeout"` + // ImageStreamKeepaliveInterval: 图片流式 keepalive 间隔(秒),0表示禁用 + ImageStreamKeepaliveInterval int `mapstructure:"image_stream_keepalive_interval"` // MaxLineSize: 上游 SSE 单行最大字节数(0使用默认值) MaxLineSize int `mapstructure:"max_line_size"` @@ -1789,6 +1813,11 @@ func setDefaults() { viper.SetDefault("gateway.openai_ws.scheduler_score_weights.queue", 0.7) viper.SetDefault("gateway.openai_ws.scheduler_score_weights.error_rate", 0.8) viper.SetDefault("gateway.openai_ws.scheduler_score_weights.ttft", 0.5) + viper.SetDefault("gateway.image_concurrency.enabled", false) + viper.SetDefault("gateway.image_concurrency.max_concurrent_requests", 0) + viper.SetDefault("gateway.image_concurrency.overflow_mode", ImageConcurrencyOverflowModeReject) + viper.SetDefault("gateway.image_concurrency.wait_timeout_seconds", 30) + viper.SetDefault("gateway.image_concurrency.max_waiting_requests", 100) viper.SetDefault("gateway.antigravity_fallback_cooldown_minutes", 1) viper.SetDefault("gateway.antigravity_extra_retries", 10) viper.SetDefault("gateway.max_body_size", int64(256*1024*1024)) @@ -1806,6 +1835,8 @@ func setDefaults() { viper.SetDefault("gateway.concurrency_slot_ttl_minutes", 30) // 并发槽位过期时间(支持超长请求) viper.SetDefault("gateway.stream_data_interval_timeout", 180) viper.SetDefault("gateway.stream_keepalive_interval", 10) + viper.SetDefault("gateway.image_stream_data_interval_timeout", 900) + viper.SetDefault("gateway.image_stream_keepalive_interval", 10) viper.SetDefault("gateway.max_line_size", 500*1024*1024) viper.SetDefault("gateway.scheduling.sticky_session_max_waiting", 3) viper.SetDefault("gateway.scheduling.sticky_session_wait_timeout", 120*time.Second) @@ -2410,6 +2441,21 @@ func (c *Config) Validate() error { ConnectionPoolIsolationProxy, ConnectionPoolIsolationAccount, ConnectionPoolIsolationAccountProxy) } } + if c.Gateway.ImageConcurrency.MaxConcurrentRequests < 0 { + return fmt.Errorf("gateway.image_concurrency.max_concurrent_requests must be non-negative") + } + switch strings.TrimSpace(c.Gateway.ImageConcurrency.OverflowMode) { + case "", ImageConcurrencyOverflowModeReject, ImageConcurrencyOverflowModeWait: + default: + return fmt.Errorf("gateway.image_concurrency.overflow_mode must be one of: %s/%s", + ImageConcurrencyOverflowModeReject, ImageConcurrencyOverflowModeWait) + } + if c.Gateway.ImageConcurrency.WaitTimeoutSeconds < 0 { + return fmt.Errorf("gateway.image_concurrency.wait_timeout_seconds must be non-negative") + } + if c.Gateway.ImageConcurrency.MaxWaitingRequests < 0 { + return fmt.Errorf("gateway.image_concurrency.max_waiting_requests must be non-negative") + } if c.Gateway.MaxIdleConns <= 0 { return fmt.Errorf("gateway.max_idle_conns must be positive") } @@ -2448,6 +2494,20 @@ func (c *Config) Validate() error { (c.Gateway.StreamKeepaliveInterval < 5 || c.Gateway.StreamKeepaliveInterval > 30) { return fmt.Errorf("gateway.stream_keepalive_interval must be 0 or between 5-30 seconds") } + if c.Gateway.ImageStreamDataIntervalTimeout < 0 { + return fmt.Errorf("gateway.image_stream_data_interval_timeout must be non-negative") + } + if c.Gateway.ImageStreamDataIntervalTimeout != 0 && + (c.Gateway.ImageStreamDataIntervalTimeout < 60 || c.Gateway.ImageStreamDataIntervalTimeout > 1800) { + return fmt.Errorf("gateway.image_stream_data_interval_timeout must be 0 or between 60-1800 seconds") + } + if c.Gateway.ImageStreamKeepaliveInterval < 0 { + return fmt.Errorf("gateway.image_stream_keepalive_interval must be non-negative") + } + if c.Gateway.ImageStreamKeepaliveInterval != 0 && + (c.Gateway.ImageStreamKeepaliveInterval < 5 || c.Gateway.ImageStreamKeepaliveInterval > 60) { + return fmt.Errorf("gateway.image_stream_keepalive_interval must be 0 or between 5-60 seconds") + } // 兼容旧键 sticky_previous_response_ttl_seconds if c.Gateway.OpenAIWS.StickyResponseIDTTLSeconds <= 0 && c.Gateway.OpenAIWS.StickyPreviousResponseTTLSeconds > 0 { c.Gateway.OpenAIWS.StickyResponseIDTTLSeconds = c.Gateway.OpenAIWS.StickyPreviousResponseTTLSeconds diff --git a/backend/internal/config/config_test.go b/backend/internal/config/config_test.go index 6ba86aa1..a47de2f8 100644 --- a/backend/internal/config/config_test.go +++ b/backend/internal/config/config_test.go @@ -1282,6 +1282,46 @@ func TestValidateConfigErrors(t *testing.T) { mutate: func(c *Config) { c.Gateway.StreamDataIntervalTimeout = -1 }, wantErr: "gateway.stream_data_interval_timeout must be non-negative", }, + { + name: "gateway image stream keepalive range", + mutate: func(c *Config) { c.Gateway.ImageStreamKeepaliveInterval = 4 }, + wantErr: "gateway.image_stream_keepalive_interval", + }, + { + name: "gateway image stream keepalive negative", + mutate: func(c *Config) { c.Gateway.ImageStreamKeepaliveInterval = -1 }, + wantErr: "gateway.image_stream_keepalive_interval must be non-negative", + }, + { + name: "gateway image stream data interval range", + mutate: func(c *Config) { c.Gateway.ImageStreamDataIntervalTimeout = 30 }, + wantErr: "gateway.image_stream_data_interval_timeout", + }, + { + name: "gateway image stream data interval negative", + mutate: func(c *Config) { c.Gateway.ImageStreamDataIntervalTimeout = -1 }, + wantErr: "gateway.image_stream_data_interval_timeout must be non-negative", + }, + { + name: "gateway image concurrency max negative", + mutate: func(c *Config) { c.Gateway.ImageConcurrency.MaxConcurrentRequests = -1 }, + wantErr: "gateway.image_concurrency.max_concurrent_requests must be non-negative", + }, + { + name: "gateway image concurrency overflow mode invalid", + mutate: func(c *Config) { c.Gateway.ImageConcurrency.OverflowMode = "queue" }, + wantErr: "gateway.image_concurrency.overflow_mode", + }, + { + name: "gateway image concurrency wait timeout negative", + mutate: func(c *Config) { c.Gateway.ImageConcurrency.WaitTimeoutSeconds = -1 }, + wantErr: "gateway.image_concurrency.wait_timeout_seconds must be non-negative", + }, + { + name: "gateway image concurrency max waiting negative", + mutate: func(c *Config) { c.Gateway.ImageConcurrency.MaxWaitingRequests = -1 }, + wantErr: "gateway.image_concurrency.max_waiting_requests must be non-negative", + }, { name: "gateway max line size", mutate: func(c *Config) { c.Gateway.MaxLineSize = 1024 }, @@ -1754,3 +1794,41 @@ func TestLoad_DefaultGatewayUsageRecordConfig(t *testing.T) { t.Fatalf("auto_scale_cooldown_seconds = %d, want 10", cfg.Gateway.UsageRecord.AutoScaleCooldownSeconds) } } + +func TestLoad_DefaultGatewayImageStreamConfig(t *testing.T) { + resetViperWithJWTSecret(t) + cfg, err := Load() + if err != nil { + t.Fatalf("Load() error: %v", err) + } + if cfg.Gateway.StreamDataIntervalTimeout != 180 { + t.Fatalf("stream_data_interval_timeout = %d, want 180", cfg.Gateway.StreamDataIntervalTimeout) + } + if cfg.Gateway.StreamKeepaliveInterval != 10 { + t.Fatalf("stream_keepalive_interval = %d, want 10", cfg.Gateway.StreamKeepaliveInterval) + } + if cfg.Gateway.ImageStreamDataIntervalTimeout != 900 { + t.Fatalf("image_stream_data_interval_timeout = %d, want 900", cfg.Gateway.ImageStreamDataIntervalTimeout) + } + if cfg.Gateway.ImageStreamKeepaliveInterval != 10 { + t.Fatalf("image_stream_keepalive_interval = %d, want 10", cfg.Gateway.ImageStreamKeepaliveInterval) + } + if cfg.Gateway.ImageConcurrency.Enabled { + t.Fatalf("image_concurrency.enabled = true, want false") + } + if cfg.Gateway.ImageConcurrency.MaxConcurrentRequests != 0 { + t.Fatalf("image_concurrency.max_concurrent_requests = %d, want 0", cfg.Gateway.ImageConcurrency.MaxConcurrentRequests) + } + if cfg.Gateway.ImageConcurrency.OverflowMode != ImageConcurrencyOverflowModeReject { + t.Fatalf("image_concurrency.overflow_mode = %q, want %q", cfg.Gateway.ImageConcurrency.OverflowMode, ImageConcurrencyOverflowModeReject) + } + if cfg.Gateway.ImageConcurrency.WaitTimeoutSeconds != 30 { + t.Fatalf("image_concurrency.wait_timeout_seconds = %d, want 30", cfg.Gateway.ImageConcurrency.WaitTimeoutSeconds) + } + if cfg.Gateway.ImageConcurrency.MaxWaitingRequests != 100 { + t.Fatalf("image_concurrency.max_waiting_requests = %d, want 100", cfg.Gateway.ImageConcurrency.MaxWaitingRequests) + } + if cfg.Gateway.ImageStreamDataIntervalTimeout <= cfg.Gateway.StreamDataIntervalTimeout { + t.Fatalf("image stream timeout = %d, want greater than ordinary stream timeout %d", cfg.Gateway.ImageStreamDataIntervalTimeout, cfg.Gateway.StreamDataIntervalTimeout) + } +} diff --git a/backend/internal/handler/admin/account_handler.go b/backend/internal/handler/admin/account_handler.go index c107c329..eade363e 100644 --- a/backend/internal/handler/admin/account_handler.go +++ b/backend/internal/handler/admin/account_handler.go @@ -529,6 +529,10 @@ func (h *AccountHandler) Create(c *gin.Context) { // 确定是否跳过混合渠道检查 skipCheck := req.ConfirmMixedChannelRisk != nil && *req.ConfirmMixedChannelRisk + // 捕获闭包内创建的账号引用,用于创建成功后触发异步探测。 + // 幂等重放时闭包不会执行 → createdAccount 为 nil → 不重复调度。 + var createdAccount *service.Account + result, err := executeAdminIdempotent(c, "admin.accounts.create", req, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) { account, execErr := h.adminService.CreateAccount(ctx, &service.CreateAccountInput{ Name: req.Name, @@ -550,6 +554,7 @@ func (h *AccountHandler) Create(c *gin.Context) { if execErr != nil { return nil, execErr } + createdAccount = account // Antigravity OAuth: 新账号直接设置隐私 h.adminService.ForceAntigravityPrivacy(ctx, account) // OpenAI OAuth: 新账号直接设置隐私 @@ -578,6 +583,9 @@ func (h *AccountHandler) Create(c *gin.Context) { if result != nil && result.Replayed { c.Header("X-Idempotency-Replayed", "true") } + // OpenAI APIKey 账号创建后异步探测上游 /v1/responses 能力。 + // 探测失败不影响账号创建响应。 + h.scheduleOpenAIResponsesProbe(createdAccount) response.Success(c, result.Data) } @@ -638,9 +646,39 @@ func (h *AccountHandler) Update(c *gin.Context) { return } + // OpenAI APIKey: credentials 修改后重新探测上游能力(base_url/api_key 可能变更)。 + // 异步执行,探测失败不影响账号更新响应。 + if len(req.Credentials) > 0 { + h.scheduleOpenAIResponsesProbe(account) + } + response.Success(c, h.buildAccountResponseWithRuntime(c.Request.Context(), account)) } +// scheduleOpenAIResponsesProbe 异步触发 OpenAI APIKey 账号的 Responses API 能力探测。 +// +// 仅对 platform=openai && type=apikey 账号生效;其他账号无操作。 +// 探测本身在 goroutine 中执行(会发一次 HTTP 请求到上游),不会阻塞 +// 当前请求。探测错误仅记录日志,不向上下文传播:探测失败时标记保持缺失, +// 网关会按"现状即证据"默认走 Responses。 +func (h *AccountHandler) scheduleOpenAIResponsesProbe(account *service.Account) { + if account == nil || account.Platform != service.PlatformOpenAI || account.Type != service.AccountTypeAPIKey { + return + } + if h.accountTestService == nil { + return + } + accountID := account.ID + go func() { + defer func() { + if r := recover(); r != nil { + slog.Error("openai_responses_probe_panic", "account_id", accountID, "recover", r) + } + }() + h.accountTestService.ProbeOpenAIAPIKeyResponsesSupport(context.Background(), accountID) + }() +} + // Delete handles deleting an account // DELETE /api/v1/admin/accounts/:id func (h *AccountHandler) Delete(c *gin.Context) { @@ -1232,6 +1270,8 @@ func (h *AccountHandler) BatchCreate(c *gin.Context) { openaiPrivacyAccounts = append(openaiPrivacyAccounts, account) } } + // OpenAI APIKey 账号异步探测 /v1/responses 能力。 + h.scheduleOpenAIResponsesProbe(account) success++ results = append(results, gin.H{ "name": item.Name, diff --git a/backend/internal/handler/admin/affiliate_handler.go b/backend/internal/handler/admin/affiliate_handler.go index 97e649ec..d443d344 100644 --- a/backend/internal/handler/admin/affiliate_handler.go +++ b/backend/internal/handler/admin/affiliate_handler.go @@ -2,8 +2,11 @@ package admin import ( "strconv" + "strings" + "time" "github.com/Wei-Shaw/sub2api/internal/pkg/response" + "github.com/Wei-Shaw/sub2api/internal/pkg/timezone" "github.com/Wei-Shaw/sub2api/internal/service" "github.com/gin-gonic/gin" @@ -181,3 +184,108 @@ func (h *AffiliateHandler) LookupUsers(c *gin.Context) { } response.Success(c, result) } + +// GetUserOverview returns one user's affiliate overview. +// GET /api/v1/admin/affiliates/users/:user_id/overview +func (h *AffiliateHandler) GetUserOverview(c *gin.Context) { + userID, err := strconv.ParseInt(c.Param("user_id"), 10, 64) + if err != nil || userID <= 0 { + response.BadRequest(c, "Invalid user_id") + return + } + overview, err := h.affiliateService.AdminGetUserOverview(c.Request.Context(), userID) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, overview) +} + +// ListInviteRecords returns all inviter-invitee relationships. +// GET /api/v1/admin/affiliates/invites +func (h *AffiliateHandler) ListInviteRecords(c *gin.Context) { + page, pageSize := response.ParsePagination(c) + filter := parseAffiliateRecordFilter(c, page, pageSize) + items, total, err := h.affiliateService.AdminListInviteRecords(c.Request.Context(), filter) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Paginated(c, items, total, filter.Page, filter.PageSize) +} + +// ListRebateRecords returns all order-level affiliate rebate records. +// GET /api/v1/admin/affiliates/rebates +func (h *AffiliateHandler) ListRebateRecords(c *gin.Context) { + page, pageSize := response.ParsePagination(c) + filter := parseAffiliateRecordFilter(c, page, pageSize) + items, total, err := h.affiliateService.AdminListRebateRecords(c.Request.Context(), filter) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Paginated(c, items, total, filter.Page, filter.PageSize) +} + +// ListTransferRecords returns all affiliate quota-to-balance transfer records. +// GET /api/v1/admin/affiliates/transfers +func (h *AffiliateHandler) ListTransferRecords(c *gin.Context) { + page, pageSize := response.ParsePagination(c) + filter := parseAffiliateRecordFilter(c, page, pageSize) + items, total, err := h.affiliateService.AdminListTransferRecords(c.Request.Context(), filter) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Paginated(c, items, total, filter.Page, filter.PageSize) +} + +func parseAffiliateRecordFilter(c *gin.Context, page, pageSize int) service.AffiliateRecordFilter { + filter := service.AffiliateRecordFilter{ + Search: c.Query("search"), + Page: page, + PageSize: pageSize, + SortBy: c.Query("sort_by"), + SortDesc: c.Query("sort_order") != "asc", + } + if filter.PageSize > 100 { + filter.PageSize = 100 + } + userTZ := c.Query("timezone") + if t := parseAffiliateRecordStartTime(c.Query("start_at"), userTZ); t != nil { + filter.StartAt = t + } + if t := parseAffiliateRecordEndTime(c.Query("end_at"), userTZ); t != nil { + filter.EndAt = t + } + return filter +} + +func parseAffiliateRecordStartTime(raw string, userTZ string) *time.Time { + raw = strings.TrimSpace(raw) + if raw == "" { + return nil + } + if parsed, err := time.Parse(time.RFC3339, raw); err == nil { + return &parsed + } + if parsed, err := timezone.ParseInUserLocation("2006-01-02", raw, userTZ); err == nil { + return &parsed + } + return nil +} + +func parseAffiliateRecordEndTime(raw string, userTZ string) *time.Time { + raw = strings.TrimSpace(raw) + if raw == "" { + return nil + } + if parsed, err := time.Parse(time.RFC3339, raw); err == nil { + return &parsed + } + if parsed, err := timezone.ParseInUserLocation("2006-01-02", raw, userTZ); err == nil { + end := parsed.AddDate(0, 0, 1).Add(-time.Nanosecond) + return &end + } + return nil +} diff --git a/backend/internal/handler/admin/group_handler.go b/backend/internal/handler/admin/group_handler.go index a1de536b..17b9555f 100644 --- a/backend/internal/handler/admin/group_handler.go +++ b/backend/internal/handler/admin/group_handler.go @@ -92,6 +92,9 @@ type CreateGroupRequest struct { WeeklyLimitUSD optionalLimitField `json:"weekly_limit_usd"` MonthlyLimitUSD optionalLimitField `json:"monthly_limit_usd"` // 图片生成计费配置(antigravity 和 gemini 平台使用,负数表示清除配置) + AllowImageGeneration bool `json:"allow_image_generation"` + ImageRateIndependent bool `json:"image_rate_independent"` + ImageRateMultiplier *float64 `json:"image_rate_multiplier"` ImagePrice1K *float64 `json:"image_price_1k"` ImagePrice2K *float64 `json:"image_price_2k"` ImagePrice4K *float64 `json:"image_price_4k"` @@ -129,6 +132,9 @@ type UpdateGroupRequest struct { WeeklyLimitUSD optionalLimitField `json:"weekly_limit_usd"` MonthlyLimitUSD optionalLimitField `json:"monthly_limit_usd"` // 图片生成计费配置(antigravity 和 gemini 平台使用,负数表示清除配置) + AllowImageGeneration *bool `json:"allow_image_generation"` + ImageRateIndependent *bool `json:"image_rate_independent"` + ImageRateMultiplier *float64 `json:"image_rate_multiplier"` ImagePrice1K *float64 `json:"image_price_1k"` ImagePrice2K *float64 `json:"image_price_2k"` ImagePrice4K *float64 `json:"image_price_4k"` @@ -251,6 +257,9 @@ func (h *GroupHandler) Create(c *gin.Context) { DailyLimitUSD: req.DailyLimitUSD.ToServiceInput(), WeeklyLimitUSD: req.WeeklyLimitUSD.ToServiceInput(), MonthlyLimitUSD: req.MonthlyLimitUSD.ToServiceInput(), + AllowImageGeneration: req.AllowImageGeneration, + ImageRateIndependent: req.ImageRateIndependent, + ImageRateMultiplier: req.ImageRateMultiplier, ImagePrice1K: req.ImagePrice1K, ImagePrice2K: req.ImagePrice2K, ImagePrice4K: req.ImagePrice4K, @@ -303,6 +312,9 @@ func (h *GroupHandler) Update(c *gin.Context) { DailyLimitUSD: req.DailyLimitUSD.ToServiceInput(), WeeklyLimitUSD: req.WeeklyLimitUSD.ToServiceInput(), MonthlyLimitUSD: req.MonthlyLimitUSD.ToServiceInput(), + AllowImageGeneration: req.AllowImageGeneration, + ImageRateIndependent: req.ImageRateIndependent, + ImageRateMultiplier: req.ImageRateMultiplier, ImagePrice1K: req.ImagePrice1K, ImagePrice2K: req.ImagePrice2K, ImagePrice4K: req.ImagePrice4K, diff --git a/backend/internal/handler/admin/setting_handler.go b/backend/internal/handler/admin/setting_handler.go index 59f4fe85..0cec89aa 100644 --- a/backend/internal/handler/admin/setting_handler.go +++ b/backend/internal/handler/admin/setting_handler.go @@ -2462,6 +2462,58 @@ func (h *SettingHandler) UpdateOverloadCooldownSettings(c *gin.Context) { }) } +// GetRateLimit429CooldownSettings 获取429默认回避配置 +// GET /api/v1/admin/settings/rate-limit-429-cooldown +func (h *SettingHandler) GetRateLimit429CooldownSettings(c *gin.Context) { + settings, err := h.settingService.GetRateLimit429CooldownSettings(c.Request.Context()) + if err != nil { + response.ErrorFrom(c, err) + return + } + + response.Success(c, dto.RateLimit429CooldownSettings{ + Enabled: settings.Enabled, + CooldownSeconds: settings.CooldownSeconds, + }) +} + +// UpdateRateLimit429CooldownSettingsRequest 更新429默认回避配置请求 +type UpdateRateLimit429CooldownSettingsRequest struct { + Enabled bool `json:"enabled"` + CooldownSeconds int `json:"cooldown_seconds"` +} + +// UpdateRateLimit429CooldownSettings 更新429默认回避配置 +// PUT /api/v1/admin/settings/rate-limit-429-cooldown +func (h *SettingHandler) UpdateRateLimit429CooldownSettings(c *gin.Context) { + var req UpdateRateLimit429CooldownSettingsRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + + settings := &service.RateLimit429CooldownSettings{ + Enabled: req.Enabled, + CooldownSeconds: req.CooldownSeconds, + } + + if err := h.settingService.SetRateLimit429CooldownSettings(c.Request.Context(), settings); err != nil { + response.BadRequest(c, err.Error()) + return + } + + updatedSettings, err := h.settingService.GetRateLimit429CooldownSettings(c.Request.Context()) + if err != nil { + response.ErrorFrom(c, err) + return + } + + response.Success(c, dto.RateLimit429CooldownSettings{ + Enabled: updatedSettings.Enabled, + CooldownSeconds: updatedSettings.CooldownSeconds, + }) +} + // GetStreamTimeoutSettings 获取流超时处理配置 // GET /api/v1/admin/settings/stream-timeout func (h *SettingHandler) GetStreamTimeoutSettings(c *gin.Context) { diff --git a/backend/internal/handler/admin/user_handler.go b/backend/internal/handler/admin/user_handler.go index 3d80107f..a297c56c 100644 --- a/backend/internal/handler/admin/user_handler.go +++ b/backend/internal/handler/admin/user_handler.go @@ -390,7 +390,7 @@ func (h *UserHandler) GetUserUsage(c *gin.Context) { // GetBalanceHistory handles getting user's balance/concurrency change history // GET /api/v1/admin/users/:id/balance-history // Query params: -// - type: filter by record type (balance, admin_balance, concurrency, admin_concurrency, subscription) +// - type: filter by record type (balance, affiliate_balance, admin_balance, concurrency, admin_concurrency, subscription) func (h *UserHandler) GetBalanceHistory(c *gin.Context) { userID, err := strconv.ParseInt(c.Param("id"), 10, 64) if err != nil { diff --git a/backend/internal/handler/dto/mappers.go b/backend/internal/handler/dto/mappers.go index f7503c2e..2559b112 100644 --- a/backend/internal/handler/dto/mappers.go +++ b/backend/internal/handler/dto/mappers.go @@ -176,6 +176,9 @@ func groupFromServiceBase(g *service.Group) Group { DailyLimitUSD: g.DailyLimitUSD, WeeklyLimitUSD: g.WeeklyLimitUSD, MonthlyLimitUSD: g.MonthlyLimitUSD, + AllowImageGeneration: g.AllowImageGeneration, + ImageRateIndependent: g.ImageRateIndependent, + ImageRateMultiplier: g.ImageRateMultiplier, ImagePrice1K: g.ImagePrice1K, ImagePrice2K: g.ImagePrice2K, ImagePrice4K: g.ImagePrice4K, diff --git a/backend/internal/handler/dto/settings.go b/backend/internal/handler/dto/settings.go index 492be170..0bc834fe 100644 --- a/backend/internal/handler/dto/settings.go +++ b/backend/internal/handler/dto/settings.go @@ -264,6 +264,12 @@ type OverloadCooldownSettings struct { CooldownMinutes int `json:"cooldown_minutes"` } +// RateLimit429CooldownSettings 429默认回避配置 DTO +type RateLimit429CooldownSettings struct { + Enabled bool `json:"enabled"` + CooldownSeconds int `json:"cooldown_seconds"` +} + // StreamTimeoutSettings 流超时处理配置 DTO type StreamTimeoutSettings struct { Enabled bool `json:"enabled"` diff --git a/backend/internal/handler/dto/types.go b/backend/internal/handler/dto/types.go index 5cc2f8e4..e15a916e 100644 --- a/backend/internal/handler/dto/types.go +++ b/backend/internal/handler/dto/types.go @@ -94,9 +94,12 @@ type Group struct { MonthlyLimitUSD *float64 `json:"monthly_limit_usd"` // 图片生成计费配置(仅 antigravity 平台使用) - ImagePrice1K *float64 `json:"image_price_1k"` - ImagePrice2K *float64 `json:"image_price_2k"` - ImagePrice4K *float64 `json:"image_price_4k"` + AllowImageGeneration bool `json:"allow_image_generation"` + ImageRateIndependent bool `json:"image_rate_independent"` + ImageRateMultiplier float64 `json:"image_rate_multiplier"` + ImagePrice1K *float64 `json:"image_price_1k"` + ImagePrice2K *float64 `json:"image_price_2k"` + ImagePrice4K *float64 `json:"image_price_4k"` // Claude Code 客户端限制 ClaudeCodeOnly bool `json:"claude_code_only"` diff --git a/backend/internal/handler/image_concurrency_limiter.go b/backend/internal/handler/image_concurrency_limiter.go new file mode 100644 index 00000000..6e7cbb67 --- /dev/null +++ b/backend/internal/handler/image_concurrency_limiter.go @@ -0,0 +1,126 @@ +package handler + +import ( + "context" + "sync" + "time" +) + +type imageConcurrencyLimiter struct { + mu sync.Mutex + notify chan struct{} + limit int + active int + waiting int + enabled bool +} + +func (l *imageConcurrencyLimiter) TryAcquire(enabled bool, limit int) (func(), bool) { + return l.acquire(context.Background(), enabled, limit, false, 0, 0) +} + +func (l *imageConcurrencyLimiter) Acquire(ctx context.Context, enabled bool, limit int, wait bool, timeout time.Duration, maxWaiting int) (func(), bool) { + return l.acquire(ctx, enabled, limit, wait, timeout, maxWaiting) +} + +func (l *imageConcurrencyLimiter) acquire(ctx context.Context, enabled bool, limit int, wait bool, timeout time.Duration, maxWaiting int) (func(), bool) { + if !enabled || limit <= 0 { + return nil, true + } + if ctx == nil { + ctx = context.Background() + } + if wait { + if timeout <= 0 { + return nil, false + } + waitCtx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + ctx = waitCtx + } + if maxWaiting < 0 { + maxWaiting = 0 + } + for { + release, acquired, waitRelease, notify := l.tryAcquireLocked(enabled, limit, wait, maxWaiting) + if acquired { + return release, acquired + } + if !wait || notify == nil { + return nil, false + } + if !l.waitForSlot(ctx, notify) { + if waitRelease != nil { + waitRelease() + } + return nil, false + } + if waitRelease != nil { + waitRelease() + } + } +} + +func (l *imageConcurrencyLimiter) tryAcquireLocked(enabled bool, limit int, wait bool, maxWaiting int) (func(), bool, func(), <-chan struct{}) { + l.mu.Lock() + defer l.mu.Unlock() + + if l.notify == nil { + l.notify = make(chan struct{}) + } + if l.enabled != enabled || l.limit != limit { + l.enabled = enabled + l.limit = limit + } + if l.active < l.limit { + l.active++ + return l.releaseFunc(), true, nil, nil + } + if !wait { + return nil, false, nil, nil + } + if maxWaiting > 0 && l.waiting >= maxWaiting { + return nil, false, nil, nil + } + l.waiting++ + return nil, false, l.waiterReleaseFunc(), l.notify +} + +func (l *imageConcurrencyLimiter) waitForSlot(ctx context.Context, notify <-chan struct{}) bool { + select { + case <-notify: + return true + case <-ctx.Done(): + return false + } +} + +func (l *imageConcurrencyLimiter) releaseFunc() func() { + var once sync.Once + return func() { + once.Do(func() { + l.mu.Lock() + if l.active > 0 { + l.active-- + } + if l.notify != nil { + close(l.notify) + l.notify = make(chan struct{}) + } + l.mu.Unlock() + }) + } +} + +func (l *imageConcurrencyLimiter) waiterReleaseFunc() func() { + var once sync.Once + return func() { + once.Do(func() { + l.mu.Lock() + if l.waiting > 0 { + l.waiting-- + } + l.mu.Unlock() + }) + } +} diff --git a/backend/internal/handler/image_concurrency_limiter_test.go b/backend/internal/handler/image_concurrency_limiter_test.go new file mode 100644 index 00000000..20147f16 --- /dev/null +++ b/backend/internal/handler/image_concurrency_limiter_test.go @@ -0,0 +1,230 @@ +package handler + +import ( + "context" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" + "github.com/tidwall/gjson" +) + +func TestImageConcurrencyLimiter_DefaultDisabledAllowsRequests(t *testing.T) { + limiter := &imageConcurrencyLimiter{} + + release, acquired := limiter.TryAcquire(false, 1) + + require.True(t, acquired) + require.Nil(t, release) +} + +func TestImageConcurrencyLimiter_RejectsWhenLimitReachedAndAllowsAfterRelease(t *testing.T) { + limiter := &imageConcurrencyLimiter{} + + release, acquired := limiter.TryAcquire(true, 1) + require.True(t, acquired) + require.NotNil(t, release) + + secondRelease, secondAcquired := limiter.TryAcquire(true, 1) + require.False(t, secondAcquired) + require.Nil(t, secondRelease) + + release() + thirdRelease, thirdAcquired := limiter.TryAcquire(true, 1) + require.True(t, thirdAcquired) + require.NotNil(t, thirdRelease) + thirdRelease() +} + +func TestImageConcurrencyLimiter_WaitsUntilSlotReleased(t *testing.T) { + limiter := &imageConcurrencyLimiter{} + release, acquired := limiter.Acquire(context.Background(), true, 1, true, time.Second, 1) + require.True(t, acquired) + require.NotNil(t, release) + + acquiredCh := make(chan func(), 1) + go func() { + waitRelease, waitAcquired := limiter.Acquire(context.Background(), true, 1, true, time.Second, 1) + require.True(t, waitAcquired) + acquiredCh <- waitRelease + }() + + time.Sleep(20 * time.Millisecond) + release() + + select { + case waitRelease := <-acquiredCh: + require.NotNil(t, waitRelease) + waitRelease() + case <-time.After(time.Second): + t.Fatal("timed out waiting for image concurrency slot") + } +} + +func TestImageConcurrencyLimiter_WaitTimesOut(t *testing.T) { + limiter := &imageConcurrencyLimiter{} + release, acquired := limiter.Acquire(context.Background(), true, 1, true, time.Second, 1) + require.True(t, acquired) + require.NotNil(t, release) + defer release() + + waitRelease, waitAcquired := limiter.Acquire(context.Background(), true, 1, true, 10*time.Millisecond, 1) + + require.False(t, waitAcquired) + require.Nil(t, waitRelease) +} + +func TestImageConcurrencyLimiter_MaxWaitingRequestsRejectsOverflow(t *testing.T) { + limiter := &imageConcurrencyLimiter{} + release, acquired := limiter.Acquire(context.Background(), true, 1, true, time.Second, 1) + require.True(t, acquired) + require.NotNil(t, release) + defer release() + + waitingStarted := make(chan struct{}) + waitingDone := make(chan struct{}) + go func() { + close(waitingStarted) + waitRelease, waitAcquired := limiter.Acquire(context.Background(), true, 1, true, time.Second, 1) + if waitAcquired && waitRelease != nil { + waitRelease() + } + close(waitingDone) + }() + <-waitingStarted + time.Sleep(20 * time.Millisecond) + + overflowRelease, overflowAcquired := limiter.Acquire(context.Background(), true, 1, true, time.Second, 1) + + require.False(t, overflowAcquired) + require.Nil(t, overflowRelease) + release() + <-waitingDone +} + +func TestOpenAIGatewayHandlerAcquireImageGenerationSlot_Returns429WhenFull(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/images/generations", nil) + + h := &OpenAIGatewayHandler{ + cfg: &config.Config{ + Gateway: config.GatewayConfig{ + ImageConcurrency: config.ImageConcurrencyConfig{ + Enabled: true, + MaxConcurrentRequests: 1, + OverflowMode: config.ImageConcurrencyOverflowModeReject, + }, + }, + }, + imageLimiter: &imageConcurrencyLimiter{}, + } + release, acquired := h.acquireImageGenerationSlot(c, false) + require.True(t, acquired) + require.NotNil(t, release) + defer release() + + blockedRelease, blocked := h.acquireImageGenerationSlot(c, false) + + require.False(t, blocked) + require.Nil(t, blockedRelease) + require.Equal(t, http.StatusTooManyRequests, rec.Code) + require.Equal(t, "rate_limit_error", gjson.GetBytes(rec.Body.Bytes(), "error.type").String()) + require.Contains(t, rec.Body.String(), "Image generation concurrency limit exceeded") +} + +func TestOpenAIGatewayHandlerResponses_ImageIntentRejectedByImageConcurrency(t *testing.T) { + gin.SetMode(gin.TestMode) + body := `{"model":"gpt-5.4","input":"draw","tools":[{"type":"image_generation"}]}` + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", strings.NewReader(body)) + groupID := int64(1) + c.Set(string(middleware2.ContextKeyAPIKey), &service.APIKey{ + ID: 10, + GroupID: &groupID, + Group: &service.Group{ + ID: groupID, + AllowImageGeneration: true, + }, + User: &service.User{ID: 20}, + }) + c.Set(string(middleware2.ContextKeyUser), middleware2.AuthSubject{UserID: 20, Concurrency: 1}) + + h := &OpenAIGatewayHandler{ + gatewayService: &service.OpenAIGatewayService{}, + billingCacheService: &service.BillingCacheService{}, + apiKeyService: &service.APIKeyService{}, + concurrencyHelper: &ConcurrencyHelper{concurrencyService: service.NewConcurrencyService(&helperConcurrencyCacheStub{userSeq: []bool{true}})}, + errorPassthroughService: nil, + cfg: &config.Config{Gateway: config.GatewayConfig{ImageConcurrency: config.ImageConcurrencyConfig{ + Enabled: true, + MaxConcurrentRequests: 1, + OverflowMode: config.ImageConcurrencyOverflowModeReject, + }}}, + imageLimiter: &imageConcurrencyLimiter{}, + } + release, acquired := h.acquireImageGenerationSlot(c, false) + require.True(t, acquired) + require.NotNil(t, release) + defer release() + rec.Body.Reset() + rec.Code = 0 + + h.Responses(c) + + require.Equal(t, http.StatusTooManyRequests, rec.Code) + require.Equal(t, "rate_limit_error", gjson.GetBytes(rec.Body.Bytes(), "error.type").String()) + require.Contains(t, rec.Body.String(), "Image generation concurrency limit exceeded") +} + +func TestOpenAIGatewayHandlerResponses_TextOnlyNotRejectedByImageConcurrency(t *testing.T) { + gin.SetMode(gin.TestMode) + body := `{"model":"gpt-5.4","input":"write code"}` + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", strings.NewReader(body)) + groupID := int64(1) + c.Set(string(middleware2.ContextKeyAPIKey), &service.APIKey{ + ID: 10, + GroupID: &groupID, + Group: &service.Group{ + ID: groupID, + AllowImageGeneration: true, + }, + User: &service.User{ID: 20}, + }) + c.Set(string(middleware2.ContextKeyUser), middleware2.AuthSubject{UserID: 20, Concurrency: 1}) + + h := &OpenAIGatewayHandler{ + gatewayService: &service.OpenAIGatewayService{}, + billingCacheService: service.NewBillingCacheService(nil, nil, nil, nil, nil, nil, &config.Config{RunMode: config.RunModeSimple}), + apiKeyService: &service.APIKeyService{}, + concurrencyHelper: &ConcurrencyHelper{concurrencyService: service.NewConcurrencyService(&helperConcurrencyCacheStub{userSeq: []bool{true}})}, + cfg: &config.Config{Gateway: config.GatewayConfig{ImageConcurrency: config.ImageConcurrencyConfig{ + Enabled: true, + MaxConcurrentRequests: 1, + OverflowMode: config.ImageConcurrencyOverflowModeReject, + }}}, + imageLimiter: &imageConcurrencyLimiter{}, + } + release, acquired := h.acquireImageGenerationSlot(c, false) + require.True(t, acquired) + require.NotNil(t, release) + defer release() + rec.Body.Reset() + rec.Code = 0 + + h.Responses(c) + + require.NotEqual(t, http.StatusTooManyRequests, rec.Code) + require.NotContains(t, rec.Body.String(), "Image generation concurrency limit exceeded") +} diff --git a/backend/internal/handler/openai_chat_completions.go b/backend/internal/handler/openai_chat_completions.go index f395970a..06ab9d52 100644 --- a/backend/internal/handler/openai_chat_completions.go +++ b/backend/internal/handler/openai_chat_completions.go @@ -10,6 +10,7 @@ import ( pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil" "github.com/Wei-Shaw/sub2api/internal/pkg/ip" "github.com/Wei-Shaw/sub2api/internal/pkg/logger" + "github.com/Wei-Shaw/sub2api/internal/pkg/openai_compat" middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" "github.com/Wei-Shaw/sub2api/internal/service" "github.com/gin-gonic/gin" @@ -120,7 +121,6 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) { var lastFailoverErr *service.UpstreamFailoverError for { - c.Set("openai_chat_completions_fallback_model", "") reqLog.Debug("openai_chat_completions.account_selecting", zap.Int("excluded_account_count", len(failedAccountIDs))) selection, scheduleDecision, err := h.gatewayService.SelectAccountWithScheduler( c.Request.Context(), @@ -138,32 +138,8 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) { zap.Int("excluded_account_count", len(failedAccountIDs)), ) if len(failedAccountIDs) == 0 { - defaultModel := "" - if apiKey.Group != nil { - defaultModel = apiKey.Group.DefaultMappedModel - } - if defaultModel != "" && defaultModel != reqModel { - reqLog.Info("openai_chat_completions.fallback_to_default_model", - zap.String("default_mapped_model", defaultModel), - ) - selection, scheduleDecision, err = h.gatewayService.SelectAccountWithScheduler( - c.Request.Context(), - apiKey.GroupID, - "", - sessionHash, - defaultModel, - failedAccountIDs, - service.OpenAIUpstreamTransportAny, - false, - ) - if err == nil && selection != nil { - c.Set("openai_chat_completions_fallback_model", defaultModel) - } - } - if err != nil { - h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "Service temporarily unavailable", streamStarted) - return - } + h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "Service temporarily unavailable", streamStarted) + return } else { if lastFailoverErr != nil { h.handleFailoverExhausted(c, lastFailoverErr, streamStarted) @@ -191,12 +167,11 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) { service.SetOpsLatencyMs(c, service.OpsRoutingLatencyMsKey, time.Since(routingStart).Milliseconds()) forwardStart := time.Now() - defaultMappedModel := resolveOpenAIForwardDefaultMappedModel(apiKey, c.GetString("openai_chat_completions_fallback_model")) forwardBody := body if channelMapping.Mapped { forwardBody = h.gatewayService.ReplaceModelInBody(body, channelMapping.MappedModel) } - result, err := h.gatewayService.ForwardAsChatCompletions(c.Request.Context(), c, account, forwardBody, promptCacheKey, defaultMappedModel) + result, err := h.gatewayService.ForwardAsChatCompletions(c.Request.Context(), c, account, forwardBody, promptCacheKey, "") forwardDurationMs := time.Since(forwardStart).Milliseconds() if accountReleaseFunc != nil { @@ -212,52 +187,60 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) { service.SetOpsLatencyMs(c, service.OpsTimeToFirstTokenMsKey, int64(*result.FirstTokenMs)) } if err != nil { - var failoverErr *service.UpstreamFailoverError - if errors.As(err, &failoverErr) { - h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil) - // Pool mode: retry on the same account - if failoverErr.RetryableOnSameAccount { - retryLimit := account.GetPoolModeRetryCount() - if sameAccountRetryCount[account.ID] < retryLimit { - sameAccountRetryCount[account.ID]++ - reqLog.Warn("openai_chat_completions.pool_mode_same_account_retry", - zap.Int64("account_id", account.ID), - zap.Int("upstream_status", failoverErr.StatusCode), - zap.Int("retry_limit", retryLimit), - zap.Int("retry_count", sameAccountRetryCount[account.ID]), - ) - select { - case <-c.Request.Context().Done(): - return - case <-time.After(sameAccountRetryDelay): - } - continue - } - } - h.gatewayService.RecordOpenAIAccountSwitch() - failedAccountIDs[account.ID] = struct{}{} - lastFailoverErr = failoverErr - if switchCount >= maxAccountSwitches { - h.handleFailoverExhausted(c, failoverErr, streamStarted) - return - } - switchCount++ - reqLog.Warn("openai_chat_completions.upstream_failover_switching", + if result != nil && result.ImageCount > 0 { + reqLog.Warn("openai_chat_completions.forward_partial_error_with_image_result", zap.Int64("account_id", account.ID), - zap.Int("upstream_status", failoverErr.StatusCode), - zap.Int("switch_count", switchCount), - zap.Int("max_switches", maxAccountSwitches), + zap.Int("image_count", result.ImageCount), + zap.Error(err), ) - continue + } else { + var failoverErr *service.UpstreamFailoverError + if errors.As(err, &failoverErr) { + h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil) + // Pool mode: retry on the same account + if failoverErr.RetryableOnSameAccount { + retryLimit := account.GetPoolModeRetryCount() + if sameAccountRetryCount[account.ID] < retryLimit { + sameAccountRetryCount[account.ID]++ + reqLog.Warn("openai_chat_completions.pool_mode_same_account_retry", + zap.Int64("account_id", account.ID), + zap.Int("upstream_status", failoverErr.StatusCode), + zap.Int("retry_limit", retryLimit), + zap.Int("retry_count", sameAccountRetryCount[account.ID]), + ) + select { + case <-c.Request.Context().Done(): + return + case <-time.After(sameAccountRetryDelay): + } + continue + } + } + h.gatewayService.RecordOpenAIAccountSwitch() + failedAccountIDs[account.ID] = struct{}{} + lastFailoverErr = failoverErr + if switchCount >= maxAccountSwitches { + h.handleFailoverExhausted(c, failoverErr, streamStarted) + return + } + switchCount++ + reqLog.Warn("openai_chat_completions.upstream_failover_switching", + zap.Int64("account_id", account.ID), + zap.Int("upstream_status", failoverErr.StatusCode), + zap.Int("switch_count", switchCount), + zap.Int("max_switches", maxAccountSwitches), + ) + continue + } + h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil) + wroteFallback := h.ensureForwardErrorResponse(c, streamStarted) + reqLog.Warn("openai_chat_completions.forward_failed", + zap.Int64("account_id", account.ID), + zap.Bool("fallback_error_response_written", wroteFallback), + zap.Error(err), + ) + return } - h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil) - wroteFallback := h.ensureForwardErrorResponse(c, streamStarted) - reqLog.Warn("openai_chat_completions.forward_failed", - zap.Int64("account_id", account.ID), - zap.Bool("fallback_error_response_written", wroteFallback), - zap.Error(err), - ) - return } if result != nil { h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, result.FirstTokenMs) @@ -267,16 +250,18 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) { userAgent := c.GetHeader("User-Agent") clientIP := ip.GetClientIP(c) + inboundEndpoint := GetInboundEndpoint(c) + upstreamEndpoint := resolveRawCCUpstreamEndpoint(c, account) - h.submitUsageRecordTask(func(ctx context.Context) { + h.submitOpenAIUsageRecordTask(result, func(ctx context.Context) { if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{ Result: result, APIKey: apiKey, User: apiKey.User, Account: account, Subscription: subscription, - InboundEndpoint: GetInboundEndpoint(c), - UpstreamEndpoint: GetUpstreamEndpoint(c, account.Platform), + InboundEndpoint: inboundEndpoint, + UpstreamEndpoint: upstreamEndpoint, UserAgent: userAgent, IPAddress: clientIP, APIKeyService: h.apiKeyService, @@ -299,3 +284,16 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) { return } } + +// resolveRawCCUpstreamEndpoint returns the actual upstream endpoint for +// OpenAI Chat Completions requests. For APIKey accounts whose upstream +// has been probed to not support the Responses API, the request is +// forwarded directly to /v1/chat/completions — not through the default +// CC→Responses conversion path. +func resolveRawCCUpstreamEndpoint(c *gin.Context, account *service.Account) string { + if account != nil && account.Type == service.AccountTypeAPIKey && + !openai_compat.ShouldUseResponsesAPI(account.Extra) { + return "/v1/chat/completions" + } + return GetUpstreamEndpoint(c, account.Platform) +} diff --git a/backend/internal/handler/openai_gateway_handler.go b/backend/internal/handler/openai_gateway_handler.go index 7676ffa3..3997a0ee 100644 --- a/backend/internal/handler/openai_gateway_handler.go +++ b/backend/internal/handler/openai_gateway_handler.go @@ -33,20 +33,11 @@ type OpenAIGatewayHandler struct { usageRecordWorkerPool *service.UsageRecordWorkerPool errorPassthroughService *service.ErrorPassthroughService concurrencyHelper *ConcurrencyHelper + imageLimiter *imageConcurrencyLimiter maxAccountSwitches int cfg *config.Config } -func resolveOpenAIForwardDefaultMappedModel(apiKey *service.APIKey, fallbackModel string) string { - if fallbackModel = strings.TrimSpace(fallbackModel); fallbackModel != "" { - return fallbackModel - } - if apiKey == nil || apiKey.Group == nil { - return "" - } - return strings.TrimSpace(apiKey.Group.DefaultMappedModel) -} - func resolveOpenAIMessagesDispatchMappedModel(apiKey *service.APIKey, requestedModel string) string { if apiKey == nil || apiKey.Group == nil { return "" @@ -79,6 +70,7 @@ func NewOpenAIGatewayHandler( usageRecordWorkerPool: usageRecordWorkerPool, errorPassthroughService: errorPassthroughService, concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatComment, pingInterval), + imageLimiter: &imageConcurrencyLimiter{}, maxAccountSwitches: maxAccountSwitches, cfg: cfg, } @@ -197,6 +189,23 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { setOpsRequestContext(c, reqModel, reqStream, body) setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false))) + imageIntent := service.IsImageGenerationIntent("/v1/responses", reqModel, body) + if imageIntent && !service.GroupAllowsImageGeneration(apiKey.Group) { + h.errorResponse(c, http.StatusForbidden, "permission_error", service.ImageGenerationPermissionMessage()) + return + } + var imageReleaseFunc func() + if imageIntent { + var imageAcquired bool + imageReleaseFunc, imageAcquired = h.acquireImageGenerationSlot(c, streamStarted) + if !imageAcquired { + return + } + if imageReleaseFunc != nil { + defer imageReleaseFunc() + } + } + // 解析渠道级模型映射 channelMapping, _ := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel) @@ -328,57 +337,65 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { service.SetOpsLatencyMs(c, service.OpsTimeToFirstTokenMsKey, int64(*result.FirstTokenMs)) } if err != nil { - var failoverErr *service.UpstreamFailoverError - if errors.As(err, &failoverErr) { - h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil) - // 池模式:同账号重试 - if failoverErr.RetryableOnSameAccount { - retryLimit := account.GetPoolModeRetryCount() - if sameAccountRetryCount[account.ID] < retryLimit { - sameAccountRetryCount[account.ID]++ - reqLog.Warn("openai.pool_mode_same_account_retry", - zap.Int64("account_id", account.ID), - zap.Int("upstream_status", failoverErr.StatusCode), - zap.Int("retry_limit", retryLimit), - zap.Int("retry_count", sameAccountRetryCount[account.ID]), - ) - select { - case <-c.Request.Context().Done(): - return - case <-time.After(sameAccountRetryDelay): + if result != nil && result.ImageCount > 0 { + reqLog.Warn("openai.forward_partial_error_with_image_result", + zap.Int64("account_id", account.ID), + zap.Int("image_count", result.ImageCount), + zap.Error(err), + ) + } else { + var failoverErr *service.UpstreamFailoverError + if errors.As(err, &failoverErr) { + h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil) + // 池模式:同账号重试 + if failoverErr.RetryableOnSameAccount { + retryLimit := account.GetPoolModeRetryCount() + if sameAccountRetryCount[account.ID] < retryLimit { + sameAccountRetryCount[account.ID]++ + reqLog.Warn("openai.pool_mode_same_account_retry", + zap.Int64("account_id", account.ID), + zap.Int("upstream_status", failoverErr.StatusCode), + zap.Int("retry_limit", retryLimit), + zap.Int("retry_count", sameAccountRetryCount[account.ID]), + ) + select { + case <-c.Request.Context().Done(): + return + case <-time.After(sameAccountRetryDelay): + } + continue } - continue } + h.gatewayService.RecordOpenAIAccountSwitch() + failedAccountIDs[account.ID] = struct{}{} + lastFailoverErr = failoverErr + if switchCount >= maxAccountSwitches { + h.handleFailoverExhausted(c, failoverErr, streamStarted) + return + } + switchCount++ + reqLog.Warn("openai.upstream_failover_switching", + zap.Int64("account_id", account.ID), + zap.Int("upstream_status", failoverErr.StatusCode), + zap.Int("switch_count", switchCount), + zap.Int("max_switches", maxAccountSwitches), + ) + continue } - h.gatewayService.RecordOpenAIAccountSwitch() - failedAccountIDs[account.ID] = struct{}{} - lastFailoverErr = failoverErr - if switchCount >= maxAccountSwitches { - h.handleFailoverExhausted(c, failoverErr, streamStarted) + h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil) + wroteFallback := h.ensureForwardErrorResponse(c, streamStarted) + fields := []zap.Field{ + zap.Int64("account_id", account.ID), + zap.Bool("fallback_error_response_written", wroteFallback), + zap.Error(err), + } + if shouldLogOpenAIForwardFailureAsWarn(c, wroteFallback) { + reqLog.Warn("openai.forward_failed", fields...) return } - switchCount++ - reqLog.Warn("openai.upstream_failover_switching", - zap.Int64("account_id", account.ID), - zap.Int("upstream_status", failoverErr.StatusCode), - zap.Int("switch_count", switchCount), - zap.Int("max_switches", maxAccountSwitches), - ) - continue - } - h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil) - wroteFallback := h.ensureForwardErrorResponse(c, streamStarted) - fields := []zap.Field{ - zap.Int64("account_id", account.ID), - zap.Bool("fallback_error_response_written", wroteFallback), - zap.Error(err), - } - if shouldLogOpenAIForwardFailureAsWarn(c, wroteFallback) { - reqLog.Warn("openai.forward_failed", fields...) + reqLog.Error("openai.forward_failed", fields...) return } - reqLog.Error("openai.forward_failed", fields...) - return } if result != nil { if account.Type == service.AccountTypeOAuth { @@ -393,17 +410,19 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { userAgent := c.GetHeader("User-Agent") clientIP := ip.GetClientIP(c) requestPayloadHash := service.HashUsageRequestPayload(body) + inboundEndpoint := GetInboundEndpoint(c) + upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform) // 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。 - h.submitUsageRecordTask(func(ctx context.Context) { + h.submitOpenAIUsageRecordTask(result, func(ctx context.Context) { if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{ Result: result, APIKey: apiKey, User: apiKey.User, Account: account, Subscription: subscription, - InboundEndpoint: GetInboundEndpoint(c), - UpstreamEndpoint: GetUpstreamEndpoint(c, account.Platform), + InboundEndpoint: inboundEndpoint, + UpstreamEndpoint: upstreamEndpoint, UserAgent: userAgent, IPAddress: clientIP, RequestPayloadHash: requestPayloadHash, @@ -613,21 +632,7 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) { sessionHash := h.gatewayService.GenerateSessionHash(c, body) promptCacheKey := h.gatewayService.ExtractSessionID(c, body) - - // Anthropic 格式的请求在 metadata.user_id 中携带 session 标识, - // 而非 OpenAI 的 session_id/conversation_id headers。 - // 从中派生 sessionHash(sticky session)和 promptCacheKey(upstream cache)。 - if sessionHash == "" || promptCacheKey == "" { - if userID := strings.TrimSpace(gjson.GetBytes(body, "metadata.user_id").String()); userID != "" { - seed := reqModel + "-" + userID - if promptCacheKey == "" { - promptCacheKey = service.GenerateSessionUUID(seed) - } - if sessionHash == "" { - sessionHash = service.DeriveSessionHashFromSeed(seed) - } - } - } + sessionHash, promptCacheKey = resolveOpenAIMessagesMetadataSession(sessionHash, promptCacheKey, reqModel, body) maxAccountSwitches := h.maxAccountSwitches switchCount := 0 @@ -711,52 +716,60 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) { service.SetOpsLatencyMs(c, service.OpsTimeToFirstTokenMsKey, int64(*result.FirstTokenMs)) } if err != nil { - var failoverErr *service.UpstreamFailoverError - if errors.As(err, &failoverErr) { - h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil) - // 池模式:同账号重试 - if failoverErr.RetryableOnSameAccount { - retryLimit := account.GetPoolModeRetryCount() - if sameAccountRetryCount[account.ID] < retryLimit { - sameAccountRetryCount[account.ID]++ - reqLog.Warn("openai_messages.pool_mode_same_account_retry", - zap.Int64("account_id", account.ID), - zap.Int("upstream_status", failoverErr.StatusCode), - zap.Int("retry_limit", retryLimit), - zap.Int("retry_count", sameAccountRetryCount[account.ID]), - ) - select { - case <-c.Request.Context().Done(): - return - case <-time.After(sameAccountRetryDelay): - } - continue - } - } - h.gatewayService.RecordOpenAIAccountSwitch() - failedAccountIDs[account.ID] = struct{}{} - lastFailoverErr = failoverErr - if switchCount >= maxAccountSwitches { - h.handleAnthropicFailoverExhausted(c, failoverErr, streamStarted) - return - } - switchCount++ - reqLog.Warn("openai_messages.upstream_failover_switching", + if result != nil && result.ImageCount > 0 { + reqLog.Warn("openai_messages.forward_partial_error_with_image_result", zap.Int64("account_id", account.ID), - zap.Int("upstream_status", failoverErr.StatusCode), - zap.Int("switch_count", switchCount), - zap.Int("max_switches", maxAccountSwitches), + zap.Int("image_count", result.ImageCount), + zap.Error(err), ) - continue + } else { + var failoverErr *service.UpstreamFailoverError + if errors.As(err, &failoverErr) { + h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil) + // 池模式:同账号重试 + if failoverErr.RetryableOnSameAccount { + retryLimit := account.GetPoolModeRetryCount() + if sameAccountRetryCount[account.ID] < retryLimit { + sameAccountRetryCount[account.ID]++ + reqLog.Warn("openai_messages.pool_mode_same_account_retry", + zap.Int64("account_id", account.ID), + zap.Int("upstream_status", failoverErr.StatusCode), + zap.Int("retry_limit", retryLimit), + zap.Int("retry_count", sameAccountRetryCount[account.ID]), + ) + select { + case <-c.Request.Context().Done(): + return + case <-time.After(sameAccountRetryDelay): + } + continue + } + } + h.gatewayService.RecordOpenAIAccountSwitch() + failedAccountIDs[account.ID] = struct{}{} + lastFailoverErr = failoverErr + if switchCount >= maxAccountSwitches { + h.handleAnthropicFailoverExhausted(c, failoverErr, streamStarted) + return + } + switchCount++ + reqLog.Warn("openai_messages.upstream_failover_switching", + zap.Int64("account_id", account.ID), + zap.Int("upstream_status", failoverErr.StatusCode), + zap.Int("switch_count", switchCount), + zap.Int("max_switches", maxAccountSwitches), + ) + continue + } + h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil) + wroteFallback := h.ensureAnthropicErrorResponse(c, streamStarted) + reqLog.Warn("openai_messages.forward_failed", + zap.Int64("account_id", account.ID), + zap.Bool("fallback_error_response_written", wroteFallback), + zap.Error(err), + ) + return } - h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil) - wroteFallback := h.ensureAnthropicErrorResponse(c, streamStarted) - reqLog.Warn("openai_messages.forward_failed", - zap.Int64("account_id", account.ID), - zap.Bool("fallback_error_response_written", wroteFallback), - zap.Error(err), - ) - return } if result != nil { h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, result.FirstTokenMs) @@ -767,16 +780,18 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) { userAgent := c.GetHeader("User-Agent") clientIP := ip.GetClientIP(c) requestPayloadHash := service.HashUsageRequestPayload(body) + inboundEndpoint := GetInboundEndpoint(c) + upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform) - h.submitUsageRecordTask(func(ctx context.Context) { + h.submitOpenAIUsageRecordTask(result, func(ctx context.Context) { if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{ Result: result, APIKey: apiKey, User: apiKey.User, Account: account, Subscription: subscription, - InboundEndpoint: GetInboundEndpoint(c), - UpstreamEndpoint: GetUpstreamEndpoint(c, account.Platform), + InboundEndpoint: inboundEndpoint, + UpstreamEndpoint: upstreamEndpoint, UserAgent: userAgent, IPAddress: clientIP, RequestPayloadHash: requestPayloadHash, @@ -801,6 +816,20 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) { } } +func resolveOpenAIMessagesMetadataSession(sessionHash, promptCacheKey, reqModel string, body []byte) (string, string) { + // Anthropic metadata.user_id 只作为账号粘性信号。上游 GPT/Codex 缓存键 + // 交给 ForwardAsAnthropic 从 cache_control 或完整消息 digest 派生,避免 + // 固定 metadata key 压住后续 turn 的缓存滚动。 + if sessionHash != "" { + return sessionHash, promptCacheKey + } + if userID := strings.TrimSpace(gjson.GetBytes(body, "metadata.user_id").String()); userID != "" { + seed := reqModel + "-" + userID + sessionHash = service.DeriveSessionHashFromSeed(seed) + } + return sessionHash, promptCacheKey +} + // anthropicErrorResponse writes an error in Anthropic Messages API format. func (h *OpenAIGatewayHandler) anthropicErrorResponse(c *gin.Context, status int, errType, message string) { c.JSON(status, gin.H{ @@ -1124,6 +1153,11 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) { setOpsRequestContext(c, reqModel, true, firstMessage) setOpsEndpointContext(c, "", int16(service.RequestTypeWSV2)) + if service.IsImageGenerationIntent("/v1/responses", reqModel, firstMessage) && !service.GroupAllowsImageGeneration(apiKey.Group) { + closeOpenAIClientWS(wsConn, coderws.StatusPolicyViolation, service.ImageGenerationPermissionMessage()) + return + } + // 解析渠道级模型映射 channelMappingWS, _ := h.gatewayService.ResolveChannelMappingAndRestrict(ctx, apiKey.GroupID, reqModel) @@ -1233,6 +1267,7 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) { ) hooks := &service.OpenAIWSIngressHooks{ + InitialRequestModel: reqModel, BeforeTurn: func(turn int) error { if turn == 1 { return nil @@ -1266,22 +1301,34 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) { }, AfterTurn: func(turn int, result *service.OpenAIForwardResult, turnErr error) { releaseTurnSlots() - if turnErr != nil || result == nil { + if turnErr != nil { + if result == nil || result.ImageCount <= 0 { + return + } + reqLog.Warn("openai.websocket_partial_error_with_image_result", + zap.Int64("account_id", account.ID), + zap.Int("image_count", result.ImageCount), + zap.Error(turnErr), + ) + } + if result == nil { return } if account.Type == service.AccountTypeOAuth { h.gatewayService.UpdateCodexUsageSnapshotFromHeaders(ctx, account.ID, result.ResponseHeaders) } h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, result.FirstTokenMs) - h.submitUsageRecordTask(func(taskCtx context.Context) { + inboundEndpoint := GetInboundEndpoint(c) + upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform) + h.submitOpenAIUsageRecordTask(result, func(taskCtx context.Context) { if err := h.gatewayService.RecordUsage(taskCtx, &service.OpenAIRecordUsageInput{ Result: result, APIKey: apiKey, User: apiKey.User, Account: account, Subscription: subscription, - InboundEndpoint: GetInboundEndpoint(c), - UpstreamEndpoint: GetUpstreamEndpoint(c, account.Platform), + InboundEndpoint: inboundEndpoint, + UpstreamEndpoint: upstreamEndpoint, UserAgent: userAgent, IPAddress: clientIP, RequestPayloadHash: service.HashUsageRequestPayload(firstMessage), @@ -1449,6 +1496,60 @@ func (h *OpenAIGatewayHandler) submitUsageRecordTask(task service.UsageRecordTas task(ctx) } +func (h *OpenAIGatewayHandler) submitOpenAIUsageRecordTask(result *service.OpenAIForwardResult, task service.UsageRecordTask) { + if result != nil && result.ImageCount > 0 { + h.submitMandatoryUsageRecordTask(task) + return + } + h.submitUsageRecordTask(task) +} + +func (h *OpenAIGatewayHandler) submitMandatoryUsageRecordTask(task service.UsageRecordTask) { + if task == nil { + return + } + if h.usageRecordWorkerPool != nil { + if mode := h.usageRecordWorkerPool.Submit(task); mode != service.UsageRecordSubmitModeDropped { + return + } + logger.L().With( + zap.String("component", "handler.openai_gateway.usage"), + ).Warn("openai.usage_record_task_mandatory_sync_fallback") + } + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + defer func() { + if recovered := recover(); recovered != nil { + logger.L().With( + zap.String("component", "handler.openai_gateway.usage"), + zap.Any("panic", recovered), + ).Error("openai.usage_record_task_panic_recovered") + } + }() + task(ctx) +} + +func (h *OpenAIGatewayHandler) acquireImageGenerationSlot(c *gin.Context, streamStarted bool) (func(), bool) { + if h == nil || h.cfg == nil || h.imageLimiter == nil { + return nil, true + } + imageConcurrency := h.cfg.Gateway.ImageConcurrency + wait := strings.TrimSpace(imageConcurrency.OverflowMode) == config.ImageConcurrencyOverflowModeWait + release, acquired := h.imageLimiter.Acquire( + c.Request.Context(), + imageConcurrency.Enabled, + imageConcurrency.MaxConcurrentRequests, + wait, + time.Duration(imageConcurrency.WaitTimeoutSeconds)*time.Second, + imageConcurrency.MaxWaitingRequests, + ) + if acquired { + return release, true + } + h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Image generation concurrency limit exceeded, please retry later", streamStarted) + return nil, false +} + // handleConcurrencyError handles concurrency-related errors with proper 429 response func (h *OpenAIGatewayHandler) handleConcurrencyError(c *gin.Context, err error, slotType string, streamStarted bool) { h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", diff --git a/backend/internal/handler/openai_gateway_handler_test.go b/backend/internal/handler/openai_gateway_handler_test.go index 8ecee59a..c560350e 100644 --- a/backend/internal/handler/openai_gateway_handler_test.go +++ b/backend/internal/handler/openai_gateway_handler_test.go @@ -10,6 +10,7 @@ import ( "testing" "time" + "github.com/Wei-Shaw/sub2api/internal/config" pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil" "github.com/Wei-Shaw/sub2api/internal/server/middleware" "github.com/Wei-Shaw/sub2api/internal/service" @@ -91,6 +92,24 @@ func TestOpenAIHandleStreamingAwareError_JSONEscaping(t *testing.T) { } } +func TestResolveOpenAIMessagesMetadataSession_DoesNotDerivePromptCacheKey(t *testing.T) { + body := []byte(`{"model":"claude-sonnet-4-5","metadata":{"user_id":"claude-code-session"},"messages":[{"role":"user","content":"hello"}]}`) + + sessionHash, promptCacheKey := resolveOpenAIMessagesMetadataSession("", "", "claude-sonnet-4-5", body) + + require.NotEmpty(t, sessionHash) + require.Empty(t, promptCacheKey) +} + +func TestResolveOpenAIMessagesMetadataSession_PreservesExplicitPromptCacheKey(t *testing.T) { + body := []byte(`{"metadata":{"user_id":"claude-code-session"}}`) + + sessionHash, promptCacheKey := resolveOpenAIMessagesMetadataSession("", "explicit-cache", "claude-sonnet-4-5", body) + + require.NotEmpty(t, sessionHash) + require.Equal(t, "explicit-cache", promptCacheKey) +} + func TestOpenAIHandleStreamingAwareError_NonStreaming(t *testing.T) { gin.SetMode(gin.TestMode) w := httptest.NewRecorder() @@ -352,30 +371,6 @@ func TestOpenAIEnsureResponsesDependencies(t *testing.T) { }) } -func TestResolveOpenAIForwardDefaultMappedModel(t *testing.T) { - t.Run("prefers_explicit_fallback_model", func(t *testing.T) { - apiKey := &service.APIKey{ - Group: &service.Group{DefaultMappedModel: "gpt-5.4"}, - } - require.Equal(t, "gpt-5.2", resolveOpenAIForwardDefaultMappedModel(apiKey, " gpt-5.2 ")) - }) - - t.Run("uses_group_default_when_explicit_fallback_absent", func(t *testing.T) { - apiKey := &service.APIKey{ - Group: &service.Group{DefaultMappedModel: "gpt-5.4"}, - } - require.Equal(t, "gpt-5.4", resolveOpenAIForwardDefaultMappedModel(apiKey, "")) - }) - - t.Run("returns_empty_without_group_default", func(t *testing.T) { - require.Empty(t, resolveOpenAIForwardDefaultMappedModel(nil, "")) - require.Empty(t, resolveOpenAIForwardDefaultMappedModel(&service.APIKey{}, "")) - require.Empty(t, resolveOpenAIForwardDefaultMappedModel(&service.APIKey{ - Group: &service.Group{}, - }, "")) - }) -} - func TestResolveOpenAIMessagesDispatchMappedModel(t *testing.T) { t.Run("exact_claude_model_override_wins", func(t *testing.T) { apiKey := &service.APIKey{ @@ -651,6 +646,46 @@ func TestOpenAIResponsesWebSocket_PreviousResponseIDKindLoggedBeforeAcquireFailu require.Contains(t, strings.ToLower(closeErr.Reason), "failed to acquire user concurrency slot") } +func TestOpenAIResponsesWebSocket_PassthroughUsageLogPersistsUserAgentAndReasoningEffort(t *testing.T) { + got := runOpenAIResponsesWebSocketUsageLogCase(t, openAIResponsesWSUsageLogCase{ + firstPayload: `{"type":"response.create","model":"gpt-5.4","stream":false,"reasoning":{"effort":"HIGH"}}`, + userAgent: testStringPtr("codex_cli_rs/0.125.0 test"), + }) + + require.NotNil(t, got.log.UserAgent) + require.Equal(t, "codex_cli_rs/0.125.0 test", *got.log.UserAgent) + require.NotNil(t, got.log.ReasoningEffort) + require.Equal(t, "high", *got.log.ReasoningEffort) + require.True(t, got.log.OpenAIWSMode) +} + +func TestOpenAIResponsesWebSocket_PassthroughUsageLogInfersReasoningFromInitialRequestModel(t *testing.T) { + got := runOpenAIResponsesWebSocketUsageLogCase(t, openAIResponsesWSUsageLogCase{ + firstPayload: `{"type":"response.create","model":"gpt-5.4-xhigh","stream":false}`, + userAgent: testStringPtr("codex_cli_rs/0.125.0 mapped"), + channelMapping: map[string]string{ + "gpt-5.4-xhigh": "gpt-5.4", + }, + }) + + require.Equal(t, "gpt-5.4", gjson.GetBytes(got.upstreamFirstPayload, "model").String(), + "上游首帧应使用渠道映射后的模型") + require.NotNil(t, got.log.ReasoningEffort) + require.Equal(t, "xhigh", *got.log.ReasoningEffort, + "usage log reasoning effort 必须使用渠道映射前首帧模型后缀推导") +} + +func TestOpenAIResponsesWebSocket_PassthroughUsageLogLeavesUserAgentNilWhenMissing(t *testing.T) { + got := runOpenAIResponsesWebSocketUsageLogCase(t, openAIResponsesWSUsageLogCase{ + firstPayload: `{"type":"response.create","model":"gpt-5.4","stream":false,"reasoning":{"effort":"medium"}}`, + userAgent: testStringPtr(""), + }) + + require.Nil(t, got.log.UserAgent, "空入站 User-Agent 不应由上游握手 UA 或默认 UA 兜底") + require.NotNil(t, got.log.ReasoningEffort) + require.Equal(t, "medium", *got.log.ReasoningEffort) +} + func TestSetOpenAIClientTransportHTTP(t *testing.T) { gin.SetMode(gin.TestMode) @@ -796,3 +831,278 @@ func newOpenAIWSHandlerTestServer(t *testing.T, h *OpenAIGatewayHandler, subject router.GET("/openai/v1/responses", h.ResponsesWebSocket) return httptest.NewServer(router) } + +type openAIResponsesWSUsageLogCase struct { + firstPayload string + userAgent *string + channelMapping map[string]string +} + +type openAIResponsesWSUsageLogResult struct { + log *service.UsageLog + upstreamFirstPayload []byte +} + +type openAIWSUsageHandlerAccountRepoStub struct { + service.AccountRepository + account service.Account +} + +func (s *openAIWSUsageHandlerAccountRepoStub) ListSchedulableByPlatform(ctx context.Context, platform string) ([]service.Account, error) { + if s.account.Platform != platform { + return nil, nil + } + return []service.Account{s.account}, nil +} + +func (s *openAIWSUsageHandlerAccountRepoStub) ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]service.Account, error) { + return s.ListSchedulableByPlatform(ctx, platform) +} + +func (s *openAIWSUsageHandlerAccountRepoStub) GetByID(ctx context.Context, id int64) (*service.Account, error) { + if s.account.ID != id { + return nil, nil + } + account := s.account + return &account, nil +} + +type openAIWSUsageHandlerUsageLogRepoStub struct { + service.UsageLogRepository + created chan *service.UsageLog +} + +func (s *openAIWSUsageHandlerUsageLogRepoStub) Create(ctx context.Context, log *service.UsageLog) (bool, error) { + if s.created != nil { + s.created <- log + } + return true, nil +} + +type openAIWSUsageHandlerChannelRepoStub struct { + service.ChannelRepository + channels []service.Channel + groupPlatforms map[int64]string +} + +func (s *openAIWSUsageHandlerChannelRepoStub) ListAll(ctx context.Context) ([]service.Channel, error) { + return s.channels, nil +} + +func (s *openAIWSUsageHandlerChannelRepoStub) GetGroupPlatforms(ctx context.Context, groupIDs []int64) (map[int64]string, error) { + out := make(map[int64]string, len(groupIDs)) + for _, groupID := range groupIDs { + if platform := strings.TrimSpace(s.groupPlatforms[groupID]); platform != "" { + out[groupID] = platform + } + } + return out, nil +} + +func runOpenAIResponsesWebSocketUsageLogCase(t *testing.T, tc openAIResponsesWSUsageLogCase) openAIResponsesWSUsageLogResult { + t.Helper() + gin.SetMode(gin.TestMode) + + upstreamPayloadCh := make(chan []byte, 1) + upstreamErrCh := make(chan error, 1) + upstreamServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{ + CompressionMode: coderws.CompressionContextTakeover, + }) + if err != nil { + upstreamErrCh <- err + return + } + defer func() { + _ = conn.CloseNow() + }() + + readCtx, cancelRead := context.WithTimeout(r.Context(), 3*time.Second) + msgType, payload, readErr := conn.Read(readCtx) + cancelRead() + if readErr != nil { + upstreamErrCh <- readErr + return + } + if msgType != coderws.MessageText && msgType != coderws.MessageBinary { + upstreamErrCh <- errors.New("unexpected upstream websocket message type") + return + } + upstreamPayloadCh <- payload + + writeCtx, cancelWrite := context.WithTimeout(r.Context(), 3*time.Second) + writeErr := conn.Write(writeCtx, coderws.MessageText, []byte( + `{"type":"response.completed","response":{"id":"resp_usage_e2e","model":"gpt-5.4","usage":{"input_tokens":2,"output_tokens":1}}}`, + )) + cancelWrite() + if writeErr != nil { + upstreamErrCh <- writeErr + return + } + _ = conn.Close(coderws.StatusNormalClosure, "done") + upstreamErrCh <- nil + })) + defer upstreamServer.Close() + + groupID := int64(4201) + account := service.Account{ + ID: 9901, + Name: "openai-ws-passthrough-usage-e2e", + Platform: service.PlatformOpenAI, + Type: service.AccountTypeAPIKey, + Status: service.StatusActive, + Schedulable: true, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + "base_url": upstreamServer.URL, + }, + Extra: map[string]any{ + "openai_apikey_responses_websockets_v2_enabled": true, + "openai_apikey_responses_websockets_v2_mode": service.OpenAIWSIngressModePassthrough, + }, + } + + cfg := &config.Config{} + cfg.RunMode = config.RunModeSimple + cfg.Default.RateMultiplier = 1 + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.ModeRouterV2Enabled = true + cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3 + cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3 + cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3 + + accountRepo := &openAIWSUsageHandlerAccountRepoStub{account: account} + usageRepo := &openAIWSUsageHandlerUsageLogRepoStub{created: make(chan *service.UsageLog, 1)} + + var channelSvc *service.ChannelService + if len(tc.channelMapping) > 0 { + channelSvc = service.NewChannelService(&openAIWSUsageHandlerChannelRepoStub{ + channels: []service.Channel{{ + ID: 7701, + Name: "openai-ws-e2e-channel", + Status: service.StatusActive, + GroupIDs: []int64{groupID}, + ModelMapping: map[string]map[string]string{service.PlatformOpenAI: tc.channelMapping}, + }}, + groupPlatforms: map[int64]string{groupID: service.PlatformOpenAI}, + }, nil, nil, nil) + } + + billingCacheSvc := service.NewBillingCacheService(nil, nil, nil, nil, nil, nil, cfg) + gatewaySvc := service.NewOpenAIGatewayService( + accountRepo, + usageRepo, + nil, + nil, + nil, + nil, + nil, + cfg, + nil, + nil, + service.NewBillingService(cfg, nil), + nil, + billingCacheSvc, + nil, + &service.DeferredService{}, + nil, + nil, + channelSvc, + nil, + nil, + ) + + cache := &concurrencyCacheMock{ + acquireUserSlotFn: func(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error) { + return true, nil + }, + acquireAccountSlotFn: func(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) { + return true, nil + }, + } + h := &OpenAIGatewayHandler{ + gatewayService: gatewaySvc, + billingCacheService: billingCacheSvc, + apiKeyService: &service.APIKeyService{}, + concurrencyHelper: NewConcurrencyHelper(service.NewConcurrencyService(cache), SSEPingFormatNone, time.Second), + } + + apiKey := &service.APIKey{ + ID: 1801, + GroupID: &groupID, + User: &service.User{ID: 1701, Status: service.StatusActive}, + } + router := gin.New() + router.Use(func(c *gin.Context) { + c.Set(string(middleware.ContextKeyAPIKey), apiKey) + c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{UserID: apiKey.User.ID, Concurrency: 1}) + c.Next() + }) + router.GET("/openai/v1/responses", h.ResponsesWebSocket) + handlerServer := httptest.NewServer(router) + defer handlerServer.Close() + + headers := http.Header{} + if tc.userAgent != nil { + headers.Set("User-Agent", *tc.userAgent) + } + dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second) + clientConn, _, err := coderws.Dial( + dialCtx, + "ws"+strings.TrimPrefix(handlerServer.URL, "http")+"/openai/v1/responses", + &coderws.DialOptions{HTTPHeader: headers, CompressionMode: coderws.CompressionContextTakeover}, + ) + cancelDial() + require.NoError(t, err) + defer func() { + _ = clientConn.CloseNow() + }() + + writeCtx, cancelWrite := context.WithTimeout(context.Background(), 3*time.Second) + err = clientConn.Write(writeCtx, coderws.MessageText, []byte(tc.firstPayload)) + cancelWrite() + require.NoError(t, err) + + readCtx, cancelRead := context.WithTimeout(context.Background(), 3*time.Second) + _, event, err := clientConn.Read(readCtx) + cancelRead() + require.NoError(t, err) + require.Equal(t, "response.completed", gjson.GetBytes(event, "type").String()) + _ = clientConn.Close(coderws.StatusNormalClosure, "done") + + var usageLog *service.UsageLog + select { + case usageLog = <-usageRepo.created: + require.NotNil(t, usageLog) + case <-time.After(3 * time.Second): + t.Fatal("等待 WebSocket usage log 写入超时") + } + + var upstreamFirstPayload []byte + select { + case upstreamFirstPayload = <-upstreamPayloadCh: + case <-time.After(3 * time.Second): + t.Fatal("等待上游 WebSocket 首帧超时") + } + + select { + case upstreamErr := <-upstreamErrCh: + require.NoError(t, upstreamErr) + case <-time.After(3 * time.Second): + t.Fatal("等待上游 WebSocket 结束超时") + } + + return openAIResponsesWSUsageLogResult{ + log: usageLog, + upstreamFirstPayload: upstreamFirstPayload, + } +} + +func testStringPtr(v string) *string { + return &v +} diff --git a/backend/internal/handler/openai_images.go b/backend/internal/handler/openai_images.go index 4d0078a7..eba701f1 100644 --- a/backend/internal/handler/openai_images.go +++ b/backend/internal/handler/openai_images.go @@ -81,6 +81,18 @@ func (h *OpenAIGatewayHandler) Images(c *gin.Context) { zap.String("capability", string(parsed.RequiredCapability)), ) + if !service.GroupAllowsImageGeneration(apiKey.Group) { + h.errorResponse(c, http.StatusForbidden, "permission_error", service.ImageGenerationPermissionMessage()) + return + } + imageReleaseFunc, acquired := h.acquireImageGenerationSlot(c, streamStarted) + if !acquired { + return + } + if imageReleaseFunc != nil { + defer imageReleaseFunc() + } + if parsed.Multipart { setOpsRequestContext(c, parsed.Model, parsed.Stream, nil) } else { @@ -188,62 +200,69 @@ func (h *OpenAIGatewayHandler) Images(c *gin.Context) { responseLatencyMs = forwardDurationMs - upstreamLatencyMs } service.SetOpsLatencyMs(c, service.OpsResponseLatencyMsKey, responseLatencyMs) - if err == nil && result != nil && result.FirstTokenMs != nil { + if result != nil && result.FirstTokenMs != nil { service.SetOpsLatencyMs(c, service.OpsTimeToFirstTokenMsKey, int64(*result.FirstTokenMs)) } if err != nil { - var failoverErr *service.UpstreamFailoverError - if errors.As(err, &failoverErr) { - h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil) - if failoverErr.RetryableOnSameAccount { - retryLimit := account.GetPoolModeRetryCount() - if sameAccountRetryCount[account.ID] < retryLimit { - sameAccountRetryCount[account.ID]++ - reqLog.Warn("openai.images.pool_mode_same_account_retry", - zap.Int64("account_id", account.ID), - zap.Int("upstream_status", failoverErr.StatusCode), - zap.Int("retry_limit", retryLimit), - zap.Int("retry_count", sameAccountRetryCount[account.ID]), - ) - select { - case <-c.Request.Context().Done(): - return - case <-time.After(sameAccountRetryDelay): + if result != nil && result.ImageCount > 0 { + reqLog.Warn("openai.images.forward_partial_error_with_image_result", + zap.Int64("account_id", account.ID), + zap.Int("image_count", result.ImageCount), + zap.Error(err), + ) + } else { + var failoverErr *service.UpstreamFailoverError + if errors.As(err, &failoverErr) { + h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil) + if failoverErr.RetryableOnSameAccount { + retryLimit := account.GetPoolModeRetryCount() + if sameAccountRetryCount[account.ID] < retryLimit { + sameAccountRetryCount[account.ID]++ + reqLog.Warn("openai.images.pool_mode_same_account_retry", + zap.Int64("account_id", account.ID), + zap.Int("upstream_status", failoverErr.StatusCode), + zap.Int("retry_limit", retryLimit), + zap.Int("retry_count", sameAccountRetryCount[account.ID]), + ) + select { + case <-c.Request.Context().Done(): + return + case <-time.After(sameAccountRetryDelay): + } + continue } - continue } + h.gatewayService.RecordOpenAIAccountSwitch() + failedAccountIDs[account.ID] = struct{}{} + lastFailoverErr = failoverErr + if switchCount >= maxAccountSwitches { + h.handleFailoverExhausted(c, failoverErr, streamStarted) + return + } + switchCount++ + reqLog.Warn("openai.images.upstream_failover_switching", + zap.Int64("account_id", account.ID), + zap.Int("upstream_status", failoverErr.StatusCode), + zap.Int("switch_count", switchCount), + zap.Int("max_switches", maxAccountSwitches), + ) + continue } - h.gatewayService.RecordOpenAIAccountSwitch() - failedAccountIDs[account.ID] = struct{}{} - lastFailoverErr = failoverErr - if switchCount >= maxAccountSwitches { - h.handleFailoverExhausted(c, failoverErr, streamStarted) + h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil) + wroteFallback := h.ensureForwardErrorResponse(c, streamStarted) + fields := []zap.Field{ + zap.Int64("account_id", account.ID), + zap.Bool("fallback_error_response_written", wroteFallback), + zap.Error(err), + } + if shouldLogOpenAIForwardFailureAsWarn(c, wroteFallback) { + reqLog.Warn("openai.images.forward_failed", fields...) return } - switchCount++ - reqLog.Warn("openai.images.upstream_failover_switching", - zap.Int64("account_id", account.ID), - zap.Int("upstream_status", failoverErr.StatusCode), - zap.Int("switch_count", switchCount), - zap.Int("max_switches", maxAccountSwitches), - ) - continue - } - h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil) - wroteFallback := h.ensureForwardErrorResponse(c, streamStarted) - fields := []zap.Field{ - zap.Int64("account_id", account.ID), - zap.Bool("fallback_error_response_written", wroteFallback), - zap.Error(err), - } - if shouldLogOpenAIForwardFailureAsWarn(c, wroteFallback) { - reqLog.Warn("openai.images.forward_failed", fields...) + reqLog.Error("openai.images.forward_failed", fields...) return } - reqLog.Error("openai.images.forward_failed", fields...) - return } - if result != nil { if account.Type == service.AccountTypeOAuth { h.gatewayService.UpdateCodexUsageSnapshotFromHeaders(c.Request.Context(), account.ID, result.ResponseHeaders) @@ -259,21 +278,27 @@ func (h *OpenAIGatewayHandler) Images(c *gin.Context) { if parsed.Multipart { requestPayloadHash = service.HashUsageRequestPayload([]byte(parsed.StickySessionSeed())) } + inboundEndpoint := GetInboundEndpoint(c) + upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform) - h.submitUsageRecordTask(func(ctx context.Context) { + upstreamModel := "" + if result != nil { + upstreamModel = result.UpstreamModel + } + h.submitMandatoryUsageRecordTask(func(ctx context.Context) { if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{ Result: result, APIKey: apiKey, User: apiKey.User, Account: account, Subscription: subscription, - InboundEndpoint: GetInboundEndpoint(c), - UpstreamEndpoint: GetUpstreamEndpoint(c, account.Platform), + InboundEndpoint: inboundEndpoint, + UpstreamEndpoint: upstreamEndpoint, UserAgent: userAgent, IPAddress: clientIP, RequestPayloadHash: requestPayloadHash, APIKeyService: h.apiKeyService, - ChannelUsageFields: channelMapping.ToUsageFields(parsed.Model, result.UpstreamModel), + ChannelUsageFields: channelMapping.ToUsageFields(parsed.Model, upstreamModel), }); err != nil { logger.L().With( zap.String("component", "handler.openai_gateway.images"), diff --git a/backend/internal/handler/openai_images_controls_test.go b/backend/internal/handler/openai_images_controls_test.go new file mode 100644 index 00000000..cebcccac --- /dev/null +++ b/backend/internal/handler/openai_images_controls_test.go @@ -0,0 +1,49 @@ +package handler + +import ( + "bytes" + "net/http" + "net/http/httptest" + "testing" + + middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" + "github.com/tidwall/gjson" +) + +func TestOpenAIGatewayHandlerImages_DisabledGroupRejectsBeforeScheduling(t *testing.T) { + gin.SetMode(gin.TestMode) + + body := []byte(`{"model":"gpt-image-2","prompt":"draw","size":"1024x1024"}`) + req := httptest.NewRequest(http.MethodPost, "/v1/images/generations", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = req + groupID := int64(111) + c.Set(string(middleware2.ContextKeyAPIKey), &service.APIKey{ + ID: 222, + GroupID: &groupID, + Group: &service.Group{ + ID: groupID, + AllowImageGeneration: false, + }, + User: &service.User{ID: 333}, + }) + c.Set(string(middleware2.ContextKeyUser), middleware2.AuthSubject{UserID: 333, Concurrency: 1}) + + h := &OpenAIGatewayHandler{ + gatewayService: &service.OpenAIGatewayService{}, + billingCacheService: &service.BillingCacheService{}, + apiKeyService: &service.APIKeyService{}, + concurrencyHelper: &ConcurrencyHelper{concurrencyService: &service.ConcurrencyService{}}, + } + + h.Images(c) + + require.Equal(t, http.StatusForbidden, rec.Code) + require.Equal(t, "permission_error", gjson.GetBytes(rec.Body.Bytes(), "error.type").String()) + require.Contains(t, rec.Body.String(), service.ImageGenerationPermissionMessage()) +} diff --git a/backend/internal/handler/usage_record_submit_task_test.go b/backend/internal/handler/usage_record_submit_task_test.go index 5c945815..e4c2837a 100644 --- a/backend/internal/handler/usage_record_submit_task_test.go +++ b/backend/internal/handler/usage_record_submit_task_test.go @@ -129,3 +129,63 @@ func TestOpenAIGatewayHandlerSubmitUsageRecordTask_WithoutPool_TaskPanicRecovere }) require.True(t, called.Load(), "panic 后后续任务应仍可执行") } + +func TestOpenAIGatewayHandlerSubmitMandatoryUsageRecordTask_DroppedTaskSyncFallback(t *testing.T) { + pool := service.NewUsageRecordWorkerPoolWithOptions(service.UsageRecordWorkerPoolOptions{ + WorkerCount: 1, + QueueSize: 1, + TaskTimeout: time.Second, + OverflowPolicy: "drop", + OverflowSamplePercent: 0, + AutoScaleEnabled: false, + }) + t.Cleanup(pool.Stop) + h := &OpenAIGatewayHandler{usageRecordWorkerPool: pool} + + block := make(chan struct{}) + release := make(chan struct{}) + pool.Submit(func(ctx context.Context) { + close(block) + <-release + }) + <-block + pool.Submit(func(ctx context.Context) {}) + + var called atomic.Bool + h.submitMandatoryUsageRecordTask(func(ctx context.Context) { + called.Store(true) + }) + close(release) + + require.True(t, called.Load(), "mandatory usage task must run synchronously when async submit is dropped") +} + +func TestOpenAIGatewayHandlerSubmitOpenAIUsageRecordTask_ImageResultUsesMandatoryFallback(t *testing.T) { + pool := service.NewUsageRecordWorkerPoolWithOptions(service.UsageRecordWorkerPoolOptions{ + WorkerCount: 1, + QueueSize: 1, + TaskTimeout: time.Second, + OverflowPolicy: "drop", + OverflowSamplePercent: 0, + AutoScaleEnabled: false, + }) + t.Cleanup(pool.Stop) + h := &OpenAIGatewayHandler{usageRecordWorkerPool: pool} + + block := make(chan struct{}) + release := make(chan struct{}) + pool.Submit(func(ctx context.Context) { + close(block) + <-release + }) + <-block + pool.Submit(func(ctx context.Context) {}) + + var called atomic.Bool + h.submitOpenAIUsageRecordTask(&service.OpenAIForwardResult{ImageCount: 1}, func(ctx context.Context) { + called.Store(true) + }) + close(release) + + require.True(t, called.Load(), "image usage task must be mandatory when async submit is dropped") +} diff --git a/backend/internal/pkg/apicompat/anthropic_responses_test.go b/backend/internal/pkg/apicompat/anthropic_responses_test.go index e8b25c2b..aa36ef0b 100644 --- a/backend/internal/pkg/apicompat/anthropic_responses_test.go +++ b/backend/internal/pkg/apicompat/anthropic_responses_test.go @@ -32,7 +32,13 @@ func TestAnthropicToResponses_BasicText(t *testing.T) { var items []ResponsesInputItem require.NoError(t, json.Unmarshal(resp.Input, &items)) require.Len(t, items, 1) + assert.Equal(t, "message", items[0].Type) assert.Equal(t, "user", items[0].Role) + var parts []ResponsesContentPart + require.NoError(t, json.Unmarshal(items[0].Content, &parts)) + require.Len(t, parts, 1) + assert.Equal(t, "input_text", parts[0].Type) + assert.Equal(t, "Hello", parts[0].Text) } func TestAnthropicToResponses_SystemPrompt(t *testing.T) { @@ -49,7 +55,12 @@ func TestAnthropicToResponses_SystemPrompt(t *testing.T) { var items []ResponsesInputItem require.NoError(t, json.Unmarshal(resp.Input, &items)) require.Len(t, items, 2) - assert.Equal(t, "system", items[0].Role) + assert.Equal(t, "developer", items[0].Role) + var parts []ResponsesContentPart + require.NoError(t, json.Unmarshal(items[0].Content, &parts)) + require.Len(t, parts, 1) + assert.Equal(t, "input_text", parts[0].Type) + assert.Equal(t, "You are helpful.", parts[0].Text) }) t.Run("array", func(t *testing.T) { @@ -65,11 +76,33 @@ func TestAnthropicToResponses_SystemPrompt(t *testing.T) { var items []ResponsesInputItem require.NoError(t, json.Unmarshal(resp.Input, &items)) require.Len(t, items, 2) - assert.Equal(t, "system", items[0].Role) - // System text should be joined with double newline. - var text string - require.NoError(t, json.Unmarshal(items[0].Content, &text)) - assert.Equal(t, "Part 1\n\nPart 2", text) + assert.Equal(t, "developer", items[0].Role) + var parts []ResponsesContentPart + require.NoError(t, json.Unmarshal(items[0].Content, &parts)) + require.Len(t, parts, 2) + assert.Equal(t, "input_text", parts[0].Type) + assert.Equal(t, "Part 1", parts[0].Text) + assert.Equal(t, "input_text", parts[1].Type) + assert.Equal(t, "Part 2", parts[1].Text) + }) + + t.Run("billing header skipped", func(t *testing.T) { + req := &AnthropicRequest{ + Model: "gpt-5.2", + MaxTokens: 100, + System: json.RawMessage(`[{"type":"text","text":"x-anthropic-billing-header: cc_version=1;"},{"type":"text","text":"Project prompt"}]`), + Messages: []AnthropicMessage{{Role: "user", Content: json.RawMessage(`"Hi"`)}}, + } + resp, err := AnthropicToResponses(req) + require.NoError(t, err) + + var items []ResponsesInputItem + require.NoError(t, json.Unmarshal(resp.Input, &items)) + require.Len(t, items, 2) + var parts []ResponsesContentPart + require.NoError(t, json.Unmarshal(items[0].Content, &parts)) + require.Len(t, parts, 1) + assert.Equal(t, "Project prompt", parts[0].Text) }) } @@ -94,6 +127,8 @@ func TestAnthropicToResponses_ToolUse(t *testing.T) { require.Len(t, resp.Tools, 1) assert.Equal(t, "function", resp.Tools[0].Type) assert.Equal(t, "get_weather", resp.Tools[0].Name) + require.NotNil(t, resp.Tools[0].Strict) + assert.False(t, *resp.Tools[0].Strict) // Check input items var items []ResponsesInputItem @@ -104,10 +139,10 @@ func TestAnthropicToResponses_ToolUse(t *testing.T) { assert.Equal(t, "user", items[0].Role) assert.Equal(t, "assistant", items[1].Role) assert.Equal(t, "function_call", items[2].Type) - assert.Equal(t, "fc_call_1", items[2].CallID) + assert.Equal(t, "call_1", items[2].CallID) assert.Empty(t, items[2].ID) assert.Equal(t, "function_call_output", items[3].Type) - assert.Equal(t, "fc_call_1", items[3].CallID) + assert.Equal(t, "call_1", items[3].CallID) assert.Equal(t, "Sunny, 72°F", items[3].Output) } @@ -261,6 +296,34 @@ func TestResponsesToAnthropic_ToolUse(t *testing.T) { assert.JSONEq(t, `{"city":"NYC"}`, string(anth.Content[1].Input)) } +func TestResponsesToAnthropic_ToolUseStopReasonDoesNotDependOnLastBlock(t *testing.T) { + resp := &ResponsesResponse{ + ID: "resp_tool_then_text", + Model: "gpt-5.5", + Status: "completed", + Output: []ResponsesOutput{ + { + Type: "function_call", + CallID: "call_todo", + Name: "TodoWrite", + Arguments: `{"todos":[{"content":"review changes","status":"in_progress"}]}`, + }, + { + Type: "message", + Content: []ResponsesContentPart{ + {Type: "output_text", Text: "Task list updated."}, + }, + }, + }, + } + + anth := ResponsesToAnthropic(resp, "claude-opus-4-6") + assert.Equal(t, "tool_use", anth.StopReason) + require.Len(t, anth.Content, 2) + assert.Equal(t, "tool_use", anth.Content[0].Type) + assert.Equal(t, "text", anth.Content[1].Type) +} + func TestResponsesToAnthropic_ReadToolDropsEmptyPages(t *testing.T) { resp := &ResponsesResponse{ ID: "resp_read", @@ -434,6 +497,45 @@ func TestStreamingTextOnly(t *testing.T) { assert.Equal(t, "message_stop", events[1].Type) } +func TestResponsesEventToAnthropicEvents_ResponseDone(t *testing.T) { + state := NewResponsesEventToAnthropicState() + state.Model = "gpt-4o" + + events := ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{ + Type: "response.done", + Response: &ResponsesResponse{ + Status: "completed", + Usage: &ResponsesUsage{InputTokens: 12, OutputTokens: 4}, + }, + }, state) + require.Len(t, events, 2) + assert.Equal(t, "message_delta", events[0].Type) + assert.Equal(t, "end_turn", events[0].Delta.StopReason) + assert.Equal(t, 12, events[0].Usage.InputTokens) + assert.Equal(t, 4, events[0].Usage.OutputTokens) + assert.Equal(t, "message_stop", events[1].Type) + assert.Nil(t, FinalizeResponsesAnthropicStream(state)) +} + +func TestResponsesEventToAnthropicEvents_ResponseDoneIncomplete(t *testing.T) { + state := NewResponsesEventToAnthropicState() + state.Model = "gpt-4o" + + events := ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{ + Type: "response.done", + Response: &ResponsesResponse{ + Status: "incomplete", + IncompleteDetails: &ResponsesIncompleteDetails{Reason: "max_output_tokens"}, + Usage: &ResponsesUsage{InputTokens: 12, OutputTokens: 4}, + }, + }, state) + require.Len(t, events, 2) + assert.Equal(t, "message_delta", events[0].Type) + assert.Equal(t, "max_tokens", events[0].Delta.StopReason) + assert.Equal(t, "message_stop", events[1].Type) + assert.Nil(t, FinalizeResponsesAnthropicStream(state)) +} + func TestStreamingCachedTokensUseAnthropicInputSemantics(t *testing.T) { state := NewResponsesEventToAnthropicState() ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{ @@ -514,6 +616,81 @@ func TestStreamingToolCall(t *testing.T) { assert.Equal(t, "tool_use", events[0].Delta.StopReason) } +func TestStreamingToolCallStopReasonSurvivesLaterText(t *testing.T) { + state := NewResponsesEventToAnthropicState() + + ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{ + Type: "response.created", + Response: &ResponsesResponse{ID: "resp_tool_then_text", Model: "gpt-5.5"}, + }, state) + + events := ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{ + Type: "response.output_item.added", + OutputIndex: 0, + Item: &ResponsesOutput{Type: "function_call", CallID: "call_todo", Name: "TodoWrite"}, + }, state) + require.Len(t, events, 1) + assert.Equal(t, "content_block_start", events[0].Type) + + events = ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{ + Type: "response.function_call_arguments.done", + OutputIndex: 0, + Arguments: `{"todos":[{"content":"review changes","status":"in_progress","activeForm":"reviewing changes"}]}`, + }, state) + require.Len(t, events, 2) + assert.Equal(t, "content_block_delta", events[0].Type) + assert.Equal(t, "content_block_stop", events[1].Type) + + events = ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{ + Type: "response.output_text.delta", + OutputIndex: 1, + Delta: "I will continue after the task list updates.", + }, state) + require.Len(t, events, 2) + assert.Equal(t, "content_block_start", events[0].Type) + assert.Equal(t, "content_block_delta", events[1].Type) + + events = ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{ + Type: "response.completed", + Response: &ResponsesResponse{ + Status: "completed", + Usage: &ResponsesUsage{InputTokens: 20, OutputTokens: 10}, + }, + }, state) + require.Len(t, events, 3) + assert.Equal(t, "content_block_stop", events[0].Type) + assert.Equal(t, "tool_use", events[1].Delta.StopReason) + assert.Equal(t, "message_stop", events[2].Type) +} + +func TestStreamingToolCallDoneWithoutDeltaEmitsArguments(t *testing.T) { + state := NewResponsesEventToAnthropicState() + + ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{ + Type: "response.created", + Response: &ResponsesResponse{ID: "resp_bash", Model: "gpt-5.5"}, + }, state) + + events := ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{ + Type: "response.output_item.added", + OutputIndex: 0, + Item: &ResponsesOutput{Type: "function_call", CallID: "call_bash", Name: "Bash"}, + }, state) + require.Len(t, events, 1) + assert.Equal(t, "content_block_start", events[0].Type) + + events = ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{ + Type: "response.function_call_arguments.done", + OutputIndex: 0, + Arguments: `{"command":"git -C \"/mnt/d/nodejs/other/edmt\" status --short --ignored"}`, + }, state) + require.Len(t, events, 2) + assert.Equal(t, "content_block_delta", events[0].Type) + assert.Equal(t, "input_json_delta", events[0].Delta.Type) + assert.JSONEq(t, `{"command":"git -C \"/mnt/d/nodejs/other/edmt\" status --short --ignored"}`, events[0].Delta.PartialJSON) + assert.Equal(t, "content_block_stop", events[1].Type) +} + func TestStreamingReadToolDropsEmptyPages(t *testing.T) { state := NewResponsesEventToAnthropicState() @@ -653,6 +830,27 @@ func TestFinalizeStream_AbnormalTermination(t *testing.T) { assert.Equal(t, "message_stop", events[2].Type) } +func TestFinalizeStream_ToolCallAbnormalTerminationKeepsToolUseStopReason(t *testing.T) { + state := NewResponsesEventToAnthropicState() + + ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{ + Type: "response.created", + Response: &ResponsesResponse{ID: "resp_tool_interrupted", Model: "gpt-5.5"}, + }, state) + ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{ + Type: "response.output_item.added", + OutputIndex: 0, + Item: &ResponsesOutput{Type: "function_call", CallID: "call_todo", Name: "TodoWrite"}, + }, state) + + events := FinalizeResponsesAnthropicStream(state) + require.Len(t, events, 3) + assert.Equal(t, "content_block_stop", events[0].Type) + assert.Equal(t, "message_delta", events[1].Type) + assert.Equal(t, "tool_use", events[1].Delta.StopReason) + assert.Equal(t, "message_stop", events[2].Type) +} + func TestStreamingEmptyResponse(t *testing.T) { state := NewResponsesEventToAnthropicState() @@ -788,8 +986,8 @@ func TestAnthropicToResponses_ThinkingEnabled(t *testing.T) { resp, err := AnthropicToResponses(req) require.NoError(t, err) require.NotNil(t, resp.Reasoning) - // thinking.type is ignored for effort; default high applies. - assert.Equal(t, "high", resp.Reasoning.Effort) + // thinking.type is ignored for effort; Codex bridge default medium applies. + assert.Equal(t, "medium", resp.Reasoning.Effort) assert.Equal(t, "auto", resp.Reasoning.Summary) assert.Contains(t, resp.Include, "reasoning.encrypted_content") assert.NotContains(t, resp.Include, "reasoning.summary") @@ -806,8 +1004,8 @@ func TestAnthropicToResponses_ThinkingAdaptive(t *testing.T) { resp, err := AnthropicToResponses(req) require.NoError(t, err) require.NotNil(t, resp.Reasoning) - // thinking.type is ignored for effort; default high applies. - assert.Equal(t, "high", resp.Reasoning.Effort) + // thinking.type is ignored for effort; Codex bridge default medium applies. + assert.Equal(t, "medium", resp.Reasoning.Effort) assert.Equal(t, "auto", resp.Reasoning.Summary) assert.NotContains(t, resp.Include, "reasoning.summary") } @@ -822,9 +1020,9 @@ func TestAnthropicToResponses_ThinkingDisabled(t *testing.T) { resp, err := AnthropicToResponses(req) require.NoError(t, err) - // Default effort applies (high → high) even when thinking is disabled. + // Default effort applies (medium) even when thinking is disabled. require.NotNil(t, resp.Reasoning) - assert.Equal(t, "high", resp.Reasoning.Effort) + assert.Equal(t, "medium", resp.Reasoning.Effort) } func TestAnthropicToResponses_NoThinking(t *testing.T) { @@ -836,9 +1034,9 @@ func TestAnthropicToResponses_NoThinking(t *testing.T) { resp, err := AnthropicToResponses(req) require.NoError(t, err) - // Default effort applies (high → high) when no thinking/output_config is set. + // Default effort applies (medium) when no thinking/output_config is set. require.NotNil(t, resp.Reasoning) - assert.Equal(t, "high", resp.Reasoning.Effort) + assert.Equal(t, "medium", resp.Reasoning.Effort) } // --------------------------------------------------------------------------- @@ -846,7 +1044,7 @@ func TestAnthropicToResponses_NoThinking(t *testing.T) { // --------------------------------------------------------------------------- func TestAnthropicToResponses_OutputConfigOverridesDefault(t *testing.T) { - // Default is high, but output_config.effort="low" overrides. low→low after mapping. + // Default is medium, but output_config.effort="low" overrides. low→low after mapping. req := &AnthropicRequest{ Model: "gpt-5.2", MaxTokens: 1024, @@ -880,7 +1078,7 @@ func TestAnthropicToResponses_OutputConfigWithoutThinking(t *testing.T) { } func TestAnthropicToResponses_OutputConfigHigh(t *testing.T) { - // output_config.effort="high" → mapped to "high" (1:1, both sides' default). + // output_config.effort="high" → mapped to "high" (1:1). req := &AnthropicRequest{ Model: "gpt-5.2", MaxTokens: 1024, @@ -912,7 +1110,7 @@ func TestAnthropicToResponses_OutputConfigMax(t *testing.T) { } func TestAnthropicToResponses_NoOutputConfig(t *testing.T) { - // No output_config → default high regardless of thinking.type. + // No output_config → default medium regardless of thinking.type. req := &AnthropicRequest{ Model: "gpt-5.2", MaxTokens: 1024, @@ -923,11 +1121,11 @@ func TestAnthropicToResponses_NoOutputConfig(t *testing.T) { resp, err := AnthropicToResponses(req) require.NoError(t, err) require.NotNil(t, resp.Reasoning) - assert.Equal(t, "high", resp.Reasoning.Effort) + assert.Equal(t, "medium", resp.Reasoning.Effort) } func TestAnthropicToResponses_OutputConfigWithoutEffort(t *testing.T) { - // output_config present but effort empty (e.g. only format set) → default high. + // output_config present but effort empty (e.g. only format set) → default medium. req := &AnthropicRequest{ Model: "gpt-5.2", MaxTokens: 1024, @@ -938,7 +1136,7 @@ func TestAnthropicToResponses_OutputConfigWithoutEffort(t *testing.T) { resp, err := AnthropicToResponses(req) require.NoError(t, err) require.NotNil(t, resp.Reasoning) - assert.Equal(t, "high", resp.Reasoning.Effort) + assert.Equal(t, "medium", resp.Reasoning.Effort) } // --------------------------------------------------------------------------- @@ -1110,7 +1308,7 @@ func TestAnthropicToResponses_ToolResultWithImage(t *testing.T) { // function_call_output should have text-only output (no image). assert.Equal(t, "function_call_output", items[2].Type) - assert.Equal(t, "fc_toolu_1", items[2].CallID) + assert.Equal(t, "toolu_1", items[2].CallID) assert.Equal(t, "(empty)", items[2].Output) // Image should be in a separate user message. diff --git a/backend/internal/pkg/apicompat/anthropic_to_responses.go b/backend/internal/pkg/apicompat/anthropic_to_responses.go index 268f9f22..5f04004d 100644 --- a/backend/internal/pkg/apicompat/anthropic_to_responses.go +++ b/backend/internal/pkg/apicompat/anthropic_to_responses.go @@ -32,6 +32,9 @@ func AnthropicToResponses(req *AnthropicRequest) (*ResponsesRequest, error) { storeFalse := false out.Store = &storeFalse + parallelToolCalls := true + out.ParallelToolCalls = ¶llelToolCalls + out.Text = &ResponsesText{Verbosity: "medium"} if req.MaxTokens > 0 { v := req.MaxTokens @@ -46,10 +49,10 @@ func AnthropicToResponses(req *AnthropicRequest) (*ResponsesRequest, error) { } // Determine reasoning effort: only output_config.effort controls the - // level; thinking.type is ignored. Default is high when unset (both - // Anthropic and OpenAI default to high). + // level; thinking.type is ignored. Default follows Codex CLI / airgate's + // Anthropic bridge shape, which uses medium when unset. // Anthropic levels map 1:1 to OpenAI: low→low, medium→medium, high→high, max→xhigh. - effort := "high" // default → both sides' default + effort := "medium" if req.OutputConfig != nil && req.OutputConfig.Effort != "" { effort = req.OutputConfig.Effort } @@ -108,16 +111,19 @@ func convertAnthropicToolChoiceToResponses(raw json.RawMessage) (json.RawMessage func convertAnthropicToResponsesInput(system json.RawMessage, msgs []AnthropicMessage) ([]ResponsesInputItem, error) { var out []ResponsesInputItem - // System prompt → system role input item. + // System prompt → developer role input item. ChatGPT Codex SSE behaves like + // Codex CLI here: keeping Anthropic system text in input preserves the + // conversation/cache shape better than moving it into instructions. if len(system) > 0 { - sysText, err := parseAnthropicSystemPrompt(system) + sysParts, err := parseAnthropicSystemContentParts(system) if err != nil { return nil, err } - if sysText != "" { - content, _ := json.Marshal(sysText) + if len(sysParts) > 0 { + content, _ := json.Marshal(sysParts) out = append(out, ResponsesInputItem{ - Role: "system", + Type: "message", + Role: "developer", Content: content, }) } @@ -133,24 +139,32 @@ func convertAnthropicToResponsesInput(system json.RawMessage, msgs []AnthropicMe return out, nil } -// parseAnthropicSystemPrompt handles the Anthropic system field which can be -// a plain string or an array of text blocks. -func parseAnthropicSystemPrompt(raw json.RawMessage) (string, error) { +// parseAnthropicSystemContentParts handles the Anthropic system field which can +// be a plain string or an array of text blocks. Claude Code may include an +// x-anthropic-billing-header block; airgate drops it before sending to Codex. +func parseAnthropicSystemContentParts(raw json.RawMessage) ([]ResponsesContentPart, error) { var s string if err := json.Unmarshal(raw, &s); err == nil { - return s, nil + if isAnthropicBillingHeaderText(s) || s == "" { + return nil, nil + } + return []ResponsesContentPart{{Type: "input_text", Text: s}}, nil } var blocks []AnthropicContentBlock if err := json.Unmarshal(raw, &blocks); err != nil { - return "", err + return nil, err } - var parts []string + var parts []ResponsesContentPart for _, b := range blocks { - if b.Type == "text" && b.Text != "" { - parts = append(parts, b.Text) + if b.Type == "text" && b.Text != "" && !isAnthropicBillingHeaderText(b.Text) { + parts = append(parts, ResponsesContentPart{Type: "input_text", Text: b.Text}) } } - return strings.Join(parts, "\n\n"), nil + return parts, nil +} + +func isAnthropicBillingHeaderText(text string) bool { + return strings.HasPrefix(text, "x-anthropic-billing-header: ") } // anthropicMsgToResponsesItems converts a single Anthropic message into one @@ -173,8 +187,12 @@ func anthropicUserToResponses(raw json.RawMessage) ([]ResponsesInputItem, error) // Try plain string. var s string if err := json.Unmarshal(raw, &s); err == nil { - content, _ := json.Marshal(s) - return []ResponsesInputItem{{Role: "user", Content: content}}, nil + parts := []ResponsesContentPart{{Type: "input_text", Text: s}} + partsJSON, err := json.Marshal(parts) + if err != nil { + return nil, err + } + return []ResponsesInputItem{{Type: "message", Role: "user", Content: partsJSON}}, nil } var blocks []AnthropicContentBlock @@ -223,7 +241,7 @@ func anthropicUserToResponses(raw json.RawMessage) ([]ResponsesInputItem, error) if err != nil { return nil, err } - out = append(out, ResponsesInputItem{Role: "user", Content: content}) + out = append(out, ResponsesInputItem{Type: "message", Role: "user", Content: content}) } return out, nil @@ -242,7 +260,7 @@ func anthropicAssistantToResponses(raw json.RawMessage) ([]ResponsesInputItem, e if err != nil { return nil, err } - return []ResponsesInputItem{{Role: "assistant", Content: partsJSON}}, nil + return []ResponsesInputItem{{Type: "message", Role: "assistant", Content: partsJSON}}, nil } var blocks []AnthropicContentBlock @@ -260,7 +278,7 @@ func anthropicAssistantToResponses(raw json.RawMessage) ([]ResponsesInputItem, e if err != nil { return nil, err } - items = append(items, ResponsesInputItem{Role: "assistant", Content: partsJSON}) + items = append(items, ResponsesInputItem{Type: "message", Role: "assistant", Content: partsJSON}) } // tool_use → function_call items. @@ -284,17 +302,14 @@ func anthropicAssistantToResponses(raw json.RawMessage) ([]ResponsesInputItem, e return items, nil } -// toResponsesCallID converts an Anthropic tool ID (toolu_xxx / call_xxx) to a -// Responses API function_call ID that starts with "fc_". +// toResponsesCallID preserves Anthropic tool IDs as Responses call_id values. +// Claude Code sends tool_result.tool_use_id back verbatim, and ChatGPT Codex +// continuation expects that call_id to match the original tool_use id. func toResponsesCallID(id string) string { - if strings.HasPrefix(id, "fc_") { - return id - } - return "fc_" + id + return id } -// fromResponsesCallID reverses toResponsesCallID, stripping the "fc_" prefix -// that was added during request conversion. +// fromResponsesCallID reverses old prefixed IDs while preserving current IDs. func fromResponsesCallID(id string) string { if after, ok := strings.CutPrefix(id, "fc_"); ok { // Only strip if the remainder doesn't look like it was already "fc_" prefixed. @@ -412,11 +427,16 @@ func convertAnthropicToolsToResponses(tools []AnthropicTool) []ResponsesTool { Name: t.Name, Description: t.Description, Parameters: normalizeToolParameters(t.InputSchema), + Strict: boolPtr(false), }) } return out } +func boolPtr(v bool) *bool { + return &v +} + // normalizeToolParameters ensures the tool parameter schema is valid for // OpenAI's Responses API, which requires "properties" on object schemas. // diff --git a/backend/internal/pkg/apicompat/chatcompletions_responses_test.go b/backend/internal/pkg/apicompat/chatcompletions_responses_test.go index 35d42999..bf5c23d5 100644 --- a/backend/internal/pkg/apicompat/chatcompletions_responses_test.go +++ b/backend/internal/pkg/apicompat/chatcompletions_responses_test.go @@ -720,6 +720,49 @@ func TestResponsesEventToChatChunks_Completed(t *testing.T) { assert.Equal(t, 30, chunks[1].Usage.PromptTokensDetails.CachedTokens) } +func TestResponsesEventToChatChunks_ResponseDone(t *testing.T) { + state := NewResponsesEventToChatState() + state.Model = "gpt-4o" + state.IncludeUsage = true + + chunks := ResponsesEventToChatChunks(&ResponsesStreamEvent{ + Type: "response.done", + Response: &ResponsesResponse{ + Status: "completed", + Usage: &ResponsesUsage{InputTokens: 13, OutputTokens: 7}, + }, + }, state) + require.Len(t, chunks, 2) + require.NotNil(t, chunks[0].Choices[0].FinishReason) + assert.Equal(t, "stop", *chunks[0].Choices[0].FinishReason) + require.NotNil(t, chunks[1].Usage) + assert.Equal(t, 13, chunks[1].Usage.PromptTokens) + assert.Equal(t, 7, chunks[1].Usage.CompletionTokens) + assert.Nil(t, FinalizeResponsesChatStream(state)) +} + +func TestResponsesEventToChatChunks_ResponseDoneIncomplete(t *testing.T) { + state := NewResponsesEventToChatState() + state.Model = "gpt-4o" + state.IncludeUsage = true + + chunks := ResponsesEventToChatChunks(&ResponsesStreamEvent{ + Type: "response.done", + Response: &ResponsesResponse{ + Status: "incomplete", + IncompleteDetails: &ResponsesIncompleteDetails{Reason: "max_output_tokens"}, + Usage: &ResponsesUsage{InputTokens: 13, OutputTokens: 7}, + }, + }, state) + require.Len(t, chunks, 2) + require.NotNil(t, chunks[0].Choices[0].FinishReason) + assert.Equal(t, "length", *chunks[0].Choices[0].FinishReason) + require.NotNil(t, chunks[1].Usage) + assert.Equal(t, 13, chunks[1].Usage.PromptTokens) + assert.Equal(t, 7, chunks[1].Usage.CompletionTokens) + assert.Nil(t, FinalizeResponsesChatStream(state)) +} + func TestResponsesEventToChatChunks_CompletedWithToolCalls(t *testing.T) { state := NewResponsesEventToChatState() state.Model = "gpt-4o" diff --git a/backend/internal/pkg/apicompat/responses_to_anthropic.go b/backend/internal/pkg/apicompat/responses_to_anthropic.go index 489ed238..d7ef0145 100644 --- a/backend/internal/pkg/apicompat/responses_to_anthropic.go +++ b/backend/internal/pkg/apicompat/responses_to_anthropic.go @@ -120,7 +120,7 @@ func responsesStatusToAnthropicStopReason(status string, details *ResponsesIncom } return "end_turn" case "completed": - if len(blocks) > 0 && blocks[len(blocks)-1].Type == "tool_use" { + if containsAnthropicToolUseBlock(blocks) { return "tool_use" } return "end_turn" @@ -129,6 +129,15 @@ func responsesStatusToAnthropicStopReason(status string, details *ResponsesIncom } } +func containsAnthropicToolUseBlock(blocks []AnthropicContentBlock) bool { + for _, block := range blocks { + if block.Type == "tool_use" { + return true + } + } + return false +} + func sanitizeAnthropicToolUseInput(name string, raw string) json.RawMessage { if name != "Read" || raw == "" { return json.RawMessage(raw) @@ -161,11 +170,13 @@ type ResponsesEventToAnthropicState struct { MessageStartSent bool MessageStopSent bool - ContentBlockIndex int - ContentBlockOpen bool - CurrentBlockType string // "text" | "thinking" | "tool_use" - CurrentToolName string - CurrentToolArgs string + ContentBlockIndex int + ContentBlockOpen bool + CurrentBlockType string // "text" | "thinking" | "tool_use" + CurrentToolName string + CurrentToolArgs string + CurrentToolHadDelta bool + HasToolCall bool // OutputIndexToBlockIdx maps Responses output_index → Anthropic content block index. OutputIndexToBlockIdx map[int]int @@ -212,7 +223,9 @@ func ResponsesEventToAnthropicEvents( return resToAnthHandleReasoningDelta(evt, state) case "response.reasoning_summary_text.done": return resToAnthHandleBlockDone(state) - case "response.completed", "response.incomplete", "response.failed": + // response.done 是 Realtime/WS 与项目透传路径使用的终止别名; + // 普通 Responses HTTP SSE 的公开终止事件仍以 response.completed 为主。 + case "response.completed", "response.done", "response.incomplete", "response.failed": return resToAnthHandleCompleted(evt, state) default: return nil @@ -229,11 +242,16 @@ func FinalizeResponsesAnthropicStream(state *ResponsesEventToAnthropicState) []A var events []AnthropicStreamEvent events = append(events, closeCurrentBlock(state)...) + stopReason := "end_turn" + if state.HasToolCall { + stopReason = "tool_use" + } + events = append(events, AnthropicStreamEvent{ Type: "message_delta", Delta: &AnthropicDelta{ - StopReason: "end_turn", + StopReason: stopReason, }, Usage: &AnthropicUsage{ InputTokens: state.InputTokens, @@ -304,6 +322,8 @@ func resToAnthHandleOutputItemAdded(evt *ResponsesStreamEvent, state *ResponsesE state.CurrentBlockType = "tool_use" state.CurrentToolName = evt.Item.Name state.CurrentToolArgs = "" + state.CurrentToolHadDelta = false + state.HasToolCall = true events = append(events, AnthropicStreamEvent{ Type: "content_block_start", @@ -388,6 +408,9 @@ func resToAnthHandleFuncArgsDelta(evt *ResponsesStreamEvent, state *ResponsesEve state.CurrentToolArgs += evt.Delta return nil } + if state.CurrentBlockType == "tool_use" { + state.CurrentToolHadDelta = true + } blockIdx, ok := state.OutputIndexToBlockIdx[evt.OutputIndex] if !ok { @@ -405,7 +428,7 @@ func resToAnthHandleFuncArgsDelta(evt *ResponsesStreamEvent, state *ResponsesEve } func resToAnthHandleFuncArgsDone(evt *ResponsesStreamEvent, state *ResponsesEventToAnthropicState) []AnthropicStreamEvent { - if state.CurrentBlockType != "tool_use" || state.CurrentToolName != "Read" { + if state.CurrentBlockType != "tool_use" { return resToAnthHandleBlockDone(state) } @@ -413,10 +436,16 @@ func resToAnthHandleFuncArgsDone(evt *ResponsesStreamEvent, state *ResponsesEven if raw == "" { raw = state.CurrentToolArgs } - sanitized := sanitizeAnthropicToolUseInput(state.CurrentToolName, raw) - if len(sanitized) == 0 { + if raw == "" || state.CurrentToolHadDelta { return closeCurrentBlock(state) } + if state.CurrentToolName == "Read" { + sanitized := sanitizeAnthropicToolUseInput(state.CurrentToolName, raw) + if len(sanitized) == 0 { + return closeCurrentBlock(state) + } + raw = string(sanitized) + } idx := state.ContentBlockIndex events := []AnthropicStreamEvent{{ @@ -424,7 +453,7 @@ func resToAnthHandleFuncArgsDone(evt *ResponsesStreamEvent, state *ResponsesEven Index: &idx, Delta: &AnthropicDelta{ Type: "input_json_delta", - PartialJSON: string(sanitized), + PartialJSON: raw, }, }} events = append(events, closeCurrentBlock(state)...) @@ -551,7 +580,7 @@ func resToAnthHandleCompleted(evt *ResponsesStreamEvent, state *ResponsesEventTo stopReason = "max_tokens" } case "completed": - if state.ContentBlockIndex > 0 && state.CurrentBlockType == "tool_use" { + if state.HasToolCall { stopReason = "tool_use" } } @@ -584,6 +613,7 @@ func closeCurrentBlock(state *ResponsesEventToAnthropicState) []AnthropicStreamE state.ContentBlockIndex++ state.CurrentToolName = "" state.CurrentToolArgs = "" + state.CurrentToolHadDelta = false return []AnthropicStreamEvent{{ Type: "content_block_stop", Index: &idx, diff --git a/backend/internal/pkg/apicompat/responses_to_chatcompletions.go b/backend/internal/pkg/apicompat/responses_to_chatcompletions.go index 61b3bf9c..2386771d 100644 --- a/backend/internal/pkg/apicompat/responses_to_chatcompletions.go +++ b/backend/internal/pkg/apicompat/responses_to_chatcompletions.go @@ -160,7 +160,9 @@ func ResponsesEventToChatChunks(evt *ResponsesStreamEvent, state *ResponsesEvent return resToChatHandleReasoningDelta(evt, state) case "response.reasoning_summary_text.done": return nil - case "response.completed", "response.incomplete", "response.failed": + // response.done 是 Realtime/WS 与项目透传路径使用的终止别名; + // 普通 Responses HTTP SSE 的公开终止事件仍以 response.completed 为主。 + case "response.completed", "response.done", "response.incomplete", "response.failed": return resToChatHandleCompleted(evt, state) default: return nil diff --git a/backend/internal/pkg/apicompat/types.go b/backend/internal/pkg/apicompat/types.go index f8c6b75f..f9cd5a1c 100644 --- a/backend/internal/pkg/apicompat/types.go +++ b/backend/internal/pkg/apicompat/types.go @@ -53,6 +53,8 @@ type AnthropicMessage struct { type AnthropicContentBlock struct { Type string `json:"type"` + CacheControl *AnthropicCacheControl `json:"cache_control,omitempty"` + // type=text Text string `json:"text,omitempty"` @@ -165,19 +167,23 @@ type AnthropicDelta struct { // ResponsesRequest is the request body for POST /v1/responses. type ResponsesRequest struct { - Model string `json:"model"` - Instructions string `json:"instructions,omitempty"` - Input json.RawMessage `json:"input"` // string or []ResponsesInputItem - MaxOutputTokens *int `json:"max_output_tokens,omitempty"` - Temperature *float64 `json:"temperature,omitempty"` - TopP *float64 `json:"top_p,omitempty"` - Stream bool `json:"stream,omitempty"` - Tools []ResponsesTool `json:"tools,omitempty"` - Include []string `json:"include,omitempty"` - Store *bool `json:"store,omitempty"` - Reasoning *ResponsesReasoning `json:"reasoning,omitempty"` - ToolChoice json.RawMessage `json:"tool_choice,omitempty"` - ServiceTier string `json:"service_tier,omitempty"` + Model string `json:"model"` + Instructions string `json:"instructions,omitempty"` + Input json.RawMessage `json:"input"` // string or []ResponsesInputItem + MaxOutputTokens *int `json:"max_output_tokens,omitempty"` + Temperature *float64 `json:"temperature,omitempty"` + TopP *float64 `json:"top_p,omitempty"` + Stream bool `json:"stream,omitempty"` + Tools []ResponsesTool `json:"tools,omitempty"` + Include []string `json:"include,omitempty"` + Store *bool `json:"store,omitempty"` + ParallelToolCalls *bool `json:"parallel_tool_calls,omitempty"` + Reasoning *ResponsesReasoning `json:"reasoning,omitempty"` + Text *ResponsesText `json:"text,omitempty"` + ToolChoice json.RawMessage `json:"tool_choice,omitempty"` + ServiceTier string `json:"service_tier,omitempty"` + PromptCacheKey string `json:"prompt_cache_key,omitempty"` + PreviousResponseID string `json:"previous_response_id,omitempty"` } // ResponsesReasoning configures reasoning effort in the Responses API. @@ -186,13 +192,18 @@ type ResponsesReasoning struct { Summary string `json:"summary,omitempty"` // "auto" | "concise" | "detailed" } +// ResponsesText configures text output options in the Responses API. +type ResponsesText struct { + Verbosity string `json:"verbosity,omitempty"` // "low" | "medium" | "high" +} + // ResponsesInputItem is one item in the Responses API input array. // The Type field determines which other fields are populated. type ResponsesInputItem struct { // Common Type string `json:"type,omitempty"` // "" for role-based messages - // Role-based messages (system/user/assistant) + // Role-based messages (developer/system/user/assistant) Role string `json:"role,omitempty"` Content json.RawMessage `json:"content,omitempty"` // string or []ResponsesContentPart @@ -314,7 +325,7 @@ type ResponsesOutputTokensDetails struct { type ResponsesStreamEvent struct { Type string `json:"type"` - // response.created / response.completed / response.failed / response.incomplete + // response.created / response.completed / response.done / response.failed / response.incomplete Response *ResponsesResponse `json:"response,omitempty"` // response.output_item.added / response.output_item.done diff --git a/backend/internal/pkg/openai_compat/upstream_capability.go b/backend/internal/pkg/openai_compat/upstream_capability.go new file mode 100644 index 00000000..ff05afe5 --- /dev/null +++ b/backend/internal/pkg/openai_compat/upstream_capability.go @@ -0,0 +1,75 @@ +// Package openai_compat 提供 OpenAI 协议族在不同上游间的能力差异判定工具。 +// +// 背景:sub2api 的 OpenAI APIKey 账号通过 base_url 接入多种第三方 OpenAI 兼容上游 +// (DeepSeek、Kimi、GLM、Qwen 等)。这些上游普遍只支持 /v1/chat/completions, +// 不存在 /v1/responses 端点。但网关历史代码无差别走 CC→Responses 转换并打到 +// /v1/responses,导致兼容上游 404。 +// +// 本包提供基于"账号探测标记"的能力判定,配合 +// internal/service/openai_apikey_responses_probe.go 在创建/修改账号时一次性 +// 探测并落标。 +// +// 设计取舍: +// - 不维护静态 host 白名单——避免新增厂商时必须改代码(讨论沉淀于 +// pensieve/short-term/knowledge/upstream-capability-detection-design-tradeoffs) +// - 标记缺失时默认 true(即"走 Responses"),保持与重构前老代码完全一致的存量 +// 账号行为("现状即证据"原则;详见 +// pensieve/short-term/maxims/preserve-existing-runtime-behavior-when-replacing-logic-in-stateful-systems) +package openai_compat + +// AccountResponsesSupport 描述账号上游对 OpenAI Responses API 的支持状态。 +// +// 仅用于 platform=openai + type=apikey 的账号;其他账号类型不应调用本包判定。 +type AccountResponsesSupport int + +const ( + // ResponsesSupportUnknown 表示账号尚未完成能力探测(extra 字段缺失)。 + // 上游路由层应按"现状即证据"原则默认走 Responses,保持与重构前一致。 + ResponsesSupportUnknown AccountResponsesSupport = iota + + // ResponsesSupportYes 探测确认上游支持 /v1/responses。 + ResponsesSupportYes + + // ResponsesSupportNo 探测确认上游不支持 /v1/responses,应走 + // /v1/chat/completions 直转路径。 + ResponsesSupportNo +) + +// ExtraKeyResponsesSupported 是 accounts.extra JSON 中存储探测结果的键名。 +// 值类型为 bool:true=支持、false=不支持、键缺失=未探测。 +const ExtraKeyResponsesSupported = "openai_responses_supported" + +// ResolveResponsesSupport 从账号的 extra map 中读取探测标记。 +// +// 标记缺失或类型不匹配时返回 ResponsesSupportUnknown——调用方应按 +// "未探测=保留旧行为=走 Responses" 处理(参见 ShouldUseResponsesAPI)。 +func ResolveResponsesSupport(extra map[string]any) AccountResponsesSupport { + if extra == nil { + return ResponsesSupportUnknown + } + v, ok := extra[ExtraKeyResponsesSupported] + if !ok { + return ResponsesSupportUnknown + } + supported, ok := v.(bool) + if !ok { + return ResponsesSupportUnknown + } + if supported { + return ResponsesSupportYes + } + return ResponsesSupportNo +} + +// ShouldUseResponsesAPI 判断 OpenAI APIKey 账号的入站 /v1/chat/completions 请求 +// 是否应走"CC→Responses 转换 + 上游 /v1/responses"路径。 +// +// 返回 true 的两种情况: +// 1. 账号已探测确认支持 Responses +// 2. 账号未探测(标记缺失)——按"现状即证据"原则保留旧行为 +// +// 仅当账号已探测且确认不支持时返回 false,此时调用方应走 CC 直转路径 +// (详见 internal/service/openai_gateway_chat_completions_raw.go)。 +func ShouldUseResponsesAPI(extra map[string]any) bool { + return ResolveResponsesSupport(extra) != ResponsesSupportNo +} diff --git a/backend/internal/pkg/openai_compat/upstream_capability_test.go b/backend/internal/pkg/openai_compat/upstream_capability_test.go new file mode 100644 index 00000000..d650daa4 --- /dev/null +++ b/backend/internal/pkg/openai_compat/upstream_capability_test.go @@ -0,0 +1,55 @@ +package openai_compat + +import "testing" + +func TestResolveResponsesSupport(t *testing.T) { + tests := []struct { + name string + extra map[string]any + want AccountResponsesSupport + }{ + {"nil extra", nil, ResponsesSupportUnknown}, + {"empty extra", map[string]any{}, ResponsesSupportUnknown}, + {"key missing", map[string]any{"other": "value"}, ResponsesSupportUnknown}, + {"value true", map[string]any{ExtraKeyResponsesSupported: true}, ResponsesSupportYes}, + {"value false", map[string]any{ExtraKeyResponsesSupported: false}, ResponsesSupportNo}, + {"value wrong type string", map[string]any{ExtraKeyResponsesSupported: "true"}, ResponsesSupportUnknown}, + {"value wrong type number", map[string]any{ExtraKeyResponsesSupported: 1}, ResponsesSupportUnknown}, + {"value nil", map[string]any{ExtraKeyResponsesSupported: nil}, ResponsesSupportUnknown}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got := ResolveResponsesSupport(tc.extra) + if got != tc.want { + t.Errorf("ResolveResponsesSupport(%v) = %v, want %v", tc.extra, got, tc.want) + } + }) + } +} + +func TestShouldUseResponsesAPI(t *testing.T) { + tests := []struct { + name string + extra map[string]any + want bool + }{ + // 关键不变量:未探测必须返回 true(保留旧行为) + {"unknown defaults to true (preserve old behavior)", nil, true}, + {"unknown empty defaults to true", map[string]any{}, true}, + {"unknown wrong type defaults to true", map[string]any{ExtraKeyResponsesSupported: "yes"}, true}, + + // 已探测:标记决定 + {"explicitly supported", map[string]any{ExtraKeyResponsesSupported: true}, true}, + {"explicitly unsupported", map[string]any{ExtraKeyResponsesSupported: false}, false}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got := ShouldUseResponsesAPI(tc.extra) + if got != tc.want { + t.Errorf("ShouldUseResponsesAPI(%v) = %v, want %v", tc.extra, got, tc.want) + } + }) + } +} diff --git a/backend/internal/repository/affiliate_repo.go b/backend/internal/repository/affiliate_repo.go index ef89e5b6..61da539e 100644 --- a/backend/internal/repository/affiliate_repo.go +++ b/backend/internal/repository/affiliate_repo.go @@ -22,6 +22,34 @@ const ( var affiliateCodeCharset = []byte("ABCDEFGHJKLMNPQRSTUVWXYZ23456789") +const affiliateUserOverviewSQL = ` +SELECT ua.user_id, + COALESCE(u.email, ''), + COALESCE(u.username, ''), + ua.aff_code, + COALESCE(ua.aff_rebate_rate_percent, 0)::double precision, + (ua.aff_rebate_rate_percent IS NOT NULL) AS has_custom_rate, + ua.aff_count, + COALESCE(rebated.rebated_invitee_count, 0), + (ua.aff_quota + COALESCE(matured.matured_frozen_quota, 0))::double precision, + ua.aff_history_quota::double precision +FROM user_affiliates ua +JOIN users u ON u.id = ua.user_id +LEFT JOIN ( + SELECT user_id, COUNT(DISTINCT source_user_id)::integer AS rebated_invitee_count + FROM user_affiliate_ledger + WHERE action = 'accrue' AND source_user_id IS NOT NULL + GROUP BY user_id +) rebated ON rebated.user_id = ua.user_id +LEFT JOIN ( + SELECT user_id, COALESCE(SUM(amount), 0)::double precision AS matured_frozen_quota + FROM user_affiliate_ledger + WHERE action = 'accrue' AND frozen_until IS NOT NULL AND frozen_until <= NOW() + GROUP BY user_id +) matured ON matured.user_id = ua.user_id +WHERE ua.user_id = $1 +LIMIT 1` + type affiliateQueryExecer interface { QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) @@ -86,7 +114,7 @@ func (r *affiliateRepository) BindInviter(ctx context.Context, userID, inviterID return bound, nil } -func (r *affiliateRepository) AccrueQuota(ctx context.Context, inviterID, inviteeUserID int64, amount float64, freezeHours int) (bool, error) { +func (r *affiliateRepository) AccrueQuota(ctx context.Context, inviterID, inviteeUserID int64, amount float64, freezeHours int, sourceOrderID *int64) (bool, error) { if amount <= 0 { return false, nil } @@ -112,15 +140,15 @@ func (r *affiliateRepository) AccrueQuota(ctx context.Context, inviterID, invite if freezeHours > 0 { if _, err = txClient.ExecContext(txCtx, ` -INSERT INTO user_affiliate_ledger (user_id, action, amount, source_user_id, frozen_until, created_at, updated_at) -VALUES ($1, 'accrue', $2, $3, NOW() + make_interval(hours => $4), NOW(), NOW())`, - inviterID, amount, inviteeUserID, freezeHours); err != nil { +INSERT INTO user_affiliate_ledger (user_id, action, amount, source_user_id, source_order_id, frozen_until, created_at, updated_at) +VALUES ($1, 'accrue', $2, $3, $4, NOW() + make_interval(hours => $5), NOW(), NOW())`, + inviterID, amount, inviteeUserID, nullableInt64Arg(sourceOrderID), freezeHours); err != nil { return fmt.Errorf("insert affiliate accrue ledger: %w", err) } } else { if _, err = txClient.ExecContext(txCtx, ` -INSERT INTO user_affiliate_ledger (user_id, action, amount, source_user_id, created_at, updated_at) -VALUES ($1, 'accrue', $2, $3, NOW(), NOW())`, inviterID, amount, inviteeUserID); err != nil { +INSERT INTO user_affiliate_ledger (user_id, action, amount, source_user_id, source_order_id, created_at, updated_at) +VALUES ($1, 'accrue', $2, $3, $4, NOW(), NOW())`, inviterID, amount, inviteeUserID, nullableInt64Arg(sourceOrderID)); err != nil { return fmt.Errorf("insert affiliate accrue ledger: %w", err) } } @@ -275,9 +303,32 @@ FROM cleared`, userID) return err } + snapshot, err := queryAffiliateTransferSnapshot(txCtx, txClient, userID) + if err != nil { + return err + } + if _, err = txClient.ExecContext(txCtx, ` -INSERT INTO user_affiliate_ledger (user_id, action, amount, source_user_id, created_at, updated_at) -VALUES ($1, 'transfer', $2, NULL, NOW(), NOW())`, userID, transferred); err != nil { +INSERT INTO user_affiliate_ledger ( + user_id, + action, + amount, + source_user_id, + balance_after, + aff_quota_after, + aff_frozen_quota_after, + aff_history_quota_after, + created_at, + updated_at +) +VALUES ($1, 'transfer', $2, NULL, $3, $4, $5, $6, NOW(), NOW())`, + userID, + transferred, + snapshot.BalanceAfter, + snapshot.AvailableQuotaAfter, + snapshot.FrozenQuotaAfter, + snapshot.HistoryQuotaAfter, + ); err != nil { return fmt.Errorf("insert affiliate transfer ledger: %w", err) } @@ -332,6 +383,349 @@ LIMIT $2`, inviterID, limit) return invitees, nil } +func (r *affiliateRepository) ListAffiliateInviteRecords(ctx context.Context, filter service.AffiliateRecordFilter) ([]service.AffiliateInviteRecord, int64, error) { + client := clientFromContext(ctx, r.client) + where, args := buildAffiliateRecordWhere(filter, "ua.created_at", []string{ + "inviter.email", "inviter.username", "invitee.email", "invitee.username", + "ua.inviter_id::text", "ua.user_id::text", "inviter_aff.aff_code", + }) + + total, err := queryAffiliateRecordCount(ctx, client, ` +SELECT COUNT(*) +FROM user_affiliates ua +JOIN users invitee ON invitee.id = ua.user_id +JOIN users inviter ON inviter.id = ua.inviter_id +JOIN user_affiliates inviter_aff ON inviter_aff.user_id = ua.inviter_id +`+where, args...) + if err != nil { + return nil, 0, err + } + + orderBy := buildAffiliateRecordOrderBy(filter, map[string]string{ + "inviter": "inviter.email", + "invitee": "invitee.email", + "aff_code": "inviter_aff.aff_code", + "total_rebate": "total_rebate", + "created_at": "ua.created_at", + }, "ua.created_at") + args = append(args, filter.PageSize, (filter.Page-1)*filter.PageSize) + rows, err := client.QueryContext(ctx, ` +SELECT ua.inviter_id, + COALESCE(inviter.email, ''), + COALESCE(inviter.username, ''), + ua.user_id, + COALESCE(invitee.email, ''), + COALESCE(invitee.username, ''), + COALESCE(inviter_aff.aff_code, ''), + COALESCE(SUM(ual.amount), 0)::double precision AS total_rebate, + ua.created_at +FROM user_affiliates ua +JOIN users invitee ON invitee.id = ua.user_id +JOIN users inviter ON inviter.id = ua.inviter_id +JOIN user_affiliates inviter_aff ON inviter_aff.user_id = ua.inviter_id +LEFT JOIN user_affiliate_ledger ual + ON ual.user_id = ua.inviter_id + AND ual.source_user_id = ua.user_id + AND ual.action = 'accrue' +`+where+` +GROUP BY ua.inviter_id, inviter.email, inviter.username, ua.user_id, invitee.email, invitee.username, inviter_aff.aff_code, ua.created_at +`+orderBy+` +LIMIT $`+fmt.Sprint(len(args)-1)+` OFFSET $`+fmt.Sprint(len(args)), args...) + if err != nil { + return nil, 0, err + } + defer func() { _ = rows.Close() }() + + items := make([]service.AffiliateInviteRecord, 0) + for rows.Next() { + var item service.AffiliateInviteRecord + if err := rows.Scan( + &item.InviterID, + &item.InviterEmail, + &item.InviterUsername, + &item.InviteeID, + &item.InviteeEmail, + &item.InviteeUsername, + &item.AffCode, + &item.TotalRebate, + &item.CreatedAt, + ); err != nil { + return nil, 0, err + } + items = append(items, item) + } + if err := rows.Err(); err != nil { + return nil, 0, err + } + return items, total, nil +} + +func (r *affiliateRepository) ListAffiliateRebateRecords(ctx context.Context, filter service.AffiliateRecordFilter) ([]service.AffiliateRebateRecord, int64, error) { + client := clientFromContext(ctx, r.client) + where, args := buildAffiliateRecordWhere(filter, "ual.created_at", []string{ + "inviter.email", "inviter.username", "invitee.email", "invitee.username", + "po.id::text", "po.out_trade_no", "po.payment_type", "po.status", + }) + baseJoin := ` +FROM user_affiliate_ledger ual +JOIN payment_orders po ON po.id = ual.source_order_id +JOIN users invitee ON invitee.id = ual.source_user_id +JOIN users inviter ON inviter.id = ual.user_id +WHERE ual.action = 'accrue' + AND ual.source_order_id IS NOT NULL` + if where != "" { + where = strings.Replace(where, "WHERE ", " AND ", 1) + } + + total, err := queryAffiliateRecordCount(ctx, client, "SELECT COUNT(*) "+baseJoin+where, args...) + if err != nil { + return nil, 0, err + } + + orderBy := buildAffiliateRecordOrderBy(filter, map[string]string{ + "order": "po.id", + "inviter": "inviter.email", + "invitee": "invitee.email", + "order_amount": "po.amount", + "pay_amount": "po.pay_amount", + "rebate_amount": "ual.amount", + "payment_type": "po.payment_type", + "order_status": "po.status", + "created_at": "ual.created_at", + }, "ual.created_at") + args = append(args, filter.PageSize, (filter.Page-1)*filter.PageSize) + rows, err := client.QueryContext(ctx, ` +SELECT po.id, + po.out_trade_no, + ual.user_id, + COALESCE(inviter.email, ''), + COALESCE(inviter.username, ''), + ual.source_user_id, + COALESCE(invitee.email, ''), + COALESCE(invitee.username, ''), + po.amount::double precision, + po.pay_amount::double precision, + ual.amount::double precision, + po.payment_type, + po.status, + ual.created_at +`+baseJoin+where+` +`+orderBy+` +LIMIT $`+fmt.Sprint(len(args)-1)+` OFFSET $`+fmt.Sprint(len(args)), args...) + if err != nil { + return nil, 0, err + } + defer func() { _ = rows.Close() }() + + items := make([]service.AffiliateRebateRecord, 0) + for rows.Next() { + var item service.AffiliateRebateRecord + if err := rows.Scan( + &item.OrderID, + &item.OutTradeNo, + &item.InviterID, + &item.InviterEmail, + &item.InviterUsername, + &item.InviteeID, + &item.InviteeEmail, + &item.InviteeUsername, + &item.OrderAmount, + &item.PayAmount, + &item.RebateAmount, + &item.PaymentType, + &item.OrderStatus, + &item.CreatedAt, + ); err != nil { + return nil, 0, err + } + items = append(items, item) + } + if err := rows.Err(); err != nil { + return nil, 0, err + } + return items, total, nil +} + +func (r *affiliateRepository) ListAffiliateTransferRecords(ctx context.Context, filter service.AffiliateRecordFilter) ([]service.AffiliateTransferRecord, int64, error) { + client := clientFromContext(ctx, r.client) + where, args := buildAffiliateRecordWhere(filter, "ual.created_at", []string{ + "u.email", "u.username", "u.id::text", + }) + baseJoin := ` +FROM user_affiliate_ledger ual +JOIN users u ON u.id = ual.user_id +WHERE ual.action = 'transfer'` + if where != "" { + where = strings.Replace(where, "WHERE ", " AND ", 1) + } + + total, err := queryAffiliateRecordCount(ctx, client, "SELECT COUNT(*) "+baseJoin+where, args...) + if err != nil { + return nil, 0, err + } + + orderBy := buildAffiliateRecordOrderBy(filter, map[string]string{ + "user": "u.email", + "amount": "ual.amount", + "balance_after": "ual.balance_after", + "available_quota_after": "ual.aff_quota_after", + "frozen_quota_after": "ual.aff_frozen_quota_after", + "history_quota_after": "ual.aff_history_quota_after", + "created_at": "ual.created_at", + }, "ual.created_at") + args = append(args, filter.PageSize, (filter.Page-1)*filter.PageSize) + rows, err := client.QueryContext(ctx, ` +SELECT ual.id, + ual.user_id, + COALESCE(u.email, ''), + COALESCE(u.username, ''), + ual.amount::double precision, + ual.balance_after::double precision, + ual.aff_quota_after::double precision, + ual.aff_frozen_quota_after::double precision, + ual.aff_history_quota_after::double precision, + ual.created_at +`+baseJoin+where+` +`+orderBy+` +LIMIT $`+fmt.Sprint(len(args)-1)+` OFFSET $`+fmt.Sprint(len(args)), args...) + if err != nil { + return nil, 0, err + } + defer func() { _ = rows.Close() }() + + items := make([]service.AffiliateTransferRecord, 0) + for rows.Next() { + var item service.AffiliateTransferRecord + var balanceAfter sql.NullFloat64 + var availableQuotaAfter sql.NullFloat64 + var frozenQuotaAfter sql.NullFloat64 + var historyQuotaAfter sql.NullFloat64 + if err := rows.Scan( + &item.LedgerID, + &item.UserID, + &item.UserEmail, + &item.Username, + &item.Amount, + &balanceAfter, + &availableQuotaAfter, + &frozenQuotaAfter, + &historyQuotaAfter, + &item.CreatedAt, + ); err != nil { + return nil, 0, err + } + item.BalanceAfter = nullableFloat64Ptr(balanceAfter) + item.AvailableQuotaAfter = nullableFloat64Ptr(availableQuotaAfter) + item.FrozenQuotaAfter = nullableFloat64Ptr(frozenQuotaAfter) + item.HistoryQuotaAfter = nullableFloat64Ptr(historyQuotaAfter) + item.SnapshotAvailable = balanceAfter.Valid && + availableQuotaAfter.Valid && + frozenQuotaAfter.Valid && + historyQuotaAfter.Valid + items = append(items, item) + } + if err := rows.Err(); err != nil { + return nil, 0, err + } + return items, total, nil +} + +func (r *affiliateRepository) GetAffiliateUserOverview(ctx context.Context, userID int64) (*service.AffiliateUserOverview, error) { + if userID <= 0 { + return nil, service.ErrUserNotFound + } + client := clientFromContext(ctx, r.client) + rows, err := client.QueryContext(ctx, affiliateUserOverviewSQL, userID) + if err != nil { + return nil, err + } + defer func() { _ = rows.Close() }() + + if !rows.Next() { + if err := rows.Err(); err != nil { + return nil, err + } + return nil, service.ErrUserNotFound + } + + var overview service.AffiliateUserOverview + var customRate float64 + var hasCustomRate bool + if err := rows.Scan( + &overview.UserID, + &overview.Email, + &overview.Username, + &overview.AffCode, + &customRate, + &hasCustomRate, + &overview.InvitedCount, + &overview.RebatedInviteeCount, + &overview.AvailableQuota, + &overview.HistoryQuota, + ); err != nil { + return nil, err + } + if hasCustomRate { + overview.RebateRatePercent = customRate + overview.RebateRateCustom = true + } + return &overview, rows.Err() +} + +func buildAffiliateRecordWhere(filter service.AffiliateRecordFilter, timeColumn string, searchColumns []string) (string, []any) { + clauses := make([]string, 0, 3) + args := make([]any, 0, 3) + if filter.StartAt != nil { + args = append(args, *filter.StartAt) + clauses = append(clauses, fmt.Sprintf("%s >= $%d", timeColumn, len(args))) + } + if filter.EndAt != nil { + args = append(args, *filter.EndAt) + clauses = append(clauses, fmt.Sprintf("%s <= $%d", timeColumn, len(args))) + } + search := strings.TrimSpace(filter.Search) + if search != "" && len(searchColumns) > 0 { + args = append(args, "%"+strings.ToLower(search)+"%") + parts := make([]string, 0, len(searchColumns)) + for _, col := range searchColumns { + parts = append(parts, fmt.Sprintf("LOWER(%s) LIKE $%d", col, len(args))) + } + clauses = append(clauses, "("+strings.Join(parts, " OR ")+")") + } + if len(clauses) == 0 { + return "", args + } + return "WHERE " + strings.Join(clauses, " AND "), args +} + +func buildAffiliateRecordOrderBy(filter service.AffiliateRecordFilter, sortColumns map[string]string, fallbackColumn string) string { + column := sortColumns[filter.SortBy] + if column == "" { + column = fallbackColumn + } + direction := "DESC" + if !filter.SortDesc { + direction = "ASC" + } + return "ORDER BY " + column + " " + direction + " NULLS LAST" +} + +func queryAffiliateRecordCount(ctx context.Context, client affiliateQueryExecer, query string, args ...any) (int64, error) { + rows, err := client.QueryContext(ctx, query, args...) + if err != nil { + return 0, err + } + defer func() { _ = rows.Close() }() + if !rows.Next() { + return 0, rows.Err() + } + var total int64 + if err := rows.Scan(&total); err != nil { + return 0, err + } + return total, rows.Err() +} + func (r *affiliateRepository) withTx(ctx context.Context, fn func(txCtx context.Context, txClient *dbent.Client) error) error { if tx := dbent.TxFromContext(ctx); tx != nil { return fn(ctx, tx.Client()) @@ -516,6 +910,54 @@ func queryUserBalance(ctx context.Context, client affiliateQueryExecer, userID i return balance, nil } +type affiliateTransferSnapshot struct { + BalanceAfter float64 + AvailableQuotaAfter float64 + FrozenQuotaAfter float64 + HistoryQuotaAfter float64 +} + +func queryAffiliateTransferSnapshot(ctx context.Context, client affiliateQueryExecer, userID int64) (*affiliateTransferSnapshot, error) { + rows, err := client.QueryContext(ctx, ` +SELECT u.balance::double precision, + ua.aff_quota::double precision, + ua.aff_frozen_quota::double precision, + ua.aff_history_quota::double precision +FROM users u +JOIN user_affiliates ua ON ua.user_id = u.id +WHERE u.id = $1 +LIMIT 1`, userID) + if err != nil { + return nil, fmt.Errorf("query affiliate transfer snapshot: %w", err) + } + defer func() { _ = rows.Close() }() + + if !rows.Next() { + if err := rows.Err(); err != nil { + return nil, err + } + return nil, service.ErrUserNotFound + } + + var snapshot affiliateTransferSnapshot + if err := rows.Scan( + &snapshot.BalanceAfter, + &snapshot.AvailableQuotaAfter, + &snapshot.FrozenQuotaAfter, + &snapshot.HistoryQuotaAfter, + ); err != nil { + return nil, err + } + return &snapshot, rows.Err() +} + +func nullableFloat64Ptr(v sql.NullFloat64) *float64 { + if !v.Valid { + return nil + } + return &v.Float64 +} + func generateAffiliateCode() (string, error) { buf := make([]byte, affiliateCodeLength) if _, err := rand.Read(buf); err != nil { @@ -674,6 +1116,13 @@ func nullableArg(v *float64) any { return *v } +func nullableInt64Arg(v *int64) any { + if v == nil { + return nil + } + return *v +} + // ListUsersWithCustomSettings 列出有专属配置(自定义码或专属比例)的用户。 // // 单一查询同时处理"无搜索"与"按邮箱/用户名模糊搜索": diff --git a/backend/internal/repository/affiliate_repo_integration_test.go b/backend/internal/repository/affiliate_repo_integration_test.go index 697a193b..b01ed528 100644 --- a/backend/internal/repository/affiliate_repo_integration_test.go +++ b/backend/internal/repository/affiliate_repo_integration_test.go @@ -78,6 +78,26 @@ VALUES ($1, $2, $3, $3, NOW(), NOW())`, u.ID, affCode, 12.34) ledgerCount := querySingleInt(t, txCtx, client, "SELECT COUNT(*) FROM user_affiliate_ledger WHERE user_id = $1 AND action = 'transfer'", u.ID) require.Equal(t, 1, ledgerCount) + + rows, err := client.QueryContext(txCtx, ` +SELECT amount::double precision, + balance_after::double precision, + aff_quota_after::double precision, + aff_frozen_quota_after::double precision, + aff_history_quota_after::double precision +FROM user_affiliate_ledger +WHERE user_id = $1 AND action = 'transfer' +LIMIT 1`, u.ID) + require.NoError(t, err) + defer func() { _ = rows.Close() }() + require.True(t, rows.Next(), "expected transfer ledger") + var amount, balanceAfter, quotaAfter, frozenAfter, historyAfter float64 + require.NoError(t, rows.Scan(&amount, &balanceAfter, "aAfter, &frozenAfter, &historyAfter)) + require.InDelta(t, 12.34, amount, 1e-9) + require.InDelta(t, 17.84, balanceAfter, 1e-9) + require.InDelta(t, 0.0, quotaAfter, 1e-9) + require.InDelta(t, 0.0, frozenAfter, 1e-9) + require.InDelta(t, 12.34, historyAfter, 1e-9) } // TestAffiliateRepository_AccrueQuota_ReusesOuterTransaction guards the @@ -125,7 +145,7 @@ func TestAffiliateRepository_AccrueQuota_ReusesOuterTransaction(t *testing.T) { require.NoError(t, err) require.True(t, bound, "invitee must bind to inviter") - applied, err := repo.AccrueQuota(txCtx, inviter.ID, invitee.ID, 3.5, 0) + applied, err := repo.AccrueQuota(txCtx, inviter.ID, invitee.ID, 3.5, 0, nil) require.NoError(t, err) require.True(t, applied, "AccrueQuota must report applied=true") diff --git a/backend/internal/repository/affiliate_repo_test.go b/backend/internal/repository/affiliate_repo_test.go new file mode 100644 index 00000000..ccb7bb3d --- /dev/null +++ b/backend/internal/repository/affiliate_repo_test.go @@ -0,0 +1,28 @@ +package repository + +import ( + "os" + "strings" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestAffiliateUserOverviewSQLIncludesMaturedFrozenQuota(t *testing.T) { + query := strings.Join(strings.Fields(affiliateUserOverviewSQL), " ") + + require.Contains(t, query, "ua.aff_quota + COALESCE(matured.matured_frozen_quota, 0)") + require.Contains(t, query, "frozen_until <= NOW()") +} + +func TestAffiliateRecordQueriesUseLedgerAuditFields(t *testing.T) { + source, err := os.ReadFile("affiliate_repo.go") + require.NoError(t, err) + content := string(source) + + require.Contains(t, content, "JOIN payment_orders po ON po.id = ual.source_order_id") + require.Contains(t, content, "ual.amount::double precision") + require.Contains(t, content, "ual.balance_after::double precision") + require.NotContains(t, content, "parseAffiliateRebateAmount") + require.NotContains(t, content, `"current_balance": "u.balance"`) +} diff --git a/backend/internal/repository/api_key_repo.go b/backend/internal/repository/api_key_repo.go index 3a527405..68895475 100644 --- a/backend/internal/repository/api_key_repo.go +++ b/backend/internal/repository/api_key_repo.go @@ -166,6 +166,9 @@ func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*se group.FieldDailyLimitUsd, group.FieldWeeklyLimitUsd, group.FieldMonthlyLimitUsd, + group.FieldAllowImageGeneration, + group.FieldImageRateIndependent, + group.FieldImageRateMultiplier, group.FieldImagePrice1k, group.FieldImagePrice2k, group.FieldImagePrice4k, @@ -699,6 +702,9 @@ func groupEntityToService(g *dbent.Group) *service.Group { DailyLimitUSD: g.DailyLimitUsd, WeeklyLimitUSD: g.WeeklyLimitUsd, MonthlyLimitUSD: g.MonthlyLimitUsd, + AllowImageGeneration: g.AllowImageGeneration, + ImageRateIndependent: g.ImageRateIndependent, + ImageRateMultiplier: g.ImageRateMultiplier, ImagePrice1K: g.ImagePrice1k, ImagePrice2K: g.ImagePrice2k, ImagePrice4K: g.ImagePrice4k, diff --git a/backend/internal/repository/group_repo.go b/backend/internal/repository/group_repo.go index 5e16475a..112575f4 100644 --- a/backend/internal/repository/group_repo.go +++ b/backend/internal/repository/group_repo.go @@ -50,6 +50,9 @@ func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) er SetNillableDailyLimitUsd(groupIn.DailyLimitUSD). SetNillableWeeklyLimitUsd(groupIn.WeeklyLimitUSD). SetNillableMonthlyLimitUsd(groupIn.MonthlyLimitUSD). + SetAllowImageGeneration(groupIn.AllowImageGeneration). + SetImageRateIndependent(groupIn.ImageRateIndependent). + SetImageRateMultiplier(groupIn.ImageRateMultiplier). SetNillableImagePrice1k(groupIn.ImagePrice1K). SetNillableImagePrice2k(groupIn.ImagePrice2K). SetNillableImagePrice4k(groupIn.ImagePrice4K). @@ -120,6 +123,9 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er SetNillableDailyLimitUsd(groupIn.DailyLimitUSD). SetNillableWeeklyLimitUsd(groupIn.WeeklyLimitUSD). SetNillableMonthlyLimitUsd(groupIn.MonthlyLimitUSD). + SetAllowImageGeneration(groupIn.AllowImageGeneration). + SetImageRateIndependent(groupIn.ImageRateIndependent). + SetImageRateMultiplier(groupIn.ImageRateMultiplier). SetNillableImagePrice1k(groupIn.ImagePrice1K). SetNillableImagePrice2k(groupIn.ImagePrice2K). SetNillableImagePrice4k(groupIn.ImagePrice4K). diff --git a/backend/internal/server/api_contract_test.go b/backend/internal/server/api_contract_test.go index fabf3b5d..0c7248d2 100644 --- a/backend/internal/server/api_contract_test.go +++ b/backend/internal/server/api_contract_test.go @@ -328,6 +328,9 @@ func TestAPIContracts(t *testing.T) { "image_price_1k": null, "image_price_2k": null, "image_price_4k": null, + "allow_image_generation": false, + "image_rate_independent": false, + "image_rate_multiplier": 0, "claude_code_only": false, "allow_messages_dispatch": false, "fallback_group_id": null, diff --git a/backend/internal/server/routes/admin.go b/backend/internal/server/routes/admin.go index 36bafb86..53734ce8 100644 --- a/backend/internal/server/routes/admin.go +++ b/backend/internal/server/routes/admin.go @@ -412,6 +412,9 @@ func registerSettingsRoutes(admin *gin.RouterGroup, h *handler.Handlers) { // 529过载冷却配置 adminSettings.GET("/overload-cooldown", h.Admin.Setting.GetOverloadCooldownSettings) adminSettings.PUT("/overload-cooldown", h.Admin.Setting.UpdateOverloadCooldownSettings) + // 429默认回避配置 + adminSettings.GET("/rate-limit-429-cooldown", h.Admin.Setting.GetRateLimit429CooldownSettings) + adminSettings.PUT("/rate-limit-429-cooldown", h.Admin.Setting.UpdateRateLimit429CooldownSettings) // 流超时处理配置 adminSettings.GET("/stream-timeout", h.Admin.Setting.GetStreamTimeoutSettings) adminSettings.PUT("/stream-timeout", h.Admin.Setting.UpdateStreamTimeoutSettings) @@ -624,11 +627,16 @@ func registerChannelMonitorRoutes(admin *gin.RouterGroup, h *handler.Handlers) { func registerAffiliateRoutes(admin *gin.RouterGroup, h *handler.Handlers) { affiliates := admin.Group("/affiliates") { + affiliates.GET("/invites", h.Admin.Affiliate.ListInviteRecords) + affiliates.GET("/rebates", h.Admin.Affiliate.ListRebateRecords) + affiliates.GET("/transfers", h.Admin.Affiliate.ListTransferRecords) + users := affiliates.Group("/users") { users.GET("", h.Admin.Affiliate.ListUsers) users.GET("/lookup", h.Admin.Affiliate.LookupUsers) users.POST("/batch-rate", h.Admin.Affiliate.BatchSetRate) + users.GET("/:user_id/overview", h.Admin.Affiliate.GetUserOverview) users.PUT("/:user_id", h.Admin.Affiliate.UpdateUserSettings) users.DELETE("/:user_id", h.Admin.Affiliate.ClearUserSettings) } diff --git a/backend/internal/service/account_stats_pricing.go b/backend/internal/service/account_stats_pricing.go index 90ff450f..221021d8 100644 --- a/backend/internal/service/account_stats_pricing.go +++ b/backend/internal/service/account_stats_pricing.go @@ -230,7 +230,11 @@ func applyAccountStatsCost( if model == "" { model = requestedModel } + requestCount := 1 + if usageLog != nil && usageLog.ImageCount > 0 { + requestCount = usageLog.ImageCount + } usageLog.AccountStatsCost = resolveAccountStatsCost( - ctx, cs, bs, accountID, groupID, model, tokens, 1, totalCost, + ctx, cs, bs, accountID, groupID, model, tokens, requestCount, totalCost, ) } diff --git a/backend/internal/service/account_test_service.go b/backend/internal/service/account_test_service.go index cab0215b..bf9f71ff 100644 --- a/backend/internal/service/account_test_service.go +++ b/backend/internal/service/account_test_service.go @@ -21,6 +21,7 @@ import ( "github.com/Wei-Shaw/sub2api/internal/pkg/claude" "github.com/Wei-Shaw/sub2api/internal/pkg/geminicli" "github.com/Wei-Shaw/sub2api/internal/pkg/openai" + "github.com/Wei-Shaw/sub2api/internal/pkg/openai_compat" "github.com/Wei-Shaw/sub2api/internal/pkg/windsurf" "github.com/Wei-Shaw/sub2api/internal/util/urlvalidator" "github.com/gin-gonic/gin" @@ -571,7 +572,16 @@ func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account if err != nil { return s.sendErrorAndEnd(c, fmt.Sprintf("Invalid base URL: %s", err.Error())) } - apiURL = strings.TrimSuffix(normalizedBaseURL, "/") + "/responses" + // 账号已被探测为不支持 Responses(如 DeepSeek/Kimi 等)时,丢出明确提示。 + // 账号本身可用(网关会走 CC 直转),仅测试入口需要补齐 CC SSE 处理逻辑。 + // TODO:实现 CC 格式的账号测试路径(需专门的 CC SSE handler)。 + if !openai_compat.ShouldUseResponsesAPI(account.Extra) { + return s.sendErrorAndEnd(c, + "账号已被探测为不支持 OpenAI Responses API(如 DeepSeek/Kimi 等三方兼容上游),"+ + "账号本身可正常使用,但当前测试接口仅支持 Responses API 路径。请直接通过实际 API 调用验证。", + ) + } + apiURL = buildOpenAIResponsesURL(normalizedBaseURL) } else { return s.sendErrorAndEnd(c, fmt.Sprintf("Unsupported account type: %s", account.Type)) } diff --git a/backend/internal/service/admin_balance_history_test.go b/backend/internal/service/admin_balance_history_test.go new file mode 100644 index 00000000..291d3f7b --- /dev/null +++ b/backend/internal/service/admin_balance_history_test.go @@ -0,0 +1,86 @@ +package service + +import ( + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" + "github.com/stretchr/testify/require" +) + +func TestMergeBalanceHistoryCodesIncludesAffiliateTransfersByDefault(t *testing.T) { + t.Parallel() + + now := time.Date(2026, 5, 3, 12, 0, 0, 0, time.UTC) + older := now.Add(-2 * time.Hour) + newer := now.Add(time.Hour) + + usedBy := int64(10) + redeemCodes := []RedeemCode{ + { + ID: 1, + Type: RedeemTypeBalance, + Value: 8, + Status: StatusUsed, + UsedBy: &usedBy, + UsedAt: &now, + CreatedAt: now, + }, + { + ID: 2, + Type: RedeemTypeConcurrency, + Value: 1, + Status: StatusUsed, + UsedBy: &usedBy, + UsedAt: &older, + CreatedAt: older, + }, + } + affiliateCodes := []RedeemCode{ + { + ID: -20, + Type: RedeemTypeAffiliateBalance, + Value: 3.5, + Status: StatusUsed, + UsedBy: &usedBy, + UsedAt: &newer, + CreatedAt: newer, + }, + } + + got := mergeBalanceHistoryCodes(redeemCodes, affiliateCodes, pagination.PaginationParams{ + Page: 1, + PageSize: 2, + }) + + require.Len(t, got, 2) + require.Equal(t, RedeemTypeAffiliateBalance, got[0].Type) + require.Equal(t, RedeemTypeBalance, got[1].Type) +} + +func TestMergeBalanceHistoryCodesPaginatesAfterCombiningSources(t *testing.T) { + t.Parallel() + + base := time.Date(2026, 5, 3, 12, 0, 0, 0, time.UTC) + usedBy := int64(10) + at := func(hours int) *time.Time { + v := base.Add(time.Duration(hours) * time.Hour) + return &v + } + + got := mergeBalanceHistoryCodes( + []RedeemCode{ + {ID: 1, Type: RedeemTypeBalance, UsedBy: &usedBy, UsedAt: at(4), CreatedAt: *at(4)}, + {ID: 2, Type: RedeemTypeConcurrency, UsedBy: &usedBy, UsedAt: at(2), CreatedAt: *at(2)}, + }, + []RedeemCode{ + {ID: -3, Type: RedeemTypeAffiliateBalance, UsedBy: &usedBy, UsedAt: at(3), CreatedAt: *at(3)}, + {ID: -4, Type: RedeemTypeAffiliateBalance, UsedBy: &usedBy, UsedAt: at(1), CreatedAt: *at(1)}, + }, + pagination.PaginationParams{Page: 2, PageSize: 2}, + ) + + require.Len(t, got, 2) + require.Equal(t, RedeemTypeConcurrency, got[0].Type) + require.Equal(t, int64(-4), got[1].ID) +} diff --git a/backend/internal/service/admin_service.go b/backend/internal/service/admin_service.go index b854c16e..1bf44218 100644 --- a/backend/internal/service/admin_service.go +++ b/backend/internal/service/admin_service.go @@ -2,6 +2,7 @@ package service import ( "context" + "database/sql" "encoding/json" "errors" "fmt" @@ -188,11 +189,14 @@ type CreateGroupInput struct { WeeklyLimitUSD *float64 // 周限额 (USD) MonthlyLimitUSD *float64 // 月限额 (USD) // 图片生成计费配置(仅 antigravity 平台使用) - ImagePrice1K *float64 - ImagePrice2K *float64 - ImagePrice4K *float64 - ClaudeCodeOnly bool // 仅允许 Claude Code 客户端 - FallbackGroupID *int64 // 降级分组 ID + AllowImageGeneration bool + ImageRateIndependent bool + ImageRateMultiplier *float64 + ImagePrice1K *float64 + ImagePrice2K *float64 + ImagePrice4K *float64 + ClaudeCodeOnly bool // 仅允许 Claude Code 客户端 + FallbackGroupID *int64 // 降级分组 ID // 无效请求兜底分组 ID(仅 anthropic 平台使用) FallbackGroupIDOnInvalidRequest *int64 // 模型路由配置(仅 anthropic 平台使用) @@ -225,11 +229,14 @@ type UpdateGroupInput struct { WeeklyLimitUSD *float64 // 周限额 (USD) MonthlyLimitUSD *float64 // 月限额 (USD) // 图片生成计费配置(仅 antigravity 平台使用) - ImagePrice1K *float64 - ImagePrice2K *float64 - ImagePrice4K *float64 - ClaudeCodeOnly *bool // 仅允许 Claude Code 客户端 - FallbackGroupID *int64 // 降级分组 ID + AllowImageGeneration *bool + ImageRateIndependent *bool + ImageRateMultiplier *float64 + ImagePrice1K *float64 + ImagePrice2K *float64 + ImagePrice4K *float64 + ClaudeCodeOnly *bool // 仅允许 Claude Code 客户端 + FallbackGroupID *int64 // 降级分组 ID // 无效请求兜底分组 ID(仅 anthropic 平台使用) FallbackGroupIDOnInvalidRequest *int64 // 模型路由配置(仅 anthropic 平台使用) @@ -973,16 +980,213 @@ func (s *adminServiceImpl) GetUserUsageStats(ctx context.Context, userID int64, // GetUserBalanceHistory returns paginated balance/concurrency change records for a user. func (s *adminServiceImpl) GetUserBalanceHistory(ctx context.Context, userID int64, page, pageSize int, codeType string) ([]RedeemCode, int64, float64, error) { params := pagination.PaginationParams{Page: page, PageSize: pageSize} + if codeType == RedeemTypeAffiliateBalance { + codes, total, err := s.listAffiliateBalanceHistory(ctx, userID, params) + if err != nil { + return nil, 0, 0, err + } + totalRecharged, err := s.redeemCodeRepo.SumPositiveBalanceByUser(ctx, userID) + if err != nil { + return nil, 0, 0, err + } + return codes, total, totalRecharged, nil + } + + if codeType == "" { + return s.getAllUserBalanceHistory(ctx, userID, params) + } + codes, result, err := s.redeemCodeRepo.ListByUserPaginated(ctx, userID, params, codeType) if err != nil { return nil, 0, 0, err } + total := result.Total // Aggregate total recharged amount (only once, regardless of type filter) totalRecharged, err := s.redeemCodeRepo.SumPositiveBalanceByUser(ctx, userID) if err != nil { return nil, 0, 0, err } - return codes, result.Total, totalRecharged, nil + return codes, total, totalRecharged, nil +} + +func (s *adminServiceImpl) getAllUserBalanceHistory(ctx context.Context, userID int64, params pagination.PaginationParams) ([]RedeemCode, int64, float64, error) { + needed := params.Offset() + params.Limit() + if needed < params.Limit() { + needed = params.Limit() + } + + redeemCodes, redeemTotal, err := s.listRedeemBalanceHistoryForMerge(ctx, userID, needed) + if err != nil { + return nil, 0, 0, err + } + affiliateCodes, affiliateTotal, err := s.listAffiliateBalanceHistoryForMerge(ctx, userID, needed) + if err != nil { + return nil, 0, 0, err + } + codes := mergeBalanceHistoryCodes(redeemCodes, affiliateCodes, params) + + totalRecharged, err := s.redeemCodeRepo.SumPositiveBalanceByUser(ctx, userID) + if err != nil { + return nil, 0, 0, err + } + return codes, redeemTotal + affiliateTotal, totalRecharged, nil +} + +func (s *adminServiceImpl) listRedeemBalanceHistoryForMerge(ctx context.Context, userID int64, needed int) ([]RedeemCode, int64, error) { + if needed <= 0 { + return nil, 0, nil + } + + var ( + out []RedeemCode + total int64 + ) + for page := 1; len(out) < needed; page++ { + params := pagination.PaginationParams{Page: page, PageSize: 1000} + codes, result, err := s.redeemCodeRepo.ListByUserPaginated(ctx, userID, params, "") + if err != nil { + return nil, 0, err + } + if result != nil { + total = result.Total + } + out = append(out, codes...) + if len(codes) < params.Limit() || int64(len(out)) >= total { + break + } + } + if len(out) > needed { + out = out[:needed] + } + return out, total, nil +} + +func (s *adminServiceImpl) listAffiliateBalanceHistoryForMerge(ctx context.Context, userID int64, needed int) ([]RedeemCode, int64, error) { + if needed <= 0 { + return nil, 0, nil + } + + var ( + out []RedeemCode + total int64 + ) + for page := 1; len(out) < needed; page++ { + params := pagination.PaginationParams{Page: page, PageSize: 1000} + codes, currentTotal, err := s.listAffiliateBalanceHistory(ctx, userID, params) + if err != nil { + return nil, 0, err + } + total = currentTotal + out = append(out, codes...) + if len(codes) < params.Limit() || int64(len(out)) >= total { + break + } + } + if len(out) > needed { + out = out[:needed] + } + return out, total, nil +} + +func (s *adminServiceImpl) listAffiliateBalanceHistory(ctx context.Context, userID int64, params pagination.PaginationParams) ([]RedeemCode, int64, error) { + if s == nil || s.entClient == nil || userID <= 0 { + return nil, 0, nil + } + + rows, err := s.entClient.QueryContext(ctx, ` +SELECT id, + amount::double precision, + created_at +FROM user_affiliate_ledger +WHERE user_id = $1 + AND action = 'transfer' +ORDER BY created_at DESC, id DESC +OFFSET $2 +LIMIT $3`, userID, params.Offset(), params.Limit()) + if err != nil { + return nil, 0, err + } + defer func() { _ = rows.Close() }() + + codes := make([]RedeemCode, 0, params.Limit()) + for rows.Next() { + var id int64 + var amount float64 + var createdAt time.Time + if err := rows.Scan(&id, &amount, &createdAt); err != nil { + return nil, 0, err + } + usedBy := userID + usedAt := createdAt + codes = append(codes, RedeemCode{ + ID: -id, + Code: fmt.Sprintf("AFF-%d", id), + Type: RedeemTypeAffiliateBalance, + Value: amount, + Status: StatusUsed, + UsedBy: &usedBy, + UsedAt: &usedAt, + CreatedAt: createdAt, + }) + } + if err := rows.Err(); err != nil { + return nil, 0, err + } + + total, err := countAffiliateBalanceHistory(ctx, s.entClient, userID) + if err != nil { + return nil, 0, err + } + return codes, total, nil +} + +func countAffiliateBalanceHistory(ctx context.Context, client *dbent.Client, userID int64) (int64, error) { + rows, err := client.QueryContext(ctx, ` +SELECT COUNT(*) +FROM user_affiliate_ledger +WHERE user_id = $1 + AND action = 'transfer'`, userID) + if err != nil { + return 0, err + } + defer func() { _ = rows.Close() }() + + var total sql.NullInt64 + if rows.Next() { + if err := rows.Scan(&total); err != nil { + return 0, err + } + } + if err := rows.Err(); err != nil { + return 0, err + } + if !total.Valid { + return 0, nil + } + return total.Int64, nil +} + +func mergeBalanceHistoryCodes(redeemCodes, affiliateCodes []RedeemCode, params pagination.PaginationParams) []RedeemCode { + combined := append(append([]RedeemCode{}, redeemCodes...), affiliateCodes...) + sort.SliceStable(combined, func(i, j int) bool { + return redeemCodeHistoryTime(combined[i]).After(redeemCodeHistoryTime(combined[j])) + }) + offset := params.Offset() + if offset >= len(combined) { + return []RedeemCode{} + } + end := offset + params.Limit() + if end > len(combined) { + end = len(combined) + } + return combined[offset:end] +} + +func redeemCodeHistoryTime(code RedeemCode) time.Time { + if code.UsedAt != nil { + return *code.UsedAt + } + return code.CreatedAt } func (s *adminServiceImpl) BindUserAuthIdentity(ctx context.Context, userID int64, input AdminBindAuthIdentityInput) (*AdminBoundAuthIdentity, error) { @@ -1359,6 +1563,13 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn imagePrice1K := normalizePrice(input.ImagePrice1K) imagePrice2K := normalizePrice(input.ImagePrice2K) imagePrice4K := normalizePrice(input.ImagePrice4K) + imageRateMultiplier := 1.0 + if input.ImageRateMultiplier != nil { + if *input.ImageRateMultiplier < 0 { + return nil, errors.New("image_rate_multiplier must be >= 0") + } + imageRateMultiplier = *input.ImageRateMultiplier + } // 校验降级分组 if input.FallbackGroupID != nil { @@ -1426,6 +1637,9 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn DailyLimitUSD: dailyLimit, WeeklyLimitUSD: weeklyLimit, MonthlyLimitUSD: monthlyLimit, + AllowImageGeneration: input.AllowImageGeneration, + ImageRateIndependent: input.ImageRateIndependent, + ImageRateMultiplier: imageRateMultiplier, ImagePrice1K: imagePrice1K, ImagePrice2K: imagePrice2K, ImagePrice4K: imagePrice4K, @@ -1602,6 +1816,18 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd group.WeeklyLimitUSD = normalizeLimit(input.WeeklyLimitUSD) group.MonthlyLimitUSD = normalizeLimit(input.MonthlyLimitUSD) // 图片生成计费配置:负数表示清除(使用默认价格) + if input.AllowImageGeneration != nil { + group.AllowImageGeneration = *input.AllowImageGeneration + } + if input.ImageRateIndependent != nil { + group.ImageRateIndependent = *input.ImageRateIndependent + } + if input.ImageRateMultiplier != nil { + if *input.ImageRateMultiplier < 0 { + return nil, errors.New("image_rate_multiplier must be >= 0") + } + group.ImageRateMultiplier = *input.ImageRateMultiplier + } if input.ImagePrice1K != nil { group.ImagePrice1K = normalizePrice(input.ImagePrice1K) } diff --git a/backend/internal/service/admin_service_group_test.go b/backend/internal/service/admin_service_group_test.go index eef02240..0a2020ea 100644 --- a/backend/internal/service/admin_service_group_test.go +++ b/backend/internal/service/admin_service_group_test.go @@ -266,6 +266,50 @@ func TestAdminService_UpdateGroup_PartialImagePricing(t *testing.T) { require.Nil(t, repo.updated.ImagePrice4K) } +func TestAdminService_UpdateGroup_PreservesImageGenerationControlsWhenOmitted(t *testing.T) { + imageMultiplier := 0.5 + existingGroup := &Group{ + ID: 1, + Name: "existing-group", + Platform: PlatformOpenAI, + Status: StatusActive, + AllowImageGeneration: true, + ImageRateIndependent: true, + ImageRateMultiplier: imageMultiplier, + } + repo := &groupRepoStubForAdmin{getByID: existingGroup} + svc := &adminServiceImpl{groupRepo: repo} + + group, err := svc.UpdateGroup(context.Background(), 1, &UpdateGroupInput{ + Description: "updated", + }) + require.NoError(t, err) + require.NotNil(t, group) + require.NotNil(t, repo.updated) + require.True(t, repo.updated.AllowImageGeneration) + require.True(t, repo.updated.ImageRateIndependent) + require.InDelta(t, 0.5, repo.updated.ImageRateMultiplier, 1e-12) +} + +func TestAdminService_UpdateGroup_RejectsNegativeImageRateMultiplier(t *testing.T) { + existingGroup := &Group{ + ID: 1, + Name: "existing-group", + Platform: PlatformOpenAI, + Status: StatusActive, + ImageRateMultiplier: 1, + } + repo := &groupRepoStubForAdmin{getByID: existingGroup} + svc := &adminServiceImpl{groupRepo: repo} + negative := -0.1 + + _, err := svc.UpdateGroup(context.Background(), 1, &UpdateGroupInput{ + ImageRateMultiplier: &negative, + }) + require.Error(t, err) + require.Nil(t, repo.updated) +} + func TestAdminService_UpdateGroup_InvalidatesAuthCacheOnRPMLimitChange(t *testing.T) { existingGroup := &Group{ ID: 1, diff --git a/backend/internal/service/affiliate_service.go b/backend/internal/service/affiliate_service.go index 5a4e91e7..91cca5e2 100644 --- a/backend/internal/service/affiliate_service.go +++ b/backend/internal/service/affiliate_service.go @@ -98,7 +98,7 @@ type AffiliateRepository interface { EnsureUserAffiliate(ctx context.Context, userID int64) (*AffiliateSummary, error) GetAffiliateByCode(ctx context.Context, code string) (*AffiliateSummary, error) BindInviter(ctx context.Context, userID, inviterID int64) (bool, error) - AccrueQuota(ctx context.Context, inviterID, inviteeUserID int64, amount float64, freezeHours int) (bool, error) + AccrueQuota(ctx context.Context, inviterID, inviteeUserID int64, amount float64, freezeHours int, sourceOrderID *int64) (bool, error) GetAccruedRebateFromInvitee(ctx context.Context, inviterID, inviteeUserID int64) (float64, error) ThawFrozenQuota(ctx context.Context, userID int64) (float64, error) TransferQuotaToBalance(ctx context.Context, userID int64) (float64, float64, error) @@ -110,6 +110,10 @@ type AffiliateRepository interface { SetUserRebateRate(ctx context.Context, userID int64, ratePercent *float64) error BatchSetUserRebateRate(ctx context.Context, userIDs []int64, ratePercent *float64) error ListUsersWithCustomSettings(ctx context.Context, filter AffiliateAdminFilter) ([]AffiliateAdminEntry, int64, error) + ListAffiliateInviteRecords(ctx context.Context, filter AffiliateRecordFilter) ([]AffiliateInviteRecord, int64, error) + ListAffiliateRebateRecords(ctx context.Context, filter AffiliateRecordFilter) ([]AffiliateRebateRecord, int64, error) + ListAffiliateTransferRecords(ctx context.Context, filter AffiliateRecordFilter) ([]AffiliateTransferRecord, int64, error) + GetAffiliateUserOverview(ctx context.Context, userID int64) (*AffiliateUserOverview, error) } // AffiliateAdminFilter 列表筛选条件 @@ -130,6 +134,76 @@ type AffiliateAdminEntry struct { AffCount int `json:"aff_count"` } +type AffiliateRecordFilter struct { + Search string + Page int + PageSize int + StartAt *time.Time + EndAt *time.Time + SortBy string + SortDesc bool +} + +type AffiliateInviteRecord struct { + InviterID int64 `json:"inviter_id"` + InviterEmail string `json:"inviter_email"` + InviterUsername string `json:"inviter_username"` + InviteeID int64 `json:"invitee_id"` + InviteeEmail string `json:"invitee_email"` + InviteeUsername string `json:"invitee_username"` + AffCode string `json:"aff_code"` + TotalRebate float64 `json:"total_rebate"` + CreatedAt time.Time `json:"created_at"` +} + +type AffiliateRebateRecord struct { + OrderID int64 `json:"order_id"` + OutTradeNo string `json:"out_trade_no"` + InviterID int64 `json:"inviter_id"` + InviterEmail string `json:"inviter_email"` + InviterUsername string `json:"inviter_username"` + InviteeID int64 `json:"invitee_id"` + InviteeEmail string `json:"invitee_email"` + InviteeUsername string `json:"invitee_username"` + OrderAmount float64 `json:"order_amount"` + PayAmount float64 `json:"pay_amount"` + RebateAmount float64 `json:"rebate_amount"` + PaymentType string `json:"payment_type"` + OrderStatus string `json:"order_status"` + CreatedAt time.Time `json:"created_at"` +} + +type AffiliateTransferRecord struct { + LedgerID int64 `json:"ledger_id"` + UserID int64 `json:"user_id"` + UserEmail string `json:"user_email"` + Username string `json:"username"` + Amount float64 `json:"amount"` + BalanceAfter *float64 `json:"balance_after,omitempty"` + AvailableQuotaAfter *float64 `json:"available_quota_after,omitempty"` + FrozenQuotaAfter *float64 `json:"frozen_quota_after,omitempty"` + HistoryQuotaAfter *float64 `json:"history_quota_after,omitempty"` + SnapshotAvailable bool `json:"snapshot_available"` + CurrentBalance float64 `json:"-"` + RemainingQuota float64 `json:"-"` + FrozenQuota float64 `json:"-"` + HistoryQuota float64 `json:"-"` + CreatedAt time.Time `json:"created_at"` +} + +type AffiliateUserOverview struct { + UserID int64 `json:"user_id"` + Email string `json:"email"` + Username string `json:"username"` + AffCode string `json:"aff_code"` + RebateRatePercent float64 `json:"rebate_rate_percent"` + RebateRateCustom bool `json:"-"` + InvitedCount int `json:"invited_count"` + RebatedInviteeCount int `json:"rebated_invitee_count"` + AvailableQuota float64 `json:"available_quota"` + HistoryQuota float64 `json:"history_quota"` +} + type AffiliateService struct { repo AffiliateRepository settingService *SettingService @@ -238,6 +312,10 @@ func (s *AffiliateService) BindInviterByCode(ctx context.Context, userID int64, } func (s *AffiliateService) AccrueInviteRebate(ctx context.Context, inviteeUserID int64, baseRechargeAmount float64) (float64, error) { + return s.AccrueInviteRebateForOrder(ctx, inviteeUserID, baseRechargeAmount, nil) +} + +func (s *AffiliateService) AccrueInviteRebateForOrder(ctx context.Context, inviteeUserID int64, baseRechargeAmount float64, sourceOrderID *int64) (float64, error) { if s == nil || s.repo == nil { return 0, nil } @@ -298,7 +376,7 @@ func (s *AffiliateService) AccrueInviteRebate(ctx context.Context, inviteeUserID freezeHours = s.settingService.GetAffiliateRebateFreezeHours(ctx) } - applied, err := s.repo.AccrueQuota(ctx, *inviteeSummary.InviterID, inviteeUserID, rebate, freezeHours) + applied, err := s.repo.AccrueQuota(ctx, *inviteeSummary.InviterID, inviteeUserID, rebate, freezeHours, sourceOrderID) if err != nil { return 0, err } @@ -488,3 +566,59 @@ func (s *AffiliateService) AdminListCustomUsers(ctx context.Context, filter Affi } return s.repo.ListUsersWithCustomSettings(ctx, filter) } + +func (s *AffiliateService) AdminListInviteRecords(ctx context.Context, filter AffiliateRecordFilter) ([]AffiliateInviteRecord, int64, error) { + if s == nil || s.repo == nil { + return nil, 0, infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "affiliate service unavailable") + } + return s.repo.ListAffiliateInviteRecords(ctx, normalizeAffiliateRecordFilter(filter)) +} + +func (s *AffiliateService) AdminListRebateRecords(ctx context.Context, filter AffiliateRecordFilter) ([]AffiliateRebateRecord, int64, error) { + if s == nil || s.repo == nil { + return nil, 0, infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "affiliate service unavailable") + } + return s.repo.ListAffiliateRebateRecords(ctx, normalizeAffiliateRecordFilter(filter)) +} + +func (s *AffiliateService) AdminListTransferRecords(ctx context.Context, filter AffiliateRecordFilter) ([]AffiliateTransferRecord, int64, error) { + if s == nil || s.repo == nil { + return nil, 0, infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "affiliate service unavailable") + } + return s.repo.ListAffiliateTransferRecords(ctx, normalizeAffiliateRecordFilter(filter)) +} + +func (s *AffiliateService) AdminGetUserOverview(ctx context.Context, userID int64) (*AffiliateUserOverview, error) { + if userID <= 0 { + return nil, infraerrors.BadRequest("INVALID_USER", "invalid user") + } + if s == nil || s.repo == nil { + return nil, infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "affiliate service unavailable") + } + overview, err := s.repo.GetAffiliateUserOverview(ctx, userID) + if err != nil { + return nil, err + } + if overview != nil { + if !overview.RebateRateCustom { + overview.RebateRatePercent = s.globalRebateRatePercent(ctx) + } + overview.RebateRatePercent = clampAffiliateRebateRate(overview.RebateRatePercent) + } + return overview, nil +} + +func normalizeAffiliateRecordFilter(filter AffiliateRecordFilter) AffiliateRecordFilter { + if filter.Page <= 0 { + filter.Page = 1 + } + if filter.PageSize <= 0 { + filter.PageSize = 20 + } + if filter.PageSize > 100 { + filter.PageSize = 100 + } + filter.Search = strings.TrimSpace(filter.Search) + filter.SortBy = strings.TrimSpace(filter.SortBy) + return filter +} diff --git a/backend/internal/service/api_key_auth_cache.go b/backend/internal/service/api_key_auth_cache.go index 1a1c78b8..4432ad7d 100644 --- a/backend/internal/service/api_key_auth_cache.go +++ b/backend/internal/service/api_key_auth_cache.go @@ -63,6 +63,9 @@ type APIKeyAuthGroupSnapshot struct { DailyLimitUSD *float64 `json:"daily_limit_usd,omitempty"` WeeklyLimitUSD *float64 `json:"weekly_limit_usd,omitempty"` MonthlyLimitUSD *float64 `json:"monthly_limit_usd,omitempty"` + AllowImageGeneration bool `json:"allow_image_generation"` + ImageRateIndependent bool `json:"image_rate_independent"` + ImageRateMultiplier float64 `json:"image_rate_multiplier"` ImagePrice1K *float64 `json:"image_price_1k,omitempty"` ImagePrice2K *float64 `json:"image_price_2k,omitempty"` ImagePrice4K *float64 `json:"image_price_4k,omitempty"` diff --git a/backend/internal/service/api_key_auth_cache_impl.go b/backend/internal/service/api_key_auth_cache_impl.go index 974ea66e..0f9d4214 100644 --- a/backend/internal/service/api_key_auth_cache_impl.go +++ b/backend/internal/service/api_key_auth_cache_impl.go @@ -14,7 +14,7 @@ import ( "github.com/dgraph-io/ristretto" ) -const apiKeyAuthSnapshotVersion = 7 // v7: added UserGroupRPMOverride on user snapshot +const apiKeyAuthSnapshotVersion = 8 // v8: added group image generation controls type apiKeyAuthCacheConfig struct { l1Size int @@ -255,6 +255,9 @@ func (s *APIKeyService) snapshotFromAPIKey(ctx context.Context, apiKey *APIKey) DailyLimitUSD: apiKey.Group.DailyLimitUSD, WeeklyLimitUSD: apiKey.Group.WeeklyLimitUSD, MonthlyLimitUSD: apiKey.Group.MonthlyLimitUSD, + AllowImageGeneration: apiKey.Group.AllowImageGeneration, + ImageRateIndependent: apiKey.Group.ImageRateIndependent, + ImageRateMultiplier: apiKey.Group.ImageRateMultiplier, ImagePrice1K: apiKey.Group.ImagePrice1K, ImagePrice2K: apiKey.Group.ImagePrice2K, ImagePrice4K: apiKey.Group.ImagePrice4K, @@ -321,6 +324,9 @@ func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapsho DailyLimitUSD: snapshot.Group.DailyLimitUSD, WeeklyLimitUSD: snapshot.Group.WeeklyLimitUSD, MonthlyLimitUSD: snapshot.Group.MonthlyLimitUSD, + AllowImageGeneration: snapshot.Group.AllowImageGeneration, + ImageRateIndependent: snapshot.Group.ImageRateIndependent, + ImageRateMultiplier: snapshot.Group.ImageRateMultiplier, ImagePrice1K: snapshot.Group.ImagePrice1K, ImagePrice2K: snapshot.Group.ImagePrice2K, ImagePrice4K: snapshot.Group.ImagePrice4K, diff --git a/backend/internal/service/billing_service.go b/backend/internal/service/billing_service.go index 392b3e0b..a9c21884 100644 --- a/backend/internal/service/billing_service.go +++ b/backend/internal/service/billing_service.go @@ -226,6 +226,12 @@ func (s *BillingService) initFallbackPricing() { CacheReadPricePerToken: 7.5e-8, SupportsCacheBreakdown: false, } + s.fallbackPrices["gpt-5.4-nano"] = &ModelPricing{ + InputPricePerToken: 2e-7, + OutputPricePerToken: 1.25e-6, + CacheReadPricePerToken: 2e-8, + SupportsCacheBreakdown: false, + } // OpenAI GPT-5.2(本地兜底) s.fallbackPrices["gpt-5.2"] = &ModelPricing{ InputPricePerToken: 1.75e-6, @@ -288,13 +294,14 @@ func (s *BillingService) getFallbackPricing(model string) *ModelPricing { } // OpenAI 仅匹配已知 GPT-5/Codex 族,避免未知 OpenAI 型号误计价。 - if strings.Contains(modelLower, "gpt-5") || strings.Contains(modelLower, "codex") { - normalized := normalizeCodexModel(modelLower) + if normalized := normalizeKnownOpenAICodexModel(modelLower); normalized != "" { switch normalized { case "gpt-5.5": return s.fallbackPrices["gpt-5.5"] case "gpt-5.4-mini": return s.fallbackPrices["gpt-5.4-mini"] + case "gpt-5.4-nano": + return s.fallbackPrices["gpt-5.4-nano"] case "gpt-5.4": return s.fallbackPrices["gpt-5.4"] case "gpt-5.2": @@ -636,13 +643,10 @@ func (s *BillingService) shouldApplySessionLongContextPricing(tokens UsageTokens } func isOpenAIGPT54Model(model string) bool { - trimmed := strings.TrimSpace(strings.ToLower(model)) - // 仅当模型字符串实际属于 GPT-5/Codex 族时才做归一判定,避免 normalizeCodexModel - // 的默认兜底把非 OpenAI 模型(claude-*、gemini-*、gpt-4o)误识别为 gpt-5.4。 - if !strings.Contains(trimmed, "gpt-5") && !strings.Contains(trimmed, "codex") { - return false - } - normalized := normalizeCodexModel(trimmed) + // 仅当模型字符串实际属于已知 GPT-5/Codex 族时才做归一判定,避免 + // normalizeCodexModel 的默认兜底把非 OpenAI 模型(claude-*、gemini-*、gpt-4o) + // 误识别为 gpt-5.4。 + normalized := normalizeKnownOpenAICodexModel(model) return normalized == "gpt-5.4" || normalized == "gpt-5.5" } diff --git a/backend/internal/service/billing_service_test.go b/backend/internal/service/billing_service_test.go index 222abd69..df3e3a0a 100644 --- a/backend/internal/service/billing_service_test.go +++ b/backend/internal/service/billing_service_test.go @@ -137,6 +137,35 @@ func TestGetModelPricing_OpenAIGPT54Fallback(t *testing.T) { require.InDelta(t, 1.5, pricing.LongContextOutputMultiplier, 1e-12) } +func TestGetModelPricing_OpenAICompactAliasesFallback(t *testing.T) { + svc := newTestBillingService() + + tests := []struct { + model string + inputPrice float64 + outputPrice float64 + cacheRead float64 + longContext int + }{ + {model: "gpt5.5", inputPrice: 2.5e-6, outputPrice: 15e-6, cacheRead: 0.25e-6, longContext: 272000}, + {model: "openai/gpt5.4", inputPrice: 2.5e-6, outputPrice: 15e-6, cacheRead: 0.25e-6, longContext: 272000}, + {model: "gpt5.4-mini", inputPrice: 7.5e-7, outputPrice: 4.5e-6, cacheRead: 7.5e-8, longContext: 0}, + {model: "gpt5.3codexspark", inputPrice: 1.5e-6, outputPrice: 12e-6, cacheRead: 0.15e-6, longContext: 0}, + } + + for _, tt := range tests { + t.Run(tt.model, func(t *testing.T) { + pricing, err := svc.GetModelPricing(tt.model) + require.NoError(t, err) + require.NotNil(t, pricing) + require.InDelta(t, tt.inputPrice, pricing.InputPricePerToken, 1e-12) + require.InDelta(t, tt.outputPrice, pricing.OutputPricePerToken, 1e-12) + require.InDelta(t, tt.cacheRead, pricing.CacheReadPricePerToken, 1e-12) + require.Equal(t, tt.longContext, pricing.LongContextInputThreshold) + }) + } +} + func TestGetModelPricing_OpenAIGPT54MiniFallback(t *testing.T) { svc := newTestBillingService() diff --git a/backend/internal/service/domain_constants.go b/backend/internal/service/domain_constants.go index 9793b6b2..6d0c5fda 100644 --- a/backend/internal/service/domain_constants.go +++ b/backend/internal/service/domain_constants.go @@ -52,10 +52,11 @@ const ( // Redeem type constants const ( - RedeemTypeBalance = domain.RedeemTypeBalance - RedeemTypeConcurrency = domain.RedeemTypeConcurrency - RedeemTypeSubscription = domain.RedeemTypeSubscription - RedeemTypeInvitation = domain.RedeemTypeInvitation + RedeemTypeBalance = domain.RedeemTypeBalance + RedeemTypeConcurrency = domain.RedeemTypeConcurrency + RedeemTypeSubscription = domain.RedeemTypeSubscription + RedeemTypeInvitation = domain.RedeemTypeInvitation + RedeemTypeAffiliateBalance = "affiliate_balance" ) // PromoCode status constants @@ -287,6 +288,9 @@ const ( // SettingKeyOverloadCooldownSettings stores JSON config for 529 overload cooldown handling. SettingKeyOverloadCooldownSettings = "overload_cooldown_settings" + // SettingKeyRateLimit429CooldownSettings stores JSON config for 429 fallback cooldown handling. + SettingKeyRateLimit429CooldownSettings = "rate_limit_429_cooldown_settings" + // ========================= // Stream Timeout Handling // ========================= diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index 40d82fe3..a1f3f353 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -8297,9 +8297,16 @@ func detachedBillingContext(ctx context.Context) (context.Context, context.Cance } func detachStreamUpstreamContext(ctx context.Context, stream bool) (context.Context, context.CancelFunc) { + if ctx == nil { + return context.Background(), func() {} + } if !stream { return ctx, func() {} } + return context.WithoutCancel(ctx), func() {} +} + +func detachUpstreamContext(ctx context.Context) (context.Context, context.CancelFunc) { if ctx == nil { return context.Background(), func() {} } @@ -8483,6 +8490,7 @@ func (s *GatewayService) recordUsageCore(ctx context.Context, input *recordUsage groupDefault := apiKey.Group.RateMultiplier multiplier = s.getUserGroupRateMultiplier(ctx, user.ID, *apiKey.GroupID, groupDefault) } + imageMultiplier := resolveImageRateMultiplier(apiKey, multiplier) // 确定计费模型 billingModel := forwardResultBillingModel(result.Model, result.UpstreamModel) @@ -8500,7 +8508,7 @@ func (s *GatewayService) recordUsageCore(ctx context.Context, input *recordUsage } // 计算费用 - cost := s.calculateRecordUsageCost(ctx, result, apiKey, billingModel, multiplier, opts) + cost := s.calculateRecordUsageCost(ctx, result, apiKey, billingModel, multiplier, imageMultiplier, opts) // 判断计费方式:订阅模式 vs 余额模式 isSubscriptionBilling := subscription != nil && apiKey.Group != nil && apiKey.Group.IsSubscriptionType() @@ -8512,7 +8520,7 @@ func (s *GatewayService) recordUsageCore(ctx context.Context, input *recordUsage // 创建使用日志 accountRateMultiplier := account.BillingRateMultiplier() usageLog := s.buildRecordUsageLog(ctx, input, result, apiKey, user, account, subscription, - requestedModel, multiplier, accountRateMultiplier, billingType, cacheTTLOverridden, cost, opts) + requestedModel, multiplier, imageMultiplier, accountRateMultiplier, billingType, cacheTTLOverridden, cost, opts) // 计算账号统计定价费用(使用最终上游模型匹配自定义规则) if apiKey.GroupID != nil { @@ -8566,11 +8574,12 @@ func (s *GatewayService) calculateRecordUsageCost( apiKey *APIKey, billingModel string, multiplier float64, + imageMultiplier float64, opts *recordUsageOpts, ) *CostBreakdown { // 图片生成计费 if result.ImageCount > 0 { - return s.calculateImageCost(ctx, result, apiKey, billingModel, multiplier) + return s.calculateImageCost(ctx, result, apiKey, billingModel, imageMultiplier) } // Token 计费 @@ -8611,7 +8620,8 @@ func (s *GatewayService) calculateImageCost( Model: billingModel, GroupID: &gid, Tokens: tokens, - RequestCount: 1, + RequestCount: result.ImageCount, + SizeTier: result.ImageSize, RateMultiplier: multiplier, Resolver: s.resolver, Resolved: resolved, @@ -8696,6 +8706,7 @@ func (s *GatewayService) buildRecordUsageLog( subscription *UserSubscription, requestedModel string, multiplier float64, + imageMultiplier float64, accountRateMultiplier float64, billingType int8, cacheTTLOverridden bool, @@ -8740,6 +8751,9 @@ func (s *GatewayService) buildRecordUsageLog( SubscriptionID: optionalSubscriptionID(subscription), CreatedAt: time.Now(), } + if result.ImageCount > 0 { + usageLog.RateMultiplier = imageMultiplier + } if cost != nil { usageLog.InputCost = cost.InputCost usageLog.OutputCost = cost.OutputCost diff --git a/backend/internal/service/gateway_service_streaming_test.go b/backend/internal/service/gateway_service_streaming_test.go index c8803d39..39a7d3b0 100644 --- a/backend/internal/service/gateway_service_streaming_test.go +++ b/backend/internal/service/gateway_service_streaming_test.go @@ -13,6 +13,8 @@ import ( "github.com/stretchr/testify/require" ) +type upstreamContextTestKey string + func TestGatewayService_StreamingReusesScannerBufferAndStillParsesUsage(t *testing.T) { gin.SetMode(gin.TestMode) cfg := &config.Config{ @@ -50,3 +52,14 @@ func TestGatewayService_StreamingReusesScannerBufferAndStillParsesUsage(t *testi require.Equal(t, 3, result.usage.InputTokens) require.Equal(t, 7, result.usage.OutputTokens) } + +func TestDetachUpstreamContextIgnoresClientCancel(t *testing.T) { + parent, cancel := context.WithCancel(context.WithValue(context.Background(), upstreamContextTestKey("test-key"), "test-value")) + upstreamCtx, release := detachUpstreamContext(parent) + defer release() + + cancel() + + require.NoError(t, upstreamCtx.Err()) + require.Equal(t, "test-value", upstreamCtx.Value(upstreamContextTestKey("test-key"))) +} diff --git a/backend/internal/service/group.go b/backend/internal/service/group.go index bb4c5aa1..f6155352 100644 --- a/backend/internal/service/group.go +++ b/backend/internal/service/group.go @@ -26,9 +26,12 @@ type Group struct { DefaultValidityDays int // 图片生成计费配置(antigravity 和 gemini 平台使用) - ImagePrice1K *float64 - ImagePrice2K *float64 - ImagePrice4K *float64 + AllowImageGeneration bool + ImageRateIndependent bool + ImageRateMultiplier float64 + ImagePrice1K *float64 + ImagePrice2K *float64 + ImagePrice4K *float64 // Claude Code 客户端限制 ClaudeCodeOnly bool diff --git a/backend/internal/service/group_service.go b/backend/internal/service/group_service.go index 87174e03..93078aa6 100644 --- a/backend/internal/service/group_service.go +++ b/backend/internal/service/group_service.go @@ -45,19 +45,25 @@ type GroupSortOrderUpdate struct { // CreateGroupRequest 创建分组请求 type CreateGroupRequest struct { - Name string `json:"name"` - Description string `json:"description"` - RateMultiplier float64 `json:"rate_multiplier"` - IsExclusive bool `json:"is_exclusive"` + Name string `json:"name"` + Description string `json:"description"` + RateMultiplier float64 `json:"rate_multiplier"` + IsExclusive bool `json:"is_exclusive"` + AllowImageGeneration bool `json:"allow_image_generation"` + ImageRateIndependent bool `json:"image_rate_independent"` + ImageRateMultiplier *float64 `json:"image_rate_multiplier"` } // UpdateGroupRequest 更新分组请求 type UpdateGroupRequest struct { - Name *string `json:"name"` - Description *string `json:"description"` - RateMultiplier *float64 `json:"rate_multiplier"` - IsExclusive *bool `json:"is_exclusive"` - Status *string `json:"status"` + Name *string `json:"name"` + Description *string `json:"description"` + RateMultiplier *float64 `json:"rate_multiplier"` + IsExclusive *bool `json:"is_exclusive"` + Status *string `json:"status"` + AllowImageGeneration *bool `json:"allow_image_generation"` + ImageRateIndependent *bool `json:"image_rate_independent"` + ImageRateMultiplier *float64 `json:"image_rate_multiplier"` } // GroupService 分组管理服务 @@ -76,6 +82,13 @@ func NewGroupService(groupRepo GroupRepository, authCacheInvalidator APIKeyAuthC // Create 创建分组 func (s *GroupService) Create(ctx context.Context, req CreateGroupRequest) (*Group, error) { + imageRateMultiplier := 1.0 + if req.ImageRateMultiplier != nil { + if *req.ImageRateMultiplier < 0 { + return nil, fmt.Errorf("image_rate_multiplier must be >= 0") + } + imageRateMultiplier = *req.ImageRateMultiplier + } // 检查名称是否已存在 exists, err := s.groupRepo.ExistsByName(ctx, req.Name) if err != nil { @@ -87,13 +100,16 @@ func (s *GroupService) Create(ctx context.Context, req CreateGroupRequest) (*Gro // 创建分组 group := &Group{ - Name: req.Name, - Description: req.Description, - Platform: PlatformAnthropic, - RateMultiplier: req.RateMultiplier, - IsExclusive: req.IsExclusive, - Status: StatusActive, - SubscriptionType: SubscriptionTypeStandard, + Name: req.Name, + Description: req.Description, + Platform: PlatformAnthropic, + RateMultiplier: req.RateMultiplier, + IsExclusive: req.IsExclusive, + Status: StatusActive, + SubscriptionType: SubscriptionTypeStandard, + AllowImageGeneration: req.AllowImageGeneration, + ImageRateIndependent: req.ImageRateIndependent, + ImageRateMultiplier: imageRateMultiplier, } if err := s.groupRepo.Create(ctx, group); err != nil { @@ -165,6 +181,18 @@ func (s *GroupService) Update(ctx context.Context, id int64, req UpdateGroupRequ if req.Status != nil { group.Status = *req.Status } + if req.AllowImageGeneration != nil { + group.AllowImageGeneration = *req.AllowImageGeneration + } + if req.ImageRateIndependent != nil { + group.ImageRateIndependent = *req.ImageRateIndependent + } + if req.ImageRateMultiplier != nil { + if *req.ImageRateMultiplier < 0 { + return nil, fmt.Errorf("image_rate_multiplier must be >= 0") + } + group.ImageRateMultiplier = *req.ImageRateMultiplier + } if err := s.groupRepo.Update(ctx, group); err != nil { return nil, fmt.Errorf("update group: %w", err) diff --git a/backend/internal/service/image_billing_multiplier.go b/backend/internal/service/image_billing_multiplier.go new file mode 100644 index 00000000..23ec5ac1 --- /dev/null +++ b/backend/internal/service/image_billing_multiplier.go @@ -0,0 +1,11 @@ +package service + +func resolveImageRateMultiplier(apiKey *APIKey, effectiveGroupMultiplier float64) float64 { + if apiKey != nil && apiKey.Group != nil && apiKey.Group.ImageRateIndependent { + if apiKey.Group.ImageRateMultiplier < 0 { + return 0 + } + return apiKey.Group.ImageRateMultiplier + } + return effectiveGroupMultiplier +} diff --git a/backend/internal/service/image_generation_intent.go b/backend/internal/service/image_generation_intent.go new file mode 100644 index 00000000..b6ef1065 --- /dev/null +++ b/backend/internal/service/image_generation_intent.go @@ -0,0 +1,220 @@ +package service + +import ( + "encoding/json" + "strings" + + "github.com/tidwall/gjson" +) + +const ( + openAIResponsesEndpoint = "/v1/responses" + openAIResponsesCompactEndpoint = "/v1/responses/compact" + imageGenerationPermissionMessage = "Image generation is not enabled for this group" +) + +// ImageGenerationPermissionMessage returns the stable end-user error text for disabled groups. +func ImageGenerationPermissionMessage() string { + return imageGenerationPermissionMessage +} + +// GroupAllowsImageGeneration preserves ungrouped-key behavior and enforces the flag when a group is present. +func GroupAllowsImageGeneration(group *Group) bool { + return group == nil || group.AllowImageGeneration +} + +// IsImageGenerationIntent classifies requests that can produce generated images. +func IsImageGenerationIntent(endpoint string, requestedModel string, body []byte) bool { + if IsImageGenerationEndpoint(endpoint) { + return true + } + if isOpenAIImageGenerationModel(requestedModel) { + return true + } + if len(body) == 0 || !gjson.ValidBytes(body) { + return false + } + if model := strings.TrimSpace(gjson.GetBytes(body, "model").String()); isOpenAIImageGenerationModel(model) { + return true + } + if openAIJSONToolsContainImageGeneration(gjson.GetBytes(body, "tools")) { + return true + } + return openAIJSONToolChoiceSelectsImageGeneration(gjson.GetBytes(body, "tool_choice")) +} + +// IsImageGenerationIntentMap is the map-backed variant used after service-side request mutation. +func IsImageGenerationIntentMap(endpoint string, requestedModel string, reqBody map[string]any) bool { + if IsImageGenerationEndpoint(endpoint) { + return true + } + if isOpenAIImageGenerationModel(requestedModel) { + return true + } + if reqBody == nil { + return false + } + if isOpenAIImageGenerationModel(firstNonEmptyString(reqBody["model"])) { + return true + } + if hasOpenAIImageGenerationTool(reqBody) { + return true + } + return openAIAnyToolChoiceSelectsImageGeneration(reqBody["tool_choice"]) +} + +// IsImageGenerationEndpoint identifies dedicated generated-image endpoints. +func IsImageGenerationEndpoint(endpoint string) bool { + switch normalizeImageGenerationEndpoint(endpoint) { + case "/v1/images/generations", "/v1/images/edits", "/images/generations", "/images/edits": + return true + default: + return false + } +} + +func normalizeImageGenerationEndpoint(endpoint string) string { + endpoint = strings.TrimSpace(strings.ToLower(endpoint)) + if endpoint == "" { + return "" + } + endpoint = strings.TrimPrefix(endpoint, "https://api.openai.com") + if idx := strings.IndexByte(endpoint, '?'); idx >= 0 { + endpoint = endpoint[:idx] + } + return strings.TrimRight(endpoint, "/") +} + +func openAIJSONToolsContainImageGeneration(tools gjson.Result) bool { + if !tools.IsArray() { + return false + } + found := false + tools.ForEach(func(_, item gjson.Result) bool { + if strings.TrimSpace(item.Get("type").String()) == "image_generation" { + found = true + return false + } + return true + }) + return found +} + +func openAIJSONToolChoiceSelectsImageGeneration(choice gjson.Result) bool { + if !choice.Exists() { + return false + } + if choice.Type == gjson.String { + return strings.TrimSpace(choice.String()) == "image_generation" + } + if !choice.IsObject() { + return false + } + if strings.TrimSpace(choice.Get("type").String()) == "image_generation" { + return true + } + if strings.TrimSpace(choice.Get("tool.type").String()) == "image_generation" { + return true + } + if strings.TrimSpace(choice.Get("function.name").String()) == "image_generation" { + return true + } + return false +} + +func openAIAnyToolChoiceSelectsImageGeneration(choice any) bool { + switch v := choice.(type) { + case string: + return strings.TrimSpace(v) == "image_generation" + case map[string]any: + if strings.TrimSpace(firstNonEmptyString(v["type"])) == "image_generation" { + return true + } + if tool, ok := v["tool"].(map[string]any); ok && strings.TrimSpace(firstNonEmptyString(tool["type"])) == "image_generation" { + return true + } + if fn, ok := v["function"].(map[string]any); ok && strings.TrimSpace(firstNonEmptyString(fn["name"])) == "image_generation" { + return true + } + } + return false +} + +func getAPIKeyFromContext(c interface{ Get(string) (any, bool) }) *APIKey { + if c == nil { + return nil + } + v, exists := c.Get("api_key") + if !exists { + return nil + } + apiKey, _ := v.(*APIKey) + return apiKey +} + +func apiKeyGroup(apiKey *APIKey) *Group { + if apiKey == nil { + return nil + } + return apiKey.Group +} + +func cloneRequestMapForImageIntent(body []byte) map[string]any { + if len(body) == 0 { + return nil + } + var out map[string]any + if err := json.Unmarshal(body, &out); err != nil { + return nil + } + return out +} + +func resolveOpenAIResponsesImageBillingConfig(reqBody map[string]any, fallbackModel string) (string, string, error) { + imageModel := "" + imageSize := "" + hasImageTool := false + if reqBody != nil { + rawTools, _ := reqBody["tools"].([]any) + for _, rawTool := range rawTools { + toolMap, ok := rawTool.(map[string]any) + if !ok || strings.TrimSpace(firstNonEmptyString(toolMap["type"])) != "image_generation" { + continue + } + hasImageTool = true + imageModel = strings.TrimSpace(firstNonEmptyString(toolMap["model"])) + imageSize = strings.TrimSpace(firstNonEmptyString(toolMap["size"])) + break + } + if imageSize == "" { + imageSize = strings.TrimSpace(firstNonEmptyString(reqBody["size"])) + } + } + if imageModel == "" && reqBody != nil { + bodyModel := strings.TrimSpace(firstNonEmptyString(reqBody["model"])) + if isOpenAIImageBillingModelAlias(bodyModel) || !hasImageTool { + imageModel = bodyModel + } + } + if imageModel == "" && hasImageTool { + imageModel = "gpt-image-2" + } + if imageModel == "" { + imageModel = strings.TrimSpace(fallbackModel) + } + sizeTier := normalizeOpenAIImageSizeTier(imageSize) + return imageModel, sizeTier, nil +} + +func resolveOpenAIResponsesImageBillingConfigFromBody(body []byte, fallbackModel string) (string, string, error) { + reqBody := cloneRequestMapForImageIntent(body) + return resolveOpenAIResponsesImageBillingConfig(reqBody, fallbackModel) +} + +func isOpenAIImageBillingModelAlias(model string) bool { + normalized := strings.ToLower(strings.TrimSpace(model)) + if normalized == "" { + return false + } + return isOpenAIImageGenerationModel(normalized) || strings.Contains(normalized, "image") +} diff --git a/backend/internal/service/image_generation_intent_test.go b/backend/internal/service/image_generation_intent_test.go new file mode 100644 index 00000000..5e7bec79 --- /dev/null +++ b/backend/internal/service/image_generation_intent_test.go @@ -0,0 +1,184 @@ +package service + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestIsImageGenerationIntent(t *testing.T) { + tests := []struct { + name string + endpoint string + model string + body []byte + want bool + }{ + { + name: "images endpoint", + endpoint: "/v1/images/generations", + body: []byte(`{"model":"gpt-image-2"}`), + want: true, + }, + { + name: "image model", + endpoint: "/v1/responses", + model: "gpt-image-2", + body: []byte(`{"model":"gpt-image-2"}`), + want: true, + }, + { + name: "image tool", + endpoint: "/v1/responses", + model: "gpt-5.4", + body: []byte(`{"model":"gpt-5.4","tools":[{"type":"image_generation"}]}`), + want: true, + }, + { + name: "image tool choice", + endpoint: "/v1/responses", + model: "gpt-5.4", + body: []byte(`{"model":"gpt-5.4","tool_choice":{"type":"image_generation"}}`), + want: true, + }, + { + name: "required tool choice alone is text", + endpoint: "/v1/responses", + model: "gpt-5.4", + body: []byte(`{"model":"gpt-5.4","tool_choice":"required"}`), + want: false, + }, + { + name: "text only gpt 5.4", + endpoint: "/v1/responses", + model: "gpt-5.4", + body: []byte(`{"model":"gpt-5.4","input":"write code"}`), + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + require.Equal(t, tt.want, IsImageGenerationIntent(tt.endpoint, tt.model, tt.body)) + }) + } +} + +func TestResolveOpenAIResponsesImageBillingConfigUsesCurrentBodyModel(t *testing.T) { + imageModel, imageSize, err := resolveOpenAIResponsesImageBillingConfigFromBody( + []byte(`{"model":"mapped-image-model","tools":[{"type":"image_generation","size":"1024x1024"}]}`), + "requested-model", + ) + require.NoError(t, err) + require.Equal(t, "mapped-image-model", imageModel) + require.Equal(t, "1K", imageSize) +} + +func TestResolveOpenAIResponsesImageBillingConfigToolModelWins(t *testing.T) { + imageModel, imageSize, err := resolveOpenAIResponsesImageBillingConfigFromBody( + []byte(`{"model":"mapped-text-model","tools":[{"type":"image_generation","model":"gpt-image-2","size":"1536x1024"}]}`), + "requested-model", + ) + require.NoError(t, err) + require.Equal(t, "gpt-image-2", imageModel) + require.Equal(t, "2K", imageSize) +} + +func TestResolveOpenAIResponsesImageBillingConfigSupportsOfficialAndCustomSizes(t *testing.T) { + tests := []struct { + name string + body []byte + wantTier string + }{ + { + name: "official 2k landscape", + body: []byte(`{"model":"gpt-5.4","tools":[{"type":"image_generation","model":"gpt-image-2","size":"2048x1152"}]}`), + wantTier: "2K", + }, + { + name: "official 4k landscape", + body: []byte(`{"model":"gpt-5.4","tools":[{"type":"image_generation","model":"gpt-image-2","size":"3840x2160"}]}`), + wantTier: "4K", + }, + { + name: "custom valid 2k", + body: []byte(`{"model":"gpt-5.5","tools":[{"type":"image_generation","model":"gpt-image-2","size":"1280x768"}]}`), + wantTier: "2K", + }, + { + name: "default image tool model supports flexible size", + body: []byte(`{"model":"gpt-5.4","tools":[{"type":"image_generation","size":"2048x1152"}]}`), + wantTier: "2K", + }, + { + name: "top level image size is moved into billing", + body: []byte(`{"model":"gpt-image-2","size":"2048x2048","tools":[{"type":"image_generation","model":"gpt-image-2"}]}`), + wantTier: "2K", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + imageModel, imageSize, err := resolveOpenAIResponsesImageBillingConfigFromBody(tt.body, "requested-model") + require.NoError(t, err) + require.NotEmpty(t, imageModel) + require.Equal(t, tt.wantTier, imageSize) + }) + } +} + +func TestResolveOpenAIResponsesImageBillingConfigDoesNotRejectUnknownSizes(t *testing.T) { + imageModel, imageSize, err := resolveOpenAIResponsesImageBillingConfigFromBody( + []byte(`{"model":"gpt-5.4","tools":[{"type":"image_generation","model":"gpt-image-1.5","size":"2048x1152"}]}`), + "requested-model", + ) + require.NoError(t, err) + require.Equal(t, "gpt-image-1.5", imageModel) + require.Equal(t, "2K", imageSize) +} + +func TestOpenAIImageOutputCounterDeduplicatesFinalImages(t *testing.T) { + counter := newOpenAIImageOutputCounter() + counter.AddSSEData([]byte(`{"type":"response.image_generation_call.partial_image","partial_image_b64":"abc"}`)) + counter.AddSSEData([]byte(`{"type":"response.output_item.done","item":{"id":"ig_1","type":"image_generation_call","result":"final-a"}}`)) + counter.AddSSEData([]byte(`{"type":"response.completed","response":{"output":[{"id":"ig_1","type":"image_generation_call","result":"final-a"},{"id":"ig_2","type":"image_generation_call","result":"final-b"}]}}`)) + require.Equal(t, 2, counter.Count()) +} + +func TestOpenAIImageOutputCounterCountsImagesAPIStreamShapes(t *testing.T) { + counter := newOpenAIImageOutputCounter() + counter.AddSSEData([]byte(`{"type":"image_generation.completed","id":"ig_complete","b64_json":"final-a"}`)) + counter.AddSSEData([]byte(`{"type":"response.output_item.done","item":{"id":"ig_item","type":"image_generation_call","result":"final-b"}}`)) + counter.AddSSEData([]byte(`{"type":"response.completed","response":{"output":[{"id":"ig_done","type":"image_generation_call","result":"final-c"}]}}`)) + require.Equal(t, 3, counter.Count()) + + dataCounter := newOpenAIImageOutputCounter() + dataCounter.AddSSEData([]byte(`{"data":[{"b64_json":"a"},{"b64_json":"b"}]}`)) + dataCounter.AddSSEData([]byte(`{"data":[{"b64_json":"a"},{"b64_json":"b"},{"b64_json":"c"}]}`)) + require.Equal(t, 3, dataCounter.Count()) +} + +func TestOpenAIImageOutputCounterCountsMultilineSSEDataPayload(t *testing.T) { + counter := newOpenAIImageOutputCounter() + counter.AddSSEData([]byte("{\"type\":\"image_generation.completed\",\n\"b64_json\":\"final-a\"}")) + require.Equal(t, 1, counter.Count()) +} + +func TestOpenAIImageOutputCounterCountsMultilineSSEBodyPayload(t *testing.T) { + counter := newOpenAIImageOutputCounter() + counter.AddSSEBody( + "data: {\"type\":\"image_generation.completed\",\n" + + "data: \"b64_json\":\"final-a\"}\n\n" + + "data: [DONE]\n\n", + ) + require.Equal(t, 1, counter.Count()) +} + +func TestOpenAIImageOutputCounterFallsBackForInvalidMultilineSSEBody(t *testing.T) { + counter := newOpenAIImageOutputCounter() + counter.AddSSEBody( + "data: {\"type\":\"image_generation.completed\",\"b64_json\":\"final-a\"}\n" + + "data: {\"type\":\"image_generation.completed\",\"b64_json\":\"final-b\"}\n\n", + ) + require.Equal(t, 2, counter.Count()) +} diff --git a/backend/internal/service/image_output_accounting.go b/backend/internal/service/image_output_accounting.go new file mode 100644 index 00000000..219c0c59 --- /dev/null +++ b/backend/internal/service/image_output_accounting.go @@ -0,0 +1,149 @@ +package service + +import ( + "crypto/sha256" + "encoding/hex" + "strings" + + "github.com/tidwall/gjson" +) + +type openAIImageOutputCounter struct { + seen map[string]struct{} + count int + maxDataCount int +} + +func newOpenAIImageOutputCounter() *openAIImageOutputCounter { + return &openAIImageOutputCounter{seen: make(map[string]struct{})} +} + +func (c *openAIImageOutputCounter) Count() int { + if c == nil { + return 0 + } + if c.maxDataCount > c.count { + return c.maxDataCount + } + return c.count +} + +func (c *openAIImageOutputCounter) AddJSONResponse(body []byte) { + if c == nil || len(body) == 0 || !gjson.ValidBytes(body) { + return + } + c.addDataArray(gjson.GetBytes(body, "data")) + c.addOutputArray(gjson.GetBytes(body, "output")) + c.addOutputArray(gjson.GetBytes(body, "response.output")) +} + +func (c *openAIImageOutputCounter) AddSSEData(data []byte) { + if c == nil || len(data) == 0 || strings.TrimSpace(string(data)) == "[DONE]" || !gjson.ValidBytes(data) { + return + } + root := gjson.ParseBytes(data) + c.addDataArray(root.Get("data")) + eventType := strings.TrimSpace(root.Get("type").String()) + switch eventType { + case "response.output_item.done": + c.addImageOutputItem(root.Get("item")) + case "response.completed", "response.done": + c.addOutputArray(root.Get("response.output")) + case "image_generation.completed": + if item := root.Get("item"); item.Exists() { + c.addImageOutputItem(item) + return + } + if output := root.Get("output"); output.Exists() { + c.addImageOutputItem(output) + return + } + c.addImageOutputItem(root) + } +} + +func (c *openAIImageOutputCounter) AddSSEBody(body string) { + if c == nil || strings.TrimSpace(body) == "" { + return + } + forEachOpenAISSEDataPayload(body, c.AddSSEData) +} + +func (c *openAIImageOutputCounter) addDataArray(data gjson.Result) { + if !data.IsArray() { + return + } + count := len(data.Array()) + if count > c.maxDataCount { + c.maxDataCount = count + } +} + +func (c *openAIImageOutputCounter) addOutputArray(output gjson.Result) { + if !output.IsArray() { + return + } + output.ForEach(func(_, item gjson.Result) bool { + c.addImageOutputItem(item) + return true + }) +} + +func (c *openAIImageOutputCounter) addImageOutputItem(item gjson.Result) { + if !item.Exists() || !item.IsObject() { + return + } + itemType := strings.TrimSpace(item.Get("type").String()) + if itemType != "" && itemType != "image_generation_call" && itemType != "image_generation.completed" { + return + } + if strings.Contains(strings.ToLower(item.Raw), "partial_image") { + return + } + result := strings.TrimSpace(item.Get("result").String()) + if result == "" { + result = strings.TrimSpace(item.Get("b64_json").String()) + } + if result == "" { + result = strings.TrimSpace(item.Get("url").String()) + } + if result == "" && itemType != "image_generation.completed" { + return + } + key := strings.TrimSpace(item.Get("id").String()) + if key == "" { + key = strings.TrimSpace(item.Get("call_id").String()) + } + if key == "" { + key = hashOpenAIImageOutputResult(result) + } + if key == "" { + return + } + if _, exists := c.seen[key]; exists { + return + } + c.seen[key] = struct{}{} + c.count++ +} + +func hashOpenAIImageOutputResult(result string) string { + result = strings.TrimSpace(result) + if result == "" { + return "" + } + sum := sha256.Sum256([]byte(result)) + return hex.EncodeToString(sum[:]) +} + +func countOpenAIResponseImageOutputsFromJSONBytes(body []byte) int { + counter := newOpenAIImageOutputCounter() + counter.AddJSONResponse(body) + return counter.Count() +} + +func countOpenAIImageOutputsFromSSEBody(body string) int { + counter := newOpenAIImageOutputCounter() + counter.AddSSEBody(body) + return counter.Count() +} diff --git a/backend/internal/service/openai_account_scheduler.go b/backend/internal/service/openai_account_scheduler.go index 5fe96243..c63151ae 100644 --- a/backend/internal/service/openai_account_scheduler.go +++ b/backend/internal/service/openai_account_scheduler.go @@ -5,6 +5,7 @@ import ( "context" "fmt" "hash/fnv" + "log/slog" "math" "sort" "strconv" @@ -345,7 +346,7 @@ func (s *defaultOpenAIAccountScheduler) selectBySessionHash( _ = s.service.deleteStickySessionAccountID(ctx, req.GroupID, sessionHash) return nil, nil } - if !s.isAccountRequestCompatible(account, req) { + if !s.isAccountRequestCompatible(ctx, account, req) { return nil, nil } if !s.isAccountTransportCompatible(account, req.RequiredTransport) { @@ -621,7 +622,7 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance( fmt.Sprintf("Privacy not set, required by group [%s]", schedGroup.Name)) continue } - if !s.isAccountRequestCompatible(account, req) { + if !s.isAccountRequestCompatible(ctx, account, req) { continue } if !s.isAccountTransportCompatible(account, req.RequiredTransport) { @@ -828,11 +829,11 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance( for i := 0; i < len(selectionOrder); i++ { candidate := selectionOrder[i] fresh := s.service.resolveFreshSchedulableOpenAIAccount(ctx, candidate.account, req.RequestedModel, false) - if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) || !s.isAccountRequestCompatible(fresh, req) { + if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) || !s.isAccountRequestCompatible(ctx, fresh, req) { continue } fresh = s.service.recheckSelectedOpenAIAccountFromDB(ctx, fresh, req.RequestedModel, false) - if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) || !s.isAccountRequestCompatible(fresh, req) { + if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) || !s.isAccountRequestCompatible(ctx, fresh, req) { continue } if req.RequireCompact && openAICompactSupportTier(fresh) == 0 { @@ -859,11 +860,11 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance( // WaitPlan.MaxConcurrency 使用 Concurrency(非 EffectiveLoadFactor),因为 WaitPlan 控制的是 Redis 实际并发槽位等待。 for _, candidate := range selectionOrder { fresh := s.service.resolveFreshSchedulableOpenAIAccount(ctx, candidate.account, req.RequestedModel, false) - if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) || !s.isAccountRequestCompatible(fresh, req) { + if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) || !s.isAccountRequestCompatible(ctx, fresh, req) { continue } fresh = s.service.recheckSelectedOpenAIAccountFromDB(ctx, fresh, req.RequestedModel, false) - if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) || !s.isAccountRequestCompatible(fresh, req) { + if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) || !s.isAccountRequestCompatible(ctx, fresh, req) { continue } if req.RequireCompact && openAICompactSupportTier(fresh) == 0 { @@ -894,13 +895,18 @@ func (s *defaultOpenAIAccountScheduler) isAccountTransportCompatible(account *Ac return s.service.isOpenAIAccountTransportCompatible(account, requiredTransport) } -func (s *defaultOpenAIAccountScheduler) isAccountRequestCompatible(account *Account, req OpenAIAccountScheduleRequest) bool { +func (s *defaultOpenAIAccountScheduler) isAccountRequestCompatible(ctx context.Context, account *Account, req OpenAIAccountScheduleRequest) bool { if account == nil { return false } if req.RequestedModel != "" && !account.IsModelSupported(req.RequestedModel) { return false } + if req.GroupID != nil && s != nil && s.service != nil && + s.service.needsUpstreamChannelRestrictionCheck(ctx, req.GroupID) && + s.service.isUpstreamModelRestrictedByChannel(ctx, *req.GroupID, account, req.RequestedModel, req.RequireCompact) { + return false + } return account.SupportsOpenAIImageCapability(req.RequiredImageCapability) } @@ -1112,6 +1118,13 @@ func (s *OpenAIGatewayService) selectAccountWithScheduler( } } + if s.checkChannelPricingRestriction(ctx, groupID, requestedModel) { + slog.Warn("channel pricing restriction blocked request", + "group_id", derefGroupID(groupID), + "model", requestedModel) + return nil, decision, fmt.Errorf("%w supporting model: %s (channel pricing restriction)", ErrNoAvailableAccounts, requestedModel) + } + var stickyAccountID int64 if sessionHash != "" && s.cache != nil { if accountID, err := s.getStickySessionAccountID(ctx, groupID, sessionHash); err == nil && accountID > 0 { diff --git a/backend/internal/service/openai_apikey_responses_probe.go b/backend/internal/service/openai_apikey_responses_probe.go new file mode 100644 index 00000000..a4eb9252 --- /dev/null +++ b/backend/internal/service/openai_apikey_responses_probe.go @@ -0,0 +1,149 @@ +package service + +import ( + "bytes" + "context" + "encoding/json" + "io" + "net/http" + "strings" + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" + "github.com/Wei-Shaw/sub2api/internal/pkg/openai" + "github.com/Wei-Shaw/sub2api/internal/pkg/openai_compat" +) + +// openaiResponsesProbeTimeout 是探测请求的超时时长。 +// 探测必须快速失败——超时不应阻塞账号创建/更新流程。 +const openaiResponsesProbeTimeout = 8 * time.Second + +// openaiResponsesProbePayload 是探测使用的最小 Responses 请求体。 +// 仅作能力探测,不期望响应内容质量;Stream=false 减少 SSE 解析开销。 +// +// 注意:探测的目标是区分"端点存在"与"端点不存在"——只要上游返回非 404 的 +// 4xx/5xx(如 400 invalid_request_error / 401 unauthorized / 422 等), +// 都视为"端点存在 → 支持 Responses"。仅 404 / 405 视为"端点不存在"。 +func openaiResponsesProbePayload(modelID string) []byte { + if strings.TrimSpace(modelID) == "" { + modelID = openai.DefaultTestModel + } + body, _ := json.Marshal(map[string]any{ + "model": modelID, + "input": []map[string]any{ + { + "role": "user", + "content": []map[string]any{ + {"type": "input_text", "text": "hi"}, + }, + }, + }, + "instructions": openai.DefaultInstructions, + "stream": false, + }) + return body +} + +// ProbeOpenAIAPIKeyResponsesSupport 探测 OpenAI APIKey 账号上游是否支持 +// /v1/responses 端点,并将结果持久化到 accounts.extra.openai_responses_supported。 +// +// 调用时机:账号创建/更新后,且仅当 platform=openai && type=apikey 时。 +// +// 探测策略(参见包文档 internal/pkg/openai_compat): +// - 上游 404 / 405 → 不支持,写 false +// - 上游 2xx / 其他 4xx(401/422/400 等)/ 5xx → 支持,写 true +// - 网络层失败(连接错误、超时)→ 不写标记,保持 unknown +// (后续请求仍按"现状即证据"默认走 Responses) +// +// 该方法是幂等的:重复调用会以最新探测结果覆盖标记。 +// +// 关于失败处理:探测本身的失败不应阻塞账号创建——账号能创建/更新成功就够了, +// 探测结果只影响后续路由优化。所有错误都仅记录日志,不向调用方传播。 +func (s *AccountTestService) ProbeOpenAIAPIKeyResponsesSupport(ctx context.Context, accountID int64) { + account, err := s.accountRepo.GetByID(ctx, accountID) + if err != nil { + logger.LegacyPrintf("service.openai_probe", "probe_load_account_failed: account_id=%d err=%v", accountID, err) + return + } + if account.Platform != PlatformOpenAI || account.Type != AccountTypeAPIKey { + // 仅 OpenAI APIKey 账号需要探测;其他账号类型无能力差异。 + return + } + + apiKey := account.GetOpenAIApiKey() + if apiKey == "" { + logger.LegacyPrintf("service.openai_probe", "probe_skip_no_apikey: account_id=%d", accountID) + return + } + baseURL := account.GetOpenAIBaseURL() + if baseURL == "" { + baseURL = "https://api.openai.com" + } + normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL) + if err != nil { + logger.LegacyPrintf("service.openai_probe", "probe_invalid_baseurl: account_id=%d base_url=%q err=%v", accountID, baseURL, err) + return + } + + probeURL := buildOpenAIResponsesURL(normalizedBaseURL) + + probeCtx, cancel := context.WithTimeout(ctx, openaiResponsesProbeTimeout) + defer cancel() + + req, err := http.NewRequestWithContext(probeCtx, http.MethodPost, probeURL, bytes.NewReader(openaiResponsesProbePayload(""))) + if err != nil { + logger.LegacyPrintf("service.openai_probe", "probe_build_request_failed: account_id=%d err=%v", accountID, err) + return + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+apiKey) + req.Header.Set("Accept", "application/json") + + proxyURL := "" + if account.ProxyID != nil && account.Proxy != nil { + proxyURL = account.Proxy.URL() + } + + resp, err := s.httpUpstream.DoWithTLS(req, proxyURL, account.ID, account.Concurrency, s.tlsFPProfileService.ResolveTLSProfile(account)) + if err != nil { + // 网络层失败:不写标记,保持 unknown,下次重试或由网关 fallback 处理 + logger.LegacyPrintf("service.openai_probe", "probe_request_failed: account_id=%d url=%s err=%v", accountID, probeURL, err) + return + } + defer func() { + _, _ = io.Copy(io.Discard, io.LimitReader(resp.Body, 1<<20)) + _ = resp.Body.Close() + }() + + supported := isResponsesEndpointSupportedByStatus(resp.StatusCode) + + if err := s.accountRepo.UpdateExtra(ctx, accountID, map[string]any{ + openai_compat.ExtraKeyResponsesSupported: supported, + }); err != nil { + logger.LegacyPrintf("service.openai_probe", "probe_persist_failed: account_id=%d supported=%v err=%v", accountID, supported, err) + return + } + + logger.LegacyPrintf("service.openai_probe", + "probe_done: account_id=%d base_url=%s status=%d supported=%v", + accountID, normalizedBaseURL, resp.StatusCode, supported, + ) +} + +// isResponsesEndpointSupportedByStatus 根据探测响应的 HTTP 状态码判定上游 +// 是否暴露 /v1/responses 端点。 +// +// 关键观察:第三方 OpenAI 兼容上游(DeepSeek/Kimi 等)对未知端点统一返回 404 +// 或 405;而 OpenAI 官方/有 Responses 实现的上游会因为请求体最简(缺字段) +// 返回 400/422 等业务错误,但端点本身存在。 +// +// 因此:仅 404 和 405 视为"端点不存在",其他 status 视为"端点存在"。 +// +// 5xx 也视为"端点存在"——上游偶发故障不应误判为不支持。 +func isResponsesEndpointSupportedByStatus(status int) bool { + switch status { + case http.StatusNotFound, http.StatusMethodNotAllowed: + return false + } + return true +} diff --git a/backend/internal/service/openai_codex_transform.go b/backend/internal/service/openai_codex_transform.go index b256f1c7..a3b69dee 100644 --- a/backend/internal/service/openai_codex_transform.go +++ b/backend/internal/service/openai_codex_transform.go @@ -38,6 +38,29 @@ var codexModelMap = map[string]string{ "gpt-5.2-medium": "gpt-5.2", "gpt-5.2-high": "gpt-5.2", "gpt-5.2-xhigh": "gpt-5.2", + "gpt-5": "gpt-5.4", + "gpt-5-mini": "gpt-5.4", + "gpt-5-nano": "gpt-5.4", + "gpt-5.1": "gpt-5.4", + "gpt-5.1-codex": "gpt-5.3-codex", + "gpt-5.1-codex-max": "gpt-5.3-codex", + "gpt-5.1-codex-mini": "gpt-5.3-codex", + "gpt-5.2-codex": "gpt-5.2", + "codex-mini-latest": "gpt-5.3-codex", + "gpt-5-codex": "gpt-5.3-codex", +} + +var codexVersionModelPrefixes = []struct { + prefix string + target string +}{ + {prefix: "gpt-5.3-codex-spark", target: "gpt-5.3-codex-spark"}, + {prefix: "gpt-5.3-codex", target: "gpt-5.3-codex"}, + {prefix: "gpt-5.4-mini", target: "gpt-5.4-mini"}, + {prefix: "gpt-5.4-nano", target: "gpt-5.4-nano"}, + {prefix: "gpt-5.5", target: "gpt-5.5"}, + {prefix: "gpt-5.4", target: "gpt-5.4"}, + {prefix: "gpt-5.2", target: "gpt-5.2"}, } type codexTransformResult struct { @@ -46,6 +69,13 @@ type codexTransformResult struct { PromptCacheKey string } +type codexOAuthTransformOptions struct { + IsCodexCLI bool + IsCompact bool + SkipDefaultInstructions bool + PreserveToolCallIDs bool +} + const ( codexImageGenerationBridgeMarker = "" codexImageGenerationBridgeText = codexImageGenerationBridgeMarker + "\nWhen the user asks for raster image generation or editing, use the OpenAI Responses native `image_generation` tool attached to this request. The local Codex client may not expose an `image_gen` namespace, but that does not mean image generation is unavailable. Do not ask the user to switch to CLI fallback solely because `image_gen` is absent.\n" @@ -71,6 +101,13 @@ var openAICodexOAuthUnsupportedFields = append([]string{ }, openAIChatGPTInternalUnsupportedFields...) func applyCodexOAuthTransform(reqBody map[string]any, isCodexCLI bool, isCompact bool) codexTransformResult { + return applyCodexOAuthTransformWithOptions(reqBody, codexOAuthTransformOptions{ + IsCodexCLI: isCodexCLI, + IsCompact: isCompact, + }) +} + +func applyCodexOAuthTransformWithOptions(reqBody map[string]any, opts codexOAuthTransformOptions) codexTransformResult { result := codexTransformResult{} // 工具续链需求会影响存储策略与 input 过滤逻辑。 needsToolContinuation := NeedsToolContinuation(reqBody) @@ -88,7 +125,7 @@ func applyCodexOAuthTransform(reqBody map[string]any, isCodexCLI bool, isCompact result.NormalizedModel = normalizedModel } - if isCompact { + if opts.IsCompact { if _, ok := reqBody["store"]; ok { delete(reqBody, "store") result.Modified = true @@ -160,6 +197,10 @@ func applyCodexOAuthTransform(reqBody map[string]any, isCodexCLI bool, isCompact if v, ok := reqBody["prompt_cache_key"].(string); ok { result.PromptCacheKey = strings.TrimSpace(v) + if isOpenAICompatMessagesBridgeRequestBody(reqBody) { + delete(reqBody, "prompt_cache_key") + result.Modified = true + } } // 提取 input 中 role:"system" 消息至 instructions(OAuth 上游不支持 system role)。 @@ -168,7 +209,7 @@ func applyCodexOAuthTransform(reqBody map[string]any, isCodexCLI bool, isCompact } // instructions 处理逻辑:根据是否是 Codex CLI 分别调用不同方法 - if applyInstructions(reqBody, isCodexCLI) { + if !opts.SkipDefaultInstructions && applyInstructions(reqBody, opts.IsCodexCLI) { result.Modified = true } if isCodexSparkModel(normalizedModel) && applyCodexSparkImageUnsupportedInstructions(reqBody) { @@ -185,7 +226,10 @@ func applyCodexOAuthTransform(reqBody map[string]any, isCodexCLI bool, isCompact input = normalizedInput result.Modified = true } - input = filterCodexInput(input, needsToolContinuation) + input = filterCodexInputWithOptions(input, codexInputFilterOptions{ + PreserveReferences: needsToolContinuation, + PreserveCallIDs: opts.PreserveToolCallIDs, + }) reqBody["input"] = input result.Modified = true } else if inputStr, ok := reqBody["input"].(string); ok { @@ -447,51 +491,81 @@ func normalizeCodexModel(model string) string { if model == "" { return "gpt-5.4" } + if mapped, ok := normalizeKnownCodexModel(model); ok { + return mapped + } + return model +} + +func normalizeKnownCodexModel(model string) (string, bool) { + model = strings.TrimSpace(model) + if model == "" { + return "", false + } if isOpenAIImageGenerationModel(model) { - return model + return model, true } - modelID := model + modelID := lastOpenAIModelSegment(model) + + if normalized := canonicalizeOpenAIModelAliasSpelling(modelID); normalized != "" { + modelID = normalized + } + if mapped := normalizeKnownOpenAICodexModel(modelID); mapped != "" { + return mapped, true + } + key := codexModelLookupKey(modelID) + if key == "" { + return "", false + } + if mapped := getNormalizedCodexModel(key); mapped != "" { + return mapped, true + } + for _, item := range codexVersionModelPrefixes { + if key == item.prefix { + return item.target, true + } + suffix, ok := strings.CutPrefix(key, item.prefix+"-") + if ok && isKnownCodexModelSuffix(suffix) { + return item.target, true + } + } + return "", false +} + +func codexModelLookupKey(modelID string) string { + modelID = strings.TrimSpace(modelID) + if modelID == "" { + return "" + } if strings.Contains(modelID, "/") { parts := strings.Split(modelID, "/") modelID = parts[len(parts)-1] } + return strings.ToLower(strings.Join(strings.Fields(modelID), "-")) +} - if mapped := getNormalizedCodexModel(modelID); mapped != "" { - return mapped +func isKnownCodexModelSuffix(suffix string) bool { + switch suffix { + case "none", "minimal", "low", "medium", "high", "xhigh": + return true } + return isCodexDateSuffix(suffix) +} - normalized := strings.ToLower(modelID) - - if strings.Contains(normalized, "gpt-5.5") || strings.Contains(normalized, "gpt 5.5") { - return "gpt-5.5" +func isCodexDateSuffix(suffix string) bool { + parts := strings.Split(suffix, "-") + if len(parts) != 3 || len(parts[0]) != 4 || len(parts[1]) != 2 || len(parts[2]) != 2 { + return false } - if strings.Contains(normalized, "gpt-5.4-mini") || strings.Contains(normalized, "gpt 5.4 mini") { - return "gpt-5.4-mini" + for _, part := range parts { + for _, r := range part { + if r < '0' || r > '9' { + return false + } + } } - if strings.Contains(normalized, "gpt-5.4") || strings.Contains(normalized, "gpt 5.4") { - return "gpt-5.4" - } - if strings.Contains(normalized, "gpt-5.2") || strings.Contains(normalized, "gpt 5.2") { - return "gpt-5.2" - } - if strings.Contains(normalized, "gpt-5.3-codex-spark") || strings.Contains(normalized, "gpt 5.3 codex spark") { - return "gpt-5.3-codex-spark" - } - if strings.Contains(normalized, "gpt-5.3-codex") || strings.Contains(normalized, "gpt 5.3 codex") { - return "gpt-5.3-codex" - } - if strings.Contains(normalized, "gpt-5.3") || strings.Contains(normalized, "gpt 5.3") { - return "gpt-5.3-codex" - } - if strings.Contains(normalized, "codex") { - return "gpt-5.3-codex" - } - if strings.Contains(normalized, "gpt-5") || strings.Contains(normalized, "gpt 5") { - return "gpt-5.4" - } - - return "gpt-5.4" + return true } func isCodexSparkModel(model string) bool { @@ -789,23 +863,18 @@ func SupportsVerbosity(model string) bool { } func getNormalizedCodexModel(modelID string) string { - if modelID == "" { + key := codexModelLookupKey(modelID) + if key == "" { return "" } - if mapped, ok := codexModelMap[modelID]; ok { + if mapped, ok := codexModelMap[key]; ok { return mapped } - lower := strings.ToLower(modelID) - for key, value := range codexModelMap { - if strings.ToLower(key) == lower { - return value - } - } return "" } // extractTextFromContent extracts plain text from a content value that is either -// a Go string or a []any of content-part maps with type:"text". +// a Go string or a []any of text-like content-part maps. func extractTextFromContent(content any) string { switch v := content.(type) { case string: @@ -817,7 +886,8 @@ func extractTextFromContent(content any) string { if !ok { continue } - if t, _ := m["type"].(string); t == "text" { + switch t, _ := m["type"].(string); t { + case "text", "input_text", "output_text": if text, ok := m["text"].(string); ok { parts = append(parts, text) } @@ -871,6 +941,28 @@ func extractSystemMessagesFromInput(reqBody map[string]any) bool { return true } +func extractPromptLikeInstructionsFromInput(reqBody map[string]any) string { + input, ok := reqBody["input"].([]any) + if !ok || len(input) == 0 { + return "" + } + var texts []string + for _, item := range input { + m, ok := item.(map[string]any) + if !ok { + continue + } + role, _ := m["role"].(string) + switch role { + case "developer", "system": + if text := strings.TrimSpace(extractTextFromContent(m["content"])); text != "" { + texts = append(texts, text) + } + } + } + return strings.Join(texts, "\n\n") +} + // applyInstructions 处理 instructions 字段:仅在 instructions 为空时填充默认值。 func applyInstructions(reqBody map[string]any, isCodexCLI bool) bool { if !isInstructionsEmpty(reqBody) { @@ -897,9 +989,20 @@ func isInstructionsEmpty(reqBody map[string]any) bool { return strings.TrimSpace(str) == "" } +type codexInputFilterOptions struct { + PreserveReferences bool + PreserveCallIDs bool +} + // filterCodexInput 按需过滤 item_reference 与 id。 // preserveReferences 为 true 时保持引用与 id,以满足续链请求对上下文的依赖。 func filterCodexInput(input []any, preserveReferences bool) []any { + return filterCodexInputWithOptions(input, codexInputFilterOptions{ + PreserveReferences: preserveReferences, + }) +} + +func filterCodexInputWithOptions(input []any, opts codexInputFilterOptions) []any { filtered := make([]any, 0, len(input)) for _, item := range input { m, ok := item.(map[string]any) @@ -920,6 +1023,9 @@ func filterCodexInput(input []any, preserveReferences bool) []any { // 仅修正真正的 tool/function call 标识,避免误改普通 message/reasoning id; // 若 item_reference 指向 legacy call_* 标识,则仅修正该引用本身。 fixCallIDPrefix := func(id string) string { + if opts.PreserveCallIDs { + return id + } if id == "" || strings.HasPrefix(id, "fc") { return id } @@ -930,7 +1036,7 @@ func filterCodexInput(input []any, preserveReferences bool) []any { } if typ == "item_reference" { - if !preserveReferences { + if !opts.PreserveReferences { continue } newItem := make(map[string]any, len(m)) @@ -998,7 +1104,7 @@ func filterCodexInput(input []any, preserveReferences bool) []any { } } - if !preserveReferences { + if !opts.PreserveReferences { ensureCopy() delete(newItem, "id") } diff --git a/backend/internal/service/openai_codex_transform_test.go b/backend/internal/service/openai_codex_transform_test.go index 87bb7162..9c72760a 100644 --- a/backend/internal/service/openai_codex_transform_test.go +++ b/backend/internal/service/openai_codex_transform_test.go @@ -44,6 +44,39 @@ func TestApplyCodexOAuthTransform_ToolContinuationPreservesInput(t *testing.T) { require.Equal(t, "fc1", second["call_id"]) } +func TestApplyCodexOAuthTransform_MessagesBridgePromptCacheKeyIsHeaderOnly(t *testing.T) { + reqBody := map[string]any{ + "model": "gpt-5.5", + "prompt_cache_key": "anthropic-metadata-session-1", + "input": []any{ + map[string]any{ + "type": "message", + "role": "developer", + "content": []any{ + map[string]any{ + "type": "input_text", + "text": openAICompatClaudeCodeTodoGuardMarker, + }, + }, + }, + map[string]any{ + "type": "message", + "role": "user", + "content": "hello", + }, + }, + } + + result := applyCodexOAuthTransformWithOptions(reqBody, codexOAuthTransformOptions{ + SkipDefaultInstructions: true, + PreserveToolCallIDs: true, + }) + + require.Equal(t, "anthropic-metadata-session-1", result.PromptCacheKey) + require.True(t, result.Modified) + require.NotContains(t, reqBody, "prompt_cache_key") +} + func TestApplyCodexOAuthTransform_ToolContinuationPreservesNativeMessageAndReasoningIDs(t *testing.T) { reqBody := map[string]any{ "model": "gpt-5.2", @@ -804,15 +837,25 @@ func TestApplyCodexOAuthTransform_EmptyInput(t *testing.T) { func TestNormalizeCodexModel_Gpt53(t *testing.T) { cases := map[string]string{ "gpt-5.4": "gpt-5.4", + "gpt5.5": "gpt-5.5", + "openai/gpt5.5": "gpt-5.5", + "gpt5.4": "gpt-5.4", "gpt-5.4-high": "gpt-5.4", "gpt-5.4-chat-latest": "gpt-5.4", "gpt 5.4": "gpt-5.4", "gpt-5.4-mini": "gpt-5.4-mini", + "gpt5.4-mini": "gpt-5.4-mini", + "gpt5.4mini": "gpt-5.4-mini", "gpt 5.4 mini": "gpt-5.4-mini", "gpt-5.3": "gpt-5.3-codex", + "gpt5.3": "gpt-5.3-codex", "gpt-5.3-codex": "gpt-5.3-codex", + "gpt5.3-codex": "gpt-5.3-codex", + "gpt5.3codex": "gpt-5.3-codex", "gpt-5.3-codex-xhigh": "gpt-5.3-codex", "gpt-5.3-codex-spark": "gpt-5.3-codex-spark", + "gpt5.3-codex-spark": "gpt-5.3-codex-spark", + "gpt5.3codexspark": "gpt-5.3-codex-spark", "gpt 5.3 codex spark": "gpt-5.3-codex-spark", "gpt-5.3-codex-spark-high": "gpt-5.3-codex-spark", "gpt-5.3-codex-spark-xhigh": "gpt-5.3-codex-spark", diff --git a/backend/internal/service/openai_compat_model_test.go b/backend/internal/service/openai_compat_model_test.go index 4396c15f..a897e219 100644 --- a/backend/internal/service/openai_compat_model_test.go +++ b/backend/internal/service/openai_compat_model_test.go @@ -3,13 +3,17 @@ package service import ( "bytes" "context" + "errors" + "fmt" "io" "net/http" "net/http/httptest" "os" "path/filepath" "strings" + "sync" "testing" + "time" "github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/pkg/apicompat" @@ -18,6 +22,51 @@ import ( "github.com/tidwall/gjson" ) +type openAICompatFailingWriter struct { + gin.ResponseWriter + failAfter int + writes int +} + +func (w *openAICompatFailingWriter) Write(p []byte) (int, error) { + if w.writes >= w.failAfter { + return 0, errors.New("write failed: client disconnected") + } + w.writes++ + return w.ResponseWriter.Write(p) +} + +type openAICompatBlockingReadCloser struct { + data []byte + offset int + closed chan struct{} + closeOnce sync.Once +} + +func newOpenAICompatBlockingReadCloser(data []byte) *openAICompatBlockingReadCloser { + return &openAICompatBlockingReadCloser{ + data: data, + closed: make(chan struct{}), + } +} + +func (r *openAICompatBlockingReadCloser) Read(p []byte) (int, error) { + if r.offset < len(r.data) { + n := copy(p, r.data[r.offset:]) + r.offset += n + return n, nil + } + <-r.closed + return 0, io.EOF +} + +func (r *openAICompatBlockingReadCloser) Close() error { + r.closeOnce.Do(func() { + close(r.closed) + }) + return nil +} + func TestNormalizeOpenAICompatRequestedModel(t *testing.T) { t.Parallel() @@ -97,7 +146,10 @@ func TestForwardAsAnthropic_NormalizesRoutingAndEffortForGpt54XHigh(t *testing.T Body: io.NopCloser(strings.NewReader(upstreamBody)), }} - svc := &OpenAIGatewayService{httpUpstream: upstream} + svc := &OpenAIGatewayService{ + httpUpstream: upstream, + cfg: &config.Config{Security: config.SecurityConfig{URLAllowlist: config.URLAllowlistConfig{Enabled: false}}}, + } account := &Account{ ID: 1, Name: "openai-oauth", @@ -131,6 +183,927 @@ func TestForwardAsAnthropic_NormalizesRoutingAndEffortForGpt54XHigh(t *testing.T t.Logf("response body: %s", rec.Body.String()) } +func TestForwardAsAnthropic_InjectsPromptCacheKeyForAPIKeyMessagesDispatch(t *testing.T) { + t.Parallel() + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + body := []byte(`{"model":"claude-sonnet-4-5","max_tokens":16,"metadata":{"user_id":"claude-session-1"},"messages":[{"role":"user","content":"hello"}],"stream":false}`) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body)) + c.Request.Header.Set("Content-Type", "application/json") + + upstreamBody := strings.Join([]string{ + `data: {"type":"response.completed","response":{"id":"resp_1","object":"response","model":"gpt-5.3-codex","status":"completed","output":[{"type":"message","id":"msg_1","role":"assistant","status":"completed","content":[{"type":"output_text","text":"ok"}]}],"usage":{"input_tokens":5,"output_tokens":2,"total_tokens":7,"input_tokens_details":{"cached_tokens":3}}}}`, + "", + "data: [DONE]", + "", + }, "\n") + upstream := &httpUpstreamRecorder{resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_cache_key"}}, + Body: io.NopCloser(strings.NewReader(upstreamBody)), + }} + + svc := &OpenAIGatewayService{ + httpUpstream: upstream, + cfg: &config.Config{Security: config.SecurityConfig{URLAllowlist: config.URLAllowlistConfig{Enabled: false}}}, + } + account := &Account{ + ID: 1, + Name: "openai-apikey", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + "base_url": "https://api.openai.com/v1", + }, + } + + result, err := svc.ForwardAsAnthropic(context.Background(), c, account, body, "stable-cache-key", "gpt-5.3-codex") + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, "stable-cache-key", gjson.GetBytes(upstream.lastBody, "prompt_cache_key").String()) + require.Equal(t, "gpt-5.3-codex", gjson.GetBytes(upstream.lastBody, "model").String()) + require.Equal(t, 3, result.Usage.CacheReadInputTokens) +} + +func TestForwardAsAnthropic_AutoDerivesPromptCacheKeyWhenMessagesDispatchHasNoSessionID(t *testing.T) { + t.Parallel() + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + body := []byte(`{"model":"claude-sonnet-4-5","max_tokens":16,"system":"You are helpful.","messages":[{"role":"user","content":"open repo"}],"stream":false}`) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body)) + c.Request.Header.Set("Content-Type", "application/json") + + upstreamBody := strings.Join([]string{ + `data: {"type":"response.completed","response":{"id":"resp_1","object":"response","model":"gpt-5.3-codex","status":"completed","output":[{"type":"message","id":"msg_1","role":"assistant","status":"completed","content":[{"type":"output_text","text":"ok"}]}],"usage":{"input_tokens":5,"output_tokens":2,"total_tokens":7,"input_tokens_details":{"cached_tokens":3}}}}`, + "", + "data: [DONE]", + "", + }, "\n") + upstream := &httpUpstreamRecorder{resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_auto_cache_key"}}, + Body: io.NopCloser(strings.NewReader(upstreamBody)), + }} + + svc := &OpenAIGatewayService{ + httpUpstream: upstream, + cfg: &config.Config{Security: config.SecurityConfig{URLAllowlist: config.URLAllowlistConfig{Enabled: false}}}, + } + account := &Account{ + ID: 1, + Name: "openai-apikey", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + "base_url": "https://api.openai.com/v1", + }, + } + + result, err := svc.ForwardAsAnthropic(context.Background(), c, account, body, "", "gpt-5.3-codex") + require.NoError(t, err) + require.NotNil(t, result) + cacheKey := gjson.GetBytes(upstream.lastBody, "prompt_cache_key").String() + require.NotEmpty(t, cacheKey) + require.True(t, strings.HasPrefix(cacheKey, "anthropic-digest-")) + require.Equal(t, generateSessionUUID(isolateOpenAISessionID(0, cacheKey)), upstream.lastReq.Header.Get("session_id")) +} + +func TestForwardAsAnthropic_DoesNotAutoDerivePromptCacheKeyForNonCodexModel(t *testing.T) { + t.Parallel() + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + body := []byte(`{"model":"claude-sonnet-4-5","max_tokens":16,"messages":[{"role":"user","content":"hello"}],"stream":false}`) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body)) + c.Request.Header.Set("Content-Type", "application/json") + + upstreamBody := strings.Join([]string{ + `data: {"type":"response.completed","response":{"id":"resp_1","object":"response","model":"gpt-4o","status":"completed","output":[{"type":"message","id":"msg_1","role":"assistant","status":"completed","content":[{"type":"output_text","text":"ok"}]}],"usage":{"input_tokens":5,"output_tokens":2,"total_tokens":7}}}`, + "", + "data: [DONE]", + "", + }, "\n") + upstream := &httpUpstreamRecorder{resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_no_cache_key"}}, + Body: io.NopCloser(strings.NewReader(upstreamBody)), + }} + + svc := &OpenAIGatewayService{ + httpUpstream: upstream, + cfg: &config.Config{Security: config.SecurityConfig{URLAllowlist: config.URLAllowlistConfig{Enabled: false}}}, + } + account := &Account{ + ID: 1, + Name: "openai-apikey", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + "base_url": "https://api.openai.com/v1", + }, + } + + result, err := svc.ForwardAsAnthropic(context.Background(), c, account, body, "", "gpt-4o") + require.NoError(t, err) + require.NotNil(t, result) + require.False(t, gjson.GetBytes(upstream.lastBody, "prompt_cache_key").Exists()) + require.Empty(t, upstream.lastReq.Header.Get("session_id")) +} + +func TestForwardAsAnthropic_TrimsFullReplayOnlyForCodexCompatModels(t *testing.T) { + t.Parallel() + gin.SetMode(gin.TestMode) + + messages := make([]string, 0, openAICompatAnthropicReplayMaxTailMessages+3) + for i := 0; i < openAICompatAnthropicReplayMaxTailMessages+3; i++ { + messages = append(messages, `{"role":"user","content":"message-`+fmt.Sprintf("%02d", i)+`"}`) + } + body := []byte(`{"model":"claude-sonnet-4-5","max_tokens":16,"messages":[` + strings.Join(messages, ",") + `],"stream":false}`) + + run := func(t *testing.T, mappedModel string) []byte { + t.Helper() + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body)) + c.Request.Header.Set("Content-Type", "application/json") + + upstreamBody := strings.Join([]string{ + `data: {"type":"response.completed","response":{"id":"resp_1","object":"response","model":"` + mappedModel + `","status":"completed","output":[{"type":"message","id":"msg_1","role":"assistant","status":"completed","content":[{"type":"output_text","text":"ok"}]}],"usage":{"input_tokens":5,"output_tokens":2,"total_tokens":7}}}`, + "", + "data: [DONE]", + "", + }, "\n") + upstream := &httpUpstreamRecorder{resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_trim"}}, + Body: io.NopCloser(strings.NewReader(upstreamBody)), + }} + + svc := &OpenAIGatewayService{ + httpUpstream: upstream, + cfg: &config.Config{Security: config.SecurityConfig{URLAllowlist: config.URLAllowlistConfig{Enabled: false}}}, + } + account := &Account{ + ID: 1, + Name: "openai-apikey", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + "base_url": "https://api.openai.com/v1", + }, + } + + result, err := svc.ForwardAsAnthropic(context.Background(), c, account, body, "", mappedModel) + require.NoError(t, err) + require.NotNil(t, result) + return upstream.lastBody + } + + codexBody := run(t, "gpt-5.3-codex") + require.Equal(t, int64(openAICompatAnthropicReplayMaxTailMessages+1), gjson.GetBytes(codexBody, "input.#").Int()) + require.Equal(t, "developer", gjson.GetBytes(codexBody, "input.0.role").String()) + require.Contains(t, gjson.GetBytes(codexBody, "input.0.content.0.text").String(), "") + require.Equal(t, "message-03", gjson.GetBytes(codexBody, "input.1.content.0.text").String()) + require.Equal(t, "message-14", gjson.GetBytes(codexBody, "input.12.content.0.text").String()) + + nonCompatBody := run(t, "gpt-4o") + require.Equal(t, int64(openAICompatAnthropicReplayMaxTailMessages+3), gjson.GetBytes(nonCompatBody, "input.#").Int()) + require.Equal(t, "message-00", gjson.GetBytes(nonCompatBody, "input.0.content.0.text").String()) +} + +func TestForwardAsAnthropic_OAuthCompatKeepsFullReplayForCacheGrowth(t *testing.T) { + t.Parallel() + gin.SetMode(gin.TestMode) + + messages := make([]string, 0, openAICompatAnthropicReplayMaxTailMessages+3) + for i := 0; i < openAICompatAnthropicReplayMaxTailMessages+3; i++ { + messages = append(messages, `{"role":"user","content":"message-`+fmt.Sprintf("%02d", i)+`"}`) + } + body := []byte(`{"model":"claude-sonnet-4-5","max_tokens":16,"messages":[` + strings.Join(messages, ",") + `],"stream":false}`) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body)) + c.Request.Header.Set("Content-Type", "application/json") + + upstream := &httpUpstreamRecorder{resp: openAICompatSSECompletedResponse("resp_oauth_trim", "gpt-5.4")} + svc := &OpenAIGatewayService{ + httpUpstream: upstream, + cfg: &config.Config{Security: config.SecurityConfig{URLAllowlist: config.URLAllowlistConfig{Enabled: false}}}, + } + account := &Account{ + ID: 1, + Name: "openai-oauth", + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Concurrency: 1, + Credentials: map[string]any{ + "access_token": "oauth-token", + "chatgpt_account_id": "chatgpt-acc", + }, + } + + result, err := svc.ForwardAsAnthropic(context.Background(), c, account, body, "", "gpt-5.4") + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, int64(openAICompatAnthropicReplayMaxTailMessages+4), gjson.GetBytes(upstream.lastBody, "input.#").Int()) + require.Equal(t, "developer", gjson.GetBytes(upstream.lastBody, "input.0.role").String()) + require.Contains(t, gjson.GetBytes(upstream.lastBody, "input.0.content.0.text").String(), "") + require.Equal(t, "message-00", gjson.GetBytes(upstream.lastBody, "input.1.content.0.text").String()) + require.Equal(t, "message-14", gjson.GetBytes(upstream.lastBody, "input.15.content.0.text").String()) + require.False(t, gjson.GetBytes(upstream.lastBody, "prompt_cache_key").Exists()) +} + +func TestForwardAsAnthropic_AttachesPreviousResponseIDForCompatContinuation(t *testing.T) { + t.Parallel() + gin.SetMode(gin.TestMode) + + upstream := &httpUpstreamRecorder{} + svc := &OpenAIGatewayService{ + httpUpstream: upstream, + cfg: &config.Config{Security: config.SecurityConfig{URLAllowlist: config.URLAllowlistConfig{Enabled: false}}}, + } + account := &Account{ + ID: 1, + Name: "openai-apikey", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + "base_url": "https://api.openai.com/v1", + }, + } + + firstBody := []byte(`{"model":"claude-sonnet-4-5","max_tokens":16,"messages":[{"role":"user","content":"first"}],"stream":false}`) + upstream.resp = openAICompatSSECompletedResponse("resp_first", "gpt-5.3-codex") + firstRec := httptest.NewRecorder() + firstCtx, _ := gin.CreateTestContext(firstRec) + firstCtx.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(firstBody)) + firstCtx.Request.Header.Set("Content-Type", "application/json") + + firstResult, err := svc.ForwardAsAnthropic(context.Background(), firstCtx, account, firstBody, "stable-cache-key", "gpt-5.3-codex") + require.NoError(t, err) + require.NotNil(t, firstResult) + require.Equal(t, "resp_first", firstResult.ResponseID) + require.False(t, gjson.GetBytes(upstream.lastBody, "previous_response_id").Exists()) + + secondBody := []byte(`{"model":"claude-sonnet-4-5","max_tokens":16,"messages":[{"role":"user","content":"first"},{"role":"assistant","content":"ok"},{"role":"user","content":"second"}],"stream":false}`) + upstream.resp = openAICompatSSECompletedResponse("resp_second", "gpt-5.3-codex") + secondRec := httptest.NewRecorder() + secondCtx, _ := gin.CreateTestContext(secondRec) + secondCtx.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(secondBody)) + secondCtx.Request.Header.Set("Content-Type", "application/json") + + secondResult, err := svc.ForwardAsAnthropic(context.Background(), secondCtx, account, secondBody, "stable-cache-key", "gpt-5.3-codex") + require.NoError(t, err) + require.NotNil(t, secondResult) + require.Equal(t, "resp_second", secondResult.ResponseID) + require.Equal(t, "resp_first", gjson.GetBytes(upstream.lastBody, "previous_response_id").String()) + require.Equal(t, int64(2), gjson.GetBytes(upstream.lastBody, "input.#").Int()) + require.Equal(t, "developer", gjson.GetBytes(upstream.lastBody, "input.0.role").String()) + require.Contains(t, gjson.GetBytes(upstream.lastBody, "input.0.content.0.text").String(), "") + require.Equal(t, "second", gjson.GetBytes(upstream.lastBody, "input.1.content.0.text").String()) +} + +func TestForwardAsAnthropic_ReplaysWithoutContinuationWhenPreviousResponseMissing(t *testing.T) { + t.Parallel() + gin.SetMode(gin.TestMode) + + upstream := &httpUpstreamRecorder{} + svc := &OpenAIGatewayService{ + httpUpstream: upstream, + cfg: &config.Config{Security: config.SecurityConfig{URLAllowlist: config.URLAllowlistConfig{Enabled: false}}}, + } + account := &Account{ + ID: 1, + Name: "openai-apikey", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + "base_url": "https://api.openai.com/v1", + }, + } + + svc.bindOpenAICompatSessionResponseID(context.Background(), nil, account, "stable-cache-key", "resp_missing") + secondBody := []byte(`{"model":"claude-sonnet-4-5","max_tokens":16,"messages":[{"role":"user","content":"first"},{"role":"assistant","content":"ok"},{"role":"user","content":"second"}],"stream":false}`) + upstream.responses = []*http.Response{ + { + StatusCode: http.StatusBadRequest, + Header: http.Header{"Content-Type": []string{"application/json"}, "x-request-id": []string{"rid_prev_missing"}}, + Body: io.NopCloser(strings.NewReader(`{"error":{"code":"previous_response_not_found","message":"previous response not found"}}`)), + }, + openAICompatSSECompletedResponse("resp_replayed", "gpt-5.3-codex"), + } + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(secondBody)) + c.Request.Header.Set("Content-Type", "application/json") + + result, err := svc.ForwardAsAnthropic(context.Background(), c, account, secondBody, "stable-cache-key", "gpt-5.3-codex") + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, "resp_replayed", result.ResponseID) + require.Len(t, upstream.requests, 2) + require.Equal(t, "resp_missing", gjson.GetBytes(upstream.bodies[0], "previous_response_id").String()) + require.False(t, gjson.GetBytes(upstream.bodies[1], "previous_response_id").Exists()) + require.Equal(t, int64(4), gjson.GetBytes(upstream.bodies[1], "input.#").Int()) + require.Equal(t, "developer", gjson.GetBytes(upstream.bodies[1], "input.0.role").String()) + require.Contains(t, gjson.GetBytes(upstream.bodies[1], "input.0.content.0.text").String(), "") + require.Equal(t, "first", gjson.GetBytes(upstream.bodies[1], "input.1.content.0.text").String()) + require.Equal(t, "second", gjson.GetBytes(upstream.bodies[1], "input.3.content.0.text").String()) +} + +func TestForwardAsAnthropic_DisablesAPIKeyContinuationWhenUpstreamRequiresWebSocketV2(t *testing.T) { + t.Parallel() + gin.SetMode(gin.TestMode) + + upstream := &httpUpstreamRecorder{} + svc := &OpenAIGatewayService{ + httpUpstream: upstream, + cfg: &config.Config{Security: config.SecurityConfig{URLAllowlist: config.URLAllowlistConfig{Enabled: false}}}, + } + account := &Account{ + ID: 1, + Name: "openai-apikey", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + "base_url": "https://api.openai.com/v1", + }, + } + + svc.bindOpenAICompatSessionResponseID(context.Background(), nil, account, "stable-cache-key", "resp_http_unsupported") + body := []byte(`{"model":"claude-sonnet-4-5","max_tokens":16,"messages":[{"role":"user","content":"first"},{"role":"assistant","content":"ok"},{"role":"user","content":"second"}],"stream":false}`) + upstream.responses = []*http.Response{ + { + StatusCode: http.StatusBadRequest, + Header: http.Header{"Content-Type": []string{"application/json"}, "x-request-id": []string{"rid_prev_http_unsupported"}}, + Body: io.NopCloser(strings.NewReader(`{"error":{"message":"previous_response_id is only supported on Responses WebSocket v2","type":"invalid_request_error"}}`)), + }, + openAICompatSSECompletedResponse("resp_replayed", "gpt-5.5"), + openAICompatSSECompletedResponse("resp_later", "gpt-5.5"), + } + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body)) + c.Request.Header.Set("Content-Type", "application/json") + + result, err := svc.ForwardAsAnthropic(context.Background(), c, account, body, "stable-cache-key", "gpt-5.5") + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, "resp_replayed", result.ResponseID) + require.Len(t, upstream.requests, 2) + require.Equal(t, "resp_http_unsupported", gjson.GetBytes(upstream.bodies[0], "previous_response_id").String()) + require.False(t, gjson.GetBytes(upstream.bodies[1], "previous_response_id").Exists()) + + laterRec := httptest.NewRecorder() + laterCtx, _ := gin.CreateTestContext(laterRec) + laterCtx.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body)) + laterCtx.Request.Header.Set("Content-Type", "application/json") + + laterResult, err := svc.ForwardAsAnthropic(context.Background(), laterCtx, account, body, "stable-cache-key", "gpt-5.5") + require.NoError(t, err) + require.NotNil(t, laterResult) + require.Equal(t, "resp_later", laterResult.ResponseID) + require.Len(t, upstream.requests, 3) + require.False(t, gjson.GetBytes(upstream.bodies[2], "previous_response_id").Exists()) +} + +func TestForwardAsAnthropic_APIKeyMetadataSessionSurvivesChangingCacheControlAnchorAfterContinuationDisabled(t *testing.T) { + t.Parallel() + gin.SetMode(gin.TestMode) + + metadata := `{"user_id":"{\"device_id\":\"aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa\",\"account_uuid\":\"\",\"session_id\":\"aaaaaaaa-aaaa-4aaa-8aaa-aaaaaaaaaaaa\"}"}` + firstBody := []byte(`{"model":"claude-haiku-4-5-20251001","max_tokens":16,"metadata":` + metadata + `,"system":[{"type":"text","text":"project docs","cache_control":{"type":"ephemeral"}}],"messages":[{"role":"user","content":"first"}],"stream":false}`) + messages := make([]string, 0, openAICompatAnthropicReplayMaxTailMessages+4) + messages = append(messages, `{"role":"user","content":[{"type":"text","text":"rewritten context","cache_control":{"type":"ephemeral"}}]}`) + for i := 1; i < openAICompatAnthropicReplayMaxTailMessages+4; i++ { + messages = append(messages, `{"role":"user","content":"message-`+fmt.Sprintf("%02d", i)+`"}`) + } + secondBody := []byte(`{"model":"claude-haiku-4-5-20251001","max_tokens":16,"metadata":` + metadata + `,"messages":[` + strings.Join(messages, ",") + `],"stream":false}`) + + upstream := &httpUpstreamRecorder{responses: []*http.Response{ + openAICompatSSECompletedResponse("resp_first", "gpt-5.4-mini"), + openAICompatSSECompletedResponse("resp_second", "gpt-5.4-mini"), + }} + svc := &OpenAIGatewayService{ + httpUpstream: upstream, + cfg: &config.Config{Security: config.SecurityConfig{URLAllowlist: config.URLAllowlistConfig{Enabled: false}}}, + } + account := &Account{ + ID: 1, + Name: "openai-apikey", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + "base_url": "https://api.openai.com/v1", + }, + } + + firstRec := httptest.NewRecorder() + firstCtx, _ := gin.CreateTestContext(firstRec) + firstCtx.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(firstBody)) + firstCtx.Request.Header.Set("Content-Type", "application/json") + + firstResult, err := svc.ForwardAsAnthropic(context.Background(), firstCtx, account, firstBody, "", "gpt-5.4-mini") + require.NoError(t, err) + require.NotNil(t, firstResult) + firstKey := gjson.GetBytes(upstream.bodies[0], "prompt_cache_key").String() + require.NotEmpty(t, firstKey) + require.True(t, strings.HasPrefix(firstKey, "anthropic-metadata-")) + + svc.disableOpenAICompatSessionContinuation(context.Background(), nil, account, firstKey) + + secondRec := httptest.NewRecorder() + secondCtx, _ := gin.CreateTestContext(secondRec) + secondCtx.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(secondBody)) + secondCtx.Request.Header.Set("Content-Type", "application/json") + + secondResult, err := svc.ForwardAsAnthropic(context.Background(), secondCtx, account, secondBody, "", "gpt-5.4-mini") + require.NoError(t, err) + require.NotNil(t, secondResult) + require.Len(t, upstream.requests, 2) + require.Equal(t, firstKey, gjson.GetBytes(upstream.bodies[1], "prompt_cache_key").String()) + require.False(t, gjson.GetBytes(upstream.bodies[1], "previous_response_id").Exists()) + require.Equal(t, int64(openAICompatAnthropicReplayMaxTailMessages+5), gjson.GetBytes(upstream.bodies[1], "input.#").Int()) + require.Equal(t, "developer", gjson.GetBytes(upstream.bodies[1], "input.0.role").String()) + require.Contains(t, gjson.GetBytes(upstream.bodies[1], "input.0.content.0.text").String(), "") + require.Equal(t, "rewritten context", gjson.GetBytes(upstream.bodies[1], "input.1.content.0.text").String()) + require.Equal(t, "message-15", gjson.GetBytes(upstream.bodies[1], "input.16.content.0.text").String()) +} + +func TestForwardAsAnthropic_DoesNotAttachPreviousResponseIDForOAuthCompat(t *testing.T) { + t.Parallel() + gin.SetMode(gin.TestMode) + + upstream := &httpUpstreamRecorder{resp: openAICompatSSECompletedResponse("resp_oauth_next", "gpt-5.4")} + svc := &OpenAIGatewayService{ + httpUpstream: upstream, + cfg: &config.Config{Security: config.SecurityConfig{URLAllowlist: config.URLAllowlistConfig{Enabled: false}}}, + } + account := &Account{ + ID: 1, + Name: "openai-oauth", + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Concurrency: 1, + Credentials: map[string]any{ + "access_token": "oauth-token", + "chatgpt_account_id": "chatgpt-acc", + }, + } + svc.bindOpenAICompatSessionResponseID(context.Background(), nil, account, "stable-cache-key", "resp_oauth_prev") + + body := []byte(`{"model":"claude-sonnet-4-5","max_tokens":16,"messages":[{"role":"user","content":"first"},{"role":"assistant","content":"ok"},{"role":"user","content":"second"}],"stream":false}`) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body)) + c.Request.Header.Set("Content-Type", "application/json") + + result, err := svc.ForwardAsAnthropic(context.Background(), c, account, body, "stable-cache-key", "gpt-5.4") + require.NoError(t, err) + require.NotNil(t, result) + require.False(t, gjson.GetBytes(upstream.lastBody, "previous_response_id").Exists()) +} + +func TestForwardAsAnthropic_ReusesOAuthCodexTurnState(t *testing.T) { + t.Parallel() + gin.SetMode(gin.TestMode) + + firstResp := openAICompatSSECompletedResponse("resp_oauth_first", "gpt-5.4") + firstResp.Header.Set("x-codex-turn-state", "turn_state_first") + upstream := &httpUpstreamRecorder{responses: []*http.Response{ + firstResp, + openAICompatSSECompletedResponse("resp_oauth_second", "gpt-5.4"), + }} + svc := &OpenAIGatewayService{ + httpUpstream: upstream, + cfg: &config.Config{Security: config.SecurityConfig{URLAllowlist: config.URLAllowlistConfig{Enabled: false}}}, + } + account := &Account{ + ID: 1, + Name: "openai-oauth", + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Concurrency: 1, + Credentials: map[string]any{ + "access_token": "oauth-token", + "chatgpt_account_id": "chatgpt-acc", + }, + } + + firstBody := []byte(`{"model":"claude-sonnet-4-5","max_tokens":16,"messages":[{"role":"user","content":"first"}],"stream":false}`) + firstRec := httptest.NewRecorder() + firstCtx, _ := gin.CreateTestContext(firstRec) + firstCtx.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(firstBody)) + firstCtx.Request.Header.Set("Content-Type", "application/json") + + firstResult, err := svc.ForwardAsAnthropic(context.Background(), firstCtx, account, firstBody, "stable-cache-key", "gpt-5.4") + require.NoError(t, err) + require.NotNil(t, firstResult) + require.Empty(t, upstream.requests[0].Header.Get("x-codex-turn-state")) + require.Empty(t, upstream.requests[0].Header.Get("OpenAI-Beta")) + require.Empty(t, upstream.requests[0].Header.Get("originator")) + + secondBody := []byte(`{"model":"claude-sonnet-4-5","max_tokens":16,"messages":[{"role":"user","content":"first"},{"role":"assistant","content":"ok"},{"role":"user","content":"second"}],"stream":false}`) + secondRec := httptest.NewRecorder() + secondCtx, _ := gin.CreateTestContext(secondRec) + secondCtx.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(secondBody)) + secondCtx.Request.Header.Set("Content-Type", "application/json") + + secondResult, err := svc.ForwardAsAnthropic(context.Background(), secondCtx, account, secondBody, "stable-cache-key", "gpt-5.4") + require.NoError(t, err) + require.NotNil(t, secondResult) + require.Equal(t, "turn_state_first", upstream.requests[1].Header.Get("x-codex-turn-state")) + require.Equal(t, generateSessionUUID(isolateOpenAISessionID(0, "stable-cache-key")), upstream.requests[1].Header.Get("session_id")) + require.Empty(t, upstream.requests[1].Header.Get("conversation_id")) + require.Empty(t, upstream.requests[1].Header.Get("OpenAI-Beta")) + require.Empty(t, upstream.requests[1].Header.Get("originator")) + require.False(t, gjson.GetBytes(upstream.bodies[1], "prompt_cache_key").Exists()) + require.False(t, gjson.GetBytes(upstream.bodies[1], "previous_response_id").Exists()) +} + +func TestForwardAsAnthropic_OAuthDigestFallbackReusesTurnStateWithoutExplicitKey(t *testing.T) { + t.Parallel() + gin.SetMode(gin.TestMode) + + firstResp := openAICompatSSECompletedResponse("resp_oauth_digest_first", "gpt-5.4") + firstResp.Header.Set("x-codex-turn-state", "turn_state_digest_first") + upstream := &httpUpstreamRecorder{responses: []*http.Response{ + firstResp, + openAICompatSSECompletedResponse("resp_oauth_digest_second", "gpt-5.4"), + }} + svc := &OpenAIGatewayService{ + httpUpstream: upstream, + cfg: &config.Config{Security: config.SecurityConfig{URLAllowlist: config.URLAllowlistConfig{Enabled: false}}}, + } + account := &Account{ + ID: 1, + Name: "openai-oauth", + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Concurrency: 1, + Credentials: map[string]any{ + "access_token": "oauth-token", + "chatgpt_account_id": "chatgpt-acc", + }, + } + + firstBody := []byte(`{"model":"claude-sonnet-4-5","max_tokens":16,"messages":[{"role":"user","content":"first"}],"stream":false}`) + firstRec := httptest.NewRecorder() + firstCtx, _ := gin.CreateTestContext(firstRec) + firstCtx.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(firstBody)) + firstCtx.Request.Header.Set("Content-Type", "application/json") + + firstResult, err := svc.ForwardAsAnthropic(context.Background(), firstCtx, account, firstBody, "", "gpt-5.4") + require.NoError(t, err) + require.NotNil(t, firstResult) + firstSessionID := upstream.requests[0].Header.Get("session_id") + require.NotEmpty(t, firstSessionID) + require.Empty(t, upstream.requests[0].Header.Get("x-codex-turn-state")) + require.False(t, gjson.GetBytes(upstream.bodies[0], "prompt_cache_key").Exists()) + + secondBody := []byte(`{"model":"claude-sonnet-4-5","max_tokens":16,"messages":[{"role":"user","content":"first"},{"role":"assistant","content":"ok"},{"role":"user","content":"second"}],"stream":false}`) + secondRec := httptest.NewRecorder() + secondCtx, _ := gin.CreateTestContext(secondRec) + secondCtx.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(secondBody)) + secondCtx.Request.Header.Set("Content-Type", "application/json") + + secondResult, err := svc.ForwardAsAnthropic(context.Background(), secondCtx, account, secondBody, "", "gpt-5.4") + require.NoError(t, err) + require.NotNil(t, secondResult) + require.Equal(t, firstSessionID, upstream.requests[1].Header.Get("session_id")) + require.Equal(t, "turn_state_digest_first", upstream.requests[1].Header.Get("x-codex-turn-state")) + require.Empty(t, upstream.requests[1].Header.Get("conversation_id")) + require.False(t, gjson.GetBytes(upstream.bodies[1], "prompt_cache_key").Exists()) + require.False(t, gjson.GetBytes(upstream.bodies[1], "previous_response_id").Exists()) +} + +func TestForwardAsAnthropic_OAuthMetadataSessionSurvivesDigestPrefixRewrite(t *testing.T) { + t.Parallel() + gin.SetMode(gin.TestMode) + + firstResp := openAICompatSSECompletedResponse("resp_oauth_metadata_first", "gpt-5.5") + firstResp.Header.Set("x-codex-turn-state", "turn_state_metadata_first") + upstream := &httpUpstreamRecorder{responses: []*http.Response{ + firstResp, + openAICompatSSECompletedResponse("resp_oauth_metadata_second", "gpt-5.5"), + }} + svc := &OpenAIGatewayService{ + httpUpstream: upstream, + cfg: &config.Config{Security: config.SecurityConfig{URLAllowlist: config.URLAllowlistConfig{Enabled: false}}}, + } + account := &Account{ + ID: 1, + Name: "openai-oauth", + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Concurrency: 1, + Credentials: map[string]any{ + "access_token": "oauth-token", + "chatgpt_account_id": "chatgpt-acc", + }, + } + metadata := `{"user_id":"{\"device_id\":\"aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa\",\"account_uuid\":\"\",\"session_id\":\"aaaaaaaa-aaaa-4aaa-8aaa-aaaaaaaaaaaa\"}"}` + + firstBody := []byte(`{"model":"claude-sonnet-4-5","max_tokens":16,"metadata":` + metadata + `,"messages":[{"role":"user","content":"first plan"}],"stream":false}`) + firstRec := httptest.NewRecorder() + firstCtx, _ := gin.CreateTestContext(firstRec) + firstCtx.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(firstBody)) + firstCtx.Request.Header.Set("Content-Type", "application/json") + + firstResult, err := svc.ForwardAsAnthropic(context.Background(), firstCtx, account, firstBody, "", "gpt-5.5") + require.NoError(t, err) + require.NotNil(t, firstResult) + firstSessionID := upstream.requests[0].Header.Get("session_id") + require.NotEmpty(t, firstSessionID) + require.Empty(t, upstream.requests[0].Header.Get("x-codex-turn-state")) + require.False(t, gjson.GetBytes(upstream.bodies[0], "prompt_cache_key").Exists()) + + secondBody := []byte(`{"model":"claude-sonnet-4-5","max_tokens":16,"metadata":` + metadata + `,"messages":[{"role":"user","content":"rewritten plan"},{"role":"assistant","content":"ok"},{"role":"user","content":"second"}],"stream":false}`) + secondRec := httptest.NewRecorder() + secondCtx, _ := gin.CreateTestContext(secondRec) + secondCtx.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(secondBody)) + secondCtx.Request.Header.Set("Content-Type", "application/json") + + secondResult, err := svc.ForwardAsAnthropic(context.Background(), secondCtx, account, secondBody, "", "gpt-5.5") + require.NoError(t, err) + require.NotNil(t, secondResult) + require.Equal(t, firstSessionID, upstream.requests[1].Header.Get("session_id")) + require.Equal(t, "turn_state_metadata_first", upstream.requests[1].Header.Get("x-codex-turn-state")) + require.Empty(t, upstream.requests[1].Header.Get("conversation_id")) + require.False(t, gjson.GetBytes(upstream.bodies[1], "prompt_cache_key").Exists()) + require.False(t, gjson.GetBytes(upstream.bodies[1], "previous_response_id").Exists()) +} + +func TestForwardAsAnthropic_OAuthMetadataSessionSurvivesChangingCacheControlAnchor(t *testing.T) { + t.Parallel() + gin.SetMode(gin.TestMode) + + firstResp := openAICompatSSECompletedResponse("resp_oauth_cache_anchor_first", "gpt-5.5") + firstResp.Header.Set("x-codex-turn-state", "turn_state_cache_anchor_first") + upstream := &httpUpstreamRecorder{responses: []*http.Response{ + firstResp, + openAICompatSSECompletedResponse("resp_oauth_cache_anchor_second", "gpt-5.5"), + }} + svc := &OpenAIGatewayService{ + httpUpstream: upstream, + cfg: &config.Config{Security: config.SecurityConfig{URLAllowlist: config.URLAllowlistConfig{Enabled: false}}}, + } + account := &Account{ + ID: 1, + Name: "openai-oauth", + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Concurrency: 1, + Credentials: map[string]any{ + "access_token": "oauth-token", + "chatgpt_account_id": "chatgpt-acc", + }, + } + metadata := `{"user_id":"{\"device_id\":\"bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb\",\"account_uuid\":\"\",\"session_id\":\"bbbbbbbb-bbbb-4bbb-8bbb-bbbbbbbbbbbb\"}"}` + + firstBody := []byte(`{"model":"claude-sonnet-4-5","max_tokens":16,"metadata":` + metadata + `,"system":[{"type":"text","text":"anchor one","cache_control":{"type":"ephemeral"}}],"messages":[{"role":"user","content":"first"}],"stream":false}`) + firstRec := httptest.NewRecorder() + firstCtx, _ := gin.CreateTestContext(firstRec) + firstCtx.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(firstBody)) + firstCtx.Request.Header.Set("Content-Type", "application/json") + + firstResult, err := svc.ForwardAsAnthropic(context.Background(), firstCtx, account, firstBody, "", "gpt-5.5") + require.NoError(t, err) + require.NotNil(t, firstResult) + firstSessionID := upstream.requests[0].Header.Get("session_id") + require.NotEmpty(t, firstSessionID) + require.Empty(t, upstream.requests[0].Header.Get("x-codex-turn-state")) + require.False(t, gjson.GetBytes(upstream.bodies[0], "prompt_cache_key").Exists()) + + secondBody := []byte(`{"model":"claude-sonnet-4-5","max_tokens":16,"metadata":` + metadata + `,"system":[{"type":"text","text":"anchor two","cache_control":{"type":"ephemeral"}}],"messages":[{"role":"user","content":"first"},{"role":"assistant","content":"ok"},{"role":"user","content":"second"}],"stream":false}`) + secondRec := httptest.NewRecorder() + secondCtx, _ := gin.CreateTestContext(secondRec) + secondCtx.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(secondBody)) + secondCtx.Request.Header.Set("Content-Type", "application/json") + + secondResult, err := svc.ForwardAsAnthropic(context.Background(), secondCtx, account, secondBody, "", "gpt-5.5") + require.NoError(t, err) + require.NotNil(t, secondResult) + require.Equal(t, firstSessionID, upstream.requests[1].Header.Get("session_id")) + require.Equal(t, "turn_state_cache_anchor_first", upstream.requests[1].Header.Get("x-codex-turn-state")) + require.Empty(t, upstream.requests[1].Header.Get("conversation_id")) + require.False(t, gjson.GetBytes(upstream.bodies[1], "prompt_cache_key").Exists()) + require.False(t, gjson.GetBytes(upstream.bodies[1], "previous_response_id").Exists()) +} + +func TestForwardAsAnthropic_OAuthKeepsSystemAsDeveloperInput(t *testing.T) { + t.Parallel() + gin.SetMode(gin.TestMode) + + upstream := &httpUpstreamRecorder{resp: openAICompatSSECompletedResponse("resp_oauth_system", "gpt-5.4")} + svc := &OpenAIGatewayService{ + httpUpstream: upstream, + cfg: &config.Config{Security: config.SecurityConfig{URLAllowlist: config.URLAllowlistConfig{Enabled: false}}}, + } + account := &Account{ + ID: 1, + Name: "openai-oauth", + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Concurrency: 1, + Credentials: map[string]any{ + "access_token": "oauth-token", + "chatgpt_account_id": "chatgpt-acc", + }, + } + + body := []byte(`{"model":"claude-sonnet-4-5","max_tokens":16,"system":[{"type":"text","text":"project instructions","cache_control":{"type":"ephemeral"}}],"messages":[{"role":"user","content":"first"}],"stream":false}`) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body)) + c.Request.Header.Set("Content-Type", "application/json") + + result, err := svc.ForwardAsAnthropic(context.Background(), c, account, body, "", "gpt-5.4") + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, "developer", gjson.GetBytes(upstream.lastBody, "input.0.role").String()) + require.Equal(t, "input_text", gjson.GetBytes(upstream.lastBody, "input.0.content.0.type").String()) + require.Equal(t, "project instructions", gjson.GetBytes(upstream.lastBody, "input.0.content.0.text").String()) + instructions := gjson.GetBytes(upstream.lastBody, "instructions") + require.True(t, instructions.Exists()) + require.Empty(t, instructions.String()) + require.Empty(t, upstream.requests[0].Header.Get("OpenAI-Beta")) + require.Empty(t, upstream.requests[0].Header.Get("originator")) +} + +func TestForwardAsAnthropic_OAuthAddsClaudeCodeTodoGuardForCompatModel(t *testing.T) { + t.Parallel() + gin.SetMode(gin.TestMode) + + upstream := &httpUpstreamRecorder{resp: openAICompatSSECompletedResponse("resp_oauth_todo_guard", "gpt-5.5")} + svc := &OpenAIGatewayService{ + httpUpstream: upstream, + cfg: &config.Config{Security: config.SecurityConfig{URLAllowlist: config.URLAllowlistConfig{Enabled: false}}}, + } + account := &Account{ + ID: 1, + Name: "openai-oauth", + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Concurrency: 1, + Credentials: map[string]any{ + "access_token": "oauth-token", + "chatgpt_account_id": "chatgpt-acc", + }, + } + + body := []byte(`{"model":"claude-sonnet-4-5","max_tokens":16,"system":"project instructions","messages":[{"role":"user","content":"review files"}],"stream":false}`) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body)) + c.Request.Header.Set("Content-Type", "application/json") + + result, err := svc.ForwardAsAnthropic(context.Background(), c, account, body, "", "gpt-5.5") + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, "developer", gjson.GetBytes(upstream.lastBody, "input.0.role").String()) + require.Equal(t, "project instructions", gjson.GetBytes(upstream.lastBody, "input.0.content.0.text").String()) + require.Equal(t, "developer", gjson.GetBytes(upstream.lastBody, "input.1.role").String()) + require.Contains(t, gjson.GetBytes(upstream.lastBody, "input.1.content.0.text").String(), "") + require.Equal(t, "user", gjson.GetBytes(upstream.lastBody, "input.2.role").String()) +} + +func TestForwardAsAnthropic_OAuthPreservesClaudeCodeToolCallID(t *testing.T) { + t.Parallel() + gin.SetMode(gin.TestMode) + + upstream := &httpUpstreamRecorder{resp: openAICompatSSECompletedResponse("resp_oauth_tool", "gpt-5.4")} + svc := &OpenAIGatewayService{ + httpUpstream: upstream, + cfg: &config.Config{Security: config.SecurityConfig{URLAllowlist: config.URLAllowlistConfig{Enabled: false}}}, + } + account := &Account{ + ID: 1, + Name: "openai-oauth", + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Concurrency: 1, + Credentials: map[string]any{ + "access_token": "oauth-token", + "chatgpt_account_id": "chatgpt-acc", + }, + } + + body := []byte(`{"model":"claude-sonnet-4-5","max_tokens":16,"messages":[{"role":"user","content":"list files"},{"role":"assistant","content":[{"type":"tool_use","id":"toolu_123","name":"Bash","input":{"command":"ls"}}]},{"role":"user","content":[{"type":"tool_result","tool_use_id":"toolu_123","content":"ok"}]}],"tools":[{"name":"Bash","description":"run shell","input_schema":{"type":"object","properties":{"command":{"type":"string"}}}}],"stream":false}`) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body)) + c.Request.Header.Set("Content-Type", "application/json") + + result, err := svc.ForwardAsAnthropic(context.Background(), c, account, body, "stable-cache-key", "gpt-5.4") + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, "toolu_123", gjson.GetBytes(upstream.lastBody, `input.#(type=="function_call").call_id`).String()) + require.Equal(t, "toolu_123", gjson.GetBytes(upstream.lastBody, `input.#(type=="function_call_output").call_id`).String()) + require.True(t, gjson.GetBytes(upstream.lastBody, "parallel_tool_calls").Bool()) + require.Equal(t, "medium", gjson.GetBytes(upstream.lastBody, "text.verbosity").String()) + require.False(t, gjson.GetBytes(upstream.lastBody, "tools.0.strict").Bool()) +} + +func TestForwardAsAnthropic_StoresStreamingResponseIDWithoutUsage(t *testing.T) { + t.Parallel() + gin.SetMode(gin.TestMode) + + upstream := &httpUpstreamRecorder{} + svc := &OpenAIGatewayService{ + httpUpstream: upstream, + cfg: &config.Config{Security: config.SecurityConfig{URLAllowlist: config.URLAllowlistConfig{Enabled: false}}}, + } + account := &Account{ + ID: 1, + Name: "openai-apikey", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + "base_url": "https://api.openai.com/v1", + }, + } + + firstBody := []byte(`{"model":"claude-sonnet-4-5","max_tokens":16,"messages":[{"role":"user","content":"first"}],"stream":true}`) + upstream.resp = openAICompatSSEResponseWithoutUsage("resp_stream_first", "gpt-5.3-codex") + firstRec := httptest.NewRecorder() + firstCtx, _ := gin.CreateTestContext(firstRec) + firstCtx.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(firstBody)) + firstCtx.Request.Header.Set("Content-Type", "application/json") + + firstResult, err := svc.ForwardAsAnthropic(context.Background(), firstCtx, account, firstBody, "stable-cache-key", "gpt-5.3-codex") + require.NoError(t, err) + require.NotNil(t, firstResult) + require.Equal(t, "resp_stream_first", firstResult.ResponseID) + + secondBody := []byte(`{"model":"claude-sonnet-4-5","max_tokens":16,"messages":[{"role":"user","content":"first"},{"role":"assistant","content":"ok"},{"role":"user","content":"second"}],"stream":false}`) + upstream.resp = openAICompatSSECompletedResponse("resp_stream_second", "gpt-5.3-codex") + secondRec := httptest.NewRecorder() + secondCtx, _ := gin.CreateTestContext(secondRec) + secondCtx.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(secondBody)) + secondCtx.Request.Header.Set("Content-Type", "application/json") + + secondResult, err := svc.ForwardAsAnthropic(context.Background(), secondCtx, account, secondBody, "stable-cache-key", "gpt-5.3-codex") + require.NoError(t, err) + require.NotNil(t, secondResult) + require.Equal(t, "resp_stream_first", gjson.GetBytes(upstream.lastBody, "previous_response_id").String()) +} + +func openAICompatSSECompletedResponse(responseID, model string) *http.Response { + body := strings.Join([]string{ + `data: {"type":"response.completed","response":{"id":"` + responseID + `","object":"response","model":"` + model + `","status":"completed","output":[{"type":"message","id":"msg_1","role":"assistant","status":"completed","content":[{"type":"output_text","text":"ok"}]}],"usage":{"input_tokens":5,"output_tokens":2,"total_tokens":7}}}`, + "", + "data: [DONE]", + "", + }, "\n") + return &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_continuation"}}, + Body: io.NopCloser(strings.NewReader(body)), + } +} + +func openAICompatSSEResponseWithoutUsage(responseID, model string) *http.Response { + body := strings.Join([]string{ + `data: {"type":"response.completed","response":{"id":"` + responseID + `","object":"response","model":"` + model + `","status":"completed","output":[{"type":"message","id":"msg_1","role":"assistant","status":"completed","content":[{"type":"output_text","text":"ok"}]}]}}`, + "", + "data: [DONE]", + "", + }, "\n") + return &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_" + responseID}}, + Body: io.NopCloser(strings.NewReader(body)), + } +} + func TestForwardAsAnthropic_ForcedCodexInstructionsTemplatePrependsRenderedInstructions(t *testing.T) { t.Parallel() gin.SetMode(gin.TestMode) @@ -228,3 +1201,242 @@ func TestForwardAsAnthropic_ForcedCodexInstructionsTemplateUsesCachedTemplateCon require.NotNil(t, result) require.Equal(t, "cached-prefix\n\nclient-system", gjson.GetBytes(upstream.lastBody, "instructions").String()) } + +func TestForwardAsAnthropic_ClientDisconnectDrainsUpstreamUsage(t *testing.T) { + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Writer = &openAICompatFailingWriter{ResponseWriter: c.Writer, failAfter: 0} + body := []byte(`{"model":"gpt-5.4","max_tokens":16,"messages":[{"role":"user","content":"hello"}],"stream":true}`) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body)) + c.Request.Header.Set("Content-Type", "application/json") + + upstreamBody := strings.Join([]string{ + `data: {"type":"response.created","response":{"id":"resp_1","model":"gpt-5.4","status":"in_progress","output":[]}}`, + "", + `data: {"type":"response.output_text.delta","delta":"ok"}`, + "", + `data: {"type":"response.completed","response":{"id":"resp_1","object":"response","model":"gpt-5.4","status":"completed","output":[{"type":"message","id":"msg_1","role":"assistant","status":"completed","content":[{"type":"output_text","text":"ok"}]}],"usage":{"input_tokens":9,"output_tokens":4,"total_tokens":13,"input_tokens_details":{"cached_tokens":3}}}}`, + "", + "data: [DONE]", + "", + }, "\n") + upstream := &httpUpstreamRecorder{resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_disconnect"}}, + Body: io.NopCloser(strings.NewReader(upstreamBody)), + }} + + svc := &OpenAIGatewayService{httpUpstream: upstream} + account := &Account{ + ID: 1, + Name: "openai-oauth", + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Concurrency: 1, + Credentials: map[string]any{ + "access_token": "oauth-token", + "chatgpt_account_id": "chatgpt-acc", + }, + } + + result, err := svc.ForwardAsAnthropic(context.Background(), c, account, body, "", "gpt-5.1") + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, 9, result.Usage.InputTokens) + require.Equal(t, 4, result.Usage.OutputTokens) + require.Equal(t, 3, result.Usage.CacheReadInputTokens) +} + +func TestForwardAsAnthropic_TerminalUsageWithoutUpstreamCloseReturns(t *testing.T) { + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Writer = &openAICompatFailingWriter{ResponseWriter: c.Writer, failAfter: 0} + body := []byte(`{"model":"gpt-5.4","max_tokens":16,"messages":[{"role":"user","content":"hello"}],"stream":true}`) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body)) + c.Request.Header.Set("Content-Type", "application/json") + + upstreamBody := []byte(`data: {"type":"response.completed","response":{"id":"resp_1","object":"response","model":"gpt-5.4","status":"completed","output":[{"type":"message","id":"msg_1","role":"assistant","status":"completed","content":[{"type":"output_text","text":"ok"}]}],"usage":{"input_tokens":15,"output_tokens":6,"total_tokens":21,"input_tokens_details":{"cached_tokens":5}}}}` + "\n\n") + upstreamStream := newOpenAICompatBlockingReadCloser(upstreamBody) + defer func() { + require.NoError(t, upstreamStream.Close()) + }() + upstream := &httpUpstreamRecorder{resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_terminal_no_close"}}, + Body: upstreamStream, + }} + + svc := &OpenAIGatewayService{httpUpstream: upstream} + account := &Account{ + ID: 1, + Name: "openai-oauth", + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Concurrency: 1, + Credentials: map[string]any{ + "access_token": "oauth-token", + "chatgpt_account_id": "chatgpt-acc", + }, + } + + type forwardResult struct { + result *OpenAIForwardResult + err error + } + resultCh := make(chan forwardResult, 1) + go func() { + result, err := svc.ForwardAsAnthropic(context.Background(), c, account, body, "", "gpt-5.1") + resultCh <- forwardResult{result: result, err: err} + }() + + select { + case got := <-resultCh: + require.NoError(t, got.err) + require.NotNil(t, got.result) + require.Equal(t, 15, got.result.Usage.InputTokens) + require.Equal(t, 6, got.result.Usage.OutputTokens) + require.Equal(t, 5, got.result.Usage.CacheReadInputTokens) + case <-time.After(time.Second): + require.Fail(t, "ForwardAsAnthropic should return after terminal usage event even if upstream keeps the connection open") + } +} + +func TestForwardAsAnthropic_BufferedTerminalWithoutUpstreamCloseReturns(t *testing.T) { + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + body := []byte(`{"model":"gpt-5.4","max_tokens":16,"messages":[{"role":"user","content":"hello"}],"stream":false}`) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body)) + c.Request.Header.Set("Content-Type", "application/json") + + upstreamBody := []byte(`data: {"type":"response.completed","response":{"id":"resp_1","object":"response","model":"gpt-5.4","status":"completed","output":[{"type":"message","id":"msg_1","role":"assistant","status":"completed","content":[{"type":"output_text","text":"ok"}]}],"usage":{"input_tokens":15,"output_tokens":6,"total_tokens":21,"input_tokens_details":{"cached_tokens":5}}}}` + "\n\n") + upstreamStream := newOpenAICompatBlockingReadCloser(upstreamBody) + defer func() { + require.NoError(t, upstreamStream.Close()) + }() + upstream := &httpUpstreamRecorder{resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_buffered_terminal_no_close"}}, + Body: upstreamStream, + }} + + svc := &OpenAIGatewayService{httpUpstream: upstream} + account := &Account{ + ID: 1, + Name: "openai-oauth", + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Concurrency: 1, + Credentials: map[string]any{ + "access_token": "oauth-token", + "chatgpt_account_id": "chatgpt-acc", + }, + } + + type forwardResult struct { + result *OpenAIForwardResult + err error + } + resultCh := make(chan forwardResult, 1) + go func() { + result, err := svc.ForwardAsAnthropic(context.Background(), c, account, body, "", "gpt-5.1") + resultCh <- forwardResult{result: result, err: err} + }() + + select { + case got := <-resultCh: + require.NoError(t, got.err) + require.NotNil(t, got.result) + require.Equal(t, 15, got.result.Usage.InputTokens) + require.Equal(t, 6, got.result.Usage.OutputTokens) + require.Equal(t, 5, got.result.Usage.CacheReadInputTokens) + require.Contains(t, rec.Body.String(), `"stop_reason":"end_turn"`) + case <-time.After(time.Second): + require.Fail(t, "ForwardAsAnthropic buffered response should return after terminal usage event even if upstream keeps the connection open") + } +} + +func TestForwardAsAnthropic_DoneSentinelWithoutTerminalReturnsError(t *testing.T) { + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + body := []byte(`{"model":"gpt-5.4","max_tokens":16,"messages":[{"role":"user","content":"hello"}],"stream":true}`) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body)) + c.Request.Header.Set("Content-Type", "application/json") + + upstreamBody := "data: [DONE]\n\n" + upstream := &httpUpstreamRecorder{resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_missing_terminal"}}, + Body: io.NopCloser(strings.NewReader(upstreamBody)), + }} + + svc := &OpenAIGatewayService{httpUpstream: upstream} + account := &Account{ + ID: 1, + Name: "openai-oauth", + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Concurrency: 1, + Credentials: map[string]any{ + "access_token": "oauth-token", + "chatgpt_account_id": "chatgpt-acc", + }, + } + + result, err := svc.ForwardAsAnthropic(context.Background(), c, account, body, "", "gpt-5.1") + require.Error(t, err) + require.Contains(t, err.Error(), "missing terminal event") + require.NotNil(t, result) + require.Zero(t, result.Usage.InputTokens) + require.Zero(t, result.Usage.OutputTokens) +} + +func TestForwardAsAnthropic_UpstreamRequestIgnoresClientCancel(t *testing.T) { + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + reqCtx, cancel := context.WithCancel(context.Background()) + body := []byte(`{"model":"gpt-5.4","max_tokens":16,"messages":[{"role":"user","content":"hello"}],"stream":false}`) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body)).WithContext(reqCtx) + c.Request.Header.Set("Content-Type", "application/json") + cancel() + + upstreamBody := strings.Join([]string{ + `data: {"type":"response.completed","response":{"id":"resp_1","object":"response","model":"gpt-5.4","status":"completed","output":[{"type":"message","id":"msg_1","role":"assistant","status":"completed","content":[{"type":"output_text","text":"ok"}]}],"usage":{"input_tokens":5,"output_tokens":2,"total_tokens":7}}}`, + "", + "data: [DONE]", + "", + }, "\n") + upstream := &httpUpstreamRecorder{resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_ctx"}}, + Body: io.NopCloser(strings.NewReader(upstreamBody)), + }} + + svc := &OpenAIGatewayService{httpUpstream: upstream} + account := &Account{ + ID: 1, + Name: "openai-oauth", + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Concurrency: 1, + Credentials: map[string]any{ + "access_token": "oauth-token", + "chatgpt_account_id": "chatgpt-acc", + }, + } + + result, err := svc.ForwardAsAnthropic(reqCtx, c, account, body, "", "gpt-5.1") + require.NoError(t, err) + require.NotNil(t, result) + require.NotNil(t, upstream.lastReq) + require.NoError(t, upstream.lastReq.Context().Err()) +} diff --git a/backend/internal/service/openai_compat_prompt_cache_key.go b/backend/internal/service/openai_compat_prompt_cache_key.go index fcd27f19..de227ff1 100644 --- a/backend/internal/service/openai_compat_prompt_cache_key.go +++ b/backend/internal/service/openai_compat_prompt_cache_key.go @@ -1,7 +1,9 @@ package service import ( + "crypto/sha256" "encoding/json" + "fmt" "strings" "github.com/Wei-Shaw/sub2api/internal/pkg/apicompat" @@ -16,12 +18,8 @@ func shouldAutoInjectPromptCacheKeyForCompat(model string) bool { if !strings.Contains(trimmed, "gpt-5") && !strings.Contains(trimmed, "codex") { return false } - switch normalizeCodexModel(trimmed) { - case "gpt-5.4", "gpt-5.3-codex", "gpt-5.3-codex-spark": - return true - default: - return false - } + normalized := strings.TrimSpace(strings.ToLower(normalizeCodexModel(trimmed))) + return strings.HasPrefix(normalized, "gpt-5") || strings.Contains(normalized, "codex") } func deriveCompatPromptCacheKey(req *apicompat.ChatCompletionsRequest, mappedModel string) string { @@ -71,6 +69,102 @@ func deriveCompatPromptCacheKey(req *apicompat.ChatCompletionsRequest, mappedMod return compatPromptCacheKeyPrefix + hashSensitiveValueForLog(strings.Join(seedParts, "|")) } +func deriveAnthropicCompatPromptCacheKey(req *apicompat.AnthropicRequest, mappedModel string) string { + if req == nil { + return "" + } + if anchorKey := deriveAnthropicCacheControlPromptCacheKey(req); anchorKey != "" { + return anchorKey + } + + normalizedModel := normalizeCodexModel(strings.TrimSpace(mappedModel)) + if normalizedModel == "" { + normalizedModel = normalizeCodexModel(strings.TrimSpace(req.Model)) + } + if normalizedModel == "" { + normalizedModel = strings.TrimSpace(req.Model) + } + + seedParts := []string{"model=" + normalizedModel} + if req.OutputConfig != nil && strings.TrimSpace(req.OutputConfig.Effort) != "" { + seedParts = append(seedParts, "effort="+strings.TrimSpace(req.OutputConfig.Effort)) + } + if len(req.ToolChoice) > 0 { + seedParts = append(seedParts, "tool_choice="+normalizeCompatSeedJSON(req.ToolChoice)) + } + if len(req.Tools) > 0 { + if raw, err := json.Marshal(req.Tools); err == nil { + seedParts = append(seedParts, "tools="+normalizeCompatSeedJSON(raw)) + } + } + if len(req.System) > 0 { + seedParts = append(seedParts, "system="+normalizeCompatSeedJSON(req.System)) + } + + firstUserCaptured := false + for _, msg := range req.Messages { + if strings.TrimSpace(msg.Role) != "user" || firstUserCaptured { + continue + } + seedParts = append(seedParts, "first_user="+normalizeCompatSeedJSON(msg.Content)) + firstUserCaptured = true + } + + return compatPromptCacheKeyPrefix + hashSensitiveValueForLog(strings.Join(seedParts, "|")) +} + +func deriveAnthropicCacheControlPromptCacheKey(req *apicompat.AnthropicRequest) string { + if req == nil { + return "" + } + + var parts []string + var systemBlocks []apicompat.AnthropicContentBlock + if len(req.System) > 0 && json.Unmarshal(req.System, &systemBlocks) == nil { + for _, block := range systemBlocks { + if block.Type == "text" && + block.CacheControl != nil && + strings.TrimSpace(block.CacheControl.Type) == "ephemeral" && + strings.TrimSpace(block.Text) != "" { + parts = append(parts, "system:"+strings.TrimSpace(block.Text)) + } + } + } + + firstUserAnchor := "" + for _, msg := range req.Messages { + var blocks []apicompat.AnthropicContentBlock + if len(msg.Content) == 0 || json.Unmarshal(msg.Content, &blocks) != nil { + continue + } + role := strings.TrimSpace(msg.Role) + for _, block := range blocks { + if block.Type != "text" || + block.CacheControl == nil || + strings.TrimSpace(block.CacheControl.Type) != "ephemeral" || + strings.TrimSpace(block.Text) == "" { + continue + } + switch role { + case "user": + if firstUserAnchor == "" { + firstUserAnchor = strings.TrimSpace(block.Text) + } + case "assistant": + parts = append(parts, "assistant:"+strings.TrimSpace(block.Text)) + } + } + } + if firstUserAnchor != "" { + parts = append(parts, "user_anchor:"+firstUserAnchor) + } + if len(parts) == 0 { + return "" + } + sum := sha256.Sum256([]byte("anthropic-cache:" + strings.Join(parts, "\n"))) + return fmt.Sprintf("anthropic-cache-%x", sum[:16]) +} + func normalizeCompatSeedJSON(v json.RawMessage) string { if len(v) == 0 { return "" diff --git a/backend/internal/service/openai_compat_prompt_cache_key_test.go b/backend/internal/service/openai_compat_prompt_cache_key_test.go index 6ca3e85c..3fe7db6e 100644 --- a/backend/internal/service/openai_compat_prompt_cache_key_test.go +++ b/backend/internal/service/openai_compat_prompt_cache_key_test.go @@ -2,6 +2,7 @@ package service import ( "encoding/json" + "strings" "testing" "github.com/Wei-Shaw/sub2api/internal/pkg/apicompat" @@ -14,7 +15,10 @@ func mustRawJSON(t *testing.T, s string) json.RawMessage { } func TestShouldAutoInjectPromptCacheKeyForCompat(t *testing.T) { + require.True(t, shouldAutoInjectPromptCacheKeyForCompat("gpt-5.5")) require.True(t, shouldAutoInjectPromptCacheKeyForCompat("gpt-5.4")) + require.True(t, shouldAutoInjectPromptCacheKeyForCompat("gpt-5.4-mini")) + require.True(t, shouldAutoInjectPromptCacheKeyForCompat("gpt-5.2")) require.True(t, shouldAutoInjectPromptCacheKeyForCompat("gpt-5.3")) require.True(t, shouldAutoInjectPromptCacheKeyForCompat("gpt-5.3-codex")) require.True(t, shouldAutoInjectPromptCacheKeyForCompat("gpt-5.3-codex-spark")) @@ -77,3 +81,57 @@ func TestDeriveCompatPromptCacheKey_UsesResolvedSparkFamily(t *testing.T) { require.NotEmpty(t, k1) require.Equal(t, k1, k2, "resolved spark family should derive a stable compat cache key") } + +func TestDeriveAnthropicCompatPromptCacheKey_StableAcrossLaterTurns(t *testing.T) { + base := &apicompat.AnthropicRequest{ + Model: "claude-sonnet-4-5", + System: mustRawJSON(t, `"You are helpful."`), + Messages: []apicompat.AnthropicMessage{ + {Role: "user", Content: mustRawJSON(t, `"Open repo"`)}, + }, + } + extended := &apicompat.AnthropicRequest{ + Model: "claude-sonnet-4-5", + System: mustRawJSON(t, `"You are helpful."`), + Messages: []apicompat.AnthropicMessage{ + {Role: "user", Content: mustRawJSON(t, `"Open repo"`)}, + {Role: "assistant", Content: mustRawJSON(t, `"Opened."`)}, + {Role: "user", Content: mustRawJSON(t, `"Run tests"`)}, + }, + } + + k1 := deriveAnthropicCompatPromptCacheKey(base, "gpt-5.3-codex") + k2 := deriveAnthropicCompatPromptCacheKey(extended, "gpt-5.3-codex") + require.NotEmpty(t, k1) + require.Equal(t, k1, k2, "cache key should stay stable as later Claude Code turns append history") +} + +func TestDeriveAnthropicCompatPromptCacheKey_UsesCacheControlAnchors(t *testing.T) { + base := &apicompat.AnthropicRequest{ + Model: "claude-sonnet-4-5", + System: mustRawJSON(t, `[ + {"type":"text","text":"project instructions","cache_control":{"type":"ephemeral"}} + ]`), + Messages: []apicompat.AnthropicMessage{ + {Role: "user", Content: mustRawJSON(t, `[ + {"type":"text","text":"repo anchor","cache_control":{"type":"ephemeral"}} + ]`)}, + }, + } + extended := &apicompat.AnthropicRequest{ + Model: base.Model, + System: base.System, + Messages: []apicompat.AnthropicMessage{ + base.Messages[0], + {Role: "assistant", Content: mustRawJSON(t, `[{"type":"text","text":"Opened."}]`)}, + {Role: "user", Content: mustRawJSON(t, `[{"type":"text","text":"Run tests"}]`)}, + }, + } + + k1 := deriveAnthropicCompatPromptCacheKey(base, "gpt-5.4") + k2 := deriveAnthropicCompatPromptCacheKey(extended, "gpt-5.4") + require.NotEmpty(t, k1) + require.Equal(t, k1, k2) + require.True(t, strings.HasPrefix(k1, "anthropic-cache-")) + require.False(t, strings.HasPrefix(k1, compatPromptCacheKeyPrefix)) +} diff --git a/backend/internal/service/openai_fast_policy_ws_test.go b/backend/internal/service/openai_fast_policy_ws_test.go index 3316a242..7c8341b2 100644 --- a/backend/internal/service/openai_fast_policy_ws_test.go +++ b/backend/internal/service/openai_fast_policy_ws_test.go @@ -972,6 +972,62 @@ func TestPassthroughBilling_MultiTurnServiceTierFollowsFilteredFrames(t *testing "turn 3: response.create without service_tier overwrites billing to nil to match upstream default") } +func TestPassthroughUsageMeta_TracksReasoningEffortAcrossTurns(t *testing.T) { + svc := newOpenAIGatewayServiceWithSettings(t, DefaultOpenAIFastPolicySettings()) + account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey} + + firstFrame := []byte(`{"type":"response.create","model":"gpt-5.5","reasoning":{"effort":"medium"},"service_tier":"priority"}`) + meta := newOpenAIWSPassthroughUsageMeta("", firstFrame) + capturedSessionModel := openAIWSPassthroughPolicyModelForFrame(account, firstFrame) + firstOut, firstBlocked, firstErr := svc.applyOpenAIFastPolicyToWSResponseCreate(context.Background(), account, capturedSessionModel, firstFrame) + require.NoError(t, firstErr) + require.Nil(t, firstBlocked) + meta.initFromFirstFrame(firstOut) + require.NotNil(t, meta.reasoningEffort.Load()) + require.Equal(t, "medium", *meta.reasoningEffort.Load()) + + process := func(payload []byte) ([]byte, *OpenAIFastBlockedError, error) { + if updated := openAIWSPassthroughPolicyModelFromSessionFrame(account, payload); updated != "" { + capturedSessionModel = updated + } + meta.updateSessionRequestModel(payload) + requestModelForThisFrame := meta.requestModelForFrame(payload) + model := openAIWSPassthroughPolicyModelForFrame(account, payload) + if model == "" { + model = capturedSessionModel + } + out, blocked, policyErr := svc.applyOpenAIFastPolicyToWSResponseCreate(context.Background(), account, model, payload) + if policyErr == nil && blocked == nil && + strings.TrimSpace(gjson.GetBytes(payload, "type").String()) == "response.create" { + meta.updateFromResponseCreate(out, requestModelForThisFrame) + } + return out, blocked, policyErr + } + + _, blockedSession, errSession := process([]byte(`{"type":"session.update","session":{"model":"gpt-5-high"}}`)) + require.NoError(t, errSession) + require.Nil(t, blockedSession) + require.NotNil(t, meta.reasoningEffort.Load()) + require.Equal(t, "medium", *meta.reasoningEffort.Load(), "session.update 只刷新后续 fallback model,不覆盖当前 turn metadata") + + _, blockedCancel, errCancel := process([]byte(`{"type":"response.cancel","reasoning_effort":"x-high"}`)) + require.NoError(t, errCancel) + require.Nil(t, blockedCancel) + require.NotNil(t, meta.reasoningEffort.Load()) + require.Equal(t, "medium", *meta.reasoningEffort.Load(), "非 response.create 帧不能污染当前 turn metadata") + + _, blockedFlat, errFlat := process([]byte(`{"type":"response.create","reasoning_effort":"x-high"}`)) + require.NoError(t, errFlat) + require.Nil(t, blockedFlat) + require.NotNil(t, meta.reasoningEffort.Load()) + require.Equal(t, "xhigh", *meta.reasoningEffort.Load(), "flat reasoning_effort 必须进入 passthrough usage metadata") + + _, blockedClear, errClear := process([]byte(`{"type":"response.create","model":"gpt-4o"}`)) + require.NoError(t, errClear) + require.Nil(t, blockedClear) + require.Nil(t, meta.reasoningEffort.Load(), "新的 response.create 无 effort 且无可推导后缀时必须清空旧值") +} + // TestPassthroughBilling_BlockedFrameDoesNotMutateServiceTier locks in the // "block keeps previous" semantic: when policy returns block on a // response.create frame, that frame is never sent upstream, so billing tier diff --git a/backend/internal/service/openai_gateway_403_reset_test.go b/backend/internal/service/openai_gateway_403_reset_test.go index c6805464..440b94a9 100644 --- a/backend/internal/service/openai_gateway_403_reset_test.go +++ b/backend/internal/service/openai_gateway_403_reset_test.go @@ -20,20 +20,29 @@ func (s *openAI403CounterResetStub) ResetOpenAI403Count(_ context.Context, accou return nil } -func TestOpenAIGatewayServiceRecordUsage_ResetsOpenAI403CounterBeforeZeroUsageReturn(t *testing.T) { +func TestOpenAIGatewayServiceRecordUsage_ResetsOpenAI403CounterForZeroUsage(t *testing.T) { counter := &openAI403CounterResetStub{} rateLimitSvc := NewRateLimitService(nil, nil, nil, nil, nil) rateLimitSvc.SetOpenAI403CounterCache(counter) - svc := &OpenAIGatewayService{ - rateLimitService: rateLimitSvc, - } + usageRepo := &openAIRecordUsageLogRepoStub{inserted: true} + billingRepo := &openAIRecordUsageBillingRepoStub{result: &UsageBillingApplyResult{Applied: true}} + userRepo := &openAIRecordUsageUserRepoStub{} + subRepo := &openAIRecordUsageSubRepoStub{} + svc := newOpenAIRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, userRepo, subRepo, nil) + svc.rateLimitService = rateLimitSvc err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{ - Result: &OpenAIForwardResult{}, + Result: &OpenAIForwardResult{ + RequestID: "resp_zero_usage_reset_403", + Model: "gpt-5.1", + }, + APIKey: &APIKey{ID: 1001, Group: &Group{RateMultiplier: 1}}, + User: &User{ID: 2001}, Account: &Account{ID: 777, Platform: PlatformOpenAI}, }) require.NoError(t, err) require.Equal(t, []int64{777}, counter.resetCalls) + require.Equal(t, 1, usageRepo.calls) } diff --git a/backend/internal/service/openai_gateway_chat_completions.go b/backend/internal/service/openai_gateway_chat_completions.go index 5822ae4c..84d85c74 100644 --- a/backend/internal/service/openai_gateway_chat_completions.go +++ b/backend/internal/service/openai_gateway_chat_completions.go @@ -10,10 +10,12 @@ import ( "io" "net/http" "strings" + "sync/atomic" "time" "github.com/Wei-Shaw/sub2api/internal/pkg/apicompat" "github.com/Wei-Shaw/sub2api/internal/pkg/logger" + "github.com/Wei-Shaw/sub2api/internal/pkg/openai_compat" "github.com/Wei-Shaw/sub2api/internal/util/responseheaders" "github.com/gin-gonic/gin" "github.com/tidwall/gjson" @@ -39,9 +41,18 @@ var cursorResponsesUnsupportedFields = []string{ // ForwardAsChatCompletions accepts a Chat Completions request body, converts it // to OpenAI Responses API format, forwards to the OpenAI upstream, and converts -// the response back to Chat Completions format. All account types (OAuth and API -// Key) go through the Responses API conversion path since the upstream only -// exposes the /v1/responses endpoint. +// the response back to Chat Completions format. +// +// 历史背景:该函数原本对所有 OpenAI 账号无差别走 CC→Responses 转换 + /v1/responses +// 端点——这在 OAuth(ChatGPT 内部 API 仅支持 Responses)和官方 APIKey 账号上是 +// 正确的,但 sub2api 接入 DeepSeek/Kimi/GLM 等第三方 OpenAI 兼容上游后假设破裂: +// 这些上游普遍只支持 /v1/chat/completions,无 /v1/responses 端点。 +// +// 当前路由策略(基于账号探测标记,详见 openai_compat.ShouldUseResponsesAPI): +// - APIKey 账号 + 探测确认不支持 Responses → 走 forwardAsRawChatCompletions +// 直转上游 /v1/chat/completions,不做协议转换 +// - 其他所有情况(OAuth、APIKey 探测确认支持、未探测)→ 走原有 CC→Responses +// 转换路径(保留旧行为,存量未探测账号零兼容破坏) func (s *OpenAIGatewayService) ForwardAsChatCompletions( ctx context.Context, c *gin.Context, @@ -50,6 +61,12 @@ func (s *OpenAIGatewayService) ForwardAsChatCompletions( promptCacheKey string, defaultMappedModel string, ) (*OpenAIForwardResult, error) { + // 入口分流:APIKey 账号 + 已探测且确认上游不支持 Responses,走 CC 直转。 + // 标记缺失(未探测)按"现状即证据"原则继续走下方原 Responses 转换路径。 + if account.Type == AccountTypeAPIKey && !openai_compat.ShouldUseResponsesAPI(account.Extra) { + return s.forwardAsRawChatCompletions(ctx, c, account, body, defaultMappedModel) + } + startTime := time.Now() // 1. Parse Chat Completions request @@ -189,7 +206,9 @@ func (s *OpenAIGatewayService) ForwardAsChatCompletions( } // 6. Build upstream request - upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, responsesBody, token, true, promptCacheKey, false) + upstreamCtx, releaseUpstreamCtx := detachUpstreamContext(ctx) + upstreamReq, err := s.buildUpstreamRequest(upstreamCtx, c, account, responsesBody, token, true, promptCacheKey, false) + releaseUpstreamCtx() if err != nil { return nil, fmt.Errorf("build upstream request: %w", err) } @@ -348,59 +367,9 @@ func (s *OpenAIGatewayService) handleChatBufferedStreamingResponse( ) (*OpenAIForwardResult, error) { requestID := resp.Header.Get("x-request-id") - scanner := bufio.NewScanner(resp.Body) - maxLineSize := defaultMaxLineSize - if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 { - maxLineSize = s.cfg.Gateway.MaxLineSize - } - scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize) - - var finalResponse *apicompat.ResponsesResponse - var usage OpenAIUsage - acc := apicompat.NewBufferedResponseAccumulator() - - for scanner.Scan() { - line := scanner.Text() - if !strings.HasPrefix(line, "data: ") || line == "data: [DONE]" { - continue - } - payload := line[6:] - - var event apicompat.ResponsesStreamEvent - if err := json.Unmarshal([]byte(payload), &event); err != nil { - logger.L().Warn("openai chat_completions buffered: failed to parse event", - zap.Error(err), - zap.String("request_id", requestID), - ) - continue - } - - // Accumulate delta content for fallback when terminal output is empty. - acc.ProcessEvent(&event) - - if (event.Type == "response.completed" || event.Type == "response.done" || - event.Type == "response.incomplete" || event.Type == "response.failed") && - event.Response != nil { - finalResponse = event.Response - if event.Response.Usage != nil { - usage = OpenAIUsage{ - InputTokens: event.Response.Usage.InputTokens, - OutputTokens: event.Response.Usage.OutputTokens, - } - if event.Response.Usage.InputTokensDetails != nil { - usage.CacheReadInputTokens = event.Response.Usage.InputTokensDetails.CachedTokens - } - } - } - } - - if err := scanner.Err(); err != nil { - if !errors.Is(err, context.Canceled) && !errors.Is(err, context.DeadlineExceeded) { - logger.L().Warn("openai chat_completions buffered: read error", - zap.Error(err), - zap.String("request_id", requestID), - ) - } + finalResponse, usage, acc, err := s.readOpenAICompatBufferedTerminal(resp, "openai chat_completions buffered", requestID) + if err != nil { + return nil, err } if finalResponse == nil { @@ -459,6 +428,7 @@ func (s *OpenAIGatewayService) handleChatStreamingResponse( var usage OpenAIUsage var firstTokenMs *int firstChunk := true + clientDisconnected := false scanner := bufio.NewScanner(resp.Body) maxLineSize := defaultMaxLineSize @@ -467,6 +437,20 @@ func (s *OpenAIGatewayService) handleChatStreamingResponse( } scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize) + streamInterval := time.Duration(0) + if s.cfg != nil && s.cfg.Gateway.StreamDataIntervalTimeout > 0 { + streamInterval = time.Duration(s.cfg.Gateway.StreamDataIntervalTimeout) * time.Second + } + var intervalTicker *time.Ticker + if streamInterval > 0 { + intervalTicker = time.NewTicker(streamInterval) + defer intervalTicker.Stop() + } + var intervalCh <-chan time.Time + if intervalTicker != nil { + intervalCh = intervalTicker.C + } + resultWithUsage := func() *OpenAIForwardResult { return &OpenAIForwardResult{ RequestID: requestID, @@ -496,54 +480,66 @@ func (s *OpenAIGatewayService) handleChatStreamingResponse( return false } - // Extract usage from completion events - if (event.Type == "response.completed" || event.Type == "response.incomplete" || event.Type == "response.failed") && - event.Response != nil && event.Response.Usage != nil { - usage = OpenAIUsage{ - InputTokens: event.Response.Usage.InputTokens, - OutputTokens: event.Response.Usage.OutputTokens, - } - if event.Response.Usage.InputTokensDetails != nil { - usage.CacheReadInputTokens = event.Response.Usage.InputTokensDetails.CachedTokens - } + // 仅按兼容转换器支持的终止事件提取 usage,避免无意扩大事件语义。 + isTerminalEvent := isOpenAICompatResponsesTerminalEvent(event.Type) + if isTerminalEvent && event.Response != nil && event.Response.Usage != nil { + usage = copyOpenAIUsageFromResponsesUsage(event.Response.Usage) } chunks := apicompat.ResponsesEventToChatChunks(&event, state) - for _, chunk := range chunks { - sse, err := apicompat.ChatChunkToSSE(chunk) - if err != nil { - logger.L().Warn("openai chat_completions stream: failed to marshal chunk", - zap.Error(err), - zap.String("request_id", requestID), - ) - continue - } - if _, err := fmt.Fprint(c.Writer, sse); err != nil { - logger.L().Info("openai chat_completions stream: client disconnected", - zap.String("request_id", requestID), - ) - return true + if !clientDisconnected { + for _, chunk := range chunks { + sse, err := apicompat.ChatChunkToSSE(chunk) + if err != nil { + logger.L().Warn("openai chat_completions stream: failed to marshal chunk", + zap.Error(err), + zap.String("request_id", requestID), + ) + continue + } + if _, err := fmt.Fprint(c.Writer, sse); err != nil { + clientDisconnected = true + logger.L().Info("openai chat_completions stream: client disconnected, continuing to drain upstream for billing", + zap.String("request_id", requestID), + ) + break + } } } - if len(chunks) > 0 { + if len(chunks) > 0 && !clientDisconnected { c.Writer.Flush() } - return false + return isTerminalEvent } finalizeStream := func() (*OpenAIForwardResult, error) { - if finalChunks := apicompat.FinalizeResponsesChatStream(state); len(finalChunks) > 0 { + if finalChunks := apicompat.FinalizeResponsesChatStream(state); len(finalChunks) > 0 && !clientDisconnected { for _, chunk := range finalChunks { sse, err := apicompat.ChatChunkToSSE(chunk) if err != nil { continue } - fmt.Fprint(c.Writer, sse) //nolint:errcheck + if _, err := fmt.Fprint(c.Writer, sse); err != nil { + clientDisconnected = true + logger.L().Info("openai chat_completions stream: client disconnected during final flush", + zap.String("request_id", requestID), + ) + break + } } } // Send [DONE] sentinel - fmt.Fprint(c.Writer, "data: [DONE]\n\n") //nolint:errcheck - c.Writer.Flush() + if !clientDisconnected { + if _, err := fmt.Fprint(c.Writer, "data: [DONE]\n\n"); err != nil { + clientDisconnected = true + logger.L().Info("openai chat_completions stream: client disconnected during done flush", + zap.String("request_id", requestID), + ) + } + } + if !clientDisconnected { + c.Writer.Flush() + } return resultWithUsage(), nil } @@ -555,6 +551,9 @@ func (s *OpenAIGatewayService) handleChatStreamingResponse( ) } } + missingTerminalErr := func() (*OpenAIForwardResult, error) { + return resultWithUsage(), fmt.Errorf("stream usage incomplete: missing terminal event") + } // Determine keepalive interval keepaliveInterval := time.Duration(0) @@ -563,18 +562,25 @@ func (s *OpenAIGatewayService) handleChatStreamingResponse( } // No keepalive: fast synchronous path - if keepaliveInterval <= 0 { + if streamInterval <= 0 && keepaliveInterval <= 0 { for scanner.Scan() { line := scanner.Text() - if !strings.HasPrefix(line, "data: ") || line == "data: [DONE]" { + payload, ok := extractOpenAISSEDataLine(line) + if !ok { continue } - if processDataLine(line[6:]) { - return resultWithUsage(), nil + if strings.TrimSpace(payload) == "[DONE]" { + return missingTerminalErr() + } + if processDataLine(payload) { + return finalizeStream() } } - handleScanErr(scanner.Err()) - return finalizeStream() + if err := scanner.Err(); err != nil { + handleScanErr(err) + return resultWithUsage(), fmt.Errorf("stream usage incomplete: %w", err) + } + return missingTerminalErr() } // With keepalive: goroutine + channel + select @@ -584,6 +590,8 @@ func (s *OpenAIGatewayService) handleChatStreamingResponse( } events := make(chan scanEvent, 16) done := make(chan struct{}) + var lastReadAt int64 + atomic.StoreInt64(&lastReadAt, time.Now().UnixNano()) sendEvent := func(ev scanEvent) bool { select { case events <- ev: @@ -595,6 +603,7 @@ func (s *OpenAIGatewayService) handleChatStreamingResponse( go func() { defer close(events) for scanner.Scan() { + atomic.StoreInt64(&lastReadAt, time.Now().UnixNano()) if !sendEvent(scanEvent{line: scanner.Text()}) { return } @@ -605,30 +614,59 @@ func (s *OpenAIGatewayService) handleChatStreamingResponse( }() defer close(done) - keepaliveTicker := time.NewTicker(keepaliveInterval) - defer keepaliveTicker.Stop() + var keepaliveTicker *time.Ticker + if keepaliveInterval > 0 { + keepaliveTicker = time.NewTicker(keepaliveInterval) + defer keepaliveTicker.Stop() + } + var keepaliveCh <-chan time.Time + if keepaliveTicker != nil { + keepaliveCh = keepaliveTicker.C + } lastDataAt := time.Now() for { select { case ev, ok := <-events: if !ok { - return finalizeStream() + return missingTerminalErr() } if ev.err != nil { handleScanErr(ev.err) - return finalizeStream() + return resultWithUsage(), fmt.Errorf("stream usage incomplete: %w", ev.err) } lastDataAt = time.Now() line := ev.line - if !strings.HasPrefix(line, "data: ") || line == "data: [DONE]" { + payload, ok := extractOpenAISSEDataLine(line) + if !ok { continue } - if processDataLine(line[6:]) { - return resultWithUsage(), nil + if strings.TrimSpace(payload) == "[DONE]" { + return missingTerminalErr() + } + if processDataLine(payload) { + return finalizeStream() } - case <-keepaliveTicker.C: + case <-intervalCh: + lastRead := time.Unix(0, atomic.LoadInt64(&lastReadAt)) + if time.Since(lastRead) < streamInterval { + continue + } + if clientDisconnected { + return resultWithUsage(), fmt.Errorf("stream usage incomplete after timeout") + } + logger.L().Warn("openai chat_completions stream: data interval timeout", + zap.String("request_id", requestID), + zap.String("model", originalModel), + zap.Duration("interval", streamInterval), + ) + return resultWithUsage(), fmt.Errorf("stream data interval timeout") + + case <-keepaliveCh: + if clientDisconnected { + continue + } if time.Since(lastDataAt) < keepaliveInterval { continue } @@ -637,7 +675,8 @@ func (s *OpenAIGatewayService) handleChatStreamingResponse( logger.L().Info("openai chat_completions stream: client disconnected during keepalive", zap.String("request_id", requestID), ) - return resultWithUsage(), nil + clientDisconnected = true + continue } c.Writer.Flush() } diff --git a/backend/internal/service/openai_gateway_chat_completions_raw.go b/backend/internal/service/openai_gateway_chat_completions_raw.go new file mode 100644 index 00000000..3be765a2 --- /dev/null +++ b/backend/internal/service/openai_gateway_chat_completions_raw.go @@ -0,0 +1,437 @@ +package service + +import ( + "bufio" + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "strings" + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/apicompat" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" + "github.com/Wei-Shaw/sub2api/internal/util/responseheaders" + "github.com/gin-gonic/gin" + "github.com/tidwall/gjson" + "github.com/tidwall/sjson" + "go.uber.org/zap" +) + +// openaiCCRawAllowedHeaders 是 CC 直转路径专用的客户端 header 透传白名单。 +// +// **关键**:不能复用 openaiAllowedHeaders——后者含 Codex 客户端专属 header +// (originator / session_id / x-codex-turn-state / x-codex-turn-metadata / conversation_id), +// 这些在 ChatGPT OAuth 上游是必需的,但透传给 DeepSeek/Kimi/GLM 等第三方 +// OpenAI 兼容上游会造成: +// - 完全忽略(多数友好厂商)——隐性污染上游统计 +// - 400 "unknown parameter"(严格上游)——可见错误 +// +// 这里仅放行通用 HTTP header;content-type / authorization / accept 由上下文 +// 显式设置,不依赖透传。 +// +// 参见决策记录: +// pensieve/short-term/maxims/dont-reuse-shared-headers-whitelist-across-different-upstream-trust-domains +var openaiCCRawAllowedHeaders = map[string]bool{ + "accept-language": true, + "user-agent": true, +} + +// forwardAsRawChatCompletions 直转客户端的 Chat Completions 请求到上游 +// `{base_url}/v1/chat/completions`,**不**做 CC↔Responses 协议转换。 +// +// 适用场景:account.platform=openai && account.type=apikey && 上游已被探测确认 +// 不支持 /v1/responses 端点(如 DeepSeek/Kimi/GLM/Qwen 等第三方 OpenAI 兼容上游)。 +// +// 与 ForwardAsChatCompletions 的关键差异: +// +// - 不调用 apicompat.ChatCompletionsToResponses,body 仅做模型 ID 改写 +// - 上游 URL 拼到 /v1/chat/completions 而非 /v1/responses +// - 流式响应 SSE 直接透传给客户端(上游 chunk 已是 CC 格式) +// - 非流式响应 JSON 直接透传,仅按需提取 usage +// - 不应用 codex OAuth transform(APIKey 路径无 OAuth) +// - 不注入 prompt_cache_key(OAuth 专属机制) +// +// 调用入口:openai_gateway_chat_completions.go::ForwardAsChatCompletions +// 在函数顶部按 openai_compat.ShouldUseResponsesAPI 分流。 +func (s *OpenAIGatewayService) forwardAsRawChatCompletions( + ctx context.Context, + c *gin.Context, + account *Account, + body []byte, + defaultMappedModel string, +) (*OpenAIForwardResult, error) { + startTime := time.Now() + + // 1. Parse minimal fields needed for routing/billing + originalModel := gjson.GetBytes(body, "model").String() + if originalModel == "" { + writeChatCompletionsError(c, http.StatusBadRequest, "invalid_request_error", "model is required") + return nil, fmt.Errorf("missing model in request") + } + clientStream := gjson.GetBytes(body, "stream").Bool() + + // 1b. Extract reasoning effort and service tier from the raw body before any transformation. + reasoningEffort := extractOpenAIReasoningEffortFromBody(body, originalModel) + serviceTier := extractOpenAIServiceTierFromBody(body) + + // 2. Resolve model mapping (same as ForwardAsChatCompletions) + billingModel := resolveOpenAIForwardModel(account, originalModel, defaultMappedModel) + upstreamModel := normalizeOpenAIModelForUpstream(account, billingModel) + + // 3. Rewrite model in body (no protocol conversion) + upstreamBody := body + if upstreamModel != originalModel { + upstreamBody = ReplaceModelInBody(body, upstreamModel) + } + + // 4. Apply OpenAI fast policy on the CC body + updatedBody, policyErr := s.applyOpenAIFastPolicyToBody(ctx, account, upstreamModel, upstreamBody) + if policyErr != nil { + var blocked *OpenAIFastBlockedError + if errors.As(policyErr, &blocked) { + writeChatCompletionsError(c, http.StatusForbidden, "permission_error", blocked.Message) + } + return nil, policyErr + } + upstreamBody = updatedBody + if clientStream { + var usageErr error + upstreamBody, usageErr = ensureOpenAIChatStreamUsage(upstreamBody) + if usageErr != nil { + return nil, fmt.Errorf("enable stream usage: %w", usageErr) + } + } + + logger.L().Debug("openai chat_completions raw: forwarding without protocol conversion", + zap.Int64("account_id", account.ID), + zap.String("original_model", originalModel), + zap.String("billing_model", billingModel), + zap.String("upstream_model", upstreamModel), + zap.Bool("stream", clientStream), + ) + + // 5. Build upstream request + apiKey := account.GetOpenAIApiKey() + if apiKey == "" { + return nil, fmt.Errorf("account %d missing api_key", account.ID) + } + baseURL := account.GetOpenAIBaseURL() + if baseURL == "" { + baseURL = "https://api.openai.com" + } + validatedURL, err := s.validateUpstreamBaseURL(baseURL) + if err != nil { + return nil, fmt.Errorf("invalid base_url: %w", err) + } + targetURL := buildOpenAIChatCompletionsURL(validatedURL) + + upstreamCtx, releaseUpstreamCtx := detachUpstreamContext(ctx) + upstreamReq, err := http.NewRequestWithContext(upstreamCtx, http.MethodPost, targetURL, bytes.NewReader(upstreamBody)) + releaseUpstreamCtx() + if err != nil { + return nil, fmt.Errorf("build upstream request: %w", err) + } + upstreamReq.Header.Set("Content-Type", "application/json") + upstreamReq.Header.Set("Authorization", "Bearer "+apiKey) + if clientStream { + upstreamReq.Header.Set("Accept", "text/event-stream") + } else { + upstreamReq.Header.Set("Accept", "application/json") + } + + // 透传白名单中的客户端 header。详见 openaiCCRawAllowedHeaders 的设计说明。 + for key, values := range c.Request.Header { + lowerKey := strings.ToLower(key) + if openaiCCRawAllowedHeaders[lowerKey] { + for _, v := range values { + upstreamReq.Header.Add(key, v) + } + } + } + customUA := account.GetOpenAIUserAgent() + if customUA != "" { + upstreamReq.Header.Set("user-agent", customUA) + } + + // 6. Send request + proxyURL := "" + if account.Proxy != nil { + proxyURL = account.Proxy.URL() + } + resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency) + if err != nil { + safeErr := sanitizeUpstreamErrorMessage(err.Error()) + setOpsUpstreamError(c, 0, safeErr, "") + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: 0, + Kind: "request_error", + Message: safeErr, + }) + writeChatCompletionsError(c, http.StatusBadGateway, "upstream_error", "Upstream request failed") + return nil, fmt.Errorf("upstream request failed: %s", safeErr) + } + defer func() { _ = resp.Body.Close() }() + + // 7. Handle error response with failover + if resp.StatusCode >= 400 { + respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + _ = resp.Body.Close() + resp.Body = io.NopCloser(bytes.NewReader(respBody)) + + upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody)) + upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) + if s.shouldFailoverOpenAIUpstreamResponse(resp.StatusCode, upstreamMsg, respBody) { + upstreamDetail := "" + if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { + maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes + if maxBytes <= 0 { + maxBytes = 2048 + } + upstreamDetail = truncateString(string(respBody), maxBytes) + } + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: resp.StatusCode, + UpstreamRequestID: resp.Header.Get("x-request-id"), + Kind: "failover", + Message: upstreamMsg, + Detail: upstreamDetail, + }) + if s.rateLimitService != nil { + s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody) + } + return nil, &UpstreamFailoverError{ + StatusCode: resp.StatusCode, + ResponseBody: respBody, + RetryableOnSameAccount: account.IsPoolMode() && (isPoolModeRetryableStatus(resp.StatusCode) || isOpenAITransientProcessingError(resp.StatusCode, upstreamMsg, respBody)), + } + } + return s.handleChatCompletionsErrorResponse(resp, c, account) + } + + // 8. Forward response + if clientStream { + return s.streamRawChatCompletions(c, resp, originalModel, billingModel, upstreamModel, reasoningEffort, serviceTier, startTime) + } + return s.bufferRawChatCompletions(c, resp, originalModel, billingModel, upstreamModel, reasoningEffort, serviceTier, startTime) +} + +// streamRawChatCompletions 透传上游 CC SSE 流到客户端,并提取 usage(包括 +// 末尾 [DONE] 之前的 chunk 中的 usage 字段,按 OpenAI CC 协议)。 +// +// usage 字段仅在客户端请求 stream_options.include_usage=true 时出现于上游响应中。 +// 网关会对上游强制打开 include_usage 以保证计费完整,并原样向下游透传 usage, +// 让级联代理或下游计费系统也能拿到完整用量。 +func (s *OpenAIGatewayService) streamRawChatCompletions( + c *gin.Context, + resp *http.Response, + originalModel string, + billingModel string, + upstreamModel string, + reasoningEffort *string, + serviceTier *string, + startTime time.Time, +) (*OpenAIForwardResult, error) { + requestID := resp.Header.Get("x-request-id") + + if s.responseHeaderFilter != nil { + responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter) + } + c.Writer.Header().Set("Content-Type", "text/event-stream") + c.Writer.Header().Set("Cache-Control", "no-cache") + c.Writer.Header().Set("Connection", "keep-alive") + c.Writer.Header().Set("X-Accel-Buffering", "no") + c.Writer.WriteHeader(http.StatusOK) + + scanner := bufio.NewScanner(resp.Body) + maxLineSize := defaultMaxLineSize + if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 { + maxLineSize = s.cfg.Gateway.MaxLineSize + } + scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize) + + var usage OpenAIUsage + var firstTokenMs *int + clientDisconnected := false + + for scanner.Scan() { + line := scanner.Text() + if payload, ok := extractOpenAISSEDataLine(line); ok { + trimmedPayload := strings.TrimSpace(payload) + if trimmedPayload != "[DONE]" { + usageOnlyChunk := isOpenAIChatUsageOnlyStreamChunk(payload) + if u := extractCCStreamUsage(payload); u != nil { + usage = *u + } + if firstTokenMs == nil && !usageOnlyChunk { + elapsed := int(time.Since(startTime).Milliseconds()) + firstTokenMs = &elapsed + } + } + } + + if !clientDisconnected { + if _, werr := c.Writer.WriteString(line + "\n"); werr != nil { + clientDisconnected = true + logger.L().Debug("openai chat_completions raw: client disconnected, continuing to drain upstream for billing", + zap.Error(werr), + zap.String("request_id", requestID), + ) + } + } + if line == "" { + if !clientDisconnected { + c.Writer.Flush() + } + continue + } + if !clientDisconnected { + c.Writer.Flush() + } + } + + if err := scanner.Err(); err != nil { + if !errors.Is(err, context.Canceled) && !errors.Is(err, context.DeadlineExceeded) { + logger.L().Warn("openai chat_completions raw: stream read error", + zap.Error(err), + zap.String("request_id", requestID), + ) + } + } + + return &OpenAIForwardResult{ + RequestID: requestID, + Usage: usage, + Model: originalModel, + BillingModel: billingModel, + UpstreamModel: upstreamModel, + ReasoningEffort: reasoningEffort, + ServiceTier: serviceTier, + Stream: true, + Duration: time.Since(startTime), + FirstTokenMs: firstTokenMs, + }, nil +} + +// ensureOpenAIChatStreamUsage 确保 raw Chat Completions 流式请求会让上游返回 usage。 +// usage 也会继续向下游透传,支持级联代理和下游计费系统。 +func ensureOpenAIChatStreamUsage(body []byte) ([]byte, error) { + updated, err := sjson.SetBytes(body, "stream_options.include_usage", true) + if err != nil { + return body, err + } + return updated, nil +} + +func isOpenAIChatUsageOnlyStreamChunk(payload string) bool { + if strings.TrimSpace(payload) == "" { + return false + } + if !gjson.Get(payload, "usage").Exists() { + return false + } + choices := gjson.Get(payload, "choices") + return choices.Exists() && choices.IsArray() && len(choices.Array()) == 0 +} + +// extractCCStreamUsage 从单个 CC 流式 chunk 的 payload 中提取 usage 字段。 +// CC 协议中 usage 仅出现在末尾 chunk(且仅当 include_usage 生效时), +// 但上游可能在多个 chunk 中重复——总是用最新值。 +func extractCCStreamUsage(payload string) *OpenAIUsage { + usageResult := gjson.Get(payload, "usage") + if !usageResult.Exists() || !usageResult.IsObject() { + return nil + } + u := OpenAIUsage{ + InputTokens: int(gjson.Get(payload, "usage.prompt_tokens").Int()), + OutputTokens: int(gjson.Get(payload, "usage.completion_tokens").Int()), + } + if cached := gjson.Get(payload, "usage.prompt_tokens_details.cached_tokens"); cached.Exists() { + u.CacheReadInputTokens = int(cached.Int()) + } + return &u +} + +// bufferRawChatCompletions 透传上游 CC 非流式 JSON 响应。 +func (s *OpenAIGatewayService) bufferRawChatCompletions( + c *gin.Context, + resp *http.Response, + originalModel string, + billingModel string, + upstreamModel string, + reasoningEffort *string, + serviceTier *string, + startTime time.Time, +) (*OpenAIForwardResult, error) { + requestID := resp.Header.Get("x-request-id") + + respBody, err := ReadUpstreamResponseBody(resp.Body, s.cfg, c, openAITooLargeError) + if err != nil { + if !errors.Is(err, ErrUpstreamResponseBodyTooLarge) { + writeChatCompletionsError(c, http.StatusBadGateway, "api_error", "Failed to read upstream response") + } + return nil, fmt.Errorf("read upstream body: %w", err) + } + + var ccResp apicompat.ChatCompletionsResponse + var usage OpenAIUsage + if err := json.Unmarshal(respBody, &ccResp); err == nil && ccResp.Usage != nil { + usage = OpenAIUsage{ + InputTokens: ccResp.Usage.PromptTokens, + OutputTokens: ccResp.Usage.CompletionTokens, + } + if ccResp.Usage.PromptTokensDetails != nil { + usage.CacheReadInputTokens = ccResp.Usage.PromptTokensDetails.CachedTokens + } + } + + if s.responseHeaderFilter != nil { + responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter) + } + if ct := resp.Header.Get("Content-Type"); ct != "" { + c.Writer.Header().Set("Content-Type", ct) + } else { + c.Writer.Header().Set("Content-Type", "application/json") + } + c.Writer.WriteHeader(http.StatusOK) + _, _ = c.Writer.Write(respBody) + + return &OpenAIForwardResult{ + RequestID: requestID, + Usage: usage, + Model: originalModel, + BillingModel: billingModel, + UpstreamModel: upstreamModel, + ReasoningEffort: reasoningEffort, + ServiceTier: serviceTier, + Stream: false, + Duration: time.Since(startTime), + }, nil +} + +// buildOpenAIChatCompletionsURL 拼接上游 Chat Completions 端点 URL。 +// +// - base 已是 /chat/completions:原样返回 +// - base 以 /v1 结尾:追加 /chat/completions +// - 其他情况:追加 /v1/chat/completions +// +// 与 buildOpenAIResponsesURL 是姐妹函数。 +func buildOpenAIChatCompletionsURL(base string) string { + normalized := strings.TrimRight(strings.TrimSpace(base), "/") + if strings.HasSuffix(normalized, "/chat/completions") { + return normalized + } + if strings.HasSuffix(normalized, "/v1") { + return normalized + "/chat/completions" + } + return normalized + "/v1/chat/completions" +} diff --git a/backend/internal/service/openai_gateway_chat_completions_raw_test.go b/backend/internal/service/openai_gateway_chat_completions_raw_test.go new file mode 100644 index 00000000..1be07fd7 --- /dev/null +++ b/backend/internal/service/openai_gateway_chat_completions_raw_test.go @@ -0,0 +1,260 @@ +//go:build unit + +package service + +import ( + "bytes" + "context" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" + "github.com/tidwall/gjson" +) + +func TestBuildOpenAIChatCompletionsURL(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + base string + want string + }{ + // 已是 /chat/completions:原样返回 + {"already chat/completions", "https://api.openai.com/v1/chat/completions", "https://api.openai.com/v1/chat/completions"}, + // 以 /v1 结尾:追加 /chat/completions + {"bare /v1", "https://api.openai.com/v1", "https://api.openai.com/v1/chat/completions"}, + // 其他情况:追加 /v1/chat/completions + {"bare domain", "https://api.openai.com", "https://api.openai.com/v1/chat/completions"}, + {"domain with trailing slash", "https://api.openai.com/", "https://api.openai.com/v1/chat/completions"}, + // 第三方上游常见形式 + {"third-party bare domain", "https://api.deepseek.com", "https://api.deepseek.com/v1/chat/completions"}, + {"third-party with path prefix", "https://api.gptgod.online/api", "https://api.gptgod.online/api/v1/chat/completions"}, + // 带空白字符 + {"whitespace trimmed", " https://api.openai.com/v1 ", "https://api.openai.com/v1/chat/completions"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + got := buildOpenAIChatCompletionsURL(tt.base) + require.Equal(t, tt.want, got) + }) + } +} + +// TestBuildOpenAIResponsesURL_ProbeURL 锁定 probe/测试端点使用的 URL 构建逻辑, +// 确保 buildOpenAIResponsesURL 对标准 OpenAI base_url 格式均拼出 `/v1/responses`。 +func TestBuildOpenAIResponsesURL_ProbeURL(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + base string + want string + }{ + {"bare domain", "https://api.openai.com", "https://api.openai.com/v1/responses"}, + {"domain trailing slash", "https://api.openai.com/", "https://api.openai.com/v1/responses"}, + {"bare /v1", "https://api.openai.com/v1", "https://api.openai.com/v1/responses"}, + {"already /responses", "https://api.openai.com/v1/responses", "https://api.openai.com/v1/responses"}, + {"third-party bare domain", "https://api.deepseek.com", "https://api.deepseek.com/v1/responses"}, + {"only domain, no scheme", "api.gptgod.online", "api.gptgod.online/v1/responses"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + got := buildOpenAIResponsesURL(tt.base) + require.Equal(t, tt.want, got) + }) + } +} + +func TestForwardAsRawChatCompletions_ForcesStreamUsageUpstreamAndPassesUsageDownstream(t *testing.T) { + gin.SetMode(gin.TestMode) + + body := []byte(`{"model":"gpt-5.4","messages":[{"role":"user","content":"hello"}],"stream":true}`) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", bytes.NewReader(body)) + c.Request.Header.Set("Content-Type", "application/json") + + upstreamBody := strings.Join([]string{ + `data: {"id":"chatcmpl_1","object":"chat.completion.chunk","model":"gpt-5.4","choices":[{"index":0,"delta":{"content":"ok"}}]}`, + "", + `data: {"id":"chatcmpl_1","object":"chat.completion.chunk","model":"gpt-5.4","choices":[],"usage":{"prompt_tokens":9,"completion_tokens":4,"total_tokens":13,"prompt_tokens_details":{"cached_tokens":3}}}`, + "", + "data: [DONE]", + "", + }, "\n") + upstream := &httpUpstreamRecorder{resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_raw_usage"}}, + Body: io.NopCloser(strings.NewReader(upstreamBody)), + }} + + svc := &OpenAIGatewayService{ + cfg: rawChatCompletionsTestConfig(), + httpUpstream: upstream, + } + account := rawChatCompletionsTestAccount() + + result, err := svc.forwardAsRawChatCompletions(context.Background(), c, account, body, "") + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, 9, result.Usage.InputTokens) + require.Equal(t, 4, result.Usage.OutputTokens) + require.Equal(t, 3, result.Usage.CacheReadInputTokens) + require.NotNil(t, upstream.lastReq) + require.NoError(t, upstream.lastReq.Context().Err()) + require.True(t, gjson.GetBytes(upstream.lastBody, "stream_options.include_usage").Bool()) + require.Contains(t, rec.Body.String(), `"usage"`) + require.Contains(t, rec.Body.String(), "data: [DONE]") +} + +func TestForwardAsRawChatCompletions_ClientDisconnectDrainsUsage(t *testing.T) { + gin.SetMode(gin.TestMode) + + body := []byte(`{"model":"gpt-5.4","messages":[{"role":"user","content":"hello"}],"stream":true}`) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Writer = &openAIChatFailingWriter{ResponseWriter: c.Writer, failAfter: 0} + c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", bytes.NewReader(body)) + c.Request.Header.Set("Content-Type", "application/json") + + upstreamBody := strings.Join([]string{ + `data: {"id":"chatcmpl_1","object":"chat.completion.chunk","model":"gpt-5.4","choices":[{"index":0,"delta":{"content":"ok"}}]}`, + "", + `data: {"id":"chatcmpl_1","object":"chat.completion.chunk","model":"gpt-5.4","choices":[],"usage":{"prompt_tokens":17,"completion_tokens":8,"total_tokens":25,"prompt_tokens_details":{"cached_tokens":6}}}`, + "", + "data: [DONE]", + "", + }, "\n") + upstream := &httpUpstreamRecorder{resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_raw_disconnect"}}, + Body: io.NopCloser(strings.NewReader(upstreamBody)), + }} + + svc := &OpenAIGatewayService{ + cfg: rawChatCompletionsTestConfig(), + httpUpstream: upstream, + } + account := rawChatCompletionsTestAccount() + + result, err := svc.forwardAsRawChatCompletions(context.Background(), c, account, body, "") + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, 17, result.Usage.InputTokens) + require.Equal(t, 8, result.Usage.OutputTokens) + require.Equal(t, 6, result.Usage.CacheReadInputTokens) + require.True(t, gjson.GetBytes(upstream.lastBody, "stream_options.include_usage").Bool()) +} + +func TestForwardAsRawChatCompletions_UpstreamRequestIgnoresClientCancel(t *testing.T) { + gin.SetMode(gin.TestMode) + + reqCtx, cancel := context.WithCancel(context.Background()) + body := []byte(`{"model":"gpt-5.4","messages":[{"role":"user","content":"hello"}],"stream":true}`) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", bytes.NewReader(body)).WithContext(reqCtx) + c.Request.Header.Set("Content-Type", "application/json") + cancel() + + upstreamBody := strings.Join([]string{ + `data: {"id":"chatcmpl_1","object":"chat.completion.chunk","model":"gpt-5.4","choices":[],"usage":{"prompt_tokens":5,"completion_tokens":2,"total_tokens":7}}`, + "", + "data: [DONE]", + "", + }, "\n") + upstream := &httpUpstreamRecorder{resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_raw_ctx"}}, + Body: io.NopCloser(strings.NewReader(upstreamBody)), + }} + + svc := &OpenAIGatewayService{ + cfg: rawChatCompletionsTestConfig(), + httpUpstream: upstream, + } + account := rawChatCompletionsTestAccount() + + result, err := svc.forwardAsRawChatCompletions(reqCtx, c, account, body, "") + require.NoError(t, err) + require.NotNil(t, result) + require.NotNil(t, upstream.lastReq) + require.NoError(t, upstream.lastReq.Context().Err()) +} + +func TestIsOpenAIChatUsageOnlyStreamChunk(t *testing.T) { + t.Parallel() + + require.True(t, isOpenAIChatUsageOnlyStreamChunk(`{"choices":[],"usage":{"prompt_tokens":1,"completion_tokens":2}}`)) + require.False(t, isOpenAIChatUsageOnlyStreamChunk(`{"choices":[{"index":0}],"usage":{"prompt_tokens":1,"completion_tokens":2}}`)) + require.False(t, isOpenAIChatUsageOnlyStreamChunk(`{"choices":[]}`)) + require.False(t, isOpenAIChatUsageOnlyStreamChunk(``)) +} + +func TestEnsureOpenAIChatStreamUsage(t *testing.T) { + t.Parallel() + + body, err := ensureOpenAIChatStreamUsage([]byte(`{"model":"gpt-5.4"}`)) + require.NoError(t, err) + require.True(t, gjson.GetBytes(body, "stream_options.include_usage").Bool()) + + body, err = ensureOpenAIChatStreamUsage([]byte(`{"model":"gpt-5.4","stream_options":{"include_usage":false}}`)) + require.NoError(t, err) + require.True(t, gjson.GetBytes(body, "stream_options.include_usage").Bool()) +} + +func TestBufferRawChatCompletions_RejectsOversizedResponse(t *testing.T) { + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil) + resp := &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"application/json"}}, + Body: io.NopCloser(strings.NewReader("toolong")), + } + svc := &OpenAIGatewayService{cfg: rawChatCompletionsTestConfig()} + svc.cfg.Gateway.UpstreamResponseReadMaxBytes = 3 + + result, err := svc.bufferRawChatCompletions(c, resp, "gpt-5.4", "gpt-5.4", "gpt-5.4", nil, nil, time.Now()) + require.ErrorIs(t, err, ErrUpstreamResponseBodyTooLarge) + require.Nil(t, result) + require.Equal(t, http.StatusBadGateway, rec.Code) +} + +func rawChatCompletionsTestConfig() *config.Config { + return &config.Config{ + Security: config.SecurityConfig{ + URLAllowlist: config.URLAllowlistConfig{ + Enabled: false, + AllowInsecureHTTP: true, + }, + }, + } +} + +func rawChatCompletionsTestAccount() *Account { + return &Account{ + ID: 101, + Name: "raw-openai-apikey", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + "base_url": "http://upstream.example", + }, + } +} diff --git a/backend/internal/service/openai_gateway_chat_completions_test.go b/backend/internal/service/openai_gateway_chat_completions_test.go index 6846e03a..b0d1fa31 100644 --- a/backend/internal/service/openai_gateway_chat_completions_test.go +++ b/backend/internal/service/openai_gateway_chat_completions_test.go @@ -1,13 +1,36 @@ package service import ( + "bytes" + "context" + "errors" + "io" + "net/http" + "net/http/httptest" + "strings" "testing" + "time" "github.com/Wei-Shaw/sub2api/internal/pkg/apicompat" + "github.com/gin-gonic/gin" "github.com/stretchr/testify/require" "github.com/tidwall/gjson" ) +type openAIChatFailingWriter struct { + gin.ResponseWriter + failAfter int + writes int +} + +func (w *openAIChatFailingWriter) Write(p []byte) (int, error) { + if w.writes >= w.failAfter { + return 0, errors.New("write failed: client disconnected") + } + w.writes++ + return w.ResponseWriter.Write(p) +} + func TestNormalizeResponsesRequestServiceTier(t *testing.T) { t.Parallel() @@ -73,3 +96,278 @@ func TestNormalizeResponsesBodyServiceTier(t *testing.T) { require.Empty(t, tier) require.False(t, gjson.GetBytes(body, "service_tier").Exists()) } + +func TestForwardAsChatCompletions_UnknownModelDoesNotUseDefaultMappedModel(t *testing.T) { + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + body := []byte(`{"model":"gpt6","messages":[{"role":"user","content":"hello"}],"stream":false}`) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", bytes.NewReader(body)) + c.Request.Header.Set("Content-Type", "application/json") + + upstream := &httpUpstreamRecorder{resp: &http.Response{ + StatusCode: http.StatusBadRequest, + Header: http.Header{"Content-Type": []string{"application/json"}, "x-request-id": []string{"rid_chat_unknown_model"}}, + Body: io.NopCloser(strings.NewReader(`{"error":{"type":"invalid_request_error","message":"model not found"}}`)), + }} + + svc := &OpenAIGatewayService{httpUpstream: upstream} + account := &Account{ + ID: 1, + Name: "openai-oauth", + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Concurrency: 1, + Credentials: map[string]any{ + "access_token": "oauth-token", + "chatgpt_account_id": "chatgpt-acc", + }, + } + + result, err := svc.ForwardAsChatCompletions(context.Background(), c, account, body, "", "gpt-5.4") + require.Error(t, err) + require.Nil(t, result) + require.Equal(t, "gpt6", gjson.GetBytes(upstream.lastBody, "model").String()) + require.NotEqual(t, "gpt-5.4", gjson.GetBytes(upstream.lastBody, "model").String()) + require.Equal(t, http.StatusBadRequest, rec.Code) +} + +func TestForwardAsChatCompletions_ClientDisconnectDrainsUpstreamUsage(t *testing.T) { + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Writer = &openAIChatFailingWriter{ResponseWriter: c.Writer, failAfter: 0} + body := []byte(`{"model":"gpt-5.4","messages":[{"role":"user","content":"hello"}],"stream":true}`) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", bytes.NewReader(body)) + c.Request.Header.Set("Content-Type", "application/json") + + upstreamBody := strings.Join([]string{ + `data: {"type":"response.created","response":{"id":"resp_1","model":"gpt-5.4","status":"in_progress","output":[]}}`, + "", + `data: {"type":"response.output_text.delta","delta":"ok"}`, + "", + `data: {"type":"response.completed","response":{"id":"resp_1","object":"response","model":"gpt-5.4","status":"completed","output":[{"type":"message","id":"msg_1","role":"assistant","status":"completed","content":[{"type":"output_text","text":"ok"}]}],"usage":{"input_tokens":11,"output_tokens":5,"total_tokens":16,"input_tokens_details":{"cached_tokens":4}}}}`, + "", + "data: [DONE]", + "", + }, "\n") + upstream := &httpUpstreamRecorder{resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_chat_disconnect"}}, + Body: io.NopCloser(strings.NewReader(upstreamBody)), + }} + + svc := &OpenAIGatewayService{httpUpstream: upstream} + account := &Account{ + ID: 1, + Name: "openai-oauth", + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Concurrency: 1, + Credentials: map[string]any{ + "access_token": "oauth-token", + "chatgpt_account_id": "chatgpt-acc", + }, + } + + result, err := svc.ForwardAsChatCompletions(context.Background(), c, account, body, "", "gpt-5.1") + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, 11, result.Usage.InputTokens) + require.Equal(t, 5, result.Usage.OutputTokens) + require.Equal(t, 4, result.Usage.CacheReadInputTokens) +} + +func TestForwardAsChatCompletions_TerminalUsageWithoutUpstreamCloseReturns(t *testing.T) { + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Writer = &openAIChatFailingWriter{ResponseWriter: c.Writer, failAfter: 0} + body := []byte(`{"model":"gpt-5.4","messages":[{"role":"user","content":"hello"}],"stream":true}`) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", bytes.NewReader(body)) + c.Request.Header.Set("Content-Type", "application/json") + + upstreamBody := []byte(`data: {"type":"response.completed","response":{"id":"resp_1","object":"response","model":"gpt-5.4","status":"completed","output":[{"type":"message","id":"msg_1","role":"assistant","status":"completed","content":[{"type":"output_text","text":"ok"}]}],"usage":{"input_tokens":17,"output_tokens":8,"total_tokens":25,"input_tokens_details":{"cached_tokens":6}}}}` + "\n\n") + upstreamStream := newOpenAICompatBlockingReadCloser(upstreamBody) + defer func() { + require.NoError(t, upstreamStream.Close()) + }() + upstream := &httpUpstreamRecorder{resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_chat_terminal_no_close"}}, + Body: upstreamStream, + }} + + svc := &OpenAIGatewayService{httpUpstream: upstream} + account := &Account{ + ID: 1, + Name: "openai-oauth", + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Concurrency: 1, + Credentials: map[string]any{ + "access_token": "oauth-token", + "chatgpt_account_id": "chatgpt-acc", + }, + } + + type forwardResult struct { + result *OpenAIForwardResult + err error + } + resultCh := make(chan forwardResult, 1) + go func() { + result, err := svc.ForwardAsChatCompletions(context.Background(), c, account, body, "", "gpt-5.1") + resultCh <- forwardResult{result: result, err: err} + }() + + select { + case got := <-resultCh: + require.NoError(t, got.err) + require.NotNil(t, got.result) + require.Equal(t, 17, got.result.Usage.InputTokens) + require.Equal(t, 8, got.result.Usage.OutputTokens) + require.Equal(t, 6, got.result.Usage.CacheReadInputTokens) + case <-time.After(time.Second): + require.Fail(t, "ForwardAsChatCompletions should return after terminal usage event even if upstream keeps the connection open") + } +} + +func TestForwardAsChatCompletions_BufferedTerminalWithoutUpstreamCloseReturns(t *testing.T) { + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + body := []byte(`{"model":"gpt-5.4","messages":[{"role":"user","content":"hello"}],"stream":false}`) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", bytes.NewReader(body)) + c.Request.Header.Set("Content-Type", "application/json") + + upstreamBody := []byte(`data: {"type":"response.completed","response":{"id":"resp_1","object":"response","model":"gpt-5.4","status":"completed","output":[{"type":"message","id":"msg_1","role":"assistant","status":"completed","content":[{"type":"output_text","text":"ok"}]}],"usage":{"input_tokens":17,"output_tokens":8,"total_tokens":25,"input_tokens_details":{"cached_tokens":6}}}}` + "\n\n") + upstreamStream := newOpenAICompatBlockingReadCloser(upstreamBody) + defer func() { + require.NoError(t, upstreamStream.Close()) + }() + upstream := &httpUpstreamRecorder{resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_chat_buffered_terminal_no_close"}}, + Body: upstreamStream, + }} + + svc := &OpenAIGatewayService{httpUpstream: upstream} + account := &Account{ + ID: 1, + Name: "openai-oauth", + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Concurrency: 1, + Credentials: map[string]any{ + "access_token": "oauth-token", + "chatgpt_account_id": "chatgpt-acc", + }, + } + + type forwardResult struct { + result *OpenAIForwardResult + err error + } + resultCh := make(chan forwardResult, 1) + go func() { + result, err := svc.ForwardAsChatCompletions(context.Background(), c, account, body, "", "gpt-5.1") + resultCh <- forwardResult{result: result, err: err} + }() + + select { + case got := <-resultCh: + require.NoError(t, got.err) + require.NotNil(t, got.result) + require.Equal(t, 17, got.result.Usage.InputTokens) + require.Equal(t, 8, got.result.Usage.OutputTokens) + require.Equal(t, 6, got.result.Usage.CacheReadInputTokens) + require.Contains(t, rec.Body.String(), `"finish_reason":"stop"`) + case <-time.After(time.Second): + require.Fail(t, "ForwardAsChatCompletions buffered response should return after terminal usage event even if upstream keeps the connection open") + } +} + +func TestForwardAsChatCompletions_DoneSentinelWithoutTerminalReturnsError(t *testing.T) { + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + body := []byte(`{"model":"gpt-5.4","messages":[{"role":"user","content":"hello"}],"stream":true}`) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", bytes.NewReader(body)) + c.Request.Header.Set("Content-Type", "application/json") + + upstreamBody := "data: [DONE]\n\n" + upstream := &httpUpstreamRecorder{resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_chat_missing_terminal"}}, + Body: io.NopCloser(strings.NewReader(upstreamBody)), + }} + + svc := &OpenAIGatewayService{httpUpstream: upstream} + account := &Account{ + ID: 1, + Name: "openai-oauth", + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Concurrency: 1, + Credentials: map[string]any{ + "access_token": "oauth-token", + "chatgpt_account_id": "chatgpt-acc", + }, + } + + result, err := svc.ForwardAsChatCompletions(context.Background(), c, account, body, "", "gpt-5.1") + require.Error(t, err) + require.Contains(t, err.Error(), "missing terminal event") + require.NotNil(t, result) + require.Zero(t, result.Usage.InputTokens) + require.Zero(t, result.Usage.OutputTokens) +} + +func TestForwardAsChatCompletions_UpstreamRequestIgnoresClientCancel(t *testing.T) { + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + reqCtx, cancel := context.WithCancel(context.Background()) + body := []byte(`{"model":"gpt-5.4","messages":[{"role":"user","content":"hello"}],"stream":false}`) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", bytes.NewReader(body)).WithContext(reqCtx) + c.Request.Header.Set("Content-Type", "application/json") + cancel() + + upstreamBody := strings.Join([]string{ + `data: {"type":"response.completed","response":{"id":"resp_1","object":"response","model":"gpt-5.4","status":"completed","output":[{"type":"message","id":"msg_1","role":"assistant","status":"completed","content":[{"type":"output_text","text":"ok"}]}],"usage":{"input_tokens":5,"output_tokens":2,"total_tokens":7}}}`, + "", + "data: [DONE]", + "", + }, "\n") + upstream := &httpUpstreamRecorder{resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_chat_ctx"}}, + Body: io.NopCloser(strings.NewReader(upstreamBody)), + }} + + svc := &OpenAIGatewayService{httpUpstream: upstream} + account := &Account{ + ID: 1, + Name: "openai-oauth", + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Concurrency: 1, + Credentials: map[string]any{ + "access_token": "oauth-token", + "chatgpt_account_id": "chatgpt-acc", + }, + } + + result, err := svc.ForwardAsChatCompletions(reqCtx, c, account, body, "", "gpt-5.1") + require.NoError(t, err) + require.NotNil(t, result) + require.NotNil(t, upstream.lastReq) + require.NoError(t, upstream.lastReq.Context().Err()) +} diff --git a/backend/internal/service/openai_gateway_messages.go b/backend/internal/service/openai_gateway_messages.go index 4e0ebb2e..aefa8fd2 100644 --- a/backend/internal/service/openai_gateway_messages.go +++ b/backend/internal/service/openai_gateway_messages.go @@ -10,6 +10,7 @@ import ( "io" "net/http" "strings" + "sync/atomic" "time" "github.com/Wei-Shaw/sub2api/internal/pkg/apicompat" @@ -39,12 +40,54 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic( if err := json.Unmarshal(body, &anthropicReq); err != nil { return nil, fmt.Errorf("parse anthropic request: %w", err) } + anthropicDigestReq := cloneAnthropicRequestForDigest(&anthropicReq) originalModel := anthropicReq.Model applyOpenAICompatModelNormalization(&anthropicReq) normalizedModel := anthropicReq.Model clientStream := anthropicReq.Stream // client's original stream preference - // 2. Convert Anthropic → Responses + // 2. Model mapping + billingModel := resolveOpenAIForwardModel(account, normalizedModel, defaultMappedModel) + upstreamModel := normalizeOpenAIModelForUpstream(account, billingModel) + promptCacheKey = strings.TrimSpace(promptCacheKey) + apiKeyID := getAPIKeyIDFromContext(c) + anthropicDigestChain := "" + anthropicMatchedDigestChain := "" + compatPromptCacheInjected := false + if promptCacheKey == "" && shouldAutoInjectPromptCacheKeyForCompat(upstreamModel) { + promptCacheKey = promptCacheKeyFromAnthropicMetadataSession(&anthropicReq) + if promptCacheKey == "" { + promptCacheKey = deriveAnthropicCacheControlPromptCacheKey(&anthropicReq) + } + if promptCacheKey == "" { + anthropicDigestChain = buildOpenAICompatAnthropicDigestChain(anthropicDigestReq) + if reusedKey, matchedChain := s.findOpenAICompatAnthropicDigestPromptCacheKey(account, apiKeyID, anthropicDigestChain); reusedKey != "" { + promptCacheKey = reusedKey + anthropicMatchedDigestChain = matchedChain + } else { + promptCacheKey = promptCacheKeyFromAnthropicDigest(anthropicDigestChain) + } + } + compatPromptCacheInjected = promptCacheKey != "" + } + compatReplayTrimmed := false + compatReplayGuardEnabled := shouldAutoInjectPromptCacheKeyForCompat(upstreamModel) + compatContinuationEnabled := openAICompatContinuationEnabled(account, upstreamModel) + previousResponseID := "" + if compatContinuationEnabled { + previousResponseID = s.getOpenAICompatSessionResponseID(ctx, c, account, promptCacheKey) + } + compatContinuationDisabled := compatContinuationEnabled && + s.isOpenAICompatSessionContinuationDisabled(ctx, c, account, promptCacheKey) + compatTurnState := "" + // OAuth/Plus relies on session_id + x-codex-turn-state; trimming to a + // sliding 12-message window makes the cached prefix stall at system/tools. + // Keep full replay there so upstream prompt caching can grow turn by turn. + if compatReplayGuardEnabled && account.Type != AccountTypeOAuth && previousResponseID == "" && !compatContinuationDisabled { + compatReplayTrimmed = applyAnthropicCompatFullReplayGuard(&anthropicReq) + } + + // 3. Convert Anthropic → Responses after compatibility-only replay guard. responsesReq, err := apicompat.AnthropicToResponses(&anthropicReq) if err != nil { return nil, fmt.Errorf("convert anthropic to responses: %w", err) @@ -55,24 +98,50 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic( responsesReq.Stream = true isStream := true - // 2b. Handle BetaFastMode → service_tier: "priority" + // 3b. Handle BetaFastMode → service_tier: "priority" if containsBetaToken(c.GetHeader("anthropic-beta"), claude.BetaFastMode) { responsesReq.ServiceTier = "priority" } - // 3. Model mapping - billingModel := resolveOpenAIForwardModel(account, normalizedModel, defaultMappedModel) - upstreamModel := normalizeOpenAIModelForUpstream(account, billingModel) responsesReq.Model = upstreamModel + if previousResponseID != "" { + responsesReq.PreviousResponseID = previousResponseID + trimAnthropicCompatResponsesInputToLatestTurn(responsesReq) + } + if compatReplayGuardEnabled && account.Type != AccountTypeOAuth { + appendOpenAICompatClaudeCodeTodoGuard(responsesReq) + } - logger.L().Debug("openai messages: model mapping applied", + logFields := []zap.Field{ zap.Int64("account_id", account.ID), zap.String("original_model", originalModel), zap.String("normalized_model", normalizedModel), zap.String("billing_model", billingModel), zap.String("upstream_model", upstreamModel), zap.Bool("stream", isStream), - ) + } + if compatPromptCacheInjected { + logFields = append(logFields, + zap.Bool("compat_prompt_cache_key_injected", true), + zap.String("compat_prompt_cache_key_sha256", hashSensitiveValueForLog(promptCacheKey)), + ) + } + if compatReplayTrimmed { + logFields = append(logFields, + zap.Bool("compat_full_replay_trimmed", true), + zap.Int("compat_messages_after_trim", len(anthropicReq.Messages)), + ) + } + if previousResponseID != "" { + logFields = append(logFields, + zap.Bool("compat_previous_response_id_attached", true), + zap.String("compat_previous_response_id", truncateOpenAIWSLogValue(previousResponseID, openAIWSIDValueMaxLen)), + ) + } + if compatTurnState != "" { + logFields = append(logFields, zap.Bool("compat_turn_state_attached", true)) + } + logger.L().Debug("openai messages: model mapping applied", logFields...) // 4. Marshal Responses request body, then apply OAuth codex transform responsesBody, err := json.Marshal(responsesReq) @@ -85,7 +154,10 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic( if err := json.Unmarshal(responsesBody, &reqBody); err != nil { return nil, fmt.Errorf("unmarshal for codex transform: %w", err) } - codexResult := applyCodexOAuthTransform(reqBody, false, false) + codexResult := applyCodexOAuthTransformWithOptions(reqBody, codexOAuthTransformOptions{ + SkipDefaultInstructions: true, + PreserveToolCallIDs: true, + }) forcedTemplateText := "" if s.cfg != nil { forcedTemplateText = s.cfg.Gateway.ForcedCodexInstructionsTemplate @@ -95,6 +167,9 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic( templateUpstreamModel = codexResult.NormalizedModel } existingInstructions, _ := reqBody["instructions"].(string) + if strings.TrimSpace(existingInstructions) == "" { + existingInstructions = extractPromptLikeInstructionsFromInput(reqBody) + } if _, err := applyForcedCodexInstructionsTemplate(reqBody, forcedTemplateText, forcedCodexInstructionsTemplateData{ ExistingInstructions: strings.TrimSpace(existingInstructions), OriginalModel: originalModel, @@ -104,13 +179,19 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic( }); err != nil { return nil, err } + ensureCodexOAuthInstructionsField(reqBody) + if shouldAutoInjectPromptCacheKeyForCompat(upstreamModel) { + appendOpenAICompatClaudeCodeTodoGuardToRequestBody(reqBody) + } if codexResult.NormalizedModel != "" { upstreamModel = codexResult.NormalizedModel } if codexResult.PromptCacheKey != "" { promptCacheKey = codexResult.PromptCacheKey - } else if promptCacheKey != "" { - reqBody["prompt_cache_key"] = promptCacheKey + } + delete(reqBody, "prompt_cache_key") + if shouldAutoInjectPromptCacheKeyForCompat(upstreamModel) { + compatTurnState = s.getOpenAICompatSessionTurnState(ctx, c, account, promptCacheKey) } // OAuth codex transform forces stream=true upstream, so always use // the streaming response handler regardless of what the client asked. @@ -163,7 +244,9 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic( } // 6. Build upstream request - upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, responsesBody, token, isStream, promptCacheKey, false) + upstreamCtx, releaseUpstreamCtx := detachUpstreamContext(ctx) + upstreamReq, err := s.buildUpstreamRequest(upstreamCtx, c, account, responsesBody, token, isStream, promptCacheKey, false) + releaseUpstreamCtx() if err != nil { return nil, fmt.Errorf("build upstream request: %w", err) } @@ -171,8 +254,25 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic( // Override session_id with a deterministic UUID derived from the isolated // session key, ensuring different API keys produce different upstream sessions. if promptCacheKey != "" { - apiKeyID := getAPIKeyIDFromContext(c) - upstreamReq.Header.Set("session_id", generateSessionUUID(isolateOpenAISessionID(apiKeyID, promptCacheKey))) + isolatedSessionID := generateSessionUUID(isolateOpenAISessionID(apiKeyID, promptCacheKey)) + upstreamReq.Header.Set("session_id", isolatedSessionID) + if upstreamReq.Header.Get("conversation_id") != "" { + upstreamReq.Header.Set("conversation_id", isolatedSessionID) + } + } + if account.Type == AccountTypeOAuth { + // Anthropic Messages compatibility uses the ChatGPT Codex SSE endpoint. + // Match airgate-openai's request shape: the SSE endpoint does not need + // the Responses experimental beta header, and forcing originator can make + // ChatGPT select a different internal continuation path. + upstreamReq.Header.Del("OpenAI-Beta") + upstreamReq.Header.Del("originator") + } + if account.Type == AccountTypeOAuth && promptCacheKey != "" && strings.TrimSpace(c.GetHeader("conversation_id")) == "" { + upstreamReq.Header.Del("conversation_id") + } + if compatTurnState != "" && upstreamReq.Header.Get("x-codex-turn-state") == "" { + upstreamReq.Header.Set("x-codex-turn-state", compatTurnState) } // 7. Send request @@ -205,6 +305,19 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic( upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody)) upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) + if previousResponseID != "" && (isOpenAICompatPreviousResponseNotFound(resp.StatusCode, upstreamMsg, respBody) || isOpenAICompatPreviousResponseUnsupported(resp.StatusCode, upstreamMsg, respBody)) { + if isOpenAICompatPreviousResponseUnsupported(resp.StatusCode, upstreamMsg, respBody) { + s.disableOpenAICompatSessionContinuation(ctx, c, account, promptCacheKey) + } else { + s.deleteOpenAICompatSessionResponseID(ctx, c, account, promptCacheKey) + } + logger.L().Info("openai messages: previous_response_id unavailable, retrying without continuation", + zap.Int64("account_id", account.ID), + zap.String("previous_response_id", truncateOpenAIWSLogValue(previousResponseID, openAIWSIDValueMaxLen)), + zap.String("upstream_model", upstreamModel), + ) + return s.ForwardAsAnthropic(ctx, c, account, body, promptCacheKey, defaultMappedModel) + } if s.shouldFailoverOpenAIUpstreamResponse(resp.StatusCode, upstreamMsg, respBody) { upstreamDetail := "" if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { @@ -237,6 +350,12 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic( return s.handleAnthropicErrorResponse(resp, c, account) } + if account.Type == AccountTypeOAuth && promptCacheKey != "" { + if turnState := strings.TrimSpace(resp.Header.Get("x-codex-turn-state")); turnState != "" { + s.bindOpenAICompatSessionTurnState(ctx, c, account, promptCacheKey, turnState) + } + } + // 9. Handle normal response // Upstream is always streaming; choose response format based on client preference. var result *OpenAIForwardResult @@ -250,6 +369,12 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic( // Propagate ServiceTier and ReasoningEffort to result for billing if handleErr == nil && result != nil { + if compatContinuationEnabled && promptCacheKey != "" && result.ResponseID != "" { + s.bindOpenAICompatSessionResponseID(ctx, c, account, promptCacheKey, result.ResponseID) + } + if promptCacheKey != "" && anthropicDigestChain != "" { + s.bindOpenAICompatAnthropicDigestPromptCacheKey(account, apiKeyID, anthropicDigestChain, promptCacheKey, anthropicMatchedDigestChain) + } if responsesReq.ServiceTier != "" { st := responsesReq.ServiceTier result.ServiceTier = &st @@ -270,6 +395,19 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic( return result, handleErr } +func ensureCodexOAuthInstructionsField(reqBody map[string]any) { + if reqBody == nil { + return + } + if value, ok := reqBody["instructions"]; !ok || value == nil { + reqBody["instructions"] = "" + return + } + if _, ok := reqBody["instructions"].(string); !ok { + reqBody["instructions"] = "" + } +} + // handleAnthropicErrorResponse reads an upstream error and returns it in // Anthropic error format. func (s *OpenAIGatewayService) handleAnthropicErrorResponse( @@ -296,61 +434,9 @@ func (s *OpenAIGatewayService) handleAnthropicBufferedStreamingResponse( ) (*OpenAIForwardResult, error) { requestID := resp.Header.Get("x-request-id") - scanner := bufio.NewScanner(resp.Body) - maxLineSize := defaultMaxLineSize - if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 { - maxLineSize = s.cfg.Gateway.MaxLineSize - } - scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize) - - var finalResponse *apicompat.ResponsesResponse - var usage OpenAIUsage - acc := apicompat.NewBufferedResponseAccumulator() - - for scanner.Scan() { - line := scanner.Text() - - if !strings.HasPrefix(line, "data: ") || line == "data: [DONE]" { - continue - } - payload := line[6:] - - var event apicompat.ResponsesStreamEvent - if err := json.Unmarshal([]byte(payload), &event); err != nil { - logger.L().Warn("openai messages buffered: failed to parse event", - zap.Error(err), - zap.String("request_id", requestID), - ) - continue - } - - // Accumulate delta content for fallback when terminal output is empty. - acc.ProcessEvent(&event) - - // Terminal events carry the complete ResponsesResponse with output + usage. - if (event.Type == "response.completed" || event.Type == "response.done" || - event.Type == "response.incomplete" || event.Type == "response.failed") && - event.Response != nil { - finalResponse = event.Response - if event.Response.Usage != nil { - usage = OpenAIUsage{ - InputTokens: event.Response.Usage.InputTokens, - OutputTokens: event.Response.Usage.OutputTokens, - } - if event.Response.Usage.InputTokensDetails != nil { - usage.CacheReadInputTokens = event.Response.Usage.InputTokensDetails.CachedTokens - } - } - } - } - - if err := scanner.Err(); err != nil { - if !errors.Is(err, context.Canceled) && !errors.Is(err, context.DeadlineExceeded) { - logger.L().Warn("openai messages buffered: read error", - zap.Error(err), - zap.String("request_id", requestID), - ) - } + finalResponse, usage, acc, err := s.readOpenAICompatBufferedTerminal(resp, "openai messages buffered", requestID) + if err != nil { + return nil, err } if finalResponse == nil { @@ -371,6 +457,7 @@ func (s *OpenAIGatewayService) handleAnthropicBufferedStreamingResponse( return &OpenAIForwardResult{ RequestID: requestID, + ResponseID: finalResponse.ID, Usage: usage, Model: originalModel, BillingModel: billingModel, @@ -380,6 +467,153 @@ func (s *OpenAIGatewayService) handleAnthropicBufferedStreamingResponse( }, nil } +func isOpenAICompatResponsesTerminalEvent(eventType string) bool { + switch strings.TrimSpace(eventType) { + case "response.completed", "response.done", "response.incomplete", "response.failed": + return true + default: + return false + } +} + +func isOpenAICompatDoneSentinelLine(line string) bool { + payload, ok := extractOpenAISSEDataLine(line) + return ok && strings.TrimSpace(payload) == "[DONE]" +} + +func (s *OpenAIGatewayService) readOpenAICompatBufferedTerminal( + resp *http.Response, + logPrefix string, + requestID string, +) (*apicompat.ResponsesResponse, OpenAIUsage, *apicompat.BufferedResponseAccumulator, error) { + acc := apicompat.NewBufferedResponseAccumulator() + var usage OpenAIUsage + if resp == nil || resp.Body == nil { + return nil, usage, acc, errors.New("upstream response body is nil") + } + + scanner := bufio.NewScanner(resp.Body) + maxLineSize := defaultMaxLineSize + if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 { + maxLineSize = s.cfg.Gateway.MaxLineSize + } + scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize) + + streamInterval := time.Duration(0) + if s.cfg != nil && s.cfg.Gateway.StreamDataIntervalTimeout > 0 { + streamInterval = time.Duration(s.cfg.Gateway.StreamDataIntervalTimeout) * time.Second + } + var timeoutCh <-chan time.Time + var timeoutTimer *time.Timer + resetTimeout := func() { + if streamInterval <= 0 { + return + } + if timeoutTimer == nil { + timeoutTimer = time.NewTimer(streamInterval) + timeoutCh = timeoutTimer.C + return + } + if !timeoutTimer.Stop() { + select { + case <-timeoutTimer.C: + default: + } + } + timeoutTimer.Reset(streamInterval) + } + stopTimeout := func() { + if timeoutTimer == nil { + return + } + if !timeoutTimer.Stop() { + select { + case <-timeoutTimer.C: + default: + } + } + } + resetTimeout() + defer stopTimeout() + + type scanEvent struct { + line string + err error + } + events := make(chan scanEvent, 16) + done := make(chan struct{}) + go func() { + defer close(events) + for scanner.Scan() { + select { + case events <- scanEvent{line: scanner.Text()}: + case <-done: + return + } + } + if err := scanner.Err(); err != nil { + select { + case events <- scanEvent{err: err}: + case <-done: + } + } + }() + defer close(done) + + for { + select { + case ev, ok := <-events: + if !ok { + return nil, usage, acc, nil + } + resetTimeout() + if ev.err != nil { + if !errors.Is(ev.err, context.Canceled) && !errors.Is(ev.err, context.DeadlineExceeded) { + logger.L().Warn(logPrefix+": read error", + zap.Error(ev.err), + zap.String("request_id", requestID), + ) + } + return nil, usage, acc, ev.err + } + + if isOpenAICompatDoneSentinelLine(ev.line) { + return nil, usage, acc, nil + } + payload, ok := extractOpenAISSEDataLine(ev.line) + if !ok || payload == "" { + continue + } + + var event apicompat.ResponsesStreamEvent + if err := json.Unmarshal([]byte(payload), &event); err != nil { + logger.L().Warn(logPrefix+": failed to parse event", + zap.Error(err), + zap.String("request_id", requestID), + ) + continue + } + + acc.ProcessEvent(&event) + + if isOpenAICompatResponsesTerminalEvent(event.Type) && event.Response != nil { + if event.Response.Usage != nil { + usage = copyOpenAIUsageFromResponsesUsage(event.Response.Usage) + } + return event.Response, usage, acc, nil + } + + case <-timeoutCh: + _ = resp.Body.Close() + logger.L().Warn(logPrefix+": data interval timeout", + zap.String("request_id", requestID), + zap.Duration("interval", streamInterval), + ) + return nil, usage, acc, fmt.Errorf("stream data interval timeout") + } + } +} + // handleAnthropicStreamingResponse reads Responses SSE events from upstream, // converts each to Anthropic SSE events, and writes them to the client. // When StreamKeepaliveInterval is configured, it uses a goroutine + channel @@ -407,8 +641,10 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse( state := apicompat.NewResponsesEventToAnthropicState() state.Model = originalModel var usage OpenAIUsage + responseID := "" var firstTokenMs *int firstChunk := true + clientDisconnected := false scanner := bufio.NewScanner(resp.Body) maxLineSize := defaultMaxLineSize @@ -417,10 +653,25 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse( } scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize) + streamInterval := time.Duration(0) + if s.cfg != nil && s.cfg.Gateway.StreamDataIntervalTimeout > 0 { + streamInterval = time.Duration(s.cfg.Gateway.StreamDataIntervalTimeout) * time.Second + } + var intervalTicker *time.Ticker + if streamInterval > 0 { + intervalTicker = time.NewTicker(streamInterval) + defer intervalTicker.Stop() + } + var intervalCh <-chan time.Time + if intervalTicker != nil { + intervalCh = intervalTicker.C + } + // resultWithUsage builds the final result snapshot. resultWithUsage := func() *OpenAIForwardResult { return &OpenAIForwardResult{ RequestID: requestID, + ResponseID: responseID, Usage: usage, Model: originalModel, BillingModel: billingModel, @@ -432,7 +683,6 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse( } // processDataLine handles a single "data: ..." SSE line from upstream. - // Returns (clientDisconnected bool). processDataLine := func(payload string) bool { if firstChunk { firstChunk = false @@ -449,53 +699,63 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse( return false } - // Extract usage from completion events - if (event.Type == "response.completed" || event.Type == "response.incomplete" || event.Type == "response.failed") && - event.Response != nil && event.Response.Usage != nil { - usage = OpenAIUsage{ - InputTokens: event.Response.Usage.InputTokens, - OutputTokens: event.Response.Usage.OutputTokens, + // 仅按兼容转换器支持的终止事件提取 usage,避免无意扩大事件语义。 + isTerminalEvent := isOpenAICompatResponsesTerminalEvent(event.Type) + if isTerminalEvent && event.Response != nil { + if id := strings.TrimSpace(event.Response.ID); id != "" { + responseID = id } - if event.Response.Usage.InputTokensDetails != nil { - usage.CacheReadInputTokens = event.Response.Usage.InputTokensDetails.CachedTokens + if event.Response.Usage != nil { + usage = copyOpenAIUsageFromResponsesUsage(event.Response.Usage) } } // Convert to Anthropic events events := apicompat.ResponsesEventToAnthropicEvents(&event, state) - for _, evt := range events { - sse, err := apicompat.ResponsesAnthropicEventToSSE(evt) - if err != nil { - logger.L().Warn("openai messages stream: failed to marshal event", - zap.Error(err), - zap.String("request_id", requestID), - ) - continue - } - if _, err := fmt.Fprint(c.Writer, sse); err != nil { - logger.L().Info("openai messages stream: client disconnected", - zap.String("request_id", requestID), - ) - return true + if !clientDisconnected { + for _, evt := range events { + sse, err := apicompat.ResponsesAnthropicEventToSSE(evt) + if err != nil { + logger.L().Warn("openai messages stream: failed to marshal event", + zap.Error(err), + zap.String("request_id", requestID), + ) + continue + } + if _, err := fmt.Fprint(c.Writer, sse); err != nil { + clientDisconnected = true + logger.L().Info("openai messages stream: client disconnected, continuing to drain upstream for billing", + zap.String("request_id", requestID), + ) + break + } } } - if len(events) > 0 { + if len(events) > 0 && !clientDisconnected { c.Writer.Flush() } - return false + return isTerminalEvent } // finalizeStream sends any remaining Anthropic events and returns the result. finalizeStream := func() (*OpenAIForwardResult, error) { - if finalEvents := apicompat.FinalizeResponsesAnthropicStream(state); len(finalEvents) > 0 { + if finalEvents := apicompat.FinalizeResponsesAnthropicStream(state); len(finalEvents) > 0 && !clientDisconnected { for _, evt := range finalEvents { sse, err := apicompat.ResponsesAnthropicEventToSSE(evt) if err != nil { continue } - fmt.Fprint(c.Writer, sse) //nolint:errcheck + if _, err := fmt.Fprint(c.Writer, sse); err != nil { + clientDisconnected = true + logger.L().Info("openai messages stream: client disconnected during final flush", + zap.String("request_id", requestID), + ) + break + } + } + if !clientDisconnected { + c.Writer.Flush() } - c.Writer.Flush() } return resultWithUsage(), nil } @@ -509,6 +769,9 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse( ) } } + missingTerminalErr := func() (*OpenAIForwardResult, error) { + return resultWithUsage(), fmt.Errorf("stream usage incomplete: missing terminal event") + } // ── Determine keepalive interval ── keepaliveInterval := time.Duration(0) @@ -517,18 +780,25 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse( } // ── No keepalive: fast synchronous path (no goroutine overhead) ── - if keepaliveInterval <= 0 { + if streamInterval <= 0 && keepaliveInterval <= 0 { for scanner.Scan() { line := scanner.Text() - if !strings.HasPrefix(line, "data: ") || line == "data: [DONE]" { + if isOpenAICompatDoneSentinelLine(line) { + return missingTerminalErr() + } + payload, ok := extractOpenAISSEDataLine(line) + if !ok { continue } - if processDataLine(line[6:]) { - return resultWithUsage(), nil + if processDataLine(payload) { + return finalizeStream() } } - handleScanErr(scanner.Err()) - return finalizeStream() + if err := scanner.Err(); err != nil { + handleScanErr(err) + return resultWithUsage(), fmt.Errorf("stream usage incomplete: %w", err) + } + return missingTerminalErr() } // ── With keepalive: goroutine + channel + select ── @@ -538,6 +808,8 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse( } events := make(chan scanEvent, 16) done := make(chan struct{}) + var lastReadAt int64 + atomic.StoreInt64(&lastReadAt, time.Now().UnixNano()) sendEvent := func(ev scanEvent) bool { select { case events <- ev: @@ -549,6 +821,7 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse( go func() { defer close(events) for scanner.Scan() { + atomic.StoreInt64(&lastReadAt, time.Now().UnixNano()) if !sendEvent(scanEvent{line: scanner.Text()}) { return } @@ -559,8 +832,15 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse( }() defer close(done) - keepaliveTicker := time.NewTicker(keepaliveInterval) - defer keepaliveTicker.Stop() + var keepaliveTicker *time.Ticker + if keepaliveInterval > 0 { + keepaliveTicker = time.NewTicker(keepaliveInterval) + defer keepaliveTicker.Stop() + } + var keepaliveCh <-chan time.Time + if keepaliveTicker != nil { + keepaliveCh = keepaliveTicker.C + } lastDataAt := time.Now() for { @@ -568,22 +848,44 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse( case ev, ok := <-events: if !ok { // Upstream closed - return finalizeStream() + return missingTerminalErr() } if ev.err != nil { handleScanErr(ev.err) - return finalizeStream() + return resultWithUsage(), fmt.Errorf("stream usage incomplete: %w", ev.err) } lastDataAt = time.Now() line := ev.line - if !strings.HasPrefix(line, "data: ") || line == "data: [DONE]" { + if isOpenAICompatDoneSentinelLine(line) { + return missingTerminalErr() + } + payload, ok := extractOpenAISSEDataLine(line) + if !ok { continue } - if processDataLine(line[6:]) { - return resultWithUsage(), nil + if processDataLine(payload) { + return finalizeStream() } - case <-keepaliveTicker.C: + case <-intervalCh: + lastRead := time.Unix(0, atomic.LoadInt64(&lastReadAt)) + if time.Since(lastRead) < streamInterval { + continue + } + if clientDisconnected { + return resultWithUsage(), fmt.Errorf("stream usage incomplete after timeout") + } + logger.L().Warn("openai messages stream: data interval timeout", + zap.String("request_id", requestID), + zap.String("model", originalModel), + zap.Duration("interval", streamInterval), + ) + return resultWithUsage(), fmt.Errorf("stream data interval timeout") + + case <-keepaliveCh: + if clientDisconnected { + continue + } if time.Since(lastDataAt) < keepaliveInterval { continue } @@ -593,7 +895,8 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse( logger.L().Info("openai messages stream: client disconnected during keepalive", zap.String("request_id", requestID), ) - return resultWithUsage(), nil + clientDisconnected = true + continue } c.Writer.Flush() } @@ -610,3 +913,17 @@ func writeAnthropicError(c *gin.Context, statusCode int, errType, message string }, }) } + +func copyOpenAIUsageFromResponsesUsage(usage *apicompat.ResponsesUsage) OpenAIUsage { + if usage == nil { + return OpenAIUsage{} + } + result := OpenAIUsage{ + InputTokens: usage.InputTokens, + OutputTokens: usage.OutputTokens, + } + if usage.InputTokensDetails != nil { + result.CacheReadInputTokens = usage.InputTokensDetails.CachedTokens + } + return result +} diff --git a/backend/internal/service/openai_gateway_record_usage_test.go b/backend/internal/service/openai_gateway_record_usage_test.go index 47ff4e3b..3791c5a8 100644 --- a/backend/internal/service/openai_gateway_record_usage_test.go +++ b/backend/internal/service/openai_gateway_record_usage_test.go @@ -52,6 +52,12 @@ func (s *openAIRecordUsageBillingRepoStub) Apply(ctx context.Context, cmd *Usage return &UsageBillingApplyResult{Applied: true}, nil } +func TestOpenAIGatewayServiceRecordUsage_RejectsNilInput(t *testing.T) { + svc := &OpenAIGatewayService{} + require.Error(t, svc.RecordUsage(context.Background(), nil)) + require.Error(t, svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{})) +} + type openAIRecordUsageUserRepoStub struct { UserRepository @@ -186,6 +192,56 @@ func max(a, b int) int { return b } +func TestOpenAIGatewayServiceRecordUsage_ZeroUsageStillWritesUsageLog(t *testing.T) { + usageRepo := &openAIRecordUsageLogRepoStub{inserted: true} + billingRepo := &openAIRecordUsageBillingRepoStub{result: &UsageBillingApplyResult{Applied: true}} + userRepo := &openAIRecordUsageUserRepoStub{} + subRepo := &openAIRecordUsageSubRepoStub{} + quotaSvc := &openAIRecordUsageAPIKeyQuotaStub{} + svc := newOpenAIRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, userRepo, subRepo, nil) + + err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{ + Result: &OpenAIForwardResult{ + RequestID: "resp_zero_usage", + Usage: OpenAIUsage{}, + Model: "gpt-5.1", + Duration: time.Second, + }, + APIKey: &APIKey{ID: 1000, Quota: 100, Group: &Group{RateMultiplier: 1}}, + User: &User{ID: 2000}, + Account: &Account{ID: 3000, Type: AccountTypeAPIKey}, + APIKeyService: quotaSvc, + }) + + require.NoError(t, err) + require.Equal(t, 1, billingRepo.calls) + require.Equal(t, 1, usageRepo.calls) + require.Equal(t, 0, userRepo.deductCalls) + require.Equal(t, 0, subRepo.incrementCalls) + require.Equal(t, 0, quotaSvc.quotaCalls) + require.Equal(t, 0, quotaSvc.rateLimitCalls) + + require.NotNil(t, usageRepo.lastLog) + require.Equal(t, "resp_zero_usage", usageRepo.lastLog.RequestID) + require.Zero(t, usageRepo.lastLog.InputTokens) + require.Zero(t, usageRepo.lastLog.OutputTokens) + require.Zero(t, usageRepo.lastLog.CacheCreationTokens) + require.Zero(t, usageRepo.lastLog.CacheReadTokens) + require.Zero(t, usageRepo.lastLog.ImageOutputTokens) + require.Zero(t, usageRepo.lastLog.ImageCount) + require.Zero(t, usageRepo.lastLog.InputCost) + require.Zero(t, usageRepo.lastLog.OutputCost) + require.Zero(t, usageRepo.lastLog.TotalCost) + require.Zero(t, usageRepo.lastLog.ActualCost) + + require.NotNil(t, billingRepo.lastCmd) + require.Zero(t, billingRepo.lastCmd.BalanceCost) + require.Zero(t, billingRepo.lastCmd.SubscriptionCost) + require.Zero(t, billingRepo.lastCmd.APIKeyQuotaCost) + require.Zero(t, billingRepo.lastCmd.APIKeyRateLimitCost) + require.Zero(t, billingRepo.lastCmd.AccountQuotaCost) +} + func TestOpenAIGatewayServiceRecordUsage_UsesUserSpecificGroupRate(t *testing.T) { groupID := int64(11) groupRate := 1.4 @@ -956,9 +1012,8 @@ func TestOpenAIGatewayServiceRecordUsage_ChannelMappedDoesNotOverrideBillingMode svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil) usage := OpenAIUsage{InputTokens: 20, OutputTokens: 10} - // When channel did NOT map the model (ChannelMappedModel == OriginalModel), - // billing should use result.BillingModel (the actual model used after group - // DefaultMappedModel resolution), not the unmapped original model. + // 渠道未发生模型映射时,应使用 result.BillingModel 中记录的实际上游计费模型, + // 而不是未映射的原始请求模型。 expectedCost, err := svc.billingService.CalculateCost("gpt-5.1", UsageTokens{ InputTokens: 20, OutputTokens: 10, @@ -1032,6 +1087,101 @@ func TestOpenAIGatewayServiceRecordUsage_ChannelMappedOverridesBillingModelWhenM require.True(t, usageRepo.lastLog.ActualCost > 0, "cost must not be zero") } +func TestOpenAIGatewayServiceRecordUsage_BillsCompactOpenAIModelAlias(t *testing.T) { + usageRepo := &openAIRecordUsageLogRepoStub{inserted: true} + userRepo := &openAIRecordUsageUserRepoStub{} + subRepo := &openAIRecordUsageSubRepoStub{} + svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil) + usage := OpenAIUsage{InputTokens: 20, OutputTokens: 10} + + expectedCost, err := svc.billingService.CalculateCost("gpt-5.5", UsageTokens{ + InputTokens: 20, + OutputTokens: 10, + }, 1.1) + require.NoError(t, err) + + err = svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{ + Result: &OpenAIForwardResult{ + RequestID: "resp_compact_openai_alias", + Model: "gpt5.5", + UpstreamModel: "gpt-5.4", + Usage: usage, + Duration: time.Second, + }, + APIKey: &APIKey{ID: 10}, + User: &User{ID: 20}, + Account: &Account{ID: 30}, + }) + + require.NoError(t, err) + require.NotNil(t, usageRepo.lastLog) + require.Equal(t, "gpt5.5", usageRepo.lastLog.Model) + require.NotNil(t, usageRepo.lastLog.UpstreamModel) + require.Equal(t, "gpt-5.4", *usageRepo.lastLog.UpstreamModel) + require.InDelta(t, expectedCost.ActualCost, usageRepo.lastLog.ActualCost, 1e-12) + require.True(t, usageRepo.lastLog.ActualCost > 0, "cost must not be zero") + require.InDelta(t, expectedCost.ActualCost, userRepo.lastAmount, 1e-12) +} + +func TestOpenAIGatewayServiceRecordUsage_FallsBackToUpstreamModelWhenPrimaryUnpriceable(t *testing.T) { + usageRepo := &openAIRecordUsageLogRepoStub{inserted: true} + userRepo := &openAIRecordUsageUserRepoStub{} + subRepo := &openAIRecordUsageSubRepoStub{} + svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil) + usage := OpenAIUsage{InputTokens: 20, OutputTokens: 10} + + expectedCost, err := svc.billingService.CalculateCost("gpt-5.4", UsageTokens{ + InputTokens: 20, + OutputTokens: 10, + }, 1.1) + require.NoError(t, err) + + err = svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{ + Result: &OpenAIForwardResult{ + RequestID: "resp_unpriceable_primary_upstream_fallback", + Model: "not-priceable-alias", + BillingModel: "not-priceable-alias", + UpstreamModel: "gpt-5.4", + Usage: usage, + Duration: time.Second, + }, + APIKey: &APIKey{ID: 10}, + User: &User{ID: 20}, + Account: &Account{ID: 30}, + }) + + require.NoError(t, err) + require.NotNil(t, usageRepo.lastLog) + require.InDelta(t, expectedCost.ActualCost, usageRepo.lastLog.ActualCost, 1e-12) + require.True(t, usageRepo.lastLog.ActualCost > 0, "cost must not be zero") + require.InDelta(t, expectedCost.ActualCost, userRepo.lastAmount, 1e-12) +} + +func TestOpenAIGatewayServiceRecordUsage_ReturnsErrorWhenTokenModelCannotBePriced(t *testing.T) { + usageRepo := &openAIRecordUsageLogRepoStub{inserted: true} + userRepo := &openAIRecordUsageUserRepoStub{} + subRepo := &openAIRecordUsageSubRepoStub{} + svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil) + + err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{ + Result: &OpenAIForwardResult{ + RequestID: "resp_unpriceable_without_upstream", + Model: "not-priceable-alias", + Usage: OpenAIUsage{InputTokens: 20, OutputTokens: 10}, + Duration: time.Second, + }, + APIKey: &APIKey{ID: 10}, + User: &User{ID: 20}, + Account: &Account{ID: 30}, + }) + + require.Error(t, err) + require.Contains(t, err.Error(), "calculate OpenAI usage cost failed") + require.Equal(t, 0, usageRepo.calls) + require.Equal(t, 0, userRepo.deductCalls) + require.Equal(t, 0, subRepo.incrementCalls) +} + func TestOpenAIGatewayServiceRecordUsage_SubscriptionBillingSetsSubscriptionFields(t *testing.T) { usageRepo := &openAIRecordUsageLogRepoStub{inserted: true} userRepo := &openAIRecordUsageUserRepoStub{} @@ -1160,3 +1310,278 @@ func TestOpenAIGatewayServiceRecordUsage_ImageUsesPerImageBillingEvenWithUsageTo require.InDelta(t, 0.0, usageRepo.lastLog.OutputCost, 1e-12) require.InDelta(t, 0.0, usageRepo.lastLog.ImageOutputCost, 1e-12) } + +func TestOpenAIGatewayServiceRecordUsage_ImageSharedMultiplierPreservesExistingBehavior(t *testing.T) { + imagePrice := 0.2 + groupID := int64(121) + + usageRepo := &openAIRecordUsageLogRepoStub{inserted: true} + svc := newOpenAIRecordUsageServiceForTest(usageRepo, &openAIRecordUsageUserRepoStub{}, &openAIRecordUsageSubRepoStub{}, nil) + + err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{ + Result: &OpenAIForwardResult{ + RequestID: "resp_image_shared_multiplier", + Model: "gpt-image-2", + ImageCount: 1, + ImageSize: "1K", + Duration: time.Second, + }, + APIKey: &APIKey{ + ID: 10121, + GroupID: i64p(groupID), + Group: &Group{ + ID: groupID, + RateMultiplier: 0.15, + ImageRateIndependent: false, + ImageRateMultiplier: 1, + ImagePrice1K: &imagePrice, + }, + }, + User: &User{ID: 20121}, + Account: &Account{ID: 30121}, + }) + + require.NoError(t, err) + require.NotNil(t, usageRepo.lastLog) + require.InDelta(t, 0.2, usageRepo.lastLog.TotalCost, 1e-12) + require.InDelta(t, 0.03, usageRepo.lastLog.ActualCost, 1e-12) + require.InDelta(t, 0.15, usageRepo.lastLog.RateMultiplier, 1e-12) + require.NotNil(t, usageRepo.lastLog.BillingMode) + require.Equal(t, string(BillingModeImage), *usageRepo.lastLog.BillingMode) +} + +func TestOpenAIGatewayServiceRecordUsage_ImageSharedMultiplierUsesUserGroupOverride(t *testing.T) { + imagePrice := 0.5 + userRate := 0.2 + groupID := int64(125) + + usageRepo := &openAIRecordUsageLogRepoStub{inserted: true} + svc := newOpenAIRecordUsageServiceForTest( + usageRepo, + &openAIRecordUsageUserRepoStub{}, + &openAIRecordUsageSubRepoStub{}, + &openAIUserGroupRateRepoStub{rate: &userRate}, + ) + + err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{ + Result: &OpenAIForwardResult{ + RequestID: "resp_image_user_group_override", + Model: "gpt-image-2", + ImageCount: 1, + ImageSize: "1K", + Duration: time.Second, + }, + APIKey: &APIKey{ + ID: 10125, + GroupID: i64p(groupID), + Group: &Group{ + ID: groupID, + RateMultiplier: 0.15, + ImageRateIndependent: false, + ImageRateMultiplier: 1, + ImagePrice1K: &imagePrice, + }, + }, + User: &User{ID: 20125}, + Account: &Account{ID: 30125}, + }) + + require.NoError(t, err) + require.NotNil(t, usageRepo.lastLog) + require.InDelta(t, 0.5, usageRepo.lastLog.TotalCost, 1e-12) + require.InDelta(t, 0.1, usageRepo.lastLog.ActualCost, 1e-12) + require.InDelta(t, 0.2, usageRepo.lastLog.RateMultiplier, 1e-12) +} + +func TestOpenAIGatewayServiceRecordUsage_ImageIndependentMultiplierUsesImageRate(t *testing.T) { + imagePrice := 0.2 + groupID := int64(122) + + usageRepo := &openAIRecordUsageLogRepoStub{inserted: true} + svc := newOpenAIRecordUsageServiceForTest(usageRepo, &openAIRecordUsageUserRepoStub{}, &openAIRecordUsageSubRepoStub{}, nil) + + err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{ + Result: &OpenAIForwardResult{ + RequestID: "resp_image_independent_multiplier", + Model: "gpt-image-2", + ImageCount: 1, + ImageSize: "1K", + Duration: time.Second, + }, + APIKey: &APIKey{ + ID: 10122, + GroupID: i64p(groupID), + Group: &Group{ + ID: groupID, + RateMultiplier: 0.15, + ImageRateIndependent: true, + ImageRateMultiplier: 1, + ImagePrice1K: &imagePrice, + }, + }, + User: &User{ID: 20122}, + Account: &Account{ID: 30122}, + }) + + require.NoError(t, err) + require.NotNil(t, usageRepo.lastLog) + require.InDelta(t, 0.2, usageRepo.lastLog.TotalCost, 1e-12) + require.InDelta(t, 0.2, usageRepo.lastLog.ActualCost, 1e-12) + require.InDelta(t, 1.0, usageRepo.lastLog.RateMultiplier, 1e-12) + require.NotNil(t, usageRepo.lastLog.BillingMode) + require.Equal(t, string(BillingModeImage), *usageRepo.lastLog.BillingMode) +} + +func TestOpenAIGatewayServiceRecordUsage_ChannelImageBillingUsesImageCountAndSharedMultiplier(t *testing.T) { + groupID := int64(123) + usageRepo := &openAIRecordUsageLogRepoStub{inserted: true} + svc := newOpenAIRecordUsageServiceForTest(usageRepo, &openAIRecordUsageUserRepoStub{}, &openAIRecordUsageSubRepoStub{}, nil) + svc.resolver = newOpenAIImageChannelPricingResolverForTest(t, groupID, "gpt-image-2", 0.25) + + err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{ + Result: &OpenAIForwardResult{ + RequestID: "resp_image_channel_shared", + Model: "gpt-image-2", + ImageCount: 3, + ImageSize: "1K", + Duration: time.Second, + }, + APIKey: &APIKey{ + ID: 10123, + GroupID: i64p(groupID), + Group: &Group{ + ID: groupID, + RateMultiplier: 0.15, + ImageRateIndependent: false, + ImageRateMultiplier: 1, + }, + }, + User: &User{ID: 20123}, + Account: &Account{ID: 30123}, + }) + + require.NoError(t, err) + require.NotNil(t, usageRepo.lastLog) + require.InDelta(t, 0.75, usageRepo.lastLog.TotalCost, 1e-12) + require.InDelta(t, 0.1125, usageRepo.lastLog.ActualCost, 1e-12) + require.InDelta(t, 0.15, usageRepo.lastLog.RateMultiplier, 1e-12) + require.Equal(t, 3, usageRepo.lastLog.ImageCount) + require.NotNil(t, usageRepo.lastLog.BillingMode) + require.Equal(t, string(BillingModeImage), *usageRepo.lastLog.BillingMode) +} + +func TestOpenAIGatewayServiceRecordUsage_ChannelImageBillingUsesImageCountAndIndependentMultiplier(t *testing.T) { + groupID := int64(124) + usageRepo := &openAIRecordUsageLogRepoStub{inserted: true} + svc := newOpenAIRecordUsageServiceForTest(usageRepo, &openAIRecordUsageUserRepoStub{}, &openAIRecordUsageSubRepoStub{}, nil) + svc.resolver = newOpenAIImageChannelPricingResolverForTest(t, groupID, "gpt-image-2", 0.25) + + err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{ + Result: &OpenAIForwardResult{ + RequestID: "resp_image_channel_independent", + Model: "gpt-image-2", + ImageCount: 3, + ImageSize: "1K", + Duration: time.Second, + }, + APIKey: &APIKey{ + ID: 10124, + GroupID: i64p(groupID), + Group: &Group{ + ID: groupID, + RateMultiplier: 0.15, + ImageRateIndependent: true, + ImageRateMultiplier: 1, + }, + }, + User: &User{ID: 20124}, + Account: &Account{ID: 30124}, + }) + + require.NoError(t, err) + require.NotNil(t, usageRepo.lastLog) + require.InDelta(t, 0.75, usageRepo.lastLog.TotalCost, 1e-12) + require.InDelta(t, 0.75, usageRepo.lastLog.ActualCost, 1e-12) + require.InDelta(t, 1.0, usageRepo.lastLog.RateMultiplier, 1e-12) + require.Equal(t, 3, usageRepo.lastLog.ImageCount) + require.NotNil(t, usageRepo.lastLog.BillingMode) + require.Equal(t, string(BillingModeImage), *usageRepo.lastLog.BillingMode) +} + +func newOpenAIImageChannelPricingResolverForTest(t *testing.T, groupID int64, model string, price float64) *ModelPricingResolver { + t.Helper() + cache := newEmptyChannelCache() + cache.pricingByGroupModel[channelModelKey{groupID: groupID, model: model}] = &ChannelModelPricing{ + BillingMode: BillingModeImage, + PerRequestPrice: &price, + } + cache.channelByGroupID[groupID] = &Channel{ID: groupID, Status: StatusActive} + cache.groupPlatform[groupID] = "" + cache.loadedAt = time.Now() + cs := &ChannelService{} + cs.cache.Store(cache) + return NewModelPricingResolver(cs, NewBillingService(&config.Config{}, nil)) +} + +func TestGatewayServiceCalculateRecordUsageCost_ChannelImageBillingUsesImageCount(t *testing.T) { + groupID := int64(126) + billingService := NewBillingService(&config.Config{}, nil) + svc := &GatewayService{ + billingService: billingService, + resolver: newOpenAIImageChannelPricingResolverForTest(t, groupID, "gemini-image", 0.25), + } + + cost := svc.calculateRecordUsageCost( + context.Background(), + &ForwardResult{Model: "gemini-image", ImageCount: 2, ImageSize: "1K"}, + &APIKey{GroupID: i64p(groupID), Group: &Group{ID: groupID}}, + "gemini-image", + 0.15, + 1.0, + nil, + ) + + require.NotNil(t, cost) + require.Equal(t, string(BillingModeImage), cost.BillingMode) + require.InDelta(t, 0.5, cost.TotalCost, 1e-12) + require.InDelta(t, 0.5, cost.ActualCost, 1e-12) +} + +func TestGatewayServiceCalculateRecordUsageCost_ChannelImageBillingUsesSizeTier(t *testing.T) { + groupID := int64(127) + defaultPrice := 0.10 + price4K := 0.40 + cache := newEmptyChannelCache() + cache.pricingByGroupModel[channelModelKey{groupID: groupID, model: "gemini-image"}] = &ChannelModelPricing{ + BillingMode: BillingModeImage, + PerRequestPrice: &defaultPrice, + Intervals: []PricingInterval{{ + TierLabel: "4K", + PerRequestPrice: &price4K, + }}, + } + cache.channelByGroupID[groupID] = &Channel{ID: groupID, Status: StatusActive} + cache.loadedAt = time.Now() + channelService := &ChannelService{} + channelService.cache.Store(cache) + + svc := &GatewayService{ + billingService: NewBillingService(&config.Config{}, nil), + resolver: NewModelPricingResolver(channelService, NewBillingService(&config.Config{}, nil)), + } + + cost := svc.calculateRecordUsageCost( + context.Background(), + &ForwardResult{Model: "gemini-image", ImageCount: 2, ImageSize: "4K"}, + &APIKey{GroupID: i64p(groupID), Group: &Group{ID: groupID}}, + "gemini-image", + 1.0, + 1.0, + nil, + ) + + require.NotNil(t, cost) + require.Equal(t, string(BillingModeImage), cost.BillingMode) + require.InDelta(t, 0.80, cost.TotalCost, 1e-12) + require.InDelta(t, 0.80, cost.ActualCost, 1e-12) +} diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go index ed69730c..a5fe707d 100644 --- a/backend/internal/service/openai_gateway_service.go +++ b/backend/internal/service/openai_gateway_service.go @@ -211,9 +211,10 @@ type OpenAIUsage struct { // OpenAIForwardResult represents the result of forwarding type OpenAIForwardResult struct { - RequestID string - Usage OpenAIUsage - Model string // 原始模型(用于响应和日志显示) + RequestID string + ResponseID string + Usage OpenAIUsage + Model string // 原始模型(用于响应和日志显示) // BillingModel is the model used for cost calculation. // When non-empty, CalculateCost uses this instead of Model. // This is set by the Anthropic Messages conversion path where @@ -346,10 +347,12 @@ type OpenAIGatewayService struct { openaiWSPassthroughDialer openAIWSClientDialer openaiAccountStats *openAIAccountRuntimeStats - openaiWSFallbackUntil sync.Map // key: int64(accountID), value: time.Time - openaiWSRetryMetrics openAIWSRetryMetrics - responseHeaderFilter *responseheaders.CompiledHeaderFilter - codexSnapshotThrottle *accountWriteThrottle + openaiWSFallbackUntil sync.Map // key: int64(accountID), value: time.Time + openaiWSRetryMetrics openAIWSRetryMetrics + responseHeaderFilter *responseheaders.CompiledHeaderFilter + codexSnapshotThrottle *accountWriteThrottle + openaiCompatSessionResponses sync.Map + openaiCompatAnthropicDigestSessions sync.Map } // NewOpenAIGatewayService creates a new OpenAIGatewayService @@ -1992,6 +1995,8 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco originalBody := body reqModel, reqStream, promptCacheKey := extractOpenAIRequestMetaFromBody(body) originalModel := reqModel + compatMessagesBridge := isOpenAICompatMessagesBridgeBody(body) + setOpenAICompatMessagesBridgeContext(c, compatMessagesBridge) isCodexCLI := openai.IsCodexOfficialClientByHeaders(c.GetHeader("User-Agent"), c.GetHeader("originator")) || (s.cfg != nil && s.cfg.Gateway.ForceCodexCLI) wsDecision := s.getOpenAIWSProtocolResolver().Resolve(account) @@ -2049,6 +2054,21 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco promptCacheKey = strings.TrimSpace(v) } } + apiKey := getAPIKeyFromContext(c) + imageGenerationAllowed := GroupAllowsImageGeneration(nil) + if apiKey != nil { + imageGenerationAllowed = GroupAllowsImageGeneration(apiKey.Group) + } + if IsImageGenerationIntentMap(openAIResponsesEndpoint, reqModel, reqBody) && !imageGenerationAllowed { + setOpsUpstreamError(c, http.StatusForbidden, ImageGenerationPermissionMessage(), "") + c.JSON(http.StatusForbidden, gin.H{ + "error": gin.H{ + "type": "permission_error", + "message": ImageGenerationPermissionMessage(), + }, + }) + return nil, errors.New("image generation disabled for group") + } // Track if body needs re-serialization bodyModified := false @@ -2102,13 +2122,13 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco } // 非透传模式下,instructions 为空时注入默认指令。 - if isInstructionsEmpty(reqBody) { + if isInstructionsEmpty(reqBody) && !compatMessagesBridge { reqBody["instructions"] = "You are a helpful coding assistant." bodyModified = true markPatchSet("instructions", "You are a helpful coding assistant.") } - if isCodexCLI && ensureOpenAIResponsesImageGenerationTool(reqBody) { + if isCodexCLI && imageGenerationAllowed && ensureOpenAIResponsesImageGenerationTool(reqBody) { bodyModified = true disablePatch() logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Injected /responses image_generation tool for Codex client") @@ -2119,7 +2139,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco disablePatch() logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Normalized /responses image_generation tool payload") } - if isCodexCLI && applyCodexImageGenerationBridgeInstructions(reqBody) { + if isCodexCLI && imageGenerationAllowed && applyCodexImageGenerationBridgeInstructions(reqBody) { bodyModified = true disablePatch() logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Added Codex image_generation bridge instructions") @@ -2134,7 +2154,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco markPatchSet("model", billingModel) } upstreamModel := billingModel - if normalizeOpenAIResponsesImageOnlyModel(reqBody) { + if imageGenerationAllowed && normalizeOpenAIResponsesImageOnlyModel(reqBody) { bodyModified = true disablePatch() if model, ok := reqBody["model"].(string); ok { @@ -2231,7 +2251,20 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco } if account.Type == AccountTypeOAuth { - codexResult := applyCodexOAuthTransform(reqBody, isCodexCLI, isCompactRequest) + codexResult := codexTransformResult{} + if compatMessagesBridge { + codexResult = applyCodexOAuthTransformWithOptions(reqBody, codexOAuthTransformOptions{ + IsCodexCLI: isCodexCLI, + IsCompact: isCompactRequest, + SkipDefaultInstructions: true, + PreserveToolCallIDs: true, + }) + ensureCodexOAuthInstructionsField(reqBody) + bodyModified = true + disablePatch() + } else { + codexResult = applyCodexOAuthTransform(reqBody, isCodexCLI, isCompactRequest) + } if codexResult.Modified { bodyModified = true disablePatch() @@ -2355,6 +2388,34 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco } } + if IsImageGenerationIntentMap(openAIResponsesEndpoint, reqModel, reqBody) && !imageGenerationAllowed { + setOpsUpstreamError(c, http.StatusForbidden, ImageGenerationPermissionMessage(), "") + c.JSON(http.StatusForbidden, gin.H{ + "error": gin.H{ + "type": "permission_error", + "message": ImageGenerationPermissionMessage(), + }, + }) + return nil, errors.New("image generation disabled for group") + } + imageBillingModel := "" + imageSizeTier := "" + if IsImageGenerationIntentMap(openAIResponsesEndpoint, reqModel, reqBody) { + var imageCfgErr error + imageBillingModel, imageSizeTier, imageCfgErr = resolveOpenAIResponsesImageBillingConfig(reqBody, billingModel) + if imageCfgErr != nil { + setOpsUpstreamError(c, http.StatusBadRequest, imageCfgErr.Error(), "") + c.JSON(http.StatusBadRequest, gin.H{ + "error": gin.H{ + "type": "invalid_request_error", + "message": imageCfgErr.Error(), + "param": "size", + }, + }) + return nil, imageCfgErr + } + } + // Re-serialize body only if modified if bodyModified { serializedByPatch := false @@ -2592,6 +2653,10 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco wsAttempts, ) wsResult.UpstreamModel = upstreamModel + if wsResult.ImageCount > 0 { + wsResult.ImageSize = imageSizeTier + wsResult.BillingModel = imageBillingModel + } return wsResult, nil } s.writeOpenAIWSFallbackErrorResponse(c, account, wsErr) @@ -2601,7 +2666,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco httpInvalidEncryptedContentRetryTried := false for { // Build upstream request - upstreamCtx, releaseUpstreamCtx := detachStreamUpstreamContext(ctx, reqStream) + upstreamCtx, releaseUpstreamCtx := detachUpstreamContext(ctx) upstreamReq, err := s.buildUpstreamRequest(upstreamCtx, c, account, body, token, reqStream, promptCacheKey, isCodexCLI) releaseUpstreamCtx() if err != nil { @@ -2695,6 +2760,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco // Handle normal response var usage *OpenAIUsage var firstTokenMs *int + imageCount := 0 if reqStream { streamResult, err := s.handleStreamingResponse(ctx, resp, c, account, startTime, originalModel, upstreamModel) if err != nil { @@ -2702,11 +2768,14 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco } usage = streamResult.usage firstTokenMs = streamResult.firstTokenMs + imageCount = streamResult.imageCount } else { - usage, err = s.handleNonStreamingResponse(ctx, resp, c, account, originalModel, upstreamModel) + nonStreamResult, err := s.handleNonStreamingResponse(ctx, resp, c, account, originalModel, upstreamModel) if err != nil { return nil, err } + usage = nonStreamResult.usage + imageCount = nonStreamResult.imageCount } // Extract and save Codex usage snapshot from response headers (for OAuth accounts) @@ -2723,7 +2792,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco reasoningEffort := extractOpenAIReasoningEffort(reqBody, originalModel) serviceTier := extractOpenAIServiceTier(reqBody) - return &OpenAIForwardResult{ + forwardResult := &OpenAIForwardResult{ RequestID: resp.Header.Get("x-request-id"), Usage: *usage, Model: originalModel, @@ -2734,7 +2803,13 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco OpenAIWSMode: false, Duration: time.Since(startTime), FirstTokenMs: firstTokenMs, - }, nil + } + if imageCount > 0 { + forwardResult.ImageCount = imageCount + forwardResult.ImageSize = imageSizeTier + forwardResult.BillingModel = imageBillingModel + } + return forwardResult, nil } } @@ -2823,6 +2898,35 @@ func (s *OpenAIGatewayService) forwardOpenAIPassthrough( } body = updatedBody + apiKey := getAPIKeyFromContext(c) + if IsImageGenerationIntent(openAIResponsesEndpoint, reqModel, body) && !GroupAllowsImageGeneration(apiKeyGroup(apiKey)) { + setOpsUpstreamError(c, http.StatusForbidden, ImageGenerationPermissionMessage(), "") + c.JSON(http.StatusForbidden, gin.H{ + "error": gin.H{ + "type": "permission_error", + "message": ImageGenerationPermissionMessage(), + }, + }) + return nil, errors.New("image generation disabled for group") + } + imageBillingModel := "" + imageSizeTier := "" + if IsImageGenerationIntent(openAIResponsesEndpoint, reqModel, body) { + var imageCfgErr error + imageBillingModel, imageSizeTier, imageCfgErr = resolveOpenAIResponsesImageBillingConfigFromBody(body, reqModel) + if imageCfgErr != nil { + setOpsUpstreamError(c, http.StatusBadRequest, imageCfgErr.Error(), "") + c.JSON(http.StatusBadRequest, gin.H{ + "error": gin.H{ + "type": "invalid_request_error", + "message": imageCfgErr.Error(), + "param": "size", + }, + }) + return nil, imageCfgErr + } + } + logger.LegacyPrintf("service.openai_gateway", "[OpenAI 自动透传] 命中自动透传分支: account=%d name=%s type=%s model=%s stream=%v", account.ID, @@ -2852,7 +2956,7 @@ func (s *OpenAIGatewayService) forwardOpenAIPassthrough( return nil, err } - upstreamCtx, releaseUpstreamCtx := detachStreamUpstreamContext(ctx, reqStream) + upstreamCtx, releaseUpstreamCtx := detachUpstreamContext(ctx) upstreamReq, err := s.buildUpstreamRequestOpenAIPassthrough(upstreamCtx, c, account, body, token) releaseUpstreamCtx() if err != nil { @@ -2905,6 +3009,7 @@ func (s *OpenAIGatewayService) forwardOpenAIPassthrough( var usage *OpenAIUsage var firstTokenMs *int + imageCount := 0 if reqStream { result, err := s.handleStreamingResponsePassthrough(ctx, resp, c, account, startTime, reqModel, upstreamPassthroughModel) if err != nil { @@ -2912,11 +3017,14 @@ func (s *OpenAIGatewayService) forwardOpenAIPassthrough( } usage = result.usage firstTokenMs = result.firstTokenMs + imageCount = result.imageCount } else { - usage, err = s.handleNonStreamingResponsePassthrough(ctx, resp, c, reqModel, upstreamPassthroughModel) + result, err := s.handleNonStreamingResponsePassthrough(ctx, resp, c, reqModel, upstreamPassthroughModel) if err != nil { return nil, err } + usage = result.usage + imageCount = result.imageCount } if snapshot := ParseCodexRateLimitHeaders(resp.Header); snapshot != nil { @@ -2927,7 +3035,7 @@ func (s *OpenAIGatewayService) forwardOpenAIPassthrough( usage = &OpenAIUsage{} } - return &OpenAIForwardResult{ + forwardResult := &OpenAIForwardResult{ RequestID: resp.Header.Get("x-request-id"), Usage: *usage, Model: reqModel, @@ -2938,7 +3046,13 @@ func (s *OpenAIGatewayService) forwardOpenAIPassthrough( OpenAIWSMode: false, Duration: time.Since(startTime), FirstTokenMs: firstTokenMs, - }, nil + } + if imageCount > 0 { + forwardResult.ImageCount = imageCount + forwardResult.ImageSize = imageSizeTier + forwardResult.BillingModel = imageBillingModel + } + return forwardResult, nil } func logOpenAIPassthroughInstructionsRejected( @@ -3233,6 +3347,13 @@ func collectOpenAIPassthroughTimeoutHeaders(h http.Header) []string { type openaiStreamingResultPassthrough struct { usage *OpenAIUsage firstTokenMs *int + imageCount int +} + +type openaiNonStreamingResultPassthrough struct { + *OpenAIUsage + usage *OpenAIUsage + imageCount int } func openAIStreamClientOutputStarted(c *gin.Context, localStarted bool) bool { @@ -3369,6 +3490,7 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough( } usage := &OpenAIUsage{} + imageCounter := newOpenAIImageOutputCounter() var firstTokenMs *int clientDisconnected := false sawDone := false @@ -3400,6 +3522,9 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough( defer putSSEScannerBuf64K(scanBuf) needModelReplace := strings.TrimSpace(originalModel) != "" && strings.TrimSpace(mappedModel) != "" && strings.TrimSpace(originalModel) != strings.TrimSpace(mappedModel) + resultWithUsage := func() *openaiStreamingResultPassthrough { + return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs, imageCount: imageCounter.Count()} + } for scanner.Scan() { line := scanner.Text() @@ -3419,7 +3544,7 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough( if eventType == "response.failed" { failedMessage = extractOpenAISSEErrorMessage(dataBytes) if !openAIStreamClientOutputStarted(c, clientOutputStarted) && openAIStreamFailedEventShouldFailover(dataBytes, failedMessage) { - return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, + return resultWithUsage(), s.newOpenAIStreamFailoverError(c, account, true, upstreamRequestID, dataBytes, failedMessage) } forceFlushFailedEvent = true @@ -3431,6 +3556,7 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough( if openAIStreamEventIsTerminal(trimmedData) { sawTerminalEvent = true } + imageCounter.AddSSEData(dataBytes) lineStartsClientOutput = forceFlushFailedEvent || openAIStreamDataStartsClientOutput(trimmedData, eventType) if firstTokenMs == nil && lineStartsClientOutput && trimmedData != "[DONE]" { ms := int(time.Since(startTime).Milliseconds()) @@ -3460,28 +3586,28 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough( } if err := scanner.Err(); err != nil { if sawTerminalEvent && !sawFailedEvent { - return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, nil + return resultWithUsage(), nil } if sawFailedEvent { - return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("upstream response failed: %s", failedMessage) + return resultWithUsage(), fmt.Errorf("upstream response failed: %s", failedMessage) } if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { - return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream usage incomplete: %w", err) + return resultWithUsage(), fmt.Errorf("stream usage incomplete: %w", err) } if errors.Is(err, bufio.ErrTooLong) { logger.LegacyPrintf("service.openai_gateway", "[OpenAI passthrough] SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, err) - return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, err + return resultWithUsage(), err } if !openAIStreamClientOutputStarted(c, clientOutputStarted) { msg := "OpenAI stream disconnected before completion" if errText := strings.TrimSpace(err.Error()); errText != "" { msg += ": " + errText } - return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, + return resultWithUsage(), s.newOpenAIStreamFailoverError(c, account, true, upstreamRequestID, nil, msg) } if clientDisconnected { - return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream usage incomplete after disconnect: %w", err) + return resultWithUsage(), fmt.Errorf("stream usage incomplete after disconnect: %w", err) } logger.LegacyPrintf("service.openai_gateway", "[OpenAI passthrough] 流读取异常中断: account=%d request_id=%s err=%v", @@ -3489,10 +3615,10 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough( upstreamRequestID, err, ) - return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream read error: %w", err) + return resultWithUsage(), fmt.Errorf("stream read error: %w", err) } if sawFailedEvent { - return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("upstream response failed: %s", failedMessage) + return resultWithUsage(), fmt.Errorf("upstream response failed: %s", failedMessage) } if !clientDisconnected && !sawDone && !sawTerminalEvent && ctx.Err() == nil { logger.FromContext(ctx).With( @@ -3501,13 +3627,13 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough( zap.String("upstream_request_id", upstreamRequestID), ).Info("OpenAI passthrough 上游流在未收到 [DONE] 时结束,疑似断流") if !openAIStreamClientOutputStarted(c, clientOutputStarted) { - return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, + return resultWithUsage(), s.newOpenAIStreamFailoverError(c, account, true, upstreamRequestID, nil, "OpenAI stream ended before a terminal event") } - return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, errors.New("stream usage incomplete: missing terminal event") + return resultWithUsage(), errors.New("stream usage incomplete: missing terminal event") } - return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, nil + return resultWithUsage(), nil } func (s *OpenAIGatewayService) handleNonStreamingResponsePassthrough( @@ -3516,7 +3642,7 @@ func (s *OpenAIGatewayService) handleNonStreamingResponsePassthrough( c *gin.Context, originalModel string, mappedModel string, -) (*OpenAIUsage, error) { +) (*openaiNonStreamingResultPassthrough, error) { body, err := ReadUpstreamResponseBody(resp.Body, s.cfg, c, openAITooLargeError) if err != nil { return nil, err @@ -3553,14 +3679,18 @@ func (s *OpenAIGatewayService) handleNonStreamingResponsePassthrough( body = s.replaceModelInResponseBody(body, mappedModel, originalModel) } c.Data(resp.StatusCode, contentType, body) - return usage, nil + return &openaiNonStreamingResultPassthrough{ + OpenAIUsage: usage, + usage: usage, + imageCount: countOpenAIResponseImageOutputsFromJSONBytes(body), + }, nil } // handlePassthroughSSEToJSON converts an SSE response body into a JSON // response for the passthrough path. It mirrors handleSSEToJSON while // preserving passthrough payloads, except compact-only model remapping may // rewrite model fields back to the original requested model. -func (s *OpenAIGatewayService) handlePassthroughSSEToJSON(resp *http.Response, c *gin.Context, body []byte, originalModel string, mappedModel string) (*OpenAIUsage, error) { +func (s *OpenAIGatewayService) handlePassthroughSSEToJSON(resp *http.Response, c *gin.Context, body []byte, originalModel string, mappedModel string) (*openaiNonStreamingResultPassthrough, error) { bodyText := string(body) finalResponse, ok := extractCodexFinalResponse(bodyText) @@ -3611,7 +3741,11 @@ func (s *OpenAIGatewayService) handlePassthroughSSEToJSON(resp *http.Response, c } c.Data(resp.StatusCode, contentType, body) - return usage, nil + return &openaiNonStreamingResultPassthrough{ + OpenAIUsage: usage, + usage: usage, + imageCount: countOpenAIImageOutputsFromSSEBody(bodyText), + }, nil } func writeOpenAIPassthroughResponseHeaders(dst http.Header, src http.Header, filter *responseheaders.CompiledHeaderFilter) { @@ -3715,12 +3849,19 @@ func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin. } } if account.Type == AccountTypeOAuth { + compatMessagesBridge := isOpenAICompatMessagesBridgeContext(c) || isOpenAICompatMessagesBridgeBody(body) // 清除客户端透传的 session 头,后续用隔离后的值重新设置,防止跨用户会话碰撞。 + clientConversationID := strings.TrimSpace(req.Header.Get("conversation_id")) req.Header.Del("conversation_id") req.Header.Del("session_id") - req.Header.Set("OpenAI-Beta", "responses=experimental") - req.Header.Set("originator", resolveOpenAIUpstreamOriginator(c, isCodexCLI)) + if compatMessagesBridge { + req.Header.Del("OpenAI-Beta") + req.Header.Del("originator") + } else { + req.Header.Set("OpenAI-Beta", "responses=experimental") + req.Header.Set("originator", resolveOpenAIUpstreamOriginator(c, isCodexCLI)) + } apiKeyID := getAPIKeyIDFromContext(c) if isOpenAIResponsesCompactPath(c) { req.Header.Set("accept", "application/json") @@ -3734,8 +3875,10 @@ func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin. } if promptCacheKey != "" { isolated := isolateOpenAISessionID(apiKeyID, promptCacheKey) - req.Header.Set("conversation_id", isolated) req.Header.Set("session_id", isolated) + if !compatMessagesBridge || clientConversationID != "" { + req.Header.Set("conversation_id", isolated) + } } } @@ -4025,6 +4168,13 @@ func (s *OpenAIGatewayService) handleCompatErrorResponse( type openaiStreamingResult struct { usage *OpenAIUsage firstTokenMs *int + imageCount int +} + +type openaiNonStreamingResult struct { + *OpenAIUsage + usage *OpenAIUsage + imageCount int } func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, startTime time.Time, originalModel, mappedModel string) (*openaiStreamingResult, error) { @@ -4058,6 +4208,7 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp } usage := &OpenAIUsage{} + imageCounter := newOpenAIImageOutputCounter() var firstTokenMs *int scanner := bufio.NewScanner(resp.Body) maxLineSize := defaultMaxLineSize @@ -4136,7 +4287,7 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp needModelReplace := originalModel != mappedModel resultWithUsage := func() *openaiStreamingResult { - return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs} + return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs, imageCount: imageCounter.Count()} } finalizeStream := func() (*openaiStreamingResult, error) { if !sawTerminalEvent { @@ -4231,6 +4382,7 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp forceFlushFailedEvent = true sawFailedEvent = true } + imageCounter.AddSSEData(dataBytes) // Correct Codex tool calls if needed (apply_patch -> edit, etc.) if correctedData, corrected := s.toolCorrector.CorrectToolCallsInSSEBytes(dataBytes); corrected { @@ -4496,7 +4648,7 @@ func extractOpenAIUsageFromJSONBytes(body []byte) (OpenAIUsage, bool) { }, true } -func (s *OpenAIGatewayService) handleNonStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, originalModel, mappedModel string) (*OpenAIUsage, error) { +func (s *OpenAIGatewayService) handleNonStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, originalModel, mappedModel string) (*openaiNonStreamingResult, error) { body, err := ReadUpstreamResponseBody(resp.Body, s.cfg, c, openAITooLargeError) if err != nil { return nil, err @@ -4542,7 +4694,11 @@ func (s *OpenAIGatewayService) handleNonStreamingResponse(ctx context.Context, r c.Data(resp.StatusCode, contentType, body) - return usage, nil + return &openaiNonStreamingResult{ + OpenAIUsage: usage, + usage: usage, + imageCount: countOpenAIResponseImageOutputsFromJSONBytes(body), + }, nil } func isEventStreamResponse(header http.Header) bool { @@ -4550,7 +4706,7 @@ func isEventStreamResponse(header http.Header) bool { return strings.Contains(contentType, "text/event-stream") } -func (s *OpenAIGatewayService) handleSSEToJSON(resp *http.Response, c *gin.Context, body []byte, originalModel, mappedModel string) (*OpenAIUsage, error) { +func (s *OpenAIGatewayService) handleSSEToJSON(resp *http.Response, c *gin.Context, body []byte, originalModel, mappedModel string) (*openaiNonStreamingResult, error) { bodyText := string(body) finalResponse, ok := extractCodexFinalResponse(bodyText) @@ -4602,21 +4758,29 @@ func (s *OpenAIGatewayService) handleSSEToJSON(resp *http.Response, c *gin.Conte } c.Data(resp.StatusCode, contentType, body) - return usage, nil + return &openaiNonStreamingResult{ + OpenAIUsage: usage, + usage: usage, + imageCount: countOpenAIImageOutputsFromSSEBody(bodyText), + }, nil } func extractOpenAISSETerminalEvent(body string) (string, []byte, bool) { - lines := strings.Split(body, "\n") - for _, line := range lines { - data, ok := extractOpenAISSEDataLine(line) - if !ok || data == "" || data == "[DONE]" { - continue + var terminalType string + var terminalPayload []byte + forEachOpenAISSEDataPayload(body, func(data []byte) { + if terminalPayload != nil { + return } - eventType := strings.TrimSpace(gjson.Get(data, "type").String()) + eventType := strings.TrimSpace(gjson.GetBytes(data, "type").String()) switch eventType { case "response.completed", "response.done", "response.failed", "response.incomplete", "response.cancelled", "response.canceled": - return eventType, []byte(data), true + terminalType = eventType + terminalPayload = append([]byte(nil), data...) } + }) + if terminalPayload != nil { + return terminalType, terminalPayload, true } return "", nil, false } @@ -4651,21 +4815,20 @@ func (s *OpenAIGatewayService) writeOpenAINonStreamingProtocolError(resp *http.R } func extractCodexFinalResponse(body string) ([]byte, bool) { - lines := strings.Split(body, "\n") - for _, line := range lines { - data, ok := extractOpenAISSEDataLine(line) - if !ok { - continue + var finalResponse []byte + forEachOpenAISSEDataPayload(body, func(data []byte) { + if finalResponse != nil { + return } - if data == "" || data == "[DONE]" { - continue - } - eventType := gjson.Get(data, "type").String() + eventType := gjson.GetBytes(data, "type").String() if eventType == "response.done" || eventType == "response.completed" { - if response := gjson.Get(data, "response"); response.Exists() && response.Type == gjson.JSON && response.Raw != "" { - return []byte(response.Raw), true + if response := gjson.GetBytes(data, "response"); response.Exists() && response.Type == gjson.JSON && response.Raw != "" { + finalResponse = []byte(response.Raw) } } + }) + if finalResponse != nil { + return finalResponse, true } return nil, false } @@ -4677,21 +4840,15 @@ func reconstructResponseOutputFromSSE(bodyText string) ([]byte, bool) { acc := apicompat.NewBufferedResponseAccumulator() imageOutputs := make([]json.RawMessage, 0, 1) seenImages := make(map[string]struct{}) - lines := strings.Split(bodyText, "\n") - for _, line := range lines { - data, ok := extractOpenAISSEDataLine(line) - if !ok || data == "" || data == "[DONE]" { - continue - } - if imageOutput, ok := extractImageGenerationOutputFromSSEData([]byte(data), seenImages); ok { + forEachOpenAISSEDataPayload(bodyText, func(data []byte) { + if imageOutput, ok := extractImageGenerationOutputFromSSEData(data, seenImages); ok { imageOutputs = append(imageOutputs, imageOutput) } var event apicompat.ResponsesStreamEvent - if err := json.Unmarshal([]byte(data), &event); err != nil { - continue + if err := json.Unmarshal(data, &event); err == nil { + acc.ProcessEvent(&event) } - acc.ProcessEvent(&event) - } + }) if !acc.HasContent() && len(imageOutputs) == 0 { return nil, false } @@ -4744,17 +4901,9 @@ func extractImageGenerationOutputFromSSEData(data []byte, seen map[string]struct func (s *OpenAIGatewayService) parseSSEUsageFromBody(body string) *OpenAIUsage { usage := &OpenAIUsage{} - lines := strings.Split(body, "\n") - for _, line := range lines { - data, ok := extractOpenAISSEDataLine(line) - if !ok { - continue - } - if data == "" || data == "[DONE]" { - continue - } - s.parseSSEUsageBytes([]byte(data), usage) - } + forEachOpenAISSEDataPayload(body, func(data []byte) { + s.parseSSEUsageBytes(data, usage) + }) return usage } @@ -5036,16 +5185,15 @@ type OpenAIRecordUsageInput struct { // RecordUsage records usage and deducts balance func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRecordUsageInput) error { - result := input.Result - if s.rateLimitService != nil && input != nil && input.Account != nil && input.Account.Platform == PlatformOpenAI { - s.rateLimitService.ResetOpenAI403Counter(ctx, input.Account.ID) + if input == nil { + return errors.New("openai usage input is nil") } - - // 跳过所有 token 均为零的用量记录——上游未返回 usage 时不应写入数据库 - if result.Usage.InputTokens == 0 && result.Usage.OutputTokens == 0 && - result.Usage.CacheCreationInputTokens == 0 && result.Usage.CacheReadInputTokens == 0 && - result.Usage.ImageOutputTokens == 0 && result.ImageCount == 0 { - return nil + result := input.Result + if result == nil { + return errors.New("openai usage result is nil") + } + if s.rateLimitService != nil && input.Account != nil && input.Account.Platform == PlatformOpenAI { + s.rateLimitService.ResetOpenAI403Counter(ctx, input.Account.ID) } apiKey := input.APIKey @@ -5081,6 +5229,7 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec } multiplier = resolver.Resolve(ctx, user.ID, *apiKey.GroupID, apiKey.Group.RateMultiplier) } + imageMultiplier := resolveImageRateMultiplier(apiKey, multiplier) var cost *CostBreakdown var err error @@ -5094,13 +5243,21 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec if input.BillingModelSource == BillingModelSourceRequested && input.OriginalModel != "" { billingModel = input.OriginalModel } + billingModels := usageBillingModelCandidates( + billingModel, + result.BillingModel, + input.ChannelMappedModel, + input.OriginalModel, + result.UpstreamModel, + result.Model, + ) serviceTier := "" if result.ServiceTier != nil { serviceTier = strings.TrimSpace(*result.ServiceTier) } - cost, err = s.calculateOpenAIRecordUsageCost(ctx, result, apiKey, billingModel, multiplier, tokens, serviceTier) + cost, err = s.calculateOpenAIRecordUsageCost(ctx, result, apiKey, billingModels, multiplier, imageMultiplier, tokens, serviceTier) if err != nil { - cost = &CostBreakdown{ActualCost: 0} + return err } // Determine billing type @@ -5150,7 +5307,11 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec usageLog.TotalCost = cost.TotalCost usageLog.ActualCost = cost.ActualCost } - usageLog.RateMultiplier = multiplier + if result.ImageCount > 0 { + usageLog.RateMultiplier = imageMultiplier + } else { + usageLog.RateMultiplier = multiplier + } usageLog.AccountRateMultiplier = &accountRateMultiplier usageLog.BillingType = billingType usageLog.Stream = result.Stream @@ -5231,14 +5392,45 @@ func (s *OpenAIGatewayService) calculateOpenAIRecordUsageCost( ctx context.Context, result *OpenAIForwardResult, apiKey *APIKey, + billingModels []string, + multiplier float64, + imageMultiplier float64, + tokens UsageTokens, + serviceTier string, +) (*CostBreakdown, error) { + billingModel := firstUsageBillingModel(billingModels) + if result != nil && result.ImageCount > 0 { + return s.calculateOpenAIImageCost(ctx, billingModel, apiKey, result, imageMultiplier), nil + } + if len(billingModels) == 0 || billingModel == "" { + return nil, errors.New("openai usage billing model is empty") + } + var lastErr error + for _, candidate := range billingModels { + candidate = strings.TrimSpace(candidate) + if candidate == "" { + continue + } + cost, err := s.calculateOpenAIRecordUsageTokenCost(ctx, apiKey, candidate, multiplier, tokens, serviceTier) + if err == nil { + return cost, nil + } + lastErr = err + } + if lastErr == nil { + lastErr = errors.New("no non-empty billing model candidates") + } + return nil, fmt.Errorf("calculate OpenAI usage cost failed for billing models %s: %w", strings.Join(billingModels, ","), lastErr) +} + +func (s *OpenAIGatewayService) calculateOpenAIRecordUsageTokenCost( + ctx context.Context, + apiKey *APIKey, billingModel string, multiplier float64, tokens UsageTokens, serviceTier string, ) (*CostBreakdown, error) { - if result != nil && result.ImageCount > 0 { - return s.calculateOpenAIImageCost(ctx, billingModel, apiKey, result, multiplier), nil - } if s.resolver != nil && apiKey.Group != nil { gid := apiKey.Group.ID return s.billingService.CalculateCostUnified(CostInput{ @@ -5269,7 +5461,7 @@ func (s *OpenAIGatewayService) calculateOpenAIImageCost( Ctx: ctx, Model: billingModel, GroupID: &gid, - RequestCount: 1, + RequestCount: result.ImageCount, SizeTier: result.ImageSize, RateMultiplier: multiplier, Resolver: s.resolver, diff --git a/backend/internal/service/openai_gateway_service_test.go b/backend/internal/service/openai_gateway_service_test.go index 5d1c6fc6..24095f2b 100644 --- a/backend/internal/service/openai_gateway_service_test.go +++ b/backend/internal/service/openai_gateway_service_test.go @@ -1846,6 +1846,29 @@ func TestOpenAIBuildUpstreamRequestCompactForcesJSONAcceptForOAuth(t *testing.T) require.NotEmpty(t, req.Header.Get("Session_Id")) } +func TestOpenAIBuildUpstreamRequestOAuthMessagesBridgeUsesSessionOnly(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + body := []byte(`{"model":"gpt-5.5","prompt_cache_key":"anthropic-metadata-session-1","input":[{"type":"message","role":"developer","content":[{"type":"input_text","text":""}]},{"type":"message","role":"user","content":"hello"}]}`) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(body)) + c.Request.Header.Set("OpenAI-Beta", "responses=experimental") + c.Request.Header.Set("originator", "codex_cli_rs") + + svc := &OpenAIGatewayService{} + account := &Account{ + Type: AccountTypeOAuth, + Credentials: map[string]any{"chatgpt_account_id": "chatgpt-acc"}, + } + + req, err := svc.buildUpstreamRequest(c.Request.Context(), c, account, body, "token", true, "anthropic-metadata-session-1", false) + require.NoError(t, err) + require.NotEmpty(t, req.Header.Get("Session_Id")) + require.Empty(t, req.Header.Get("Conversation_Id")) + require.Empty(t, req.Header.Get("OpenAI-Beta")) + require.Empty(t, req.Header.Get("originator")) +} + func TestOpenAIBuildUpstreamRequestPreservesCompactPathForAPIKeyBaseURL(t *testing.T) { gin.SetMode(gin.TestMode) rec := httptest.NewRecorder() diff --git a/backend/internal/service/openai_image_generation_controls_test.go b/backend/internal/service/openai_image_generation_controls_test.go new file mode 100644 index 00000000..76dc8053 --- /dev/null +++ b/backend/internal/service/openai_image_generation_controls_test.go @@ -0,0 +1,215 @@ +package service + +import ( + "context" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" + "github.com/tidwall/gjson" +) + +func TestOpenAIGatewayServiceForward_RejectsDisabledImageGenerationIntents(t *testing.T) { + gin.SetMode(gin.TestMode) + + tests := []struct { + name string + body []byte + }{ + { + name: "image model", + body: []byte(`{"model":"gpt-image-2","input":"draw"}`), + }, + { + name: "image tool", + body: []byte(`{"model":"gpt-5.4","input":"draw","tools":[{"type":"image_generation"}]}`), + }, + { + name: "image tool choice", + body: []byte(`{"model":"gpt-5.4","input":"draw","tool_choice":{"type":"image_generation"}}`), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + upstream := &httpUpstreamRecorder{} + svc := newOpenAIImageGenerationControlTestService(upstream) + c, recorder := newOpenAIImageGenerationControlTestContext(false, "unit-test-agent/1.0") + account := newOpenAIImageGenerationControlTestAccount() + + result, err := svc.Forward(context.Background(), c, account, tt.body) + + require.Error(t, err) + require.Nil(t, result) + require.Equal(t, http.StatusForbidden, recorder.Code) + require.Equal(t, "permission_error", gjson.GetBytes(recorder.Body.Bytes(), "error.type").String()) + require.Nil(t, upstream.lastReq, "disabled image request must not reach upstream") + }) + } +} + +func TestOpenAIGatewayServiceForward_DisabledGroupAllowsTextOnlyResponses(t *testing.T) { + gin.SetMode(gin.TestMode) + + upstream := &httpUpstreamRecorder{ + resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"application/json"}}, + Body: io.NopCloser(strings.NewReader(`{"id":"resp_text","model":"gpt-5.4","usage":{"input_tokens":3,"output_tokens":2}}`)), + }, + } + svc := newOpenAIImageGenerationControlTestService(upstream) + c, recorder := newOpenAIImageGenerationControlTestContext(false, "unit-test-agent/1.0") + account := newOpenAIImageGenerationControlTestAccount() + + result, err := svc.Forward(context.Background(), c, account, []byte(`{"model":"gpt-5.4","input":"write code","stream":false}`)) + + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, http.StatusOK, recorder.Code) + require.Equal(t, 3, result.Usage.InputTokens) + require.Equal(t, 2, result.Usage.OutputTokens) + require.Equal(t, 0, result.ImageCount) + require.NotNil(t, upstream.lastReq) +} + +func TestOpenAIGatewayServiceForward_CodexImageInjectionRespectsGroupCapability(t *testing.T) { + gin.SetMode(gin.TestMode) + + tests := []struct { + name string + allowImages bool + wantInjected bool + }{ + {name: "disabled group skips injection", allowImages: false, wantInjected: false}, + {name: "enabled group injects image tool", allowImages: true, wantInjected: true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + upstream := &httpUpstreamRecorder{ + resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"application/json"}}, + Body: io.NopCloser(strings.NewReader(`{"id":"resp_codex","model":"gpt-5.4","usage":{"input_tokens":1,"output_tokens":1}}`)), + }, + } + svc := newOpenAIImageGenerationControlTestService(upstream) + c, _ := newOpenAIImageGenerationControlTestContext(tt.allowImages, "codex_cli_rs/0.98.0") + account := newOpenAIImageGenerationControlTestAccount() + + result, err := svc.Forward(context.Background(), c, account, []byte(`{"model":"gpt-5.4","input":"write code","stream":false}`)) + + require.NoError(t, err) + require.NotNil(t, result) + require.NotNil(t, upstream.lastReq) + hasImageTool := gjson.GetBytes(upstream.lastBody, `tools.#(type=="image_generation")`).Exists() + require.Equal(t, tt.wantInjected, hasImageTool) + instructions := gjson.GetBytes(upstream.lastBody, "instructions").String() + require.Equal(t, tt.wantInjected, strings.Contains(instructions, "image_generation")) + }) + } +} + +func TestOpenAIGatewayServiceHandleResponsesImageOutputs_NonStreaming(t *testing.T) { + gin.SetMode(gin.TestMode) + + svc := newOpenAIImageGenerationControlTestService(&httpUpstreamRecorder{}) + c, _ := newOpenAIImageGenerationControlTestContext(true, "unit-test-agent/1.0") + resp := &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"application/json"}}, + Body: io.NopCloser(strings.NewReader(`{ + "id":"resp_image_json", + "model":"gpt-5.4", + "output":[{"id":"ig_json_1","type":"image_generation_call","result":"final-image"}], + "usage":{"input_tokens":7,"output_tokens":3,"output_tokens_details":{"image_tokens":2}} + }`)), + } + + result, err := svc.handleNonStreamingResponse(context.Background(), resp, c, &Account{ID: 1, Type: AccountTypeAPIKey}, "gpt-5.4", "gpt-5.4") + + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, 1, result.imageCount) + require.NotNil(t, result.usage) + require.Equal(t, 7, result.usage.InputTokens) + require.Equal(t, 3, result.usage.OutputTokens) + require.Equal(t, 2, result.usage.ImageOutputTokens) +} + +func TestOpenAIGatewayServiceHandleResponsesImageOutputs_Streaming(t *testing.T) { + gin.SetMode(gin.TestMode) + + svc := newOpenAIImageGenerationControlTestService(&httpUpstreamRecorder{}) + c, _ := newOpenAIImageGenerationControlTestContext(true, "unit-test-agent/1.0") + resp := &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}}, + Body: io.NopCloser(strings.NewReader( + "data: {\"type\":\"response.output_item.done\",\"item\":{\"id\":\"ig_stream_1\",\"type\":\"image_generation_call\",\"result\":\"final-image\"}}\n\n" + + "data: {\"type\":\"response.completed\",\"response\":{\"id\":\"resp_image_stream\",\"model\":\"gpt-5.5\",\"output\":[{\"id\":\"ig_stream_1\",\"type\":\"image_generation_call\",\"result\":\"final-image\"}],\"usage\":{\"input_tokens\":11,\"output_tokens\":5,\"output_tokens_details\":{\"image_tokens\":4}}}}\n\n", + )), + } + + result, err := svc.handleStreamingResponse(context.Background(), resp, c, &Account{ID: 1}, time.Now(), "gpt-5.5", "gpt-5.5") + + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, 1, result.imageCount) + require.NotNil(t, result.usage) + require.Equal(t, 11, result.usage.InputTokens) + require.Equal(t, 5, result.usage.OutputTokens) + require.Equal(t, 4, result.usage.ImageOutputTokens) +} + +func newOpenAIImageGenerationControlTestService(upstream *httpUpstreamRecorder) *OpenAIGatewayService { + cfg := &config.Config{} + return &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: upstream, + cache: &stubGatewayCache{}, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + toolCorrector: NewCodexToolCorrector(), + } +} + +func newOpenAIImageGenerationControlTestContext(allowImages bool, userAgent string) (*gin.Context, *httptest.ResponseRecorder) { + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) + c.Request.Header.Set("User-Agent", userAgent) + groupID := int64(4242) + c.Set("api_key", &APIKey{ + ID: 2424, + GroupID: &groupID, + Group: &Group{ + ID: groupID, + AllowImageGeneration: allowImages, + RateMultiplier: 1, + ImageRateMultiplier: 1, + }, + }) + return c, recorder +} + +func newOpenAIImageGenerationControlTestAccount() *Account { + return &Account{ + ID: 5151, + Name: "openai-image-controls", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + }, + } +} diff --git a/backend/internal/service/openai_images.go b/backend/internal/service/openai_images.go index 4badcb1c..04be5164 100644 --- a/backend/internal/service/openai_images.go +++ b/backend/internal/service/openai_images.go @@ -16,6 +16,7 @@ import ( "net/textproto" "strconv" "strings" + "sync/atomic" "time" "github.com/Wei-Shaw/sub2api/internal/pkg/logger" @@ -468,14 +469,54 @@ func isOpenAINativeImageOption(name string) bool { } func normalizeOpenAIImageSizeTier(size string) string { - switch strings.ToLower(strings.TrimSpace(size)) { + trimmed := strings.TrimSpace(size) + normalized := strings.ToLower(trimmed) + switch normalized { + case "", "auto": + return "2K" case "1024x1024": return "1K" - case "1536x1024", "1024x1536", "1792x1024", "1024x1792", "", "auto": + case "1536x1024", "1024x1536", "1792x1024", "1024x1792", "2048x2048", "2048x1152", "1152x2048": return "2K" - default: + case "3840x2160", "2160x3840": + return "4K" + } + width, height, ok := parseOpenAIImageSizeDimensions(trimmed) + if !ok { return "2K" } + return classifyUnknownOpenAIImageSizeTier(width, height) +} + +const ( + openAIImage2KMaxPixels = 2560 * 1440 +) + +func parseOpenAIImageSizeDimensions(size string) (int, int, bool) { + trimmed := strings.TrimSpace(size) + parts := strings.Split(strings.ToLower(trimmed), "x") + if len(parts) != 2 { + return 0, 0, false + } + width, err := strconv.Atoi(strings.TrimSpace(parts[0])) + if err != nil { + return 0, 0, false + } + height, err := strconv.Atoi(strings.TrimSpace(parts[1])) + if err != nil { + return 0, 0, false + } + if width <= 0 || height <= 0 { + return 0, 0, false + } + return width, height, true +} + +func classifyUnknownOpenAIImageSizeTier(width int, height int) string { + if height > 0 && width > openAIImage2KMaxPixels/height { + return "4K" + } + return "2K" } func (s *OpenAIGatewayService) ForwardImages( @@ -535,11 +576,14 @@ func (s *OpenAIGatewayService) forwardOpenAIImagesAPIKey( setOpsUpstreamRequestBody(c, forwardBody) } - token, _, err := s.GetAccessToken(ctx, account) + upstreamCtx, releaseUpstreamCtx := detachStreamUpstreamContext(ctx, parsed.Stream) + defer releaseUpstreamCtx() + + token, _, err := s.GetAccessToken(upstreamCtx, account) if err != nil { return nil, err } - upstreamReq, err := s.buildOpenAIImagesRequest(ctx, c, account, forwardBody, forwardContentType, token, parsed.Endpoint) + upstreamReq, err := s.buildOpenAIImagesRequest(upstreamCtx, c, account, forwardBody, forwardContentType, token, parsed.Endpoint) if err != nil { return nil, err } @@ -582,23 +626,37 @@ func (s *OpenAIGatewayService) forwardOpenAIImagesAPIKey( Kind: "failover", Message: upstreamMsg, }) - s.handleFailoverSideEffects(ctx, resp, account) + s.handleFailoverSideEffects(upstreamCtx, resp, account) return nil, &UpstreamFailoverError{ StatusCode: resp.StatusCode, ResponseBody: respBody, RetryableOnSameAccount: account.IsPoolMode() && isPoolModeRetryableStatus(resp.StatusCode), } } - return s.handleErrorResponse(ctx, resp, c, account, forwardBody) + return s.handleErrorResponse(upstreamCtx, resp, c, account, forwardBody) } defer func() { _ = resp.Body.Close() }() var usage OpenAIUsage imageCount := parsed.N var firstTokenMs *int - if parsed.Stream { + if parsed.Stream && isEventStreamResponse(resp.Header) { streamUsage, streamCount, ttft, err := s.handleOpenAIImagesStreamingResponse(resp, c, startTime) if err != nil { + if streamCount > 0 { + return &OpenAIForwardResult{ + RequestID: resp.Header.Get("x-request-id"), + Usage: streamUsage, + Model: requestModel, + UpstreamModel: upstreamModel, + Stream: parsed.Stream, + ResponseHeaders: resp.Header.Clone(), + Duration: time.Since(startTime), + FirstTokenMs: ttft, + ImageCount: streamCount, + ImageSize: parsed.SizeTier, + }, err + } return nil, err } usage = streamUsage @@ -807,39 +865,228 @@ func (s *OpenAIGatewayService) handleOpenAIImagesStreamingResponse( return OpenAIUsage{}, 0, nil, fmt.Errorf("streaming is not supported by response writer") } - reader := bufio.NewReader(resp.Body) usage := OpenAIUsage{} - imageCount := 0 + imageCounter := newOpenAIImageOutputCounter() var firstTokenMs *int + clientDisconnected := false + lastDownstreamWriteAt := time.Now() + var fallbackBody bytes.Buffer + fallbackBytes := int64(0) + fallbackLimit := resolveUpstreamResponseReadLimit(s.cfg) + seenSSEData := false + fallbackTooLarge := false + var sseData openAISSEDataAccumulator - for { - line, err := reader.ReadBytes('\n') - if len(line) > 0 { - if firstTokenMs == nil { - ms := int(time.Since(startTime).Milliseconds()) - firstTokenMs = &ms - } + processSSEData := func(dataBytes []byte) { + seenSSEData = true + fallbackBody.Reset() + fallbackBytes = 0 + mergeOpenAIUsage(&usage, dataBytes) + imageCounter.AddSSEData(dataBytes) + } + + flushSSEEvent := func() { + sseData.Flush(processSSEData) + } + + processLine := func(line []byte) { + if len(line) == 0 { + return + } + if firstTokenMs == nil { + ms := int(time.Since(startTime).Milliseconds()) + firstTokenMs = &ms + } + if !clientDisconnected { if _, writeErr := c.Writer.Write(line); writeErr != nil { - return OpenAIUsage{}, 0, firstTokenMs, writeErr + clientDisconnected = true + logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Images stream client disconnected, continue draining upstream for billing") + } else { + flusher.Flush() + lastDownstreamWriteAt = time.Now() } - flusher.Flush() + } - if data, ok := extractOpenAISSEDataLine(strings.TrimRight(string(line), "\r\n")); ok && data != "" && data != "[DONE]" { - dataBytes := []byte(data) - mergeOpenAIUsage(&usage, dataBytes) - if count := extractOpenAIImageCountFromJSONBytes(dataBytes); count > imageCount { - imageCount = count - } + trimmedLine := strings.TrimRight(string(line), "\r\n") + if _, ok := extractOpenAISSEDataLine(trimmedLine); ok || strings.TrimSpace(trimmedLine) == "" { + sseData.AddLine(trimmedLine, processSSEData) + return + } + if !seenSSEData && !fallbackTooLarge { + fallbackBytes += int64(len(line)) + if fallbackBytes <= fallbackLimit { + _, _ = fallbackBody.Write(line) + } else { + fallbackTooLarge = true + fallbackBody.Reset() } } - if err == io.EOF { - break - } - if err != nil { - return OpenAIUsage{}, 0, firstTokenMs, err - } } - return usage, imageCount, firstTokenMs, nil + + finalizeFallbackBody := func() { + if seenSSEData || fallbackBody.Len() == 0 { + return + } + body := bytes.TrimSpace(fallbackBody.Bytes()) + if len(body) == 0 { + return + } + mergeOpenAIUsage(&usage, body) + imageCounter.AddJSONResponse(body) + } + + streamInterval := s.openAIImageStreamDataInterval() + keepaliveInterval := s.openAIImageStreamKeepaliveInterval() + if streamInterval <= 0 && keepaliveInterval <= 0 { + reader := bufio.NewReader(resp.Body) + for { + line, err := reader.ReadBytes('\n') + processLine(line) + if err == io.EOF { + break + } + if err != nil { + flushSSEEvent() + return usage, imageCounter.Count(), firstTokenMs, err + } + } + flushSSEEvent() + finalizeFallbackBody() + return usage, imageCounter.Count(), firstTokenMs, nil + } + + type readEvent struct { + line []byte + err error + } + events := make(chan readEvent, 16) + done := make(chan struct{}) + sendEvent := func(ev readEvent) bool { + select { + case events <- ev: + return true + case <-done: + return false + } + } + var lastReadAt int64 + atomic.StoreInt64(&lastReadAt, time.Now().UnixNano()) + go func() { + defer close(events) + reader := bufio.NewReader(resp.Body) + for { + line, err := reader.ReadBytes('\n') + if len(line) > 0 { + atomic.StoreInt64(&lastReadAt, time.Now().UnixNano()) + } + if len(line) > 0 && !sendEvent(readEvent{line: line}) { + return + } + if err == io.EOF { + return + } + if err != nil { + _ = sendEvent(readEvent{err: err}) + return + } + } + }() + defer close(done) + + var intervalTicker *time.Ticker + if streamInterval > 0 { + intervalTicker = time.NewTicker(streamInterval) + defer intervalTicker.Stop() + } + var intervalCh <-chan time.Time + if intervalTicker != nil { + intervalCh = intervalTicker.C + } + + var keepaliveTicker *time.Ticker + if keepaliveInterval > 0 { + keepaliveTicker = time.NewTicker(keepaliveInterval) + defer keepaliveTicker.Stop() + } + var keepaliveCh <-chan time.Time + if keepaliveTicker != nil { + keepaliveCh = keepaliveTicker.C + } + + for { + select { + case ev, ok := <-events: + if !ok { + flushSSEEvent() + finalizeFallbackBody() + return usage, imageCounter.Count(), firstTokenMs, nil + } + if ev.err != nil { + flushSSEEvent() + return usage, imageCounter.Count(), firstTokenMs, ev.err + } + processLine(ev.line) + case <-intervalCh: + lastRead := time.Unix(0, atomic.LoadInt64(&lastReadAt)) + if time.Since(lastRead) < streamInterval { + continue + } + if clientDisconnected { + return usage, imageCounter.Count(), firstTokenMs, fmt.Errorf("image stream incomplete after timeout") + } + logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Images stream data interval timeout: interval=%s", streamInterval) + _ = s.writeOpenAIImagesStreamEvent(c, flusher, "error", buildOpenAIImagesStreamErrorBody(fmt.Sprintf("upstream image stream idle for %s", streamInterval))) + return usage, imageCounter.Count(), firstTokenMs, fmt.Errorf("image stream data interval timeout") + case <-keepaliveCh: + if clientDisconnected || time.Since(lastDownstreamWriteAt) < keepaliveInterval { + continue + } + if _, writeErr := io.WriteString(c.Writer, ":\n\n"); writeErr != nil { + clientDisconnected = true + logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Images stream client disconnected during keepalive, continue draining upstream for billing") + continue + } + flusher.Flush() + lastDownstreamWriteAt = time.Now() + } + } +} + +func (s *OpenAIGatewayService) openAIImageStreamDataInterval() time.Duration { + if s == nil || s.cfg == nil || s.cfg.Gateway.ImageStreamDataIntervalTimeout <= 0 { + return 0 + } + return time.Duration(s.cfg.Gateway.ImageStreamDataIntervalTimeout) * time.Second +} + +func (s *OpenAIGatewayService) openAIImageStreamKeepaliveInterval() time.Duration { + if s == nil || s.cfg == nil || s.cfg.Gateway.ImageStreamKeepaliveInterval <= 0 { + return 0 + } + return time.Duration(s.cfg.Gateway.ImageStreamKeepaliveInterval) * time.Second +} + +func extractOpenAIImagesBillableCountFromJSONBytes(body []byte) int { + if count := extractOpenAIImageCountFromJSONBytes(body); count > 0 { + return count + } + if len(body) == 0 || !gjson.ValidBytes(body) { + return 0 + } + if count := int(gjson.GetBytes(body, "usage.images").Int()); count > 0 { + return count + } + if count := int(gjson.GetBytes(body, "tool_usage.image_gen.images").Int()); count > 0 { + return count + } + eventType := strings.TrimSpace(gjson.GetBytes(body, "type").String()) + if eventType == "" || !strings.HasSuffix(eventType, ".completed") { + return 0 + } + if gjson.GetBytes(body, "b64_json").Exists() || gjson.GetBytes(body, "url").Exists() { + return 1 + } + return 0 } func mergeOpenAIUsage(dst *OpenAIUsage, body []byte) { @@ -863,14 +1110,7 @@ func mergeOpenAIUsage(dst *OpenAIUsage, body []byte) { } func extractOpenAIImageCountFromJSONBytes(body []byte) int { - if len(body) == 0 || !gjson.ValidBytes(body) { - return 0 - } - data := gjson.GetBytes(body, "data") - if data.Exists() && data.IsArray() { - return len(data.Array()) - } - return 0 + return countOpenAIResponseImageOutputsFromJSONBytes(body) } type openAIImagePointerInfo struct { diff --git a/backend/internal/service/openai_images_responses.go b/backend/internal/service/openai_images_responses.go index 64d995e1..25cd8228 100644 --- a/backend/internal/service/openai_images_responses.go +++ b/backend/internal/service/openai_images_responses.go @@ -9,6 +9,7 @@ import ( "io" "net/http" "strings" + "sync/atomic" "time" "github.com/Wei-Shaw/sub2api/internal/pkg/logger" @@ -361,21 +362,21 @@ func collectOpenAIImagesFromResponsesBody(body []byte) ([]openAIResponsesImageRe var ( fallbackResults []openAIResponsesImageResult fallbackSeen = make(map[string]struct{}) + finalResults []openAIResponsesImageResult + finalMeta openAIResponsesImageResult + collectErr error createdAt int64 usageRaw []byte foundFinal bool responseMeta openAIResponsesImageResult ) - for _, line := range bytes.Split(body, []byte("\n")) { - line = bytes.TrimRight(line, "\r") - data, ok := extractOpenAISSEDataLine(string(line)) - if !ok || data == "" || data == "[DONE]" { - continue + forEachOpenAISSEDataPayload(string(body), func(payload []byte) { + if collectErr != nil || len(finalResults) > 0 { + return } - payload := []byte(data) if !gjson.ValidBytes(payload) { - continue + return } if meta, eventCreatedAt, ok := extractOpenAIResponsesImageMetaFromLifecycleEvent(payload); ok { mergeOpenAIResponsesImageMeta(&responseMeta, meta) @@ -388,7 +389,8 @@ func collectOpenAIImagesFromResponsesBody(body []byte) ([]openAIResponsesImageRe case "response.output_item.done": result, itemID, ok, err := extractOpenAIImageFromResponsesOutputItemDone(payload) if err != nil { - return nil, 0, nil, openAIResponsesImageResult{}, false, err + collectErr = err + return } if ok { mergeOpenAIResponsesImageMeta(&result, responseMeta) @@ -397,7 +399,8 @@ func collectOpenAIImagesFromResponsesBody(body []byte) ([]openAIResponsesImageRe case "response.completed": results, completedAt, completedUsageRaw, firstMeta, err := extractOpenAIImagesFromResponsesCompleted(payload) if err != nil { - return nil, 0, nil, openAIResponsesImageResult{}, false, err + collectErr = err + return } foundFinal = true if completedAt > 0 { @@ -408,14 +411,24 @@ func collectOpenAIImagesFromResponsesBody(body []byte) ([]openAIResponsesImageRe } if len(results) > 0 { mergeOpenAIResponsesImageMeta(&firstMeta, responseMeta) - return results, createdAt, usageRaw, firstMeta, true, nil + finalResults = results + finalMeta = firstMeta + return } if len(fallbackResults) > 0 { firstMeta = fallbackResults[0] mergeOpenAIResponsesImageMeta(&firstMeta, responseMeta) - return fallbackResults, createdAt, usageRaw, firstMeta, true, nil + finalResults = fallbackResults + finalMeta = firstMeta + return } } + }) + if collectErr != nil { + return nil, 0, nil, openAIResponsesImageResult{}, false, collectErr + } + if len(finalResults) > 0 { + return finalResults, createdAt, usageRaw, finalMeta, true, nil } if len(fallbackResults) > 0 { @@ -505,6 +518,30 @@ func (s *OpenAIGatewayService) writeOpenAIImagesStreamEvent(c *gin.Context, flus return nil } +func (s *OpenAIGatewayService) tryWriteOpenAIImagesStreamEvent( + c *gin.Context, + flusher http.Flusher, + clientDisconnected *bool, + lastWriteAt *time.Time, + eventName string, + payload []byte, +) bool { + if clientDisconnected != nil && *clientDisconnected { + return false + } + if err := s.writeOpenAIImagesStreamEvent(c, flusher, eventName, payload); err != nil { + if clientDisconnected != nil { + *clientDisconnected = true + } + logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Images stream client disconnected, continue draining upstream for billing") + return false + } + if lastWriteAt != nil { + *lastWriteAt = time.Now() + } + return true +} + func (s *OpenAIGatewayService) handleOpenAIImagesOAuthNonStreamingResponse( resp *http.Response, c *gin.Context, @@ -517,15 +554,9 @@ func (s *OpenAIGatewayService) handleOpenAIImagesOAuthNonStreamingResponse( } var usage OpenAIUsage - for _, line := range bytes.Split(body, []byte("\n")) { - line = bytes.TrimRight(line, "\r") - data, ok := extractOpenAISSEDataLine(string(line)) - if !ok || data == "" || data == "[DONE]" { - continue - } - dataBytes := []byte(data) - s.parseSSEUsageBytes(dataBytes, &usage) - } + forEachOpenAISSEDataPayload(string(body), func(data []byte) { + s.parseSSEUsageBytes(data, &usage) + }) results, createdAt, usageRaw, firstMeta, _, err := collectOpenAIImagesFromResponsesBody(body) if err != nil { return OpenAIUsage{}, 0, err @@ -570,7 +601,6 @@ func (s *OpenAIGatewayService) handleOpenAIImagesOAuthStreamingResponse( format = "b64_json" } - reader := bufio.NewReader(resp.Body) usage := OpenAIUsage{} imageCount := 0 var firstTokenMs *int @@ -579,141 +609,307 @@ func (s *OpenAIGatewayService) handleOpenAIImagesOAuthStreamingResponse( pendingSeen := make(map[string]struct{}) streamMeta := openAIResponsesImageResult{Model: strings.TrimSpace(fallbackModel)} var createdAt int64 + clientDisconnected := false + lastDownstreamWriteAt := time.Now() + var sseData openAISSEDataAccumulator + var processDataErr error + processDataDone := false + + processData := func(dataBytes []byte) { + if processDataDone || processDataErr != nil { + return + } + if firstTokenMs == nil { + ms := int(time.Since(startTime).Milliseconds()) + firstTokenMs = &ms + } + s.parseSSEUsageBytes(dataBytes, &usage) + if !gjson.ValidBytes(dataBytes) { + return + } + if meta, eventCreatedAt, ok := extractOpenAIResponsesImageMetaFromLifecycleEvent(dataBytes); ok { + mergeOpenAIResponsesImageMeta(&streamMeta, meta) + if eventCreatedAt > 0 { + createdAt = eventCreatedAt + } + } + switch gjson.GetBytes(dataBytes, "type").String() { + case "response.image_generation_call.partial_image": + b64 := strings.TrimSpace(gjson.GetBytes(dataBytes, "partial_image_b64").String()) + if b64 == "" { + return + } + eventName := streamPrefix + ".partial_image" + partialMeta := streamMeta + mergeOpenAIResponsesImageMeta(&partialMeta, openAIResponsesImageResult{ + OutputFormat: strings.TrimSpace(gjson.GetBytes(dataBytes, "output_format").String()), + Background: strings.TrimSpace(gjson.GetBytes(dataBytes, "background").String()), + }) + payload := buildOpenAIImagesStreamPartialPayload( + eventName, + b64, + gjson.GetBytes(dataBytes, "partial_image_index").Int(), + format, + createdAt, + partialMeta, + ) + s.tryWriteOpenAIImagesStreamEvent(c, flusher, &clientDisconnected, &lastDownstreamWriteAt, eventName, payload) + case "response.output_item.done": + img, itemID, ok, extractErr := extractOpenAIImageFromResponsesOutputItemDone(dataBytes) + if extractErr != nil { + s.tryWriteOpenAIImagesStreamEvent(c, flusher, &clientDisconnected, &lastDownstreamWriteAt, "error", buildOpenAIImagesStreamErrorBody(extractErr.Error())) + processDataErr = extractErr + processDataDone = true + return + } + if !ok { + return + } + mergeOpenAIResponsesImageMeta(&streamMeta, img) + mergeOpenAIResponsesImageMeta(&img, streamMeta) + key := openAIResponsesImageResultKey(itemID, img) + if _, exists := emitted[key]; exists { + return + } + if _, exists := pendingSeen[key]; exists { + return + } + pendingSeen[key] = struct{}{} + pendingResults = append(pendingResults, img) + case "response.completed": + results, _, usageRaw, firstMeta, extractErr := extractOpenAIImagesFromResponsesCompleted(dataBytes) + if extractErr != nil { + s.tryWriteOpenAIImagesStreamEvent(c, flusher, &clientDisconnected, &lastDownstreamWriteAt, "error", buildOpenAIImagesStreamErrorBody(extractErr.Error())) + processDataErr = extractErr + processDataDone = true + return + } + mergeOpenAIResponsesImageMeta(&streamMeta, firstMeta) + finalResults := make([]openAIResponsesImageResult, 0, len(results)+len(pendingResults)) + finalSeen := make(map[string]struct{}) + for _, img := range results { + mergeOpenAIResponsesImageMeta(&img, streamMeta) + appendOpenAIResponsesImageResultDedup(&finalResults, finalSeen, "", img) + } + for _, img := range pendingResults { + mergeOpenAIResponsesImageMeta(&img, streamMeta) + appendOpenAIResponsesImageResultDedup(&finalResults, finalSeen, "", img) + } + if len(finalResults) == 0 { + outputErr := fmt.Errorf("upstream did not return image output") + s.tryWriteOpenAIImagesStreamEvent(c, flusher, &clientDisconnected, &lastDownstreamWriteAt, "error", buildOpenAIImagesStreamErrorBody(outputErr.Error())) + processDataErr = outputErr + processDataDone = true + return + } + eventName := streamPrefix + ".completed" + for _, img := range finalResults { + key := openAIResponsesImageResultKey("", img) + if _, exists := emitted[key]; exists { + continue + } + payload := buildOpenAIImagesStreamCompletedPayload(eventName, img, format, createdAt, usageRaw) + emitted[key] = struct{}{} + s.tryWriteOpenAIImagesStreamEvent(c, flusher, &clientDisconnected, &lastDownstreamWriteAt, eventName, payload) + } + imageCount = len(emitted) + processDataDone = true + } + } + + processLine := func(line []byte) (bool, error) { + if len(line) == 0 { + return false, nil + } + sseData.AddLine(string(line), processData) + if processDataErr != nil { + return true, processDataErr + } + return processDataDone, nil + } + + flushData := func() (bool, error) { + sseData.Flush(processData) + if processDataErr != nil { + return true, processDataErr + } + return processDataDone, nil + } + + finalizePending := func() error { + if imageCount > 0 { + return nil + } + if len(pendingResults) > 0 { + eventName := streamPrefix + ".completed" + for _, img := range pendingResults { + mergeOpenAIResponsesImageMeta(&img, streamMeta) + key := openAIResponsesImageResultKey("", img) + if _, exists := emitted[key]; exists { + continue + } + payload := buildOpenAIImagesStreamCompletedPayload(eventName, img, format, createdAt, nil) + emitted[key] = struct{}{} + s.tryWriteOpenAIImagesStreamEvent(c, flusher, &clientDisconnected, &lastDownstreamWriteAt, eventName, payload) + } + imageCount = len(emitted) + return nil + } + + streamErr := fmt.Errorf("stream disconnected before image generation completed") + s.tryWriteOpenAIImagesStreamEvent(c, flusher, &clientDisconnected, &lastDownstreamWriteAt, "error", buildOpenAIImagesStreamErrorBody(streamErr.Error())) + return streamErr + } + + streamInterval := s.openAIImageStreamDataInterval() + keepaliveInterval := s.openAIImageStreamKeepaliveInterval() + if streamInterval <= 0 && keepaliveInterval <= 0 { + reader := bufio.NewReader(resp.Body) + for { + line, err := reader.ReadBytes('\n') + done, processErr := processLine(line) + if processErr != nil { + return usage, imageCount, firstTokenMs, processErr + } + if done { + return usage, imageCount, firstTokenMs, nil + } + if err == io.EOF { + break + } + if err != nil { + if done, processErr := flushData(); processErr != nil { + return usage, imageCount, firstTokenMs, processErr + } else if done { + return usage, imageCount, firstTokenMs, nil + } + s.tryWriteOpenAIImagesStreamEvent(c, flusher, &clientDisconnected, &lastDownstreamWriteAt, "error", buildOpenAIImagesStreamErrorBody(err.Error())) + return usage, imageCount, firstTokenMs, err + } + } + if done, processErr := flushData(); processErr != nil { + return usage, imageCount, firstTokenMs, processErr + } else if done { + return usage, imageCount, firstTokenMs, nil + } + if err := finalizePending(); err != nil { + return usage, imageCount, firstTokenMs, err + } + return usage, imageCount, firstTokenMs, nil + } + + type readEvent struct { + line []byte + err error + } + events := make(chan readEvent, 16) + done := make(chan struct{}) + sendEvent := func(ev readEvent) bool { + select { + case events <- ev: + return true + case <-done: + return false + } + } + var lastReadAt int64 + atomic.StoreInt64(&lastReadAt, time.Now().UnixNano()) + go func() { + defer close(events) + reader := bufio.NewReader(resp.Body) + for { + line, err := reader.ReadBytes('\n') + if len(line) > 0 { + atomic.StoreInt64(&lastReadAt, time.Now().UnixNano()) + } + if len(line) > 0 && !sendEvent(readEvent{line: line}) { + return + } + if err == io.EOF { + return + } + if err != nil { + _ = sendEvent(readEvent{err: err}) + return + } + } + }() + defer close(done) + + var intervalTicker *time.Ticker + if streamInterval > 0 { + intervalTicker = time.NewTicker(streamInterval) + defer intervalTicker.Stop() + } + var intervalCh <-chan time.Time + if intervalTicker != nil { + intervalCh = intervalTicker.C + } + + var keepaliveTicker *time.Ticker + if keepaliveInterval > 0 { + keepaliveTicker = time.NewTicker(keepaliveInterval) + defer keepaliveTicker.Stop() + } + var keepaliveCh <-chan time.Time + if keepaliveTicker != nil { + keepaliveCh = keepaliveTicker.C + } for { - line, err := reader.ReadBytes('\n') - if len(line) > 0 { - trimmedLine := strings.TrimRight(string(line), "\r\n") - data, ok := extractOpenAISSEDataLine(trimmedLine) - if ok && data != "" && data != "[DONE]" { - if firstTokenMs == nil { - ms := int(time.Since(startTime).Milliseconds()) - firstTokenMs = &ms + select { + case ev, ok := <-events: + if !ok { + if done, processErr := flushData(); processErr != nil { + return usage, imageCount, firstTokenMs, processErr + } else if done { + return usage, imageCount, firstTokenMs, nil } - dataBytes := []byte(data) - s.parseSSEUsageBytes(dataBytes, &usage) - if gjson.ValidBytes(dataBytes) { - if meta, eventCreatedAt, ok := extractOpenAIResponsesImageMetaFromLifecycleEvent(dataBytes); ok { - mergeOpenAIResponsesImageMeta(&streamMeta, meta) - if eventCreatedAt > 0 { - createdAt = eventCreatedAt - } - } - switch gjson.GetBytes(dataBytes, "type").String() { - case "response.image_generation_call.partial_image": - b64 := strings.TrimSpace(gjson.GetBytes(dataBytes, "partial_image_b64").String()) - if b64 != "" { - eventName := streamPrefix + ".partial_image" - partialMeta := streamMeta - mergeOpenAIResponsesImageMeta(&partialMeta, openAIResponsesImageResult{ - OutputFormat: strings.TrimSpace(gjson.GetBytes(dataBytes, "output_format").String()), - Background: strings.TrimSpace(gjson.GetBytes(dataBytes, "background").String()), - }) - payload := buildOpenAIImagesStreamPartialPayload( - eventName, - b64, - gjson.GetBytes(dataBytes, "partial_image_index").Int(), - format, - createdAt, - partialMeta, - ) - if writeErr := s.writeOpenAIImagesStreamEvent(c, flusher, eventName, payload); writeErr != nil { - return OpenAIUsage{}, imageCount, firstTokenMs, writeErr - } - } - case "response.output_item.done": - img, itemID, ok, extractErr := extractOpenAIImageFromResponsesOutputItemDone(dataBytes) - if extractErr != nil { - _ = s.writeOpenAIImagesStreamEvent(c, flusher, "error", buildOpenAIImagesStreamErrorBody(extractErr.Error())) - return OpenAIUsage{}, imageCount, firstTokenMs, extractErr - } - if !ok { - break - } - mergeOpenAIResponsesImageMeta(&streamMeta, img) - mergeOpenAIResponsesImageMeta(&img, streamMeta) - key := openAIResponsesImageResultKey(itemID, img) - if _, exists := emitted[key]; exists { - break - } - if _, exists := pendingSeen[key]; exists { - break - } - pendingSeen[key] = struct{}{} - pendingResults = append(pendingResults, img) - case "response.completed": - results, _, usageRaw, firstMeta, extractErr := extractOpenAIImagesFromResponsesCompleted(dataBytes) - if extractErr != nil { - _ = s.writeOpenAIImagesStreamEvent(c, flusher, "error", buildOpenAIImagesStreamErrorBody(extractErr.Error())) - return OpenAIUsage{}, imageCount, firstTokenMs, extractErr - } - mergeOpenAIResponsesImageMeta(&streamMeta, firstMeta) - finalResults := make([]openAIResponsesImageResult, 0, len(results)+len(pendingResults)) - finalSeen := make(map[string]struct{}) - for _, img := range results { - mergeOpenAIResponsesImageMeta(&img, streamMeta) - appendOpenAIResponsesImageResultDedup(&finalResults, finalSeen, "", img) - } - for _, img := range pendingResults { - mergeOpenAIResponsesImageMeta(&img, streamMeta) - appendOpenAIResponsesImageResultDedup(&finalResults, finalSeen, "", img) - } - if len(finalResults) == 0 { - err = fmt.Errorf("upstream did not return image output") - _ = s.writeOpenAIImagesStreamEvent(c, flusher, "error", buildOpenAIImagesStreamErrorBody(err.Error())) - return OpenAIUsage{}, imageCount, firstTokenMs, err - } - eventName := streamPrefix + ".completed" - for _, img := range finalResults { - key := openAIResponsesImageResultKey("", img) - if _, exists := emitted[key]; exists { - continue - } - payload := buildOpenAIImagesStreamCompletedPayload(eventName, img, format, createdAt, usageRaw) - if writeErr := s.writeOpenAIImagesStreamEvent(c, flusher, eventName, payload); writeErr != nil { - return OpenAIUsage{}, imageCount, firstTokenMs, writeErr - } - emitted[key] = struct{}{} - } - imageCount = len(emitted) - return usage, imageCount, firstTokenMs, nil - } + if err := finalizePending(); err != nil { + return usage, imageCount, firstTokenMs, err } + return usage, imageCount, firstTokenMs, nil } - } - if err == io.EOF { - break - } - if err != nil { - _ = s.writeOpenAIImagesStreamEvent(c, flusher, "error", buildOpenAIImagesStreamErrorBody(err.Error())) - return OpenAIUsage{}, imageCount, firstTokenMs, err - } - } - - if imageCount > 0 { - return usage, imageCount, firstTokenMs, nil - } - if len(pendingResults) > 0 { - eventName := streamPrefix + ".completed" - for _, img := range pendingResults { - mergeOpenAIResponsesImageMeta(&img, streamMeta) - key := openAIResponsesImageResultKey("", img) - if _, exists := emitted[key]; exists { + if ev.err != nil { + if done, processErr := flushData(); processErr != nil { + return usage, imageCount, firstTokenMs, processErr + } else if done { + return usage, imageCount, firstTokenMs, nil + } + s.tryWriteOpenAIImagesStreamEvent(c, flusher, &clientDisconnected, &lastDownstreamWriteAt, "error", buildOpenAIImagesStreamErrorBody(ev.err.Error())) + return usage, imageCount, firstTokenMs, ev.err + } + done, processErr := processLine(ev.line) + if processErr != nil { + return usage, imageCount, firstTokenMs, processErr + } + if done { + return usage, imageCount, firstTokenMs, nil + } + case <-intervalCh: + lastRead := time.Unix(0, atomic.LoadInt64(&lastReadAt)) + if time.Since(lastRead) < streamInterval { continue } - payload := buildOpenAIImagesStreamCompletedPayload(eventName, img, format, createdAt, nil) - if writeErr := s.writeOpenAIImagesStreamEvent(c, flusher, eventName, payload); writeErr != nil { - return OpenAIUsage{}, imageCount, firstTokenMs, writeErr + if clientDisconnected { + return usage, imageCount, firstTokenMs, fmt.Errorf("image stream incomplete after timeout") } - emitted[key] = struct{}{} + logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Images responses stream data interval timeout: interval=%s", streamInterval) + s.tryWriteOpenAIImagesStreamEvent(c, flusher, &clientDisconnected, &lastDownstreamWriteAt, "error", buildOpenAIImagesStreamErrorBody(fmt.Sprintf("upstream image stream idle for %s", streamInterval))) + return usage, imageCount, firstTokenMs, fmt.Errorf("image stream data interval timeout") + case <-keepaliveCh: + if clientDisconnected || time.Since(lastDownstreamWriteAt) < keepaliveInterval { + continue + } + if _, writeErr := io.WriteString(c.Writer, ":\n\n"); writeErr != nil { + clientDisconnected = true + logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Images responses stream client disconnected during keepalive, continue draining upstream for billing") + continue + } + flusher.Flush() + lastDownstreamWriteAt = time.Now() } - imageCount = len(emitted) - return usage, imageCount, firstTokenMs, nil } - - streamErr := fmt.Errorf("stream disconnected before image generation completed") - _ = s.writeOpenAIImagesStreamEvent(c, flusher, "error", buildOpenAIImagesStreamErrorBody(streamErr.Error())) - return OpenAIUsage{}, imageCount, firstTokenMs, streamErr } func (s *OpenAIGatewayService) forwardOpenAIImagesOAuth( @@ -752,7 +948,10 @@ func (s *OpenAIGatewayService) forwardOpenAIImagesOAuth( ) } - token, _, err := s.GetAccessToken(ctx, account) + upstreamCtx, releaseUpstreamCtx := detachStreamUpstreamContext(ctx, parsed.Stream) + defer releaseUpstreamCtx() + + token, _, err := s.GetAccessToken(upstreamCtx, account) if err != nil { return nil, err } @@ -763,7 +962,7 @@ func (s *OpenAIGatewayService) forwardOpenAIImagesOAuth( } setOpsUpstreamRequestBody(c, responsesBody) - upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, responsesBody, token, true, parsed.StickySessionSeed(), false) + upstreamReq, err := s.buildUpstreamRequest(upstreamCtx, c, account, responsesBody, token, true, parsed.StickySessionSeed(), false) if err != nil { return nil, err } @@ -808,14 +1007,14 @@ func (s *OpenAIGatewayService) forwardOpenAIImagesOAuth( Kind: "failover", Message: upstreamMsg, }) - s.handleFailoverSideEffects(ctx, resp, account) + s.handleFailoverSideEffects(upstreamCtx, resp, account) return nil, &UpstreamFailoverError{ StatusCode: resp.StatusCode, ResponseBody: respBody, RetryableOnSameAccount: account.IsPoolMode() && isPoolModeRetryableStatus(resp.StatusCode), } } - return s.handleErrorResponse(ctx, resp, c, account, responsesBody) + return s.handleErrorResponse(upstreamCtx, resp, c, account, responsesBody) } defer func() { _ = resp.Body.Close() }() @@ -827,6 +1026,20 @@ func (s *OpenAIGatewayService) forwardOpenAIImagesOAuth( if parsed.Stream { usage, imageCount, firstTokenMs, err = s.handleOpenAIImagesOAuthStreamingResponse(resp, c, startTime, parsed.ResponseFormat, openAIImagesStreamPrefix(parsed), requestModel) if err != nil { + if imageCount > 0 { + return &OpenAIForwardResult{ + RequestID: resp.Header.Get("x-request-id"), + Usage: usage, + Model: requestModel, + UpstreamModel: requestModel, + Stream: parsed.Stream, + ResponseHeaders: resp.Header.Clone(), + Duration: time.Since(startTime), + FirstTokenMs: firstTokenMs, + ImageCount: imageCount, + ImageSize: parsed.SizeTier, + }, err + } return nil, err } } else { diff --git a/backend/internal/service/openai_images_test.go b/backend/internal/service/openai_images_test.go index 47113d4d..fa4a4415 100644 --- a/backend/internal/service/openai_images_test.go +++ b/backend/internal/service/openai_images_test.go @@ -3,6 +3,7 @@ package service import ( "bytes" "context" + "errors" "io" "mime/multipart" "net/http" @@ -17,6 +18,20 @@ import ( "github.com/tidwall/gjson" ) +type failingOpenAIImageWriter struct { + gin.ResponseWriter + failAfter int + writes int +} + +func (w *failingOpenAIImageWriter) Write(p []byte) (int, error) { + if w.writes >= w.failAfter { + return 0, errors.New("write failed: client disconnected") + } + w.writes++ + return w.ResponseWriter.Write(p) +} + func TestOpenAIGatewayServiceParseOpenAIImagesRequest_JSON(t *testing.T) { gin.SetMode(gin.TestMode) body := []byte(`{"model":"gpt-image-2","prompt":"draw a cat","size":"1024x1024","quality":"high","stream":true}`) @@ -75,6 +90,100 @@ func TestOpenAIGatewayServiceParseOpenAIImagesRequest_MultipartEdit(t *testing.T require.Equal(t, OpenAIImagesCapabilityNative, parsed.RequiredCapability) } +func TestOpenAIGatewayServiceParseOpenAIImagesRequest_NormalizesOfficialAndCustomSizes(t *testing.T) { + gin.SetMode(gin.TestMode) + + tests := []struct { + size string + wantTier string + }{ + {size: "1024x1024", wantTier: "1K"}, + {size: "1536x1024", wantTier: "2K"}, + {size: "1024x1536", wantTier: "2K"}, + {size: "2048x2048", wantTier: "2K"}, + {size: "2048x1152", wantTier: "2K"}, + {size: "3840x2160", wantTier: "4K"}, + {size: "2160x3840", wantTier: "4K"}, + {size: "1024X768", wantTier: "2K"}, + {size: "1280x768", wantTier: "2K"}, + {size: "2560x1440", wantTier: "2K"}, + {size: "2560x1600", wantTier: "4K"}, + {size: "auto", wantTier: "2K"}, + } + + svc := &OpenAIGatewayService{} + for _, tt := range tests { + t.Run(tt.size, func(t *testing.T) { + body := []byte(`{"model":"gpt-image-2","prompt":"draw a cat","size":"` + tt.size + `"}`) + + req := httptest.NewRequest(http.MethodPost, "/v1/images/generations", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = req + + parsed, err := svc.ParseOpenAIImagesRequest(c, body) + require.NoError(t, err) + require.NotNil(t, parsed) + require.Equal(t, tt.size, parsed.Size) + require.Equal(t, tt.wantTier, parsed.SizeTier) + }) + } +} + +func TestOpenAIGatewayServiceParseOpenAIImagesRequest_UnknownSizesDoNotBlockPassthrough(t *testing.T) { + gin.SetMode(gin.TestMode) + + tests := []struct { + size string + wantTier string + }{ + {size: "2048x1153", wantTier: "2K"}, + {size: "4096x1024", wantTier: "4K"}, + {size: "3840x1024", wantTier: "4K"}, + {size: "512x512", wantTier: "2K"}, + {size: "invalid", wantTier: "2K"}, + {size: "999999999999999999999999999x2", wantTier: "2K"}, + } + + svc := &OpenAIGatewayService{} + for _, tt := range tests { + t.Run(tt.size, func(t *testing.T) { + body := []byte(`{"model":"gpt-image-2","prompt":"draw a cat","size":"` + tt.size + `"}`) + + req := httptest.NewRequest(http.MethodPost, "/v1/images/generations", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = req + + parsed, err := svc.ParseOpenAIImagesRequest(c, body) + require.NoError(t, err) + require.NotNil(t, parsed) + require.Equal(t, tt.size, parsed.Size) + require.Equal(t, tt.wantTier, parsed.SizeTier) + }) + } +} + +func TestOpenAIGatewayServiceParseOpenAIImagesRequest_LegacyImageModelUnknownSizePassthrough(t *testing.T) { + gin.SetMode(gin.TestMode) + body := []byte(`{"model":"gpt-image-1.5","prompt":"draw a cat","size":"2048x1152"}`) + + req := httptest.NewRequest(http.MethodPost, "/v1/images/generations", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = req + + svc := &OpenAIGatewayService{} + parsed, err := svc.ParseOpenAIImagesRequest(c, body) + require.NoError(t, err) + require.NotNil(t, parsed) + require.Equal(t, "2048x1152", parsed.Size) + require.Equal(t, "2K", parsed.SizeTier) +} + func TestOpenAIGatewayServiceParseOpenAIImagesRequest_MultipartEditWithMaskAndNativeOptions(t *testing.T) { gin.SetMode(gin.TestMode) @@ -446,6 +555,160 @@ func TestOpenAIGatewayServiceForwardImages_APIKeyGenerationUsesConfiguredV1BaseU require.Equal(t, "aGVsbG8=", gjson.Get(rec.Body.String(), "data.0.b64_json").String()) } +func TestOpenAIGatewayServiceForwardImages_APIKeyStreamJSONResponseBillsImage(t *testing.T) { + gin.SetMode(gin.TestMode) + body := []byte(`{"model":"gpt-image-2","prompt":"draw a cat","stream":true,"response_format":"b64_json"}`) + + req := httptest.NewRequest(http.MethodPost, "/v1/images/generations", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = req + + svc := &OpenAIGatewayService{ + cfg: &config.Config{}, + httpUpstream: &httpUpstreamRecorder{ + resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{ + "Content-Type": []string{"application/json"}, + "X-Request-Id": []string{"req_img_stream_json"}, + }, + Body: io.NopCloser(strings.NewReader(`{"created":1710000008,"usage":{"input_tokens":12,"output_tokens":21,"output_tokens_details":{"image_tokens":9}},"data":[{"b64_json":"aGVsbG8=","revised_prompt":"draw a cat"}]}`)), + }, + }, + } + parsed, err := svc.ParseOpenAIImagesRequest(c, body) + require.NoError(t, err) + + account := &Account{ + ID: 7, + Name: "openai-apikey", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Credentials: map[string]any{ + "api_key": "test-api-key", + "base_url": "https://image-upstream.example/v1", + }, + } + + result, err := svc.ForwardImages(context.Background(), c, account, body, parsed, "") + require.NoError(t, err) + require.NotNil(t, result) + require.True(t, result.Stream) + require.Equal(t, 1, result.ImageCount) + require.Equal(t, 12, result.Usage.InputTokens) + require.Equal(t, 21, result.Usage.OutputTokens) + require.Equal(t, 9, result.Usage.ImageOutputTokens) + require.Equal(t, http.StatusOK, rec.Code) + require.Equal(t, "aGVsbG8=", gjson.Get(rec.Body.String(), "data.0.b64_json").String()) +} + +func TestOpenAIGatewayServiceForwardImages_APIKeyStreamRawJSONEventStreamFallbackBillsImage(t *testing.T) { + gin.SetMode(gin.TestMode) + body := []byte(`{"model":"gpt-image-2","prompt":"draw a cat","stream":true,"response_format":"b64_json"}`) + + req := httptest.NewRequest(http.MethodPost, "/v1/images/generations", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = req + + svc := &OpenAIGatewayService{ + cfg: &config.Config{}, + httpUpstream: &httpUpstreamRecorder{ + resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{ + "Content-Type": []string{"text/event-stream"}, + "X-Request-Id": []string{"req_img_stream_json_mislabeled"}, + }, + Body: io.NopCloser(strings.NewReader(`{"created":1710000009,"usage":{"input_tokens":10,"output_tokens":18,"output_tokens_details":{"image_tokens":8}},"data":[{"b64_json":"ZmluYWw="}]}`)), + }, + }, + } + parsed, err := svc.ParseOpenAIImagesRequest(c, body) + require.NoError(t, err) + + account := &Account{ + ID: 8, + Name: "openai-apikey", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Credentials: map[string]any{ + "api_key": "test-api-key", + "base_url": "https://image-upstream.example/v1", + }, + } + + result, err := svc.ForwardImages(context.Background(), c, account, body, parsed, "") + require.NoError(t, err) + require.NotNil(t, result) + require.True(t, result.Stream) + require.Equal(t, 1, result.ImageCount) + require.Equal(t, 10, result.Usage.InputTokens) + require.Equal(t, 18, result.Usage.OutputTokens) + require.Equal(t, 8, result.Usage.ImageOutputTokens) + require.Equal(t, "ZmluYWw=", gjson.Get(rec.Body.String(), "data.0.b64_json").String()) +} + +func TestOpenAIGatewayServiceForwardImages_APIKeyStreamMultilineSSEDataBillsImage(t *testing.T) { + gin.SetMode(gin.TestMode) + body := []byte(`{"model":"gpt-image-2","prompt":"draw a cat","stream":true,"response_format":"b64_json"}`) + + req := httptest.NewRequest(http.MethodPost, "/v1/images/generations", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = req + + svc := &OpenAIGatewayService{ + cfg: &config.Config{}, + httpUpstream: &httpUpstreamRecorder{ + resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{ + "Content-Type": []string{"text/event-stream"}, + "X-Request-Id": []string{"req_img_stream_multiline"}, + }, + Body: io.NopCloser(strings.NewReader( + "data: {\"type\":\"image_generation.completed\",\n" + + "data: \"usage\":{\"input_tokens\":10,\"output_tokens\":18,\"output_tokens_details\":{\"image_tokens\":8}},\n" + + "data: \"b64_json\":\"ZmluYWw=\",\"output_format\":\"png\"}\n\n" + + "data: [DONE]\n\n", + )), + }, + }, + } + parsed, err := svc.ParseOpenAIImagesRequest(c, body) + require.NoError(t, err) + + account := &Account{ + ID: 8, + Name: "openai-apikey", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Credentials: map[string]any{ + "api_key": "test-api-key", + }, + } + + result, err := svc.ForwardImages(context.Background(), c, account, body, parsed, "") + require.NoError(t, err) + require.NotNil(t, result) + require.True(t, result.Stream) + require.Equal(t, 1, result.ImageCount) + require.Equal(t, 10, result.Usage.InputTokens) + require.Equal(t, 18, result.Usage.OutputTokens) + require.Equal(t, 8, result.Usage.ImageOutputTokens) +} + +func TestExtractOpenAIImagesBillableCountFromJSONBytes_CompletedEvent(t *testing.T) { + body := []byte(`{"type":"image_generation.completed","b64_json":"ZmluYWw=","usage":{"input_tokens":10,"output_tokens":18}}`) + + require.Equal(t, 1, extractOpenAIImagesBillableCountFromJSONBytes(body)) +} + func TestOpenAIGatewayServiceForwardImages_APIKeyEditUsesConfiguredV1BaseURL(t *testing.T) { gin.SetMode(gin.TestMode) @@ -583,6 +846,61 @@ func TestOpenAIGatewayServiceForwardImages_OAuthStreamingTransformsEvents(t *tes require.False(t, gjson.Get(completed.Data, "revised_prompt").Exists()) } +func TestOpenAIGatewayServiceForwardImages_APIKeyStreamingDrainsAfterClientDisconnect(t *testing.T) { + gin.SetMode(gin.TestMode) + body := []byte(`{"model":"gpt-image-2","prompt":"draw a cat","stream":true}`) + + req := httptest.NewRequest(http.MethodPost, "/v1/images/generations", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = req + c.Writer = &failingOpenAIImageWriter{ResponseWriter: c.Writer, failAfter: 1} + + svc := &OpenAIGatewayService{ + cfg: &config.Config{ + Gateway: config.GatewayConfig{ + ImageStreamDataIntervalTimeout: 1, + ImageStreamKeepaliveInterval: 0, + }, + }, + httpUpstream: &httpUpstreamRecorder{ + resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{ + "Content-Type": []string{"text/event-stream"}, + "X-Request-Id": []string{"req_img_stream_disconnect_apikey"}, + }, + Body: io.NopCloser(strings.NewReader( + "data: {\"type\":\"image_generation.partial_image\",\"b64_json\":\"cGFydGlhbA==\"}\n\n" + + "data: {\"type\":\"image_generation.completed\",\"usage\":{\"input_tokens\":3,\"output_tokens\":4,\"output_tokens_details\":{\"image_tokens\":2}},\"b64_json\":\"ZmluYWw=\",\"output_format\":\"png\"}\n\n" + + "data: [DONE]\n\n", + )), + }, + }, + } + parsed, err := svc.ParseOpenAIImagesRequest(c, body) + require.NoError(t, err) + + account := &Account{ + ID: 8, + Name: "openai-apikey", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Credentials: map[string]any{ + "api_key": "test-api-key", + }, + } + + result, err := svc.ForwardImages(context.Background(), c, account, body, parsed, "") + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, 1, result.ImageCount) + require.Equal(t, 3, result.Usage.InputTokens) + require.Equal(t, 4, result.Usage.OutputTokens) + require.Equal(t, 2, result.Usage.ImageOutputTokens) +} + func TestOpenAIGatewayServiceForwardImages_OAuthEditsMultipartUsesResponsesAPI(t *testing.T) { gin.SetMode(gin.TestMode) @@ -798,6 +1116,23 @@ func TestCollectOpenAIImagesFromResponsesBody_FallsBackToOutputItemDone(t *testi require.JSONEq(t, `{"images":1}`, string(usageRaw)) } +func TestCollectOpenAIImagesFromResponsesBody_MultilineSSE(t *testing.T) { + body := []byte( + "data: {\"type\":\"response.completed\",\n" + + "data: \"response\":{\"created_at\":1710000010,\"usage\":{\"input_tokens\":5,\"output_tokens\":9,\"output_tokens_details\":{\"image_tokens\":4}},\"tool_usage\":{\"image_gen\":{\"images\":1}},\"output\":[{\"type\":\"image_generation_call\",\"result\":\"ZmluYWw=\",\"output_format\":\"png\"}]}}\n\n" + + "data: [DONE]\n\n", + ) + + results, createdAt, usageRaw, firstMeta, foundFinal, err := collectOpenAIImagesFromResponsesBody(body) + require.NoError(t, err) + require.True(t, foundFinal) + require.Equal(t, int64(1710000010), createdAt) + require.Len(t, results, 1) + require.Equal(t, "ZmluYWw=", results[0].Result) + require.Equal(t, "png", firstMeta.OutputFormat) + require.JSONEq(t, `{"images":1}`, string(usageRaw)) +} + func TestOpenAIGatewayServiceForwardImages_OAuthStreamingHandlesOutputItemDoneFallback(t *testing.T) { gin.SetMode(gin.TestMode) body := []byte(`{"model":"gpt-image-2","prompt":"draw a cat","stream":true,"response_format":"url"}`) @@ -854,3 +1189,116 @@ func TestOpenAIGatewayServiceForwardImages_OAuthStreamingHandlesOutputItemDoneFa require.JSONEq(t, `{"images":1}`, gjson.Get(completed.Data, "usage").Raw) require.NotContains(t, rec.Body.String(), "event: error") } + +func TestOpenAIGatewayServiceForwardImages_OAuthStreamingHandlesMultilineSSE(t *testing.T) { + gin.SetMode(gin.TestMode) + body := []byte(`{"model":"gpt-image-2","prompt":"draw a cat","stream":true,"response_format":"b64_json"}`) + + req := httptest.NewRequest(http.MethodPost, "/v1/images/generations", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = req + + svc := &OpenAIGatewayService{} + parsed, err := svc.ParseOpenAIImagesRequest(c, body) + require.NoError(t, err) + + svc.httpUpstream = &httpUpstreamRecorder{ + resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{ + "Content-Type": []string{"text/event-stream"}, + "X-Request-Id": []string{"req_img_stream_multiline_oauth"}, + }, + Body: io.NopCloser(strings.NewReader( + "data: {\"type\":\"response.completed\",\n" + + "data: \"response\":{\"created_at\":1710000011,\"usage\":{\"input_tokens\":6,\"output_tokens\":10,\"output_tokens_details\":{\"image_tokens\":5}},\"tool_usage\":{\"image_gen\":{\"images\":1}},\"output\":[{\"type\":\"image_generation_call\",\"result\":\"TXVsdGlsaW5l\",\"output_format\":\"png\"}]}}\n\n" + + "data: [DONE]\n\n", + )), + }, + } + + account := &Account{ + ID: 11, + Name: "openai-oauth", + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "access_token": "token-123", + }, + } + + result, err := svc.ForwardImages(context.Background(), c, account, body, parsed, "") + require.NoError(t, err) + require.NotNil(t, result) + require.True(t, result.Stream) + require.Equal(t, 1, result.ImageCount) + require.Equal(t, 6, result.Usage.InputTokens) + require.Equal(t, 10, result.Usage.OutputTokens) + require.Equal(t, 5, result.Usage.ImageOutputTokens) + events := parseOpenAIImageTestSSEEvents(rec.Body.String()) + completed, ok := findOpenAIImageTestSSEEvent(events, "image_generation.completed") + require.True(t, ok) + require.Equal(t, "TXVsdGlsaW5l", gjson.Get(completed.Data, "b64_json").String()) + require.JSONEq(t, `{"images":1}`, gjson.Get(completed.Data, "usage").Raw) + require.NotContains(t, rec.Body.String(), "event: error") +} + +func TestOpenAIGatewayServiceForwardImages_OAuthStreamingDrainsAfterClientDisconnect(t *testing.T) { + gin.SetMode(gin.TestMode) + body := []byte(`{"model":"gpt-image-2","prompt":"draw a cat","stream":true,"response_format":"url"}`) + + req := httptest.NewRequest(http.MethodPost, "/v1/images/generations", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = req + c.Writer = &failingOpenAIImageWriter{ResponseWriter: c.Writer, failAfter: 1} + + svc := &OpenAIGatewayService{ + cfg: &config.Config{ + Gateway: config.GatewayConfig{ + ImageStreamDataIntervalTimeout: 1, + ImageStreamKeepaliveInterval: 0, + }, + }, + } + parsed, err := svc.ParseOpenAIImagesRequest(c, body) + require.NoError(t, err) + + upstream := &httpUpstreamRecorder{ + resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{ + "Content-Type": []string{"text/event-stream"}, + "X-Request-Id": []string{"req_img_stream_disconnect_oauth"}, + }, + Body: io.NopCloser(strings.NewReader( + "data: {\"type\":\"response.image_generation_call.partial_image\",\"partial_image_b64\":\"cGFydGlhbA==\",\"partial_image_index\":0,\"output_format\":\"png\"}\n\n" + + "data: {\"type\":\"response.completed\",\"response\":{\"created_at\":1710000009,\"usage\":{\"input_tokens\":5,\"output_tokens\":9,\"output_tokens_details\":{\"image_tokens\":4}},\"tool_usage\":{\"image_gen\":{\"images\":1}},\"output\":[{\"type\":\"image_generation_call\",\"result\":\"ZmluYWw=\",\"output_format\":\"png\"}]}}\n\n" + + "data: [DONE]\n\n", + )), + }, + } + svc.httpUpstream = upstream + + account := &Account{ + ID: 9, + Name: "openai-oauth", + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "access_token": "token-123", + }, + } + + result, err := svc.ForwardImages(context.Background(), c, account, body, parsed, "") + require.NoError(t, err) + require.NotNil(t, result) + require.True(t, result.Stream) + require.Equal(t, 1, result.ImageCount) + require.Equal(t, 5, result.Usage.InputTokens) + require.Equal(t, 9, result.Usage.OutputTokens) + require.Equal(t, 4, result.Usage.ImageOutputTokens) +} diff --git a/backend/internal/service/openai_messages_bridge.go b/backend/internal/service/openai_messages_bridge.go new file mode 100644 index 00000000..d67b4b1e --- /dev/null +++ b/backend/internal/service/openai_messages_bridge.go @@ -0,0 +1,57 @@ +package service + +import ( + "bytes" + "strings" + + "github.com/gin-gonic/gin" + "github.com/tidwall/gjson" +) + +const openAICompatMessagesBridgeContextKey = "openai_compat_messages_bridge" + +func isOpenAICompatMessagesBridgeBody(body []byte) bool { + if len(body) == 0 { + return false + } + if bytes.Contains(body, []byte(openAICompatClaudeCodeTodoGuardMarker)) { + return true + } + return isOpenAICompatMessagesBridgePromptCacheKey(gjson.GetBytes(body, "prompt_cache_key").String()) +} + +func isOpenAICompatMessagesBridgeRequestBody(reqBody map[string]any) bool { + if reqBody == nil { + return false + } + if input, ok := reqBody["input"].([]any); ok && inputContainsText(input, openAICompatClaudeCodeTodoGuardMarker) { + return true + } + return isOpenAICompatMessagesBridgePromptCacheKey(firstNonEmptyString(reqBody["prompt_cache_key"])) +} + +func isOpenAICompatMessagesBridgePromptCacheKey(key string) bool { + key = strings.TrimSpace(key) + return strings.HasPrefix(key, "anthropic-metadata-") || + strings.HasPrefix(key, "anthropic-cache-") || + strings.HasPrefix(key, "anthropic-digest-") +} + +func setOpenAICompatMessagesBridgeContext(c *gin.Context, enabled bool) { + if c == nil || !enabled { + return + } + c.Set(openAICompatMessagesBridgeContextKey, true) +} + +func isOpenAICompatMessagesBridgeContext(c *gin.Context) bool { + if c == nil { + return false + } + value, ok := c.Get(openAICompatMessagesBridgeContextKey) + if !ok { + return false + } + enabled, ok := value.(bool) + return ok && enabled +} diff --git a/backend/internal/service/openai_messages_continuation.go b/backend/internal/service/openai_messages_continuation.go new file mode 100644 index 00000000..57d04784 --- /dev/null +++ b/backend/internal/service/openai_messages_continuation.go @@ -0,0 +1,277 @@ +package service + +import ( + "context" + "encoding/json" + "net/http" + "strconv" + "strings" + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/apicompat" + "github.com/gin-gonic/gin" + "github.com/tidwall/gjson" +) + +type openAICompatSessionResponseBinding struct { + ResponseID string + TurnState string + ContinuationDisabled bool + ExpiresAt time.Time +} + +func openAICompatContinuationEnabled(account *Account, model string) bool { + if account == nil || account.Type != AccountTypeAPIKey { + return false + } + return shouldAutoInjectPromptCacheKeyForCompat(model) +} + +func trimAnthropicCompatResponsesInputToLatestTurn(req *apicompat.ResponsesRequest) { + if req == nil || len(req.Input) == 0 { + return + } + + var items []apicompat.ResponsesInputItem + if err := json.Unmarshal(req.Input, &items); err != nil || len(items) == 0 { + return + } + + start := len(items) - 1 + for start > 0 && items[start].Type == "function_call_output" { + start-- + } + trimmed := append([]apicompat.ResponsesInputItem(nil), items[start:]...) + if len(trimmed) == len(items) { + return + } + if input, err := json.Marshal(trimmed); err == nil { + req.Input = input + } +} + +func isOpenAICompatPreviousResponseNotFound(statusCode int, upstreamMsg string, upstreamBody []byte) bool { + if statusCode != http.StatusBadRequest && statusCode != http.StatusNotFound { + return false + } + check := func(s string) bool { + lower := strings.ToLower(strings.TrimSpace(s)) + return strings.Contains(lower, "previous_response_not_found") || + (strings.Contains(lower, "previous response") && strings.Contains(lower, "not found")) || + (strings.Contains(lower, "unsupported parameter") && strings.Contains(lower, "previous_response_id")) + } + if check(upstreamMsg) || check(string(upstreamBody)) { + return true + } + return check(gjson.GetBytes(upstreamBody, "error.code").String()) || + check(gjson.GetBytes(upstreamBody, "error.message").String()) +} + +func isOpenAICompatPreviousResponseUnsupported(statusCode int, upstreamMsg string, upstreamBody []byte) bool { + if statusCode != http.StatusBadRequest { + return false + } + check := func(s string) bool { + lower := strings.ToLower(strings.TrimSpace(s)) + if !strings.Contains(lower, "previous_response_id") { + return false + } + return strings.Contains(lower, "unsupported parameter") || + strings.Contains(lower, "only supported on responses websocket") || + strings.Contains(lower, "not supported") + } + if check(upstreamMsg) || check(string(upstreamBody)) { + return true + } + return check(gjson.GetBytes(upstreamBody, "error.code").String()) || + check(gjson.GetBytes(upstreamBody, "error.message").String()) +} + +func openAICompatSessionResponseKey(c *gin.Context, account *Account, promptCacheKey string) string { + key := strings.TrimSpace(promptCacheKey) + if account == nil || key == "" { + return "" + } + apiKeyID := int64(0) + if c != nil { + apiKeyID = getAPIKeyIDFromContext(c) + } + return strings.Join([]string{ + strconv.FormatInt(account.ID, 10), + strconv.FormatInt(apiKeyID, 10), + key, + }, "\x00") +} + +func (s *OpenAIGatewayService) getOpenAICompatSessionResponseID(_ context.Context, c *gin.Context, account *Account, promptCacheKey string) string { + if s == nil { + return "" + } + key := openAICompatSessionResponseKey(c, account, promptCacheKey) + if key == "" { + return "" + } + raw, ok := s.openaiCompatSessionResponses.Load(key) + if !ok { + return "" + } + binding, ok := raw.(openAICompatSessionResponseBinding) + if !ok { + s.openaiCompatSessionResponses.Delete(key) + return "" + } + if !binding.ExpiresAt.IsZero() && time.Now().After(binding.ExpiresAt) { + s.openaiCompatSessionResponses.Delete(key) + return "" + } + if binding.ContinuationDisabled { + return "" + } + if strings.TrimSpace(binding.ResponseID) == "" { + s.openaiCompatSessionResponses.Delete(key) + return "" + } + return strings.TrimSpace(binding.ResponseID) +} + +func (s *OpenAIGatewayService) bindOpenAICompatSessionResponseID(_ context.Context, c *gin.Context, account *Account, promptCacheKey, responseID string) { + if s == nil { + return + } + key := openAICompatSessionResponseKey(c, account, promptCacheKey) + id := strings.TrimSpace(responseID) + if key == "" || id == "" { + return + } + binding := openAICompatSessionResponseBinding{ + ResponseID: id, + ExpiresAt: time.Now().Add(s.openAIWSResponseStickyTTL()), + } + if raw, ok := s.openaiCompatSessionResponses.Load(key); ok { + if existing, ok := raw.(openAICompatSessionResponseBinding); ok { + if existing.ContinuationDisabled { + existing.ResponseID = "" + existing.ExpiresAt = time.Now().Add(s.openAIWSResponseStickyTTL()) + s.openaiCompatSessionResponses.Store(key, existing) + return + } + binding.TurnState = existing.TurnState + } + } + s.openaiCompatSessionResponses.Store(key, binding) +} + +func (s *OpenAIGatewayService) deleteOpenAICompatSessionResponseID(_ context.Context, c *gin.Context, account *Account, promptCacheKey string) { + if s == nil { + return + } + key := openAICompatSessionResponseKey(c, account, promptCacheKey) + if key == "" { + return + } + raw, ok := s.openaiCompatSessionResponses.Load(key) + if !ok { + return + } + binding, ok := raw.(openAICompatSessionResponseBinding) + if !ok { + s.openaiCompatSessionResponses.Delete(key) + return + } + binding.ResponseID = "" + if strings.TrimSpace(binding.TurnState) == "" && !binding.ContinuationDisabled { + s.openaiCompatSessionResponses.Delete(key) + return + } + binding.ExpiresAt = time.Now().Add(s.openAIWSResponseStickyTTL()) + s.openaiCompatSessionResponses.Store(key, binding) +} + +func (s *OpenAIGatewayService) disableOpenAICompatSessionContinuation(_ context.Context, c *gin.Context, account *Account, promptCacheKey string) { + if s == nil { + return + } + key := openAICompatSessionResponseKey(c, account, promptCacheKey) + if key == "" { + return + } + binding := openAICompatSessionResponseBinding{ + ContinuationDisabled: true, + ExpiresAt: time.Now().Add(s.openAIWSResponseStickyTTL()), + } + if raw, ok := s.openaiCompatSessionResponses.Load(key); ok { + if existing, ok := raw.(openAICompatSessionResponseBinding); ok { + binding.TurnState = existing.TurnState + } + } + s.openaiCompatSessionResponses.Store(key, binding) +} + +func (s *OpenAIGatewayService) isOpenAICompatSessionContinuationDisabled(_ context.Context, c *gin.Context, account *Account, promptCacheKey string) bool { + if s == nil { + return false + } + key := openAICompatSessionResponseKey(c, account, promptCacheKey) + if key == "" { + return false + } + raw, ok := s.openaiCompatSessionResponses.Load(key) + if !ok { + return false + } + binding, ok := raw.(openAICompatSessionResponseBinding) + if !ok { + s.openaiCompatSessionResponses.Delete(key) + return false + } + if !binding.ExpiresAt.IsZero() && time.Now().After(binding.ExpiresAt) { + s.openaiCompatSessionResponses.Delete(key) + return false + } + return binding.ContinuationDisabled +} + +func (s *OpenAIGatewayService) getOpenAICompatSessionTurnState(_ context.Context, c *gin.Context, account *Account, promptCacheKey string) string { + if s == nil { + return "" + } + key := openAICompatSessionResponseKey(c, account, promptCacheKey) + if key == "" { + return "" + } + raw, ok := s.openaiCompatSessionResponses.Load(key) + if !ok { + return "" + } + binding, ok := raw.(openAICompatSessionResponseBinding) + if !ok || strings.TrimSpace(binding.TurnState) == "" { + return "" + } + if !binding.ExpiresAt.IsZero() && time.Now().After(binding.ExpiresAt) { + s.openaiCompatSessionResponses.Delete(key) + return "" + } + return strings.TrimSpace(binding.TurnState) +} + +func (s *OpenAIGatewayService) bindOpenAICompatSessionTurnState(_ context.Context, c *gin.Context, account *Account, promptCacheKey, turnState string) { + if s == nil { + return + } + key := openAICompatSessionResponseKey(c, account, promptCacheKey) + state := strings.TrimSpace(turnState) + if key == "" || state == "" { + return + } + binding := openAICompatSessionResponseBinding{ + TurnState: state, + ExpiresAt: time.Now().Add(s.openAIWSResponseStickyTTL()), + } + if raw, ok := s.openaiCompatSessionResponses.Load(key); ok { + if existing, ok := raw.(openAICompatSessionResponseBinding); ok { + binding.ResponseID = existing.ResponseID + binding.ContinuationDisabled = existing.ContinuationDisabled + } + } + s.openaiCompatSessionResponses.Store(key, binding) +} diff --git a/backend/internal/service/openai_messages_digest_session.go b/backend/internal/service/openai_messages_digest_session.go new file mode 100644 index 00000000..44a49d1e --- /dev/null +++ b/backend/internal/service/openai_messages_digest_session.go @@ -0,0 +1,135 @@ +package service + +import ( + "encoding/json" + "fmt" + "strings" + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/apicompat" +) + +type openAICompatAnthropicDigestBinding struct { + PromptCacheKey string + ExpiresAt time.Time +} + +func buildOpenAICompatAnthropicDigestChain(req *apicompat.AnthropicRequest) string { + if req == nil { + return "" + } + + parts := make([]string, 0, len(req.Messages)+1) + if len(req.System) > 0 && strings.TrimSpace(string(req.System)) != "" && strings.TrimSpace(string(req.System)) != "null" { + parts = append(parts, "s:"+shortHash(req.System)) + } + for _, msg := range req.Messages { + content := msg.Content + if len(content) == 0 || strings.TrimSpace(string(content)) == "" { + continue + } + prefix := "u" + if strings.TrimSpace(msg.Role) == "assistant" { + prefix = "a" + } + parts = append(parts, prefix+":"+shortHash(content)) + } + return strings.Join(parts, "-") +} + +func openAICompatAnthropicDigestNamespace(account *Account, cAPIKeyID int64) string { + if account == nil || account.ID <= 0 { + return "" + } + return fmt.Sprintf("%d|%d|", account.ID, cAPIKeyID) +} + +func (s *OpenAIGatewayService) findOpenAICompatAnthropicDigestPromptCacheKey(account *Account, cAPIKeyID int64, digestChain string) (promptCacheKey string, matchedChain string) { + if s == nil || digestChain == "" { + return "", "" + } + ns := openAICompatAnthropicDigestNamespace(account, cAPIKeyID) + if ns == "" { + return "", "" + } + chain := digestChain + for { + if raw, ok := s.openaiCompatAnthropicDigestSessions.Load(ns + chain); ok { + if binding, ok := raw.(openAICompatAnthropicDigestBinding); ok { + if binding.ExpiresAt.IsZero() || time.Now().Before(binding.ExpiresAt) { + if key := strings.TrimSpace(binding.PromptCacheKey); key != "" { + return key, chain + } + } + } + s.openaiCompatAnthropicDigestSessions.Delete(ns + chain) + } + i := strings.LastIndex(chain, "-") + if i < 0 { + return "", "" + } + chain = chain[:i] + } +} + +func (s *OpenAIGatewayService) bindOpenAICompatAnthropicDigestPromptCacheKey(account *Account, cAPIKeyID int64, digestChain, promptCacheKey, oldDigestChain string) { + if s == nil || digestChain == "" || strings.TrimSpace(promptCacheKey) == "" { + return + } + ns := openAICompatAnthropicDigestNamespace(account, cAPIKeyID) + if ns == "" { + return + } + binding := openAICompatAnthropicDigestBinding{ + PromptCacheKey: strings.TrimSpace(promptCacheKey), + ExpiresAt: time.Now().Add(s.openAIWSResponseStickyTTL()), + } + s.openaiCompatAnthropicDigestSessions.Store(ns+digestChain, binding) + if oldDigestChain != "" && oldDigestChain != digestChain { + s.openaiCompatAnthropicDigestSessions.Delete(ns + oldDigestChain) + } +} + +func promptCacheKeyFromAnthropicDigest(digestChain string) string { + if strings.TrimSpace(digestChain) == "" { + return "" + } + return "anthropic-digest-" + hashSensitiveValueForLog(digestChain) +} + +func promptCacheKeyFromAnthropicMetadataSession(req *apicompat.AnthropicRequest) string { + if req == nil || len(req.Metadata) == 0 { + return "" + } + var metadata struct { + UserID string `json:"user_id"` + } + if err := json.Unmarshal(req.Metadata, &metadata); err != nil { + return "" + } + parsed := ParseMetadataUserID(metadata.UserID) + if parsed == nil || strings.TrimSpace(parsed.SessionID) == "" { + return "" + } + seed := strings.Join([]string{ + "anthropic-metadata", + strings.TrimSpace(parsed.DeviceID), + strings.TrimSpace(parsed.AccountUUID), + strings.TrimSpace(parsed.SessionID), + }, "|") + return "anthropic-metadata-" + hashSensitiveValueForLog(seed) +} + +func cloneAnthropicRequestForDigest(req *apicompat.AnthropicRequest) *apicompat.AnthropicRequest { + if req == nil { + return nil + } + cp := *req + if len(req.System) > 0 { + cp.System = append(json.RawMessage(nil), req.System...) + } + if len(req.Messages) > 0 { + cp.Messages = append([]apicompat.AnthropicMessage(nil), req.Messages...) + } + return &cp +} diff --git a/backend/internal/service/openai_messages_replay_guard.go b/backend/internal/service/openai_messages_replay_guard.go new file mode 100644 index 00000000..2ad9b6bc --- /dev/null +++ b/backend/internal/service/openai_messages_replay_guard.go @@ -0,0 +1,90 @@ +package service + +import ( + "encoding/json" + + "github.com/Wei-Shaw/sub2api/internal/pkg/apicompat" +) + +const openAICompatAnthropicReplayMaxTailMessages = 12 + +func applyAnthropicCompatFullReplayGuard(req *apicompat.AnthropicRequest) bool { + if req == nil || len(req.Messages) <= openAICompatAnthropicReplayMaxTailMessages { + return false + } + + start := len(req.Messages) - openAICompatAnthropicReplayMaxTailMessages + start = expandAnthropicCompatTrimBoundary(req.Messages, start) + if start <= 0 { + return false + } + + req.Messages = append([]apicompat.AnthropicMessage(nil), req.Messages[start:]...) + return true +} + +func expandAnthropicCompatTrimBoundary(messages []apicompat.AnthropicMessage, start int) int { + if start <= 0 || start >= len(messages) { + return start + } + + toolUseIndex := make(map[string]int) + toolResultIndex := make(map[string]int) + for i, msg := range messages { + uses, results := anthropicCompatMessageToolIDs(msg) + for _, id := range uses { + if _, exists := toolUseIndex[id]; !exists { + toolUseIndex[id] = i + } + } + for _, id := range results { + if _, exists := toolResultIndex[id]; !exists { + toolResultIndex[id] = i + } + } + } + + for { + next := start + for i := start; i < len(messages); i++ { + uses, results := anthropicCompatMessageToolIDs(messages[i]) + for _, id := range results { + if useIdx, ok := toolUseIndex[id]; ok && useIdx < next { + next = useIdx + } + } + for _, id := range uses { + if resultIdx, ok := toolResultIndex[id]; ok && resultIdx < next { + next = resultIdx + } + } + } + if next == start { + return start + } + start = next + } +} + +func anthropicCompatMessageToolIDs(msg apicompat.AnthropicMessage) ([]string, []string) { + var blocks []apicompat.AnthropicContentBlock + if err := json.Unmarshal(msg.Content, &blocks); err != nil { + return nil, nil + } + + uses := make([]string, 0, 1) + results := make([]string, 0, 1) + for _, block := range blocks { + switch block.Type { + case "tool_use": + if block.ID != "" { + uses = append(uses, block.ID) + } + case "tool_result": + if block.ToolUseID != "" { + results = append(results, block.ToolUseID) + } + } + } + return uses, results +} diff --git a/backend/internal/service/openai_messages_replay_guard_test.go b/backend/internal/service/openai_messages_replay_guard_test.go new file mode 100644 index 00000000..6176beec --- /dev/null +++ b/backend/internal/service/openai_messages_replay_guard_test.go @@ -0,0 +1,58 @@ +package service + +import ( + "encoding/json" + "fmt" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/pkg/apicompat" + "github.com/stretchr/testify/require" +) + +func TestApplyAnthropicCompatFullReplayGuard_TrimsOldMessages(t *testing.T) { + t.Parallel() + + req := &apicompat.AnthropicRequest{Messages: make([]apicompat.AnthropicMessage, 0, openAICompatAnthropicReplayMaxTailMessages+3)} + for i := 0; i < openAICompatAnthropicReplayMaxTailMessages+3; i++ { + req.Messages = append(req.Messages, apicompat.AnthropicMessage{ + Role: "user", + Content: json.RawMessage(fmt.Sprintf(`"message-%02d"`, i)), + }) + } + + trimmed := applyAnthropicCompatFullReplayGuard(req) + + require.True(t, trimmed) + require.Len(t, req.Messages, openAICompatAnthropicReplayMaxTailMessages) + require.JSONEq(t, `"message-03"`, string(req.Messages[0].Content)) + require.JSONEq(t, `"message-14"`, string(req.Messages[len(req.Messages)-1].Content)) +} + +func TestApplyAnthropicCompatFullReplayGuard_KeepsToolBoundaryIntact(t *testing.T) { + t.Parallel() + + req := &apicompat.AnthropicRequest{Messages: make([]apicompat.AnthropicMessage, 0, openAICompatAnthropicReplayMaxTailMessages+3)} + for i := 0; i < openAICompatAnthropicReplayMaxTailMessages+3; i++ { + role := "user" + content := json.RawMessage(fmt.Sprintf(`"message-%02d"`, i)) + if i == 1 { + role = "assistant" + content = json.RawMessage(`[{"type":"tool_use","id":"toolu_keep","name":"Read","input":{"file_path":"main.go"}}]`) + } + if i == 3 { + content = json.RawMessage(`[{"type":"tool_result","tool_use_id":"toolu_keep","content":"ok"}]`) + } + req.Messages = append(req.Messages, apicompat.AnthropicMessage{ + Role: role, + Content: content, + }) + } + + trimmed := applyAnthropicCompatFullReplayGuard(req) + + require.True(t, trimmed) + require.Len(t, req.Messages, openAICompatAnthropicReplayMaxTailMessages+2) + require.Equal(t, "assistant", req.Messages[0].Role) + require.Contains(t, string(req.Messages[0].Content), `"toolu_keep"`) + require.Contains(t, string(req.Messages[2].Content), `"tool_result"`) +} diff --git a/backend/internal/service/openai_messages_todo_guard.go b/backend/internal/service/openai_messages_todo_guard.go new file mode 100644 index 00000000..96fc90cb --- /dev/null +++ b/backend/internal/service/openai_messages_todo_guard.go @@ -0,0 +1,121 @@ +package service + +import ( + "encoding/json" + "strings" + + "github.com/Wei-Shaw/sub2api/internal/pkg/apicompat" +) + +const ( + openAICompatClaudeCodeTodoGuardMarker = "" + openAICompatClaudeCodeTodoGuardText = openAICompatClaudeCodeTodoGuardMarker + "\nWhen using Claude Code todo or task tracking tools, keep the visible task list consistent. Do not send final or summary text while any item remains in_progress. Before finishing, asking the user to choose, or reporting a blocker, update the todo list so completed work is completed and deferred work is pending/open; leave an item in_progress only when active work will continue in the same turn.\n" +) + +func appendOpenAICompatClaudeCodeTodoGuard(req *apicompat.ResponsesRequest) bool { + if req == nil || len(req.Input) == 0 { + return false + } + + var items []apicompat.ResponsesInputItem + if err := json.Unmarshal(req.Input, &items); err != nil { + return false + } + if len(items) == 0 || responsesInputItemsContainText(items, openAICompatClaudeCodeTodoGuardMarker) { + return false + } + + content, err := json.Marshal([]apicompat.ResponsesContentPart{{ + Type: "input_text", + Text: openAICompatClaudeCodeTodoGuardText, + }}) + if err != nil { + return false + } + + guard := apicompat.ResponsesInputItem{ + Type: "message", + Role: "developer", + Content: content, + } + + insertAt := 0 + for insertAt < len(items) && items[insertAt].Type == "message" && items[insertAt].Role == "developer" { + insertAt++ + } + + items = append(items, apicompat.ResponsesInputItem{}) + copy(items[insertAt+1:], items[insertAt:]) + items[insertAt] = guard + + input, err := json.Marshal(items) + if err != nil { + return false + } + req.Input = input + return true +} + +func appendOpenAICompatClaudeCodeTodoGuardToRequestBody(reqBody map[string]any) bool { + if reqBody == nil { + return false + } + + input, ok := reqBody["input"].([]any) + if !ok || len(input) == 0 || inputContainsText(input, openAICompatClaudeCodeTodoGuardMarker) { + return false + } + + guard := map[string]any{ + "type": "message", + "role": "developer", + "content": []any{ + map[string]any{ + "type": "input_text", + "text": openAICompatClaudeCodeTodoGuardText, + }, + }, + } + + insertAt := 0 + for insertAt < len(input) { + item, ok := input[insertAt].(map[string]any) + if !ok || strings.TrimSpace(firstNonEmptyString(item["type"])) != "message" || strings.TrimSpace(firstNonEmptyString(item["role"])) != "developer" { + break + } + insertAt++ + } + + input = append(input, nil) + copy(input[insertAt+1:], input[insertAt:]) + input[insertAt] = guard + reqBody["input"] = input + return true +} + +func responsesInputItemsContainText(items []apicompat.ResponsesInputItem, needle string) bool { + needle = strings.TrimSpace(needle) + if needle == "" { + return false + } + for _, item := range items { + if strings.Contains(string(item.Content), needle) { + return true + } + } + return false +} + +func inputContainsText(input []any, needle string) bool { + needle = strings.TrimSpace(needle) + if needle == "" { + return false + } + for _, item := range input { + b, err := json.Marshal(item) + if err == nil && strings.Contains(string(b), needle) { + return true + } + } + return false +} diff --git a/backend/internal/service/openai_model_alias.go b/backend/internal/service/openai_model_alias.go new file mode 100644 index 00000000..2fa2c90e --- /dev/null +++ b/backend/internal/service/openai_model_alias.go @@ -0,0 +1,137 @@ +package service + +import "strings" + +func lastOpenAIModelSegment(model string) string { + model = strings.TrimSpace(model) + if model == "" { + return "" + } + if strings.Contains(model, "/") { + parts := strings.Split(model, "/") + model = parts[len(parts)-1] + } + return strings.TrimSpace(model) +} + +func canonicalizeOpenAIModelAliasSpelling(model string) string { + model = strings.ToLower(lastOpenAIModelSegment(model)) + if model == "" { + return "" + } + + normalized := strings.ReplaceAll(model, "_", "-") + normalized = strings.Join(strings.Fields(normalized), "-") + for strings.Contains(normalized, "--") { + normalized = strings.ReplaceAll(normalized, "--", "-") + } + + if strings.HasPrefix(normalized, "gpt5") { + normalized = "gpt-5" + strings.TrimPrefix(normalized, "gpt5") + } + if !strings.HasPrefix(normalized, "gpt-") && !strings.Contains(normalized, "codex") { + return "" + } + + replacements := []struct { + from string + to string + }{ + {"gpt-5.4mini", "gpt-5.4-mini"}, + {"gpt-5.4nano", "gpt-5.4-nano"}, + {"gpt-5.3-codexspark", "gpt-5.3-codex-spark"}, + {"gpt-5.3codexspark", "gpt-5.3-codex-spark"}, + {"gpt-5.3codex", "gpt-5.3-codex"}, + } + for _, replacement := range replacements { + normalized = strings.ReplaceAll(normalized, replacement.from, replacement.to) + } + return normalized +} + +func normalizeKnownOpenAICodexModel(model string) string { + normalized := canonicalizeOpenAIModelAliasSpelling(model) + if normalized == "" { + return "" + } + + if mapped := getNormalizedCodexModel(normalized); mapped != "" { + return mapped + } + if strings.HasSuffix(normalized, "-openai-compact") { + if mapped := getNormalizedCodexModel(strings.TrimSuffix(normalized, "-openai-compact")); mapped != "" { + return mapped + } + } + + switch { + case strings.Contains(normalized, "gpt-5.5"): + return "gpt-5.5" + case strings.Contains(normalized, "gpt-5.4-mini"): + return "gpt-5.4-mini" + case strings.Contains(normalized, "gpt-5.4-nano"): + return "gpt-5.4-nano" + case strings.Contains(normalized, "gpt-5.4"): + return "gpt-5.4" + case strings.Contains(normalized, "gpt-5.2"): + return "gpt-5.2" + case strings.Contains(normalized, "gpt-5.3-codex-spark"): + return "gpt-5.3-codex-spark" + case strings.Contains(normalized, "gpt-5.3-codex"): + return "gpt-5.3-codex" + case strings.Contains(normalized, "gpt-5.3"): + return "gpt-5.3-codex" + case strings.Contains(normalized, "codex"): + return "gpt-5.3-codex" + case strings.Contains(normalized, "gpt-5"): + return "gpt-5.4" + default: + return "" + } +} + +func appendUsageBillingModelCandidate(candidates []string, seen map[string]struct{}, model string) []string { + trimmed := strings.TrimSpace(model) + if trimmed == "" { + return candidates + } + add := func(candidate string) { + candidate = strings.TrimSpace(candidate) + if candidate == "" { + return + } + key := strings.ToLower(candidate) + if _, ok := seen[key]; ok { + return + } + seen[key] = struct{}{} + candidates = append(candidates, candidate) + } + + add(trimmed) + if canonical := canonicalizeOpenAIModelAliasSpelling(trimmed); canonical != "" { + add(canonical) + } + if normalized := normalizeKnownOpenAICodexModel(trimmed); normalized != "" { + add(normalized) + } + return candidates +} + +func usageBillingModelCandidates(primary string, alternates ...string) []string { + seen := make(map[string]struct{}, 1+len(alternates)) + candidates := appendUsageBillingModelCandidate(nil, seen, primary) + for _, alternate := range alternates { + candidates = appendUsageBillingModelCandidate(candidates, seen, alternate) + } + return candidates +} + +func firstUsageBillingModel(candidates []string) string { + for _, candidate := range candidates { + if trimmed := strings.TrimSpace(candidate); trimmed != "" { + return trimmed + } + } + return "" +} diff --git a/backend/internal/service/openai_model_mapping.go b/backend/internal/service/openai_model_mapping.go index f332633c..7cec5212 100644 --- a/backend/internal/service/openai_model_mapping.go +++ b/backend/internal/service/openai_model_mapping.go @@ -2,44 +2,24 @@ package service import "strings" -// resolveOpenAIForwardModel determines the upstream model for OpenAI-compatible -// forwarding. Group-level default mapping only applies when the account itself -// did not match any explicit model_mapping rule. +// resolveOpenAIForwardModel 解析 OpenAI 兼容转发使用的模型。 +// defaultMappedModel 只服务于 /v1/messages 的 Claude 系列显式调度映射, +// 不作为普通 OpenAI 请求的未知模型兜底。 func resolveOpenAIForwardModel(account *Account, requestedModel, defaultMappedModel string) string { if account == nil { - if defaultMappedModel != "" { + if defaultMappedModel != "" && claudeMessagesDispatchFamily(requestedModel) != "" { return defaultMappedModel } return requestedModel } mappedModel, matched := account.ResolveMappedModel(requestedModel) - if !matched && defaultMappedModel != "" && !isExplicitCodexModel(requestedModel) { + if !matched && defaultMappedModel != "" && claudeMessagesDispatchFamily(requestedModel) != "" { return defaultMappedModel } return mappedModel } -func isExplicitCodexModel(model string) bool { - model = strings.TrimSpace(model) - if model == "" { - return false - } - if strings.Contains(model, "/") { - parts := strings.Split(model, "/") - model = parts[len(parts)-1] - } - model = strings.ToLower(strings.TrimSpace(model)) - if getNormalizedCodexModel(model) != "" { - return true - } - if strings.HasSuffix(model, "-openai-compact") { - base := strings.TrimSuffix(model, "-openai-compact") - return getNormalizedCodexModel(base) != "" - } - return false -} - // resolveOpenAICompactForwardModel determines the compact-only upstream model // for /responses/compact requests. It never affects normal /responses traffic. // When no compact-specific mapping matches, the input model is returned as-is. diff --git a/backend/internal/service/openai_model_mapping_test.go b/backend/internal/service/openai_model_mapping_test.go index 4802c089..f087ac32 100644 --- a/backend/internal/service/openai_model_mapping_test.go +++ b/backend/internal/service/openai_model_mapping_test.go @@ -11,7 +11,7 @@ func TestResolveOpenAIForwardModel(t *testing.T) { expectedModel string }{ { - name: "falls back to group default when account has no mapping", + name: "uses messages dispatch default for claude model", account: &Account{ Credentials: map[string]any{}, }, @@ -19,6 +19,15 @@ func TestResolveOpenAIForwardModel(t *testing.T) { defaultMappedModel: "gpt-4o-mini", expectedModel: "gpt-4o-mini", }, + { + name: "does not fall back to group default for invalid gpt model", + account: &Account{ + Credentials: map[string]any{}, + }, + requestedModel: "gpt6", + defaultMappedModel: "gpt-5.4", + expectedModel: "gpt6", + }, { name: "preserves explicit gpt-5.4 instead of group default", account: &Account{ @@ -85,6 +94,15 @@ func TestResolveOpenAIForwardModel(t *testing.T) { defaultMappedModel: "gpt-5.4", expectedModel: "gpt-5.5", }, + { + name: "preserves compact-spelled gpt5.5 instead of group default", + account: &Account{ + Credentials: map[string]any{}, + }, + requestedModel: "gpt5.5", + defaultMappedModel: "gpt-5.4", + expectedModel: "gpt5.5", + }, { name: "preserves openai namespaced gpt-5.5 instead of group default", account: &Account{ @@ -119,14 +137,14 @@ func TestResolveOpenAIForwardModel_PreventsClaudeModelFromFallingBackToGpt54(t * Credentials: map[string]any{}, } - withoutDefault := normalizeCodexModel(resolveOpenAIForwardModel(account, "claude-opus-4-6", "")) - if withoutDefault != "gpt-5.4" { - t.Fatalf("normalizeCodexModel(...) = %q, want %q", withoutDefault, "gpt-5.4") + withoutDefault := resolveOpenAIForwardModel(account, "claude-opus-4-6", "") + if withoutDefault != "claude-opus-4-6" { + t.Fatalf("resolveOpenAIForwardModel(...) = %q, want %q", withoutDefault, "claude-opus-4-6") } - withDefault := normalizeCodexModel(resolveOpenAIForwardModel(account, "claude-opus-4-6", "gpt-5.4")) + withDefault := resolveOpenAIForwardModel(account, "claude-opus-4-6", "gpt-5.4") if withDefault != "gpt-5.4" { - t.Fatalf("normalizeCodexModel(...) = %q, want %q", withDefault, "gpt-5.4") + t.Fatalf("resolveOpenAIForwardModel(...) = %q, want %q", withDefault, "gpt-5.4") } } @@ -205,6 +223,10 @@ func TestNormalizeCodexModel(t *testing.T) { "gpt-5.3-codex-spark-xhigh": "gpt-5.3-codex-spark", "gpt-5.3": "gpt-5.3-codex", "gpt-image-2": "gpt-image-2", + "gpt-5.4-nano": "gpt-5.4-nano", + "gpt-5.4-nano-high": "gpt-5.4-nano", + "gpt6": "gpt6", + "claude-opus-4-6": "claude-opus-4-6", } for input, expected := range cases { @@ -222,9 +244,21 @@ func TestNormalizeOpenAIModelForUpstream(t *testing.T) { want string }{ { - name: "oauth keeps codex normalization behavior", + name: "oauth preserves unknown non codex model", account: &Account{Type: AccountTypeOAuth}, model: "gemini-3-flash-preview", + want: "gemini-3-flash-preview", + }, + { + name: "oauth preserves invalid gpt model", + account: &Account{Type: AccountTypeOAuth}, + model: "gpt6", + want: "gpt6", + }, + { + name: "oauth normalizes known codex alias", + account: &Account{Type: AccountTypeOAuth}, + model: "gpt-5.4-high", want: "gpt-5.4", }, { diff --git a/backend/internal/service/openai_oauth_passthrough_test.go b/backend/internal/service/openai_oauth_passthrough_test.go index 049ffdd8..398cbb85 100644 --- a/backend/internal/service/openai_oauth_passthrough_test.go +++ b/backend/internal/service/openai_oauth_passthrough_test.go @@ -25,9 +25,12 @@ func f64p(v float64) *float64 { return &v } type httpUpstreamRecorder struct { lastReq *http.Request lastBody []byte + requests []*http.Request + bodies [][]byte - resp *http.Response - err error + resp *http.Response + responses []*http.Response + err error } func (u *httpUpstreamRecorder) Do(req *http.Request, proxyURL string, accountID int64, accountConcurrency int) (*http.Response, error) { @@ -35,12 +38,19 @@ func (u *httpUpstreamRecorder) Do(req *http.Request, proxyURL string, accountID if req != nil && req.Body != nil { b, _ := io.ReadAll(req.Body) u.lastBody = b + u.bodies = append(u.bodies, append([]byte(nil), b...)) _ = req.Body.Close() req.Body = io.NopCloser(bytes.NewReader(b)) } + u.requests = append(u.requests, req) if u.err != nil { return nil, u.err } + if len(u.responses) > 0 { + resp := u.responses[0] + u.responses = u.responses[1:] + return resp, nil + } return u.resp, nil } @@ -48,6 +58,93 @@ func (u *httpUpstreamRecorder) DoWithTLS(req *http.Request, proxyURL string, acc return u.Do(req, proxyURL, accountID, accountConcurrency) } +func TestOpenAIGatewayService_ResponsesUnknownModelDoesNotFallbackToGPT54(t *testing.T) { + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + originalBody := []byte(`{"model":"gpt6","stream":false,"instructions":"local-test-instructions","input":[{"type":"text","text":"hi"}]}`) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(originalBody)) + c.Request.Header.Set("Content-Type", "application/json") + + upstream := &httpUpstreamRecorder{resp: &http.Response{ + StatusCode: http.StatusBadRequest, + Header: http.Header{"Content-Type": []string{"application/json"}, "x-request-id": []string{"rid_unknown_model"}}, + Body: io.NopCloser(strings.NewReader(`{"error":{"type":"invalid_request_error","message":"model not found"}}`)), + }} + + svc := &OpenAIGatewayService{ + cfg: &config.Config{}, + httpUpstream: upstream, + } + account := &Account{ + ID: 123, + Name: "acc", + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Concurrency: 1, + Credentials: map[string]any{ + "access_token": "oauth-token", + "chatgpt_account_id": "chatgpt-acc", + }, + Status: StatusActive, + Schedulable: true, + } + + result, err := svc.Forward(context.Background(), c, account, originalBody) + require.Error(t, err) + require.Nil(t, result) + require.NotNil(t, upstream.lastReq) + require.Equal(t, "https://chatgpt.com/backend-api/codex/responses", upstream.lastReq.URL.String()) + require.Equal(t, "gpt6", gjson.GetBytes(upstream.lastBody, "model").String()) + require.NotEqual(t, "gpt-5.4", gjson.GetBytes(upstream.lastBody, "model").String()) + require.True(t, rec.Code >= http.StatusBadRequest) +} + +func TestOpenAIGatewayService_OAuthMessagesBridgeDoesNotInjectDefaultInstructions(t *testing.T) { + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + originalBody := []byte(`{"model":"gpt-5.5","stream":true,"prompt_cache_key":"anthropic-metadata-session-1","input":[{"type":"message","role":"developer","content":[{"type":"input_text","text":""}]},{"type":"message","role":"user","content":"hello"}]}`) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(originalBody)) + c.Request.Header.Set("Content-Type", "application/json") + + upstream := &httpUpstreamRecorder{resp: &http.Response{ + StatusCode: http.StatusBadRequest, + Header: http.Header{"Content-Type": []string{"application/json"}, "x-request-id": []string{"rid_bridge"}}, + Body: io.NopCloser(strings.NewReader(`{"error":{"type":"invalid_request_error","message":"bridge stop"}}`)), + }} + svc := &OpenAIGatewayService{ + cfg: &config.Config{}, + httpUpstream: upstream, + } + account := &Account{ + ID: 123, + Name: "acc", + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Concurrency: 1, + Credentials: map[string]any{ + "access_token": "oauth-token", + "chatgpt_account_id": "chatgpt-acc", + }, + Status: StatusActive, + Schedulable: true, + } + + result, err := svc.Forward(context.Background(), c, account, originalBody) + require.Error(t, err) + require.Nil(t, result) + require.NotNil(t, upstream.lastReq) + require.Equal(t, "", gjson.GetBytes(upstream.lastBody, "instructions").String()) + require.False(t, gjson.GetBytes(upstream.lastBody, "prompt_cache_key").Exists()) + require.NotEmpty(t, upstream.lastReq.Header.Get("Session_Id")) + require.Empty(t, upstream.lastReq.Header.Get("Conversation_Id")) + require.Empty(t, upstream.lastReq.Header.Get("OpenAI-Beta")) + require.Empty(t, upstream.lastReq.Header.Get("originator")) +} + type openAIPassthroughFailoverRepo struct { stubOpenAIAccountRepo rateLimitCalls []time.Time @@ -307,6 +404,52 @@ func TestOpenAIGatewayService_OAuthPassthrough_CompactUsesJSONAndKeepsNonStreami require.Contains(t, rec.Body.String(), `"id":"cmp_123"`) } +func TestOpenAIGatewayService_OAuthPassthrough_UpstreamRequestIgnoresClientCancel(t *testing.T) { + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + reqCtx, cancel := context.WithCancel(context.Background()) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(nil)).WithContext(reqCtx) + c.Request.Header.Set("User-Agent", "codex_cli_rs/0.1.0") + cancel() + + originalBody := []byte(`{"model":"gpt-5.2","stream":true,"store":true,"instructions":"local-test-instructions","input":[{"type":"text","text":"hi"}]}`) + upstream := &httpUpstreamRecorder{resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_passthrough_ctx"}}, + Body: io.NopCloser(strings.NewReader(strings.Join([]string{ + `data: {"type":"response.completed","response":{"usage":{"input_tokens":2,"output_tokens":1}}}`, + "", + "data: [DONE]", + "", + }, "\n"))), + }} + + svc := &OpenAIGatewayService{ + cfg: &config.Config{Gateway: config.GatewayConfig{ForceCodexCLI: false}}, + httpUpstream: upstream, + } + account := &Account{ + ID: 123, + Name: "acc", + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Concurrency: 1, + Credentials: map[string]any{"access_token": "oauth-token", "chatgpt_account_id": "chatgpt-acc"}, + Extra: map[string]any{"openai_passthrough": true, "openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeOff}, + Status: StatusActive, + Schedulable: true, + RateMultiplier: f64p(1), + } + + result, err := svc.Forward(reqCtx, c, account, originalBody) + require.NoError(t, err) + require.NotNil(t, result) + require.NotNil(t, upstream.lastReq) + require.NoError(t, upstream.lastReq.Context().Err()) +} + func TestOpenAIGatewayService_OAuthPassthrough_CodexMissingInstructionsRejectedBeforeUpstream(t *testing.T) { gin.SetMode(gin.TestMode) logSink, restore := captureStructuredLog(t) @@ -405,6 +548,52 @@ func TestOpenAIGatewayService_OAuthPassthrough_DisabledUsesLegacyTransform(t *te require.Contains(t, string(upstream.lastBody), `"stream":true`) } +func TestOpenAIGatewayService_OAuthLegacy_UpstreamRequestIgnoresClientCancel(t *testing.T) { + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + reqCtx, cancel := context.WithCancel(context.Background()) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(nil)).WithContext(reqCtx) + c.Request.Header.Set("User-Agent", "codex_cli_rs/0.1.0") + cancel() + + originalBody := []byte(`{"model":"gpt-5.2","stream":false,"store":true,"input":[{"type":"text","text":"hi"}]}`) + upstream := &httpUpstreamRecorder{resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_legacy_ctx"}}, + Body: io.NopCloser(strings.NewReader(strings.Join([]string{ + `data: {"type":"response.completed","response":{"usage":{"input_tokens":1,"output_tokens":1}}}`, + "", + "data: [DONE]", + "", + }, "\n"))), + }} + + svc := &OpenAIGatewayService{ + cfg: &config.Config{Gateway: config.GatewayConfig{ForceCodexCLI: false}}, + httpUpstream: upstream, + } + account := &Account{ + ID: 123, + Name: "acc", + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Concurrency: 1, + Credentials: map[string]any{"access_token": "oauth-token", "chatgpt_account_id": "chatgpt-acc"}, + Extra: map[string]any{"openai_passthrough": false, "openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeOff}, + Status: StatusActive, + Schedulable: true, + RateMultiplier: f64p(1), + } + + result, err := svc.Forward(reqCtx, c, account, originalBody) + require.NoError(t, err) + require.NotNil(t, result) + require.NotNil(t, upstream.lastReq) + require.NoError(t, upstream.lastReq.Context().Err()) +} + func TestOpenAIGatewayService_OAuthLegacy_CompositeCodexUAUsesCodexOriginator(t *testing.T) { gin.SetMode(gin.TestMode) diff --git a/backend/internal/service/openai_sse_data.go b/backend/internal/service/openai_sse_data.go new file mode 100644 index 00000000..61b813b6 --- /dev/null +++ b/backend/internal/service/openai_sse_data.go @@ -0,0 +1,70 @@ +package service + +import ( + "strings" + + "github.com/tidwall/gjson" +) + +type openAISSEDataAccumulator struct { + lines []string +} + +func (a *openAISSEDataAccumulator) AddLine(line string, fn func([]byte)) { + if fn == nil { + return + } + trimmedLine := strings.TrimRight(line, "\r\n") + if data, ok := extractOpenAISSEDataLine(trimmedLine); ok { + a.lines = append(a.lines, data) + return + } + if strings.TrimSpace(trimmedLine) == "" { + a.Flush(fn) + } +} + +func (a *openAISSEDataAccumulator) Flush(fn func([]byte)) { + if fn == nil || len(a.lines) == 0 { + return + } + emitOpenAISSEDataPayloads(a.lines, fn) + a.lines = a.lines[:0] +} + +func forEachOpenAISSEDataPayload(body string, fn func([]byte)) { + if fn == nil || strings.TrimSpace(body) == "" { + return + } + var acc openAISSEDataAccumulator + for _, line := range strings.Split(body, "\n") { + acc.AddLine(line, fn) + } + acc.Flush(fn) +} + +func emitOpenAISSEDataPayloads(lines []string, fn func([]byte)) { + if fn == nil || len(lines) == 0 { + return + } + if len(lines) == 1 { + emitOpenAISSEDataPayload(lines[0], fn) + return + } + joined := strings.Join(lines, "\n") + if gjson.Valid(joined) { + emitOpenAISSEDataPayload(joined, fn) + return + } + for _, line := range lines { + emitOpenAISSEDataPayload(line, fn) + } +} + +func emitOpenAISSEDataPayload(data string, fn func([]byte)) { + data = strings.TrimSpace(data) + if data == "" || data == "[DONE]" { + return + } + fn([]byte(data)) +} diff --git a/backend/internal/service/openai_ws_forwarder.go b/backend/internal/service/openai_ws_forwarder.go index d1386b1b..784cdbe5 100644 --- a/backend/internal/service/openai_ws_forwarder.go +++ b/backend/internal/service/openai_ws_forwarder.go @@ -219,8 +219,11 @@ func (e *OpenAIWSClientCloseError) Reason() string { // OpenAIWSIngressHooks 定义入站 WS 每个 turn 的生命周期回调。 type OpenAIWSIngressHooks struct { - BeforeTurn func(turn int) error - AfterTurn func(turn int, result *OpenAIForwardResult, turnErr error) + // InitialRequestModel 是首帧渠道映射前的请求模型,只用于 usage metadata + // 的 reasoning effort 后缀推导,禁止用于上游请求或计费模型。 + InitialRequestModel string + BeforeTurn func(turn int) error + AfterTurn func(turn int, result *OpenAIForwardResult, turnErr error) } func normalizeOpenAIWSLogValue(value string) string { @@ -1987,6 +1990,7 @@ func (s *OpenAIGatewayService) forwardOpenAIWSV2( } usage := &OpenAIUsage{} + imageCounter := newOpenAIImageOutputCounter() var firstTokenMs *int responseID := "" var finalResponse []byte @@ -2168,6 +2172,7 @@ func (s *OpenAIGatewayService) forwardOpenAIWSV2( if openAIWSEventShouldParseUsage(eventType) { parseOpenAIWSResponseUsageFromCompletedEvent(message, usage) } + imageCounter.AddSSEData(message) if eventType == "error" { errCodeRaw, errTypeRaw, errMsgRaw := parseOpenAIWSErrorEventFields(message) @@ -2340,6 +2345,7 @@ func (s *OpenAIGatewayService) forwardOpenAIWSV2( Usage: *usage, Model: originalModel, UpstreamModel: mappedModel, + ImageCount: imageCounter.Count(), ServiceTier: extractOpenAIServiceTier(reqBody), ReasoningEffort: extractOpenAIReasoningEffort(reqBody, originalModel), Stream: reqStream, @@ -2446,6 +2452,8 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient( promptCacheKey string previousResponseID string originalModel string + imageBillingModel string + imageSizeTier string payloadBytes int } @@ -2543,6 +2551,19 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient( } normalized = next } + imageIntent := IsImageGenerationIntent(openAIResponsesEndpoint, originalModel, normalized) + if imageIntent && !GroupAllowsImageGeneration(apiKeyGroup(getAPIKeyFromContext(c))) { + return openAIWSClientPayload{}, NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, ImageGenerationPermissionMessage(), nil) + } + imageBillingModel := "" + imageSizeTier := "" + if imageIntent { + var imageCfgErr error + imageBillingModel, imageSizeTier, imageCfgErr = resolveOpenAIResponsesImageBillingConfigFromBody(normalized, originalModel) + if imageCfgErr != nil { + return openAIWSClientPayload{}, NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, imageCfgErr.Error(), imageCfgErr) + } + } // Apply OpenAI Fast Policy on the response.create frame using the same // evaluator/normalize/scope rules as the HTTP entrypoints. This is the @@ -2588,6 +2609,8 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient( promptCacheKey: promptCacheKey, previousResponseID: previousResponseID, originalModel: originalModel, + imageBillingModel: imageBillingModel, + imageSizeTier: imageSizeTier, payloadBytes: len(normalized), }, nil } @@ -2789,7 +2812,7 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient( return payload, nil } - sendAndRelay := func(turn int, lease *openAIWSConnLease, payload []byte, payloadBytes int, originalModel string) (*OpenAIForwardResult, error) { + sendAndRelay := func(turn int, lease *openAIWSConnLease, payload []byte, payloadBytes int, originalModel string, imageBillingModel string, imageSizeTier string) (*OpenAIForwardResult, error) { if lease == nil { return nil, errors.New("upstream websocket lease is nil") } @@ -2814,6 +2837,7 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient( responseID := "" usage := OpenAIUsage{} + imageCounter := newOpenAIImageOutputCounter() var firstTokenMs *int reqStream := openAIWSPayloadBoolFromRaw(payload, "stream", true) turnPreviousResponseID := openAIWSPayloadStringFromRaw(payload, "previous_response_id") @@ -2935,6 +2959,7 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient( if openAIWSEventShouldParseUsage(eventType) { parseOpenAIWSResponseUsageFromCompletedEvent(upstreamMessage, &usage) } + imageCounter.AddSSEData(upstreamMessage) if !clientDisconnected { if needModelReplace && len(mappedModelBytes) > 0 && openAIWSEventMayContainModel(eventType) && bytes.Contains(upstreamMessage, mappedModelBytes) { @@ -2994,7 +3019,8 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient( clientDisconnected, ) } - return &OpenAIForwardResult{ + imageCount := imageCounter.Count() + result := &OpenAIForwardResult{ RequestID: responseID, Usage: usage, Model: originalModel, @@ -3006,13 +3032,21 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient( ResponseHeaders: lease.HandshakeHeaders(), Duration: time.Since(turnStart), FirstTokenMs: firstTokenMs, - }, nil + } + if imageCount > 0 { + result.ImageCount = imageCount + result.ImageSize = imageSizeTier + result.BillingModel = imageBillingModel + } + return result, nil } } } currentPayload := firstPayload.payloadRaw currentOriginalModel := firstPayload.originalModel + currentImageBillingModel := firstPayload.imageBillingModel + currentImageSizeTier := firstPayload.imageSizeTier currentPayloadBytes := firstPayload.payloadBytes isStrictAffinityTurn := func(payload []byte) bool { if !storeDisabled { @@ -3101,6 +3135,12 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient( if turnPrevRecoveryTried || !s.openAIWSIngressPreviousResponseRecoveryEnabled() { return false } + // 携带 function_call_output 的请求不能丢弃 previous_response_id: + // 上游 API 需要 response chain 来匹配 tool_result 与之前的 tool_use, + // 丢弃后会导致 "No tool call found for function call output" 400 错误。 + if gjson.GetBytes(currentPayload, `input.#(type=="function_call_output")`).Exists() { + return false + } if isStrictAffinityTurn(currentPayload) { // Layer 2:严格亲和链路命中 previous_response_not_found 时,降级为“去掉 previous_response_id 后重放一次”。 // 该错误说明续链锚点已失效,继续 strict fail-close 只会直接中断本轮请求。 @@ -3367,7 +3407,11 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient( truncateOpenAIWSLogValue(pingErr.Error(), openAIWSLogValueMaxLen), ) if forcePreferredConn { - if !turnPrevRecoveryTried && currentPreviousResponseID != "" { + // 携带 function_call_output 的请求不能丢弃 previous_response_id: + // 上游 API 需要 response chain 来匹配 tool_result 与之前的 tool_use, + // 丢弃后会导致 "No tool call found for function call output" 400 错误。 + hasFCOutput := gjson.GetBytes(currentPayload, `input.#(type=="function_call_output")`).Exists() + if !turnPrevRecoveryTried && currentPreviousResponseID != "" && !hasFCOutput { updatedPayload, removed, dropErr := dropPreviousResponseIDFromRawPayload(currentPayload) if dropErr != nil || !removed { reason := "not_removed" @@ -3457,7 +3501,7 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient( ) } - result, relayErr := sendAndRelay(turn, sessionLease, currentPayload, currentPayloadBytes, currentOriginalModel) + result, relayErr := sendAndRelay(turn, sessionLease, currentPayload, currentPayloadBytes, currentOriginalModel, currentImageBillingModel, currentImageSizeTier) if relayErr != nil { lastTurnClean = false if recoverIngressPrevResponseNotFound(relayErr, turn, connID) { @@ -3579,6 +3623,8 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient( } currentPayload = nextPayload.payloadRaw currentOriginalModel = nextPayload.originalModel + currentImageBillingModel = nextPayload.imageBillingModel + currentImageSizeTier = nextPayload.imageSizeTier currentPayloadBytes = nextPayload.payloadBytes storeDisabled = s.isOpenAIWSStoreDisabledInRequestRaw(currentPayload, account) if !storeDisabled { diff --git a/backend/internal/service/openai_ws_forwarder_ingress_session_test.go b/backend/internal/service/openai_ws_forwarder_ingress_session_test.go index 30fd4142..5246d37d 100644 --- a/backend/internal/service/openai_ws_forwarder_ingress_session_test.go +++ b/backend/internal/service/openai_ws_forwarder_ingress_session_test.go @@ -399,7 +399,7 @@ func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_PassthroughModeR }() writeCtx, cancelWrite := context.WithTimeout(context.Background(), 3*time.Second) - err = clientConn.Write(writeCtx, coderws.MessageText, []byte(`{"type":"response.create","model":"gpt-5.1","stream":false,"service_tier":"fast"}`)) + err = clientConn.Write(writeCtx, coderws.MessageText, []byte(`{"type":"response.create","model":"gpt-5.1","stream":false,"service_tier":"fast","reasoning":{"effort":"HIGH"}}`)) cancelWrite() require.NoError(t, err) @@ -431,6 +431,8 @@ func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_PassthroughModeR require.Equal(t, 3, result.Usage.OutputTokens) require.NotNil(t, result.ServiceTier) require.Equal(t, "priority", *result.ServiceTier) + require.NotNil(t, result.ReasoningEffort) + require.Equal(t, "high", *result.ReasoningEffort) case <-time.After(2 * time.Second): t.Fatal("未收到 passthrough turn 结果回调") } diff --git a/backend/internal/service/openai_ws_forwarder_success_test.go b/backend/internal/service/openai_ws_forwarder_success_test.go index 7a76c385..cd816533 100644 --- a/backend/internal/service/openai_ws_forwarder_success_test.go +++ b/backend/internal/service/openai_ws_forwarder_success_test.go @@ -171,6 +171,127 @@ func TestOpenAIGatewayService_Forward_WSv2_SuccessAndBindSticky(t *testing.T) { require.Equal(t, "resp_new_1", gjson.GetBytes(responseBody, "id").String()) } +func TestOpenAIGatewayService_Forward_WSv2_ImageGenerationCountsOutputs(t *testing.T) { + gin.SetMode(gin.TestMode) + + upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }} + wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + t.Errorf("upgrade websocket failed: %v", err) + return + } + defer func() { + _ = conn.Close() + }() + + var request map[string]any + if err := conn.ReadJSON(&request); err != nil { + t.Errorf("read ws request failed: %v", err) + return + } + + if err := conn.WriteJSON(map[string]any{ + "type": "response.output_item.done", + "item": map[string]any{ + "id": "ig_ws_1", + "type": "image_generation_call", + "result": "final-image", + }, + }); err != nil { + t.Errorf("write response.output_item.done failed: %v", err) + return + } + if err := conn.WriteJSON(map[string]any{ + "type": "response.completed", + "response": map[string]any{ + "id": "resp_ws_image_1", + "model": "gpt-5.4", + "output": []any{ + map[string]any{ + "id": "ig_ws_1", + "type": "image_generation_call", + "result": "final-image", + }, + }, + "usage": map[string]any{ + "input_tokens": 9, + "output_tokens": 4, + }, + }, + }); err != nil { + t.Errorf("write response.completed failed: %v", err) + return + } + })) + defer wsServer.Close() + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) + groupID := int64(1010) + c.Set("api_key", &APIKey{ + GroupID: &groupID, + Group: &Group{ + ID: groupID, + AllowImageGeneration: true, + }, + }) + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1 + cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0 + cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1 + cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8 + cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3 + cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 5 + cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3 + + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: &httpUpstreamRecorder{}, + cache: &stubGatewayCache{}, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + toolCorrector: NewCodexToolCorrector(), + } + + account := &Account{ + ID: 10, + Name: "openai-ws-image", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + "base_url": wsServer.URL, + }, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + + body := []byte(`{"model":"gpt-5.4","stream":false,"input":"draw","tools":[{"type":"image_generation","model":"gpt-image-2","size":"1024x1024"}],"tool_choice":{"type":"image_generation"}}`) + result, err := svc.Forward(context.Background(), c, account, body) + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, "resp_ws_image_1", result.RequestID) + require.Equal(t, 1, result.ImageCount) + require.Equal(t, "1K", result.ImageSize) + require.Equal(t, "gpt-image-2", result.BillingModel) + require.Equal(t, 9, result.Usage.InputTokens) + require.Equal(t, 4, result.Usage.OutputTokens) + require.True(t, result.OpenAIWSMode) + require.Equal(t, "resp_ws_image_1", gjson.GetBytes(rec.Body.Bytes(), "id").String()) +} + func requestToJSONString(payload map[string]any) string { if len(payload) == 0 { return "{}" diff --git a/backend/internal/service/openai_ws_v2_passthrough_adapter.go b/backend/internal/service/openai_ws_v2_passthrough_adapter.go index 3dbb199a..8bc17d42 100644 --- a/backend/internal/service/openai_ws_v2_passthrough_adapter.go +++ b/backend/internal/service/openai_ws_v2_passthrough_adapter.go @@ -124,6 +124,73 @@ func openAIWSPassthroughPolicyModelFromSessionFrame(account *Account, payload [] return normalizeOpenAIModelForUpstream(account, account.GetMappedModel(original)) } +type openAIWSPassthroughUsageMeta struct { + serviceTier atomic.Pointer[string] + reasoningEffort atomic.Pointer[string] + + // 仅在 client->upstream filter goroutine 中读写;Load 侧通过上方原子指针同步。 + sessionRequestModel string +} + +func newOpenAIWSPassthroughUsageMeta(initialRequestModel string, firstFrame []byte) *openAIWSPassthroughUsageMeta { + meta := &openAIWSPassthroughUsageMeta{ + sessionRequestModel: strings.TrimSpace(initialRequestModel), + } + if meta.sessionRequestModel == "" { + meta.sessionRequestModel = openAIWSPassthroughRequestModelForFrame(firstFrame) + } + return meta +} + +func (m *openAIWSPassthroughUsageMeta) initFromFirstFrame(policyOutput []byte) { + if m == nil { + return + } + m.serviceTier.Store(extractOpenAIServiceTierFromBody(policyOutput)) + m.reasoningEffort.Store(extractOpenAIReasoningEffortFromBody(policyOutput, m.sessionRequestModel)) +} + +func (m *openAIWSPassthroughUsageMeta) updateSessionRequestModel(payload []byte) { + if m == nil { + return + } + if model := openAIWSPassthroughRequestModelFromSessionFrame(payload); model != "" { + m.sessionRequestModel = model + } +} + +func (m *openAIWSPassthroughUsageMeta) requestModelForFrame(payload []byte) string { + if m == nil { + return openAIWSPassthroughRequestModelForFrame(payload) + } + if model := openAIWSPassthroughRequestModelForFrame(payload); model != "" { + return model + } + return m.sessionRequestModel +} + +func (m *openAIWSPassthroughUsageMeta) updateFromResponseCreate(policyOutput []byte, requestModelForFrame string) { + if m == nil { + return + } + m.serviceTier.Store(extractOpenAIServiceTierFromBody(policyOutput)) + m.reasoningEffort.Store(extractOpenAIReasoningEffortFromBody(policyOutput, requestModelForFrame)) +} + +func openAIWSPassthroughRequestModelForFrame(payload []byte) string { + if len(payload) == 0 || strings.TrimSpace(gjson.GetBytes(payload, "type").String()) != "response.create" { + return "" + } + return strings.TrimSpace(gjson.GetBytes(payload, "model").String()) +} + +func openAIWSPassthroughRequestModelFromSessionFrame(payload []byte) string { + if len(payload) == 0 || strings.TrimSpace(gjson.GetBytes(payload, "type").String()) != "session.update" { + return "" + } + return strings.TrimSpace(gjson.GetBytes(payload, "session.model").String()) +} + const openaiWSV2PassthroughModeFields = "ws_mode=passthrough ws_router=v2" var _ openaiwsv2.FrameConn = (*openAIWSClientFrameConn)(nil) @@ -204,6 +271,11 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough( // silently passed through, defeating the policy on every frame after // the first. capturedSessionModel := openAIWSPassthroughPolicyModelForFrame(account, firstClientMessage) + initialRequestModel := "" + if hooks != nil { + initialRequestModel = hooks.InitialRequestModel + } + usageMeta := newOpenAIWSPassthroughUsageMeta(initialRequestModel, firstClientMessage) updatedFirst, blocked, policyErr := s.applyOpenAIFastPolicyToWSResponseCreate(ctx, account, capturedSessionModel, firstClientMessage) if policyErr != nil { return fmt.Errorf("apply openai fast policy on first ws frame: %w", policyErr) @@ -226,7 +298,8 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough( } firstClientMessage = updatedFirst - // 在 policy filter 之后再提取 service_tier 用于 billing 上报:filter + // 在 policy filter 之后再提取 service_tier / reasoning_effort 用于 + // usage 上报:filter // 命中时 service_tier 已经从 firstClientMessage 中删除,billing 应当 // 反映上游实际处理的 tier(nil = default),而不是用户最初请求的 // "priority"。HTTP 入口(line ~2728 extractOpenAIServiceTier(reqBody)) @@ -237,11 +310,8 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough( // codex-rs/core/src/client.rs build_responses_request 每次重新填值)。 // 因此使用 atomic.Pointer[string] 在 filter(runClientToUpstream // goroutine)和 OnTurnComplete / final result(runUpstreamToClient - // goroutine)之间同步当前 turn 的 service_tier。 - // extractOpenAIServiceTierFromBody 返回 *string,本身是指针类型, - // 可直接 Store/Load 而无需额外封装。 - var requestServiceTierPtr atomic.Pointer[string] - requestServiceTierPtr.Store(extractOpenAIServiceTierFromBody(firstClientMessage)) + // goroutine)之间同步当前 turn 的 usage metadata。 + usageMeta.initFromFirstFrame(firstClientMessage) wsURL, err := s.buildOpenAIResponsesWSURL(account) if err != nil { @@ -327,6 +397,8 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough( if updated := openAIWSPassthroughPolicyModelFromSessionFrame(account, payload); updated != "" { capturedSessionModel = updated } + usageMeta.updateSessionRequestModel(payload) + requestModelForThisFrame := usageMeta.requestModelForFrame(payload) // Per-frame model first; if the client omits "model" on a // follow-up frame (legal in Realtime), fall back to the // session-level model captured from the first frame so the @@ -337,14 +409,14 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough( model = capturedSessionModel } out, blocked, policyErr := s.applyOpenAIFastPolicyToWSResponseCreate(ctx, account, model, payload) - // 多轮 passthrough billing:仅在成功(non-block / non-err) - // 的 response.create 帧上更新 requestServiceTierPtr,使用 + // 多轮 passthrough usage:仅在成功(non-block / non-err) + // 的 response.create 帧上更新 usageMeta,使用 // filter 处理后的 payload,与首帧 policy-after-extract 语义 // 保持一致(参见上方 extractOpenAIServiceTierFromBody 注释)。 // - 非 response.create 帧(response.cancel / // conversation.item.create / session.update 等)不携带 - // per-response service_tier,不应覆盖前一轮值。 - // - blocked != nil:该帧不会发送上游,billing tier 应保持 + // per-response metadata,不应覆盖前一轮值。 + // - blocked != nil:该帧不会发送上游,usage metadata 应保持 // 上一轮值。 // - policyErr != nil:异常路径,保持上一轮值。 // - 不带 service_tier 的 response.create 会让 @@ -353,7 +425,7 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough( // service_tier 时按 default 处理,billing 应如实反映。 if policyErr == nil && blocked == nil && strings.TrimSpace(gjson.GetBytes(payload, "type").String()) == "response.create" { - requestServiceTierPtr.Store(extractOpenAIServiceTierFromBody(out)) + usageMeta.updateFromResponseCreate(out, requestModelForThisFrame) } return out, blocked, policyErr }, @@ -397,7 +469,8 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough( CacheReadInputTokens: turn.Usage.CacheReadInputTokens, }, Model: turn.RequestModel, - ServiceTier: requestServiceTierPtr.Load(), + ServiceTier: usageMeta.serviceTier.Load(), + ReasoningEffort: usageMeta.reasoningEffort.Load(), Stream: true, OpenAIWSMode: true, ResponseHeaders: cloneHeader(handshakeHeaders), @@ -445,7 +518,8 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough( CacheReadInputTokens: relayResult.Usage.CacheReadInputTokens, }, Model: relayResult.RequestModel, - ServiceTier: requestServiceTierPtr.Load(), + ServiceTier: usageMeta.serviceTier.Load(), + ReasoningEffort: usageMeta.reasoningEffort.Load(), Stream: true, OpenAIWSMode: true, ResponseHeaders: cloneHeader(handshakeHeaders), diff --git a/backend/internal/service/ops_cleanup_executor.go b/backend/internal/service/ops_cleanup_executor.go new file mode 100644 index 00000000..63a7367f --- /dev/null +++ b/backend/internal/service/ops_cleanup_executor.go @@ -0,0 +1,164 @@ +package service + +import ( + "context" + "database/sql" + "fmt" + "strings" + "time" +) + +const ( + opsCleanupDefaultSchedule = "0 2 * * *" + opsCleanupBatchSize = 5000 + opsCleanupCronStopTimeout = 3 * time.Second + opsCleanupRunTimeout = 30 * time.Minute + opsCleanupHeartbeatTimeout = 2 * time.Second +) + +type opsCleanupTarget struct { + retentionDays int + table string + timeCol string + castDate bool + counter *int64 +} + +type opsCleanupDeletedCounts struct { + errorLogs int64 + retryAttempts int64 + alertEvents int64 + systemLogs int64 + logAudits int64 + systemMetrics int64 + hourlyPreagg int64 + dailyPreagg int64 +} + +func (c opsCleanupDeletedCounts) String() string { + return fmt.Sprintf( + "error_logs=%d retry_attempts=%d alert_events=%d system_logs=%d log_audits=%d system_metrics=%d hourly_preagg=%d daily_preagg=%d", + c.errorLogs, + c.retryAttempts, + c.alertEvents, + c.systemLogs, + c.logAudits, + c.systemMetrics, + c.hourlyPreagg, + c.dailyPreagg, + ) +} + +// opsCleanupPlan 把"保留天数"翻译成具体的清理动作。 +// - days < 0 → 跳过该项清理(ok=false),保留兼容老数据 +// - days == 0 → TRUNCATE TABLE(O(1) 全清),truncate=true +// - days > 0 → 批量 DELETE 早于 now-N天 的行,cutoff = now - N 天 +func opsCleanupPlan(now time.Time, days int) (cutoff time.Time, truncate, ok bool) { + if days < 0 { + return time.Time{}, false, false + } + if days == 0 { + return time.Time{}, true, true + } + return now.AddDate(0, 0, -days), false, true +} + +func opsCleanupRunOne( + ctx context.Context, + db *sql.DB, + truncate bool, + cutoff time.Time, + table, timeCol string, + castDate bool, + batchSize int, +) (int64, error) { + if truncate { + return truncateOpsTable(ctx, db, table) + } + return deleteOldRowsByID(ctx, db, table, timeCol, cutoff, batchSize, castDate) +} + +func deleteOldRowsByID( + ctx context.Context, + db *sql.DB, + table string, + timeColumn string, + cutoff time.Time, + batchSize int, + castCutoffToDate bool, +) (int64, error) { + if db == nil { + return 0, nil + } + if batchSize <= 0 { + batchSize = opsCleanupBatchSize + } + + where := fmt.Sprintf("%s < $1", timeColumn) + if castCutoffToDate { + where = fmt.Sprintf("%s < $1::date", timeColumn) + } + + q := fmt.Sprintf(` +WITH batch AS ( + SELECT id FROM %s + WHERE %s + ORDER BY id + LIMIT $2 +) +DELETE FROM %s +WHERE id IN (SELECT id FROM batch) +`, table, where, table) + + var total int64 + for { + res, err := db.ExecContext(ctx, q, cutoff, batchSize) + if err != nil { + if isMissingRelationError(err) { + return total, nil + } + return total, err + } + affected, err := res.RowsAffected() + if err != nil { + return total, err + } + total += affected + if affected == 0 { + break + } + } + return total, nil +} + +// truncateOpsTable 用 TRUNCATE TABLE 清空指定表,先 SELECT COUNT(*) 取得清空前行数用于 heartbeat。 +func truncateOpsTable(ctx context.Context, db *sql.DB, table string) (int64, error) { + if db == nil { + return 0, nil + } + var count int64 + if err := db.QueryRowContext(ctx, fmt.Sprintf("SELECT COUNT(*) FROM %s", table)).Scan(&count); err != nil { + if isMissingRelationError(err) { + return 0, nil + } + return 0, fmt.Errorf("count %s: %w", table, err) + } + if count == 0 { + return 0, nil + } + if _, err := db.ExecContext(ctx, fmt.Sprintf("TRUNCATE TABLE %s", table)); err != nil { + if isMissingRelationError(err) { + return 0, nil + } + return 0, fmt.Errorf("truncate %s: %w", table, err) + } + return count, nil +} + +func isMissingRelationError(err error) bool { + if err == nil { + return false + } + s := strings.ToLower(err.Error()) + return strings.Contains(s, "does not exist") && strings.Contains(s, "relation") +} diff --git a/backend/internal/service/ops_cleanup_overlay_test.go b/backend/internal/service/ops_cleanup_overlay_test.go new file mode 100644 index 00000000..f751a426 --- /dev/null +++ b/backend/internal/service/ops_cleanup_overlay_test.go @@ -0,0 +1,257 @@ +//go:build unit + +package service + +import ( + "context" + "encoding/json" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/config" +) + +// makeOverlayService 构造一个没有 cron / db 的 cleanup service,仅用来测试 effective overlay。 +func makeOverlayService(repo SettingRepository, base config.OpsCleanupConfig) *OpsCleanupService { + cfg := &config.Config{} + cfg.Ops.Cleanup = base + return &OpsCleanupService{ + cfg: cfg, + settingRepo: repo, + } +} + +func writeAdvancedSettings(t *testing.T, repo *runtimeSettingRepoStub, dr OpsDataRetentionSettings) { + t.Helper() + adv := OpsAdvancedSettings{DataRetention: dr} + raw, err := json.Marshal(adv) + if err != nil { + t.Fatalf("marshal: %v", err) + } + if err := repo.Set(context.Background(), SettingKeyOpsAdvancedSettings, string(raw)); err != nil { + t.Fatalf("set: %v", err) + } +} + +func TestComputeEffective_FallbackToCfgWhenSettingsAbsent(t *testing.T) { + repo := newRuntimeSettingRepoStub() + base := config.OpsCleanupConfig{ + Enabled: false, + Schedule: "0 2 * * *", + ErrorLogRetentionDays: 30, + MinuteMetricsRetentionDays: 30, + HourlyMetricsRetentionDays: 30, + } + svc := makeOverlayService(repo, base) + + svc.computeEffectiveLocked(context.Background()) + + if svc.effective != base { + t.Fatalf("expected effective == cfg base, got %#v", svc.effective) + } +} + +func TestComputeEffective_SettingsOverridesAll(t *testing.T) { + repo := newRuntimeSettingRepoStub() + writeAdvancedSettings(t, repo, OpsDataRetentionSettings{ + CleanupEnabled: true, + CleanupSchedule: "0 * * * *", + ErrorLogRetentionDays: 0, + MinuteMetricsRetentionDays: 7, + HourlyMetricsRetentionDays: 14, + }) + base := config.OpsCleanupConfig{ + Enabled: false, + Schedule: "0 2 * * *", + ErrorLogRetentionDays: 30, + MinuteMetricsRetentionDays: 30, + HourlyMetricsRetentionDays: 30, + } + svc := makeOverlayService(repo, base) + + svc.computeEffectiveLocked(context.Background()) + + want := config.OpsCleanupConfig{ + Enabled: true, + Schedule: "0 * * * *", + ErrorLogRetentionDays: 0, + MinuteMetricsRetentionDays: 7, + HourlyMetricsRetentionDays: 14, + } + if svc.effective != want { + t.Fatalf("effective mismatch:\nwant %#v\n got %#v", want, svc.effective) + } +} + +func TestComputeEffective_EmptyScheduleFallbackToCfg(t *testing.T) { + repo := newRuntimeSettingRepoStub() + writeAdvancedSettings(t, repo, OpsDataRetentionSettings{ + CleanupEnabled: true, + CleanupSchedule: " ", // 空白被 trim 后视为空 + ErrorLogRetentionDays: 5, + MinuteMetricsRetentionDays: 5, + HourlyMetricsRetentionDays: 5, + }) + base := config.OpsCleanupConfig{ + Enabled: false, + Schedule: "0 2 * * *", + ErrorLogRetentionDays: 30, + MinuteMetricsRetentionDays: 30, + HourlyMetricsRetentionDays: 30, + } + svc := makeOverlayService(repo, base) + + svc.computeEffectiveLocked(context.Background()) + + if svc.effective.Schedule != "0 2 * * *" { + t.Fatalf("expected schedule fallback to cfg, got %q", svc.effective.Schedule) + } + if !svc.effective.Enabled { + t.Fatalf("expected enabled=true from settings") + } + if svc.effective.ErrorLogRetentionDays != 5 { + t.Fatalf("expected retention=5 from settings, got %d", svc.effective.ErrorLogRetentionDays) + } +} + +func TestComputeEffective_NegativeRetentionFallsBackToCfg(t *testing.T) { + repo := newRuntimeSettingRepoStub() + writeAdvancedSettings(t, repo, OpsDataRetentionSettings{ + CleanupEnabled: true, + CleanupSchedule: "0 * * * *", + ErrorLogRetentionDays: -1, + MinuteMetricsRetentionDays: -1, + HourlyMetricsRetentionDays: -1, + }) + base := config.OpsCleanupConfig{ + Enabled: false, + Schedule: "0 2 * * *", + ErrorLogRetentionDays: 30, + MinuteMetricsRetentionDays: 60, + HourlyMetricsRetentionDays: 90, + } + svc := makeOverlayService(repo, base) + + svc.computeEffectiveLocked(context.Background()) + + if svc.effective.ErrorLogRetentionDays != 30 || + svc.effective.MinuteMetricsRetentionDays != 60 || + svc.effective.HourlyMetricsRetentionDays != 90 { + t.Fatalf("expected retention fallback to cfg, got %#v", svc.effective) + } +} + +func TestComputeEffective_BadJSONFallsBackToCfg(t *testing.T) { + repo := newRuntimeSettingRepoStub() + if err := repo.Set(context.Background(), SettingKeyOpsAdvancedSettings, "{not json"); err != nil { + t.Fatalf("set: %v", err) + } + base := config.OpsCleanupConfig{ + Enabled: true, + Schedule: "0 3 * * *", + ErrorLogRetentionDays: 30, + MinuteMetricsRetentionDays: 30, + HourlyMetricsRetentionDays: 30, + } + svc := makeOverlayService(repo, base) + + svc.computeEffectiveLocked(context.Background()) + + if svc.effective != base { + t.Fatalf("expected fallback to cfg on bad JSON, got %#v", svc.effective) + } +} + +// 验证 OpsService.UpdateOpsAdvancedSettings 写入后会调用 cleanupReloader.Reload。 +type fakeCleanupReloader struct { + calls int + last context.Context + err error +} + +func (f *fakeCleanupReloader) Reload(ctx context.Context) error { + f.calls++ + f.last = ctx + return f.err +} + +func TestUpdateOpsAdvancedSettings_TriggersReload(t *testing.T) { + repo := newRuntimeSettingRepoStub() + reloader := &fakeCleanupReloader{} + svc := &OpsService{settingRepo: repo} + svc.SetCleanupReloader(reloader) + + cfg := defaultOpsAdvancedSettings() + cfg.DataRetention.CleanupEnabled = true + cfg.DataRetention.CleanupSchedule = "0 * * * *" + cfg.DataRetention.ErrorLogRetentionDays = 3 + cfg.DataRetention.MinuteMetricsRetentionDays = 3 + cfg.DataRetention.HourlyMetricsRetentionDays = 3 + + if _, err := svc.UpdateOpsAdvancedSettings(context.Background(), cfg); err != nil { + t.Fatalf("update: %v", err) + } + if reloader.calls != 1 { + t.Fatalf("expected reloader.Reload called once, got %d", reloader.calls) + } +} + +func TestReload_BeforeStart_IsNoop(t *testing.T) { + svc := &OpsCleanupService{} + if err := svc.Reload(context.Background()); err != nil { + t.Fatalf("Reload before Start should return nil, got %v", err) + } +} + +func TestReload_AfterStop_IsNoop(t *testing.T) { + svc := &OpsCleanupService{started: true, stopped: true} + if err := svc.Reload(context.Background()); err != nil { + t.Fatalf("Reload after Stop should return nil, got %v", err) + } +} + +func TestUpdateOpsAdvancedSettings_NilReloader_NoPanic(t *testing.T) { + repo := newRuntimeSettingRepoStub() + svc := &OpsService{settingRepo: repo} + // cleanupReloader intentionally nil + + cfg := defaultOpsAdvancedSettings() + cfg.DataRetention.ErrorLogRetentionDays = 7 + + // should not panic + if _, err := svc.UpdateOpsAdvancedSettings(context.Background(), cfg); err != nil { + t.Fatalf("update with nil reloader: %v", err) + } +} + +func TestStart_IdempotentSecondCall(t *testing.T) { + svc := &OpsCleanupService{started: true} + svc.Start() // second call should be noop, not panic +} + +func TestRefreshEffectiveBeforeRun_UpdatesSnapshot(t *testing.T) { + repo := newRuntimeSettingRepoStub() + base := config.OpsCleanupConfig{ + Enabled: true, + Schedule: "0 2 * * *", + ErrorLogRetentionDays: 30, + } + svc := makeOverlayService(repo, base) + svc.computeEffectiveLocked(context.Background()) + + if svc.effective.ErrorLogRetentionDays != 30 { + t.Fatalf("initial retention should be 30, got %d", svc.effective.ErrorLogRetentionDays) + } + + // simulate UI change + writeAdvancedSettings(t, repo, OpsDataRetentionSettings{ + CleanupEnabled: true, + CleanupSchedule: "0 * * * *", + ErrorLogRetentionDays: 7, + }) + + svc.refreshEffectiveBeforeRun(context.Background()) + snap := svc.snapshotEffective() + if snap.ErrorLogRetentionDays != 7 { + t.Fatalf("after refresh, retention should be 7, got %d", snap.ErrorLogRetentionDays) + } +} diff --git a/backend/internal/service/ops_cleanup_service.go b/backend/internal/service/ops_cleanup_service.go index 44ec1ad1..60a690f3 100644 --- a/backend/internal/service/ops_cleanup_service.go +++ b/backend/internal/service/ops_cleanup_service.go @@ -3,6 +3,8 @@ package service import ( "context" "database/sql" + "encoding/json" + "errors" "fmt" "strings" "sync" @@ -45,13 +47,18 @@ type OpsCleanupService struct { redisClient *redis.Client cfg *config.Config channelMonitorSvc *ChannelMonitorService + settingRepo SettingRepository instanceID string - cron *cron.Cron - - startOnce sync.Once - stopOnce sync.Once + // mu 守护 cron 实例切换 + effective 配置切换。 + // 这里不再用 startOnce/stopOnce,是因为 Reload 需要"停旧 cron 重启新 cron", + // 而 Once 一旦触发就无法再次执行;改为 started/stopped 布尔配合 mu。 + mu sync.Mutex + cron *cron.Cron + started bool + stopped bool + effective config.OpsCleanupConfig warnNoRedisOnce sync.Once } @@ -62,6 +69,7 @@ func NewOpsCleanupService( redisClient *redis.Client, cfg *config.Config, channelMonitorSvc *ChannelMonitorService, + settingRepo SettingRepository, ) *OpsCleanupService { return &OpsCleanupService{ opsRepo: opsRepo, @@ -69,10 +77,13 @@ func NewOpsCleanupService( redisClient: redisClient, cfg: cfg, channelMonitorSvc: channelMonitorSvc, + settingRepo: settingRepo, instanceID: uuid.NewString(), } } +// Start 首次启动 cron 调度。Enabled / Schedule 由 effective 配置决定(settings 优先 cfg)。 +// 重复调用幂等。 func (s *OpsCleanupService) Start() { if s == nil { return @@ -80,54 +91,169 @@ func (s *OpsCleanupService) Start() { if s.cfg != nil && !s.cfg.Ops.Enabled { return } - if s.cfg != nil && !s.cfg.Ops.Cleanup.Enabled { - logger.LegacyPrintf("service.ops_cleanup", "[OpsCleanup] not started (disabled)") - return - } if s.opsRepo == nil || s.db == nil { logger.LegacyPrintf("service.ops_cleanup", "[OpsCleanup] not started (missing deps)") return } - s.startOnce.Do(func() { - schedule := "0 2 * * *" - if s.cfg != nil && strings.TrimSpace(s.cfg.Ops.Cleanup.Schedule) != "" { - schedule = strings.TrimSpace(s.cfg.Ops.Cleanup.Schedule) - } - - loc := time.Local - if s.cfg != nil && strings.TrimSpace(s.cfg.Timezone) != "" { - if parsed, err := time.LoadLocation(strings.TrimSpace(s.cfg.Timezone)); err == nil && parsed != nil { - loc = parsed - } - } - - c := cron.New(cron.WithParser(opsCleanupCronParser), cron.WithLocation(loc)) - _, err := c.AddFunc(schedule, func() { s.runScheduled() }) - if err != nil { - logger.LegacyPrintf("service.ops_cleanup", "[OpsCleanup] not started (invalid schedule=%q): %v", schedule, err) - return - } - s.cron = c - s.cron.Start() - logger.LegacyPrintf("service.ops_cleanup", "[OpsCleanup] started (schedule=%q tz=%s)", schedule, loc.String()) - }) + s.mu.Lock() + defer s.mu.Unlock() + if s.started || s.stopped { + return + } + s.started = true + if err := s.applyScheduleLocked(context.Background()); err != nil { + logger.LegacyPrintf("service.ops_cleanup", "[OpsCleanup] not started: %v", err) + } } +// Stop 关闭 cron。幂等。 func (s *OpsCleanupService) Stop() { if s == nil { return } - s.stopOnce.Do(func() { - if s.cron != nil { - ctx := s.cron.Stop() - select { - case <-ctx.Done(): - case <-time.After(3 * time.Second): - logger.LegacyPrintf("service.ops_cleanup", "[OpsCleanup] cron stop timed out") - } + s.mu.Lock() + defer s.mu.Unlock() + if s.stopped { + return + } + s.stopped = true + s.stopCronLocked() +} + +// stopCronLocked 停掉当前 cron 实例(带 3s 超时)。调用方持锁。 +func (s *OpsCleanupService) stopCronLocked() { + if s.cron == nil { + return + } + ctx := s.cron.Stop() + select { + case <-ctx.Done(): + case <-time.After(opsCleanupCronStopTimeout): + logger.LegacyPrintf("service.ops_cleanup", "[OpsCleanup] cron stop timed out") + } + s.cron = nil +} + +// applyScheduleLocked 重新计算 effective 配置并按其 schedule 重建 cron。调用方持锁。 +// 若 effective.Enabled=false(用户在 UI 关闭清理),停旧 cron 后直接返回,不创建新 cron。 +func (s *OpsCleanupService) applyScheduleLocked(ctx context.Context) error { + s.computeEffectiveLocked(ctx) + s.stopCronLocked() + + if !s.effective.Enabled { + logger.LegacyPrintf("service.ops_cleanup", "[OpsCleanup] cron disabled by settings") + return nil + } + + schedule := strings.TrimSpace(s.effective.Schedule) + if schedule == "" { + schedule = opsCleanupDefaultSchedule + } + + loc := time.Local + if s.cfg != nil && strings.TrimSpace(s.cfg.Timezone) != "" { + if parsed, err := time.LoadLocation(strings.TrimSpace(s.cfg.Timezone)); err == nil && parsed != nil { + loc = parsed } - }) + } + + c := cron.New(cron.WithParser(opsCleanupCronParser), cron.WithLocation(loc)) + if _, err := c.AddFunc(schedule, func() { s.runScheduled() }); err != nil { + return fmt.Errorf("invalid schedule %q: %w", schedule, err) + } + c.Start() + s.cron = c + logger.LegacyPrintf("service.ops_cleanup", + "[OpsCleanup] scheduled (schedule=%q tz=%s retention_days=err:%d/min:%d/hour:%d)", + schedule, loc.String(), + s.effective.ErrorLogRetentionDays, + s.effective.MinuteMetricsRetentionDays, + s.effective.HourlyMetricsRetentionDays, + ) + return nil +} + +// Reload 重新读取 ops_advanced_settings.data_retention 并按新配置重建 cron。 +// 适用于 admin 在 UI 修改清理设置后立即生效(schedule / enabled 改动需要 Reload; +// retention 改动 runScheduled 顶部也会刷新,下一次触发即生效)。 +// 若 service 还未 Start 或已 Stop,Reload 不做任何事。 +func (s *OpsCleanupService) Reload(ctx context.Context) error { + if s == nil { + return nil + } + s.mu.Lock() + defer s.mu.Unlock() + if !s.started || s.stopped { + return nil + } + return s.applyScheduleLocked(ctx) +} + +// computeEffectiveLocked 计算"生效配置"并写入 s.effective。调用方持锁。 +// +// 优先级:UI 写入的 settings.ops_advanced_settings.data_retention(权威)覆盖 cfg.Ops.Cleanup 的副本。 +// - Enabled:settings 直接覆盖 +// - Schedule:settings 非空时覆盖,否则保留 cfg +// - *RetentionDays:settings >=0 时覆盖(包括 0=TRUNCATE),<0 沿用 cfg +// +// 若 settings 表无该 key(ErrSettingNotFound)或解析失败,整体 fallback 到 cfg.Ops.Cleanup。 +func (s *OpsCleanupService) computeEffectiveLocked(ctx context.Context) { + base := config.OpsCleanupConfig{} + if s.cfg != nil { + base = s.cfg.Ops.Cleanup + } + defer func() { s.effective = base }() + + if s.settingRepo == nil { + return + } + if ctx == nil { + ctx = context.Background() + } + raw, err := s.settingRepo.GetValue(ctx, SettingKeyOpsAdvancedSettings) + if err != nil { + if !errors.Is(err, ErrSettingNotFound) { + logger.LegacyPrintf("service.ops_cleanup", + "[OpsCleanup] read advanced settings failed, using cfg: %v", err) + } + return + } + var adv OpsAdvancedSettings + if err := json.Unmarshal([]byte(raw), &adv); err != nil { + logger.LegacyPrintf("service.ops_cleanup", + "[OpsCleanup] parse advanced settings failed, using cfg: %v", err) + return + } + dr := adv.DataRetention + base.Enabled = dr.CleanupEnabled + if sched := strings.TrimSpace(dr.CleanupSchedule); sched != "" { + base.Schedule = sched + } + if dr.ErrorLogRetentionDays >= 0 { + base.ErrorLogRetentionDays = dr.ErrorLogRetentionDays + } + if dr.MinuteMetricsRetentionDays >= 0 { + base.MinuteMetricsRetentionDays = dr.MinuteMetricsRetentionDays + } + if dr.HourlyMetricsRetentionDays >= 0 { + base.HourlyMetricsRetentionDays = dr.HourlyMetricsRetentionDays + } +} + +// snapshotEffective 取一份 effective 副本(runCleanupOnce 等读路径使用)。 +func (s *OpsCleanupService) snapshotEffective() config.OpsCleanupConfig { + s.mu.Lock() + defer s.mu.Unlock() + return s.effective +} + +// refreshEffectiveBeforeRun 在 cron 触发时刷新 effective,让 retention 改动当次即生效。 +// schedule 改动不影响当次(cron 调度由库管理,需要 Reload 才换 schedule)。 +func (s *OpsCleanupService) refreshEffectiveBeforeRun(ctx context.Context) { + s.mu.Lock() + defer s.mu.Unlock() + s.computeEffectiveLocked(ctx) } func (s *OpsCleanupService) runScheduled() { @@ -135,9 +261,12 @@ func (s *OpsCleanupService) runScheduled() { return } - ctx, cancel := context.WithTimeout(context.Background(), 30*time.Minute) + ctx, cancel := context.WithTimeout(context.Background(), opsCleanupRunTimeout) defer cancel() + // 让 retention 改动当次生效(schedule/enabled 改动需要 Reload)。 + s.refreshEffectiveBeforeRun(ctx) + release, ok := s.tryAcquireLeaderLock(ctx) if !ok { return @@ -159,124 +288,36 @@ func (s *OpsCleanupService) runScheduled() { logger.LegacyPrintf("service.ops_cleanup", "[OpsCleanup] cleanup complete: %s", counts) } -type opsCleanupDeletedCounts struct { - errorLogs int64 - retryAttempts int64 - alertEvents int64 - systemLogs int64 - logAudits int64 - systemMetrics int64 - hourlyPreagg int64 - dailyPreagg int64 -} - -func (c opsCleanupDeletedCounts) String() string { - return fmt.Sprintf( - "error_logs=%d retry_attempts=%d alert_events=%d system_logs=%d log_audits=%d system_metrics=%d hourly_preagg=%d daily_preagg=%d", - c.errorLogs, - c.retryAttempts, - c.alertEvents, - c.systemLogs, - c.logAudits, - c.systemMetrics, - c.hourlyPreagg, - c.dailyPreagg, - ) -} - -// opsCleanupPlan 把"保留天数"翻译成具体的清理动作。 -// - days < 0 → 跳过该项清理(ok=false),保留兼容老数据 -// - days == 0 → TRUNCATE TABLE(O(1) 全清),truncate=true -// - days > 0 → 批量 DELETE 早于 now-N天 的行,cutoff = now - N 天 -// -// 之所以 days==0 走 TRUNCATE 而非"now+24h cutoff + DELETE": -// - 速度从 O(N) 降到 O(1),对百万行级表毫秒完成 -// - 无 WAL 写入、无后续 VACUUM 压力 -// - 这些 ops 表只有 cleanup 任务自己写,TRUNCATE 的 ACCESS EXCLUSIVE 锁影响可忽略 -func opsCleanupPlan(now time.Time, days int) (cutoff time.Time, truncate, ok bool) { - if days < 0 { - return time.Time{}, false, false - } - if days == 0 { - return time.Time{}, true, true - } - return now.AddDate(0, 0, -days), false, true -} - func (s *OpsCleanupService) runCleanupOnce(ctx context.Context) (opsCleanupDeletedCounts, error) { out := opsCleanupDeletedCounts{} if s == nil || s.db == nil || s.cfg == nil { return out, nil } - batchSize := 5000 - + effective := s.snapshotEffective() now := time.Now().UTC() - // runOne 把"truncate? cutoff? batched delete?"封装到一处, - // 让三组清理(错误日志类 / 分钟指标 / 小时+日预聚合)调用方只关心表名和列名。 - runOne := func(truncate bool, cutoff time.Time, table, timeCol string, castDate bool) (int64, error) { - if truncate { - return truncateOpsTable(ctx, s.db, table) - } - return deleteOldRowsByID(ctx, s.db, table, timeCol, cutoff, batchSize, castDate) + targets := []opsCleanupTarget{ + {effective.ErrorLogRetentionDays, "ops_error_logs", "created_at", false, &out.errorLogs}, + {effective.ErrorLogRetentionDays, "ops_retry_attempts", "created_at", false, &out.retryAttempts}, + {effective.ErrorLogRetentionDays, "ops_alert_events", "created_at", false, &out.alertEvents}, + {effective.ErrorLogRetentionDays, "ops_system_logs", "created_at", false, &out.systemLogs}, + {effective.ErrorLogRetentionDays, "ops_system_log_cleanup_audits", "created_at", false, &out.logAudits}, + {effective.MinuteMetricsRetentionDays, "ops_system_metrics", "created_at", false, &out.systemMetrics}, + {effective.HourlyMetricsRetentionDays, "ops_metrics_hourly", "bucket_start", false, &out.hourlyPreagg}, + {effective.HourlyMetricsRetentionDays, "ops_metrics_daily", "bucket_date", true, &out.dailyPreagg}, } - // Error-like tables: error logs / retry attempts / alert events / system logs / cleanup audits. - if cutoff, truncate, ok := opsCleanupPlan(now, s.cfg.Ops.Cleanup.ErrorLogRetentionDays); ok { - n, err := runOne(truncate, cutoff, "ops_error_logs", "created_at", false) + for _, t := range targets { + cutoff, truncate, ok := opsCleanupPlan(now, t.retentionDays) + if !ok { + continue + } + n, err := opsCleanupRunOne(ctx, s.db, truncate, cutoff, t.table, t.timeCol, t.castDate, opsCleanupBatchSize) if err != nil { return out, err } - out.errorLogs = n - - n, err = runOne(truncate, cutoff, "ops_retry_attempts", "created_at", false) - if err != nil { - return out, err - } - out.retryAttempts = n - - n, err = runOne(truncate, cutoff, "ops_alert_events", "created_at", false) - if err != nil { - return out, err - } - out.alertEvents = n - - n, err = runOne(truncate, cutoff, "ops_system_logs", "created_at", false) - if err != nil { - return out, err - } - out.systemLogs = n - - n, err = runOne(truncate, cutoff, "ops_system_log_cleanup_audits", "created_at", false) - if err != nil { - return out, err - } - out.logAudits = n - } - - // Minute-level metrics snapshots. - if cutoff, truncate, ok := opsCleanupPlan(now, s.cfg.Ops.Cleanup.MinuteMetricsRetentionDays); ok { - n, err := runOne(truncate, cutoff, "ops_system_metrics", "created_at", false) - if err != nil { - return out, err - } - out.systemMetrics = n - } - - // Pre-aggregation tables (hourly/daily). - if cutoff, truncate, ok := opsCleanupPlan(now, s.cfg.Ops.Cleanup.HourlyMetricsRetentionDays); ok { - n, err := runOne(truncate, cutoff, "ops_metrics_hourly", "bucket_start", false) - if err != nil { - return out, err - } - out.hourlyPreagg = n - - n, err = runOne(truncate, cutoff, "ops_metrics_daily", "bucket_date", true) - if err != nil { - return out, err - } - out.dailyPreagg = n + *t.counter = n } // Channel monitor 每日维护(聚合昨日明细 + 软删过期明细/聚合)。 @@ -291,100 +332,6 @@ func (s *OpsCleanupService) runCleanupOnce(ctx context.Context) (opsCleanupDelet return out, nil } -func deleteOldRowsByID( - ctx context.Context, - db *sql.DB, - table string, - timeColumn string, - cutoff time.Time, - batchSize int, - castCutoffToDate bool, -) (int64, error) { - if db == nil { - return 0, nil - } - if batchSize <= 0 { - batchSize = 5000 - } - - where := fmt.Sprintf("%s < $1", timeColumn) - if castCutoffToDate { - where = fmt.Sprintf("%s < $1::date", timeColumn) - } - - q := fmt.Sprintf(` -WITH batch AS ( - SELECT id FROM %s - WHERE %s - ORDER BY id - LIMIT $2 -) -DELETE FROM %s -WHERE id IN (SELECT id FROM batch) -`, table, where, table) - - var total int64 - for { - res, err := db.ExecContext(ctx, q, cutoff, batchSize) - if err != nil { - // If ops tables aren't present yet (partial deployments), treat as no-op. - if isMissingRelationError(err) { - return total, nil - } - return total, err - } - affected, err := res.RowsAffected() - if err != nil { - return total, err - } - total += affected - if affected == 0 { - break - } - } - return total, nil -} - -// truncateOpsTable 用 TRUNCATE TABLE 清空指定表,先 SELECT COUNT(*) 取得清空前行数用于 heartbeat。 -// -// 与 deleteOldRowsByID 的差异: -// - 不可指定 WHERE 条件,仅用于 days==0 的"清空全部"语义 -// - O(1) 释放表的物理存储页,毫秒级完成,无 WAL 写入、无 VACUUM 压力 -// - 需要 ACCESS EXCLUSIVE 锁,但 ops 表只有清理任务自己写入,瞬间锁影响可忽略 -// -// 表不存在(部分部署)静默返回 0,与 deleteOldRowsByID 保持一致。 -func truncateOpsTable(ctx context.Context, db *sql.DB, table string) (int64, error) { - if db == nil { - return 0, nil - } - var count int64 - if err := db.QueryRowContext(ctx, fmt.Sprintf("SELECT COUNT(*) FROM %s", table)).Scan(&count); err != nil { - if isMissingRelationError(err) { - return 0, nil - } - return 0, fmt.Errorf("count %s: %w", table, err) - } - if count == 0 { - return 0, nil - } - if _, err := db.ExecContext(ctx, fmt.Sprintf("TRUNCATE TABLE %s", table)); err != nil { - if isMissingRelationError(err) { - return 0, nil - } - return 0, fmt.Errorf("truncate %s: %w", table, err) - } - return count, nil -} - -// isMissingRelationError 判断 PG 报错是否为"表不存在",用于让清理任务在部分部署场景静默跳过。 -func isMissingRelationError(err error) bool { - if err == nil { - return false - } - s := strings.ToLower(err.Error()) - return strings.Contains(s, "does not exist") && strings.Contains(s, "relation") -} - func (s *OpsCleanupService) tryAcquireLeaderLock(ctx context.Context) (func(), bool) { if s == nil { return nil, false @@ -433,7 +380,7 @@ func (s *OpsCleanupService) recordHeartbeatSuccess(runAt time.Time, duration tim now := time.Now().UTC() durMs := duration.Milliseconds() result := truncateString(counts.String(), 2048) - ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + ctx, cancel := context.WithTimeout(context.Background(), opsCleanupHeartbeatTimeout) defer cancel() _ = s.opsRepo.UpsertJobHeartbeat(ctx, &OpsUpsertJobHeartbeatInput{ JobName: opsCleanupJobName, @@ -451,7 +398,7 @@ func (s *OpsCleanupService) recordHeartbeatError(runAt time.Time, duration time. now := time.Now().UTC() durMs := duration.Milliseconds() msg := truncateString(err.Error(), 2048) - ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + ctx, cancel := context.WithTimeout(context.Background(), opsCleanupHeartbeatTimeout) defer cancel() _ = s.opsRepo.UpsertJobHeartbeat(ctx, &OpsUpsertJobHeartbeatInput{ JobName: opsCleanupJobName, diff --git a/backend/internal/service/ops_service.go b/backend/internal/service/ops_service.go index cd3974a0..11afc6f9 100644 --- a/backend/internal/service/ops_service.go +++ b/backend/internal/service/ops_service.go @@ -54,6 +54,24 @@ type OpsService struct { geminiCompatService *GeminiMessagesCompatService antigravityGatewayService *AntigravityGatewayService systemLogSink *OpsSystemLogSink + + // cleanupReloader 由 wire 在 OpsCleanupService 构造完成后通过 SetCleanupReloader 注入。 + // 解耦避免 OpsService -> OpsCleanupService 的硬依赖(cleanup 也读 settings,会循环)。 + cleanupReloader CleanupReloader +} + +// CleanupReloader 由 OpsCleanupService 实现。 +// UpdateOpsAdvancedSettings 写入新配置后调用 Reload,让 schedule/enabled 改动立刻生效。 +type CleanupReloader interface { + Reload(ctx context.Context) error +} + +// SetCleanupReloader 由 wire 注入 cleanup hook(构造期循环依赖的解耦点)。 +func (s *OpsService) SetCleanupReloader(r CleanupReloader) { + if s == nil { + return + } + s.cleanupReloader = r } func NewOpsService( diff --git a/backend/internal/service/ops_settings.go b/backend/internal/service/ops_settings.go index ecc3a94b..68c1d9dd 100644 --- a/backend/internal/service/ops_settings.go +++ b/backend/internal/service/ops_settings.go @@ -4,6 +4,7 @@ import ( "context" "encoding/json" "errors" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" "strings" "time" ) @@ -360,7 +361,7 @@ func defaultOpsAdvancedSettings() *OpsAdvancedSettings { return &OpsAdvancedSettings{ DataRetention: OpsDataRetentionSettings{ CleanupEnabled: false, - CleanupSchedule: "0 2 * * *", + CleanupSchedule: opsCleanupDefaultSchedule, ErrorLogRetentionDays: 30, MinuteMetricsRetentionDays: 30, HourlyMetricsRetentionDays: 30, @@ -385,7 +386,7 @@ func normalizeOpsAdvancedSettings(cfg *OpsAdvancedSettings) { } cfg.DataRetention.CleanupSchedule = strings.TrimSpace(cfg.DataRetention.CleanupSchedule) if cfg.DataRetention.CleanupSchedule == "" { - cfg.DataRetention.CleanupSchedule = "0 2 * * *" + cfg.DataRetention.CleanupSchedule = opsCleanupDefaultSchedule } // 保留天数:0 表示每次定时清理全部(清空所有),> 0 表示按天数保留; // 仅在拿到非法的负数时回填默认值,避免覆盖用户主动设的 0。 @@ -477,6 +478,14 @@ func (s *OpsService) UpdateOpsAdvancedSettings(ctx context.Context, cfg *OpsAdva return nil, err } + // notify cleanup service to reload schedule/enabled. + if s.cleanupReloader != nil { + if rerr := s.cleanupReloader.Reload(ctx); rerr != nil { + logger.LegacyPrintf("service.ops_settings", + "[OpsSettings] cleanup reload after advanced-settings update failed: %v", rerr) + } + } + updated := &OpsAdvancedSettings{} _ = json.Unmarshal(raw, updated) return updated, nil diff --git a/backend/internal/service/payment_fulfillment.go b/backend/internal/service/payment_fulfillment.go index 5df69aea..4ae6d134 100644 --- a/backend/internal/service/payment_fulfillment.go +++ b/backend/internal/service/payment_fulfillment.go @@ -394,7 +394,8 @@ func (s *PaymentService) applyAffiliateRebateForOrder(ctx context.Context, o *db return nil } - rebateAmount, err := s.affiliateService.AccrueInviteRebate(txCtx, o.UserID, o.Amount) + sourceOrderID := o.ID + rebateAmount, err := s.affiliateService.AccrueInviteRebateForOrder(txCtx, o.UserID, o.Amount, &sourceOrderID) if err != nil { s.writeAuditLog(ctx, o.ID, "AFFILIATE_REBATE_FAILED", "system", map[string]any{ "error": err.Error(), diff --git a/backend/internal/service/pricing_service.go b/backend/internal/service/pricing_service.go index 91a02901..8a033710 100644 --- a/backend/internal/service/pricing_service.go +++ b/backend/internal/service/pricing_service.go @@ -625,6 +625,9 @@ func normalizeModelNameForPricing(model string) string { } model = strings.TrimLeft(model, "/") + if canonical := canonicalizeOpenAIModelAliasSpelling(model); canonical != "" { + return canonical + } return model } diff --git a/backend/internal/service/pricing_service_test.go b/backend/internal/service/pricing_service_test.go index e2bd7cf3..3c3e2c5b 100644 --- a/backend/internal/service/pricing_service_test.go +++ b/backend/internal/service/pricing_service_test.go @@ -98,6 +98,19 @@ func TestGetModelPricing_Gpt54UsesStaticFallbackWhenRemoteMissing(t *testing.T) require.InDelta(t, 1.5, got.LongContextOutputCostMultiplier, 1e-12) } +func TestGetModelPricing_OpenAICompactAliasUsesStaticFallback(t *testing.T) { + svc := &PricingService{ + pricingData: map[string]*LiteLLMModelPricing{ + "gpt-5.1-codex": {InputCostPerToken: 1.25e-6}, + }, + } + + got := svc.GetModelPricing("openai/gpt5.5") + require.NotNil(t, got) + require.InDelta(t, 2.5e-6, got.InputCostPerToken, 1e-12) + require.InDelta(t, 1.5e-5, got.OutputCostPerToken, 1e-12) +} + func TestGetModelPricing_Gpt54MiniUsesDedicatedStaticFallbackWhenRemoteMissing(t *testing.T) { svc := &PricingService{ pricingData: map[string]*LiteLLMModelPricing{ diff --git a/backend/internal/service/rate_limit_429_cooldown_test.go b/backend/internal/service/rate_limit_429_cooldown_test.go new file mode 100644 index 00000000..fb7e0dd7 --- /dev/null +++ b/backend/internal/service/rate_limit_429_cooldown_test.go @@ -0,0 +1,113 @@ +//go:build unit + +package service + +import ( + "context" + "encoding/json" + "net/http" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/stretchr/testify/require" +) + +type rateLimit429AccountRepoStub struct { + mockAccountRepoForGemini + rateLimitCalls int + lastRateLimitID int64 + lastRateLimitReset time.Time +} + +func (r *rateLimit429AccountRepoStub) SetRateLimited(_ context.Context, id int64, resetAt time.Time) error { + r.rateLimitCalls++ + r.lastRateLimitID = id + r.lastRateLimitReset = resetAt + return nil +} + +func TestGetRateLimit429CooldownSettings_DefaultsWhenNotSet(t *testing.T) { + repo := newMockSettingRepo() + svc := NewSettingService(repo, &config.Config{}) + + settings, err := svc.GetRateLimit429CooldownSettings(context.Background()) + require.NoError(t, err) + require.True(t, settings.Enabled) + require.Equal(t, 5, settings.CooldownSeconds) +} + +func TestGetRateLimit429CooldownSettings_ReadsFromDB(t *testing.T) { + repo := newMockSettingRepo() + data, _ := json.Marshal(RateLimit429CooldownSettings{Enabled: false, CooldownSeconds: 12}) + repo.data[SettingKeyRateLimit429CooldownSettings] = string(data) + svc := NewSettingService(repo, &config.Config{}) + + settings, err := svc.GetRateLimit429CooldownSettings(context.Background()) + require.NoError(t, err) + require.False(t, settings.Enabled) + require.Equal(t, 12, settings.CooldownSeconds) +} + +func TestSetRateLimit429CooldownSettings_EnabledRejectsOutOfRange(t *testing.T) { + svc := NewSettingService(newMockSettingRepo(), &config.Config{}) + + for _, seconds := range []int{0, -1, 7201, 99999} { + err := svc.SetRateLimit429CooldownSettings(context.Background(), &RateLimit429CooldownSettings{ + Enabled: true, CooldownSeconds: seconds, + }) + require.Error(t, err, "should reject enabled=true + cooldown_seconds=%d", seconds) + require.Contains(t, err.Error(), "cooldown_seconds must be between 1-7200") + } +} + +func TestHandle429_FallbackUsesDBSeconds(t *testing.T) { + accountRepo := &rateLimit429AccountRepoStub{} + settingRepo := newMockSettingRepo() + data, _ := json.Marshal(RateLimit429CooldownSettings{Enabled: true, CooldownSeconds: 12}) + settingRepo.data[SettingKeyRateLimit429CooldownSettings] = string(data) + + settingSvc := NewSettingService(settingRepo, &config.Config{}) + svc := NewRateLimitService(accountRepo, nil, &config.Config{}, nil, nil) + svc.SetSettingService(settingSvc) + + account := &Account{ID: 42, Platform: PlatformOpenAI, Type: AccountTypeOAuth} + before := time.Now() + svc.handle429(context.Background(), account, http.Header{}, []byte(`{"error":{"type":"rate_limit_error","message":"slow down"}}`)) + after := time.Now() + + require.Equal(t, 1, accountRepo.rateLimitCalls) + require.Equal(t, int64(42), accountRepo.lastRateLimitID) + require.True(t, !accountRepo.lastRateLimitReset.Before(before.Add(12*time.Second)) && !accountRepo.lastRateLimitReset.After(after.Add(12*time.Second))) +} + +func TestHandle429_FallbackDisabledSkipsLocalMark(t *testing.T) { + accountRepo := &rateLimit429AccountRepoStub{} + settingRepo := newMockSettingRepo() + data, _ := json.Marshal(RateLimit429CooldownSettings{Enabled: false, CooldownSeconds: 12}) + settingRepo.data[SettingKeyRateLimit429CooldownSettings] = string(data) + + settingSvc := NewSettingService(settingRepo, &config.Config{}) + svc := NewRateLimitService(accountRepo, nil, &config.Config{}, nil, nil) + svc.SetSettingService(settingSvc) + + account := &Account{ID: 43, Platform: PlatformOpenAI, Type: AccountTypeOAuth} + svc.handle429(context.Background(), account, http.Header{}, []byte(`{"error":{"type":"rate_limit_error","message":"slow down"}}`)) + + require.Zero(t, accountRepo.rateLimitCalls) +} + +func TestHandle429_FallbackUsesDefaultSecondsWhenSettingServiceMissing(t *testing.T) { + accountRepo := &rateLimit429AccountRepoStub{} + cfg := &config.Config{} + svc := NewRateLimitService(accountRepo, nil, cfg, nil, nil) + + account := &Account{ID: 44, Platform: PlatformGemini, Type: AccountTypeAPIKey} + before := time.Now() + svc.handle429(context.Background(), account, http.Header{}, []byte(`{"error":{"message":"slow down"}}`)) + after := time.Now() + + require.Equal(t, 1, accountRepo.rateLimitCalls) + require.Equal(t, int64(44), accountRepo.lastRateLimitID) + require.True(t, !accountRepo.lastRateLimitReset.Before(before.Add(5*time.Second)) && !accountRepo.lastRateLimitReset.After(after.Add(5*time.Second))) +} diff --git a/backend/internal/service/ratelimit_service.go b/backend/internal/service/ratelimit_service.go index 9344de47..a53cb0e9 100644 --- a/backend/internal/service/ratelimit_service.go +++ b/backend/internal/service/ratelimit_service.go @@ -55,6 +55,11 @@ type geminiUsageTotalsBatchProvider interface { const geminiPrecheckCacheTTL = time.Minute +const ( + defaultRateLimit429CooldownSeconds = 5 + maxRateLimit429CooldownSeconds = 7200 +) + const ( openAI403CooldownMinutesDefault = 10 openAI403DisableThreshold = 3 @@ -891,12 +896,8 @@ func (s *RateLimitService) handle429(ctx context.Context, account *Account, head return } - // 其他平台:没有重置时间,使用默认5分钟 - resetAt := time.Now().Add(5 * time.Minute) - slog.Warn("rate_limit_no_reset_time", "account_id", account.ID, "platform", account.Platform, "using_default", "5m") - if err := s.accountRepo.SetRateLimited(ctx, account.ID, resetAt); err != nil { - slog.Warn("rate_limit_set_failed", "account_id", account.ID, "error", err) - } + // 其他平台:没有重置时间,使用可配置的秒级默认回避,避免误伤长时间不可调度。 + s.apply429FallbackRateLimit(ctx, account, "no_reset_time") return } @@ -904,10 +905,7 @@ func (s *RateLimitService) handle429(ctx context.Context, account *Account, head ts, err := strconv.ParseInt(resetTimestamp, 10, 64) if err != nil { slog.Warn("rate_limit_reset_parse_failed", "reset_timestamp", resetTimestamp, "error", err) - resetAt := time.Now().Add(5 * time.Minute) - if err := s.accountRepo.SetRateLimited(ctx, account.ID, resetAt); err != nil { - slog.Warn("rate_limit_set_failed", "account_id", account.ID, "error", err) - } + s.apply429FallbackRateLimit(ctx, account, "reset_parse_failed") return } @@ -929,6 +927,48 @@ func (s *RateLimitService) handle429(ctx context.Context, account *Account, head slog.Info("account_rate_limited", "account_id", account.ID, "reset_at", resetAt) } +func (s *RateLimitService) apply429FallbackRateLimit(ctx context.Context, account *Account, reason string) { + cooldown, enabled := s.get429FallbackCooldown(ctx, account) + if !enabled { + slog.Info("rate_limit_429_fallback_ignored", "account_id", account.ID, "platform", account.Platform, "reason", reason) + return + } + + resetAt := time.Now().Add(cooldown) + slog.Warn("rate_limit_429_fallback_used", "account_id", account.ID, "platform", account.Platform, "reason", reason, "using_default", cooldown.String()) + if err := s.accountRepo.SetRateLimited(ctx, account.ID, resetAt); err != nil { + slog.Warn("rate_limit_set_failed", "account_id", account.ID, "error", err) + } +} + +func (s *RateLimitService) get429FallbackCooldown(ctx context.Context, account *Account) (time.Duration, bool) { + if s.settingService != nil { + settings, err := s.settingService.GetRateLimit429CooldownSettings(ctx) + if err == nil && settings != nil { + if !settings.Enabled { + return 0, false + } + seconds := clampRateLimit429CooldownSeconds(settings.CooldownSeconds) + return time.Duration(seconds) * time.Second, true + } + slog.Warn("rate_limit_429_settings_read_failed", "account_id", account.ID, "error", err) + } + + seconds := defaultRateLimit429CooldownSeconds + seconds = clampRateLimit429CooldownSeconds(seconds) + return time.Duration(seconds) * time.Second, true +} + +func clampRateLimit429CooldownSeconds(seconds int) int { + if seconds < 1 { + return 1 + } + if seconds > maxRateLimit429CooldownSeconds { + return maxRateLimit429CooldownSeconds + } + return seconds +} + // calculateOpenAI429ResetTime 从 OpenAI 429 响应头计算正确的重置时间 // 返回 nil 表示无法从响应头中确定重置时间 func calculateOpenAI429ResetTime(headers http.Header) *time.Time { diff --git a/backend/internal/service/setting_service.go b/backend/internal/service/setting_service.go index 2bae686a..a5d65ad7 100644 --- a/backend/internal/service/setting_service.go +++ b/backend/internal/service/setting_service.go @@ -2778,6 +2778,55 @@ func (s *SettingService) SetOverloadCooldownSettings(ctx context.Context, settin return s.settingRepo.Set(ctx, SettingKeyOverloadCooldownSettings, string(data)) } +// GetRateLimit429CooldownSettings 获取429默认回避配置 +func (s *SettingService) GetRateLimit429CooldownSettings(ctx context.Context) (*RateLimit429CooldownSettings, error) { + value, err := s.settingRepo.GetValue(ctx, SettingKeyRateLimit429CooldownSettings) + if err != nil { + if errors.Is(err, ErrSettingNotFound) { + return DefaultRateLimit429CooldownSettings(), nil + } + return nil, fmt.Errorf("get 429 cooldown settings: %w", err) + } + if value == "" { + return DefaultRateLimit429CooldownSettings(), nil + } + + var settings RateLimit429CooldownSettings + if err := json.Unmarshal([]byte(value), &settings); err != nil { + return DefaultRateLimit429CooldownSettings(), nil + } + + if settings.CooldownSeconds < 1 { + settings.CooldownSeconds = 1 + } + if settings.CooldownSeconds > 7200 { + settings.CooldownSeconds = 7200 + } + + return &settings, nil +} + +// SetRateLimit429CooldownSettings 设置429默认回避配置 +func (s *SettingService) SetRateLimit429CooldownSettings(ctx context.Context, settings *RateLimit429CooldownSettings) error { + if settings == nil { + return fmt.Errorf("settings cannot be nil") + } + + if settings.CooldownSeconds < 1 || settings.CooldownSeconds > 7200 { + if settings.Enabled { + return fmt.Errorf("cooldown_seconds must be between 1-7200") + } + settings.CooldownSeconds = 5 + } + + data, err := json.Marshal(settings) + if err != nil { + return fmt.Errorf("marshal 429 cooldown settings: %w", err) + } + + return s.settingRepo.Set(ctx, SettingKeyRateLimit429CooldownSettings, string(data)) +} + // GetOIDCConnectOAuthConfig 返回用于登录的“最终生效” OIDC 配置。 // // 优先级: diff --git a/backend/internal/service/settings_view.go b/backend/internal/service/settings_view.go index 41c01cca..aaf837bd 100644 --- a/backend/internal/service/settings_view.go +++ b/backend/internal/service/settings_view.go @@ -381,6 +381,14 @@ type OverloadCooldownSettings struct { CooldownMinutes int `json:"cooldown_minutes"` } +// RateLimit429CooldownSettings 429默认回避配置 +type RateLimit429CooldownSettings struct { + // Enabled 是否在无法解析上游重置时间时应用默认429回避 + Enabled bool `json:"enabled"` + // CooldownSeconds 默认回避时长(秒) + CooldownSeconds int `json:"cooldown_seconds"` +} + // DefaultOverloadCooldownSettings 返回默认的过载冷却配置(启用,10分钟) func DefaultOverloadCooldownSettings() *OverloadCooldownSettings { return &OverloadCooldownSettings{ @@ -389,6 +397,14 @@ func DefaultOverloadCooldownSettings() *OverloadCooldownSettings { } } +// DefaultRateLimit429CooldownSettings 返回默认的429回避配置(启用,5秒) +func DefaultRateLimit429CooldownSettings() *RateLimit429CooldownSettings { + return &RateLimit429CooldownSettings{ + Enabled: true, + CooldownSeconds: 5, + } +} + // DefaultBetaPolicySettings 返回默认的 Beta 策略配置 func DefaultBetaPolicySettings() *BetaPolicySettings { return &BetaPolicySettings{ diff --git a/backend/internal/service/wire.go b/backend/internal/service/wire.go index 1b7cb7ac..dfafa94e 100644 --- a/backend/internal/service/wire.go +++ b/backend/internal/service/wire.go @@ -274,15 +274,22 @@ func ProvideOpsAlertEvaluatorService( // ProvideOpsCleanupService creates and starts OpsCleanupService (cron scheduled). // channelMonitorSvc 让维护任务(聚合 + 历史/聚合软删)跟随 ops 清理 cron 一起跑, // 共享 leader lock + heartbeat。 +// settingRepo 让 cleanup service 自己读 ops_advanced_settings.data_retention 覆盖 cfg; +// opsService 用来反向注入 cleanup hook,以便 UI 改清理设置时能 Reload cron。 func ProvideOpsCleanupService( opsRepo OpsRepository, db *sql.DB, redisClient *redis.Client, cfg *config.Config, channelMonitorSvc *ChannelMonitorService, + settingRepo SettingRepository, + opsService *OpsService, ) *OpsCleanupService { - svc := NewOpsCleanupService(opsRepo, db, redisClient, cfg, channelMonitorSvc) + svc := NewOpsCleanupService(opsRepo, db, redisClient, cfg, channelMonitorSvc, settingRepo) svc.Start() + if opsService != nil { + opsService.SetCleanupReloader(svc) + } return svc } diff --git a/backend/migrations/134_affiliate_ledger_audit_snapshots.sql b/backend/migrations/134_affiliate_ledger_audit_snapshots.sql new file mode 100644 index 00000000..8a87ed1f --- /dev/null +++ b/backend/migrations/134_affiliate_ledger_audit_snapshots.sql @@ -0,0 +1,85 @@ +-- 邀请返利流水补充订单关联和转余额快照。 +-- 这些字段只用于审计展示;历史旧流水无法可靠反推的字段保持 NULL,避免把当前状态误展示为历史状态。 + +ALTER TABLE user_affiliate_ledger + ADD COLUMN IF NOT EXISTS source_order_id BIGINT NULL REFERENCES payment_orders(id) ON DELETE SET NULL; + +ALTER TABLE user_affiliate_ledger + ADD COLUMN IF NOT EXISTS balance_after DECIMAL(20,8) NULL; + +ALTER TABLE user_affiliate_ledger + ADD COLUMN IF NOT EXISTS aff_quota_after DECIMAL(20,8) NULL; + +ALTER TABLE user_affiliate_ledger + ADD COLUMN IF NOT EXISTS aff_frozen_quota_after DECIMAL(20,8) NULL; + +ALTER TABLE user_affiliate_ledger + ADD COLUMN IF NOT EXISTS aff_history_quota_after DECIMAL(20,8) NULL; + +COMMENT ON COLUMN user_affiliate_ledger.source_order_id IS '产生该返利流水的充值订单;转余额或无法可靠回填的历史数据为 NULL'; +COMMENT ON COLUMN user_affiliate_ledger.balance_after IS '邀请返利转余额后的用户余额快照;无法取得时为 NULL'; +COMMENT ON COLUMN user_affiliate_ledger.aff_quota_after IS '邀请返利转余额后的可用返利额度快照;无法取得时为 NULL'; +COMMENT ON COLUMN user_affiliate_ledger.aff_frozen_quota_after IS '邀请返利转余额后的冻结返利额度快照;无法取得时为 NULL'; +COMMENT ON COLUMN user_affiliate_ledger.aff_history_quota_after IS '邀请返利转余额后的历史返利总额快照;无法取得时为 NULL'; + +CREATE INDEX IF NOT EXISTS idx_user_affiliate_ledger_source_order_id + ON user_affiliate_ledger(source_order_id) + WHERE source_order_id IS NOT NULL; + +CREATE INDEX IF NOT EXISTS idx_user_affiliate_ledger_rebate_lookup + ON user_affiliate_ledger(action, source_order_id, user_id, source_user_id, created_at) + WHERE action = 'accrue'; + +-- 尽力回填 PR #2169 合并后、该迁移前已经产生的返利流水。 +-- 只有在同一订单只能匹配到一条返利流水时才回填,避免把多笔同额流水错误绑定到订单。 +WITH rebate_audits AS ( + SELECT po.id AS order_id, + po.user_id AS invitee_user_id, + invitee_aff.inviter_id, + rebate_detail.rebate_amount, + pal.created_at AS audit_created_at + FROM payment_audit_logs pal + CROSS JOIN LATERAL ( + SELECT substring( + pal.detail + FROM '"rebateAmount"[[:space:]]*:[[:space:]]*(-?[0-9]+(\.[0-9]+)?)' + )::numeric AS rebate_amount + ) rebate_detail + JOIN payment_orders po ON po.id::text = pal.order_id + JOIN user_affiliates invitee_aff ON invitee_aff.user_id = po.user_id + WHERE pal.action = 'AFFILIATE_REBATE_APPLIED' + AND rebate_detail.rebate_amount IS NOT NULL +), +ranked_matches AS ( + SELECT ual.id AS ledger_id, + ra.order_id, + COUNT(*) OVER (PARTITION BY ra.order_id) AS order_match_count, + COUNT(*) OVER (PARTITION BY ual.id) AS ledger_match_count, + ROW_NUMBER() OVER ( + PARTITION BY ual.id + ORDER BY ABS(EXTRACT(EPOCH FROM (ual.created_at - ra.audit_created_at))), ra.order_id + ) AS ledger_rank + FROM rebate_audits ra + JOIN user_affiliate_ledger ual + ON ual.action = 'accrue' + AND ual.source_order_id IS NULL + AND ual.user_id = ra.inviter_id + AND ual.source_user_id = ra.invitee_user_id + AND ABS(ual.amount - ra.rebate_amount) < 0.00000001 + AND ual.created_at BETWEEN ra.audit_created_at - INTERVAL '10 minutes' + AND ra.audit_created_at + INTERVAL '10 minutes' +) +UPDATE user_affiliate_ledger ual +SET source_order_id = ranked_matches.order_id, + updated_at = NOW() +FROM ranked_matches +WHERE ual.id = ranked_matches.ledger_id + AND ranked_matches.order_match_count = 1 + AND ranked_matches.ledger_match_count = 1 + AND ranked_matches.ledger_rank = 1 + AND NOT EXISTS ( + SELECT 1 + FROM user_affiliate_ledger existing + WHERE existing.source_order_id = ranked_matches.order_id + AND existing.action = 'accrue' + ); diff --git a/backend/migrations/134_image_generation_group_controls.sql b/backend/migrations/134_image_generation_group_controls.sql new file mode 100644 index 00000000..37941c00 --- /dev/null +++ b/backend/migrations/134_image_generation_group_controls.sql @@ -0,0 +1,26 @@ +-- 生图能力与图片倍率模式控制 +-- 兼容性原则: +-- 1. 不改写现有 image_price_1k/2k/4k,避免改变已配置分组的最终图片价格。 +-- 2. 现有 openai/gemini/antigravity 分组默认保持可生图,避免升级后中断已有图片业务。 +-- 3. 现有分组默认共享当前有效分组倍率,保持历史扣费公式。 + +ALTER TABLE groups + ADD COLUMN IF NOT EXISTS allow_image_generation BOOLEAN NOT NULL DEFAULT false; + +ALTER TABLE groups + ADD COLUMN IF NOT EXISTS image_rate_independent BOOLEAN NOT NULL DEFAULT false; + +ALTER TABLE groups + ADD COLUMN IF NOT EXISTS image_rate_multiplier DECIMAL(10,4) NOT NULL DEFAULT 1.0; + +UPDATE groups +SET allow_image_generation = true +WHERE platform IN ('openai', 'gemini', 'antigravity'); + +UPDATE groups +SET image_rate_independent = false, + image_rate_multiplier = 1.0; + +COMMENT ON COLUMN groups.allow_image_generation IS '是否允许该分组使用图片生成能力'; +COMMENT ON COLUMN groups.image_rate_independent IS '图片生成是否使用独立倍率;false 表示共享分组有效倍率'; +COMMENT ON COLUMN groups.image_rate_multiplier IS '图片生成独立倍率,仅 image_rate_independent=true 时生效'; diff --git a/backend/migrations/auth_identity_payment_migrations_regression_test.go b/backend/migrations/auth_identity_payment_migrations_regression_test.go index 798ae0fe..99216296 100644 --- a/backend/migrations/auth_identity_payment_migrations_regression_test.go +++ b/backend/migrations/auth_identity_payment_migrations_regression_test.go @@ -127,3 +127,18 @@ func TestMigration124BackfillsLegacyOIDCSecurityFlagsSafely(t *testing.T) { require.Contains(t, sql, "oidc_connect_enabled") require.Contains(t, sql, "'false'") } + +func TestMigration134AddsAffiliateLedgerAuditFieldsWithoutJSONCast(t *testing.T) { + content, err := FS.ReadFile("134_affiliate_ledger_audit_snapshots.sql") + require.NoError(t, err) + + sql := string(content) + require.Contains(t, sql, "ADD COLUMN IF NOT EXISTS source_order_id BIGINT") + require.Contains(t, sql, "ADD COLUMN IF NOT EXISTS balance_after DECIMAL(20,8)") + require.Contains(t, sql, "ADD COLUMN IF NOT EXISTS aff_quota_after DECIMAL(20,8)") + require.Contains(t, sql, "substring(") + require.Contains(t, sql, `"rebateAmount"`) + require.Contains(t, sql, "COUNT(*) OVER (PARTITION BY ra.order_id) AS order_match_count") + require.Contains(t, sql, "COUNT(*) OVER (PARTITION BY ual.id) AS ledger_match_count") + require.NotContains(t, sql, "detail::jsonb") +} diff --git a/deploy/.env.example b/deploy/.env.example index e0126bcb..9df40dcd 100644 --- a/deploy/.env.example +++ b/deploy/.env.example @@ -288,6 +288,25 @@ GATEWAY_SCHEDULING_OUTBOX_BACKLOG_REBUILD_ROWS=10000 # 全量重建周期(秒) GATEWAY_SCHEDULING_FULL_REBUILD_INTERVAL_SECONDS=300 +# ----------------------------------------------------------------------------- +# Image Generation Stream & Concurrency (Optional) +# 图片生成流式与并发隔离配置(可选) +# ----------------------------------------------------------------------------- +# 图片流式上游数据间隔超时(秒)。0 表示禁用;非 0 时必须为 60-1800。 +GATEWAY_IMAGE_STREAM_DATA_INTERVAL_TIMEOUT=900 +# 图片流式 keepalive 间隔(秒)。0 表示禁用;非 0 时必须为 5-60。 +GATEWAY_IMAGE_STREAM_KEEPALIVE_INTERVAL=10 +# 是否启用进程级图片生成并发限制。默认 false,保持历史行为。 +GATEWAY_IMAGE_CONCURRENCY_ENABLED=false +# 当前进程允许同时处理的图片生成请求数。0 表示不限制。 +GATEWAY_IMAGE_CONCURRENCY_MAX_CONCURRENT_REQUESTS=0 +# 图片并发超限策略:reject 直接返回 429;wait 等待空闲槽位。 +GATEWAY_IMAGE_CONCURRENCY_OVERFLOW_MODE=reject +# wait 模式下等待空闲图片槽位的最长时间(秒)。 +GATEWAY_IMAGE_CONCURRENCY_WAIT_TIMEOUT_SECONDS=30 +# wait 模式下当前进程允许排队等待的最大图片请求数。0 表示不允许等待队列。 +GATEWAY_IMAGE_CONCURRENCY_MAX_WAITING_REQUESTS=100 + # ----------------------------------------------------------------------------- # Dashboard Aggregation (Optional) # ----------------------------------------------------------------------------- diff --git a/deploy/config.example.yaml b/deploy/config.example.yaml index c53c430e..dc5d3cc9 100644 --- a/deploy/config.example.yaml +++ b/deploy/config.example.yaml @@ -364,6 +364,30 @@ gateway: # Stream keepalive interval (seconds), 0=disable # 流式 keepalive 间隔(秒),0=禁用 stream_keepalive_interval: 10 + # Image stream data interval timeout (seconds), 0=disable; independent from ordinary text streams + # 图片流数据间隔超时(秒),0=禁用;独立于普通文本流式 + image_stream_data_interval_timeout: 900 + # Image stream keepalive interval (seconds), 0=disable; independent from ordinary text streams + # 图片流式 keepalive 间隔(秒),0=禁用;独立于普通文本流式 + image_stream_keepalive_interval: 10 + # Image generation independent concurrency limiter (process-local, default disabled) + # 图片生成独立并发限制(进程级,默认关闭;多实例总上限约为实例数×该值) + image_concurrency: + # Enable image-only concurrency protection; false keeps existing behavior unchanged + # 是否启用图片独立并发保护;false 保持现有行为不变 + enabled: false + # Max concurrent image generation requests in this process, 0=unlimited + # 当前进程允许同时处理的图片生成请求数,0=不限制 + max_concurrent_requests: 0 + # Overflow mode when the image concurrency limit is full: reject/wait + # 图片并发满时的处理方式:reject=立即拒绝,wait=等待槽位 + overflow_mode: "reject" + # Wait timeout for overflow_mode=wait (seconds), 0=do not wait + # wait 模式等待图片并发槽位的超时时间(秒),0=不等待 + wait_timeout_seconds: 30 + # Max image requests waiting in this process when overflow_mode=wait, 0=unlimited + # wait 模式当前进程允许排队等待的图片请求数,0=不限制 + max_waiting_requests: 100 # SSE max line size in bytes (default: 40MB) # SSE 单行最大字节数(默认 40MB) max_line_size: 41943040 diff --git a/deploy/docker-compose.dev.yml b/deploy/docker-compose.dev.yml index 7793e424..b7a805b5 100644 --- a/deploy/docker-compose.dev.yml +++ b/deploy/docker-compose.dev.yml @@ -40,6 +40,13 @@ services: - JWT_SECRET=${JWT_SECRET:-} - TOTP_ENCRYPTION_KEY=${TOTP_ENCRYPTION_KEY:-} - TZ=${TZ:-Asia/Shanghai} + - GATEWAY_IMAGE_STREAM_DATA_INTERVAL_TIMEOUT=${GATEWAY_IMAGE_STREAM_DATA_INTERVAL_TIMEOUT:-900} + - GATEWAY_IMAGE_STREAM_KEEPALIVE_INTERVAL=${GATEWAY_IMAGE_STREAM_KEEPALIVE_INTERVAL:-10} + - GATEWAY_IMAGE_CONCURRENCY_ENABLED=${GATEWAY_IMAGE_CONCURRENCY_ENABLED:-false} + - GATEWAY_IMAGE_CONCURRENCY_MAX_CONCURRENT_REQUESTS=${GATEWAY_IMAGE_CONCURRENCY_MAX_CONCURRENT_REQUESTS:-0} + - GATEWAY_IMAGE_CONCURRENCY_OVERFLOW_MODE=${GATEWAY_IMAGE_CONCURRENCY_OVERFLOW_MODE:-reject} + - GATEWAY_IMAGE_CONCURRENCY_WAIT_TIMEOUT_SECONDS=${GATEWAY_IMAGE_CONCURRENCY_WAIT_TIMEOUT_SECONDS:-30} + - GATEWAY_IMAGE_CONCURRENCY_MAX_WAITING_REQUESTS=${GATEWAY_IMAGE_CONCURRENCY_MAX_WAITING_REQUESTS:-100} depends_on: postgres: condition: service_healthy diff --git a/deploy/docker-compose.local.yml b/deploy/docker-compose.local.yml index 902740cb..f0ab401e 100644 --- a/deploy/docker-compose.local.yml +++ b/deploy/docker-compose.local.yml @@ -163,6 +163,17 @@ services: # Proxy for accessing GitHub (online updates + pricing data) # Examples: http://host:port, socks5://host:port - UPDATE_PROXY_URL=${UPDATE_PROXY_URL:-} + + # ======================================================================= + # Image Generation Stream & Concurrency + # ======================================================================= + - GATEWAY_IMAGE_STREAM_DATA_INTERVAL_TIMEOUT=${GATEWAY_IMAGE_STREAM_DATA_INTERVAL_TIMEOUT:-900} + - GATEWAY_IMAGE_STREAM_KEEPALIVE_INTERVAL=${GATEWAY_IMAGE_STREAM_KEEPALIVE_INTERVAL:-10} + - GATEWAY_IMAGE_CONCURRENCY_ENABLED=${GATEWAY_IMAGE_CONCURRENCY_ENABLED:-false} + - GATEWAY_IMAGE_CONCURRENCY_MAX_CONCURRENT_REQUESTS=${GATEWAY_IMAGE_CONCURRENCY_MAX_CONCURRENT_REQUESTS:-0} + - GATEWAY_IMAGE_CONCURRENCY_OVERFLOW_MODE=${GATEWAY_IMAGE_CONCURRENCY_OVERFLOW_MODE:-reject} + - GATEWAY_IMAGE_CONCURRENCY_WAIT_TIMEOUT_SECONDS=${GATEWAY_IMAGE_CONCURRENCY_WAIT_TIMEOUT_SECONDS:-30} + - GATEWAY_IMAGE_CONCURRENCY_MAX_WAITING_REQUESTS=${GATEWAY_IMAGE_CONCURRENCY_MAX_WAITING_REQUESTS:-100} depends_on: postgres: condition: service_healthy diff --git a/deploy/docker-compose.standalone.yml b/deploy/docker-compose.standalone.yml index df0ccfcc..438d0a8a 100644 --- a/deploy/docker-compose.standalone.yml +++ b/deploy/docker-compose.standalone.yml @@ -93,6 +93,17 @@ services: # SECURITY: This repo does not embed third-party client_secret. - GEMINI_CLI_OAUTH_CLIENT_SECRET=${GEMINI_CLI_OAUTH_CLIENT_SECRET:-} - ANTIGRAVITY_OAUTH_CLIENT_SECRET=${ANTIGRAVITY_OAUTH_CLIENT_SECRET:-} + + # ======================================================================= + # Image Generation Stream & Concurrency + # ======================================================================= + - GATEWAY_IMAGE_STREAM_DATA_INTERVAL_TIMEOUT=${GATEWAY_IMAGE_STREAM_DATA_INTERVAL_TIMEOUT:-900} + - GATEWAY_IMAGE_STREAM_KEEPALIVE_INTERVAL=${GATEWAY_IMAGE_STREAM_KEEPALIVE_INTERVAL:-10} + - GATEWAY_IMAGE_CONCURRENCY_ENABLED=${GATEWAY_IMAGE_CONCURRENCY_ENABLED:-false} + - GATEWAY_IMAGE_CONCURRENCY_MAX_CONCURRENT_REQUESTS=${GATEWAY_IMAGE_CONCURRENCY_MAX_CONCURRENT_REQUESTS:-0} + - GATEWAY_IMAGE_CONCURRENCY_OVERFLOW_MODE=${GATEWAY_IMAGE_CONCURRENCY_OVERFLOW_MODE:-reject} + - GATEWAY_IMAGE_CONCURRENCY_WAIT_TIMEOUT_SECONDS=${GATEWAY_IMAGE_CONCURRENCY_WAIT_TIMEOUT_SECONDS:-30} + - GATEWAY_IMAGE_CONCURRENCY_MAX_WAITING_REQUESTS=${GATEWAY_IMAGE_CONCURRENCY_MAX_WAITING_REQUESTS:-100} healthcheck: test: ["CMD", "wget", "-q", "-T", "5", "-O", "/dev/null", "http://localhost:8080/health"] interval: 30s diff --git a/deploy/docker-compose.yml b/deploy/docker-compose.yml index 6f210d56..212a6a7b 100644 --- a/deploy/docker-compose.yml +++ b/deploy/docker-compose.yml @@ -109,6 +109,16 @@ services: - WINDSURF_DOCKER_PORT=${WINDSURF_DOCKER_PORT:-42099} - WINDSURF_DOCKER_CSRF_TOKEN=${WINDSURF_DOCKER_CSRF_TOKEN:-} + # ======================================================================= + # Image Generation Stream & Concurrency + # ======================================================================= + - GATEWAY_IMAGE_STREAM_DATA_INTERVAL_TIMEOUT=${GATEWAY_IMAGE_STREAM_DATA_INTERVAL_TIMEOUT:-900} + - GATEWAY_IMAGE_STREAM_KEEPALIVE_INTERVAL=${GATEWAY_IMAGE_STREAM_KEEPALIVE_INTERVAL:-10} + - GATEWAY_IMAGE_CONCURRENCY_ENABLED=${GATEWAY_IMAGE_CONCURRENCY_ENABLED:-false} + - GATEWAY_IMAGE_CONCURRENCY_MAX_CONCURRENT_REQUESTS=${GATEWAY_IMAGE_CONCURRENCY_MAX_CONCURRENT_REQUESTS:-0} + - GATEWAY_IMAGE_CONCURRENCY_OVERFLOW_MODE=${GATEWAY_IMAGE_CONCURRENCY_OVERFLOW_MODE:-reject} + - GATEWAY_IMAGE_CONCURRENCY_WAIT_TIMEOUT_SECONDS=${GATEWAY_IMAGE_CONCURRENCY_WAIT_TIMEOUT_SECONDS:-30} + - GATEWAY_IMAGE_CONCURRENCY_MAX_WAITING_REQUESTS=${GATEWAY_IMAGE_CONCURRENCY_MAX_WAITING_REQUESTS:-100} depends_on: postgres: condition: service_healthy diff --git a/frontend/package.json b/frontend/package.json index 098b0979..a4681577 100644 --- a/frontend/package.json +++ b/frontend/package.json @@ -18,7 +18,7 @@ "@lobehub/icons": "^4.0.2", "@tanstack/vue-virtual": "^3.13.23", "@vueuse/core": "^10.7.0", - "axios": "^1.15.0", + "axios": "^1.16.0", "chart.js": "^4.4.1", "dompurify": "^3.3.1", "driver.js": "^1.4.0", diff --git a/frontend/pnpm-lock.yaml b/frontend/pnpm-lock.yaml index 67d2a9b1..31d929e3 100644 --- a/frontend/pnpm-lock.yaml +++ b/frontend/pnpm-lock.yaml @@ -18,8 +18,8 @@ importers: specifier: ^10.7.0 version: 10.11.1(vue@3.5.26(typescript@5.6.3)) axios: - specifier: ^1.15.0 - version: 1.15.0 + specifier: ^1.16.0 + version: 1.16.0 chart.js: specifier: ^4.4.1 version: 4.5.1 @@ -1858,8 +1858,8 @@ packages: peerDependencies: postcss: ^8.1.0 - axios@1.15.0: - resolution: {integrity: sha512-wWyJDlAatxk30ZJer+GeCWS209sA42X+N5jU2jy6oHTp7ufw8uzUTVFBX9+wTfAlhiJXGS0Bq7X6efruWjuK9Q==} + axios@1.16.0: + resolution: {integrity: sha512-6hp5CwvTPlN2A31g5dxnwAX0orzM7pmCRDLnZSX772mv8WDqICwFjowHuPs04Mc8deIld1+ejhtaMn5vp6b+1w==} babel-plugin-macros@3.1.0: resolution: {integrity: sha512-Cg7TFGpIr01vOQNODXOOaGz2NpCU5gl8x1qJFbb6hbZxR7XrcE2vtbAsTAbJ7/xwJtUuJEw8K8Zr/AE0LHlesg==} @@ -2534,8 +2534,8 @@ packages: flatted@3.3.3: resolution: {integrity: sha512-GX+ysw4PBCz0PzosHDepZGANEuFCMLrnRTiEy9McGjmkCQYwRq4A/X786G/fjM/+OjsWSU1ZrY5qyARZmO/uwg==} - follow-redirects@1.15.11: - resolution: {integrity: sha512-deG2P0JfjrTxl50XGCDyfI97ZGVCxIpfKYmfyrQ54n5FO/0gfIES8C/Psl6kWVDolizcaaxZJnTS0QSMxvnsBQ==} + follow-redirects@1.16.0: + resolution: {integrity: sha512-y5rN/uOsadFT/JfYwhxRS5R7Qce+g3zG97+JrtFZlC9klX/W5hD7iiLzScI4nZqUS7DNUdhPgw4xI8W2LuXlUw==} engines: {node: '>=4.0'} peerDependencies: debug: '*' @@ -6484,9 +6484,9 @@ snapshots: postcss: 8.5.6 postcss-value-parser: 4.2.0 - axios@1.15.0: + axios@1.16.0: dependencies: - follow-redirects: 1.15.11 + follow-redirects: 1.16.0 form-data: 4.0.5 proxy-from-env: 2.1.0 transitivePeerDependencies: @@ -7228,7 +7228,7 @@ snapshots: flatted@3.3.3: {} - follow-redirects@1.15.11: {} + follow-redirects@1.16.0: {} for-in@1.0.2: {} diff --git a/frontend/src/api/admin/affiliates.ts b/frontend/src/api/admin/affiliates.ts index 22639bd2..dadb0ae9 100644 --- a/frontend/src/api/admin/affiliates.ts +++ b/frontend/src/api/admin/affiliates.ts @@ -23,6 +23,72 @@ export interface ListAffiliateUsersParams { search?: string } +export interface ListAffiliateRecordsParams { + page?: number + page_size?: number + search?: string + start_at?: string + end_at?: string + sort_by?: string + sort_order?: 'asc' | 'desc' + timezone?: string +} + +export interface AffiliateInviteRecord { + inviter_id: number + inviter_email: string + inviter_username: string + invitee_id: number + invitee_email: string + invitee_username: string + aff_code: string + total_rebate: number + created_at: string +} + +export interface AffiliateRebateRecord { + order_id: number + out_trade_no: string + inviter_id: number + inviter_email: string + inviter_username: string + invitee_id: number + invitee_email: string + invitee_username: string + order_amount: number + pay_amount: number + rebate_amount: number + payment_type: string + order_status: string + created_at: string +} + +export interface AffiliateTransferRecord { + ledger_id: number + user_id: number + user_email: string + username: string + amount: number + balance_after?: number | null + available_quota_after?: number | null + frozen_quota_after?: number | null + history_quota_after?: number | null + snapshot_available: boolean + created_at: string +} + +export interface AffiliateUserOverview { + user_id: number + email: string + username: string + aff_code: string + rebate_rate_percent: number + invited_count: number + rebated_invitee_count: number + available_quota: number + history_quota: number +} + export interface UpdateAffiliateUserRequest { aff_code?: string aff_rebate_rate_percent?: number | null @@ -97,12 +163,68 @@ export async function batchSetRate( return data } +function recordParams(params: ListAffiliateRecordsParams = {}) { + return { + page: params.page ?? 1, + page_size: params.page_size ?? 20, + search: params.search ?? '', + start_at: params.start_at || undefined, + end_at: params.end_at || undefined, + sort_by: params.sort_by || undefined, + sort_order: params.sort_order || undefined, + timezone: params.timezone || undefined, + } +} + +export async function listInviteRecords( + params: ListAffiliateRecordsParams = {}, +): Promise> { + const { data } = await apiClient.get>( + '/admin/affiliates/invites', + { params: recordParams(params) }, + ) + return data +} + +export async function listRebateRecords( + params: ListAffiliateRecordsParams = {}, +): Promise> { + const { data } = await apiClient.get>( + '/admin/affiliates/rebates', + { params: recordParams(params) }, + ) + return data +} + +export async function listTransferRecords( + params: ListAffiliateRecordsParams = {}, +): Promise> { + const { data } = await apiClient.get>( + '/admin/affiliates/transfers', + { params: recordParams(params) }, + ) + return data +} + +export async function getUserOverview( + userId: number, +): Promise { + const { data } = await apiClient.get( + `/admin/affiliates/users/${userId}/overview`, + ) + return data +} + export const affiliatesAPI = { listUsers, lookupUsers, updateUserSettings, clearUserSettings, batchSetRate, + listInviteRecords, + listRebateRecords, + listTransferRecords, + getUserOverview, } export default affiliatesAPI diff --git a/frontend/src/api/admin/settings.ts b/frontend/src/api/admin/settings.ts index b887355a..057a85e8 100644 --- a/frontend/src/api/admin/settings.ts +++ b/frontend/src/api/admin/settings.ts @@ -807,6 +807,30 @@ export async function updateOverloadCooldownSettings( return data; } +// ==================== 429 Rate Limit Cooldown Settings ==================== + +export interface RateLimit429CooldownSettings { + enabled: boolean; + cooldown_seconds: number; +} + +export async function getRateLimit429CooldownSettings(): Promise { + const { data } = await apiClient.get( + "/admin/settings/rate-limit-429-cooldown", + ); + return data; +} + +export async function updateRateLimit429CooldownSettings( + settings: RateLimit429CooldownSettings, +): Promise { + const { data } = await apiClient.put( + "/admin/settings/rate-limit-429-cooldown", + settings, + ); + return data; +} + // ==================== Stream Timeout Settings ==================== /** @@ -1026,6 +1050,8 @@ export const settingsAPI = { deleteAdminApiKey, getOverloadCooldownSettings, updateOverloadCooldownSettings, + getRateLimit429CooldownSettings, + updateRateLimit429CooldownSettings, getStreamTimeoutSettings, updateStreamTimeoutSettings, getRectifierSettings, diff --git a/frontend/src/api/admin/users.ts b/frontend/src/api/admin/users.ts index 3c75a6c4..fabc69bc 100644 --- a/frontend/src/api/admin/users.ts +++ b/frontend/src/api/admin/users.ts @@ -249,7 +249,7 @@ export interface BalanceHistoryResponse extends PaginatedResponse + +
+
+
+ +

+ {{ t('admin.accounts.openai.compactModeDesc') }} +

+
+ +
+
+ +
+
+
+
+ + + + +
+
+ +
+
+
@@ -989,7 +1093,7 @@ import { ref, watch, computed } from 'vue' import { useI18n } from 'vue-i18n' import { useAppStore } from '@/stores/app' import { adminAPI } from '@/api/admin' -import type { Proxy as ProxyConfig, AdminGroup, AccountPlatform, AccountType } from '@/types' +import type { Proxy as ProxyConfig, AdminGroup, AccountPlatform, AccountType, OpenAICompactMode } from '@/types' import BaseDialog from '@/components/common/BaseDialog.vue' import ConfirmDialog from '@/components/common/ConfirmDialog.vue' import Select from '@/components/common/Select.vue' @@ -1115,6 +1219,8 @@ const enableOpenAIPassthrough = ref(false) const enableOpenAIWSMode = ref(false) const enableOpenAIAPIKeyWSMode = ref(false) const enableCodexCLIOnly = ref(false) +const enableOpenAICompactMode = ref(false) +const enableOpenAICompactModelMapping = ref(false) const enableRpmLimit = ref(false) // State - field values @@ -1140,6 +1246,8 @@ const openaiPassthroughEnabled = ref(false) const openaiOAuthResponsesWebSocketV2Mode = ref(OPENAI_WS_MODE_OFF) const openaiAPIKeyResponsesWebSocketV2Mode = ref(OPENAI_WS_MODE_OFF) const codexCLIOnlyEnabled = ref(false) +const openAICompactMode = ref('auto') +const openAICompactModelMappings = ref([]) const rpmLimitEnabled = ref(false) const bulkBaseRpm = ref(null) const bulkRpmStrategy = ref<'tiered' | 'sticky_exempt'>('tiered') @@ -1178,6 +1286,11 @@ const openAIWSModeOptions = computed(() => [ { value: OPENAI_WS_MODE_CTX_POOL, label: t('admin.accounts.openai.wsModeCtxPool') }, { value: OPENAI_WS_MODE_PASSTHROUGH, label: t('admin.accounts.openai.wsModePassthrough') } ]) +const openAICompactModeOptions = computed(() => [ + { value: 'auto', label: t('admin.accounts.openai.compactModeAuto') }, + { value: 'force_on', label: t('admin.accounts.openai.compactModeForceOn') }, + { value: 'force_off', label: t('admin.accounts.openai.compactModeForceOff') } +]) const openAIWSModeConcurrencyHintKey = computed(() => resolveOpenAIWSModeConcurrencyHintKey(openaiOAuthResponsesWebSocketV2Mode.value) ) @@ -1194,6 +1307,14 @@ const removeModelMapping = (index: number) => { modelMappings.value.splice(index, 1) } +const addOpenAICompactModelMapping = () => { + openAICompactModelMappings.value.push({ from: '', to: '' }) +} + +const removeOpenAICompactModelMapping = (index: number) => { + openAICompactModelMappings.value.splice(index, 1) +} + const addPresetMapping = (from: string, to: string) => { const exists = modelMappings.value.some((m) => m.from === from) if (exists) { @@ -1262,6 +1383,10 @@ const buildModelMappingObject = (): Record | null => { ) } +const buildOpenAICompactModelMapping = (): Record | null => { + return buildModelMappingPayload('mapping', [], openAICompactModelMappings.value) +} + const buildUpdatePayload = (): Record | null => { const updates: Record = {} const credentials: Record = {} @@ -1350,10 +1475,6 @@ const buildUpdatePayload = (): Record | null => { credentialsChanged = true } - if (credentialsChanged) { - updates.credentials = credentials - } - if (enableOpenAIWSMode.value) { const extra = ensureExtra() extra.openai_oauth_responses_websockets_v2_mode = openaiOAuthResponsesWebSocketV2Mode.value @@ -1375,6 +1496,16 @@ const buildUpdatePayload = (): Record | null => { extra.codex_cli_only = codexCLIOnlyEnabled.value } + if (enableOpenAICompactMode.value) { + const extra = ensureExtra() + extra.openai_compact_mode = openAICompactMode.value + } + + if (enableOpenAICompactModelMapping.value) { + credentials.compact_model_mapping = buildOpenAICompactModelMapping() ?? {} + credentialsChanged = true + } + // RPM limit settings (写入 extra 字段) if (enableRpmLimit.value) { const extra = ensureExtra() @@ -1402,6 +1533,10 @@ const buildUpdatePayload = (): Record | null => { umqExtra.user_msg_queue_enabled = false // 清理旧字段(JSONB merge) } + if (credentialsChanged) { + updates.credentials = credentials + } + return Object.keys(updates).length > 0 ? updates : null } @@ -1467,6 +1602,8 @@ const handleSubmit = async () => { enableOpenAIWSMode.value || enableOpenAIAPIKeyWSMode.value || enableCodexCLIOnly.value || + enableOpenAICompactMode.value || + enableOpenAICompactModelMapping.value || enableRpmLimit.value || userMsgQueueMode.value !== null @@ -1567,6 +1704,8 @@ watch( enableOpenAIWSMode.value = false enableOpenAIAPIKeyWSMode.value = false enableCodexCLIOnly.value = false + enableOpenAICompactMode.value = false + enableOpenAICompactModelMapping.value = false enableRpmLimit.value = false // Reset all values @@ -1588,6 +1727,8 @@ watch( openaiOAuthResponsesWebSocketV2Mode.value = OPENAI_WS_MODE_OFF openaiAPIKeyResponsesWebSocketV2Mode.value = OPENAI_WS_MODE_OFF codexCLIOnlyEnabled.value = false + openAICompactMode.value = 'auto' + openAICompactModelMappings.value = [] rpmLimitEnabled.value = false bulkBaseRpm.value = null bulkRpmStrategy.value = 'tiered' diff --git a/frontend/src/components/account/__tests__/BulkEditAccountModal.spec.ts b/frontend/src/components/account/__tests__/BulkEditAccountModal.spec.ts index 50d170da..caa307fc 100644 --- a/frontend/src/components/account/__tests__/BulkEditAccountModal.spec.ts +++ b/frontend/src/components/account/__tests__/BulkEditAccountModal.spec.ts @@ -217,6 +217,44 @@ describe('BulkEditAccountModal', () => { }) }) + it('筛选 OpenAI 账号批量编辑应提交 Compact 模式和专属模型映射', async () => { + const wrapper = mountModal({ + accountIds: [], + selectedPlatforms: [], + selectedTypes: [], + target: { + mode: 'filtered', + filters: { platform: 'openai' }, + previewCount: 12, + selectedPlatforms: ['openai'], + selectedTypes: ['oauth', 'apikey'] + } + }) + + await wrapper.get('#bulk-edit-openai-compact-mode-enabled').setValue(true) + await wrapper.get('[data-testid="bulk-edit-openai-compact-mode-select"]').setValue('force_on') + await wrapper.get('#bulk-edit-openai-compact-model-mapping-enabled').setValue(true) + await wrapper.get('[data-testid="bulk-edit-openai-compact-model-mapping-add"]').trigger('click') + const inputs = wrapper.findAll('[data-testid="bulk-edit-openai-compact-model-mapping-input"]') + await inputs[0].setValue('gpt-5.4') + await inputs[1].setValue('gpt-5.4-openai-compact') + await wrapper.get('#bulk-edit-account-form').trigger('submit.prevent') + await flushPromises() + + expect(adminAPI.accounts.bulkUpdate).toHaveBeenCalledTimes(1) + expect(adminAPI.accounts.bulkUpdate).toHaveBeenCalledWith({ + filters: { platform: 'openai' }, + extra: { + openai_compact_mode: 'force_on' + }, + credentials: { + compact_model_mapping: { + 'gpt-5.4': 'gpt-5.4-openai-compact' + } + } + }) + }) + it('OpenAI 账号批量编辑可关闭自动透传', async () => { const wrapper = mountModal({ selectedPlatforms: ['openai'], diff --git a/frontend/src/components/admin/usage/UsageTable.vue b/frontend/src/components/admin/usage/UsageTable.vue index adcb3cc6..629e6aa2 100644 --- a/frontend/src/components/admin/usage/UsageTable.vue +++ b/frontend/src/components/admin/usage/UsageTable.vue @@ -291,9 +291,23 @@
+
- {{ tooltipData.billing_mode === BILLING_MODE_IMAGE ? t('usage.imageUnitPrice') : t('usage.unitPrice') }} - ${{ tooltipData.total_cost?.toFixed(6) || '0.000000' }} + {{ t('usage.unitPrice') }} + ${{ tooltipData?.total_cost?.toFixed(6) || '0.000000' }}
{{ t('admin.usage.cacheCreationCost') }} @@ -360,6 +374,13 @@ function accountBilled(row: { total_cost?: number | null; account_stats_cost?: n return Number.isNaN(result) ? 0 : result } +function imageUnitPrice(row: AdminUsageLog | null): number { + if (!row || row.image_count <= 0) return 0 + const total = row.total_cost ?? 0 + const price = total / row.image_count + return Number.isFinite(price) ? price : 0 +} + import DataTable from '@/components/common/DataTable.vue' import EmptyState from '@/components/common/EmptyState.vue' import Icon from '@/components/icons/Icon.vue' diff --git a/frontend/src/components/admin/user/UserBalanceHistoryModal.vue b/frontend/src/components/admin/user/UserBalanceHistoryModal.vue index 1a79e4e3..6d48ed77 100644 --- a/frontend/src/components/admin/user/UserBalanceHistoryModal.vue +++ b/frontend/src/components/admin/user/UserBalanceHistoryModal.vue @@ -196,6 +196,7 @@ const totalPages = computed(() => Math.ceil(total.value / pageSize) || 1) const typeOptions = computed(() => [ { value: '', label: t('admin.users.allTypes') }, { value: 'balance', label: t('admin.users.typeBalance') }, + { value: 'affiliate_balance', label: t('admin.users.typeAffiliateBalance') }, { value: 'admin_balance', label: t('admin.users.typeAdminBalance') }, { value: 'concurrency', label: t('admin.users.typeConcurrency') }, { value: 'admin_concurrency', label: t('admin.users.typeAdminConcurrency') }, @@ -235,7 +236,7 @@ const loadHistory = async (page: number) => { const isAdminType = (type: string) => type === 'admin_balance' || type === 'admin_concurrency' // Helper: check if balance type (includes admin_balance) -const isBalanceType = (type: string) => type === 'balance' || type === 'admin_balance' +const isBalanceType = (type: string) => type === 'balance' || type === 'admin_balance' || type === 'affiliate_balance' // Helper: check if subscription type const isSubscriptionType = (type: string) => type === 'subscription' @@ -291,6 +292,8 @@ const getItemTitle = (item: BalanceHistoryItem) => { switch (item.type) { case 'balance': return t('redeem.balanceAddedRedeem') + case 'affiliate_balance': + return t('redeem.balanceAddedAffiliate') case 'admin_balance': return item.value >= 0 ? t('redeem.balanceAddedAdmin') : t('redeem.balanceDeductedAdmin') case 'concurrency': diff --git a/frontend/src/components/common/GroupSelector.vue b/frontend/src/components/common/GroupSelector.vue index 582b6f0b..e5980ef5 100644 --- a/frontend/src/components/common/GroupSelector.vue +++ b/frontend/src/components/common/GroupSelector.vue @@ -5,7 +5,24 @@ {{ t('common.selectedCount', { count: modelValue.length }) }}
+ + +
+
+
- {{ tooltipData.billing_mode === 'image' ? t('usage.imageUnitPrice') : t('usage.unitPrice') }} - ${{ tooltipData.total_cost?.toFixed(6) || '0.000000' }} + {{ t('usage.unitPrice') }} + ${{ tooltipData?.total_cost?.toFixed(6) || '0.000000' }}
{{ t('admin.usage.cacheCreationCost') }} @@ -625,6 +639,13 @@ const formatDuration = (ms: number): string => { return `${(ms / 1000).toFixed(2)}s` } +const imageUnitPrice = (row: UsageLog | null): number => { + if (!row || row.image_count <= 0) return 0 + const total = row.total_cost ?? 0 + const price = total / row.image_count + return Number.isFinite(price) ? price : 0 +} + const formatUserAgent = (ua: string): string => { return ua } diff --git a/frontend/vite.config.ts b/frontend/vite.config.ts index b71f9d58..38770704 100644 --- a/frontend/vite.config.ts +++ b/frontend/vite.config.ts @@ -44,7 +44,6 @@ export default defineConfig(({ mode }) => { plugins: [ vue(), checker({ - typescript: true, vueTsc: true }), injectPublicSettings(backendUrl)