chore: merge upstream v0.1.122-123, keep Windsurf/Antigravity customizations
New upstream features: - feat: improve OpenAI messages compatibility for Claude Code - feat: image generation stream & concurrency controls - fix(rate-limit): remove 429 cooldown config option - fix: skip previous_response_id recovery when payload has function_call_output - feat: support select search in group/account views - fix: ops cleanup settings - chore: remove openspec and update axios Conflict resolutions: - config.go: kept AntigravityLSWorker+NodeTLSProxy AND added ImageConcurrency - account_test_service.go: kept windsurf import AND added openai_compat import - docker-compose.yml: kept Windsurf env vars AND added image concurrency env vars
This commit is contained in:
commit
3fe228d143
@ -1 +1 @@
|
||||
0.1.121
|
||||
0.1.123
|
||||
|
||||
@ -265,7 +265,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
opsMetricsCollector := service.ProvideOpsMetricsCollector(opsRepository, settingRepository, accountRepository, concurrencyService, db, redisClient, configConfig)
|
||||
opsAggregationService := service.ProvideOpsAggregationService(opsRepository, settingRepository, db, redisClient, configConfig)
|
||||
opsAlertEvaluatorService := service.ProvideOpsAlertEvaluatorService(opsService, opsRepository, emailService, redisClient, configConfig)
|
||||
opsCleanupService := service.ProvideOpsCleanupService(opsRepository, db, redisClient, configConfig, channelMonitorService)
|
||||
opsCleanupService := service.ProvideOpsCleanupService(opsRepository, db, redisClient, configConfig, channelMonitorService, settingRepository, opsService)
|
||||
opsScheduledReportService := service.ProvideOpsScheduledReportService(opsService, userService, emailService, redisClient, configConfig)
|
||||
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, schedulerCache, configConfig, tempUnschedCache, privacyClientFactory, proxyRepository, oAuthRefreshAPI)
|
||||
accountExpiryService := service.ProvideAccountExpiryService(accountRepository)
|
||||
|
||||
@ -47,6 +47,12 @@ type Group struct {
|
||||
MonthlyLimitUsd *float64 `json:"monthly_limit_usd,omitempty"`
|
||||
// DefaultValidityDays holds the value of the "default_validity_days" field.
|
||||
DefaultValidityDays int `json:"default_validity_days,omitempty"`
|
||||
// 是否允许该分组使用图片生成能力
|
||||
AllowImageGeneration bool `json:"allow_image_generation,omitempty"`
|
||||
// 图片生成是否使用独立倍率;false 表示共享分组有效倍率
|
||||
ImageRateIndependent bool `json:"image_rate_independent,omitempty"`
|
||||
// 图片生成独立倍率,仅 image_rate_independent=true 时生效
|
||||
ImageRateMultiplier float64 `json:"image_rate_multiplier,omitempty"`
|
||||
// ImagePrice1k holds the value of the "image_price_1k" field.
|
||||
ImagePrice1k *float64 `json:"image_price_1k,omitempty"`
|
||||
// ImagePrice2k holds the value of the "image_price_2k" field.
|
||||
@ -189,9 +195,9 @@ func (*Group) scanValues(columns []string) ([]any, error) {
|
||||
switch columns[i] {
|
||||
case group.FieldModelRouting, group.FieldSupportedModelScopes, group.FieldMessagesDispatchModelConfig:
|
||||
values[i] = new([]byte)
|
||||
case group.FieldIsExclusive, group.FieldClaudeCodeOnly, group.FieldModelRoutingEnabled, group.FieldMcpXMLInject, group.FieldAllowMessagesDispatch, group.FieldRequireOauthOnly, group.FieldRequirePrivacySet:
|
||||
case group.FieldIsExclusive, group.FieldAllowImageGeneration, group.FieldImageRateIndependent, group.FieldClaudeCodeOnly, group.FieldModelRoutingEnabled, group.FieldMcpXMLInject, group.FieldAllowMessagesDispatch, group.FieldRequireOauthOnly, group.FieldRequirePrivacySet:
|
||||
values[i] = new(sql.NullBool)
|
||||
case group.FieldRateMultiplier, group.FieldDailyLimitUsd, group.FieldWeeklyLimitUsd, group.FieldMonthlyLimitUsd, group.FieldImagePrice1k, group.FieldImagePrice2k, group.FieldImagePrice4k:
|
||||
case group.FieldRateMultiplier, group.FieldDailyLimitUsd, group.FieldWeeklyLimitUsd, group.FieldMonthlyLimitUsd, group.FieldImageRateMultiplier, group.FieldImagePrice1k, group.FieldImagePrice2k, group.FieldImagePrice4k:
|
||||
values[i] = new(sql.NullFloat64)
|
||||
case group.FieldID, group.FieldDefaultValidityDays, group.FieldFallbackGroupID, group.FieldFallbackGroupIDOnInvalidRequest, group.FieldSortOrder, group.FieldRpmLimit:
|
||||
values[i] = new(sql.NullInt64)
|
||||
@ -309,6 +315,24 @@ func (_m *Group) assignValues(columns []string, values []any) error {
|
||||
} else if value.Valid {
|
||||
_m.DefaultValidityDays = int(value.Int64)
|
||||
}
|
||||
case group.FieldAllowImageGeneration:
|
||||
if value, ok := values[i].(*sql.NullBool); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field allow_image_generation", values[i])
|
||||
} else if value.Valid {
|
||||
_m.AllowImageGeneration = value.Bool
|
||||
}
|
||||
case group.FieldImageRateIndependent:
|
||||
if value, ok := values[i].(*sql.NullBool); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field image_rate_independent", values[i])
|
||||
} else if value.Valid {
|
||||
_m.ImageRateIndependent = value.Bool
|
||||
}
|
||||
case group.FieldImageRateMultiplier:
|
||||
if value, ok := values[i].(*sql.NullFloat64); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field image_rate_multiplier", values[i])
|
||||
} else if value.Valid {
|
||||
_m.ImageRateMultiplier = value.Float64
|
||||
}
|
||||
case group.FieldImagePrice1k:
|
||||
if value, ok := values[i].(*sql.NullFloat64); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field image_price_1k", values[i])
|
||||
@ -550,6 +574,15 @@ func (_m *Group) String() string {
|
||||
builder.WriteString("default_validity_days=")
|
||||
builder.WriteString(fmt.Sprintf("%v", _m.DefaultValidityDays))
|
||||
builder.WriteString(", ")
|
||||
builder.WriteString("allow_image_generation=")
|
||||
builder.WriteString(fmt.Sprintf("%v", _m.AllowImageGeneration))
|
||||
builder.WriteString(", ")
|
||||
builder.WriteString("image_rate_independent=")
|
||||
builder.WriteString(fmt.Sprintf("%v", _m.ImageRateIndependent))
|
||||
builder.WriteString(", ")
|
||||
builder.WriteString("image_rate_multiplier=")
|
||||
builder.WriteString(fmt.Sprintf("%v", _m.ImageRateMultiplier))
|
||||
builder.WriteString(", ")
|
||||
if v := _m.ImagePrice1k; v != nil {
|
||||
builder.WriteString("image_price_1k=")
|
||||
builder.WriteString(fmt.Sprintf("%v", *v))
|
||||
|
||||
@ -44,6 +44,12 @@ const (
|
||||
FieldMonthlyLimitUsd = "monthly_limit_usd"
|
||||
// FieldDefaultValidityDays holds the string denoting the default_validity_days field in the database.
|
||||
FieldDefaultValidityDays = "default_validity_days"
|
||||
// FieldAllowImageGeneration holds the string denoting the allow_image_generation field in the database.
|
||||
FieldAllowImageGeneration = "allow_image_generation"
|
||||
// FieldImageRateIndependent holds the string denoting the image_rate_independent field in the database.
|
||||
FieldImageRateIndependent = "image_rate_independent"
|
||||
// FieldImageRateMultiplier holds the string denoting the image_rate_multiplier field in the database.
|
||||
FieldImageRateMultiplier = "image_rate_multiplier"
|
||||
// FieldImagePrice1k holds the string denoting the image_price_1k field in the database.
|
||||
FieldImagePrice1k = "image_price_1k"
|
||||
// FieldImagePrice2k holds the string denoting the image_price_2k field in the database.
|
||||
@ -167,6 +173,9 @@ var Columns = []string{
|
||||
FieldWeeklyLimitUsd,
|
||||
FieldMonthlyLimitUsd,
|
||||
FieldDefaultValidityDays,
|
||||
FieldAllowImageGeneration,
|
||||
FieldImageRateIndependent,
|
||||
FieldImageRateMultiplier,
|
||||
FieldImagePrice1k,
|
||||
FieldImagePrice2k,
|
||||
FieldImagePrice4k,
|
||||
@ -239,6 +248,12 @@ var (
|
||||
SubscriptionTypeValidator func(string) error
|
||||
// DefaultDefaultValidityDays holds the default value on creation for the "default_validity_days" field.
|
||||
DefaultDefaultValidityDays int
|
||||
// DefaultAllowImageGeneration holds the default value on creation for the "allow_image_generation" field.
|
||||
DefaultAllowImageGeneration bool
|
||||
// DefaultImageRateIndependent holds the default value on creation for the "image_rate_independent" field.
|
||||
DefaultImageRateIndependent bool
|
||||
// DefaultImageRateMultiplier holds the default value on creation for the "image_rate_multiplier" field.
|
||||
DefaultImageRateMultiplier float64
|
||||
// DefaultClaudeCodeOnly holds the default value on creation for the "claude_code_only" field.
|
||||
DefaultClaudeCodeOnly bool
|
||||
// DefaultModelRoutingEnabled holds the default value on creation for the "model_routing_enabled" field.
|
||||
@ -343,6 +358,21 @@ func ByDefaultValidityDays(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldDefaultValidityDays, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByAllowImageGeneration orders the results by the allow_image_generation field.
|
||||
func ByAllowImageGeneration(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldAllowImageGeneration, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByImageRateIndependent orders the results by the image_rate_independent field.
|
||||
func ByImageRateIndependent(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldImageRateIndependent, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByImageRateMultiplier orders the results by the image_rate_multiplier field.
|
||||
func ByImageRateMultiplier(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldImageRateMultiplier, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByImagePrice1k orders the results by the image_price_1k field.
|
||||
func ByImagePrice1k(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldImagePrice1k, opts...).ToFunc()
|
||||
|
||||
@ -125,6 +125,21 @@ func DefaultValidityDays(v int) predicate.Group {
|
||||
return predicate.Group(sql.FieldEQ(FieldDefaultValidityDays, v))
|
||||
}
|
||||
|
||||
// AllowImageGeneration applies equality check predicate on the "allow_image_generation" field. It's identical to AllowImageGenerationEQ.
|
||||
func AllowImageGeneration(v bool) predicate.Group {
|
||||
return predicate.Group(sql.FieldEQ(FieldAllowImageGeneration, v))
|
||||
}
|
||||
|
||||
// ImageRateIndependent applies equality check predicate on the "image_rate_independent" field. It's identical to ImageRateIndependentEQ.
|
||||
func ImageRateIndependent(v bool) predicate.Group {
|
||||
return predicate.Group(sql.FieldEQ(FieldImageRateIndependent, v))
|
||||
}
|
||||
|
||||
// ImageRateMultiplier applies equality check predicate on the "image_rate_multiplier" field. It's identical to ImageRateMultiplierEQ.
|
||||
func ImageRateMultiplier(v float64) predicate.Group {
|
||||
return predicate.Group(sql.FieldEQ(FieldImageRateMultiplier, v))
|
||||
}
|
||||
|
||||
// ImagePrice1k applies equality check predicate on the "image_price_1k" field. It's identical to ImagePrice1kEQ.
|
||||
func ImagePrice1k(v float64) predicate.Group {
|
||||
return predicate.Group(sql.FieldEQ(FieldImagePrice1k, v))
|
||||
@ -900,6 +915,66 @@ func DefaultValidityDaysLTE(v int) predicate.Group {
|
||||
return predicate.Group(sql.FieldLTE(FieldDefaultValidityDays, v))
|
||||
}
|
||||
|
||||
// AllowImageGenerationEQ applies the EQ predicate on the "allow_image_generation" field.
|
||||
func AllowImageGenerationEQ(v bool) predicate.Group {
|
||||
return predicate.Group(sql.FieldEQ(FieldAllowImageGeneration, v))
|
||||
}
|
||||
|
||||
// AllowImageGenerationNEQ applies the NEQ predicate on the "allow_image_generation" field.
|
||||
func AllowImageGenerationNEQ(v bool) predicate.Group {
|
||||
return predicate.Group(sql.FieldNEQ(FieldAllowImageGeneration, v))
|
||||
}
|
||||
|
||||
// ImageRateIndependentEQ applies the EQ predicate on the "image_rate_independent" field.
|
||||
func ImageRateIndependentEQ(v bool) predicate.Group {
|
||||
return predicate.Group(sql.FieldEQ(FieldImageRateIndependent, v))
|
||||
}
|
||||
|
||||
// ImageRateIndependentNEQ applies the NEQ predicate on the "image_rate_independent" field.
|
||||
func ImageRateIndependentNEQ(v bool) predicate.Group {
|
||||
return predicate.Group(sql.FieldNEQ(FieldImageRateIndependent, v))
|
||||
}
|
||||
|
||||
// ImageRateMultiplierEQ applies the EQ predicate on the "image_rate_multiplier" field.
|
||||
func ImageRateMultiplierEQ(v float64) predicate.Group {
|
||||
return predicate.Group(sql.FieldEQ(FieldImageRateMultiplier, v))
|
||||
}
|
||||
|
||||
// ImageRateMultiplierNEQ applies the NEQ predicate on the "image_rate_multiplier" field.
|
||||
func ImageRateMultiplierNEQ(v float64) predicate.Group {
|
||||
return predicate.Group(sql.FieldNEQ(FieldImageRateMultiplier, v))
|
||||
}
|
||||
|
||||
// ImageRateMultiplierIn applies the In predicate on the "image_rate_multiplier" field.
|
||||
func ImageRateMultiplierIn(vs ...float64) predicate.Group {
|
||||
return predicate.Group(sql.FieldIn(FieldImageRateMultiplier, vs...))
|
||||
}
|
||||
|
||||
// ImageRateMultiplierNotIn applies the NotIn predicate on the "image_rate_multiplier" field.
|
||||
func ImageRateMultiplierNotIn(vs ...float64) predicate.Group {
|
||||
return predicate.Group(sql.FieldNotIn(FieldImageRateMultiplier, vs...))
|
||||
}
|
||||
|
||||
// ImageRateMultiplierGT applies the GT predicate on the "image_rate_multiplier" field.
|
||||
func ImageRateMultiplierGT(v float64) predicate.Group {
|
||||
return predicate.Group(sql.FieldGT(FieldImageRateMultiplier, v))
|
||||
}
|
||||
|
||||
// ImageRateMultiplierGTE applies the GTE predicate on the "image_rate_multiplier" field.
|
||||
func ImageRateMultiplierGTE(v float64) predicate.Group {
|
||||
return predicate.Group(sql.FieldGTE(FieldImageRateMultiplier, v))
|
||||
}
|
||||
|
||||
// ImageRateMultiplierLT applies the LT predicate on the "image_rate_multiplier" field.
|
||||
func ImageRateMultiplierLT(v float64) predicate.Group {
|
||||
return predicate.Group(sql.FieldLT(FieldImageRateMultiplier, v))
|
||||
}
|
||||
|
||||
// ImageRateMultiplierLTE applies the LTE predicate on the "image_rate_multiplier" field.
|
||||
func ImageRateMultiplierLTE(v float64) predicate.Group {
|
||||
return predicate.Group(sql.FieldLTE(FieldImageRateMultiplier, v))
|
||||
}
|
||||
|
||||
// ImagePrice1kEQ applies the EQ predicate on the "image_price_1k" field.
|
||||
func ImagePrice1kEQ(v float64) predicate.Group {
|
||||
return predicate.Group(sql.FieldEQ(FieldImagePrice1k, v))
|
||||
|
||||
@ -217,6 +217,48 @@ func (_c *GroupCreate) SetNillableDefaultValidityDays(v *int) *GroupCreate {
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetAllowImageGeneration sets the "allow_image_generation" field.
|
||||
func (_c *GroupCreate) SetAllowImageGeneration(v bool) *GroupCreate {
|
||||
_c.mutation.SetAllowImageGeneration(v)
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetNillableAllowImageGeneration sets the "allow_image_generation" field if the given value is not nil.
|
||||
func (_c *GroupCreate) SetNillableAllowImageGeneration(v *bool) *GroupCreate {
|
||||
if v != nil {
|
||||
_c.SetAllowImageGeneration(*v)
|
||||
}
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetImageRateIndependent sets the "image_rate_independent" field.
|
||||
func (_c *GroupCreate) SetImageRateIndependent(v bool) *GroupCreate {
|
||||
_c.mutation.SetImageRateIndependent(v)
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetNillableImageRateIndependent sets the "image_rate_independent" field if the given value is not nil.
|
||||
func (_c *GroupCreate) SetNillableImageRateIndependent(v *bool) *GroupCreate {
|
||||
if v != nil {
|
||||
_c.SetImageRateIndependent(*v)
|
||||
}
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetImageRateMultiplier sets the "image_rate_multiplier" field.
|
||||
func (_c *GroupCreate) SetImageRateMultiplier(v float64) *GroupCreate {
|
||||
_c.mutation.SetImageRateMultiplier(v)
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetNillableImageRateMultiplier sets the "image_rate_multiplier" field if the given value is not nil.
|
||||
func (_c *GroupCreate) SetNillableImageRateMultiplier(v *float64) *GroupCreate {
|
||||
if v != nil {
|
||||
_c.SetImageRateMultiplier(*v)
|
||||
}
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetImagePrice1k sets the "image_price_1k" field.
|
||||
func (_c *GroupCreate) SetImagePrice1k(v float64) *GroupCreate {
|
||||
_c.mutation.SetImagePrice1k(v)
|
||||
@ -604,6 +646,18 @@ func (_c *GroupCreate) defaults() error {
|
||||
v := group.DefaultDefaultValidityDays
|
||||
_c.mutation.SetDefaultValidityDays(v)
|
||||
}
|
||||
if _, ok := _c.mutation.AllowImageGeneration(); !ok {
|
||||
v := group.DefaultAllowImageGeneration
|
||||
_c.mutation.SetAllowImageGeneration(v)
|
||||
}
|
||||
if _, ok := _c.mutation.ImageRateIndependent(); !ok {
|
||||
v := group.DefaultImageRateIndependent
|
||||
_c.mutation.SetImageRateIndependent(v)
|
||||
}
|
||||
if _, ok := _c.mutation.ImageRateMultiplier(); !ok {
|
||||
v := group.DefaultImageRateMultiplier
|
||||
_c.mutation.SetImageRateMultiplier(v)
|
||||
}
|
||||
if _, ok := _c.mutation.ClaudeCodeOnly(); !ok {
|
||||
v := group.DefaultClaudeCodeOnly
|
||||
_c.mutation.SetClaudeCodeOnly(v)
|
||||
@ -700,6 +754,15 @@ func (_c *GroupCreate) check() error {
|
||||
if _, ok := _c.mutation.DefaultValidityDays(); !ok {
|
||||
return &ValidationError{Name: "default_validity_days", err: errors.New(`ent: missing required field "Group.default_validity_days"`)}
|
||||
}
|
||||
if _, ok := _c.mutation.AllowImageGeneration(); !ok {
|
||||
return &ValidationError{Name: "allow_image_generation", err: errors.New(`ent: missing required field "Group.allow_image_generation"`)}
|
||||
}
|
||||
if _, ok := _c.mutation.ImageRateIndependent(); !ok {
|
||||
return &ValidationError{Name: "image_rate_independent", err: errors.New(`ent: missing required field "Group.image_rate_independent"`)}
|
||||
}
|
||||
if _, ok := _c.mutation.ImageRateMultiplier(); !ok {
|
||||
return &ValidationError{Name: "image_rate_multiplier", err: errors.New(`ent: missing required field "Group.image_rate_multiplier"`)}
|
||||
}
|
||||
if _, ok := _c.mutation.ClaudeCodeOnly(); !ok {
|
||||
return &ValidationError{Name: "claude_code_only", err: errors.New(`ent: missing required field "Group.claude_code_only"`)}
|
||||
}
|
||||
@ -821,6 +884,18 @@ func (_c *GroupCreate) createSpec() (*Group, *sqlgraph.CreateSpec) {
|
||||
_spec.SetField(group.FieldDefaultValidityDays, field.TypeInt, value)
|
||||
_node.DefaultValidityDays = value
|
||||
}
|
||||
if value, ok := _c.mutation.AllowImageGeneration(); ok {
|
||||
_spec.SetField(group.FieldAllowImageGeneration, field.TypeBool, value)
|
||||
_node.AllowImageGeneration = value
|
||||
}
|
||||
if value, ok := _c.mutation.ImageRateIndependent(); ok {
|
||||
_spec.SetField(group.FieldImageRateIndependent, field.TypeBool, value)
|
||||
_node.ImageRateIndependent = value
|
||||
}
|
||||
if value, ok := _c.mutation.ImageRateMultiplier(); ok {
|
||||
_spec.SetField(group.FieldImageRateMultiplier, field.TypeFloat64, value)
|
||||
_node.ImageRateMultiplier = value
|
||||
}
|
||||
if value, ok := _c.mutation.ImagePrice1k(); ok {
|
||||
_spec.SetField(group.FieldImagePrice1k, field.TypeFloat64, value)
|
||||
_node.ImagePrice1k = &value
|
||||
@ -1261,6 +1336,48 @@ func (u *GroupUpsert) AddDefaultValidityDays(v int) *GroupUpsert {
|
||||
return u
|
||||
}
|
||||
|
||||
// SetAllowImageGeneration sets the "allow_image_generation" field.
|
||||
func (u *GroupUpsert) SetAllowImageGeneration(v bool) *GroupUpsert {
|
||||
u.Set(group.FieldAllowImageGeneration, v)
|
||||
return u
|
||||
}
|
||||
|
||||
// UpdateAllowImageGeneration sets the "allow_image_generation" field to the value that was provided on create.
|
||||
func (u *GroupUpsert) UpdateAllowImageGeneration() *GroupUpsert {
|
||||
u.SetExcluded(group.FieldAllowImageGeneration)
|
||||
return u
|
||||
}
|
||||
|
||||
// SetImageRateIndependent sets the "image_rate_independent" field.
|
||||
func (u *GroupUpsert) SetImageRateIndependent(v bool) *GroupUpsert {
|
||||
u.Set(group.FieldImageRateIndependent, v)
|
||||
return u
|
||||
}
|
||||
|
||||
// UpdateImageRateIndependent sets the "image_rate_independent" field to the value that was provided on create.
|
||||
func (u *GroupUpsert) UpdateImageRateIndependent() *GroupUpsert {
|
||||
u.SetExcluded(group.FieldImageRateIndependent)
|
||||
return u
|
||||
}
|
||||
|
||||
// SetImageRateMultiplier sets the "image_rate_multiplier" field.
|
||||
func (u *GroupUpsert) SetImageRateMultiplier(v float64) *GroupUpsert {
|
||||
u.Set(group.FieldImageRateMultiplier, v)
|
||||
return u
|
||||
}
|
||||
|
||||
// UpdateImageRateMultiplier sets the "image_rate_multiplier" field to the value that was provided on create.
|
||||
func (u *GroupUpsert) UpdateImageRateMultiplier() *GroupUpsert {
|
||||
u.SetExcluded(group.FieldImageRateMultiplier)
|
||||
return u
|
||||
}
|
||||
|
||||
// AddImageRateMultiplier adds v to the "image_rate_multiplier" field.
|
||||
func (u *GroupUpsert) AddImageRateMultiplier(v float64) *GroupUpsert {
|
||||
u.Add(group.FieldImageRateMultiplier, v)
|
||||
return u
|
||||
}
|
||||
|
||||
// SetImagePrice1k sets the "image_price_1k" field.
|
||||
func (u *GroupUpsert) SetImagePrice1k(v float64) *GroupUpsert {
|
||||
u.Set(group.FieldImagePrice1k, v)
|
||||
@ -1840,6 +1957,55 @@ func (u *GroupUpsertOne) UpdateDefaultValidityDays() *GroupUpsertOne {
|
||||
})
|
||||
}
|
||||
|
||||
// SetAllowImageGeneration sets the "allow_image_generation" field.
|
||||
func (u *GroupUpsertOne) SetAllowImageGeneration(v bool) *GroupUpsertOne {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.SetAllowImageGeneration(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateAllowImageGeneration sets the "allow_image_generation" field to the value that was provided on create.
|
||||
func (u *GroupUpsertOne) UpdateAllowImageGeneration() *GroupUpsertOne {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.UpdateAllowImageGeneration()
|
||||
})
|
||||
}
|
||||
|
||||
// SetImageRateIndependent sets the "image_rate_independent" field.
|
||||
func (u *GroupUpsertOne) SetImageRateIndependent(v bool) *GroupUpsertOne {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.SetImageRateIndependent(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateImageRateIndependent sets the "image_rate_independent" field to the value that was provided on create.
|
||||
func (u *GroupUpsertOne) UpdateImageRateIndependent() *GroupUpsertOne {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.UpdateImageRateIndependent()
|
||||
})
|
||||
}
|
||||
|
||||
// SetImageRateMultiplier sets the "image_rate_multiplier" field.
|
||||
func (u *GroupUpsertOne) SetImageRateMultiplier(v float64) *GroupUpsertOne {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.SetImageRateMultiplier(v)
|
||||
})
|
||||
}
|
||||
|
||||
// AddImageRateMultiplier adds v to the "image_rate_multiplier" field.
|
||||
func (u *GroupUpsertOne) AddImageRateMultiplier(v float64) *GroupUpsertOne {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.AddImageRateMultiplier(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateImageRateMultiplier sets the "image_rate_multiplier" field to the value that was provided on create.
|
||||
func (u *GroupUpsertOne) UpdateImageRateMultiplier() *GroupUpsertOne {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.UpdateImageRateMultiplier()
|
||||
})
|
||||
}
|
||||
|
||||
// SetImagePrice1k sets the "image_price_1k" field.
|
||||
func (u *GroupUpsertOne) SetImagePrice1k(v float64) *GroupUpsertOne {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
@ -2632,6 +2798,55 @@ func (u *GroupUpsertBulk) UpdateDefaultValidityDays() *GroupUpsertBulk {
|
||||
})
|
||||
}
|
||||
|
||||
// SetAllowImageGeneration sets the "allow_image_generation" field.
|
||||
func (u *GroupUpsertBulk) SetAllowImageGeneration(v bool) *GroupUpsertBulk {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.SetAllowImageGeneration(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateAllowImageGeneration sets the "allow_image_generation" field to the value that was provided on create.
|
||||
func (u *GroupUpsertBulk) UpdateAllowImageGeneration() *GroupUpsertBulk {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.UpdateAllowImageGeneration()
|
||||
})
|
||||
}
|
||||
|
||||
// SetImageRateIndependent sets the "image_rate_independent" field.
|
||||
func (u *GroupUpsertBulk) SetImageRateIndependent(v bool) *GroupUpsertBulk {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.SetImageRateIndependent(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateImageRateIndependent sets the "image_rate_independent" field to the value that was provided on create.
|
||||
func (u *GroupUpsertBulk) UpdateImageRateIndependent() *GroupUpsertBulk {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.UpdateImageRateIndependent()
|
||||
})
|
||||
}
|
||||
|
||||
// SetImageRateMultiplier sets the "image_rate_multiplier" field.
|
||||
func (u *GroupUpsertBulk) SetImageRateMultiplier(v float64) *GroupUpsertBulk {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.SetImageRateMultiplier(v)
|
||||
})
|
||||
}
|
||||
|
||||
// AddImageRateMultiplier adds v to the "image_rate_multiplier" field.
|
||||
func (u *GroupUpsertBulk) AddImageRateMultiplier(v float64) *GroupUpsertBulk {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.AddImageRateMultiplier(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateImageRateMultiplier sets the "image_rate_multiplier" field to the value that was provided on create.
|
||||
func (u *GroupUpsertBulk) UpdateImageRateMultiplier() *GroupUpsertBulk {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.UpdateImageRateMultiplier()
|
||||
})
|
||||
}
|
||||
|
||||
// SetImagePrice1k sets the "image_price_1k" field.
|
||||
func (u *GroupUpsertBulk) SetImagePrice1k(v float64) *GroupUpsertBulk {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
|
||||
@ -275,6 +275,55 @@ func (_u *GroupUpdate) AddDefaultValidityDays(v int) *GroupUpdate {
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetAllowImageGeneration sets the "allow_image_generation" field.
|
||||
func (_u *GroupUpdate) SetAllowImageGeneration(v bool) *GroupUpdate {
|
||||
_u.mutation.SetAllowImageGeneration(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableAllowImageGeneration sets the "allow_image_generation" field if the given value is not nil.
|
||||
func (_u *GroupUpdate) SetNillableAllowImageGeneration(v *bool) *GroupUpdate {
|
||||
if v != nil {
|
||||
_u.SetAllowImageGeneration(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetImageRateIndependent sets the "image_rate_independent" field.
|
||||
func (_u *GroupUpdate) SetImageRateIndependent(v bool) *GroupUpdate {
|
||||
_u.mutation.SetImageRateIndependent(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableImageRateIndependent sets the "image_rate_independent" field if the given value is not nil.
|
||||
func (_u *GroupUpdate) SetNillableImageRateIndependent(v *bool) *GroupUpdate {
|
||||
if v != nil {
|
||||
_u.SetImageRateIndependent(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetImageRateMultiplier sets the "image_rate_multiplier" field.
|
||||
func (_u *GroupUpdate) SetImageRateMultiplier(v float64) *GroupUpdate {
|
||||
_u.mutation.ResetImageRateMultiplier()
|
||||
_u.mutation.SetImageRateMultiplier(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableImageRateMultiplier sets the "image_rate_multiplier" field if the given value is not nil.
|
||||
func (_u *GroupUpdate) SetNillableImageRateMultiplier(v *float64) *GroupUpdate {
|
||||
if v != nil {
|
||||
_u.SetImageRateMultiplier(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// AddImageRateMultiplier adds value to the "image_rate_multiplier" field.
|
||||
func (_u *GroupUpdate) AddImageRateMultiplier(v float64) *GroupUpdate {
|
||||
_u.mutation.AddImageRateMultiplier(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetImagePrice1k sets the "image_price_1k" field.
|
||||
func (_u *GroupUpdate) SetImagePrice1k(v float64) *GroupUpdate {
|
||||
_u.mutation.ResetImagePrice1k()
|
||||
@ -962,6 +1011,18 @@ func (_u *GroupUpdate) sqlSave(ctx context.Context) (_node int, err error) {
|
||||
if value, ok := _u.mutation.AddedDefaultValidityDays(); ok {
|
||||
_spec.AddField(group.FieldDefaultValidityDays, field.TypeInt, value)
|
||||
}
|
||||
if value, ok := _u.mutation.AllowImageGeneration(); ok {
|
||||
_spec.SetField(group.FieldAllowImageGeneration, field.TypeBool, value)
|
||||
}
|
||||
if value, ok := _u.mutation.ImageRateIndependent(); ok {
|
||||
_spec.SetField(group.FieldImageRateIndependent, field.TypeBool, value)
|
||||
}
|
||||
if value, ok := _u.mutation.ImageRateMultiplier(); ok {
|
||||
_spec.SetField(group.FieldImageRateMultiplier, field.TypeFloat64, value)
|
||||
}
|
||||
if value, ok := _u.mutation.AddedImageRateMultiplier(); ok {
|
||||
_spec.AddField(group.FieldImageRateMultiplier, field.TypeFloat64, value)
|
||||
}
|
||||
if value, ok := _u.mutation.ImagePrice1k(); ok {
|
||||
_spec.SetField(group.FieldImagePrice1k, field.TypeFloat64, value)
|
||||
}
|
||||
@ -1610,6 +1671,55 @@ func (_u *GroupUpdateOne) AddDefaultValidityDays(v int) *GroupUpdateOne {
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetAllowImageGeneration sets the "allow_image_generation" field.
|
||||
func (_u *GroupUpdateOne) SetAllowImageGeneration(v bool) *GroupUpdateOne {
|
||||
_u.mutation.SetAllowImageGeneration(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableAllowImageGeneration sets the "allow_image_generation" field if the given value is not nil.
|
||||
func (_u *GroupUpdateOne) SetNillableAllowImageGeneration(v *bool) *GroupUpdateOne {
|
||||
if v != nil {
|
||||
_u.SetAllowImageGeneration(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetImageRateIndependent sets the "image_rate_independent" field.
|
||||
func (_u *GroupUpdateOne) SetImageRateIndependent(v bool) *GroupUpdateOne {
|
||||
_u.mutation.SetImageRateIndependent(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableImageRateIndependent sets the "image_rate_independent" field if the given value is not nil.
|
||||
func (_u *GroupUpdateOne) SetNillableImageRateIndependent(v *bool) *GroupUpdateOne {
|
||||
if v != nil {
|
||||
_u.SetImageRateIndependent(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetImageRateMultiplier sets the "image_rate_multiplier" field.
|
||||
func (_u *GroupUpdateOne) SetImageRateMultiplier(v float64) *GroupUpdateOne {
|
||||
_u.mutation.ResetImageRateMultiplier()
|
||||
_u.mutation.SetImageRateMultiplier(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableImageRateMultiplier sets the "image_rate_multiplier" field if the given value is not nil.
|
||||
func (_u *GroupUpdateOne) SetNillableImageRateMultiplier(v *float64) *GroupUpdateOne {
|
||||
if v != nil {
|
||||
_u.SetImageRateMultiplier(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// AddImageRateMultiplier adds value to the "image_rate_multiplier" field.
|
||||
func (_u *GroupUpdateOne) AddImageRateMultiplier(v float64) *GroupUpdateOne {
|
||||
_u.mutation.AddImageRateMultiplier(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetImagePrice1k sets the "image_price_1k" field.
|
||||
func (_u *GroupUpdateOne) SetImagePrice1k(v float64) *GroupUpdateOne {
|
||||
_u.mutation.ResetImagePrice1k()
|
||||
@ -2327,6 +2437,18 @@ func (_u *GroupUpdateOne) sqlSave(ctx context.Context) (_node *Group, err error)
|
||||
if value, ok := _u.mutation.AddedDefaultValidityDays(); ok {
|
||||
_spec.AddField(group.FieldDefaultValidityDays, field.TypeInt, value)
|
||||
}
|
||||
if value, ok := _u.mutation.AllowImageGeneration(); ok {
|
||||
_spec.SetField(group.FieldAllowImageGeneration, field.TypeBool, value)
|
||||
}
|
||||
if value, ok := _u.mutation.ImageRateIndependent(); ok {
|
||||
_spec.SetField(group.FieldImageRateIndependent, field.TypeBool, value)
|
||||
}
|
||||
if value, ok := _u.mutation.ImageRateMultiplier(); ok {
|
||||
_spec.SetField(group.FieldImageRateMultiplier, field.TypeFloat64, value)
|
||||
}
|
||||
if value, ok := _u.mutation.AddedImageRateMultiplier(); ok {
|
||||
_spec.AddField(group.FieldImageRateMultiplier, field.TypeFloat64, value)
|
||||
}
|
||||
if value, ok := _u.mutation.ImagePrice1k(); ok {
|
||||
_spec.SetField(group.FieldImagePrice1k, field.TypeFloat64, value)
|
||||
}
|
||||
|
||||
@ -638,6 +638,9 @@ var (
|
||||
{Name: "weekly_limit_usd", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}},
|
||||
{Name: "monthly_limit_usd", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}},
|
||||
{Name: "default_validity_days", Type: field.TypeInt, Default: 30},
|
||||
{Name: "allow_image_generation", Type: field.TypeBool, Default: false},
|
||||
{Name: "image_rate_independent", Type: field.TypeBool, Default: false},
|
||||
{Name: "image_rate_multiplier", Type: field.TypeFloat64, Default: 1, SchemaType: map[string]string{"postgres": "decimal(10,4)"}},
|
||||
{Name: "image_price_1k", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}},
|
||||
{Name: "image_price_2k", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}},
|
||||
{Name: "image_price_4k", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}},
|
||||
@ -690,7 +693,7 @@ var (
|
||||
{
|
||||
Name: "group_sort_order",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{GroupsColumns[25]},
|
||||
Columns: []*schema.Column{GroupsColumns[28]},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
@ -14764,6 +14764,10 @@ type GroupMutation struct {
|
||||
addmonthly_limit_usd *float64
|
||||
default_validity_days *int
|
||||
adddefault_validity_days *int
|
||||
allow_image_generation *bool
|
||||
image_rate_independent *bool
|
||||
image_rate_multiplier *float64
|
||||
addimage_rate_multiplier *float64
|
||||
image_price_1k *float64
|
||||
addimage_price_1k *float64
|
||||
image_price_2k *float64
|
||||
@ -15583,6 +15587,134 @@ func (m *GroupMutation) ResetDefaultValidityDays() {
|
||||
m.adddefault_validity_days = nil
|
||||
}
|
||||
|
||||
// SetAllowImageGeneration sets the "allow_image_generation" field.
|
||||
func (m *GroupMutation) SetAllowImageGeneration(b bool) {
|
||||
m.allow_image_generation = &b
|
||||
}
|
||||
|
||||
// AllowImageGeneration returns the value of the "allow_image_generation" field in the mutation.
|
||||
func (m *GroupMutation) AllowImageGeneration() (r bool, exists bool) {
|
||||
v := m.allow_image_generation
|
||||
if v == nil {
|
||||
return
|
||||
}
|
||||
return *v, true
|
||||
}
|
||||
|
||||
// OldAllowImageGeneration returns the old "allow_image_generation" field's value of the Group entity.
|
||||
// If the Group object wasn't provided to the builder, the object is fetched from the database.
|
||||
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
|
||||
func (m *GroupMutation) OldAllowImageGeneration(ctx context.Context) (v bool, err error) {
|
||||
if !m.op.Is(OpUpdateOne) {
|
||||
return v, errors.New("OldAllowImageGeneration is only allowed on UpdateOne operations")
|
||||
}
|
||||
if m.id == nil || m.oldValue == nil {
|
||||
return v, errors.New("OldAllowImageGeneration requires an ID field in the mutation")
|
||||
}
|
||||
oldValue, err := m.oldValue(ctx)
|
||||
if err != nil {
|
||||
return v, fmt.Errorf("querying old value for OldAllowImageGeneration: %w", err)
|
||||
}
|
||||
return oldValue.AllowImageGeneration, nil
|
||||
}
|
||||
|
||||
// ResetAllowImageGeneration resets all changes to the "allow_image_generation" field.
|
||||
func (m *GroupMutation) ResetAllowImageGeneration() {
|
||||
m.allow_image_generation = nil
|
||||
}
|
||||
|
||||
// SetImageRateIndependent sets the "image_rate_independent" field.
|
||||
func (m *GroupMutation) SetImageRateIndependent(b bool) {
|
||||
m.image_rate_independent = &b
|
||||
}
|
||||
|
||||
// ImageRateIndependent returns the value of the "image_rate_independent" field in the mutation.
|
||||
func (m *GroupMutation) ImageRateIndependent() (r bool, exists bool) {
|
||||
v := m.image_rate_independent
|
||||
if v == nil {
|
||||
return
|
||||
}
|
||||
return *v, true
|
||||
}
|
||||
|
||||
// OldImageRateIndependent returns the old "image_rate_independent" field's value of the Group entity.
|
||||
// If the Group object wasn't provided to the builder, the object is fetched from the database.
|
||||
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
|
||||
func (m *GroupMutation) OldImageRateIndependent(ctx context.Context) (v bool, err error) {
|
||||
if !m.op.Is(OpUpdateOne) {
|
||||
return v, errors.New("OldImageRateIndependent is only allowed on UpdateOne operations")
|
||||
}
|
||||
if m.id == nil || m.oldValue == nil {
|
||||
return v, errors.New("OldImageRateIndependent requires an ID field in the mutation")
|
||||
}
|
||||
oldValue, err := m.oldValue(ctx)
|
||||
if err != nil {
|
||||
return v, fmt.Errorf("querying old value for OldImageRateIndependent: %w", err)
|
||||
}
|
||||
return oldValue.ImageRateIndependent, nil
|
||||
}
|
||||
|
||||
// ResetImageRateIndependent resets all changes to the "image_rate_independent" field.
|
||||
func (m *GroupMutation) ResetImageRateIndependent() {
|
||||
m.image_rate_independent = nil
|
||||
}
|
||||
|
||||
// SetImageRateMultiplier sets the "image_rate_multiplier" field.
|
||||
func (m *GroupMutation) SetImageRateMultiplier(f float64) {
|
||||
m.image_rate_multiplier = &f
|
||||
m.addimage_rate_multiplier = nil
|
||||
}
|
||||
|
||||
// ImageRateMultiplier returns the value of the "image_rate_multiplier" field in the mutation.
|
||||
func (m *GroupMutation) ImageRateMultiplier() (r float64, exists bool) {
|
||||
v := m.image_rate_multiplier
|
||||
if v == nil {
|
||||
return
|
||||
}
|
||||
return *v, true
|
||||
}
|
||||
|
||||
// OldImageRateMultiplier returns the old "image_rate_multiplier" field's value of the Group entity.
|
||||
// If the Group object wasn't provided to the builder, the object is fetched from the database.
|
||||
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
|
||||
func (m *GroupMutation) OldImageRateMultiplier(ctx context.Context) (v float64, err error) {
|
||||
if !m.op.Is(OpUpdateOne) {
|
||||
return v, errors.New("OldImageRateMultiplier is only allowed on UpdateOne operations")
|
||||
}
|
||||
if m.id == nil || m.oldValue == nil {
|
||||
return v, errors.New("OldImageRateMultiplier requires an ID field in the mutation")
|
||||
}
|
||||
oldValue, err := m.oldValue(ctx)
|
||||
if err != nil {
|
||||
return v, fmt.Errorf("querying old value for OldImageRateMultiplier: %w", err)
|
||||
}
|
||||
return oldValue.ImageRateMultiplier, nil
|
||||
}
|
||||
|
||||
// AddImageRateMultiplier adds f to the "image_rate_multiplier" field.
|
||||
func (m *GroupMutation) AddImageRateMultiplier(f float64) {
|
||||
if m.addimage_rate_multiplier != nil {
|
||||
*m.addimage_rate_multiplier += f
|
||||
} else {
|
||||
m.addimage_rate_multiplier = &f
|
||||
}
|
||||
}
|
||||
|
||||
// AddedImageRateMultiplier returns the value that was added to the "image_rate_multiplier" field in this mutation.
|
||||
func (m *GroupMutation) AddedImageRateMultiplier() (r float64, exists bool) {
|
||||
v := m.addimage_rate_multiplier
|
||||
if v == nil {
|
||||
return
|
||||
}
|
||||
return *v, true
|
||||
}
|
||||
|
||||
// ResetImageRateMultiplier resets all changes to the "image_rate_multiplier" field.
|
||||
func (m *GroupMutation) ResetImageRateMultiplier() {
|
||||
m.image_rate_multiplier = nil
|
||||
m.addimage_rate_multiplier = nil
|
||||
}
|
||||
|
||||
// SetImagePrice1k sets the "image_price_1k" field.
|
||||
func (m *GroupMutation) SetImagePrice1k(f float64) {
|
||||
m.image_price_1k = &f
|
||||
@ -16791,7 +16923,7 @@ func (m *GroupMutation) Type() string {
|
||||
// order to get all numeric fields that were incremented/decremented, call
|
||||
// AddedFields().
|
||||
func (m *GroupMutation) Fields() []string {
|
||||
fields := make([]string, 0, 31)
|
||||
fields := make([]string, 0, 34)
|
||||
if m.created_at != nil {
|
||||
fields = append(fields, group.FieldCreatedAt)
|
||||
}
|
||||
@ -16834,6 +16966,15 @@ func (m *GroupMutation) Fields() []string {
|
||||
if m.default_validity_days != nil {
|
||||
fields = append(fields, group.FieldDefaultValidityDays)
|
||||
}
|
||||
if m.allow_image_generation != nil {
|
||||
fields = append(fields, group.FieldAllowImageGeneration)
|
||||
}
|
||||
if m.image_rate_independent != nil {
|
||||
fields = append(fields, group.FieldImageRateIndependent)
|
||||
}
|
||||
if m.image_rate_multiplier != nil {
|
||||
fields = append(fields, group.FieldImageRateMultiplier)
|
||||
}
|
||||
if m.image_price_1k != nil {
|
||||
fields = append(fields, group.FieldImagePrice1k)
|
||||
}
|
||||
@ -16921,6 +17062,12 @@ func (m *GroupMutation) Field(name string) (ent.Value, bool) {
|
||||
return m.MonthlyLimitUsd()
|
||||
case group.FieldDefaultValidityDays:
|
||||
return m.DefaultValidityDays()
|
||||
case group.FieldAllowImageGeneration:
|
||||
return m.AllowImageGeneration()
|
||||
case group.FieldImageRateIndependent:
|
||||
return m.ImageRateIndependent()
|
||||
case group.FieldImageRateMultiplier:
|
||||
return m.ImageRateMultiplier()
|
||||
case group.FieldImagePrice1k:
|
||||
return m.ImagePrice1k()
|
||||
case group.FieldImagePrice2k:
|
||||
@ -16992,6 +17139,12 @@ func (m *GroupMutation) OldField(ctx context.Context, name string) (ent.Value, e
|
||||
return m.OldMonthlyLimitUsd(ctx)
|
||||
case group.FieldDefaultValidityDays:
|
||||
return m.OldDefaultValidityDays(ctx)
|
||||
case group.FieldAllowImageGeneration:
|
||||
return m.OldAllowImageGeneration(ctx)
|
||||
case group.FieldImageRateIndependent:
|
||||
return m.OldImageRateIndependent(ctx)
|
||||
case group.FieldImageRateMultiplier:
|
||||
return m.OldImageRateMultiplier(ctx)
|
||||
case group.FieldImagePrice1k:
|
||||
return m.OldImagePrice1k(ctx)
|
||||
case group.FieldImagePrice2k:
|
||||
@ -17133,6 +17286,27 @@ func (m *GroupMutation) SetField(name string, value ent.Value) error {
|
||||
}
|
||||
m.SetDefaultValidityDays(v)
|
||||
return nil
|
||||
case group.FieldAllowImageGeneration:
|
||||
v, ok := value.(bool)
|
||||
if !ok {
|
||||
return fmt.Errorf("unexpected type %T for field %s", value, name)
|
||||
}
|
||||
m.SetAllowImageGeneration(v)
|
||||
return nil
|
||||
case group.FieldImageRateIndependent:
|
||||
v, ok := value.(bool)
|
||||
if !ok {
|
||||
return fmt.Errorf("unexpected type %T for field %s", value, name)
|
||||
}
|
||||
m.SetImageRateIndependent(v)
|
||||
return nil
|
||||
case group.FieldImageRateMultiplier:
|
||||
v, ok := value.(float64)
|
||||
if !ok {
|
||||
return fmt.Errorf("unexpected type %T for field %s", value, name)
|
||||
}
|
||||
m.SetImageRateMultiplier(v)
|
||||
return nil
|
||||
case group.FieldImagePrice1k:
|
||||
v, ok := value.(float64)
|
||||
if !ok {
|
||||
@ -17275,6 +17449,9 @@ func (m *GroupMutation) AddedFields() []string {
|
||||
if m.adddefault_validity_days != nil {
|
||||
fields = append(fields, group.FieldDefaultValidityDays)
|
||||
}
|
||||
if m.addimage_rate_multiplier != nil {
|
||||
fields = append(fields, group.FieldImageRateMultiplier)
|
||||
}
|
||||
if m.addimage_price_1k != nil {
|
||||
fields = append(fields, group.FieldImagePrice1k)
|
||||
}
|
||||
@ -17314,6 +17491,8 @@ func (m *GroupMutation) AddedField(name string) (ent.Value, bool) {
|
||||
return m.AddedMonthlyLimitUsd()
|
||||
case group.FieldDefaultValidityDays:
|
||||
return m.AddedDefaultValidityDays()
|
||||
case group.FieldImageRateMultiplier:
|
||||
return m.AddedImageRateMultiplier()
|
||||
case group.FieldImagePrice1k:
|
||||
return m.AddedImagePrice1k()
|
||||
case group.FieldImagePrice2k:
|
||||
@ -17372,6 +17551,13 @@ func (m *GroupMutation) AddField(name string, value ent.Value) error {
|
||||
}
|
||||
m.AddDefaultValidityDays(v)
|
||||
return nil
|
||||
case group.FieldImageRateMultiplier:
|
||||
v, ok := value.(float64)
|
||||
if !ok {
|
||||
return fmt.Errorf("unexpected type %T for field %s", value, name)
|
||||
}
|
||||
m.AddImageRateMultiplier(v)
|
||||
return nil
|
||||
case group.FieldImagePrice1k:
|
||||
v, ok := value.(float64)
|
||||
if !ok {
|
||||
@ -17559,6 +17745,15 @@ func (m *GroupMutation) ResetField(name string) error {
|
||||
case group.FieldDefaultValidityDays:
|
||||
m.ResetDefaultValidityDays()
|
||||
return nil
|
||||
case group.FieldAllowImageGeneration:
|
||||
m.ResetAllowImageGeneration()
|
||||
return nil
|
||||
case group.FieldImageRateIndependent:
|
||||
m.ResetImageRateIndependent()
|
||||
return nil
|
||||
case group.FieldImageRateMultiplier:
|
||||
m.ResetImageRateMultiplier()
|
||||
return nil
|
||||
case group.FieldImagePrice1k:
|
||||
m.ResetImagePrice1k()
|
||||
return nil
|
||||
|
||||
@ -803,50 +803,62 @@ func init() {
|
||||
groupDescDefaultValidityDays := groupFields[10].Descriptor()
|
||||
// group.DefaultDefaultValidityDays holds the default value on creation for the default_validity_days field.
|
||||
group.DefaultDefaultValidityDays = groupDescDefaultValidityDays.Default.(int)
|
||||
// groupDescAllowImageGeneration is the schema descriptor for allow_image_generation field.
|
||||
groupDescAllowImageGeneration := groupFields[11].Descriptor()
|
||||
// group.DefaultAllowImageGeneration holds the default value on creation for the allow_image_generation field.
|
||||
group.DefaultAllowImageGeneration = groupDescAllowImageGeneration.Default.(bool)
|
||||
// groupDescImageRateIndependent is the schema descriptor for image_rate_independent field.
|
||||
groupDescImageRateIndependent := groupFields[12].Descriptor()
|
||||
// group.DefaultImageRateIndependent holds the default value on creation for the image_rate_independent field.
|
||||
group.DefaultImageRateIndependent = groupDescImageRateIndependent.Default.(bool)
|
||||
// groupDescImageRateMultiplier is the schema descriptor for image_rate_multiplier field.
|
||||
groupDescImageRateMultiplier := groupFields[13].Descriptor()
|
||||
// group.DefaultImageRateMultiplier holds the default value on creation for the image_rate_multiplier field.
|
||||
group.DefaultImageRateMultiplier = groupDescImageRateMultiplier.Default.(float64)
|
||||
// groupDescClaudeCodeOnly is the schema descriptor for claude_code_only field.
|
||||
groupDescClaudeCodeOnly := groupFields[14].Descriptor()
|
||||
groupDescClaudeCodeOnly := groupFields[17].Descriptor()
|
||||
// group.DefaultClaudeCodeOnly holds the default value on creation for the claude_code_only field.
|
||||
group.DefaultClaudeCodeOnly = groupDescClaudeCodeOnly.Default.(bool)
|
||||
// groupDescModelRoutingEnabled is the schema descriptor for model_routing_enabled field.
|
||||
groupDescModelRoutingEnabled := groupFields[18].Descriptor()
|
||||
groupDescModelRoutingEnabled := groupFields[21].Descriptor()
|
||||
// group.DefaultModelRoutingEnabled holds the default value on creation for the model_routing_enabled field.
|
||||
group.DefaultModelRoutingEnabled = groupDescModelRoutingEnabled.Default.(bool)
|
||||
// groupDescMcpXMLInject is the schema descriptor for mcp_xml_inject field.
|
||||
groupDescMcpXMLInject := groupFields[19].Descriptor()
|
||||
groupDescMcpXMLInject := groupFields[22].Descriptor()
|
||||
// group.DefaultMcpXMLInject holds the default value on creation for the mcp_xml_inject field.
|
||||
group.DefaultMcpXMLInject = groupDescMcpXMLInject.Default.(bool)
|
||||
// groupDescSupportedModelScopes is the schema descriptor for supported_model_scopes field.
|
||||
groupDescSupportedModelScopes := groupFields[20].Descriptor()
|
||||
groupDescSupportedModelScopes := groupFields[23].Descriptor()
|
||||
// group.DefaultSupportedModelScopes holds the default value on creation for the supported_model_scopes field.
|
||||
group.DefaultSupportedModelScopes = groupDescSupportedModelScopes.Default.([]string)
|
||||
// groupDescSortOrder is the schema descriptor for sort_order field.
|
||||
groupDescSortOrder := groupFields[21].Descriptor()
|
||||
groupDescSortOrder := groupFields[24].Descriptor()
|
||||
// group.DefaultSortOrder holds the default value on creation for the sort_order field.
|
||||
group.DefaultSortOrder = groupDescSortOrder.Default.(int)
|
||||
// groupDescAllowMessagesDispatch is the schema descriptor for allow_messages_dispatch field.
|
||||
groupDescAllowMessagesDispatch := groupFields[22].Descriptor()
|
||||
groupDescAllowMessagesDispatch := groupFields[25].Descriptor()
|
||||
// group.DefaultAllowMessagesDispatch holds the default value on creation for the allow_messages_dispatch field.
|
||||
group.DefaultAllowMessagesDispatch = groupDescAllowMessagesDispatch.Default.(bool)
|
||||
// groupDescRequireOauthOnly is the schema descriptor for require_oauth_only field.
|
||||
groupDescRequireOauthOnly := groupFields[23].Descriptor()
|
||||
groupDescRequireOauthOnly := groupFields[26].Descriptor()
|
||||
// group.DefaultRequireOauthOnly holds the default value on creation for the require_oauth_only field.
|
||||
group.DefaultRequireOauthOnly = groupDescRequireOauthOnly.Default.(bool)
|
||||
// groupDescRequirePrivacySet is the schema descriptor for require_privacy_set field.
|
||||
groupDescRequirePrivacySet := groupFields[24].Descriptor()
|
||||
groupDescRequirePrivacySet := groupFields[27].Descriptor()
|
||||
// group.DefaultRequirePrivacySet holds the default value on creation for the require_privacy_set field.
|
||||
group.DefaultRequirePrivacySet = groupDescRequirePrivacySet.Default.(bool)
|
||||
// groupDescDefaultMappedModel is the schema descriptor for default_mapped_model field.
|
||||
groupDescDefaultMappedModel := groupFields[25].Descriptor()
|
||||
groupDescDefaultMappedModel := groupFields[28].Descriptor()
|
||||
// group.DefaultDefaultMappedModel holds the default value on creation for the default_mapped_model field.
|
||||
group.DefaultDefaultMappedModel = groupDescDefaultMappedModel.Default.(string)
|
||||
// group.DefaultMappedModelValidator is a validator for the "default_mapped_model" field. It is called by the builders before save.
|
||||
group.DefaultMappedModelValidator = groupDescDefaultMappedModel.Validators[0].(func(string) error)
|
||||
// groupDescMessagesDispatchModelConfig is the schema descriptor for messages_dispatch_model_config field.
|
||||
groupDescMessagesDispatchModelConfig := groupFields[26].Descriptor()
|
||||
groupDescMessagesDispatchModelConfig := groupFields[29].Descriptor()
|
||||
// group.DefaultMessagesDispatchModelConfig holds the default value on creation for the messages_dispatch_model_config field.
|
||||
group.DefaultMessagesDispatchModelConfig = groupDescMessagesDispatchModelConfig.Default.(domain.OpenAIMessagesDispatchModelConfig)
|
||||
// groupDescRpmLimit is the schema descriptor for rpm_limit field.
|
||||
groupDescRpmLimit := groupFields[27].Descriptor()
|
||||
groupDescRpmLimit := groupFields[30].Descriptor()
|
||||
// group.DefaultRpmLimit holds the default value on creation for the rpm_limit field.
|
||||
group.DefaultRpmLimit = groupDescRpmLimit.Default.(int)
|
||||
idempotencyrecordMixin := schema.IdempotencyRecord{}.Mixin()
|
||||
|
||||
@ -74,6 +74,16 @@ func (Group) Fields() []ent.Field {
|
||||
Default(30),
|
||||
|
||||
// 图片生成计费配置(antigravity 和 gemini 平台使用)
|
||||
field.Bool("allow_image_generation").
|
||||
Default(false).
|
||||
Comment("是否允许该分组使用图片生成能力"),
|
||||
field.Bool("image_rate_independent").
|
||||
Default(false).
|
||||
Comment("图片生成是否使用独立倍率;false 表示共享分组有效倍率"),
|
||||
field.Float("image_rate_multiplier").
|
||||
SchemaType(map[string]string{dialect.Postgres: "decimal(10,4)"}).
|
||||
Default(1.0).
|
||||
Comment("图片生成独立倍率,仅 image_rate_independent=true 时生效"),
|
||||
field.Float("image_price_1k").
|
||||
Optional().
|
||||
Nillable().
|
||||
|
||||
@ -576,6 +576,24 @@ type ConcurrencyConfig struct {
|
||||
PingInterval int `mapstructure:"ping_interval"`
|
||||
}
|
||||
|
||||
type ImageConcurrencyConfig struct {
|
||||
// Enabled: 是否启用图片生成独立并发限制,默认关闭以保持现有行为
|
||||
Enabled bool `mapstructure:"enabled"`
|
||||
// MaxConcurrentRequests: 当前进程允许同时处理的图片生成请求数,0表示不限制
|
||||
MaxConcurrentRequests int `mapstructure:"max_concurrent_requests"`
|
||||
// OverflowMode: 图片并发达到上限后的处理方式:reject/wait
|
||||
OverflowMode string `mapstructure:"overflow_mode"`
|
||||
// WaitTimeoutSeconds: overflow_mode=wait 时等待图片并发槽位的超时时间(秒)
|
||||
WaitTimeoutSeconds int `mapstructure:"wait_timeout_seconds"`
|
||||
// MaxWaitingRequests: overflow_mode=wait 时当前进程允许排队等待的图片请求数
|
||||
MaxWaitingRequests int `mapstructure:"max_waiting_requests"`
|
||||
}
|
||||
|
||||
const (
|
||||
ImageConcurrencyOverflowModeReject = "reject"
|
||||
ImageConcurrencyOverflowModeWait = "wait"
|
||||
)
|
||||
|
||||
// GatewayConfig API网关相关配置
|
||||
type GatewayConfig struct {
|
||||
// 等待上游响应头的超时时间(秒),0表示无超时
|
||||
@ -609,6 +627,8 @@ type GatewayConfig struct {
|
||||
AntigravityLSWorker GatewayAntigravityLSWorkerConfig `mapstructure:"antigravity_ls_worker"`
|
||||
// NodeTLSProxy: Node.js TLS 代理配置
|
||||
NodeTLSProxy NodeTLSProxyConfig `mapstructure:"node_tls_proxy"`
|
||||
// ImageConcurrency: 图片生成独立并发限制配置(默认关闭)
|
||||
ImageConcurrency ImageConcurrencyConfig `mapstructure:"image_concurrency"`
|
||||
|
||||
// HTTP 上游连接池配置(性能优化:支持高并发场景调优)
|
||||
// MaxIdleConns: 所有主机的最大空闲连接总数
|
||||
@ -640,6 +660,10 @@ type GatewayConfig struct {
|
||||
StreamDataIntervalTimeout int `mapstructure:"stream_data_interval_timeout"`
|
||||
// StreamKeepaliveInterval: 流式 keepalive 间隔(秒),0表示禁用
|
||||
StreamKeepaliveInterval int `mapstructure:"stream_keepalive_interval"`
|
||||
// ImageStreamDataIntervalTimeout: 图片流数据间隔超时(秒),0表示禁用
|
||||
ImageStreamDataIntervalTimeout int `mapstructure:"image_stream_data_interval_timeout"`
|
||||
// ImageStreamKeepaliveInterval: 图片流式 keepalive 间隔(秒),0表示禁用
|
||||
ImageStreamKeepaliveInterval int `mapstructure:"image_stream_keepalive_interval"`
|
||||
// MaxLineSize: 上游 SSE 单行最大字节数(0使用默认值)
|
||||
MaxLineSize int `mapstructure:"max_line_size"`
|
||||
|
||||
@ -1789,6 +1813,11 @@ func setDefaults() {
|
||||
viper.SetDefault("gateway.openai_ws.scheduler_score_weights.queue", 0.7)
|
||||
viper.SetDefault("gateway.openai_ws.scheduler_score_weights.error_rate", 0.8)
|
||||
viper.SetDefault("gateway.openai_ws.scheduler_score_weights.ttft", 0.5)
|
||||
viper.SetDefault("gateway.image_concurrency.enabled", false)
|
||||
viper.SetDefault("gateway.image_concurrency.max_concurrent_requests", 0)
|
||||
viper.SetDefault("gateway.image_concurrency.overflow_mode", ImageConcurrencyOverflowModeReject)
|
||||
viper.SetDefault("gateway.image_concurrency.wait_timeout_seconds", 30)
|
||||
viper.SetDefault("gateway.image_concurrency.max_waiting_requests", 100)
|
||||
viper.SetDefault("gateway.antigravity_fallback_cooldown_minutes", 1)
|
||||
viper.SetDefault("gateway.antigravity_extra_retries", 10)
|
||||
viper.SetDefault("gateway.max_body_size", int64(256*1024*1024))
|
||||
@ -1806,6 +1835,8 @@ func setDefaults() {
|
||||
viper.SetDefault("gateway.concurrency_slot_ttl_minutes", 30) // 并发槽位过期时间(支持超长请求)
|
||||
viper.SetDefault("gateway.stream_data_interval_timeout", 180)
|
||||
viper.SetDefault("gateway.stream_keepalive_interval", 10)
|
||||
viper.SetDefault("gateway.image_stream_data_interval_timeout", 900)
|
||||
viper.SetDefault("gateway.image_stream_keepalive_interval", 10)
|
||||
viper.SetDefault("gateway.max_line_size", 500*1024*1024)
|
||||
viper.SetDefault("gateway.scheduling.sticky_session_max_waiting", 3)
|
||||
viper.SetDefault("gateway.scheduling.sticky_session_wait_timeout", 120*time.Second)
|
||||
@ -2410,6 +2441,21 @@ func (c *Config) Validate() error {
|
||||
ConnectionPoolIsolationProxy, ConnectionPoolIsolationAccount, ConnectionPoolIsolationAccountProxy)
|
||||
}
|
||||
}
|
||||
if c.Gateway.ImageConcurrency.MaxConcurrentRequests < 0 {
|
||||
return fmt.Errorf("gateway.image_concurrency.max_concurrent_requests must be non-negative")
|
||||
}
|
||||
switch strings.TrimSpace(c.Gateway.ImageConcurrency.OverflowMode) {
|
||||
case "", ImageConcurrencyOverflowModeReject, ImageConcurrencyOverflowModeWait:
|
||||
default:
|
||||
return fmt.Errorf("gateway.image_concurrency.overflow_mode must be one of: %s/%s",
|
||||
ImageConcurrencyOverflowModeReject, ImageConcurrencyOverflowModeWait)
|
||||
}
|
||||
if c.Gateway.ImageConcurrency.WaitTimeoutSeconds < 0 {
|
||||
return fmt.Errorf("gateway.image_concurrency.wait_timeout_seconds must be non-negative")
|
||||
}
|
||||
if c.Gateway.ImageConcurrency.MaxWaitingRequests < 0 {
|
||||
return fmt.Errorf("gateway.image_concurrency.max_waiting_requests must be non-negative")
|
||||
}
|
||||
if c.Gateway.MaxIdleConns <= 0 {
|
||||
return fmt.Errorf("gateway.max_idle_conns must be positive")
|
||||
}
|
||||
@ -2448,6 +2494,20 @@ func (c *Config) Validate() error {
|
||||
(c.Gateway.StreamKeepaliveInterval < 5 || c.Gateway.StreamKeepaliveInterval > 30) {
|
||||
return fmt.Errorf("gateway.stream_keepalive_interval must be 0 or between 5-30 seconds")
|
||||
}
|
||||
if c.Gateway.ImageStreamDataIntervalTimeout < 0 {
|
||||
return fmt.Errorf("gateway.image_stream_data_interval_timeout must be non-negative")
|
||||
}
|
||||
if c.Gateway.ImageStreamDataIntervalTimeout != 0 &&
|
||||
(c.Gateway.ImageStreamDataIntervalTimeout < 60 || c.Gateway.ImageStreamDataIntervalTimeout > 1800) {
|
||||
return fmt.Errorf("gateway.image_stream_data_interval_timeout must be 0 or between 60-1800 seconds")
|
||||
}
|
||||
if c.Gateway.ImageStreamKeepaliveInterval < 0 {
|
||||
return fmt.Errorf("gateway.image_stream_keepalive_interval must be non-negative")
|
||||
}
|
||||
if c.Gateway.ImageStreamKeepaliveInterval != 0 &&
|
||||
(c.Gateway.ImageStreamKeepaliveInterval < 5 || c.Gateway.ImageStreamKeepaliveInterval > 60) {
|
||||
return fmt.Errorf("gateway.image_stream_keepalive_interval must be 0 or between 5-60 seconds")
|
||||
}
|
||||
// 兼容旧键 sticky_previous_response_ttl_seconds
|
||||
if c.Gateway.OpenAIWS.StickyResponseIDTTLSeconds <= 0 && c.Gateway.OpenAIWS.StickyPreviousResponseTTLSeconds > 0 {
|
||||
c.Gateway.OpenAIWS.StickyResponseIDTTLSeconds = c.Gateway.OpenAIWS.StickyPreviousResponseTTLSeconds
|
||||
|
||||
@ -1282,6 +1282,46 @@ func TestValidateConfigErrors(t *testing.T) {
|
||||
mutate: func(c *Config) { c.Gateway.StreamDataIntervalTimeout = -1 },
|
||||
wantErr: "gateway.stream_data_interval_timeout must be non-negative",
|
||||
},
|
||||
{
|
||||
name: "gateway image stream keepalive range",
|
||||
mutate: func(c *Config) { c.Gateway.ImageStreamKeepaliveInterval = 4 },
|
||||
wantErr: "gateway.image_stream_keepalive_interval",
|
||||
},
|
||||
{
|
||||
name: "gateway image stream keepalive negative",
|
||||
mutate: func(c *Config) { c.Gateway.ImageStreamKeepaliveInterval = -1 },
|
||||
wantErr: "gateway.image_stream_keepalive_interval must be non-negative",
|
||||
},
|
||||
{
|
||||
name: "gateway image stream data interval range",
|
||||
mutate: func(c *Config) { c.Gateway.ImageStreamDataIntervalTimeout = 30 },
|
||||
wantErr: "gateway.image_stream_data_interval_timeout",
|
||||
},
|
||||
{
|
||||
name: "gateway image stream data interval negative",
|
||||
mutate: func(c *Config) { c.Gateway.ImageStreamDataIntervalTimeout = -1 },
|
||||
wantErr: "gateway.image_stream_data_interval_timeout must be non-negative",
|
||||
},
|
||||
{
|
||||
name: "gateway image concurrency max negative",
|
||||
mutate: func(c *Config) { c.Gateway.ImageConcurrency.MaxConcurrentRequests = -1 },
|
||||
wantErr: "gateway.image_concurrency.max_concurrent_requests must be non-negative",
|
||||
},
|
||||
{
|
||||
name: "gateway image concurrency overflow mode invalid",
|
||||
mutate: func(c *Config) { c.Gateway.ImageConcurrency.OverflowMode = "queue" },
|
||||
wantErr: "gateway.image_concurrency.overflow_mode",
|
||||
},
|
||||
{
|
||||
name: "gateway image concurrency wait timeout negative",
|
||||
mutate: func(c *Config) { c.Gateway.ImageConcurrency.WaitTimeoutSeconds = -1 },
|
||||
wantErr: "gateway.image_concurrency.wait_timeout_seconds must be non-negative",
|
||||
},
|
||||
{
|
||||
name: "gateway image concurrency max waiting negative",
|
||||
mutate: func(c *Config) { c.Gateway.ImageConcurrency.MaxWaitingRequests = -1 },
|
||||
wantErr: "gateway.image_concurrency.max_waiting_requests must be non-negative",
|
||||
},
|
||||
{
|
||||
name: "gateway max line size",
|
||||
mutate: func(c *Config) { c.Gateway.MaxLineSize = 1024 },
|
||||
@ -1754,3 +1794,41 @@ func TestLoad_DefaultGatewayUsageRecordConfig(t *testing.T) {
|
||||
t.Fatalf("auto_scale_cooldown_seconds = %d, want 10", cfg.Gateway.UsageRecord.AutoScaleCooldownSeconds)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoad_DefaultGatewayImageStreamConfig(t *testing.T) {
|
||||
resetViperWithJWTSecret(t)
|
||||
cfg, err := Load()
|
||||
if err != nil {
|
||||
t.Fatalf("Load() error: %v", err)
|
||||
}
|
||||
if cfg.Gateway.StreamDataIntervalTimeout != 180 {
|
||||
t.Fatalf("stream_data_interval_timeout = %d, want 180", cfg.Gateway.StreamDataIntervalTimeout)
|
||||
}
|
||||
if cfg.Gateway.StreamKeepaliveInterval != 10 {
|
||||
t.Fatalf("stream_keepalive_interval = %d, want 10", cfg.Gateway.StreamKeepaliveInterval)
|
||||
}
|
||||
if cfg.Gateway.ImageStreamDataIntervalTimeout != 900 {
|
||||
t.Fatalf("image_stream_data_interval_timeout = %d, want 900", cfg.Gateway.ImageStreamDataIntervalTimeout)
|
||||
}
|
||||
if cfg.Gateway.ImageStreamKeepaliveInterval != 10 {
|
||||
t.Fatalf("image_stream_keepalive_interval = %d, want 10", cfg.Gateway.ImageStreamKeepaliveInterval)
|
||||
}
|
||||
if cfg.Gateway.ImageConcurrency.Enabled {
|
||||
t.Fatalf("image_concurrency.enabled = true, want false")
|
||||
}
|
||||
if cfg.Gateway.ImageConcurrency.MaxConcurrentRequests != 0 {
|
||||
t.Fatalf("image_concurrency.max_concurrent_requests = %d, want 0", cfg.Gateway.ImageConcurrency.MaxConcurrentRequests)
|
||||
}
|
||||
if cfg.Gateway.ImageConcurrency.OverflowMode != ImageConcurrencyOverflowModeReject {
|
||||
t.Fatalf("image_concurrency.overflow_mode = %q, want %q", cfg.Gateway.ImageConcurrency.OverflowMode, ImageConcurrencyOverflowModeReject)
|
||||
}
|
||||
if cfg.Gateway.ImageConcurrency.WaitTimeoutSeconds != 30 {
|
||||
t.Fatalf("image_concurrency.wait_timeout_seconds = %d, want 30", cfg.Gateway.ImageConcurrency.WaitTimeoutSeconds)
|
||||
}
|
||||
if cfg.Gateway.ImageConcurrency.MaxWaitingRequests != 100 {
|
||||
t.Fatalf("image_concurrency.max_waiting_requests = %d, want 100", cfg.Gateway.ImageConcurrency.MaxWaitingRequests)
|
||||
}
|
||||
if cfg.Gateway.ImageStreamDataIntervalTimeout <= cfg.Gateway.StreamDataIntervalTimeout {
|
||||
t.Fatalf("image stream timeout = %d, want greater than ordinary stream timeout %d", cfg.Gateway.ImageStreamDataIntervalTimeout, cfg.Gateway.StreamDataIntervalTimeout)
|
||||
}
|
||||
}
|
||||
|
||||
@ -529,6 +529,10 @@ func (h *AccountHandler) Create(c *gin.Context) {
|
||||
// 确定是否跳过混合渠道检查
|
||||
skipCheck := req.ConfirmMixedChannelRisk != nil && *req.ConfirmMixedChannelRisk
|
||||
|
||||
// 捕获闭包内创建的账号引用,用于创建成功后触发异步探测。
|
||||
// 幂等重放时闭包不会执行 → createdAccount 为 nil → 不重复调度。
|
||||
var createdAccount *service.Account
|
||||
|
||||
result, err := executeAdminIdempotent(c, "admin.accounts.create", req, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) {
|
||||
account, execErr := h.adminService.CreateAccount(ctx, &service.CreateAccountInput{
|
||||
Name: req.Name,
|
||||
@ -550,6 +554,7 @@ func (h *AccountHandler) Create(c *gin.Context) {
|
||||
if execErr != nil {
|
||||
return nil, execErr
|
||||
}
|
||||
createdAccount = account
|
||||
// Antigravity OAuth: 新账号直接设置隐私
|
||||
h.adminService.ForceAntigravityPrivacy(ctx, account)
|
||||
// OpenAI OAuth: 新账号直接设置隐私
|
||||
@ -578,6 +583,9 @@ func (h *AccountHandler) Create(c *gin.Context) {
|
||||
if result != nil && result.Replayed {
|
||||
c.Header("X-Idempotency-Replayed", "true")
|
||||
}
|
||||
// OpenAI APIKey 账号创建后异步探测上游 /v1/responses 能力。
|
||||
// 探测失败不影响账号创建响应。
|
||||
h.scheduleOpenAIResponsesProbe(createdAccount)
|
||||
response.Success(c, result.Data)
|
||||
}
|
||||
|
||||
@ -638,9 +646,39 @@ func (h *AccountHandler) Update(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// OpenAI APIKey: credentials 修改后重新探测上游能力(base_url/api_key 可能变更)。
|
||||
// 异步执行,探测失败不影响账号更新响应。
|
||||
if len(req.Credentials) > 0 {
|
||||
h.scheduleOpenAIResponsesProbe(account)
|
||||
}
|
||||
|
||||
response.Success(c, h.buildAccountResponseWithRuntime(c.Request.Context(), account))
|
||||
}
|
||||
|
||||
// scheduleOpenAIResponsesProbe 异步触发 OpenAI APIKey 账号的 Responses API 能力探测。
|
||||
//
|
||||
// 仅对 platform=openai && type=apikey 账号生效;其他账号无操作。
|
||||
// 探测本身在 goroutine 中执行(会发一次 HTTP 请求到上游),不会阻塞
|
||||
// 当前请求。探测错误仅记录日志,不向上下文传播:探测失败时标记保持缺失,
|
||||
// 网关会按"现状即证据"默认走 Responses。
|
||||
func (h *AccountHandler) scheduleOpenAIResponsesProbe(account *service.Account) {
|
||||
if account == nil || account.Platform != service.PlatformOpenAI || account.Type != service.AccountTypeAPIKey {
|
||||
return
|
||||
}
|
||||
if h.accountTestService == nil {
|
||||
return
|
||||
}
|
||||
accountID := account.ID
|
||||
go func() {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
slog.Error("openai_responses_probe_panic", "account_id", accountID, "recover", r)
|
||||
}
|
||||
}()
|
||||
h.accountTestService.ProbeOpenAIAPIKeyResponsesSupport(context.Background(), accountID)
|
||||
}()
|
||||
}
|
||||
|
||||
// Delete handles deleting an account
|
||||
// DELETE /api/v1/admin/accounts/:id
|
||||
func (h *AccountHandler) Delete(c *gin.Context) {
|
||||
@ -1232,6 +1270,8 @@ func (h *AccountHandler) BatchCreate(c *gin.Context) {
|
||||
openaiPrivacyAccounts = append(openaiPrivacyAccounts, account)
|
||||
}
|
||||
}
|
||||
// OpenAI APIKey 账号异步探测 /v1/responses 能力。
|
||||
h.scheduleOpenAIResponsesProbe(account)
|
||||
success++
|
||||
results = append(results, gin.H{
|
||||
"name": item.Name,
|
||||
|
||||
@ -2,8 +2,11 @@ package admin
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@ -181,3 +184,108 @@ func (h *AffiliateHandler) LookupUsers(c *gin.Context) {
|
||||
}
|
||||
response.Success(c, result)
|
||||
}
|
||||
|
||||
// GetUserOverview returns one user's affiliate overview.
|
||||
// GET /api/v1/admin/affiliates/users/:user_id/overview
|
||||
func (h *AffiliateHandler) GetUserOverview(c *gin.Context) {
|
||||
userID, err := strconv.ParseInt(c.Param("user_id"), 10, 64)
|
||||
if err != nil || userID <= 0 {
|
||||
response.BadRequest(c, "Invalid user_id")
|
||||
return
|
||||
}
|
||||
overview, err := h.affiliateService.AdminGetUserOverview(c.Request.Context(), userID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, overview)
|
||||
}
|
||||
|
||||
// ListInviteRecords returns all inviter-invitee relationships.
|
||||
// GET /api/v1/admin/affiliates/invites
|
||||
func (h *AffiliateHandler) ListInviteRecords(c *gin.Context) {
|
||||
page, pageSize := response.ParsePagination(c)
|
||||
filter := parseAffiliateRecordFilter(c, page, pageSize)
|
||||
items, total, err := h.affiliateService.AdminListInviteRecords(c.Request.Context(), filter)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Paginated(c, items, total, filter.Page, filter.PageSize)
|
||||
}
|
||||
|
||||
// ListRebateRecords returns all order-level affiliate rebate records.
|
||||
// GET /api/v1/admin/affiliates/rebates
|
||||
func (h *AffiliateHandler) ListRebateRecords(c *gin.Context) {
|
||||
page, pageSize := response.ParsePagination(c)
|
||||
filter := parseAffiliateRecordFilter(c, page, pageSize)
|
||||
items, total, err := h.affiliateService.AdminListRebateRecords(c.Request.Context(), filter)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Paginated(c, items, total, filter.Page, filter.PageSize)
|
||||
}
|
||||
|
||||
// ListTransferRecords returns all affiliate quota-to-balance transfer records.
|
||||
// GET /api/v1/admin/affiliates/transfers
|
||||
func (h *AffiliateHandler) ListTransferRecords(c *gin.Context) {
|
||||
page, pageSize := response.ParsePagination(c)
|
||||
filter := parseAffiliateRecordFilter(c, page, pageSize)
|
||||
items, total, err := h.affiliateService.AdminListTransferRecords(c.Request.Context(), filter)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Paginated(c, items, total, filter.Page, filter.PageSize)
|
||||
}
|
||||
|
||||
func parseAffiliateRecordFilter(c *gin.Context, page, pageSize int) service.AffiliateRecordFilter {
|
||||
filter := service.AffiliateRecordFilter{
|
||||
Search: c.Query("search"),
|
||||
Page: page,
|
||||
PageSize: pageSize,
|
||||
SortBy: c.Query("sort_by"),
|
||||
SortDesc: c.Query("sort_order") != "asc",
|
||||
}
|
||||
if filter.PageSize > 100 {
|
||||
filter.PageSize = 100
|
||||
}
|
||||
userTZ := c.Query("timezone")
|
||||
if t := parseAffiliateRecordStartTime(c.Query("start_at"), userTZ); t != nil {
|
||||
filter.StartAt = t
|
||||
}
|
||||
if t := parseAffiliateRecordEndTime(c.Query("end_at"), userTZ); t != nil {
|
||||
filter.EndAt = t
|
||||
}
|
||||
return filter
|
||||
}
|
||||
|
||||
func parseAffiliateRecordStartTime(raw string, userTZ string) *time.Time {
|
||||
raw = strings.TrimSpace(raw)
|
||||
if raw == "" {
|
||||
return nil
|
||||
}
|
||||
if parsed, err := time.Parse(time.RFC3339, raw); err == nil {
|
||||
return &parsed
|
||||
}
|
||||
if parsed, err := timezone.ParseInUserLocation("2006-01-02", raw, userTZ); err == nil {
|
||||
return &parsed
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func parseAffiliateRecordEndTime(raw string, userTZ string) *time.Time {
|
||||
raw = strings.TrimSpace(raw)
|
||||
if raw == "" {
|
||||
return nil
|
||||
}
|
||||
if parsed, err := time.Parse(time.RFC3339, raw); err == nil {
|
||||
return &parsed
|
||||
}
|
||||
if parsed, err := timezone.ParseInUserLocation("2006-01-02", raw, userTZ); err == nil {
|
||||
end := parsed.AddDate(0, 0, 1).Add(-time.Nanosecond)
|
||||
return &end
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -92,6 +92,9 @@ type CreateGroupRequest struct {
|
||||
WeeklyLimitUSD optionalLimitField `json:"weekly_limit_usd"`
|
||||
MonthlyLimitUSD optionalLimitField `json:"monthly_limit_usd"`
|
||||
// 图片生成计费配置(antigravity 和 gemini 平台使用,负数表示清除配置)
|
||||
AllowImageGeneration bool `json:"allow_image_generation"`
|
||||
ImageRateIndependent bool `json:"image_rate_independent"`
|
||||
ImageRateMultiplier *float64 `json:"image_rate_multiplier"`
|
||||
ImagePrice1K *float64 `json:"image_price_1k"`
|
||||
ImagePrice2K *float64 `json:"image_price_2k"`
|
||||
ImagePrice4K *float64 `json:"image_price_4k"`
|
||||
@ -129,6 +132,9 @@ type UpdateGroupRequest struct {
|
||||
WeeklyLimitUSD optionalLimitField `json:"weekly_limit_usd"`
|
||||
MonthlyLimitUSD optionalLimitField `json:"monthly_limit_usd"`
|
||||
// 图片生成计费配置(antigravity 和 gemini 平台使用,负数表示清除配置)
|
||||
AllowImageGeneration *bool `json:"allow_image_generation"`
|
||||
ImageRateIndependent *bool `json:"image_rate_independent"`
|
||||
ImageRateMultiplier *float64 `json:"image_rate_multiplier"`
|
||||
ImagePrice1K *float64 `json:"image_price_1k"`
|
||||
ImagePrice2K *float64 `json:"image_price_2k"`
|
||||
ImagePrice4K *float64 `json:"image_price_4k"`
|
||||
@ -251,6 +257,9 @@ func (h *GroupHandler) Create(c *gin.Context) {
|
||||
DailyLimitUSD: req.DailyLimitUSD.ToServiceInput(),
|
||||
WeeklyLimitUSD: req.WeeklyLimitUSD.ToServiceInput(),
|
||||
MonthlyLimitUSD: req.MonthlyLimitUSD.ToServiceInput(),
|
||||
AllowImageGeneration: req.AllowImageGeneration,
|
||||
ImageRateIndependent: req.ImageRateIndependent,
|
||||
ImageRateMultiplier: req.ImageRateMultiplier,
|
||||
ImagePrice1K: req.ImagePrice1K,
|
||||
ImagePrice2K: req.ImagePrice2K,
|
||||
ImagePrice4K: req.ImagePrice4K,
|
||||
@ -303,6 +312,9 @@ func (h *GroupHandler) Update(c *gin.Context) {
|
||||
DailyLimitUSD: req.DailyLimitUSD.ToServiceInput(),
|
||||
WeeklyLimitUSD: req.WeeklyLimitUSD.ToServiceInput(),
|
||||
MonthlyLimitUSD: req.MonthlyLimitUSD.ToServiceInput(),
|
||||
AllowImageGeneration: req.AllowImageGeneration,
|
||||
ImageRateIndependent: req.ImageRateIndependent,
|
||||
ImageRateMultiplier: req.ImageRateMultiplier,
|
||||
ImagePrice1K: req.ImagePrice1K,
|
||||
ImagePrice2K: req.ImagePrice2K,
|
||||
ImagePrice4K: req.ImagePrice4K,
|
||||
|
||||
@ -2462,6 +2462,58 @@ func (h *SettingHandler) UpdateOverloadCooldownSettings(c *gin.Context) {
|
||||
})
|
||||
}
|
||||
|
||||
// GetRateLimit429CooldownSettings 获取429默认回避配置
|
||||
// GET /api/v1/admin/settings/rate-limit-429-cooldown
|
||||
func (h *SettingHandler) GetRateLimit429CooldownSettings(c *gin.Context) {
|
||||
settings, err := h.settingService.GetRateLimit429CooldownSettings(c.Request.Context())
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, dto.RateLimit429CooldownSettings{
|
||||
Enabled: settings.Enabled,
|
||||
CooldownSeconds: settings.CooldownSeconds,
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateRateLimit429CooldownSettingsRequest 更新429默认回避配置请求
|
||||
type UpdateRateLimit429CooldownSettingsRequest struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
CooldownSeconds int `json:"cooldown_seconds"`
|
||||
}
|
||||
|
||||
// UpdateRateLimit429CooldownSettings 更新429默认回避配置
|
||||
// PUT /api/v1/admin/settings/rate-limit-429-cooldown
|
||||
func (h *SettingHandler) UpdateRateLimit429CooldownSettings(c *gin.Context) {
|
||||
var req UpdateRateLimit429CooldownSettingsRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
settings := &service.RateLimit429CooldownSettings{
|
||||
Enabled: req.Enabled,
|
||||
CooldownSeconds: req.CooldownSeconds,
|
||||
}
|
||||
|
||||
if err := h.settingService.SetRateLimit429CooldownSettings(c.Request.Context(), settings); err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
updatedSettings, err := h.settingService.GetRateLimit429CooldownSettings(c.Request.Context())
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, dto.RateLimit429CooldownSettings{
|
||||
Enabled: updatedSettings.Enabled,
|
||||
CooldownSeconds: updatedSettings.CooldownSeconds,
|
||||
})
|
||||
}
|
||||
|
||||
// GetStreamTimeoutSettings 获取流超时处理配置
|
||||
// GET /api/v1/admin/settings/stream-timeout
|
||||
func (h *SettingHandler) GetStreamTimeoutSettings(c *gin.Context) {
|
||||
|
||||
@ -390,7 +390,7 @@ func (h *UserHandler) GetUserUsage(c *gin.Context) {
|
||||
// GetBalanceHistory handles getting user's balance/concurrency change history
|
||||
// GET /api/v1/admin/users/:id/balance-history
|
||||
// Query params:
|
||||
// - type: filter by record type (balance, admin_balance, concurrency, admin_concurrency, subscription)
|
||||
// - type: filter by record type (balance, affiliate_balance, admin_balance, concurrency, admin_concurrency, subscription)
|
||||
func (h *UserHandler) GetBalanceHistory(c *gin.Context) {
|
||||
userID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
|
||||
@ -176,6 +176,9 @@ func groupFromServiceBase(g *service.Group) Group {
|
||||
DailyLimitUSD: g.DailyLimitUSD,
|
||||
WeeklyLimitUSD: g.WeeklyLimitUSD,
|
||||
MonthlyLimitUSD: g.MonthlyLimitUSD,
|
||||
AllowImageGeneration: g.AllowImageGeneration,
|
||||
ImageRateIndependent: g.ImageRateIndependent,
|
||||
ImageRateMultiplier: g.ImageRateMultiplier,
|
||||
ImagePrice1K: g.ImagePrice1K,
|
||||
ImagePrice2K: g.ImagePrice2K,
|
||||
ImagePrice4K: g.ImagePrice4K,
|
||||
|
||||
@ -264,6 +264,12 @@ type OverloadCooldownSettings struct {
|
||||
CooldownMinutes int `json:"cooldown_minutes"`
|
||||
}
|
||||
|
||||
// RateLimit429CooldownSettings 429默认回避配置 DTO
|
||||
type RateLimit429CooldownSettings struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
CooldownSeconds int `json:"cooldown_seconds"`
|
||||
}
|
||||
|
||||
// StreamTimeoutSettings 流超时处理配置 DTO
|
||||
type StreamTimeoutSettings struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
|
||||
@ -94,9 +94,12 @@ type Group struct {
|
||||
MonthlyLimitUSD *float64 `json:"monthly_limit_usd"`
|
||||
|
||||
// 图片生成计费配置(仅 antigravity 平台使用)
|
||||
ImagePrice1K *float64 `json:"image_price_1k"`
|
||||
ImagePrice2K *float64 `json:"image_price_2k"`
|
||||
ImagePrice4K *float64 `json:"image_price_4k"`
|
||||
AllowImageGeneration bool `json:"allow_image_generation"`
|
||||
ImageRateIndependent bool `json:"image_rate_independent"`
|
||||
ImageRateMultiplier float64 `json:"image_rate_multiplier"`
|
||||
ImagePrice1K *float64 `json:"image_price_1k"`
|
||||
ImagePrice2K *float64 `json:"image_price_2k"`
|
||||
ImagePrice4K *float64 `json:"image_price_4k"`
|
||||
|
||||
// Claude Code 客户端限制
|
||||
ClaudeCodeOnly bool `json:"claude_code_only"`
|
||||
|
||||
126
backend/internal/handler/image_concurrency_limiter.go
Normal file
126
backend/internal/handler/image_concurrency_limiter.go
Normal file
@ -0,0 +1,126 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type imageConcurrencyLimiter struct {
|
||||
mu sync.Mutex
|
||||
notify chan struct{}
|
||||
limit int
|
||||
active int
|
||||
waiting int
|
||||
enabled bool
|
||||
}
|
||||
|
||||
func (l *imageConcurrencyLimiter) TryAcquire(enabled bool, limit int) (func(), bool) {
|
||||
return l.acquire(context.Background(), enabled, limit, false, 0, 0)
|
||||
}
|
||||
|
||||
func (l *imageConcurrencyLimiter) Acquire(ctx context.Context, enabled bool, limit int, wait bool, timeout time.Duration, maxWaiting int) (func(), bool) {
|
||||
return l.acquire(ctx, enabled, limit, wait, timeout, maxWaiting)
|
||||
}
|
||||
|
||||
func (l *imageConcurrencyLimiter) acquire(ctx context.Context, enabled bool, limit int, wait bool, timeout time.Duration, maxWaiting int) (func(), bool) {
|
||||
if !enabled || limit <= 0 {
|
||||
return nil, true
|
||||
}
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
if wait {
|
||||
if timeout <= 0 {
|
||||
return nil, false
|
||||
}
|
||||
waitCtx, cancel := context.WithTimeout(ctx, timeout)
|
||||
defer cancel()
|
||||
ctx = waitCtx
|
||||
}
|
||||
if maxWaiting < 0 {
|
||||
maxWaiting = 0
|
||||
}
|
||||
for {
|
||||
release, acquired, waitRelease, notify := l.tryAcquireLocked(enabled, limit, wait, maxWaiting)
|
||||
if acquired {
|
||||
return release, acquired
|
||||
}
|
||||
if !wait || notify == nil {
|
||||
return nil, false
|
||||
}
|
||||
if !l.waitForSlot(ctx, notify) {
|
||||
if waitRelease != nil {
|
||||
waitRelease()
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
if waitRelease != nil {
|
||||
waitRelease()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (l *imageConcurrencyLimiter) tryAcquireLocked(enabled bool, limit int, wait bool, maxWaiting int) (func(), bool, func(), <-chan struct{}) {
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
|
||||
if l.notify == nil {
|
||||
l.notify = make(chan struct{})
|
||||
}
|
||||
if l.enabled != enabled || l.limit != limit {
|
||||
l.enabled = enabled
|
||||
l.limit = limit
|
||||
}
|
||||
if l.active < l.limit {
|
||||
l.active++
|
||||
return l.releaseFunc(), true, nil, nil
|
||||
}
|
||||
if !wait {
|
||||
return nil, false, nil, nil
|
||||
}
|
||||
if maxWaiting > 0 && l.waiting >= maxWaiting {
|
||||
return nil, false, nil, nil
|
||||
}
|
||||
l.waiting++
|
||||
return nil, false, l.waiterReleaseFunc(), l.notify
|
||||
}
|
||||
|
||||
func (l *imageConcurrencyLimiter) waitForSlot(ctx context.Context, notify <-chan struct{}) bool {
|
||||
select {
|
||||
case <-notify:
|
||||
return true
|
||||
case <-ctx.Done():
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (l *imageConcurrencyLimiter) releaseFunc() func() {
|
||||
var once sync.Once
|
||||
return func() {
|
||||
once.Do(func() {
|
||||
l.mu.Lock()
|
||||
if l.active > 0 {
|
||||
l.active--
|
||||
}
|
||||
if l.notify != nil {
|
||||
close(l.notify)
|
||||
l.notify = make(chan struct{})
|
||||
}
|
||||
l.mu.Unlock()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (l *imageConcurrencyLimiter) waiterReleaseFunc() func() {
|
||||
var once sync.Once
|
||||
return func() {
|
||||
once.Do(func() {
|
||||
l.mu.Lock()
|
||||
if l.waiting > 0 {
|
||||
l.waiting--
|
||||
}
|
||||
l.mu.Unlock()
|
||||
})
|
||||
}
|
||||
}
|
||||
230
backend/internal/handler/image_concurrency_limiter_test.go
Normal file
230
backend/internal/handler/image_concurrency_limiter_test.go
Normal file
@ -0,0 +1,230 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
func TestImageConcurrencyLimiter_DefaultDisabledAllowsRequests(t *testing.T) {
|
||||
limiter := &imageConcurrencyLimiter{}
|
||||
|
||||
release, acquired := limiter.TryAcquire(false, 1)
|
||||
|
||||
require.True(t, acquired)
|
||||
require.Nil(t, release)
|
||||
}
|
||||
|
||||
func TestImageConcurrencyLimiter_RejectsWhenLimitReachedAndAllowsAfterRelease(t *testing.T) {
|
||||
limiter := &imageConcurrencyLimiter{}
|
||||
|
||||
release, acquired := limiter.TryAcquire(true, 1)
|
||||
require.True(t, acquired)
|
||||
require.NotNil(t, release)
|
||||
|
||||
secondRelease, secondAcquired := limiter.TryAcquire(true, 1)
|
||||
require.False(t, secondAcquired)
|
||||
require.Nil(t, secondRelease)
|
||||
|
||||
release()
|
||||
thirdRelease, thirdAcquired := limiter.TryAcquire(true, 1)
|
||||
require.True(t, thirdAcquired)
|
||||
require.NotNil(t, thirdRelease)
|
||||
thirdRelease()
|
||||
}
|
||||
|
||||
func TestImageConcurrencyLimiter_WaitsUntilSlotReleased(t *testing.T) {
|
||||
limiter := &imageConcurrencyLimiter{}
|
||||
release, acquired := limiter.Acquire(context.Background(), true, 1, true, time.Second, 1)
|
||||
require.True(t, acquired)
|
||||
require.NotNil(t, release)
|
||||
|
||||
acquiredCh := make(chan func(), 1)
|
||||
go func() {
|
||||
waitRelease, waitAcquired := limiter.Acquire(context.Background(), true, 1, true, time.Second, 1)
|
||||
require.True(t, waitAcquired)
|
||||
acquiredCh <- waitRelease
|
||||
}()
|
||||
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
release()
|
||||
|
||||
select {
|
||||
case waitRelease := <-acquiredCh:
|
||||
require.NotNil(t, waitRelease)
|
||||
waitRelease()
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("timed out waiting for image concurrency slot")
|
||||
}
|
||||
}
|
||||
|
||||
func TestImageConcurrencyLimiter_WaitTimesOut(t *testing.T) {
|
||||
limiter := &imageConcurrencyLimiter{}
|
||||
release, acquired := limiter.Acquire(context.Background(), true, 1, true, time.Second, 1)
|
||||
require.True(t, acquired)
|
||||
require.NotNil(t, release)
|
||||
defer release()
|
||||
|
||||
waitRelease, waitAcquired := limiter.Acquire(context.Background(), true, 1, true, 10*time.Millisecond, 1)
|
||||
|
||||
require.False(t, waitAcquired)
|
||||
require.Nil(t, waitRelease)
|
||||
}
|
||||
|
||||
func TestImageConcurrencyLimiter_MaxWaitingRequestsRejectsOverflow(t *testing.T) {
|
||||
limiter := &imageConcurrencyLimiter{}
|
||||
release, acquired := limiter.Acquire(context.Background(), true, 1, true, time.Second, 1)
|
||||
require.True(t, acquired)
|
||||
require.NotNil(t, release)
|
||||
defer release()
|
||||
|
||||
waitingStarted := make(chan struct{})
|
||||
waitingDone := make(chan struct{})
|
||||
go func() {
|
||||
close(waitingStarted)
|
||||
waitRelease, waitAcquired := limiter.Acquire(context.Background(), true, 1, true, time.Second, 1)
|
||||
if waitAcquired && waitRelease != nil {
|
||||
waitRelease()
|
||||
}
|
||||
close(waitingDone)
|
||||
}()
|
||||
<-waitingStarted
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
|
||||
overflowRelease, overflowAcquired := limiter.Acquire(context.Background(), true, 1, true, time.Second, 1)
|
||||
|
||||
require.False(t, overflowAcquired)
|
||||
require.Nil(t, overflowRelease)
|
||||
release()
|
||||
<-waitingDone
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayHandlerAcquireImageGenerationSlot_Returns429WhenFull(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/images/generations", nil)
|
||||
|
||||
h := &OpenAIGatewayHandler{
|
||||
cfg: &config.Config{
|
||||
Gateway: config.GatewayConfig{
|
||||
ImageConcurrency: config.ImageConcurrencyConfig{
|
||||
Enabled: true,
|
||||
MaxConcurrentRequests: 1,
|
||||
OverflowMode: config.ImageConcurrencyOverflowModeReject,
|
||||
},
|
||||
},
|
||||
},
|
||||
imageLimiter: &imageConcurrencyLimiter{},
|
||||
}
|
||||
release, acquired := h.acquireImageGenerationSlot(c, false)
|
||||
require.True(t, acquired)
|
||||
require.NotNil(t, release)
|
||||
defer release()
|
||||
|
||||
blockedRelease, blocked := h.acquireImageGenerationSlot(c, false)
|
||||
|
||||
require.False(t, blocked)
|
||||
require.Nil(t, blockedRelease)
|
||||
require.Equal(t, http.StatusTooManyRequests, rec.Code)
|
||||
require.Equal(t, "rate_limit_error", gjson.GetBytes(rec.Body.Bytes(), "error.type").String())
|
||||
require.Contains(t, rec.Body.String(), "Image generation concurrency limit exceeded")
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayHandlerResponses_ImageIntentRejectedByImageConcurrency(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
body := `{"model":"gpt-5.4","input":"draw","tools":[{"type":"image_generation"}]}`
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", strings.NewReader(body))
|
||||
groupID := int64(1)
|
||||
c.Set(string(middleware2.ContextKeyAPIKey), &service.APIKey{
|
||||
ID: 10,
|
||||
GroupID: &groupID,
|
||||
Group: &service.Group{
|
||||
ID: groupID,
|
||||
AllowImageGeneration: true,
|
||||
},
|
||||
User: &service.User{ID: 20},
|
||||
})
|
||||
c.Set(string(middleware2.ContextKeyUser), middleware2.AuthSubject{UserID: 20, Concurrency: 1})
|
||||
|
||||
h := &OpenAIGatewayHandler{
|
||||
gatewayService: &service.OpenAIGatewayService{},
|
||||
billingCacheService: &service.BillingCacheService{},
|
||||
apiKeyService: &service.APIKeyService{},
|
||||
concurrencyHelper: &ConcurrencyHelper{concurrencyService: service.NewConcurrencyService(&helperConcurrencyCacheStub{userSeq: []bool{true}})},
|
||||
errorPassthroughService: nil,
|
||||
cfg: &config.Config{Gateway: config.GatewayConfig{ImageConcurrency: config.ImageConcurrencyConfig{
|
||||
Enabled: true,
|
||||
MaxConcurrentRequests: 1,
|
||||
OverflowMode: config.ImageConcurrencyOverflowModeReject,
|
||||
}}},
|
||||
imageLimiter: &imageConcurrencyLimiter{},
|
||||
}
|
||||
release, acquired := h.acquireImageGenerationSlot(c, false)
|
||||
require.True(t, acquired)
|
||||
require.NotNil(t, release)
|
||||
defer release()
|
||||
rec.Body.Reset()
|
||||
rec.Code = 0
|
||||
|
||||
h.Responses(c)
|
||||
|
||||
require.Equal(t, http.StatusTooManyRequests, rec.Code)
|
||||
require.Equal(t, "rate_limit_error", gjson.GetBytes(rec.Body.Bytes(), "error.type").String())
|
||||
require.Contains(t, rec.Body.String(), "Image generation concurrency limit exceeded")
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayHandlerResponses_TextOnlyNotRejectedByImageConcurrency(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
body := `{"model":"gpt-5.4","input":"write code"}`
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", strings.NewReader(body))
|
||||
groupID := int64(1)
|
||||
c.Set(string(middleware2.ContextKeyAPIKey), &service.APIKey{
|
||||
ID: 10,
|
||||
GroupID: &groupID,
|
||||
Group: &service.Group{
|
||||
ID: groupID,
|
||||
AllowImageGeneration: true,
|
||||
},
|
||||
User: &service.User{ID: 20},
|
||||
})
|
||||
c.Set(string(middleware2.ContextKeyUser), middleware2.AuthSubject{UserID: 20, Concurrency: 1})
|
||||
|
||||
h := &OpenAIGatewayHandler{
|
||||
gatewayService: &service.OpenAIGatewayService{},
|
||||
billingCacheService: service.NewBillingCacheService(nil, nil, nil, nil, nil, nil, &config.Config{RunMode: config.RunModeSimple}),
|
||||
apiKeyService: &service.APIKeyService{},
|
||||
concurrencyHelper: &ConcurrencyHelper{concurrencyService: service.NewConcurrencyService(&helperConcurrencyCacheStub{userSeq: []bool{true}})},
|
||||
cfg: &config.Config{Gateway: config.GatewayConfig{ImageConcurrency: config.ImageConcurrencyConfig{
|
||||
Enabled: true,
|
||||
MaxConcurrentRequests: 1,
|
||||
OverflowMode: config.ImageConcurrencyOverflowModeReject,
|
||||
}}},
|
||||
imageLimiter: &imageConcurrencyLimiter{},
|
||||
}
|
||||
release, acquired := h.acquireImageGenerationSlot(c, false)
|
||||
require.True(t, acquired)
|
||||
require.NotNil(t, release)
|
||||
defer release()
|
||||
rec.Body.Reset()
|
||||
rec.Code = 0
|
||||
|
||||
h.Responses(c)
|
||||
|
||||
require.NotEqual(t, http.StatusTooManyRequests, rec.Code)
|
||||
require.NotContains(t, rec.Body.String(), "Image generation concurrency limit exceeded")
|
||||
}
|
||||
@ -10,6 +10,7 @@ import (
|
||||
pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/openai_compat"
|
||||
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
@ -120,7 +121,6 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
|
||||
var lastFailoverErr *service.UpstreamFailoverError
|
||||
|
||||
for {
|
||||
c.Set("openai_chat_completions_fallback_model", "")
|
||||
reqLog.Debug("openai_chat_completions.account_selecting", zap.Int("excluded_account_count", len(failedAccountIDs)))
|
||||
selection, scheduleDecision, err := h.gatewayService.SelectAccountWithScheduler(
|
||||
c.Request.Context(),
|
||||
@ -138,32 +138,8 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
|
||||
zap.Int("excluded_account_count", len(failedAccountIDs)),
|
||||
)
|
||||
if len(failedAccountIDs) == 0 {
|
||||
defaultModel := ""
|
||||
if apiKey.Group != nil {
|
||||
defaultModel = apiKey.Group.DefaultMappedModel
|
||||
}
|
||||
if defaultModel != "" && defaultModel != reqModel {
|
||||
reqLog.Info("openai_chat_completions.fallback_to_default_model",
|
||||
zap.String("default_mapped_model", defaultModel),
|
||||
)
|
||||
selection, scheduleDecision, err = h.gatewayService.SelectAccountWithScheduler(
|
||||
c.Request.Context(),
|
||||
apiKey.GroupID,
|
||||
"",
|
||||
sessionHash,
|
||||
defaultModel,
|
||||
failedAccountIDs,
|
||||
service.OpenAIUpstreamTransportAny,
|
||||
false,
|
||||
)
|
||||
if err == nil && selection != nil {
|
||||
c.Set("openai_chat_completions_fallback_model", defaultModel)
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "Service temporarily unavailable", streamStarted)
|
||||
return
|
||||
}
|
||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "Service temporarily unavailable", streamStarted)
|
||||
return
|
||||
} else {
|
||||
if lastFailoverErr != nil {
|
||||
h.handleFailoverExhausted(c, lastFailoverErr, streamStarted)
|
||||
@ -191,12 +167,11 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
|
||||
service.SetOpsLatencyMs(c, service.OpsRoutingLatencyMsKey, time.Since(routingStart).Milliseconds())
|
||||
forwardStart := time.Now()
|
||||
|
||||
defaultMappedModel := resolveOpenAIForwardDefaultMappedModel(apiKey, c.GetString("openai_chat_completions_fallback_model"))
|
||||
forwardBody := body
|
||||
if channelMapping.Mapped {
|
||||
forwardBody = h.gatewayService.ReplaceModelInBody(body, channelMapping.MappedModel)
|
||||
}
|
||||
result, err := h.gatewayService.ForwardAsChatCompletions(c.Request.Context(), c, account, forwardBody, promptCacheKey, defaultMappedModel)
|
||||
result, err := h.gatewayService.ForwardAsChatCompletions(c.Request.Context(), c, account, forwardBody, promptCacheKey, "")
|
||||
|
||||
forwardDurationMs := time.Since(forwardStart).Milliseconds()
|
||||
if accountReleaseFunc != nil {
|
||||
@ -212,52 +187,60 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
|
||||
service.SetOpsLatencyMs(c, service.OpsTimeToFirstTokenMsKey, int64(*result.FirstTokenMs))
|
||||
}
|
||||
if err != nil {
|
||||
var failoverErr *service.UpstreamFailoverError
|
||||
if errors.As(err, &failoverErr) {
|
||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
|
||||
// Pool mode: retry on the same account
|
||||
if failoverErr.RetryableOnSameAccount {
|
||||
retryLimit := account.GetPoolModeRetryCount()
|
||||
if sameAccountRetryCount[account.ID] < retryLimit {
|
||||
sameAccountRetryCount[account.ID]++
|
||||
reqLog.Warn("openai_chat_completions.pool_mode_same_account_retry",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int("upstream_status", failoverErr.StatusCode),
|
||||
zap.Int("retry_limit", retryLimit),
|
||||
zap.Int("retry_count", sameAccountRetryCount[account.ID]),
|
||||
)
|
||||
select {
|
||||
case <-c.Request.Context().Done():
|
||||
return
|
||||
case <-time.After(sameAccountRetryDelay):
|
||||
}
|
||||
continue
|
||||
}
|
||||
}
|
||||
h.gatewayService.RecordOpenAIAccountSwitch()
|
||||
failedAccountIDs[account.ID] = struct{}{}
|
||||
lastFailoverErr = failoverErr
|
||||
if switchCount >= maxAccountSwitches {
|
||||
h.handleFailoverExhausted(c, failoverErr, streamStarted)
|
||||
return
|
||||
}
|
||||
switchCount++
|
||||
reqLog.Warn("openai_chat_completions.upstream_failover_switching",
|
||||
if result != nil && result.ImageCount > 0 {
|
||||
reqLog.Warn("openai_chat_completions.forward_partial_error_with_image_result",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int("upstream_status", failoverErr.StatusCode),
|
||||
zap.Int("switch_count", switchCount),
|
||||
zap.Int("max_switches", maxAccountSwitches),
|
||||
zap.Int("image_count", result.ImageCount),
|
||||
zap.Error(err),
|
||||
)
|
||||
continue
|
||||
} else {
|
||||
var failoverErr *service.UpstreamFailoverError
|
||||
if errors.As(err, &failoverErr) {
|
||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
|
||||
// Pool mode: retry on the same account
|
||||
if failoverErr.RetryableOnSameAccount {
|
||||
retryLimit := account.GetPoolModeRetryCount()
|
||||
if sameAccountRetryCount[account.ID] < retryLimit {
|
||||
sameAccountRetryCount[account.ID]++
|
||||
reqLog.Warn("openai_chat_completions.pool_mode_same_account_retry",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int("upstream_status", failoverErr.StatusCode),
|
||||
zap.Int("retry_limit", retryLimit),
|
||||
zap.Int("retry_count", sameAccountRetryCount[account.ID]),
|
||||
)
|
||||
select {
|
||||
case <-c.Request.Context().Done():
|
||||
return
|
||||
case <-time.After(sameAccountRetryDelay):
|
||||
}
|
||||
continue
|
||||
}
|
||||
}
|
||||
h.gatewayService.RecordOpenAIAccountSwitch()
|
||||
failedAccountIDs[account.ID] = struct{}{}
|
||||
lastFailoverErr = failoverErr
|
||||
if switchCount >= maxAccountSwitches {
|
||||
h.handleFailoverExhausted(c, failoverErr, streamStarted)
|
||||
return
|
||||
}
|
||||
switchCount++
|
||||
reqLog.Warn("openai_chat_completions.upstream_failover_switching",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int("upstream_status", failoverErr.StatusCode),
|
||||
zap.Int("switch_count", switchCount),
|
||||
zap.Int("max_switches", maxAccountSwitches),
|
||||
)
|
||||
continue
|
||||
}
|
||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
|
||||
wroteFallback := h.ensureForwardErrorResponse(c, streamStarted)
|
||||
reqLog.Warn("openai_chat_completions.forward_failed",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Bool("fallback_error_response_written", wroteFallback),
|
||||
zap.Error(err),
|
||||
)
|
||||
return
|
||||
}
|
||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
|
||||
wroteFallback := h.ensureForwardErrorResponse(c, streamStarted)
|
||||
reqLog.Warn("openai_chat_completions.forward_failed",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Bool("fallback_error_response_written", wroteFallback),
|
||||
zap.Error(err),
|
||||
)
|
||||
return
|
||||
}
|
||||
if result != nil {
|
||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, result.FirstTokenMs)
|
||||
@ -267,16 +250,18 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
|
||||
|
||||
userAgent := c.GetHeader("User-Agent")
|
||||
clientIP := ip.GetClientIP(c)
|
||||
inboundEndpoint := GetInboundEndpoint(c)
|
||||
upstreamEndpoint := resolveRawCCUpstreamEndpoint(c, account)
|
||||
|
||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||
h.submitOpenAIUsageRecordTask(result, func(ctx context.Context) {
|
||||
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
|
||||
Result: result,
|
||||
APIKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
InboundEndpoint: GetInboundEndpoint(c),
|
||||
UpstreamEndpoint: GetUpstreamEndpoint(c, account.Platform),
|
||||
InboundEndpoint: inboundEndpoint,
|
||||
UpstreamEndpoint: upstreamEndpoint,
|
||||
UserAgent: userAgent,
|
||||
IPAddress: clientIP,
|
||||
APIKeyService: h.apiKeyService,
|
||||
@ -299,3 +284,16 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// resolveRawCCUpstreamEndpoint returns the actual upstream endpoint for
|
||||
// OpenAI Chat Completions requests. For APIKey accounts whose upstream
|
||||
// has been probed to not support the Responses API, the request is
|
||||
// forwarded directly to /v1/chat/completions — not through the default
|
||||
// CC→Responses conversion path.
|
||||
func resolveRawCCUpstreamEndpoint(c *gin.Context, account *service.Account) string {
|
||||
if account != nil && account.Type == service.AccountTypeAPIKey &&
|
||||
!openai_compat.ShouldUseResponsesAPI(account.Extra) {
|
||||
return "/v1/chat/completions"
|
||||
}
|
||||
return GetUpstreamEndpoint(c, account.Platform)
|
||||
}
|
||||
|
||||
@ -33,20 +33,11 @@ type OpenAIGatewayHandler struct {
|
||||
usageRecordWorkerPool *service.UsageRecordWorkerPool
|
||||
errorPassthroughService *service.ErrorPassthroughService
|
||||
concurrencyHelper *ConcurrencyHelper
|
||||
imageLimiter *imageConcurrencyLimiter
|
||||
maxAccountSwitches int
|
||||
cfg *config.Config
|
||||
}
|
||||
|
||||
func resolveOpenAIForwardDefaultMappedModel(apiKey *service.APIKey, fallbackModel string) string {
|
||||
if fallbackModel = strings.TrimSpace(fallbackModel); fallbackModel != "" {
|
||||
return fallbackModel
|
||||
}
|
||||
if apiKey == nil || apiKey.Group == nil {
|
||||
return ""
|
||||
}
|
||||
return strings.TrimSpace(apiKey.Group.DefaultMappedModel)
|
||||
}
|
||||
|
||||
func resolveOpenAIMessagesDispatchMappedModel(apiKey *service.APIKey, requestedModel string) string {
|
||||
if apiKey == nil || apiKey.Group == nil {
|
||||
return ""
|
||||
@ -79,6 +70,7 @@ func NewOpenAIGatewayHandler(
|
||||
usageRecordWorkerPool: usageRecordWorkerPool,
|
||||
errorPassthroughService: errorPassthroughService,
|
||||
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatComment, pingInterval),
|
||||
imageLimiter: &imageConcurrencyLimiter{},
|
||||
maxAccountSwitches: maxAccountSwitches,
|
||||
cfg: cfg,
|
||||
}
|
||||
@ -197,6 +189,23 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
setOpsRequestContext(c, reqModel, reqStream, body)
|
||||
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false)))
|
||||
|
||||
imageIntent := service.IsImageGenerationIntent("/v1/responses", reqModel, body)
|
||||
if imageIntent && !service.GroupAllowsImageGeneration(apiKey.Group) {
|
||||
h.errorResponse(c, http.StatusForbidden, "permission_error", service.ImageGenerationPermissionMessage())
|
||||
return
|
||||
}
|
||||
var imageReleaseFunc func()
|
||||
if imageIntent {
|
||||
var imageAcquired bool
|
||||
imageReleaseFunc, imageAcquired = h.acquireImageGenerationSlot(c, streamStarted)
|
||||
if !imageAcquired {
|
||||
return
|
||||
}
|
||||
if imageReleaseFunc != nil {
|
||||
defer imageReleaseFunc()
|
||||
}
|
||||
}
|
||||
|
||||
// 解析渠道级模型映射
|
||||
channelMapping, _ := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel)
|
||||
|
||||
@ -328,57 +337,65 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
service.SetOpsLatencyMs(c, service.OpsTimeToFirstTokenMsKey, int64(*result.FirstTokenMs))
|
||||
}
|
||||
if err != nil {
|
||||
var failoverErr *service.UpstreamFailoverError
|
||||
if errors.As(err, &failoverErr) {
|
||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
|
||||
// 池模式:同账号重试
|
||||
if failoverErr.RetryableOnSameAccount {
|
||||
retryLimit := account.GetPoolModeRetryCount()
|
||||
if sameAccountRetryCount[account.ID] < retryLimit {
|
||||
sameAccountRetryCount[account.ID]++
|
||||
reqLog.Warn("openai.pool_mode_same_account_retry",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int("upstream_status", failoverErr.StatusCode),
|
||||
zap.Int("retry_limit", retryLimit),
|
||||
zap.Int("retry_count", sameAccountRetryCount[account.ID]),
|
||||
)
|
||||
select {
|
||||
case <-c.Request.Context().Done():
|
||||
return
|
||||
case <-time.After(sameAccountRetryDelay):
|
||||
if result != nil && result.ImageCount > 0 {
|
||||
reqLog.Warn("openai.forward_partial_error_with_image_result",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int("image_count", result.ImageCount),
|
||||
zap.Error(err),
|
||||
)
|
||||
} else {
|
||||
var failoverErr *service.UpstreamFailoverError
|
||||
if errors.As(err, &failoverErr) {
|
||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
|
||||
// 池模式:同账号重试
|
||||
if failoverErr.RetryableOnSameAccount {
|
||||
retryLimit := account.GetPoolModeRetryCount()
|
||||
if sameAccountRetryCount[account.ID] < retryLimit {
|
||||
sameAccountRetryCount[account.ID]++
|
||||
reqLog.Warn("openai.pool_mode_same_account_retry",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int("upstream_status", failoverErr.StatusCode),
|
||||
zap.Int("retry_limit", retryLimit),
|
||||
zap.Int("retry_count", sameAccountRetryCount[account.ID]),
|
||||
)
|
||||
select {
|
||||
case <-c.Request.Context().Done():
|
||||
return
|
||||
case <-time.After(sameAccountRetryDelay):
|
||||
}
|
||||
continue
|
||||
}
|
||||
continue
|
||||
}
|
||||
h.gatewayService.RecordOpenAIAccountSwitch()
|
||||
failedAccountIDs[account.ID] = struct{}{}
|
||||
lastFailoverErr = failoverErr
|
||||
if switchCount >= maxAccountSwitches {
|
||||
h.handleFailoverExhausted(c, failoverErr, streamStarted)
|
||||
return
|
||||
}
|
||||
switchCount++
|
||||
reqLog.Warn("openai.upstream_failover_switching",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int("upstream_status", failoverErr.StatusCode),
|
||||
zap.Int("switch_count", switchCount),
|
||||
zap.Int("max_switches", maxAccountSwitches),
|
||||
)
|
||||
continue
|
||||
}
|
||||
h.gatewayService.RecordOpenAIAccountSwitch()
|
||||
failedAccountIDs[account.ID] = struct{}{}
|
||||
lastFailoverErr = failoverErr
|
||||
if switchCount >= maxAccountSwitches {
|
||||
h.handleFailoverExhausted(c, failoverErr, streamStarted)
|
||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
|
||||
wroteFallback := h.ensureForwardErrorResponse(c, streamStarted)
|
||||
fields := []zap.Field{
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Bool("fallback_error_response_written", wroteFallback),
|
||||
zap.Error(err),
|
||||
}
|
||||
if shouldLogOpenAIForwardFailureAsWarn(c, wroteFallback) {
|
||||
reqLog.Warn("openai.forward_failed", fields...)
|
||||
return
|
||||
}
|
||||
switchCount++
|
||||
reqLog.Warn("openai.upstream_failover_switching",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int("upstream_status", failoverErr.StatusCode),
|
||||
zap.Int("switch_count", switchCount),
|
||||
zap.Int("max_switches", maxAccountSwitches),
|
||||
)
|
||||
continue
|
||||
}
|
||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
|
||||
wroteFallback := h.ensureForwardErrorResponse(c, streamStarted)
|
||||
fields := []zap.Field{
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Bool("fallback_error_response_written", wroteFallback),
|
||||
zap.Error(err),
|
||||
}
|
||||
if shouldLogOpenAIForwardFailureAsWarn(c, wroteFallback) {
|
||||
reqLog.Warn("openai.forward_failed", fields...)
|
||||
reqLog.Error("openai.forward_failed", fields...)
|
||||
return
|
||||
}
|
||||
reqLog.Error("openai.forward_failed", fields...)
|
||||
return
|
||||
}
|
||||
if result != nil {
|
||||
if account.Type == service.AccountTypeOAuth {
|
||||
@ -393,17 +410,19 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
userAgent := c.GetHeader("User-Agent")
|
||||
clientIP := ip.GetClientIP(c)
|
||||
requestPayloadHash := service.HashUsageRequestPayload(body)
|
||||
inboundEndpoint := GetInboundEndpoint(c)
|
||||
upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform)
|
||||
|
||||
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
|
||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||
h.submitOpenAIUsageRecordTask(result, func(ctx context.Context) {
|
||||
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
|
||||
Result: result,
|
||||
APIKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
InboundEndpoint: GetInboundEndpoint(c),
|
||||
UpstreamEndpoint: GetUpstreamEndpoint(c, account.Platform),
|
||||
InboundEndpoint: inboundEndpoint,
|
||||
UpstreamEndpoint: upstreamEndpoint,
|
||||
UserAgent: userAgent,
|
||||
IPAddress: clientIP,
|
||||
RequestPayloadHash: requestPayloadHash,
|
||||
@ -613,21 +632,7 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
|
||||
|
||||
sessionHash := h.gatewayService.GenerateSessionHash(c, body)
|
||||
promptCacheKey := h.gatewayService.ExtractSessionID(c, body)
|
||||
|
||||
// Anthropic 格式的请求在 metadata.user_id 中携带 session 标识,
|
||||
// 而非 OpenAI 的 session_id/conversation_id headers。
|
||||
// 从中派生 sessionHash(sticky session)和 promptCacheKey(upstream cache)。
|
||||
if sessionHash == "" || promptCacheKey == "" {
|
||||
if userID := strings.TrimSpace(gjson.GetBytes(body, "metadata.user_id").String()); userID != "" {
|
||||
seed := reqModel + "-" + userID
|
||||
if promptCacheKey == "" {
|
||||
promptCacheKey = service.GenerateSessionUUID(seed)
|
||||
}
|
||||
if sessionHash == "" {
|
||||
sessionHash = service.DeriveSessionHashFromSeed(seed)
|
||||
}
|
||||
}
|
||||
}
|
||||
sessionHash, promptCacheKey = resolveOpenAIMessagesMetadataSession(sessionHash, promptCacheKey, reqModel, body)
|
||||
|
||||
maxAccountSwitches := h.maxAccountSwitches
|
||||
switchCount := 0
|
||||
@ -711,52 +716,60 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
|
||||
service.SetOpsLatencyMs(c, service.OpsTimeToFirstTokenMsKey, int64(*result.FirstTokenMs))
|
||||
}
|
||||
if err != nil {
|
||||
var failoverErr *service.UpstreamFailoverError
|
||||
if errors.As(err, &failoverErr) {
|
||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
|
||||
// 池模式:同账号重试
|
||||
if failoverErr.RetryableOnSameAccount {
|
||||
retryLimit := account.GetPoolModeRetryCount()
|
||||
if sameAccountRetryCount[account.ID] < retryLimit {
|
||||
sameAccountRetryCount[account.ID]++
|
||||
reqLog.Warn("openai_messages.pool_mode_same_account_retry",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int("upstream_status", failoverErr.StatusCode),
|
||||
zap.Int("retry_limit", retryLimit),
|
||||
zap.Int("retry_count", sameAccountRetryCount[account.ID]),
|
||||
)
|
||||
select {
|
||||
case <-c.Request.Context().Done():
|
||||
return
|
||||
case <-time.After(sameAccountRetryDelay):
|
||||
}
|
||||
continue
|
||||
}
|
||||
}
|
||||
h.gatewayService.RecordOpenAIAccountSwitch()
|
||||
failedAccountIDs[account.ID] = struct{}{}
|
||||
lastFailoverErr = failoverErr
|
||||
if switchCount >= maxAccountSwitches {
|
||||
h.handleAnthropicFailoverExhausted(c, failoverErr, streamStarted)
|
||||
return
|
||||
}
|
||||
switchCount++
|
||||
reqLog.Warn("openai_messages.upstream_failover_switching",
|
||||
if result != nil && result.ImageCount > 0 {
|
||||
reqLog.Warn("openai_messages.forward_partial_error_with_image_result",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int("upstream_status", failoverErr.StatusCode),
|
||||
zap.Int("switch_count", switchCount),
|
||||
zap.Int("max_switches", maxAccountSwitches),
|
||||
zap.Int("image_count", result.ImageCount),
|
||||
zap.Error(err),
|
||||
)
|
||||
continue
|
||||
} else {
|
||||
var failoverErr *service.UpstreamFailoverError
|
||||
if errors.As(err, &failoverErr) {
|
||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
|
||||
// 池模式:同账号重试
|
||||
if failoverErr.RetryableOnSameAccount {
|
||||
retryLimit := account.GetPoolModeRetryCount()
|
||||
if sameAccountRetryCount[account.ID] < retryLimit {
|
||||
sameAccountRetryCount[account.ID]++
|
||||
reqLog.Warn("openai_messages.pool_mode_same_account_retry",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int("upstream_status", failoverErr.StatusCode),
|
||||
zap.Int("retry_limit", retryLimit),
|
||||
zap.Int("retry_count", sameAccountRetryCount[account.ID]),
|
||||
)
|
||||
select {
|
||||
case <-c.Request.Context().Done():
|
||||
return
|
||||
case <-time.After(sameAccountRetryDelay):
|
||||
}
|
||||
continue
|
||||
}
|
||||
}
|
||||
h.gatewayService.RecordOpenAIAccountSwitch()
|
||||
failedAccountIDs[account.ID] = struct{}{}
|
||||
lastFailoverErr = failoverErr
|
||||
if switchCount >= maxAccountSwitches {
|
||||
h.handleAnthropicFailoverExhausted(c, failoverErr, streamStarted)
|
||||
return
|
||||
}
|
||||
switchCount++
|
||||
reqLog.Warn("openai_messages.upstream_failover_switching",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int("upstream_status", failoverErr.StatusCode),
|
||||
zap.Int("switch_count", switchCount),
|
||||
zap.Int("max_switches", maxAccountSwitches),
|
||||
)
|
||||
continue
|
||||
}
|
||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
|
||||
wroteFallback := h.ensureAnthropicErrorResponse(c, streamStarted)
|
||||
reqLog.Warn("openai_messages.forward_failed",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Bool("fallback_error_response_written", wroteFallback),
|
||||
zap.Error(err),
|
||||
)
|
||||
return
|
||||
}
|
||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
|
||||
wroteFallback := h.ensureAnthropicErrorResponse(c, streamStarted)
|
||||
reqLog.Warn("openai_messages.forward_failed",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Bool("fallback_error_response_written", wroteFallback),
|
||||
zap.Error(err),
|
||||
)
|
||||
return
|
||||
}
|
||||
if result != nil {
|
||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, result.FirstTokenMs)
|
||||
@ -767,16 +780,18 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
|
||||
userAgent := c.GetHeader("User-Agent")
|
||||
clientIP := ip.GetClientIP(c)
|
||||
requestPayloadHash := service.HashUsageRequestPayload(body)
|
||||
inboundEndpoint := GetInboundEndpoint(c)
|
||||
upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform)
|
||||
|
||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||
h.submitOpenAIUsageRecordTask(result, func(ctx context.Context) {
|
||||
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
|
||||
Result: result,
|
||||
APIKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
InboundEndpoint: GetInboundEndpoint(c),
|
||||
UpstreamEndpoint: GetUpstreamEndpoint(c, account.Platform),
|
||||
InboundEndpoint: inboundEndpoint,
|
||||
UpstreamEndpoint: upstreamEndpoint,
|
||||
UserAgent: userAgent,
|
||||
IPAddress: clientIP,
|
||||
RequestPayloadHash: requestPayloadHash,
|
||||
@ -801,6 +816,20 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
func resolveOpenAIMessagesMetadataSession(sessionHash, promptCacheKey, reqModel string, body []byte) (string, string) {
|
||||
// Anthropic metadata.user_id 只作为账号粘性信号。上游 GPT/Codex 缓存键
|
||||
// 交给 ForwardAsAnthropic 从 cache_control 或完整消息 digest 派生,避免
|
||||
// 固定 metadata key 压住后续 turn 的缓存滚动。
|
||||
if sessionHash != "" {
|
||||
return sessionHash, promptCacheKey
|
||||
}
|
||||
if userID := strings.TrimSpace(gjson.GetBytes(body, "metadata.user_id").String()); userID != "" {
|
||||
seed := reqModel + "-" + userID
|
||||
sessionHash = service.DeriveSessionHashFromSeed(seed)
|
||||
}
|
||||
return sessionHash, promptCacheKey
|
||||
}
|
||||
|
||||
// anthropicErrorResponse writes an error in Anthropic Messages API format.
|
||||
func (h *OpenAIGatewayHandler) anthropicErrorResponse(c *gin.Context, status int, errType, message string) {
|
||||
c.JSON(status, gin.H{
|
||||
@ -1124,6 +1153,11 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) {
|
||||
setOpsRequestContext(c, reqModel, true, firstMessage)
|
||||
setOpsEndpointContext(c, "", int16(service.RequestTypeWSV2))
|
||||
|
||||
if service.IsImageGenerationIntent("/v1/responses", reqModel, firstMessage) && !service.GroupAllowsImageGeneration(apiKey.Group) {
|
||||
closeOpenAIClientWS(wsConn, coderws.StatusPolicyViolation, service.ImageGenerationPermissionMessage())
|
||||
return
|
||||
}
|
||||
|
||||
// 解析渠道级模型映射
|
||||
channelMappingWS, _ := h.gatewayService.ResolveChannelMappingAndRestrict(ctx, apiKey.GroupID, reqModel)
|
||||
|
||||
@ -1233,6 +1267,7 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) {
|
||||
)
|
||||
|
||||
hooks := &service.OpenAIWSIngressHooks{
|
||||
InitialRequestModel: reqModel,
|
||||
BeforeTurn: func(turn int) error {
|
||||
if turn == 1 {
|
||||
return nil
|
||||
@ -1266,22 +1301,34 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) {
|
||||
},
|
||||
AfterTurn: func(turn int, result *service.OpenAIForwardResult, turnErr error) {
|
||||
releaseTurnSlots()
|
||||
if turnErr != nil || result == nil {
|
||||
if turnErr != nil {
|
||||
if result == nil || result.ImageCount <= 0 {
|
||||
return
|
||||
}
|
||||
reqLog.Warn("openai.websocket_partial_error_with_image_result",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int("image_count", result.ImageCount),
|
||||
zap.Error(turnErr),
|
||||
)
|
||||
}
|
||||
if result == nil {
|
||||
return
|
||||
}
|
||||
if account.Type == service.AccountTypeOAuth {
|
||||
h.gatewayService.UpdateCodexUsageSnapshotFromHeaders(ctx, account.ID, result.ResponseHeaders)
|
||||
}
|
||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, result.FirstTokenMs)
|
||||
h.submitUsageRecordTask(func(taskCtx context.Context) {
|
||||
inboundEndpoint := GetInboundEndpoint(c)
|
||||
upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform)
|
||||
h.submitOpenAIUsageRecordTask(result, func(taskCtx context.Context) {
|
||||
if err := h.gatewayService.RecordUsage(taskCtx, &service.OpenAIRecordUsageInput{
|
||||
Result: result,
|
||||
APIKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
InboundEndpoint: GetInboundEndpoint(c),
|
||||
UpstreamEndpoint: GetUpstreamEndpoint(c, account.Platform),
|
||||
InboundEndpoint: inboundEndpoint,
|
||||
UpstreamEndpoint: upstreamEndpoint,
|
||||
UserAgent: userAgent,
|
||||
IPAddress: clientIP,
|
||||
RequestPayloadHash: service.HashUsageRequestPayload(firstMessage),
|
||||
@ -1449,6 +1496,60 @@ func (h *OpenAIGatewayHandler) submitUsageRecordTask(task service.UsageRecordTas
|
||||
task(ctx)
|
||||
}
|
||||
|
||||
func (h *OpenAIGatewayHandler) submitOpenAIUsageRecordTask(result *service.OpenAIForwardResult, task service.UsageRecordTask) {
|
||||
if result != nil && result.ImageCount > 0 {
|
||||
h.submitMandatoryUsageRecordTask(task)
|
||||
return
|
||||
}
|
||||
h.submitUsageRecordTask(task)
|
||||
}
|
||||
|
||||
func (h *OpenAIGatewayHandler) submitMandatoryUsageRecordTask(task service.UsageRecordTask) {
|
||||
if task == nil {
|
||||
return
|
||||
}
|
||||
if h.usageRecordWorkerPool != nil {
|
||||
if mode := h.usageRecordWorkerPool.Submit(task); mode != service.UsageRecordSubmitModeDropped {
|
||||
return
|
||||
}
|
||||
logger.L().With(
|
||||
zap.String("component", "handler.openai_gateway.usage"),
|
||||
).Warn("openai.usage_record_task_mandatory_sync_fallback")
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
defer func() {
|
||||
if recovered := recover(); recovered != nil {
|
||||
logger.L().With(
|
||||
zap.String("component", "handler.openai_gateway.usage"),
|
||||
zap.Any("panic", recovered),
|
||||
).Error("openai.usage_record_task_panic_recovered")
|
||||
}
|
||||
}()
|
||||
task(ctx)
|
||||
}
|
||||
|
||||
func (h *OpenAIGatewayHandler) acquireImageGenerationSlot(c *gin.Context, streamStarted bool) (func(), bool) {
|
||||
if h == nil || h.cfg == nil || h.imageLimiter == nil {
|
||||
return nil, true
|
||||
}
|
||||
imageConcurrency := h.cfg.Gateway.ImageConcurrency
|
||||
wait := strings.TrimSpace(imageConcurrency.OverflowMode) == config.ImageConcurrencyOverflowModeWait
|
||||
release, acquired := h.imageLimiter.Acquire(
|
||||
c.Request.Context(),
|
||||
imageConcurrency.Enabled,
|
||||
imageConcurrency.MaxConcurrentRequests,
|
||||
wait,
|
||||
time.Duration(imageConcurrency.WaitTimeoutSeconds)*time.Second,
|
||||
imageConcurrency.MaxWaitingRequests,
|
||||
)
|
||||
if acquired {
|
||||
return release, true
|
||||
}
|
||||
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Image generation concurrency limit exceeded, please retry later", streamStarted)
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// handleConcurrencyError handles concurrency-related errors with proper 429 response
|
||||
func (h *OpenAIGatewayHandler) handleConcurrencyError(c *gin.Context, err error, slotType string, streamStarted bool) {
|
||||
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error",
|
||||
|
||||
@ -10,6 +10,7 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil"
|
||||
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
@ -91,6 +92,24 @@ func TestOpenAIHandleStreamingAwareError_JSONEscaping(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveOpenAIMessagesMetadataSession_DoesNotDerivePromptCacheKey(t *testing.T) {
|
||||
body := []byte(`{"model":"claude-sonnet-4-5","metadata":{"user_id":"claude-code-session"},"messages":[{"role":"user","content":"hello"}]}`)
|
||||
|
||||
sessionHash, promptCacheKey := resolveOpenAIMessagesMetadataSession("", "", "claude-sonnet-4-5", body)
|
||||
|
||||
require.NotEmpty(t, sessionHash)
|
||||
require.Empty(t, promptCacheKey)
|
||||
}
|
||||
|
||||
func TestResolveOpenAIMessagesMetadataSession_PreservesExplicitPromptCacheKey(t *testing.T) {
|
||||
body := []byte(`{"metadata":{"user_id":"claude-code-session"}}`)
|
||||
|
||||
sessionHash, promptCacheKey := resolveOpenAIMessagesMetadataSession("", "explicit-cache", "claude-sonnet-4-5", body)
|
||||
|
||||
require.NotEmpty(t, sessionHash)
|
||||
require.Equal(t, "explicit-cache", promptCacheKey)
|
||||
}
|
||||
|
||||
func TestOpenAIHandleStreamingAwareError_NonStreaming(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
@ -352,30 +371,6 @@ func TestOpenAIEnsureResponsesDependencies(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestResolveOpenAIForwardDefaultMappedModel(t *testing.T) {
|
||||
t.Run("prefers_explicit_fallback_model", func(t *testing.T) {
|
||||
apiKey := &service.APIKey{
|
||||
Group: &service.Group{DefaultMappedModel: "gpt-5.4"},
|
||||
}
|
||||
require.Equal(t, "gpt-5.2", resolveOpenAIForwardDefaultMappedModel(apiKey, " gpt-5.2 "))
|
||||
})
|
||||
|
||||
t.Run("uses_group_default_when_explicit_fallback_absent", func(t *testing.T) {
|
||||
apiKey := &service.APIKey{
|
||||
Group: &service.Group{DefaultMappedModel: "gpt-5.4"},
|
||||
}
|
||||
require.Equal(t, "gpt-5.4", resolveOpenAIForwardDefaultMappedModel(apiKey, ""))
|
||||
})
|
||||
|
||||
t.Run("returns_empty_without_group_default", func(t *testing.T) {
|
||||
require.Empty(t, resolveOpenAIForwardDefaultMappedModel(nil, ""))
|
||||
require.Empty(t, resolveOpenAIForwardDefaultMappedModel(&service.APIKey{}, ""))
|
||||
require.Empty(t, resolveOpenAIForwardDefaultMappedModel(&service.APIKey{
|
||||
Group: &service.Group{},
|
||||
}, ""))
|
||||
})
|
||||
}
|
||||
|
||||
func TestResolveOpenAIMessagesDispatchMappedModel(t *testing.T) {
|
||||
t.Run("exact_claude_model_override_wins", func(t *testing.T) {
|
||||
apiKey := &service.APIKey{
|
||||
@ -651,6 +646,46 @@ func TestOpenAIResponsesWebSocket_PreviousResponseIDKindLoggedBeforeAcquireFailu
|
||||
require.Contains(t, strings.ToLower(closeErr.Reason), "failed to acquire user concurrency slot")
|
||||
}
|
||||
|
||||
func TestOpenAIResponsesWebSocket_PassthroughUsageLogPersistsUserAgentAndReasoningEffort(t *testing.T) {
|
||||
got := runOpenAIResponsesWebSocketUsageLogCase(t, openAIResponsesWSUsageLogCase{
|
||||
firstPayload: `{"type":"response.create","model":"gpt-5.4","stream":false,"reasoning":{"effort":"HIGH"}}`,
|
||||
userAgent: testStringPtr("codex_cli_rs/0.125.0 test"),
|
||||
})
|
||||
|
||||
require.NotNil(t, got.log.UserAgent)
|
||||
require.Equal(t, "codex_cli_rs/0.125.0 test", *got.log.UserAgent)
|
||||
require.NotNil(t, got.log.ReasoningEffort)
|
||||
require.Equal(t, "high", *got.log.ReasoningEffort)
|
||||
require.True(t, got.log.OpenAIWSMode)
|
||||
}
|
||||
|
||||
func TestOpenAIResponsesWebSocket_PassthroughUsageLogInfersReasoningFromInitialRequestModel(t *testing.T) {
|
||||
got := runOpenAIResponsesWebSocketUsageLogCase(t, openAIResponsesWSUsageLogCase{
|
||||
firstPayload: `{"type":"response.create","model":"gpt-5.4-xhigh","stream":false}`,
|
||||
userAgent: testStringPtr("codex_cli_rs/0.125.0 mapped"),
|
||||
channelMapping: map[string]string{
|
||||
"gpt-5.4-xhigh": "gpt-5.4",
|
||||
},
|
||||
})
|
||||
|
||||
require.Equal(t, "gpt-5.4", gjson.GetBytes(got.upstreamFirstPayload, "model").String(),
|
||||
"上游首帧应使用渠道映射后的模型")
|
||||
require.NotNil(t, got.log.ReasoningEffort)
|
||||
require.Equal(t, "xhigh", *got.log.ReasoningEffort,
|
||||
"usage log reasoning effort 必须使用渠道映射前首帧模型后缀推导")
|
||||
}
|
||||
|
||||
func TestOpenAIResponsesWebSocket_PassthroughUsageLogLeavesUserAgentNilWhenMissing(t *testing.T) {
|
||||
got := runOpenAIResponsesWebSocketUsageLogCase(t, openAIResponsesWSUsageLogCase{
|
||||
firstPayload: `{"type":"response.create","model":"gpt-5.4","stream":false,"reasoning":{"effort":"medium"}}`,
|
||||
userAgent: testStringPtr(""),
|
||||
})
|
||||
|
||||
require.Nil(t, got.log.UserAgent, "空入站 User-Agent 不应由上游握手 UA 或默认 UA 兜底")
|
||||
require.NotNil(t, got.log.ReasoningEffort)
|
||||
require.Equal(t, "medium", *got.log.ReasoningEffort)
|
||||
}
|
||||
|
||||
func TestSetOpenAIClientTransportHTTP(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
@ -796,3 +831,278 @@ func newOpenAIWSHandlerTestServer(t *testing.T, h *OpenAIGatewayHandler, subject
|
||||
router.GET("/openai/v1/responses", h.ResponsesWebSocket)
|
||||
return httptest.NewServer(router)
|
||||
}
|
||||
|
||||
type openAIResponsesWSUsageLogCase struct {
|
||||
firstPayload string
|
||||
userAgent *string
|
||||
channelMapping map[string]string
|
||||
}
|
||||
|
||||
type openAIResponsesWSUsageLogResult struct {
|
||||
log *service.UsageLog
|
||||
upstreamFirstPayload []byte
|
||||
}
|
||||
|
||||
type openAIWSUsageHandlerAccountRepoStub struct {
|
||||
service.AccountRepository
|
||||
account service.Account
|
||||
}
|
||||
|
||||
func (s *openAIWSUsageHandlerAccountRepoStub) ListSchedulableByPlatform(ctx context.Context, platform string) ([]service.Account, error) {
|
||||
if s.account.Platform != platform {
|
||||
return nil, nil
|
||||
}
|
||||
return []service.Account{s.account}, nil
|
||||
}
|
||||
|
||||
func (s *openAIWSUsageHandlerAccountRepoStub) ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]service.Account, error) {
|
||||
return s.ListSchedulableByPlatform(ctx, platform)
|
||||
}
|
||||
|
||||
func (s *openAIWSUsageHandlerAccountRepoStub) GetByID(ctx context.Context, id int64) (*service.Account, error) {
|
||||
if s.account.ID != id {
|
||||
return nil, nil
|
||||
}
|
||||
account := s.account
|
||||
return &account, nil
|
||||
}
|
||||
|
||||
type openAIWSUsageHandlerUsageLogRepoStub struct {
|
||||
service.UsageLogRepository
|
||||
created chan *service.UsageLog
|
||||
}
|
||||
|
||||
func (s *openAIWSUsageHandlerUsageLogRepoStub) Create(ctx context.Context, log *service.UsageLog) (bool, error) {
|
||||
if s.created != nil {
|
||||
s.created <- log
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
type openAIWSUsageHandlerChannelRepoStub struct {
|
||||
service.ChannelRepository
|
||||
channels []service.Channel
|
||||
groupPlatforms map[int64]string
|
||||
}
|
||||
|
||||
func (s *openAIWSUsageHandlerChannelRepoStub) ListAll(ctx context.Context) ([]service.Channel, error) {
|
||||
return s.channels, nil
|
||||
}
|
||||
|
||||
func (s *openAIWSUsageHandlerChannelRepoStub) GetGroupPlatforms(ctx context.Context, groupIDs []int64) (map[int64]string, error) {
|
||||
out := make(map[int64]string, len(groupIDs))
|
||||
for _, groupID := range groupIDs {
|
||||
if platform := strings.TrimSpace(s.groupPlatforms[groupID]); platform != "" {
|
||||
out[groupID] = platform
|
||||
}
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func runOpenAIResponsesWebSocketUsageLogCase(t *testing.T, tc openAIResponsesWSUsageLogCase) openAIResponsesWSUsageLogResult {
|
||||
t.Helper()
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
upstreamPayloadCh := make(chan []byte, 1)
|
||||
upstreamErrCh := make(chan error, 1)
|
||||
upstreamServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{
|
||||
CompressionMode: coderws.CompressionContextTakeover,
|
||||
})
|
||||
if err != nil {
|
||||
upstreamErrCh <- err
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
_ = conn.CloseNow()
|
||||
}()
|
||||
|
||||
readCtx, cancelRead := context.WithTimeout(r.Context(), 3*time.Second)
|
||||
msgType, payload, readErr := conn.Read(readCtx)
|
||||
cancelRead()
|
||||
if readErr != nil {
|
||||
upstreamErrCh <- readErr
|
||||
return
|
||||
}
|
||||
if msgType != coderws.MessageText && msgType != coderws.MessageBinary {
|
||||
upstreamErrCh <- errors.New("unexpected upstream websocket message type")
|
||||
return
|
||||
}
|
||||
upstreamPayloadCh <- payload
|
||||
|
||||
writeCtx, cancelWrite := context.WithTimeout(r.Context(), 3*time.Second)
|
||||
writeErr := conn.Write(writeCtx, coderws.MessageText, []byte(
|
||||
`{"type":"response.completed","response":{"id":"resp_usage_e2e","model":"gpt-5.4","usage":{"input_tokens":2,"output_tokens":1}}}`,
|
||||
))
|
||||
cancelWrite()
|
||||
if writeErr != nil {
|
||||
upstreamErrCh <- writeErr
|
||||
return
|
||||
}
|
||||
_ = conn.Close(coderws.StatusNormalClosure, "done")
|
||||
upstreamErrCh <- nil
|
||||
}))
|
||||
defer upstreamServer.Close()
|
||||
|
||||
groupID := int64(4201)
|
||||
account := service.Account{
|
||||
ID: 9901,
|
||||
Name: "openai-ws-passthrough-usage-e2e",
|
||||
Platform: service.PlatformOpenAI,
|
||||
Type: service.AccountTypeAPIKey,
|
||||
Status: service.StatusActive,
|
||||
Schedulable: true,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{
|
||||
"api_key": "sk-test",
|
||||
"base_url": upstreamServer.URL,
|
||||
},
|
||||
Extra: map[string]any{
|
||||
"openai_apikey_responses_websockets_v2_enabled": true,
|
||||
"openai_apikey_responses_websockets_v2_mode": service.OpenAIWSIngressModePassthrough,
|
||||
},
|
||||
}
|
||||
|
||||
cfg := &config.Config{}
|
||||
cfg.RunMode = config.RunModeSimple
|
||||
cfg.Default.RateMultiplier = 1
|
||||
cfg.Security.URLAllowlist.Enabled = false
|
||||
cfg.Security.URLAllowlist.AllowInsecureHTTP = true
|
||||
cfg.Gateway.OpenAIWS.Enabled = true
|
||||
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
|
||||
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
|
||||
cfg.Gateway.OpenAIWS.ModeRouterV2Enabled = true
|
||||
cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3
|
||||
cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3
|
||||
cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3
|
||||
|
||||
accountRepo := &openAIWSUsageHandlerAccountRepoStub{account: account}
|
||||
usageRepo := &openAIWSUsageHandlerUsageLogRepoStub{created: make(chan *service.UsageLog, 1)}
|
||||
|
||||
var channelSvc *service.ChannelService
|
||||
if len(tc.channelMapping) > 0 {
|
||||
channelSvc = service.NewChannelService(&openAIWSUsageHandlerChannelRepoStub{
|
||||
channels: []service.Channel{{
|
||||
ID: 7701,
|
||||
Name: "openai-ws-e2e-channel",
|
||||
Status: service.StatusActive,
|
||||
GroupIDs: []int64{groupID},
|
||||
ModelMapping: map[string]map[string]string{service.PlatformOpenAI: tc.channelMapping},
|
||||
}},
|
||||
groupPlatforms: map[int64]string{groupID: service.PlatformOpenAI},
|
||||
}, nil, nil, nil)
|
||||
}
|
||||
|
||||
billingCacheSvc := service.NewBillingCacheService(nil, nil, nil, nil, nil, nil, cfg)
|
||||
gatewaySvc := service.NewOpenAIGatewayService(
|
||||
accountRepo,
|
||||
usageRepo,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
cfg,
|
||||
nil,
|
||||
nil,
|
||||
service.NewBillingService(cfg, nil),
|
||||
nil,
|
||||
billingCacheSvc,
|
||||
nil,
|
||||
&service.DeferredService{},
|
||||
nil,
|
||||
nil,
|
||||
channelSvc,
|
||||
nil,
|
||||
nil,
|
||||
)
|
||||
|
||||
cache := &concurrencyCacheMock{
|
||||
acquireUserSlotFn: func(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error) {
|
||||
return true, nil
|
||||
},
|
||||
acquireAccountSlotFn: func(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) {
|
||||
return true, nil
|
||||
},
|
||||
}
|
||||
h := &OpenAIGatewayHandler{
|
||||
gatewayService: gatewaySvc,
|
||||
billingCacheService: billingCacheSvc,
|
||||
apiKeyService: &service.APIKeyService{},
|
||||
concurrencyHelper: NewConcurrencyHelper(service.NewConcurrencyService(cache), SSEPingFormatNone, time.Second),
|
||||
}
|
||||
|
||||
apiKey := &service.APIKey{
|
||||
ID: 1801,
|
||||
GroupID: &groupID,
|
||||
User: &service.User{ID: 1701, Status: service.StatusActive},
|
||||
}
|
||||
router := gin.New()
|
||||
router.Use(func(c *gin.Context) {
|
||||
c.Set(string(middleware.ContextKeyAPIKey), apiKey)
|
||||
c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{UserID: apiKey.User.ID, Concurrency: 1})
|
||||
c.Next()
|
||||
})
|
||||
router.GET("/openai/v1/responses", h.ResponsesWebSocket)
|
||||
handlerServer := httptest.NewServer(router)
|
||||
defer handlerServer.Close()
|
||||
|
||||
headers := http.Header{}
|
||||
if tc.userAgent != nil {
|
||||
headers.Set("User-Agent", *tc.userAgent)
|
||||
}
|
||||
dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
clientConn, _, err := coderws.Dial(
|
||||
dialCtx,
|
||||
"ws"+strings.TrimPrefix(handlerServer.URL, "http")+"/openai/v1/responses",
|
||||
&coderws.DialOptions{HTTPHeader: headers, CompressionMode: coderws.CompressionContextTakeover},
|
||||
)
|
||||
cancelDial()
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
_ = clientConn.CloseNow()
|
||||
}()
|
||||
|
||||
writeCtx, cancelWrite := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
err = clientConn.Write(writeCtx, coderws.MessageText, []byte(tc.firstPayload))
|
||||
cancelWrite()
|
||||
require.NoError(t, err)
|
||||
|
||||
readCtx, cancelRead := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
_, event, err := clientConn.Read(readCtx)
|
||||
cancelRead()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "response.completed", gjson.GetBytes(event, "type").String())
|
||||
_ = clientConn.Close(coderws.StatusNormalClosure, "done")
|
||||
|
||||
var usageLog *service.UsageLog
|
||||
select {
|
||||
case usageLog = <-usageRepo.created:
|
||||
require.NotNil(t, usageLog)
|
||||
case <-time.After(3 * time.Second):
|
||||
t.Fatal("等待 WebSocket usage log 写入超时")
|
||||
}
|
||||
|
||||
var upstreamFirstPayload []byte
|
||||
select {
|
||||
case upstreamFirstPayload = <-upstreamPayloadCh:
|
||||
case <-time.After(3 * time.Second):
|
||||
t.Fatal("等待上游 WebSocket 首帧超时")
|
||||
}
|
||||
|
||||
select {
|
||||
case upstreamErr := <-upstreamErrCh:
|
||||
require.NoError(t, upstreamErr)
|
||||
case <-time.After(3 * time.Second):
|
||||
t.Fatal("等待上游 WebSocket 结束超时")
|
||||
}
|
||||
|
||||
return openAIResponsesWSUsageLogResult{
|
||||
log: usageLog,
|
||||
upstreamFirstPayload: upstreamFirstPayload,
|
||||
}
|
||||
}
|
||||
|
||||
func testStringPtr(v string) *string {
|
||||
return &v
|
||||
}
|
||||
|
||||
@ -81,6 +81,18 @@ func (h *OpenAIGatewayHandler) Images(c *gin.Context) {
|
||||
zap.String("capability", string(parsed.RequiredCapability)),
|
||||
)
|
||||
|
||||
if !service.GroupAllowsImageGeneration(apiKey.Group) {
|
||||
h.errorResponse(c, http.StatusForbidden, "permission_error", service.ImageGenerationPermissionMessage())
|
||||
return
|
||||
}
|
||||
imageReleaseFunc, acquired := h.acquireImageGenerationSlot(c, streamStarted)
|
||||
if !acquired {
|
||||
return
|
||||
}
|
||||
if imageReleaseFunc != nil {
|
||||
defer imageReleaseFunc()
|
||||
}
|
||||
|
||||
if parsed.Multipart {
|
||||
setOpsRequestContext(c, parsed.Model, parsed.Stream, nil)
|
||||
} else {
|
||||
@ -188,62 +200,69 @@ func (h *OpenAIGatewayHandler) Images(c *gin.Context) {
|
||||
responseLatencyMs = forwardDurationMs - upstreamLatencyMs
|
||||
}
|
||||
service.SetOpsLatencyMs(c, service.OpsResponseLatencyMsKey, responseLatencyMs)
|
||||
if err == nil && result != nil && result.FirstTokenMs != nil {
|
||||
if result != nil && result.FirstTokenMs != nil {
|
||||
service.SetOpsLatencyMs(c, service.OpsTimeToFirstTokenMsKey, int64(*result.FirstTokenMs))
|
||||
}
|
||||
if err != nil {
|
||||
var failoverErr *service.UpstreamFailoverError
|
||||
if errors.As(err, &failoverErr) {
|
||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
|
||||
if failoverErr.RetryableOnSameAccount {
|
||||
retryLimit := account.GetPoolModeRetryCount()
|
||||
if sameAccountRetryCount[account.ID] < retryLimit {
|
||||
sameAccountRetryCount[account.ID]++
|
||||
reqLog.Warn("openai.images.pool_mode_same_account_retry",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int("upstream_status", failoverErr.StatusCode),
|
||||
zap.Int("retry_limit", retryLimit),
|
||||
zap.Int("retry_count", sameAccountRetryCount[account.ID]),
|
||||
)
|
||||
select {
|
||||
case <-c.Request.Context().Done():
|
||||
return
|
||||
case <-time.After(sameAccountRetryDelay):
|
||||
if result != nil && result.ImageCount > 0 {
|
||||
reqLog.Warn("openai.images.forward_partial_error_with_image_result",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int("image_count", result.ImageCount),
|
||||
zap.Error(err),
|
||||
)
|
||||
} else {
|
||||
var failoverErr *service.UpstreamFailoverError
|
||||
if errors.As(err, &failoverErr) {
|
||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
|
||||
if failoverErr.RetryableOnSameAccount {
|
||||
retryLimit := account.GetPoolModeRetryCount()
|
||||
if sameAccountRetryCount[account.ID] < retryLimit {
|
||||
sameAccountRetryCount[account.ID]++
|
||||
reqLog.Warn("openai.images.pool_mode_same_account_retry",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int("upstream_status", failoverErr.StatusCode),
|
||||
zap.Int("retry_limit", retryLimit),
|
||||
zap.Int("retry_count", sameAccountRetryCount[account.ID]),
|
||||
)
|
||||
select {
|
||||
case <-c.Request.Context().Done():
|
||||
return
|
||||
case <-time.After(sameAccountRetryDelay):
|
||||
}
|
||||
continue
|
||||
}
|
||||
continue
|
||||
}
|
||||
h.gatewayService.RecordOpenAIAccountSwitch()
|
||||
failedAccountIDs[account.ID] = struct{}{}
|
||||
lastFailoverErr = failoverErr
|
||||
if switchCount >= maxAccountSwitches {
|
||||
h.handleFailoverExhausted(c, failoverErr, streamStarted)
|
||||
return
|
||||
}
|
||||
switchCount++
|
||||
reqLog.Warn("openai.images.upstream_failover_switching",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int("upstream_status", failoverErr.StatusCode),
|
||||
zap.Int("switch_count", switchCount),
|
||||
zap.Int("max_switches", maxAccountSwitches),
|
||||
)
|
||||
continue
|
||||
}
|
||||
h.gatewayService.RecordOpenAIAccountSwitch()
|
||||
failedAccountIDs[account.ID] = struct{}{}
|
||||
lastFailoverErr = failoverErr
|
||||
if switchCount >= maxAccountSwitches {
|
||||
h.handleFailoverExhausted(c, failoverErr, streamStarted)
|
||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
|
||||
wroteFallback := h.ensureForwardErrorResponse(c, streamStarted)
|
||||
fields := []zap.Field{
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Bool("fallback_error_response_written", wroteFallback),
|
||||
zap.Error(err),
|
||||
}
|
||||
if shouldLogOpenAIForwardFailureAsWarn(c, wroteFallback) {
|
||||
reqLog.Warn("openai.images.forward_failed", fields...)
|
||||
return
|
||||
}
|
||||
switchCount++
|
||||
reqLog.Warn("openai.images.upstream_failover_switching",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int("upstream_status", failoverErr.StatusCode),
|
||||
zap.Int("switch_count", switchCount),
|
||||
zap.Int("max_switches", maxAccountSwitches),
|
||||
)
|
||||
continue
|
||||
}
|
||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
|
||||
wroteFallback := h.ensureForwardErrorResponse(c, streamStarted)
|
||||
fields := []zap.Field{
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Bool("fallback_error_response_written", wroteFallback),
|
||||
zap.Error(err),
|
||||
}
|
||||
if shouldLogOpenAIForwardFailureAsWarn(c, wroteFallback) {
|
||||
reqLog.Warn("openai.images.forward_failed", fields...)
|
||||
reqLog.Error("openai.images.forward_failed", fields...)
|
||||
return
|
||||
}
|
||||
reqLog.Error("openai.images.forward_failed", fields...)
|
||||
return
|
||||
}
|
||||
|
||||
if result != nil {
|
||||
if account.Type == service.AccountTypeOAuth {
|
||||
h.gatewayService.UpdateCodexUsageSnapshotFromHeaders(c.Request.Context(), account.ID, result.ResponseHeaders)
|
||||
@ -259,21 +278,27 @@ func (h *OpenAIGatewayHandler) Images(c *gin.Context) {
|
||||
if parsed.Multipart {
|
||||
requestPayloadHash = service.HashUsageRequestPayload([]byte(parsed.StickySessionSeed()))
|
||||
}
|
||||
inboundEndpoint := GetInboundEndpoint(c)
|
||||
upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform)
|
||||
|
||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||
upstreamModel := ""
|
||||
if result != nil {
|
||||
upstreamModel = result.UpstreamModel
|
||||
}
|
||||
h.submitMandatoryUsageRecordTask(func(ctx context.Context) {
|
||||
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
|
||||
Result: result,
|
||||
APIKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
InboundEndpoint: GetInboundEndpoint(c),
|
||||
UpstreamEndpoint: GetUpstreamEndpoint(c, account.Platform),
|
||||
InboundEndpoint: inboundEndpoint,
|
||||
UpstreamEndpoint: upstreamEndpoint,
|
||||
UserAgent: userAgent,
|
||||
IPAddress: clientIP,
|
||||
RequestPayloadHash: requestPayloadHash,
|
||||
APIKeyService: h.apiKeyService,
|
||||
ChannelUsageFields: channelMapping.ToUsageFields(parsed.Model, result.UpstreamModel),
|
||||
ChannelUsageFields: channelMapping.ToUsageFields(parsed.Model, upstreamModel),
|
||||
}); err != nil {
|
||||
logger.L().With(
|
||||
zap.String("component", "handler.openai_gateway.images"),
|
||||
|
||||
49
backend/internal/handler/openai_images_controls_test.go
Normal file
49
backend/internal/handler/openai_images_controls_test.go
Normal file
@ -0,0 +1,49 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
func TestOpenAIGatewayHandlerImages_DisabledGroupRejectsBeforeScheduling(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
body := []byte(`{"model":"gpt-image-2","prompt":"draw","size":"1024x1024"}`)
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/images/generations", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = req
|
||||
groupID := int64(111)
|
||||
c.Set(string(middleware2.ContextKeyAPIKey), &service.APIKey{
|
||||
ID: 222,
|
||||
GroupID: &groupID,
|
||||
Group: &service.Group{
|
||||
ID: groupID,
|
||||
AllowImageGeneration: false,
|
||||
},
|
||||
User: &service.User{ID: 333},
|
||||
})
|
||||
c.Set(string(middleware2.ContextKeyUser), middleware2.AuthSubject{UserID: 333, Concurrency: 1})
|
||||
|
||||
h := &OpenAIGatewayHandler{
|
||||
gatewayService: &service.OpenAIGatewayService{},
|
||||
billingCacheService: &service.BillingCacheService{},
|
||||
apiKeyService: &service.APIKeyService{},
|
||||
concurrencyHelper: &ConcurrencyHelper{concurrencyService: &service.ConcurrencyService{}},
|
||||
}
|
||||
|
||||
h.Images(c)
|
||||
|
||||
require.Equal(t, http.StatusForbidden, rec.Code)
|
||||
require.Equal(t, "permission_error", gjson.GetBytes(rec.Body.Bytes(), "error.type").String())
|
||||
require.Contains(t, rec.Body.String(), service.ImageGenerationPermissionMessage())
|
||||
}
|
||||
@ -129,3 +129,63 @@ func TestOpenAIGatewayHandlerSubmitUsageRecordTask_WithoutPool_TaskPanicRecovere
|
||||
})
|
||||
require.True(t, called.Load(), "panic 后后续任务应仍可执行")
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayHandlerSubmitMandatoryUsageRecordTask_DroppedTaskSyncFallback(t *testing.T) {
|
||||
pool := service.NewUsageRecordWorkerPoolWithOptions(service.UsageRecordWorkerPoolOptions{
|
||||
WorkerCount: 1,
|
||||
QueueSize: 1,
|
||||
TaskTimeout: time.Second,
|
||||
OverflowPolicy: "drop",
|
||||
OverflowSamplePercent: 0,
|
||||
AutoScaleEnabled: false,
|
||||
})
|
||||
t.Cleanup(pool.Stop)
|
||||
h := &OpenAIGatewayHandler{usageRecordWorkerPool: pool}
|
||||
|
||||
block := make(chan struct{})
|
||||
release := make(chan struct{})
|
||||
pool.Submit(func(ctx context.Context) {
|
||||
close(block)
|
||||
<-release
|
||||
})
|
||||
<-block
|
||||
pool.Submit(func(ctx context.Context) {})
|
||||
|
||||
var called atomic.Bool
|
||||
h.submitMandatoryUsageRecordTask(func(ctx context.Context) {
|
||||
called.Store(true)
|
||||
})
|
||||
close(release)
|
||||
|
||||
require.True(t, called.Load(), "mandatory usage task must run synchronously when async submit is dropped")
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayHandlerSubmitOpenAIUsageRecordTask_ImageResultUsesMandatoryFallback(t *testing.T) {
|
||||
pool := service.NewUsageRecordWorkerPoolWithOptions(service.UsageRecordWorkerPoolOptions{
|
||||
WorkerCount: 1,
|
||||
QueueSize: 1,
|
||||
TaskTimeout: time.Second,
|
||||
OverflowPolicy: "drop",
|
||||
OverflowSamplePercent: 0,
|
||||
AutoScaleEnabled: false,
|
||||
})
|
||||
t.Cleanup(pool.Stop)
|
||||
h := &OpenAIGatewayHandler{usageRecordWorkerPool: pool}
|
||||
|
||||
block := make(chan struct{})
|
||||
release := make(chan struct{})
|
||||
pool.Submit(func(ctx context.Context) {
|
||||
close(block)
|
||||
<-release
|
||||
})
|
||||
<-block
|
||||
pool.Submit(func(ctx context.Context) {})
|
||||
|
||||
var called atomic.Bool
|
||||
h.submitOpenAIUsageRecordTask(&service.OpenAIForwardResult{ImageCount: 1}, func(ctx context.Context) {
|
||||
called.Store(true)
|
||||
})
|
||||
close(release)
|
||||
|
||||
require.True(t, called.Load(), "image usage task must be mandatory when async submit is dropped")
|
||||
}
|
||||
|
||||
@ -32,7 +32,13 @@ func TestAnthropicToResponses_BasicText(t *testing.T) {
|
||||
var items []ResponsesInputItem
|
||||
require.NoError(t, json.Unmarshal(resp.Input, &items))
|
||||
require.Len(t, items, 1)
|
||||
assert.Equal(t, "message", items[0].Type)
|
||||
assert.Equal(t, "user", items[0].Role)
|
||||
var parts []ResponsesContentPart
|
||||
require.NoError(t, json.Unmarshal(items[0].Content, &parts))
|
||||
require.Len(t, parts, 1)
|
||||
assert.Equal(t, "input_text", parts[0].Type)
|
||||
assert.Equal(t, "Hello", parts[0].Text)
|
||||
}
|
||||
|
||||
func TestAnthropicToResponses_SystemPrompt(t *testing.T) {
|
||||
@ -49,7 +55,12 @@ func TestAnthropicToResponses_SystemPrompt(t *testing.T) {
|
||||
var items []ResponsesInputItem
|
||||
require.NoError(t, json.Unmarshal(resp.Input, &items))
|
||||
require.Len(t, items, 2)
|
||||
assert.Equal(t, "system", items[0].Role)
|
||||
assert.Equal(t, "developer", items[0].Role)
|
||||
var parts []ResponsesContentPart
|
||||
require.NoError(t, json.Unmarshal(items[0].Content, &parts))
|
||||
require.Len(t, parts, 1)
|
||||
assert.Equal(t, "input_text", parts[0].Type)
|
||||
assert.Equal(t, "You are helpful.", parts[0].Text)
|
||||
})
|
||||
|
||||
t.Run("array", func(t *testing.T) {
|
||||
@ -65,11 +76,33 @@ func TestAnthropicToResponses_SystemPrompt(t *testing.T) {
|
||||
var items []ResponsesInputItem
|
||||
require.NoError(t, json.Unmarshal(resp.Input, &items))
|
||||
require.Len(t, items, 2)
|
||||
assert.Equal(t, "system", items[0].Role)
|
||||
// System text should be joined with double newline.
|
||||
var text string
|
||||
require.NoError(t, json.Unmarshal(items[0].Content, &text))
|
||||
assert.Equal(t, "Part 1\n\nPart 2", text)
|
||||
assert.Equal(t, "developer", items[0].Role)
|
||||
var parts []ResponsesContentPart
|
||||
require.NoError(t, json.Unmarshal(items[0].Content, &parts))
|
||||
require.Len(t, parts, 2)
|
||||
assert.Equal(t, "input_text", parts[0].Type)
|
||||
assert.Equal(t, "Part 1", parts[0].Text)
|
||||
assert.Equal(t, "input_text", parts[1].Type)
|
||||
assert.Equal(t, "Part 2", parts[1].Text)
|
||||
})
|
||||
|
||||
t.Run("billing header skipped", func(t *testing.T) {
|
||||
req := &AnthropicRequest{
|
||||
Model: "gpt-5.2",
|
||||
MaxTokens: 100,
|
||||
System: json.RawMessage(`[{"type":"text","text":"x-anthropic-billing-header: cc_version=1;"},{"type":"text","text":"Project prompt"}]`),
|
||||
Messages: []AnthropicMessage{{Role: "user", Content: json.RawMessage(`"Hi"`)}},
|
||||
}
|
||||
resp, err := AnthropicToResponses(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
var items []ResponsesInputItem
|
||||
require.NoError(t, json.Unmarshal(resp.Input, &items))
|
||||
require.Len(t, items, 2)
|
||||
var parts []ResponsesContentPart
|
||||
require.NoError(t, json.Unmarshal(items[0].Content, &parts))
|
||||
require.Len(t, parts, 1)
|
||||
assert.Equal(t, "Project prompt", parts[0].Text)
|
||||
})
|
||||
}
|
||||
|
||||
@ -94,6 +127,8 @@ func TestAnthropicToResponses_ToolUse(t *testing.T) {
|
||||
require.Len(t, resp.Tools, 1)
|
||||
assert.Equal(t, "function", resp.Tools[0].Type)
|
||||
assert.Equal(t, "get_weather", resp.Tools[0].Name)
|
||||
require.NotNil(t, resp.Tools[0].Strict)
|
||||
assert.False(t, *resp.Tools[0].Strict)
|
||||
|
||||
// Check input items
|
||||
var items []ResponsesInputItem
|
||||
@ -104,10 +139,10 @@ func TestAnthropicToResponses_ToolUse(t *testing.T) {
|
||||
assert.Equal(t, "user", items[0].Role)
|
||||
assert.Equal(t, "assistant", items[1].Role)
|
||||
assert.Equal(t, "function_call", items[2].Type)
|
||||
assert.Equal(t, "fc_call_1", items[2].CallID)
|
||||
assert.Equal(t, "call_1", items[2].CallID)
|
||||
assert.Empty(t, items[2].ID)
|
||||
assert.Equal(t, "function_call_output", items[3].Type)
|
||||
assert.Equal(t, "fc_call_1", items[3].CallID)
|
||||
assert.Equal(t, "call_1", items[3].CallID)
|
||||
assert.Equal(t, "Sunny, 72°F", items[3].Output)
|
||||
}
|
||||
|
||||
@ -261,6 +296,34 @@ func TestResponsesToAnthropic_ToolUse(t *testing.T) {
|
||||
assert.JSONEq(t, `{"city":"NYC"}`, string(anth.Content[1].Input))
|
||||
}
|
||||
|
||||
func TestResponsesToAnthropic_ToolUseStopReasonDoesNotDependOnLastBlock(t *testing.T) {
|
||||
resp := &ResponsesResponse{
|
||||
ID: "resp_tool_then_text",
|
||||
Model: "gpt-5.5",
|
||||
Status: "completed",
|
||||
Output: []ResponsesOutput{
|
||||
{
|
||||
Type: "function_call",
|
||||
CallID: "call_todo",
|
||||
Name: "TodoWrite",
|
||||
Arguments: `{"todos":[{"content":"review changes","status":"in_progress"}]}`,
|
||||
},
|
||||
{
|
||||
Type: "message",
|
||||
Content: []ResponsesContentPart{
|
||||
{Type: "output_text", Text: "Task list updated."},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
anth := ResponsesToAnthropic(resp, "claude-opus-4-6")
|
||||
assert.Equal(t, "tool_use", anth.StopReason)
|
||||
require.Len(t, anth.Content, 2)
|
||||
assert.Equal(t, "tool_use", anth.Content[0].Type)
|
||||
assert.Equal(t, "text", anth.Content[1].Type)
|
||||
}
|
||||
|
||||
func TestResponsesToAnthropic_ReadToolDropsEmptyPages(t *testing.T) {
|
||||
resp := &ResponsesResponse{
|
||||
ID: "resp_read",
|
||||
@ -434,6 +497,45 @@ func TestStreamingTextOnly(t *testing.T) {
|
||||
assert.Equal(t, "message_stop", events[1].Type)
|
||||
}
|
||||
|
||||
func TestResponsesEventToAnthropicEvents_ResponseDone(t *testing.T) {
|
||||
state := NewResponsesEventToAnthropicState()
|
||||
state.Model = "gpt-4o"
|
||||
|
||||
events := ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{
|
||||
Type: "response.done",
|
||||
Response: &ResponsesResponse{
|
||||
Status: "completed",
|
||||
Usage: &ResponsesUsage{InputTokens: 12, OutputTokens: 4},
|
||||
},
|
||||
}, state)
|
||||
require.Len(t, events, 2)
|
||||
assert.Equal(t, "message_delta", events[0].Type)
|
||||
assert.Equal(t, "end_turn", events[0].Delta.StopReason)
|
||||
assert.Equal(t, 12, events[0].Usage.InputTokens)
|
||||
assert.Equal(t, 4, events[0].Usage.OutputTokens)
|
||||
assert.Equal(t, "message_stop", events[1].Type)
|
||||
assert.Nil(t, FinalizeResponsesAnthropicStream(state))
|
||||
}
|
||||
|
||||
func TestResponsesEventToAnthropicEvents_ResponseDoneIncomplete(t *testing.T) {
|
||||
state := NewResponsesEventToAnthropicState()
|
||||
state.Model = "gpt-4o"
|
||||
|
||||
events := ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{
|
||||
Type: "response.done",
|
||||
Response: &ResponsesResponse{
|
||||
Status: "incomplete",
|
||||
IncompleteDetails: &ResponsesIncompleteDetails{Reason: "max_output_tokens"},
|
||||
Usage: &ResponsesUsage{InputTokens: 12, OutputTokens: 4},
|
||||
},
|
||||
}, state)
|
||||
require.Len(t, events, 2)
|
||||
assert.Equal(t, "message_delta", events[0].Type)
|
||||
assert.Equal(t, "max_tokens", events[0].Delta.StopReason)
|
||||
assert.Equal(t, "message_stop", events[1].Type)
|
||||
assert.Nil(t, FinalizeResponsesAnthropicStream(state))
|
||||
}
|
||||
|
||||
func TestStreamingCachedTokensUseAnthropicInputSemantics(t *testing.T) {
|
||||
state := NewResponsesEventToAnthropicState()
|
||||
ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{
|
||||
@ -514,6 +616,81 @@ func TestStreamingToolCall(t *testing.T) {
|
||||
assert.Equal(t, "tool_use", events[0].Delta.StopReason)
|
||||
}
|
||||
|
||||
func TestStreamingToolCallStopReasonSurvivesLaterText(t *testing.T) {
|
||||
state := NewResponsesEventToAnthropicState()
|
||||
|
||||
ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{
|
||||
Type: "response.created",
|
||||
Response: &ResponsesResponse{ID: "resp_tool_then_text", Model: "gpt-5.5"},
|
||||
}, state)
|
||||
|
||||
events := ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{
|
||||
Type: "response.output_item.added",
|
||||
OutputIndex: 0,
|
||||
Item: &ResponsesOutput{Type: "function_call", CallID: "call_todo", Name: "TodoWrite"},
|
||||
}, state)
|
||||
require.Len(t, events, 1)
|
||||
assert.Equal(t, "content_block_start", events[0].Type)
|
||||
|
||||
events = ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{
|
||||
Type: "response.function_call_arguments.done",
|
||||
OutputIndex: 0,
|
||||
Arguments: `{"todos":[{"content":"review changes","status":"in_progress","activeForm":"reviewing changes"}]}`,
|
||||
}, state)
|
||||
require.Len(t, events, 2)
|
||||
assert.Equal(t, "content_block_delta", events[0].Type)
|
||||
assert.Equal(t, "content_block_stop", events[1].Type)
|
||||
|
||||
events = ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{
|
||||
Type: "response.output_text.delta",
|
||||
OutputIndex: 1,
|
||||
Delta: "I will continue after the task list updates.",
|
||||
}, state)
|
||||
require.Len(t, events, 2)
|
||||
assert.Equal(t, "content_block_start", events[0].Type)
|
||||
assert.Equal(t, "content_block_delta", events[1].Type)
|
||||
|
||||
events = ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{
|
||||
Type: "response.completed",
|
||||
Response: &ResponsesResponse{
|
||||
Status: "completed",
|
||||
Usage: &ResponsesUsage{InputTokens: 20, OutputTokens: 10},
|
||||
},
|
||||
}, state)
|
||||
require.Len(t, events, 3)
|
||||
assert.Equal(t, "content_block_stop", events[0].Type)
|
||||
assert.Equal(t, "tool_use", events[1].Delta.StopReason)
|
||||
assert.Equal(t, "message_stop", events[2].Type)
|
||||
}
|
||||
|
||||
func TestStreamingToolCallDoneWithoutDeltaEmitsArguments(t *testing.T) {
|
||||
state := NewResponsesEventToAnthropicState()
|
||||
|
||||
ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{
|
||||
Type: "response.created",
|
||||
Response: &ResponsesResponse{ID: "resp_bash", Model: "gpt-5.5"},
|
||||
}, state)
|
||||
|
||||
events := ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{
|
||||
Type: "response.output_item.added",
|
||||
OutputIndex: 0,
|
||||
Item: &ResponsesOutput{Type: "function_call", CallID: "call_bash", Name: "Bash"},
|
||||
}, state)
|
||||
require.Len(t, events, 1)
|
||||
assert.Equal(t, "content_block_start", events[0].Type)
|
||||
|
||||
events = ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{
|
||||
Type: "response.function_call_arguments.done",
|
||||
OutputIndex: 0,
|
||||
Arguments: `{"command":"git -C \"/mnt/d/nodejs/other/edmt\" status --short --ignored"}`,
|
||||
}, state)
|
||||
require.Len(t, events, 2)
|
||||
assert.Equal(t, "content_block_delta", events[0].Type)
|
||||
assert.Equal(t, "input_json_delta", events[0].Delta.Type)
|
||||
assert.JSONEq(t, `{"command":"git -C \"/mnt/d/nodejs/other/edmt\" status --short --ignored"}`, events[0].Delta.PartialJSON)
|
||||
assert.Equal(t, "content_block_stop", events[1].Type)
|
||||
}
|
||||
|
||||
func TestStreamingReadToolDropsEmptyPages(t *testing.T) {
|
||||
state := NewResponsesEventToAnthropicState()
|
||||
|
||||
@ -653,6 +830,27 @@ func TestFinalizeStream_AbnormalTermination(t *testing.T) {
|
||||
assert.Equal(t, "message_stop", events[2].Type)
|
||||
}
|
||||
|
||||
func TestFinalizeStream_ToolCallAbnormalTerminationKeepsToolUseStopReason(t *testing.T) {
|
||||
state := NewResponsesEventToAnthropicState()
|
||||
|
||||
ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{
|
||||
Type: "response.created",
|
||||
Response: &ResponsesResponse{ID: "resp_tool_interrupted", Model: "gpt-5.5"},
|
||||
}, state)
|
||||
ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{
|
||||
Type: "response.output_item.added",
|
||||
OutputIndex: 0,
|
||||
Item: &ResponsesOutput{Type: "function_call", CallID: "call_todo", Name: "TodoWrite"},
|
||||
}, state)
|
||||
|
||||
events := FinalizeResponsesAnthropicStream(state)
|
||||
require.Len(t, events, 3)
|
||||
assert.Equal(t, "content_block_stop", events[0].Type)
|
||||
assert.Equal(t, "message_delta", events[1].Type)
|
||||
assert.Equal(t, "tool_use", events[1].Delta.StopReason)
|
||||
assert.Equal(t, "message_stop", events[2].Type)
|
||||
}
|
||||
|
||||
func TestStreamingEmptyResponse(t *testing.T) {
|
||||
state := NewResponsesEventToAnthropicState()
|
||||
|
||||
@ -788,8 +986,8 @@ func TestAnthropicToResponses_ThinkingEnabled(t *testing.T) {
|
||||
resp, err := AnthropicToResponses(req)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp.Reasoning)
|
||||
// thinking.type is ignored for effort; default high applies.
|
||||
assert.Equal(t, "high", resp.Reasoning.Effort)
|
||||
// thinking.type is ignored for effort; Codex bridge default medium applies.
|
||||
assert.Equal(t, "medium", resp.Reasoning.Effort)
|
||||
assert.Equal(t, "auto", resp.Reasoning.Summary)
|
||||
assert.Contains(t, resp.Include, "reasoning.encrypted_content")
|
||||
assert.NotContains(t, resp.Include, "reasoning.summary")
|
||||
@ -806,8 +1004,8 @@ func TestAnthropicToResponses_ThinkingAdaptive(t *testing.T) {
|
||||
resp, err := AnthropicToResponses(req)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp.Reasoning)
|
||||
// thinking.type is ignored for effort; default high applies.
|
||||
assert.Equal(t, "high", resp.Reasoning.Effort)
|
||||
// thinking.type is ignored for effort; Codex bridge default medium applies.
|
||||
assert.Equal(t, "medium", resp.Reasoning.Effort)
|
||||
assert.Equal(t, "auto", resp.Reasoning.Summary)
|
||||
assert.NotContains(t, resp.Include, "reasoning.summary")
|
||||
}
|
||||
@ -822,9 +1020,9 @@ func TestAnthropicToResponses_ThinkingDisabled(t *testing.T) {
|
||||
|
||||
resp, err := AnthropicToResponses(req)
|
||||
require.NoError(t, err)
|
||||
// Default effort applies (high → high) even when thinking is disabled.
|
||||
// Default effort applies (medium) even when thinking is disabled.
|
||||
require.NotNil(t, resp.Reasoning)
|
||||
assert.Equal(t, "high", resp.Reasoning.Effort)
|
||||
assert.Equal(t, "medium", resp.Reasoning.Effort)
|
||||
}
|
||||
|
||||
func TestAnthropicToResponses_NoThinking(t *testing.T) {
|
||||
@ -836,9 +1034,9 @@ func TestAnthropicToResponses_NoThinking(t *testing.T) {
|
||||
|
||||
resp, err := AnthropicToResponses(req)
|
||||
require.NoError(t, err)
|
||||
// Default effort applies (high → high) when no thinking/output_config is set.
|
||||
// Default effort applies (medium) when no thinking/output_config is set.
|
||||
require.NotNil(t, resp.Reasoning)
|
||||
assert.Equal(t, "high", resp.Reasoning.Effort)
|
||||
assert.Equal(t, "medium", resp.Reasoning.Effort)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
@ -846,7 +1044,7 @@ func TestAnthropicToResponses_NoThinking(t *testing.T) {
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestAnthropicToResponses_OutputConfigOverridesDefault(t *testing.T) {
|
||||
// Default is high, but output_config.effort="low" overrides. low→low after mapping.
|
||||
// Default is medium, but output_config.effort="low" overrides. low→low after mapping.
|
||||
req := &AnthropicRequest{
|
||||
Model: "gpt-5.2",
|
||||
MaxTokens: 1024,
|
||||
@ -880,7 +1078,7 @@ func TestAnthropicToResponses_OutputConfigWithoutThinking(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestAnthropicToResponses_OutputConfigHigh(t *testing.T) {
|
||||
// output_config.effort="high" → mapped to "high" (1:1, both sides' default).
|
||||
// output_config.effort="high" → mapped to "high" (1:1).
|
||||
req := &AnthropicRequest{
|
||||
Model: "gpt-5.2",
|
||||
MaxTokens: 1024,
|
||||
@ -912,7 +1110,7 @@ func TestAnthropicToResponses_OutputConfigMax(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestAnthropicToResponses_NoOutputConfig(t *testing.T) {
|
||||
// No output_config → default high regardless of thinking.type.
|
||||
// No output_config → default medium regardless of thinking.type.
|
||||
req := &AnthropicRequest{
|
||||
Model: "gpt-5.2",
|
||||
MaxTokens: 1024,
|
||||
@ -923,11 +1121,11 @@ func TestAnthropicToResponses_NoOutputConfig(t *testing.T) {
|
||||
resp, err := AnthropicToResponses(req)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp.Reasoning)
|
||||
assert.Equal(t, "high", resp.Reasoning.Effort)
|
||||
assert.Equal(t, "medium", resp.Reasoning.Effort)
|
||||
}
|
||||
|
||||
func TestAnthropicToResponses_OutputConfigWithoutEffort(t *testing.T) {
|
||||
// output_config present but effort empty (e.g. only format set) → default high.
|
||||
// output_config present but effort empty (e.g. only format set) → default medium.
|
||||
req := &AnthropicRequest{
|
||||
Model: "gpt-5.2",
|
||||
MaxTokens: 1024,
|
||||
@ -938,7 +1136,7 @@ func TestAnthropicToResponses_OutputConfigWithoutEffort(t *testing.T) {
|
||||
resp, err := AnthropicToResponses(req)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp.Reasoning)
|
||||
assert.Equal(t, "high", resp.Reasoning.Effort)
|
||||
assert.Equal(t, "medium", resp.Reasoning.Effort)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
@ -1110,7 +1308,7 @@ func TestAnthropicToResponses_ToolResultWithImage(t *testing.T) {
|
||||
|
||||
// function_call_output should have text-only output (no image).
|
||||
assert.Equal(t, "function_call_output", items[2].Type)
|
||||
assert.Equal(t, "fc_toolu_1", items[2].CallID)
|
||||
assert.Equal(t, "toolu_1", items[2].CallID)
|
||||
assert.Equal(t, "(empty)", items[2].Output)
|
||||
|
||||
// Image should be in a separate user message.
|
||||
|
||||
@ -32,6 +32,9 @@ func AnthropicToResponses(req *AnthropicRequest) (*ResponsesRequest, error) {
|
||||
|
||||
storeFalse := false
|
||||
out.Store = &storeFalse
|
||||
parallelToolCalls := true
|
||||
out.ParallelToolCalls = ¶llelToolCalls
|
||||
out.Text = &ResponsesText{Verbosity: "medium"}
|
||||
|
||||
if req.MaxTokens > 0 {
|
||||
v := req.MaxTokens
|
||||
@ -46,10 +49,10 @@ func AnthropicToResponses(req *AnthropicRequest) (*ResponsesRequest, error) {
|
||||
}
|
||||
|
||||
// Determine reasoning effort: only output_config.effort controls the
|
||||
// level; thinking.type is ignored. Default is high when unset (both
|
||||
// Anthropic and OpenAI default to high).
|
||||
// level; thinking.type is ignored. Default follows Codex CLI / airgate's
|
||||
// Anthropic bridge shape, which uses medium when unset.
|
||||
// Anthropic levels map 1:1 to OpenAI: low→low, medium→medium, high→high, max→xhigh.
|
||||
effort := "high" // default → both sides' default
|
||||
effort := "medium"
|
||||
if req.OutputConfig != nil && req.OutputConfig.Effort != "" {
|
||||
effort = req.OutputConfig.Effort
|
||||
}
|
||||
@ -108,16 +111,19 @@ func convertAnthropicToolChoiceToResponses(raw json.RawMessage) (json.RawMessage
|
||||
func convertAnthropicToResponsesInput(system json.RawMessage, msgs []AnthropicMessage) ([]ResponsesInputItem, error) {
|
||||
var out []ResponsesInputItem
|
||||
|
||||
// System prompt → system role input item.
|
||||
// System prompt → developer role input item. ChatGPT Codex SSE behaves like
|
||||
// Codex CLI here: keeping Anthropic system text in input preserves the
|
||||
// conversation/cache shape better than moving it into instructions.
|
||||
if len(system) > 0 {
|
||||
sysText, err := parseAnthropicSystemPrompt(system)
|
||||
sysParts, err := parseAnthropicSystemContentParts(system)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if sysText != "" {
|
||||
content, _ := json.Marshal(sysText)
|
||||
if len(sysParts) > 0 {
|
||||
content, _ := json.Marshal(sysParts)
|
||||
out = append(out, ResponsesInputItem{
|
||||
Role: "system",
|
||||
Type: "message",
|
||||
Role: "developer",
|
||||
Content: content,
|
||||
})
|
||||
}
|
||||
@ -133,24 +139,32 @@ func convertAnthropicToResponsesInput(system json.RawMessage, msgs []AnthropicMe
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// parseAnthropicSystemPrompt handles the Anthropic system field which can be
|
||||
// a plain string or an array of text blocks.
|
||||
func parseAnthropicSystemPrompt(raw json.RawMessage) (string, error) {
|
||||
// parseAnthropicSystemContentParts handles the Anthropic system field which can
|
||||
// be a plain string or an array of text blocks. Claude Code may include an
|
||||
// x-anthropic-billing-header block; airgate drops it before sending to Codex.
|
||||
func parseAnthropicSystemContentParts(raw json.RawMessage) ([]ResponsesContentPart, error) {
|
||||
var s string
|
||||
if err := json.Unmarshal(raw, &s); err == nil {
|
||||
return s, nil
|
||||
if isAnthropicBillingHeaderText(s) || s == "" {
|
||||
return nil, nil
|
||||
}
|
||||
return []ResponsesContentPart{{Type: "input_text", Text: s}}, nil
|
||||
}
|
||||
var blocks []AnthropicContentBlock
|
||||
if err := json.Unmarshal(raw, &blocks); err != nil {
|
||||
return "", err
|
||||
return nil, err
|
||||
}
|
||||
var parts []string
|
||||
var parts []ResponsesContentPart
|
||||
for _, b := range blocks {
|
||||
if b.Type == "text" && b.Text != "" {
|
||||
parts = append(parts, b.Text)
|
||||
if b.Type == "text" && b.Text != "" && !isAnthropicBillingHeaderText(b.Text) {
|
||||
parts = append(parts, ResponsesContentPart{Type: "input_text", Text: b.Text})
|
||||
}
|
||||
}
|
||||
return strings.Join(parts, "\n\n"), nil
|
||||
return parts, nil
|
||||
}
|
||||
|
||||
func isAnthropicBillingHeaderText(text string) bool {
|
||||
return strings.HasPrefix(text, "x-anthropic-billing-header: ")
|
||||
}
|
||||
|
||||
// anthropicMsgToResponsesItems converts a single Anthropic message into one
|
||||
@ -173,8 +187,12 @@ func anthropicUserToResponses(raw json.RawMessage) ([]ResponsesInputItem, error)
|
||||
// Try plain string.
|
||||
var s string
|
||||
if err := json.Unmarshal(raw, &s); err == nil {
|
||||
content, _ := json.Marshal(s)
|
||||
return []ResponsesInputItem{{Role: "user", Content: content}}, nil
|
||||
parts := []ResponsesContentPart{{Type: "input_text", Text: s}}
|
||||
partsJSON, err := json.Marshal(parts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return []ResponsesInputItem{{Type: "message", Role: "user", Content: partsJSON}}, nil
|
||||
}
|
||||
|
||||
var blocks []AnthropicContentBlock
|
||||
@ -223,7 +241,7 @@ func anthropicUserToResponses(raw json.RawMessage) ([]ResponsesInputItem, error)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out = append(out, ResponsesInputItem{Role: "user", Content: content})
|
||||
out = append(out, ResponsesInputItem{Type: "message", Role: "user", Content: content})
|
||||
}
|
||||
|
||||
return out, nil
|
||||
@ -242,7 +260,7 @@ func anthropicAssistantToResponses(raw json.RawMessage) ([]ResponsesInputItem, e
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return []ResponsesInputItem{{Role: "assistant", Content: partsJSON}}, nil
|
||||
return []ResponsesInputItem{{Type: "message", Role: "assistant", Content: partsJSON}}, nil
|
||||
}
|
||||
|
||||
var blocks []AnthropicContentBlock
|
||||
@ -260,7 +278,7 @@ func anthropicAssistantToResponses(raw json.RawMessage) ([]ResponsesInputItem, e
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, ResponsesInputItem{Role: "assistant", Content: partsJSON})
|
||||
items = append(items, ResponsesInputItem{Type: "message", Role: "assistant", Content: partsJSON})
|
||||
}
|
||||
|
||||
// tool_use → function_call items.
|
||||
@ -284,17 +302,14 @@ func anthropicAssistantToResponses(raw json.RawMessage) ([]ResponsesInputItem, e
|
||||
return items, nil
|
||||
}
|
||||
|
||||
// toResponsesCallID converts an Anthropic tool ID (toolu_xxx / call_xxx) to a
|
||||
// Responses API function_call ID that starts with "fc_".
|
||||
// toResponsesCallID preserves Anthropic tool IDs as Responses call_id values.
|
||||
// Claude Code sends tool_result.tool_use_id back verbatim, and ChatGPT Codex
|
||||
// continuation expects that call_id to match the original tool_use id.
|
||||
func toResponsesCallID(id string) string {
|
||||
if strings.HasPrefix(id, "fc_") {
|
||||
return id
|
||||
}
|
||||
return "fc_" + id
|
||||
return id
|
||||
}
|
||||
|
||||
// fromResponsesCallID reverses toResponsesCallID, stripping the "fc_" prefix
|
||||
// that was added during request conversion.
|
||||
// fromResponsesCallID reverses old prefixed IDs while preserving current IDs.
|
||||
func fromResponsesCallID(id string) string {
|
||||
if after, ok := strings.CutPrefix(id, "fc_"); ok {
|
||||
// Only strip if the remainder doesn't look like it was already "fc_" prefixed.
|
||||
@ -412,11 +427,16 @@ func convertAnthropicToolsToResponses(tools []AnthropicTool) []ResponsesTool {
|
||||
Name: t.Name,
|
||||
Description: t.Description,
|
||||
Parameters: normalizeToolParameters(t.InputSchema),
|
||||
Strict: boolPtr(false),
|
||||
})
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func boolPtr(v bool) *bool {
|
||||
return &v
|
||||
}
|
||||
|
||||
// normalizeToolParameters ensures the tool parameter schema is valid for
|
||||
// OpenAI's Responses API, which requires "properties" on object schemas.
|
||||
//
|
||||
|
||||
@ -720,6 +720,49 @@ func TestResponsesEventToChatChunks_Completed(t *testing.T) {
|
||||
assert.Equal(t, 30, chunks[1].Usage.PromptTokensDetails.CachedTokens)
|
||||
}
|
||||
|
||||
func TestResponsesEventToChatChunks_ResponseDone(t *testing.T) {
|
||||
state := NewResponsesEventToChatState()
|
||||
state.Model = "gpt-4o"
|
||||
state.IncludeUsage = true
|
||||
|
||||
chunks := ResponsesEventToChatChunks(&ResponsesStreamEvent{
|
||||
Type: "response.done",
|
||||
Response: &ResponsesResponse{
|
||||
Status: "completed",
|
||||
Usage: &ResponsesUsage{InputTokens: 13, OutputTokens: 7},
|
||||
},
|
||||
}, state)
|
||||
require.Len(t, chunks, 2)
|
||||
require.NotNil(t, chunks[0].Choices[0].FinishReason)
|
||||
assert.Equal(t, "stop", *chunks[0].Choices[0].FinishReason)
|
||||
require.NotNil(t, chunks[1].Usage)
|
||||
assert.Equal(t, 13, chunks[1].Usage.PromptTokens)
|
||||
assert.Equal(t, 7, chunks[1].Usage.CompletionTokens)
|
||||
assert.Nil(t, FinalizeResponsesChatStream(state))
|
||||
}
|
||||
|
||||
func TestResponsesEventToChatChunks_ResponseDoneIncomplete(t *testing.T) {
|
||||
state := NewResponsesEventToChatState()
|
||||
state.Model = "gpt-4o"
|
||||
state.IncludeUsage = true
|
||||
|
||||
chunks := ResponsesEventToChatChunks(&ResponsesStreamEvent{
|
||||
Type: "response.done",
|
||||
Response: &ResponsesResponse{
|
||||
Status: "incomplete",
|
||||
IncompleteDetails: &ResponsesIncompleteDetails{Reason: "max_output_tokens"},
|
||||
Usage: &ResponsesUsage{InputTokens: 13, OutputTokens: 7},
|
||||
},
|
||||
}, state)
|
||||
require.Len(t, chunks, 2)
|
||||
require.NotNil(t, chunks[0].Choices[0].FinishReason)
|
||||
assert.Equal(t, "length", *chunks[0].Choices[0].FinishReason)
|
||||
require.NotNil(t, chunks[1].Usage)
|
||||
assert.Equal(t, 13, chunks[1].Usage.PromptTokens)
|
||||
assert.Equal(t, 7, chunks[1].Usage.CompletionTokens)
|
||||
assert.Nil(t, FinalizeResponsesChatStream(state))
|
||||
}
|
||||
|
||||
func TestResponsesEventToChatChunks_CompletedWithToolCalls(t *testing.T) {
|
||||
state := NewResponsesEventToChatState()
|
||||
state.Model = "gpt-4o"
|
||||
|
||||
@ -120,7 +120,7 @@ func responsesStatusToAnthropicStopReason(status string, details *ResponsesIncom
|
||||
}
|
||||
return "end_turn"
|
||||
case "completed":
|
||||
if len(blocks) > 0 && blocks[len(blocks)-1].Type == "tool_use" {
|
||||
if containsAnthropicToolUseBlock(blocks) {
|
||||
return "tool_use"
|
||||
}
|
||||
return "end_turn"
|
||||
@ -129,6 +129,15 @@ func responsesStatusToAnthropicStopReason(status string, details *ResponsesIncom
|
||||
}
|
||||
}
|
||||
|
||||
func containsAnthropicToolUseBlock(blocks []AnthropicContentBlock) bool {
|
||||
for _, block := range blocks {
|
||||
if block.Type == "tool_use" {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func sanitizeAnthropicToolUseInput(name string, raw string) json.RawMessage {
|
||||
if name != "Read" || raw == "" {
|
||||
return json.RawMessage(raw)
|
||||
@ -161,11 +170,13 @@ type ResponsesEventToAnthropicState struct {
|
||||
MessageStartSent bool
|
||||
MessageStopSent bool
|
||||
|
||||
ContentBlockIndex int
|
||||
ContentBlockOpen bool
|
||||
CurrentBlockType string // "text" | "thinking" | "tool_use"
|
||||
CurrentToolName string
|
||||
CurrentToolArgs string
|
||||
ContentBlockIndex int
|
||||
ContentBlockOpen bool
|
||||
CurrentBlockType string // "text" | "thinking" | "tool_use"
|
||||
CurrentToolName string
|
||||
CurrentToolArgs string
|
||||
CurrentToolHadDelta bool
|
||||
HasToolCall bool
|
||||
|
||||
// OutputIndexToBlockIdx maps Responses output_index → Anthropic content block index.
|
||||
OutputIndexToBlockIdx map[int]int
|
||||
@ -212,7 +223,9 @@ func ResponsesEventToAnthropicEvents(
|
||||
return resToAnthHandleReasoningDelta(evt, state)
|
||||
case "response.reasoning_summary_text.done":
|
||||
return resToAnthHandleBlockDone(state)
|
||||
case "response.completed", "response.incomplete", "response.failed":
|
||||
// response.done 是 Realtime/WS 与项目透传路径使用的终止别名;
|
||||
// 普通 Responses HTTP SSE 的公开终止事件仍以 response.completed 为主。
|
||||
case "response.completed", "response.done", "response.incomplete", "response.failed":
|
||||
return resToAnthHandleCompleted(evt, state)
|
||||
default:
|
||||
return nil
|
||||
@ -229,11 +242,16 @@ func FinalizeResponsesAnthropicStream(state *ResponsesEventToAnthropicState) []A
|
||||
var events []AnthropicStreamEvent
|
||||
events = append(events, closeCurrentBlock(state)...)
|
||||
|
||||
stopReason := "end_turn"
|
||||
if state.HasToolCall {
|
||||
stopReason = "tool_use"
|
||||
}
|
||||
|
||||
events = append(events,
|
||||
AnthropicStreamEvent{
|
||||
Type: "message_delta",
|
||||
Delta: &AnthropicDelta{
|
||||
StopReason: "end_turn",
|
||||
StopReason: stopReason,
|
||||
},
|
||||
Usage: &AnthropicUsage{
|
||||
InputTokens: state.InputTokens,
|
||||
@ -304,6 +322,8 @@ func resToAnthHandleOutputItemAdded(evt *ResponsesStreamEvent, state *ResponsesE
|
||||
state.CurrentBlockType = "tool_use"
|
||||
state.CurrentToolName = evt.Item.Name
|
||||
state.CurrentToolArgs = ""
|
||||
state.CurrentToolHadDelta = false
|
||||
state.HasToolCall = true
|
||||
|
||||
events = append(events, AnthropicStreamEvent{
|
||||
Type: "content_block_start",
|
||||
@ -388,6 +408,9 @@ func resToAnthHandleFuncArgsDelta(evt *ResponsesStreamEvent, state *ResponsesEve
|
||||
state.CurrentToolArgs += evt.Delta
|
||||
return nil
|
||||
}
|
||||
if state.CurrentBlockType == "tool_use" {
|
||||
state.CurrentToolHadDelta = true
|
||||
}
|
||||
|
||||
blockIdx, ok := state.OutputIndexToBlockIdx[evt.OutputIndex]
|
||||
if !ok {
|
||||
@ -405,7 +428,7 @@ func resToAnthHandleFuncArgsDelta(evt *ResponsesStreamEvent, state *ResponsesEve
|
||||
}
|
||||
|
||||
func resToAnthHandleFuncArgsDone(evt *ResponsesStreamEvent, state *ResponsesEventToAnthropicState) []AnthropicStreamEvent {
|
||||
if state.CurrentBlockType != "tool_use" || state.CurrentToolName != "Read" {
|
||||
if state.CurrentBlockType != "tool_use" {
|
||||
return resToAnthHandleBlockDone(state)
|
||||
}
|
||||
|
||||
@ -413,10 +436,16 @@ func resToAnthHandleFuncArgsDone(evt *ResponsesStreamEvent, state *ResponsesEven
|
||||
if raw == "" {
|
||||
raw = state.CurrentToolArgs
|
||||
}
|
||||
sanitized := sanitizeAnthropicToolUseInput(state.CurrentToolName, raw)
|
||||
if len(sanitized) == 0 {
|
||||
if raw == "" || state.CurrentToolHadDelta {
|
||||
return closeCurrentBlock(state)
|
||||
}
|
||||
if state.CurrentToolName == "Read" {
|
||||
sanitized := sanitizeAnthropicToolUseInput(state.CurrentToolName, raw)
|
||||
if len(sanitized) == 0 {
|
||||
return closeCurrentBlock(state)
|
||||
}
|
||||
raw = string(sanitized)
|
||||
}
|
||||
|
||||
idx := state.ContentBlockIndex
|
||||
events := []AnthropicStreamEvent{{
|
||||
@ -424,7 +453,7 @@ func resToAnthHandleFuncArgsDone(evt *ResponsesStreamEvent, state *ResponsesEven
|
||||
Index: &idx,
|
||||
Delta: &AnthropicDelta{
|
||||
Type: "input_json_delta",
|
||||
PartialJSON: string(sanitized),
|
||||
PartialJSON: raw,
|
||||
},
|
||||
}}
|
||||
events = append(events, closeCurrentBlock(state)...)
|
||||
@ -551,7 +580,7 @@ func resToAnthHandleCompleted(evt *ResponsesStreamEvent, state *ResponsesEventTo
|
||||
stopReason = "max_tokens"
|
||||
}
|
||||
case "completed":
|
||||
if state.ContentBlockIndex > 0 && state.CurrentBlockType == "tool_use" {
|
||||
if state.HasToolCall {
|
||||
stopReason = "tool_use"
|
||||
}
|
||||
}
|
||||
@ -584,6 +613,7 @@ func closeCurrentBlock(state *ResponsesEventToAnthropicState) []AnthropicStreamE
|
||||
state.ContentBlockIndex++
|
||||
state.CurrentToolName = ""
|
||||
state.CurrentToolArgs = ""
|
||||
state.CurrentToolHadDelta = false
|
||||
return []AnthropicStreamEvent{{
|
||||
Type: "content_block_stop",
|
||||
Index: &idx,
|
||||
|
||||
@ -160,7 +160,9 @@ func ResponsesEventToChatChunks(evt *ResponsesStreamEvent, state *ResponsesEvent
|
||||
return resToChatHandleReasoningDelta(evt, state)
|
||||
case "response.reasoning_summary_text.done":
|
||||
return nil
|
||||
case "response.completed", "response.incomplete", "response.failed":
|
||||
// response.done 是 Realtime/WS 与项目透传路径使用的终止别名;
|
||||
// 普通 Responses HTTP SSE 的公开终止事件仍以 response.completed 为主。
|
||||
case "response.completed", "response.done", "response.incomplete", "response.failed":
|
||||
return resToChatHandleCompleted(evt, state)
|
||||
default:
|
||||
return nil
|
||||
|
||||
@ -53,6 +53,8 @@ type AnthropicMessage struct {
|
||||
type AnthropicContentBlock struct {
|
||||
Type string `json:"type"`
|
||||
|
||||
CacheControl *AnthropicCacheControl `json:"cache_control,omitempty"`
|
||||
|
||||
// type=text
|
||||
Text string `json:"text,omitempty"`
|
||||
|
||||
@ -165,19 +167,23 @@ type AnthropicDelta struct {
|
||||
|
||||
// ResponsesRequest is the request body for POST /v1/responses.
|
||||
type ResponsesRequest struct {
|
||||
Model string `json:"model"`
|
||||
Instructions string `json:"instructions,omitempty"`
|
||||
Input json.RawMessage `json:"input"` // string or []ResponsesInputItem
|
||||
MaxOutputTokens *int `json:"max_output_tokens,omitempty"`
|
||||
Temperature *float64 `json:"temperature,omitempty"`
|
||||
TopP *float64 `json:"top_p,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
Tools []ResponsesTool `json:"tools,omitempty"`
|
||||
Include []string `json:"include,omitempty"`
|
||||
Store *bool `json:"store,omitempty"`
|
||||
Reasoning *ResponsesReasoning `json:"reasoning,omitempty"`
|
||||
ToolChoice json.RawMessage `json:"tool_choice,omitempty"`
|
||||
ServiceTier string `json:"service_tier,omitempty"`
|
||||
Model string `json:"model"`
|
||||
Instructions string `json:"instructions,omitempty"`
|
||||
Input json.RawMessage `json:"input"` // string or []ResponsesInputItem
|
||||
MaxOutputTokens *int `json:"max_output_tokens,omitempty"`
|
||||
Temperature *float64 `json:"temperature,omitempty"`
|
||||
TopP *float64 `json:"top_p,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
Tools []ResponsesTool `json:"tools,omitempty"`
|
||||
Include []string `json:"include,omitempty"`
|
||||
Store *bool `json:"store,omitempty"`
|
||||
ParallelToolCalls *bool `json:"parallel_tool_calls,omitempty"`
|
||||
Reasoning *ResponsesReasoning `json:"reasoning,omitempty"`
|
||||
Text *ResponsesText `json:"text,omitempty"`
|
||||
ToolChoice json.RawMessage `json:"tool_choice,omitempty"`
|
||||
ServiceTier string `json:"service_tier,omitempty"`
|
||||
PromptCacheKey string `json:"prompt_cache_key,omitempty"`
|
||||
PreviousResponseID string `json:"previous_response_id,omitempty"`
|
||||
}
|
||||
|
||||
// ResponsesReasoning configures reasoning effort in the Responses API.
|
||||
@ -186,13 +192,18 @@ type ResponsesReasoning struct {
|
||||
Summary string `json:"summary,omitempty"` // "auto" | "concise" | "detailed"
|
||||
}
|
||||
|
||||
// ResponsesText configures text output options in the Responses API.
|
||||
type ResponsesText struct {
|
||||
Verbosity string `json:"verbosity,omitempty"` // "low" | "medium" | "high"
|
||||
}
|
||||
|
||||
// ResponsesInputItem is one item in the Responses API input array.
|
||||
// The Type field determines which other fields are populated.
|
||||
type ResponsesInputItem struct {
|
||||
// Common
|
||||
Type string `json:"type,omitempty"` // "" for role-based messages
|
||||
|
||||
// Role-based messages (system/user/assistant)
|
||||
// Role-based messages (developer/system/user/assistant)
|
||||
Role string `json:"role,omitempty"`
|
||||
Content json.RawMessage `json:"content,omitempty"` // string or []ResponsesContentPart
|
||||
|
||||
@ -314,7 +325,7 @@ type ResponsesOutputTokensDetails struct {
|
||||
type ResponsesStreamEvent struct {
|
||||
Type string `json:"type"`
|
||||
|
||||
// response.created / response.completed / response.failed / response.incomplete
|
||||
// response.created / response.completed / response.done / response.failed / response.incomplete
|
||||
Response *ResponsesResponse `json:"response,omitempty"`
|
||||
|
||||
// response.output_item.added / response.output_item.done
|
||||
|
||||
75
backend/internal/pkg/openai_compat/upstream_capability.go
Normal file
75
backend/internal/pkg/openai_compat/upstream_capability.go
Normal file
@ -0,0 +1,75 @@
|
||||
// Package openai_compat 提供 OpenAI 协议族在不同上游间的能力差异判定工具。
|
||||
//
|
||||
// 背景:sub2api 的 OpenAI APIKey 账号通过 base_url 接入多种第三方 OpenAI 兼容上游
|
||||
// (DeepSeek、Kimi、GLM、Qwen 等)。这些上游普遍只支持 /v1/chat/completions,
|
||||
// 不存在 /v1/responses 端点。但网关历史代码无差别走 CC→Responses 转换并打到
|
||||
// /v1/responses,导致兼容上游 404。
|
||||
//
|
||||
// 本包提供基于"账号探测标记"的能力判定,配合
|
||||
// internal/service/openai_apikey_responses_probe.go 在创建/修改账号时一次性
|
||||
// 探测并落标。
|
||||
//
|
||||
// 设计取舍:
|
||||
// - 不维护静态 host 白名单——避免新增厂商时必须改代码(讨论沉淀于
|
||||
// pensieve/short-term/knowledge/upstream-capability-detection-design-tradeoffs)
|
||||
// - 标记缺失时默认 true(即"走 Responses"),保持与重构前老代码完全一致的存量
|
||||
// 账号行为("现状即证据"原则;详见
|
||||
// pensieve/short-term/maxims/preserve-existing-runtime-behavior-when-replacing-logic-in-stateful-systems)
|
||||
package openai_compat
|
||||
|
||||
// AccountResponsesSupport 描述账号上游对 OpenAI Responses API 的支持状态。
|
||||
//
|
||||
// 仅用于 platform=openai + type=apikey 的账号;其他账号类型不应调用本包判定。
|
||||
type AccountResponsesSupport int
|
||||
|
||||
const (
|
||||
// ResponsesSupportUnknown 表示账号尚未完成能力探测(extra 字段缺失)。
|
||||
// 上游路由层应按"现状即证据"原则默认走 Responses,保持与重构前一致。
|
||||
ResponsesSupportUnknown AccountResponsesSupport = iota
|
||||
|
||||
// ResponsesSupportYes 探测确认上游支持 /v1/responses。
|
||||
ResponsesSupportYes
|
||||
|
||||
// ResponsesSupportNo 探测确认上游不支持 /v1/responses,应走
|
||||
// /v1/chat/completions 直转路径。
|
||||
ResponsesSupportNo
|
||||
)
|
||||
|
||||
// ExtraKeyResponsesSupported 是 accounts.extra JSON 中存储探测结果的键名。
|
||||
// 值类型为 bool:true=支持、false=不支持、键缺失=未探测。
|
||||
const ExtraKeyResponsesSupported = "openai_responses_supported"
|
||||
|
||||
// ResolveResponsesSupport 从账号的 extra map 中读取探测标记。
|
||||
//
|
||||
// 标记缺失或类型不匹配时返回 ResponsesSupportUnknown——调用方应按
|
||||
// "未探测=保留旧行为=走 Responses" 处理(参见 ShouldUseResponsesAPI)。
|
||||
func ResolveResponsesSupport(extra map[string]any) AccountResponsesSupport {
|
||||
if extra == nil {
|
||||
return ResponsesSupportUnknown
|
||||
}
|
||||
v, ok := extra[ExtraKeyResponsesSupported]
|
||||
if !ok {
|
||||
return ResponsesSupportUnknown
|
||||
}
|
||||
supported, ok := v.(bool)
|
||||
if !ok {
|
||||
return ResponsesSupportUnknown
|
||||
}
|
||||
if supported {
|
||||
return ResponsesSupportYes
|
||||
}
|
||||
return ResponsesSupportNo
|
||||
}
|
||||
|
||||
// ShouldUseResponsesAPI 判断 OpenAI APIKey 账号的入站 /v1/chat/completions 请求
|
||||
// 是否应走"CC→Responses 转换 + 上游 /v1/responses"路径。
|
||||
//
|
||||
// 返回 true 的两种情况:
|
||||
// 1. 账号已探测确认支持 Responses
|
||||
// 2. 账号未探测(标记缺失)——按"现状即证据"原则保留旧行为
|
||||
//
|
||||
// 仅当账号已探测且确认不支持时返回 false,此时调用方应走 CC 直转路径
|
||||
// (详见 internal/service/openai_gateway_chat_completions_raw.go)。
|
||||
func ShouldUseResponsesAPI(extra map[string]any) bool {
|
||||
return ResolveResponsesSupport(extra) != ResponsesSupportNo
|
||||
}
|
||||
@ -0,0 +1,55 @@
|
||||
package openai_compat
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestResolveResponsesSupport(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
extra map[string]any
|
||||
want AccountResponsesSupport
|
||||
}{
|
||||
{"nil extra", nil, ResponsesSupportUnknown},
|
||||
{"empty extra", map[string]any{}, ResponsesSupportUnknown},
|
||||
{"key missing", map[string]any{"other": "value"}, ResponsesSupportUnknown},
|
||||
{"value true", map[string]any{ExtraKeyResponsesSupported: true}, ResponsesSupportYes},
|
||||
{"value false", map[string]any{ExtraKeyResponsesSupported: false}, ResponsesSupportNo},
|
||||
{"value wrong type string", map[string]any{ExtraKeyResponsesSupported: "true"}, ResponsesSupportUnknown},
|
||||
{"value wrong type number", map[string]any{ExtraKeyResponsesSupported: 1}, ResponsesSupportUnknown},
|
||||
{"value nil", map[string]any{ExtraKeyResponsesSupported: nil}, ResponsesSupportUnknown},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got := ResolveResponsesSupport(tc.extra)
|
||||
if got != tc.want {
|
||||
t.Errorf("ResolveResponsesSupport(%v) = %v, want %v", tc.extra, got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestShouldUseResponsesAPI(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
extra map[string]any
|
||||
want bool
|
||||
}{
|
||||
// 关键不变量:未探测必须返回 true(保留旧行为)
|
||||
{"unknown defaults to true (preserve old behavior)", nil, true},
|
||||
{"unknown empty defaults to true", map[string]any{}, true},
|
||||
{"unknown wrong type defaults to true", map[string]any{ExtraKeyResponsesSupported: "yes"}, true},
|
||||
|
||||
// 已探测:标记决定
|
||||
{"explicitly supported", map[string]any{ExtraKeyResponsesSupported: true}, true},
|
||||
{"explicitly unsupported", map[string]any{ExtraKeyResponsesSupported: false}, false},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got := ShouldUseResponsesAPI(tc.extra)
|
||||
if got != tc.want {
|
||||
t.Errorf("ShouldUseResponsesAPI(%v) = %v, want %v", tc.extra, got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@ -22,6 +22,34 @@ const (
|
||||
|
||||
var affiliateCodeCharset = []byte("ABCDEFGHJKLMNPQRSTUVWXYZ23456789")
|
||||
|
||||
const affiliateUserOverviewSQL = `
|
||||
SELECT ua.user_id,
|
||||
COALESCE(u.email, ''),
|
||||
COALESCE(u.username, ''),
|
||||
ua.aff_code,
|
||||
COALESCE(ua.aff_rebate_rate_percent, 0)::double precision,
|
||||
(ua.aff_rebate_rate_percent IS NOT NULL) AS has_custom_rate,
|
||||
ua.aff_count,
|
||||
COALESCE(rebated.rebated_invitee_count, 0),
|
||||
(ua.aff_quota + COALESCE(matured.matured_frozen_quota, 0))::double precision,
|
||||
ua.aff_history_quota::double precision
|
||||
FROM user_affiliates ua
|
||||
JOIN users u ON u.id = ua.user_id
|
||||
LEFT JOIN (
|
||||
SELECT user_id, COUNT(DISTINCT source_user_id)::integer AS rebated_invitee_count
|
||||
FROM user_affiliate_ledger
|
||||
WHERE action = 'accrue' AND source_user_id IS NOT NULL
|
||||
GROUP BY user_id
|
||||
) rebated ON rebated.user_id = ua.user_id
|
||||
LEFT JOIN (
|
||||
SELECT user_id, COALESCE(SUM(amount), 0)::double precision AS matured_frozen_quota
|
||||
FROM user_affiliate_ledger
|
||||
WHERE action = 'accrue' AND frozen_until IS NOT NULL AND frozen_until <= NOW()
|
||||
GROUP BY user_id
|
||||
) matured ON matured.user_id = ua.user_id
|
||||
WHERE ua.user_id = $1
|
||||
LIMIT 1`
|
||||
|
||||
type affiliateQueryExecer interface {
|
||||
QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error)
|
||||
ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error)
|
||||
@ -86,7 +114,7 @@ func (r *affiliateRepository) BindInviter(ctx context.Context, userID, inviterID
|
||||
return bound, nil
|
||||
}
|
||||
|
||||
func (r *affiliateRepository) AccrueQuota(ctx context.Context, inviterID, inviteeUserID int64, amount float64, freezeHours int) (bool, error) {
|
||||
func (r *affiliateRepository) AccrueQuota(ctx context.Context, inviterID, inviteeUserID int64, amount float64, freezeHours int, sourceOrderID *int64) (bool, error) {
|
||||
if amount <= 0 {
|
||||
return false, nil
|
||||
}
|
||||
@ -112,15 +140,15 @@ func (r *affiliateRepository) AccrueQuota(ctx context.Context, inviterID, invite
|
||||
|
||||
if freezeHours > 0 {
|
||||
if _, err = txClient.ExecContext(txCtx, `
|
||||
INSERT INTO user_affiliate_ledger (user_id, action, amount, source_user_id, frozen_until, created_at, updated_at)
|
||||
VALUES ($1, 'accrue', $2, $3, NOW() + make_interval(hours => $4), NOW(), NOW())`,
|
||||
inviterID, amount, inviteeUserID, freezeHours); err != nil {
|
||||
INSERT INTO user_affiliate_ledger (user_id, action, amount, source_user_id, source_order_id, frozen_until, created_at, updated_at)
|
||||
VALUES ($1, 'accrue', $2, $3, $4, NOW() + make_interval(hours => $5), NOW(), NOW())`,
|
||||
inviterID, amount, inviteeUserID, nullableInt64Arg(sourceOrderID), freezeHours); err != nil {
|
||||
return fmt.Errorf("insert affiliate accrue ledger: %w", err)
|
||||
}
|
||||
} else {
|
||||
if _, err = txClient.ExecContext(txCtx, `
|
||||
INSERT INTO user_affiliate_ledger (user_id, action, amount, source_user_id, created_at, updated_at)
|
||||
VALUES ($1, 'accrue', $2, $3, NOW(), NOW())`, inviterID, amount, inviteeUserID); err != nil {
|
||||
INSERT INTO user_affiliate_ledger (user_id, action, amount, source_user_id, source_order_id, created_at, updated_at)
|
||||
VALUES ($1, 'accrue', $2, $3, $4, NOW(), NOW())`, inviterID, amount, inviteeUserID, nullableInt64Arg(sourceOrderID)); err != nil {
|
||||
return fmt.Errorf("insert affiliate accrue ledger: %w", err)
|
||||
}
|
||||
}
|
||||
@ -275,9 +303,32 @@ FROM cleared`, userID)
|
||||
return err
|
||||
}
|
||||
|
||||
snapshot, err := queryAffiliateTransferSnapshot(txCtx, txClient, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if _, err = txClient.ExecContext(txCtx, `
|
||||
INSERT INTO user_affiliate_ledger (user_id, action, amount, source_user_id, created_at, updated_at)
|
||||
VALUES ($1, 'transfer', $2, NULL, NOW(), NOW())`, userID, transferred); err != nil {
|
||||
INSERT INTO user_affiliate_ledger (
|
||||
user_id,
|
||||
action,
|
||||
amount,
|
||||
source_user_id,
|
||||
balance_after,
|
||||
aff_quota_after,
|
||||
aff_frozen_quota_after,
|
||||
aff_history_quota_after,
|
||||
created_at,
|
||||
updated_at
|
||||
)
|
||||
VALUES ($1, 'transfer', $2, NULL, $3, $4, $5, $6, NOW(), NOW())`,
|
||||
userID,
|
||||
transferred,
|
||||
snapshot.BalanceAfter,
|
||||
snapshot.AvailableQuotaAfter,
|
||||
snapshot.FrozenQuotaAfter,
|
||||
snapshot.HistoryQuotaAfter,
|
||||
); err != nil {
|
||||
return fmt.Errorf("insert affiliate transfer ledger: %w", err)
|
||||
}
|
||||
|
||||
@ -332,6 +383,349 @@ LIMIT $2`, inviterID, limit)
|
||||
return invitees, nil
|
||||
}
|
||||
|
||||
func (r *affiliateRepository) ListAffiliateInviteRecords(ctx context.Context, filter service.AffiliateRecordFilter) ([]service.AffiliateInviteRecord, int64, error) {
|
||||
client := clientFromContext(ctx, r.client)
|
||||
where, args := buildAffiliateRecordWhere(filter, "ua.created_at", []string{
|
||||
"inviter.email", "inviter.username", "invitee.email", "invitee.username",
|
||||
"ua.inviter_id::text", "ua.user_id::text", "inviter_aff.aff_code",
|
||||
})
|
||||
|
||||
total, err := queryAffiliateRecordCount(ctx, client, `
|
||||
SELECT COUNT(*)
|
||||
FROM user_affiliates ua
|
||||
JOIN users invitee ON invitee.id = ua.user_id
|
||||
JOIN users inviter ON inviter.id = ua.inviter_id
|
||||
JOIN user_affiliates inviter_aff ON inviter_aff.user_id = ua.inviter_id
|
||||
`+where, args...)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
orderBy := buildAffiliateRecordOrderBy(filter, map[string]string{
|
||||
"inviter": "inviter.email",
|
||||
"invitee": "invitee.email",
|
||||
"aff_code": "inviter_aff.aff_code",
|
||||
"total_rebate": "total_rebate",
|
||||
"created_at": "ua.created_at",
|
||||
}, "ua.created_at")
|
||||
args = append(args, filter.PageSize, (filter.Page-1)*filter.PageSize)
|
||||
rows, err := client.QueryContext(ctx, `
|
||||
SELECT ua.inviter_id,
|
||||
COALESCE(inviter.email, ''),
|
||||
COALESCE(inviter.username, ''),
|
||||
ua.user_id,
|
||||
COALESCE(invitee.email, ''),
|
||||
COALESCE(invitee.username, ''),
|
||||
COALESCE(inviter_aff.aff_code, ''),
|
||||
COALESCE(SUM(ual.amount), 0)::double precision AS total_rebate,
|
||||
ua.created_at
|
||||
FROM user_affiliates ua
|
||||
JOIN users invitee ON invitee.id = ua.user_id
|
||||
JOIN users inviter ON inviter.id = ua.inviter_id
|
||||
JOIN user_affiliates inviter_aff ON inviter_aff.user_id = ua.inviter_id
|
||||
LEFT JOIN user_affiliate_ledger ual
|
||||
ON ual.user_id = ua.inviter_id
|
||||
AND ual.source_user_id = ua.user_id
|
||||
AND ual.action = 'accrue'
|
||||
`+where+`
|
||||
GROUP BY ua.inviter_id, inviter.email, inviter.username, ua.user_id, invitee.email, invitee.username, inviter_aff.aff_code, ua.created_at
|
||||
`+orderBy+`
|
||||
LIMIT $`+fmt.Sprint(len(args)-1)+` OFFSET $`+fmt.Sprint(len(args)), args...)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
items := make([]service.AffiliateInviteRecord, 0)
|
||||
for rows.Next() {
|
||||
var item service.AffiliateInviteRecord
|
||||
if err := rows.Scan(
|
||||
&item.InviterID,
|
||||
&item.InviterEmail,
|
||||
&item.InviterUsername,
|
||||
&item.InviteeID,
|
||||
&item.InviteeEmail,
|
||||
&item.InviteeUsername,
|
||||
&item.AffCode,
|
||||
&item.TotalRebate,
|
||||
&item.CreatedAt,
|
||||
); err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
items = append(items, item)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
return items, total, nil
|
||||
}
|
||||
|
||||
func (r *affiliateRepository) ListAffiliateRebateRecords(ctx context.Context, filter service.AffiliateRecordFilter) ([]service.AffiliateRebateRecord, int64, error) {
|
||||
client := clientFromContext(ctx, r.client)
|
||||
where, args := buildAffiliateRecordWhere(filter, "ual.created_at", []string{
|
||||
"inviter.email", "inviter.username", "invitee.email", "invitee.username",
|
||||
"po.id::text", "po.out_trade_no", "po.payment_type", "po.status",
|
||||
})
|
||||
baseJoin := `
|
||||
FROM user_affiliate_ledger ual
|
||||
JOIN payment_orders po ON po.id = ual.source_order_id
|
||||
JOIN users invitee ON invitee.id = ual.source_user_id
|
||||
JOIN users inviter ON inviter.id = ual.user_id
|
||||
WHERE ual.action = 'accrue'
|
||||
AND ual.source_order_id IS NOT NULL`
|
||||
if where != "" {
|
||||
where = strings.Replace(where, "WHERE ", " AND ", 1)
|
||||
}
|
||||
|
||||
total, err := queryAffiliateRecordCount(ctx, client, "SELECT COUNT(*) "+baseJoin+where, args...)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
orderBy := buildAffiliateRecordOrderBy(filter, map[string]string{
|
||||
"order": "po.id",
|
||||
"inviter": "inviter.email",
|
||||
"invitee": "invitee.email",
|
||||
"order_amount": "po.amount",
|
||||
"pay_amount": "po.pay_amount",
|
||||
"rebate_amount": "ual.amount",
|
||||
"payment_type": "po.payment_type",
|
||||
"order_status": "po.status",
|
||||
"created_at": "ual.created_at",
|
||||
}, "ual.created_at")
|
||||
args = append(args, filter.PageSize, (filter.Page-1)*filter.PageSize)
|
||||
rows, err := client.QueryContext(ctx, `
|
||||
SELECT po.id,
|
||||
po.out_trade_no,
|
||||
ual.user_id,
|
||||
COALESCE(inviter.email, ''),
|
||||
COALESCE(inviter.username, ''),
|
||||
ual.source_user_id,
|
||||
COALESCE(invitee.email, ''),
|
||||
COALESCE(invitee.username, ''),
|
||||
po.amount::double precision,
|
||||
po.pay_amount::double precision,
|
||||
ual.amount::double precision,
|
||||
po.payment_type,
|
||||
po.status,
|
||||
ual.created_at
|
||||
`+baseJoin+where+`
|
||||
`+orderBy+`
|
||||
LIMIT $`+fmt.Sprint(len(args)-1)+` OFFSET $`+fmt.Sprint(len(args)), args...)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
items := make([]service.AffiliateRebateRecord, 0)
|
||||
for rows.Next() {
|
||||
var item service.AffiliateRebateRecord
|
||||
if err := rows.Scan(
|
||||
&item.OrderID,
|
||||
&item.OutTradeNo,
|
||||
&item.InviterID,
|
||||
&item.InviterEmail,
|
||||
&item.InviterUsername,
|
||||
&item.InviteeID,
|
||||
&item.InviteeEmail,
|
||||
&item.InviteeUsername,
|
||||
&item.OrderAmount,
|
||||
&item.PayAmount,
|
||||
&item.RebateAmount,
|
||||
&item.PaymentType,
|
||||
&item.OrderStatus,
|
||||
&item.CreatedAt,
|
||||
); err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
items = append(items, item)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
return items, total, nil
|
||||
}
|
||||
|
||||
func (r *affiliateRepository) ListAffiliateTransferRecords(ctx context.Context, filter service.AffiliateRecordFilter) ([]service.AffiliateTransferRecord, int64, error) {
|
||||
client := clientFromContext(ctx, r.client)
|
||||
where, args := buildAffiliateRecordWhere(filter, "ual.created_at", []string{
|
||||
"u.email", "u.username", "u.id::text",
|
||||
})
|
||||
baseJoin := `
|
||||
FROM user_affiliate_ledger ual
|
||||
JOIN users u ON u.id = ual.user_id
|
||||
WHERE ual.action = 'transfer'`
|
||||
if where != "" {
|
||||
where = strings.Replace(where, "WHERE ", " AND ", 1)
|
||||
}
|
||||
|
||||
total, err := queryAffiliateRecordCount(ctx, client, "SELECT COUNT(*) "+baseJoin+where, args...)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
orderBy := buildAffiliateRecordOrderBy(filter, map[string]string{
|
||||
"user": "u.email",
|
||||
"amount": "ual.amount",
|
||||
"balance_after": "ual.balance_after",
|
||||
"available_quota_after": "ual.aff_quota_after",
|
||||
"frozen_quota_after": "ual.aff_frozen_quota_after",
|
||||
"history_quota_after": "ual.aff_history_quota_after",
|
||||
"created_at": "ual.created_at",
|
||||
}, "ual.created_at")
|
||||
args = append(args, filter.PageSize, (filter.Page-1)*filter.PageSize)
|
||||
rows, err := client.QueryContext(ctx, `
|
||||
SELECT ual.id,
|
||||
ual.user_id,
|
||||
COALESCE(u.email, ''),
|
||||
COALESCE(u.username, ''),
|
||||
ual.amount::double precision,
|
||||
ual.balance_after::double precision,
|
||||
ual.aff_quota_after::double precision,
|
||||
ual.aff_frozen_quota_after::double precision,
|
||||
ual.aff_history_quota_after::double precision,
|
||||
ual.created_at
|
||||
`+baseJoin+where+`
|
||||
`+orderBy+`
|
||||
LIMIT $`+fmt.Sprint(len(args)-1)+` OFFSET $`+fmt.Sprint(len(args)), args...)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
items := make([]service.AffiliateTransferRecord, 0)
|
||||
for rows.Next() {
|
||||
var item service.AffiliateTransferRecord
|
||||
var balanceAfter sql.NullFloat64
|
||||
var availableQuotaAfter sql.NullFloat64
|
||||
var frozenQuotaAfter sql.NullFloat64
|
||||
var historyQuotaAfter sql.NullFloat64
|
||||
if err := rows.Scan(
|
||||
&item.LedgerID,
|
||||
&item.UserID,
|
||||
&item.UserEmail,
|
||||
&item.Username,
|
||||
&item.Amount,
|
||||
&balanceAfter,
|
||||
&availableQuotaAfter,
|
||||
&frozenQuotaAfter,
|
||||
&historyQuotaAfter,
|
||||
&item.CreatedAt,
|
||||
); err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
item.BalanceAfter = nullableFloat64Ptr(balanceAfter)
|
||||
item.AvailableQuotaAfter = nullableFloat64Ptr(availableQuotaAfter)
|
||||
item.FrozenQuotaAfter = nullableFloat64Ptr(frozenQuotaAfter)
|
||||
item.HistoryQuotaAfter = nullableFloat64Ptr(historyQuotaAfter)
|
||||
item.SnapshotAvailable = balanceAfter.Valid &&
|
||||
availableQuotaAfter.Valid &&
|
||||
frozenQuotaAfter.Valid &&
|
||||
historyQuotaAfter.Valid
|
||||
items = append(items, item)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
return items, total, nil
|
||||
}
|
||||
|
||||
func (r *affiliateRepository) GetAffiliateUserOverview(ctx context.Context, userID int64) (*service.AffiliateUserOverview, error) {
|
||||
if userID <= 0 {
|
||||
return nil, service.ErrUserNotFound
|
||||
}
|
||||
client := clientFromContext(ctx, r.client)
|
||||
rows, err := client.QueryContext(ctx, affiliateUserOverviewSQL, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
if !rows.Next() {
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return nil, service.ErrUserNotFound
|
||||
}
|
||||
|
||||
var overview service.AffiliateUserOverview
|
||||
var customRate float64
|
||||
var hasCustomRate bool
|
||||
if err := rows.Scan(
|
||||
&overview.UserID,
|
||||
&overview.Email,
|
||||
&overview.Username,
|
||||
&overview.AffCode,
|
||||
&customRate,
|
||||
&hasCustomRate,
|
||||
&overview.InvitedCount,
|
||||
&overview.RebatedInviteeCount,
|
||||
&overview.AvailableQuota,
|
||||
&overview.HistoryQuota,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if hasCustomRate {
|
||||
overview.RebateRatePercent = customRate
|
||||
overview.RebateRateCustom = true
|
||||
}
|
||||
return &overview, rows.Err()
|
||||
}
|
||||
|
||||
func buildAffiliateRecordWhere(filter service.AffiliateRecordFilter, timeColumn string, searchColumns []string) (string, []any) {
|
||||
clauses := make([]string, 0, 3)
|
||||
args := make([]any, 0, 3)
|
||||
if filter.StartAt != nil {
|
||||
args = append(args, *filter.StartAt)
|
||||
clauses = append(clauses, fmt.Sprintf("%s >= $%d", timeColumn, len(args)))
|
||||
}
|
||||
if filter.EndAt != nil {
|
||||
args = append(args, *filter.EndAt)
|
||||
clauses = append(clauses, fmt.Sprintf("%s <= $%d", timeColumn, len(args)))
|
||||
}
|
||||
search := strings.TrimSpace(filter.Search)
|
||||
if search != "" && len(searchColumns) > 0 {
|
||||
args = append(args, "%"+strings.ToLower(search)+"%")
|
||||
parts := make([]string, 0, len(searchColumns))
|
||||
for _, col := range searchColumns {
|
||||
parts = append(parts, fmt.Sprintf("LOWER(%s) LIKE $%d", col, len(args)))
|
||||
}
|
||||
clauses = append(clauses, "("+strings.Join(parts, " OR ")+")")
|
||||
}
|
||||
if len(clauses) == 0 {
|
||||
return "", args
|
||||
}
|
||||
return "WHERE " + strings.Join(clauses, " AND "), args
|
||||
}
|
||||
|
||||
func buildAffiliateRecordOrderBy(filter service.AffiliateRecordFilter, sortColumns map[string]string, fallbackColumn string) string {
|
||||
column := sortColumns[filter.SortBy]
|
||||
if column == "" {
|
||||
column = fallbackColumn
|
||||
}
|
||||
direction := "DESC"
|
||||
if !filter.SortDesc {
|
||||
direction = "ASC"
|
||||
}
|
||||
return "ORDER BY " + column + " " + direction + " NULLS LAST"
|
||||
}
|
||||
|
||||
func queryAffiliateRecordCount(ctx context.Context, client affiliateQueryExecer, query string, args ...any) (int64, error) {
|
||||
rows, err := client.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
if !rows.Next() {
|
||||
return 0, rows.Err()
|
||||
}
|
||||
var total int64
|
||||
if err := rows.Scan(&total); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return total, rows.Err()
|
||||
}
|
||||
|
||||
func (r *affiliateRepository) withTx(ctx context.Context, fn func(txCtx context.Context, txClient *dbent.Client) error) error {
|
||||
if tx := dbent.TxFromContext(ctx); tx != nil {
|
||||
return fn(ctx, tx.Client())
|
||||
@ -516,6 +910,54 @@ func queryUserBalance(ctx context.Context, client affiliateQueryExecer, userID i
|
||||
return balance, nil
|
||||
}
|
||||
|
||||
type affiliateTransferSnapshot struct {
|
||||
BalanceAfter float64
|
||||
AvailableQuotaAfter float64
|
||||
FrozenQuotaAfter float64
|
||||
HistoryQuotaAfter float64
|
||||
}
|
||||
|
||||
func queryAffiliateTransferSnapshot(ctx context.Context, client affiliateQueryExecer, userID int64) (*affiliateTransferSnapshot, error) {
|
||||
rows, err := client.QueryContext(ctx, `
|
||||
SELECT u.balance::double precision,
|
||||
ua.aff_quota::double precision,
|
||||
ua.aff_frozen_quota::double precision,
|
||||
ua.aff_history_quota::double precision
|
||||
FROM users u
|
||||
JOIN user_affiliates ua ON ua.user_id = u.id
|
||||
WHERE u.id = $1
|
||||
LIMIT 1`, userID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query affiliate transfer snapshot: %w", err)
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
if !rows.Next() {
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return nil, service.ErrUserNotFound
|
||||
}
|
||||
|
||||
var snapshot affiliateTransferSnapshot
|
||||
if err := rows.Scan(
|
||||
&snapshot.BalanceAfter,
|
||||
&snapshot.AvailableQuotaAfter,
|
||||
&snapshot.FrozenQuotaAfter,
|
||||
&snapshot.HistoryQuotaAfter,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &snapshot, rows.Err()
|
||||
}
|
||||
|
||||
func nullableFloat64Ptr(v sql.NullFloat64) *float64 {
|
||||
if !v.Valid {
|
||||
return nil
|
||||
}
|
||||
return &v.Float64
|
||||
}
|
||||
|
||||
func generateAffiliateCode() (string, error) {
|
||||
buf := make([]byte, affiliateCodeLength)
|
||||
if _, err := rand.Read(buf); err != nil {
|
||||
@ -674,6 +1116,13 @@ func nullableArg(v *float64) any {
|
||||
return *v
|
||||
}
|
||||
|
||||
func nullableInt64Arg(v *int64) any {
|
||||
if v == nil {
|
||||
return nil
|
||||
}
|
||||
return *v
|
||||
}
|
||||
|
||||
// ListUsersWithCustomSettings 列出有专属配置(自定义码或专属比例)的用户。
|
||||
//
|
||||
// 单一查询同时处理"无搜索"与"按邮箱/用户名模糊搜索":
|
||||
|
||||
@ -78,6 +78,26 @@ VALUES ($1, $2, $3, $3, NOW(), NOW())`, u.ID, affCode, 12.34)
|
||||
ledgerCount := querySingleInt(t, txCtx, client,
|
||||
"SELECT COUNT(*) FROM user_affiliate_ledger WHERE user_id = $1 AND action = 'transfer'", u.ID)
|
||||
require.Equal(t, 1, ledgerCount)
|
||||
|
||||
rows, err := client.QueryContext(txCtx, `
|
||||
SELECT amount::double precision,
|
||||
balance_after::double precision,
|
||||
aff_quota_after::double precision,
|
||||
aff_frozen_quota_after::double precision,
|
||||
aff_history_quota_after::double precision
|
||||
FROM user_affiliate_ledger
|
||||
WHERE user_id = $1 AND action = 'transfer'
|
||||
LIMIT 1`, u.ID)
|
||||
require.NoError(t, err)
|
||||
defer func() { _ = rows.Close() }()
|
||||
require.True(t, rows.Next(), "expected transfer ledger")
|
||||
var amount, balanceAfter, quotaAfter, frozenAfter, historyAfter float64
|
||||
require.NoError(t, rows.Scan(&amount, &balanceAfter, "aAfter, &frozenAfter, &historyAfter))
|
||||
require.InDelta(t, 12.34, amount, 1e-9)
|
||||
require.InDelta(t, 17.84, balanceAfter, 1e-9)
|
||||
require.InDelta(t, 0.0, quotaAfter, 1e-9)
|
||||
require.InDelta(t, 0.0, frozenAfter, 1e-9)
|
||||
require.InDelta(t, 12.34, historyAfter, 1e-9)
|
||||
}
|
||||
|
||||
// TestAffiliateRepository_AccrueQuota_ReusesOuterTransaction guards the
|
||||
@ -125,7 +145,7 @@ func TestAffiliateRepository_AccrueQuota_ReusesOuterTransaction(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
require.True(t, bound, "invitee must bind to inviter")
|
||||
|
||||
applied, err := repo.AccrueQuota(txCtx, inviter.ID, invitee.ID, 3.5, 0)
|
||||
applied, err := repo.AccrueQuota(txCtx, inviter.ID, invitee.ID, 3.5, 0, nil)
|
||||
require.NoError(t, err)
|
||||
require.True(t, applied, "AccrueQuota must report applied=true")
|
||||
|
||||
|
||||
28
backend/internal/repository/affiliate_repo_test.go
Normal file
28
backend/internal/repository/affiliate_repo_test.go
Normal file
@ -0,0 +1,28 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestAffiliateUserOverviewSQLIncludesMaturedFrozenQuota(t *testing.T) {
|
||||
query := strings.Join(strings.Fields(affiliateUserOverviewSQL), " ")
|
||||
|
||||
require.Contains(t, query, "ua.aff_quota + COALESCE(matured.matured_frozen_quota, 0)")
|
||||
require.Contains(t, query, "frozen_until <= NOW()")
|
||||
}
|
||||
|
||||
func TestAffiliateRecordQueriesUseLedgerAuditFields(t *testing.T) {
|
||||
source, err := os.ReadFile("affiliate_repo.go")
|
||||
require.NoError(t, err)
|
||||
content := string(source)
|
||||
|
||||
require.Contains(t, content, "JOIN payment_orders po ON po.id = ual.source_order_id")
|
||||
require.Contains(t, content, "ual.amount::double precision")
|
||||
require.Contains(t, content, "ual.balance_after::double precision")
|
||||
require.NotContains(t, content, "parseAffiliateRebateAmount")
|
||||
require.NotContains(t, content, `"current_balance": "u.balance"`)
|
||||
}
|
||||
@ -166,6 +166,9 @@ func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*se
|
||||
group.FieldDailyLimitUsd,
|
||||
group.FieldWeeklyLimitUsd,
|
||||
group.FieldMonthlyLimitUsd,
|
||||
group.FieldAllowImageGeneration,
|
||||
group.FieldImageRateIndependent,
|
||||
group.FieldImageRateMultiplier,
|
||||
group.FieldImagePrice1k,
|
||||
group.FieldImagePrice2k,
|
||||
group.FieldImagePrice4k,
|
||||
@ -699,6 +702,9 @@ func groupEntityToService(g *dbent.Group) *service.Group {
|
||||
DailyLimitUSD: g.DailyLimitUsd,
|
||||
WeeklyLimitUSD: g.WeeklyLimitUsd,
|
||||
MonthlyLimitUSD: g.MonthlyLimitUsd,
|
||||
AllowImageGeneration: g.AllowImageGeneration,
|
||||
ImageRateIndependent: g.ImageRateIndependent,
|
||||
ImageRateMultiplier: g.ImageRateMultiplier,
|
||||
ImagePrice1K: g.ImagePrice1k,
|
||||
ImagePrice2K: g.ImagePrice2k,
|
||||
ImagePrice4K: g.ImagePrice4k,
|
||||
|
||||
@ -50,6 +50,9 @@ func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) er
|
||||
SetNillableDailyLimitUsd(groupIn.DailyLimitUSD).
|
||||
SetNillableWeeklyLimitUsd(groupIn.WeeklyLimitUSD).
|
||||
SetNillableMonthlyLimitUsd(groupIn.MonthlyLimitUSD).
|
||||
SetAllowImageGeneration(groupIn.AllowImageGeneration).
|
||||
SetImageRateIndependent(groupIn.ImageRateIndependent).
|
||||
SetImageRateMultiplier(groupIn.ImageRateMultiplier).
|
||||
SetNillableImagePrice1k(groupIn.ImagePrice1K).
|
||||
SetNillableImagePrice2k(groupIn.ImagePrice2K).
|
||||
SetNillableImagePrice4k(groupIn.ImagePrice4K).
|
||||
@ -120,6 +123,9 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er
|
||||
SetNillableDailyLimitUsd(groupIn.DailyLimitUSD).
|
||||
SetNillableWeeklyLimitUsd(groupIn.WeeklyLimitUSD).
|
||||
SetNillableMonthlyLimitUsd(groupIn.MonthlyLimitUSD).
|
||||
SetAllowImageGeneration(groupIn.AllowImageGeneration).
|
||||
SetImageRateIndependent(groupIn.ImageRateIndependent).
|
||||
SetImageRateMultiplier(groupIn.ImageRateMultiplier).
|
||||
SetNillableImagePrice1k(groupIn.ImagePrice1K).
|
||||
SetNillableImagePrice2k(groupIn.ImagePrice2K).
|
||||
SetNillableImagePrice4k(groupIn.ImagePrice4K).
|
||||
|
||||
@ -328,6 +328,9 @@ func TestAPIContracts(t *testing.T) {
|
||||
"image_price_1k": null,
|
||||
"image_price_2k": null,
|
||||
"image_price_4k": null,
|
||||
"allow_image_generation": false,
|
||||
"image_rate_independent": false,
|
||||
"image_rate_multiplier": 0,
|
||||
"claude_code_only": false,
|
||||
"allow_messages_dispatch": false,
|
||||
"fallback_group_id": null,
|
||||
|
||||
@ -412,6 +412,9 @@ func registerSettingsRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
// 529过载冷却配置
|
||||
adminSettings.GET("/overload-cooldown", h.Admin.Setting.GetOverloadCooldownSettings)
|
||||
adminSettings.PUT("/overload-cooldown", h.Admin.Setting.UpdateOverloadCooldownSettings)
|
||||
// 429默认回避配置
|
||||
adminSettings.GET("/rate-limit-429-cooldown", h.Admin.Setting.GetRateLimit429CooldownSettings)
|
||||
adminSettings.PUT("/rate-limit-429-cooldown", h.Admin.Setting.UpdateRateLimit429CooldownSettings)
|
||||
// 流超时处理配置
|
||||
adminSettings.GET("/stream-timeout", h.Admin.Setting.GetStreamTimeoutSettings)
|
||||
adminSettings.PUT("/stream-timeout", h.Admin.Setting.UpdateStreamTimeoutSettings)
|
||||
@ -624,11 +627,16 @@ func registerChannelMonitorRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
func registerAffiliateRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
affiliates := admin.Group("/affiliates")
|
||||
{
|
||||
affiliates.GET("/invites", h.Admin.Affiliate.ListInviteRecords)
|
||||
affiliates.GET("/rebates", h.Admin.Affiliate.ListRebateRecords)
|
||||
affiliates.GET("/transfers", h.Admin.Affiliate.ListTransferRecords)
|
||||
|
||||
users := affiliates.Group("/users")
|
||||
{
|
||||
users.GET("", h.Admin.Affiliate.ListUsers)
|
||||
users.GET("/lookup", h.Admin.Affiliate.LookupUsers)
|
||||
users.POST("/batch-rate", h.Admin.Affiliate.BatchSetRate)
|
||||
users.GET("/:user_id/overview", h.Admin.Affiliate.GetUserOverview)
|
||||
users.PUT("/:user_id", h.Admin.Affiliate.UpdateUserSettings)
|
||||
users.DELETE("/:user_id", h.Admin.Affiliate.ClearUserSettings)
|
||||
}
|
||||
|
||||
@ -230,7 +230,11 @@ func applyAccountStatsCost(
|
||||
if model == "" {
|
||||
model = requestedModel
|
||||
}
|
||||
requestCount := 1
|
||||
if usageLog != nil && usageLog.ImageCount > 0 {
|
||||
requestCount = usageLog.ImageCount
|
||||
}
|
||||
usageLog.AccountStatsCost = resolveAccountStatsCost(
|
||||
ctx, cs, bs, accountID, groupID, model, tokens, 1, totalCost,
|
||||
ctx, cs, bs, accountID, groupID, model, tokens, requestCount, totalCost,
|
||||
)
|
||||
}
|
||||
|
||||
@ -21,6 +21,7 @@ import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/openai_compat"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/windsurf"
|
||||
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
|
||||
"github.com/gin-gonic/gin"
|
||||
@ -571,7 +572,16 @@ func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account
|
||||
if err != nil {
|
||||
return s.sendErrorAndEnd(c, fmt.Sprintf("Invalid base URL: %s", err.Error()))
|
||||
}
|
||||
apiURL = strings.TrimSuffix(normalizedBaseURL, "/") + "/responses"
|
||||
// 账号已被探测为不支持 Responses(如 DeepSeek/Kimi 等)时,丢出明确提示。
|
||||
// 账号本身可用(网关会走 CC 直转),仅测试入口需要补齐 CC SSE 处理逻辑。
|
||||
// TODO:实现 CC 格式的账号测试路径(需专门的 CC SSE handler)。
|
||||
if !openai_compat.ShouldUseResponsesAPI(account.Extra) {
|
||||
return s.sendErrorAndEnd(c,
|
||||
"账号已被探测为不支持 OpenAI Responses API(如 DeepSeek/Kimi 等三方兼容上游),"+
|
||||
"账号本身可正常使用,但当前测试接口仅支持 Responses API 路径。请直接通过实际 API 调用验证。",
|
||||
)
|
||||
}
|
||||
apiURL = buildOpenAIResponsesURL(normalizedBaseURL)
|
||||
} else {
|
||||
return s.sendErrorAndEnd(c, fmt.Sprintf("Unsupported account type: %s", account.Type))
|
||||
}
|
||||
|
||||
86
backend/internal/service/admin_balance_history_test.go
Normal file
86
backend/internal/service/admin_balance_history_test.go
Normal file
@ -0,0 +1,86 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestMergeBalanceHistoryCodesIncludesAffiliateTransfersByDefault(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
now := time.Date(2026, 5, 3, 12, 0, 0, 0, time.UTC)
|
||||
older := now.Add(-2 * time.Hour)
|
||||
newer := now.Add(time.Hour)
|
||||
|
||||
usedBy := int64(10)
|
||||
redeemCodes := []RedeemCode{
|
||||
{
|
||||
ID: 1,
|
||||
Type: RedeemTypeBalance,
|
||||
Value: 8,
|
||||
Status: StatusUsed,
|
||||
UsedBy: &usedBy,
|
||||
UsedAt: &now,
|
||||
CreatedAt: now,
|
||||
},
|
||||
{
|
||||
ID: 2,
|
||||
Type: RedeemTypeConcurrency,
|
||||
Value: 1,
|
||||
Status: StatusUsed,
|
||||
UsedBy: &usedBy,
|
||||
UsedAt: &older,
|
||||
CreatedAt: older,
|
||||
},
|
||||
}
|
||||
affiliateCodes := []RedeemCode{
|
||||
{
|
||||
ID: -20,
|
||||
Type: RedeemTypeAffiliateBalance,
|
||||
Value: 3.5,
|
||||
Status: StatusUsed,
|
||||
UsedBy: &usedBy,
|
||||
UsedAt: &newer,
|
||||
CreatedAt: newer,
|
||||
},
|
||||
}
|
||||
|
||||
got := mergeBalanceHistoryCodes(redeemCodes, affiliateCodes, pagination.PaginationParams{
|
||||
Page: 1,
|
||||
PageSize: 2,
|
||||
})
|
||||
|
||||
require.Len(t, got, 2)
|
||||
require.Equal(t, RedeemTypeAffiliateBalance, got[0].Type)
|
||||
require.Equal(t, RedeemTypeBalance, got[1].Type)
|
||||
}
|
||||
|
||||
func TestMergeBalanceHistoryCodesPaginatesAfterCombiningSources(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
base := time.Date(2026, 5, 3, 12, 0, 0, 0, time.UTC)
|
||||
usedBy := int64(10)
|
||||
at := func(hours int) *time.Time {
|
||||
v := base.Add(time.Duration(hours) * time.Hour)
|
||||
return &v
|
||||
}
|
||||
|
||||
got := mergeBalanceHistoryCodes(
|
||||
[]RedeemCode{
|
||||
{ID: 1, Type: RedeemTypeBalance, UsedBy: &usedBy, UsedAt: at(4), CreatedAt: *at(4)},
|
||||
{ID: 2, Type: RedeemTypeConcurrency, UsedBy: &usedBy, UsedAt: at(2), CreatedAt: *at(2)},
|
||||
},
|
||||
[]RedeemCode{
|
||||
{ID: -3, Type: RedeemTypeAffiliateBalance, UsedBy: &usedBy, UsedAt: at(3), CreatedAt: *at(3)},
|
||||
{ID: -4, Type: RedeemTypeAffiliateBalance, UsedBy: &usedBy, UsedAt: at(1), CreatedAt: *at(1)},
|
||||
},
|
||||
pagination.PaginationParams{Page: 2, PageSize: 2},
|
||||
)
|
||||
|
||||
require.Len(t, got, 2)
|
||||
require.Equal(t, RedeemTypeConcurrency, got[0].Type)
|
||||
require.Equal(t, int64(-4), got[1].ID)
|
||||
}
|
||||
@ -2,6 +2,7 @@ package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
@ -188,11 +189,14 @@ type CreateGroupInput struct {
|
||||
WeeklyLimitUSD *float64 // 周限额 (USD)
|
||||
MonthlyLimitUSD *float64 // 月限额 (USD)
|
||||
// 图片生成计费配置(仅 antigravity 平台使用)
|
||||
ImagePrice1K *float64
|
||||
ImagePrice2K *float64
|
||||
ImagePrice4K *float64
|
||||
ClaudeCodeOnly bool // 仅允许 Claude Code 客户端
|
||||
FallbackGroupID *int64 // 降级分组 ID
|
||||
AllowImageGeneration bool
|
||||
ImageRateIndependent bool
|
||||
ImageRateMultiplier *float64
|
||||
ImagePrice1K *float64
|
||||
ImagePrice2K *float64
|
||||
ImagePrice4K *float64
|
||||
ClaudeCodeOnly bool // 仅允许 Claude Code 客户端
|
||||
FallbackGroupID *int64 // 降级分组 ID
|
||||
// 无效请求兜底分组 ID(仅 anthropic 平台使用)
|
||||
FallbackGroupIDOnInvalidRequest *int64
|
||||
// 模型路由配置(仅 anthropic 平台使用)
|
||||
@ -225,11 +229,14 @@ type UpdateGroupInput struct {
|
||||
WeeklyLimitUSD *float64 // 周限额 (USD)
|
||||
MonthlyLimitUSD *float64 // 月限额 (USD)
|
||||
// 图片生成计费配置(仅 antigravity 平台使用)
|
||||
ImagePrice1K *float64
|
||||
ImagePrice2K *float64
|
||||
ImagePrice4K *float64
|
||||
ClaudeCodeOnly *bool // 仅允许 Claude Code 客户端
|
||||
FallbackGroupID *int64 // 降级分组 ID
|
||||
AllowImageGeneration *bool
|
||||
ImageRateIndependent *bool
|
||||
ImageRateMultiplier *float64
|
||||
ImagePrice1K *float64
|
||||
ImagePrice2K *float64
|
||||
ImagePrice4K *float64
|
||||
ClaudeCodeOnly *bool // 仅允许 Claude Code 客户端
|
||||
FallbackGroupID *int64 // 降级分组 ID
|
||||
// 无效请求兜底分组 ID(仅 anthropic 平台使用)
|
||||
FallbackGroupIDOnInvalidRequest *int64
|
||||
// 模型路由配置(仅 anthropic 平台使用)
|
||||
@ -973,16 +980,213 @@ func (s *adminServiceImpl) GetUserUsageStats(ctx context.Context, userID int64,
|
||||
// GetUserBalanceHistory returns paginated balance/concurrency change records for a user.
|
||||
func (s *adminServiceImpl) GetUserBalanceHistory(ctx context.Context, userID int64, page, pageSize int, codeType string) ([]RedeemCode, int64, float64, error) {
|
||||
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
|
||||
if codeType == RedeemTypeAffiliateBalance {
|
||||
codes, total, err := s.listAffiliateBalanceHistory(ctx, userID, params)
|
||||
if err != nil {
|
||||
return nil, 0, 0, err
|
||||
}
|
||||
totalRecharged, err := s.redeemCodeRepo.SumPositiveBalanceByUser(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, 0, 0, err
|
||||
}
|
||||
return codes, total, totalRecharged, nil
|
||||
}
|
||||
|
||||
if codeType == "" {
|
||||
return s.getAllUserBalanceHistory(ctx, userID, params)
|
||||
}
|
||||
|
||||
codes, result, err := s.redeemCodeRepo.ListByUserPaginated(ctx, userID, params, codeType)
|
||||
if err != nil {
|
||||
return nil, 0, 0, err
|
||||
}
|
||||
total := result.Total
|
||||
// Aggregate total recharged amount (only once, regardless of type filter)
|
||||
totalRecharged, err := s.redeemCodeRepo.SumPositiveBalanceByUser(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, 0, 0, err
|
||||
}
|
||||
return codes, result.Total, totalRecharged, nil
|
||||
return codes, total, totalRecharged, nil
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) getAllUserBalanceHistory(ctx context.Context, userID int64, params pagination.PaginationParams) ([]RedeemCode, int64, float64, error) {
|
||||
needed := params.Offset() + params.Limit()
|
||||
if needed < params.Limit() {
|
||||
needed = params.Limit()
|
||||
}
|
||||
|
||||
redeemCodes, redeemTotal, err := s.listRedeemBalanceHistoryForMerge(ctx, userID, needed)
|
||||
if err != nil {
|
||||
return nil, 0, 0, err
|
||||
}
|
||||
affiliateCodes, affiliateTotal, err := s.listAffiliateBalanceHistoryForMerge(ctx, userID, needed)
|
||||
if err != nil {
|
||||
return nil, 0, 0, err
|
||||
}
|
||||
codes := mergeBalanceHistoryCodes(redeemCodes, affiliateCodes, params)
|
||||
|
||||
totalRecharged, err := s.redeemCodeRepo.SumPositiveBalanceByUser(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, 0, 0, err
|
||||
}
|
||||
return codes, redeemTotal + affiliateTotal, totalRecharged, nil
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) listRedeemBalanceHistoryForMerge(ctx context.Context, userID int64, needed int) ([]RedeemCode, int64, error) {
|
||||
if needed <= 0 {
|
||||
return nil, 0, nil
|
||||
}
|
||||
|
||||
var (
|
||||
out []RedeemCode
|
||||
total int64
|
||||
)
|
||||
for page := 1; len(out) < needed; page++ {
|
||||
params := pagination.PaginationParams{Page: page, PageSize: 1000}
|
||||
codes, result, err := s.redeemCodeRepo.ListByUserPaginated(ctx, userID, params, "")
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
if result != nil {
|
||||
total = result.Total
|
||||
}
|
||||
out = append(out, codes...)
|
||||
if len(codes) < params.Limit() || int64(len(out)) >= total {
|
||||
break
|
||||
}
|
||||
}
|
||||
if len(out) > needed {
|
||||
out = out[:needed]
|
||||
}
|
||||
return out, total, nil
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) listAffiliateBalanceHistoryForMerge(ctx context.Context, userID int64, needed int) ([]RedeemCode, int64, error) {
|
||||
if needed <= 0 {
|
||||
return nil, 0, nil
|
||||
}
|
||||
|
||||
var (
|
||||
out []RedeemCode
|
||||
total int64
|
||||
)
|
||||
for page := 1; len(out) < needed; page++ {
|
||||
params := pagination.PaginationParams{Page: page, PageSize: 1000}
|
||||
codes, currentTotal, err := s.listAffiliateBalanceHistory(ctx, userID, params)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
total = currentTotal
|
||||
out = append(out, codes...)
|
||||
if len(codes) < params.Limit() || int64(len(out)) >= total {
|
||||
break
|
||||
}
|
||||
}
|
||||
if len(out) > needed {
|
||||
out = out[:needed]
|
||||
}
|
||||
return out, total, nil
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) listAffiliateBalanceHistory(ctx context.Context, userID int64, params pagination.PaginationParams) ([]RedeemCode, int64, error) {
|
||||
if s == nil || s.entClient == nil || userID <= 0 {
|
||||
return nil, 0, nil
|
||||
}
|
||||
|
||||
rows, err := s.entClient.QueryContext(ctx, `
|
||||
SELECT id,
|
||||
amount::double precision,
|
||||
created_at
|
||||
FROM user_affiliate_ledger
|
||||
WHERE user_id = $1
|
||||
AND action = 'transfer'
|
||||
ORDER BY created_at DESC, id DESC
|
||||
OFFSET $2
|
||||
LIMIT $3`, userID, params.Offset(), params.Limit())
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
codes := make([]RedeemCode, 0, params.Limit())
|
||||
for rows.Next() {
|
||||
var id int64
|
||||
var amount float64
|
||||
var createdAt time.Time
|
||||
if err := rows.Scan(&id, &amount, &createdAt); err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
usedBy := userID
|
||||
usedAt := createdAt
|
||||
codes = append(codes, RedeemCode{
|
||||
ID: -id,
|
||||
Code: fmt.Sprintf("AFF-%d", id),
|
||||
Type: RedeemTypeAffiliateBalance,
|
||||
Value: amount,
|
||||
Status: StatusUsed,
|
||||
UsedBy: &usedBy,
|
||||
UsedAt: &usedAt,
|
||||
CreatedAt: createdAt,
|
||||
})
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
total, err := countAffiliateBalanceHistory(ctx, s.entClient, userID)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
return codes, total, nil
|
||||
}
|
||||
|
||||
func countAffiliateBalanceHistory(ctx context.Context, client *dbent.Client, userID int64) (int64, error) {
|
||||
rows, err := client.QueryContext(ctx, `
|
||||
SELECT COUNT(*)
|
||||
FROM user_affiliate_ledger
|
||||
WHERE user_id = $1
|
||||
AND action = 'transfer'`, userID)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
var total sql.NullInt64
|
||||
if rows.Next() {
|
||||
if err := rows.Scan(&total); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if !total.Valid {
|
||||
return 0, nil
|
||||
}
|
||||
return total.Int64, nil
|
||||
}
|
||||
|
||||
func mergeBalanceHistoryCodes(redeemCodes, affiliateCodes []RedeemCode, params pagination.PaginationParams) []RedeemCode {
|
||||
combined := append(append([]RedeemCode{}, redeemCodes...), affiliateCodes...)
|
||||
sort.SliceStable(combined, func(i, j int) bool {
|
||||
return redeemCodeHistoryTime(combined[i]).After(redeemCodeHistoryTime(combined[j]))
|
||||
})
|
||||
offset := params.Offset()
|
||||
if offset >= len(combined) {
|
||||
return []RedeemCode{}
|
||||
}
|
||||
end := offset + params.Limit()
|
||||
if end > len(combined) {
|
||||
end = len(combined)
|
||||
}
|
||||
return combined[offset:end]
|
||||
}
|
||||
|
||||
func redeemCodeHistoryTime(code RedeemCode) time.Time {
|
||||
if code.UsedAt != nil {
|
||||
return *code.UsedAt
|
||||
}
|
||||
return code.CreatedAt
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) BindUserAuthIdentity(ctx context.Context, userID int64, input AdminBindAuthIdentityInput) (*AdminBoundAuthIdentity, error) {
|
||||
@ -1359,6 +1563,13 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
|
||||
imagePrice1K := normalizePrice(input.ImagePrice1K)
|
||||
imagePrice2K := normalizePrice(input.ImagePrice2K)
|
||||
imagePrice4K := normalizePrice(input.ImagePrice4K)
|
||||
imageRateMultiplier := 1.0
|
||||
if input.ImageRateMultiplier != nil {
|
||||
if *input.ImageRateMultiplier < 0 {
|
||||
return nil, errors.New("image_rate_multiplier must be >= 0")
|
||||
}
|
||||
imageRateMultiplier = *input.ImageRateMultiplier
|
||||
}
|
||||
|
||||
// 校验降级分组
|
||||
if input.FallbackGroupID != nil {
|
||||
@ -1426,6 +1637,9 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
|
||||
DailyLimitUSD: dailyLimit,
|
||||
WeeklyLimitUSD: weeklyLimit,
|
||||
MonthlyLimitUSD: monthlyLimit,
|
||||
AllowImageGeneration: input.AllowImageGeneration,
|
||||
ImageRateIndependent: input.ImageRateIndependent,
|
||||
ImageRateMultiplier: imageRateMultiplier,
|
||||
ImagePrice1K: imagePrice1K,
|
||||
ImagePrice2K: imagePrice2K,
|
||||
ImagePrice4K: imagePrice4K,
|
||||
@ -1602,6 +1816,18 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd
|
||||
group.WeeklyLimitUSD = normalizeLimit(input.WeeklyLimitUSD)
|
||||
group.MonthlyLimitUSD = normalizeLimit(input.MonthlyLimitUSD)
|
||||
// 图片生成计费配置:负数表示清除(使用默认价格)
|
||||
if input.AllowImageGeneration != nil {
|
||||
group.AllowImageGeneration = *input.AllowImageGeneration
|
||||
}
|
||||
if input.ImageRateIndependent != nil {
|
||||
group.ImageRateIndependent = *input.ImageRateIndependent
|
||||
}
|
||||
if input.ImageRateMultiplier != nil {
|
||||
if *input.ImageRateMultiplier < 0 {
|
||||
return nil, errors.New("image_rate_multiplier must be >= 0")
|
||||
}
|
||||
group.ImageRateMultiplier = *input.ImageRateMultiplier
|
||||
}
|
||||
if input.ImagePrice1K != nil {
|
||||
group.ImagePrice1K = normalizePrice(input.ImagePrice1K)
|
||||
}
|
||||
|
||||
@ -266,6 +266,50 @@ func TestAdminService_UpdateGroup_PartialImagePricing(t *testing.T) {
|
||||
require.Nil(t, repo.updated.ImagePrice4K)
|
||||
}
|
||||
|
||||
func TestAdminService_UpdateGroup_PreservesImageGenerationControlsWhenOmitted(t *testing.T) {
|
||||
imageMultiplier := 0.5
|
||||
existingGroup := &Group{
|
||||
ID: 1,
|
||||
Name: "existing-group",
|
||||
Platform: PlatformOpenAI,
|
||||
Status: StatusActive,
|
||||
AllowImageGeneration: true,
|
||||
ImageRateIndependent: true,
|
||||
ImageRateMultiplier: imageMultiplier,
|
||||
}
|
||||
repo := &groupRepoStubForAdmin{getByID: existingGroup}
|
||||
svc := &adminServiceImpl{groupRepo: repo}
|
||||
|
||||
group, err := svc.UpdateGroup(context.Background(), 1, &UpdateGroupInput{
|
||||
Description: "updated",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, group)
|
||||
require.NotNil(t, repo.updated)
|
||||
require.True(t, repo.updated.AllowImageGeneration)
|
||||
require.True(t, repo.updated.ImageRateIndependent)
|
||||
require.InDelta(t, 0.5, repo.updated.ImageRateMultiplier, 1e-12)
|
||||
}
|
||||
|
||||
func TestAdminService_UpdateGroup_RejectsNegativeImageRateMultiplier(t *testing.T) {
|
||||
existingGroup := &Group{
|
||||
ID: 1,
|
||||
Name: "existing-group",
|
||||
Platform: PlatformOpenAI,
|
||||
Status: StatusActive,
|
||||
ImageRateMultiplier: 1,
|
||||
}
|
||||
repo := &groupRepoStubForAdmin{getByID: existingGroup}
|
||||
svc := &adminServiceImpl{groupRepo: repo}
|
||||
negative := -0.1
|
||||
|
||||
_, err := svc.UpdateGroup(context.Background(), 1, &UpdateGroupInput{
|
||||
ImageRateMultiplier: &negative,
|
||||
})
|
||||
require.Error(t, err)
|
||||
require.Nil(t, repo.updated)
|
||||
}
|
||||
|
||||
func TestAdminService_UpdateGroup_InvalidatesAuthCacheOnRPMLimitChange(t *testing.T) {
|
||||
existingGroup := &Group{
|
||||
ID: 1,
|
||||
|
||||
@ -98,7 +98,7 @@ type AffiliateRepository interface {
|
||||
EnsureUserAffiliate(ctx context.Context, userID int64) (*AffiliateSummary, error)
|
||||
GetAffiliateByCode(ctx context.Context, code string) (*AffiliateSummary, error)
|
||||
BindInviter(ctx context.Context, userID, inviterID int64) (bool, error)
|
||||
AccrueQuota(ctx context.Context, inviterID, inviteeUserID int64, amount float64, freezeHours int) (bool, error)
|
||||
AccrueQuota(ctx context.Context, inviterID, inviteeUserID int64, amount float64, freezeHours int, sourceOrderID *int64) (bool, error)
|
||||
GetAccruedRebateFromInvitee(ctx context.Context, inviterID, inviteeUserID int64) (float64, error)
|
||||
ThawFrozenQuota(ctx context.Context, userID int64) (float64, error)
|
||||
TransferQuotaToBalance(ctx context.Context, userID int64) (float64, float64, error)
|
||||
@ -110,6 +110,10 @@ type AffiliateRepository interface {
|
||||
SetUserRebateRate(ctx context.Context, userID int64, ratePercent *float64) error
|
||||
BatchSetUserRebateRate(ctx context.Context, userIDs []int64, ratePercent *float64) error
|
||||
ListUsersWithCustomSettings(ctx context.Context, filter AffiliateAdminFilter) ([]AffiliateAdminEntry, int64, error)
|
||||
ListAffiliateInviteRecords(ctx context.Context, filter AffiliateRecordFilter) ([]AffiliateInviteRecord, int64, error)
|
||||
ListAffiliateRebateRecords(ctx context.Context, filter AffiliateRecordFilter) ([]AffiliateRebateRecord, int64, error)
|
||||
ListAffiliateTransferRecords(ctx context.Context, filter AffiliateRecordFilter) ([]AffiliateTransferRecord, int64, error)
|
||||
GetAffiliateUserOverview(ctx context.Context, userID int64) (*AffiliateUserOverview, error)
|
||||
}
|
||||
|
||||
// AffiliateAdminFilter 列表筛选条件
|
||||
@ -130,6 +134,76 @@ type AffiliateAdminEntry struct {
|
||||
AffCount int `json:"aff_count"`
|
||||
}
|
||||
|
||||
type AffiliateRecordFilter struct {
|
||||
Search string
|
||||
Page int
|
||||
PageSize int
|
||||
StartAt *time.Time
|
||||
EndAt *time.Time
|
||||
SortBy string
|
||||
SortDesc bool
|
||||
}
|
||||
|
||||
type AffiliateInviteRecord struct {
|
||||
InviterID int64 `json:"inviter_id"`
|
||||
InviterEmail string `json:"inviter_email"`
|
||||
InviterUsername string `json:"inviter_username"`
|
||||
InviteeID int64 `json:"invitee_id"`
|
||||
InviteeEmail string `json:"invitee_email"`
|
||||
InviteeUsername string `json:"invitee_username"`
|
||||
AffCode string `json:"aff_code"`
|
||||
TotalRebate float64 `json:"total_rebate"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
}
|
||||
|
||||
type AffiliateRebateRecord struct {
|
||||
OrderID int64 `json:"order_id"`
|
||||
OutTradeNo string `json:"out_trade_no"`
|
||||
InviterID int64 `json:"inviter_id"`
|
||||
InviterEmail string `json:"inviter_email"`
|
||||
InviterUsername string `json:"inviter_username"`
|
||||
InviteeID int64 `json:"invitee_id"`
|
||||
InviteeEmail string `json:"invitee_email"`
|
||||
InviteeUsername string `json:"invitee_username"`
|
||||
OrderAmount float64 `json:"order_amount"`
|
||||
PayAmount float64 `json:"pay_amount"`
|
||||
RebateAmount float64 `json:"rebate_amount"`
|
||||
PaymentType string `json:"payment_type"`
|
||||
OrderStatus string `json:"order_status"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
}
|
||||
|
||||
type AffiliateTransferRecord struct {
|
||||
LedgerID int64 `json:"ledger_id"`
|
||||
UserID int64 `json:"user_id"`
|
||||
UserEmail string `json:"user_email"`
|
||||
Username string `json:"username"`
|
||||
Amount float64 `json:"amount"`
|
||||
BalanceAfter *float64 `json:"balance_after,omitempty"`
|
||||
AvailableQuotaAfter *float64 `json:"available_quota_after,omitempty"`
|
||||
FrozenQuotaAfter *float64 `json:"frozen_quota_after,omitempty"`
|
||||
HistoryQuotaAfter *float64 `json:"history_quota_after,omitempty"`
|
||||
SnapshotAvailable bool `json:"snapshot_available"`
|
||||
CurrentBalance float64 `json:"-"`
|
||||
RemainingQuota float64 `json:"-"`
|
||||
FrozenQuota float64 `json:"-"`
|
||||
HistoryQuota float64 `json:"-"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
}
|
||||
|
||||
type AffiliateUserOverview struct {
|
||||
UserID int64 `json:"user_id"`
|
||||
Email string `json:"email"`
|
||||
Username string `json:"username"`
|
||||
AffCode string `json:"aff_code"`
|
||||
RebateRatePercent float64 `json:"rebate_rate_percent"`
|
||||
RebateRateCustom bool `json:"-"`
|
||||
InvitedCount int `json:"invited_count"`
|
||||
RebatedInviteeCount int `json:"rebated_invitee_count"`
|
||||
AvailableQuota float64 `json:"available_quota"`
|
||||
HistoryQuota float64 `json:"history_quota"`
|
||||
}
|
||||
|
||||
type AffiliateService struct {
|
||||
repo AffiliateRepository
|
||||
settingService *SettingService
|
||||
@ -238,6 +312,10 @@ func (s *AffiliateService) BindInviterByCode(ctx context.Context, userID int64,
|
||||
}
|
||||
|
||||
func (s *AffiliateService) AccrueInviteRebate(ctx context.Context, inviteeUserID int64, baseRechargeAmount float64) (float64, error) {
|
||||
return s.AccrueInviteRebateForOrder(ctx, inviteeUserID, baseRechargeAmount, nil)
|
||||
}
|
||||
|
||||
func (s *AffiliateService) AccrueInviteRebateForOrder(ctx context.Context, inviteeUserID int64, baseRechargeAmount float64, sourceOrderID *int64) (float64, error) {
|
||||
if s == nil || s.repo == nil {
|
||||
return 0, nil
|
||||
}
|
||||
@ -298,7 +376,7 @@ func (s *AffiliateService) AccrueInviteRebate(ctx context.Context, inviteeUserID
|
||||
freezeHours = s.settingService.GetAffiliateRebateFreezeHours(ctx)
|
||||
}
|
||||
|
||||
applied, err := s.repo.AccrueQuota(ctx, *inviteeSummary.InviterID, inviteeUserID, rebate, freezeHours)
|
||||
applied, err := s.repo.AccrueQuota(ctx, *inviteeSummary.InviterID, inviteeUserID, rebate, freezeHours, sourceOrderID)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
@ -488,3 +566,59 @@ func (s *AffiliateService) AdminListCustomUsers(ctx context.Context, filter Affi
|
||||
}
|
||||
return s.repo.ListUsersWithCustomSettings(ctx, filter)
|
||||
}
|
||||
|
||||
func (s *AffiliateService) AdminListInviteRecords(ctx context.Context, filter AffiliateRecordFilter) ([]AffiliateInviteRecord, int64, error) {
|
||||
if s == nil || s.repo == nil {
|
||||
return nil, 0, infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "affiliate service unavailable")
|
||||
}
|
||||
return s.repo.ListAffiliateInviteRecords(ctx, normalizeAffiliateRecordFilter(filter))
|
||||
}
|
||||
|
||||
func (s *AffiliateService) AdminListRebateRecords(ctx context.Context, filter AffiliateRecordFilter) ([]AffiliateRebateRecord, int64, error) {
|
||||
if s == nil || s.repo == nil {
|
||||
return nil, 0, infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "affiliate service unavailable")
|
||||
}
|
||||
return s.repo.ListAffiliateRebateRecords(ctx, normalizeAffiliateRecordFilter(filter))
|
||||
}
|
||||
|
||||
func (s *AffiliateService) AdminListTransferRecords(ctx context.Context, filter AffiliateRecordFilter) ([]AffiliateTransferRecord, int64, error) {
|
||||
if s == nil || s.repo == nil {
|
||||
return nil, 0, infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "affiliate service unavailable")
|
||||
}
|
||||
return s.repo.ListAffiliateTransferRecords(ctx, normalizeAffiliateRecordFilter(filter))
|
||||
}
|
||||
|
||||
func (s *AffiliateService) AdminGetUserOverview(ctx context.Context, userID int64) (*AffiliateUserOverview, error) {
|
||||
if userID <= 0 {
|
||||
return nil, infraerrors.BadRequest("INVALID_USER", "invalid user")
|
||||
}
|
||||
if s == nil || s.repo == nil {
|
||||
return nil, infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "affiliate service unavailable")
|
||||
}
|
||||
overview, err := s.repo.GetAffiliateUserOverview(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if overview != nil {
|
||||
if !overview.RebateRateCustom {
|
||||
overview.RebateRatePercent = s.globalRebateRatePercent(ctx)
|
||||
}
|
||||
overview.RebateRatePercent = clampAffiliateRebateRate(overview.RebateRatePercent)
|
||||
}
|
||||
return overview, nil
|
||||
}
|
||||
|
||||
func normalizeAffiliateRecordFilter(filter AffiliateRecordFilter) AffiliateRecordFilter {
|
||||
if filter.Page <= 0 {
|
||||
filter.Page = 1
|
||||
}
|
||||
if filter.PageSize <= 0 {
|
||||
filter.PageSize = 20
|
||||
}
|
||||
if filter.PageSize > 100 {
|
||||
filter.PageSize = 100
|
||||
}
|
||||
filter.Search = strings.TrimSpace(filter.Search)
|
||||
filter.SortBy = strings.TrimSpace(filter.SortBy)
|
||||
return filter
|
||||
}
|
||||
|
||||
@ -63,6 +63,9 @@ type APIKeyAuthGroupSnapshot struct {
|
||||
DailyLimitUSD *float64 `json:"daily_limit_usd,omitempty"`
|
||||
WeeklyLimitUSD *float64 `json:"weekly_limit_usd,omitempty"`
|
||||
MonthlyLimitUSD *float64 `json:"monthly_limit_usd,omitempty"`
|
||||
AllowImageGeneration bool `json:"allow_image_generation"`
|
||||
ImageRateIndependent bool `json:"image_rate_independent"`
|
||||
ImageRateMultiplier float64 `json:"image_rate_multiplier"`
|
||||
ImagePrice1K *float64 `json:"image_price_1k,omitempty"`
|
||||
ImagePrice2K *float64 `json:"image_price_2k,omitempty"`
|
||||
ImagePrice4K *float64 `json:"image_price_4k,omitempty"`
|
||||
|
||||
@ -14,7 +14,7 @@ import (
|
||||
"github.com/dgraph-io/ristretto"
|
||||
)
|
||||
|
||||
const apiKeyAuthSnapshotVersion = 7 // v7: added UserGroupRPMOverride on user snapshot
|
||||
const apiKeyAuthSnapshotVersion = 8 // v8: added group image generation controls
|
||||
|
||||
type apiKeyAuthCacheConfig struct {
|
||||
l1Size int
|
||||
@ -255,6 +255,9 @@ func (s *APIKeyService) snapshotFromAPIKey(ctx context.Context, apiKey *APIKey)
|
||||
DailyLimitUSD: apiKey.Group.DailyLimitUSD,
|
||||
WeeklyLimitUSD: apiKey.Group.WeeklyLimitUSD,
|
||||
MonthlyLimitUSD: apiKey.Group.MonthlyLimitUSD,
|
||||
AllowImageGeneration: apiKey.Group.AllowImageGeneration,
|
||||
ImageRateIndependent: apiKey.Group.ImageRateIndependent,
|
||||
ImageRateMultiplier: apiKey.Group.ImageRateMultiplier,
|
||||
ImagePrice1K: apiKey.Group.ImagePrice1K,
|
||||
ImagePrice2K: apiKey.Group.ImagePrice2K,
|
||||
ImagePrice4K: apiKey.Group.ImagePrice4K,
|
||||
@ -321,6 +324,9 @@ func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapsho
|
||||
DailyLimitUSD: snapshot.Group.DailyLimitUSD,
|
||||
WeeklyLimitUSD: snapshot.Group.WeeklyLimitUSD,
|
||||
MonthlyLimitUSD: snapshot.Group.MonthlyLimitUSD,
|
||||
AllowImageGeneration: snapshot.Group.AllowImageGeneration,
|
||||
ImageRateIndependent: snapshot.Group.ImageRateIndependent,
|
||||
ImageRateMultiplier: snapshot.Group.ImageRateMultiplier,
|
||||
ImagePrice1K: snapshot.Group.ImagePrice1K,
|
||||
ImagePrice2K: snapshot.Group.ImagePrice2K,
|
||||
ImagePrice4K: snapshot.Group.ImagePrice4K,
|
||||
|
||||
@ -226,6 +226,12 @@ func (s *BillingService) initFallbackPricing() {
|
||||
CacheReadPricePerToken: 7.5e-8,
|
||||
SupportsCacheBreakdown: false,
|
||||
}
|
||||
s.fallbackPrices["gpt-5.4-nano"] = &ModelPricing{
|
||||
InputPricePerToken: 2e-7,
|
||||
OutputPricePerToken: 1.25e-6,
|
||||
CacheReadPricePerToken: 2e-8,
|
||||
SupportsCacheBreakdown: false,
|
||||
}
|
||||
// OpenAI GPT-5.2(本地兜底)
|
||||
s.fallbackPrices["gpt-5.2"] = &ModelPricing{
|
||||
InputPricePerToken: 1.75e-6,
|
||||
@ -288,13 +294,14 @@ func (s *BillingService) getFallbackPricing(model string) *ModelPricing {
|
||||
}
|
||||
|
||||
// OpenAI 仅匹配已知 GPT-5/Codex 族,避免未知 OpenAI 型号误计价。
|
||||
if strings.Contains(modelLower, "gpt-5") || strings.Contains(modelLower, "codex") {
|
||||
normalized := normalizeCodexModel(modelLower)
|
||||
if normalized := normalizeKnownOpenAICodexModel(modelLower); normalized != "" {
|
||||
switch normalized {
|
||||
case "gpt-5.5":
|
||||
return s.fallbackPrices["gpt-5.5"]
|
||||
case "gpt-5.4-mini":
|
||||
return s.fallbackPrices["gpt-5.4-mini"]
|
||||
case "gpt-5.4-nano":
|
||||
return s.fallbackPrices["gpt-5.4-nano"]
|
||||
case "gpt-5.4":
|
||||
return s.fallbackPrices["gpt-5.4"]
|
||||
case "gpt-5.2":
|
||||
@ -636,13 +643,10 @@ func (s *BillingService) shouldApplySessionLongContextPricing(tokens UsageTokens
|
||||
}
|
||||
|
||||
func isOpenAIGPT54Model(model string) bool {
|
||||
trimmed := strings.TrimSpace(strings.ToLower(model))
|
||||
// 仅当模型字符串实际属于 GPT-5/Codex 族时才做归一判定,避免 normalizeCodexModel
|
||||
// 的默认兜底把非 OpenAI 模型(claude-*、gemini-*、gpt-4o)误识别为 gpt-5.4。
|
||||
if !strings.Contains(trimmed, "gpt-5") && !strings.Contains(trimmed, "codex") {
|
||||
return false
|
||||
}
|
||||
normalized := normalizeCodexModel(trimmed)
|
||||
// 仅当模型字符串实际属于已知 GPT-5/Codex 族时才做归一判定,避免
|
||||
// normalizeCodexModel 的默认兜底把非 OpenAI 模型(claude-*、gemini-*、gpt-4o)
|
||||
// 误识别为 gpt-5.4。
|
||||
normalized := normalizeKnownOpenAICodexModel(model)
|
||||
return normalized == "gpt-5.4" || normalized == "gpt-5.5"
|
||||
}
|
||||
|
||||
|
||||
@ -137,6 +137,35 @@ func TestGetModelPricing_OpenAIGPT54Fallback(t *testing.T) {
|
||||
require.InDelta(t, 1.5, pricing.LongContextOutputMultiplier, 1e-12)
|
||||
}
|
||||
|
||||
func TestGetModelPricing_OpenAICompactAliasesFallback(t *testing.T) {
|
||||
svc := newTestBillingService()
|
||||
|
||||
tests := []struct {
|
||||
model string
|
||||
inputPrice float64
|
||||
outputPrice float64
|
||||
cacheRead float64
|
||||
longContext int
|
||||
}{
|
||||
{model: "gpt5.5", inputPrice: 2.5e-6, outputPrice: 15e-6, cacheRead: 0.25e-6, longContext: 272000},
|
||||
{model: "openai/gpt5.4", inputPrice: 2.5e-6, outputPrice: 15e-6, cacheRead: 0.25e-6, longContext: 272000},
|
||||
{model: "gpt5.4-mini", inputPrice: 7.5e-7, outputPrice: 4.5e-6, cacheRead: 7.5e-8, longContext: 0},
|
||||
{model: "gpt5.3codexspark", inputPrice: 1.5e-6, outputPrice: 12e-6, cacheRead: 0.15e-6, longContext: 0},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.model, func(t *testing.T) {
|
||||
pricing, err := svc.GetModelPricing(tt.model)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, pricing)
|
||||
require.InDelta(t, tt.inputPrice, pricing.InputPricePerToken, 1e-12)
|
||||
require.InDelta(t, tt.outputPrice, pricing.OutputPricePerToken, 1e-12)
|
||||
require.InDelta(t, tt.cacheRead, pricing.CacheReadPricePerToken, 1e-12)
|
||||
require.Equal(t, tt.longContext, pricing.LongContextInputThreshold)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetModelPricing_OpenAIGPT54MiniFallback(t *testing.T) {
|
||||
svc := newTestBillingService()
|
||||
|
||||
|
||||
@ -52,10 +52,11 @@ const (
|
||||
|
||||
// Redeem type constants
|
||||
const (
|
||||
RedeemTypeBalance = domain.RedeemTypeBalance
|
||||
RedeemTypeConcurrency = domain.RedeemTypeConcurrency
|
||||
RedeemTypeSubscription = domain.RedeemTypeSubscription
|
||||
RedeemTypeInvitation = domain.RedeemTypeInvitation
|
||||
RedeemTypeBalance = domain.RedeemTypeBalance
|
||||
RedeemTypeConcurrency = domain.RedeemTypeConcurrency
|
||||
RedeemTypeSubscription = domain.RedeemTypeSubscription
|
||||
RedeemTypeInvitation = domain.RedeemTypeInvitation
|
||||
RedeemTypeAffiliateBalance = "affiliate_balance"
|
||||
)
|
||||
|
||||
// PromoCode status constants
|
||||
@ -287,6 +288,9 @@ const (
|
||||
// SettingKeyOverloadCooldownSettings stores JSON config for 529 overload cooldown handling.
|
||||
SettingKeyOverloadCooldownSettings = "overload_cooldown_settings"
|
||||
|
||||
// SettingKeyRateLimit429CooldownSettings stores JSON config for 429 fallback cooldown handling.
|
||||
SettingKeyRateLimit429CooldownSettings = "rate_limit_429_cooldown_settings"
|
||||
|
||||
// =========================
|
||||
// Stream Timeout Handling
|
||||
// =========================
|
||||
|
||||
@ -8297,9 +8297,16 @@ func detachedBillingContext(ctx context.Context) (context.Context, context.Cance
|
||||
}
|
||||
|
||||
func detachStreamUpstreamContext(ctx context.Context, stream bool) (context.Context, context.CancelFunc) {
|
||||
if ctx == nil {
|
||||
return context.Background(), func() {}
|
||||
}
|
||||
if !stream {
|
||||
return ctx, func() {}
|
||||
}
|
||||
return context.WithoutCancel(ctx), func() {}
|
||||
}
|
||||
|
||||
func detachUpstreamContext(ctx context.Context) (context.Context, context.CancelFunc) {
|
||||
if ctx == nil {
|
||||
return context.Background(), func() {}
|
||||
}
|
||||
@ -8483,6 +8490,7 @@ func (s *GatewayService) recordUsageCore(ctx context.Context, input *recordUsage
|
||||
groupDefault := apiKey.Group.RateMultiplier
|
||||
multiplier = s.getUserGroupRateMultiplier(ctx, user.ID, *apiKey.GroupID, groupDefault)
|
||||
}
|
||||
imageMultiplier := resolveImageRateMultiplier(apiKey, multiplier)
|
||||
|
||||
// 确定计费模型
|
||||
billingModel := forwardResultBillingModel(result.Model, result.UpstreamModel)
|
||||
@ -8500,7 +8508,7 @@ func (s *GatewayService) recordUsageCore(ctx context.Context, input *recordUsage
|
||||
}
|
||||
|
||||
// 计算费用
|
||||
cost := s.calculateRecordUsageCost(ctx, result, apiKey, billingModel, multiplier, opts)
|
||||
cost := s.calculateRecordUsageCost(ctx, result, apiKey, billingModel, multiplier, imageMultiplier, opts)
|
||||
|
||||
// 判断计费方式:订阅模式 vs 余额模式
|
||||
isSubscriptionBilling := subscription != nil && apiKey.Group != nil && apiKey.Group.IsSubscriptionType()
|
||||
@ -8512,7 +8520,7 @@ func (s *GatewayService) recordUsageCore(ctx context.Context, input *recordUsage
|
||||
// 创建使用日志
|
||||
accountRateMultiplier := account.BillingRateMultiplier()
|
||||
usageLog := s.buildRecordUsageLog(ctx, input, result, apiKey, user, account, subscription,
|
||||
requestedModel, multiplier, accountRateMultiplier, billingType, cacheTTLOverridden, cost, opts)
|
||||
requestedModel, multiplier, imageMultiplier, accountRateMultiplier, billingType, cacheTTLOverridden, cost, opts)
|
||||
|
||||
// 计算账号统计定价费用(使用最终上游模型匹配自定义规则)
|
||||
if apiKey.GroupID != nil {
|
||||
@ -8566,11 +8574,12 @@ func (s *GatewayService) calculateRecordUsageCost(
|
||||
apiKey *APIKey,
|
||||
billingModel string,
|
||||
multiplier float64,
|
||||
imageMultiplier float64,
|
||||
opts *recordUsageOpts,
|
||||
) *CostBreakdown {
|
||||
// 图片生成计费
|
||||
if result.ImageCount > 0 {
|
||||
return s.calculateImageCost(ctx, result, apiKey, billingModel, multiplier)
|
||||
return s.calculateImageCost(ctx, result, apiKey, billingModel, imageMultiplier)
|
||||
}
|
||||
|
||||
// Token 计费
|
||||
@ -8611,7 +8620,8 @@ func (s *GatewayService) calculateImageCost(
|
||||
Model: billingModel,
|
||||
GroupID: &gid,
|
||||
Tokens: tokens,
|
||||
RequestCount: 1,
|
||||
RequestCount: result.ImageCount,
|
||||
SizeTier: result.ImageSize,
|
||||
RateMultiplier: multiplier,
|
||||
Resolver: s.resolver,
|
||||
Resolved: resolved,
|
||||
@ -8696,6 +8706,7 @@ func (s *GatewayService) buildRecordUsageLog(
|
||||
subscription *UserSubscription,
|
||||
requestedModel string,
|
||||
multiplier float64,
|
||||
imageMultiplier float64,
|
||||
accountRateMultiplier float64,
|
||||
billingType int8,
|
||||
cacheTTLOverridden bool,
|
||||
@ -8740,6 +8751,9 @@ func (s *GatewayService) buildRecordUsageLog(
|
||||
SubscriptionID: optionalSubscriptionID(subscription),
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
if result.ImageCount > 0 {
|
||||
usageLog.RateMultiplier = imageMultiplier
|
||||
}
|
||||
if cost != nil {
|
||||
usageLog.InputCost = cost.InputCost
|
||||
usageLog.OutputCost = cost.OutputCost
|
||||
|
||||
@ -13,6 +13,8 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type upstreamContextTestKey string
|
||||
|
||||
func TestGatewayService_StreamingReusesScannerBufferAndStillParsesUsage(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
cfg := &config.Config{
|
||||
@ -50,3 +52,14 @@ func TestGatewayService_StreamingReusesScannerBufferAndStillParsesUsage(t *testi
|
||||
require.Equal(t, 3, result.usage.InputTokens)
|
||||
require.Equal(t, 7, result.usage.OutputTokens)
|
||||
}
|
||||
|
||||
func TestDetachUpstreamContextIgnoresClientCancel(t *testing.T) {
|
||||
parent, cancel := context.WithCancel(context.WithValue(context.Background(), upstreamContextTestKey("test-key"), "test-value"))
|
||||
upstreamCtx, release := detachUpstreamContext(parent)
|
||||
defer release()
|
||||
|
||||
cancel()
|
||||
|
||||
require.NoError(t, upstreamCtx.Err())
|
||||
require.Equal(t, "test-value", upstreamCtx.Value(upstreamContextTestKey("test-key")))
|
||||
}
|
||||
|
||||
@ -26,9 +26,12 @@ type Group struct {
|
||||
DefaultValidityDays int
|
||||
|
||||
// 图片生成计费配置(antigravity 和 gemini 平台使用)
|
||||
ImagePrice1K *float64
|
||||
ImagePrice2K *float64
|
||||
ImagePrice4K *float64
|
||||
AllowImageGeneration bool
|
||||
ImageRateIndependent bool
|
||||
ImageRateMultiplier float64
|
||||
ImagePrice1K *float64
|
||||
ImagePrice2K *float64
|
||||
ImagePrice4K *float64
|
||||
|
||||
// Claude Code 客户端限制
|
||||
ClaudeCodeOnly bool
|
||||
|
||||
@ -45,19 +45,25 @@ type GroupSortOrderUpdate struct {
|
||||
|
||||
// CreateGroupRequest 创建分组请求
|
||||
type CreateGroupRequest struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
RateMultiplier float64 `json:"rate_multiplier"`
|
||||
IsExclusive bool `json:"is_exclusive"`
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
RateMultiplier float64 `json:"rate_multiplier"`
|
||||
IsExclusive bool `json:"is_exclusive"`
|
||||
AllowImageGeneration bool `json:"allow_image_generation"`
|
||||
ImageRateIndependent bool `json:"image_rate_independent"`
|
||||
ImageRateMultiplier *float64 `json:"image_rate_multiplier"`
|
||||
}
|
||||
|
||||
// UpdateGroupRequest 更新分组请求
|
||||
type UpdateGroupRequest struct {
|
||||
Name *string `json:"name"`
|
||||
Description *string `json:"description"`
|
||||
RateMultiplier *float64 `json:"rate_multiplier"`
|
||||
IsExclusive *bool `json:"is_exclusive"`
|
||||
Status *string `json:"status"`
|
||||
Name *string `json:"name"`
|
||||
Description *string `json:"description"`
|
||||
RateMultiplier *float64 `json:"rate_multiplier"`
|
||||
IsExclusive *bool `json:"is_exclusive"`
|
||||
Status *string `json:"status"`
|
||||
AllowImageGeneration *bool `json:"allow_image_generation"`
|
||||
ImageRateIndependent *bool `json:"image_rate_independent"`
|
||||
ImageRateMultiplier *float64 `json:"image_rate_multiplier"`
|
||||
}
|
||||
|
||||
// GroupService 分组管理服务
|
||||
@ -76,6 +82,13 @@ func NewGroupService(groupRepo GroupRepository, authCacheInvalidator APIKeyAuthC
|
||||
|
||||
// Create 创建分组
|
||||
func (s *GroupService) Create(ctx context.Context, req CreateGroupRequest) (*Group, error) {
|
||||
imageRateMultiplier := 1.0
|
||||
if req.ImageRateMultiplier != nil {
|
||||
if *req.ImageRateMultiplier < 0 {
|
||||
return nil, fmt.Errorf("image_rate_multiplier must be >= 0")
|
||||
}
|
||||
imageRateMultiplier = *req.ImageRateMultiplier
|
||||
}
|
||||
// 检查名称是否已存在
|
||||
exists, err := s.groupRepo.ExistsByName(ctx, req.Name)
|
||||
if err != nil {
|
||||
@ -87,13 +100,16 @@ func (s *GroupService) Create(ctx context.Context, req CreateGroupRequest) (*Gro
|
||||
|
||||
// 创建分组
|
||||
group := &Group{
|
||||
Name: req.Name,
|
||||
Description: req.Description,
|
||||
Platform: PlatformAnthropic,
|
||||
RateMultiplier: req.RateMultiplier,
|
||||
IsExclusive: req.IsExclusive,
|
||||
Status: StatusActive,
|
||||
SubscriptionType: SubscriptionTypeStandard,
|
||||
Name: req.Name,
|
||||
Description: req.Description,
|
||||
Platform: PlatformAnthropic,
|
||||
RateMultiplier: req.RateMultiplier,
|
||||
IsExclusive: req.IsExclusive,
|
||||
Status: StatusActive,
|
||||
SubscriptionType: SubscriptionTypeStandard,
|
||||
AllowImageGeneration: req.AllowImageGeneration,
|
||||
ImageRateIndependent: req.ImageRateIndependent,
|
||||
ImageRateMultiplier: imageRateMultiplier,
|
||||
}
|
||||
|
||||
if err := s.groupRepo.Create(ctx, group); err != nil {
|
||||
@ -165,6 +181,18 @@ func (s *GroupService) Update(ctx context.Context, id int64, req UpdateGroupRequ
|
||||
if req.Status != nil {
|
||||
group.Status = *req.Status
|
||||
}
|
||||
if req.AllowImageGeneration != nil {
|
||||
group.AllowImageGeneration = *req.AllowImageGeneration
|
||||
}
|
||||
if req.ImageRateIndependent != nil {
|
||||
group.ImageRateIndependent = *req.ImageRateIndependent
|
||||
}
|
||||
if req.ImageRateMultiplier != nil {
|
||||
if *req.ImageRateMultiplier < 0 {
|
||||
return nil, fmt.Errorf("image_rate_multiplier must be >= 0")
|
||||
}
|
||||
group.ImageRateMultiplier = *req.ImageRateMultiplier
|
||||
}
|
||||
|
||||
if err := s.groupRepo.Update(ctx, group); err != nil {
|
||||
return nil, fmt.Errorf("update group: %w", err)
|
||||
|
||||
11
backend/internal/service/image_billing_multiplier.go
Normal file
11
backend/internal/service/image_billing_multiplier.go
Normal file
@ -0,0 +1,11 @@
|
||||
package service
|
||||
|
||||
func resolveImageRateMultiplier(apiKey *APIKey, effectiveGroupMultiplier float64) float64 {
|
||||
if apiKey != nil && apiKey.Group != nil && apiKey.Group.ImageRateIndependent {
|
||||
if apiKey.Group.ImageRateMultiplier < 0 {
|
||||
return 0
|
||||
}
|
||||
return apiKey.Group.ImageRateMultiplier
|
||||
}
|
||||
return effectiveGroupMultiplier
|
||||
}
|
||||
220
backend/internal/service/image_generation_intent.go
Normal file
220
backend/internal/service/image_generation_intent.go
Normal file
@ -0,0 +1,220 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
const (
|
||||
openAIResponsesEndpoint = "/v1/responses"
|
||||
openAIResponsesCompactEndpoint = "/v1/responses/compact"
|
||||
imageGenerationPermissionMessage = "Image generation is not enabled for this group"
|
||||
)
|
||||
|
||||
// ImageGenerationPermissionMessage returns the stable end-user error text for disabled groups.
|
||||
func ImageGenerationPermissionMessage() string {
|
||||
return imageGenerationPermissionMessage
|
||||
}
|
||||
|
||||
// GroupAllowsImageGeneration preserves ungrouped-key behavior and enforces the flag when a group is present.
|
||||
func GroupAllowsImageGeneration(group *Group) bool {
|
||||
return group == nil || group.AllowImageGeneration
|
||||
}
|
||||
|
||||
// IsImageGenerationIntent classifies requests that can produce generated images.
|
||||
func IsImageGenerationIntent(endpoint string, requestedModel string, body []byte) bool {
|
||||
if IsImageGenerationEndpoint(endpoint) {
|
||||
return true
|
||||
}
|
||||
if isOpenAIImageGenerationModel(requestedModel) {
|
||||
return true
|
||||
}
|
||||
if len(body) == 0 || !gjson.ValidBytes(body) {
|
||||
return false
|
||||
}
|
||||
if model := strings.TrimSpace(gjson.GetBytes(body, "model").String()); isOpenAIImageGenerationModel(model) {
|
||||
return true
|
||||
}
|
||||
if openAIJSONToolsContainImageGeneration(gjson.GetBytes(body, "tools")) {
|
||||
return true
|
||||
}
|
||||
return openAIJSONToolChoiceSelectsImageGeneration(gjson.GetBytes(body, "tool_choice"))
|
||||
}
|
||||
|
||||
// IsImageGenerationIntentMap is the map-backed variant used after service-side request mutation.
|
||||
func IsImageGenerationIntentMap(endpoint string, requestedModel string, reqBody map[string]any) bool {
|
||||
if IsImageGenerationEndpoint(endpoint) {
|
||||
return true
|
||||
}
|
||||
if isOpenAIImageGenerationModel(requestedModel) {
|
||||
return true
|
||||
}
|
||||
if reqBody == nil {
|
||||
return false
|
||||
}
|
||||
if isOpenAIImageGenerationModel(firstNonEmptyString(reqBody["model"])) {
|
||||
return true
|
||||
}
|
||||
if hasOpenAIImageGenerationTool(reqBody) {
|
||||
return true
|
||||
}
|
||||
return openAIAnyToolChoiceSelectsImageGeneration(reqBody["tool_choice"])
|
||||
}
|
||||
|
||||
// IsImageGenerationEndpoint identifies dedicated generated-image endpoints.
|
||||
func IsImageGenerationEndpoint(endpoint string) bool {
|
||||
switch normalizeImageGenerationEndpoint(endpoint) {
|
||||
case "/v1/images/generations", "/v1/images/edits", "/images/generations", "/images/edits":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func normalizeImageGenerationEndpoint(endpoint string) string {
|
||||
endpoint = strings.TrimSpace(strings.ToLower(endpoint))
|
||||
if endpoint == "" {
|
||||
return ""
|
||||
}
|
||||
endpoint = strings.TrimPrefix(endpoint, "https://api.openai.com")
|
||||
if idx := strings.IndexByte(endpoint, '?'); idx >= 0 {
|
||||
endpoint = endpoint[:idx]
|
||||
}
|
||||
return strings.TrimRight(endpoint, "/")
|
||||
}
|
||||
|
||||
func openAIJSONToolsContainImageGeneration(tools gjson.Result) bool {
|
||||
if !tools.IsArray() {
|
||||
return false
|
||||
}
|
||||
found := false
|
||||
tools.ForEach(func(_, item gjson.Result) bool {
|
||||
if strings.TrimSpace(item.Get("type").String()) == "image_generation" {
|
||||
found = true
|
||||
return false
|
||||
}
|
||||
return true
|
||||
})
|
||||
return found
|
||||
}
|
||||
|
||||
func openAIJSONToolChoiceSelectsImageGeneration(choice gjson.Result) bool {
|
||||
if !choice.Exists() {
|
||||
return false
|
||||
}
|
||||
if choice.Type == gjson.String {
|
||||
return strings.TrimSpace(choice.String()) == "image_generation"
|
||||
}
|
||||
if !choice.IsObject() {
|
||||
return false
|
||||
}
|
||||
if strings.TrimSpace(choice.Get("type").String()) == "image_generation" {
|
||||
return true
|
||||
}
|
||||
if strings.TrimSpace(choice.Get("tool.type").String()) == "image_generation" {
|
||||
return true
|
||||
}
|
||||
if strings.TrimSpace(choice.Get("function.name").String()) == "image_generation" {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func openAIAnyToolChoiceSelectsImageGeneration(choice any) bool {
|
||||
switch v := choice.(type) {
|
||||
case string:
|
||||
return strings.TrimSpace(v) == "image_generation"
|
||||
case map[string]any:
|
||||
if strings.TrimSpace(firstNonEmptyString(v["type"])) == "image_generation" {
|
||||
return true
|
||||
}
|
||||
if tool, ok := v["tool"].(map[string]any); ok && strings.TrimSpace(firstNonEmptyString(tool["type"])) == "image_generation" {
|
||||
return true
|
||||
}
|
||||
if fn, ok := v["function"].(map[string]any); ok && strings.TrimSpace(firstNonEmptyString(fn["name"])) == "image_generation" {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func getAPIKeyFromContext(c interface{ Get(string) (any, bool) }) *APIKey {
|
||||
if c == nil {
|
||||
return nil
|
||||
}
|
||||
v, exists := c.Get("api_key")
|
||||
if !exists {
|
||||
return nil
|
||||
}
|
||||
apiKey, _ := v.(*APIKey)
|
||||
return apiKey
|
||||
}
|
||||
|
||||
func apiKeyGroup(apiKey *APIKey) *Group {
|
||||
if apiKey == nil {
|
||||
return nil
|
||||
}
|
||||
return apiKey.Group
|
||||
}
|
||||
|
||||
func cloneRequestMapForImageIntent(body []byte) map[string]any {
|
||||
if len(body) == 0 {
|
||||
return nil
|
||||
}
|
||||
var out map[string]any
|
||||
if err := json.Unmarshal(body, &out); err != nil {
|
||||
return nil
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func resolveOpenAIResponsesImageBillingConfig(reqBody map[string]any, fallbackModel string) (string, string, error) {
|
||||
imageModel := ""
|
||||
imageSize := ""
|
||||
hasImageTool := false
|
||||
if reqBody != nil {
|
||||
rawTools, _ := reqBody["tools"].([]any)
|
||||
for _, rawTool := range rawTools {
|
||||
toolMap, ok := rawTool.(map[string]any)
|
||||
if !ok || strings.TrimSpace(firstNonEmptyString(toolMap["type"])) != "image_generation" {
|
||||
continue
|
||||
}
|
||||
hasImageTool = true
|
||||
imageModel = strings.TrimSpace(firstNonEmptyString(toolMap["model"]))
|
||||
imageSize = strings.TrimSpace(firstNonEmptyString(toolMap["size"]))
|
||||
break
|
||||
}
|
||||
if imageSize == "" {
|
||||
imageSize = strings.TrimSpace(firstNonEmptyString(reqBody["size"]))
|
||||
}
|
||||
}
|
||||
if imageModel == "" && reqBody != nil {
|
||||
bodyModel := strings.TrimSpace(firstNonEmptyString(reqBody["model"]))
|
||||
if isOpenAIImageBillingModelAlias(bodyModel) || !hasImageTool {
|
||||
imageModel = bodyModel
|
||||
}
|
||||
}
|
||||
if imageModel == "" && hasImageTool {
|
||||
imageModel = "gpt-image-2"
|
||||
}
|
||||
if imageModel == "" {
|
||||
imageModel = strings.TrimSpace(fallbackModel)
|
||||
}
|
||||
sizeTier := normalizeOpenAIImageSizeTier(imageSize)
|
||||
return imageModel, sizeTier, nil
|
||||
}
|
||||
|
||||
func resolveOpenAIResponsesImageBillingConfigFromBody(body []byte, fallbackModel string) (string, string, error) {
|
||||
reqBody := cloneRequestMapForImageIntent(body)
|
||||
return resolveOpenAIResponsesImageBillingConfig(reqBody, fallbackModel)
|
||||
}
|
||||
|
||||
func isOpenAIImageBillingModelAlias(model string) bool {
|
||||
normalized := strings.ToLower(strings.TrimSpace(model))
|
||||
if normalized == "" {
|
||||
return false
|
||||
}
|
||||
return isOpenAIImageGenerationModel(normalized) || strings.Contains(normalized, "image")
|
||||
}
|
||||
184
backend/internal/service/image_generation_intent_test.go
Normal file
184
backend/internal/service/image_generation_intent_test.go
Normal file
@ -0,0 +1,184 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestIsImageGenerationIntent(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
endpoint string
|
||||
model string
|
||||
body []byte
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "images endpoint",
|
||||
endpoint: "/v1/images/generations",
|
||||
body: []byte(`{"model":"gpt-image-2"}`),
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "image model",
|
||||
endpoint: "/v1/responses",
|
||||
model: "gpt-image-2",
|
||||
body: []byte(`{"model":"gpt-image-2"}`),
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "image tool",
|
||||
endpoint: "/v1/responses",
|
||||
model: "gpt-5.4",
|
||||
body: []byte(`{"model":"gpt-5.4","tools":[{"type":"image_generation"}]}`),
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "image tool choice",
|
||||
endpoint: "/v1/responses",
|
||||
model: "gpt-5.4",
|
||||
body: []byte(`{"model":"gpt-5.4","tool_choice":{"type":"image_generation"}}`),
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "required tool choice alone is text",
|
||||
endpoint: "/v1/responses",
|
||||
model: "gpt-5.4",
|
||||
body: []byte(`{"model":"gpt-5.4","tool_choice":"required"}`),
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "text only gpt 5.4",
|
||||
endpoint: "/v1/responses",
|
||||
model: "gpt-5.4",
|
||||
body: []byte(`{"model":"gpt-5.4","input":"write code"}`),
|
||||
want: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
require.Equal(t, tt.want, IsImageGenerationIntent(tt.endpoint, tt.model, tt.body))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveOpenAIResponsesImageBillingConfigUsesCurrentBodyModel(t *testing.T) {
|
||||
imageModel, imageSize, err := resolveOpenAIResponsesImageBillingConfigFromBody(
|
||||
[]byte(`{"model":"mapped-image-model","tools":[{"type":"image_generation","size":"1024x1024"}]}`),
|
||||
"requested-model",
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "mapped-image-model", imageModel)
|
||||
require.Equal(t, "1K", imageSize)
|
||||
}
|
||||
|
||||
func TestResolveOpenAIResponsesImageBillingConfigToolModelWins(t *testing.T) {
|
||||
imageModel, imageSize, err := resolveOpenAIResponsesImageBillingConfigFromBody(
|
||||
[]byte(`{"model":"mapped-text-model","tools":[{"type":"image_generation","model":"gpt-image-2","size":"1536x1024"}]}`),
|
||||
"requested-model",
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "gpt-image-2", imageModel)
|
||||
require.Equal(t, "2K", imageSize)
|
||||
}
|
||||
|
||||
func TestResolveOpenAIResponsesImageBillingConfigSupportsOfficialAndCustomSizes(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
body []byte
|
||||
wantTier string
|
||||
}{
|
||||
{
|
||||
name: "official 2k landscape",
|
||||
body: []byte(`{"model":"gpt-5.4","tools":[{"type":"image_generation","model":"gpt-image-2","size":"2048x1152"}]}`),
|
||||
wantTier: "2K",
|
||||
},
|
||||
{
|
||||
name: "official 4k landscape",
|
||||
body: []byte(`{"model":"gpt-5.4","tools":[{"type":"image_generation","model":"gpt-image-2","size":"3840x2160"}]}`),
|
||||
wantTier: "4K",
|
||||
},
|
||||
{
|
||||
name: "custom valid 2k",
|
||||
body: []byte(`{"model":"gpt-5.5","tools":[{"type":"image_generation","model":"gpt-image-2","size":"1280x768"}]}`),
|
||||
wantTier: "2K",
|
||||
},
|
||||
{
|
||||
name: "default image tool model supports flexible size",
|
||||
body: []byte(`{"model":"gpt-5.4","tools":[{"type":"image_generation","size":"2048x1152"}]}`),
|
||||
wantTier: "2K",
|
||||
},
|
||||
{
|
||||
name: "top level image size is moved into billing",
|
||||
body: []byte(`{"model":"gpt-image-2","size":"2048x2048","tools":[{"type":"image_generation","model":"gpt-image-2"}]}`),
|
||||
wantTier: "2K",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
imageModel, imageSize, err := resolveOpenAIResponsesImageBillingConfigFromBody(tt.body, "requested-model")
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, imageModel)
|
||||
require.Equal(t, tt.wantTier, imageSize)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveOpenAIResponsesImageBillingConfigDoesNotRejectUnknownSizes(t *testing.T) {
|
||||
imageModel, imageSize, err := resolveOpenAIResponsesImageBillingConfigFromBody(
|
||||
[]byte(`{"model":"gpt-5.4","tools":[{"type":"image_generation","model":"gpt-image-1.5","size":"2048x1152"}]}`),
|
||||
"requested-model",
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "gpt-image-1.5", imageModel)
|
||||
require.Equal(t, "2K", imageSize)
|
||||
}
|
||||
|
||||
func TestOpenAIImageOutputCounterDeduplicatesFinalImages(t *testing.T) {
|
||||
counter := newOpenAIImageOutputCounter()
|
||||
counter.AddSSEData([]byte(`{"type":"response.image_generation_call.partial_image","partial_image_b64":"abc"}`))
|
||||
counter.AddSSEData([]byte(`{"type":"response.output_item.done","item":{"id":"ig_1","type":"image_generation_call","result":"final-a"}}`))
|
||||
counter.AddSSEData([]byte(`{"type":"response.completed","response":{"output":[{"id":"ig_1","type":"image_generation_call","result":"final-a"},{"id":"ig_2","type":"image_generation_call","result":"final-b"}]}}`))
|
||||
require.Equal(t, 2, counter.Count())
|
||||
}
|
||||
|
||||
func TestOpenAIImageOutputCounterCountsImagesAPIStreamShapes(t *testing.T) {
|
||||
counter := newOpenAIImageOutputCounter()
|
||||
counter.AddSSEData([]byte(`{"type":"image_generation.completed","id":"ig_complete","b64_json":"final-a"}`))
|
||||
counter.AddSSEData([]byte(`{"type":"response.output_item.done","item":{"id":"ig_item","type":"image_generation_call","result":"final-b"}}`))
|
||||
counter.AddSSEData([]byte(`{"type":"response.completed","response":{"output":[{"id":"ig_done","type":"image_generation_call","result":"final-c"}]}}`))
|
||||
require.Equal(t, 3, counter.Count())
|
||||
|
||||
dataCounter := newOpenAIImageOutputCounter()
|
||||
dataCounter.AddSSEData([]byte(`{"data":[{"b64_json":"a"},{"b64_json":"b"}]}`))
|
||||
dataCounter.AddSSEData([]byte(`{"data":[{"b64_json":"a"},{"b64_json":"b"},{"b64_json":"c"}]}`))
|
||||
require.Equal(t, 3, dataCounter.Count())
|
||||
}
|
||||
|
||||
func TestOpenAIImageOutputCounterCountsMultilineSSEDataPayload(t *testing.T) {
|
||||
counter := newOpenAIImageOutputCounter()
|
||||
counter.AddSSEData([]byte("{\"type\":\"image_generation.completed\",\n\"b64_json\":\"final-a\"}"))
|
||||
require.Equal(t, 1, counter.Count())
|
||||
}
|
||||
|
||||
func TestOpenAIImageOutputCounterCountsMultilineSSEBodyPayload(t *testing.T) {
|
||||
counter := newOpenAIImageOutputCounter()
|
||||
counter.AddSSEBody(
|
||||
"data: {\"type\":\"image_generation.completed\",\n" +
|
||||
"data: \"b64_json\":\"final-a\"}\n\n" +
|
||||
"data: [DONE]\n\n",
|
||||
)
|
||||
require.Equal(t, 1, counter.Count())
|
||||
}
|
||||
|
||||
func TestOpenAIImageOutputCounterFallsBackForInvalidMultilineSSEBody(t *testing.T) {
|
||||
counter := newOpenAIImageOutputCounter()
|
||||
counter.AddSSEBody(
|
||||
"data: {\"type\":\"image_generation.completed\",\"b64_json\":\"final-a\"}\n" +
|
||||
"data: {\"type\":\"image_generation.completed\",\"b64_json\":\"final-b\"}\n\n",
|
||||
)
|
||||
require.Equal(t, 2, counter.Count())
|
||||
}
|
||||
149
backend/internal/service/image_output_accounting.go
Normal file
149
backend/internal/service/image_output_accounting.go
Normal file
@ -0,0 +1,149 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"strings"
|
||||
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
type openAIImageOutputCounter struct {
|
||||
seen map[string]struct{}
|
||||
count int
|
||||
maxDataCount int
|
||||
}
|
||||
|
||||
func newOpenAIImageOutputCounter() *openAIImageOutputCounter {
|
||||
return &openAIImageOutputCounter{seen: make(map[string]struct{})}
|
||||
}
|
||||
|
||||
func (c *openAIImageOutputCounter) Count() int {
|
||||
if c == nil {
|
||||
return 0
|
||||
}
|
||||
if c.maxDataCount > c.count {
|
||||
return c.maxDataCount
|
||||
}
|
||||
return c.count
|
||||
}
|
||||
|
||||
func (c *openAIImageOutputCounter) AddJSONResponse(body []byte) {
|
||||
if c == nil || len(body) == 0 || !gjson.ValidBytes(body) {
|
||||
return
|
||||
}
|
||||
c.addDataArray(gjson.GetBytes(body, "data"))
|
||||
c.addOutputArray(gjson.GetBytes(body, "output"))
|
||||
c.addOutputArray(gjson.GetBytes(body, "response.output"))
|
||||
}
|
||||
|
||||
func (c *openAIImageOutputCounter) AddSSEData(data []byte) {
|
||||
if c == nil || len(data) == 0 || strings.TrimSpace(string(data)) == "[DONE]" || !gjson.ValidBytes(data) {
|
||||
return
|
||||
}
|
||||
root := gjson.ParseBytes(data)
|
||||
c.addDataArray(root.Get("data"))
|
||||
eventType := strings.TrimSpace(root.Get("type").String())
|
||||
switch eventType {
|
||||
case "response.output_item.done":
|
||||
c.addImageOutputItem(root.Get("item"))
|
||||
case "response.completed", "response.done":
|
||||
c.addOutputArray(root.Get("response.output"))
|
||||
case "image_generation.completed":
|
||||
if item := root.Get("item"); item.Exists() {
|
||||
c.addImageOutputItem(item)
|
||||
return
|
||||
}
|
||||
if output := root.Get("output"); output.Exists() {
|
||||
c.addImageOutputItem(output)
|
||||
return
|
||||
}
|
||||
c.addImageOutputItem(root)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *openAIImageOutputCounter) AddSSEBody(body string) {
|
||||
if c == nil || strings.TrimSpace(body) == "" {
|
||||
return
|
||||
}
|
||||
forEachOpenAISSEDataPayload(body, c.AddSSEData)
|
||||
}
|
||||
|
||||
func (c *openAIImageOutputCounter) addDataArray(data gjson.Result) {
|
||||
if !data.IsArray() {
|
||||
return
|
||||
}
|
||||
count := len(data.Array())
|
||||
if count > c.maxDataCount {
|
||||
c.maxDataCount = count
|
||||
}
|
||||
}
|
||||
|
||||
func (c *openAIImageOutputCounter) addOutputArray(output gjson.Result) {
|
||||
if !output.IsArray() {
|
||||
return
|
||||
}
|
||||
output.ForEach(func(_, item gjson.Result) bool {
|
||||
c.addImageOutputItem(item)
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
func (c *openAIImageOutputCounter) addImageOutputItem(item gjson.Result) {
|
||||
if !item.Exists() || !item.IsObject() {
|
||||
return
|
||||
}
|
||||
itemType := strings.TrimSpace(item.Get("type").String())
|
||||
if itemType != "" && itemType != "image_generation_call" && itemType != "image_generation.completed" {
|
||||
return
|
||||
}
|
||||
if strings.Contains(strings.ToLower(item.Raw), "partial_image") {
|
||||
return
|
||||
}
|
||||
result := strings.TrimSpace(item.Get("result").String())
|
||||
if result == "" {
|
||||
result = strings.TrimSpace(item.Get("b64_json").String())
|
||||
}
|
||||
if result == "" {
|
||||
result = strings.TrimSpace(item.Get("url").String())
|
||||
}
|
||||
if result == "" && itemType != "image_generation.completed" {
|
||||
return
|
||||
}
|
||||
key := strings.TrimSpace(item.Get("id").String())
|
||||
if key == "" {
|
||||
key = strings.TrimSpace(item.Get("call_id").String())
|
||||
}
|
||||
if key == "" {
|
||||
key = hashOpenAIImageOutputResult(result)
|
||||
}
|
||||
if key == "" {
|
||||
return
|
||||
}
|
||||
if _, exists := c.seen[key]; exists {
|
||||
return
|
||||
}
|
||||
c.seen[key] = struct{}{}
|
||||
c.count++
|
||||
}
|
||||
|
||||
func hashOpenAIImageOutputResult(result string) string {
|
||||
result = strings.TrimSpace(result)
|
||||
if result == "" {
|
||||
return ""
|
||||
}
|
||||
sum := sha256.Sum256([]byte(result))
|
||||
return hex.EncodeToString(sum[:])
|
||||
}
|
||||
|
||||
func countOpenAIResponseImageOutputsFromJSONBytes(body []byte) int {
|
||||
counter := newOpenAIImageOutputCounter()
|
||||
counter.AddJSONResponse(body)
|
||||
return counter.Count()
|
||||
}
|
||||
|
||||
func countOpenAIImageOutputsFromSSEBody(body string) int {
|
||||
counter := newOpenAIImageOutputCounter()
|
||||
counter.AddSSEBody(body)
|
||||
return counter.Count()
|
||||
}
|
||||
@ -5,6 +5,7 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"hash/fnv"
|
||||
"log/slog"
|
||||
"math"
|
||||
"sort"
|
||||
"strconv"
|
||||
@ -345,7 +346,7 @@ func (s *defaultOpenAIAccountScheduler) selectBySessionHash(
|
||||
_ = s.service.deleteStickySessionAccountID(ctx, req.GroupID, sessionHash)
|
||||
return nil, nil
|
||||
}
|
||||
if !s.isAccountRequestCompatible(account, req) {
|
||||
if !s.isAccountRequestCompatible(ctx, account, req) {
|
||||
return nil, nil
|
||||
}
|
||||
if !s.isAccountTransportCompatible(account, req.RequiredTransport) {
|
||||
@ -621,7 +622,7 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance(
|
||||
fmt.Sprintf("Privacy not set, required by group [%s]", schedGroup.Name))
|
||||
continue
|
||||
}
|
||||
if !s.isAccountRequestCompatible(account, req) {
|
||||
if !s.isAccountRequestCompatible(ctx, account, req) {
|
||||
continue
|
||||
}
|
||||
if !s.isAccountTransportCompatible(account, req.RequiredTransport) {
|
||||
@ -828,11 +829,11 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance(
|
||||
for i := 0; i < len(selectionOrder); i++ {
|
||||
candidate := selectionOrder[i]
|
||||
fresh := s.service.resolveFreshSchedulableOpenAIAccount(ctx, candidate.account, req.RequestedModel, false)
|
||||
if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) || !s.isAccountRequestCompatible(fresh, req) {
|
||||
if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) || !s.isAccountRequestCompatible(ctx, fresh, req) {
|
||||
continue
|
||||
}
|
||||
fresh = s.service.recheckSelectedOpenAIAccountFromDB(ctx, fresh, req.RequestedModel, false)
|
||||
if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) || !s.isAccountRequestCompatible(fresh, req) {
|
||||
if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) || !s.isAccountRequestCompatible(ctx, fresh, req) {
|
||||
continue
|
||||
}
|
||||
if req.RequireCompact && openAICompactSupportTier(fresh) == 0 {
|
||||
@ -859,11 +860,11 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance(
|
||||
// WaitPlan.MaxConcurrency 使用 Concurrency(非 EffectiveLoadFactor),因为 WaitPlan 控制的是 Redis 实际并发槽位等待。
|
||||
for _, candidate := range selectionOrder {
|
||||
fresh := s.service.resolveFreshSchedulableOpenAIAccount(ctx, candidate.account, req.RequestedModel, false)
|
||||
if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) || !s.isAccountRequestCompatible(fresh, req) {
|
||||
if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) || !s.isAccountRequestCompatible(ctx, fresh, req) {
|
||||
continue
|
||||
}
|
||||
fresh = s.service.recheckSelectedOpenAIAccountFromDB(ctx, fresh, req.RequestedModel, false)
|
||||
if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) || !s.isAccountRequestCompatible(fresh, req) {
|
||||
if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) || !s.isAccountRequestCompatible(ctx, fresh, req) {
|
||||
continue
|
||||
}
|
||||
if req.RequireCompact && openAICompactSupportTier(fresh) == 0 {
|
||||
@ -894,13 +895,18 @@ func (s *defaultOpenAIAccountScheduler) isAccountTransportCompatible(account *Ac
|
||||
return s.service.isOpenAIAccountTransportCompatible(account, requiredTransport)
|
||||
}
|
||||
|
||||
func (s *defaultOpenAIAccountScheduler) isAccountRequestCompatible(account *Account, req OpenAIAccountScheduleRequest) bool {
|
||||
func (s *defaultOpenAIAccountScheduler) isAccountRequestCompatible(ctx context.Context, account *Account, req OpenAIAccountScheduleRequest) bool {
|
||||
if account == nil {
|
||||
return false
|
||||
}
|
||||
if req.RequestedModel != "" && !account.IsModelSupported(req.RequestedModel) {
|
||||
return false
|
||||
}
|
||||
if req.GroupID != nil && s != nil && s.service != nil &&
|
||||
s.service.needsUpstreamChannelRestrictionCheck(ctx, req.GroupID) &&
|
||||
s.service.isUpstreamModelRestrictedByChannel(ctx, *req.GroupID, account, req.RequestedModel, req.RequireCompact) {
|
||||
return false
|
||||
}
|
||||
return account.SupportsOpenAIImageCapability(req.RequiredImageCapability)
|
||||
}
|
||||
|
||||
@ -1112,6 +1118,13 @@ func (s *OpenAIGatewayService) selectAccountWithScheduler(
|
||||
}
|
||||
}
|
||||
|
||||
if s.checkChannelPricingRestriction(ctx, groupID, requestedModel) {
|
||||
slog.Warn("channel pricing restriction blocked request",
|
||||
"group_id", derefGroupID(groupID),
|
||||
"model", requestedModel)
|
||||
return nil, decision, fmt.Errorf("%w supporting model: %s (channel pricing restriction)", ErrNoAvailableAccounts, requestedModel)
|
||||
}
|
||||
|
||||
var stickyAccountID int64
|
||||
if sessionHash != "" && s.cache != nil {
|
||||
if accountID, err := s.getStickySessionAccountID(ctx, groupID, sessionHash); err == nil && accountID > 0 {
|
||||
|
||||
149
backend/internal/service/openai_apikey_responses_probe.go
Normal file
149
backend/internal/service/openai_apikey_responses_probe.go
Normal file
@ -0,0 +1,149 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/openai_compat"
|
||||
)
|
||||
|
||||
// openaiResponsesProbeTimeout 是探测请求的超时时长。
|
||||
// 探测必须快速失败——超时不应阻塞账号创建/更新流程。
|
||||
const openaiResponsesProbeTimeout = 8 * time.Second
|
||||
|
||||
// openaiResponsesProbePayload 是探测使用的最小 Responses 请求体。
|
||||
// 仅作能力探测,不期望响应内容质量;Stream=false 减少 SSE 解析开销。
|
||||
//
|
||||
// 注意:探测的目标是区分"端点存在"与"端点不存在"——只要上游返回非 404 的
|
||||
// 4xx/5xx(如 400 invalid_request_error / 401 unauthorized / 422 等),
|
||||
// 都视为"端点存在 → 支持 Responses"。仅 404 / 405 视为"端点不存在"。
|
||||
func openaiResponsesProbePayload(modelID string) []byte {
|
||||
if strings.TrimSpace(modelID) == "" {
|
||||
modelID = openai.DefaultTestModel
|
||||
}
|
||||
body, _ := json.Marshal(map[string]any{
|
||||
"model": modelID,
|
||||
"input": []map[string]any{
|
||||
{
|
||||
"role": "user",
|
||||
"content": []map[string]any{
|
||||
{"type": "input_text", "text": "hi"},
|
||||
},
|
||||
},
|
||||
},
|
||||
"instructions": openai.DefaultInstructions,
|
||||
"stream": false,
|
||||
})
|
||||
return body
|
||||
}
|
||||
|
||||
// ProbeOpenAIAPIKeyResponsesSupport 探测 OpenAI APIKey 账号上游是否支持
|
||||
// /v1/responses 端点,并将结果持久化到 accounts.extra.openai_responses_supported。
|
||||
//
|
||||
// 调用时机:账号创建/更新后,且仅当 platform=openai && type=apikey 时。
|
||||
//
|
||||
// 探测策略(参见包文档 internal/pkg/openai_compat):
|
||||
// - 上游 404 / 405 → 不支持,写 false
|
||||
// - 上游 2xx / 其他 4xx(401/422/400 等)/ 5xx → 支持,写 true
|
||||
// - 网络层失败(连接错误、超时)→ 不写标记,保持 unknown
|
||||
// (后续请求仍按"现状即证据"默认走 Responses)
|
||||
//
|
||||
// 该方法是幂等的:重复调用会以最新探测结果覆盖标记。
|
||||
//
|
||||
// 关于失败处理:探测本身的失败不应阻塞账号创建——账号能创建/更新成功就够了,
|
||||
// 探测结果只影响后续路由优化。所有错误都仅记录日志,不向调用方传播。
|
||||
func (s *AccountTestService) ProbeOpenAIAPIKeyResponsesSupport(ctx context.Context, accountID int64) {
|
||||
account, err := s.accountRepo.GetByID(ctx, accountID)
|
||||
if err != nil {
|
||||
logger.LegacyPrintf("service.openai_probe", "probe_load_account_failed: account_id=%d err=%v", accountID, err)
|
||||
return
|
||||
}
|
||||
if account.Platform != PlatformOpenAI || account.Type != AccountTypeAPIKey {
|
||||
// 仅 OpenAI APIKey 账号需要探测;其他账号类型无能力差异。
|
||||
return
|
||||
}
|
||||
|
||||
apiKey := account.GetOpenAIApiKey()
|
||||
if apiKey == "" {
|
||||
logger.LegacyPrintf("service.openai_probe", "probe_skip_no_apikey: account_id=%d", accountID)
|
||||
return
|
||||
}
|
||||
baseURL := account.GetOpenAIBaseURL()
|
||||
if baseURL == "" {
|
||||
baseURL = "https://api.openai.com"
|
||||
}
|
||||
normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
|
||||
if err != nil {
|
||||
logger.LegacyPrintf("service.openai_probe", "probe_invalid_baseurl: account_id=%d base_url=%q err=%v", accountID, baseURL, err)
|
||||
return
|
||||
}
|
||||
|
||||
probeURL := buildOpenAIResponsesURL(normalizedBaseURL)
|
||||
|
||||
probeCtx, cancel := context.WithTimeout(ctx, openaiResponsesProbeTimeout)
|
||||
defer cancel()
|
||||
|
||||
req, err := http.NewRequestWithContext(probeCtx, http.MethodPost, probeURL, bytes.NewReader(openaiResponsesProbePayload("")))
|
||||
if err != nil {
|
||||
logger.LegacyPrintf("service.openai_probe", "probe_build_request_failed: account_id=%d err=%v", accountID, err)
|
||||
return
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+apiKey)
|
||||
req.Header.Set("Accept", "application/json")
|
||||
|
||||
proxyURL := ""
|
||||
if account.ProxyID != nil && account.Proxy != nil {
|
||||
proxyURL = account.Proxy.URL()
|
||||
}
|
||||
|
||||
resp, err := s.httpUpstream.DoWithTLS(req, proxyURL, account.ID, account.Concurrency, s.tlsFPProfileService.ResolveTLSProfile(account))
|
||||
if err != nil {
|
||||
// 网络层失败:不写标记,保持 unknown,下次重试或由网关 fallback 处理
|
||||
logger.LegacyPrintf("service.openai_probe", "probe_request_failed: account_id=%d url=%s err=%v", accountID, probeURL, err)
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
_, _ = io.Copy(io.Discard, io.LimitReader(resp.Body, 1<<20))
|
||||
_ = resp.Body.Close()
|
||||
}()
|
||||
|
||||
supported := isResponsesEndpointSupportedByStatus(resp.StatusCode)
|
||||
|
||||
if err := s.accountRepo.UpdateExtra(ctx, accountID, map[string]any{
|
||||
openai_compat.ExtraKeyResponsesSupported: supported,
|
||||
}); err != nil {
|
||||
logger.LegacyPrintf("service.openai_probe", "probe_persist_failed: account_id=%d supported=%v err=%v", accountID, supported, err)
|
||||
return
|
||||
}
|
||||
|
||||
logger.LegacyPrintf("service.openai_probe",
|
||||
"probe_done: account_id=%d base_url=%s status=%d supported=%v",
|
||||
accountID, normalizedBaseURL, resp.StatusCode, supported,
|
||||
)
|
||||
}
|
||||
|
||||
// isResponsesEndpointSupportedByStatus 根据探测响应的 HTTP 状态码判定上游
|
||||
// 是否暴露 /v1/responses 端点。
|
||||
//
|
||||
// 关键观察:第三方 OpenAI 兼容上游(DeepSeek/Kimi 等)对未知端点统一返回 404
|
||||
// 或 405;而 OpenAI 官方/有 Responses 实现的上游会因为请求体最简(缺字段)
|
||||
// 返回 400/422 等业务错误,但端点本身存在。
|
||||
//
|
||||
// 因此:仅 404 和 405 视为"端点不存在",其他 status 视为"端点存在"。
|
||||
//
|
||||
// 5xx 也视为"端点存在"——上游偶发故障不应误判为不支持。
|
||||
func isResponsesEndpointSupportedByStatus(status int) bool {
|
||||
switch status {
|
||||
case http.StatusNotFound, http.StatusMethodNotAllowed:
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
@ -38,6 +38,29 @@ var codexModelMap = map[string]string{
|
||||
"gpt-5.2-medium": "gpt-5.2",
|
||||
"gpt-5.2-high": "gpt-5.2",
|
||||
"gpt-5.2-xhigh": "gpt-5.2",
|
||||
"gpt-5": "gpt-5.4",
|
||||
"gpt-5-mini": "gpt-5.4",
|
||||
"gpt-5-nano": "gpt-5.4",
|
||||
"gpt-5.1": "gpt-5.4",
|
||||
"gpt-5.1-codex": "gpt-5.3-codex",
|
||||
"gpt-5.1-codex-max": "gpt-5.3-codex",
|
||||
"gpt-5.1-codex-mini": "gpt-5.3-codex",
|
||||
"gpt-5.2-codex": "gpt-5.2",
|
||||
"codex-mini-latest": "gpt-5.3-codex",
|
||||
"gpt-5-codex": "gpt-5.3-codex",
|
||||
}
|
||||
|
||||
var codexVersionModelPrefixes = []struct {
|
||||
prefix string
|
||||
target string
|
||||
}{
|
||||
{prefix: "gpt-5.3-codex-spark", target: "gpt-5.3-codex-spark"},
|
||||
{prefix: "gpt-5.3-codex", target: "gpt-5.3-codex"},
|
||||
{prefix: "gpt-5.4-mini", target: "gpt-5.4-mini"},
|
||||
{prefix: "gpt-5.4-nano", target: "gpt-5.4-nano"},
|
||||
{prefix: "gpt-5.5", target: "gpt-5.5"},
|
||||
{prefix: "gpt-5.4", target: "gpt-5.4"},
|
||||
{prefix: "gpt-5.2", target: "gpt-5.2"},
|
||||
}
|
||||
|
||||
type codexTransformResult struct {
|
||||
@ -46,6 +69,13 @@ type codexTransformResult struct {
|
||||
PromptCacheKey string
|
||||
}
|
||||
|
||||
type codexOAuthTransformOptions struct {
|
||||
IsCodexCLI bool
|
||||
IsCompact bool
|
||||
SkipDefaultInstructions bool
|
||||
PreserveToolCallIDs bool
|
||||
}
|
||||
|
||||
const (
|
||||
codexImageGenerationBridgeMarker = "<sub2api-codex-image-generation>"
|
||||
codexImageGenerationBridgeText = codexImageGenerationBridgeMarker + "\nWhen the user asks for raster image generation or editing, use the OpenAI Responses native `image_generation` tool attached to this request. The local Codex client may not expose an `image_gen` namespace, but that does not mean image generation is unavailable. Do not ask the user to switch to CLI fallback solely because `image_gen` is absent.\n</sub2api-codex-image-generation>"
|
||||
@ -71,6 +101,13 @@ var openAICodexOAuthUnsupportedFields = append([]string{
|
||||
}, openAIChatGPTInternalUnsupportedFields...)
|
||||
|
||||
func applyCodexOAuthTransform(reqBody map[string]any, isCodexCLI bool, isCompact bool) codexTransformResult {
|
||||
return applyCodexOAuthTransformWithOptions(reqBody, codexOAuthTransformOptions{
|
||||
IsCodexCLI: isCodexCLI,
|
||||
IsCompact: isCompact,
|
||||
})
|
||||
}
|
||||
|
||||
func applyCodexOAuthTransformWithOptions(reqBody map[string]any, opts codexOAuthTransformOptions) codexTransformResult {
|
||||
result := codexTransformResult{}
|
||||
// 工具续链需求会影响存储策略与 input 过滤逻辑。
|
||||
needsToolContinuation := NeedsToolContinuation(reqBody)
|
||||
@ -88,7 +125,7 @@ func applyCodexOAuthTransform(reqBody map[string]any, isCodexCLI bool, isCompact
|
||||
result.NormalizedModel = normalizedModel
|
||||
}
|
||||
|
||||
if isCompact {
|
||||
if opts.IsCompact {
|
||||
if _, ok := reqBody["store"]; ok {
|
||||
delete(reqBody, "store")
|
||||
result.Modified = true
|
||||
@ -160,6 +197,10 @@ func applyCodexOAuthTransform(reqBody map[string]any, isCodexCLI bool, isCompact
|
||||
|
||||
if v, ok := reqBody["prompt_cache_key"].(string); ok {
|
||||
result.PromptCacheKey = strings.TrimSpace(v)
|
||||
if isOpenAICompatMessagesBridgeRequestBody(reqBody) {
|
||||
delete(reqBody, "prompt_cache_key")
|
||||
result.Modified = true
|
||||
}
|
||||
}
|
||||
|
||||
// 提取 input 中 role:"system" 消息至 instructions(OAuth 上游不支持 system role)。
|
||||
@ -168,7 +209,7 @@ func applyCodexOAuthTransform(reqBody map[string]any, isCodexCLI bool, isCompact
|
||||
}
|
||||
|
||||
// instructions 处理逻辑:根据是否是 Codex CLI 分别调用不同方法
|
||||
if applyInstructions(reqBody, isCodexCLI) {
|
||||
if !opts.SkipDefaultInstructions && applyInstructions(reqBody, opts.IsCodexCLI) {
|
||||
result.Modified = true
|
||||
}
|
||||
if isCodexSparkModel(normalizedModel) && applyCodexSparkImageUnsupportedInstructions(reqBody) {
|
||||
@ -185,7 +226,10 @@ func applyCodexOAuthTransform(reqBody map[string]any, isCodexCLI bool, isCompact
|
||||
input = normalizedInput
|
||||
result.Modified = true
|
||||
}
|
||||
input = filterCodexInput(input, needsToolContinuation)
|
||||
input = filterCodexInputWithOptions(input, codexInputFilterOptions{
|
||||
PreserveReferences: needsToolContinuation,
|
||||
PreserveCallIDs: opts.PreserveToolCallIDs,
|
||||
})
|
||||
reqBody["input"] = input
|
||||
result.Modified = true
|
||||
} else if inputStr, ok := reqBody["input"].(string); ok {
|
||||
@ -447,51 +491,81 @@ func normalizeCodexModel(model string) string {
|
||||
if model == "" {
|
||||
return "gpt-5.4"
|
||||
}
|
||||
if mapped, ok := normalizeKnownCodexModel(model); ok {
|
||||
return mapped
|
||||
}
|
||||
return model
|
||||
}
|
||||
|
||||
func normalizeKnownCodexModel(model string) (string, bool) {
|
||||
model = strings.TrimSpace(model)
|
||||
if model == "" {
|
||||
return "", false
|
||||
}
|
||||
if isOpenAIImageGenerationModel(model) {
|
||||
return model
|
||||
return model, true
|
||||
}
|
||||
|
||||
modelID := model
|
||||
modelID := lastOpenAIModelSegment(model)
|
||||
|
||||
if normalized := canonicalizeOpenAIModelAliasSpelling(modelID); normalized != "" {
|
||||
modelID = normalized
|
||||
}
|
||||
if mapped := normalizeKnownOpenAICodexModel(modelID); mapped != "" {
|
||||
return mapped, true
|
||||
}
|
||||
key := codexModelLookupKey(modelID)
|
||||
if key == "" {
|
||||
return "", false
|
||||
}
|
||||
if mapped := getNormalizedCodexModel(key); mapped != "" {
|
||||
return mapped, true
|
||||
}
|
||||
for _, item := range codexVersionModelPrefixes {
|
||||
if key == item.prefix {
|
||||
return item.target, true
|
||||
}
|
||||
suffix, ok := strings.CutPrefix(key, item.prefix+"-")
|
||||
if ok && isKnownCodexModelSuffix(suffix) {
|
||||
return item.target, true
|
||||
}
|
||||
}
|
||||
return "", false
|
||||
}
|
||||
|
||||
func codexModelLookupKey(modelID string) string {
|
||||
modelID = strings.TrimSpace(modelID)
|
||||
if modelID == "" {
|
||||
return ""
|
||||
}
|
||||
if strings.Contains(modelID, "/") {
|
||||
parts := strings.Split(modelID, "/")
|
||||
modelID = parts[len(parts)-1]
|
||||
}
|
||||
return strings.ToLower(strings.Join(strings.Fields(modelID), "-"))
|
||||
}
|
||||
|
||||
if mapped := getNormalizedCodexModel(modelID); mapped != "" {
|
||||
return mapped
|
||||
func isKnownCodexModelSuffix(suffix string) bool {
|
||||
switch suffix {
|
||||
case "none", "minimal", "low", "medium", "high", "xhigh":
|
||||
return true
|
||||
}
|
||||
return isCodexDateSuffix(suffix)
|
||||
}
|
||||
|
||||
normalized := strings.ToLower(modelID)
|
||||
|
||||
if strings.Contains(normalized, "gpt-5.5") || strings.Contains(normalized, "gpt 5.5") {
|
||||
return "gpt-5.5"
|
||||
func isCodexDateSuffix(suffix string) bool {
|
||||
parts := strings.Split(suffix, "-")
|
||||
if len(parts) != 3 || len(parts[0]) != 4 || len(parts[1]) != 2 || len(parts[2]) != 2 {
|
||||
return false
|
||||
}
|
||||
if strings.Contains(normalized, "gpt-5.4-mini") || strings.Contains(normalized, "gpt 5.4 mini") {
|
||||
return "gpt-5.4-mini"
|
||||
for _, part := range parts {
|
||||
for _, r := range part {
|
||||
if r < '0' || r > '9' {
|
||||
return false
|
||||
}
|
||||
}
|
||||
}
|
||||
if strings.Contains(normalized, "gpt-5.4") || strings.Contains(normalized, "gpt 5.4") {
|
||||
return "gpt-5.4"
|
||||
}
|
||||
if strings.Contains(normalized, "gpt-5.2") || strings.Contains(normalized, "gpt 5.2") {
|
||||
return "gpt-5.2"
|
||||
}
|
||||
if strings.Contains(normalized, "gpt-5.3-codex-spark") || strings.Contains(normalized, "gpt 5.3 codex spark") {
|
||||
return "gpt-5.3-codex-spark"
|
||||
}
|
||||
if strings.Contains(normalized, "gpt-5.3-codex") || strings.Contains(normalized, "gpt 5.3 codex") {
|
||||
return "gpt-5.3-codex"
|
||||
}
|
||||
if strings.Contains(normalized, "gpt-5.3") || strings.Contains(normalized, "gpt 5.3") {
|
||||
return "gpt-5.3-codex"
|
||||
}
|
||||
if strings.Contains(normalized, "codex") {
|
||||
return "gpt-5.3-codex"
|
||||
}
|
||||
if strings.Contains(normalized, "gpt-5") || strings.Contains(normalized, "gpt 5") {
|
||||
return "gpt-5.4"
|
||||
}
|
||||
|
||||
return "gpt-5.4"
|
||||
return true
|
||||
}
|
||||
|
||||
func isCodexSparkModel(model string) bool {
|
||||
@ -789,23 +863,18 @@ func SupportsVerbosity(model string) bool {
|
||||
}
|
||||
|
||||
func getNormalizedCodexModel(modelID string) string {
|
||||
if modelID == "" {
|
||||
key := codexModelLookupKey(modelID)
|
||||
if key == "" {
|
||||
return ""
|
||||
}
|
||||
if mapped, ok := codexModelMap[modelID]; ok {
|
||||
if mapped, ok := codexModelMap[key]; ok {
|
||||
return mapped
|
||||
}
|
||||
lower := strings.ToLower(modelID)
|
||||
for key, value := range codexModelMap {
|
||||
if strings.ToLower(key) == lower {
|
||||
return value
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// extractTextFromContent extracts plain text from a content value that is either
|
||||
// a Go string or a []any of content-part maps with type:"text".
|
||||
// a Go string or a []any of text-like content-part maps.
|
||||
func extractTextFromContent(content any) string {
|
||||
switch v := content.(type) {
|
||||
case string:
|
||||
@ -817,7 +886,8 @@ func extractTextFromContent(content any) string {
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if t, _ := m["type"].(string); t == "text" {
|
||||
switch t, _ := m["type"].(string); t {
|
||||
case "text", "input_text", "output_text":
|
||||
if text, ok := m["text"].(string); ok {
|
||||
parts = append(parts, text)
|
||||
}
|
||||
@ -871,6 +941,28 @@ func extractSystemMessagesFromInput(reqBody map[string]any) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func extractPromptLikeInstructionsFromInput(reqBody map[string]any) string {
|
||||
input, ok := reqBody["input"].([]any)
|
||||
if !ok || len(input) == 0 {
|
||||
return ""
|
||||
}
|
||||
var texts []string
|
||||
for _, item := range input {
|
||||
m, ok := item.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
role, _ := m["role"].(string)
|
||||
switch role {
|
||||
case "developer", "system":
|
||||
if text := strings.TrimSpace(extractTextFromContent(m["content"])); text != "" {
|
||||
texts = append(texts, text)
|
||||
}
|
||||
}
|
||||
}
|
||||
return strings.Join(texts, "\n\n")
|
||||
}
|
||||
|
||||
// applyInstructions 处理 instructions 字段:仅在 instructions 为空时填充默认值。
|
||||
func applyInstructions(reqBody map[string]any, isCodexCLI bool) bool {
|
||||
if !isInstructionsEmpty(reqBody) {
|
||||
@ -897,9 +989,20 @@ func isInstructionsEmpty(reqBody map[string]any) bool {
|
||||
return strings.TrimSpace(str) == ""
|
||||
}
|
||||
|
||||
type codexInputFilterOptions struct {
|
||||
PreserveReferences bool
|
||||
PreserveCallIDs bool
|
||||
}
|
||||
|
||||
// filterCodexInput 按需过滤 item_reference 与 id。
|
||||
// preserveReferences 为 true 时保持引用与 id,以满足续链请求对上下文的依赖。
|
||||
func filterCodexInput(input []any, preserveReferences bool) []any {
|
||||
return filterCodexInputWithOptions(input, codexInputFilterOptions{
|
||||
PreserveReferences: preserveReferences,
|
||||
})
|
||||
}
|
||||
|
||||
func filterCodexInputWithOptions(input []any, opts codexInputFilterOptions) []any {
|
||||
filtered := make([]any, 0, len(input))
|
||||
for _, item := range input {
|
||||
m, ok := item.(map[string]any)
|
||||
@ -920,6 +1023,9 @@ func filterCodexInput(input []any, preserveReferences bool) []any {
|
||||
// 仅修正真正的 tool/function call 标识,避免误改普通 message/reasoning id;
|
||||
// 若 item_reference 指向 legacy call_* 标识,则仅修正该引用本身。
|
||||
fixCallIDPrefix := func(id string) string {
|
||||
if opts.PreserveCallIDs {
|
||||
return id
|
||||
}
|
||||
if id == "" || strings.HasPrefix(id, "fc") {
|
||||
return id
|
||||
}
|
||||
@ -930,7 +1036,7 @@ func filterCodexInput(input []any, preserveReferences bool) []any {
|
||||
}
|
||||
|
||||
if typ == "item_reference" {
|
||||
if !preserveReferences {
|
||||
if !opts.PreserveReferences {
|
||||
continue
|
||||
}
|
||||
newItem := make(map[string]any, len(m))
|
||||
@ -998,7 +1104,7 @@ func filterCodexInput(input []any, preserveReferences bool) []any {
|
||||
}
|
||||
}
|
||||
|
||||
if !preserveReferences {
|
||||
if !opts.PreserveReferences {
|
||||
ensureCopy()
|
||||
delete(newItem, "id")
|
||||
}
|
||||
|
||||
@ -44,6 +44,39 @@ func TestApplyCodexOAuthTransform_ToolContinuationPreservesInput(t *testing.T) {
|
||||
require.Equal(t, "fc1", second["call_id"])
|
||||
}
|
||||
|
||||
func TestApplyCodexOAuthTransform_MessagesBridgePromptCacheKeyIsHeaderOnly(t *testing.T) {
|
||||
reqBody := map[string]any{
|
||||
"model": "gpt-5.5",
|
||||
"prompt_cache_key": "anthropic-metadata-session-1",
|
||||
"input": []any{
|
||||
map[string]any{
|
||||
"type": "message",
|
||||
"role": "developer",
|
||||
"content": []any{
|
||||
map[string]any{
|
||||
"type": "input_text",
|
||||
"text": openAICompatClaudeCodeTodoGuardMarker,
|
||||
},
|
||||
},
|
||||
},
|
||||
map[string]any{
|
||||
"type": "message",
|
||||
"role": "user",
|
||||
"content": "hello",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result := applyCodexOAuthTransformWithOptions(reqBody, codexOAuthTransformOptions{
|
||||
SkipDefaultInstructions: true,
|
||||
PreserveToolCallIDs: true,
|
||||
})
|
||||
|
||||
require.Equal(t, "anthropic-metadata-session-1", result.PromptCacheKey)
|
||||
require.True(t, result.Modified)
|
||||
require.NotContains(t, reqBody, "prompt_cache_key")
|
||||
}
|
||||
|
||||
func TestApplyCodexOAuthTransform_ToolContinuationPreservesNativeMessageAndReasoningIDs(t *testing.T) {
|
||||
reqBody := map[string]any{
|
||||
"model": "gpt-5.2",
|
||||
@ -804,15 +837,25 @@ func TestApplyCodexOAuthTransform_EmptyInput(t *testing.T) {
|
||||
func TestNormalizeCodexModel_Gpt53(t *testing.T) {
|
||||
cases := map[string]string{
|
||||
"gpt-5.4": "gpt-5.4",
|
||||
"gpt5.5": "gpt-5.5",
|
||||
"openai/gpt5.5": "gpt-5.5",
|
||||
"gpt5.4": "gpt-5.4",
|
||||
"gpt-5.4-high": "gpt-5.4",
|
||||
"gpt-5.4-chat-latest": "gpt-5.4",
|
||||
"gpt 5.4": "gpt-5.4",
|
||||
"gpt-5.4-mini": "gpt-5.4-mini",
|
||||
"gpt5.4-mini": "gpt-5.4-mini",
|
||||
"gpt5.4mini": "gpt-5.4-mini",
|
||||
"gpt 5.4 mini": "gpt-5.4-mini",
|
||||
"gpt-5.3": "gpt-5.3-codex",
|
||||
"gpt5.3": "gpt-5.3-codex",
|
||||
"gpt-5.3-codex": "gpt-5.3-codex",
|
||||
"gpt5.3-codex": "gpt-5.3-codex",
|
||||
"gpt5.3codex": "gpt-5.3-codex",
|
||||
"gpt-5.3-codex-xhigh": "gpt-5.3-codex",
|
||||
"gpt-5.3-codex-spark": "gpt-5.3-codex-spark",
|
||||
"gpt5.3-codex-spark": "gpt-5.3-codex-spark",
|
||||
"gpt5.3codexspark": "gpt-5.3-codex-spark",
|
||||
"gpt 5.3 codex spark": "gpt-5.3-codex-spark",
|
||||
"gpt-5.3-codex-spark-high": "gpt-5.3-codex-spark",
|
||||
"gpt-5.3-codex-spark-xhigh": "gpt-5.3-codex-spark",
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -1,7 +1,9 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/apicompat"
|
||||
@ -16,12 +18,8 @@ func shouldAutoInjectPromptCacheKeyForCompat(model string) bool {
|
||||
if !strings.Contains(trimmed, "gpt-5") && !strings.Contains(trimmed, "codex") {
|
||||
return false
|
||||
}
|
||||
switch normalizeCodexModel(trimmed) {
|
||||
case "gpt-5.4", "gpt-5.3-codex", "gpt-5.3-codex-spark":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
normalized := strings.TrimSpace(strings.ToLower(normalizeCodexModel(trimmed)))
|
||||
return strings.HasPrefix(normalized, "gpt-5") || strings.Contains(normalized, "codex")
|
||||
}
|
||||
|
||||
func deriveCompatPromptCacheKey(req *apicompat.ChatCompletionsRequest, mappedModel string) string {
|
||||
@ -71,6 +69,102 @@ func deriveCompatPromptCacheKey(req *apicompat.ChatCompletionsRequest, mappedMod
|
||||
return compatPromptCacheKeyPrefix + hashSensitiveValueForLog(strings.Join(seedParts, "|"))
|
||||
}
|
||||
|
||||
func deriveAnthropicCompatPromptCacheKey(req *apicompat.AnthropicRequest, mappedModel string) string {
|
||||
if req == nil {
|
||||
return ""
|
||||
}
|
||||
if anchorKey := deriveAnthropicCacheControlPromptCacheKey(req); anchorKey != "" {
|
||||
return anchorKey
|
||||
}
|
||||
|
||||
normalizedModel := normalizeCodexModel(strings.TrimSpace(mappedModel))
|
||||
if normalizedModel == "" {
|
||||
normalizedModel = normalizeCodexModel(strings.TrimSpace(req.Model))
|
||||
}
|
||||
if normalizedModel == "" {
|
||||
normalizedModel = strings.TrimSpace(req.Model)
|
||||
}
|
||||
|
||||
seedParts := []string{"model=" + normalizedModel}
|
||||
if req.OutputConfig != nil && strings.TrimSpace(req.OutputConfig.Effort) != "" {
|
||||
seedParts = append(seedParts, "effort="+strings.TrimSpace(req.OutputConfig.Effort))
|
||||
}
|
||||
if len(req.ToolChoice) > 0 {
|
||||
seedParts = append(seedParts, "tool_choice="+normalizeCompatSeedJSON(req.ToolChoice))
|
||||
}
|
||||
if len(req.Tools) > 0 {
|
||||
if raw, err := json.Marshal(req.Tools); err == nil {
|
||||
seedParts = append(seedParts, "tools="+normalizeCompatSeedJSON(raw))
|
||||
}
|
||||
}
|
||||
if len(req.System) > 0 {
|
||||
seedParts = append(seedParts, "system="+normalizeCompatSeedJSON(req.System))
|
||||
}
|
||||
|
||||
firstUserCaptured := false
|
||||
for _, msg := range req.Messages {
|
||||
if strings.TrimSpace(msg.Role) != "user" || firstUserCaptured {
|
||||
continue
|
||||
}
|
||||
seedParts = append(seedParts, "first_user="+normalizeCompatSeedJSON(msg.Content))
|
||||
firstUserCaptured = true
|
||||
}
|
||||
|
||||
return compatPromptCacheKeyPrefix + hashSensitiveValueForLog(strings.Join(seedParts, "|"))
|
||||
}
|
||||
|
||||
func deriveAnthropicCacheControlPromptCacheKey(req *apicompat.AnthropicRequest) string {
|
||||
if req == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
var parts []string
|
||||
var systemBlocks []apicompat.AnthropicContentBlock
|
||||
if len(req.System) > 0 && json.Unmarshal(req.System, &systemBlocks) == nil {
|
||||
for _, block := range systemBlocks {
|
||||
if block.Type == "text" &&
|
||||
block.CacheControl != nil &&
|
||||
strings.TrimSpace(block.CacheControl.Type) == "ephemeral" &&
|
||||
strings.TrimSpace(block.Text) != "" {
|
||||
parts = append(parts, "system:"+strings.TrimSpace(block.Text))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
firstUserAnchor := ""
|
||||
for _, msg := range req.Messages {
|
||||
var blocks []apicompat.AnthropicContentBlock
|
||||
if len(msg.Content) == 0 || json.Unmarshal(msg.Content, &blocks) != nil {
|
||||
continue
|
||||
}
|
||||
role := strings.TrimSpace(msg.Role)
|
||||
for _, block := range blocks {
|
||||
if block.Type != "text" ||
|
||||
block.CacheControl == nil ||
|
||||
strings.TrimSpace(block.CacheControl.Type) != "ephemeral" ||
|
||||
strings.TrimSpace(block.Text) == "" {
|
||||
continue
|
||||
}
|
||||
switch role {
|
||||
case "user":
|
||||
if firstUserAnchor == "" {
|
||||
firstUserAnchor = strings.TrimSpace(block.Text)
|
||||
}
|
||||
case "assistant":
|
||||
parts = append(parts, "assistant:"+strings.TrimSpace(block.Text))
|
||||
}
|
||||
}
|
||||
}
|
||||
if firstUserAnchor != "" {
|
||||
parts = append(parts, "user_anchor:"+firstUserAnchor)
|
||||
}
|
||||
if len(parts) == 0 {
|
||||
return ""
|
||||
}
|
||||
sum := sha256.Sum256([]byte("anthropic-cache:" + strings.Join(parts, "\n")))
|
||||
return fmt.Sprintf("anthropic-cache-%x", sum[:16])
|
||||
}
|
||||
|
||||
func normalizeCompatSeedJSON(v json.RawMessage) string {
|
||||
if len(v) == 0 {
|
||||
return ""
|
||||
|
||||
@ -2,6 +2,7 @@ package service
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/apicompat"
|
||||
@ -14,7 +15,10 @@ func mustRawJSON(t *testing.T, s string) json.RawMessage {
|
||||
}
|
||||
|
||||
func TestShouldAutoInjectPromptCacheKeyForCompat(t *testing.T) {
|
||||
require.True(t, shouldAutoInjectPromptCacheKeyForCompat("gpt-5.5"))
|
||||
require.True(t, shouldAutoInjectPromptCacheKeyForCompat("gpt-5.4"))
|
||||
require.True(t, shouldAutoInjectPromptCacheKeyForCompat("gpt-5.4-mini"))
|
||||
require.True(t, shouldAutoInjectPromptCacheKeyForCompat("gpt-5.2"))
|
||||
require.True(t, shouldAutoInjectPromptCacheKeyForCompat("gpt-5.3"))
|
||||
require.True(t, shouldAutoInjectPromptCacheKeyForCompat("gpt-5.3-codex"))
|
||||
require.True(t, shouldAutoInjectPromptCacheKeyForCompat("gpt-5.3-codex-spark"))
|
||||
@ -77,3 +81,57 @@ func TestDeriveCompatPromptCacheKey_UsesResolvedSparkFamily(t *testing.T) {
|
||||
require.NotEmpty(t, k1)
|
||||
require.Equal(t, k1, k2, "resolved spark family should derive a stable compat cache key")
|
||||
}
|
||||
|
||||
func TestDeriveAnthropicCompatPromptCacheKey_StableAcrossLaterTurns(t *testing.T) {
|
||||
base := &apicompat.AnthropicRequest{
|
||||
Model: "claude-sonnet-4-5",
|
||||
System: mustRawJSON(t, `"You are helpful."`),
|
||||
Messages: []apicompat.AnthropicMessage{
|
||||
{Role: "user", Content: mustRawJSON(t, `"Open repo"`)},
|
||||
},
|
||||
}
|
||||
extended := &apicompat.AnthropicRequest{
|
||||
Model: "claude-sonnet-4-5",
|
||||
System: mustRawJSON(t, `"You are helpful."`),
|
||||
Messages: []apicompat.AnthropicMessage{
|
||||
{Role: "user", Content: mustRawJSON(t, `"Open repo"`)},
|
||||
{Role: "assistant", Content: mustRawJSON(t, `"Opened."`)},
|
||||
{Role: "user", Content: mustRawJSON(t, `"Run tests"`)},
|
||||
},
|
||||
}
|
||||
|
||||
k1 := deriveAnthropicCompatPromptCacheKey(base, "gpt-5.3-codex")
|
||||
k2 := deriveAnthropicCompatPromptCacheKey(extended, "gpt-5.3-codex")
|
||||
require.NotEmpty(t, k1)
|
||||
require.Equal(t, k1, k2, "cache key should stay stable as later Claude Code turns append history")
|
||||
}
|
||||
|
||||
func TestDeriveAnthropicCompatPromptCacheKey_UsesCacheControlAnchors(t *testing.T) {
|
||||
base := &apicompat.AnthropicRequest{
|
||||
Model: "claude-sonnet-4-5",
|
||||
System: mustRawJSON(t, `[
|
||||
{"type":"text","text":"project instructions","cache_control":{"type":"ephemeral"}}
|
||||
]`),
|
||||
Messages: []apicompat.AnthropicMessage{
|
||||
{Role: "user", Content: mustRawJSON(t, `[
|
||||
{"type":"text","text":"repo anchor","cache_control":{"type":"ephemeral"}}
|
||||
]`)},
|
||||
},
|
||||
}
|
||||
extended := &apicompat.AnthropicRequest{
|
||||
Model: base.Model,
|
||||
System: base.System,
|
||||
Messages: []apicompat.AnthropicMessage{
|
||||
base.Messages[0],
|
||||
{Role: "assistant", Content: mustRawJSON(t, `[{"type":"text","text":"Opened."}]`)},
|
||||
{Role: "user", Content: mustRawJSON(t, `[{"type":"text","text":"Run tests"}]`)},
|
||||
},
|
||||
}
|
||||
|
||||
k1 := deriveAnthropicCompatPromptCacheKey(base, "gpt-5.4")
|
||||
k2 := deriveAnthropicCompatPromptCacheKey(extended, "gpt-5.4")
|
||||
require.NotEmpty(t, k1)
|
||||
require.Equal(t, k1, k2)
|
||||
require.True(t, strings.HasPrefix(k1, "anthropic-cache-"))
|
||||
require.False(t, strings.HasPrefix(k1, compatPromptCacheKeyPrefix))
|
||||
}
|
||||
|
||||
@ -972,6 +972,62 @@ func TestPassthroughBilling_MultiTurnServiceTierFollowsFilteredFrames(t *testing
|
||||
"turn 3: response.create without service_tier overwrites billing to nil to match upstream default")
|
||||
}
|
||||
|
||||
func TestPassthroughUsageMeta_TracksReasoningEffortAcrossTurns(t *testing.T) {
|
||||
svc := newOpenAIGatewayServiceWithSettings(t, DefaultOpenAIFastPolicySettings())
|
||||
account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
|
||||
|
||||
firstFrame := []byte(`{"type":"response.create","model":"gpt-5.5","reasoning":{"effort":"medium"},"service_tier":"priority"}`)
|
||||
meta := newOpenAIWSPassthroughUsageMeta("", firstFrame)
|
||||
capturedSessionModel := openAIWSPassthroughPolicyModelForFrame(account, firstFrame)
|
||||
firstOut, firstBlocked, firstErr := svc.applyOpenAIFastPolicyToWSResponseCreate(context.Background(), account, capturedSessionModel, firstFrame)
|
||||
require.NoError(t, firstErr)
|
||||
require.Nil(t, firstBlocked)
|
||||
meta.initFromFirstFrame(firstOut)
|
||||
require.NotNil(t, meta.reasoningEffort.Load())
|
||||
require.Equal(t, "medium", *meta.reasoningEffort.Load())
|
||||
|
||||
process := func(payload []byte) ([]byte, *OpenAIFastBlockedError, error) {
|
||||
if updated := openAIWSPassthroughPolicyModelFromSessionFrame(account, payload); updated != "" {
|
||||
capturedSessionModel = updated
|
||||
}
|
||||
meta.updateSessionRequestModel(payload)
|
||||
requestModelForThisFrame := meta.requestModelForFrame(payload)
|
||||
model := openAIWSPassthroughPolicyModelForFrame(account, payload)
|
||||
if model == "" {
|
||||
model = capturedSessionModel
|
||||
}
|
||||
out, blocked, policyErr := svc.applyOpenAIFastPolicyToWSResponseCreate(context.Background(), account, model, payload)
|
||||
if policyErr == nil && blocked == nil &&
|
||||
strings.TrimSpace(gjson.GetBytes(payload, "type").String()) == "response.create" {
|
||||
meta.updateFromResponseCreate(out, requestModelForThisFrame)
|
||||
}
|
||||
return out, blocked, policyErr
|
||||
}
|
||||
|
||||
_, blockedSession, errSession := process([]byte(`{"type":"session.update","session":{"model":"gpt-5-high"}}`))
|
||||
require.NoError(t, errSession)
|
||||
require.Nil(t, blockedSession)
|
||||
require.NotNil(t, meta.reasoningEffort.Load())
|
||||
require.Equal(t, "medium", *meta.reasoningEffort.Load(), "session.update 只刷新后续 fallback model,不覆盖当前 turn metadata")
|
||||
|
||||
_, blockedCancel, errCancel := process([]byte(`{"type":"response.cancel","reasoning_effort":"x-high"}`))
|
||||
require.NoError(t, errCancel)
|
||||
require.Nil(t, blockedCancel)
|
||||
require.NotNil(t, meta.reasoningEffort.Load())
|
||||
require.Equal(t, "medium", *meta.reasoningEffort.Load(), "非 response.create 帧不能污染当前 turn metadata")
|
||||
|
||||
_, blockedFlat, errFlat := process([]byte(`{"type":"response.create","reasoning_effort":"x-high"}`))
|
||||
require.NoError(t, errFlat)
|
||||
require.Nil(t, blockedFlat)
|
||||
require.NotNil(t, meta.reasoningEffort.Load())
|
||||
require.Equal(t, "xhigh", *meta.reasoningEffort.Load(), "flat reasoning_effort 必须进入 passthrough usage metadata")
|
||||
|
||||
_, blockedClear, errClear := process([]byte(`{"type":"response.create","model":"gpt-4o"}`))
|
||||
require.NoError(t, errClear)
|
||||
require.Nil(t, blockedClear)
|
||||
require.Nil(t, meta.reasoningEffort.Load(), "新的 response.create 无 effort 且无可推导后缀时必须清空旧值")
|
||||
}
|
||||
|
||||
// TestPassthroughBilling_BlockedFrameDoesNotMutateServiceTier locks in the
|
||||
// "block keeps previous" semantic: when policy returns block on a
|
||||
// response.create frame, that frame is never sent upstream, so billing tier
|
||||
|
||||
@ -20,20 +20,29 @@ func (s *openAI403CounterResetStub) ResetOpenAI403Count(_ context.Context, accou
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayServiceRecordUsage_ResetsOpenAI403CounterBeforeZeroUsageReturn(t *testing.T) {
|
||||
func TestOpenAIGatewayServiceRecordUsage_ResetsOpenAI403CounterForZeroUsage(t *testing.T) {
|
||||
counter := &openAI403CounterResetStub{}
|
||||
rateLimitSvc := NewRateLimitService(nil, nil, nil, nil, nil)
|
||||
rateLimitSvc.SetOpenAI403CounterCache(counter)
|
||||
|
||||
svc := &OpenAIGatewayService{
|
||||
rateLimitService: rateLimitSvc,
|
||||
}
|
||||
usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
|
||||
billingRepo := &openAIRecordUsageBillingRepoStub{result: &UsageBillingApplyResult{Applied: true}}
|
||||
userRepo := &openAIRecordUsageUserRepoStub{}
|
||||
subRepo := &openAIRecordUsageSubRepoStub{}
|
||||
svc := newOpenAIRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, userRepo, subRepo, nil)
|
||||
svc.rateLimitService = rateLimitSvc
|
||||
|
||||
err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
|
||||
Result: &OpenAIForwardResult{},
|
||||
Result: &OpenAIForwardResult{
|
||||
RequestID: "resp_zero_usage_reset_403",
|
||||
Model: "gpt-5.1",
|
||||
},
|
||||
APIKey: &APIKey{ID: 1001, Group: &Group{RateMultiplier: 1}},
|
||||
User: &User{ID: 2001},
|
||||
Account: &Account{ID: 777, Platform: PlatformOpenAI},
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, []int64{777}, counter.resetCalls)
|
||||
require.Equal(t, 1, usageRepo.calls)
|
||||
}
|
||||
|
||||
@ -10,10 +10,12 @@ import (
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/apicompat"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/openai_compat"
|
||||
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/tidwall/gjson"
|
||||
@ -39,9 +41,18 @@ var cursorResponsesUnsupportedFields = []string{
|
||||
|
||||
// ForwardAsChatCompletions accepts a Chat Completions request body, converts it
|
||||
// to OpenAI Responses API format, forwards to the OpenAI upstream, and converts
|
||||
// the response back to Chat Completions format. All account types (OAuth and API
|
||||
// Key) go through the Responses API conversion path since the upstream only
|
||||
// exposes the /v1/responses endpoint.
|
||||
// the response back to Chat Completions format.
|
||||
//
|
||||
// 历史背景:该函数原本对所有 OpenAI 账号无差别走 CC→Responses 转换 + /v1/responses
|
||||
// 端点——这在 OAuth(ChatGPT 内部 API 仅支持 Responses)和官方 APIKey 账号上是
|
||||
// 正确的,但 sub2api 接入 DeepSeek/Kimi/GLM 等第三方 OpenAI 兼容上游后假设破裂:
|
||||
// 这些上游普遍只支持 /v1/chat/completions,无 /v1/responses 端点。
|
||||
//
|
||||
// 当前路由策略(基于账号探测标记,详见 openai_compat.ShouldUseResponsesAPI):
|
||||
// - APIKey 账号 + 探测确认不支持 Responses → 走 forwardAsRawChatCompletions
|
||||
// 直转上游 /v1/chat/completions,不做协议转换
|
||||
// - 其他所有情况(OAuth、APIKey 探测确认支持、未探测)→ 走原有 CC→Responses
|
||||
// 转换路径(保留旧行为,存量未探测账号零兼容破坏)
|
||||
func (s *OpenAIGatewayService) ForwardAsChatCompletions(
|
||||
ctx context.Context,
|
||||
c *gin.Context,
|
||||
@ -50,6 +61,12 @@ func (s *OpenAIGatewayService) ForwardAsChatCompletions(
|
||||
promptCacheKey string,
|
||||
defaultMappedModel string,
|
||||
) (*OpenAIForwardResult, error) {
|
||||
// 入口分流:APIKey 账号 + 已探测且确认上游不支持 Responses,走 CC 直转。
|
||||
// 标记缺失(未探测)按"现状即证据"原则继续走下方原 Responses 转换路径。
|
||||
if account.Type == AccountTypeAPIKey && !openai_compat.ShouldUseResponsesAPI(account.Extra) {
|
||||
return s.forwardAsRawChatCompletions(ctx, c, account, body, defaultMappedModel)
|
||||
}
|
||||
|
||||
startTime := time.Now()
|
||||
|
||||
// 1. Parse Chat Completions request
|
||||
@ -189,7 +206,9 @@ func (s *OpenAIGatewayService) ForwardAsChatCompletions(
|
||||
}
|
||||
|
||||
// 6. Build upstream request
|
||||
upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, responsesBody, token, true, promptCacheKey, false)
|
||||
upstreamCtx, releaseUpstreamCtx := detachUpstreamContext(ctx)
|
||||
upstreamReq, err := s.buildUpstreamRequest(upstreamCtx, c, account, responsesBody, token, true, promptCacheKey, false)
|
||||
releaseUpstreamCtx()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("build upstream request: %w", err)
|
||||
}
|
||||
@ -348,59 +367,9 @@ func (s *OpenAIGatewayService) handleChatBufferedStreamingResponse(
|
||||
) (*OpenAIForwardResult, error) {
|
||||
requestID := resp.Header.Get("x-request-id")
|
||||
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
maxLineSize := defaultMaxLineSize
|
||||
if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 {
|
||||
maxLineSize = s.cfg.Gateway.MaxLineSize
|
||||
}
|
||||
scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize)
|
||||
|
||||
var finalResponse *apicompat.ResponsesResponse
|
||||
var usage OpenAIUsage
|
||||
acc := apicompat.NewBufferedResponseAccumulator()
|
||||
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if !strings.HasPrefix(line, "data: ") || line == "data: [DONE]" {
|
||||
continue
|
||||
}
|
||||
payload := line[6:]
|
||||
|
||||
var event apicompat.ResponsesStreamEvent
|
||||
if err := json.Unmarshal([]byte(payload), &event); err != nil {
|
||||
logger.L().Warn("openai chat_completions buffered: failed to parse event",
|
||||
zap.Error(err),
|
||||
zap.String("request_id", requestID),
|
||||
)
|
||||
continue
|
||||
}
|
||||
|
||||
// Accumulate delta content for fallback when terminal output is empty.
|
||||
acc.ProcessEvent(&event)
|
||||
|
||||
if (event.Type == "response.completed" || event.Type == "response.done" ||
|
||||
event.Type == "response.incomplete" || event.Type == "response.failed") &&
|
||||
event.Response != nil {
|
||||
finalResponse = event.Response
|
||||
if event.Response.Usage != nil {
|
||||
usage = OpenAIUsage{
|
||||
InputTokens: event.Response.Usage.InputTokens,
|
||||
OutputTokens: event.Response.Usage.OutputTokens,
|
||||
}
|
||||
if event.Response.Usage.InputTokensDetails != nil {
|
||||
usage.CacheReadInputTokens = event.Response.Usage.InputTokensDetails.CachedTokens
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
if !errors.Is(err, context.Canceled) && !errors.Is(err, context.DeadlineExceeded) {
|
||||
logger.L().Warn("openai chat_completions buffered: read error",
|
||||
zap.Error(err),
|
||||
zap.String("request_id", requestID),
|
||||
)
|
||||
}
|
||||
finalResponse, usage, acc, err := s.readOpenAICompatBufferedTerminal(resp, "openai chat_completions buffered", requestID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if finalResponse == nil {
|
||||
@ -459,6 +428,7 @@ func (s *OpenAIGatewayService) handleChatStreamingResponse(
|
||||
var usage OpenAIUsage
|
||||
var firstTokenMs *int
|
||||
firstChunk := true
|
||||
clientDisconnected := false
|
||||
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
maxLineSize := defaultMaxLineSize
|
||||
@ -467,6 +437,20 @@ func (s *OpenAIGatewayService) handleChatStreamingResponse(
|
||||
}
|
||||
scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize)
|
||||
|
||||
streamInterval := time.Duration(0)
|
||||
if s.cfg != nil && s.cfg.Gateway.StreamDataIntervalTimeout > 0 {
|
||||
streamInterval = time.Duration(s.cfg.Gateway.StreamDataIntervalTimeout) * time.Second
|
||||
}
|
||||
var intervalTicker *time.Ticker
|
||||
if streamInterval > 0 {
|
||||
intervalTicker = time.NewTicker(streamInterval)
|
||||
defer intervalTicker.Stop()
|
||||
}
|
||||
var intervalCh <-chan time.Time
|
||||
if intervalTicker != nil {
|
||||
intervalCh = intervalTicker.C
|
||||
}
|
||||
|
||||
resultWithUsage := func() *OpenAIForwardResult {
|
||||
return &OpenAIForwardResult{
|
||||
RequestID: requestID,
|
||||
@ -496,54 +480,66 @@ func (s *OpenAIGatewayService) handleChatStreamingResponse(
|
||||
return false
|
||||
}
|
||||
|
||||
// Extract usage from completion events
|
||||
if (event.Type == "response.completed" || event.Type == "response.incomplete" || event.Type == "response.failed") &&
|
||||
event.Response != nil && event.Response.Usage != nil {
|
||||
usage = OpenAIUsage{
|
||||
InputTokens: event.Response.Usage.InputTokens,
|
||||
OutputTokens: event.Response.Usage.OutputTokens,
|
||||
}
|
||||
if event.Response.Usage.InputTokensDetails != nil {
|
||||
usage.CacheReadInputTokens = event.Response.Usage.InputTokensDetails.CachedTokens
|
||||
}
|
||||
// 仅按兼容转换器支持的终止事件提取 usage,避免无意扩大事件语义。
|
||||
isTerminalEvent := isOpenAICompatResponsesTerminalEvent(event.Type)
|
||||
if isTerminalEvent && event.Response != nil && event.Response.Usage != nil {
|
||||
usage = copyOpenAIUsageFromResponsesUsage(event.Response.Usage)
|
||||
}
|
||||
|
||||
chunks := apicompat.ResponsesEventToChatChunks(&event, state)
|
||||
for _, chunk := range chunks {
|
||||
sse, err := apicompat.ChatChunkToSSE(chunk)
|
||||
if err != nil {
|
||||
logger.L().Warn("openai chat_completions stream: failed to marshal chunk",
|
||||
zap.Error(err),
|
||||
zap.String("request_id", requestID),
|
||||
)
|
||||
continue
|
||||
}
|
||||
if _, err := fmt.Fprint(c.Writer, sse); err != nil {
|
||||
logger.L().Info("openai chat_completions stream: client disconnected",
|
||||
zap.String("request_id", requestID),
|
||||
)
|
||||
return true
|
||||
if !clientDisconnected {
|
||||
for _, chunk := range chunks {
|
||||
sse, err := apicompat.ChatChunkToSSE(chunk)
|
||||
if err != nil {
|
||||
logger.L().Warn("openai chat_completions stream: failed to marshal chunk",
|
||||
zap.Error(err),
|
||||
zap.String("request_id", requestID),
|
||||
)
|
||||
continue
|
||||
}
|
||||
if _, err := fmt.Fprint(c.Writer, sse); err != nil {
|
||||
clientDisconnected = true
|
||||
logger.L().Info("openai chat_completions stream: client disconnected, continuing to drain upstream for billing",
|
||||
zap.String("request_id", requestID),
|
||||
)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(chunks) > 0 {
|
||||
if len(chunks) > 0 && !clientDisconnected {
|
||||
c.Writer.Flush()
|
||||
}
|
||||
return false
|
||||
return isTerminalEvent
|
||||
}
|
||||
|
||||
finalizeStream := func() (*OpenAIForwardResult, error) {
|
||||
if finalChunks := apicompat.FinalizeResponsesChatStream(state); len(finalChunks) > 0 {
|
||||
if finalChunks := apicompat.FinalizeResponsesChatStream(state); len(finalChunks) > 0 && !clientDisconnected {
|
||||
for _, chunk := range finalChunks {
|
||||
sse, err := apicompat.ChatChunkToSSE(chunk)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
fmt.Fprint(c.Writer, sse) //nolint:errcheck
|
||||
if _, err := fmt.Fprint(c.Writer, sse); err != nil {
|
||||
clientDisconnected = true
|
||||
logger.L().Info("openai chat_completions stream: client disconnected during final flush",
|
||||
zap.String("request_id", requestID),
|
||||
)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
// Send [DONE] sentinel
|
||||
fmt.Fprint(c.Writer, "data: [DONE]\n\n") //nolint:errcheck
|
||||
c.Writer.Flush()
|
||||
if !clientDisconnected {
|
||||
if _, err := fmt.Fprint(c.Writer, "data: [DONE]\n\n"); err != nil {
|
||||
clientDisconnected = true
|
||||
logger.L().Info("openai chat_completions stream: client disconnected during done flush",
|
||||
zap.String("request_id", requestID),
|
||||
)
|
||||
}
|
||||
}
|
||||
if !clientDisconnected {
|
||||
c.Writer.Flush()
|
||||
}
|
||||
return resultWithUsage(), nil
|
||||
}
|
||||
|
||||
@ -555,6 +551,9 @@ func (s *OpenAIGatewayService) handleChatStreamingResponse(
|
||||
)
|
||||
}
|
||||
}
|
||||
missingTerminalErr := func() (*OpenAIForwardResult, error) {
|
||||
return resultWithUsage(), fmt.Errorf("stream usage incomplete: missing terminal event")
|
||||
}
|
||||
|
||||
// Determine keepalive interval
|
||||
keepaliveInterval := time.Duration(0)
|
||||
@ -563,18 +562,25 @@ func (s *OpenAIGatewayService) handleChatStreamingResponse(
|
||||
}
|
||||
|
||||
// No keepalive: fast synchronous path
|
||||
if keepaliveInterval <= 0 {
|
||||
if streamInterval <= 0 && keepaliveInterval <= 0 {
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if !strings.HasPrefix(line, "data: ") || line == "data: [DONE]" {
|
||||
payload, ok := extractOpenAISSEDataLine(line)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if processDataLine(line[6:]) {
|
||||
return resultWithUsage(), nil
|
||||
if strings.TrimSpace(payload) == "[DONE]" {
|
||||
return missingTerminalErr()
|
||||
}
|
||||
if processDataLine(payload) {
|
||||
return finalizeStream()
|
||||
}
|
||||
}
|
||||
handleScanErr(scanner.Err())
|
||||
return finalizeStream()
|
||||
if err := scanner.Err(); err != nil {
|
||||
handleScanErr(err)
|
||||
return resultWithUsage(), fmt.Errorf("stream usage incomplete: %w", err)
|
||||
}
|
||||
return missingTerminalErr()
|
||||
}
|
||||
|
||||
// With keepalive: goroutine + channel + select
|
||||
@ -584,6 +590,8 @@ func (s *OpenAIGatewayService) handleChatStreamingResponse(
|
||||
}
|
||||
events := make(chan scanEvent, 16)
|
||||
done := make(chan struct{})
|
||||
var lastReadAt int64
|
||||
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
|
||||
sendEvent := func(ev scanEvent) bool {
|
||||
select {
|
||||
case events <- ev:
|
||||
@ -595,6 +603,7 @@ func (s *OpenAIGatewayService) handleChatStreamingResponse(
|
||||
go func() {
|
||||
defer close(events)
|
||||
for scanner.Scan() {
|
||||
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
|
||||
if !sendEvent(scanEvent{line: scanner.Text()}) {
|
||||
return
|
||||
}
|
||||
@ -605,30 +614,59 @@ func (s *OpenAIGatewayService) handleChatStreamingResponse(
|
||||
}()
|
||||
defer close(done)
|
||||
|
||||
keepaliveTicker := time.NewTicker(keepaliveInterval)
|
||||
defer keepaliveTicker.Stop()
|
||||
var keepaliveTicker *time.Ticker
|
||||
if keepaliveInterval > 0 {
|
||||
keepaliveTicker = time.NewTicker(keepaliveInterval)
|
||||
defer keepaliveTicker.Stop()
|
||||
}
|
||||
var keepaliveCh <-chan time.Time
|
||||
if keepaliveTicker != nil {
|
||||
keepaliveCh = keepaliveTicker.C
|
||||
}
|
||||
lastDataAt := time.Now()
|
||||
|
||||
for {
|
||||
select {
|
||||
case ev, ok := <-events:
|
||||
if !ok {
|
||||
return finalizeStream()
|
||||
return missingTerminalErr()
|
||||
}
|
||||
if ev.err != nil {
|
||||
handleScanErr(ev.err)
|
||||
return finalizeStream()
|
||||
return resultWithUsage(), fmt.Errorf("stream usage incomplete: %w", ev.err)
|
||||
}
|
||||
lastDataAt = time.Now()
|
||||
line := ev.line
|
||||
if !strings.HasPrefix(line, "data: ") || line == "data: [DONE]" {
|
||||
payload, ok := extractOpenAISSEDataLine(line)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if processDataLine(line[6:]) {
|
||||
return resultWithUsage(), nil
|
||||
if strings.TrimSpace(payload) == "[DONE]" {
|
||||
return missingTerminalErr()
|
||||
}
|
||||
if processDataLine(payload) {
|
||||
return finalizeStream()
|
||||
}
|
||||
|
||||
case <-keepaliveTicker.C:
|
||||
case <-intervalCh:
|
||||
lastRead := time.Unix(0, atomic.LoadInt64(&lastReadAt))
|
||||
if time.Since(lastRead) < streamInterval {
|
||||
continue
|
||||
}
|
||||
if clientDisconnected {
|
||||
return resultWithUsage(), fmt.Errorf("stream usage incomplete after timeout")
|
||||
}
|
||||
logger.L().Warn("openai chat_completions stream: data interval timeout",
|
||||
zap.String("request_id", requestID),
|
||||
zap.String("model", originalModel),
|
||||
zap.Duration("interval", streamInterval),
|
||||
)
|
||||
return resultWithUsage(), fmt.Errorf("stream data interval timeout")
|
||||
|
||||
case <-keepaliveCh:
|
||||
if clientDisconnected {
|
||||
continue
|
||||
}
|
||||
if time.Since(lastDataAt) < keepaliveInterval {
|
||||
continue
|
||||
}
|
||||
@ -637,7 +675,8 @@ func (s *OpenAIGatewayService) handleChatStreamingResponse(
|
||||
logger.L().Info("openai chat_completions stream: client disconnected during keepalive",
|
||||
zap.String("request_id", requestID),
|
||||
)
|
||||
return resultWithUsage(), nil
|
||||
clientDisconnected = true
|
||||
continue
|
||||
}
|
||||
c.Writer.Flush()
|
||||
}
|
||||
|
||||
437
backend/internal/service/openai_gateway_chat_completions_raw.go
Normal file
437
backend/internal/service/openai_gateway_chat_completions_raw.go
Normal file
@ -0,0 +1,437 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/apicompat"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// openaiCCRawAllowedHeaders 是 CC 直转路径专用的客户端 header 透传白名单。
|
||||
//
|
||||
// **关键**:不能复用 openaiAllowedHeaders——后者含 Codex 客户端专属 header
|
||||
// (originator / session_id / x-codex-turn-state / x-codex-turn-metadata / conversation_id),
|
||||
// 这些在 ChatGPT OAuth 上游是必需的,但透传给 DeepSeek/Kimi/GLM 等第三方
|
||||
// OpenAI 兼容上游会造成:
|
||||
// - 完全忽略(多数友好厂商)——隐性污染上游统计
|
||||
// - 400 "unknown parameter"(严格上游)——可见错误
|
||||
//
|
||||
// 这里仅放行通用 HTTP header;content-type / authorization / accept 由上下文
|
||||
// 显式设置,不依赖透传。
|
||||
//
|
||||
// 参见决策记录:
|
||||
// pensieve/short-term/maxims/dont-reuse-shared-headers-whitelist-across-different-upstream-trust-domains
|
||||
var openaiCCRawAllowedHeaders = map[string]bool{
|
||||
"accept-language": true,
|
||||
"user-agent": true,
|
||||
}
|
||||
|
||||
// forwardAsRawChatCompletions 直转客户端的 Chat Completions 请求到上游
|
||||
// `{base_url}/v1/chat/completions`,**不**做 CC↔Responses 协议转换。
|
||||
//
|
||||
// 适用场景:account.platform=openai && account.type=apikey && 上游已被探测确认
|
||||
// 不支持 /v1/responses 端点(如 DeepSeek/Kimi/GLM/Qwen 等第三方 OpenAI 兼容上游)。
|
||||
//
|
||||
// 与 ForwardAsChatCompletions 的关键差异:
|
||||
//
|
||||
// - 不调用 apicompat.ChatCompletionsToResponses,body 仅做模型 ID 改写
|
||||
// - 上游 URL 拼到 /v1/chat/completions 而非 /v1/responses
|
||||
// - 流式响应 SSE 直接透传给客户端(上游 chunk 已是 CC 格式)
|
||||
// - 非流式响应 JSON 直接透传,仅按需提取 usage
|
||||
// - 不应用 codex OAuth transform(APIKey 路径无 OAuth)
|
||||
// - 不注入 prompt_cache_key(OAuth 专属机制)
|
||||
//
|
||||
// 调用入口:openai_gateway_chat_completions.go::ForwardAsChatCompletions
|
||||
// 在函数顶部按 openai_compat.ShouldUseResponsesAPI 分流。
|
||||
func (s *OpenAIGatewayService) forwardAsRawChatCompletions(
|
||||
ctx context.Context,
|
||||
c *gin.Context,
|
||||
account *Account,
|
||||
body []byte,
|
||||
defaultMappedModel string,
|
||||
) (*OpenAIForwardResult, error) {
|
||||
startTime := time.Now()
|
||||
|
||||
// 1. Parse minimal fields needed for routing/billing
|
||||
originalModel := gjson.GetBytes(body, "model").String()
|
||||
if originalModel == "" {
|
||||
writeChatCompletionsError(c, http.StatusBadRequest, "invalid_request_error", "model is required")
|
||||
return nil, fmt.Errorf("missing model in request")
|
||||
}
|
||||
clientStream := gjson.GetBytes(body, "stream").Bool()
|
||||
|
||||
// 1b. Extract reasoning effort and service tier from the raw body before any transformation.
|
||||
reasoningEffort := extractOpenAIReasoningEffortFromBody(body, originalModel)
|
||||
serviceTier := extractOpenAIServiceTierFromBody(body)
|
||||
|
||||
// 2. Resolve model mapping (same as ForwardAsChatCompletions)
|
||||
billingModel := resolveOpenAIForwardModel(account, originalModel, defaultMappedModel)
|
||||
upstreamModel := normalizeOpenAIModelForUpstream(account, billingModel)
|
||||
|
||||
// 3. Rewrite model in body (no protocol conversion)
|
||||
upstreamBody := body
|
||||
if upstreamModel != originalModel {
|
||||
upstreamBody = ReplaceModelInBody(body, upstreamModel)
|
||||
}
|
||||
|
||||
// 4. Apply OpenAI fast policy on the CC body
|
||||
updatedBody, policyErr := s.applyOpenAIFastPolicyToBody(ctx, account, upstreamModel, upstreamBody)
|
||||
if policyErr != nil {
|
||||
var blocked *OpenAIFastBlockedError
|
||||
if errors.As(policyErr, &blocked) {
|
||||
writeChatCompletionsError(c, http.StatusForbidden, "permission_error", blocked.Message)
|
||||
}
|
||||
return nil, policyErr
|
||||
}
|
||||
upstreamBody = updatedBody
|
||||
if clientStream {
|
||||
var usageErr error
|
||||
upstreamBody, usageErr = ensureOpenAIChatStreamUsage(upstreamBody)
|
||||
if usageErr != nil {
|
||||
return nil, fmt.Errorf("enable stream usage: %w", usageErr)
|
||||
}
|
||||
}
|
||||
|
||||
logger.L().Debug("openai chat_completions raw: forwarding without protocol conversion",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.String("original_model", originalModel),
|
||||
zap.String("billing_model", billingModel),
|
||||
zap.String("upstream_model", upstreamModel),
|
||||
zap.Bool("stream", clientStream),
|
||||
)
|
||||
|
||||
// 5. Build upstream request
|
||||
apiKey := account.GetOpenAIApiKey()
|
||||
if apiKey == "" {
|
||||
return nil, fmt.Errorf("account %d missing api_key", account.ID)
|
||||
}
|
||||
baseURL := account.GetOpenAIBaseURL()
|
||||
if baseURL == "" {
|
||||
baseURL = "https://api.openai.com"
|
||||
}
|
||||
validatedURL, err := s.validateUpstreamBaseURL(baseURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid base_url: %w", err)
|
||||
}
|
||||
targetURL := buildOpenAIChatCompletionsURL(validatedURL)
|
||||
|
||||
upstreamCtx, releaseUpstreamCtx := detachUpstreamContext(ctx)
|
||||
upstreamReq, err := http.NewRequestWithContext(upstreamCtx, http.MethodPost, targetURL, bytes.NewReader(upstreamBody))
|
||||
releaseUpstreamCtx()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("build upstream request: %w", err)
|
||||
}
|
||||
upstreamReq.Header.Set("Content-Type", "application/json")
|
||||
upstreamReq.Header.Set("Authorization", "Bearer "+apiKey)
|
||||
if clientStream {
|
||||
upstreamReq.Header.Set("Accept", "text/event-stream")
|
||||
} else {
|
||||
upstreamReq.Header.Set("Accept", "application/json")
|
||||
}
|
||||
|
||||
// 透传白名单中的客户端 header。详见 openaiCCRawAllowedHeaders 的设计说明。
|
||||
for key, values := range c.Request.Header {
|
||||
lowerKey := strings.ToLower(key)
|
||||
if openaiCCRawAllowedHeaders[lowerKey] {
|
||||
for _, v := range values {
|
||||
upstreamReq.Header.Add(key, v)
|
||||
}
|
||||
}
|
||||
}
|
||||
customUA := account.GetOpenAIUserAgent()
|
||||
if customUA != "" {
|
||||
upstreamReq.Header.Set("user-agent", customUA)
|
||||
}
|
||||
|
||||
// 6. Send request
|
||||
proxyURL := ""
|
||||
if account.Proxy != nil {
|
||||
proxyURL = account.Proxy.URL()
|
||||
}
|
||||
resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
|
||||
if err != nil {
|
||||
safeErr := sanitizeUpstreamErrorMessage(err.Error())
|
||||
setOpsUpstreamError(c, 0, safeErr, "")
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
AccountID: account.ID,
|
||||
AccountName: account.Name,
|
||||
UpstreamStatusCode: 0,
|
||||
Kind: "request_error",
|
||||
Message: safeErr,
|
||||
})
|
||||
writeChatCompletionsError(c, http.StatusBadGateway, "upstream_error", "Upstream request failed")
|
||||
return nil, fmt.Errorf("upstream request failed: %s", safeErr)
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
// 7. Handle error response with failover
|
||||
if resp.StatusCode >= 400 {
|
||||
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||
_ = resp.Body.Close()
|
||||
resp.Body = io.NopCloser(bytes.NewReader(respBody))
|
||||
|
||||
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody))
|
||||
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
|
||||
if s.shouldFailoverOpenAIUpstreamResponse(resp.StatusCode, upstreamMsg, respBody) {
|
||||
upstreamDetail := ""
|
||||
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
|
||||
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
|
||||
if maxBytes <= 0 {
|
||||
maxBytes = 2048
|
||||
}
|
||||
upstreamDetail = truncateString(string(respBody), maxBytes)
|
||||
}
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
AccountID: account.ID,
|
||||
AccountName: account.Name,
|
||||
UpstreamStatusCode: resp.StatusCode,
|
||||
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
||||
Kind: "failover",
|
||||
Message: upstreamMsg,
|
||||
Detail: upstreamDetail,
|
||||
})
|
||||
if s.rateLimitService != nil {
|
||||
s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
|
||||
}
|
||||
return nil, &UpstreamFailoverError{
|
||||
StatusCode: resp.StatusCode,
|
||||
ResponseBody: respBody,
|
||||
RetryableOnSameAccount: account.IsPoolMode() && (isPoolModeRetryableStatus(resp.StatusCode) || isOpenAITransientProcessingError(resp.StatusCode, upstreamMsg, respBody)),
|
||||
}
|
||||
}
|
||||
return s.handleChatCompletionsErrorResponse(resp, c, account)
|
||||
}
|
||||
|
||||
// 8. Forward response
|
||||
if clientStream {
|
||||
return s.streamRawChatCompletions(c, resp, originalModel, billingModel, upstreamModel, reasoningEffort, serviceTier, startTime)
|
||||
}
|
||||
return s.bufferRawChatCompletions(c, resp, originalModel, billingModel, upstreamModel, reasoningEffort, serviceTier, startTime)
|
||||
}
|
||||
|
||||
// streamRawChatCompletions 透传上游 CC SSE 流到客户端,并提取 usage(包括
|
||||
// 末尾 [DONE] 之前的 chunk 中的 usage 字段,按 OpenAI CC 协议)。
|
||||
//
|
||||
// usage 字段仅在客户端请求 stream_options.include_usage=true 时出现于上游响应中。
|
||||
// 网关会对上游强制打开 include_usage 以保证计费完整,并原样向下游透传 usage,
|
||||
// 让级联代理或下游计费系统也能拿到完整用量。
|
||||
func (s *OpenAIGatewayService) streamRawChatCompletions(
|
||||
c *gin.Context,
|
||||
resp *http.Response,
|
||||
originalModel string,
|
||||
billingModel string,
|
||||
upstreamModel string,
|
||||
reasoningEffort *string,
|
||||
serviceTier *string,
|
||||
startTime time.Time,
|
||||
) (*OpenAIForwardResult, error) {
|
||||
requestID := resp.Header.Get("x-request-id")
|
||||
|
||||
if s.responseHeaderFilter != nil {
|
||||
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter)
|
||||
}
|
||||
c.Writer.Header().Set("Content-Type", "text/event-stream")
|
||||
c.Writer.Header().Set("Cache-Control", "no-cache")
|
||||
c.Writer.Header().Set("Connection", "keep-alive")
|
||||
c.Writer.Header().Set("X-Accel-Buffering", "no")
|
||||
c.Writer.WriteHeader(http.StatusOK)
|
||||
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
maxLineSize := defaultMaxLineSize
|
||||
if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 {
|
||||
maxLineSize = s.cfg.Gateway.MaxLineSize
|
||||
}
|
||||
scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize)
|
||||
|
||||
var usage OpenAIUsage
|
||||
var firstTokenMs *int
|
||||
clientDisconnected := false
|
||||
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if payload, ok := extractOpenAISSEDataLine(line); ok {
|
||||
trimmedPayload := strings.TrimSpace(payload)
|
||||
if trimmedPayload != "[DONE]" {
|
||||
usageOnlyChunk := isOpenAIChatUsageOnlyStreamChunk(payload)
|
||||
if u := extractCCStreamUsage(payload); u != nil {
|
||||
usage = *u
|
||||
}
|
||||
if firstTokenMs == nil && !usageOnlyChunk {
|
||||
elapsed := int(time.Since(startTime).Milliseconds())
|
||||
firstTokenMs = &elapsed
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !clientDisconnected {
|
||||
if _, werr := c.Writer.WriteString(line + "\n"); werr != nil {
|
||||
clientDisconnected = true
|
||||
logger.L().Debug("openai chat_completions raw: client disconnected, continuing to drain upstream for billing",
|
||||
zap.Error(werr),
|
||||
zap.String("request_id", requestID),
|
||||
)
|
||||
}
|
||||
}
|
||||
if line == "" {
|
||||
if !clientDisconnected {
|
||||
c.Writer.Flush()
|
||||
}
|
||||
continue
|
||||
}
|
||||
if !clientDisconnected {
|
||||
c.Writer.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
if !errors.Is(err, context.Canceled) && !errors.Is(err, context.DeadlineExceeded) {
|
||||
logger.L().Warn("openai chat_completions raw: stream read error",
|
||||
zap.Error(err),
|
||||
zap.String("request_id", requestID),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
return &OpenAIForwardResult{
|
||||
RequestID: requestID,
|
||||
Usage: usage,
|
||||
Model: originalModel,
|
||||
BillingModel: billingModel,
|
||||
UpstreamModel: upstreamModel,
|
||||
ReasoningEffort: reasoningEffort,
|
||||
ServiceTier: serviceTier,
|
||||
Stream: true,
|
||||
Duration: time.Since(startTime),
|
||||
FirstTokenMs: firstTokenMs,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ensureOpenAIChatStreamUsage 确保 raw Chat Completions 流式请求会让上游返回 usage。
|
||||
// usage 也会继续向下游透传,支持级联代理和下游计费系统。
|
||||
func ensureOpenAIChatStreamUsage(body []byte) ([]byte, error) {
|
||||
updated, err := sjson.SetBytes(body, "stream_options.include_usage", true)
|
||||
if err != nil {
|
||||
return body, err
|
||||
}
|
||||
return updated, nil
|
||||
}
|
||||
|
||||
func isOpenAIChatUsageOnlyStreamChunk(payload string) bool {
|
||||
if strings.TrimSpace(payload) == "" {
|
||||
return false
|
||||
}
|
||||
if !gjson.Get(payload, "usage").Exists() {
|
||||
return false
|
||||
}
|
||||
choices := gjson.Get(payload, "choices")
|
||||
return choices.Exists() && choices.IsArray() && len(choices.Array()) == 0
|
||||
}
|
||||
|
||||
// extractCCStreamUsage 从单个 CC 流式 chunk 的 payload 中提取 usage 字段。
|
||||
// CC 协议中 usage 仅出现在末尾 chunk(且仅当 include_usage 生效时),
|
||||
// 但上游可能在多个 chunk 中重复——总是用最新值。
|
||||
func extractCCStreamUsage(payload string) *OpenAIUsage {
|
||||
usageResult := gjson.Get(payload, "usage")
|
||||
if !usageResult.Exists() || !usageResult.IsObject() {
|
||||
return nil
|
||||
}
|
||||
u := OpenAIUsage{
|
||||
InputTokens: int(gjson.Get(payload, "usage.prompt_tokens").Int()),
|
||||
OutputTokens: int(gjson.Get(payload, "usage.completion_tokens").Int()),
|
||||
}
|
||||
if cached := gjson.Get(payload, "usage.prompt_tokens_details.cached_tokens"); cached.Exists() {
|
||||
u.CacheReadInputTokens = int(cached.Int())
|
||||
}
|
||||
return &u
|
||||
}
|
||||
|
||||
// bufferRawChatCompletions 透传上游 CC 非流式 JSON 响应。
|
||||
func (s *OpenAIGatewayService) bufferRawChatCompletions(
|
||||
c *gin.Context,
|
||||
resp *http.Response,
|
||||
originalModel string,
|
||||
billingModel string,
|
||||
upstreamModel string,
|
||||
reasoningEffort *string,
|
||||
serviceTier *string,
|
||||
startTime time.Time,
|
||||
) (*OpenAIForwardResult, error) {
|
||||
requestID := resp.Header.Get("x-request-id")
|
||||
|
||||
respBody, err := ReadUpstreamResponseBody(resp.Body, s.cfg, c, openAITooLargeError)
|
||||
if err != nil {
|
||||
if !errors.Is(err, ErrUpstreamResponseBodyTooLarge) {
|
||||
writeChatCompletionsError(c, http.StatusBadGateway, "api_error", "Failed to read upstream response")
|
||||
}
|
||||
return nil, fmt.Errorf("read upstream body: %w", err)
|
||||
}
|
||||
|
||||
var ccResp apicompat.ChatCompletionsResponse
|
||||
var usage OpenAIUsage
|
||||
if err := json.Unmarshal(respBody, &ccResp); err == nil && ccResp.Usage != nil {
|
||||
usage = OpenAIUsage{
|
||||
InputTokens: ccResp.Usage.PromptTokens,
|
||||
OutputTokens: ccResp.Usage.CompletionTokens,
|
||||
}
|
||||
if ccResp.Usage.PromptTokensDetails != nil {
|
||||
usage.CacheReadInputTokens = ccResp.Usage.PromptTokensDetails.CachedTokens
|
||||
}
|
||||
}
|
||||
|
||||
if s.responseHeaderFilter != nil {
|
||||
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter)
|
||||
}
|
||||
if ct := resp.Header.Get("Content-Type"); ct != "" {
|
||||
c.Writer.Header().Set("Content-Type", ct)
|
||||
} else {
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
}
|
||||
c.Writer.WriteHeader(http.StatusOK)
|
||||
_, _ = c.Writer.Write(respBody)
|
||||
|
||||
return &OpenAIForwardResult{
|
||||
RequestID: requestID,
|
||||
Usage: usage,
|
||||
Model: originalModel,
|
||||
BillingModel: billingModel,
|
||||
UpstreamModel: upstreamModel,
|
||||
ReasoningEffort: reasoningEffort,
|
||||
ServiceTier: serviceTier,
|
||||
Stream: false,
|
||||
Duration: time.Since(startTime),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// buildOpenAIChatCompletionsURL 拼接上游 Chat Completions 端点 URL。
|
||||
//
|
||||
// - base 已是 /chat/completions:原样返回
|
||||
// - base 以 /v1 结尾:追加 /chat/completions
|
||||
// - 其他情况:追加 /v1/chat/completions
|
||||
//
|
||||
// 与 buildOpenAIResponsesURL 是姐妹函数。
|
||||
func buildOpenAIChatCompletionsURL(base string) string {
|
||||
normalized := strings.TrimRight(strings.TrimSpace(base), "/")
|
||||
if strings.HasSuffix(normalized, "/chat/completions") {
|
||||
return normalized
|
||||
}
|
||||
if strings.HasSuffix(normalized, "/v1") {
|
||||
return normalized + "/chat/completions"
|
||||
}
|
||||
return normalized + "/v1/chat/completions"
|
||||
}
|
||||
@ -0,0 +1,260 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
func TestBuildOpenAIChatCompletionsURL(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
base string
|
||||
want string
|
||||
}{
|
||||
// 已是 /chat/completions:原样返回
|
||||
{"already chat/completions", "https://api.openai.com/v1/chat/completions", "https://api.openai.com/v1/chat/completions"},
|
||||
// 以 /v1 结尾:追加 /chat/completions
|
||||
{"bare /v1", "https://api.openai.com/v1", "https://api.openai.com/v1/chat/completions"},
|
||||
// 其他情况:追加 /v1/chat/completions
|
||||
{"bare domain", "https://api.openai.com", "https://api.openai.com/v1/chat/completions"},
|
||||
{"domain with trailing slash", "https://api.openai.com/", "https://api.openai.com/v1/chat/completions"},
|
||||
// 第三方上游常见形式
|
||||
{"third-party bare domain", "https://api.deepseek.com", "https://api.deepseek.com/v1/chat/completions"},
|
||||
{"third-party with path prefix", "https://api.gptgod.online/api", "https://api.gptgod.online/api/v1/chat/completions"},
|
||||
// 带空白字符
|
||||
{"whitespace trimmed", " https://api.openai.com/v1 ", "https://api.openai.com/v1/chat/completions"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
got := buildOpenAIChatCompletionsURL(tt.base)
|
||||
require.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestBuildOpenAIResponsesURL_ProbeURL 锁定 probe/测试端点使用的 URL 构建逻辑,
|
||||
// 确保 buildOpenAIResponsesURL 对标准 OpenAI base_url 格式均拼出 `/v1/responses`。
|
||||
func TestBuildOpenAIResponsesURL_ProbeURL(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
base string
|
||||
want string
|
||||
}{
|
||||
{"bare domain", "https://api.openai.com", "https://api.openai.com/v1/responses"},
|
||||
{"domain trailing slash", "https://api.openai.com/", "https://api.openai.com/v1/responses"},
|
||||
{"bare /v1", "https://api.openai.com/v1", "https://api.openai.com/v1/responses"},
|
||||
{"already /responses", "https://api.openai.com/v1/responses", "https://api.openai.com/v1/responses"},
|
||||
{"third-party bare domain", "https://api.deepseek.com", "https://api.deepseek.com/v1/responses"},
|
||||
{"only domain, no scheme", "api.gptgod.online", "api.gptgod.online/v1/responses"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
got := buildOpenAIResponsesURL(tt.base)
|
||||
require.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestForwardAsRawChatCompletions_ForcesStreamUsageUpstreamAndPassesUsageDownstream(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
body := []byte(`{"model":"gpt-5.4","messages":[{"role":"user","content":"hello"}],"stream":true}`)
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", bytes.NewReader(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
upstreamBody := strings.Join([]string{
|
||||
`data: {"id":"chatcmpl_1","object":"chat.completion.chunk","model":"gpt-5.4","choices":[{"index":0,"delta":{"content":"ok"}}]}`,
|
||||
"",
|
||||
`data: {"id":"chatcmpl_1","object":"chat.completion.chunk","model":"gpt-5.4","choices":[],"usage":{"prompt_tokens":9,"completion_tokens":4,"total_tokens":13,"prompt_tokens_details":{"cached_tokens":3}}}`,
|
||||
"",
|
||||
"data: [DONE]",
|
||||
"",
|
||||
}, "\n")
|
||||
upstream := &httpUpstreamRecorder{resp: &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_raw_usage"}},
|
||||
Body: io.NopCloser(strings.NewReader(upstreamBody)),
|
||||
}}
|
||||
|
||||
svc := &OpenAIGatewayService{
|
||||
cfg: rawChatCompletionsTestConfig(),
|
||||
httpUpstream: upstream,
|
||||
}
|
||||
account := rawChatCompletionsTestAccount()
|
||||
|
||||
result, err := svc.forwardAsRawChatCompletions(context.Background(), c, account, body, "")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.Equal(t, 9, result.Usage.InputTokens)
|
||||
require.Equal(t, 4, result.Usage.OutputTokens)
|
||||
require.Equal(t, 3, result.Usage.CacheReadInputTokens)
|
||||
require.NotNil(t, upstream.lastReq)
|
||||
require.NoError(t, upstream.lastReq.Context().Err())
|
||||
require.True(t, gjson.GetBytes(upstream.lastBody, "stream_options.include_usage").Bool())
|
||||
require.Contains(t, rec.Body.String(), `"usage"`)
|
||||
require.Contains(t, rec.Body.String(), "data: [DONE]")
|
||||
}
|
||||
|
||||
func TestForwardAsRawChatCompletions_ClientDisconnectDrainsUsage(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
body := []byte(`{"model":"gpt-5.4","messages":[{"role":"user","content":"hello"}],"stream":true}`)
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Writer = &openAIChatFailingWriter{ResponseWriter: c.Writer, failAfter: 0}
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", bytes.NewReader(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
upstreamBody := strings.Join([]string{
|
||||
`data: {"id":"chatcmpl_1","object":"chat.completion.chunk","model":"gpt-5.4","choices":[{"index":0,"delta":{"content":"ok"}}]}`,
|
||||
"",
|
||||
`data: {"id":"chatcmpl_1","object":"chat.completion.chunk","model":"gpt-5.4","choices":[],"usage":{"prompt_tokens":17,"completion_tokens":8,"total_tokens":25,"prompt_tokens_details":{"cached_tokens":6}}}`,
|
||||
"",
|
||||
"data: [DONE]",
|
||||
"",
|
||||
}, "\n")
|
||||
upstream := &httpUpstreamRecorder{resp: &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_raw_disconnect"}},
|
||||
Body: io.NopCloser(strings.NewReader(upstreamBody)),
|
||||
}}
|
||||
|
||||
svc := &OpenAIGatewayService{
|
||||
cfg: rawChatCompletionsTestConfig(),
|
||||
httpUpstream: upstream,
|
||||
}
|
||||
account := rawChatCompletionsTestAccount()
|
||||
|
||||
result, err := svc.forwardAsRawChatCompletions(context.Background(), c, account, body, "")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.Equal(t, 17, result.Usage.InputTokens)
|
||||
require.Equal(t, 8, result.Usage.OutputTokens)
|
||||
require.Equal(t, 6, result.Usage.CacheReadInputTokens)
|
||||
require.True(t, gjson.GetBytes(upstream.lastBody, "stream_options.include_usage").Bool())
|
||||
}
|
||||
|
||||
func TestForwardAsRawChatCompletions_UpstreamRequestIgnoresClientCancel(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
reqCtx, cancel := context.WithCancel(context.Background())
|
||||
body := []byte(`{"model":"gpt-5.4","messages":[{"role":"user","content":"hello"}],"stream":true}`)
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", bytes.NewReader(body)).WithContext(reqCtx)
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
cancel()
|
||||
|
||||
upstreamBody := strings.Join([]string{
|
||||
`data: {"id":"chatcmpl_1","object":"chat.completion.chunk","model":"gpt-5.4","choices":[],"usage":{"prompt_tokens":5,"completion_tokens":2,"total_tokens":7}}`,
|
||||
"",
|
||||
"data: [DONE]",
|
||||
"",
|
||||
}, "\n")
|
||||
upstream := &httpUpstreamRecorder{resp: &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_raw_ctx"}},
|
||||
Body: io.NopCloser(strings.NewReader(upstreamBody)),
|
||||
}}
|
||||
|
||||
svc := &OpenAIGatewayService{
|
||||
cfg: rawChatCompletionsTestConfig(),
|
||||
httpUpstream: upstream,
|
||||
}
|
||||
account := rawChatCompletionsTestAccount()
|
||||
|
||||
result, err := svc.forwardAsRawChatCompletions(reqCtx, c, account, body, "")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.NotNil(t, upstream.lastReq)
|
||||
require.NoError(t, upstream.lastReq.Context().Err())
|
||||
}
|
||||
|
||||
func TestIsOpenAIChatUsageOnlyStreamChunk(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
require.True(t, isOpenAIChatUsageOnlyStreamChunk(`{"choices":[],"usage":{"prompt_tokens":1,"completion_tokens":2}}`))
|
||||
require.False(t, isOpenAIChatUsageOnlyStreamChunk(`{"choices":[{"index":0}],"usage":{"prompt_tokens":1,"completion_tokens":2}}`))
|
||||
require.False(t, isOpenAIChatUsageOnlyStreamChunk(`{"choices":[]}`))
|
||||
require.False(t, isOpenAIChatUsageOnlyStreamChunk(``))
|
||||
}
|
||||
|
||||
func TestEnsureOpenAIChatStreamUsage(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
body, err := ensureOpenAIChatStreamUsage([]byte(`{"model":"gpt-5.4"}`))
|
||||
require.NoError(t, err)
|
||||
require.True(t, gjson.GetBytes(body, "stream_options.include_usage").Bool())
|
||||
|
||||
body, err = ensureOpenAIChatStreamUsage([]byte(`{"model":"gpt-5.4","stream_options":{"include_usage":false}}`))
|
||||
require.NoError(t, err)
|
||||
require.True(t, gjson.GetBytes(body, "stream_options.include_usage").Bool())
|
||||
}
|
||||
|
||||
func TestBufferRawChatCompletions_RejectsOversizedResponse(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{"Content-Type": []string{"application/json"}},
|
||||
Body: io.NopCloser(strings.NewReader("toolong")),
|
||||
}
|
||||
svc := &OpenAIGatewayService{cfg: rawChatCompletionsTestConfig()}
|
||||
svc.cfg.Gateway.UpstreamResponseReadMaxBytes = 3
|
||||
|
||||
result, err := svc.bufferRawChatCompletions(c, resp, "gpt-5.4", "gpt-5.4", "gpt-5.4", nil, nil, time.Now())
|
||||
require.ErrorIs(t, err, ErrUpstreamResponseBodyTooLarge)
|
||||
require.Nil(t, result)
|
||||
require.Equal(t, http.StatusBadGateway, rec.Code)
|
||||
}
|
||||
|
||||
func rawChatCompletionsTestConfig() *config.Config {
|
||||
return &config.Config{
|
||||
Security: config.SecurityConfig{
|
||||
URLAllowlist: config.URLAllowlistConfig{
|
||||
Enabled: false,
|
||||
AllowInsecureHTTP: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func rawChatCompletionsTestAccount() *Account {
|
||||
return &Account{
|
||||
ID: 101,
|
||||
Name: "raw-openai-apikey",
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeAPIKey,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{
|
||||
"api_key": "sk-test",
|
||||
"base_url": "http://upstream.example",
|
||||
},
|
||||
}
|
||||
}
|
||||
@ -1,13 +1,36 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/apicompat"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
type openAIChatFailingWriter struct {
|
||||
gin.ResponseWriter
|
||||
failAfter int
|
||||
writes int
|
||||
}
|
||||
|
||||
func (w *openAIChatFailingWriter) Write(p []byte) (int, error) {
|
||||
if w.writes >= w.failAfter {
|
||||
return 0, errors.New("write failed: client disconnected")
|
||||
}
|
||||
w.writes++
|
||||
return w.ResponseWriter.Write(p)
|
||||
}
|
||||
|
||||
func TestNormalizeResponsesRequestServiceTier(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@ -73,3 +96,278 @@ func TestNormalizeResponsesBodyServiceTier(t *testing.T) {
|
||||
require.Empty(t, tier)
|
||||
require.False(t, gjson.GetBytes(body, "service_tier").Exists())
|
||||
}
|
||||
|
||||
func TestForwardAsChatCompletions_UnknownModelDoesNotUseDefaultMappedModel(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
body := []byte(`{"model":"gpt6","messages":[{"role":"user","content":"hello"}],"stream":false}`)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", bytes.NewReader(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
upstream := &httpUpstreamRecorder{resp: &http.Response{
|
||||
StatusCode: http.StatusBadRequest,
|
||||
Header: http.Header{"Content-Type": []string{"application/json"}, "x-request-id": []string{"rid_chat_unknown_model"}},
|
||||
Body: io.NopCloser(strings.NewReader(`{"error":{"type":"invalid_request_error","message":"model not found"}}`)),
|
||||
}}
|
||||
|
||||
svc := &OpenAIGatewayService{httpUpstream: upstream}
|
||||
account := &Account{
|
||||
ID: 1,
|
||||
Name: "openai-oauth",
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "oauth-token",
|
||||
"chatgpt_account_id": "chatgpt-acc",
|
||||
},
|
||||
}
|
||||
|
||||
result, err := svc.ForwardAsChatCompletions(context.Background(), c, account, body, "", "gpt-5.4")
|
||||
require.Error(t, err)
|
||||
require.Nil(t, result)
|
||||
require.Equal(t, "gpt6", gjson.GetBytes(upstream.lastBody, "model").String())
|
||||
require.NotEqual(t, "gpt-5.4", gjson.GetBytes(upstream.lastBody, "model").String())
|
||||
require.Equal(t, http.StatusBadRequest, rec.Code)
|
||||
}
|
||||
|
||||
func TestForwardAsChatCompletions_ClientDisconnectDrainsUpstreamUsage(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Writer = &openAIChatFailingWriter{ResponseWriter: c.Writer, failAfter: 0}
|
||||
body := []byte(`{"model":"gpt-5.4","messages":[{"role":"user","content":"hello"}],"stream":true}`)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", bytes.NewReader(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
upstreamBody := strings.Join([]string{
|
||||
`data: {"type":"response.created","response":{"id":"resp_1","model":"gpt-5.4","status":"in_progress","output":[]}}`,
|
||||
"",
|
||||
`data: {"type":"response.output_text.delta","delta":"ok"}`,
|
||||
"",
|
||||
`data: {"type":"response.completed","response":{"id":"resp_1","object":"response","model":"gpt-5.4","status":"completed","output":[{"type":"message","id":"msg_1","role":"assistant","status":"completed","content":[{"type":"output_text","text":"ok"}]}],"usage":{"input_tokens":11,"output_tokens":5,"total_tokens":16,"input_tokens_details":{"cached_tokens":4}}}}`,
|
||||
"",
|
||||
"data: [DONE]",
|
||||
"",
|
||||
}, "\n")
|
||||
upstream := &httpUpstreamRecorder{resp: &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_chat_disconnect"}},
|
||||
Body: io.NopCloser(strings.NewReader(upstreamBody)),
|
||||
}}
|
||||
|
||||
svc := &OpenAIGatewayService{httpUpstream: upstream}
|
||||
account := &Account{
|
||||
ID: 1,
|
||||
Name: "openai-oauth",
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "oauth-token",
|
||||
"chatgpt_account_id": "chatgpt-acc",
|
||||
},
|
||||
}
|
||||
|
||||
result, err := svc.ForwardAsChatCompletions(context.Background(), c, account, body, "", "gpt-5.1")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.Equal(t, 11, result.Usage.InputTokens)
|
||||
require.Equal(t, 5, result.Usage.OutputTokens)
|
||||
require.Equal(t, 4, result.Usage.CacheReadInputTokens)
|
||||
}
|
||||
|
||||
func TestForwardAsChatCompletions_TerminalUsageWithoutUpstreamCloseReturns(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Writer = &openAIChatFailingWriter{ResponseWriter: c.Writer, failAfter: 0}
|
||||
body := []byte(`{"model":"gpt-5.4","messages":[{"role":"user","content":"hello"}],"stream":true}`)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", bytes.NewReader(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
upstreamBody := []byte(`data: {"type":"response.completed","response":{"id":"resp_1","object":"response","model":"gpt-5.4","status":"completed","output":[{"type":"message","id":"msg_1","role":"assistant","status":"completed","content":[{"type":"output_text","text":"ok"}]}],"usage":{"input_tokens":17,"output_tokens":8,"total_tokens":25,"input_tokens_details":{"cached_tokens":6}}}}` + "\n\n")
|
||||
upstreamStream := newOpenAICompatBlockingReadCloser(upstreamBody)
|
||||
defer func() {
|
||||
require.NoError(t, upstreamStream.Close())
|
||||
}()
|
||||
upstream := &httpUpstreamRecorder{resp: &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_chat_terminal_no_close"}},
|
||||
Body: upstreamStream,
|
||||
}}
|
||||
|
||||
svc := &OpenAIGatewayService{httpUpstream: upstream}
|
||||
account := &Account{
|
||||
ID: 1,
|
||||
Name: "openai-oauth",
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "oauth-token",
|
||||
"chatgpt_account_id": "chatgpt-acc",
|
||||
},
|
||||
}
|
||||
|
||||
type forwardResult struct {
|
||||
result *OpenAIForwardResult
|
||||
err error
|
||||
}
|
||||
resultCh := make(chan forwardResult, 1)
|
||||
go func() {
|
||||
result, err := svc.ForwardAsChatCompletions(context.Background(), c, account, body, "", "gpt-5.1")
|
||||
resultCh <- forwardResult{result: result, err: err}
|
||||
}()
|
||||
|
||||
select {
|
||||
case got := <-resultCh:
|
||||
require.NoError(t, got.err)
|
||||
require.NotNil(t, got.result)
|
||||
require.Equal(t, 17, got.result.Usage.InputTokens)
|
||||
require.Equal(t, 8, got.result.Usage.OutputTokens)
|
||||
require.Equal(t, 6, got.result.Usage.CacheReadInputTokens)
|
||||
case <-time.After(time.Second):
|
||||
require.Fail(t, "ForwardAsChatCompletions should return after terminal usage event even if upstream keeps the connection open")
|
||||
}
|
||||
}
|
||||
|
||||
func TestForwardAsChatCompletions_BufferedTerminalWithoutUpstreamCloseReturns(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
body := []byte(`{"model":"gpt-5.4","messages":[{"role":"user","content":"hello"}],"stream":false}`)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", bytes.NewReader(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
upstreamBody := []byte(`data: {"type":"response.completed","response":{"id":"resp_1","object":"response","model":"gpt-5.4","status":"completed","output":[{"type":"message","id":"msg_1","role":"assistant","status":"completed","content":[{"type":"output_text","text":"ok"}]}],"usage":{"input_tokens":17,"output_tokens":8,"total_tokens":25,"input_tokens_details":{"cached_tokens":6}}}}` + "\n\n")
|
||||
upstreamStream := newOpenAICompatBlockingReadCloser(upstreamBody)
|
||||
defer func() {
|
||||
require.NoError(t, upstreamStream.Close())
|
||||
}()
|
||||
upstream := &httpUpstreamRecorder{resp: &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_chat_buffered_terminal_no_close"}},
|
||||
Body: upstreamStream,
|
||||
}}
|
||||
|
||||
svc := &OpenAIGatewayService{httpUpstream: upstream}
|
||||
account := &Account{
|
||||
ID: 1,
|
||||
Name: "openai-oauth",
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "oauth-token",
|
||||
"chatgpt_account_id": "chatgpt-acc",
|
||||
},
|
||||
}
|
||||
|
||||
type forwardResult struct {
|
||||
result *OpenAIForwardResult
|
||||
err error
|
||||
}
|
||||
resultCh := make(chan forwardResult, 1)
|
||||
go func() {
|
||||
result, err := svc.ForwardAsChatCompletions(context.Background(), c, account, body, "", "gpt-5.1")
|
||||
resultCh <- forwardResult{result: result, err: err}
|
||||
}()
|
||||
|
||||
select {
|
||||
case got := <-resultCh:
|
||||
require.NoError(t, got.err)
|
||||
require.NotNil(t, got.result)
|
||||
require.Equal(t, 17, got.result.Usage.InputTokens)
|
||||
require.Equal(t, 8, got.result.Usage.OutputTokens)
|
||||
require.Equal(t, 6, got.result.Usage.CacheReadInputTokens)
|
||||
require.Contains(t, rec.Body.String(), `"finish_reason":"stop"`)
|
||||
case <-time.After(time.Second):
|
||||
require.Fail(t, "ForwardAsChatCompletions buffered response should return after terminal usage event even if upstream keeps the connection open")
|
||||
}
|
||||
}
|
||||
|
||||
func TestForwardAsChatCompletions_DoneSentinelWithoutTerminalReturnsError(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
body := []byte(`{"model":"gpt-5.4","messages":[{"role":"user","content":"hello"}],"stream":true}`)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", bytes.NewReader(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
upstreamBody := "data: [DONE]\n\n"
|
||||
upstream := &httpUpstreamRecorder{resp: &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_chat_missing_terminal"}},
|
||||
Body: io.NopCloser(strings.NewReader(upstreamBody)),
|
||||
}}
|
||||
|
||||
svc := &OpenAIGatewayService{httpUpstream: upstream}
|
||||
account := &Account{
|
||||
ID: 1,
|
||||
Name: "openai-oauth",
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "oauth-token",
|
||||
"chatgpt_account_id": "chatgpt-acc",
|
||||
},
|
||||
}
|
||||
|
||||
result, err := svc.ForwardAsChatCompletions(context.Background(), c, account, body, "", "gpt-5.1")
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "missing terminal event")
|
||||
require.NotNil(t, result)
|
||||
require.Zero(t, result.Usage.InputTokens)
|
||||
require.Zero(t, result.Usage.OutputTokens)
|
||||
}
|
||||
|
||||
func TestForwardAsChatCompletions_UpstreamRequestIgnoresClientCancel(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
reqCtx, cancel := context.WithCancel(context.Background())
|
||||
body := []byte(`{"model":"gpt-5.4","messages":[{"role":"user","content":"hello"}],"stream":false}`)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", bytes.NewReader(body)).WithContext(reqCtx)
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
cancel()
|
||||
|
||||
upstreamBody := strings.Join([]string{
|
||||
`data: {"type":"response.completed","response":{"id":"resp_1","object":"response","model":"gpt-5.4","status":"completed","output":[{"type":"message","id":"msg_1","role":"assistant","status":"completed","content":[{"type":"output_text","text":"ok"}]}],"usage":{"input_tokens":5,"output_tokens":2,"total_tokens":7}}}`,
|
||||
"",
|
||||
"data: [DONE]",
|
||||
"",
|
||||
}, "\n")
|
||||
upstream := &httpUpstreamRecorder{resp: &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_chat_ctx"}},
|
||||
Body: io.NopCloser(strings.NewReader(upstreamBody)),
|
||||
}}
|
||||
|
||||
svc := &OpenAIGatewayService{httpUpstream: upstream}
|
||||
account := &Account{
|
||||
ID: 1,
|
||||
Name: "openai-oauth",
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "oauth-token",
|
||||
"chatgpt_account_id": "chatgpt-acc",
|
||||
},
|
||||
}
|
||||
|
||||
result, err := svc.ForwardAsChatCompletions(reqCtx, c, account, body, "", "gpt-5.1")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.NotNil(t, upstream.lastReq)
|
||||
require.NoError(t, upstream.lastReq.Context().Err())
|
||||
}
|
||||
|
||||
@ -10,6 +10,7 @@ import (
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/apicompat"
|
||||
@ -39,12 +40,54 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic(
|
||||
if err := json.Unmarshal(body, &anthropicReq); err != nil {
|
||||
return nil, fmt.Errorf("parse anthropic request: %w", err)
|
||||
}
|
||||
anthropicDigestReq := cloneAnthropicRequestForDigest(&anthropicReq)
|
||||
originalModel := anthropicReq.Model
|
||||
applyOpenAICompatModelNormalization(&anthropicReq)
|
||||
normalizedModel := anthropicReq.Model
|
||||
clientStream := anthropicReq.Stream // client's original stream preference
|
||||
|
||||
// 2. Convert Anthropic → Responses
|
||||
// 2. Model mapping
|
||||
billingModel := resolveOpenAIForwardModel(account, normalizedModel, defaultMappedModel)
|
||||
upstreamModel := normalizeOpenAIModelForUpstream(account, billingModel)
|
||||
promptCacheKey = strings.TrimSpace(promptCacheKey)
|
||||
apiKeyID := getAPIKeyIDFromContext(c)
|
||||
anthropicDigestChain := ""
|
||||
anthropicMatchedDigestChain := ""
|
||||
compatPromptCacheInjected := false
|
||||
if promptCacheKey == "" && shouldAutoInjectPromptCacheKeyForCompat(upstreamModel) {
|
||||
promptCacheKey = promptCacheKeyFromAnthropicMetadataSession(&anthropicReq)
|
||||
if promptCacheKey == "" {
|
||||
promptCacheKey = deriveAnthropicCacheControlPromptCacheKey(&anthropicReq)
|
||||
}
|
||||
if promptCacheKey == "" {
|
||||
anthropicDigestChain = buildOpenAICompatAnthropicDigestChain(anthropicDigestReq)
|
||||
if reusedKey, matchedChain := s.findOpenAICompatAnthropicDigestPromptCacheKey(account, apiKeyID, anthropicDigestChain); reusedKey != "" {
|
||||
promptCacheKey = reusedKey
|
||||
anthropicMatchedDigestChain = matchedChain
|
||||
} else {
|
||||
promptCacheKey = promptCacheKeyFromAnthropicDigest(anthropicDigestChain)
|
||||
}
|
||||
}
|
||||
compatPromptCacheInjected = promptCacheKey != ""
|
||||
}
|
||||
compatReplayTrimmed := false
|
||||
compatReplayGuardEnabled := shouldAutoInjectPromptCacheKeyForCompat(upstreamModel)
|
||||
compatContinuationEnabled := openAICompatContinuationEnabled(account, upstreamModel)
|
||||
previousResponseID := ""
|
||||
if compatContinuationEnabled {
|
||||
previousResponseID = s.getOpenAICompatSessionResponseID(ctx, c, account, promptCacheKey)
|
||||
}
|
||||
compatContinuationDisabled := compatContinuationEnabled &&
|
||||
s.isOpenAICompatSessionContinuationDisabled(ctx, c, account, promptCacheKey)
|
||||
compatTurnState := ""
|
||||
// OAuth/Plus relies on session_id + x-codex-turn-state; trimming to a
|
||||
// sliding 12-message window makes the cached prefix stall at system/tools.
|
||||
// Keep full replay there so upstream prompt caching can grow turn by turn.
|
||||
if compatReplayGuardEnabled && account.Type != AccountTypeOAuth && previousResponseID == "" && !compatContinuationDisabled {
|
||||
compatReplayTrimmed = applyAnthropicCompatFullReplayGuard(&anthropicReq)
|
||||
}
|
||||
|
||||
// 3. Convert Anthropic → Responses after compatibility-only replay guard.
|
||||
responsesReq, err := apicompat.AnthropicToResponses(&anthropicReq)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("convert anthropic to responses: %w", err)
|
||||
@ -55,24 +98,50 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic(
|
||||
responsesReq.Stream = true
|
||||
isStream := true
|
||||
|
||||
// 2b. Handle BetaFastMode → service_tier: "priority"
|
||||
// 3b. Handle BetaFastMode → service_tier: "priority"
|
||||
if containsBetaToken(c.GetHeader("anthropic-beta"), claude.BetaFastMode) {
|
||||
responsesReq.ServiceTier = "priority"
|
||||
}
|
||||
|
||||
// 3. Model mapping
|
||||
billingModel := resolveOpenAIForwardModel(account, normalizedModel, defaultMappedModel)
|
||||
upstreamModel := normalizeOpenAIModelForUpstream(account, billingModel)
|
||||
responsesReq.Model = upstreamModel
|
||||
if previousResponseID != "" {
|
||||
responsesReq.PreviousResponseID = previousResponseID
|
||||
trimAnthropicCompatResponsesInputToLatestTurn(responsesReq)
|
||||
}
|
||||
if compatReplayGuardEnabled && account.Type != AccountTypeOAuth {
|
||||
appendOpenAICompatClaudeCodeTodoGuard(responsesReq)
|
||||
}
|
||||
|
||||
logger.L().Debug("openai messages: model mapping applied",
|
||||
logFields := []zap.Field{
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.String("original_model", originalModel),
|
||||
zap.String("normalized_model", normalizedModel),
|
||||
zap.String("billing_model", billingModel),
|
||||
zap.String("upstream_model", upstreamModel),
|
||||
zap.Bool("stream", isStream),
|
||||
)
|
||||
}
|
||||
if compatPromptCacheInjected {
|
||||
logFields = append(logFields,
|
||||
zap.Bool("compat_prompt_cache_key_injected", true),
|
||||
zap.String("compat_prompt_cache_key_sha256", hashSensitiveValueForLog(promptCacheKey)),
|
||||
)
|
||||
}
|
||||
if compatReplayTrimmed {
|
||||
logFields = append(logFields,
|
||||
zap.Bool("compat_full_replay_trimmed", true),
|
||||
zap.Int("compat_messages_after_trim", len(anthropicReq.Messages)),
|
||||
)
|
||||
}
|
||||
if previousResponseID != "" {
|
||||
logFields = append(logFields,
|
||||
zap.Bool("compat_previous_response_id_attached", true),
|
||||
zap.String("compat_previous_response_id", truncateOpenAIWSLogValue(previousResponseID, openAIWSIDValueMaxLen)),
|
||||
)
|
||||
}
|
||||
if compatTurnState != "" {
|
||||
logFields = append(logFields, zap.Bool("compat_turn_state_attached", true))
|
||||
}
|
||||
logger.L().Debug("openai messages: model mapping applied", logFields...)
|
||||
|
||||
// 4. Marshal Responses request body, then apply OAuth codex transform
|
||||
responsesBody, err := json.Marshal(responsesReq)
|
||||
@ -85,7 +154,10 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic(
|
||||
if err := json.Unmarshal(responsesBody, &reqBody); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal for codex transform: %w", err)
|
||||
}
|
||||
codexResult := applyCodexOAuthTransform(reqBody, false, false)
|
||||
codexResult := applyCodexOAuthTransformWithOptions(reqBody, codexOAuthTransformOptions{
|
||||
SkipDefaultInstructions: true,
|
||||
PreserveToolCallIDs: true,
|
||||
})
|
||||
forcedTemplateText := ""
|
||||
if s.cfg != nil {
|
||||
forcedTemplateText = s.cfg.Gateway.ForcedCodexInstructionsTemplate
|
||||
@ -95,6 +167,9 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic(
|
||||
templateUpstreamModel = codexResult.NormalizedModel
|
||||
}
|
||||
existingInstructions, _ := reqBody["instructions"].(string)
|
||||
if strings.TrimSpace(existingInstructions) == "" {
|
||||
existingInstructions = extractPromptLikeInstructionsFromInput(reqBody)
|
||||
}
|
||||
if _, err := applyForcedCodexInstructionsTemplate(reqBody, forcedTemplateText, forcedCodexInstructionsTemplateData{
|
||||
ExistingInstructions: strings.TrimSpace(existingInstructions),
|
||||
OriginalModel: originalModel,
|
||||
@ -104,13 +179,19 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic(
|
||||
}); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ensureCodexOAuthInstructionsField(reqBody)
|
||||
if shouldAutoInjectPromptCacheKeyForCompat(upstreamModel) {
|
||||
appendOpenAICompatClaudeCodeTodoGuardToRequestBody(reqBody)
|
||||
}
|
||||
if codexResult.NormalizedModel != "" {
|
||||
upstreamModel = codexResult.NormalizedModel
|
||||
}
|
||||
if codexResult.PromptCacheKey != "" {
|
||||
promptCacheKey = codexResult.PromptCacheKey
|
||||
} else if promptCacheKey != "" {
|
||||
reqBody["prompt_cache_key"] = promptCacheKey
|
||||
}
|
||||
delete(reqBody, "prompt_cache_key")
|
||||
if shouldAutoInjectPromptCacheKeyForCompat(upstreamModel) {
|
||||
compatTurnState = s.getOpenAICompatSessionTurnState(ctx, c, account, promptCacheKey)
|
||||
}
|
||||
// OAuth codex transform forces stream=true upstream, so always use
|
||||
// the streaming response handler regardless of what the client asked.
|
||||
@ -163,7 +244,9 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic(
|
||||
}
|
||||
|
||||
// 6. Build upstream request
|
||||
upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, responsesBody, token, isStream, promptCacheKey, false)
|
||||
upstreamCtx, releaseUpstreamCtx := detachUpstreamContext(ctx)
|
||||
upstreamReq, err := s.buildUpstreamRequest(upstreamCtx, c, account, responsesBody, token, isStream, promptCacheKey, false)
|
||||
releaseUpstreamCtx()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("build upstream request: %w", err)
|
||||
}
|
||||
@ -171,8 +254,25 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic(
|
||||
// Override session_id with a deterministic UUID derived from the isolated
|
||||
// session key, ensuring different API keys produce different upstream sessions.
|
||||
if promptCacheKey != "" {
|
||||
apiKeyID := getAPIKeyIDFromContext(c)
|
||||
upstreamReq.Header.Set("session_id", generateSessionUUID(isolateOpenAISessionID(apiKeyID, promptCacheKey)))
|
||||
isolatedSessionID := generateSessionUUID(isolateOpenAISessionID(apiKeyID, promptCacheKey))
|
||||
upstreamReq.Header.Set("session_id", isolatedSessionID)
|
||||
if upstreamReq.Header.Get("conversation_id") != "" {
|
||||
upstreamReq.Header.Set("conversation_id", isolatedSessionID)
|
||||
}
|
||||
}
|
||||
if account.Type == AccountTypeOAuth {
|
||||
// Anthropic Messages compatibility uses the ChatGPT Codex SSE endpoint.
|
||||
// Match airgate-openai's request shape: the SSE endpoint does not need
|
||||
// the Responses experimental beta header, and forcing originator can make
|
||||
// ChatGPT select a different internal continuation path.
|
||||
upstreamReq.Header.Del("OpenAI-Beta")
|
||||
upstreamReq.Header.Del("originator")
|
||||
}
|
||||
if account.Type == AccountTypeOAuth && promptCacheKey != "" && strings.TrimSpace(c.GetHeader("conversation_id")) == "" {
|
||||
upstreamReq.Header.Del("conversation_id")
|
||||
}
|
||||
if compatTurnState != "" && upstreamReq.Header.Get("x-codex-turn-state") == "" {
|
||||
upstreamReq.Header.Set("x-codex-turn-state", compatTurnState)
|
||||
}
|
||||
|
||||
// 7. Send request
|
||||
@ -205,6 +305,19 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic(
|
||||
|
||||
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody))
|
||||
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
|
||||
if previousResponseID != "" && (isOpenAICompatPreviousResponseNotFound(resp.StatusCode, upstreamMsg, respBody) || isOpenAICompatPreviousResponseUnsupported(resp.StatusCode, upstreamMsg, respBody)) {
|
||||
if isOpenAICompatPreviousResponseUnsupported(resp.StatusCode, upstreamMsg, respBody) {
|
||||
s.disableOpenAICompatSessionContinuation(ctx, c, account, promptCacheKey)
|
||||
} else {
|
||||
s.deleteOpenAICompatSessionResponseID(ctx, c, account, promptCacheKey)
|
||||
}
|
||||
logger.L().Info("openai messages: previous_response_id unavailable, retrying without continuation",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.String("previous_response_id", truncateOpenAIWSLogValue(previousResponseID, openAIWSIDValueMaxLen)),
|
||||
zap.String("upstream_model", upstreamModel),
|
||||
)
|
||||
return s.ForwardAsAnthropic(ctx, c, account, body, promptCacheKey, defaultMappedModel)
|
||||
}
|
||||
if s.shouldFailoverOpenAIUpstreamResponse(resp.StatusCode, upstreamMsg, respBody) {
|
||||
upstreamDetail := ""
|
||||
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
|
||||
@ -237,6 +350,12 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic(
|
||||
return s.handleAnthropicErrorResponse(resp, c, account)
|
||||
}
|
||||
|
||||
if account.Type == AccountTypeOAuth && promptCacheKey != "" {
|
||||
if turnState := strings.TrimSpace(resp.Header.Get("x-codex-turn-state")); turnState != "" {
|
||||
s.bindOpenAICompatSessionTurnState(ctx, c, account, promptCacheKey, turnState)
|
||||
}
|
||||
}
|
||||
|
||||
// 9. Handle normal response
|
||||
// Upstream is always streaming; choose response format based on client preference.
|
||||
var result *OpenAIForwardResult
|
||||
@ -250,6 +369,12 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic(
|
||||
|
||||
// Propagate ServiceTier and ReasoningEffort to result for billing
|
||||
if handleErr == nil && result != nil {
|
||||
if compatContinuationEnabled && promptCacheKey != "" && result.ResponseID != "" {
|
||||
s.bindOpenAICompatSessionResponseID(ctx, c, account, promptCacheKey, result.ResponseID)
|
||||
}
|
||||
if promptCacheKey != "" && anthropicDigestChain != "" {
|
||||
s.bindOpenAICompatAnthropicDigestPromptCacheKey(account, apiKeyID, anthropicDigestChain, promptCacheKey, anthropicMatchedDigestChain)
|
||||
}
|
||||
if responsesReq.ServiceTier != "" {
|
||||
st := responsesReq.ServiceTier
|
||||
result.ServiceTier = &st
|
||||
@ -270,6 +395,19 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic(
|
||||
return result, handleErr
|
||||
}
|
||||
|
||||
func ensureCodexOAuthInstructionsField(reqBody map[string]any) {
|
||||
if reqBody == nil {
|
||||
return
|
||||
}
|
||||
if value, ok := reqBody["instructions"]; !ok || value == nil {
|
||||
reqBody["instructions"] = ""
|
||||
return
|
||||
}
|
||||
if _, ok := reqBody["instructions"].(string); !ok {
|
||||
reqBody["instructions"] = ""
|
||||
}
|
||||
}
|
||||
|
||||
// handleAnthropicErrorResponse reads an upstream error and returns it in
|
||||
// Anthropic error format.
|
||||
func (s *OpenAIGatewayService) handleAnthropicErrorResponse(
|
||||
@ -296,61 +434,9 @@ func (s *OpenAIGatewayService) handleAnthropicBufferedStreamingResponse(
|
||||
) (*OpenAIForwardResult, error) {
|
||||
requestID := resp.Header.Get("x-request-id")
|
||||
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
maxLineSize := defaultMaxLineSize
|
||||
if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 {
|
||||
maxLineSize = s.cfg.Gateway.MaxLineSize
|
||||
}
|
||||
scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize)
|
||||
|
||||
var finalResponse *apicompat.ResponsesResponse
|
||||
var usage OpenAIUsage
|
||||
acc := apicompat.NewBufferedResponseAccumulator()
|
||||
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
|
||||
if !strings.HasPrefix(line, "data: ") || line == "data: [DONE]" {
|
||||
continue
|
||||
}
|
||||
payload := line[6:]
|
||||
|
||||
var event apicompat.ResponsesStreamEvent
|
||||
if err := json.Unmarshal([]byte(payload), &event); err != nil {
|
||||
logger.L().Warn("openai messages buffered: failed to parse event",
|
||||
zap.Error(err),
|
||||
zap.String("request_id", requestID),
|
||||
)
|
||||
continue
|
||||
}
|
||||
|
||||
// Accumulate delta content for fallback when terminal output is empty.
|
||||
acc.ProcessEvent(&event)
|
||||
|
||||
// Terminal events carry the complete ResponsesResponse with output + usage.
|
||||
if (event.Type == "response.completed" || event.Type == "response.done" ||
|
||||
event.Type == "response.incomplete" || event.Type == "response.failed") &&
|
||||
event.Response != nil {
|
||||
finalResponse = event.Response
|
||||
if event.Response.Usage != nil {
|
||||
usage = OpenAIUsage{
|
||||
InputTokens: event.Response.Usage.InputTokens,
|
||||
OutputTokens: event.Response.Usage.OutputTokens,
|
||||
}
|
||||
if event.Response.Usage.InputTokensDetails != nil {
|
||||
usage.CacheReadInputTokens = event.Response.Usage.InputTokensDetails.CachedTokens
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
if !errors.Is(err, context.Canceled) && !errors.Is(err, context.DeadlineExceeded) {
|
||||
logger.L().Warn("openai messages buffered: read error",
|
||||
zap.Error(err),
|
||||
zap.String("request_id", requestID),
|
||||
)
|
||||
}
|
||||
finalResponse, usage, acc, err := s.readOpenAICompatBufferedTerminal(resp, "openai messages buffered", requestID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if finalResponse == nil {
|
||||
@ -371,6 +457,7 @@ func (s *OpenAIGatewayService) handleAnthropicBufferedStreamingResponse(
|
||||
|
||||
return &OpenAIForwardResult{
|
||||
RequestID: requestID,
|
||||
ResponseID: finalResponse.ID,
|
||||
Usage: usage,
|
||||
Model: originalModel,
|
||||
BillingModel: billingModel,
|
||||
@ -380,6 +467,153 @@ func (s *OpenAIGatewayService) handleAnthropicBufferedStreamingResponse(
|
||||
}, nil
|
||||
}
|
||||
|
||||
func isOpenAICompatResponsesTerminalEvent(eventType string) bool {
|
||||
switch strings.TrimSpace(eventType) {
|
||||
case "response.completed", "response.done", "response.incomplete", "response.failed":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func isOpenAICompatDoneSentinelLine(line string) bool {
|
||||
payload, ok := extractOpenAISSEDataLine(line)
|
||||
return ok && strings.TrimSpace(payload) == "[DONE]"
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) readOpenAICompatBufferedTerminal(
|
||||
resp *http.Response,
|
||||
logPrefix string,
|
||||
requestID string,
|
||||
) (*apicompat.ResponsesResponse, OpenAIUsage, *apicompat.BufferedResponseAccumulator, error) {
|
||||
acc := apicompat.NewBufferedResponseAccumulator()
|
||||
var usage OpenAIUsage
|
||||
if resp == nil || resp.Body == nil {
|
||||
return nil, usage, acc, errors.New("upstream response body is nil")
|
||||
}
|
||||
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
maxLineSize := defaultMaxLineSize
|
||||
if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 {
|
||||
maxLineSize = s.cfg.Gateway.MaxLineSize
|
||||
}
|
||||
scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize)
|
||||
|
||||
streamInterval := time.Duration(0)
|
||||
if s.cfg != nil && s.cfg.Gateway.StreamDataIntervalTimeout > 0 {
|
||||
streamInterval = time.Duration(s.cfg.Gateway.StreamDataIntervalTimeout) * time.Second
|
||||
}
|
||||
var timeoutCh <-chan time.Time
|
||||
var timeoutTimer *time.Timer
|
||||
resetTimeout := func() {
|
||||
if streamInterval <= 0 {
|
||||
return
|
||||
}
|
||||
if timeoutTimer == nil {
|
||||
timeoutTimer = time.NewTimer(streamInterval)
|
||||
timeoutCh = timeoutTimer.C
|
||||
return
|
||||
}
|
||||
if !timeoutTimer.Stop() {
|
||||
select {
|
||||
case <-timeoutTimer.C:
|
||||
default:
|
||||
}
|
||||
}
|
||||
timeoutTimer.Reset(streamInterval)
|
||||
}
|
||||
stopTimeout := func() {
|
||||
if timeoutTimer == nil {
|
||||
return
|
||||
}
|
||||
if !timeoutTimer.Stop() {
|
||||
select {
|
||||
case <-timeoutTimer.C:
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
||||
resetTimeout()
|
||||
defer stopTimeout()
|
||||
|
||||
type scanEvent struct {
|
||||
line string
|
||||
err error
|
||||
}
|
||||
events := make(chan scanEvent, 16)
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer close(events)
|
||||
for scanner.Scan() {
|
||||
select {
|
||||
case events <- scanEvent{line: scanner.Text()}:
|
||||
case <-done:
|
||||
return
|
||||
}
|
||||
}
|
||||
if err := scanner.Err(); err != nil {
|
||||
select {
|
||||
case events <- scanEvent{err: err}:
|
||||
case <-done:
|
||||
}
|
||||
}
|
||||
}()
|
||||
defer close(done)
|
||||
|
||||
for {
|
||||
select {
|
||||
case ev, ok := <-events:
|
||||
if !ok {
|
||||
return nil, usage, acc, nil
|
||||
}
|
||||
resetTimeout()
|
||||
if ev.err != nil {
|
||||
if !errors.Is(ev.err, context.Canceled) && !errors.Is(ev.err, context.DeadlineExceeded) {
|
||||
logger.L().Warn(logPrefix+": read error",
|
||||
zap.Error(ev.err),
|
||||
zap.String("request_id", requestID),
|
||||
)
|
||||
}
|
||||
return nil, usage, acc, ev.err
|
||||
}
|
||||
|
||||
if isOpenAICompatDoneSentinelLine(ev.line) {
|
||||
return nil, usage, acc, nil
|
||||
}
|
||||
payload, ok := extractOpenAISSEDataLine(ev.line)
|
||||
if !ok || payload == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
var event apicompat.ResponsesStreamEvent
|
||||
if err := json.Unmarshal([]byte(payload), &event); err != nil {
|
||||
logger.L().Warn(logPrefix+": failed to parse event",
|
||||
zap.Error(err),
|
||||
zap.String("request_id", requestID),
|
||||
)
|
||||
continue
|
||||
}
|
||||
|
||||
acc.ProcessEvent(&event)
|
||||
|
||||
if isOpenAICompatResponsesTerminalEvent(event.Type) && event.Response != nil {
|
||||
if event.Response.Usage != nil {
|
||||
usage = copyOpenAIUsageFromResponsesUsage(event.Response.Usage)
|
||||
}
|
||||
return event.Response, usage, acc, nil
|
||||
}
|
||||
|
||||
case <-timeoutCh:
|
||||
_ = resp.Body.Close()
|
||||
logger.L().Warn(logPrefix+": data interval timeout",
|
||||
zap.String("request_id", requestID),
|
||||
zap.Duration("interval", streamInterval),
|
||||
)
|
||||
return nil, usage, acc, fmt.Errorf("stream data interval timeout")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// handleAnthropicStreamingResponse reads Responses SSE events from upstream,
|
||||
// converts each to Anthropic SSE events, and writes them to the client.
|
||||
// When StreamKeepaliveInterval is configured, it uses a goroutine + channel
|
||||
@ -407,8 +641,10 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse(
|
||||
state := apicompat.NewResponsesEventToAnthropicState()
|
||||
state.Model = originalModel
|
||||
var usage OpenAIUsage
|
||||
responseID := ""
|
||||
var firstTokenMs *int
|
||||
firstChunk := true
|
||||
clientDisconnected := false
|
||||
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
maxLineSize := defaultMaxLineSize
|
||||
@ -417,10 +653,25 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse(
|
||||
}
|
||||
scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize)
|
||||
|
||||
streamInterval := time.Duration(0)
|
||||
if s.cfg != nil && s.cfg.Gateway.StreamDataIntervalTimeout > 0 {
|
||||
streamInterval = time.Duration(s.cfg.Gateway.StreamDataIntervalTimeout) * time.Second
|
||||
}
|
||||
var intervalTicker *time.Ticker
|
||||
if streamInterval > 0 {
|
||||
intervalTicker = time.NewTicker(streamInterval)
|
||||
defer intervalTicker.Stop()
|
||||
}
|
||||
var intervalCh <-chan time.Time
|
||||
if intervalTicker != nil {
|
||||
intervalCh = intervalTicker.C
|
||||
}
|
||||
|
||||
// resultWithUsage builds the final result snapshot.
|
||||
resultWithUsage := func() *OpenAIForwardResult {
|
||||
return &OpenAIForwardResult{
|
||||
RequestID: requestID,
|
||||
ResponseID: responseID,
|
||||
Usage: usage,
|
||||
Model: originalModel,
|
||||
BillingModel: billingModel,
|
||||
@ -432,7 +683,6 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse(
|
||||
}
|
||||
|
||||
// processDataLine handles a single "data: ..." SSE line from upstream.
|
||||
// Returns (clientDisconnected bool).
|
||||
processDataLine := func(payload string) bool {
|
||||
if firstChunk {
|
||||
firstChunk = false
|
||||
@ -449,53 +699,63 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse(
|
||||
return false
|
||||
}
|
||||
|
||||
// Extract usage from completion events
|
||||
if (event.Type == "response.completed" || event.Type == "response.incomplete" || event.Type == "response.failed") &&
|
||||
event.Response != nil && event.Response.Usage != nil {
|
||||
usage = OpenAIUsage{
|
||||
InputTokens: event.Response.Usage.InputTokens,
|
||||
OutputTokens: event.Response.Usage.OutputTokens,
|
||||
// 仅按兼容转换器支持的终止事件提取 usage,避免无意扩大事件语义。
|
||||
isTerminalEvent := isOpenAICompatResponsesTerminalEvent(event.Type)
|
||||
if isTerminalEvent && event.Response != nil {
|
||||
if id := strings.TrimSpace(event.Response.ID); id != "" {
|
||||
responseID = id
|
||||
}
|
||||
if event.Response.Usage.InputTokensDetails != nil {
|
||||
usage.CacheReadInputTokens = event.Response.Usage.InputTokensDetails.CachedTokens
|
||||
if event.Response.Usage != nil {
|
||||
usage = copyOpenAIUsageFromResponsesUsage(event.Response.Usage)
|
||||
}
|
||||
}
|
||||
|
||||
// Convert to Anthropic events
|
||||
events := apicompat.ResponsesEventToAnthropicEvents(&event, state)
|
||||
for _, evt := range events {
|
||||
sse, err := apicompat.ResponsesAnthropicEventToSSE(evt)
|
||||
if err != nil {
|
||||
logger.L().Warn("openai messages stream: failed to marshal event",
|
||||
zap.Error(err),
|
||||
zap.String("request_id", requestID),
|
||||
)
|
||||
continue
|
||||
}
|
||||
if _, err := fmt.Fprint(c.Writer, sse); err != nil {
|
||||
logger.L().Info("openai messages stream: client disconnected",
|
||||
zap.String("request_id", requestID),
|
||||
)
|
||||
return true
|
||||
if !clientDisconnected {
|
||||
for _, evt := range events {
|
||||
sse, err := apicompat.ResponsesAnthropicEventToSSE(evt)
|
||||
if err != nil {
|
||||
logger.L().Warn("openai messages stream: failed to marshal event",
|
||||
zap.Error(err),
|
||||
zap.String("request_id", requestID),
|
||||
)
|
||||
continue
|
||||
}
|
||||
if _, err := fmt.Fprint(c.Writer, sse); err != nil {
|
||||
clientDisconnected = true
|
||||
logger.L().Info("openai messages stream: client disconnected, continuing to drain upstream for billing",
|
||||
zap.String("request_id", requestID),
|
||||
)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(events) > 0 {
|
||||
if len(events) > 0 && !clientDisconnected {
|
||||
c.Writer.Flush()
|
||||
}
|
||||
return false
|
||||
return isTerminalEvent
|
||||
}
|
||||
|
||||
// finalizeStream sends any remaining Anthropic events and returns the result.
|
||||
finalizeStream := func() (*OpenAIForwardResult, error) {
|
||||
if finalEvents := apicompat.FinalizeResponsesAnthropicStream(state); len(finalEvents) > 0 {
|
||||
if finalEvents := apicompat.FinalizeResponsesAnthropicStream(state); len(finalEvents) > 0 && !clientDisconnected {
|
||||
for _, evt := range finalEvents {
|
||||
sse, err := apicompat.ResponsesAnthropicEventToSSE(evt)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
fmt.Fprint(c.Writer, sse) //nolint:errcheck
|
||||
if _, err := fmt.Fprint(c.Writer, sse); err != nil {
|
||||
clientDisconnected = true
|
||||
logger.L().Info("openai messages stream: client disconnected during final flush",
|
||||
zap.String("request_id", requestID),
|
||||
)
|
||||
break
|
||||
}
|
||||
}
|
||||
if !clientDisconnected {
|
||||
c.Writer.Flush()
|
||||
}
|
||||
c.Writer.Flush()
|
||||
}
|
||||
return resultWithUsage(), nil
|
||||
}
|
||||
@ -509,6 +769,9 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse(
|
||||
)
|
||||
}
|
||||
}
|
||||
missingTerminalErr := func() (*OpenAIForwardResult, error) {
|
||||
return resultWithUsage(), fmt.Errorf("stream usage incomplete: missing terminal event")
|
||||
}
|
||||
|
||||
// ── Determine keepalive interval ──
|
||||
keepaliveInterval := time.Duration(0)
|
||||
@ -517,18 +780,25 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse(
|
||||
}
|
||||
|
||||
// ── No keepalive: fast synchronous path (no goroutine overhead) ──
|
||||
if keepaliveInterval <= 0 {
|
||||
if streamInterval <= 0 && keepaliveInterval <= 0 {
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if !strings.HasPrefix(line, "data: ") || line == "data: [DONE]" {
|
||||
if isOpenAICompatDoneSentinelLine(line) {
|
||||
return missingTerminalErr()
|
||||
}
|
||||
payload, ok := extractOpenAISSEDataLine(line)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if processDataLine(line[6:]) {
|
||||
return resultWithUsage(), nil
|
||||
if processDataLine(payload) {
|
||||
return finalizeStream()
|
||||
}
|
||||
}
|
||||
handleScanErr(scanner.Err())
|
||||
return finalizeStream()
|
||||
if err := scanner.Err(); err != nil {
|
||||
handleScanErr(err)
|
||||
return resultWithUsage(), fmt.Errorf("stream usage incomplete: %w", err)
|
||||
}
|
||||
return missingTerminalErr()
|
||||
}
|
||||
|
||||
// ── With keepalive: goroutine + channel + select ──
|
||||
@ -538,6 +808,8 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse(
|
||||
}
|
||||
events := make(chan scanEvent, 16)
|
||||
done := make(chan struct{})
|
||||
var lastReadAt int64
|
||||
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
|
||||
sendEvent := func(ev scanEvent) bool {
|
||||
select {
|
||||
case events <- ev:
|
||||
@ -549,6 +821,7 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse(
|
||||
go func() {
|
||||
defer close(events)
|
||||
for scanner.Scan() {
|
||||
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
|
||||
if !sendEvent(scanEvent{line: scanner.Text()}) {
|
||||
return
|
||||
}
|
||||
@ -559,8 +832,15 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse(
|
||||
}()
|
||||
defer close(done)
|
||||
|
||||
keepaliveTicker := time.NewTicker(keepaliveInterval)
|
||||
defer keepaliveTicker.Stop()
|
||||
var keepaliveTicker *time.Ticker
|
||||
if keepaliveInterval > 0 {
|
||||
keepaliveTicker = time.NewTicker(keepaliveInterval)
|
||||
defer keepaliveTicker.Stop()
|
||||
}
|
||||
var keepaliveCh <-chan time.Time
|
||||
if keepaliveTicker != nil {
|
||||
keepaliveCh = keepaliveTicker.C
|
||||
}
|
||||
lastDataAt := time.Now()
|
||||
|
||||
for {
|
||||
@ -568,22 +848,44 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse(
|
||||
case ev, ok := <-events:
|
||||
if !ok {
|
||||
// Upstream closed
|
||||
return finalizeStream()
|
||||
return missingTerminalErr()
|
||||
}
|
||||
if ev.err != nil {
|
||||
handleScanErr(ev.err)
|
||||
return finalizeStream()
|
||||
return resultWithUsage(), fmt.Errorf("stream usage incomplete: %w", ev.err)
|
||||
}
|
||||
lastDataAt = time.Now()
|
||||
line := ev.line
|
||||
if !strings.HasPrefix(line, "data: ") || line == "data: [DONE]" {
|
||||
if isOpenAICompatDoneSentinelLine(line) {
|
||||
return missingTerminalErr()
|
||||
}
|
||||
payload, ok := extractOpenAISSEDataLine(line)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if processDataLine(line[6:]) {
|
||||
return resultWithUsage(), nil
|
||||
if processDataLine(payload) {
|
||||
return finalizeStream()
|
||||
}
|
||||
|
||||
case <-keepaliveTicker.C:
|
||||
case <-intervalCh:
|
||||
lastRead := time.Unix(0, atomic.LoadInt64(&lastReadAt))
|
||||
if time.Since(lastRead) < streamInterval {
|
||||
continue
|
||||
}
|
||||
if clientDisconnected {
|
||||
return resultWithUsage(), fmt.Errorf("stream usage incomplete after timeout")
|
||||
}
|
||||
logger.L().Warn("openai messages stream: data interval timeout",
|
||||
zap.String("request_id", requestID),
|
||||
zap.String("model", originalModel),
|
||||
zap.Duration("interval", streamInterval),
|
||||
)
|
||||
return resultWithUsage(), fmt.Errorf("stream data interval timeout")
|
||||
|
||||
case <-keepaliveCh:
|
||||
if clientDisconnected {
|
||||
continue
|
||||
}
|
||||
if time.Since(lastDataAt) < keepaliveInterval {
|
||||
continue
|
||||
}
|
||||
@ -593,7 +895,8 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse(
|
||||
logger.L().Info("openai messages stream: client disconnected during keepalive",
|
||||
zap.String("request_id", requestID),
|
||||
)
|
||||
return resultWithUsage(), nil
|
||||
clientDisconnected = true
|
||||
continue
|
||||
}
|
||||
c.Writer.Flush()
|
||||
}
|
||||
@ -610,3 +913,17 @@ func writeAnthropicError(c *gin.Context, statusCode int, errType, message string
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func copyOpenAIUsageFromResponsesUsage(usage *apicompat.ResponsesUsage) OpenAIUsage {
|
||||
if usage == nil {
|
||||
return OpenAIUsage{}
|
||||
}
|
||||
result := OpenAIUsage{
|
||||
InputTokens: usage.InputTokens,
|
||||
OutputTokens: usage.OutputTokens,
|
||||
}
|
||||
if usage.InputTokensDetails != nil {
|
||||
result.CacheReadInputTokens = usage.InputTokensDetails.CachedTokens
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
@ -52,6 +52,12 @@ func (s *openAIRecordUsageBillingRepoStub) Apply(ctx context.Context, cmd *Usage
|
||||
return &UsageBillingApplyResult{Applied: true}, nil
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayServiceRecordUsage_RejectsNilInput(t *testing.T) {
|
||||
svc := &OpenAIGatewayService{}
|
||||
require.Error(t, svc.RecordUsage(context.Background(), nil))
|
||||
require.Error(t, svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{}))
|
||||
}
|
||||
|
||||
type openAIRecordUsageUserRepoStub struct {
|
||||
UserRepository
|
||||
|
||||
@ -186,6 +192,56 @@ func max(a, b int) int {
|
||||
return b
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayServiceRecordUsage_ZeroUsageStillWritesUsageLog(t *testing.T) {
|
||||
usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
|
||||
billingRepo := &openAIRecordUsageBillingRepoStub{result: &UsageBillingApplyResult{Applied: true}}
|
||||
userRepo := &openAIRecordUsageUserRepoStub{}
|
||||
subRepo := &openAIRecordUsageSubRepoStub{}
|
||||
quotaSvc := &openAIRecordUsageAPIKeyQuotaStub{}
|
||||
svc := newOpenAIRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, userRepo, subRepo, nil)
|
||||
|
||||
err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
|
||||
Result: &OpenAIForwardResult{
|
||||
RequestID: "resp_zero_usage",
|
||||
Usage: OpenAIUsage{},
|
||||
Model: "gpt-5.1",
|
||||
Duration: time.Second,
|
||||
},
|
||||
APIKey: &APIKey{ID: 1000, Quota: 100, Group: &Group{RateMultiplier: 1}},
|
||||
User: &User{ID: 2000},
|
||||
Account: &Account{ID: 3000, Type: AccountTypeAPIKey},
|
||||
APIKeyService: quotaSvc,
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, billingRepo.calls)
|
||||
require.Equal(t, 1, usageRepo.calls)
|
||||
require.Equal(t, 0, userRepo.deductCalls)
|
||||
require.Equal(t, 0, subRepo.incrementCalls)
|
||||
require.Equal(t, 0, quotaSvc.quotaCalls)
|
||||
require.Equal(t, 0, quotaSvc.rateLimitCalls)
|
||||
|
||||
require.NotNil(t, usageRepo.lastLog)
|
||||
require.Equal(t, "resp_zero_usage", usageRepo.lastLog.RequestID)
|
||||
require.Zero(t, usageRepo.lastLog.InputTokens)
|
||||
require.Zero(t, usageRepo.lastLog.OutputTokens)
|
||||
require.Zero(t, usageRepo.lastLog.CacheCreationTokens)
|
||||
require.Zero(t, usageRepo.lastLog.CacheReadTokens)
|
||||
require.Zero(t, usageRepo.lastLog.ImageOutputTokens)
|
||||
require.Zero(t, usageRepo.lastLog.ImageCount)
|
||||
require.Zero(t, usageRepo.lastLog.InputCost)
|
||||
require.Zero(t, usageRepo.lastLog.OutputCost)
|
||||
require.Zero(t, usageRepo.lastLog.TotalCost)
|
||||
require.Zero(t, usageRepo.lastLog.ActualCost)
|
||||
|
||||
require.NotNil(t, billingRepo.lastCmd)
|
||||
require.Zero(t, billingRepo.lastCmd.BalanceCost)
|
||||
require.Zero(t, billingRepo.lastCmd.SubscriptionCost)
|
||||
require.Zero(t, billingRepo.lastCmd.APIKeyQuotaCost)
|
||||
require.Zero(t, billingRepo.lastCmd.APIKeyRateLimitCost)
|
||||
require.Zero(t, billingRepo.lastCmd.AccountQuotaCost)
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayServiceRecordUsage_UsesUserSpecificGroupRate(t *testing.T) {
|
||||
groupID := int64(11)
|
||||
groupRate := 1.4
|
||||
@ -956,9 +1012,8 @@ func TestOpenAIGatewayServiceRecordUsage_ChannelMappedDoesNotOverrideBillingMode
|
||||
svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil)
|
||||
usage := OpenAIUsage{InputTokens: 20, OutputTokens: 10}
|
||||
|
||||
// When channel did NOT map the model (ChannelMappedModel == OriginalModel),
|
||||
// billing should use result.BillingModel (the actual model used after group
|
||||
// DefaultMappedModel resolution), not the unmapped original model.
|
||||
// 渠道未发生模型映射时,应使用 result.BillingModel 中记录的实际上游计费模型,
|
||||
// 而不是未映射的原始请求模型。
|
||||
expectedCost, err := svc.billingService.CalculateCost("gpt-5.1", UsageTokens{
|
||||
InputTokens: 20,
|
||||
OutputTokens: 10,
|
||||
@ -1032,6 +1087,101 @@ func TestOpenAIGatewayServiceRecordUsage_ChannelMappedOverridesBillingModelWhenM
|
||||
require.True(t, usageRepo.lastLog.ActualCost > 0, "cost must not be zero")
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayServiceRecordUsage_BillsCompactOpenAIModelAlias(t *testing.T) {
|
||||
usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
|
||||
userRepo := &openAIRecordUsageUserRepoStub{}
|
||||
subRepo := &openAIRecordUsageSubRepoStub{}
|
||||
svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil)
|
||||
usage := OpenAIUsage{InputTokens: 20, OutputTokens: 10}
|
||||
|
||||
expectedCost, err := svc.billingService.CalculateCost("gpt-5.5", UsageTokens{
|
||||
InputTokens: 20,
|
||||
OutputTokens: 10,
|
||||
}, 1.1)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
|
||||
Result: &OpenAIForwardResult{
|
||||
RequestID: "resp_compact_openai_alias",
|
||||
Model: "gpt5.5",
|
||||
UpstreamModel: "gpt-5.4",
|
||||
Usage: usage,
|
||||
Duration: time.Second,
|
||||
},
|
||||
APIKey: &APIKey{ID: 10},
|
||||
User: &User{ID: 20},
|
||||
Account: &Account{ID: 30},
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, usageRepo.lastLog)
|
||||
require.Equal(t, "gpt5.5", usageRepo.lastLog.Model)
|
||||
require.NotNil(t, usageRepo.lastLog.UpstreamModel)
|
||||
require.Equal(t, "gpt-5.4", *usageRepo.lastLog.UpstreamModel)
|
||||
require.InDelta(t, expectedCost.ActualCost, usageRepo.lastLog.ActualCost, 1e-12)
|
||||
require.True(t, usageRepo.lastLog.ActualCost > 0, "cost must not be zero")
|
||||
require.InDelta(t, expectedCost.ActualCost, userRepo.lastAmount, 1e-12)
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayServiceRecordUsage_FallsBackToUpstreamModelWhenPrimaryUnpriceable(t *testing.T) {
|
||||
usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
|
||||
userRepo := &openAIRecordUsageUserRepoStub{}
|
||||
subRepo := &openAIRecordUsageSubRepoStub{}
|
||||
svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil)
|
||||
usage := OpenAIUsage{InputTokens: 20, OutputTokens: 10}
|
||||
|
||||
expectedCost, err := svc.billingService.CalculateCost("gpt-5.4", UsageTokens{
|
||||
InputTokens: 20,
|
||||
OutputTokens: 10,
|
||||
}, 1.1)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
|
||||
Result: &OpenAIForwardResult{
|
||||
RequestID: "resp_unpriceable_primary_upstream_fallback",
|
||||
Model: "not-priceable-alias",
|
||||
BillingModel: "not-priceable-alias",
|
||||
UpstreamModel: "gpt-5.4",
|
||||
Usage: usage,
|
||||
Duration: time.Second,
|
||||
},
|
||||
APIKey: &APIKey{ID: 10},
|
||||
User: &User{ID: 20},
|
||||
Account: &Account{ID: 30},
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, usageRepo.lastLog)
|
||||
require.InDelta(t, expectedCost.ActualCost, usageRepo.lastLog.ActualCost, 1e-12)
|
||||
require.True(t, usageRepo.lastLog.ActualCost > 0, "cost must not be zero")
|
||||
require.InDelta(t, expectedCost.ActualCost, userRepo.lastAmount, 1e-12)
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayServiceRecordUsage_ReturnsErrorWhenTokenModelCannotBePriced(t *testing.T) {
|
||||
usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
|
||||
userRepo := &openAIRecordUsageUserRepoStub{}
|
||||
subRepo := &openAIRecordUsageSubRepoStub{}
|
||||
svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil)
|
||||
|
||||
err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
|
||||
Result: &OpenAIForwardResult{
|
||||
RequestID: "resp_unpriceable_without_upstream",
|
||||
Model: "not-priceable-alias",
|
||||
Usage: OpenAIUsage{InputTokens: 20, OutputTokens: 10},
|
||||
Duration: time.Second,
|
||||
},
|
||||
APIKey: &APIKey{ID: 10},
|
||||
User: &User{ID: 20},
|
||||
Account: &Account{ID: 30},
|
||||
})
|
||||
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "calculate OpenAI usage cost failed")
|
||||
require.Equal(t, 0, usageRepo.calls)
|
||||
require.Equal(t, 0, userRepo.deductCalls)
|
||||
require.Equal(t, 0, subRepo.incrementCalls)
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayServiceRecordUsage_SubscriptionBillingSetsSubscriptionFields(t *testing.T) {
|
||||
usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
|
||||
userRepo := &openAIRecordUsageUserRepoStub{}
|
||||
@ -1160,3 +1310,278 @@ func TestOpenAIGatewayServiceRecordUsage_ImageUsesPerImageBillingEvenWithUsageTo
|
||||
require.InDelta(t, 0.0, usageRepo.lastLog.OutputCost, 1e-12)
|
||||
require.InDelta(t, 0.0, usageRepo.lastLog.ImageOutputCost, 1e-12)
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayServiceRecordUsage_ImageSharedMultiplierPreservesExistingBehavior(t *testing.T) {
|
||||
imagePrice := 0.2
|
||||
groupID := int64(121)
|
||||
|
||||
usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
|
||||
svc := newOpenAIRecordUsageServiceForTest(usageRepo, &openAIRecordUsageUserRepoStub{}, &openAIRecordUsageSubRepoStub{}, nil)
|
||||
|
||||
err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
|
||||
Result: &OpenAIForwardResult{
|
||||
RequestID: "resp_image_shared_multiplier",
|
||||
Model: "gpt-image-2",
|
||||
ImageCount: 1,
|
||||
ImageSize: "1K",
|
||||
Duration: time.Second,
|
||||
},
|
||||
APIKey: &APIKey{
|
||||
ID: 10121,
|
||||
GroupID: i64p(groupID),
|
||||
Group: &Group{
|
||||
ID: groupID,
|
||||
RateMultiplier: 0.15,
|
||||
ImageRateIndependent: false,
|
||||
ImageRateMultiplier: 1,
|
||||
ImagePrice1K: &imagePrice,
|
||||
},
|
||||
},
|
||||
User: &User{ID: 20121},
|
||||
Account: &Account{ID: 30121},
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, usageRepo.lastLog)
|
||||
require.InDelta(t, 0.2, usageRepo.lastLog.TotalCost, 1e-12)
|
||||
require.InDelta(t, 0.03, usageRepo.lastLog.ActualCost, 1e-12)
|
||||
require.InDelta(t, 0.15, usageRepo.lastLog.RateMultiplier, 1e-12)
|
||||
require.NotNil(t, usageRepo.lastLog.BillingMode)
|
||||
require.Equal(t, string(BillingModeImage), *usageRepo.lastLog.BillingMode)
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayServiceRecordUsage_ImageSharedMultiplierUsesUserGroupOverride(t *testing.T) {
|
||||
imagePrice := 0.5
|
||||
userRate := 0.2
|
||||
groupID := int64(125)
|
||||
|
||||
usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
|
||||
svc := newOpenAIRecordUsageServiceForTest(
|
||||
usageRepo,
|
||||
&openAIRecordUsageUserRepoStub{},
|
||||
&openAIRecordUsageSubRepoStub{},
|
||||
&openAIUserGroupRateRepoStub{rate: &userRate},
|
||||
)
|
||||
|
||||
err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
|
||||
Result: &OpenAIForwardResult{
|
||||
RequestID: "resp_image_user_group_override",
|
||||
Model: "gpt-image-2",
|
||||
ImageCount: 1,
|
||||
ImageSize: "1K",
|
||||
Duration: time.Second,
|
||||
},
|
||||
APIKey: &APIKey{
|
||||
ID: 10125,
|
||||
GroupID: i64p(groupID),
|
||||
Group: &Group{
|
||||
ID: groupID,
|
||||
RateMultiplier: 0.15,
|
||||
ImageRateIndependent: false,
|
||||
ImageRateMultiplier: 1,
|
||||
ImagePrice1K: &imagePrice,
|
||||
},
|
||||
},
|
||||
User: &User{ID: 20125},
|
||||
Account: &Account{ID: 30125},
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, usageRepo.lastLog)
|
||||
require.InDelta(t, 0.5, usageRepo.lastLog.TotalCost, 1e-12)
|
||||
require.InDelta(t, 0.1, usageRepo.lastLog.ActualCost, 1e-12)
|
||||
require.InDelta(t, 0.2, usageRepo.lastLog.RateMultiplier, 1e-12)
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayServiceRecordUsage_ImageIndependentMultiplierUsesImageRate(t *testing.T) {
|
||||
imagePrice := 0.2
|
||||
groupID := int64(122)
|
||||
|
||||
usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
|
||||
svc := newOpenAIRecordUsageServiceForTest(usageRepo, &openAIRecordUsageUserRepoStub{}, &openAIRecordUsageSubRepoStub{}, nil)
|
||||
|
||||
err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
|
||||
Result: &OpenAIForwardResult{
|
||||
RequestID: "resp_image_independent_multiplier",
|
||||
Model: "gpt-image-2",
|
||||
ImageCount: 1,
|
||||
ImageSize: "1K",
|
||||
Duration: time.Second,
|
||||
},
|
||||
APIKey: &APIKey{
|
||||
ID: 10122,
|
||||
GroupID: i64p(groupID),
|
||||
Group: &Group{
|
||||
ID: groupID,
|
||||
RateMultiplier: 0.15,
|
||||
ImageRateIndependent: true,
|
||||
ImageRateMultiplier: 1,
|
||||
ImagePrice1K: &imagePrice,
|
||||
},
|
||||
},
|
||||
User: &User{ID: 20122},
|
||||
Account: &Account{ID: 30122},
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, usageRepo.lastLog)
|
||||
require.InDelta(t, 0.2, usageRepo.lastLog.TotalCost, 1e-12)
|
||||
require.InDelta(t, 0.2, usageRepo.lastLog.ActualCost, 1e-12)
|
||||
require.InDelta(t, 1.0, usageRepo.lastLog.RateMultiplier, 1e-12)
|
||||
require.NotNil(t, usageRepo.lastLog.BillingMode)
|
||||
require.Equal(t, string(BillingModeImage), *usageRepo.lastLog.BillingMode)
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayServiceRecordUsage_ChannelImageBillingUsesImageCountAndSharedMultiplier(t *testing.T) {
|
||||
groupID := int64(123)
|
||||
usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
|
||||
svc := newOpenAIRecordUsageServiceForTest(usageRepo, &openAIRecordUsageUserRepoStub{}, &openAIRecordUsageSubRepoStub{}, nil)
|
||||
svc.resolver = newOpenAIImageChannelPricingResolverForTest(t, groupID, "gpt-image-2", 0.25)
|
||||
|
||||
err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
|
||||
Result: &OpenAIForwardResult{
|
||||
RequestID: "resp_image_channel_shared",
|
||||
Model: "gpt-image-2",
|
||||
ImageCount: 3,
|
||||
ImageSize: "1K",
|
||||
Duration: time.Second,
|
||||
},
|
||||
APIKey: &APIKey{
|
||||
ID: 10123,
|
||||
GroupID: i64p(groupID),
|
||||
Group: &Group{
|
||||
ID: groupID,
|
||||
RateMultiplier: 0.15,
|
||||
ImageRateIndependent: false,
|
||||
ImageRateMultiplier: 1,
|
||||
},
|
||||
},
|
||||
User: &User{ID: 20123},
|
||||
Account: &Account{ID: 30123},
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, usageRepo.lastLog)
|
||||
require.InDelta(t, 0.75, usageRepo.lastLog.TotalCost, 1e-12)
|
||||
require.InDelta(t, 0.1125, usageRepo.lastLog.ActualCost, 1e-12)
|
||||
require.InDelta(t, 0.15, usageRepo.lastLog.RateMultiplier, 1e-12)
|
||||
require.Equal(t, 3, usageRepo.lastLog.ImageCount)
|
||||
require.NotNil(t, usageRepo.lastLog.BillingMode)
|
||||
require.Equal(t, string(BillingModeImage), *usageRepo.lastLog.BillingMode)
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayServiceRecordUsage_ChannelImageBillingUsesImageCountAndIndependentMultiplier(t *testing.T) {
|
||||
groupID := int64(124)
|
||||
usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
|
||||
svc := newOpenAIRecordUsageServiceForTest(usageRepo, &openAIRecordUsageUserRepoStub{}, &openAIRecordUsageSubRepoStub{}, nil)
|
||||
svc.resolver = newOpenAIImageChannelPricingResolverForTest(t, groupID, "gpt-image-2", 0.25)
|
||||
|
||||
err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
|
||||
Result: &OpenAIForwardResult{
|
||||
RequestID: "resp_image_channel_independent",
|
||||
Model: "gpt-image-2",
|
||||
ImageCount: 3,
|
||||
ImageSize: "1K",
|
||||
Duration: time.Second,
|
||||
},
|
||||
APIKey: &APIKey{
|
||||
ID: 10124,
|
||||
GroupID: i64p(groupID),
|
||||
Group: &Group{
|
||||
ID: groupID,
|
||||
RateMultiplier: 0.15,
|
||||
ImageRateIndependent: true,
|
||||
ImageRateMultiplier: 1,
|
||||
},
|
||||
},
|
||||
User: &User{ID: 20124},
|
||||
Account: &Account{ID: 30124},
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, usageRepo.lastLog)
|
||||
require.InDelta(t, 0.75, usageRepo.lastLog.TotalCost, 1e-12)
|
||||
require.InDelta(t, 0.75, usageRepo.lastLog.ActualCost, 1e-12)
|
||||
require.InDelta(t, 1.0, usageRepo.lastLog.RateMultiplier, 1e-12)
|
||||
require.Equal(t, 3, usageRepo.lastLog.ImageCount)
|
||||
require.NotNil(t, usageRepo.lastLog.BillingMode)
|
||||
require.Equal(t, string(BillingModeImage), *usageRepo.lastLog.BillingMode)
|
||||
}
|
||||
|
||||
func newOpenAIImageChannelPricingResolverForTest(t *testing.T, groupID int64, model string, price float64) *ModelPricingResolver {
|
||||
t.Helper()
|
||||
cache := newEmptyChannelCache()
|
||||
cache.pricingByGroupModel[channelModelKey{groupID: groupID, model: model}] = &ChannelModelPricing{
|
||||
BillingMode: BillingModeImage,
|
||||
PerRequestPrice: &price,
|
||||
}
|
||||
cache.channelByGroupID[groupID] = &Channel{ID: groupID, Status: StatusActive}
|
||||
cache.groupPlatform[groupID] = ""
|
||||
cache.loadedAt = time.Now()
|
||||
cs := &ChannelService{}
|
||||
cs.cache.Store(cache)
|
||||
return NewModelPricingResolver(cs, NewBillingService(&config.Config{}, nil))
|
||||
}
|
||||
|
||||
func TestGatewayServiceCalculateRecordUsageCost_ChannelImageBillingUsesImageCount(t *testing.T) {
|
||||
groupID := int64(126)
|
||||
billingService := NewBillingService(&config.Config{}, nil)
|
||||
svc := &GatewayService{
|
||||
billingService: billingService,
|
||||
resolver: newOpenAIImageChannelPricingResolverForTest(t, groupID, "gemini-image", 0.25),
|
||||
}
|
||||
|
||||
cost := svc.calculateRecordUsageCost(
|
||||
context.Background(),
|
||||
&ForwardResult{Model: "gemini-image", ImageCount: 2, ImageSize: "1K"},
|
||||
&APIKey{GroupID: i64p(groupID), Group: &Group{ID: groupID}},
|
||||
"gemini-image",
|
||||
0.15,
|
||||
1.0,
|
||||
nil,
|
||||
)
|
||||
|
||||
require.NotNil(t, cost)
|
||||
require.Equal(t, string(BillingModeImage), cost.BillingMode)
|
||||
require.InDelta(t, 0.5, cost.TotalCost, 1e-12)
|
||||
require.InDelta(t, 0.5, cost.ActualCost, 1e-12)
|
||||
}
|
||||
|
||||
func TestGatewayServiceCalculateRecordUsageCost_ChannelImageBillingUsesSizeTier(t *testing.T) {
|
||||
groupID := int64(127)
|
||||
defaultPrice := 0.10
|
||||
price4K := 0.40
|
||||
cache := newEmptyChannelCache()
|
||||
cache.pricingByGroupModel[channelModelKey{groupID: groupID, model: "gemini-image"}] = &ChannelModelPricing{
|
||||
BillingMode: BillingModeImage,
|
||||
PerRequestPrice: &defaultPrice,
|
||||
Intervals: []PricingInterval{{
|
||||
TierLabel: "4K",
|
||||
PerRequestPrice: &price4K,
|
||||
}},
|
||||
}
|
||||
cache.channelByGroupID[groupID] = &Channel{ID: groupID, Status: StatusActive}
|
||||
cache.loadedAt = time.Now()
|
||||
channelService := &ChannelService{}
|
||||
channelService.cache.Store(cache)
|
||||
|
||||
svc := &GatewayService{
|
||||
billingService: NewBillingService(&config.Config{}, nil),
|
||||
resolver: NewModelPricingResolver(channelService, NewBillingService(&config.Config{}, nil)),
|
||||
}
|
||||
|
||||
cost := svc.calculateRecordUsageCost(
|
||||
context.Background(),
|
||||
&ForwardResult{Model: "gemini-image", ImageCount: 2, ImageSize: "4K"},
|
||||
&APIKey{GroupID: i64p(groupID), Group: &Group{ID: groupID}},
|
||||
"gemini-image",
|
||||
1.0,
|
||||
1.0,
|
||||
nil,
|
||||
)
|
||||
|
||||
require.NotNil(t, cost)
|
||||
require.Equal(t, string(BillingModeImage), cost.BillingMode)
|
||||
require.InDelta(t, 0.80, cost.TotalCost, 1e-12)
|
||||
require.InDelta(t, 0.80, cost.ActualCost, 1e-12)
|
||||
}
|
||||
|
||||
@ -211,9 +211,10 @@ type OpenAIUsage struct {
|
||||
|
||||
// OpenAIForwardResult represents the result of forwarding
|
||||
type OpenAIForwardResult struct {
|
||||
RequestID string
|
||||
Usage OpenAIUsage
|
||||
Model string // 原始模型(用于响应和日志显示)
|
||||
RequestID string
|
||||
ResponseID string
|
||||
Usage OpenAIUsage
|
||||
Model string // 原始模型(用于响应和日志显示)
|
||||
// BillingModel is the model used for cost calculation.
|
||||
// When non-empty, CalculateCost uses this instead of Model.
|
||||
// This is set by the Anthropic Messages conversion path where
|
||||
@ -346,10 +347,12 @@ type OpenAIGatewayService struct {
|
||||
openaiWSPassthroughDialer openAIWSClientDialer
|
||||
openaiAccountStats *openAIAccountRuntimeStats
|
||||
|
||||
openaiWSFallbackUntil sync.Map // key: int64(accountID), value: time.Time
|
||||
openaiWSRetryMetrics openAIWSRetryMetrics
|
||||
responseHeaderFilter *responseheaders.CompiledHeaderFilter
|
||||
codexSnapshotThrottle *accountWriteThrottle
|
||||
openaiWSFallbackUntil sync.Map // key: int64(accountID), value: time.Time
|
||||
openaiWSRetryMetrics openAIWSRetryMetrics
|
||||
responseHeaderFilter *responseheaders.CompiledHeaderFilter
|
||||
codexSnapshotThrottle *accountWriteThrottle
|
||||
openaiCompatSessionResponses sync.Map
|
||||
openaiCompatAnthropicDigestSessions sync.Map
|
||||
}
|
||||
|
||||
// NewOpenAIGatewayService creates a new OpenAIGatewayService
|
||||
@ -1992,6 +1995,8 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
|
||||
originalBody := body
|
||||
reqModel, reqStream, promptCacheKey := extractOpenAIRequestMetaFromBody(body)
|
||||
originalModel := reqModel
|
||||
compatMessagesBridge := isOpenAICompatMessagesBridgeBody(body)
|
||||
setOpenAICompatMessagesBridgeContext(c, compatMessagesBridge)
|
||||
|
||||
isCodexCLI := openai.IsCodexOfficialClientByHeaders(c.GetHeader("User-Agent"), c.GetHeader("originator")) || (s.cfg != nil && s.cfg.Gateway.ForceCodexCLI)
|
||||
wsDecision := s.getOpenAIWSProtocolResolver().Resolve(account)
|
||||
@ -2049,6 +2054,21 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
|
||||
promptCacheKey = strings.TrimSpace(v)
|
||||
}
|
||||
}
|
||||
apiKey := getAPIKeyFromContext(c)
|
||||
imageGenerationAllowed := GroupAllowsImageGeneration(nil)
|
||||
if apiKey != nil {
|
||||
imageGenerationAllowed = GroupAllowsImageGeneration(apiKey.Group)
|
||||
}
|
||||
if IsImageGenerationIntentMap(openAIResponsesEndpoint, reqModel, reqBody) && !imageGenerationAllowed {
|
||||
setOpsUpstreamError(c, http.StatusForbidden, ImageGenerationPermissionMessage(), "")
|
||||
c.JSON(http.StatusForbidden, gin.H{
|
||||
"error": gin.H{
|
||||
"type": "permission_error",
|
||||
"message": ImageGenerationPermissionMessage(),
|
||||
},
|
||||
})
|
||||
return nil, errors.New("image generation disabled for group")
|
||||
}
|
||||
|
||||
// Track if body needs re-serialization
|
||||
bodyModified := false
|
||||
@ -2102,13 +2122,13 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
|
||||
}
|
||||
|
||||
// 非透传模式下,instructions 为空时注入默认指令。
|
||||
if isInstructionsEmpty(reqBody) {
|
||||
if isInstructionsEmpty(reqBody) && !compatMessagesBridge {
|
||||
reqBody["instructions"] = "You are a helpful coding assistant."
|
||||
bodyModified = true
|
||||
markPatchSet("instructions", "You are a helpful coding assistant.")
|
||||
}
|
||||
|
||||
if isCodexCLI && ensureOpenAIResponsesImageGenerationTool(reqBody) {
|
||||
if isCodexCLI && imageGenerationAllowed && ensureOpenAIResponsesImageGenerationTool(reqBody) {
|
||||
bodyModified = true
|
||||
disablePatch()
|
||||
logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Injected /responses image_generation tool for Codex client")
|
||||
@ -2119,7 +2139,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
|
||||
disablePatch()
|
||||
logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Normalized /responses image_generation tool payload")
|
||||
}
|
||||
if isCodexCLI && applyCodexImageGenerationBridgeInstructions(reqBody) {
|
||||
if isCodexCLI && imageGenerationAllowed && applyCodexImageGenerationBridgeInstructions(reqBody) {
|
||||
bodyModified = true
|
||||
disablePatch()
|
||||
logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Added Codex image_generation bridge instructions")
|
||||
@ -2134,7 +2154,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
|
||||
markPatchSet("model", billingModel)
|
||||
}
|
||||
upstreamModel := billingModel
|
||||
if normalizeOpenAIResponsesImageOnlyModel(reqBody) {
|
||||
if imageGenerationAllowed && normalizeOpenAIResponsesImageOnlyModel(reqBody) {
|
||||
bodyModified = true
|
||||
disablePatch()
|
||||
if model, ok := reqBody["model"].(string); ok {
|
||||
@ -2231,7 +2251,20 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
|
||||
}
|
||||
|
||||
if account.Type == AccountTypeOAuth {
|
||||
codexResult := applyCodexOAuthTransform(reqBody, isCodexCLI, isCompactRequest)
|
||||
codexResult := codexTransformResult{}
|
||||
if compatMessagesBridge {
|
||||
codexResult = applyCodexOAuthTransformWithOptions(reqBody, codexOAuthTransformOptions{
|
||||
IsCodexCLI: isCodexCLI,
|
||||
IsCompact: isCompactRequest,
|
||||
SkipDefaultInstructions: true,
|
||||
PreserveToolCallIDs: true,
|
||||
})
|
||||
ensureCodexOAuthInstructionsField(reqBody)
|
||||
bodyModified = true
|
||||
disablePatch()
|
||||
} else {
|
||||
codexResult = applyCodexOAuthTransform(reqBody, isCodexCLI, isCompactRequest)
|
||||
}
|
||||
if codexResult.Modified {
|
||||
bodyModified = true
|
||||
disablePatch()
|
||||
@ -2355,6 +2388,34 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
|
||||
}
|
||||
}
|
||||
|
||||
if IsImageGenerationIntentMap(openAIResponsesEndpoint, reqModel, reqBody) && !imageGenerationAllowed {
|
||||
setOpsUpstreamError(c, http.StatusForbidden, ImageGenerationPermissionMessage(), "")
|
||||
c.JSON(http.StatusForbidden, gin.H{
|
||||
"error": gin.H{
|
||||
"type": "permission_error",
|
||||
"message": ImageGenerationPermissionMessage(),
|
||||
},
|
||||
})
|
||||
return nil, errors.New("image generation disabled for group")
|
||||
}
|
||||
imageBillingModel := ""
|
||||
imageSizeTier := ""
|
||||
if IsImageGenerationIntentMap(openAIResponsesEndpoint, reqModel, reqBody) {
|
||||
var imageCfgErr error
|
||||
imageBillingModel, imageSizeTier, imageCfgErr = resolveOpenAIResponsesImageBillingConfig(reqBody, billingModel)
|
||||
if imageCfgErr != nil {
|
||||
setOpsUpstreamError(c, http.StatusBadRequest, imageCfgErr.Error(), "")
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": gin.H{
|
||||
"type": "invalid_request_error",
|
||||
"message": imageCfgErr.Error(),
|
||||
"param": "size",
|
||||
},
|
||||
})
|
||||
return nil, imageCfgErr
|
||||
}
|
||||
}
|
||||
|
||||
// Re-serialize body only if modified
|
||||
if bodyModified {
|
||||
serializedByPatch := false
|
||||
@ -2592,6 +2653,10 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
|
||||
wsAttempts,
|
||||
)
|
||||
wsResult.UpstreamModel = upstreamModel
|
||||
if wsResult.ImageCount > 0 {
|
||||
wsResult.ImageSize = imageSizeTier
|
||||
wsResult.BillingModel = imageBillingModel
|
||||
}
|
||||
return wsResult, nil
|
||||
}
|
||||
s.writeOpenAIWSFallbackErrorResponse(c, account, wsErr)
|
||||
@ -2601,7 +2666,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
|
||||
httpInvalidEncryptedContentRetryTried := false
|
||||
for {
|
||||
// Build upstream request
|
||||
upstreamCtx, releaseUpstreamCtx := detachStreamUpstreamContext(ctx, reqStream)
|
||||
upstreamCtx, releaseUpstreamCtx := detachUpstreamContext(ctx)
|
||||
upstreamReq, err := s.buildUpstreamRequest(upstreamCtx, c, account, body, token, reqStream, promptCacheKey, isCodexCLI)
|
||||
releaseUpstreamCtx()
|
||||
if err != nil {
|
||||
@ -2695,6 +2760,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
|
||||
// Handle normal response
|
||||
var usage *OpenAIUsage
|
||||
var firstTokenMs *int
|
||||
imageCount := 0
|
||||
if reqStream {
|
||||
streamResult, err := s.handleStreamingResponse(ctx, resp, c, account, startTime, originalModel, upstreamModel)
|
||||
if err != nil {
|
||||
@ -2702,11 +2768,14 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
|
||||
}
|
||||
usage = streamResult.usage
|
||||
firstTokenMs = streamResult.firstTokenMs
|
||||
imageCount = streamResult.imageCount
|
||||
} else {
|
||||
usage, err = s.handleNonStreamingResponse(ctx, resp, c, account, originalModel, upstreamModel)
|
||||
nonStreamResult, err := s.handleNonStreamingResponse(ctx, resp, c, account, originalModel, upstreamModel)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
usage = nonStreamResult.usage
|
||||
imageCount = nonStreamResult.imageCount
|
||||
}
|
||||
|
||||
// Extract and save Codex usage snapshot from response headers (for OAuth accounts)
|
||||
@ -2723,7 +2792,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
|
||||
reasoningEffort := extractOpenAIReasoningEffort(reqBody, originalModel)
|
||||
serviceTier := extractOpenAIServiceTier(reqBody)
|
||||
|
||||
return &OpenAIForwardResult{
|
||||
forwardResult := &OpenAIForwardResult{
|
||||
RequestID: resp.Header.Get("x-request-id"),
|
||||
Usage: *usage,
|
||||
Model: originalModel,
|
||||
@ -2734,7 +2803,13 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
|
||||
OpenAIWSMode: false,
|
||||
Duration: time.Since(startTime),
|
||||
FirstTokenMs: firstTokenMs,
|
||||
}, nil
|
||||
}
|
||||
if imageCount > 0 {
|
||||
forwardResult.ImageCount = imageCount
|
||||
forwardResult.ImageSize = imageSizeTier
|
||||
forwardResult.BillingModel = imageBillingModel
|
||||
}
|
||||
return forwardResult, nil
|
||||
}
|
||||
}
|
||||
|
||||
@ -2823,6 +2898,35 @@ func (s *OpenAIGatewayService) forwardOpenAIPassthrough(
|
||||
}
|
||||
body = updatedBody
|
||||
|
||||
apiKey := getAPIKeyFromContext(c)
|
||||
if IsImageGenerationIntent(openAIResponsesEndpoint, reqModel, body) && !GroupAllowsImageGeneration(apiKeyGroup(apiKey)) {
|
||||
setOpsUpstreamError(c, http.StatusForbidden, ImageGenerationPermissionMessage(), "")
|
||||
c.JSON(http.StatusForbidden, gin.H{
|
||||
"error": gin.H{
|
||||
"type": "permission_error",
|
||||
"message": ImageGenerationPermissionMessage(),
|
||||
},
|
||||
})
|
||||
return nil, errors.New("image generation disabled for group")
|
||||
}
|
||||
imageBillingModel := ""
|
||||
imageSizeTier := ""
|
||||
if IsImageGenerationIntent(openAIResponsesEndpoint, reqModel, body) {
|
||||
var imageCfgErr error
|
||||
imageBillingModel, imageSizeTier, imageCfgErr = resolveOpenAIResponsesImageBillingConfigFromBody(body, reqModel)
|
||||
if imageCfgErr != nil {
|
||||
setOpsUpstreamError(c, http.StatusBadRequest, imageCfgErr.Error(), "")
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": gin.H{
|
||||
"type": "invalid_request_error",
|
||||
"message": imageCfgErr.Error(),
|
||||
"param": "size",
|
||||
},
|
||||
})
|
||||
return nil, imageCfgErr
|
||||
}
|
||||
}
|
||||
|
||||
logger.LegacyPrintf("service.openai_gateway",
|
||||
"[OpenAI 自动透传] 命中自动透传分支: account=%d name=%s type=%s model=%s stream=%v",
|
||||
account.ID,
|
||||
@ -2852,7 +2956,7 @@ func (s *OpenAIGatewayService) forwardOpenAIPassthrough(
|
||||
return nil, err
|
||||
}
|
||||
|
||||
upstreamCtx, releaseUpstreamCtx := detachStreamUpstreamContext(ctx, reqStream)
|
||||
upstreamCtx, releaseUpstreamCtx := detachUpstreamContext(ctx)
|
||||
upstreamReq, err := s.buildUpstreamRequestOpenAIPassthrough(upstreamCtx, c, account, body, token)
|
||||
releaseUpstreamCtx()
|
||||
if err != nil {
|
||||
@ -2905,6 +3009,7 @@ func (s *OpenAIGatewayService) forwardOpenAIPassthrough(
|
||||
|
||||
var usage *OpenAIUsage
|
||||
var firstTokenMs *int
|
||||
imageCount := 0
|
||||
if reqStream {
|
||||
result, err := s.handleStreamingResponsePassthrough(ctx, resp, c, account, startTime, reqModel, upstreamPassthroughModel)
|
||||
if err != nil {
|
||||
@ -2912,11 +3017,14 @@ func (s *OpenAIGatewayService) forwardOpenAIPassthrough(
|
||||
}
|
||||
usage = result.usage
|
||||
firstTokenMs = result.firstTokenMs
|
||||
imageCount = result.imageCount
|
||||
} else {
|
||||
usage, err = s.handleNonStreamingResponsePassthrough(ctx, resp, c, reqModel, upstreamPassthroughModel)
|
||||
result, err := s.handleNonStreamingResponsePassthrough(ctx, resp, c, reqModel, upstreamPassthroughModel)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
usage = result.usage
|
||||
imageCount = result.imageCount
|
||||
}
|
||||
|
||||
if snapshot := ParseCodexRateLimitHeaders(resp.Header); snapshot != nil {
|
||||
@ -2927,7 +3035,7 @@ func (s *OpenAIGatewayService) forwardOpenAIPassthrough(
|
||||
usage = &OpenAIUsage{}
|
||||
}
|
||||
|
||||
return &OpenAIForwardResult{
|
||||
forwardResult := &OpenAIForwardResult{
|
||||
RequestID: resp.Header.Get("x-request-id"),
|
||||
Usage: *usage,
|
||||
Model: reqModel,
|
||||
@ -2938,7 +3046,13 @@ func (s *OpenAIGatewayService) forwardOpenAIPassthrough(
|
||||
OpenAIWSMode: false,
|
||||
Duration: time.Since(startTime),
|
||||
FirstTokenMs: firstTokenMs,
|
||||
}, nil
|
||||
}
|
||||
if imageCount > 0 {
|
||||
forwardResult.ImageCount = imageCount
|
||||
forwardResult.ImageSize = imageSizeTier
|
||||
forwardResult.BillingModel = imageBillingModel
|
||||
}
|
||||
return forwardResult, nil
|
||||
}
|
||||
|
||||
func logOpenAIPassthroughInstructionsRejected(
|
||||
@ -3233,6 +3347,13 @@ func collectOpenAIPassthroughTimeoutHeaders(h http.Header) []string {
|
||||
type openaiStreamingResultPassthrough struct {
|
||||
usage *OpenAIUsage
|
||||
firstTokenMs *int
|
||||
imageCount int
|
||||
}
|
||||
|
||||
type openaiNonStreamingResultPassthrough struct {
|
||||
*OpenAIUsage
|
||||
usage *OpenAIUsage
|
||||
imageCount int
|
||||
}
|
||||
|
||||
func openAIStreamClientOutputStarted(c *gin.Context, localStarted bool) bool {
|
||||
@ -3369,6 +3490,7 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough(
|
||||
}
|
||||
|
||||
usage := &OpenAIUsage{}
|
||||
imageCounter := newOpenAIImageOutputCounter()
|
||||
var firstTokenMs *int
|
||||
clientDisconnected := false
|
||||
sawDone := false
|
||||
@ -3400,6 +3522,9 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough(
|
||||
defer putSSEScannerBuf64K(scanBuf)
|
||||
|
||||
needModelReplace := strings.TrimSpace(originalModel) != "" && strings.TrimSpace(mappedModel) != "" && strings.TrimSpace(originalModel) != strings.TrimSpace(mappedModel)
|
||||
resultWithUsage := func() *openaiStreamingResultPassthrough {
|
||||
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs, imageCount: imageCounter.Count()}
|
||||
}
|
||||
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
@ -3419,7 +3544,7 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough(
|
||||
if eventType == "response.failed" {
|
||||
failedMessage = extractOpenAISSEErrorMessage(dataBytes)
|
||||
if !openAIStreamClientOutputStarted(c, clientOutputStarted) && openAIStreamFailedEventShouldFailover(dataBytes, failedMessage) {
|
||||
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs},
|
||||
return resultWithUsage(),
|
||||
s.newOpenAIStreamFailoverError(c, account, true, upstreamRequestID, dataBytes, failedMessage)
|
||||
}
|
||||
forceFlushFailedEvent = true
|
||||
@ -3431,6 +3556,7 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough(
|
||||
if openAIStreamEventIsTerminal(trimmedData) {
|
||||
sawTerminalEvent = true
|
||||
}
|
||||
imageCounter.AddSSEData(dataBytes)
|
||||
lineStartsClientOutput = forceFlushFailedEvent || openAIStreamDataStartsClientOutput(trimmedData, eventType)
|
||||
if firstTokenMs == nil && lineStartsClientOutput && trimmedData != "[DONE]" {
|
||||
ms := int(time.Since(startTime).Milliseconds())
|
||||
@ -3460,28 +3586,28 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough(
|
||||
}
|
||||
if err := scanner.Err(); err != nil {
|
||||
if sawTerminalEvent && !sawFailedEvent {
|
||||
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, nil
|
||||
return resultWithUsage(), nil
|
||||
}
|
||||
if sawFailedEvent {
|
||||
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("upstream response failed: %s", failedMessage)
|
||||
return resultWithUsage(), fmt.Errorf("upstream response failed: %s", failedMessage)
|
||||
}
|
||||
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
|
||||
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream usage incomplete: %w", err)
|
||||
return resultWithUsage(), fmt.Errorf("stream usage incomplete: %w", err)
|
||||
}
|
||||
if errors.Is(err, bufio.ErrTooLong) {
|
||||
logger.LegacyPrintf("service.openai_gateway", "[OpenAI passthrough] SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, err)
|
||||
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, err
|
||||
return resultWithUsage(), err
|
||||
}
|
||||
if !openAIStreamClientOutputStarted(c, clientOutputStarted) {
|
||||
msg := "OpenAI stream disconnected before completion"
|
||||
if errText := strings.TrimSpace(err.Error()); errText != "" {
|
||||
msg += ": " + errText
|
||||
}
|
||||
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs},
|
||||
return resultWithUsage(),
|
||||
s.newOpenAIStreamFailoverError(c, account, true, upstreamRequestID, nil, msg)
|
||||
}
|
||||
if clientDisconnected {
|
||||
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream usage incomplete after disconnect: %w", err)
|
||||
return resultWithUsage(), fmt.Errorf("stream usage incomplete after disconnect: %w", err)
|
||||
}
|
||||
logger.LegacyPrintf("service.openai_gateway",
|
||||
"[OpenAI passthrough] 流读取异常中断: account=%d request_id=%s err=%v",
|
||||
@ -3489,10 +3615,10 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough(
|
||||
upstreamRequestID,
|
||||
err,
|
||||
)
|
||||
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream read error: %w", err)
|
||||
return resultWithUsage(), fmt.Errorf("stream read error: %w", err)
|
||||
}
|
||||
if sawFailedEvent {
|
||||
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("upstream response failed: %s", failedMessage)
|
||||
return resultWithUsage(), fmt.Errorf("upstream response failed: %s", failedMessage)
|
||||
}
|
||||
if !clientDisconnected && !sawDone && !sawTerminalEvent && ctx.Err() == nil {
|
||||
logger.FromContext(ctx).With(
|
||||
@ -3501,13 +3627,13 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough(
|
||||
zap.String("upstream_request_id", upstreamRequestID),
|
||||
).Info("OpenAI passthrough 上游流在未收到 [DONE] 时结束,疑似断流")
|
||||
if !openAIStreamClientOutputStarted(c, clientOutputStarted) {
|
||||
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs},
|
||||
return resultWithUsage(),
|
||||
s.newOpenAIStreamFailoverError(c, account, true, upstreamRequestID, nil, "OpenAI stream ended before a terminal event")
|
||||
}
|
||||
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, errors.New("stream usage incomplete: missing terminal event")
|
||||
return resultWithUsage(), errors.New("stream usage incomplete: missing terminal event")
|
||||
}
|
||||
|
||||
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, nil
|
||||
return resultWithUsage(), nil
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) handleNonStreamingResponsePassthrough(
|
||||
@ -3516,7 +3642,7 @@ func (s *OpenAIGatewayService) handleNonStreamingResponsePassthrough(
|
||||
c *gin.Context,
|
||||
originalModel string,
|
||||
mappedModel string,
|
||||
) (*OpenAIUsage, error) {
|
||||
) (*openaiNonStreamingResultPassthrough, error) {
|
||||
body, err := ReadUpstreamResponseBody(resp.Body, s.cfg, c, openAITooLargeError)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -3553,14 +3679,18 @@ func (s *OpenAIGatewayService) handleNonStreamingResponsePassthrough(
|
||||
body = s.replaceModelInResponseBody(body, mappedModel, originalModel)
|
||||
}
|
||||
c.Data(resp.StatusCode, contentType, body)
|
||||
return usage, nil
|
||||
return &openaiNonStreamingResultPassthrough{
|
||||
OpenAIUsage: usage,
|
||||
usage: usage,
|
||||
imageCount: countOpenAIResponseImageOutputsFromJSONBytes(body),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// handlePassthroughSSEToJSON converts an SSE response body into a JSON
|
||||
// response for the passthrough path. It mirrors handleSSEToJSON while
|
||||
// preserving passthrough payloads, except compact-only model remapping may
|
||||
// rewrite model fields back to the original requested model.
|
||||
func (s *OpenAIGatewayService) handlePassthroughSSEToJSON(resp *http.Response, c *gin.Context, body []byte, originalModel string, mappedModel string) (*OpenAIUsage, error) {
|
||||
func (s *OpenAIGatewayService) handlePassthroughSSEToJSON(resp *http.Response, c *gin.Context, body []byte, originalModel string, mappedModel string) (*openaiNonStreamingResultPassthrough, error) {
|
||||
bodyText := string(body)
|
||||
finalResponse, ok := extractCodexFinalResponse(bodyText)
|
||||
|
||||
@ -3611,7 +3741,11 @@ func (s *OpenAIGatewayService) handlePassthroughSSEToJSON(resp *http.Response, c
|
||||
}
|
||||
c.Data(resp.StatusCode, contentType, body)
|
||||
|
||||
return usage, nil
|
||||
return &openaiNonStreamingResultPassthrough{
|
||||
OpenAIUsage: usage,
|
||||
usage: usage,
|
||||
imageCount: countOpenAIImageOutputsFromSSEBody(bodyText),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func writeOpenAIPassthroughResponseHeaders(dst http.Header, src http.Header, filter *responseheaders.CompiledHeaderFilter) {
|
||||
@ -3715,12 +3849,19 @@ func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin.
|
||||
}
|
||||
}
|
||||
if account.Type == AccountTypeOAuth {
|
||||
compatMessagesBridge := isOpenAICompatMessagesBridgeContext(c) || isOpenAICompatMessagesBridgeBody(body)
|
||||
// 清除客户端透传的 session 头,后续用隔离后的值重新设置,防止跨用户会话碰撞。
|
||||
clientConversationID := strings.TrimSpace(req.Header.Get("conversation_id"))
|
||||
req.Header.Del("conversation_id")
|
||||
req.Header.Del("session_id")
|
||||
|
||||
req.Header.Set("OpenAI-Beta", "responses=experimental")
|
||||
req.Header.Set("originator", resolveOpenAIUpstreamOriginator(c, isCodexCLI))
|
||||
if compatMessagesBridge {
|
||||
req.Header.Del("OpenAI-Beta")
|
||||
req.Header.Del("originator")
|
||||
} else {
|
||||
req.Header.Set("OpenAI-Beta", "responses=experimental")
|
||||
req.Header.Set("originator", resolveOpenAIUpstreamOriginator(c, isCodexCLI))
|
||||
}
|
||||
apiKeyID := getAPIKeyIDFromContext(c)
|
||||
if isOpenAIResponsesCompactPath(c) {
|
||||
req.Header.Set("accept", "application/json")
|
||||
@ -3734,8 +3875,10 @@ func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin.
|
||||
}
|
||||
if promptCacheKey != "" {
|
||||
isolated := isolateOpenAISessionID(apiKeyID, promptCacheKey)
|
||||
req.Header.Set("conversation_id", isolated)
|
||||
req.Header.Set("session_id", isolated)
|
||||
if !compatMessagesBridge || clientConversationID != "" {
|
||||
req.Header.Set("conversation_id", isolated)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -4025,6 +4168,13 @@ func (s *OpenAIGatewayService) handleCompatErrorResponse(
|
||||
type openaiStreamingResult struct {
|
||||
usage *OpenAIUsage
|
||||
firstTokenMs *int
|
||||
imageCount int
|
||||
}
|
||||
|
||||
type openaiNonStreamingResult struct {
|
||||
*OpenAIUsage
|
||||
usage *OpenAIUsage
|
||||
imageCount int
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, startTime time.Time, originalModel, mappedModel string) (*openaiStreamingResult, error) {
|
||||
@ -4058,6 +4208,7 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
|
||||
}
|
||||
|
||||
usage := &OpenAIUsage{}
|
||||
imageCounter := newOpenAIImageOutputCounter()
|
||||
var firstTokenMs *int
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
maxLineSize := defaultMaxLineSize
|
||||
@ -4136,7 +4287,7 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
|
||||
|
||||
needModelReplace := originalModel != mappedModel
|
||||
resultWithUsage := func() *openaiStreamingResult {
|
||||
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}
|
||||
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs, imageCount: imageCounter.Count()}
|
||||
}
|
||||
finalizeStream := func() (*openaiStreamingResult, error) {
|
||||
if !sawTerminalEvent {
|
||||
@ -4231,6 +4382,7 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
|
||||
forceFlushFailedEvent = true
|
||||
sawFailedEvent = true
|
||||
}
|
||||
imageCounter.AddSSEData(dataBytes)
|
||||
|
||||
// Correct Codex tool calls if needed (apply_patch -> edit, etc.)
|
||||
if correctedData, corrected := s.toolCorrector.CorrectToolCallsInSSEBytes(dataBytes); corrected {
|
||||
@ -4496,7 +4648,7 @@ func extractOpenAIUsageFromJSONBytes(body []byte) (OpenAIUsage, bool) {
|
||||
}, true
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) handleNonStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, originalModel, mappedModel string) (*OpenAIUsage, error) {
|
||||
func (s *OpenAIGatewayService) handleNonStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, originalModel, mappedModel string) (*openaiNonStreamingResult, error) {
|
||||
body, err := ReadUpstreamResponseBody(resp.Body, s.cfg, c, openAITooLargeError)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -4542,7 +4694,11 @@ func (s *OpenAIGatewayService) handleNonStreamingResponse(ctx context.Context, r
|
||||
|
||||
c.Data(resp.StatusCode, contentType, body)
|
||||
|
||||
return usage, nil
|
||||
return &openaiNonStreamingResult{
|
||||
OpenAIUsage: usage,
|
||||
usage: usage,
|
||||
imageCount: countOpenAIResponseImageOutputsFromJSONBytes(body),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func isEventStreamResponse(header http.Header) bool {
|
||||
@ -4550,7 +4706,7 @@ func isEventStreamResponse(header http.Header) bool {
|
||||
return strings.Contains(contentType, "text/event-stream")
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) handleSSEToJSON(resp *http.Response, c *gin.Context, body []byte, originalModel, mappedModel string) (*OpenAIUsage, error) {
|
||||
func (s *OpenAIGatewayService) handleSSEToJSON(resp *http.Response, c *gin.Context, body []byte, originalModel, mappedModel string) (*openaiNonStreamingResult, error) {
|
||||
bodyText := string(body)
|
||||
finalResponse, ok := extractCodexFinalResponse(bodyText)
|
||||
|
||||
@ -4602,21 +4758,29 @@ func (s *OpenAIGatewayService) handleSSEToJSON(resp *http.Response, c *gin.Conte
|
||||
}
|
||||
c.Data(resp.StatusCode, contentType, body)
|
||||
|
||||
return usage, nil
|
||||
return &openaiNonStreamingResult{
|
||||
OpenAIUsage: usage,
|
||||
usage: usage,
|
||||
imageCount: countOpenAIImageOutputsFromSSEBody(bodyText),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func extractOpenAISSETerminalEvent(body string) (string, []byte, bool) {
|
||||
lines := strings.Split(body, "\n")
|
||||
for _, line := range lines {
|
||||
data, ok := extractOpenAISSEDataLine(line)
|
||||
if !ok || data == "" || data == "[DONE]" {
|
||||
continue
|
||||
var terminalType string
|
||||
var terminalPayload []byte
|
||||
forEachOpenAISSEDataPayload(body, func(data []byte) {
|
||||
if terminalPayload != nil {
|
||||
return
|
||||
}
|
||||
eventType := strings.TrimSpace(gjson.Get(data, "type").String())
|
||||
eventType := strings.TrimSpace(gjson.GetBytes(data, "type").String())
|
||||
switch eventType {
|
||||
case "response.completed", "response.done", "response.failed", "response.incomplete", "response.cancelled", "response.canceled":
|
||||
return eventType, []byte(data), true
|
||||
terminalType = eventType
|
||||
terminalPayload = append([]byte(nil), data...)
|
||||
}
|
||||
})
|
||||
if terminalPayload != nil {
|
||||
return terminalType, terminalPayload, true
|
||||
}
|
||||
return "", nil, false
|
||||
}
|
||||
@ -4651,21 +4815,20 @@ func (s *OpenAIGatewayService) writeOpenAINonStreamingProtocolError(resp *http.R
|
||||
}
|
||||
|
||||
func extractCodexFinalResponse(body string) ([]byte, bool) {
|
||||
lines := strings.Split(body, "\n")
|
||||
for _, line := range lines {
|
||||
data, ok := extractOpenAISSEDataLine(line)
|
||||
if !ok {
|
||||
continue
|
||||
var finalResponse []byte
|
||||
forEachOpenAISSEDataPayload(body, func(data []byte) {
|
||||
if finalResponse != nil {
|
||||
return
|
||||
}
|
||||
if data == "" || data == "[DONE]" {
|
||||
continue
|
||||
}
|
||||
eventType := gjson.Get(data, "type").String()
|
||||
eventType := gjson.GetBytes(data, "type").String()
|
||||
if eventType == "response.done" || eventType == "response.completed" {
|
||||
if response := gjson.Get(data, "response"); response.Exists() && response.Type == gjson.JSON && response.Raw != "" {
|
||||
return []byte(response.Raw), true
|
||||
if response := gjson.GetBytes(data, "response"); response.Exists() && response.Type == gjson.JSON && response.Raw != "" {
|
||||
finalResponse = []byte(response.Raw)
|
||||
}
|
||||
}
|
||||
})
|
||||
if finalResponse != nil {
|
||||
return finalResponse, true
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
@ -4677,21 +4840,15 @@ func reconstructResponseOutputFromSSE(bodyText string) ([]byte, bool) {
|
||||
acc := apicompat.NewBufferedResponseAccumulator()
|
||||
imageOutputs := make([]json.RawMessage, 0, 1)
|
||||
seenImages := make(map[string]struct{})
|
||||
lines := strings.Split(bodyText, "\n")
|
||||
for _, line := range lines {
|
||||
data, ok := extractOpenAISSEDataLine(line)
|
||||
if !ok || data == "" || data == "[DONE]" {
|
||||
continue
|
||||
}
|
||||
if imageOutput, ok := extractImageGenerationOutputFromSSEData([]byte(data), seenImages); ok {
|
||||
forEachOpenAISSEDataPayload(bodyText, func(data []byte) {
|
||||
if imageOutput, ok := extractImageGenerationOutputFromSSEData(data, seenImages); ok {
|
||||
imageOutputs = append(imageOutputs, imageOutput)
|
||||
}
|
||||
var event apicompat.ResponsesStreamEvent
|
||||
if err := json.Unmarshal([]byte(data), &event); err != nil {
|
||||
continue
|
||||
if err := json.Unmarshal(data, &event); err == nil {
|
||||
acc.ProcessEvent(&event)
|
||||
}
|
||||
acc.ProcessEvent(&event)
|
||||
}
|
||||
})
|
||||
if !acc.HasContent() && len(imageOutputs) == 0 {
|
||||
return nil, false
|
||||
}
|
||||
@ -4744,17 +4901,9 @@ func extractImageGenerationOutputFromSSEData(data []byte, seen map[string]struct
|
||||
|
||||
func (s *OpenAIGatewayService) parseSSEUsageFromBody(body string) *OpenAIUsage {
|
||||
usage := &OpenAIUsage{}
|
||||
lines := strings.Split(body, "\n")
|
||||
for _, line := range lines {
|
||||
data, ok := extractOpenAISSEDataLine(line)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if data == "" || data == "[DONE]" {
|
||||
continue
|
||||
}
|
||||
s.parseSSEUsageBytes([]byte(data), usage)
|
||||
}
|
||||
forEachOpenAISSEDataPayload(body, func(data []byte) {
|
||||
s.parseSSEUsageBytes(data, usage)
|
||||
})
|
||||
return usage
|
||||
}
|
||||
|
||||
@ -5036,16 +5185,15 @@ type OpenAIRecordUsageInput struct {
|
||||
|
||||
// RecordUsage records usage and deducts balance
|
||||
func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRecordUsageInput) error {
|
||||
result := input.Result
|
||||
if s.rateLimitService != nil && input != nil && input.Account != nil && input.Account.Platform == PlatformOpenAI {
|
||||
s.rateLimitService.ResetOpenAI403Counter(ctx, input.Account.ID)
|
||||
if input == nil {
|
||||
return errors.New("openai usage input is nil")
|
||||
}
|
||||
|
||||
// 跳过所有 token 均为零的用量记录——上游未返回 usage 时不应写入数据库
|
||||
if result.Usage.InputTokens == 0 && result.Usage.OutputTokens == 0 &&
|
||||
result.Usage.CacheCreationInputTokens == 0 && result.Usage.CacheReadInputTokens == 0 &&
|
||||
result.Usage.ImageOutputTokens == 0 && result.ImageCount == 0 {
|
||||
return nil
|
||||
result := input.Result
|
||||
if result == nil {
|
||||
return errors.New("openai usage result is nil")
|
||||
}
|
||||
if s.rateLimitService != nil && input.Account != nil && input.Account.Platform == PlatformOpenAI {
|
||||
s.rateLimitService.ResetOpenAI403Counter(ctx, input.Account.ID)
|
||||
}
|
||||
|
||||
apiKey := input.APIKey
|
||||
@ -5081,6 +5229,7 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
|
||||
}
|
||||
multiplier = resolver.Resolve(ctx, user.ID, *apiKey.GroupID, apiKey.Group.RateMultiplier)
|
||||
}
|
||||
imageMultiplier := resolveImageRateMultiplier(apiKey, multiplier)
|
||||
|
||||
var cost *CostBreakdown
|
||||
var err error
|
||||
@ -5094,13 +5243,21 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
|
||||
if input.BillingModelSource == BillingModelSourceRequested && input.OriginalModel != "" {
|
||||
billingModel = input.OriginalModel
|
||||
}
|
||||
billingModels := usageBillingModelCandidates(
|
||||
billingModel,
|
||||
result.BillingModel,
|
||||
input.ChannelMappedModel,
|
||||
input.OriginalModel,
|
||||
result.UpstreamModel,
|
||||
result.Model,
|
||||
)
|
||||
serviceTier := ""
|
||||
if result.ServiceTier != nil {
|
||||
serviceTier = strings.TrimSpace(*result.ServiceTier)
|
||||
}
|
||||
cost, err = s.calculateOpenAIRecordUsageCost(ctx, result, apiKey, billingModel, multiplier, tokens, serviceTier)
|
||||
cost, err = s.calculateOpenAIRecordUsageCost(ctx, result, apiKey, billingModels, multiplier, imageMultiplier, tokens, serviceTier)
|
||||
if err != nil {
|
||||
cost = &CostBreakdown{ActualCost: 0}
|
||||
return err
|
||||
}
|
||||
|
||||
// Determine billing type
|
||||
@ -5150,7 +5307,11 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
|
||||
usageLog.TotalCost = cost.TotalCost
|
||||
usageLog.ActualCost = cost.ActualCost
|
||||
}
|
||||
usageLog.RateMultiplier = multiplier
|
||||
if result.ImageCount > 0 {
|
||||
usageLog.RateMultiplier = imageMultiplier
|
||||
} else {
|
||||
usageLog.RateMultiplier = multiplier
|
||||
}
|
||||
usageLog.AccountRateMultiplier = &accountRateMultiplier
|
||||
usageLog.BillingType = billingType
|
||||
usageLog.Stream = result.Stream
|
||||
@ -5231,14 +5392,45 @@ func (s *OpenAIGatewayService) calculateOpenAIRecordUsageCost(
|
||||
ctx context.Context,
|
||||
result *OpenAIForwardResult,
|
||||
apiKey *APIKey,
|
||||
billingModels []string,
|
||||
multiplier float64,
|
||||
imageMultiplier float64,
|
||||
tokens UsageTokens,
|
||||
serviceTier string,
|
||||
) (*CostBreakdown, error) {
|
||||
billingModel := firstUsageBillingModel(billingModels)
|
||||
if result != nil && result.ImageCount > 0 {
|
||||
return s.calculateOpenAIImageCost(ctx, billingModel, apiKey, result, imageMultiplier), nil
|
||||
}
|
||||
if len(billingModels) == 0 || billingModel == "" {
|
||||
return nil, errors.New("openai usage billing model is empty")
|
||||
}
|
||||
var lastErr error
|
||||
for _, candidate := range billingModels {
|
||||
candidate = strings.TrimSpace(candidate)
|
||||
if candidate == "" {
|
||||
continue
|
||||
}
|
||||
cost, err := s.calculateOpenAIRecordUsageTokenCost(ctx, apiKey, candidate, multiplier, tokens, serviceTier)
|
||||
if err == nil {
|
||||
return cost, nil
|
||||
}
|
||||
lastErr = err
|
||||
}
|
||||
if lastErr == nil {
|
||||
lastErr = errors.New("no non-empty billing model candidates")
|
||||
}
|
||||
return nil, fmt.Errorf("calculate OpenAI usage cost failed for billing models %s: %w", strings.Join(billingModels, ","), lastErr)
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) calculateOpenAIRecordUsageTokenCost(
|
||||
ctx context.Context,
|
||||
apiKey *APIKey,
|
||||
billingModel string,
|
||||
multiplier float64,
|
||||
tokens UsageTokens,
|
||||
serviceTier string,
|
||||
) (*CostBreakdown, error) {
|
||||
if result != nil && result.ImageCount > 0 {
|
||||
return s.calculateOpenAIImageCost(ctx, billingModel, apiKey, result, multiplier), nil
|
||||
}
|
||||
if s.resolver != nil && apiKey.Group != nil {
|
||||
gid := apiKey.Group.ID
|
||||
return s.billingService.CalculateCostUnified(CostInput{
|
||||
@ -5269,7 +5461,7 @@ func (s *OpenAIGatewayService) calculateOpenAIImageCost(
|
||||
Ctx: ctx,
|
||||
Model: billingModel,
|
||||
GroupID: &gid,
|
||||
RequestCount: 1,
|
||||
RequestCount: result.ImageCount,
|
||||
SizeTier: result.ImageSize,
|
||||
RateMultiplier: multiplier,
|
||||
Resolver: s.resolver,
|
||||
|
||||
@ -1846,6 +1846,29 @@ func TestOpenAIBuildUpstreamRequestCompactForcesJSONAcceptForOAuth(t *testing.T)
|
||||
require.NotEmpty(t, req.Header.Get("Session_Id"))
|
||||
}
|
||||
|
||||
func TestOpenAIBuildUpstreamRequestOAuthMessagesBridgeUsesSessionOnly(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
body := []byte(`{"model":"gpt-5.5","prompt_cache_key":"anthropic-metadata-session-1","input":[{"type":"message","role":"developer","content":[{"type":"input_text","text":"<sub2api-claude-code-todo-guard>"}]},{"type":"message","role":"user","content":"hello"}]}`)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(body))
|
||||
c.Request.Header.Set("OpenAI-Beta", "responses=experimental")
|
||||
c.Request.Header.Set("originator", "codex_cli_rs")
|
||||
|
||||
svc := &OpenAIGatewayService{}
|
||||
account := &Account{
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{"chatgpt_account_id": "chatgpt-acc"},
|
||||
}
|
||||
|
||||
req, err := svc.buildUpstreamRequest(c.Request.Context(), c, account, body, "token", true, "anthropic-metadata-session-1", false)
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, req.Header.Get("Session_Id"))
|
||||
require.Empty(t, req.Header.Get("Conversation_Id"))
|
||||
require.Empty(t, req.Header.Get("OpenAI-Beta"))
|
||||
require.Empty(t, req.Header.Get("originator"))
|
||||
}
|
||||
|
||||
func TestOpenAIBuildUpstreamRequestPreservesCompactPathForAPIKeyBaseURL(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
@ -0,0 +1,215 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
func TestOpenAIGatewayServiceForward_RejectsDisabledImageGenerationIntents(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
body []byte
|
||||
}{
|
||||
{
|
||||
name: "image model",
|
||||
body: []byte(`{"model":"gpt-image-2","input":"draw"}`),
|
||||
},
|
||||
{
|
||||
name: "image tool",
|
||||
body: []byte(`{"model":"gpt-5.4","input":"draw","tools":[{"type":"image_generation"}]}`),
|
||||
},
|
||||
{
|
||||
name: "image tool choice",
|
||||
body: []byte(`{"model":"gpt-5.4","input":"draw","tool_choice":{"type":"image_generation"}}`),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
upstream := &httpUpstreamRecorder{}
|
||||
svc := newOpenAIImageGenerationControlTestService(upstream)
|
||||
c, recorder := newOpenAIImageGenerationControlTestContext(false, "unit-test-agent/1.0")
|
||||
account := newOpenAIImageGenerationControlTestAccount()
|
||||
|
||||
result, err := svc.Forward(context.Background(), c, account, tt.body)
|
||||
|
||||
require.Error(t, err)
|
||||
require.Nil(t, result)
|
||||
require.Equal(t, http.StatusForbidden, recorder.Code)
|
||||
require.Equal(t, "permission_error", gjson.GetBytes(recorder.Body.Bytes(), "error.type").String())
|
||||
require.Nil(t, upstream.lastReq, "disabled image request must not reach upstream")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayServiceForward_DisabledGroupAllowsTextOnlyResponses(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
upstream := &httpUpstreamRecorder{
|
||||
resp: &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{"Content-Type": []string{"application/json"}},
|
||||
Body: io.NopCloser(strings.NewReader(`{"id":"resp_text","model":"gpt-5.4","usage":{"input_tokens":3,"output_tokens":2}}`)),
|
||||
},
|
||||
}
|
||||
svc := newOpenAIImageGenerationControlTestService(upstream)
|
||||
c, recorder := newOpenAIImageGenerationControlTestContext(false, "unit-test-agent/1.0")
|
||||
account := newOpenAIImageGenerationControlTestAccount()
|
||||
|
||||
result, err := svc.Forward(context.Background(), c, account, []byte(`{"model":"gpt-5.4","input":"write code","stream":false}`))
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.Equal(t, http.StatusOK, recorder.Code)
|
||||
require.Equal(t, 3, result.Usage.InputTokens)
|
||||
require.Equal(t, 2, result.Usage.OutputTokens)
|
||||
require.Equal(t, 0, result.ImageCount)
|
||||
require.NotNil(t, upstream.lastReq)
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayServiceForward_CodexImageInjectionRespectsGroupCapability(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
allowImages bool
|
||||
wantInjected bool
|
||||
}{
|
||||
{name: "disabled group skips injection", allowImages: false, wantInjected: false},
|
||||
{name: "enabled group injects image tool", allowImages: true, wantInjected: true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
upstream := &httpUpstreamRecorder{
|
||||
resp: &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{"Content-Type": []string{"application/json"}},
|
||||
Body: io.NopCloser(strings.NewReader(`{"id":"resp_codex","model":"gpt-5.4","usage":{"input_tokens":1,"output_tokens":1}}`)),
|
||||
},
|
||||
}
|
||||
svc := newOpenAIImageGenerationControlTestService(upstream)
|
||||
c, _ := newOpenAIImageGenerationControlTestContext(tt.allowImages, "codex_cli_rs/0.98.0")
|
||||
account := newOpenAIImageGenerationControlTestAccount()
|
||||
|
||||
result, err := svc.Forward(context.Background(), c, account, []byte(`{"model":"gpt-5.4","input":"write code","stream":false}`))
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.NotNil(t, upstream.lastReq)
|
||||
hasImageTool := gjson.GetBytes(upstream.lastBody, `tools.#(type=="image_generation")`).Exists()
|
||||
require.Equal(t, tt.wantInjected, hasImageTool)
|
||||
instructions := gjson.GetBytes(upstream.lastBody, "instructions").String()
|
||||
require.Equal(t, tt.wantInjected, strings.Contains(instructions, "image_generation"))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayServiceHandleResponsesImageOutputs_NonStreaming(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
svc := newOpenAIImageGenerationControlTestService(&httpUpstreamRecorder{})
|
||||
c, _ := newOpenAIImageGenerationControlTestContext(true, "unit-test-agent/1.0")
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{"Content-Type": []string{"application/json"}},
|
||||
Body: io.NopCloser(strings.NewReader(`{
|
||||
"id":"resp_image_json",
|
||||
"model":"gpt-5.4",
|
||||
"output":[{"id":"ig_json_1","type":"image_generation_call","result":"final-image"}],
|
||||
"usage":{"input_tokens":7,"output_tokens":3,"output_tokens_details":{"image_tokens":2}}
|
||||
}`)),
|
||||
}
|
||||
|
||||
result, err := svc.handleNonStreamingResponse(context.Background(), resp, c, &Account{ID: 1, Type: AccountTypeAPIKey}, "gpt-5.4", "gpt-5.4")
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.Equal(t, 1, result.imageCount)
|
||||
require.NotNil(t, result.usage)
|
||||
require.Equal(t, 7, result.usage.InputTokens)
|
||||
require.Equal(t, 3, result.usage.OutputTokens)
|
||||
require.Equal(t, 2, result.usage.ImageOutputTokens)
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayServiceHandleResponsesImageOutputs_Streaming(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
svc := newOpenAIImageGenerationControlTestService(&httpUpstreamRecorder{})
|
||||
c, _ := newOpenAIImageGenerationControlTestContext(true, "unit-test-agent/1.0")
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{"Content-Type": []string{"text/event-stream"}},
|
||||
Body: io.NopCloser(strings.NewReader(
|
||||
"data: {\"type\":\"response.output_item.done\",\"item\":{\"id\":\"ig_stream_1\",\"type\":\"image_generation_call\",\"result\":\"final-image\"}}\n\n" +
|
||||
"data: {\"type\":\"response.completed\",\"response\":{\"id\":\"resp_image_stream\",\"model\":\"gpt-5.5\",\"output\":[{\"id\":\"ig_stream_1\",\"type\":\"image_generation_call\",\"result\":\"final-image\"}],\"usage\":{\"input_tokens\":11,\"output_tokens\":5,\"output_tokens_details\":{\"image_tokens\":4}}}}\n\n",
|
||||
)),
|
||||
}
|
||||
|
||||
result, err := svc.handleStreamingResponse(context.Background(), resp, c, &Account{ID: 1}, time.Now(), "gpt-5.5", "gpt-5.5")
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.Equal(t, 1, result.imageCount)
|
||||
require.NotNil(t, result.usage)
|
||||
require.Equal(t, 11, result.usage.InputTokens)
|
||||
require.Equal(t, 5, result.usage.OutputTokens)
|
||||
require.Equal(t, 4, result.usage.ImageOutputTokens)
|
||||
}
|
||||
|
||||
func newOpenAIImageGenerationControlTestService(upstream *httpUpstreamRecorder) *OpenAIGatewayService {
|
||||
cfg := &config.Config{}
|
||||
return &OpenAIGatewayService{
|
||||
cfg: cfg,
|
||||
httpUpstream: upstream,
|
||||
cache: &stubGatewayCache{},
|
||||
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
|
||||
toolCorrector: NewCodexToolCorrector(),
|
||||
}
|
||||
}
|
||||
|
||||
func newOpenAIImageGenerationControlTestContext(allowImages bool, userAgent string) (*gin.Context, *httptest.ResponseRecorder) {
|
||||
recorder := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(recorder)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
|
||||
c.Request.Header.Set("User-Agent", userAgent)
|
||||
groupID := int64(4242)
|
||||
c.Set("api_key", &APIKey{
|
||||
ID: 2424,
|
||||
GroupID: &groupID,
|
||||
Group: &Group{
|
||||
ID: groupID,
|
||||
AllowImageGeneration: allowImages,
|
||||
RateMultiplier: 1,
|
||||
ImageRateMultiplier: 1,
|
||||
},
|
||||
})
|
||||
return c, recorder
|
||||
}
|
||||
|
||||
func newOpenAIImageGenerationControlTestAccount() *Account {
|
||||
return &Account{
|
||||
ID: 5151,
|
||||
Name: "openai-image-controls",
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeAPIKey,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{
|
||||
"api_key": "sk-test",
|
||||
},
|
||||
}
|
||||
}
|
||||
@ -16,6 +16,7 @@ import (
|
||||
"net/textproto"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
@ -468,14 +469,54 @@ func isOpenAINativeImageOption(name string) bool {
|
||||
}
|
||||
|
||||
func normalizeOpenAIImageSizeTier(size string) string {
|
||||
switch strings.ToLower(strings.TrimSpace(size)) {
|
||||
trimmed := strings.TrimSpace(size)
|
||||
normalized := strings.ToLower(trimmed)
|
||||
switch normalized {
|
||||
case "", "auto":
|
||||
return "2K"
|
||||
case "1024x1024":
|
||||
return "1K"
|
||||
case "1536x1024", "1024x1536", "1792x1024", "1024x1792", "", "auto":
|
||||
case "1536x1024", "1024x1536", "1792x1024", "1024x1792", "2048x2048", "2048x1152", "1152x2048":
|
||||
return "2K"
|
||||
default:
|
||||
case "3840x2160", "2160x3840":
|
||||
return "4K"
|
||||
}
|
||||
width, height, ok := parseOpenAIImageSizeDimensions(trimmed)
|
||||
if !ok {
|
||||
return "2K"
|
||||
}
|
||||
return classifyUnknownOpenAIImageSizeTier(width, height)
|
||||
}
|
||||
|
||||
const (
|
||||
openAIImage2KMaxPixels = 2560 * 1440
|
||||
)
|
||||
|
||||
func parseOpenAIImageSizeDimensions(size string) (int, int, bool) {
|
||||
trimmed := strings.TrimSpace(size)
|
||||
parts := strings.Split(strings.ToLower(trimmed), "x")
|
||||
if len(parts) != 2 {
|
||||
return 0, 0, false
|
||||
}
|
||||
width, err := strconv.Atoi(strings.TrimSpace(parts[0]))
|
||||
if err != nil {
|
||||
return 0, 0, false
|
||||
}
|
||||
height, err := strconv.Atoi(strings.TrimSpace(parts[1]))
|
||||
if err != nil {
|
||||
return 0, 0, false
|
||||
}
|
||||
if width <= 0 || height <= 0 {
|
||||
return 0, 0, false
|
||||
}
|
||||
return width, height, true
|
||||
}
|
||||
|
||||
func classifyUnknownOpenAIImageSizeTier(width int, height int) string {
|
||||
if height > 0 && width > openAIImage2KMaxPixels/height {
|
||||
return "4K"
|
||||
}
|
||||
return "2K"
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) ForwardImages(
|
||||
@ -535,11 +576,14 @@ func (s *OpenAIGatewayService) forwardOpenAIImagesAPIKey(
|
||||
setOpsUpstreamRequestBody(c, forwardBody)
|
||||
}
|
||||
|
||||
token, _, err := s.GetAccessToken(ctx, account)
|
||||
upstreamCtx, releaseUpstreamCtx := detachStreamUpstreamContext(ctx, parsed.Stream)
|
||||
defer releaseUpstreamCtx()
|
||||
|
||||
token, _, err := s.GetAccessToken(upstreamCtx, account)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
upstreamReq, err := s.buildOpenAIImagesRequest(ctx, c, account, forwardBody, forwardContentType, token, parsed.Endpoint)
|
||||
upstreamReq, err := s.buildOpenAIImagesRequest(upstreamCtx, c, account, forwardBody, forwardContentType, token, parsed.Endpoint)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -582,23 +626,37 @@ func (s *OpenAIGatewayService) forwardOpenAIImagesAPIKey(
|
||||
Kind: "failover",
|
||||
Message: upstreamMsg,
|
||||
})
|
||||
s.handleFailoverSideEffects(ctx, resp, account)
|
||||
s.handleFailoverSideEffects(upstreamCtx, resp, account)
|
||||
return nil, &UpstreamFailoverError{
|
||||
StatusCode: resp.StatusCode,
|
||||
ResponseBody: respBody,
|
||||
RetryableOnSameAccount: account.IsPoolMode() && isPoolModeRetryableStatus(resp.StatusCode),
|
||||
}
|
||||
}
|
||||
return s.handleErrorResponse(ctx, resp, c, account, forwardBody)
|
||||
return s.handleErrorResponse(upstreamCtx, resp, c, account, forwardBody)
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
var usage OpenAIUsage
|
||||
imageCount := parsed.N
|
||||
var firstTokenMs *int
|
||||
if parsed.Stream {
|
||||
if parsed.Stream && isEventStreamResponse(resp.Header) {
|
||||
streamUsage, streamCount, ttft, err := s.handleOpenAIImagesStreamingResponse(resp, c, startTime)
|
||||
if err != nil {
|
||||
if streamCount > 0 {
|
||||
return &OpenAIForwardResult{
|
||||
RequestID: resp.Header.Get("x-request-id"),
|
||||
Usage: streamUsage,
|
||||
Model: requestModel,
|
||||
UpstreamModel: upstreamModel,
|
||||
Stream: parsed.Stream,
|
||||
ResponseHeaders: resp.Header.Clone(),
|
||||
Duration: time.Since(startTime),
|
||||
FirstTokenMs: ttft,
|
||||
ImageCount: streamCount,
|
||||
ImageSize: parsed.SizeTier,
|
||||
}, err
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
usage = streamUsage
|
||||
@ -807,39 +865,228 @@ func (s *OpenAIGatewayService) handleOpenAIImagesStreamingResponse(
|
||||
return OpenAIUsage{}, 0, nil, fmt.Errorf("streaming is not supported by response writer")
|
||||
}
|
||||
|
||||
reader := bufio.NewReader(resp.Body)
|
||||
usage := OpenAIUsage{}
|
||||
imageCount := 0
|
||||
imageCounter := newOpenAIImageOutputCounter()
|
||||
var firstTokenMs *int
|
||||
clientDisconnected := false
|
||||
lastDownstreamWriteAt := time.Now()
|
||||
var fallbackBody bytes.Buffer
|
||||
fallbackBytes := int64(0)
|
||||
fallbackLimit := resolveUpstreamResponseReadLimit(s.cfg)
|
||||
seenSSEData := false
|
||||
fallbackTooLarge := false
|
||||
var sseData openAISSEDataAccumulator
|
||||
|
||||
for {
|
||||
line, err := reader.ReadBytes('\n')
|
||||
if len(line) > 0 {
|
||||
if firstTokenMs == nil {
|
||||
ms := int(time.Since(startTime).Milliseconds())
|
||||
firstTokenMs = &ms
|
||||
}
|
||||
processSSEData := func(dataBytes []byte) {
|
||||
seenSSEData = true
|
||||
fallbackBody.Reset()
|
||||
fallbackBytes = 0
|
||||
mergeOpenAIUsage(&usage, dataBytes)
|
||||
imageCounter.AddSSEData(dataBytes)
|
||||
}
|
||||
|
||||
flushSSEEvent := func() {
|
||||
sseData.Flush(processSSEData)
|
||||
}
|
||||
|
||||
processLine := func(line []byte) {
|
||||
if len(line) == 0 {
|
||||
return
|
||||
}
|
||||
if firstTokenMs == nil {
|
||||
ms := int(time.Since(startTime).Milliseconds())
|
||||
firstTokenMs = &ms
|
||||
}
|
||||
if !clientDisconnected {
|
||||
if _, writeErr := c.Writer.Write(line); writeErr != nil {
|
||||
return OpenAIUsage{}, 0, firstTokenMs, writeErr
|
||||
clientDisconnected = true
|
||||
logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Images stream client disconnected, continue draining upstream for billing")
|
||||
} else {
|
||||
flusher.Flush()
|
||||
lastDownstreamWriteAt = time.Now()
|
||||
}
|
||||
flusher.Flush()
|
||||
}
|
||||
|
||||
if data, ok := extractOpenAISSEDataLine(strings.TrimRight(string(line), "\r\n")); ok && data != "" && data != "[DONE]" {
|
||||
dataBytes := []byte(data)
|
||||
mergeOpenAIUsage(&usage, dataBytes)
|
||||
if count := extractOpenAIImageCountFromJSONBytes(dataBytes); count > imageCount {
|
||||
imageCount = count
|
||||
}
|
||||
trimmedLine := strings.TrimRight(string(line), "\r\n")
|
||||
if _, ok := extractOpenAISSEDataLine(trimmedLine); ok || strings.TrimSpace(trimmedLine) == "" {
|
||||
sseData.AddLine(trimmedLine, processSSEData)
|
||||
return
|
||||
}
|
||||
if !seenSSEData && !fallbackTooLarge {
|
||||
fallbackBytes += int64(len(line))
|
||||
if fallbackBytes <= fallbackLimit {
|
||||
_, _ = fallbackBody.Write(line)
|
||||
} else {
|
||||
fallbackTooLarge = true
|
||||
fallbackBody.Reset()
|
||||
}
|
||||
}
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
return OpenAIUsage{}, 0, firstTokenMs, err
|
||||
}
|
||||
}
|
||||
return usage, imageCount, firstTokenMs, nil
|
||||
|
||||
finalizeFallbackBody := func() {
|
||||
if seenSSEData || fallbackBody.Len() == 0 {
|
||||
return
|
||||
}
|
||||
body := bytes.TrimSpace(fallbackBody.Bytes())
|
||||
if len(body) == 0 {
|
||||
return
|
||||
}
|
||||
mergeOpenAIUsage(&usage, body)
|
||||
imageCounter.AddJSONResponse(body)
|
||||
}
|
||||
|
||||
streamInterval := s.openAIImageStreamDataInterval()
|
||||
keepaliveInterval := s.openAIImageStreamKeepaliveInterval()
|
||||
if streamInterval <= 0 && keepaliveInterval <= 0 {
|
||||
reader := bufio.NewReader(resp.Body)
|
||||
for {
|
||||
line, err := reader.ReadBytes('\n')
|
||||
processLine(line)
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
flushSSEEvent()
|
||||
return usage, imageCounter.Count(), firstTokenMs, err
|
||||
}
|
||||
}
|
||||
flushSSEEvent()
|
||||
finalizeFallbackBody()
|
||||
return usage, imageCounter.Count(), firstTokenMs, nil
|
||||
}
|
||||
|
||||
type readEvent struct {
|
||||
line []byte
|
||||
err error
|
||||
}
|
||||
events := make(chan readEvent, 16)
|
||||
done := make(chan struct{})
|
||||
sendEvent := func(ev readEvent) bool {
|
||||
select {
|
||||
case events <- ev:
|
||||
return true
|
||||
case <-done:
|
||||
return false
|
||||
}
|
||||
}
|
||||
var lastReadAt int64
|
||||
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
|
||||
go func() {
|
||||
defer close(events)
|
||||
reader := bufio.NewReader(resp.Body)
|
||||
for {
|
||||
line, err := reader.ReadBytes('\n')
|
||||
if len(line) > 0 {
|
||||
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
|
||||
}
|
||||
if len(line) > 0 && !sendEvent(readEvent{line: line}) {
|
||||
return
|
||||
}
|
||||
if err == io.EOF {
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
_ = sendEvent(readEvent{err: err})
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
defer close(done)
|
||||
|
||||
var intervalTicker *time.Ticker
|
||||
if streamInterval > 0 {
|
||||
intervalTicker = time.NewTicker(streamInterval)
|
||||
defer intervalTicker.Stop()
|
||||
}
|
||||
var intervalCh <-chan time.Time
|
||||
if intervalTicker != nil {
|
||||
intervalCh = intervalTicker.C
|
||||
}
|
||||
|
||||
var keepaliveTicker *time.Ticker
|
||||
if keepaliveInterval > 0 {
|
||||
keepaliveTicker = time.NewTicker(keepaliveInterval)
|
||||
defer keepaliveTicker.Stop()
|
||||
}
|
||||
var keepaliveCh <-chan time.Time
|
||||
if keepaliveTicker != nil {
|
||||
keepaliveCh = keepaliveTicker.C
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
case ev, ok := <-events:
|
||||
if !ok {
|
||||
flushSSEEvent()
|
||||
finalizeFallbackBody()
|
||||
return usage, imageCounter.Count(), firstTokenMs, nil
|
||||
}
|
||||
if ev.err != nil {
|
||||
flushSSEEvent()
|
||||
return usage, imageCounter.Count(), firstTokenMs, ev.err
|
||||
}
|
||||
processLine(ev.line)
|
||||
case <-intervalCh:
|
||||
lastRead := time.Unix(0, atomic.LoadInt64(&lastReadAt))
|
||||
if time.Since(lastRead) < streamInterval {
|
||||
continue
|
||||
}
|
||||
if clientDisconnected {
|
||||
return usage, imageCounter.Count(), firstTokenMs, fmt.Errorf("image stream incomplete after timeout")
|
||||
}
|
||||
logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Images stream data interval timeout: interval=%s", streamInterval)
|
||||
_ = s.writeOpenAIImagesStreamEvent(c, flusher, "error", buildOpenAIImagesStreamErrorBody(fmt.Sprintf("upstream image stream idle for %s", streamInterval)))
|
||||
return usage, imageCounter.Count(), firstTokenMs, fmt.Errorf("image stream data interval timeout")
|
||||
case <-keepaliveCh:
|
||||
if clientDisconnected || time.Since(lastDownstreamWriteAt) < keepaliveInterval {
|
||||
continue
|
||||
}
|
||||
if _, writeErr := io.WriteString(c.Writer, ":\n\n"); writeErr != nil {
|
||||
clientDisconnected = true
|
||||
logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Images stream client disconnected during keepalive, continue draining upstream for billing")
|
||||
continue
|
||||
}
|
||||
flusher.Flush()
|
||||
lastDownstreamWriteAt = time.Now()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) openAIImageStreamDataInterval() time.Duration {
|
||||
if s == nil || s.cfg == nil || s.cfg.Gateway.ImageStreamDataIntervalTimeout <= 0 {
|
||||
return 0
|
||||
}
|
||||
return time.Duration(s.cfg.Gateway.ImageStreamDataIntervalTimeout) * time.Second
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) openAIImageStreamKeepaliveInterval() time.Duration {
|
||||
if s == nil || s.cfg == nil || s.cfg.Gateway.ImageStreamKeepaliveInterval <= 0 {
|
||||
return 0
|
||||
}
|
||||
return time.Duration(s.cfg.Gateway.ImageStreamKeepaliveInterval) * time.Second
|
||||
}
|
||||
|
||||
func extractOpenAIImagesBillableCountFromJSONBytes(body []byte) int {
|
||||
if count := extractOpenAIImageCountFromJSONBytes(body); count > 0 {
|
||||
return count
|
||||
}
|
||||
if len(body) == 0 || !gjson.ValidBytes(body) {
|
||||
return 0
|
||||
}
|
||||
if count := int(gjson.GetBytes(body, "usage.images").Int()); count > 0 {
|
||||
return count
|
||||
}
|
||||
if count := int(gjson.GetBytes(body, "tool_usage.image_gen.images").Int()); count > 0 {
|
||||
return count
|
||||
}
|
||||
eventType := strings.TrimSpace(gjson.GetBytes(body, "type").String())
|
||||
if eventType == "" || !strings.HasSuffix(eventType, ".completed") {
|
||||
return 0
|
||||
}
|
||||
if gjson.GetBytes(body, "b64_json").Exists() || gjson.GetBytes(body, "url").Exists() {
|
||||
return 1
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func mergeOpenAIUsage(dst *OpenAIUsage, body []byte) {
|
||||
@ -863,14 +1110,7 @@ func mergeOpenAIUsage(dst *OpenAIUsage, body []byte) {
|
||||
}
|
||||
|
||||
func extractOpenAIImageCountFromJSONBytes(body []byte) int {
|
||||
if len(body) == 0 || !gjson.ValidBytes(body) {
|
||||
return 0
|
||||
}
|
||||
data := gjson.GetBytes(body, "data")
|
||||
if data.Exists() && data.IsArray() {
|
||||
return len(data.Array())
|
||||
}
|
||||
return 0
|
||||
return countOpenAIResponseImageOutputsFromJSONBytes(body)
|
||||
}
|
||||
|
||||
type openAIImagePointerInfo struct {
|
||||
|
||||
@ -9,6 +9,7 @@ import (
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
@ -361,21 +362,21 @@ func collectOpenAIImagesFromResponsesBody(body []byte) ([]openAIResponsesImageRe
|
||||
var (
|
||||
fallbackResults []openAIResponsesImageResult
|
||||
fallbackSeen = make(map[string]struct{})
|
||||
finalResults []openAIResponsesImageResult
|
||||
finalMeta openAIResponsesImageResult
|
||||
collectErr error
|
||||
createdAt int64
|
||||
usageRaw []byte
|
||||
foundFinal bool
|
||||
responseMeta openAIResponsesImageResult
|
||||
)
|
||||
|
||||
for _, line := range bytes.Split(body, []byte("\n")) {
|
||||
line = bytes.TrimRight(line, "\r")
|
||||
data, ok := extractOpenAISSEDataLine(string(line))
|
||||
if !ok || data == "" || data == "[DONE]" {
|
||||
continue
|
||||
forEachOpenAISSEDataPayload(string(body), func(payload []byte) {
|
||||
if collectErr != nil || len(finalResults) > 0 {
|
||||
return
|
||||
}
|
||||
payload := []byte(data)
|
||||
if !gjson.ValidBytes(payload) {
|
||||
continue
|
||||
return
|
||||
}
|
||||
if meta, eventCreatedAt, ok := extractOpenAIResponsesImageMetaFromLifecycleEvent(payload); ok {
|
||||
mergeOpenAIResponsesImageMeta(&responseMeta, meta)
|
||||
@ -388,7 +389,8 @@ func collectOpenAIImagesFromResponsesBody(body []byte) ([]openAIResponsesImageRe
|
||||
case "response.output_item.done":
|
||||
result, itemID, ok, err := extractOpenAIImageFromResponsesOutputItemDone(payload)
|
||||
if err != nil {
|
||||
return nil, 0, nil, openAIResponsesImageResult{}, false, err
|
||||
collectErr = err
|
||||
return
|
||||
}
|
||||
if ok {
|
||||
mergeOpenAIResponsesImageMeta(&result, responseMeta)
|
||||
@ -397,7 +399,8 @@ func collectOpenAIImagesFromResponsesBody(body []byte) ([]openAIResponsesImageRe
|
||||
case "response.completed":
|
||||
results, completedAt, completedUsageRaw, firstMeta, err := extractOpenAIImagesFromResponsesCompleted(payload)
|
||||
if err != nil {
|
||||
return nil, 0, nil, openAIResponsesImageResult{}, false, err
|
||||
collectErr = err
|
||||
return
|
||||
}
|
||||
foundFinal = true
|
||||
if completedAt > 0 {
|
||||
@ -408,14 +411,24 @@ func collectOpenAIImagesFromResponsesBody(body []byte) ([]openAIResponsesImageRe
|
||||
}
|
||||
if len(results) > 0 {
|
||||
mergeOpenAIResponsesImageMeta(&firstMeta, responseMeta)
|
||||
return results, createdAt, usageRaw, firstMeta, true, nil
|
||||
finalResults = results
|
||||
finalMeta = firstMeta
|
||||
return
|
||||
}
|
||||
if len(fallbackResults) > 0 {
|
||||
firstMeta = fallbackResults[0]
|
||||
mergeOpenAIResponsesImageMeta(&firstMeta, responseMeta)
|
||||
return fallbackResults, createdAt, usageRaw, firstMeta, true, nil
|
||||
finalResults = fallbackResults
|
||||
finalMeta = firstMeta
|
||||
return
|
||||
}
|
||||
}
|
||||
})
|
||||
if collectErr != nil {
|
||||
return nil, 0, nil, openAIResponsesImageResult{}, false, collectErr
|
||||
}
|
||||
if len(finalResults) > 0 {
|
||||
return finalResults, createdAt, usageRaw, finalMeta, true, nil
|
||||
}
|
||||
|
||||
if len(fallbackResults) > 0 {
|
||||
@ -505,6 +518,30 @@ func (s *OpenAIGatewayService) writeOpenAIImagesStreamEvent(c *gin.Context, flus
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) tryWriteOpenAIImagesStreamEvent(
|
||||
c *gin.Context,
|
||||
flusher http.Flusher,
|
||||
clientDisconnected *bool,
|
||||
lastWriteAt *time.Time,
|
||||
eventName string,
|
||||
payload []byte,
|
||||
) bool {
|
||||
if clientDisconnected != nil && *clientDisconnected {
|
||||
return false
|
||||
}
|
||||
if err := s.writeOpenAIImagesStreamEvent(c, flusher, eventName, payload); err != nil {
|
||||
if clientDisconnected != nil {
|
||||
*clientDisconnected = true
|
||||
}
|
||||
logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Images stream client disconnected, continue draining upstream for billing")
|
||||
return false
|
||||
}
|
||||
if lastWriteAt != nil {
|
||||
*lastWriteAt = time.Now()
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) handleOpenAIImagesOAuthNonStreamingResponse(
|
||||
resp *http.Response,
|
||||
c *gin.Context,
|
||||
@ -517,15 +554,9 @@ func (s *OpenAIGatewayService) handleOpenAIImagesOAuthNonStreamingResponse(
|
||||
}
|
||||
|
||||
var usage OpenAIUsage
|
||||
for _, line := range bytes.Split(body, []byte("\n")) {
|
||||
line = bytes.TrimRight(line, "\r")
|
||||
data, ok := extractOpenAISSEDataLine(string(line))
|
||||
if !ok || data == "" || data == "[DONE]" {
|
||||
continue
|
||||
}
|
||||
dataBytes := []byte(data)
|
||||
s.parseSSEUsageBytes(dataBytes, &usage)
|
||||
}
|
||||
forEachOpenAISSEDataPayload(string(body), func(data []byte) {
|
||||
s.parseSSEUsageBytes(data, &usage)
|
||||
})
|
||||
results, createdAt, usageRaw, firstMeta, _, err := collectOpenAIImagesFromResponsesBody(body)
|
||||
if err != nil {
|
||||
return OpenAIUsage{}, 0, err
|
||||
@ -570,7 +601,6 @@ func (s *OpenAIGatewayService) handleOpenAIImagesOAuthStreamingResponse(
|
||||
format = "b64_json"
|
||||
}
|
||||
|
||||
reader := bufio.NewReader(resp.Body)
|
||||
usage := OpenAIUsage{}
|
||||
imageCount := 0
|
||||
var firstTokenMs *int
|
||||
@ -579,141 +609,307 @@ func (s *OpenAIGatewayService) handleOpenAIImagesOAuthStreamingResponse(
|
||||
pendingSeen := make(map[string]struct{})
|
||||
streamMeta := openAIResponsesImageResult{Model: strings.TrimSpace(fallbackModel)}
|
||||
var createdAt int64
|
||||
clientDisconnected := false
|
||||
lastDownstreamWriteAt := time.Now()
|
||||
var sseData openAISSEDataAccumulator
|
||||
var processDataErr error
|
||||
processDataDone := false
|
||||
|
||||
processData := func(dataBytes []byte) {
|
||||
if processDataDone || processDataErr != nil {
|
||||
return
|
||||
}
|
||||
if firstTokenMs == nil {
|
||||
ms := int(time.Since(startTime).Milliseconds())
|
||||
firstTokenMs = &ms
|
||||
}
|
||||
s.parseSSEUsageBytes(dataBytes, &usage)
|
||||
if !gjson.ValidBytes(dataBytes) {
|
||||
return
|
||||
}
|
||||
if meta, eventCreatedAt, ok := extractOpenAIResponsesImageMetaFromLifecycleEvent(dataBytes); ok {
|
||||
mergeOpenAIResponsesImageMeta(&streamMeta, meta)
|
||||
if eventCreatedAt > 0 {
|
||||
createdAt = eventCreatedAt
|
||||
}
|
||||
}
|
||||
switch gjson.GetBytes(dataBytes, "type").String() {
|
||||
case "response.image_generation_call.partial_image":
|
||||
b64 := strings.TrimSpace(gjson.GetBytes(dataBytes, "partial_image_b64").String())
|
||||
if b64 == "" {
|
||||
return
|
||||
}
|
||||
eventName := streamPrefix + ".partial_image"
|
||||
partialMeta := streamMeta
|
||||
mergeOpenAIResponsesImageMeta(&partialMeta, openAIResponsesImageResult{
|
||||
OutputFormat: strings.TrimSpace(gjson.GetBytes(dataBytes, "output_format").String()),
|
||||
Background: strings.TrimSpace(gjson.GetBytes(dataBytes, "background").String()),
|
||||
})
|
||||
payload := buildOpenAIImagesStreamPartialPayload(
|
||||
eventName,
|
||||
b64,
|
||||
gjson.GetBytes(dataBytes, "partial_image_index").Int(),
|
||||
format,
|
||||
createdAt,
|
||||
partialMeta,
|
||||
)
|
||||
s.tryWriteOpenAIImagesStreamEvent(c, flusher, &clientDisconnected, &lastDownstreamWriteAt, eventName, payload)
|
||||
case "response.output_item.done":
|
||||
img, itemID, ok, extractErr := extractOpenAIImageFromResponsesOutputItemDone(dataBytes)
|
||||
if extractErr != nil {
|
||||
s.tryWriteOpenAIImagesStreamEvent(c, flusher, &clientDisconnected, &lastDownstreamWriteAt, "error", buildOpenAIImagesStreamErrorBody(extractErr.Error()))
|
||||
processDataErr = extractErr
|
||||
processDataDone = true
|
||||
return
|
||||
}
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
mergeOpenAIResponsesImageMeta(&streamMeta, img)
|
||||
mergeOpenAIResponsesImageMeta(&img, streamMeta)
|
||||
key := openAIResponsesImageResultKey(itemID, img)
|
||||
if _, exists := emitted[key]; exists {
|
||||
return
|
||||
}
|
||||
if _, exists := pendingSeen[key]; exists {
|
||||
return
|
||||
}
|
||||
pendingSeen[key] = struct{}{}
|
||||
pendingResults = append(pendingResults, img)
|
||||
case "response.completed":
|
||||
results, _, usageRaw, firstMeta, extractErr := extractOpenAIImagesFromResponsesCompleted(dataBytes)
|
||||
if extractErr != nil {
|
||||
s.tryWriteOpenAIImagesStreamEvent(c, flusher, &clientDisconnected, &lastDownstreamWriteAt, "error", buildOpenAIImagesStreamErrorBody(extractErr.Error()))
|
||||
processDataErr = extractErr
|
||||
processDataDone = true
|
||||
return
|
||||
}
|
||||
mergeOpenAIResponsesImageMeta(&streamMeta, firstMeta)
|
||||
finalResults := make([]openAIResponsesImageResult, 0, len(results)+len(pendingResults))
|
||||
finalSeen := make(map[string]struct{})
|
||||
for _, img := range results {
|
||||
mergeOpenAIResponsesImageMeta(&img, streamMeta)
|
||||
appendOpenAIResponsesImageResultDedup(&finalResults, finalSeen, "", img)
|
||||
}
|
||||
for _, img := range pendingResults {
|
||||
mergeOpenAIResponsesImageMeta(&img, streamMeta)
|
||||
appendOpenAIResponsesImageResultDedup(&finalResults, finalSeen, "", img)
|
||||
}
|
||||
if len(finalResults) == 0 {
|
||||
outputErr := fmt.Errorf("upstream did not return image output")
|
||||
s.tryWriteOpenAIImagesStreamEvent(c, flusher, &clientDisconnected, &lastDownstreamWriteAt, "error", buildOpenAIImagesStreamErrorBody(outputErr.Error()))
|
||||
processDataErr = outputErr
|
||||
processDataDone = true
|
||||
return
|
||||
}
|
||||
eventName := streamPrefix + ".completed"
|
||||
for _, img := range finalResults {
|
||||
key := openAIResponsesImageResultKey("", img)
|
||||
if _, exists := emitted[key]; exists {
|
||||
continue
|
||||
}
|
||||
payload := buildOpenAIImagesStreamCompletedPayload(eventName, img, format, createdAt, usageRaw)
|
||||
emitted[key] = struct{}{}
|
||||
s.tryWriteOpenAIImagesStreamEvent(c, flusher, &clientDisconnected, &lastDownstreamWriteAt, eventName, payload)
|
||||
}
|
||||
imageCount = len(emitted)
|
||||
processDataDone = true
|
||||
}
|
||||
}
|
||||
|
||||
processLine := func(line []byte) (bool, error) {
|
||||
if len(line) == 0 {
|
||||
return false, nil
|
||||
}
|
||||
sseData.AddLine(string(line), processData)
|
||||
if processDataErr != nil {
|
||||
return true, processDataErr
|
||||
}
|
||||
return processDataDone, nil
|
||||
}
|
||||
|
||||
flushData := func() (bool, error) {
|
||||
sseData.Flush(processData)
|
||||
if processDataErr != nil {
|
||||
return true, processDataErr
|
||||
}
|
||||
return processDataDone, nil
|
||||
}
|
||||
|
||||
finalizePending := func() error {
|
||||
if imageCount > 0 {
|
||||
return nil
|
||||
}
|
||||
if len(pendingResults) > 0 {
|
||||
eventName := streamPrefix + ".completed"
|
||||
for _, img := range pendingResults {
|
||||
mergeOpenAIResponsesImageMeta(&img, streamMeta)
|
||||
key := openAIResponsesImageResultKey("", img)
|
||||
if _, exists := emitted[key]; exists {
|
||||
continue
|
||||
}
|
||||
payload := buildOpenAIImagesStreamCompletedPayload(eventName, img, format, createdAt, nil)
|
||||
emitted[key] = struct{}{}
|
||||
s.tryWriteOpenAIImagesStreamEvent(c, flusher, &clientDisconnected, &lastDownstreamWriteAt, eventName, payload)
|
||||
}
|
||||
imageCount = len(emitted)
|
||||
return nil
|
||||
}
|
||||
|
||||
streamErr := fmt.Errorf("stream disconnected before image generation completed")
|
||||
s.tryWriteOpenAIImagesStreamEvent(c, flusher, &clientDisconnected, &lastDownstreamWriteAt, "error", buildOpenAIImagesStreamErrorBody(streamErr.Error()))
|
||||
return streamErr
|
||||
}
|
||||
|
||||
streamInterval := s.openAIImageStreamDataInterval()
|
||||
keepaliveInterval := s.openAIImageStreamKeepaliveInterval()
|
||||
if streamInterval <= 0 && keepaliveInterval <= 0 {
|
||||
reader := bufio.NewReader(resp.Body)
|
||||
for {
|
||||
line, err := reader.ReadBytes('\n')
|
||||
done, processErr := processLine(line)
|
||||
if processErr != nil {
|
||||
return usage, imageCount, firstTokenMs, processErr
|
||||
}
|
||||
if done {
|
||||
return usage, imageCount, firstTokenMs, nil
|
||||
}
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
if done, processErr := flushData(); processErr != nil {
|
||||
return usage, imageCount, firstTokenMs, processErr
|
||||
} else if done {
|
||||
return usage, imageCount, firstTokenMs, nil
|
||||
}
|
||||
s.tryWriteOpenAIImagesStreamEvent(c, flusher, &clientDisconnected, &lastDownstreamWriteAt, "error", buildOpenAIImagesStreamErrorBody(err.Error()))
|
||||
return usage, imageCount, firstTokenMs, err
|
||||
}
|
||||
}
|
||||
if done, processErr := flushData(); processErr != nil {
|
||||
return usage, imageCount, firstTokenMs, processErr
|
||||
} else if done {
|
||||
return usage, imageCount, firstTokenMs, nil
|
||||
}
|
||||
if err := finalizePending(); err != nil {
|
||||
return usage, imageCount, firstTokenMs, err
|
||||
}
|
||||
return usage, imageCount, firstTokenMs, nil
|
||||
}
|
||||
|
||||
type readEvent struct {
|
||||
line []byte
|
||||
err error
|
||||
}
|
||||
events := make(chan readEvent, 16)
|
||||
done := make(chan struct{})
|
||||
sendEvent := func(ev readEvent) bool {
|
||||
select {
|
||||
case events <- ev:
|
||||
return true
|
||||
case <-done:
|
||||
return false
|
||||
}
|
||||
}
|
||||
var lastReadAt int64
|
||||
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
|
||||
go func() {
|
||||
defer close(events)
|
||||
reader := bufio.NewReader(resp.Body)
|
||||
for {
|
||||
line, err := reader.ReadBytes('\n')
|
||||
if len(line) > 0 {
|
||||
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
|
||||
}
|
||||
if len(line) > 0 && !sendEvent(readEvent{line: line}) {
|
||||
return
|
||||
}
|
||||
if err == io.EOF {
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
_ = sendEvent(readEvent{err: err})
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
defer close(done)
|
||||
|
||||
var intervalTicker *time.Ticker
|
||||
if streamInterval > 0 {
|
||||
intervalTicker = time.NewTicker(streamInterval)
|
||||
defer intervalTicker.Stop()
|
||||
}
|
||||
var intervalCh <-chan time.Time
|
||||
if intervalTicker != nil {
|
||||
intervalCh = intervalTicker.C
|
||||
}
|
||||
|
||||
var keepaliveTicker *time.Ticker
|
||||
if keepaliveInterval > 0 {
|
||||
keepaliveTicker = time.NewTicker(keepaliveInterval)
|
||||
defer keepaliveTicker.Stop()
|
||||
}
|
||||
var keepaliveCh <-chan time.Time
|
||||
if keepaliveTicker != nil {
|
||||
keepaliveCh = keepaliveTicker.C
|
||||
}
|
||||
|
||||
for {
|
||||
line, err := reader.ReadBytes('\n')
|
||||
if len(line) > 0 {
|
||||
trimmedLine := strings.TrimRight(string(line), "\r\n")
|
||||
data, ok := extractOpenAISSEDataLine(trimmedLine)
|
||||
if ok && data != "" && data != "[DONE]" {
|
||||
if firstTokenMs == nil {
|
||||
ms := int(time.Since(startTime).Milliseconds())
|
||||
firstTokenMs = &ms
|
||||
select {
|
||||
case ev, ok := <-events:
|
||||
if !ok {
|
||||
if done, processErr := flushData(); processErr != nil {
|
||||
return usage, imageCount, firstTokenMs, processErr
|
||||
} else if done {
|
||||
return usage, imageCount, firstTokenMs, nil
|
||||
}
|
||||
dataBytes := []byte(data)
|
||||
s.parseSSEUsageBytes(dataBytes, &usage)
|
||||
if gjson.ValidBytes(dataBytes) {
|
||||
if meta, eventCreatedAt, ok := extractOpenAIResponsesImageMetaFromLifecycleEvent(dataBytes); ok {
|
||||
mergeOpenAIResponsesImageMeta(&streamMeta, meta)
|
||||
if eventCreatedAt > 0 {
|
||||
createdAt = eventCreatedAt
|
||||
}
|
||||
}
|
||||
switch gjson.GetBytes(dataBytes, "type").String() {
|
||||
case "response.image_generation_call.partial_image":
|
||||
b64 := strings.TrimSpace(gjson.GetBytes(dataBytes, "partial_image_b64").String())
|
||||
if b64 != "" {
|
||||
eventName := streamPrefix + ".partial_image"
|
||||
partialMeta := streamMeta
|
||||
mergeOpenAIResponsesImageMeta(&partialMeta, openAIResponsesImageResult{
|
||||
OutputFormat: strings.TrimSpace(gjson.GetBytes(dataBytes, "output_format").String()),
|
||||
Background: strings.TrimSpace(gjson.GetBytes(dataBytes, "background").String()),
|
||||
})
|
||||
payload := buildOpenAIImagesStreamPartialPayload(
|
||||
eventName,
|
||||
b64,
|
||||
gjson.GetBytes(dataBytes, "partial_image_index").Int(),
|
||||
format,
|
||||
createdAt,
|
||||
partialMeta,
|
||||
)
|
||||
if writeErr := s.writeOpenAIImagesStreamEvent(c, flusher, eventName, payload); writeErr != nil {
|
||||
return OpenAIUsage{}, imageCount, firstTokenMs, writeErr
|
||||
}
|
||||
}
|
||||
case "response.output_item.done":
|
||||
img, itemID, ok, extractErr := extractOpenAIImageFromResponsesOutputItemDone(dataBytes)
|
||||
if extractErr != nil {
|
||||
_ = s.writeOpenAIImagesStreamEvent(c, flusher, "error", buildOpenAIImagesStreamErrorBody(extractErr.Error()))
|
||||
return OpenAIUsage{}, imageCount, firstTokenMs, extractErr
|
||||
}
|
||||
if !ok {
|
||||
break
|
||||
}
|
||||
mergeOpenAIResponsesImageMeta(&streamMeta, img)
|
||||
mergeOpenAIResponsesImageMeta(&img, streamMeta)
|
||||
key := openAIResponsesImageResultKey(itemID, img)
|
||||
if _, exists := emitted[key]; exists {
|
||||
break
|
||||
}
|
||||
if _, exists := pendingSeen[key]; exists {
|
||||
break
|
||||
}
|
||||
pendingSeen[key] = struct{}{}
|
||||
pendingResults = append(pendingResults, img)
|
||||
case "response.completed":
|
||||
results, _, usageRaw, firstMeta, extractErr := extractOpenAIImagesFromResponsesCompleted(dataBytes)
|
||||
if extractErr != nil {
|
||||
_ = s.writeOpenAIImagesStreamEvent(c, flusher, "error", buildOpenAIImagesStreamErrorBody(extractErr.Error()))
|
||||
return OpenAIUsage{}, imageCount, firstTokenMs, extractErr
|
||||
}
|
||||
mergeOpenAIResponsesImageMeta(&streamMeta, firstMeta)
|
||||
finalResults := make([]openAIResponsesImageResult, 0, len(results)+len(pendingResults))
|
||||
finalSeen := make(map[string]struct{})
|
||||
for _, img := range results {
|
||||
mergeOpenAIResponsesImageMeta(&img, streamMeta)
|
||||
appendOpenAIResponsesImageResultDedup(&finalResults, finalSeen, "", img)
|
||||
}
|
||||
for _, img := range pendingResults {
|
||||
mergeOpenAIResponsesImageMeta(&img, streamMeta)
|
||||
appendOpenAIResponsesImageResultDedup(&finalResults, finalSeen, "", img)
|
||||
}
|
||||
if len(finalResults) == 0 {
|
||||
err = fmt.Errorf("upstream did not return image output")
|
||||
_ = s.writeOpenAIImagesStreamEvent(c, flusher, "error", buildOpenAIImagesStreamErrorBody(err.Error()))
|
||||
return OpenAIUsage{}, imageCount, firstTokenMs, err
|
||||
}
|
||||
eventName := streamPrefix + ".completed"
|
||||
for _, img := range finalResults {
|
||||
key := openAIResponsesImageResultKey("", img)
|
||||
if _, exists := emitted[key]; exists {
|
||||
continue
|
||||
}
|
||||
payload := buildOpenAIImagesStreamCompletedPayload(eventName, img, format, createdAt, usageRaw)
|
||||
if writeErr := s.writeOpenAIImagesStreamEvent(c, flusher, eventName, payload); writeErr != nil {
|
||||
return OpenAIUsage{}, imageCount, firstTokenMs, writeErr
|
||||
}
|
||||
emitted[key] = struct{}{}
|
||||
}
|
||||
imageCount = len(emitted)
|
||||
return usage, imageCount, firstTokenMs, nil
|
||||
}
|
||||
if err := finalizePending(); err != nil {
|
||||
return usage, imageCount, firstTokenMs, err
|
||||
}
|
||||
return usage, imageCount, firstTokenMs, nil
|
||||
}
|
||||
}
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
_ = s.writeOpenAIImagesStreamEvent(c, flusher, "error", buildOpenAIImagesStreamErrorBody(err.Error()))
|
||||
return OpenAIUsage{}, imageCount, firstTokenMs, err
|
||||
}
|
||||
}
|
||||
|
||||
if imageCount > 0 {
|
||||
return usage, imageCount, firstTokenMs, nil
|
||||
}
|
||||
if len(pendingResults) > 0 {
|
||||
eventName := streamPrefix + ".completed"
|
||||
for _, img := range pendingResults {
|
||||
mergeOpenAIResponsesImageMeta(&img, streamMeta)
|
||||
key := openAIResponsesImageResultKey("", img)
|
||||
if _, exists := emitted[key]; exists {
|
||||
if ev.err != nil {
|
||||
if done, processErr := flushData(); processErr != nil {
|
||||
return usage, imageCount, firstTokenMs, processErr
|
||||
} else if done {
|
||||
return usage, imageCount, firstTokenMs, nil
|
||||
}
|
||||
s.tryWriteOpenAIImagesStreamEvent(c, flusher, &clientDisconnected, &lastDownstreamWriteAt, "error", buildOpenAIImagesStreamErrorBody(ev.err.Error()))
|
||||
return usage, imageCount, firstTokenMs, ev.err
|
||||
}
|
||||
done, processErr := processLine(ev.line)
|
||||
if processErr != nil {
|
||||
return usage, imageCount, firstTokenMs, processErr
|
||||
}
|
||||
if done {
|
||||
return usage, imageCount, firstTokenMs, nil
|
||||
}
|
||||
case <-intervalCh:
|
||||
lastRead := time.Unix(0, atomic.LoadInt64(&lastReadAt))
|
||||
if time.Since(lastRead) < streamInterval {
|
||||
continue
|
||||
}
|
||||
payload := buildOpenAIImagesStreamCompletedPayload(eventName, img, format, createdAt, nil)
|
||||
if writeErr := s.writeOpenAIImagesStreamEvent(c, flusher, eventName, payload); writeErr != nil {
|
||||
return OpenAIUsage{}, imageCount, firstTokenMs, writeErr
|
||||
if clientDisconnected {
|
||||
return usage, imageCount, firstTokenMs, fmt.Errorf("image stream incomplete after timeout")
|
||||
}
|
||||
emitted[key] = struct{}{}
|
||||
logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Images responses stream data interval timeout: interval=%s", streamInterval)
|
||||
s.tryWriteOpenAIImagesStreamEvent(c, flusher, &clientDisconnected, &lastDownstreamWriteAt, "error", buildOpenAIImagesStreamErrorBody(fmt.Sprintf("upstream image stream idle for %s", streamInterval)))
|
||||
return usage, imageCount, firstTokenMs, fmt.Errorf("image stream data interval timeout")
|
||||
case <-keepaliveCh:
|
||||
if clientDisconnected || time.Since(lastDownstreamWriteAt) < keepaliveInterval {
|
||||
continue
|
||||
}
|
||||
if _, writeErr := io.WriteString(c.Writer, ":\n\n"); writeErr != nil {
|
||||
clientDisconnected = true
|
||||
logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Images responses stream client disconnected during keepalive, continue draining upstream for billing")
|
||||
continue
|
||||
}
|
||||
flusher.Flush()
|
||||
lastDownstreamWriteAt = time.Now()
|
||||
}
|
||||
imageCount = len(emitted)
|
||||
return usage, imageCount, firstTokenMs, nil
|
||||
}
|
||||
|
||||
streamErr := fmt.Errorf("stream disconnected before image generation completed")
|
||||
_ = s.writeOpenAIImagesStreamEvent(c, flusher, "error", buildOpenAIImagesStreamErrorBody(streamErr.Error()))
|
||||
return OpenAIUsage{}, imageCount, firstTokenMs, streamErr
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) forwardOpenAIImagesOAuth(
|
||||
@ -752,7 +948,10 @@ func (s *OpenAIGatewayService) forwardOpenAIImagesOAuth(
|
||||
)
|
||||
}
|
||||
|
||||
token, _, err := s.GetAccessToken(ctx, account)
|
||||
upstreamCtx, releaseUpstreamCtx := detachStreamUpstreamContext(ctx, parsed.Stream)
|
||||
defer releaseUpstreamCtx()
|
||||
|
||||
token, _, err := s.GetAccessToken(upstreamCtx, account)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -763,7 +962,7 @@ func (s *OpenAIGatewayService) forwardOpenAIImagesOAuth(
|
||||
}
|
||||
setOpsUpstreamRequestBody(c, responsesBody)
|
||||
|
||||
upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, responsesBody, token, true, parsed.StickySessionSeed(), false)
|
||||
upstreamReq, err := s.buildUpstreamRequest(upstreamCtx, c, account, responsesBody, token, true, parsed.StickySessionSeed(), false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -808,14 +1007,14 @@ func (s *OpenAIGatewayService) forwardOpenAIImagesOAuth(
|
||||
Kind: "failover",
|
||||
Message: upstreamMsg,
|
||||
})
|
||||
s.handleFailoverSideEffects(ctx, resp, account)
|
||||
s.handleFailoverSideEffects(upstreamCtx, resp, account)
|
||||
return nil, &UpstreamFailoverError{
|
||||
StatusCode: resp.StatusCode,
|
||||
ResponseBody: respBody,
|
||||
RetryableOnSameAccount: account.IsPoolMode() && isPoolModeRetryableStatus(resp.StatusCode),
|
||||
}
|
||||
}
|
||||
return s.handleErrorResponse(ctx, resp, c, account, responsesBody)
|
||||
return s.handleErrorResponse(upstreamCtx, resp, c, account, responsesBody)
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
@ -827,6 +1026,20 @@ func (s *OpenAIGatewayService) forwardOpenAIImagesOAuth(
|
||||
if parsed.Stream {
|
||||
usage, imageCount, firstTokenMs, err = s.handleOpenAIImagesOAuthStreamingResponse(resp, c, startTime, parsed.ResponseFormat, openAIImagesStreamPrefix(parsed), requestModel)
|
||||
if err != nil {
|
||||
if imageCount > 0 {
|
||||
return &OpenAIForwardResult{
|
||||
RequestID: resp.Header.Get("x-request-id"),
|
||||
Usage: usage,
|
||||
Model: requestModel,
|
||||
UpstreamModel: requestModel,
|
||||
Stream: parsed.Stream,
|
||||
ResponseHeaders: resp.Header.Clone(),
|
||||
Duration: time.Since(startTime),
|
||||
FirstTokenMs: firstTokenMs,
|
||||
ImageCount: imageCount,
|
||||
ImageSize: parsed.SizeTier,
|
||||
}, err
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
|
||||
@ -3,6 +3,7 @@ package service
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
@ -17,6 +18,20 @@ import (
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
type failingOpenAIImageWriter struct {
|
||||
gin.ResponseWriter
|
||||
failAfter int
|
||||
writes int
|
||||
}
|
||||
|
||||
func (w *failingOpenAIImageWriter) Write(p []byte) (int, error) {
|
||||
if w.writes >= w.failAfter {
|
||||
return 0, errors.New("write failed: client disconnected")
|
||||
}
|
||||
w.writes++
|
||||
return w.ResponseWriter.Write(p)
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayServiceParseOpenAIImagesRequest_JSON(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
body := []byte(`{"model":"gpt-image-2","prompt":"draw a cat","size":"1024x1024","quality":"high","stream":true}`)
|
||||
@ -75,6 +90,100 @@ func TestOpenAIGatewayServiceParseOpenAIImagesRequest_MultipartEdit(t *testing.T
|
||||
require.Equal(t, OpenAIImagesCapabilityNative, parsed.RequiredCapability)
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayServiceParseOpenAIImagesRequest_NormalizesOfficialAndCustomSizes(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
tests := []struct {
|
||||
size string
|
||||
wantTier string
|
||||
}{
|
||||
{size: "1024x1024", wantTier: "1K"},
|
||||
{size: "1536x1024", wantTier: "2K"},
|
||||
{size: "1024x1536", wantTier: "2K"},
|
||||
{size: "2048x2048", wantTier: "2K"},
|
||||
{size: "2048x1152", wantTier: "2K"},
|
||||
{size: "3840x2160", wantTier: "4K"},
|
||||
{size: "2160x3840", wantTier: "4K"},
|
||||
{size: "1024X768", wantTier: "2K"},
|
||||
{size: "1280x768", wantTier: "2K"},
|
||||
{size: "2560x1440", wantTier: "2K"},
|
||||
{size: "2560x1600", wantTier: "4K"},
|
||||
{size: "auto", wantTier: "2K"},
|
||||
}
|
||||
|
||||
svc := &OpenAIGatewayService{}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.size, func(t *testing.T) {
|
||||
body := []byte(`{"model":"gpt-image-2","prompt":"draw a cat","size":"` + tt.size + `"}`)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/images/generations", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = req
|
||||
|
||||
parsed, err := svc.ParseOpenAIImagesRequest(c, body)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, parsed)
|
||||
require.Equal(t, tt.size, parsed.Size)
|
||||
require.Equal(t, tt.wantTier, parsed.SizeTier)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayServiceParseOpenAIImagesRequest_UnknownSizesDoNotBlockPassthrough(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
tests := []struct {
|
||||
size string
|
||||
wantTier string
|
||||
}{
|
||||
{size: "2048x1153", wantTier: "2K"},
|
||||
{size: "4096x1024", wantTier: "4K"},
|
||||
{size: "3840x1024", wantTier: "4K"},
|
||||
{size: "512x512", wantTier: "2K"},
|
||||
{size: "invalid", wantTier: "2K"},
|
||||
{size: "999999999999999999999999999x2", wantTier: "2K"},
|
||||
}
|
||||
|
||||
svc := &OpenAIGatewayService{}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.size, func(t *testing.T) {
|
||||
body := []byte(`{"model":"gpt-image-2","prompt":"draw a cat","size":"` + tt.size + `"}`)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/images/generations", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = req
|
||||
|
||||
parsed, err := svc.ParseOpenAIImagesRequest(c, body)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, parsed)
|
||||
require.Equal(t, tt.size, parsed.Size)
|
||||
require.Equal(t, tt.wantTier, parsed.SizeTier)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayServiceParseOpenAIImagesRequest_LegacyImageModelUnknownSizePassthrough(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
body := []byte(`{"model":"gpt-image-1.5","prompt":"draw a cat","size":"2048x1152"}`)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/images/generations", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = req
|
||||
|
||||
svc := &OpenAIGatewayService{}
|
||||
parsed, err := svc.ParseOpenAIImagesRequest(c, body)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, parsed)
|
||||
require.Equal(t, "2048x1152", parsed.Size)
|
||||
require.Equal(t, "2K", parsed.SizeTier)
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayServiceParseOpenAIImagesRequest_MultipartEditWithMaskAndNativeOptions(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
@ -446,6 +555,160 @@ func TestOpenAIGatewayServiceForwardImages_APIKeyGenerationUsesConfiguredV1BaseU
|
||||
require.Equal(t, "aGVsbG8=", gjson.Get(rec.Body.String(), "data.0.b64_json").String())
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayServiceForwardImages_APIKeyStreamJSONResponseBillsImage(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
body := []byte(`{"model":"gpt-image-2","prompt":"draw a cat","stream":true,"response_format":"b64_json"}`)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/images/generations", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = req
|
||||
|
||||
svc := &OpenAIGatewayService{
|
||||
cfg: &config.Config{},
|
||||
httpUpstream: &httpUpstreamRecorder{
|
||||
resp: &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{
|
||||
"Content-Type": []string{"application/json"},
|
||||
"X-Request-Id": []string{"req_img_stream_json"},
|
||||
},
|
||||
Body: io.NopCloser(strings.NewReader(`{"created":1710000008,"usage":{"input_tokens":12,"output_tokens":21,"output_tokens_details":{"image_tokens":9}},"data":[{"b64_json":"aGVsbG8=","revised_prompt":"draw a cat"}]}`)),
|
||||
},
|
||||
},
|
||||
}
|
||||
parsed, err := svc.ParseOpenAIImagesRequest(c, body)
|
||||
require.NoError(t, err)
|
||||
|
||||
account := &Account{
|
||||
ID: 7,
|
||||
Name: "openai-apikey",
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeAPIKey,
|
||||
Credentials: map[string]any{
|
||||
"api_key": "test-api-key",
|
||||
"base_url": "https://image-upstream.example/v1",
|
||||
},
|
||||
}
|
||||
|
||||
result, err := svc.ForwardImages(context.Background(), c, account, body, parsed, "")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.True(t, result.Stream)
|
||||
require.Equal(t, 1, result.ImageCount)
|
||||
require.Equal(t, 12, result.Usage.InputTokens)
|
||||
require.Equal(t, 21, result.Usage.OutputTokens)
|
||||
require.Equal(t, 9, result.Usage.ImageOutputTokens)
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
require.Equal(t, "aGVsbG8=", gjson.Get(rec.Body.String(), "data.0.b64_json").String())
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayServiceForwardImages_APIKeyStreamRawJSONEventStreamFallbackBillsImage(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
body := []byte(`{"model":"gpt-image-2","prompt":"draw a cat","stream":true,"response_format":"b64_json"}`)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/images/generations", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = req
|
||||
|
||||
svc := &OpenAIGatewayService{
|
||||
cfg: &config.Config{},
|
||||
httpUpstream: &httpUpstreamRecorder{
|
||||
resp: &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{
|
||||
"Content-Type": []string{"text/event-stream"},
|
||||
"X-Request-Id": []string{"req_img_stream_json_mislabeled"},
|
||||
},
|
||||
Body: io.NopCloser(strings.NewReader(`{"created":1710000009,"usage":{"input_tokens":10,"output_tokens":18,"output_tokens_details":{"image_tokens":8}},"data":[{"b64_json":"ZmluYWw="}]}`)),
|
||||
},
|
||||
},
|
||||
}
|
||||
parsed, err := svc.ParseOpenAIImagesRequest(c, body)
|
||||
require.NoError(t, err)
|
||||
|
||||
account := &Account{
|
||||
ID: 8,
|
||||
Name: "openai-apikey",
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeAPIKey,
|
||||
Credentials: map[string]any{
|
||||
"api_key": "test-api-key",
|
||||
"base_url": "https://image-upstream.example/v1",
|
||||
},
|
||||
}
|
||||
|
||||
result, err := svc.ForwardImages(context.Background(), c, account, body, parsed, "")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.True(t, result.Stream)
|
||||
require.Equal(t, 1, result.ImageCount)
|
||||
require.Equal(t, 10, result.Usage.InputTokens)
|
||||
require.Equal(t, 18, result.Usage.OutputTokens)
|
||||
require.Equal(t, 8, result.Usage.ImageOutputTokens)
|
||||
require.Equal(t, "ZmluYWw=", gjson.Get(rec.Body.String(), "data.0.b64_json").String())
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayServiceForwardImages_APIKeyStreamMultilineSSEDataBillsImage(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
body := []byte(`{"model":"gpt-image-2","prompt":"draw a cat","stream":true,"response_format":"b64_json"}`)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/images/generations", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = req
|
||||
|
||||
svc := &OpenAIGatewayService{
|
||||
cfg: &config.Config{},
|
||||
httpUpstream: &httpUpstreamRecorder{
|
||||
resp: &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{
|
||||
"Content-Type": []string{"text/event-stream"},
|
||||
"X-Request-Id": []string{"req_img_stream_multiline"},
|
||||
},
|
||||
Body: io.NopCloser(strings.NewReader(
|
||||
"data: {\"type\":\"image_generation.completed\",\n" +
|
||||
"data: \"usage\":{\"input_tokens\":10,\"output_tokens\":18,\"output_tokens_details\":{\"image_tokens\":8}},\n" +
|
||||
"data: \"b64_json\":\"ZmluYWw=\",\"output_format\":\"png\"}\n\n" +
|
||||
"data: [DONE]\n\n",
|
||||
)),
|
||||
},
|
||||
},
|
||||
}
|
||||
parsed, err := svc.ParseOpenAIImagesRequest(c, body)
|
||||
require.NoError(t, err)
|
||||
|
||||
account := &Account{
|
||||
ID: 8,
|
||||
Name: "openai-apikey",
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeAPIKey,
|
||||
Credentials: map[string]any{
|
||||
"api_key": "test-api-key",
|
||||
},
|
||||
}
|
||||
|
||||
result, err := svc.ForwardImages(context.Background(), c, account, body, parsed, "")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.True(t, result.Stream)
|
||||
require.Equal(t, 1, result.ImageCount)
|
||||
require.Equal(t, 10, result.Usage.InputTokens)
|
||||
require.Equal(t, 18, result.Usage.OutputTokens)
|
||||
require.Equal(t, 8, result.Usage.ImageOutputTokens)
|
||||
}
|
||||
|
||||
func TestExtractOpenAIImagesBillableCountFromJSONBytes_CompletedEvent(t *testing.T) {
|
||||
body := []byte(`{"type":"image_generation.completed","b64_json":"ZmluYWw=","usage":{"input_tokens":10,"output_tokens":18}}`)
|
||||
|
||||
require.Equal(t, 1, extractOpenAIImagesBillableCountFromJSONBytes(body))
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayServiceForwardImages_APIKeyEditUsesConfiguredV1BaseURL(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
@ -583,6 +846,61 @@ func TestOpenAIGatewayServiceForwardImages_OAuthStreamingTransformsEvents(t *tes
|
||||
require.False(t, gjson.Get(completed.Data, "revised_prompt").Exists())
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayServiceForwardImages_APIKeyStreamingDrainsAfterClientDisconnect(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
body := []byte(`{"model":"gpt-image-2","prompt":"draw a cat","stream":true}`)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/images/generations", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = req
|
||||
c.Writer = &failingOpenAIImageWriter{ResponseWriter: c.Writer, failAfter: 1}
|
||||
|
||||
svc := &OpenAIGatewayService{
|
||||
cfg: &config.Config{
|
||||
Gateway: config.GatewayConfig{
|
||||
ImageStreamDataIntervalTimeout: 1,
|
||||
ImageStreamKeepaliveInterval: 0,
|
||||
},
|
||||
},
|
||||
httpUpstream: &httpUpstreamRecorder{
|
||||
resp: &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{
|
||||
"Content-Type": []string{"text/event-stream"},
|
||||
"X-Request-Id": []string{"req_img_stream_disconnect_apikey"},
|
||||
},
|
||||
Body: io.NopCloser(strings.NewReader(
|
||||
"data: {\"type\":\"image_generation.partial_image\",\"b64_json\":\"cGFydGlhbA==\"}\n\n" +
|
||||
"data: {\"type\":\"image_generation.completed\",\"usage\":{\"input_tokens\":3,\"output_tokens\":4,\"output_tokens_details\":{\"image_tokens\":2}},\"b64_json\":\"ZmluYWw=\",\"output_format\":\"png\"}\n\n" +
|
||||
"data: [DONE]\n\n",
|
||||
)),
|
||||
},
|
||||
},
|
||||
}
|
||||
parsed, err := svc.ParseOpenAIImagesRequest(c, body)
|
||||
require.NoError(t, err)
|
||||
|
||||
account := &Account{
|
||||
ID: 8,
|
||||
Name: "openai-apikey",
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeAPIKey,
|
||||
Credentials: map[string]any{
|
||||
"api_key": "test-api-key",
|
||||
},
|
||||
}
|
||||
|
||||
result, err := svc.ForwardImages(context.Background(), c, account, body, parsed, "")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.Equal(t, 1, result.ImageCount)
|
||||
require.Equal(t, 3, result.Usage.InputTokens)
|
||||
require.Equal(t, 4, result.Usage.OutputTokens)
|
||||
require.Equal(t, 2, result.Usage.ImageOutputTokens)
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayServiceForwardImages_OAuthEditsMultipartUsesResponsesAPI(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
@ -798,6 +1116,23 @@ func TestCollectOpenAIImagesFromResponsesBody_FallsBackToOutputItemDone(t *testi
|
||||
require.JSONEq(t, `{"images":1}`, string(usageRaw))
|
||||
}
|
||||
|
||||
func TestCollectOpenAIImagesFromResponsesBody_MultilineSSE(t *testing.T) {
|
||||
body := []byte(
|
||||
"data: {\"type\":\"response.completed\",\n" +
|
||||
"data: \"response\":{\"created_at\":1710000010,\"usage\":{\"input_tokens\":5,\"output_tokens\":9,\"output_tokens_details\":{\"image_tokens\":4}},\"tool_usage\":{\"image_gen\":{\"images\":1}},\"output\":[{\"type\":\"image_generation_call\",\"result\":\"ZmluYWw=\",\"output_format\":\"png\"}]}}\n\n" +
|
||||
"data: [DONE]\n\n",
|
||||
)
|
||||
|
||||
results, createdAt, usageRaw, firstMeta, foundFinal, err := collectOpenAIImagesFromResponsesBody(body)
|
||||
require.NoError(t, err)
|
||||
require.True(t, foundFinal)
|
||||
require.Equal(t, int64(1710000010), createdAt)
|
||||
require.Len(t, results, 1)
|
||||
require.Equal(t, "ZmluYWw=", results[0].Result)
|
||||
require.Equal(t, "png", firstMeta.OutputFormat)
|
||||
require.JSONEq(t, `{"images":1}`, string(usageRaw))
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayServiceForwardImages_OAuthStreamingHandlesOutputItemDoneFallback(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
body := []byte(`{"model":"gpt-image-2","prompt":"draw a cat","stream":true,"response_format":"url"}`)
|
||||
@ -854,3 +1189,116 @@ func TestOpenAIGatewayServiceForwardImages_OAuthStreamingHandlesOutputItemDoneFa
|
||||
require.JSONEq(t, `{"images":1}`, gjson.Get(completed.Data, "usage").Raw)
|
||||
require.NotContains(t, rec.Body.String(), "event: error")
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayServiceForwardImages_OAuthStreamingHandlesMultilineSSE(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
body := []byte(`{"model":"gpt-image-2","prompt":"draw a cat","stream":true,"response_format":"b64_json"}`)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/images/generations", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = req
|
||||
|
||||
svc := &OpenAIGatewayService{}
|
||||
parsed, err := svc.ParseOpenAIImagesRequest(c, body)
|
||||
require.NoError(t, err)
|
||||
|
||||
svc.httpUpstream = &httpUpstreamRecorder{
|
||||
resp: &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{
|
||||
"Content-Type": []string{"text/event-stream"},
|
||||
"X-Request-Id": []string{"req_img_stream_multiline_oauth"},
|
||||
},
|
||||
Body: io.NopCloser(strings.NewReader(
|
||||
"data: {\"type\":\"response.completed\",\n" +
|
||||
"data: \"response\":{\"created_at\":1710000011,\"usage\":{\"input_tokens\":6,\"output_tokens\":10,\"output_tokens_details\":{\"image_tokens\":5}},\"tool_usage\":{\"image_gen\":{\"images\":1}},\"output\":[{\"type\":\"image_generation_call\",\"result\":\"TXVsdGlsaW5l\",\"output_format\":\"png\"}]}}\n\n" +
|
||||
"data: [DONE]\n\n",
|
||||
)),
|
||||
},
|
||||
}
|
||||
|
||||
account := &Account{
|
||||
ID: 11,
|
||||
Name: "openai-oauth",
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "token-123",
|
||||
},
|
||||
}
|
||||
|
||||
result, err := svc.ForwardImages(context.Background(), c, account, body, parsed, "")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.True(t, result.Stream)
|
||||
require.Equal(t, 1, result.ImageCount)
|
||||
require.Equal(t, 6, result.Usage.InputTokens)
|
||||
require.Equal(t, 10, result.Usage.OutputTokens)
|
||||
require.Equal(t, 5, result.Usage.ImageOutputTokens)
|
||||
events := parseOpenAIImageTestSSEEvents(rec.Body.String())
|
||||
completed, ok := findOpenAIImageTestSSEEvent(events, "image_generation.completed")
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "TXVsdGlsaW5l", gjson.Get(completed.Data, "b64_json").String())
|
||||
require.JSONEq(t, `{"images":1}`, gjson.Get(completed.Data, "usage").Raw)
|
||||
require.NotContains(t, rec.Body.String(), "event: error")
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayServiceForwardImages_OAuthStreamingDrainsAfterClientDisconnect(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
body := []byte(`{"model":"gpt-image-2","prompt":"draw a cat","stream":true,"response_format":"url"}`)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/images/generations", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = req
|
||||
c.Writer = &failingOpenAIImageWriter{ResponseWriter: c.Writer, failAfter: 1}
|
||||
|
||||
svc := &OpenAIGatewayService{
|
||||
cfg: &config.Config{
|
||||
Gateway: config.GatewayConfig{
|
||||
ImageStreamDataIntervalTimeout: 1,
|
||||
ImageStreamKeepaliveInterval: 0,
|
||||
},
|
||||
},
|
||||
}
|
||||
parsed, err := svc.ParseOpenAIImagesRequest(c, body)
|
||||
require.NoError(t, err)
|
||||
|
||||
upstream := &httpUpstreamRecorder{
|
||||
resp: &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{
|
||||
"Content-Type": []string{"text/event-stream"},
|
||||
"X-Request-Id": []string{"req_img_stream_disconnect_oauth"},
|
||||
},
|
||||
Body: io.NopCloser(strings.NewReader(
|
||||
"data: {\"type\":\"response.image_generation_call.partial_image\",\"partial_image_b64\":\"cGFydGlhbA==\",\"partial_image_index\":0,\"output_format\":\"png\"}\n\n" +
|
||||
"data: {\"type\":\"response.completed\",\"response\":{\"created_at\":1710000009,\"usage\":{\"input_tokens\":5,\"output_tokens\":9,\"output_tokens_details\":{\"image_tokens\":4}},\"tool_usage\":{\"image_gen\":{\"images\":1}},\"output\":[{\"type\":\"image_generation_call\",\"result\":\"ZmluYWw=\",\"output_format\":\"png\"}]}}\n\n" +
|
||||
"data: [DONE]\n\n",
|
||||
)),
|
||||
},
|
||||
}
|
||||
svc.httpUpstream = upstream
|
||||
|
||||
account := &Account{
|
||||
ID: 9,
|
||||
Name: "openai-oauth",
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "token-123",
|
||||
},
|
||||
}
|
||||
|
||||
result, err := svc.ForwardImages(context.Background(), c, account, body, parsed, "")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.True(t, result.Stream)
|
||||
require.Equal(t, 1, result.ImageCount)
|
||||
require.Equal(t, 5, result.Usage.InputTokens)
|
||||
require.Equal(t, 9, result.Usage.OutputTokens)
|
||||
require.Equal(t, 4, result.Usage.ImageOutputTokens)
|
||||
}
|
||||
|
||||
57
backend/internal/service/openai_messages_bridge.go
Normal file
57
backend/internal/service/openai_messages_bridge.go
Normal file
@ -0,0 +1,57 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
const openAICompatMessagesBridgeContextKey = "openai_compat_messages_bridge"
|
||||
|
||||
func isOpenAICompatMessagesBridgeBody(body []byte) bool {
|
||||
if len(body) == 0 {
|
||||
return false
|
||||
}
|
||||
if bytes.Contains(body, []byte(openAICompatClaudeCodeTodoGuardMarker)) {
|
||||
return true
|
||||
}
|
||||
return isOpenAICompatMessagesBridgePromptCacheKey(gjson.GetBytes(body, "prompt_cache_key").String())
|
||||
}
|
||||
|
||||
func isOpenAICompatMessagesBridgeRequestBody(reqBody map[string]any) bool {
|
||||
if reqBody == nil {
|
||||
return false
|
||||
}
|
||||
if input, ok := reqBody["input"].([]any); ok && inputContainsText(input, openAICompatClaudeCodeTodoGuardMarker) {
|
||||
return true
|
||||
}
|
||||
return isOpenAICompatMessagesBridgePromptCacheKey(firstNonEmptyString(reqBody["prompt_cache_key"]))
|
||||
}
|
||||
|
||||
func isOpenAICompatMessagesBridgePromptCacheKey(key string) bool {
|
||||
key = strings.TrimSpace(key)
|
||||
return strings.HasPrefix(key, "anthropic-metadata-") ||
|
||||
strings.HasPrefix(key, "anthropic-cache-") ||
|
||||
strings.HasPrefix(key, "anthropic-digest-")
|
||||
}
|
||||
|
||||
func setOpenAICompatMessagesBridgeContext(c *gin.Context, enabled bool) {
|
||||
if c == nil || !enabled {
|
||||
return
|
||||
}
|
||||
c.Set(openAICompatMessagesBridgeContextKey, true)
|
||||
}
|
||||
|
||||
func isOpenAICompatMessagesBridgeContext(c *gin.Context) bool {
|
||||
if c == nil {
|
||||
return false
|
||||
}
|
||||
value, ok := c.Get(openAICompatMessagesBridgeContextKey)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
enabled, ok := value.(bool)
|
||||
return ok && enabled
|
||||
}
|
||||
277
backend/internal/service/openai_messages_continuation.go
Normal file
277
backend/internal/service/openai_messages_continuation.go
Normal file
@ -0,0 +1,277 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/apicompat"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
type openAICompatSessionResponseBinding struct {
|
||||
ResponseID string
|
||||
TurnState string
|
||||
ContinuationDisabled bool
|
||||
ExpiresAt time.Time
|
||||
}
|
||||
|
||||
func openAICompatContinuationEnabled(account *Account, model string) bool {
|
||||
if account == nil || account.Type != AccountTypeAPIKey {
|
||||
return false
|
||||
}
|
||||
return shouldAutoInjectPromptCacheKeyForCompat(model)
|
||||
}
|
||||
|
||||
func trimAnthropicCompatResponsesInputToLatestTurn(req *apicompat.ResponsesRequest) {
|
||||
if req == nil || len(req.Input) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
var items []apicompat.ResponsesInputItem
|
||||
if err := json.Unmarshal(req.Input, &items); err != nil || len(items) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
start := len(items) - 1
|
||||
for start > 0 && items[start].Type == "function_call_output" {
|
||||
start--
|
||||
}
|
||||
trimmed := append([]apicompat.ResponsesInputItem(nil), items[start:]...)
|
||||
if len(trimmed) == len(items) {
|
||||
return
|
||||
}
|
||||
if input, err := json.Marshal(trimmed); err == nil {
|
||||
req.Input = input
|
||||
}
|
||||
}
|
||||
|
||||
func isOpenAICompatPreviousResponseNotFound(statusCode int, upstreamMsg string, upstreamBody []byte) bool {
|
||||
if statusCode != http.StatusBadRequest && statusCode != http.StatusNotFound {
|
||||
return false
|
||||
}
|
||||
check := func(s string) bool {
|
||||
lower := strings.ToLower(strings.TrimSpace(s))
|
||||
return strings.Contains(lower, "previous_response_not_found") ||
|
||||
(strings.Contains(lower, "previous response") && strings.Contains(lower, "not found")) ||
|
||||
(strings.Contains(lower, "unsupported parameter") && strings.Contains(lower, "previous_response_id"))
|
||||
}
|
||||
if check(upstreamMsg) || check(string(upstreamBody)) {
|
||||
return true
|
||||
}
|
||||
return check(gjson.GetBytes(upstreamBody, "error.code").String()) ||
|
||||
check(gjson.GetBytes(upstreamBody, "error.message").String())
|
||||
}
|
||||
|
||||
func isOpenAICompatPreviousResponseUnsupported(statusCode int, upstreamMsg string, upstreamBody []byte) bool {
|
||||
if statusCode != http.StatusBadRequest {
|
||||
return false
|
||||
}
|
||||
check := func(s string) bool {
|
||||
lower := strings.ToLower(strings.TrimSpace(s))
|
||||
if !strings.Contains(lower, "previous_response_id") {
|
||||
return false
|
||||
}
|
||||
return strings.Contains(lower, "unsupported parameter") ||
|
||||
strings.Contains(lower, "only supported on responses websocket") ||
|
||||
strings.Contains(lower, "not supported")
|
||||
}
|
||||
if check(upstreamMsg) || check(string(upstreamBody)) {
|
||||
return true
|
||||
}
|
||||
return check(gjson.GetBytes(upstreamBody, "error.code").String()) ||
|
||||
check(gjson.GetBytes(upstreamBody, "error.message").String())
|
||||
}
|
||||
|
||||
func openAICompatSessionResponseKey(c *gin.Context, account *Account, promptCacheKey string) string {
|
||||
key := strings.TrimSpace(promptCacheKey)
|
||||
if account == nil || key == "" {
|
||||
return ""
|
||||
}
|
||||
apiKeyID := int64(0)
|
||||
if c != nil {
|
||||
apiKeyID = getAPIKeyIDFromContext(c)
|
||||
}
|
||||
return strings.Join([]string{
|
||||
strconv.FormatInt(account.ID, 10),
|
||||
strconv.FormatInt(apiKeyID, 10),
|
||||
key,
|
||||
}, "\x00")
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) getOpenAICompatSessionResponseID(_ context.Context, c *gin.Context, account *Account, promptCacheKey string) string {
|
||||
if s == nil {
|
||||
return ""
|
||||
}
|
||||
key := openAICompatSessionResponseKey(c, account, promptCacheKey)
|
||||
if key == "" {
|
||||
return ""
|
||||
}
|
||||
raw, ok := s.openaiCompatSessionResponses.Load(key)
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
binding, ok := raw.(openAICompatSessionResponseBinding)
|
||||
if !ok {
|
||||
s.openaiCompatSessionResponses.Delete(key)
|
||||
return ""
|
||||
}
|
||||
if !binding.ExpiresAt.IsZero() && time.Now().After(binding.ExpiresAt) {
|
||||
s.openaiCompatSessionResponses.Delete(key)
|
||||
return ""
|
||||
}
|
||||
if binding.ContinuationDisabled {
|
||||
return ""
|
||||
}
|
||||
if strings.TrimSpace(binding.ResponseID) == "" {
|
||||
s.openaiCompatSessionResponses.Delete(key)
|
||||
return ""
|
||||
}
|
||||
return strings.TrimSpace(binding.ResponseID)
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) bindOpenAICompatSessionResponseID(_ context.Context, c *gin.Context, account *Account, promptCacheKey, responseID string) {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
key := openAICompatSessionResponseKey(c, account, promptCacheKey)
|
||||
id := strings.TrimSpace(responseID)
|
||||
if key == "" || id == "" {
|
||||
return
|
||||
}
|
||||
binding := openAICompatSessionResponseBinding{
|
||||
ResponseID: id,
|
||||
ExpiresAt: time.Now().Add(s.openAIWSResponseStickyTTL()),
|
||||
}
|
||||
if raw, ok := s.openaiCompatSessionResponses.Load(key); ok {
|
||||
if existing, ok := raw.(openAICompatSessionResponseBinding); ok {
|
||||
if existing.ContinuationDisabled {
|
||||
existing.ResponseID = ""
|
||||
existing.ExpiresAt = time.Now().Add(s.openAIWSResponseStickyTTL())
|
||||
s.openaiCompatSessionResponses.Store(key, existing)
|
||||
return
|
||||
}
|
||||
binding.TurnState = existing.TurnState
|
||||
}
|
||||
}
|
||||
s.openaiCompatSessionResponses.Store(key, binding)
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) deleteOpenAICompatSessionResponseID(_ context.Context, c *gin.Context, account *Account, promptCacheKey string) {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
key := openAICompatSessionResponseKey(c, account, promptCacheKey)
|
||||
if key == "" {
|
||||
return
|
||||
}
|
||||
raw, ok := s.openaiCompatSessionResponses.Load(key)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
binding, ok := raw.(openAICompatSessionResponseBinding)
|
||||
if !ok {
|
||||
s.openaiCompatSessionResponses.Delete(key)
|
||||
return
|
||||
}
|
||||
binding.ResponseID = ""
|
||||
if strings.TrimSpace(binding.TurnState) == "" && !binding.ContinuationDisabled {
|
||||
s.openaiCompatSessionResponses.Delete(key)
|
||||
return
|
||||
}
|
||||
binding.ExpiresAt = time.Now().Add(s.openAIWSResponseStickyTTL())
|
||||
s.openaiCompatSessionResponses.Store(key, binding)
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) disableOpenAICompatSessionContinuation(_ context.Context, c *gin.Context, account *Account, promptCacheKey string) {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
key := openAICompatSessionResponseKey(c, account, promptCacheKey)
|
||||
if key == "" {
|
||||
return
|
||||
}
|
||||
binding := openAICompatSessionResponseBinding{
|
||||
ContinuationDisabled: true,
|
||||
ExpiresAt: time.Now().Add(s.openAIWSResponseStickyTTL()),
|
||||
}
|
||||
if raw, ok := s.openaiCompatSessionResponses.Load(key); ok {
|
||||
if existing, ok := raw.(openAICompatSessionResponseBinding); ok {
|
||||
binding.TurnState = existing.TurnState
|
||||
}
|
||||
}
|
||||
s.openaiCompatSessionResponses.Store(key, binding)
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) isOpenAICompatSessionContinuationDisabled(_ context.Context, c *gin.Context, account *Account, promptCacheKey string) bool {
|
||||
if s == nil {
|
||||
return false
|
||||
}
|
||||
key := openAICompatSessionResponseKey(c, account, promptCacheKey)
|
||||
if key == "" {
|
||||
return false
|
||||
}
|
||||
raw, ok := s.openaiCompatSessionResponses.Load(key)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
binding, ok := raw.(openAICompatSessionResponseBinding)
|
||||
if !ok {
|
||||
s.openaiCompatSessionResponses.Delete(key)
|
||||
return false
|
||||
}
|
||||
if !binding.ExpiresAt.IsZero() && time.Now().After(binding.ExpiresAt) {
|
||||
s.openaiCompatSessionResponses.Delete(key)
|
||||
return false
|
||||
}
|
||||
return binding.ContinuationDisabled
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) getOpenAICompatSessionTurnState(_ context.Context, c *gin.Context, account *Account, promptCacheKey string) string {
|
||||
if s == nil {
|
||||
return ""
|
||||
}
|
||||
key := openAICompatSessionResponseKey(c, account, promptCacheKey)
|
||||
if key == "" {
|
||||
return ""
|
||||
}
|
||||
raw, ok := s.openaiCompatSessionResponses.Load(key)
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
binding, ok := raw.(openAICompatSessionResponseBinding)
|
||||
if !ok || strings.TrimSpace(binding.TurnState) == "" {
|
||||
return ""
|
||||
}
|
||||
if !binding.ExpiresAt.IsZero() && time.Now().After(binding.ExpiresAt) {
|
||||
s.openaiCompatSessionResponses.Delete(key)
|
||||
return ""
|
||||
}
|
||||
return strings.TrimSpace(binding.TurnState)
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) bindOpenAICompatSessionTurnState(_ context.Context, c *gin.Context, account *Account, promptCacheKey, turnState string) {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
key := openAICompatSessionResponseKey(c, account, promptCacheKey)
|
||||
state := strings.TrimSpace(turnState)
|
||||
if key == "" || state == "" {
|
||||
return
|
||||
}
|
||||
binding := openAICompatSessionResponseBinding{
|
||||
TurnState: state,
|
||||
ExpiresAt: time.Now().Add(s.openAIWSResponseStickyTTL()),
|
||||
}
|
||||
if raw, ok := s.openaiCompatSessionResponses.Load(key); ok {
|
||||
if existing, ok := raw.(openAICompatSessionResponseBinding); ok {
|
||||
binding.ResponseID = existing.ResponseID
|
||||
binding.ContinuationDisabled = existing.ContinuationDisabled
|
||||
}
|
||||
}
|
||||
s.openaiCompatSessionResponses.Store(key, binding)
|
||||
}
|
||||
135
backend/internal/service/openai_messages_digest_session.go
Normal file
135
backend/internal/service/openai_messages_digest_session.go
Normal file
@ -0,0 +1,135 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/apicompat"
|
||||
)
|
||||
|
||||
type openAICompatAnthropicDigestBinding struct {
|
||||
PromptCacheKey string
|
||||
ExpiresAt time.Time
|
||||
}
|
||||
|
||||
func buildOpenAICompatAnthropicDigestChain(req *apicompat.AnthropicRequest) string {
|
||||
if req == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
parts := make([]string, 0, len(req.Messages)+1)
|
||||
if len(req.System) > 0 && strings.TrimSpace(string(req.System)) != "" && strings.TrimSpace(string(req.System)) != "null" {
|
||||
parts = append(parts, "s:"+shortHash(req.System))
|
||||
}
|
||||
for _, msg := range req.Messages {
|
||||
content := msg.Content
|
||||
if len(content) == 0 || strings.TrimSpace(string(content)) == "" {
|
||||
continue
|
||||
}
|
||||
prefix := "u"
|
||||
if strings.TrimSpace(msg.Role) == "assistant" {
|
||||
prefix = "a"
|
||||
}
|
||||
parts = append(parts, prefix+":"+shortHash(content))
|
||||
}
|
||||
return strings.Join(parts, "-")
|
||||
}
|
||||
|
||||
func openAICompatAnthropicDigestNamespace(account *Account, cAPIKeyID int64) string {
|
||||
if account == nil || account.ID <= 0 {
|
||||
return ""
|
||||
}
|
||||
return fmt.Sprintf("%d|%d|", account.ID, cAPIKeyID)
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) findOpenAICompatAnthropicDigestPromptCacheKey(account *Account, cAPIKeyID int64, digestChain string) (promptCacheKey string, matchedChain string) {
|
||||
if s == nil || digestChain == "" {
|
||||
return "", ""
|
||||
}
|
||||
ns := openAICompatAnthropicDigestNamespace(account, cAPIKeyID)
|
||||
if ns == "" {
|
||||
return "", ""
|
||||
}
|
||||
chain := digestChain
|
||||
for {
|
||||
if raw, ok := s.openaiCompatAnthropicDigestSessions.Load(ns + chain); ok {
|
||||
if binding, ok := raw.(openAICompatAnthropicDigestBinding); ok {
|
||||
if binding.ExpiresAt.IsZero() || time.Now().Before(binding.ExpiresAt) {
|
||||
if key := strings.TrimSpace(binding.PromptCacheKey); key != "" {
|
||||
return key, chain
|
||||
}
|
||||
}
|
||||
}
|
||||
s.openaiCompatAnthropicDigestSessions.Delete(ns + chain)
|
||||
}
|
||||
i := strings.LastIndex(chain, "-")
|
||||
if i < 0 {
|
||||
return "", ""
|
||||
}
|
||||
chain = chain[:i]
|
||||
}
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) bindOpenAICompatAnthropicDigestPromptCacheKey(account *Account, cAPIKeyID int64, digestChain, promptCacheKey, oldDigestChain string) {
|
||||
if s == nil || digestChain == "" || strings.TrimSpace(promptCacheKey) == "" {
|
||||
return
|
||||
}
|
||||
ns := openAICompatAnthropicDigestNamespace(account, cAPIKeyID)
|
||||
if ns == "" {
|
||||
return
|
||||
}
|
||||
binding := openAICompatAnthropicDigestBinding{
|
||||
PromptCacheKey: strings.TrimSpace(promptCacheKey),
|
||||
ExpiresAt: time.Now().Add(s.openAIWSResponseStickyTTL()),
|
||||
}
|
||||
s.openaiCompatAnthropicDigestSessions.Store(ns+digestChain, binding)
|
||||
if oldDigestChain != "" && oldDigestChain != digestChain {
|
||||
s.openaiCompatAnthropicDigestSessions.Delete(ns + oldDigestChain)
|
||||
}
|
||||
}
|
||||
|
||||
func promptCacheKeyFromAnthropicDigest(digestChain string) string {
|
||||
if strings.TrimSpace(digestChain) == "" {
|
||||
return ""
|
||||
}
|
||||
return "anthropic-digest-" + hashSensitiveValueForLog(digestChain)
|
||||
}
|
||||
|
||||
func promptCacheKeyFromAnthropicMetadataSession(req *apicompat.AnthropicRequest) string {
|
||||
if req == nil || len(req.Metadata) == 0 {
|
||||
return ""
|
||||
}
|
||||
var metadata struct {
|
||||
UserID string `json:"user_id"`
|
||||
}
|
||||
if err := json.Unmarshal(req.Metadata, &metadata); err != nil {
|
||||
return ""
|
||||
}
|
||||
parsed := ParseMetadataUserID(metadata.UserID)
|
||||
if parsed == nil || strings.TrimSpace(parsed.SessionID) == "" {
|
||||
return ""
|
||||
}
|
||||
seed := strings.Join([]string{
|
||||
"anthropic-metadata",
|
||||
strings.TrimSpace(parsed.DeviceID),
|
||||
strings.TrimSpace(parsed.AccountUUID),
|
||||
strings.TrimSpace(parsed.SessionID),
|
||||
}, "|")
|
||||
return "anthropic-metadata-" + hashSensitiveValueForLog(seed)
|
||||
}
|
||||
|
||||
func cloneAnthropicRequestForDigest(req *apicompat.AnthropicRequest) *apicompat.AnthropicRequest {
|
||||
if req == nil {
|
||||
return nil
|
||||
}
|
||||
cp := *req
|
||||
if len(req.System) > 0 {
|
||||
cp.System = append(json.RawMessage(nil), req.System...)
|
||||
}
|
||||
if len(req.Messages) > 0 {
|
||||
cp.Messages = append([]apicompat.AnthropicMessage(nil), req.Messages...)
|
||||
}
|
||||
return &cp
|
||||
}
|
||||
90
backend/internal/service/openai_messages_replay_guard.go
Normal file
90
backend/internal/service/openai_messages_replay_guard.go
Normal file
@ -0,0 +1,90 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/apicompat"
|
||||
)
|
||||
|
||||
const openAICompatAnthropicReplayMaxTailMessages = 12
|
||||
|
||||
func applyAnthropicCompatFullReplayGuard(req *apicompat.AnthropicRequest) bool {
|
||||
if req == nil || len(req.Messages) <= openAICompatAnthropicReplayMaxTailMessages {
|
||||
return false
|
||||
}
|
||||
|
||||
start := len(req.Messages) - openAICompatAnthropicReplayMaxTailMessages
|
||||
start = expandAnthropicCompatTrimBoundary(req.Messages, start)
|
||||
if start <= 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
req.Messages = append([]apicompat.AnthropicMessage(nil), req.Messages[start:]...)
|
||||
return true
|
||||
}
|
||||
|
||||
func expandAnthropicCompatTrimBoundary(messages []apicompat.AnthropicMessage, start int) int {
|
||||
if start <= 0 || start >= len(messages) {
|
||||
return start
|
||||
}
|
||||
|
||||
toolUseIndex := make(map[string]int)
|
||||
toolResultIndex := make(map[string]int)
|
||||
for i, msg := range messages {
|
||||
uses, results := anthropicCompatMessageToolIDs(msg)
|
||||
for _, id := range uses {
|
||||
if _, exists := toolUseIndex[id]; !exists {
|
||||
toolUseIndex[id] = i
|
||||
}
|
||||
}
|
||||
for _, id := range results {
|
||||
if _, exists := toolResultIndex[id]; !exists {
|
||||
toolResultIndex[id] = i
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for {
|
||||
next := start
|
||||
for i := start; i < len(messages); i++ {
|
||||
uses, results := anthropicCompatMessageToolIDs(messages[i])
|
||||
for _, id := range results {
|
||||
if useIdx, ok := toolUseIndex[id]; ok && useIdx < next {
|
||||
next = useIdx
|
||||
}
|
||||
}
|
||||
for _, id := range uses {
|
||||
if resultIdx, ok := toolResultIndex[id]; ok && resultIdx < next {
|
||||
next = resultIdx
|
||||
}
|
||||
}
|
||||
}
|
||||
if next == start {
|
||||
return start
|
||||
}
|
||||
start = next
|
||||
}
|
||||
}
|
||||
|
||||
func anthropicCompatMessageToolIDs(msg apicompat.AnthropicMessage) ([]string, []string) {
|
||||
var blocks []apicompat.AnthropicContentBlock
|
||||
if err := json.Unmarshal(msg.Content, &blocks); err != nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
uses := make([]string, 0, 1)
|
||||
results := make([]string, 0, 1)
|
||||
for _, block := range blocks {
|
||||
switch block.Type {
|
||||
case "tool_use":
|
||||
if block.ID != "" {
|
||||
uses = append(uses, block.ID)
|
||||
}
|
||||
case "tool_result":
|
||||
if block.ToolUseID != "" {
|
||||
results = append(results, block.ToolUseID)
|
||||
}
|
||||
}
|
||||
}
|
||||
return uses, results
|
||||
}
|
||||
@ -0,0 +1,58 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/apicompat"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestApplyAnthropicCompatFullReplayGuard_TrimsOldMessages(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
req := &apicompat.AnthropicRequest{Messages: make([]apicompat.AnthropicMessage, 0, openAICompatAnthropicReplayMaxTailMessages+3)}
|
||||
for i := 0; i < openAICompatAnthropicReplayMaxTailMessages+3; i++ {
|
||||
req.Messages = append(req.Messages, apicompat.AnthropicMessage{
|
||||
Role: "user",
|
||||
Content: json.RawMessage(fmt.Sprintf(`"message-%02d"`, i)),
|
||||
})
|
||||
}
|
||||
|
||||
trimmed := applyAnthropicCompatFullReplayGuard(req)
|
||||
|
||||
require.True(t, trimmed)
|
||||
require.Len(t, req.Messages, openAICompatAnthropicReplayMaxTailMessages)
|
||||
require.JSONEq(t, `"message-03"`, string(req.Messages[0].Content))
|
||||
require.JSONEq(t, `"message-14"`, string(req.Messages[len(req.Messages)-1].Content))
|
||||
}
|
||||
|
||||
func TestApplyAnthropicCompatFullReplayGuard_KeepsToolBoundaryIntact(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
req := &apicompat.AnthropicRequest{Messages: make([]apicompat.AnthropicMessage, 0, openAICompatAnthropicReplayMaxTailMessages+3)}
|
||||
for i := 0; i < openAICompatAnthropicReplayMaxTailMessages+3; i++ {
|
||||
role := "user"
|
||||
content := json.RawMessage(fmt.Sprintf(`"message-%02d"`, i))
|
||||
if i == 1 {
|
||||
role = "assistant"
|
||||
content = json.RawMessage(`[{"type":"tool_use","id":"toolu_keep","name":"Read","input":{"file_path":"main.go"}}]`)
|
||||
}
|
||||
if i == 3 {
|
||||
content = json.RawMessage(`[{"type":"tool_result","tool_use_id":"toolu_keep","content":"ok"}]`)
|
||||
}
|
||||
req.Messages = append(req.Messages, apicompat.AnthropicMessage{
|
||||
Role: role,
|
||||
Content: content,
|
||||
})
|
||||
}
|
||||
|
||||
trimmed := applyAnthropicCompatFullReplayGuard(req)
|
||||
|
||||
require.True(t, trimmed)
|
||||
require.Len(t, req.Messages, openAICompatAnthropicReplayMaxTailMessages+2)
|
||||
require.Equal(t, "assistant", req.Messages[0].Role)
|
||||
require.Contains(t, string(req.Messages[0].Content), `"toolu_keep"`)
|
||||
require.Contains(t, string(req.Messages[2].Content), `"tool_result"`)
|
||||
}
|
||||
121
backend/internal/service/openai_messages_todo_guard.go
Normal file
121
backend/internal/service/openai_messages_todo_guard.go
Normal file
@ -0,0 +1,121 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/apicompat"
|
||||
)
|
||||
|
||||
const (
|
||||
openAICompatClaudeCodeTodoGuardMarker = "<sub2api-claude-code-todo-guard>"
|
||||
openAICompatClaudeCodeTodoGuardText = openAICompatClaudeCodeTodoGuardMarker + "\nWhen using Claude Code todo or task tracking tools, keep the visible task list consistent. Do not send final or summary text while any item remains in_progress. Before finishing, asking the user to choose, or reporting a blocker, update the todo list so completed work is completed and deferred work is pending/open; leave an item in_progress only when active work will continue in the same turn.\n</sub2api-claude-code-todo-guard>"
|
||||
)
|
||||
|
||||
func appendOpenAICompatClaudeCodeTodoGuard(req *apicompat.ResponsesRequest) bool {
|
||||
if req == nil || len(req.Input) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
var items []apicompat.ResponsesInputItem
|
||||
if err := json.Unmarshal(req.Input, &items); err != nil {
|
||||
return false
|
||||
}
|
||||
if len(items) == 0 || responsesInputItemsContainText(items, openAICompatClaudeCodeTodoGuardMarker) {
|
||||
return false
|
||||
}
|
||||
|
||||
content, err := json.Marshal([]apicompat.ResponsesContentPart{{
|
||||
Type: "input_text",
|
||||
Text: openAICompatClaudeCodeTodoGuardText,
|
||||
}})
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
guard := apicompat.ResponsesInputItem{
|
||||
Type: "message",
|
||||
Role: "developer",
|
||||
Content: content,
|
||||
}
|
||||
|
||||
insertAt := 0
|
||||
for insertAt < len(items) && items[insertAt].Type == "message" && items[insertAt].Role == "developer" {
|
||||
insertAt++
|
||||
}
|
||||
|
||||
items = append(items, apicompat.ResponsesInputItem{})
|
||||
copy(items[insertAt+1:], items[insertAt:])
|
||||
items[insertAt] = guard
|
||||
|
||||
input, err := json.Marshal(items)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
req.Input = input
|
||||
return true
|
||||
}
|
||||
|
||||
func appendOpenAICompatClaudeCodeTodoGuardToRequestBody(reqBody map[string]any) bool {
|
||||
if reqBody == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
input, ok := reqBody["input"].([]any)
|
||||
if !ok || len(input) == 0 || inputContainsText(input, openAICompatClaudeCodeTodoGuardMarker) {
|
||||
return false
|
||||
}
|
||||
|
||||
guard := map[string]any{
|
||||
"type": "message",
|
||||
"role": "developer",
|
||||
"content": []any{
|
||||
map[string]any{
|
||||
"type": "input_text",
|
||||
"text": openAICompatClaudeCodeTodoGuardText,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
insertAt := 0
|
||||
for insertAt < len(input) {
|
||||
item, ok := input[insertAt].(map[string]any)
|
||||
if !ok || strings.TrimSpace(firstNonEmptyString(item["type"])) != "message" || strings.TrimSpace(firstNonEmptyString(item["role"])) != "developer" {
|
||||
break
|
||||
}
|
||||
insertAt++
|
||||
}
|
||||
|
||||
input = append(input, nil)
|
||||
copy(input[insertAt+1:], input[insertAt:])
|
||||
input[insertAt] = guard
|
||||
reqBody["input"] = input
|
||||
return true
|
||||
}
|
||||
|
||||
func responsesInputItemsContainText(items []apicompat.ResponsesInputItem, needle string) bool {
|
||||
needle = strings.TrimSpace(needle)
|
||||
if needle == "" {
|
||||
return false
|
||||
}
|
||||
for _, item := range items {
|
||||
if strings.Contains(string(item.Content), needle) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func inputContainsText(input []any, needle string) bool {
|
||||
needle = strings.TrimSpace(needle)
|
||||
if needle == "" {
|
||||
return false
|
||||
}
|
||||
for _, item := range input {
|
||||
b, err := json.Marshal(item)
|
||||
if err == nil && strings.Contains(string(b), needle) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
137
backend/internal/service/openai_model_alias.go
Normal file
137
backend/internal/service/openai_model_alias.go
Normal file
@ -0,0 +1,137 @@
|
||||
package service
|
||||
|
||||
import "strings"
|
||||
|
||||
func lastOpenAIModelSegment(model string) string {
|
||||
model = strings.TrimSpace(model)
|
||||
if model == "" {
|
||||
return ""
|
||||
}
|
||||
if strings.Contains(model, "/") {
|
||||
parts := strings.Split(model, "/")
|
||||
model = parts[len(parts)-1]
|
||||
}
|
||||
return strings.TrimSpace(model)
|
||||
}
|
||||
|
||||
func canonicalizeOpenAIModelAliasSpelling(model string) string {
|
||||
model = strings.ToLower(lastOpenAIModelSegment(model))
|
||||
if model == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
normalized := strings.ReplaceAll(model, "_", "-")
|
||||
normalized = strings.Join(strings.Fields(normalized), "-")
|
||||
for strings.Contains(normalized, "--") {
|
||||
normalized = strings.ReplaceAll(normalized, "--", "-")
|
||||
}
|
||||
|
||||
if strings.HasPrefix(normalized, "gpt5") {
|
||||
normalized = "gpt-5" + strings.TrimPrefix(normalized, "gpt5")
|
||||
}
|
||||
if !strings.HasPrefix(normalized, "gpt-") && !strings.Contains(normalized, "codex") {
|
||||
return ""
|
||||
}
|
||||
|
||||
replacements := []struct {
|
||||
from string
|
||||
to string
|
||||
}{
|
||||
{"gpt-5.4mini", "gpt-5.4-mini"},
|
||||
{"gpt-5.4nano", "gpt-5.4-nano"},
|
||||
{"gpt-5.3-codexspark", "gpt-5.3-codex-spark"},
|
||||
{"gpt-5.3codexspark", "gpt-5.3-codex-spark"},
|
||||
{"gpt-5.3codex", "gpt-5.3-codex"},
|
||||
}
|
||||
for _, replacement := range replacements {
|
||||
normalized = strings.ReplaceAll(normalized, replacement.from, replacement.to)
|
||||
}
|
||||
return normalized
|
||||
}
|
||||
|
||||
func normalizeKnownOpenAICodexModel(model string) string {
|
||||
normalized := canonicalizeOpenAIModelAliasSpelling(model)
|
||||
if normalized == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
if mapped := getNormalizedCodexModel(normalized); mapped != "" {
|
||||
return mapped
|
||||
}
|
||||
if strings.HasSuffix(normalized, "-openai-compact") {
|
||||
if mapped := getNormalizedCodexModel(strings.TrimSuffix(normalized, "-openai-compact")); mapped != "" {
|
||||
return mapped
|
||||
}
|
||||
}
|
||||
|
||||
switch {
|
||||
case strings.Contains(normalized, "gpt-5.5"):
|
||||
return "gpt-5.5"
|
||||
case strings.Contains(normalized, "gpt-5.4-mini"):
|
||||
return "gpt-5.4-mini"
|
||||
case strings.Contains(normalized, "gpt-5.4-nano"):
|
||||
return "gpt-5.4-nano"
|
||||
case strings.Contains(normalized, "gpt-5.4"):
|
||||
return "gpt-5.4"
|
||||
case strings.Contains(normalized, "gpt-5.2"):
|
||||
return "gpt-5.2"
|
||||
case strings.Contains(normalized, "gpt-5.3-codex-spark"):
|
||||
return "gpt-5.3-codex-spark"
|
||||
case strings.Contains(normalized, "gpt-5.3-codex"):
|
||||
return "gpt-5.3-codex"
|
||||
case strings.Contains(normalized, "gpt-5.3"):
|
||||
return "gpt-5.3-codex"
|
||||
case strings.Contains(normalized, "codex"):
|
||||
return "gpt-5.3-codex"
|
||||
case strings.Contains(normalized, "gpt-5"):
|
||||
return "gpt-5.4"
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
func appendUsageBillingModelCandidate(candidates []string, seen map[string]struct{}, model string) []string {
|
||||
trimmed := strings.TrimSpace(model)
|
||||
if trimmed == "" {
|
||||
return candidates
|
||||
}
|
||||
add := func(candidate string) {
|
||||
candidate = strings.TrimSpace(candidate)
|
||||
if candidate == "" {
|
||||
return
|
||||
}
|
||||
key := strings.ToLower(candidate)
|
||||
if _, ok := seen[key]; ok {
|
||||
return
|
||||
}
|
||||
seen[key] = struct{}{}
|
||||
candidates = append(candidates, candidate)
|
||||
}
|
||||
|
||||
add(trimmed)
|
||||
if canonical := canonicalizeOpenAIModelAliasSpelling(trimmed); canonical != "" {
|
||||
add(canonical)
|
||||
}
|
||||
if normalized := normalizeKnownOpenAICodexModel(trimmed); normalized != "" {
|
||||
add(normalized)
|
||||
}
|
||||
return candidates
|
||||
}
|
||||
|
||||
func usageBillingModelCandidates(primary string, alternates ...string) []string {
|
||||
seen := make(map[string]struct{}, 1+len(alternates))
|
||||
candidates := appendUsageBillingModelCandidate(nil, seen, primary)
|
||||
for _, alternate := range alternates {
|
||||
candidates = appendUsageBillingModelCandidate(candidates, seen, alternate)
|
||||
}
|
||||
return candidates
|
||||
}
|
||||
|
||||
func firstUsageBillingModel(candidates []string) string {
|
||||
for _, candidate := range candidates {
|
||||
if trimmed := strings.TrimSpace(candidate); trimmed != "" {
|
||||
return trimmed
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
@ -2,44 +2,24 @@ package service
|
||||
|
||||
import "strings"
|
||||
|
||||
// resolveOpenAIForwardModel determines the upstream model for OpenAI-compatible
|
||||
// forwarding. Group-level default mapping only applies when the account itself
|
||||
// did not match any explicit model_mapping rule.
|
||||
// resolveOpenAIForwardModel 解析 OpenAI 兼容转发使用的模型。
|
||||
// defaultMappedModel 只服务于 /v1/messages 的 Claude 系列显式调度映射,
|
||||
// 不作为普通 OpenAI 请求的未知模型兜底。
|
||||
func resolveOpenAIForwardModel(account *Account, requestedModel, defaultMappedModel string) string {
|
||||
if account == nil {
|
||||
if defaultMappedModel != "" {
|
||||
if defaultMappedModel != "" && claudeMessagesDispatchFamily(requestedModel) != "" {
|
||||
return defaultMappedModel
|
||||
}
|
||||
return requestedModel
|
||||
}
|
||||
|
||||
mappedModel, matched := account.ResolveMappedModel(requestedModel)
|
||||
if !matched && defaultMappedModel != "" && !isExplicitCodexModel(requestedModel) {
|
||||
if !matched && defaultMappedModel != "" && claudeMessagesDispatchFamily(requestedModel) != "" {
|
||||
return defaultMappedModel
|
||||
}
|
||||
return mappedModel
|
||||
}
|
||||
|
||||
func isExplicitCodexModel(model string) bool {
|
||||
model = strings.TrimSpace(model)
|
||||
if model == "" {
|
||||
return false
|
||||
}
|
||||
if strings.Contains(model, "/") {
|
||||
parts := strings.Split(model, "/")
|
||||
model = parts[len(parts)-1]
|
||||
}
|
||||
model = strings.ToLower(strings.TrimSpace(model))
|
||||
if getNormalizedCodexModel(model) != "" {
|
||||
return true
|
||||
}
|
||||
if strings.HasSuffix(model, "-openai-compact") {
|
||||
base := strings.TrimSuffix(model, "-openai-compact")
|
||||
return getNormalizedCodexModel(base) != ""
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// resolveOpenAICompactForwardModel determines the compact-only upstream model
|
||||
// for /responses/compact requests. It never affects normal /responses traffic.
|
||||
// When no compact-specific mapping matches, the input model is returned as-is.
|
||||
|
||||
@ -11,7 +11,7 @@ func TestResolveOpenAIForwardModel(t *testing.T) {
|
||||
expectedModel string
|
||||
}{
|
||||
{
|
||||
name: "falls back to group default when account has no mapping",
|
||||
name: "uses messages dispatch default for claude model",
|
||||
account: &Account{
|
||||
Credentials: map[string]any{},
|
||||
},
|
||||
@ -19,6 +19,15 @@ func TestResolveOpenAIForwardModel(t *testing.T) {
|
||||
defaultMappedModel: "gpt-4o-mini",
|
||||
expectedModel: "gpt-4o-mini",
|
||||
},
|
||||
{
|
||||
name: "does not fall back to group default for invalid gpt model",
|
||||
account: &Account{
|
||||
Credentials: map[string]any{},
|
||||
},
|
||||
requestedModel: "gpt6",
|
||||
defaultMappedModel: "gpt-5.4",
|
||||
expectedModel: "gpt6",
|
||||
},
|
||||
{
|
||||
name: "preserves explicit gpt-5.4 instead of group default",
|
||||
account: &Account{
|
||||
@ -85,6 +94,15 @@ func TestResolveOpenAIForwardModel(t *testing.T) {
|
||||
defaultMappedModel: "gpt-5.4",
|
||||
expectedModel: "gpt-5.5",
|
||||
},
|
||||
{
|
||||
name: "preserves compact-spelled gpt5.5 instead of group default",
|
||||
account: &Account{
|
||||
Credentials: map[string]any{},
|
||||
},
|
||||
requestedModel: "gpt5.5",
|
||||
defaultMappedModel: "gpt-5.4",
|
||||
expectedModel: "gpt5.5",
|
||||
},
|
||||
{
|
||||
name: "preserves openai namespaced gpt-5.5 instead of group default",
|
||||
account: &Account{
|
||||
@ -119,14 +137,14 @@ func TestResolveOpenAIForwardModel_PreventsClaudeModelFromFallingBackToGpt54(t *
|
||||
Credentials: map[string]any{},
|
||||
}
|
||||
|
||||
withoutDefault := normalizeCodexModel(resolveOpenAIForwardModel(account, "claude-opus-4-6", ""))
|
||||
if withoutDefault != "gpt-5.4" {
|
||||
t.Fatalf("normalizeCodexModel(...) = %q, want %q", withoutDefault, "gpt-5.4")
|
||||
withoutDefault := resolveOpenAIForwardModel(account, "claude-opus-4-6", "")
|
||||
if withoutDefault != "claude-opus-4-6" {
|
||||
t.Fatalf("resolveOpenAIForwardModel(...) = %q, want %q", withoutDefault, "claude-opus-4-6")
|
||||
}
|
||||
|
||||
withDefault := normalizeCodexModel(resolveOpenAIForwardModel(account, "claude-opus-4-6", "gpt-5.4"))
|
||||
withDefault := resolveOpenAIForwardModel(account, "claude-opus-4-6", "gpt-5.4")
|
||||
if withDefault != "gpt-5.4" {
|
||||
t.Fatalf("normalizeCodexModel(...) = %q, want %q", withDefault, "gpt-5.4")
|
||||
t.Fatalf("resolveOpenAIForwardModel(...) = %q, want %q", withDefault, "gpt-5.4")
|
||||
}
|
||||
}
|
||||
|
||||
@ -205,6 +223,10 @@ func TestNormalizeCodexModel(t *testing.T) {
|
||||
"gpt-5.3-codex-spark-xhigh": "gpt-5.3-codex-spark",
|
||||
"gpt-5.3": "gpt-5.3-codex",
|
||||
"gpt-image-2": "gpt-image-2",
|
||||
"gpt-5.4-nano": "gpt-5.4-nano",
|
||||
"gpt-5.4-nano-high": "gpt-5.4-nano",
|
||||
"gpt6": "gpt6",
|
||||
"claude-opus-4-6": "claude-opus-4-6",
|
||||
}
|
||||
|
||||
for input, expected := range cases {
|
||||
@ -222,9 +244,21 @@ func TestNormalizeOpenAIModelForUpstream(t *testing.T) {
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "oauth keeps codex normalization behavior",
|
||||
name: "oauth preserves unknown non codex model",
|
||||
account: &Account{Type: AccountTypeOAuth},
|
||||
model: "gemini-3-flash-preview",
|
||||
want: "gemini-3-flash-preview",
|
||||
},
|
||||
{
|
||||
name: "oauth preserves invalid gpt model",
|
||||
account: &Account{Type: AccountTypeOAuth},
|
||||
model: "gpt6",
|
||||
want: "gpt6",
|
||||
},
|
||||
{
|
||||
name: "oauth normalizes known codex alias",
|
||||
account: &Account{Type: AccountTypeOAuth},
|
||||
model: "gpt-5.4-high",
|
||||
want: "gpt-5.4",
|
||||
},
|
||||
{
|
||||
|
||||
@ -25,9 +25,12 @@ func f64p(v float64) *float64 { return &v }
|
||||
type httpUpstreamRecorder struct {
|
||||
lastReq *http.Request
|
||||
lastBody []byte
|
||||
requests []*http.Request
|
||||
bodies [][]byte
|
||||
|
||||
resp *http.Response
|
||||
err error
|
||||
resp *http.Response
|
||||
responses []*http.Response
|
||||
err error
|
||||
}
|
||||
|
||||
func (u *httpUpstreamRecorder) Do(req *http.Request, proxyURL string, accountID int64, accountConcurrency int) (*http.Response, error) {
|
||||
@ -35,12 +38,19 @@ func (u *httpUpstreamRecorder) Do(req *http.Request, proxyURL string, accountID
|
||||
if req != nil && req.Body != nil {
|
||||
b, _ := io.ReadAll(req.Body)
|
||||
u.lastBody = b
|
||||
u.bodies = append(u.bodies, append([]byte(nil), b...))
|
||||
_ = req.Body.Close()
|
||||
req.Body = io.NopCloser(bytes.NewReader(b))
|
||||
}
|
||||
u.requests = append(u.requests, req)
|
||||
if u.err != nil {
|
||||
return nil, u.err
|
||||
}
|
||||
if len(u.responses) > 0 {
|
||||
resp := u.responses[0]
|
||||
u.responses = u.responses[1:]
|
||||
return resp, nil
|
||||
}
|
||||
return u.resp, nil
|
||||
}
|
||||
|
||||
@ -48,6 +58,93 @@ func (u *httpUpstreamRecorder) DoWithTLS(req *http.Request, proxyURL string, acc
|
||||
return u.Do(req, proxyURL, accountID, accountConcurrency)
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayService_ResponsesUnknownModelDoesNotFallbackToGPT54(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
originalBody := []byte(`{"model":"gpt6","stream":false,"instructions":"local-test-instructions","input":[{"type":"text","text":"hi"}]}`)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(originalBody))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
upstream := &httpUpstreamRecorder{resp: &http.Response{
|
||||
StatusCode: http.StatusBadRequest,
|
||||
Header: http.Header{"Content-Type": []string{"application/json"}, "x-request-id": []string{"rid_unknown_model"}},
|
||||
Body: io.NopCloser(strings.NewReader(`{"error":{"type":"invalid_request_error","message":"model not found"}}`)),
|
||||
}}
|
||||
|
||||
svc := &OpenAIGatewayService{
|
||||
cfg: &config.Config{},
|
||||
httpUpstream: upstream,
|
||||
}
|
||||
account := &Account{
|
||||
ID: 123,
|
||||
Name: "acc",
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "oauth-token",
|
||||
"chatgpt_account_id": "chatgpt-acc",
|
||||
},
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
}
|
||||
|
||||
result, err := svc.Forward(context.Background(), c, account, originalBody)
|
||||
require.Error(t, err)
|
||||
require.Nil(t, result)
|
||||
require.NotNil(t, upstream.lastReq)
|
||||
require.Equal(t, "https://chatgpt.com/backend-api/codex/responses", upstream.lastReq.URL.String())
|
||||
require.Equal(t, "gpt6", gjson.GetBytes(upstream.lastBody, "model").String())
|
||||
require.NotEqual(t, "gpt-5.4", gjson.GetBytes(upstream.lastBody, "model").String())
|
||||
require.True(t, rec.Code >= http.StatusBadRequest)
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayService_OAuthMessagesBridgeDoesNotInjectDefaultInstructions(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
originalBody := []byte(`{"model":"gpt-5.5","stream":true,"prompt_cache_key":"anthropic-metadata-session-1","input":[{"type":"message","role":"developer","content":[{"type":"input_text","text":"<sub2api-claude-code-todo-guard>"}]},{"type":"message","role":"user","content":"hello"}]}`)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(originalBody))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
upstream := &httpUpstreamRecorder{resp: &http.Response{
|
||||
StatusCode: http.StatusBadRequest,
|
||||
Header: http.Header{"Content-Type": []string{"application/json"}, "x-request-id": []string{"rid_bridge"}},
|
||||
Body: io.NopCloser(strings.NewReader(`{"error":{"type":"invalid_request_error","message":"bridge stop"}}`)),
|
||||
}}
|
||||
svc := &OpenAIGatewayService{
|
||||
cfg: &config.Config{},
|
||||
httpUpstream: upstream,
|
||||
}
|
||||
account := &Account{
|
||||
ID: 123,
|
||||
Name: "acc",
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "oauth-token",
|
||||
"chatgpt_account_id": "chatgpt-acc",
|
||||
},
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
}
|
||||
|
||||
result, err := svc.Forward(context.Background(), c, account, originalBody)
|
||||
require.Error(t, err)
|
||||
require.Nil(t, result)
|
||||
require.NotNil(t, upstream.lastReq)
|
||||
require.Equal(t, "", gjson.GetBytes(upstream.lastBody, "instructions").String())
|
||||
require.False(t, gjson.GetBytes(upstream.lastBody, "prompt_cache_key").Exists())
|
||||
require.NotEmpty(t, upstream.lastReq.Header.Get("Session_Id"))
|
||||
require.Empty(t, upstream.lastReq.Header.Get("Conversation_Id"))
|
||||
require.Empty(t, upstream.lastReq.Header.Get("OpenAI-Beta"))
|
||||
require.Empty(t, upstream.lastReq.Header.Get("originator"))
|
||||
}
|
||||
|
||||
type openAIPassthroughFailoverRepo struct {
|
||||
stubOpenAIAccountRepo
|
||||
rateLimitCalls []time.Time
|
||||
@ -307,6 +404,52 @@ func TestOpenAIGatewayService_OAuthPassthrough_CompactUsesJSONAndKeepsNonStreami
|
||||
require.Contains(t, rec.Body.String(), `"id":"cmp_123"`)
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayService_OAuthPassthrough_UpstreamRequestIgnoresClientCancel(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
reqCtx, cancel := context.WithCancel(context.Background())
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(nil)).WithContext(reqCtx)
|
||||
c.Request.Header.Set("User-Agent", "codex_cli_rs/0.1.0")
|
||||
cancel()
|
||||
|
||||
originalBody := []byte(`{"model":"gpt-5.2","stream":true,"store":true,"instructions":"local-test-instructions","input":[{"type":"text","text":"hi"}]}`)
|
||||
upstream := &httpUpstreamRecorder{resp: &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_passthrough_ctx"}},
|
||||
Body: io.NopCloser(strings.NewReader(strings.Join([]string{
|
||||
`data: {"type":"response.completed","response":{"usage":{"input_tokens":2,"output_tokens":1}}}`,
|
||||
"",
|
||||
"data: [DONE]",
|
||||
"",
|
||||
}, "\n"))),
|
||||
}}
|
||||
|
||||
svc := &OpenAIGatewayService{
|
||||
cfg: &config.Config{Gateway: config.GatewayConfig{ForceCodexCLI: false}},
|
||||
httpUpstream: upstream,
|
||||
}
|
||||
account := &Account{
|
||||
ID: 123,
|
||||
Name: "acc",
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{"access_token": "oauth-token", "chatgpt_account_id": "chatgpt-acc"},
|
||||
Extra: map[string]any{"openai_passthrough": true, "openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeOff},
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
RateMultiplier: f64p(1),
|
||||
}
|
||||
|
||||
result, err := svc.Forward(reqCtx, c, account, originalBody)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.NotNil(t, upstream.lastReq)
|
||||
require.NoError(t, upstream.lastReq.Context().Err())
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayService_OAuthPassthrough_CodexMissingInstructionsRejectedBeforeUpstream(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
logSink, restore := captureStructuredLog(t)
|
||||
@ -405,6 +548,52 @@ func TestOpenAIGatewayService_OAuthPassthrough_DisabledUsesLegacyTransform(t *te
|
||||
require.Contains(t, string(upstream.lastBody), `"stream":true`)
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayService_OAuthLegacy_UpstreamRequestIgnoresClientCancel(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
reqCtx, cancel := context.WithCancel(context.Background())
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(nil)).WithContext(reqCtx)
|
||||
c.Request.Header.Set("User-Agent", "codex_cli_rs/0.1.0")
|
||||
cancel()
|
||||
|
||||
originalBody := []byte(`{"model":"gpt-5.2","stream":false,"store":true,"input":[{"type":"text","text":"hi"}]}`)
|
||||
upstream := &httpUpstreamRecorder{resp: &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_legacy_ctx"}},
|
||||
Body: io.NopCloser(strings.NewReader(strings.Join([]string{
|
||||
`data: {"type":"response.completed","response":{"usage":{"input_tokens":1,"output_tokens":1}}}`,
|
||||
"",
|
||||
"data: [DONE]",
|
||||
"",
|
||||
}, "\n"))),
|
||||
}}
|
||||
|
||||
svc := &OpenAIGatewayService{
|
||||
cfg: &config.Config{Gateway: config.GatewayConfig{ForceCodexCLI: false}},
|
||||
httpUpstream: upstream,
|
||||
}
|
||||
account := &Account{
|
||||
ID: 123,
|
||||
Name: "acc",
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{"access_token": "oauth-token", "chatgpt_account_id": "chatgpt-acc"},
|
||||
Extra: map[string]any{"openai_passthrough": false, "openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeOff},
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
RateMultiplier: f64p(1),
|
||||
}
|
||||
|
||||
result, err := svc.Forward(reqCtx, c, account, originalBody)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.NotNil(t, upstream.lastReq)
|
||||
require.NoError(t, upstream.lastReq.Context().Err())
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayService_OAuthLegacy_CompositeCodexUAUsesCodexOriginator(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
|
||||
70
backend/internal/service/openai_sse_data.go
Normal file
70
backend/internal/service/openai_sse_data.go
Normal file
@ -0,0 +1,70 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
type openAISSEDataAccumulator struct {
|
||||
lines []string
|
||||
}
|
||||
|
||||
func (a *openAISSEDataAccumulator) AddLine(line string, fn func([]byte)) {
|
||||
if fn == nil {
|
||||
return
|
||||
}
|
||||
trimmedLine := strings.TrimRight(line, "\r\n")
|
||||
if data, ok := extractOpenAISSEDataLine(trimmedLine); ok {
|
||||
a.lines = append(a.lines, data)
|
||||
return
|
||||
}
|
||||
if strings.TrimSpace(trimmedLine) == "" {
|
||||
a.Flush(fn)
|
||||
}
|
||||
}
|
||||
|
||||
func (a *openAISSEDataAccumulator) Flush(fn func([]byte)) {
|
||||
if fn == nil || len(a.lines) == 0 {
|
||||
return
|
||||
}
|
||||
emitOpenAISSEDataPayloads(a.lines, fn)
|
||||
a.lines = a.lines[:0]
|
||||
}
|
||||
|
||||
func forEachOpenAISSEDataPayload(body string, fn func([]byte)) {
|
||||
if fn == nil || strings.TrimSpace(body) == "" {
|
||||
return
|
||||
}
|
||||
var acc openAISSEDataAccumulator
|
||||
for _, line := range strings.Split(body, "\n") {
|
||||
acc.AddLine(line, fn)
|
||||
}
|
||||
acc.Flush(fn)
|
||||
}
|
||||
|
||||
func emitOpenAISSEDataPayloads(lines []string, fn func([]byte)) {
|
||||
if fn == nil || len(lines) == 0 {
|
||||
return
|
||||
}
|
||||
if len(lines) == 1 {
|
||||
emitOpenAISSEDataPayload(lines[0], fn)
|
||||
return
|
||||
}
|
||||
joined := strings.Join(lines, "\n")
|
||||
if gjson.Valid(joined) {
|
||||
emitOpenAISSEDataPayload(joined, fn)
|
||||
return
|
||||
}
|
||||
for _, line := range lines {
|
||||
emitOpenAISSEDataPayload(line, fn)
|
||||
}
|
||||
}
|
||||
|
||||
func emitOpenAISSEDataPayload(data string, fn func([]byte)) {
|
||||
data = strings.TrimSpace(data)
|
||||
if data == "" || data == "[DONE]" {
|
||||
return
|
||||
}
|
||||
fn([]byte(data))
|
||||
}
|
||||
@ -219,8 +219,11 @@ func (e *OpenAIWSClientCloseError) Reason() string {
|
||||
|
||||
// OpenAIWSIngressHooks 定义入站 WS 每个 turn 的生命周期回调。
|
||||
type OpenAIWSIngressHooks struct {
|
||||
BeforeTurn func(turn int) error
|
||||
AfterTurn func(turn int, result *OpenAIForwardResult, turnErr error)
|
||||
// InitialRequestModel 是首帧渠道映射前的请求模型,只用于 usage metadata
|
||||
// 的 reasoning effort 后缀推导,禁止用于上游请求或计费模型。
|
||||
InitialRequestModel string
|
||||
BeforeTurn func(turn int) error
|
||||
AfterTurn func(turn int, result *OpenAIForwardResult, turnErr error)
|
||||
}
|
||||
|
||||
func normalizeOpenAIWSLogValue(value string) string {
|
||||
@ -1987,6 +1990,7 @@ func (s *OpenAIGatewayService) forwardOpenAIWSV2(
|
||||
}
|
||||
|
||||
usage := &OpenAIUsage{}
|
||||
imageCounter := newOpenAIImageOutputCounter()
|
||||
var firstTokenMs *int
|
||||
responseID := ""
|
||||
var finalResponse []byte
|
||||
@ -2168,6 +2172,7 @@ func (s *OpenAIGatewayService) forwardOpenAIWSV2(
|
||||
if openAIWSEventShouldParseUsage(eventType) {
|
||||
parseOpenAIWSResponseUsageFromCompletedEvent(message, usage)
|
||||
}
|
||||
imageCounter.AddSSEData(message)
|
||||
|
||||
if eventType == "error" {
|
||||
errCodeRaw, errTypeRaw, errMsgRaw := parseOpenAIWSErrorEventFields(message)
|
||||
@ -2340,6 +2345,7 @@ func (s *OpenAIGatewayService) forwardOpenAIWSV2(
|
||||
Usage: *usage,
|
||||
Model: originalModel,
|
||||
UpstreamModel: mappedModel,
|
||||
ImageCount: imageCounter.Count(),
|
||||
ServiceTier: extractOpenAIServiceTier(reqBody),
|
||||
ReasoningEffort: extractOpenAIReasoningEffort(reqBody, originalModel),
|
||||
Stream: reqStream,
|
||||
@ -2446,6 +2452,8 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
|
||||
promptCacheKey string
|
||||
previousResponseID string
|
||||
originalModel string
|
||||
imageBillingModel string
|
||||
imageSizeTier string
|
||||
payloadBytes int
|
||||
}
|
||||
|
||||
@ -2543,6 +2551,19 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
|
||||
}
|
||||
normalized = next
|
||||
}
|
||||
imageIntent := IsImageGenerationIntent(openAIResponsesEndpoint, originalModel, normalized)
|
||||
if imageIntent && !GroupAllowsImageGeneration(apiKeyGroup(getAPIKeyFromContext(c))) {
|
||||
return openAIWSClientPayload{}, NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, ImageGenerationPermissionMessage(), nil)
|
||||
}
|
||||
imageBillingModel := ""
|
||||
imageSizeTier := ""
|
||||
if imageIntent {
|
||||
var imageCfgErr error
|
||||
imageBillingModel, imageSizeTier, imageCfgErr = resolveOpenAIResponsesImageBillingConfigFromBody(normalized, originalModel)
|
||||
if imageCfgErr != nil {
|
||||
return openAIWSClientPayload{}, NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, imageCfgErr.Error(), imageCfgErr)
|
||||
}
|
||||
}
|
||||
|
||||
// Apply OpenAI Fast Policy on the response.create frame using the same
|
||||
// evaluator/normalize/scope rules as the HTTP entrypoints. This is the
|
||||
@ -2588,6 +2609,8 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
|
||||
promptCacheKey: promptCacheKey,
|
||||
previousResponseID: previousResponseID,
|
||||
originalModel: originalModel,
|
||||
imageBillingModel: imageBillingModel,
|
||||
imageSizeTier: imageSizeTier,
|
||||
payloadBytes: len(normalized),
|
||||
}, nil
|
||||
}
|
||||
@ -2789,7 +2812,7 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
|
||||
return payload, nil
|
||||
}
|
||||
|
||||
sendAndRelay := func(turn int, lease *openAIWSConnLease, payload []byte, payloadBytes int, originalModel string) (*OpenAIForwardResult, error) {
|
||||
sendAndRelay := func(turn int, lease *openAIWSConnLease, payload []byte, payloadBytes int, originalModel string, imageBillingModel string, imageSizeTier string) (*OpenAIForwardResult, error) {
|
||||
if lease == nil {
|
||||
return nil, errors.New("upstream websocket lease is nil")
|
||||
}
|
||||
@ -2814,6 +2837,7 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
|
||||
|
||||
responseID := ""
|
||||
usage := OpenAIUsage{}
|
||||
imageCounter := newOpenAIImageOutputCounter()
|
||||
var firstTokenMs *int
|
||||
reqStream := openAIWSPayloadBoolFromRaw(payload, "stream", true)
|
||||
turnPreviousResponseID := openAIWSPayloadStringFromRaw(payload, "previous_response_id")
|
||||
@ -2935,6 +2959,7 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
|
||||
if openAIWSEventShouldParseUsage(eventType) {
|
||||
parseOpenAIWSResponseUsageFromCompletedEvent(upstreamMessage, &usage)
|
||||
}
|
||||
imageCounter.AddSSEData(upstreamMessage)
|
||||
|
||||
if !clientDisconnected {
|
||||
if needModelReplace && len(mappedModelBytes) > 0 && openAIWSEventMayContainModel(eventType) && bytes.Contains(upstreamMessage, mappedModelBytes) {
|
||||
@ -2994,7 +3019,8 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
|
||||
clientDisconnected,
|
||||
)
|
||||
}
|
||||
return &OpenAIForwardResult{
|
||||
imageCount := imageCounter.Count()
|
||||
result := &OpenAIForwardResult{
|
||||
RequestID: responseID,
|
||||
Usage: usage,
|
||||
Model: originalModel,
|
||||
@ -3006,13 +3032,21 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
|
||||
ResponseHeaders: lease.HandshakeHeaders(),
|
||||
Duration: time.Since(turnStart),
|
||||
FirstTokenMs: firstTokenMs,
|
||||
}, nil
|
||||
}
|
||||
if imageCount > 0 {
|
||||
result.ImageCount = imageCount
|
||||
result.ImageSize = imageSizeTier
|
||||
result.BillingModel = imageBillingModel
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
currentPayload := firstPayload.payloadRaw
|
||||
currentOriginalModel := firstPayload.originalModel
|
||||
currentImageBillingModel := firstPayload.imageBillingModel
|
||||
currentImageSizeTier := firstPayload.imageSizeTier
|
||||
currentPayloadBytes := firstPayload.payloadBytes
|
||||
isStrictAffinityTurn := func(payload []byte) bool {
|
||||
if !storeDisabled {
|
||||
@ -3101,6 +3135,12 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
|
||||
if turnPrevRecoveryTried || !s.openAIWSIngressPreviousResponseRecoveryEnabled() {
|
||||
return false
|
||||
}
|
||||
// 携带 function_call_output 的请求不能丢弃 previous_response_id:
|
||||
// 上游 API 需要 response chain 来匹配 tool_result 与之前的 tool_use,
|
||||
// 丢弃后会导致 "No tool call found for function call output" 400 错误。
|
||||
if gjson.GetBytes(currentPayload, `input.#(type=="function_call_output")`).Exists() {
|
||||
return false
|
||||
}
|
||||
if isStrictAffinityTurn(currentPayload) {
|
||||
// Layer 2:严格亲和链路命中 previous_response_not_found 时,降级为“去掉 previous_response_id 后重放一次”。
|
||||
// 该错误说明续链锚点已失效,继续 strict fail-close 只会直接中断本轮请求。
|
||||
@ -3367,7 +3407,11 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
|
||||
truncateOpenAIWSLogValue(pingErr.Error(), openAIWSLogValueMaxLen),
|
||||
)
|
||||
if forcePreferredConn {
|
||||
if !turnPrevRecoveryTried && currentPreviousResponseID != "" {
|
||||
// 携带 function_call_output 的请求不能丢弃 previous_response_id:
|
||||
// 上游 API 需要 response chain 来匹配 tool_result 与之前的 tool_use,
|
||||
// 丢弃后会导致 "No tool call found for function call output" 400 错误。
|
||||
hasFCOutput := gjson.GetBytes(currentPayload, `input.#(type=="function_call_output")`).Exists()
|
||||
if !turnPrevRecoveryTried && currentPreviousResponseID != "" && !hasFCOutput {
|
||||
updatedPayload, removed, dropErr := dropPreviousResponseIDFromRawPayload(currentPayload)
|
||||
if dropErr != nil || !removed {
|
||||
reason := "not_removed"
|
||||
@ -3457,7 +3501,7 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
|
||||
)
|
||||
}
|
||||
|
||||
result, relayErr := sendAndRelay(turn, sessionLease, currentPayload, currentPayloadBytes, currentOriginalModel)
|
||||
result, relayErr := sendAndRelay(turn, sessionLease, currentPayload, currentPayloadBytes, currentOriginalModel, currentImageBillingModel, currentImageSizeTier)
|
||||
if relayErr != nil {
|
||||
lastTurnClean = false
|
||||
if recoverIngressPrevResponseNotFound(relayErr, turn, connID) {
|
||||
@ -3579,6 +3623,8 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
|
||||
}
|
||||
currentPayload = nextPayload.payloadRaw
|
||||
currentOriginalModel = nextPayload.originalModel
|
||||
currentImageBillingModel = nextPayload.imageBillingModel
|
||||
currentImageSizeTier = nextPayload.imageSizeTier
|
||||
currentPayloadBytes = nextPayload.payloadBytes
|
||||
storeDisabled = s.isOpenAIWSStoreDisabledInRequestRaw(currentPayload, account)
|
||||
if !storeDisabled {
|
||||
|
||||
@ -399,7 +399,7 @@ func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_PassthroughModeR
|
||||
}()
|
||||
|
||||
writeCtx, cancelWrite := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
err = clientConn.Write(writeCtx, coderws.MessageText, []byte(`{"type":"response.create","model":"gpt-5.1","stream":false,"service_tier":"fast"}`))
|
||||
err = clientConn.Write(writeCtx, coderws.MessageText, []byte(`{"type":"response.create","model":"gpt-5.1","stream":false,"service_tier":"fast","reasoning":{"effort":"HIGH"}}`))
|
||||
cancelWrite()
|
||||
require.NoError(t, err)
|
||||
|
||||
@ -431,6 +431,8 @@ func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_PassthroughModeR
|
||||
require.Equal(t, 3, result.Usage.OutputTokens)
|
||||
require.NotNil(t, result.ServiceTier)
|
||||
require.Equal(t, "priority", *result.ServiceTier)
|
||||
require.NotNil(t, result.ReasoningEffort)
|
||||
require.Equal(t, "high", *result.ReasoningEffort)
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("未收到 passthrough turn 结果回调")
|
||||
}
|
||||
|
||||
@ -171,6 +171,127 @@ func TestOpenAIGatewayService_Forward_WSv2_SuccessAndBindSticky(t *testing.T) {
|
||||
require.Equal(t, "resp_new_1", gjson.GetBytes(responseBody, "id").String())
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayService_Forward_WSv2_ImageGenerationCountsOutputs(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}
|
||||
wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
conn, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
t.Errorf("upgrade websocket failed: %v", err)
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
_ = conn.Close()
|
||||
}()
|
||||
|
||||
var request map[string]any
|
||||
if err := conn.ReadJSON(&request); err != nil {
|
||||
t.Errorf("read ws request failed: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if err := conn.WriteJSON(map[string]any{
|
||||
"type": "response.output_item.done",
|
||||
"item": map[string]any{
|
||||
"id": "ig_ws_1",
|
||||
"type": "image_generation_call",
|
||||
"result": "final-image",
|
||||
},
|
||||
}); err != nil {
|
||||
t.Errorf("write response.output_item.done failed: %v", err)
|
||||
return
|
||||
}
|
||||
if err := conn.WriteJSON(map[string]any{
|
||||
"type": "response.completed",
|
||||
"response": map[string]any{
|
||||
"id": "resp_ws_image_1",
|
||||
"model": "gpt-5.4",
|
||||
"output": []any{
|
||||
map[string]any{
|
||||
"id": "ig_ws_1",
|
||||
"type": "image_generation_call",
|
||||
"result": "final-image",
|
||||
},
|
||||
},
|
||||
"usage": map[string]any{
|
||||
"input_tokens": 9,
|
||||
"output_tokens": 4,
|
||||
},
|
||||
},
|
||||
}); err != nil {
|
||||
t.Errorf("write response.completed failed: %v", err)
|
||||
return
|
||||
}
|
||||
}))
|
||||
defer wsServer.Close()
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
|
||||
groupID := int64(1010)
|
||||
c.Set("api_key", &APIKey{
|
||||
GroupID: &groupID,
|
||||
Group: &Group{
|
||||
ID: groupID,
|
||||
AllowImageGeneration: true,
|
||||
},
|
||||
})
|
||||
|
||||
cfg := &config.Config{}
|
||||
cfg.Security.URLAllowlist.Enabled = false
|
||||
cfg.Security.URLAllowlist.AllowInsecureHTTP = true
|
||||
cfg.Gateway.OpenAIWS.Enabled = true
|
||||
cfg.Gateway.OpenAIWS.OAuthEnabled = true
|
||||
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
|
||||
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
|
||||
cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1
|
||||
cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
|
||||
cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1
|
||||
cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8
|
||||
cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3
|
||||
cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 5
|
||||
cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3
|
||||
|
||||
svc := &OpenAIGatewayService{
|
||||
cfg: cfg,
|
||||
httpUpstream: &httpUpstreamRecorder{},
|
||||
cache: &stubGatewayCache{},
|
||||
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
|
||||
toolCorrector: NewCodexToolCorrector(),
|
||||
}
|
||||
|
||||
account := &Account{
|
||||
ID: 10,
|
||||
Name: "openai-ws-image",
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeAPIKey,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{
|
||||
"api_key": "sk-test",
|
||||
"base_url": wsServer.URL,
|
||||
},
|
||||
Extra: map[string]any{
|
||||
"responses_websockets_v2_enabled": true,
|
||||
},
|
||||
}
|
||||
|
||||
body := []byte(`{"model":"gpt-5.4","stream":false,"input":"draw","tools":[{"type":"image_generation","model":"gpt-image-2","size":"1024x1024"}],"tool_choice":{"type":"image_generation"}}`)
|
||||
result, err := svc.Forward(context.Background(), c, account, body)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.Equal(t, "resp_ws_image_1", result.RequestID)
|
||||
require.Equal(t, 1, result.ImageCount)
|
||||
require.Equal(t, "1K", result.ImageSize)
|
||||
require.Equal(t, "gpt-image-2", result.BillingModel)
|
||||
require.Equal(t, 9, result.Usage.InputTokens)
|
||||
require.Equal(t, 4, result.Usage.OutputTokens)
|
||||
require.True(t, result.OpenAIWSMode)
|
||||
require.Equal(t, "resp_ws_image_1", gjson.GetBytes(rec.Body.Bytes(), "id").String())
|
||||
}
|
||||
|
||||
func requestToJSONString(payload map[string]any) string {
|
||||
if len(payload) == 0 {
|
||||
return "{}"
|
||||
|
||||
@ -124,6 +124,73 @@ func openAIWSPassthroughPolicyModelFromSessionFrame(account *Account, payload []
|
||||
return normalizeOpenAIModelForUpstream(account, account.GetMappedModel(original))
|
||||
}
|
||||
|
||||
type openAIWSPassthroughUsageMeta struct {
|
||||
serviceTier atomic.Pointer[string]
|
||||
reasoningEffort atomic.Pointer[string]
|
||||
|
||||
// 仅在 client->upstream filter goroutine 中读写;Load 侧通过上方原子指针同步。
|
||||
sessionRequestModel string
|
||||
}
|
||||
|
||||
func newOpenAIWSPassthroughUsageMeta(initialRequestModel string, firstFrame []byte) *openAIWSPassthroughUsageMeta {
|
||||
meta := &openAIWSPassthroughUsageMeta{
|
||||
sessionRequestModel: strings.TrimSpace(initialRequestModel),
|
||||
}
|
||||
if meta.sessionRequestModel == "" {
|
||||
meta.sessionRequestModel = openAIWSPassthroughRequestModelForFrame(firstFrame)
|
||||
}
|
||||
return meta
|
||||
}
|
||||
|
||||
func (m *openAIWSPassthroughUsageMeta) initFromFirstFrame(policyOutput []byte) {
|
||||
if m == nil {
|
||||
return
|
||||
}
|
||||
m.serviceTier.Store(extractOpenAIServiceTierFromBody(policyOutput))
|
||||
m.reasoningEffort.Store(extractOpenAIReasoningEffortFromBody(policyOutput, m.sessionRequestModel))
|
||||
}
|
||||
|
||||
func (m *openAIWSPassthroughUsageMeta) updateSessionRequestModel(payload []byte) {
|
||||
if m == nil {
|
||||
return
|
||||
}
|
||||
if model := openAIWSPassthroughRequestModelFromSessionFrame(payload); model != "" {
|
||||
m.sessionRequestModel = model
|
||||
}
|
||||
}
|
||||
|
||||
func (m *openAIWSPassthroughUsageMeta) requestModelForFrame(payload []byte) string {
|
||||
if m == nil {
|
||||
return openAIWSPassthroughRequestModelForFrame(payload)
|
||||
}
|
||||
if model := openAIWSPassthroughRequestModelForFrame(payload); model != "" {
|
||||
return model
|
||||
}
|
||||
return m.sessionRequestModel
|
||||
}
|
||||
|
||||
func (m *openAIWSPassthroughUsageMeta) updateFromResponseCreate(policyOutput []byte, requestModelForFrame string) {
|
||||
if m == nil {
|
||||
return
|
||||
}
|
||||
m.serviceTier.Store(extractOpenAIServiceTierFromBody(policyOutput))
|
||||
m.reasoningEffort.Store(extractOpenAIReasoningEffortFromBody(policyOutput, requestModelForFrame))
|
||||
}
|
||||
|
||||
func openAIWSPassthroughRequestModelForFrame(payload []byte) string {
|
||||
if len(payload) == 0 || strings.TrimSpace(gjson.GetBytes(payload, "type").String()) != "response.create" {
|
||||
return ""
|
||||
}
|
||||
return strings.TrimSpace(gjson.GetBytes(payload, "model").String())
|
||||
}
|
||||
|
||||
func openAIWSPassthroughRequestModelFromSessionFrame(payload []byte) string {
|
||||
if len(payload) == 0 || strings.TrimSpace(gjson.GetBytes(payload, "type").String()) != "session.update" {
|
||||
return ""
|
||||
}
|
||||
return strings.TrimSpace(gjson.GetBytes(payload, "session.model").String())
|
||||
}
|
||||
|
||||
const openaiWSV2PassthroughModeFields = "ws_mode=passthrough ws_router=v2"
|
||||
|
||||
var _ openaiwsv2.FrameConn = (*openAIWSClientFrameConn)(nil)
|
||||
@ -204,6 +271,11 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough(
|
||||
// silently passed through, defeating the policy on every frame after
|
||||
// the first.
|
||||
capturedSessionModel := openAIWSPassthroughPolicyModelForFrame(account, firstClientMessage)
|
||||
initialRequestModel := ""
|
||||
if hooks != nil {
|
||||
initialRequestModel = hooks.InitialRequestModel
|
||||
}
|
||||
usageMeta := newOpenAIWSPassthroughUsageMeta(initialRequestModel, firstClientMessage)
|
||||
updatedFirst, blocked, policyErr := s.applyOpenAIFastPolicyToWSResponseCreate(ctx, account, capturedSessionModel, firstClientMessage)
|
||||
if policyErr != nil {
|
||||
return fmt.Errorf("apply openai fast policy on first ws frame: %w", policyErr)
|
||||
@ -226,7 +298,8 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough(
|
||||
}
|
||||
firstClientMessage = updatedFirst
|
||||
|
||||
// 在 policy filter 之后再提取 service_tier 用于 billing 上报:filter
|
||||
// 在 policy filter 之后再提取 service_tier / reasoning_effort 用于
|
||||
// usage 上报:filter
|
||||
// 命中时 service_tier 已经从 firstClientMessage 中删除,billing 应当
|
||||
// 反映上游实际处理的 tier(nil = default),而不是用户最初请求的
|
||||
// "priority"。HTTP 入口(line ~2728 extractOpenAIServiceTier(reqBody))
|
||||
@ -237,11 +310,8 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough(
|
||||
// codex-rs/core/src/client.rs build_responses_request 每次重新填值)。
|
||||
// 因此使用 atomic.Pointer[string] 在 filter(runClientToUpstream
|
||||
// goroutine)和 OnTurnComplete / final result(runUpstreamToClient
|
||||
// goroutine)之间同步当前 turn 的 service_tier。
|
||||
// extractOpenAIServiceTierFromBody 返回 *string,本身是指针类型,
|
||||
// 可直接 Store/Load 而无需额外封装。
|
||||
var requestServiceTierPtr atomic.Pointer[string]
|
||||
requestServiceTierPtr.Store(extractOpenAIServiceTierFromBody(firstClientMessage))
|
||||
// goroutine)之间同步当前 turn 的 usage metadata。
|
||||
usageMeta.initFromFirstFrame(firstClientMessage)
|
||||
|
||||
wsURL, err := s.buildOpenAIResponsesWSURL(account)
|
||||
if err != nil {
|
||||
@ -327,6 +397,8 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough(
|
||||
if updated := openAIWSPassthroughPolicyModelFromSessionFrame(account, payload); updated != "" {
|
||||
capturedSessionModel = updated
|
||||
}
|
||||
usageMeta.updateSessionRequestModel(payload)
|
||||
requestModelForThisFrame := usageMeta.requestModelForFrame(payload)
|
||||
// Per-frame model first; if the client omits "model" on a
|
||||
// follow-up frame (legal in Realtime), fall back to the
|
||||
// session-level model captured from the first frame so the
|
||||
@ -337,14 +409,14 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough(
|
||||
model = capturedSessionModel
|
||||
}
|
||||
out, blocked, policyErr := s.applyOpenAIFastPolicyToWSResponseCreate(ctx, account, model, payload)
|
||||
// 多轮 passthrough billing:仅在成功(non-block / non-err)
|
||||
// 的 response.create 帧上更新 requestServiceTierPtr,使用
|
||||
// 多轮 passthrough usage:仅在成功(non-block / non-err)
|
||||
// 的 response.create 帧上更新 usageMeta,使用
|
||||
// filter 处理后的 payload,与首帧 policy-after-extract 语义
|
||||
// 保持一致(参见上方 extractOpenAIServiceTierFromBody 注释)。
|
||||
// - 非 response.create 帧(response.cancel /
|
||||
// conversation.item.create / session.update 等)不携带
|
||||
// per-response service_tier,不应覆盖前一轮值。
|
||||
// - blocked != nil:该帧不会发送上游,billing tier 应保持
|
||||
// per-response metadata,不应覆盖前一轮值。
|
||||
// - blocked != nil:该帧不会发送上游,usage metadata 应保持
|
||||
// 上一轮值。
|
||||
// - policyErr != nil:异常路径,保持上一轮值。
|
||||
// - 不带 service_tier 的 response.create 会让
|
||||
@ -353,7 +425,7 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough(
|
||||
// service_tier 时按 default 处理,billing 应如实反映。
|
||||
if policyErr == nil && blocked == nil &&
|
||||
strings.TrimSpace(gjson.GetBytes(payload, "type").String()) == "response.create" {
|
||||
requestServiceTierPtr.Store(extractOpenAIServiceTierFromBody(out))
|
||||
usageMeta.updateFromResponseCreate(out, requestModelForThisFrame)
|
||||
}
|
||||
return out, blocked, policyErr
|
||||
},
|
||||
@ -397,7 +469,8 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough(
|
||||
CacheReadInputTokens: turn.Usage.CacheReadInputTokens,
|
||||
},
|
||||
Model: turn.RequestModel,
|
||||
ServiceTier: requestServiceTierPtr.Load(),
|
||||
ServiceTier: usageMeta.serviceTier.Load(),
|
||||
ReasoningEffort: usageMeta.reasoningEffort.Load(),
|
||||
Stream: true,
|
||||
OpenAIWSMode: true,
|
||||
ResponseHeaders: cloneHeader(handshakeHeaders),
|
||||
@ -445,7 +518,8 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough(
|
||||
CacheReadInputTokens: relayResult.Usage.CacheReadInputTokens,
|
||||
},
|
||||
Model: relayResult.RequestModel,
|
||||
ServiceTier: requestServiceTierPtr.Load(),
|
||||
ServiceTier: usageMeta.serviceTier.Load(),
|
||||
ReasoningEffort: usageMeta.reasoningEffort.Load(),
|
||||
Stream: true,
|
||||
OpenAIWSMode: true,
|
||||
ResponseHeaders: cloneHeader(handshakeHeaders),
|
||||
|
||||
164
backend/internal/service/ops_cleanup_executor.go
Normal file
164
backend/internal/service/ops_cleanup_executor.go
Normal file
@ -0,0 +1,164 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
opsCleanupDefaultSchedule = "0 2 * * *"
|
||||
opsCleanupBatchSize = 5000
|
||||
opsCleanupCronStopTimeout = 3 * time.Second
|
||||
opsCleanupRunTimeout = 30 * time.Minute
|
||||
opsCleanupHeartbeatTimeout = 2 * time.Second
|
||||
)
|
||||
|
||||
type opsCleanupTarget struct {
|
||||
retentionDays int
|
||||
table string
|
||||
timeCol string
|
||||
castDate bool
|
||||
counter *int64
|
||||
}
|
||||
|
||||
type opsCleanupDeletedCounts struct {
|
||||
errorLogs int64
|
||||
retryAttempts int64
|
||||
alertEvents int64
|
||||
systemLogs int64
|
||||
logAudits int64
|
||||
systemMetrics int64
|
||||
hourlyPreagg int64
|
||||
dailyPreagg int64
|
||||
}
|
||||
|
||||
func (c opsCleanupDeletedCounts) String() string {
|
||||
return fmt.Sprintf(
|
||||
"error_logs=%d retry_attempts=%d alert_events=%d system_logs=%d log_audits=%d system_metrics=%d hourly_preagg=%d daily_preagg=%d",
|
||||
c.errorLogs,
|
||||
c.retryAttempts,
|
||||
c.alertEvents,
|
||||
c.systemLogs,
|
||||
c.logAudits,
|
||||
c.systemMetrics,
|
||||
c.hourlyPreagg,
|
||||
c.dailyPreagg,
|
||||
)
|
||||
}
|
||||
|
||||
// opsCleanupPlan 把"保留天数"翻译成具体的清理动作。
|
||||
// - days < 0 → 跳过该项清理(ok=false),保留兼容老数据
|
||||
// - days == 0 → TRUNCATE TABLE(O(1) 全清),truncate=true
|
||||
// - days > 0 → 批量 DELETE 早于 now-N天 的行,cutoff = now - N 天
|
||||
func opsCleanupPlan(now time.Time, days int) (cutoff time.Time, truncate, ok bool) {
|
||||
if days < 0 {
|
||||
return time.Time{}, false, false
|
||||
}
|
||||
if days == 0 {
|
||||
return time.Time{}, true, true
|
||||
}
|
||||
return now.AddDate(0, 0, -days), false, true
|
||||
}
|
||||
|
||||
func opsCleanupRunOne(
|
||||
ctx context.Context,
|
||||
db *sql.DB,
|
||||
truncate bool,
|
||||
cutoff time.Time,
|
||||
table, timeCol string,
|
||||
castDate bool,
|
||||
batchSize int,
|
||||
) (int64, error) {
|
||||
if truncate {
|
||||
return truncateOpsTable(ctx, db, table)
|
||||
}
|
||||
return deleteOldRowsByID(ctx, db, table, timeCol, cutoff, batchSize, castDate)
|
||||
}
|
||||
|
||||
func deleteOldRowsByID(
|
||||
ctx context.Context,
|
||||
db *sql.DB,
|
||||
table string,
|
||||
timeColumn string,
|
||||
cutoff time.Time,
|
||||
batchSize int,
|
||||
castCutoffToDate bool,
|
||||
) (int64, error) {
|
||||
if db == nil {
|
||||
return 0, nil
|
||||
}
|
||||
if batchSize <= 0 {
|
||||
batchSize = opsCleanupBatchSize
|
||||
}
|
||||
|
||||
where := fmt.Sprintf("%s < $1", timeColumn)
|
||||
if castCutoffToDate {
|
||||
where = fmt.Sprintf("%s < $1::date", timeColumn)
|
||||
}
|
||||
|
||||
q := fmt.Sprintf(`
|
||||
WITH batch AS (
|
||||
SELECT id FROM %s
|
||||
WHERE %s
|
||||
ORDER BY id
|
||||
LIMIT $2
|
||||
)
|
||||
DELETE FROM %s
|
||||
WHERE id IN (SELECT id FROM batch)
|
||||
`, table, where, table)
|
||||
|
||||
var total int64
|
||||
for {
|
||||
res, err := db.ExecContext(ctx, q, cutoff, batchSize)
|
||||
if err != nil {
|
||||
if isMissingRelationError(err) {
|
||||
return total, nil
|
||||
}
|
||||
return total, err
|
||||
}
|
||||
affected, err := res.RowsAffected()
|
||||
if err != nil {
|
||||
return total, err
|
||||
}
|
||||
total += affected
|
||||
if affected == 0 {
|
||||
break
|
||||
}
|
||||
}
|
||||
return total, nil
|
||||
}
|
||||
|
||||
// truncateOpsTable 用 TRUNCATE TABLE 清空指定表,先 SELECT COUNT(*) 取得清空前行数用于 heartbeat。
|
||||
func truncateOpsTable(ctx context.Context, db *sql.DB, table string) (int64, error) {
|
||||
if db == nil {
|
||||
return 0, nil
|
||||
}
|
||||
var count int64
|
||||
if err := db.QueryRowContext(ctx, fmt.Sprintf("SELECT COUNT(*) FROM %s", table)).Scan(&count); err != nil {
|
||||
if isMissingRelationError(err) {
|
||||
return 0, nil
|
||||
}
|
||||
return 0, fmt.Errorf("count %s: %w", table, err)
|
||||
}
|
||||
if count == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
if _, err := db.ExecContext(ctx, fmt.Sprintf("TRUNCATE TABLE %s", table)); err != nil {
|
||||
if isMissingRelationError(err) {
|
||||
return 0, nil
|
||||
}
|
||||
return 0, fmt.Errorf("truncate %s: %w", table, err)
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
||||
func isMissingRelationError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
s := strings.ToLower(err.Error())
|
||||
return strings.Contains(s, "does not exist") && strings.Contains(s, "relation")
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user