feat: add OpenAI image generation controls

This commit is contained in:
2ue 2026-05-05 03:26:54 +08:00
parent 4de28fec8c
commit 6faa344916
85 changed files with 6086 additions and 568 deletions

View File

@ -47,6 +47,12 @@ type Group struct {
MonthlyLimitUsd *float64 `json:"monthly_limit_usd,omitempty"`
// DefaultValidityDays holds the value of the "default_validity_days" field.
DefaultValidityDays int `json:"default_validity_days,omitempty"`
// 是否允许该分组使用图片生成能力
AllowImageGeneration bool `json:"allow_image_generation,omitempty"`
// 图片生成是否使用独立倍率false 表示共享分组有效倍率
ImageRateIndependent bool `json:"image_rate_independent,omitempty"`
// 图片生成独立倍率,仅 image_rate_independent=true 时生效
ImageRateMultiplier float64 `json:"image_rate_multiplier,omitempty"`
// ImagePrice1k holds the value of the "image_price_1k" field.
ImagePrice1k *float64 `json:"image_price_1k,omitempty"`
// ImagePrice2k holds the value of the "image_price_2k" field.
@ -189,9 +195,9 @@ func (*Group) scanValues(columns []string) ([]any, error) {
switch columns[i] {
case group.FieldModelRouting, group.FieldSupportedModelScopes, group.FieldMessagesDispatchModelConfig:
values[i] = new([]byte)
case group.FieldIsExclusive, group.FieldClaudeCodeOnly, group.FieldModelRoutingEnabled, group.FieldMcpXMLInject, group.FieldAllowMessagesDispatch, group.FieldRequireOauthOnly, group.FieldRequirePrivacySet:
case group.FieldIsExclusive, group.FieldAllowImageGeneration, group.FieldImageRateIndependent, group.FieldClaudeCodeOnly, group.FieldModelRoutingEnabled, group.FieldMcpXMLInject, group.FieldAllowMessagesDispatch, group.FieldRequireOauthOnly, group.FieldRequirePrivacySet:
values[i] = new(sql.NullBool)
case group.FieldRateMultiplier, group.FieldDailyLimitUsd, group.FieldWeeklyLimitUsd, group.FieldMonthlyLimitUsd, group.FieldImagePrice1k, group.FieldImagePrice2k, group.FieldImagePrice4k:
case group.FieldRateMultiplier, group.FieldDailyLimitUsd, group.FieldWeeklyLimitUsd, group.FieldMonthlyLimitUsd, group.FieldImageRateMultiplier, group.FieldImagePrice1k, group.FieldImagePrice2k, group.FieldImagePrice4k:
values[i] = new(sql.NullFloat64)
case group.FieldID, group.FieldDefaultValidityDays, group.FieldFallbackGroupID, group.FieldFallbackGroupIDOnInvalidRequest, group.FieldSortOrder, group.FieldRpmLimit:
values[i] = new(sql.NullInt64)
@ -309,6 +315,24 @@ func (_m *Group) assignValues(columns []string, values []any) error {
} else if value.Valid {
_m.DefaultValidityDays = int(value.Int64)
}
case group.FieldAllowImageGeneration:
if value, ok := values[i].(*sql.NullBool); !ok {
return fmt.Errorf("unexpected type %T for field allow_image_generation", values[i])
} else if value.Valid {
_m.AllowImageGeneration = value.Bool
}
case group.FieldImageRateIndependent:
if value, ok := values[i].(*sql.NullBool); !ok {
return fmt.Errorf("unexpected type %T for field image_rate_independent", values[i])
} else if value.Valid {
_m.ImageRateIndependent = value.Bool
}
case group.FieldImageRateMultiplier:
if value, ok := values[i].(*sql.NullFloat64); !ok {
return fmt.Errorf("unexpected type %T for field image_rate_multiplier", values[i])
} else if value.Valid {
_m.ImageRateMultiplier = value.Float64
}
case group.FieldImagePrice1k:
if value, ok := values[i].(*sql.NullFloat64); !ok {
return fmt.Errorf("unexpected type %T for field image_price_1k", values[i])
@ -550,6 +574,15 @@ func (_m *Group) String() string {
builder.WriteString("default_validity_days=")
builder.WriteString(fmt.Sprintf("%v", _m.DefaultValidityDays))
builder.WriteString(", ")
builder.WriteString("allow_image_generation=")
builder.WriteString(fmt.Sprintf("%v", _m.AllowImageGeneration))
builder.WriteString(", ")
builder.WriteString("image_rate_independent=")
builder.WriteString(fmt.Sprintf("%v", _m.ImageRateIndependent))
builder.WriteString(", ")
builder.WriteString("image_rate_multiplier=")
builder.WriteString(fmt.Sprintf("%v", _m.ImageRateMultiplier))
builder.WriteString(", ")
if v := _m.ImagePrice1k; v != nil {
builder.WriteString("image_price_1k=")
builder.WriteString(fmt.Sprintf("%v", *v))

View File

@ -44,6 +44,12 @@ const (
FieldMonthlyLimitUsd = "monthly_limit_usd"
// FieldDefaultValidityDays holds the string denoting the default_validity_days field in the database.
FieldDefaultValidityDays = "default_validity_days"
// FieldAllowImageGeneration holds the string denoting the allow_image_generation field in the database.
FieldAllowImageGeneration = "allow_image_generation"
// FieldImageRateIndependent holds the string denoting the image_rate_independent field in the database.
FieldImageRateIndependent = "image_rate_independent"
// FieldImageRateMultiplier holds the string denoting the image_rate_multiplier field in the database.
FieldImageRateMultiplier = "image_rate_multiplier"
// FieldImagePrice1k holds the string denoting the image_price_1k field in the database.
FieldImagePrice1k = "image_price_1k"
// FieldImagePrice2k holds the string denoting the image_price_2k field in the database.
@ -167,6 +173,9 @@ var Columns = []string{
FieldWeeklyLimitUsd,
FieldMonthlyLimitUsd,
FieldDefaultValidityDays,
FieldAllowImageGeneration,
FieldImageRateIndependent,
FieldImageRateMultiplier,
FieldImagePrice1k,
FieldImagePrice2k,
FieldImagePrice4k,
@ -239,6 +248,12 @@ var (
SubscriptionTypeValidator func(string) error
// DefaultDefaultValidityDays holds the default value on creation for the "default_validity_days" field.
DefaultDefaultValidityDays int
// DefaultAllowImageGeneration holds the default value on creation for the "allow_image_generation" field.
DefaultAllowImageGeneration bool
// DefaultImageRateIndependent holds the default value on creation for the "image_rate_independent" field.
DefaultImageRateIndependent bool
// DefaultImageRateMultiplier holds the default value on creation for the "image_rate_multiplier" field.
DefaultImageRateMultiplier float64
// DefaultClaudeCodeOnly holds the default value on creation for the "claude_code_only" field.
DefaultClaudeCodeOnly bool
// DefaultModelRoutingEnabled holds the default value on creation for the "model_routing_enabled" field.
@ -343,6 +358,21 @@ func ByDefaultValidityDays(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldDefaultValidityDays, opts...).ToFunc()
}
// ByAllowImageGeneration orders the results by the allow_image_generation field.
func ByAllowImageGeneration(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldAllowImageGeneration, opts...).ToFunc()
}
// ByImageRateIndependent orders the results by the image_rate_independent field.
func ByImageRateIndependent(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldImageRateIndependent, opts...).ToFunc()
}
// ByImageRateMultiplier orders the results by the image_rate_multiplier field.
func ByImageRateMultiplier(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldImageRateMultiplier, opts...).ToFunc()
}
// ByImagePrice1k orders the results by the image_price_1k field.
func ByImagePrice1k(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldImagePrice1k, opts...).ToFunc()

View File

@ -125,6 +125,21 @@ func DefaultValidityDays(v int) predicate.Group {
return predicate.Group(sql.FieldEQ(FieldDefaultValidityDays, v))
}
// AllowImageGeneration applies equality check predicate on the "allow_image_generation" field. It's identical to AllowImageGenerationEQ.
func AllowImageGeneration(v bool) predicate.Group {
return predicate.Group(sql.FieldEQ(FieldAllowImageGeneration, v))
}
// ImageRateIndependent applies equality check predicate on the "image_rate_independent" field. It's identical to ImageRateIndependentEQ.
func ImageRateIndependent(v bool) predicate.Group {
return predicate.Group(sql.FieldEQ(FieldImageRateIndependent, v))
}
// ImageRateMultiplier applies equality check predicate on the "image_rate_multiplier" field. It's identical to ImageRateMultiplierEQ.
func ImageRateMultiplier(v float64) predicate.Group {
return predicate.Group(sql.FieldEQ(FieldImageRateMultiplier, v))
}
// ImagePrice1k applies equality check predicate on the "image_price_1k" field. It's identical to ImagePrice1kEQ.
func ImagePrice1k(v float64) predicate.Group {
return predicate.Group(sql.FieldEQ(FieldImagePrice1k, v))
@ -900,6 +915,66 @@ func DefaultValidityDaysLTE(v int) predicate.Group {
return predicate.Group(sql.FieldLTE(FieldDefaultValidityDays, v))
}
// AllowImageGenerationEQ applies the EQ predicate on the "allow_image_generation" field.
func AllowImageGenerationEQ(v bool) predicate.Group {
return predicate.Group(sql.FieldEQ(FieldAllowImageGeneration, v))
}
// AllowImageGenerationNEQ applies the NEQ predicate on the "allow_image_generation" field.
func AllowImageGenerationNEQ(v bool) predicate.Group {
return predicate.Group(sql.FieldNEQ(FieldAllowImageGeneration, v))
}
// ImageRateIndependentEQ applies the EQ predicate on the "image_rate_independent" field.
func ImageRateIndependentEQ(v bool) predicate.Group {
return predicate.Group(sql.FieldEQ(FieldImageRateIndependent, v))
}
// ImageRateIndependentNEQ applies the NEQ predicate on the "image_rate_independent" field.
func ImageRateIndependentNEQ(v bool) predicate.Group {
return predicate.Group(sql.FieldNEQ(FieldImageRateIndependent, v))
}
// ImageRateMultiplierEQ applies the EQ predicate on the "image_rate_multiplier" field.
func ImageRateMultiplierEQ(v float64) predicate.Group {
return predicate.Group(sql.FieldEQ(FieldImageRateMultiplier, v))
}
// ImageRateMultiplierNEQ applies the NEQ predicate on the "image_rate_multiplier" field.
func ImageRateMultiplierNEQ(v float64) predicate.Group {
return predicate.Group(sql.FieldNEQ(FieldImageRateMultiplier, v))
}
// ImageRateMultiplierIn applies the In predicate on the "image_rate_multiplier" field.
func ImageRateMultiplierIn(vs ...float64) predicate.Group {
return predicate.Group(sql.FieldIn(FieldImageRateMultiplier, vs...))
}
// ImageRateMultiplierNotIn applies the NotIn predicate on the "image_rate_multiplier" field.
func ImageRateMultiplierNotIn(vs ...float64) predicate.Group {
return predicate.Group(sql.FieldNotIn(FieldImageRateMultiplier, vs...))
}
// ImageRateMultiplierGT applies the GT predicate on the "image_rate_multiplier" field.
func ImageRateMultiplierGT(v float64) predicate.Group {
return predicate.Group(sql.FieldGT(FieldImageRateMultiplier, v))
}
// ImageRateMultiplierGTE applies the GTE predicate on the "image_rate_multiplier" field.
func ImageRateMultiplierGTE(v float64) predicate.Group {
return predicate.Group(sql.FieldGTE(FieldImageRateMultiplier, v))
}
// ImageRateMultiplierLT applies the LT predicate on the "image_rate_multiplier" field.
func ImageRateMultiplierLT(v float64) predicate.Group {
return predicate.Group(sql.FieldLT(FieldImageRateMultiplier, v))
}
// ImageRateMultiplierLTE applies the LTE predicate on the "image_rate_multiplier" field.
func ImageRateMultiplierLTE(v float64) predicate.Group {
return predicate.Group(sql.FieldLTE(FieldImageRateMultiplier, v))
}
// ImagePrice1kEQ applies the EQ predicate on the "image_price_1k" field.
func ImagePrice1kEQ(v float64) predicate.Group {
return predicate.Group(sql.FieldEQ(FieldImagePrice1k, v))

View File

@ -217,6 +217,48 @@ func (_c *GroupCreate) SetNillableDefaultValidityDays(v *int) *GroupCreate {
return _c
}
// SetAllowImageGeneration sets the "allow_image_generation" field.
func (_c *GroupCreate) SetAllowImageGeneration(v bool) *GroupCreate {
_c.mutation.SetAllowImageGeneration(v)
return _c
}
// SetNillableAllowImageGeneration sets the "allow_image_generation" field if the given value is not nil.
func (_c *GroupCreate) SetNillableAllowImageGeneration(v *bool) *GroupCreate {
if v != nil {
_c.SetAllowImageGeneration(*v)
}
return _c
}
// SetImageRateIndependent sets the "image_rate_independent" field.
func (_c *GroupCreate) SetImageRateIndependent(v bool) *GroupCreate {
_c.mutation.SetImageRateIndependent(v)
return _c
}
// SetNillableImageRateIndependent sets the "image_rate_independent" field if the given value is not nil.
func (_c *GroupCreate) SetNillableImageRateIndependent(v *bool) *GroupCreate {
if v != nil {
_c.SetImageRateIndependent(*v)
}
return _c
}
// SetImageRateMultiplier sets the "image_rate_multiplier" field.
func (_c *GroupCreate) SetImageRateMultiplier(v float64) *GroupCreate {
_c.mutation.SetImageRateMultiplier(v)
return _c
}
// SetNillableImageRateMultiplier sets the "image_rate_multiplier" field if the given value is not nil.
func (_c *GroupCreate) SetNillableImageRateMultiplier(v *float64) *GroupCreate {
if v != nil {
_c.SetImageRateMultiplier(*v)
}
return _c
}
// SetImagePrice1k sets the "image_price_1k" field.
func (_c *GroupCreate) SetImagePrice1k(v float64) *GroupCreate {
_c.mutation.SetImagePrice1k(v)
@ -604,6 +646,18 @@ func (_c *GroupCreate) defaults() error {
v := group.DefaultDefaultValidityDays
_c.mutation.SetDefaultValidityDays(v)
}
if _, ok := _c.mutation.AllowImageGeneration(); !ok {
v := group.DefaultAllowImageGeneration
_c.mutation.SetAllowImageGeneration(v)
}
if _, ok := _c.mutation.ImageRateIndependent(); !ok {
v := group.DefaultImageRateIndependent
_c.mutation.SetImageRateIndependent(v)
}
if _, ok := _c.mutation.ImageRateMultiplier(); !ok {
v := group.DefaultImageRateMultiplier
_c.mutation.SetImageRateMultiplier(v)
}
if _, ok := _c.mutation.ClaudeCodeOnly(); !ok {
v := group.DefaultClaudeCodeOnly
_c.mutation.SetClaudeCodeOnly(v)
@ -700,6 +754,15 @@ func (_c *GroupCreate) check() error {
if _, ok := _c.mutation.DefaultValidityDays(); !ok {
return &ValidationError{Name: "default_validity_days", err: errors.New(`ent: missing required field "Group.default_validity_days"`)}
}
if _, ok := _c.mutation.AllowImageGeneration(); !ok {
return &ValidationError{Name: "allow_image_generation", err: errors.New(`ent: missing required field "Group.allow_image_generation"`)}
}
if _, ok := _c.mutation.ImageRateIndependent(); !ok {
return &ValidationError{Name: "image_rate_independent", err: errors.New(`ent: missing required field "Group.image_rate_independent"`)}
}
if _, ok := _c.mutation.ImageRateMultiplier(); !ok {
return &ValidationError{Name: "image_rate_multiplier", err: errors.New(`ent: missing required field "Group.image_rate_multiplier"`)}
}
if _, ok := _c.mutation.ClaudeCodeOnly(); !ok {
return &ValidationError{Name: "claude_code_only", err: errors.New(`ent: missing required field "Group.claude_code_only"`)}
}
@ -821,6 +884,18 @@ func (_c *GroupCreate) createSpec() (*Group, *sqlgraph.CreateSpec) {
_spec.SetField(group.FieldDefaultValidityDays, field.TypeInt, value)
_node.DefaultValidityDays = value
}
if value, ok := _c.mutation.AllowImageGeneration(); ok {
_spec.SetField(group.FieldAllowImageGeneration, field.TypeBool, value)
_node.AllowImageGeneration = value
}
if value, ok := _c.mutation.ImageRateIndependent(); ok {
_spec.SetField(group.FieldImageRateIndependent, field.TypeBool, value)
_node.ImageRateIndependent = value
}
if value, ok := _c.mutation.ImageRateMultiplier(); ok {
_spec.SetField(group.FieldImageRateMultiplier, field.TypeFloat64, value)
_node.ImageRateMultiplier = value
}
if value, ok := _c.mutation.ImagePrice1k(); ok {
_spec.SetField(group.FieldImagePrice1k, field.TypeFloat64, value)
_node.ImagePrice1k = &value
@ -1261,6 +1336,48 @@ func (u *GroupUpsert) AddDefaultValidityDays(v int) *GroupUpsert {
return u
}
// SetAllowImageGeneration sets the "allow_image_generation" field.
func (u *GroupUpsert) SetAllowImageGeneration(v bool) *GroupUpsert {
u.Set(group.FieldAllowImageGeneration, v)
return u
}
// UpdateAllowImageGeneration sets the "allow_image_generation" field to the value that was provided on create.
func (u *GroupUpsert) UpdateAllowImageGeneration() *GroupUpsert {
u.SetExcluded(group.FieldAllowImageGeneration)
return u
}
// SetImageRateIndependent sets the "image_rate_independent" field.
func (u *GroupUpsert) SetImageRateIndependent(v bool) *GroupUpsert {
u.Set(group.FieldImageRateIndependent, v)
return u
}
// UpdateImageRateIndependent sets the "image_rate_independent" field to the value that was provided on create.
func (u *GroupUpsert) UpdateImageRateIndependent() *GroupUpsert {
u.SetExcluded(group.FieldImageRateIndependent)
return u
}
// SetImageRateMultiplier sets the "image_rate_multiplier" field.
func (u *GroupUpsert) SetImageRateMultiplier(v float64) *GroupUpsert {
u.Set(group.FieldImageRateMultiplier, v)
return u
}
// UpdateImageRateMultiplier sets the "image_rate_multiplier" field to the value that was provided on create.
func (u *GroupUpsert) UpdateImageRateMultiplier() *GroupUpsert {
u.SetExcluded(group.FieldImageRateMultiplier)
return u
}
// AddImageRateMultiplier adds v to the "image_rate_multiplier" field.
func (u *GroupUpsert) AddImageRateMultiplier(v float64) *GroupUpsert {
u.Add(group.FieldImageRateMultiplier, v)
return u
}
// SetImagePrice1k sets the "image_price_1k" field.
func (u *GroupUpsert) SetImagePrice1k(v float64) *GroupUpsert {
u.Set(group.FieldImagePrice1k, v)
@ -1840,6 +1957,55 @@ func (u *GroupUpsertOne) UpdateDefaultValidityDays() *GroupUpsertOne {
})
}
// SetAllowImageGeneration sets the "allow_image_generation" field.
func (u *GroupUpsertOne) SetAllowImageGeneration(v bool) *GroupUpsertOne {
return u.Update(func(s *GroupUpsert) {
s.SetAllowImageGeneration(v)
})
}
// UpdateAllowImageGeneration sets the "allow_image_generation" field to the value that was provided on create.
func (u *GroupUpsertOne) UpdateAllowImageGeneration() *GroupUpsertOne {
return u.Update(func(s *GroupUpsert) {
s.UpdateAllowImageGeneration()
})
}
// SetImageRateIndependent sets the "image_rate_independent" field.
func (u *GroupUpsertOne) SetImageRateIndependent(v bool) *GroupUpsertOne {
return u.Update(func(s *GroupUpsert) {
s.SetImageRateIndependent(v)
})
}
// UpdateImageRateIndependent sets the "image_rate_independent" field to the value that was provided on create.
func (u *GroupUpsertOne) UpdateImageRateIndependent() *GroupUpsertOne {
return u.Update(func(s *GroupUpsert) {
s.UpdateImageRateIndependent()
})
}
// SetImageRateMultiplier sets the "image_rate_multiplier" field.
func (u *GroupUpsertOne) SetImageRateMultiplier(v float64) *GroupUpsertOne {
return u.Update(func(s *GroupUpsert) {
s.SetImageRateMultiplier(v)
})
}
// AddImageRateMultiplier adds v to the "image_rate_multiplier" field.
func (u *GroupUpsertOne) AddImageRateMultiplier(v float64) *GroupUpsertOne {
return u.Update(func(s *GroupUpsert) {
s.AddImageRateMultiplier(v)
})
}
// UpdateImageRateMultiplier sets the "image_rate_multiplier" field to the value that was provided on create.
func (u *GroupUpsertOne) UpdateImageRateMultiplier() *GroupUpsertOne {
return u.Update(func(s *GroupUpsert) {
s.UpdateImageRateMultiplier()
})
}
// SetImagePrice1k sets the "image_price_1k" field.
func (u *GroupUpsertOne) SetImagePrice1k(v float64) *GroupUpsertOne {
return u.Update(func(s *GroupUpsert) {
@ -2632,6 +2798,55 @@ func (u *GroupUpsertBulk) UpdateDefaultValidityDays() *GroupUpsertBulk {
})
}
// SetAllowImageGeneration sets the "allow_image_generation" field.
func (u *GroupUpsertBulk) SetAllowImageGeneration(v bool) *GroupUpsertBulk {
return u.Update(func(s *GroupUpsert) {
s.SetAllowImageGeneration(v)
})
}
// UpdateAllowImageGeneration sets the "allow_image_generation" field to the value that was provided on create.
func (u *GroupUpsertBulk) UpdateAllowImageGeneration() *GroupUpsertBulk {
return u.Update(func(s *GroupUpsert) {
s.UpdateAllowImageGeneration()
})
}
// SetImageRateIndependent sets the "image_rate_independent" field.
func (u *GroupUpsertBulk) SetImageRateIndependent(v bool) *GroupUpsertBulk {
return u.Update(func(s *GroupUpsert) {
s.SetImageRateIndependent(v)
})
}
// UpdateImageRateIndependent sets the "image_rate_independent" field to the value that was provided on create.
func (u *GroupUpsertBulk) UpdateImageRateIndependent() *GroupUpsertBulk {
return u.Update(func(s *GroupUpsert) {
s.UpdateImageRateIndependent()
})
}
// SetImageRateMultiplier sets the "image_rate_multiplier" field.
func (u *GroupUpsertBulk) SetImageRateMultiplier(v float64) *GroupUpsertBulk {
return u.Update(func(s *GroupUpsert) {
s.SetImageRateMultiplier(v)
})
}
// AddImageRateMultiplier adds v to the "image_rate_multiplier" field.
func (u *GroupUpsertBulk) AddImageRateMultiplier(v float64) *GroupUpsertBulk {
return u.Update(func(s *GroupUpsert) {
s.AddImageRateMultiplier(v)
})
}
// UpdateImageRateMultiplier sets the "image_rate_multiplier" field to the value that was provided on create.
func (u *GroupUpsertBulk) UpdateImageRateMultiplier() *GroupUpsertBulk {
return u.Update(func(s *GroupUpsert) {
s.UpdateImageRateMultiplier()
})
}
// SetImagePrice1k sets the "image_price_1k" field.
func (u *GroupUpsertBulk) SetImagePrice1k(v float64) *GroupUpsertBulk {
return u.Update(func(s *GroupUpsert) {

View File

@ -275,6 +275,55 @@ func (_u *GroupUpdate) AddDefaultValidityDays(v int) *GroupUpdate {
return _u
}
// SetAllowImageGeneration sets the "allow_image_generation" field.
func (_u *GroupUpdate) SetAllowImageGeneration(v bool) *GroupUpdate {
_u.mutation.SetAllowImageGeneration(v)
return _u
}
// SetNillableAllowImageGeneration sets the "allow_image_generation" field if the given value is not nil.
func (_u *GroupUpdate) SetNillableAllowImageGeneration(v *bool) *GroupUpdate {
if v != nil {
_u.SetAllowImageGeneration(*v)
}
return _u
}
// SetImageRateIndependent sets the "image_rate_independent" field.
func (_u *GroupUpdate) SetImageRateIndependent(v bool) *GroupUpdate {
_u.mutation.SetImageRateIndependent(v)
return _u
}
// SetNillableImageRateIndependent sets the "image_rate_independent" field if the given value is not nil.
func (_u *GroupUpdate) SetNillableImageRateIndependent(v *bool) *GroupUpdate {
if v != nil {
_u.SetImageRateIndependent(*v)
}
return _u
}
// SetImageRateMultiplier sets the "image_rate_multiplier" field.
func (_u *GroupUpdate) SetImageRateMultiplier(v float64) *GroupUpdate {
_u.mutation.ResetImageRateMultiplier()
_u.mutation.SetImageRateMultiplier(v)
return _u
}
// SetNillableImageRateMultiplier sets the "image_rate_multiplier" field if the given value is not nil.
func (_u *GroupUpdate) SetNillableImageRateMultiplier(v *float64) *GroupUpdate {
if v != nil {
_u.SetImageRateMultiplier(*v)
}
return _u
}
// AddImageRateMultiplier adds value to the "image_rate_multiplier" field.
func (_u *GroupUpdate) AddImageRateMultiplier(v float64) *GroupUpdate {
_u.mutation.AddImageRateMultiplier(v)
return _u
}
// SetImagePrice1k sets the "image_price_1k" field.
func (_u *GroupUpdate) SetImagePrice1k(v float64) *GroupUpdate {
_u.mutation.ResetImagePrice1k()
@ -962,6 +1011,18 @@ func (_u *GroupUpdate) sqlSave(ctx context.Context) (_node int, err error) {
if value, ok := _u.mutation.AddedDefaultValidityDays(); ok {
_spec.AddField(group.FieldDefaultValidityDays, field.TypeInt, value)
}
if value, ok := _u.mutation.AllowImageGeneration(); ok {
_spec.SetField(group.FieldAllowImageGeneration, field.TypeBool, value)
}
if value, ok := _u.mutation.ImageRateIndependent(); ok {
_spec.SetField(group.FieldImageRateIndependent, field.TypeBool, value)
}
if value, ok := _u.mutation.ImageRateMultiplier(); ok {
_spec.SetField(group.FieldImageRateMultiplier, field.TypeFloat64, value)
}
if value, ok := _u.mutation.AddedImageRateMultiplier(); ok {
_spec.AddField(group.FieldImageRateMultiplier, field.TypeFloat64, value)
}
if value, ok := _u.mutation.ImagePrice1k(); ok {
_spec.SetField(group.FieldImagePrice1k, field.TypeFloat64, value)
}
@ -1610,6 +1671,55 @@ func (_u *GroupUpdateOne) AddDefaultValidityDays(v int) *GroupUpdateOne {
return _u
}
// SetAllowImageGeneration sets the "allow_image_generation" field.
func (_u *GroupUpdateOne) SetAllowImageGeneration(v bool) *GroupUpdateOne {
_u.mutation.SetAllowImageGeneration(v)
return _u
}
// SetNillableAllowImageGeneration sets the "allow_image_generation" field if the given value is not nil.
func (_u *GroupUpdateOne) SetNillableAllowImageGeneration(v *bool) *GroupUpdateOne {
if v != nil {
_u.SetAllowImageGeneration(*v)
}
return _u
}
// SetImageRateIndependent sets the "image_rate_independent" field.
func (_u *GroupUpdateOne) SetImageRateIndependent(v bool) *GroupUpdateOne {
_u.mutation.SetImageRateIndependent(v)
return _u
}
// SetNillableImageRateIndependent sets the "image_rate_independent" field if the given value is not nil.
func (_u *GroupUpdateOne) SetNillableImageRateIndependent(v *bool) *GroupUpdateOne {
if v != nil {
_u.SetImageRateIndependent(*v)
}
return _u
}
// SetImageRateMultiplier sets the "image_rate_multiplier" field.
func (_u *GroupUpdateOne) SetImageRateMultiplier(v float64) *GroupUpdateOne {
_u.mutation.ResetImageRateMultiplier()
_u.mutation.SetImageRateMultiplier(v)
return _u
}
// SetNillableImageRateMultiplier sets the "image_rate_multiplier" field if the given value is not nil.
func (_u *GroupUpdateOne) SetNillableImageRateMultiplier(v *float64) *GroupUpdateOne {
if v != nil {
_u.SetImageRateMultiplier(*v)
}
return _u
}
// AddImageRateMultiplier adds value to the "image_rate_multiplier" field.
func (_u *GroupUpdateOne) AddImageRateMultiplier(v float64) *GroupUpdateOne {
_u.mutation.AddImageRateMultiplier(v)
return _u
}
// SetImagePrice1k sets the "image_price_1k" field.
func (_u *GroupUpdateOne) SetImagePrice1k(v float64) *GroupUpdateOne {
_u.mutation.ResetImagePrice1k()
@ -2327,6 +2437,18 @@ func (_u *GroupUpdateOne) sqlSave(ctx context.Context) (_node *Group, err error)
if value, ok := _u.mutation.AddedDefaultValidityDays(); ok {
_spec.AddField(group.FieldDefaultValidityDays, field.TypeInt, value)
}
if value, ok := _u.mutation.AllowImageGeneration(); ok {
_spec.SetField(group.FieldAllowImageGeneration, field.TypeBool, value)
}
if value, ok := _u.mutation.ImageRateIndependent(); ok {
_spec.SetField(group.FieldImageRateIndependent, field.TypeBool, value)
}
if value, ok := _u.mutation.ImageRateMultiplier(); ok {
_spec.SetField(group.FieldImageRateMultiplier, field.TypeFloat64, value)
}
if value, ok := _u.mutation.AddedImageRateMultiplier(); ok {
_spec.AddField(group.FieldImageRateMultiplier, field.TypeFloat64, value)
}
if value, ok := _u.mutation.ImagePrice1k(); ok {
_spec.SetField(group.FieldImagePrice1k, field.TypeFloat64, value)
}

View File

@ -638,6 +638,9 @@ var (
{Name: "weekly_limit_usd", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}},
{Name: "monthly_limit_usd", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}},
{Name: "default_validity_days", Type: field.TypeInt, Default: 30},
{Name: "allow_image_generation", Type: field.TypeBool, Default: false},
{Name: "image_rate_independent", Type: field.TypeBool, Default: false},
{Name: "image_rate_multiplier", Type: field.TypeFloat64, Default: 1, SchemaType: map[string]string{"postgres": "decimal(10,4)"}},
{Name: "image_price_1k", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}},
{Name: "image_price_2k", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}},
{Name: "image_price_4k", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}},
@ -690,7 +693,7 @@ var (
{
Name: "group_sort_order",
Unique: false,
Columns: []*schema.Column{GroupsColumns[25]},
Columns: []*schema.Column{GroupsColumns[28]},
},
},
}

View File

@ -14764,6 +14764,10 @@ type GroupMutation struct {
addmonthly_limit_usd *float64
default_validity_days *int
adddefault_validity_days *int
allow_image_generation *bool
image_rate_independent *bool
image_rate_multiplier *float64
addimage_rate_multiplier *float64
image_price_1k *float64
addimage_price_1k *float64
image_price_2k *float64
@ -15583,6 +15587,134 @@ func (m *GroupMutation) ResetDefaultValidityDays() {
m.adddefault_validity_days = nil
}
// SetAllowImageGeneration sets the "allow_image_generation" field.
func (m *GroupMutation) SetAllowImageGeneration(b bool) {
m.allow_image_generation = &b
}
// AllowImageGeneration returns the value of the "allow_image_generation" field in the mutation.
func (m *GroupMutation) AllowImageGeneration() (r bool, exists bool) {
v := m.allow_image_generation
if v == nil {
return
}
return *v, true
}
// OldAllowImageGeneration returns the old "allow_image_generation" field's value of the Group entity.
// If the Group object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
func (m *GroupMutation) OldAllowImageGeneration(ctx context.Context) (v bool, err error) {
if !m.op.Is(OpUpdateOne) {
return v, errors.New("OldAllowImageGeneration is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
return v, errors.New("OldAllowImageGeneration requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
return v, fmt.Errorf("querying old value for OldAllowImageGeneration: %w", err)
}
return oldValue.AllowImageGeneration, nil
}
// ResetAllowImageGeneration resets all changes to the "allow_image_generation" field.
func (m *GroupMutation) ResetAllowImageGeneration() {
m.allow_image_generation = nil
}
// SetImageRateIndependent sets the "image_rate_independent" field.
func (m *GroupMutation) SetImageRateIndependent(b bool) {
m.image_rate_independent = &b
}
// ImageRateIndependent returns the value of the "image_rate_independent" field in the mutation.
func (m *GroupMutation) ImageRateIndependent() (r bool, exists bool) {
v := m.image_rate_independent
if v == nil {
return
}
return *v, true
}
// OldImageRateIndependent returns the old "image_rate_independent" field's value of the Group entity.
// If the Group object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
func (m *GroupMutation) OldImageRateIndependent(ctx context.Context) (v bool, err error) {
if !m.op.Is(OpUpdateOne) {
return v, errors.New("OldImageRateIndependent is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
return v, errors.New("OldImageRateIndependent requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
return v, fmt.Errorf("querying old value for OldImageRateIndependent: %w", err)
}
return oldValue.ImageRateIndependent, nil
}
// ResetImageRateIndependent resets all changes to the "image_rate_independent" field.
func (m *GroupMutation) ResetImageRateIndependent() {
m.image_rate_independent = nil
}
// SetImageRateMultiplier sets the "image_rate_multiplier" field.
func (m *GroupMutation) SetImageRateMultiplier(f float64) {
m.image_rate_multiplier = &f
m.addimage_rate_multiplier = nil
}
// ImageRateMultiplier returns the value of the "image_rate_multiplier" field in the mutation.
func (m *GroupMutation) ImageRateMultiplier() (r float64, exists bool) {
v := m.image_rate_multiplier
if v == nil {
return
}
return *v, true
}
// OldImageRateMultiplier returns the old "image_rate_multiplier" field's value of the Group entity.
// If the Group object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
func (m *GroupMutation) OldImageRateMultiplier(ctx context.Context) (v float64, err error) {
if !m.op.Is(OpUpdateOne) {
return v, errors.New("OldImageRateMultiplier is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
return v, errors.New("OldImageRateMultiplier requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
return v, fmt.Errorf("querying old value for OldImageRateMultiplier: %w", err)
}
return oldValue.ImageRateMultiplier, nil
}
// AddImageRateMultiplier adds f to the "image_rate_multiplier" field.
func (m *GroupMutation) AddImageRateMultiplier(f float64) {
if m.addimage_rate_multiplier != nil {
*m.addimage_rate_multiplier += f
} else {
m.addimage_rate_multiplier = &f
}
}
// AddedImageRateMultiplier returns the value that was added to the "image_rate_multiplier" field in this mutation.
func (m *GroupMutation) AddedImageRateMultiplier() (r float64, exists bool) {
v := m.addimage_rate_multiplier
if v == nil {
return
}
return *v, true
}
// ResetImageRateMultiplier resets all changes to the "image_rate_multiplier" field.
func (m *GroupMutation) ResetImageRateMultiplier() {
m.image_rate_multiplier = nil
m.addimage_rate_multiplier = nil
}
// SetImagePrice1k sets the "image_price_1k" field.
func (m *GroupMutation) SetImagePrice1k(f float64) {
m.image_price_1k = &f
@ -16791,7 +16923,7 @@ func (m *GroupMutation) Type() string {
// order to get all numeric fields that were incremented/decremented, call
// AddedFields().
func (m *GroupMutation) Fields() []string {
fields := make([]string, 0, 31)
fields := make([]string, 0, 34)
if m.created_at != nil {
fields = append(fields, group.FieldCreatedAt)
}
@ -16834,6 +16966,15 @@ func (m *GroupMutation) Fields() []string {
if m.default_validity_days != nil {
fields = append(fields, group.FieldDefaultValidityDays)
}
if m.allow_image_generation != nil {
fields = append(fields, group.FieldAllowImageGeneration)
}
if m.image_rate_independent != nil {
fields = append(fields, group.FieldImageRateIndependent)
}
if m.image_rate_multiplier != nil {
fields = append(fields, group.FieldImageRateMultiplier)
}
if m.image_price_1k != nil {
fields = append(fields, group.FieldImagePrice1k)
}
@ -16921,6 +17062,12 @@ func (m *GroupMutation) Field(name string) (ent.Value, bool) {
return m.MonthlyLimitUsd()
case group.FieldDefaultValidityDays:
return m.DefaultValidityDays()
case group.FieldAllowImageGeneration:
return m.AllowImageGeneration()
case group.FieldImageRateIndependent:
return m.ImageRateIndependent()
case group.FieldImageRateMultiplier:
return m.ImageRateMultiplier()
case group.FieldImagePrice1k:
return m.ImagePrice1k()
case group.FieldImagePrice2k:
@ -16992,6 +17139,12 @@ func (m *GroupMutation) OldField(ctx context.Context, name string) (ent.Value, e
return m.OldMonthlyLimitUsd(ctx)
case group.FieldDefaultValidityDays:
return m.OldDefaultValidityDays(ctx)
case group.FieldAllowImageGeneration:
return m.OldAllowImageGeneration(ctx)
case group.FieldImageRateIndependent:
return m.OldImageRateIndependent(ctx)
case group.FieldImageRateMultiplier:
return m.OldImageRateMultiplier(ctx)
case group.FieldImagePrice1k:
return m.OldImagePrice1k(ctx)
case group.FieldImagePrice2k:
@ -17133,6 +17286,27 @@ func (m *GroupMutation) SetField(name string, value ent.Value) error {
}
m.SetDefaultValidityDays(v)
return nil
case group.FieldAllowImageGeneration:
v, ok := value.(bool)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
m.SetAllowImageGeneration(v)
return nil
case group.FieldImageRateIndependent:
v, ok := value.(bool)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
m.SetImageRateIndependent(v)
return nil
case group.FieldImageRateMultiplier:
v, ok := value.(float64)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
m.SetImageRateMultiplier(v)
return nil
case group.FieldImagePrice1k:
v, ok := value.(float64)
if !ok {
@ -17275,6 +17449,9 @@ func (m *GroupMutation) AddedFields() []string {
if m.adddefault_validity_days != nil {
fields = append(fields, group.FieldDefaultValidityDays)
}
if m.addimage_rate_multiplier != nil {
fields = append(fields, group.FieldImageRateMultiplier)
}
if m.addimage_price_1k != nil {
fields = append(fields, group.FieldImagePrice1k)
}
@ -17314,6 +17491,8 @@ func (m *GroupMutation) AddedField(name string) (ent.Value, bool) {
return m.AddedMonthlyLimitUsd()
case group.FieldDefaultValidityDays:
return m.AddedDefaultValidityDays()
case group.FieldImageRateMultiplier:
return m.AddedImageRateMultiplier()
case group.FieldImagePrice1k:
return m.AddedImagePrice1k()
case group.FieldImagePrice2k:
@ -17372,6 +17551,13 @@ func (m *GroupMutation) AddField(name string, value ent.Value) error {
}
m.AddDefaultValidityDays(v)
return nil
case group.FieldImageRateMultiplier:
v, ok := value.(float64)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
m.AddImageRateMultiplier(v)
return nil
case group.FieldImagePrice1k:
v, ok := value.(float64)
if !ok {
@ -17559,6 +17745,15 @@ func (m *GroupMutation) ResetField(name string) error {
case group.FieldDefaultValidityDays:
m.ResetDefaultValidityDays()
return nil
case group.FieldAllowImageGeneration:
m.ResetAllowImageGeneration()
return nil
case group.FieldImageRateIndependent:
m.ResetImageRateIndependent()
return nil
case group.FieldImageRateMultiplier:
m.ResetImageRateMultiplier()
return nil
case group.FieldImagePrice1k:
m.ResetImagePrice1k()
return nil

View File

@ -803,50 +803,62 @@ func init() {
groupDescDefaultValidityDays := groupFields[10].Descriptor()
// group.DefaultDefaultValidityDays holds the default value on creation for the default_validity_days field.
group.DefaultDefaultValidityDays = groupDescDefaultValidityDays.Default.(int)
// groupDescAllowImageGeneration is the schema descriptor for allow_image_generation field.
groupDescAllowImageGeneration := groupFields[11].Descriptor()
// group.DefaultAllowImageGeneration holds the default value on creation for the allow_image_generation field.
group.DefaultAllowImageGeneration = groupDescAllowImageGeneration.Default.(bool)
// groupDescImageRateIndependent is the schema descriptor for image_rate_independent field.
groupDescImageRateIndependent := groupFields[12].Descriptor()
// group.DefaultImageRateIndependent holds the default value on creation for the image_rate_independent field.
group.DefaultImageRateIndependent = groupDescImageRateIndependent.Default.(bool)
// groupDescImageRateMultiplier is the schema descriptor for image_rate_multiplier field.
groupDescImageRateMultiplier := groupFields[13].Descriptor()
// group.DefaultImageRateMultiplier holds the default value on creation for the image_rate_multiplier field.
group.DefaultImageRateMultiplier = groupDescImageRateMultiplier.Default.(float64)
// groupDescClaudeCodeOnly is the schema descriptor for claude_code_only field.
groupDescClaudeCodeOnly := groupFields[14].Descriptor()
groupDescClaudeCodeOnly := groupFields[17].Descriptor()
// group.DefaultClaudeCodeOnly holds the default value on creation for the claude_code_only field.
group.DefaultClaudeCodeOnly = groupDescClaudeCodeOnly.Default.(bool)
// groupDescModelRoutingEnabled is the schema descriptor for model_routing_enabled field.
groupDescModelRoutingEnabled := groupFields[18].Descriptor()
groupDescModelRoutingEnabled := groupFields[21].Descriptor()
// group.DefaultModelRoutingEnabled holds the default value on creation for the model_routing_enabled field.
group.DefaultModelRoutingEnabled = groupDescModelRoutingEnabled.Default.(bool)
// groupDescMcpXMLInject is the schema descriptor for mcp_xml_inject field.
groupDescMcpXMLInject := groupFields[19].Descriptor()
groupDescMcpXMLInject := groupFields[22].Descriptor()
// group.DefaultMcpXMLInject holds the default value on creation for the mcp_xml_inject field.
group.DefaultMcpXMLInject = groupDescMcpXMLInject.Default.(bool)
// groupDescSupportedModelScopes is the schema descriptor for supported_model_scopes field.
groupDescSupportedModelScopes := groupFields[20].Descriptor()
groupDescSupportedModelScopes := groupFields[23].Descriptor()
// group.DefaultSupportedModelScopes holds the default value on creation for the supported_model_scopes field.
group.DefaultSupportedModelScopes = groupDescSupportedModelScopes.Default.([]string)
// groupDescSortOrder is the schema descriptor for sort_order field.
groupDescSortOrder := groupFields[21].Descriptor()
groupDescSortOrder := groupFields[24].Descriptor()
// group.DefaultSortOrder holds the default value on creation for the sort_order field.
group.DefaultSortOrder = groupDescSortOrder.Default.(int)
// groupDescAllowMessagesDispatch is the schema descriptor for allow_messages_dispatch field.
groupDescAllowMessagesDispatch := groupFields[22].Descriptor()
groupDescAllowMessagesDispatch := groupFields[25].Descriptor()
// group.DefaultAllowMessagesDispatch holds the default value on creation for the allow_messages_dispatch field.
group.DefaultAllowMessagesDispatch = groupDescAllowMessagesDispatch.Default.(bool)
// groupDescRequireOauthOnly is the schema descriptor for require_oauth_only field.
groupDescRequireOauthOnly := groupFields[23].Descriptor()
groupDescRequireOauthOnly := groupFields[26].Descriptor()
// group.DefaultRequireOauthOnly holds the default value on creation for the require_oauth_only field.
group.DefaultRequireOauthOnly = groupDescRequireOauthOnly.Default.(bool)
// groupDescRequirePrivacySet is the schema descriptor for require_privacy_set field.
groupDescRequirePrivacySet := groupFields[24].Descriptor()
groupDescRequirePrivacySet := groupFields[27].Descriptor()
// group.DefaultRequirePrivacySet holds the default value on creation for the require_privacy_set field.
group.DefaultRequirePrivacySet = groupDescRequirePrivacySet.Default.(bool)
// groupDescDefaultMappedModel is the schema descriptor for default_mapped_model field.
groupDescDefaultMappedModel := groupFields[25].Descriptor()
groupDescDefaultMappedModel := groupFields[28].Descriptor()
// group.DefaultDefaultMappedModel holds the default value on creation for the default_mapped_model field.
group.DefaultDefaultMappedModel = groupDescDefaultMappedModel.Default.(string)
// group.DefaultMappedModelValidator is a validator for the "default_mapped_model" field. It is called by the builders before save.
group.DefaultMappedModelValidator = groupDescDefaultMappedModel.Validators[0].(func(string) error)
// groupDescMessagesDispatchModelConfig is the schema descriptor for messages_dispatch_model_config field.
groupDescMessagesDispatchModelConfig := groupFields[26].Descriptor()
groupDescMessagesDispatchModelConfig := groupFields[29].Descriptor()
// group.DefaultMessagesDispatchModelConfig holds the default value on creation for the messages_dispatch_model_config field.
group.DefaultMessagesDispatchModelConfig = groupDescMessagesDispatchModelConfig.Default.(domain.OpenAIMessagesDispatchModelConfig)
// groupDescRpmLimit is the schema descriptor for rpm_limit field.
groupDescRpmLimit := groupFields[27].Descriptor()
groupDescRpmLimit := groupFields[30].Descriptor()
// group.DefaultRpmLimit holds the default value on creation for the rpm_limit field.
group.DefaultRpmLimit = groupDescRpmLimit.Default.(int)
idempotencyrecordMixin := schema.IdempotencyRecord{}.Mixin()

View File

@ -74,6 +74,16 @@ func (Group) Fields() []ent.Field {
Default(30),
// 图片生成计费配置antigravity 和 gemini 平台使用)
field.Bool("allow_image_generation").
Default(false).
Comment("是否允许该分组使用图片生成能力"),
field.Bool("image_rate_independent").
Default(false).
Comment("图片生成是否使用独立倍率false 表示共享分组有效倍率"),
field.Float("image_rate_multiplier").
SchemaType(map[string]string{dialect.Postgres: "decimal(10,4)"}).
Default(1.0).
Comment("图片生成独立倍率,仅 image_rate_independent=true 时生效"),
field.Float("image_price_1k").
Optional().
Nillable().

View File

@ -575,6 +575,24 @@ type ConcurrencyConfig struct {
PingInterval int `mapstructure:"ping_interval"`
}
type ImageConcurrencyConfig struct {
// Enabled: 是否启用图片生成独立并发限制,默认关闭以保持现有行为
Enabled bool `mapstructure:"enabled"`
// MaxConcurrentRequests: 当前进程允许同时处理的图片生成请求数0表示不限制
MaxConcurrentRequests int `mapstructure:"max_concurrent_requests"`
// OverflowMode: 图片并发达到上限后的处理方式reject/wait
OverflowMode string `mapstructure:"overflow_mode"`
// WaitTimeoutSeconds: overflow_mode=wait 时等待图片并发槽位的超时时间(秒)
WaitTimeoutSeconds int `mapstructure:"wait_timeout_seconds"`
// MaxWaitingRequests: overflow_mode=wait 时当前进程允许排队等待的图片请求数
MaxWaitingRequests int `mapstructure:"max_waiting_requests"`
}
const (
ImageConcurrencyOverflowModeReject = "reject"
ImageConcurrencyOverflowModeWait = "wait"
)
// GatewayConfig API网关相关配置
type GatewayConfig struct {
// 等待上游响应头的超时时间0表示无超时
@ -604,6 +622,8 @@ type GatewayConfig struct {
OpenAIPassthroughAllowTimeoutHeaders bool `mapstructure:"openai_passthrough_allow_timeout_headers"`
// OpenAIWS: OpenAI Responses WebSocket 配置(默认开启,可按需回滚到 HTTP
OpenAIWS GatewayOpenAIWSConfig `mapstructure:"openai_ws"`
// ImageConcurrency: 图片生成独立并发限制配置(默认关闭)
ImageConcurrency ImageConcurrencyConfig `mapstructure:"image_concurrency"`
// HTTP 上游连接池配置(性能优化:支持高并发场景调优)
// MaxIdleConns: 所有主机的最大空闲连接总数
@ -635,6 +655,10 @@ type GatewayConfig struct {
StreamDataIntervalTimeout int `mapstructure:"stream_data_interval_timeout"`
// StreamKeepaliveInterval: 流式 keepalive 间隔0表示禁用
StreamKeepaliveInterval int `mapstructure:"stream_keepalive_interval"`
// ImageStreamDataIntervalTimeout: 图片流数据间隔超时0表示禁用
ImageStreamDataIntervalTimeout int `mapstructure:"image_stream_data_interval_timeout"`
// ImageStreamKeepaliveInterval: 图片流式 keepalive 间隔0表示禁用
ImageStreamKeepaliveInterval int `mapstructure:"image_stream_keepalive_interval"`
// MaxLineSize: 上游 SSE 单行最大字节数0使用默认值
MaxLineSize int `mapstructure:"max_line_size"`
@ -1672,6 +1696,11 @@ func setDefaults() {
viper.SetDefault("gateway.openai_ws.scheduler_score_weights.queue", 0.7)
viper.SetDefault("gateway.openai_ws.scheduler_score_weights.error_rate", 0.8)
viper.SetDefault("gateway.openai_ws.scheduler_score_weights.ttft", 0.5)
viper.SetDefault("gateway.image_concurrency.enabled", false)
viper.SetDefault("gateway.image_concurrency.max_concurrent_requests", 0)
viper.SetDefault("gateway.image_concurrency.overflow_mode", ImageConcurrencyOverflowModeReject)
viper.SetDefault("gateway.image_concurrency.wait_timeout_seconds", 30)
viper.SetDefault("gateway.image_concurrency.max_waiting_requests", 100)
viper.SetDefault("gateway.antigravity_fallback_cooldown_minutes", 1)
viper.SetDefault("gateway.antigravity_extra_retries", 10)
viper.SetDefault("gateway.max_body_size", int64(256*1024*1024))
@ -1689,6 +1718,8 @@ func setDefaults() {
viper.SetDefault("gateway.concurrency_slot_ttl_minutes", 30) // 并发槽位过期时间(支持超长请求)
viper.SetDefault("gateway.stream_data_interval_timeout", 180)
viper.SetDefault("gateway.stream_keepalive_interval", 10)
viper.SetDefault("gateway.image_stream_data_interval_timeout", 900)
viper.SetDefault("gateway.image_stream_keepalive_interval", 10)
viper.SetDefault("gateway.max_line_size", 500*1024*1024)
viper.SetDefault("gateway.scheduling.sticky_session_max_waiting", 3)
viper.SetDefault("gateway.scheduling.sticky_session_wait_timeout", 120*time.Second)
@ -2239,6 +2270,21 @@ func (c *Config) Validate() error {
ConnectionPoolIsolationProxy, ConnectionPoolIsolationAccount, ConnectionPoolIsolationAccountProxy)
}
}
if c.Gateway.ImageConcurrency.MaxConcurrentRequests < 0 {
return fmt.Errorf("gateway.image_concurrency.max_concurrent_requests must be non-negative")
}
switch strings.TrimSpace(c.Gateway.ImageConcurrency.OverflowMode) {
case "", ImageConcurrencyOverflowModeReject, ImageConcurrencyOverflowModeWait:
default:
return fmt.Errorf("gateway.image_concurrency.overflow_mode must be one of: %s/%s",
ImageConcurrencyOverflowModeReject, ImageConcurrencyOverflowModeWait)
}
if c.Gateway.ImageConcurrency.WaitTimeoutSeconds < 0 {
return fmt.Errorf("gateway.image_concurrency.wait_timeout_seconds must be non-negative")
}
if c.Gateway.ImageConcurrency.MaxWaitingRequests < 0 {
return fmt.Errorf("gateway.image_concurrency.max_waiting_requests must be non-negative")
}
if c.Gateway.MaxIdleConns <= 0 {
return fmt.Errorf("gateway.max_idle_conns must be positive")
}
@ -2277,6 +2323,20 @@ func (c *Config) Validate() error {
(c.Gateway.StreamKeepaliveInterval < 5 || c.Gateway.StreamKeepaliveInterval > 30) {
return fmt.Errorf("gateway.stream_keepalive_interval must be 0 or between 5-30 seconds")
}
if c.Gateway.ImageStreamDataIntervalTimeout < 0 {
return fmt.Errorf("gateway.image_stream_data_interval_timeout must be non-negative")
}
if c.Gateway.ImageStreamDataIntervalTimeout != 0 &&
(c.Gateway.ImageStreamDataIntervalTimeout < 60 || c.Gateway.ImageStreamDataIntervalTimeout > 1800) {
return fmt.Errorf("gateway.image_stream_data_interval_timeout must be 0 or between 60-1800 seconds")
}
if c.Gateway.ImageStreamKeepaliveInterval < 0 {
return fmt.Errorf("gateway.image_stream_keepalive_interval must be non-negative")
}
if c.Gateway.ImageStreamKeepaliveInterval != 0 &&
(c.Gateway.ImageStreamKeepaliveInterval < 5 || c.Gateway.ImageStreamKeepaliveInterval > 60) {
return fmt.Errorf("gateway.image_stream_keepalive_interval must be 0 or between 5-60 seconds")
}
// 兼容旧键 sticky_previous_response_ttl_seconds
if c.Gateway.OpenAIWS.StickyResponseIDTTLSeconds <= 0 && c.Gateway.OpenAIWS.StickyPreviousResponseTTLSeconds > 0 {
c.Gateway.OpenAIWS.StickyResponseIDTTLSeconds = c.Gateway.OpenAIWS.StickyPreviousResponseTTLSeconds

View File

@ -1282,6 +1282,46 @@ func TestValidateConfigErrors(t *testing.T) {
mutate: func(c *Config) { c.Gateway.StreamDataIntervalTimeout = -1 },
wantErr: "gateway.stream_data_interval_timeout must be non-negative",
},
{
name: "gateway image stream keepalive range",
mutate: func(c *Config) { c.Gateway.ImageStreamKeepaliveInterval = 4 },
wantErr: "gateway.image_stream_keepalive_interval",
},
{
name: "gateway image stream keepalive negative",
mutate: func(c *Config) { c.Gateway.ImageStreamKeepaliveInterval = -1 },
wantErr: "gateway.image_stream_keepalive_interval must be non-negative",
},
{
name: "gateway image stream data interval range",
mutate: func(c *Config) { c.Gateway.ImageStreamDataIntervalTimeout = 30 },
wantErr: "gateway.image_stream_data_interval_timeout",
},
{
name: "gateway image stream data interval negative",
mutate: func(c *Config) { c.Gateway.ImageStreamDataIntervalTimeout = -1 },
wantErr: "gateway.image_stream_data_interval_timeout must be non-negative",
},
{
name: "gateway image concurrency max negative",
mutate: func(c *Config) { c.Gateway.ImageConcurrency.MaxConcurrentRequests = -1 },
wantErr: "gateway.image_concurrency.max_concurrent_requests must be non-negative",
},
{
name: "gateway image concurrency overflow mode invalid",
mutate: func(c *Config) { c.Gateway.ImageConcurrency.OverflowMode = "queue" },
wantErr: "gateway.image_concurrency.overflow_mode",
},
{
name: "gateway image concurrency wait timeout negative",
mutate: func(c *Config) { c.Gateway.ImageConcurrency.WaitTimeoutSeconds = -1 },
wantErr: "gateway.image_concurrency.wait_timeout_seconds must be non-negative",
},
{
name: "gateway image concurrency max waiting negative",
mutate: func(c *Config) { c.Gateway.ImageConcurrency.MaxWaitingRequests = -1 },
wantErr: "gateway.image_concurrency.max_waiting_requests must be non-negative",
},
{
name: "gateway max line size",
mutate: func(c *Config) { c.Gateway.MaxLineSize = 1024 },
@ -1754,3 +1794,41 @@ func TestLoad_DefaultGatewayUsageRecordConfig(t *testing.T) {
t.Fatalf("auto_scale_cooldown_seconds = %d, want 10", cfg.Gateway.UsageRecord.AutoScaleCooldownSeconds)
}
}
func TestLoad_DefaultGatewayImageStreamConfig(t *testing.T) {
resetViperWithJWTSecret(t)
cfg, err := Load()
if err != nil {
t.Fatalf("Load() error: %v", err)
}
if cfg.Gateway.StreamDataIntervalTimeout != 180 {
t.Fatalf("stream_data_interval_timeout = %d, want 180", cfg.Gateway.StreamDataIntervalTimeout)
}
if cfg.Gateway.StreamKeepaliveInterval != 10 {
t.Fatalf("stream_keepalive_interval = %d, want 10", cfg.Gateway.StreamKeepaliveInterval)
}
if cfg.Gateway.ImageStreamDataIntervalTimeout != 900 {
t.Fatalf("image_stream_data_interval_timeout = %d, want 900", cfg.Gateway.ImageStreamDataIntervalTimeout)
}
if cfg.Gateway.ImageStreamKeepaliveInterval != 10 {
t.Fatalf("image_stream_keepalive_interval = %d, want 10", cfg.Gateway.ImageStreamKeepaliveInterval)
}
if cfg.Gateway.ImageConcurrency.Enabled {
t.Fatalf("image_concurrency.enabled = true, want false")
}
if cfg.Gateway.ImageConcurrency.MaxConcurrentRequests != 0 {
t.Fatalf("image_concurrency.max_concurrent_requests = %d, want 0", cfg.Gateway.ImageConcurrency.MaxConcurrentRequests)
}
if cfg.Gateway.ImageConcurrency.OverflowMode != ImageConcurrencyOverflowModeReject {
t.Fatalf("image_concurrency.overflow_mode = %q, want %q", cfg.Gateway.ImageConcurrency.OverflowMode, ImageConcurrencyOverflowModeReject)
}
if cfg.Gateway.ImageConcurrency.WaitTimeoutSeconds != 30 {
t.Fatalf("image_concurrency.wait_timeout_seconds = %d, want 30", cfg.Gateway.ImageConcurrency.WaitTimeoutSeconds)
}
if cfg.Gateway.ImageConcurrency.MaxWaitingRequests != 100 {
t.Fatalf("image_concurrency.max_waiting_requests = %d, want 100", cfg.Gateway.ImageConcurrency.MaxWaitingRequests)
}
if cfg.Gateway.ImageStreamDataIntervalTimeout <= cfg.Gateway.StreamDataIntervalTimeout {
t.Fatalf("image stream timeout = %d, want greater than ordinary stream timeout %d", cfg.Gateway.ImageStreamDataIntervalTimeout, cfg.Gateway.StreamDataIntervalTimeout)
}
}

View File

@ -92,6 +92,9 @@ type CreateGroupRequest struct {
WeeklyLimitUSD optionalLimitField `json:"weekly_limit_usd"`
MonthlyLimitUSD optionalLimitField `json:"monthly_limit_usd"`
// 图片生成计费配置antigravity 和 gemini 平台使用,负数表示清除配置)
AllowImageGeneration bool `json:"allow_image_generation"`
ImageRateIndependent bool `json:"image_rate_independent"`
ImageRateMultiplier *float64 `json:"image_rate_multiplier"`
ImagePrice1K *float64 `json:"image_price_1k"`
ImagePrice2K *float64 `json:"image_price_2k"`
ImagePrice4K *float64 `json:"image_price_4k"`
@ -129,6 +132,9 @@ type UpdateGroupRequest struct {
WeeklyLimitUSD optionalLimitField `json:"weekly_limit_usd"`
MonthlyLimitUSD optionalLimitField `json:"monthly_limit_usd"`
// 图片生成计费配置antigravity 和 gemini 平台使用,负数表示清除配置)
AllowImageGeneration *bool `json:"allow_image_generation"`
ImageRateIndependent *bool `json:"image_rate_independent"`
ImageRateMultiplier *float64 `json:"image_rate_multiplier"`
ImagePrice1K *float64 `json:"image_price_1k"`
ImagePrice2K *float64 `json:"image_price_2k"`
ImagePrice4K *float64 `json:"image_price_4k"`
@ -251,6 +257,9 @@ func (h *GroupHandler) Create(c *gin.Context) {
DailyLimitUSD: req.DailyLimitUSD.ToServiceInput(),
WeeklyLimitUSD: req.WeeklyLimitUSD.ToServiceInput(),
MonthlyLimitUSD: req.MonthlyLimitUSD.ToServiceInput(),
AllowImageGeneration: req.AllowImageGeneration,
ImageRateIndependent: req.ImageRateIndependent,
ImageRateMultiplier: req.ImageRateMultiplier,
ImagePrice1K: req.ImagePrice1K,
ImagePrice2K: req.ImagePrice2K,
ImagePrice4K: req.ImagePrice4K,
@ -303,6 +312,9 @@ func (h *GroupHandler) Update(c *gin.Context) {
DailyLimitUSD: req.DailyLimitUSD.ToServiceInput(),
WeeklyLimitUSD: req.WeeklyLimitUSD.ToServiceInput(),
MonthlyLimitUSD: req.MonthlyLimitUSD.ToServiceInput(),
AllowImageGeneration: req.AllowImageGeneration,
ImageRateIndependent: req.ImageRateIndependent,
ImageRateMultiplier: req.ImageRateMultiplier,
ImagePrice1K: req.ImagePrice1K,
ImagePrice2K: req.ImagePrice2K,
ImagePrice4K: req.ImagePrice4K,

View File

@ -176,6 +176,9 @@ func groupFromServiceBase(g *service.Group) Group {
DailyLimitUSD: g.DailyLimitUSD,
WeeklyLimitUSD: g.WeeklyLimitUSD,
MonthlyLimitUSD: g.MonthlyLimitUSD,
AllowImageGeneration: g.AllowImageGeneration,
ImageRateIndependent: g.ImageRateIndependent,
ImageRateMultiplier: g.ImageRateMultiplier,
ImagePrice1K: g.ImagePrice1K,
ImagePrice2K: g.ImagePrice2K,
ImagePrice4K: g.ImagePrice4K,

View File

@ -94,9 +94,12 @@ type Group struct {
MonthlyLimitUSD *float64 `json:"monthly_limit_usd"`
// 图片生成计费配置(仅 antigravity 平台使用)
ImagePrice1K *float64 `json:"image_price_1k"`
ImagePrice2K *float64 `json:"image_price_2k"`
ImagePrice4K *float64 `json:"image_price_4k"`
AllowImageGeneration bool `json:"allow_image_generation"`
ImageRateIndependent bool `json:"image_rate_independent"`
ImageRateMultiplier float64 `json:"image_rate_multiplier"`
ImagePrice1K *float64 `json:"image_price_1k"`
ImagePrice2K *float64 `json:"image_price_2k"`
ImagePrice4K *float64 `json:"image_price_4k"`
// Claude Code 客户端限制
ClaudeCodeOnly bool `json:"claude_code_only"`

View File

@ -0,0 +1,126 @@
package handler
import (
"context"
"sync"
"time"
)
type imageConcurrencyLimiter struct {
mu sync.Mutex
notify chan struct{}
limit int
active int
waiting int
enabled bool
}
func (l *imageConcurrencyLimiter) TryAcquire(enabled bool, limit int) (func(), bool) {
return l.acquire(context.Background(), enabled, limit, false, 0, 0)
}
func (l *imageConcurrencyLimiter) Acquire(ctx context.Context, enabled bool, limit int, wait bool, timeout time.Duration, maxWaiting int) (func(), bool) {
return l.acquire(ctx, enabled, limit, wait, timeout, maxWaiting)
}
func (l *imageConcurrencyLimiter) acquire(ctx context.Context, enabled bool, limit int, wait bool, timeout time.Duration, maxWaiting int) (func(), bool) {
if !enabled || limit <= 0 {
return nil, true
}
if ctx == nil {
ctx = context.Background()
}
if wait {
if timeout <= 0 {
return nil, false
}
waitCtx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
ctx = waitCtx
}
if maxWaiting < 0 {
maxWaiting = 0
}
for {
release, acquired, waitRelease, notify := l.tryAcquireLocked(enabled, limit, wait, maxWaiting)
if acquired {
return release, acquired
}
if !wait || notify == nil {
return nil, false
}
if !l.waitForSlot(ctx, notify) {
if waitRelease != nil {
waitRelease()
}
return nil, false
}
if waitRelease != nil {
waitRelease()
}
}
}
func (l *imageConcurrencyLimiter) tryAcquireLocked(enabled bool, limit int, wait bool, maxWaiting int) (func(), bool, func(), <-chan struct{}) {
l.mu.Lock()
defer l.mu.Unlock()
if l.notify == nil {
l.notify = make(chan struct{})
}
if l.enabled != enabled || l.limit != limit {
l.enabled = enabled
l.limit = limit
}
if l.active < l.limit {
l.active++
return l.releaseFunc(), true, nil, nil
}
if !wait {
return nil, false, nil, nil
}
if maxWaiting > 0 && l.waiting >= maxWaiting {
return nil, false, nil, nil
}
l.waiting++
return nil, false, l.waiterReleaseFunc(), l.notify
}
func (l *imageConcurrencyLimiter) waitForSlot(ctx context.Context, notify <-chan struct{}) bool {
select {
case <-notify:
return true
case <-ctx.Done():
return false
}
}
func (l *imageConcurrencyLimiter) releaseFunc() func() {
var once sync.Once
return func() {
once.Do(func() {
l.mu.Lock()
if l.active > 0 {
l.active--
}
if l.notify != nil {
close(l.notify)
l.notify = make(chan struct{})
}
l.mu.Unlock()
})
}
}
func (l *imageConcurrencyLimiter) waiterReleaseFunc() func() {
var once sync.Once
return func() {
once.Do(func() {
l.mu.Lock()
if l.waiting > 0 {
l.waiting--
}
l.mu.Unlock()
})
}
}

View File

@ -0,0 +1,230 @@
package handler
import (
"context"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
"github.com/tidwall/gjson"
)
func TestImageConcurrencyLimiter_DefaultDisabledAllowsRequests(t *testing.T) {
limiter := &imageConcurrencyLimiter{}
release, acquired := limiter.TryAcquire(false, 1)
require.True(t, acquired)
require.Nil(t, release)
}
func TestImageConcurrencyLimiter_RejectsWhenLimitReachedAndAllowsAfterRelease(t *testing.T) {
limiter := &imageConcurrencyLimiter{}
release, acquired := limiter.TryAcquire(true, 1)
require.True(t, acquired)
require.NotNil(t, release)
secondRelease, secondAcquired := limiter.TryAcquire(true, 1)
require.False(t, secondAcquired)
require.Nil(t, secondRelease)
release()
thirdRelease, thirdAcquired := limiter.TryAcquire(true, 1)
require.True(t, thirdAcquired)
require.NotNil(t, thirdRelease)
thirdRelease()
}
func TestImageConcurrencyLimiter_WaitsUntilSlotReleased(t *testing.T) {
limiter := &imageConcurrencyLimiter{}
release, acquired := limiter.Acquire(context.Background(), true, 1, true, time.Second, 1)
require.True(t, acquired)
require.NotNil(t, release)
acquiredCh := make(chan func(), 1)
go func() {
waitRelease, waitAcquired := limiter.Acquire(context.Background(), true, 1, true, time.Second, 1)
require.True(t, waitAcquired)
acquiredCh <- waitRelease
}()
time.Sleep(20 * time.Millisecond)
release()
select {
case waitRelease := <-acquiredCh:
require.NotNil(t, waitRelease)
waitRelease()
case <-time.After(time.Second):
t.Fatal("timed out waiting for image concurrency slot")
}
}
func TestImageConcurrencyLimiter_WaitTimesOut(t *testing.T) {
limiter := &imageConcurrencyLimiter{}
release, acquired := limiter.Acquire(context.Background(), true, 1, true, time.Second, 1)
require.True(t, acquired)
require.NotNil(t, release)
defer release()
waitRelease, waitAcquired := limiter.Acquire(context.Background(), true, 1, true, 10*time.Millisecond, 1)
require.False(t, waitAcquired)
require.Nil(t, waitRelease)
}
func TestImageConcurrencyLimiter_MaxWaitingRequestsRejectsOverflow(t *testing.T) {
limiter := &imageConcurrencyLimiter{}
release, acquired := limiter.Acquire(context.Background(), true, 1, true, time.Second, 1)
require.True(t, acquired)
require.NotNil(t, release)
defer release()
waitingStarted := make(chan struct{})
waitingDone := make(chan struct{})
go func() {
close(waitingStarted)
waitRelease, waitAcquired := limiter.Acquire(context.Background(), true, 1, true, time.Second, 1)
if waitAcquired && waitRelease != nil {
waitRelease()
}
close(waitingDone)
}()
<-waitingStarted
time.Sleep(20 * time.Millisecond)
overflowRelease, overflowAcquired := limiter.Acquire(context.Background(), true, 1, true, time.Second, 1)
require.False(t, overflowAcquired)
require.Nil(t, overflowRelease)
release()
<-waitingDone
}
func TestOpenAIGatewayHandlerAcquireImageGenerationSlot_Returns429WhenFull(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/images/generations", nil)
h := &OpenAIGatewayHandler{
cfg: &config.Config{
Gateway: config.GatewayConfig{
ImageConcurrency: config.ImageConcurrencyConfig{
Enabled: true,
MaxConcurrentRequests: 1,
OverflowMode: config.ImageConcurrencyOverflowModeReject,
},
},
},
imageLimiter: &imageConcurrencyLimiter{},
}
release, acquired := h.acquireImageGenerationSlot(c, false)
require.True(t, acquired)
require.NotNil(t, release)
defer release()
blockedRelease, blocked := h.acquireImageGenerationSlot(c, false)
require.False(t, blocked)
require.Nil(t, blockedRelease)
require.Equal(t, http.StatusTooManyRequests, rec.Code)
require.Equal(t, "rate_limit_error", gjson.GetBytes(rec.Body.Bytes(), "error.type").String())
require.Contains(t, rec.Body.String(), "Image generation concurrency limit exceeded")
}
func TestOpenAIGatewayHandlerResponses_ImageIntentRejectedByImageConcurrency(t *testing.T) {
gin.SetMode(gin.TestMode)
body := `{"model":"gpt-5.4","input":"draw","tools":[{"type":"image_generation"}]}`
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", strings.NewReader(body))
groupID := int64(1)
c.Set(string(middleware2.ContextKeyAPIKey), &service.APIKey{
ID: 10,
GroupID: &groupID,
Group: &service.Group{
ID: groupID,
AllowImageGeneration: true,
},
User: &service.User{ID: 20},
})
c.Set(string(middleware2.ContextKeyUser), middleware2.AuthSubject{UserID: 20, Concurrency: 1})
h := &OpenAIGatewayHandler{
gatewayService: &service.OpenAIGatewayService{},
billingCacheService: &service.BillingCacheService{},
apiKeyService: &service.APIKeyService{},
concurrencyHelper: &ConcurrencyHelper{concurrencyService: service.NewConcurrencyService(&helperConcurrencyCacheStub{userSeq: []bool{true}})},
errorPassthroughService: nil,
cfg: &config.Config{Gateway: config.GatewayConfig{ImageConcurrency: config.ImageConcurrencyConfig{
Enabled: true,
MaxConcurrentRequests: 1,
OverflowMode: config.ImageConcurrencyOverflowModeReject,
}}},
imageLimiter: &imageConcurrencyLimiter{},
}
release, acquired := h.acquireImageGenerationSlot(c, false)
require.True(t, acquired)
require.NotNil(t, release)
defer release()
rec.Body.Reset()
rec.Code = 0
h.Responses(c)
require.Equal(t, http.StatusTooManyRequests, rec.Code)
require.Equal(t, "rate_limit_error", gjson.GetBytes(rec.Body.Bytes(), "error.type").String())
require.Contains(t, rec.Body.String(), "Image generation concurrency limit exceeded")
}
func TestOpenAIGatewayHandlerResponses_TextOnlyNotRejectedByImageConcurrency(t *testing.T) {
gin.SetMode(gin.TestMode)
body := `{"model":"gpt-5.4","input":"write code"}`
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", strings.NewReader(body))
groupID := int64(1)
c.Set(string(middleware2.ContextKeyAPIKey), &service.APIKey{
ID: 10,
GroupID: &groupID,
Group: &service.Group{
ID: groupID,
AllowImageGeneration: true,
},
User: &service.User{ID: 20},
})
c.Set(string(middleware2.ContextKeyUser), middleware2.AuthSubject{UserID: 20, Concurrency: 1})
h := &OpenAIGatewayHandler{
gatewayService: &service.OpenAIGatewayService{},
billingCacheService: service.NewBillingCacheService(nil, nil, nil, nil, nil, nil, &config.Config{RunMode: config.RunModeSimple}),
apiKeyService: &service.APIKeyService{},
concurrencyHelper: &ConcurrencyHelper{concurrencyService: service.NewConcurrencyService(&helperConcurrencyCacheStub{userSeq: []bool{true}})},
cfg: &config.Config{Gateway: config.GatewayConfig{ImageConcurrency: config.ImageConcurrencyConfig{
Enabled: true,
MaxConcurrentRequests: 1,
OverflowMode: config.ImageConcurrencyOverflowModeReject,
}}},
imageLimiter: &imageConcurrencyLimiter{},
}
release, acquired := h.acquireImageGenerationSlot(c, false)
require.True(t, acquired)
require.NotNil(t, release)
defer release()
rec.Body.Reset()
rec.Code = 0
h.Responses(c)
require.NotEqual(t, http.StatusTooManyRequests, rec.Code)
require.NotContains(t, rec.Body.String(), "Image generation concurrency limit exceeded")
}

View File

@ -187,52 +187,60 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
service.SetOpsLatencyMs(c, service.OpsTimeToFirstTokenMsKey, int64(*result.FirstTokenMs))
}
if err != nil {
var failoverErr *service.UpstreamFailoverError
if errors.As(err, &failoverErr) {
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
// Pool mode: retry on the same account
if failoverErr.RetryableOnSameAccount {
retryLimit := account.GetPoolModeRetryCount()
if sameAccountRetryCount[account.ID] < retryLimit {
sameAccountRetryCount[account.ID]++
reqLog.Warn("openai_chat_completions.pool_mode_same_account_retry",
zap.Int64("account_id", account.ID),
zap.Int("upstream_status", failoverErr.StatusCode),
zap.Int("retry_limit", retryLimit),
zap.Int("retry_count", sameAccountRetryCount[account.ID]),
)
select {
case <-c.Request.Context().Done():
return
case <-time.After(sameAccountRetryDelay):
}
continue
}
}
h.gatewayService.RecordOpenAIAccountSwitch()
failedAccountIDs[account.ID] = struct{}{}
lastFailoverErr = failoverErr
if switchCount >= maxAccountSwitches {
h.handleFailoverExhausted(c, failoverErr, streamStarted)
return
}
switchCount++
reqLog.Warn("openai_chat_completions.upstream_failover_switching",
if result != nil && result.ImageCount > 0 {
reqLog.Warn("openai_chat_completions.forward_partial_error_with_image_result",
zap.Int64("account_id", account.ID),
zap.Int("upstream_status", failoverErr.StatusCode),
zap.Int("switch_count", switchCount),
zap.Int("max_switches", maxAccountSwitches),
zap.Int("image_count", result.ImageCount),
zap.Error(err),
)
continue
} else {
var failoverErr *service.UpstreamFailoverError
if errors.As(err, &failoverErr) {
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
// Pool mode: retry on the same account
if failoverErr.RetryableOnSameAccount {
retryLimit := account.GetPoolModeRetryCount()
if sameAccountRetryCount[account.ID] < retryLimit {
sameAccountRetryCount[account.ID]++
reqLog.Warn("openai_chat_completions.pool_mode_same_account_retry",
zap.Int64("account_id", account.ID),
zap.Int("upstream_status", failoverErr.StatusCode),
zap.Int("retry_limit", retryLimit),
zap.Int("retry_count", sameAccountRetryCount[account.ID]),
)
select {
case <-c.Request.Context().Done():
return
case <-time.After(sameAccountRetryDelay):
}
continue
}
}
h.gatewayService.RecordOpenAIAccountSwitch()
failedAccountIDs[account.ID] = struct{}{}
lastFailoverErr = failoverErr
if switchCount >= maxAccountSwitches {
h.handleFailoverExhausted(c, failoverErr, streamStarted)
return
}
switchCount++
reqLog.Warn("openai_chat_completions.upstream_failover_switching",
zap.Int64("account_id", account.ID),
zap.Int("upstream_status", failoverErr.StatusCode),
zap.Int("switch_count", switchCount),
zap.Int("max_switches", maxAccountSwitches),
)
continue
}
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
wroteFallback := h.ensureForwardErrorResponse(c, streamStarted)
reqLog.Warn("openai_chat_completions.forward_failed",
zap.Int64("account_id", account.ID),
zap.Bool("fallback_error_response_written", wroteFallback),
zap.Error(err),
)
return
}
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
wroteFallback := h.ensureForwardErrorResponse(c, streamStarted)
reqLog.Warn("openai_chat_completions.forward_failed",
zap.Int64("account_id", account.ID),
zap.Bool("fallback_error_response_written", wroteFallback),
zap.Error(err),
)
return
}
if result != nil {
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, result.FirstTokenMs)
@ -242,16 +250,18 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
userAgent := c.GetHeader("User-Agent")
clientIP := ip.GetClientIP(c)
inboundEndpoint := GetInboundEndpoint(c)
upstreamEndpoint := resolveRawCCUpstreamEndpoint(c, account)
h.submitUsageRecordTask(func(ctx context.Context) {
h.submitOpenAIUsageRecordTask(result, func(ctx context.Context) {
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
Result: result,
APIKey: apiKey,
User: apiKey.User,
Account: account,
Subscription: subscription,
InboundEndpoint: GetInboundEndpoint(c),
UpstreamEndpoint: resolveRawCCUpstreamEndpoint(c, account),
InboundEndpoint: inboundEndpoint,
UpstreamEndpoint: upstreamEndpoint,
UserAgent: userAgent,
IPAddress: clientIP,
APIKeyService: h.apiKeyService,

View File

@ -33,6 +33,7 @@ type OpenAIGatewayHandler struct {
usageRecordWorkerPool *service.UsageRecordWorkerPool
errorPassthroughService *service.ErrorPassthroughService
concurrencyHelper *ConcurrencyHelper
imageLimiter *imageConcurrencyLimiter
maxAccountSwitches int
cfg *config.Config
}
@ -69,6 +70,7 @@ func NewOpenAIGatewayHandler(
usageRecordWorkerPool: usageRecordWorkerPool,
errorPassthroughService: errorPassthroughService,
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatComment, pingInterval),
imageLimiter: &imageConcurrencyLimiter{},
maxAccountSwitches: maxAccountSwitches,
cfg: cfg,
}
@ -187,6 +189,23 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
setOpsRequestContext(c, reqModel, reqStream, body)
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false)))
imageIntent := service.IsImageGenerationIntent("/v1/responses", reqModel, body)
if imageIntent && !service.GroupAllowsImageGeneration(apiKey.Group) {
h.errorResponse(c, http.StatusForbidden, "permission_error", service.ImageGenerationPermissionMessage())
return
}
var imageReleaseFunc func()
if imageIntent {
var imageAcquired bool
imageReleaseFunc, imageAcquired = h.acquireImageGenerationSlot(c, streamStarted)
if !imageAcquired {
return
}
if imageReleaseFunc != nil {
defer imageReleaseFunc()
}
}
// 解析渠道级模型映射
channelMapping, _ := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel)
@ -318,57 +337,65 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
service.SetOpsLatencyMs(c, service.OpsTimeToFirstTokenMsKey, int64(*result.FirstTokenMs))
}
if err != nil {
var failoverErr *service.UpstreamFailoverError
if errors.As(err, &failoverErr) {
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
// 池模式:同账号重试
if failoverErr.RetryableOnSameAccount {
retryLimit := account.GetPoolModeRetryCount()
if sameAccountRetryCount[account.ID] < retryLimit {
sameAccountRetryCount[account.ID]++
reqLog.Warn("openai.pool_mode_same_account_retry",
zap.Int64("account_id", account.ID),
zap.Int("upstream_status", failoverErr.StatusCode),
zap.Int("retry_limit", retryLimit),
zap.Int("retry_count", sameAccountRetryCount[account.ID]),
)
select {
case <-c.Request.Context().Done():
return
case <-time.After(sameAccountRetryDelay):
if result != nil && result.ImageCount > 0 {
reqLog.Warn("openai.forward_partial_error_with_image_result",
zap.Int64("account_id", account.ID),
zap.Int("image_count", result.ImageCount),
zap.Error(err),
)
} else {
var failoverErr *service.UpstreamFailoverError
if errors.As(err, &failoverErr) {
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
// 池模式:同账号重试
if failoverErr.RetryableOnSameAccount {
retryLimit := account.GetPoolModeRetryCount()
if sameAccountRetryCount[account.ID] < retryLimit {
sameAccountRetryCount[account.ID]++
reqLog.Warn("openai.pool_mode_same_account_retry",
zap.Int64("account_id", account.ID),
zap.Int("upstream_status", failoverErr.StatusCode),
zap.Int("retry_limit", retryLimit),
zap.Int("retry_count", sameAccountRetryCount[account.ID]),
)
select {
case <-c.Request.Context().Done():
return
case <-time.After(sameAccountRetryDelay):
}
continue
}
continue
}
h.gatewayService.RecordOpenAIAccountSwitch()
failedAccountIDs[account.ID] = struct{}{}
lastFailoverErr = failoverErr
if switchCount >= maxAccountSwitches {
h.handleFailoverExhausted(c, failoverErr, streamStarted)
return
}
switchCount++
reqLog.Warn("openai.upstream_failover_switching",
zap.Int64("account_id", account.ID),
zap.Int("upstream_status", failoverErr.StatusCode),
zap.Int("switch_count", switchCount),
zap.Int("max_switches", maxAccountSwitches),
)
continue
}
h.gatewayService.RecordOpenAIAccountSwitch()
failedAccountIDs[account.ID] = struct{}{}
lastFailoverErr = failoverErr
if switchCount >= maxAccountSwitches {
h.handleFailoverExhausted(c, failoverErr, streamStarted)
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
wroteFallback := h.ensureForwardErrorResponse(c, streamStarted)
fields := []zap.Field{
zap.Int64("account_id", account.ID),
zap.Bool("fallback_error_response_written", wroteFallback),
zap.Error(err),
}
if shouldLogOpenAIForwardFailureAsWarn(c, wroteFallback) {
reqLog.Warn("openai.forward_failed", fields...)
return
}
switchCount++
reqLog.Warn("openai.upstream_failover_switching",
zap.Int64("account_id", account.ID),
zap.Int("upstream_status", failoverErr.StatusCode),
zap.Int("switch_count", switchCount),
zap.Int("max_switches", maxAccountSwitches),
)
continue
}
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
wroteFallback := h.ensureForwardErrorResponse(c, streamStarted)
fields := []zap.Field{
zap.Int64("account_id", account.ID),
zap.Bool("fallback_error_response_written", wroteFallback),
zap.Error(err),
}
if shouldLogOpenAIForwardFailureAsWarn(c, wroteFallback) {
reqLog.Warn("openai.forward_failed", fields...)
reqLog.Error("openai.forward_failed", fields...)
return
}
reqLog.Error("openai.forward_failed", fields...)
return
}
if result != nil {
if account.Type == service.AccountTypeOAuth {
@ -383,17 +410,19 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
userAgent := c.GetHeader("User-Agent")
clientIP := ip.GetClientIP(c)
requestPayloadHash := service.HashUsageRequestPayload(body)
inboundEndpoint := GetInboundEndpoint(c)
upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform)
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
h.submitUsageRecordTask(func(ctx context.Context) {
h.submitOpenAIUsageRecordTask(result, func(ctx context.Context) {
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
Result: result,
APIKey: apiKey,
User: apiKey.User,
Account: account,
Subscription: subscription,
InboundEndpoint: GetInboundEndpoint(c),
UpstreamEndpoint: GetUpstreamEndpoint(c, account.Platform),
InboundEndpoint: inboundEndpoint,
UpstreamEndpoint: upstreamEndpoint,
UserAgent: userAgent,
IPAddress: clientIP,
RequestPayloadHash: requestPayloadHash,
@ -701,52 +730,60 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
service.SetOpsLatencyMs(c, service.OpsTimeToFirstTokenMsKey, int64(*result.FirstTokenMs))
}
if err != nil {
var failoverErr *service.UpstreamFailoverError
if errors.As(err, &failoverErr) {
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
// 池模式:同账号重试
if failoverErr.RetryableOnSameAccount {
retryLimit := account.GetPoolModeRetryCount()
if sameAccountRetryCount[account.ID] < retryLimit {
sameAccountRetryCount[account.ID]++
reqLog.Warn("openai_messages.pool_mode_same_account_retry",
zap.Int64("account_id", account.ID),
zap.Int("upstream_status", failoverErr.StatusCode),
zap.Int("retry_limit", retryLimit),
zap.Int("retry_count", sameAccountRetryCount[account.ID]),
)
select {
case <-c.Request.Context().Done():
return
case <-time.After(sameAccountRetryDelay):
}
continue
}
}
h.gatewayService.RecordOpenAIAccountSwitch()
failedAccountIDs[account.ID] = struct{}{}
lastFailoverErr = failoverErr
if switchCount >= maxAccountSwitches {
h.handleAnthropicFailoverExhausted(c, failoverErr, streamStarted)
return
}
switchCount++
reqLog.Warn("openai_messages.upstream_failover_switching",
if result != nil && result.ImageCount > 0 {
reqLog.Warn("openai_messages.forward_partial_error_with_image_result",
zap.Int64("account_id", account.ID),
zap.Int("upstream_status", failoverErr.StatusCode),
zap.Int("switch_count", switchCount),
zap.Int("max_switches", maxAccountSwitches),
zap.Int("image_count", result.ImageCount),
zap.Error(err),
)
continue
} else {
var failoverErr *service.UpstreamFailoverError
if errors.As(err, &failoverErr) {
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
// 池模式:同账号重试
if failoverErr.RetryableOnSameAccount {
retryLimit := account.GetPoolModeRetryCount()
if sameAccountRetryCount[account.ID] < retryLimit {
sameAccountRetryCount[account.ID]++
reqLog.Warn("openai_messages.pool_mode_same_account_retry",
zap.Int64("account_id", account.ID),
zap.Int("upstream_status", failoverErr.StatusCode),
zap.Int("retry_limit", retryLimit),
zap.Int("retry_count", sameAccountRetryCount[account.ID]),
)
select {
case <-c.Request.Context().Done():
return
case <-time.After(sameAccountRetryDelay):
}
continue
}
}
h.gatewayService.RecordOpenAIAccountSwitch()
failedAccountIDs[account.ID] = struct{}{}
lastFailoverErr = failoverErr
if switchCount >= maxAccountSwitches {
h.handleAnthropicFailoverExhausted(c, failoverErr, streamStarted)
return
}
switchCount++
reqLog.Warn("openai_messages.upstream_failover_switching",
zap.Int64("account_id", account.ID),
zap.Int("upstream_status", failoverErr.StatusCode),
zap.Int("switch_count", switchCount),
zap.Int("max_switches", maxAccountSwitches),
)
continue
}
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
wroteFallback := h.ensureAnthropicErrorResponse(c, streamStarted)
reqLog.Warn("openai_messages.forward_failed",
zap.Int64("account_id", account.ID),
zap.Bool("fallback_error_response_written", wroteFallback),
zap.Error(err),
)
return
}
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
wroteFallback := h.ensureAnthropicErrorResponse(c, streamStarted)
reqLog.Warn("openai_messages.forward_failed",
zap.Int64("account_id", account.ID),
zap.Bool("fallback_error_response_written", wroteFallback),
zap.Error(err),
)
return
}
if result != nil {
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, result.FirstTokenMs)
@ -757,16 +794,18 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
userAgent := c.GetHeader("User-Agent")
clientIP := ip.GetClientIP(c)
requestPayloadHash := service.HashUsageRequestPayload(body)
inboundEndpoint := GetInboundEndpoint(c)
upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform)
h.submitUsageRecordTask(func(ctx context.Context) {
h.submitOpenAIUsageRecordTask(result, func(ctx context.Context) {
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
Result: result,
APIKey: apiKey,
User: apiKey.User,
Account: account,
Subscription: subscription,
InboundEndpoint: GetInboundEndpoint(c),
UpstreamEndpoint: GetUpstreamEndpoint(c, account.Platform),
InboundEndpoint: inboundEndpoint,
UpstreamEndpoint: upstreamEndpoint,
UserAgent: userAgent,
IPAddress: clientIP,
RequestPayloadHash: requestPayloadHash,
@ -1114,6 +1153,11 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) {
setOpsRequestContext(c, reqModel, true, firstMessage)
setOpsEndpointContext(c, "", int16(service.RequestTypeWSV2))
if service.IsImageGenerationIntent("/v1/responses", reqModel, firstMessage) && !service.GroupAllowsImageGeneration(apiKey.Group) {
closeOpenAIClientWS(wsConn, coderws.StatusPolicyViolation, service.ImageGenerationPermissionMessage())
return
}
// 解析渠道级模型映射
channelMappingWS, _ := h.gatewayService.ResolveChannelMappingAndRestrict(ctx, apiKey.GroupID, reqModel)
@ -1257,22 +1301,34 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) {
},
AfterTurn: func(turn int, result *service.OpenAIForwardResult, turnErr error) {
releaseTurnSlots()
if turnErr != nil || result == nil {
if turnErr != nil {
if result == nil || result.ImageCount <= 0 {
return
}
reqLog.Warn("openai.websocket_partial_error_with_image_result",
zap.Int64("account_id", account.ID),
zap.Int("image_count", result.ImageCount),
zap.Error(turnErr),
)
}
if result == nil {
return
}
if account.Type == service.AccountTypeOAuth {
h.gatewayService.UpdateCodexUsageSnapshotFromHeaders(ctx, account.ID, result.ResponseHeaders)
}
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, result.FirstTokenMs)
h.submitUsageRecordTask(func(taskCtx context.Context) {
inboundEndpoint := GetInboundEndpoint(c)
upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform)
h.submitOpenAIUsageRecordTask(result, func(taskCtx context.Context) {
if err := h.gatewayService.RecordUsage(taskCtx, &service.OpenAIRecordUsageInput{
Result: result,
APIKey: apiKey,
User: apiKey.User,
Account: account,
Subscription: subscription,
InboundEndpoint: GetInboundEndpoint(c),
UpstreamEndpoint: GetUpstreamEndpoint(c, account.Platform),
InboundEndpoint: inboundEndpoint,
UpstreamEndpoint: upstreamEndpoint,
UserAgent: userAgent,
IPAddress: clientIP,
RequestPayloadHash: service.HashUsageRequestPayload(firstMessage),
@ -1440,6 +1496,60 @@ func (h *OpenAIGatewayHandler) submitUsageRecordTask(task service.UsageRecordTas
task(ctx)
}
func (h *OpenAIGatewayHandler) submitOpenAIUsageRecordTask(result *service.OpenAIForwardResult, task service.UsageRecordTask) {
if result != nil && result.ImageCount > 0 {
h.submitMandatoryUsageRecordTask(task)
return
}
h.submitUsageRecordTask(task)
}
func (h *OpenAIGatewayHandler) submitMandatoryUsageRecordTask(task service.UsageRecordTask) {
if task == nil {
return
}
if h.usageRecordWorkerPool != nil {
if mode := h.usageRecordWorkerPool.Submit(task); mode != service.UsageRecordSubmitModeDropped {
return
}
logger.L().With(
zap.String("component", "handler.openai_gateway.usage"),
).Warn("openai.usage_record_task_mandatory_sync_fallback")
}
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
defer func() {
if recovered := recover(); recovered != nil {
logger.L().With(
zap.String("component", "handler.openai_gateway.usage"),
zap.Any("panic", recovered),
).Error("openai.usage_record_task_panic_recovered")
}
}()
task(ctx)
}
func (h *OpenAIGatewayHandler) acquireImageGenerationSlot(c *gin.Context, streamStarted bool) (func(), bool) {
if h == nil || h.cfg == nil || h.imageLimiter == nil {
return nil, true
}
imageConcurrency := h.cfg.Gateway.ImageConcurrency
wait := strings.TrimSpace(imageConcurrency.OverflowMode) == config.ImageConcurrencyOverflowModeWait
release, acquired := h.imageLimiter.Acquire(
c.Request.Context(),
imageConcurrency.Enabled,
imageConcurrency.MaxConcurrentRequests,
wait,
time.Duration(imageConcurrency.WaitTimeoutSeconds)*time.Second,
imageConcurrency.MaxWaitingRequests,
)
if acquired {
return release, true
}
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Image generation concurrency limit exceeded, please retry later", streamStarted)
return nil, false
}
// handleConcurrencyError handles concurrency-related errors with proper 429 response
func (h *OpenAIGatewayHandler) handleConcurrencyError(c *gin.Context, err error, slotType string, streamStarted bool) {
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error",

View File

@ -81,6 +81,18 @@ func (h *OpenAIGatewayHandler) Images(c *gin.Context) {
zap.String("capability", string(parsed.RequiredCapability)),
)
if !service.GroupAllowsImageGeneration(apiKey.Group) {
h.errorResponse(c, http.StatusForbidden, "permission_error", service.ImageGenerationPermissionMessage())
return
}
imageReleaseFunc, acquired := h.acquireImageGenerationSlot(c, streamStarted)
if !acquired {
return
}
if imageReleaseFunc != nil {
defer imageReleaseFunc()
}
if parsed.Multipart {
setOpsRequestContext(c, parsed.Model, parsed.Stream, nil)
} else {
@ -188,62 +200,69 @@ func (h *OpenAIGatewayHandler) Images(c *gin.Context) {
responseLatencyMs = forwardDurationMs - upstreamLatencyMs
}
service.SetOpsLatencyMs(c, service.OpsResponseLatencyMsKey, responseLatencyMs)
if err == nil && result != nil && result.FirstTokenMs != nil {
if result != nil && result.FirstTokenMs != nil {
service.SetOpsLatencyMs(c, service.OpsTimeToFirstTokenMsKey, int64(*result.FirstTokenMs))
}
if err != nil {
var failoverErr *service.UpstreamFailoverError
if errors.As(err, &failoverErr) {
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
if failoverErr.RetryableOnSameAccount {
retryLimit := account.GetPoolModeRetryCount()
if sameAccountRetryCount[account.ID] < retryLimit {
sameAccountRetryCount[account.ID]++
reqLog.Warn("openai.images.pool_mode_same_account_retry",
zap.Int64("account_id", account.ID),
zap.Int("upstream_status", failoverErr.StatusCode),
zap.Int("retry_limit", retryLimit),
zap.Int("retry_count", sameAccountRetryCount[account.ID]),
)
select {
case <-c.Request.Context().Done():
return
case <-time.After(sameAccountRetryDelay):
if result != nil && result.ImageCount > 0 {
reqLog.Warn("openai.images.forward_partial_error_with_image_result",
zap.Int64("account_id", account.ID),
zap.Int("image_count", result.ImageCount),
zap.Error(err),
)
} else {
var failoverErr *service.UpstreamFailoverError
if errors.As(err, &failoverErr) {
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
if failoverErr.RetryableOnSameAccount {
retryLimit := account.GetPoolModeRetryCount()
if sameAccountRetryCount[account.ID] < retryLimit {
sameAccountRetryCount[account.ID]++
reqLog.Warn("openai.images.pool_mode_same_account_retry",
zap.Int64("account_id", account.ID),
zap.Int("upstream_status", failoverErr.StatusCode),
zap.Int("retry_limit", retryLimit),
zap.Int("retry_count", sameAccountRetryCount[account.ID]),
)
select {
case <-c.Request.Context().Done():
return
case <-time.After(sameAccountRetryDelay):
}
continue
}
continue
}
h.gatewayService.RecordOpenAIAccountSwitch()
failedAccountIDs[account.ID] = struct{}{}
lastFailoverErr = failoverErr
if switchCount >= maxAccountSwitches {
h.handleFailoverExhausted(c, failoverErr, streamStarted)
return
}
switchCount++
reqLog.Warn("openai.images.upstream_failover_switching",
zap.Int64("account_id", account.ID),
zap.Int("upstream_status", failoverErr.StatusCode),
zap.Int("switch_count", switchCount),
zap.Int("max_switches", maxAccountSwitches),
)
continue
}
h.gatewayService.RecordOpenAIAccountSwitch()
failedAccountIDs[account.ID] = struct{}{}
lastFailoverErr = failoverErr
if switchCount >= maxAccountSwitches {
h.handleFailoverExhausted(c, failoverErr, streamStarted)
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
wroteFallback := h.ensureForwardErrorResponse(c, streamStarted)
fields := []zap.Field{
zap.Int64("account_id", account.ID),
zap.Bool("fallback_error_response_written", wroteFallback),
zap.Error(err),
}
if shouldLogOpenAIForwardFailureAsWarn(c, wroteFallback) {
reqLog.Warn("openai.images.forward_failed", fields...)
return
}
switchCount++
reqLog.Warn("openai.images.upstream_failover_switching",
zap.Int64("account_id", account.ID),
zap.Int("upstream_status", failoverErr.StatusCode),
zap.Int("switch_count", switchCount),
zap.Int("max_switches", maxAccountSwitches),
)
continue
}
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
wroteFallback := h.ensureForwardErrorResponse(c, streamStarted)
fields := []zap.Field{
zap.Int64("account_id", account.ID),
zap.Bool("fallback_error_response_written", wroteFallback),
zap.Error(err),
}
if shouldLogOpenAIForwardFailureAsWarn(c, wroteFallback) {
reqLog.Warn("openai.images.forward_failed", fields...)
reqLog.Error("openai.images.forward_failed", fields...)
return
}
reqLog.Error("openai.images.forward_failed", fields...)
return
}
if result != nil {
if account.Type == service.AccountTypeOAuth {
h.gatewayService.UpdateCodexUsageSnapshotFromHeaders(c.Request.Context(), account.ID, result.ResponseHeaders)
@ -259,21 +278,27 @@ func (h *OpenAIGatewayHandler) Images(c *gin.Context) {
if parsed.Multipart {
requestPayloadHash = service.HashUsageRequestPayload([]byte(parsed.StickySessionSeed()))
}
inboundEndpoint := GetInboundEndpoint(c)
upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform)
h.submitUsageRecordTask(func(ctx context.Context) {
upstreamModel := ""
if result != nil {
upstreamModel = result.UpstreamModel
}
h.submitMandatoryUsageRecordTask(func(ctx context.Context) {
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
Result: result,
APIKey: apiKey,
User: apiKey.User,
Account: account,
Subscription: subscription,
InboundEndpoint: GetInboundEndpoint(c),
UpstreamEndpoint: GetUpstreamEndpoint(c, account.Platform),
InboundEndpoint: inboundEndpoint,
UpstreamEndpoint: upstreamEndpoint,
UserAgent: userAgent,
IPAddress: clientIP,
RequestPayloadHash: requestPayloadHash,
APIKeyService: h.apiKeyService,
ChannelUsageFields: channelMapping.ToUsageFields(parsed.Model, result.UpstreamModel),
ChannelUsageFields: channelMapping.ToUsageFields(parsed.Model, upstreamModel),
}); err != nil {
logger.L().With(
zap.String("component", "handler.openai_gateway.images"),

View File

@ -0,0 +1,49 @@
package handler
import (
"bytes"
"net/http"
"net/http/httptest"
"testing"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
"github.com/tidwall/gjson"
)
func TestOpenAIGatewayHandlerImages_DisabledGroupRejectsBeforeScheduling(t *testing.T) {
gin.SetMode(gin.TestMode)
body := []byte(`{"model":"gpt-image-2","prompt":"draw","size":"1024x1024"}`)
req := httptest.NewRequest(http.MethodPost, "/v1/images/generations", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = req
groupID := int64(111)
c.Set(string(middleware2.ContextKeyAPIKey), &service.APIKey{
ID: 222,
GroupID: &groupID,
Group: &service.Group{
ID: groupID,
AllowImageGeneration: false,
},
User: &service.User{ID: 333},
})
c.Set(string(middleware2.ContextKeyUser), middleware2.AuthSubject{UserID: 333, Concurrency: 1})
h := &OpenAIGatewayHandler{
gatewayService: &service.OpenAIGatewayService{},
billingCacheService: &service.BillingCacheService{},
apiKeyService: &service.APIKeyService{},
concurrencyHelper: &ConcurrencyHelper{concurrencyService: &service.ConcurrencyService{}},
}
h.Images(c)
require.Equal(t, http.StatusForbidden, rec.Code)
require.Equal(t, "permission_error", gjson.GetBytes(rec.Body.Bytes(), "error.type").String())
require.Contains(t, rec.Body.String(), service.ImageGenerationPermissionMessage())
}

View File

@ -129,3 +129,63 @@ func TestOpenAIGatewayHandlerSubmitUsageRecordTask_WithoutPool_TaskPanicRecovere
})
require.True(t, called.Load(), "panic 后后续任务应仍可执行")
}
func TestOpenAIGatewayHandlerSubmitMandatoryUsageRecordTask_DroppedTaskSyncFallback(t *testing.T) {
pool := service.NewUsageRecordWorkerPoolWithOptions(service.UsageRecordWorkerPoolOptions{
WorkerCount: 1,
QueueSize: 1,
TaskTimeout: time.Second,
OverflowPolicy: "drop",
OverflowSamplePercent: 0,
AutoScaleEnabled: false,
})
t.Cleanup(pool.Stop)
h := &OpenAIGatewayHandler{usageRecordWorkerPool: pool}
block := make(chan struct{})
release := make(chan struct{})
pool.Submit(func(ctx context.Context) {
close(block)
<-release
})
<-block
pool.Submit(func(ctx context.Context) {})
var called atomic.Bool
h.submitMandatoryUsageRecordTask(func(ctx context.Context) {
called.Store(true)
})
close(release)
require.True(t, called.Load(), "mandatory usage task must run synchronously when async submit is dropped")
}
func TestOpenAIGatewayHandlerSubmitOpenAIUsageRecordTask_ImageResultUsesMandatoryFallback(t *testing.T) {
pool := service.NewUsageRecordWorkerPoolWithOptions(service.UsageRecordWorkerPoolOptions{
WorkerCount: 1,
QueueSize: 1,
TaskTimeout: time.Second,
OverflowPolicy: "drop",
OverflowSamplePercent: 0,
AutoScaleEnabled: false,
})
t.Cleanup(pool.Stop)
h := &OpenAIGatewayHandler{usageRecordWorkerPool: pool}
block := make(chan struct{})
release := make(chan struct{})
pool.Submit(func(ctx context.Context) {
close(block)
<-release
})
<-block
pool.Submit(func(ctx context.Context) {})
var called atomic.Bool
h.submitOpenAIUsageRecordTask(&service.OpenAIForwardResult{ImageCount: 1}, func(ctx context.Context) {
called.Store(true)
})
close(release)
require.True(t, called.Load(), "image usage task must be mandatory when async submit is dropped")
}

View File

@ -166,6 +166,9 @@ func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*se
group.FieldDailyLimitUsd,
group.FieldWeeklyLimitUsd,
group.FieldMonthlyLimitUsd,
group.FieldAllowImageGeneration,
group.FieldImageRateIndependent,
group.FieldImageRateMultiplier,
group.FieldImagePrice1k,
group.FieldImagePrice2k,
group.FieldImagePrice4k,
@ -699,6 +702,9 @@ func groupEntityToService(g *dbent.Group) *service.Group {
DailyLimitUSD: g.DailyLimitUsd,
WeeklyLimitUSD: g.WeeklyLimitUsd,
MonthlyLimitUSD: g.MonthlyLimitUsd,
AllowImageGeneration: g.AllowImageGeneration,
ImageRateIndependent: g.ImageRateIndependent,
ImageRateMultiplier: g.ImageRateMultiplier,
ImagePrice1K: g.ImagePrice1k,
ImagePrice2K: g.ImagePrice2k,
ImagePrice4K: g.ImagePrice4k,

View File

@ -50,6 +50,9 @@ func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) er
SetNillableDailyLimitUsd(groupIn.DailyLimitUSD).
SetNillableWeeklyLimitUsd(groupIn.WeeklyLimitUSD).
SetNillableMonthlyLimitUsd(groupIn.MonthlyLimitUSD).
SetAllowImageGeneration(groupIn.AllowImageGeneration).
SetImageRateIndependent(groupIn.ImageRateIndependent).
SetImageRateMultiplier(groupIn.ImageRateMultiplier).
SetNillableImagePrice1k(groupIn.ImagePrice1K).
SetNillableImagePrice2k(groupIn.ImagePrice2K).
SetNillableImagePrice4k(groupIn.ImagePrice4K).
@ -120,6 +123,9 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er
SetNillableDailyLimitUsd(groupIn.DailyLimitUSD).
SetNillableWeeklyLimitUsd(groupIn.WeeklyLimitUSD).
SetNillableMonthlyLimitUsd(groupIn.MonthlyLimitUSD).
SetAllowImageGeneration(groupIn.AllowImageGeneration).
SetImageRateIndependent(groupIn.ImageRateIndependent).
SetImageRateMultiplier(groupIn.ImageRateMultiplier).
SetNillableImagePrice1k(groupIn.ImagePrice1K).
SetNillableImagePrice2k(groupIn.ImagePrice2K).
SetNillableImagePrice4k(groupIn.ImagePrice4K).

View File

@ -328,6 +328,9 @@ func TestAPIContracts(t *testing.T) {
"image_price_1k": null,
"image_price_2k": null,
"image_price_4k": null,
"allow_image_generation": false,
"image_rate_independent": false,
"image_rate_multiplier": 0,
"claude_code_only": false,
"allow_messages_dispatch": false,
"fallback_group_id": null,

View File

@ -230,7 +230,11 @@ func applyAccountStatsCost(
if model == "" {
model = requestedModel
}
requestCount := 1
if usageLog != nil && usageLog.ImageCount > 0 {
requestCount = usageLog.ImageCount
}
usageLog.AccountStatsCost = resolveAccountStatsCost(
ctx, cs, bs, accountID, groupID, model, tokens, 1, totalCost,
ctx, cs, bs, accountID, groupID, model, tokens, requestCount, totalCost,
)
}

View File

@ -189,11 +189,14 @@ type CreateGroupInput struct {
WeeklyLimitUSD *float64 // 周限额 (USD)
MonthlyLimitUSD *float64 // 月限额 (USD)
// 图片生成计费配置(仅 antigravity 平台使用)
ImagePrice1K *float64
ImagePrice2K *float64
ImagePrice4K *float64
ClaudeCodeOnly bool // 仅允许 Claude Code 客户端
FallbackGroupID *int64 // 降级分组 ID
AllowImageGeneration bool
ImageRateIndependent bool
ImageRateMultiplier *float64
ImagePrice1K *float64
ImagePrice2K *float64
ImagePrice4K *float64
ClaudeCodeOnly bool // 仅允许 Claude Code 客户端
FallbackGroupID *int64 // 降级分组 ID
// 无效请求兜底分组 ID仅 anthropic 平台使用)
FallbackGroupIDOnInvalidRequest *int64
// 模型路由配置(仅 anthropic 平台使用)
@ -226,11 +229,14 @@ type UpdateGroupInput struct {
WeeklyLimitUSD *float64 // 周限额 (USD)
MonthlyLimitUSD *float64 // 月限额 (USD)
// 图片生成计费配置(仅 antigravity 平台使用)
ImagePrice1K *float64
ImagePrice2K *float64
ImagePrice4K *float64
ClaudeCodeOnly *bool // 仅允许 Claude Code 客户端
FallbackGroupID *int64 // 降级分组 ID
AllowImageGeneration *bool
ImageRateIndependent *bool
ImageRateMultiplier *float64
ImagePrice1K *float64
ImagePrice2K *float64
ImagePrice4K *float64
ClaudeCodeOnly *bool // 仅允许 Claude Code 客户端
FallbackGroupID *int64 // 降级分组 ID
// 无效请求兜底分组 ID仅 anthropic 平台使用)
FallbackGroupIDOnInvalidRequest *int64
// 模型路由配置(仅 anthropic 平台使用)
@ -1557,6 +1563,13 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
imagePrice1K := normalizePrice(input.ImagePrice1K)
imagePrice2K := normalizePrice(input.ImagePrice2K)
imagePrice4K := normalizePrice(input.ImagePrice4K)
imageRateMultiplier := 1.0
if input.ImageRateMultiplier != nil {
if *input.ImageRateMultiplier < 0 {
return nil, errors.New("image_rate_multiplier must be >= 0")
}
imageRateMultiplier = *input.ImageRateMultiplier
}
// 校验降级分组
if input.FallbackGroupID != nil {
@ -1624,6 +1637,9 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
DailyLimitUSD: dailyLimit,
WeeklyLimitUSD: weeklyLimit,
MonthlyLimitUSD: monthlyLimit,
AllowImageGeneration: input.AllowImageGeneration,
ImageRateIndependent: input.ImageRateIndependent,
ImageRateMultiplier: imageRateMultiplier,
ImagePrice1K: imagePrice1K,
ImagePrice2K: imagePrice2K,
ImagePrice4K: imagePrice4K,
@ -1800,6 +1816,18 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd
group.WeeklyLimitUSD = normalizeLimit(input.WeeklyLimitUSD)
group.MonthlyLimitUSD = normalizeLimit(input.MonthlyLimitUSD)
// 图片生成计费配置:负数表示清除(使用默认价格)
if input.AllowImageGeneration != nil {
group.AllowImageGeneration = *input.AllowImageGeneration
}
if input.ImageRateIndependent != nil {
group.ImageRateIndependent = *input.ImageRateIndependent
}
if input.ImageRateMultiplier != nil {
if *input.ImageRateMultiplier < 0 {
return nil, errors.New("image_rate_multiplier must be >= 0")
}
group.ImageRateMultiplier = *input.ImageRateMultiplier
}
if input.ImagePrice1K != nil {
group.ImagePrice1K = normalizePrice(input.ImagePrice1K)
}

View File

@ -266,6 +266,50 @@ func TestAdminService_UpdateGroup_PartialImagePricing(t *testing.T) {
require.Nil(t, repo.updated.ImagePrice4K)
}
func TestAdminService_UpdateGroup_PreservesImageGenerationControlsWhenOmitted(t *testing.T) {
imageMultiplier := 0.5
existingGroup := &Group{
ID: 1,
Name: "existing-group",
Platform: PlatformOpenAI,
Status: StatusActive,
AllowImageGeneration: true,
ImageRateIndependent: true,
ImageRateMultiplier: imageMultiplier,
}
repo := &groupRepoStubForAdmin{getByID: existingGroup}
svc := &adminServiceImpl{groupRepo: repo}
group, err := svc.UpdateGroup(context.Background(), 1, &UpdateGroupInput{
Description: "updated",
})
require.NoError(t, err)
require.NotNil(t, group)
require.NotNil(t, repo.updated)
require.True(t, repo.updated.AllowImageGeneration)
require.True(t, repo.updated.ImageRateIndependent)
require.InDelta(t, 0.5, repo.updated.ImageRateMultiplier, 1e-12)
}
func TestAdminService_UpdateGroup_RejectsNegativeImageRateMultiplier(t *testing.T) {
existingGroup := &Group{
ID: 1,
Name: "existing-group",
Platform: PlatformOpenAI,
Status: StatusActive,
ImageRateMultiplier: 1,
}
repo := &groupRepoStubForAdmin{getByID: existingGroup}
svc := &adminServiceImpl{groupRepo: repo}
negative := -0.1
_, err := svc.UpdateGroup(context.Background(), 1, &UpdateGroupInput{
ImageRateMultiplier: &negative,
})
require.Error(t, err)
require.Nil(t, repo.updated)
}
func TestAdminService_UpdateGroup_InvalidatesAuthCacheOnRPMLimitChange(t *testing.T) {
existingGroup := &Group{
ID: 1,

View File

@ -63,6 +63,9 @@ type APIKeyAuthGroupSnapshot struct {
DailyLimitUSD *float64 `json:"daily_limit_usd,omitempty"`
WeeklyLimitUSD *float64 `json:"weekly_limit_usd,omitempty"`
MonthlyLimitUSD *float64 `json:"monthly_limit_usd,omitempty"`
AllowImageGeneration bool `json:"allow_image_generation"`
ImageRateIndependent bool `json:"image_rate_independent"`
ImageRateMultiplier float64 `json:"image_rate_multiplier"`
ImagePrice1K *float64 `json:"image_price_1k,omitempty"`
ImagePrice2K *float64 `json:"image_price_2k,omitempty"`
ImagePrice4K *float64 `json:"image_price_4k,omitempty"`

View File

@ -14,7 +14,7 @@ import (
"github.com/dgraph-io/ristretto"
)
const apiKeyAuthSnapshotVersion = 7 // v7: added UserGroupRPMOverride on user snapshot
const apiKeyAuthSnapshotVersion = 8 // v8: added group image generation controls
type apiKeyAuthCacheConfig struct {
l1Size int
@ -255,6 +255,9 @@ func (s *APIKeyService) snapshotFromAPIKey(ctx context.Context, apiKey *APIKey)
DailyLimitUSD: apiKey.Group.DailyLimitUSD,
WeeklyLimitUSD: apiKey.Group.WeeklyLimitUSD,
MonthlyLimitUSD: apiKey.Group.MonthlyLimitUSD,
AllowImageGeneration: apiKey.Group.AllowImageGeneration,
ImageRateIndependent: apiKey.Group.ImageRateIndependent,
ImageRateMultiplier: apiKey.Group.ImageRateMultiplier,
ImagePrice1K: apiKey.Group.ImagePrice1K,
ImagePrice2K: apiKey.Group.ImagePrice2K,
ImagePrice4K: apiKey.Group.ImagePrice4K,
@ -321,6 +324,9 @@ func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapsho
DailyLimitUSD: snapshot.Group.DailyLimitUSD,
WeeklyLimitUSD: snapshot.Group.WeeklyLimitUSD,
MonthlyLimitUSD: snapshot.Group.MonthlyLimitUSD,
AllowImageGeneration: snapshot.Group.AllowImageGeneration,
ImageRateIndependent: snapshot.Group.ImageRateIndependent,
ImageRateMultiplier: snapshot.Group.ImageRateMultiplier,
ImagePrice1K: snapshot.Group.ImagePrice1K,
ImagePrice2K: snapshot.Group.ImagePrice2K,
ImagePrice4K: snapshot.Group.ImagePrice4K,

View File

@ -294,8 +294,7 @@ func (s *BillingService) getFallbackPricing(model string) *ModelPricing {
}
// OpenAI 仅匹配已知 GPT-5/Codex 族,避免未知 OpenAI 型号误计价。
if strings.Contains(modelLower, "gpt-5") || strings.Contains(modelLower, "codex") {
normalized := normalizeCodexModel(modelLower)
if normalized := normalizeKnownOpenAICodexModel(modelLower); normalized != "" {
switch normalized {
case "gpt-5.5":
return s.fallbackPrices["gpt-5.5"]
@ -644,13 +643,10 @@ func (s *BillingService) shouldApplySessionLongContextPricing(tokens UsageTokens
}
func isOpenAIGPT54Model(model string) bool {
trimmed := strings.TrimSpace(strings.ToLower(model))
// 仅当模型字符串实际属于 GPT-5/Codex 族时才做归一判定,避免 normalizeCodexModel
// 的默认兜底把非 OpenAI 模型claude-*、gemini-*、gpt-4o误识别为 gpt-5.4。
if !strings.Contains(trimmed, "gpt-5") && !strings.Contains(trimmed, "codex") {
return false
}
normalized := normalizeCodexModel(trimmed)
// 仅当模型字符串实际属于已知 GPT-5/Codex 族时才做归一判定,避免
// normalizeCodexModel 的默认兜底把非 OpenAI 模型claude-*、gemini-*、gpt-4o
// 误识别为 gpt-5.4。
normalized := normalizeKnownOpenAICodexModel(model)
return normalized == "gpt-5.4" || normalized == "gpt-5.5"
}

View File

@ -137,6 +137,35 @@ func TestGetModelPricing_OpenAIGPT54Fallback(t *testing.T) {
require.InDelta(t, 1.5, pricing.LongContextOutputMultiplier, 1e-12)
}
func TestGetModelPricing_OpenAICompactAliasesFallback(t *testing.T) {
svc := newTestBillingService()
tests := []struct {
model string
inputPrice float64
outputPrice float64
cacheRead float64
longContext int
}{
{model: "gpt5.5", inputPrice: 2.5e-6, outputPrice: 15e-6, cacheRead: 0.25e-6, longContext: 272000},
{model: "openai/gpt5.4", inputPrice: 2.5e-6, outputPrice: 15e-6, cacheRead: 0.25e-6, longContext: 272000},
{model: "gpt5.4-mini", inputPrice: 7.5e-7, outputPrice: 4.5e-6, cacheRead: 7.5e-8, longContext: 0},
{model: "gpt5.3codexspark", inputPrice: 1.5e-6, outputPrice: 12e-6, cacheRead: 0.15e-6, longContext: 0},
}
for _, tt := range tests {
t.Run(tt.model, func(t *testing.T) {
pricing, err := svc.GetModelPricing(tt.model)
require.NoError(t, err)
require.NotNil(t, pricing)
require.InDelta(t, tt.inputPrice, pricing.InputPricePerToken, 1e-12)
require.InDelta(t, tt.outputPrice, pricing.OutputPricePerToken, 1e-12)
require.InDelta(t, tt.cacheRead, pricing.CacheReadPricePerToken, 1e-12)
require.Equal(t, tt.longContext, pricing.LongContextInputThreshold)
})
}
}
func TestGetModelPricing_OpenAIGPT54MiniFallback(t *testing.T) {
svc := newTestBillingService()

View File

@ -8367,6 +8367,7 @@ func (s *GatewayService) recordUsageCore(ctx context.Context, input *recordUsage
groupDefault := apiKey.Group.RateMultiplier
multiplier = s.getUserGroupRateMultiplier(ctx, user.ID, *apiKey.GroupID, groupDefault)
}
imageMultiplier := resolveImageRateMultiplier(apiKey, multiplier)
// 确定计费模型
billingModel := forwardResultBillingModel(result.Model, result.UpstreamModel)
@ -8384,7 +8385,7 @@ func (s *GatewayService) recordUsageCore(ctx context.Context, input *recordUsage
}
// 计算费用
cost := s.calculateRecordUsageCost(ctx, result, apiKey, billingModel, multiplier, opts)
cost := s.calculateRecordUsageCost(ctx, result, apiKey, billingModel, multiplier, imageMultiplier, opts)
// 判断计费方式:订阅模式 vs 余额模式
isSubscriptionBilling := subscription != nil && apiKey.Group != nil && apiKey.Group.IsSubscriptionType()
@ -8396,7 +8397,7 @@ func (s *GatewayService) recordUsageCore(ctx context.Context, input *recordUsage
// 创建使用日志
accountRateMultiplier := account.BillingRateMultiplier()
usageLog := s.buildRecordUsageLog(ctx, input, result, apiKey, user, account, subscription,
requestedModel, multiplier, accountRateMultiplier, billingType, cacheTTLOverridden, cost, opts)
requestedModel, multiplier, imageMultiplier, accountRateMultiplier, billingType, cacheTTLOverridden, cost, opts)
// 计算账号统计定价费用(使用最终上游模型匹配自定义规则)
if apiKey.GroupID != nil {
@ -8450,11 +8451,12 @@ func (s *GatewayService) calculateRecordUsageCost(
apiKey *APIKey,
billingModel string,
multiplier float64,
imageMultiplier float64,
opts *recordUsageOpts,
) *CostBreakdown {
// 图片生成计费
if result.ImageCount > 0 {
return s.calculateImageCost(ctx, result, apiKey, billingModel, multiplier)
return s.calculateImageCost(ctx, result, apiKey, billingModel, imageMultiplier)
}
// Token 计费
@ -8495,7 +8497,8 @@ func (s *GatewayService) calculateImageCost(
Model: billingModel,
GroupID: &gid,
Tokens: tokens,
RequestCount: 1,
RequestCount: result.ImageCount,
SizeTier: result.ImageSize,
RateMultiplier: multiplier,
Resolver: s.resolver,
Resolved: resolved,
@ -8580,6 +8583,7 @@ func (s *GatewayService) buildRecordUsageLog(
subscription *UserSubscription,
requestedModel string,
multiplier float64,
imageMultiplier float64,
accountRateMultiplier float64,
billingType int8,
cacheTTLOverridden bool,
@ -8624,6 +8628,9 @@ func (s *GatewayService) buildRecordUsageLog(
SubscriptionID: optionalSubscriptionID(subscription),
CreatedAt: time.Now(),
}
if result.ImageCount > 0 {
usageLog.RateMultiplier = imageMultiplier
}
if cost != nil {
usageLog.InputCost = cost.InputCost
usageLog.OutputCost = cost.OutputCost

View File

@ -26,9 +26,12 @@ type Group struct {
DefaultValidityDays int
// 图片生成计费配置antigravity 和 gemini 平台使用)
ImagePrice1K *float64
ImagePrice2K *float64
ImagePrice4K *float64
AllowImageGeneration bool
ImageRateIndependent bool
ImageRateMultiplier float64
ImagePrice1K *float64
ImagePrice2K *float64
ImagePrice4K *float64
// Claude Code 客户端限制
ClaudeCodeOnly bool

View File

@ -45,19 +45,25 @@ type GroupSortOrderUpdate struct {
// CreateGroupRequest 创建分组请求
type CreateGroupRequest struct {
Name string `json:"name"`
Description string `json:"description"`
RateMultiplier float64 `json:"rate_multiplier"`
IsExclusive bool `json:"is_exclusive"`
Name string `json:"name"`
Description string `json:"description"`
RateMultiplier float64 `json:"rate_multiplier"`
IsExclusive bool `json:"is_exclusive"`
AllowImageGeneration bool `json:"allow_image_generation"`
ImageRateIndependent bool `json:"image_rate_independent"`
ImageRateMultiplier *float64 `json:"image_rate_multiplier"`
}
// UpdateGroupRequest 更新分组请求
type UpdateGroupRequest struct {
Name *string `json:"name"`
Description *string `json:"description"`
RateMultiplier *float64 `json:"rate_multiplier"`
IsExclusive *bool `json:"is_exclusive"`
Status *string `json:"status"`
Name *string `json:"name"`
Description *string `json:"description"`
RateMultiplier *float64 `json:"rate_multiplier"`
IsExclusive *bool `json:"is_exclusive"`
Status *string `json:"status"`
AllowImageGeneration *bool `json:"allow_image_generation"`
ImageRateIndependent *bool `json:"image_rate_independent"`
ImageRateMultiplier *float64 `json:"image_rate_multiplier"`
}
// GroupService 分组管理服务
@ -76,6 +82,13 @@ func NewGroupService(groupRepo GroupRepository, authCacheInvalidator APIKeyAuthC
// Create 创建分组
func (s *GroupService) Create(ctx context.Context, req CreateGroupRequest) (*Group, error) {
imageRateMultiplier := 1.0
if req.ImageRateMultiplier != nil {
if *req.ImageRateMultiplier < 0 {
return nil, fmt.Errorf("image_rate_multiplier must be >= 0")
}
imageRateMultiplier = *req.ImageRateMultiplier
}
// 检查名称是否已存在
exists, err := s.groupRepo.ExistsByName(ctx, req.Name)
if err != nil {
@ -87,13 +100,16 @@ func (s *GroupService) Create(ctx context.Context, req CreateGroupRequest) (*Gro
// 创建分组
group := &Group{
Name: req.Name,
Description: req.Description,
Platform: PlatformAnthropic,
RateMultiplier: req.RateMultiplier,
IsExclusive: req.IsExclusive,
Status: StatusActive,
SubscriptionType: SubscriptionTypeStandard,
Name: req.Name,
Description: req.Description,
Platform: PlatformAnthropic,
RateMultiplier: req.RateMultiplier,
IsExclusive: req.IsExclusive,
Status: StatusActive,
SubscriptionType: SubscriptionTypeStandard,
AllowImageGeneration: req.AllowImageGeneration,
ImageRateIndependent: req.ImageRateIndependent,
ImageRateMultiplier: imageRateMultiplier,
}
if err := s.groupRepo.Create(ctx, group); err != nil {
@ -165,6 +181,18 @@ func (s *GroupService) Update(ctx context.Context, id int64, req UpdateGroupRequ
if req.Status != nil {
group.Status = *req.Status
}
if req.AllowImageGeneration != nil {
group.AllowImageGeneration = *req.AllowImageGeneration
}
if req.ImageRateIndependent != nil {
group.ImageRateIndependent = *req.ImageRateIndependent
}
if req.ImageRateMultiplier != nil {
if *req.ImageRateMultiplier < 0 {
return nil, fmt.Errorf("image_rate_multiplier must be >= 0")
}
group.ImageRateMultiplier = *req.ImageRateMultiplier
}
if err := s.groupRepo.Update(ctx, group); err != nil {
return nil, fmt.Errorf("update group: %w", err)

View File

@ -0,0 +1,11 @@
package service
func resolveImageRateMultiplier(apiKey *APIKey, effectiveGroupMultiplier float64) float64 {
if apiKey != nil && apiKey.Group != nil && apiKey.Group.ImageRateIndependent {
if apiKey.Group.ImageRateMultiplier < 0 {
return 0
}
return apiKey.Group.ImageRateMultiplier
}
return effectiveGroupMultiplier
}

View File

@ -0,0 +1,220 @@
package service
import (
"encoding/json"
"strings"
"github.com/tidwall/gjson"
)
const (
openAIResponsesEndpoint = "/v1/responses"
openAIResponsesCompactEndpoint = "/v1/responses/compact"
imageGenerationPermissionMessage = "Image generation is not enabled for this group"
)
// ImageGenerationPermissionMessage returns the stable end-user error text for disabled groups.
func ImageGenerationPermissionMessage() string {
return imageGenerationPermissionMessage
}
// GroupAllowsImageGeneration preserves ungrouped-key behavior and enforces the flag when a group is present.
func GroupAllowsImageGeneration(group *Group) bool {
return group == nil || group.AllowImageGeneration
}
// IsImageGenerationIntent classifies requests that can produce generated images.
func IsImageGenerationIntent(endpoint string, requestedModel string, body []byte) bool {
if IsImageGenerationEndpoint(endpoint) {
return true
}
if isOpenAIImageGenerationModel(requestedModel) {
return true
}
if len(body) == 0 || !gjson.ValidBytes(body) {
return false
}
if model := strings.TrimSpace(gjson.GetBytes(body, "model").String()); isOpenAIImageGenerationModel(model) {
return true
}
if openAIJSONToolsContainImageGeneration(gjson.GetBytes(body, "tools")) {
return true
}
return openAIJSONToolChoiceSelectsImageGeneration(gjson.GetBytes(body, "tool_choice"))
}
// IsImageGenerationIntentMap is the map-backed variant used after service-side request mutation.
func IsImageGenerationIntentMap(endpoint string, requestedModel string, reqBody map[string]any) bool {
if IsImageGenerationEndpoint(endpoint) {
return true
}
if isOpenAIImageGenerationModel(requestedModel) {
return true
}
if reqBody == nil {
return false
}
if isOpenAIImageGenerationModel(firstNonEmptyString(reqBody["model"])) {
return true
}
if hasOpenAIImageGenerationTool(reqBody) {
return true
}
return openAIAnyToolChoiceSelectsImageGeneration(reqBody["tool_choice"])
}
// IsImageGenerationEndpoint identifies dedicated generated-image endpoints.
func IsImageGenerationEndpoint(endpoint string) bool {
switch normalizeImageGenerationEndpoint(endpoint) {
case "/v1/images/generations", "/v1/images/edits", "/images/generations", "/images/edits":
return true
default:
return false
}
}
func normalizeImageGenerationEndpoint(endpoint string) string {
endpoint = strings.TrimSpace(strings.ToLower(endpoint))
if endpoint == "" {
return ""
}
endpoint = strings.TrimPrefix(endpoint, "https://api.openai.com")
if idx := strings.IndexByte(endpoint, '?'); idx >= 0 {
endpoint = endpoint[:idx]
}
return strings.TrimRight(endpoint, "/")
}
func openAIJSONToolsContainImageGeneration(tools gjson.Result) bool {
if !tools.IsArray() {
return false
}
found := false
tools.ForEach(func(_, item gjson.Result) bool {
if strings.TrimSpace(item.Get("type").String()) == "image_generation" {
found = true
return false
}
return true
})
return found
}
func openAIJSONToolChoiceSelectsImageGeneration(choice gjson.Result) bool {
if !choice.Exists() {
return false
}
if choice.Type == gjson.String {
return strings.TrimSpace(choice.String()) == "image_generation"
}
if !choice.IsObject() {
return false
}
if strings.TrimSpace(choice.Get("type").String()) == "image_generation" {
return true
}
if strings.TrimSpace(choice.Get("tool.type").String()) == "image_generation" {
return true
}
if strings.TrimSpace(choice.Get("function.name").String()) == "image_generation" {
return true
}
return false
}
func openAIAnyToolChoiceSelectsImageGeneration(choice any) bool {
switch v := choice.(type) {
case string:
return strings.TrimSpace(v) == "image_generation"
case map[string]any:
if strings.TrimSpace(firstNonEmptyString(v["type"])) == "image_generation" {
return true
}
if tool, ok := v["tool"].(map[string]any); ok && strings.TrimSpace(firstNonEmptyString(tool["type"])) == "image_generation" {
return true
}
if fn, ok := v["function"].(map[string]any); ok && strings.TrimSpace(firstNonEmptyString(fn["name"])) == "image_generation" {
return true
}
}
return false
}
func getAPIKeyFromContext(c interface{ Get(string) (any, bool) }) *APIKey {
if c == nil {
return nil
}
v, exists := c.Get("api_key")
if !exists {
return nil
}
apiKey, _ := v.(*APIKey)
return apiKey
}
func apiKeyGroup(apiKey *APIKey) *Group {
if apiKey == nil {
return nil
}
return apiKey.Group
}
func cloneRequestMapForImageIntent(body []byte) map[string]any {
if len(body) == 0 {
return nil
}
var out map[string]any
if err := json.Unmarshal(body, &out); err != nil {
return nil
}
return out
}
func resolveOpenAIResponsesImageBillingConfig(reqBody map[string]any, fallbackModel string) (string, string, error) {
imageModel := ""
imageSize := ""
hasImageTool := false
if reqBody != nil {
rawTools, _ := reqBody["tools"].([]any)
for _, rawTool := range rawTools {
toolMap, ok := rawTool.(map[string]any)
if !ok || strings.TrimSpace(firstNonEmptyString(toolMap["type"])) != "image_generation" {
continue
}
hasImageTool = true
imageModel = strings.TrimSpace(firstNonEmptyString(toolMap["model"]))
imageSize = strings.TrimSpace(firstNonEmptyString(toolMap["size"]))
break
}
if imageSize == "" {
imageSize = strings.TrimSpace(firstNonEmptyString(reqBody["size"]))
}
}
if imageModel == "" && reqBody != nil {
bodyModel := strings.TrimSpace(firstNonEmptyString(reqBody["model"]))
if isOpenAIImageBillingModelAlias(bodyModel) || !hasImageTool {
imageModel = bodyModel
}
}
if imageModel == "" && hasImageTool {
imageModel = "gpt-image-2"
}
if imageModel == "" {
imageModel = strings.TrimSpace(fallbackModel)
}
sizeTier := normalizeOpenAIImageSizeTier(imageSize)
return imageModel, sizeTier, nil
}
func resolveOpenAIResponsesImageBillingConfigFromBody(body []byte, fallbackModel string) (string, string, error) {
reqBody := cloneRequestMapForImageIntent(body)
return resolveOpenAIResponsesImageBillingConfig(reqBody, fallbackModel)
}
func isOpenAIImageBillingModelAlias(model string) bool {
normalized := strings.ToLower(strings.TrimSpace(model))
if normalized == "" {
return false
}
return isOpenAIImageGenerationModel(normalized) || strings.Contains(normalized, "image")
}

View File

@ -0,0 +1,184 @@
package service
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestIsImageGenerationIntent(t *testing.T) {
tests := []struct {
name string
endpoint string
model string
body []byte
want bool
}{
{
name: "images endpoint",
endpoint: "/v1/images/generations",
body: []byte(`{"model":"gpt-image-2"}`),
want: true,
},
{
name: "image model",
endpoint: "/v1/responses",
model: "gpt-image-2",
body: []byte(`{"model":"gpt-image-2"}`),
want: true,
},
{
name: "image tool",
endpoint: "/v1/responses",
model: "gpt-5.4",
body: []byte(`{"model":"gpt-5.4","tools":[{"type":"image_generation"}]}`),
want: true,
},
{
name: "image tool choice",
endpoint: "/v1/responses",
model: "gpt-5.4",
body: []byte(`{"model":"gpt-5.4","tool_choice":{"type":"image_generation"}}`),
want: true,
},
{
name: "required tool choice alone is text",
endpoint: "/v1/responses",
model: "gpt-5.4",
body: []byte(`{"model":"gpt-5.4","tool_choice":"required"}`),
want: false,
},
{
name: "text only gpt 5.4",
endpoint: "/v1/responses",
model: "gpt-5.4",
body: []byte(`{"model":"gpt-5.4","input":"write code"}`),
want: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
require.Equal(t, tt.want, IsImageGenerationIntent(tt.endpoint, tt.model, tt.body))
})
}
}
func TestResolveOpenAIResponsesImageBillingConfigUsesCurrentBodyModel(t *testing.T) {
imageModel, imageSize, err := resolveOpenAIResponsesImageBillingConfigFromBody(
[]byte(`{"model":"mapped-image-model","tools":[{"type":"image_generation","size":"1024x1024"}]}`),
"requested-model",
)
require.NoError(t, err)
require.Equal(t, "mapped-image-model", imageModel)
require.Equal(t, "1K", imageSize)
}
func TestResolveOpenAIResponsesImageBillingConfigToolModelWins(t *testing.T) {
imageModel, imageSize, err := resolveOpenAIResponsesImageBillingConfigFromBody(
[]byte(`{"model":"mapped-text-model","tools":[{"type":"image_generation","model":"gpt-image-2","size":"1536x1024"}]}`),
"requested-model",
)
require.NoError(t, err)
require.Equal(t, "gpt-image-2", imageModel)
require.Equal(t, "2K", imageSize)
}
func TestResolveOpenAIResponsesImageBillingConfigSupportsOfficialAndCustomSizes(t *testing.T) {
tests := []struct {
name string
body []byte
wantTier string
}{
{
name: "official 2k landscape",
body: []byte(`{"model":"gpt-5.4","tools":[{"type":"image_generation","model":"gpt-image-2","size":"2048x1152"}]}`),
wantTier: "2K",
},
{
name: "official 4k landscape",
body: []byte(`{"model":"gpt-5.4","tools":[{"type":"image_generation","model":"gpt-image-2","size":"3840x2160"}]}`),
wantTier: "4K",
},
{
name: "custom valid 2k",
body: []byte(`{"model":"gpt-5.5","tools":[{"type":"image_generation","model":"gpt-image-2","size":"1280x768"}]}`),
wantTier: "2K",
},
{
name: "default image tool model supports flexible size",
body: []byte(`{"model":"gpt-5.4","tools":[{"type":"image_generation","size":"2048x1152"}]}`),
wantTier: "2K",
},
{
name: "top level image size is moved into billing",
body: []byte(`{"model":"gpt-image-2","size":"2048x2048","tools":[{"type":"image_generation","model":"gpt-image-2"}]}`),
wantTier: "2K",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
imageModel, imageSize, err := resolveOpenAIResponsesImageBillingConfigFromBody(tt.body, "requested-model")
require.NoError(t, err)
require.NotEmpty(t, imageModel)
require.Equal(t, tt.wantTier, imageSize)
})
}
}
func TestResolveOpenAIResponsesImageBillingConfigDoesNotRejectUnknownSizes(t *testing.T) {
imageModel, imageSize, err := resolveOpenAIResponsesImageBillingConfigFromBody(
[]byte(`{"model":"gpt-5.4","tools":[{"type":"image_generation","model":"gpt-image-1.5","size":"2048x1152"}]}`),
"requested-model",
)
require.NoError(t, err)
require.Equal(t, "gpt-image-1.5", imageModel)
require.Equal(t, "2K", imageSize)
}
func TestOpenAIImageOutputCounterDeduplicatesFinalImages(t *testing.T) {
counter := newOpenAIImageOutputCounter()
counter.AddSSEData([]byte(`{"type":"response.image_generation_call.partial_image","partial_image_b64":"abc"}`))
counter.AddSSEData([]byte(`{"type":"response.output_item.done","item":{"id":"ig_1","type":"image_generation_call","result":"final-a"}}`))
counter.AddSSEData([]byte(`{"type":"response.completed","response":{"output":[{"id":"ig_1","type":"image_generation_call","result":"final-a"},{"id":"ig_2","type":"image_generation_call","result":"final-b"}]}}`))
require.Equal(t, 2, counter.Count())
}
func TestOpenAIImageOutputCounterCountsImagesAPIStreamShapes(t *testing.T) {
counter := newOpenAIImageOutputCounter()
counter.AddSSEData([]byte(`{"type":"image_generation.completed","id":"ig_complete","b64_json":"final-a"}`))
counter.AddSSEData([]byte(`{"type":"response.output_item.done","item":{"id":"ig_item","type":"image_generation_call","result":"final-b"}}`))
counter.AddSSEData([]byte(`{"type":"response.completed","response":{"output":[{"id":"ig_done","type":"image_generation_call","result":"final-c"}]}}`))
require.Equal(t, 3, counter.Count())
dataCounter := newOpenAIImageOutputCounter()
dataCounter.AddSSEData([]byte(`{"data":[{"b64_json":"a"},{"b64_json":"b"}]}`))
dataCounter.AddSSEData([]byte(`{"data":[{"b64_json":"a"},{"b64_json":"b"},{"b64_json":"c"}]}`))
require.Equal(t, 3, dataCounter.Count())
}
func TestOpenAIImageOutputCounterCountsMultilineSSEDataPayload(t *testing.T) {
counter := newOpenAIImageOutputCounter()
counter.AddSSEData([]byte("{\"type\":\"image_generation.completed\",\n\"b64_json\":\"final-a\"}"))
require.Equal(t, 1, counter.Count())
}
func TestOpenAIImageOutputCounterCountsMultilineSSEBodyPayload(t *testing.T) {
counter := newOpenAIImageOutputCounter()
counter.AddSSEBody(
"data: {\"type\":\"image_generation.completed\",\n" +
"data: \"b64_json\":\"final-a\"}\n\n" +
"data: [DONE]\n\n",
)
require.Equal(t, 1, counter.Count())
}
func TestOpenAIImageOutputCounterFallsBackForInvalidMultilineSSEBody(t *testing.T) {
counter := newOpenAIImageOutputCounter()
counter.AddSSEBody(
"data: {\"type\":\"image_generation.completed\",\"b64_json\":\"final-a\"}\n" +
"data: {\"type\":\"image_generation.completed\",\"b64_json\":\"final-b\"}\n\n",
)
require.Equal(t, 2, counter.Count())
}

View File

@ -0,0 +1,149 @@
package service
import (
"crypto/sha256"
"encoding/hex"
"strings"
"github.com/tidwall/gjson"
)
type openAIImageOutputCounter struct {
seen map[string]struct{}
count int
maxDataCount int
}
func newOpenAIImageOutputCounter() *openAIImageOutputCounter {
return &openAIImageOutputCounter{seen: make(map[string]struct{})}
}
func (c *openAIImageOutputCounter) Count() int {
if c == nil {
return 0
}
if c.maxDataCount > c.count {
return c.maxDataCount
}
return c.count
}
func (c *openAIImageOutputCounter) AddJSONResponse(body []byte) {
if c == nil || len(body) == 0 || !gjson.ValidBytes(body) {
return
}
c.addDataArray(gjson.GetBytes(body, "data"))
c.addOutputArray(gjson.GetBytes(body, "output"))
c.addOutputArray(gjson.GetBytes(body, "response.output"))
}
func (c *openAIImageOutputCounter) AddSSEData(data []byte) {
if c == nil || len(data) == 0 || strings.TrimSpace(string(data)) == "[DONE]" || !gjson.ValidBytes(data) {
return
}
root := gjson.ParseBytes(data)
c.addDataArray(root.Get("data"))
eventType := strings.TrimSpace(root.Get("type").String())
switch eventType {
case "response.output_item.done":
c.addImageOutputItem(root.Get("item"))
case "response.completed", "response.done":
c.addOutputArray(root.Get("response.output"))
case "image_generation.completed":
if item := root.Get("item"); item.Exists() {
c.addImageOutputItem(item)
return
}
if output := root.Get("output"); output.Exists() {
c.addImageOutputItem(output)
return
}
c.addImageOutputItem(root)
}
}
func (c *openAIImageOutputCounter) AddSSEBody(body string) {
if c == nil || strings.TrimSpace(body) == "" {
return
}
forEachOpenAISSEDataPayload(body, c.AddSSEData)
}
func (c *openAIImageOutputCounter) addDataArray(data gjson.Result) {
if !data.IsArray() {
return
}
count := len(data.Array())
if count > c.maxDataCount {
c.maxDataCount = count
}
}
func (c *openAIImageOutputCounter) addOutputArray(output gjson.Result) {
if !output.IsArray() {
return
}
output.ForEach(func(_, item gjson.Result) bool {
c.addImageOutputItem(item)
return true
})
}
func (c *openAIImageOutputCounter) addImageOutputItem(item gjson.Result) {
if !item.Exists() || !item.IsObject() {
return
}
itemType := strings.TrimSpace(item.Get("type").String())
if itemType != "" && itemType != "image_generation_call" && itemType != "image_generation.completed" {
return
}
if strings.Contains(strings.ToLower(item.Raw), "partial_image") {
return
}
result := strings.TrimSpace(item.Get("result").String())
if result == "" {
result = strings.TrimSpace(item.Get("b64_json").String())
}
if result == "" {
result = strings.TrimSpace(item.Get("url").String())
}
if result == "" && itemType != "image_generation.completed" {
return
}
key := strings.TrimSpace(item.Get("id").String())
if key == "" {
key = strings.TrimSpace(item.Get("call_id").String())
}
if key == "" {
key = hashOpenAIImageOutputResult(result)
}
if key == "" {
return
}
if _, exists := c.seen[key]; exists {
return
}
c.seen[key] = struct{}{}
c.count++
}
func hashOpenAIImageOutputResult(result string) string {
result = strings.TrimSpace(result)
if result == "" {
return ""
}
sum := sha256.Sum256([]byte(result))
return hex.EncodeToString(sum[:])
}
func countOpenAIResponseImageOutputsFromJSONBytes(body []byte) int {
counter := newOpenAIImageOutputCounter()
counter.AddJSONResponse(body)
return counter.Count()
}
func countOpenAIImageOutputsFromSSEBody(body string) int {
counter := newOpenAIImageOutputCounter()
counter.AddSSEBody(body)
return counter.Count()
}

View File

@ -5,6 +5,7 @@ import (
"context"
"fmt"
"hash/fnv"
"log/slog"
"math"
"sort"
"strconv"
@ -345,7 +346,7 @@ func (s *defaultOpenAIAccountScheduler) selectBySessionHash(
_ = s.service.deleteStickySessionAccountID(ctx, req.GroupID, sessionHash)
return nil, nil
}
if !s.isAccountRequestCompatible(account, req) {
if !s.isAccountRequestCompatible(ctx, account, req) {
return nil, nil
}
if !s.isAccountTransportCompatible(account, req.RequiredTransport) {
@ -621,7 +622,7 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance(
fmt.Sprintf("Privacy not set, required by group [%s]", schedGroup.Name))
continue
}
if !s.isAccountRequestCompatible(account, req) {
if !s.isAccountRequestCompatible(ctx, account, req) {
continue
}
if !s.isAccountTransportCompatible(account, req.RequiredTransport) {
@ -822,11 +823,11 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance(
for i := 0; i < len(selectionOrder); i++ {
candidate := selectionOrder[i]
fresh := s.service.resolveFreshSchedulableOpenAIAccount(ctx, candidate.account, req.RequestedModel, false)
if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) || !s.isAccountRequestCompatible(fresh, req) {
if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) || !s.isAccountRequestCompatible(ctx, fresh, req) {
continue
}
fresh = s.service.recheckSelectedOpenAIAccountFromDB(ctx, fresh, req.RequestedModel, false)
if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) || !s.isAccountRequestCompatible(fresh, req) {
if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) || !s.isAccountRequestCompatible(ctx, fresh, req) {
continue
}
if req.RequireCompact && openAICompactSupportTier(fresh) == 0 {
@ -853,11 +854,11 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance(
// WaitPlan.MaxConcurrency 使用 Concurrency非 EffectiveLoadFactor因为 WaitPlan 控制的是 Redis 实际并发槽位等待。
for _, candidate := range selectionOrder {
fresh := s.service.resolveFreshSchedulableOpenAIAccount(ctx, candidate.account, req.RequestedModel, false)
if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) || !s.isAccountRequestCompatible(fresh, req) {
if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) || !s.isAccountRequestCompatible(ctx, fresh, req) {
continue
}
fresh = s.service.recheckSelectedOpenAIAccountFromDB(ctx, fresh, req.RequestedModel, false)
if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) || !s.isAccountRequestCompatible(fresh, req) {
if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) || !s.isAccountRequestCompatible(ctx, fresh, req) {
continue
}
if req.RequireCompact && openAICompactSupportTier(fresh) == 0 {
@ -888,13 +889,18 @@ func (s *defaultOpenAIAccountScheduler) isAccountTransportCompatible(account *Ac
return s.service.isOpenAIAccountTransportCompatible(account, requiredTransport)
}
func (s *defaultOpenAIAccountScheduler) isAccountRequestCompatible(account *Account, req OpenAIAccountScheduleRequest) bool {
func (s *defaultOpenAIAccountScheduler) isAccountRequestCompatible(ctx context.Context, account *Account, req OpenAIAccountScheduleRequest) bool {
if account == nil {
return false
}
if req.RequestedModel != "" && !account.IsModelSupported(req.RequestedModel) {
return false
}
if req.GroupID != nil && s != nil && s.service != nil &&
s.service.needsUpstreamChannelRestrictionCheck(ctx, req.GroupID) &&
s.service.isUpstreamModelRestrictedByChannel(ctx, *req.GroupID, account, req.RequestedModel, req.RequireCompact) {
return false
}
return account.SupportsOpenAIImageCapability(req.RequiredImageCapability)
}
@ -1106,6 +1112,13 @@ func (s *OpenAIGatewayService) selectAccountWithScheduler(
}
}
if s.checkChannelPricingRestriction(ctx, groupID, requestedModel) {
slog.Warn("channel pricing restriction blocked request",
"group_id", derefGroupID(groupID),
"model", requestedModel)
return nil, decision, fmt.Errorf("%w supporting model: %s (channel pricing restriction)", ErrNoAvailableAccounts, requestedModel)
}
var stickyAccountID int64
if sessionHash != "" && s.cache != nil {
if accountID, err := s.getStickySessionAccountID(ctx, groupID, sessionHash); err == nil && accountID > 0 {

View File

@ -485,12 +485,14 @@ func normalizeKnownCodexModel(model string) (string, bool) {
return model, true
}
modelID := model
if strings.Contains(modelID, "/") {
parts := strings.Split(modelID, "/")
modelID = parts[len(parts)-1]
}
modelID := lastOpenAIModelSegment(model)
if normalized := canonicalizeOpenAIModelAliasSpelling(modelID); normalized != "" {
modelID = normalized
}
if mapped := normalizeKnownOpenAICodexModel(modelID); mapped != "" {
return mapped, true
}
key := codexModelLookupKey(modelID)
if key == "" {
return "", false

View File

@ -804,15 +804,25 @@ func TestApplyCodexOAuthTransform_EmptyInput(t *testing.T) {
func TestNormalizeCodexModel_Gpt53(t *testing.T) {
cases := map[string]string{
"gpt-5.4": "gpt-5.4",
"gpt5.5": "gpt-5.5",
"openai/gpt5.5": "gpt-5.5",
"gpt5.4": "gpt-5.4",
"gpt-5.4-high": "gpt-5.4",
"gpt-5.4-chat-latest": "gpt-5.4",
"gpt 5.4": "gpt-5.4",
"gpt-5.4-mini": "gpt-5.4-mini",
"gpt5.4-mini": "gpt-5.4-mini",
"gpt5.4mini": "gpt-5.4-mini",
"gpt 5.4 mini": "gpt-5.4-mini",
"gpt-5.3": "gpt-5.3-codex",
"gpt5.3": "gpt-5.3-codex",
"gpt-5.3-codex": "gpt-5.3-codex",
"gpt5.3-codex": "gpt-5.3-codex",
"gpt5.3codex": "gpt-5.3-codex",
"gpt-5.3-codex-xhigh": "gpt-5.3-codex",
"gpt-5.3-codex-spark": "gpt-5.3-codex-spark",
"gpt5.3-codex-spark": "gpt-5.3-codex-spark",
"gpt5.3codexspark": "gpt-5.3-codex-spark",
"gpt 5.3 codex spark": "gpt-5.3-codex-spark",
"gpt-5.3-codex-spark-high": "gpt-5.3-codex-spark",
"gpt-5.3-codex-spark-xhigh": "gpt-5.3-codex-spark",

View File

@ -52,6 +52,12 @@ func (s *openAIRecordUsageBillingRepoStub) Apply(ctx context.Context, cmd *Usage
return &UsageBillingApplyResult{Applied: true}, nil
}
func TestOpenAIGatewayServiceRecordUsage_RejectsNilInput(t *testing.T) {
svc := &OpenAIGatewayService{}
require.Error(t, svc.RecordUsage(context.Background(), nil))
require.Error(t, svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{}))
}
type openAIRecordUsageUserRepoStub struct {
UserRepository
@ -1081,6 +1087,101 @@ func TestOpenAIGatewayServiceRecordUsage_ChannelMappedOverridesBillingModelWhenM
require.True(t, usageRepo.lastLog.ActualCost > 0, "cost must not be zero")
}
func TestOpenAIGatewayServiceRecordUsage_BillsCompactOpenAIModelAlias(t *testing.T) {
usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
userRepo := &openAIRecordUsageUserRepoStub{}
subRepo := &openAIRecordUsageSubRepoStub{}
svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil)
usage := OpenAIUsage{InputTokens: 20, OutputTokens: 10}
expectedCost, err := svc.billingService.CalculateCost("gpt-5.5", UsageTokens{
InputTokens: 20,
OutputTokens: 10,
}, 1.1)
require.NoError(t, err)
err = svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
Result: &OpenAIForwardResult{
RequestID: "resp_compact_openai_alias",
Model: "gpt5.5",
UpstreamModel: "gpt-5.4",
Usage: usage,
Duration: time.Second,
},
APIKey: &APIKey{ID: 10},
User: &User{ID: 20},
Account: &Account{ID: 30},
})
require.NoError(t, err)
require.NotNil(t, usageRepo.lastLog)
require.Equal(t, "gpt5.5", usageRepo.lastLog.Model)
require.NotNil(t, usageRepo.lastLog.UpstreamModel)
require.Equal(t, "gpt-5.4", *usageRepo.lastLog.UpstreamModel)
require.InDelta(t, expectedCost.ActualCost, usageRepo.lastLog.ActualCost, 1e-12)
require.True(t, usageRepo.lastLog.ActualCost > 0, "cost must not be zero")
require.InDelta(t, expectedCost.ActualCost, userRepo.lastAmount, 1e-12)
}
func TestOpenAIGatewayServiceRecordUsage_FallsBackToUpstreamModelWhenPrimaryUnpriceable(t *testing.T) {
usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
userRepo := &openAIRecordUsageUserRepoStub{}
subRepo := &openAIRecordUsageSubRepoStub{}
svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil)
usage := OpenAIUsage{InputTokens: 20, OutputTokens: 10}
expectedCost, err := svc.billingService.CalculateCost("gpt-5.4", UsageTokens{
InputTokens: 20,
OutputTokens: 10,
}, 1.1)
require.NoError(t, err)
err = svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
Result: &OpenAIForwardResult{
RequestID: "resp_unpriceable_primary_upstream_fallback",
Model: "not-priceable-alias",
BillingModel: "not-priceable-alias",
UpstreamModel: "gpt-5.4",
Usage: usage,
Duration: time.Second,
},
APIKey: &APIKey{ID: 10},
User: &User{ID: 20},
Account: &Account{ID: 30},
})
require.NoError(t, err)
require.NotNil(t, usageRepo.lastLog)
require.InDelta(t, expectedCost.ActualCost, usageRepo.lastLog.ActualCost, 1e-12)
require.True(t, usageRepo.lastLog.ActualCost > 0, "cost must not be zero")
require.InDelta(t, expectedCost.ActualCost, userRepo.lastAmount, 1e-12)
}
func TestOpenAIGatewayServiceRecordUsage_ReturnsErrorWhenTokenModelCannotBePriced(t *testing.T) {
usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
userRepo := &openAIRecordUsageUserRepoStub{}
subRepo := &openAIRecordUsageSubRepoStub{}
svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil)
err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
Result: &OpenAIForwardResult{
RequestID: "resp_unpriceable_without_upstream",
Model: "not-priceable-alias",
Usage: OpenAIUsage{InputTokens: 20, OutputTokens: 10},
Duration: time.Second,
},
APIKey: &APIKey{ID: 10},
User: &User{ID: 20},
Account: &Account{ID: 30},
})
require.Error(t, err)
require.Contains(t, err.Error(), "calculate OpenAI usage cost failed")
require.Equal(t, 0, usageRepo.calls)
require.Equal(t, 0, userRepo.deductCalls)
require.Equal(t, 0, subRepo.incrementCalls)
}
func TestOpenAIGatewayServiceRecordUsage_SubscriptionBillingSetsSubscriptionFields(t *testing.T) {
usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
userRepo := &openAIRecordUsageUserRepoStub{}
@ -1209,3 +1310,278 @@ func TestOpenAIGatewayServiceRecordUsage_ImageUsesPerImageBillingEvenWithUsageTo
require.InDelta(t, 0.0, usageRepo.lastLog.OutputCost, 1e-12)
require.InDelta(t, 0.0, usageRepo.lastLog.ImageOutputCost, 1e-12)
}
func TestOpenAIGatewayServiceRecordUsage_ImageSharedMultiplierPreservesExistingBehavior(t *testing.T) {
imagePrice := 0.2
groupID := int64(121)
usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
svc := newOpenAIRecordUsageServiceForTest(usageRepo, &openAIRecordUsageUserRepoStub{}, &openAIRecordUsageSubRepoStub{}, nil)
err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
Result: &OpenAIForwardResult{
RequestID: "resp_image_shared_multiplier",
Model: "gpt-image-2",
ImageCount: 1,
ImageSize: "1K",
Duration: time.Second,
},
APIKey: &APIKey{
ID: 10121,
GroupID: i64p(groupID),
Group: &Group{
ID: groupID,
RateMultiplier: 0.15,
ImageRateIndependent: false,
ImageRateMultiplier: 1,
ImagePrice1K: &imagePrice,
},
},
User: &User{ID: 20121},
Account: &Account{ID: 30121},
})
require.NoError(t, err)
require.NotNil(t, usageRepo.lastLog)
require.InDelta(t, 0.2, usageRepo.lastLog.TotalCost, 1e-12)
require.InDelta(t, 0.03, usageRepo.lastLog.ActualCost, 1e-12)
require.InDelta(t, 0.15, usageRepo.lastLog.RateMultiplier, 1e-12)
require.NotNil(t, usageRepo.lastLog.BillingMode)
require.Equal(t, string(BillingModeImage), *usageRepo.lastLog.BillingMode)
}
func TestOpenAIGatewayServiceRecordUsage_ImageSharedMultiplierUsesUserGroupOverride(t *testing.T) {
imagePrice := 0.5
userRate := 0.2
groupID := int64(125)
usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
svc := newOpenAIRecordUsageServiceForTest(
usageRepo,
&openAIRecordUsageUserRepoStub{},
&openAIRecordUsageSubRepoStub{},
&openAIUserGroupRateRepoStub{rate: &userRate},
)
err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
Result: &OpenAIForwardResult{
RequestID: "resp_image_user_group_override",
Model: "gpt-image-2",
ImageCount: 1,
ImageSize: "1K",
Duration: time.Second,
},
APIKey: &APIKey{
ID: 10125,
GroupID: i64p(groupID),
Group: &Group{
ID: groupID,
RateMultiplier: 0.15,
ImageRateIndependent: false,
ImageRateMultiplier: 1,
ImagePrice1K: &imagePrice,
},
},
User: &User{ID: 20125},
Account: &Account{ID: 30125},
})
require.NoError(t, err)
require.NotNil(t, usageRepo.lastLog)
require.InDelta(t, 0.5, usageRepo.lastLog.TotalCost, 1e-12)
require.InDelta(t, 0.1, usageRepo.lastLog.ActualCost, 1e-12)
require.InDelta(t, 0.2, usageRepo.lastLog.RateMultiplier, 1e-12)
}
func TestOpenAIGatewayServiceRecordUsage_ImageIndependentMultiplierUsesImageRate(t *testing.T) {
imagePrice := 0.2
groupID := int64(122)
usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
svc := newOpenAIRecordUsageServiceForTest(usageRepo, &openAIRecordUsageUserRepoStub{}, &openAIRecordUsageSubRepoStub{}, nil)
err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
Result: &OpenAIForwardResult{
RequestID: "resp_image_independent_multiplier",
Model: "gpt-image-2",
ImageCount: 1,
ImageSize: "1K",
Duration: time.Second,
},
APIKey: &APIKey{
ID: 10122,
GroupID: i64p(groupID),
Group: &Group{
ID: groupID,
RateMultiplier: 0.15,
ImageRateIndependent: true,
ImageRateMultiplier: 1,
ImagePrice1K: &imagePrice,
},
},
User: &User{ID: 20122},
Account: &Account{ID: 30122},
})
require.NoError(t, err)
require.NotNil(t, usageRepo.lastLog)
require.InDelta(t, 0.2, usageRepo.lastLog.TotalCost, 1e-12)
require.InDelta(t, 0.2, usageRepo.lastLog.ActualCost, 1e-12)
require.InDelta(t, 1.0, usageRepo.lastLog.RateMultiplier, 1e-12)
require.NotNil(t, usageRepo.lastLog.BillingMode)
require.Equal(t, string(BillingModeImage), *usageRepo.lastLog.BillingMode)
}
func TestOpenAIGatewayServiceRecordUsage_ChannelImageBillingUsesImageCountAndSharedMultiplier(t *testing.T) {
groupID := int64(123)
usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
svc := newOpenAIRecordUsageServiceForTest(usageRepo, &openAIRecordUsageUserRepoStub{}, &openAIRecordUsageSubRepoStub{}, nil)
svc.resolver = newOpenAIImageChannelPricingResolverForTest(t, groupID, "gpt-image-2", 0.25)
err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
Result: &OpenAIForwardResult{
RequestID: "resp_image_channel_shared",
Model: "gpt-image-2",
ImageCount: 3,
ImageSize: "1K",
Duration: time.Second,
},
APIKey: &APIKey{
ID: 10123,
GroupID: i64p(groupID),
Group: &Group{
ID: groupID,
RateMultiplier: 0.15,
ImageRateIndependent: false,
ImageRateMultiplier: 1,
},
},
User: &User{ID: 20123},
Account: &Account{ID: 30123},
})
require.NoError(t, err)
require.NotNil(t, usageRepo.lastLog)
require.InDelta(t, 0.75, usageRepo.lastLog.TotalCost, 1e-12)
require.InDelta(t, 0.1125, usageRepo.lastLog.ActualCost, 1e-12)
require.InDelta(t, 0.15, usageRepo.lastLog.RateMultiplier, 1e-12)
require.Equal(t, 3, usageRepo.lastLog.ImageCount)
require.NotNil(t, usageRepo.lastLog.BillingMode)
require.Equal(t, string(BillingModeImage), *usageRepo.lastLog.BillingMode)
}
func TestOpenAIGatewayServiceRecordUsage_ChannelImageBillingUsesImageCountAndIndependentMultiplier(t *testing.T) {
groupID := int64(124)
usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
svc := newOpenAIRecordUsageServiceForTest(usageRepo, &openAIRecordUsageUserRepoStub{}, &openAIRecordUsageSubRepoStub{}, nil)
svc.resolver = newOpenAIImageChannelPricingResolverForTest(t, groupID, "gpt-image-2", 0.25)
err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
Result: &OpenAIForwardResult{
RequestID: "resp_image_channel_independent",
Model: "gpt-image-2",
ImageCount: 3,
ImageSize: "1K",
Duration: time.Second,
},
APIKey: &APIKey{
ID: 10124,
GroupID: i64p(groupID),
Group: &Group{
ID: groupID,
RateMultiplier: 0.15,
ImageRateIndependent: true,
ImageRateMultiplier: 1,
},
},
User: &User{ID: 20124},
Account: &Account{ID: 30124},
})
require.NoError(t, err)
require.NotNil(t, usageRepo.lastLog)
require.InDelta(t, 0.75, usageRepo.lastLog.TotalCost, 1e-12)
require.InDelta(t, 0.75, usageRepo.lastLog.ActualCost, 1e-12)
require.InDelta(t, 1.0, usageRepo.lastLog.RateMultiplier, 1e-12)
require.Equal(t, 3, usageRepo.lastLog.ImageCount)
require.NotNil(t, usageRepo.lastLog.BillingMode)
require.Equal(t, string(BillingModeImage), *usageRepo.lastLog.BillingMode)
}
func newOpenAIImageChannelPricingResolverForTest(t *testing.T, groupID int64, model string, price float64) *ModelPricingResolver {
t.Helper()
cache := newEmptyChannelCache()
cache.pricingByGroupModel[channelModelKey{groupID: groupID, model: model}] = &ChannelModelPricing{
BillingMode: BillingModeImage,
PerRequestPrice: &price,
}
cache.channelByGroupID[groupID] = &Channel{ID: groupID, Status: StatusActive}
cache.groupPlatform[groupID] = ""
cache.loadedAt = time.Now()
cs := &ChannelService{}
cs.cache.Store(cache)
return NewModelPricingResolver(cs, NewBillingService(&config.Config{}, nil))
}
func TestGatewayServiceCalculateRecordUsageCost_ChannelImageBillingUsesImageCount(t *testing.T) {
groupID := int64(126)
billingService := NewBillingService(&config.Config{}, nil)
svc := &GatewayService{
billingService: billingService,
resolver: newOpenAIImageChannelPricingResolverForTest(t, groupID, "gemini-image", 0.25),
}
cost := svc.calculateRecordUsageCost(
context.Background(),
&ForwardResult{Model: "gemini-image", ImageCount: 2, ImageSize: "1K"},
&APIKey{GroupID: i64p(groupID), Group: &Group{ID: groupID}},
"gemini-image",
0.15,
1.0,
nil,
)
require.NotNil(t, cost)
require.Equal(t, string(BillingModeImage), cost.BillingMode)
require.InDelta(t, 0.5, cost.TotalCost, 1e-12)
require.InDelta(t, 0.5, cost.ActualCost, 1e-12)
}
func TestGatewayServiceCalculateRecordUsageCost_ChannelImageBillingUsesSizeTier(t *testing.T) {
groupID := int64(127)
defaultPrice := 0.10
price4K := 0.40
cache := newEmptyChannelCache()
cache.pricingByGroupModel[channelModelKey{groupID: groupID, model: "gemini-image"}] = &ChannelModelPricing{
BillingMode: BillingModeImage,
PerRequestPrice: &defaultPrice,
Intervals: []PricingInterval{{
TierLabel: "4K",
PerRequestPrice: &price4K,
}},
}
cache.channelByGroupID[groupID] = &Channel{ID: groupID, Status: StatusActive}
cache.loadedAt = time.Now()
channelService := &ChannelService{}
channelService.cache.Store(cache)
svc := &GatewayService{
billingService: NewBillingService(&config.Config{}, nil),
resolver: NewModelPricingResolver(channelService, NewBillingService(&config.Config{}, nil)),
}
cost := svc.calculateRecordUsageCost(
context.Background(),
&ForwardResult{Model: "gemini-image", ImageCount: 2, ImageSize: "4K"},
&APIKey{GroupID: i64p(groupID), Group: &Group{ID: groupID}},
"gemini-image",
1.0,
1.0,
nil,
)
require.NotNil(t, cost)
require.Equal(t, string(BillingModeImage), cost.BillingMode)
require.InDelta(t, 0.80, cost.TotalCost, 1e-12)
require.InDelta(t, 0.80, cost.ActualCost, 1e-12)
}

View File

@ -2049,6 +2049,21 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
promptCacheKey = strings.TrimSpace(v)
}
}
apiKey := getAPIKeyFromContext(c)
imageGenerationAllowed := GroupAllowsImageGeneration(nil)
if apiKey != nil {
imageGenerationAllowed = GroupAllowsImageGeneration(apiKey.Group)
}
if IsImageGenerationIntentMap(openAIResponsesEndpoint, reqModel, reqBody) && !imageGenerationAllowed {
setOpsUpstreamError(c, http.StatusForbidden, ImageGenerationPermissionMessage(), "")
c.JSON(http.StatusForbidden, gin.H{
"error": gin.H{
"type": "permission_error",
"message": ImageGenerationPermissionMessage(),
},
})
return nil, errors.New("image generation disabled for group")
}
// Track if body needs re-serialization
bodyModified := false
@ -2108,7 +2123,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
markPatchSet("instructions", "You are a helpful coding assistant.")
}
if isCodexCLI && ensureOpenAIResponsesImageGenerationTool(reqBody) {
if isCodexCLI && imageGenerationAllowed && ensureOpenAIResponsesImageGenerationTool(reqBody) {
bodyModified = true
disablePatch()
logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Injected /responses image_generation tool for Codex client")
@ -2119,7 +2134,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
disablePatch()
logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Normalized /responses image_generation tool payload")
}
if isCodexCLI && applyCodexImageGenerationBridgeInstructions(reqBody) {
if isCodexCLI && imageGenerationAllowed && applyCodexImageGenerationBridgeInstructions(reqBody) {
bodyModified = true
disablePatch()
logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Added Codex image_generation bridge instructions")
@ -2134,7 +2149,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
markPatchSet("model", billingModel)
}
upstreamModel := billingModel
if normalizeOpenAIResponsesImageOnlyModel(reqBody) {
if imageGenerationAllowed && normalizeOpenAIResponsesImageOnlyModel(reqBody) {
bodyModified = true
disablePatch()
if model, ok := reqBody["model"].(string); ok {
@ -2355,6 +2370,34 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
}
}
if IsImageGenerationIntentMap(openAIResponsesEndpoint, reqModel, reqBody) && !imageGenerationAllowed {
setOpsUpstreamError(c, http.StatusForbidden, ImageGenerationPermissionMessage(), "")
c.JSON(http.StatusForbidden, gin.H{
"error": gin.H{
"type": "permission_error",
"message": ImageGenerationPermissionMessage(),
},
})
return nil, errors.New("image generation disabled for group")
}
imageBillingModel := ""
imageSizeTier := ""
if IsImageGenerationIntentMap(openAIResponsesEndpoint, reqModel, reqBody) {
var imageCfgErr error
imageBillingModel, imageSizeTier, imageCfgErr = resolveOpenAIResponsesImageBillingConfig(reqBody, billingModel)
if imageCfgErr != nil {
setOpsUpstreamError(c, http.StatusBadRequest, imageCfgErr.Error(), "")
c.JSON(http.StatusBadRequest, gin.H{
"error": gin.H{
"type": "invalid_request_error",
"message": imageCfgErr.Error(),
"param": "size",
},
})
return nil, imageCfgErr
}
}
// Re-serialize body only if modified
if bodyModified {
serializedByPatch := false
@ -2592,6 +2635,10 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
wsAttempts,
)
wsResult.UpstreamModel = upstreamModel
if wsResult.ImageCount > 0 {
wsResult.ImageSize = imageSizeTier
wsResult.BillingModel = imageBillingModel
}
return wsResult, nil
}
s.writeOpenAIWSFallbackErrorResponse(c, account, wsErr)
@ -2695,6 +2742,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
// Handle normal response
var usage *OpenAIUsage
var firstTokenMs *int
imageCount := 0
if reqStream {
streamResult, err := s.handleStreamingResponse(ctx, resp, c, account, startTime, originalModel, upstreamModel)
if err != nil {
@ -2702,11 +2750,14 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
}
usage = streamResult.usage
firstTokenMs = streamResult.firstTokenMs
imageCount = streamResult.imageCount
} else {
usage, err = s.handleNonStreamingResponse(ctx, resp, c, account, originalModel, upstreamModel)
nonStreamResult, err := s.handleNonStreamingResponse(ctx, resp, c, account, originalModel, upstreamModel)
if err != nil {
return nil, err
}
usage = nonStreamResult.usage
imageCount = nonStreamResult.imageCount
}
// Extract and save Codex usage snapshot from response headers (for OAuth accounts)
@ -2723,7 +2774,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
reasoningEffort := extractOpenAIReasoningEffort(reqBody, originalModel)
serviceTier := extractOpenAIServiceTier(reqBody)
return &OpenAIForwardResult{
forwardResult := &OpenAIForwardResult{
RequestID: resp.Header.Get("x-request-id"),
Usage: *usage,
Model: originalModel,
@ -2734,7 +2785,13 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
OpenAIWSMode: false,
Duration: time.Since(startTime),
FirstTokenMs: firstTokenMs,
}, nil
}
if imageCount > 0 {
forwardResult.ImageCount = imageCount
forwardResult.ImageSize = imageSizeTier
forwardResult.BillingModel = imageBillingModel
}
return forwardResult, nil
}
}
@ -2823,6 +2880,35 @@ func (s *OpenAIGatewayService) forwardOpenAIPassthrough(
}
body = updatedBody
apiKey := getAPIKeyFromContext(c)
if IsImageGenerationIntent(openAIResponsesEndpoint, reqModel, body) && !GroupAllowsImageGeneration(apiKeyGroup(apiKey)) {
setOpsUpstreamError(c, http.StatusForbidden, ImageGenerationPermissionMessage(), "")
c.JSON(http.StatusForbidden, gin.H{
"error": gin.H{
"type": "permission_error",
"message": ImageGenerationPermissionMessage(),
},
})
return nil, errors.New("image generation disabled for group")
}
imageBillingModel := ""
imageSizeTier := ""
if IsImageGenerationIntent(openAIResponsesEndpoint, reqModel, body) {
var imageCfgErr error
imageBillingModel, imageSizeTier, imageCfgErr = resolveOpenAIResponsesImageBillingConfigFromBody(body, reqModel)
if imageCfgErr != nil {
setOpsUpstreamError(c, http.StatusBadRequest, imageCfgErr.Error(), "")
c.JSON(http.StatusBadRequest, gin.H{
"error": gin.H{
"type": "invalid_request_error",
"message": imageCfgErr.Error(),
"param": "size",
},
})
return nil, imageCfgErr
}
}
logger.LegacyPrintf("service.openai_gateway",
"[OpenAI 自动透传] 命中自动透传分支: account=%d name=%s type=%s model=%s stream=%v",
account.ID,
@ -2905,6 +2991,7 @@ func (s *OpenAIGatewayService) forwardOpenAIPassthrough(
var usage *OpenAIUsage
var firstTokenMs *int
imageCount := 0
if reqStream {
result, err := s.handleStreamingResponsePassthrough(ctx, resp, c, account, startTime, reqModel, upstreamPassthroughModel)
if err != nil {
@ -2912,11 +2999,14 @@ func (s *OpenAIGatewayService) forwardOpenAIPassthrough(
}
usage = result.usage
firstTokenMs = result.firstTokenMs
imageCount = result.imageCount
} else {
usage, err = s.handleNonStreamingResponsePassthrough(ctx, resp, c, reqModel, upstreamPassthroughModel)
result, err := s.handleNonStreamingResponsePassthrough(ctx, resp, c, reqModel, upstreamPassthroughModel)
if err != nil {
return nil, err
}
usage = result.usage
imageCount = result.imageCount
}
if snapshot := ParseCodexRateLimitHeaders(resp.Header); snapshot != nil {
@ -2927,7 +3017,7 @@ func (s *OpenAIGatewayService) forwardOpenAIPassthrough(
usage = &OpenAIUsage{}
}
return &OpenAIForwardResult{
forwardResult := &OpenAIForwardResult{
RequestID: resp.Header.Get("x-request-id"),
Usage: *usage,
Model: reqModel,
@ -2938,7 +3028,13 @@ func (s *OpenAIGatewayService) forwardOpenAIPassthrough(
OpenAIWSMode: false,
Duration: time.Since(startTime),
FirstTokenMs: firstTokenMs,
}, nil
}
if imageCount > 0 {
forwardResult.ImageCount = imageCount
forwardResult.ImageSize = imageSizeTier
forwardResult.BillingModel = imageBillingModel
}
return forwardResult, nil
}
func logOpenAIPassthroughInstructionsRejected(
@ -3233,6 +3329,13 @@ func collectOpenAIPassthroughTimeoutHeaders(h http.Header) []string {
type openaiStreamingResultPassthrough struct {
usage *OpenAIUsage
firstTokenMs *int
imageCount int
}
type openaiNonStreamingResultPassthrough struct {
*OpenAIUsage
usage *OpenAIUsage
imageCount int
}
func openAIStreamClientOutputStarted(c *gin.Context, localStarted bool) bool {
@ -3369,6 +3472,7 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough(
}
usage := &OpenAIUsage{}
imageCounter := newOpenAIImageOutputCounter()
var firstTokenMs *int
clientDisconnected := false
sawDone := false
@ -3400,6 +3504,9 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough(
defer putSSEScannerBuf64K(scanBuf)
needModelReplace := strings.TrimSpace(originalModel) != "" && strings.TrimSpace(mappedModel) != "" && strings.TrimSpace(originalModel) != strings.TrimSpace(mappedModel)
resultWithUsage := func() *openaiStreamingResultPassthrough {
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs, imageCount: imageCounter.Count()}
}
for scanner.Scan() {
line := scanner.Text()
@ -3419,7 +3526,7 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough(
if eventType == "response.failed" {
failedMessage = extractOpenAISSEErrorMessage(dataBytes)
if !openAIStreamClientOutputStarted(c, clientOutputStarted) && openAIStreamFailedEventShouldFailover(dataBytes, failedMessage) {
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs},
return resultWithUsage(),
s.newOpenAIStreamFailoverError(c, account, true, upstreamRequestID, dataBytes, failedMessage)
}
forceFlushFailedEvent = true
@ -3431,6 +3538,7 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough(
if openAIStreamEventIsTerminal(trimmedData) {
sawTerminalEvent = true
}
imageCounter.AddSSEData(dataBytes)
lineStartsClientOutput = forceFlushFailedEvent || openAIStreamDataStartsClientOutput(trimmedData, eventType)
if firstTokenMs == nil && lineStartsClientOutput && trimmedData != "[DONE]" {
ms := int(time.Since(startTime).Milliseconds())
@ -3460,28 +3568,28 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough(
}
if err := scanner.Err(); err != nil {
if sawTerminalEvent && !sawFailedEvent {
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, nil
return resultWithUsage(), nil
}
if sawFailedEvent {
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("upstream response failed: %s", failedMessage)
return resultWithUsage(), fmt.Errorf("upstream response failed: %s", failedMessage)
}
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream usage incomplete: %w", err)
return resultWithUsage(), fmt.Errorf("stream usage incomplete: %w", err)
}
if errors.Is(err, bufio.ErrTooLong) {
logger.LegacyPrintf("service.openai_gateway", "[OpenAI passthrough] SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, err)
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, err
return resultWithUsage(), err
}
if !openAIStreamClientOutputStarted(c, clientOutputStarted) {
msg := "OpenAI stream disconnected before completion"
if errText := strings.TrimSpace(err.Error()); errText != "" {
msg += ": " + errText
}
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs},
return resultWithUsage(),
s.newOpenAIStreamFailoverError(c, account, true, upstreamRequestID, nil, msg)
}
if clientDisconnected {
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream usage incomplete after disconnect: %w", err)
return resultWithUsage(), fmt.Errorf("stream usage incomplete after disconnect: %w", err)
}
logger.LegacyPrintf("service.openai_gateway",
"[OpenAI passthrough] 流读取异常中断: account=%d request_id=%s err=%v",
@ -3489,10 +3597,10 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough(
upstreamRequestID,
err,
)
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream read error: %w", err)
return resultWithUsage(), fmt.Errorf("stream read error: %w", err)
}
if sawFailedEvent {
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("upstream response failed: %s", failedMessage)
return resultWithUsage(), fmt.Errorf("upstream response failed: %s", failedMessage)
}
if !clientDisconnected && !sawDone && !sawTerminalEvent && ctx.Err() == nil {
logger.FromContext(ctx).With(
@ -3501,13 +3609,13 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough(
zap.String("upstream_request_id", upstreamRequestID),
).Info("OpenAI passthrough 上游流在未收到 [DONE] 时结束,疑似断流")
if !openAIStreamClientOutputStarted(c, clientOutputStarted) {
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs},
return resultWithUsage(),
s.newOpenAIStreamFailoverError(c, account, true, upstreamRequestID, nil, "OpenAI stream ended before a terminal event")
}
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, errors.New("stream usage incomplete: missing terminal event")
return resultWithUsage(), errors.New("stream usage incomplete: missing terminal event")
}
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, nil
return resultWithUsage(), nil
}
func (s *OpenAIGatewayService) handleNonStreamingResponsePassthrough(
@ -3516,7 +3624,7 @@ func (s *OpenAIGatewayService) handleNonStreamingResponsePassthrough(
c *gin.Context,
originalModel string,
mappedModel string,
) (*OpenAIUsage, error) {
) (*openaiNonStreamingResultPassthrough, error) {
body, err := ReadUpstreamResponseBody(resp.Body, s.cfg, c, openAITooLargeError)
if err != nil {
return nil, err
@ -3553,14 +3661,18 @@ func (s *OpenAIGatewayService) handleNonStreamingResponsePassthrough(
body = s.replaceModelInResponseBody(body, mappedModel, originalModel)
}
c.Data(resp.StatusCode, contentType, body)
return usage, nil
return &openaiNonStreamingResultPassthrough{
OpenAIUsage: usage,
usage: usage,
imageCount: countOpenAIResponseImageOutputsFromJSONBytes(body),
}, nil
}
// handlePassthroughSSEToJSON converts an SSE response body into a JSON
// response for the passthrough path. It mirrors handleSSEToJSON while
// preserving passthrough payloads, except compact-only model remapping may
// rewrite model fields back to the original requested model.
func (s *OpenAIGatewayService) handlePassthroughSSEToJSON(resp *http.Response, c *gin.Context, body []byte, originalModel string, mappedModel string) (*OpenAIUsage, error) {
func (s *OpenAIGatewayService) handlePassthroughSSEToJSON(resp *http.Response, c *gin.Context, body []byte, originalModel string, mappedModel string) (*openaiNonStreamingResultPassthrough, error) {
bodyText := string(body)
finalResponse, ok := extractCodexFinalResponse(bodyText)
@ -3611,7 +3723,11 @@ func (s *OpenAIGatewayService) handlePassthroughSSEToJSON(resp *http.Response, c
}
c.Data(resp.StatusCode, contentType, body)
return usage, nil
return &openaiNonStreamingResultPassthrough{
OpenAIUsage: usage,
usage: usage,
imageCount: countOpenAIImageOutputsFromSSEBody(bodyText),
}, nil
}
func writeOpenAIPassthroughResponseHeaders(dst http.Header, src http.Header, filter *responseheaders.CompiledHeaderFilter) {
@ -4025,6 +4141,13 @@ func (s *OpenAIGatewayService) handleCompatErrorResponse(
type openaiStreamingResult struct {
usage *OpenAIUsage
firstTokenMs *int
imageCount int
}
type openaiNonStreamingResult struct {
*OpenAIUsage
usage *OpenAIUsage
imageCount int
}
func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, startTime time.Time, originalModel, mappedModel string) (*openaiStreamingResult, error) {
@ -4058,6 +4181,7 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
}
usage := &OpenAIUsage{}
imageCounter := newOpenAIImageOutputCounter()
var firstTokenMs *int
scanner := bufio.NewScanner(resp.Body)
maxLineSize := defaultMaxLineSize
@ -4136,7 +4260,7 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
needModelReplace := originalModel != mappedModel
resultWithUsage := func() *openaiStreamingResult {
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs, imageCount: imageCounter.Count()}
}
finalizeStream := func() (*openaiStreamingResult, error) {
if !sawTerminalEvent {
@ -4231,6 +4355,7 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
forceFlushFailedEvent = true
sawFailedEvent = true
}
imageCounter.AddSSEData(dataBytes)
// Correct Codex tool calls if needed (apply_patch -> edit, etc.)
if correctedData, corrected := s.toolCorrector.CorrectToolCallsInSSEBytes(dataBytes); corrected {
@ -4496,7 +4621,7 @@ func extractOpenAIUsageFromJSONBytes(body []byte) (OpenAIUsage, bool) {
}, true
}
func (s *OpenAIGatewayService) handleNonStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, originalModel, mappedModel string) (*OpenAIUsage, error) {
func (s *OpenAIGatewayService) handleNonStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, originalModel, mappedModel string) (*openaiNonStreamingResult, error) {
body, err := ReadUpstreamResponseBody(resp.Body, s.cfg, c, openAITooLargeError)
if err != nil {
return nil, err
@ -4542,7 +4667,11 @@ func (s *OpenAIGatewayService) handleNonStreamingResponse(ctx context.Context, r
c.Data(resp.StatusCode, contentType, body)
return usage, nil
return &openaiNonStreamingResult{
OpenAIUsage: usage,
usage: usage,
imageCount: countOpenAIResponseImageOutputsFromJSONBytes(body),
}, nil
}
func isEventStreamResponse(header http.Header) bool {
@ -4550,7 +4679,7 @@ func isEventStreamResponse(header http.Header) bool {
return strings.Contains(contentType, "text/event-stream")
}
func (s *OpenAIGatewayService) handleSSEToJSON(resp *http.Response, c *gin.Context, body []byte, originalModel, mappedModel string) (*OpenAIUsage, error) {
func (s *OpenAIGatewayService) handleSSEToJSON(resp *http.Response, c *gin.Context, body []byte, originalModel, mappedModel string) (*openaiNonStreamingResult, error) {
bodyText := string(body)
finalResponse, ok := extractCodexFinalResponse(bodyText)
@ -4602,21 +4731,29 @@ func (s *OpenAIGatewayService) handleSSEToJSON(resp *http.Response, c *gin.Conte
}
c.Data(resp.StatusCode, contentType, body)
return usage, nil
return &openaiNonStreamingResult{
OpenAIUsage: usage,
usage: usage,
imageCount: countOpenAIImageOutputsFromSSEBody(bodyText),
}, nil
}
func extractOpenAISSETerminalEvent(body string) (string, []byte, bool) {
lines := strings.Split(body, "\n")
for _, line := range lines {
data, ok := extractOpenAISSEDataLine(line)
if !ok || data == "" || data == "[DONE]" {
continue
var terminalType string
var terminalPayload []byte
forEachOpenAISSEDataPayload(body, func(data []byte) {
if terminalPayload != nil {
return
}
eventType := strings.TrimSpace(gjson.Get(data, "type").String())
eventType := strings.TrimSpace(gjson.GetBytes(data, "type").String())
switch eventType {
case "response.completed", "response.done", "response.failed", "response.incomplete", "response.cancelled", "response.canceled":
return eventType, []byte(data), true
terminalType = eventType
terminalPayload = append([]byte(nil), data...)
}
})
if terminalPayload != nil {
return terminalType, terminalPayload, true
}
return "", nil, false
}
@ -4651,21 +4788,20 @@ func (s *OpenAIGatewayService) writeOpenAINonStreamingProtocolError(resp *http.R
}
func extractCodexFinalResponse(body string) ([]byte, bool) {
lines := strings.Split(body, "\n")
for _, line := range lines {
data, ok := extractOpenAISSEDataLine(line)
if !ok {
continue
var finalResponse []byte
forEachOpenAISSEDataPayload(body, func(data []byte) {
if finalResponse != nil {
return
}
if data == "" || data == "[DONE]" {
continue
}
eventType := gjson.Get(data, "type").String()
eventType := gjson.GetBytes(data, "type").String()
if eventType == "response.done" || eventType == "response.completed" {
if response := gjson.Get(data, "response"); response.Exists() && response.Type == gjson.JSON && response.Raw != "" {
return []byte(response.Raw), true
if response := gjson.GetBytes(data, "response"); response.Exists() && response.Type == gjson.JSON && response.Raw != "" {
finalResponse = []byte(response.Raw)
}
}
})
if finalResponse != nil {
return finalResponse, true
}
return nil, false
}
@ -4677,21 +4813,15 @@ func reconstructResponseOutputFromSSE(bodyText string) ([]byte, bool) {
acc := apicompat.NewBufferedResponseAccumulator()
imageOutputs := make([]json.RawMessage, 0, 1)
seenImages := make(map[string]struct{})
lines := strings.Split(bodyText, "\n")
for _, line := range lines {
data, ok := extractOpenAISSEDataLine(line)
if !ok || data == "" || data == "[DONE]" {
continue
}
if imageOutput, ok := extractImageGenerationOutputFromSSEData([]byte(data), seenImages); ok {
forEachOpenAISSEDataPayload(bodyText, func(data []byte) {
if imageOutput, ok := extractImageGenerationOutputFromSSEData(data, seenImages); ok {
imageOutputs = append(imageOutputs, imageOutput)
}
var event apicompat.ResponsesStreamEvent
if err := json.Unmarshal([]byte(data), &event); err != nil {
continue
if err := json.Unmarshal(data, &event); err == nil {
acc.ProcessEvent(&event)
}
acc.ProcessEvent(&event)
}
})
if !acc.HasContent() && len(imageOutputs) == 0 {
return nil, false
}
@ -4744,17 +4874,9 @@ func extractImageGenerationOutputFromSSEData(data []byte, seen map[string]struct
func (s *OpenAIGatewayService) parseSSEUsageFromBody(body string) *OpenAIUsage {
usage := &OpenAIUsage{}
lines := strings.Split(body, "\n")
for _, line := range lines {
data, ok := extractOpenAISSEDataLine(line)
if !ok {
continue
}
if data == "" || data == "[DONE]" {
continue
}
s.parseSSEUsageBytes([]byte(data), usage)
}
forEachOpenAISSEDataPayload(body, func(data []byte) {
s.parseSSEUsageBytes(data, usage)
})
return usage
}
@ -5036,8 +5158,14 @@ type OpenAIRecordUsageInput struct {
// RecordUsage records usage and deducts balance
func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRecordUsageInput) error {
if input == nil {
return errors.New("openai usage input is nil")
}
result := input.Result
if s.rateLimitService != nil && input != nil && input.Account != nil && input.Account.Platform == PlatformOpenAI {
if result == nil {
return errors.New("openai usage result is nil")
}
if s.rateLimitService != nil && input.Account != nil && input.Account.Platform == PlatformOpenAI {
s.rateLimitService.ResetOpenAI403Counter(ctx, input.Account.ID)
}
@ -5074,6 +5202,7 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
}
multiplier = resolver.Resolve(ctx, user.ID, *apiKey.GroupID, apiKey.Group.RateMultiplier)
}
imageMultiplier := resolveImageRateMultiplier(apiKey, multiplier)
var cost *CostBreakdown
var err error
@ -5087,13 +5216,21 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
if input.BillingModelSource == BillingModelSourceRequested && input.OriginalModel != "" {
billingModel = input.OriginalModel
}
billingModels := usageBillingModelCandidates(
billingModel,
result.BillingModel,
input.ChannelMappedModel,
input.OriginalModel,
result.UpstreamModel,
result.Model,
)
serviceTier := ""
if result.ServiceTier != nil {
serviceTier = strings.TrimSpace(*result.ServiceTier)
}
cost, err = s.calculateOpenAIRecordUsageCost(ctx, result, apiKey, billingModel, multiplier, tokens, serviceTier)
cost, err = s.calculateOpenAIRecordUsageCost(ctx, result, apiKey, billingModels, multiplier, imageMultiplier, tokens, serviceTier)
if err != nil {
cost = &CostBreakdown{ActualCost: 0}
return err
}
// Determine billing type
@ -5143,7 +5280,11 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
usageLog.TotalCost = cost.TotalCost
usageLog.ActualCost = cost.ActualCost
}
usageLog.RateMultiplier = multiplier
if result.ImageCount > 0 {
usageLog.RateMultiplier = imageMultiplier
} else {
usageLog.RateMultiplier = multiplier
}
usageLog.AccountRateMultiplier = &accountRateMultiplier
usageLog.BillingType = billingType
usageLog.Stream = result.Stream
@ -5224,14 +5365,45 @@ func (s *OpenAIGatewayService) calculateOpenAIRecordUsageCost(
ctx context.Context,
result *OpenAIForwardResult,
apiKey *APIKey,
billingModels []string,
multiplier float64,
imageMultiplier float64,
tokens UsageTokens,
serviceTier string,
) (*CostBreakdown, error) {
billingModel := firstUsageBillingModel(billingModels)
if result != nil && result.ImageCount > 0 {
return s.calculateOpenAIImageCost(ctx, billingModel, apiKey, result, imageMultiplier), nil
}
if len(billingModels) == 0 || billingModel == "" {
return nil, errors.New("openai usage billing model is empty")
}
var lastErr error
for _, candidate := range billingModels {
candidate = strings.TrimSpace(candidate)
if candidate == "" {
continue
}
cost, err := s.calculateOpenAIRecordUsageTokenCost(ctx, apiKey, candidate, multiplier, tokens, serviceTier)
if err == nil {
return cost, nil
}
lastErr = err
}
if lastErr == nil {
lastErr = errors.New("no non-empty billing model candidates")
}
return nil, fmt.Errorf("calculate OpenAI usage cost failed for billing models %s: %w", strings.Join(billingModels, ","), lastErr)
}
func (s *OpenAIGatewayService) calculateOpenAIRecordUsageTokenCost(
ctx context.Context,
apiKey *APIKey,
billingModel string,
multiplier float64,
tokens UsageTokens,
serviceTier string,
) (*CostBreakdown, error) {
if result != nil && result.ImageCount > 0 {
return s.calculateOpenAIImageCost(ctx, billingModel, apiKey, result, multiplier), nil
}
if s.resolver != nil && apiKey.Group != nil {
gid := apiKey.Group.ID
return s.billingService.CalculateCostUnified(CostInput{
@ -5262,7 +5434,7 @@ func (s *OpenAIGatewayService) calculateOpenAIImageCost(
Ctx: ctx,
Model: billingModel,
GroupID: &gid,
RequestCount: 1,
RequestCount: result.ImageCount,
SizeTier: result.ImageSize,
RateMultiplier: multiplier,
Resolver: s.resolver,

View File

@ -0,0 +1,215 @@
package service
import (
"context"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
"github.com/tidwall/gjson"
)
func TestOpenAIGatewayServiceForward_RejectsDisabledImageGenerationIntents(t *testing.T) {
gin.SetMode(gin.TestMode)
tests := []struct {
name string
body []byte
}{
{
name: "image model",
body: []byte(`{"model":"gpt-image-2","input":"draw"}`),
},
{
name: "image tool",
body: []byte(`{"model":"gpt-5.4","input":"draw","tools":[{"type":"image_generation"}]}`),
},
{
name: "image tool choice",
body: []byte(`{"model":"gpt-5.4","input":"draw","tool_choice":{"type":"image_generation"}}`),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
upstream := &httpUpstreamRecorder{}
svc := newOpenAIImageGenerationControlTestService(upstream)
c, recorder := newOpenAIImageGenerationControlTestContext(false, "unit-test-agent/1.0")
account := newOpenAIImageGenerationControlTestAccount()
result, err := svc.Forward(context.Background(), c, account, tt.body)
require.Error(t, err)
require.Nil(t, result)
require.Equal(t, http.StatusForbidden, recorder.Code)
require.Equal(t, "permission_error", gjson.GetBytes(recorder.Body.Bytes(), "error.type").String())
require.Nil(t, upstream.lastReq, "disabled image request must not reach upstream")
})
}
}
func TestOpenAIGatewayServiceForward_DisabledGroupAllowsTextOnlyResponses(t *testing.T) {
gin.SetMode(gin.TestMode)
upstream := &httpUpstreamRecorder{
resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"application/json"}},
Body: io.NopCloser(strings.NewReader(`{"id":"resp_text","model":"gpt-5.4","usage":{"input_tokens":3,"output_tokens":2}}`)),
},
}
svc := newOpenAIImageGenerationControlTestService(upstream)
c, recorder := newOpenAIImageGenerationControlTestContext(false, "unit-test-agent/1.0")
account := newOpenAIImageGenerationControlTestAccount()
result, err := svc.Forward(context.Background(), c, account, []byte(`{"model":"gpt-5.4","input":"write code","stream":false}`))
require.NoError(t, err)
require.NotNil(t, result)
require.Equal(t, http.StatusOK, recorder.Code)
require.Equal(t, 3, result.Usage.InputTokens)
require.Equal(t, 2, result.Usage.OutputTokens)
require.Equal(t, 0, result.ImageCount)
require.NotNil(t, upstream.lastReq)
}
func TestOpenAIGatewayServiceForward_CodexImageInjectionRespectsGroupCapability(t *testing.T) {
gin.SetMode(gin.TestMode)
tests := []struct {
name string
allowImages bool
wantInjected bool
}{
{name: "disabled group skips injection", allowImages: false, wantInjected: false},
{name: "enabled group injects image tool", allowImages: true, wantInjected: true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
upstream := &httpUpstreamRecorder{
resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"application/json"}},
Body: io.NopCloser(strings.NewReader(`{"id":"resp_codex","model":"gpt-5.4","usage":{"input_tokens":1,"output_tokens":1}}`)),
},
}
svc := newOpenAIImageGenerationControlTestService(upstream)
c, _ := newOpenAIImageGenerationControlTestContext(tt.allowImages, "codex_cli_rs/0.98.0")
account := newOpenAIImageGenerationControlTestAccount()
result, err := svc.Forward(context.Background(), c, account, []byte(`{"model":"gpt-5.4","input":"write code","stream":false}`))
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, upstream.lastReq)
hasImageTool := gjson.GetBytes(upstream.lastBody, `tools.#(type=="image_generation")`).Exists()
require.Equal(t, tt.wantInjected, hasImageTool)
instructions := gjson.GetBytes(upstream.lastBody, "instructions").String()
require.Equal(t, tt.wantInjected, strings.Contains(instructions, "image_generation"))
})
}
}
func TestOpenAIGatewayServiceHandleResponsesImageOutputs_NonStreaming(t *testing.T) {
gin.SetMode(gin.TestMode)
svc := newOpenAIImageGenerationControlTestService(&httpUpstreamRecorder{})
c, _ := newOpenAIImageGenerationControlTestContext(true, "unit-test-agent/1.0")
resp := &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"application/json"}},
Body: io.NopCloser(strings.NewReader(`{
"id":"resp_image_json",
"model":"gpt-5.4",
"output":[{"id":"ig_json_1","type":"image_generation_call","result":"final-image"}],
"usage":{"input_tokens":7,"output_tokens":3,"output_tokens_details":{"image_tokens":2}}
}`)),
}
result, err := svc.handleNonStreamingResponse(context.Background(), resp, c, &Account{ID: 1, Type: AccountTypeAPIKey}, "gpt-5.4", "gpt-5.4")
require.NoError(t, err)
require.NotNil(t, result)
require.Equal(t, 1, result.imageCount)
require.NotNil(t, result.usage)
require.Equal(t, 7, result.usage.InputTokens)
require.Equal(t, 3, result.usage.OutputTokens)
require.Equal(t, 2, result.usage.ImageOutputTokens)
}
func TestOpenAIGatewayServiceHandleResponsesImageOutputs_Streaming(t *testing.T) {
gin.SetMode(gin.TestMode)
svc := newOpenAIImageGenerationControlTestService(&httpUpstreamRecorder{})
c, _ := newOpenAIImageGenerationControlTestContext(true, "unit-test-agent/1.0")
resp := &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"text/event-stream"}},
Body: io.NopCloser(strings.NewReader(
"data: {\"type\":\"response.output_item.done\",\"item\":{\"id\":\"ig_stream_1\",\"type\":\"image_generation_call\",\"result\":\"final-image\"}}\n\n" +
"data: {\"type\":\"response.completed\",\"response\":{\"id\":\"resp_image_stream\",\"model\":\"gpt-5.5\",\"output\":[{\"id\":\"ig_stream_1\",\"type\":\"image_generation_call\",\"result\":\"final-image\"}],\"usage\":{\"input_tokens\":11,\"output_tokens\":5,\"output_tokens_details\":{\"image_tokens\":4}}}}\n\n",
)),
}
result, err := svc.handleStreamingResponse(context.Background(), resp, c, &Account{ID: 1}, time.Now(), "gpt-5.5", "gpt-5.5")
require.NoError(t, err)
require.NotNil(t, result)
require.Equal(t, 1, result.imageCount)
require.NotNil(t, result.usage)
require.Equal(t, 11, result.usage.InputTokens)
require.Equal(t, 5, result.usage.OutputTokens)
require.Equal(t, 4, result.usage.ImageOutputTokens)
}
func newOpenAIImageGenerationControlTestService(upstream *httpUpstreamRecorder) *OpenAIGatewayService {
cfg := &config.Config{}
return &OpenAIGatewayService{
cfg: cfg,
httpUpstream: upstream,
cache: &stubGatewayCache{},
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
toolCorrector: NewCodexToolCorrector(),
}
}
func newOpenAIImageGenerationControlTestContext(allowImages bool, userAgent string) (*gin.Context, *httptest.ResponseRecorder) {
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
c.Request.Header.Set("User-Agent", userAgent)
groupID := int64(4242)
c.Set("api_key", &APIKey{
ID: 2424,
GroupID: &groupID,
Group: &Group{
ID: groupID,
AllowImageGeneration: allowImages,
RateMultiplier: 1,
ImageRateMultiplier: 1,
},
})
return c, recorder
}
func newOpenAIImageGenerationControlTestAccount() *Account {
return &Account{
ID: 5151,
Name: "openai-image-controls",
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Credentials: map[string]any{
"api_key": "sk-test",
},
}
}

View File

@ -16,6 +16,7 @@ import (
"net/textproto"
"strconv"
"strings"
"sync/atomic"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
@ -468,14 +469,54 @@ func isOpenAINativeImageOption(name string) bool {
}
func normalizeOpenAIImageSizeTier(size string) string {
switch strings.ToLower(strings.TrimSpace(size)) {
trimmed := strings.TrimSpace(size)
normalized := strings.ToLower(trimmed)
switch normalized {
case "", "auto":
return "2K"
case "1024x1024":
return "1K"
case "1536x1024", "1024x1536", "1792x1024", "1024x1792", "", "auto":
case "1536x1024", "1024x1536", "1792x1024", "1024x1792", "2048x2048", "2048x1152", "1152x2048":
return "2K"
default:
case "3840x2160", "2160x3840":
return "4K"
}
width, height, ok := parseOpenAIImageSizeDimensions(trimmed)
if !ok {
return "2K"
}
return classifyUnknownOpenAIImageSizeTier(width, height)
}
const (
openAIImage2KMaxPixels = 2560 * 1440
)
func parseOpenAIImageSizeDimensions(size string) (int, int, bool) {
trimmed := strings.TrimSpace(size)
parts := strings.Split(strings.ToLower(trimmed), "x")
if len(parts) != 2 {
return 0, 0, false
}
width, err := strconv.Atoi(strings.TrimSpace(parts[0]))
if err != nil {
return 0, 0, false
}
height, err := strconv.Atoi(strings.TrimSpace(parts[1]))
if err != nil {
return 0, 0, false
}
if width <= 0 || height <= 0 {
return 0, 0, false
}
return width, height, true
}
func classifyUnknownOpenAIImageSizeTier(width int, height int) string {
if height > 0 && width > openAIImage2KMaxPixels/height {
return "4K"
}
return "2K"
}
func (s *OpenAIGatewayService) ForwardImages(
@ -535,11 +576,14 @@ func (s *OpenAIGatewayService) forwardOpenAIImagesAPIKey(
setOpsUpstreamRequestBody(c, forwardBody)
}
token, _, err := s.GetAccessToken(ctx, account)
upstreamCtx, releaseUpstreamCtx := detachStreamUpstreamContext(ctx, parsed.Stream)
defer releaseUpstreamCtx()
token, _, err := s.GetAccessToken(upstreamCtx, account)
if err != nil {
return nil, err
}
upstreamReq, err := s.buildOpenAIImagesRequest(ctx, c, account, forwardBody, forwardContentType, token, parsed.Endpoint)
upstreamReq, err := s.buildOpenAIImagesRequest(upstreamCtx, c, account, forwardBody, forwardContentType, token, parsed.Endpoint)
if err != nil {
return nil, err
}
@ -582,14 +626,14 @@ func (s *OpenAIGatewayService) forwardOpenAIImagesAPIKey(
Kind: "failover",
Message: upstreamMsg,
})
s.handleFailoverSideEffects(ctx, resp, account)
s.handleFailoverSideEffects(upstreamCtx, resp, account)
return nil, &UpstreamFailoverError{
StatusCode: resp.StatusCode,
ResponseBody: respBody,
RetryableOnSameAccount: account.IsPoolMode() && isPoolModeRetryableStatus(resp.StatusCode),
}
}
return s.handleErrorResponse(ctx, resp, c, account, forwardBody)
return s.handleErrorResponse(upstreamCtx, resp, c, account, forwardBody)
}
defer func() { _ = resp.Body.Close() }()
@ -599,6 +643,20 @@ func (s *OpenAIGatewayService) forwardOpenAIImagesAPIKey(
if parsed.Stream && isEventStreamResponse(resp.Header) {
streamUsage, streamCount, ttft, err := s.handleOpenAIImagesStreamingResponse(resp, c, startTime)
if err != nil {
if streamCount > 0 {
return &OpenAIForwardResult{
RequestID: resp.Header.Get("x-request-id"),
Usage: streamUsage,
Model: requestModel,
UpstreamModel: upstreamModel,
Stream: parsed.Stream,
ResponseHeaders: resp.Header.Clone(),
Duration: time.Since(startTime),
FirstTokenMs: ttft,
ImageCount: streamCount,
ImageSize: parsed.SizeTier,
}, err
}
return nil, err
}
usage = streamUsage
@ -807,66 +865,205 @@ func (s *OpenAIGatewayService) handleOpenAIImagesStreamingResponse(
return OpenAIUsage{}, 0, nil, fmt.Errorf("streaming is not supported by response writer")
}
reader := bufio.NewReader(resp.Body)
usage := OpenAIUsage{}
imageCount := 0
imageCounter := newOpenAIImageOutputCounter()
var firstTokenMs *int
clientDisconnected := false
lastDownstreamWriteAt := time.Now()
var fallbackBody bytes.Buffer
fallbackBytes := int64(0)
fallbackLimit := resolveUpstreamResponseReadLimit(s.cfg)
seenSSEData := false
fallbackTooLarge := false
var sseData openAISSEDataAccumulator
processSSEData := func(dataBytes []byte) {
seenSSEData = true
fallbackBody.Reset()
fallbackBytes = 0
mergeOpenAIUsage(&usage, dataBytes)
imageCounter.AddSSEData(dataBytes)
}
flushSSEEvent := func() {
sseData.Flush(processSSEData)
}
processLine := func(line []byte) {
if len(line) == 0 {
return
}
if firstTokenMs == nil {
ms := int(time.Since(startTime).Milliseconds())
firstTokenMs = &ms
}
if !clientDisconnected {
if _, writeErr := c.Writer.Write(line); writeErr != nil {
clientDisconnected = true
logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Images stream client disconnected, continue draining upstream for billing")
} else {
flusher.Flush()
lastDownstreamWriteAt = time.Now()
}
}
trimmedLine := strings.TrimRight(string(line), "\r\n")
if _, ok := extractOpenAISSEDataLine(trimmedLine); ok || strings.TrimSpace(trimmedLine) == "" {
sseData.AddLine(trimmedLine, processSSEData)
return
}
if !seenSSEData && !fallbackTooLarge {
fallbackBytes += int64(len(line))
if fallbackBytes <= fallbackLimit {
_, _ = fallbackBody.Write(line)
} else {
fallbackTooLarge = true
fallbackBody.Reset()
}
}
}
finalizeFallbackBody := func() {
if seenSSEData || fallbackBody.Len() == 0 {
return
}
body := bytes.TrimSpace(fallbackBody.Bytes())
if len(body) == 0 {
return
}
mergeOpenAIUsage(&usage, body)
imageCounter.AddJSONResponse(body)
}
streamInterval := s.openAIImageStreamDataInterval()
keepaliveInterval := s.openAIImageStreamKeepaliveInterval()
if streamInterval <= 0 && keepaliveInterval <= 0 {
reader := bufio.NewReader(resp.Body)
for {
line, err := reader.ReadBytes('\n')
processLine(line)
if err == io.EOF {
break
}
if err != nil {
flushSSEEvent()
return usage, imageCounter.Count(), firstTokenMs, err
}
}
flushSSEEvent()
finalizeFallbackBody()
return usage, imageCounter.Count(), firstTokenMs, nil
}
type readEvent struct {
line []byte
err error
}
events := make(chan readEvent, 16)
done := make(chan struct{})
sendEvent := func(ev readEvent) bool {
select {
case events <- ev:
return true
case <-done:
return false
}
}
var lastReadAt int64
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
go func() {
defer close(events)
reader := bufio.NewReader(resp.Body)
for {
line, err := reader.ReadBytes('\n')
if len(line) > 0 {
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
}
if len(line) > 0 && !sendEvent(readEvent{line: line}) {
return
}
if err == io.EOF {
return
}
if err != nil {
_ = sendEvent(readEvent{err: err})
return
}
}
}()
defer close(done)
var intervalTicker *time.Ticker
if streamInterval > 0 {
intervalTicker = time.NewTicker(streamInterval)
defer intervalTicker.Stop()
}
var intervalCh <-chan time.Time
if intervalTicker != nil {
intervalCh = intervalTicker.C
}
var keepaliveTicker *time.Ticker
if keepaliveInterval > 0 {
keepaliveTicker = time.NewTicker(keepaliveInterval)
defer keepaliveTicker.Stop()
}
var keepaliveCh <-chan time.Time
if keepaliveTicker != nil {
keepaliveCh = keepaliveTicker.C
}
for {
line, err := reader.ReadBytes('\n')
if len(line) > 0 {
if firstTokenMs == nil {
ms := int(time.Since(startTime).Milliseconds())
firstTokenMs = &ms
select {
case ev, ok := <-events:
if !ok {
flushSSEEvent()
finalizeFallbackBody()
return usage, imageCounter.Count(), firstTokenMs, nil
}
if _, writeErr := c.Writer.Write(line); writeErr != nil {
return OpenAIUsage{}, 0, firstTokenMs, writeErr
if ev.err != nil {
flushSSEEvent()
return usage, imageCounter.Count(), firstTokenMs, ev.err
}
processLine(ev.line)
case <-intervalCh:
lastRead := time.Unix(0, atomic.LoadInt64(&lastReadAt))
if time.Since(lastRead) < streamInterval {
continue
}
if clientDisconnected {
return usage, imageCounter.Count(), firstTokenMs, fmt.Errorf("image stream incomplete after timeout")
}
logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Images stream data interval timeout: interval=%s", streamInterval)
_ = s.writeOpenAIImagesStreamEvent(c, flusher, "error", buildOpenAIImagesStreamErrorBody(fmt.Sprintf("upstream image stream idle for %s", streamInterval)))
return usage, imageCounter.Count(), firstTokenMs, fmt.Errorf("image stream data interval timeout")
case <-keepaliveCh:
if clientDisconnected || time.Since(lastDownstreamWriteAt) < keepaliveInterval {
continue
}
if _, writeErr := io.WriteString(c.Writer, ":\n\n"); writeErr != nil {
clientDisconnected = true
logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Images stream client disconnected during keepalive, continue draining upstream for billing")
continue
}
flusher.Flush()
lastDownstreamWriteAt = time.Now()
}
}
}
if data, ok := extractOpenAISSEDataLine(strings.TrimRight(string(line), "\r\n")); ok {
if data != "" && data != "[DONE]" {
seenSSEData = true
fallbackBody.Reset()
fallbackBytes = 0
dataBytes := []byte(data)
mergeOpenAIUsage(&usage, dataBytes)
if count := extractOpenAIImagesBillableCountFromJSONBytes(dataBytes); count > imageCount {
imageCount = count
}
}
} else if !seenSSEData && !fallbackTooLarge {
fallbackBytes += int64(len(line))
if fallbackBytes <= fallbackLimit {
_, _ = fallbackBody.Write(line)
} else {
fallbackTooLarge = true
fallbackBody.Reset()
}
}
}
if err == io.EOF {
break
}
if err != nil {
return OpenAIUsage{}, 0, firstTokenMs, err
}
func (s *OpenAIGatewayService) openAIImageStreamDataInterval() time.Duration {
if s == nil || s.cfg == nil || s.cfg.Gateway.ImageStreamDataIntervalTimeout <= 0 {
return 0
}
if !seenSSEData && fallbackBody.Len() > 0 {
body := bytes.TrimSpace(fallbackBody.Bytes())
if len(body) > 0 {
mergeOpenAIUsage(&usage, body)
if count := extractOpenAIImagesBillableCountFromJSONBytes(body); count > imageCount {
imageCount = count
}
}
return time.Duration(s.cfg.Gateway.ImageStreamDataIntervalTimeout) * time.Second
}
func (s *OpenAIGatewayService) openAIImageStreamKeepaliveInterval() time.Duration {
if s == nil || s.cfg == nil || s.cfg.Gateway.ImageStreamKeepaliveInterval <= 0 {
return 0
}
return usage, imageCount, firstTokenMs, nil
return time.Duration(s.cfg.Gateway.ImageStreamKeepaliveInterval) * time.Second
}
func extractOpenAIImagesBillableCountFromJSONBytes(body []byte) int {
@ -913,14 +1110,7 @@ func mergeOpenAIUsage(dst *OpenAIUsage, body []byte) {
}
func extractOpenAIImageCountFromJSONBytes(body []byte) int {
if len(body) == 0 || !gjson.ValidBytes(body) {
return 0
}
data := gjson.GetBytes(body, "data")
if data.Exists() && data.IsArray() {
return len(data.Array())
}
return 0
return countOpenAIResponseImageOutputsFromJSONBytes(body)
}
type openAIImagePointerInfo struct {

View File

@ -9,6 +9,7 @@ import (
"io"
"net/http"
"strings"
"sync/atomic"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
@ -361,21 +362,21 @@ func collectOpenAIImagesFromResponsesBody(body []byte) ([]openAIResponsesImageRe
var (
fallbackResults []openAIResponsesImageResult
fallbackSeen = make(map[string]struct{})
finalResults []openAIResponsesImageResult
finalMeta openAIResponsesImageResult
collectErr error
createdAt int64
usageRaw []byte
foundFinal bool
responseMeta openAIResponsesImageResult
)
for _, line := range bytes.Split(body, []byte("\n")) {
line = bytes.TrimRight(line, "\r")
data, ok := extractOpenAISSEDataLine(string(line))
if !ok || data == "" || data == "[DONE]" {
continue
forEachOpenAISSEDataPayload(string(body), func(payload []byte) {
if collectErr != nil || len(finalResults) > 0 {
return
}
payload := []byte(data)
if !gjson.ValidBytes(payload) {
continue
return
}
if meta, eventCreatedAt, ok := extractOpenAIResponsesImageMetaFromLifecycleEvent(payload); ok {
mergeOpenAIResponsesImageMeta(&responseMeta, meta)
@ -388,7 +389,8 @@ func collectOpenAIImagesFromResponsesBody(body []byte) ([]openAIResponsesImageRe
case "response.output_item.done":
result, itemID, ok, err := extractOpenAIImageFromResponsesOutputItemDone(payload)
if err != nil {
return nil, 0, nil, openAIResponsesImageResult{}, false, err
collectErr = err
return
}
if ok {
mergeOpenAIResponsesImageMeta(&result, responseMeta)
@ -397,7 +399,8 @@ func collectOpenAIImagesFromResponsesBody(body []byte) ([]openAIResponsesImageRe
case "response.completed":
results, completedAt, completedUsageRaw, firstMeta, err := extractOpenAIImagesFromResponsesCompleted(payload)
if err != nil {
return nil, 0, nil, openAIResponsesImageResult{}, false, err
collectErr = err
return
}
foundFinal = true
if completedAt > 0 {
@ -408,14 +411,24 @@ func collectOpenAIImagesFromResponsesBody(body []byte) ([]openAIResponsesImageRe
}
if len(results) > 0 {
mergeOpenAIResponsesImageMeta(&firstMeta, responseMeta)
return results, createdAt, usageRaw, firstMeta, true, nil
finalResults = results
finalMeta = firstMeta
return
}
if len(fallbackResults) > 0 {
firstMeta = fallbackResults[0]
mergeOpenAIResponsesImageMeta(&firstMeta, responseMeta)
return fallbackResults, createdAt, usageRaw, firstMeta, true, nil
finalResults = fallbackResults
finalMeta = firstMeta
return
}
}
})
if collectErr != nil {
return nil, 0, nil, openAIResponsesImageResult{}, false, collectErr
}
if len(finalResults) > 0 {
return finalResults, createdAt, usageRaw, finalMeta, true, nil
}
if len(fallbackResults) > 0 {
@ -505,6 +518,30 @@ func (s *OpenAIGatewayService) writeOpenAIImagesStreamEvent(c *gin.Context, flus
return nil
}
func (s *OpenAIGatewayService) tryWriteOpenAIImagesStreamEvent(
c *gin.Context,
flusher http.Flusher,
clientDisconnected *bool,
lastWriteAt *time.Time,
eventName string,
payload []byte,
) bool {
if clientDisconnected != nil && *clientDisconnected {
return false
}
if err := s.writeOpenAIImagesStreamEvent(c, flusher, eventName, payload); err != nil {
if clientDisconnected != nil {
*clientDisconnected = true
}
logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Images stream client disconnected, continue draining upstream for billing")
return false
}
if lastWriteAt != nil {
*lastWriteAt = time.Now()
}
return true
}
func (s *OpenAIGatewayService) handleOpenAIImagesOAuthNonStreamingResponse(
resp *http.Response,
c *gin.Context,
@ -517,15 +554,9 @@ func (s *OpenAIGatewayService) handleOpenAIImagesOAuthNonStreamingResponse(
}
var usage OpenAIUsage
for _, line := range bytes.Split(body, []byte("\n")) {
line = bytes.TrimRight(line, "\r")
data, ok := extractOpenAISSEDataLine(string(line))
if !ok || data == "" || data == "[DONE]" {
continue
}
dataBytes := []byte(data)
s.parseSSEUsageBytes(dataBytes, &usage)
}
forEachOpenAISSEDataPayload(string(body), func(data []byte) {
s.parseSSEUsageBytes(data, &usage)
})
results, createdAt, usageRaw, firstMeta, _, err := collectOpenAIImagesFromResponsesBody(body)
if err != nil {
return OpenAIUsage{}, 0, err
@ -570,7 +601,6 @@ func (s *OpenAIGatewayService) handleOpenAIImagesOAuthStreamingResponse(
format = "b64_json"
}
reader := bufio.NewReader(resp.Body)
usage := OpenAIUsage{}
imageCount := 0
var firstTokenMs *int
@ -579,141 +609,307 @@ func (s *OpenAIGatewayService) handleOpenAIImagesOAuthStreamingResponse(
pendingSeen := make(map[string]struct{})
streamMeta := openAIResponsesImageResult{Model: strings.TrimSpace(fallbackModel)}
var createdAt int64
clientDisconnected := false
lastDownstreamWriteAt := time.Now()
var sseData openAISSEDataAccumulator
var processDataErr error
processDataDone := false
processData := func(dataBytes []byte) {
if processDataDone || processDataErr != nil {
return
}
if firstTokenMs == nil {
ms := int(time.Since(startTime).Milliseconds())
firstTokenMs = &ms
}
s.parseSSEUsageBytes(dataBytes, &usage)
if !gjson.ValidBytes(dataBytes) {
return
}
if meta, eventCreatedAt, ok := extractOpenAIResponsesImageMetaFromLifecycleEvent(dataBytes); ok {
mergeOpenAIResponsesImageMeta(&streamMeta, meta)
if eventCreatedAt > 0 {
createdAt = eventCreatedAt
}
}
switch gjson.GetBytes(dataBytes, "type").String() {
case "response.image_generation_call.partial_image":
b64 := strings.TrimSpace(gjson.GetBytes(dataBytes, "partial_image_b64").String())
if b64 == "" {
return
}
eventName := streamPrefix + ".partial_image"
partialMeta := streamMeta
mergeOpenAIResponsesImageMeta(&partialMeta, openAIResponsesImageResult{
OutputFormat: strings.TrimSpace(gjson.GetBytes(dataBytes, "output_format").String()),
Background: strings.TrimSpace(gjson.GetBytes(dataBytes, "background").String()),
})
payload := buildOpenAIImagesStreamPartialPayload(
eventName,
b64,
gjson.GetBytes(dataBytes, "partial_image_index").Int(),
format,
createdAt,
partialMeta,
)
s.tryWriteOpenAIImagesStreamEvent(c, flusher, &clientDisconnected, &lastDownstreamWriteAt, eventName, payload)
case "response.output_item.done":
img, itemID, ok, extractErr := extractOpenAIImageFromResponsesOutputItemDone(dataBytes)
if extractErr != nil {
s.tryWriteOpenAIImagesStreamEvent(c, flusher, &clientDisconnected, &lastDownstreamWriteAt, "error", buildOpenAIImagesStreamErrorBody(extractErr.Error()))
processDataErr = extractErr
processDataDone = true
return
}
if !ok {
return
}
mergeOpenAIResponsesImageMeta(&streamMeta, img)
mergeOpenAIResponsesImageMeta(&img, streamMeta)
key := openAIResponsesImageResultKey(itemID, img)
if _, exists := emitted[key]; exists {
return
}
if _, exists := pendingSeen[key]; exists {
return
}
pendingSeen[key] = struct{}{}
pendingResults = append(pendingResults, img)
case "response.completed":
results, _, usageRaw, firstMeta, extractErr := extractOpenAIImagesFromResponsesCompleted(dataBytes)
if extractErr != nil {
s.tryWriteOpenAIImagesStreamEvent(c, flusher, &clientDisconnected, &lastDownstreamWriteAt, "error", buildOpenAIImagesStreamErrorBody(extractErr.Error()))
processDataErr = extractErr
processDataDone = true
return
}
mergeOpenAIResponsesImageMeta(&streamMeta, firstMeta)
finalResults := make([]openAIResponsesImageResult, 0, len(results)+len(pendingResults))
finalSeen := make(map[string]struct{})
for _, img := range results {
mergeOpenAIResponsesImageMeta(&img, streamMeta)
appendOpenAIResponsesImageResultDedup(&finalResults, finalSeen, "", img)
}
for _, img := range pendingResults {
mergeOpenAIResponsesImageMeta(&img, streamMeta)
appendOpenAIResponsesImageResultDedup(&finalResults, finalSeen, "", img)
}
if len(finalResults) == 0 {
outputErr := fmt.Errorf("upstream did not return image output")
s.tryWriteOpenAIImagesStreamEvent(c, flusher, &clientDisconnected, &lastDownstreamWriteAt, "error", buildOpenAIImagesStreamErrorBody(outputErr.Error()))
processDataErr = outputErr
processDataDone = true
return
}
eventName := streamPrefix + ".completed"
for _, img := range finalResults {
key := openAIResponsesImageResultKey("", img)
if _, exists := emitted[key]; exists {
continue
}
payload := buildOpenAIImagesStreamCompletedPayload(eventName, img, format, createdAt, usageRaw)
emitted[key] = struct{}{}
s.tryWriteOpenAIImagesStreamEvent(c, flusher, &clientDisconnected, &lastDownstreamWriteAt, eventName, payload)
}
imageCount = len(emitted)
processDataDone = true
}
}
processLine := func(line []byte) (bool, error) {
if len(line) == 0 {
return false, nil
}
sseData.AddLine(string(line), processData)
if processDataErr != nil {
return true, processDataErr
}
return processDataDone, nil
}
flushData := func() (bool, error) {
sseData.Flush(processData)
if processDataErr != nil {
return true, processDataErr
}
return processDataDone, nil
}
finalizePending := func() error {
if imageCount > 0 {
return nil
}
if len(pendingResults) > 0 {
eventName := streamPrefix + ".completed"
for _, img := range pendingResults {
mergeOpenAIResponsesImageMeta(&img, streamMeta)
key := openAIResponsesImageResultKey("", img)
if _, exists := emitted[key]; exists {
continue
}
payload := buildOpenAIImagesStreamCompletedPayload(eventName, img, format, createdAt, nil)
emitted[key] = struct{}{}
s.tryWriteOpenAIImagesStreamEvent(c, flusher, &clientDisconnected, &lastDownstreamWriteAt, eventName, payload)
}
imageCount = len(emitted)
return nil
}
streamErr := fmt.Errorf("stream disconnected before image generation completed")
s.tryWriteOpenAIImagesStreamEvent(c, flusher, &clientDisconnected, &lastDownstreamWriteAt, "error", buildOpenAIImagesStreamErrorBody(streamErr.Error()))
return streamErr
}
streamInterval := s.openAIImageStreamDataInterval()
keepaliveInterval := s.openAIImageStreamKeepaliveInterval()
if streamInterval <= 0 && keepaliveInterval <= 0 {
reader := bufio.NewReader(resp.Body)
for {
line, err := reader.ReadBytes('\n')
done, processErr := processLine(line)
if processErr != nil {
return usage, imageCount, firstTokenMs, processErr
}
if done {
return usage, imageCount, firstTokenMs, nil
}
if err == io.EOF {
break
}
if err != nil {
if done, processErr := flushData(); processErr != nil {
return usage, imageCount, firstTokenMs, processErr
} else if done {
return usage, imageCount, firstTokenMs, nil
}
s.tryWriteOpenAIImagesStreamEvent(c, flusher, &clientDisconnected, &lastDownstreamWriteAt, "error", buildOpenAIImagesStreamErrorBody(err.Error()))
return usage, imageCount, firstTokenMs, err
}
}
if done, processErr := flushData(); processErr != nil {
return usage, imageCount, firstTokenMs, processErr
} else if done {
return usage, imageCount, firstTokenMs, nil
}
if err := finalizePending(); err != nil {
return usage, imageCount, firstTokenMs, err
}
return usage, imageCount, firstTokenMs, nil
}
type readEvent struct {
line []byte
err error
}
events := make(chan readEvent, 16)
done := make(chan struct{})
sendEvent := func(ev readEvent) bool {
select {
case events <- ev:
return true
case <-done:
return false
}
}
var lastReadAt int64
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
go func() {
defer close(events)
reader := bufio.NewReader(resp.Body)
for {
line, err := reader.ReadBytes('\n')
if len(line) > 0 {
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
}
if len(line) > 0 && !sendEvent(readEvent{line: line}) {
return
}
if err == io.EOF {
return
}
if err != nil {
_ = sendEvent(readEvent{err: err})
return
}
}
}()
defer close(done)
var intervalTicker *time.Ticker
if streamInterval > 0 {
intervalTicker = time.NewTicker(streamInterval)
defer intervalTicker.Stop()
}
var intervalCh <-chan time.Time
if intervalTicker != nil {
intervalCh = intervalTicker.C
}
var keepaliveTicker *time.Ticker
if keepaliveInterval > 0 {
keepaliveTicker = time.NewTicker(keepaliveInterval)
defer keepaliveTicker.Stop()
}
var keepaliveCh <-chan time.Time
if keepaliveTicker != nil {
keepaliveCh = keepaliveTicker.C
}
for {
line, err := reader.ReadBytes('\n')
if len(line) > 0 {
trimmedLine := strings.TrimRight(string(line), "\r\n")
data, ok := extractOpenAISSEDataLine(trimmedLine)
if ok && data != "" && data != "[DONE]" {
if firstTokenMs == nil {
ms := int(time.Since(startTime).Milliseconds())
firstTokenMs = &ms
select {
case ev, ok := <-events:
if !ok {
if done, processErr := flushData(); processErr != nil {
return usage, imageCount, firstTokenMs, processErr
} else if done {
return usage, imageCount, firstTokenMs, nil
}
dataBytes := []byte(data)
s.parseSSEUsageBytes(dataBytes, &usage)
if gjson.ValidBytes(dataBytes) {
if meta, eventCreatedAt, ok := extractOpenAIResponsesImageMetaFromLifecycleEvent(dataBytes); ok {
mergeOpenAIResponsesImageMeta(&streamMeta, meta)
if eventCreatedAt > 0 {
createdAt = eventCreatedAt
}
}
switch gjson.GetBytes(dataBytes, "type").String() {
case "response.image_generation_call.partial_image":
b64 := strings.TrimSpace(gjson.GetBytes(dataBytes, "partial_image_b64").String())
if b64 != "" {
eventName := streamPrefix + ".partial_image"
partialMeta := streamMeta
mergeOpenAIResponsesImageMeta(&partialMeta, openAIResponsesImageResult{
OutputFormat: strings.TrimSpace(gjson.GetBytes(dataBytes, "output_format").String()),
Background: strings.TrimSpace(gjson.GetBytes(dataBytes, "background").String()),
})
payload := buildOpenAIImagesStreamPartialPayload(
eventName,
b64,
gjson.GetBytes(dataBytes, "partial_image_index").Int(),
format,
createdAt,
partialMeta,
)
if writeErr := s.writeOpenAIImagesStreamEvent(c, flusher, eventName, payload); writeErr != nil {
return OpenAIUsage{}, imageCount, firstTokenMs, writeErr
}
}
case "response.output_item.done":
img, itemID, ok, extractErr := extractOpenAIImageFromResponsesOutputItemDone(dataBytes)
if extractErr != nil {
_ = s.writeOpenAIImagesStreamEvent(c, flusher, "error", buildOpenAIImagesStreamErrorBody(extractErr.Error()))
return OpenAIUsage{}, imageCount, firstTokenMs, extractErr
}
if !ok {
break
}
mergeOpenAIResponsesImageMeta(&streamMeta, img)
mergeOpenAIResponsesImageMeta(&img, streamMeta)
key := openAIResponsesImageResultKey(itemID, img)
if _, exists := emitted[key]; exists {
break
}
if _, exists := pendingSeen[key]; exists {
break
}
pendingSeen[key] = struct{}{}
pendingResults = append(pendingResults, img)
case "response.completed":
results, _, usageRaw, firstMeta, extractErr := extractOpenAIImagesFromResponsesCompleted(dataBytes)
if extractErr != nil {
_ = s.writeOpenAIImagesStreamEvent(c, flusher, "error", buildOpenAIImagesStreamErrorBody(extractErr.Error()))
return OpenAIUsage{}, imageCount, firstTokenMs, extractErr
}
mergeOpenAIResponsesImageMeta(&streamMeta, firstMeta)
finalResults := make([]openAIResponsesImageResult, 0, len(results)+len(pendingResults))
finalSeen := make(map[string]struct{})
for _, img := range results {
mergeOpenAIResponsesImageMeta(&img, streamMeta)
appendOpenAIResponsesImageResultDedup(&finalResults, finalSeen, "", img)
}
for _, img := range pendingResults {
mergeOpenAIResponsesImageMeta(&img, streamMeta)
appendOpenAIResponsesImageResultDedup(&finalResults, finalSeen, "", img)
}
if len(finalResults) == 0 {
err = fmt.Errorf("upstream did not return image output")
_ = s.writeOpenAIImagesStreamEvent(c, flusher, "error", buildOpenAIImagesStreamErrorBody(err.Error()))
return OpenAIUsage{}, imageCount, firstTokenMs, err
}
eventName := streamPrefix + ".completed"
for _, img := range finalResults {
key := openAIResponsesImageResultKey("", img)
if _, exists := emitted[key]; exists {
continue
}
payload := buildOpenAIImagesStreamCompletedPayload(eventName, img, format, createdAt, usageRaw)
if writeErr := s.writeOpenAIImagesStreamEvent(c, flusher, eventName, payload); writeErr != nil {
return OpenAIUsage{}, imageCount, firstTokenMs, writeErr
}
emitted[key] = struct{}{}
}
imageCount = len(emitted)
return usage, imageCount, firstTokenMs, nil
}
if err := finalizePending(); err != nil {
return usage, imageCount, firstTokenMs, err
}
return usage, imageCount, firstTokenMs, nil
}
}
if err == io.EOF {
break
}
if err != nil {
_ = s.writeOpenAIImagesStreamEvent(c, flusher, "error", buildOpenAIImagesStreamErrorBody(err.Error()))
return OpenAIUsage{}, imageCount, firstTokenMs, err
}
}
if imageCount > 0 {
return usage, imageCount, firstTokenMs, nil
}
if len(pendingResults) > 0 {
eventName := streamPrefix + ".completed"
for _, img := range pendingResults {
mergeOpenAIResponsesImageMeta(&img, streamMeta)
key := openAIResponsesImageResultKey("", img)
if _, exists := emitted[key]; exists {
if ev.err != nil {
if done, processErr := flushData(); processErr != nil {
return usage, imageCount, firstTokenMs, processErr
} else if done {
return usage, imageCount, firstTokenMs, nil
}
s.tryWriteOpenAIImagesStreamEvent(c, flusher, &clientDisconnected, &lastDownstreamWriteAt, "error", buildOpenAIImagesStreamErrorBody(ev.err.Error()))
return usage, imageCount, firstTokenMs, ev.err
}
done, processErr := processLine(ev.line)
if processErr != nil {
return usage, imageCount, firstTokenMs, processErr
}
if done {
return usage, imageCount, firstTokenMs, nil
}
case <-intervalCh:
lastRead := time.Unix(0, atomic.LoadInt64(&lastReadAt))
if time.Since(lastRead) < streamInterval {
continue
}
payload := buildOpenAIImagesStreamCompletedPayload(eventName, img, format, createdAt, nil)
if writeErr := s.writeOpenAIImagesStreamEvent(c, flusher, eventName, payload); writeErr != nil {
return OpenAIUsage{}, imageCount, firstTokenMs, writeErr
if clientDisconnected {
return usage, imageCount, firstTokenMs, fmt.Errorf("image stream incomplete after timeout")
}
emitted[key] = struct{}{}
logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Images responses stream data interval timeout: interval=%s", streamInterval)
s.tryWriteOpenAIImagesStreamEvent(c, flusher, &clientDisconnected, &lastDownstreamWriteAt, "error", buildOpenAIImagesStreamErrorBody(fmt.Sprintf("upstream image stream idle for %s", streamInterval)))
return usage, imageCount, firstTokenMs, fmt.Errorf("image stream data interval timeout")
case <-keepaliveCh:
if clientDisconnected || time.Since(lastDownstreamWriteAt) < keepaliveInterval {
continue
}
if _, writeErr := io.WriteString(c.Writer, ":\n\n"); writeErr != nil {
clientDisconnected = true
logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Images responses stream client disconnected during keepalive, continue draining upstream for billing")
continue
}
flusher.Flush()
lastDownstreamWriteAt = time.Now()
}
imageCount = len(emitted)
return usage, imageCount, firstTokenMs, nil
}
streamErr := fmt.Errorf("stream disconnected before image generation completed")
_ = s.writeOpenAIImagesStreamEvent(c, flusher, "error", buildOpenAIImagesStreamErrorBody(streamErr.Error()))
return OpenAIUsage{}, imageCount, firstTokenMs, streamErr
}
func (s *OpenAIGatewayService) forwardOpenAIImagesOAuth(
@ -752,7 +948,10 @@ func (s *OpenAIGatewayService) forwardOpenAIImagesOAuth(
)
}
token, _, err := s.GetAccessToken(ctx, account)
upstreamCtx, releaseUpstreamCtx := detachStreamUpstreamContext(ctx, parsed.Stream)
defer releaseUpstreamCtx()
token, _, err := s.GetAccessToken(upstreamCtx, account)
if err != nil {
return nil, err
}
@ -763,7 +962,7 @@ func (s *OpenAIGatewayService) forwardOpenAIImagesOAuth(
}
setOpsUpstreamRequestBody(c, responsesBody)
upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, responsesBody, token, true, parsed.StickySessionSeed(), false)
upstreamReq, err := s.buildUpstreamRequest(upstreamCtx, c, account, responsesBody, token, true, parsed.StickySessionSeed(), false)
if err != nil {
return nil, err
}
@ -808,14 +1007,14 @@ func (s *OpenAIGatewayService) forwardOpenAIImagesOAuth(
Kind: "failover",
Message: upstreamMsg,
})
s.handleFailoverSideEffects(ctx, resp, account)
s.handleFailoverSideEffects(upstreamCtx, resp, account)
return nil, &UpstreamFailoverError{
StatusCode: resp.StatusCode,
ResponseBody: respBody,
RetryableOnSameAccount: account.IsPoolMode() && isPoolModeRetryableStatus(resp.StatusCode),
}
}
return s.handleErrorResponse(ctx, resp, c, account, responsesBody)
return s.handleErrorResponse(upstreamCtx, resp, c, account, responsesBody)
}
defer func() { _ = resp.Body.Close() }()
@ -827,6 +1026,20 @@ func (s *OpenAIGatewayService) forwardOpenAIImagesOAuth(
if parsed.Stream {
usage, imageCount, firstTokenMs, err = s.handleOpenAIImagesOAuthStreamingResponse(resp, c, startTime, parsed.ResponseFormat, openAIImagesStreamPrefix(parsed), requestModel)
if err != nil {
if imageCount > 0 {
return &OpenAIForwardResult{
RequestID: resp.Header.Get("x-request-id"),
Usage: usage,
Model: requestModel,
UpstreamModel: requestModel,
Stream: parsed.Stream,
ResponseHeaders: resp.Header.Clone(),
Duration: time.Since(startTime),
FirstTokenMs: firstTokenMs,
ImageCount: imageCount,
ImageSize: parsed.SizeTier,
}, err
}
return nil, err
}
} else {

View File

@ -3,6 +3,7 @@ package service
import (
"bytes"
"context"
"errors"
"io"
"mime/multipart"
"net/http"
@ -17,6 +18,20 @@ import (
"github.com/tidwall/gjson"
)
type failingOpenAIImageWriter struct {
gin.ResponseWriter
failAfter int
writes int
}
func (w *failingOpenAIImageWriter) Write(p []byte) (int, error) {
if w.writes >= w.failAfter {
return 0, errors.New("write failed: client disconnected")
}
w.writes++
return w.ResponseWriter.Write(p)
}
func TestOpenAIGatewayServiceParseOpenAIImagesRequest_JSON(t *testing.T) {
gin.SetMode(gin.TestMode)
body := []byte(`{"model":"gpt-image-2","prompt":"draw a cat","size":"1024x1024","quality":"high","stream":true}`)
@ -75,6 +90,100 @@ func TestOpenAIGatewayServiceParseOpenAIImagesRequest_MultipartEdit(t *testing.T
require.Equal(t, OpenAIImagesCapabilityNative, parsed.RequiredCapability)
}
func TestOpenAIGatewayServiceParseOpenAIImagesRequest_NormalizesOfficialAndCustomSizes(t *testing.T) {
gin.SetMode(gin.TestMode)
tests := []struct {
size string
wantTier string
}{
{size: "1024x1024", wantTier: "1K"},
{size: "1536x1024", wantTier: "2K"},
{size: "1024x1536", wantTier: "2K"},
{size: "2048x2048", wantTier: "2K"},
{size: "2048x1152", wantTier: "2K"},
{size: "3840x2160", wantTier: "4K"},
{size: "2160x3840", wantTier: "4K"},
{size: "1024X768", wantTier: "2K"},
{size: "1280x768", wantTier: "2K"},
{size: "2560x1440", wantTier: "2K"},
{size: "2560x1600", wantTier: "4K"},
{size: "auto", wantTier: "2K"},
}
svc := &OpenAIGatewayService{}
for _, tt := range tests {
t.Run(tt.size, func(t *testing.T) {
body := []byte(`{"model":"gpt-image-2","prompt":"draw a cat","size":"` + tt.size + `"}`)
req := httptest.NewRequest(http.MethodPost, "/v1/images/generations", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = req
parsed, err := svc.ParseOpenAIImagesRequest(c, body)
require.NoError(t, err)
require.NotNil(t, parsed)
require.Equal(t, tt.size, parsed.Size)
require.Equal(t, tt.wantTier, parsed.SizeTier)
})
}
}
func TestOpenAIGatewayServiceParseOpenAIImagesRequest_UnknownSizesDoNotBlockPassthrough(t *testing.T) {
gin.SetMode(gin.TestMode)
tests := []struct {
size string
wantTier string
}{
{size: "2048x1153", wantTier: "2K"},
{size: "4096x1024", wantTier: "4K"},
{size: "3840x1024", wantTier: "4K"},
{size: "512x512", wantTier: "2K"},
{size: "invalid", wantTier: "2K"},
{size: "999999999999999999999999999x2", wantTier: "2K"},
}
svc := &OpenAIGatewayService{}
for _, tt := range tests {
t.Run(tt.size, func(t *testing.T) {
body := []byte(`{"model":"gpt-image-2","prompt":"draw a cat","size":"` + tt.size + `"}`)
req := httptest.NewRequest(http.MethodPost, "/v1/images/generations", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = req
parsed, err := svc.ParseOpenAIImagesRequest(c, body)
require.NoError(t, err)
require.NotNil(t, parsed)
require.Equal(t, tt.size, parsed.Size)
require.Equal(t, tt.wantTier, parsed.SizeTier)
})
}
}
func TestOpenAIGatewayServiceParseOpenAIImagesRequest_LegacyImageModelUnknownSizePassthrough(t *testing.T) {
gin.SetMode(gin.TestMode)
body := []byte(`{"model":"gpt-image-1.5","prompt":"draw a cat","size":"2048x1152"}`)
req := httptest.NewRequest(http.MethodPost, "/v1/images/generations", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = req
svc := &OpenAIGatewayService{}
parsed, err := svc.ParseOpenAIImagesRequest(c, body)
require.NoError(t, err)
require.NotNil(t, parsed)
require.Equal(t, "2048x1152", parsed.Size)
require.Equal(t, "2K", parsed.SizeTier)
}
func TestOpenAIGatewayServiceParseOpenAIImagesRequest_MultipartEditWithMaskAndNativeOptions(t *testing.T) {
gin.SetMode(gin.TestMode)
@ -543,6 +652,57 @@ func TestOpenAIGatewayServiceForwardImages_APIKeyStreamRawJSONEventStreamFallbac
require.Equal(t, "ZmluYWw=", gjson.Get(rec.Body.String(), "data.0.b64_json").String())
}
func TestOpenAIGatewayServiceForwardImages_APIKeyStreamMultilineSSEDataBillsImage(t *testing.T) {
gin.SetMode(gin.TestMode)
body := []byte(`{"model":"gpt-image-2","prompt":"draw a cat","stream":true,"response_format":"b64_json"}`)
req := httptest.NewRequest(http.MethodPost, "/v1/images/generations", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = req
svc := &OpenAIGatewayService{
cfg: &config.Config{},
httpUpstream: &httpUpstreamRecorder{
resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{
"Content-Type": []string{"text/event-stream"},
"X-Request-Id": []string{"req_img_stream_multiline"},
},
Body: io.NopCloser(strings.NewReader(
"data: {\"type\":\"image_generation.completed\",\n" +
"data: \"usage\":{\"input_tokens\":10,\"output_tokens\":18,\"output_tokens_details\":{\"image_tokens\":8}},\n" +
"data: \"b64_json\":\"ZmluYWw=\",\"output_format\":\"png\"}\n\n" +
"data: [DONE]\n\n",
)),
},
},
}
parsed, err := svc.ParseOpenAIImagesRequest(c, body)
require.NoError(t, err)
account := &Account{
ID: 8,
Name: "openai-apikey",
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Credentials: map[string]any{
"api_key": "test-api-key",
},
}
result, err := svc.ForwardImages(context.Background(), c, account, body, parsed, "")
require.NoError(t, err)
require.NotNil(t, result)
require.True(t, result.Stream)
require.Equal(t, 1, result.ImageCount)
require.Equal(t, 10, result.Usage.InputTokens)
require.Equal(t, 18, result.Usage.OutputTokens)
require.Equal(t, 8, result.Usage.ImageOutputTokens)
}
func TestExtractOpenAIImagesBillableCountFromJSONBytes_CompletedEvent(t *testing.T) {
body := []byte(`{"type":"image_generation.completed","b64_json":"ZmluYWw=","usage":{"input_tokens":10,"output_tokens":18}}`)
@ -686,6 +846,61 @@ func TestOpenAIGatewayServiceForwardImages_OAuthStreamingTransformsEvents(t *tes
require.False(t, gjson.Get(completed.Data, "revised_prompt").Exists())
}
func TestOpenAIGatewayServiceForwardImages_APIKeyStreamingDrainsAfterClientDisconnect(t *testing.T) {
gin.SetMode(gin.TestMode)
body := []byte(`{"model":"gpt-image-2","prompt":"draw a cat","stream":true}`)
req := httptest.NewRequest(http.MethodPost, "/v1/images/generations", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = req
c.Writer = &failingOpenAIImageWriter{ResponseWriter: c.Writer, failAfter: 1}
svc := &OpenAIGatewayService{
cfg: &config.Config{
Gateway: config.GatewayConfig{
ImageStreamDataIntervalTimeout: 1,
ImageStreamKeepaliveInterval: 0,
},
},
httpUpstream: &httpUpstreamRecorder{
resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{
"Content-Type": []string{"text/event-stream"},
"X-Request-Id": []string{"req_img_stream_disconnect_apikey"},
},
Body: io.NopCloser(strings.NewReader(
"data: {\"type\":\"image_generation.partial_image\",\"b64_json\":\"cGFydGlhbA==\"}\n\n" +
"data: {\"type\":\"image_generation.completed\",\"usage\":{\"input_tokens\":3,\"output_tokens\":4,\"output_tokens_details\":{\"image_tokens\":2}},\"b64_json\":\"ZmluYWw=\",\"output_format\":\"png\"}\n\n" +
"data: [DONE]\n\n",
)),
},
},
}
parsed, err := svc.ParseOpenAIImagesRequest(c, body)
require.NoError(t, err)
account := &Account{
ID: 8,
Name: "openai-apikey",
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Credentials: map[string]any{
"api_key": "test-api-key",
},
}
result, err := svc.ForwardImages(context.Background(), c, account, body, parsed, "")
require.NoError(t, err)
require.NotNil(t, result)
require.Equal(t, 1, result.ImageCount)
require.Equal(t, 3, result.Usage.InputTokens)
require.Equal(t, 4, result.Usage.OutputTokens)
require.Equal(t, 2, result.Usage.ImageOutputTokens)
}
func TestOpenAIGatewayServiceForwardImages_OAuthEditsMultipartUsesResponsesAPI(t *testing.T) {
gin.SetMode(gin.TestMode)
@ -901,6 +1116,23 @@ func TestCollectOpenAIImagesFromResponsesBody_FallsBackToOutputItemDone(t *testi
require.JSONEq(t, `{"images":1}`, string(usageRaw))
}
func TestCollectOpenAIImagesFromResponsesBody_MultilineSSE(t *testing.T) {
body := []byte(
"data: {\"type\":\"response.completed\",\n" +
"data: \"response\":{\"created_at\":1710000010,\"usage\":{\"input_tokens\":5,\"output_tokens\":9,\"output_tokens_details\":{\"image_tokens\":4}},\"tool_usage\":{\"image_gen\":{\"images\":1}},\"output\":[{\"type\":\"image_generation_call\",\"result\":\"ZmluYWw=\",\"output_format\":\"png\"}]}}\n\n" +
"data: [DONE]\n\n",
)
results, createdAt, usageRaw, firstMeta, foundFinal, err := collectOpenAIImagesFromResponsesBody(body)
require.NoError(t, err)
require.True(t, foundFinal)
require.Equal(t, int64(1710000010), createdAt)
require.Len(t, results, 1)
require.Equal(t, "ZmluYWw=", results[0].Result)
require.Equal(t, "png", firstMeta.OutputFormat)
require.JSONEq(t, `{"images":1}`, string(usageRaw))
}
func TestOpenAIGatewayServiceForwardImages_OAuthStreamingHandlesOutputItemDoneFallback(t *testing.T) {
gin.SetMode(gin.TestMode)
body := []byte(`{"model":"gpt-image-2","prompt":"draw a cat","stream":true,"response_format":"url"}`)
@ -957,3 +1189,116 @@ func TestOpenAIGatewayServiceForwardImages_OAuthStreamingHandlesOutputItemDoneFa
require.JSONEq(t, `{"images":1}`, gjson.Get(completed.Data, "usage").Raw)
require.NotContains(t, rec.Body.String(), "event: error")
}
func TestOpenAIGatewayServiceForwardImages_OAuthStreamingHandlesMultilineSSE(t *testing.T) {
gin.SetMode(gin.TestMode)
body := []byte(`{"model":"gpt-image-2","prompt":"draw a cat","stream":true,"response_format":"b64_json"}`)
req := httptest.NewRequest(http.MethodPost, "/v1/images/generations", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = req
svc := &OpenAIGatewayService{}
parsed, err := svc.ParseOpenAIImagesRequest(c, body)
require.NoError(t, err)
svc.httpUpstream = &httpUpstreamRecorder{
resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{
"Content-Type": []string{"text/event-stream"},
"X-Request-Id": []string{"req_img_stream_multiline_oauth"},
},
Body: io.NopCloser(strings.NewReader(
"data: {\"type\":\"response.completed\",\n" +
"data: \"response\":{\"created_at\":1710000011,\"usage\":{\"input_tokens\":6,\"output_tokens\":10,\"output_tokens_details\":{\"image_tokens\":5}},\"tool_usage\":{\"image_gen\":{\"images\":1}},\"output\":[{\"type\":\"image_generation_call\",\"result\":\"TXVsdGlsaW5l\",\"output_format\":\"png\"}]}}\n\n" +
"data: [DONE]\n\n",
)),
},
}
account := &Account{
ID: 11,
Name: "openai-oauth",
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "token-123",
},
}
result, err := svc.ForwardImages(context.Background(), c, account, body, parsed, "")
require.NoError(t, err)
require.NotNil(t, result)
require.True(t, result.Stream)
require.Equal(t, 1, result.ImageCount)
require.Equal(t, 6, result.Usage.InputTokens)
require.Equal(t, 10, result.Usage.OutputTokens)
require.Equal(t, 5, result.Usage.ImageOutputTokens)
events := parseOpenAIImageTestSSEEvents(rec.Body.String())
completed, ok := findOpenAIImageTestSSEEvent(events, "image_generation.completed")
require.True(t, ok)
require.Equal(t, "TXVsdGlsaW5l", gjson.Get(completed.Data, "b64_json").String())
require.JSONEq(t, `{"images":1}`, gjson.Get(completed.Data, "usage").Raw)
require.NotContains(t, rec.Body.String(), "event: error")
}
func TestOpenAIGatewayServiceForwardImages_OAuthStreamingDrainsAfterClientDisconnect(t *testing.T) {
gin.SetMode(gin.TestMode)
body := []byte(`{"model":"gpt-image-2","prompt":"draw a cat","stream":true,"response_format":"url"}`)
req := httptest.NewRequest(http.MethodPost, "/v1/images/generations", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = req
c.Writer = &failingOpenAIImageWriter{ResponseWriter: c.Writer, failAfter: 1}
svc := &OpenAIGatewayService{
cfg: &config.Config{
Gateway: config.GatewayConfig{
ImageStreamDataIntervalTimeout: 1,
ImageStreamKeepaliveInterval: 0,
},
},
}
parsed, err := svc.ParseOpenAIImagesRequest(c, body)
require.NoError(t, err)
upstream := &httpUpstreamRecorder{
resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{
"Content-Type": []string{"text/event-stream"},
"X-Request-Id": []string{"req_img_stream_disconnect_oauth"},
},
Body: io.NopCloser(strings.NewReader(
"data: {\"type\":\"response.image_generation_call.partial_image\",\"partial_image_b64\":\"cGFydGlhbA==\",\"partial_image_index\":0,\"output_format\":\"png\"}\n\n" +
"data: {\"type\":\"response.completed\",\"response\":{\"created_at\":1710000009,\"usage\":{\"input_tokens\":5,\"output_tokens\":9,\"output_tokens_details\":{\"image_tokens\":4}},\"tool_usage\":{\"image_gen\":{\"images\":1}},\"output\":[{\"type\":\"image_generation_call\",\"result\":\"ZmluYWw=\",\"output_format\":\"png\"}]}}\n\n" +
"data: [DONE]\n\n",
)),
},
}
svc.httpUpstream = upstream
account := &Account{
ID: 9,
Name: "openai-oauth",
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "token-123",
},
}
result, err := svc.ForwardImages(context.Background(), c, account, body, parsed, "")
require.NoError(t, err)
require.NotNil(t, result)
require.True(t, result.Stream)
require.Equal(t, 1, result.ImageCount)
require.Equal(t, 5, result.Usage.InputTokens)
require.Equal(t, 9, result.Usage.OutputTokens)
require.Equal(t, 4, result.Usage.ImageOutputTokens)
}

View File

@ -0,0 +1,137 @@
package service
import "strings"
func lastOpenAIModelSegment(model string) string {
model = strings.TrimSpace(model)
if model == "" {
return ""
}
if strings.Contains(model, "/") {
parts := strings.Split(model, "/")
model = parts[len(parts)-1]
}
return strings.TrimSpace(model)
}
func canonicalizeOpenAIModelAliasSpelling(model string) string {
model = strings.ToLower(lastOpenAIModelSegment(model))
if model == "" {
return ""
}
normalized := strings.ReplaceAll(model, "_", "-")
normalized = strings.Join(strings.Fields(normalized), "-")
for strings.Contains(normalized, "--") {
normalized = strings.ReplaceAll(normalized, "--", "-")
}
if strings.HasPrefix(normalized, "gpt5") {
normalized = "gpt-5" + strings.TrimPrefix(normalized, "gpt5")
}
if !strings.HasPrefix(normalized, "gpt-") && !strings.Contains(normalized, "codex") {
return ""
}
replacements := []struct {
from string
to string
}{
{"gpt-5.4mini", "gpt-5.4-mini"},
{"gpt-5.4nano", "gpt-5.4-nano"},
{"gpt-5.3-codexspark", "gpt-5.3-codex-spark"},
{"gpt-5.3codexspark", "gpt-5.3-codex-spark"},
{"gpt-5.3codex", "gpt-5.3-codex"},
}
for _, replacement := range replacements {
normalized = strings.ReplaceAll(normalized, replacement.from, replacement.to)
}
return normalized
}
func normalizeKnownOpenAICodexModel(model string) string {
normalized := canonicalizeOpenAIModelAliasSpelling(model)
if normalized == "" {
return ""
}
if mapped := getNormalizedCodexModel(normalized); mapped != "" {
return mapped
}
if strings.HasSuffix(normalized, "-openai-compact") {
if mapped := getNormalizedCodexModel(strings.TrimSuffix(normalized, "-openai-compact")); mapped != "" {
return mapped
}
}
switch {
case strings.Contains(normalized, "gpt-5.5"):
return "gpt-5.5"
case strings.Contains(normalized, "gpt-5.4-mini"):
return "gpt-5.4-mini"
case strings.Contains(normalized, "gpt-5.4-nano"):
return "gpt-5.4-nano"
case strings.Contains(normalized, "gpt-5.4"):
return "gpt-5.4"
case strings.Contains(normalized, "gpt-5.2"):
return "gpt-5.2"
case strings.Contains(normalized, "gpt-5.3-codex-spark"):
return "gpt-5.3-codex-spark"
case strings.Contains(normalized, "gpt-5.3-codex"):
return "gpt-5.3-codex"
case strings.Contains(normalized, "gpt-5.3"):
return "gpt-5.3-codex"
case strings.Contains(normalized, "codex"):
return "gpt-5.3-codex"
case strings.Contains(normalized, "gpt-5"):
return "gpt-5.4"
default:
return ""
}
}
func appendUsageBillingModelCandidate(candidates []string, seen map[string]struct{}, model string) []string {
trimmed := strings.TrimSpace(model)
if trimmed == "" {
return candidates
}
add := func(candidate string) {
candidate = strings.TrimSpace(candidate)
if candidate == "" {
return
}
key := strings.ToLower(candidate)
if _, ok := seen[key]; ok {
return
}
seen[key] = struct{}{}
candidates = append(candidates, candidate)
}
add(trimmed)
if canonical := canonicalizeOpenAIModelAliasSpelling(trimmed); canonical != "" {
add(canonical)
}
if normalized := normalizeKnownOpenAICodexModel(trimmed); normalized != "" {
add(normalized)
}
return candidates
}
func usageBillingModelCandidates(primary string, alternates ...string) []string {
seen := make(map[string]struct{}, 1+len(alternates))
candidates := appendUsageBillingModelCandidate(nil, seen, primary)
for _, alternate := range alternates {
candidates = appendUsageBillingModelCandidate(candidates, seen, alternate)
}
return candidates
}
func firstUsageBillingModel(candidates []string) string {
for _, candidate := range candidates {
if trimmed := strings.TrimSpace(candidate); trimmed != "" {
return trimmed
}
}
return ""
}

View File

@ -94,6 +94,15 @@ func TestResolveOpenAIForwardModel(t *testing.T) {
defaultMappedModel: "gpt-5.4",
expectedModel: "gpt-5.5",
},
{
name: "preserves compact-spelled gpt5.5 instead of group default",
account: &Account{
Credentials: map[string]any{},
},
requestedModel: "gpt5.5",
defaultMappedModel: "gpt-5.4",
expectedModel: "gpt5.5",
},
{
name: "preserves openai namespaced gpt-5.5 instead of group default",
account: &Account{

View File

@ -0,0 +1,70 @@
package service
import (
"strings"
"github.com/tidwall/gjson"
)
type openAISSEDataAccumulator struct {
lines []string
}
func (a *openAISSEDataAccumulator) AddLine(line string, fn func([]byte)) {
if fn == nil {
return
}
trimmedLine := strings.TrimRight(line, "\r\n")
if data, ok := extractOpenAISSEDataLine(trimmedLine); ok {
a.lines = append(a.lines, data)
return
}
if strings.TrimSpace(trimmedLine) == "" {
a.Flush(fn)
}
}
func (a *openAISSEDataAccumulator) Flush(fn func([]byte)) {
if fn == nil || len(a.lines) == 0 {
return
}
emitOpenAISSEDataPayloads(a.lines, fn)
a.lines = a.lines[:0]
}
func forEachOpenAISSEDataPayload(body string, fn func([]byte)) {
if fn == nil || strings.TrimSpace(body) == "" {
return
}
var acc openAISSEDataAccumulator
for _, line := range strings.Split(body, "\n") {
acc.AddLine(line, fn)
}
acc.Flush(fn)
}
func emitOpenAISSEDataPayloads(lines []string, fn func([]byte)) {
if fn == nil || len(lines) == 0 {
return
}
if len(lines) == 1 {
emitOpenAISSEDataPayload(lines[0], fn)
return
}
joined := strings.Join(lines, "\n")
if gjson.Valid(joined) {
emitOpenAISSEDataPayload(joined, fn)
return
}
for _, line := range lines {
emitOpenAISSEDataPayload(line, fn)
}
}
func emitOpenAISSEDataPayload(data string, fn func([]byte)) {
data = strings.TrimSpace(data)
if data == "" || data == "[DONE]" {
return
}
fn([]byte(data))
}

View File

@ -1990,6 +1990,7 @@ func (s *OpenAIGatewayService) forwardOpenAIWSV2(
}
usage := &OpenAIUsage{}
imageCounter := newOpenAIImageOutputCounter()
var firstTokenMs *int
responseID := ""
var finalResponse []byte
@ -2171,6 +2172,7 @@ func (s *OpenAIGatewayService) forwardOpenAIWSV2(
if openAIWSEventShouldParseUsage(eventType) {
parseOpenAIWSResponseUsageFromCompletedEvent(message, usage)
}
imageCounter.AddSSEData(message)
if eventType == "error" {
errCodeRaw, errTypeRaw, errMsgRaw := parseOpenAIWSErrorEventFields(message)
@ -2343,6 +2345,7 @@ func (s *OpenAIGatewayService) forwardOpenAIWSV2(
Usage: *usage,
Model: originalModel,
UpstreamModel: mappedModel,
ImageCount: imageCounter.Count(),
ServiceTier: extractOpenAIServiceTier(reqBody),
ReasoningEffort: extractOpenAIReasoningEffort(reqBody, originalModel),
Stream: reqStream,
@ -2449,6 +2452,8 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
promptCacheKey string
previousResponseID string
originalModel string
imageBillingModel string
imageSizeTier string
payloadBytes int
}
@ -2546,6 +2551,19 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
}
normalized = next
}
imageIntent := IsImageGenerationIntent(openAIResponsesEndpoint, originalModel, normalized)
if imageIntent && !GroupAllowsImageGeneration(apiKeyGroup(getAPIKeyFromContext(c))) {
return openAIWSClientPayload{}, NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, ImageGenerationPermissionMessage(), nil)
}
imageBillingModel := ""
imageSizeTier := ""
if imageIntent {
var imageCfgErr error
imageBillingModel, imageSizeTier, imageCfgErr = resolveOpenAIResponsesImageBillingConfigFromBody(normalized, originalModel)
if imageCfgErr != nil {
return openAIWSClientPayload{}, NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, imageCfgErr.Error(), imageCfgErr)
}
}
// Apply OpenAI Fast Policy on the response.create frame using the same
// evaluator/normalize/scope rules as the HTTP entrypoints. This is the
@ -2591,6 +2609,8 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
promptCacheKey: promptCacheKey,
previousResponseID: previousResponseID,
originalModel: originalModel,
imageBillingModel: imageBillingModel,
imageSizeTier: imageSizeTier,
payloadBytes: len(normalized),
}, nil
}
@ -2792,7 +2812,7 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
return payload, nil
}
sendAndRelay := func(turn int, lease *openAIWSConnLease, payload []byte, payloadBytes int, originalModel string) (*OpenAIForwardResult, error) {
sendAndRelay := func(turn int, lease *openAIWSConnLease, payload []byte, payloadBytes int, originalModel string, imageBillingModel string, imageSizeTier string) (*OpenAIForwardResult, error) {
if lease == nil {
return nil, errors.New("upstream websocket lease is nil")
}
@ -2817,6 +2837,7 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
responseID := ""
usage := OpenAIUsage{}
imageCounter := newOpenAIImageOutputCounter()
var firstTokenMs *int
reqStream := openAIWSPayloadBoolFromRaw(payload, "stream", true)
turnPreviousResponseID := openAIWSPayloadStringFromRaw(payload, "previous_response_id")
@ -2938,6 +2959,7 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
if openAIWSEventShouldParseUsage(eventType) {
parseOpenAIWSResponseUsageFromCompletedEvent(upstreamMessage, &usage)
}
imageCounter.AddSSEData(upstreamMessage)
if !clientDisconnected {
if needModelReplace && len(mappedModelBytes) > 0 && openAIWSEventMayContainModel(eventType) && bytes.Contains(upstreamMessage, mappedModelBytes) {
@ -2997,7 +3019,8 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
clientDisconnected,
)
}
return &OpenAIForwardResult{
imageCount := imageCounter.Count()
result := &OpenAIForwardResult{
RequestID: responseID,
Usage: usage,
Model: originalModel,
@ -3009,13 +3032,21 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
ResponseHeaders: lease.HandshakeHeaders(),
Duration: time.Since(turnStart),
FirstTokenMs: firstTokenMs,
}, nil
}
if imageCount > 0 {
result.ImageCount = imageCount
result.ImageSize = imageSizeTier
result.BillingModel = imageBillingModel
}
return result, nil
}
}
}
currentPayload := firstPayload.payloadRaw
currentOriginalModel := firstPayload.originalModel
currentImageBillingModel := firstPayload.imageBillingModel
currentImageSizeTier := firstPayload.imageSizeTier
currentPayloadBytes := firstPayload.payloadBytes
isStrictAffinityTurn := func(payload []byte) bool {
if !storeDisabled {
@ -3460,7 +3491,7 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
)
}
result, relayErr := sendAndRelay(turn, sessionLease, currentPayload, currentPayloadBytes, currentOriginalModel)
result, relayErr := sendAndRelay(turn, sessionLease, currentPayload, currentPayloadBytes, currentOriginalModel, currentImageBillingModel, currentImageSizeTier)
if relayErr != nil {
lastTurnClean = false
if recoverIngressPrevResponseNotFound(relayErr, turn, connID) {
@ -3582,6 +3613,8 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
}
currentPayload = nextPayload.payloadRaw
currentOriginalModel = nextPayload.originalModel
currentImageBillingModel = nextPayload.imageBillingModel
currentImageSizeTier = nextPayload.imageSizeTier
currentPayloadBytes = nextPayload.payloadBytes
storeDisabled = s.isOpenAIWSStoreDisabledInRequestRaw(currentPayload, account)
if !storeDisabled {

View File

@ -171,6 +171,127 @@ func TestOpenAIGatewayService_Forward_WSv2_SuccessAndBindSticky(t *testing.T) {
require.Equal(t, "resp_new_1", gjson.GetBytes(responseBody, "id").String())
}
func TestOpenAIGatewayService_Forward_WSv2_ImageGenerationCountsOutputs(t *testing.T) {
gin.SetMode(gin.TestMode)
upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}
wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
t.Errorf("upgrade websocket failed: %v", err)
return
}
defer func() {
_ = conn.Close()
}()
var request map[string]any
if err := conn.ReadJSON(&request); err != nil {
t.Errorf("read ws request failed: %v", err)
return
}
if err := conn.WriteJSON(map[string]any{
"type": "response.output_item.done",
"item": map[string]any{
"id": "ig_ws_1",
"type": "image_generation_call",
"result": "final-image",
},
}); err != nil {
t.Errorf("write response.output_item.done failed: %v", err)
return
}
if err := conn.WriteJSON(map[string]any{
"type": "response.completed",
"response": map[string]any{
"id": "resp_ws_image_1",
"model": "gpt-5.4",
"output": []any{
map[string]any{
"id": "ig_ws_1",
"type": "image_generation_call",
"result": "final-image",
},
},
"usage": map[string]any{
"input_tokens": 9,
"output_tokens": 4,
},
},
}); err != nil {
t.Errorf("write response.completed failed: %v", err)
return
}
}))
defer wsServer.Close()
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
groupID := int64(1010)
c.Set("api_key", &APIKey{
GroupID: &groupID,
Group: &Group{
ID: groupID,
AllowImageGeneration: true,
},
})
cfg := &config.Config{}
cfg.Security.URLAllowlist.Enabled = false
cfg.Security.URLAllowlist.AllowInsecureHTTP = true
cfg.Gateway.OpenAIWS.Enabled = true
cfg.Gateway.OpenAIWS.OAuthEnabled = true
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1
cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1
cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8
cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3
cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 5
cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3
svc := &OpenAIGatewayService{
cfg: cfg,
httpUpstream: &httpUpstreamRecorder{},
cache: &stubGatewayCache{},
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
toolCorrector: NewCodexToolCorrector(),
}
account := &Account{
ID: 10,
Name: "openai-ws-image",
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Credentials: map[string]any{
"api_key": "sk-test",
"base_url": wsServer.URL,
},
Extra: map[string]any{
"responses_websockets_v2_enabled": true,
},
}
body := []byte(`{"model":"gpt-5.4","stream":false,"input":"draw","tools":[{"type":"image_generation","model":"gpt-image-2","size":"1024x1024"}],"tool_choice":{"type":"image_generation"}}`)
result, err := svc.Forward(context.Background(), c, account, body)
require.NoError(t, err)
require.NotNil(t, result)
require.Equal(t, "resp_ws_image_1", result.RequestID)
require.Equal(t, 1, result.ImageCount)
require.Equal(t, "1K", result.ImageSize)
require.Equal(t, "gpt-image-2", result.BillingModel)
require.Equal(t, 9, result.Usage.InputTokens)
require.Equal(t, 4, result.Usage.OutputTokens)
require.True(t, result.OpenAIWSMode)
require.Equal(t, "resp_ws_image_1", gjson.GetBytes(rec.Body.Bytes(), "id").String())
}
func requestToJSONString(payload map[string]any) string {
if len(payload) == 0 {
return "{}"

View File

@ -625,6 +625,9 @@ func normalizeModelNameForPricing(model string) string {
}
model = strings.TrimLeft(model, "/")
if canonical := canonicalizeOpenAIModelAliasSpelling(model); canonical != "" {
return canonical
}
return model
}

View File

@ -98,6 +98,19 @@ func TestGetModelPricing_Gpt54UsesStaticFallbackWhenRemoteMissing(t *testing.T)
require.InDelta(t, 1.5, got.LongContextOutputCostMultiplier, 1e-12)
}
func TestGetModelPricing_OpenAICompactAliasUsesStaticFallback(t *testing.T) {
svc := &PricingService{
pricingData: map[string]*LiteLLMModelPricing{
"gpt-5.1-codex": {InputCostPerToken: 1.25e-6},
},
}
got := svc.GetModelPricing("openai/gpt5.5")
require.NotNil(t, got)
require.InDelta(t, 2.5e-6, got.InputCostPerToken, 1e-12)
require.InDelta(t, 1.5e-5, got.OutputCostPerToken, 1e-12)
}
func TestGetModelPricing_Gpt54MiniUsesDedicatedStaticFallbackWhenRemoteMissing(t *testing.T) {
svc := &PricingService{
pricingData: map[string]*LiteLLMModelPricing{

View File

@ -0,0 +1,26 @@
-- 生图能力与图片倍率模式控制
-- 兼容性原则:
-- 1. 不改写现有 image_price_1k/2k/4k避免改变已配置分组的最终图片价格。
-- 2. 现有 openai/gemini/antigravity 分组默认保持可生图,避免升级后中断已有图片业务。
-- 3. 现有分组默认共享当前有效分组倍率,保持历史扣费公式。
ALTER TABLE groups
ADD COLUMN IF NOT EXISTS allow_image_generation BOOLEAN NOT NULL DEFAULT false;
ALTER TABLE groups
ADD COLUMN IF NOT EXISTS image_rate_independent BOOLEAN NOT NULL DEFAULT false;
ALTER TABLE groups
ADD COLUMN IF NOT EXISTS image_rate_multiplier DECIMAL(10,4) NOT NULL DEFAULT 1.0;
UPDATE groups
SET allow_image_generation = true
WHERE platform IN ('openai', 'gemini', 'antigravity');
UPDATE groups
SET image_rate_independent = false,
image_rate_multiplier = 1.0;
COMMENT ON COLUMN groups.allow_image_generation IS '是否允许该分组使用图片生成能力';
COMMENT ON COLUMN groups.image_rate_independent IS '图片生成是否使用独立倍率false 表示共享分组有效倍率';
COMMENT ON COLUMN groups.image_rate_multiplier IS '图片生成独立倍率,仅 image_rate_independent=true 时生效';

View File

@ -285,6 +285,25 @@ GATEWAY_SCHEDULING_OUTBOX_BACKLOG_REBUILD_ROWS=10000
# 全量重建周期(秒)
GATEWAY_SCHEDULING_FULL_REBUILD_INTERVAL_SECONDS=300
# -----------------------------------------------------------------------------
# Image Generation Stream & Concurrency (Optional)
# 图片生成流式与并发隔离配置(可选)
# -----------------------------------------------------------------------------
# 图片流式上游数据间隔超时。0 表示禁用;非 0 时必须为 60-1800。
GATEWAY_IMAGE_STREAM_DATA_INTERVAL_TIMEOUT=900
# 图片流式 keepalive 间隔。0 表示禁用;非 0 时必须为 5-60。
GATEWAY_IMAGE_STREAM_KEEPALIVE_INTERVAL=10
# 是否启用进程级图片生成并发限制。默认 false保持历史行为。
GATEWAY_IMAGE_CONCURRENCY_ENABLED=false
# 当前进程允许同时处理的图片生成请求数。0 表示不限制。
GATEWAY_IMAGE_CONCURRENCY_MAX_CONCURRENT_REQUESTS=0
# 图片并发超限策略reject 直接返回 429wait 等待空闲槽位。
GATEWAY_IMAGE_CONCURRENCY_OVERFLOW_MODE=reject
# wait 模式下等待空闲图片槽位的最长时间(秒)。
GATEWAY_IMAGE_CONCURRENCY_WAIT_TIMEOUT_SECONDS=30
# wait 模式下当前进程允许排队等待的最大图片请求数。0 表示不允许等待队列。
GATEWAY_IMAGE_CONCURRENCY_MAX_WAITING_REQUESTS=100
# -----------------------------------------------------------------------------
# Dashboard Aggregation (Optional)
# -----------------------------------------------------------------------------

View File

@ -340,6 +340,30 @@ gateway:
# Stream keepalive interval (seconds), 0=disable
# 流式 keepalive 间隔0=禁用
stream_keepalive_interval: 10
# Image stream data interval timeout (seconds), 0=disable; independent from ordinary text streams
# 图片流数据间隔超时0=禁用;独立于普通文本流式
image_stream_data_interval_timeout: 900
# Image stream keepalive interval (seconds), 0=disable; independent from ordinary text streams
# 图片流式 keepalive 间隔0=禁用;独立于普通文本流式
image_stream_keepalive_interval: 10
# Image generation independent concurrency limiter (process-local, default disabled)
# 图片生成独立并发限制(进程级,默认关闭;多实例总上限约为实例数×该值)
image_concurrency:
# Enable image-only concurrency protection; false keeps existing behavior unchanged
# 是否启用图片独立并发保护false 保持现有行为不变
enabled: false
# Max concurrent image generation requests in this process, 0=unlimited
# 当前进程允许同时处理的图片生成请求数0=不限制
max_concurrent_requests: 0
# Overflow mode when the image concurrency limit is full: reject/wait
# 图片并发满时的处理方式reject=立即拒绝wait=等待槽位
overflow_mode: "reject"
# Wait timeout for overflow_mode=wait (seconds), 0=do not wait
# wait 模式等待图片并发槽位的超时时间0=不等待
wait_timeout_seconds: 30
# Max image requests waiting in this process when overflow_mode=wait, 0=unlimited
# wait 模式当前进程允许排队等待的图片请求数0=不限制
max_waiting_requests: 100
# SSE max line size in bytes (default: 40MB)
# SSE 单行最大字节数(默认 40MB
max_line_size: 41943040

View File

@ -40,6 +40,13 @@ services:
- JWT_SECRET=${JWT_SECRET:-}
- TOTP_ENCRYPTION_KEY=${TOTP_ENCRYPTION_KEY:-}
- TZ=${TZ:-Asia/Shanghai}
- GATEWAY_IMAGE_STREAM_DATA_INTERVAL_TIMEOUT=${GATEWAY_IMAGE_STREAM_DATA_INTERVAL_TIMEOUT:-900}
- GATEWAY_IMAGE_STREAM_KEEPALIVE_INTERVAL=${GATEWAY_IMAGE_STREAM_KEEPALIVE_INTERVAL:-10}
- GATEWAY_IMAGE_CONCURRENCY_ENABLED=${GATEWAY_IMAGE_CONCURRENCY_ENABLED:-false}
- GATEWAY_IMAGE_CONCURRENCY_MAX_CONCURRENT_REQUESTS=${GATEWAY_IMAGE_CONCURRENCY_MAX_CONCURRENT_REQUESTS:-0}
- GATEWAY_IMAGE_CONCURRENCY_OVERFLOW_MODE=${GATEWAY_IMAGE_CONCURRENCY_OVERFLOW_MODE:-reject}
- GATEWAY_IMAGE_CONCURRENCY_WAIT_TIMEOUT_SECONDS=${GATEWAY_IMAGE_CONCURRENCY_WAIT_TIMEOUT_SECONDS:-30}
- GATEWAY_IMAGE_CONCURRENCY_MAX_WAITING_REQUESTS=${GATEWAY_IMAGE_CONCURRENCY_MAX_WAITING_REQUESTS:-100}
depends_on:
postgres:
condition: service_healthy

View File

@ -146,6 +146,17 @@ services:
# Proxy for accessing GitHub (online updates + pricing data)
# Examples: http://host:port, socks5://host:port
- UPDATE_PROXY_URL=${UPDATE_PROXY_URL:-}
# =======================================================================
# Image Generation Stream & Concurrency
# =======================================================================
- GATEWAY_IMAGE_STREAM_DATA_INTERVAL_TIMEOUT=${GATEWAY_IMAGE_STREAM_DATA_INTERVAL_TIMEOUT:-900}
- GATEWAY_IMAGE_STREAM_KEEPALIVE_INTERVAL=${GATEWAY_IMAGE_STREAM_KEEPALIVE_INTERVAL:-10}
- GATEWAY_IMAGE_CONCURRENCY_ENABLED=${GATEWAY_IMAGE_CONCURRENCY_ENABLED:-false}
- GATEWAY_IMAGE_CONCURRENCY_MAX_CONCURRENT_REQUESTS=${GATEWAY_IMAGE_CONCURRENCY_MAX_CONCURRENT_REQUESTS:-0}
- GATEWAY_IMAGE_CONCURRENCY_OVERFLOW_MODE=${GATEWAY_IMAGE_CONCURRENCY_OVERFLOW_MODE:-reject}
- GATEWAY_IMAGE_CONCURRENCY_WAIT_TIMEOUT_SECONDS=${GATEWAY_IMAGE_CONCURRENCY_WAIT_TIMEOUT_SECONDS:-30}
- GATEWAY_IMAGE_CONCURRENCY_MAX_WAITING_REQUESTS=${GATEWAY_IMAGE_CONCURRENCY_MAX_WAITING_REQUESTS:-100}
depends_on:
postgres:
condition: service_healthy

View File

@ -93,6 +93,17 @@ services:
# SECURITY: This repo does not embed third-party client_secret.
- GEMINI_CLI_OAUTH_CLIENT_SECRET=${GEMINI_CLI_OAUTH_CLIENT_SECRET:-}
- ANTIGRAVITY_OAUTH_CLIENT_SECRET=${ANTIGRAVITY_OAUTH_CLIENT_SECRET:-}
# =======================================================================
# Image Generation Stream & Concurrency
# =======================================================================
- GATEWAY_IMAGE_STREAM_DATA_INTERVAL_TIMEOUT=${GATEWAY_IMAGE_STREAM_DATA_INTERVAL_TIMEOUT:-900}
- GATEWAY_IMAGE_STREAM_KEEPALIVE_INTERVAL=${GATEWAY_IMAGE_STREAM_KEEPALIVE_INTERVAL:-10}
- GATEWAY_IMAGE_CONCURRENCY_ENABLED=${GATEWAY_IMAGE_CONCURRENCY_ENABLED:-false}
- GATEWAY_IMAGE_CONCURRENCY_MAX_CONCURRENT_REQUESTS=${GATEWAY_IMAGE_CONCURRENCY_MAX_CONCURRENT_REQUESTS:-0}
- GATEWAY_IMAGE_CONCURRENCY_OVERFLOW_MODE=${GATEWAY_IMAGE_CONCURRENCY_OVERFLOW_MODE:-reject}
- GATEWAY_IMAGE_CONCURRENCY_WAIT_TIMEOUT_SECONDS=${GATEWAY_IMAGE_CONCURRENCY_WAIT_TIMEOUT_SECONDS:-30}
- GATEWAY_IMAGE_CONCURRENCY_MAX_WAITING_REQUESTS=${GATEWAY_IMAGE_CONCURRENCY_MAX_WAITING_REQUESTS:-100}
healthcheck:
test: ["CMD", "wget", "-q", "-T", "5", "-O", "/dev/null", "http://localhost:8080/health"]
interval: 30s

View File

@ -142,6 +142,17 @@ services:
# Proxy for accessing GitHub (online updates + pricing data)
# Examples: http://host:port, socks5://host:port
- UPDATE_PROXY_URL=${UPDATE_PROXY_URL:-}
# =======================================================================
# Image Generation Stream & Concurrency
# =======================================================================
- GATEWAY_IMAGE_STREAM_DATA_INTERVAL_TIMEOUT=${GATEWAY_IMAGE_STREAM_DATA_INTERVAL_TIMEOUT:-900}
- GATEWAY_IMAGE_STREAM_KEEPALIVE_INTERVAL=${GATEWAY_IMAGE_STREAM_KEEPALIVE_INTERVAL:-10}
- GATEWAY_IMAGE_CONCURRENCY_ENABLED=${GATEWAY_IMAGE_CONCURRENCY_ENABLED:-false}
- GATEWAY_IMAGE_CONCURRENCY_MAX_CONCURRENT_REQUESTS=${GATEWAY_IMAGE_CONCURRENCY_MAX_CONCURRENT_REQUESTS:-0}
- GATEWAY_IMAGE_CONCURRENCY_OVERFLOW_MODE=${GATEWAY_IMAGE_CONCURRENCY_OVERFLOW_MODE:-reject}
- GATEWAY_IMAGE_CONCURRENCY_WAIT_TIMEOUT_SECONDS=${GATEWAY_IMAGE_CONCURRENCY_WAIT_TIMEOUT_SECONDS:-30}
- GATEWAY_IMAGE_CONCURRENCY_MAX_WAITING_REQUESTS=${GATEWAY_IMAGE_CONCURRENCY_MAX_WAITING_REQUESTS:-100}
depends_on:
postgres:
condition: service_healthy

View File

@ -291,9 +291,23 @@
</div>
</template>
<!-- Per-request / image billing: show unit price -->
<template v-else-if="tooltipData?.billing_mode === BILLING_MODE_IMAGE">
<div class="flex items-center justify-between gap-4">
<span class="text-gray-400">{{ t('usage.imageCount') }}</span>
<span class="font-medium text-white">{{ tooltipData.image_count }}{{ t('usage.imageUnit') }} ({{ tooltipData.image_size || '2K' }})</span>
</div>
<div class="flex items-center justify-between gap-4">
<span class="text-gray-400">{{ t('usage.imageUnitPrice') }}</span>
<span class="font-medium text-sky-300">${{ imageUnitPrice(tooltipData).toFixed(6) }}</span>
</div>
<div class="flex items-center justify-between gap-4">
<span class="text-gray-400">{{ t('usage.imageTotalPrice') }}</span>
<span class="font-medium text-white">${{ tooltipData.total_cost?.toFixed(6) || '0.000000' }}</span>
</div>
</template>
<div v-else class="flex items-center justify-between gap-4">
<span class="text-gray-400">{{ tooltipData.billing_mode === BILLING_MODE_IMAGE ? t('usage.imageUnitPrice') : t('usage.unitPrice') }}</span>
<span class="font-medium text-sky-300">${{ tooltipData.total_cost?.toFixed(6) || '0.000000' }}</span>
<span class="text-gray-400">{{ t('usage.unitPrice') }}</span>
<span class="font-medium text-sky-300">${{ tooltipData?.total_cost?.toFixed(6) || '0.000000' }}</span>
</div>
<div v-if="tooltipData && tooltipData.cache_creation_cost > 0" class="flex items-center justify-between gap-4">
<span class="text-gray-400">{{ t('admin.usage.cacheCreationCost') }}</span>
@ -360,6 +374,13 @@ function accountBilled(row: { total_cost?: number | null; account_stats_cost?: n
return Number.isNaN(result) ? 0 : result
}
function imageUnitPrice(row: AdminUsageLog | null): number {
if (!row || row.image_count <= 0) return 0
const total = row.total_cost ?? 0
const price = total / row.image_count
return Number.isFinite(price) ? price : 0
}
import DataTable from '@/components/common/DataTable.vue'
import EmptyState from '@/components/common/EmptyState.vue'
import Icon from '@/components/icons/Icon.vue'

View File

@ -844,6 +844,8 @@ export default {
perMillionTokens: '/ 1M tokens',
unitPrice: 'Per-request price',
imageUnitPrice: 'Per-image price',
imageTotalPrice: 'Image total price',
imageCount: 'Image count',
cacheRead: 'Read',
cacheWrite: 'Write',
serviceTier: 'Service tier',
@ -2050,7 +2052,13 @@ export default {
},
imagePricing: {
title: 'Image Generation Pricing',
description: 'Configure pricing for image generation models. Leave empty to use default prices.'
description: 'Configure image generation access and base image prices. Leave empty to use default prices.',
allowImageGeneration: 'Allow image generation for this group',
independentMultiplier: 'Use independent image multiplier',
imageMultiplier: 'Image multiplier',
modeHint: 'By default, image billing uses image price × current effective group multiplier. Independent mode uses image price × image multiplier.',
finalPricePreview: 'Final per-image price preview',
notConfigured: 'Not configured'
},
claudeCode: {
title: 'Claude Code Client Restriction',

View File

@ -848,6 +848,8 @@ export default {
perMillionTokens: '/ 1M Token',
unitPrice: '单次价格',
imageUnitPrice: '单张价格',
imageTotalPrice: '图片总价',
imageCount: '图片张数',
cacheRead: '读取',
cacheWrite: '写入',
serviceTier: '服务档位',
@ -2133,7 +2135,13 @@ export default {
},
imagePricing: {
title: '图片生成计费',
description: '配置图片生成模型的图片生成价格,留空则使用默认价格'
description: '配置图片生成能力和图片基础单价,留空则使用默认价格',
allowImageGeneration: '允许当前分组生图',
independentMultiplier: '生图倍率独立',
imageMultiplier: '生图独立倍率',
modeHint: '默认关闭独立倍率时,图片费用 = 图片价格 × 当前分组有效倍率;开启独立倍率后,图片费用 = 图片价格 × 生图独立倍率。',
finalPricePreview: '最终单张价格预览',
notConfigured: '未配置'
},
claudeCode: {
title: 'Claude Code 客户端限制',

View File

@ -492,7 +492,10 @@ export interface Group {
daily_limit_usd: number | null
weekly_limit_usd: number | null
monthly_limit_usd: number | null
// 图片生成计费配置(仅 antigravity 平台使用)
// 图片生成计费配置
allow_image_generation: boolean
image_rate_independent: boolean
image_rate_multiplier: number
image_price_1k: number | null
image_price_2k: number | null
image_price_4k: number | null
@ -602,6 +605,9 @@ export interface CreateGroupRequest {
daily_limit_usd?: number | null
weekly_limit_usd?: number | null
monthly_limit_usd?: number | null
allow_image_generation?: boolean
image_rate_independent?: boolean
image_rate_multiplier?: number
image_price_1k?: number | null
image_price_2k?: number | null
image_price_4k?: number | null
@ -627,6 +633,9 @@ export interface UpdateGroupRequest {
daily_limit_usd?: number | null
weekly_limit_usd?: number | null
monthly_limit_usd?: number | null
allow_image_generation?: boolean
image_rate_independent?: boolean
image_rate_multiplier?: number
image_price_1k?: number | null
image_price_2k?: number | null
image_price_4k?: number | null

View File

@ -666,6 +666,40 @@
<p class="text-xs text-gray-500 dark:text-gray-400 mb-3">
{{ t("admin.groups.imagePricing.description") }}
</p>
<div class="mb-4 grid grid-cols-1 gap-3 md:grid-cols-2">
<label class="flex items-center gap-2 text-sm text-gray-700 dark:text-gray-300">
<input
v-model="createForm.allow_image_generation"
type="checkbox"
class="rounded border-gray-300 text-blue-600 focus:ring-blue-500"
/>
{{ t("admin.groups.imagePricing.allowImageGeneration") }}
</label>
<label class="flex items-center gap-2 text-sm text-gray-700 dark:text-gray-300">
<input
v-model="createForm.image_rate_independent"
type="checkbox"
class="rounded border-gray-300 text-blue-600 focus:ring-blue-500"
/>
{{ t("admin.groups.imagePricing.independentMultiplier") }}
</label>
</div>
<div
v-if="createForm.image_rate_independent"
class="mb-4"
>
<label class="input-label">{{
t("admin.groups.imagePricing.imageMultiplier")
}}</label>
<input
v-model.number="createForm.image_rate_multiplier"
type="number"
step="0.0001"
min="0"
class="input"
placeholder="1"
/>
</div>
<div class="grid grid-cols-3 gap-3">
<div>
<label class="input-label">1K ($)</label>
@ -701,6 +735,22 @@
/>
</div>
</div>
<p class="mt-3 text-xs text-gray-500 dark:text-gray-400">
{{ t("admin.groups.imagePricing.modeHint") }}
</p>
<div class="mt-2 rounded-lg bg-gray-50 p-3 text-xs text-gray-700 dark:bg-gray-800 dark:text-gray-300">
<div class="mb-1 font-medium">
{{ t("admin.groups.imagePricing.finalPricePreview") }}
</div>
<div class="grid grid-cols-3 gap-2">
<div
v-for="item in createImageFinalPricePreview"
:key="item.label"
>
{{ item.label }}: {{ item.value }}
</div>
</div>
</div>
</div>
<!-- 支持的模型系列 antigravity 平台 -->
@ -1801,6 +1851,40 @@
<p class="text-xs text-gray-500 dark:text-gray-400 mb-3">
{{ t("admin.groups.imagePricing.description") }}
</p>
<div class="mb-4 grid grid-cols-1 gap-3 md:grid-cols-2">
<label class="flex items-center gap-2 text-sm text-gray-700 dark:text-gray-300">
<input
v-model="editForm.allow_image_generation"
type="checkbox"
class="rounded border-gray-300 text-blue-600 focus:ring-blue-500"
/>
{{ t("admin.groups.imagePricing.allowImageGeneration") }}
</label>
<label class="flex items-center gap-2 text-sm text-gray-700 dark:text-gray-300">
<input
v-model="editForm.image_rate_independent"
type="checkbox"
class="rounded border-gray-300 text-blue-600 focus:ring-blue-500"
/>
{{ t("admin.groups.imagePricing.independentMultiplier") }}
</label>
</div>
<div
v-if="editForm.image_rate_independent"
class="mb-4"
>
<label class="input-label">{{
t("admin.groups.imagePricing.imageMultiplier")
}}</label>
<input
v-model.number="editForm.image_rate_multiplier"
type="number"
step="0.0001"
min="0"
class="input"
placeholder="1"
/>
</div>
<div class="grid grid-cols-3 gap-3">
<div>
<label class="input-label">1K ($)</label>
@ -1836,6 +1920,22 @@
/>
</div>
</div>
<p class="mt-3 text-xs text-gray-500 dark:text-gray-400">
{{ t("admin.groups.imagePricing.modeHint") }}
</p>
<div class="mt-2 rounded-lg bg-gray-50 p-3 text-xs text-gray-700 dark:bg-gray-800 dark:text-gray-300">
<div class="mb-1 font-medium">
{{ t("admin.groups.imagePricing.finalPricePreview") }}
</div>
<div class="grid grid-cols-3 gap-2">
<div
v-for="item in editImageFinalPricePreview"
:key="item.label"
>
{{ item.label }}: {{ item.value }}
</div>
</div>
</div>
</div>
<!-- 支持的模型系列 antigravity 平台 -->
@ -3009,7 +3109,10 @@ const createForm = reactive({
daily_limit_usd: null as number | null,
weekly_limit_usd: null as number | null,
monthly_limit_usd: null as number | null,
// antigravity 使
//
allow_image_generation: false,
image_rate_independent: false,
image_rate_multiplier: 1,
image_price_1k: null as number | null,
image_price_2k: null as number | null,
image_price_4k: null as number | null,
@ -3291,7 +3394,10 @@ const editForm = reactive({
daily_limit_usd: null as number | null,
weekly_limit_usd: null as number | null,
monthly_limit_usd: null as number | null,
// antigravity 使
//
allow_image_generation: false,
image_rate_independent: false,
image_rate_multiplier: 1,
image_price_1k: null as number | null,
image_price_2k: null as number | null,
image_price_4k: null as number | null,
@ -3321,6 +3427,62 @@ const editForm = reactive({
rpm_limit: 0 as number,
});
type ImagePricingFormState = {
rate_multiplier: number;
image_rate_independent: boolean;
image_rate_multiplier: number;
image_price_1k: number | string | null;
image_price_2k: number | string | null;
image_price_4k: number | string | null;
};
const imagePricingTiers = [
{ key: "image_price_1k", label: "1K" },
{ key: "image_price_2k", label: "2K" },
{ key: "image_price_4k", label: "4K" },
] as const;
const normalizePreviewNumber = (value: number | string | null | undefined, fallback = 0) => {
if (value === null || value === undefined || value === "") {
return fallback;
}
const parsed = Number(value);
return Number.isFinite(parsed) ? parsed : fallback;
};
const formatImagePricePreview = (value: number | string | null | undefined) => {
if (value === null || value === undefined || value === "") {
return t("admin.groups.imagePricing.notConfigured");
}
const price = Number(value);
if (!Number.isFinite(price) || price < 0) {
return t("admin.groups.imagePricing.notConfigured");
}
return `$${price.toFixed(6).replace(/0+$/, "").replace(/\.$/, "")}`;
};
const buildImageFinalPricePreview = (form: ImagePricingFormState) => {
const multiplier = form.image_rate_independent
? normalizePreviewNumber(form.image_rate_multiplier, 1)
: normalizePreviewNumber(form.rate_multiplier, 1);
return imagePricingTiers.map((tier) => {
const basePrice = normalizePreviewNumber(form[tier.key]);
return {
label: tier.label,
value: basePrice > 0
? formatImagePricePreview(basePrice * multiplier)
: t("admin.groups.imagePricing.notConfigured"),
};
});
};
const createImageFinalPricePreview = computed(() =>
buildImageFinalPricePreview(createForm),
);
const editImageFinalPricePreview = computed(() =>
buildImageFinalPricePreview(editForm),
);
//
const deleteConfirmMessage = computed(() => {
if (!deletingGroup.value) {
@ -3479,6 +3641,9 @@ const closeCreateModal = () => {
createForm.daily_limit_usd = null;
createForm.weekly_limit_usd = null;
createForm.monthly_limit_usd = null;
createForm.allow_image_generation = false;
createForm.image_rate_independent = false;
createForm.image_rate_multiplier = 1;
createForm.image_price_1k = null;
createForm.image_price_2k = null;
createForm.image_price_4k = null;
@ -3513,6 +3678,16 @@ const normalizeOptionalLimit = (
return Number.isFinite(value) && value > 0 ? value : null;
};
const normalizeImageRateMultiplier = (
value: number | string | null | undefined,
): number => {
if (value === null || value === undefined || value === "") {
return 1;
}
const parsed = Number(value);
return Number.isFinite(parsed) && parsed >= 0 ? parsed : 1;
};
const handleCreateGroup = async () => {
if (!createForm.name.trim()) {
appStore.showError(t("admin.groups.nameRequired"));
@ -3551,6 +3726,9 @@ const handleCreateGroup = async () => {
requestData.daily_limit_usd = emptyToNull(requestData.daily_limit_usd);
requestData.weekly_limit_usd = emptyToNull(requestData.weekly_limit_usd);
requestData.monthly_limit_usd = emptyToNull(requestData.monthly_limit_usd);
requestData.image_rate_multiplier = normalizeImageRateMultiplier(
requestData.image_rate_multiplier,
);
await adminAPI.groups.create(requestData);
appStore.showSuccess(t("admin.groups.groupCreated"));
closeCreateModal();
@ -3582,6 +3760,9 @@ const handleEdit = async (group: AdminGroup) => {
editForm.daily_limit_usd = group.daily_limit_usd;
editForm.weekly_limit_usd = group.weekly_limit_usd;
editForm.monthly_limit_usd = group.monthly_limit_usd;
editForm.allow_image_generation = group.allow_image_generation ?? false;
editForm.image_rate_independent = group.image_rate_independent ?? false;
editForm.image_rate_multiplier = group.image_rate_multiplier ?? 1;
editForm.image_price_1k = group.image_price_1k;
editForm.image_price_2k = group.image_price_2k;
editForm.image_price_4k = group.image_price_4k;
@ -3676,6 +3857,9 @@ const handleUpdateGroup = async () => {
payload.daily_limit_usd = emptyToNull(payload.daily_limit_usd);
payload.weekly_limit_usd = emptyToNull(payload.weekly_limit_usd);
payload.monthly_limit_usd = emptyToNull(payload.monthly_limit_usd);
payload.image_rate_multiplier = normalizeImageRateMultiplier(
payload.image_rate_multiplier,
);
await adminAPI.groups.update(editingGroup.value.id, payload);
appStore.showSuccess(t("admin.groups.groupUpdated"));
closeEditModal();

View File

@ -459,9 +459,23 @@
</div>
</template>
<!-- Per-request / image billing: show unit price -->
<template v-else-if="tooltipData?.billing_mode === 'image'">
<div class="flex items-center justify-between gap-4">
<span class="text-gray-400">{{ t('usage.imageCount') }}</span>
<span class="font-medium text-white">{{ tooltipData.image_count }}{{ t('usage.imageUnit') }} ({{ tooltipData.image_size || '2K' }})</span>
</div>
<div class="flex items-center justify-between gap-4">
<span class="text-gray-400">{{ t('usage.imageUnitPrice') }}</span>
<span class="font-medium text-sky-300">${{ imageUnitPrice(tooltipData).toFixed(6) }}</span>
</div>
<div class="flex items-center justify-between gap-4">
<span class="text-gray-400">{{ t('usage.imageTotalPrice') }}</span>
<span class="font-medium text-white">${{ tooltipData.total_cost?.toFixed(6) || '0.000000' }}</span>
</div>
</template>
<div v-else class="flex items-center justify-between gap-4">
<span class="text-gray-400">{{ tooltipData.billing_mode === 'image' ? t('usage.imageUnitPrice') : t('usage.unitPrice') }}</span>
<span class="font-medium text-sky-300">${{ tooltipData.total_cost?.toFixed(6) || '0.000000' }}</span>
<span class="text-gray-400">{{ t('usage.unitPrice') }}</span>
<span class="font-medium text-sky-300">${{ tooltipData?.total_cost?.toFixed(6) || '0.000000' }}</span>
</div>
<div v-if="tooltipData && tooltipData.cache_creation_cost > 0" class="flex items-center justify-between gap-4">
<span class="text-gray-400">{{ t('admin.usage.cacheCreationCost') }}</span>
@ -625,6 +639,13 @@ const formatDuration = (ms: number): string => {
return `${(ms / 1000).toFixed(2)}s`
}
const imageUnitPrice = (row: UsageLog | null): number => {
if (!row || row.image_count <= 0) return 0
const total = row.total_cost ?? 0
const price = total / row.image_count
return Number.isFinite(price) ? price : 0
}
const formatUserAgent = (ua: string): string => {
return ua
}

View File

@ -44,7 +44,6 @@ export default defineConfig(({ mode }) => {
plugins: [
vue(),
checker({
typescript: true,
vueTsc: true
}),
injectPublicSettings(backendUrl)

View File

@ -0,0 +1,2 @@
schema: spec-driven
created: 2026-04-29

View File

@ -0,0 +1,227 @@
## Context
当前代码已经具备图片价格字段和部分图片转发能力,但边界不完整:
- `backend/ent/schema/group.go` 只有 `rate_multiplier``image_price_1k/2k/4k`,没有分组级生图能力开关,也没有“图片是否共享分组倍率”的开关。
- `backend/internal/handler/openai_images.go` 在解析 `/v1/images/*` 后只做通用余额/订阅资格检查,没有检查分组是否允许生图。
- `backend/internal/service/openai_gateway_service.go` 对 Codex CLI 会自动注入 `image_generation` tool通用 `/v1/responses` 只记录日志,没有把图片工具产物数量写入 `OpenAIForwardResult.ImageCount`
- `backend/internal/service/billing_service.go``CalculateImageCost` 当前使用 `image_price_* * image_count * rate_multiplier`。这个行为本身可以作为默认兼容模式,但普通编码分组 `rate_multiplier=0.15` 且希望图片最终价为 `0.2/张` 时,管理员必须填写 `image_price=0.2/0.15`,不可读且不适合长期运营。
- `backend/internal/service/openai_gateway_service.go``backend/internal/service/gateway_service.go` 的渠道图片计费路径当前传 `RequestCount: 1`,多图请求会按 1 次收费。
- `backend/internal/service/openai_images.go` 的 OpenAI 图片尺寸分层此前只覆盖少量固定尺寸;`gpt-image-2` 官方文档已经支持满足约束的自定义 `size`,因此本地计费必须能够对未知尺寸做稳定分档,同时不能因为本地映射不认识就提前拦截请求。
用户澄清后的业务要求是:普通编码分组可以关闭生图,也可以开启生图;开启后默认继续共享现有分组倍率以保持兼容,但管理员可以打开“生图倍率独立”开关,改用单独的图片倍率输入框。图片分组是推荐的运营隔离方式,但不是唯一承载方式。
## Goals / Non-Goals
**Goals:**
- 分组具备明确的 `allow_image_generation` 开关,所有已知生图入口在调度上游前执行同一个权限判断。
- 分组具备“生图倍率是否独立”的开关;默认 `false`,即共享当前代码里的有效分组倍率。
- 生图倍率独立开关打开后,图片费用使用单独的 `image_rate_multiplier`,不再使用普通编码分组的倍率。
- 保留现有 `image_price_1k/2k/4k` 字段作为图片单价配置,不强制把它们迁移成新的语义。
- 普通编码分组在 `allow_image_generation=false` 时仍可正常使用 `gpt-5.4` / `gpt-5.5` 文本能力,但不能使用图片工具。
- 普通编码分组在 `allow_image_generation=true` 时可使用 `gpt-5.4` / `gpt-5.5 + image_generation`,且按实际图片数量收费。
- 通用 `/v1/responses`、OpenAI Images API、流式、非流式、透传路径全部把成功产出的图片数量写入 `ImageCount`
- 渠道 `billing_mode=image` 使用真实 `ImageCount`,不再固定按 1 次收费。
**Non-Goals:**
- 不引入新的第三方依赖。
- 不改变 OpenAI 上游协议;只在现有请求转发、响应解析和计费归因层补齐控制。
- 不把“图片分组”做成唯一安全边界;分组开关和图片计费逻辑必须适用于任意开启生图的分组。
- 不在本变更中实现预扣费/资金冻结;失败请求仍不收费,成功请求按实际产物后扣费。
- 不改变默认历史图片价格行为;默认共享现有有效倍率,历史 `图片价格 * 分组/用户有效倍率` 的扣费行为保持。
- 不在本变更中新增用户级图片独立倍率覆盖;用户专属普通倍率只在共享倍率模式下继续影响图片。
## Decisions
### 0. 兼容性优先原则
本变更的默认行为必须以“不改变现有已配置分组的最终扣费”为优先级:
- 迁移不修改现有 `image_price_1k/2k/4k`
- 迁移把所有现有分组设置为 `image_rate_independent=false`,因此现有图片路径继续使用当前有效分组倍率。
- 管理员不传新字段更新分组时,不得覆盖已保存的 `allow_image_generation``image_rate_independent``image_rate_multiplier`
- 前端编辑旧分组时必须回显服务端值;不能因为表单默认值把旧分组从共享倍率误改成独立倍率,或把允许生图误改成禁止生图。
- 只有管理员显式打开 `image_rate_independent` 后,图片扣费才从共享倍率切换到图片独立倍率。
### 1. 分组字段与迁移策略
新增三个分组字段对应“2 个开关 + 1 个输入框”:
- `allow_image_generation BOOLEAN NOT NULL DEFAULT false`
- `image_rate_independent BOOLEAN NOT NULL DEFAULT false`
- `image_rate_multiplier DECIMAL(10,4) NOT NULL DEFAULT 1.0`
字段语义:
- `allow_image_generation`:是否支持当前分组生图。
- `image_rate_independent=false`:图片计费共享当前普通计费链路里的有效倍率,即当前 `userGroupRateResolver.Resolve(ctx, user.ID, groupID, group.RateMultiplier)` 得到的倍率;这保持现有行为。
- `image_rate_independent=true`:图片计费使用 `group.image_rate_multiplier`;普通编码的 `rate_multiplier` 和用户专属普通倍率不参与图片扣费。
- `image_price_1k/2k/4k`:继续表示图片基础单价,由选中的图片倍率模式继续相乘。
新建分组默认 `allow_image_generation=false`,避免新普通编码分组意外获得生图能力。为避免升级后立即打断已有图片业务,迁移对现有 `openai``gemini``antigravity` 分组回填 `allow_image_generation=true``anthropic` 分组保持 `false`。该回填只是兼容现状;上线后管理员必须按业务策略关闭不允许生图的普通编码分组。
迁移不改写已有 `image_price_1k/2k/4k`,并将所有现有分组设为 `image_rate_independent=false``image_rate_multiplier=1`。这样现有最终扣费公式保持不变:
```text
历史/默认模式图片最终扣费 = image_price_* * image_count * 当前有效分组倍率
```
普通编码分组 `rate_multiplier=0.15` 且希望图片 1K 最终扣费 `0.2/张` 时,管理员不再需要填写 `0.2/0.15`,而是设置:
```text
image_rate_independent = true
image_rate_multiplier = 1
image_price_1k = 0.2
```
如果希望图片也打折,例如图片标价 `0.2/张`、图片折扣 `0.8`,则设置:
```text
image_rate_independent = true
image_rate_multiplier = 0.8
image_price_1k = 0.2
```
### 2. 生图意图统一识别
新增一个服务层 helper输入至少包含 endpoint、请求模型、请求体输出是否为生图意图
```text
isImageGenerationIntent =
endpoint 是 /v1/images/generations 或 /v1/images/edits
OR requested model 以 gpt-image- 开头
OR tools[] 存在 type == image_generation
OR tool_choice 显式指向 image_generation
```
生图意图判断必须在请求体被 Codex 注入、模型改写、渠道映射改写之前执行一次,并在这些改写之后再对最终请求体执行一次。原因是当前代码会在 `backend/internal/service/openai_gateway_service.go` 中注入 `image_generation` tool也会在 `normalizeOpenAIResponsesImageOnlyModel` 中把 `gpt-image-*` 改写为文本模型 + 图片工具;只检查改写前或只检查改写后都可能漏掉场景。
`tool_choice` 判断只把明确指向 `image_generation` 的值视为生图意图;`auto``none``required` 本身不构成生图意图,但如果 `tools[]` 中存在 `image_generation`,仍由 `tools[]` 规则命中。
该判断必须在以下位置使用:
- `/v1/images/*` handler 解析请求后、账号调度前。
- `/v1/responses` 解析 body 后、Codex 自动注入 `image_generation` tool 前。
- `normalizeOpenAIResponsesImageOnlyModel``gpt-image-*` 改写为 Responses 文本模型前。
- OpenAI 高级 scheduler 入口保留现有账号能力检查,同时补齐渠道 restriction 检查,避免启用高级调度时绕过渠道模型限制。
`allow_image_generation=false` 时:
- 显式生图意图返回 HTTP 403错误类型使用现有 `permission_error` 风格。
- Codex CLI 请求不自动注入 `image_generation` tool也不追加图片桥接指令如果请求没有显式生图意图则继续按普通文本请求处理。
### 3. gpt-5.4 / gpt-5.5 生图承载方式
`gpt-5.4` / `gpt-5.5` 生图通过现有 OpenAI Responses API 的 `image_generation` tool 承载,不新增专用 endpoint
```json
{
"model": "gpt-5.4",
"input": "生成一张图片",
"tools": [
{
"type": "image_generation",
"model": "gpt-image-2",
"size": "1024x1024",
"output_format": "png"
}
],
"tool_choice": { "type": "image_generation" }
}
```
`model=gpt-image-*` 发到 `/v1/responses` 时保留现有改写方向:主模型改为 Responses 文本模型,图片模型放入 `image_generation` tool。计费时如果能从工具配置得到 `gpt-image-*`,图片默认价格按该图片模型解析;如果工具未指定图片模型,则使用当前转发结果的 billing model并优先使用分组/渠道配置价格。
### 4. 图片数量归因
新增统一图片输出解析 helper返回去重后的图片数量和可用图片元信息。必须覆盖以下已有或可借鉴的事件形态
- 非流式 Responses JSON`output[]``type == image_generation_call``result` 非空。
- Responses SSE`response.output_item.done``item.type == image_generation_call``item.result` 非空。
- Responses SSE 完成事件:`response.completed.response.output[]` 中图片工具结果。
- Images API 非流式:顶层 `data[]`
- Images API 流式:顶层 `data[]``image_generation.completed``response.output_item.done``response.completed`
去重键按优先级使用 `item.id``call_id``result` 内容 hash。只统计最终图片不统计 `partial_image`
`openaiStreamingResult` 增加 `imageCount``imageSize``imageBillingModel``handleStreamingResponse``handleStreamingResponsePassthrough``handleNonStreamingResponse``handleNonStreamingResponsePassthrough` 都必须把解析结果带回 `OpenAIForwardResult`。当 `ImageCount > 0` 时,即使上游 usage 为 0也必须写 usage log 并进入图片计费。
### 5. 图片价格公式
图片计费先确定单价,再确定倍率:
```text
unit_price = 渠道 image 模式价格 或 分组 image_price_* 或 默认图片价格
image_multiplier =
如果 group.image_rate_independent == true: group.image_rate_multiplier
否则: 当前有效分组倍率
total_cost = unit_price * image_count
actual_cost = total_cost * image_multiplier
```
“当前有效分组倍率”必须沿用当前代码的倍率解析方式:默认配置倍率 → 分组 `rate_multiplier` → 用户专属分组倍率覆盖。这样 `image_rate_independent=false` 时完全保留当前行为。
`billing_mode=image` 的渠道价格是图片单价来源之一,仍优先于分组图片价格。图片渠道价格也必须按 `ImageCount` 计数,并使用同一套 `image_multiplier` 选择逻辑。
`billing_mode=per_request` 的非图片请求保持当前普通按次语义,继续使用普通 token 倍率;只有已经识别为图片请求且 `ImageCount > 0` 的路径使用图片计费逻辑。
`usage_logs.rate_multiplier` 继续表示“本次扣费实际使用的倍率”。因此:
- token 日志记录普通 token 有效倍率。
- image 日志在共享模式记录普通有效倍率。
- image 日志在独立模式记录 `image_rate_multiplier`
专用 `/v1/images/*` 仍按图片请求语义计费:当 `ImageCount > 0` 时,图片价格决定费用,伴随的上游 token usage 只记录不额外计 token 费用。这保持当前 Images API 的行为。
通用 `/v1/responses + image_generation` 的混合文本+图片输出存在一个明确取舍:如果继续沿用“`ImageCount > 0` 时只按图片计费”的当前计费分支,用户可以在一次图片请求中夹带大量文本输出而只付图片费用;如果改成“图片费用 + 非图片 token 费用”,会改变当前 `billing_mode=image` 的单一计费语义,并可能让渠道图片单价不再是全包价格。本变更为最大兼容性不引入混合计费模式,但必须在 usage log 中完整记录 token 与 image_count便于后续按数据决定是否新增 `image_plus_token` 计费模式。
### 6. 尺寸档位与参数透传
OpenAI 图片请求的 `size` 参数必须透传给上游;本地只做计费分档,不做 OpenAI 尺寸合法性校验。无论尺寸是否满足官方约束,本地都不能因为未知尺寸或 provider-invalid 尺寸返回 400如果上游不接受该尺寸由上游响应错误。
官方 `gpt-image-2` 文档给出的常用尺寸与约束是本地计费分档的依据:
- 常用尺寸:`1024x1024``1536x1024``1024x1536``2048x2048``2048x1152``3840x2160``2160x3840``auto`
- 自定义尺寸:官方支持满足约束的任意 `size`包括边长、16 像素倍数、长短边比例、总像素范围等约束。
- `2560x1440` 是 2K/QHD 参考边界;超过 `2560x1440` 总像素的输出进入更高档位风险区。
OpenAI 图片尺寸分层必须按以下规则:
```text
empty, auto => 2K
1024x1024 => 1K
1536x1024, 1024x1536 => 2K
1792x1024, 1024x1792 => 2K
2048x2048, 2048x1152, 1152x2048 => 2K
3840x2160, 2160x3840 => 4K
未知且无法解析为正整数 WIDTHxHEIGHT => 2K
未知且 WIDTH * HEIGHT <= 2560*1440 => 2K
未知且 WIDTH * HEIGHT > 2560*1440 => 4K
```
这个规则只决定 `ImageSize` 和扣费档位,不修改请求体,不删除未知参数,不把未知尺寸改写成预设尺寸。
## Risks / Trade-offs
- 历史普通编码分组迁移后仍默认允许生图 → 通过管理员可见开关、上线核对清单和新建分组默认关闭来控制;代码无法可靠判断“普通编码分组”和“图片分组”的业务意图。
- 默认共享现有有效倍率仍保留“图片最终价不直观”的问题 → 这是兼容性选择;需要直观设置图片最终价的分组必须打开 `image_rate_independent`
- 独立图片倍率不会读取用户专属普通倍率 → 这是目标行为;如需要用户级图片独立倍率,应作为后续独立需求实现。
- 通用 Responses 图片工具可能同时输出文本和图片 → 本变更默认仍按图片请求语义计费并完整记录 token若业务要求文本也收费应新增独立的混合计费模式不能混入本次兼容性变更。
- 本地不再拦截未知或 provider-invalid OpenAI 尺寸 → 非法尺寸会消耗一次上游请求失败成本和用户体验往返,但这是为了保证参数透传、兼容官方新增尺寸和第三方兼容提供商;计费只在成功产出最终图片后发生。
- Responses 流式解析需要在客户端断开后继续 drain 上游以完成计费 → 沿用当前流式处理“客户端断开后继续读取上游用于计费”的模式,并只新增轻量 JSON 路径提取。
- 预扣费不在本变更中实现 → 继续使用现有成功后扣费模型,避免失败请求退款、流式中断退款和图片数量未确定时预估错误。
## Migration Plan
1. 新增数据库迁移,添加 `groups.allow_image_generation``groups.image_rate_independent``groups.image_rate_multiplier`
2. 回填现有分组:`openai``gemini``antigravity``allow_image_generation=true``anthropic=false`;所有现有分组 `image_rate_independent=false``image_rate_multiplier=1`
3. 不改写现有 `image_price_1k/2k/4k`,保持默认共享倍率模式下的历史扣费结果。
4. 更新 Ent schema 与生成代码,更新后端 service/handler DTO 和前端类型。
5. 先接入权限判断,确保未开启生图的分组不会到达上游。
6. 再接入图片数量解析和图片计费倍率选择,确保开启生图的分组按图片数量收费。
7. 最后更新前端管理界面、i18n、文档和测试。
8. 回滚时只能通过新迁移回滚字段行为;不能修改已应用迁移文件。
## Open Questions
无。当前方案不依赖未确认的上游新尺寸、新模型或新 endpoint。

View File

@ -0,0 +1,29 @@
## Why
当前代码把“能否生图”和“如何按图片收费”混在模型、分组倍率、渠道定价与 Responses 工具调用里,导致 OpenAI 普通编码分组在允许 `gpt-5.4` / `gpt-5.5` 时也能通过 `image_generation` tool 产图,并且通用 `/v1/responses` 产图不会稳定写入 `ImageCount`。需要把生图能力、图片倍率模式、图片产出数量归因拆成独立能力,保证普通编码分组可按业务开关生图,开启后既能沿用现有倍率行为,也能按需切换到图片独立倍率。
## What Changes
- 新增分组级生图能力开关,明确控制 `/v1/images/*``gpt-image-*`、显式 `image_generation` tool、Codex 自动注入图片工具等所有生图入口。
- 新增分组级图片倍率模式开关,默认继续共享现有分组有效倍率;打开独立模式后使用图片独立倍率输入框。
- 保留现有 `image_price_1k/2k/4k` 图片价格配置;图片最终扣费由“图片价格 × 当前倍率模式选出的倍率 × 图片数量”决定。
- 统一统计 OpenAI Responses 图片工具产物数量,使 `gpt-5.4` / `gpt-5.5` 通过 `image_generation` tool 产图时进入图片计费,而不是退化成普通 token 计费或无 usage 时不计费。
- 修正专用 Images API 与渠道图片计费场景,按实际图片数量和明确尺寸档位计费,避免固定 `RequestCount=1` 或未知尺寸静默落到 `2K`
- 更新后台分组配置、前端类型、使用说明和测试,覆盖普通编码分组关闭生图、普通编码分组开启生图、独立图片分组承载、生图流式/非流式等场景。
## Capabilities
### New Capabilities
- `image-generation-access-control`: 定义分组级生图能力开关、所有生图意图识别规则、拒绝行为与 Codex 自动注入规则。
- `image-generation-billing-accounting`: 定义图片倍率模式、图片数量归因、尺寸档位、渠道图片价格和用量日志要求。
### Modified Capabilities
- 无。
## Impact
- Backend schema/API: `backend/ent/schema/group.go`、Ent 生成代码、数据库迁移、管理员分组 create/update/list DTO、分组缓存/序列化。
- Backend request gates: `backend/internal/handler/openai_images.go``backend/internal/service/openai_gateway_service.go``backend/internal/service/openai_codex_transform.go`、OpenAI account scheduler 相关模型/图片能力调度入口。
- Backend billing: `backend/internal/service/billing_service.go``backend/internal/service/openai_gateway_service.go``backend/internal/service/gateway_service.go`、usage log 与 account stats 成本计算路径。
- Frontend admin: `frontend/src/types/index.ts``frontend/src/views/admin/GroupsView.vue`、相关 i18n 文案与图片计费展示。
- Tests: OpenAI Images API、OpenAI Responses stream/non-stream/passthrough、分组开关、图片倍率模式、渠道图片计数、尺寸档位与 usage log 断言。

View File

@ -0,0 +1,118 @@
## ADDED Requirements
### Requirement: Group image generation capability
The system SHALL store a group-level `allow_image_generation` capability flag and SHALL expose it through admin group create, update, list, and detail APIs.
#### Scenario: New group defaults to image generation disabled
- **WHEN** an admin creates a group without providing `allow_image_generation`
- **THEN** the persisted group has `allow_image_generation=false`
#### Scenario: Existing image-capable platform groups are backfilled
- **WHEN** the migration is applied to existing groups
- **THEN** existing `openai`, `gemini`, and `antigravity` groups have `allow_image_generation=true`
- **AND** existing `anthropic` groups have `allow_image_generation=false`
#### Scenario: Admin enables image generation on an ordinary coding group
- **WHEN** an admin updates an `openai` group with `allow_image_generation=true`
- **THEN** the group can use image generation paths subject to the billing requirements
#### Scenario: Admin disables image generation on an ordinary coding group
- **WHEN** an admin updates an `openai` group with `allow_image_generation=false`
- **THEN** the group can still use non-image text model requests
- **AND** image generation intents are denied before upstream dispatch
### Requirement: Image generation intent detection
The system SHALL classify a request as an image generation intent before upstream account scheduling when the endpoint or request body can produce generated images.
#### Scenario: Images endpoint is an image generation intent
- **WHEN** a request targets `/v1/images/generations`, `/v1/images/edits`, `/images/generations`, or `/images/edits`
- **THEN** the request is classified as an image generation intent
#### Scenario: Responses request with image-only model is an image generation intent
- **WHEN** a `/v1/responses` request has a requested model whose normalized name starts with `gpt-image-`
- **THEN** the request is classified as an image generation intent before any model rewrite
#### Scenario: Responses request with image_generation tool is an image generation intent
- **WHEN** a `/v1/responses` request contains any `tools[]` entry with `type == "image_generation"`
- **THEN** the request is classified as an image generation intent
#### Scenario: Responses request with image_generation tool_choice is an image generation intent
- **WHEN** a `/v1/responses` request contains `tool_choice` that explicitly selects `image_generation`
- **THEN** the request is classified as an image generation intent even if `tools[]` is malformed or absent
#### Scenario: Generic tool_choice required is not sufficient by itself
- **WHEN** a `/v1/responses` request contains `tool_choice="required"`
- **AND** the request does not contain an `image_generation` tool
- **THEN** the request is not classified as an image generation intent because of `tool_choice` alone
#### Scenario: Text-only gpt-5.4 request is not an image generation intent
- **WHEN** a `/v1/responses` request uses `model="gpt-5.4"` or `model="gpt-5.5"` without `image_generation` tool and without image `tool_choice`
- **THEN** the request is not classified as an image generation intent
#### Scenario: Intent is checked before and after service-side mutation
- **WHEN** the service mutates a `/v1/responses` request by injecting `image_generation` or rewriting `gpt-image-*` to a Responses text model plus image tool
- **THEN** the final mutated request is checked against the same image generation intent rules before upstream dispatch
### Requirement: Disabled groups reject explicit image generation
The system SHALL reject explicit image generation intents for groups with `allow_image_generation=false` before selecting or calling an upstream account.
#### Scenario: Disabled group rejects Images API
- **WHEN** a group has `allow_image_generation=false`
- **AND** a user calls `/v1/images/generations`
- **THEN** the system returns HTTP 403 with error type `permission_error`
- **AND** no upstream account is selected
- **AND** no usage log is written
#### Scenario: Disabled group rejects Responses image tool
- **WHEN** a group has `allow_image_generation=false`
- **AND** a user calls `/v1/responses` with `tools:[{"type":"image_generation"}]`
- **THEN** the system returns HTTP 403 with error type `permission_error`
- **AND** no upstream account is selected
- **AND** no usage log is written
#### Scenario: Disabled group rejects Responses image-only model rewrite
- **WHEN** a group has `allow_image_generation=false`
- **AND** a user calls `/v1/responses` with `model` starting with `gpt-image-`
- **THEN** the system returns HTTP 403 with error type `permission_error`
- **AND** the request is not rewritten to a text Responses model
#### Scenario: Disabled group permits normal coding request
- **WHEN** a group has `allow_image_generation=false`
- **AND** a user calls `/v1/responses` with `model="gpt-5.4"` and no image generation intent
- **THEN** the request proceeds through the normal text forwarding path
### Requirement: Codex image tool injection respects group capability
The system SHALL only inject the OpenAI Responses `image_generation` tool and bridge instructions for Codex clients when the request group has `allow_image_generation=true`.
#### Scenario: Codex request in enabled group receives image tool
- **WHEN** a Codex CLI `/v1/responses` request belongs to a group with `allow_image_generation=true`
- **AND** the request has no `image_generation` tool
- **THEN** the system injects the existing `image_generation` tool payload
- **AND** the system appends the existing Codex image bridge instructions
#### Scenario: Codex request in disabled group does not receive image tool
- **WHEN** a Codex CLI `/v1/responses` request belongs to a group with `allow_image_generation=false`
- **AND** the request has no explicit image generation intent
- **THEN** the system does not inject `image_generation`
- **AND** the system does not append image bridge instructions
- **AND** the request proceeds as a text request
#### Scenario: Codex explicit image request in disabled group is denied
- **WHEN** a Codex CLI `/v1/responses` request belongs to a group with `allow_image_generation=false`
- **AND** the request explicitly contains `image_generation`
- **THEN** the system returns HTTP 403 with error type `permission_error`
### Requirement: Channel model restrictions remain enforced
The system SHALL keep existing channel model restriction behavior for image and non-image OpenAI requests, including when the advanced OpenAI account scheduler is enabled.
#### Scenario: Advanced scheduler blocks restricted requested model
- **WHEN** a channel has `restrict_models=true`
- **AND** the requested model is not allowed by channel pricing or mapping rules
- **AND** the OpenAI advanced scheduler path is used
- **THEN** the request is rejected before upstream account selection succeeds
#### Scenario: Image generation flag does not bypass channel restrictions
- **WHEN** a group has `allow_image_generation=true`
- **AND** the channel restriction rejects the requested or billing model
- **THEN** the image generation request is rejected
- **AND** no upstream image request is sent

View File

@ -0,0 +1,225 @@
## ADDED Requirements
### Requirement: Image multiplier mode
The system SHALL calculate image generation cost with group image prices and a selectable image multiplier mode. By default image billing SHALL share the existing effective group multiplier; when `image_rate_independent=true`, image billing SHALL use `image_rate_multiplier`.
#### Scenario: Default image billing shares current effective group multiplier
- **WHEN** a group has `rate_multiplier=0.15`
- **AND** `image_rate_independent=false`
- **AND** `image_price_1k=0.2`
- **AND** a successful image request produces one `1K` image
- **THEN** `actual_cost` is `0.03`
- **AND** the calculation matches current default behavior
#### Scenario: User-specific token multiplier still applies in shared mode
- **WHEN** a user has a user-group token multiplier override of `0.2`
- **AND** the group has `image_rate_independent=false`
- **AND** `image_price_1k=0.5`
- **AND** a successful image request produces one `1K` image
- **THEN** `actual_cost` is `0.1`
- **AND** the applied image multiplier is the same effective multiplier used by token billing
#### Scenario: Independent image multiplier allows direct final price
- **WHEN** a group has `rate_multiplier=0.15`
- **AND** `image_rate_independent=true`
- **AND** `image_rate_multiplier=1`
- **AND** `image_price_1k=0.2`
- **AND** a successful image request produces one `1K` image
- **THEN** `actual_cost` is `0.2`
- **AND** ordinary `rate_multiplier=0.15` is not applied to the image cost
#### Scenario: Independent image multiplier supports image discounts
- **WHEN** a group has `image_rate_independent=true`
- **AND** `image_rate_multiplier=0.5`
- **AND** `image_price_1k=0.2`
- **AND** a successful image request produces two `1K` images
- **THEN** `total_cost` is `0.4`
- **AND** `actual_cost` is `0.2`
#### Scenario: Migration preserves existing image price behavior
- **WHEN** an existing group has `rate_multiplier=0.15` and `image_price_1k=1.3333333333`
- **AND** the migration is applied
- **THEN** the stored `image_price_1k` remains `1.3333333333`
- **AND** the stored `image_rate_independent` is `false`
- **AND** the stored `image_rate_multiplier` is `1`
- **AND** default-mode image billing still produces the historical final price within decimal precision
#### Scenario: Omitted update fields preserve existing multiplier mode
- **WHEN** an admin updates a group without sending `image_rate_independent`
- **AND** without sending `image_rate_multiplier`
- **THEN** the stored image multiplier mode and image multiplier value remain unchanged
#### Scenario: Image multiplier can be zero only by explicit independent mode configuration
- **WHEN** a group has `image_rate_independent=true`
- **AND** `image_rate_multiplier=0`
- **AND** a successful image request produces one image
- **THEN** the image request is free
- **AND** this free-image behavior does not occur unless the group explicitly enables independent image multiplier mode with zero multiplier
### Requirement: Responses image output accounting
The system SHALL count generated image outputs from OpenAI Responses stream, non-stream, and passthrough paths and SHALL return the count in `OpenAIForwardResult.ImageCount`.
#### Scenario: Non-stream Responses image tool output is counted
- **WHEN** a non-stream `/v1/responses` upstream response contains `output[]` item with `type == "image_generation_call"` and non-empty `result`
- **THEN** `OpenAIForwardResult.ImageCount` equals the number of unique final image outputs
- **AND** `OpenAIForwardResult.ImageSize` is the normalized image size tier
#### Scenario: Stream Responses output item is counted
- **WHEN** a stream `/v1/responses` upstream SSE event has `type == "response.output_item.done"`
- **AND** the event item has `type == "image_generation_call"` and non-empty `result`
- **THEN** the streaming result increments the unique final image output count
#### Scenario: Stream Responses completed output is counted
- **WHEN** a stream `/v1/responses` upstream SSE event has `type == "response.completed"`
- **AND** `response.output[]` contains final image generation outputs
- **THEN** the streaming result counts those images without double-counting images already seen in `response.output_item.done`
#### Scenario: Partial image events are not billed as completed images
- **WHEN** a stream response contains `partial_image` events
- **THEN** those partial events do not increment `ImageCount`
- **AND** only final image generation outputs increment `ImageCount`
#### Scenario: gpt-5.4 image tool request is billed as image
- **WHEN** a `/v1/responses` request uses `model="gpt-5.4"` or `model="gpt-5.5"`
- **AND** the request includes an `image_generation` tool
- **AND** the upstream response contains one final image output
- **THEN** the usage log has `image_count=1`
- **AND** the usage log has `billing_mode="image"`
- **AND** image pricing, not token pricing, determines `actual_cost`
#### Scenario: Image output with zero usage is still billed
- **WHEN** an upstream Responses result contains final image output
- **AND** the upstream result has zero or missing token usage
- **THEN** the system writes a usage log
- **AND** the system bills using image pricing
#### Scenario: Responses image request records accompanying token usage
- **WHEN** a `/v1/responses` image tool request returns final images and token usage
- **THEN** the usage log records input tokens, output tokens, image output tokens, and image count
- **AND** the applied billing mode remains `image`
#### Scenario: Responses image request does not introduce hybrid billing by default
- **WHEN** a `/v1/responses` image tool request returns final images and text tokens
- **THEN** the request is billed by image pricing under this change
- **AND** non-image token charges are not added unless a future explicit hybrid billing mode is implemented
### Requirement: OpenAI Images API output accounting
The system SHALL count generated images from dedicated OpenAI Images API stream and non-stream paths and SHALL set `ImageCount` for successful image responses.
#### Scenario: Images non-stream data array is counted
- **WHEN** `/v1/images/generations` returns a non-stream JSON response with top-level `data[]`
- **THEN** `ImageCount` equals the length of `data[]`
#### Scenario: Images stream data array is counted
- **WHEN** `/v1/images/generations` stream response emits SSE data containing top-level `data[]`
- **THEN** `ImageCount` equals the maximum final data array count observed for the request
#### Scenario: Images stream completed event is counted
- **WHEN** `/v1/images/generations` stream response emits `image_generation.completed` with a final image payload
- **THEN** the stream result counts one final image output
#### Scenario: Images stream Responses-form event is counted
- **WHEN** an Images API upstream path emits Responses-form `response.output_item.done` or `response.completed` events with final image outputs
- **THEN** the stream result counts final image outputs using the same de-duplication rules as Responses
### Requirement: Channel image billing uses actual image count
The system SHALL use actual generated image count for channel `billing_mode=image` pricing and SHALL NOT bill multi-image requests as a single request.
#### Scenario: OpenAI channel image billing counts multiple images
- **WHEN** a channel image pricing entry resolves to unit price `0.25`
- **AND** an OpenAI image request produces three images
- **THEN** `total_cost` is `0.75` before the selected image multiplier is applied
- **AND** `RequestCount` passed into unified pricing is `3`
#### Scenario: Gateway channel image billing counts multiple images
- **WHEN** a non-OpenAI gateway image path produces two images
- **AND** channel image pricing resolves for the billing model
- **THEN** `RequestCount` passed into unified pricing is `2`
#### Scenario: Channel image pricing uses shared multiplier by default
- **WHEN** a channel image pricing entry resolves to unit price `0.25`
- **AND** the group has ordinary effective multiplier `0.15`
- **AND** the group has `image_rate_independent=false`
- **AND** the image request produces one image
- **THEN** `actual_cost` is `0.0375`
#### Scenario: Channel image pricing uses independent image multiplier when enabled
- **WHEN** a channel image pricing entry resolves to unit price `0.25`
- **AND** the group has ordinary effective multiplier `0.15`
- **AND** the group has `image_rate_independent=true`
- **AND** the group has `image_rate_multiplier=1`
- **AND** the image request produces one image
- **THEN** `actual_cost` is `0.25`
- **AND** ordinary effective multiplier `0.15` is not applied
#### Scenario: Account stats image pricing receives image count
- **WHEN** account stats pricing uses `billing_mode=image`
- **AND** the request produces multiple images
- **THEN** account stats cost is calculated with the actual image count
### Requirement: Image size tier normalization
The system SHALL normalize OpenAI image sizes to explicit billing tiers for billing only. The system SHALL NOT reject requests locally because of an unknown or provider-invalid `size`; it SHALL forward the original size parameter upstream and let the official upstream API decide whether the request is valid.
#### Scenario: OpenAI 1024 square maps to 1K
- **WHEN** an OpenAI image request specifies `size="1024x1024"`
- **THEN** `ImageSize` is `1K`
#### Scenario: OpenAI landscape and portrait large sizes map to 2K
- **WHEN** an OpenAI image request specifies `1536x1024`, `1024x1536`, `1792x1024`, `1024x1792`, `2048x2048`, `2048x1152`, or `1152x2048`
- **THEN** `ImageSize` is `2K`
#### Scenario: OpenAI gpt-image-2 4K presets map to 4K
- **WHEN** an OpenAI `gpt-image-2` image request specifies `3840x2160` or `2160x3840`
- **THEN** `ImageSize` is `4K`
#### Scenario: OpenAI auto size maps to 2K
- **WHEN** an OpenAI image request omits size or specifies `size="auto"`
- **THEN** `ImageSize` is `2K`
#### Scenario: Custom OpenAI size is forwarded without local validation
- **WHEN** an OpenAI image request specifies a custom explicit `WIDTHxHEIGHT` size
- **THEN** the system forwards the request upstream
- **AND** `ImageSize` is normalized to `2K` or `4K` for billing
#### Scenario: Responses image tool without model uses default image billing model
- **WHEN** a `/v1/responses` request uses an `image_generation` tool without `tool.model`
- **THEN** image size validation and image billing use `gpt-image-2` as the image billing model
#### Scenario: Invalid OpenAI size constraints are delegated upstream
- **WHEN** an OpenAI image request specifies an explicit size that fails OpenAI size constraints
- **THEN** the system forwards the request upstream
- **AND** any invalid-size error comes from the upstream provider response
#### Scenario: Custom OpenAI size tier mapping
- **WHEN** a custom size cannot be parsed as positive `WIDTHxHEIGHT`
- **THEN** `ImageSize` is `2K`
- **WHEN** a custom size parses as positive `WIDTHxHEIGHT`
- **AND** `WIDTH * HEIGHT` is no more than `2560x1440`
- **THEN** `ImageSize` is `2K`
- **WHEN** a custom size parses as positive `WIDTHxHEIGHT`
- **AND** `WIDTH * HEIGHT` exceeds `2560x1440`
- **THEN** `ImageSize` is `4K`
### Requirement: Image usage log semantics
The system SHALL write usage logs for successful image generation with image billing metadata that matches the applied image pricing path.
#### Scenario: Image usage log records image billing mode
- **WHEN** a successful request has `ImageCount > 0`
- **THEN** the usage log has `billing_mode="image"`
- **AND** the usage log records `image_count`
- **AND** the usage log records `image_size` when a normalized size tier is available
#### Scenario: Shared mode image usage log records shared multiplier
- **WHEN** a successful image request is billed with `image_rate_independent=false`
- **AND** the effective ordinary multiplier is `0.15`
- **THEN** `usage_logs.rate_multiplier` is `0.15`
#### Scenario: Independent mode image usage log records image multiplier
- **WHEN** a successful image request is billed with `image_rate_independent=true`
- **AND** `image_rate_multiplier=0.5`
- **THEN** `usage_logs.rate_multiplier` is `0.5`
#### Scenario: Token request usage log is unchanged
- **WHEN** a successful non-image token request is billed
- **THEN** `usage_logs.rate_multiplier` continues to record the ordinary token multiplier
- **AND** `image_count` is `0`

View File

@ -0,0 +1,72 @@
## 1. Data Model And Migration
- [x] 1.1 Add `allow_image_generation`, `image_rate_independent`, and `image_rate_multiplier` to `backend/ent/schema/group.go`.
- [x] 1.2 Create a new idempotent SQL migration after `133_affiliate_rebate_freeze.sql` for the three group columns.
- [x] 1.3 Backfill existing `openai`, `gemini`, and `antigravity` groups to `allow_image_generation=true` and `anthropic` groups to `false`.
- [x] 1.4 Backfill all existing groups to `image_rate_independent=false` and `image_rate_multiplier=1` without changing existing `image_price_1k/2k/4k`.
- [x] 1.5 Regenerate or update Ent generated group fields, predicates, create/update setters, and query projections.
- [x] 1.6 Add the new fields to backend group domain/service structs, admin create/update inputs, admin responses, and group serialization.
## 2. Admin API And Frontend
- [x] 2.1 Add `allow_image_generation`, `image_rate_independent`, and `image_rate_multiplier` to `CreateGroupRequest` and `UpdateGroupRequest`.
- [x] 2.2 Validate `image_rate_multiplier >= 0` and keep negative image prices using the existing clear-price behavior only for `image_price_*`.
- [x] 2.3 Add the new fields to `frontend/src/types/index.ts` group, create, and update interfaces.
- [x] 2.4 Ensure omitted update fields do not overwrite existing image generation and multiplier mode settings.
- [x] 2.5 Update `frontend/src/views/admin/GroupsView.vue` create/edit forms with a 生图开关, 生图倍率是否独立开关, and conditional image multiplier input.
- [x] 2.6 Add a live final-price preview for `image_price_1k/2k/4k` under shared and independent multiplier modes.
- [x] 2.7 Update group form help text to state that default image billing shares the existing group effective multiplier and independent mode uses the image multiplier input.
- [x] 2.8 Update i18n strings for the new controls and image multiplier mode explanation.
## 3. Image Generation Access Control
- [x] 3.1 Implement a shared helper that detects image generation intent from endpoint, requested model, `tools[]`, and `tool_choice`.
- [x] 3.2 Gate `/v1/images/generations` and `/v1/images/edits` in `backend/internal/handler/openai_images.go` after request parsing and before billing eligibility/account scheduling.
- [x] 3.3 Gate `/v1/responses` explicit `image_generation` tool requests in `backend/internal/service/openai_gateway_service.go` before upstream account scheduling.
- [x] 3.4 Prevent `normalizeOpenAIResponsesImageOnlyModel` from rewriting `gpt-image-*` Responses requests when the group does not allow image generation.
- [x] 3.5 Skip Codex `image_generation` auto-injection and image bridge instructions when the group does not allow image generation.
- [x] 3.6 Re-run image intent detection after service-side request mutation and before upstream dispatch.
- [x] 3.7 Ensure OpenAI advanced scheduler paths apply the same channel `RestrictModels` checks as the load-aware path.
## 4. Responses Image Output Accounting
- [x] 4.1 Add shared parsers for final `image_generation_call.result` outputs in non-stream JSON and SSE payloads.
- [x] 4.2 Extend `openaiStreamingResult` with image count, image size tier, and image billing model fields.
- [x] 4.3 Update `handleStreamingResponse` to count final image outputs while preserving existing stream forwarding and usage parsing.
- [x] 4.4 Update `handleStreamingResponsePassthrough` with the same image output counting.
- [x] 4.5 Update `handleNonStreamingResponse` to count final image outputs from `output[]`.
- [x] 4.6 Update `handleNonStreamingResponsePassthrough` with the same non-stream image output counting.
- [x] 4.7 Populate `OpenAIForwardResult.ImageCount`, `ImageSize`, and image billing model for `gpt-5.4` / `gpt-5.5 + image_generation` requests.
## 5. Images API Accounting And Size Tiers
- [x] 5.1 Extend OpenAI Images API-key stream counting to handle `image_generation.completed`, `response.output_item.done`, and `response.completed`.
- [x] 5.2 Reuse the same final-image de-duplication rules across Images API and Responses API paths.
- [x] 5.3 Keep unknown explicit OpenAI image sizes pass-through and delegate invalid-size errors to upstream.
- [x] 5.4 Map documented OpenAI image sizes to `1K`/`2K`/`4K` billing tiers without rewriting request parameters.
- [x] 5.5 Classify custom OpenAI `WIDTHxHEIGHT` sizes by `2560x1440` total-pixel boundary, falling back to `2K` when unparseable.
## 6. Billing And Usage Logs
- [x] 6.1 Add an image multiplier resolver: shared mode uses the current effective group multiplier, independent mode uses `apiKey.Group.ImageRateMultiplier`.
- [x] 6.2 Update `CalculateImageCost` or its caller contract so image costs use the resolved image multiplier.
- [x] 6.3 Set image usage log `RateMultiplier` to the applied image multiplier; keep token logs unchanged.
- [x] 6.4 Change OpenAI channel image billing `RequestCount` from `1` to `result.ImageCount`.
- [x] 6.5 Change non-OpenAI gateway channel image billing `RequestCount` from `1` to `result.ImageCount`.
- [x] 6.6 Pass actual image count into account stats pricing for `billing_mode=image`.
- [x] 6.7 Ensure `ImageCount > 0` writes a usage log and bills even when upstream token usage is zero.
- [x] 6.8 Record accompanying token usage for Responses image tool requests while keeping default billing mode as `image`.
## 7. Tests And Documentation
- [x] 7.1 Add backend tests for disabled group rejecting `/v1/images/*`, `gpt-image-*` Responses, explicit `image_generation`, and image `tool_choice`.
- [x] 7.2 Add backend tests proving disabled Codex groups do not receive injected image tools while enabled Codex groups still do.
- [x] 7.3 Add backend tests proving omitted group update fields preserve existing image generation and multiplier mode settings.
- [x] 7.4 Add Responses stream and non-stream tests for `gpt-5.4` / `gpt-5.5 + image_generation` image counting and image billing.
- [x] 7.5 Add Images API stream tests for `image_generation.completed`, `response.output_item.done`, and `response.completed` counting.
- [x] 7.6 Add billing tests for shared mode `rate_multiplier=0.15`, `image_price_1k=0.2`, final `actual_cost=0.03`.
- [x] 7.7 Add billing tests for independent mode `rate_multiplier=0.15`, `image_rate_multiplier=1`, `image_price_1k=0.2`, final `actual_cost=0.2`.
- [x] 7.8 Add channel image billing tests proving multi-image requests use `RequestCount=ImageCount` in both shared and independent multiplier modes.
- [x] 7.9 Add size-tier tests for known OpenAI sizes and unknown explicit size pass-through.
- [x] 7.10 Add Responses image tool tests proving token usage is recorded but default billing remains image-mode only.
- [x] 7.11 Update `2ue/image-billing-risk-analysis.md` or add a linked follow-up note that points to this OpenSpec change as the normalized solution.

View File

@ -0,0 +1,2 @@
schema: spec-driven
created: 2026-05-03

View File

@ -0,0 +1,70 @@
## Overview
本次只实现“图片独立并发开关”,不实现外部图片网关的运行时代码。目标是在最大程度不改变现有行为的前提下,为图片流式长连接提供服务级资源保护。
## Current Constraints
- 当前 Redis 并发槽位只有用户和账号维度,键语义是 `concurrency:user:*``concurrency:account:*`
- 图片接口和普通 Responses 在同一个 Go 服务内运行共享进程、HTTP 上游连接池和账号调度。
- Codex OAuth 路径会自动注入 `image_generation` tool这个注入表示“模型具备工具能力”不等价于当前请求一定会生图。
- `/v1/responses` 在 handler 入口只能可靠识别显式图片意图image 模型、请求体已有 image tool、或 tool_choice 明确选择 image_generation。
- 图片实际产物计数与计费仍以 service 层的最终输出解析为准。
## Decisions
### 1. 默认关闭,保持兼容
新增配置:
- `gateway.image_concurrency.enabled`,默认 `false`
- `gateway.image_concurrency.max_concurrent_requests`,默认 `0`,表示不限制。
- `gateway.image_concurrency.overflow_mode`,默认 `reject`,可选 `reject` / `wait`
- `gateway.image_concurrency.wait_timeout_seconds`,默认 `30`,仅 `overflow_mode=wait` 生效。
- `gateway.image_concurrency.max_waiting_requests`,默认 `100`,仅 `overflow_mode=wait` 生效,限制当前进程内图片等待队列。
只有当 `enabled=true``max_concurrent_requests>0` 时才启用图片独立并发限制。默认配置不改变任何现有流量行为。
### 2. 进程级信号量作为第一阶段隔离
本次使用进程内有界信号量做服务级图片并发限制。原因:
- 不扩展现有 Redis `ConcurrencyCache` 接口,避免影响用户/账号并发的既有语义。
- 不新增迁移,不改变分组已有字段。
- 单实例部署可立即保护进程资源。
- 多实例部署时该限制按实例生效;文档必须明确总图片并发约等于 `实例数 × max_concurrent_requests`
### 3. 限制对象只包含明确图片意图
纳入限制:
- `/v1/images/generations`
- `/v1/images/edits`
- `/v1/responses` 中入口请求已明确包含图片意图image 模型、`tools[].type=image_generation``tool_choice` 明确选择 image_generation。
暂不纳入限制:
- 普通 Codex 请求因为服务端自动注入 image tool 而具备生图能力,但入口请求本身未明确要求生图。
这样避免把普通编码请求错误算作图片并发。后续若要对“模型运行中动态调用 image tool”做更细粒度隔离需要在工具调用实际发生时获得可阻塞的事件目前当前代码没有这种入口级阻塞点。
### 4. 限流行为
- `overflow_mode=reject` 时,未开始流式响应直接返回 HTTP `429`,错误类型 `rate_limit_error`
- `overflow_mode=wait` 时,请求在当前进程内等待图片并发槽位,超过 `wait_timeout_seconds` 或超过 `max_waiting_requests` 后返回 HTTP `429`
- 已开始流式响应时,使用现有 `handleStreamingAwareError` 写 SSE 错误事件。
- 图片并发限制命中或等待超时不触发账号 failover不记录为上游账号失败。
- `gateway.image_stream_data_interval_timeout` 是上游图片流数据空闲超时,不用于图片排队等待。
### 5. 与外部图片网关的关系
本次不实现外部图片网关代码。外部网关方案沉淀到 `2ue` 文档:
- 推荐由 Caddy/Nginx/API Gateway 按 `/v1/images/*` 分流。
- `/v1/responses` 的图片 tool 请求不能仅靠 path 分流,必须在前置层读取 body 或保留主服务兜底。
- 即使未来拆出图片网关,主网关仍保留图片 intent 检测、开关和计费兜底,避免直连或漏判绕过。
## Risks And Mitigations
- 风险:进程级限制在多实例部署下不是全局严格限制。缓解:文档明确容量计算,后续可基于 Redis 扩展为集群级图片并发。
- 风险Codex 自动注入 image tool 后,普通编码请求未被图片限流。缓解:这是有意选择,避免误伤普通请求;实际输出图片仍按图片计费。
- 风险图片请求在账号槽位前被拒绝可能改变排队体验。缓解仅当独立开关启用时生效默认关闭429 明确提示图片并发达到上限。

View File

@ -0,0 +1,28 @@
## Why
图片生成流式请求会比普通文本流式请求占用更长的连接、goroutine、HTTP 上游连接和账号/用户槽位。当前图片能力已经具备独立计费与更长流式超时,但仍缺少默认关闭的图片专属并发隔离开关,图片高并发时仍可能挤压普通文本流式接口。
## What Changes
- 新增服务级图片独立并发开关,默认关闭,不改变现有已部署分组和普通文本请求行为。
- 新增图片全局并发上限配置;开启后仅限制已明确是图片生成意图的请求。
- 新增图片并发满载后的溢出策略配置:默认立即拒绝,也可配置等待槽位和等待超时。
- 将图片并发限制覆盖 `/v1/images/generations``/v1/images/edits``/v1/responses` 显式图片生成请求。
- 保留当前图片生成开关、图片计费、图片流式续读与超时语义。
- 不在本次代码实现外部独立图片网关;只把外部网关拆分方案沉淀到本地文档。
## Capabilities
### New Capabilities
- `image-generation-concurrency-isolation`: 图片生成请求的独立并发开关、并发上限、429 行为和外部网关落地建议。
### Modified Capabilities
- `image-stream-resilience`: 图片流式续读能力在独立并发开启时受到图片专属并发上限保护,但流式续读与计费契约不变。
## Impact
- 影响 `backend/internal/config/config.go` 的 gateway 配置字段、默认值和校验。
- 影响 `backend/internal/handler/openai_images.go``backend/internal/handler/openai_gateway_handler.go` 的图片请求入口限流。
- 影响 `deploy/config.example.yaml` 的示例配置与说明。
- 影响后端测试:配置默认值/校验、图片接口限流、Responses 显式 image tool 限流。
- 新增或更新 `2ue` 本地分析文档,记录外部独立图片网关只作为后续部署方案,不在本次代码落地。

View File

@ -0,0 +1,82 @@
# image-generation-concurrency-isolation Specification
## ADDED Requirements
### Requirement: Image concurrency isolation is opt-in
The system SHALL keep image concurrency isolation disabled by default.
#### Scenario: default config keeps existing behavior
- **GIVEN** the deployment does not set `gateway.image_concurrency.enabled`
- **WHEN** image generation requests are received
- **THEN** no new image-specific concurrency limit is applied
- **AND** existing user/account concurrency and billing behavior remains unchanged
### Requirement: Dedicated image concurrency limit
The system SHALL provide an opt-in service-level image concurrency limit controlled by gateway configuration.
#### Scenario: explicit image endpoint is limited
- **GIVEN** `gateway.image_concurrency.enabled=true`
- **AND** `gateway.image_concurrency.max_concurrent_requests=1`
- **AND** one image generation request is already active
- **WHEN** another `/v1/images/generations` or `/v1/images/edits` request arrives
- **THEN** the second request is rejected with HTTP `429`
- **AND** the error type is `rate_limit_error`
#### Scenario: explicit Responses image generation request is limited
- **GIVEN** `gateway.image_concurrency.enabled=true`
- **AND** `gateway.image_concurrency.max_concurrent_requests=1`
- **AND** `gateway.image_concurrency.overflow_mode=reject`
- **AND** one image generation request is already active
- **WHEN** a `/v1/responses` request explicitly contains `tools[].type=image_generation`, an image model, or `tool_choice` selecting `image_generation`
- **THEN** the request is rejected with HTTP `429`
- **AND** it is not retried through account failover
#### Scenario: image request waits for a slot
- **GIVEN** `gateway.image_concurrency.enabled=true`
- **AND** `gateway.image_concurrency.max_concurrent_requests=1`
- **AND** `gateway.image_concurrency.overflow_mode=wait`
- **AND** `gateway.image_concurrency.wait_timeout_seconds` is greater than zero
- **AND** one image generation request is already active
- **WHEN** another explicit image generation request arrives
- **AND** the active image generation request releases its slot before the wait timeout
- **THEN** the waiting image generation request acquires the slot and continues
#### Scenario: image wait times out
- **GIVEN** `gateway.image_concurrency.enabled=true`
- **AND** `gateway.image_concurrency.max_concurrent_requests=1`
- **AND** `gateway.image_concurrency.overflow_mode=wait`
- **AND** one image generation request is already active
- **WHEN** another explicit image generation request waits longer than `gateway.image_concurrency.wait_timeout_seconds`
- **THEN** the waiting request is rejected with HTTP `429`
- **AND** the error type is `rate_limit_error`
#### Scenario: image waiting queue is full
- **GIVEN** `gateway.image_concurrency.enabled=true`
- **AND** `gateway.image_concurrency.overflow_mode=wait`
- **AND** `gateway.image_concurrency.max_waiting_requests` is already reached
- **WHEN** another explicit image generation request arrives
- **THEN** the request is rejected with HTTP `429`
- **AND** it does not wait for account scheduling
### Requirement: Text requests are not image-limited
The system SHALL NOT apply the image concurrency limit to requests without explicit image generation intent.
#### Scenario: normal coding request bypasses image limit
- **GIVEN** `gateway.image_concurrency.enabled=true`
- **AND** the image concurrency limit is full
- **WHEN** a `/v1/responses` request uses a text model and does not explicitly contain image generation intent
- **THEN** the image concurrency limiter does not reject it
- **AND** normal user/account concurrency handling continues
### Requirement: External image gateway remains a deployment pattern
The system SHALL document external image gateway routing as a deployment option without adding runtime forwarding code in this change.
#### Scenario: operator reads local design note
- **GIVEN** the repository documentation is available
- **WHEN** an operator evaluates isolating image traffic into a separate service
- **THEN** local `2ue` documentation describes which paths are safe to route by path
- **AND** explains why `/v1/responses` image tool requests require body-aware routing or main-gateway fallback

View File

@ -0,0 +1,28 @@
## 1. Spec and documentation
- [x] 1.1 Create OpenSpec proposal, design, tasks, and capability spec for image concurrency isolation.
- [x] 1.2 Add a local `2ue` note for the external image gateway deployment pattern and current non-goals.
## 2. Config
- [x] 2.1 Add `gateway.image_concurrency.enabled` and `gateway.image_concurrency.max_concurrent_requests` config fields.
- [x] 2.2 Register defaults that keep existing behavior unchanged.
- [x] 2.3 Validate max concurrent requests as non-negative.
- [x] 2.4 Update `deploy/config.example.yaml` with safe usage notes.
- [x] 2.5 Add image concurrency overflow mode, wait timeout, and max waiting request config.
## 3. Runtime limiter
- [x] 3.1 Implement a process-level image concurrency limiter with resize-on-config-read behavior.
- [x] 3.2 Acquire/release the limiter around `/v1/images/generations` and `/v1/images/edits` before account scheduling.
- [x] 3.3 Acquire/release the limiter around explicit `/v1/responses` image generation intent before account scheduling.
- [x] 3.4 Ensure limiter rejections return `429 rate_limit_error` and do not trigger account failover.
- [x] 3.5 Support `reject` and `wait` overflow modes with bounded wait timeout and waiting queue size.
## 4. Tests and verification
- [x] 4.1 Add config default and validation tests.
- [x] 4.2 Add handler tests for image endpoint limiter rejection.
- [x] 4.3 Add handler tests proving text-only Responses requests are not rejected by the image limiter.
- [x] 4.4 Run focused Go tests for config and OpenAI handler/service paths.
- [x] 4.5 Add limiter tests for wait success, wait timeout, and waiting queue overflow.

View File

@ -0,0 +1,2 @@
schema: spec-driven
created: 2026-05-02

View File

@ -0,0 +1,46 @@
## Context
现有普通 Responses 流式已经具备较完整的断连续写能力:客户端写失败后可继续 drain 上游,并且具备数据间隔超时和 keepalive。图片流式路径目前仍然采用更直接的读写方式客户端写失败会立即返回上游读取也更容易跟随客户端取消而结束。
本次变更只针对图片流式路径,不改变普通文本流式路径的配置和行为。系统已经存在普通流式的后端超时配置,因此这里不引入页面级超时设置;图片流式只需要独立的后端默认值,让图片生成有更长的容忍窗口。
## Goals / Non-Goals
**Goals:**
- 图片流式在客户端断开后继续读取上游,尽量保留最终图片结果与计费结果。
- 图片流式使用独立于普通流式的超时与 keepalive 默认值。
- 不修改现有普通流式配置项的含义,不要求管理员新增页面配置。
- 维持图片计费与图片结果计数的一致性。
**Non-Goals:**
- 不设计新的前端配置页面。
- 不修改普通文本流式的超时策略。
- 不改变图片计费公式或分组倍率语义。
## Decisions
1. **使用独立的图片流式配置键**
- 选择:在后端配置中增加图片流式专用 `image_stream_data_interval_timeout` / `image_stream_keepalive_interval`
- 原因:图片流式耗时显著更长,复用普通流式默认值会过早触发超时;独立键能避免影响现有文本流式。
- 备选方案:直接复用普通流式配置并在代码里按路径放大倍数。这个方案会让普通流式和图片流式共享语义,后续难以维护。
2. **继续使用上下文 detach而不是依赖客户端上下文**
- 选择:图片流式请求向上游发起时使用 `context.WithoutCancel` 派生的上下文。
- 原因:客户端断开时不应自动取消上游请求,否则无法收集最终图片结果,也无法完成图片计费。
- 备选方案:仍使用 `c.Request.Context()` 并只在写失败后继续 drain。这个方案在客户端取消场景下无法保证上游读取继续进行。
3. **只改图片流式路径,不改普通流式路径**
- 选择:`/v1/images/*``Responses + image_generation` 两条图片流式链路单独处理。
- 原因:风险最小,避免回归普通文本流式和现有超时配置。
- 备选方案:统一重构所有流式处理。这个方案范围更大,验证成本更高,不符合本次“尽量少改现有行为”的目标。
4. **不新增页面配置**
- 选择:图片流式独立超时默认值写入后端配置,沿用当前配置加载方式。
- 原因:用户明确要求和当前设置行为统一,不需要额外页面输入项。
- 备选方案:前端增加图片超时配置项。这个方案会改变现有运维方式,也容易引入误配。
## Risks / Trade-offs
- [Risk] 图片流式继续 drain 上游后,客户端已经断开但服务端仍占用连接与协程资源。→ [Mitigation] 只对图片流式启用更长但仍有限的专用超时,并保持与普通流式同样的 keepalive/超时退出机制。
- [Risk] 图片流式与普通流式的默认超时不同,运维如果只关注通用配置可能忽略图片专用值。→ [Mitigation] 在配置示例中明确标注图片流式专用默认值和用途。
- [Risk] 断连后继续读取可能导致日志中出现“客户端断开但最终成功”的状态。→ [Mitigation] 保留现有图片计费结果返回语义,同时让调用方在结果与错误并存时优先使用结果对象。

View File

@ -0,0 +1,25 @@
## Why
图片流式路径目前没有和普通 Responses 流式一致的断连续写策略,也没有独立于普通流式的超时控制。由于图片生成耗时更长,如果继续沿用普通流式处理方式,客户端断开时容易中断上游读取,影响图片产物收集与按图计费的准确性。
## What Changes
- 为 OpenAI Images API 和 `Responses + image_generation` 流式路径补充独立的上游续读策略,客户端断开后继续 drain 上游,尽量保留最终图片结果和计费结果。
- 为图片流式路径使用独立的流数据间隔超时与 keepalive 策略,默认比普通流式更长,不新增页面配置项。
- 保持现有普通流式配置与行为不变,避免影响已经配置好的普通文本分组。
- 让图片流式路径在超时、断连、写入失败等场景下保持图片计费语义一致。
## Capabilities
### New Capabilities
- `image-stream-resilience`: 图片流式路径的断连续读、独立超时和计费保留能力。
### Modified Capabilities
- `image-generation-billing-accounting`: 图片流式结果计数和计费结果的稳定性行为发生改变,但计费契约不变。
## Impact
- 影响 `backend/internal/service/openai_images.go``backend/internal/service/openai_images_responses.go` 的流式实现。
- 影响 `backend/internal/config/config.go``deploy/config.example.yaml` 中图片流式默认值和校验逻辑。
- 影响 `backend/internal/service/openai_images_test.go``backend/internal/config/config_test.go` 以及新增的图片流式稳定性测试。
- 不新增前端页面设置,不改变普通流式配置项名称和语义。

View File

@ -0,0 +1,53 @@
## ADDED Requirements
### Requirement: Image stream resilience
The system SHALL keep image generation stream processing active after downstream client disconnects so long as upstream reading can continue, in order to preserve final image outputs and billing results.
#### Scenario: Images API stream survives downstream disconnect
- **WHEN** `/v1/images/generations` is streamed to a client
- **AND** the downstream writer returns an error before the upstream stream completes
- **THEN** the service continues draining the upstream stream
- **AND** it still counts final image outputs if the upstream later emits them
- **AND** the request can still complete with image billing metadata
#### Scenario: Responses image tool stream survives downstream disconnect
- **WHEN** a `/v1/responses` request uses `image_generation` and is streamed to a client
- **AND** the downstream writer returns an error before the upstream stream completes
- **THEN** the service continues draining the upstream stream
- **AND** it still counts final image outputs if the upstream later emits them
- **AND** the request can still complete with image billing metadata
#### Scenario: Client disconnect does not force image stream to downgrade to text billing
- **WHEN** a successful image stream request has already produced final image outputs
- **AND** the downstream client disconnects before the final flush
- **THEN** the request remains billed as an image request
- **AND** the image count is preserved in the forward result
### Requirement: Image stream timeout isolation
The system SHALL use image-specific streaming timeout settings for image generation stream paths, and these settings SHALL be independent from the ordinary text streaming timeout values.
#### Scenario: Image stream uses dedicated timeout defaults
- **WHEN** an image generation stream path is executed
- **THEN** it uses the image-specific data interval timeout and keepalive interval defaults
- **AND** it does not rely on the ordinary text stream timeout defaults
#### Scenario: Ordinary stream settings remain unchanged
- **WHEN** a normal non-image streaming request is executed
- **THEN** the existing ordinary stream timeout configuration and behavior remain unchanged
#### Scenario: Image stream timeout is longer than ordinary stream timeout
- **WHEN** the image streaming timeout defaults are compared with the ordinary streaming defaults
- **THEN** the image streaming timeout is configured to allow a longer wait window than ordinary text streaming
### Requirement: Image stream billing consistency
The system SHALL keep the image billing result consistent even when image stream handling uses retries, keepalive writes, or downstream disconnect recovery.
#### Scenario: Final image count is preserved after reconnect-unsafe downstream failure
- **WHEN** the downstream client disconnects after at least one final image output has been observed upstream
- **THEN** the forward result retains the final image count
- **AND** usage recording can still proceed with image billing metadata
#### Scenario: Image stream timeout does not silently switch billing mode
- **WHEN** an image stream times out before any final image output is observed
- **THEN** the request is handled as a failed image stream
- **AND** it does not fall back to ordinary text billing semantics

View File

@ -0,0 +1,20 @@
## 1. Config and defaults
- [x] 1.1 Add image-specific stream timeout fields to gateway config.
- [x] 1.2 Register image stream timeout defaults in the config loader.
- [x] 1.3 Add config validation for image stream timeout ranges.
- [x] 1.4 Expose image stream timeout defaults in `deploy/config.example.yaml`.
## 2. Image stream runtime behavior
- [x] 2.1 Detach image stream upstream contexts from client cancellation.
- [x] 2.2 Add image-specific data interval timeout handling to `/v1/images/*` streaming.
- [x] 2.3 Add image-specific data interval timeout handling to `Responses + image_generation` streaming.
- [x] 2.4 Preserve upstream draining after downstream write failures in both image stream paths.
## 3. Tests and verification
- [x] 3.1 Add config tests for image stream timeout defaults and validation.
- [x] 3.2 Add image streaming disconnect tests for the Images API path.
- [x] 3.3 Add image streaming disconnect tests for the Responses image tool path.
- [x] 3.4 Run focused Go tests for the touched config and image service paths.

20
openspec/config.yaml Normal file
View File

@ -0,0 +1,20 @@
schema: spec-driven
# Project context (optional)
# This is shown to AI when creating artifacts.
# Add your tech stack, conventions, style guides, domain knowledge, etc.
# Example:
# context: |
# Tech stack: TypeScript, React, Node.js
# We use conventional commits
# Domain: e-commerce platform
# Per-artifact rules (optional)
# Add custom rules for specific artifacts.
# Example:
# rules:
# proposal:
# - Keep proposals under 500 words
# - Always include a "Non-goals" section
# tasks:
# - Break tasks into chunks of max 2 hours