Merge pull request #2191 from 2ue/feat/openai-image-generation-controls
完善 GPT Image 访问控制、计费、并发与流式稳定性
This commit is contained in:
commit
ad9b88f0e3
@ -47,6 +47,12 @@ type Group struct {
|
||||
MonthlyLimitUsd *float64 `json:"monthly_limit_usd,omitempty"`
|
||||
// DefaultValidityDays holds the value of the "default_validity_days" field.
|
||||
DefaultValidityDays int `json:"default_validity_days,omitempty"`
|
||||
// 是否允许该分组使用图片生成能力
|
||||
AllowImageGeneration bool `json:"allow_image_generation,omitempty"`
|
||||
// 图片生成是否使用独立倍率;false 表示共享分组有效倍率
|
||||
ImageRateIndependent bool `json:"image_rate_independent,omitempty"`
|
||||
// 图片生成独立倍率,仅 image_rate_independent=true 时生效
|
||||
ImageRateMultiplier float64 `json:"image_rate_multiplier,omitempty"`
|
||||
// ImagePrice1k holds the value of the "image_price_1k" field.
|
||||
ImagePrice1k *float64 `json:"image_price_1k,omitempty"`
|
||||
// ImagePrice2k holds the value of the "image_price_2k" field.
|
||||
@ -189,9 +195,9 @@ func (*Group) scanValues(columns []string) ([]any, error) {
|
||||
switch columns[i] {
|
||||
case group.FieldModelRouting, group.FieldSupportedModelScopes, group.FieldMessagesDispatchModelConfig:
|
||||
values[i] = new([]byte)
|
||||
case group.FieldIsExclusive, group.FieldClaudeCodeOnly, group.FieldModelRoutingEnabled, group.FieldMcpXMLInject, group.FieldAllowMessagesDispatch, group.FieldRequireOauthOnly, group.FieldRequirePrivacySet:
|
||||
case group.FieldIsExclusive, group.FieldAllowImageGeneration, group.FieldImageRateIndependent, group.FieldClaudeCodeOnly, group.FieldModelRoutingEnabled, group.FieldMcpXMLInject, group.FieldAllowMessagesDispatch, group.FieldRequireOauthOnly, group.FieldRequirePrivacySet:
|
||||
values[i] = new(sql.NullBool)
|
||||
case group.FieldRateMultiplier, group.FieldDailyLimitUsd, group.FieldWeeklyLimitUsd, group.FieldMonthlyLimitUsd, group.FieldImagePrice1k, group.FieldImagePrice2k, group.FieldImagePrice4k:
|
||||
case group.FieldRateMultiplier, group.FieldDailyLimitUsd, group.FieldWeeklyLimitUsd, group.FieldMonthlyLimitUsd, group.FieldImageRateMultiplier, group.FieldImagePrice1k, group.FieldImagePrice2k, group.FieldImagePrice4k:
|
||||
values[i] = new(sql.NullFloat64)
|
||||
case group.FieldID, group.FieldDefaultValidityDays, group.FieldFallbackGroupID, group.FieldFallbackGroupIDOnInvalidRequest, group.FieldSortOrder, group.FieldRpmLimit:
|
||||
values[i] = new(sql.NullInt64)
|
||||
@ -309,6 +315,24 @@ func (_m *Group) assignValues(columns []string, values []any) error {
|
||||
} else if value.Valid {
|
||||
_m.DefaultValidityDays = int(value.Int64)
|
||||
}
|
||||
case group.FieldAllowImageGeneration:
|
||||
if value, ok := values[i].(*sql.NullBool); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field allow_image_generation", values[i])
|
||||
} else if value.Valid {
|
||||
_m.AllowImageGeneration = value.Bool
|
||||
}
|
||||
case group.FieldImageRateIndependent:
|
||||
if value, ok := values[i].(*sql.NullBool); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field image_rate_independent", values[i])
|
||||
} else if value.Valid {
|
||||
_m.ImageRateIndependent = value.Bool
|
||||
}
|
||||
case group.FieldImageRateMultiplier:
|
||||
if value, ok := values[i].(*sql.NullFloat64); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field image_rate_multiplier", values[i])
|
||||
} else if value.Valid {
|
||||
_m.ImageRateMultiplier = value.Float64
|
||||
}
|
||||
case group.FieldImagePrice1k:
|
||||
if value, ok := values[i].(*sql.NullFloat64); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field image_price_1k", values[i])
|
||||
@ -550,6 +574,15 @@ func (_m *Group) String() string {
|
||||
builder.WriteString("default_validity_days=")
|
||||
builder.WriteString(fmt.Sprintf("%v", _m.DefaultValidityDays))
|
||||
builder.WriteString(", ")
|
||||
builder.WriteString("allow_image_generation=")
|
||||
builder.WriteString(fmt.Sprintf("%v", _m.AllowImageGeneration))
|
||||
builder.WriteString(", ")
|
||||
builder.WriteString("image_rate_independent=")
|
||||
builder.WriteString(fmt.Sprintf("%v", _m.ImageRateIndependent))
|
||||
builder.WriteString(", ")
|
||||
builder.WriteString("image_rate_multiplier=")
|
||||
builder.WriteString(fmt.Sprintf("%v", _m.ImageRateMultiplier))
|
||||
builder.WriteString(", ")
|
||||
if v := _m.ImagePrice1k; v != nil {
|
||||
builder.WriteString("image_price_1k=")
|
||||
builder.WriteString(fmt.Sprintf("%v", *v))
|
||||
|
||||
@ -44,6 +44,12 @@ const (
|
||||
FieldMonthlyLimitUsd = "monthly_limit_usd"
|
||||
// FieldDefaultValidityDays holds the string denoting the default_validity_days field in the database.
|
||||
FieldDefaultValidityDays = "default_validity_days"
|
||||
// FieldAllowImageGeneration holds the string denoting the allow_image_generation field in the database.
|
||||
FieldAllowImageGeneration = "allow_image_generation"
|
||||
// FieldImageRateIndependent holds the string denoting the image_rate_independent field in the database.
|
||||
FieldImageRateIndependent = "image_rate_independent"
|
||||
// FieldImageRateMultiplier holds the string denoting the image_rate_multiplier field in the database.
|
||||
FieldImageRateMultiplier = "image_rate_multiplier"
|
||||
// FieldImagePrice1k holds the string denoting the image_price_1k field in the database.
|
||||
FieldImagePrice1k = "image_price_1k"
|
||||
// FieldImagePrice2k holds the string denoting the image_price_2k field in the database.
|
||||
@ -167,6 +173,9 @@ var Columns = []string{
|
||||
FieldWeeklyLimitUsd,
|
||||
FieldMonthlyLimitUsd,
|
||||
FieldDefaultValidityDays,
|
||||
FieldAllowImageGeneration,
|
||||
FieldImageRateIndependent,
|
||||
FieldImageRateMultiplier,
|
||||
FieldImagePrice1k,
|
||||
FieldImagePrice2k,
|
||||
FieldImagePrice4k,
|
||||
@ -239,6 +248,12 @@ var (
|
||||
SubscriptionTypeValidator func(string) error
|
||||
// DefaultDefaultValidityDays holds the default value on creation for the "default_validity_days" field.
|
||||
DefaultDefaultValidityDays int
|
||||
// DefaultAllowImageGeneration holds the default value on creation for the "allow_image_generation" field.
|
||||
DefaultAllowImageGeneration bool
|
||||
// DefaultImageRateIndependent holds the default value on creation for the "image_rate_independent" field.
|
||||
DefaultImageRateIndependent bool
|
||||
// DefaultImageRateMultiplier holds the default value on creation for the "image_rate_multiplier" field.
|
||||
DefaultImageRateMultiplier float64
|
||||
// DefaultClaudeCodeOnly holds the default value on creation for the "claude_code_only" field.
|
||||
DefaultClaudeCodeOnly bool
|
||||
// DefaultModelRoutingEnabled holds the default value on creation for the "model_routing_enabled" field.
|
||||
@ -343,6 +358,21 @@ func ByDefaultValidityDays(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldDefaultValidityDays, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByAllowImageGeneration orders the results by the allow_image_generation field.
|
||||
func ByAllowImageGeneration(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldAllowImageGeneration, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByImageRateIndependent orders the results by the image_rate_independent field.
|
||||
func ByImageRateIndependent(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldImageRateIndependent, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByImageRateMultiplier orders the results by the image_rate_multiplier field.
|
||||
func ByImageRateMultiplier(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldImageRateMultiplier, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByImagePrice1k orders the results by the image_price_1k field.
|
||||
func ByImagePrice1k(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldImagePrice1k, opts...).ToFunc()
|
||||
|
||||
@ -125,6 +125,21 @@ func DefaultValidityDays(v int) predicate.Group {
|
||||
return predicate.Group(sql.FieldEQ(FieldDefaultValidityDays, v))
|
||||
}
|
||||
|
||||
// AllowImageGeneration applies equality check predicate on the "allow_image_generation" field. It's identical to AllowImageGenerationEQ.
|
||||
func AllowImageGeneration(v bool) predicate.Group {
|
||||
return predicate.Group(sql.FieldEQ(FieldAllowImageGeneration, v))
|
||||
}
|
||||
|
||||
// ImageRateIndependent applies equality check predicate on the "image_rate_independent" field. It's identical to ImageRateIndependentEQ.
|
||||
func ImageRateIndependent(v bool) predicate.Group {
|
||||
return predicate.Group(sql.FieldEQ(FieldImageRateIndependent, v))
|
||||
}
|
||||
|
||||
// ImageRateMultiplier applies equality check predicate on the "image_rate_multiplier" field. It's identical to ImageRateMultiplierEQ.
|
||||
func ImageRateMultiplier(v float64) predicate.Group {
|
||||
return predicate.Group(sql.FieldEQ(FieldImageRateMultiplier, v))
|
||||
}
|
||||
|
||||
// ImagePrice1k applies equality check predicate on the "image_price_1k" field. It's identical to ImagePrice1kEQ.
|
||||
func ImagePrice1k(v float64) predicate.Group {
|
||||
return predicate.Group(sql.FieldEQ(FieldImagePrice1k, v))
|
||||
@ -900,6 +915,66 @@ func DefaultValidityDaysLTE(v int) predicate.Group {
|
||||
return predicate.Group(sql.FieldLTE(FieldDefaultValidityDays, v))
|
||||
}
|
||||
|
||||
// AllowImageGenerationEQ applies the EQ predicate on the "allow_image_generation" field.
|
||||
func AllowImageGenerationEQ(v bool) predicate.Group {
|
||||
return predicate.Group(sql.FieldEQ(FieldAllowImageGeneration, v))
|
||||
}
|
||||
|
||||
// AllowImageGenerationNEQ applies the NEQ predicate on the "allow_image_generation" field.
|
||||
func AllowImageGenerationNEQ(v bool) predicate.Group {
|
||||
return predicate.Group(sql.FieldNEQ(FieldAllowImageGeneration, v))
|
||||
}
|
||||
|
||||
// ImageRateIndependentEQ applies the EQ predicate on the "image_rate_independent" field.
|
||||
func ImageRateIndependentEQ(v bool) predicate.Group {
|
||||
return predicate.Group(sql.FieldEQ(FieldImageRateIndependent, v))
|
||||
}
|
||||
|
||||
// ImageRateIndependentNEQ applies the NEQ predicate on the "image_rate_independent" field.
|
||||
func ImageRateIndependentNEQ(v bool) predicate.Group {
|
||||
return predicate.Group(sql.FieldNEQ(FieldImageRateIndependent, v))
|
||||
}
|
||||
|
||||
// ImageRateMultiplierEQ applies the EQ predicate on the "image_rate_multiplier" field.
|
||||
func ImageRateMultiplierEQ(v float64) predicate.Group {
|
||||
return predicate.Group(sql.FieldEQ(FieldImageRateMultiplier, v))
|
||||
}
|
||||
|
||||
// ImageRateMultiplierNEQ applies the NEQ predicate on the "image_rate_multiplier" field.
|
||||
func ImageRateMultiplierNEQ(v float64) predicate.Group {
|
||||
return predicate.Group(sql.FieldNEQ(FieldImageRateMultiplier, v))
|
||||
}
|
||||
|
||||
// ImageRateMultiplierIn applies the In predicate on the "image_rate_multiplier" field.
|
||||
func ImageRateMultiplierIn(vs ...float64) predicate.Group {
|
||||
return predicate.Group(sql.FieldIn(FieldImageRateMultiplier, vs...))
|
||||
}
|
||||
|
||||
// ImageRateMultiplierNotIn applies the NotIn predicate on the "image_rate_multiplier" field.
|
||||
func ImageRateMultiplierNotIn(vs ...float64) predicate.Group {
|
||||
return predicate.Group(sql.FieldNotIn(FieldImageRateMultiplier, vs...))
|
||||
}
|
||||
|
||||
// ImageRateMultiplierGT applies the GT predicate on the "image_rate_multiplier" field.
|
||||
func ImageRateMultiplierGT(v float64) predicate.Group {
|
||||
return predicate.Group(sql.FieldGT(FieldImageRateMultiplier, v))
|
||||
}
|
||||
|
||||
// ImageRateMultiplierGTE applies the GTE predicate on the "image_rate_multiplier" field.
|
||||
func ImageRateMultiplierGTE(v float64) predicate.Group {
|
||||
return predicate.Group(sql.FieldGTE(FieldImageRateMultiplier, v))
|
||||
}
|
||||
|
||||
// ImageRateMultiplierLT applies the LT predicate on the "image_rate_multiplier" field.
|
||||
func ImageRateMultiplierLT(v float64) predicate.Group {
|
||||
return predicate.Group(sql.FieldLT(FieldImageRateMultiplier, v))
|
||||
}
|
||||
|
||||
// ImageRateMultiplierLTE applies the LTE predicate on the "image_rate_multiplier" field.
|
||||
func ImageRateMultiplierLTE(v float64) predicate.Group {
|
||||
return predicate.Group(sql.FieldLTE(FieldImageRateMultiplier, v))
|
||||
}
|
||||
|
||||
// ImagePrice1kEQ applies the EQ predicate on the "image_price_1k" field.
|
||||
func ImagePrice1kEQ(v float64) predicate.Group {
|
||||
return predicate.Group(sql.FieldEQ(FieldImagePrice1k, v))
|
||||
|
||||
@ -217,6 +217,48 @@ func (_c *GroupCreate) SetNillableDefaultValidityDays(v *int) *GroupCreate {
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetAllowImageGeneration sets the "allow_image_generation" field.
|
||||
func (_c *GroupCreate) SetAllowImageGeneration(v bool) *GroupCreate {
|
||||
_c.mutation.SetAllowImageGeneration(v)
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetNillableAllowImageGeneration sets the "allow_image_generation" field if the given value is not nil.
|
||||
func (_c *GroupCreate) SetNillableAllowImageGeneration(v *bool) *GroupCreate {
|
||||
if v != nil {
|
||||
_c.SetAllowImageGeneration(*v)
|
||||
}
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetImageRateIndependent sets the "image_rate_independent" field.
|
||||
func (_c *GroupCreate) SetImageRateIndependent(v bool) *GroupCreate {
|
||||
_c.mutation.SetImageRateIndependent(v)
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetNillableImageRateIndependent sets the "image_rate_independent" field if the given value is not nil.
|
||||
func (_c *GroupCreate) SetNillableImageRateIndependent(v *bool) *GroupCreate {
|
||||
if v != nil {
|
||||
_c.SetImageRateIndependent(*v)
|
||||
}
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetImageRateMultiplier sets the "image_rate_multiplier" field.
|
||||
func (_c *GroupCreate) SetImageRateMultiplier(v float64) *GroupCreate {
|
||||
_c.mutation.SetImageRateMultiplier(v)
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetNillableImageRateMultiplier sets the "image_rate_multiplier" field if the given value is not nil.
|
||||
func (_c *GroupCreate) SetNillableImageRateMultiplier(v *float64) *GroupCreate {
|
||||
if v != nil {
|
||||
_c.SetImageRateMultiplier(*v)
|
||||
}
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetImagePrice1k sets the "image_price_1k" field.
|
||||
func (_c *GroupCreate) SetImagePrice1k(v float64) *GroupCreate {
|
||||
_c.mutation.SetImagePrice1k(v)
|
||||
@ -604,6 +646,18 @@ func (_c *GroupCreate) defaults() error {
|
||||
v := group.DefaultDefaultValidityDays
|
||||
_c.mutation.SetDefaultValidityDays(v)
|
||||
}
|
||||
if _, ok := _c.mutation.AllowImageGeneration(); !ok {
|
||||
v := group.DefaultAllowImageGeneration
|
||||
_c.mutation.SetAllowImageGeneration(v)
|
||||
}
|
||||
if _, ok := _c.mutation.ImageRateIndependent(); !ok {
|
||||
v := group.DefaultImageRateIndependent
|
||||
_c.mutation.SetImageRateIndependent(v)
|
||||
}
|
||||
if _, ok := _c.mutation.ImageRateMultiplier(); !ok {
|
||||
v := group.DefaultImageRateMultiplier
|
||||
_c.mutation.SetImageRateMultiplier(v)
|
||||
}
|
||||
if _, ok := _c.mutation.ClaudeCodeOnly(); !ok {
|
||||
v := group.DefaultClaudeCodeOnly
|
||||
_c.mutation.SetClaudeCodeOnly(v)
|
||||
@ -700,6 +754,15 @@ func (_c *GroupCreate) check() error {
|
||||
if _, ok := _c.mutation.DefaultValidityDays(); !ok {
|
||||
return &ValidationError{Name: "default_validity_days", err: errors.New(`ent: missing required field "Group.default_validity_days"`)}
|
||||
}
|
||||
if _, ok := _c.mutation.AllowImageGeneration(); !ok {
|
||||
return &ValidationError{Name: "allow_image_generation", err: errors.New(`ent: missing required field "Group.allow_image_generation"`)}
|
||||
}
|
||||
if _, ok := _c.mutation.ImageRateIndependent(); !ok {
|
||||
return &ValidationError{Name: "image_rate_independent", err: errors.New(`ent: missing required field "Group.image_rate_independent"`)}
|
||||
}
|
||||
if _, ok := _c.mutation.ImageRateMultiplier(); !ok {
|
||||
return &ValidationError{Name: "image_rate_multiplier", err: errors.New(`ent: missing required field "Group.image_rate_multiplier"`)}
|
||||
}
|
||||
if _, ok := _c.mutation.ClaudeCodeOnly(); !ok {
|
||||
return &ValidationError{Name: "claude_code_only", err: errors.New(`ent: missing required field "Group.claude_code_only"`)}
|
||||
}
|
||||
@ -821,6 +884,18 @@ func (_c *GroupCreate) createSpec() (*Group, *sqlgraph.CreateSpec) {
|
||||
_spec.SetField(group.FieldDefaultValidityDays, field.TypeInt, value)
|
||||
_node.DefaultValidityDays = value
|
||||
}
|
||||
if value, ok := _c.mutation.AllowImageGeneration(); ok {
|
||||
_spec.SetField(group.FieldAllowImageGeneration, field.TypeBool, value)
|
||||
_node.AllowImageGeneration = value
|
||||
}
|
||||
if value, ok := _c.mutation.ImageRateIndependent(); ok {
|
||||
_spec.SetField(group.FieldImageRateIndependent, field.TypeBool, value)
|
||||
_node.ImageRateIndependent = value
|
||||
}
|
||||
if value, ok := _c.mutation.ImageRateMultiplier(); ok {
|
||||
_spec.SetField(group.FieldImageRateMultiplier, field.TypeFloat64, value)
|
||||
_node.ImageRateMultiplier = value
|
||||
}
|
||||
if value, ok := _c.mutation.ImagePrice1k(); ok {
|
||||
_spec.SetField(group.FieldImagePrice1k, field.TypeFloat64, value)
|
||||
_node.ImagePrice1k = &value
|
||||
@ -1261,6 +1336,48 @@ func (u *GroupUpsert) AddDefaultValidityDays(v int) *GroupUpsert {
|
||||
return u
|
||||
}
|
||||
|
||||
// SetAllowImageGeneration sets the "allow_image_generation" field.
|
||||
func (u *GroupUpsert) SetAllowImageGeneration(v bool) *GroupUpsert {
|
||||
u.Set(group.FieldAllowImageGeneration, v)
|
||||
return u
|
||||
}
|
||||
|
||||
// UpdateAllowImageGeneration sets the "allow_image_generation" field to the value that was provided on create.
|
||||
func (u *GroupUpsert) UpdateAllowImageGeneration() *GroupUpsert {
|
||||
u.SetExcluded(group.FieldAllowImageGeneration)
|
||||
return u
|
||||
}
|
||||
|
||||
// SetImageRateIndependent sets the "image_rate_independent" field.
|
||||
func (u *GroupUpsert) SetImageRateIndependent(v bool) *GroupUpsert {
|
||||
u.Set(group.FieldImageRateIndependent, v)
|
||||
return u
|
||||
}
|
||||
|
||||
// UpdateImageRateIndependent sets the "image_rate_independent" field to the value that was provided on create.
|
||||
func (u *GroupUpsert) UpdateImageRateIndependent() *GroupUpsert {
|
||||
u.SetExcluded(group.FieldImageRateIndependent)
|
||||
return u
|
||||
}
|
||||
|
||||
// SetImageRateMultiplier sets the "image_rate_multiplier" field.
|
||||
func (u *GroupUpsert) SetImageRateMultiplier(v float64) *GroupUpsert {
|
||||
u.Set(group.FieldImageRateMultiplier, v)
|
||||
return u
|
||||
}
|
||||
|
||||
// UpdateImageRateMultiplier sets the "image_rate_multiplier" field to the value that was provided on create.
|
||||
func (u *GroupUpsert) UpdateImageRateMultiplier() *GroupUpsert {
|
||||
u.SetExcluded(group.FieldImageRateMultiplier)
|
||||
return u
|
||||
}
|
||||
|
||||
// AddImageRateMultiplier adds v to the "image_rate_multiplier" field.
|
||||
func (u *GroupUpsert) AddImageRateMultiplier(v float64) *GroupUpsert {
|
||||
u.Add(group.FieldImageRateMultiplier, v)
|
||||
return u
|
||||
}
|
||||
|
||||
// SetImagePrice1k sets the "image_price_1k" field.
|
||||
func (u *GroupUpsert) SetImagePrice1k(v float64) *GroupUpsert {
|
||||
u.Set(group.FieldImagePrice1k, v)
|
||||
@ -1840,6 +1957,55 @@ func (u *GroupUpsertOne) UpdateDefaultValidityDays() *GroupUpsertOne {
|
||||
})
|
||||
}
|
||||
|
||||
// SetAllowImageGeneration sets the "allow_image_generation" field.
|
||||
func (u *GroupUpsertOne) SetAllowImageGeneration(v bool) *GroupUpsertOne {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.SetAllowImageGeneration(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateAllowImageGeneration sets the "allow_image_generation" field to the value that was provided on create.
|
||||
func (u *GroupUpsertOne) UpdateAllowImageGeneration() *GroupUpsertOne {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.UpdateAllowImageGeneration()
|
||||
})
|
||||
}
|
||||
|
||||
// SetImageRateIndependent sets the "image_rate_independent" field.
|
||||
func (u *GroupUpsertOne) SetImageRateIndependent(v bool) *GroupUpsertOne {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.SetImageRateIndependent(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateImageRateIndependent sets the "image_rate_independent" field to the value that was provided on create.
|
||||
func (u *GroupUpsertOne) UpdateImageRateIndependent() *GroupUpsertOne {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.UpdateImageRateIndependent()
|
||||
})
|
||||
}
|
||||
|
||||
// SetImageRateMultiplier sets the "image_rate_multiplier" field.
|
||||
func (u *GroupUpsertOne) SetImageRateMultiplier(v float64) *GroupUpsertOne {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.SetImageRateMultiplier(v)
|
||||
})
|
||||
}
|
||||
|
||||
// AddImageRateMultiplier adds v to the "image_rate_multiplier" field.
|
||||
func (u *GroupUpsertOne) AddImageRateMultiplier(v float64) *GroupUpsertOne {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.AddImageRateMultiplier(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateImageRateMultiplier sets the "image_rate_multiplier" field to the value that was provided on create.
|
||||
func (u *GroupUpsertOne) UpdateImageRateMultiplier() *GroupUpsertOne {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.UpdateImageRateMultiplier()
|
||||
})
|
||||
}
|
||||
|
||||
// SetImagePrice1k sets the "image_price_1k" field.
|
||||
func (u *GroupUpsertOne) SetImagePrice1k(v float64) *GroupUpsertOne {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
@ -2632,6 +2798,55 @@ func (u *GroupUpsertBulk) UpdateDefaultValidityDays() *GroupUpsertBulk {
|
||||
})
|
||||
}
|
||||
|
||||
// SetAllowImageGeneration sets the "allow_image_generation" field.
|
||||
func (u *GroupUpsertBulk) SetAllowImageGeneration(v bool) *GroupUpsertBulk {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.SetAllowImageGeneration(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateAllowImageGeneration sets the "allow_image_generation" field to the value that was provided on create.
|
||||
func (u *GroupUpsertBulk) UpdateAllowImageGeneration() *GroupUpsertBulk {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.UpdateAllowImageGeneration()
|
||||
})
|
||||
}
|
||||
|
||||
// SetImageRateIndependent sets the "image_rate_independent" field.
|
||||
func (u *GroupUpsertBulk) SetImageRateIndependent(v bool) *GroupUpsertBulk {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.SetImageRateIndependent(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateImageRateIndependent sets the "image_rate_independent" field to the value that was provided on create.
|
||||
func (u *GroupUpsertBulk) UpdateImageRateIndependent() *GroupUpsertBulk {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.UpdateImageRateIndependent()
|
||||
})
|
||||
}
|
||||
|
||||
// SetImageRateMultiplier sets the "image_rate_multiplier" field.
|
||||
func (u *GroupUpsertBulk) SetImageRateMultiplier(v float64) *GroupUpsertBulk {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.SetImageRateMultiplier(v)
|
||||
})
|
||||
}
|
||||
|
||||
// AddImageRateMultiplier adds v to the "image_rate_multiplier" field.
|
||||
func (u *GroupUpsertBulk) AddImageRateMultiplier(v float64) *GroupUpsertBulk {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.AddImageRateMultiplier(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateImageRateMultiplier sets the "image_rate_multiplier" field to the value that was provided on create.
|
||||
func (u *GroupUpsertBulk) UpdateImageRateMultiplier() *GroupUpsertBulk {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.UpdateImageRateMultiplier()
|
||||
})
|
||||
}
|
||||
|
||||
// SetImagePrice1k sets the "image_price_1k" field.
|
||||
func (u *GroupUpsertBulk) SetImagePrice1k(v float64) *GroupUpsertBulk {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
|
||||
@ -275,6 +275,55 @@ func (_u *GroupUpdate) AddDefaultValidityDays(v int) *GroupUpdate {
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetAllowImageGeneration sets the "allow_image_generation" field.
|
||||
func (_u *GroupUpdate) SetAllowImageGeneration(v bool) *GroupUpdate {
|
||||
_u.mutation.SetAllowImageGeneration(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableAllowImageGeneration sets the "allow_image_generation" field if the given value is not nil.
|
||||
func (_u *GroupUpdate) SetNillableAllowImageGeneration(v *bool) *GroupUpdate {
|
||||
if v != nil {
|
||||
_u.SetAllowImageGeneration(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetImageRateIndependent sets the "image_rate_independent" field.
|
||||
func (_u *GroupUpdate) SetImageRateIndependent(v bool) *GroupUpdate {
|
||||
_u.mutation.SetImageRateIndependent(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableImageRateIndependent sets the "image_rate_independent" field if the given value is not nil.
|
||||
func (_u *GroupUpdate) SetNillableImageRateIndependent(v *bool) *GroupUpdate {
|
||||
if v != nil {
|
||||
_u.SetImageRateIndependent(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetImageRateMultiplier sets the "image_rate_multiplier" field.
|
||||
func (_u *GroupUpdate) SetImageRateMultiplier(v float64) *GroupUpdate {
|
||||
_u.mutation.ResetImageRateMultiplier()
|
||||
_u.mutation.SetImageRateMultiplier(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableImageRateMultiplier sets the "image_rate_multiplier" field if the given value is not nil.
|
||||
func (_u *GroupUpdate) SetNillableImageRateMultiplier(v *float64) *GroupUpdate {
|
||||
if v != nil {
|
||||
_u.SetImageRateMultiplier(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// AddImageRateMultiplier adds value to the "image_rate_multiplier" field.
|
||||
func (_u *GroupUpdate) AddImageRateMultiplier(v float64) *GroupUpdate {
|
||||
_u.mutation.AddImageRateMultiplier(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetImagePrice1k sets the "image_price_1k" field.
|
||||
func (_u *GroupUpdate) SetImagePrice1k(v float64) *GroupUpdate {
|
||||
_u.mutation.ResetImagePrice1k()
|
||||
@ -962,6 +1011,18 @@ func (_u *GroupUpdate) sqlSave(ctx context.Context) (_node int, err error) {
|
||||
if value, ok := _u.mutation.AddedDefaultValidityDays(); ok {
|
||||
_spec.AddField(group.FieldDefaultValidityDays, field.TypeInt, value)
|
||||
}
|
||||
if value, ok := _u.mutation.AllowImageGeneration(); ok {
|
||||
_spec.SetField(group.FieldAllowImageGeneration, field.TypeBool, value)
|
||||
}
|
||||
if value, ok := _u.mutation.ImageRateIndependent(); ok {
|
||||
_spec.SetField(group.FieldImageRateIndependent, field.TypeBool, value)
|
||||
}
|
||||
if value, ok := _u.mutation.ImageRateMultiplier(); ok {
|
||||
_spec.SetField(group.FieldImageRateMultiplier, field.TypeFloat64, value)
|
||||
}
|
||||
if value, ok := _u.mutation.AddedImageRateMultiplier(); ok {
|
||||
_spec.AddField(group.FieldImageRateMultiplier, field.TypeFloat64, value)
|
||||
}
|
||||
if value, ok := _u.mutation.ImagePrice1k(); ok {
|
||||
_spec.SetField(group.FieldImagePrice1k, field.TypeFloat64, value)
|
||||
}
|
||||
@ -1610,6 +1671,55 @@ func (_u *GroupUpdateOne) AddDefaultValidityDays(v int) *GroupUpdateOne {
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetAllowImageGeneration sets the "allow_image_generation" field.
|
||||
func (_u *GroupUpdateOne) SetAllowImageGeneration(v bool) *GroupUpdateOne {
|
||||
_u.mutation.SetAllowImageGeneration(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableAllowImageGeneration sets the "allow_image_generation" field if the given value is not nil.
|
||||
func (_u *GroupUpdateOne) SetNillableAllowImageGeneration(v *bool) *GroupUpdateOne {
|
||||
if v != nil {
|
||||
_u.SetAllowImageGeneration(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetImageRateIndependent sets the "image_rate_independent" field.
|
||||
func (_u *GroupUpdateOne) SetImageRateIndependent(v bool) *GroupUpdateOne {
|
||||
_u.mutation.SetImageRateIndependent(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableImageRateIndependent sets the "image_rate_independent" field if the given value is not nil.
|
||||
func (_u *GroupUpdateOne) SetNillableImageRateIndependent(v *bool) *GroupUpdateOne {
|
||||
if v != nil {
|
||||
_u.SetImageRateIndependent(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetImageRateMultiplier sets the "image_rate_multiplier" field.
|
||||
func (_u *GroupUpdateOne) SetImageRateMultiplier(v float64) *GroupUpdateOne {
|
||||
_u.mutation.ResetImageRateMultiplier()
|
||||
_u.mutation.SetImageRateMultiplier(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableImageRateMultiplier sets the "image_rate_multiplier" field if the given value is not nil.
|
||||
func (_u *GroupUpdateOne) SetNillableImageRateMultiplier(v *float64) *GroupUpdateOne {
|
||||
if v != nil {
|
||||
_u.SetImageRateMultiplier(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// AddImageRateMultiplier adds value to the "image_rate_multiplier" field.
|
||||
func (_u *GroupUpdateOne) AddImageRateMultiplier(v float64) *GroupUpdateOne {
|
||||
_u.mutation.AddImageRateMultiplier(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetImagePrice1k sets the "image_price_1k" field.
|
||||
func (_u *GroupUpdateOne) SetImagePrice1k(v float64) *GroupUpdateOne {
|
||||
_u.mutation.ResetImagePrice1k()
|
||||
@ -2327,6 +2437,18 @@ func (_u *GroupUpdateOne) sqlSave(ctx context.Context) (_node *Group, err error)
|
||||
if value, ok := _u.mutation.AddedDefaultValidityDays(); ok {
|
||||
_spec.AddField(group.FieldDefaultValidityDays, field.TypeInt, value)
|
||||
}
|
||||
if value, ok := _u.mutation.AllowImageGeneration(); ok {
|
||||
_spec.SetField(group.FieldAllowImageGeneration, field.TypeBool, value)
|
||||
}
|
||||
if value, ok := _u.mutation.ImageRateIndependent(); ok {
|
||||
_spec.SetField(group.FieldImageRateIndependent, field.TypeBool, value)
|
||||
}
|
||||
if value, ok := _u.mutation.ImageRateMultiplier(); ok {
|
||||
_spec.SetField(group.FieldImageRateMultiplier, field.TypeFloat64, value)
|
||||
}
|
||||
if value, ok := _u.mutation.AddedImageRateMultiplier(); ok {
|
||||
_spec.AddField(group.FieldImageRateMultiplier, field.TypeFloat64, value)
|
||||
}
|
||||
if value, ok := _u.mutation.ImagePrice1k(); ok {
|
||||
_spec.SetField(group.FieldImagePrice1k, field.TypeFloat64, value)
|
||||
}
|
||||
|
||||
@ -638,6 +638,9 @@ var (
|
||||
{Name: "weekly_limit_usd", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}},
|
||||
{Name: "monthly_limit_usd", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}},
|
||||
{Name: "default_validity_days", Type: field.TypeInt, Default: 30},
|
||||
{Name: "allow_image_generation", Type: field.TypeBool, Default: false},
|
||||
{Name: "image_rate_independent", Type: field.TypeBool, Default: false},
|
||||
{Name: "image_rate_multiplier", Type: field.TypeFloat64, Default: 1, SchemaType: map[string]string{"postgres": "decimal(10,4)"}},
|
||||
{Name: "image_price_1k", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}},
|
||||
{Name: "image_price_2k", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}},
|
||||
{Name: "image_price_4k", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}},
|
||||
@ -690,7 +693,7 @@ var (
|
||||
{
|
||||
Name: "group_sort_order",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{GroupsColumns[25]},
|
||||
Columns: []*schema.Column{GroupsColumns[28]},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
@ -14764,6 +14764,10 @@ type GroupMutation struct {
|
||||
addmonthly_limit_usd *float64
|
||||
default_validity_days *int
|
||||
adddefault_validity_days *int
|
||||
allow_image_generation *bool
|
||||
image_rate_independent *bool
|
||||
image_rate_multiplier *float64
|
||||
addimage_rate_multiplier *float64
|
||||
image_price_1k *float64
|
||||
addimage_price_1k *float64
|
||||
image_price_2k *float64
|
||||
@ -15583,6 +15587,134 @@ func (m *GroupMutation) ResetDefaultValidityDays() {
|
||||
m.adddefault_validity_days = nil
|
||||
}
|
||||
|
||||
// SetAllowImageGeneration sets the "allow_image_generation" field.
|
||||
func (m *GroupMutation) SetAllowImageGeneration(b bool) {
|
||||
m.allow_image_generation = &b
|
||||
}
|
||||
|
||||
// AllowImageGeneration returns the value of the "allow_image_generation" field in the mutation.
|
||||
func (m *GroupMutation) AllowImageGeneration() (r bool, exists bool) {
|
||||
v := m.allow_image_generation
|
||||
if v == nil {
|
||||
return
|
||||
}
|
||||
return *v, true
|
||||
}
|
||||
|
||||
// OldAllowImageGeneration returns the old "allow_image_generation" field's value of the Group entity.
|
||||
// If the Group object wasn't provided to the builder, the object is fetched from the database.
|
||||
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
|
||||
func (m *GroupMutation) OldAllowImageGeneration(ctx context.Context) (v bool, err error) {
|
||||
if !m.op.Is(OpUpdateOne) {
|
||||
return v, errors.New("OldAllowImageGeneration is only allowed on UpdateOne operations")
|
||||
}
|
||||
if m.id == nil || m.oldValue == nil {
|
||||
return v, errors.New("OldAllowImageGeneration requires an ID field in the mutation")
|
||||
}
|
||||
oldValue, err := m.oldValue(ctx)
|
||||
if err != nil {
|
||||
return v, fmt.Errorf("querying old value for OldAllowImageGeneration: %w", err)
|
||||
}
|
||||
return oldValue.AllowImageGeneration, nil
|
||||
}
|
||||
|
||||
// ResetAllowImageGeneration resets all changes to the "allow_image_generation" field.
|
||||
func (m *GroupMutation) ResetAllowImageGeneration() {
|
||||
m.allow_image_generation = nil
|
||||
}
|
||||
|
||||
// SetImageRateIndependent sets the "image_rate_independent" field.
|
||||
func (m *GroupMutation) SetImageRateIndependent(b bool) {
|
||||
m.image_rate_independent = &b
|
||||
}
|
||||
|
||||
// ImageRateIndependent returns the value of the "image_rate_independent" field in the mutation.
|
||||
func (m *GroupMutation) ImageRateIndependent() (r bool, exists bool) {
|
||||
v := m.image_rate_independent
|
||||
if v == nil {
|
||||
return
|
||||
}
|
||||
return *v, true
|
||||
}
|
||||
|
||||
// OldImageRateIndependent returns the old "image_rate_independent" field's value of the Group entity.
|
||||
// If the Group object wasn't provided to the builder, the object is fetched from the database.
|
||||
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
|
||||
func (m *GroupMutation) OldImageRateIndependent(ctx context.Context) (v bool, err error) {
|
||||
if !m.op.Is(OpUpdateOne) {
|
||||
return v, errors.New("OldImageRateIndependent is only allowed on UpdateOne operations")
|
||||
}
|
||||
if m.id == nil || m.oldValue == nil {
|
||||
return v, errors.New("OldImageRateIndependent requires an ID field in the mutation")
|
||||
}
|
||||
oldValue, err := m.oldValue(ctx)
|
||||
if err != nil {
|
||||
return v, fmt.Errorf("querying old value for OldImageRateIndependent: %w", err)
|
||||
}
|
||||
return oldValue.ImageRateIndependent, nil
|
||||
}
|
||||
|
||||
// ResetImageRateIndependent resets all changes to the "image_rate_independent" field.
|
||||
func (m *GroupMutation) ResetImageRateIndependent() {
|
||||
m.image_rate_independent = nil
|
||||
}
|
||||
|
||||
// SetImageRateMultiplier sets the "image_rate_multiplier" field.
|
||||
func (m *GroupMutation) SetImageRateMultiplier(f float64) {
|
||||
m.image_rate_multiplier = &f
|
||||
m.addimage_rate_multiplier = nil
|
||||
}
|
||||
|
||||
// ImageRateMultiplier returns the value of the "image_rate_multiplier" field in the mutation.
|
||||
func (m *GroupMutation) ImageRateMultiplier() (r float64, exists bool) {
|
||||
v := m.image_rate_multiplier
|
||||
if v == nil {
|
||||
return
|
||||
}
|
||||
return *v, true
|
||||
}
|
||||
|
||||
// OldImageRateMultiplier returns the old "image_rate_multiplier" field's value of the Group entity.
|
||||
// If the Group object wasn't provided to the builder, the object is fetched from the database.
|
||||
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
|
||||
func (m *GroupMutation) OldImageRateMultiplier(ctx context.Context) (v float64, err error) {
|
||||
if !m.op.Is(OpUpdateOne) {
|
||||
return v, errors.New("OldImageRateMultiplier is only allowed on UpdateOne operations")
|
||||
}
|
||||
if m.id == nil || m.oldValue == nil {
|
||||
return v, errors.New("OldImageRateMultiplier requires an ID field in the mutation")
|
||||
}
|
||||
oldValue, err := m.oldValue(ctx)
|
||||
if err != nil {
|
||||
return v, fmt.Errorf("querying old value for OldImageRateMultiplier: %w", err)
|
||||
}
|
||||
return oldValue.ImageRateMultiplier, nil
|
||||
}
|
||||
|
||||
// AddImageRateMultiplier adds f to the "image_rate_multiplier" field.
|
||||
func (m *GroupMutation) AddImageRateMultiplier(f float64) {
|
||||
if m.addimage_rate_multiplier != nil {
|
||||
*m.addimage_rate_multiplier += f
|
||||
} else {
|
||||
m.addimage_rate_multiplier = &f
|
||||
}
|
||||
}
|
||||
|
||||
// AddedImageRateMultiplier returns the value that was added to the "image_rate_multiplier" field in this mutation.
|
||||
func (m *GroupMutation) AddedImageRateMultiplier() (r float64, exists bool) {
|
||||
v := m.addimage_rate_multiplier
|
||||
if v == nil {
|
||||
return
|
||||
}
|
||||
return *v, true
|
||||
}
|
||||
|
||||
// ResetImageRateMultiplier resets all changes to the "image_rate_multiplier" field.
|
||||
func (m *GroupMutation) ResetImageRateMultiplier() {
|
||||
m.image_rate_multiplier = nil
|
||||
m.addimage_rate_multiplier = nil
|
||||
}
|
||||
|
||||
// SetImagePrice1k sets the "image_price_1k" field.
|
||||
func (m *GroupMutation) SetImagePrice1k(f float64) {
|
||||
m.image_price_1k = &f
|
||||
@ -16791,7 +16923,7 @@ func (m *GroupMutation) Type() string {
|
||||
// order to get all numeric fields that were incremented/decremented, call
|
||||
// AddedFields().
|
||||
func (m *GroupMutation) Fields() []string {
|
||||
fields := make([]string, 0, 31)
|
||||
fields := make([]string, 0, 34)
|
||||
if m.created_at != nil {
|
||||
fields = append(fields, group.FieldCreatedAt)
|
||||
}
|
||||
@ -16834,6 +16966,15 @@ func (m *GroupMutation) Fields() []string {
|
||||
if m.default_validity_days != nil {
|
||||
fields = append(fields, group.FieldDefaultValidityDays)
|
||||
}
|
||||
if m.allow_image_generation != nil {
|
||||
fields = append(fields, group.FieldAllowImageGeneration)
|
||||
}
|
||||
if m.image_rate_independent != nil {
|
||||
fields = append(fields, group.FieldImageRateIndependent)
|
||||
}
|
||||
if m.image_rate_multiplier != nil {
|
||||
fields = append(fields, group.FieldImageRateMultiplier)
|
||||
}
|
||||
if m.image_price_1k != nil {
|
||||
fields = append(fields, group.FieldImagePrice1k)
|
||||
}
|
||||
@ -16921,6 +17062,12 @@ func (m *GroupMutation) Field(name string) (ent.Value, bool) {
|
||||
return m.MonthlyLimitUsd()
|
||||
case group.FieldDefaultValidityDays:
|
||||
return m.DefaultValidityDays()
|
||||
case group.FieldAllowImageGeneration:
|
||||
return m.AllowImageGeneration()
|
||||
case group.FieldImageRateIndependent:
|
||||
return m.ImageRateIndependent()
|
||||
case group.FieldImageRateMultiplier:
|
||||
return m.ImageRateMultiplier()
|
||||
case group.FieldImagePrice1k:
|
||||
return m.ImagePrice1k()
|
||||
case group.FieldImagePrice2k:
|
||||
@ -16992,6 +17139,12 @@ func (m *GroupMutation) OldField(ctx context.Context, name string) (ent.Value, e
|
||||
return m.OldMonthlyLimitUsd(ctx)
|
||||
case group.FieldDefaultValidityDays:
|
||||
return m.OldDefaultValidityDays(ctx)
|
||||
case group.FieldAllowImageGeneration:
|
||||
return m.OldAllowImageGeneration(ctx)
|
||||
case group.FieldImageRateIndependent:
|
||||
return m.OldImageRateIndependent(ctx)
|
||||
case group.FieldImageRateMultiplier:
|
||||
return m.OldImageRateMultiplier(ctx)
|
||||
case group.FieldImagePrice1k:
|
||||
return m.OldImagePrice1k(ctx)
|
||||
case group.FieldImagePrice2k:
|
||||
@ -17133,6 +17286,27 @@ func (m *GroupMutation) SetField(name string, value ent.Value) error {
|
||||
}
|
||||
m.SetDefaultValidityDays(v)
|
||||
return nil
|
||||
case group.FieldAllowImageGeneration:
|
||||
v, ok := value.(bool)
|
||||
if !ok {
|
||||
return fmt.Errorf("unexpected type %T for field %s", value, name)
|
||||
}
|
||||
m.SetAllowImageGeneration(v)
|
||||
return nil
|
||||
case group.FieldImageRateIndependent:
|
||||
v, ok := value.(bool)
|
||||
if !ok {
|
||||
return fmt.Errorf("unexpected type %T for field %s", value, name)
|
||||
}
|
||||
m.SetImageRateIndependent(v)
|
||||
return nil
|
||||
case group.FieldImageRateMultiplier:
|
||||
v, ok := value.(float64)
|
||||
if !ok {
|
||||
return fmt.Errorf("unexpected type %T for field %s", value, name)
|
||||
}
|
||||
m.SetImageRateMultiplier(v)
|
||||
return nil
|
||||
case group.FieldImagePrice1k:
|
||||
v, ok := value.(float64)
|
||||
if !ok {
|
||||
@ -17275,6 +17449,9 @@ func (m *GroupMutation) AddedFields() []string {
|
||||
if m.adddefault_validity_days != nil {
|
||||
fields = append(fields, group.FieldDefaultValidityDays)
|
||||
}
|
||||
if m.addimage_rate_multiplier != nil {
|
||||
fields = append(fields, group.FieldImageRateMultiplier)
|
||||
}
|
||||
if m.addimage_price_1k != nil {
|
||||
fields = append(fields, group.FieldImagePrice1k)
|
||||
}
|
||||
@ -17314,6 +17491,8 @@ func (m *GroupMutation) AddedField(name string) (ent.Value, bool) {
|
||||
return m.AddedMonthlyLimitUsd()
|
||||
case group.FieldDefaultValidityDays:
|
||||
return m.AddedDefaultValidityDays()
|
||||
case group.FieldImageRateMultiplier:
|
||||
return m.AddedImageRateMultiplier()
|
||||
case group.FieldImagePrice1k:
|
||||
return m.AddedImagePrice1k()
|
||||
case group.FieldImagePrice2k:
|
||||
@ -17372,6 +17551,13 @@ func (m *GroupMutation) AddField(name string, value ent.Value) error {
|
||||
}
|
||||
m.AddDefaultValidityDays(v)
|
||||
return nil
|
||||
case group.FieldImageRateMultiplier:
|
||||
v, ok := value.(float64)
|
||||
if !ok {
|
||||
return fmt.Errorf("unexpected type %T for field %s", value, name)
|
||||
}
|
||||
m.AddImageRateMultiplier(v)
|
||||
return nil
|
||||
case group.FieldImagePrice1k:
|
||||
v, ok := value.(float64)
|
||||
if !ok {
|
||||
@ -17559,6 +17745,15 @@ func (m *GroupMutation) ResetField(name string) error {
|
||||
case group.FieldDefaultValidityDays:
|
||||
m.ResetDefaultValidityDays()
|
||||
return nil
|
||||
case group.FieldAllowImageGeneration:
|
||||
m.ResetAllowImageGeneration()
|
||||
return nil
|
||||
case group.FieldImageRateIndependent:
|
||||
m.ResetImageRateIndependent()
|
||||
return nil
|
||||
case group.FieldImageRateMultiplier:
|
||||
m.ResetImageRateMultiplier()
|
||||
return nil
|
||||
case group.FieldImagePrice1k:
|
||||
m.ResetImagePrice1k()
|
||||
return nil
|
||||
|
||||
@ -803,50 +803,62 @@ func init() {
|
||||
groupDescDefaultValidityDays := groupFields[10].Descriptor()
|
||||
// group.DefaultDefaultValidityDays holds the default value on creation for the default_validity_days field.
|
||||
group.DefaultDefaultValidityDays = groupDescDefaultValidityDays.Default.(int)
|
||||
// groupDescAllowImageGeneration is the schema descriptor for allow_image_generation field.
|
||||
groupDescAllowImageGeneration := groupFields[11].Descriptor()
|
||||
// group.DefaultAllowImageGeneration holds the default value on creation for the allow_image_generation field.
|
||||
group.DefaultAllowImageGeneration = groupDescAllowImageGeneration.Default.(bool)
|
||||
// groupDescImageRateIndependent is the schema descriptor for image_rate_independent field.
|
||||
groupDescImageRateIndependent := groupFields[12].Descriptor()
|
||||
// group.DefaultImageRateIndependent holds the default value on creation for the image_rate_independent field.
|
||||
group.DefaultImageRateIndependent = groupDescImageRateIndependent.Default.(bool)
|
||||
// groupDescImageRateMultiplier is the schema descriptor for image_rate_multiplier field.
|
||||
groupDescImageRateMultiplier := groupFields[13].Descriptor()
|
||||
// group.DefaultImageRateMultiplier holds the default value on creation for the image_rate_multiplier field.
|
||||
group.DefaultImageRateMultiplier = groupDescImageRateMultiplier.Default.(float64)
|
||||
// groupDescClaudeCodeOnly is the schema descriptor for claude_code_only field.
|
||||
groupDescClaudeCodeOnly := groupFields[14].Descriptor()
|
||||
groupDescClaudeCodeOnly := groupFields[17].Descriptor()
|
||||
// group.DefaultClaudeCodeOnly holds the default value on creation for the claude_code_only field.
|
||||
group.DefaultClaudeCodeOnly = groupDescClaudeCodeOnly.Default.(bool)
|
||||
// groupDescModelRoutingEnabled is the schema descriptor for model_routing_enabled field.
|
||||
groupDescModelRoutingEnabled := groupFields[18].Descriptor()
|
||||
groupDescModelRoutingEnabled := groupFields[21].Descriptor()
|
||||
// group.DefaultModelRoutingEnabled holds the default value on creation for the model_routing_enabled field.
|
||||
group.DefaultModelRoutingEnabled = groupDescModelRoutingEnabled.Default.(bool)
|
||||
// groupDescMcpXMLInject is the schema descriptor for mcp_xml_inject field.
|
||||
groupDescMcpXMLInject := groupFields[19].Descriptor()
|
||||
groupDescMcpXMLInject := groupFields[22].Descriptor()
|
||||
// group.DefaultMcpXMLInject holds the default value on creation for the mcp_xml_inject field.
|
||||
group.DefaultMcpXMLInject = groupDescMcpXMLInject.Default.(bool)
|
||||
// groupDescSupportedModelScopes is the schema descriptor for supported_model_scopes field.
|
||||
groupDescSupportedModelScopes := groupFields[20].Descriptor()
|
||||
groupDescSupportedModelScopes := groupFields[23].Descriptor()
|
||||
// group.DefaultSupportedModelScopes holds the default value on creation for the supported_model_scopes field.
|
||||
group.DefaultSupportedModelScopes = groupDescSupportedModelScopes.Default.([]string)
|
||||
// groupDescSortOrder is the schema descriptor for sort_order field.
|
||||
groupDescSortOrder := groupFields[21].Descriptor()
|
||||
groupDescSortOrder := groupFields[24].Descriptor()
|
||||
// group.DefaultSortOrder holds the default value on creation for the sort_order field.
|
||||
group.DefaultSortOrder = groupDescSortOrder.Default.(int)
|
||||
// groupDescAllowMessagesDispatch is the schema descriptor for allow_messages_dispatch field.
|
||||
groupDescAllowMessagesDispatch := groupFields[22].Descriptor()
|
||||
groupDescAllowMessagesDispatch := groupFields[25].Descriptor()
|
||||
// group.DefaultAllowMessagesDispatch holds the default value on creation for the allow_messages_dispatch field.
|
||||
group.DefaultAllowMessagesDispatch = groupDescAllowMessagesDispatch.Default.(bool)
|
||||
// groupDescRequireOauthOnly is the schema descriptor for require_oauth_only field.
|
||||
groupDescRequireOauthOnly := groupFields[23].Descriptor()
|
||||
groupDescRequireOauthOnly := groupFields[26].Descriptor()
|
||||
// group.DefaultRequireOauthOnly holds the default value on creation for the require_oauth_only field.
|
||||
group.DefaultRequireOauthOnly = groupDescRequireOauthOnly.Default.(bool)
|
||||
// groupDescRequirePrivacySet is the schema descriptor for require_privacy_set field.
|
||||
groupDescRequirePrivacySet := groupFields[24].Descriptor()
|
||||
groupDescRequirePrivacySet := groupFields[27].Descriptor()
|
||||
// group.DefaultRequirePrivacySet holds the default value on creation for the require_privacy_set field.
|
||||
group.DefaultRequirePrivacySet = groupDescRequirePrivacySet.Default.(bool)
|
||||
// groupDescDefaultMappedModel is the schema descriptor for default_mapped_model field.
|
||||
groupDescDefaultMappedModel := groupFields[25].Descriptor()
|
||||
groupDescDefaultMappedModel := groupFields[28].Descriptor()
|
||||
// group.DefaultDefaultMappedModel holds the default value on creation for the default_mapped_model field.
|
||||
group.DefaultDefaultMappedModel = groupDescDefaultMappedModel.Default.(string)
|
||||
// group.DefaultMappedModelValidator is a validator for the "default_mapped_model" field. It is called by the builders before save.
|
||||
group.DefaultMappedModelValidator = groupDescDefaultMappedModel.Validators[0].(func(string) error)
|
||||
// groupDescMessagesDispatchModelConfig is the schema descriptor for messages_dispatch_model_config field.
|
||||
groupDescMessagesDispatchModelConfig := groupFields[26].Descriptor()
|
||||
groupDescMessagesDispatchModelConfig := groupFields[29].Descriptor()
|
||||
// group.DefaultMessagesDispatchModelConfig holds the default value on creation for the messages_dispatch_model_config field.
|
||||
group.DefaultMessagesDispatchModelConfig = groupDescMessagesDispatchModelConfig.Default.(domain.OpenAIMessagesDispatchModelConfig)
|
||||
// groupDescRpmLimit is the schema descriptor for rpm_limit field.
|
||||
groupDescRpmLimit := groupFields[27].Descriptor()
|
||||
groupDescRpmLimit := groupFields[30].Descriptor()
|
||||
// group.DefaultRpmLimit holds the default value on creation for the rpm_limit field.
|
||||
group.DefaultRpmLimit = groupDescRpmLimit.Default.(int)
|
||||
idempotencyrecordMixin := schema.IdempotencyRecord{}.Mixin()
|
||||
|
||||
@ -74,6 +74,16 @@ func (Group) Fields() []ent.Field {
|
||||
Default(30),
|
||||
|
||||
// 图片生成计费配置(antigravity 和 gemini 平台使用)
|
||||
field.Bool("allow_image_generation").
|
||||
Default(false).
|
||||
Comment("是否允许该分组使用图片生成能力"),
|
||||
field.Bool("image_rate_independent").
|
||||
Default(false).
|
||||
Comment("图片生成是否使用独立倍率;false 表示共享分组有效倍率"),
|
||||
field.Float("image_rate_multiplier").
|
||||
SchemaType(map[string]string{dialect.Postgres: "decimal(10,4)"}).
|
||||
Default(1.0).
|
||||
Comment("图片生成独立倍率,仅 image_rate_independent=true 时生效"),
|
||||
field.Float("image_price_1k").
|
||||
Optional().
|
||||
Nillable().
|
||||
|
||||
@ -575,6 +575,24 @@ type ConcurrencyConfig struct {
|
||||
PingInterval int `mapstructure:"ping_interval"`
|
||||
}
|
||||
|
||||
type ImageConcurrencyConfig struct {
|
||||
// Enabled: 是否启用图片生成独立并发限制,默认关闭以保持现有行为
|
||||
Enabled bool `mapstructure:"enabled"`
|
||||
// MaxConcurrentRequests: 当前进程允许同时处理的图片生成请求数,0表示不限制
|
||||
MaxConcurrentRequests int `mapstructure:"max_concurrent_requests"`
|
||||
// OverflowMode: 图片并发达到上限后的处理方式:reject/wait
|
||||
OverflowMode string `mapstructure:"overflow_mode"`
|
||||
// WaitTimeoutSeconds: overflow_mode=wait 时等待图片并发槽位的超时时间(秒)
|
||||
WaitTimeoutSeconds int `mapstructure:"wait_timeout_seconds"`
|
||||
// MaxWaitingRequests: overflow_mode=wait 时当前进程允许排队等待的图片请求数
|
||||
MaxWaitingRequests int `mapstructure:"max_waiting_requests"`
|
||||
}
|
||||
|
||||
const (
|
||||
ImageConcurrencyOverflowModeReject = "reject"
|
||||
ImageConcurrencyOverflowModeWait = "wait"
|
||||
)
|
||||
|
||||
// GatewayConfig API网关相关配置
|
||||
type GatewayConfig struct {
|
||||
// 等待上游响应头的超时时间(秒),0表示无超时
|
||||
@ -604,6 +622,8 @@ type GatewayConfig struct {
|
||||
OpenAIPassthroughAllowTimeoutHeaders bool `mapstructure:"openai_passthrough_allow_timeout_headers"`
|
||||
// OpenAIWS: OpenAI Responses WebSocket 配置(默认开启,可按需回滚到 HTTP)
|
||||
OpenAIWS GatewayOpenAIWSConfig `mapstructure:"openai_ws"`
|
||||
// ImageConcurrency: 图片生成独立并发限制配置(默认关闭)
|
||||
ImageConcurrency ImageConcurrencyConfig `mapstructure:"image_concurrency"`
|
||||
|
||||
// HTTP 上游连接池配置(性能优化:支持高并发场景调优)
|
||||
// MaxIdleConns: 所有主机的最大空闲连接总数
|
||||
@ -635,6 +655,10 @@ type GatewayConfig struct {
|
||||
StreamDataIntervalTimeout int `mapstructure:"stream_data_interval_timeout"`
|
||||
// StreamKeepaliveInterval: 流式 keepalive 间隔(秒),0表示禁用
|
||||
StreamKeepaliveInterval int `mapstructure:"stream_keepalive_interval"`
|
||||
// ImageStreamDataIntervalTimeout: 图片流数据间隔超时(秒),0表示禁用
|
||||
ImageStreamDataIntervalTimeout int `mapstructure:"image_stream_data_interval_timeout"`
|
||||
// ImageStreamKeepaliveInterval: 图片流式 keepalive 间隔(秒),0表示禁用
|
||||
ImageStreamKeepaliveInterval int `mapstructure:"image_stream_keepalive_interval"`
|
||||
// MaxLineSize: 上游 SSE 单行最大字节数(0使用默认值)
|
||||
MaxLineSize int `mapstructure:"max_line_size"`
|
||||
|
||||
@ -1672,6 +1696,11 @@ func setDefaults() {
|
||||
viper.SetDefault("gateway.openai_ws.scheduler_score_weights.queue", 0.7)
|
||||
viper.SetDefault("gateway.openai_ws.scheduler_score_weights.error_rate", 0.8)
|
||||
viper.SetDefault("gateway.openai_ws.scheduler_score_weights.ttft", 0.5)
|
||||
viper.SetDefault("gateway.image_concurrency.enabled", false)
|
||||
viper.SetDefault("gateway.image_concurrency.max_concurrent_requests", 0)
|
||||
viper.SetDefault("gateway.image_concurrency.overflow_mode", ImageConcurrencyOverflowModeReject)
|
||||
viper.SetDefault("gateway.image_concurrency.wait_timeout_seconds", 30)
|
||||
viper.SetDefault("gateway.image_concurrency.max_waiting_requests", 100)
|
||||
viper.SetDefault("gateway.antigravity_fallback_cooldown_minutes", 1)
|
||||
viper.SetDefault("gateway.antigravity_extra_retries", 10)
|
||||
viper.SetDefault("gateway.max_body_size", int64(256*1024*1024))
|
||||
@ -1689,6 +1718,8 @@ func setDefaults() {
|
||||
viper.SetDefault("gateway.concurrency_slot_ttl_minutes", 30) // 并发槽位过期时间(支持超长请求)
|
||||
viper.SetDefault("gateway.stream_data_interval_timeout", 180)
|
||||
viper.SetDefault("gateway.stream_keepalive_interval", 10)
|
||||
viper.SetDefault("gateway.image_stream_data_interval_timeout", 900)
|
||||
viper.SetDefault("gateway.image_stream_keepalive_interval", 10)
|
||||
viper.SetDefault("gateway.max_line_size", 500*1024*1024)
|
||||
viper.SetDefault("gateway.scheduling.sticky_session_max_waiting", 3)
|
||||
viper.SetDefault("gateway.scheduling.sticky_session_wait_timeout", 120*time.Second)
|
||||
@ -2239,6 +2270,21 @@ func (c *Config) Validate() error {
|
||||
ConnectionPoolIsolationProxy, ConnectionPoolIsolationAccount, ConnectionPoolIsolationAccountProxy)
|
||||
}
|
||||
}
|
||||
if c.Gateway.ImageConcurrency.MaxConcurrentRequests < 0 {
|
||||
return fmt.Errorf("gateway.image_concurrency.max_concurrent_requests must be non-negative")
|
||||
}
|
||||
switch strings.TrimSpace(c.Gateway.ImageConcurrency.OverflowMode) {
|
||||
case "", ImageConcurrencyOverflowModeReject, ImageConcurrencyOverflowModeWait:
|
||||
default:
|
||||
return fmt.Errorf("gateway.image_concurrency.overflow_mode must be one of: %s/%s",
|
||||
ImageConcurrencyOverflowModeReject, ImageConcurrencyOverflowModeWait)
|
||||
}
|
||||
if c.Gateway.ImageConcurrency.WaitTimeoutSeconds < 0 {
|
||||
return fmt.Errorf("gateway.image_concurrency.wait_timeout_seconds must be non-negative")
|
||||
}
|
||||
if c.Gateway.ImageConcurrency.MaxWaitingRequests < 0 {
|
||||
return fmt.Errorf("gateway.image_concurrency.max_waiting_requests must be non-negative")
|
||||
}
|
||||
if c.Gateway.MaxIdleConns <= 0 {
|
||||
return fmt.Errorf("gateway.max_idle_conns must be positive")
|
||||
}
|
||||
@ -2277,6 +2323,20 @@ func (c *Config) Validate() error {
|
||||
(c.Gateway.StreamKeepaliveInterval < 5 || c.Gateway.StreamKeepaliveInterval > 30) {
|
||||
return fmt.Errorf("gateway.stream_keepalive_interval must be 0 or between 5-30 seconds")
|
||||
}
|
||||
if c.Gateway.ImageStreamDataIntervalTimeout < 0 {
|
||||
return fmt.Errorf("gateway.image_stream_data_interval_timeout must be non-negative")
|
||||
}
|
||||
if c.Gateway.ImageStreamDataIntervalTimeout != 0 &&
|
||||
(c.Gateway.ImageStreamDataIntervalTimeout < 60 || c.Gateway.ImageStreamDataIntervalTimeout > 1800) {
|
||||
return fmt.Errorf("gateway.image_stream_data_interval_timeout must be 0 or between 60-1800 seconds")
|
||||
}
|
||||
if c.Gateway.ImageStreamKeepaliveInterval < 0 {
|
||||
return fmt.Errorf("gateway.image_stream_keepalive_interval must be non-negative")
|
||||
}
|
||||
if c.Gateway.ImageStreamKeepaliveInterval != 0 &&
|
||||
(c.Gateway.ImageStreamKeepaliveInterval < 5 || c.Gateway.ImageStreamKeepaliveInterval > 60) {
|
||||
return fmt.Errorf("gateway.image_stream_keepalive_interval must be 0 or between 5-60 seconds")
|
||||
}
|
||||
// 兼容旧键 sticky_previous_response_ttl_seconds
|
||||
if c.Gateway.OpenAIWS.StickyResponseIDTTLSeconds <= 0 && c.Gateway.OpenAIWS.StickyPreviousResponseTTLSeconds > 0 {
|
||||
c.Gateway.OpenAIWS.StickyResponseIDTTLSeconds = c.Gateway.OpenAIWS.StickyPreviousResponseTTLSeconds
|
||||
|
||||
@ -1282,6 +1282,46 @@ func TestValidateConfigErrors(t *testing.T) {
|
||||
mutate: func(c *Config) { c.Gateway.StreamDataIntervalTimeout = -1 },
|
||||
wantErr: "gateway.stream_data_interval_timeout must be non-negative",
|
||||
},
|
||||
{
|
||||
name: "gateway image stream keepalive range",
|
||||
mutate: func(c *Config) { c.Gateway.ImageStreamKeepaliveInterval = 4 },
|
||||
wantErr: "gateway.image_stream_keepalive_interval",
|
||||
},
|
||||
{
|
||||
name: "gateway image stream keepalive negative",
|
||||
mutate: func(c *Config) { c.Gateway.ImageStreamKeepaliveInterval = -1 },
|
||||
wantErr: "gateway.image_stream_keepalive_interval must be non-negative",
|
||||
},
|
||||
{
|
||||
name: "gateway image stream data interval range",
|
||||
mutate: func(c *Config) { c.Gateway.ImageStreamDataIntervalTimeout = 30 },
|
||||
wantErr: "gateway.image_stream_data_interval_timeout",
|
||||
},
|
||||
{
|
||||
name: "gateway image stream data interval negative",
|
||||
mutate: func(c *Config) { c.Gateway.ImageStreamDataIntervalTimeout = -1 },
|
||||
wantErr: "gateway.image_stream_data_interval_timeout must be non-negative",
|
||||
},
|
||||
{
|
||||
name: "gateway image concurrency max negative",
|
||||
mutate: func(c *Config) { c.Gateway.ImageConcurrency.MaxConcurrentRequests = -1 },
|
||||
wantErr: "gateway.image_concurrency.max_concurrent_requests must be non-negative",
|
||||
},
|
||||
{
|
||||
name: "gateway image concurrency overflow mode invalid",
|
||||
mutate: func(c *Config) { c.Gateway.ImageConcurrency.OverflowMode = "queue" },
|
||||
wantErr: "gateway.image_concurrency.overflow_mode",
|
||||
},
|
||||
{
|
||||
name: "gateway image concurrency wait timeout negative",
|
||||
mutate: func(c *Config) { c.Gateway.ImageConcurrency.WaitTimeoutSeconds = -1 },
|
||||
wantErr: "gateway.image_concurrency.wait_timeout_seconds must be non-negative",
|
||||
},
|
||||
{
|
||||
name: "gateway image concurrency max waiting negative",
|
||||
mutate: func(c *Config) { c.Gateway.ImageConcurrency.MaxWaitingRequests = -1 },
|
||||
wantErr: "gateway.image_concurrency.max_waiting_requests must be non-negative",
|
||||
},
|
||||
{
|
||||
name: "gateway max line size",
|
||||
mutate: func(c *Config) { c.Gateway.MaxLineSize = 1024 },
|
||||
@ -1754,3 +1794,41 @@ func TestLoad_DefaultGatewayUsageRecordConfig(t *testing.T) {
|
||||
t.Fatalf("auto_scale_cooldown_seconds = %d, want 10", cfg.Gateway.UsageRecord.AutoScaleCooldownSeconds)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoad_DefaultGatewayImageStreamConfig(t *testing.T) {
|
||||
resetViperWithJWTSecret(t)
|
||||
cfg, err := Load()
|
||||
if err != nil {
|
||||
t.Fatalf("Load() error: %v", err)
|
||||
}
|
||||
if cfg.Gateway.StreamDataIntervalTimeout != 180 {
|
||||
t.Fatalf("stream_data_interval_timeout = %d, want 180", cfg.Gateway.StreamDataIntervalTimeout)
|
||||
}
|
||||
if cfg.Gateway.StreamKeepaliveInterval != 10 {
|
||||
t.Fatalf("stream_keepalive_interval = %d, want 10", cfg.Gateway.StreamKeepaliveInterval)
|
||||
}
|
||||
if cfg.Gateway.ImageStreamDataIntervalTimeout != 900 {
|
||||
t.Fatalf("image_stream_data_interval_timeout = %d, want 900", cfg.Gateway.ImageStreamDataIntervalTimeout)
|
||||
}
|
||||
if cfg.Gateway.ImageStreamKeepaliveInterval != 10 {
|
||||
t.Fatalf("image_stream_keepalive_interval = %d, want 10", cfg.Gateway.ImageStreamKeepaliveInterval)
|
||||
}
|
||||
if cfg.Gateway.ImageConcurrency.Enabled {
|
||||
t.Fatalf("image_concurrency.enabled = true, want false")
|
||||
}
|
||||
if cfg.Gateway.ImageConcurrency.MaxConcurrentRequests != 0 {
|
||||
t.Fatalf("image_concurrency.max_concurrent_requests = %d, want 0", cfg.Gateway.ImageConcurrency.MaxConcurrentRequests)
|
||||
}
|
||||
if cfg.Gateway.ImageConcurrency.OverflowMode != ImageConcurrencyOverflowModeReject {
|
||||
t.Fatalf("image_concurrency.overflow_mode = %q, want %q", cfg.Gateway.ImageConcurrency.OverflowMode, ImageConcurrencyOverflowModeReject)
|
||||
}
|
||||
if cfg.Gateway.ImageConcurrency.WaitTimeoutSeconds != 30 {
|
||||
t.Fatalf("image_concurrency.wait_timeout_seconds = %d, want 30", cfg.Gateway.ImageConcurrency.WaitTimeoutSeconds)
|
||||
}
|
||||
if cfg.Gateway.ImageConcurrency.MaxWaitingRequests != 100 {
|
||||
t.Fatalf("image_concurrency.max_waiting_requests = %d, want 100", cfg.Gateway.ImageConcurrency.MaxWaitingRequests)
|
||||
}
|
||||
if cfg.Gateway.ImageStreamDataIntervalTimeout <= cfg.Gateway.StreamDataIntervalTimeout {
|
||||
t.Fatalf("image stream timeout = %d, want greater than ordinary stream timeout %d", cfg.Gateway.ImageStreamDataIntervalTimeout, cfg.Gateway.StreamDataIntervalTimeout)
|
||||
}
|
||||
}
|
||||
|
||||
@ -92,6 +92,9 @@ type CreateGroupRequest struct {
|
||||
WeeklyLimitUSD optionalLimitField `json:"weekly_limit_usd"`
|
||||
MonthlyLimitUSD optionalLimitField `json:"monthly_limit_usd"`
|
||||
// 图片生成计费配置(antigravity 和 gemini 平台使用,负数表示清除配置)
|
||||
AllowImageGeneration bool `json:"allow_image_generation"`
|
||||
ImageRateIndependent bool `json:"image_rate_independent"`
|
||||
ImageRateMultiplier *float64 `json:"image_rate_multiplier"`
|
||||
ImagePrice1K *float64 `json:"image_price_1k"`
|
||||
ImagePrice2K *float64 `json:"image_price_2k"`
|
||||
ImagePrice4K *float64 `json:"image_price_4k"`
|
||||
@ -129,6 +132,9 @@ type UpdateGroupRequest struct {
|
||||
WeeklyLimitUSD optionalLimitField `json:"weekly_limit_usd"`
|
||||
MonthlyLimitUSD optionalLimitField `json:"monthly_limit_usd"`
|
||||
// 图片生成计费配置(antigravity 和 gemini 平台使用,负数表示清除配置)
|
||||
AllowImageGeneration *bool `json:"allow_image_generation"`
|
||||
ImageRateIndependent *bool `json:"image_rate_independent"`
|
||||
ImageRateMultiplier *float64 `json:"image_rate_multiplier"`
|
||||
ImagePrice1K *float64 `json:"image_price_1k"`
|
||||
ImagePrice2K *float64 `json:"image_price_2k"`
|
||||
ImagePrice4K *float64 `json:"image_price_4k"`
|
||||
@ -251,6 +257,9 @@ func (h *GroupHandler) Create(c *gin.Context) {
|
||||
DailyLimitUSD: req.DailyLimitUSD.ToServiceInput(),
|
||||
WeeklyLimitUSD: req.WeeklyLimitUSD.ToServiceInput(),
|
||||
MonthlyLimitUSD: req.MonthlyLimitUSD.ToServiceInput(),
|
||||
AllowImageGeneration: req.AllowImageGeneration,
|
||||
ImageRateIndependent: req.ImageRateIndependent,
|
||||
ImageRateMultiplier: req.ImageRateMultiplier,
|
||||
ImagePrice1K: req.ImagePrice1K,
|
||||
ImagePrice2K: req.ImagePrice2K,
|
||||
ImagePrice4K: req.ImagePrice4K,
|
||||
@ -303,6 +312,9 @@ func (h *GroupHandler) Update(c *gin.Context) {
|
||||
DailyLimitUSD: req.DailyLimitUSD.ToServiceInput(),
|
||||
WeeklyLimitUSD: req.WeeklyLimitUSD.ToServiceInput(),
|
||||
MonthlyLimitUSD: req.MonthlyLimitUSD.ToServiceInput(),
|
||||
AllowImageGeneration: req.AllowImageGeneration,
|
||||
ImageRateIndependent: req.ImageRateIndependent,
|
||||
ImageRateMultiplier: req.ImageRateMultiplier,
|
||||
ImagePrice1K: req.ImagePrice1K,
|
||||
ImagePrice2K: req.ImagePrice2K,
|
||||
ImagePrice4K: req.ImagePrice4K,
|
||||
|
||||
@ -176,6 +176,9 @@ func groupFromServiceBase(g *service.Group) Group {
|
||||
DailyLimitUSD: g.DailyLimitUSD,
|
||||
WeeklyLimitUSD: g.WeeklyLimitUSD,
|
||||
MonthlyLimitUSD: g.MonthlyLimitUSD,
|
||||
AllowImageGeneration: g.AllowImageGeneration,
|
||||
ImageRateIndependent: g.ImageRateIndependent,
|
||||
ImageRateMultiplier: g.ImageRateMultiplier,
|
||||
ImagePrice1K: g.ImagePrice1K,
|
||||
ImagePrice2K: g.ImagePrice2K,
|
||||
ImagePrice4K: g.ImagePrice4K,
|
||||
|
||||
@ -94,9 +94,12 @@ type Group struct {
|
||||
MonthlyLimitUSD *float64 `json:"monthly_limit_usd"`
|
||||
|
||||
// 图片生成计费配置(仅 antigravity 平台使用)
|
||||
ImagePrice1K *float64 `json:"image_price_1k"`
|
||||
ImagePrice2K *float64 `json:"image_price_2k"`
|
||||
ImagePrice4K *float64 `json:"image_price_4k"`
|
||||
AllowImageGeneration bool `json:"allow_image_generation"`
|
||||
ImageRateIndependent bool `json:"image_rate_independent"`
|
||||
ImageRateMultiplier float64 `json:"image_rate_multiplier"`
|
||||
ImagePrice1K *float64 `json:"image_price_1k"`
|
||||
ImagePrice2K *float64 `json:"image_price_2k"`
|
||||
ImagePrice4K *float64 `json:"image_price_4k"`
|
||||
|
||||
// Claude Code 客户端限制
|
||||
ClaudeCodeOnly bool `json:"claude_code_only"`
|
||||
|
||||
126
backend/internal/handler/image_concurrency_limiter.go
Normal file
126
backend/internal/handler/image_concurrency_limiter.go
Normal file
@ -0,0 +1,126 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type imageConcurrencyLimiter struct {
|
||||
mu sync.Mutex
|
||||
notify chan struct{}
|
||||
limit int
|
||||
active int
|
||||
waiting int
|
||||
enabled bool
|
||||
}
|
||||
|
||||
func (l *imageConcurrencyLimiter) TryAcquire(enabled bool, limit int) (func(), bool) {
|
||||
return l.acquire(context.Background(), enabled, limit, false, 0, 0)
|
||||
}
|
||||
|
||||
func (l *imageConcurrencyLimiter) Acquire(ctx context.Context, enabled bool, limit int, wait bool, timeout time.Duration, maxWaiting int) (func(), bool) {
|
||||
return l.acquire(ctx, enabled, limit, wait, timeout, maxWaiting)
|
||||
}
|
||||
|
||||
func (l *imageConcurrencyLimiter) acquire(ctx context.Context, enabled bool, limit int, wait bool, timeout time.Duration, maxWaiting int) (func(), bool) {
|
||||
if !enabled || limit <= 0 {
|
||||
return nil, true
|
||||
}
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
if wait {
|
||||
if timeout <= 0 {
|
||||
return nil, false
|
||||
}
|
||||
waitCtx, cancel := context.WithTimeout(ctx, timeout)
|
||||
defer cancel()
|
||||
ctx = waitCtx
|
||||
}
|
||||
if maxWaiting < 0 {
|
||||
maxWaiting = 0
|
||||
}
|
||||
for {
|
||||
release, acquired, waitRelease, notify := l.tryAcquireLocked(enabled, limit, wait, maxWaiting)
|
||||
if acquired {
|
||||
return release, acquired
|
||||
}
|
||||
if !wait || notify == nil {
|
||||
return nil, false
|
||||
}
|
||||
if !l.waitForSlot(ctx, notify) {
|
||||
if waitRelease != nil {
|
||||
waitRelease()
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
if waitRelease != nil {
|
||||
waitRelease()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (l *imageConcurrencyLimiter) tryAcquireLocked(enabled bool, limit int, wait bool, maxWaiting int) (func(), bool, func(), <-chan struct{}) {
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
|
||||
if l.notify == nil {
|
||||
l.notify = make(chan struct{})
|
||||
}
|
||||
if l.enabled != enabled || l.limit != limit {
|
||||
l.enabled = enabled
|
||||
l.limit = limit
|
||||
}
|
||||
if l.active < l.limit {
|
||||
l.active++
|
||||
return l.releaseFunc(), true, nil, nil
|
||||
}
|
||||
if !wait {
|
||||
return nil, false, nil, nil
|
||||
}
|
||||
if maxWaiting > 0 && l.waiting >= maxWaiting {
|
||||
return nil, false, nil, nil
|
||||
}
|
||||
l.waiting++
|
||||
return nil, false, l.waiterReleaseFunc(), l.notify
|
||||
}
|
||||
|
||||
func (l *imageConcurrencyLimiter) waitForSlot(ctx context.Context, notify <-chan struct{}) bool {
|
||||
select {
|
||||
case <-notify:
|
||||
return true
|
||||
case <-ctx.Done():
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (l *imageConcurrencyLimiter) releaseFunc() func() {
|
||||
var once sync.Once
|
||||
return func() {
|
||||
once.Do(func() {
|
||||
l.mu.Lock()
|
||||
if l.active > 0 {
|
||||
l.active--
|
||||
}
|
||||
if l.notify != nil {
|
||||
close(l.notify)
|
||||
l.notify = make(chan struct{})
|
||||
}
|
||||
l.mu.Unlock()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (l *imageConcurrencyLimiter) waiterReleaseFunc() func() {
|
||||
var once sync.Once
|
||||
return func() {
|
||||
once.Do(func() {
|
||||
l.mu.Lock()
|
||||
if l.waiting > 0 {
|
||||
l.waiting--
|
||||
}
|
||||
l.mu.Unlock()
|
||||
})
|
||||
}
|
||||
}
|
||||
230
backend/internal/handler/image_concurrency_limiter_test.go
Normal file
230
backend/internal/handler/image_concurrency_limiter_test.go
Normal file
@ -0,0 +1,230 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
func TestImageConcurrencyLimiter_DefaultDisabledAllowsRequests(t *testing.T) {
|
||||
limiter := &imageConcurrencyLimiter{}
|
||||
|
||||
release, acquired := limiter.TryAcquire(false, 1)
|
||||
|
||||
require.True(t, acquired)
|
||||
require.Nil(t, release)
|
||||
}
|
||||
|
||||
func TestImageConcurrencyLimiter_RejectsWhenLimitReachedAndAllowsAfterRelease(t *testing.T) {
|
||||
limiter := &imageConcurrencyLimiter{}
|
||||
|
||||
release, acquired := limiter.TryAcquire(true, 1)
|
||||
require.True(t, acquired)
|
||||
require.NotNil(t, release)
|
||||
|
||||
secondRelease, secondAcquired := limiter.TryAcquire(true, 1)
|
||||
require.False(t, secondAcquired)
|
||||
require.Nil(t, secondRelease)
|
||||
|
||||
release()
|
||||
thirdRelease, thirdAcquired := limiter.TryAcquire(true, 1)
|
||||
require.True(t, thirdAcquired)
|
||||
require.NotNil(t, thirdRelease)
|
||||
thirdRelease()
|
||||
}
|
||||
|
||||
func TestImageConcurrencyLimiter_WaitsUntilSlotReleased(t *testing.T) {
|
||||
limiter := &imageConcurrencyLimiter{}
|
||||
release, acquired := limiter.Acquire(context.Background(), true, 1, true, time.Second, 1)
|
||||
require.True(t, acquired)
|
||||
require.NotNil(t, release)
|
||||
|
||||
acquiredCh := make(chan func(), 1)
|
||||
go func() {
|
||||
waitRelease, waitAcquired := limiter.Acquire(context.Background(), true, 1, true, time.Second, 1)
|
||||
require.True(t, waitAcquired)
|
||||
acquiredCh <- waitRelease
|
||||
}()
|
||||
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
release()
|
||||
|
||||
select {
|
||||
case waitRelease := <-acquiredCh:
|
||||
require.NotNil(t, waitRelease)
|
||||
waitRelease()
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("timed out waiting for image concurrency slot")
|
||||
}
|
||||
}
|
||||
|
||||
func TestImageConcurrencyLimiter_WaitTimesOut(t *testing.T) {
|
||||
limiter := &imageConcurrencyLimiter{}
|
||||
release, acquired := limiter.Acquire(context.Background(), true, 1, true, time.Second, 1)
|
||||
require.True(t, acquired)
|
||||
require.NotNil(t, release)
|
||||
defer release()
|
||||
|
||||
waitRelease, waitAcquired := limiter.Acquire(context.Background(), true, 1, true, 10*time.Millisecond, 1)
|
||||
|
||||
require.False(t, waitAcquired)
|
||||
require.Nil(t, waitRelease)
|
||||
}
|
||||
|
||||
func TestImageConcurrencyLimiter_MaxWaitingRequestsRejectsOverflow(t *testing.T) {
|
||||
limiter := &imageConcurrencyLimiter{}
|
||||
release, acquired := limiter.Acquire(context.Background(), true, 1, true, time.Second, 1)
|
||||
require.True(t, acquired)
|
||||
require.NotNil(t, release)
|
||||
defer release()
|
||||
|
||||
waitingStarted := make(chan struct{})
|
||||
waitingDone := make(chan struct{})
|
||||
go func() {
|
||||
close(waitingStarted)
|
||||
waitRelease, waitAcquired := limiter.Acquire(context.Background(), true, 1, true, time.Second, 1)
|
||||
if waitAcquired && waitRelease != nil {
|
||||
waitRelease()
|
||||
}
|
||||
close(waitingDone)
|
||||
}()
|
||||
<-waitingStarted
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
|
||||
overflowRelease, overflowAcquired := limiter.Acquire(context.Background(), true, 1, true, time.Second, 1)
|
||||
|
||||
require.False(t, overflowAcquired)
|
||||
require.Nil(t, overflowRelease)
|
||||
release()
|
||||
<-waitingDone
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayHandlerAcquireImageGenerationSlot_Returns429WhenFull(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/images/generations", nil)
|
||||
|
||||
h := &OpenAIGatewayHandler{
|
||||
cfg: &config.Config{
|
||||
Gateway: config.GatewayConfig{
|
||||
ImageConcurrency: config.ImageConcurrencyConfig{
|
||||
Enabled: true,
|
||||
MaxConcurrentRequests: 1,
|
||||
OverflowMode: config.ImageConcurrencyOverflowModeReject,
|
||||
},
|
||||
},
|
||||
},
|
||||
imageLimiter: &imageConcurrencyLimiter{},
|
||||
}
|
||||
release, acquired := h.acquireImageGenerationSlot(c, false)
|
||||
require.True(t, acquired)
|
||||
require.NotNil(t, release)
|
||||
defer release()
|
||||
|
||||
blockedRelease, blocked := h.acquireImageGenerationSlot(c, false)
|
||||
|
||||
require.False(t, blocked)
|
||||
require.Nil(t, blockedRelease)
|
||||
require.Equal(t, http.StatusTooManyRequests, rec.Code)
|
||||
require.Equal(t, "rate_limit_error", gjson.GetBytes(rec.Body.Bytes(), "error.type").String())
|
||||
require.Contains(t, rec.Body.String(), "Image generation concurrency limit exceeded")
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayHandlerResponses_ImageIntentRejectedByImageConcurrency(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
body := `{"model":"gpt-5.4","input":"draw","tools":[{"type":"image_generation"}]}`
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", strings.NewReader(body))
|
||||
groupID := int64(1)
|
||||
c.Set(string(middleware2.ContextKeyAPIKey), &service.APIKey{
|
||||
ID: 10,
|
||||
GroupID: &groupID,
|
||||
Group: &service.Group{
|
||||
ID: groupID,
|
||||
AllowImageGeneration: true,
|
||||
},
|
||||
User: &service.User{ID: 20},
|
||||
})
|
||||
c.Set(string(middleware2.ContextKeyUser), middleware2.AuthSubject{UserID: 20, Concurrency: 1})
|
||||
|
||||
h := &OpenAIGatewayHandler{
|
||||
gatewayService: &service.OpenAIGatewayService{},
|
||||
billingCacheService: &service.BillingCacheService{},
|
||||
apiKeyService: &service.APIKeyService{},
|
||||
concurrencyHelper: &ConcurrencyHelper{concurrencyService: service.NewConcurrencyService(&helperConcurrencyCacheStub{userSeq: []bool{true}})},
|
||||
errorPassthroughService: nil,
|
||||
cfg: &config.Config{Gateway: config.GatewayConfig{ImageConcurrency: config.ImageConcurrencyConfig{
|
||||
Enabled: true,
|
||||
MaxConcurrentRequests: 1,
|
||||
OverflowMode: config.ImageConcurrencyOverflowModeReject,
|
||||
}}},
|
||||
imageLimiter: &imageConcurrencyLimiter{},
|
||||
}
|
||||
release, acquired := h.acquireImageGenerationSlot(c, false)
|
||||
require.True(t, acquired)
|
||||
require.NotNil(t, release)
|
||||
defer release()
|
||||
rec.Body.Reset()
|
||||
rec.Code = 0
|
||||
|
||||
h.Responses(c)
|
||||
|
||||
require.Equal(t, http.StatusTooManyRequests, rec.Code)
|
||||
require.Equal(t, "rate_limit_error", gjson.GetBytes(rec.Body.Bytes(), "error.type").String())
|
||||
require.Contains(t, rec.Body.String(), "Image generation concurrency limit exceeded")
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayHandlerResponses_TextOnlyNotRejectedByImageConcurrency(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
body := `{"model":"gpt-5.4","input":"write code"}`
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", strings.NewReader(body))
|
||||
groupID := int64(1)
|
||||
c.Set(string(middleware2.ContextKeyAPIKey), &service.APIKey{
|
||||
ID: 10,
|
||||
GroupID: &groupID,
|
||||
Group: &service.Group{
|
||||
ID: groupID,
|
||||
AllowImageGeneration: true,
|
||||
},
|
||||
User: &service.User{ID: 20},
|
||||
})
|
||||
c.Set(string(middleware2.ContextKeyUser), middleware2.AuthSubject{UserID: 20, Concurrency: 1})
|
||||
|
||||
h := &OpenAIGatewayHandler{
|
||||
gatewayService: &service.OpenAIGatewayService{},
|
||||
billingCacheService: service.NewBillingCacheService(nil, nil, nil, nil, nil, nil, &config.Config{RunMode: config.RunModeSimple}),
|
||||
apiKeyService: &service.APIKeyService{},
|
||||
concurrencyHelper: &ConcurrencyHelper{concurrencyService: service.NewConcurrencyService(&helperConcurrencyCacheStub{userSeq: []bool{true}})},
|
||||
cfg: &config.Config{Gateway: config.GatewayConfig{ImageConcurrency: config.ImageConcurrencyConfig{
|
||||
Enabled: true,
|
||||
MaxConcurrentRequests: 1,
|
||||
OverflowMode: config.ImageConcurrencyOverflowModeReject,
|
||||
}}},
|
||||
imageLimiter: &imageConcurrencyLimiter{},
|
||||
}
|
||||
release, acquired := h.acquireImageGenerationSlot(c, false)
|
||||
require.True(t, acquired)
|
||||
require.NotNil(t, release)
|
||||
defer release()
|
||||
rec.Body.Reset()
|
||||
rec.Code = 0
|
||||
|
||||
h.Responses(c)
|
||||
|
||||
require.NotEqual(t, http.StatusTooManyRequests, rec.Code)
|
||||
require.NotContains(t, rec.Body.String(), "Image generation concurrency limit exceeded")
|
||||
}
|
||||
@ -187,52 +187,60 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
|
||||
service.SetOpsLatencyMs(c, service.OpsTimeToFirstTokenMsKey, int64(*result.FirstTokenMs))
|
||||
}
|
||||
if err != nil {
|
||||
var failoverErr *service.UpstreamFailoverError
|
||||
if errors.As(err, &failoverErr) {
|
||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
|
||||
// Pool mode: retry on the same account
|
||||
if failoverErr.RetryableOnSameAccount {
|
||||
retryLimit := account.GetPoolModeRetryCount()
|
||||
if sameAccountRetryCount[account.ID] < retryLimit {
|
||||
sameAccountRetryCount[account.ID]++
|
||||
reqLog.Warn("openai_chat_completions.pool_mode_same_account_retry",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int("upstream_status", failoverErr.StatusCode),
|
||||
zap.Int("retry_limit", retryLimit),
|
||||
zap.Int("retry_count", sameAccountRetryCount[account.ID]),
|
||||
)
|
||||
select {
|
||||
case <-c.Request.Context().Done():
|
||||
return
|
||||
case <-time.After(sameAccountRetryDelay):
|
||||
}
|
||||
continue
|
||||
}
|
||||
}
|
||||
h.gatewayService.RecordOpenAIAccountSwitch()
|
||||
failedAccountIDs[account.ID] = struct{}{}
|
||||
lastFailoverErr = failoverErr
|
||||
if switchCount >= maxAccountSwitches {
|
||||
h.handleFailoverExhausted(c, failoverErr, streamStarted)
|
||||
return
|
||||
}
|
||||
switchCount++
|
||||
reqLog.Warn("openai_chat_completions.upstream_failover_switching",
|
||||
if result != nil && result.ImageCount > 0 {
|
||||
reqLog.Warn("openai_chat_completions.forward_partial_error_with_image_result",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int("upstream_status", failoverErr.StatusCode),
|
||||
zap.Int("switch_count", switchCount),
|
||||
zap.Int("max_switches", maxAccountSwitches),
|
||||
zap.Int("image_count", result.ImageCount),
|
||||
zap.Error(err),
|
||||
)
|
||||
continue
|
||||
} else {
|
||||
var failoverErr *service.UpstreamFailoverError
|
||||
if errors.As(err, &failoverErr) {
|
||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
|
||||
// Pool mode: retry on the same account
|
||||
if failoverErr.RetryableOnSameAccount {
|
||||
retryLimit := account.GetPoolModeRetryCount()
|
||||
if sameAccountRetryCount[account.ID] < retryLimit {
|
||||
sameAccountRetryCount[account.ID]++
|
||||
reqLog.Warn("openai_chat_completions.pool_mode_same_account_retry",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int("upstream_status", failoverErr.StatusCode),
|
||||
zap.Int("retry_limit", retryLimit),
|
||||
zap.Int("retry_count", sameAccountRetryCount[account.ID]),
|
||||
)
|
||||
select {
|
||||
case <-c.Request.Context().Done():
|
||||
return
|
||||
case <-time.After(sameAccountRetryDelay):
|
||||
}
|
||||
continue
|
||||
}
|
||||
}
|
||||
h.gatewayService.RecordOpenAIAccountSwitch()
|
||||
failedAccountIDs[account.ID] = struct{}{}
|
||||
lastFailoverErr = failoverErr
|
||||
if switchCount >= maxAccountSwitches {
|
||||
h.handleFailoverExhausted(c, failoverErr, streamStarted)
|
||||
return
|
||||
}
|
||||
switchCount++
|
||||
reqLog.Warn("openai_chat_completions.upstream_failover_switching",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int("upstream_status", failoverErr.StatusCode),
|
||||
zap.Int("switch_count", switchCount),
|
||||
zap.Int("max_switches", maxAccountSwitches),
|
||||
)
|
||||
continue
|
||||
}
|
||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
|
||||
wroteFallback := h.ensureForwardErrorResponse(c, streamStarted)
|
||||
reqLog.Warn("openai_chat_completions.forward_failed",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Bool("fallback_error_response_written", wroteFallback),
|
||||
zap.Error(err),
|
||||
)
|
||||
return
|
||||
}
|
||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
|
||||
wroteFallback := h.ensureForwardErrorResponse(c, streamStarted)
|
||||
reqLog.Warn("openai_chat_completions.forward_failed",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Bool("fallback_error_response_written", wroteFallback),
|
||||
zap.Error(err),
|
||||
)
|
||||
return
|
||||
}
|
||||
if result != nil {
|
||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, result.FirstTokenMs)
|
||||
@ -242,16 +250,18 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
|
||||
|
||||
userAgent := c.GetHeader("User-Agent")
|
||||
clientIP := ip.GetClientIP(c)
|
||||
inboundEndpoint := GetInboundEndpoint(c)
|
||||
upstreamEndpoint := resolveRawCCUpstreamEndpoint(c, account)
|
||||
|
||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||
h.submitOpenAIUsageRecordTask(result, func(ctx context.Context) {
|
||||
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
|
||||
Result: result,
|
||||
APIKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
InboundEndpoint: GetInboundEndpoint(c),
|
||||
UpstreamEndpoint: resolveRawCCUpstreamEndpoint(c, account),
|
||||
InboundEndpoint: inboundEndpoint,
|
||||
UpstreamEndpoint: upstreamEndpoint,
|
||||
UserAgent: userAgent,
|
||||
IPAddress: clientIP,
|
||||
APIKeyService: h.apiKeyService,
|
||||
|
||||
@ -33,6 +33,7 @@ type OpenAIGatewayHandler struct {
|
||||
usageRecordWorkerPool *service.UsageRecordWorkerPool
|
||||
errorPassthroughService *service.ErrorPassthroughService
|
||||
concurrencyHelper *ConcurrencyHelper
|
||||
imageLimiter *imageConcurrencyLimiter
|
||||
maxAccountSwitches int
|
||||
cfg *config.Config
|
||||
}
|
||||
@ -69,6 +70,7 @@ func NewOpenAIGatewayHandler(
|
||||
usageRecordWorkerPool: usageRecordWorkerPool,
|
||||
errorPassthroughService: errorPassthroughService,
|
||||
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatComment, pingInterval),
|
||||
imageLimiter: &imageConcurrencyLimiter{},
|
||||
maxAccountSwitches: maxAccountSwitches,
|
||||
cfg: cfg,
|
||||
}
|
||||
@ -187,6 +189,23 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
setOpsRequestContext(c, reqModel, reqStream, body)
|
||||
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false)))
|
||||
|
||||
imageIntent := service.IsImageGenerationIntent("/v1/responses", reqModel, body)
|
||||
if imageIntent && !service.GroupAllowsImageGeneration(apiKey.Group) {
|
||||
h.errorResponse(c, http.StatusForbidden, "permission_error", service.ImageGenerationPermissionMessage())
|
||||
return
|
||||
}
|
||||
var imageReleaseFunc func()
|
||||
if imageIntent {
|
||||
var imageAcquired bool
|
||||
imageReleaseFunc, imageAcquired = h.acquireImageGenerationSlot(c, streamStarted)
|
||||
if !imageAcquired {
|
||||
return
|
||||
}
|
||||
if imageReleaseFunc != nil {
|
||||
defer imageReleaseFunc()
|
||||
}
|
||||
}
|
||||
|
||||
// 解析渠道级模型映射
|
||||
channelMapping, _ := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel)
|
||||
|
||||
@ -318,57 +337,65 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
service.SetOpsLatencyMs(c, service.OpsTimeToFirstTokenMsKey, int64(*result.FirstTokenMs))
|
||||
}
|
||||
if err != nil {
|
||||
var failoverErr *service.UpstreamFailoverError
|
||||
if errors.As(err, &failoverErr) {
|
||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
|
||||
// 池模式:同账号重试
|
||||
if failoverErr.RetryableOnSameAccount {
|
||||
retryLimit := account.GetPoolModeRetryCount()
|
||||
if sameAccountRetryCount[account.ID] < retryLimit {
|
||||
sameAccountRetryCount[account.ID]++
|
||||
reqLog.Warn("openai.pool_mode_same_account_retry",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int("upstream_status", failoverErr.StatusCode),
|
||||
zap.Int("retry_limit", retryLimit),
|
||||
zap.Int("retry_count", sameAccountRetryCount[account.ID]),
|
||||
)
|
||||
select {
|
||||
case <-c.Request.Context().Done():
|
||||
return
|
||||
case <-time.After(sameAccountRetryDelay):
|
||||
if result != nil && result.ImageCount > 0 {
|
||||
reqLog.Warn("openai.forward_partial_error_with_image_result",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int("image_count", result.ImageCount),
|
||||
zap.Error(err),
|
||||
)
|
||||
} else {
|
||||
var failoverErr *service.UpstreamFailoverError
|
||||
if errors.As(err, &failoverErr) {
|
||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
|
||||
// 池模式:同账号重试
|
||||
if failoverErr.RetryableOnSameAccount {
|
||||
retryLimit := account.GetPoolModeRetryCount()
|
||||
if sameAccountRetryCount[account.ID] < retryLimit {
|
||||
sameAccountRetryCount[account.ID]++
|
||||
reqLog.Warn("openai.pool_mode_same_account_retry",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int("upstream_status", failoverErr.StatusCode),
|
||||
zap.Int("retry_limit", retryLimit),
|
||||
zap.Int("retry_count", sameAccountRetryCount[account.ID]),
|
||||
)
|
||||
select {
|
||||
case <-c.Request.Context().Done():
|
||||
return
|
||||
case <-time.After(sameAccountRetryDelay):
|
||||
}
|
||||
continue
|
||||
}
|
||||
continue
|
||||
}
|
||||
h.gatewayService.RecordOpenAIAccountSwitch()
|
||||
failedAccountIDs[account.ID] = struct{}{}
|
||||
lastFailoverErr = failoverErr
|
||||
if switchCount >= maxAccountSwitches {
|
||||
h.handleFailoverExhausted(c, failoverErr, streamStarted)
|
||||
return
|
||||
}
|
||||
switchCount++
|
||||
reqLog.Warn("openai.upstream_failover_switching",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int("upstream_status", failoverErr.StatusCode),
|
||||
zap.Int("switch_count", switchCount),
|
||||
zap.Int("max_switches", maxAccountSwitches),
|
||||
)
|
||||
continue
|
||||
}
|
||||
h.gatewayService.RecordOpenAIAccountSwitch()
|
||||
failedAccountIDs[account.ID] = struct{}{}
|
||||
lastFailoverErr = failoverErr
|
||||
if switchCount >= maxAccountSwitches {
|
||||
h.handleFailoverExhausted(c, failoverErr, streamStarted)
|
||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
|
||||
wroteFallback := h.ensureForwardErrorResponse(c, streamStarted)
|
||||
fields := []zap.Field{
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Bool("fallback_error_response_written", wroteFallback),
|
||||
zap.Error(err),
|
||||
}
|
||||
if shouldLogOpenAIForwardFailureAsWarn(c, wroteFallback) {
|
||||
reqLog.Warn("openai.forward_failed", fields...)
|
||||
return
|
||||
}
|
||||
switchCount++
|
||||
reqLog.Warn("openai.upstream_failover_switching",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int("upstream_status", failoverErr.StatusCode),
|
||||
zap.Int("switch_count", switchCount),
|
||||
zap.Int("max_switches", maxAccountSwitches),
|
||||
)
|
||||
continue
|
||||
}
|
||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
|
||||
wroteFallback := h.ensureForwardErrorResponse(c, streamStarted)
|
||||
fields := []zap.Field{
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Bool("fallback_error_response_written", wroteFallback),
|
||||
zap.Error(err),
|
||||
}
|
||||
if shouldLogOpenAIForwardFailureAsWarn(c, wroteFallback) {
|
||||
reqLog.Warn("openai.forward_failed", fields...)
|
||||
reqLog.Error("openai.forward_failed", fields...)
|
||||
return
|
||||
}
|
||||
reqLog.Error("openai.forward_failed", fields...)
|
||||
return
|
||||
}
|
||||
if result != nil {
|
||||
if account.Type == service.AccountTypeOAuth {
|
||||
@ -383,17 +410,19 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
userAgent := c.GetHeader("User-Agent")
|
||||
clientIP := ip.GetClientIP(c)
|
||||
requestPayloadHash := service.HashUsageRequestPayload(body)
|
||||
inboundEndpoint := GetInboundEndpoint(c)
|
||||
upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform)
|
||||
|
||||
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
|
||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||
h.submitOpenAIUsageRecordTask(result, func(ctx context.Context) {
|
||||
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
|
||||
Result: result,
|
||||
APIKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
InboundEndpoint: GetInboundEndpoint(c),
|
||||
UpstreamEndpoint: GetUpstreamEndpoint(c, account.Platform),
|
||||
InboundEndpoint: inboundEndpoint,
|
||||
UpstreamEndpoint: upstreamEndpoint,
|
||||
UserAgent: userAgent,
|
||||
IPAddress: clientIP,
|
||||
RequestPayloadHash: requestPayloadHash,
|
||||
@ -701,52 +730,60 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
|
||||
service.SetOpsLatencyMs(c, service.OpsTimeToFirstTokenMsKey, int64(*result.FirstTokenMs))
|
||||
}
|
||||
if err != nil {
|
||||
var failoverErr *service.UpstreamFailoverError
|
||||
if errors.As(err, &failoverErr) {
|
||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
|
||||
// 池模式:同账号重试
|
||||
if failoverErr.RetryableOnSameAccount {
|
||||
retryLimit := account.GetPoolModeRetryCount()
|
||||
if sameAccountRetryCount[account.ID] < retryLimit {
|
||||
sameAccountRetryCount[account.ID]++
|
||||
reqLog.Warn("openai_messages.pool_mode_same_account_retry",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int("upstream_status", failoverErr.StatusCode),
|
||||
zap.Int("retry_limit", retryLimit),
|
||||
zap.Int("retry_count", sameAccountRetryCount[account.ID]),
|
||||
)
|
||||
select {
|
||||
case <-c.Request.Context().Done():
|
||||
return
|
||||
case <-time.After(sameAccountRetryDelay):
|
||||
}
|
||||
continue
|
||||
}
|
||||
}
|
||||
h.gatewayService.RecordOpenAIAccountSwitch()
|
||||
failedAccountIDs[account.ID] = struct{}{}
|
||||
lastFailoverErr = failoverErr
|
||||
if switchCount >= maxAccountSwitches {
|
||||
h.handleAnthropicFailoverExhausted(c, failoverErr, streamStarted)
|
||||
return
|
||||
}
|
||||
switchCount++
|
||||
reqLog.Warn("openai_messages.upstream_failover_switching",
|
||||
if result != nil && result.ImageCount > 0 {
|
||||
reqLog.Warn("openai_messages.forward_partial_error_with_image_result",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int("upstream_status", failoverErr.StatusCode),
|
||||
zap.Int("switch_count", switchCount),
|
||||
zap.Int("max_switches", maxAccountSwitches),
|
||||
zap.Int("image_count", result.ImageCount),
|
||||
zap.Error(err),
|
||||
)
|
||||
continue
|
||||
} else {
|
||||
var failoverErr *service.UpstreamFailoverError
|
||||
if errors.As(err, &failoverErr) {
|
||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
|
||||
// 池模式:同账号重试
|
||||
if failoverErr.RetryableOnSameAccount {
|
||||
retryLimit := account.GetPoolModeRetryCount()
|
||||
if sameAccountRetryCount[account.ID] < retryLimit {
|
||||
sameAccountRetryCount[account.ID]++
|
||||
reqLog.Warn("openai_messages.pool_mode_same_account_retry",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int("upstream_status", failoverErr.StatusCode),
|
||||
zap.Int("retry_limit", retryLimit),
|
||||
zap.Int("retry_count", sameAccountRetryCount[account.ID]),
|
||||
)
|
||||
select {
|
||||
case <-c.Request.Context().Done():
|
||||
return
|
||||
case <-time.After(sameAccountRetryDelay):
|
||||
}
|
||||
continue
|
||||
}
|
||||
}
|
||||
h.gatewayService.RecordOpenAIAccountSwitch()
|
||||
failedAccountIDs[account.ID] = struct{}{}
|
||||
lastFailoverErr = failoverErr
|
||||
if switchCount >= maxAccountSwitches {
|
||||
h.handleAnthropicFailoverExhausted(c, failoverErr, streamStarted)
|
||||
return
|
||||
}
|
||||
switchCount++
|
||||
reqLog.Warn("openai_messages.upstream_failover_switching",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int("upstream_status", failoverErr.StatusCode),
|
||||
zap.Int("switch_count", switchCount),
|
||||
zap.Int("max_switches", maxAccountSwitches),
|
||||
)
|
||||
continue
|
||||
}
|
||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
|
||||
wroteFallback := h.ensureAnthropicErrorResponse(c, streamStarted)
|
||||
reqLog.Warn("openai_messages.forward_failed",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Bool("fallback_error_response_written", wroteFallback),
|
||||
zap.Error(err),
|
||||
)
|
||||
return
|
||||
}
|
||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
|
||||
wroteFallback := h.ensureAnthropicErrorResponse(c, streamStarted)
|
||||
reqLog.Warn("openai_messages.forward_failed",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Bool("fallback_error_response_written", wroteFallback),
|
||||
zap.Error(err),
|
||||
)
|
||||
return
|
||||
}
|
||||
if result != nil {
|
||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, result.FirstTokenMs)
|
||||
@ -757,16 +794,18 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
|
||||
userAgent := c.GetHeader("User-Agent")
|
||||
clientIP := ip.GetClientIP(c)
|
||||
requestPayloadHash := service.HashUsageRequestPayload(body)
|
||||
inboundEndpoint := GetInboundEndpoint(c)
|
||||
upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform)
|
||||
|
||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||
h.submitOpenAIUsageRecordTask(result, func(ctx context.Context) {
|
||||
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
|
||||
Result: result,
|
||||
APIKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
InboundEndpoint: GetInboundEndpoint(c),
|
||||
UpstreamEndpoint: GetUpstreamEndpoint(c, account.Platform),
|
||||
InboundEndpoint: inboundEndpoint,
|
||||
UpstreamEndpoint: upstreamEndpoint,
|
||||
UserAgent: userAgent,
|
||||
IPAddress: clientIP,
|
||||
RequestPayloadHash: requestPayloadHash,
|
||||
@ -1114,6 +1153,11 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) {
|
||||
setOpsRequestContext(c, reqModel, true, firstMessage)
|
||||
setOpsEndpointContext(c, "", int16(service.RequestTypeWSV2))
|
||||
|
||||
if service.IsImageGenerationIntent("/v1/responses", reqModel, firstMessage) && !service.GroupAllowsImageGeneration(apiKey.Group) {
|
||||
closeOpenAIClientWS(wsConn, coderws.StatusPolicyViolation, service.ImageGenerationPermissionMessage())
|
||||
return
|
||||
}
|
||||
|
||||
// 解析渠道级模型映射
|
||||
channelMappingWS, _ := h.gatewayService.ResolveChannelMappingAndRestrict(ctx, apiKey.GroupID, reqModel)
|
||||
|
||||
@ -1257,22 +1301,34 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) {
|
||||
},
|
||||
AfterTurn: func(turn int, result *service.OpenAIForwardResult, turnErr error) {
|
||||
releaseTurnSlots()
|
||||
if turnErr != nil || result == nil {
|
||||
if turnErr != nil {
|
||||
if result == nil || result.ImageCount <= 0 {
|
||||
return
|
||||
}
|
||||
reqLog.Warn("openai.websocket_partial_error_with_image_result",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int("image_count", result.ImageCount),
|
||||
zap.Error(turnErr),
|
||||
)
|
||||
}
|
||||
if result == nil {
|
||||
return
|
||||
}
|
||||
if account.Type == service.AccountTypeOAuth {
|
||||
h.gatewayService.UpdateCodexUsageSnapshotFromHeaders(ctx, account.ID, result.ResponseHeaders)
|
||||
}
|
||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, result.FirstTokenMs)
|
||||
h.submitUsageRecordTask(func(taskCtx context.Context) {
|
||||
inboundEndpoint := GetInboundEndpoint(c)
|
||||
upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform)
|
||||
h.submitOpenAIUsageRecordTask(result, func(taskCtx context.Context) {
|
||||
if err := h.gatewayService.RecordUsage(taskCtx, &service.OpenAIRecordUsageInput{
|
||||
Result: result,
|
||||
APIKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
InboundEndpoint: GetInboundEndpoint(c),
|
||||
UpstreamEndpoint: GetUpstreamEndpoint(c, account.Platform),
|
||||
InboundEndpoint: inboundEndpoint,
|
||||
UpstreamEndpoint: upstreamEndpoint,
|
||||
UserAgent: userAgent,
|
||||
IPAddress: clientIP,
|
||||
RequestPayloadHash: service.HashUsageRequestPayload(firstMessage),
|
||||
@ -1440,6 +1496,60 @@ func (h *OpenAIGatewayHandler) submitUsageRecordTask(task service.UsageRecordTas
|
||||
task(ctx)
|
||||
}
|
||||
|
||||
func (h *OpenAIGatewayHandler) submitOpenAIUsageRecordTask(result *service.OpenAIForwardResult, task service.UsageRecordTask) {
|
||||
if result != nil && result.ImageCount > 0 {
|
||||
h.submitMandatoryUsageRecordTask(task)
|
||||
return
|
||||
}
|
||||
h.submitUsageRecordTask(task)
|
||||
}
|
||||
|
||||
func (h *OpenAIGatewayHandler) submitMandatoryUsageRecordTask(task service.UsageRecordTask) {
|
||||
if task == nil {
|
||||
return
|
||||
}
|
||||
if h.usageRecordWorkerPool != nil {
|
||||
if mode := h.usageRecordWorkerPool.Submit(task); mode != service.UsageRecordSubmitModeDropped {
|
||||
return
|
||||
}
|
||||
logger.L().With(
|
||||
zap.String("component", "handler.openai_gateway.usage"),
|
||||
).Warn("openai.usage_record_task_mandatory_sync_fallback")
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
defer func() {
|
||||
if recovered := recover(); recovered != nil {
|
||||
logger.L().With(
|
||||
zap.String("component", "handler.openai_gateway.usage"),
|
||||
zap.Any("panic", recovered),
|
||||
).Error("openai.usage_record_task_panic_recovered")
|
||||
}
|
||||
}()
|
||||
task(ctx)
|
||||
}
|
||||
|
||||
func (h *OpenAIGatewayHandler) acquireImageGenerationSlot(c *gin.Context, streamStarted bool) (func(), bool) {
|
||||
if h == nil || h.cfg == nil || h.imageLimiter == nil {
|
||||
return nil, true
|
||||
}
|
||||
imageConcurrency := h.cfg.Gateway.ImageConcurrency
|
||||
wait := strings.TrimSpace(imageConcurrency.OverflowMode) == config.ImageConcurrencyOverflowModeWait
|
||||
release, acquired := h.imageLimiter.Acquire(
|
||||
c.Request.Context(),
|
||||
imageConcurrency.Enabled,
|
||||
imageConcurrency.MaxConcurrentRequests,
|
||||
wait,
|
||||
time.Duration(imageConcurrency.WaitTimeoutSeconds)*time.Second,
|
||||
imageConcurrency.MaxWaitingRequests,
|
||||
)
|
||||
if acquired {
|
||||
return release, true
|
||||
}
|
||||
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Image generation concurrency limit exceeded, please retry later", streamStarted)
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// handleConcurrencyError handles concurrency-related errors with proper 429 response
|
||||
func (h *OpenAIGatewayHandler) handleConcurrencyError(c *gin.Context, err error, slotType string, streamStarted bool) {
|
||||
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error",
|
||||
|
||||
@ -81,6 +81,18 @@ func (h *OpenAIGatewayHandler) Images(c *gin.Context) {
|
||||
zap.String("capability", string(parsed.RequiredCapability)),
|
||||
)
|
||||
|
||||
if !service.GroupAllowsImageGeneration(apiKey.Group) {
|
||||
h.errorResponse(c, http.StatusForbidden, "permission_error", service.ImageGenerationPermissionMessage())
|
||||
return
|
||||
}
|
||||
imageReleaseFunc, acquired := h.acquireImageGenerationSlot(c, streamStarted)
|
||||
if !acquired {
|
||||
return
|
||||
}
|
||||
if imageReleaseFunc != nil {
|
||||
defer imageReleaseFunc()
|
||||
}
|
||||
|
||||
if parsed.Multipart {
|
||||
setOpsRequestContext(c, parsed.Model, parsed.Stream, nil)
|
||||
} else {
|
||||
@ -188,62 +200,69 @@ func (h *OpenAIGatewayHandler) Images(c *gin.Context) {
|
||||
responseLatencyMs = forwardDurationMs - upstreamLatencyMs
|
||||
}
|
||||
service.SetOpsLatencyMs(c, service.OpsResponseLatencyMsKey, responseLatencyMs)
|
||||
if err == nil && result != nil && result.FirstTokenMs != nil {
|
||||
if result != nil && result.FirstTokenMs != nil {
|
||||
service.SetOpsLatencyMs(c, service.OpsTimeToFirstTokenMsKey, int64(*result.FirstTokenMs))
|
||||
}
|
||||
if err != nil {
|
||||
var failoverErr *service.UpstreamFailoverError
|
||||
if errors.As(err, &failoverErr) {
|
||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
|
||||
if failoverErr.RetryableOnSameAccount {
|
||||
retryLimit := account.GetPoolModeRetryCount()
|
||||
if sameAccountRetryCount[account.ID] < retryLimit {
|
||||
sameAccountRetryCount[account.ID]++
|
||||
reqLog.Warn("openai.images.pool_mode_same_account_retry",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int("upstream_status", failoverErr.StatusCode),
|
||||
zap.Int("retry_limit", retryLimit),
|
||||
zap.Int("retry_count", sameAccountRetryCount[account.ID]),
|
||||
)
|
||||
select {
|
||||
case <-c.Request.Context().Done():
|
||||
return
|
||||
case <-time.After(sameAccountRetryDelay):
|
||||
if result != nil && result.ImageCount > 0 {
|
||||
reqLog.Warn("openai.images.forward_partial_error_with_image_result",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int("image_count", result.ImageCount),
|
||||
zap.Error(err),
|
||||
)
|
||||
} else {
|
||||
var failoverErr *service.UpstreamFailoverError
|
||||
if errors.As(err, &failoverErr) {
|
||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
|
||||
if failoverErr.RetryableOnSameAccount {
|
||||
retryLimit := account.GetPoolModeRetryCount()
|
||||
if sameAccountRetryCount[account.ID] < retryLimit {
|
||||
sameAccountRetryCount[account.ID]++
|
||||
reqLog.Warn("openai.images.pool_mode_same_account_retry",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int("upstream_status", failoverErr.StatusCode),
|
||||
zap.Int("retry_limit", retryLimit),
|
||||
zap.Int("retry_count", sameAccountRetryCount[account.ID]),
|
||||
)
|
||||
select {
|
||||
case <-c.Request.Context().Done():
|
||||
return
|
||||
case <-time.After(sameAccountRetryDelay):
|
||||
}
|
||||
continue
|
||||
}
|
||||
continue
|
||||
}
|
||||
h.gatewayService.RecordOpenAIAccountSwitch()
|
||||
failedAccountIDs[account.ID] = struct{}{}
|
||||
lastFailoverErr = failoverErr
|
||||
if switchCount >= maxAccountSwitches {
|
||||
h.handleFailoverExhausted(c, failoverErr, streamStarted)
|
||||
return
|
||||
}
|
||||
switchCount++
|
||||
reqLog.Warn("openai.images.upstream_failover_switching",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int("upstream_status", failoverErr.StatusCode),
|
||||
zap.Int("switch_count", switchCount),
|
||||
zap.Int("max_switches", maxAccountSwitches),
|
||||
)
|
||||
continue
|
||||
}
|
||||
h.gatewayService.RecordOpenAIAccountSwitch()
|
||||
failedAccountIDs[account.ID] = struct{}{}
|
||||
lastFailoverErr = failoverErr
|
||||
if switchCount >= maxAccountSwitches {
|
||||
h.handleFailoverExhausted(c, failoverErr, streamStarted)
|
||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
|
||||
wroteFallback := h.ensureForwardErrorResponse(c, streamStarted)
|
||||
fields := []zap.Field{
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Bool("fallback_error_response_written", wroteFallback),
|
||||
zap.Error(err),
|
||||
}
|
||||
if shouldLogOpenAIForwardFailureAsWarn(c, wroteFallback) {
|
||||
reqLog.Warn("openai.images.forward_failed", fields...)
|
||||
return
|
||||
}
|
||||
switchCount++
|
||||
reqLog.Warn("openai.images.upstream_failover_switching",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int("upstream_status", failoverErr.StatusCode),
|
||||
zap.Int("switch_count", switchCount),
|
||||
zap.Int("max_switches", maxAccountSwitches),
|
||||
)
|
||||
continue
|
||||
}
|
||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
|
||||
wroteFallback := h.ensureForwardErrorResponse(c, streamStarted)
|
||||
fields := []zap.Field{
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Bool("fallback_error_response_written", wroteFallback),
|
||||
zap.Error(err),
|
||||
}
|
||||
if shouldLogOpenAIForwardFailureAsWarn(c, wroteFallback) {
|
||||
reqLog.Warn("openai.images.forward_failed", fields...)
|
||||
reqLog.Error("openai.images.forward_failed", fields...)
|
||||
return
|
||||
}
|
||||
reqLog.Error("openai.images.forward_failed", fields...)
|
||||
return
|
||||
}
|
||||
|
||||
if result != nil {
|
||||
if account.Type == service.AccountTypeOAuth {
|
||||
h.gatewayService.UpdateCodexUsageSnapshotFromHeaders(c.Request.Context(), account.ID, result.ResponseHeaders)
|
||||
@ -259,21 +278,27 @@ func (h *OpenAIGatewayHandler) Images(c *gin.Context) {
|
||||
if parsed.Multipart {
|
||||
requestPayloadHash = service.HashUsageRequestPayload([]byte(parsed.StickySessionSeed()))
|
||||
}
|
||||
inboundEndpoint := GetInboundEndpoint(c)
|
||||
upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform)
|
||||
|
||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||
upstreamModel := ""
|
||||
if result != nil {
|
||||
upstreamModel = result.UpstreamModel
|
||||
}
|
||||
h.submitMandatoryUsageRecordTask(func(ctx context.Context) {
|
||||
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
|
||||
Result: result,
|
||||
APIKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
InboundEndpoint: GetInboundEndpoint(c),
|
||||
UpstreamEndpoint: GetUpstreamEndpoint(c, account.Platform),
|
||||
InboundEndpoint: inboundEndpoint,
|
||||
UpstreamEndpoint: upstreamEndpoint,
|
||||
UserAgent: userAgent,
|
||||
IPAddress: clientIP,
|
||||
RequestPayloadHash: requestPayloadHash,
|
||||
APIKeyService: h.apiKeyService,
|
||||
ChannelUsageFields: channelMapping.ToUsageFields(parsed.Model, result.UpstreamModel),
|
||||
ChannelUsageFields: channelMapping.ToUsageFields(parsed.Model, upstreamModel),
|
||||
}); err != nil {
|
||||
logger.L().With(
|
||||
zap.String("component", "handler.openai_gateway.images"),
|
||||
|
||||
49
backend/internal/handler/openai_images_controls_test.go
Normal file
49
backend/internal/handler/openai_images_controls_test.go
Normal file
@ -0,0 +1,49 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
func TestOpenAIGatewayHandlerImages_DisabledGroupRejectsBeforeScheduling(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
body := []byte(`{"model":"gpt-image-2","prompt":"draw","size":"1024x1024"}`)
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/images/generations", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = req
|
||||
groupID := int64(111)
|
||||
c.Set(string(middleware2.ContextKeyAPIKey), &service.APIKey{
|
||||
ID: 222,
|
||||
GroupID: &groupID,
|
||||
Group: &service.Group{
|
||||
ID: groupID,
|
||||
AllowImageGeneration: false,
|
||||
},
|
||||
User: &service.User{ID: 333},
|
||||
})
|
||||
c.Set(string(middleware2.ContextKeyUser), middleware2.AuthSubject{UserID: 333, Concurrency: 1})
|
||||
|
||||
h := &OpenAIGatewayHandler{
|
||||
gatewayService: &service.OpenAIGatewayService{},
|
||||
billingCacheService: &service.BillingCacheService{},
|
||||
apiKeyService: &service.APIKeyService{},
|
||||
concurrencyHelper: &ConcurrencyHelper{concurrencyService: &service.ConcurrencyService{}},
|
||||
}
|
||||
|
||||
h.Images(c)
|
||||
|
||||
require.Equal(t, http.StatusForbidden, rec.Code)
|
||||
require.Equal(t, "permission_error", gjson.GetBytes(rec.Body.Bytes(), "error.type").String())
|
||||
require.Contains(t, rec.Body.String(), service.ImageGenerationPermissionMessage())
|
||||
}
|
||||
@ -129,3 +129,63 @@ func TestOpenAIGatewayHandlerSubmitUsageRecordTask_WithoutPool_TaskPanicRecovere
|
||||
})
|
||||
require.True(t, called.Load(), "panic 后后续任务应仍可执行")
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayHandlerSubmitMandatoryUsageRecordTask_DroppedTaskSyncFallback(t *testing.T) {
|
||||
pool := service.NewUsageRecordWorkerPoolWithOptions(service.UsageRecordWorkerPoolOptions{
|
||||
WorkerCount: 1,
|
||||
QueueSize: 1,
|
||||
TaskTimeout: time.Second,
|
||||
OverflowPolicy: "drop",
|
||||
OverflowSamplePercent: 0,
|
||||
AutoScaleEnabled: false,
|
||||
})
|
||||
t.Cleanup(pool.Stop)
|
||||
h := &OpenAIGatewayHandler{usageRecordWorkerPool: pool}
|
||||
|
||||
block := make(chan struct{})
|
||||
release := make(chan struct{})
|
||||
pool.Submit(func(ctx context.Context) {
|
||||
close(block)
|
||||
<-release
|
||||
})
|
||||
<-block
|
||||
pool.Submit(func(ctx context.Context) {})
|
||||
|
||||
var called atomic.Bool
|
||||
h.submitMandatoryUsageRecordTask(func(ctx context.Context) {
|
||||
called.Store(true)
|
||||
})
|
||||
close(release)
|
||||
|
||||
require.True(t, called.Load(), "mandatory usage task must run synchronously when async submit is dropped")
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayHandlerSubmitOpenAIUsageRecordTask_ImageResultUsesMandatoryFallback(t *testing.T) {
|
||||
pool := service.NewUsageRecordWorkerPoolWithOptions(service.UsageRecordWorkerPoolOptions{
|
||||
WorkerCount: 1,
|
||||
QueueSize: 1,
|
||||
TaskTimeout: time.Second,
|
||||
OverflowPolicy: "drop",
|
||||
OverflowSamplePercent: 0,
|
||||
AutoScaleEnabled: false,
|
||||
})
|
||||
t.Cleanup(pool.Stop)
|
||||
h := &OpenAIGatewayHandler{usageRecordWorkerPool: pool}
|
||||
|
||||
block := make(chan struct{})
|
||||
release := make(chan struct{})
|
||||
pool.Submit(func(ctx context.Context) {
|
||||
close(block)
|
||||
<-release
|
||||
})
|
||||
<-block
|
||||
pool.Submit(func(ctx context.Context) {})
|
||||
|
||||
var called atomic.Bool
|
||||
h.submitOpenAIUsageRecordTask(&service.OpenAIForwardResult{ImageCount: 1}, func(ctx context.Context) {
|
||||
called.Store(true)
|
||||
})
|
||||
close(release)
|
||||
|
||||
require.True(t, called.Load(), "image usage task must be mandatory when async submit is dropped")
|
||||
}
|
||||
|
||||
@ -166,6 +166,9 @@ func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*se
|
||||
group.FieldDailyLimitUsd,
|
||||
group.FieldWeeklyLimitUsd,
|
||||
group.FieldMonthlyLimitUsd,
|
||||
group.FieldAllowImageGeneration,
|
||||
group.FieldImageRateIndependent,
|
||||
group.FieldImageRateMultiplier,
|
||||
group.FieldImagePrice1k,
|
||||
group.FieldImagePrice2k,
|
||||
group.FieldImagePrice4k,
|
||||
@ -699,6 +702,9 @@ func groupEntityToService(g *dbent.Group) *service.Group {
|
||||
DailyLimitUSD: g.DailyLimitUsd,
|
||||
WeeklyLimitUSD: g.WeeklyLimitUsd,
|
||||
MonthlyLimitUSD: g.MonthlyLimitUsd,
|
||||
AllowImageGeneration: g.AllowImageGeneration,
|
||||
ImageRateIndependent: g.ImageRateIndependent,
|
||||
ImageRateMultiplier: g.ImageRateMultiplier,
|
||||
ImagePrice1K: g.ImagePrice1k,
|
||||
ImagePrice2K: g.ImagePrice2k,
|
||||
ImagePrice4K: g.ImagePrice4k,
|
||||
|
||||
@ -50,6 +50,9 @@ func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) er
|
||||
SetNillableDailyLimitUsd(groupIn.DailyLimitUSD).
|
||||
SetNillableWeeklyLimitUsd(groupIn.WeeklyLimitUSD).
|
||||
SetNillableMonthlyLimitUsd(groupIn.MonthlyLimitUSD).
|
||||
SetAllowImageGeneration(groupIn.AllowImageGeneration).
|
||||
SetImageRateIndependent(groupIn.ImageRateIndependent).
|
||||
SetImageRateMultiplier(groupIn.ImageRateMultiplier).
|
||||
SetNillableImagePrice1k(groupIn.ImagePrice1K).
|
||||
SetNillableImagePrice2k(groupIn.ImagePrice2K).
|
||||
SetNillableImagePrice4k(groupIn.ImagePrice4K).
|
||||
@ -120,6 +123,9 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er
|
||||
SetNillableDailyLimitUsd(groupIn.DailyLimitUSD).
|
||||
SetNillableWeeklyLimitUsd(groupIn.WeeklyLimitUSD).
|
||||
SetNillableMonthlyLimitUsd(groupIn.MonthlyLimitUSD).
|
||||
SetAllowImageGeneration(groupIn.AllowImageGeneration).
|
||||
SetImageRateIndependent(groupIn.ImageRateIndependent).
|
||||
SetImageRateMultiplier(groupIn.ImageRateMultiplier).
|
||||
SetNillableImagePrice1k(groupIn.ImagePrice1K).
|
||||
SetNillableImagePrice2k(groupIn.ImagePrice2K).
|
||||
SetNillableImagePrice4k(groupIn.ImagePrice4K).
|
||||
|
||||
@ -328,6 +328,9 @@ func TestAPIContracts(t *testing.T) {
|
||||
"image_price_1k": null,
|
||||
"image_price_2k": null,
|
||||
"image_price_4k": null,
|
||||
"allow_image_generation": false,
|
||||
"image_rate_independent": false,
|
||||
"image_rate_multiplier": 0,
|
||||
"claude_code_only": false,
|
||||
"allow_messages_dispatch": false,
|
||||
"fallback_group_id": null,
|
||||
|
||||
@ -230,7 +230,11 @@ func applyAccountStatsCost(
|
||||
if model == "" {
|
||||
model = requestedModel
|
||||
}
|
||||
requestCount := 1
|
||||
if usageLog != nil && usageLog.ImageCount > 0 {
|
||||
requestCount = usageLog.ImageCount
|
||||
}
|
||||
usageLog.AccountStatsCost = resolveAccountStatsCost(
|
||||
ctx, cs, bs, accountID, groupID, model, tokens, 1, totalCost,
|
||||
ctx, cs, bs, accountID, groupID, model, tokens, requestCount, totalCost,
|
||||
)
|
||||
}
|
||||
|
||||
@ -189,11 +189,14 @@ type CreateGroupInput struct {
|
||||
WeeklyLimitUSD *float64 // 周限额 (USD)
|
||||
MonthlyLimitUSD *float64 // 月限额 (USD)
|
||||
// 图片生成计费配置(仅 antigravity 平台使用)
|
||||
ImagePrice1K *float64
|
||||
ImagePrice2K *float64
|
||||
ImagePrice4K *float64
|
||||
ClaudeCodeOnly bool // 仅允许 Claude Code 客户端
|
||||
FallbackGroupID *int64 // 降级分组 ID
|
||||
AllowImageGeneration bool
|
||||
ImageRateIndependent bool
|
||||
ImageRateMultiplier *float64
|
||||
ImagePrice1K *float64
|
||||
ImagePrice2K *float64
|
||||
ImagePrice4K *float64
|
||||
ClaudeCodeOnly bool // 仅允许 Claude Code 客户端
|
||||
FallbackGroupID *int64 // 降级分组 ID
|
||||
// 无效请求兜底分组 ID(仅 anthropic 平台使用)
|
||||
FallbackGroupIDOnInvalidRequest *int64
|
||||
// 模型路由配置(仅 anthropic 平台使用)
|
||||
@ -226,11 +229,14 @@ type UpdateGroupInput struct {
|
||||
WeeklyLimitUSD *float64 // 周限额 (USD)
|
||||
MonthlyLimitUSD *float64 // 月限额 (USD)
|
||||
// 图片生成计费配置(仅 antigravity 平台使用)
|
||||
ImagePrice1K *float64
|
||||
ImagePrice2K *float64
|
||||
ImagePrice4K *float64
|
||||
ClaudeCodeOnly *bool // 仅允许 Claude Code 客户端
|
||||
FallbackGroupID *int64 // 降级分组 ID
|
||||
AllowImageGeneration *bool
|
||||
ImageRateIndependent *bool
|
||||
ImageRateMultiplier *float64
|
||||
ImagePrice1K *float64
|
||||
ImagePrice2K *float64
|
||||
ImagePrice4K *float64
|
||||
ClaudeCodeOnly *bool // 仅允许 Claude Code 客户端
|
||||
FallbackGroupID *int64 // 降级分组 ID
|
||||
// 无效请求兜底分组 ID(仅 anthropic 平台使用)
|
||||
FallbackGroupIDOnInvalidRequest *int64
|
||||
// 模型路由配置(仅 anthropic 平台使用)
|
||||
@ -1557,6 +1563,13 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
|
||||
imagePrice1K := normalizePrice(input.ImagePrice1K)
|
||||
imagePrice2K := normalizePrice(input.ImagePrice2K)
|
||||
imagePrice4K := normalizePrice(input.ImagePrice4K)
|
||||
imageRateMultiplier := 1.0
|
||||
if input.ImageRateMultiplier != nil {
|
||||
if *input.ImageRateMultiplier < 0 {
|
||||
return nil, errors.New("image_rate_multiplier must be >= 0")
|
||||
}
|
||||
imageRateMultiplier = *input.ImageRateMultiplier
|
||||
}
|
||||
|
||||
// 校验降级分组
|
||||
if input.FallbackGroupID != nil {
|
||||
@ -1624,6 +1637,9 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
|
||||
DailyLimitUSD: dailyLimit,
|
||||
WeeklyLimitUSD: weeklyLimit,
|
||||
MonthlyLimitUSD: monthlyLimit,
|
||||
AllowImageGeneration: input.AllowImageGeneration,
|
||||
ImageRateIndependent: input.ImageRateIndependent,
|
||||
ImageRateMultiplier: imageRateMultiplier,
|
||||
ImagePrice1K: imagePrice1K,
|
||||
ImagePrice2K: imagePrice2K,
|
||||
ImagePrice4K: imagePrice4K,
|
||||
@ -1800,6 +1816,18 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd
|
||||
group.WeeklyLimitUSD = normalizeLimit(input.WeeklyLimitUSD)
|
||||
group.MonthlyLimitUSD = normalizeLimit(input.MonthlyLimitUSD)
|
||||
// 图片生成计费配置:负数表示清除(使用默认价格)
|
||||
if input.AllowImageGeneration != nil {
|
||||
group.AllowImageGeneration = *input.AllowImageGeneration
|
||||
}
|
||||
if input.ImageRateIndependent != nil {
|
||||
group.ImageRateIndependent = *input.ImageRateIndependent
|
||||
}
|
||||
if input.ImageRateMultiplier != nil {
|
||||
if *input.ImageRateMultiplier < 0 {
|
||||
return nil, errors.New("image_rate_multiplier must be >= 0")
|
||||
}
|
||||
group.ImageRateMultiplier = *input.ImageRateMultiplier
|
||||
}
|
||||
if input.ImagePrice1K != nil {
|
||||
group.ImagePrice1K = normalizePrice(input.ImagePrice1K)
|
||||
}
|
||||
|
||||
@ -266,6 +266,50 @@ func TestAdminService_UpdateGroup_PartialImagePricing(t *testing.T) {
|
||||
require.Nil(t, repo.updated.ImagePrice4K)
|
||||
}
|
||||
|
||||
func TestAdminService_UpdateGroup_PreservesImageGenerationControlsWhenOmitted(t *testing.T) {
|
||||
imageMultiplier := 0.5
|
||||
existingGroup := &Group{
|
||||
ID: 1,
|
||||
Name: "existing-group",
|
||||
Platform: PlatformOpenAI,
|
||||
Status: StatusActive,
|
||||
AllowImageGeneration: true,
|
||||
ImageRateIndependent: true,
|
||||
ImageRateMultiplier: imageMultiplier,
|
||||
}
|
||||
repo := &groupRepoStubForAdmin{getByID: existingGroup}
|
||||
svc := &adminServiceImpl{groupRepo: repo}
|
||||
|
||||
group, err := svc.UpdateGroup(context.Background(), 1, &UpdateGroupInput{
|
||||
Description: "updated",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, group)
|
||||
require.NotNil(t, repo.updated)
|
||||
require.True(t, repo.updated.AllowImageGeneration)
|
||||
require.True(t, repo.updated.ImageRateIndependent)
|
||||
require.InDelta(t, 0.5, repo.updated.ImageRateMultiplier, 1e-12)
|
||||
}
|
||||
|
||||
func TestAdminService_UpdateGroup_RejectsNegativeImageRateMultiplier(t *testing.T) {
|
||||
existingGroup := &Group{
|
||||
ID: 1,
|
||||
Name: "existing-group",
|
||||
Platform: PlatformOpenAI,
|
||||
Status: StatusActive,
|
||||
ImageRateMultiplier: 1,
|
||||
}
|
||||
repo := &groupRepoStubForAdmin{getByID: existingGroup}
|
||||
svc := &adminServiceImpl{groupRepo: repo}
|
||||
negative := -0.1
|
||||
|
||||
_, err := svc.UpdateGroup(context.Background(), 1, &UpdateGroupInput{
|
||||
ImageRateMultiplier: &negative,
|
||||
})
|
||||
require.Error(t, err)
|
||||
require.Nil(t, repo.updated)
|
||||
}
|
||||
|
||||
func TestAdminService_UpdateGroup_InvalidatesAuthCacheOnRPMLimitChange(t *testing.T) {
|
||||
existingGroup := &Group{
|
||||
ID: 1,
|
||||
|
||||
@ -63,6 +63,9 @@ type APIKeyAuthGroupSnapshot struct {
|
||||
DailyLimitUSD *float64 `json:"daily_limit_usd,omitempty"`
|
||||
WeeklyLimitUSD *float64 `json:"weekly_limit_usd,omitempty"`
|
||||
MonthlyLimitUSD *float64 `json:"monthly_limit_usd,omitempty"`
|
||||
AllowImageGeneration bool `json:"allow_image_generation"`
|
||||
ImageRateIndependent bool `json:"image_rate_independent"`
|
||||
ImageRateMultiplier float64 `json:"image_rate_multiplier"`
|
||||
ImagePrice1K *float64 `json:"image_price_1k,omitempty"`
|
||||
ImagePrice2K *float64 `json:"image_price_2k,omitempty"`
|
||||
ImagePrice4K *float64 `json:"image_price_4k,omitempty"`
|
||||
|
||||
@ -14,7 +14,7 @@ import (
|
||||
"github.com/dgraph-io/ristretto"
|
||||
)
|
||||
|
||||
const apiKeyAuthSnapshotVersion = 7 // v7: added UserGroupRPMOverride on user snapshot
|
||||
const apiKeyAuthSnapshotVersion = 8 // v8: added group image generation controls
|
||||
|
||||
type apiKeyAuthCacheConfig struct {
|
||||
l1Size int
|
||||
@ -255,6 +255,9 @@ func (s *APIKeyService) snapshotFromAPIKey(ctx context.Context, apiKey *APIKey)
|
||||
DailyLimitUSD: apiKey.Group.DailyLimitUSD,
|
||||
WeeklyLimitUSD: apiKey.Group.WeeklyLimitUSD,
|
||||
MonthlyLimitUSD: apiKey.Group.MonthlyLimitUSD,
|
||||
AllowImageGeneration: apiKey.Group.AllowImageGeneration,
|
||||
ImageRateIndependent: apiKey.Group.ImageRateIndependent,
|
||||
ImageRateMultiplier: apiKey.Group.ImageRateMultiplier,
|
||||
ImagePrice1K: apiKey.Group.ImagePrice1K,
|
||||
ImagePrice2K: apiKey.Group.ImagePrice2K,
|
||||
ImagePrice4K: apiKey.Group.ImagePrice4K,
|
||||
@ -321,6 +324,9 @@ func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapsho
|
||||
DailyLimitUSD: snapshot.Group.DailyLimitUSD,
|
||||
WeeklyLimitUSD: snapshot.Group.WeeklyLimitUSD,
|
||||
MonthlyLimitUSD: snapshot.Group.MonthlyLimitUSD,
|
||||
AllowImageGeneration: snapshot.Group.AllowImageGeneration,
|
||||
ImageRateIndependent: snapshot.Group.ImageRateIndependent,
|
||||
ImageRateMultiplier: snapshot.Group.ImageRateMultiplier,
|
||||
ImagePrice1K: snapshot.Group.ImagePrice1K,
|
||||
ImagePrice2K: snapshot.Group.ImagePrice2K,
|
||||
ImagePrice4K: snapshot.Group.ImagePrice4K,
|
||||
|
||||
@ -294,8 +294,7 @@ func (s *BillingService) getFallbackPricing(model string) *ModelPricing {
|
||||
}
|
||||
|
||||
// OpenAI 仅匹配已知 GPT-5/Codex 族,避免未知 OpenAI 型号误计价。
|
||||
if strings.Contains(modelLower, "gpt-5") || strings.Contains(modelLower, "codex") {
|
||||
normalized := normalizeCodexModel(modelLower)
|
||||
if normalized := normalizeKnownOpenAICodexModel(modelLower); normalized != "" {
|
||||
switch normalized {
|
||||
case "gpt-5.5":
|
||||
return s.fallbackPrices["gpt-5.5"]
|
||||
@ -644,13 +643,10 @@ func (s *BillingService) shouldApplySessionLongContextPricing(tokens UsageTokens
|
||||
}
|
||||
|
||||
func isOpenAIGPT54Model(model string) bool {
|
||||
trimmed := strings.TrimSpace(strings.ToLower(model))
|
||||
// 仅当模型字符串实际属于 GPT-5/Codex 族时才做归一判定,避免 normalizeCodexModel
|
||||
// 的默认兜底把非 OpenAI 模型(claude-*、gemini-*、gpt-4o)误识别为 gpt-5.4。
|
||||
if !strings.Contains(trimmed, "gpt-5") && !strings.Contains(trimmed, "codex") {
|
||||
return false
|
||||
}
|
||||
normalized := normalizeCodexModel(trimmed)
|
||||
// 仅当模型字符串实际属于已知 GPT-5/Codex 族时才做归一判定,避免
|
||||
// normalizeCodexModel 的默认兜底把非 OpenAI 模型(claude-*、gemini-*、gpt-4o)
|
||||
// 误识别为 gpt-5.4。
|
||||
normalized := normalizeKnownOpenAICodexModel(model)
|
||||
return normalized == "gpt-5.4" || normalized == "gpt-5.5"
|
||||
}
|
||||
|
||||
|
||||
@ -137,6 +137,35 @@ func TestGetModelPricing_OpenAIGPT54Fallback(t *testing.T) {
|
||||
require.InDelta(t, 1.5, pricing.LongContextOutputMultiplier, 1e-12)
|
||||
}
|
||||
|
||||
func TestGetModelPricing_OpenAICompactAliasesFallback(t *testing.T) {
|
||||
svc := newTestBillingService()
|
||||
|
||||
tests := []struct {
|
||||
model string
|
||||
inputPrice float64
|
||||
outputPrice float64
|
||||
cacheRead float64
|
||||
longContext int
|
||||
}{
|
||||
{model: "gpt5.5", inputPrice: 2.5e-6, outputPrice: 15e-6, cacheRead: 0.25e-6, longContext: 272000},
|
||||
{model: "openai/gpt5.4", inputPrice: 2.5e-6, outputPrice: 15e-6, cacheRead: 0.25e-6, longContext: 272000},
|
||||
{model: "gpt5.4-mini", inputPrice: 7.5e-7, outputPrice: 4.5e-6, cacheRead: 7.5e-8, longContext: 0},
|
||||
{model: "gpt5.3codexspark", inputPrice: 1.5e-6, outputPrice: 12e-6, cacheRead: 0.15e-6, longContext: 0},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.model, func(t *testing.T) {
|
||||
pricing, err := svc.GetModelPricing(tt.model)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, pricing)
|
||||
require.InDelta(t, tt.inputPrice, pricing.InputPricePerToken, 1e-12)
|
||||
require.InDelta(t, tt.outputPrice, pricing.OutputPricePerToken, 1e-12)
|
||||
require.InDelta(t, tt.cacheRead, pricing.CacheReadPricePerToken, 1e-12)
|
||||
require.Equal(t, tt.longContext, pricing.LongContextInputThreshold)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetModelPricing_OpenAIGPT54MiniFallback(t *testing.T) {
|
||||
svc := newTestBillingService()
|
||||
|
||||
|
||||
@ -8367,6 +8367,7 @@ func (s *GatewayService) recordUsageCore(ctx context.Context, input *recordUsage
|
||||
groupDefault := apiKey.Group.RateMultiplier
|
||||
multiplier = s.getUserGroupRateMultiplier(ctx, user.ID, *apiKey.GroupID, groupDefault)
|
||||
}
|
||||
imageMultiplier := resolveImageRateMultiplier(apiKey, multiplier)
|
||||
|
||||
// 确定计费模型
|
||||
billingModel := forwardResultBillingModel(result.Model, result.UpstreamModel)
|
||||
@ -8384,7 +8385,7 @@ func (s *GatewayService) recordUsageCore(ctx context.Context, input *recordUsage
|
||||
}
|
||||
|
||||
// 计算费用
|
||||
cost := s.calculateRecordUsageCost(ctx, result, apiKey, billingModel, multiplier, opts)
|
||||
cost := s.calculateRecordUsageCost(ctx, result, apiKey, billingModel, multiplier, imageMultiplier, opts)
|
||||
|
||||
// 判断计费方式:订阅模式 vs 余额模式
|
||||
isSubscriptionBilling := subscription != nil && apiKey.Group != nil && apiKey.Group.IsSubscriptionType()
|
||||
@ -8396,7 +8397,7 @@ func (s *GatewayService) recordUsageCore(ctx context.Context, input *recordUsage
|
||||
// 创建使用日志
|
||||
accountRateMultiplier := account.BillingRateMultiplier()
|
||||
usageLog := s.buildRecordUsageLog(ctx, input, result, apiKey, user, account, subscription,
|
||||
requestedModel, multiplier, accountRateMultiplier, billingType, cacheTTLOverridden, cost, opts)
|
||||
requestedModel, multiplier, imageMultiplier, accountRateMultiplier, billingType, cacheTTLOverridden, cost, opts)
|
||||
|
||||
// 计算账号统计定价费用(使用最终上游模型匹配自定义规则)
|
||||
if apiKey.GroupID != nil {
|
||||
@ -8450,11 +8451,12 @@ func (s *GatewayService) calculateRecordUsageCost(
|
||||
apiKey *APIKey,
|
||||
billingModel string,
|
||||
multiplier float64,
|
||||
imageMultiplier float64,
|
||||
opts *recordUsageOpts,
|
||||
) *CostBreakdown {
|
||||
// 图片生成计费
|
||||
if result.ImageCount > 0 {
|
||||
return s.calculateImageCost(ctx, result, apiKey, billingModel, multiplier)
|
||||
return s.calculateImageCost(ctx, result, apiKey, billingModel, imageMultiplier)
|
||||
}
|
||||
|
||||
// Token 计费
|
||||
@ -8495,7 +8497,8 @@ func (s *GatewayService) calculateImageCost(
|
||||
Model: billingModel,
|
||||
GroupID: &gid,
|
||||
Tokens: tokens,
|
||||
RequestCount: 1,
|
||||
RequestCount: result.ImageCount,
|
||||
SizeTier: result.ImageSize,
|
||||
RateMultiplier: multiplier,
|
||||
Resolver: s.resolver,
|
||||
Resolved: resolved,
|
||||
@ -8580,6 +8583,7 @@ func (s *GatewayService) buildRecordUsageLog(
|
||||
subscription *UserSubscription,
|
||||
requestedModel string,
|
||||
multiplier float64,
|
||||
imageMultiplier float64,
|
||||
accountRateMultiplier float64,
|
||||
billingType int8,
|
||||
cacheTTLOverridden bool,
|
||||
@ -8624,6 +8628,9 @@ func (s *GatewayService) buildRecordUsageLog(
|
||||
SubscriptionID: optionalSubscriptionID(subscription),
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
if result.ImageCount > 0 {
|
||||
usageLog.RateMultiplier = imageMultiplier
|
||||
}
|
||||
if cost != nil {
|
||||
usageLog.InputCost = cost.InputCost
|
||||
usageLog.OutputCost = cost.OutputCost
|
||||
|
||||
@ -26,9 +26,12 @@ type Group struct {
|
||||
DefaultValidityDays int
|
||||
|
||||
// 图片生成计费配置(antigravity 和 gemini 平台使用)
|
||||
ImagePrice1K *float64
|
||||
ImagePrice2K *float64
|
||||
ImagePrice4K *float64
|
||||
AllowImageGeneration bool
|
||||
ImageRateIndependent bool
|
||||
ImageRateMultiplier float64
|
||||
ImagePrice1K *float64
|
||||
ImagePrice2K *float64
|
||||
ImagePrice4K *float64
|
||||
|
||||
// Claude Code 客户端限制
|
||||
ClaudeCodeOnly bool
|
||||
|
||||
@ -45,19 +45,25 @@ type GroupSortOrderUpdate struct {
|
||||
|
||||
// CreateGroupRequest 创建分组请求
|
||||
type CreateGroupRequest struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
RateMultiplier float64 `json:"rate_multiplier"`
|
||||
IsExclusive bool `json:"is_exclusive"`
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
RateMultiplier float64 `json:"rate_multiplier"`
|
||||
IsExclusive bool `json:"is_exclusive"`
|
||||
AllowImageGeneration bool `json:"allow_image_generation"`
|
||||
ImageRateIndependent bool `json:"image_rate_independent"`
|
||||
ImageRateMultiplier *float64 `json:"image_rate_multiplier"`
|
||||
}
|
||||
|
||||
// UpdateGroupRequest 更新分组请求
|
||||
type UpdateGroupRequest struct {
|
||||
Name *string `json:"name"`
|
||||
Description *string `json:"description"`
|
||||
RateMultiplier *float64 `json:"rate_multiplier"`
|
||||
IsExclusive *bool `json:"is_exclusive"`
|
||||
Status *string `json:"status"`
|
||||
Name *string `json:"name"`
|
||||
Description *string `json:"description"`
|
||||
RateMultiplier *float64 `json:"rate_multiplier"`
|
||||
IsExclusive *bool `json:"is_exclusive"`
|
||||
Status *string `json:"status"`
|
||||
AllowImageGeneration *bool `json:"allow_image_generation"`
|
||||
ImageRateIndependent *bool `json:"image_rate_independent"`
|
||||
ImageRateMultiplier *float64 `json:"image_rate_multiplier"`
|
||||
}
|
||||
|
||||
// GroupService 分组管理服务
|
||||
@ -76,6 +82,13 @@ func NewGroupService(groupRepo GroupRepository, authCacheInvalidator APIKeyAuthC
|
||||
|
||||
// Create 创建分组
|
||||
func (s *GroupService) Create(ctx context.Context, req CreateGroupRequest) (*Group, error) {
|
||||
imageRateMultiplier := 1.0
|
||||
if req.ImageRateMultiplier != nil {
|
||||
if *req.ImageRateMultiplier < 0 {
|
||||
return nil, fmt.Errorf("image_rate_multiplier must be >= 0")
|
||||
}
|
||||
imageRateMultiplier = *req.ImageRateMultiplier
|
||||
}
|
||||
// 检查名称是否已存在
|
||||
exists, err := s.groupRepo.ExistsByName(ctx, req.Name)
|
||||
if err != nil {
|
||||
@ -87,13 +100,16 @@ func (s *GroupService) Create(ctx context.Context, req CreateGroupRequest) (*Gro
|
||||
|
||||
// 创建分组
|
||||
group := &Group{
|
||||
Name: req.Name,
|
||||
Description: req.Description,
|
||||
Platform: PlatformAnthropic,
|
||||
RateMultiplier: req.RateMultiplier,
|
||||
IsExclusive: req.IsExclusive,
|
||||
Status: StatusActive,
|
||||
SubscriptionType: SubscriptionTypeStandard,
|
||||
Name: req.Name,
|
||||
Description: req.Description,
|
||||
Platform: PlatformAnthropic,
|
||||
RateMultiplier: req.RateMultiplier,
|
||||
IsExclusive: req.IsExclusive,
|
||||
Status: StatusActive,
|
||||
SubscriptionType: SubscriptionTypeStandard,
|
||||
AllowImageGeneration: req.AllowImageGeneration,
|
||||
ImageRateIndependent: req.ImageRateIndependent,
|
||||
ImageRateMultiplier: imageRateMultiplier,
|
||||
}
|
||||
|
||||
if err := s.groupRepo.Create(ctx, group); err != nil {
|
||||
@ -165,6 +181,18 @@ func (s *GroupService) Update(ctx context.Context, id int64, req UpdateGroupRequ
|
||||
if req.Status != nil {
|
||||
group.Status = *req.Status
|
||||
}
|
||||
if req.AllowImageGeneration != nil {
|
||||
group.AllowImageGeneration = *req.AllowImageGeneration
|
||||
}
|
||||
if req.ImageRateIndependent != nil {
|
||||
group.ImageRateIndependent = *req.ImageRateIndependent
|
||||
}
|
||||
if req.ImageRateMultiplier != nil {
|
||||
if *req.ImageRateMultiplier < 0 {
|
||||
return nil, fmt.Errorf("image_rate_multiplier must be >= 0")
|
||||
}
|
||||
group.ImageRateMultiplier = *req.ImageRateMultiplier
|
||||
}
|
||||
|
||||
if err := s.groupRepo.Update(ctx, group); err != nil {
|
||||
return nil, fmt.Errorf("update group: %w", err)
|
||||
|
||||
11
backend/internal/service/image_billing_multiplier.go
Normal file
11
backend/internal/service/image_billing_multiplier.go
Normal file
@ -0,0 +1,11 @@
|
||||
package service
|
||||
|
||||
func resolveImageRateMultiplier(apiKey *APIKey, effectiveGroupMultiplier float64) float64 {
|
||||
if apiKey != nil && apiKey.Group != nil && apiKey.Group.ImageRateIndependent {
|
||||
if apiKey.Group.ImageRateMultiplier < 0 {
|
||||
return 0
|
||||
}
|
||||
return apiKey.Group.ImageRateMultiplier
|
||||
}
|
||||
return effectiveGroupMultiplier
|
||||
}
|
||||
220
backend/internal/service/image_generation_intent.go
Normal file
220
backend/internal/service/image_generation_intent.go
Normal file
@ -0,0 +1,220 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
const (
|
||||
openAIResponsesEndpoint = "/v1/responses"
|
||||
openAIResponsesCompactEndpoint = "/v1/responses/compact"
|
||||
imageGenerationPermissionMessage = "Image generation is not enabled for this group"
|
||||
)
|
||||
|
||||
// ImageGenerationPermissionMessage returns the stable end-user error text for disabled groups.
|
||||
func ImageGenerationPermissionMessage() string {
|
||||
return imageGenerationPermissionMessage
|
||||
}
|
||||
|
||||
// GroupAllowsImageGeneration preserves ungrouped-key behavior and enforces the flag when a group is present.
|
||||
func GroupAllowsImageGeneration(group *Group) bool {
|
||||
return group == nil || group.AllowImageGeneration
|
||||
}
|
||||
|
||||
// IsImageGenerationIntent classifies requests that can produce generated images.
|
||||
func IsImageGenerationIntent(endpoint string, requestedModel string, body []byte) bool {
|
||||
if IsImageGenerationEndpoint(endpoint) {
|
||||
return true
|
||||
}
|
||||
if isOpenAIImageGenerationModel(requestedModel) {
|
||||
return true
|
||||
}
|
||||
if len(body) == 0 || !gjson.ValidBytes(body) {
|
||||
return false
|
||||
}
|
||||
if model := strings.TrimSpace(gjson.GetBytes(body, "model").String()); isOpenAIImageGenerationModel(model) {
|
||||
return true
|
||||
}
|
||||
if openAIJSONToolsContainImageGeneration(gjson.GetBytes(body, "tools")) {
|
||||
return true
|
||||
}
|
||||
return openAIJSONToolChoiceSelectsImageGeneration(gjson.GetBytes(body, "tool_choice"))
|
||||
}
|
||||
|
||||
// IsImageGenerationIntentMap is the map-backed variant used after service-side request mutation.
|
||||
func IsImageGenerationIntentMap(endpoint string, requestedModel string, reqBody map[string]any) bool {
|
||||
if IsImageGenerationEndpoint(endpoint) {
|
||||
return true
|
||||
}
|
||||
if isOpenAIImageGenerationModel(requestedModel) {
|
||||
return true
|
||||
}
|
||||
if reqBody == nil {
|
||||
return false
|
||||
}
|
||||
if isOpenAIImageGenerationModel(firstNonEmptyString(reqBody["model"])) {
|
||||
return true
|
||||
}
|
||||
if hasOpenAIImageGenerationTool(reqBody) {
|
||||
return true
|
||||
}
|
||||
return openAIAnyToolChoiceSelectsImageGeneration(reqBody["tool_choice"])
|
||||
}
|
||||
|
||||
// IsImageGenerationEndpoint identifies dedicated generated-image endpoints.
|
||||
func IsImageGenerationEndpoint(endpoint string) bool {
|
||||
switch normalizeImageGenerationEndpoint(endpoint) {
|
||||
case "/v1/images/generations", "/v1/images/edits", "/images/generations", "/images/edits":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func normalizeImageGenerationEndpoint(endpoint string) string {
|
||||
endpoint = strings.TrimSpace(strings.ToLower(endpoint))
|
||||
if endpoint == "" {
|
||||
return ""
|
||||
}
|
||||
endpoint = strings.TrimPrefix(endpoint, "https://api.openai.com")
|
||||
if idx := strings.IndexByte(endpoint, '?'); idx >= 0 {
|
||||
endpoint = endpoint[:idx]
|
||||
}
|
||||
return strings.TrimRight(endpoint, "/")
|
||||
}
|
||||
|
||||
func openAIJSONToolsContainImageGeneration(tools gjson.Result) bool {
|
||||
if !tools.IsArray() {
|
||||
return false
|
||||
}
|
||||
found := false
|
||||
tools.ForEach(func(_, item gjson.Result) bool {
|
||||
if strings.TrimSpace(item.Get("type").String()) == "image_generation" {
|
||||
found = true
|
||||
return false
|
||||
}
|
||||
return true
|
||||
})
|
||||
return found
|
||||
}
|
||||
|
||||
func openAIJSONToolChoiceSelectsImageGeneration(choice gjson.Result) bool {
|
||||
if !choice.Exists() {
|
||||
return false
|
||||
}
|
||||
if choice.Type == gjson.String {
|
||||
return strings.TrimSpace(choice.String()) == "image_generation"
|
||||
}
|
||||
if !choice.IsObject() {
|
||||
return false
|
||||
}
|
||||
if strings.TrimSpace(choice.Get("type").String()) == "image_generation" {
|
||||
return true
|
||||
}
|
||||
if strings.TrimSpace(choice.Get("tool.type").String()) == "image_generation" {
|
||||
return true
|
||||
}
|
||||
if strings.TrimSpace(choice.Get("function.name").String()) == "image_generation" {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func openAIAnyToolChoiceSelectsImageGeneration(choice any) bool {
|
||||
switch v := choice.(type) {
|
||||
case string:
|
||||
return strings.TrimSpace(v) == "image_generation"
|
||||
case map[string]any:
|
||||
if strings.TrimSpace(firstNonEmptyString(v["type"])) == "image_generation" {
|
||||
return true
|
||||
}
|
||||
if tool, ok := v["tool"].(map[string]any); ok && strings.TrimSpace(firstNonEmptyString(tool["type"])) == "image_generation" {
|
||||
return true
|
||||
}
|
||||
if fn, ok := v["function"].(map[string]any); ok && strings.TrimSpace(firstNonEmptyString(fn["name"])) == "image_generation" {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func getAPIKeyFromContext(c interface{ Get(string) (any, bool) }) *APIKey {
|
||||
if c == nil {
|
||||
return nil
|
||||
}
|
||||
v, exists := c.Get("api_key")
|
||||
if !exists {
|
||||
return nil
|
||||
}
|
||||
apiKey, _ := v.(*APIKey)
|
||||
return apiKey
|
||||
}
|
||||
|
||||
func apiKeyGroup(apiKey *APIKey) *Group {
|
||||
if apiKey == nil {
|
||||
return nil
|
||||
}
|
||||
return apiKey.Group
|
||||
}
|
||||
|
||||
func cloneRequestMapForImageIntent(body []byte) map[string]any {
|
||||
if len(body) == 0 {
|
||||
return nil
|
||||
}
|
||||
var out map[string]any
|
||||
if err := json.Unmarshal(body, &out); err != nil {
|
||||
return nil
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func resolveOpenAIResponsesImageBillingConfig(reqBody map[string]any, fallbackModel string) (string, string, error) {
|
||||
imageModel := ""
|
||||
imageSize := ""
|
||||
hasImageTool := false
|
||||
if reqBody != nil {
|
||||
rawTools, _ := reqBody["tools"].([]any)
|
||||
for _, rawTool := range rawTools {
|
||||
toolMap, ok := rawTool.(map[string]any)
|
||||
if !ok || strings.TrimSpace(firstNonEmptyString(toolMap["type"])) != "image_generation" {
|
||||
continue
|
||||
}
|
||||
hasImageTool = true
|
||||
imageModel = strings.TrimSpace(firstNonEmptyString(toolMap["model"]))
|
||||
imageSize = strings.TrimSpace(firstNonEmptyString(toolMap["size"]))
|
||||
break
|
||||
}
|
||||
if imageSize == "" {
|
||||
imageSize = strings.TrimSpace(firstNonEmptyString(reqBody["size"]))
|
||||
}
|
||||
}
|
||||
if imageModel == "" && reqBody != nil {
|
||||
bodyModel := strings.TrimSpace(firstNonEmptyString(reqBody["model"]))
|
||||
if isOpenAIImageBillingModelAlias(bodyModel) || !hasImageTool {
|
||||
imageModel = bodyModel
|
||||
}
|
||||
}
|
||||
if imageModel == "" && hasImageTool {
|
||||
imageModel = "gpt-image-2"
|
||||
}
|
||||
if imageModel == "" {
|
||||
imageModel = strings.TrimSpace(fallbackModel)
|
||||
}
|
||||
sizeTier := normalizeOpenAIImageSizeTier(imageSize)
|
||||
return imageModel, sizeTier, nil
|
||||
}
|
||||
|
||||
func resolveOpenAIResponsesImageBillingConfigFromBody(body []byte, fallbackModel string) (string, string, error) {
|
||||
reqBody := cloneRequestMapForImageIntent(body)
|
||||
return resolveOpenAIResponsesImageBillingConfig(reqBody, fallbackModel)
|
||||
}
|
||||
|
||||
func isOpenAIImageBillingModelAlias(model string) bool {
|
||||
normalized := strings.ToLower(strings.TrimSpace(model))
|
||||
if normalized == "" {
|
||||
return false
|
||||
}
|
||||
return isOpenAIImageGenerationModel(normalized) || strings.Contains(normalized, "image")
|
||||
}
|
||||
184
backend/internal/service/image_generation_intent_test.go
Normal file
184
backend/internal/service/image_generation_intent_test.go
Normal file
@ -0,0 +1,184 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestIsImageGenerationIntent(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
endpoint string
|
||||
model string
|
||||
body []byte
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "images endpoint",
|
||||
endpoint: "/v1/images/generations",
|
||||
body: []byte(`{"model":"gpt-image-2"}`),
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "image model",
|
||||
endpoint: "/v1/responses",
|
||||
model: "gpt-image-2",
|
||||
body: []byte(`{"model":"gpt-image-2"}`),
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "image tool",
|
||||
endpoint: "/v1/responses",
|
||||
model: "gpt-5.4",
|
||||
body: []byte(`{"model":"gpt-5.4","tools":[{"type":"image_generation"}]}`),
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "image tool choice",
|
||||
endpoint: "/v1/responses",
|
||||
model: "gpt-5.4",
|
||||
body: []byte(`{"model":"gpt-5.4","tool_choice":{"type":"image_generation"}}`),
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "required tool choice alone is text",
|
||||
endpoint: "/v1/responses",
|
||||
model: "gpt-5.4",
|
||||
body: []byte(`{"model":"gpt-5.4","tool_choice":"required"}`),
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "text only gpt 5.4",
|
||||
endpoint: "/v1/responses",
|
||||
model: "gpt-5.4",
|
||||
body: []byte(`{"model":"gpt-5.4","input":"write code"}`),
|
||||
want: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
require.Equal(t, tt.want, IsImageGenerationIntent(tt.endpoint, tt.model, tt.body))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveOpenAIResponsesImageBillingConfigUsesCurrentBodyModel(t *testing.T) {
|
||||
imageModel, imageSize, err := resolveOpenAIResponsesImageBillingConfigFromBody(
|
||||
[]byte(`{"model":"mapped-image-model","tools":[{"type":"image_generation","size":"1024x1024"}]}`),
|
||||
"requested-model",
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "mapped-image-model", imageModel)
|
||||
require.Equal(t, "1K", imageSize)
|
||||
}
|
||||
|
||||
func TestResolveOpenAIResponsesImageBillingConfigToolModelWins(t *testing.T) {
|
||||
imageModel, imageSize, err := resolveOpenAIResponsesImageBillingConfigFromBody(
|
||||
[]byte(`{"model":"mapped-text-model","tools":[{"type":"image_generation","model":"gpt-image-2","size":"1536x1024"}]}`),
|
||||
"requested-model",
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "gpt-image-2", imageModel)
|
||||
require.Equal(t, "2K", imageSize)
|
||||
}
|
||||
|
||||
func TestResolveOpenAIResponsesImageBillingConfigSupportsOfficialAndCustomSizes(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
body []byte
|
||||
wantTier string
|
||||
}{
|
||||
{
|
||||
name: "official 2k landscape",
|
||||
body: []byte(`{"model":"gpt-5.4","tools":[{"type":"image_generation","model":"gpt-image-2","size":"2048x1152"}]}`),
|
||||
wantTier: "2K",
|
||||
},
|
||||
{
|
||||
name: "official 4k landscape",
|
||||
body: []byte(`{"model":"gpt-5.4","tools":[{"type":"image_generation","model":"gpt-image-2","size":"3840x2160"}]}`),
|
||||
wantTier: "4K",
|
||||
},
|
||||
{
|
||||
name: "custom valid 2k",
|
||||
body: []byte(`{"model":"gpt-5.5","tools":[{"type":"image_generation","model":"gpt-image-2","size":"1280x768"}]}`),
|
||||
wantTier: "2K",
|
||||
},
|
||||
{
|
||||
name: "default image tool model supports flexible size",
|
||||
body: []byte(`{"model":"gpt-5.4","tools":[{"type":"image_generation","size":"2048x1152"}]}`),
|
||||
wantTier: "2K",
|
||||
},
|
||||
{
|
||||
name: "top level image size is moved into billing",
|
||||
body: []byte(`{"model":"gpt-image-2","size":"2048x2048","tools":[{"type":"image_generation","model":"gpt-image-2"}]}`),
|
||||
wantTier: "2K",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
imageModel, imageSize, err := resolveOpenAIResponsesImageBillingConfigFromBody(tt.body, "requested-model")
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, imageModel)
|
||||
require.Equal(t, tt.wantTier, imageSize)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveOpenAIResponsesImageBillingConfigDoesNotRejectUnknownSizes(t *testing.T) {
|
||||
imageModel, imageSize, err := resolveOpenAIResponsesImageBillingConfigFromBody(
|
||||
[]byte(`{"model":"gpt-5.4","tools":[{"type":"image_generation","model":"gpt-image-1.5","size":"2048x1152"}]}`),
|
||||
"requested-model",
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "gpt-image-1.5", imageModel)
|
||||
require.Equal(t, "2K", imageSize)
|
||||
}
|
||||
|
||||
func TestOpenAIImageOutputCounterDeduplicatesFinalImages(t *testing.T) {
|
||||
counter := newOpenAIImageOutputCounter()
|
||||
counter.AddSSEData([]byte(`{"type":"response.image_generation_call.partial_image","partial_image_b64":"abc"}`))
|
||||
counter.AddSSEData([]byte(`{"type":"response.output_item.done","item":{"id":"ig_1","type":"image_generation_call","result":"final-a"}}`))
|
||||
counter.AddSSEData([]byte(`{"type":"response.completed","response":{"output":[{"id":"ig_1","type":"image_generation_call","result":"final-a"},{"id":"ig_2","type":"image_generation_call","result":"final-b"}]}}`))
|
||||
require.Equal(t, 2, counter.Count())
|
||||
}
|
||||
|
||||
func TestOpenAIImageOutputCounterCountsImagesAPIStreamShapes(t *testing.T) {
|
||||
counter := newOpenAIImageOutputCounter()
|
||||
counter.AddSSEData([]byte(`{"type":"image_generation.completed","id":"ig_complete","b64_json":"final-a"}`))
|
||||
counter.AddSSEData([]byte(`{"type":"response.output_item.done","item":{"id":"ig_item","type":"image_generation_call","result":"final-b"}}`))
|
||||
counter.AddSSEData([]byte(`{"type":"response.completed","response":{"output":[{"id":"ig_done","type":"image_generation_call","result":"final-c"}]}}`))
|
||||
require.Equal(t, 3, counter.Count())
|
||||
|
||||
dataCounter := newOpenAIImageOutputCounter()
|
||||
dataCounter.AddSSEData([]byte(`{"data":[{"b64_json":"a"},{"b64_json":"b"}]}`))
|
||||
dataCounter.AddSSEData([]byte(`{"data":[{"b64_json":"a"},{"b64_json":"b"},{"b64_json":"c"}]}`))
|
||||
require.Equal(t, 3, dataCounter.Count())
|
||||
}
|
||||
|
||||
func TestOpenAIImageOutputCounterCountsMultilineSSEDataPayload(t *testing.T) {
|
||||
counter := newOpenAIImageOutputCounter()
|
||||
counter.AddSSEData([]byte("{\"type\":\"image_generation.completed\",\n\"b64_json\":\"final-a\"}"))
|
||||
require.Equal(t, 1, counter.Count())
|
||||
}
|
||||
|
||||
func TestOpenAIImageOutputCounterCountsMultilineSSEBodyPayload(t *testing.T) {
|
||||
counter := newOpenAIImageOutputCounter()
|
||||
counter.AddSSEBody(
|
||||
"data: {\"type\":\"image_generation.completed\",\n" +
|
||||
"data: \"b64_json\":\"final-a\"}\n\n" +
|
||||
"data: [DONE]\n\n",
|
||||
)
|
||||
require.Equal(t, 1, counter.Count())
|
||||
}
|
||||
|
||||
func TestOpenAIImageOutputCounterFallsBackForInvalidMultilineSSEBody(t *testing.T) {
|
||||
counter := newOpenAIImageOutputCounter()
|
||||
counter.AddSSEBody(
|
||||
"data: {\"type\":\"image_generation.completed\",\"b64_json\":\"final-a\"}\n" +
|
||||
"data: {\"type\":\"image_generation.completed\",\"b64_json\":\"final-b\"}\n\n",
|
||||
)
|
||||
require.Equal(t, 2, counter.Count())
|
||||
}
|
||||
149
backend/internal/service/image_output_accounting.go
Normal file
149
backend/internal/service/image_output_accounting.go
Normal file
@ -0,0 +1,149 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"strings"
|
||||
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
type openAIImageOutputCounter struct {
|
||||
seen map[string]struct{}
|
||||
count int
|
||||
maxDataCount int
|
||||
}
|
||||
|
||||
func newOpenAIImageOutputCounter() *openAIImageOutputCounter {
|
||||
return &openAIImageOutputCounter{seen: make(map[string]struct{})}
|
||||
}
|
||||
|
||||
func (c *openAIImageOutputCounter) Count() int {
|
||||
if c == nil {
|
||||
return 0
|
||||
}
|
||||
if c.maxDataCount > c.count {
|
||||
return c.maxDataCount
|
||||
}
|
||||
return c.count
|
||||
}
|
||||
|
||||
func (c *openAIImageOutputCounter) AddJSONResponse(body []byte) {
|
||||
if c == nil || len(body) == 0 || !gjson.ValidBytes(body) {
|
||||
return
|
||||
}
|
||||
c.addDataArray(gjson.GetBytes(body, "data"))
|
||||
c.addOutputArray(gjson.GetBytes(body, "output"))
|
||||
c.addOutputArray(gjson.GetBytes(body, "response.output"))
|
||||
}
|
||||
|
||||
func (c *openAIImageOutputCounter) AddSSEData(data []byte) {
|
||||
if c == nil || len(data) == 0 || strings.TrimSpace(string(data)) == "[DONE]" || !gjson.ValidBytes(data) {
|
||||
return
|
||||
}
|
||||
root := gjson.ParseBytes(data)
|
||||
c.addDataArray(root.Get("data"))
|
||||
eventType := strings.TrimSpace(root.Get("type").String())
|
||||
switch eventType {
|
||||
case "response.output_item.done":
|
||||
c.addImageOutputItem(root.Get("item"))
|
||||
case "response.completed", "response.done":
|
||||
c.addOutputArray(root.Get("response.output"))
|
||||
case "image_generation.completed":
|
||||
if item := root.Get("item"); item.Exists() {
|
||||
c.addImageOutputItem(item)
|
||||
return
|
||||
}
|
||||
if output := root.Get("output"); output.Exists() {
|
||||
c.addImageOutputItem(output)
|
||||
return
|
||||
}
|
||||
c.addImageOutputItem(root)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *openAIImageOutputCounter) AddSSEBody(body string) {
|
||||
if c == nil || strings.TrimSpace(body) == "" {
|
||||
return
|
||||
}
|
||||
forEachOpenAISSEDataPayload(body, c.AddSSEData)
|
||||
}
|
||||
|
||||
func (c *openAIImageOutputCounter) addDataArray(data gjson.Result) {
|
||||
if !data.IsArray() {
|
||||
return
|
||||
}
|
||||
count := len(data.Array())
|
||||
if count > c.maxDataCount {
|
||||
c.maxDataCount = count
|
||||
}
|
||||
}
|
||||
|
||||
func (c *openAIImageOutputCounter) addOutputArray(output gjson.Result) {
|
||||
if !output.IsArray() {
|
||||
return
|
||||
}
|
||||
output.ForEach(func(_, item gjson.Result) bool {
|
||||
c.addImageOutputItem(item)
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
func (c *openAIImageOutputCounter) addImageOutputItem(item gjson.Result) {
|
||||
if !item.Exists() || !item.IsObject() {
|
||||
return
|
||||
}
|
||||
itemType := strings.TrimSpace(item.Get("type").String())
|
||||
if itemType != "" && itemType != "image_generation_call" && itemType != "image_generation.completed" {
|
||||
return
|
||||
}
|
||||
if strings.Contains(strings.ToLower(item.Raw), "partial_image") {
|
||||
return
|
||||
}
|
||||
result := strings.TrimSpace(item.Get("result").String())
|
||||
if result == "" {
|
||||
result = strings.TrimSpace(item.Get("b64_json").String())
|
||||
}
|
||||
if result == "" {
|
||||
result = strings.TrimSpace(item.Get("url").String())
|
||||
}
|
||||
if result == "" && itemType != "image_generation.completed" {
|
||||
return
|
||||
}
|
||||
key := strings.TrimSpace(item.Get("id").String())
|
||||
if key == "" {
|
||||
key = strings.TrimSpace(item.Get("call_id").String())
|
||||
}
|
||||
if key == "" {
|
||||
key = hashOpenAIImageOutputResult(result)
|
||||
}
|
||||
if key == "" {
|
||||
return
|
||||
}
|
||||
if _, exists := c.seen[key]; exists {
|
||||
return
|
||||
}
|
||||
c.seen[key] = struct{}{}
|
||||
c.count++
|
||||
}
|
||||
|
||||
func hashOpenAIImageOutputResult(result string) string {
|
||||
result = strings.TrimSpace(result)
|
||||
if result == "" {
|
||||
return ""
|
||||
}
|
||||
sum := sha256.Sum256([]byte(result))
|
||||
return hex.EncodeToString(sum[:])
|
||||
}
|
||||
|
||||
func countOpenAIResponseImageOutputsFromJSONBytes(body []byte) int {
|
||||
counter := newOpenAIImageOutputCounter()
|
||||
counter.AddJSONResponse(body)
|
||||
return counter.Count()
|
||||
}
|
||||
|
||||
func countOpenAIImageOutputsFromSSEBody(body string) int {
|
||||
counter := newOpenAIImageOutputCounter()
|
||||
counter.AddSSEBody(body)
|
||||
return counter.Count()
|
||||
}
|
||||
@ -5,6 +5,7 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"hash/fnv"
|
||||
"log/slog"
|
||||
"math"
|
||||
"sort"
|
||||
"strconv"
|
||||
@ -345,7 +346,7 @@ func (s *defaultOpenAIAccountScheduler) selectBySessionHash(
|
||||
_ = s.service.deleteStickySessionAccountID(ctx, req.GroupID, sessionHash)
|
||||
return nil, nil
|
||||
}
|
||||
if !s.isAccountRequestCompatible(account, req) {
|
||||
if !s.isAccountRequestCompatible(ctx, account, req) {
|
||||
return nil, nil
|
||||
}
|
||||
if !s.isAccountTransportCompatible(account, req.RequiredTransport) {
|
||||
@ -621,7 +622,7 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance(
|
||||
fmt.Sprintf("Privacy not set, required by group [%s]", schedGroup.Name))
|
||||
continue
|
||||
}
|
||||
if !s.isAccountRequestCompatible(account, req) {
|
||||
if !s.isAccountRequestCompatible(ctx, account, req) {
|
||||
continue
|
||||
}
|
||||
if !s.isAccountTransportCompatible(account, req.RequiredTransport) {
|
||||
@ -822,11 +823,11 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance(
|
||||
for i := 0; i < len(selectionOrder); i++ {
|
||||
candidate := selectionOrder[i]
|
||||
fresh := s.service.resolveFreshSchedulableOpenAIAccount(ctx, candidate.account, req.RequestedModel, false)
|
||||
if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) || !s.isAccountRequestCompatible(fresh, req) {
|
||||
if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) || !s.isAccountRequestCompatible(ctx, fresh, req) {
|
||||
continue
|
||||
}
|
||||
fresh = s.service.recheckSelectedOpenAIAccountFromDB(ctx, fresh, req.RequestedModel, false)
|
||||
if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) || !s.isAccountRequestCompatible(fresh, req) {
|
||||
if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) || !s.isAccountRequestCompatible(ctx, fresh, req) {
|
||||
continue
|
||||
}
|
||||
if req.RequireCompact && openAICompactSupportTier(fresh) == 0 {
|
||||
@ -853,11 +854,11 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance(
|
||||
// WaitPlan.MaxConcurrency 使用 Concurrency(非 EffectiveLoadFactor),因为 WaitPlan 控制的是 Redis 实际并发槽位等待。
|
||||
for _, candidate := range selectionOrder {
|
||||
fresh := s.service.resolveFreshSchedulableOpenAIAccount(ctx, candidate.account, req.RequestedModel, false)
|
||||
if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) || !s.isAccountRequestCompatible(fresh, req) {
|
||||
if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) || !s.isAccountRequestCompatible(ctx, fresh, req) {
|
||||
continue
|
||||
}
|
||||
fresh = s.service.recheckSelectedOpenAIAccountFromDB(ctx, fresh, req.RequestedModel, false)
|
||||
if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) || !s.isAccountRequestCompatible(fresh, req) {
|
||||
if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) || !s.isAccountRequestCompatible(ctx, fresh, req) {
|
||||
continue
|
||||
}
|
||||
if req.RequireCompact && openAICompactSupportTier(fresh) == 0 {
|
||||
@ -888,13 +889,18 @@ func (s *defaultOpenAIAccountScheduler) isAccountTransportCompatible(account *Ac
|
||||
return s.service.isOpenAIAccountTransportCompatible(account, requiredTransport)
|
||||
}
|
||||
|
||||
func (s *defaultOpenAIAccountScheduler) isAccountRequestCompatible(account *Account, req OpenAIAccountScheduleRequest) bool {
|
||||
func (s *defaultOpenAIAccountScheduler) isAccountRequestCompatible(ctx context.Context, account *Account, req OpenAIAccountScheduleRequest) bool {
|
||||
if account == nil {
|
||||
return false
|
||||
}
|
||||
if req.RequestedModel != "" && !account.IsModelSupported(req.RequestedModel) {
|
||||
return false
|
||||
}
|
||||
if req.GroupID != nil && s != nil && s.service != nil &&
|
||||
s.service.needsUpstreamChannelRestrictionCheck(ctx, req.GroupID) &&
|
||||
s.service.isUpstreamModelRestrictedByChannel(ctx, *req.GroupID, account, req.RequestedModel, req.RequireCompact) {
|
||||
return false
|
||||
}
|
||||
return account.SupportsOpenAIImageCapability(req.RequiredImageCapability)
|
||||
}
|
||||
|
||||
@ -1106,6 +1112,13 @@ func (s *OpenAIGatewayService) selectAccountWithScheduler(
|
||||
}
|
||||
}
|
||||
|
||||
if s.checkChannelPricingRestriction(ctx, groupID, requestedModel) {
|
||||
slog.Warn("channel pricing restriction blocked request",
|
||||
"group_id", derefGroupID(groupID),
|
||||
"model", requestedModel)
|
||||
return nil, decision, fmt.Errorf("%w supporting model: %s (channel pricing restriction)", ErrNoAvailableAccounts, requestedModel)
|
||||
}
|
||||
|
||||
var stickyAccountID int64
|
||||
if sessionHash != "" && s.cache != nil {
|
||||
if accountID, err := s.getStickySessionAccountID(ctx, groupID, sessionHash); err == nil && accountID > 0 {
|
||||
|
||||
@ -485,12 +485,14 @@ func normalizeKnownCodexModel(model string) (string, bool) {
|
||||
return model, true
|
||||
}
|
||||
|
||||
modelID := model
|
||||
if strings.Contains(modelID, "/") {
|
||||
parts := strings.Split(modelID, "/")
|
||||
modelID = parts[len(parts)-1]
|
||||
}
|
||||
modelID := lastOpenAIModelSegment(model)
|
||||
|
||||
if normalized := canonicalizeOpenAIModelAliasSpelling(modelID); normalized != "" {
|
||||
modelID = normalized
|
||||
}
|
||||
if mapped := normalizeKnownOpenAICodexModel(modelID); mapped != "" {
|
||||
return mapped, true
|
||||
}
|
||||
key := codexModelLookupKey(modelID)
|
||||
if key == "" {
|
||||
return "", false
|
||||
|
||||
@ -804,15 +804,25 @@ func TestApplyCodexOAuthTransform_EmptyInput(t *testing.T) {
|
||||
func TestNormalizeCodexModel_Gpt53(t *testing.T) {
|
||||
cases := map[string]string{
|
||||
"gpt-5.4": "gpt-5.4",
|
||||
"gpt5.5": "gpt-5.5",
|
||||
"openai/gpt5.5": "gpt-5.5",
|
||||
"gpt5.4": "gpt-5.4",
|
||||
"gpt-5.4-high": "gpt-5.4",
|
||||
"gpt-5.4-chat-latest": "gpt-5.4",
|
||||
"gpt 5.4": "gpt-5.4",
|
||||
"gpt-5.4-mini": "gpt-5.4-mini",
|
||||
"gpt5.4-mini": "gpt-5.4-mini",
|
||||
"gpt5.4mini": "gpt-5.4-mini",
|
||||
"gpt 5.4 mini": "gpt-5.4-mini",
|
||||
"gpt-5.3": "gpt-5.3-codex",
|
||||
"gpt5.3": "gpt-5.3-codex",
|
||||
"gpt-5.3-codex": "gpt-5.3-codex",
|
||||
"gpt5.3-codex": "gpt-5.3-codex",
|
||||
"gpt5.3codex": "gpt-5.3-codex",
|
||||
"gpt-5.3-codex-xhigh": "gpt-5.3-codex",
|
||||
"gpt-5.3-codex-spark": "gpt-5.3-codex-spark",
|
||||
"gpt5.3-codex-spark": "gpt-5.3-codex-spark",
|
||||
"gpt5.3codexspark": "gpt-5.3-codex-spark",
|
||||
"gpt 5.3 codex spark": "gpt-5.3-codex-spark",
|
||||
"gpt-5.3-codex-spark-high": "gpt-5.3-codex-spark",
|
||||
"gpt-5.3-codex-spark-xhigh": "gpt-5.3-codex-spark",
|
||||
|
||||
@ -52,6 +52,12 @@ func (s *openAIRecordUsageBillingRepoStub) Apply(ctx context.Context, cmd *Usage
|
||||
return &UsageBillingApplyResult{Applied: true}, nil
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayServiceRecordUsage_RejectsNilInput(t *testing.T) {
|
||||
svc := &OpenAIGatewayService{}
|
||||
require.Error(t, svc.RecordUsage(context.Background(), nil))
|
||||
require.Error(t, svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{}))
|
||||
}
|
||||
|
||||
type openAIRecordUsageUserRepoStub struct {
|
||||
UserRepository
|
||||
|
||||
@ -1081,6 +1087,101 @@ func TestOpenAIGatewayServiceRecordUsage_ChannelMappedOverridesBillingModelWhenM
|
||||
require.True(t, usageRepo.lastLog.ActualCost > 0, "cost must not be zero")
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayServiceRecordUsage_BillsCompactOpenAIModelAlias(t *testing.T) {
|
||||
usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
|
||||
userRepo := &openAIRecordUsageUserRepoStub{}
|
||||
subRepo := &openAIRecordUsageSubRepoStub{}
|
||||
svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil)
|
||||
usage := OpenAIUsage{InputTokens: 20, OutputTokens: 10}
|
||||
|
||||
expectedCost, err := svc.billingService.CalculateCost("gpt-5.5", UsageTokens{
|
||||
InputTokens: 20,
|
||||
OutputTokens: 10,
|
||||
}, 1.1)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
|
||||
Result: &OpenAIForwardResult{
|
||||
RequestID: "resp_compact_openai_alias",
|
||||
Model: "gpt5.5",
|
||||
UpstreamModel: "gpt-5.4",
|
||||
Usage: usage,
|
||||
Duration: time.Second,
|
||||
},
|
||||
APIKey: &APIKey{ID: 10},
|
||||
User: &User{ID: 20},
|
||||
Account: &Account{ID: 30},
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, usageRepo.lastLog)
|
||||
require.Equal(t, "gpt5.5", usageRepo.lastLog.Model)
|
||||
require.NotNil(t, usageRepo.lastLog.UpstreamModel)
|
||||
require.Equal(t, "gpt-5.4", *usageRepo.lastLog.UpstreamModel)
|
||||
require.InDelta(t, expectedCost.ActualCost, usageRepo.lastLog.ActualCost, 1e-12)
|
||||
require.True(t, usageRepo.lastLog.ActualCost > 0, "cost must not be zero")
|
||||
require.InDelta(t, expectedCost.ActualCost, userRepo.lastAmount, 1e-12)
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayServiceRecordUsage_FallsBackToUpstreamModelWhenPrimaryUnpriceable(t *testing.T) {
|
||||
usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
|
||||
userRepo := &openAIRecordUsageUserRepoStub{}
|
||||
subRepo := &openAIRecordUsageSubRepoStub{}
|
||||
svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil)
|
||||
usage := OpenAIUsage{InputTokens: 20, OutputTokens: 10}
|
||||
|
||||
expectedCost, err := svc.billingService.CalculateCost("gpt-5.4", UsageTokens{
|
||||
InputTokens: 20,
|
||||
OutputTokens: 10,
|
||||
}, 1.1)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
|
||||
Result: &OpenAIForwardResult{
|
||||
RequestID: "resp_unpriceable_primary_upstream_fallback",
|
||||
Model: "not-priceable-alias",
|
||||
BillingModel: "not-priceable-alias",
|
||||
UpstreamModel: "gpt-5.4",
|
||||
Usage: usage,
|
||||
Duration: time.Second,
|
||||
},
|
||||
APIKey: &APIKey{ID: 10},
|
||||
User: &User{ID: 20},
|
||||
Account: &Account{ID: 30},
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, usageRepo.lastLog)
|
||||
require.InDelta(t, expectedCost.ActualCost, usageRepo.lastLog.ActualCost, 1e-12)
|
||||
require.True(t, usageRepo.lastLog.ActualCost > 0, "cost must not be zero")
|
||||
require.InDelta(t, expectedCost.ActualCost, userRepo.lastAmount, 1e-12)
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayServiceRecordUsage_ReturnsErrorWhenTokenModelCannotBePriced(t *testing.T) {
|
||||
usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
|
||||
userRepo := &openAIRecordUsageUserRepoStub{}
|
||||
subRepo := &openAIRecordUsageSubRepoStub{}
|
||||
svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil)
|
||||
|
||||
err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
|
||||
Result: &OpenAIForwardResult{
|
||||
RequestID: "resp_unpriceable_without_upstream",
|
||||
Model: "not-priceable-alias",
|
||||
Usage: OpenAIUsage{InputTokens: 20, OutputTokens: 10},
|
||||
Duration: time.Second,
|
||||
},
|
||||
APIKey: &APIKey{ID: 10},
|
||||
User: &User{ID: 20},
|
||||
Account: &Account{ID: 30},
|
||||
})
|
||||
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "calculate OpenAI usage cost failed")
|
||||
require.Equal(t, 0, usageRepo.calls)
|
||||
require.Equal(t, 0, userRepo.deductCalls)
|
||||
require.Equal(t, 0, subRepo.incrementCalls)
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayServiceRecordUsage_SubscriptionBillingSetsSubscriptionFields(t *testing.T) {
|
||||
usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
|
||||
userRepo := &openAIRecordUsageUserRepoStub{}
|
||||
@ -1209,3 +1310,278 @@ func TestOpenAIGatewayServiceRecordUsage_ImageUsesPerImageBillingEvenWithUsageTo
|
||||
require.InDelta(t, 0.0, usageRepo.lastLog.OutputCost, 1e-12)
|
||||
require.InDelta(t, 0.0, usageRepo.lastLog.ImageOutputCost, 1e-12)
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayServiceRecordUsage_ImageSharedMultiplierPreservesExistingBehavior(t *testing.T) {
|
||||
imagePrice := 0.2
|
||||
groupID := int64(121)
|
||||
|
||||
usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
|
||||
svc := newOpenAIRecordUsageServiceForTest(usageRepo, &openAIRecordUsageUserRepoStub{}, &openAIRecordUsageSubRepoStub{}, nil)
|
||||
|
||||
err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
|
||||
Result: &OpenAIForwardResult{
|
||||
RequestID: "resp_image_shared_multiplier",
|
||||
Model: "gpt-image-2",
|
||||
ImageCount: 1,
|
||||
ImageSize: "1K",
|
||||
Duration: time.Second,
|
||||
},
|
||||
APIKey: &APIKey{
|
||||
ID: 10121,
|
||||
GroupID: i64p(groupID),
|
||||
Group: &Group{
|
||||
ID: groupID,
|
||||
RateMultiplier: 0.15,
|
||||
ImageRateIndependent: false,
|
||||
ImageRateMultiplier: 1,
|
||||
ImagePrice1K: &imagePrice,
|
||||
},
|
||||
},
|
||||
User: &User{ID: 20121},
|
||||
Account: &Account{ID: 30121},
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, usageRepo.lastLog)
|
||||
require.InDelta(t, 0.2, usageRepo.lastLog.TotalCost, 1e-12)
|
||||
require.InDelta(t, 0.03, usageRepo.lastLog.ActualCost, 1e-12)
|
||||
require.InDelta(t, 0.15, usageRepo.lastLog.RateMultiplier, 1e-12)
|
||||
require.NotNil(t, usageRepo.lastLog.BillingMode)
|
||||
require.Equal(t, string(BillingModeImage), *usageRepo.lastLog.BillingMode)
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayServiceRecordUsage_ImageSharedMultiplierUsesUserGroupOverride(t *testing.T) {
|
||||
imagePrice := 0.5
|
||||
userRate := 0.2
|
||||
groupID := int64(125)
|
||||
|
||||
usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
|
||||
svc := newOpenAIRecordUsageServiceForTest(
|
||||
usageRepo,
|
||||
&openAIRecordUsageUserRepoStub{},
|
||||
&openAIRecordUsageSubRepoStub{},
|
||||
&openAIUserGroupRateRepoStub{rate: &userRate},
|
||||
)
|
||||
|
||||
err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
|
||||
Result: &OpenAIForwardResult{
|
||||
RequestID: "resp_image_user_group_override",
|
||||
Model: "gpt-image-2",
|
||||
ImageCount: 1,
|
||||
ImageSize: "1K",
|
||||
Duration: time.Second,
|
||||
},
|
||||
APIKey: &APIKey{
|
||||
ID: 10125,
|
||||
GroupID: i64p(groupID),
|
||||
Group: &Group{
|
||||
ID: groupID,
|
||||
RateMultiplier: 0.15,
|
||||
ImageRateIndependent: false,
|
||||
ImageRateMultiplier: 1,
|
||||
ImagePrice1K: &imagePrice,
|
||||
},
|
||||
},
|
||||
User: &User{ID: 20125},
|
||||
Account: &Account{ID: 30125},
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, usageRepo.lastLog)
|
||||
require.InDelta(t, 0.5, usageRepo.lastLog.TotalCost, 1e-12)
|
||||
require.InDelta(t, 0.1, usageRepo.lastLog.ActualCost, 1e-12)
|
||||
require.InDelta(t, 0.2, usageRepo.lastLog.RateMultiplier, 1e-12)
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayServiceRecordUsage_ImageIndependentMultiplierUsesImageRate(t *testing.T) {
|
||||
imagePrice := 0.2
|
||||
groupID := int64(122)
|
||||
|
||||
usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
|
||||
svc := newOpenAIRecordUsageServiceForTest(usageRepo, &openAIRecordUsageUserRepoStub{}, &openAIRecordUsageSubRepoStub{}, nil)
|
||||
|
||||
err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
|
||||
Result: &OpenAIForwardResult{
|
||||
RequestID: "resp_image_independent_multiplier",
|
||||
Model: "gpt-image-2",
|
||||
ImageCount: 1,
|
||||
ImageSize: "1K",
|
||||
Duration: time.Second,
|
||||
},
|
||||
APIKey: &APIKey{
|
||||
ID: 10122,
|
||||
GroupID: i64p(groupID),
|
||||
Group: &Group{
|
||||
ID: groupID,
|
||||
RateMultiplier: 0.15,
|
||||
ImageRateIndependent: true,
|
||||
ImageRateMultiplier: 1,
|
||||
ImagePrice1K: &imagePrice,
|
||||
},
|
||||
},
|
||||
User: &User{ID: 20122},
|
||||
Account: &Account{ID: 30122},
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, usageRepo.lastLog)
|
||||
require.InDelta(t, 0.2, usageRepo.lastLog.TotalCost, 1e-12)
|
||||
require.InDelta(t, 0.2, usageRepo.lastLog.ActualCost, 1e-12)
|
||||
require.InDelta(t, 1.0, usageRepo.lastLog.RateMultiplier, 1e-12)
|
||||
require.NotNil(t, usageRepo.lastLog.BillingMode)
|
||||
require.Equal(t, string(BillingModeImage), *usageRepo.lastLog.BillingMode)
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayServiceRecordUsage_ChannelImageBillingUsesImageCountAndSharedMultiplier(t *testing.T) {
|
||||
groupID := int64(123)
|
||||
usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
|
||||
svc := newOpenAIRecordUsageServiceForTest(usageRepo, &openAIRecordUsageUserRepoStub{}, &openAIRecordUsageSubRepoStub{}, nil)
|
||||
svc.resolver = newOpenAIImageChannelPricingResolverForTest(t, groupID, "gpt-image-2", 0.25)
|
||||
|
||||
err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
|
||||
Result: &OpenAIForwardResult{
|
||||
RequestID: "resp_image_channel_shared",
|
||||
Model: "gpt-image-2",
|
||||
ImageCount: 3,
|
||||
ImageSize: "1K",
|
||||
Duration: time.Second,
|
||||
},
|
||||
APIKey: &APIKey{
|
||||
ID: 10123,
|
||||
GroupID: i64p(groupID),
|
||||
Group: &Group{
|
||||
ID: groupID,
|
||||
RateMultiplier: 0.15,
|
||||
ImageRateIndependent: false,
|
||||
ImageRateMultiplier: 1,
|
||||
},
|
||||
},
|
||||
User: &User{ID: 20123},
|
||||
Account: &Account{ID: 30123},
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, usageRepo.lastLog)
|
||||
require.InDelta(t, 0.75, usageRepo.lastLog.TotalCost, 1e-12)
|
||||
require.InDelta(t, 0.1125, usageRepo.lastLog.ActualCost, 1e-12)
|
||||
require.InDelta(t, 0.15, usageRepo.lastLog.RateMultiplier, 1e-12)
|
||||
require.Equal(t, 3, usageRepo.lastLog.ImageCount)
|
||||
require.NotNil(t, usageRepo.lastLog.BillingMode)
|
||||
require.Equal(t, string(BillingModeImage), *usageRepo.lastLog.BillingMode)
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayServiceRecordUsage_ChannelImageBillingUsesImageCountAndIndependentMultiplier(t *testing.T) {
|
||||
groupID := int64(124)
|
||||
usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
|
||||
svc := newOpenAIRecordUsageServiceForTest(usageRepo, &openAIRecordUsageUserRepoStub{}, &openAIRecordUsageSubRepoStub{}, nil)
|
||||
svc.resolver = newOpenAIImageChannelPricingResolverForTest(t, groupID, "gpt-image-2", 0.25)
|
||||
|
||||
err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
|
||||
Result: &OpenAIForwardResult{
|
||||
RequestID: "resp_image_channel_independent",
|
||||
Model: "gpt-image-2",
|
||||
ImageCount: 3,
|
||||
ImageSize: "1K",
|
||||
Duration: time.Second,
|
||||
},
|
||||
APIKey: &APIKey{
|
||||
ID: 10124,
|
||||
GroupID: i64p(groupID),
|
||||
Group: &Group{
|
||||
ID: groupID,
|
||||
RateMultiplier: 0.15,
|
||||
ImageRateIndependent: true,
|
||||
ImageRateMultiplier: 1,
|
||||
},
|
||||
},
|
||||
User: &User{ID: 20124},
|
||||
Account: &Account{ID: 30124},
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, usageRepo.lastLog)
|
||||
require.InDelta(t, 0.75, usageRepo.lastLog.TotalCost, 1e-12)
|
||||
require.InDelta(t, 0.75, usageRepo.lastLog.ActualCost, 1e-12)
|
||||
require.InDelta(t, 1.0, usageRepo.lastLog.RateMultiplier, 1e-12)
|
||||
require.Equal(t, 3, usageRepo.lastLog.ImageCount)
|
||||
require.NotNil(t, usageRepo.lastLog.BillingMode)
|
||||
require.Equal(t, string(BillingModeImage), *usageRepo.lastLog.BillingMode)
|
||||
}
|
||||
|
||||
func newOpenAIImageChannelPricingResolverForTest(t *testing.T, groupID int64, model string, price float64) *ModelPricingResolver {
|
||||
t.Helper()
|
||||
cache := newEmptyChannelCache()
|
||||
cache.pricingByGroupModel[channelModelKey{groupID: groupID, model: model}] = &ChannelModelPricing{
|
||||
BillingMode: BillingModeImage,
|
||||
PerRequestPrice: &price,
|
||||
}
|
||||
cache.channelByGroupID[groupID] = &Channel{ID: groupID, Status: StatusActive}
|
||||
cache.groupPlatform[groupID] = ""
|
||||
cache.loadedAt = time.Now()
|
||||
cs := &ChannelService{}
|
||||
cs.cache.Store(cache)
|
||||
return NewModelPricingResolver(cs, NewBillingService(&config.Config{}, nil))
|
||||
}
|
||||
|
||||
func TestGatewayServiceCalculateRecordUsageCost_ChannelImageBillingUsesImageCount(t *testing.T) {
|
||||
groupID := int64(126)
|
||||
billingService := NewBillingService(&config.Config{}, nil)
|
||||
svc := &GatewayService{
|
||||
billingService: billingService,
|
||||
resolver: newOpenAIImageChannelPricingResolverForTest(t, groupID, "gemini-image", 0.25),
|
||||
}
|
||||
|
||||
cost := svc.calculateRecordUsageCost(
|
||||
context.Background(),
|
||||
&ForwardResult{Model: "gemini-image", ImageCount: 2, ImageSize: "1K"},
|
||||
&APIKey{GroupID: i64p(groupID), Group: &Group{ID: groupID}},
|
||||
"gemini-image",
|
||||
0.15,
|
||||
1.0,
|
||||
nil,
|
||||
)
|
||||
|
||||
require.NotNil(t, cost)
|
||||
require.Equal(t, string(BillingModeImage), cost.BillingMode)
|
||||
require.InDelta(t, 0.5, cost.TotalCost, 1e-12)
|
||||
require.InDelta(t, 0.5, cost.ActualCost, 1e-12)
|
||||
}
|
||||
|
||||
func TestGatewayServiceCalculateRecordUsageCost_ChannelImageBillingUsesSizeTier(t *testing.T) {
|
||||
groupID := int64(127)
|
||||
defaultPrice := 0.10
|
||||
price4K := 0.40
|
||||
cache := newEmptyChannelCache()
|
||||
cache.pricingByGroupModel[channelModelKey{groupID: groupID, model: "gemini-image"}] = &ChannelModelPricing{
|
||||
BillingMode: BillingModeImage,
|
||||
PerRequestPrice: &defaultPrice,
|
||||
Intervals: []PricingInterval{{
|
||||
TierLabel: "4K",
|
||||
PerRequestPrice: &price4K,
|
||||
}},
|
||||
}
|
||||
cache.channelByGroupID[groupID] = &Channel{ID: groupID, Status: StatusActive}
|
||||
cache.loadedAt = time.Now()
|
||||
channelService := &ChannelService{}
|
||||
channelService.cache.Store(cache)
|
||||
|
||||
svc := &GatewayService{
|
||||
billingService: NewBillingService(&config.Config{}, nil),
|
||||
resolver: NewModelPricingResolver(channelService, NewBillingService(&config.Config{}, nil)),
|
||||
}
|
||||
|
||||
cost := svc.calculateRecordUsageCost(
|
||||
context.Background(),
|
||||
&ForwardResult{Model: "gemini-image", ImageCount: 2, ImageSize: "4K"},
|
||||
&APIKey{GroupID: i64p(groupID), Group: &Group{ID: groupID}},
|
||||
"gemini-image",
|
||||
1.0,
|
||||
1.0,
|
||||
nil,
|
||||
)
|
||||
|
||||
require.NotNil(t, cost)
|
||||
require.Equal(t, string(BillingModeImage), cost.BillingMode)
|
||||
require.InDelta(t, 0.80, cost.TotalCost, 1e-12)
|
||||
require.InDelta(t, 0.80, cost.ActualCost, 1e-12)
|
||||
}
|
||||
|
||||
@ -2049,6 +2049,21 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
|
||||
promptCacheKey = strings.TrimSpace(v)
|
||||
}
|
||||
}
|
||||
apiKey := getAPIKeyFromContext(c)
|
||||
imageGenerationAllowed := GroupAllowsImageGeneration(nil)
|
||||
if apiKey != nil {
|
||||
imageGenerationAllowed = GroupAllowsImageGeneration(apiKey.Group)
|
||||
}
|
||||
if IsImageGenerationIntentMap(openAIResponsesEndpoint, reqModel, reqBody) && !imageGenerationAllowed {
|
||||
setOpsUpstreamError(c, http.StatusForbidden, ImageGenerationPermissionMessage(), "")
|
||||
c.JSON(http.StatusForbidden, gin.H{
|
||||
"error": gin.H{
|
||||
"type": "permission_error",
|
||||
"message": ImageGenerationPermissionMessage(),
|
||||
},
|
||||
})
|
||||
return nil, errors.New("image generation disabled for group")
|
||||
}
|
||||
|
||||
// Track if body needs re-serialization
|
||||
bodyModified := false
|
||||
@ -2108,7 +2123,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
|
||||
markPatchSet("instructions", "You are a helpful coding assistant.")
|
||||
}
|
||||
|
||||
if isCodexCLI && ensureOpenAIResponsesImageGenerationTool(reqBody) {
|
||||
if isCodexCLI && imageGenerationAllowed && ensureOpenAIResponsesImageGenerationTool(reqBody) {
|
||||
bodyModified = true
|
||||
disablePatch()
|
||||
logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Injected /responses image_generation tool for Codex client")
|
||||
@ -2119,7 +2134,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
|
||||
disablePatch()
|
||||
logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Normalized /responses image_generation tool payload")
|
||||
}
|
||||
if isCodexCLI && applyCodexImageGenerationBridgeInstructions(reqBody) {
|
||||
if isCodexCLI && imageGenerationAllowed && applyCodexImageGenerationBridgeInstructions(reqBody) {
|
||||
bodyModified = true
|
||||
disablePatch()
|
||||
logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Added Codex image_generation bridge instructions")
|
||||
@ -2134,7 +2149,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
|
||||
markPatchSet("model", billingModel)
|
||||
}
|
||||
upstreamModel := billingModel
|
||||
if normalizeOpenAIResponsesImageOnlyModel(reqBody) {
|
||||
if imageGenerationAllowed && normalizeOpenAIResponsesImageOnlyModel(reqBody) {
|
||||
bodyModified = true
|
||||
disablePatch()
|
||||
if model, ok := reqBody["model"].(string); ok {
|
||||
@ -2355,6 +2370,34 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
|
||||
}
|
||||
}
|
||||
|
||||
if IsImageGenerationIntentMap(openAIResponsesEndpoint, reqModel, reqBody) && !imageGenerationAllowed {
|
||||
setOpsUpstreamError(c, http.StatusForbidden, ImageGenerationPermissionMessage(), "")
|
||||
c.JSON(http.StatusForbidden, gin.H{
|
||||
"error": gin.H{
|
||||
"type": "permission_error",
|
||||
"message": ImageGenerationPermissionMessage(),
|
||||
},
|
||||
})
|
||||
return nil, errors.New("image generation disabled for group")
|
||||
}
|
||||
imageBillingModel := ""
|
||||
imageSizeTier := ""
|
||||
if IsImageGenerationIntentMap(openAIResponsesEndpoint, reqModel, reqBody) {
|
||||
var imageCfgErr error
|
||||
imageBillingModel, imageSizeTier, imageCfgErr = resolveOpenAIResponsesImageBillingConfig(reqBody, billingModel)
|
||||
if imageCfgErr != nil {
|
||||
setOpsUpstreamError(c, http.StatusBadRequest, imageCfgErr.Error(), "")
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": gin.H{
|
||||
"type": "invalid_request_error",
|
||||
"message": imageCfgErr.Error(),
|
||||
"param": "size",
|
||||
},
|
||||
})
|
||||
return nil, imageCfgErr
|
||||
}
|
||||
}
|
||||
|
||||
// Re-serialize body only if modified
|
||||
if bodyModified {
|
||||
serializedByPatch := false
|
||||
@ -2592,6 +2635,10 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
|
||||
wsAttempts,
|
||||
)
|
||||
wsResult.UpstreamModel = upstreamModel
|
||||
if wsResult.ImageCount > 0 {
|
||||
wsResult.ImageSize = imageSizeTier
|
||||
wsResult.BillingModel = imageBillingModel
|
||||
}
|
||||
return wsResult, nil
|
||||
}
|
||||
s.writeOpenAIWSFallbackErrorResponse(c, account, wsErr)
|
||||
@ -2695,6 +2742,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
|
||||
// Handle normal response
|
||||
var usage *OpenAIUsage
|
||||
var firstTokenMs *int
|
||||
imageCount := 0
|
||||
if reqStream {
|
||||
streamResult, err := s.handleStreamingResponse(ctx, resp, c, account, startTime, originalModel, upstreamModel)
|
||||
if err != nil {
|
||||
@ -2702,11 +2750,14 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
|
||||
}
|
||||
usage = streamResult.usage
|
||||
firstTokenMs = streamResult.firstTokenMs
|
||||
imageCount = streamResult.imageCount
|
||||
} else {
|
||||
usage, err = s.handleNonStreamingResponse(ctx, resp, c, account, originalModel, upstreamModel)
|
||||
nonStreamResult, err := s.handleNonStreamingResponse(ctx, resp, c, account, originalModel, upstreamModel)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
usage = nonStreamResult.usage
|
||||
imageCount = nonStreamResult.imageCount
|
||||
}
|
||||
|
||||
// Extract and save Codex usage snapshot from response headers (for OAuth accounts)
|
||||
@ -2723,7 +2774,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
|
||||
reasoningEffort := extractOpenAIReasoningEffort(reqBody, originalModel)
|
||||
serviceTier := extractOpenAIServiceTier(reqBody)
|
||||
|
||||
return &OpenAIForwardResult{
|
||||
forwardResult := &OpenAIForwardResult{
|
||||
RequestID: resp.Header.Get("x-request-id"),
|
||||
Usage: *usage,
|
||||
Model: originalModel,
|
||||
@ -2734,7 +2785,13 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
|
||||
OpenAIWSMode: false,
|
||||
Duration: time.Since(startTime),
|
||||
FirstTokenMs: firstTokenMs,
|
||||
}, nil
|
||||
}
|
||||
if imageCount > 0 {
|
||||
forwardResult.ImageCount = imageCount
|
||||
forwardResult.ImageSize = imageSizeTier
|
||||
forwardResult.BillingModel = imageBillingModel
|
||||
}
|
||||
return forwardResult, nil
|
||||
}
|
||||
}
|
||||
|
||||
@ -2823,6 +2880,35 @@ func (s *OpenAIGatewayService) forwardOpenAIPassthrough(
|
||||
}
|
||||
body = updatedBody
|
||||
|
||||
apiKey := getAPIKeyFromContext(c)
|
||||
if IsImageGenerationIntent(openAIResponsesEndpoint, reqModel, body) && !GroupAllowsImageGeneration(apiKeyGroup(apiKey)) {
|
||||
setOpsUpstreamError(c, http.StatusForbidden, ImageGenerationPermissionMessage(), "")
|
||||
c.JSON(http.StatusForbidden, gin.H{
|
||||
"error": gin.H{
|
||||
"type": "permission_error",
|
||||
"message": ImageGenerationPermissionMessage(),
|
||||
},
|
||||
})
|
||||
return nil, errors.New("image generation disabled for group")
|
||||
}
|
||||
imageBillingModel := ""
|
||||
imageSizeTier := ""
|
||||
if IsImageGenerationIntent(openAIResponsesEndpoint, reqModel, body) {
|
||||
var imageCfgErr error
|
||||
imageBillingModel, imageSizeTier, imageCfgErr = resolveOpenAIResponsesImageBillingConfigFromBody(body, reqModel)
|
||||
if imageCfgErr != nil {
|
||||
setOpsUpstreamError(c, http.StatusBadRequest, imageCfgErr.Error(), "")
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": gin.H{
|
||||
"type": "invalid_request_error",
|
||||
"message": imageCfgErr.Error(),
|
||||
"param": "size",
|
||||
},
|
||||
})
|
||||
return nil, imageCfgErr
|
||||
}
|
||||
}
|
||||
|
||||
logger.LegacyPrintf("service.openai_gateway",
|
||||
"[OpenAI 自动透传] 命中自动透传分支: account=%d name=%s type=%s model=%s stream=%v",
|
||||
account.ID,
|
||||
@ -2905,6 +2991,7 @@ func (s *OpenAIGatewayService) forwardOpenAIPassthrough(
|
||||
|
||||
var usage *OpenAIUsage
|
||||
var firstTokenMs *int
|
||||
imageCount := 0
|
||||
if reqStream {
|
||||
result, err := s.handleStreamingResponsePassthrough(ctx, resp, c, account, startTime, reqModel, upstreamPassthroughModel)
|
||||
if err != nil {
|
||||
@ -2912,11 +2999,14 @@ func (s *OpenAIGatewayService) forwardOpenAIPassthrough(
|
||||
}
|
||||
usage = result.usage
|
||||
firstTokenMs = result.firstTokenMs
|
||||
imageCount = result.imageCount
|
||||
} else {
|
||||
usage, err = s.handleNonStreamingResponsePassthrough(ctx, resp, c, reqModel, upstreamPassthroughModel)
|
||||
result, err := s.handleNonStreamingResponsePassthrough(ctx, resp, c, reqModel, upstreamPassthroughModel)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
usage = result.usage
|
||||
imageCount = result.imageCount
|
||||
}
|
||||
|
||||
if snapshot := ParseCodexRateLimitHeaders(resp.Header); snapshot != nil {
|
||||
@ -2927,7 +3017,7 @@ func (s *OpenAIGatewayService) forwardOpenAIPassthrough(
|
||||
usage = &OpenAIUsage{}
|
||||
}
|
||||
|
||||
return &OpenAIForwardResult{
|
||||
forwardResult := &OpenAIForwardResult{
|
||||
RequestID: resp.Header.Get("x-request-id"),
|
||||
Usage: *usage,
|
||||
Model: reqModel,
|
||||
@ -2938,7 +3028,13 @@ func (s *OpenAIGatewayService) forwardOpenAIPassthrough(
|
||||
OpenAIWSMode: false,
|
||||
Duration: time.Since(startTime),
|
||||
FirstTokenMs: firstTokenMs,
|
||||
}, nil
|
||||
}
|
||||
if imageCount > 0 {
|
||||
forwardResult.ImageCount = imageCount
|
||||
forwardResult.ImageSize = imageSizeTier
|
||||
forwardResult.BillingModel = imageBillingModel
|
||||
}
|
||||
return forwardResult, nil
|
||||
}
|
||||
|
||||
func logOpenAIPassthroughInstructionsRejected(
|
||||
@ -3233,6 +3329,13 @@ func collectOpenAIPassthroughTimeoutHeaders(h http.Header) []string {
|
||||
type openaiStreamingResultPassthrough struct {
|
||||
usage *OpenAIUsage
|
||||
firstTokenMs *int
|
||||
imageCount int
|
||||
}
|
||||
|
||||
type openaiNonStreamingResultPassthrough struct {
|
||||
*OpenAIUsage
|
||||
usage *OpenAIUsage
|
||||
imageCount int
|
||||
}
|
||||
|
||||
func openAIStreamClientOutputStarted(c *gin.Context, localStarted bool) bool {
|
||||
@ -3369,6 +3472,7 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough(
|
||||
}
|
||||
|
||||
usage := &OpenAIUsage{}
|
||||
imageCounter := newOpenAIImageOutputCounter()
|
||||
var firstTokenMs *int
|
||||
clientDisconnected := false
|
||||
sawDone := false
|
||||
@ -3400,6 +3504,9 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough(
|
||||
defer putSSEScannerBuf64K(scanBuf)
|
||||
|
||||
needModelReplace := strings.TrimSpace(originalModel) != "" && strings.TrimSpace(mappedModel) != "" && strings.TrimSpace(originalModel) != strings.TrimSpace(mappedModel)
|
||||
resultWithUsage := func() *openaiStreamingResultPassthrough {
|
||||
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs, imageCount: imageCounter.Count()}
|
||||
}
|
||||
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
@ -3419,7 +3526,7 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough(
|
||||
if eventType == "response.failed" {
|
||||
failedMessage = extractOpenAISSEErrorMessage(dataBytes)
|
||||
if !openAIStreamClientOutputStarted(c, clientOutputStarted) && openAIStreamFailedEventShouldFailover(dataBytes, failedMessage) {
|
||||
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs},
|
||||
return resultWithUsage(),
|
||||
s.newOpenAIStreamFailoverError(c, account, true, upstreamRequestID, dataBytes, failedMessage)
|
||||
}
|
||||
forceFlushFailedEvent = true
|
||||
@ -3431,6 +3538,7 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough(
|
||||
if openAIStreamEventIsTerminal(trimmedData) {
|
||||
sawTerminalEvent = true
|
||||
}
|
||||
imageCounter.AddSSEData(dataBytes)
|
||||
lineStartsClientOutput = forceFlushFailedEvent || openAIStreamDataStartsClientOutput(trimmedData, eventType)
|
||||
if firstTokenMs == nil && lineStartsClientOutput && trimmedData != "[DONE]" {
|
||||
ms := int(time.Since(startTime).Milliseconds())
|
||||
@ -3460,28 +3568,28 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough(
|
||||
}
|
||||
if err := scanner.Err(); err != nil {
|
||||
if sawTerminalEvent && !sawFailedEvent {
|
||||
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, nil
|
||||
return resultWithUsage(), nil
|
||||
}
|
||||
if sawFailedEvent {
|
||||
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("upstream response failed: %s", failedMessage)
|
||||
return resultWithUsage(), fmt.Errorf("upstream response failed: %s", failedMessage)
|
||||
}
|
||||
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
|
||||
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream usage incomplete: %w", err)
|
||||
return resultWithUsage(), fmt.Errorf("stream usage incomplete: %w", err)
|
||||
}
|
||||
if errors.Is(err, bufio.ErrTooLong) {
|
||||
logger.LegacyPrintf("service.openai_gateway", "[OpenAI passthrough] SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, err)
|
||||
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, err
|
||||
return resultWithUsage(), err
|
||||
}
|
||||
if !openAIStreamClientOutputStarted(c, clientOutputStarted) {
|
||||
msg := "OpenAI stream disconnected before completion"
|
||||
if errText := strings.TrimSpace(err.Error()); errText != "" {
|
||||
msg += ": " + errText
|
||||
}
|
||||
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs},
|
||||
return resultWithUsage(),
|
||||
s.newOpenAIStreamFailoverError(c, account, true, upstreamRequestID, nil, msg)
|
||||
}
|
||||
if clientDisconnected {
|
||||
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream usage incomplete after disconnect: %w", err)
|
||||
return resultWithUsage(), fmt.Errorf("stream usage incomplete after disconnect: %w", err)
|
||||
}
|
||||
logger.LegacyPrintf("service.openai_gateway",
|
||||
"[OpenAI passthrough] 流读取异常中断: account=%d request_id=%s err=%v",
|
||||
@ -3489,10 +3597,10 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough(
|
||||
upstreamRequestID,
|
||||
err,
|
||||
)
|
||||
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream read error: %w", err)
|
||||
return resultWithUsage(), fmt.Errorf("stream read error: %w", err)
|
||||
}
|
||||
if sawFailedEvent {
|
||||
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("upstream response failed: %s", failedMessage)
|
||||
return resultWithUsage(), fmt.Errorf("upstream response failed: %s", failedMessage)
|
||||
}
|
||||
if !clientDisconnected && !sawDone && !sawTerminalEvent && ctx.Err() == nil {
|
||||
logger.FromContext(ctx).With(
|
||||
@ -3501,13 +3609,13 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough(
|
||||
zap.String("upstream_request_id", upstreamRequestID),
|
||||
).Info("OpenAI passthrough 上游流在未收到 [DONE] 时结束,疑似断流")
|
||||
if !openAIStreamClientOutputStarted(c, clientOutputStarted) {
|
||||
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs},
|
||||
return resultWithUsage(),
|
||||
s.newOpenAIStreamFailoverError(c, account, true, upstreamRequestID, nil, "OpenAI stream ended before a terminal event")
|
||||
}
|
||||
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, errors.New("stream usage incomplete: missing terminal event")
|
||||
return resultWithUsage(), errors.New("stream usage incomplete: missing terminal event")
|
||||
}
|
||||
|
||||
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, nil
|
||||
return resultWithUsage(), nil
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) handleNonStreamingResponsePassthrough(
|
||||
@ -3516,7 +3624,7 @@ func (s *OpenAIGatewayService) handleNonStreamingResponsePassthrough(
|
||||
c *gin.Context,
|
||||
originalModel string,
|
||||
mappedModel string,
|
||||
) (*OpenAIUsage, error) {
|
||||
) (*openaiNonStreamingResultPassthrough, error) {
|
||||
body, err := ReadUpstreamResponseBody(resp.Body, s.cfg, c, openAITooLargeError)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -3553,14 +3661,18 @@ func (s *OpenAIGatewayService) handleNonStreamingResponsePassthrough(
|
||||
body = s.replaceModelInResponseBody(body, mappedModel, originalModel)
|
||||
}
|
||||
c.Data(resp.StatusCode, contentType, body)
|
||||
return usage, nil
|
||||
return &openaiNonStreamingResultPassthrough{
|
||||
OpenAIUsage: usage,
|
||||
usage: usage,
|
||||
imageCount: countOpenAIResponseImageOutputsFromJSONBytes(body),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// handlePassthroughSSEToJSON converts an SSE response body into a JSON
|
||||
// response for the passthrough path. It mirrors handleSSEToJSON while
|
||||
// preserving passthrough payloads, except compact-only model remapping may
|
||||
// rewrite model fields back to the original requested model.
|
||||
func (s *OpenAIGatewayService) handlePassthroughSSEToJSON(resp *http.Response, c *gin.Context, body []byte, originalModel string, mappedModel string) (*OpenAIUsage, error) {
|
||||
func (s *OpenAIGatewayService) handlePassthroughSSEToJSON(resp *http.Response, c *gin.Context, body []byte, originalModel string, mappedModel string) (*openaiNonStreamingResultPassthrough, error) {
|
||||
bodyText := string(body)
|
||||
finalResponse, ok := extractCodexFinalResponse(bodyText)
|
||||
|
||||
@ -3611,7 +3723,11 @@ func (s *OpenAIGatewayService) handlePassthroughSSEToJSON(resp *http.Response, c
|
||||
}
|
||||
c.Data(resp.StatusCode, contentType, body)
|
||||
|
||||
return usage, nil
|
||||
return &openaiNonStreamingResultPassthrough{
|
||||
OpenAIUsage: usage,
|
||||
usage: usage,
|
||||
imageCount: countOpenAIImageOutputsFromSSEBody(bodyText),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func writeOpenAIPassthroughResponseHeaders(dst http.Header, src http.Header, filter *responseheaders.CompiledHeaderFilter) {
|
||||
@ -4025,6 +4141,13 @@ func (s *OpenAIGatewayService) handleCompatErrorResponse(
|
||||
type openaiStreamingResult struct {
|
||||
usage *OpenAIUsage
|
||||
firstTokenMs *int
|
||||
imageCount int
|
||||
}
|
||||
|
||||
type openaiNonStreamingResult struct {
|
||||
*OpenAIUsage
|
||||
usage *OpenAIUsage
|
||||
imageCount int
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, startTime time.Time, originalModel, mappedModel string) (*openaiStreamingResult, error) {
|
||||
@ -4058,6 +4181,7 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
|
||||
}
|
||||
|
||||
usage := &OpenAIUsage{}
|
||||
imageCounter := newOpenAIImageOutputCounter()
|
||||
var firstTokenMs *int
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
maxLineSize := defaultMaxLineSize
|
||||
@ -4136,7 +4260,7 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
|
||||
|
||||
needModelReplace := originalModel != mappedModel
|
||||
resultWithUsage := func() *openaiStreamingResult {
|
||||
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}
|
||||
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs, imageCount: imageCounter.Count()}
|
||||
}
|
||||
finalizeStream := func() (*openaiStreamingResult, error) {
|
||||
if !sawTerminalEvent {
|
||||
@ -4231,6 +4355,7 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
|
||||
forceFlushFailedEvent = true
|
||||
sawFailedEvent = true
|
||||
}
|
||||
imageCounter.AddSSEData(dataBytes)
|
||||
|
||||
// Correct Codex tool calls if needed (apply_patch -> edit, etc.)
|
||||
if correctedData, corrected := s.toolCorrector.CorrectToolCallsInSSEBytes(dataBytes); corrected {
|
||||
@ -4496,7 +4621,7 @@ func extractOpenAIUsageFromJSONBytes(body []byte) (OpenAIUsage, bool) {
|
||||
}, true
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) handleNonStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, originalModel, mappedModel string) (*OpenAIUsage, error) {
|
||||
func (s *OpenAIGatewayService) handleNonStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, originalModel, mappedModel string) (*openaiNonStreamingResult, error) {
|
||||
body, err := ReadUpstreamResponseBody(resp.Body, s.cfg, c, openAITooLargeError)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -4542,7 +4667,11 @@ func (s *OpenAIGatewayService) handleNonStreamingResponse(ctx context.Context, r
|
||||
|
||||
c.Data(resp.StatusCode, contentType, body)
|
||||
|
||||
return usage, nil
|
||||
return &openaiNonStreamingResult{
|
||||
OpenAIUsage: usage,
|
||||
usage: usage,
|
||||
imageCount: countOpenAIResponseImageOutputsFromJSONBytes(body),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func isEventStreamResponse(header http.Header) bool {
|
||||
@ -4550,7 +4679,7 @@ func isEventStreamResponse(header http.Header) bool {
|
||||
return strings.Contains(contentType, "text/event-stream")
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) handleSSEToJSON(resp *http.Response, c *gin.Context, body []byte, originalModel, mappedModel string) (*OpenAIUsage, error) {
|
||||
func (s *OpenAIGatewayService) handleSSEToJSON(resp *http.Response, c *gin.Context, body []byte, originalModel, mappedModel string) (*openaiNonStreamingResult, error) {
|
||||
bodyText := string(body)
|
||||
finalResponse, ok := extractCodexFinalResponse(bodyText)
|
||||
|
||||
@ -4602,21 +4731,29 @@ func (s *OpenAIGatewayService) handleSSEToJSON(resp *http.Response, c *gin.Conte
|
||||
}
|
||||
c.Data(resp.StatusCode, contentType, body)
|
||||
|
||||
return usage, nil
|
||||
return &openaiNonStreamingResult{
|
||||
OpenAIUsage: usage,
|
||||
usage: usage,
|
||||
imageCount: countOpenAIImageOutputsFromSSEBody(bodyText),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func extractOpenAISSETerminalEvent(body string) (string, []byte, bool) {
|
||||
lines := strings.Split(body, "\n")
|
||||
for _, line := range lines {
|
||||
data, ok := extractOpenAISSEDataLine(line)
|
||||
if !ok || data == "" || data == "[DONE]" {
|
||||
continue
|
||||
var terminalType string
|
||||
var terminalPayload []byte
|
||||
forEachOpenAISSEDataPayload(body, func(data []byte) {
|
||||
if terminalPayload != nil {
|
||||
return
|
||||
}
|
||||
eventType := strings.TrimSpace(gjson.Get(data, "type").String())
|
||||
eventType := strings.TrimSpace(gjson.GetBytes(data, "type").String())
|
||||
switch eventType {
|
||||
case "response.completed", "response.done", "response.failed", "response.incomplete", "response.cancelled", "response.canceled":
|
||||
return eventType, []byte(data), true
|
||||
terminalType = eventType
|
||||
terminalPayload = append([]byte(nil), data...)
|
||||
}
|
||||
})
|
||||
if terminalPayload != nil {
|
||||
return terminalType, terminalPayload, true
|
||||
}
|
||||
return "", nil, false
|
||||
}
|
||||
@ -4651,21 +4788,20 @@ func (s *OpenAIGatewayService) writeOpenAINonStreamingProtocolError(resp *http.R
|
||||
}
|
||||
|
||||
func extractCodexFinalResponse(body string) ([]byte, bool) {
|
||||
lines := strings.Split(body, "\n")
|
||||
for _, line := range lines {
|
||||
data, ok := extractOpenAISSEDataLine(line)
|
||||
if !ok {
|
||||
continue
|
||||
var finalResponse []byte
|
||||
forEachOpenAISSEDataPayload(body, func(data []byte) {
|
||||
if finalResponse != nil {
|
||||
return
|
||||
}
|
||||
if data == "" || data == "[DONE]" {
|
||||
continue
|
||||
}
|
||||
eventType := gjson.Get(data, "type").String()
|
||||
eventType := gjson.GetBytes(data, "type").String()
|
||||
if eventType == "response.done" || eventType == "response.completed" {
|
||||
if response := gjson.Get(data, "response"); response.Exists() && response.Type == gjson.JSON && response.Raw != "" {
|
||||
return []byte(response.Raw), true
|
||||
if response := gjson.GetBytes(data, "response"); response.Exists() && response.Type == gjson.JSON && response.Raw != "" {
|
||||
finalResponse = []byte(response.Raw)
|
||||
}
|
||||
}
|
||||
})
|
||||
if finalResponse != nil {
|
||||
return finalResponse, true
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
@ -4677,21 +4813,15 @@ func reconstructResponseOutputFromSSE(bodyText string) ([]byte, bool) {
|
||||
acc := apicompat.NewBufferedResponseAccumulator()
|
||||
imageOutputs := make([]json.RawMessage, 0, 1)
|
||||
seenImages := make(map[string]struct{})
|
||||
lines := strings.Split(bodyText, "\n")
|
||||
for _, line := range lines {
|
||||
data, ok := extractOpenAISSEDataLine(line)
|
||||
if !ok || data == "" || data == "[DONE]" {
|
||||
continue
|
||||
}
|
||||
if imageOutput, ok := extractImageGenerationOutputFromSSEData([]byte(data), seenImages); ok {
|
||||
forEachOpenAISSEDataPayload(bodyText, func(data []byte) {
|
||||
if imageOutput, ok := extractImageGenerationOutputFromSSEData(data, seenImages); ok {
|
||||
imageOutputs = append(imageOutputs, imageOutput)
|
||||
}
|
||||
var event apicompat.ResponsesStreamEvent
|
||||
if err := json.Unmarshal([]byte(data), &event); err != nil {
|
||||
continue
|
||||
if err := json.Unmarshal(data, &event); err == nil {
|
||||
acc.ProcessEvent(&event)
|
||||
}
|
||||
acc.ProcessEvent(&event)
|
||||
}
|
||||
})
|
||||
if !acc.HasContent() && len(imageOutputs) == 0 {
|
||||
return nil, false
|
||||
}
|
||||
@ -4744,17 +4874,9 @@ func extractImageGenerationOutputFromSSEData(data []byte, seen map[string]struct
|
||||
|
||||
func (s *OpenAIGatewayService) parseSSEUsageFromBody(body string) *OpenAIUsage {
|
||||
usage := &OpenAIUsage{}
|
||||
lines := strings.Split(body, "\n")
|
||||
for _, line := range lines {
|
||||
data, ok := extractOpenAISSEDataLine(line)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if data == "" || data == "[DONE]" {
|
||||
continue
|
||||
}
|
||||
s.parseSSEUsageBytes([]byte(data), usage)
|
||||
}
|
||||
forEachOpenAISSEDataPayload(body, func(data []byte) {
|
||||
s.parseSSEUsageBytes(data, usage)
|
||||
})
|
||||
return usage
|
||||
}
|
||||
|
||||
@ -5036,8 +5158,14 @@ type OpenAIRecordUsageInput struct {
|
||||
|
||||
// RecordUsage records usage and deducts balance
|
||||
func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRecordUsageInput) error {
|
||||
if input == nil {
|
||||
return errors.New("openai usage input is nil")
|
||||
}
|
||||
result := input.Result
|
||||
if s.rateLimitService != nil && input != nil && input.Account != nil && input.Account.Platform == PlatformOpenAI {
|
||||
if result == nil {
|
||||
return errors.New("openai usage result is nil")
|
||||
}
|
||||
if s.rateLimitService != nil && input.Account != nil && input.Account.Platform == PlatformOpenAI {
|
||||
s.rateLimitService.ResetOpenAI403Counter(ctx, input.Account.ID)
|
||||
}
|
||||
|
||||
@ -5074,6 +5202,7 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
|
||||
}
|
||||
multiplier = resolver.Resolve(ctx, user.ID, *apiKey.GroupID, apiKey.Group.RateMultiplier)
|
||||
}
|
||||
imageMultiplier := resolveImageRateMultiplier(apiKey, multiplier)
|
||||
|
||||
var cost *CostBreakdown
|
||||
var err error
|
||||
@ -5087,13 +5216,21 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
|
||||
if input.BillingModelSource == BillingModelSourceRequested && input.OriginalModel != "" {
|
||||
billingModel = input.OriginalModel
|
||||
}
|
||||
billingModels := usageBillingModelCandidates(
|
||||
billingModel,
|
||||
result.BillingModel,
|
||||
input.ChannelMappedModel,
|
||||
input.OriginalModel,
|
||||
result.UpstreamModel,
|
||||
result.Model,
|
||||
)
|
||||
serviceTier := ""
|
||||
if result.ServiceTier != nil {
|
||||
serviceTier = strings.TrimSpace(*result.ServiceTier)
|
||||
}
|
||||
cost, err = s.calculateOpenAIRecordUsageCost(ctx, result, apiKey, billingModel, multiplier, tokens, serviceTier)
|
||||
cost, err = s.calculateOpenAIRecordUsageCost(ctx, result, apiKey, billingModels, multiplier, imageMultiplier, tokens, serviceTier)
|
||||
if err != nil {
|
||||
cost = &CostBreakdown{ActualCost: 0}
|
||||
return err
|
||||
}
|
||||
|
||||
// Determine billing type
|
||||
@ -5143,7 +5280,11 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
|
||||
usageLog.TotalCost = cost.TotalCost
|
||||
usageLog.ActualCost = cost.ActualCost
|
||||
}
|
||||
usageLog.RateMultiplier = multiplier
|
||||
if result.ImageCount > 0 {
|
||||
usageLog.RateMultiplier = imageMultiplier
|
||||
} else {
|
||||
usageLog.RateMultiplier = multiplier
|
||||
}
|
||||
usageLog.AccountRateMultiplier = &accountRateMultiplier
|
||||
usageLog.BillingType = billingType
|
||||
usageLog.Stream = result.Stream
|
||||
@ -5224,14 +5365,45 @@ func (s *OpenAIGatewayService) calculateOpenAIRecordUsageCost(
|
||||
ctx context.Context,
|
||||
result *OpenAIForwardResult,
|
||||
apiKey *APIKey,
|
||||
billingModels []string,
|
||||
multiplier float64,
|
||||
imageMultiplier float64,
|
||||
tokens UsageTokens,
|
||||
serviceTier string,
|
||||
) (*CostBreakdown, error) {
|
||||
billingModel := firstUsageBillingModel(billingModels)
|
||||
if result != nil && result.ImageCount > 0 {
|
||||
return s.calculateOpenAIImageCost(ctx, billingModel, apiKey, result, imageMultiplier), nil
|
||||
}
|
||||
if len(billingModels) == 0 || billingModel == "" {
|
||||
return nil, errors.New("openai usage billing model is empty")
|
||||
}
|
||||
var lastErr error
|
||||
for _, candidate := range billingModels {
|
||||
candidate = strings.TrimSpace(candidate)
|
||||
if candidate == "" {
|
||||
continue
|
||||
}
|
||||
cost, err := s.calculateOpenAIRecordUsageTokenCost(ctx, apiKey, candidate, multiplier, tokens, serviceTier)
|
||||
if err == nil {
|
||||
return cost, nil
|
||||
}
|
||||
lastErr = err
|
||||
}
|
||||
if lastErr == nil {
|
||||
lastErr = errors.New("no non-empty billing model candidates")
|
||||
}
|
||||
return nil, fmt.Errorf("calculate OpenAI usage cost failed for billing models %s: %w", strings.Join(billingModels, ","), lastErr)
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) calculateOpenAIRecordUsageTokenCost(
|
||||
ctx context.Context,
|
||||
apiKey *APIKey,
|
||||
billingModel string,
|
||||
multiplier float64,
|
||||
tokens UsageTokens,
|
||||
serviceTier string,
|
||||
) (*CostBreakdown, error) {
|
||||
if result != nil && result.ImageCount > 0 {
|
||||
return s.calculateOpenAIImageCost(ctx, billingModel, apiKey, result, multiplier), nil
|
||||
}
|
||||
if s.resolver != nil && apiKey.Group != nil {
|
||||
gid := apiKey.Group.ID
|
||||
return s.billingService.CalculateCostUnified(CostInput{
|
||||
@ -5262,7 +5434,7 @@ func (s *OpenAIGatewayService) calculateOpenAIImageCost(
|
||||
Ctx: ctx,
|
||||
Model: billingModel,
|
||||
GroupID: &gid,
|
||||
RequestCount: 1,
|
||||
RequestCount: result.ImageCount,
|
||||
SizeTier: result.ImageSize,
|
||||
RateMultiplier: multiplier,
|
||||
Resolver: s.resolver,
|
||||
|
||||
@ -0,0 +1,215 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
func TestOpenAIGatewayServiceForward_RejectsDisabledImageGenerationIntents(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
body []byte
|
||||
}{
|
||||
{
|
||||
name: "image model",
|
||||
body: []byte(`{"model":"gpt-image-2","input":"draw"}`),
|
||||
},
|
||||
{
|
||||
name: "image tool",
|
||||
body: []byte(`{"model":"gpt-5.4","input":"draw","tools":[{"type":"image_generation"}]}`),
|
||||
},
|
||||
{
|
||||
name: "image tool choice",
|
||||
body: []byte(`{"model":"gpt-5.4","input":"draw","tool_choice":{"type":"image_generation"}}`),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
upstream := &httpUpstreamRecorder{}
|
||||
svc := newOpenAIImageGenerationControlTestService(upstream)
|
||||
c, recorder := newOpenAIImageGenerationControlTestContext(false, "unit-test-agent/1.0")
|
||||
account := newOpenAIImageGenerationControlTestAccount()
|
||||
|
||||
result, err := svc.Forward(context.Background(), c, account, tt.body)
|
||||
|
||||
require.Error(t, err)
|
||||
require.Nil(t, result)
|
||||
require.Equal(t, http.StatusForbidden, recorder.Code)
|
||||
require.Equal(t, "permission_error", gjson.GetBytes(recorder.Body.Bytes(), "error.type").String())
|
||||
require.Nil(t, upstream.lastReq, "disabled image request must not reach upstream")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayServiceForward_DisabledGroupAllowsTextOnlyResponses(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
upstream := &httpUpstreamRecorder{
|
||||
resp: &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{"Content-Type": []string{"application/json"}},
|
||||
Body: io.NopCloser(strings.NewReader(`{"id":"resp_text","model":"gpt-5.4","usage":{"input_tokens":3,"output_tokens":2}}`)),
|
||||
},
|
||||
}
|
||||
svc := newOpenAIImageGenerationControlTestService(upstream)
|
||||
c, recorder := newOpenAIImageGenerationControlTestContext(false, "unit-test-agent/1.0")
|
||||
account := newOpenAIImageGenerationControlTestAccount()
|
||||
|
||||
result, err := svc.Forward(context.Background(), c, account, []byte(`{"model":"gpt-5.4","input":"write code","stream":false}`))
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.Equal(t, http.StatusOK, recorder.Code)
|
||||
require.Equal(t, 3, result.Usage.InputTokens)
|
||||
require.Equal(t, 2, result.Usage.OutputTokens)
|
||||
require.Equal(t, 0, result.ImageCount)
|
||||
require.NotNil(t, upstream.lastReq)
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayServiceForward_CodexImageInjectionRespectsGroupCapability(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
allowImages bool
|
||||
wantInjected bool
|
||||
}{
|
||||
{name: "disabled group skips injection", allowImages: false, wantInjected: false},
|
||||
{name: "enabled group injects image tool", allowImages: true, wantInjected: true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
upstream := &httpUpstreamRecorder{
|
||||
resp: &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{"Content-Type": []string{"application/json"}},
|
||||
Body: io.NopCloser(strings.NewReader(`{"id":"resp_codex","model":"gpt-5.4","usage":{"input_tokens":1,"output_tokens":1}}`)),
|
||||
},
|
||||
}
|
||||
svc := newOpenAIImageGenerationControlTestService(upstream)
|
||||
c, _ := newOpenAIImageGenerationControlTestContext(tt.allowImages, "codex_cli_rs/0.98.0")
|
||||
account := newOpenAIImageGenerationControlTestAccount()
|
||||
|
||||
result, err := svc.Forward(context.Background(), c, account, []byte(`{"model":"gpt-5.4","input":"write code","stream":false}`))
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.NotNil(t, upstream.lastReq)
|
||||
hasImageTool := gjson.GetBytes(upstream.lastBody, `tools.#(type=="image_generation")`).Exists()
|
||||
require.Equal(t, tt.wantInjected, hasImageTool)
|
||||
instructions := gjson.GetBytes(upstream.lastBody, "instructions").String()
|
||||
require.Equal(t, tt.wantInjected, strings.Contains(instructions, "image_generation"))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayServiceHandleResponsesImageOutputs_NonStreaming(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
svc := newOpenAIImageGenerationControlTestService(&httpUpstreamRecorder{})
|
||||
c, _ := newOpenAIImageGenerationControlTestContext(true, "unit-test-agent/1.0")
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{"Content-Type": []string{"application/json"}},
|
||||
Body: io.NopCloser(strings.NewReader(`{
|
||||
"id":"resp_image_json",
|
||||
"model":"gpt-5.4",
|
||||
"output":[{"id":"ig_json_1","type":"image_generation_call","result":"final-image"}],
|
||||
"usage":{"input_tokens":7,"output_tokens":3,"output_tokens_details":{"image_tokens":2}}
|
||||
}`)),
|
||||
}
|
||||
|
||||
result, err := svc.handleNonStreamingResponse(context.Background(), resp, c, &Account{ID: 1, Type: AccountTypeAPIKey}, "gpt-5.4", "gpt-5.4")
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.Equal(t, 1, result.imageCount)
|
||||
require.NotNil(t, result.usage)
|
||||
require.Equal(t, 7, result.usage.InputTokens)
|
||||
require.Equal(t, 3, result.usage.OutputTokens)
|
||||
require.Equal(t, 2, result.usage.ImageOutputTokens)
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayServiceHandleResponsesImageOutputs_Streaming(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
svc := newOpenAIImageGenerationControlTestService(&httpUpstreamRecorder{})
|
||||
c, _ := newOpenAIImageGenerationControlTestContext(true, "unit-test-agent/1.0")
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{"Content-Type": []string{"text/event-stream"}},
|
||||
Body: io.NopCloser(strings.NewReader(
|
||||
"data: {\"type\":\"response.output_item.done\",\"item\":{\"id\":\"ig_stream_1\",\"type\":\"image_generation_call\",\"result\":\"final-image\"}}\n\n" +
|
||||
"data: {\"type\":\"response.completed\",\"response\":{\"id\":\"resp_image_stream\",\"model\":\"gpt-5.5\",\"output\":[{\"id\":\"ig_stream_1\",\"type\":\"image_generation_call\",\"result\":\"final-image\"}],\"usage\":{\"input_tokens\":11,\"output_tokens\":5,\"output_tokens_details\":{\"image_tokens\":4}}}}\n\n",
|
||||
)),
|
||||
}
|
||||
|
||||
result, err := svc.handleStreamingResponse(context.Background(), resp, c, &Account{ID: 1}, time.Now(), "gpt-5.5", "gpt-5.5")
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.Equal(t, 1, result.imageCount)
|
||||
require.NotNil(t, result.usage)
|
||||
require.Equal(t, 11, result.usage.InputTokens)
|
||||
require.Equal(t, 5, result.usage.OutputTokens)
|
||||
require.Equal(t, 4, result.usage.ImageOutputTokens)
|
||||
}
|
||||
|
||||
func newOpenAIImageGenerationControlTestService(upstream *httpUpstreamRecorder) *OpenAIGatewayService {
|
||||
cfg := &config.Config{}
|
||||
return &OpenAIGatewayService{
|
||||
cfg: cfg,
|
||||
httpUpstream: upstream,
|
||||
cache: &stubGatewayCache{},
|
||||
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
|
||||
toolCorrector: NewCodexToolCorrector(),
|
||||
}
|
||||
}
|
||||
|
||||
func newOpenAIImageGenerationControlTestContext(allowImages bool, userAgent string) (*gin.Context, *httptest.ResponseRecorder) {
|
||||
recorder := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(recorder)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
|
||||
c.Request.Header.Set("User-Agent", userAgent)
|
||||
groupID := int64(4242)
|
||||
c.Set("api_key", &APIKey{
|
||||
ID: 2424,
|
||||
GroupID: &groupID,
|
||||
Group: &Group{
|
||||
ID: groupID,
|
||||
AllowImageGeneration: allowImages,
|
||||
RateMultiplier: 1,
|
||||
ImageRateMultiplier: 1,
|
||||
},
|
||||
})
|
||||
return c, recorder
|
||||
}
|
||||
|
||||
func newOpenAIImageGenerationControlTestAccount() *Account {
|
||||
return &Account{
|
||||
ID: 5151,
|
||||
Name: "openai-image-controls",
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeAPIKey,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{
|
||||
"api_key": "sk-test",
|
||||
},
|
||||
}
|
||||
}
|
||||
@ -16,6 +16,7 @@ import (
|
||||
"net/textproto"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
@ -468,14 +469,54 @@ func isOpenAINativeImageOption(name string) bool {
|
||||
}
|
||||
|
||||
func normalizeOpenAIImageSizeTier(size string) string {
|
||||
switch strings.ToLower(strings.TrimSpace(size)) {
|
||||
trimmed := strings.TrimSpace(size)
|
||||
normalized := strings.ToLower(trimmed)
|
||||
switch normalized {
|
||||
case "", "auto":
|
||||
return "2K"
|
||||
case "1024x1024":
|
||||
return "1K"
|
||||
case "1536x1024", "1024x1536", "1792x1024", "1024x1792", "", "auto":
|
||||
case "1536x1024", "1024x1536", "1792x1024", "1024x1792", "2048x2048", "2048x1152", "1152x2048":
|
||||
return "2K"
|
||||
default:
|
||||
case "3840x2160", "2160x3840":
|
||||
return "4K"
|
||||
}
|
||||
width, height, ok := parseOpenAIImageSizeDimensions(trimmed)
|
||||
if !ok {
|
||||
return "2K"
|
||||
}
|
||||
return classifyUnknownOpenAIImageSizeTier(width, height)
|
||||
}
|
||||
|
||||
const (
|
||||
openAIImage2KMaxPixels = 2560 * 1440
|
||||
)
|
||||
|
||||
func parseOpenAIImageSizeDimensions(size string) (int, int, bool) {
|
||||
trimmed := strings.TrimSpace(size)
|
||||
parts := strings.Split(strings.ToLower(trimmed), "x")
|
||||
if len(parts) != 2 {
|
||||
return 0, 0, false
|
||||
}
|
||||
width, err := strconv.Atoi(strings.TrimSpace(parts[0]))
|
||||
if err != nil {
|
||||
return 0, 0, false
|
||||
}
|
||||
height, err := strconv.Atoi(strings.TrimSpace(parts[1]))
|
||||
if err != nil {
|
||||
return 0, 0, false
|
||||
}
|
||||
if width <= 0 || height <= 0 {
|
||||
return 0, 0, false
|
||||
}
|
||||
return width, height, true
|
||||
}
|
||||
|
||||
func classifyUnknownOpenAIImageSizeTier(width int, height int) string {
|
||||
if height > 0 && width > openAIImage2KMaxPixels/height {
|
||||
return "4K"
|
||||
}
|
||||
return "2K"
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) ForwardImages(
|
||||
@ -535,11 +576,14 @@ func (s *OpenAIGatewayService) forwardOpenAIImagesAPIKey(
|
||||
setOpsUpstreamRequestBody(c, forwardBody)
|
||||
}
|
||||
|
||||
token, _, err := s.GetAccessToken(ctx, account)
|
||||
upstreamCtx, releaseUpstreamCtx := detachStreamUpstreamContext(ctx, parsed.Stream)
|
||||
defer releaseUpstreamCtx()
|
||||
|
||||
token, _, err := s.GetAccessToken(upstreamCtx, account)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
upstreamReq, err := s.buildOpenAIImagesRequest(ctx, c, account, forwardBody, forwardContentType, token, parsed.Endpoint)
|
||||
upstreamReq, err := s.buildOpenAIImagesRequest(upstreamCtx, c, account, forwardBody, forwardContentType, token, parsed.Endpoint)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -582,14 +626,14 @@ func (s *OpenAIGatewayService) forwardOpenAIImagesAPIKey(
|
||||
Kind: "failover",
|
||||
Message: upstreamMsg,
|
||||
})
|
||||
s.handleFailoverSideEffects(ctx, resp, account)
|
||||
s.handleFailoverSideEffects(upstreamCtx, resp, account)
|
||||
return nil, &UpstreamFailoverError{
|
||||
StatusCode: resp.StatusCode,
|
||||
ResponseBody: respBody,
|
||||
RetryableOnSameAccount: account.IsPoolMode() && isPoolModeRetryableStatus(resp.StatusCode),
|
||||
}
|
||||
}
|
||||
return s.handleErrorResponse(ctx, resp, c, account, forwardBody)
|
||||
return s.handleErrorResponse(upstreamCtx, resp, c, account, forwardBody)
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
@ -599,6 +643,20 @@ func (s *OpenAIGatewayService) forwardOpenAIImagesAPIKey(
|
||||
if parsed.Stream && isEventStreamResponse(resp.Header) {
|
||||
streamUsage, streamCount, ttft, err := s.handleOpenAIImagesStreamingResponse(resp, c, startTime)
|
||||
if err != nil {
|
||||
if streamCount > 0 {
|
||||
return &OpenAIForwardResult{
|
||||
RequestID: resp.Header.Get("x-request-id"),
|
||||
Usage: streamUsage,
|
||||
Model: requestModel,
|
||||
UpstreamModel: upstreamModel,
|
||||
Stream: parsed.Stream,
|
||||
ResponseHeaders: resp.Header.Clone(),
|
||||
Duration: time.Since(startTime),
|
||||
FirstTokenMs: ttft,
|
||||
ImageCount: streamCount,
|
||||
ImageSize: parsed.SizeTier,
|
||||
}, err
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
usage = streamUsage
|
||||
@ -807,66 +865,205 @@ func (s *OpenAIGatewayService) handleOpenAIImagesStreamingResponse(
|
||||
return OpenAIUsage{}, 0, nil, fmt.Errorf("streaming is not supported by response writer")
|
||||
}
|
||||
|
||||
reader := bufio.NewReader(resp.Body)
|
||||
usage := OpenAIUsage{}
|
||||
imageCount := 0
|
||||
imageCounter := newOpenAIImageOutputCounter()
|
||||
var firstTokenMs *int
|
||||
clientDisconnected := false
|
||||
lastDownstreamWriteAt := time.Now()
|
||||
var fallbackBody bytes.Buffer
|
||||
fallbackBytes := int64(0)
|
||||
fallbackLimit := resolveUpstreamResponseReadLimit(s.cfg)
|
||||
seenSSEData := false
|
||||
fallbackTooLarge := false
|
||||
var sseData openAISSEDataAccumulator
|
||||
|
||||
processSSEData := func(dataBytes []byte) {
|
||||
seenSSEData = true
|
||||
fallbackBody.Reset()
|
||||
fallbackBytes = 0
|
||||
mergeOpenAIUsage(&usage, dataBytes)
|
||||
imageCounter.AddSSEData(dataBytes)
|
||||
}
|
||||
|
||||
flushSSEEvent := func() {
|
||||
sseData.Flush(processSSEData)
|
||||
}
|
||||
|
||||
processLine := func(line []byte) {
|
||||
if len(line) == 0 {
|
||||
return
|
||||
}
|
||||
if firstTokenMs == nil {
|
||||
ms := int(time.Since(startTime).Milliseconds())
|
||||
firstTokenMs = &ms
|
||||
}
|
||||
if !clientDisconnected {
|
||||
if _, writeErr := c.Writer.Write(line); writeErr != nil {
|
||||
clientDisconnected = true
|
||||
logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Images stream client disconnected, continue draining upstream for billing")
|
||||
} else {
|
||||
flusher.Flush()
|
||||
lastDownstreamWriteAt = time.Now()
|
||||
}
|
||||
}
|
||||
|
||||
trimmedLine := strings.TrimRight(string(line), "\r\n")
|
||||
if _, ok := extractOpenAISSEDataLine(trimmedLine); ok || strings.TrimSpace(trimmedLine) == "" {
|
||||
sseData.AddLine(trimmedLine, processSSEData)
|
||||
return
|
||||
}
|
||||
if !seenSSEData && !fallbackTooLarge {
|
||||
fallbackBytes += int64(len(line))
|
||||
if fallbackBytes <= fallbackLimit {
|
||||
_, _ = fallbackBody.Write(line)
|
||||
} else {
|
||||
fallbackTooLarge = true
|
||||
fallbackBody.Reset()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
finalizeFallbackBody := func() {
|
||||
if seenSSEData || fallbackBody.Len() == 0 {
|
||||
return
|
||||
}
|
||||
body := bytes.TrimSpace(fallbackBody.Bytes())
|
||||
if len(body) == 0 {
|
||||
return
|
||||
}
|
||||
mergeOpenAIUsage(&usage, body)
|
||||
imageCounter.AddJSONResponse(body)
|
||||
}
|
||||
|
||||
streamInterval := s.openAIImageStreamDataInterval()
|
||||
keepaliveInterval := s.openAIImageStreamKeepaliveInterval()
|
||||
if streamInterval <= 0 && keepaliveInterval <= 0 {
|
||||
reader := bufio.NewReader(resp.Body)
|
||||
for {
|
||||
line, err := reader.ReadBytes('\n')
|
||||
processLine(line)
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
flushSSEEvent()
|
||||
return usage, imageCounter.Count(), firstTokenMs, err
|
||||
}
|
||||
}
|
||||
flushSSEEvent()
|
||||
finalizeFallbackBody()
|
||||
return usage, imageCounter.Count(), firstTokenMs, nil
|
||||
}
|
||||
|
||||
type readEvent struct {
|
||||
line []byte
|
||||
err error
|
||||
}
|
||||
events := make(chan readEvent, 16)
|
||||
done := make(chan struct{})
|
||||
sendEvent := func(ev readEvent) bool {
|
||||
select {
|
||||
case events <- ev:
|
||||
return true
|
||||
case <-done:
|
||||
return false
|
||||
}
|
||||
}
|
||||
var lastReadAt int64
|
||||
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
|
||||
go func() {
|
||||
defer close(events)
|
||||
reader := bufio.NewReader(resp.Body)
|
||||
for {
|
||||
line, err := reader.ReadBytes('\n')
|
||||
if len(line) > 0 {
|
||||
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
|
||||
}
|
||||
if len(line) > 0 && !sendEvent(readEvent{line: line}) {
|
||||
return
|
||||
}
|
||||
if err == io.EOF {
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
_ = sendEvent(readEvent{err: err})
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
defer close(done)
|
||||
|
||||
var intervalTicker *time.Ticker
|
||||
if streamInterval > 0 {
|
||||
intervalTicker = time.NewTicker(streamInterval)
|
||||
defer intervalTicker.Stop()
|
||||
}
|
||||
var intervalCh <-chan time.Time
|
||||
if intervalTicker != nil {
|
||||
intervalCh = intervalTicker.C
|
||||
}
|
||||
|
||||
var keepaliveTicker *time.Ticker
|
||||
if keepaliveInterval > 0 {
|
||||
keepaliveTicker = time.NewTicker(keepaliveInterval)
|
||||
defer keepaliveTicker.Stop()
|
||||
}
|
||||
var keepaliveCh <-chan time.Time
|
||||
if keepaliveTicker != nil {
|
||||
keepaliveCh = keepaliveTicker.C
|
||||
}
|
||||
|
||||
for {
|
||||
line, err := reader.ReadBytes('\n')
|
||||
if len(line) > 0 {
|
||||
if firstTokenMs == nil {
|
||||
ms := int(time.Since(startTime).Milliseconds())
|
||||
firstTokenMs = &ms
|
||||
select {
|
||||
case ev, ok := <-events:
|
||||
if !ok {
|
||||
flushSSEEvent()
|
||||
finalizeFallbackBody()
|
||||
return usage, imageCounter.Count(), firstTokenMs, nil
|
||||
}
|
||||
if _, writeErr := c.Writer.Write(line); writeErr != nil {
|
||||
return OpenAIUsage{}, 0, firstTokenMs, writeErr
|
||||
if ev.err != nil {
|
||||
flushSSEEvent()
|
||||
return usage, imageCounter.Count(), firstTokenMs, ev.err
|
||||
}
|
||||
processLine(ev.line)
|
||||
case <-intervalCh:
|
||||
lastRead := time.Unix(0, atomic.LoadInt64(&lastReadAt))
|
||||
if time.Since(lastRead) < streamInterval {
|
||||
continue
|
||||
}
|
||||
if clientDisconnected {
|
||||
return usage, imageCounter.Count(), firstTokenMs, fmt.Errorf("image stream incomplete after timeout")
|
||||
}
|
||||
logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Images stream data interval timeout: interval=%s", streamInterval)
|
||||
_ = s.writeOpenAIImagesStreamEvent(c, flusher, "error", buildOpenAIImagesStreamErrorBody(fmt.Sprintf("upstream image stream idle for %s", streamInterval)))
|
||||
return usage, imageCounter.Count(), firstTokenMs, fmt.Errorf("image stream data interval timeout")
|
||||
case <-keepaliveCh:
|
||||
if clientDisconnected || time.Since(lastDownstreamWriteAt) < keepaliveInterval {
|
||||
continue
|
||||
}
|
||||
if _, writeErr := io.WriteString(c.Writer, ":\n\n"); writeErr != nil {
|
||||
clientDisconnected = true
|
||||
logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Images stream client disconnected during keepalive, continue draining upstream for billing")
|
||||
continue
|
||||
}
|
||||
flusher.Flush()
|
||||
lastDownstreamWriteAt = time.Now()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if data, ok := extractOpenAISSEDataLine(strings.TrimRight(string(line), "\r\n")); ok {
|
||||
if data != "" && data != "[DONE]" {
|
||||
seenSSEData = true
|
||||
fallbackBody.Reset()
|
||||
fallbackBytes = 0
|
||||
dataBytes := []byte(data)
|
||||
mergeOpenAIUsage(&usage, dataBytes)
|
||||
if count := extractOpenAIImagesBillableCountFromJSONBytes(dataBytes); count > imageCount {
|
||||
imageCount = count
|
||||
}
|
||||
}
|
||||
} else if !seenSSEData && !fallbackTooLarge {
|
||||
fallbackBytes += int64(len(line))
|
||||
if fallbackBytes <= fallbackLimit {
|
||||
_, _ = fallbackBody.Write(line)
|
||||
} else {
|
||||
fallbackTooLarge = true
|
||||
fallbackBody.Reset()
|
||||
}
|
||||
}
|
||||
}
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
return OpenAIUsage{}, 0, firstTokenMs, err
|
||||
}
|
||||
func (s *OpenAIGatewayService) openAIImageStreamDataInterval() time.Duration {
|
||||
if s == nil || s.cfg == nil || s.cfg.Gateway.ImageStreamDataIntervalTimeout <= 0 {
|
||||
return 0
|
||||
}
|
||||
if !seenSSEData && fallbackBody.Len() > 0 {
|
||||
body := bytes.TrimSpace(fallbackBody.Bytes())
|
||||
if len(body) > 0 {
|
||||
mergeOpenAIUsage(&usage, body)
|
||||
if count := extractOpenAIImagesBillableCountFromJSONBytes(body); count > imageCount {
|
||||
imageCount = count
|
||||
}
|
||||
}
|
||||
return time.Duration(s.cfg.Gateway.ImageStreamDataIntervalTimeout) * time.Second
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) openAIImageStreamKeepaliveInterval() time.Duration {
|
||||
if s == nil || s.cfg == nil || s.cfg.Gateway.ImageStreamKeepaliveInterval <= 0 {
|
||||
return 0
|
||||
}
|
||||
return usage, imageCount, firstTokenMs, nil
|
||||
return time.Duration(s.cfg.Gateway.ImageStreamKeepaliveInterval) * time.Second
|
||||
}
|
||||
|
||||
func extractOpenAIImagesBillableCountFromJSONBytes(body []byte) int {
|
||||
@ -913,14 +1110,7 @@ func mergeOpenAIUsage(dst *OpenAIUsage, body []byte) {
|
||||
}
|
||||
|
||||
func extractOpenAIImageCountFromJSONBytes(body []byte) int {
|
||||
if len(body) == 0 || !gjson.ValidBytes(body) {
|
||||
return 0
|
||||
}
|
||||
data := gjson.GetBytes(body, "data")
|
||||
if data.Exists() && data.IsArray() {
|
||||
return len(data.Array())
|
||||
}
|
||||
return 0
|
||||
return countOpenAIResponseImageOutputsFromJSONBytes(body)
|
||||
}
|
||||
|
||||
type openAIImagePointerInfo struct {
|
||||
|
||||
@ -9,6 +9,7 @@ import (
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
@ -361,21 +362,21 @@ func collectOpenAIImagesFromResponsesBody(body []byte) ([]openAIResponsesImageRe
|
||||
var (
|
||||
fallbackResults []openAIResponsesImageResult
|
||||
fallbackSeen = make(map[string]struct{})
|
||||
finalResults []openAIResponsesImageResult
|
||||
finalMeta openAIResponsesImageResult
|
||||
collectErr error
|
||||
createdAt int64
|
||||
usageRaw []byte
|
||||
foundFinal bool
|
||||
responseMeta openAIResponsesImageResult
|
||||
)
|
||||
|
||||
for _, line := range bytes.Split(body, []byte("\n")) {
|
||||
line = bytes.TrimRight(line, "\r")
|
||||
data, ok := extractOpenAISSEDataLine(string(line))
|
||||
if !ok || data == "" || data == "[DONE]" {
|
||||
continue
|
||||
forEachOpenAISSEDataPayload(string(body), func(payload []byte) {
|
||||
if collectErr != nil || len(finalResults) > 0 {
|
||||
return
|
||||
}
|
||||
payload := []byte(data)
|
||||
if !gjson.ValidBytes(payload) {
|
||||
continue
|
||||
return
|
||||
}
|
||||
if meta, eventCreatedAt, ok := extractOpenAIResponsesImageMetaFromLifecycleEvent(payload); ok {
|
||||
mergeOpenAIResponsesImageMeta(&responseMeta, meta)
|
||||
@ -388,7 +389,8 @@ func collectOpenAIImagesFromResponsesBody(body []byte) ([]openAIResponsesImageRe
|
||||
case "response.output_item.done":
|
||||
result, itemID, ok, err := extractOpenAIImageFromResponsesOutputItemDone(payload)
|
||||
if err != nil {
|
||||
return nil, 0, nil, openAIResponsesImageResult{}, false, err
|
||||
collectErr = err
|
||||
return
|
||||
}
|
||||
if ok {
|
||||
mergeOpenAIResponsesImageMeta(&result, responseMeta)
|
||||
@ -397,7 +399,8 @@ func collectOpenAIImagesFromResponsesBody(body []byte) ([]openAIResponsesImageRe
|
||||
case "response.completed":
|
||||
results, completedAt, completedUsageRaw, firstMeta, err := extractOpenAIImagesFromResponsesCompleted(payload)
|
||||
if err != nil {
|
||||
return nil, 0, nil, openAIResponsesImageResult{}, false, err
|
||||
collectErr = err
|
||||
return
|
||||
}
|
||||
foundFinal = true
|
||||
if completedAt > 0 {
|
||||
@ -408,14 +411,24 @@ func collectOpenAIImagesFromResponsesBody(body []byte) ([]openAIResponsesImageRe
|
||||
}
|
||||
if len(results) > 0 {
|
||||
mergeOpenAIResponsesImageMeta(&firstMeta, responseMeta)
|
||||
return results, createdAt, usageRaw, firstMeta, true, nil
|
||||
finalResults = results
|
||||
finalMeta = firstMeta
|
||||
return
|
||||
}
|
||||
if len(fallbackResults) > 0 {
|
||||
firstMeta = fallbackResults[0]
|
||||
mergeOpenAIResponsesImageMeta(&firstMeta, responseMeta)
|
||||
return fallbackResults, createdAt, usageRaw, firstMeta, true, nil
|
||||
finalResults = fallbackResults
|
||||
finalMeta = firstMeta
|
||||
return
|
||||
}
|
||||
}
|
||||
})
|
||||
if collectErr != nil {
|
||||
return nil, 0, nil, openAIResponsesImageResult{}, false, collectErr
|
||||
}
|
||||
if len(finalResults) > 0 {
|
||||
return finalResults, createdAt, usageRaw, finalMeta, true, nil
|
||||
}
|
||||
|
||||
if len(fallbackResults) > 0 {
|
||||
@ -505,6 +518,30 @@ func (s *OpenAIGatewayService) writeOpenAIImagesStreamEvent(c *gin.Context, flus
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) tryWriteOpenAIImagesStreamEvent(
|
||||
c *gin.Context,
|
||||
flusher http.Flusher,
|
||||
clientDisconnected *bool,
|
||||
lastWriteAt *time.Time,
|
||||
eventName string,
|
||||
payload []byte,
|
||||
) bool {
|
||||
if clientDisconnected != nil && *clientDisconnected {
|
||||
return false
|
||||
}
|
||||
if err := s.writeOpenAIImagesStreamEvent(c, flusher, eventName, payload); err != nil {
|
||||
if clientDisconnected != nil {
|
||||
*clientDisconnected = true
|
||||
}
|
||||
logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Images stream client disconnected, continue draining upstream for billing")
|
||||
return false
|
||||
}
|
||||
if lastWriteAt != nil {
|
||||
*lastWriteAt = time.Now()
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) handleOpenAIImagesOAuthNonStreamingResponse(
|
||||
resp *http.Response,
|
||||
c *gin.Context,
|
||||
@ -517,15 +554,9 @@ func (s *OpenAIGatewayService) handleOpenAIImagesOAuthNonStreamingResponse(
|
||||
}
|
||||
|
||||
var usage OpenAIUsage
|
||||
for _, line := range bytes.Split(body, []byte("\n")) {
|
||||
line = bytes.TrimRight(line, "\r")
|
||||
data, ok := extractOpenAISSEDataLine(string(line))
|
||||
if !ok || data == "" || data == "[DONE]" {
|
||||
continue
|
||||
}
|
||||
dataBytes := []byte(data)
|
||||
s.parseSSEUsageBytes(dataBytes, &usage)
|
||||
}
|
||||
forEachOpenAISSEDataPayload(string(body), func(data []byte) {
|
||||
s.parseSSEUsageBytes(data, &usage)
|
||||
})
|
||||
results, createdAt, usageRaw, firstMeta, _, err := collectOpenAIImagesFromResponsesBody(body)
|
||||
if err != nil {
|
||||
return OpenAIUsage{}, 0, err
|
||||
@ -570,7 +601,6 @@ func (s *OpenAIGatewayService) handleOpenAIImagesOAuthStreamingResponse(
|
||||
format = "b64_json"
|
||||
}
|
||||
|
||||
reader := bufio.NewReader(resp.Body)
|
||||
usage := OpenAIUsage{}
|
||||
imageCount := 0
|
||||
var firstTokenMs *int
|
||||
@ -579,141 +609,307 @@ func (s *OpenAIGatewayService) handleOpenAIImagesOAuthStreamingResponse(
|
||||
pendingSeen := make(map[string]struct{})
|
||||
streamMeta := openAIResponsesImageResult{Model: strings.TrimSpace(fallbackModel)}
|
||||
var createdAt int64
|
||||
clientDisconnected := false
|
||||
lastDownstreamWriteAt := time.Now()
|
||||
var sseData openAISSEDataAccumulator
|
||||
var processDataErr error
|
||||
processDataDone := false
|
||||
|
||||
processData := func(dataBytes []byte) {
|
||||
if processDataDone || processDataErr != nil {
|
||||
return
|
||||
}
|
||||
if firstTokenMs == nil {
|
||||
ms := int(time.Since(startTime).Milliseconds())
|
||||
firstTokenMs = &ms
|
||||
}
|
||||
s.parseSSEUsageBytes(dataBytes, &usage)
|
||||
if !gjson.ValidBytes(dataBytes) {
|
||||
return
|
||||
}
|
||||
if meta, eventCreatedAt, ok := extractOpenAIResponsesImageMetaFromLifecycleEvent(dataBytes); ok {
|
||||
mergeOpenAIResponsesImageMeta(&streamMeta, meta)
|
||||
if eventCreatedAt > 0 {
|
||||
createdAt = eventCreatedAt
|
||||
}
|
||||
}
|
||||
switch gjson.GetBytes(dataBytes, "type").String() {
|
||||
case "response.image_generation_call.partial_image":
|
||||
b64 := strings.TrimSpace(gjson.GetBytes(dataBytes, "partial_image_b64").String())
|
||||
if b64 == "" {
|
||||
return
|
||||
}
|
||||
eventName := streamPrefix + ".partial_image"
|
||||
partialMeta := streamMeta
|
||||
mergeOpenAIResponsesImageMeta(&partialMeta, openAIResponsesImageResult{
|
||||
OutputFormat: strings.TrimSpace(gjson.GetBytes(dataBytes, "output_format").String()),
|
||||
Background: strings.TrimSpace(gjson.GetBytes(dataBytes, "background").String()),
|
||||
})
|
||||
payload := buildOpenAIImagesStreamPartialPayload(
|
||||
eventName,
|
||||
b64,
|
||||
gjson.GetBytes(dataBytes, "partial_image_index").Int(),
|
||||
format,
|
||||
createdAt,
|
||||
partialMeta,
|
||||
)
|
||||
s.tryWriteOpenAIImagesStreamEvent(c, flusher, &clientDisconnected, &lastDownstreamWriteAt, eventName, payload)
|
||||
case "response.output_item.done":
|
||||
img, itemID, ok, extractErr := extractOpenAIImageFromResponsesOutputItemDone(dataBytes)
|
||||
if extractErr != nil {
|
||||
s.tryWriteOpenAIImagesStreamEvent(c, flusher, &clientDisconnected, &lastDownstreamWriteAt, "error", buildOpenAIImagesStreamErrorBody(extractErr.Error()))
|
||||
processDataErr = extractErr
|
||||
processDataDone = true
|
||||
return
|
||||
}
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
mergeOpenAIResponsesImageMeta(&streamMeta, img)
|
||||
mergeOpenAIResponsesImageMeta(&img, streamMeta)
|
||||
key := openAIResponsesImageResultKey(itemID, img)
|
||||
if _, exists := emitted[key]; exists {
|
||||
return
|
||||
}
|
||||
if _, exists := pendingSeen[key]; exists {
|
||||
return
|
||||
}
|
||||
pendingSeen[key] = struct{}{}
|
||||
pendingResults = append(pendingResults, img)
|
||||
case "response.completed":
|
||||
results, _, usageRaw, firstMeta, extractErr := extractOpenAIImagesFromResponsesCompleted(dataBytes)
|
||||
if extractErr != nil {
|
||||
s.tryWriteOpenAIImagesStreamEvent(c, flusher, &clientDisconnected, &lastDownstreamWriteAt, "error", buildOpenAIImagesStreamErrorBody(extractErr.Error()))
|
||||
processDataErr = extractErr
|
||||
processDataDone = true
|
||||
return
|
||||
}
|
||||
mergeOpenAIResponsesImageMeta(&streamMeta, firstMeta)
|
||||
finalResults := make([]openAIResponsesImageResult, 0, len(results)+len(pendingResults))
|
||||
finalSeen := make(map[string]struct{})
|
||||
for _, img := range results {
|
||||
mergeOpenAIResponsesImageMeta(&img, streamMeta)
|
||||
appendOpenAIResponsesImageResultDedup(&finalResults, finalSeen, "", img)
|
||||
}
|
||||
for _, img := range pendingResults {
|
||||
mergeOpenAIResponsesImageMeta(&img, streamMeta)
|
||||
appendOpenAIResponsesImageResultDedup(&finalResults, finalSeen, "", img)
|
||||
}
|
||||
if len(finalResults) == 0 {
|
||||
outputErr := fmt.Errorf("upstream did not return image output")
|
||||
s.tryWriteOpenAIImagesStreamEvent(c, flusher, &clientDisconnected, &lastDownstreamWriteAt, "error", buildOpenAIImagesStreamErrorBody(outputErr.Error()))
|
||||
processDataErr = outputErr
|
||||
processDataDone = true
|
||||
return
|
||||
}
|
||||
eventName := streamPrefix + ".completed"
|
||||
for _, img := range finalResults {
|
||||
key := openAIResponsesImageResultKey("", img)
|
||||
if _, exists := emitted[key]; exists {
|
||||
continue
|
||||
}
|
||||
payload := buildOpenAIImagesStreamCompletedPayload(eventName, img, format, createdAt, usageRaw)
|
||||
emitted[key] = struct{}{}
|
||||
s.tryWriteOpenAIImagesStreamEvent(c, flusher, &clientDisconnected, &lastDownstreamWriteAt, eventName, payload)
|
||||
}
|
||||
imageCount = len(emitted)
|
||||
processDataDone = true
|
||||
}
|
||||
}
|
||||
|
||||
processLine := func(line []byte) (bool, error) {
|
||||
if len(line) == 0 {
|
||||
return false, nil
|
||||
}
|
||||
sseData.AddLine(string(line), processData)
|
||||
if processDataErr != nil {
|
||||
return true, processDataErr
|
||||
}
|
||||
return processDataDone, nil
|
||||
}
|
||||
|
||||
flushData := func() (bool, error) {
|
||||
sseData.Flush(processData)
|
||||
if processDataErr != nil {
|
||||
return true, processDataErr
|
||||
}
|
||||
return processDataDone, nil
|
||||
}
|
||||
|
||||
finalizePending := func() error {
|
||||
if imageCount > 0 {
|
||||
return nil
|
||||
}
|
||||
if len(pendingResults) > 0 {
|
||||
eventName := streamPrefix + ".completed"
|
||||
for _, img := range pendingResults {
|
||||
mergeOpenAIResponsesImageMeta(&img, streamMeta)
|
||||
key := openAIResponsesImageResultKey("", img)
|
||||
if _, exists := emitted[key]; exists {
|
||||
continue
|
||||
}
|
||||
payload := buildOpenAIImagesStreamCompletedPayload(eventName, img, format, createdAt, nil)
|
||||
emitted[key] = struct{}{}
|
||||
s.tryWriteOpenAIImagesStreamEvent(c, flusher, &clientDisconnected, &lastDownstreamWriteAt, eventName, payload)
|
||||
}
|
||||
imageCount = len(emitted)
|
||||
return nil
|
||||
}
|
||||
|
||||
streamErr := fmt.Errorf("stream disconnected before image generation completed")
|
||||
s.tryWriteOpenAIImagesStreamEvent(c, flusher, &clientDisconnected, &lastDownstreamWriteAt, "error", buildOpenAIImagesStreamErrorBody(streamErr.Error()))
|
||||
return streamErr
|
||||
}
|
||||
|
||||
streamInterval := s.openAIImageStreamDataInterval()
|
||||
keepaliveInterval := s.openAIImageStreamKeepaliveInterval()
|
||||
if streamInterval <= 0 && keepaliveInterval <= 0 {
|
||||
reader := bufio.NewReader(resp.Body)
|
||||
for {
|
||||
line, err := reader.ReadBytes('\n')
|
||||
done, processErr := processLine(line)
|
||||
if processErr != nil {
|
||||
return usage, imageCount, firstTokenMs, processErr
|
||||
}
|
||||
if done {
|
||||
return usage, imageCount, firstTokenMs, nil
|
||||
}
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
if done, processErr := flushData(); processErr != nil {
|
||||
return usage, imageCount, firstTokenMs, processErr
|
||||
} else if done {
|
||||
return usage, imageCount, firstTokenMs, nil
|
||||
}
|
||||
s.tryWriteOpenAIImagesStreamEvent(c, flusher, &clientDisconnected, &lastDownstreamWriteAt, "error", buildOpenAIImagesStreamErrorBody(err.Error()))
|
||||
return usage, imageCount, firstTokenMs, err
|
||||
}
|
||||
}
|
||||
if done, processErr := flushData(); processErr != nil {
|
||||
return usage, imageCount, firstTokenMs, processErr
|
||||
} else if done {
|
||||
return usage, imageCount, firstTokenMs, nil
|
||||
}
|
||||
if err := finalizePending(); err != nil {
|
||||
return usage, imageCount, firstTokenMs, err
|
||||
}
|
||||
return usage, imageCount, firstTokenMs, nil
|
||||
}
|
||||
|
||||
type readEvent struct {
|
||||
line []byte
|
||||
err error
|
||||
}
|
||||
events := make(chan readEvent, 16)
|
||||
done := make(chan struct{})
|
||||
sendEvent := func(ev readEvent) bool {
|
||||
select {
|
||||
case events <- ev:
|
||||
return true
|
||||
case <-done:
|
||||
return false
|
||||
}
|
||||
}
|
||||
var lastReadAt int64
|
||||
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
|
||||
go func() {
|
||||
defer close(events)
|
||||
reader := bufio.NewReader(resp.Body)
|
||||
for {
|
||||
line, err := reader.ReadBytes('\n')
|
||||
if len(line) > 0 {
|
||||
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
|
||||
}
|
||||
if len(line) > 0 && !sendEvent(readEvent{line: line}) {
|
||||
return
|
||||
}
|
||||
if err == io.EOF {
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
_ = sendEvent(readEvent{err: err})
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
defer close(done)
|
||||
|
||||
var intervalTicker *time.Ticker
|
||||
if streamInterval > 0 {
|
||||
intervalTicker = time.NewTicker(streamInterval)
|
||||
defer intervalTicker.Stop()
|
||||
}
|
||||
var intervalCh <-chan time.Time
|
||||
if intervalTicker != nil {
|
||||
intervalCh = intervalTicker.C
|
||||
}
|
||||
|
||||
var keepaliveTicker *time.Ticker
|
||||
if keepaliveInterval > 0 {
|
||||
keepaliveTicker = time.NewTicker(keepaliveInterval)
|
||||
defer keepaliveTicker.Stop()
|
||||
}
|
||||
var keepaliveCh <-chan time.Time
|
||||
if keepaliveTicker != nil {
|
||||
keepaliveCh = keepaliveTicker.C
|
||||
}
|
||||
|
||||
for {
|
||||
line, err := reader.ReadBytes('\n')
|
||||
if len(line) > 0 {
|
||||
trimmedLine := strings.TrimRight(string(line), "\r\n")
|
||||
data, ok := extractOpenAISSEDataLine(trimmedLine)
|
||||
if ok && data != "" && data != "[DONE]" {
|
||||
if firstTokenMs == nil {
|
||||
ms := int(time.Since(startTime).Milliseconds())
|
||||
firstTokenMs = &ms
|
||||
select {
|
||||
case ev, ok := <-events:
|
||||
if !ok {
|
||||
if done, processErr := flushData(); processErr != nil {
|
||||
return usage, imageCount, firstTokenMs, processErr
|
||||
} else if done {
|
||||
return usage, imageCount, firstTokenMs, nil
|
||||
}
|
||||
dataBytes := []byte(data)
|
||||
s.parseSSEUsageBytes(dataBytes, &usage)
|
||||
if gjson.ValidBytes(dataBytes) {
|
||||
if meta, eventCreatedAt, ok := extractOpenAIResponsesImageMetaFromLifecycleEvent(dataBytes); ok {
|
||||
mergeOpenAIResponsesImageMeta(&streamMeta, meta)
|
||||
if eventCreatedAt > 0 {
|
||||
createdAt = eventCreatedAt
|
||||
}
|
||||
}
|
||||
switch gjson.GetBytes(dataBytes, "type").String() {
|
||||
case "response.image_generation_call.partial_image":
|
||||
b64 := strings.TrimSpace(gjson.GetBytes(dataBytes, "partial_image_b64").String())
|
||||
if b64 != "" {
|
||||
eventName := streamPrefix + ".partial_image"
|
||||
partialMeta := streamMeta
|
||||
mergeOpenAIResponsesImageMeta(&partialMeta, openAIResponsesImageResult{
|
||||
OutputFormat: strings.TrimSpace(gjson.GetBytes(dataBytes, "output_format").String()),
|
||||
Background: strings.TrimSpace(gjson.GetBytes(dataBytes, "background").String()),
|
||||
})
|
||||
payload := buildOpenAIImagesStreamPartialPayload(
|
||||
eventName,
|
||||
b64,
|
||||
gjson.GetBytes(dataBytes, "partial_image_index").Int(),
|
||||
format,
|
||||
createdAt,
|
||||
partialMeta,
|
||||
)
|
||||
if writeErr := s.writeOpenAIImagesStreamEvent(c, flusher, eventName, payload); writeErr != nil {
|
||||
return OpenAIUsage{}, imageCount, firstTokenMs, writeErr
|
||||
}
|
||||
}
|
||||
case "response.output_item.done":
|
||||
img, itemID, ok, extractErr := extractOpenAIImageFromResponsesOutputItemDone(dataBytes)
|
||||
if extractErr != nil {
|
||||
_ = s.writeOpenAIImagesStreamEvent(c, flusher, "error", buildOpenAIImagesStreamErrorBody(extractErr.Error()))
|
||||
return OpenAIUsage{}, imageCount, firstTokenMs, extractErr
|
||||
}
|
||||
if !ok {
|
||||
break
|
||||
}
|
||||
mergeOpenAIResponsesImageMeta(&streamMeta, img)
|
||||
mergeOpenAIResponsesImageMeta(&img, streamMeta)
|
||||
key := openAIResponsesImageResultKey(itemID, img)
|
||||
if _, exists := emitted[key]; exists {
|
||||
break
|
||||
}
|
||||
if _, exists := pendingSeen[key]; exists {
|
||||
break
|
||||
}
|
||||
pendingSeen[key] = struct{}{}
|
||||
pendingResults = append(pendingResults, img)
|
||||
case "response.completed":
|
||||
results, _, usageRaw, firstMeta, extractErr := extractOpenAIImagesFromResponsesCompleted(dataBytes)
|
||||
if extractErr != nil {
|
||||
_ = s.writeOpenAIImagesStreamEvent(c, flusher, "error", buildOpenAIImagesStreamErrorBody(extractErr.Error()))
|
||||
return OpenAIUsage{}, imageCount, firstTokenMs, extractErr
|
||||
}
|
||||
mergeOpenAIResponsesImageMeta(&streamMeta, firstMeta)
|
||||
finalResults := make([]openAIResponsesImageResult, 0, len(results)+len(pendingResults))
|
||||
finalSeen := make(map[string]struct{})
|
||||
for _, img := range results {
|
||||
mergeOpenAIResponsesImageMeta(&img, streamMeta)
|
||||
appendOpenAIResponsesImageResultDedup(&finalResults, finalSeen, "", img)
|
||||
}
|
||||
for _, img := range pendingResults {
|
||||
mergeOpenAIResponsesImageMeta(&img, streamMeta)
|
||||
appendOpenAIResponsesImageResultDedup(&finalResults, finalSeen, "", img)
|
||||
}
|
||||
if len(finalResults) == 0 {
|
||||
err = fmt.Errorf("upstream did not return image output")
|
||||
_ = s.writeOpenAIImagesStreamEvent(c, flusher, "error", buildOpenAIImagesStreamErrorBody(err.Error()))
|
||||
return OpenAIUsage{}, imageCount, firstTokenMs, err
|
||||
}
|
||||
eventName := streamPrefix + ".completed"
|
||||
for _, img := range finalResults {
|
||||
key := openAIResponsesImageResultKey("", img)
|
||||
if _, exists := emitted[key]; exists {
|
||||
continue
|
||||
}
|
||||
payload := buildOpenAIImagesStreamCompletedPayload(eventName, img, format, createdAt, usageRaw)
|
||||
if writeErr := s.writeOpenAIImagesStreamEvent(c, flusher, eventName, payload); writeErr != nil {
|
||||
return OpenAIUsage{}, imageCount, firstTokenMs, writeErr
|
||||
}
|
||||
emitted[key] = struct{}{}
|
||||
}
|
||||
imageCount = len(emitted)
|
||||
return usage, imageCount, firstTokenMs, nil
|
||||
}
|
||||
if err := finalizePending(); err != nil {
|
||||
return usage, imageCount, firstTokenMs, err
|
||||
}
|
||||
return usage, imageCount, firstTokenMs, nil
|
||||
}
|
||||
}
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
_ = s.writeOpenAIImagesStreamEvent(c, flusher, "error", buildOpenAIImagesStreamErrorBody(err.Error()))
|
||||
return OpenAIUsage{}, imageCount, firstTokenMs, err
|
||||
}
|
||||
}
|
||||
|
||||
if imageCount > 0 {
|
||||
return usage, imageCount, firstTokenMs, nil
|
||||
}
|
||||
if len(pendingResults) > 0 {
|
||||
eventName := streamPrefix + ".completed"
|
||||
for _, img := range pendingResults {
|
||||
mergeOpenAIResponsesImageMeta(&img, streamMeta)
|
||||
key := openAIResponsesImageResultKey("", img)
|
||||
if _, exists := emitted[key]; exists {
|
||||
if ev.err != nil {
|
||||
if done, processErr := flushData(); processErr != nil {
|
||||
return usage, imageCount, firstTokenMs, processErr
|
||||
} else if done {
|
||||
return usage, imageCount, firstTokenMs, nil
|
||||
}
|
||||
s.tryWriteOpenAIImagesStreamEvent(c, flusher, &clientDisconnected, &lastDownstreamWriteAt, "error", buildOpenAIImagesStreamErrorBody(ev.err.Error()))
|
||||
return usage, imageCount, firstTokenMs, ev.err
|
||||
}
|
||||
done, processErr := processLine(ev.line)
|
||||
if processErr != nil {
|
||||
return usage, imageCount, firstTokenMs, processErr
|
||||
}
|
||||
if done {
|
||||
return usage, imageCount, firstTokenMs, nil
|
||||
}
|
||||
case <-intervalCh:
|
||||
lastRead := time.Unix(0, atomic.LoadInt64(&lastReadAt))
|
||||
if time.Since(lastRead) < streamInterval {
|
||||
continue
|
||||
}
|
||||
payload := buildOpenAIImagesStreamCompletedPayload(eventName, img, format, createdAt, nil)
|
||||
if writeErr := s.writeOpenAIImagesStreamEvent(c, flusher, eventName, payload); writeErr != nil {
|
||||
return OpenAIUsage{}, imageCount, firstTokenMs, writeErr
|
||||
if clientDisconnected {
|
||||
return usage, imageCount, firstTokenMs, fmt.Errorf("image stream incomplete after timeout")
|
||||
}
|
||||
emitted[key] = struct{}{}
|
||||
logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Images responses stream data interval timeout: interval=%s", streamInterval)
|
||||
s.tryWriteOpenAIImagesStreamEvent(c, flusher, &clientDisconnected, &lastDownstreamWriteAt, "error", buildOpenAIImagesStreamErrorBody(fmt.Sprintf("upstream image stream idle for %s", streamInterval)))
|
||||
return usage, imageCount, firstTokenMs, fmt.Errorf("image stream data interval timeout")
|
||||
case <-keepaliveCh:
|
||||
if clientDisconnected || time.Since(lastDownstreamWriteAt) < keepaliveInterval {
|
||||
continue
|
||||
}
|
||||
if _, writeErr := io.WriteString(c.Writer, ":\n\n"); writeErr != nil {
|
||||
clientDisconnected = true
|
||||
logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Images responses stream client disconnected during keepalive, continue draining upstream for billing")
|
||||
continue
|
||||
}
|
||||
flusher.Flush()
|
||||
lastDownstreamWriteAt = time.Now()
|
||||
}
|
||||
imageCount = len(emitted)
|
||||
return usage, imageCount, firstTokenMs, nil
|
||||
}
|
||||
|
||||
streamErr := fmt.Errorf("stream disconnected before image generation completed")
|
||||
_ = s.writeOpenAIImagesStreamEvent(c, flusher, "error", buildOpenAIImagesStreamErrorBody(streamErr.Error()))
|
||||
return OpenAIUsage{}, imageCount, firstTokenMs, streamErr
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) forwardOpenAIImagesOAuth(
|
||||
@ -752,7 +948,10 @@ func (s *OpenAIGatewayService) forwardOpenAIImagesOAuth(
|
||||
)
|
||||
}
|
||||
|
||||
token, _, err := s.GetAccessToken(ctx, account)
|
||||
upstreamCtx, releaseUpstreamCtx := detachStreamUpstreamContext(ctx, parsed.Stream)
|
||||
defer releaseUpstreamCtx()
|
||||
|
||||
token, _, err := s.GetAccessToken(upstreamCtx, account)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -763,7 +962,7 @@ func (s *OpenAIGatewayService) forwardOpenAIImagesOAuth(
|
||||
}
|
||||
setOpsUpstreamRequestBody(c, responsesBody)
|
||||
|
||||
upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, responsesBody, token, true, parsed.StickySessionSeed(), false)
|
||||
upstreamReq, err := s.buildUpstreamRequest(upstreamCtx, c, account, responsesBody, token, true, parsed.StickySessionSeed(), false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -808,14 +1007,14 @@ func (s *OpenAIGatewayService) forwardOpenAIImagesOAuth(
|
||||
Kind: "failover",
|
||||
Message: upstreamMsg,
|
||||
})
|
||||
s.handleFailoverSideEffects(ctx, resp, account)
|
||||
s.handleFailoverSideEffects(upstreamCtx, resp, account)
|
||||
return nil, &UpstreamFailoverError{
|
||||
StatusCode: resp.StatusCode,
|
||||
ResponseBody: respBody,
|
||||
RetryableOnSameAccount: account.IsPoolMode() && isPoolModeRetryableStatus(resp.StatusCode),
|
||||
}
|
||||
}
|
||||
return s.handleErrorResponse(ctx, resp, c, account, responsesBody)
|
||||
return s.handleErrorResponse(upstreamCtx, resp, c, account, responsesBody)
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
@ -827,6 +1026,20 @@ func (s *OpenAIGatewayService) forwardOpenAIImagesOAuth(
|
||||
if parsed.Stream {
|
||||
usage, imageCount, firstTokenMs, err = s.handleOpenAIImagesOAuthStreamingResponse(resp, c, startTime, parsed.ResponseFormat, openAIImagesStreamPrefix(parsed), requestModel)
|
||||
if err != nil {
|
||||
if imageCount > 0 {
|
||||
return &OpenAIForwardResult{
|
||||
RequestID: resp.Header.Get("x-request-id"),
|
||||
Usage: usage,
|
||||
Model: requestModel,
|
||||
UpstreamModel: requestModel,
|
||||
Stream: parsed.Stream,
|
||||
ResponseHeaders: resp.Header.Clone(),
|
||||
Duration: time.Since(startTime),
|
||||
FirstTokenMs: firstTokenMs,
|
||||
ImageCount: imageCount,
|
||||
ImageSize: parsed.SizeTier,
|
||||
}, err
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
|
||||
@ -3,6 +3,7 @@ package service
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
@ -17,6 +18,20 @@ import (
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
type failingOpenAIImageWriter struct {
|
||||
gin.ResponseWriter
|
||||
failAfter int
|
||||
writes int
|
||||
}
|
||||
|
||||
func (w *failingOpenAIImageWriter) Write(p []byte) (int, error) {
|
||||
if w.writes >= w.failAfter {
|
||||
return 0, errors.New("write failed: client disconnected")
|
||||
}
|
||||
w.writes++
|
||||
return w.ResponseWriter.Write(p)
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayServiceParseOpenAIImagesRequest_JSON(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
body := []byte(`{"model":"gpt-image-2","prompt":"draw a cat","size":"1024x1024","quality":"high","stream":true}`)
|
||||
@ -75,6 +90,100 @@ func TestOpenAIGatewayServiceParseOpenAIImagesRequest_MultipartEdit(t *testing.T
|
||||
require.Equal(t, OpenAIImagesCapabilityNative, parsed.RequiredCapability)
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayServiceParseOpenAIImagesRequest_NormalizesOfficialAndCustomSizes(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
tests := []struct {
|
||||
size string
|
||||
wantTier string
|
||||
}{
|
||||
{size: "1024x1024", wantTier: "1K"},
|
||||
{size: "1536x1024", wantTier: "2K"},
|
||||
{size: "1024x1536", wantTier: "2K"},
|
||||
{size: "2048x2048", wantTier: "2K"},
|
||||
{size: "2048x1152", wantTier: "2K"},
|
||||
{size: "3840x2160", wantTier: "4K"},
|
||||
{size: "2160x3840", wantTier: "4K"},
|
||||
{size: "1024X768", wantTier: "2K"},
|
||||
{size: "1280x768", wantTier: "2K"},
|
||||
{size: "2560x1440", wantTier: "2K"},
|
||||
{size: "2560x1600", wantTier: "4K"},
|
||||
{size: "auto", wantTier: "2K"},
|
||||
}
|
||||
|
||||
svc := &OpenAIGatewayService{}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.size, func(t *testing.T) {
|
||||
body := []byte(`{"model":"gpt-image-2","prompt":"draw a cat","size":"` + tt.size + `"}`)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/images/generations", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = req
|
||||
|
||||
parsed, err := svc.ParseOpenAIImagesRequest(c, body)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, parsed)
|
||||
require.Equal(t, tt.size, parsed.Size)
|
||||
require.Equal(t, tt.wantTier, parsed.SizeTier)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayServiceParseOpenAIImagesRequest_UnknownSizesDoNotBlockPassthrough(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
tests := []struct {
|
||||
size string
|
||||
wantTier string
|
||||
}{
|
||||
{size: "2048x1153", wantTier: "2K"},
|
||||
{size: "4096x1024", wantTier: "4K"},
|
||||
{size: "3840x1024", wantTier: "4K"},
|
||||
{size: "512x512", wantTier: "2K"},
|
||||
{size: "invalid", wantTier: "2K"},
|
||||
{size: "999999999999999999999999999x2", wantTier: "2K"},
|
||||
}
|
||||
|
||||
svc := &OpenAIGatewayService{}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.size, func(t *testing.T) {
|
||||
body := []byte(`{"model":"gpt-image-2","prompt":"draw a cat","size":"` + tt.size + `"}`)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/images/generations", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = req
|
||||
|
||||
parsed, err := svc.ParseOpenAIImagesRequest(c, body)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, parsed)
|
||||
require.Equal(t, tt.size, parsed.Size)
|
||||
require.Equal(t, tt.wantTier, parsed.SizeTier)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayServiceParseOpenAIImagesRequest_LegacyImageModelUnknownSizePassthrough(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
body := []byte(`{"model":"gpt-image-1.5","prompt":"draw a cat","size":"2048x1152"}`)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/images/generations", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = req
|
||||
|
||||
svc := &OpenAIGatewayService{}
|
||||
parsed, err := svc.ParseOpenAIImagesRequest(c, body)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, parsed)
|
||||
require.Equal(t, "2048x1152", parsed.Size)
|
||||
require.Equal(t, "2K", parsed.SizeTier)
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayServiceParseOpenAIImagesRequest_MultipartEditWithMaskAndNativeOptions(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
@ -543,6 +652,57 @@ func TestOpenAIGatewayServiceForwardImages_APIKeyStreamRawJSONEventStreamFallbac
|
||||
require.Equal(t, "ZmluYWw=", gjson.Get(rec.Body.String(), "data.0.b64_json").String())
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayServiceForwardImages_APIKeyStreamMultilineSSEDataBillsImage(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
body := []byte(`{"model":"gpt-image-2","prompt":"draw a cat","stream":true,"response_format":"b64_json"}`)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/images/generations", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = req
|
||||
|
||||
svc := &OpenAIGatewayService{
|
||||
cfg: &config.Config{},
|
||||
httpUpstream: &httpUpstreamRecorder{
|
||||
resp: &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{
|
||||
"Content-Type": []string{"text/event-stream"},
|
||||
"X-Request-Id": []string{"req_img_stream_multiline"},
|
||||
},
|
||||
Body: io.NopCloser(strings.NewReader(
|
||||
"data: {\"type\":\"image_generation.completed\",\n" +
|
||||
"data: \"usage\":{\"input_tokens\":10,\"output_tokens\":18,\"output_tokens_details\":{\"image_tokens\":8}},\n" +
|
||||
"data: \"b64_json\":\"ZmluYWw=\",\"output_format\":\"png\"}\n\n" +
|
||||
"data: [DONE]\n\n",
|
||||
)),
|
||||
},
|
||||
},
|
||||
}
|
||||
parsed, err := svc.ParseOpenAIImagesRequest(c, body)
|
||||
require.NoError(t, err)
|
||||
|
||||
account := &Account{
|
||||
ID: 8,
|
||||
Name: "openai-apikey",
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeAPIKey,
|
||||
Credentials: map[string]any{
|
||||
"api_key": "test-api-key",
|
||||
},
|
||||
}
|
||||
|
||||
result, err := svc.ForwardImages(context.Background(), c, account, body, parsed, "")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.True(t, result.Stream)
|
||||
require.Equal(t, 1, result.ImageCount)
|
||||
require.Equal(t, 10, result.Usage.InputTokens)
|
||||
require.Equal(t, 18, result.Usage.OutputTokens)
|
||||
require.Equal(t, 8, result.Usage.ImageOutputTokens)
|
||||
}
|
||||
|
||||
func TestExtractOpenAIImagesBillableCountFromJSONBytes_CompletedEvent(t *testing.T) {
|
||||
body := []byte(`{"type":"image_generation.completed","b64_json":"ZmluYWw=","usage":{"input_tokens":10,"output_tokens":18}}`)
|
||||
|
||||
@ -686,6 +846,61 @@ func TestOpenAIGatewayServiceForwardImages_OAuthStreamingTransformsEvents(t *tes
|
||||
require.False(t, gjson.Get(completed.Data, "revised_prompt").Exists())
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayServiceForwardImages_APIKeyStreamingDrainsAfterClientDisconnect(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
body := []byte(`{"model":"gpt-image-2","prompt":"draw a cat","stream":true}`)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/images/generations", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = req
|
||||
c.Writer = &failingOpenAIImageWriter{ResponseWriter: c.Writer, failAfter: 1}
|
||||
|
||||
svc := &OpenAIGatewayService{
|
||||
cfg: &config.Config{
|
||||
Gateway: config.GatewayConfig{
|
||||
ImageStreamDataIntervalTimeout: 1,
|
||||
ImageStreamKeepaliveInterval: 0,
|
||||
},
|
||||
},
|
||||
httpUpstream: &httpUpstreamRecorder{
|
||||
resp: &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{
|
||||
"Content-Type": []string{"text/event-stream"},
|
||||
"X-Request-Id": []string{"req_img_stream_disconnect_apikey"},
|
||||
},
|
||||
Body: io.NopCloser(strings.NewReader(
|
||||
"data: {\"type\":\"image_generation.partial_image\",\"b64_json\":\"cGFydGlhbA==\"}\n\n" +
|
||||
"data: {\"type\":\"image_generation.completed\",\"usage\":{\"input_tokens\":3,\"output_tokens\":4,\"output_tokens_details\":{\"image_tokens\":2}},\"b64_json\":\"ZmluYWw=\",\"output_format\":\"png\"}\n\n" +
|
||||
"data: [DONE]\n\n",
|
||||
)),
|
||||
},
|
||||
},
|
||||
}
|
||||
parsed, err := svc.ParseOpenAIImagesRequest(c, body)
|
||||
require.NoError(t, err)
|
||||
|
||||
account := &Account{
|
||||
ID: 8,
|
||||
Name: "openai-apikey",
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeAPIKey,
|
||||
Credentials: map[string]any{
|
||||
"api_key": "test-api-key",
|
||||
},
|
||||
}
|
||||
|
||||
result, err := svc.ForwardImages(context.Background(), c, account, body, parsed, "")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.Equal(t, 1, result.ImageCount)
|
||||
require.Equal(t, 3, result.Usage.InputTokens)
|
||||
require.Equal(t, 4, result.Usage.OutputTokens)
|
||||
require.Equal(t, 2, result.Usage.ImageOutputTokens)
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayServiceForwardImages_OAuthEditsMultipartUsesResponsesAPI(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
@ -901,6 +1116,23 @@ func TestCollectOpenAIImagesFromResponsesBody_FallsBackToOutputItemDone(t *testi
|
||||
require.JSONEq(t, `{"images":1}`, string(usageRaw))
|
||||
}
|
||||
|
||||
func TestCollectOpenAIImagesFromResponsesBody_MultilineSSE(t *testing.T) {
|
||||
body := []byte(
|
||||
"data: {\"type\":\"response.completed\",\n" +
|
||||
"data: \"response\":{\"created_at\":1710000010,\"usage\":{\"input_tokens\":5,\"output_tokens\":9,\"output_tokens_details\":{\"image_tokens\":4}},\"tool_usage\":{\"image_gen\":{\"images\":1}},\"output\":[{\"type\":\"image_generation_call\",\"result\":\"ZmluYWw=\",\"output_format\":\"png\"}]}}\n\n" +
|
||||
"data: [DONE]\n\n",
|
||||
)
|
||||
|
||||
results, createdAt, usageRaw, firstMeta, foundFinal, err := collectOpenAIImagesFromResponsesBody(body)
|
||||
require.NoError(t, err)
|
||||
require.True(t, foundFinal)
|
||||
require.Equal(t, int64(1710000010), createdAt)
|
||||
require.Len(t, results, 1)
|
||||
require.Equal(t, "ZmluYWw=", results[0].Result)
|
||||
require.Equal(t, "png", firstMeta.OutputFormat)
|
||||
require.JSONEq(t, `{"images":1}`, string(usageRaw))
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayServiceForwardImages_OAuthStreamingHandlesOutputItemDoneFallback(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
body := []byte(`{"model":"gpt-image-2","prompt":"draw a cat","stream":true,"response_format":"url"}`)
|
||||
@ -957,3 +1189,116 @@ func TestOpenAIGatewayServiceForwardImages_OAuthStreamingHandlesOutputItemDoneFa
|
||||
require.JSONEq(t, `{"images":1}`, gjson.Get(completed.Data, "usage").Raw)
|
||||
require.NotContains(t, rec.Body.String(), "event: error")
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayServiceForwardImages_OAuthStreamingHandlesMultilineSSE(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
body := []byte(`{"model":"gpt-image-2","prompt":"draw a cat","stream":true,"response_format":"b64_json"}`)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/images/generations", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = req
|
||||
|
||||
svc := &OpenAIGatewayService{}
|
||||
parsed, err := svc.ParseOpenAIImagesRequest(c, body)
|
||||
require.NoError(t, err)
|
||||
|
||||
svc.httpUpstream = &httpUpstreamRecorder{
|
||||
resp: &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{
|
||||
"Content-Type": []string{"text/event-stream"},
|
||||
"X-Request-Id": []string{"req_img_stream_multiline_oauth"},
|
||||
},
|
||||
Body: io.NopCloser(strings.NewReader(
|
||||
"data: {\"type\":\"response.completed\",\n" +
|
||||
"data: \"response\":{\"created_at\":1710000011,\"usage\":{\"input_tokens\":6,\"output_tokens\":10,\"output_tokens_details\":{\"image_tokens\":5}},\"tool_usage\":{\"image_gen\":{\"images\":1}},\"output\":[{\"type\":\"image_generation_call\",\"result\":\"TXVsdGlsaW5l\",\"output_format\":\"png\"}]}}\n\n" +
|
||||
"data: [DONE]\n\n",
|
||||
)),
|
||||
},
|
||||
}
|
||||
|
||||
account := &Account{
|
||||
ID: 11,
|
||||
Name: "openai-oauth",
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "token-123",
|
||||
},
|
||||
}
|
||||
|
||||
result, err := svc.ForwardImages(context.Background(), c, account, body, parsed, "")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.True(t, result.Stream)
|
||||
require.Equal(t, 1, result.ImageCount)
|
||||
require.Equal(t, 6, result.Usage.InputTokens)
|
||||
require.Equal(t, 10, result.Usage.OutputTokens)
|
||||
require.Equal(t, 5, result.Usage.ImageOutputTokens)
|
||||
events := parseOpenAIImageTestSSEEvents(rec.Body.String())
|
||||
completed, ok := findOpenAIImageTestSSEEvent(events, "image_generation.completed")
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "TXVsdGlsaW5l", gjson.Get(completed.Data, "b64_json").String())
|
||||
require.JSONEq(t, `{"images":1}`, gjson.Get(completed.Data, "usage").Raw)
|
||||
require.NotContains(t, rec.Body.String(), "event: error")
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayServiceForwardImages_OAuthStreamingDrainsAfterClientDisconnect(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
body := []byte(`{"model":"gpt-image-2","prompt":"draw a cat","stream":true,"response_format":"url"}`)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/images/generations", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = req
|
||||
c.Writer = &failingOpenAIImageWriter{ResponseWriter: c.Writer, failAfter: 1}
|
||||
|
||||
svc := &OpenAIGatewayService{
|
||||
cfg: &config.Config{
|
||||
Gateway: config.GatewayConfig{
|
||||
ImageStreamDataIntervalTimeout: 1,
|
||||
ImageStreamKeepaliveInterval: 0,
|
||||
},
|
||||
},
|
||||
}
|
||||
parsed, err := svc.ParseOpenAIImagesRequest(c, body)
|
||||
require.NoError(t, err)
|
||||
|
||||
upstream := &httpUpstreamRecorder{
|
||||
resp: &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{
|
||||
"Content-Type": []string{"text/event-stream"},
|
||||
"X-Request-Id": []string{"req_img_stream_disconnect_oauth"},
|
||||
},
|
||||
Body: io.NopCloser(strings.NewReader(
|
||||
"data: {\"type\":\"response.image_generation_call.partial_image\",\"partial_image_b64\":\"cGFydGlhbA==\",\"partial_image_index\":0,\"output_format\":\"png\"}\n\n" +
|
||||
"data: {\"type\":\"response.completed\",\"response\":{\"created_at\":1710000009,\"usage\":{\"input_tokens\":5,\"output_tokens\":9,\"output_tokens_details\":{\"image_tokens\":4}},\"tool_usage\":{\"image_gen\":{\"images\":1}},\"output\":[{\"type\":\"image_generation_call\",\"result\":\"ZmluYWw=\",\"output_format\":\"png\"}]}}\n\n" +
|
||||
"data: [DONE]\n\n",
|
||||
)),
|
||||
},
|
||||
}
|
||||
svc.httpUpstream = upstream
|
||||
|
||||
account := &Account{
|
||||
ID: 9,
|
||||
Name: "openai-oauth",
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "token-123",
|
||||
},
|
||||
}
|
||||
|
||||
result, err := svc.ForwardImages(context.Background(), c, account, body, parsed, "")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.True(t, result.Stream)
|
||||
require.Equal(t, 1, result.ImageCount)
|
||||
require.Equal(t, 5, result.Usage.InputTokens)
|
||||
require.Equal(t, 9, result.Usage.OutputTokens)
|
||||
require.Equal(t, 4, result.Usage.ImageOutputTokens)
|
||||
}
|
||||
|
||||
137
backend/internal/service/openai_model_alias.go
Normal file
137
backend/internal/service/openai_model_alias.go
Normal file
@ -0,0 +1,137 @@
|
||||
package service
|
||||
|
||||
import "strings"
|
||||
|
||||
func lastOpenAIModelSegment(model string) string {
|
||||
model = strings.TrimSpace(model)
|
||||
if model == "" {
|
||||
return ""
|
||||
}
|
||||
if strings.Contains(model, "/") {
|
||||
parts := strings.Split(model, "/")
|
||||
model = parts[len(parts)-1]
|
||||
}
|
||||
return strings.TrimSpace(model)
|
||||
}
|
||||
|
||||
func canonicalizeOpenAIModelAliasSpelling(model string) string {
|
||||
model = strings.ToLower(lastOpenAIModelSegment(model))
|
||||
if model == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
normalized := strings.ReplaceAll(model, "_", "-")
|
||||
normalized = strings.Join(strings.Fields(normalized), "-")
|
||||
for strings.Contains(normalized, "--") {
|
||||
normalized = strings.ReplaceAll(normalized, "--", "-")
|
||||
}
|
||||
|
||||
if strings.HasPrefix(normalized, "gpt5") {
|
||||
normalized = "gpt-5" + strings.TrimPrefix(normalized, "gpt5")
|
||||
}
|
||||
if !strings.HasPrefix(normalized, "gpt-") && !strings.Contains(normalized, "codex") {
|
||||
return ""
|
||||
}
|
||||
|
||||
replacements := []struct {
|
||||
from string
|
||||
to string
|
||||
}{
|
||||
{"gpt-5.4mini", "gpt-5.4-mini"},
|
||||
{"gpt-5.4nano", "gpt-5.4-nano"},
|
||||
{"gpt-5.3-codexspark", "gpt-5.3-codex-spark"},
|
||||
{"gpt-5.3codexspark", "gpt-5.3-codex-spark"},
|
||||
{"gpt-5.3codex", "gpt-5.3-codex"},
|
||||
}
|
||||
for _, replacement := range replacements {
|
||||
normalized = strings.ReplaceAll(normalized, replacement.from, replacement.to)
|
||||
}
|
||||
return normalized
|
||||
}
|
||||
|
||||
func normalizeKnownOpenAICodexModel(model string) string {
|
||||
normalized := canonicalizeOpenAIModelAliasSpelling(model)
|
||||
if normalized == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
if mapped := getNormalizedCodexModel(normalized); mapped != "" {
|
||||
return mapped
|
||||
}
|
||||
if strings.HasSuffix(normalized, "-openai-compact") {
|
||||
if mapped := getNormalizedCodexModel(strings.TrimSuffix(normalized, "-openai-compact")); mapped != "" {
|
||||
return mapped
|
||||
}
|
||||
}
|
||||
|
||||
switch {
|
||||
case strings.Contains(normalized, "gpt-5.5"):
|
||||
return "gpt-5.5"
|
||||
case strings.Contains(normalized, "gpt-5.4-mini"):
|
||||
return "gpt-5.4-mini"
|
||||
case strings.Contains(normalized, "gpt-5.4-nano"):
|
||||
return "gpt-5.4-nano"
|
||||
case strings.Contains(normalized, "gpt-5.4"):
|
||||
return "gpt-5.4"
|
||||
case strings.Contains(normalized, "gpt-5.2"):
|
||||
return "gpt-5.2"
|
||||
case strings.Contains(normalized, "gpt-5.3-codex-spark"):
|
||||
return "gpt-5.3-codex-spark"
|
||||
case strings.Contains(normalized, "gpt-5.3-codex"):
|
||||
return "gpt-5.3-codex"
|
||||
case strings.Contains(normalized, "gpt-5.3"):
|
||||
return "gpt-5.3-codex"
|
||||
case strings.Contains(normalized, "codex"):
|
||||
return "gpt-5.3-codex"
|
||||
case strings.Contains(normalized, "gpt-5"):
|
||||
return "gpt-5.4"
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
func appendUsageBillingModelCandidate(candidates []string, seen map[string]struct{}, model string) []string {
|
||||
trimmed := strings.TrimSpace(model)
|
||||
if trimmed == "" {
|
||||
return candidates
|
||||
}
|
||||
add := func(candidate string) {
|
||||
candidate = strings.TrimSpace(candidate)
|
||||
if candidate == "" {
|
||||
return
|
||||
}
|
||||
key := strings.ToLower(candidate)
|
||||
if _, ok := seen[key]; ok {
|
||||
return
|
||||
}
|
||||
seen[key] = struct{}{}
|
||||
candidates = append(candidates, candidate)
|
||||
}
|
||||
|
||||
add(trimmed)
|
||||
if canonical := canonicalizeOpenAIModelAliasSpelling(trimmed); canonical != "" {
|
||||
add(canonical)
|
||||
}
|
||||
if normalized := normalizeKnownOpenAICodexModel(trimmed); normalized != "" {
|
||||
add(normalized)
|
||||
}
|
||||
return candidates
|
||||
}
|
||||
|
||||
func usageBillingModelCandidates(primary string, alternates ...string) []string {
|
||||
seen := make(map[string]struct{}, 1+len(alternates))
|
||||
candidates := appendUsageBillingModelCandidate(nil, seen, primary)
|
||||
for _, alternate := range alternates {
|
||||
candidates = appendUsageBillingModelCandidate(candidates, seen, alternate)
|
||||
}
|
||||
return candidates
|
||||
}
|
||||
|
||||
func firstUsageBillingModel(candidates []string) string {
|
||||
for _, candidate := range candidates {
|
||||
if trimmed := strings.TrimSpace(candidate); trimmed != "" {
|
||||
return trimmed
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
@ -94,6 +94,15 @@ func TestResolveOpenAIForwardModel(t *testing.T) {
|
||||
defaultMappedModel: "gpt-5.4",
|
||||
expectedModel: "gpt-5.5",
|
||||
},
|
||||
{
|
||||
name: "preserves compact-spelled gpt5.5 instead of group default",
|
||||
account: &Account{
|
||||
Credentials: map[string]any{},
|
||||
},
|
||||
requestedModel: "gpt5.5",
|
||||
defaultMappedModel: "gpt-5.4",
|
||||
expectedModel: "gpt5.5",
|
||||
},
|
||||
{
|
||||
name: "preserves openai namespaced gpt-5.5 instead of group default",
|
||||
account: &Account{
|
||||
|
||||
70
backend/internal/service/openai_sse_data.go
Normal file
70
backend/internal/service/openai_sse_data.go
Normal file
@ -0,0 +1,70 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
type openAISSEDataAccumulator struct {
|
||||
lines []string
|
||||
}
|
||||
|
||||
func (a *openAISSEDataAccumulator) AddLine(line string, fn func([]byte)) {
|
||||
if fn == nil {
|
||||
return
|
||||
}
|
||||
trimmedLine := strings.TrimRight(line, "\r\n")
|
||||
if data, ok := extractOpenAISSEDataLine(trimmedLine); ok {
|
||||
a.lines = append(a.lines, data)
|
||||
return
|
||||
}
|
||||
if strings.TrimSpace(trimmedLine) == "" {
|
||||
a.Flush(fn)
|
||||
}
|
||||
}
|
||||
|
||||
func (a *openAISSEDataAccumulator) Flush(fn func([]byte)) {
|
||||
if fn == nil || len(a.lines) == 0 {
|
||||
return
|
||||
}
|
||||
emitOpenAISSEDataPayloads(a.lines, fn)
|
||||
a.lines = a.lines[:0]
|
||||
}
|
||||
|
||||
func forEachOpenAISSEDataPayload(body string, fn func([]byte)) {
|
||||
if fn == nil || strings.TrimSpace(body) == "" {
|
||||
return
|
||||
}
|
||||
var acc openAISSEDataAccumulator
|
||||
for _, line := range strings.Split(body, "\n") {
|
||||
acc.AddLine(line, fn)
|
||||
}
|
||||
acc.Flush(fn)
|
||||
}
|
||||
|
||||
func emitOpenAISSEDataPayloads(lines []string, fn func([]byte)) {
|
||||
if fn == nil || len(lines) == 0 {
|
||||
return
|
||||
}
|
||||
if len(lines) == 1 {
|
||||
emitOpenAISSEDataPayload(lines[0], fn)
|
||||
return
|
||||
}
|
||||
joined := strings.Join(lines, "\n")
|
||||
if gjson.Valid(joined) {
|
||||
emitOpenAISSEDataPayload(joined, fn)
|
||||
return
|
||||
}
|
||||
for _, line := range lines {
|
||||
emitOpenAISSEDataPayload(line, fn)
|
||||
}
|
||||
}
|
||||
|
||||
func emitOpenAISSEDataPayload(data string, fn func([]byte)) {
|
||||
data = strings.TrimSpace(data)
|
||||
if data == "" || data == "[DONE]" {
|
||||
return
|
||||
}
|
||||
fn([]byte(data))
|
||||
}
|
||||
@ -1990,6 +1990,7 @@ func (s *OpenAIGatewayService) forwardOpenAIWSV2(
|
||||
}
|
||||
|
||||
usage := &OpenAIUsage{}
|
||||
imageCounter := newOpenAIImageOutputCounter()
|
||||
var firstTokenMs *int
|
||||
responseID := ""
|
||||
var finalResponse []byte
|
||||
@ -2171,6 +2172,7 @@ func (s *OpenAIGatewayService) forwardOpenAIWSV2(
|
||||
if openAIWSEventShouldParseUsage(eventType) {
|
||||
parseOpenAIWSResponseUsageFromCompletedEvent(message, usage)
|
||||
}
|
||||
imageCounter.AddSSEData(message)
|
||||
|
||||
if eventType == "error" {
|
||||
errCodeRaw, errTypeRaw, errMsgRaw := parseOpenAIWSErrorEventFields(message)
|
||||
@ -2343,6 +2345,7 @@ func (s *OpenAIGatewayService) forwardOpenAIWSV2(
|
||||
Usage: *usage,
|
||||
Model: originalModel,
|
||||
UpstreamModel: mappedModel,
|
||||
ImageCount: imageCounter.Count(),
|
||||
ServiceTier: extractOpenAIServiceTier(reqBody),
|
||||
ReasoningEffort: extractOpenAIReasoningEffort(reqBody, originalModel),
|
||||
Stream: reqStream,
|
||||
@ -2449,6 +2452,8 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
|
||||
promptCacheKey string
|
||||
previousResponseID string
|
||||
originalModel string
|
||||
imageBillingModel string
|
||||
imageSizeTier string
|
||||
payloadBytes int
|
||||
}
|
||||
|
||||
@ -2546,6 +2551,19 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
|
||||
}
|
||||
normalized = next
|
||||
}
|
||||
imageIntent := IsImageGenerationIntent(openAIResponsesEndpoint, originalModel, normalized)
|
||||
if imageIntent && !GroupAllowsImageGeneration(apiKeyGroup(getAPIKeyFromContext(c))) {
|
||||
return openAIWSClientPayload{}, NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, ImageGenerationPermissionMessage(), nil)
|
||||
}
|
||||
imageBillingModel := ""
|
||||
imageSizeTier := ""
|
||||
if imageIntent {
|
||||
var imageCfgErr error
|
||||
imageBillingModel, imageSizeTier, imageCfgErr = resolveOpenAIResponsesImageBillingConfigFromBody(normalized, originalModel)
|
||||
if imageCfgErr != nil {
|
||||
return openAIWSClientPayload{}, NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, imageCfgErr.Error(), imageCfgErr)
|
||||
}
|
||||
}
|
||||
|
||||
// Apply OpenAI Fast Policy on the response.create frame using the same
|
||||
// evaluator/normalize/scope rules as the HTTP entrypoints. This is the
|
||||
@ -2591,6 +2609,8 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
|
||||
promptCacheKey: promptCacheKey,
|
||||
previousResponseID: previousResponseID,
|
||||
originalModel: originalModel,
|
||||
imageBillingModel: imageBillingModel,
|
||||
imageSizeTier: imageSizeTier,
|
||||
payloadBytes: len(normalized),
|
||||
}, nil
|
||||
}
|
||||
@ -2792,7 +2812,7 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
|
||||
return payload, nil
|
||||
}
|
||||
|
||||
sendAndRelay := func(turn int, lease *openAIWSConnLease, payload []byte, payloadBytes int, originalModel string) (*OpenAIForwardResult, error) {
|
||||
sendAndRelay := func(turn int, lease *openAIWSConnLease, payload []byte, payloadBytes int, originalModel string, imageBillingModel string, imageSizeTier string) (*OpenAIForwardResult, error) {
|
||||
if lease == nil {
|
||||
return nil, errors.New("upstream websocket lease is nil")
|
||||
}
|
||||
@ -2817,6 +2837,7 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
|
||||
|
||||
responseID := ""
|
||||
usage := OpenAIUsage{}
|
||||
imageCounter := newOpenAIImageOutputCounter()
|
||||
var firstTokenMs *int
|
||||
reqStream := openAIWSPayloadBoolFromRaw(payload, "stream", true)
|
||||
turnPreviousResponseID := openAIWSPayloadStringFromRaw(payload, "previous_response_id")
|
||||
@ -2938,6 +2959,7 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
|
||||
if openAIWSEventShouldParseUsage(eventType) {
|
||||
parseOpenAIWSResponseUsageFromCompletedEvent(upstreamMessage, &usage)
|
||||
}
|
||||
imageCounter.AddSSEData(upstreamMessage)
|
||||
|
||||
if !clientDisconnected {
|
||||
if needModelReplace && len(mappedModelBytes) > 0 && openAIWSEventMayContainModel(eventType) && bytes.Contains(upstreamMessage, mappedModelBytes) {
|
||||
@ -2997,7 +3019,8 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
|
||||
clientDisconnected,
|
||||
)
|
||||
}
|
||||
return &OpenAIForwardResult{
|
||||
imageCount := imageCounter.Count()
|
||||
result := &OpenAIForwardResult{
|
||||
RequestID: responseID,
|
||||
Usage: usage,
|
||||
Model: originalModel,
|
||||
@ -3009,13 +3032,21 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
|
||||
ResponseHeaders: lease.HandshakeHeaders(),
|
||||
Duration: time.Since(turnStart),
|
||||
FirstTokenMs: firstTokenMs,
|
||||
}, nil
|
||||
}
|
||||
if imageCount > 0 {
|
||||
result.ImageCount = imageCount
|
||||
result.ImageSize = imageSizeTier
|
||||
result.BillingModel = imageBillingModel
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
currentPayload := firstPayload.payloadRaw
|
||||
currentOriginalModel := firstPayload.originalModel
|
||||
currentImageBillingModel := firstPayload.imageBillingModel
|
||||
currentImageSizeTier := firstPayload.imageSizeTier
|
||||
currentPayloadBytes := firstPayload.payloadBytes
|
||||
isStrictAffinityTurn := func(payload []byte) bool {
|
||||
if !storeDisabled {
|
||||
@ -3460,7 +3491,7 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
|
||||
)
|
||||
}
|
||||
|
||||
result, relayErr := sendAndRelay(turn, sessionLease, currentPayload, currentPayloadBytes, currentOriginalModel)
|
||||
result, relayErr := sendAndRelay(turn, sessionLease, currentPayload, currentPayloadBytes, currentOriginalModel, currentImageBillingModel, currentImageSizeTier)
|
||||
if relayErr != nil {
|
||||
lastTurnClean = false
|
||||
if recoverIngressPrevResponseNotFound(relayErr, turn, connID) {
|
||||
@ -3582,6 +3613,8 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
|
||||
}
|
||||
currentPayload = nextPayload.payloadRaw
|
||||
currentOriginalModel = nextPayload.originalModel
|
||||
currentImageBillingModel = nextPayload.imageBillingModel
|
||||
currentImageSizeTier = nextPayload.imageSizeTier
|
||||
currentPayloadBytes = nextPayload.payloadBytes
|
||||
storeDisabled = s.isOpenAIWSStoreDisabledInRequestRaw(currentPayload, account)
|
||||
if !storeDisabled {
|
||||
|
||||
@ -171,6 +171,127 @@ func TestOpenAIGatewayService_Forward_WSv2_SuccessAndBindSticky(t *testing.T) {
|
||||
require.Equal(t, "resp_new_1", gjson.GetBytes(responseBody, "id").String())
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayService_Forward_WSv2_ImageGenerationCountsOutputs(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}
|
||||
wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
conn, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
t.Errorf("upgrade websocket failed: %v", err)
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
_ = conn.Close()
|
||||
}()
|
||||
|
||||
var request map[string]any
|
||||
if err := conn.ReadJSON(&request); err != nil {
|
||||
t.Errorf("read ws request failed: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if err := conn.WriteJSON(map[string]any{
|
||||
"type": "response.output_item.done",
|
||||
"item": map[string]any{
|
||||
"id": "ig_ws_1",
|
||||
"type": "image_generation_call",
|
||||
"result": "final-image",
|
||||
},
|
||||
}); err != nil {
|
||||
t.Errorf("write response.output_item.done failed: %v", err)
|
||||
return
|
||||
}
|
||||
if err := conn.WriteJSON(map[string]any{
|
||||
"type": "response.completed",
|
||||
"response": map[string]any{
|
||||
"id": "resp_ws_image_1",
|
||||
"model": "gpt-5.4",
|
||||
"output": []any{
|
||||
map[string]any{
|
||||
"id": "ig_ws_1",
|
||||
"type": "image_generation_call",
|
||||
"result": "final-image",
|
||||
},
|
||||
},
|
||||
"usage": map[string]any{
|
||||
"input_tokens": 9,
|
||||
"output_tokens": 4,
|
||||
},
|
||||
},
|
||||
}); err != nil {
|
||||
t.Errorf("write response.completed failed: %v", err)
|
||||
return
|
||||
}
|
||||
}))
|
||||
defer wsServer.Close()
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
|
||||
groupID := int64(1010)
|
||||
c.Set("api_key", &APIKey{
|
||||
GroupID: &groupID,
|
||||
Group: &Group{
|
||||
ID: groupID,
|
||||
AllowImageGeneration: true,
|
||||
},
|
||||
})
|
||||
|
||||
cfg := &config.Config{}
|
||||
cfg.Security.URLAllowlist.Enabled = false
|
||||
cfg.Security.URLAllowlist.AllowInsecureHTTP = true
|
||||
cfg.Gateway.OpenAIWS.Enabled = true
|
||||
cfg.Gateway.OpenAIWS.OAuthEnabled = true
|
||||
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
|
||||
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
|
||||
cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1
|
||||
cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
|
||||
cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1
|
||||
cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8
|
||||
cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3
|
||||
cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 5
|
||||
cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3
|
||||
|
||||
svc := &OpenAIGatewayService{
|
||||
cfg: cfg,
|
||||
httpUpstream: &httpUpstreamRecorder{},
|
||||
cache: &stubGatewayCache{},
|
||||
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
|
||||
toolCorrector: NewCodexToolCorrector(),
|
||||
}
|
||||
|
||||
account := &Account{
|
||||
ID: 10,
|
||||
Name: "openai-ws-image",
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeAPIKey,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{
|
||||
"api_key": "sk-test",
|
||||
"base_url": wsServer.URL,
|
||||
},
|
||||
Extra: map[string]any{
|
||||
"responses_websockets_v2_enabled": true,
|
||||
},
|
||||
}
|
||||
|
||||
body := []byte(`{"model":"gpt-5.4","stream":false,"input":"draw","tools":[{"type":"image_generation","model":"gpt-image-2","size":"1024x1024"}],"tool_choice":{"type":"image_generation"}}`)
|
||||
result, err := svc.Forward(context.Background(), c, account, body)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.Equal(t, "resp_ws_image_1", result.RequestID)
|
||||
require.Equal(t, 1, result.ImageCount)
|
||||
require.Equal(t, "1K", result.ImageSize)
|
||||
require.Equal(t, "gpt-image-2", result.BillingModel)
|
||||
require.Equal(t, 9, result.Usage.InputTokens)
|
||||
require.Equal(t, 4, result.Usage.OutputTokens)
|
||||
require.True(t, result.OpenAIWSMode)
|
||||
require.Equal(t, "resp_ws_image_1", gjson.GetBytes(rec.Body.Bytes(), "id").String())
|
||||
}
|
||||
|
||||
func requestToJSONString(payload map[string]any) string {
|
||||
if len(payload) == 0 {
|
||||
return "{}"
|
||||
|
||||
@ -625,6 +625,9 @@ func normalizeModelNameForPricing(model string) string {
|
||||
}
|
||||
|
||||
model = strings.TrimLeft(model, "/")
|
||||
if canonical := canonicalizeOpenAIModelAliasSpelling(model); canonical != "" {
|
||||
return canonical
|
||||
}
|
||||
return model
|
||||
}
|
||||
|
||||
|
||||
@ -98,6 +98,19 @@ func TestGetModelPricing_Gpt54UsesStaticFallbackWhenRemoteMissing(t *testing.T)
|
||||
require.InDelta(t, 1.5, got.LongContextOutputCostMultiplier, 1e-12)
|
||||
}
|
||||
|
||||
func TestGetModelPricing_OpenAICompactAliasUsesStaticFallback(t *testing.T) {
|
||||
svc := &PricingService{
|
||||
pricingData: map[string]*LiteLLMModelPricing{
|
||||
"gpt-5.1-codex": {InputCostPerToken: 1.25e-6},
|
||||
},
|
||||
}
|
||||
|
||||
got := svc.GetModelPricing("openai/gpt5.5")
|
||||
require.NotNil(t, got)
|
||||
require.InDelta(t, 2.5e-6, got.InputCostPerToken, 1e-12)
|
||||
require.InDelta(t, 1.5e-5, got.OutputCostPerToken, 1e-12)
|
||||
}
|
||||
|
||||
func TestGetModelPricing_Gpt54MiniUsesDedicatedStaticFallbackWhenRemoteMissing(t *testing.T) {
|
||||
svc := &PricingService{
|
||||
pricingData: map[string]*LiteLLMModelPricing{
|
||||
|
||||
26
backend/migrations/134_image_generation_group_controls.sql
Normal file
26
backend/migrations/134_image_generation_group_controls.sql
Normal file
@ -0,0 +1,26 @@
|
||||
-- 生图能力与图片倍率模式控制
|
||||
-- 兼容性原则:
|
||||
-- 1. 不改写现有 image_price_1k/2k/4k,避免改变已配置分组的最终图片价格。
|
||||
-- 2. 现有 openai/gemini/antigravity 分组默认保持可生图,避免升级后中断已有图片业务。
|
||||
-- 3. 现有分组默认共享当前有效分组倍率,保持历史扣费公式。
|
||||
|
||||
ALTER TABLE groups
|
||||
ADD COLUMN IF NOT EXISTS allow_image_generation BOOLEAN NOT NULL DEFAULT false;
|
||||
|
||||
ALTER TABLE groups
|
||||
ADD COLUMN IF NOT EXISTS image_rate_independent BOOLEAN NOT NULL DEFAULT false;
|
||||
|
||||
ALTER TABLE groups
|
||||
ADD COLUMN IF NOT EXISTS image_rate_multiplier DECIMAL(10,4) NOT NULL DEFAULT 1.0;
|
||||
|
||||
UPDATE groups
|
||||
SET allow_image_generation = true
|
||||
WHERE platform IN ('openai', 'gemini', 'antigravity');
|
||||
|
||||
UPDATE groups
|
||||
SET image_rate_independent = false,
|
||||
image_rate_multiplier = 1.0;
|
||||
|
||||
COMMENT ON COLUMN groups.allow_image_generation IS '是否允许该分组使用图片生成能力';
|
||||
COMMENT ON COLUMN groups.image_rate_independent IS '图片生成是否使用独立倍率;false 表示共享分组有效倍率';
|
||||
COMMENT ON COLUMN groups.image_rate_multiplier IS '图片生成独立倍率,仅 image_rate_independent=true 时生效';
|
||||
@ -285,6 +285,25 @@ GATEWAY_SCHEDULING_OUTBOX_BACKLOG_REBUILD_ROWS=10000
|
||||
# 全量重建周期(秒)
|
||||
GATEWAY_SCHEDULING_FULL_REBUILD_INTERVAL_SECONDS=300
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Image Generation Stream & Concurrency (Optional)
|
||||
# 图片生成流式与并发隔离配置(可选)
|
||||
# -----------------------------------------------------------------------------
|
||||
# 图片流式上游数据间隔超时(秒)。0 表示禁用;非 0 时必须为 60-1800。
|
||||
GATEWAY_IMAGE_STREAM_DATA_INTERVAL_TIMEOUT=900
|
||||
# 图片流式 keepalive 间隔(秒)。0 表示禁用;非 0 时必须为 5-60。
|
||||
GATEWAY_IMAGE_STREAM_KEEPALIVE_INTERVAL=10
|
||||
# 是否启用进程级图片生成并发限制。默认 false,保持历史行为。
|
||||
GATEWAY_IMAGE_CONCURRENCY_ENABLED=false
|
||||
# 当前进程允许同时处理的图片生成请求数。0 表示不限制。
|
||||
GATEWAY_IMAGE_CONCURRENCY_MAX_CONCURRENT_REQUESTS=0
|
||||
# 图片并发超限策略:reject 直接返回 429;wait 等待空闲槽位。
|
||||
GATEWAY_IMAGE_CONCURRENCY_OVERFLOW_MODE=reject
|
||||
# wait 模式下等待空闲图片槽位的最长时间(秒)。
|
||||
GATEWAY_IMAGE_CONCURRENCY_WAIT_TIMEOUT_SECONDS=30
|
||||
# wait 模式下当前进程允许排队等待的最大图片请求数。0 表示不允许等待队列。
|
||||
GATEWAY_IMAGE_CONCURRENCY_MAX_WAITING_REQUESTS=100
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Dashboard Aggregation (Optional)
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
@ -340,6 +340,30 @@ gateway:
|
||||
# Stream keepalive interval (seconds), 0=disable
|
||||
# 流式 keepalive 间隔(秒),0=禁用
|
||||
stream_keepalive_interval: 10
|
||||
# Image stream data interval timeout (seconds), 0=disable; independent from ordinary text streams
|
||||
# 图片流数据间隔超时(秒),0=禁用;独立于普通文本流式
|
||||
image_stream_data_interval_timeout: 900
|
||||
# Image stream keepalive interval (seconds), 0=disable; independent from ordinary text streams
|
||||
# 图片流式 keepalive 间隔(秒),0=禁用;独立于普通文本流式
|
||||
image_stream_keepalive_interval: 10
|
||||
# Image generation independent concurrency limiter (process-local, default disabled)
|
||||
# 图片生成独立并发限制(进程级,默认关闭;多实例总上限约为实例数×该值)
|
||||
image_concurrency:
|
||||
# Enable image-only concurrency protection; false keeps existing behavior unchanged
|
||||
# 是否启用图片独立并发保护;false 保持现有行为不变
|
||||
enabled: false
|
||||
# Max concurrent image generation requests in this process, 0=unlimited
|
||||
# 当前进程允许同时处理的图片生成请求数,0=不限制
|
||||
max_concurrent_requests: 0
|
||||
# Overflow mode when the image concurrency limit is full: reject/wait
|
||||
# 图片并发满时的处理方式:reject=立即拒绝,wait=等待槽位
|
||||
overflow_mode: "reject"
|
||||
# Wait timeout for overflow_mode=wait (seconds), 0=do not wait
|
||||
# wait 模式等待图片并发槽位的超时时间(秒),0=不等待
|
||||
wait_timeout_seconds: 30
|
||||
# Max image requests waiting in this process when overflow_mode=wait, 0=unlimited
|
||||
# wait 模式当前进程允许排队等待的图片请求数,0=不限制
|
||||
max_waiting_requests: 100
|
||||
# SSE max line size in bytes (default: 40MB)
|
||||
# SSE 单行最大字节数(默认 40MB)
|
||||
max_line_size: 41943040
|
||||
|
||||
@ -40,6 +40,13 @@ services:
|
||||
- JWT_SECRET=${JWT_SECRET:-}
|
||||
- TOTP_ENCRYPTION_KEY=${TOTP_ENCRYPTION_KEY:-}
|
||||
- TZ=${TZ:-Asia/Shanghai}
|
||||
- GATEWAY_IMAGE_STREAM_DATA_INTERVAL_TIMEOUT=${GATEWAY_IMAGE_STREAM_DATA_INTERVAL_TIMEOUT:-900}
|
||||
- GATEWAY_IMAGE_STREAM_KEEPALIVE_INTERVAL=${GATEWAY_IMAGE_STREAM_KEEPALIVE_INTERVAL:-10}
|
||||
- GATEWAY_IMAGE_CONCURRENCY_ENABLED=${GATEWAY_IMAGE_CONCURRENCY_ENABLED:-false}
|
||||
- GATEWAY_IMAGE_CONCURRENCY_MAX_CONCURRENT_REQUESTS=${GATEWAY_IMAGE_CONCURRENCY_MAX_CONCURRENT_REQUESTS:-0}
|
||||
- GATEWAY_IMAGE_CONCURRENCY_OVERFLOW_MODE=${GATEWAY_IMAGE_CONCURRENCY_OVERFLOW_MODE:-reject}
|
||||
- GATEWAY_IMAGE_CONCURRENCY_WAIT_TIMEOUT_SECONDS=${GATEWAY_IMAGE_CONCURRENCY_WAIT_TIMEOUT_SECONDS:-30}
|
||||
- GATEWAY_IMAGE_CONCURRENCY_MAX_WAITING_REQUESTS=${GATEWAY_IMAGE_CONCURRENCY_MAX_WAITING_REQUESTS:-100}
|
||||
depends_on:
|
||||
postgres:
|
||||
condition: service_healthy
|
||||
|
||||
@ -146,6 +146,17 @@ services:
|
||||
# Proxy for accessing GitHub (online updates + pricing data)
|
||||
# Examples: http://host:port, socks5://host:port
|
||||
- UPDATE_PROXY_URL=${UPDATE_PROXY_URL:-}
|
||||
|
||||
# =======================================================================
|
||||
# Image Generation Stream & Concurrency
|
||||
# =======================================================================
|
||||
- GATEWAY_IMAGE_STREAM_DATA_INTERVAL_TIMEOUT=${GATEWAY_IMAGE_STREAM_DATA_INTERVAL_TIMEOUT:-900}
|
||||
- GATEWAY_IMAGE_STREAM_KEEPALIVE_INTERVAL=${GATEWAY_IMAGE_STREAM_KEEPALIVE_INTERVAL:-10}
|
||||
- GATEWAY_IMAGE_CONCURRENCY_ENABLED=${GATEWAY_IMAGE_CONCURRENCY_ENABLED:-false}
|
||||
- GATEWAY_IMAGE_CONCURRENCY_MAX_CONCURRENT_REQUESTS=${GATEWAY_IMAGE_CONCURRENCY_MAX_CONCURRENT_REQUESTS:-0}
|
||||
- GATEWAY_IMAGE_CONCURRENCY_OVERFLOW_MODE=${GATEWAY_IMAGE_CONCURRENCY_OVERFLOW_MODE:-reject}
|
||||
- GATEWAY_IMAGE_CONCURRENCY_WAIT_TIMEOUT_SECONDS=${GATEWAY_IMAGE_CONCURRENCY_WAIT_TIMEOUT_SECONDS:-30}
|
||||
- GATEWAY_IMAGE_CONCURRENCY_MAX_WAITING_REQUESTS=${GATEWAY_IMAGE_CONCURRENCY_MAX_WAITING_REQUESTS:-100}
|
||||
depends_on:
|
||||
postgres:
|
||||
condition: service_healthy
|
||||
|
||||
@ -93,6 +93,17 @@ services:
|
||||
# SECURITY: This repo does not embed third-party client_secret.
|
||||
- GEMINI_CLI_OAUTH_CLIENT_SECRET=${GEMINI_CLI_OAUTH_CLIENT_SECRET:-}
|
||||
- ANTIGRAVITY_OAUTH_CLIENT_SECRET=${ANTIGRAVITY_OAUTH_CLIENT_SECRET:-}
|
||||
|
||||
# =======================================================================
|
||||
# Image Generation Stream & Concurrency
|
||||
# =======================================================================
|
||||
- GATEWAY_IMAGE_STREAM_DATA_INTERVAL_TIMEOUT=${GATEWAY_IMAGE_STREAM_DATA_INTERVAL_TIMEOUT:-900}
|
||||
- GATEWAY_IMAGE_STREAM_KEEPALIVE_INTERVAL=${GATEWAY_IMAGE_STREAM_KEEPALIVE_INTERVAL:-10}
|
||||
- GATEWAY_IMAGE_CONCURRENCY_ENABLED=${GATEWAY_IMAGE_CONCURRENCY_ENABLED:-false}
|
||||
- GATEWAY_IMAGE_CONCURRENCY_MAX_CONCURRENT_REQUESTS=${GATEWAY_IMAGE_CONCURRENCY_MAX_CONCURRENT_REQUESTS:-0}
|
||||
- GATEWAY_IMAGE_CONCURRENCY_OVERFLOW_MODE=${GATEWAY_IMAGE_CONCURRENCY_OVERFLOW_MODE:-reject}
|
||||
- GATEWAY_IMAGE_CONCURRENCY_WAIT_TIMEOUT_SECONDS=${GATEWAY_IMAGE_CONCURRENCY_WAIT_TIMEOUT_SECONDS:-30}
|
||||
- GATEWAY_IMAGE_CONCURRENCY_MAX_WAITING_REQUESTS=${GATEWAY_IMAGE_CONCURRENCY_MAX_WAITING_REQUESTS:-100}
|
||||
healthcheck:
|
||||
test: ["CMD", "wget", "-q", "-T", "5", "-O", "/dev/null", "http://localhost:8080/health"]
|
||||
interval: 30s
|
||||
|
||||
@ -142,6 +142,17 @@ services:
|
||||
# Proxy for accessing GitHub (online updates + pricing data)
|
||||
# Examples: http://host:port, socks5://host:port
|
||||
- UPDATE_PROXY_URL=${UPDATE_PROXY_URL:-}
|
||||
|
||||
# =======================================================================
|
||||
# Image Generation Stream & Concurrency
|
||||
# =======================================================================
|
||||
- GATEWAY_IMAGE_STREAM_DATA_INTERVAL_TIMEOUT=${GATEWAY_IMAGE_STREAM_DATA_INTERVAL_TIMEOUT:-900}
|
||||
- GATEWAY_IMAGE_STREAM_KEEPALIVE_INTERVAL=${GATEWAY_IMAGE_STREAM_KEEPALIVE_INTERVAL:-10}
|
||||
- GATEWAY_IMAGE_CONCURRENCY_ENABLED=${GATEWAY_IMAGE_CONCURRENCY_ENABLED:-false}
|
||||
- GATEWAY_IMAGE_CONCURRENCY_MAX_CONCURRENT_REQUESTS=${GATEWAY_IMAGE_CONCURRENCY_MAX_CONCURRENT_REQUESTS:-0}
|
||||
- GATEWAY_IMAGE_CONCURRENCY_OVERFLOW_MODE=${GATEWAY_IMAGE_CONCURRENCY_OVERFLOW_MODE:-reject}
|
||||
- GATEWAY_IMAGE_CONCURRENCY_WAIT_TIMEOUT_SECONDS=${GATEWAY_IMAGE_CONCURRENCY_WAIT_TIMEOUT_SECONDS:-30}
|
||||
- GATEWAY_IMAGE_CONCURRENCY_MAX_WAITING_REQUESTS=${GATEWAY_IMAGE_CONCURRENCY_MAX_WAITING_REQUESTS:-100}
|
||||
depends_on:
|
||||
postgres:
|
||||
condition: service_healthy
|
||||
|
||||
@ -291,9 +291,23 @@
|
||||
</div>
|
||||
</template>
|
||||
<!-- Per-request / image billing: show unit price -->
|
||||
<template v-else-if="tooltipData?.billing_mode === BILLING_MODE_IMAGE">
|
||||
<div class="flex items-center justify-between gap-4">
|
||||
<span class="text-gray-400">{{ t('usage.imageCount') }}</span>
|
||||
<span class="font-medium text-white">{{ tooltipData.image_count }}{{ t('usage.imageUnit') }} ({{ tooltipData.image_size || '2K' }})</span>
|
||||
</div>
|
||||
<div class="flex items-center justify-between gap-4">
|
||||
<span class="text-gray-400">{{ t('usage.imageUnitPrice') }}</span>
|
||||
<span class="font-medium text-sky-300">${{ imageUnitPrice(tooltipData).toFixed(6) }}</span>
|
||||
</div>
|
||||
<div class="flex items-center justify-between gap-4">
|
||||
<span class="text-gray-400">{{ t('usage.imageTotalPrice') }}</span>
|
||||
<span class="font-medium text-white">${{ tooltipData.total_cost?.toFixed(6) || '0.000000' }}</span>
|
||||
</div>
|
||||
</template>
|
||||
<div v-else class="flex items-center justify-between gap-4">
|
||||
<span class="text-gray-400">{{ tooltipData.billing_mode === BILLING_MODE_IMAGE ? t('usage.imageUnitPrice') : t('usage.unitPrice') }}</span>
|
||||
<span class="font-medium text-sky-300">${{ tooltipData.total_cost?.toFixed(6) || '0.000000' }}</span>
|
||||
<span class="text-gray-400">{{ t('usage.unitPrice') }}</span>
|
||||
<span class="font-medium text-sky-300">${{ tooltipData?.total_cost?.toFixed(6) || '0.000000' }}</span>
|
||||
</div>
|
||||
<div v-if="tooltipData && tooltipData.cache_creation_cost > 0" class="flex items-center justify-between gap-4">
|
||||
<span class="text-gray-400">{{ t('admin.usage.cacheCreationCost') }}</span>
|
||||
@ -360,6 +374,13 @@ function accountBilled(row: { total_cost?: number | null; account_stats_cost?: n
|
||||
return Number.isNaN(result) ? 0 : result
|
||||
}
|
||||
|
||||
function imageUnitPrice(row: AdminUsageLog | null): number {
|
||||
if (!row || row.image_count <= 0) return 0
|
||||
const total = row.total_cost ?? 0
|
||||
const price = total / row.image_count
|
||||
return Number.isFinite(price) ? price : 0
|
||||
}
|
||||
|
||||
import DataTable from '@/components/common/DataTable.vue'
|
||||
import EmptyState from '@/components/common/EmptyState.vue'
|
||||
import Icon from '@/components/icons/Icon.vue'
|
||||
|
||||
@ -844,6 +844,8 @@ export default {
|
||||
perMillionTokens: '/ 1M tokens',
|
||||
unitPrice: 'Per-request price',
|
||||
imageUnitPrice: 'Per-image price',
|
||||
imageTotalPrice: 'Image total price',
|
||||
imageCount: 'Image count',
|
||||
cacheRead: 'Read',
|
||||
cacheWrite: 'Write',
|
||||
serviceTier: 'Service tier',
|
||||
@ -2050,7 +2052,13 @@ export default {
|
||||
},
|
||||
imagePricing: {
|
||||
title: 'Image Generation Pricing',
|
||||
description: 'Configure pricing for image generation models. Leave empty to use default prices.'
|
||||
description: 'Configure image generation access and base image prices. Leave empty to use default prices.',
|
||||
allowImageGeneration: 'Allow image generation for this group',
|
||||
independentMultiplier: 'Use independent image multiplier',
|
||||
imageMultiplier: 'Image multiplier',
|
||||
modeHint: 'By default, image billing uses image price × current effective group multiplier. Independent mode uses image price × image multiplier.',
|
||||
finalPricePreview: 'Final per-image price preview',
|
||||
notConfigured: 'Not configured'
|
||||
},
|
||||
claudeCode: {
|
||||
title: 'Claude Code Client Restriction',
|
||||
|
||||
@ -848,6 +848,8 @@ export default {
|
||||
perMillionTokens: '/ 1M Token',
|
||||
unitPrice: '单次价格',
|
||||
imageUnitPrice: '单张价格',
|
||||
imageTotalPrice: '图片总价',
|
||||
imageCount: '图片张数',
|
||||
cacheRead: '读取',
|
||||
cacheWrite: '写入',
|
||||
serviceTier: '服务档位',
|
||||
@ -2133,7 +2135,13 @@ export default {
|
||||
},
|
||||
imagePricing: {
|
||||
title: '图片生成计费',
|
||||
description: '配置图片生成模型的图片生成价格,留空则使用默认价格'
|
||||
description: '配置图片生成能力和图片基础单价,留空则使用默认价格',
|
||||
allowImageGeneration: '允许当前分组生图',
|
||||
independentMultiplier: '生图倍率独立',
|
||||
imageMultiplier: '生图独立倍率',
|
||||
modeHint: '默认关闭独立倍率时,图片费用 = 图片价格 × 当前分组有效倍率;开启独立倍率后,图片费用 = 图片价格 × 生图独立倍率。',
|
||||
finalPricePreview: '最终单张价格预览',
|
||||
notConfigured: '未配置'
|
||||
},
|
||||
claudeCode: {
|
||||
title: 'Claude Code 客户端限制',
|
||||
|
||||
@ -492,7 +492,10 @@ export interface Group {
|
||||
daily_limit_usd: number | null
|
||||
weekly_limit_usd: number | null
|
||||
monthly_limit_usd: number | null
|
||||
// 图片生成计费配置(仅 antigravity 平台使用)
|
||||
// 图片生成计费配置
|
||||
allow_image_generation: boolean
|
||||
image_rate_independent: boolean
|
||||
image_rate_multiplier: number
|
||||
image_price_1k: number | null
|
||||
image_price_2k: number | null
|
||||
image_price_4k: number | null
|
||||
@ -602,6 +605,9 @@ export interface CreateGroupRequest {
|
||||
daily_limit_usd?: number | null
|
||||
weekly_limit_usd?: number | null
|
||||
monthly_limit_usd?: number | null
|
||||
allow_image_generation?: boolean
|
||||
image_rate_independent?: boolean
|
||||
image_rate_multiplier?: number
|
||||
image_price_1k?: number | null
|
||||
image_price_2k?: number | null
|
||||
image_price_4k?: number | null
|
||||
@ -627,6 +633,9 @@ export interface UpdateGroupRequest {
|
||||
daily_limit_usd?: number | null
|
||||
weekly_limit_usd?: number | null
|
||||
monthly_limit_usd?: number | null
|
||||
allow_image_generation?: boolean
|
||||
image_rate_independent?: boolean
|
||||
image_rate_multiplier?: number
|
||||
image_price_1k?: number | null
|
||||
image_price_2k?: number | null
|
||||
image_price_4k?: number | null
|
||||
|
||||
@ -666,6 +666,40 @@
|
||||
<p class="text-xs text-gray-500 dark:text-gray-400 mb-3">
|
||||
{{ t("admin.groups.imagePricing.description") }}
|
||||
</p>
|
||||
<div class="mb-4 grid grid-cols-1 gap-3 md:grid-cols-2">
|
||||
<label class="flex items-center gap-2 text-sm text-gray-700 dark:text-gray-300">
|
||||
<input
|
||||
v-model="createForm.allow_image_generation"
|
||||
type="checkbox"
|
||||
class="rounded border-gray-300 text-blue-600 focus:ring-blue-500"
|
||||
/>
|
||||
{{ t("admin.groups.imagePricing.allowImageGeneration") }}
|
||||
</label>
|
||||
<label class="flex items-center gap-2 text-sm text-gray-700 dark:text-gray-300">
|
||||
<input
|
||||
v-model="createForm.image_rate_independent"
|
||||
type="checkbox"
|
||||
class="rounded border-gray-300 text-blue-600 focus:ring-blue-500"
|
||||
/>
|
||||
{{ t("admin.groups.imagePricing.independentMultiplier") }}
|
||||
</label>
|
||||
</div>
|
||||
<div
|
||||
v-if="createForm.image_rate_independent"
|
||||
class="mb-4"
|
||||
>
|
||||
<label class="input-label">{{
|
||||
t("admin.groups.imagePricing.imageMultiplier")
|
||||
}}</label>
|
||||
<input
|
||||
v-model.number="createForm.image_rate_multiplier"
|
||||
type="number"
|
||||
step="0.0001"
|
||||
min="0"
|
||||
class="input"
|
||||
placeholder="1"
|
||||
/>
|
||||
</div>
|
||||
<div class="grid grid-cols-3 gap-3">
|
||||
<div>
|
||||
<label class="input-label">1K ($)</label>
|
||||
@ -701,6 +735,22 @@
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
<p class="mt-3 text-xs text-gray-500 dark:text-gray-400">
|
||||
{{ t("admin.groups.imagePricing.modeHint") }}
|
||||
</p>
|
||||
<div class="mt-2 rounded-lg bg-gray-50 p-3 text-xs text-gray-700 dark:bg-gray-800 dark:text-gray-300">
|
||||
<div class="mb-1 font-medium">
|
||||
{{ t("admin.groups.imagePricing.finalPricePreview") }}
|
||||
</div>
|
||||
<div class="grid grid-cols-3 gap-2">
|
||||
<div
|
||||
v-for="item in createImageFinalPricePreview"
|
||||
:key="item.label"
|
||||
>
|
||||
{{ item.label }}: {{ item.value }}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- 支持的模型系列(仅 antigravity 平台) -->
|
||||
@ -1801,6 +1851,40 @@
|
||||
<p class="text-xs text-gray-500 dark:text-gray-400 mb-3">
|
||||
{{ t("admin.groups.imagePricing.description") }}
|
||||
</p>
|
||||
<div class="mb-4 grid grid-cols-1 gap-3 md:grid-cols-2">
|
||||
<label class="flex items-center gap-2 text-sm text-gray-700 dark:text-gray-300">
|
||||
<input
|
||||
v-model="editForm.allow_image_generation"
|
||||
type="checkbox"
|
||||
class="rounded border-gray-300 text-blue-600 focus:ring-blue-500"
|
||||
/>
|
||||
{{ t("admin.groups.imagePricing.allowImageGeneration") }}
|
||||
</label>
|
||||
<label class="flex items-center gap-2 text-sm text-gray-700 dark:text-gray-300">
|
||||
<input
|
||||
v-model="editForm.image_rate_independent"
|
||||
type="checkbox"
|
||||
class="rounded border-gray-300 text-blue-600 focus:ring-blue-500"
|
||||
/>
|
||||
{{ t("admin.groups.imagePricing.independentMultiplier") }}
|
||||
</label>
|
||||
</div>
|
||||
<div
|
||||
v-if="editForm.image_rate_independent"
|
||||
class="mb-4"
|
||||
>
|
||||
<label class="input-label">{{
|
||||
t("admin.groups.imagePricing.imageMultiplier")
|
||||
}}</label>
|
||||
<input
|
||||
v-model.number="editForm.image_rate_multiplier"
|
||||
type="number"
|
||||
step="0.0001"
|
||||
min="0"
|
||||
class="input"
|
||||
placeholder="1"
|
||||
/>
|
||||
</div>
|
||||
<div class="grid grid-cols-3 gap-3">
|
||||
<div>
|
||||
<label class="input-label">1K ($)</label>
|
||||
@ -1836,6 +1920,22 @@
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
<p class="mt-3 text-xs text-gray-500 dark:text-gray-400">
|
||||
{{ t("admin.groups.imagePricing.modeHint") }}
|
||||
</p>
|
||||
<div class="mt-2 rounded-lg bg-gray-50 p-3 text-xs text-gray-700 dark:bg-gray-800 dark:text-gray-300">
|
||||
<div class="mb-1 font-medium">
|
||||
{{ t("admin.groups.imagePricing.finalPricePreview") }}
|
||||
</div>
|
||||
<div class="grid grid-cols-3 gap-2">
|
||||
<div
|
||||
v-for="item in editImageFinalPricePreview"
|
||||
:key="item.label"
|
||||
>
|
||||
{{ item.label }}: {{ item.value }}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- 支持的模型系列(仅 antigravity 平台) -->
|
||||
@ -3009,7 +3109,10 @@ const createForm = reactive({
|
||||
daily_limit_usd: null as number | null,
|
||||
weekly_limit_usd: null as number | null,
|
||||
monthly_limit_usd: null as number | null,
|
||||
// 图片生成计费配置(仅 antigravity 平台使用)
|
||||
// 图片生成计费配置
|
||||
allow_image_generation: false,
|
||||
image_rate_independent: false,
|
||||
image_rate_multiplier: 1,
|
||||
image_price_1k: null as number | null,
|
||||
image_price_2k: null as number | null,
|
||||
image_price_4k: null as number | null,
|
||||
@ -3291,7 +3394,10 @@ const editForm = reactive({
|
||||
daily_limit_usd: null as number | null,
|
||||
weekly_limit_usd: null as number | null,
|
||||
monthly_limit_usd: null as number | null,
|
||||
// 图片生成计费配置(仅 antigravity 平台使用)
|
||||
// 图片生成计费配置
|
||||
allow_image_generation: false,
|
||||
image_rate_independent: false,
|
||||
image_rate_multiplier: 1,
|
||||
image_price_1k: null as number | null,
|
||||
image_price_2k: null as number | null,
|
||||
image_price_4k: null as number | null,
|
||||
@ -3321,6 +3427,62 @@ const editForm = reactive({
|
||||
rpm_limit: 0 as number,
|
||||
});
|
||||
|
||||
type ImagePricingFormState = {
|
||||
rate_multiplier: number;
|
||||
image_rate_independent: boolean;
|
||||
image_rate_multiplier: number;
|
||||
image_price_1k: number | string | null;
|
||||
image_price_2k: number | string | null;
|
||||
image_price_4k: number | string | null;
|
||||
};
|
||||
|
||||
const imagePricingTiers = [
|
||||
{ key: "image_price_1k", label: "1K" },
|
||||
{ key: "image_price_2k", label: "2K" },
|
||||
{ key: "image_price_4k", label: "4K" },
|
||||
] as const;
|
||||
|
||||
const normalizePreviewNumber = (value: number | string | null | undefined, fallback = 0) => {
|
||||
if (value === null || value === undefined || value === "") {
|
||||
return fallback;
|
||||
}
|
||||
const parsed = Number(value);
|
||||
return Number.isFinite(parsed) ? parsed : fallback;
|
||||
};
|
||||
|
||||
const formatImagePricePreview = (value: number | string | null | undefined) => {
|
||||
if (value === null || value === undefined || value === "") {
|
||||
return t("admin.groups.imagePricing.notConfigured");
|
||||
}
|
||||
const price = Number(value);
|
||||
if (!Number.isFinite(price) || price < 0) {
|
||||
return t("admin.groups.imagePricing.notConfigured");
|
||||
}
|
||||
return `$${price.toFixed(6).replace(/0+$/, "").replace(/\.$/, "")}`;
|
||||
};
|
||||
|
||||
const buildImageFinalPricePreview = (form: ImagePricingFormState) => {
|
||||
const multiplier = form.image_rate_independent
|
||||
? normalizePreviewNumber(form.image_rate_multiplier, 1)
|
||||
: normalizePreviewNumber(form.rate_multiplier, 1);
|
||||
return imagePricingTiers.map((tier) => {
|
||||
const basePrice = normalizePreviewNumber(form[tier.key]);
|
||||
return {
|
||||
label: tier.label,
|
||||
value: basePrice > 0
|
||||
? formatImagePricePreview(basePrice * multiplier)
|
||||
: t("admin.groups.imagePricing.notConfigured"),
|
||||
};
|
||||
});
|
||||
};
|
||||
|
||||
const createImageFinalPricePreview = computed(() =>
|
||||
buildImageFinalPricePreview(createForm),
|
||||
);
|
||||
const editImageFinalPricePreview = computed(() =>
|
||||
buildImageFinalPricePreview(editForm),
|
||||
);
|
||||
|
||||
// 根据分组类型返回不同的删除确认消息
|
||||
const deleteConfirmMessage = computed(() => {
|
||||
if (!deletingGroup.value) {
|
||||
@ -3479,6 +3641,9 @@ const closeCreateModal = () => {
|
||||
createForm.daily_limit_usd = null;
|
||||
createForm.weekly_limit_usd = null;
|
||||
createForm.monthly_limit_usd = null;
|
||||
createForm.allow_image_generation = false;
|
||||
createForm.image_rate_independent = false;
|
||||
createForm.image_rate_multiplier = 1;
|
||||
createForm.image_price_1k = null;
|
||||
createForm.image_price_2k = null;
|
||||
createForm.image_price_4k = null;
|
||||
@ -3513,6 +3678,16 @@ const normalizeOptionalLimit = (
|
||||
return Number.isFinite(value) && value > 0 ? value : null;
|
||||
};
|
||||
|
||||
const normalizeImageRateMultiplier = (
|
||||
value: number | string | null | undefined,
|
||||
): number => {
|
||||
if (value === null || value === undefined || value === "") {
|
||||
return 1;
|
||||
}
|
||||
const parsed = Number(value);
|
||||
return Number.isFinite(parsed) && parsed >= 0 ? parsed : 1;
|
||||
};
|
||||
|
||||
const handleCreateGroup = async () => {
|
||||
if (!createForm.name.trim()) {
|
||||
appStore.showError(t("admin.groups.nameRequired"));
|
||||
@ -3551,6 +3726,9 @@ const handleCreateGroup = async () => {
|
||||
requestData.daily_limit_usd = emptyToNull(requestData.daily_limit_usd);
|
||||
requestData.weekly_limit_usd = emptyToNull(requestData.weekly_limit_usd);
|
||||
requestData.monthly_limit_usd = emptyToNull(requestData.monthly_limit_usd);
|
||||
requestData.image_rate_multiplier = normalizeImageRateMultiplier(
|
||||
requestData.image_rate_multiplier,
|
||||
);
|
||||
await adminAPI.groups.create(requestData);
|
||||
appStore.showSuccess(t("admin.groups.groupCreated"));
|
||||
closeCreateModal();
|
||||
@ -3582,6 +3760,9 @@ const handleEdit = async (group: AdminGroup) => {
|
||||
editForm.daily_limit_usd = group.daily_limit_usd;
|
||||
editForm.weekly_limit_usd = group.weekly_limit_usd;
|
||||
editForm.monthly_limit_usd = group.monthly_limit_usd;
|
||||
editForm.allow_image_generation = group.allow_image_generation ?? false;
|
||||
editForm.image_rate_independent = group.image_rate_independent ?? false;
|
||||
editForm.image_rate_multiplier = group.image_rate_multiplier ?? 1;
|
||||
editForm.image_price_1k = group.image_price_1k;
|
||||
editForm.image_price_2k = group.image_price_2k;
|
||||
editForm.image_price_4k = group.image_price_4k;
|
||||
@ -3676,6 +3857,9 @@ const handleUpdateGroup = async () => {
|
||||
payload.daily_limit_usd = emptyToNull(payload.daily_limit_usd);
|
||||
payload.weekly_limit_usd = emptyToNull(payload.weekly_limit_usd);
|
||||
payload.monthly_limit_usd = emptyToNull(payload.monthly_limit_usd);
|
||||
payload.image_rate_multiplier = normalizeImageRateMultiplier(
|
||||
payload.image_rate_multiplier,
|
||||
);
|
||||
await adminAPI.groups.update(editingGroup.value.id, payload);
|
||||
appStore.showSuccess(t("admin.groups.groupUpdated"));
|
||||
closeEditModal();
|
||||
|
||||
@ -459,9 +459,23 @@
|
||||
</div>
|
||||
</template>
|
||||
<!-- Per-request / image billing: show unit price -->
|
||||
<template v-else-if="tooltipData?.billing_mode === 'image'">
|
||||
<div class="flex items-center justify-between gap-4">
|
||||
<span class="text-gray-400">{{ t('usage.imageCount') }}</span>
|
||||
<span class="font-medium text-white">{{ tooltipData.image_count }}{{ t('usage.imageUnit') }} ({{ tooltipData.image_size || '2K' }})</span>
|
||||
</div>
|
||||
<div class="flex items-center justify-between gap-4">
|
||||
<span class="text-gray-400">{{ t('usage.imageUnitPrice') }}</span>
|
||||
<span class="font-medium text-sky-300">${{ imageUnitPrice(tooltipData).toFixed(6) }}</span>
|
||||
</div>
|
||||
<div class="flex items-center justify-between gap-4">
|
||||
<span class="text-gray-400">{{ t('usage.imageTotalPrice') }}</span>
|
||||
<span class="font-medium text-white">${{ tooltipData.total_cost?.toFixed(6) || '0.000000' }}</span>
|
||||
</div>
|
||||
</template>
|
||||
<div v-else class="flex items-center justify-between gap-4">
|
||||
<span class="text-gray-400">{{ tooltipData.billing_mode === 'image' ? t('usage.imageUnitPrice') : t('usage.unitPrice') }}</span>
|
||||
<span class="font-medium text-sky-300">${{ tooltipData.total_cost?.toFixed(6) || '0.000000' }}</span>
|
||||
<span class="text-gray-400">{{ t('usage.unitPrice') }}</span>
|
||||
<span class="font-medium text-sky-300">${{ tooltipData?.total_cost?.toFixed(6) || '0.000000' }}</span>
|
||||
</div>
|
||||
<div v-if="tooltipData && tooltipData.cache_creation_cost > 0" class="flex items-center justify-between gap-4">
|
||||
<span class="text-gray-400">{{ t('admin.usage.cacheCreationCost') }}</span>
|
||||
@ -625,6 +639,13 @@ const formatDuration = (ms: number): string => {
|
||||
return `${(ms / 1000).toFixed(2)}s`
|
||||
}
|
||||
|
||||
const imageUnitPrice = (row: UsageLog | null): number => {
|
||||
if (!row || row.image_count <= 0) return 0
|
||||
const total = row.total_cost ?? 0
|
||||
const price = total / row.image_count
|
||||
return Number.isFinite(price) ? price : 0
|
||||
}
|
||||
|
||||
const formatUserAgent = (ua: string): string => {
|
||||
return ua
|
||||
}
|
||||
|
||||
@ -44,7 +44,6 @@ export default defineConfig(({ mode }) => {
|
||||
plugins: [
|
||||
vue(),
|
||||
checker({
|
||||
typescript: true,
|
||||
vueTsc: true
|
||||
}),
|
||||
injectPublicSettings(backendUrl)
|
||||
|
||||
@ -0,0 +1,2 @@
|
||||
schema: spec-driven
|
||||
created: 2026-04-29
|
||||
227
openspec/changes/add-image-generation-billing-controls/design.md
Normal file
227
openspec/changes/add-image-generation-billing-controls/design.md
Normal file
@ -0,0 +1,227 @@
|
||||
## Context
|
||||
|
||||
当前代码已经具备图片价格字段和部分图片转发能力,但边界不完整:
|
||||
|
||||
- `backend/ent/schema/group.go` 只有 `rate_multiplier` 和 `image_price_1k/2k/4k`,没有分组级生图能力开关,也没有“图片是否共享分组倍率”的开关。
|
||||
- `backend/internal/handler/openai_images.go` 在解析 `/v1/images/*` 后只做通用余额/订阅资格检查,没有检查分组是否允许生图。
|
||||
- `backend/internal/service/openai_gateway_service.go` 对 Codex CLI 会自动注入 `image_generation` tool;通用 `/v1/responses` 只记录日志,没有把图片工具产物数量写入 `OpenAIForwardResult.ImageCount`。
|
||||
- `backend/internal/service/billing_service.go` 的 `CalculateImageCost` 当前使用 `image_price_* * image_count * rate_multiplier`。这个行为本身可以作为默认兼容模式,但普通编码分组 `rate_multiplier=0.15` 且希望图片最终价为 `0.2/张` 时,管理员必须填写 `image_price=0.2/0.15`,不可读且不适合长期运营。
|
||||
- `backend/internal/service/openai_gateway_service.go` 和 `backend/internal/service/gateway_service.go` 的渠道图片计费路径当前传 `RequestCount: 1`,多图请求会按 1 次收费。
|
||||
- `backend/internal/service/openai_images.go` 的 OpenAI 图片尺寸分层此前只覆盖少量固定尺寸;`gpt-image-2` 官方文档已经支持满足约束的自定义 `size`,因此本地计费必须能够对未知尺寸做稳定分档,同时不能因为本地映射不认识就提前拦截请求。
|
||||
|
||||
用户澄清后的业务要求是:普通编码分组可以关闭生图,也可以开启生图;开启后默认继续共享现有分组倍率以保持兼容,但管理员可以打开“生图倍率独立”开关,改用单独的图片倍率输入框。图片分组是推荐的运营隔离方式,但不是唯一承载方式。
|
||||
|
||||
## Goals / Non-Goals
|
||||
|
||||
**Goals:**
|
||||
- 分组具备明确的 `allow_image_generation` 开关,所有已知生图入口在调度上游前执行同一个权限判断。
|
||||
- 分组具备“生图倍率是否独立”的开关;默认 `false`,即共享当前代码里的有效分组倍率。
|
||||
- 生图倍率独立开关打开后,图片费用使用单独的 `image_rate_multiplier`,不再使用普通编码分组的倍率。
|
||||
- 保留现有 `image_price_1k/2k/4k` 字段作为图片单价配置,不强制把它们迁移成新的语义。
|
||||
- 普通编码分组在 `allow_image_generation=false` 时仍可正常使用 `gpt-5.4` / `gpt-5.5` 文本能力,但不能使用图片工具。
|
||||
- 普通编码分组在 `allow_image_generation=true` 时可使用 `gpt-5.4` / `gpt-5.5 + image_generation`,且按实际图片数量收费。
|
||||
- 通用 `/v1/responses`、OpenAI Images API、流式、非流式、透传路径全部把成功产出的图片数量写入 `ImageCount`。
|
||||
- 渠道 `billing_mode=image` 使用真实 `ImageCount`,不再固定按 1 次收费。
|
||||
|
||||
**Non-Goals:**
|
||||
- 不引入新的第三方依赖。
|
||||
- 不改变 OpenAI 上游协议;只在现有请求转发、响应解析和计费归因层补齐控制。
|
||||
- 不把“图片分组”做成唯一安全边界;分组开关和图片计费逻辑必须适用于任意开启生图的分组。
|
||||
- 不在本变更中实现预扣费/资金冻结;失败请求仍不收费,成功请求按实际产物后扣费。
|
||||
- 不改变默认历史图片价格行为;默认共享现有有效倍率,历史 `图片价格 * 分组/用户有效倍率` 的扣费行为保持。
|
||||
- 不在本变更中新增用户级图片独立倍率覆盖;用户专属普通倍率只在共享倍率模式下继续影响图片。
|
||||
|
||||
## Decisions
|
||||
|
||||
### 0. 兼容性优先原则
|
||||
|
||||
本变更的默认行为必须以“不改变现有已配置分组的最终扣费”为优先级:
|
||||
|
||||
- 迁移不修改现有 `image_price_1k/2k/4k`。
|
||||
- 迁移把所有现有分组设置为 `image_rate_independent=false`,因此现有图片路径继续使用当前有效分组倍率。
|
||||
- 管理员不传新字段更新分组时,不得覆盖已保存的 `allow_image_generation`、`image_rate_independent`、`image_rate_multiplier`。
|
||||
- 前端编辑旧分组时必须回显服务端值;不能因为表单默认值把旧分组从共享倍率误改成独立倍率,或把允许生图误改成禁止生图。
|
||||
- 只有管理员显式打开 `image_rate_independent` 后,图片扣费才从共享倍率切换到图片独立倍率。
|
||||
|
||||
### 1. 分组字段与迁移策略
|
||||
|
||||
新增三个分组字段,对应“2 个开关 + 1 个输入框”:
|
||||
|
||||
- `allow_image_generation BOOLEAN NOT NULL DEFAULT false`
|
||||
- `image_rate_independent BOOLEAN NOT NULL DEFAULT false`
|
||||
- `image_rate_multiplier DECIMAL(10,4) NOT NULL DEFAULT 1.0`
|
||||
|
||||
字段语义:
|
||||
|
||||
- `allow_image_generation`:是否支持当前分组生图。
|
||||
- `image_rate_independent=false`:图片计费共享当前普通计费链路里的有效倍率,即当前 `userGroupRateResolver.Resolve(ctx, user.ID, groupID, group.RateMultiplier)` 得到的倍率;这保持现有行为。
|
||||
- `image_rate_independent=true`:图片计费使用 `group.image_rate_multiplier`;普通编码的 `rate_multiplier` 和用户专属普通倍率不参与图片扣费。
|
||||
- `image_price_1k/2k/4k`:继续表示图片基础单价,由选中的图片倍率模式继续相乘。
|
||||
|
||||
新建分组默认 `allow_image_generation=false`,避免新普通编码分组意外获得生图能力。为避免升级后立即打断已有图片业务,迁移对现有 `openai`、`gemini`、`antigravity` 分组回填 `allow_image_generation=true`,`anthropic` 分组保持 `false`。该回填只是兼容现状;上线后管理员必须按业务策略关闭不允许生图的普通编码分组。
|
||||
|
||||
迁移不改写已有 `image_price_1k/2k/4k`,并将所有现有分组设为 `image_rate_independent=false`、`image_rate_multiplier=1`。这样现有最终扣费公式保持不变:
|
||||
|
||||
```text
|
||||
历史/默认模式图片最终扣费 = image_price_* * image_count * 当前有效分组倍率
|
||||
```
|
||||
|
||||
普通编码分组 `rate_multiplier=0.15` 且希望图片 1K 最终扣费 `0.2/张` 时,管理员不再需要填写 `0.2/0.15`,而是设置:
|
||||
|
||||
```text
|
||||
image_rate_independent = true
|
||||
image_rate_multiplier = 1
|
||||
image_price_1k = 0.2
|
||||
```
|
||||
|
||||
如果希望图片也打折,例如图片标价 `0.2/张`、图片折扣 `0.8`,则设置:
|
||||
|
||||
```text
|
||||
image_rate_independent = true
|
||||
image_rate_multiplier = 0.8
|
||||
image_price_1k = 0.2
|
||||
```
|
||||
|
||||
### 2. 生图意图统一识别
|
||||
|
||||
新增一个服务层 helper,输入至少包含 endpoint、请求模型、请求体,输出是否为生图意图:
|
||||
|
||||
```text
|
||||
isImageGenerationIntent =
|
||||
endpoint 是 /v1/images/generations 或 /v1/images/edits
|
||||
OR requested model 以 gpt-image- 开头
|
||||
OR tools[] 存在 type == image_generation
|
||||
OR tool_choice 显式指向 image_generation
|
||||
```
|
||||
|
||||
生图意图判断必须在请求体被 Codex 注入、模型改写、渠道映射改写之前执行一次,并在这些改写之后再对最终请求体执行一次。原因是当前代码会在 `backend/internal/service/openai_gateway_service.go` 中注入 `image_generation` tool,也会在 `normalizeOpenAIResponsesImageOnlyModel` 中把 `gpt-image-*` 改写为文本模型 + 图片工具;只检查改写前或只检查改写后都可能漏掉场景。
|
||||
|
||||
`tool_choice` 判断只把明确指向 `image_generation` 的值视为生图意图;`auto`、`none`、`required` 本身不构成生图意图,但如果 `tools[]` 中存在 `image_generation`,仍由 `tools[]` 规则命中。
|
||||
|
||||
该判断必须在以下位置使用:
|
||||
|
||||
- `/v1/images/*` handler 解析请求后、账号调度前。
|
||||
- `/v1/responses` 解析 body 后、Codex 自动注入 `image_generation` tool 前。
|
||||
- `normalizeOpenAIResponsesImageOnlyModel` 把 `gpt-image-*` 改写为 Responses 文本模型前。
|
||||
- OpenAI 高级 scheduler 入口保留现有账号能力检查,同时补齐渠道 restriction 检查,避免启用高级调度时绕过渠道模型限制。
|
||||
|
||||
当 `allow_image_generation=false` 时:
|
||||
|
||||
- 显式生图意图返回 HTTP 403,错误类型使用现有 `permission_error` 风格。
|
||||
- Codex CLI 请求不自动注入 `image_generation` tool,也不追加图片桥接指令;如果请求没有显式生图意图,则继续按普通文本请求处理。
|
||||
|
||||
### 3. gpt-5.4 / gpt-5.5 生图承载方式
|
||||
|
||||
`gpt-5.4` / `gpt-5.5` 生图通过现有 OpenAI Responses API 的 `image_generation` tool 承载,不新增专用 endpoint:
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "gpt-5.4",
|
||||
"input": "生成一张图片",
|
||||
"tools": [
|
||||
{
|
||||
"type": "image_generation",
|
||||
"model": "gpt-image-2",
|
||||
"size": "1024x1024",
|
||||
"output_format": "png"
|
||||
}
|
||||
],
|
||||
"tool_choice": { "type": "image_generation" }
|
||||
}
|
||||
```
|
||||
|
||||
`model=gpt-image-*` 发到 `/v1/responses` 时保留现有改写方向:主模型改为 Responses 文本模型,图片模型放入 `image_generation` tool。计费时如果能从工具配置得到 `gpt-image-*`,图片默认价格按该图片模型解析;如果工具未指定图片模型,则使用当前转发结果的 billing model,并优先使用分组/渠道配置价格。
|
||||
|
||||
### 4. 图片数量归因
|
||||
|
||||
新增统一图片输出解析 helper,返回去重后的图片数量和可用图片元信息。必须覆盖以下已有或可借鉴的事件形态:
|
||||
|
||||
- 非流式 Responses JSON:`output[]` 中 `type == image_generation_call` 且 `result` 非空。
|
||||
- Responses SSE:`response.output_item.done` 中 `item.type == image_generation_call` 且 `item.result` 非空。
|
||||
- Responses SSE 完成事件:`response.completed.response.output[]` 中图片工具结果。
|
||||
- Images API 非流式:顶层 `data[]`。
|
||||
- Images API 流式:顶层 `data[]`、`image_generation.completed`、`response.output_item.done`、`response.completed`。
|
||||
|
||||
去重键按优先级使用 `item.id`、`call_id`、`result` 内容 hash。只统计最终图片,不统计 `partial_image`。
|
||||
|
||||
`openaiStreamingResult` 增加 `imageCount`、`imageSize`、`imageBillingModel`。`handleStreamingResponse`、`handleStreamingResponsePassthrough`、`handleNonStreamingResponse`、`handleNonStreamingResponsePassthrough` 都必须把解析结果带回 `OpenAIForwardResult`。当 `ImageCount > 0` 时,即使上游 usage 为 0,也必须写 usage log 并进入图片计费。
|
||||
|
||||
### 5. 图片价格公式
|
||||
|
||||
图片计费先确定单价,再确定倍率:
|
||||
|
||||
```text
|
||||
unit_price = 渠道 image 模式价格 或 分组 image_price_* 或 默认图片价格
|
||||
image_multiplier =
|
||||
如果 group.image_rate_independent == true: group.image_rate_multiplier
|
||||
否则: 当前有效分组倍率
|
||||
total_cost = unit_price * image_count
|
||||
actual_cost = total_cost * image_multiplier
|
||||
```
|
||||
|
||||
“当前有效分组倍率”必须沿用当前代码的倍率解析方式:默认配置倍率 → 分组 `rate_multiplier` → 用户专属分组倍率覆盖。这样 `image_rate_independent=false` 时完全保留当前行为。
|
||||
|
||||
`billing_mode=image` 的渠道价格是图片单价来源之一,仍优先于分组图片价格。图片渠道价格也必须按 `ImageCount` 计数,并使用同一套 `image_multiplier` 选择逻辑。
|
||||
|
||||
`billing_mode=per_request` 的非图片请求保持当前普通按次语义,继续使用普通 token 倍率;只有已经识别为图片请求且 `ImageCount > 0` 的路径使用图片计费逻辑。
|
||||
|
||||
`usage_logs.rate_multiplier` 继续表示“本次扣费实际使用的倍率”。因此:
|
||||
|
||||
- token 日志记录普通 token 有效倍率。
|
||||
- image 日志在共享模式记录普通有效倍率。
|
||||
- image 日志在独立模式记录 `image_rate_multiplier`。
|
||||
|
||||
专用 `/v1/images/*` 仍按图片请求语义计费:当 `ImageCount > 0` 时,图片价格决定费用,伴随的上游 token usage 只记录不额外计 token 费用。这保持当前 Images API 的行为。
|
||||
|
||||
通用 `/v1/responses + image_generation` 的混合文本+图片输出存在一个明确取舍:如果继续沿用“`ImageCount > 0` 时只按图片计费”的当前计费分支,用户可以在一次图片请求中夹带大量文本输出而只付图片费用;如果改成“图片费用 + 非图片 token 费用”,会改变当前 `billing_mode=image` 的单一计费语义,并可能让渠道图片单价不再是全包价格。本变更为最大兼容性不引入混合计费模式,但必须在 usage log 中完整记录 token 与 image_count,便于后续按数据决定是否新增 `image_plus_token` 计费模式。
|
||||
|
||||
### 6. 尺寸档位与参数透传
|
||||
|
||||
OpenAI 图片请求的 `size` 参数必须透传给上游;本地只做计费分档,不做 OpenAI 尺寸合法性校验。无论尺寸是否满足官方约束,本地都不能因为未知尺寸或 provider-invalid 尺寸返回 400;如果上游不接受该尺寸,由上游响应错误。
|
||||
|
||||
官方 `gpt-image-2` 文档给出的常用尺寸与约束是本地计费分档的依据:
|
||||
|
||||
- 常用尺寸:`1024x1024`、`1536x1024`、`1024x1536`、`2048x2048`、`2048x1152`、`3840x2160`、`2160x3840`、`auto`。
|
||||
- 自定义尺寸:官方支持满足约束的任意 `size`,包括边长、16 像素倍数、长短边比例、总像素范围等约束。
|
||||
- `2560x1440` 是 2K/QHD 参考边界;超过 `2560x1440` 总像素的输出进入更高档位风险区。
|
||||
|
||||
OpenAI 图片尺寸分层必须按以下规则:
|
||||
|
||||
```text
|
||||
empty, auto => 2K
|
||||
1024x1024 => 1K
|
||||
1536x1024, 1024x1536 => 2K
|
||||
1792x1024, 1024x1792 => 2K
|
||||
2048x2048, 2048x1152, 1152x2048 => 2K
|
||||
3840x2160, 2160x3840 => 4K
|
||||
未知且无法解析为正整数 WIDTHxHEIGHT => 2K
|
||||
未知且 WIDTH * HEIGHT <= 2560*1440 => 2K
|
||||
未知且 WIDTH * HEIGHT > 2560*1440 => 4K
|
||||
```
|
||||
|
||||
这个规则只决定 `ImageSize` 和扣费档位,不修改请求体,不删除未知参数,不把未知尺寸改写成预设尺寸。
|
||||
|
||||
## Risks / Trade-offs
|
||||
|
||||
- 历史普通编码分组迁移后仍默认允许生图 → 通过管理员可见开关、上线核对清单和新建分组默认关闭来控制;代码无法可靠判断“普通编码分组”和“图片分组”的业务意图。
|
||||
- 默认共享现有有效倍率仍保留“图片最终价不直观”的问题 → 这是兼容性选择;需要直观设置图片最终价的分组必须打开 `image_rate_independent`。
|
||||
- 独立图片倍率不会读取用户专属普通倍率 → 这是目标行为;如需要用户级图片独立倍率,应作为后续独立需求实现。
|
||||
- 通用 Responses 图片工具可能同时输出文本和图片 → 本变更默认仍按图片请求语义计费并完整记录 token;若业务要求文本也收费,应新增独立的混合计费模式,不能混入本次兼容性变更。
|
||||
- 本地不再拦截未知或 provider-invalid OpenAI 尺寸 → 非法尺寸会消耗一次上游请求失败成本和用户体验往返,但这是为了保证参数透传、兼容官方新增尺寸和第三方兼容提供商;计费只在成功产出最终图片后发生。
|
||||
- Responses 流式解析需要在客户端断开后继续 drain 上游以完成计费 → 沿用当前流式处理“客户端断开后继续读取上游用于计费”的模式,并只新增轻量 JSON 路径提取。
|
||||
- 预扣费不在本变更中实现 → 继续使用现有成功后扣费模型,避免失败请求退款、流式中断退款和图片数量未确定时预估错误。
|
||||
|
||||
## Migration Plan
|
||||
|
||||
1. 新增数据库迁移,添加 `groups.allow_image_generation`、`groups.image_rate_independent` 和 `groups.image_rate_multiplier`。
|
||||
2. 回填现有分组:`openai`、`gemini`、`antigravity` 的 `allow_image_generation=true`,`anthropic=false`;所有现有分组 `image_rate_independent=false`、`image_rate_multiplier=1`。
|
||||
3. 不改写现有 `image_price_1k/2k/4k`,保持默认共享倍率模式下的历史扣费结果。
|
||||
4. 更新 Ent schema 与生成代码,更新后端 service/handler DTO 和前端类型。
|
||||
5. 先接入权限判断,确保未开启生图的分组不会到达上游。
|
||||
6. 再接入图片数量解析和图片计费倍率选择,确保开启生图的分组按图片数量收费。
|
||||
7. 最后更新前端管理界面、i18n、文档和测试。
|
||||
8. 回滚时只能通过新迁移回滚字段行为;不能修改已应用迁移文件。
|
||||
|
||||
## Open Questions
|
||||
|
||||
无。当前方案不依赖未确认的上游新尺寸、新模型或新 endpoint。
|
||||
@ -0,0 +1,29 @@
|
||||
## Why
|
||||
|
||||
当前代码把“能否生图”和“如何按图片收费”混在模型、分组倍率、渠道定价与 Responses 工具调用里,导致 OpenAI 普通编码分组在允许 `gpt-5.4` / `gpt-5.5` 时也能通过 `image_generation` tool 产图,并且通用 `/v1/responses` 产图不会稳定写入 `ImageCount`。需要把生图能力、图片倍率模式、图片产出数量归因拆成独立能力,保证普通编码分组可按业务开关生图,开启后既能沿用现有倍率行为,也能按需切换到图片独立倍率。
|
||||
|
||||
## What Changes
|
||||
|
||||
- 新增分组级生图能力开关,明确控制 `/v1/images/*`、`gpt-image-*`、显式 `image_generation` tool、Codex 自动注入图片工具等所有生图入口。
|
||||
- 新增分组级图片倍率模式开关,默认继续共享现有分组有效倍率;打开独立模式后使用图片独立倍率输入框。
|
||||
- 保留现有 `image_price_1k/2k/4k` 图片价格配置;图片最终扣费由“图片价格 × 当前倍率模式选出的倍率 × 图片数量”决定。
|
||||
- 统一统计 OpenAI Responses 图片工具产物数量,使 `gpt-5.4` / `gpt-5.5` 通过 `image_generation` tool 产图时进入图片计费,而不是退化成普通 token 计费或无 usage 时不计费。
|
||||
- 修正专用 Images API 与渠道图片计费场景,按实际图片数量和明确尺寸档位计费,避免固定 `RequestCount=1` 或未知尺寸静默落到 `2K`。
|
||||
- 更新后台分组配置、前端类型、使用说明和测试,覆盖普通编码分组关闭生图、普通编码分组开启生图、独立图片分组承载、生图流式/非流式等场景。
|
||||
|
||||
## Capabilities
|
||||
|
||||
### New Capabilities
|
||||
- `image-generation-access-control`: 定义分组级生图能力开关、所有生图意图识别规则、拒绝行为与 Codex 自动注入规则。
|
||||
- `image-generation-billing-accounting`: 定义图片倍率模式、图片数量归因、尺寸档位、渠道图片价格和用量日志要求。
|
||||
|
||||
### Modified Capabilities
|
||||
- 无。
|
||||
|
||||
## Impact
|
||||
|
||||
- Backend schema/API: `backend/ent/schema/group.go`、Ent 生成代码、数据库迁移、管理员分组 create/update/list DTO、分组缓存/序列化。
|
||||
- Backend request gates: `backend/internal/handler/openai_images.go`、`backend/internal/service/openai_gateway_service.go`、`backend/internal/service/openai_codex_transform.go`、OpenAI account scheduler 相关模型/图片能力调度入口。
|
||||
- Backend billing: `backend/internal/service/billing_service.go`、`backend/internal/service/openai_gateway_service.go`、`backend/internal/service/gateway_service.go`、usage log 与 account stats 成本计算路径。
|
||||
- Frontend admin: `frontend/src/types/index.ts`、`frontend/src/views/admin/GroupsView.vue`、相关 i18n 文案与图片计费展示。
|
||||
- Tests: OpenAI Images API、OpenAI Responses stream/non-stream/passthrough、分组开关、图片倍率模式、渠道图片计数、尺寸档位与 usage log 断言。
|
||||
@ -0,0 +1,118 @@
|
||||
## ADDED Requirements
|
||||
|
||||
### Requirement: Group image generation capability
|
||||
The system SHALL store a group-level `allow_image_generation` capability flag and SHALL expose it through admin group create, update, list, and detail APIs.
|
||||
|
||||
#### Scenario: New group defaults to image generation disabled
|
||||
- **WHEN** an admin creates a group without providing `allow_image_generation`
|
||||
- **THEN** the persisted group has `allow_image_generation=false`
|
||||
|
||||
#### Scenario: Existing image-capable platform groups are backfilled
|
||||
- **WHEN** the migration is applied to existing groups
|
||||
- **THEN** existing `openai`, `gemini`, and `antigravity` groups have `allow_image_generation=true`
|
||||
- **AND** existing `anthropic` groups have `allow_image_generation=false`
|
||||
|
||||
#### Scenario: Admin enables image generation on an ordinary coding group
|
||||
- **WHEN** an admin updates an `openai` group with `allow_image_generation=true`
|
||||
- **THEN** the group can use image generation paths subject to the billing requirements
|
||||
|
||||
#### Scenario: Admin disables image generation on an ordinary coding group
|
||||
- **WHEN** an admin updates an `openai` group with `allow_image_generation=false`
|
||||
- **THEN** the group can still use non-image text model requests
|
||||
- **AND** image generation intents are denied before upstream dispatch
|
||||
|
||||
### Requirement: Image generation intent detection
|
||||
The system SHALL classify a request as an image generation intent before upstream account scheduling when the endpoint or request body can produce generated images.
|
||||
|
||||
#### Scenario: Images endpoint is an image generation intent
|
||||
- **WHEN** a request targets `/v1/images/generations`, `/v1/images/edits`, `/images/generations`, or `/images/edits`
|
||||
- **THEN** the request is classified as an image generation intent
|
||||
|
||||
#### Scenario: Responses request with image-only model is an image generation intent
|
||||
- **WHEN** a `/v1/responses` request has a requested model whose normalized name starts with `gpt-image-`
|
||||
- **THEN** the request is classified as an image generation intent before any model rewrite
|
||||
|
||||
#### Scenario: Responses request with image_generation tool is an image generation intent
|
||||
- **WHEN** a `/v1/responses` request contains any `tools[]` entry with `type == "image_generation"`
|
||||
- **THEN** the request is classified as an image generation intent
|
||||
|
||||
#### Scenario: Responses request with image_generation tool_choice is an image generation intent
|
||||
- **WHEN** a `/v1/responses` request contains `tool_choice` that explicitly selects `image_generation`
|
||||
- **THEN** the request is classified as an image generation intent even if `tools[]` is malformed or absent
|
||||
|
||||
#### Scenario: Generic tool_choice required is not sufficient by itself
|
||||
- **WHEN** a `/v1/responses` request contains `tool_choice="required"`
|
||||
- **AND** the request does not contain an `image_generation` tool
|
||||
- **THEN** the request is not classified as an image generation intent because of `tool_choice` alone
|
||||
|
||||
#### Scenario: Text-only gpt-5.4 request is not an image generation intent
|
||||
- **WHEN** a `/v1/responses` request uses `model="gpt-5.4"` or `model="gpt-5.5"` without `image_generation` tool and without image `tool_choice`
|
||||
- **THEN** the request is not classified as an image generation intent
|
||||
|
||||
#### Scenario: Intent is checked before and after service-side mutation
|
||||
- **WHEN** the service mutates a `/v1/responses` request by injecting `image_generation` or rewriting `gpt-image-*` to a Responses text model plus image tool
|
||||
- **THEN** the final mutated request is checked against the same image generation intent rules before upstream dispatch
|
||||
|
||||
### Requirement: Disabled groups reject explicit image generation
|
||||
The system SHALL reject explicit image generation intents for groups with `allow_image_generation=false` before selecting or calling an upstream account.
|
||||
|
||||
#### Scenario: Disabled group rejects Images API
|
||||
- **WHEN** a group has `allow_image_generation=false`
|
||||
- **AND** a user calls `/v1/images/generations`
|
||||
- **THEN** the system returns HTTP 403 with error type `permission_error`
|
||||
- **AND** no upstream account is selected
|
||||
- **AND** no usage log is written
|
||||
|
||||
#### Scenario: Disabled group rejects Responses image tool
|
||||
- **WHEN** a group has `allow_image_generation=false`
|
||||
- **AND** a user calls `/v1/responses` with `tools:[{"type":"image_generation"}]`
|
||||
- **THEN** the system returns HTTP 403 with error type `permission_error`
|
||||
- **AND** no upstream account is selected
|
||||
- **AND** no usage log is written
|
||||
|
||||
#### Scenario: Disabled group rejects Responses image-only model rewrite
|
||||
- **WHEN** a group has `allow_image_generation=false`
|
||||
- **AND** a user calls `/v1/responses` with `model` starting with `gpt-image-`
|
||||
- **THEN** the system returns HTTP 403 with error type `permission_error`
|
||||
- **AND** the request is not rewritten to a text Responses model
|
||||
|
||||
#### Scenario: Disabled group permits normal coding request
|
||||
- **WHEN** a group has `allow_image_generation=false`
|
||||
- **AND** a user calls `/v1/responses` with `model="gpt-5.4"` and no image generation intent
|
||||
- **THEN** the request proceeds through the normal text forwarding path
|
||||
|
||||
### Requirement: Codex image tool injection respects group capability
|
||||
The system SHALL only inject the OpenAI Responses `image_generation` tool and bridge instructions for Codex clients when the request group has `allow_image_generation=true`.
|
||||
|
||||
#### Scenario: Codex request in enabled group receives image tool
|
||||
- **WHEN** a Codex CLI `/v1/responses` request belongs to a group with `allow_image_generation=true`
|
||||
- **AND** the request has no `image_generation` tool
|
||||
- **THEN** the system injects the existing `image_generation` tool payload
|
||||
- **AND** the system appends the existing Codex image bridge instructions
|
||||
|
||||
#### Scenario: Codex request in disabled group does not receive image tool
|
||||
- **WHEN** a Codex CLI `/v1/responses` request belongs to a group with `allow_image_generation=false`
|
||||
- **AND** the request has no explicit image generation intent
|
||||
- **THEN** the system does not inject `image_generation`
|
||||
- **AND** the system does not append image bridge instructions
|
||||
- **AND** the request proceeds as a text request
|
||||
|
||||
#### Scenario: Codex explicit image request in disabled group is denied
|
||||
- **WHEN** a Codex CLI `/v1/responses` request belongs to a group with `allow_image_generation=false`
|
||||
- **AND** the request explicitly contains `image_generation`
|
||||
- **THEN** the system returns HTTP 403 with error type `permission_error`
|
||||
|
||||
### Requirement: Channel model restrictions remain enforced
|
||||
The system SHALL keep existing channel model restriction behavior for image and non-image OpenAI requests, including when the advanced OpenAI account scheduler is enabled.
|
||||
|
||||
#### Scenario: Advanced scheduler blocks restricted requested model
|
||||
- **WHEN** a channel has `restrict_models=true`
|
||||
- **AND** the requested model is not allowed by channel pricing or mapping rules
|
||||
- **AND** the OpenAI advanced scheduler path is used
|
||||
- **THEN** the request is rejected before upstream account selection succeeds
|
||||
|
||||
#### Scenario: Image generation flag does not bypass channel restrictions
|
||||
- **WHEN** a group has `allow_image_generation=true`
|
||||
- **AND** the channel restriction rejects the requested or billing model
|
||||
- **THEN** the image generation request is rejected
|
||||
- **AND** no upstream image request is sent
|
||||
@ -0,0 +1,225 @@
|
||||
## ADDED Requirements
|
||||
|
||||
### Requirement: Image multiplier mode
|
||||
The system SHALL calculate image generation cost with group image prices and a selectable image multiplier mode. By default image billing SHALL share the existing effective group multiplier; when `image_rate_independent=true`, image billing SHALL use `image_rate_multiplier`.
|
||||
|
||||
#### Scenario: Default image billing shares current effective group multiplier
|
||||
- **WHEN** a group has `rate_multiplier=0.15`
|
||||
- **AND** `image_rate_independent=false`
|
||||
- **AND** `image_price_1k=0.2`
|
||||
- **AND** a successful image request produces one `1K` image
|
||||
- **THEN** `actual_cost` is `0.03`
|
||||
- **AND** the calculation matches current default behavior
|
||||
|
||||
#### Scenario: User-specific token multiplier still applies in shared mode
|
||||
- **WHEN** a user has a user-group token multiplier override of `0.2`
|
||||
- **AND** the group has `image_rate_independent=false`
|
||||
- **AND** `image_price_1k=0.5`
|
||||
- **AND** a successful image request produces one `1K` image
|
||||
- **THEN** `actual_cost` is `0.1`
|
||||
- **AND** the applied image multiplier is the same effective multiplier used by token billing
|
||||
|
||||
#### Scenario: Independent image multiplier allows direct final price
|
||||
- **WHEN** a group has `rate_multiplier=0.15`
|
||||
- **AND** `image_rate_independent=true`
|
||||
- **AND** `image_rate_multiplier=1`
|
||||
- **AND** `image_price_1k=0.2`
|
||||
- **AND** a successful image request produces one `1K` image
|
||||
- **THEN** `actual_cost` is `0.2`
|
||||
- **AND** ordinary `rate_multiplier=0.15` is not applied to the image cost
|
||||
|
||||
#### Scenario: Independent image multiplier supports image discounts
|
||||
- **WHEN** a group has `image_rate_independent=true`
|
||||
- **AND** `image_rate_multiplier=0.5`
|
||||
- **AND** `image_price_1k=0.2`
|
||||
- **AND** a successful image request produces two `1K` images
|
||||
- **THEN** `total_cost` is `0.4`
|
||||
- **AND** `actual_cost` is `0.2`
|
||||
|
||||
#### Scenario: Migration preserves existing image price behavior
|
||||
- **WHEN** an existing group has `rate_multiplier=0.15` and `image_price_1k=1.3333333333`
|
||||
- **AND** the migration is applied
|
||||
- **THEN** the stored `image_price_1k` remains `1.3333333333`
|
||||
- **AND** the stored `image_rate_independent` is `false`
|
||||
- **AND** the stored `image_rate_multiplier` is `1`
|
||||
- **AND** default-mode image billing still produces the historical final price within decimal precision
|
||||
|
||||
#### Scenario: Omitted update fields preserve existing multiplier mode
|
||||
- **WHEN** an admin updates a group without sending `image_rate_independent`
|
||||
- **AND** without sending `image_rate_multiplier`
|
||||
- **THEN** the stored image multiplier mode and image multiplier value remain unchanged
|
||||
|
||||
#### Scenario: Image multiplier can be zero only by explicit independent mode configuration
|
||||
- **WHEN** a group has `image_rate_independent=true`
|
||||
- **AND** `image_rate_multiplier=0`
|
||||
- **AND** a successful image request produces one image
|
||||
- **THEN** the image request is free
|
||||
- **AND** this free-image behavior does not occur unless the group explicitly enables independent image multiplier mode with zero multiplier
|
||||
|
||||
### Requirement: Responses image output accounting
|
||||
The system SHALL count generated image outputs from OpenAI Responses stream, non-stream, and passthrough paths and SHALL return the count in `OpenAIForwardResult.ImageCount`.
|
||||
|
||||
#### Scenario: Non-stream Responses image tool output is counted
|
||||
- **WHEN** a non-stream `/v1/responses` upstream response contains `output[]` item with `type == "image_generation_call"` and non-empty `result`
|
||||
- **THEN** `OpenAIForwardResult.ImageCount` equals the number of unique final image outputs
|
||||
- **AND** `OpenAIForwardResult.ImageSize` is the normalized image size tier
|
||||
|
||||
#### Scenario: Stream Responses output item is counted
|
||||
- **WHEN** a stream `/v1/responses` upstream SSE event has `type == "response.output_item.done"`
|
||||
- **AND** the event item has `type == "image_generation_call"` and non-empty `result`
|
||||
- **THEN** the streaming result increments the unique final image output count
|
||||
|
||||
#### Scenario: Stream Responses completed output is counted
|
||||
- **WHEN** a stream `/v1/responses` upstream SSE event has `type == "response.completed"`
|
||||
- **AND** `response.output[]` contains final image generation outputs
|
||||
- **THEN** the streaming result counts those images without double-counting images already seen in `response.output_item.done`
|
||||
|
||||
#### Scenario: Partial image events are not billed as completed images
|
||||
- **WHEN** a stream response contains `partial_image` events
|
||||
- **THEN** those partial events do not increment `ImageCount`
|
||||
- **AND** only final image generation outputs increment `ImageCount`
|
||||
|
||||
#### Scenario: gpt-5.4 image tool request is billed as image
|
||||
- **WHEN** a `/v1/responses` request uses `model="gpt-5.4"` or `model="gpt-5.5"`
|
||||
- **AND** the request includes an `image_generation` tool
|
||||
- **AND** the upstream response contains one final image output
|
||||
- **THEN** the usage log has `image_count=1`
|
||||
- **AND** the usage log has `billing_mode="image"`
|
||||
- **AND** image pricing, not token pricing, determines `actual_cost`
|
||||
|
||||
#### Scenario: Image output with zero usage is still billed
|
||||
- **WHEN** an upstream Responses result contains final image output
|
||||
- **AND** the upstream result has zero or missing token usage
|
||||
- **THEN** the system writes a usage log
|
||||
- **AND** the system bills using image pricing
|
||||
|
||||
#### Scenario: Responses image request records accompanying token usage
|
||||
- **WHEN** a `/v1/responses` image tool request returns final images and token usage
|
||||
- **THEN** the usage log records input tokens, output tokens, image output tokens, and image count
|
||||
- **AND** the applied billing mode remains `image`
|
||||
|
||||
#### Scenario: Responses image request does not introduce hybrid billing by default
|
||||
- **WHEN** a `/v1/responses` image tool request returns final images and text tokens
|
||||
- **THEN** the request is billed by image pricing under this change
|
||||
- **AND** non-image token charges are not added unless a future explicit hybrid billing mode is implemented
|
||||
|
||||
### Requirement: OpenAI Images API output accounting
|
||||
The system SHALL count generated images from dedicated OpenAI Images API stream and non-stream paths and SHALL set `ImageCount` for successful image responses.
|
||||
|
||||
#### Scenario: Images non-stream data array is counted
|
||||
- **WHEN** `/v1/images/generations` returns a non-stream JSON response with top-level `data[]`
|
||||
- **THEN** `ImageCount` equals the length of `data[]`
|
||||
|
||||
#### Scenario: Images stream data array is counted
|
||||
- **WHEN** `/v1/images/generations` stream response emits SSE data containing top-level `data[]`
|
||||
- **THEN** `ImageCount` equals the maximum final data array count observed for the request
|
||||
|
||||
#### Scenario: Images stream completed event is counted
|
||||
- **WHEN** `/v1/images/generations` stream response emits `image_generation.completed` with a final image payload
|
||||
- **THEN** the stream result counts one final image output
|
||||
|
||||
#### Scenario: Images stream Responses-form event is counted
|
||||
- **WHEN** an Images API upstream path emits Responses-form `response.output_item.done` or `response.completed` events with final image outputs
|
||||
- **THEN** the stream result counts final image outputs using the same de-duplication rules as Responses
|
||||
|
||||
### Requirement: Channel image billing uses actual image count
|
||||
The system SHALL use actual generated image count for channel `billing_mode=image` pricing and SHALL NOT bill multi-image requests as a single request.
|
||||
|
||||
#### Scenario: OpenAI channel image billing counts multiple images
|
||||
- **WHEN** a channel image pricing entry resolves to unit price `0.25`
|
||||
- **AND** an OpenAI image request produces three images
|
||||
- **THEN** `total_cost` is `0.75` before the selected image multiplier is applied
|
||||
- **AND** `RequestCount` passed into unified pricing is `3`
|
||||
|
||||
#### Scenario: Gateway channel image billing counts multiple images
|
||||
- **WHEN** a non-OpenAI gateway image path produces two images
|
||||
- **AND** channel image pricing resolves for the billing model
|
||||
- **THEN** `RequestCount` passed into unified pricing is `2`
|
||||
|
||||
#### Scenario: Channel image pricing uses shared multiplier by default
|
||||
- **WHEN** a channel image pricing entry resolves to unit price `0.25`
|
||||
- **AND** the group has ordinary effective multiplier `0.15`
|
||||
- **AND** the group has `image_rate_independent=false`
|
||||
- **AND** the image request produces one image
|
||||
- **THEN** `actual_cost` is `0.0375`
|
||||
|
||||
#### Scenario: Channel image pricing uses independent image multiplier when enabled
|
||||
- **WHEN** a channel image pricing entry resolves to unit price `0.25`
|
||||
- **AND** the group has ordinary effective multiplier `0.15`
|
||||
- **AND** the group has `image_rate_independent=true`
|
||||
- **AND** the group has `image_rate_multiplier=1`
|
||||
- **AND** the image request produces one image
|
||||
- **THEN** `actual_cost` is `0.25`
|
||||
- **AND** ordinary effective multiplier `0.15` is not applied
|
||||
|
||||
#### Scenario: Account stats image pricing receives image count
|
||||
- **WHEN** account stats pricing uses `billing_mode=image`
|
||||
- **AND** the request produces multiple images
|
||||
- **THEN** account stats cost is calculated with the actual image count
|
||||
|
||||
### Requirement: Image size tier normalization
|
||||
The system SHALL normalize OpenAI image sizes to explicit billing tiers for billing only. The system SHALL NOT reject requests locally because of an unknown or provider-invalid `size`; it SHALL forward the original size parameter upstream and let the official upstream API decide whether the request is valid.
|
||||
|
||||
#### Scenario: OpenAI 1024 square maps to 1K
|
||||
- **WHEN** an OpenAI image request specifies `size="1024x1024"`
|
||||
- **THEN** `ImageSize` is `1K`
|
||||
|
||||
#### Scenario: OpenAI landscape and portrait large sizes map to 2K
|
||||
- **WHEN** an OpenAI image request specifies `1536x1024`, `1024x1536`, `1792x1024`, `1024x1792`, `2048x2048`, `2048x1152`, or `1152x2048`
|
||||
- **THEN** `ImageSize` is `2K`
|
||||
|
||||
#### Scenario: OpenAI gpt-image-2 4K presets map to 4K
|
||||
- **WHEN** an OpenAI `gpt-image-2` image request specifies `3840x2160` or `2160x3840`
|
||||
- **THEN** `ImageSize` is `4K`
|
||||
|
||||
#### Scenario: OpenAI auto size maps to 2K
|
||||
- **WHEN** an OpenAI image request omits size or specifies `size="auto"`
|
||||
- **THEN** `ImageSize` is `2K`
|
||||
|
||||
#### Scenario: Custom OpenAI size is forwarded without local validation
|
||||
- **WHEN** an OpenAI image request specifies a custom explicit `WIDTHxHEIGHT` size
|
||||
- **THEN** the system forwards the request upstream
|
||||
- **AND** `ImageSize` is normalized to `2K` or `4K` for billing
|
||||
|
||||
#### Scenario: Responses image tool without model uses default image billing model
|
||||
- **WHEN** a `/v1/responses` request uses an `image_generation` tool without `tool.model`
|
||||
- **THEN** image size validation and image billing use `gpt-image-2` as the image billing model
|
||||
|
||||
#### Scenario: Invalid OpenAI size constraints are delegated upstream
|
||||
- **WHEN** an OpenAI image request specifies an explicit size that fails OpenAI size constraints
|
||||
- **THEN** the system forwards the request upstream
|
||||
- **AND** any invalid-size error comes from the upstream provider response
|
||||
|
||||
#### Scenario: Custom OpenAI size tier mapping
|
||||
- **WHEN** a custom size cannot be parsed as positive `WIDTHxHEIGHT`
|
||||
- **THEN** `ImageSize` is `2K`
|
||||
- **WHEN** a custom size parses as positive `WIDTHxHEIGHT`
|
||||
- **AND** `WIDTH * HEIGHT` is no more than `2560x1440`
|
||||
- **THEN** `ImageSize` is `2K`
|
||||
- **WHEN** a custom size parses as positive `WIDTHxHEIGHT`
|
||||
- **AND** `WIDTH * HEIGHT` exceeds `2560x1440`
|
||||
- **THEN** `ImageSize` is `4K`
|
||||
|
||||
### Requirement: Image usage log semantics
|
||||
The system SHALL write usage logs for successful image generation with image billing metadata that matches the applied image pricing path.
|
||||
|
||||
#### Scenario: Image usage log records image billing mode
|
||||
- **WHEN** a successful request has `ImageCount > 0`
|
||||
- **THEN** the usage log has `billing_mode="image"`
|
||||
- **AND** the usage log records `image_count`
|
||||
- **AND** the usage log records `image_size` when a normalized size tier is available
|
||||
|
||||
#### Scenario: Shared mode image usage log records shared multiplier
|
||||
- **WHEN** a successful image request is billed with `image_rate_independent=false`
|
||||
- **AND** the effective ordinary multiplier is `0.15`
|
||||
- **THEN** `usage_logs.rate_multiplier` is `0.15`
|
||||
|
||||
#### Scenario: Independent mode image usage log records image multiplier
|
||||
- **WHEN** a successful image request is billed with `image_rate_independent=true`
|
||||
- **AND** `image_rate_multiplier=0.5`
|
||||
- **THEN** `usage_logs.rate_multiplier` is `0.5`
|
||||
|
||||
#### Scenario: Token request usage log is unchanged
|
||||
- **WHEN** a successful non-image token request is billed
|
||||
- **THEN** `usage_logs.rate_multiplier` continues to record the ordinary token multiplier
|
||||
- **AND** `image_count` is `0`
|
||||
@ -0,0 +1,72 @@
|
||||
## 1. Data Model And Migration
|
||||
|
||||
- [x] 1.1 Add `allow_image_generation`, `image_rate_independent`, and `image_rate_multiplier` to `backend/ent/schema/group.go`.
|
||||
- [x] 1.2 Create a new idempotent SQL migration after `133_affiliate_rebate_freeze.sql` for the three group columns.
|
||||
- [x] 1.3 Backfill existing `openai`, `gemini`, and `antigravity` groups to `allow_image_generation=true` and `anthropic` groups to `false`.
|
||||
- [x] 1.4 Backfill all existing groups to `image_rate_independent=false` and `image_rate_multiplier=1` without changing existing `image_price_1k/2k/4k`.
|
||||
- [x] 1.5 Regenerate or update Ent generated group fields, predicates, create/update setters, and query projections.
|
||||
- [x] 1.6 Add the new fields to backend group domain/service structs, admin create/update inputs, admin responses, and group serialization.
|
||||
|
||||
## 2. Admin API And Frontend
|
||||
|
||||
- [x] 2.1 Add `allow_image_generation`, `image_rate_independent`, and `image_rate_multiplier` to `CreateGroupRequest` and `UpdateGroupRequest`.
|
||||
- [x] 2.2 Validate `image_rate_multiplier >= 0` and keep negative image prices using the existing clear-price behavior only for `image_price_*`.
|
||||
- [x] 2.3 Add the new fields to `frontend/src/types/index.ts` group, create, and update interfaces.
|
||||
- [x] 2.4 Ensure omitted update fields do not overwrite existing image generation and multiplier mode settings.
|
||||
- [x] 2.5 Update `frontend/src/views/admin/GroupsView.vue` create/edit forms with a 生图开关, 生图倍率是否独立开关, and conditional image multiplier input.
|
||||
- [x] 2.6 Add a live final-price preview for `image_price_1k/2k/4k` under shared and independent multiplier modes.
|
||||
- [x] 2.7 Update group form help text to state that default image billing shares the existing group effective multiplier and independent mode uses the image multiplier input.
|
||||
- [x] 2.8 Update i18n strings for the new controls and image multiplier mode explanation.
|
||||
|
||||
## 3. Image Generation Access Control
|
||||
|
||||
- [x] 3.1 Implement a shared helper that detects image generation intent from endpoint, requested model, `tools[]`, and `tool_choice`.
|
||||
- [x] 3.2 Gate `/v1/images/generations` and `/v1/images/edits` in `backend/internal/handler/openai_images.go` after request parsing and before billing eligibility/account scheduling.
|
||||
- [x] 3.3 Gate `/v1/responses` explicit `image_generation` tool requests in `backend/internal/service/openai_gateway_service.go` before upstream account scheduling.
|
||||
- [x] 3.4 Prevent `normalizeOpenAIResponsesImageOnlyModel` from rewriting `gpt-image-*` Responses requests when the group does not allow image generation.
|
||||
- [x] 3.5 Skip Codex `image_generation` auto-injection and image bridge instructions when the group does not allow image generation.
|
||||
- [x] 3.6 Re-run image intent detection after service-side request mutation and before upstream dispatch.
|
||||
- [x] 3.7 Ensure OpenAI advanced scheduler paths apply the same channel `RestrictModels` checks as the load-aware path.
|
||||
|
||||
## 4. Responses Image Output Accounting
|
||||
|
||||
- [x] 4.1 Add shared parsers for final `image_generation_call.result` outputs in non-stream JSON and SSE payloads.
|
||||
- [x] 4.2 Extend `openaiStreamingResult` with image count, image size tier, and image billing model fields.
|
||||
- [x] 4.3 Update `handleStreamingResponse` to count final image outputs while preserving existing stream forwarding and usage parsing.
|
||||
- [x] 4.4 Update `handleStreamingResponsePassthrough` with the same image output counting.
|
||||
- [x] 4.5 Update `handleNonStreamingResponse` to count final image outputs from `output[]`.
|
||||
- [x] 4.6 Update `handleNonStreamingResponsePassthrough` with the same non-stream image output counting.
|
||||
- [x] 4.7 Populate `OpenAIForwardResult.ImageCount`, `ImageSize`, and image billing model for `gpt-5.4` / `gpt-5.5 + image_generation` requests.
|
||||
|
||||
## 5. Images API Accounting And Size Tiers
|
||||
|
||||
- [x] 5.1 Extend OpenAI Images API-key stream counting to handle `image_generation.completed`, `response.output_item.done`, and `response.completed`.
|
||||
- [x] 5.2 Reuse the same final-image de-duplication rules across Images API and Responses API paths.
|
||||
- [x] 5.3 Keep unknown explicit OpenAI image sizes pass-through and delegate invalid-size errors to upstream.
|
||||
- [x] 5.4 Map documented OpenAI image sizes to `1K`/`2K`/`4K` billing tiers without rewriting request parameters.
|
||||
- [x] 5.5 Classify custom OpenAI `WIDTHxHEIGHT` sizes by `2560x1440` total-pixel boundary, falling back to `2K` when unparseable.
|
||||
|
||||
## 6. Billing And Usage Logs
|
||||
|
||||
- [x] 6.1 Add an image multiplier resolver: shared mode uses the current effective group multiplier, independent mode uses `apiKey.Group.ImageRateMultiplier`.
|
||||
- [x] 6.2 Update `CalculateImageCost` or its caller contract so image costs use the resolved image multiplier.
|
||||
- [x] 6.3 Set image usage log `RateMultiplier` to the applied image multiplier; keep token logs unchanged.
|
||||
- [x] 6.4 Change OpenAI channel image billing `RequestCount` from `1` to `result.ImageCount`.
|
||||
- [x] 6.5 Change non-OpenAI gateway channel image billing `RequestCount` from `1` to `result.ImageCount`.
|
||||
- [x] 6.6 Pass actual image count into account stats pricing for `billing_mode=image`.
|
||||
- [x] 6.7 Ensure `ImageCount > 0` writes a usage log and bills even when upstream token usage is zero.
|
||||
- [x] 6.8 Record accompanying token usage for Responses image tool requests while keeping default billing mode as `image`.
|
||||
|
||||
## 7. Tests And Documentation
|
||||
|
||||
- [x] 7.1 Add backend tests for disabled group rejecting `/v1/images/*`, `gpt-image-*` Responses, explicit `image_generation`, and image `tool_choice`.
|
||||
- [x] 7.2 Add backend tests proving disabled Codex groups do not receive injected image tools while enabled Codex groups still do.
|
||||
- [x] 7.3 Add backend tests proving omitted group update fields preserve existing image generation and multiplier mode settings.
|
||||
- [x] 7.4 Add Responses stream and non-stream tests for `gpt-5.4` / `gpt-5.5 + image_generation` image counting and image billing.
|
||||
- [x] 7.5 Add Images API stream tests for `image_generation.completed`, `response.output_item.done`, and `response.completed` counting.
|
||||
- [x] 7.6 Add billing tests for shared mode `rate_multiplier=0.15`, `image_price_1k=0.2`, final `actual_cost=0.03`.
|
||||
- [x] 7.7 Add billing tests for independent mode `rate_multiplier=0.15`, `image_rate_multiplier=1`, `image_price_1k=0.2`, final `actual_cost=0.2`.
|
||||
- [x] 7.8 Add channel image billing tests proving multi-image requests use `RequestCount=ImageCount` in both shared and independent multiplier modes.
|
||||
- [x] 7.9 Add size-tier tests for known OpenAI sizes and unknown explicit size pass-through.
|
||||
- [x] 7.10 Add Responses image tool tests proving token usage is recorded but default billing remains image-mode only.
|
||||
- [x] 7.11 Update `2ue/image-billing-risk-analysis.md` or add a linked follow-up note that points to this OpenSpec change as the normalized solution.
|
||||
@ -0,0 +1,2 @@
|
||||
schema: spec-driven
|
||||
created: 2026-05-03
|
||||
@ -0,0 +1,70 @@
|
||||
## Overview
|
||||
|
||||
本次只实现“图片独立并发开关”,不实现外部图片网关的运行时代码。目标是在最大程度不改变现有行为的前提下,为图片流式长连接提供服务级资源保护。
|
||||
|
||||
## Current Constraints
|
||||
|
||||
- 当前 Redis 并发槽位只有用户和账号维度,键语义是 `concurrency:user:*` 与 `concurrency:account:*`。
|
||||
- 图片接口和普通 Responses 在同一个 Go 服务内运行,共享进程、HTTP 上游连接池和账号调度。
|
||||
- Codex OAuth 路径会自动注入 `image_generation` tool;这个注入表示“模型具备工具能力”,不等价于当前请求一定会生图。
|
||||
- `/v1/responses` 在 handler 入口只能可靠识别显式图片意图:image 模型、请求体已有 image tool、或 tool_choice 明确选择 image_generation。
|
||||
- 图片实际产物计数与计费仍以 service 层的最终输出解析为准。
|
||||
|
||||
## Decisions
|
||||
|
||||
### 1. 默认关闭,保持兼容
|
||||
|
||||
新增配置:
|
||||
|
||||
- `gateway.image_concurrency.enabled`,默认 `false`。
|
||||
- `gateway.image_concurrency.max_concurrent_requests`,默认 `0`,表示不限制。
|
||||
- `gateway.image_concurrency.overflow_mode`,默认 `reject`,可选 `reject` / `wait`。
|
||||
- `gateway.image_concurrency.wait_timeout_seconds`,默认 `30`,仅 `overflow_mode=wait` 生效。
|
||||
- `gateway.image_concurrency.max_waiting_requests`,默认 `100`,仅 `overflow_mode=wait` 生效,限制当前进程内图片等待队列。
|
||||
|
||||
只有当 `enabled=true` 且 `max_concurrent_requests>0` 时才启用图片独立并发限制。默认配置不改变任何现有流量行为。
|
||||
|
||||
### 2. 进程级信号量作为第一阶段隔离
|
||||
|
||||
本次使用进程内有界信号量做服务级图片并发限制。原因:
|
||||
|
||||
- 不扩展现有 Redis `ConcurrencyCache` 接口,避免影响用户/账号并发的既有语义。
|
||||
- 不新增迁移,不改变分组已有字段。
|
||||
- 单实例部署可立即保护进程资源。
|
||||
- 多实例部署时该限制按实例生效;文档必须明确总图片并发约等于 `实例数 × max_concurrent_requests`。
|
||||
|
||||
### 3. 限制对象只包含明确图片意图
|
||||
|
||||
纳入限制:
|
||||
|
||||
- `/v1/images/generations`
|
||||
- `/v1/images/edits`
|
||||
- `/v1/responses` 中入口请求已明确包含图片意图:image 模型、`tools[].type=image_generation`、`tool_choice` 明确选择 image_generation。
|
||||
|
||||
暂不纳入限制:
|
||||
|
||||
- 普通 Codex 请求因为服务端自动注入 image tool 而具备生图能力,但入口请求本身未明确要求生图。
|
||||
|
||||
这样避免把普通编码请求错误算作图片并发。后续若要对“模型运行中动态调用 image tool”做更细粒度隔离,需要在工具调用实际发生时获得可阻塞的事件,目前当前代码没有这种入口级阻塞点。
|
||||
|
||||
### 4. 限流行为
|
||||
|
||||
- `overflow_mode=reject` 时,未开始流式响应直接返回 HTTP `429`,错误类型 `rate_limit_error`。
|
||||
- `overflow_mode=wait` 时,请求在当前进程内等待图片并发槽位,超过 `wait_timeout_seconds` 或超过 `max_waiting_requests` 后返回 HTTP `429`。
|
||||
- 已开始流式响应时,使用现有 `handleStreamingAwareError` 写 SSE 错误事件。
|
||||
- 图片并发限制命中或等待超时不触发账号 failover,不记录为上游账号失败。
|
||||
- `gateway.image_stream_data_interval_timeout` 是上游图片流数据空闲超时,不用于图片排队等待。
|
||||
|
||||
### 5. 与外部图片网关的关系
|
||||
|
||||
本次不实现外部图片网关代码。外部网关方案沉淀到 `2ue` 文档:
|
||||
|
||||
- 推荐由 Caddy/Nginx/API Gateway 按 `/v1/images/*` 分流。
|
||||
- `/v1/responses` 的图片 tool 请求不能仅靠 path 分流,必须在前置层读取 body 或保留主服务兜底。
|
||||
- 即使未来拆出图片网关,主网关仍保留图片 intent 检测、开关和计费兜底,避免直连或漏判绕过。
|
||||
|
||||
## Risks And Mitigations
|
||||
|
||||
- 风险:进程级限制在多实例部署下不是全局严格限制。缓解:文档明确容量计算,后续可基于 Redis 扩展为集群级图片并发。
|
||||
- 风险:Codex 自动注入 image tool 后,普通编码请求未被图片限流。缓解:这是有意选择,避免误伤普通请求;实际输出图片仍按图片计费。
|
||||
- 风险:图片请求在账号槽位前被拒绝可能改变排队体验。缓解:仅当独立开关启用时生效,默认关闭;429 明确提示图片并发达到上限。
|
||||
@ -0,0 +1,28 @@
|
||||
## Why
|
||||
|
||||
图片生成流式请求会比普通文本流式请求占用更长的连接、goroutine、HTTP 上游连接和账号/用户槽位。当前图片能力已经具备独立计费与更长流式超时,但仍缺少默认关闭的图片专属并发隔离开关,图片高并发时仍可能挤压普通文本流式接口。
|
||||
|
||||
## What Changes
|
||||
|
||||
- 新增服务级图片独立并发开关,默认关闭,不改变现有已部署分组和普通文本请求行为。
|
||||
- 新增图片全局并发上限配置;开启后仅限制已明确是图片生成意图的请求。
|
||||
- 新增图片并发满载后的溢出策略配置:默认立即拒绝,也可配置等待槽位和等待超时。
|
||||
- 将图片并发限制覆盖 `/v1/images/generations`、`/v1/images/edits` 和 `/v1/responses` 显式图片生成请求。
|
||||
- 保留当前图片生成开关、图片计费、图片流式续读与超时语义。
|
||||
- 不在本次代码实现外部独立图片网关;只把外部网关拆分方案沉淀到本地文档。
|
||||
|
||||
## Capabilities
|
||||
|
||||
### New Capabilities
|
||||
- `image-generation-concurrency-isolation`: 图片生成请求的独立并发开关、并发上限、429 行为和外部网关落地建议。
|
||||
|
||||
### Modified Capabilities
|
||||
- `image-stream-resilience`: 图片流式续读能力在独立并发开启时受到图片专属并发上限保护,但流式续读与计费契约不变。
|
||||
|
||||
## Impact
|
||||
|
||||
- 影响 `backend/internal/config/config.go` 的 gateway 配置字段、默认值和校验。
|
||||
- 影响 `backend/internal/handler/openai_images.go` 与 `backend/internal/handler/openai_gateway_handler.go` 的图片请求入口限流。
|
||||
- 影响 `deploy/config.example.yaml` 的示例配置与说明。
|
||||
- 影响后端测试:配置默认值/校验、图片接口限流、Responses 显式 image tool 限流。
|
||||
- 新增或更新 `2ue` 本地分析文档,记录外部独立图片网关只作为后续部署方案,不在本次代码落地。
|
||||
@ -0,0 +1,82 @@
|
||||
# image-generation-concurrency-isolation Specification
|
||||
|
||||
## ADDED Requirements
|
||||
|
||||
### Requirement: Image concurrency isolation is opt-in
|
||||
|
||||
The system SHALL keep image concurrency isolation disabled by default.
|
||||
|
||||
#### Scenario: default config keeps existing behavior
|
||||
- **GIVEN** the deployment does not set `gateway.image_concurrency.enabled`
|
||||
- **WHEN** image generation requests are received
|
||||
- **THEN** no new image-specific concurrency limit is applied
|
||||
- **AND** existing user/account concurrency and billing behavior remains unchanged
|
||||
|
||||
### Requirement: Dedicated image concurrency limit
|
||||
|
||||
The system SHALL provide an opt-in service-level image concurrency limit controlled by gateway configuration.
|
||||
|
||||
#### Scenario: explicit image endpoint is limited
|
||||
- **GIVEN** `gateway.image_concurrency.enabled=true`
|
||||
- **AND** `gateway.image_concurrency.max_concurrent_requests=1`
|
||||
- **AND** one image generation request is already active
|
||||
- **WHEN** another `/v1/images/generations` or `/v1/images/edits` request arrives
|
||||
- **THEN** the second request is rejected with HTTP `429`
|
||||
- **AND** the error type is `rate_limit_error`
|
||||
|
||||
#### Scenario: explicit Responses image generation request is limited
|
||||
- **GIVEN** `gateway.image_concurrency.enabled=true`
|
||||
- **AND** `gateway.image_concurrency.max_concurrent_requests=1`
|
||||
- **AND** `gateway.image_concurrency.overflow_mode=reject`
|
||||
- **AND** one image generation request is already active
|
||||
- **WHEN** a `/v1/responses` request explicitly contains `tools[].type=image_generation`, an image model, or `tool_choice` selecting `image_generation`
|
||||
- **THEN** the request is rejected with HTTP `429`
|
||||
- **AND** it is not retried through account failover
|
||||
|
||||
#### Scenario: image request waits for a slot
|
||||
- **GIVEN** `gateway.image_concurrency.enabled=true`
|
||||
- **AND** `gateway.image_concurrency.max_concurrent_requests=1`
|
||||
- **AND** `gateway.image_concurrency.overflow_mode=wait`
|
||||
- **AND** `gateway.image_concurrency.wait_timeout_seconds` is greater than zero
|
||||
- **AND** one image generation request is already active
|
||||
- **WHEN** another explicit image generation request arrives
|
||||
- **AND** the active image generation request releases its slot before the wait timeout
|
||||
- **THEN** the waiting image generation request acquires the slot and continues
|
||||
|
||||
#### Scenario: image wait times out
|
||||
- **GIVEN** `gateway.image_concurrency.enabled=true`
|
||||
- **AND** `gateway.image_concurrency.max_concurrent_requests=1`
|
||||
- **AND** `gateway.image_concurrency.overflow_mode=wait`
|
||||
- **AND** one image generation request is already active
|
||||
- **WHEN** another explicit image generation request waits longer than `gateway.image_concurrency.wait_timeout_seconds`
|
||||
- **THEN** the waiting request is rejected with HTTP `429`
|
||||
- **AND** the error type is `rate_limit_error`
|
||||
|
||||
#### Scenario: image waiting queue is full
|
||||
- **GIVEN** `gateway.image_concurrency.enabled=true`
|
||||
- **AND** `gateway.image_concurrency.overflow_mode=wait`
|
||||
- **AND** `gateway.image_concurrency.max_waiting_requests` is already reached
|
||||
- **WHEN** another explicit image generation request arrives
|
||||
- **THEN** the request is rejected with HTTP `429`
|
||||
- **AND** it does not wait for account scheduling
|
||||
|
||||
### Requirement: Text requests are not image-limited
|
||||
|
||||
The system SHALL NOT apply the image concurrency limit to requests without explicit image generation intent.
|
||||
|
||||
#### Scenario: normal coding request bypasses image limit
|
||||
- **GIVEN** `gateway.image_concurrency.enabled=true`
|
||||
- **AND** the image concurrency limit is full
|
||||
- **WHEN** a `/v1/responses` request uses a text model and does not explicitly contain image generation intent
|
||||
- **THEN** the image concurrency limiter does not reject it
|
||||
- **AND** normal user/account concurrency handling continues
|
||||
|
||||
### Requirement: External image gateway remains a deployment pattern
|
||||
|
||||
The system SHALL document external image gateway routing as a deployment option without adding runtime forwarding code in this change.
|
||||
|
||||
#### Scenario: operator reads local design note
|
||||
- **GIVEN** the repository documentation is available
|
||||
- **WHEN** an operator evaluates isolating image traffic into a separate service
|
||||
- **THEN** local `2ue` documentation describes which paths are safe to route by path
|
||||
- **AND** explains why `/v1/responses` image tool requests require body-aware routing or main-gateway fallback
|
||||
@ -0,0 +1,28 @@
|
||||
## 1. Spec and documentation
|
||||
|
||||
- [x] 1.1 Create OpenSpec proposal, design, tasks, and capability spec for image concurrency isolation.
|
||||
- [x] 1.2 Add a local `2ue` note for the external image gateway deployment pattern and current non-goals.
|
||||
|
||||
## 2. Config
|
||||
|
||||
- [x] 2.1 Add `gateway.image_concurrency.enabled` and `gateway.image_concurrency.max_concurrent_requests` config fields.
|
||||
- [x] 2.2 Register defaults that keep existing behavior unchanged.
|
||||
- [x] 2.3 Validate max concurrent requests as non-negative.
|
||||
- [x] 2.4 Update `deploy/config.example.yaml` with safe usage notes.
|
||||
- [x] 2.5 Add image concurrency overflow mode, wait timeout, and max waiting request config.
|
||||
|
||||
## 3. Runtime limiter
|
||||
|
||||
- [x] 3.1 Implement a process-level image concurrency limiter with resize-on-config-read behavior.
|
||||
- [x] 3.2 Acquire/release the limiter around `/v1/images/generations` and `/v1/images/edits` before account scheduling.
|
||||
- [x] 3.3 Acquire/release the limiter around explicit `/v1/responses` image generation intent before account scheduling.
|
||||
- [x] 3.4 Ensure limiter rejections return `429 rate_limit_error` and do not trigger account failover.
|
||||
- [x] 3.5 Support `reject` and `wait` overflow modes with bounded wait timeout and waiting queue size.
|
||||
|
||||
## 4. Tests and verification
|
||||
|
||||
- [x] 4.1 Add config default and validation tests.
|
||||
- [x] 4.2 Add handler tests for image endpoint limiter rejection.
|
||||
- [x] 4.3 Add handler tests proving text-only Responses requests are not rejected by the image limiter.
|
||||
- [x] 4.4 Run focused Go tests for config and OpenAI handler/service paths.
|
||||
- [x] 4.5 Add limiter tests for wait success, wait timeout, and waiting queue overflow.
|
||||
2
openspec/changes/image-stream-resilience/.openspec.yaml
Normal file
2
openspec/changes/image-stream-resilience/.openspec.yaml
Normal file
@ -0,0 +1,2 @@
|
||||
schema: spec-driven
|
||||
created: 2026-05-02
|
||||
46
openspec/changes/image-stream-resilience/design.md
Normal file
46
openspec/changes/image-stream-resilience/design.md
Normal file
@ -0,0 +1,46 @@
|
||||
## Context
|
||||
|
||||
现有普通 Responses 流式已经具备较完整的断连续写能力:客户端写失败后可继续 drain 上游,并且具备数据间隔超时和 keepalive。图片流式路径目前仍然采用更直接的读写方式,客户端写失败会立即返回,上游读取也更容易跟随客户端取消而结束。
|
||||
|
||||
本次变更只针对图片流式路径,不改变普通文本流式路径的配置和行为。系统已经存在普通流式的后端超时配置,因此这里不引入页面级超时设置;图片流式只需要独立的后端默认值,让图片生成有更长的容忍窗口。
|
||||
|
||||
## Goals / Non-Goals
|
||||
|
||||
**Goals:**
|
||||
- 图片流式在客户端断开后继续读取上游,尽量保留最终图片结果与计费结果。
|
||||
- 图片流式使用独立于普通流式的超时与 keepalive 默认值。
|
||||
- 不修改现有普通流式配置项的含义,不要求管理员新增页面配置。
|
||||
- 维持图片计费与图片结果计数的一致性。
|
||||
|
||||
**Non-Goals:**
|
||||
- 不设计新的前端配置页面。
|
||||
- 不修改普通文本流式的超时策略。
|
||||
- 不改变图片计费公式或分组倍率语义。
|
||||
|
||||
## Decisions
|
||||
|
||||
1. **使用独立的图片流式配置键**
|
||||
- 选择:在后端配置中增加图片流式专用 `image_stream_data_interval_timeout` / `image_stream_keepalive_interval`。
|
||||
- 原因:图片流式耗时显著更长,复用普通流式默认值会过早触发超时;独立键能避免影响现有文本流式。
|
||||
- 备选方案:直接复用普通流式配置并在代码里按路径放大倍数。这个方案会让普通流式和图片流式共享语义,后续难以维护。
|
||||
|
||||
2. **继续使用上下文 detach,而不是依赖客户端上下文**
|
||||
- 选择:图片流式请求向上游发起时使用 `context.WithoutCancel` 派生的上下文。
|
||||
- 原因:客户端断开时不应自动取消上游请求,否则无法收集最终图片结果,也无法完成图片计费。
|
||||
- 备选方案:仍使用 `c.Request.Context()` 并只在写失败后继续 drain。这个方案在客户端取消场景下无法保证上游读取继续进行。
|
||||
|
||||
3. **只改图片流式路径,不改普通流式路径**
|
||||
- 选择:`/v1/images/*` 与 `Responses + image_generation` 两条图片流式链路单独处理。
|
||||
- 原因:风险最小,避免回归普通文本流式和现有超时配置。
|
||||
- 备选方案:统一重构所有流式处理。这个方案范围更大,验证成本更高,不符合本次“尽量少改现有行为”的目标。
|
||||
|
||||
4. **不新增页面配置**
|
||||
- 选择:图片流式独立超时默认值写入后端配置,沿用当前配置加载方式。
|
||||
- 原因:用户明确要求和当前设置行为统一,不需要额外页面输入项。
|
||||
- 备选方案:前端增加图片超时配置项。这个方案会改变现有运维方式,也容易引入误配。
|
||||
|
||||
## Risks / Trade-offs
|
||||
|
||||
- [Risk] 图片流式继续 drain 上游后,客户端已经断开但服务端仍占用连接与协程资源。→ [Mitigation] 只对图片流式启用更长但仍有限的专用超时,并保持与普通流式同样的 keepalive/超时退出机制。
|
||||
- [Risk] 图片流式与普通流式的默认超时不同,运维如果只关注通用配置可能忽略图片专用值。→ [Mitigation] 在配置示例中明确标注图片流式专用默认值和用途。
|
||||
- [Risk] 断连后继续读取可能导致日志中出现“客户端断开但最终成功”的状态。→ [Mitigation] 保留现有图片计费结果返回语义,同时让调用方在结果与错误并存时优先使用结果对象。
|
||||
25
openspec/changes/image-stream-resilience/proposal.md
Normal file
25
openspec/changes/image-stream-resilience/proposal.md
Normal file
@ -0,0 +1,25 @@
|
||||
## Why
|
||||
|
||||
图片流式路径目前没有和普通 Responses 流式一致的断连续写策略,也没有独立于普通流式的超时控制。由于图片生成耗时更长,如果继续沿用普通流式处理方式,客户端断开时容易中断上游读取,影响图片产物收集与按图计费的准确性。
|
||||
|
||||
## What Changes
|
||||
|
||||
- 为 OpenAI Images API 和 `Responses + image_generation` 流式路径补充独立的上游续读策略,客户端断开后继续 drain 上游,尽量保留最终图片结果和计费结果。
|
||||
- 为图片流式路径使用独立的流数据间隔超时与 keepalive 策略,默认比普通流式更长,不新增页面配置项。
|
||||
- 保持现有普通流式配置与行为不变,避免影响已经配置好的普通文本分组。
|
||||
- 让图片流式路径在超时、断连、写入失败等场景下保持图片计费语义一致。
|
||||
|
||||
## Capabilities
|
||||
|
||||
### New Capabilities
|
||||
- `image-stream-resilience`: 图片流式路径的断连续读、独立超时和计费保留能力。
|
||||
|
||||
### Modified Capabilities
|
||||
- `image-generation-billing-accounting`: 图片流式结果计数和计费结果的稳定性行为发生改变,但计费契约不变。
|
||||
|
||||
## Impact
|
||||
|
||||
- 影响 `backend/internal/service/openai_images.go` 和 `backend/internal/service/openai_images_responses.go` 的流式实现。
|
||||
- 影响 `backend/internal/config/config.go` 与 `deploy/config.example.yaml` 中图片流式默认值和校验逻辑。
|
||||
- 影响 `backend/internal/service/openai_images_test.go`、`backend/internal/config/config_test.go` 以及新增的图片流式稳定性测试。
|
||||
- 不新增前端页面设置,不改变普通流式配置项名称和语义。
|
||||
@ -0,0 +1,53 @@
|
||||
## ADDED Requirements
|
||||
|
||||
### Requirement: Image stream resilience
|
||||
The system SHALL keep image generation stream processing active after downstream client disconnects so long as upstream reading can continue, in order to preserve final image outputs and billing results.
|
||||
|
||||
#### Scenario: Images API stream survives downstream disconnect
|
||||
- **WHEN** `/v1/images/generations` is streamed to a client
|
||||
- **AND** the downstream writer returns an error before the upstream stream completes
|
||||
- **THEN** the service continues draining the upstream stream
|
||||
- **AND** it still counts final image outputs if the upstream later emits them
|
||||
- **AND** the request can still complete with image billing metadata
|
||||
|
||||
#### Scenario: Responses image tool stream survives downstream disconnect
|
||||
- **WHEN** a `/v1/responses` request uses `image_generation` and is streamed to a client
|
||||
- **AND** the downstream writer returns an error before the upstream stream completes
|
||||
- **THEN** the service continues draining the upstream stream
|
||||
- **AND** it still counts final image outputs if the upstream later emits them
|
||||
- **AND** the request can still complete with image billing metadata
|
||||
|
||||
#### Scenario: Client disconnect does not force image stream to downgrade to text billing
|
||||
- **WHEN** a successful image stream request has already produced final image outputs
|
||||
- **AND** the downstream client disconnects before the final flush
|
||||
- **THEN** the request remains billed as an image request
|
||||
- **AND** the image count is preserved in the forward result
|
||||
|
||||
### Requirement: Image stream timeout isolation
|
||||
The system SHALL use image-specific streaming timeout settings for image generation stream paths, and these settings SHALL be independent from the ordinary text streaming timeout values.
|
||||
|
||||
#### Scenario: Image stream uses dedicated timeout defaults
|
||||
- **WHEN** an image generation stream path is executed
|
||||
- **THEN** it uses the image-specific data interval timeout and keepalive interval defaults
|
||||
- **AND** it does not rely on the ordinary text stream timeout defaults
|
||||
|
||||
#### Scenario: Ordinary stream settings remain unchanged
|
||||
- **WHEN** a normal non-image streaming request is executed
|
||||
- **THEN** the existing ordinary stream timeout configuration and behavior remain unchanged
|
||||
|
||||
#### Scenario: Image stream timeout is longer than ordinary stream timeout
|
||||
- **WHEN** the image streaming timeout defaults are compared with the ordinary streaming defaults
|
||||
- **THEN** the image streaming timeout is configured to allow a longer wait window than ordinary text streaming
|
||||
|
||||
### Requirement: Image stream billing consistency
|
||||
The system SHALL keep the image billing result consistent even when image stream handling uses retries, keepalive writes, or downstream disconnect recovery.
|
||||
|
||||
#### Scenario: Final image count is preserved after reconnect-unsafe downstream failure
|
||||
- **WHEN** the downstream client disconnects after at least one final image output has been observed upstream
|
||||
- **THEN** the forward result retains the final image count
|
||||
- **AND** usage recording can still proceed with image billing metadata
|
||||
|
||||
#### Scenario: Image stream timeout does not silently switch billing mode
|
||||
- **WHEN** an image stream times out before any final image output is observed
|
||||
- **THEN** the request is handled as a failed image stream
|
||||
- **AND** it does not fall back to ordinary text billing semantics
|
||||
20
openspec/changes/image-stream-resilience/tasks.md
Normal file
20
openspec/changes/image-stream-resilience/tasks.md
Normal file
@ -0,0 +1,20 @@
|
||||
## 1. Config and defaults
|
||||
|
||||
- [x] 1.1 Add image-specific stream timeout fields to gateway config.
|
||||
- [x] 1.2 Register image stream timeout defaults in the config loader.
|
||||
- [x] 1.3 Add config validation for image stream timeout ranges.
|
||||
- [x] 1.4 Expose image stream timeout defaults in `deploy/config.example.yaml`.
|
||||
|
||||
## 2. Image stream runtime behavior
|
||||
|
||||
- [x] 2.1 Detach image stream upstream contexts from client cancellation.
|
||||
- [x] 2.2 Add image-specific data interval timeout handling to `/v1/images/*` streaming.
|
||||
- [x] 2.3 Add image-specific data interval timeout handling to `Responses + image_generation` streaming.
|
||||
- [x] 2.4 Preserve upstream draining after downstream write failures in both image stream paths.
|
||||
|
||||
## 3. Tests and verification
|
||||
|
||||
- [x] 3.1 Add config tests for image stream timeout defaults and validation.
|
||||
- [x] 3.2 Add image streaming disconnect tests for the Images API path.
|
||||
- [x] 3.3 Add image streaming disconnect tests for the Responses image tool path.
|
||||
- [x] 3.4 Run focused Go tests for the touched config and image service paths.
|
||||
20
openspec/config.yaml
Normal file
20
openspec/config.yaml
Normal file
@ -0,0 +1,20 @@
|
||||
schema: spec-driven
|
||||
|
||||
# Project context (optional)
|
||||
# This is shown to AI when creating artifacts.
|
||||
# Add your tech stack, conventions, style guides, domain knowledge, etc.
|
||||
# Example:
|
||||
# context: |
|
||||
# Tech stack: TypeScript, React, Node.js
|
||||
# We use conventional commits
|
||||
# Domain: e-commerce platform
|
||||
|
||||
# Per-artifact rules (optional)
|
||||
# Add custom rules for specific artifacts.
|
||||
# Example:
|
||||
# rules:
|
||||
# proposal:
|
||||
# - Keep proposals under 500 words
|
||||
# - Always include a "Non-goals" section
|
||||
# tasks:
|
||||
# - Break tasks into chunks of max 2 hours
|
||||
Loading…
x
Reference in New Issue
Block a user