chore: merge upstream v0.1.122-123, keep Windsurf/Antigravity customizations

New upstream features:
- feat: improve OpenAI messages compatibility for Claude Code
- feat: image generation stream & concurrency controls
- fix(rate-limit): remove 429 cooldown config option
- fix: skip previous_response_id recovery when payload has function_call_output
- feat: support select search in group/account views
- fix: ops cleanup settings
- chore: remove openspec and update axios

Conflict resolutions:
- config.go: kept AntigravityLSWorker+NodeTLSProxy AND added ImageConcurrency
- account_test_service.go: kept windsurf import AND added openai_compat import
- docker-compose.yml: kept Windsurf env vars AND added image concurrency env vars
This commit is contained in:
win 2026-05-06 11:50:54 +08:00
commit 3fe228d143
146 changed files with 13916 additions and 1361 deletions

View File

@ -1 +1 @@
0.1.121 0.1.123

View File

@ -265,7 +265,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
opsMetricsCollector := service.ProvideOpsMetricsCollector(opsRepository, settingRepository, accountRepository, concurrencyService, db, redisClient, configConfig) opsMetricsCollector := service.ProvideOpsMetricsCollector(opsRepository, settingRepository, accountRepository, concurrencyService, db, redisClient, configConfig)
opsAggregationService := service.ProvideOpsAggregationService(opsRepository, settingRepository, db, redisClient, configConfig) opsAggregationService := service.ProvideOpsAggregationService(opsRepository, settingRepository, db, redisClient, configConfig)
opsAlertEvaluatorService := service.ProvideOpsAlertEvaluatorService(opsService, opsRepository, emailService, redisClient, configConfig) opsAlertEvaluatorService := service.ProvideOpsAlertEvaluatorService(opsService, opsRepository, emailService, redisClient, configConfig)
opsCleanupService := service.ProvideOpsCleanupService(opsRepository, db, redisClient, configConfig, channelMonitorService) opsCleanupService := service.ProvideOpsCleanupService(opsRepository, db, redisClient, configConfig, channelMonitorService, settingRepository, opsService)
opsScheduledReportService := service.ProvideOpsScheduledReportService(opsService, userService, emailService, redisClient, configConfig) opsScheduledReportService := service.ProvideOpsScheduledReportService(opsService, userService, emailService, redisClient, configConfig)
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, schedulerCache, configConfig, tempUnschedCache, privacyClientFactory, proxyRepository, oAuthRefreshAPI) tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, schedulerCache, configConfig, tempUnschedCache, privacyClientFactory, proxyRepository, oAuthRefreshAPI)
accountExpiryService := service.ProvideAccountExpiryService(accountRepository) accountExpiryService := service.ProvideAccountExpiryService(accountRepository)

View File

@ -47,6 +47,12 @@ type Group struct {
MonthlyLimitUsd *float64 `json:"monthly_limit_usd,omitempty"` MonthlyLimitUsd *float64 `json:"monthly_limit_usd,omitempty"`
// DefaultValidityDays holds the value of the "default_validity_days" field. // DefaultValidityDays holds the value of the "default_validity_days" field.
DefaultValidityDays int `json:"default_validity_days,omitempty"` DefaultValidityDays int `json:"default_validity_days,omitempty"`
// 是否允许该分组使用图片生成能力
AllowImageGeneration bool `json:"allow_image_generation,omitempty"`
// 图片生成是否使用独立倍率false 表示共享分组有效倍率
ImageRateIndependent bool `json:"image_rate_independent,omitempty"`
// 图片生成独立倍率,仅 image_rate_independent=true 时生效
ImageRateMultiplier float64 `json:"image_rate_multiplier,omitempty"`
// ImagePrice1k holds the value of the "image_price_1k" field. // ImagePrice1k holds the value of the "image_price_1k" field.
ImagePrice1k *float64 `json:"image_price_1k,omitempty"` ImagePrice1k *float64 `json:"image_price_1k,omitempty"`
// ImagePrice2k holds the value of the "image_price_2k" field. // ImagePrice2k holds the value of the "image_price_2k" field.
@ -189,9 +195,9 @@ func (*Group) scanValues(columns []string) ([]any, error) {
switch columns[i] { switch columns[i] {
case group.FieldModelRouting, group.FieldSupportedModelScopes, group.FieldMessagesDispatchModelConfig: case group.FieldModelRouting, group.FieldSupportedModelScopes, group.FieldMessagesDispatchModelConfig:
values[i] = new([]byte) values[i] = new([]byte)
case group.FieldIsExclusive, group.FieldClaudeCodeOnly, group.FieldModelRoutingEnabled, group.FieldMcpXMLInject, group.FieldAllowMessagesDispatch, group.FieldRequireOauthOnly, group.FieldRequirePrivacySet: case group.FieldIsExclusive, group.FieldAllowImageGeneration, group.FieldImageRateIndependent, group.FieldClaudeCodeOnly, group.FieldModelRoutingEnabled, group.FieldMcpXMLInject, group.FieldAllowMessagesDispatch, group.FieldRequireOauthOnly, group.FieldRequirePrivacySet:
values[i] = new(sql.NullBool) values[i] = new(sql.NullBool)
case group.FieldRateMultiplier, group.FieldDailyLimitUsd, group.FieldWeeklyLimitUsd, group.FieldMonthlyLimitUsd, group.FieldImagePrice1k, group.FieldImagePrice2k, group.FieldImagePrice4k: case group.FieldRateMultiplier, group.FieldDailyLimitUsd, group.FieldWeeklyLimitUsd, group.FieldMonthlyLimitUsd, group.FieldImageRateMultiplier, group.FieldImagePrice1k, group.FieldImagePrice2k, group.FieldImagePrice4k:
values[i] = new(sql.NullFloat64) values[i] = new(sql.NullFloat64)
case group.FieldID, group.FieldDefaultValidityDays, group.FieldFallbackGroupID, group.FieldFallbackGroupIDOnInvalidRequest, group.FieldSortOrder, group.FieldRpmLimit: case group.FieldID, group.FieldDefaultValidityDays, group.FieldFallbackGroupID, group.FieldFallbackGroupIDOnInvalidRequest, group.FieldSortOrder, group.FieldRpmLimit:
values[i] = new(sql.NullInt64) values[i] = new(sql.NullInt64)
@ -309,6 +315,24 @@ func (_m *Group) assignValues(columns []string, values []any) error {
} else if value.Valid { } else if value.Valid {
_m.DefaultValidityDays = int(value.Int64) _m.DefaultValidityDays = int(value.Int64)
} }
case group.FieldAllowImageGeneration:
if value, ok := values[i].(*sql.NullBool); !ok {
return fmt.Errorf("unexpected type %T for field allow_image_generation", values[i])
} else if value.Valid {
_m.AllowImageGeneration = value.Bool
}
case group.FieldImageRateIndependent:
if value, ok := values[i].(*sql.NullBool); !ok {
return fmt.Errorf("unexpected type %T for field image_rate_independent", values[i])
} else if value.Valid {
_m.ImageRateIndependent = value.Bool
}
case group.FieldImageRateMultiplier:
if value, ok := values[i].(*sql.NullFloat64); !ok {
return fmt.Errorf("unexpected type %T for field image_rate_multiplier", values[i])
} else if value.Valid {
_m.ImageRateMultiplier = value.Float64
}
case group.FieldImagePrice1k: case group.FieldImagePrice1k:
if value, ok := values[i].(*sql.NullFloat64); !ok { if value, ok := values[i].(*sql.NullFloat64); !ok {
return fmt.Errorf("unexpected type %T for field image_price_1k", values[i]) return fmt.Errorf("unexpected type %T for field image_price_1k", values[i])
@ -550,6 +574,15 @@ func (_m *Group) String() string {
builder.WriteString("default_validity_days=") builder.WriteString("default_validity_days=")
builder.WriteString(fmt.Sprintf("%v", _m.DefaultValidityDays)) builder.WriteString(fmt.Sprintf("%v", _m.DefaultValidityDays))
builder.WriteString(", ") builder.WriteString(", ")
builder.WriteString("allow_image_generation=")
builder.WriteString(fmt.Sprintf("%v", _m.AllowImageGeneration))
builder.WriteString(", ")
builder.WriteString("image_rate_independent=")
builder.WriteString(fmt.Sprintf("%v", _m.ImageRateIndependent))
builder.WriteString(", ")
builder.WriteString("image_rate_multiplier=")
builder.WriteString(fmt.Sprintf("%v", _m.ImageRateMultiplier))
builder.WriteString(", ")
if v := _m.ImagePrice1k; v != nil { if v := _m.ImagePrice1k; v != nil {
builder.WriteString("image_price_1k=") builder.WriteString("image_price_1k=")
builder.WriteString(fmt.Sprintf("%v", *v)) builder.WriteString(fmt.Sprintf("%v", *v))

View File

@ -44,6 +44,12 @@ const (
FieldMonthlyLimitUsd = "monthly_limit_usd" FieldMonthlyLimitUsd = "monthly_limit_usd"
// FieldDefaultValidityDays holds the string denoting the default_validity_days field in the database. // FieldDefaultValidityDays holds the string denoting the default_validity_days field in the database.
FieldDefaultValidityDays = "default_validity_days" FieldDefaultValidityDays = "default_validity_days"
// FieldAllowImageGeneration holds the string denoting the allow_image_generation field in the database.
FieldAllowImageGeneration = "allow_image_generation"
// FieldImageRateIndependent holds the string denoting the image_rate_independent field in the database.
FieldImageRateIndependent = "image_rate_independent"
// FieldImageRateMultiplier holds the string denoting the image_rate_multiplier field in the database.
FieldImageRateMultiplier = "image_rate_multiplier"
// FieldImagePrice1k holds the string denoting the image_price_1k field in the database. // FieldImagePrice1k holds the string denoting the image_price_1k field in the database.
FieldImagePrice1k = "image_price_1k" FieldImagePrice1k = "image_price_1k"
// FieldImagePrice2k holds the string denoting the image_price_2k field in the database. // FieldImagePrice2k holds the string denoting the image_price_2k field in the database.
@ -167,6 +173,9 @@ var Columns = []string{
FieldWeeklyLimitUsd, FieldWeeklyLimitUsd,
FieldMonthlyLimitUsd, FieldMonthlyLimitUsd,
FieldDefaultValidityDays, FieldDefaultValidityDays,
FieldAllowImageGeneration,
FieldImageRateIndependent,
FieldImageRateMultiplier,
FieldImagePrice1k, FieldImagePrice1k,
FieldImagePrice2k, FieldImagePrice2k,
FieldImagePrice4k, FieldImagePrice4k,
@ -239,6 +248,12 @@ var (
SubscriptionTypeValidator func(string) error SubscriptionTypeValidator func(string) error
// DefaultDefaultValidityDays holds the default value on creation for the "default_validity_days" field. // DefaultDefaultValidityDays holds the default value on creation for the "default_validity_days" field.
DefaultDefaultValidityDays int DefaultDefaultValidityDays int
// DefaultAllowImageGeneration holds the default value on creation for the "allow_image_generation" field.
DefaultAllowImageGeneration bool
// DefaultImageRateIndependent holds the default value on creation for the "image_rate_independent" field.
DefaultImageRateIndependent bool
// DefaultImageRateMultiplier holds the default value on creation for the "image_rate_multiplier" field.
DefaultImageRateMultiplier float64
// DefaultClaudeCodeOnly holds the default value on creation for the "claude_code_only" field. // DefaultClaudeCodeOnly holds the default value on creation for the "claude_code_only" field.
DefaultClaudeCodeOnly bool DefaultClaudeCodeOnly bool
// DefaultModelRoutingEnabled holds the default value on creation for the "model_routing_enabled" field. // DefaultModelRoutingEnabled holds the default value on creation for the "model_routing_enabled" field.
@ -343,6 +358,21 @@ func ByDefaultValidityDays(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldDefaultValidityDays, opts...).ToFunc() return sql.OrderByField(FieldDefaultValidityDays, opts...).ToFunc()
} }
// ByAllowImageGeneration orders the results by the allow_image_generation field.
func ByAllowImageGeneration(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldAllowImageGeneration, opts...).ToFunc()
}
// ByImageRateIndependent orders the results by the image_rate_independent field.
func ByImageRateIndependent(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldImageRateIndependent, opts...).ToFunc()
}
// ByImageRateMultiplier orders the results by the image_rate_multiplier field.
func ByImageRateMultiplier(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldImageRateMultiplier, opts...).ToFunc()
}
// ByImagePrice1k orders the results by the image_price_1k field. // ByImagePrice1k orders the results by the image_price_1k field.
func ByImagePrice1k(opts ...sql.OrderTermOption) OrderOption { func ByImagePrice1k(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldImagePrice1k, opts...).ToFunc() return sql.OrderByField(FieldImagePrice1k, opts...).ToFunc()

View File

@ -125,6 +125,21 @@ func DefaultValidityDays(v int) predicate.Group {
return predicate.Group(sql.FieldEQ(FieldDefaultValidityDays, v)) return predicate.Group(sql.FieldEQ(FieldDefaultValidityDays, v))
} }
// AllowImageGeneration applies equality check predicate on the "allow_image_generation" field. It's identical to AllowImageGenerationEQ.
func AllowImageGeneration(v bool) predicate.Group {
return predicate.Group(sql.FieldEQ(FieldAllowImageGeneration, v))
}
// ImageRateIndependent applies equality check predicate on the "image_rate_independent" field. It's identical to ImageRateIndependentEQ.
func ImageRateIndependent(v bool) predicate.Group {
return predicate.Group(sql.FieldEQ(FieldImageRateIndependent, v))
}
// ImageRateMultiplier applies equality check predicate on the "image_rate_multiplier" field. It's identical to ImageRateMultiplierEQ.
func ImageRateMultiplier(v float64) predicate.Group {
return predicate.Group(sql.FieldEQ(FieldImageRateMultiplier, v))
}
// ImagePrice1k applies equality check predicate on the "image_price_1k" field. It's identical to ImagePrice1kEQ. // ImagePrice1k applies equality check predicate on the "image_price_1k" field. It's identical to ImagePrice1kEQ.
func ImagePrice1k(v float64) predicate.Group { func ImagePrice1k(v float64) predicate.Group {
return predicate.Group(sql.FieldEQ(FieldImagePrice1k, v)) return predicate.Group(sql.FieldEQ(FieldImagePrice1k, v))
@ -900,6 +915,66 @@ func DefaultValidityDaysLTE(v int) predicate.Group {
return predicate.Group(sql.FieldLTE(FieldDefaultValidityDays, v)) return predicate.Group(sql.FieldLTE(FieldDefaultValidityDays, v))
} }
// AllowImageGenerationEQ applies the EQ predicate on the "allow_image_generation" field.
func AllowImageGenerationEQ(v bool) predicate.Group {
return predicate.Group(sql.FieldEQ(FieldAllowImageGeneration, v))
}
// AllowImageGenerationNEQ applies the NEQ predicate on the "allow_image_generation" field.
func AllowImageGenerationNEQ(v bool) predicate.Group {
return predicate.Group(sql.FieldNEQ(FieldAllowImageGeneration, v))
}
// ImageRateIndependentEQ applies the EQ predicate on the "image_rate_independent" field.
func ImageRateIndependentEQ(v bool) predicate.Group {
return predicate.Group(sql.FieldEQ(FieldImageRateIndependent, v))
}
// ImageRateIndependentNEQ applies the NEQ predicate on the "image_rate_independent" field.
func ImageRateIndependentNEQ(v bool) predicate.Group {
return predicate.Group(sql.FieldNEQ(FieldImageRateIndependent, v))
}
// ImageRateMultiplierEQ applies the EQ predicate on the "image_rate_multiplier" field.
func ImageRateMultiplierEQ(v float64) predicate.Group {
return predicate.Group(sql.FieldEQ(FieldImageRateMultiplier, v))
}
// ImageRateMultiplierNEQ applies the NEQ predicate on the "image_rate_multiplier" field.
func ImageRateMultiplierNEQ(v float64) predicate.Group {
return predicate.Group(sql.FieldNEQ(FieldImageRateMultiplier, v))
}
// ImageRateMultiplierIn applies the In predicate on the "image_rate_multiplier" field.
func ImageRateMultiplierIn(vs ...float64) predicate.Group {
return predicate.Group(sql.FieldIn(FieldImageRateMultiplier, vs...))
}
// ImageRateMultiplierNotIn applies the NotIn predicate on the "image_rate_multiplier" field.
func ImageRateMultiplierNotIn(vs ...float64) predicate.Group {
return predicate.Group(sql.FieldNotIn(FieldImageRateMultiplier, vs...))
}
// ImageRateMultiplierGT applies the GT predicate on the "image_rate_multiplier" field.
func ImageRateMultiplierGT(v float64) predicate.Group {
return predicate.Group(sql.FieldGT(FieldImageRateMultiplier, v))
}
// ImageRateMultiplierGTE applies the GTE predicate on the "image_rate_multiplier" field.
func ImageRateMultiplierGTE(v float64) predicate.Group {
return predicate.Group(sql.FieldGTE(FieldImageRateMultiplier, v))
}
// ImageRateMultiplierLT applies the LT predicate on the "image_rate_multiplier" field.
func ImageRateMultiplierLT(v float64) predicate.Group {
return predicate.Group(sql.FieldLT(FieldImageRateMultiplier, v))
}
// ImageRateMultiplierLTE applies the LTE predicate on the "image_rate_multiplier" field.
func ImageRateMultiplierLTE(v float64) predicate.Group {
return predicate.Group(sql.FieldLTE(FieldImageRateMultiplier, v))
}
// ImagePrice1kEQ applies the EQ predicate on the "image_price_1k" field. // ImagePrice1kEQ applies the EQ predicate on the "image_price_1k" field.
func ImagePrice1kEQ(v float64) predicate.Group { func ImagePrice1kEQ(v float64) predicate.Group {
return predicate.Group(sql.FieldEQ(FieldImagePrice1k, v)) return predicate.Group(sql.FieldEQ(FieldImagePrice1k, v))

View File

@ -217,6 +217,48 @@ func (_c *GroupCreate) SetNillableDefaultValidityDays(v *int) *GroupCreate {
return _c return _c
} }
// SetAllowImageGeneration sets the "allow_image_generation" field.
func (_c *GroupCreate) SetAllowImageGeneration(v bool) *GroupCreate {
_c.mutation.SetAllowImageGeneration(v)
return _c
}
// SetNillableAllowImageGeneration sets the "allow_image_generation" field if the given value is not nil.
func (_c *GroupCreate) SetNillableAllowImageGeneration(v *bool) *GroupCreate {
if v != nil {
_c.SetAllowImageGeneration(*v)
}
return _c
}
// SetImageRateIndependent sets the "image_rate_independent" field.
func (_c *GroupCreate) SetImageRateIndependent(v bool) *GroupCreate {
_c.mutation.SetImageRateIndependent(v)
return _c
}
// SetNillableImageRateIndependent sets the "image_rate_independent" field if the given value is not nil.
func (_c *GroupCreate) SetNillableImageRateIndependent(v *bool) *GroupCreate {
if v != nil {
_c.SetImageRateIndependent(*v)
}
return _c
}
// SetImageRateMultiplier sets the "image_rate_multiplier" field.
func (_c *GroupCreate) SetImageRateMultiplier(v float64) *GroupCreate {
_c.mutation.SetImageRateMultiplier(v)
return _c
}
// SetNillableImageRateMultiplier sets the "image_rate_multiplier" field if the given value is not nil.
func (_c *GroupCreate) SetNillableImageRateMultiplier(v *float64) *GroupCreate {
if v != nil {
_c.SetImageRateMultiplier(*v)
}
return _c
}
// SetImagePrice1k sets the "image_price_1k" field. // SetImagePrice1k sets the "image_price_1k" field.
func (_c *GroupCreate) SetImagePrice1k(v float64) *GroupCreate { func (_c *GroupCreate) SetImagePrice1k(v float64) *GroupCreate {
_c.mutation.SetImagePrice1k(v) _c.mutation.SetImagePrice1k(v)
@ -604,6 +646,18 @@ func (_c *GroupCreate) defaults() error {
v := group.DefaultDefaultValidityDays v := group.DefaultDefaultValidityDays
_c.mutation.SetDefaultValidityDays(v) _c.mutation.SetDefaultValidityDays(v)
} }
if _, ok := _c.mutation.AllowImageGeneration(); !ok {
v := group.DefaultAllowImageGeneration
_c.mutation.SetAllowImageGeneration(v)
}
if _, ok := _c.mutation.ImageRateIndependent(); !ok {
v := group.DefaultImageRateIndependent
_c.mutation.SetImageRateIndependent(v)
}
if _, ok := _c.mutation.ImageRateMultiplier(); !ok {
v := group.DefaultImageRateMultiplier
_c.mutation.SetImageRateMultiplier(v)
}
if _, ok := _c.mutation.ClaudeCodeOnly(); !ok { if _, ok := _c.mutation.ClaudeCodeOnly(); !ok {
v := group.DefaultClaudeCodeOnly v := group.DefaultClaudeCodeOnly
_c.mutation.SetClaudeCodeOnly(v) _c.mutation.SetClaudeCodeOnly(v)
@ -700,6 +754,15 @@ func (_c *GroupCreate) check() error {
if _, ok := _c.mutation.DefaultValidityDays(); !ok { if _, ok := _c.mutation.DefaultValidityDays(); !ok {
return &ValidationError{Name: "default_validity_days", err: errors.New(`ent: missing required field "Group.default_validity_days"`)} return &ValidationError{Name: "default_validity_days", err: errors.New(`ent: missing required field "Group.default_validity_days"`)}
} }
if _, ok := _c.mutation.AllowImageGeneration(); !ok {
return &ValidationError{Name: "allow_image_generation", err: errors.New(`ent: missing required field "Group.allow_image_generation"`)}
}
if _, ok := _c.mutation.ImageRateIndependent(); !ok {
return &ValidationError{Name: "image_rate_independent", err: errors.New(`ent: missing required field "Group.image_rate_independent"`)}
}
if _, ok := _c.mutation.ImageRateMultiplier(); !ok {
return &ValidationError{Name: "image_rate_multiplier", err: errors.New(`ent: missing required field "Group.image_rate_multiplier"`)}
}
if _, ok := _c.mutation.ClaudeCodeOnly(); !ok { if _, ok := _c.mutation.ClaudeCodeOnly(); !ok {
return &ValidationError{Name: "claude_code_only", err: errors.New(`ent: missing required field "Group.claude_code_only"`)} return &ValidationError{Name: "claude_code_only", err: errors.New(`ent: missing required field "Group.claude_code_only"`)}
} }
@ -821,6 +884,18 @@ func (_c *GroupCreate) createSpec() (*Group, *sqlgraph.CreateSpec) {
_spec.SetField(group.FieldDefaultValidityDays, field.TypeInt, value) _spec.SetField(group.FieldDefaultValidityDays, field.TypeInt, value)
_node.DefaultValidityDays = value _node.DefaultValidityDays = value
} }
if value, ok := _c.mutation.AllowImageGeneration(); ok {
_spec.SetField(group.FieldAllowImageGeneration, field.TypeBool, value)
_node.AllowImageGeneration = value
}
if value, ok := _c.mutation.ImageRateIndependent(); ok {
_spec.SetField(group.FieldImageRateIndependent, field.TypeBool, value)
_node.ImageRateIndependent = value
}
if value, ok := _c.mutation.ImageRateMultiplier(); ok {
_spec.SetField(group.FieldImageRateMultiplier, field.TypeFloat64, value)
_node.ImageRateMultiplier = value
}
if value, ok := _c.mutation.ImagePrice1k(); ok { if value, ok := _c.mutation.ImagePrice1k(); ok {
_spec.SetField(group.FieldImagePrice1k, field.TypeFloat64, value) _spec.SetField(group.FieldImagePrice1k, field.TypeFloat64, value)
_node.ImagePrice1k = &value _node.ImagePrice1k = &value
@ -1261,6 +1336,48 @@ func (u *GroupUpsert) AddDefaultValidityDays(v int) *GroupUpsert {
return u return u
} }
// SetAllowImageGeneration sets the "allow_image_generation" field.
func (u *GroupUpsert) SetAllowImageGeneration(v bool) *GroupUpsert {
u.Set(group.FieldAllowImageGeneration, v)
return u
}
// UpdateAllowImageGeneration sets the "allow_image_generation" field to the value that was provided on create.
func (u *GroupUpsert) UpdateAllowImageGeneration() *GroupUpsert {
u.SetExcluded(group.FieldAllowImageGeneration)
return u
}
// SetImageRateIndependent sets the "image_rate_independent" field.
func (u *GroupUpsert) SetImageRateIndependent(v bool) *GroupUpsert {
u.Set(group.FieldImageRateIndependent, v)
return u
}
// UpdateImageRateIndependent sets the "image_rate_independent" field to the value that was provided on create.
func (u *GroupUpsert) UpdateImageRateIndependent() *GroupUpsert {
u.SetExcluded(group.FieldImageRateIndependent)
return u
}
// SetImageRateMultiplier sets the "image_rate_multiplier" field.
func (u *GroupUpsert) SetImageRateMultiplier(v float64) *GroupUpsert {
u.Set(group.FieldImageRateMultiplier, v)
return u
}
// UpdateImageRateMultiplier sets the "image_rate_multiplier" field to the value that was provided on create.
func (u *GroupUpsert) UpdateImageRateMultiplier() *GroupUpsert {
u.SetExcluded(group.FieldImageRateMultiplier)
return u
}
// AddImageRateMultiplier adds v to the "image_rate_multiplier" field.
func (u *GroupUpsert) AddImageRateMultiplier(v float64) *GroupUpsert {
u.Add(group.FieldImageRateMultiplier, v)
return u
}
// SetImagePrice1k sets the "image_price_1k" field. // SetImagePrice1k sets the "image_price_1k" field.
func (u *GroupUpsert) SetImagePrice1k(v float64) *GroupUpsert { func (u *GroupUpsert) SetImagePrice1k(v float64) *GroupUpsert {
u.Set(group.FieldImagePrice1k, v) u.Set(group.FieldImagePrice1k, v)
@ -1840,6 +1957,55 @@ func (u *GroupUpsertOne) UpdateDefaultValidityDays() *GroupUpsertOne {
}) })
} }
// SetAllowImageGeneration sets the "allow_image_generation" field.
func (u *GroupUpsertOne) SetAllowImageGeneration(v bool) *GroupUpsertOne {
return u.Update(func(s *GroupUpsert) {
s.SetAllowImageGeneration(v)
})
}
// UpdateAllowImageGeneration sets the "allow_image_generation" field to the value that was provided on create.
func (u *GroupUpsertOne) UpdateAllowImageGeneration() *GroupUpsertOne {
return u.Update(func(s *GroupUpsert) {
s.UpdateAllowImageGeneration()
})
}
// SetImageRateIndependent sets the "image_rate_independent" field.
func (u *GroupUpsertOne) SetImageRateIndependent(v bool) *GroupUpsertOne {
return u.Update(func(s *GroupUpsert) {
s.SetImageRateIndependent(v)
})
}
// UpdateImageRateIndependent sets the "image_rate_independent" field to the value that was provided on create.
func (u *GroupUpsertOne) UpdateImageRateIndependent() *GroupUpsertOne {
return u.Update(func(s *GroupUpsert) {
s.UpdateImageRateIndependent()
})
}
// SetImageRateMultiplier sets the "image_rate_multiplier" field.
func (u *GroupUpsertOne) SetImageRateMultiplier(v float64) *GroupUpsertOne {
return u.Update(func(s *GroupUpsert) {
s.SetImageRateMultiplier(v)
})
}
// AddImageRateMultiplier adds v to the "image_rate_multiplier" field.
func (u *GroupUpsertOne) AddImageRateMultiplier(v float64) *GroupUpsertOne {
return u.Update(func(s *GroupUpsert) {
s.AddImageRateMultiplier(v)
})
}
// UpdateImageRateMultiplier sets the "image_rate_multiplier" field to the value that was provided on create.
func (u *GroupUpsertOne) UpdateImageRateMultiplier() *GroupUpsertOne {
return u.Update(func(s *GroupUpsert) {
s.UpdateImageRateMultiplier()
})
}
// SetImagePrice1k sets the "image_price_1k" field. // SetImagePrice1k sets the "image_price_1k" field.
func (u *GroupUpsertOne) SetImagePrice1k(v float64) *GroupUpsertOne { func (u *GroupUpsertOne) SetImagePrice1k(v float64) *GroupUpsertOne {
return u.Update(func(s *GroupUpsert) { return u.Update(func(s *GroupUpsert) {
@ -2632,6 +2798,55 @@ func (u *GroupUpsertBulk) UpdateDefaultValidityDays() *GroupUpsertBulk {
}) })
} }
// SetAllowImageGeneration sets the "allow_image_generation" field.
func (u *GroupUpsertBulk) SetAllowImageGeneration(v bool) *GroupUpsertBulk {
return u.Update(func(s *GroupUpsert) {
s.SetAllowImageGeneration(v)
})
}
// UpdateAllowImageGeneration sets the "allow_image_generation" field to the value that was provided on create.
func (u *GroupUpsertBulk) UpdateAllowImageGeneration() *GroupUpsertBulk {
return u.Update(func(s *GroupUpsert) {
s.UpdateAllowImageGeneration()
})
}
// SetImageRateIndependent sets the "image_rate_independent" field.
func (u *GroupUpsertBulk) SetImageRateIndependent(v bool) *GroupUpsertBulk {
return u.Update(func(s *GroupUpsert) {
s.SetImageRateIndependent(v)
})
}
// UpdateImageRateIndependent sets the "image_rate_independent" field to the value that was provided on create.
func (u *GroupUpsertBulk) UpdateImageRateIndependent() *GroupUpsertBulk {
return u.Update(func(s *GroupUpsert) {
s.UpdateImageRateIndependent()
})
}
// SetImageRateMultiplier sets the "image_rate_multiplier" field.
func (u *GroupUpsertBulk) SetImageRateMultiplier(v float64) *GroupUpsertBulk {
return u.Update(func(s *GroupUpsert) {
s.SetImageRateMultiplier(v)
})
}
// AddImageRateMultiplier adds v to the "image_rate_multiplier" field.
func (u *GroupUpsertBulk) AddImageRateMultiplier(v float64) *GroupUpsertBulk {
return u.Update(func(s *GroupUpsert) {
s.AddImageRateMultiplier(v)
})
}
// UpdateImageRateMultiplier sets the "image_rate_multiplier" field to the value that was provided on create.
func (u *GroupUpsertBulk) UpdateImageRateMultiplier() *GroupUpsertBulk {
return u.Update(func(s *GroupUpsert) {
s.UpdateImageRateMultiplier()
})
}
// SetImagePrice1k sets the "image_price_1k" field. // SetImagePrice1k sets the "image_price_1k" field.
func (u *GroupUpsertBulk) SetImagePrice1k(v float64) *GroupUpsertBulk { func (u *GroupUpsertBulk) SetImagePrice1k(v float64) *GroupUpsertBulk {
return u.Update(func(s *GroupUpsert) { return u.Update(func(s *GroupUpsert) {

View File

@ -275,6 +275,55 @@ func (_u *GroupUpdate) AddDefaultValidityDays(v int) *GroupUpdate {
return _u return _u
} }
// SetAllowImageGeneration sets the "allow_image_generation" field.
func (_u *GroupUpdate) SetAllowImageGeneration(v bool) *GroupUpdate {
_u.mutation.SetAllowImageGeneration(v)
return _u
}
// SetNillableAllowImageGeneration sets the "allow_image_generation" field if the given value is not nil.
func (_u *GroupUpdate) SetNillableAllowImageGeneration(v *bool) *GroupUpdate {
if v != nil {
_u.SetAllowImageGeneration(*v)
}
return _u
}
// SetImageRateIndependent sets the "image_rate_independent" field.
func (_u *GroupUpdate) SetImageRateIndependent(v bool) *GroupUpdate {
_u.mutation.SetImageRateIndependent(v)
return _u
}
// SetNillableImageRateIndependent sets the "image_rate_independent" field if the given value is not nil.
func (_u *GroupUpdate) SetNillableImageRateIndependent(v *bool) *GroupUpdate {
if v != nil {
_u.SetImageRateIndependent(*v)
}
return _u
}
// SetImageRateMultiplier sets the "image_rate_multiplier" field.
func (_u *GroupUpdate) SetImageRateMultiplier(v float64) *GroupUpdate {
_u.mutation.ResetImageRateMultiplier()
_u.mutation.SetImageRateMultiplier(v)
return _u
}
// SetNillableImageRateMultiplier sets the "image_rate_multiplier" field if the given value is not nil.
func (_u *GroupUpdate) SetNillableImageRateMultiplier(v *float64) *GroupUpdate {
if v != nil {
_u.SetImageRateMultiplier(*v)
}
return _u
}
// AddImageRateMultiplier adds value to the "image_rate_multiplier" field.
func (_u *GroupUpdate) AddImageRateMultiplier(v float64) *GroupUpdate {
_u.mutation.AddImageRateMultiplier(v)
return _u
}
// SetImagePrice1k sets the "image_price_1k" field. // SetImagePrice1k sets the "image_price_1k" field.
func (_u *GroupUpdate) SetImagePrice1k(v float64) *GroupUpdate { func (_u *GroupUpdate) SetImagePrice1k(v float64) *GroupUpdate {
_u.mutation.ResetImagePrice1k() _u.mutation.ResetImagePrice1k()
@ -962,6 +1011,18 @@ func (_u *GroupUpdate) sqlSave(ctx context.Context) (_node int, err error) {
if value, ok := _u.mutation.AddedDefaultValidityDays(); ok { if value, ok := _u.mutation.AddedDefaultValidityDays(); ok {
_spec.AddField(group.FieldDefaultValidityDays, field.TypeInt, value) _spec.AddField(group.FieldDefaultValidityDays, field.TypeInt, value)
} }
if value, ok := _u.mutation.AllowImageGeneration(); ok {
_spec.SetField(group.FieldAllowImageGeneration, field.TypeBool, value)
}
if value, ok := _u.mutation.ImageRateIndependent(); ok {
_spec.SetField(group.FieldImageRateIndependent, field.TypeBool, value)
}
if value, ok := _u.mutation.ImageRateMultiplier(); ok {
_spec.SetField(group.FieldImageRateMultiplier, field.TypeFloat64, value)
}
if value, ok := _u.mutation.AddedImageRateMultiplier(); ok {
_spec.AddField(group.FieldImageRateMultiplier, field.TypeFloat64, value)
}
if value, ok := _u.mutation.ImagePrice1k(); ok { if value, ok := _u.mutation.ImagePrice1k(); ok {
_spec.SetField(group.FieldImagePrice1k, field.TypeFloat64, value) _spec.SetField(group.FieldImagePrice1k, field.TypeFloat64, value)
} }
@ -1610,6 +1671,55 @@ func (_u *GroupUpdateOne) AddDefaultValidityDays(v int) *GroupUpdateOne {
return _u return _u
} }
// SetAllowImageGeneration sets the "allow_image_generation" field.
func (_u *GroupUpdateOne) SetAllowImageGeneration(v bool) *GroupUpdateOne {
_u.mutation.SetAllowImageGeneration(v)
return _u
}
// SetNillableAllowImageGeneration sets the "allow_image_generation" field if the given value is not nil.
func (_u *GroupUpdateOne) SetNillableAllowImageGeneration(v *bool) *GroupUpdateOne {
if v != nil {
_u.SetAllowImageGeneration(*v)
}
return _u
}
// SetImageRateIndependent sets the "image_rate_independent" field.
func (_u *GroupUpdateOne) SetImageRateIndependent(v bool) *GroupUpdateOne {
_u.mutation.SetImageRateIndependent(v)
return _u
}
// SetNillableImageRateIndependent sets the "image_rate_independent" field if the given value is not nil.
func (_u *GroupUpdateOne) SetNillableImageRateIndependent(v *bool) *GroupUpdateOne {
if v != nil {
_u.SetImageRateIndependent(*v)
}
return _u
}
// SetImageRateMultiplier sets the "image_rate_multiplier" field.
func (_u *GroupUpdateOne) SetImageRateMultiplier(v float64) *GroupUpdateOne {
_u.mutation.ResetImageRateMultiplier()
_u.mutation.SetImageRateMultiplier(v)
return _u
}
// SetNillableImageRateMultiplier sets the "image_rate_multiplier" field if the given value is not nil.
func (_u *GroupUpdateOne) SetNillableImageRateMultiplier(v *float64) *GroupUpdateOne {
if v != nil {
_u.SetImageRateMultiplier(*v)
}
return _u
}
// AddImageRateMultiplier adds value to the "image_rate_multiplier" field.
func (_u *GroupUpdateOne) AddImageRateMultiplier(v float64) *GroupUpdateOne {
_u.mutation.AddImageRateMultiplier(v)
return _u
}
// SetImagePrice1k sets the "image_price_1k" field. // SetImagePrice1k sets the "image_price_1k" field.
func (_u *GroupUpdateOne) SetImagePrice1k(v float64) *GroupUpdateOne { func (_u *GroupUpdateOne) SetImagePrice1k(v float64) *GroupUpdateOne {
_u.mutation.ResetImagePrice1k() _u.mutation.ResetImagePrice1k()
@ -2327,6 +2437,18 @@ func (_u *GroupUpdateOne) sqlSave(ctx context.Context) (_node *Group, err error)
if value, ok := _u.mutation.AddedDefaultValidityDays(); ok { if value, ok := _u.mutation.AddedDefaultValidityDays(); ok {
_spec.AddField(group.FieldDefaultValidityDays, field.TypeInt, value) _spec.AddField(group.FieldDefaultValidityDays, field.TypeInt, value)
} }
if value, ok := _u.mutation.AllowImageGeneration(); ok {
_spec.SetField(group.FieldAllowImageGeneration, field.TypeBool, value)
}
if value, ok := _u.mutation.ImageRateIndependent(); ok {
_spec.SetField(group.FieldImageRateIndependent, field.TypeBool, value)
}
if value, ok := _u.mutation.ImageRateMultiplier(); ok {
_spec.SetField(group.FieldImageRateMultiplier, field.TypeFloat64, value)
}
if value, ok := _u.mutation.AddedImageRateMultiplier(); ok {
_spec.AddField(group.FieldImageRateMultiplier, field.TypeFloat64, value)
}
if value, ok := _u.mutation.ImagePrice1k(); ok { if value, ok := _u.mutation.ImagePrice1k(); ok {
_spec.SetField(group.FieldImagePrice1k, field.TypeFloat64, value) _spec.SetField(group.FieldImagePrice1k, field.TypeFloat64, value)
} }

View File

@ -638,6 +638,9 @@ var (
{Name: "weekly_limit_usd", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}}, {Name: "weekly_limit_usd", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}},
{Name: "monthly_limit_usd", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}}, {Name: "monthly_limit_usd", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}},
{Name: "default_validity_days", Type: field.TypeInt, Default: 30}, {Name: "default_validity_days", Type: field.TypeInt, Default: 30},
{Name: "allow_image_generation", Type: field.TypeBool, Default: false},
{Name: "image_rate_independent", Type: field.TypeBool, Default: false},
{Name: "image_rate_multiplier", Type: field.TypeFloat64, Default: 1, SchemaType: map[string]string{"postgres": "decimal(10,4)"}},
{Name: "image_price_1k", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}}, {Name: "image_price_1k", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}},
{Name: "image_price_2k", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}}, {Name: "image_price_2k", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}},
{Name: "image_price_4k", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}}, {Name: "image_price_4k", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}},
@ -690,7 +693,7 @@ var (
{ {
Name: "group_sort_order", Name: "group_sort_order",
Unique: false, Unique: false,
Columns: []*schema.Column{GroupsColumns[25]}, Columns: []*schema.Column{GroupsColumns[28]},
}, },
}, },
} }

View File

@ -14764,6 +14764,10 @@ type GroupMutation struct {
addmonthly_limit_usd *float64 addmonthly_limit_usd *float64
default_validity_days *int default_validity_days *int
adddefault_validity_days *int adddefault_validity_days *int
allow_image_generation *bool
image_rate_independent *bool
image_rate_multiplier *float64
addimage_rate_multiplier *float64
image_price_1k *float64 image_price_1k *float64
addimage_price_1k *float64 addimage_price_1k *float64
image_price_2k *float64 image_price_2k *float64
@ -15583,6 +15587,134 @@ func (m *GroupMutation) ResetDefaultValidityDays() {
m.adddefault_validity_days = nil m.adddefault_validity_days = nil
} }
// SetAllowImageGeneration sets the "allow_image_generation" field.
func (m *GroupMutation) SetAllowImageGeneration(b bool) {
m.allow_image_generation = &b
}
// AllowImageGeneration returns the value of the "allow_image_generation" field in the mutation.
func (m *GroupMutation) AllowImageGeneration() (r bool, exists bool) {
v := m.allow_image_generation
if v == nil {
return
}
return *v, true
}
// OldAllowImageGeneration returns the old "allow_image_generation" field's value of the Group entity.
// If the Group object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
func (m *GroupMutation) OldAllowImageGeneration(ctx context.Context) (v bool, err error) {
if !m.op.Is(OpUpdateOne) {
return v, errors.New("OldAllowImageGeneration is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
return v, errors.New("OldAllowImageGeneration requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
return v, fmt.Errorf("querying old value for OldAllowImageGeneration: %w", err)
}
return oldValue.AllowImageGeneration, nil
}
// ResetAllowImageGeneration resets all changes to the "allow_image_generation" field.
func (m *GroupMutation) ResetAllowImageGeneration() {
m.allow_image_generation = nil
}
// SetImageRateIndependent sets the "image_rate_independent" field.
func (m *GroupMutation) SetImageRateIndependent(b bool) {
m.image_rate_independent = &b
}
// ImageRateIndependent returns the value of the "image_rate_independent" field in the mutation.
func (m *GroupMutation) ImageRateIndependent() (r bool, exists bool) {
v := m.image_rate_independent
if v == nil {
return
}
return *v, true
}
// OldImageRateIndependent returns the old "image_rate_independent" field's value of the Group entity.
// If the Group object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
func (m *GroupMutation) OldImageRateIndependent(ctx context.Context) (v bool, err error) {
if !m.op.Is(OpUpdateOne) {
return v, errors.New("OldImageRateIndependent is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
return v, errors.New("OldImageRateIndependent requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
return v, fmt.Errorf("querying old value for OldImageRateIndependent: %w", err)
}
return oldValue.ImageRateIndependent, nil
}
// ResetImageRateIndependent resets all changes to the "image_rate_independent" field.
func (m *GroupMutation) ResetImageRateIndependent() {
m.image_rate_independent = nil
}
// SetImageRateMultiplier sets the "image_rate_multiplier" field.
func (m *GroupMutation) SetImageRateMultiplier(f float64) {
m.image_rate_multiplier = &f
m.addimage_rate_multiplier = nil
}
// ImageRateMultiplier returns the value of the "image_rate_multiplier" field in the mutation.
func (m *GroupMutation) ImageRateMultiplier() (r float64, exists bool) {
v := m.image_rate_multiplier
if v == nil {
return
}
return *v, true
}
// OldImageRateMultiplier returns the old "image_rate_multiplier" field's value of the Group entity.
// If the Group object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
func (m *GroupMutation) OldImageRateMultiplier(ctx context.Context) (v float64, err error) {
if !m.op.Is(OpUpdateOne) {
return v, errors.New("OldImageRateMultiplier is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
return v, errors.New("OldImageRateMultiplier requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
return v, fmt.Errorf("querying old value for OldImageRateMultiplier: %w", err)
}
return oldValue.ImageRateMultiplier, nil
}
// AddImageRateMultiplier adds f to the "image_rate_multiplier" field.
func (m *GroupMutation) AddImageRateMultiplier(f float64) {
if m.addimage_rate_multiplier != nil {
*m.addimage_rate_multiplier += f
} else {
m.addimage_rate_multiplier = &f
}
}
// AddedImageRateMultiplier returns the value that was added to the "image_rate_multiplier" field in this mutation.
func (m *GroupMutation) AddedImageRateMultiplier() (r float64, exists bool) {
v := m.addimage_rate_multiplier
if v == nil {
return
}
return *v, true
}
// ResetImageRateMultiplier resets all changes to the "image_rate_multiplier" field.
func (m *GroupMutation) ResetImageRateMultiplier() {
m.image_rate_multiplier = nil
m.addimage_rate_multiplier = nil
}
// SetImagePrice1k sets the "image_price_1k" field. // SetImagePrice1k sets the "image_price_1k" field.
func (m *GroupMutation) SetImagePrice1k(f float64) { func (m *GroupMutation) SetImagePrice1k(f float64) {
m.image_price_1k = &f m.image_price_1k = &f
@ -16791,7 +16923,7 @@ func (m *GroupMutation) Type() string {
// order to get all numeric fields that were incremented/decremented, call // order to get all numeric fields that were incremented/decremented, call
// AddedFields(). // AddedFields().
func (m *GroupMutation) Fields() []string { func (m *GroupMutation) Fields() []string {
fields := make([]string, 0, 31) fields := make([]string, 0, 34)
if m.created_at != nil { if m.created_at != nil {
fields = append(fields, group.FieldCreatedAt) fields = append(fields, group.FieldCreatedAt)
} }
@ -16834,6 +16966,15 @@ func (m *GroupMutation) Fields() []string {
if m.default_validity_days != nil { if m.default_validity_days != nil {
fields = append(fields, group.FieldDefaultValidityDays) fields = append(fields, group.FieldDefaultValidityDays)
} }
if m.allow_image_generation != nil {
fields = append(fields, group.FieldAllowImageGeneration)
}
if m.image_rate_independent != nil {
fields = append(fields, group.FieldImageRateIndependent)
}
if m.image_rate_multiplier != nil {
fields = append(fields, group.FieldImageRateMultiplier)
}
if m.image_price_1k != nil { if m.image_price_1k != nil {
fields = append(fields, group.FieldImagePrice1k) fields = append(fields, group.FieldImagePrice1k)
} }
@ -16921,6 +17062,12 @@ func (m *GroupMutation) Field(name string) (ent.Value, bool) {
return m.MonthlyLimitUsd() return m.MonthlyLimitUsd()
case group.FieldDefaultValidityDays: case group.FieldDefaultValidityDays:
return m.DefaultValidityDays() return m.DefaultValidityDays()
case group.FieldAllowImageGeneration:
return m.AllowImageGeneration()
case group.FieldImageRateIndependent:
return m.ImageRateIndependent()
case group.FieldImageRateMultiplier:
return m.ImageRateMultiplier()
case group.FieldImagePrice1k: case group.FieldImagePrice1k:
return m.ImagePrice1k() return m.ImagePrice1k()
case group.FieldImagePrice2k: case group.FieldImagePrice2k:
@ -16992,6 +17139,12 @@ func (m *GroupMutation) OldField(ctx context.Context, name string) (ent.Value, e
return m.OldMonthlyLimitUsd(ctx) return m.OldMonthlyLimitUsd(ctx)
case group.FieldDefaultValidityDays: case group.FieldDefaultValidityDays:
return m.OldDefaultValidityDays(ctx) return m.OldDefaultValidityDays(ctx)
case group.FieldAllowImageGeneration:
return m.OldAllowImageGeneration(ctx)
case group.FieldImageRateIndependent:
return m.OldImageRateIndependent(ctx)
case group.FieldImageRateMultiplier:
return m.OldImageRateMultiplier(ctx)
case group.FieldImagePrice1k: case group.FieldImagePrice1k:
return m.OldImagePrice1k(ctx) return m.OldImagePrice1k(ctx)
case group.FieldImagePrice2k: case group.FieldImagePrice2k:
@ -17133,6 +17286,27 @@ func (m *GroupMutation) SetField(name string, value ent.Value) error {
} }
m.SetDefaultValidityDays(v) m.SetDefaultValidityDays(v)
return nil return nil
case group.FieldAllowImageGeneration:
v, ok := value.(bool)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
m.SetAllowImageGeneration(v)
return nil
case group.FieldImageRateIndependent:
v, ok := value.(bool)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
m.SetImageRateIndependent(v)
return nil
case group.FieldImageRateMultiplier:
v, ok := value.(float64)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
m.SetImageRateMultiplier(v)
return nil
case group.FieldImagePrice1k: case group.FieldImagePrice1k:
v, ok := value.(float64) v, ok := value.(float64)
if !ok { if !ok {
@ -17275,6 +17449,9 @@ func (m *GroupMutation) AddedFields() []string {
if m.adddefault_validity_days != nil { if m.adddefault_validity_days != nil {
fields = append(fields, group.FieldDefaultValidityDays) fields = append(fields, group.FieldDefaultValidityDays)
} }
if m.addimage_rate_multiplier != nil {
fields = append(fields, group.FieldImageRateMultiplier)
}
if m.addimage_price_1k != nil { if m.addimage_price_1k != nil {
fields = append(fields, group.FieldImagePrice1k) fields = append(fields, group.FieldImagePrice1k)
} }
@ -17314,6 +17491,8 @@ func (m *GroupMutation) AddedField(name string) (ent.Value, bool) {
return m.AddedMonthlyLimitUsd() return m.AddedMonthlyLimitUsd()
case group.FieldDefaultValidityDays: case group.FieldDefaultValidityDays:
return m.AddedDefaultValidityDays() return m.AddedDefaultValidityDays()
case group.FieldImageRateMultiplier:
return m.AddedImageRateMultiplier()
case group.FieldImagePrice1k: case group.FieldImagePrice1k:
return m.AddedImagePrice1k() return m.AddedImagePrice1k()
case group.FieldImagePrice2k: case group.FieldImagePrice2k:
@ -17372,6 +17551,13 @@ func (m *GroupMutation) AddField(name string, value ent.Value) error {
} }
m.AddDefaultValidityDays(v) m.AddDefaultValidityDays(v)
return nil return nil
case group.FieldImageRateMultiplier:
v, ok := value.(float64)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
m.AddImageRateMultiplier(v)
return nil
case group.FieldImagePrice1k: case group.FieldImagePrice1k:
v, ok := value.(float64) v, ok := value.(float64)
if !ok { if !ok {
@ -17559,6 +17745,15 @@ func (m *GroupMutation) ResetField(name string) error {
case group.FieldDefaultValidityDays: case group.FieldDefaultValidityDays:
m.ResetDefaultValidityDays() m.ResetDefaultValidityDays()
return nil return nil
case group.FieldAllowImageGeneration:
m.ResetAllowImageGeneration()
return nil
case group.FieldImageRateIndependent:
m.ResetImageRateIndependent()
return nil
case group.FieldImageRateMultiplier:
m.ResetImageRateMultiplier()
return nil
case group.FieldImagePrice1k: case group.FieldImagePrice1k:
m.ResetImagePrice1k() m.ResetImagePrice1k()
return nil return nil

View File

@ -803,50 +803,62 @@ func init() {
groupDescDefaultValidityDays := groupFields[10].Descriptor() groupDescDefaultValidityDays := groupFields[10].Descriptor()
// group.DefaultDefaultValidityDays holds the default value on creation for the default_validity_days field. // group.DefaultDefaultValidityDays holds the default value on creation for the default_validity_days field.
group.DefaultDefaultValidityDays = groupDescDefaultValidityDays.Default.(int) group.DefaultDefaultValidityDays = groupDescDefaultValidityDays.Default.(int)
// groupDescAllowImageGeneration is the schema descriptor for allow_image_generation field.
groupDescAllowImageGeneration := groupFields[11].Descriptor()
// group.DefaultAllowImageGeneration holds the default value on creation for the allow_image_generation field.
group.DefaultAllowImageGeneration = groupDescAllowImageGeneration.Default.(bool)
// groupDescImageRateIndependent is the schema descriptor for image_rate_independent field.
groupDescImageRateIndependent := groupFields[12].Descriptor()
// group.DefaultImageRateIndependent holds the default value on creation for the image_rate_independent field.
group.DefaultImageRateIndependent = groupDescImageRateIndependent.Default.(bool)
// groupDescImageRateMultiplier is the schema descriptor for image_rate_multiplier field.
groupDescImageRateMultiplier := groupFields[13].Descriptor()
// group.DefaultImageRateMultiplier holds the default value on creation for the image_rate_multiplier field.
group.DefaultImageRateMultiplier = groupDescImageRateMultiplier.Default.(float64)
// groupDescClaudeCodeOnly is the schema descriptor for claude_code_only field. // groupDescClaudeCodeOnly is the schema descriptor for claude_code_only field.
groupDescClaudeCodeOnly := groupFields[14].Descriptor() groupDescClaudeCodeOnly := groupFields[17].Descriptor()
// group.DefaultClaudeCodeOnly holds the default value on creation for the claude_code_only field. // group.DefaultClaudeCodeOnly holds the default value on creation for the claude_code_only field.
group.DefaultClaudeCodeOnly = groupDescClaudeCodeOnly.Default.(bool) group.DefaultClaudeCodeOnly = groupDescClaudeCodeOnly.Default.(bool)
// groupDescModelRoutingEnabled is the schema descriptor for model_routing_enabled field. // groupDescModelRoutingEnabled is the schema descriptor for model_routing_enabled field.
groupDescModelRoutingEnabled := groupFields[18].Descriptor() groupDescModelRoutingEnabled := groupFields[21].Descriptor()
// group.DefaultModelRoutingEnabled holds the default value on creation for the model_routing_enabled field. // group.DefaultModelRoutingEnabled holds the default value on creation for the model_routing_enabled field.
group.DefaultModelRoutingEnabled = groupDescModelRoutingEnabled.Default.(bool) group.DefaultModelRoutingEnabled = groupDescModelRoutingEnabled.Default.(bool)
// groupDescMcpXMLInject is the schema descriptor for mcp_xml_inject field. // groupDescMcpXMLInject is the schema descriptor for mcp_xml_inject field.
groupDescMcpXMLInject := groupFields[19].Descriptor() groupDescMcpXMLInject := groupFields[22].Descriptor()
// group.DefaultMcpXMLInject holds the default value on creation for the mcp_xml_inject field. // group.DefaultMcpXMLInject holds the default value on creation for the mcp_xml_inject field.
group.DefaultMcpXMLInject = groupDescMcpXMLInject.Default.(bool) group.DefaultMcpXMLInject = groupDescMcpXMLInject.Default.(bool)
// groupDescSupportedModelScopes is the schema descriptor for supported_model_scopes field. // groupDescSupportedModelScopes is the schema descriptor for supported_model_scopes field.
groupDescSupportedModelScopes := groupFields[20].Descriptor() groupDescSupportedModelScopes := groupFields[23].Descriptor()
// group.DefaultSupportedModelScopes holds the default value on creation for the supported_model_scopes field. // group.DefaultSupportedModelScopes holds the default value on creation for the supported_model_scopes field.
group.DefaultSupportedModelScopes = groupDescSupportedModelScopes.Default.([]string) group.DefaultSupportedModelScopes = groupDescSupportedModelScopes.Default.([]string)
// groupDescSortOrder is the schema descriptor for sort_order field. // groupDescSortOrder is the schema descriptor for sort_order field.
groupDescSortOrder := groupFields[21].Descriptor() groupDescSortOrder := groupFields[24].Descriptor()
// group.DefaultSortOrder holds the default value on creation for the sort_order field. // group.DefaultSortOrder holds the default value on creation for the sort_order field.
group.DefaultSortOrder = groupDescSortOrder.Default.(int) group.DefaultSortOrder = groupDescSortOrder.Default.(int)
// groupDescAllowMessagesDispatch is the schema descriptor for allow_messages_dispatch field. // groupDescAllowMessagesDispatch is the schema descriptor for allow_messages_dispatch field.
groupDescAllowMessagesDispatch := groupFields[22].Descriptor() groupDescAllowMessagesDispatch := groupFields[25].Descriptor()
// group.DefaultAllowMessagesDispatch holds the default value on creation for the allow_messages_dispatch field. // group.DefaultAllowMessagesDispatch holds the default value on creation for the allow_messages_dispatch field.
group.DefaultAllowMessagesDispatch = groupDescAllowMessagesDispatch.Default.(bool) group.DefaultAllowMessagesDispatch = groupDescAllowMessagesDispatch.Default.(bool)
// groupDescRequireOauthOnly is the schema descriptor for require_oauth_only field. // groupDescRequireOauthOnly is the schema descriptor for require_oauth_only field.
groupDescRequireOauthOnly := groupFields[23].Descriptor() groupDescRequireOauthOnly := groupFields[26].Descriptor()
// group.DefaultRequireOauthOnly holds the default value on creation for the require_oauth_only field. // group.DefaultRequireOauthOnly holds the default value on creation for the require_oauth_only field.
group.DefaultRequireOauthOnly = groupDescRequireOauthOnly.Default.(bool) group.DefaultRequireOauthOnly = groupDescRequireOauthOnly.Default.(bool)
// groupDescRequirePrivacySet is the schema descriptor for require_privacy_set field. // groupDescRequirePrivacySet is the schema descriptor for require_privacy_set field.
groupDescRequirePrivacySet := groupFields[24].Descriptor() groupDescRequirePrivacySet := groupFields[27].Descriptor()
// group.DefaultRequirePrivacySet holds the default value on creation for the require_privacy_set field. // group.DefaultRequirePrivacySet holds the default value on creation for the require_privacy_set field.
group.DefaultRequirePrivacySet = groupDescRequirePrivacySet.Default.(bool) group.DefaultRequirePrivacySet = groupDescRequirePrivacySet.Default.(bool)
// groupDescDefaultMappedModel is the schema descriptor for default_mapped_model field. // groupDescDefaultMappedModel is the schema descriptor for default_mapped_model field.
groupDescDefaultMappedModel := groupFields[25].Descriptor() groupDescDefaultMappedModel := groupFields[28].Descriptor()
// group.DefaultDefaultMappedModel holds the default value on creation for the default_mapped_model field. // group.DefaultDefaultMappedModel holds the default value on creation for the default_mapped_model field.
group.DefaultDefaultMappedModel = groupDescDefaultMappedModel.Default.(string) group.DefaultDefaultMappedModel = groupDescDefaultMappedModel.Default.(string)
// group.DefaultMappedModelValidator is a validator for the "default_mapped_model" field. It is called by the builders before save. // group.DefaultMappedModelValidator is a validator for the "default_mapped_model" field. It is called by the builders before save.
group.DefaultMappedModelValidator = groupDescDefaultMappedModel.Validators[0].(func(string) error) group.DefaultMappedModelValidator = groupDescDefaultMappedModel.Validators[0].(func(string) error)
// groupDescMessagesDispatchModelConfig is the schema descriptor for messages_dispatch_model_config field. // groupDescMessagesDispatchModelConfig is the schema descriptor for messages_dispatch_model_config field.
groupDescMessagesDispatchModelConfig := groupFields[26].Descriptor() groupDescMessagesDispatchModelConfig := groupFields[29].Descriptor()
// group.DefaultMessagesDispatchModelConfig holds the default value on creation for the messages_dispatch_model_config field. // group.DefaultMessagesDispatchModelConfig holds the default value on creation for the messages_dispatch_model_config field.
group.DefaultMessagesDispatchModelConfig = groupDescMessagesDispatchModelConfig.Default.(domain.OpenAIMessagesDispatchModelConfig) group.DefaultMessagesDispatchModelConfig = groupDescMessagesDispatchModelConfig.Default.(domain.OpenAIMessagesDispatchModelConfig)
// groupDescRpmLimit is the schema descriptor for rpm_limit field. // groupDescRpmLimit is the schema descriptor for rpm_limit field.
groupDescRpmLimit := groupFields[27].Descriptor() groupDescRpmLimit := groupFields[30].Descriptor()
// group.DefaultRpmLimit holds the default value on creation for the rpm_limit field. // group.DefaultRpmLimit holds the default value on creation for the rpm_limit field.
group.DefaultRpmLimit = groupDescRpmLimit.Default.(int) group.DefaultRpmLimit = groupDescRpmLimit.Default.(int)
idempotencyrecordMixin := schema.IdempotencyRecord{}.Mixin() idempotencyrecordMixin := schema.IdempotencyRecord{}.Mixin()

View File

@ -74,6 +74,16 @@ func (Group) Fields() []ent.Field {
Default(30), Default(30),
// 图片生成计费配置antigravity 和 gemini 平台使用) // 图片生成计费配置antigravity 和 gemini 平台使用)
field.Bool("allow_image_generation").
Default(false).
Comment("是否允许该分组使用图片生成能力"),
field.Bool("image_rate_independent").
Default(false).
Comment("图片生成是否使用独立倍率false 表示共享分组有效倍率"),
field.Float("image_rate_multiplier").
SchemaType(map[string]string{dialect.Postgres: "decimal(10,4)"}).
Default(1.0).
Comment("图片生成独立倍率,仅 image_rate_independent=true 时生效"),
field.Float("image_price_1k"). field.Float("image_price_1k").
Optional(). Optional().
Nillable(). Nillable().

View File

@ -576,6 +576,24 @@ type ConcurrencyConfig struct {
PingInterval int `mapstructure:"ping_interval"` PingInterval int `mapstructure:"ping_interval"`
} }
type ImageConcurrencyConfig struct {
// Enabled: 是否启用图片生成独立并发限制,默认关闭以保持现有行为
Enabled bool `mapstructure:"enabled"`
// MaxConcurrentRequests: 当前进程允许同时处理的图片生成请求数0表示不限制
MaxConcurrentRequests int `mapstructure:"max_concurrent_requests"`
// OverflowMode: 图片并发达到上限后的处理方式reject/wait
OverflowMode string `mapstructure:"overflow_mode"`
// WaitTimeoutSeconds: overflow_mode=wait 时等待图片并发槽位的超时时间(秒)
WaitTimeoutSeconds int `mapstructure:"wait_timeout_seconds"`
// MaxWaitingRequests: overflow_mode=wait 时当前进程允许排队等待的图片请求数
MaxWaitingRequests int `mapstructure:"max_waiting_requests"`
}
const (
ImageConcurrencyOverflowModeReject = "reject"
ImageConcurrencyOverflowModeWait = "wait"
)
// GatewayConfig API网关相关配置 // GatewayConfig API网关相关配置
type GatewayConfig struct { type GatewayConfig struct {
// 等待上游响应头的超时时间0表示无超时 // 等待上游响应头的超时时间0表示无超时
@ -609,6 +627,8 @@ type GatewayConfig struct {
AntigravityLSWorker GatewayAntigravityLSWorkerConfig `mapstructure:"antigravity_ls_worker"` AntigravityLSWorker GatewayAntigravityLSWorkerConfig `mapstructure:"antigravity_ls_worker"`
// NodeTLSProxy: Node.js TLS 代理配置 // NodeTLSProxy: Node.js TLS 代理配置
NodeTLSProxy NodeTLSProxyConfig `mapstructure:"node_tls_proxy"` NodeTLSProxy NodeTLSProxyConfig `mapstructure:"node_tls_proxy"`
// ImageConcurrency: 图片生成独立并发限制配置(默认关闭)
ImageConcurrency ImageConcurrencyConfig `mapstructure:"image_concurrency"`
// HTTP 上游连接池配置(性能优化:支持高并发场景调优) // HTTP 上游连接池配置(性能优化:支持高并发场景调优)
// MaxIdleConns: 所有主机的最大空闲连接总数 // MaxIdleConns: 所有主机的最大空闲连接总数
@ -640,6 +660,10 @@ type GatewayConfig struct {
StreamDataIntervalTimeout int `mapstructure:"stream_data_interval_timeout"` StreamDataIntervalTimeout int `mapstructure:"stream_data_interval_timeout"`
// StreamKeepaliveInterval: 流式 keepalive 间隔0表示禁用 // StreamKeepaliveInterval: 流式 keepalive 间隔0表示禁用
StreamKeepaliveInterval int `mapstructure:"stream_keepalive_interval"` StreamKeepaliveInterval int `mapstructure:"stream_keepalive_interval"`
// ImageStreamDataIntervalTimeout: 图片流数据间隔超时0表示禁用
ImageStreamDataIntervalTimeout int `mapstructure:"image_stream_data_interval_timeout"`
// ImageStreamKeepaliveInterval: 图片流式 keepalive 间隔0表示禁用
ImageStreamKeepaliveInterval int `mapstructure:"image_stream_keepalive_interval"`
// MaxLineSize: 上游 SSE 单行最大字节数0使用默认值 // MaxLineSize: 上游 SSE 单行最大字节数0使用默认值
MaxLineSize int `mapstructure:"max_line_size"` MaxLineSize int `mapstructure:"max_line_size"`
@ -1789,6 +1813,11 @@ func setDefaults() {
viper.SetDefault("gateway.openai_ws.scheduler_score_weights.queue", 0.7) viper.SetDefault("gateway.openai_ws.scheduler_score_weights.queue", 0.7)
viper.SetDefault("gateway.openai_ws.scheduler_score_weights.error_rate", 0.8) viper.SetDefault("gateway.openai_ws.scheduler_score_weights.error_rate", 0.8)
viper.SetDefault("gateway.openai_ws.scheduler_score_weights.ttft", 0.5) viper.SetDefault("gateway.openai_ws.scheduler_score_weights.ttft", 0.5)
viper.SetDefault("gateway.image_concurrency.enabled", false)
viper.SetDefault("gateway.image_concurrency.max_concurrent_requests", 0)
viper.SetDefault("gateway.image_concurrency.overflow_mode", ImageConcurrencyOverflowModeReject)
viper.SetDefault("gateway.image_concurrency.wait_timeout_seconds", 30)
viper.SetDefault("gateway.image_concurrency.max_waiting_requests", 100)
viper.SetDefault("gateway.antigravity_fallback_cooldown_minutes", 1) viper.SetDefault("gateway.antigravity_fallback_cooldown_minutes", 1)
viper.SetDefault("gateway.antigravity_extra_retries", 10) viper.SetDefault("gateway.antigravity_extra_retries", 10)
viper.SetDefault("gateway.max_body_size", int64(256*1024*1024)) viper.SetDefault("gateway.max_body_size", int64(256*1024*1024))
@ -1806,6 +1835,8 @@ func setDefaults() {
viper.SetDefault("gateway.concurrency_slot_ttl_minutes", 30) // 并发槽位过期时间(支持超长请求) viper.SetDefault("gateway.concurrency_slot_ttl_minutes", 30) // 并发槽位过期时间(支持超长请求)
viper.SetDefault("gateway.stream_data_interval_timeout", 180) viper.SetDefault("gateway.stream_data_interval_timeout", 180)
viper.SetDefault("gateway.stream_keepalive_interval", 10) viper.SetDefault("gateway.stream_keepalive_interval", 10)
viper.SetDefault("gateway.image_stream_data_interval_timeout", 900)
viper.SetDefault("gateway.image_stream_keepalive_interval", 10)
viper.SetDefault("gateway.max_line_size", 500*1024*1024) viper.SetDefault("gateway.max_line_size", 500*1024*1024)
viper.SetDefault("gateway.scheduling.sticky_session_max_waiting", 3) viper.SetDefault("gateway.scheduling.sticky_session_max_waiting", 3)
viper.SetDefault("gateway.scheduling.sticky_session_wait_timeout", 120*time.Second) viper.SetDefault("gateway.scheduling.sticky_session_wait_timeout", 120*time.Second)
@ -2410,6 +2441,21 @@ func (c *Config) Validate() error {
ConnectionPoolIsolationProxy, ConnectionPoolIsolationAccount, ConnectionPoolIsolationAccountProxy) ConnectionPoolIsolationProxy, ConnectionPoolIsolationAccount, ConnectionPoolIsolationAccountProxy)
} }
} }
if c.Gateway.ImageConcurrency.MaxConcurrentRequests < 0 {
return fmt.Errorf("gateway.image_concurrency.max_concurrent_requests must be non-negative")
}
switch strings.TrimSpace(c.Gateway.ImageConcurrency.OverflowMode) {
case "", ImageConcurrencyOverflowModeReject, ImageConcurrencyOverflowModeWait:
default:
return fmt.Errorf("gateway.image_concurrency.overflow_mode must be one of: %s/%s",
ImageConcurrencyOverflowModeReject, ImageConcurrencyOverflowModeWait)
}
if c.Gateway.ImageConcurrency.WaitTimeoutSeconds < 0 {
return fmt.Errorf("gateway.image_concurrency.wait_timeout_seconds must be non-negative")
}
if c.Gateway.ImageConcurrency.MaxWaitingRequests < 0 {
return fmt.Errorf("gateway.image_concurrency.max_waiting_requests must be non-negative")
}
if c.Gateway.MaxIdleConns <= 0 { if c.Gateway.MaxIdleConns <= 0 {
return fmt.Errorf("gateway.max_idle_conns must be positive") return fmt.Errorf("gateway.max_idle_conns must be positive")
} }
@ -2448,6 +2494,20 @@ func (c *Config) Validate() error {
(c.Gateway.StreamKeepaliveInterval < 5 || c.Gateway.StreamKeepaliveInterval > 30) { (c.Gateway.StreamKeepaliveInterval < 5 || c.Gateway.StreamKeepaliveInterval > 30) {
return fmt.Errorf("gateway.stream_keepalive_interval must be 0 or between 5-30 seconds") return fmt.Errorf("gateway.stream_keepalive_interval must be 0 or between 5-30 seconds")
} }
if c.Gateway.ImageStreamDataIntervalTimeout < 0 {
return fmt.Errorf("gateway.image_stream_data_interval_timeout must be non-negative")
}
if c.Gateway.ImageStreamDataIntervalTimeout != 0 &&
(c.Gateway.ImageStreamDataIntervalTimeout < 60 || c.Gateway.ImageStreamDataIntervalTimeout > 1800) {
return fmt.Errorf("gateway.image_stream_data_interval_timeout must be 0 or between 60-1800 seconds")
}
if c.Gateway.ImageStreamKeepaliveInterval < 0 {
return fmt.Errorf("gateway.image_stream_keepalive_interval must be non-negative")
}
if c.Gateway.ImageStreamKeepaliveInterval != 0 &&
(c.Gateway.ImageStreamKeepaliveInterval < 5 || c.Gateway.ImageStreamKeepaliveInterval > 60) {
return fmt.Errorf("gateway.image_stream_keepalive_interval must be 0 or between 5-60 seconds")
}
// 兼容旧键 sticky_previous_response_ttl_seconds // 兼容旧键 sticky_previous_response_ttl_seconds
if c.Gateway.OpenAIWS.StickyResponseIDTTLSeconds <= 0 && c.Gateway.OpenAIWS.StickyPreviousResponseTTLSeconds > 0 { if c.Gateway.OpenAIWS.StickyResponseIDTTLSeconds <= 0 && c.Gateway.OpenAIWS.StickyPreviousResponseTTLSeconds > 0 {
c.Gateway.OpenAIWS.StickyResponseIDTTLSeconds = c.Gateway.OpenAIWS.StickyPreviousResponseTTLSeconds c.Gateway.OpenAIWS.StickyResponseIDTTLSeconds = c.Gateway.OpenAIWS.StickyPreviousResponseTTLSeconds

View File

@ -1282,6 +1282,46 @@ func TestValidateConfigErrors(t *testing.T) {
mutate: func(c *Config) { c.Gateway.StreamDataIntervalTimeout = -1 }, mutate: func(c *Config) { c.Gateway.StreamDataIntervalTimeout = -1 },
wantErr: "gateway.stream_data_interval_timeout must be non-negative", wantErr: "gateway.stream_data_interval_timeout must be non-negative",
}, },
{
name: "gateway image stream keepalive range",
mutate: func(c *Config) { c.Gateway.ImageStreamKeepaliveInterval = 4 },
wantErr: "gateway.image_stream_keepalive_interval",
},
{
name: "gateway image stream keepalive negative",
mutate: func(c *Config) { c.Gateway.ImageStreamKeepaliveInterval = -1 },
wantErr: "gateway.image_stream_keepalive_interval must be non-negative",
},
{
name: "gateway image stream data interval range",
mutate: func(c *Config) { c.Gateway.ImageStreamDataIntervalTimeout = 30 },
wantErr: "gateway.image_stream_data_interval_timeout",
},
{
name: "gateway image stream data interval negative",
mutate: func(c *Config) { c.Gateway.ImageStreamDataIntervalTimeout = -1 },
wantErr: "gateway.image_stream_data_interval_timeout must be non-negative",
},
{
name: "gateway image concurrency max negative",
mutate: func(c *Config) { c.Gateway.ImageConcurrency.MaxConcurrentRequests = -1 },
wantErr: "gateway.image_concurrency.max_concurrent_requests must be non-negative",
},
{
name: "gateway image concurrency overflow mode invalid",
mutate: func(c *Config) { c.Gateway.ImageConcurrency.OverflowMode = "queue" },
wantErr: "gateway.image_concurrency.overflow_mode",
},
{
name: "gateway image concurrency wait timeout negative",
mutate: func(c *Config) { c.Gateway.ImageConcurrency.WaitTimeoutSeconds = -1 },
wantErr: "gateway.image_concurrency.wait_timeout_seconds must be non-negative",
},
{
name: "gateway image concurrency max waiting negative",
mutate: func(c *Config) { c.Gateway.ImageConcurrency.MaxWaitingRequests = -1 },
wantErr: "gateway.image_concurrency.max_waiting_requests must be non-negative",
},
{ {
name: "gateway max line size", name: "gateway max line size",
mutate: func(c *Config) { c.Gateway.MaxLineSize = 1024 }, mutate: func(c *Config) { c.Gateway.MaxLineSize = 1024 },
@ -1754,3 +1794,41 @@ func TestLoad_DefaultGatewayUsageRecordConfig(t *testing.T) {
t.Fatalf("auto_scale_cooldown_seconds = %d, want 10", cfg.Gateway.UsageRecord.AutoScaleCooldownSeconds) t.Fatalf("auto_scale_cooldown_seconds = %d, want 10", cfg.Gateway.UsageRecord.AutoScaleCooldownSeconds)
} }
} }
func TestLoad_DefaultGatewayImageStreamConfig(t *testing.T) {
resetViperWithJWTSecret(t)
cfg, err := Load()
if err != nil {
t.Fatalf("Load() error: %v", err)
}
if cfg.Gateway.StreamDataIntervalTimeout != 180 {
t.Fatalf("stream_data_interval_timeout = %d, want 180", cfg.Gateway.StreamDataIntervalTimeout)
}
if cfg.Gateway.StreamKeepaliveInterval != 10 {
t.Fatalf("stream_keepalive_interval = %d, want 10", cfg.Gateway.StreamKeepaliveInterval)
}
if cfg.Gateway.ImageStreamDataIntervalTimeout != 900 {
t.Fatalf("image_stream_data_interval_timeout = %d, want 900", cfg.Gateway.ImageStreamDataIntervalTimeout)
}
if cfg.Gateway.ImageStreamKeepaliveInterval != 10 {
t.Fatalf("image_stream_keepalive_interval = %d, want 10", cfg.Gateway.ImageStreamKeepaliveInterval)
}
if cfg.Gateway.ImageConcurrency.Enabled {
t.Fatalf("image_concurrency.enabled = true, want false")
}
if cfg.Gateway.ImageConcurrency.MaxConcurrentRequests != 0 {
t.Fatalf("image_concurrency.max_concurrent_requests = %d, want 0", cfg.Gateway.ImageConcurrency.MaxConcurrentRequests)
}
if cfg.Gateway.ImageConcurrency.OverflowMode != ImageConcurrencyOverflowModeReject {
t.Fatalf("image_concurrency.overflow_mode = %q, want %q", cfg.Gateway.ImageConcurrency.OverflowMode, ImageConcurrencyOverflowModeReject)
}
if cfg.Gateway.ImageConcurrency.WaitTimeoutSeconds != 30 {
t.Fatalf("image_concurrency.wait_timeout_seconds = %d, want 30", cfg.Gateway.ImageConcurrency.WaitTimeoutSeconds)
}
if cfg.Gateway.ImageConcurrency.MaxWaitingRequests != 100 {
t.Fatalf("image_concurrency.max_waiting_requests = %d, want 100", cfg.Gateway.ImageConcurrency.MaxWaitingRequests)
}
if cfg.Gateway.ImageStreamDataIntervalTimeout <= cfg.Gateway.StreamDataIntervalTimeout {
t.Fatalf("image stream timeout = %d, want greater than ordinary stream timeout %d", cfg.Gateway.ImageStreamDataIntervalTimeout, cfg.Gateway.StreamDataIntervalTimeout)
}
}

View File

@ -529,6 +529,10 @@ func (h *AccountHandler) Create(c *gin.Context) {
// 确定是否跳过混合渠道检查 // 确定是否跳过混合渠道检查
skipCheck := req.ConfirmMixedChannelRisk != nil && *req.ConfirmMixedChannelRisk skipCheck := req.ConfirmMixedChannelRisk != nil && *req.ConfirmMixedChannelRisk
// 捕获闭包内创建的账号引用,用于创建成功后触发异步探测。
// 幂等重放时闭包不会执行 → createdAccount 为 nil → 不重复调度。
var createdAccount *service.Account
result, err := executeAdminIdempotent(c, "admin.accounts.create", req, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) { result, err := executeAdminIdempotent(c, "admin.accounts.create", req, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) {
account, execErr := h.adminService.CreateAccount(ctx, &service.CreateAccountInput{ account, execErr := h.adminService.CreateAccount(ctx, &service.CreateAccountInput{
Name: req.Name, Name: req.Name,
@ -550,6 +554,7 @@ func (h *AccountHandler) Create(c *gin.Context) {
if execErr != nil { if execErr != nil {
return nil, execErr return nil, execErr
} }
createdAccount = account
// Antigravity OAuth: 新账号直接设置隐私 // Antigravity OAuth: 新账号直接设置隐私
h.adminService.ForceAntigravityPrivacy(ctx, account) h.adminService.ForceAntigravityPrivacy(ctx, account)
// OpenAI OAuth: 新账号直接设置隐私 // OpenAI OAuth: 新账号直接设置隐私
@ -578,6 +583,9 @@ func (h *AccountHandler) Create(c *gin.Context) {
if result != nil && result.Replayed { if result != nil && result.Replayed {
c.Header("X-Idempotency-Replayed", "true") c.Header("X-Idempotency-Replayed", "true")
} }
// OpenAI APIKey 账号创建后异步探测上游 /v1/responses 能力。
// 探测失败不影响账号创建响应。
h.scheduleOpenAIResponsesProbe(createdAccount)
response.Success(c, result.Data) response.Success(c, result.Data)
} }
@ -638,9 +646,39 @@ func (h *AccountHandler) Update(c *gin.Context) {
return return
} }
// OpenAI APIKey: credentials 修改后重新探测上游能力base_url/api_key 可能变更)。
// 异步执行,探测失败不影响账号更新响应。
if len(req.Credentials) > 0 {
h.scheduleOpenAIResponsesProbe(account)
}
response.Success(c, h.buildAccountResponseWithRuntime(c.Request.Context(), account)) response.Success(c, h.buildAccountResponseWithRuntime(c.Request.Context(), account))
} }
// scheduleOpenAIResponsesProbe 异步触发 OpenAI APIKey 账号的 Responses API 能力探测。
//
// 仅对 platform=openai && type=apikey 账号生效;其他账号无操作。
// 探测本身在 goroutine 中执行(会发一次 HTTP 请求到上游),不会阻塞
// 当前请求。探测错误仅记录日志,不向上下文传播:探测失败时标记保持缺失,
// 网关会按"现状即证据"默认走 Responses。
func (h *AccountHandler) scheduleOpenAIResponsesProbe(account *service.Account) {
if account == nil || account.Platform != service.PlatformOpenAI || account.Type != service.AccountTypeAPIKey {
return
}
if h.accountTestService == nil {
return
}
accountID := account.ID
go func() {
defer func() {
if r := recover(); r != nil {
slog.Error("openai_responses_probe_panic", "account_id", accountID, "recover", r)
}
}()
h.accountTestService.ProbeOpenAIAPIKeyResponsesSupport(context.Background(), accountID)
}()
}
// Delete handles deleting an account // Delete handles deleting an account
// DELETE /api/v1/admin/accounts/:id // DELETE /api/v1/admin/accounts/:id
func (h *AccountHandler) Delete(c *gin.Context) { func (h *AccountHandler) Delete(c *gin.Context) {
@ -1232,6 +1270,8 @@ func (h *AccountHandler) BatchCreate(c *gin.Context) {
openaiPrivacyAccounts = append(openaiPrivacyAccounts, account) openaiPrivacyAccounts = append(openaiPrivacyAccounts, account)
} }
} }
// OpenAI APIKey 账号异步探测 /v1/responses 能力。
h.scheduleOpenAIResponsesProbe(account)
success++ success++
results = append(results, gin.H{ results = append(results, gin.H{
"name": item.Name, "name": item.Name,

View File

@ -2,8 +2,11 @@ package admin
import ( import (
"strconv" "strconv"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/response" "github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
@ -181,3 +184,108 @@ func (h *AffiliateHandler) LookupUsers(c *gin.Context) {
} }
response.Success(c, result) response.Success(c, result)
} }
// GetUserOverview returns one user's affiliate overview.
// GET /api/v1/admin/affiliates/users/:user_id/overview
func (h *AffiliateHandler) GetUserOverview(c *gin.Context) {
userID, err := strconv.ParseInt(c.Param("user_id"), 10, 64)
if err != nil || userID <= 0 {
response.BadRequest(c, "Invalid user_id")
return
}
overview, err := h.affiliateService.AdminGetUserOverview(c.Request.Context(), userID)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, overview)
}
// ListInviteRecords returns all inviter-invitee relationships.
// GET /api/v1/admin/affiliates/invites
func (h *AffiliateHandler) ListInviteRecords(c *gin.Context) {
page, pageSize := response.ParsePagination(c)
filter := parseAffiliateRecordFilter(c, page, pageSize)
items, total, err := h.affiliateService.AdminListInviteRecords(c.Request.Context(), filter)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Paginated(c, items, total, filter.Page, filter.PageSize)
}
// ListRebateRecords returns all order-level affiliate rebate records.
// GET /api/v1/admin/affiliates/rebates
func (h *AffiliateHandler) ListRebateRecords(c *gin.Context) {
page, pageSize := response.ParsePagination(c)
filter := parseAffiliateRecordFilter(c, page, pageSize)
items, total, err := h.affiliateService.AdminListRebateRecords(c.Request.Context(), filter)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Paginated(c, items, total, filter.Page, filter.PageSize)
}
// ListTransferRecords returns all affiliate quota-to-balance transfer records.
// GET /api/v1/admin/affiliates/transfers
func (h *AffiliateHandler) ListTransferRecords(c *gin.Context) {
page, pageSize := response.ParsePagination(c)
filter := parseAffiliateRecordFilter(c, page, pageSize)
items, total, err := h.affiliateService.AdminListTransferRecords(c.Request.Context(), filter)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Paginated(c, items, total, filter.Page, filter.PageSize)
}
func parseAffiliateRecordFilter(c *gin.Context, page, pageSize int) service.AffiliateRecordFilter {
filter := service.AffiliateRecordFilter{
Search: c.Query("search"),
Page: page,
PageSize: pageSize,
SortBy: c.Query("sort_by"),
SortDesc: c.Query("sort_order") != "asc",
}
if filter.PageSize > 100 {
filter.PageSize = 100
}
userTZ := c.Query("timezone")
if t := parseAffiliateRecordStartTime(c.Query("start_at"), userTZ); t != nil {
filter.StartAt = t
}
if t := parseAffiliateRecordEndTime(c.Query("end_at"), userTZ); t != nil {
filter.EndAt = t
}
return filter
}
func parseAffiliateRecordStartTime(raw string, userTZ string) *time.Time {
raw = strings.TrimSpace(raw)
if raw == "" {
return nil
}
if parsed, err := time.Parse(time.RFC3339, raw); err == nil {
return &parsed
}
if parsed, err := timezone.ParseInUserLocation("2006-01-02", raw, userTZ); err == nil {
return &parsed
}
return nil
}
func parseAffiliateRecordEndTime(raw string, userTZ string) *time.Time {
raw = strings.TrimSpace(raw)
if raw == "" {
return nil
}
if parsed, err := time.Parse(time.RFC3339, raw); err == nil {
return &parsed
}
if parsed, err := timezone.ParseInUserLocation("2006-01-02", raw, userTZ); err == nil {
end := parsed.AddDate(0, 0, 1).Add(-time.Nanosecond)
return &end
}
return nil
}

View File

@ -92,6 +92,9 @@ type CreateGroupRequest struct {
WeeklyLimitUSD optionalLimitField `json:"weekly_limit_usd"` WeeklyLimitUSD optionalLimitField `json:"weekly_limit_usd"`
MonthlyLimitUSD optionalLimitField `json:"monthly_limit_usd"` MonthlyLimitUSD optionalLimitField `json:"monthly_limit_usd"`
// 图片生成计费配置antigravity 和 gemini 平台使用,负数表示清除配置) // 图片生成计费配置antigravity 和 gemini 平台使用,负数表示清除配置)
AllowImageGeneration bool `json:"allow_image_generation"`
ImageRateIndependent bool `json:"image_rate_independent"`
ImageRateMultiplier *float64 `json:"image_rate_multiplier"`
ImagePrice1K *float64 `json:"image_price_1k"` ImagePrice1K *float64 `json:"image_price_1k"`
ImagePrice2K *float64 `json:"image_price_2k"` ImagePrice2K *float64 `json:"image_price_2k"`
ImagePrice4K *float64 `json:"image_price_4k"` ImagePrice4K *float64 `json:"image_price_4k"`
@ -129,6 +132,9 @@ type UpdateGroupRequest struct {
WeeklyLimitUSD optionalLimitField `json:"weekly_limit_usd"` WeeklyLimitUSD optionalLimitField `json:"weekly_limit_usd"`
MonthlyLimitUSD optionalLimitField `json:"monthly_limit_usd"` MonthlyLimitUSD optionalLimitField `json:"monthly_limit_usd"`
// 图片生成计费配置antigravity 和 gemini 平台使用,负数表示清除配置) // 图片生成计费配置antigravity 和 gemini 平台使用,负数表示清除配置)
AllowImageGeneration *bool `json:"allow_image_generation"`
ImageRateIndependent *bool `json:"image_rate_independent"`
ImageRateMultiplier *float64 `json:"image_rate_multiplier"`
ImagePrice1K *float64 `json:"image_price_1k"` ImagePrice1K *float64 `json:"image_price_1k"`
ImagePrice2K *float64 `json:"image_price_2k"` ImagePrice2K *float64 `json:"image_price_2k"`
ImagePrice4K *float64 `json:"image_price_4k"` ImagePrice4K *float64 `json:"image_price_4k"`
@ -251,6 +257,9 @@ func (h *GroupHandler) Create(c *gin.Context) {
DailyLimitUSD: req.DailyLimitUSD.ToServiceInput(), DailyLimitUSD: req.DailyLimitUSD.ToServiceInput(),
WeeklyLimitUSD: req.WeeklyLimitUSD.ToServiceInput(), WeeklyLimitUSD: req.WeeklyLimitUSD.ToServiceInput(),
MonthlyLimitUSD: req.MonthlyLimitUSD.ToServiceInput(), MonthlyLimitUSD: req.MonthlyLimitUSD.ToServiceInput(),
AllowImageGeneration: req.AllowImageGeneration,
ImageRateIndependent: req.ImageRateIndependent,
ImageRateMultiplier: req.ImageRateMultiplier,
ImagePrice1K: req.ImagePrice1K, ImagePrice1K: req.ImagePrice1K,
ImagePrice2K: req.ImagePrice2K, ImagePrice2K: req.ImagePrice2K,
ImagePrice4K: req.ImagePrice4K, ImagePrice4K: req.ImagePrice4K,
@ -303,6 +312,9 @@ func (h *GroupHandler) Update(c *gin.Context) {
DailyLimitUSD: req.DailyLimitUSD.ToServiceInput(), DailyLimitUSD: req.DailyLimitUSD.ToServiceInput(),
WeeklyLimitUSD: req.WeeklyLimitUSD.ToServiceInput(), WeeklyLimitUSD: req.WeeklyLimitUSD.ToServiceInput(),
MonthlyLimitUSD: req.MonthlyLimitUSD.ToServiceInput(), MonthlyLimitUSD: req.MonthlyLimitUSD.ToServiceInput(),
AllowImageGeneration: req.AllowImageGeneration,
ImageRateIndependent: req.ImageRateIndependent,
ImageRateMultiplier: req.ImageRateMultiplier,
ImagePrice1K: req.ImagePrice1K, ImagePrice1K: req.ImagePrice1K,
ImagePrice2K: req.ImagePrice2K, ImagePrice2K: req.ImagePrice2K,
ImagePrice4K: req.ImagePrice4K, ImagePrice4K: req.ImagePrice4K,

View File

@ -2462,6 +2462,58 @@ func (h *SettingHandler) UpdateOverloadCooldownSettings(c *gin.Context) {
}) })
} }
// GetRateLimit429CooldownSettings 获取429默认回避配置
// GET /api/v1/admin/settings/rate-limit-429-cooldown
func (h *SettingHandler) GetRateLimit429CooldownSettings(c *gin.Context) {
settings, err := h.settingService.GetRateLimit429CooldownSettings(c.Request.Context())
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, dto.RateLimit429CooldownSettings{
Enabled: settings.Enabled,
CooldownSeconds: settings.CooldownSeconds,
})
}
// UpdateRateLimit429CooldownSettingsRequest 更新429默认回避配置请求
type UpdateRateLimit429CooldownSettingsRequest struct {
Enabled bool `json:"enabled"`
CooldownSeconds int `json:"cooldown_seconds"`
}
// UpdateRateLimit429CooldownSettings 更新429默认回避配置
// PUT /api/v1/admin/settings/rate-limit-429-cooldown
func (h *SettingHandler) UpdateRateLimit429CooldownSettings(c *gin.Context) {
var req UpdateRateLimit429CooldownSettingsRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
settings := &service.RateLimit429CooldownSettings{
Enabled: req.Enabled,
CooldownSeconds: req.CooldownSeconds,
}
if err := h.settingService.SetRateLimit429CooldownSettings(c.Request.Context(), settings); err != nil {
response.BadRequest(c, err.Error())
return
}
updatedSettings, err := h.settingService.GetRateLimit429CooldownSettings(c.Request.Context())
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, dto.RateLimit429CooldownSettings{
Enabled: updatedSettings.Enabled,
CooldownSeconds: updatedSettings.CooldownSeconds,
})
}
// GetStreamTimeoutSettings 获取流超时处理配置 // GetStreamTimeoutSettings 获取流超时处理配置
// GET /api/v1/admin/settings/stream-timeout // GET /api/v1/admin/settings/stream-timeout
func (h *SettingHandler) GetStreamTimeoutSettings(c *gin.Context) { func (h *SettingHandler) GetStreamTimeoutSettings(c *gin.Context) {

View File

@ -390,7 +390,7 @@ func (h *UserHandler) GetUserUsage(c *gin.Context) {
// GetBalanceHistory handles getting user's balance/concurrency change history // GetBalanceHistory handles getting user's balance/concurrency change history
// GET /api/v1/admin/users/:id/balance-history // GET /api/v1/admin/users/:id/balance-history
// Query params: // Query params:
// - type: filter by record type (balance, admin_balance, concurrency, admin_concurrency, subscription) // - type: filter by record type (balance, affiliate_balance, admin_balance, concurrency, admin_concurrency, subscription)
func (h *UserHandler) GetBalanceHistory(c *gin.Context) { func (h *UserHandler) GetBalanceHistory(c *gin.Context) {
userID, err := strconv.ParseInt(c.Param("id"), 10, 64) userID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil { if err != nil {

View File

@ -176,6 +176,9 @@ func groupFromServiceBase(g *service.Group) Group {
DailyLimitUSD: g.DailyLimitUSD, DailyLimitUSD: g.DailyLimitUSD,
WeeklyLimitUSD: g.WeeklyLimitUSD, WeeklyLimitUSD: g.WeeklyLimitUSD,
MonthlyLimitUSD: g.MonthlyLimitUSD, MonthlyLimitUSD: g.MonthlyLimitUSD,
AllowImageGeneration: g.AllowImageGeneration,
ImageRateIndependent: g.ImageRateIndependent,
ImageRateMultiplier: g.ImageRateMultiplier,
ImagePrice1K: g.ImagePrice1K, ImagePrice1K: g.ImagePrice1K,
ImagePrice2K: g.ImagePrice2K, ImagePrice2K: g.ImagePrice2K,
ImagePrice4K: g.ImagePrice4K, ImagePrice4K: g.ImagePrice4K,

View File

@ -264,6 +264,12 @@ type OverloadCooldownSettings struct {
CooldownMinutes int `json:"cooldown_minutes"` CooldownMinutes int `json:"cooldown_minutes"`
} }
// RateLimit429CooldownSettings 429默认回避配置 DTO
type RateLimit429CooldownSettings struct {
Enabled bool `json:"enabled"`
CooldownSeconds int `json:"cooldown_seconds"`
}
// StreamTimeoutSettings 流超时处理配置 DTO // StreamTimeoutSettings 流超时处理配置 DTO
type StreamTimeoutSettings struct { type StreamTimeoutSettings struct {
Enabled bool `json:"enabled"` Enabled bool `json:"enabled"`

View File

@ -94,9 +94,12 @@ type Group struct {
MonthlyLimitUSD *float64 `json:"monthly_limit_usd"` MonthlyLimitUSD *float64 `json:"monthly_limit_usd"`
// 图片生成计费配置(仅 antigravity 平台使用) // 图片生成计费配置(仅 antigravity 平台使用)
ImagePrice1K *float64 `json:"image_price_1k"` AllowImageGeneration bool `json:"allow_image_generation"`
ImagePrice2K *float64 `json:"image_price_2k"` ImageRateIndependent bool `json:"image_rate_independent"`
ImagePrice4K *float64 `json:"image_price_4k"` ImageRateMultiplier float64 `json:"image_rate_multiplier"`
ImagePrice1K *float64 `json:"image_price_1k"`
ImagePrice2K *float64 `json:"image_price_2k"`
ImagePrice4K *float64 `json:"image_price_4k"`
// Claude Code 客户端限制 // Claude Code 客户端限制
ClaudeCodeOnly bool `json:"claude_code_only"` ClaudeCodeOnly bool `json:"claude_code_only"`

View File

@ -0,0 +1,126 @@
package handler
import (
"context"
"sync"
"time"
)
type imageConcurrencyLimiter struct {
mu sync.Mutex
notify chan struct{}
limit int
active int
waiting int
enabled bool
}
func (l *imageConcurrencyLimiter) TryAcquire(enabled bool, limit int) (func(), bool) {
return l.acquire(context.Background(), enabled, limit, false, 0, 0)
}
func (l *imageConcurrencyLimiter) Acquire(ctx context.Context, enabled bool, limit int, wait bool, timeout time.Duration, maxWaiting int) (func(), bool) {
return l.acquire(ctx, enabled, limit, wait, timeout, maxWaiting)
}
func (l *imageConcurrencyLimiter) acquire(ctx context.Context, enabled bool, limit int, wait bool, timeout time.Duration, maxWaiting int) (func(), bool) {
if !enabled || limit <= 0 {
return nil, true
}
if ctx == nil {
ctx = context.Background()
}
if wait {
if timeout <= 0 {
return nil, false
}
waitCtx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
ctx = waitCtx
}
if maxWaiting < 0 {
maxWaiting = 0
}
for {
release, acquired, waitRelease, notify := l.tryAcquireLocked(enabled, limit, wait, maxWaiting)
if acquired {
return release, acquired
}
if !wait || notify == nil {
return nil, false
}
if !l.waitForSlot(ctx, notify) {
if waitRelease != nil {
waitRelease()
}
return nil, false
}
if waitRelease != nil {
waitRelease()
}
}
}
func (l *imageConcurrencyLimiter) tryAcquireLocked(enabled bool, limit int, wait bool, maxWaiting int) (func(), bool, func(), <-chan struct{}) {
l.mu.Lock()
defer l.mu.Unlock()
if l.notify == nil {
l.notify = make(chan struct{})
}
if l.enabled != enabled || l.limit != limit {
l.enabled = enabled
l.limit = limit
}
if l.active < l.limit {
l.active++
return l.releaseFunc(), true, nil, nil
}
if !wait {
return nil, false, nil, nil
}
if maxWaiting > 0 && l.waiting >= maxWaiting {
return nil, false, nil, nil
}
l.waiting++
return nil, false, l.waiterReleaseFunc(), l.notify
}
func (l *imageConcurrencyLimiter) waitForSlot(ctx context.Context, notify <-chan struct{}) bool {
select {
case <-notify:
return true
case <-ctx.Done():
return false
}
}
func (l *imageConcurrencyLimiter) releaseFunc() func() {
var once sync.Once
return func() {
once.Do(func() {
l.mu.Lock()
if l.active > 0 {
l.active--
}
if l.notify != nil {
close(l.notify)
l.notify = make(chan struct{})
}
l.mu.Unlock()
})
}
}
func (l *imageConcurrencyLimiter) waiterReleaseFunc() func() {
var once sync.Once
return func() {
once.Do(func() {
l.mu.Lock()
if l.waiting > 0 {
l.waiting--
}
l.mu.Unlock()
})
}
}

View File

@ -0,0 +1,230 @@
package handler
import (
"context"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
"github.com/tidwall/gjson"
)
func TestImageConcurrencyLimiter_DefaultDisabledAllowsRequests(t *testing.T) {
limiter := &imageConcurrencyLimiter{}
release, acquired := limiter.TryAcquire(false, 1)
require.True(t, acquired)
require.Nil(t, release)
}
func TestImageConcurrencyLimiter_RejectsWhenLimitReachedAndAllowsAfterRelease(t *testing.T) {
limiter := &imageConcurrencyLimiter{}
release, acquired := limiter.TryAcquire(true, 1)
require.True(t, acquired)
require.NotNil(t, release)
secondRelease, secondAcquired := limiter.TryAcquire(true, 1)
require.False(t, secondAcquired)
require.Nil(t, secondRelease)
release()
thirdRelease, thirdAcquired := limiter.TryAcquire(true, 1)
require.True(t, thirdAcquired)
require.NotNil(t, thirdRelease)
thirdRelease()
}
func TestImageConcurrencyLimiter_WaitsUntilSlotReleased(t *testing.T) {
limiter := &imageConcurrencyLimiter{}
release, acquired := limiter.Acquire(context.Background(), true, 1, true, time.Second, 1)
require.True(t, acquired)
require.NotNil(t, release)
acquiredCh := make(chan func(), 1)
go func() {
waitRelease, waitAcquired := limiter.Acquire(context.Background(), true, 1, true, time.Second, 1)
require.True(t, waitAcquired)
acquiredCh <- waitRelease
}()
time.Sleep(20 * time.Millisecond)
release()
select {
case waitRelease := <-acquiredCh:
require.NotNil(t, waitRelease)
waitRelease()
case <-time.After(time.Second):
t.Fatal("timed out waiting for image concurrency slot")
}
}
func TestImageConcurrencyLimiter_WaitTimesOut(t *testing.T) {
limiter := &imageConcurrencyLimiter{}
release, acquired := limiter.Acquire(context.Background(), true, 1, true, time.Second, 1)
require.True(t, acquired)
require.NotNil(t, release)
defer release()
waitRelease, waitAcquired := limiter.Acquire(context.Background(), true, 1, true, 10*time.Millisecond, 1)
require.False(t, waitAcquired)
require.Nil(t, waitRelease)
}
func TestImageConcurrencyLimiter_MaxWaitingRequestsRejectsOverflow(t *testing.T) {
limiter := &imageConcurrencyLimiter{}
release, acquired := limiter.Acquire(context.Background(), true, 1, true, time.Second, 1)
require.True(t, acquired)
require.NotNil(t, release)
defer release()
waitingStarted := make(chan struct{})
waitingDone := make(chan struct{})
go func() {
close(waitingStarted)
waitRelease, waitAcquired := limiter.Acquire(context.Background(), true, 1, true, time.Second, 1)
if waitAcquired && waitRelease != nil {
waitRelease()
}
close(waitingDone)
}()
<-waitingStarted
time.Sleep(20 * time.Millisecond)
overflowRelease, overflowAcquired := limiter.Acquire(context.Background(), true, 1, true, time.Second, 1)
require.False(t, overflowAcquired)
require.Nil(t, overflowRelease)
release()
<-waitingDone
}
func TestOpenAIGatewayHandlerAcquireImageGenerationSlot_Returns429WhenFull(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/images/generations", nil)
h := &OpenAIGatewayHandler{
cfg: &config.Config{
Gateway: config.GatewayConfig{
ImageConcurrency: config.ImageConcurrencyConfig{
Enabled: true,
MaxConcurrentRequests: 1,
OverflowMode: config.ImageConcurrencyOverflowModeReject,
},
},
},
imageLimiter: &imageConcurrencyLimiter{},
}
release, acquired := h.acquireImageGenerationSlot(c, false)
require.True(t, acquired)
require.NotNil(t, release)
defer release()
blockedRelease, blocked := h.acquireImageGenerationSlot(c, false)
require.False(t, blocked)
require.Nil(t, blockedRelease)
require.Equal(t, http.StatusTooManyRequests, rec.Code)
require.Equal(t, "rate_limit_error", gjson.GetBytes(rec.Body.Bytes(), "error.type").String())
require.Contains(t, rec.Body.String(), "Image generation concurrency limit exceeded")
}
func TestOpenAIGatewayHandlerResponses_ImageIntentRejectedByImageConcurrency(t *testing.T) {
gin.SetMode(gin.TestMode)
body := `{"model":"gpt-5.4","input":"draw","tools":[{"type":"image_generation"}]}`
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", strings.NewReader(body))
groupID := int64(1)
c.Set(string(middleware2.ContextKeyAPIKey), &service.APIKey{
ID: 10,
GroupID: &groupID,
Group: &service.Group{
ID: groupID,
AllowImageGeneration: true,
},
User: &service.User{ID: 20},
})
c.Set(string(middleware2.ContextKeyUser), middleware2.AuthSubject{UserID: 20, Concurrency: 1})
h := &OpenAIGatewayHandler{
gatewayService: &service.OpenAIGatewayService{},
billingCacheService: &service.BillingCacheService{},
apiKeyService: &service.APIKeyService{},
concurrencyHelper: &ConcurrencyHelper{concurrencyService: service.NewConcurrencyService(&helperConcurrencyCacheStub{userSeq: []bool{true}})},
errorPassthroughService: nil,
cfg: &config.Config{Gateway: config.GatewayConfig{ImageConcurrency: config.ImageConcurrencyConfig{
Enabled: true,
MaxConcurrentRequests: 1,
OverflowMode: config.ImageConcurrencyOverflowModeReject,
}}},
imageLimiter: &imageConcurrencyLimiter{},
}
release, acquired := h.acquireImageGenerationSlot(c, false)
require.True(t, acquired)
require.NotNil(t, release)
defer release()
rec.Body.Reset()
rec.Code = 0
h.Responses(c)
require.Equal(t, http.StatusTooManyRequests, rec.Code)
require.Equal(t, "rate_limit_error", gjson.GetBytes(rec.Body.Bytes(), "error.type").String())
require.Contains(t, rec.Body.String(), "Image generation concurrency limit exceeded")
}
func TestOpenAIGatewayHandlerResponses_TextOnlyNotRejectedByImageConcurrency(t *testing.T) {
gin.SetMode(gin.TestMode)
body := `{"model":"gpt-5.4","input":"write code"}`
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", strings.NewReader(body))
groupID := int64(1)
c.Set(string(middleware2.ContextKeyAPIKey), &service.APIKey{
ID: 10,
GroupID: &groupID,
Group: &service.Group{
ID: groupID,
AllowImageGeneration: true,
},
User: &service.User{ID: 20},
})
c.Set(string(middleware2.ContextKeyUser), middleware2.AuthSubject{UserID: 20, Concurrency: 1})
h := &OpenAIGatewayHandler{
gatewayService: &service.OpenAIGatewayService{},
billingCacheService: service.NewBillingCacheService(nil, nil, nil, nil, nil, nil, &config.Config{RunMode: config.RunModeSimple}),
apiKeyService: &service.APIKeyService{},
concurrencyHelper: &ConcurrencyHelper{concurrencyService: service.NewConcurrencyService(&helperConcurrencyCacheStub{userSeq: []bool{true}})},
cfg: &config.Config{Gateway: config.GatewayConfig{ImageConcurrency: config.ImageConcurrencyConfig{
Enabled: true,
MaxConcurrentRequests: 1,
OverflowMode: config.ImageConcurrencyOverflowModeReject,
}}},
imageLimiter: &imageConcurrencyLimiter{},
}
release, acquired := h.acquireImageGenerationSlot(c, false)
require.True(t, acquired)
require.NotNil(t, release)
defer release()
rec.Body.Reset()
rec.Code = 0
h.Responses(c)
require.NotEqual(t, http.StatusTooManyRequests, rec.Code)
require.NotContains(t, rec.Body.String(), "Image generation concurrency limit exceeded")
}

View File

@ -10,6 +10,7 @@ import (
pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil" pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil"
"github.com/Wei-Shaw/sub2api/internal/pkg/ip" "github.com/Wei-Shaw/sub2api/internal/pkg/ip"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger" "github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai_compat"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
@ -120,7 +121,6 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
var lastFailoverErr *service.UpstreamFailoverError var lastFailoverErr *service.UpstreamFailoverError
for { for {
c.Set("openai_chat_completions_fallback_model", "")
reqLog.Debug("openai_chat_completions.account_selecting", zap.Int("excluded_account_count", len(failedAccountIDs))) reqLog.Debug("openai_chat_completions.account_selecting", zap.Int("excluded_account_count", len(failedAccountIDs)))
selection, scheduleDecision, err := h.gatewayService.SelectAccountWithScheduler( selection, scheduleDecision, err := h.gatewayService.SelectAccountWithScheduler(
c.Request.Context(), c.Request.Context(),
@ -138,32 +138,8 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
zap.Int("excluded_account_count", len(failedAccountIDs)), zap.Int("excluded_account_count", len(failedAccountIDs)),
) )
if len(failedAccountIDs) == 0 { if len(failedAccountIDs) == 0 {
defaultModel := "" h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "Service temporarily unavailable", streamStarted)
if apiKey.Group != nil { return
defaultModel = apiKey.Group.DefaultMappedModel
}
if defaultModel != "" && defaultModel != reqModel {
reqLog.Info("openai_chat_completions.fallback_to_default_model",
zap.String("default_mapped_model", defaultModel),
)
selection, scheduleDecision, err = h.gatewayService.SelectAccountWithScheduler(
c.Request.Context(),
apiKey.GroupID,
"",
sessionHash,
defaultModel,
failedAccountIDs,
service.OpenAIUpstreamTransportAny,
false,
)
if err == nil && selection != nil {
c.Set("openai_chat_completions_fallback_model", defaultModel)
}
}
if err != nil {
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "Service temporarily unavailable", streamStarted)
return
}
} else { } else {
if lastFailoverErr != nil { if lastFailoverErr != nil {
h.handleFailoverExhausted(c, lastFailoverErr, streamStarted) h.handleFailoverExhausted(c, lastFailoverErr, streamStarted)
@ -191,12 +167,11 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
service.SetOpsLatencyMs(c, service.OpsRoutingLatencyMsKey, time.Since(routingStart).Milliseconds()) service.SetOpsLatencyMs(c, service.OpsRoutingLatencyMsKey, time.Since(routingStart).Milliseconds())
forwardStart := time.Now() forwardStart := time.Now()
defaultMappedModel := resolveOpenAIForwardDefaultMappedModel(apiKey, c.GetString("openai_chat_completions_fallback_model"))
forwardBody := body forwardBody := body
if channelMapping.Mapped { if channelMapping.Mapped {
forwardBody = h.gatewayService.ReplaceModelInBody(body, channelMapping.MappedModel) forwardBody = h.gatewayService.ReplaceModelInBody(body, channelMapping.MappedModel)
} }
result, err := h.gatewayService.ForwardAsChatCompletions(c.Request.Context(), c, account, forwardBody, promptCacheKey, defaultMappedModel) result, err := h.gatewayService.ForwardAsChatCompletions(c.Request.Context(), c, account, forwardBody, promptCacheKey, "")
forwardDurationMs := time.Since(forwardStart).Milliseconds() forwardDurationMs := time.Since(forwardStart).Milliseconds()
if accountReleaseFunc != nil { if accountReleaseFunc != nil {
@ -212,52 +187,60 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
service.SetOpsLatencyMs(c, service.OpsTimeToFirstTokenMsKey, int64(*result.FirstTokenMs)) service.SetOpsLatencyMs(c, service.OpsTimeToFirstTokenMsKey, int64(*result.FirstTokenMs))
} }
if err != nil { if err != nil {
var failoverErr *service.UpstreamFailoverError if result != nil && result.ImageCount > 0 {
if errors.As(err, &failoverErr) { reqLog.Warn("openai_chat_completions.forward_partial_error_with_image_result",
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
// Pool mode: retry on the same account
if failoverErr.RetryableOnSameAccount {
retryLimit := account.GetPoolModeRetryCount()
if sameAccountRetryCount[account.ID] < retryLimit {
sameAccountRetryCount[account.ID]++
reqLog.Warn("openai_chat_completions.pool_mode_same_account_retry",
zap.Int64("account_id", account.ID),
zap.Int("upstream_status", failoverErr.StatusCode),
zap.Int("retry_limit", retryLimit),
zap.Int("retry_count", sameAccountRetryCount[account.ID]),
)
select {
case <-c.Request.Context().Done():
return
case <-time.After(sameAccountRetryDelay):
}
continue
}
}
h.gatewayService.RecordOpenAIAccountSwitch()
failedAccountIDs[account.ID] = struct{}{}
lastFailoverErr = failoverErr
if switchCount >= maxAccountSwitches {
h.handleFailoverExhausted(c, failoverErr, streamStarted)
return
}
switchCount++
reqLog.Warn("openai_chat_completions.upstream_failover_switching",
zap.Int64("account_id", account.ID), zap.Int64("account_id", account.ID),
zap.Int("upstream_status", failoverErr.StatusCode), zap.Int("image_count", result.ImageCount),
zap.Int("switch_count", switchCount), zap.Error(err),
zap.Int("max_switches", maxAccountSwitches),
) )
continue } else {
var failoverErr *service.UpstreamFailoverError
if errors.As(err, &failoverErr) {
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
// Pool mode: retry on the same account
if failoverErr.RetryableOnSameAccount {
retryLimit := account.GetPoolModeRetryCount()
if sameAccountRetryCount[account.ID] < retryLimit {
sameAccountRetryCount[account.ID]++
reqLog.Warn("openai_chat_completions.pool_mode_same_account_retry",
zap.Int64("account_id", account.ID),
zap.Int("upstream_status", failoverErr.StatusCode),
zap.Int("retry_limit", retryLimit),
zap.Int("retry_count", sameAccountRetryCount[account.ID]),
)
select {
case <-c.Request.Context().Done():
return
case <-time.After(sameAccountRetryDelay):
}
continue
}
}
h.gatewayService.RecordOpenAIAccountSwitch()
failedAccountIDs[account.ID] = struct{}{}
lastFailoverErr = failoverErr
if switchCount >= maxAccountSwitches {
h.handleFailoverExhausted(c, failoverErr, streamStarted)
return
}
switchCount++
reqLog.Warn("openai_chat_completions.upstream_failover_switching",
zap.Int64("account_id", account.ID),
zap.Int("upstream_status", failoverErr.StatusCode),
zap.Int("switch_count", switchCount),
zap.Int("max_switches", maxAccountSwitches),
)
continue
}
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
wroteFallback := h.ensureForwardErrorResponse(c, streamStarted)
reqLog.Warn("openai_chat_completions.forward_failed",
zap.Int64("account_id", account.ID),
zap.Bool("fallback_error_response_written", wroteFallback),
zap.Error(err),
)
return
} }
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
wroteFallback := h.ensureForwardErrorResponse(c, streamStarted)
reqLog.Warn("openai_chat_completions.forward_failed",
zap.Int64("account_id", account.ID),
zap.Bool("fallback_error_response_written", wroteFallback),
zap.Error(err),
)
return
} }
if result != nil { if result != nil {
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, result.FirstTokenMs) h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, result.FirstTokenMs)
@ -267,16 +250,18 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
userAgent := c.GetHeader("User-Agent") userAgent := c.GetHeader("User-Agent")
clientIP := ip.GetClientIP(c) clientIP := ip.GetClientIP(c)
inboundEndpoint := GetInboundEndpoint(c)
upstreamEndpoint := resolveRawCCUpstreamEndpoint(c, account)
h.submitUsageRecordTask(func(ctx context.Context) { h.submitOpenAIUsageRecordTask(result, func(ctx context.Context) {
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{ if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
Result: result, Result: result,
APIKey: apiKey, APIKey: apiKey,
User: apiKey.User, User: apiKey.User,
Account: account, Account: account,
Subscription: subscription, Subscription: subscription,
InboundEndpoint: GetInboundEndpoint(c), InboundEndpoint: inboundEndpoint,
UpstreamEndpoint: GetUpstreamEndpoint(c, account.Platform), UpstreamEndpoint: upstreamEndpoint,
UserAgent: userAgent, UserAgent: userAgent,
IPAddress: clientIP, IPAddress: clientIP,
APIKeyService: h.apiKeyService, APIKeyService: h.apiKeyService,
@ -299,3 +284,16 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
return return
} }
} }
// resolveRawCCUpstreamEndpoint returns the actual upstream endpoint for
// OpenAI Chat Completions requests. For APIKey accounts whose upstream
// has been probed to not support the Responses API, the request is
// forwarded directly to /v1/chat/completions — not through the default
// CC→Responses conversion path.
func resolveRawCCUpstreamEndpoint(c *gin.Context, account *service.Account) string {
if account != nil && account.Type == service.AccountTypeAPIKey &&
!openai_compat.ShouldUseResponsesAPI(account.Extra) {
return "/v1/chat/completions"
}
return GetUpstreamEndpoint(c, account.Platform)
}

View File

@ -33,20 +33,11 @@ type OpenAIGatewayHandler struct {
usageRecordWorkerPool *service.UsageRecordWorkerPool usageRecordWorkerPool *service.UsageRecordWorkerPool
errorPassthroughService *service.ErrorPassthroughService errorPassthroughService *service.ErrorPassthroughService
concurrencyHelper *ConcurrencyHelper concurrencyHelper *ConcurrencyHelper
imageLimiter *imageConcurrencyLimiter
maxAccountSwitches int maxAccountSwitches int
cfg *config.Config cfg *config.Config
} }
func resolveOpenAIForwardDefaultMappedModel(apiKey *service.APIKey, fallbackModel string) string {
if fallbackModel = strings.TrimSpace(fallbackModel); fallbackModel != "" {
return fallbackModel
}
if apiKey == nil || apiKey.Group == nil {
return ""
}
return strings.TrimSpace(apiKey.Group.DefaultMappedModel)
}
func resolveOpenAIMessagesDispatchMappedModel(apiKey *service.APIKey, requestedModel string) string { func resolveOpenAIMessagesDispatchMappedModel(apiKey *service.APIKey, requestedModel string) string {
if apiKey == nil || apiKey.Group == nil { if apiKey == nil || apiKey.Group == nil {
return "" return ""
@ -79,6 +70,7 @@ func NewOpenAIGatewayHandler(
usageRecordWorkerPool: usageRecordWorkerPool, usageRecordWorkerPool: usageRecordWorkerPool,
errorPassthroughService: errorPassthroughService, errorPassthroughService: errorPassthroughService,
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatComment, pingInterval), concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatComment, pingInterval),
imageLimiter: &imageConcurrencyLimiter{},
maxAccountSwitches: maxAccountSwitches, maxAccountSwitches: maxAccountSwitches,
cfg: cfg, cfg: cfg,
} }
@ -197,6 +189,23 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
setOpsRequestContext(c, reqModel, reqStream, body) setOpsRequestContext(c, reqModel, reqStream, body)
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false))) setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false)))
imageIntent := service.IsImageGenerationIntent("/v1/responses", reqModel, body)
if imageIntent && !service.GroupAllowsImageGeneration(apiKey.Group) {
h.errorResponse(c, http.StatusForbidden, "permission_error", service.ImageGenerationPermissionMessage())
return
}
var imageReleaseFunc func()
if imageIntent {
var imageAcquired bool
imageReleaseFunc, imageAcquired = h.acquireImageGenerationSlot(c, streamStarted)
if !imageAcquired {
return
}
if imageReleaseFunc != nil {
defer imageReleaseFunc()
}
}
// 解析渠道级模型映射 // 解析渠道级模型映射
channelMapping, _ := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel) channelMapping, _ := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel)
@ -328,57 +337,65 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
service.SetOpsLatencyMs(c, service.OpsTimeToFirstTokenMsKey, int64(*result.FirstTokenMs)) service.SetOpsLatencyMs(c, service.OpsTimeToFirstTokenMsKey, int64(*result.FirstTokenMs))
} }
if err != nil { if err != nil {
var failoverErr *service.UpstreamFailoverError if result != nil && result.ImageCount > 0 {
if errors.As(err, &failoverErr) { reqLog.Warn("openai.forward_partial_error_with_image_result",
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil) zap.Int64("account_id", account.ID),
// 池模式:同账号重试 zap.Int("image_count", result.ImageCount),
if failoverErr.RetryableOnSameAccount { zap.Error(err),
retryLimit := account.GetPoolModeRetryCount() )
if sameAccountRetryCount[account.ID] < retryLimit { } else {
sameAccountRetryCount[account.ID]++ var failoverErr *service.UpstreamFailoverError
reqLog.Warn("openai.pool_mode_same_account_retry", if errors.As(err, &failoverErr) {
zap.Int64("account_id", account.ID), h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
zap.Int("upstream_status", failoverErr.StatusCode), // 池模式:同账号重试
zap.Int("retry_limit", retryLimit), if failoverErr.RetryableOnSameAccount {
zap.Int("retry_count", sameAccountRetryCount[account.ID]), retryLimit := account.GetPoolModeRetryCount()
) if sameAccountRetryCount[account.ID] < retryLimit {
select { sameAccountRetryCount[account.ID]++
case <-c.Request.Context().Done(): reqLog.Warn("openai.pool_mode_same_account_retry",
return zap.Int64("account_id", account.ID),
case <-time.After(sameAccountRetryDelay): zap.Int("upstream_status", failoverErr.StatusCode),
zap.Int("retry_limit", retryLimit),
zap.Int("retry_count", sameAccountRetryCount[account.ID]),
)
select {
case <-c.Request.Context().Done():
return
case <-time.After(sameAccountRetryDelay):
}
continue
} }
continue
} }
h.gatewayService.RecordOpenAIAccountSwitch()
failedAccountIDs[account.ID] = struct{}{}
lastFailoverErr = failoverErr
if switchCount >= maxAccountSwitches {
h.handleFailoverExhausted(c, failoverErr, streamStarted)
return
}
switchCount++
reqLog.Warn("openai.upstream_failover_switching",
zap.Int64("account_id", account.ID),
zap.Int("upstream_status", failoverErr.StatusCode),
zap.Int("switch_count", switchCount),
zap.Int("max_switches", maxAccountSwitches),
)
continue
} }
h.gatewayService.RecordOpenAIAccountSwitch() h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
failedAccountIDs[account.ID] = struct{}{} wroteFallback := h.ensureForwardErrorResponse(c, streamStarted)
lastFailoverErr = failoverErr fields := []zap.Field{
if switchCount >= maxAccountSwitches { zap.Int64("account_id", account.ID),
h.handleFailoverExhausted(c, failoverErr, streamStarted) zap.Bool("fallback_error_response_written", wroteFallback),
zap.Error(err),
}
if shouldLogOpenAIForwardFailureAsWarn(c, wroteFallback) {
reqLog.Warn("openai.forward_failed", fields...)
return return
} }
switchCount++ reqLog.Error("openai.forward_failed", fields...)
reqLog.Warn("openai.upstream_failover_switching",
zap.Int64("account_id", account.ID),
zap.Int("upstream_status", failoverErr.StatusCode),
zap.Int("switch_count", switchCount),
zap.Int("max_switches", maxAccountSwitches),
)
continue
}
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
wroteFallback := h.ensureForwardErrorResponse(c, streamStarted)
fields := []zap.Field{
zap.Int64("account_id", account.ID),
zap.Bool("fallback_error_response_written", wroteFallback),
zap.Error(err),
}
if shouldLogOpenAIForwardFailureAsWarn(c, wroteFallback) {
reqLog.Warn("openai.forward_failed", fields...)
return return
} }
reqLog.Error("openai.forward_failed", fields...)
return
} }
if result != nil { if result != nil {
if account.Type == service.AccountTypeOAuth { if account.Type == service.AccountTypeOAuth {
@ -393,17 +410,19 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
userAgent := c.GetHeader("User-Agent") userAgent := c.GetHeader("User-Agent")
clientIP := ip.GetClientIP(c) clientIP := ip.GetClientIP(c)
requestPayloadHash := service.HashUsageRequestPayload(body) requestPayloadHash := service.HashUsageRequestPayload(body)
inboundEndpoint := GetInboundEndpoint(c)
upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform)
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。 // 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
h.submitUsageRecordTask(func(ctx context.Context) { h.submitOpenAIUsageRecordTask(result, func(ctx context.Context) {
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{ if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
Result: result, Result: result,
APIKey: apiKey, APIKey: apiKey,
User: apiKey.User, User: apiKey.User,
Account: account, Account: account,
Subscription: subscription, Subscription: subscription,
InboundEndpoint: GetInboundEndpoint(c), InboundEndpoint: inboundEndpoint,
UpstreamEndpoint: GetUpstreamEndpoint(c, account.Platform), UpstreamEndpoint: upstreamEndpoint,
UserAgent: userAgent, UserAgent: userAgent,
IPAddress: clientIP, IPAddress: clientIP,
RequestPayloadHash: requestPayloadHash, RequestPayloadHash: requestPayloadHash,
@ -613,21 +632,7 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
sessionHash := h.gatewayService.GenerateSessionHash(c, body) sessionHash := h.gatewayService.GenerateSessionHash(c, body)
promptCacheKey := h.gatewayService.ExtractSessionID(c, body) promptCacheKey := h.gatewayService.ExtractSessionID(c, body)
sessionHash, promptCacheKey = resolveOpenAIMessagesMetadataSession(sessionHash, promptCacheKey, reqModel, body)
// Anthropic 格式的请求在 metadata.user_id 中携带 session 标识,
// 而非 OpenAI 的 session_id/conversation_id headers。
// 从中派生 sessionHashsticky session和 promptCacheKeyupstream cache
if sessionHash == "" || promptCacheKey == "" {
if userID := strings.TrimSpace(gjson.GetBytes(body, "metadata.user_id").String()); userID != "" {
seed := reqModel + "-" + userID
if promptCacheKey == "" {
promptCacheKey = service.GenerateSessionUUID(seed)
}
if sessionHash == "" {
sessionHash = service.DeriveSessionHashFromSeed(seed)
}
}
}
maxAccountSwitches := h.maxAccountSwitches maxAccountSwitches := h.maxAccountSwitches
switchCount := 0 switchCount := 0
@ -711,52 +716,60 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
service.SetOpsLatencyMs(c, service.OpsTimeToFirstTokenMsKey, int64(*result.FirstTokenMs)) service.SetOpsLatencyMs(c, service.OpsTimeToFirstTokenMsKey, int64(*result.FirstTokenMs))
} }
if err != nil { if err != nil {
var failoverErr *service.UpstreamFailoverError if result != nil && result.ImageCount > 0 {
if errors.As(err, &failoverErr) { reqLog.Warn("openai_messages.forward_partial_error_with_image_result",
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
// 池模式:同账号重试
if failoverErr.RetryableOnSameAccount {
retryLimit := account.GetPoolModeRetryCount()
if sameAccountRetryCount[account.ID] < retryLimit {
sameAccountRetryCount[account.ID]++
reqLog.Warn("openai_messages.pool_mode_same_account_retry",
zap.Int64("account_id", account.ID),
zap.Int("upstream_status", failoverErr.StatusCode),
zap.Int("retry_limit", retryLimit),
zap.Int("retry_count", sameAccountRetryCount[account.ID]),
)
select {
case <-c.Request.Context().Done():
return
case <-time.After(sameAccountRetryDelay):
}
continue
}
}
h.gatewayService.RecordOpenAIAccountSwitch()
failedAccountIDs[account.ID] = struct{}{}
lastFailoverErr = failoverErr
if switchCount >= maxAccountSwitches {
h.handleAnthropicFailoverExhausted(c, failoverErr, streamStarted)
return
}
switchCount++
reqLog.Warn("openai_messages.upstream_failover_switching",
zap.Int64("account_id", account.ID), zap.Int64("account_id", account.ID),
zap.Int("upstream_status", failoverErr.StatusCode), zap.Int("image_count", result.ImageCount),
zap.Int("switch_count", switchCount), zap.Error(err),
zap.Int("max_switches", maxAccountSwitches),
) )
continue } else {
var failoverErr *service.UpstreamFailoverError
if errors.As(err, &failoverErr) {
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
// 池模式:同账号重试
if failoverErr.RetryableOnSameAccount {
retryLimit := account.GetPoolModeRetryCount()
if sameAccountRetryCount[account.ID] < retryLimit {
sameAccountRetryCount[account.ID]++
reqLog.Warn("openai_messages.pool_mode_same_account_retry",
zap.Int64("account_id", account.ID),
zap.Int("upstream_status", failoverErr.StatusCode),
zap.Int("retry_limit", retryLimit),
zap.Int("retry_count", sameAccountRetryCount[account.ID]),
)
select {
case <-c.Request.Context().Done():
return
case <-time.After(sameAccountRetryDelay):
}
continue
}
}
h.gatewayService.RecordOpenAIAccountSwitch()
failedAccountIDs[account.ID] = struct{}{}
lastFailoverErr = failoverErr
if switchCount >= maxAccountSwitches {
h.handleAnthropicFailoverExhausted(c, failoverErr, streamStarted)
return
}
switchCount++
reqLog.Warn("openai_messages.upstream_failover_switching",
zap.Int64("account_id", account.ID),
zap.Int("upstream_status", failoverErr.StatusCode),
zap.Int("switch_count", switchCount),
zap.Int("max_switches", maxAccountSwitches),
)
continue
}
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
wroteFallback := h.ensureAnthropicErrorResponse(c, streamStarted)
reqLog.Warn("openai_messages.forward_failed",
zap.Int64("account_id", account.ID),
zap.Bool("fallback_error_response_written", wroteFallback),
zap.Error(err),
)
return
} }
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
wroteFallback := h.ensureAnthropicErrorResponse(c, streamStarted)
reqLog.Warn("openai_messages.forward_failed",
zap.Int64("account_id", account.ID),
zap.Bool("fallback_error_response_written", wroteFallback),
zap.Error(err),
)
return
} }
if result != nil { if result != nil {
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, result.FirstTokenMs) h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, result.FirstTokenMs)
@ -767,16 +780,18 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
userAgent := c.GetHeader("User-Agent") userAgent := c.GetHeader("User-Agent")
clientIP := ip.GetClientIP(c) clientIP := ip.GetClientIP(c)
requestPayloadHash := service.HashUsageRequestPayload(body) requestPayloadHash := service.HashUsageRequestPayload(body)
inboundEndpoint := GetInboundEndpoint(c)
upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform)
h.submitUsageRecordTask(func(ctx context.Context) { h.submitOpenAIUsageRecordTask(result, func(ctx context.Context) {
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{ if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
Result: result, Result: result,
APIKey: apiKey, APIKey: apiKey,
User: apiKey.User, User: apiKey.User,
Account: account, Account: account,
Subscription: subscription, Subscription: subscription,
InboundEndpoint: GetInboundEndpoint(c), InboundEndpoint: inboundEndpoint,
UpstreamEndpoint: GetUpstreamEndpoint(c, account.Platform), UpstreamEndpoint: upstreamEndpoint,
UserAgent: userAgent, UserAgent: userAgent,
IPAddress: clientIP, IPAddress: clientIP,
RequestPayloadHash: requestPayloadHash, RequestPayloadHash: requestPayloadHash,
@ -801,6 +816,20 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
} }
} }
func resolveOpenAIMessagesMetadataSession(sessionHash, promptCacheKey, reqModel string, body []byte) (string, string) {
// Anthropic metadata.user_id 只作为账号粘性信号。上游 GPT/Codex 缓存键
// 交给 ForwardAsAnthropic 从 cache_control 或完整消息 digest 派生,避免
// 固定 metadata key 压住后续 turn 的缓存滚动。
if sessionHash != "" {
return sessionHash, promptCacheKey
}
if userID := strings.TrimSpace(gjson.GetBytes(body, "metadata.user_id").String()); userID != "" {
seed := reqModel + "-" + userID
sessionHash = service.DeriveSessionHashFromSeed(seed)
}
return sessionHash, promptCacheKey
}
// anthropicErrorResponse writes an error in Anthropic Messages API format. // anthropicErrorResponse writes an error in Anthropic Messages API format.
func (h *OpenAIGatewayHandler) anthropicErrorResponse(c *gin.Context, status int, errType, message string) { func (h *OpenAIGatewayHandler) anthropicErrorResponse(c *gin.Context, status int, errType, message string) {
c.JSON(status, gin.H{ c.JSON(status, gin.H{
@ -1124,6 +1153,11 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) {
setOpsRequestContext(c, reqModel, true, firstMessage) setOpsRequestContext(c, reqModel, true, firstMessage)
setOpsEndpointContext(c, "", int16(service.RequestTypeWSV2)) setOpsEndpointContext(c, "", int16(service.RequestTypeWSV2))
if service.IsImageGenerationIntent("/v1/responses", reqModel, firstMessage) && !service.GroupAllowsImageGeneration(apiKey.Group) {
closeOpenAIClientWS(wsConn, coderws.StatusPolicyViolation, service.ImageGenerationPermissionMessage())
return
}
// 解析渠道级模型映射 // 解析渠道级模型映射
channelMappingWS, _ := h.gatewayService.ResolveChannelMappingAndRestrict(ctx, apiKey.GroupID, reqModel) channelMappingWS, _ := h.gatewayService.ResolveChannelMappingAndRestrict(ctx, apiKey.GroupID, reqModel)
@ -1233,6 +1267,7 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) {
) )
hooks := &service.OpenAIWSIngressHooks{ hooks := &service.OpenAIWSIngressHooks{
InitialRequestModel: reqModel,
BeforeTurn: func(turn int) error { BeforeTurn: func(turn int) error {
if turn == 1 { if turn == 1 {
return nil return nil
@ -1266,22 +1301,34 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) {
}, },
AfterTurn: func(turn int, result *service.OpenAIForwardResult, turnErr error) { AfterTurn: func(turn int, result *service.OpenAIForwardResult, turnErr error) {
releaseTurnSlots() releaseTurnSlots()
if turnErr != nil || result == nil { if turnErr != nil {
if result == nil || result.ImageCount <= 0 {
return
}
reqLog.Warn("openai.websocket_partial_error_with_image_result",
zap.Int64("account_id", account.ID),
zap.Int("image_count", result.ImageCount),
zap.Error(turnErr),
)
}
if result == nil {
return return
} }
if account.Type == service.AccountTypeOAuth { if account.Type == service.AccountTypeOAuth {
h.gatewayService.UpdateCodexUsageSnapshotFromHeaders(ctx, account.ID, result.ResponseHeaders) h.gatewayService.UpdateCodexUsageSnapshotFromHeaders(ctx, account.ID, result.ResponseHeaders)
} }
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, result.FirstTokenMs) h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, result.FirstTokenMs)
h.submitUsageRecordTask(func(taskCtx context.Context) { inboundEndpoint := GetInboundEndpoint(c)
upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform)
h.submitOpenAIUsageRecordTask(result, func(taskCtx context.Context) {
if err := h.gatewayService.RecordUsage(taskCtx, &service.OpenAIRecordUsageInput{ if err := h.gatewayService.RecordUsage(taskCtx, &service.OpenAIRecordUsageInput{
Result: result, Result: result,
APIKey: apiKey, APIKey: apiKey,
User: apiKey.User, User: apiKey.User,
Account: account, Account: account,
Subscription: subscription, Subscription: subscription,
InboundEndpoint: GetInboundEndpoint(c), InboundEndpoint: inboundEndpoint,
UpstreamEndpoint: GetUpstreamEndpoint(c, account.Platform), UpstreamEndpoint: upstreamEndpoint,
UserAgent: userAgent, UserAgent: userAgent,
IPAddress: clientIP, IPAddress: clientIP,
RequestPayloadHash: service.HashUsageRequestPayload(firstMessage), RequestPayloadHash: service.HashUsageRequestPayload(firstMessage),
@ -1449,6 +1496,60 @@ func (h *OpenAIGatewayHandler) submitUsageRecordTask(task service.UsageRecordTas
task(ctx) task(ctx)
} }
func (h *OpenAIGatewayHandler) submitOpenAIUsageRecordTask(result *service.OpenAIForwardResult, task service.UsageRecordTask) {
if result != nil && result.ImageCount > 0 {
h.submitMandatoryUsageRecordTask(task)
return
}
h.submitUsageRecordTask(task)
}
func (h *OpenAIGatewayHandler) submitMandatoryUsageRecordTask(task service.UsageRecordTask) {
if task == nil {
return
}
if h.usageRecordWorkerPool != nil {
if mode := h.usageRecordWorkerPool.Submit(task); mode != service.UsageRecordSubmitModeDropped {
return
}
logger.L().With(
zap.String("component", "handler.openai_gateway.usage"),
).Warn("openai.usage_record_task_mandatory_sync_fallback")
}
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
defer func() {
if recovered := recover(); recovered != nil {
logger.L().With(
zap.String("component", "handler.openai_gateway.usage"),
zap.Any("panic", recovered),
).Error("openai.usage_record_task_panic_recovered")
}
}()
task(ctx)
}
func (h *OpenAIGatewayHandler) acquireImageGenerationSlot(c *gin.Context, streamStarted bool) (func(), bool) {
if h == nil || h.cfg == nil || h.imageLimiter == nil {
return nil, true
}
imageConcurrency := h.cfg.Gateway.ImageConcurrency
wait := strings.TrimSpace(imageConcurrency.OverflowMode) == config.ImageConcurrencyOverflowModeWait
release, acquired := h.imageLimiter.Acquire(
c.Request.Context(),
imageConcurrency.Enabled,
imageConcurrency.MaxConcurrentRequests,
wait,
time.Duration(imageConcurrency.WaitTimeoutSeconds)*time.Second,
imageConcurrency.MaxWaitingRequests,
)
if acquired {
return release, true
}
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Image generation concurrency limit exceeded, please retry later", streamStarted)
return nil, false
}
// handleConcurrencyError handles concurrency-related errors with proper 429 response // handleConcurrencyError handles concurrency-related errors with proper 429 response
func (h *OpenAIGatewayHandler) handleConcurrencyError(c *gin.Context, err error, slotType string, streamStarted bool) { func (h *OpenAIGatewayHandler) handleConcurrencyError(c *gin.Context, err error, slotType string, streamStarted bool) {
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error",

View File

@ -10,6 +10,7 @@ import (
"testing" "testing"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/config"
pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil" pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil"
"github.com/Wei-Shaw/sub2api/internal/server/middleware" "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
@ -91,6 +92,24 @@ func TestOpenAIHandleStreamingAwareError_JSONEscaping(t *testing.T) {
} }
} }
func TestResolveOpenAIMessagesMetadataSession_DoesNotDerivePromptCacheKey(t *testing.T) {
body := []byte(`{"model":"claude-sonnet-4-5","metadata":{"user_id":"claude-code-session"},"messages":[{"role":"user","content":"hello"}]}`)
sessionHash, promptCacheKey := resolveOpenAIMessagesMetadataSession("", "", "claude-sonnet-4-5", body)
require.NotEmpty(t, sessionHash)
require.Empty(t, promptCacheKey)
}
func TestResolveOpenAIMessagesMetadataSession_PreservesExplicitPromptCacheKey(t *testing.T) {
body := []byte(`{"metadata":{"user_id":"claude-code-session"}}`)
sessionHash, promptCacheKey := resolveOpenAIMessagesMetadataSession("", "explicit-cache", "claude-sonnet-4-5", body)
require.NotEmpty(t, sessionHash)
require.Equal(t, "explicit-cache", promptCacheKey)
}
func TestOpenAIHandleStreamingAwareError_NonStreaming(t *testing.T) { func TestOpenAIHandleStreamingAwareError_NonStreaming(t *testing.T) {
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)
w := httptest.NewRecorder() w := httptest.NewRecorder()
@ -352,30 +371,6 @@ func TestOpenAIEnsureResponsesDependencies(t *testing.T) {
}) })
} }
func TestResolveOpenAIForwardDefaultMappedModel(t *testing.T) {
t.Run("prefers_explicit_fallback_model", func(t *testing.T) {
apiKey := &service.APIKey{
Group: &service.Group{DefaultMappedModel: "gpt-5.4"},
}
require.Equal(t, "gpt-5.2", resolveOpenAIForwardDefaultMappedModel(apiKey, " gpt-5.2 "))
})
t.Run("uses_group_default_when_explicit_fallback_absent", func(t *testing.T) {
apiKey := &service.APIKey{
Group: &service.Group{DefaultMappedModel: "gpt-5.4"},
}
require.Equal(t, "gpt-5.4", resolveOpenAIForwardDefaultMappedModel(apiKey, ""))
})
t.Run("returns_empty_without_group_default", func(t *testing.T) {
require.Empty(t, resolveOpenAIForwardDefaultMappedModel(nil, ""))
require.Empty(t, resolveOpenAIForwardDefaultMappedModel(&service.APIKey{}, ""))
require.Empty(t, resolveOpenAIForwardDefaultMappedModel(&service.APIKey{
Group: &service.Group{},
}, ""))
})
}
func TestResolveOpenAIMessagesDispatchMappedModel(t *testing.T) { func TestResolveOpenAIMessagesDispatchMappedModel(t *testing.T) {
t.Run("exact_claude_model_override_wins", func(t *testing.T) { t.Run("exact_claude_model_override_wins", func(t *testing.T) {
apiKey := &service.APIKey{ apiKey := &service.APIKey{
@ -651,6 +646,46 @@ func TestOpenAIResponsesWebSocket_PreviousResponseIDKindLoggedBeforeAcquireFailu
require.Contains(t, strings.ToLower(closeErr.Reason), "failed to acquire user concurrency slot") require.Contains(t, strings.ToLower(closeErr.Reason), "failed to acquire user concurrency slot")
} }
func TestOpenAIResponsesWebSocket_PassthroughUsageLogPersistsUserAgentAndReasoningEffort(t *testing.T) {
got := runOpenAIResponsesWebSocketUsageLogCase(t, openAIResponsesWSUsageLogCase{
firstPayload: `{"type":"response.create","model":"gpt-5.4","stream":false,"reasoning":{"effort":"HIGH"}}`,
userAgent: testStringPtr("codex_cli_rs/0.125.0 test"),
})
require.NotNil(t, got.log.UserAgent)
require.Equal(t, "codex_cli_rs/0.125.0 test", *got.log.UserAgent)
require.NotNil(t, got.log.ReasoningEffort)
require.Equal(t, "high", *got.log.ReasoningEffort)
require.True(t, got.log.OpenAIWSMode)
}
func TestOpenAIResponsesWebSocket_PassthroughUsageLogInfersReasoningFromInitialRequestModel(t *testing.T) {
got := runOpenAIResponsesWebSocketUsageLogCase(t, openAIResponsesWSUsageLogCase{
firstPayload: `{"type":"response.create","model":"gpt-5.4-xhigh","stream":false}`,
userAgent: testStringPtr("codex_cli_rs/0.125.0 mapped"),
channelMapping: map[string]string{
"gpt-5.4-xhigh": "gpt-5.4",
},
})
require.Equal(t, "gpt-5.4", gjson.GetBytes(got.upstreamFirstPayload, "model").String(),
"上游首帧应使用渠道映射后的模型")
require.NotNil(t, got.log.ReasoningEffort)
require.Equal(t, "xhigh", *got.log.ReasoningEffort,
"usage log reasoning effort 必须使用渠道映射前首帧模型后缀推导")
}
func TestOpenAIResponsesWebSocket_PassthroughUsageLogLeavesUserAgentNilWhenMissing(t *testing.T) {
got := runOpenAIResponsesWebSocketUsageLogCase(t, openAIResponsesWSUsageLogCase{
firstPayload: `{"type":"response.create","model":"gpt-5.4","stream":false,"reasoning":{"effort":"medium"}}`,
userAgent: testStringPtr(""),
})
require.Nil(t, got.log.UserAgent, "空入站 User-Agent 不应由上游握手 UA 或默认 UA 兜底")
require.NotNil(t, got.log.ReasoningEffort)
require.Equal(t, "medium", *got.log.ReasoningEffort)
}
func TestSetOpenAIClientTransportHTTP(t *testing.T) { func TestSetOpenAIClientTransportHTTP(t *testing.T) {
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)
@ -796,3 +831,278 @@ func newOpenAIWSHandlerTestServer(t *testing.T, h *OpenAIGatewayHandler, subject
router.GET("/openai/v1/responses", h.ResponsesWebSocket) router.GET("/openai/v1/responses", h.ResponsesWebSocket)
return httptest.NewServer(router) return httptest.NewServer(router)
} }
type openAIResponsesWSUsageLogCase struct {
firstPayload string
userAgent *string
channelMapping map[string]string
}
type openAIResponsesWSUsageLogResult struct {
log *service.UsageLog
upstreamFirstPayload []byte
}
type openAIWSUsageHandlerAccountRepoStub struct {
service.AccountRepository
account service.Account
}
func (s *openAIWSUsageHandlerAccountRepoStub) ListSchedulableByPlatform(ctx context.Context, platform string) ([]service.Account, error) {
if s.account.Platform != platform {
return nil, nil
}
return []service.Account{s.account}, nil
}
func (s *openAIWSUsageHandlerAccountRepoStub) ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]service.Account, error) {
return s.ListSchedulableByPlatform(ctx, platform)
}
func (s *openAIWSUsageHandlerAccountRepoStub) GetByID(ctx context.Context, id int64) (*service.Account, error) {
if s.account.ID != id {
return nil, nil
}
account := s.account
return &account, nil
}
type openAIWSUsageHandlerUsageLogRepoStub struct {
service.UsageLogRepository
created chan *service.UsageLog
}
func (s *openAIWSUsageHandlerUsageLogRepoStub) Create(ctx context.Context, log *service.UsageLog) (bool, error) {
if s.created != nil {
s.created <- log
}
return true, nil
}
type openAIWSUsageHandlerChannelRepoStub struct {
service.ChannelRepository
channels []service.Channel
groupPlatforms map[int64]string
}
func (s *openAIWSUsageHandlerChannelRepoStub) ListAll(ctx context.Context) ([]service.Channel, error) {
return s.channels, nil
}
func (s *openAIWSUsageHandlerChannelRepoStub) GetGroupPlatforms(ctx context.Context, groupIDs []int64) (map[int64]string, error) {
out := make(map[int64]string, len(groupIDs))
for _, groupID := range groupIDs {
if platform := strings.TrimSpace(s.groupPlatforms[groupID]); platform != "" {
out[groupID] = platform
}
}
return out, nil
}
func runOpenAIResponsesWebSocketUsageLogCase(t *testing.T, tc openAIResponsesWSUsageLogCase) openAIResponsesWSUsageLogResult {
t.Helper()
gin.SetMode(gin.TestMode)
upstreamPayloadCh := make(chan []byte, 1)
upstreamErrCh := make(chan error, 1)
upstreamServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{
CompressionMode: coderws.CompressionContextTakeover,
})
if err != nil {
upstreamErrCh <- err
return
}
defer func() {
_ = conn.CloseNow()
}()
readCtx, cancelRead := context.WithTimeout(r.Context(), 3*time.Second)
msgType, payload, readErr := conn.Read(readCtx)
cancelRead()
if readErr != nil {
upstreamErrCh <- readErr
return
}
if msgType != coderws.MessageText && msgType != coderws.MessageBinary {
upstreamErrCh <- errors.New("unexpected upstream websocket message type")
return
}
upstreamPayloadCh <- payload
writeCtx, cancelWrite := context.WithTimeout(r.Context(), 3*time.Second)
writeErr := conn.Write(writeCtx, coderws.MessageText, []byte(
`{"type":"response.completed","response":{"id":"resp_usage_e2e","model":"gpt-5.4","usage":{"input_tokens":2,"output_tokens":1}}}`,
))
cancelWrite()
if writeErr != nil {
upstreamErrCh <- writeErr
return
}
_ = conn.Close(coderws.StatusNormalClosure, "done")
upstreamErrCh <- nil
}))
defer upstreamServer.Close()
groupID := int64(4201)
account := service.Account{
ID: 9901,
Name: "openai-ws-passthrough-usage-e2e",
Platform: service.PlatformOpenAI,
Type: service.AccountTypeAPIKey,
Status: service.StatusActive,
Schedulable: true,
Concurrency: 1,
Credentials: map[string]any{
"api_key": "sk-test",
"base_url": upstreamServer.URL,
},
Extra: map[string]any{
"openai_apikey_responses_websockets_v2_enabled": true,
"openai_apikey_responses_websockets_v2_mode": service.OpenAIWSIngressModePassthrough,
},
}
cfg := &config.Config{}
cfg.RunMode = config.RunModeSimple
cfg.Default.RateMultiplier = 1
cfg.Security.URLAllowlist.Enabled = false
cfg.Security.URLAllowlist.AllowInsecureHTTP = true
cfg.Gateway.OpenAIWS.Enabled = true
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
cfg.Gateway.OpenAIWS.ModeRouterV2Enabled = true
cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3
cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3
cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3
accountRepo := &openAIWSUsageHandlerAccountRepoStub{account: account}
usageRepo := &openAIWSUsageHandlerUsageLogRepoStub{created: make(chan *service.UsageLog, 1)}
var channelSvc *service.ChannelService
if len(tc.channelMapping) > 0 {
channelSvc = service.NewChannelService(&openAIWSUsageHandlerChannelRepoStub{
channels: []service.Channel{{
ID: 7701,
Name: "openai-ws-e2e-channel",
Status: service.StatusActive,
GroupIDs: []int64{groupID},
ModelMapping: map[string]map[string]string{service.PlatformOpenAI: tc.channelMapping},
}},
groupPlatforms: map[int64]string{groupID: service.PlatformOpenAI},
}, nil, nil, nil)
}
billingCacheSvc := service.NewBillingCacheService(nil, nil, nil, nil, nil, nil, cfg)
gatewaySvc := service.NewOpenAIGatewayService(
accountRepo,
usageRepo,
nil,
nil,
nil,
nil,
nil,
cfg,
nil,
nil,
service.NewBillingService(cfg, nil),
nil,
billingCacheSvc,
nil,
&service.DeferredService{},
nil,
nil,
channelSvc,
nil,
nil,
)
cache := &concurrencyCacheMock{
acquireUserSlotFn: func(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error) {
return true, nil
},
acquireAccountSlotFn: func(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) {
return true, nil
},
}
h := &OpenAIGatewayHandler{
gatewayService: gatewaySvc,
billingCacheService: billingCacheSvc,
apiKeyService: &service.APIKeyService{},
concurrencyHelper: NewConcurrencyHelper(service.NewConcurrencyService(cache), SSEPingFormatNone, time.Second),
}
apiKey := &service.APIKey{
ID: 1801,
GroupID: &groupID,
User: &service.User{ID: 1701, Status: service.StatusActive},
}
router := gin.New()
router.Use(func(c *gin.Context) {
c.Set(string(middleware.ContextKeyAPIKey), apiKey)
c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{UserID: apiKey.User.ID, Concurrency: 1})
c.Next()
})
router.GET("/openai/v1/responses", h.ResponsesWebSocket)
handlerServer := httptest.NewServer(router)
defer handlerServer.Close()
headers := http.Header{}
if tc.userAgent != nil {
headers.Set("User-Agent", *tc.userAgent)
}
dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second)
clientConn, _, err := coderws.Dial(
dialCtx,
"ws"+strings.TrimPrefix(handlerServer.URL, "http")+"/openai/v1/responses",
&coderws.DialOptions{HTTPHeader: headers, CompressionMode: coderws.CompressionContextTakeover},
)
cancelDial()
require.NoError(t, err)
defer func() {
_ = clientConn.CloseNow()
}()
writeCtx, cancelWrite := context.WithTimeout(context.Background(), 3*time.Second)
err = clientConn.Write(writeCtx, coderws.MessageText, []byte(tc.firstPayload))
cancelWrite()
require.NoError(t, err)
readCtx, cancelRead := context.WithTimeout(context.Background(), 3*time.Second)
_, event, err := clientConn.Read(readCtx)
cancelRead()
require.NoError(t, err)
require.Equal(t, "response.completed", gjson.GetBytes(event, "type").String())
_ = clientConn.Close(coderws.StatusNormalClosure, "done")
var usageLog *service.UsageLog
select {
case usageLog = <-usageRepo.created:
require.NotNil(t, usageLog)
case <-time.After(3 * time.Second):
t.Fatal("等待 WebSocket usage log 写入超时")
}
var upstreamFirstPayload []byte
select {
case upstreamFirstPayload = <-upstreamPayloadCh:
case <-time.After(3 * time.Second):
t.Fatal("等待上游 WebSocket 首帧超时")
}
select {
case upstreamErr := <-upstreamErrCh:
require.NoError(t, upstreamErr)
case <-time.After(3 * time.Second):
t.Fatal("等待上游 WebSocket 结束超时")
}
return openAIResponsesWSUsageLogResult{
log: usageLog,
upstreamFirstPayload: upstreamFirstPayload,
}
}
func testStringPtr(v string) *string {
return &v
}

View File

@ -81,6 +81,18 @@ func (h *OpenAIGatewayHandler) Images(c *gin.Context) {
zap.String("capability", string(parsed.RequiredCapability)), zap.String("capability", string(parsed.RequiredCapability)),
) )
if !service.GroupAllowsImageGeneration(apiKey.Group) {
h.errorResponse(c, http.StatusForbidden, "permission_error", service.ImageGenerationPermissionMessage())
return
}
imageReleaseFunc, acquired := h.acquireImageGenerationSlot(c, streamStarted)
if !acquired {
return
}
if imageReleaseFunc != nil {
defer imageReleaseFunc()
}
if parsed.Multipart { if parsed.Multipart {
setOpsRequestContext(c, parsed.Model, parsed.Stream, nil) setOpsRequestContext(c, parsed.Model, parsed.Stream, nil)
} else { } else {
@ -188,62 +200,69 @@ func (h *OpenAIGatewayHandler) Images(c *gin.Context) {
responseLatencyMs = forwardDurationMs - upstreamLatencyMs responseLatencyMs = forwardDurationMs - upstreamLatencyMs
} }
service.SetOpsLatencyMs(c, service.OpsResponseLatencyMsKey, responseLatencyMs) service.SetOpsLatencyMs(c, service.OpsResponseLatencyMsKey, responseLatencyMs)
if err == nil && result != nil && result.FirstTokenMs != nil { if result != nil && result.FirstTokenMs != nil {
service.SetOpsLatencyMs(c, service.OpsTimeToFirstTokenMsKey, int64(*result.FirstTokenMs)) service.SetOpsLatencyMs(c, service.OpsTimeToFirstTokenMsKey, int64(*result.FirstTokenMs))
} }
if err != nil { if err != nil {
var failoverErr *service.UpstreamFailoverError if result != nil && result.ImageCount > 0 {
if errors.As(err, &failoverErr) { reqLog.Warn("openai.images.forward_partial_error_with_image_result",
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil) zap.Int64("account_id", account.ID),
if failoverErr.RetryableOnSameAccount { zap.Int("image_count", result.ImageCount),
retryLimit := account.GetPoolModeRetryCount() zap.Error(err),
if sameAccountRetryCount[account.ID] < retryLimit { )
sameAccountRetryCount[account.ID]++ } else {
reqLog.Warn("openai.images.pool_mode_same_account_retry", var failoverErr *service.UpstreamFailoverError
zap.Int64("account_id", account.ID), if errors.As(err, &failoverErr) {
zap.Int("upstream_status", failoverErr.StatusCode), h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
zap.Int("retry_limit", retryLimit), if failoverErr.RetryableOnSameAccount {
zap.Int("retry_count", sameAccountRetryCount[account.ID]), retryLimit := account.GetPoolModeRetryCount()
) if sameAccountRetryCount[account.ID] < retryLimit {
select { sameAccountRetryCount[account.ID]++
case <-c.Request.Context().Done(): reqLog.Warn("openai.images.pool_mode_same_account_retry",
return zap.Int64("account_id", account.ID),
case <-time.After(sameAccountRetryDelay): zap.Int("upstream_status", failoverErr.StatusCode),
zap.Int("retry_limit", retryLimit),
zap.Int("retry_count", sameAccountRetryCount[account.ID]),
)
select {
case <-c.Request.Context().Done():
return
case <-time.After(sameAccountRetryDelay):
}
continue
} }
continue
} }
h.gatewayService.RecordOpenAIAccountSwitch()
failedAccountIDs[account.ID] = struct{}{}
lastFailoverErr = failoverErr
if switchCount >= maxAccountSwitches {
h.handleFailoverExhausted(c, failoverErr, streamStarted)
return
}
switchCount++
reqLog.Warn("openai.images.upstream_failover_switching",
zap.Int64("account_id", account.ID),
zap.Int("upstream_status", failoverErr.StatusCode),
zap.Int("switch_count", switchCount),
zap.Int("max_switches", maxAccountSwitches),
)
continue
} }
h.gatewayService.RecordOpenAIAccountSwitch() h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
failedAccountIDs[account.ID] = struct{}{} wroteFallback := h.ensureForwardErrorResponse(c, streamStarted)
lastFailoverErr = failoverErr fields := []zap.Field{
if switchCount >= maxAccountSwitches { zap.Int64("account_id", account.ID),
h.handleFailoverExhausted(c, failoverErr, streamStarted) zap.Bool("fallback_error_response_written", wroteFallback),
zap.Error(err),
}
if shouldLogOpenAIForwardFailureAsWarn(c, wroteFallback) {
reqLog.Warn("openai.images.forward_failed", fields...)
return return
} }
switchCount++ reqLog.Error("openai.images.forward_failed", fields...)
reqLog.Warn("openai.images.upstream_failover_switching",
zap.Int64("account_id", account.ID),
zap.Int("upstream_status", failoverErr.StatusCode),
zap.Int("switch_count", switchCount),
zap.Int("max_switches", maxAccountSwitches),
)
continue
}
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
wroteFallback := h.ensureForwardErrorResponse(c, streamStarted)
fields := []zap.Field{
zap.Int64("account_id", account.ID),
zap.Bool("fallback_error_response_written", wroteFallback),
zap.Error(err),
}
if shouldLogOpenAIForwardFailureAsWarn(c, wroteFallback) {
reqLog.Warn("openai.images.forward_failed", fields...)
return return
} }
reqLog.Error("openai.images.forward_failed", fields...)
return
} }
if result != nil { if result != nil {
if account.Type == service.AccountTypeOAuth { if account.Type == service.AccountTypeOAuth {
h.gatewayService.UpdateCodexUsageSnapshotFromHeaders(c.Request.Context(), account.ID, result.ResponseHeaders) h.gatewayService.UpdateCodexUsageSnapshotFromHeaders(c.Request.Context(), account.ID, result.ResponseHeaders)
@ -259,21 +278,27 @@ func (h *OpenAIGatewayHandler) Images(c *gin.Context) {
if parsed.Multipart { if parsed.Multipart {
requestPayloadHash = service.HashUsageRequestPayload([]byte(parsed.StickySessionSeed())) requestPayloadHash = service.HashUsageRequestPayload([]byte(parsed.StickySessionSeed()))
} }
inboundEndpoint := GetInboundEndpoint(c)
upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform)
h.submitUsageRecordTask(func(ctx context.Context) { upstreamModel := ""
if result != nil {
upstreamModel = result.UpstreamModel
}
h.submitMandatoryUsageRecordTask(func(ctx context.Context) {
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{ if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
Result: result, Result: result,
APIKey: apiKey, APIKey: apiKey,
User: apiKey.User, User: apiKey.User,
Account: account, Account: account,
Subscription: subscription, Subscription: subscription,
InboundEndpoint: GetInboundEndpoint(c), InboundEndpoint: inboundEndpoint,
UpstreamEndpoint: GetUpstreamEndpoint(c, account.Platform), UpstreamEndpoint: upstreamEndpoint,
UserAgent: userAgent, UserAgent: userAgent,
IPAddress: clientIP, IPAddress: clientIP,
RequestPayloadHash: requestPayloadHash, RequestPayloadHash: requestPayloadHash,
APIKeyService: h.apiKeyService, APIKeyService: h.apiKeyService,
ChannelUsageFields: channelMapping.ToUsageFields(parsed.Model, result.UpstreamModel), ChannelUsageFields: channelMapping.ToUsageFields(parsed.Model, upstreamModel),
}); err != nil { }); err != nil {
logger.L().With( logger.L().With(
zap.String("component", "handler.openai_gateway.images"), zap.String("component", "handler.openai_gateway.images"),

View File

@ -0,0 +1,49 @@
package handler
import (
"bytes"
"net/http"
"net/http/httptest"
"testing"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
"github.com/tidwall/gjson"
)
func TestOpenAIGatewayHandlerImages_DisabledGroupRejectsBeforeScheduling(t *testing.T) {
gin.SetMode(gin.TestMode)
body := []byte(`{"model":"gpt-image-2","prompt":"draw","size":"1024x1024"}`)
req := httptest.NewRequest(http.MethodPost, "/v1/images/generations", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = req
groupID := int64(111)
c.Set(string(middleware2.ContextKeyAPIKey), &service.APIKey{
ID: 222,
GroupID: &groupID,
Group: &service.Group{
ID: groupID,
AllowImageGeneration: false,
},
User: &service.User{ID: 333},
})
c.Set(string(middleware2.ContextKeyUser), middleware2.AuthSubject{UserID: 333, Concurrency: 1})
h := &OpenAIGatewayHandler{
gatewayService: &service.OpenAIGatewayService{},
billingCacheService: &service.BillingCacheService{},
apiKeyService: &service.APIKeyService{},
concurrencyHelper: &ConcurrencyHelper{concurrencyService: &service.ConcurrencyService{}},
}
h.Images(c)
require.Equal(t, http.StatusForbidden, rec.Code)
require.Equal(t, "permission_error", gjson.GetBytes(rec.Body.Bytes(), "error.type").String())
require.Contains(t, rec.Body.String(), service.ImageGenerationPermissionMessage())
}

View File

@ -129,3 +129,63 @@ func TestOpenAIGatewayHandlerSubmitUsageRecordTask_WithoutPool_TaskPanicRecovere
}) })
require.True(t, called.Load(), "panic 后后续任务应仍可执行") require.True(t, called.Load(), "panic 后后续任务应仍可执行")
} }
func TestOpenAIGatewayHandlerSubmitMandatoryUsageRecordTask_DroppedTaskSyncFallback(t *testing.T) {
pool := service.NewUsageRecordWorkerPoolWithOptions(service.UsageRecordWorkerPoolOptions{
WorkerCount: 1,
QueueSize: 1,
TaskTimeout: time.Second,
OverflowPolicy: "drop",
OverflowSamplePercent: 0,
AutoScaleEnabled: false,
})
t.Cleanup(pool.Stop)
h := &OpenAIGatewayHandler{usageRecordWorkerPool: pool}
block := make(chan struct{})
release := make(chan struct{})
pool.Submit(func(ctx context.Context) {
close(block)
<-release
})
<-block
pool.Submit(func(ctx context.Context) {})
var called atomic.Bool
h.submitMandatoryUsageRecordTask(func(ctx context.Context) {
called.Store(true)
})
close(release)
require.True(t, called.Load(), "mandatory usage task must run synchronously when async submit is dropped")
}
func TestOpenAIGatewayHandlerSubmitOpenAIUsageRecordTask_ImageResultUsesMandatoryFallback(t *testing.T) {
pool := service.NewUsageRecordWorkerPoolWithOptions(service.UsageRecordWorkerPoolOptions{
WorkerCount: 1,
QueueSize: 1,
TaskTimeout: time.Second,
OverflowPolicy: "drop",
OverflowSamplePercent: 0,
AutoScaleEnabled: false,
})
t.Cleanup(pool.Stop)
h := &OpenAIGatewayHandler{usageRecordWorkerPool: pool}
block := make(chan struct{})
release := make(chan struct{})
pool.Submit(func(ctx context.Context) {
close(block)
<-release
})
<-block
pool.Submit(func(ctx context.Context) {})
var called atomic.Bool
h.submitOpenAIUsageRecordTask(&service.OpenAIForwardResult{ImageCount: 1}, func(ctx context.Context) {
called.Store(true)
})
close(release)
require.True(t, called.Load(), "image usage task must be mandatory when async submit is dropped")
}

View File

@ -32,7 +32,13 @@ func TestAnthropicToResponses_BasicText(t *testing.T) {
var items []ResponsesInputItem var items []ResponsesInputItem
require.NoError(t, json.Unmarshal(resp.Input, &items)) require.NoError(t, json.Unmarshal(resp.Input, &items))
require.Len(t, items, 1) require.Len(t, items, 1)
assert.Equal(t, "message", items[0].Type)
assert.Equal(t, "user", items[0].Role) assert.Equal(t, "user", items[0].Role)
var parts []ResponsesContentPart
require.NoError(t, json.Unmarshal(items[0].Content, &parts))
require.Len(t, parts, 1)
assert.Equal(t, "input_text", parts[0].Type)
assert.Equal(t, "Hello", parts[0].Text)
} }
func TestAnthropicToResponses_SystemPrompt(t *testing.T) { func TestAnthropicToResponses_SystemPrompt(t *testing.T) {
@ -49,7 +55,12 @@ func TestAnthropicToResponses_SystemPrompt(t *testing.T) {
var items []ResponsesInputItem var items []ResponsesInputItem
require.NoError(t, json.Unmarshal(resp.Input, &items)) require.NoError(t, json.Unmarshal(resp.Input, &items))
require.Len(t, items, 2) require.Len(t, items, 2)
assert.Equal(t, "system", items[0].Role) assert.Equal(t, "developer", items[0].Role)
var parts []ResponsesContentPart
require.NoError(t, json.Unmarshal(items[0].Content, &parts))
require.Len(t, parts, 1)
assert.Equal(t, "input_text", parts[0].Type)
assert.Equal(t, "You are helpful.", parts[0].Text)
}) })
t.Run("array", func(t *testing.T) { t.Run("array", func(t *testing.T) {
@ -65,11 +76,33 @@ func TestAnthropicToResponses_SystemPrompt(t *testing.T) {
var items []ResponsesInputItem var items []ResponsesInputItem
require.NoError(t, json.Unmarshal(resp.Input, &items)) require.NoError(t, json.Unmarshal(resp.Input, &items))
require.Len(t, items, 2) require.Len(t, items, 2)
assert.Equal(t, "system", items[0].Role) assert.Equal(t, "developer", items[0].Role)
// System text should be joined with double newline. var parts []ResponsesContentPart
var text string require.NoError(t, json.Unmarshal(items[0].Content, &parts))
require.NoError(t, json.Unmarshal(items[0].Content, &text)) require.Len(t, parts, 2)
assert.Equal(t, "Part 1\n\nPart 2", text) assert.Equal(t, "input_text", parts[0].Type)
assert.Equal(t, "Part 1", parts[0].Text)
assert.Equal(t, "input_text", parts[1].Type)
assert.Equal(t, "Part 2", parts[1].Text)
})
t.Run("billing header skipped", func(t *testing.T) {
req := &AnthropicRequest{
Model: "gpt-5.2",
MaxTokens: 100,
System: json.RawMessage(`[{"type":"text","text":"x-anthropic-billing-header: cc_version=1;"},{"type":"text","text":"Project prompt"}]`),
Messages: []AnthropicMessage{{Role: "user", Content: json.RawMessage(`"Hi"`)}},
}
resp, err := AnthropicToResponses(req)
require.NoError(t, err)
var items []ResponsesInputItem
require.NoError(t, json.Unmarshal(resp.Input, &items))
require.Len(t, items, 2)
var parts []ResponsesContentPart
require.NoError(t, json.Unmarshal(items[0].Content, &parts))
require.Len(t, parts, 1)
assert.Equal(t, "Project prompt", parts[0].Text)
}) })
} }
@ -94,6 +127,8 @@ func TestAnthropicToResponses_ToolUse(t *testing.T) {
require.Len(t, resp.Tools, 1) require.Len(t, resp.Tools, 1)
assert.Equal(t, "function", resp.Tools[0].Type) assert.Equal(t, "function", resp.Tools[0].Type)
assert.Equal(t, "get_weather", resp.Tools[0].Name) assert.Equal(t, "get_weather", resp.Tools[0].Name)
require.NotNil(t, resp.Tools[0].Strict)
assert.False(t, *resp.Tools[0].Strict)
// Check input items // Check input items
var items []ResponsesInputItem var items []ResponsesInputItem
@ -104,10 +139,10 @@ func TestAnthropicToResponses_ToolUse(t *testing.T) {
assert.Equal(t, "user", items[0].Role) assert.Equal(t, "user", items[0].Role)
assert.Equal(t, "assistant", items[1].Role) assert.Equal(t, "assistant", items[1].Role)
assert.Equal(t, "function_call", items[2].Type) assert.Equal(t, "function_call", items[2].Type)
assert.Equal(t, "fc_call_1", items[2].CallID) assert.Equal(t, "call_1", items[2].CallID)
assert.Empty(t, items[2].ID) assert.Empty(t, items[2].ID)
assert.Equal(t, "function_call_output", items[3].Type) assert.Equal(t, "function_call_output", items[3].Type)
assert.Equal(t, "fc_call_1", items[3].CallID) assert.Equal(t, "call_1", items[3].CallID)
assert.Equal(t, "Sunny, 72°F", items[3].Output) assert.Equal(t, "Sunny, 72°F", items[3].Output)
} }
@ -261,6 +296,34 @@ func TestResponsesToAnthropic_ToolUse(t *testing.T) {
assert.JSONEq(t, `{"city":"NYC"}`, string(anth.Content[1].Input)) assert.JSONEq(t, `{"city":"NYC"}`, string(anth.Content[1].Input))
} }
func TestResponsesToAnthropic_ToolUseStopReasonDoesNotDependOnLastBlock(t *testing.T) {
resp := &ResponsesResponse{
ID: "resp_tool_then_text",
Model: "gpt-5.5",
Status: "completed",
Output: []ResponsesOutput{
{
Type: "function_call",
CallID: "call_todo",
Name: "TodoWrite",
Arguments: `{"todos":[{"content":"review changes","status":"in_progress"}]}`,
},
{
Type: "message",
Content: []ResponsesContentPart{
{Type: "output_text", Text: "Task list updated."},
},
},
},
}
anth := ResponsesToAnthropic(resp, "claude-opus-4-6")
assert.Equal(t, "tool_use", anth.StopReason)
require.Len(t, anth.Content, 2)
assert.Equal(t, "tool_use", anth.Content[0].Type)
assert.Equal(t, "text", anth.Content[1].Type)
}
func TestResponsesToAnthropic_ReadToolDropsEmptyPages(t *testing.T) { func TestResponsesToAnthropic_ReadToolDropsEmptyPages(t *testing.T) {
resp := &ResponsesResponse{ resp := &ResponsesResponse{
ID: "resp_read", ID: "resp_read",
@ -434,6 +497,45 @@ func TestStreamingTextOnly(t *testing.T) {
assert.Equal(t, "message_stop", events[1].Type) assert.Equal(t, "message_stop", events[1].Type)
} }
func TestResponsesEventToAnthropicEvents_ResponseDone(t *testing.T) {
state := NewResponsesEventToAnthropicState()
state.Model = "gpt-4o"
events := ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{
Type: "response.done",
Response: &ResponsesResponse{
Status: "completed",
Usage: &ResponsesUsage{InputTokens: 12, OutputTokens: 4},
},
}, state)
require.Len(t, events, 2)
assert.Equal(t, "message_delta", events[0].Type)
assert.Equal(t, "end_turn", events[0].Delta.StopReason)
assert.Equal(t, 12, events[0].Usage.InputTokens)
assert.Equal(t, 4, events[0].Usage.OutputTokens)
assert.Equal(t, "message_stop", events[1].Type)
assert.Nil(t, FinalizeResponsesAnthropicStream(state))
}
func TestResponsesEventToAnthropicEvents_ResponseDoneIncomplete(t *testing.T) {
state := NewResponsesEventToAnthropicState()
state.Model = "gpt-4o"
events := ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{
Type: "response.done",
Response: &ResponsesResponse{
Status: "incomplete",
IncompleteDetails: &ResponsesIncompleteDetails{Reason: "max_output_tokens"},
Usage: &ResponsesUsage{InputTokens: 12, OutputTokens: 4},
},
}, state)
require.Len(t, events, 2)
assert.Equal(t, "message_delta", events[0].Type)
assert.Equal(t, "max_tokens", events[0].Delta.StopReason)
assert.Equal(t, "message_stop", events[1].Type)
assert.Nil(t, FinalizeResponsesAnthropicStream(state))
}
func TestStreamingCachedTokensUseAnthropicInputSemantics(t *testing.T) { func TestStreamingCachedTokensUseAnthropicInputSemantics(t *testing.T) {
state := NewResponsesEventToAnthropicState() state := NewResponsesEventToAnthropicState()
ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{ ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{
@ -514,6 +616,81 @@ func TestStreamingToolCall(t *testing.T) {
assert.Equal(t, "tool_use", events[0].Delta.StopReason) assert.Equal(t, "tool_use", events[0].Delta.StopReason)
} }
func TestStreamingToolCallStopReasonSurvivesLaterText(t *testing.T) {
state := NewResponsesEventToAnthropicState()
ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{
Type: "response.created",
Response: &ResponsesResponse{ID: "resp_tool_then_text", Model: "gpt-5.5"},
}, state)
events := ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{
Type: "response.output_item.added",
OutputIndex: 0,
Item: &ResponsesOutput{Type: "function_call", CallID: "call_todo", Name: "TodoWrite"},
}, state)
require.Len(t, events, 1)
assert.Equal(t, "content_block_start", events[0].Type)
events = ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{
Type: "response.function_call_arguments.done",
OutputIndex: 0,
Arguments: `{"todos":[{"content":"review changes","status":"in_progress","activeForm":"reviewing changes"}]}`,
}, state)
require.Len(t, events, 2)
assert.Equal(t, "content_block_delta", events[0].Type)
assert.Equal(t, "content_block_stop", events[1].Type)
events = ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{
Type: "response.output_text.delta",
OutputIndex: 1,
Delta: "I will continue after the task list updates.",
}, state)
require.Len(t, events, 2)
assert.Equal(t, "content_block_start", events[0].Type)
assert.Equal(t, "content_block_delta", events[1].Type)
events = ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{
Type: "response.completed",
Response: &ResponsesResponse{
Status: "completed",
Usage: &ResponsesUsage{InputTokens: 20, OutputTokens: 10},
},
}, state)
require.Len(t, events, 3)
assert.Equal(t, "content_block_stop", events[0].Type)
assert.Equal(t, "tool_use", events[1].Delta.StopReason)
assert.Equal(t, "message_stop", events[2].Type)
}
func TestStreamingToolCallDoneWithoutDeltaEmitsArguments(t *testing.T) {
state := NewResponsesEventToAnthropicState()
ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{
Type: "response.created",
Response: &ResponsesResponse{ID: "resp_bash", Model: "gpt-5.5"},
}, state)
events := ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{
Type: "response.output_item.added",
OutputIndex: 0,
Item: &ResponsesOutput{Type: "function_call", CallID: "call_bash", Name: "Bash"},
}, state)
require.Len(t, events, 1)
assert.Equal(t, "content_block_start", events[0].Type)
events = ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{
Type: "response.function_call_arguments.done",
OutputIndex: 0,
Arguments: `{"command":"git -C \"/mnt/d/nodejs/other/edmt\" status --short --ignored"}`,
}, state)
require.Len(t, events, 2)
assert.Equal(t, "content_block_delta", events[0].Type)
assert.Equal(t, "input_json_delta", events[0].Delta.Type)
assert.JSONEq(t, `{"command":"git -C \"/mnt/d/nodejs/other/edmt\" status --short --ignored"}`, events[0].Delta.PartialJSON)
assert.Equal(t, "content_block_stop", events[1].Type)
}
func TestStreamingReadToolDropsEmptyPages(t *testing.T) { func TestStreamingReadToolDropsEmptyPages(t *testing.T) {
state := NewResponsesEventToAnthropicState() state := NewResponsesEventToAnthropicState()
@ -653,6 +830,27 @@ func TestFinalizeStream_AbnormalTermination(t *testing.T) {
assert.Equal(t, "message_stop", events[2].Type) assert.Equal(t, "message_stop", events[2].Type)
} }
func TestFinalizeStream_ToolCallAbnormalTerminationKeepsToolUseStopReason(t *testing.T) {
state := NewResponsesEventToAnthropicState()
ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{
Type: "response.created",
Response: &ResponsesResponse{ID: "resp_tool_interrupted", Model: "gpt-5.5"},
}, state)
ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{
Type: "response.output_item.added",
OutputIndex: 0,
Item: &ResponsesOutput{Type: "function_call", CallID: "call_todo", Name: "TodoWrite"},
}, state)
events := FinalizeResponsesAnthropicStream(state)
require.Len(t, events, 3)
assert.Equal(t, "content_block_stop", events[0].Type)
assert.Equal(t, "message_delta", events[1].Type)
assert.Equal(t, "tool_use", events[1].Delta.StopReason)
assert.Equal(t, "message_stop", events[2].Type)
}
func TestStreamingEmptyResponse(t *testing.T) { func TestStreamingEmptyResponse(t *testing.T) {
state := NewResponsesEventToAnthropicState() state := NewResponsesEventToAnthropicState()
@ -788,8 +986,8 @@ func TestAnthropicToResponses_ThinkingEnabled(t *testing.T) {
resp, err := AnthropicToResponses(req) resp, err := AnthropicToResponses(req)
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, resp.Reasoning) require.NotNil(t, resp.Reasoning)
// thinking.type is ignored for effort; default high applies. // thinking.type is ignored for effort; Codex bridge default medium applies.
assert.Equal(t, "high", resp.Reasoning.Effort) assert.Equal(t, "medium", resp.Reasoning.Effort)
assert.Equal(t, "auto", resp.Reasoning.Summary) assert.Equal(t, "auto", resp.Reasoning.Summary)
assert.Contains(t, resp.Include, "reasoning.encrypted_content") assert.Contains(t, resp.Include, "reasoning.encrypted_content")
assert.NotContains(t, resp.Include, "reasoning.summary") assert.NotContains(t, resp.Include, "reasoning.summary")
@ -806,8 +1004,8 @@ func TestAnthropicToResponses_ThinkingAdaptive(t *testing.T) {
resp, err := AnthropicToResponses(req) resp, err := AnthropicToResponses(req)
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, resp.Reasoning) require.NotNil(t, resp.Reasoning)
// thinking.type is ignored for effort; default high applies. // thinking.type is ignored for effort; Codex bridge default medium applies.
assert.Equal(t, "high", resp.Reasoning.Effort) assert.Equal(t, "medium", resp.Reasoning.Effort)
assert.Equal(t, "auto", resp.Reasoning.Summary) assert.Equal(t, "auto", resp.Reasoning.Summary)
assert.NotContains(t, resp.Include, "reasoning.summary") assert.NotContains(t, resp.Include, "reasoning.summary")
} }
@ -822,9 +1020,9 @@ func TestAnthropicToResponses_ThinkingDisabled(t *testing.T) {
resp, err := AnthropicToResponses(req) resp, err := AnthropicToResponses(req)
require.NoError(t, err) require.NoError(t, err)
// Default effort applies (high → high) even when thinking is disabled. // Default effort applies (medium) even when thinking is disabled.
require.NotNil(t, resp.Reasoning) require.NotNil(t, resp.Reasoning)
assert.Equal(t, "high", resp.Reasoning.Effort) assert.Equal(t, "medium", resp.Reasoning.Effort)
} }
func TestAnthropicToResponses_NoThinking(t *testing.T) { func TestAnthropicToResponses_NoThinking(t *testing.T) {
@ -836,9 +1034,9 @@ func TestAnthropicToResponses_NoThinking(t *testing.T) {
resp, err := AnthropicToResponses(req) resp, err := AnthropicToResponses(req)
require.NoError(t, err) require.NoError(t, err)
// Default effort applies (high → high) when no thinking/output_config is set. // Default effort applies (medium) when no thinking/output_config is set.
require.NotNil(t, resp.Reasoning) require.NotNil(t, resp.Reasoning)
assert.Equal(t, "high", resp.Reasoning.Effort) assert.Equal(t, "medium", resp.Reasoning.Effort)
} }
// --------------------------------------------------------------------------- // ---------------------------------------------------------------------------
@ -846,7 +1044,7 @@ func TestAnthropicToResponses_NoThinking(t *testing.T) {
// --------------------------------------------------------------------------- // ---------------------------------------------------------------------------
func TestAnthropicToResponses_OutputConfigOverridesDefault(t *testing.T) { func TestAnthropicToResponses_OutputConfigOverridesDefault(t *testing.T) {
// Default is high, but output_config.effort="low" overrides. low→low after mapping. // Default is medium, but output_config.effort="low" overrides. low→low after mapping.
req := &AnthropicRequest{ req := &AnthropicRequest{
Model: "gpt-5.2", Model: "gpt-5.2",
MaxTokens: 1024, MaxTokens: 1024,
@ -880,7 +1078,7 @@ func TestAnthropicToResponses_OutputConfigWithoutThinking(t *testing.T) {
} }
func TestAnthropicToResponses_OutputConfigHigh(t *testing.T) { func TestAnthropicToResponses_OutputConfigHigh(t *testing.T) {
// output_config.effort="high" → mapped to "high" (1:1, both sides' default). // output_config.effort="high" → mapped to "high" (1:1).
req := &AnthropicRequest{ req := &AnthropicRequest{
Model: "gpt-5.2", Model: "gpt-5.2",
MaxTokens: 1024, MaxTokens: 1024,
@ -912,7 +1110,7 @@ func TestAnthropicToResponses_OutputConfigMax(t *testing.T) {
} }
func TestAnthropicToResponses_NoOutputConfig(t *testing.T) { func TestAnthropicToResponses_NoOutputConfig(t *testing.T) {
// No output_config → default high regardless of thinking.type. // No output_config → default medium regardless of thinking.type.
req := &AnthropicRequest{ req := &AnthropicRequest{
Model: "gpt-5.2", Model: "gpt-5.2",
MaxTokens: 1024, MaxTokens: 1024,
@ -923,11 +1121,11 @@ func TestAnthropicToResponses_NoOutputConfig(t *testing.T) {
resp, err := AnthropicToResponses(req) resp, err := AnthropicToResponses(req)
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, resp.Reasoning) require.NotNil(t, resp.Reasoning)
assert.Equal(t, "high", resp.Reasoning.Effort) assert.Equal(t, "medium", resp.Reasoning.Effort)
} }
func TestAnthropicToResponses_OutputConfigWithoutEffort(t *testing.T) { func TestAnthropicToResponses_OutputConfigWithoutEffort(t *testing.T) {
// output_config present but effort empty (e.g. only format set) → default high. // output_config present but effort empty (e.g. only format set) → default medium.
req := &AnthropicRequest{ req := &AnthropicRequest{
Model: "gpt-5.2", Model: "gpt-5.2",
MaxTokens: 1024, MaxTokens: 1024,
@ -938,7 +1136,7 @@ func TestAnthropicToResponses_OutputConfigWithoutEffort(t *testing.T) {
resp, err := AnthropicToResponses(req) resp, err := AnthropicToResponses(req)
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, resp.Reasoning) require.NotNil(t, resp.Reasoning)
assert.Equal(t, "high", resp.Reasoning.Effort) assert.Equal(t, "medium", resp.Reasoning.Effort)
} }
// --------------------------------------------------------------------------- // ---------------------------------------------------------------------------
@ -1110,7 +1308,7 @@ func TestAnthropicToResponses_ToolResultWithImage(t *testing.T) {
// function_call_output should have text-only output (no image). // function_call_output should have text-only output (no image).
assert.Equal(t, "function_call_output", items[2].Type) assert.Equal(t, "function_call_output", items[2].Type)
assert.Equal(t, "fc_toolu_1", items[2].CallID) assert.Equal(t, "toolu_1", items[2].CallID)
assert.Equal(t, "(empty)", items[2].Output) assert.Equal(t, "(empty)", items[2].Output)
// Image should be in a separate user message. // Image should be in a separate user message.

View File

@ -32,6 +32,9 @@ func AnthropicToResponses(req *AnthropicRequest) (*ResponsesRequest, error) {
storeFalse := false storeFalse := false
out.Store = &storeFalse out.Store = &storeFalse
parallelToolCalls := true
out.ParallelToolCalls = &parallelToolCalls
out.Text = &ResponsesText{Verbosity: "medium"}
if req.MaxTokens > 0 { if req.MaxTokens > 0 {
v := req.MaxTokens v := req.MaxTokens
@ -46,10 +49,10 @@ func AnthropicToResponses(req *AnthropicRequest) (*ResponsesRequest, error) {
} }
// Determine reasoning effort: only output_config.effort controls the // Determine reasoning effort: only output_config.effort controls the
// level; thinking.type is ignored. Default is high when unset (both // level; thinking.type is ignored. Default follows Codex CLI / airgate's
// Anthropic and OpenAI default to high). // Anthropic bridge shape, which uses medium when unset.
// Anthropic levels map 1:1 to OpenAI: low→low, medium→medium, high→high, max→xhigh. // Anthropic levels map 1:1 to OpenAI: low→low, medium→medium, high→high, max→xhigh.
effort := "high" // default → both sides' default effort := "medium"
if req.OutputConfig != nil && req.OutputConfig.Effort != "" { if req.OutputConfig != nil && req.OutputConfig.Effort != "" {
effort = req.OutputConfig.Effort effort = req.OutputConfig.Effort
} }
@ -108,16 +111,19 @@ func convertAnthropicToolChoiceToResponses(raw json.RawMessage) (json.RawMessage
func convertAnthropicToResponsesInput(system json.RawMessage, msgs []AnthropicMessage) ([]ResponsesInputItem, error) { func convertAnthropicToResponsesInput(system json.RawMessage, msgs []AnthropicMessage) ([]ResponsesInputItem, error) {
var out []ResponsesInputItem var out []ResponsesInputItem
// System prompt → system role input item. // System prompt → developer role input item. ChatGPT Codex SSE behaves like
// Codex CLI here: keeping Anthropic system text in input preserves the
// conversation/cache shape better than moving it into instructions.
if len(system) > 0 { if len(system) > 0 {
sysText, err := parseAnthropicSystemPrompt(system) sysParts, err := parseAnthropicSystemContentParts(system)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if sysText != "" { if len(sysParts) > 0 {
content, _ := json.Marshal(sysText) content, _ := json.Marshal(sysParts)
out = append(out, ResponsesInputItem{ out = append(out, ResponsesInputItem{
Role: "system", Type: "message",
Role: "developer",
Content: content, Content: content,
}) })
} }
@ -133,24 +139,32 @@ func convertAnthropicToResponsesInput(system json.RawMessage, msgs []AnthropicMe
return out, nil return out, nil
} }
// parseAnthropicSystemPrompt handles the Anthropic system field which can be // parseAnthropicSystemContentParts handles the Anthropic system field which can
// a plain string or an array of text blocks. // be a plain string or an array of text blocks. Claude Code may include an
func parseAnthropicSystemPrompt(raw json.RawMessage) (string, error) { // x-anthropic-billing-header block; airgate drops it before sending to Codex.
func parseAnthropicSystemContentParts(raw json.RawMessage) ([]ResponsesContentPart, error) {
var s string var s string
if err := json.Unmarshal(raw, &s); err == nil { if err := json.Unmarshal(raw, &s); err == nil {
return s, nil if isAnthropicBillingHeaderText(s) || s == "" {
return nil, nil
}
return []ResponsesContentPart{{Type: "input_text", Text: s}}, nil
} }
var blocks []AnthropicContentBlock var blocks []AnthropicContentBlock
if err := json.Unmarshal(raw, &blocks); err != nil { if err := json.Unmarshal(raw, &blocks); err != nil {
return "", err return nil, err
} }
var parts []string var parts []ResponsesContentPart
for _, b := range blocks { for _, b := range blocks {
if b.Type == "text" && b.Text != "" { if b.Type == "text" && b.Text != "" && !isAnthropicBillingHeaderText(b.Text) {
parts = append(parts, b.Text) parts = append(parts, ResponsesContentPart{Type: "input_text", Text: b.Text})
} }
} }
return strings.Join(parts, "\n\n"), nil return parts, nil
}
func isAnthropicBillingHeaderText(text string) bool {
return strings.HasPrefix(text, "x-anthropic-billing-header: ")
} }
// anthropicMsgToResponsesItems converts a single Anthropic message into one // anthropicMsgToResponsesItems converts a single Anthropic message into one
@ -173,8 +187,12 @@ func anthropicUserToResponses(raw json.RawMessage) ([]ResponsesInputItem, error)
// Try plain string. // Try plain string.
var s string var s string
if err := json.Unmarshal(raw, &s); err == nil { if err := json.Unmarshal(raw, &s); err == nil {
content, _ := json.Marshal(s) parts := []ResponsesContentPart{{Type: "input_text", Text: s}}
return []ResponsesInputItem{{Role: "user", Content: content}}, nil partsJSON, err := json.Marshal(parts)
if err != nil {
return nil, err
}
return []ResponsesInputItem{{Type: "message", Role: "user", Content: partsJSON}}, nil
} }
var blocks []AnthropicContentBlock var blocks []AnthropicContentBlock
@ -223,7 +241,7 @@ func anthropicUserToResponses(raw json.RawMessage) ([]ResponsesInputItem, error)
if err != nil { if err != nil {
return nil, err return nil, err
} }
out = append(out, ResponsesInputItem{Role: "user", Content: content}) out = append(out, ResponsesInputItem{Type: "message", Role: "user", Content: content})
} }
return out, nil return out, nil
@ -242,7 +260,7 @@ func anthropicAssistantToResponses(raw json.RawMessage) ([]ResponsesInputItem, e
if err != nil { if err != nil {
return nil, err return nil, err
} }
return []ResponsesInputItem{{Role: "assistant", Content: partsJSON}}, nil return []ResponsesInputItem{{Type: "message", Role: "assistant", Content: partsJSON}}, nil
} }
var blocks []AnthropicContentBlock var blocks []AnthropicContentBlock
@ -260,7 +278,7 @@ func anthropicAssistantToResponses(raw json.RawMessage) ([]ResponsesInputItem, e
if err != nil { if err != nil {
return nil, err return nil, err
} }
items = append(items, ResponsesInputItem{Role: "assistant", Content: partsJSON}) items = append(items, ResponsesInputItem{Type: "message", Role: "assistant", Content: partsJSON})
} }
// tool_use → function_call items. // tool_use → function_call items.
@ -284,17 +302,14 @@ func anthropicAssistantToResponses(raw json.RawMessage) ([]ResponsesInputItem, e
return items, nil return items, nil
} }
// toResponsesCallID converts an Anthropic tool ID (toolu_xxx / call_xxx) to a // toResponsesCallID preserves Anthropic tool IDs as Responses call_id values.
// Responses API function_call ID that starts with "fc_". // Claude Code sends tool_result.tool_use_id back verbatim, and ChatGPT Codex
// continuation expects that call_id to match the original tool_use id.
func toResponsesCallID(id string) string { func toResponsesCallID(id string) string {
if strings.HasPrefix(id, "fc_") { return id
return id
}
return "fc_" + id
} }
// fromResponsesCallID reverses toResponsesCallID, stripping the "fc_" prefix // fromResponsesCallID reverses old prefixed IDs while preserving current IDs.
// that was added during request conversion.
func fromResponsesCallID(id string) string { func fromResponsesCallID(id string) string {
if after, ok := strings.CutPrefix(id, "fc_"); ok { if after, ok := strings.CutPrefix(id, "fc_"); ok {
// Only strip if the remainder doesn't look like it was already "fc_" prefixed. // Only strip if the remainder doesn't look like it was already "fc_" prefixed.
@ -412,11 +427,16 @@ func convertAnthropicToolsToResponses(tools []AnthropicTool) []ResponsesTool {
Name: t.Name, Name: t.Name,
Description: t.Description, Description: t.Description,
Parameters: normalizeToolParameters(t.InputSchema), Parameters: normalizeToolParameters(t.InputSchema),
Strict: boolPtr(false),
}) })
} }
return out return out
} }
func boolPtr(v bool) *bool {
return &v
}
// normalizeToolParameters ensures the tool parameter schema is valid for // normalizeToolParameters ensures the tool parameter schema is valid for
// OpenAI's Responses API, which requires "properties" on object schemas. // OpenAI's Responses API, which requires "properties" on object schemas.
// //

View File

@ -720,6 +720,49 @@ func TestResponsesEventToChatChunks_Completed(t *testing.T) {
assert.Equal(t, 30, chunks[1].Usage.PromptTokensDetails.CachedTokens) assert.Equal(t, 30, chunks[1].Usage.PromptTokensDetails.CachedTokens)
} }
func TestResponsesEventToChatChunks_ResponseDone(t *testing.T) {
state := NewResponsesEventToChatState()
state.Model = "gpt-4o"
state.IncludeUsage = true
chunks := ResponsesEventToChatChunks(&ResponsesStreamEvent{
Type: "response.done",
Response: &ResponsesResponse{
Status: "completed",
Usage: &ResponsesUsage{InputTokens: 13, OutputTokens: 7},
},
}, state)
require.Len(t, chunks, 2)
require.NotNil(t, chunks[0].Choices[0].FinishReason)
assert.Equal(t, "stop", *chunks[0].Choices[0].FinishReason)
require.NotNil(t, chunks[1].Usage)
assert.Equal(t, 13, chunks[1].Usage.PromptTokens)
assert.Equal(t, 7, chunks[1].Usage.CompletionTokens)
assert.Nil(t, FinalizeResponsesChatStream(state))
}
func TestResponsesEventToChatChunks_ResponseDoneIncomplete(t *testing.T) {
state := NewResponsesEventToChatState()
state.Model = "gpt-4o"
state.IncludeUsage = true
chunks := ResponsesEventToChatChunks(&ResponsesStreamEvent{
Type: "response.done",
Response: &ResponsesResponse{
Status: "incomplete",
IncompleteDetails: &ResponsesIncompleteDetails{Reason: "max_output_tokens"},
Usage: &ResponsesUsage{InputTokens: 13, OutputTokens: 7},
},
}, state)
require.Len(t, chunks, 2)
require.NotNil(t, chunks[0].Choices[0].FinishReason)
assert.Equal(t, "length", *chunks[0].Choices[0].FinishReason)
require.NotNil(t, chunks[1].Usage)
assert.Equal(t, 13, chunks[1].Usage.PromptTokens)
assert.Equal(t, 7, chunks[1].Usage.CompletionTokens)
assert.Nil(t, FinalizeResponsesChatStream(state))
}
func TestResponsesEventToChatChunks_CompletedWithToolCalls(t *testing.T) { func TestResponsesEventToChatChunks_CompletedWithToolCalls(t *testing.T) {
state := NewResponsesEventToChatState() state := NewResponsesEventToChatState()
state.Model = "gpt-4o" state.Model = "gpt-4o"

View File

@ -120,7 +120,7 @@ func responsesStatusToAnthropicStopReason(status string, details *ResponsesIncom
} }
return "end_turn" return "end_turn"
case "completed": case "completed":
if len(blocks) > 0 && blocks[len(blocks)-1].Type == "tool_use" { if containsAnthropicToolUseBlock(blocks) {
return "tool_use" return "tool_use"
} }
return "end_turn" return "end_turn"
@ -129,6 +129,15 @@ func responsesStatusToAnthropicStopReason(status string, details *ResponsesIncom
} }
} }
func containsAnthropicToolUseBlock(blocks []AnthropicContentBlock) bool {
for _, block := range blocks {
if block.Type == "tool_use" {
return true
}
}
return false
}
func sanitizeAnthropicToolUseInput(name string, raw string) json.RawMessage { func sanitizeAnthropicToolUseInput(name string, raw string) json.RawMessage {
if name != "Read" || raw == "" { if name != "Read" || raw == "" {
return json.RawMessage(raw) return json.RawMessage(raw)
@ -161,11 +170,13 @@ type ResponsesEventToAnthropicState struct {
MessageStartSent bool MessageStartSent bool
MessageStopSent bool MessageStopSent bool
ContentBlockIndex int ContentBlockIndex int
ContentBlockOpen bool ContentBlockOpen bool
CurrentBlockType string // "text" | "thinking" | "tool_use" CurrentBlockType string // "text" | "thinking" | "tool_use"
CurrentToolName string CurrentToolName string
CurrentToolArgs string CurrentToolArgs string
CurrentToolHadDelta bool
HasToolCall bool
// OutputIndexToBlockIdx maps Responses output_index → Anthropic content block index. // OutputIndexToBlockIdx maps Responses output_index → Anthropic content block index.
OutputIndexToBlockIdx map[int]int OutputIndexToBlockIdx map[int]int
@ -212,7 +223,9 @@ func ResponsesEventToAnthropicEvents(
return resToAnthHandleReasoningDelta(evt, state) return resToAnthHandleReasoningDelta(evt, state)
case "response.reasoning_summary_text.done": case "response.reasoning_summary_text.done":
return resToAnthHandleBlockDone(state) return resToAnthHandleBlockDone(state)
case "response.completed", "response.incomplete", "response.failed": // response.done 是 Realtime/WS 与项目透传路径使用的终止别名;
// 普通 Responses HTTP SSE 的公开终止事件仍以 response.completed 为主。
case "response.completed", "response.done", "response.incomplete", "response.failed":
return resToAnthHandleCompleted(evt, state) return resToAnthHandleCompleted(evt, state)
default: default:
return nil return nil
@ -229,11 +242,16 @@ func FinalizeResponsesAnthropicStream(state *ResponsesEventToAnthropicState) []A
var events []AnthropicStreamEvent var events []AnthropicStreamEvent
events = append(events, closeCurrentBlock(state)...) events = append(events, closeCurrentBlock(state)...)
stopReason := "end_turn"
if state.HasToolCall {
stopReason = "tool_use"
}
events = append(events, events = append(events,
AnthropicStreamEvent{ AnthropicStreamEvent{
Type: "message_delta", Type: "message_delta",
Delta: &AnthropicDelta{ Delta: &AnthropicDelta{
StopReason: "end_turn", StopReason: stopReason,
}, },
Usage: &AnthropicUsage{ Usage: &AnthropicUsage{
InputTokens: state.InputTokens, InputTokens: state.InputTokens,
@ -304,6 +322,8 @@ func resToAnthHandleOutputItemAdded(evt *ResponsesStreamEvent, state *ResponsesE
state.CurrentBlockType = "tool_use" state.CurrentBlockType = "tool_use"
state.CurrentToolName = evt.Item.Name state.CurrentToolName = evt.Item.Name
state.CurrentToolArgs = "" state.CurrentToolArgs = ""
state.CurrentToolHadDelta = false
state.HasToolCall = true
events = append(events, AnthropicStreamEvent{ events = append(events, AnthropicStreamEvent{
Type: "content_block_start", Type: "content_block_start",
@ -388,6 +408,9 @@ func resToAnthHandleFuncArgsDelta(evt *ResponsesStreamEvent, state *ResponsesEve
state.CurrentToolArgs += evt.Delta state.CurrentToolArgs += evt.Delta
return nil return nil
} }
if state.CurrentBlockType == "tool_use" {
state.CurrentToolHadDelta = true
}
blockIdx, ok := state.OutputIndexToBlockIdx[evt.OutputIndex] blockIdx, ok := state.OutputIndexToBlockIdx[evt.OutputIndex]
if !ok { if !ok {
@ -405,7 +428,7 @@ func resToAnthHandleFuncArgsDelta(evt *ResponsesStreamEvent, state *ResponsesEve
} }
func resToAnthHandleFuncArgsDone(evt *ResponsesStreamEvent, state *ResponsesEventToAnthropicState) []AnthropicStreamEvent { func resToAnthHandleFuncArgsDone(evt *ResponsesStreamEvent, state *ResponsesEventToAnthropicState) []AnthropicStreamEvent {
if state.CurrentBlockType != "tool_use" || state.CurrentToolName != "Read" { if state.CurrentBlockType != "tool_use" {
return resToAnthHandleBlockDone(state) return resToAnthHandleBlockDone(state)
} }
@ -413,10 +436,16 @@ func resToAnthHandleFuncArgsDone(evt *ResponsesStreamEvent, state *ResponsesEven
if raw == "" { if raw == "" {
raw = state.CurrentToolArgs raw = state.CurrentToolArgs
} }
sanitized := sanitizeAnthropicToolUseInput(state.CurrentToolName, raw) if raw == "" || state.CurrentToolHadDelta {
if len(sanitized) == 0 {
return closeCurrentBlock(state) return closeCurrentBlock(state)
} }
if state.CurrentToolName == "Read" {
sanitized := sanitizeAnthropicToolUseInput(state.CurrentToolName, raw)
if len(sanitized) == 0 {
return closeCurrentBlock(state)
}
raw = string(sanitized)
}
idx := state.ContentBlockIndex idx := state.ContentBlockIndex
events := []AnthropicStreamEvent{{ events := []AnthropicStreamEvent{{
@ -424,7 +453,7 @@ func resToAnthHandleFuncArgsDone(evt *ResponsesStreamEvent, state *ResponsesEven
Index: &idx, Index: &idx,
Delta: &AnthropicDelta{ Delta: &AnthropicDelta{
Type: "input_json_delta", Type: "input_json_delta",
PartialJSON: string(sanitized), PartialJSON: raw,
}, },
}} }}
events = append(events, closeCurrentBlock(state)...) events = append(events, closeCurrentBlock(state)...)
@ -551,7 +580,7 @@ func resToAnthHandleCompleted(evt *ResponsesStreamEvent, state *ResponsesEventTo
stopReason = "max_tokens" stopReason = "max_tokens"
} }
case "completed": case "completed":
if state.ContentBlockIndex > 0 && state.CurrentBlockType == "tool_use" { if state.HasToolCall {
stopReason = "tool_use" stopReason = "tool_use"
} }
} }
@ -584,6 +613,7 @@ func closeCurrentBlock(state *ResponsesEventToAnthropicState) []AnthropicStreamE
state.ContentBlockIndex++ state.ContentBlockIndex++
state.CurrentToolName = "" state.CurrentToolName = ""
state.CurrentToolArgs = "" state.CurrentToolArgs = ""
state.CurrentToolHadDelta = false
return []AnthropicStreamEvent{{ return []AnthropicStreamEvent{{
Type: "content_block_stop", Type: "content_block_stop",
Index: &idx, Index: &idx,

View File

@ -160,7 +160,9 @@ func ResponsesEventToChatChunks(evt *ResponsesStreamEvent, state *ResponsesEvent
return resToChatHandleReasoningDelta(evt, state) return resToChatHandleReasoningDelta(evt, state)
case "response.reasoning_summary_text.done": case "response.reasoning_summary_text.done":
return nil return nil
case "response.completed", "response.incomplete", "response.failed": // response.done 是 Realtime/WS 与项目透传路径使用的终止别名;
// 普通 Responses HTTP SSE 的公开终止事件仍以 response.completed 为主。
case "response.completed", "response.done", "response.incomplete", "response.failed":
return resToChatHandleCompleted(evt, state) return resToChatHandleCompleted(evt, state)
default: default:
return nil return nil

View File

@ -53,6 +53,8 @@ type AnthropicMessage struct {
type AnthropicContentBlock struct { type AnthropicContentBlock struct {
Type string `json:"type"` Type string `json:"type"`
CacheControl *AnthropicCacheControl `json:"cache_control,omitempty"`
// type=text // type=text
Text string `json:"text,omitempty"` Text string `json:"text,omitempty"`
@ -165,19 +167,23 @@ type AnthropicDelta struct {
// ResponsesRequest is the request body for POST /v1/responses. // ResponsesRequest is the request body for POST /v1/responses.
type ResponsesRequest struct { type ResponsesRequest struct {
Model string `json:"model"` Model string `json:"model"`
Instructions string `json:"instructions,omitempty"` Instructions string `json:"instructions,omitempty"`
Input json.RawMessage `json:"input"` // string or []ResponsesInputItem Input json.RawMessage `json:"input"` // string or []ResponsesInputItem
MaxOutputTokens *int `json:"max_output_tokens,omitempty"` MaxOutputTokens *int `json:"max_output_tokens,omitempty"`
Temperature *float64 `json:"temperature,omitempty"` Temperature *float64 `json:"temperature,omitempty"`
TopP *float64 `json:"top_p,omitempty"` TopP *float64 `json:"top_p,omitempty"`
Stream bool `json:"stream,omitempty"` Stream bool `json:"stream,omitempty"`
Tools []ResponsesTool `json:"tools,omitempty"` Tools []ResponsesTool `json:"tools,omitempty"`
Include []string `json:"include,omitempty"` Include []string `json:"include,omitempty"`
Store *bool `json:"store,omitempty"` Store *bool `json:"store,omitempty"`
Reasoning *ResponsesReasoning `json:"reasoning,omitempty"` ParallelToolCalls *bool `json:"parallel_tool_calls,omitempty"`
ToolChoice json.RawMessage `json:"tool_choice,omitempty"` Reasoning *ResponsesReasoning `json:"reasoning,omitempty"`
ServiceTier string `json:"service_tier,omitempty"` Text *ResponsesText `json:"text,omitempty"`
ToolChoice json.RawMessage `json:"tool_choice,omitempty"`
ServiceTier string `json:"service_tier,omitempty"`
PromptCacheKey string `json:"prompt_cache_key,omitempty"`
PreviousResponseID string `json:"previous_response_id,omitempty"`
} }
// ResponsesReasoning configures reasoning effort in the Responses API. // ResponsesReasoning configures reasoning effort in the Responses API.
@ -186,13 +192,18 @@ type ResponsesReasoning struct {
Summary string `json:"summary,omitempty"` // "auto" | "concise" | "detailed" Summary string `json:"summary,omitempty"` // "auto" | "concise" | "detailed"
} }
// ResponsesText configures text output options in the Responses API.
type ResponsesText struct {
Verbosity string `json:"verbosity,omitempty"` // "low" | "medium" | "high"
}
// ResponsesInputItem is one item in the Responses API input array. // ResponsesInputItem is one item in the Responses API input array.
// The Type field determines which other fields are populated. // The Type field determines which other fields are populated.
type ResponsesInputItem struct { type ResponsesInputItem struct {
// Common // Common
Type string `json:"type,omitempty"` // "" for role-based messages Type string `json:"type,omitempty"` // "" for role-based messages
// Role-based messages (system/user/assistant) // Role-based messages (developer/system/user/assistant)
Role string `json:"role,omitempty"` Role string `json:"role,omitempty"`
Content json.RawMessage `json:"content,omitempty"` // string or []ResponsesContentPart Content json.RawMessage `json:"content,omitempty"` // string or []ResponsesContentPart
@ -314,7 +325,7 @@ type ResponsesOutputTokensDetails struct {
type ResponsesStreamEvent struct { type ResponsesStreamEvent struct {
Type string `json:"type"` Type string `json:"type"`
// response.created / response.completed / response.failed / response.incomplete // response.created / response.completed / response.done / response.failed / response.incomplete
Response *ResponsesResponse `json:"response,omitempty"` Response *ResponsesResponse `json:"response,omitempty"`
// response.output_item.added / response.output_item.done // response.output_item.added / response.output_item.done

View File

@ -0,0 +1,75 @@
// Package openai_compat 提供 OpenAI 协议族在不同上游间的能力差异判定工具。
//
// 背景sub2api 的 OpenAI APIKey 账号通过 base_url 接入多种第三方 OpenAI 兼容上游
// DeepSeek、Kimi、GLM、Qwen 等)。这些上游普遍只支持 /v1/chat/completions
// 不存在 /v1/responses 端点。但网关历史代码无差别走 CC→Responses 转换并打到
// /v1/responses导致兼容上游 404。
//
// 本包提供基于"账号探测标记"的能力判定,配合
// internal/service/openai_apikey_responses_probe.go 在创建/修改账号时一次性
// 探测并落标。
//
// 设计取舍:
// - 不维护静态 host 白名单——避免新增厂商时必须改代码(讨论沉淀于
// pensieve/short-term/knowledge/upstream-capability-detection-design-tradeoffs
// - 标记缺失时默认 true即"走 Responses"),保持与重构前老代码完全一致的存量
// 账号行为("现状即证据"原则;详见
// pensieve/short-term/maxims/preserve-existing-runtime-behavior-when-replacing-logic-in-stateful-systems
package openai_compat
// AccountResponsesSupport 描述账号上游对 OpenAI Responses API 的支持状态。
//
// 仅用于 platform=openai + type=apikey 的账号;其他账号类型不应调用本包判定。
type AccountResponsesSupport int
const (
// ResponsesSupportUnknown 表示账号尚未完成能力探测extra 字段缺失)。
// 上游路由层应按"现状即证据"原则默认走 Responses保持与重构前一致。
ResponsesSupportUnknown AccountResponsesSupport = iota
// ResponsesSupportYes 探测确认上游支持 /v1/responses。
ResponsesSupportYes
// ResponsesSupportNo 探测确认上游不支持 /v1/responses应走
// /v1/chat/completions 直转路径。
ResponsesSupportNo
)
// ExtraKeyResponsesSupported 是 accounts.extra JSON 中存储探测结果的键名。
// 值类型为 booltrue=支持、false=不支持、键缺失=未探测。
const ExtraKeyResponsesSupported = "openai_responses_supported"
// ResolveResponsesSupport 从账号的 extra map 中读取探测标记。
//
// 标记缺失或类型不匹配时返回 ResponsesSupportUnknown——调用方应按
// "未探测=保留旧行为=走 Responses" 处理(参见 ShouldUseResponsesAPI
func ResolveResponsesSupport(extra map[string]any) AccountResponsesSupport {
if extra == nil {
return ResponsesSupportUnknown
}
v, ok := extra[ExtraKeyResponsesSupported]
if !ok {
return ResponsesSupportUnknown
}
supported, ok := v.(bool)
if !ok {
return ResponsesSupportUnknown
}
if supported {
return ResponsesSupportYes
}
return ResponsesSupportNo
}
// ShouldUseResponsesAPI 判断 OpenAI APIKey 账号的入站 /v1/chat/completions 请求
// 是否应走"CC→Responses 转换 + 上游 /v1/responses"路径。
//
// 返回 true 的两种情况:
// 1. 账号已探测确认支持 Responses
// 2. 账号未探测(标记缺失)——按"现状即证据"原则保留旧行为
//
// 仅当账号已探测且确认不支持时返回 false此时调用方应走 CC 直转路径
// (详见 internal/service/openai_gateway_chat_completions_raw.go
func ShouldUseResponsesAPI(extra map[string]any) bool {
return ResolveResponsesSupport(extra) != ResponsesSupportNo
}

View File

@ -0,0 +1,55 @@
package openai_compat
import "testing"
func TestResolveResponsesSupport(t *testing.T) {
tests := []struct {
name string
extra map[string]any
want AccountResponsesSupport
}{
{"nil extra", nil, ResponsesSupportUnknown},
{"empty extra", map[string]any{}, ResponsesSupportUnknown},
{"key missing", map[string]any{"other": "value"}, ResponsesSupportUnknown},
{"value true", map[string]any{ExtraKeyResponsesSupported: true}, ResponsesSupportYes},
{"value false", map[string]any{ExtraKeyResponsesSupported: false}, ResponsesSupportNo},
{"value wrong type string", map[string]any{ExtraKeyResponsesSupported: "true"}, ResponsesSupportUnknown},
{"value wrong type number", map[string]any{ExtraKeyResponsesSupported: 1}, ResponsesSupportUnknown},
{"value nil", map[string]any{ExtraKeyResponsesSupported: nil}, ResponsesSupportUnknown},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
got := ResolveResponsesSupport(tc.extra)
if got != tc.want {
t.Errorf("ResolveResponsesSupport(%v) = %v, want %v", tc.extra, got, tc.want)
}
})
}
}
func TestShouldUseResponsesAPI(t *testing.T) {
tests := []struct {
name string
extra map[string]any
want bool
}{
// 关键不变量:未探测必须返回 true保留旧行为
{"unknown defaults to true (preserve old behavior)", nil, true},
{"unknown empty defaults to true", map[string]any{}, true},
{"unknown wrong type defaults to true", map[string]any{ExtraKeyResponsesSupported: "yes"}, true},
// 已探测:标记决定
{"explicitly supported", map[string]any{ExtraKeyResponsesSupported: true}, true},
{"explicitly unsupported", map[string]any{ExtraKeyResponsesSupported: false}, false},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
got := ShouldUseResponsesAPI(tc.extra)
if got != tc.want {
t.Errorf("ShouldUseResponsesAPI(%v) = %v, want %v", tc.extra, got, tc.want)
}
})
}
}

View File

@ -22,6 +22,34 @@ const (
var affiliateCodeCharset = []byte("ABCDEFGHJKLMNPQRSTUVWXYZ23456789") var affiliateCodeCharset = []byte("ABCDEFGHJKLMNPQRSTUVWXYZ23456789")
const affiliateUserOverviewSQL = `
SELECT ua.user_id,
COALESCE(u.email, ''),
COALESCE(u.username, ''),
ua.aff_code,
COALESCE(ua.aff_rebate_rate_percent, 0)::double precision,
(ua.aff_rebate_rate_percent IS NOT NULL) AS has_custom_rate,
ua.aff_count,
COALESCE(rebated.rebated_invitee_count, 0),
(ua.aff_quota + COALESCE(matured.matured_frozen_quota, 0))::double precision,
ua.aff_history_quota::double precision
FROM user_affiliates ua
JOIN users u ON u.id = ua.user_id
LEFT JOIN (
SELECT user_id, COUNT(DISTINCT source_user_id)::integer AS rebated_invitee_count
FROM user_affiliate_ledger
WHERE action = 'accrue' AND source_user_id IS NOT NULL
GROUP BY user_id
) rebated ON rebated.user_id = ua.user_id
LEFT JOIN (
SELECT user_id, COALESCE(SUM(amount), 0)::double precision AS matured_frozen_quota
FROM user_affiliate_ledger
WHERE action = 'accrue' AND frozen_until IS NOT NULL AND frozen_until <= NOW()
GROUP BY user_id
) matured ON matured.user_id = ua.user_id
WHERE ua.user_id = $1
LIMIT 1`
type affiliateQueryExecer interface { type affiliateQueryExecer interface {
QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error)
ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error)
@ -86,7 +114,7 @@ func (r *affiliateRepository) BindInviter(ctx context.Context, userID, inviterID
return bound, nil return bound, nil
} }
func (r *affiliateRepository) AccrueQuota(ctx context.Context, inviterID, inviteeUserID int64, amount float64, freezeHours int) (bool, error) { func (r *affiliateRepository) AccrueQuota(ctx context.Context, inviterID, inviteeUserID int64, amount float64, freezeHours int, sourceOrderID *int64) (bool, error) {
if amount <= 0 { if amount <= 0 {
return false, nil return false, nil
} }
@ -112,15 +140,15 @@ func (r *affiliateRepository) AccrueQuota(ctx context.Context, inviterID, invite
if freezeHours > 0 { if freezeHours > 0 {
if _, err = txClient.ExecContext(txCtx, ` if _, err = txClient.ExecContext(txCtx, `
INSERT INTO user_affiliate_ledger (user_id, action, amount, source_user_id, frozen_until, created_at, updated_at) INSERT INTO user_affiliate_ledger (user_id, action, amount, source_user_id, source_order_id, frozen_until, created_at, updated_at)
VALUES ($1, 'accrue', $2, $3, NOW() + make_interval(hours => $4), NOW(), NOW())`, VALUES ($1, 'accrue', $2, $3, $4, NOW() + make_interval(hours => $5), NOW(), NOW())`,
inviterID, amount, inviteeUserID, freezeHours); err != nil { inviterID, amount, inviteeUserID, nullableInt64Arg(sourceOrderID), freezeHours); err != nil {
return fmt.Errorf("insert affiliate accrue ledger: %w", err) return fmt.Errorf("insert affiliate accrue ledger: %w", err)
} }
} else { } else {
if _, err = txClient.ExecContext(txCtx, ` if _, err = txClient.ExecContext(txCtx, `
INSERT INTO user_affiliate_ledger (user_id, action, amount, source_user_id, created_at, updated_at) INSERT INTO user_affiliate_ledger (user_id, action, amount, source_user_id, source_order_id, created_at, updated_at)
VALUES ($1, 'accrue', $2, $3, NOW(), NOW())`, inviterID, amount, inviteeUserID); err != nil { VALUES ($1, 'accrue', $2, $3, $4, NOW(), NOW())`, inviterID, amount, inviteeUserID, nullableInt64Arg(sourceOrderID)); err != nil {
return fmt.Errorf("insert affiliate accrue ledger: %w", err) return fmt.Errorf("insert affiliate accrue ledger: %w", err)
} }
} }
@ -275,9 +303,32 @@ FROM cleared`, userID)
return err return err
} }
snapshot, err := queryAffiliateTransferSnapshot(txCtx, txClient, userID)
if err != nil {
return err
}
if _, err = txClient.ExecContext(txCtx, ` if _, err = txClient.ExecContext(txCtx, `
INSERT INTO user_affiliate_ledger (user_id, action, amount, source_user_id, created_at, updated_at) INSERT INTO user_affiliate_ledger (
VALUES ($1, 'transfer', $2, NULL, NOW(), NOW())`, userID, transferred); err != nil { user_id,
action,
amount,
source_user_id,
balance_after,
aff_quota_after,
aff_frozen_quota_after,
aff_history_quota_after,
created_at,
updated_at
)
VALUES ($1, 'transfer', $2, NULL, $3, $4, $5, $6, NOW(), NOW())`,
userID,
transferred,
snapshot.BalanceAfter,
snapshot.AvailableQuotaAfter,
snapshot.FrozenQuotaAfter,
snapshot.HistoryQuotaAfter,
); err != nil {
return fmt.Errorf("insert affiliate transfer ledger: %w", err) return fmt.Errorf("insert affiliate transfer ledger: %w", err)
} }
@ -332,6 +383,349 @@ LIMIT $2`, inviterID, limit)
return invitees, nil return invitees, nil
} }
func (r *affiliateRepository) ListAffiliateInviteRecords(ctx context.Context, filter service.AffiliateRecordFilter) ([]service.AffiliateInviteRecord, int64, error) {
client := clientFromContext(ctx, r.client)
where, args := buildAffiliateRecordWhere(filter, "ua.created_at", []string{
"inviter.email", "inviter.username", "invitee.email", "invitee.username",
"ua.inviter_id::text", "ua.user_id::text", "inviter_aff.aff_code",
})
total, err := queryAffiliateRecordCount(ctx, client, `
SELECT COUNT(*)
FROM user_affiliates ua
JOIN users invitee ON invitee.id = ua.user_id
JOIN users inviter ON inviter.id = ua.inviter_id
JOIN user_affiliates inviter_aff ON inviter_aff.user_id = ua.inviter_id
`+where, args...)
if err != nil {
return nil, 0, err
}
orderBy := buildAffiliateRecordOrderBy(filter, map[string]string{
"inviter": "inviter.email",
"invitee": "invitee.email",
"aff_code": "inviter_aff.aff_code",
"total_rebate": "total_rebate",
"created_at": "ua.created_at",
}, "ua.created_at")
args = append(args, filter.PageSize, (filter.Page-1)*filter.PageSize)
rows, err := client.QueryContext(ctx, `
SELECT ua.inviter_id,
COALESCE(inviter.email, ''),
COALESCE(inviter.username, ''),
ua.user_id,
COALESCE(invitee.email, ''),
COALESCE(invitee.username, ''),
COALESCE(inviter_aff.aff_code, ''),
COALESCE(SUM(ual.amount), 0)::double precision AS total_rebate,
ua.created_at
FROM user_affiliates ua
JOIN users invitee ON invitee.id = ua.user_id
JOIN users inviter ON inviter.id = ua.inviter_id
JOIN user_affiliates inviter_aff ON inviter_aff.user_id = ua.inviter_id
LEFT JOIN user_affiliate_ledger ual
ON ual.user_id = ua.inviter_id
AND ual.source_user_id = ua.user_id
AND ual.action = 'accrue'
`+where+`
GROUP BY ua.inviter_id, inviter.email, inviter.username, ua.user_id, invitee.email, invitee.username, inviter_aff.aff_code, ua.created_at
`+orderBy+`
LIMIT $`+fmt.Sprint(len(args)-1)+` OFFSET $`+fmt.Sprint(len(args)), args...)
if err != nil {
return nil, 0, err
}
defer func() { _ = rows.Close() }()
items := make([]service.AffiliateInviteRecord, 0)
for rows.Next() {
var item service.AffiliateInviteRecord
if err := rows.Scan(
&item.InviterID,
&item.InviterEmail,
&item.InviterUsername,
&item.InviteeID,
&item.InviteeEmail,
&item.InviteeUsername,
&item.AffCode,
&item.TotalRebate,
&item.CreatedAt,
); err != nil {
return nil, 0, err
}
items = append(items, item)
}
if err := rows.Err(); err != nil {
return nil, 0, err
}
return items, total, nil
}
func (r *affiliateRepository) ListAffiliateRebateRecords(ctx context.Context, filter service.AffiliateRecordFilter) ([]service.AffiliateRebateRecord, int64, error) {
client := clientFromContext(ctx, r.client)
where, args := buildAffiliateRecordWhere(filter, "ual.created_at", []string{
"inviter.email", "inviter.username", "invitee.email", "invitee.username",
"po.id::text", "po.out_trade_no", "po.payment_type", "po.status",
})
baseJoin := `
FROM user_affiliate_ledger ual
JOIN payment_orders po ON po.id = ual.source_order_id
JOIN users invitee ON invitee.id = ual.source_user_id
JOIN users inviter ON inviter.id = ual.user_id
WHERE ual.action = 'accrue'
AND ual.source_order_id IS NOT NULL`
if where != "" {
where = strings.Replace(where, "WHERE ", " AND ", 1)
}
total, err := queryAffiliateRecordCount(ctx, client, "SELECT COUNT(*) "+baseJoin+where, args...)
if err != nil {
return nil, 0, err
}
orderBy := buildAffiliateRecordOrderBy(filter, map[string]string{
"order": "po.id",
"inviter": "inviter.email",
"invitee": "invitee.email",
"order_amount": "po.amount",
"pay_amount": "po.pay_amount",
"rebate_amount": "ual.amount",
"payment_type": "po.payment_type",
"order_status": "po.status",
"created_at": "ual.created_at",
}, "ual.created_at")
args = append(args, filter.PageSize, (filter.Page-1)*filter.PageSize)
rows, err := client.QueryContext(ctx, `
SELECT po.id,
po.out_trade_no,
ual.user_id,
COALESCE(inviter.email, ''),
COALESCE(inviter.username, ''),
ual.source_user_id,
COALESCE(invitee.email, ''),
COALESCE(invitee.username, ''),
po.amount::double precision,
po.pay_amount::double precision,
ual.amount::double precision,
po.payment_type,
po.status,
ual.created_at
`+baseJoin+where+`
`+orderBy+`
LIMIT $`+fmt.Sprint(len(args)-1)+` OFFSET $`+fmt.Sprint(len(args)), args...)
if err != nil {
return nil, 0, err
}
defer func() { _ = rows.Close() }()
items := make([]service.AffiliateRebateRecord, 0)
for rows.Next() {
var item service.AffiliateRebateRecord
if err := rows.Scan(
&item.OrderID,
&item.OutTradeNo,
&item.InviterID,
&item.InviterEmail,
&item.InviterUsername,
&item.InviteeID,
&item.InviteeEmail,
&item.InviteeUsername,
&item.OrderAmount,
&item.PayAmount,
&item.RebateAmount,
&item.PaymentType,
&item.OrderStatus,
&item.CreatedAt,
); err != nil {
return nil, 0, err
}
items = append(items, item)
}
if err := rows.Err(); err != nil {
return nil, 0, err
}
return items, total, nil
}
func (r *affiliateRepository) ListAffiliateTransferRecords(ctx context.Context, filter service.AffiliateRecordFilter) ([]service.AffiliateTransferRecord, int64, error) {
client := clientFromContext(ctx, r.client)
where, args := buildAffiliateRecordWhere(filter, "ual.created_at", []string{
"u.email", "u.username", "u.id::text",
})
baseJoin := `
FROM user_affiliate_ledger ual
JOIN users u ON u.id = ual.user_id
WHERE ual.action = 'transfer'`
if where != "" {
where = strings.Replace(where, "WHERE ", " AND ", 1)
}
total, err := queryAffiliateRecordCount(ctx, client, "SELECT COUNT(*) "+baseJoin+where, args...)
if err != nil {
return nil, 0, err
}
orderBy := buildAffiliateRecordOrderBy(filter, map[string]string{
"user": "u.email",
"amount": "ual.amount",
"balance_after": "ual.balance_after",
"available_quota_after": "ual.aff_quota_after",
"frozen_quota_after": "ual.aff_frozen_quota_after",
"history_quota_after": "ual.aff_history_quota_after",
"created_at": "ual.created_at",
}, "ual.created_at")
args = append(args, filter.PageSize, (filter.Page-1)*filter.PageSize)
rows, err := client.QueryContext(ctx, `
SELECT ual.id,
ual.user_id,
COALESCE(u.email, ''),
COALESCE(u.username, ''),
ual.amount::double precision,
ual.balance_after::double precision,
ual.aff_quota_after::double precision,
ual.aff_frozen_quota_after::double precision,
ual.aff_history_quota_after::double precision,
ual.created_at
`+baseJoin+where+`
`+orderBy+`
LIMIT $`+fmt.Sprint(len(args)-1)+` OFFSET $`+fmt.Sprint(len(args)), args...)
if err != nil {
return nil, 0, err
}
defer func() { _ = rows.Close() }()
items := make([]service.AffiliateTransferRecord, 0)
for rows.Next() {
var item service.AffiliateTransferRecord
var balanceAfter sql.NullFloat64
var availableQuotaAfter sql.NullFloat64
var frozenQuotaAfter sql.NullFloat64
var historyQuotaAfter sql.NullFloat64
if err := rows.Scan(
&item.LedgerID,
&item.UserID,
&item.UserEmail,
&item.Username,
&item.Amount,
&balanceAfter,
&availableQuotaAfter,
&frozenQuotaAfter,
&historyQuotaAfter,
&item.CreatedAt,
); err != nil {
return nil, 0, err
}
item.BalanceAfter = nullableFloat64Ptr(balanceAfter)
item.AvailableQuotaAfter = nullableFloat64Ptr(availableQuotaAfter)
item.FrozenQuotaAfter = nullableFloat64Ptr(frozenQuotaAfter)
item.HistoryQuotaAfter = nullableFloat64Ptr(historyQuotaAfter)
item.SnapshotAvailable = balanceAfter.Valid &&
availableQuotaAfter.Valid &&
frozenQuotaAfter.Valid &&
historyQuotaAfter.Valid
items = append(items, item)
}
if err := rows.Err(); err != nil {
return nil, 0, err
}
return items, total, nil
}
func (r *affiliateRepository) GetAffiliateUserOverview(ctx context.Context, userID int64) (*service.AffiliateUserOverview, error) {
if userID <= 0 {
return nil, service.ErrUserNotFound
}
client := clientFromContext(ctx, r.client)
rows, err := client.QueryContext(ctx, affiliateUserOverviewSQL, userID)
if err != nil {
return nil, err
}
defer func() { _ = rows.Close() }()
if !rows.Next() {
if err := rows.Err(); err != nil {
return nil, err
}
return nil, service.ErrUserNotFound
}
var overview service.AffiliateUserOverview
var customRate float64
var hasCustomRate bool
if err := rows.Scan(
&overview.UserID,
&overview.Email,
&overview.Username,
&overview.AffCode,
&customRate,
&hasCustomRate,
&overview.InvitedCount,
&overview.RebatedInviteeCount,
&overview.AvailableQuota,
&overview.HistoryQuota,
); err != nil {
return nil, err
}
if hasCustomRate {
overview.RebateRatePercent = customRate
overview.RebateRateCustom = true
}
return &overview, rows.Err()
}
func buildAffiliateRecordWhere(filter service.AffiliateRecordFilter, timeColumn string, searchColumns []string) (string, []any) {
clauses := make([]string, 0, 3)
args := make([]any, 0, 3)
if filter.StartAt != nil {
args = append(args, *filter.StartAt)
clauses = append(clauses, fmt.Sprintf("%s >= $%d", timeColumn, len(args)))
}
if filter.EndAt != nil {
args = append(args, *filter.EndAt)
clauses = append(clauses, fmt.Sprintf("%s <= $%d", timeColumn, len(args)))
}
search := strings.TrimSpace(filter.Search)
if search != "" && len(searchColumns) > 0 {
args = append(args, "%"+strings.ToLower(search)+"%")
parts := make([]string, 0, len(searchColumns))
for _, col := range searchColumns {
parts = append(parts, fmt.Sprintf("LOWER(%s) LIKE $%d", col, len(args)))
}
clauses = append(clauses, "("+strings.Join(parts, " OR ")+")")
}
if len(clauses) == 0 {
return "", args
}
return "WHERE " + strings.Join(clauses, " AND "), args
}
func buildAffiliateRecordOrderBy(filter service.AffiliateRecordFilter, sortColumns map[string]string, fallbackColumn string) string {
column := sortColumns[filter.SortBy]
if column == "" {
column = fallbackColumn
}
direction := "DESC"
if !filter.SortDesc {
direction = "ASC"
}
return "ORDER BY " + column + " " + direction + " NULLS LAST"
}
func queryAffiliateRecordCount(ctx context.Context, client affiliateQueryExecer, query string, args ...any) (int64, error) {
rows, err := client.QueryContext(ctx, query, args...)
if err != nil {
return 0, err
}
defer func() { _ = rows.Close() }()
if !rows.Next() {
return 0, rows.Err()
}
var total int64
if err := rows.Scan(&total); err != nil {
return 0, err
}
return total, rows.Err()
}
func (r *affiliateRepository) withTx(ctx context.Context, fn func(txCtx context.Context, txClient *dbent.Client) error) error { func (r *affiliateRepository) withTx(ctx context.Context, fn func(txCtx context.Context, txClient *dbent.Client) error) error {
if tx := dbent.TxFromContext(ctx); tx != nil { if tx := dbent.TxFromContext(ctx); tx != nil {
return fn(ctx, tx.Client()) return fn(ctx, tx.Client())
@ -516,6 +910,54 @@ func queryUserBalance(ctx context.Context, client affiliateQueryExecer, userID i
return balance, nil return balance, nil
} }
type affiliateTransferSnapshot struct {
BalanceAfter float64
AvailableQuotaAfter float64
FrozenQuotaAfter float64
HistoryQuotaAfter float64
}
func queryAffiliateTransferSnapshot(ctx context.Context, client affiliateQueryExecer, userID int64) (*affiliateTransferSnapshot, error) {
rows, err := client.QueryContext(ctx, `
SELECT u.balance::double precision,
ua.aff_quota::double precision,
ua.aff_frozen_quota::double precision,
ua.aff_history_quota::double precision
FROM users u
JOIN user_affiliates ua ON ua.user_id = u.id
WHERE u.id = $1
LIMIT 1`, userID)
if err != nil {
return nil, fmt.Errorf("query affiliate transfer snapshot: %w", err)
}
defer func() { _ = rows.Close() }()
if !rows.Next() {
if err := rows.Err(); err != nil {
return nil, err
}
return nil, service.ErrUserNotFound
}
var snapshot affiliateTransferSnapshot
if err := rows.Scan(
&snapshot.BalanceAfter,
&snapshot.AvailableQuotaAfter,
&snapshot.FrozenQuotaAfter,
&snapshot.HistoryQuotaAfter,
); err != nil {
return nil, err
}
return &snapshot, rows.Err()
}
func nullableFloat64Ptr(v sql.NullFloat64) *float64 {
if !v.Valid {
return nil
}
return &v.Float64
}
func generateAffiliateCode() (string, error) { func generateAffiliateCode() (string, error) {
buf := make([]byte, affiliateCodeLength) buf := make([]byte, affiliateCodeLength)
if _, err := rand.Read(buf); err != nil { if _, err := rand.Read(buf); err != nil {
@ -674,6 +1116,13 @@ func nullableArg(v *float64) any {
return *v return *v
} }
func nullableInt64Arg(v *int64) any {
if v == nil {
return nil
}
return *v
}
// ListUsersWithCustomSettings 列出有专属配置(自定义码或专属比例)的用户。 // ListUsersWithCustomSettings 列出有专属配置(自定义码或专属比例)的用户。
// //
// 单一查询同时处理"无搜索"与"按邮箱/用户名模糊搜索" // 单一查询同时处理"无搜索"与"按邮箱/用户名模糊搜索"

View File

@ -78,6 +78,26 @@ VALUES ($1, $2, $3, $3, NOW(), NOW())`, u.ID, affCode, 12.34)
ledgerCount := querySingleInt(t, txCtx, client, ledgerCount := querySingleInt(t, txCtx, client,
"SELECT COUNT(*) FROM user_affiliate_ledger WHERE user_id = $1 AND action = 'transfer'", u.ID) "SELECT COUNT(*) FROM user_affiliate_ledger WHERE user_id = $1 AND action = 'transfer'", u.ID)
require.Equal(t, 1, ledgerCount) require.Equal(t, 1, ledgerCount)
rows, err := client.QueryContext(txCtx, `
SELECT amount::double precision,
balance_after::double precision,
aff_quota_after::double precision,
aff_frozen_quota_after::double precision,
aff_history_quota_after::double precision
FROM user_affiliate_ledger
WHERE user_id = $1 AND action = 'transfer'
LIMIT 1`, u.ID)
require.NoError(t, err)
defer func() { _ = rows.Close() }()
require.True(t, rows.Next(), "expected transfer ledger")
var amount, balanceAfter, quotaAfter, frozenAfter, historyAfter float64
require.NoError(t, rows.Scan(&amount, &balanceAfter, &quotaAfter, &frozenAfter, &historyAfter))
require.InDelta(t, 12.34, amount, 1e-9)
require.InDelta(t, 17.84, balanceAfter, 1e-9)
require.InDelta(t, 0.0, quotaAfter, 1e-9)
require.InDelta(t, 0.0, frozenAfter, 1e-9)
require.InDelta(t, 12.34, historyAfter, 1e-9)
} }
// TestAffiliateRepository_AccrueQuota_ReusesOuterTransaction guards the // TestAffiliateRepository_AccrueQuota_ReusesOuterTransaction guards the
@ -125,7 +145,7 @@ func TestAffiliateRepository_AccrueQuota_ReusesOuterTransaction(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
require.True(t, bound, "invitee must bind to inviter") require.True(t, bound, "invitee must bind to inviter")
applied, err := repo.AccrueQuota(txCtx, inviter.ID, invitee.ID, 3.5, 0) applied, err := repo.AccrueQuota(txCtx, inviter.ID, invitee.ID, 3.5, 0, nil)
require.NoError(t, err) require.NoError(t, err)
require.True(t, applied, "AccrueQuota must report applied=true") require.True(t, applied, "AccrueQuota must report applied=true")

View File

@ -0,0 +1,28 @@
package repository
import (
"os"
"strings"
"testing"
"github.com/stretchr/testify/require"
)
func TestAffiliateUserOverviewSQLIncludesMaturedFrozenQuota(t *testing.T) {
query := strings.Join(strings.Fields(affiliateUserOverviewSQL), " ")
require.Contains(t, query, "ua.aff_quota + COALESCE(matured.matured_frozen_quota, 0)")
require.Contains(t, query, "frozen_until <= NOW()")
}
func TestAffiliateRecordQueriesUseLedgerAuditFields(t *testing.T) {
source, err := os.ReadFile("affiliate_repo.go")
require.NoError(t, err)
content := string(source)
require.Contains(t, content, "JOIN payment_orders po ON po.id = ual.source_order_id")
require.Contains(t, content, "ual.amount::double precision")
require.Contains(t, content, "ual.balance_after::double precision")
require.NotContains(t, content, "parseAffiliateRebateAmount")
require.NotContains(t, content, `"current_balance": "u.balance"`)
}

View File

@ -166,6 +166,9 @@ func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*se
group.FieldDailyLimitUsd, group.FieldDailyLimitUsd,
group.FieldWeeklyLimitUsd, group.FieldWeeklyLimitUsd,
group.FieldMonthlyLimitUsd, group.FieldMonthlyLimitUsd,
group.FieldAllowImageGeneration,
group.FieldImageRateIndependent,
group.FieldImageRateMultiplier,
group.FieldImagePrice1k, group.FieldImagePrice1k,
group.FieldImagePrice2k, group.FieldImagePrice2k,
group.FieldImagePrice4k, group.FieldImagePrice4k,
@ -699,6 +702,9 @@ func groupEntityToService(g *dbent.Group) *service.Group {
DailyLimitUSD: g.DailyLimitUsd, DailyLimitUSD: g.DailyLimitUsd,
WeeklyLimitUSD: g.WeeklyLimitUsd, WeeklyLimitUSD: g.WeeklyLimitUsd,
MonthlyLimitUSD: g.MonthlyLimitUsd, MonthlyLimitUSD: g.MonthlyLimitUsd,
AllowImageGeneration: g.AllowImageGeneration,
ImageRateIndependent: g.ImageRateIndependent,
ImageRateMultiplier: g.ImageRateMultiplier,
ImagePrice1K: g.ImagePrice1k, ImagePrice1K: g.ImagePrice1k,
ImagePrice2K: g.ImagePrice2k, ImagePrice2K: g.ImagePrice2k,
ImagePrice4K: g.ImagePrice4k, ImagePrice4K: g.ImagePrice4k,

View File

@ -50,6 +50,9 @@ func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) er
SetNillableDailyLimitUsd(groupIn.DailyLimitUSD). SetNillableDailyLimitUsd(groupIn.DailyLimitUSD).
SetNillableWeeklyLimitUsd(groupIn.WeeklyLimitUSD). SetNillableWeeklyLimitUsd(groupIn.WeeklyLimitUSD).
SetNillableMonthlyLimitUsd(groupIn.MonthlyLimitUSD). SetNillableMonthlyLimitUsd(groupIn.MonthlyLimitUSD).
SetAllowImageGeneration(groupIn.AllowImageGeneration).
SetImageRateIndependent(groupIn.ImageRateIndependent).
SetImageRateMultiplier(groupIn.ImageRateMultiplier).
SetNillableImagePrice1k(groupIn.ImagePrice1K). SetNillableImagePrice1k(groupIn.ImagePrice1K).
SetNillableImagePrice2k(groupIn.ImagePrice2K). SetNillableImagePrice2k(groupIn.ImagePrice2K).
SetNillableImagePrice4k(groupIn.ImagePrice4K). SetNillableImagePrice4k(groupIn.ImagePrice4K).
@ -120,6 +123,9 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er
SetNillableDailyLimitUsd(groupIn.DailyLimitUSD). SetNillableDailyLimitUsd(groupIn.DailyLimitUSD).
SetNillableWeeklyLimitUsd(groupIn.WeeklyLimitUSD). SetNillableWeeklyLimitUsd(groupIn.WeeklyLimitUSD).
SetNillableMonthlyLimitUsd(groupIn.MonthlyLimitUSD). SetNillableMonthlyLimitUsd(groupIn.MonthlyLimitUSD).
SetAllowImageGeneration(groupIn.AllowImageGeneration).
SetImageRateIndependent(groupIn.ImageRateIndependent).
SetImageRateMultiplier(groupIn.ImageRateMultiplier).
SetNillableImagePrice1k(groupIn.ImagePrice1K). SetNillableImagePrice1k(groupIn.ImagePrice1K).
SetNillableImagePrice2k(groupIn.ImagePrice2K). SetNillableImagePrice2k(groupIn.ImagePrice2K).
SetNillableImagePrice4k(groupIn.ImagePrice4K). SetNillableImagePrice4k(groupIn.ImagePrice4K).

View File

@ -328,6 +328,9 @@ func TestAPIContracts(t *testing.T) {
"image_price_1k": null, "image_price_1k": null,
"image_price_2k": null, "image_price_2k": null,
"image_price_4k": null, "image_price_4k": null,
"allow_image_generation": false,
"image_rate_independent": false,
"image_rate_multiplier": 0,
"claude_code_only": false, "claude_code_only": false,
"allow_messages_dispatch": false, "allow_messages_dispatch": false,
"fallback_group_id": null, "fallback_group_id": null,

View File

@ -412,6 +412,9 @@ func registerSettingsRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
// 529过载冷却配置 // 529过载冷却配置
adminSettings.GET("/overload-cooldown", h.Admin.Setting.GetOverloadCooldownSettings) adminSettings.GET("/overload-cooldown", h.Admin.Setting.GetOverloadCooldownSettings)
adminSettings.PUT("/overload-cooldown", h.Admin.Setting.UpdateOverloadCooldownSettings) adminSettings.PUT("/overload-cooldown", h.Admin.Setting.UpdateOverloadCooldownSettings)
// 429默认回避配置
adminSettings.GET("/rate-limit-429-cooldown", h.Admin.Setting.GetRateLimit429CooldownSettings)
adminSettings.PUT("/rate-limit-429-cooldown", h.Admin.Setting.UpdateRateLimit429CooldownSettings)
// 流超时处理配置 // 流超时处理配置
adminSettings.GET("/stream-timeout", h.Admin.Setting.GetStreamTimeoutSettings) adminSettings.GET("/stream-timeout", h.Admin.Setting.GetStreamTimeoutSettings)
adminSettings.PUT("/stream-timeout", h.Admin.Setting.UpdateStreamTimeoutSettings) adminSettings.PUT("/stream-timeout", h.Admin.Setting.UpdateStreamTimeoutSettings)
@ -624,11 +627,16 @@ func registerChannelMonitorRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
func registerAffiliateRoutes(admin *gin.RouterGroup, h *handler.Handlers) { func registerAffiliateRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
affiliates := admin.Group("/affiliates") affiliates := admin.Group("/affiliates")
{ {
affiliates.GET("/invites", h.Admin.Affiliate.ListInviteRecords)
affiliates.GET("/rebates", h.Admin.Affiliate.ListRebateRecords)
affiliates.GET("/transfers", h.Admin.Affiliate.ListTransferRecords)
users := affiliates.Group("/users") users := affiliates.Group("/users")
{ {
users.GET("", h.Admin.Affiliate.ListUsers) users.GET("", h.Admin.Affiliate.ListUsers)
users.GET("/lookup", h.Admin.Affiliate.LookupUsers) users.GET("/lookup", h.Admin.Affiliate.LookupUsers)
users.POST("/batch-rate", h.Admin.Affiliate.BatchSetRate) users.POST("/batch-rate", h.Admin.Affiliate.BatchSetRate)
users.GET("/:user_id/overview", h.Admin.Affiliate.GetUserOverview)
users.PUT("/:user_id", h.Admin.Affiliate.UpdateUserSettings) users.PUT("/:user_id", h.Admin.Affiliate.UpdateUserSettings)
users.DELETE("/:user_id", h.Admin.Affiliate.ClearUserSettings) users.DELETE("/:user_id", h.Admin.Affiliate.ClearUserSettings)
} }

View File

@ -230,7 +230,11 @@ func applyAccountStatsCost(
if model == "" { if model == "" {
model = requestedModel model = requestedModel
} }
requestCount := 1
if usageLog != nil && usageLog.ImageCount > 0 {
requestCount = usageLog.ImageCount
}
usageLog.AccountStatsCost = resolveAccountStatsCost( usageLog.AccountStatsCost = resolveAccountStatsCost(
ctx, cs, bs, accountID, groupID, model, tokens, 1, totalCost, ctx, cs, bs, accountID, groupID, model, tokens, requestCount, totalCost,
) )
} }

View File

@ -21,6 +21,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/pkg/claude" "github.com/Wei-Shaw/sub2api/internal/pkg/claude"
"github.com/Wei-Shaw/sub2api/internal/pkg/geminicli" "github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai" "github.com/Wei-Shaw/sub2api/internal/pkg/openai"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai_compat"
"github.com/Wei-Shaw/sub2api/internal/pkg/windsurf" "github.com/Wei-Shaw/sub2api/internal/pkg/windsurf"
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator" "github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
@ -571,7 +572,16 @@ func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account
if err != nil { if err != nil {
return s.sendErrorAndEnd(c, fmt.Sprintf("Invalid base URL: %s", err.Error())) return s.sendErrorAndEnd(c, fmt.Sprintf("Invalid base URL: %s", err.Error()))
} }
apiURL = strings.TrimSuffix(normalizedBaseURL, "/") + "/responses" // 账号已被探测为不支持 Responses如 DeepSeek/Kimi 等)时,丢出明确提示。
// 账号本身可用(网关会走 CC 直转),仅测试入口需要补齐 CC SSE 处理逻辑。
// TODO实现 CC 格式的账号测试路径(需专门的 CC SSE handler
if !openai_compat.ShouldUseResponsesAPI(account.Extra) {
return s.sendErrorAndEnd(c,
"账号已被探测为不支持 OpenAI Responses API如 DeepSeek/Kimi 等三方兼容上游),"+
"账号本身可正常使用,但当前测试接口仅支持 Responses API 路径。请直接通过实际 API 调用验证。",
)
}
apiURL = buildOpenAIResponsesURL(normalizedBaseURL)
} else { } else {
return s.sendErrorAndEnd(c, fmt.Sprintf("Unsupported account type: %s", account.Type)) return s.sendErrorAndEnd(c, fmt.Sprintf("Unsupported account type: %s", account.Type))
} }

View File

@ -0,0 +1,86 @@
package service
import (
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/stretchr/testify/require"
)
func TestMergeBalanceHistoryCodesIncludesAffiliateTransfersByDefault(t *testing.T) {
t.Parallel()
now := time.Date(2026, 5, 3, 12, 0, 0, 0, time.UTC)
older := now.Add(-2 * time.Hour)
newer := now.Add(time.Hour)
usedBy := int64(10)
redeemCodes := []RedeemCode{
{
ID: 1,
Type: RedeemTypeBalance,
Value: 8,
Status: StatusUsed,
UsedBy: &usedBy,
UsedAt: &now,
CreatedAt: now,
},
{
ID: 2,
Type: RedeemTypeConcurrency,
Value: 1,
Status: StatusUsed,
UsedBy: &usedBy,
UsedAt: &older,
CreatedAt: older,
},
}
affiliateCodes := []RedeemCode{
{
ID: -20,
Type: RedeemTypeAffiliateBalance,
Value: 3.5,
Status: StatusUsed,
UsedBy: &usedBy,
UsedAt: &newer,
CreatedAt: newer,
},
}
got := mergeBalanceHistoryCodes(redeemCodes, affiliateCodes, pagination.PaginationParams{
Page: 1,
PageSize: 2,
})
require.Len(t, got, 2)
require.Equal(t, RedeemTypeAffiliateBalance, got[0].Type)
require.Equal(t, RedeemTypeBalance, got[1].Type)
}
func TestMergeBalanceHistoryCodesPaginatesAfterCombiningSources(t *testing.T) {
t.Parallel()
base := time.Date(2026, 5, 3, 12, 0, 0, 0, time.UTC)
usedBy := int64(10)
at := func(hours int) *time.Time {
v := base.Add(time.Duration(hours) * time.Hour)
return &v
}
got := mergeBalanceHistoryCodes(
[]RedeemCode{
{ID: 1, Type: RedeemTypeBalance, UsedBy: &usedBy, UsedAt: at(4), CreatedAt: *at(4)},
{ID: 2, Type: RedeemTypeConcurrency, UsedBy: &usedBy, UsedAt: at(2), CreatedAt: *at(2)},
},
[]RedeemCode{
{ID: -3, Type: RedeemTypeAffiliateBalance, UsedBy: &usedBy, UsedAt: at(3), CreatedAt: *at(3)},
{ID: -4, Type: RedeemTypeAffiliateBalance, UsedBy: &usedBy, UsedAt: at(1), CreatedAt: *at(1)},
},
pagination.PaginationParams{Page: 2, PageSize: 2},
)
require.Len(t, got, 2)
require.Equal(t, RedeemTypeConcurrency, got[0].Type)
require.Equal(t, int64(-4), got[1].ID)
}

View File

@ -2,6 +2,7 @@ package service
import ( import (
"context" "context"
"database/sql"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
@ -188,11 +189,14 @@ type CreateGroupInput struct {
WeeklyLimitUSD *float64 // 周限额 (USD) WeeklyLimitUSD *float64 // 周限额 (USD)
MonthlyLimitUSD *float64 // 月限额 (USD) MonthlyLimitUSD *float64 // 月限额 (USD)
// 图片生成计费配置(仅 antigravity 平台使用) // 图片生成计费配置(仅 antigravity 平台使用)
ImagePrice1K *float64 AllowImageGeneration bool
ImagePrice2K *float64 ImageRateIndependent bool
ImagePrice4K *float64 ImageRateMultiplier *float64
ClaudeCodeOnly bool // 仅允许 Claude Code 客户端 ImagePrice1K *float64
FallbackGroupID *int64 // 降级分组 ID ImagePrice2K *float64
ImagePrice4K *float64
ClaudeCodeOnly bool // 仅允许 Claude Code 客户端
FallbackGroupID *int64 // 降级分组 ID
// 无效请求兜底分组 ID仅 anthropic 平台使用) // 无效请求兜底分组 ID仅 anthropic 平台使用)
FallbackGroupIDOnInvalidRequest *int64 FallbackGroupIDOnInvalidRequest *int64
// 模型路由配置(仅 anthropic 平台使用) // 模型路由配置(仅 anthropic 平台使用)
@ -225,11 +229,14 @@ type UpdateGroupInput struct {
WeeklyLimitUSD *float64 // 周限额 (USD) WeeklyLimitUSD *float64 // 周限额 (USD)
MonthlyLimitUSD *float64 // 月限额 (USD) MonthlyLimitUSD *float64 // 月限额 (USD)
// 图片生成计费配置(仅 antigravity 平台使用) // 图片生成计费配置(仅 antigravity 平台使用)
ImagePrice1K *float64 AllowImageGeneration *bool
ImagePrice2K *float64 ImageRateIndependent *bool
ImagePrice4K *float64 ImageRateMultiplier *float64
ClaudeCodeOnly *bool // 仅允许 Claude Code 客户端 ImagePrice1K *float64
FallbackGroupID *int64 // 降级分组 ID ImagePrice2K *float64
ImagePrice4K *float64
ClaudeCodeOnly *bool // 仅允许 Claude Code 客户端
FallbackGroupID *int64 // 降级分组 ID
// 无效请求兜底分组 ID仅 anthropic 平台使用) // 无效请求兜底分组 ID仅 anthropic 平台使用)
FallbackGroupIDOnInvalidRequest *int64 FallbackGroupIDOnInvalidRequest *int64
// 模型路由配置(仅 anthropic 平台使用) // 模型路由配置(仅 anthropic 平台使用)
@ -973,16 +980,213 @@ func (s *adminServiceImpl) GetUserUsageStats(ctx context.Context, userID int64,
// GetUserBalanceHistory returns paginated balance/concurrency change records for a user. // GetUserBalanceHistory returns paginated balance/concurrency change records for a user.
func (s *adminServiceImpl) GetUserBalanceHistory(ctx context.Context, userID int64, page, pageSize int, codeType string) ([]RedeemCode, int64, float64, error) { func (s *adminServiceImpl) GetUserBalanceHistory(ctx context.Context, userID int64, page, pageSize int, codeType string) ([]RedeemCode, int64, float64, error) {
params := pagination.PaginationParams{Page: page, PageSize: pageSize} params := pagination.PaginationParams{Page: page, PageSize: pageSize}
if codeType == RedeemTypeAffiliateBalance {
codes, total, err := s.listAffiliateBalanceHistory(ctx, userID, params)
if err != nil {
return nil, 0, 0, err
}
totalRecharged, err := s.redeemCodeRepo.SumPositiveBalanceByUser(ctx, userID)
if err != nil {
return nil, 0, 0, err
}
return codes, total, totalRecharged, nil
}
if codeType == "" {
return s.getAllUserBalanceHistory(ctx, userID, params)
}
codes, result, err := s.redeemCodeRepo.ListByUserPaginated(ctx, userID, params, codeType) codes, result, err := s.redeemCodeRepo.ListByUserPaginated(ctx, userID, params, codeType)
if err != nil { if err != nil {
return nil, 0, 0, err return nil, 0, 0, err
} }
total := result.Total
// Aggregate total recharged amount (only once, regardless of type filter) // Aggregate total recharged amount (only once, regardless of type filter)
totalRecharged, err := s.redeemCodeRepo.SumPositiveBalanceByUser(ctx, userID) totalRecharged, err := s.redeemCodeRepo.SumPositiveBalanceByUser(ctx, userID)
if err != nil { if err != nil {
return nil, 0, 0, err return nil, 0, 0, err
} }
return codes, result.Total, totalRecharged, nil return codes, total, totalRecharged, nil
}
func (s *adminServiceImpl) getAllUserBalanceHistory(ctx context.Context, userID int64, params pagination.PaginationParams) ([]RedeemCode, int64, float64, error) {
needed := params.Offset() + params.Limit()
if needed < params.Limit() {
needed = params.Limit()
}
redeemCodes, redeemTotal, err := s.listRedeemBalanceHistoryForMerge(ctx, userID, needed)
if err != nil {
return nil, 0, 0, err
}
affiliateCodes, affiliateTotal, err := s.listAffiliateBalanceHistoryForMerge(ctx, userID, needed)
if err != nil {
return nil, 0, 0, err
}
codes := mergeBalanceHistoryCodes(redeemCodes, affiliateCodes, params)
totalRecharged, err := s.redeemCodeRepo.SumPositiveBalanceByUser(ctx, userID)
if err != nil {
return nil, 0, 0, err
}
return codes, redeemTotal + affiliateTotal, totalRecharged, nil
}
func (s *adminServiceImpl) listRedeemBalanceHistoryForMerge(ctx context.Context, userID int64, needed int) ([]RedeemCode, int64, error) {
if needed <= 0 {
return nil, 0, nil
}
var (
out []RedeemCode
total int64
)
for page := 1; len(out) < needed; page++ {
params := pagination.PaginationParams{Page: page, PageSize: 1000}
codes, result, err := s.redeemCodeRepo.ListByUserPaginated(ctx, userID, params, "")
if err != nil {
return nil, 0, err
}
if result != nil {
total = result.Total
}
out = append(out, codes...)
if len(codes) < params.Limit() || int64(len(out)) >= total {
break
}
}
if len(out) > needed {
out = out[:needed]
}
return out, total, nil
}
func (s *adminServiceImpl) listAffiliateBalanceHistoryForMerge(ctx context.Context, userID int64, needed int) ([]RedeemCode, int64, error) {
if needed <= 0 {
return nil, 0, nil
}
var (
out []RedeemCode
total int64
)
for page := 1; len(out) < needed; page++ {
params := pagination.PaginationParams{Page: page, PageSize: 1000}
codes, currentTotal, err := s.listAffiliateBalanceHistory(ctx, userID, params)
if err != nil {
return nil, 0, err
}
total = currentTotal
out = append(out, codes...)
if len(codes) < params.Limit() || int64(len(out)) >= total {
break
}
}
if len(out) > needed {
out = out[:needed]
}
return out, total, nil
}
func (s *adminServiceImpl) listAffiliateBalanceHistory(ctx context.Context, userID int64, params pagination.PaginationParams) ([]RedeemCode, int64, error) {
if s == nil || s.entClient == nil || userID <= 0 {
return nil, 0, nil
}
rows, err := s.entClient.QueryContext(ctx, `
SELECT id,
amount::double precision,
created_at
FROM user_affiliate_ledger
WHERE user_id = $1
AND action = 'transfer'
ORDER BY created_at DESC, id DESC
OFFSET $2
LIMIT $3`, userID, params.Offset(), params.Limit())
if err != nil {
return nil, 0, err
}
defer func() { _ = rows.Close() }()
codes := make([]RedeemCode, 0, params.Limit())
for rows.Next() {
var id int64
var amount float64
var createdAt time.Time
if err := rows.Scan(&id, &amount, &createdAt); err != nil {
return nil, 0, err
}
usedBy := userID
usedAt := createdAt
codes = append(codes, RedeemCode{
ID: -id,
Code: fmt.Sprintf("AFF-%d", id),
Type: RedeemTypeAffiliateBalance,
Value: amount,
Status: StatusUsed,
UsedBy: &usedBy,
UsedAt: &usedAt,
CreatedAt: createdAt,
})
}
if err := rows.Err(); err != nil {
return nil, 0, err
}
total, err := countAffiliateBalanceHistory(ctx, s.entClient, userID)
if err != nil {
return nil, 0, err
}
return codes, total, nil
}
func countAffiliateBalanceHistory(ctx context.Context, client *dbent.Client, userID int64) (int64, error) {
rows, err := client.QueryContext(ctx, `
SELECT COUNT(*)
FROM user_affiliate_ledger
WHERE user_id = $1
AND action = 'transfer'`, userID)
if err != nil {
return 0, err
}
defer func() { _ = rows.Close() }()
var total sql.NullInt64
if rows.Next() {
if err := rows.Scan(&total); err != nil {
return 0, err
}
}
if err := rows.Err(); err != nil {
return 0, err
}
if !total.Valid {
return 0, nil
}
return total.Int64, nil
}
func mergeBalanceHistoryCodes(redeemCodes, affiliateCodes []RedeemCode, params pagination.PaginationParams) []RedeemCode {
combined := append(append([]RedeemCode{}, redeemCodes...), affiliateCodes...)
sort.SliceStable(combined, func(i, j int) bool {
return redeemCodeHistoryTime(combined[i]).After(redeemCodeHistoryTime(combined[j]))
})
offset := params.Offset()
if offset >= len(combined) {
return []RedeemCode{}
}
end := offset + params.Limit()
if end > len(combined) {
end = len(combined)
}
return combined[offset:end]
}
func redeemCodeHistoryTime(code RedeemCode) time.Time {
if code.UsedAt != nil {
return *code.UsedAt
}
return code.CreatedAt
} }
func (s *adminServiceImpl) BindUserAuthIdentity(ctx context.Context, userID int64, input AdminBindAuthIdentityInput) (*AdminBoundAuthIdentity, error) { func (s *adminServiceImpl) BindUserAuthIdentity(ctx context.Context, userID int64, input AdminBindAuthIdentityInput) (*AdminBoundAuthIdentity, error) {
@ -1359,6 +1563,13 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
imagePrice1K := normalizePrice(input.ImagePrice1K) imagePrice1K := normalizePrice(input.ImagePrice1K)
imagePrice2K := normalizePrice(input.ImagePrice2K) imagePrice2K := normalizePrice(input.ImagePrice2K)
imagePrice4K := normalizePrice(input.ImagePrice4K) imagePrice4K := normalizePrice(input.ImagePrice4K)
imageRateMultiplier := 1.0
if input.ImageRateMultiplier != nil {
if *input.ImageRateMultiplier < 0 {
return nil, errors.New("image_rate_multiplier must be >= 0")
}
imageRateMultiplier = *input.ImageRateMultiplier
}
// 校验降级分组 // 校验降级分组
if input.FallbackGroupID != nil { if input.FallbackGroupID != nil {
@ -1426,6 +1637,9 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
DailyLimitUSD: dailyLimit, DailyLimitUSD: dailyLimit,
WeeklyLimitUSD: weeklyLimit, WeeklyLimitUSD: weeklyLimit,
MonthlyLimitUSD: monthlyLimit, MonthlyLimitUSD: monthlyLimit,
AllowImageGeneration: input.AllowImageGeneration,
ImageRateIndependent: input.ImageRateIndependent,
ImageRateMultiplier: imageRateMultiplier,
ImagePrice1K: imagePrice1K, ImagePrice1K: imagePrice1K,
ImagePrice2K: imagePrice2K, ImagePrice2K: imagePrice2K,
ImagePrice4K: imagePrice4K, ImagePrice4K: imagePrice4K,
@ -1602,6 +1816,18 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd
group.WeeklyLimitUSD = normalizeLimit(input.WeeklyLimitUSD) group.WeeklyLimitUSD = normalizeLimit(input.WeeklyLimitUSD)
group.MonthlyLimitUSD = normalizeLimit(input.MonthlyLimitUSD) group.MonthlyLimitUSD = normalizeLimit(input.MonthlyLimitUSD)
// 图片生成计费配置:负数表示清除(使用默认价格) // 图片生成计费配置:负数表示清除(使用默认价格)
if input.AllowImageGeneration != nil {
group.AllowImageGeneration = *input.AllowImageGeneration
}
if input.ImageRateIndependent != nil {
group.ImageRateIndependent = *input.ImageRateIndependent
}
if input.ImageRateMultiplier != nil {
if *input.ImageRateMultiplier < 0 {
return nil, errors.New("image_rate_multiplier must be >= 0")
}
group.ImageRateMultiplier = *input.ImageRateMultiplier
}
if input.ImagePrice1K != nil { if input.ImagePrice1K != nil {
group.ImagePrice1K = normalizePrice(input.ImagePrice1K) group.ImagePrice1K = normalizePrice(input.ImagePrice1K)
} }

View File

@ -266,6 +266,50 @@ func TestAdminService_UpdateGroup_PartialImagePricing(t *testing.T) {
require.Nil(t, repo.updated.ImagePrice4K) require.Nil(t, repo.updated.ImagePrice4K)
} }
func TestAdminService_UpdateGroup_PreservesImageGenerationControlsWhenOmitted(t *testing.T) {
imageMultiplier := 0.5
existingGroup := &Group{
ID: 1,
Name: "existing-group",
Platform: PlatformOpenAI,
Status: StatusActive,
AllowImageGeneration: true,
ImageRateIndependent: true,
ImageRateMultiplier: imageMultiplier,
}
repo := &groupRepoStubForAdmin{getByID: existingGroup}
svc := &adminServiceImpl{groupRepo: repo}
group, err := svc.UpdateGroup(context.Background(), 1, &UpdateGroupInput{
Description: "updated",
})
require.NoError(t, err)
require.NotNil(t, group)
require.NotNil(t, repo.updated)
require.True(t, repo.updated.AllowImageGeneration)
require.True(t, repo.updated.ImageRateIndependent)
require.InDelta(t, 0.5, repo.updated.ImageRateMultiplier, 1e-12)
}
func TestAdminService_UpdateGroup_RejectsNegativeImageRateMultiplier(t *testing.T) {
existingGroup := &Group{
ID: 1,
Name: "existing-group",
Platform: PlatformOpenAI,
Status: StatusActive,
ImageRateMultiplier: 1,
}
repo := &groupRepoStubForAdmin{getByID: existingGroup}
svc := &adminServiceImpl{groupRepo: repo}
negative := -0.1
_, err := svc.UpdateGroup(context.Background(), 1, &UpdateGroupInput{
ImageRateMultiplier: &negative,
})
require.Error(t, err)
require.Nil(t, repo.updated)
}
func TestAdminService_UpdateGroup_InvalidatesAuthCacheOnRPMLimitChange(t *testing.T) { func TestAdminService_UpdateGroup_InvalidatesAuthCacheOnRPMLimitChange(t *testing.T) {
existingGroup := &Group{ existingGroup := &Group{
ID: 1, ID: 1,

View File

@ -98,7 +98,7 @@ type AffiliateRepository interface {
EnsureUserAffiliate(ctx context.Context, userID int64) (*AffiliateSummary, error) EnsureUserAffiliate(ctx context.Context, userID int64) (*AffiliateSummary, error)
GetAffiliateByCode(ctx context.Context, code string) (*AffiliateSummary, error) GetAffiliateByCode(ctx context.Context, code string) (*AffiliateSummary, error)
BindInviter(ctx context.Context, userID, inviterID int64) (bool, error) BindInviter(ctx context.Context, userID, inviterID int64) (bool, error)
AccrueQuota(ctx context.Context, inviterID, inviteeUserID int64, amount float64, freezeHours int) (bool, error) AccrueQuota(ctx context.Context, inviterID, inviteeUserID int64, amount float64, freezeHours int, sourceOrderID *int64) (bool, error)
GetAccruedRebateFromInvitee(ctx context.Context, inviterID, inviteeUserID int64) (float64, error) GetAccruedRebateFromInvitee(ctx context.Context, inviterID, inviteeUserID int64) (float64, error)
ThawFrozenQuota(ctx context.Context, userID int64) (float64, error) ThawFrozenQuota(ctx context.Context, userID int64) (float64, error)
TransferQuotaToBalance(ctx context.Context, userID int64) (float64, float64, error) TransferQuotaToBalance(ctx context.Context, userID int64) (float64, float64, error)
@ -110,6 +110,10 @@ type AffiliateRepository interface {
SetUserRebateRate(ctx context.Context, userID int64, ratePercent *float64) error SetUserRebateRate(ctx context.Context, userID int64, ratePercent *float64) error
BatchSetUserRebateRate(ctx context.Context, userIDs []int64, ratePercent *float64) error BatchSetUserRebateRate(ctx context.Context, userIDs []int64, ratePercent *float64) error
ListUsersWithCustomSettings(ctx context.Context, filter AffiliateAdminFilter) ([]AffiliateAdminEntry, int64, error) ListUsersWithCustomSettings(ctx context.Context, filter AffiliateAdminFilter) ([]AffiliateAdminEntry, int64, error)
ListAffiliateInviteRecords(ctx context.Context, filter AffiliateRecordFilter) ([]AffiliateInviteRecord, int64, error)
ListAffiliateRebateRecords(ctx context.Context, filter AffiliateRecordFilter) ([]AffiliateRebateRecord, int64, error)
ListAffiliateTransferRecords(ctx context.Context, filter AffiliateRecordFilter) ([]AffiliateTransferRecord, int64, error)
GetAffiliateUserOverview(ctx context.Context, userID int64) (*AffiliateUserOverview, error)
} }
// AffiliateAdminFilter 列表筛选条件 // AffiliateAdminFilter 列表筛选条件
@ -130,6 +134,76 @@ type AffiliateAdminEntry struct {
AffCount int `json:"aff_count"` AffCount int `json:"aff_count"`
} }
type AffiliateRecordFilter struct {
Search string
Page int
PageSize int
StartAt *time.Time
EndAt *time.Time
SortBy string
SortDesc bool
}
type AffiliateInviteRecord struct {
InviterID int64 `json:"inviter_id"`
InviterEmail string `json:"inviter_email"`
InviterUsername string `json:"inviter_username"`
InviteeID int64 `json:"invitee_id"`
InviteeEmail string `json:"invitee_email"`
InviteeUsername string `json:"invitee_username"`
AffCode string `json:"aff_code"`
TotalRebate float64 `json:"total_rebate"`
CreatedAt time.Time `json:"created_at"`
}
type AffiliateRebateRecord struct {
OrderID int64 `json:"order_id"`
OutTradeNo string `json:"out_trade_no"`
InviterID int64 `json:"inviter_id"`
InviterEmail string `json:"inviter_email"`
InviterUsername string `json:"inviter_username"`
InviteeID int64 `json:"invitee_id"`
InviteeEmail string `json:"invitee_email"`
InviteeUsername string `json:"invitee_username"`
OrderAmount float64 `json:"order_amount"`
PayAmount float64 `json:"pay_amount"`
RebateAmount float64 `json:"rebate_amount"`
PaymentType string `json:"payment_type"`
OrderStatus string `json:"order_status"`
CreatedAt time.Time `json:"created_at"`
}
type AffiliateTransferRecord struct {
LedgerID int64 `json:"ledger_id"`
UserID int64 `json:"user_id"`
UserEmail string `json:"user_email"`
Username string `json:"username"`
Amount float64 `json:"amount"`
BalanceAfter *float64 `json:"balance_after,omitempty"`
AvailableQuotaAfter *float64 `json:"available_quota_after,omitempty"`
FrozenQuotaAfter *float64 `json:"frozen_quota_after,omitempty"`
HistoryQuotaAfter *float64 `json:"history_quota_after,omitempty"`
SnapshotAvailable bool `json:"snapshot_available"`
CurrentBalance float64 `json:"-"`
RemainingQuota float64 `json:"-"`
FrozenQuota float64 `json:"-"`
HistoryQuota float64 `json:"-"`
CreatedAt time.Time `json:"created_at"`
}
type AffiliateUserOverview struct {
UserID int64 `json:"user_id"`
Email string `json:"email"`
Username string `json:"username"`
AffCode string `json:"aff_code"`
RebateRatePercent float64 `json:"rebate_rate_percent"`
RebateRateCustom bool `json:"-"`
InvitedCount int `json:"invited_count"`
RebatedInviteeCount int `json:"rebated_invitee_count"`
AvailableQuota float64 `json:"available_quota"`
HistoryQuota float64 `json:"history_quota"`
}
type AffiliateService struct { type AffiliateService struct {
repo AffiliateRepository repo AffiliateRepository
settingService *SettingService settingService *SettingService
@ -238,6 +312,10 @@ func (s *AffiliateService) BindInviterByCode(ctx context.Context, userID int64,
} }
func (s *AffiliateService) AccrueInviteRebate(ctx context.Context, inviteeUserID int64, baseRechargeAmount float64) (float64, error) { func (s *AffiliateService) AccrueInviteRebate(ctx context.Context, inviteeUserID int64, baseRechargeAmount float64) (float64, error) {
return s.AccrueInviteRebateForOrder(ctx, inviteeUserID, baseRechargeAmount, nil)
}
func (s *AffiliateService) AccrueInviteRebateForOrder(ctx context.Context, inviteeUserID int64, baseRechargeAmount float64, sourceOrderID *int64) (float64, error) {
if s == nil || s.repo == nil { if s == nil || s.repo == nil {
return 0, nil return 0, nil
} }
@ -298,7 +376,7 @@ func (s *AffiliateService) AccrueInviteRebate(ctx context.Context, inviteeUserID
freezeHours = s.settingService.GetAffiliateRebateFreezeHours(ctx) freezeHours = s.settingService.GetAffiliateRebateFreezeHours(ctx)
} }
applied, err := s.repo.AccrueQuota(ctx, *inviteeSummary.InviterID, inviteeUserID, rebate, freezeHours) applied, err := s.repo.AccrueQuota(ctx, *inviteeSummary.InviterID, inviteeUserID, rebate, freezeHours, sourceOrderID)
if err != nil { if err != nil {
return 0, err return 0, err
} }
@ -488,3 +566,59 @@ func (s *AffiliateService) AdminListCustomUsers(ctx context.Context, filter Affi
} }
return s.repo.ListUsersWithCustomSettings(ctx, filter) return s.repo.ListUsersWithCustomSettings(ctx, filter)
} }
func (s *AffiliateService) AdminListInviteRecords(ctx context.Context, filter AffiliateRecordFilter) ([]AffiliateInviteRecord, int64, error) {
if s == nil || s.repo == nil {
return nil, 0, infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "affiliate service unavailable")
}
return s.repo.ListAffiliateInviteRecords(ctx, normalizeAffiliateRecordFilter(filter))
}
func (s *AffiliateService) AdminListRebateRecords(ctx context.Context, filter AffiliateRecordFilter) ([]AffiliateRebateRecord, int64, error) {
if s == nil || s.repo == nil {
return nil, 0, infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "affiliate service unavailable")
}
return s.repo.ListAffiliateRebateRecords(ctx, normalizeAffiliateRecordFilter(filter))
}
func (s *AffiliateService) AdminListTransferRecords(ctx context.Context, filter AffiliateRecordFilter) ([]AffiliateTransferRecord, int64, error) {
if s == nil || s.repo == nil {
return nil, 0, infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "affiliate service unavailable")
}
return s.repo.ListAffiliateTransferRecords(ctx, normalizeAffiliateRecordFilter(filter))
}
func (s *AffiliateService) AdminGetUserOverview(ctx context.Context, userID int64) (*AffiliateUserOverview, error) {
if userID <= 0 {
return nil, infraerrors.BadRequest("INVALID_USER", "invalid user")
}
if s == nil || s.repo == nil {
return nil, infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "affiliate service unavailable")
}
overview, err := s.repo.GetAffiliateUserOverview(ctx, userID)
if err != nil {
return nil, err
}
if overview != nil {
if !overview.RebateRateCustom {
overview.RebateRatePercent = s.globalRebateRatePercent(ctx)
}
overview.RebateRatePercent = clampAffiliateRebateRate(overview.RebateRatePercent)
}
return overview, nil
}
func normalizeAffiliateRecordFilter(filter AffiliateRecordFilter) AffiliateRecordFilter {
if filter.Page <= 0 {
filter.Page = 1
}
if filter.PageSize <= 0 {
filter.PageSize = 20
}
if filter.PageSize > 100 {
filter.PageSize = 100
}
filter.Search = strings.TrimSpace(filter.Search)
filter.SortBy = strings.TrimSpace(filter.SortBy)
return filter
}

View File

@ -63,6 +63,9 @@ type APIKeyAuthGroupSnapshot struct {
DailyLimitUSD *float64 `json:"daily_limit_usd,omitempty"` DailyLimitUSD *float64 `json:"daily_limit_usd,omitempty"`
WeeklyLimitUSD *float64 `json:"weekly_limit_usd,omitempty"` WeeklyLimitUSD *float64 `json:"weekly_limit_usd,omitempty"`
MonthlyLimitUSD *float64 `json:"monthly_limit_usd,omitempty"` MonthlyLimitUSD *float64 `json:"monthly_limit_usd,omitempty"`
AllowImageGeneration bool `json:"allow_image_generation"`
ImageRateIndependent bool `json:"image_rate_independent"`
ImageRateMultiplier float64 `json:"image_rate_multiplier"`
ImagePrice1K *float64 `json:"image_price_1k,omitempty"` ImagePrice1K *float64 `json:"image_price_1k,omitempty"`
ImagePrice2K *float64 `json:"image_price_2k,omitempty"` ImagePrice2K *float64 `json:"image_price_2k,omitempty"`
ImagePrice4K *float64 `json:"image_price_4k,omitempty"` ImagePrice4K *float64 `json:"image_price_4k,omitempty"`

View File

@ -14,7 +14,7 @@ import (
"github.com/dgraph-io/ristretto" "github.com/dgraph-io/ristretto"
) )
const apiKeyAuthSnapshotVersion = 7 // v7: added UserGroupRPMOverride on user snapshot const apiKeyAuthSnapshotVersion = 8 // v8: added group image generation controls
type apiKeyAuthCacheConfig struct { type apiKeyAuthCacheConfig struct {
l1Size int l1Size int
@ -255,6 +255,9 @@ func (s *APIKeyService) snapshotFromAPIKey(ctx context.Context, apiKey *APIKey)
DailyLimitUSD: apiKey.Group.DailyLimitUSD, DailyLimitUSD: apiKey.Group.DailyLimitUSD,
WeeklyLimitUSD: apiKey.Group.WeeklyLimitUSD, WeeklyLimitUSD: apiKey.Group.WeeklyLimitUSD,
MonthlyLimitUSD: apiKey.Group.MonthlyLimitUSD, MonthlyLimitUSD: apiKey.Group.MonthlyLimitUSD,
AllowImageGeneration: apiKey.Group.AllowImageGeneration,
ImageRateIndependent: apiKey.Group.ImageRateIndependent,
ImageRateMultiplier: apiKey.Group.ImageRateMultiplier,
ImagePrice1K: apiKey.Group.ImagePrice1K, ImagePrice1K: apiKey.Group.ImagePrice1K,
ImagePrice2K: apiKey.Group.ImagePrice2K, ImagePrice2K: apiKey.Group.ImagePrice2K,
ImagePrice4K: apiKey.Group.ImagePrice4K, ImagePrice4K: apiKey.Group.ImagePrice4K,
@ -321,6 +324,9 @@ func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapsho
DailyLimitUSD: snapshot.Group.DailyLimitUSD, DailyLimitUSD: snapshot.Group.DailyLimitUSD,
WeeklyLimitUSD: snapshot.Group.WeeklyLimitUSD, WeeklyLimitUSD: snapshot.Group.WeeklyLimitUSD,
MonthlyLimitUSD: snapshot.Group.MonthlyLimitUSD, MonthlyLimitUSD: snapshot.Group.MonthlyLimitUSD,
AllowImageGeneration: snapshot.Group.AllowImageGeneration,
ImageRateIndependent: snapshot.Group.ImageRateIndependent,
ImageRateMultiplier: snapshot.Group.ImageRateMultiplier,
ImagePrice1K: snapshot.Group.ImagePrice1K, ImagePrice1K: snapshot.Group.ImagePrice1K,
ImagePrice2K: snapshot.Group.ImagePrice2K, ImagePrice2K: snapshot.Group.ImagePrice2K,
ImagePrice4K: snapshot.Group.ImagePrice4K, ImagePrice4K: snapshot.Group.ImagePrice4K,

View File

@ -226,6 +226,12 @@ func (s *BillingService) initFallbackPricing() {
CacheReadPricePerToken: 7.5e-8, CacheReadPricePerToken: 7.5e-8,
SupportsCacheBreakdown: false, SupportsCacheBreakdown: false,
} }
s.fallbackPrices["gpt-5.4-nano"] = &ModelPricing{
InputPricePerToken: 2e-7,
OutputPricePerToken: 1.25e-6,
CacheReadPricePerToken: 2e-8,
SupportsCacheBreakdown: false,
}
// OpenAI GPT-5.2(本地兜底) // OpenAI GPT-5.2(本地兜底)
s.fallbackPrices["gpt-5.2"] = &ModelPricing{ s.fallbackPrices["gpt-5.2"] = &ModelPricing{
InputPricePerToken: 1.75e-6, InputPricePerToken: 1.75e-6,
@ -288,13 +294,14 @@ func (s *BillingService) getFallbackPricing(model string) *ModelPricing {
} }
// OpenAI 仅匹配已知 GPT-5/Codex 族,避免未知 OpenAI 型号误计价。 // OpenAI 仅匹配已知 GPT-5/Codex 族,避免未知 OpenAI 型号误计价。
if strings.Contains(modelLower, "gpt-5") || strings.Contains(modelLower, "codex") { if normalized := normalizeKnownOpenAICodexModel(modelLower); normalized != "" {
normalized := normalizeCodexModel(modelLower)
switch normalized { switch normalized {
case "gpt-5.5": case "gpt-5.5":
return s.fallbackPrices["gpt-5.5"] return s.fallbackPrices["gpt-5.5"]
case "gpt-5.4-mini": case "gpt-5.4-mini":
return s.fallbackPrices["gpt-5.4-mini"] return s.fallbackPrices["gpt-5.4-mini"]
case "gpt-5.4-nano":
return s.fallbackPrices["gpt-5.4-nano"]
case "gpt-5.4": case "gpt-5.4":
return s.fallbackPrices["gpt-5.4"] return s.fallbackPrices["gpt-5.4"]
case "gpt-5.2": case "gpt-5.2":
@ -636,13 +643,10 @@ func (s *BillingService) shouldApplySessionLongContextPricing(tokens UsageTokens
} }
func isOpenAIGPT54Model(model string) bool { func isOpenAIGPT54Model(model string) bool {
trimmed := strings.TrimSpace(strings.ToLower(model)) // 仅当模型字符串实际属于已知 GPT-5/Codex 族时才做归一判定,避免
// 仅当模型字符串实际属于 GPT-5/Codex 族时才做归一判定,避免 normalizeCodexModel // normalizeCodexModel 的默认兜底把非 OpenAI 模型claude-*、gemini-*、gpt-4o
// 的默认兜底把非 OpenAI 模型claude-*、gemini-*、gpt-4o误识别为 gpt-5.4。 // 误识别为 gpt-5.4。
if !strings.Contains(trimmed, "gpt-5") && !strings.Contains(trimmed, "codex") { normalized := normalizeKnownOpenAICodexModel(model)
return false
}
normalized := normalizeCodexModel(trimmed)
return normalized == "gpt-5.4" || normalized == "gpt-5.5" return normalized == "gpt-5.4" || normalized == "gpt-5.5"
} }

View File

@ -137,6 +137,35 @@ func TestGetModelPricing_OpenAIGPT54Fallback(t *testing.T) {
require.InDelta(t, 1.5, pricing.LongContextOutputMultiplier, 1e-12) require.InDelta(t, 1.5, pricing.LongContextOutputMultiplier, 1e-12)
} }
func TestGetModelPricing_OpenAICompactAliasesFallback(t *testing.T) {
svc := newTestBillingService()
tests := []struct {
model string
inputPrice float64
outputPrice float64
cacheRead float64
longContext int
}{
{model: "gpt5.5", inputPrice: 2.5e-6, outputPrice: 15e-6, cacheRead: 0.25e-6, longContext: 272000},
{model: "openai/gpt5.4", inputPrice: 2.5e-6, outputPrice: 15e-6, cacheRead: 0.25e-6, longContext: 272000},
{model: "gpt5.4-mini", inputPrice: 7.5e-7, outputPrice: 4.5e-6, cacheRead: 7.5e-8, longContext: 0},
{model: "gpt5.3codexspark", inputPrice: 1.5e-6, outputPrice: 12e-6, cacheRead: 0.15e-6, longContext: 0},
}
for _, tt := range tests {
t.Run(tt.model, func(t *testing.T) {
pricing, err := svc.GetModelPricing(tt.model)
require.NoError(t, err)
require.NotNil(t, pricing)
require.InDelta(t, tt.inputPrice, pricing.InputPricePerToken, 1e-12)
require.InDelta(t, tt.outputPrice, pricing.OutputPricePerToken, 1e-12)
require.InDelta(t, tt.cacheRead, pricing.CacheReadPricePerToken, 1e-12)
require.Equal(t, tt.longContext, pricing.LongContextInputThreshold)
})
}
}
func TestGetModelPricing_OpenAIGPT54MiniFallback(t *testing.T) { func TestGetModelPricing_OpenAIGPT54MiniFallback(t *testing.T) {
svc := newTestBillingService() svc := newTestBillingService()

View File

@ -52,10 +52,11 @@ const (
// Redeem type constants // Redeem type constants
const ( const (
RedeemTypeBalance = domain.RedeemTypeBalance RedeemTypeBalance = domain.RedeemTypeBalance
RedeemTypeConcurrency = domain.RedeemTypeConcurrency RedeemTypeConcurrency = domain.RedeemTypeConcurrency
RedeemTypeSubscription = domain.RedeemTypeSubscription RedeemTypeSubscription = domain.RedeemTypeSubscription
RedeemTypeInvitation = domain.RedeemTypeInvitation RedeemTypeInvitation = domain.RedeemTypeInvitation
RedeemTypeAffiliateBalance = "affiliate_balance"
) )
// PromoCode status constants // PromoCode status constants
@ -287,6 +288,9 @@ const (
// SettingKeyOverloadCooldownSettings stores JSON config for 529 overload cooldown handling. // SettingKeyOverloadCooldownSettings stores JSON config for 529 overload cooldown handling.
SettingKeyOverloadCooldownSettings = "overload_cooldown_settings" SettingKeyOverloadCooldownSettings = "overload_cooldown_settings"
// SettingKeyRateLimit429CooldownSettings stores JSON config for 429 fallback cooldown handling.
SettingKeyRateLimit429CooldownSettings = "rate_limit_429_cooldown_settings"
// ========================= // =========================
// Stream Timeout Handling // Stream Timeout Handling
// ========================= // =========================

View File

@ -8297,9 +8297,16 @@ func detachedBillingContext(ctx context.Context) (context.Context, context.Cance
} }
func detachStreamUpstreamContext(ctx context.Context, stream bool) (context.Context, context.CancelFunc) { func detachStreamUpstreamContext(ctx context.Context, stream bool) (context.Context, context.CancelFunc) {
if ctx == nil {
return context.Background(), func() {}
}
if !stream { if !stream {
return ctx, func() {} return ctx, func() {}
} }
return context.WithoutCancel(ctx), func() {}
}
func detachUpstreamContext(ctx context.Context) (context.Context, context.CancelFunc) {
if ctx == nil { if ctx == nil {
return context.Background(), func() {} return context.Background(), func() {}
} }
@ -8483,6 +8490,7 @@ func (s *GatewayService) recordUsageCore(ctx context.Context, input *recordUsage
groupDefault := apiKey.Group.RateMultiplier groupDefault := apiKey.Group.RateMultiplier
multiplier = s.getUserGroupRateMultiplier(ctx, user.ID, *apiKey.GroupID, groupDefault) multiplier = s.getUserGroupRateMultiplier(ctx, user.ID, *apiKey.GroupID, groupDefault)
} }
imageMultiplier := resolveImageRateMultiplier(apiKey, multiplier)
// 确定计费模型 // 确定计费模型
billingModel := forwardResultBillingModel(result.Model, result.UpstreamModel) billingModel := forwardResultBillingModel(result.Model, result.UpstreamModel)
@ -8500,7 +8508,7 @@ func (s *GatewayService) recordUsageCore(ctx context.Context, input *recordUsage
} }
// 计算费用 // 计算费用
cost := s.calculateRecordUsageCost(ctx, result, apiKey, billingModel, multiplier, opts) cost := s.calculateRecordUsageCost(ctx, result, apiKey, billingModel, multiplier, imageMultiplier, opts)
// 判断计费方式:订阅模式 vs 余额模式 // 判断计费方式:订阅模式 vs 余额模式
isSubscriptionBilling := subscription != nil && apiKey.Group != nil && apiKey.Group.IsSubscriptionType() isSubscriptionBilling := subscription != nil && apiKey.Group != nil && apiKey.Group.IsSubscriptionType()
@ -8512,7 +8520,7 @@ func (s *GatewayService) recordUsageCore(ctx context.Context, input *recordUsage
// 创建使用日志 // 创建使用日志
accountRateMultiplier := account.BillingRateMultiplier() accountRateMultiplier := account.BillingRateMultiplier()
usageLog := s.buildRecordUsageLog(ctx, input, result, apiKey, user, account, subscription, usageLog := s.buildRecordUsageLog(ctx, input, result, apiKey, user, account, subscription,
requestedModel, multiplier, accountRateMultiplier, billingType, cacheTTLOverridden, cost, opts) requestedModel, multiplier, imageMultiplier, accountRateMultiplier, billingType, cacheTTLOverridden, cost, opts)
// 计算账号统计定价费用(使用最终上游模型匹配自定义规则) // 计算账号统计定价费用(使用最终上游模型匹配自定义规则)
if apiKey.GroupID != nil { if apiKey.GroupID != nil {
@ -8566,11 +8574,12 @@ func (s *GatewayService) calculateRecordUsageCost(
apiKey *APIKey, apiKey *APIKey,
billingModel string, billingModel string,
multiplier float64, multiplier float64,
imageMultiplier float64,
opts *recordUsageOpts, opts *recordUsageOpts,
) *CostBreakdown { ) *CostBreakdown {
// 图片生成计费 // 图片生成计费
if result.ImageCount > 0 { if result.ImageCount > 0 {
return s.calculateImageCost(ctx, result, apiKey, billingModel, multiplier) return s.calculateImageCost(ctx, result, apiKey, billingModel, imageMultiplier)
} }
// Token 计费 // Token 计费
@ -8611,7 +8620,8 @@ func (s *GatewayService) calculateImageCost(
Model: billingModel, Model: billingModel,
GroupID: &gid, GroupID: &gid,
Tokens: tokens, Tokens: tokens,
RequestCount: 1, RequestCount: result.ImageCount,
SizeTier: result.ImageSize,
RateMultiplier: multiplier, RateMultiplier: multiplier,
Resolver: s.resolver, Resolver: s.resolver,
Resolved: resolved, Resolved: resolved,
@ -8696,6 +8706,7 @@ func (s *GatewayService) buildRecordUsageLog(
subscription *UserSubscription, subscription *UserSubscription,
requestedModel string, requestedModel string,
multiplier float64, multiplier float64,
imageMultiplier float64,
accountRateMultiplier float64, accountRateMultiplier float64,
billingType int8, billingType int8,
cacheTTLOverridden bool, cacheTTLOverridden bool,
@ -8740,6 +8751,9 @@ func (s *GatewayService) buildRecordUsageLog(
SubscriptionID: optionalSubscriptionID(subscription), SubscriptionID: optionalSubscriptionID(subscription),
CreatedAt: time.Now(), CreatedAt: time.Now(),
} }
if result.ImageCount > 0 {
usageLog.RateMultiplier = imageMultiplier
}
if cost != nil { if cost != nil {
usageLog.InputCost = cost.InputCost usageLog.InputCost = cost.InputCost
usageLog.OutputCost = cost.OutputCost usageLog.OutputCost = cost.OutputCost

View File

@ -13,6 +13,8 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
type upstreamContextTestKey string
func TestGatewayService_StreamingReusesScannerBufferAndStillParsesUsage(t *testing.T) { func TestGatewayService_StreamingReusesScannerBufferAndStillParsesUsage(t *testing.T) {
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)
cfg := &config.Config{ cfg := &config.Config{
@ -50,3 +52,14 @@ func TestGatewayService_StreamingReusesScannerBufferAndStillParsesUsage(t *testi
require.Equal(t, 3, result.usage.InputTokens) require.Equal(t, 3, result.usage.InputTokens)
require.Equal(t, 7, result.usage.OutputTokens) require.Equal(t, 7, result.usage.OutputTokens)
} }
func TestDetachUpstreamContextIgnoresClientCancel(t *testing.T) {
parent, cancel := context.WithCancel(context.WithValue(context.Background(), upstreamContextTestKey("test-key"), "test-value"))
upstreamCtx, release := detachUpstreamContext(parent)
defer release()
cancel()
require.NoError(t, upstreamCtx.Err())
require.Equal(t, "test-value", upstreamCtx.Value(upstreamContextTestKey("test-key")))
}

View File

@ -26,9 +26,12 @@ type Group struct {
DefaultValidityDays int DefaultValidityDays int
// 图片生成计费配置antigravity 和 gemini 平台使用) // 图片生成计费配置antigravity 和 gemini 平台使用)
ImagePrice1K *float64 AllowImageGeneration bool
ImagePrice2K *float64 ImageRateIndependent bool
ImagePrice4K *float64 ImageRateMultiplier float64
ImagePrice1K *float64
ImagePrice2K *float64
ImagePrice4K *float64
// Claude Code 客户端限制 // Claude Code 客户端限制
ClaudeCodeOnly bool ClaudeCodeOnly bool

View File

@ -45,19 +45,25 @@ type GroupSortOrderUpdate struct {
// CreateGroupRequest 创建分组请求 // CreateGroupRequest 创建分组请求
type CreateGroupRequest struct { type CreateGroupRequest struct {
Name string `json:"name"` Name string `json:"name"`
Description string `json:"description"` Description string `json:"description"`
RateMultiplier float64 `json:"rate_multiplier"` RateMultiplier float64 `json:"rate_multiplier"`
IsExclusive bool `json:"is_exclusive"` IsExclusive bool `json:"is_exclusive"`
AllowImageGeneration bool `json:"allow_image_generation"`
ImageRateIndependent bool `json:"image_rate_independent"`
ImageRateMultiplier *float64 `json:"image_rate_multiplier"`
} }
// UpdateGroupRequest 更新分组请求 // UpdateGroupRequest 更新分组请求
type UpdateGroupRequest struct { type UpdateGroupRequest struct {
Name *string `json:"name"` Name *string `json:"name"`
Description *string `json:"description"` Description *string `json:"description"`
RateMultiplier *float64 `json:"rate_multiplier"` RateMultiplier *float64 `json:"rate_multiplier"`
IsExclusive *bool `json:"is_exclusive"` IsExclusive *bool `json:"is_exclusive"`
Status *string `json:"status"` Status *string `json:"status"`
AllowImageGeneration *bool `json:"allow_image_generation"`
ImageRateIndependent *bool `json:"image_rate_independent"`
ImageRateMultiplier *float64 `json:"image_rate_multiplier"`
} }
// GroupService 分组管理服务 // GroupService 分组管理服务
@ -76,6 +82,13 @@ func NewGroupService(groupRepo GroupRepository, authCacheInvalidator APIKeyAuthC
// Create 创建分组 // Create 创建分组
func (s *GroupService) Create(ctx context.Context, req CreateGroupRequest) (*Group, error) { func (s *GroupService) Create(ctx context.Context, req CreateGroupRequest) (*Group, error) {
imageRateMultiplier := 1.0
if req.ImageRateMultiplier != nil {
if *req.ImageRateMultiplier < 0 {
return nil, fmt.Errorf("image_rate_multiplier must be >= 0")
}
imageRateMultiplier = *req.ImageRateMultiplier
}
// 检查名称是否已存在 // 检查名称是否已存在
exists, err := s.groupRepo.ExistsByName(ctx, req.Name) exists, err := s.groupRepo.ExistsByName(ctx, req.Name)
if err != nil { if err != nil {
@ -87,13 +100,16 @@ func (s *GroupService) Create(ctx context.Context, req CreateGroupRequest) (*Gro
// 创建分组 // 创建分组
group := &Group{ group := &Group{
Name: req.Name, Name: req.Name,
Description: req.Description, Description: req.Description,
Platform: PlatformAnthropic, Platform: PlatformAnthropic,
RateMultiplier: req.RateMultiplier, RateMultiplier: req.RateMultiplier,
IsExclusive: req.IsExclusive, IsExclusive: req.IsExclusive,
Status: StatusActive, Status: StatusActive,
SubscriptionType: SubscriptionTypeStandard, SubscriptionType: SubscriptionTypeStandard,
AllowImageGeneration: req.AllowImageGeneration,
ImageRateIndependent: req.ImageRateIndependent,
ImageRateMultiplier: imageRateMultiplier,
} }
if err := s.groupRepo.Create(ctx, group); err != nil { if err := s.groupRepo.Create(ctx, group); err != nil {
@ -165,6 +181,18 @@ func (s *GroupService) Update(ctx context.Context, id int64, req UpdateGroupRequ
if req.Status != nil { if req.Status != nil {
group.Status = *req.Status group.Status = *req.Status
} }
if req.AllowImageGeneration != nil {
group.AllowImageGeneration = *req.AllowImageGeneration
}
if req.ImageRateIndependent != nil {
group.ImageRateIndependent = *req.ImageRateIndependent
}
if req.ImageRateMultiplier != nil {
if *req.ImageRateMultiplier < 0 {
return nil, fmt.Errorf("image_rate_multiplier must be >= 0")
}
group.ImageRateMultiplier = *req.ImageRateMultiplier
}
if err := s.groupRepo.Update(ctx, group); err != nil { if err := s.groupRepo.Update(ctx, group); err != nil {
return nil, fmt.Errorf("update group: %w", err) return nil, fmt.Errorf("update group: %w", err)

View File

@ -0,0 +1,11 @@
package service
func resolveImageRateMultiplier(apiKey *APIKey, effectiveGroupMultiplier float64) float64 {
if apiKey != nil && apiKey.Group != nil && apiKey.Group.ImageRateIndependent {
if apiKey.Group.ImageRateMultiplier < 0 {
return 0
}
return apiKey.Group.ImageRateMultiplier
}
return effectiveGroupMultiplier
}

View File

@ -0,0 +1,220 @@
package service
import (
"encoding/json"
"strings"
"github.com/tidwall/gjson"
)
const (
openAIResponsesEndpoint = "/v1/responses"
openAIResponsesCompactEndpoint = "/v1/responses/compact"
imageGenerationPermissionMessage = "Image generation is not enabled for this group"
)
// ImageGenerationPermissionMessage returns the stable end-user error text for disabled groups.
func ImageGenerationPermissionMessage() string {
return imageGenerationPermissionMessage
}
// GroupAllowsImageGeneration preserves ungrouped-key behavior and enforces the flag when a group is present.
func GroupAllowsImageGeneration(group *Group) bool {
return group == nil || group.AllowImageGeneration
}
// IsImageGenerationIntent classifies requests that can produce generated images.
func IsImageGenerationIntent(endpoint string, requestedModel string, body []byte) bool {
if IsImageGenerationEndpoint(endpoint) {
return true
}
if isOpenAIImageGenerationModel(requestedModel) {
return true
}
if len(body) == 0 || !gjson.ValidBytes(body) {
return false
}
if model := strings.TrimSpace(gjson.GetBytes(body, "model").String()); isOpenAIImageGenerationModel(model) {
return true
}
if openAIJSONToolsContainImageGeneration(gjson.GetBytes(body, "tools")) {
return true
}
return openAIJSONToolChoiceSelectsImageGeneration(gjson.GetBytes(body, "tool_choice"))
}
// IsImageGenerationIntentMap is the map-backed variant used after service-side request mutation.
func IsImageGenerationIntentMap(endpoint string, requestedModel string, reqBody map[string]any) bool {
if IsImageGenerationEndpoint(endpoint) {
return true
}
if isOpenAIImageGenerationModel(requestedModel) {
return true
}
if reqBody == nil {
return false
}
if isOpenAIImageGenerationModel(firstNonEmptyString(reqBody["model"])) {
return true
}
if hasOpenAIImageGenerationTool(reqBody) {
return true
}
return openAIAnyToolChoiceSelectsImageGeneration(reqBody["tool_choice"])
}
// IsImageGenerationEndpoint identifies dedicated generated-image endpoints.
func IsImageGenerationEndpoint(endpoint string) bool {
switch normalizeImageGenerationEndpoint(endpoint) {
case "/v1/images/generations", "/v1/images/edits", "/images/generations", "/images/edits":
return true
default:
return false
}
}
func normalizeImageGenerationEndpoint(endpoint string) string {
endpoint = strings.TrimSpace(strings.ToLower(endpoint))
if endpoint == "" {
return ""
}
endpoint = strings.TrimPrefix(endpoint, "https://api.openai.com")
if idx := strings.IndexByte(endpoint, '?'); idx >= 0 {
endpoint = endpoint[:idx]
}
return strings.TrimRight(endpoint, "/")
}
func openAIJSONToolsContainImageGeneration(tools gjson.Result) bool {
if !tools.IsArray() {
return false
}
found := false
tools.ForEach(func(_, item gjson.Result) bool {
if strings.TrimSpace(item.Get("type").String()) == "image_generation" {
found = true
return false
}
return true
})
return found
}
func openAIJSONToolChoiceSelectsImageGeneration(choice gjson.Result) bool {
if !choice.Exists() {
return false
}
if choice.Type == gjson.String {
return strings.TrimSpace(choice.String()) == "image_generation"
}
if !choice.IsObject() {
return false
}
if strings.TrimSpace(choice.Get("type").String()) == "image_generation" {
return true
}
if strings.TrimSpace(choice.Get("tool.type").String()) == "image_generation" {
return true
}
if strings.TrimSpace(choice.Get("function.name").String()) == "image_generation" {
return true
}
return false
}
func openAIAnyToolChoiceSelectsImageGeneration(choice any) bool {
switch v := choice.(type) {
case string:
return strings.TrimSpace(v) == "image_generation"
case map[string]any:
if strings.TrimSpace(firstNonEmptyString(v["type"])) == "image_generation" {
return true
}
if tool, ok := v["tool"].(map[string]any); ok && strings.TrimSpace(firstNonEmptyString(tool["type"])) == "image_generation" {
return true
}
if fn, ok := v["function"].(map[string]any); ok && strings.TrimSpace(firstNonEmptyString(fn["name"])) == "image_generation" {
return true
}
}
return false
}
func getAPIKeyFromContext(c interface{ Get(string) (any, bool) }) *APIKey {
if c == nil {
return nil
}
v, exists := c.Get("api_key")
if !exists {
return nil
}
apiKey, _ := v.(*APIKey)
return apiKey
}
func apiKeyGroup(apiKey *APIKey) *Group {
if apiKey == nil {
return nil
}
return apiKey.Group
}
func cloneRequestMapForImageIntent(body []byte) map[string]any {
if len(body) == 0 {
return nil
}
var out map[string]any
if err := json.Unmarshal(body, &out); err != nil {
return nil
}
return out
}
func resolveOpenAIResponsesImageBillingConfig(reqBody map[string]any, fallbackModel string) (string, string, error) {
imageModel := ""
imageSize := ""
hasImageTool := false
if reqBody != nil {
rawTools, _ := reqBody["tools"].([]any)
for _, rawTool := range rawTools {
toolMap, ok := rawTool.(map[string]any)
if !ok || strings.TrimSpace(firstNonEmptyString(toolMap["type"])) != "image_generation" {
continue
}
hasImageTool = true
imageModel = strings.TrimSpace(firstNonEmptyString(toolMap["model"]))
imageSize = strings.TrimSpace(firstNonEmptyString(toolMap["size"]))
break
}
if imageSize == "" {
imageSize = strings.TrimSpace(firstNonEmptyString(reqBody["size"]))
}
}
if imageModel == "" && reqBody != nil {
bodyModel := strings.TrimSpace(firstNonEmptyString(reqBody["model"]))
if isOpenAIImageBillingModelAlias(bodyModel) || !hasImageTool {
imageModel = bodyModel
}
}
if imageModel == "" && hasImageTool {
imageModel = "gpt-image-2"
}
if imageModel == "" {
imageModel = strings.TrimSpace(fallbackModel)
}
sizeTier := normalizeOpenAIImageSizeTier(imageSize)
return imageModel, sizeTier, nil
}
func resolveOpenAIResponsesImageBillingConfigFromBody(body []byte, fallbackModel string) (string, string, error) {
reqBody := cloneRequestMapForImageIntent(body)
return resolveOpenAIResponsesImageBillingConfig(reqBody, fallbackModel)
}
func isOpenAIImageBillingModelAlias(model string) bool {
normalized := strings.ToLower(strings.TrimSpace(model))
if normalized == "" {
return false
}
return isOpenAIImageGenerationModel(normalized) || strings.Contains(normalized, "image")
}

View File

@ -0,0 +1,184 @@
package service
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestIsImageGenerationIntent(t *testing.T) {
tests := []struct {
name string
endpoint string
model string
body []byte
want bool
}{
{
name: "images endpoint",
endpoint: "/v1/images/generations",
body: []byte(`{"model":"gpt-image-2"}`),
want: true,
},
{
name: "image model",
endpoint: "/v1/responses",
model: "gpt-image-2",
body: []byte(`{"model":"gpt-image-2"}`),
want: true,
},
{
name: "image tool",
endpoint: "/v1/responses",
model: "gpt-5.4",
body: []byte(`{"model":"gpt-5.4","tools":[{"type":"image_generation"}]}`),
want: true,
},
{
name: "image tool choice",
endpoint: "/v1/responses",
model: "gpt-5.4",
body: []byte(`{"model":"gpt-5.4","tool_choice":{"type":"image_generation"}}`),
want: true,
},
{
name: "required tool choice alone is text",
endpoint: "/v1/responses",
model: "gpt-5.4",
body: []byte(`{"model":"gpt-5.4","tool_choice":"required"}`),
want: false,
},
{
name: "text only gpt 5.4",
endpoint: "/v1/responses",
model: "gpt-5.4",
body: []byte(`{"model":"gpt-5.4","input":"write code"}`),
want: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
require.Equal(t, tt.want, IsImageGenerationIntent(tt.endpoint, tt.model, tt.body))
})
}
}
func TestResolveOpenAIResponsesImageBillingConfigUsesCurrentBodyModel(t *testing.T) {
imageModel, imageSize, err := resolveOpenAIResponsesImageBillingConfigFromBody(
[]byte(`{"model":"mapped-image-model","tools":[{"type":"image_generation","size":"1024x1024"}]}`),
"requested-model",
)
require.NoError(t, err)
require.Equal(t, "mapped-image-model", imageModel)
require.Equal(t, "1K", imageSize)
}
func TestResolveOpenAIResponsesImageBillingConfigToolModelWins(t *testing.T) {
imageModel, imageSize, err := resolveOpenAIResponsesImageBillingConfigFromBody(
[]byte(`{"model":"mapped-text-model","tools":[{"type":"image_generation","model":"gpt-image-2","size":"1536x1024"}]}`),
"requested-model",
)
require.NoError(t, err)
require.Equal(t, "gpt-image-2", imageModel)
require.Equal(t, "2K", imageSize)
}
func TestResolveOpenAIResponsesImageBillingConfigSupportsOfficialAndCustomSizes(t *testing.T) {
tests := []struct {
name string
body []byte
wantTier string
}{
{
name: "official 2k landscape",
body: []byte(`{"model":"gpt-5.4","tools":[{"type":"image_generation","model":"gpt-image-2","size":"2048x1152"}]}`),
wantTier: "2K",
},
{
name: "official 4k landscape",
body: []byte(`{"model":"gpt-5.4","tools":[{"type":"image_generation","model":"gpt-image-2","size":"3840x2160"}]}`),
wantTier: "4K",
},
{
name: "custom valid 2k",
body: []byte(`{"model":"gpt-5.5","tools":[{"type":"image_generation","model":"gpt-image-2","size":"1280x768"}]}`),
wantTier: "2K",
},
{
name: "default image tool model supports flexible size",
body: []byte(`{"model":"gpt-5.4","tools":[{"type":"image_generation","size":"2048x1152"}]}`),
wantTier: "2K",
},
{
name: "top level image size is moved into billing",
body: []byte(`{"model":"gpt-image-2","size":"2048x2048","tools":[{"type":"image_generation","model":"gpt-image-2"}]}`),
wantTier: "2K",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
imageModel, imageSize, err := resolveOpenAIResponsesImageBillingConfigFromBody(tt.body, "requested-model")
require.NoError(t, err)
require.NotEmpty(t, imageModel)
require.Equal(t, tt.wantTier, imageSize)
})
}
}
func TestResolveOpenAIResponsesImageBillingConfigDoesNotRejectUnknownSizes(t *testing.T) {
imageModel, imageSize, err := resolveOpenAIResponsesImageBillingConfigFromBody(
[]byte(`{"model":"gpt-5.4","tools":[{"type":"image_generation","model":"gpt-image-1.5","size":"2048x1152"}]}`),
"requested-model",
)
require.NoError(t, err)
require.Equal(t, "gpt-image-1.5", imageModel)
require.Equal(t, "2K", imageSize)
}
func TestOpenAIImageOutputCounterDeduplicatesFinalImages(t *testing.T) {
counter := newOpenAIImageOutputCounter()
counter.AddSSEData([]byte(`{"type":"response.image_generation_call.partial_image","partial_image_b64":"abc"}`))
counter.AddSSEData([]byte(`{"type":"response.output_item.done","item":{"id":"ig_1","type":"image_generation_call","result":"final-a"}}`))
counter.AddSSEData([]byte(`{"type":"response.completed","response":{"output":[{"id":"ig_1","type":"image_generation_call","result":"final-a"},{"id":"ig_2","type":"image_generation_call","result":"final-b"}]}}`))
require.Equal(t, 2, counter.Count())
}
func TestOpenAIImageOutputCounterCountsImagesAPIStreamShapes(t *testing.T) {
counter := newOpenAIImageOutputCounter()
counter.AddSSEData([]byte(`{"type":"image_generation.completed","id":"ig_complete","b64_json":"final-a"}`))
counter.AddSSEData([]byte(`{"type":"response.output_item.done","item":{"id":"ig_item","type":"image_generation_call","result":"final-b"}}`))
counter.AddSSEData([]byte(`{"type":"response.completed","response":{"output":[{"id":"ig_done","type":"image_generation_call","result":"final-c"}]}}`))
require.Equal(t, 3, counter.Count())
dataCounter := newOpenAIImageOutputCounter()
dataCounter.AddSSEData([]byte(`{"data":[{"b64_json":"a"},{"b64_json":"b"}]}`))
dataCounter.AddSSEData([]byte(`{"data":[{"b64_json":"a"},{"b64_json":"b"},{"b64_json":"c"}]}`))
require.Equal(t, 3, dataCounter.Count())
}
func TestOpenAIImageOutputCounterCountsMultilineSSEDataPayload(t *testing.T) {
counter := newOpenAIImageOutputCounter()
counter.AddSSEData([]byte("{\"type\":\"image_generation.completed\",\n\"b64_json\":\"final-a\"}"))
require.Equal(t, 1, counter.Count())
}
func TestOpenAIImageOutputCounterCountsMultilineSSEBodyPayload(t *testing.T) {
counter := newOpenAIImageOutputCounter()
counter.AddSSEBody(
"data: {\"type\":\"image_generation.completed\",\n" +
"data: \"b64_json\":\"final-a\"}\n\n" +
"data: [DONE]\n\n",
)
require.Equal(t, 1, counter.Count())
}
func TestOpenAIImageOutputCounterFallsBackForInvalidMultilineSSEBody(t *testing.T) {
counter := newOpenAIImageOutputCounter()
counter.AddSSEBody(
"data: {\"type\":\"image_generation.completed\",\"b64_json\":\"final-a\"}\n" +
"data: {\"type\":\"image_generation.completed\",\"b64_json\":\"final-b\"}\n\n",
)
require.Equal(t, 2, counter.Count())
}

View File

@ -0,0 +1,149 @@
package service
import (
"crypto/sha256"
"encoding/hex"
"strings"
"github.com/tidwall/gjson"
)
type openAIImageOutputCounter struct {
seen map[string]struct{}
count int
maxDataCount int
}
func newOpenAIImageOutputCounter() *openAIImageOutputCounter {
return &openAIImageOutputCounter{seen: make(map[string]struct{})}
}
func (c *openAIImageOutputCounter) Count() int {
if c == nil {
return 0
}
if c.maxDataCount > c.count {
return c.maxDataCount
}
return c.count
}
func (c *openAIImageOutputCounter) AddJSONResponse(body []byte) {
if c == nil || len(body) == 0 || !gjson.ValidBytes(body) {
return
}
c.addDataArray(gjson.GetBytes(body, "data"))
c.addOutputArray(gjson.GetBytes(body, "output"))
c.addOutputArray(gjson.GetBytes(body, "response.output"))
}
func (c *openAIImageOutputCounter) AddSSEData(data []byte) {
if c == nil || len(data) == 0 || strings.TrimSpace(string(data)) == "[DONE]" || !gjson.ValidBytes(data) {
return
}
root := gjson.ParseBytes(data)
c.addDataArray(root.Get("data"))
eventType := strings.TrimSpace(root.Get("type").String())
switch eventType {
case "response.output_item.done":
c.addImageOutputItem(root.Get("item"))
case "response.completed", "response.done":
c.addOutputArray(root.Get("response.output"))
case "image_generation.completed":
if item := root.Get("item"); item.Exists() {
c.addImageOutputItem(item)
return
}
if output := root.Get("output"); output.Exists() {
c.addImageOutputItem(output)
return
}
c.addImageOutputItem(root)
}
}
func (c *openAIImageOutputCounter) AddSSEBody(body string) {
if c == nil || strings.TrimSpace(body) == "" {
return
}
forEachOpenAISSEDataPayload(body, c.AddSSEData)
}
func (c *openAIImageOutputCounter) addDataArray(data gjson.Result) {
if !data.IsArray() {
return
}
count := len(data.Array())
if count > c.maxDataCount {
c.maxDataCount = count
}
}
func (c *openAIImageOutputCounter) addOutputArray(output gjson.Result) {
if !output.IsArray() {
return
}
output.ForEach(func(_, item gjson.Result) bool {
c.addImageOutputItem(item)
return true
})
}
func (c *openAIImageOutputCounter) addImageOutputItem(item gjson.Result) {
if !item.Exists() || !item.IsObject() {
return
}
itemType := strings.TrimSpace(item.Get("type").String())
if itemType != "" && itemType != "image_generation_call" && itemType != "image_generation.completed" {
return
}
if strings.Contains(strings.ToLower(item.Raw), "partial_image") {
return
}
result := strings.TrimSpace(item.Get("result").String())
if result == "" {
result = strings.TrimSpace(item.Get("b64_json").String())
}
if result == "" {
result = strings.TrimSpace(item.Get("url").String())
}
if result == "" && itemType != "image_generation.completed" {
return
}
key := strings.TrimSpace(item.Get("id").String())
if key == "" {
key = strings.TrimSpace(item.Get("call_id").String())
}
if key == "" {
key = hashOpenAIImageOutputResult(result)
}
if key == "" {
return
}
if _, exists := c.seen[key]; exists {
return
}
c.seen[key] = struct{}{}
c.count++
}
func hashOpenAIImageOutputResult(result string) string {
result = strings.TrimSpace(result)
if result == "" {
return ""
}
sum := sha256.Sum256([]byte(result))
return hex.EncodeToString(sum[:])
}
func countOpenAIResponseImageOutputsFromJSONBytes(body []byte) int {
counter := newOpenAIImageOutputCounter()
counter.AddJSONResponse(body)
return counter.Count()
}
func countOpenAIImageOutputsFromSSEBody(body string) int {
counter := newOpenAIImageOutputCounter()
counter.AddSSEBody(body)
return counter.Count()
}

View File

@ -5,6 +5,7 @@ import (
"context" "context"
"fmt" "fmt"
"hash/fnv" "hash/fnv"
"log/slog"
"math" "math"
"sort" "sort"
"strconv" "strconv"
@ -345,7 +346,7 @@ func (s *defaultOpenAIAccountScheduler) selectBySessionHash(
_ = s.service.deleteStickySessionAccountID(ctx, req.GroupID, sessionHash) _ = s.service.deleteStickySessionAccountID(ctx, req.GroupID, sessionHash)
return nil, nil return nil, nil
} }
if !s.isAccountRequestCompatible(account, req) { if !s.isAccountRequestCompatible(ctx, account, req) {
return nil, nil return nil, nil
} }
if !s.isAccountTransportCompatible(account, req.RequiredTransport) { if !s.isAccountTransportCompatible(account, req.RequiredTransport) {
@ -621,7 +622,7 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance(
fmt.Sprintf("Privacy not set, required by group [%s]", schedGroup.Name)) fmt.Sprintf("Privacy not set, required by group [%s]", schedGroup.Name))
continue continue
} }
if !s.isAccountRequestCompatible(account, req) { if !s.isAccountRequestCompatible(ctx, account, req) {
continue continue
} }
if !s.isAccountTransportCompatible(account, req.RequiredTransport) { if !s.isAccountTransportCompatible(account, req.RequiredTransport) {
@ -828,11 +829,11 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance(
for i := 0; i < len(selectionOrder); i++ { for i := 0; i < len(selectionOrder); i++ {
candidate := selectionOrder[i] candidate := selectionOrder[i]
fresh := s.service.resolveFreshSchedulableOpenAIAccount(ctx, candidate.account, req.RequestedModel, false) fresh := s.service.resolveFreshSchedulableOpenAIAccount(ctx, candidate.account, req.RequestedModel, false)
if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) || !s.isAccountRequestCompatible(fresh, req) { if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) || !s.isAccountRequestCompatible(ctx, fresh, req) {
continue continue
} }
fresh = s.service.recheckSelectedOpenAIAccountFromDB(ctx, fresh, req.RequestedModel, false) fresh = s.service.recheckSelectedOpenAIAccountFromDB(ctx, fresh, req.RequestedModel, false)
if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) || !s.isAccountRequestCompatible(fresh, req) { if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) || !s.isAccountRequestCompatible(ctx, fresh, req) {
continue continue
} }
if req.RequireCompact && openAICompactSupportTier(fresh) == 0 { if req.RequireCompact && openAICompactSupportTier(fresh) == 0 {
@ -859,11 +860,11 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance(
// WaitPlan.MaxConcurrency 使用 Concurrency非 EffectiveLoadFactor因为 WaitPlan 控制的是 Redis 实际并发槽位等待。 // WaitPlan.MaxConcurrency 使用 Concurrency非 EffectiveLoadFactor因为 WaitPlan 控制的是 Redis 实际并发槽位等待。
for _, candidate := range selectionOrder { for _, candidate := range selectionOrder {
fresh := s.service.resolveFreshSchedulableOpenAIAccount(ctx, candidate.account, req.RequestedModel, false) fresh := s.service.resolveFreshSchedulableOpenAIAccount(ctx, candidate.account, req.RequestedModel, false)
if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) || !s.isAccountRequestCompatible(fresh, req) { if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) || !s.isAccountRequestCompatible(ctx, fresh, req) {
continue continue
} }
fresh = s.service.recheckSelectedOpenAIAccountFromDB(ctx, fresh, req.RequestedModel, false) fresh = s.service.recheckSelectedOpenAIAccountFromDB(ctx, fresh, req.RequestedModel, false)
if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) || !s.isAccountRequestCompatible(fresh, req) { if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) || !s.isAccountRequestCompatible(ctx, fresh, req) {
continue continue
} }
if req.RequireCompact && openAICompactSupportTier(fresh) == 0 { if req.RequireCompact && openAICompactSupportTier(fresh) == 0 {
@ -894,13 +895,18 @@ func (s *defaultOpenAIAccountScheduler) isAccountTransportCompatible(account *Ac
return s.service.isOpenAIAccountTransportCompatible(account, requiredTransport) return s.service.isOpenAIAccountTransportCompatible(account, requiredTransport)
} }
func (s *defaultOpenAIAccountScheduler) isAccountRequestCompatible(account *Account, req OpenAIAccountScheduleRequest) bool { func (s *defaultOpenAIAccountScheduler) isAccountRequestCompatible(ctx context.Context, account *Account, req OpenAIAccountScheduleRequest) bool {
if account == nil { if account == nil {
return false return false
} }
if req.RequestedModel != "" && !account.IsModelSupported(req.RequestedModel) { if req.RequestedModel != "" && !account.IsModelSupported(req.RequestedModel) {
return false return false
} }
if req.GroupID != nil && s != nil && s.service != nil &&
s.service.needsUpstreamChannelRestrictionCheck(ctx, req.GroupID) &&
s.service.isUpstreamModelRestrictedByChannel(ctx, *req.GroupID, account, req.RequestedModel, req.RequireCompact) {
return false
}
return account.SupportsOpenAIImageCapability(req.RequiredImageCapability) return account.SupportsOpenAIImageCapability(req.RequiredImageCapability)
} }
@ -1112,6 +1118,13 @@ func (s *OpenAIGatewayService) selectAccountWithScheduler(
} }
} }
if s.checkChannelPricingRestriction(ctx, groupID, requestedModel) {
slog.Warn("channel pricing restriction blocked request",
"group_id", derefGroupID(groupID),
"model", requestedModel)
return nil, decision, fmt.Errorf("%w supporting model: %s (channel pricing restriction)", ErrNoAvailableAccounts, requestedModel)
}
var stickyAccountID int64 var stickyAccountID int64
if sessionHash != "" && s.cache != nil { if sessionHash != "" && s.cache != nil {
if accountID, err := s.getStickySessionAccountID(ctx, groupID, sessionHash); err == nil && accountID > 0 { if accountID, err := s.getStickySessionAccountID(ctx, groupID, sessionHash); err == nil && accountID > 0 {

View File

@ -0,0 +1,149 @@
package service
import (
"bytes"
"context"
"encoding/json"
"io"
"net/http"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai_compat"
)
// openaiResponsesProbeTimeout 是探测请求的超时时长。
// 探测必须快速失败——超时不应阻塞账号创建/更新流程。
const openaiResponsesProbeTimeout = 8 * time.Second
// openaiResponsesProbePayload 是探测使用的最小 Responses 请求体。
// 仅作能力探测不期望响应内容质量Stream=false 减少 SSE 解析开销。
//
// 注意:探测的目标是区分"端点存在"与"端点不存在"——只要上游返回非 404 的
// 4xx/5xx如 400 invalid_request_error / 401 unauthorized / 422 等),
// 都视为"端点存在 → 支持 Responses"。仅 404 / 405 视为"端点不存在"。
func openaiResponsesProbePayload(modelID string) []byte {
if strings.TrimSpace(modelID) == "" {
modelID = openai.DefaultTestModel
}
body, _ := json.Marshal(map[string]any{
"model": modelID,
"input": []map[string]any{
{
"role": "user",
"content": []map[string]any{
{"type": "input_text", "text": "hi"},
},
},
},
"instructions": openai.DefaultInstructions,
"stream": false,
})
return body
}
// ProbeOpenAIAPIKeyResponsesSupport 探测 OpenAI APIKey 账号上游是否支持
// /v1/responses 端点,并将结果持久化到 accounts.extra.openai_responses_supported。
//
// 调用时机:账号创建/更新后,且仅当 platform=openai && type=apikey 时。
//
// 探测策略(参见包文档 internal/pkg/openai_compat
// - 上游 404 / 405 → 不支持,写 false
// - 上游 2xx / 其他 4xx401/422/400 等)/ 5xx → 支持,写 true
// - 网络层失败(连接错误、超时)→ 不写标记,保持 unknown
// (后续请求仍按"现状即证据"默认走 Responses
//
// 该方法是幂等的:重复调用会以最新探测结果覆盖标记。
//
// 关于失败处理:探测本身的失败不应阻塞账号创建——账号能创建/更新成功就够了,
// 探测结果只影响后续路由优化。所有错误都仅记录日志,不向调用方传播。
func (s *AccountTestService) ProbeOpenAIAPIKeyResponsesSupport(ctx context.Context, accountID int64) {
account, err := s.accountRepo.GetByID(ctx, accountID)
if err != nil {
logger.LegacyPrintf("service.openai_probe", "probe_load_account_failed: account_id=%d err=%v", accountID, err)
return
}
if account.Platform != PlatformOpenAI || account.Type != AccountTypeAPIKey {
// 仅 OpenAI APIKey 账号需要探测;其他账号类型无能力差异。
return
}
apiKey := account.GetOpenAIApiKey()
if apiKey == "" {
logger.LegacyPrintf("service.openai_probe", "probe_skip_no_apikey: account_id=%d", accountID)
return
}
baseURL := account.GetOpenAIBaseURL()
if baseURL == "" {
baseURL = "https://api.openai.com"
}
normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
if err != nil {
logger.LegacyPrintf("service.openai_probe", "probe_invalid_baseurl: account_id=%d base_url=%q err=%v", accountID, baseURL, err)
return
}
probeURL := buildOpenAIResponsesURL(normalizedBaseURL)
probeCtx, cancel := context.WithTimeout(ctx, openaiResponsesProbeTimeout)
defer cancel()
req, err := http.NewRequestWithContext(probeCtx, http.MethodPost, probeURL, bytes.NewReader(openaiResponsesProbePayload("")))
if err != nil {
logger.LegacyPrintf("service.openai_probe", "probe_build_request_failed: account_id=%d err=%v", accountID, err)
return
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+apiKey)
req.Header.Set("Accept", "application/json")
proxyURL := ""
if account.ProxyID != nil && account.Proxy != nil {
proxyURL = account.Proxy.URL()
}
resp, err := s.httpUpstream.DoWithTLS(req, proxyURL, account.ID, account.Concurrency, s.tlsFPProfileService.ResolveTLSProfile(account))
if err != nil {
// 网络层失败:不写标记,保持 unknown下次重试或由网关 fallback 处理
logger.LegacyPrintf("service.openai_probe", "probe_request_failed: account_id=%d url=%s err=%v", accountID, probeURL, err)
return
}
defer func() {
_, _ = io.Copy(io.Discard, io.LimitReader(resp.Body, 1<<20))
_ = resp.Body.Close()
}()
supported := isResponsesEndpointSupportedByStatus(resp.StatusCode)
if err := s.accountRepo.UpdateExtra(ctx, accountID, map[string]any{
openai_compat.ExtraKeyResponsesSupported: supported,
}); err != nil {
logger.LegacyPrintf("service.openai_probe", "probe_persist_failed: account_id=%d supported=%v err=%v", accountID, supported, err)
return
}
logger.LegacyPrintf("service.openai_probe",
"probe_done: account_id=%d base_url=%s status=%d supported=%v",
accountID, normalizedBaseURL, resp.StatusCode, supported,
)
}
// isResponsesEndpointSupportedByStatus 根据探测响应的 HTTP 状态码判定上游
// 是否暴露 /v1/responses 端点。
//
// 关键观察:第三方 OpenAI 兼容上游DeepSeek/Kimi 等)对未知端点统一返回 404
// 或 405而 OpenAI 官方/有 Responses 实现的上游会因为请求体最简(缺字段)
// 返回 400/422 等业务错误,但端点本身存在。
//
// 因此:仅 404 和 405 视为"端点不存在",其他 status 视为"端点存在"。
//
// 5xx 也视为"端点存在"——上游偶发故障不应误判为不支持。
func isResponsesEndpointSupportedByStatus(status int) bool {
switch status {
case http.StatusNotFound, http.StatusMethodNotAllowed:
return false
}
return true
}

View File

@ -38,6 +38,29 @@ var codexModelMap = map[string]string{
"gpt-5.2-medium": "gpt-5.2", "gpt-5.2-medium": "gpt-5.2",
"gpt-5.2-high": "gpt-5.2", "gpt-5.2-high": "gpt-5.2",
"gpt-5.2-xhigh": "gpt-5.2", "gpt-5.2-xhigh": "gpt-5.2",
"gpt-5": "gpt-5.4",
"gpt-5-mini": "gpt-5.4",
"gpt-5-nano": "gpt-5.4",
"gpt-5.1": "gpt-5.4",
"gpt-5.1-codex": "gpt-5.3-codex",
"gpt-5.1-codex-max": "gpt-5.3-codex",
"gpt-5.1-codex-mini": "gpt-5.3-codex",
"gpt-5.2-codex": "gpt-5.2",
"codex-mini-latest": "gpt-5.3-codex",
"gpt-5-codex": "gpt-5.3-codex",
}
var codexVersionModelPrefixes = []struct {
prefix string
target string
}{
{prefix: "gpt-5.3-codex-spark", target: "gpt-5.3-codex-spark"},
{prefix: "gpt-5.3-codex", target: "gpt-5.3-codex"},
{prefix: "gpt-5.4-mini", target: "gpt-5.4-mini"},
{prefix: "gpt-5.4-nano", target: "gpt-5.4-nano"},
{prefix: "gpt-5.5", target: "gpt-5.5"},
{prefix: "gpt-5.4", target: "gpt-5.4"},
{prefix: "gpt-5.2", target: "gpt-5.2"},
} }
type codexTransformResult struct { type codexTransformResult struct {
@ -46,6 +69,13 @@ type codexTransformResult struct {
PromptCacheKey string PromptCacheKey string
} }
type codexOAuthTransformOptions struct {
IsCodexCLI bool
IsCompact bool
SkipDefaultInstructions bool
PreserveToolCallIDs bool
}
const ( const (
codexImageGenerationBridgeMarker = "<sub2api-codex-image-generation>" codexImageGenerationBridgeMarker = "<sub2api-codex-image-generation>"
codexImageGenerationBridgeText = codexImageGenerationBridgeMarker + "\nWhen the user asks for raster image generation or editing, use the OpenAI Responses native `image_generation` tool attached to this request. The local Codex client may not expose an `image_gen` namespace, but that does not mean image generation is unavailable. Do not ask the user to switch to CLI fallback solely because `image_gen` is absent.\n</sub2api-codex-image-generation>" codexImageGenerationBridgeText = codexImageGenerationBridgeMarker + "\nWhen the user asks for raster image generation or editing, use the OpenAI Responses native `image_generation` tool attached to this request. The local Codex client may not expose an `image_gen` namespace, but that does not mean image generation is unavailable. Do not ask the user to switch to CLI fallback solely because `image_gen` is absent.\n</sub2api-codex-image-generation>"
@ -71,6 +101,13 @@ var openAICodexOAuthUnsupportedFields = append([]string{
}, openAIChatGPTInternalUnsupportedFields...) }, openAIChatGPTInternalUnsupportedFields...)
func applyCodexOAuthTransform(reqBody map[string]any, isCodexCLI bool, isCompact bool) codexTransformResult { func applyCodexOAuthTransform(reqBody map[string]any, isCodexCLI bool, isCompact bool) codexTransformResult {
return applyCodexOAuthTransformWithOptions(reqBody, codexOAuthTransformOptions{
IsCodexCLI: isCodexCLI,
IsCompact: isCompact,
})
}
func applyCodexOAuthTransformWithOptions(reqBody map[string]any, opts codexOAuthTransformOptions) codexTransformResult {
result := codexTransformResult{} result := codexTransformResult{}
// 工具续链需求会影响存储策略与 input 过滤逻辑。 // 工具续链需求会影响存储策略与 input 过滤逻辑。
needsToolContinuation := NeedsToolContinuation(reqBody) needsToolContinuation := NeedsToolContinuation(reqBody)
@ -88,7 +125,7 @@ func applyCodexOAuthTransform(reqBody map[string]any, isCodexCLI bool, isCompact
result.NormalizedModel = normalizedModel result.NormalizedModel = normalizedModel
} }
if isCompact { if opts.IsCompact {
if _, ok := reqBody["store"]; ok { if _, ok := reqBody["store"]; ok {
delete(reqBody, "store") delete(reqBody, "store")
result.Modified = true result.Modified = true
@ -160,6 +197,10 @@ func applyCodexOAuthTransform(reqBody map[string]any, isCodexCLI bool, isCompact
if v, ok := reqBody["prompt_cache_key"].(string); ok { if v, ok := reqBody["prompt_cache_key"].(string); ok {
result.PromptCacheKey = strings.TrimSpace(v) result.PromptCacheKey = strings.TrimSpace(v)
if isOpenAICompatMessagesBridgeRequestBody(reqBody) {
delete(reqBody, "prompt_cache_key")
result.Modified = true
}
} }
// 提取 input 中 role:"system" 消息至 instructionsOAuth 上游不支持 system role // 提取 input 中 role:"system" 消息至 instructionsOAuth 上游不支持 system role
@ -168,7 +209,7 @@ func applyCodexOAuthTransform(reqBody map[string]any, isCodexCLI bool, isCompact
} }
// instructions 处理逻辑:根据是否是 Codex CLI 分别调用不同方法 // instructions 处理逻辑:根据是否是 Codex CLI 分别调用不同方法
if applyInstructions(reqBody, isCodexCLI) { if !opts.SkipDefaultInstructions && applyInstructions(reqBody, opts.IsCodexCLI) {
result.Modified = true result.Modified = true
} }
if isCodexSparkModel(normalizedModel) && applyCodexSparkImageUnsupportedInstructions(reqBody) { if isCodexSparkModel(normalizedModel) && applyCodexSparkImageUnsupportedInstructions(reqBody) {
@ -185,7 +226,10 @@ func applyCodexOAuthTransform(reqBody map[string]any, isCodexCLI bool, isCompact
input = normalizedInput input = normalizedInput
result.Modified = true result.Modified = true
} }
input = filterCodexInput(input, needsToolContinuation) input = filterCodexInputWithOptions(input, codexInputFilterOptions{
PreserveReferences: needsToolContinuation,
PreserveCallIDs: opts.PreserveToolCallIDs,
})
reqBody["input"] = input reqBody["input"] = input
result.Modified = true result.Modified = true
} else if inputStr, ok := reqBody["input"].(string); ok { } else if inputStr, ok := reqBody["input"].(string); ok {
@ -447,51 +491,81 @@ func normalizeCodexModel(model string) string {
if model == "" { if model == "" {
return "gpt-5.4" return "gpt-5.4"
} }
if mapped, ok := normalizeKnownCodexModel(model); ok {
return mapped
}
return model
}
func normalizeKnownCodexModel(model string) (string, bool) {
model = strings.TrimSpace(model)
if model == "" {
return "", false
}
if isOpenAIImageGenerationModel(model) { if isOpenAIImageGenerationModel(model) {
return model return model, true
} }
modelID := model modelID := lastOpenAIModelSegment(model)
if normalized := canonicalizeOpenAIModelAliasSpelling(modelID); normalized != "" {
modelID = normalized
}
if mapped := normalizeKnownOpenAICodexModel(modelID); mapped != "" {
return mapped, true
}
key := codexModelLookupKey(modelID)
if key == "" {
return "", false
}
if mapped := getNormalizedCodexModel(key); mapped != "" {
return mapped, true
}
for _, item := range codexVersionModelPrefixes {
if key == item.prefix {
return item.target, true
}
suffix, ok := strings.CutPrefix(key, item.prefix+"-")
if ok && isKnownCodexModelSuffix(suffix) {
return item.target, true
}
}
return "", false
}
func codexModelLookupKey(modelID string) string {
modelID = strings.TrimSpace(modelID)
if modelID == "" {
return ""
}
if strings.Contains(modelID, "/") { if strings.Contains(modelID, "/") {
parts := strings.Split(modelID, "/") parts := strings.Split(modelID, "/")
modelID = parts[len(parts)-1] modelID = parts[len(parts)-1]
} }
return strings.ToLower(strings.Join(strings.Fields(modelID), "-"))
}
if mapped := getNormalizedCodexModel(modelID); mapped != "" { func isKnownCodexModelSuffix(suffix string) bool {
return mapped switch suffix {
case "none", "minimal", "low", "medium", "high", "xhigh":
return true
} }
return isCodexDateSuffix(suffix)
}
normalized := strings.ToLower(modelID) func isCodexDateSuffix(suffix string) bool {
parts := strings.Split(suffix, "-")
if strings.Contains(normalized, "gpt-5.5") || strings.Contains(normalized, "gpt 5.5") { if len(parts) != 3 || len(parts[0]) != 4 || len(parts[1]) != 2 || len(parts[2]) != 2 {
return "gpt-5.5" return false
} }
if strings.Contains(normalized, "gpt-5.4-mini") || strings.Contains(normalized, "gpt 5.4 mini") { for _, part := range parts {
return "gpt-5.4-mini" for _, r := range part {
if r < '0' || r > '9' {
return false
}
}
} }
if strings.Contains(normalized, "gpt-5.4") || strings.Contains(normalized, "gpt 5.4") { return true
return "gpt-5.4"
}
if strings.Contains(normalized, "gpt-5.2") || strings.Contains(normalized, "gpt 5.2") {
return "gpt-5.2"
}
if strings.Contains(normalized, "gpt-5.3-codex-spark") || strings.Contains(normalized, "gpt 5.3 codex spark") {
return "gpt-5.3-codex-spark"
}
if strings.Contains(normalized, "gpt-5.3-codex") || strings.Contains(normalized, "gpt 5.3 codex") {
return "gpt-5.3-codex"
}
if strings.Contains(normalized, "gpt-5.3") || strings.Contains(normalized, "gpt 5.3") {
return "gpt-5.3-codex"
}
if strings.Contains(normalized, "codex") {
return "gpt-5.3-codex"
}
if strings.Contains(normalized, "gpt-5") || strings.Contains(normalized, "gpt 5") {
return "gpt-5.4"
}
return "gpt-5.4"
} }
func isCodexSparkModel(model string) bool { func isCodexSparkModel(model string) bool {
@ -789,23 +863,18 @@ func SupportsVerbosity(model string) bool {
} }
func getNormalizedCodexModel(modelID string) string { func getNormalizedCodexModel(modelID string) string {
if modelID == "" { key := codexModelLookupKey(modelID)
if key == "" {
return "" return ""
} }
if mapped, ok := codexModelMap[modelID]; ok { if mapped, ok := codexModelMap[key]; ok {
return mapped return mapped
} }
lower := strings.ToLower(modelID)
for key, value := range codexModelMap {
if strings.ToLower(key) == lower {
return value
}
}
return "" return ""
} }
// extractTextFromContent extracts plain text from a content value that is either // extractTextFromContent extracts plain text from a content value that is either
// a Go string or a []any of content-part maps with type:"text". // a Go string or a []any of text-like content-part maps.
func extractTextFromContent(content any) string { func extractTextFromContent(content any) string {
switch v := content.(type) { switch v := content.(type) {
case string: case string:
@ -817,7 +886,8 @@ func extractTextFromContent(content any) string {
if !ok { if !ok {
continue continue
} }
if t, _ := m["type"].(string); t == "text" { switch t, _ := m["type"].(string); t {
case "text", "input_text", "output_text":
if text, ok := m["text"].(string); ok { if text, ok := m["text"].(string); ok {
parts = append(parts, text) parts = append(parts, text)
} }
@ -871,6 +941,28 @@ func extractSystemMessagesFromInput(reqBody map[string]any) bool {
return true return true
} }
func extractPromptLikeInstructionsFromInput(reqBody map[string]any) string {
input, ok := reqBody["input"].([]any)
if !ok || len(input) == 0 {
return ""
}
var texts []string
for _, item := range input {
m, ok := item.(map[string]any)
if !ok {
continue
}
role, _ := m["role"].(string)
switch role {
case "developer", "system":
if text := strings.TrimSpace(extractTextFromContent(m["content"])); text != "" {
texts = append(texts, text)
}
}
}
return strings.Join(texts, "\n\n")
}
// applyInstructions 处理 instructions 字段:仅在 instructions 为空时填充默认值。 // applyInstructions 处理 instructions 字段:仅在 instructions 为空时填充默认值。
func applyInstructions(reqBody map[string]any, isCodexCLI bool) bool { func applyInstructions(reqBody map[string]any, isCodexCLI bool) bool {
if !isInstructionsEmpty(reqBody) { if !isInstructionsEmpty(reqBody) {
@ -897,9 +989,20 @@ func isInstructionsEmpty(reqBody map[string]any) bool {
return strings.TrimSpace(str) == "" return strings.TrimSpace(str) == ""
} }
type codexInputFilterOptions struct {
PreserveReferences bool
PreserveCallIDs bool
}
// filterCodexInput 按需过滤 item_reference 与 id。 // filterCodexInput 按需过滤 item_reference 与 id。
// preserveReferences 为 true 时保持引用与 id以满足续链请求对上下文的依赖。 // preserveReferences 为 true 时保持引用与 id以满足续链请求对上下文的依赖。
func filterCodexInput(input []any, preserveReferences bool) []any { func filterCodexInput(input []any, preserveReferences bool) []any {
return filterCodexInputWithOptions(input, codexInputFilterOptions{
PreserveReferences: preserveReferences,
})
}
func filterCodexInputWithOptions(input []any, opts codexInputFilterOptions) []any {
filtered := make([]any, 0, len(input)) filtered := make([]any, 0, len(input))
for _, item := range input { for _, item := range input {
m, ok := item.(map[string]any) m, ok := item.(map[string]any)
@ -920,6 +1023,9 @@ func filterCodexInput(input []any, preserveReferences bool) []any {
// 仅修正真正的 tool/function call 标识,避免误改普通 message/reasoning id // 仅修正真正的 tool/function call 标识,避免误改普通 message/reasoning id
// 若 item_reference 指向 legacy call_* 标识,则仅修正该引用本身。 // 若 item_reference 指向 legacy call_* 标识,则仅修正该引用本身。
fixCallIDPrefix := func(id string) string { fixCallIDPrefix := func(id string) string {
if opts.PreserveCallIDs {
return id
}
if id == "" || strings.HasPrefix(id, "fc") { if id == "" || strings.HasPrefix(id, "fc") {
return id return id
} }
@ -930,7 +1036,7 @@ func filterCodexInput(input []any, preserveReferences bool) []any {
} }
if typ == "item_reference" { if typ == "item_reference" {
if !preserveReferences { if !opts.PreserveReferences {
continue continue
} }
newItem := make(map[string]any, len(m)) newItem := make(map[string]any, len(m))
@ -998,7 +1104,7 @@ func filterCodexInput(input []any, preserveReferences bool) []any {
} }
} }
if !preserveReferences { if !opts.PreserveReferences {
ensureCopy() ensureCopy()
delete(newItem, "id") delete(newItem, "id")
} }

View File

@ -44,6 +44,39 @@ func TestApplyCodexOAuthTransform_ToolContinuationPreservesInput(t *testing.T) {
require.Equal(t, "fc1", second["call_id"]) require.Equal(t, "fc1", second["call_id"])
} }
func TestApplyCodexOAuthTransform_MessagesBridgePromptCacheKeyIsHeaderOnly(t *testing.T) {
reqBody := map[string]any{
"model": "gpt-5.5",
"prompt_cache_key": "anthropic-metadata-session-1",
"input": []any{
map[string]any{
"type": "message",
"role": "developer",
"content": []any{
map[string]any{
"type": "input_text",
"text": openAICompatClaudeCodeTodoGuardMarker,
},
},
},
map[string]any{
"type": "message",
"role": "user",
"content": "hello",
},
},
}
result := applyCodexOAuthTransformWithOptions(reqBody, codexOAuthTransformOptions{
SkipDefaultInstructions: true,
PreserveToolCallIDs: true,
})
require.Equal(t, "anthropic-metadata-session-1", result.PromptCacheKey)
require.True(t, result.Modified)
require.NotContains(t, reqBody, "prompt_cache_key")
}
func TestApplyCodexOAuthTransform_ToolContinuationPreservesNativeMessageAndReasoningIDs(t *testing.T) { func TestApplyCodexOAuthTransform_ToolContinuationPreservesNativeMessageAndReasoningIDs(t *testing.T) {
reqBody := map[string]any{ reqBody := map[string]any{
"model": "gpt-5.2", "model": "gpt-5.2",
@ -804,15 +837,25 @@ func TestApplyCodexOAuthTransform_EmptyInput(t *testing.T) {
func TestNormalizeCodexModel_Gpt53(t *testing.T) { func TestNormalizeCodexModel_Gpt53(t *testing.T) {
cases := map[string]string{ cases := map[string]string{
"gpt-5.4": "gpt-5.4", "gpt-5.4": "gpt-5.4",
"gpt5.5": "gpt-5.5",
"openai/gpt5.5": "gpt-5.5",
"gpt5.4": "gpt-5.4",
"gpt-5.4-high": "gpt-5.4", "gpt-5.4-high": "gpt-5.4",
"gpt-5.4-chat-latest": "gpt-5.4", "gpt-5.4-chat-latest": "gpt-5.4",
"gpt 5.4": "gpt-5.4", "gpt 5.4": "gpt-5.4",
"gpt-5.4-mini": "gpt-5.4-mini", "gpt-5.4-mini": "gpt-5.4-mini",
"gpt5.4-mini": "gpt-5.4-mini",
"gpt5.4mini": "gpt-5.4-mini",
"gpt 5.4 mini": "gpt-5.4-mini", "gpt 5.4 mini": "gpt-5.4-mini",
"gpt-5.3": "gpt-5.3-codex", "gpt-5.3": "gpt-5.3-codex",
"gpt5.3": "gpt-5.3-codex",
"gpt-5.3-codex": "gpt-5.3-codex", "gpt-5.3-codex": "gpt-5.3-codex",
"gpt5.3-codex": "gpt-5.3-codex",
"gpt5.3codex": "gpt-5.3-codex",
"gpt-5.3-codex-xhigh": "gpt-5.3-codex", "gpt-5.3-codex-xhigh": "gpt-5.3-codex",
"gpt-5.3-codex-spark": "gpt-5.3-codex-spark", "gpt-5.3-codex-spark": "gpt-5.3-codex-spark",
"gpt5.3-codex-spark": "gpt-5.3-codex-spark",
"gpt5.3codexspark": "gpt-5.3-codex-spark",
"gpt 5.3 codex spark": "gpt-5.3-codex-spark", "gpt 5.3 codex spark": "gpt-5.3-codex-spark",
"gpt-5.3-codex-spark-high": "gpt-5.3-codex-spark", "gpt-5.3-codex-spark-high": "gpt-5.3-codex-spark",
"gpt-5.3-codex-spark-xhigh": "gpt-5.3-codex-spark", "gpt-5.3-codex-spark-xhigh": "gpt-5.3-codex-spark",

File diff suppressed because it is too large Load Diff

View File

@ -1,7 +1,9 @@
package service package service
import ( import (
"crypto/sha256"
"encoding/json" "encoding/json"
"fmt"
"strings" "strings"
"github.com/Wei-Shaw/sub2api/internal/pkg/apicompat" "github.com/Wei-Shaw/sub2api/internal/pkg/apicompat"
@ -16,12 +18,8 @@ func shouldAutoInjectPromptCacheKeyForCompat(model string) bool {
if !strings.Contains(trimmed, "gpt-5") && !strings.Contains(trimmed, "codex") { if !strings.Contains(trimmed, "gpt-5") && !strings.Contains(trimmed, "codex") {
return false return false
} }
switch normalizeCodexModel(trimmed) { normalized := strings.TrimSpace(strings.ToLower(normalizeCodexModel(trimmed)))
case "gpt-5.4", "gpt-5.3-codex", "gpt-5.3-codex-spark": return strings.HasPrefix(normalized, "gpt-5") || strings.Contains(normalized, "codex")
return true
default:
return false
}
} }
func deriveCompatPromptCacheKey(req *apicompat.ChatCompletionsRequest, mappedModel string) string { func deriveCompatPromptCacheKey(req *apicompat.ChatCompletionsRequest, mappedModel string) string {
@ -71,6 +69,102 @@ func deriveCompatPromptCacheKey(req *apicompat.ChatCompletionsRequest, mappedMod
return compatPromptCacheKeyPrefix + hashSensitiveValueForLog(strings.Join(seedParts, "|")) return compatPromptCacheKeyPrefix + hashSensitiveValueForLog(strings.Join(seedParts, "|"))
} }
func deriveAnthropicCompatPromptCacheKey(req *apicompat.AnthropicRequest, mappedModel string) string {
if req == nil {
return ""
}
if anchorKey := deriveAnthropicCacheControlPromptCacheKey(req); anchorKey != "" {
return anchorKey
}
normalizedModel := normalizeCodexModel(strings.TrimSpace(mappedModel))
if normalizedModel == "" {
normalizedModel = normalizeCodexModel(strings.TrimSpace(req.Model))
}
if normalizedModel == "" {
normalizedModel = strings.TrimSpace(req.Model)
}
seedParts := []string{"model=" + normalizedModel}
if req.OutputConfig != nil && strings.TrimSpace(req.OutputConfig.Effort) != "" {
seedParts = append(seedParts, "effort="+strings.TrimSpace(req.OutputConfig.Effort))
}
if len(req.ToolChoice) > 0 {
seedParts = append(seedParts, "tool_choice="+normalizeCompatSeedJSON(req.ToolChoice))
}
if len(req.Tools) > 0 {
if raw, err := json.Marshal(req.Tools); err == nil {
seedParts = append(seedParts, "tools="+normalizeCompatSeedJSON(raw))
}
}
if len(req.System) > 0 {
seedParts = append(seedParts, "system="+normalizeCompatSeedJSON(req.System))
}
firstUserCaptured := false
for _, msg := range req.Messages {
if strings.TrimSpace(msg.Role) != "user" || firstUserCaptured {
continue
}
seedParts = append(seedParts, "first_user="+normalizeCompatSeedJSON(msg.Content))
firstUserCaptured = true
}
return compatPromptCacheKeyPrefix + hashSensitiveValueForLog(strings.Join(seedParts, "|"))
}
func deriveAnthropicCacheControlPromptCacheKey(req *apicompat.AnthropicRequest) string {
if req == nil {
return ""
}
var parts []string
var systemBlocks []apicompat.AnthropicContentBlock
if len(req.System) > 0 && json.Unmarshal(req.System, &systemBlocks) == nil {
for _, block := range systemBlocks {
if block.Type == "text" &&
block.CacheControl != nil &&
strings.TrimSpace(block.CacheControl.Type) == "ephemeral" &&
strings.TrimSpace(block.Text) != "" {
parts = append(parts, "system:"+strings.TrimSpace(block.Text))
}
}
}
firstUserAnchor := ""
for _, msg := range req.Messages {
var blocks []apicompat.AnthropicContentBlock
if len(msg.Content) == 0 || json.Unmarshal(msg.Content, &blocks) != nil {
continue
}
role := strings.TrimSpace(msg.Role)
for _, block := range blocks {
if block.Type != "text" ||
block.CacheControl == nil ||
strings.TrimSpace(block.CacheControl.Type) != "ephemeral" ||
strings.TrimSpace(block.Text) == "" {
continue
}
switch role {
case "user":
if firstUserAnchor == "" {
firstUserAnchor = strings.TrimSpace(block.Text)
}
case "assistant":
parts = append(parts, "assistant:"+strings.TrimSpace(block.Text))
}
}
}
if firstUserAnchor != "" {
parts = append(parts, "user_anchor:"+firstUserAnchor)
}
if len(parts) == 0 {
return ""
}
sum := sha256.Sum256([]byte("anthropic-cache:" + strings.Join(parts, "\n")))
return fmt.Sprintf("anthropic-cache-%x", sum[:16])
}
func normalizeCompatSeedJSON(v json.RawMessage) string { func normalizeCompatSeedJSON(v json.RawMessage) string {
if len(v) == 0 { if len(v) == 0 {
return "" return ""

View File

@ -2,6 +2,7 @@ package service
import ( import (
"encoding/json" "encoding/json"
"strings"
"testing" "testing"
"github.com/Wei-Shaw/sub2api/internal/pkg/apicompat" "github.com/Wei-Shaw/sub2api/internal/pkg/apicompat"
@ -14,7 +15,10 @@ func mustRawJSON(t *testing.T, s string) json.RawMessage {
} }
func TestShouldAutoInjectPromptCacheKeyForCompat(t *testing.T) { func TestShouldAutoInjectPromptCacheKeyForCompat(t *testing.T) {
require.True(t, shouldAutoInjectPromptCacheKeyForCompat("gpt-5.5"))
require.True(t, shouldAutoInjectPromptCacheKeyForCompat("gpt-5.4")) require.True(t, shouldAutoInjectPromptCacheKeyForCompat("gpt-5.4"))
require.True(t, shouldAutoInjectPromptCacheKeyForCompat("gpt-5.4-mini"))
require.True(t, shouldAutoInjectPromptCacheKeyForCompat("gpt-5.2"))
require.True(t, shouldAutoInjectPromptCacheKeyForCompat("gpt-5.3")) require.True(t, shouldAutoInjectPromptCacheKeyForCompat("gpt-5.3"))
require.True(t, shouldAutoInjectPromptCacheKeyForCompat("gpt-5.3-codex")) require.True(t, shouldAutoInjectPromptCacheKeyForCompat("gpt-5.3-codex"))
require.True(t, shouldAutoInjectPromptCacheKeyForCompat("gpt-5.3-codex-spark")) require.True(t, shouldAutoInjectPromptCacheKeyForCompat("gpt-5.3-codex-spark"))
@ -77,3 +81,57 @@ func TestDeriveCompatPromptCacheKey_UsesResolvedSparkFamily(t *testing.T) {
require.NotEmpty(t, k1) require.NotEmpty(t, k1)
require.Equal(t, k1, k2, "resolved spark family should derive a stable compat cache key") require.Equal(t, k1, k2, "resolved spark family should derive a stable compat cache key")
} }
func TestDeriveAnthropicCompatPromptCacheKey_StableAcrossLaterTurns(t *testing.T) {
base := &apicompat.AnthropicRequest{
Model: "claude-sonnet-4-5",
System: mustRawJSON(t, `"You are helpful."`),
Messages: []apicompat.AnthropicMessage{
{Role: "user", Content: mustRawJSON(t, `"Open repo"`)},
},
}
extended := &apicompat.AnthropicRequest{
Model: "claude-sonnet-4-5",
System: mustRawJSON(t, `"You are helpful."`),
Messages: []apicompat.AnthropicMessage{
{Role: "user", Content: mustRawJSON(t, `"Open repo"`)},
{Role: "assistant", Content: mustRawJSON(t, `"Opened."`)},
{Role: "user", Content: mustRawJSON(t, `"Run tests"`)},
},
}
k1 := deriveAnthropicCompatPromptCacheKey(base, "gpt-5.3-codex")
k2 := deriveAnthropicCompatPromptCacheKey(extended, "gpt-5.3-codex")
require.NotEmpty(t, k1)
require.Equal(t, k1, k2, "cache key should stay stable as later Claude Code turns append history")
}
func TestDeriveAnthropicCompatPromptCacheKey_UsesCacheControlAnchors(t *testing.T) {
base := &apicompat.AnthropicRequest{
Model: "claude-sonnet-4-5",
System: mustRawJSON(t, `[
{"type":"text","text":"project instructions","cache_control":{"type":"ephemeral"}}
]`),
Messages: []apicompat.AnthropicMessage{
{Role: "user", Content: mustRawJSON(t, `[
{"type":"text","text":"repo anchor","cache_control":{"type":"ephemeral"}}
]`)},
},
}
extended := &apicompat.AnthropicRequest{
Model: base.Model,
System: base.System,
Messages: []apicompat.AnthropicMessage{
base.Messages[0],
{Role: "assistant", Content: mustRawJSON(t, `[{"type":"text","text":"Opened."}]`)},
{Role: "user", Content: mustRawJSON(t, `[{"type":"text","text":"Run tests"}]`)},
},
}
k1 := deriveAnthropicCompatPromptCacheKey(base, "gpt-5.4")
k2 := deriveAnthropicCompatPromptCacheKey(extended, "gpt-5.4")
require.NotEmpty(t, k1)
require.Equal(t, k1, k2)
require.True(t, strings.HasPrefix(k1, "anthropic-cache-"))
require.False(t, strings.HasPrefix(k1, compatPromptCacheKeyPrefix))
}

View File

@ -972,6 +972,62 @@ func TestPassthroughBilling_MultiTurnServiceTierFollowsFilteredFrames(t *testing
"turn 3: response.create without service_tier overwrites billing to nil to match upstream default") "turn 3: response.create without service_tier overwrites billing to nil to match upstream default")
} }
func TestPassthroughUsageMeta_TracksReasoningEffortAcrossTurns(t *testing.T) {
svc := newOpenAIGatewayServiceWithSettings(t, DefaultOpenAIFastPolicySettings())
account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
firstFrame := []byte(`{"type":"response.create","model":"gpt-5.5","reasoning":{"effort":"medium"},"service_tier":"priority"}`)
meta := newOpenAIWSPassthroughUsageMeta("", firstFrame)
capturedSessionModel := openAIWSPassthroughPolicyModelForFrame(account, firstFrame)
firstOut, firstBlocked, firstErr := svc.applyOpenAIFastPolicyToWSResponseCreate(context.Background(), account, capturedSessionModel, firstFrame)
require.NoError(t, firstErr)
require.Nil(t, firstBlocked)
meta.initFromFirstFrame(firstOut)
require.NotNil(t, meta.reasoningEffort.Load())
require.Equal(t, "medium", *meta.reasoningEffort.Load())
process := func(payload []byte) ([]byte, *OpenAIFastBlockedError, error) {
if updated := openAIWSPassthroughPolicyModelFromSessionFrame(account, payload); updated != "" {
capturedSessionModel = updated
}
meta.updateSessionRequestModel(payload)
requestModelForThisFrame := meta.requestModelForFrame(payload)
model := openAIWSPassthroughPolicyModelForFrame(account, payload)
if model == "" {
model = capturedSessionModel
}
out, blocked, policyErr := svc.applyOpenAIFastPolicyToWSResponseCreate(context.Background(), account, model, payload)
if policyErr == nil && blocked == nil &&
strings.TrimSpace(gjson.GetBytes(payload, "type").String()) == "response.create" {
meta.updateFromResponseCreate(out, requestModelForThisFrame)
}
return out, blocked, policyErr
}
_, blockedSession, errSession := process([]byte(`{"type":"session.update","session":{"model":"gpt-5-high"}}`))
require.NoError(t, errSession)
require.Nil(t, blockedSession)
require.NotNil(t, meta.reasoningEffort.Load())
require.Equal(t, "medium", *meta.reasoningEffort.Load(), "session.update 只刷新后续 fallback model不覆盖当前 turn metadata")
_, blockedCancel, errCancel := process([]byte(`{"type":"response.cancel","reasoning_effort":"x-high"}`))
require.NoError(t, errCancel)
require.Nil(t, blockedCancel)
require.NotNil(t, meta.reasoningEffort.Load())
require.Equal(t, "medium", *meta.reasoningEffort.Load(), "非 response.create 帧不能污染当前 turn metadata")
_, blockedFlat, errFlat := process([]byte(`{"type":"response.create","reasoning_effort":"x-high"}`))
require.NoError(t, errFlat)
require.Nil(t, blockedFlat)
require.NotNil(t, meta.reasoningEffort.Load())
require.Equal(t, "xhigh", *meta.reasoningEffort.Load(), "flat reasoning_effort 必须进入 passthrough usage metadata")
_, blockedClear, errClear := process([]byte(`{"type":"response.create","model":"gpt-4o"}`))
require.NoError(t, errClear)
require.Nil(t, blockedClear)
require.Nil(t, meta.reasoningEffort.Load(), "新的 response.create 无 effort 且无可推导后缀时必须清空旧值")
}
// TestPassthroughBilling_BlockedFrameDoesNotMutateServiceTier locks in the // TestPassthroughBilling_BlockedFrameDoesNotMutateServiceTier locks in the
// "block keeps previous" semantic: when policy returns block on a // "block keeps previous" semantic: when policy returns block on a
// response.create frame, that frame is never sent upstream, so billing tier // response.create frame, that frame is never sent upstream, so billing tier

View File

@ -20,20 +20,29 @@ func (s *openAI403CounterResetStub) ResetOpenAI403Count(_ context.Context, accou
return nil return nil
} }
func TestOpenAIGatewayServiceRecordUsage_ResetsOpenAI403CounterBeforeZeroUsageReturn(t *testing.T) { func TestOpenAIGatewayServiceRecordUsage_ResetsOpenAI403CounterForZeroUsage(t *testing.T) {
counter := &openAI403CounterResetStub{} counter := &openAI403CounterResetStub{}
rateLimitSvc := NewRateLimitService(nil, nil, nil, nil, nil) rateLimitSvc := NewRateLimitService(nil, nil, nil, nil, nil)
rateLimitSvc.SetOpenAI403CounterCache(counter) rateLimitSvc.SetOpenAI403CounterCache(counter)
svc := &OpenAIGatewayService{ usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
rateLimitService: rateLimitSvc, billingRepo := &openAIRecordUsageBillingRepoStub{result: &UsageBillingApplyResult{Applied: true}}
} userRepo := &openAIRecordUsageUserRepoStub{}
subRepo := &openAIRecordUsageSubRepoStub{}
svc := newOpenAIRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, userRepo, subRepo, nil)
svc.rateLimitService = rateLimitSvc
err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{ err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
Result: &OpenAIForwardResult{}, Result: &OpenAIForwardResult{
RequestID: "resp_zero_usage_reset_403",
Model: "gpt-5.1",
},
APIKey: &APIKey{ID: 1001, Group: &Group{RateMultiplier: 1}},
User: &User{ID: 2001},
Account: &Account{ID: 777, Platform: PlatformOpenAI}, Account: &Account{ID: 777, Platform: PlatformOpenAI},
}) })
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, []int64{777}, counter.resetCalls) require.Equal(t, []int64{777}, counter.resetCalls)
require.Equal(t, 1, usageRepo.calls)
} }

View File

@ -10,10 +10,12 @@ import (
"io" "io"
"net/http" "net/http"
"strings" "strings"
"sync/atomic"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/pkg/apicompat" "github.com/Wei-Shaw/sub2api/internal/pkg/apicompat"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger" "github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai_compat"
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders" "github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/tidwall/gjson" "github.com/tidwall/gjson"
@ -39,9 +41,18 @@ var cursorResponsesUnsupportedFields = []string{
// ForwardAsChatCompletions accepts a Chat Completions request body, converts it // ForwardAsChatCompletions accepts a Chat Completions request body, converts it
// to OpenAI Responses API format, forwards to the OpenAI upstream, and converts // to OpenAI Responses API format, forwards to the OpenAI upstream, and converts
// the response back to Chat Completions format. All account types (OAuth and API // the response back to Chat Completions format.
// Key) go through the Responses API conversion path since the upstream only //
// exposes the /v1/responses endpoint. // 历史背景:该函数原本对所有 OpenAI 账号无差别走 CC→Responses 转换 + /v1/responses
// 端点——这在 OAuthChatGPT 内部 API 仅支持 Responses和官方 APIKey 账号上是
// 正确的,但 sub2api 接入 DeepSeek/Kimi/GLM 等第三方 OpenAI 兼容上游后假设破裂:
// 这些上游普遍只支持 /v1/chat/completions无 /v1/responses 端点。
//
// 当前路由策略(基于账号探测标记,详见 openai_compat.ShouldUseResponsesAPI
// - APIKey 账号 + 探测确认不支持 Responses → 走 forwardAsRawChatCompletions
// 直转上游 /v1/chat/completions不做协议转换
// - 其他所有情况OAuth、APIKey 探测确认支持、未探测)→ 走原有 CC→Responses
// 转换路径(保留旧行为,存量未探测账号零兼容破坏)
func (s *OpenAIGatewayService) ForwardAsChatCompletions( func (s *OpenAIGatewayService) ForwardAsChatCompletions(
ctx context.Context, ctx context.Context,
c *gin.Context, c *gin.Context,
@ -50,6 +61,12 @@ func (s *OpenAIGatewayService) ForwardAsChatCompletions(
promptCacheKey string, promptCacheKey string,
defaultMappedModel string, defaultMappedModel string,
) (*OpenAIForwardResult, error) { ) (*OpenAIForwardResult, error) {
// 入口分流APIKey 账号 + 已探测且确认上游不支持 Responses走 CC 直转。
// 标记缺失(未探测)按"现状即证据"原则继续走下方原 Responses 转换路径。
if account.Type == AccountTypeAPIKey && !openai_compat.ShouldUseResponsesAPI(account.Extra) {
return s.forwardAsRawChatCompletions(ctx, c, account, body, defaultMappedModel)
}
startTime := time.Now() startTime := time.Now()
// 1. Parse Chat Completions request // 1. Parse Chat Completions request
@ -189,7 +206,9 @@ func (s *OpenAIGatewayService) ForwardAsChatCompletions(
} }
// 6. Build upstream request // 6. Build upstream request
upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, responsesBody, token, true, promptCacheKey, false) upstreamCtx, releaseUpstreamCtx := detachUpstreamContext(ctx)
upstreamReq, err := s.buildUpstreamRequest(upstreamCtx, c, account, responsesBody, token, true, promptCacheKey, false)
releaseUpstreamCtx()
if err != nil { if err != nil {
return nil, fmt.Errorf("build upstream request: %w", err) return nil, fmt.Errorf("build upstream request: %w", err)
} }
@ -348,59 +367,9 @@ func (s *OpenAIGatewayService) handleChatBufferedStreamingResponse(
) (*OpenAIForwardResult, error) { ) (*OpenAIForwardResult, error) {
requestID := resp.Header.Get("x-request-id") requestID := resp.Header.Get("x-request-id")
scanner := bufio.NewScanner(resp.Body) finalResponse, usage, acc, err := s.readOpenAICompatBufferedTerminal(resp, "openai chat_completions buffered", requestID)
maxLineSize := defaultMaxLineSize if err != nil {
if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 { return nil, err
maxLineSize = s.cfg.Gateway.MaxLineSize
}
scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize)
var finalResponse *apicompat.ResponsesResponse
var usage OpenAIUsage
acc := apicompat.NewBufferedResponseAccumulator()
for scanner.Scan() {
line := scanner.Text()
if !strings.HasPrefix(line, "data: ") || line == "data: [DONE]" {
continue
}
payload := line[6:]
var event apicompat.ResponsesStreamEvent
if err := json.Unmarshal([]byte(payload), &event); err != nil {
logger.L().Warn("openai chat_completions buffered: failed to parse event",
zap.Error(err),
zap.String("request_id", requestID),
)
continue
}
// Accumulate delta content for fallback when terminal output is empty.
acc.ProcessEvent(&event)
if (event.Type == "response.completed" || event.Type == "response.done" ||
event.Type == "response.incomplete" || event.Type == "response.failed") &&
event.Response != nil {
finalResponse = event.Response
if event.Response.Usage != nil {
usage = OpenAIUsage{
InputTokens: event.Response.Usage.InputTokens,
OutputTokens: event.Response.Usage.OutputTokens,
}
if event.Response.Usage.InputTokensDetails != nil {
usage.CacheReadInputTokens = event.Response.Usage.InputTokensDetails.CachedTokens
}
}
}
}
if err := scanner.Err(); err != nil {
if !errors.Is(err, context.Canceled) && !errors.Is(err, context.DeadlineExceeded) {
logger.L().Warn("openai chat_completions buffered: read error",
zap.Error(err),
zap.String("request_id", requestID),
)
}
} }
if finalResponse == nil { if finalResponse == nil {
@ -459,6 +428,7 @@ func (s *OpenAIGatewayService) handleChatStreamingResponse(
var usage OpenAIUsage var usage OpenAIUsage
var firstTokenMs *int var firstTokenMs *int
firstChunk := true firstChunk := true
clientDisconnected := false
scanner := bufio.NewScanner(resp.Body) scanner := bufio.NewScanner(resp.Body)
maxLineSize := defaultMaxLineSize maxLineSize := defaultMaxLineSize
@ -467,6 +437,20 @@ func (s *OpenAIGatewayService) handleChatStreamingResponse(
} }
scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize) scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize)
streamInterval := time.Duration(0)
if s.cfg != nil && s.cfg.Gateway.StreamDataIntervalTimeout > 0 {
streamInterval = time.Duration(s.cfg.Gateway.StreamDataIntervalTimeout) * time.Second
}
var intervalTicker *time.Ticker
if streamInterval > 0 {
intervalTicker = time.NewTicker(streamInterval)
defer intervalTicker.Stop()
}
var intervalCh <-chan time.Time
if intervalTicker != nil {
intervalCh = intervalTicker.C
}
resultWithUsage := func() *OpenAIForwardResult { resultWithUsage := func() *OpenAIForwardResult {
return &OpenAIForwardResult{ return &OpenAIForwardResult{
RequestID: requestID, RequestID: requestID,
@ -496,54 +480,66 @@ func (s *OpenAIGatewayService) handleChatStreamingResponse(
return false return false
} }
// Extract usage from completion events // 仅按兼容转换器支持的终止事件提取 usage避免无意扩大事件语义。
if (event.Type == "response.completed" || event.Type == "response.incomplete" || event.Type == "response.failed") && isTerminalEvent := isOpenAICompatResponsesTerminalEvent(event.Type)
event.Response != nil && event.Response.Usage != nil { if isTerminalEvent && event.Response != nil && event.Response.Usage != nil {
usage = OpenAIUsage{ usage = copyOpenAIUsageFromResponsesUsage(event.Response.Usage)
InputTokens: event.Response.Usage.InputTokens,
OutputTokens: event.Response.Usage.OutputTokens,
}
if event.Response.Usage.InputTokensDetails != nil {
usage.CacheReadInputTokens = event.Response.Usage.InputTokensDetails.CachedTokens
}
} }
chunks := apicompat.ResponsesEventToChatChunks(&event, state) chunks := apicompat.ResponsesEventToChatChunks(&event, state)
for _, chunk := range chunks { if !clientDisconnected {
sse, err := apicompat.ChatChunkToSSE(chunk) for _, chunk := range chunks {
if err != nil { sse, err := apicompat.ChatChunkToSSE(chunk)
logger.L().Warn("openai chat_completions stream: failed to marshal chunk", if err != nil {
zap.Error(err), logger.L().Warn("openai chat_completions stream: failed to marshal chunk",
zap.String("request_id", requestID), zap.Error(err),
) zap.String("request_id", requestID),
continue )
} continue
if _, err := fmt.Fprint(c.Writer, sse); err != nil { }
logger.L().Info("openai chat_completions stream: client disconnected", if _, err := fmt.Fprint(c.Writer, sse); err != nil {
zap.String("request_id", requestID), clientDisconnected = true
) logger.L().Info("openai chat_completions stream: client disconnected, continuing to drain upstream for billing",
return true zap.String("request_id", requestID),
)
break
}
} }
} }
if len(chunks) > 0 { if len(chunks) > 0 && !clientDisconnected {
c.Writer.Flush() c.Writer.Flush()
} }
return false return isTerminalEvent
} }
finalizeStream := func() (*OpenAIForwardResult, error) { finalizeStream := func() (*OpenAIForwardResult, error) {
if finalChunks := apicompat.FinalizeResponsesChatStream(state); len(finalChunks) > 0 { if finalChunks := apicompat.FinalizeResponsesChatStream(state); len(finalChunks) > 0 && !clientDisconnected {
for _, chunk := range finalChunks { for _, chunk := range finalChunks {
sse, err := apicompat.ChatChunkToSSE(chunk) sse, err := apicompat.ChatChunkToSSE(chunk)
if err != nil { if err != nil {
continue continue
} }
fmt.Fprint(c.Writer, sse) //nolint:errcheck if _, err := fmt.Fprint(c.Writer, sse); err != nil {
clientDisconnected = true
logger.L().Info("openai chat_completions stream: client disconnected during final flush",
zap.String("request_id", requestID),
)
break
}
} }
} }
// Send [DONE] sentinel // Send [DONE] sentinel
fmt.Fprint(c.Writer, "data: [DONE]\n\n") //nolint:errcheck if !clientDisconnected {
c.Writer.Flush() if _, err := fmt.Fprint(c.Writer, "data: [DONE]\n\n"); err != nil {
clientDisconnected = true
logger.L().Info("openai chat_completions stream: client disconnected during done flush",
zap.String("request_id", requestID),
)
}
}
if !clientDisconnected {
c.Writer.Flush()
}
return resultWithUsage(), nil return resultWithUsage(), nil
} }
@ -555,6 +551,9 @@ func (s *OpenAIGatewayService) handleChatStreamingResponse(
) )
} }
} }
missingTerminalErr := func() (*OpenAIForwardResult, error) {
return resultWithUsage(), fmt.Errorf("stream usage incomplete: missing terminal event")
}
// Determine keepalive interval // Determine keepalive interval
keepaliveInterval := time.Duration(0) keepaliveInterval := time.Duration(0)
@ -563,18 +562,25 @@ func (s *OpenAIGatewayService) handleChatStreamingResponse(
} }
// No keepalive: fast synchronous path // No keepalive: fast synchronous path
if keepaliveInterval <= 0 { if streamInterval <= 0 && keepaliveInterval <= 0 {
for scanner.Scan() { for scanner.Scan() {
line := scanner.Text() line := scanner.Text()
if !strings.HasPrefix(line, "data: ") || line == "data: [DONE]" { payload, ok := extractOpenAISSEDataLine(line)
if !ok {
continue continue
} }
if processDataLine(line[6:]) { if strings.TrimSpace(payload) == "[DONE]" {
return resultWithUsage(), nil return missingTerminalErr()
}
if processDataLine(payload) {
return finalizeStream()
} }
} }
handleScanErr(scanner.Err()) if err := scanner.Err(); err != nil {
return finalizeStream() handleScanErr(err)
return resultWithUsage(), fmt.Errorf("stream usage incomplete: %w", err)
}
return missingTerminalErr()
} }
// With keepalive: goroutine + channel + select // With keepalive: goroutine + channel + select
@ -584,6 +590,8 @@ func (s *OpenAIGatewayService) handleChatStreamingResponse(
} }
events := make(chan scanEvent, 16) events := make(chan scanEvent, 16)
done := make(chan struct{}) done := make(chan struct{})
var lastReadAt int64
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
sendEvent := func(ev scanEvent) bool { sendEvent := func(ev scanEvent) bool {
select { select {
case events <- ev: case events <- ev:
@ -595,6 +603,7 @@ func (s *OpenAIGatewayService) handleChatStreamingResponse(
go func() { go func() {
defer close(events) defer close(events)
for scanner.Scan() { for scanner.Scan() {
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
if !sendEvent(scanEvent{line: scanner.Text()}) { if !sendEvent(scanEvent{line: scanner.Text()}) {
return return
} }
@ -605,30 +614,59 @@ func (s *OpenAIGatewayService) handleChatStreamingResponse(
}() }()
defer close(done) defer close(done)
keepaliveTicker := time.NewTicker(keepaliveInterval) var keepaliveTicker *time.Ticker
defer keepaliveTicker.Stop() if keepaliveInterval > 0 {
keepaliveTicker = time.NewTicker(keepaliveInterval)
defer keepaliveTicker.Stop()
}
var keepaliveCh <-chan time.Time
if keepaliveTicker != nil {
keepaliveCh = keepaliveTicker.C
}
lastDataAt := time.Now() lastDataAt := time.Now()
for { for {
select { select {
case ev, ok := <-events: case ev, ok := <-events:
if !ok { if !ok {
return finalizeStream() return missingTerminalErr()
} }
if ev.err != nil { if ev.err != nil {
handleScanErr(ev.err) handleScanErr(ev.err)
return finalizeStream() return resultWithUsage(), fmt.Errorf("stream usage incomplete: %w", ev.err)
} }
lastDataAt = time.Now() lastDataAt = time.Now()
line := ev.line line := ev.line
if !strings.HasPrefix(line, "data: ") || line == "data: [DONE]" { payload, ok := extractOpenAISSEDataLine(line)
if !ok {
continue continue
} }
if processDataLine(line[6:]) { if strings.TrimSpace(payload) == "[DONE]" {
return resultWithUsage(), nil return missingTerminalErr()
}
if processDataLine(payload) {
return finalizeStream()
} }
case <-keepaliveTicker.C: case <-intervalCh:
lastRead := time.Unix(0, atomic.LoadInt64(&lastReadAt))
if time.Since(lastRead) < streamInterval {
continue
}
if clientDisconnected {
return resultWithUsage(), fmt.Errorf("stream usage incomplete after timeout")
}
logger.L().Warn("openai chat_completions stream: data interval timeout",
zap.String("request_id", requestID),
zap.String("model", originalModel),
zap.Duration("interval", streamInterval),
)
return resultWithUsage(), fmt.Errorf("stream data interval timeout")
case <-keepaliveCh:
if clientDisconnected {
continue
}
if time.Since(lastDataAt) < keepaliveInterval { if time.Since(lastDataAt) < keepaliveInterval {
continue continue
} }
@ -637,7 +675,8 @@ func (s *OpenAIGatewayService) handleChatStreamingResponse(
logger.L().Info("openai chat_completions stream: client disconnected during keepalive", logger.L().Info("openai chat_completions stream: client disconnected during keepalive",
zap.String("request_id", requestID), zap.String("request_id", requestID),
) )
return resultWithUsage(), nil clientDisconnected = true
continue
} }
c.Writer.Flush() c.Writer.Flush()
} }

View File

@ -0,0 +1,437 @@
package service
import (
"bufio"
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/apicompat"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
"github.com/gin-gonic/gin"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
"go.uber.org/zap"
)
// openaiCCRawAllowedHeaders 是 CC 直转路径专用的客户端 header 透传白名单。
//
// **关键**:不能复用 openaiAllowedHeaders——后者含 Codex 客户端专属 header
// originator / session_id / x-codex-turn-state / x-codex-turn-metadata / conversation_id
// 这些在 ChatGPT OAuth 上游是必需的,但透传给 DeepSeek/Kimi/GLM 等第三方
// OpenAI 兼容上游会造成:
// - 完全忽略(多数友好厂商)——隐性污染上游统计
// - 400 "unknown parameter"(严格上游)——可见错误
//
// 这里仅放行通用 HTTP headercontent-type / authorization / accept 由上下文
// 显式设置,不依赖透传。
//
// 参见决策记录:
// pensieve/short-term/maxims/dont-reuse-shared-headers-whitelist-across-different-upstream-trust-domains
var openaiCCRawAllowedHeaders = map[string]bool{
"accept-language": true,
"user-agent": true,
}
// forwardAsRawChatCompletions 直转客户端的 Chat Completions 请求到上游
// `{base_url}/v1/chat/completions`**不**做 CC↔Responses 协议转换。
//
// 适用场景account.platform=openai && account.type=apikey && 上游已被探测确认
// 不支持 /v1/responses 端点(如 DeepSeek/Kimi/GLM/Qwen 等第三方 OpenAI 兼容上游)。
//
// 与 ForwardAsChatCompletions 的关键差异:
//
// - 不调用 apicompat.ChatCompletionsToResponsesbody 仅做模型 ID 改写
// - 上游 URL 拼到 /v1/chat/completions 而非 /v1/responses
// - 流式响应 SSE 直接透传给客户端(上游 chunk 已是 CC 格式)
// - 非流式响应 JSON 直接透传,仅按需提取 usage
// - 不应用 codex OAuth transformAPIKey 路径无 OAuth
// - 不注入 prompt_cache_keyOAuth 专属机制)
//
// 调用入口openai_gateway_chat_completions.go::ForwardAsChatCompletions
// 在函数顶部按 openai_compat.ShouldUseResponsesAPI 分流。
func (s *OpenAIGatewayService) forwardAsRawChatCompletions(
ctx context.Context,
c *gin.Context,
account *Account,
body []byte,
defaultMappedModel string,
) (*OpenAIForwardResult, error) {
startTime := time.Now()
// 1. Parse minimal fields needed for routing/billing
originalModel := gjson.GetBytes(body, "model").String()
if originalModel == "" {
writeChatCompletionsError(c, http.StatusBadRequest, "invalid_request_error", "model is required")
return nil, fmt.Errorf("missing model in request")
}
clientStream := gjson.GetBytes(body, "stream").Bool()
// 1b. Extract reasoning effort and service tier from the raw body before any transformation.
reasoningEffort := extractOpenAIReasoningEffortFromBody(body, originalModel)
serviceTier := extractOpenAIServiceTierFromBody(body)
// 2. Resolve model mapping (same as ForwardAsChatCompletions)
billingModel := resolveOpenAIForwardModel(account, originalModel, defaultMappedModel)
upstreamModel := normalizeOpenAIModelForUpstream(account, billingModel)
// 3. Rewrite model in body (no protocol conversion)
upstreamBody := body
if upstreamModel != originalModel {
upstreamBody = ReplaceModelInBody(body, upstreamModel)
}
// 4. Apply OpenAI fast policy on the CC body
updatedBody, policyErr := s.applyOpenAIFastPolicyToBody(ctx, account, upstreamModel, upstreamBody)
if policyErr != nil {
var blocked *OpenAIFastBlockedError
if errors.As(policyErr, &blocked) {
writeChatCompletionsError(c, http.StatusForbidden, "permission_error", blocked.Message)
}
return nil, policyErr
}
upstreamBody = updatedBody
if clientStream {
var usageErr error
upstreamBody, usageErr = ensureOpenAIChatStreamUsage(upstreamBody)
if usageErr != nil {
return nil, fmt.Errorf("enable stream usage: %w", usageErr)
}
}
logger.L().Debug("openai chat_completions raw: forwarding without protocol conversion",
zap.Int64("account_id", account.ID),
zap.String("original_model", originalModel),
zap.String("billing_model", billingModel),
zap.String("upstream_model", upstreamModel),
zap.Bool("stream", clientStream),
)
// 5. Build upstream request
apiKey := account.GetOpenAIApiKey()
if apiKey == "" {
return nil, fmt.Errorf("account %d missing api_key", account.ID)
}
baseURL := account.GetOpenAIBaseURL()
if baseURL == "" {
baseURL = "https://api.openai.com"
}
validatedURL, err := s.validateUpstreamBaseURL(baseURL)
if err != nil {
return nil, fmt.Errorf("invalid base_url: %w", err)
}
targetURL := buildOpenAIChatCompletionsURL(validatedURL)
upstreamCtx, releaseUpstreamCtx := detachUpstreamContext(ctx)
upstreamReq, err := http.NewRequestWithContext(upstreamCtx, http.MethodPost, targetURL, bytes.NewReader(upstreamBody))
releaseUpstreamCtx()
if err != nil {
return nil, fmt.Errorf("build upstream request: %w", err)
}
upstreamReq.Header.Set("Content-Type", "application/json")
upstreamReq.Header.Set("Authorization", "Bearer "+apiKey)
if clientStream {
upstreamReq.Header.Set("Accept", "text/event-stream")
} else {
upstreamReq.Header.Set("Accept", "application/json")
}
// 透传白名单中的客户端 header。详见 openaiCCRawAllowedHeaders 的设计说明。
for key, values := range c.Request.Header {
lowerKey := strings.ToLower(key)
if openaiCCRawAllowedHeaders[lowerKey] {
for _, v := range values {
upstreamReq.Header.Add(key, v)
}
}
}
customUA := account.GetOpenAIUserAgent()
if customUA != "" {
upstreamReq.Header.Set("user-agent", customUA)
}
// 6. Send request
proxyURL := ""
if account.Proxy != nil {
proxyURL = account.Proxy.URL()
}
resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
if err != nil {
safeErr := sanitizeUpstreamErrorMessage(err.Error())
setOpsUpstreamError(c, 0, safeErr, "")
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: 0,
Kind: "request_error",
Message: safeErr,
})
writeChatCompletionsError(c, http.StatusBadGateway, "upstream_error", "Upstream request failed")
return nil, fmt.Errorf("upstream request failed: %s", safeErr)
}
defer func() { _ = resp.Body.Close() }()
// 7. Handle error response with failover
if resp.StatusCode >= 400 {
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
_ = resp.Body.Close()
resp.Body = io.NopCloser(bytes.NewReader(respBody))
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody))
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
if s.shouldFailoverOpenAIUpstreamResponse(resp.StatusCode, upstreamMsg, respBody) {
upstreamDetail := ""
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
if maxBytes <= 0 {
maxBytes = 2048
}
upstreamDetail = truncateString(string(respBody), maxBytes)
}
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: resp.Header.Get("x-request-id"),
Kind: "failover",
Message: upstreamMsg,
Detail: upstreamDetail,
})
if s.rateLimitService != nil {
s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
}
return nil, &UpstreamFailoverError{
StatusCode: resp.StatusCode,
ResponseBody: respBody,
RetryableOnSameAccount: account.IsPoolMode() && (isPoolModeRetryableStatus(resp.StatusCode) || isOpenAITransientProcessingError(resp.StatusCode, upstreamMsg, respBody)),
}
}
return s.handleChatCompletionsErrorResponse(resp, c, account)
}
// 8. Forward response
if clientStream {
return s.streamRawChatCompletions(c, resp, originalModel, billingModel, upstreamModel, reasoningEffort, serviceTier, startTime)
}
return s.bufferRawChatCompletions(c, resp, originalModel, billingModel, upstreamModel, reasoningEffort, serviceTier, startTime)
}
// streamRawChatCompletions 透传上游 CC SSE 流到客户端,并提取 usage包括
// 末尾 [DONE] 之前的 chunk 中的 usage 字段,按 OpenAI CC 协议)。
//
// usage 字段仅在客户端请求 stream_options.include_usage=true 时出现于上游响应中。
// 网关会对上游强制打开 include_usage 以保证计费完整,并原样向下游透传 usage
// 让级联代理或下游计费系统也能拿到完整用量。
func (s *OpenAIGatewayService) streamRawChatCompletions(
c *gin.Context,
resp *http.Response,
originalModel string,
billingModel string,
upstreamModel string,
reasoningEffort *string,
serviceTier *string,
startTime time.Time,
) (*OpenAIForwardResult, error) {
requestID := resp.Header.Get("x-request-id")
if s.responseHeaderFilter != nil {
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter)
}
c.Writer.Header().Set("Content-Type", "text/event-stream")
c.Writer.Header().Set("Cache-Control", "no-cache")
c.Writer.Header().Set("Connection", "keep-alive")
c.Writer.Header().Set("X-Accel-Buffering", "no")
c.Writer.WriteHeader(http.StatusOK)
scanner := bufio.NewScanner(resp.Body)
maxLineSize := defaultMaxLineSize
if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 {
maxLineSize = s.cfg.Gateway.MaxLineSize
}
scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize)
var usage OpenAIUsage
var firstTokenMs *int
clientDisconnected := false
for scanner.Scan() {
line := scanner.Text()
if payload, ok := extractOpenAISSEDataLine(line); ok {
trimmedPayload := strings.TrimSpace(payload)
if trimmedPayload != "[DONE]" {
usageOnlyChunk := isOpenAIChatUsageOnlyStreamChunk(payload)
if u := extractCCStreamUsage(payload); u != nil {
usage = *u
}
if firstTokenMs == nil && !usageOnlyChunk {
elapsed := int(time.Since(startTime).Milliseconds())
firstTokenMs = &elapsed
}
}
}
if !clientDisconnected {
if _, werr := c.Writer.WriteString(line + "\n"); werr != nil {
clientDisconnected = true
logger.L().Debug("openai chat_completions raw: client disconnected, continuing to drain upstream for billing",
zap.Error(werr),
zap.String("request_id", requestID),
)
}
}
if line == "" {
if !clientDisconnected {
c.Writer.Flush()
}
continue
}
if !clientDisconnected {
c.Writer.Flush()
}
}
if err := scanner.Err(); err != nil {
if !errors.Is(err, context.Canceled) && !errors.Is(err, context.DeadlineExceeded) {
logger.L().Warn("openai chat_completions raw: stream read error",
zap.Error(err),
zap.String("request_id", requestID),
)
}
}
return &OpenAIForwardResult{
RequestID: requestID,
Usage: usage,
Model: originalModel,
BillingModel: billingModel,
UpstreamModel: upstreamModel,
ReasoningEffort: reasoningEffort,
ServiceTier: serviceTier,
Stream: true,
Duration: time.Since(startTime),
FirstTokenMs: firstTokenMs,
}, nil
}
// ensureOpenAIChatStreamUsage 确保 raw Chat Completions 流式请求会让上游返回 usage。
// usage 也会继续向下游透传,支持级联代理和下游计费系统。
func ensureOpenAIChatStreamUsage(body []byte) ([]byte, error) {
updated, err := sjson.SetBytes(body, "stream_options.include_usage", true)
if err != nil {
return body, err
}
return updated, nil
}
func isOpenAIChatUsageOnlyStreamChunk(payload string) bool {
if strings.TrimSpace(payload) == "" {
return false
}
if !gjson.Get(payload, "usage").Exists() {
return false
}
choices := gjson.Get(payload, "choices")
return choices.Exists() && choices.IsArray() && len(choices.Array()) == 0
}
// extractCCStreamUsage 从单个 CC 流式 chunk 的 payload 中提取 usage 字段。
// CC 协议中 usage 仅出现在末尾 chunk且仅当 include_usage 生效时),
// 但上游可能在多个 chunk 中重复——总是用最新值。
func extractCCStreamUsage(payload string) *OpenAIUsage {
usageResult := gjson.Get(payload, "usage")
if !usageResult.Exists() || !usageResult.IsObject() {
return nil
}
u := OpenAIUsage{
InputTokens: int(gjson.Get(payload, "usage.prompt_tokens").Int()),
OutputTokens: int(gjson.Get(payload, "usage.completion_tokens").Int()),
}
if cached := gjson.Get(payload, "usage.prompt_tokens_details.cached_tokens"); cached.Exists() {
u.CacheReadInputTokens = int(cached.Int())
}
return &u
}
// bufferRawChatCompletions 透传上游 CC 非流式 JSON 响应。
func (s *OpenAIGatewayService) bufferRawChatCompletions(
c *gin.Context,
resp *http.Response,
originalModel string,
billingModel string,
upstreamModel string,
reasoningEffort *string,
serviceTier *string,
startTime time.Time,
) (*OpenAIForwardResult, error) {
requestID := resp.Header.Get("x-request-id")
respBody, err := ReadUpstreamResponseBody(resp.Body, s.cfg, c, openAITooLargeError)
if err != nil {
if !errors.Is(err, ErrUpstreamResponseBodyTooLarge) {
writeChatCompletionsError(c, http.StatusBadGateway, "api_error", "Failed to read upstream response")
}
return nil, fmt.Errorf("read upstream body: %w", err)
}
var ccResp apicompat.ChatCompletionsResponse
var usage OpenAIUsage
if err := json.Unmarshal(respBody, &ccResp); err == nil && ccResp.Usage != nil {
usage = OpenAIUsage{
InputTokens: ccResp.Usage.PromptTokens,
OutputTokens: ccResp.Usage.CompletionTokens,
}
if ccResp.Usage.PromptTokensDetails != nil {
usage.CacheReadInputTokens = ccResp.Usage.PromptTokensDetails.CachedTokens
}
}
if s.responseHeaderFilter != nil {
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter)
}
if ct := resp.Header.Get("Content-Type"); ct != "" {
c.Writer.Header().Set("Content-Type", ct)
} else {
c.Writer.Header().Set("Content-Type", "application/json")
}
c.Writer.WriteHeader(http.StatusOK)
_, _ = c.Writer.Write(respBody)
return &OpenAIForwardResult{
RequestID: requestID,
Usage: usage,
Model: originalModel,
BillingModel: billingModel,
UpstreamModel: upstreamModel,
ReasoningEffort: reasoningEffort,
ServiceTier: serviceTier,
Stream: false,
Duration: time.Since(startTime),
}, nil
}
// buildOpenAIChatCompletionsURL 拼接上游 Chat Completions 端点 URL。
//
// - base 已是 /chat/completions原样返回
// - base 以 /v1 结尾:追加 /chat/completions
// - 其他情况:追加 /v1/chat/completions
//
// 与 buildOpenAIResponsesURL 是姐妹函数。
func buildOpenAIChatCompletionsURL(base string) string {
normalized := strings.TrimRight(strings.TrimSpace(base), "/")
if strings.HasSuffix(normalized, "/chat/completions") {
return normalized
}
if strings.HasSuffix(normalized, "/v1") {
return normalized + "/chat/completions"
}
return normalized + "/v1/chat/completions"
}

View File

@ -0,0 +1,260 @@
//go:build unit
package service
import (
"bytes"
"context"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
"github.com/tidwall/gjson"
)
func TestBuildOpenAIChatCompletionsURL(t *testing.T) {
t.Parallel()
tests := []struct {
name string
base string
want string
}{
// 已是 /chat/completions原样返回
{"already chat/completions", "https://api.openai.com/v1/chat/completions", "https://api.openai.com/v1/chat/completions"},
// 以 /v1 结尾:追加 /chat/completions
{"bare /v1", "https://api.openai.com/v1", "https://api.openai.com/v1/chat/completions"},
// 其他情况:追加 /v1/chat/completions
{"bare domain", "https://api.openai.com", "https://api.openai.com/v1/chat/completions"},
{"domain with trailing slash", "https://api.openai.com/", "https://api.openai.com/v1/chat/completions"},
// 第三方上游常见形式
{"third-party bare domain", "https://api.deepseek.com", "https://api.deepseek.com/v1/chat/completions"},
{"third-party with path prefix", "https://api.gptgod.online/api", "https://api.gptgod.online/api/v1/chat/completions"},
// 带空白字符
{"whitespace trimmed", " https://api.openai.com/v1 ", "https://api.openai.com/v1/chat/completions"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
got := buildOpenAIChatCompletionsURL(tt.base)
require.Equal(t, tt.want, got)
})
}
}
// TestBuildOpenAIResponsesURL_ProbeURL 锁定 probe/测试端点使用的 URL 构建逻辑,
// 确保 buildOpenAIResponsesURL 对标准 OpenAI base_url 格式均拼出 `/v1/responses`。
func TestBuildOpenAIResponsesURL_ProbeURL(t *testing.T) {
t.Parallel()
tests := []struct {
name string
base string
want string
}{
{"bare domain", "https://api.openai.com", "https://api.openai.com/v1/responses"},
{"domain trailing slash", "https://api.openai.com/", "https://api.openai.com/v1/responses"},
{"bare /v1", "https://api.openai.com/v1", "https://api.openai.com/v1/responses"},
{"already /responses", "https://api.openai.com/v1/responses", "https://api.openai.com/v1/responses"},
{"third-party bare domain", "https://api.deepseek.com", "https://api.deepseek.com/v1/responses"},
{"only domain, no scheme", "api.gptgod.online", "api.gptgod.online/v1/responses"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
got := buildOpenAIResponsesURL(tt.base)
require.Equal(t, tt.want, got)
})
}
}
func TestForwardAsRawChatCompletions_ForcesStreamUsageUpstreamAndPassesUsageDownstream(t *testing.T) {
gin.SetMode(gin.TestMode)
body := []byte(`{"model":"gpt-5.4","messages":[{"role":"user","content":"hello"}],"stream":true}`)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", bytes.NewReader(body))
c.Request.Header.Set("Content-Type", "application/json")
upstreamBody := strings.Join([]string{
`data: {"id":"chatcmpl_1","object":"chat.completion.chunk","model":"gpt-5.4","choices":[{"index":0,"delta":{"content":"ok"}}]}`,
"",
`data: {"id":"chatcmpl_1","object":"chat.completion.chunk","model":"gpt-5.4","choices":[],"usage":{"prompt_tokens":9,"completion_tokens":4,"total_tokens":13,"prompt_tokens_details":{"cached_tokens":3}}}`,
"",
"data: [DONE]",
"",
}, "\n")
upstream := &httpUpstreamRecorder{resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_raw_usage"}},
Body: io.NopCloser(strings.NewReader(upstreamBody)),
}}
svc := &OpenAIGatewayService{
cfg: rawChatCompletionsTestConfig(),
httpUpstream: upstream,
}
account := rawChatCompletionsTestAccount()
result, err := svc.forwardAsRawChatCompletions(context.Background(), c, account, body, "")
require.NoError(t, err)
require.NotNil(t, result)
require.Equal(t, 9, result.Usage.InputTokens)
require.Equal(t, 4, result.Usage.OutputTokens)
require.Equal(t, 3, result.Usage.CacheReadInputTokens)
require.NotNil(t, upstream.lastReq)
require.NoError(t, upstream.lastReq.Context().Err())
require.True(t, gjson.GetBytes(upstream.lastBody, "stream_options.include_usage").Bool())
require.Contains(t, rec.Body.String(), `"usage"`)
require.Contains(t, rec.Body.String(), "data: [DONE]")
}
func TestForwardAsRawChatCompletions_ClientDisconnectDrainsUsage(t *testing.T) {
gin.SetMode(gin.TestMode)
body := []byte(`{"model":"gpt-5.4","messages":[{"role":"user","content":"hello"}],"stream":true}`)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Writer = &openAIChatFailingWriter{ResponseWriter: c.Writer, failAfter: 0}
c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", bytes.NewReader(body))
c.Request.Header.Set("Content-Type", "application/json")
upstreamBody := strings.Join([]string{
`data: {"id":"chatcmpl_1","object":"chat.completion.chunk","model":"gpt-5.4","choices":[{"index":0,"delta":{"content":"ok"}}]}`,
"",
`data: {"id":"chatcmpl_1","object":"chat.completion.chunk","model":"gpt-5.4","choices":[],"usage":{"prompt_tokens":17,"completion_tokens":8,"total_tokens":25,"prompt_tokens_details":{"cached_tokens":6}}}`,
"",
"data: [DONE]",
"",
}, "\n")
upstream := &httpUpstreamRecorder{resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_raw_disconnect"}},
Body: io.NopCloser(strings.NewReader(upstreamBody)),
}}
svc := &OpenAIGatewayService{
cfg: rawChatCompletionsTestConfig(),
httpUpstream: upstream,
}
account := rawChatCompletionsTestAccount()
result, err := svc.forwardAsRawChatCompletions(context.Background(), c, account, body, "")
require.NoError(t, err)
require.NotNil(t, result)
require.Equal(t, 17, result.Usage.InputTokens)
require.Equal(t, 8, result.Usage.OutputTokens)
require.Equal(t, 6, result.Usage.CacheReadInputTokens)
require.True(t, gjson.GetBytes(upstream.lastBody, "stream_options.include_usage").Bool())
}
func TestForwardAsRawChatCompletions_UpstreamRequestIgnoresClientCancel(t *testing.T) {
gin.SetMode(gin.TestMode)
reqCtx, cancel := context.WithCancel(context.Background())
body := []byte(`{"model":"gpt-5.4","messages":[{"role":"user","content":"hello"}],"stream":true}`)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", bytes.NewReader(body)).WithContext(reqCtx)
c.Request.Header.Set("Content-Type", "application/json")
cancel()
upstreamBody := strings.Join([]string{
`data: {"id":"chatcmpl_1","object":"chat.completion.chunk","model":"gpt-5.4","choices":[],"usage":{"prompt_tokens":5,"completion_tokens":2,"total_tokens":7}}`,
"",
"data: [DONE]",
"",
}, "\n")
upstream := &httpUpstreamRecorder{resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_raw_ctx"}},
Body: io.NopCloser(strings.NewReader(upstreamBody)),
}}
svc := &OpenAIGatewayService{
cfg: rawChatCompletionsTestConfig(),
httpUpstream: upstream,
}
account := rawChatCompletionsTestAccount()
result, err := svc.forwardAsRawChatCompletions(reqCtx, c, account, body, "")
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, upstream.lastReq)
require.NoError(t, upstream.lastReq.Context().Err())
}
func TestIsOpenAIChatUsageOnlyStreamChunk(t *testing.T) {
t.Parallel()
require.True(t, isOpenAIChatUsageOnlyStreamChunk(`{"choices":[],"usage":{"prompt_tokens":1,"completion_tokens":2}}`))
require.False(t, isOpenAIChatUsageOnlyStreamChunk(`{"choices":[{"index":0}],"usage":{"prompt_tokens":1,"completion_tokens":2}}`))
require.False(t, isOpenAIChatUsageOnlyStreamChunk(`{"choices":[]}`))
require.False(t, isOpenAIChatUsageOnlyStreamChunk(``))
}
func TestEnsureOpenAIChatStreamUsage(t *testing.T) {
t.Parallel()
body, err := ensureOpenAIChatStreamUsage([]byte(`{"model":"gpt-5.4"}`))
require.NoError(t, err)
require.True(t, gjson.GetBytes(body, "stream_options.include_usage").Bool())
body, err = ensureOpenAIChatStreamUsage([]byte(`{"model":"gpt-5.4","stream_options":{"include_usage":false}}`))
require.NoError(t, err)
require.True(t, gjson.GetBytes(body, "stream_options.include_usage").Bool())
}
func TestBufferRawChatCompletions_RejectsOversizedResponse(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
resp := &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"application/json"}},
Body: io.NopCloser(strings.NewReader("toolong")),
}
svc := &OpenAIGatewayService{cfg: rawChatCompletionsTestConfig()}
svc.cfg.Gateway.UpstreamResponseReadMaxBytes = 3
result, err := svc.bufferRawChatCompletions(c, resp, "gpt-5.4", "gpt-5.4", "gpt-5.4", nil, nil, time.Now())
require.ErrorIs(t, err, ErrUpstreamResponseBodyTooLarge)
require.Nil(t, result)
require.Equal(t, http.StatusBadGateway, rec.Code)
}
func rawChatCompletionsTestConfig() *config.Config {
return &config.Config{
Security: config.SecurityConfig{
URLAllowlist: config.URLAllowlistConfig{
Enabled: false,
AllowInsecureHTTP: true,
},
},
}
}
func rawChatCompletionsTestAccount() *Account {
return &Account{
ID: 101,
Name: "raw-openai-apikey",
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Concurrency: 1,
Credentials: map[string]any{
"api_key": "sk-test",
"base_url": "http://upstream.example",
},
}
}

View File

@ -1,13 +1,36 @@
package service package service
import ( import (
"bytes"
"context"
"errors"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing" "testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/apicompat" "github.com/Wei-Shaw/sub2api/internal/pkg/apicompat"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/tidwall/gjson" "github.com/tidwall/gjson"
) )
type openAIChatFailingWriter struct {
gin.ResponseWriter
failAfter int
writes int
}
func (w *openAIChatFailingWriter) Write(p []byte) (int, error) {
if w.writes >= w.failAfter {
return 0, errors.New("write failed: client disconnected")
}
w.writes++
return w.ResponseWriter.Write(p)
}
func TestNormalizeResponsesRequestServiceTier(t *testing.T) { func TestNormalizeResponsesRequestServiceTier(t *testing.T) {
t.Parallel() t.Parallel()
@ -73,3 +96,278 @@ func TestNormalizeResponsesBodyServiceTier(t *testing.T) {
require.Empty(t, tier) require.Empty(t, tier)
require.False(t, gjson.GetBytes(body, "service_tier").Exists()) require.False(t, gjson.GetBytes(body, "service_tier").Exists())
} }
func TestForwardAsChatCompletions_UnknownModelDoesNotUseDefaultMappedModel(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
body := []byte(`{"model":"gpt6","messages":[{"role":"user","content":"hello"}],"stream":false}`)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", bytes.NewReader(body))
c.Request.Header.Set("Content-Type", "application/json")
upstream := &httpUpstreamRecorder{resp: &http.Response{
StatusCode: http.StatusBadRequest,
Header: http.Header{"Content-Type": []string{"application/json"}, "x-request-id": []string{"rid_chat_unknown_model"}},
Body: io.NopCloser(strings.NewReader(`{"error":{"type":"invalid_request_error","message":"model not found"}}`)),
}}
svc := &OpenAIGatewayService{httpUpstream: upstream}
account := &Account{
ID: 1,
Name: "openai-oauth",
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "oauth-token",
"chatgpt_account_id": "chatgpt-acc",
},
}
result, err := svc.ForwardAsChatCompletions(context.Background(), c, account, body, "", "gpt-5.4")
require.Error(t, err)
require.Nil(t, result)
require.Equal(t, "gpt6", gjson.GetBytes(upstream.lastBody, "model").String())
require.NotEqual(t, "gpt-5.4", gjson.GetBytes(upstream.lastBody, "model").String())
require.Equal(t, http.StatusBadRequest, rec.Code)
}
func TestForwardAsChatCompletions_ClientDisconnectDrainsUpstreamUsage(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Writer = &openAIChatFailingWriter{ResponseWriter: c.Writer, failAfter: 0}
body := []byte(`{"model":"gpt-5.4","messages":[{"role":"user","content":"hello"}],"stream":true}`)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", bytes.NewReader(body))
c.Request.Header.Set("Content-Type", "application/json")
upstreamBody := strings.Join([]string{
`data: {"type":"response.created","response":{"id":"resp_1","model":"gpt-5.4","status":"in_progress","output":[]}}`,
"",
`data: {"type":"response.output_text.delta","delta":"ok"}`,
"",
`data: {"type":"response.completed","response":{"id":"resp_1","object":"response","model":"gpt-5.4","status":"completed","output":[{"type":"message","id":"msg_1","role":"assistant","status":"completed","content":[{"type":"output_text","text":"ok"}]}],"usage":{"input_tokens":11,"output_tokens":5,"total_tokens":16,"input_tokens_details":{"cached_tokens":4}}}}`,
"",
"data: [DONE]",
"",
}, "\n")
upstream := &httpUpstreamRecorder{resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_chat_disconnect"}},
Body: io.NopCloser(strings.NewReader(upstreamBody)),
}}
svc := &OpenAIGatewayService{httpUpstream: upstream}
account := &Account{
ID: 1,
Name: "openai-oauth",
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "oauth-token",
"chatgpt_account_id": "chatgpt-acc",
},
}
result, err := svc.ForwardAsChatCompletions(context.Background(), c, account, body, "", "gpt-5.1")
require.NoError(t, err)
require.NotNil(t, result)
require.Equal(t, 11, result.Usage.InputTokens)
require.Equal(t, 5, result.Usage.OutputTokens)
require.Equal(t, 4, result.Usage.CacheReadInputTokens)
}
func TestForwardAsChatCompletions_TerminalUsageWithoutUpstreamCloseReturns(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Writer = &openAIChatFailingWriter{ResponseWriter: c.Writer, failAfter: 0}
body := []byte(`{"model":"gpt-5.4","messages":[{"role":"user","content":"hello"}],"stream":true}`)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", bytes.NewReader(body))
c.Request.Header.Set("Content-Type", "application/json")
upstreamBody := []byte(`data: {"type":"response.completed","response":{"id":"resp_1","object":"response","model":"gpt-5.4","status":"completed","output":[{"type":"message","id":"msg_1","role":"assistant","status":"completed","content":[{"type":"output_text","text":"ok"}]}],"usage":{"input_tokens":17,"output_tokens":8,"total_tokens":25,"input_tokens_details":{"cached_tokens":6}}}}` + "\n\n")
upstreamStream := newOpenAICompatBlockingReadCloser(upstreamBody)
defer func() {
require.NoError(t, upstreamStream.Close())
}()
upstream := &httpUpstreamRecorder{resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_chat_terminal_no_close"}},
Body: upstreamStream,
}}
svc := &OpenAIGatewayService{httpUpstream: upstream}
account := &Account{
ID: 1,
Name: "openai-oauth",
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "oauth-token",
"chatgpt_account_id": "chatgpt-acc",
},
}
type forwardResult struct {
result *OpenAIForwardResult
err error
}
resultCh := make(chan forwardResult, 1)
go func() {
result, err := svc.ForwardAsChatCompletions(context.Background(), c, account, body, "", "gpt-5.1")
resultCh <- forwardResult{result: result, err: err}
}()
select {
case got := <-resultCh:
require.NoError(t, got.err)
require.NotNil(t, got.result)
require.Equal(t, 17, got.result.Usage.InputTokens)
require.Equal(t, 8, got.result.Usage.OutputTokens)
require.Equal(t, 6, got.result.Usage.CacheReadInputTokens)
case <-time.After(time.Second):
require.Fail(t, "ForwardAsChatCompletions should return after terminal usage event even if upstream keeps the connection open")
}
}
func TestForwardAsChatCompletions_BufferedTerminalWithoutUpstreamCloseReturns(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
body := []byte(`{"model":"gpt-5.4","messages":[{"role":"user","content":"hello"}],"stream":false}`)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", bytes.NewReader(body))
c.Request.Header.Set("Content-Type", "application/json")
upstreamBody := []byte(`data: {"type":"response.completed","response":{"id":"resp_1","object":"response","model":"gpt-5.4","status":"completed","output":[{"type":"message","id":"msg_1","role":"assistant","status":"completed","content":[{"type":"output_text","text":"ok"}]}],"usage":{"input_tokens":17,"output_tokens":8,"total_tokens":25,"input_tokens_details":{"cached_tokens":6}}}}` + "\n\n")
upstreamStream := newOpenAICompatBlockingReadCloser(upstreamBody)
defer func() {
require.NoError(t, upstreamStream.Close())
}()
upstream := &httpUpstreamRecorder{resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_chat_buffered_terminal_no_close"}},
Body: upstreamStream,
}}
svc := &OpenAIGatewayService{httpUpstream: upstream}
account := &Account{
ID: 1,
Name: "openai-oauth",
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "oauth-token",
"chatgpt_account_id": "chatgpt-acc",
},
}
type forwardResult struct {
result *OpenAIForwardResult
err error
}
resultCh := make(chan forwardResult, 1)
go func() {
result, err := svc.ForwardAsChatCompletions(context.Background(), c, account, body, "", "gpt-5.1")
resultCh <- forwardResult{result: result, err: err}
}()
select {
case got := <-resultCh:
require.NoError(t, got.err)
require.NotNil(t, got.result)
require.Equal(t, 17, got.result.Usage.InputTokens)
require.Equal(t, 8, got.result.Usage.OutputTokens)
require.Equal(t, 6, got.result.Usage.CacheReadInputTokens)
require.Contains(t, rec.Body.String(), `"finish_reason":"stop"`)
case <-time.After(time.Second):
require.Fail(t, "ForwardAsChatCompletions buffered response should return after terminal usage event even if upstream keeps the connection open")
}
}
func TestForwardAsChatCompletions_DoneSentinelWithoutTerminalReturnsError(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
body := []byte(`{"model":"gpt-5.4","messages":[{"role":"user","content":"hello"}],"stream":true}`)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", bytes.NewReader(body))
c.Request.Header.Set("Content-Type", "application/json")
upstreamBody := "data: [DONE]\n\n"
upstream := &httpUpstreamRecorder{resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_chat_missing_terminal"}},
Body: io.NopCloser(strings.NewReader(upstreamBody)),
}}
svc := &OpenAIGatewayService{httpUpstream: upstream}
account := &Account{
ID: 1,
Name: "openai-oauth",
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "oauth-token",
"chatgpt_account_id": "chatgpt-acc",
},
}
result, err := svc.ForwardAsChatCompletions(context.Background(), c, account, body, "", "gpt-5.1")
require.Error(t, err)
require.Contains(t, err.Error(), "missing terminal event")
require.NotNil(t, result)
require.Zero(t, result.Usage.InputTokens)
require.Zero(t, result.Usage.OutputTokens)
}
func TestForwardAsChatCompletions_UpstreamRequestIgnoresClientCancel(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
reqCtx, cancel := context.WithCancel(context.Background())
body := []byte(`{"model":"gpt-5.4","messages":[{"role":"user","content":"hello"}],"stream":false}`)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", bytes.NewReader(body)).WithContext(reqCtx)
c.Request.Header.Set("Content-Type", "application/json")
cancel()
upstreamBody := strings.Join([]string{
`data: {"type":"response.completed","response":{"id":"resp_1","object":"response","model":"gpt-5.4","status":"completed","output":[{"type":"message","id":"msg_1","role":"assistant","status":"completed","content":[{"type":"output_text","text":"ok"}]}],"usage":{"input_tokens":5,"output_tokens":2,"total_tokens":7}}}`,
"",
"data: [DONE]",
"",
}, "\n")
upstream := &httpUpstreamRecorder{resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_chat_ctx"}},
Body: io.NopCloser(strings.NewReader(upstreamBody)),
}}
svc := &OpenAIGatewayService{httpUpstream: upstream}
account := &Account{
ID: 1,
Name: "openai-oauth",
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "oauth-token",
"chatgpt_account_id": "chatgpt-acc",
},
}
result, err := svc.ForwardAsChatCompletions(reqCtx, c, account, body, "", "gpt-5.1")
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, upstream.lastReq)
require.NoError(t, upstream.lastReq.Context().Err())
}

View File

@ -10,6 +10,7 @@ import (
"io" "io"
"net/http" "net/http"
"strings" "strings"
"sync/atomic"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/pkg/apicompat" "github.com/Wei-Shaw/sub2api/internal/pkg/apicompat"
@ -39,12 +40,54 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic(
if err := json.Unmarshal(body, &anthropicReq); err != nil { if err := json.Unmarshal(body, &anthropicReq); err != nil {
return nil, fmt.Errorf("parse anthropic request: %w", err) return nil, fmt.Errorf("parse anthropic request: %w", err)
} }
anthropicDigestReq := cloneAnthropicRequestForDigest(&anthropicReq)
originalModel := anthropicReq.Model originalModel := anthropicReq.Model
applyOpenAICompatModelNormalization(&anthropicReq) applyOpenAICompatModelNormalization(&anthropicReq)
normalizedModel := anthropicReq.Model normalizedModel := anthropicReq.Model
clientStream := anthropicReq.Stream // client's original stream preference clientStream := anthropicReq.Stream // client's original stream preference
// 2. Convert Anthropic → Responses // 2. Model mapping
billingModel := resolveOpenAIForwardModel(account, normalizedModel, defaultMappedModel)
upstreamModel := normalizeOpenAIModelForUpstream(account, billingModel)
promptCacheKey = strings.TrimSpace(promptCacheKey)
apiKeyID := getAPIKeyIDFromContext(c)
anthropicDigestChain := ""
anthropicMatchedDigestChain := ""
compatPromptCacheInjected := false
if promptCacheKey == "" && shouldAutoInjectPromptCacheKeyForCompat(upstreamModel) {
promptCacheKey = promptCacheKeyFromAnthropicMetadataSession(&anthropicReq)
if promptCacheKey == "" {
promptCacheKey = deriveAnthropicCacheControlPromptCacheKey(&anthropicReq)
}
if promptCacheKey == "" {
anthropicDigestChain = buildOpenAICompatAnthropicDigestChain(anthropicDigestReq)
if reusedKey, matchedChain := s.findOpenAICompatAnthropicDigestPromptCacheKey(account, apiKeyID, anthropicDigestChain); reusedKey != "" {
promptCacheKey = reusedKey
anthropicMatchedDigestChain = matchedChain
} else {
promptCacheKey = promptCacheKeyFromAnthropicDigest(anthropicDigestChain)
}
}
compatPromptCacheInjected = promptCacheKey != ""
}
compatReplayTrimmed := false
compatReplayGuardEnabled := shouldAutoInjectPromptCacheKeyForCompat(upstreamModel)
compatContinuationEnabled := openAICompatContinuationEnabled(account, upstreamModel)
previousResponseID := ""
if compatContinuationEnabled {
previousResponseID = s.getOpenAICompatSessionResponseID(ctx, c, account, promptCacheKey)
}
compatContinuationDisabled := compatContinuationEnabled &&
s.isOpenAICompatSessionContinuationDisabled(ctx, c, account, promptCacheKey)
compatTurnState := ""
// OAuth/Plus relies on session_id + x-codex-turn-state; trimming to a
// sliding 12-message window makes the cached prefix stall at system/tools.
// Keep full replay there so upstream prompt caching can grow turn by turn.
if compatReplayGuardEnabled && account.Type != AccountTypeOAuth && previousResponseID == "" && !compatContinuationDisabled {
compatReplayTrimmed = applyAnthropicCompatFullReplayGuard(&anthropicReq)
}
// 3. Convert Anthropic → Responses after compatibility-only replay guard.
responsesReq, err := apicompat.AnthropicToResponses(&anthropicReq) responsesReq, err := apicompat.AnthropicToResponses(&anthropicReq)
if err != nil { if err != nil {
return nil, fmt.Errorf("convert anthropic to responses: %w", err) return nil, fmt.Errorf("convert anthropic to responses: %w", err)
@ -55,24 +98,50 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic(
responsesReq.Stream = true responsesReq.Stream = true
isStream := true isStream := true
// 2b. Handle BetaFastMode → service_tier: "priority" // 3b. Handle BetaFastMode → service_tier: "priority"
if containsBetaToken(c.GetHeader("anthropic-beta"), claude.BetaFastMode) { if containsBetaToken(c.GetHeader("anthropic-beta"), claude.BetaFastMode) {
responsesReq.ServiceTier = "priority" responsesReq.ServiceTier = "priority"
} }
// 3. Model mapping
billingModel := resolveOpenAIForwardModel(account, normalizedModel, defaultMappedModel)
upstreamModel := normalizeOpenAIModelForUpstream(account, billingModel)
responsesReq.Model = upstreamModel responsesReq.Model = upstreamModel
if previousResponseID != "" {
responsesReq.PreviousResponseID = previousResponseID
trimAnthropicCompatResponsesInputToLatestTurn(responsesReq)
}
if compatReplayGuardEnabled && account.Type != AccountTypeOAuth {
appendOpenAICompatClaudeCodeTodoGuard(responsesReq)
}
logger.L().Debug("openai messages: model mapping applied", logFields := []zap.Field{
zap.Int64("account_id", account.ID), zap.Int64("account_id", account.ID),
zap.String("original_model", originalModel), zap.String("original_model", originalModel),
zap.String("normalized_model", normalizedModel), zap.String("normalized_model", normalizedModel),
zap.String("billing_model", billingModel), zap.String("billing_model", billingModel),
zap.String("upstream_model", upstreamModel), zap.String("upstream_model", upstreamModel),
zap.Bool("stream", isStream), zap.Bool("stream", isStream),
) }
if compatPromptCacheInjected {
logFields = append(logFields,
zap.Bool("compat_prompt_cache_key_injected", true),
zap.String("compat_prompt_cache_key_sha256", hashSensitiveValueForLog(promptCacheKey)),
)
}
if compatReplayTrimmed {
logFields = append(logFields,
zap.Bool("compat_full_replay_trimmed", true),
zap.Int("compat_messages_after_trim", len(anthropicReq.Messages)),
)
}
if previousResponseID != "" {
logFields = append(logFields,
zap.Bool("compat_previous_response_id_attached", true),
zap.String("compat_previous_response_id", truncateOpenAIWSLogValue(previousResponseID, openAIWSIDValueMaxLen)),
)
}
if compatTurnState != "" {
logFields = append(logFields, zap.Bool("compat_turn_state_attached", true))
}
logger.L().Debug("openai messages: model mapping applied", logFields...)
// 4. Marshal Responses request body, then apply OAuth codex transform // 4. Marshal Responses request body, then apply OAuth codex transform
responsesBody, err := json.Marshal(responsesReq) responsesBody, err := json.Marshal(responsesReq)
@ -85,7 +154,10 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic(
if err := json.Unmarshal(responsesBody, &reqBody); err != nil { if err := json.Unmarshal(responsesBody, &reqBody); err != nil {
return nil, fmt.Errorf("unmarshal for codex transform: %w", err) return nil, fmt.Errorf("unmarshal for codex transform: %w", err)
} }
codexResult := applyCodexOAuthTransform(reqBody, false, false) codexResult := applyCodexOAuthTransformWithOptions(reqBody, codexOAuthTransformOptions{
SkipDefaultInstructions: true,
PreserveToolCallIDs: true,
})
forcedTemplateText := "" forcedTemplateText := ""
if s.cfg != nil { if s.cfg != nil {
forcedTemplateText = s.cfg.Gateway.ForcedCodexInstructionsTemplate forcedTemplateText = s.cfg.Gateway.ForcedCodexInstructionsTemplate
@ -95,6 +167,9 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic(
templateUpstreamModel = codexResult.NormalizedModel templateUpstreamModel = codexResult.NormalizedModel
} }
existingInstructions, _ := reqBody["instructions"].(string) existingInstructions, _ := reqBody["instructions"].(string)
if strings.TrimSpace(existingInstructions) == "" {
existingInstructions = extractPromptLikeInstructionsFromInput(reqBody)
}
if _, err := applyForcedCodexInstructionsTemplate(reqBody, forcedTemplateText, forcedCodexInstructionsTemplateData{ if _, err := applyForcedCodexInstructionsTemplate(reqBody, forcedTemplateText, forcedCodexInstructionsTemplateData{
ExistingInstructions: strings.TrimSpace(existingInstructions), ExistingInstructions: strings.TrimSpace(existingInstructions),
OriginalModel: originalModel, OriginalModel: originalModel,
@ -104,13 +179,19 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic(
}); err != nil { }); err != nil {
return nil, err return nil, err
} }
ensureCodexOAuthInstructionsField(reqBody)
if shouldAutoInjectPromptCacheKeyForCompat(upstreamModel) {
appendOpenAICompatClaudeCodeTodoGuardToRequestBody(reqBody)
}
if codexResult.NormalizedModel != "" { if codexResult.NormalizedModel != "" {
upstreamModel = codexResult.NormalizedModel upstreamModel = codexResult.NormalizedModel
} }
if codexResult.PromptCacheKey != "" { if codexResult.PromptCacheKey != "" {
promptCacheKey = codexResult.PromptCacheKey promptCacheKey = codexResult.PromptCacheKey
} else if promptCacheKey != "" { }
reqBody["prompt_cache_key"] = promptCacheKey delete(reqBody, "prompt_cache_key")
if shouldAutoInjectPromptCacheKeyForCompat(upstreamModel) {
compatTurnState = s.getOpenAICompatSessionTurnState(ctx, c, account, promptCacheKey)
} }
// OAuth codex transform forces stream=true upstream, so always use // OAuth codex transform forces stream=true upstream, so always use
// the streaming response handler regardless of what the client asked. // the streaming response handler regardless of what the client asked.
@ -163,7 +244,9 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic(
} }
// 6. Build upstream request // 6. Build upstream request
upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, responsesBody, token, isStream, promptCacheKey, false) upstreamCtx, releaseUpstreamCtx := detachUpstreamContext(ctx)
upstreamReq, err := s.buildUpstreamRequest(upstreamCtx, c, account, responsesBody, token, isStream, promptCacheKey, false)
releaseUpstreamCtx()
if err != nil { if err != nil {
return nil, fmt.Errorf("build upstream request: %w", err) return nil, fmt.Errorf("build upstream request: %w", err)
} }
@ -171,8 +254,25 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic(
// Override session_id with a deterministic UUID derived from the isolated // Override session_id with a deterministic UUID derived from the isolated
// session key, ensuring different API keys produce different upstream sessions. // session key, ensuring different API keys produce different upstream sessions.
if promptCacheKey != "" { if promptCacheKey != "" {
apiKeyID := getAPIKeyIDFromContext(c) isolatedSessionID := generateSessionUUID(isolateOpenAISessionID(apiKeyID, promptCacheKey))
upstreamReq.Header.Set("session_id", generateSessionUUID(isolateOpenAISessionID(apiKeyID, promptCacheKey))) upstreamReq.Header.Set("session_id", isolatedSessionID)
if upstreamReq.Header.Get("conversation_id") != "" {
upstreamReq.Header.Set("conversation_id", isolatedSessionID)
}
}
if account.Type == AccountTypeOAuth {
// Anthropic Messages compatibility uses the ChatGPT Codex SSE endpoint.
// Match airgate-openai's request shape: the SSE endpoint does not need
// the Responses experimental beta header, and forcing originator can make
// ChatGPT select a different internal continuation path.
upstreamReq.Header.Del("OpenAI-Beta")
upstreamReq.Header.Del("originator")
}
if account.Type == AccountTypeOAuth && promptCacheKey != "" && strings.TrimSpace(c.GetHeader("conversation_id")) == "" {
upstreamReq.Header.Del("conversation_id")
}
if compatTurnState != "" && upstreamReq.Header.Get("x-codex-turn-state") == "" {
upstreamReq.Header.Set("x-codex-turn-state", compatTurnState)
} }
// 7. Send request // 7. Send request
@ -205,6 +305,19 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic(
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody)) upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody))
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
if previousResponseID != "" && (isOpenAICompatPreviousResponseNotFound(resp.StatusCode, upstreamMsg, respBody) || isOpenAICompatPreviousResponseUnsupported(resp.StatusCode, upstreamMsg, respBody)) {
if isOpenAICompatPreviousResponseUnsupported(resp.StatusCode, upstreamMsg, respBody) {
s.disableOpenAICompatSessionContinuation(ctx, c, account, promptCacheKey)
} else {
s.deleteOpenAICompatSessionResponseID(ctx, c, account, promptCacheKey)
}
logger.L().Info("openai messages: previous_response_id unavailable, retrying without continuation",
zap.Int64("account_id", account.ID),
zap.String("previous_response_id", truncateOpenAIWSLogValue(previousResponseID, openAIWSIDValueMaxLen)),
zap.String("upstream_model", upstreamModel),
)
return s.ForwardAsAnthropic(ctx, c, account, body, promptCacheKey, defaultMappedModel)
}
if s.shouldFailoverOpenAIUpstreamResponse(resp.StatusCode, upstreamMsg, respBody) { if s.shouldFailoverOpenAIUpstreamResponse(resp.StatusCode, upstreamMsg, respBody) {
upstreamDetail := "" upstreamDetail := ""
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
@ -237,6 +350,12 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic(
return s.handleAnthropicErrorResponse(resp, c, account) return s.handleAnthropicErrorResponse(resp, c, account)
} }
if account.Type == AccountTypeOAuth && promptCacheKey != "" {
if turnState := strings.TrimSpace(resp.Header.Get("x-codex-turn-state")); turnState != "" {
s.bindOpenAICompatSessionTurnState(ctx, c, account, promptCacheKey, turnState)
}
}
// 9. Handle normal response // 9. Handle normal response
// Upstream is always streaming; choose response format based on client preference. // Upstream is always streaming; choose response format based on client preference.
var result *OpenAIForwardResult var result *OpenAIForwardResult
@ -250,6 +369,12 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic(
// Propagate ServiceTier and ReasoningEffort to result for billing // Propagate ServiceTier and ReasoningEffort to result for billing
if handleErr == nil && result != nil { if handleErr == nil && result != nil {
if compatContinuationEnabled && promptCacheKey != "" && result.ResponseID != "" {
s.bindOpenAICompatSessionResponseID(ctx, c, account, promptCacheKey, result.ResponseID)
}
if promptCacheKey != "" && anthropicDigestChain != "" {
s.bindOpenAICompatAnthropicDigestPromptCacheKey(account, apiKeyID, anthropicDigestChain, promptCacheKey, anthropicMatchedDigestChain)
}
if responsesReq.ServiceTier != "" { if responsesReq.ServiceTier != "" {
st := responsesReq.ServiceTier st := responsesReq.ServiceTier
result.ServiceTier = &st result.ServiceTier = &st
@ -270,6 +395,19 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic(
return result, handleErr return result, handleErr
} }
func ensureCodexOAuthInstructionsField(reqBody map[string]any) {
if reqBody == nil {
return
}
if value, ok := reqBody["instructions"]; !ok || value == nil {
reqBody["instructions"] = ""
return
}
if _, ok := reqBody["instructions"].(string); !ok {
reqBody["instructions"] = ""
}
}
// handleAnthropicErrorResponse reads an upstream error and returns it in // handleAnthropicErrorResponse reads an upstream error and returns it in
// Anthropic error format. // Anthropic error format.
func (s *OpenAIGatewayService) handleAnthropicErrorResponse( func (s *OpenAIGatewayService) handleAnthropicErrorResponse(
@ -296,61 +434,9 @@ func (s *OpenAIGatewayService) handleAnthropicBufferedStreamingResponse(
) (*OpenAIForwardResult, error) { ) (*OpenAIForwardResult, error) {
requestID := resp.Header.Get("x-request-id") requestID := resp.Header.Get("x-request-id")
scanner := bufio.NewScanner(resp.Body) finalResponse, usage, acc, err := s.readOpenAICompatBufferedTerminal(resp, "openai messages buffered", requestID)
maxLineSize := defaultMaxLineSize if err != nil {
if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 { return nil, err
maxLineSize = s.cfg.Gateway.MaxLineSize
}
scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize)
var finalResponse *apicompat.ResponsesResponse
var usage OpenAIUsage
acc := apicompat.NewBufferedResponseAccumulator()
for scanner.Scan() {
line := scanner.Text()
if !strings.HasPrefix(line, "data: ") || line == "data: [DONE]" {
continue
}
payload := line[6:]
var event apicompat.ResponsesStreamEvent
if err := json.Unmarshal([]byte(payload), &event); err != nil {
logger.L().Warn("openai messages buffered: failed to parse event",
zap.Error(err),
zap.String("request_id", requestID),
)
continue
}
// Accumulate delta content for fallback when terminal output is empty.
acc.ProcessEvent(&event)
// Terminal events carry the complete ResponsesResponse with output + usage.
if (event.Type == "response.completed" || event.Type == "response.done" ||
event.Type == "response.incomplete" || event.Type == "response.failed") &&
event.Response != nil {
finalResponse = event.Response
if event.Response.Usage != nil {
usage = OpenAIUsage{
InputTokens: event.Response.Usage.InputTokens,
OutputTokens: event.Response.Usage.OutputTokens,
}
if event.Response.Usage.InputTokensDetails != nil {
usage.CacheReadInputTokens = event.Response.Usage.InputTokensDetails.CachedTokens
}
}
}
}
if err := scanner.Err(); err != nil {
if !errors.Is(err, context.Canceled) && !errors.Is(err, context.DeadlineExceeded) {
logger.L().Warn("openai messages buffered: read error",
zap.Error(err),
zap.String("request_id", requestID),
)
}
} }
if finalResponse == nil { if finalResponse == nil {
@ -371,6 +457,7 @@ func (s *OpenAIGatewayService) handleAnthropicBufferedStreamingResponse(
return &OpenAIForwardResult{ return &OpenAIForwardResult{
RequestID: requestID, RequestID: requestID,
ResponseID: finalResponse.ID,
Usage: usage, Usage: usage,
Model: originalModel, Model: originalModel,
BillingModel: billingModel, BillingModel: billingModel,
@ -380,6 +467,153 @@ func (s *OpenAIGatewayService) handleAnthropicBufferedStreamingResponse(
}, nil }, nil
} }
func isOpenAICompatResponsesTerminalEvent(eventType string) bool {
switch strings.TrimSpace(eventType) {
case "response.completed", "response.done", "response.incomplete", "response.failed":
return true
default:
return false
}
}
func isOpenAICompatDoneSentinelLine(line string) bool {
payload, ok := extractOpenAISSEDataLine(line)
return ok && strings.TrimSpace(payload) == "[DONE]"
}
func (s *OpenAIGatewayService) readOpenAICompatBufferedTerminal(
resp *http.Response,
logPrefix string,
requestID string,
) (*apicompat.ResponsesResponse, OpenAIUsage, *apicompat.BufferedResponseAccumulator, error) {
acc := apicompat.NewBufferedResponseAccumulator()
var usage OpenAIUsage
if resp == nil || resp.Body == nil {
return nil, usage, acc, errors.New("upstream response body is nil")
}
scanner := bufio.NewScanner(resp.Body)
maxLineSize := defaultMaxLineSize
if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 {
maxLineSize = s.cfg.Gateway.MaxLineSize
}
scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize)
streamInterval := time.Duration(0)
if s.cfg != nil && s.cfg.Gateway.StreamDataIntervalTimeout > 0 {
streamInterval = time.Duration(s.cfg.Gateway.StreamDataIntervalTimeout) * time.Second
}
var timeoutCh <-chan time.Time
var timeoutTimer *time.Timer
resetTimeout := func() {
if streamInterval <= 0 {
return
}
if timeoutTimer == nil {
timeoutTimer = time.NewTimer(streamInterval)
timeoutCh = timeoutTimer.C
return
}
if !timeoutTimer.Stop() {
select {
case <-timeoutTimer.C:
default:
}
}
timeoutTimer.Reset(streamInterval)
}
stopTimeout := func() {
if timeoutTimer == nil {
return
}
if !timeoutTimer.Stop() {
select {
case <-timeoutTimer.C:
default:
}
}
}
resetTimeout()
defer stopTimeout()
type scanEvent struct {
line string
err error
}
events := make(chan scanEvent, 16)
done := make(chan struct{})
go func() {
defer close(events)
for scanner.Scan() {
select {
case events <- scanEvent{line: scanner.Text()}:
case <-done:
return
}
}
if err := scanner.Err(); err != nil {
select {
case events <- scanEvent{err: err}:
case <-done:
}
}
}()
defer close(done)
for {
select {
case ev, ok := <-events:
if !ok {
return nil, usage, acc, nil
}
resetTimeout()
if ev.err != nil {
if !errors.Is(ev.err, context.Canceled) && !errors.Is(ev.err, context.DeadlineExceeded) {
logger.L().Warn(logPrefix+": read error",
zap.Error(ev.err),
zap.String("request_id", requestID),
)
}
return nil, usage, acc, ev.err
}
if isOpenAICompatDoneSentinelLine(ev.line) {
return nil, usage, acc, nil
}
payload, ok := extractOpenAISSEDataLine(ev.line)
if !ok || payload == "" {
continue
}
var event apicompat.ResponsesStreamEvent
if err := json.Unmarshal([]byte(payload), &event); err != nil {
logger.L().Warn(logPrefix+": failed to parse event",
zap.Error(err),
zap.String("request_id", requestID),
)
continue
}
acc.ProcessEvent(&event)
if isOpenAICompatResponsesTerminalEvent(event.Type) && event.Response != nil {
if event.Response.Usage != nil {
usage = copyOpenAIUsageFromResponsesUsage(event.Response.Usage)
}
return event.Response, usage, acc, nil
}
case <-timeoutCh:
_ = resp.Body.Close()
logger.L().Warn(logPrefix+": data interval timeout",
zap.String("request_id", requestID),
zap.Duration("interval", streamInterval),
)
return nil, usage, acc, fmt.Errorf("stream data interval timeout")
}
}
}
// handleAnthropicStreamingResponse reads Responses SSE events from upstream, // handleAnthropicStreamingResponse reads Responses SSE events from upstream,
// converts each to Anthropic SSE events, and writes them to the client. // converts each to Anthropic SSE events, and writes them to the client.
// When StreamKeepaliveInterval is configured, it uses a goroutine + channel // When StreamKeepaliveInterval is configured, it uses a goroutine + channel
@ -407,8 +641,10 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse(
state := apicompat.NewResponsesEventToAnthropicState() state := apicompat.NewResponsesEventToAnthropicState()
state.Model = originalModel state.Model = originalModel
var usage OpenAIUsage var usage OpenAIUsage
responseID := ""
var firstTokenMs *int var firstTokenMs *int
firstChunk := true firstChunk := true
clientDisconnected := false
scanner := bufio.NewScanner(resp.Body) scanner := bufio.NewScanner(resp.Body)
maxLineSize := defaultMaxLineSize maxLineSize := defaultMaxLineSize
@ -417,10 +653,25 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse(
} }
scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize) scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize)
streamInterval := time.Duration(0)
if s.cfg != nil && s.cfg.Gateway.StreamDataIntervalTimeout > 0 {
streamInterval = time.Duration(s.cfg.Gateway.StreamDataIntervalTimeout) * time.Second
}
var intervalTicker *time.Ticker
if streamInterval > 0 {
intervalTicker = time.NewTicker(streamInterval)
defer intervalTicker.Stop()
}
var intervalCh <-chan time.Time
if intervalTicker != nil {
intervalCh = intervalTicker.C
}
// resultWithUsage builds the final result snapshot. // resultWithUsage builds the final result snapshot.
resultWithUsage := func() *OpenAIForwardResult { resultWithUsage := func() *OpenAIForwardResult {
return &OpenAIForwardResult{ return &OpenAIForwardResult{
RequestID: requestID, RequestID: requestID,
ResponseID: responseID,
Usage: usage, Usage: usage,
Model: originalModel, Model: originalModel,
BillingModel: billingModel, BillingModel: billingModel,
@ -432,7 +683,6 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse(
} }
// processDataLine handles a single "data: ..." SSE line from upstream. // processDataLine handles a single "data: ..." SSE line from upstream.
// Returns (clientDisconnected bool).
processDataLine := func(payload string) bool { processDataLine := func(payload string) bool {
if firstChunk { if firstChunk {
firstChunk = false firstChunk = false
@ -449,53 +699,63 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse(
return false return false
} }
// Extract usage from completion events // 仅按兼容转换器支持的终止事件提取 usage避免无意扩大事件语义。
if (event.Type == "response.completed" || event.Type == "response.incomplete" || event.Type == "response.failed") && isTerminalEvent := isOpenAICompatResponsesTerminalEvent(event.Type)
event.Response != nil && event.Response.Usage != nil { if isTerminalEvent && event.Response != nil {
usage = OpenAIUsage{ if id := strings.TrimSpace(event.Response.ID); id != "" {
InputTokens: event.Response.Usage.InputTokens, responseID = id
OutputTokens: event.Response.Usage.OutputTokens,
} }
if event.Response.Usage.InputTokensDetails != nil { if event.Response.Usage != nil {
usage.CacheReadInputTokens = event.Response.Usage.InputTokensDetails.CachedTokens usage = copyOpenAIUsageFromResponsesUsage(event.Response.Usage)
} }
} }
// Convert to Anthropic events // Convert to Anthropic events
events := apicompat.ResponsesEventToAnthropicEvents(&event, state) events := apicompat.ResponsesEventToAnthropicEvents(&event, state)
for _, evt := range events { if !clientDisconnected {
sse, err := apicompat.ResponsesAnthropicEventToSSE(evt) for _, evt := range events {
if err != nil { sse, err := apicompat.ResponsesAnthropicEventToSSE(evt)
logger.L().Warn("openai messages stream: failed to marshal event", if err != nil {
zap.Error(err), logger.L().Warn("openai messages stream: failed to marshal event",
zap.String("request_id", requestID), zap.Error(err),
) zap.String("request_id", requestID),
continue )
} continue
if _, err := fmt.Fprint(c.Writer, sse); err != nil { }
logger.L().Info("openai messages stream: client disconnected", if _, err := fmt.Fprint(c.Writer, sse); err != nil {
zap.String("request_id", requestID), clientDisconnected = true
) logger.L().Info("openai messages stream: client disconnected, continuing to drain upstream for billing",
return true zap.String("request_id", requestID),
)
break
}
} }
} }
if len(events) > 0 { if len(events) > 0 && !clientDisconnected {
c.Writer.Flush() c.Writer.Flush()
} }
return false return isTerminalEvent
} }
// finalizeStream sends any remaining Anthropic events and returns the result. // finalizeStream sends any remaining Anthropic events and returns the result.
finalizeStream := func() (*OpenAIForwardResult, error) { finalizeStream := func() (*OpenAIForwardResult, error) {
if finalEvents := apicompat.FinalizeResponsesAnthropicStream(state); len(finalEvents) > 0 { if finalEvents := apicompat.FinalizeResponsesAnthropicStream(state); len(finalEvents) > 0 && !clientDisconnected {
for _, evt := range finalEvents { for _, evt := range finalEvents {
sse, err := apicompat.ResponsesAnthropicEventToSSE(evt) sse, err := apicompat.ResponsesAnthropicEventToSSE(evt)
if err != nil { if err != nil {
continue continue
} }
fmt.Fprint(c.Writer, sse) //nolint:errcheck if _, err := fmt.Fprint(c.Writer, sse); err != nil {
clientDisconnected = true
logger.L().Info("openai messages stream: client disconnected during final flush",
zap.String("request_id", requestID),
)
break
}
}
if !clientDisconnected {
c.Writer.Flush()
} }
c.Writer.Flush()
} }
return resultWithUsage(), nil return resultWithUsage(), nil
} }
@ -509,6 +769,9 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse(
) )
} }
} }
missingTerminalErr := func() (*OpenAIForwardResult, error) {
return resultWithUsage(), fmt.Errorf("stream usage incomplete: missing terminal event")
}
// ── Determine keepalive interval ── // ── Determine keepalive interval ──
keepaliveInterval := time.Duration(0) keepaliveInterval := time.Duration(0)
@ -517,18 +780,25 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse(
} }
// ── No keepalive: fast synchronous path (no goroutine overhead) ── // ── No keepalive: fast synchronous path (no goroutine overhead) ──
if keepaliveInterval <= 0 { if streamInterval <= 0 && keepaliveInterval <= 0 {
for scanner.Scan() { for scanner.Scan() {
line := scanner.Text() line := scanner.Text()
if !strings.HasPrefix(line, "data: ") || line == "data: [DONE]" { if isOpenAICompatDoneSentinelLine(line) {
return missingTerminalErr()
}
payload, ok := extractOpenAISSEDataLine(line)
if !ok {
continue continue
} }
if processDataLine(line[6:]) { if processDataLine(payload) {
return resultWithUsage(), nil return finalizeStream()
} }
} }
handleScanErr(scanner.Err()) if err := scanner.Err(); err != nil {
return finalizeStream() handleScanErr(err)
return resultWithUsage(), fmt.Errorf("stream usage incomplete: %w", err)
}
return missingTerminalErr()
} }
// ── With keepalive: goroutine + channel + select ── // ── With keepalive: goroutine + channel + select ──
@ -538,6 +808,8 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse(
} }
events := make(chan scanEvent, 16) events := make(chan scanEvent, 16)
done := make(chan struct{}) done := make(chan struct{})
var lastReadAt int64
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
sendEvent := func(ev scanEvent) bool { sendEvent := func(ev scanEvent) bool {
select { select {
case events <- ev: case events <- ev:
@ -549,6 +821,7 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse(
go func() { go func() {
defer close(events) defer close(events)
for scanner.Scan() { for scanner.Scan() {
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
if !sendEvent(scanEvent{line: scanner.Text()}) { if !sendEvent(scanEvent{line: scanner.Text()}) {
return return
} }
@ -559,8 +832,15 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse(
}() }()
defer close(done) defer close(done)
keepaliveTicker := time.NewTicker(keepaliveInterval) var keepaliveTicker *time.Ticker
defer keepaliveTicker.Stop() if keepaliveInterval > 0 {
keepaliveTicker = time.NewTicker(keepaliveInterval)
defer keepaliveTicker.Stop()
}
var keepaliveCh <-chan time.Time
if keepaliveTicker != nil {
keepaliveCh = keepaliveTicker.C
}
lastDataAt := time.Now() lastDataAt := time.Now()
for { for {
@ -568,22 +848,44 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse(
case ev, ok := <-events: case ev, ok := <-events:
if !ok { if !ok {
// Upstream closed // Upstream closed
return finalizeStream() return missingTerminalErr()
} }
if ev.err != nil { if ev.err != nil {
handleScanErr(ev.err) handleScanErr(ev.err)
return finalizeStream() return resultWithUsage(), fmt.Errorf("stream usage incomplete: %w", ev.err)
} }
lastDataAt = time.Now() lastDataAt = time.Now()
line := ev.line line := ev.line
if !strings.HasPrefix(line, "data: ") || line == "data: [DONE]" { if isOpenAICompatDoneSentinelLine(line) {
return missingTerminalErr()
}
payload, ok := extractOpenAISSEDataLine(line)
if !ok {
continue continue
} }
if processDataLine(line[6:]) { if processDataLine(payload) {
return resultWithUsage(), nil return finalizeStream()
} }
case <-keepaliveTicker.C: case <-intervalCh:
lastRead := time.Unix(0, atomic.LoadInt64(&lastReadAt))
if time.Since(lastRead) < streamInterval {
continue
}
if clientDisconnected {
return resultWithUsage(), fmt.Errorf("stream usage incomplete after timeout")
}
logger.L().Warn("openai messages stream: data interval timeout",
zap.String("request_id", requestID),
zap.String("model", originalModel),
zap.Duration("interval", streamInterval),
)
return resultWithUsage(), fmt.Errorf("stream data interval timeout")
case <-keepaliveCh:
if clientDisconnected {
continue
}
if time.Since(lastDataAt) < keepaliveInterval { if time.Since(lastDataAt) < keepaliveInterval {
continue continue
} }
@ -593,7 +895,8 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse(
logger.L().Info("openai messages stream: client disconnected during keepalive", logger.L().Info("openai messages stream: client disconnected during keepalive",
zap.String("request_id", requestID), zap.String("request_id", requestID),
) )
return resultWithUsage(), nil clientDisconnected = true
continue
} }
c.Writer.Flush() c.Writer.Flush()
} }
@ -610,3 +913,17 @@ func writeAnthropicError(c *gin.Context, statusCode int, errType, message string
}, },
}) })
} }
func copyOpenAIUsageFromResponsesUsage(usage *apicompat.ResponsesUsage) OpenAIUsage {
if usage == nil {
return OpenAIUsage{}
}
result := OpenAIUsage{
InputTokens: usage.InputTokens,
OutputTokens: usage.OutputTokens,
}
if usage.InputTokensDetails != nil {
result.CacheReadInputTokens = usage.InputTokensDetails.CachedTokens
}
return result
}

View File

@ -52,6 +52,12 @@ func (s *openAIRecordUsageBillingRepoStub) Apply(ctx context.Context, cmd *Usage
return &UsageBillingApplyResult{Applied: true}, nil return &UsageBillingApplyResult{Applied: true}, nil
} }
func TestOpenAIGatewayServiceRecordUsage_RejectsNilInput(t *testing.T) {
svc := &OpenAIGatewayService{}
require.Error(t, svc.RecordUsage(context.Background(), nil))
require.Error(t, svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{}))
}
type openAIRecordUsageUserRepoStub struct { type openAIRecordUsageUserRepoStub struct {
UserRepository UserRepository
@ -186,6 +192,56 @@ func max(a, b int) int {
return b return b
} }
func TestOpenAIGatewayServiceRecordUsage_ZeroUsageStillWritesUsageLog(t *testing.T) {
usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
billingRepo := &openAIRecordUsageBillingRepoStub{result: &UsageBillingApplyResult{Applied: true}}
userRepo := &openAIRecordUsageUserRepoStub{}
subRepo := &openAIRecordUsageSubRepoStub{}
quotaSvc := &openAIRecordUsageAPIKeyQuotaStub{}
svc := newOpenAIRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, userRepo, subRepo, nil)
err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
Result: &OpenAIForwardResult{
RequestID: "resp_zero_usage",
Usage: OpenAIUsage{},
Model: "gpt-5.1",
Duration: time.Second,
},
APIKey: &APIKey{ID: 1000, Quota: 100, Group: &Group{RateMultiplier: 1}},
User: &User{ID: 2000},
Account: &Account{ID: 3000, Type: AccountTypeAPIKey},
APIKeyService: quotaSvc,
})
require.NoError(t, err)
require.Equal(t, 1, billingRepo.calls)
require.Equal(t, 1, usageRepo.calls)
require.Equal(t, 0, userRepo.deductCalls)
require.Equal(t, 0, subRepo.incrementCalls)
require.Equal(t, 0, quotaSvc.quotaCalls)
require.Equal(t, 0, quotaSvc.rateLimitCalls)
require.NotNil(t, usageRepo.lastLog)
require.Equal(t, "resp_zero_usage", usageRepo.lastLog.RequestID)
require.Zero(t, usageRepo.lastLog.InputTokens)
require.Zero(t, usageRepo.lastLog.OutputTokens)
require.Zero(t, usageRepo.lastLog.CacheCreationTokens)
require.Zero(t, usageRepo.lastLog.CacheReadTokens)
require.Zero(t, usageRepo.lastLog.ImageOutputTokens)
require.Zero(t, usageRepo.lastLog.ImageCount)
require.Zero(t, usageRepo.lastLog.InputCost)
require.Zero(t, usageRepo.lastLog.OutputCost)
require.Zero(t, usageRepo.lastLog.TotalCost)
require.Zero(t, usageRepo.lastLog.ActualCost)
require.NotNil(t, billingRepo.lastCmd)
require.Zero(t, billingRepo.lastCmd.BalanceCost)
require.Zero(t, billingRepo.lastCmd.SubscriptionCost)
require.Zero(t, billingRepo.lastCmd.APIKeyQuotaCost)
require.Zero(t, billingRepo.lastCmd.APIKeyRateLimitCost)
require.Zero(t, billingRepo.lastCmd.AccountQuotaCost)
}
func TestOpenAIGatewayServiceRecordUsage_UsesUserSpecificGroupRate(t *testing.T) { func TestOpenAIGatewayServiceRecordUsage_UsesUserSpecificGroupRate(t *testing.T) {
groupID := int64(11) groupID := int64(11)
groupRate := 1.4 groupRate := 1.4
@ -956,9 +1012,8 @@ func TestOpenAIGatewayServiceRecordUsage_ChannelMappedDoesNotOverrideBillingMode
svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil) svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil)
usage := OpenAIUsage{InputTokens: 20, OutputTokens: 10} usage := OpenAIUsage{InputTokens: 20, OutputTokens: 10}
// When channel did NOT map the model (ChannelMappedModel == OriginalModel), // 渠道未发生模型映射时,应使用 result.BillingModel 中记录的实际上游计费模型,
// billing should use result.BillingModel (the actual model used after group // 而不是未映射的原始请求模型。
// DefaultMappedModel resolution), not the unmapped original model.
expectedCost, err := svc.billingService.CalculateCost("gpt-5.1", UsageTokens{ expectedCost, err := svc.billingService.CalculateCost("gpt-5.1", UsageTokens{
InputTokens: 20, InputTokens: 20,
OutputTokens: 10, OutputTokens: 10,
@ -1032,6 +1087,101 @@ func TestOpenAIGatewayServiceRecordUsage_ChannelMappedOverridesBillingModelWhenM
require.True(t, usageRepo.lastLog.ActualCost > 0, "cost must not be zero") require.True(t, usageRepo.lastLog.ActualCost > 0, "cost must not be zero")
} }
func TestOpenAIGatewayServiceRecordUsage_BillsCompactOpenAIModelAlias(t *testing.T) {
usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
userRepo := &openAIRecordUsageUserRepoStub{}
subRepo := &openAIRecordUsageSubRepoStub{}
svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil)
usage := OpenAIUsage{InputTokens: 20, OutputTokens: 10}
expectedCost, err := svc.billingService.CalculateCost("gpt-5.5", UsageTokens{
InputTokens: 20,
OutputTokens: 10,
}, 1.1)
require.NoError(t, err)
err = svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
Result: &OpenAIForwardResult{
RequestID: "resp_compact_openai_alias",
Model: "gpt5.5",
UpstreamModel: "gpt-5.4",
Usage: usage,
Duration: time.Second,
},
APIKey: &APIKey{ID: 10},
User: &User{ID: 20},
Account: &Account{ID: 30},
})
require.NoError(t, err)
require.NotNil(t, usageRepo.lastLog)
require.Equal(t, "gpt5.5", usageRepo.lastLog.Model)
require.NotNil(t, usageRepo.lastLog.UpstreamModel)
require.Equal(t, "gpt-5.4", *usageRepo.lastLog.UpstreamModel)
require.InDelta(t, expectedCost.ActualCost, usageRepo.lastLog.ActualCost, 1e-12)
require.True(t, usageRepo.lastLog.ActualCost > 0, "cost must not be zero")
require.InDelta(t, expectedCost.ActualCost, userRepo.lastAmount, 1e-12)
}
func TestOpenAIGatewayServiceRecordUsage_FallsBackToUpstreamModelWhenPrimaryUnpriceable(t *testing.T) {
usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
userRepo := &openAIRecordUsageUserRepoStub{}
subRepo := &openAIRecordUsageSubRepoStub{}
svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil)
usage := OpenAIUsage{InputTokens: 20, OutputTokens: 10}
expectedCost, err := svc.billingService.CalculateCost("gpt-5.4", UsageTokens{
InputTokens: 20,
OutputTokens: 10,
}, 1.1)
require.NoError(t, err)
err = svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
Result: &OpenAIForwardResult{
RequestID: "resp_unpriceable_primary_upstream_fallback",
Model: "not-priceable-alias",
BillingModel: "not-priceable-alias",
UpstreamModel: "gpt-5.4",
Usage: usage,
Duration: time.Second,
},
APIKey: &APIKey{ID: 10},
User: &User{ID: 20},
Account: &Account{ID: 30},
})
require.NoError(t, err)
require.NotNil(t, usageRepo.lastLog)
require.InDelta(t, expectedCost.ActualCost, usageRepo.lastLog.ActualCost, 1e-12)
require.True(t, usageRepo.lastLog.ActualCost > 0, "cost must not be zero")
require.InDelta(t, expectedCost.ActualCost, userRepo.lastAmount, 1e-12)
}
func TestOpenAIGatewayServiceRecordUsage_ReturnsErrorWhenTokenModelCannotBePriced(t *testing.T) {
usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
userRepo := &openAIRecordUsageUserRepoStub{}
subRepo := &openAIRecordUsageSubRepoStub{}
svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil)
err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
Result: &OpenAIForwardResult{
RequestID: "resp_unpriceable_without_upstream",
Model: "not-priceable-alias",
Usage: OpenAIUsage{InputTokens: 20, OutputTokens: 10},
Duration: time.Second,
},
APIKey: &APIKey{ID: 10},
User: &User{ID: 20},
Account: &Account{ID: 30},
})
require.Error(t, err)
require.Contains(t, err.Error(), "calculate OpenAI usage cost failed")
require.Equal(t, 0, usageRepo.calls)
require.Equal(t, 0, userRepo.deductCalls)
require.Equal(t, 0, subRepo.incrementCalls)
}
func TestOpenAIGatewayServiceRecordUsage_SubscriptionBillingSetsSubscriptionFields(t *testing.T) { func TestOpenAIGatewayServiceRecordUsage_SubscriptionBillingSetsSubscriptionFields(t *testing.T) {
usageRepo := &openAIRecordUsageLogRepoStub{inserted: true} usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
userRepo := &openAIRecordUsageUserRepoStub{} userRepo := &openAIRecordUsageUserRepoStub{}
@ -1160,3 +1310,278 @@ func TestOpenAIGatewayServiceRecordUsage_ImageUsesPerImageBillingEvenWithUsageTo
require.InDelta(t, 0.0, usageRepo.lastLog.OutputCost, 1e-12) require.InDelta(t, 0.0, usageRepo.lastLog.OutputCost, 1e-12)
require.InDelta(t, 0.0, usageRepo.lastLog.ImageOutputCost, 1e-12) require.InDelta(t, 0.0, usageRepo.lastLog.ImageOutputCost, 1e-12)
} }
func TestOpenAIGatewayServiceRecordUsage_ImageSharedMultiplierPreservesExistingBehavior(t *testing.T) {
imagePrice := 0.2
groupID := int64(121)
usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
svc := newOpenAIRecordUsageServiceForTest(usageRepo, &openAIRecordUsageUserRepoStub{}, &openAIRecordUsageSubRepoStub{}, nil)
err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
Result: &OpenAIForwardResult{
RequestID: "resp_image_shared_multiplier",
Model: "gpt-image-2",
ImageCount: 1,
ImageSize: "1K",
Duration: time.Second,
},
APIKey: &APIKey{
ID: 10121,
GroupID: i64p(groupID),
Group: &Group{
ID: groupID,
RateMultiplier: 0.15,
ImageRateIndependent: false,
ImageRateMultiplier: 1,
ImagePrice1K: &imagePrice,
},
},
User: &User{ID: 20121},
Account: &Account{ID: 30121},
})
require.NoError(t, err)
require.NotNil(t, usageRepo.lastLog)
require.InDelta(t, 0.2, usageRepo.lastLog.TotalCost, 1e-12)
require.InDelta(t, 0.03, usageRepo.lastLog.ActualCost, 1e-12)
require.InDelta(t, 0.15, usageRepo.lastLog.RateMultiplier, 1e-12)
require.NotNil(t, usageRepo.lastLog.BillingMode)
require.Equal(t, string(BillingModeImage), *usageRepo.lastLog.BillingMode)
}
func TestOpenAIGatewayServiceRecordUsage_ImageSharedMultiplierUsesUserGroupOverride(t *testing.T) {
imagePrice := 0.5
userRate := 0.2
groupID := int64(125)
usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
svc := newOpenAIRecordUsageServiceForTest(
usageRepo,
&openAIRecordUsageUserRepoStub{},
&openAIRecordUsageSubRepoStub{},
&openAIUserGroupRateRepoStub{rate: &userRate},
)
err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
Result: &OpenAIForwardResult{
RequestID: "resp_image_user_group_override",
Model: "gpt-image-2",
ImageCount: 1,
ImageSize: "1K",
Duration: time.Second,
},
APIKey: &APIKey{
ID: 10125,
GroupID: i64p(groupID),
Group: &Group{
ID: groupID,
RateMultiplier: 0.15,
ImageRateIndependent: false,
ImageRateMultiplier: 1,
ImagePrice1K: &imagePrice,
},
},
User: &User{ID: 20125},
Account: &Account{ID: 30125},
})
require.NoError(t, err)
require.NotNil(t, usageRepo.lastLog)
require.InDelta(t, 0.5, usageRepo.lastLog.TotalCost, 1e-12)
require.InDelta(t, 0.1, usageRepo.lastLog.ActualCost, 1e-12)
require.InDelta(t, 0.2, usageRepo.lastLog.RateMultiplier, 1e-12)
}
func TestOpenAIGatewayServiceRecordUsage_ImageIndependentMultiplierUsesImageRate(t *testing.T) {
imagePrice := 0.2
groupID := int64(122)
usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
svc := newOpenAIRecordUsageServiceForTest(usageRepo, &openAIRecordUsageUserRepoStub{}, &openAIRecordUsageSubRepoStub{}, nil)
err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
Result: &OpenAIForwardResult{
RequestID: "resp_image_independent_multiplier",
Model: "gpt-image-2",
ImageCount: 1,
ImageSize: "1K",
Duration: time.Second,
},
APIKey: &APIKey{
ID: 10122,
GroupID: i64p(groupID),
Group: &Group{
ID: groupID,
RateMultiplier: 0.15,
ImageRateIndependent: true,
ImageRateMultiplier: 1,
ImagePrice1K: &imagePrice,
},
},
User: &User{ID: 20122},
Account: &Account{ID: 30122},
})
require.NoError(t, err)
require.NotNil(t, usageRepo.lastLog)
require.InDelta(t, 0.2, usageRepo.lastLog.TotalCost, 1e-12)
require.InDelta(t, 0.2, usageRepo.lastLog.ActualCost, 1e-12)
require.InDelta(t, 1.0, usageRepo.lastLog.RateMultiplier, 1e-12)
require.NotNil(t, usageRepo.lastLog.BillingMode)
require.Equal(t, string(BillingModeImage), *usageRepo.lastLog.BillingMode)
}
func TestOpenAIGatewayServiceRecordUsage_ChannelImageBillingUsesImageCountAndSharedMultiplier(t *testing.T) {
groupID := int64(123)
usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
svc := newOpenAIRecordUsageServiceForTest(usageRepo, &openAIRecordUsageUserRepoStub{}, &openAIRecordUsageSubRepoStub{}, nil)
svc.resolver = newOpenAIImageChannelPricingResolverForTest(t, groupID, "gpt-image-2", 0.25)
err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
Result: &OpenAIForwardResult{
RequestID: "resp_image_channel_shared",
Model: "gpt-image-2",
ImageCount: 3,
ImageSize: "1K",
Duration: time.Second,
},
APIKey: &APIKey{
ID: 10123,
GroupID: i64p(groupID),
Group: &Group{
ID: groupID,
RateMultiplier: 0.15,
ImageRateIndependent: false,
ImageRateMultiplier: 1,
},
},
User: &User{ID: 20123},
Account: &Account{ID: 30123},
})
require.NoError(t, err)
require.NotNil(t, usageRepo.lastLog)
require.InDelta(t, 0.75, usageRepo.lastLog.TotalCost, 1e-12)
require.InDelta(t, 0.1125, usageRepo.lastLog.ActualCost, 1e-12)
require.InDelta(t, 0.15, usageRepo.lastLog.RateMultiplier, 1e-12)
require.Equal(t, 3, usageRepo.lastLog.ImageCount)
require.NotNil(t, usageRepo.lastLog.BillingMode)
require.Equal(t, string(BillingModeImage), *usageRepo.lastLog.BillingMode)
}
func TestOpenAIGatewayServiceRecordUsage_ChannelImageBillingUsesImageCountAndIndependentMultiplier(t *testing.T) {
groupID := int64(124)
usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
svc := newOpenAIRecordUsageServiceForTest(usageRepo, &openAIRecordUsageUserRepoStub{}, &openAIRecordUsageSubRepoStub{}, nil)
svc.resolver = newOpenAIImageChannelPricingResolverForTest(t, groupID, "gpt-image-2", 0.25)
err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
Result: &OpenAIForwardResult{
RequestID: "resp_image_channel_independent",
Model: "gpt-image-2",
ImageCount: 3,
ImageSize: "1K",
Duration: time.Second,
},
APIKey: &APIKey{
ID: 10124,
GroupID: i64p(groupID),
Group: &Group{
ID: groupID,
RateMultiplier: 0.15,
ImageRateIndependent: true,
ImageRateMultiplier: 1,
},
},
User: &User{ID: 20124},
Account: &Account{ID: 30124},
})
require.NoError(t, err)
require.NotNil(t, usageRepo.lastLog)
require.InDelta(t, 0.75, usageRepo.lastLog.TotalCost, 1e-12)
require.InDelta(t, 0.75, usageRepo.lastLog.ActualCost, 1e-12)
require.InDelta(t, 1.0, usageRepo.lastLog.RateMultiplier, 1e-12)
require.Equal(t, 3, usageRepo.lastLog.ImageCount)
require.NotNil(t, usageRepo.lastLog.BillingMode)
require.Equal(t, string(BillingModeImage), *usageRepo.lastLog.BillingMode)
}
func newOpenAIImageChannelPricingResolverForTest(t *testing.T, groupID int64, model string, price float64) *ModelPricingResolver {
t.Helper()
cache := newEmptyChannelCache()
cache.pricingByGroupModel[channelModelKey{groupID: groupID, model: model}] = &ChannelModelPricing{
BillingMode: BillingModeImage,
PerRequestPrice: &price,
}
cache.channelByGroupID[groupID] = &Channel{ID: groupID, Status: StatusActive}
cache.groupPlatform[groupID] = ""
cache.loadedAt = time.Now()
cs := &ChannelService{}
cs.cache.Store(cache)
return NewModelPricingResolver(cs, NewBillingService(&config.Config{}, nil))
}
func TestGatewayServiceCalculateRecordUsageCost_ChannelImageBillingUsesImageCount(t *testing.T) {
groupID := int64(126)
billingService := NewBillingService(&config.Config{}, nil)
svc := &GatewayService{
billingService: billingService,
resolver: newOpenAIImageChannelPricingResolverForTest(t, groupID, "gemini-image", 0.25),
}
cost := svc.calculateRecordUsageCost(
context.Background(),
&ForwardResult{Model: "gemini-image", ImageCount: 2, ImageSize: "1K"},
&APIKey{GroupID: i64p(groupID), Group: &Group{ID: groupID}},
"gemini-image",
0.15,
1.0,
nil,
)
require.NotNil(t, cost)
require.Equal(t, string(BillingModeImage), cost.BillingMode)
require.InDelta(t, 0.5, cost.TotalCost, 1e-12)
require.InDelta(t, 0.5, cost.ActualCost, 1e-12)
}
func TestGatewayServiceCalculateRecordUsageCost_ChannelImageBillingUsesSizeTier(t *testing.T) {
groupID := int64(127)
defaultPrice := 0.10
price4K := 0.40
cache := newEmptyChannelCache()
cache.pricingByGroupModel[channelModelKey{groupID: groupID, model: "gemini-image"}] = &ChannelModelPricing{
BillingMode: BillingModeImage,
PerRequestPrice: &defaultPrice,
Intervals: []PricingInterval{{
TierLabel: "4K",
PerRequestPrice: &price4K,
}},
}
cache.channelByGroupID[groupID] = &Channel{ID: groupID, Status: StatusActive}
cache.loadedAt = time.Now()
channelService := &ChannelService{}
channelService.cache.Store(cache)
svc := &GatewayService{
billingService: NewBillingService(&config.Config{}, nil),
resolver: NewModelPricingResolver(channelService, NewBillingService(&config.Config{}, nil)),
}
cost := svc.calculateRecordUsageCost(
context.Background(),
&ForwardResult{Model: "gemini-image", ImageCount: 2, ImageSize: "4K"},
&APIKey{GroupID: i64p(groupID), Group: &Group{ID: groupID}},
"gemini-image",
1.0,
1.0,
nil,
)
require.NotNil(t, cost)
require.Equal(t, string(BillingModeImage), cost.BillingMode)
require.InDelta(t, 0.80, cost.TotalCost, 1e-12)
require.InDelta(t, 0.80, cost.ActualCost, 1e-12)
}

View File

@ -211,9 +211,10 @@ type OpenAIUsage struct {
// OpenAIForwardResult represents the result of forwarding // OpenAIForwardResult represents the result of forwarding
type OpenAIForwardResult struct { type OpenAIForwardResult struct {
RequestID string RequestID string
Usage OpenAIUsage ResponseID string
Model string // 原始模型(用于响应和日志显示) Usage OpenAIUsage
Model string // 原始模型(用于响应和日志显示)
// BillingModel is the model used for cost calculation. // BillingModel is the model used for cost calculation.
// When non-empty, CalculateCost uses this instead of Model. // When non-empty, CalculateCost uses this instead of Model.
// This is set by the Anthropic Messages conversion path where // This is set by the Anthropic Messages conversion path where
@ -346,10 +347,12 @@ type OpenAIGatewayService struct {
openaiWSPassthroughDialer openAIWSClientDialer openaiWSPassthroughDialer openAIWSClientDialer
openaiAccountStats *openAIAccountRuntimeStats openaiAccountStats *openAIAccountRuntimeStats
openaiWSFallbackUntil sync.Map // key: int64(accountID), value: time.Time openaiWSFallbackUntil sync.Map // key: int64(accountID), value: time.Time
openaiWSRetryMetrics openAIWSRetryMetrics openaiWSRetryMetrics openAIWSRetryMetrics
responseHeaderFilter *responseheaders.CompiledHeaderFilter responseHeaderFilter *responseheaders.CompiledHeaderFilter
codexSnapshotThrottle *accountWriteThrottle codexSnapshotThrottle *accountWriteThrottle
openaiCompatSessionResponses sync.Map
openaiCompatAnthropicDigestSessions sync.Map
} }
// NewOpenAIGatewayService creates a new OpenAIGatewayService // NewOpenAIGatewayService creates a new OpenAIGatewayService
@ -1992,6 +1995,8 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
originalBody := body originalBody := body
reqModel, reqStream, promptCacheKey := extractOpenAIRequestMetaFromBody(body) reqModel, reqStream, promptCacheKey := extractOpenAIRequestMetaFromBody(body)
originalModel := reqModel originalModel := reqModel
compatMessagesBridge := isOpenAICompatMessagesBridgeBody(body)
setOpenAICompatMessagesBridgeContext(c, compatMessagesBridge)
isCodexCLI := openai.IsCodexOfficialClientByHeaders(c.GetHeader("User-Agent"), c.GetHeader("originator")) || (s.cfg != nil && s.cfg.Gateway.ForceCodexCLI) isCodexCLI := openai.IsCodexOfficialClientByHeaders(c.GetHeader("User-Agent"), c.GetHeader("originator")) || (s.cfg != nil && s.cfg.Gateway.ForceCodexCLI)
wsDecision := s.getOpenAIWSProtocolResolver().Resolve(account) wsDecision := s.getOpenAIWSProtocolResolver().Resolve(account)
@ -2049,6 +2054,21 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
promptCacheKey = strings.TrimSpace(v) promptCacheKey = strings.TrimSpace(v)
} }
} }
apiKey := getAPIKeyFromContext(c)
imageGenerationAllowed := GroupAllowsImageGeneration(nil)
if apiKey != nil {
imageGenerationAllowed = GroupAllowsImageGeneration(apiKey.Group)
}
if IsImageGenerationIntentMap(openAIResponsesEndpoint, reqModel, reqBody) && !imageGenerationAllowed {
setOpsUpstreamError(c, http.StatusForbidden, ImageGenerationPermissionMessage(), "")
c.JSON(http.StatusForbidden, gin.H{
"error": gin.H{
"type": "permission_error",
"message": ImageGenerationPermissionMessage(),
},
})
return nil, errors.New("image generation disabled for group")
}
// Track if body needs re-serialization // Track if body needs re-serialization
bodyModified := false bodyModified := false
@ -2102,13 +2122,13 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
} }
// 非透传模式下instructions 为空时注入默认指令。 // 非透传模式下instructions 为空时注入默认指令。
if isInstructionsEmpty(reqBody) { if isInstructionsEmpty(reqBody) && !compatMessagesBridge {
reqBody["instructions"] = "You are a helpful coding assistant." reqBody["instructions"] = "You are a helpful coding assistant."
bodyModified = true bodyModified = true
markPatchSet("instructions", "You are a helpful coding assistant.") markPatchSet("instructions", "You are a helpful coding assistant.")
} }
if isCodexCLI && ensureOpenAIResponsesImageGenerationTool(reqBody) { if isCodexCLI && imageGenerationAllowed && ensureOpenAIResponsesImageGenerationTool(reqBody) {
bodyModified = true bodyModified = true
disablePatch() disablePatch()
logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Injected /responses image_generation tool for Codex client") logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Injected /responses image_generation tool for Codex client")
@ -2119,7 +2139,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
disablePatch() disablePatch()
logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Normalized /responses image_generation tool payload") logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Normalized /responses image_generation tool payload")
} }
if isCodexCLI && applyCodexImageGenerationBridgeInstructions(reqBody) { if isCodexCLI && imageGenerationAllowed && applyCodexImageGenerationBridgeInstructions(reqBody) {
bodyModified = true bodyModified = true
disablePatch() disablePatch()
logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Added Codex image_generation bridge instructions") logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Added Codex image_generation bridge instructions")
@ -2134,7 +2154,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
markPatchSet("model", billingModel) markPatchSet("model", billingModel)
} }
upstreamModel := billingModel upstreamModel := billingModel
if normalizeOpenAIResponsesImageOnlyModel(reqBody) { if imageGenerationAllowed && normalizeOpenAIResponsesImageOnlyModel(reqBody) {
bodyModified = true bodyModified = true
disablePatch() disablePatch()
if model, ok := reqBody["model"].(string); ok { if model, ok := reqBody["model"].(string); ok {
@ -2231,7 +2251,20 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
} }
if account.Type == AccountTypeOAuth { if account.Type == AccountTypeOAuth {
codexResult := applyCodexOAuthTransform(reqBody, isCodexCLI, isCompactRequest) codexResult := codexTransformResult{}
if compatMessagesBridge {
codexResult = applyCodexOAuthTransformWithOptions(reqBody, codexOAuthTransformOptions{
IsCodexCLI: isCodexCLI,
IsCompact: isCompactRequest,
SkipDefaultInstructions: true,
PreserveToolCallIDs: true,
})
ensureCodexOAuthInstructionsField(reqBody)
bodyModified = true
disablePatch()
} else {
codexResult = applyCodexOAuthTransform(reqBody, isCodexCLI, isCompactRequest)
}
if codexResult.Modified { if codexResult.Modified {
bodyModified = true bodyModified = true
disablePatch() disablePatch()
@ -2355,6 +2388,34 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
} }
} }
if IsImageGenerationIntentMap(openAIResponsesEndpoint, reqModel, reqBody) && !imageGenerationAllowed {
setOpsUpstreamError(c, http.StatusForbidden, ImageGenerationPermissionMessage(), "")
c.JSON(http.StatusForbidden, gin.H{
"error": gin.H{
"type": "permission_error",
"message": ImageGenerationPermissionMessage(),
},
})
return nil, errors.New("image generation disabled for group")
}
imageBillingModel := ""
imageSizeTier := ""
if IsImageGenerationIntentMap(openAIResponsesEndpoint, reqModel, reqBody) {
var imageCfgErr error
imageBillingModel, imageSizeTier, imageCfgErr = resolveOpenAIResponsesImageBillingConfig(reqBody, billingModel)
if imageCfgErr != nil {
setOpsUpstreamError(c, http.StatusBadRequest, imageCfgErr.Error(), "")
c.JSON(http.StatusBadRequest, gin.H{
"error": gin.H{
"type": "invalid_request_error",
"message": imageCfgErr.Error(),
"param": "size",
},
})
return nil, imageCfgErr
}
}
// Re-serialize body only if modified // Re-serialize body only if modified
if bodyModified { if bodyModified {
serializedByPatch := false serializedByPatch := false
@ -2592,6 +2653,10 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
wsAttempts, wsAttempts,
) )
wsResult.UpstreamModel = upstreamModel wsResult.UpstreamModel = upstreamModel
if wsResult.ImageCount > 0 {
wsResult.ImageSize = imageSizeTier
wsResult.BillingModel = imageBillingModel
}
return wsResult, nil return wsResult, nil
} }
s.writeOpenAIWSFallbackErrorResponse(c, account, wsErr) s.writeOpenAIWSFallbackErrorResponse(c, account, wsErr)
@ -2601,7 +2666,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
httpInvalidEncryptedContentRetryTried := false httpInvalidEncryptedContentRetryTried := false
for { for {
// Build upstream request // Build upstream request
upstreamCtx, releaseUpstreamCtx := detachStreamUpstreamContext(ctx, reqStream) upstreamCtx, releaseUpstreamCtx := detachUpstreamContext(ctx)
upstreamReq, err := s.buildUpstreamRequest(upstreamCtx, c, account, body, token, reqStream, promptCacheKey, isCodexCLI) upstreamReq, err := s.buildUpstreamRequest(upstreamCtx, c, account, body, token, reqStream, promptCacheKey, isCodexCLI)
releaseUpstreamCtx() releaseUpstreamCtx()
if err != nil { if err != nil {
@ -2695,6 +2760,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
// Handle normal response // Handle normal response
var usage *OpenAIUsage var usage *OpenAIUsage
var firstTokenMs *int var firstTokenMs *int
imageCount := 0
if reqStream { if reqStream {
streamResult, err := s.handleStreamingResponse(ctx, resp, c, account, startTime, originalModel, upstreamModel) streamResult, err := s.handleStreamingResponse(ctx, resp, c, account, startTime, originalModel, upstreamModel)
if err != nil { if err != nil {
@ -2702,11 +2768,14 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
} }
usage = streamResult.usage usage = streamResult.usage
firstTokenMs = streamResult.firstTokenMs firstTokenMs = streamResult.firstTokenMs
imageCount = streamResult.imageCount
} else { } else {
usage, err = s.handleNonStreamingResponse(ctx, resp, c, account, originalModel, upstreamModel) nonStreamResult, err := s.handleNonStreamingResponse(ctx, resp, c, account, originalModel, upstreamModel)
if err != nil { if err != nil {
return nil, err return nil, err
} }
usage = nonStreamResult.usage
imageCount = nonStreamResult.imageCount
} }
// Extract and save Codex usage snapshot from response headers (for OAuth accounts) // Extract and save Codex usage snapshot from response headers (for OAuth accounts)
@ -2723,7 +2792,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
reasoningEffort := extractOpenAIReasoningEffort(reqBody, originalModel) reasoningEffort := extractOpenAIReasoningEffort(reqBody, originalModel)
serviceTier := extractOpenAIServiceTier(reqBody) serviceTier := extractOpenAIServiceTier(reqBody)
return &OpenAIForwardResult{ forwardResult := &OpenAIForwardResult{
RequestID: resp.Header.Get("x-request-id"), RequestID: resp.Header.Get("x-request-id"),
Usage: *usage, Usage: *usage,
Model: originalModel, Model: originalModel,
@ -2734,7 +2803,13 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
OpenAIWSMode: false, OpenAIWSMode: false,
Duration: time.Since(startTime), Duration: time.Since(startTime),
FirstTokenMs: firstTokenMs, FirstTokenMs: firstTokenMs,
}, nil }
if imageCount > 0 {
forwardResult.ImageCount = imageCount
forwardResult.ImageSize = imageSizeTier
forwardResult.BillingModel = imageBillingModel
}
return forwardResult, nil
} }
} }
@ -2823,6 +2898,35 @@ func (s *OpenAIGatewayService) forwardOpenAIPassthrough(
} }
body = updatedBody body = updatedBody
apiKey := getAPIKeyFromContext(c)
if IsImageGenerationIntent(openAIResponsesEndpoint, reqModel, body) && !GroupAllowsImageGeneration(apiKeyGroup(apiKey)) {
setOpsUpstreamError(c, http.StatusForbidden, ImageGenerationPermissionMessage(), "")
c.JSON(http.StatusForbidden, gin.H{
"error": gin.H{
"type": "permission_error",
"message": ImageGenerationPermissionMessage(),
},
})
return nil, errors.New("image generation disabled for group")
}
imageBillingModel := ""
imageSizeTier := ""
if IsImageGenerationIntent(openAIResponsesEndpoint, reqModel, body) {
var imageCfgErr error
imageBillingModel, imageSizeTier, imageCfgErr = resolveOpenAIResponsesImageBillingConfigFromBody(body, reqModel)
if imageCfgErr != nil {
setOpsUpstreamError(c, http.StatusBadRequest, imageCfgErr.Error(), "")
c.JSON(http.StatusBadRequest, gin.H{
"error": gin.H{
"type": "invalid_request_error",
"message": imageCfgErr.Error(),
"param": "size",
},
})
return nil, imageCfgErr
}
}
logger.LegacyPrintf("service.openai_gateway", logger.LegacyPrintf("service.openai_gateway",
"[OpenAI 自动透传] 命中自动透传分支: account=%d name=%s type=%s model=%s stream=%v", "[OpenAI 自动透传] 命中自动透传分支: account=%d name=%s type=%s model=%s stream=%v",
account.ID, account.ID,
@ -2852,7 +2956,7 @@ func (s *OpenAIGatewayService) forwardOpenAIPassthrough(
return nil, err return nil, err
} }
upstreamCtx, releaseUpstreamCtx := detachStreamUpstreamContext(ctx, reqStream) upstreamCtx, releaseUpstreamCtx := detachUpstreamContext(ctx)
upstreamReq, err := s.buildUpstreamRequestOpenAIPassthrough(upstreamCtx, c, account, body, token) upstreamReq, err := s.buildUpstreamRequestOpenAIPassthrough(upstreamCtx, c, account, body, token)
releaseUpstreamCtx() releaseUpstreamCtx()
if err != nil { if err != nil {
@ -2905,6 +3009,7 @@ func (s *OpenAIGatewayService) forwardOpenAIPassthrough(
var usage *OpenAIUsage var usage *OpenAIUsage
var firstTokenMs *int var firstTokenMs *int
imageCount := 0
if reqStream { if reqStream {
result, err := s.handleStreamingResponsePassthrough(ctx, resp, c, account, startTime, reqModel, upstreamPassthroughModel) result, err := s.handleStreamingResponsePassthrough(ctx, resp, c, account, startTime, reqModel, upstreamPassthroughModel)
if err != nil { if err != nil {
@ -2912,11 +3017,14 @@ func (s *OpenAIGatewayService) forwardOpenAIPassthrough(
} }
usage = result.usage usage = result.usage
firstTokenMs = result.firstTokenMs firstTokenMs = result.firstTokenMs
imageCount = result.imageCount
} else { } else {
usage, err = s.handleNonStreamingResponsePassthrough(ctx, resp, c, reqModel, upstreamPassthroughModel) result, err := s.handleNonStreamingResponsePassthrough(ctx, resp, c, reqModel, upstreamPassthroughModel)
if err != nil { if err != nil {
return nil, err return nil, err
} }
usage = result.usage
imageCount = result.imageCount
} }
if snapshot := ParseCodexRateLimitHeaders(resp.Header); snapshot != nil { if snapshot := ParseCodexRateLimitHeaders(resp.Header); snapshot != nil {
@ -2927,7 +3035,7 @@ func (s *OpenAIGatewayService) forwardOpenAIPassthrough(
usage = &OpenAIUsage{} usage = &OpenAIUsage{}
} }
return &OpenAIForwardResult{ forwardResult := &OpenAIForwardResult{
RequestID: resp.Header.Get("x-request-id"), RequestID: resp.Header.Get("x-request-id"),
Usage: *usage, Usage: *usage,
Model: reqModel, Model: reqModel,
@ -2938,7 +3046,13 @@ func (s *OpenAIGatewayService) forwardOpenAIPassthrough(
OpenAIWSMode: false, OpenAIWSMode: false,
Duration: time.Since(startTime), Duration: time.Since(startTime),
FirstTokenMs: firstTokenMs, FirstTokenMs: firstTokenMs,
}, nil }
if imageCount > 0 {
forwardResult.ImageCount = imageCount
forwardResult.ImageSize = imageSizeTier
forwardResult.BillingModel = imageBillingModel
}
return forwardResult, nil
} }
func logOpenAIPassthroughInstructionsRejected( func logOpenAIPassthroughInstructionsRejected(
@ -3233,6 +3347,13 @@ func collectOpenAIPassthroughTimeoutHeaders(h http.Header) []string {
type openaiStreamingResultPassthrough struct { type openaiStreamingResultPassthrough struct {
usage *OpenAIUsage usage *OpenAIUsage
firstTokenMs *int firstTokenMs *int
imageCount int
}
type openaiNonStreamingResultPassthrough struct {
*OpenAIUsage
usage *OpenAIUsage
imageCount int
} }
func openAIStreamClientOutputStarted(c *gin.Context, localStarted bool) bool { func openAIStreamClientOutputStarted(c *gin.Context, localStarted bool) bool {
@ -3369,6 +3490,7 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough(
} }
usage := &OpenAIUsage{} usage := &OpenAIUsage{}
imageCounter := newOpenAIImageOutputCounter()
var firstTokenMs *int var firstTokenMs *int
clientDisconnected := false clientDisconnected := false
sawDone := false sawDone := false
@ -3400,6 +3522,9 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough(
defer putSSEScannerBuf64K(scanBuf) defer putSSEScannerBuf64K(scanBuf)
needModelReplace := strings.TrimSpace(originalModel) != "" && strings.TrimSpace(mappedModel) != "" && strings.TrimSpace(originalModel) != strings.TrimSpace(mappedModel) needModelReplace := strings.TrimSpace(originalModel) != "" && strings.TrimSpace(mappedModel) != "" && strings.TrimSpace(originalModel) != strings.TrimSpace(mappedModel)
resultWithUsage := func() *openaiStreamingResultPassthrough {
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs, imageCount: imageCounter.Count()}
}
for scanner.Scan() { for scanner.Scan() {
line := scanner.Text() line := scanner.Text()
@ -3419,7 +3544,7 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough(
if eventType == "response.failed" { if eventType == "response.failed" {
failedMessage = extractOpenAISSEErrorMessage(dataBytes) failedMessage = extractOpenAISSEErrorMessage(dataBytes)
if !openAIStreamClientOutputStarted(c, clientOutputStarted) && openAIStreamFailedEventShouldFailover(dataBytes, failedMessage) { if !openAIStreamClientOutputStarted(c, clientOutputStarted) && openAIStreamFailedEventShouldFailover(dataBytes, failedMessage) {
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, return resultWithUsage(),
s.newOpenAIStreamFailoverError(c, account, true, upstreamRequestID, dataBytes, failedMessage) s.newOpenAIStreamFailoverError(c, account, true, upstreamRequestID, dataBytes, failedMessage)
} }
forceFlushFailedEvent = true forceFlushFailedEvent = true
@ -3431,6 +3556,7 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough(
if openAIStreamEventIsTerminal(trimmedData) { if openAIStreamEventIsTerminal(trimmedData) {
sawTerminalEvent = true sawTerminalEvent = true
} }
imageCounter.AddSSEData(dataBytes)
lineStartsClientOutput = forceFlushFailedEvent || openAIStreamDataStartsClientOutput(trimmedData, eventType) lineStartsClientOutput = forceFlushFailedEvent || openAIStreamDataStartsClientOutput(trimmedData, eventType)
if firstTokenMs == nil && lineStartsClientOutput && trimmedData != "[DONE]" { if firstTokenMs == nil && lineStartsClientOutput && trimmedData != "[DONE]" {
ms := int(time.Since(startTime).Milliseconds()) ms := int(time.Since(startTime).Milliseconds())
@ -3460,28 +3586,28 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough(
} }
if err := scanner.Err(); err != nil { if err := scanner.Err(); err != nil {
if sawTerminalEvent && !sawFailedEvent { if sawTerminalEvent && !sawFailedEvent {
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, nil return resultWithUsage(), nil
} }
if sawFailedEvent { if sawFailedEvent {
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("upstream response failed: %s", failedMessage) return resultWithUsage(), fmt.Errorf("upstream response failed: %s", failedMessage)
} }
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream usage incomplete: %w", err) return resultWithUsage(), fmt.Errorf("stream usage incomplete: %w", err)
} }
if errors.Is(err, bufio.ErrTooLong) { if errors.Is(err, bufio.ErrTooLong) {
logger.LegacyPrintf("service.openai_gateway", "[OpenAI passthrough] SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, err) logger.LegacyPrintf("service.openai_gateway", "[OpenAI passthrough] SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, err)
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, err return resultWithUsage(), err
} }
if !openAIStreamClientOutputStarted(c, clientOutputStarted) { if !openAIStreamClientOutputStarted(c, clientOutputStarted) {
msg := "OpenAI stream disconnected before completion" msg := "OpenAI stream disconnected before completion"
if errText := strings.TrimSpace(err.Error()); errText != "" { if errText := strings.TrimSpace(err.Error()); errText != "" {
msg += ": " + errText msg += ": " + errText
} }
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, return resultWithUsage(),
s.newOpenAIStreamFailoverError(c, account, true, upstreamRequestID, nil, msg) s.newOpenAIStreamFailoverError(c, account, true, upstreamRequestID, nil, msg)
} }
if clientDisconnected { if clientDisconnected {
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream usage incomplete after disconnect: %w", err) return resultWithUsage(), fmt.Errorf("stream usage incomplete after disconnect: %w", err)
} }
logger.LegacyPrintf("service.openai_gateway", logger.LegacyPrintf("service.openai_gateway",
"[OpenAI passthrough] 流读取异常中断: account=%d request_id=%s err=%v", "[OpenAI passthrough] 流读取异常中断: account=%d request_id=%s err=%v",
@ -3489,10 +3615,10 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough(
upstreamRequestID, upstreamRequestID,
err, err,
) )
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream read error: %w", err) return resultWithUsage(), fmt.Errorf("stream read error: %w", err)
} }
if sawFailedEvent { if sawFailedEvent {
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("upstream response failed: %s", failedMessage) return resultWithUsage(), fmt.Errorf("upstream response failed: %s", failedMessage)
} }
if !clientDisconnected && !sawDone && !sawTerminalEvent && ctx.Err() == nil { if !clientDisconnected && !sawDone && !sawTerminalEvent && ctx.Err() == nil {
logger.FromContext(ctx).With( logger.FromContext(ctx).With(
@ -3501,13 +3627,13 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough(
zap.String("upstream_request_id", upstreamRequestID), zap.String("upstream_request_id", upstreamRequestID),
).Info("OpenAI passthrough 上游流在未收到 [DONE] 时结束,疑似断流") ).Info("OpenAI passthrough 上游流在未收到 [DONE] 时结束,疑似断流")
if !openAIStreamClientOutputStarted(c, clientOutputStarted) { if !openAIStreamClientOutputStarted(c, clientOutputStarted) {
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, return resultWithUsage(),
s.newOpenAIStreamFailoverError(c, account, true, upstreamRequestID, nil, "OpenAI stream ended before a terminal event") s.newOpenAIStreamFailoverError(c, account, true, upstreamRequestID, nil, "OpenAI stream ended before a terminal event")
} }
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, errors.New("stream usage incomplete: missing terminal event") return resultWithUsage(), errors.New("stream usage incomplete: missing terminal event")
} }
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, nil return resultWithUsage(), nil
} }
func (s *OpenAIGatewayService) handleNonStreamingResponsePassthrough( func (s *OpenAIGatewayService) handleNonStreamingResponsePassthrough(
@ -3516,7 +3642,7 @@ func (s *OpenAIGatewayService) handleNonStreamingResponsePassthrough(
c *gin.Context, c *gin.Context,
originalModel string, originalModel string,
mappedModel string, mappedModel string,
) (*OpenAIUsage, error) { ) (*openaiNonStreamingResultPassthrough, error) {
body, err := ReadUpstreamResponseBody(resp.Body, s.cfg, c, openAITooLargeError) body, err := ReadUpstreamResponseBody(resp.Body, s.cfg, c, openAITooLargeError)
if err != nil { if err != nil {
return nil, err return nil, err
@ -3553,14 +3679,18 @@ func (s *OpenAIGatewayService) handleNonStreamingResponsePassthrough(
body = s.replaceModelInResponseBody(body, mappedModel, originalModel) body = s.replaceModelInResponseBody(body, mappedModel, originalModel)
} }
c.Data(resp.StatusCode, contentType, body) c.Data(resp.StatusCode, contentType, body)
return usage, nil return &openaiNonStreamingResultPassthrough{
OpenAIUsage: usage,
usage: usage,
imageCount: countOpenAIResponseImageOutputsFromJSONBytes(body),
}, nil
} }
// handlePassthroughSSEToJSON converts an SSE response body into a JSON // handlePassthroughSSEToJSON converts an SSE response body into a JSON
// response for the passthrough path. It mirrors handleSSEToJSON while // response for the passthrough path. It mirrors handleSSEToJSON while
// preserving passthrough payloads, except compact-only model remapping may // preserving passthrough payloads, except compact-only model remapping may
// rewrite model fields back to the original requested model. // rewrite model fields back to the original requested model.
func (s *OpenAIGatewayService) handlePassthroughSSEToJSON(resp *http.Response, c *gin.Context, body []byte, originalModel string, mappedModel string) (*OpenAIUsage, error) { func (s *OpenAIGatewayService) handlePassthroughSSEToJSON(resp *http.Response, c *gin.Context, body []byte, originalModel string, mappedModel string) (*openaiNonStreamingResultPassthrough, error) {
bodyText := string(body) bodyText := string(body)
finalResponse, ok := extractCodexFinalResponse(bodyText) finalResponse, ok := extractCodexFinalResponse(bodyText)
@ -3611,7 +3741,11 @@ func (s *OpenAIGatewayService) handlePassthroughSSEToJSON(resp *http.Response, c
} }
c.Data(resp.StatusCode, contentType, body) c.Data(resp.StatusCode, contentType, body)
return usage, nil return &openaiNonStreamingResultPassthrough{
OpenAIUsage: usage,
usage: usage,
imageCount: countOpenAIImageOutputsFromSSEBody(bodyText),
}, nil
} }
func writeOpenAIPassthroughResponseHeaders(dst http.Header, src http.Header, filter *responseheaders.CompiledHeaderFilter) { func writeOpenAIPassthroughResponseHeaders(dst http.Header, src http.Header, filter *responseheaders.CompiledHeaderFilter) {
@ -3715,12 +3849,19 @@ func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin.
} }
} }
if account.Type == AccountTypeOAuth { if account.Type == AccountTypeOAuth {
compatMessagesBridge := isOpenAICompatMessagesBridgeContext(c) || isOpenAICompatMessagesBridgeBody(body)
// 清除客户端透传的 session 头,后续用隔离后的值重新设置,防止跨用户会话碰撞。 // 清除客户端透传的 session 头,后续用隔离后的值重新设置,防止跨用户会话碰撞。
clientConversationID := strings.TrimSpace(req.Header.Get("conversation_id"))
req.Header.Del("conversation_id") req.Header.Del("conversation_id")
req.Header.Del("session_id") req.Header.Del("session_id")
req.Header.Set("OpenAI-Beta", "responses=experimental") if compatMessagesBridge {
req.Header.Set("originator", resolveOpenAIUpstreamOriginator(c, isCodexCLI)) req.Header.Del("OpenAI-Beta")
req.Header.Del("originator")
} else {
req.Header.Set("OpenAI-Beta", "responses=experimental")
req.Header.Set("originator", resolveOpenAIUpstreamOriginator(c, isCodexCLI))
}
apiKeyID := getAPIKeyIDFromContext(c) apiKeyID := getAPIKeyIDFromContext(c)
if isOpenAIResponsesCompactPath(c) { if isOpenAIResponsesCompactPath(c) {
req.Header.Set("accept", "application/json") req.Header.Set("accept", "application/json")
@ -3734,8 +3875,10 @@ func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin.
} }
if promptCacheKey != "" { if promptCacheKey != "" {
isolated := isolateOpenAISessionID(apiKeyID, promptCacheKey) isolated := isolateOpenAISessionID(apiKeyID, promptCacheKey)
req.Header.Set("conversation_id", isolated)
req.Header.Set("session_id", isolated) req.Header.Set("session_id", isolated)
if !compatMessagesBridge || clientConversationID != "" {
req.Header.Set("conversation_id", isolated)
}
} }
} }
@ -4025,6 +4168,13 @@ func (s *OpenAIGatewayService) handleCompatErrorResponse(
type openaiStreamingResult struct { type openaiStreamingResult struct {
usage *OpenAIUsage usage *OpenAIUsage
firstTokenMs *int firstTokenMs *int
imageCount int
}
type openaiNonStreamingResult struct {
*OpenAIUsage
usage *OpenAIUsage
imageCount int
} }
func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, startTime time.Time, originalModel, mappedModel string) (*openaiStreamingResult, error) { func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, startTime time.Time, originalModel, mappedModel string) (*openaiStreamingResult, error) {
@ -4058,6 +4208,7 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
} }
usage := &OpenAIUsage{} usage := &OpenAIUsage{}
imageCounter := newOpenAIImageOutputCounter()
var firstTokenMs *int var firstTokenMs *int
scanner := bufio.NewScanner(resp.Body) scanner := bufio.NewScanner(resp.Body)
maxLineSize := defaultMaxLineSize maxLineSize := defaultMaxLineSize
@ -4136,7 +4287,7 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
needModelReplace := originalModel != mappedModel needModelReplace := originalModel != mappedModel
resultWithUsage := func() *openaiStreamingResult { resultWithUsage := func() *openaiStreamingResult {
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs} return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs, imageCount: imageCounter.Count()}
} }
finalizeStream := func() (*openaiStreamingResult, error) { finalizeStream := func() (*openaiStreamingResult, error) {
if !sawTerminalEvent { if !sawTerminalEvent {
@ -4231,6 +4382,7 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
forceFlushFailedEvent = true forceFlushFailedEvent = true
sawFailedEvent = true sawFailedEvent = true
} }
imageCounter.AddSSEData(dataBytes)
// Correct Codex tool calls if needed (apply_patch -> edit, etc.) // Correct Codex tool calls if needed (apply_patch -> edit, etc.)
if correctedData, corrected := s.toolCorrector.CorrectToolCallsInSSEBytes(dataBytes); corrected { if correctedData, corrected := s.toolCorrector.CorrectToolCallsInSSEBytes(dataBytes); corrected {
@ -4496,7 +4648,7 @@ func extractOpenAIUsageFromJSONBytes(body []byte) (OpenAIUsage, bool) {
}, true }, true
} }
func (s *OpenAIGatewayService) handleNonStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, originalModel, mappedModel string) (*OpenAIUsage, error) { func (s *OpenAIGatewayService) handleNonStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, originalModel, mappedModel string) (*openaiNonStreamingResult, error) {
body, err := ReadUpstreamResponseBody(resp.Body, s.cfg, c, openAITooLargeError) body, err := ReadUpstreamResponseBody(resp.Body, s.cfg, c, openAITooLargeError)
if err != nil { if err != nil {
return nil, err return nil, err
@ -4542,7 +4694,11 @@ func (s *OpenAIGatewayService) handleNonStreamingResponse(ctx context.Context, r
c.Data(resp.StatusCode, contentType, body) c.Data(resp.StatusCode, contentType, body)
return usage, nil return &openaiNonStreamingResult{
OpenAIUsage: usage,
usage: usage,
imageCount: countOpenAIResponseImageOutputsFromJSONBytes(body),
}, nil
} }
func isEventStreamResponse(header http.Header) bool { func isEventStreamResponse(header http.Header) bool {
@ -4550,7 +4706,7 @@ func isEventStreamResponse(header http.Header) bool {
return strings.Contains(contentType, "text/event-stream") return strings.Contains(contentType, "text/event-stream")
} }
func (s *OpenAIGatewayService) handleSSEToJSON(resp *http.Response, c *gin.Context, body []byte, originalModel, mappedModel string) (*OpenAIUsage, error) { func (s *OpenAIGatewayService) handleSSEToJSON(resp *http.Response, c *gin.Context, body []byte, originalModel, mappedModel string) (*openaiNonStreamingResult, error) {
bodyText := string(body) bodyText := string(body)
finalResponse, ok := extractCodexFinalResponse(bodyText) finalResponse, ok := extractCodexFinalResponse(bodyText)
@ -4602,21 +4758,29 @@ func (s *OpenAIGatewayService) handleSSEToJSON(resp *http.Response, c *gin.Conte
} }
c.Data(resp.StatusCode, contentType, body) c.Data(resp.StatusCode, contentType, body)
return usage, nil return &openaiNonStreamingResult{
OpenAIUsage: usage,
usage: usage,
imageCount: countOpenAIImageOutputsFromSSEBody(bodyText),
}, nil
} }
func extractOpenAISSETerminalEvent(body string) (string, []byte, bool) { func extractOpenAISSETerminalEvent(body string) (string, []byte, bool) {
lines := strings.Split(body, "\n") var terminalType string
for _, line := range lines { var terminalPayload []byte
data, ok := extractOpenAISSEDataLine(line) forEachOpenAISSEDataPayload(body, func(data []byte) {
if !ok || data == "" || data == "[DONE]" { if terminalPayload != nil {
continue return
} }
eventType := strings.TrimSpace(gjson.Get(data, "type").String()) eventType := strings.TrimSpace(gjson.GetBytes(data, "type").String())
switch eventType { switch eventType {
case "response.completed", "response.done", "response.failed", "response.incomplete", "response.cancelled", "response.canceled": case "response.completed", "response.done", "response.failed", "response.incomplete", "response.cancelled", "response.canceled":
return eventType, []byte(data), true terminalType = eventType
terminalPayload = append([]byte(nil), data...)
} }
})
if terminalPayload != nil {
return terminalType, terminalPayload, true
} }
return "", nil, false return "", nil, false
} }
@ -4651,21 +4815,20 @@ func (s *OpenAIGatewayService) writeOpenAINonStreamingProtocolError(resp *http.R
} }
func extractCodexFinalResponse(body string) ([]byte, bool) { func extractCodexFinalResponse(body string) ([]byte, bool) {
lines := strings.Split(body, "\n") var finalResponse []byte
for _, line := range lines { forEachOpenAISSEDataPayload(body, func(data []byte) {
data, ok := extractOpenAISSEDataLine(line) if finalResponse != nil {
if !ok { return
continue
} }
if data == "" || data == "[DONE]" { eventType := gjson.GetBytes(data, "type").String()
continue
}
eventType := gjson.Get(data, "type").String()
if eventType == "response.done" || eventType == "response.completed" { if eventType == "response.done" || eventType == "response.completed" {
if response := gjson.Get(data, "response"); response.Exists() && response.Type == gjson.JSON && response.Raw != "" { if response := gjson.GetBytes(data, "response"); response.Exists() && response.Type == gjson.JSON && response.Raw != "" {
return []byte(response.Raw), true finalResponse = []byte(response.Raw)
} }
} }
})
if finalResponse != nil {
return finalResponse, true
} }
return nil, false return nil, false
} }
@ -4677,21 +4840,15 @@ func reconstructResponseOutputFromSSE(bodyText string) ([]byte, bool) {
acc := apicompat.NewBufferedResponseAccumulator() acc := apicompat.NewBufferedResponseAccumulator()
imageOutputs := make([]json.RawMessage, 0, 1) imageOutputs := make([]json.RawMessage, 0, 1)
seenImages := make(map[string]struct{}) seenImages := make(map[string]struct{})
lines := strings.Split(bodyText, "\n") forEachOpenAISSEDataPayload(bodyText, func(data []byte) {
for _, line := range lines { if imageOutput, ok := extractImageGenerationOutputFromSSEData(data, seenImages); ok {
data, ok := extractOpenAISSEDataLine(line)
if !ok || data == "" || data == "[DONE]" {
continue
}
if imageOutput, ok := extractImageGenerationOutputFromSSEData([]byte(data), seenImages); ok {
imageOutputs = append(imageOutputs, imageOutput) imageOutputs = append(imageOutputs, imageOutput)
} }
var event apicompat.ResponsesStreamEvent var event apicompat.ResponsesStreamEvent
if err := json.Unmarshal([]byte(data), &event); err != nil { if err := json.Unmarshal(data, &event); err == nil {
continue acc.ProcessEvent(&event)
} }
acc.ProcessEvent(&event) })
}
if !acc.HasContent() && len(imageOutputs) == 0 { if !acc.HasContent() && len(imageOutputs) == 0 {
return nil, false return nil, false
} }
@ -4744,17 +4901,9 @@ func extractImageGenerationOutputFromSSEData(data []byte, seen map[string]struct
func (s *OpenAIGatewayService) parseSSEUsageFromBody(body string) *OpenAIUsage { func (s *OpenAIGatewayService) parseSSEUsageFromBody(body string) *OpenAIUsage {
usage := &OpenAIUsage{} usage := &OpenAIUsage{}
lines := strings.Split(body, "\n") forEachOpenAISSEDataPayload(body, func(data []byte) {
for _, line := range lines { s.parseSSEUsageBytes(data, usage)
data, ok := extractOpenAISSEDataLine(line) })
if !ok {
continue
}
if data == "" || data == "[DONE]" {
continue
}
s.parseSSEUsageBytes([]byte(data), usage)
}
return usage return usage
} }
@ -5036,16 +5185,15 @@ type OpenAIRecordUsageInput struct {
// RecordUsage records usage and deducts balance // RecordUsage records usage and deducts balance
func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRecordUsageInput) error { func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRecordUsageInput) error {
result := input.Result if input == nil {
if s.rateLimitService != nil && input != nil && input.Account != nil && input.Account.Platform == PlatformOpenAI { return errors.New("openai usage input is nil")
s.rateLimitService.ResetOpenAI403Counter(ctx, input.Account.ID)
} }
result := input.Result
// 跳过所有 token 均为零的用量记录——上游未返回 usage 时不应写入数据库 if result == nil {
if result.Usage.InputTokens == 0 && result.Usage.OutputTokens == 0 && return errors.New("openai usage result is nil")
result.Usage.CacheCreationInputTokens == 0 && result.Usage.CacheReadInputTokens == 0 && }
result.Usage.ImageOutputTokens == 0 && result.ImageCount == 0 { if s.rateLimitService != nil && input.Account != nil && input.Account.Platform == PlatformOpenAI {
return nil s.rateLimitService.ResetOpenAI403Counter(ctx, input.Account.ID)
} }
apiKey := input.APIKey apiKey := input.APIKey
@ -5081,6 +5229,7 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
} }
multiplier = resolver.Resolve(ctx, user.ID, *apiKey.GroupID, apiKey.Group.RateMultiplier) multiplier = resolver.Resolve(ctx, user.ID, *apiKey.GroupID, apiKey.Group.RateMultiplier)
} }
imageMultiplier := resolveImageRateMultiplier(apiKey, multiplier)
var cost *CostBreakdown var cost *CostBreakdown
var err error var err error
@ -5094,13 +5243,21 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
if input.BillingModelSource == BillingModelSourceRequested && input.OriginalModel != "" { if input.BillingModelSource == BillingModelSourceRequested && input.OriginalModel != "" {
billingModel = input.OriginalModel billingModel = input.OriginalModel
} }
billingModels := usageBillingModelCandidates(
billingModel,
result.BillingModel,
input.ChannelMappedModel,
input.OriginalModel,
result.UpstreamModel,
result.Model,
)
serviceTier := "" serviceTier := ""
if result.ServiceTier != nil { if result.ServiceTier != nil {
serviceTier = strings.TrimSpace(*result.ServiceTier) serviceTier = strings.TrimSpace(*result.ServiceTier)
} }
cost, err = s.calculateOpenAIRecordUsageCost(ctx, result, apiKey, billingModel, multiplier, tokens, serviceTier) cost, err = s.calculateOpenAIRecordUsageCost(ctx, result, apiKey, billingModels, multiplier, imageMultiplier, tokens, serviceTier)
if err != nil { if err != nil {
cost = &CostBreakdown{ActualCost: 0} return err
} }
// Determine billing type // Determine billing type
@ -5150,7 +5307,11 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
usageLog.TotalCost = cost.TotalCost usageLog.TotalCost = cost.TotalCost
usageLog.ActualCost = cost.ActualCost usageLog.ActualCost = cost.ActualCost
} }
usageLog.RateMultiplier = multiplier if result.ImageCount > 0 {
usageLog.RateMultiplier = imageMultiplier
} else {
usageLog.RateMultiplier = multiplier
}
usageLog.AccountRateMultiplier = &accountRateMultiplier usageLog.AccountRateMultiplier = &accountRateMultiplier
usageLog.BillingType = billingType usageLog.BillingType = billingType
usageLog.Stream = result.Stream usageLog.Stream = result.Stream
@ -5231,14 +5392,45 @@ func (s *OpenAIGatewayService) calculateOpenAIRecordUsageCost(
ctx context.Context, ctx context.Context,
result *OpenAIForwardResult, result *OpenAIForwardResult,
apiKey *APIKey, apiKey *APIKey,
billingModels []string,
multiplier float64,
imageMultiplier float64,
tokens UsageTokens,
serviceTier string,
) (*CostBreakdown, error) {
billingModel := firstUsageBillingModel(billingModels)
if result != nil && result.ImageCount > 0 {
return s.calculateOpenAIImageCost(ctx, billingModel, apiKey, result, imageMultiplier), nil
}
if len(billingModels) == 0 || billingModel == "" {
return nil, errors.New("openai usage billing model is empty")
}
var lastErr error
for _, candidate := range billingModels {
candidate = strings.TrimSpace(candidate)
if candidate == "" {
continue
}
cost, err := s.calculateOpenAIRecordUsageTokenCost(ctx, apiKey, candidate, multiplier, tokens, serviceTier)
if err == nil {
return cost, nil
}
lastErr = err
}
if lastErr == nil {
lastErr = errors.New("no non-empty billing model candidates")
}
return nil, fmt.Errorf("calculate OpenAI usage cost failed for billing models %s: %w", strings.Join(billingModels, ","), lastErr)
}
func (s *OpenAIGatewayService) calculateOpenAIRecordUsageTokenCost(
ctx context.Context,
apiKey *APIKey,
billingModel string, billingModel string,
multiplier float64, multiplier float64,
tokens UsageTokens, tokens UsageTokens,
serviceTier string, serviceTier string,
) (*CostBreakdown, error) { ) (*CostBreakdown, error) {
if result != nil && result.ImageCount > 0 {
return s.calculateOpenAIImageCost(ctx, billingModel, apiKey, result, multiplier), nil
}
if s.resolver != nil && apiKey.Group != nil { if s.resolver != nil && apiKey.Group != nil {
gid := apiKey.Group.ID gid := apiKey.Group.ID
return s.billingService.CalculateCostUnified(CostInput{ return s.billingService.CalculateCostUnified(CostInput{
@ -5269,7 +5461,7 @@ func (s *OpenAIGatewayService) calculateOpenAIImageCost(
Ctx: ctx, Ctx: ctx,
Model: billingModel, Model: billingModel,
GroupID: &gid, GroupID: &gid,
RequestCount: 1, RequestCount: result.ImageCount,
SizeTier: result.ImageSize, SizeTier: result.ImageSize,
RateMultiplier: multiplier, RateMultiplier: multiplier,
Resolver: s.resolver, Resolver: s.resolver,

View File

@ -1846,6 +1846,29 @@ func TestOpenAIBuildUpstreamRequestCompactForcesJSONAcceptForOAuth(t *testing.T)
require.NotEmpty(t, req.Header.Get("Session_Id")) require.NotEmpty(t, req.Header.Get("Session_Id"))
} }
func TestOpenAIBuildUpstreamRequestOAuthMessagesBridgeUsesSessionOnly(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
body := []byte(`{"model":"gpt-5.5","prompt_cache_key":"anthropic-metadata-session-1","input":[{"type":"message","role":"developer","content":[{"type":"input_text","text":"<sub2api-claude-code-todo-guard>"}]},{"type":"message","role":"user","content":"hello"}]}`)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(body))
c.Request.Header.Set("OpenAI-Beta", "responses=experimental")
c.Request.Header.Set("originator", "codex_cli_rs")
svc := &OpenAIGatewayService{}
account := &Account{
Type: AccountTypeOAuth,
Credentials: map[string]any{"chatgpt_account_id": "chatgpt-acc"},
}
req, err := svc.buildUpstreamRequest(c.Request.Context(), c, account, body, "token", true, "anthropic-metadata-session-1", false)
require.NoError(t, err)
require.NotEmpty(t, req.Header.Get("Session_Id"))
require.Empty(t, req.Header.Get("Conversation_Id"))
require.Empty(t, req.Header.Get("OpenAI-Beta"))
require.Empty(t, req.Header.Get("originator"))
}
func TestOpenAIBuildUpstreamRequestPreservesCompactPathForAPIKeyBaseURL(t *testing.T) { func TestOpenAIBuildUpstreamRequestPreservesCompactPathForAPIKeyBaseURL(t *testing.T) {
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder() rec := httptest.NewRecorder()

View File

@ -0,0 +1,215 @@
package service
import (
"context"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
"github.com/tidwall/gjson"
)
func TestOpenAIGatewayServiceForward_RejectsDisabledImageGenerationIntents(t *testing.T) {
gin.SetMode(gin.TestMode)
tests := []struct {
name string
body []byte
}{
{
name: "image model",
body: []byte(`{"model":"gpt-image-2","input":"draw"}`),
},
{
name: "image tool",
body: []byte(`{"model":"gpt-5.4","input":"draw","tools":[{"type":"image_generation"}]}`),
},
{
name: "image tool choice",
body: []byte(`{"model":"gpt-5.4","input":"draw","tool_choice":{"type":"image_generation"}}`),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
upstream := &httpUpstreamRecorder{}
svc := newOpenAIImageGenerationControlTestService(upstream)
c, recorder := newOpenAIImageGenerationControlTestContext(false, "unit-test-agent/1.0")
account := newOpenAIImageGenerationControlTestAccount()
result, err := svc.Forward(context.Background(), c, account, tt.body)
require.Error(t, err)
require.Nil(t, result)
require.Equal(t, http.StatusForbidden, recorder.Code)
require.Equal(t, "permission_error", gjson.GetBytes(recorder.Body.Bytes(), "error.type").String())
require.Nil(t, upstream.lastReq, "disabled image request must not reach upstream")
})
}
}
func TestOpenAIGatewayServiceForward_DisabledGroupAllowsTextOnlyResponses(t *testing.T) {
gin.SetMode(gin.TestMode)
upstream := &httpUpstreamRecorder{
resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"application/json"}},
Body: io.NopCloser(strings.NewReader(`{"id":"resp_text","model":"gpt-5.4","usage":{"input_tokens":3,"output_tokens":2}}`)),
},
}
svc := newOpenAIImageGenerationControlTestService(upstream)
c, recorder := newOpenAIImageGenerationControlTestContext(false, "unit-test-agent/1.0")
account := newOpenAIImageGenerationControlTestAccount()
result, err := svc.Forward(context.Background(), c, account, []byte(`{"model":"gpt-5.4","input":"write code","stream":false}`))
require.NoError(t, err)
require.NotNil(t, result)
require.Equal(t, http.StatusOK, recorder.Code)
require.Equal(t, 3, result.Usage.InputTokens)
require.Equal(t, 2, result.Usage.OutputTokens)
require.Equal(t, 0, result.ImageCount)
require.NotNil(t, upstream.lastReq)
}
func TestOpenAIGatewayServiceForward_CodexImageInjectionRespectsGroupCapability(t *testing.T) {
gin.SetMode(gin.TestMode)
tests := []struct {
name string
allowImages bool
wantInjected bool
}{
{name: "disabled group skips injection", allowImages: false, wantInjected: false},
{name: "enabled group injects image tool", allowImages: true, wantInjected: true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
upstream := &httpUpstreamRecorder{
resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"application/json"}},
Body: io.NopCloser(strings.NewReader(`{"id":"resp_codex","model":"gpt-5.4","usage":{"input_tokens":1,"output_tokens":1}}`)),
},
}
svc := newOpenAIImageGenerationControlTestService(upstream)
c, _ := newOpenAIImageGenerationControlTestContext(tt.allowImages, "codex_cli_rs/0.98.0")
account := newOpenAIImageGenerationControlTestAccount()
result, err := svc.Forward(context.Background(), c, account, []byte(`{"model":"gpt-5.4","input":"write code","stream":false}`))
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, upstream.lastReq)
hasImageTool := gjson.GetBytes(upstream.lastBody, `tools.#(type=="image_generation")`).Exists()
require.Equal(t, tt.wantInjected, hasImageTool)
instructions := gjson.GetBytes(upstream.lastBody, "instructions").String()
require.Equal(t, tt.wantInjected, strings.Contains(instructions, "image_generation"))
})
}
}
func TestOpenAIGatewayServiceHandleResponsesImageOutputs_NonStreaming(t *testing.T) {
gin.SetMode(gin.TestMode)
svc := newOpenAIImageGenerationControlTestService(&httpUpstreamRecorder{})
c, _ := newOpenAIImageGenerationControlTestContext(true, "unit-test-agent/1.0")
resp := &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"application/json"}},
Body: io.NopCloser(strings.NewReader(`{
"id":"resp_image_json",
"model":"gpt-5.4",
"output":[{"id":"ig_json_1","type":"image_generation_call","result":"final-image"}],
"usage":{"input_tokens":7,"output_tokens":3,"output_tokens_details":{"image_tokens":2}}
}`)),
}
result, err := svc.handleNonStreamingResponse(context.Background(), resp, c, &Account{ID: 1, Type: AccountTypeAPIKey}, "gpt-5.4", "gpt-5.4")
require.NoError(t, err)
require.NotNil(t, result)
require.Equal(t, 1, result.imageCount)
require.NotNil(t, result.usage)
require.Equal(t, 7, result.usage.InputTokens)
require.Equal(t, 3, result.usage.OutputTokens)
require.Equal(t, 2, result.usage.ImageOutputTokens)
}
func TestOpenAIGatewayServiceHandleResponsesImageOutputs_Streaming(t *testing.T) {
gin.SetMode(gin.TestMode)
svc := newOpenAIImageGenerationControlTestService(&httpUpstreamRecorder{})
c, _ := newOpenAIImageGenerationControlTestContext(true, "unit-test-agent/1.0")
resp := &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"text/event-stream"}},
Body: io.NopCloser(strings.NewReader(
"data: {\"type\":\"response.output_item.done\",\"item\":{\"id\":\"ig_stream_1\",\"type\":\"image_generation_call\",\"result\":\"final-image\"}}\n\n" +
"data: {\"type\":\"response.completed\",\"response\":{\"id\":\"resp_image_stream\",\"model\":\"gpt-5.5\",\"output\":[{\"id\":\"ig_stream_1\",\"type\":\"image_generation_call\",\"result\":\"final-image\"}],\"usage\":{\"input_tokens\":11,\"output_tokens\":5,\"output_tokens_details\":{\"image_tokens\":4}}}}\n\n",
)),
}
result, err := svc.handleStreamingResponse(context.Background(), resp, c, &Account{ID: 1}, time.Now(), "gpt-5.5", "gpt-5.5")
require.NoError(t, err)
require.NotNil(t, result)
require.Equal(t, 1, result.imageCount)
require.NotNil(t, result.usage)
require.Equal(t, 11, result.usage.InputTokens)
require.Equal(t, 5, result.usage.OutputTokens)
require.Equal(t, 4, result.usage.ImageOutputTokens)
}
func newOpenAIImageGenerationControlTestService(upstream *httpUpstreamRecorder) *OpenAIGatewayService {
cfg := &config.Config{}
return &OpenAIGatewayService{
cfg: cfg,
httpUpstream: upstream,
cache: &stubGatewayCache{},
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
toolCorrector: NewCodexToolCorrector(),
}
}
func newOpenAIImageGenerationControlTestContext(allowImages bool, userAgent string) (*gin.Context, *httptest.ResponseRecorder) {
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
c.Request.Header.Set("User-Agent", userAgent)
groupID := int64(4242)
c.Set("api_key", &APIKey{
ID: 2424,
GroupID: &groupID,
Group: &Group{
ID: groupID,
AllowImageGeneration: allowImages,
RateMultiplier: 1,
ImageRateMultiplier: 1,
},
})
return c, recorder
}
func newOpenAIImageGenerationControlTestAccount() *Account {
return &Account{
ID: 5151,
Name: "openai-image-controls",
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Credentials: map[string]any{
"api_key": "sk-test",
},
}
}

View File

@ -16,6 +16,7 @@ import (
"net/textproto" "net/textproto"
"strconv" "strconv"
"strings" "strings"
"sync/atomic"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger" "github.com/Wei-Shaw/sub2api/internal/pkg/logger"
@ -468,14 +469,54 @@ func isOpenAINativeImageOption(name string) bool {
} }
func normalizeOpenAIImageSizeTier(size string) string { func normalizeOpenAIImageSizeTier(size string) string {
switch strings.ToLower(strings.TrimSpace(size)) { trimmed := strings.TrimSpace(size)
normalized := strings.ToLower(trimmed)
switch normalized {
case "", "auto":
return "2K"
case "1024x1024": case "1024x1024":
return "1K" return "1K"
case "1536x1024", "1024x1536", "1792x1024", "1024x1792", "", "auto": case "1536x1024", "1024x1536", "1792x1024", "1024x1792", "2048x2048", "2048x1152", "1152x2048":
return "2K" return "2K"
default: case "3840x2160", "2160x3840":
return "4K"
}
width, height, ok := parseOpenAIImageSizeDimensions(trimmed)
if !ok {
return "2K" return "2K"
} }
return classifyUnknownOpenAIImageSizeTier(width, height)
}
const (
openAIImage2KMaxPixels = 2560 * 1440
)
func parseOpenAIImageSizeDimensions(size string) (int, int, bool) {
trimmed := strings.TrimSpace(size)
parts := strings.Split(strings.ToLower(trimmed), "x")
if len(parts) != 2 {
return 0, 0, false
}
width, err := strconv.Atoi(strings.TrimSpace(parts[0]))
if err != nil {
return 0, 0, false
}
height, err := strconv.Atoi(strings.TrimSpace(parts[1]))
if err != nil {
return 0, 0, false
}
if width <= 0 || height <= 0 {
return 0, 0, false
}
return width, height, true
}
func classifyUnknownOpenAIImageSizeTier(width int, height int) string {
if height > 0 && width > openAIImage2KMaxPixels/height {
return "4K"
}
return "2K"
} }
func (s *OpenAIGatewayService) ForwardImages( func (s *OpenAIGatewayService) ForwardImages(
@ -535,11 +576,14 @@ func (s *OpenAIGatewayService) forwardOpenAIImagesAPIKey(
setOpsUpstreamRequestBody(c, forwardBody) setOpsUpstreamRequestBody(c, forwardBody)
} }
token, _, err := s.GetAccessToken(ctx, account) upstreamCtx, releaseUpstreamCtx := detachStreamUpstreamContext(ctx, parsed.Stream)
defer releaseUpstreamCtx()
token, _, err := s.GetAccessToken(upstreamCtx, account)
if err != nil { if err != nil {
return nil, err return nil, err
} }
upstreamReq, err := s.buildOpenAIImagesRequest(ctx, c, account, forwardBody, forwardContentType, token, parsed.Endpoint) upstreamReq, err := s.buildOpenAIImagesRequest(upstreamCtx, c, account, forwardBody, forwardContentType, token, parsed.Endpoint)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -582,23 +626,37 @@ func (s *OpenAIGatewayService) forwardOpenAIImagesAPIKey(
Kind: "failover", Kind: "failover",
Message: upstreamMsg, Message: upstreamMsg,
}) })
s.handleFailoverSideEffects(ctx, resp, account) s.handleFailoverSideEffects(upstreamCtx, resp, account)
return nil, &UpstreamFailoverError{ return nil, &UpstreamFailoverError{
StatusCode: resp.StatusCode, StatusCode: resp.StatusCode,
ResponseBody: respBody, ResponseBody: respBody,
RetryableOnSameAccount: account.IsPoolMode() && isPoolModeRetryableStatus(resp.StatusCode), RetryableOnSameAccount: account.IsPoolMode() && isPoolModeRetryableStatus(resp.StatusCode),
} }
} }
return s.handleErrorResponse(ctx, resp, c, account, forwardBody) return s.handleErrorResponse(upstreamCtx, resp, c, account, forwardBody)
} }
defer func() { _ = resp.Body.Close() }() defer func() { _ = resp.Body.Close() }()
var usage OpenAIUsage var usage OpenAIUsage
imageCount := parsed.N imageCount := parsed.N
var firstTokenMs *int var firstTokenMs *int
if parsed.Stream { if parsed.Stream && isEventStreamResponse(resp.Header) {
streamUsage, streamCount, ttft, err := s.handleOpenAIImagesStreamingResponse(resp, c, startTime) streamUsage, streamCount, ttft, err := s.handleOpenAIImagesStreamingResponse(resp, c, startTime)
if err != nil { if err != nil {
if streamCount > 0 {
return &OpenAIForwardResult{
RequestID: resp.Header.Get("x-request-id"),
Usage: streamUsage,
Model: requestModel,
UpstreamModel: upstreamModel,
Stream: parsed.Stream,
ResponseHeaders: resp.Header.Clone(),
Duration: time.Since(startTime),
FirstTokenMs: ttft,
ImageCount: streamCount,
ImageSize: parsed.SizeTier,
}, err
}
return nil, err return nil, err
} }
usage = streamUsage usage = streamUsage
@ -807,39 +865,228 @@ func (s *OpenAIGatewayService) handleOpenAIImagesStreamingResponse(
return OpenAIUsage{}, 0, nil, fmt.Errorf("streaming is not supported by response writer") return OpenAIUsage{}, 0, nil, fmt.Errorf("streaming is not supported by response writer")
} }
reader := bufio.NewReader(resp.Body)
usage := OpenAIUsage{} usage := OpenAIUsage{}
imageCount := 0 imageCounter := newOpenAIImageOutputCounter()
var firstTokenMs *int var firstTokenMs *int
clientDisconnected := false
lastDownstreamWriteAt := time.Now()
var fallbackBody bytes.Buffer
fallbackBytes := int64(0)
fallbackLimit := resolveUpstreamResponseReadLimit(s.cfg)
seenSSEData := false
fallbackTooLarge := false
var sseData openAISSEDataAccumulator
for { processSSEData := func(dataBytes []byte) {
line, err := reader.ReadBytes('\n') seenSSEData = true
if len(line) > 0 { fallbackBody.Reset()
if firstTokenMs == nil { fallbackBytes = 0
ms := int(time.Since(startTime).Milliseconds()) mergeOpenAIUsage(&usage, dataBytes)
firstTokenMs = &ms imageCounter.AddSSEData(dataBytes)
} }
flushSSEEvent := func() {
sseData.Flush(processSSEData)
}
processLine := func(line []byte) {
if len(line) == 0 {
return
}
if firstTokenMs == nil {
ms := int(time.Since(startTime).Milliseconds())
firstTokenMs = &ms
}
if !clientDisconnected {
if _, writeErr := c.Writer.Write(line); writeErr != nil { if _, writeErr := c.Writer.Write(line); writeErr != nil {
return OpenAIUsage{}, 0, firstTokenMs, writeErr clientDisconnected = true
logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Images stream client disconnected, continue draining upstream for billing")
} else {
flusher.Flush()
lastDownstreamWriteAt = time.Now()
} }
flusher.Flush() }
if data, ok := extractOpenAISSEDataLine(strings.TrimRight(string(line), "\r\n")); ok && data != "" && data != "[DONE]" { trimmedLine := strings.TrimRight(string(line), "\r\n")
dataBytes := []byte(data) if _, ok := extractOpenAISSEDataLine(trimmedLine); ok || strings.TrimSpace(trimmedLine) == "" {
mergeOpenAIUsage(&usage, dataBytes) sseData.AddLine(trimmedLine, processSSEData)
if count := extractOpenAIImageCountFromJSONBytes(dataBytes); count > imageCount { return
imageCount = count }
} if !seenSSEData && !fallbackTooLarge {
fallbackBytes += int64(len(line))
if fallbackBytes <= fallbackLimit {
_, _ = fallbackBody.Write(line)
} else {
fallbackTooLarge = true
fallbackBody.Reset()
} }
} }
if err == io.EOF {
break
}
if err != nil {
return OpenAIUsage{}, 0, firstTokenMs, err
}
} }
return usage, imageCount, firstTokenMs, nil
finalizeFallbackBody := func() {
if seenSSEData || fallbackBody.Len() == 0 {
return
}
body := bytes.TrimSpace(fallbackBody.Bytes())
if len(body) == 0 {
return
}
mergeOpenAIUsage(&usage, body)
imageCounter.AddJSONResponse(body)
}
streamInterval := s.openAIImageStreamDataInterval()
keepaliveInterval := s.openAIImageStreamKeepaliveInterval()
if streamInterval <= 0 && keepaliveInterval <= 0 {
reader := bufio.NewReader(resp.Body)
for {
line, err := reader.ReadBytes('\n')
processLine(line)
if err == io.EOF {
break
}
if err != nil {
flushSSEEvent()
return usage, imageCounter.Count(), firstTokenMs, err
}
}
flushSSEEvent()
finalizeFallbackBody()
return usage, imageCounter.Count(), firstTokenMs, nil
}
type readEvent struct {
line []byte
err error
}
events := make(chan readEvent, 16)
done := make(chan struct{})
sendEvent := func(ev readEvent) bool {
select {
case events <- ev:
return true
case <-done:
return false
}
}
var lastReadAt int64
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
go func() {
defer close(events)
reader := bufio.NewReader(resp.Body)
for {
line, err := reader.ReadBytes('\n')
if len(line) > 0 {
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
}
if len(line) > 0 && !sendEvent(readEvent{line: line}) {
return
}
if err == io.EOF {
return
}
if err != nil {
_ = sendEvent(readEvent{err: err})
return
}
}
}()
defer close(done)
var intervalTicker *time.Ticker
if streamInterval > 0 {
intervalTicker = time.NewTicker(streamInterval)
defer intervalTicker.Stop()
}
var intervalCh <-chan time.Time
if intervalTicker != nil {
intervalCh = intervalTicker.C
}
var keepaliveTicker *time.Ticker
if keepaliveInterval > 0 {
keepaliveTicker = time.NewTicker(keepaliveInterval)
defer keepaliveTicker.Stop()
}
var keepaliveCh <-chan time.Time
if keepaliveTicker != nil {
keepaliveCh = keepaliveTicker.C
}
for {
select {
case ev, ok := <-events:
if !ok {
flushSSEEvent()
finalizeFallbackBody()
return usage, imageCounter.Count(), firstTokenMs, nil
}
if ev.err != nil {
flushSSEEvent()
return usage, imageCounter.Count(), firstTokenMs, ev.err
}
processLine(ev.line)
case <-intervalCh:
lastRead := time.Unix(0, atomic.LoadInt64(&lastReadAt))
if time.Since(lastRead) < streamInterval {
continue
}
if clientDisconnected {
return usage, imageCounter.Count(), firstTokenMs, fmt.Errorf("image stream incomplete after timeout")
}
logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Images stream data interval timeout: interval=%s", streamInterval)
_ = s.writeOpenAIImagesStreamEvent(c, flusher, "error", buildOpenAIImagesStreamErrorBody(fmt.Sprintf("upstream image stream idle for %s", streamInterval)))
return usage, imageCounter.Count(), firstTokenMs, fmt.Errorf("image stream data interval timeout")
case <-keepaliveCh:
if clientDisconnected || time.Since(lastDownstreamWriteAt) < keepaliveInterval {
continue
}
if _, writeErr := io.WriteString(c.Writer, ":\n\n"); writeErr != nil {
clientDisconnected = true
logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Images stream client disconnected during keepalive, continue draining upstream for billing")
continue
}
flusher.Flush()
lastDownstreamWriteAt = time.Now()
}
}
}
func (s *OpenAIGatewayService) openAIImageStreamDataInterval() time.Duration {
if s == nil || s.cfg == nil || s.cfg.Gateway.ImageStreamDataIntervalTimeout <= 0 {
return 0
}
return time.Duration(s.cfg.Gateway.ImageStreamDataIntervalTimeout) * time.Second
}
func (s *OpenAIGatewayService) openAIImageStreamKeepaliveInterval() time.Duration {
if s == nil || s.cfg == nil || s.cfg.Gateway.ImageStreamKeepaliveInterval <= 0 {
return 0
}
return time.Duration(s.cfg.Gateway.ImageStreamKeepaliveInterval) * time.Second
}
func extractOpenAIImagesBillableCountFromJSONBytes(body []byte) int {
if count := extractOpenAIImageCountFromJSONBytes(body); count > 0 {
return count
}
if len(body) == 0 || !gjson.ValidBytes(body) {
return 0
}
if count := int(gjson.GetBytes(body, "usage.images").Int()); count > 0 {
return count
}
if count := int(gjson.GetBytes(body, "tool_usage.image_gen.images").Int()); count > 0 {
return count
}
eventType := strings.TrimSpace(gjson.GetBytes(body, "type").String())
if eventType == "" || !strings.HasSuffix(eventType, ".completed") {
return 0
}
if gjson.GetBytes(body, "b64_json").Exists() || gjson.GetBytes(body, "url").Exists() {
return 1
}
return 0
} }
func mergeOpenAIUsage(dst *OpenAIUsage, body []byte) { func mergeOpenAIUsage(dst *OpenAIUsage, body []byte) {
@ -863,14 +1110,7 @@ func mergeOpenAIUsage(dst *OpenAIUsage, body []byte) {
} }
func extractOpenAIImageCountFromJSONBytes(body []byte) int { func extractOpenAIImageCountFromJSONBytes(body []byte) int {
if len(body) == 0 || !gjson.ValidBytes(body) { return countOpenAIResponseImageOutputsFromJSONBytes(body)
return 0
}
data := gjson.GetBytes(body, "data")
if data.Exists() && data.IsArray() {
return len(data.Array())
}
return 0
} }
type openAIImagePointerInfo struct { type openAIImagePointerInfo struct {

View File

@ -9,6 +9,7 @@ import (
"io" "io"
"net/http" "net/http"
"strings" "strings"
"sync/atomic"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger" "github.com/Wei-Shaw/sub2api/internal/pkg/logger"
@ -361,21 +362,21 @@ func collectOpenAIImagesFromResponsesBody(body []byte) ([]openAIResponsesImageRe
var ( var (
fallbackResults []openAIResponsesImageResult fallbackResults []openAIResponsesImageResult
fallbackSeen = make(map[string]struct{}) fallbackSeen = make(map[string]struct{})
finalResults []openAIResponsesImageResult
finalMeta openAIResponsesImageResult
collectErr error
createdAt int64 createdAt int64
usageRaw []byte usageRaw []byte
foundFinal bool foundFinal bool
responseMeta openAIResponsesImageResult responseMeta openAIResponsesImageResult
) )
for _, line := range bytes.Split(body, []byte("\n")) { forEachOpenAISSEDataPayload(string(body), func(payload []byte) {
line = bytes.TrimRight(line, "\r") if collectErr != nil || len(finalResults) > 0 {
data, ok := extractOpenAISSEDataLine(string(line)) return
if !ok || data == "" || data == "[DONE]" {
continue
} }
payload := []byte(data)
if !gjson.ValidBytes(payload) { if !gjson.ValidBytes(payload) {
continue return
} }
if meta, eventCreatedAt, ok := extractOpenAIResponsesImageMetaFromLifecycleEvent(payload); ok { if meta, eventCreatedAt, ok := extractOpenAIResponsesImageMetaFromLifecycleEvent(payload); ok {
mergeOpenAIResponsesImageMeta(&responseMeta, meta) mergeOpenAIResponsesImageMeta(&responseMeta, meta)
@ -388,7 +389,8 @@ func collectOpenAIImagesFromResponsesBody(body []byte) ([]openAIResponsesImageRe
case "response.output_item.done": case "response.output_item.done":
result, itemID, ok, err := extractOpenAIImageFromResponsesOutputItemDone(payload) result, itemID, ok, err := extractOpenAIImageFromResponsesOutputItemDone(payload)
if err != nil { if err != nil {
return nil, 0, nil, openAIResponsesImageResult{}, false, err collectErr = err
return
} }
if ok { if ok {
mergeOpenAIResponsesImageMeta(&result, responseMeta) mergeOpenAIResponsesImageMeta(&result, responseMeta)
@ -397,7 +399,8 @@ func collectOpenAIImagesFromResponsesBody(body []byte) ([]openAIResponsesImageRe
case "response.completed": case "response.completed":
results, completedAt, completedUsageRaw, firstMeta, err := extractOpenAIImagesFromResponsesCompleted(payload) results, completedAt, completedUsageRaw, firstMeta, err := extractOpenAIImagesFromResponsesCompleted(payload)
if err != nil { if err != nil {
return nil, 0, nil, openAIResponsesImageResult{}, false, err collectErr = err
return
} }
foundFinal = true foundFinal = true
if completedAt > 0 { if completedAt > 0 {
@ -408,14 +411,24 @@ func collectOpenAIImagesFromResponsesBody(body []byte) ([]openAIResponsesImageRe
} }
if len(results) > 0 { if len(results) > 0 {
mergeOpenAIResponsesImageMeta(&firstMeta, responseMeta) mergeOpenAIResponsesImageMeta(&firstMeta, responseMeta)
return results, createdAt, usageRaw, firstMeta, true, nil finalResults = results
finalMeta = firstMeta
return
} }
if len(fallbackResults) > 0 { if len(fallbackResults) > 0 {
firstMeta = fallbackResults[0] firstMeta = fallbackResults[0]
mergeOpenAIResponsesImageMeta(&firstMeta, responseMeta) mergeOpenAIResponsesImageMeta(&firstMeta, responseMeta)
return fallbackResults, createdAt, usageRaw, firstMeta, true, nil finalResults = fallbackResults
finalMeta = firstMeta
return
} }
} }
})
if collectErr != nil {
return nil, 0, nil, openAIResponsesImageResult{}, false, collectErr
}
if len(finalResults) > 0 {
return finalResults, createdAt, usageRaw, finalMeta, true, nil
} }
if len(fallbackResults) > 0 { if len(fallbackResults) > 0 {
@ -505,6 +518,30 @@ func (s *OpenAIGatewayService) writeOpenAIImagesStreamEvent(c *gin.Context, flus
return nil return nil
} }
func (s *OpenAIGatewayService) tryWriteOpenAIImagesStreamEvent(
c *gin.Context,
flusher http.Flusher,
clientDisconnected *bool,
lastWriteAt *time.Time,
eventName string,
payload []byte,
) bool {
if clientDisconnected != nil && *clientDisconnected {
return false
}
if err := s.writeOpenAIImagesStreamEvent(c, flusher, eventName, payload); err != nil {
if clientDisconnected != nil {
*clientDisconnected = true
}
logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Images stream client disconnected, continue draining upstream for billing")
return false
}
if lastWriteAt != nil {
*lastWriteAt = time.Now()
}
return true
}
func (s *OpenAIGatewayService) handleOpenAIImagesOAuthNonStreamingResponse( func (s *OpenAIGatewayService) handleOpenAIImagesOAuthNonStreamingResponse(
resp *http.Response, resp *http.Response,
c *gin.Context, c *gin.Context,
@ -517,15 +554,9 @@ func (s *OpenAIGatewayService) handleOpenAIImagesOAuthNonStreamingResponse(
} }
var usage OpenAIUsage var usage OpenAIUsage
for _, line := range bytes.Split(body, []byte("\n")) { forEachOpenAISSEDataPayload(string(body), func(data []byte) {
line = bytes.TrimRight(line, "\r") s.parseSSEUsageBytes(data, &usage)
data, ok := extractOpenAISSEDataLine(string(line)) })
if !ok || data == "" || data == "[DONE]" {
continue
}
dataBytes := []byte(data)
s.parseSSEUsageBytes(dataBytes, &usage)
}
results, createdAt, usageRaw, firstMeta, _, err := collectOpenAIImagesFromResponsesBody(body) results, createdAt, usageRaw, firstMeta, _, err := collectOpenAIImagesFromResponsesBody(body)
if err != nil { if err != nil {
return OpenAIUsage{}, 0, err return OpenAIUsage{}, 0, err
@ -570,7 +601,6 @@ func (s *OpenAIGatewayService) handleOpenAIImagesOAuthStreamingResponse(
format = "b64_json" format = "b64_json"
} }
reader := bufio.NewReader(resp.Body)
usage := OpenAIUsage{} usage := OpenAIUsage{}
imageCount := 0 imageCount := 0
var firstTokenMs *int var firstTokenMs *int
@ -579,141 +609,307 @@ func (s *OpenAIGatewayService) handleOpenAIImagesOAuthStreamingResponse(
pendingSeen := make(map[string]struct{}) pendingSeen := make(map[string]struct{})
streamMeta := openAIResponsesImageResult{Model: strings.TrimSpace(fallbackModel)} streamMeta := openAIResponsesImageResult{Model: strings.TrimSpace(fallbackModel)}
var createdAt int64 var createdAt int64
clientDisconnected := false
lastDownstreamWriteAt := time.Now()
var sseData openAISSEDataAccumulator
var processDataErr error
processDataDone := false
processData := func(dataBytes []byte) {
if processDataDone || processDataErr != nil {
return
}
if firstTokenMs == nil {
ms := int(time.Since(startTime).Milliseconds())
firstTokenMs = &ms
}
s.parseSSEUsageBytes(dataBytes, &usage)
if !gjson.ValidBytes(dataBytes) {
return
}
if meta, eventCreatedAt, ok := extractOpenAIResponsesImageMetaFromLifecycleEvent(dataBytes); ok {
mergeOpenAIResponsesImageMeta(&streamMeta, meta)
if eventCreatedAt > 0 {
createdAt = eventCreatedAt
}
}
switch gjson.GetBytes(dataBytes, "type").String() {
case "response.image_generation_call.partial_image":
b64 := strings.TrimSpace(gjson.GetBytes(dataBytes, "partial_image_b64").String())
if b64 == "" {
return
}
eventName := streamPrefix + ".partial_image"
partialMeta := streamMeta
mergeOpenAIResponsesImageMeta(&partialMeta, openAIResponsesImageResult{
OutputFormat: strings.TrimSpace(gjson.GetBytes(dataBytes, "output_format").String()),
Background: strings.TrimSpace(gjson.GetBytes(dataBytes, "background").String()),
})
payload := buildOpenAIImagesStreamPartialPayload(
eventName,
b64,
gjson.GetBytes(dataBytes, "partial_image_index").Int(),
format,
createdAt,
partialMeta,
)
s.tryWriteOpenAIImagesStreamEvent(c, flusher, &clientDisconnected, &lastDownstreamWriteAt, eventName, payload)
case "response.output_item.done":
img, itemID, ok, extractErr := extractOpenAIImageFromResponsesOutputItemDone(dataBytes)
if extractErr != nil {
s.tryWriteOpenAIImagesStreamEvent(c, flusher, &clientDisconnected, &lastDownstreamWriteAt, "error", buildOpenAIImagesStreamErrorBody(extractErr.Error()))
processDataErr = extractErr
processDataDone = true
return
}
if !ok {
return
}
mergeOpenAIResponsesImageMeta(&streamMeta, img)
mergeOpenAIResponsesImageMeta(&img, streamMeta)
key := openAIResponsesImageResultKey(itemID, img)
if _, exists := emitted[key]; exists {
return
}
if _, exists := pendingSeen[key]; exists {
return
}
pendingSeen[key] = struct{}{}
pendingResults = append(pendingResults, img)
case "response.completed":
results, _, usageRaw, firstMeta, extractErr := extractOpenAIImagesFromResponsesCompleted(dataBytes)
if extractErr != nil {
s.tryWriteOpenAIImagesStreamEvent(c, flusher, &clientDisconnected, &lastDownstreamWriteAt, "error", buildOpenAIImagesStreamErrorBody(extractErr.Error()))
processDataErr = extractErr
processDataDone = true
return
}
mergeOpenAIResponsesImageMeta(&streamMeta, firstMeta)
finalResults := make([]openAIResponsesImageResult, 0, len(results)+len(pendingResults))
finalSeen := make(map[string]struct{})
for _, img := range results {
mergeOpenAIResponsesImageMeta(&img, streamMeta)
appendOpenAIResponsesImageResultDedup(&finalResults, finalSeen, "", img)
}
for _, img := range pendingResults {
mergeOpenAIResponsesImageMeta(&img, streamMeta)
appendOpenAIResponsesImageResultDedup(&finalResults, finalSeen, "", img)
}
if len(finalResults) == 0 {
outputErr := fmt.Errorf("upstream did not return image output")
s.tryWriteOpenAIImagesStreamEvent(c, flusher, &clientDisconnected, &lastDownstreamWriteAt, "error", buildOpenAIImagesStreamErrorBody(outputErr.Error()))
processDataErr = outputErr
processDataDone = true
return
}
eventName := streamPrefix + ".completed"
for _, img := range finalResults {
key := openAIResponsesImageResultKey("", img)
if _, exists := emitted[key]; exists {
continue
}
payload := buildOpenAIImagesStreamCompletedPayload(eventName, img, format, createdAt, usageRaw)
emitted[key] = struct{}{}
s.tryWriteOpenAIImagesStreamEvent(c, flusher, &clientDisconnected, &lastDownstreamWriteAt, eventName, payload)
}
imageCount = len(emitted)
processDataDone = true
}
}
processLine := func(line []byte) (bool, error) {
if len(line) == 0 {
return false, nil
}
sseData.AddLine(string(line), processData)
if processDataErr != nil {
return true, processDataErr
}
return processDataDone, nil
}
flushData := func() (bool, error) {
sseData.Flush(processData)
if processDataErr != nil {
return true, processDataErr
}
return processDataDone, nil
}
finalizePending := func() error {
if imageCount > 0 {
return nil
}
if len(pendingResults) > 0 {
eventName := streamPrefix + ".completed"
for _, img := range pendingResults {
mergeOpenAIResponsesImageMeta(&img, streamMeta)
key := openAIResponsesImageResultKey("", img)
if _, exists := emitted[key]; exists {
continue
}
payload := buildOpenAIImagesStreamCompletedPayload(eventName, img, format, createdAt, nil)
emitted[key] = struct{}{}
s.tryWriteOpenAIImagesStreamEvent(c, flusher, &clientDisconnected, &lastDownstreamWriteAt, eventName, payload)
}
imageCount = len(emitted)
return nil
}
streamErr := fmt.Errorf("stream disconnected before image generation completed")
s.tryWriteOpenAIImagesStreamEvent(c, flusher, &clientDisconnected, &lastDownstreamWriteAt, "error", buildOpenAIImagesStreamErrorBody(streamErr.Error()))
return streamErr
}
streamInterval := s.openAIImageStreamDataInterval()
keepaliveInterval := s.openAIImageStreamKeepaliveInterval()
if streamInterval <= 0 && keepaliveInterval <= 0 {
reader := bufio.NewReader(resp.Body)
for {
line, err := reader.ReadBytes('\n')
done, processErr := processLine(line)
if processErr != nil {
return usage, imageCount, firstTokenMs, processErr
}
if done {
return usage, imageCount, firstTokenMs, nil
}
if err == io.EOF {
break
}
if err != nil {
if done, processErr := flushData(); processErr != nil {
return usage, imageCount, firstTokenMs, processErr
} else if done {
return usage, imageCount, firstTokenMs, nil
}
s.tryWriteOpenAIImagesStreamEvent(c, flusher, &clientDisconnected, &lastDownstreamWriteAt, "error", buildOpenAIImagesStreamErrorBody(err.Error()))
return usage, imageCount, firstTokenMs, err
}
}
if done, processErr := flushData(); processErr != nil {
return usage, imageCount, firstTokenMs, processErr
} else if done {
return usage, imageCount, firstTokenMs, nil
}
if err := finalizePending(); err != nil {
return usage, imageCount, firstTokenMs, err
}
return usage, imageCount, firstTokenMs, nil
}
type readEvent struct {
line []byte
err error
}
events := make(chan readEvent, 16)
done := make(chan struct{})
sendEvent := func(ev readEvent) bool {
select {
case events <- ev:
return true
case <-done:
return false
}
}
var lastReadAt int64
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
go func() {
defer close(events)
reader := bufio.NewReader(resp.Body)
for {
line, err := reader.ReadBytes('\n')
if len(line) > 0 {
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
}
if len(line) > 0 && !sendEvent(readEvent{line: line}) {
return
}
if err == io.EOF {
return
}
if err != nil {
_ = sendEvent(readEvent{err: err})
return
}
}
}()
defer close(done)
var intervalTicker *time.Ticker
if streamInterval > 0 {
intervalTicker = time.NewTicker(streamInterval)
defer intervalTicker.Stop()
}
var intervalCh <-chan time.Time
if intervalTicker != nil {
intervalCh = intervalTicker.C
}
var keepaliveTicker *time.Ticker
if keepaliveInterval > 0 {
keepaliveTicker = time.NewTicker(keepaliveInterval)
defer keepaliveTicker.Stop()
}
var keepaliveCh <-chan time.Time
if keepaliveTicker != nil {
keepaliveCh = keepaliveTicker.C
}
for { for {
line, err := reader.ReadBytes('\n') select {
if len(line) > 0 { case ev, ok := <-events:
trimmedLine := strings.TrimRight(string(line), "\r\n") if !ok {
data, ok := extractOpenAISSEDataLine(trimmedLine) if done, processErr := flushData(); processErr != nil {
if ok && data != "" && data != "[DONE]" { return usage, imageCount, firstTokenMs, processErr
if firstTokenMs == nil { } else if done {
ms := int(time.Since(startTime).Milliseconds()) return usage, imageCount, firstTokenMs, nil
firstTokenMs = &ms
} }
dataBytes := []byte(data) if err := finalizePending(); err != nil {
s.parseSSEUsageBytes(dataBytes, &usage) return usage, imageCount, firstTokenMs, err
if gjson.ValidBytes(dataBytes) {
if meta, eventCreatedAt, ok := extractOpenAIResponsesImageMetaFromLifecycleEvent(dataBytes); ok {
mergeOpenAIResponsesImageMeta(&streamMeta, meta)
if eventCreatedAt > 0 {
createdAt = eventCreatedAt
}
}
switch gjson.GetBytes(dataBytes, "type").String() {
case "response.image_generation_call.partial_image":
b64 := strings.TrimSpace(gjson.GetBytes(dataBytes, "partial_image_b64").String())
if b64 != "" {
eventName := streamPrefix + ".partial_image"
partialMeta := streamMeta
mergeOpenAIResponsesImageMeta(&partialMeta, openAIResponsesImageResult{
OutputFormat: strings.TrimSpace(gjson.GetBytes(dataBytes, "output_format").String()),
Background: strings.TrimSpace(gjson.GetBytes(dataBytes, "background").String()),
})
payload := buildOpenAIImagesStreamPartialPayload(
eventName,
b64,
gjson.GetBytes(dataBytes, "partial_image_index").Int(),
format,
createdAt,
partialMeta,
)
if writeErr := s.writeOpenAIImagesStreamEvent(c, flusher, eventName, payload); writeErr != nil {
return OpenAIUsage{}, imageCount, firstTokenMs, writeErr
}
}
case "response.output_item.done":
img, itemID, ok, extractErr := extractOpenAIImageFromResponsesOutputItemDone(dataBytes)
if extractErr != nil {
_ = s.writeOpenAIImagesStreamEvent(c, flusher, "error", buildOpenAIImagesStreamErrorBody(extractErr.Error()))
return OpenAIUsage{}, imageCount, firstTokenMs, extractErr
}
if !ok {
break
}
mergeOpenAIResponsesImageMeta(&streamMeta, img)
mergeOpenAIResponsesImageMeta(&img, streamMeta)
key := openAIResponsesImageResultKey(itemID, img)
if _, exists := emitted[key]; exists {
break
}
if _, exists := pendingSeen[key]; exists {
break
}
pendingSeen[key] = struct{}{}
pendingResults = append(pendingResults, img)
case "response.completed":
results, _, usageRaw, firstMeta, extractErr := extractOpenAIImagesFromResponsesCompleted(dataBytes)
if extractErr != nil {
_ = s.writeOpenAIImagesStreamEvent(c, flusher, "error", buildOpenAIImagesStreamErrorBody(extractErr.Error()))
return OpenAIUsage{}, imageCount, firstTokenMs, extractErr
}
mergeOpenAIResponsesImageMeta(&streamMeta, firstMeta)
finalResults := make([]openAIResponsesImageResult, 0, len(results)+len(pendingResults))
finalSeen := make(map[string]struct{})
for _, img := range results {
mergeOpenAIResponsesImageMeta(&img, streamMeta)
appendOpenAIResponsesImageResultDedup(&finalResults, finalSeen, "", img)
}
for _, img := range pendingResults {
mergeOpenAIResponsesImageMeta(&img, streamMeta)
appendOpenAIResponsesImageResultDedup(&finalResults, finalSeen, "", img)
}
if len(finalResults) == 0 {
err = fmt.Errorf("upstream did not return image output")
_ = s.writeOpenAIImagesStreamEvent(c, flusher, "error", buildOpenAIImagesStreamErrorBody(err.Error()))
return OpenAIUsage{}, imageCount, firstTokenMs, err
}
eventName := streamPrefix + ".completed"
for _, img := range finalResults {
key := openAIResponsesImageResultKey("", img)
if _, exists := emitted[key]; exists {
continue
}
payload := buildOpenAIImagesStreamCompletedPayload(eventName, img, format, createdAt, usageRaw)
if writeErr := s.writeOpenAIImagesStreamEvent(c, flusher, eventName, payload); writeErr != nil {
return OpenAIUsage{}, imageCount, firstTokenMs, writeErr
}
emitted[key] = struct{}{}
}
imageCount = len(emitted)
return usage, imageCount, firstTokenMs, nil
}
} }
return usage, imageCount, firstTokenMs, nil
} }
} if ev.err != nil {
if err == io.EOF { if done, processErr := flushData(); processErr != nil {
break return usage, imageCount, firstTokenMs, processErr
} } else if done {
if err != nil { return usage, imageCount, firstTokenMs, nil
_ = s.writeOpenAIImagesStreamEvent(c, flusher, "error", buildOpenAIImagesStreamErrorBody(err.Error())) }
return OpenAIUsage{}, imageCount, firstTokenMs, err s.tryWriteOpenAIImagesStreamEvent(c, flusher, &clientDisconnected, &lastDownstreamWriteAt, "error", buildOpenAIImagesStreamErrorBody(ev.err.Error()))
} return usage, imageCount, firstTokenMs, ev.err
} }
done, processErr := processLine(ev.line)
if imageCount > 0 { if processErr != nil {
return usage, imageCount, firstTokenMs, nil return usage, imageCount, firstTokenMs, processErr
} }
if len(pendingResults) > 0 { if done {
eventName := streamPrefix + ".completed" return usage, imageCount, firstTokenMs, nil
for _, img := range pendingResults { }
mergeOpenAIResponsesImageMeta(&img, streamMeta) case <-intervalCh:
key := openAIResponsesImageResultKey("", img) lastRead := time.Unix(0, atomic.LoadInt64(&lastReadAt))
if _, exists := emitted[key]; exists { if time.Since(lastRead) < streamInterval {
continue continue
} }
payload := buildOpenAIImagesStreamCompletedPayload(eventName, img, format, createdAt, nil) if clientDisconnected {
if writeErr := s.writeOpenAIImagesStreamEvent(c, flusher, eventName, payload); writeErr != nil { return usage, imageCount, firstTokenMs, fmt.Errorf("image stream incomplete after timeout")
return OpenAIUsage{}, imageCount, firstTokenMs, writeErr
} }
emitted[key] = struct{}{} logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Images responses stream data interval timeout: interval=%s", streamInterval)
s.tryWriteOpenAIImagesStreamEvent(c, flusher, &clientDisconnected, &lastDownstreamWriteAt, "error", buildOpenAIImagesStreamErrorBody(fmt.Sprintf("upstream image stream idle for %s", streamInterval)))
return usage, imageCount, firstTokenMs, fmt.Errorf("image stream data interval timeout")
case <-keepaliveCh:
if clientDisconnected || time.Since(lastDownstreamWriteAt) < keepaliveInterval {
continue
}
if _, writeErr := io.WriteString(c.Writer, ":\n\n"); writeErr != nil {
clientDisconnected = true
logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Images responses stream client disconnected during keepalive, continue draining upstream for billing")
continue
}
flusher.Flush()
lastDownstreamWriteAt = time.Now()
} }
imageCount = len(emitted)
return usage, imageCount, firstTokenMs, nil
} }
streamErr := fmt.Errorf("stream disconnected before image generation completed")
_ = s.writeOpenAIImagesStreamEvent(c, flusher, "error", buildOpenAIImagesStreamErrorBody(streamErr.Error()))
return OpenAIUsage{}, imageCount, firstTokenMs, streamErr
} }
func (s *OpenAIGatewayService) forwardOpenAIImagesOAuth( func (s *OpenAIGatewayService) forwardOpenAIImagesOAuth(
@ -752,7 +948,10 @@ func (s *OpenAIGatewayService) forwardOpenAIImagesOAuth(
) )
} }
token, _, err := s.GetAccessToken(ctx, account) upstreamCtx, releaseUpstreamCtx := detachStreamUpstreamContext(ctx, parsed.Stream)
defer releaseUpstreamCtx()
token, _, err := s.GetAccessToken(upstreamCtx, account)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -763,7 +962,7 @@ func (s *OpenAIGatewayService) forwardOpenAIImagesOAuth(
} }
setOpsUpstreamRequestBody(c, responsesBody) setOpsUpstreamRequestBody(c, responsesBody)
upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, responsesBody, token, true, parsed.StickySessionSeed(), false) upstreamReq, err := s.buildUpstreamRequest(upstreamCtx, c, account, responsesBody, token, true, parsed.StickySessionSeed(), false)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -808,14 +1007,14 @@ func (s *OpenAIGatewayService) forwardOpenAIImagesOAuth(
Kind: "failover", Kind: "failover",
Message: upstreamMsg, Message: upstreamMsg,
}) })
s.handleFailoverSideEffects(ctx, resp, account) s.handleFailoverSideEffects(upstreamCtx, resp, account)
return nil, &UpstreamFailoverError{ return nil, &UpstreamFailoverError{
StatusCode: resp.StatusCode, StatusCode: resp.StatusCode,
ResponseBody: respBody, ResponseBody: respBody,
RetryableOnSameAccount: account.IsPoolMode() && isPoolModeRetryableStatus(resp.StatusCode), RetryableOnSameAccount: account.IsPoolMode() && isPoolModeRetryableStatus(resp.StatusCode),
} }
} }
return s.handleErrorResponse(ctx, resp, c, account, responsesBody) return s.handleErrorResponse(upstreamCtx, resp, c, account, responsesBody)
} }
defer func() { _ = resp.Body.Close() }() defer func() { _ = resp.Body.Close() }()
@ -827,6 +1026,20 @@ func (s *OpenAIGatewayService) forwardOpenAIImagesOAuth(
if parsed.Stream { if parsed.Stream {
usage, imageCount, firstTokenMs, err = s.handleOpenAIImagesOAuthStreamingResponse(resp, c, startTime, parsed.ResponseFormat, openAIImagesStreamPrefix(parsed), requestModel) usage, imageCount, firstTokenMs, err = s.handleOpenAIImagesOAuthStreamingResponse(resp, c, startTime, parsed.ResponseFormat, openAIImagesStreamPrefix(parsed), requestModel)
if err != nil { if err != nil {
if imageCount > 0 {
return &OpenAIForwardResult{
RequestID: resp.Header.Get("x-request-id"),
Usage: usage,
Model: requestModel,
UpstreamModel: requestModel,
Stream: parsed.Stream,
ResponseHeaders: resp.Header.Clone(),
Duration: time.Since(startTime),
FirstTokenMs: firstTokenMs,
ImageCount: imageCount,
ImageSize: parsed.SizeTier,
}, err
}
return nil, err return nil, err
} }
} else { } else {

View File

@ -3,6 +3,7 @@ package service
import ( import (
"bytes" "bytes"
"context" "context"
"errors"
"io" "io"
"mime/multipart" "mime/multipart"
"net/http" "net/http"
@ -17,6 +18,20 @@ import (
"github.com/tidwall/gjson" "github.com/tidwall/gjson"
) )
type failingOpenAIImageWriter struct {
gin.ResponseWriter
failAfter int
writes int
}
func (w *failingOpenAIImageWriter) Write(p []byte) (int, error) {
if w.writes >= w.failAfter {
return 0, errors.New("write failed: client disconnected")
}
w.writes++
return w.ResponseWriter.Write(p)
}
func TestOpenAIGatewayServiceParseOpenAIImagesRequest_JSON(t *testing.T) { func TestOpenAIGatewayServiceParseOpenAIImagesRequest_JSON(t *testing.T) {
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)
body := []byte(`{"model":"gpt-image-2","prompt":"draw a cat","size":"1024x1024","quality":"high","stream":true}`) body := []byte(`{"model":"gpt-image-2","prompt":"draw a cat","size":"1024x1024","quality":"high","stream":true}`)
@ -75,6 +90,100 @@ func TestOpenAIGatewayServiceParseOpenAIImagesRequest_MultipartEdit(t *testing.T
require.Equal(t, OpenAIImagesCapabilityNative, parsed.RequiredCapability) require.Equal(t, OpenAIImagesCapabilityNative, parsed.RequiredCapability)
} }
func TestOpenAIGatewayServiceParseOpenAIImagesRequest_NormalizesOfficialAndCustomSizes(t *testing.T) {
gin.SetMode(gin.TestMode)
tests := []struct {
size string
wantTier string
}{
{size: "1024x1024", wantTier: "1K"},
{size: "1536x1024", wantTier: "2K"},
{size: "1024x1536", wantTier: "2K"},
{size: "2048x2048", wantTier: "2K"},
{size: "2048x1152", wantTier: "2K"},
{size: "3840x2160", wantTier: "4K"},
{size: "2160x3840", wantTier: "4K"},
{size: "1024X768", wantTier: "2K"},
{size: "1280x768", wantTier: "2K"},
{size: "2560x1440", wantTier: "2K"},
{size: "2560x1600", wantTier: "4K"},
{size: "auto", wantTier: "2K"},
}
svc := &OpenAIGatewayService{}
for _, tt := range tests {
t.Run(tt.size, func(t *testing.T) {
body := []byte(`{"model":"gpt-image-2","prompt":"draw a cat","size":"` + tt.size + `"}`)
req := httptest.NewRequest(http.MethodPost, "/v1/images/generations", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = req
parsed, err := svc.ParseOpenAIImagesRequest(c, body)
require.NoError(t, err)
require.NotNil(t, parsed)
require.Equal(t, tt.size, parsed.Size)
require.Equal(t, tt.wantTier, parsed.SizeTier)
})
}
}
func TestOpenAIGatewayServiceParseOpenAIImagesRequest_UnknownSizesDoNotBlockPassthrough(t *testing.T) {
gin.SetMode(gin.TestMode)
tests := []struct {
size string
wantTier string
}{
{size: "2048x1153", wantTier: "2K"},
{size: "4096x1024", wantTier: "4K"},
{size: "3840x1024", wantTier: "4K"},
{size: "512x512", wantTier: "2K"},
{size: "invalid", wantTier: "2K"},
{size: "999999999999999999999999999x2", wantTier: "2K"},
}
svc := &OpenAIGatewayService{}
for _, tt := range tests {
t.Run(tt.size, func(t *testing.T) {
body := []byte(`{"model":"gpt-image-2","prompt":"draw a cat","size":"` + tt.size + `"}`)
req := httptest.NewRequest(http.MethodPost, "/v1/images/generations", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = req
parsed, err := svc.ParseOpenAIImagesRequest(c, body)
require.NoError(t, err)
require.NotNil(t, parsed)
require.Equal(t, tt.size, parsed.Size)
require.Equal(t, tt.wantTier, parsed.SizeTier)
})
}
}
func TestOpenAIGatewayServiceParseOpenAIImagesRequest_LegacyImageModelUnknownSizePassthrough(t *testing.T) {
gin.SetMode(gin.TestMode)
body := []byte(`{"model":"gpt-image-1.5","prompt":"draw a cat","size":"2048x1152"}`)
req := httptest.NewRequest(http.MethodPost, "/v1/images/generations", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = req
svc := &OpenAIGatewayService{}
parsed, err := svc.ParseOpenAIImagesRequest(c, body)
require.NoError(t, err)
require.NotNil(t, parsed)
require.Equal(t, "2048x1152", parsed.Size)
require.Equal(t, "2K", parsed.SizeTier)
}
func TestOpenAIGatewayServiceParseOpenAIImagesRequest_MultipartEditWithMaskAndNativeOptions(t *testing.T) { func TestOpenAIGatewayServiceParseOpenAIImagesRequest_MultipartEditWithMaskAndNativeOptions(t *testing.T) {
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)
@ -446,6 +555,160 @@ func TestOpenAIGatewayServiceForwardImages_APIKeyGenerationUsesConfiguredV1BaseU
require.Equal(t, "aGVsbG8=", gjson.Get(rec.Body.String(), "data.0.b64_json").String()) require.Equal(t, "aGVsbG8=", gjson.Get(rec.Body.String(), "data.0.b64_json").String())
} }
func TestOpenAIGatewayServiceForwardImages_APIKeyStreamJSONResponseBillsImage(t *testing.T) {
gin.SetMode(gin.TestMode)
body := []byte(`{"model":"gpt-image-2","prompt":"draw a cat","stream":true,"response_format":"b64_json"}`)
req := httptest.NewRequest(http.MethodPost, "/v1/images/generations", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = req
svc := &OpenAIGatewayService{
cfg: &config.Config{},
httpUpstream: &httpUpstreamRecorder{
resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{
"Content-Type": []string{"application/json"},
"X-Request-Id": []string{"req_img_stream_json"},
},
Body: io.NopCloser(strings.NewReader(`{"created":1710000008,"usage":{"input_tokens":12,"output_tokens":21,"output_tokens_details":{"image_tokens":9}},"data":[{"b64_json":"aGVsbG8=","revised_prompt":"draw a cat"}]}`)),
},
},
}
parsed, err := svc.ParseOpenAIImagesRequest(c, body)
require.NoError(t, err)
account := &Account{
ID: 7,
Name: "openai-apikey",
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Credentials: map[string]any{
"api_key": "test-api-key",
"base_url": "https://image-upstream.example/v1",
},
}
result, err := svc.ForwardImages(context.Background(), c, account, body, parsed, "")
require.NoError(t, err)
require.NotNil(t, result)
require.True(t, result.Stream)
require.Equal(t, 1, result.ImageCount)
require.Equal(t, 12, result.Usage.InputTokens)
require.Equal(t, 21, result.Usage.OutputTokens)
require.Equal(t, 9, result.Usage.ImageOutputTokens)
require.Equal(t, http.StatusOK, rec.Code)
require.Equal(t, "aGVsbG8=", gjson.Get(rec.Body.String(), "data.0.b64_json").String())
}
func TestOpenAIGatewayServiceForwardImages_APIKeyStreamRawJSONEventStreamFallbackBillsImage(t *testing.T) {
gin.SetMode(gin.TestMode)
body := []byte(`{"model":"gpt-image-2","prompt":"draw a cat","stream":true,"response_format":"b64_json"}`)
req := httptest.NewRequest(http.MethodPost, "/v1/images/generations", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = req
svc := &OpenAIGatewayService{
cfg: &config.Config{},
httpUpstream: &httpUpstreamRecorder{
resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{
"Content-Type": []string{"text/event-stream"},
"X-Request-Id": []string{"req_img_stream_json_mislabeled"},
},
Body: io.NopCloser(strings.NewReader(`{"created":1710000009,"usage":{"input_tokens":10,"output_tokens":18,"output_tokens_details":{"image_tokens":8}},"data":[{"b64_json":"ZmluYWw="}]}`)),
},
},
}
parsed, err := svc.ParseOpenAIImagesRequest(c, body)
require.NoError(t, err)
account := &Account{
ID: 8,
Name: "openai-apikey",
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Credentials: map[string]any{
"api_key": "test-api-key",
"base_url": "https://image-upstream.example/v1",
},
}
result, err := svc.ForwardImages(context.Background(), c, account, body, parsed, "")
require.NoError(t, err)
require.NotNil(t, result)
require.True(t, result.Stream)
require.Equal(t, 1, result.ImageCount)
require.Equal(t, 10, result.Usage.InputTokens)
require.Equal(t, 18, result.Usage.OutputTokens)
require.Equal(t, 8, result.Usage.ImageOutputTokens)
require.Equal(t, "ZmluYWw=", gjson.Get(rec.Body.String(), "data.0.b64_json").String())
}
func TestOpenAIGatewayServiceForwardImages_APIKeyStreamMultilineSSEDataBillsImage(t *testing.T) {
gin.SetMode(gin.TestMode)
body := []byte(`{"model":"gpt-image-2","prompt":"draw a cat","stream":true,"response_format":"b64_json"}`)
req := httptest.NewRequest(http.MethodPost, "/v1/images/generations", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = req
svc := &OpenAIGatewayService{
cfg: &config.Config{},
httpUpstream: &httpUpstreamRecorder{
resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{
"Content-Type": []string{"text/event-stream"},
"X-Request-Id": []string{"req_img_stream_multiline"},
},
Body: io.NopCloser(strings.NewReader(
"data: {\"type\":\"image_generation.completed\",\n" +
"data: \"usage\":{\"input_tokens\":10,\"output_tokens\":18,\"output_tokens_details\":{\"image_tokens\":8}},\n" +
"data: \"b64_json\":\"ZmluYWw=\",\"output_format\":\"png\"}\n\n" +
"data: [DONE]\n\n",
)),
},
},
}
parsed, err := svc.ParseOpenAIImagesRequest(c, body)
require.NoError(t, err)
account := &Account{
ID: 8,
Name: "openai-apikey",
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Credentials: map[string]any{
"api_key": "test-api-key",
},
}
result, err := svc.ForwardImages(context.Background(), c, account, body, parsed, "")
require.NoError(t, err)
require.NotNil(t, result)
require.True(t, result.Stream)
require.Equal(t, 1, result.ImageCount)
require.Equal(t, 10, result.Usage.InputTokens)
require.Equal(t, 18, result.Usage.OutputTokens)
require.Equal(t, 8, result.Usage.ImageOutputTokens)
}
func TestExtractOpenAIImagesBillableCountFromJSONBytes_CompletedEvent(t *testing.T) {
body := []byte(`{"type":"image_generation.completed","b64_json":"ZmluYWw=","usage":{"input_tokens":10,"output_tokens":18}}`)
require.Equal(t, 1, extractOpenAIImagesBillableCountFromJSONBytes(body))
}
func TestOpenAIGatewayServiceForwardImages_APIKeyEditUsesConfiguredV1BaseURL(t *testing.T) { func TestOpenAIGatewayServiceForwardImages_APIKeyEditUsesConfiguredV1BaseURL(t *testing.T) {
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)
@ -583,6 +846,61 @@ func TestOpenAIGatewayServiceForwardImages_OAuthStreamingTransformsEvents(t *tes
require.False(t, gjson.Get(completed.Data, "revised_prompt").Exists()) require.False(t, gjson.Get(completed.Data, "revised_prompt").Exists())
} }
func TestOpenAIGatewayServiceForwardImages_APIKeyStreamingDrainsAfterClientDisconnect(t *testing.T) {
gin.SetMode(gin.TestMode)
body := []byte(`{"model":"gpt-image-2","prompt":"draw a cat","stream":true}`)
req := httptest.NewRequest(http.MethodPost, "/v1/images/generations", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = req
c.Writer = &failingOpenAIImageWriter{ResponseWriter: c.Writer, failAfter: 1}
svc := &OpenAIGatewayService{
cfg: &config.Config{
Gateway: config.GatewayConfig{
ImageStreamDataIntervalTimeout: 1,
ImageStreamKeepaliveInterval: 0,
},
},
httpUpstream: &httpUpstreamRecorder{
resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{
"Content-Type": []string{"text/event-stream"},
"X-Request-Id": []string{"req_img_stream_disconnect_apikey"},
},
Body: io.NopCloser(strings.NewReader(
"data: {\"type\":\"image_generation.partial_image\",\"b64_json\":\"cGFydGlhbA==\"}\n\n" +
"data: {\"type\":\"image_generation.completed\",\"usage\":{\"input_tokens\":3,\"output_tokens\":4,\"output_tokens_details\":{\"image_tokens\":2}},\"b64_json\":\"ZmluYWw=\",\"output_format\":\"png\"}\n\n" +
"data: [DONE]\n\n",
)),
},
},
}
parsed, err := svc.ParseOpenAIImagesRequest(c, body)
require.NoError(t, err)
account := &Account{
ID: 8,
Name: "openai-apikey",
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Credentials: map[string]any{
"api_key": "test-api-key",
},
}
result, err := svc.ForwardImages(context.Background(), c, account, body, parsed, "")
require.NoError(t, err)
require.NotNil(t, result)
require.Equal(t, 1, result.ImageCount)
require.Equal(t, 3, result.Usage.InputTokens)
require.Equal(t, 4, result.Usage.OutputTokens)
require.Equal(t, 2, result.Usage.ImageOutputTokens)
}
func TestOpenAIGatewayServiceForwardImages_OAuthEditsMultipartUsesResponsesAPI(t *testing.T) { func TestOpenAIGatewayServiceForwardImages_OAuthEditsMultipartUsesResponsesAPI(t *testing.T) {
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)
@ -798,6 +1116,23 @@ func TestCollectOpenAIImagesFromResponsesBody_FallsBackToOutputItemDone(t *testi
require.JSONEq(t, `{"images":1}`, string(usageRaw)) require.JSONEq(t, `{"images":1}`, string(usageRaw))
} }
func TestCollectOpenAIImagesFromResponsesBody_MultilineSSE(t *testing.T) {
body := []byte(
"data: {\"type\":\"response.completed\",\n" +
"data: \"response\":{\"created_at\":1710000010,\"usage\":{\"input_tokens\":5,\"output_tokens\":9,\"output_tokens_details\":{\"image_tokens\":4}},\"tool_usage\":{\"image_gen\":{\"images\":1}},\"output\":[{\"type\":\"image_generation_call\",\"result\":\"ZmluYWw=\",\"output_format\":\"png\"}]}}\n\n" +
"data: [DONE]\n\n",
)
results, createdAt, usageRaw, firstMeta, foundFinal, err := collectOpenAIImagesFromResponsesBody(body)
require.NoError(t, err)
require.True(t, foundFinal)
require.Equal(t, int64(1710000010), createdAt)
require.Len(t, results, 1)
require.Equal(t, "ZmluYWw=", results[0].Result)
require.Equal(t, "png", firstMeta.OutputFormat)
require.JSONEq(t, `{"images":1}`, string(usageRaw))
}
func TestOpenAIGatewayServiceForwardImages_OAuthStreamingHandlesOutputItemDoneFallback(t *testing.T) { func TestOpenAIGatewayServiceForwardImages_OAuthStreamingHandlesOutputItemDoneFallback(t *testing.T) {
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)
body := []byte(`{"model":"gpt-image-2","prompt":"draw a cat","stream":true,"response_format":"url"}`) body := []byte(`{"model":"gpt-image-2","prompt":"draw a cat","stream":true,"response_format":"url"}`)
@ -854,3 +1189,116 @@ func TestOpenAIGatewayServiceForwardImages_OAuthStreamingHandlesOutputItemDoneFa
require.JSONEq(t, `{"images":1}`, gjson.Get(completed.Data, "usage").Raw) require.JSONEq(t, `{"images":1}`, gjson.Get(completed.Data, "usage").Raw)
require.NotContains(t, rec.Body.String(), "event: error") require.NotContains(t, rec.Body.String(), "event: error")
} }
func TestOpenAIGatewayServiceForwardImages_OAuthStreamingHandlesMultilineSSE(t *testing.T) {
gin.SetMode(gin.TestMode)
body := []byte(`{"model":"gpt-image-2","prompt":"draw a cat","stream":true,"response_format":"b64_json"}`)
req := httptest.NewRequest(http.MethodPost, "/v1/images/generations", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = req
svc := &OpenAIGatewayService{}
parsed, err := svc.ParseOpenAIImagesRequest(c, body)
require.NoError(t, err)
svc.httpUpstream = &httpUpstreamRecorder{
resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{
"Content-Type": []string{"text/event-stream"},
"X-Request-Id": []string{"req_img_stream_multiline_oauth"},
},
Body: io.NopCloser(strings.NewReader(
"data: {\"type\":\"response.completed\",\n" +
"data: \"response\":{\"created_at\":1710000011,\"usage\":{\"input_tokens\":6,\"output_tokens\":10,\"output_tokens_details\":{\"image_tokens\":5}},\"tool_usage\":{\"image_gen\":{\"images\":1}},\"output\":[{\"type\":\"image_generation_call\",\"result\":\"TXVsdGlsaW5l\",\"output_format\":\"png\"}]}}\n\n" +
"data: [DONE]\n\n",
)),
},
}
account := &Account{
ID: 11,
Name: "openai-oauth",
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "token-123",
},
}
result, err := svc.ForwardImages(context.Background(), c, account, body, parsed, "")
require.NoError(t, err)
require.NotNil(t, result)
require.True(t, result.Stream)
require.Equal(t, 1, result.ImageCount)
require.Equal(t, 6, result.Usage.InputTokens)
require.Equal(t, 10, result.Usage.OutputTokens)
require.Equal(t, 5, result.Usage.ImageOutputTokens)
events := parseOpenAIImageTestSSEEvents(rec.Body.String())
completed, ok := findOpenAIImageTestSSEEvent(events, "image_generation.completed")
require.True(t, ok)
require.Equal(t, "TXVsdGlsaW5l", gjson.Get(completed.Data, "b64_json").String())
require.JSONEq(t, `{"images":1}`, gjson.Get(completed.Data, "usage").Raw)
require.NotContains(t, rec.Body.String(), "event: error")
}
func TestOpenAIGatewayServiceForwardImages_OAuthStreamingDrainsAfterClientDisconnect(t *testing.T) {
gin.SetMode(gin.TestMode)
body := []byte(`{"model":"gpt-image-2","prompt":"draw a cat","stream":true,"response_format":"url"}`)
req := httptest.NewRequest(http.MethodPost, "/v1/images/generations", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = req
c.Writer = &failingOpenAIImageWriter{ResponseWriter: c.Writer, failAfter: 1}
svc := &OpenAIGatewayService{
cfg: &config.Config{
Gateway: config.GatewayConfig{
ImageStreamDataIntervalTimeout: 1,
ImageStreamKeepaliveInterval: 0,
},
},
}
parsed, err := svc.ParseOpenAIImagesRequest(c, body)
require.NoError(t, err)
upstream := &httpUpstreamRecorder{
resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{
"Content-Type": []string{"text/event-stream"},
"X-Request-Id": []string{"req_img_stream_disconnect_oauth"},
},
Body: io.NopCloser(strings.NewReader(
"data: {\"type\":\"response.image_generation_call.partial_image\",\"partial_image_b64\":\"cGFydGlhbA==\",\"partial_image_index\":0,\"output_format\":\"png\"}\n\n" +
"data: {\"type\":\"response.completed\",\"response\":{\"created_at\":1710000009,\"usage\":{\"input_tokens\":5,\"output_tokens\":9,\"output_tokens_details\":{\"image_tokens\":4}},\"tool_usage\":{\"image_gen\":{\"images\":1}},\"output\":[{\"type\":\"image_generation_call\",\"result\":\"ZmluYWw=\",\"output_format\":\"png\"}]}}\n\n" +
"data: [DONE]\n\n",
)),
},
}
svc.httpUpstream = upstream
account := &Account{
ID: 9,
Name: "openai-oauth",
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "token-123",
},
}
result, err := svc.ForwardImages(context.Background(), c, account, body, parsed, "")
require.NoError(t, err)
require.NotNil(t, result)
require.True(t, result.Stream)
require.Equal(t, 1, result.ImageCount)
require.Equal(t, 5, result.Usage.InputTokens)
require.Equal(t, 9, result.Usage.OutputTokens)
require.Equal(t, 4, result.Usage.ImageOutputTokens)
}

View File

@ -0,0 +1,57 @@
package service
import (
"bytes"
"strings"
"github.com/gin-gonic/gin"
"github.com/tidwall/gjson"
)
const openAICompatMessagesBridgeContextKey = "openai_compat_messages_bridge"
func isOpenAICompatMessagesBridgeBody(body []byte) bool {
if len(body) == 0 {
return false
}
if bytes.Contains(body, []byte(openAICompatClaudeCodeTodoGuardMarker)) {
return true
}
return isOpenAICompatMessagesBridgePromptCacheKey(gjson.GetBytes(body, "prompt_cache_key").String())
}
func isOpenAICompatMessagesBridgeRequestBody(reqBody map[string]any) bool {
if reqBody == nil {
return false
}
if input, ok := reqBody["input"].([]any); ok && inputContainsText(input, openAICompatClaudeCodeTodoGuardMarker) {
return true
}
return isOpenAICompatMessagesBridgePromptCacheKey(firstNonEmptyString(reqBody["prompt_cache_key"]))
}
func isOpenAICompatMessagesBridgePromptCacheKey(key string) bool {
key = strings.TrimSpace(key)
return strings.HasPrefix(key, "anthropic-metadata-") ||
strings.HasPrefix(key, "anthropic-cache-") ||
strings.HasPrefix(key, "anthropic-digest-")
}
func setOpenAICompatMessagesBridgeContext(c *gin.Context, enabled bool) {
if c == nil || !enabled {
return
}
c.Set(openAICompatMessagesBridgeContextKey, true)
}
func isOpenAICompatMessagesBridgeContext(c *gin.Context) bool {
if c == nil {
return false
}
value, ok := c.Get(openAICompatMessagesBridgeContextKey)
if !ok {
return false
}
enabled, ok := value.(bool)
return ok && enabled
}

View File

@ -0,0 +1,277 @@
package service
import (
"context"
"encoding/json"
"net/http"
"strconv"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/apicompat"
"github.com/gin-gonic/gin"
"github.com/tidwall/gjson"
)
type openAICompatSessionResponseBinding struct {
ResponseID string
TurnState string
ContinuationDisabled bool
ExpiresAt time.Time
}
func openAICompatContinuationEnabled(account *Account, model string) bool {
if account == nil || account.Type != AccountTypeAPIKey {
return false
}
return shouldAutoInjectPromptCacheKeyForCompat(model)
}
func trimAnthropicCompatResponsesInputToLatestTurn(req *apicompat.ResponsesRequest) {
if req == nil || len(req.Input) == 0 {
return
}
var items []apicompat.ResponsesInputItem
if err := json.Unmarshal(req.Input, &items); err != nil || len(items) == 0 {
return
}
start := len(items) - 1
for start > 0 && items[start].Type == "function_call_output" {
start--
}
trimmed := append([]apicompat.ResponsesInputItem(nil), items[start:]...)
if len(trimmed) == len(items) {
return
}
if input, err := json.Marshal(trimmed); err == nil {
req.Input = input
}
}
func isOpenAICompatPreviousResponseNotFound(statusCode int, upstreamMsg string, upstreamBody []byte) bool {
if statusCode != http.StatusBadRequest && statusCode != http.StatusNotFound {
return false
}
check := func(s string) bool {
lower := strings.ToLower(strings.TrimSpace(s))
return strings.Contains(lower, "previous_response_not_found") ||
(strings.Contains(lower, "previous response") && strings.Contains(lower, "not found")) ||
(strings.Contains(lower, "unsupported parameter") && strings.Contains(lower, "previous_response_id"))
}
if check(upstreamMsg) || check(string(upstreamBody)) {
return true
}
return check(gjson.GetBytes(upstreamBody, "error.code").String()) ||
check(gjson.GetBytes(upstreamBody, "error.message").String())
}
func isOpenAICompatPreviousResponseUnsupported(statusCode int, upstreamMsg string, upstreamBody []byte) bool {
if statusCode != http.StatusBadRequest {
return false
}
check := func(s string) bool {
lower := strings.ToLower(strings.TrimSpace(s))
if !strings.Contains(lower, "previous_response_id") {
return false
}
return strings.Contains(lower, "unsupported parameter") ||
strings.Contains(lower, "only supported on responses websocket") ||
strings.Contains(lower, "not supported")
}
if check(upstreamMsg) || check(string(upstreamBody)) {
return true
}
return check(gjson.GetBytes(upstreamBody, "error.code").String()) ||
check(gjson.GetBytes(upstreamBody, "error.message").String())
}
func openAICompatSessionResponseKey(c *gin.Context, account *Account, promptCacheKey string) string {
key := strings.TrimSpace(promptCacheKey)
if account == nil || key == "" {
return ""
}
apiKeyID := int64(0)
if c != nil {
apiKeyID = getAPIKeyIDFromContext(c)
}
return strings.Join([]string{
strconv.FormatInt(account.ID, 10),
strconv.FormatInt(apiKeyID, 10),
key,
}, "\x00")
}
func (s *OpenAIGatewayService) getOpenAICompatSessionResponseID(_ context.Context, c *gin.Context, account *Account, promptCacheKey string) string {
if s == nil {
return ""
}
key := openAICompatSessionResponseKey(c, account, promptCacheKey)
if key == "" {
return ""
}
raw, ok := s.openaiCompatSessionResponses.Load(key)
if !ok {
return ""
}
binding, ok := raw.(openAICompatSessionResponseBinding)
if !ok {
s.openaiCompatSessionResponses.Delete(key)
return ""
}
if !binding.ExpiresAt.IsZero() && time.Now().After(binding.ExpiresAt) {
s.openaiCompatSessionResponses.Delete(key)
return ""
}
if binding.ContinuationDisabled {
return ""
}
if strings.TrimSpace(binding.ResponseID) == "" {
s.openaiCompatSessionResponses.Delete(key)
return ""
}
return strings.TrimSpace(binding.ResponseID)
}
func (s *OpenAIGatewayService) bindOpenAICompatSessionResponseID(_ context.Context, c *gin.Context, account *Account, promptCacheKey, responseID string) {
if s == nil {
return
}
key := openAICompatSessionResponseKey(c, account, promptCacheKey)
id := strings.TrimSpace(responseID)
if key == "" || id == "" {
return
}
binding := openAICompatSessionResponseBinding{
ResponseID: id,
ExpiresAt: time.Now().Add(s.openAIWSResponseStickyTTL()),
}
if raw, ok := s.openaiCompatSessionResponses.Load(key); ok {
if existing, ok := raw.(openAICompatSessionResponseBinding); ok {
if existing.ContinuationDisabled {
existing.ResponseID = ""
existing.ExpiresAt = time.Now().Add(s.openAIWSResponseStickyTTL())
s.openaiCompatSessionResponses.Store(key, existing)
return
}
binding.TurnState = existing.TurnState
}
}
s.openaiCompatSessionResponses.Store(key, binding)
}
func (s *OpenAIGatewayService) deleteOpenAICompatSessionResponseID(_ context.Context, c *gin.Context, account *Account, promptCacheKey string) {
if s == nil {
return
}
key := openAICompatSessionResponseKey(c, account, promptCacheKey)
if key == "" {
return
}
raw, ok := s.openaiCompatSessionResponses.Load(key)
if !ok {
return
}
binding, ok := raw.(openAICompatSessionResponseBinding)
if !ok {
s.openaiCompatSessionResponses.Delete(key)
return
}
binding.ResponseID = ""
if strings.TrimSpace(binding.TurnState) == "" && !binding.ContinuationDisabled {
s.openaiCompatSessionResponses.Delete(key)
return
}
binding.ExpiresAt = time.Now().Add(s.openAIWSResponseStickyTTL())
s.openaiCompatSessionResponses.Store(key, binding)
}
func (s *OpenAIGatewayService) disableOpenAICompatSessionContinuation(_ context.Context, c *gin.Context, account *Account, promptCacheKey string) {
if s == nil {
return
}
key := openAICompatSessionResponseKey(c, account, promptCacheKey)
if key == "" {
return
}
binding := openAICompatSessionResponseBinding{
ContinuationDisabled: true,
ExpiresAt: time.Now().Add(s.openAIWSResponseStickyTTL()),
}
if raw, ok := s.openaiCompatSessionResponses.Load(key); ok {
if existing, ok := raw.(openAICompatSessionResponseBinding); ok {
binding.TurnState = existing.TurnState
}
}
s.openaiCompatSessionResponses.Store(key, binding)
}
func (s *OpenAIGatewayService) isOpenAICompatSessionContinuationDisabled(_ context.Context, c *gin.Context, account *Account, promptCacheKey string) bool {
if s == nil {
return false
}
key := openAICompatSessionResponseKey(c, account, promptCacheKey)
if key == "" {
return false
}
raw, ok := s.openaiCompatSessionResponses.Load(key)
if !ok {
return false
}
binding, ok := raw.(openAICompatSessionResponseBinding)
if !ok {
s.openaiCompatSessionResponses.Delete(key)
return false
}
if !binding.ExpiresAt.IsZero() && time.Now().After(binding.ExpiresAt) {
s.openaiCompatSessionResponses.Delete(key)
return false
}
return binding.ContinuationDisabled
}
func (s *OpenAIGatewayService) getOpenAICompatSessionTurnState(_ context.Context, c *gin.Context, account *Account, promptCacheKey string) string {
if s == nil {
return ""
}
key := openAICompatSessionResponseKey(c, account, promptCacheKey)
if key == "" {
return ""
}
raw, ok := s.openaiCompatSessionResponses.Load(key)
if !ok {
return ""
}
binding, ok := raw.(openAICompatSessionResponseBinding)
if !ok || strings.TrimSpace(binding.TurnState) == "" {
return ""
}
if !binding.ExpiresAt.IsZero() && time.Now().After(binding.ExpiresAt) {
s.openaiCompatSessionResponses.Delete(key)
return ""
}
return strings.TrimSpace(binding.TurnState)
}
func (s *OpenAIGatewayService) bindOpenAICompatSessionTurnState(_ context.Context, c *gin.Context, account *Account, promptCacheKey, turnState string) {
if s == nil {
return
}
key := openAICompatSessionResponseKey(c, account, promptCacheKey)
state := strings.TrimSpace(turnState)
if key == "" || state == "" {
return
}
binding := openAICompatSessionResponseBinding{
TurnState: state,
ExpiresAt: time.Now().Add(s.openAIWSResponseStickyTTL()),
}
if raw, ok := s.openaiCompatSessionResponses.Load(key); ok {
if existing, ok := raw.(openAICompatSessionResponseBinding); ok {
binding.ResponseID = existing.ResponseID
binding.ContinuationDisabled = existing.ContinuationDisabled
}
}
s.openaiCompatSessionResponses.Store(key, binding)
}

View File

@ -0,0 +1,135 @@
package service
import (
"encoding/json"
"fmt"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/apicompat"
)
type openAICompatAnthropicDigestBinding struct {
PromptCacheKey string
ExpiresAt time.Time
}
func buildOpenAICompatAnthropicDigestChain(req *apicompat.AnthropicRequest) string {
if req == nil {
return ""
}
parts := make([]string, 0, len(req.Messages)+1)
if len(req.System) > 0 && strings.TrimSpace(string(req.System)) != "" && strings.TrimSpace(string(req.System)) != "null" {
parts = append(parts, "s:"+shortHash(req.System))
}
for _, msg := range req.Messages {
content := msg.Content
if len(content) == 0 || strings.TrimSpace(string(content)) == "" {
continue
}
prefix := "u"
if strings.TrimSpace(msg.Role) == "assistant" {
prefix = "a"
}
parts = append(parts, prefix+":"+shortHash(content))
}
return strings.Join(parts, "-")
}
func openAICompatAnthropicDigestNamespace(account *Account, cAPIKeyID int64) string {
if account == nil || account.ID <= 0 {
return ""
}
return fmt.Sprintf("%d|%d|", account.ID, cAPIKeyID)
}
func (s *OpenAIGatewayService) findOpenAICompatAnthropicDigestPromptCacheKey(account *Account, cAPIKeyID int64, digestChain string) (promptCacheKey string, matchedChain string) {
if s == nil || digestChain == "" {
return "", ""
}
ns := openAICompatAnthropicDigestNamespace(account, cAPIKeyID)
if ns == "" {
return "", ""
}
chain := digestChain
for {
if raw, ok := s.openaiCompatAnthropicDigestSessions.Load(ns + chain); ok {
if binding, ok := raw.(openAICompatAnthropicDigestBinding); ok {
if binding.ExpiresAt.IsZero() || time.Now().Before(binding.ExpiresAt) {
if key := strings.TrimSpace(binding.PromptCacheKey); key != "" {
return key, chain
}
}
}
s.openaiCompatAnthropicDigestSessions.Delete(ns + chain)
}
i := strings.LastIndex(chain, "-")
if i < 0 {
return "", ""
}
chain = chain[:i]
}
}
func (s *OpenAIGatewayService) bindOpenAICompatAnthropicDigestPromptCacheKey(account *Account, cAPIKeyID int64, digestChain, promptCacheKey, oldDigestChain string) {
if s == nil || digestChain == "" || strings.TrimSpace(promptCacheKey) == "" {
return
}
ns := openAICompatAnthropicDigestNamespace(account, cAPIKeyID)
if ns == "" {
return
}
binding := openAICompatAnthropicDigestBinding{
PromptCacheKey: strings.TrimSpace(promptCacheKey),
ExpiresAt: time.Now().Add(s.openAIWSResponseStickyTTL()),
}
s.openaiCompatAnthropicDigestSessions.Store(ns+digestChain, binding)
if oldDigestChain != "" && oldDigestChain != digestChain {
s.openaiCompatAnthropicDigestSessions.Delete(ns + oldDigestChain)
}
}
func promptCacheKeyFromAnthropicDigest(digestChain string) string {
if strings.TrimSpace(digestChain) == "" {
return ""
}
return "anthropic-digest-" + hashSensitiveValueForLog(digestChain)
}
func promptCacheKeyFromAnthropicMetadataSession(req *apicompat.AnthropicRequest) string {
if req == nil || len(req.Metadata) == 0 {
return ""
}
var metadata struct {
UserID string `json:"user_id"`
}
if err := json.Unmarshal(req.Metadata, &metadata); err != nil {
return ""
}
parsed := ParseMetadataUserID(metadata.UserID)
if parsed == nil || strings.TrimSpace(parsed.SessionID) == "" {
return ""
}
seed := strings.Join([]string{
"anthropic-metadata",
strings.TrimSpace(parsed.DeviceID),
strings.TrimSpace(parsed.AccountUUID),
strings.TrimSpace(parsed.SessionID),
}, "|")
return "anthropic-metadata-" + hashSensitiveValueForLog(seed)
}
func cloneAnthropicRequestForDigest(req *apicompat.AnthropicRequest) *apicompat.AnthropicRequest {
if req == nil {
return nil
}
cp := *req
if len(req.System) > 0 {
cp.System = append(json.RawMessage(nil), req.System...)
}
if len(req.Messages) > 0 {
cp.Messages = append([]apicompat.AnthropicMessage(nil), req.Messages...)
}
return &cp
}

View File

@ -0,0 +1,90 @@
package service
import (
"encoding/json"
"github.com/Wei-Shaw/sub2api/internal/pkg/apicompat"
)
const openAICompatAnthropicReplayMaxTailMessages = 12
func applyAnthropicCompatFullReplayGuard(req *apicompat.AnthropicRequest) bool {
if req == nil || len(req.Messages) <= openAICompatAnthropicReplayMaxTailMessages {
return false
}
start := len(req.Messages) - openAICompatAnthropicReplayMaxTailMessages
start = expandAnthropicCompatTrimBoundary(req.Messages, start)
if start <= 0 {
return false
}
req.Messages = append([]apicompat.AnthropicMessage(nil), req.Messages[start:]...)
return true
}
func expandAnthropicCompatTrimBoundary(messages []apicompat.AnthropicMessage, start int) int {
if start <= 0 || start >= len(messages) {
return start
}
toolUseIndex := make(map[string]int)
toolResultIndex := make(map[string]int)
for i, msg := range messages {
uses, results := anthropicCompatMessageToolIDs(msg)
for _, id := range uses {
if _, exists := toolUseIndex[id]; !exists {
toolUseIndex[id] = i
}
}
for _, id := range results {
if _, exists := toolResultIndex[id]; !exists {
toolResultIndex[id] = i
}
}
}
for {
next := start
for i := start; i < len(messages); i++ {
uses, results := anthropicCompatMessageToolIDs(messages[i])
for _, id := range results {
if useIdx, ok := toolUseIndex[id]; ok && useIdx < next {
next = useIdx
}
}
for _, id := range uses {
if resultIdx, ok := toolResultIndex[id]; ok && resultIdx < next {
next = resultIdx
}
}
}
if next == start {
return start
}
start = next
}
}
func anthropicCompatMessageToolIDs(msg apicompat.AnthropicMessage) ([]string, []string) {
var blocks []apicompat.AnthropicContentBlock
if err := json.Unmarshal(msg.Content, &blocks); err != nil {
return nil, nil
}
uses := make([]string, 0, 1)
results := make([]string, 0, 1)
for _, block := range blocks {
switch block.Type {
case "tool_use":
if block.ID != "" {
uses = append(uses, block.ID)
}
case "tool_result":
if block.ToolUseID != "" {
results = append(results, block.ToolUseID)
}
}
}
return uses, results
}

View File

@ -0,0 +1,58 @@
package service
import (
"encoding/json"
"fmt"
"testing"
"github.com/Wei-Shaw/sub2api/internal/pkg/apicompat"
"github.com/stretchr/testify/require"
)
func TestApplyAnthropicCompatFullReplayGuard_TrimsOldMessages(t *testing.T) {
t.Parallel()
req := &apicompat.AnthropicRequest{Messages: make([]apicompat.AnthropicMessage, 0, openAICompatAnthropicReplayMaxTailMessages+3)}
for i := 0; i < openAICompatAnthropicReplayMaxTailMessages+3; i++ {
req.Messages = append(req.Messages, apicompat.AnthropicMessage{
Role: "user",
Content: json.RawMessage(fmt.Sprintf(`"message-%02d"`, i)),
})
}
trimmed := applyAnthropicCompatFullReplayGuard(req)
require.True(t, trimmed)
require.Len(t, req.Messages, openAICompatAnthropicReplayMaxTailMessages)
require.JSONEq(t, `"message-03"`, string(req.Messages[0].Content))
require.JSONEq(t, `"message-14"`, string(req.Messages[len(req.Messages)-1].Content))
}
func TestApplyAnthropicCompatFullReplayGuard_KeepsToolBoundaryIntact(t *testing.T) {
t.Parallel()
req := &apicompat.AnthropicRequest{Messages: make([]apicompat.AnthropicMessage, 0, openAICompatAnthropicReplayMaxTailMessages+3)}
for i := 0; i < openAICompatAnthropicReplayMaxTailMessages+3; i++ {
role := "user"
content := json.RawMessage(fmt.Sprintf(`"message-%02d"`, i))
if i == 1 {
role = "assistant"
content = json.RawMessage(`[{"type":"tool_use","id":"toolu_keep","name":"Read","input":{"file_path":"main.go"}}]`)
}
if i == 3 {
content = json.RawMessage(`[{"type":"tool_result","tool_use_id":"toolu_keep","content":"ok"}]`)
}
req.Messages = append(req.Messages, apicompat.AnthropicMessage{
Role: role,
Content: content,
})
}
trimmed := applyAnthropicCompatFullReplayGuard(req)
require.True(t, trimmed)
require.Len(t, req.Messages, openAICompatAnthropicReplayMaxTailMessages+2)
require.Equal(t, "assistant", req.Messages[0].Role)
require.Contains(t, string(req.Messages[0].Content), `"toolu_keep"`)
require.Contains(t, string(req.Messages[2].Content), `"tool_result"`)
}

View File

@ -0,0 +1,121 @@
package service
import (
"encoding/json"
"strings"
"github.com/Wei-Shaw/sub2api/internal/pkg/apicompat"
)
const (
openAICompatClaudeCodeTodoGuardMarker = "<sub2api-claude-code-todo-guard>"
openAICompatClaudeCodeTodoGuardText = openAICompatClaudeCodeTodoGuardMarker + "\nWhen using Claude Code todo or task tracking tools, keep the visible task list consistent. Do not send final or summary text while any item remains in_progress. Before finishing, asking the user to choose, or reporting a blocker, update the todo list so completed work is completed and deferred work is pending/open; leave an item in_progress only when active work will continue in the same turn.\n</sub2api-claude-code-todo-guard>"
)
func appendOpenAICompatClaudeCodeTodoGuard(req *apicompat.ResponsesRequest) bool {
if req == nil || len(req.Input) == 0 {
return false
}
var items []apicompat.ResponsesInputItem
if err := json.Unmarshal(req.Input, &items); err != nil {
return false
}
if len(items) == 0 || responsesInputItemsContainText(items, openAICompatClaudeCodeTodoGuardMarker) {
return false
}
content, err := json.Marshal([]apicompat.ResponsesContentPart{{
Type: "input_text",
Text: openAICompatClaudeCodeTodoGuardText,
}})
if err != nil {
return false
}
guard := apicompat.ResponsesInputItem{
Type: "message",
Role: "developer",
Content: content,
}
insertAt := 0
for insertAt < len(items) && items[insertAt].Type == "message" && items[insertAt].Role == "developer" {
insertAt++
}
items = append(items, apicompat.ResponsesInputItem{})
copy(items[insertAt+1:], items[insertAt:])
items[insertAt] = guard
input, err := json.Marshal(items)
if err != nil {
return false
}
req.Input = input
return true
}
func appendOpenAICompatClaudeCodeTodoGuardToRequestBody(reqBody map[string]any) bool {
if reqBody == nil {
return false
}
input, ok := reqBody["input"].([]any)
if !ok || len(input) == 0 || inputContainsText(input, openAICompatClaudeCodeTodoGuardMarker) {
return false
}
guard := map[string]any{
"type": "message",
"role": "developer",
"content": []any{
map[string]any{
"type": "input_text",
"text": openAICompatClaudeCodeTodoGuardText,
},
},
}
insertAt := 0
for insertAt < len(input) {
item, ok := input[insertAt].(map[string]any)
if !ok || strings.TrimSpace(firstNonEmptyString(item["type"])) != "message" || strings.TrimSpace(firstNonEmptyString(item["role"])) != "developer" {
break
}
insertAt++
}
input = append(input, nil)
copy(input[insertAt+1:], input[insertAt:])
input[insertAt] = guard
reqBody["input"] = input
return true
}
func responsesInputItemsContainText(items []apicompat.ResponsesInputItem, needle string) bool {
needle = strings.TrimSpace(needle)
if needle == "" {
return false
}
for _, item := range items {
if strings.Contains(string(item.Content), needle) {
return true
}
}
return false
}
func inputContainsText(input []any, needle string) bool {
needle = strings.TrimSpace(needle)
if needle == "" {
return false
}
for _, item := range input {
b, err := json.Marshal(item)
if err == nil && strings.Contains(string(b), needle) {
return true
}
}
return false
}

View File

@ -0,0 +1,137 @@
package service
import "strings"
func lastOpenAIModelSegment(model string) string {
model = strings.TrimSpace(model)
if model == "" {
return ""
}
if strings.Contains(model, "/") {
parts := strings.Split(model, "/")
model = parts[len(parts)-1]
}
return strings.TrimSpace(model)
}
func canonicalizeOpenAIModelAliasSpelling(model string) string {
model = strings.ToLower(lastOpenAIModelSegment(model))
if model == "" {
return ""
}
normalized := strings.ReplaceAll(model, "_", "-")
normalized = strings.Join(strings.Fields(normalized), "-")
for strings.Contains(normalized, "--") {
normalized = strings.ReplaceAll(normalized, "--", "-")
}
if strings.HasPrefix(normalized, "gpt5") {
normalized = "gpt-5" + strings.TrimPrefix(normalized, "gpt5")
}
if !strings.HasPrefix(normalized, "gpt-") && !strings.Contains(normalized, "codex") {
return ""
}
replacements := []struct {
from string
to string
}{
{"gpt-5.4mini", "gpt-5.4-mini"},
{"gpt-5.4nano", "gpt-5.4-nano"},
{"gpt-5.3-codexspark", "gpt-5.3-codex-spark"},
{"gpt-5.3codexspark", "gpt-5.3-codex-spark"},
{"gpt-5.3codex", "gpt-5.3-codex"},
}
for _, replacement := range replacements {
normalized = strings.ReplaceAll(normalized, replacement.from, replacement.to)
}
return normalized
}
func normalizeKnownOpenAICodexModel(model string) string {
normalized := canonicalizeOpenAIModelAliasSpelling(model)
if normalized == "" {
return ""
}
if mapped := getNormalizedCodexModel(normalized); mapped != "" {
return mapped
}
if strings.HasSuffix(normalized, "-openai-compact") {
if mapped := getNormalizedCodexModel(strings.TrimSuffix(normalized, "-openai-compact")); mapped != "" {
return mapped
}
}
switch {
case strings.Contains(normalized, "gpt-5.5"):
return "gpt-5.5"
case strings.Contains(normalized, "gpt-5.4-mini"):
return "gpt-5.4-mini"
case strings.Contains(normalized, "gpt-5.4-nano"):
return "gpt-5.4-nano"
case strings.Contains(normalized, "gpt-5.4"):
return "gpt-5.4"
case strings.Contains(normalized, "gpt-5.2"):
return "gpt-5.2"
case strings.Contains(normalized, "gpt-5.3-codex-spark"):
return "gpt-5.3-codex-spark"
case strings.Contains(normalized, "gpt-5.3-codex"):
return "gpt-5.3-codex"
case strings.Contains(normalized, "gpt-5.3"):
return "gpt-5.3-codex"
case strings.Contains(normalized, "codex"):
return "gpt-5.3-codex"
case strings.Contains(normalized, "gpt-5"):
return "gpt-5.4"
default:
return ""
}
}
func appendUsageBillingModelCandidate(candidates []string, seen map[string]struct{}, model string) []string {
trimmed := strings.TrimSpace(model)
if trimmed == "" {
return candidates
}
add := func(candidate string) {
candidate = strings.TrimSpace(candidate)
if candidate == "" {
return
}
key := strings.ToLower(candidate)
if _, ok := seen[key]; ok {
return
}
seen[key] = struct{}{}
candidates = append(candidates, candidate)
}
add(trimmed)
if canonical := canonicalizeOpenAIModelAliasSpelling(trimmed); canonical != "" {
add(canonical)
}
if normalized := normalizeKnownOpenAICodexModel(trimmed); normalized != "" {
add(normalized)
}
return candidates
}
func usageBillingModelCandidates(primary string, alternates ...string) []string {
seen := make(map[string]struct{}, 1+len(alternates))
candidates := appendUsageBillingModelCandidate(nil, seen, primary)
for _, alternate := range alternates {
candidates = appendUsageBillingModelCandidate(candidates, seen, alternate)
}
return candidates
}
func firstUsageBillingModel(candidates []string) string {
for _, candidate := range candidates {
if trimmed := strings.TrimSpace(candidate); trimmed != "" {
return trimmed
}
}
return ""
}

View File

@ -2,44 +2,24 @@ package service
import "strings" import "strings"
// resolveOpenAIForwardModel determines the upstream model for OpenAI-compatible // resolveOpenAIForwardModel 解析 OpenAI 兼容转发使用的模型。
// forwarding. Group-level default mapping only applies when the account itself // defaultMappedModel 只服务于 /v1/messages 的 Claude 系列显式调度映射,
// did not match any explicit model_mapping rule. // 不作为普通 OpenAI 请求的未知模型兜底。
func resolveOpenAIForwardModel(account *Account, requestedModel, defaultMappedModel string) string { func resolveOpenAIForwardModel(account *Account, requestedModel, defaultMappedModel string) string {
if account == nil { if account == nil {
if defaultMappedModel != "" { if defaultMappedModel != "" && claudeMessagesDispatchFamily(requestedModel) != "" {
return defaultMappedModel return defaultMappedModel
} }
return requestedModel return requestedModel
} }
mappedModel, matched := account.ResolveMappedModel(requestedModel) mappedModel, matched := account.ResolveMappedModel(requestedModel)
if !matched && defaultMappedModel != "" && !isExplicitCodexModel(requestedModel) { if !matched && defaultMappedModel != "" && claudeMessagesDispatchFamily(requestedModel) != "" {
return defaultMappedModel return defaultMappedModel
} }
return mappedModel return mappedModel
} }
func isExplicitCodexModel(model string) bool {
model = strings.TrimSpace(model)
if model == "" {
return false
}
if strings.Contains(model, "/") {
parts := strings.Split(model, "/")
model = parts[len(parts)-1]
}
model = strings.ToLower(strings.TrimSpace(model))
if getNormalizedCodexModel(model) != "" {
return true
}
if strings.HasSuffix(model, "-openai-compact") {
base := strings.TrimSuffix(model, "-openai-compact")
return getNormalizedCodexModel(base) != ""
}
return false
}
// resolveOpenAICompactForwardModel determines the compact-only upstream model // resolveOpenAICompactForwardModel determines the compact-only upstream model
// for /responses/compact requests. It never affects normal /responses traffic. // for /responses/compact requests. It never affects normal /responses traffic.
// When no compact-specific mapping matches, the input model is returned as-is. // When no compact-specific mapping matches, the input model is returned as-is.

View File

@ -11,7 +11,7 @@ func TestResolveOpenAIForwardModel(t *testing.T) {
expectedModel string expectedModel string
}{ }{
{ {
name: "falls back to group default when account has no mapping", name: "uses messages dispatch default for claude model",
account: &Account{ account: &Account{
Credentials: map[string]any{}, Credentials: map[string]any{},
}, },
@ -19,6 +19,15 @@ func TestResolveOpenAIForwardModel(t *testing.T) {
defaultMappedModel: "gpt-4o-mini", defaultMappedModel: "gpt-4o-mini",
expectedModel: "gpt-4o-mini", expectedModel: "gpt-4o-mini",
}, },
{
name: "does not fall back to group default for invalid gpt model",
account: &Account{
Credentials: map[string]any{},
},
requestedModel: "gpt6",
defaultMappedModel: "gpt-5.4",
expectedModel: "gpt6",
},
{ {
name: "preserves explicit gpt-5.4 instead of group default", name: "preserves explicit gpt-5.4 instead of group default",
account: &Account{ account: &Account{
@ -85,6 +94,15 @@ func TestResolveOpenAIForwardModel(t *testing.T) {
defaultMappedModel: "gpt-5.4", defaultMappedModel: "gpt-5.4",
expectedModel: "gpt-5.5", expectedModel: "gpt-5.5",
}, },
{
name: "preserves compact-spelled gpt5.5 instead of group default",
account: &Account{
Credentials: map[string]any{},
},
requestedModel: "gpt5.5",
defaultMappedModel: "gpt-5.4",
expectedModel: "gpt5.5",
},
{ {
name: "preserves openai namespaced gpt-5.5 instead of group default", name: "preserves openai namespaced gpt-5.5 instead of group default",
account: &Account{ account: &Account{
@ -119,14 +137,14 @@ func TestResolveOpenAIForwardModel_PreventsClaudeModelFromFallingBackToGpt54(t *
Credentials: map[string]any{}, Credentials: map[string]any{},
} }
withoutDefault := normalizeCodexModel(resolveOpenAIForwardModel(account, "claude-opus-4-6", "")) withoutDefault := resolveOpenAIForwardModel(account, "claude-opus-4-6", "")
if withoutDefault != "gpt-5.4" { if withoutDefault != "claude-opus-4-6" {
t.Fatalf("normalizeCodexModel(...) = %q, want %q", withoutDefault, "gpt-5.4") t.Fatalf("resolveOpenAIForwardModel(...) = %q, want %q", withoutDefault, "claude-opus-4-6")
} }
withDefault := normalizeCodexModel(resolveOpenAIForwardModel(account, "claude-opus-4-6", "gpt-5.4")) withDefault := resolveOpenAIForwardModel(account, "claude-opus-4-6", "gpt-5.4")
if withDefault != "gpt-5.4" { if withDefault != "gpt-5.4" {
t.Fatalf("normalizeCodexModel(...) = %q, want %q", withDefault, "gpt-5.4") t.Fatalf("resolveOpenAIForwardModel(...) = %q, want %q", withDefault, "gpt-5.4")
} }
} }
@ -205,6 +223,10 @@ func TestNormalizeCodexModel(t *testing.T) {
"gpt-5.3-codex-spark-xhigh": "gpt-5.3-codex-spark", "gpt-5.3-codex-spark-xhigh": "gpt-5.3-codex-spark",
"gpt-5.3": "gpt-5.3-codex", "gpt-5.3": "gpt-5.3-codex",
"gpt-image-2": "gpt-image-2", "gpt-image-2": "gpt-image-2",
"gpt-5.4-nano": "gpt-5.4-nano",
"gpt-5.4-nano-high": "gpt-5.4-nano",
"gpt6": "gpt6",
"claude-opus-4-6": "claude-opus-4-6",
} }
for input, expected := range cases { for input, expected := range cases {
@ -222,9 +244,21 @@ func TestNormalizeOpenAIModelForUpstream(t *testing.T) {
want string want string
}{ }{
{ {
name: "oauth keeps codex normalization behavior", name: "oauth preserves unknown non codex model",
account: &Account{Type: AccountTypeOAuth}, account: &Account{Type: AccountTypeOAuth},
model: "gemini-3-flash-preview", model: "gemini-3-flash-preview",
want: "gemini-3-flash-preview",
},
{
name: "oauth preserves invalid gpt model",
account: &Account{Type: AccountTypeOAuth},
model: "gpt6",
want: "gpt6",
},
{
name: "oauth normalizes known codex alias",
account: &Account{Type: AccountTypeOAuth},
model: "gpt-5.4-high",
want: "gpt-5.4", want: "gpt-5.4",
}, },
{ {

View File

@ -25,9 +25,12 @@ func f64p(v float64) *float64 { return &v }
type httpUpstreamRecorder struct { type httpUpstreamRecorder struct {
lastReq *http.Request lastReq *http.Request
lastBody []byte lastBody []byte
requests []*http.Request
bodies [][]byte
resp *http.Response resp *http.Response
err error responses []*http.Response
err error
} }
func (u *httpUpstreamRecorder) Do(req *http.Request, proxyURL string, accountID int64, accountConcurrency int) (*http.Response, error) { func (u *httpUpstreamRecorder) Do(req *http.Request, proxyURL string, accountID int64, accountConcurrency int) (*http.Response, error) {
@ -35,12 +38,19 @@ func (u *httpUpstreamRecorder) Do(req *http.Request, proxyURL string, accountID
if req != nil && req.Body != nil { if req != nil && req.Body != nil {
b, _ := io.ReadAll(req.Body) b, _ := io.ReadAll(req.Body)
u.lastBody = b u.lastBody = b
u.bodies = append(u.bodies, append([]byte(nil), b...))
_ = req.Body.Close() _ = req.Body.Close()
req.Body = io.NopCloser(bytes.NewReader(b)) req.Body = io.NopCloser(bytes.NewReader(b))
} }
u.requests = append(u.requests, req)
if u.err != nil { if u.err != nil {
return nil, u.err return nil, u.err
} }
if len(u.responses) > 0 {
resp := u.responses[0]
u.responses = u.responses[1:]
return resp, nil
}
return u.resp, nil return u.resp, nil
} }
@ -48,6 +58,93 @@ func (u *httpUpstreamRecorder) DoWithTLS(req *http.Request, proxyURL string, acc
return u.Do(req, proxyURL, accountID, accountConcurrency) return u.Do(req, proxyURL, accountID, accountConcurrency)
} }
func TestOpenAIGatewayService_ResponsesUnknownModelDoesNotFallbackToGPT54(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
originalBody := []byte(`{"model":"gpt6","stream":false,"instructions":"local-test-instructions","input":[{"type":"text","text":"hi"}]}`)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(originalBody))
c.Request.Header.Set("Content-Type", "application/json")
upstream := &httpUpstreamRecorder{resp: &http.Response{
StatusCode: http.StatusBadRequest,
Header: http.Header{"Content-Type": []string{"application/json"}, "x-request-id": []string{"rid_unknown_model"}},
Body: io.NopCloser(strings.NewReader(`{"error":{"type":"invalid_request_error","message":"model not found"}}`)),
}}
svc := &OpenAIGatewayService{
cfg: &config.Config{},
httpUpstream: upstream,
}
account := &Account{
ID: 123,
Name: "acc",
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "oauth-token",
"chatgpt_account_id": "chatgpt-acc",
},
Status: StatusActive,
Schedulable: true,
}
result, err := svc.Forward(context.Background(), c, account, originalBody)
require.Error(t, err)
require.Nil(t, result)
require.NotNil(t, upstream.lastReq)
require.Equal(t, "https://chatgpt.com/backend-api/codex/responses", upstream.lastReq.URL.String())
require.Equal(t, "gpt6", gjson.GetBytes(upstream.lastBody, "model").String())
require.NotEqual(t, "gpt-5.4", gjson.GetBytes(upstream.lastBody, "model").String())
require.True(t, rec.Code >= http.StatusBadRequest)
}
func TestOpenAIGatewayService_OAuthMessagesBridgeDoesNotInjectDefaultInstructions(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
originalBody := []byte(`{"model":"gpt-5.5","stream":true,"prompt_cache_key":"anthropic-metadata-session-1","input":[{"type":"message","role":"developer","content":[{"type":"input_text","text":"<sub2api-claude-code-todo-guard>"}]},{"type":"message","role":"user","content":"hello"}]}`)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(originalBody))
c.Request.Header.Set("Content-Type", "application/json")
upstream := &httpUpstreamRecorder{resp: &http.Response{
StatusCode: http.StatusBadRequest,
Header: http.Header{"Content-Type": []string{"application/json"}, "x-request-id": []string{"rid_bridge"}},
Body: io.NopCloser(strings.NewReader(`{"error":{"type":"invalid_request_error","message":"bridge stop"}}`)),
}}
svc := &OpenAIGatewayService{
cfg: &config.Config{},
httpUpstream: upstream,
}
account := &Account{
ID: 123,
Name: "acc",
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "oauth-token",
"chatgpt_account_id": "chatgpt-acc",
},
Status: StatusActive,
Schedulable: true,
}
result, err := svc.Forward(context.Background(), c, account, originalBody)
require.Error(t, err)
require.Nil(t, result)
require.NotNil(t, upstream.lastReq)
require.Equal(t, "", gjson.GetBytes(upstream.lastBody, "instructions").String())
require.False(t, gjson.GetBytes(upstream.lastBody, "prompt_cache_key").Exists())
require.NotEmpty(t, upstream.lastReq.Header.Get("Session_Id"))
require.Empty(t, upstream.lastReq.Header.Get("Conversation_Id"))
require.Empty(t, upstream.lastReq.Header.Get("OpenAI-Beta"))
require.Empty(t, upstream.lastReq.Header.Get("originator"))
}
type openAIPassthroughFailoverRepo struct { type openAIPassthroughFailoverRepo struct {
stubOpenAIAccountRepo stubOpenAIAccountRepo
rateLimitCalls []time.Time rateLimitCalls []time.Time
@ -307,6 +404,52 @@ func TestOpenAIGatewayService_OAuthPassthrough_CompactUsesJSONAndKeepsNonStreami
require.Contains(t, rec.Body.String(), `"id":"cmp_123"`) require.Contains(t, rec.Body.String(), `"id":"cmp_123"`)
} }
func TestOpenAIGatewayService_OAuthPassthrough_UpstreamRequestIgnoresClientCancel(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
reqCtx, cancel := context.WithCancel(context.Background())
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(nil)).WithContext(reqCtx)
c.Request.Header.Set("User-Agent", "codex_cli_rs/0.1.0")
cancel()
originalBody := []byte(`{"model":"gpt-5.2","stream":true,"store":true,"instructions":"local-test-instructions","input":[{"type":"text","text":"hi"}]}`)
upstream := &httpUpstreamRecorder{resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_passthrough_ctx"}},
Body: io.NopCloser(strings.NewReader(strings.Join([]string{
`data: {"type":"response.completed","response":{"usage":{"input_tokens":2,"output_tokens":1}}}`,
"",
"data: [DONE]",
"",
}, "\n"))),
}}
svc := &OpenAIGatewayService{
cfg: &config.Config{Gateway: config.GatewayConfig{ForceCodexCLI: false}},
httpUpstream: upstream,
}
account := &Account{
ID: 123,
Name: "acc",
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Concurrency: 1,
Credentials: map[string]any{"access_token": "oauth-token", "chatgpt_account_id": "chatgpt-acc"},
Extra: map[string]any{"openai_passthrough": true, "openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeOff},
Status: StatusActive,
Schedulable: true,
RateMultiplier: f64p(1),
}
result, err := svc.Forward(reqCtx, c, account, originalBody)
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, upstream.lastReq)
require.NoError(t, upstream.lastReq.Context().Err())
}
func TestOpenAIGatewayService_OAuthPassthrough_CodexMissingInstructionsRejectedBeforeUpstream(t *testing.T) { func TestOpenAIGatewayService_OAuthPassthrough_CodexMissingInstructionsRejectedBeforeUpstream(t *testing.T) {
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)
logSink, restore := captureStructuredLog(t) logSink, restore := captureStructuredLog(t)
@ -405,6 +548,52 @@ func TestOpenAIGatewayService_OAuthPassthrough_DisabledUsesLegacyTransform(t *te
require.Contains(t, string(upstream.lastBody), `"stream":true`) require.Contains(t, string(upstream.lastBody), `"stream":true`)
} }
func TestOpenAIGatewayService_OAuthLegacy_UpstreamRequestIgnoresClientCancel(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
reqCtx, cancel := context.WithCancel(context.Background())
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(nil)).WithContext(reqCtx)
c.Request.Header.Set("User-Agent", "codex_cli_rs/0.1.0")
cancel()
originalBody := []byte(`{"model":"gpt-5.2","stream":false,"store":true,"input":[{"type":"text","text":"hi"}]}`)
upstream := &httpUpstreamRecorder{resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_legacy_ctx"}},
Body: io.NopCloser(strings.NewReader(strings.Join([]string{
`data: {"type":"response.completed","response":{"usage":{"input_tokens":1,"output_tokens":1}}}`,
"",
"data: [DONE]",
"",
}, "\n"))),
}}
svc := &OpenAIGatewayService{
cfg: &config.Config{Gateway: config.GatewayConfig{ForceCodexCLI: false}},
httpUpstream: upstream,
}
account := &Account{
ID: 123,
Name: "acc",
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Concurrency: 1,
Credentials: map[string]any{"access_token": "oauth-token", "chatgpt_account_id": "chatgpt-acc"},
Extra: map[string]any{"openai_passthrough": false, "openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeOff},
Status: StatusActive,
Schedulable: true,
RateMultiplier: f64p(1),
}
result, err := svc.Forward(reqCtx, c, account, originalBody)
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, upstream.lastReq)
require.NoError(t, upstream.lastReq.Context().Err())
}
func TestOpenAIGatewayService_OAuthLegacy_CompositeCodexUAUsesCodexOriginator(t *testing.T) { func TestOpenAIGatewayService_OAuthLegacy_CompositeCodexUAUsesCodexOriginator(t *testing.T) {
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)

View File

@ -0,0 +1,70 @@
package service
import (
"strings"
"github.com/tidwall/gjson"
)
type openAISSEDataAccumulator struct {
lines []string
}
func (a *openAISSEDataAccumulator) AddLine(line string, fn func([]byte)) {
if fn == nil {
return
}
trimmedLine := strings.TrimRight(line, "\r\n")
if data, ok := extractOpenAISSEDataLine(trimmedLine); ok {
a.lines = append(a.lines, data)
return
}
if strings.TrimSpace(trimmedLine) == "" {
a.Flush(fn)
}
}
func (a *openAISSEDataAccumulator) Flush(fn func([]byte)) {
if fn == nil || len(a.lines) == 0 {
return
}
emitOpenAISSEDataPayloads(a.lines, fn)
a.lines = a.lines[:0]
}
func forEachOpenAISSEDataPayload(body string, fn func([]byte)) {
if fn == nil || strings.TrimSpace(body) == "" {
return
}
var acc openAISSEDataAccumulator
for _, line := range strings.Split(body, "\n") {
acc.AddLine(line, fn)
}
acc.Flush(fn)
}
func emitOpenAISSEDataPayloads(lines []string, fn func([]byte)) {
if fn == nil || len(lines) == 0 {
return
}
if len(lines) == 1 {
emitOpenAISSEDataPayload(lines[0], fn)
return
}
joined := strings.Join(lines, "\n")
if gjson.Valid(joined) {
emitOpenAISSEDataPayload(joined, fn)
return
}
for _, line := range lines {
emitOpenAISSEDataPayload(line, fn)
}
}
func emitOpenAISSEDataPayload(data string, fn func([]byte)) {
data = strings.TrimSpace(data)
if data == "" || data == "[DONE]" {
return
}
fn([]byte(data))
}

View File

@ -219,8 +219,11 @@ func (e *OpenAIWSClientCloseError) Reason() string {
// OpenAIWSIngressHooks 定义入站 WS 每个 turn 的生命周期回调。 // OpenAIWSIngressHooks 定义入站 WS 每个 turn 的生命周期回调。
type OpenAIWSIngressHooks struct { type OpenAIWSIngressHooks struct {
BeforeTurn func(turn int) error // InitialRequestModel 是首帧渠道映射前的请求模型,只用于 usage metadata
AfterTurn func(turn int, result *OpenAIForwardResult, turnErr error) // 的 reasoning effort 后缀推导,禁止用于上游请求或计费模型。
InitialRequestModel string
BeforeTurn func(turn int) error
AfterTurn func(turn int, result *OpenAIForwardResult, turnErr error)
} }
func normalizeOpenAIWSLogValue(value string) string { func normalizeOpenAIWSLogValue(value string) string {
@ -1987,6 +1990,7 @@ func (s *OpenAIGatewayService) forwardOpenAIWSV2(
} }
usage := &OpenAIUsage{} usage := &OpenAIUsage{}
imageCounter := newOpenAIImageOutputCounter()
var firstTokenMs *int var firstTokenMs *int
responseID := "" responseID := ""
var finalResponse []byte var finalResponse []byte
@ -2168,6 +2172,7 @@ func (s *OpenAIGatewayService) forwardOpenAIWSV2(
if openAIWSEventShouldParseUsage(eventType) { if openAIWSEventShouldParseUsage(eventType) {
parseOpenAIWSResponseUsageFromCompletedEvent(message, usage) parseOpenAIWSResponseUsageFromCompletedEvent(message, usage)
} }
imageCounter.AddSSEData(message)
if eventType == "error" { if eventType == "error" {
errCodeRaw, errTypeRaw, errMsgRaw := parseOpenAIWSErrorEventFields(message) errCodeRaw, errTypeRaw, errMsgRaw := parseOpenAIWSErrorEventFields(message)
@ -2340,6 +2345,7 @@ func (s *OpenAIGatewayService) forwardOpenAIWSV2(
Usage: *usage, Usage: *usage,
Model: originalModel, Model: originalModel,
UpstreamModel: mappedModel, UpstreamModel: mappedModel,
ImageCount: imageCounter.Count(),
ServiceTier: extractOpenAIServiceTier(reqBody), ServiceTier: extractOpenAIServiceTier(reqBody),
ReasoningEffort: extractOpenAIReasoningEffort(reqBody, originalModel), ReasoningEffort: extractOpenAIReasoningEffort(reqBody, originalModel),
Stream: reqStream, Stream: reqStream,
@ -2446,6 +2452,8 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
promptCacheKey string promptCacheKey string
previousResponseID string previousResponseID string
originalModel string originalModel string
imageBillingModel string
imageSizeTier string
payloadBytes int payloadBytes int
} }
@ -2543,6 +2551,19 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
} }
normalized = next normalized = next
} }
imageIntent := IsImageGenerationIntent(openAIResponsesEndpoint, originalModel, normalized)
if imageIntent && !GroupAllowsImageGeneration(apiKeyGroup(getAPIKeyFromContext(c))) {
return openAIWSClientPayload{}, NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, ImageGenerationPermissionMessage(), nil)
}
imageBillingModel := ""
imageSizeTier := ""
if imageIntent {
var imageCfgErr error
imageBillingModel, imageSizeTier, imageCfgErr = resolveOpenAIResponsesImageBillingConfigFromBody(normalized, originalModel)
if imageCfgErr != nil {
return openAIWSClientPayload{}, NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, imageCfgErr.Error(), imageCfgErr)
}
}
// Apply OpenAI Fast Policy on the response.create frame using the same // Apply OpenAI Fast Policy on the response.create frame using the same
// evaluator/normalize/scope rules as the HTTP entrypoints. This is the // evaluator/normalize/scope rules as the HTTP entrypoints. This is the
@ -2588,6 +2609,8 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
promptCacheKey: promptCacheKey, promptCacheKey: promptCacheKey,
previousResponseID: previousResponseID, previousResponseID: previousResponseID,
originalModel: originalModel, originalModel: originalModel,
imageBillingModel: imageBillingModel,
imageSizeTier: imageSizeTier,
payloadBytes: len(normalized), payloadBytes: len(normalized),
}, nil }, nil
} }
@ -2789,7 +2812,7 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
return payload, nil return payload, nil
} }
sendAndRelay := func(turn int, lease *openAIWSConnLease, payload []byte, payloadBytes int, originalModel string) (*OpenAIForwardResult, error) { sendAndRelay := func(turn int, lease *openAIWSConnLease, payload []byte, payloadBytes int, originalModel string, imageBillingModel string, imageSizeTier string) (*OpenAIForwardResult, error) {
if lease == nil { if lease == nil {
return nil, errors.New("upstream websocket lease is nil") return nil, errors.New("upstream websocket lease is nil")
} }
@ -2814,6 +2837,7 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
responseID := "" responseID := ""
usage := OpenAIUsage{} usage := OpenAIUsage{}
imageCounter := newOpenAIImageOutputCounter()
var firstTokenMs *int var firstTokenMs *int
reqStream := openAIWSPayloadBoolFromRaw(payload, "stream", true) reqStream := openAIWSPayloadBoolFromRaw(payload, "stream", true)
turnPreviousResponseID := openAIWSPayloadStringFromRaw(payload, "previous_response_id") turnPreviousResponseID := openAIWSPayloadStringFromRaw(payload, "previous_response_id")
@ -2935,6 +2959,7 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
if openAIWSEventShouldParseUsage(eventType) { if openAIWSEventShouldParseUsage(eventType) {
parseOpenAIWSResponseUsageFromCompletedEvent(upstreamMessage, &usage) parseOpenAIWSResponseUsageFromCompletedEvent(upstreamMessage, &usage)
} }
imageCounter.AddSSEData(upstreamMessage)
if !clientDisconnected { if !clientDisconnected {
if needModelReplace && len(mappedModelBytes) > 0 && openAIWSEventMayContainModel(eventType) && bytes.Contains(upstreamMessage, mappedModelBytes) { if needModelReplace && len(mappedModelBytes) > 0 && openAIWSEventMayContainModel(eventType) && bytes.Contains(upstreamMessage, mappedModelBytes) {
@ -2994,7 +3019,8 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
clientDisconnected, clientDisconnected,
) )
} }
return &OpenAIForwardResult{ imageCount := imageCounter.Count()
result := &OpenAIForwardResult{
RequestID: responseID, RequestID: responseID,
Usage: usage, Usage: usage,
Model: originalModel, Model: originalModel,
@ -3006,13 +3032,21 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
ResponseHeaders: lease.HandshakeHeaders(), ResponseHeaders: lease.HandshakeHeaders(),
Duration: time.Since(turnStart), Duration: time.Since(turnStart),
FirstTokenMs: firstTokenMs, FirstTokenMs: firstTokenMs,
}, nil }
if imageCount > 0 {
result.ImageCount = imageCount
result.ImageSize = imageSizeTier
result.BillingModel = imageBillingModel
}
return result, nil
} }
} }
} }
currentPayload := firstPayload.payloadRaw currentPayload := firstPayload.payloadRaw
currentOriginalModel := firstPayload.originalModel currentOriginalModel := firstPayload.originalModel
currentImageBillingModel := firstPayload.imageBillingModel
currentImageSizeTier := firstPayload.imageSizeTier
currentPayloadBytes := firstPayload.payloadBytes currentPayloadBytes := firstPayload.payloadBytes
isStrictAffinityTurn := func(payload []byte) bool { isStrictAffinityTurn := func(payload []byte) bool {
if !storeDisabled { if !storeDisabled {
@ -3101,6 +3135,12 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
if turnPrevRecoveryTried || !s.openAIWSIngressPreviousResponseRecoveryEnabled() { if turnPrevRecoveryTried || !s.openAIWSIngressPreviousResponseRecoveryEnabled() {
return false return false
} }
// 携带 function_call_output 的请求不能丢弃 previous_response_id
// 上游 API 需要 response chain 来匹配 tool_result 与之前的 tool_use
// 丢弃后会导致 "No tool call found for function call output" 400 错误。
if gjson.GetBytes(currentPayload, `input.#(type=="function_call_output")`).Exists() {
return false
}
if isStrictAffinityTurn(currentPayload) { if isStrictAffinityTurn(currentPayload) {
// Layer 2严格亲和链路命中 previous_response_not_found 时,降级为“去掉 previous_response_id 后重放一次”。 // Layer 2严格亲和链路命中 previous_response_not_found 时,降级为“去掉 previous_response_id 后重放一次”。
// 该错误说明续链锚点已失效,继续 strict fail-close 只会直接中断本轮请求。 // 该错误说明续链锚点已失效,继续 strict fail-close 只会直接中断本轮请求。
@ -3367,7 +3407,11 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
truncateOpenAIWSLogValue(pingErr.Error(), openAIWSLogValueMaxLen), truncateOpenAIWSLogValue(pingErr.Error(), openAIWSLogValueMaxLen),
) )
if forcePreferredConn { if forcePreferredConn {
if !turnPrevRecoveryTried && currentPreviousResponseID != "" { // 携带 function_call_output 的请求不能丢弃 previous_response_id
// 上游 API 需要 response chain 来匹配 tool_result 与之前的 tool_use
// 丢弃后会导致 "No tool call found for function call output" 400 错误。
hasFCOutput := gjson.GetBytes(currentPayload, `input.#(type=="function_call_output")`).Exists()
if !turnPrevRecoveryTried && currentPreviousResponseID != "" && !hasFCOutput {
updatedPayload, removed, dropErr := dropPreviousResponseIDFromRawPayload(currentPayload) updatedPayload, removed, dropErr := dropPreviousResponseIDFromRawPayload(currentPayload)
if dropErr != nil || !removed { if dropErr != nil || !removed {
reason := "not_removed" reason := "not_removed"
@ -3457,7 +3501,7 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
) )
} }
result, relayErr := sendAndRelay(turn, sessionLease, currentPayload, currentPayloadBytes, currentOriginalModel) result, relayErr := sendAndRelay(turn, sessionLease, currentPayload, currentPayloadBytes, currentOriginalModel, currentImageBillingModel, currentImageSizeTier)
if relayErr != nil { if relayErr != nil {
lastTurnClean = false lastTurnClean = false
if recoverIngressPrevResponseNotFound(relayErr, turn, connID) { if recoverIngressPrevResponseNotFound(relayErr, turn, connID) {
@ -3579,6 +3623,8 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
} }
currentPayload = nextPayload.payloadRaw currentPayload = nextPayload.payloadRaw
currentOriginalModel = nextPayload.originalModel currentOriginalModel = nextPayload.originalModel
currentImageBillingModel = nextPayload.imageBillingModel
currentImageSizeTier = nextPayload.imageSizeTier
currentPayloadBytes = nextPayload.payloadBytes currentPayloadBytes = nextPayload.payloadBytes
storeDisabled = s.isOpenAIWSStoreDisabledInRequestRaw(currentPayload, account) storeDisabled = s.isOpenAIWSStoreDisabledInRequestRaw(currentPayload, account)
if !storeDisabled { if !storeDisabled {

View File

@ -399,7 +399,7 @@ func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_PassthroughModeR
}() }()
writeCtx, cancelWrite := context.WithTimeout(context.Background(), 3*time.Second) writeCtx, cancelWrite := context.WithTimeout(context.Background(), 3*time.Second)
err = clientConn.Write(writeCtx, coderws.MessageText, []byte(`{"type":"response.create","model":"gpt-5.1","stream":false,"service_tier":"fast"}`)) err = clientConn.Write(writeCtx, coderws.MessageText, []byte(`{"type":"response.create","model":"gpt-5.1","stream":false,"service_tier":"fast","reasoning":{"effort":"HIGH"}}`))
cancelWrite() cancelWrite()
require.NoError(t, err) require.NoError(t, err)
@ -431,6 +431,8 @@ func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_PassthroughModeR
require.Equal(t, 3, result.Usage.OutputTokens) require.Equal(t, 3, result.Usage.OutputTokens)
require.NotNil(t, result.ServiceTier) require.NotNil(t, result.ServiceTier)
require.Equal(t, "priority", *result.ServiceTier) require.Equal(t, "priority", *result.ServiceTier)
require.NotNil(t, result.ReasoningEffort)
require.Equal(t, "high", *result.ReasoningEffort)
case <-time.After(2 * time.Second): case <-time.After(2 * time.Second):
t.Fatal("未收到 passthrough turn 结果回调") t.Fatal("未收到 passthrough turn 结果回调")
} }

View File

@ -171,6 +171,127 @@ func TestOpenAIGatewayService_Forward_WSv2_SuccessAndBindSticky(t *testing.T) {
require.Equal(t, "resp_new_1", gjson.GetBytes(responseBody, "id").String()) require.Equal(t, "resp_new_1", gjson.GetBytes(responseBody, "id").String())
} }
func TestOpenAIGatewayService_Forward_WSv2_ImageGenerationCountsOutputs(t *testing.T) {
gin.SetMode(gin.TestMode)
upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}
wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
t.Errorf("upgrade websocket failed: %v", err)
return
}
defer func() {
_ = conn.Close()
}()
var request map[string]any
if err := conn.ReadJSON(&request); err != nil {
t.Errorf("read ws request failed: %v", err)
return
}
if err := conn.WriteJSON(map[string]any{
"type": "response.output_item.done",
"item": map[string]any{
"id": "ig_ws_1",
"type": "image_generation_call",
"result": "final-image",
},
}); err != nil {
t.Errorf("write response.output_item.done failed: %v", err)
return
}
if err := conn.WriteJSON(map[string]any{
"type": "response.completed",
"response": map[string]any{
"id": "resp_ws_image_1",
"model": "gpt-5.4",
"output": []any{
map[string]any{
"id": "ig_ws_1",
"type": "image_generation_call",
"result": "final-image",
},
},
"usage": map[string]any{
"input_tokens": 9,
"output_tokens": 4,
},
},
}); err != nil {
t.Errorf("write response.completed failed: %v", err)
return
}
}))
defer wsServer.Close()
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
groupID := int64(1010)
c.Set("api_key", &APIKey{
GroupID: &groupID,
Group: &Group{
ID: groupID,
AllowImageGeneration: true,
},
})
cfg := &config.Config{}
cfg.Security.URLAllowlist.Enabled = false
cfg.Security.URLAllowlist.AllowInsecureHTTP = true
cfg.Gateway.OpenAIWS.Enabled = true
cfg.Gateway.OpenAIWS.OAuthEnabled = true
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1
cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1
cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8
cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3
cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 5
cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3
svc := &OpenAIGatewayService{
cfg: cfg,
httpUpstream: &httpUpstreamRecorder{},
cache: &stubGatewayCache{},
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
toolCorrector: NewCodexToolCorrector(),
}
account := &Account{
ID: 10,
Name: "openai-ws-image",
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Credentials: map[string]any{
"api_key": "sk-test",
"base_url": wsServer.URL,
},
Extra: map[string]any{
"responses_websockets_v2_enabled": true,
},
}
body := []byte(`{"model":"gpt-5.4","stream":false,"input":"draw","tools":[{"type":"image_generation","model":"gpt-image-2","size":"1024x1024"}],"tool_choice":{"type":"image_generation"}}`)
result, err := svc.Forward(context.Background(), c, account, body)
require.NoError(t, err)
require.NotNil(t, result)
require.Equal(t, "resp_ws_image_1", result.RequestID)
require.Equal(t, 1, result.ImageCount)
require.Equal(t, "1K", result.ImageSize)
require.Equal(t, "gpt-image-2", result.BillingModel)
require.Equal(t, 9, result.Usage.InputTokens)
require.Equal(t, 4, result.Usage.OutputTokens)
require.True(t, result.OpenAIWSMode)
require.Equal(t, "resp_ws_image_1", gjson.GetBytes(rec.Body.Bytes(), "id").String())
}
func requestToJSONString(payload map[string]any) string { func requestToJSONString(payload map[string]any) string {
if len(payload) == 0 { if len(payload) == 0 {
return "{}" return "{}"

View File

@ -124,6 +124,73 @@ func openAIWSPassthroughPolicyModelFromSessionFrame(account *Account, payload []
return normalizeOpenAIModelForUpstream(account, account.GetMappedModel(original)) return normalizeOpenAIModelForUpstream(account, account.GetMappedModel(original))
} }
type openAIWSPassthroughUsageMeta struct {
serviceTier atomic.Pointer[string]
reasoningEffort atomic.Pointer[string]
// 仅在 client->upstream filter goroutine 中读写Load 侧通过上方原子指针同步。
sessionRequestModel string
}
func newOpenAIWSPassthroughUsageMeta(initialRequestModel string, firstFrame []byte) *openAIWSPassthroughUsageMeta {
meta := &openAIWSPassthroughUsageMeta{
sessionRequestModel: strings.TrimSpace(initialRequestModel),
}
if meta.sessionRequestModel == "" {
meta.sessionRequestModel = openAIWSPassthroughRequestModelForFrame(firstFrame)
}
return meta
}
func (m *openAIWSPassthroughUsageMeta) initFromFirstFrame(policyOutput []byte) {
if m == nil {
return
}
m.serviceTier.Store(extractOpenAIServiceTierFromBody(policyOutput))
m.reasoningEffort.Store(extractOpenAIReasoningEffortFromBody(policyOutput, m.sessionRequestModel))
}
func (m *openAIWSPassthroughUsageMeta) updateSessionRequestModel(payload []byte) {
if m == nil {
return
}
if model := openAIWSPassthroughRequestModelFromSessionFrame(payload); model != "" {
m.sessionRequestModel = model
}
}
func (m *openAIWSPassthroughUsageMeta) requestModelForFrame(payload []byte) string {
if m == nil {
return openAIWSPassthroughRequestModelForFrame(payload)
}
if model := openAIWSPassthroughRequestModelForFrame(payload); model != "" {
return model
}
return m.sessionRequestModel
}
func (m *openAIWSPassthroughUsageMeta) updateFromResponseCreate(policyOutput []byte, requestModelForFrame string) {
if m == nil {
return
}
m.serviceTier.Store(extractOpenAIServiceTierFromBody(policyOutput))
m.reasoningEffort.Store(extractOpenAIReasoningEffortFromBody(policyOutput, requestModelForFrame))
}
func openAIWSPassthroughRequestModelForFrame(payload []byte) string {
if len(payload) == 0 || strings.TrimSpace(gjson.GetBytes(payload, "type").String()) != "response.create" {
return ""
}
return strings.TrimSpace(gjson.GetBytes(payload, "model").String())
}
func openAIWSPassthroughRequestModelFromSessionFrame(payload []byte) string {
if len(payload) == 0 || strings.TrimSpace(gjson.GetBytes(payload, "type").String()) != "session.update" {
return ""
}
return strings.TrimSpace(gjson.GetBytes(payload, "session.model").String())
}
const openaiWSV2PassthroughModeFields = "ws_mode=passthrough ws_router=v2" const openaiWSV2PassthroughModeFields = "ws_mode=passthrough ws_router=v2"
var _ openaiwsv2.FrameConn = (*openAIWSClientFrameConn)(nil) var _ openaiwsv2.FrameConn = (*openAIWSClientFrameConn)(nil)
@ -204,6 +271,11 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough(
// silently passed through, defeating the policy on every frame after // silently passed through, defeating the policy on every frame after
// the first. // the first.
capturedSessionModel := openAIWSPassthroughPolicyModelForFrame(account, firstClientMessage) capturedSessionModel := openAIWSPassthroughPolicyModelForFrame(account, firstClientMessage)
initialRequestModel := ""
if hooks != nil {
initialRequestModel = hooks.InitialRequestModel
}
usageMeta := newOpenAIWSPassthroughUsageMeta(initialRequestModel, firstClientMessage)
updatedFirst, blocked, policyErr := s.applyOpenAIFastPolicyToWSResponseCreate(ctx, account, capturedSessionModel, firstClientMessage) updatedFirst, blocked, policyErr := s.applyOpenAIFastPolicyToWSResponseCreate(ctx, account, capturedSessionModel, firstClientMessage)
if policyErr != nil { if policyErr != nil {
return fmt.Errorf("apply openai fast policy on first ws frame: %w", policyErr) return fmt.Errorf("apply openai fast policy on first ws frame: %w", policyErr)
@ -226,7 +298,8 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough(
} }
firstClientMessage = updatedFirst firstClientMessage = updatedFirst
// 在 policy filter 之后再提取 service_tier 用于 billing 上报filter // 在 policy filter 之后再提取 service_tier / reasoning_effort 用于
// usage 上报filter
// 命中时 service_tier 已经从 firstClientMessage 中删除billing 应当 // 命中时 service_tier 已经从 firstClientMessage 中删除billing 应当
// 反映上游实际处理的 tiernil = default而不是用户最初请求的 // 反映上游实际处理的 tiernil = default而不是用户最初请求的
// "priority"。HTTP 入口line ~2728 extractOpenAIServiceTier(reqBody) // "priority"。HTTP 入口line ~2728 extractOpenAIServiceTier(reqBody)
@ -237,11 +310,8 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough(
// codex-rs/core/src/client.rs build_responses_request 每次重新填值)。 // codex-rs/core/src/client.rs build_responses_request 每次重新填值)。
// 因此使用 atomic.Pointer[string] 在 filterrunClientToUpstream // 因此使用 atomic.Pointer[string] 在 filterrunClientToUpstream
// goroutine和 OnTurnComplete / final resultrunUpstreamToClient // goroutine和 OnTurnComplete / final resultrunUpstreamToClient
// goroutine之间同步当前 turn 的 service_tier。 // goroutine之间同步当前 turn 的 usage metadata。
// extractOpenAIServiceTierFromBody 返回 *string本身是指针类型 usageMeta.initFromFirstFrame(firstClientMessage)
// 可直接 Store/Load 而无需额外封装。
var requestServiceTierPtr atomic.Pointer[string]
requestServiceTierPtr.Store(extractOpenAIServiceTierFromBody(firstClientMessage))
wsURL, err := s.buildOpenAIResponsesWSURL(account) wsURL, err := s.buildOpenAIResponsesWSURL(account)
if err != nil { if err != nil {
@ -327,6 +397,8 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough(
if updated := openAIWSPassthroughPolicyModelFromSessionFrame(account, payload); updated != "" { if updated := openAIWSPassthroughPolicyModelFromSessionFrame(account, payload); updated != "" {
capturedSessionModel = updated capturedSessionModel = updated
} }
usageMeta.updateSessionRequestModel(payload)
requestModelForThisFrame := usageMeta.requestModelForFrame(payload)
// Per-frame model first; if the client omits "model" on a // Per-frame model first; if the client omits "model" on a
// follow-up frame (legal in Realtime), fall back to the // follow-up frame (legal in Realtime), fall back to the
// session-level model captured from the first frame so the // session-level model captured from the first frame so the
@ -337,14 +409,14 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough(
model = capturedSessionModel model = capturedSessionModel
} }
out, blocked, policyErr := s.applyOpenAIFastPolicyToWSResponseCreate(ctx, account, model, payload) out, blocked, policyErr := s.applyOpenAIFastPolicyToWSResponseCreate(ctx, account, model, payload)
// 多轮 passthrough billing仅在成功non-block / non-err // 多轮 passthrough usage仅在成功non-block / non-err
// 的 response.create 帧上更新 requestServiceTierPtr,使用 // 的 response.create 帧上更新 usageMeta,使用
// filter 处理后的 payload与首帧 policy-after-extract 语义 // filter 处理后的 payload与首帧 policy-after-extract 语义
// 保持一致(参见上方 extractOpenAIServiceTierFromBody 注释)。 // 保持一致(参见上方 extractOpenAIServiceTierFromBody 注释)。
// - 非 response.create 帧response.cancel / // - 非 response.create 帧response.cancel /
// conversation.item.create / session.update 等)不携带 // conversation.item.create / session.update 等)不携带
// per-response service_tier,不应覆盖前一轮值。 // per-response metadata,不应覆盖前一轮值。
// - blocked != nil该帧不会发送上游billing tier 应保持 // - blocked != nil该帧不会发送上游usage metadata 应保持
// 上一轮值。 // 上一轮值。
// - policyErr != nil异常路径保持上一轮值。 // - policyErr != nil异常路径保持上一轮值。
// - 不带 service_tier 的 response.create 会让 // - 不带 service_tier 的 response.create 会让
@ -353,7 +425,7 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough(
// service_tier 时按 default 处理billing 应如实反映。 // service_tier 时按 default 处理billing 应如实反映。
if policyErr == nil && blocked == nil && if policyErr == nil && blocked == nil &&
strings.TrimSpace(gjson.GetBytes(payload, "type").String()) == "response.create" { strings.TrimSpace(gjson.GetBytes(payload, "type").String()) == "response.create" {
requestServiceTierPtr.Store(extractOpenAIServiceTierFromBody(out)) usageMeta.updateFromResponseCreate(out, requestModelForThisFrame)
} }
return out, blocked, policyErr return out, blocked, policyErr
}, },
@ -397,7 +469,8 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough(
CacheReadInputTokens: turn.Usage.CacheReadInputTokens, CacheReadInputTokens: turn.Usage.CacheReadInputTokens,
}, },
Model: turn.RequestModel, Model: turn.RequestModel,
ServiceTier: requestServiceTierPtr.Load(), ServiceTier: usageMeta.serviceTier.Load(),
ReasoningEffort: usageMeta.reasoningEffort.Load(),
Stream: true, Stream: true,
OpenAIWSMode: true, OpenAIWSMode: true,
ResponseHeaders: cloneHeader(handshakeHeaders), ResponseHeaders: cloneHeader(handshakeHeaders),
@ -445,7 +518,8 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough(
CacheReadInputTokens: relayResult.Usage.CacheReadInputTokens, CacheReadInputTokens: relayResult.Usage.CacheReadInputTokens,
}, },
Model: relayResult.RequestModel, Model: relayResult.RequestModel,
ServiceTier: requestServiceTierPtr.Load(), ServiceTier: usageMeta.serviceTier.Load(),
ReasoningEffort: usageMeta.reasoningEffort.Load(),
Stream: true, Stream: true,
OpenAIWSMode: true, OpenAIWSMode: true,
ResponseHeaders: cloneHeader(handshakeHeaders), ResponseHeaders: cloneHeader(handshakeHeaders),

View File

@ -0,0 +1,164 @@
package service
import (
"context"
"database/sql"
"fmt"
"strings"
"time"
)
const (
opsCleanupDefaultSchedule = "0 2 * * *"
opsCleanupBatchSize = 5000
opsCleanupCronStopTimeout = 3 * time.Second
opsCleanupRunTimeout = 30 * time.Minute
opsCleanupHeartbeatTimeout = 2 * time.Second
)
type opsCleanupTarget struct {
retentionDays int
table string
timeCol string
castDate bool
counter *int64
}
type opsCleanupDeletedCounts struct {
errorLogs int64
retryAttempts int64
alertEvents int64
systemLogs int64
logAudits int64
systemMetrics int64
hourlyPreagg int64
dailyPreagg int64
}
func (c opsCleanupDeletedCounts) String() string {
return fmt.Sprintf(
"error_logs=%d retry_attempts=%d alert_events=%d system_logs=%d log_audits=%d system_metrics=%d hourly_preagg=%d daily_preagg=%d",
c.errorLogs,
c.retryAttempts,
c.alertEvents,
c.systemLogs,
c.logAudits,
c.systemMetrics,
c.hourlyPreagg,
c.dailyPreagg,
)
}
// opsCleanupPlan 把"保留天数"翻译成具体的清理动作。
// - days < 0 → 跳过该项清理ok=false保留兼容老数据
// - days == 0 → TRUNCATE TABLEO(1) 全清truncate=true
// - days > 0 → 批量 DELETE 早于 now-N天 的行cutoff = now - N 天
func opsCleanupPlan(now time.Time, days int) (cutoff time.Time, truncate, ok bool) {
if days < 0 {
return time.Time{}, false, false
}
if days == 0 {
return time.Time{}, true, true
}
return now.AddDate(0, 0, -days), false, true
}
func opsCleanupRunOne(
ctx context.Context,
db *sql.DB,
truncate bool,
cutoff time.Time,
table, timeCol string,
castDate bool,
batchSize int,
) (int64, error) {
if truncate {
return truncateOpsTable(ctx, db, table)
}
return deleteOldRowsByID(ctx, db, table, timeCol, cutoff, batchSize, castDate)
}
func deleteOldRowsByID(
ctx context.Context,
db *sql.DB,
table string,
timeColumn string,
cutoff time.Time,
batchSize int,
castCutoffToDate bool,
) (int64, error) {
if db == nil {
return 0, nil
}
if batchSize <= 0 {
batchSize = opsCleanupBatchSize
}
where := fmt.Sprintf("%s < $1", timeColumn)
if castCutoffToDate {
where = fmt.Sprintf("%s < $1::date", timeColumn)
}
q := fmt.Sprintf(`
WITH batch AS (
SELECT id FROM %s
WHERE %s
ORDER BY id
LIMIT $2
)
DELETE FROM %s
WHERE id IN (SELECT id FROM batch)
`, table, where, table)
var total int64
for {
res, err := db.ExecContext(ctx, q, cutoff, batchSize)
if err != nil {
if isMissingRelationError(err) {
return total, nil
}
return total, err
}
affected, err := res.RowsAffected()
if err != nil {
return total, err
}
total += affected
if affected == 0 {
break
}
}
return total, nil
}
// truncateOpsTable 用 TRUNCATE TABLE 清空指定表,先 SELECT COUNT(*) 取得清空前行数用于 heartbeat。
func truncateOpsTable(ctx context.Context, db *sql.DB, table string) (int64, error) {
if db == nil {
return 0, nil
}
var count int64
if err := db.QueryRowContext(ctx, fmt.Sprintf("SELECT COUNT(*) FROM %s", table)).Scan(&count); err != nil {
if isMissingRelationError(err) {
return 0, nil
}
return 0, fmt.Errorf("count %s: %w", table, err)
}
if count == 0 {
return 0, nil
}
if _, err := db.ExecContext(ctx, fmt.Sprintf("TRUNCATE TABLE %s", table)); err != nil {
if isMissingRelationError(err) {
return 0, nil
}
return 0, fmt.Errorf("truncate %s: %w", table, err)
}
return count, nil
}
func isMissingRelationError(err error) bool {
if err == nil {
return false
}
s := strings.ToLower(err.Error())
return strings.Contains(s, "does not exist") && strings.Contains(s, "relation")
}

Some files were not shown because too many files have changed in this diff Show More