diff --git a/backend/ent/group.go b/backend/ent/group.go index a4f52c73..298df88a 100644 --- a/backend/ent/group.go +++ b/backend/ent/group.go @@ -85,6 +85,8 @@ type Group struct { DefaultMappedModel string `json:"default_mapped_model,omitempty"` // OpenAI Messages 调度模型配置:按 Claude 系列/精确模型映射到目标 GPT 模型 MessagesDispatchModelConfig domain.OpenAIMessagesDispatchModelConfig `json:"messages_dispatch_model_config,omitempty"` + // 自定义 /v1/models 展示列表配置;仅影响模型列表响应,不影响调度 + ModelsListConfig domain.GroupModelsListConfig `json:"models_list_config,omitempty"` // 分组 RPM 上限,0 表示不限制;设置后接管该分组用户的限流 RpmLimit int `json:"rpm_limit,omitempty"` // Edges holds the relations/edges for other nodes in the graph. @@ -193,7 +195,7 @@ func (*Group) scanValues(columns []string) ([]any, error) { values := make([]any, len(columns)) for i := range columns { switch columns[i] { - case group.FieldModelRouting, group.FieldSupportedModelScopes, group.FieldMessagesDispatchModelConfig: + case group.FieldModelRouting, group.FieldSupportedModelScopes, group.FieldMessagesDispatchModelConfig, group.FieldModelsListConfig: values[i] = new([]byte) case group.FieldIsExclusive, group.FieldAllowImageGeneration, group.FieldImageRateIndependent, group.FieldClaudeCodeOnly, group.FieldModelRoutingEnabled, group.FieldMcpXMLInject, group.FieldAllowMessagesDispatch, group.FieldRequireOauthOnly, group.FieldRequirePrivacySet: values[i] = new(sql.NullBool) @@ -440,6 +442,14 @@ func (_m *Group) assignValues(columns []string, values []any) error { return fmt.Errorf("unmarshal field messages_dispatch_model_config: %w", err) } } + case group.FieldModelsListConfig: + if value, ok := values[i].(*[]byte); !ok { + return fmt.Errorf("unexpected type %T for field models_list_config", values[i]) + } else if value != nil && len(*value) > 0 { + if err := json.Unmarshal(*value, &_m.ModelsListConfig); err != nil { + return fmt.Errorf("unmarshal field models_list_config: %w", err) + } + } case group.FieldRpmLimit: if value, ok := values[i].(*sql.NullInt64); !ok { return fmt.Errorf("unexpected type %T for field rpm_limit", values[i]) @@ -641,6 +651,9 @@ func (_m *Group) String() string { builder.WriteString("messages_dispatch_model_config=") builder.WriteString(fmt.Sprintf("%v", _m.MessagesDispatchModelConfig)) builder.WriteString(", ") + builder.WriteString("models_list_config=") + builder.WriteString(fmt.Sprintf("%v", _m.ModelsListConfig)) + builder.WriteString(", ") builder.WriteString("rpm_limit=") builder.WriteString(fmt.Sprintf("%v", _m.RpmLimit)) builder.WriteByte(')') diff --git a/backend/ent/group/group.go b/backend/ent/group/group.go index 4e9ba6b6..ebe9bd7e 100644 --- a/backend/ent/group/group.go +++ b/backend/ent/group/group.go @@ -82,6 +82,8 @@ const ( FieldDefaultMappedModel = "default_mapped_model" // FieldMessagesDispatchModelConfig holds the string denoting the messages_dispatch_model_config field in the database. FieldMessagesDispatchModelConfig = "messages_dispatch_model_config" + // FieldModelsListConfig holds the string denoting the models_list_config field in the database. + FieldModelsListConfig = "models_list_config" // FieldRpmLimit holds the string denoting the rpm_limit field in the database. FieldRpmLimit = "rpm_limit" // EdgeAPIKeys holds the string denoting the api_keys edge name in mutations. @@ -192,6 +194,7 @@ var Columns = []string{ FieldRequirePrivacySet, FieldDefaultMappedModel, FieldMessagesDispatchModelConfig, + FieldModelsListConfig, FieldRpmLimit, } @@ -276,6 +279,8 @@ var ( DefaultMappedModelValidator func(string) error // DefaultMessagesDispatchModelConfig holds the default value on creation for the "messages_dispatch_model_config" field. DefaultMessagesDispatchModelConfig domain.OpenAIMessagesDispatchModelConfig + // DefaultModelsListConfig holds the default value on creation for the "models_list_config" field. + DefaultModelsListConfig domain.GroupModelsListConfig // DefaultRpmLimit holds the default value on creation for the "rpm_limit" field. DefaultRpmLimit int ) diff --git a/backend/ent/group_create.go b/backend/ent/group_create.go index 44b905bd..d5ed0c19 100644 --- a/backend/ent/group_create.go +++ b/backend/ent/group_create.go @@ -467,6 +467,20 @@ func (_c *GroupCreate) SetNillableMessagesDispatchModelConfig(v *domain.OpenAIMe return _c } +// SetModelsListConfig sets the "models_list_config" field. +func (_c *GroupCreate) SetModelsListConfig(v domain.GroupModelsListConfig) *GroupCreate { + _c.mutation.SetModelsListConfig(v) + return _c +} + +// SetNillableModelsListConfig sets the "models_list_config" field if the given value is not nil. +func (_c *GroupCreate) SetNillableModelsListConfig(v *domain.GroupModelsListConfig) *GroupCreate { + if v != nil { + _c.SetModelsListConfig(*v) + } + return _c +} + // SetRpmLimit sets the "rpm_limit" field. func (_c *GroupCreate) SetRpmLimit(v int) *GroupCreate { _c.mutation.SetRpmLimit(v) @@ -698,6 +712,10 @@ func (_c *GroupCreate) defaults() error { v := group.DefaultMessagesDispatchModelConfig _c.mutation.SetMessagesDispatchModelConfig(v) } + if _, ok := _c.mutation.ModelsListConfig(); !ok { + v := group.DefaultModelsListConfig + _c.mutation.SetModelsListConfig(v) + } if _, ok := _c.mutation.RpmLimit(); !ok { v := group.DefaultRpmLimit _c.mutation.SetRpmLimit(v) @@ -798,6 +816,9 @@ func (_c *GroupCreate) check() error { if _, ok := _c.mutation.MessagesDispatchModelConfig(); !ok { return &ValidationError{Name: "messages_dispatch_model_config", err: errors.New(`ent: missing required field "Group.messages_dispatch_model_config"`)} } + if _, ok := _c.mutation.ModelsListConfig(); !ok { + return &ValidationError{Name: "models_list_config", err: errors.New(`ent: missing required field "Group.models_list_config"`)} + } if _, ok := _c.mutation.RpmLimit(); !ok { return &ValidationError{Name: "rpm_limit", err: errors.New(`ent: missing required field "Group.rpm_limit"`)} } @@ -960,6 +981,10 @@ func (_c *GroupCreate) createSpec() (*Group, *sqlgraph.CreateSpec) { _spec.SetField(group.FieldMessagesDispatchModelConfig, field.TypeJSON, value) _node.MessagesDispatchModelConfig = value } + if value, ok := _c.mutation.ModelsListConfig(); ok { + _spec.SetField(group.FieldModelsListConfig, field.TypeJSON, value) + _node.ModelsListConfig = value + } if value, ok := _c.mutation.RpmLimit(); ok { _spec.SetField(group.FieldRpmLimit, field.TypeInt, value) _node.RpmLimit = value @@ -1642,6 +1667,18 @@ func (u *GroupUpsert) UpdateMessagesDispatchModelConfig() *GroupUpsert { return u } +// SetModelsListConfig sets the "models_list_config" field. +func (u *GroupUpsert) SetModelsListConfig(v domain.GroupModelsListConfig) *GroupUpsert { + u.Set(group.FieldModelsListConfig, v) + return u +} + +// UpdateModelsListConfig sets the "models_list_config" field to the value that was provided on create. +func (u *GroupUpsert) UpdateModelsListConfig() *GroupUpsert { + u.SetExcluded(group.FieldModelsListConfig) + return u +} + // SetRpmLimit sets the "rpm_limit" field. func (u *GroupUpsert) SetRpmLimit(v int) *GroupUpsert { u.Set(group.FieldRpmLimit, v) @@ -2314,6 +2351,20 @@ func (u *GroupUpsertOne) UpdateMessagesDispatchModelConfig() *GroupUpsertOne { }) } +// SetModelsListConfig sets the "models_list_config" field. +func (u *GroupUpsertOne) SetModelsListConfig(v domain.GroupModelsListConfig) *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.SetModelsListConfig(v) + }) +} + +// UpdateModelsListConfig sets the "models_list_config" field to the value that was provided on create. +func (u *GroupUpsertOne) UpdateModelsListConfig() *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.UpdateModelsListConfig() + }) +} + // SetRpmLimit sets the "rpm_limit" field. func (u *GroupUpsertOne) SetRpmLimit(v int) *GroupUpsertOne { return u.Update(func(s *GroupUpsert) { @@ -3155,6 +3206,20 @@ func (u *GroupUpsertBulk) UpdateMessagesDispatchModelConfig() *GroupUpsertBulk { }) } +// SetModelsListConfig sets the "models_list_config" field. +func (u *GroupUpsertBulk) SetModelsListConfig(v domain.GroupModelsListConfig) *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.SetModelsListConfig(v) + }) +} + +// UpdateModelsListConfig sets the "models_list_config" field to the value that was provided on create. +func (u *GroupUpsertBulk) UpdateModelsListConfig() *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.UpdateModelsListConfig() + }) +} + // SetRpmLimit sets the "rpm_limit" field. func (u *GroupUpsertBulk) SetRpmLimit(v int) *GroupUpsertBulk { return u.Update(func(s *GroupUpsert) { diff --git a/backend/ent/group_update.go b/backend/ent/group_update.go index fe55982c..c10d60ec 100644 --- a/backend/ent/group_update.go +++ b/backend/ent/group_update.go @@ -616,6 +616,20 @@ func (_u *GroupUpdate) SetNillableMessagesDispatchModelConfig(v *domain.OpenAIMe return _u } +// SetModelsListConfig sets the "models_list_config" field. +func (_u *GroupUpdate) SetModelsListConfig(v domain.GroupModelsListConfig) *GroupUpdate { + _u.mutation.SetModelsListConfig(v) + return _u +} + +// SetNillableModelsListConfig sets the "models_list_config" field if the given value is not nil. +func (_u *GroupUpdate) SetNillableModelsListConfig(v *domain.GroupModelsListConfig) *GroupUpdate { + if v != nil { + _u.SetModelsListConfig(*v) + } + return _u +} + // SetRpmLimit sets the "rpm_limit" field. func (_u *GroupUpdate) SetRpmLimit(v int) *GroupUpdate { _u.mutation.ResetRpmLimit() @@ -1112,6 +1126,9 @@ func (_u *GroupUpdate) sqlSave(ctx context.Context) (_node int, err error) { if value, ok := _u.mutation.MessagesDispatchModelConfig(); ok { _spec.SetField(group.FieldMessagesDispatchModelConfig, field.TypeJSON, value) } + if value, ok := _u.mutation.ModelsListConfig(); ok { + _spec.SetField(group.FieldModelsListConfig, field.TypeJSON, value) + } if value, ok := _u.mutation.RpmLimit(); ok { _spec.SetField(group.FieldRpmLimit, field.TypeInt, value) } @@ -2012,6 +2029,20 @@ func (_u *GroupUpdateOne) SetNillableMessagesDispatchModelConfig(v *domain.OpenA return _u } +// SetModelsListConfig sets the "models_list_config" field. +func (_u *GroupUpdateOne) SetModelsListConfig(v domain.GroupModelsListConfig) *GroupUpdateOne { + _u.mutation.SetModelsListConfig(v) + return _u +} + +// SetNillableModelsListConfig sets the "models_list_config" field if the given value is not nil. +func (_u *GroupUpdateOne) SetNillableModelsListConfig(v *domain.GroupModelsListConfig) *GroupUpdateOne { + if v != nil { + _u.SetModelsListConfig(*v) + } + return _u +} + // SetRpmLimit sets the "rpm_limit" field. func (_u *GroupUpdateOne) SetRpmLimit(v int) *GroupUpdateOne { _u.mutation.ResetRpmLimit() @@ -2538,6 +2569,9 @@ func (_u *GroupUpdateOne) sqlSave(ctx context.Context) (_node *Group, err error) if value, ok := _u.mutation.MessagesDispatchModelConfig(); ok { _spec.SetField(group.FieldMessagesDispatchModelConfig, field.TypeJSON, value) } + if value, ok := _u.mutation.ModelsListConfig(); ok { + _spec.SetField(group.FieldModelsListConfig, field.TypeJSON, value) + } if value, ok := _u.mutation.RpmLimit(); ok { _spec.SetField(group.FieldRpmLimit, field.TypeInt, value) } diff --git a/backend/ent/migrate/schema.go b/backend/ent/migrate/schema.go index 447f71ef..7abe4c60 100644 --- a/backend/ent/migrate/schema.go +++ b/backend/ent/migrate/schema.go @@ -669,6 +669,7 @@ var ( {Name: "require_privacy_set", Type: field.TypeBool, Default: false}, {Name: "default_mapped_model", Type: field.TypeString, Size: 100, Default: ""}, {Name: "messages_dispatch_model_config", Type: field.TypeJSON, SchemaType: map[string]string{"postgres": "jsonb"}}, + {Name: "models_list_config", Type: field.TypeJSON, SchemaType: map[string]string{"postgres": "jsonb"}}, {Name: "rpm_limit", Type: field.TypeInt, Default: 0}, } // GroupsTable holds the schema information for the "groups" table. diff --git a/backend/ent/mutation.go b/backend/ent/mutation.go index 2e8fa7f4..003e25d5 100644 --- a/backend/ent/mutation.go +++ b/backend/ent/mutation.go @@ -14901,6 +14901,7 @@ type GroupMutation struct { require_privacy_set *bool default_mapped_model *string messages_dispatch_model_config *domain.OpenAIMessagesDispatchModelConfig + models_list_config *domain.GroupModelsListConfig rpm_limit *int addrpm_limit *int clearedFields map[string]struct{} @@ -16619,6 +16620,42 @@ func (m *GroupMutation) ResetMessagesDispatchModelConfig() { m.messages_dispatch_model_config = nil } +// SetModelsListConfig sets the "models_list_config" field. +func (m *GroupMutation) SetModelsListConfig(dmlc domain.GroupModelsListConfig) { + m.models_list_config = &dmlc +} + +// ModelsListConfig returns the value of the "models_list_config" field in the mutation. +func (m *GroupMutation) ModelsListConfig() (r domain.GroupModelsListConfig, exists bool) { + v := m.models_list_config + if v == nil { + return + } + return *v, true +} + +// OldModelsListConfig returns the old "models_list_config" field's value of the Group entity. +// If the Group object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *GroupMutation) OldModelsListConfig(ctx context.Context) (v domain.GroupModelsListConfig, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldModelsListConfig is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldModelsListConfig requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldModelsListConfig: %w", err) + } + return oldValue.ModelsListConfig, nil +} + +// ResetModelsListConfig resets all changes to the "models_list_config" field. +func (m *GroupMutation) ResetModelsListConfig() { + m.models_list_config = nil +} + // SetRpmLimit sets the "rpm_limit" field. func (m *GroupMutation) SetRpmLimit(i int) { m.rpm_limit = &i @@ -17033,7 +17070,7 @@ func (m *GroupMutation) Type() string { // order to get all numeric fields that were incremented/decremented, call // AddedFields(). func (m *GroupMutation) Fields() []string { - fields := make([]string, 0, 34) + fields := make([]string, 0, 35) if m.created_at != nil { fields = append(fields, group.FieldCreatedAt) } @@ -17133,6 +17170,9 @@ func (m *GroupMutation) Fields() []string { if m.messages_dispatch_model_config != nil { fields = append(fields, group.FieldMessagesDispatchModelConfig) } + if m.models_list_config != nil { + fields = append(fields, group.FieldModelsListConfig) + } if m.rpm_limit != nil { fields = append(fields, group.FieldRpmLimit) } @@ -17210,6 +17250,8 @@ func (m *GroupMutation) Field(name string) (ent.Value, bool) { return m.DefaultMappedModel() case group.FieldMessagesDispatchModelConfig: return m.MessagesDispatchModelConfig() + case group.FieldModelsListConfig: + return m.ModelsListConfig() case group.FieldRpmLimit: return m.RpmLimit() } @@ -17287,6 +17329,8 @@ func (m *GroupMutation) OldField(ctx context.Context, name string) (ent.Value, e return m.OldDefaultMappedModel(ctx) case group.FieldMessagesDispatchModelConfig: return m.OldMessagesDispatchModelConfig(ctx) + case group.FieldModelsListConfig: + return m.OldModelsListConfig(ctx) case group.FieldRpmLimit: return m.OldRpmLimit(ctx) } @@ -17529,6 +17573,13 @@ func (m *GroupMutation) SetField(name string, value ent.Value) error { } m.SetMessagesDispatchModelConfig(v) return nil + case group.FieldModelsListConfig: + v, ok := value.(domain.GroupModelsListConfig) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetModelsListConfig(v) + return nil case group.FieldRpmLimit: v, ok := value.(int) if !ok { @@ -17912,6 +17963,9 @@ func (m *GroupMutation) ResetField(name string) error { case group.FieldMessagesDispatchModelConfig: m.ResetMessagesDispatchModelConfig() return nil + case group.FieldModelsListConfig: + m.ResetModelsListConfig() + return nil case group.FieldRpmLimit: m.ResetRpmLimit() return nil diff --git a/backend/ent/runtime/runtime.go b/backend/ent/runtime/runtime.go index aa6130f0..fdb837e8 100644 --- a/backend/ent/runtime/runtime.go +++ b/backend/ent/runtime/runtime.go @@ -870,8 +870,12 @@ func init() { groupDescMessagesDispatchModelConfig := groupFields[29].Descriptor() // group.DefaultMessagesDispatchModelConfig holds the default value on creation for the messages_dispatch_model_config field. group.DefaultMessagesDispatchModelConfig = groupDescMessagesDispatchModelConfig.Default.(domain.OpenAIMessagesDispatchModelConfig) + // groupDescModelsListConfig is the schema descriptor for models_list_config field. + groupDescModelsListConfig := groupFields[30].Descriptor() + // group.DefaultModelsListConfig holds the default value on creation for the models_list_config field. + group.DefaultModelsListConfig = groupDescModelsListConfig.Default.(domain.GroupModelsListConfig) // groupDescRpmLimit is the schema descriptor for rpm_limit field. - groupDescRpmLimit := groupFields[30].Descriptor() + groupDescRpmLimit := groupFields[31].Descriptor() // group.DefaultRpmLimit holds the default value on creation for the rpm_limit field. group.DefaultRpmLimit = groupDescRpmLimit.Default.(int) idempotencyrecordMixin := schema.IdempotencyRecord{}.Mixin() diff --git a/backend/ent/schema/group.go b/backend/ent/schema/group.go index d47e8710..2a1715f8 100644 --- a/backend/ent/schema/group.go +++ b/backend/ent/schema/group.go @@ -155,6 +155,10 @@ func (Group) Fields() []ent.Field { Default(domain.OpenAIMessagesDispatchModelConfig{}). SchemaType(map[string]string{dialect.Postgres: "jsonb"}). Comment("OpenAI Messages 调度模型配置:按 Claude 系列/精确模型映射到目标 GPT 模型"), + field.JSON("models_list_config", domain.GroupModelsListConfig{}). + Default(domain.GroupModelsListConfig{}). + SchemaType(map[string]string{dialect.Postgres: "jsonb"}). + Comment("自定义 /v1/models 展示列表配置;仅影响模型列表响应,不影响调度"), // 分组级每分钟请求数上限(0 = 不限制)。设置后优先于用户级兜底生效。 field.Int("rpm_limit"). diff --git a/backend/internal/domain/models_list_config.go b/backend/internal/domain/models_list_config.go new file mode 100644 index 00000000..3f050585 --- /dev/null +++ b/backend/internal/domain/models_list_config.go @@ -0,0 +1,7 @@ +package domain + +// GroupModelsListConfig controls the optional custom /v1/models response list. +type GroupModelsListConfig struct { + Enabled bool `json:"enabled"` + Models []string `json:"models,omitempty"` +} diff --git a/backend/internal/handler/admin/admin_basic_handlers_test.go b/backend/internal/handler/admin/admin_basic_handlers_test.go index 7b74bafc..bffddc8a 100644 --- a/backend/internal/handler/admin/admin_basic_handlers_test.go +++ b/backend/internal/handler/admin/admin_basic_handlers_test.go @@ -33,6 +33,7 @@ func setupAdminRouter() (*gin.Engine, *stubAdminService) { router.GET("/api/v1/admin/groups", groupHandler.List) router.GET("/api/v1/admin/groups/all", groupHandler.GetAll) + router.GET("/api/v1/admin/groups/:id/models-list-candidates", groupHandler.GetModelsListCandidates) router.GET("/api/v1/admin/groups/:id", groupHandler.GetByID) router.POST("/api/v1/admin/groups", groupHandler.Create) router.PUT("/api/v1/admin/groups/:id", groupHandler.Update) @@ -177,6 +178,12 @@ func TestGroupHandlerEndpoints(t *testing.T) { router.ServeHTTP(rec, req) require.Equal(t, http.StatusOK, rec.Code) + rec = httptest.NewRecorder() + req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/groups/0/models-list-candidates?platform=openai", nil) + router.ServeHTTP(rec, req) + require.Equal(t, http.StatusOK, rec.Code) + require.Contains(t, rec.Body.String(), "gpt-5.5") + body, _ := json.Marshal(map[string]any{"name": "new", "platform": "anthropic", "subscription_type": "standard"}) rec = httptest.NewRecorder() req = httptest.NewRequest(http.MethodPost, "/api/v1/admin/groups", bytes.NewReader(body)) diff --git a/backend/internal/handler/admin/admin_service_stub_test.go b/backend/internal/handler/admin/admin_service_stub_test.go index 65b71492..fd0ec459 100644 --- a/backend/internal/handler/admin/admin_service_stub_test.go +++ b/backend/internal/handler/admin/admin_service_stub_test.go @@ -265,6 +265,13 @@ func (s *stubAdminService) GetGroup(ctx context.Context, id int64) (*service.Gro return &group, nil } +func (s *stubAdminService) GetGroupModelsListCandidates(ctx context.Context, id int64, platform string) ([]string, error) { + if platform == service.PlatformOpenAI { + return []string{"gpt-5.5", "gpt-5.4"}, nil + } + return []string{"claude-sonnet-4-6"}, nil +} + func (s *stubAdminService) CreateGroup(ctx context.Context, input *service.CreateGroupInput) (*service.Group, error) { group := service.Group{ID: 200, Name: input.Name, Status: service.StatusActive} return &group, nil diff --git a/backend/internal/handler/admin/group_handler.go b/backend/internal/handler/admin/group_handler.go index 3667bbcd..dbf6f709 100644 --- a/backend/internal/handler/admin/group_handler.go +++ b/backend/internal/handler/admin/group_handler.go @@ -113,6 +113,7 @@ type CreateGroupRequest struct { RequirePrivacySet bool `json:"require_privacy_set"` DefaultMappedModel string `json:"default_mapped_model"` MessagesDispatchModelConfig service.OpenAIMessagesDispatchModelConfig `json:"messages_dispatch_model_config"` + ModelsListConfig service.GroupModelsListConfig `json:"models_list_config"` // 分组 RPM 上限(0 = 不限制) RPMLimit int `json:"rpm_limit"` // 从指定分组复制账号(创建后自动绑定) @@ -153,6 +154,7 @@ type UpdateGroupRequest struct { RequirePrivacySet *bool `json:"require_privacy_set"` DefaultMappedModel *string `json:"default_mapped_model"` MessagesDispatchModelConfig *service.OpenAIMessagesDispatchModelConfig `json:"messages_dispatch_model_config"` + ModelsListConfig *service.GroupModelsListConfig `json:"models_list_config"` // 分组 RPM 上限(0 = 不限制);nil 表示未提供不改动 RPMLimit *int `json:"rpm_limit"` // 从指定分组复制账号(同步操作:先清空当前分组的账号绑定,再绑定源分组的账号) @@ -238,6 +240,28 @@ func (h *GroupHandler) GetByID(c *gin.Context) { response.Success(c, dto.GroupFromServiceAdmin(group)) } +// GetModelsListCandidates handles getting candidate model IDs for custom /v1/models list. +// GET /api/v1/admin/groups/:id/models-list-candidates +func (h *GroupHandler) GetModelsListCandidates(c *gin.Context) { + groupID, err := strconv.ParseInt(c.Param("id"), 10, 64) + if err != nil || groupID < 0 { + response.BadRequest(c, "Invalid group ID") + return + } + + models, err := h.adminService.GetGroupModelsListCandidates( + c.Request.Context(), + groupID, + c.Query("platform"), + ) + if err != nil { + response.ErrorFrom(c, err) + return + } + + response.Success(c, gin.H{"models": models}) +} + // Create handles creating a new group // POST /api/v1/admin/groups func (h *GroupHandler) Create(c *gin.Context) { @@ -275,6 +299,7 @@ func (h *GroupHandler) Create(c *gin.Context) { RequirePrivacySet: req.RequirePrivacySet, DefaultMappedModel: req.DefaultMappedModel, MessagesDispatchModelConfig: req.MessagesDispatchModelConfig, + ModelsListConfig: req.ModelsListConfig, RPMLimit: req.RPMLimit, CopyAccountsFromGroupIDs: req.CopyAccountsFromGroupIDs, }) @@ -330,6 +355,7 @@ func (h *GroupHandler) Update(c *gin.Context) { RequirePrivacySet: req.RequirePrivacySet, DefaultMappedModel: req.DefaultMappedModel, MessagesDispatchModelConfig: req.MessagesDispatchModelConfig, + ModelsListConfig: req.ModelsListConfig, RPMLimit: req.RPMLimit, CopyAccountsFromGroupIDs: req.CopyAccountsFromGroupIDs, }) diff --git a/backend/internal/handler/dto/mappers.go b/backend/internal/handler/dto/mappers.go index 2c71be9d..51a11ea7 100644 --- a/backend/internal/handler/dto/mappers.go +++ b/backend/internal/handler/dto/mappers.go @@ -147,6 +147,7 @@ func GroupFromServiceAdmin(g *service.Group) *AdminGroup { MCPXMLInject: g.MCPXMLInject, DefaultMappedModel: g.DefaultMappedModel, MessagesDispatchModelConfig: g.MessagesDispatchModelConfig, + ModelsListConfig: g.ModelsListConfig, SupportedModelScopes: g.SupportedModelScopes, AccountCount: g.AccountCount, ActiveAccountCount: g.ActiveAccountCount, diff --git a/backend/internal/handler/dto/types.go b/backend/internal/handler/dto/types.go index 31828375..b1841c62 100644 --- a/backend/internal/handler/dto/types.go +++ b/backend/internal/handler/dto/types.go @@ -138,6 +138,7 @@ type AdminGroup struct { // OpenAI Messages 调度配置(仅 openai 平台使用) DefaultMappedModel string `json:"default_mapped_model"` MessagesDispatchModelConfig domain.OpenAIMessagesDispatchModelConfig `json:"messages_dispatch_model_config"` + ModelsListConfig domain.GroupModelsListConfig `json:"models_list_config"` // 支持的模型系列(仅 antigravity 平台使用) SupportedModelScopes []string `json:"supported_model_scopes"` diff --git a/backend/internal/handler/gateway_handler.go b/backend/internal/handler/gateway_handler.go index 87a935fd..4695a791 100644 --- a/backend/internal/handler/gateway_handler.go +++ b/backend/internal/handler/gateway_handler.go @@ -961,22 +961,14 @@ func (h *GatewayHandler) Models(c *gin.Context) { // Get available models from account configurations for the selected group platform. availableModels := h.gatewayService.GetAvailableModels(c.Request.Context(), groupID, platform) + if apiKey != nil && apiKey.Group != nil && apiKey.Group.CustomModelsListEnabled() { + availableModels = filterModelsByCustomList(availableModels, defaultModelIDsForPlatform(platform), apiKey.Group.ModelsListConfig.Models) + writeCustomModelsList(c, platform, availableModels) + return + } if len(availableModels) > 0 { - // Build model list from whitelist - models := make([]claude.Model, 0, len(availableModels)) - for _, modelID := range availableModels { - models = append(models, claude.Model{ - ID: modelID, - Type: "model", - DisplayName: modelID, - CreatedAt: "2024-01-01T00:00:00Z", - }) - } - c.JSON(http.StatusOK, gin.H{ - "object": "list", - "data": models, - }) + writeModelsList(c, availableModels) return } @@ -1003,6 +995,134 @@ func (h *GatewayHandler) Models(c *gin.Context) { }) } +func writeModelsList(c *gin.Context, modelIDs []string) { + models := make([]claude.Model, 0, len(modelIDs)) + for _, modelID := range modelIDs { + models = append(models, claude.Model{ + ID: modelID, + Type: "model", + DisplayName: modelID, + CreatedAt: "2024-01-01T00:00:00Z", + }) + } + c.JSON(http.StatusOK, gin.H{ + "object": "list", + "data": models, + }) +} + +func writeCustomModelsList(c *gin.Context, platform string, modelIDs []string) { + if platform == service.PlatformOpenAI { + writeOpenAIModelsList(c, modelIDs) + return + } + writeModelsList(c, modelIDs) +} + +func writeOpenAIModelsList(c *gin.Context, modelIDs []string) { + defaultsByID := make(map[string]openai.Model, len(openai.DefaultModels)) + for _, model := range openai.DefaultModels { + defaultsByID[model.ID] = model + } + + models := make([]openai.Model, 0, len(modelIDs)) + for _, modelID := range modelIDs { + if model, ok := defaultsByID[modelID]; ok { + models = append(models, model) + continue + } + models = append(models, openai.Model{ + ID: modelID, + Object: "model", + Created: 1704067200, + OwnedBy: "openai", + Type: "model", + DisplayName: modelID, + }) + } + c.JSON(http.StatusOK, gin.H{ + "object": "list", + "data": models, + }) +} + +func filterModelsByCustomList(availableModels, fallbackModels, selectedModels []string) []string { + if len(selectedModels) == 0 { + return availableModels + } + source := availableModels + if len(source) == 0 { + source = fallbackModels + } + if len(source) == 0 { + return nil + } + + allowed := make([]string, 0, len(source)) + for _, model := range source { + model = strings.TrimSpace(model) + if model != "" { + allowed = append(allowed, model) + } + } + + seen := make(map[string]struct{}, len(selectedModels)) + filtered := make([]string, 0, len(selectedModels)) + for _, model := range selectedModels { + model = strings.TrimSpace(model) + if model == "" { + continue + } + if !customModelsListAllowsModel(allowed, model) { + continue + } + if _, ok := seen[model]; ok { + continue + } + seen[model] = struct{}{} + filtered = append(filtered, model) + } + return filtered +} + +func customModelsListAllowsModel(availablePatterns []string, model string) bool { + for _, pattern := range availablePatterns { + if pattern == model { + return true + } + if strings.HasSuffix(pattern, "*") && strings.HasPrefix(model, strings.TrimSuffix(pattern, "*")) { + return true + } + } + return false +} + +func defaultModelIDsForPlatform(platform string) []string { + switch platform { + case service.PlatformOpenAI: + return openai.DefaultModelIDs() + case service.PlatformGemini: + ids := make([]string, 0, len(geminicli.DefaultModels)) + for _, model := range geminicli.DefaultModels { + ids = append(ids, model.ID) + } + return ids + case service.PlatformAntigravity: + models := antigravity.DefaultModels() + ids := make([]string, 0, len(models)) + for _, model := range models { + ids = append(ids, model.ID) + } + return ids + default: + ids := make([]string, 0, len(claude.DefaultModels)) + for _, model := range claude.DefaultModels { + ids = append(ids, model.ID) + } + return ids + } +} + // AntigravityModels 返回 Antigravity 支持的全部模型 // GET /antigravity/models func (h *GatewayHandler) AntigravityModels(c *gin.Context) { diff --git a/backend/internal/handler/gateway_models_test.go b/backend/internal/handler/gateway_models_test.go index 78b07a1a..c5238f2a 100644 --- a/backend/internal/handler/gateway_models_test.go +++ b/backend/internal/handler/gateway_models_test.go @@ -25,7 +25,11 @@ type gatewayModelsResponseForTest struct { } type gatewayModelItemForTest struct { - ID string `json:"id"` + ID string `json:"id"` + Object string `json:"object"` + Created int64 `json:"created"` + OwnedBy string `json:"owned_by"` + CreatedAt string `json:"created_at"` } func (s *gatewayModelsAccountRepoStub) ListSchedulableByGroupID(ctx context.Context, groupID int64) ([]service.Account, error) { @@ -127,6 +131,267 @@ func TestGatewayModels_GeminiGroupFiltersMappedModelsByPlatform(t *testing.T) { require.Equal(t, []string{"gemini-2.5-flash"}, modelIDsForTest(got.Data)) } +func TestGatewayModels_CustomModelsListDisabledKeepsOriginalModels(t *testing.T) { + gin.SetMode(gin.TestMode) + + groupID := int64(22) + h := newGatewayModelsHandlerForTest( + &gatewayModelsAccountRepoStub{ + byGroup: map[int64][]service.Account{ + groupID: { + { + ID: 1, + Platform: service.PlatformOpenAI, + Credentials: map[string]any{ + "model_mapping": map[string]any{ + "gpt-5.5": "gpt-5.5", + "gpt-5.4": "gpt-5.4", + }, + }, + }, + }, + }, + }, + ) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodGet, "/v1/models", nil) + c.Set(string(middleware2.ContextKeyAPIKey), &service.APIKey{ + Group: &service.Group{ + ID: groupID, + Platform: service.PlatformOpenAI, + ModelsListConfig: service.GroupModelsListConfig{ + Enabled: false, + Models: []string{"gpt-5.5"}, + }, + }, + }) + + h.Models(c) + + require.Equal(t, http.StatusOK, rec.Code) + + var got gatewayModelsResponseForTest + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &got)) + require.Equal(t, []string{"gpt-5.4", "gpt-5.5"}, modelIDsForTest(got.Data)) +} + +func TestGatewayModels_CustomModelsListFiltersAndOrdersMappedModels(t *testing.T) { + gin.SetMode(gin.TestMode) + + groupID := int64(23) + h := newGatewayModelsHandlerForTest( + &gatewayModelsAccountRepoStub{ + byGroup: map[int64][]service.Account{ + groupID: { + { + ID: 1, + Platform: service.PlatformOpenAI, + Credentials: map[string]any{ + "model_mapping": map[string]any{ + "gpt-5.4": "gpt-5.4", + "gpt-5.5": "gpt-5.5", + "legacy-gpt-2024": "legacy-gpt-2024", + }, + }, + }, + }, + }, + }, + ) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodGet, "/v1/models", nil) + c.Set(string(middleware2.ContextKeyAPIKey), &service.APIKey{ + Group: &service.Group{ + ID: groupID, + Platform: service.PlatformOpenAI, + ModelsListConfig: service.GroupModelsListConfig{ + Enabled: true, + Models: []string{"gpt-5.5", "missing-model", "gpt-5.4"}, + }, + }, + }) + + h.Models(c) + + require.Equal(t, http.StatusOK, rec.Code) + + var got gatewayModelsResponseForTest + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &got)) + require.Equal(t, []string{"gpt-5.5", "gpt-5.4"}, modelIDsForTest(got.Data)) +} + +func TestGatewayModels_CustomModelsListKeepsConcreteModelAllowedByWildcardMapping(t *testing.T) { + gin.SetMode(gin.TestMode) + + groupID := int64(26) + h := newGatewayModelsHandlerForTest( + &gatewayModelsAccountRepoStub{ + byGroup: map[int64][]service.Account{ + groupID: { + { + ID: 1, + Platform: service.PlatformAnthropic, + Credentials: map[string]any{ + "model_mapping": map[string]any{ + "claude-*": "claude-sonnet-4-6", + }, + }, + }, + }, + }, + }, + ) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodGet, "/v1/models", nil) + c.Set(string(middleware2.ContextKeyAPIKey), &service.APIKey{ + Group: &service.Group{ + ID: groupID, + Platform: service.PlatformAnthropic, + ModelsListConfig: service.GroupModelsListConfig{ + Enabled: true, + Models: []string{"claude-sonnet-4-6"}, + }, + }, + }) + + h.Models(c) + + require.Equal(t, http.StatusOK, rec.Code) + + var got gatewayModelsResponseForTest + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &got)) + require.Equal(t, []string{"claude-sonnet-4-6"}, modelIDsForTest(got.Data)) +} + +func TestGatewayModels_CustomModelsListCanReturnEmptyWhenSelectionsUnavailable(t *testing.T) { + gin.SetMode(gin.TestMode) + + groupID := int64(24) + h := newGatewayModelsHandlerForTest( + &gatewayModelsAccountRepoStub{ + byGroup: map[int64][]service.Account{ + groupID: { + { + ID: 1, + Platform: service.PlatformOpenAI, + Credentials: map[string]any{ + "model_mapping": map[string]any{ + "gpt-5.4": "gpt-5.4", + }, + }, + }, + }, + }, + }, + ) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodGet, "/v1/models", nil) + c.Set(string(middleware2.ContextKeyAPIKey), &service.APIKey{ + Group: &service.Group{ + ID: groupID, + Platform: service.PlatformOpenAI, + ModelsListConfig: service.GroupModelsListConfig{ + Enabled: true, + Models: []string{"gpt-5.5"}, + }, + }, + }) + + h.Models(c) + + require.Equal(t, http.StatusOK, rec.Code) + + var got gatewayModelsResponseForTest + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &got)) + require.Empty(t, modelIDsForTest(got.Data)) +} + +func TestGatewayModels_CustomModelsListFiltersDefaultFallbackModels(t *testing.T) { + gin.SetMode(gin.TestMode) + + groupID := int64(25) + h := newGatewayModelsHandlerForTest( + &gatewayModelsAccountRepoStub{ + byGroup: map[int64][]service.Account{ + groupID: { + {ID: 1, Platform: service.PlatformOpenAI}, + }, + }, + }, + ) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodGet, "/v1/models", nil) + c.Set(string(middleware2.ContextKeyAPIKey), &service.APIKey{ + Group: &service.Group{ + ID: groupID, + Platform: service.PlatformOpenAI, + ModelsListConfig: service.GroupModelsListConfig{ + Enabled: true, + Models: []string{"gpt-5.5", "legacy-gpt-2024", "gpt-5.4"}, + }, + }, + }) + + h.Models(c) + + require.Equal(t, http.StatusOK, rec.Code) + + var got gatewayModelsResponseForTest + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &got)) + require.Equal(t, []string{"gpt-5.5", "gpt-5.4"}, modelIDsForTest(got.Data)) +} + +func TestGatewayModels_OpenAICustomModelsListKeepsOpenAIResponseShapeForDefaultFallback(t *testing.T) { + gin.SetMode(gin.TestMode) + + groupID := int64(27) + h := newGatewayModelsHandlerForTest( + &gatewayModelsAccountRepoStub{ + byGroup: map[int64][]service.Account{ + groupID: { + {ID: 1, Platform: service.PlatformOpenAI}, + }, + }, + }, + ) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodGet, "/v1/models", nil) + c.Set(string(middleware2.ContextKeyAPIKey), &service.APIKey{ + Group: &service.Group{ + ID: groupID, + Platform: service.PlatformOpenAI, + ModelsListConfig: service.GroupModelsListConfig{ + Enabled: true, + Models: []string{"gpt-5.5", "gpt-5.4"}, + }, + }, + }) + + h.Models(c) + + require.Equal(t, http.StatusOK, rec.Code) + + var got gatewayModelsResponseForTest + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &got)) + require.Equal(t, []string{"gpt-5.5", "gpt-5.4"}, modelIDsForTest(got.Data)) + require.Equal(t, "model", got.Data[0].Object) + require.NotZero(t, got.Data[0].Created) + require.Equal(t, "openai", got.Data[0].OwnedBy) + require.Empty(t, got.Data[0].CreatedAt) +} + func modelIDsForTest(models []gatewayModelItemForTest) []string { ids := make([]string, 0, len(models)) for _, model := range models { diff --git a/backend/internal/repository/api_key_repo.go b/backend/internal/repository/api_key_repo.go index 43b13937..bfe09283 100644 --- a/backend/internal/repository/api_key_repo.go +++ b/backend/internal/repository/api_key_repo.go @@ -183,6 +183,7 @@ func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*se group.FieldAllowMessagesDispatch, group.FieldDefaultMappedModel, group.FieldMessagesDispatchModelConfig, + group.FieldModelsListConfig, group.FieldRpmLimit, ) }). @@ -723,6 +724,7 @@ func groupEntityToService(g *dbent.Group) *service.Group { RequirePrivacySet: g.RequirePrivacySet, DefaultMappedModel: g.DefaultMappedModel, MessagesDispatchModelConfig: g.MessagesDispatchModelConfig, + ModelsListConfig: g.ModelsListConfig, RPMLimit: g.RpmLimit, CreatedAt: g.CreatedAt, UpdatedAt: g.UpdatedAt, diff --git a/backend/internal/repository/group_repo.go b/backend/internal/repository/group_repo.go index 9c3b2010..ac8669ab 100644 --- a/backend/internal/repository/group_repo.go +++ b/backend/internal/repository/group_repo.go @@ -66,6 +66,7 @@ func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) er SetRequirePrivacySet(groupIn.RequirePrivacySet). SetDefaultMappedModel(groupIn.DefaultMappedModel). SetMessagesDispatchModelConfig(groupIn.MessagesDispatchModelConfig). + SetModelsListConfig(groupIn.ModelsListConfig). SetRpmLimit(groupIn.RPMLimit) // 设置模型路由配置 @@ -141,6 +142,7 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er SetRequirePrivacySet(groupIn.RequirePrivacySet). SetDefaultMappedModel(groupIn.DefaultMappedModel). SetMessagesDispatchModelConfig(groupIn.MessagesDispatchModelConfig). + SetModelsListConfig(groupIn.ModelsListConfig). SetRpmLimit(groupIn.RPMLimit) // 显式处理可空字段:nil 需要 clear,非 nil 需要 set。 diff --git a/backend/internal/server/routes/admin.go b/backend/internal/server/routes/admin.go index 349c520c..2301adda 100644 --- a/backend/internal/server/routes/admin.go +++ b/backend/internal/server/routes/admin.go @@ -259,6 +259,7 @@ func registerGroupRoutes(admin *gin.RouterGroup, h *handler.Handlers) { groups.GET("/usage-summary", h.Admin.Group.GetUsageSummary) groups.GET("/capacity-summary", h.Admin.Group.GetCapacitySummary) groups.PUT("/sort-order", h.Admin.Group.UpdateSortOrder) + groups.GET("/:id/models-list-candidates", h.Admin.Group.GetModelsListCandidates) groups.GET("/:id", h.Admin.Group.GetByID) groups.POST("", h.Admin.Group.Create) groups.PUT("/:id", h.Admin.Group.Update) diff --git a/backend/internal/service/admin_service.go b/backend/internal/service/admin_service.go index fc8f3fbb..d46b636f 100644 --- a/backend/internal/service/admin_service.go +++ b/backend/internal/service/admin_service.go @@ -17,9 +17,13 @@ import ( dbent "github.com/Wei-Shaw/sub2api/ent" "github.com/Wei-Shaw/sub2api/ent/authidentity" "github.com/Wei-Shaw/sub2api/ent/authidentitychannel" + "github.com/Wei-Shaw/sub2api/internal/pkg/antigravity" + "github.com/Wei-Shaw/sub2api/internal/pkg/claude" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" + "github.com/Wei-Shaw/sub2api/internal/pkg/geminicli" "github.com/Wei-Shaw/sub2api/internal/pkg/httpclient" "github.com/Wei-Shaw/sub2api/internal/pkg/logger" + "github.com/Wei-Shaw/sub2api/internal/pkg/openai" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/util/httputil" ) @@ -48,6 +52,7 @@ type AdminService interface { GetAllGroups(ctx context.Context) ([]Group, error) GetAllGroupsByPlatform(ctx context.Context, platform string) ([]Group, error) GetGroup(ctx context.Context, id int64) (*Group, error) + GetGroupModelsListCandidates(ctx context.Context, id int64, platform string) ([]string, error) CreateGroup(ctx context.Context, input *CreateGroupInput) (*Group, error) UpdateGroup(ctx context.Context, id int64, input *UpdateGroupInput) (*Group, error) DeleteGroup(ctx context.Context, id int64) error @@ -215,6 +220,7 @@ type CreateGroupInput struct { RequireOAuthOnly bool RequirePrivacySet bool MessagesDispatchModelConfig OpenAIMessagesDispatchModelConfig + ModelsListConfig GroupModelsListConfig // RPMLimit 分组 RPM 上限(0 = 不限制) RPMLimit int // 从指定分组复制账号(创建分组后在同一事务内绑定) @@ -255,6 +261,7 @@ type UpdateGroupInput struct { RequireOAuthOnly *bool RequirePrivacySet *bool MessagesDispatchModelConfig *OpenAIMessagesDispatchModelConfig + ModelsListConfig *GroupModelsListConfig // RPMLimit 分组 RPM 上限(0 = 不限制),nil 表示未提供不改动。 RPMLimit *int // 从指定分组复制账号(同步操作:先清空当前分组的账号绑定,再绑定源分组的账号) @@ -1582,6 +1589,80 @@ func (s *adminServiceImpl) GetGroup(ctx context.Context, id int64) (*Group, erro return s.groupRepo.GetByID(ctx, id) } +func (s *adminServiceImpl) GetGroupModelsListCandidates(ctx context.Context, id int64, platform string) ([]string, error) { + platform = strings.TrimSpace(platform) + if id > 0 { + group, err := s.groupRepo.GetByIDLite(ctx, id) + if err != nil { + return nil, err + } + if platform == "" { + platform = group.Platform + } + } + if platform == "" { + platform = PlatformAnthropic + } + + candidates := defaultModelsListCandidateIDs(platform) + if id <= 0 || s.accountRepo == nil { + return candidates, nil + } + + accounts, err := s.accountRepo.ListSchedulableByGroupID(ctx, id) + if err != nil { + return nil, err + } + + seen := make(map[string]struct{}, len(candidates)) + for _, model := range candidates { + seen[model] = struct{}{} + } + for _, acc := range accounts { + if acc.Platform != platform { + continue + } + for model := range acc.GetModelMapping() { + model = strings.TrimSpace(model) + if model == "" { + continue + } + if _, ok := seen[model]; ok { + continue + } + seen[model] = struct{}{} + candidates = append(candidates, model) + } + } + return candidates, nil +} + +func defaultModelsListCandidateIDs(platform string) []string { + switch platform { + case PlatformOpenAI: + return openai.DefaultModelIDs() + case PlatformGemini: + ids := make([]string, 0, len(geminicli.DefaultModels)) + for _, model := range geminicli.DefaultModels { + ids = append(ids, model.ID) + } + return ids + case PlatformAntigravity: + models := antigravity.DefaultModels() + ids := make([]string, 0, len(models)) + for _, model := range models { + ids = append(ids, model.ID) + } + return ids + default: + ids := make([]string, 0, len(claude.DefaultModels)) + for _, model := range claude.DefaultModels { + ids = append(ids, model.ID) + } + return ids + } +} + func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupInput) (*Group, error) { if input.RateMultiplier <= 0 { return nil, errors.New("rate_multiplier must be > 0") @@ -1697,6 +1778,7 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn RequirePrivacySet: input.RequirePrivacySet, DefaultMappedModel: input.DefaultMappedModel, MessagesDispatchModelConfig: normalizeOpenAIMessagesDispatchModelConfig(input.MessagesDispatchModelConfig), + ModelsListConfig: normalizeGroupModelsListConfig(input.ModelsListConfig), RPMLimit: input.RPMLimit, } sanitizeGroupMessagesDispatchFields(group) @@ -1944,6 +2026,9 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd if input.MessagesDispatchModelConfig != nil { group.MessagesDispatchModelConfig = normalizeOpenAIMessagesDispatchModelConfig(*input.MessagesDispatchModelConfig) } + if input.ModelsListConfig != nil { + group.ModelsListConfig = normalizeGroupModelsListConfig(*input.ModelsListConfig) + } if input.RPMLimit != nil { group.RPMLimit = *input.RPMLimit } diff --git a/backend/internal/service/api_key_auth_cache.go b/backend/internal/service/api_key_auth_cache.go index 3553a18a..74163179 100644 --- a/backend/internal/service/api_key_auth_cache.go +++ b/backend/internal/service/api_key_auth_cache.go @@ -87,6 +87,7 @@ type APIKeyAuthGroupSnapshot struct { AllowMessagesDispatch bool `json:"allow_messages_dispatch"` DefaultMappedModel string `json:"default_mapped_model,omitempty"` MessagesDispatchModelConfig OpenAIMessagesDispatchModelConfig `json:"messages_dispatch_model_config,omitempty"` + ModelsListConfig GroupModelsListConfig `json:"models_list_config,omitempty"` // RPMLimit 分组级每分钟请求数上限(0 = 不限制);用于 billing_cache_service.checkRPM 级联判断。 RPMLimit int `json:"rpm_limit"` diff --git a/backend/internal/service/api_key_auth_cache_impl.go b/backend/internal/service/api_key_auth_cache_impl.go index c752ce28..69c6086f 100644 --- a/backend/internal/service/api_key_auth_cache_impl.go +++ b/backend/internal/service/api_key_auth_cache_impl.go @@ -14,7 +14,7 @@ import ( "github.com/dgraph-io/ristretto" ) -const apiKeyAuthSnapshotVersion = 10 // v10: reload snapshots for group availability checks +const apiKeyAuthSnapshotVersion = 11 // v11: reload snapshots for custom models_list_config type apiKeyAuthCacheConfig struct { l1Size int @@ -272,6 +272,7 @@ func (s *APIKeyService) snapshotFromAPIKey(ctx context.Context, apiKey *APIKey) AllowMessagesDispatch: apiKey.Group.AllowMessagesDispatch, DefaultMappedModel: apiKey.Group.DefaultMappedModel, MessagesDispatchModelConfig: apiKey.Group.MessagesDispatchModelConfig, + ModelsListConfig: apiKey.Group.ModelsListConfig, RPMLimit: apiKey.Group.RPMLimit, } } @@ -342,6 +343,7 @@ func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapsho AllowMessagesDispatch: snapshot.Group.AllowMessagesDispatch, DefaultMappedModel: snapshot.Group.DefaultMappedModel, MessagesDispatchModelConfig: snapshot.Group.MessagesDispatchModelConfig, + ModelsListConfig: snapshot.Group.ModelsListConfig, RPMLimit: snapshot.Group.RPMLimit, } } diff --git a/backend/internal/service/api_key_auth_cache_version_test.go b/backend/internal/service/api_key_auth_cache_version_test.go new file mode 100644 index 00000000..5982e526 --- /dev/null +++ b/backend/internal/service/api_key_auth_cache_version_test.go @@ -0,0 +1,43 @@ +package service + +import "testing" + +func TestAPIKeyService_RejectsV10AuthSnapshotWithoutModelsListConfig(t *testing.T) { + groupID := int64(9) + svc := &APIKeyService{} + + apiKey, ok, err := svc.applyAuthCacheEntry("k-legacy-models-list", &APIKeyAuthCacheEntry{ + Snapshot: &APIKeyAuthSnapshot{ + Version: 10, + APIKeyID: 1, + UserID: 2, + GroupID: &groupID, + Status: StatusActive, + User: APIKeyAuthUserSnapshot{ + ID: 2, + Status: StatusActive, + Role: RoleUser, + Balance: 10, + Concurrency: 3, + }, + Group: &APIKeyAuthGroupSnapshot{ + ID: groupID, + Name: "openai", + Platform: PlatformOpenAI, + Status: StatusActive, + SubscriptionType: SubscriptionTypeStandard, + RateMultiplier: 1, + }, + }, + }) + + if err != nil { + t.Fatalf("expected stale snapshot to be ignored without error, got %v", err) + } + if ok { + t.Fatalf("expected v10 auth snapshot to be rejected after models_list_config was added") + } + if apiKey != nil { + t.Fatalf("expected no API key from stale snapshot, got %#v", apiKey) + } +} diff --git a/backend/internal/service/group.go b/backend/internal/service/group.go index f6155352..9aa2a52f 100644 --- a/backend/internal/service/group.go +++ b/backend/internal/service/group.go @@ -8,6 +8,7 @@ import ( ) type OpenAIMessagesDispatchModelConfig = domain.OpenAIMessagesDispatchModelConfig +type GroupModelsListConfig = domain.GroupModelsListConfig type Group struct { ID int64 @@ -61,6 +62,7 @@ type Group struct { RequirePrivacySet bool // 调度时仅允许 privacy 已成功设置的账号(OpenAI/Antigravity/Anthropic/Gemini) DefaultMappedModel string MessagesDispatchModelConfig OpenAIMessagesDispatchModelConfig + ModelsListConfig GroupModelsListConfig // RPMLimit 分组级每分钟请求数上限(0 = 不限制)。 // 一旦设置即接管该分组用户的限流(覆盖用户级 rpm_limit),可被 user-group rpm_override 进一步覆盖。 diff --git a/backend/internal/service/group_models_list.go b/backend/internal/service/group_models_list.go new file mode 100644 index 00000000..b10de724 --- /dev/null +++ b/backend/internal/service/group_models_list.go @@ -0,0 +1,32 @@ +package service + +import "strings" + +func normalizeGroupModelsListConfig(cfg GroupModelsListConfig) GroupModelsListConfig { + out := GroupModelsListConfig{Enabled: cfg.Enabled} + if len(cfg.Models) == 0 { + return out + } + + seen := make(map[string]struct{}, len(cfg.Models)) + out.Models = make([]string, 0, len(cfg.Models)) + for _, model := range cfg.Models { + model = strings.TrimSpace(model) + if model == "" { + continue + } + if _, ok := seen[model]; ok { + continue + } + seen[model] = struct{}{} + out.Models = append(out.Models, model) + } + if len(out.Models) == 0 { + out.Models = nil + } + return out +} + +func (g *Group) CustomModelsListEnabled() bool { + return g != nil && g.ModelsListConfig.Enabled && len(g.ModelsListConfig.Models) > 0 +} diff --git a/backend/migrations/143_group_models_list_config.sql b/backend/migrations/143_group_models_list_config.sql new file mode 100644 index 00000000..67f27623 --- /dev/null +++ b/backend/migrations/143_group_models_list_config.sql @@ -0,0 +1,5 @@ +-- 分组级自定义 /v1/models 展示列表配置。 +-- 仅用于控制 GET /v1/models 的展示结果,不参与账号白名单、模型映射或网关调度。 + +ALTER TABLE groups + ADD COLUMN IF NOT EXISTS models_list_config JSONB NOT NULL DEFAULT '{}'::jsonb; diff --git a/frontend/src/api/admin/groups.ts b/frontend/src/api/admin/groups.ts index 6b94b799..b7846efd 100644 --- a/frontend/src/api/admin/groups.ts +++ b/frontend/src/api/admin/groups.ts @@ -76,6 +76,23 @@ export async function getById(id: number): Promise { return data } +/** + * Get candidate models for custom /v1/models list. + * id=0 returns platform default models for create flow. + */ +export async function getModelsListCandidates( + id: number, + platform?: GroupPlatform +): Promise { + const { data } = await apiClient.get<{ models: string[] }>( + `/admin/groups/${id}/models-list-candidates`, + { + params: platform ? { platform } : undefined + } + ) + return data.models || [] +} + /** * Create new group * @param groupData - Group data @@ -306,6 +323,7 @@ export const groupsAPI = { getAll, getByPlatform, getById, + getModelsListCandidates, create, update, delete: deleteGroup, diff --git a/frontend/src/i18n/locales/en.ts b/frontend/src/i18n/locales/en.ts index 1d94fa29..6dfa0aa2 100644 --- a/frontend/src/i18n/locales/en.ts +++ b/frontend/src/i18n/locales/en.ts @@ -2179,6 +2179,12 @@ export default { finalPricePreview: 'Final per-image price preview', notConfigured: 'Not configured' }, + modelsList: { + title: 'Custom /v1/models Model List', + hint: 'Only changes the /v1/models response. Whitelist model calls and account routing are unchanged.', + loading: 'Loading model list...', + empty: 'No displayable models' + }, claudeCode: { title: 'Claude Code Client Restriction', tooltip: 'When enabled, this group only allows official Claude Code clients. Non-Claude Code requests will be rejected or fallback to the specified group.', diff --git a/frontend/src/i18n/locales/zh.ts b/frontend/src/i18n/locales/zh.ts index 8fa15e72..8a996412 100644 --- a/frontend/src/i18n/locales/zh.ts +++ b/frontend/src/i18n/locales/zh.ts @@ -2262,6 +2262,12 @@ export default { finalPricePreview: '最终单张价格预览', notConfigured: '未配置' }, + modelsList: { + title: '自定义 /v1/models 模型列表', + hint: '仅影响 /v1/models 展示结果,不影响白名单模型调用和账号调度。', + loading: '正在加载模型列表...', + empty: '暂无可展示模型' + }, claudeCode: { title: 'Claude Code 客户端限制', tooltip: diff --git a/frontend/src/types/index.ts b/frontend/src/types/index.ts index 632e5108..eae5e455 100644 --- a/frontend/src/types/index.ts +++ b/frontend/src/types/index.ts @@ -548,11 +548,17 @@ export interface AdminGroup extends Group { // OpenAI Messages 调度配置(仅 openai 平台使用) default_mapped_model?: string messages_dispatch_model_config?: OpenAIMessagesDispatchModelConfig + models_list_config?: ModelsListConfig // 分组排序 sort_order: number } +export interface ModelsListConfig { + enabled: boolean + models: string[] +} + export interface ApiKey { id: number user_id: number @@ -632,6 +638,13 @@ export interface CreateGroupRequest { fallback_group_id_on_invalid_request?: number | null mcp_xml_inject?: boolean supported_model_scopes?: string[] + models_list_config?: ModelsListConfig + allow_messages_dispatch?: boolean + default_mapped_model?: string + messages_dispatch_model_config?: OpenAIMessagesDispatchModelConfig + model_routing?: Record | null + model_routing_enabled?: boolean + rpm_limit?: number require_oauth_only?: boolean require_privacy_set?: boolean // 从指定分组复制账号 @@ -660,6 +673,13 @@ export interface UpdateGroupRequest { fallback_group_id_on_invalid_request?: number | null mcp_xml_inject?: boolean supported_model_scopes?: string[] + models_list_config?: ModelsListConfig + allow_messages_dispatch?: boolean + default_mapped_model?: string + messages_dispatch_model_config?: OpenAIMessagesDispatchModelConfig + model_routing?: Record | null + model_routing_enabled?: boolean + rpm_limit?: number require_oauth_only?: boolean require_privacy_set?: boolean copy_accounts_from_group_ids?: number[] diff --git a/frontend/src/views/admin/GroupsView.vue b/frontend/src/views/admin/GroupsView.vue index ebb57bd9..0b583a09 100644 --- a/frontend/src/views/admin/GroupsView.vue +++ b/frontend/src/views/admin/GroupsView.vue @@ -69,7 +69,7 @@ {{ t("admin.groups.sortOrder") }} + +
+
+ + 已选 {{ createModelsListSelectedCount }} / + {{ createModelsListState.items.length }} + +
+ + +
+
+
+

+ {{ t("admin.groups.modelsList.loading") }} +

+

+ {{ t("admin.groups.modelsList.empty") }} +

+
+ + + {{ item.id }} + + + +
+
+
+ +
+
+
+ +

+ {{ t("admin.groups.modelsList.hint") }} +

+
+ +
+
+
+ + 已选 {{ editModelsListSelectedCount }} / + {{ editModelsListState.items.length }} + +
+ + +
+
+
+

+ {{ t("admin.groups.modelsList.loading") }} +

+

+ {{ t("admin.groups.modelsList.empty") }} +

+
+ + + {{ item.id }} + + + +
+
+
+
+
(null); const sortableGroups = ref([]); const createMessagesDispatchDefaults = createDefaultMessagesDispatchFormState(); const editMessagesDispatchDefaults = createDefaultMessagesDispatchFormState(); +const createModelsListState = reactive(createInitialModelsListState()); +const editModelsListState = reactive(createInitialModelsListState()); +const createModelsListLoading = ref(false); +const editModelsListLoading = ref(false); +const modelsListCandidatesTracker = createModelsListCandidatesTracker(); +const createModelsListSelectedCount = computed( + () => createModelsListState.items.filter((item) => item.selected).length, +); +const editModelsListSelectedCount = computed( + () => editModelsListState.items.filter((item) => item.selected).length, +); const createForm = reactive({ name: "", @@ -3335,6 +3561,52 @@ const removeEditRoutingRule = (rule: ModelRoutingRule) => { editModelRoutingRules.value.splice(index, 1); }; +const resetModelsListState = ( + state: typeof createModelsListState, + config?: Parameters[0], +) => { + const fresh = createInitialModelsListState(config); + state.enabled = fresh.enabled; + state.savedModels = fresh.savedModels; + state.items = fresh.items; +}; + +const loadModelsListCandidates = async ( + mode: "create" | "edit", + groupID: number, + platform: GroupPlatform, +) => { + const request = { mode, groupID, platform }; + const requestID = modelsListCandidatesTracker.next(request); + const state = mode === "create" ? createModelsListState : editModelsListState; + const loadingRef = mode === "create" ? createModelsListLoading : editModelsListLoading; + loadingRef.value = true; + try { + const models = await adminAPI.groups.getModelsListCandidates(groupID, platform); + if (!modelsListCandidatesTracker.isCurrent(requestID, request)) { + return; + } + setModelsListCandidates(state, models); + } catch (error) { + if (!modelsListCandidatesTracker.isCurrent(requestID, request)) { + return; + } + console.error("Error loading group models list candidates:", error); + } finally { + if (modelsListCandidatesTracker.isCurrent(requestID, request)) { + loadingRef.value = false; + } + } +}; + +const moveCreateModelsListItem = (fromIndex: number, toIndex: number) => { + moveModelsListItem(createModelsListState, fromIndex, toIndex); +}; + +const moveEditModelsListItem = (fromIndex: number, toIndex: number) => { + moveModelsListItem(editModelsListState, fromIndex, toIndex); +}; + // 将 UI 格式的路由规则转换为 API 格式 const convertRoutingRulesToApiFormat = ( rules: ModelRoutingRule[], @@ -3624,6 +3896,11 @@ const handleSort = (key: string, order: 'asc' | 'desc') => { loadGroups(); }; +const openCreateModal = () => { + showCreateModal.value = true; + loadModelsListCandidates("create", 0, createForm.platform); +}; + const closeCreateModal = () => { showCreateModal.value = false; createModelRoutingRules.value.forEach((rule) => { @@ -3654,6 +3931,8 @@ const closeCreateModal = () => { createForm.supported_model_scopes = ["claude", "gemini_text", "gemini_image"]; createForm.mcp_xml_inject = true; createForm.copy_accounts_from_group_ids = []; + createForm.rpm_limit = 0; + resetModelsListState(createModelsListState); createModelRoutingRules.value = []; }; @@ -3708,6 +3987,7 @@ const handleCreateGroup = async () => { model_routing: convertRoutingRulesToApiFormat( createModelRoutingRules.value, ), + models_list_config: buildModelsListConfig(createModelsListState), supported_model_scopes: normalizeSupportedModelScopesForPlatform( createForm.platform, createForm.supported_model_scopes, @@ -3794,10 +4074,12 @@ const handleEdit = async (group: AdminGroup) => { editForm.mcp_xml_inject = group.mcp_xml_inject ?? true; editForm.copy_accounts_from_group_ids = []; // 复制账号字段每次编辑时重置为空 editForm.rpm_limit = group.rpm_limit ?? 0; + resetModelsListState(editModelsListState, group.models_list_config); // 加载模型路由规则(异步加载账号名称) editModelRoutingRules.value = await convertApiFormatToRoutingRules( group.model_routing, ); + loadModelsListCandidates("edit", group.id, group.platform); showEditModal.value = true; }; @@ -3811,6 +4093,7 @@ const closeEditModal = () => { editModelRoutingRules.value = []; editForm.copy_accounts_from_group_ids = []; resetMessagesDispatchFormState(editForm); + resetModelsListState(editModelsListState); }; const handleUpdateGroup = async () => { @@ -3843,6 +4126,7 @@ const handleUpdateGroup = async () => { model_routing: convertRoutingRulesToApiFormat( editModelRoutingRules.value, ), + models_list_config: buildModelsListConfig(editModelsListState), supported_model_scopes: normalizeSupportedModelScopesForPlatform( editForm.platform, editForm.supported_model_scopes, @@ -3960,6 +4244,8 @@ watch( createForm.require_oauth_only = false; createForm.require_privacy_set = false; } + resetModelsListState(createModelsListState); + loadModelsListCandidates("create", 0, newVal); }, ); @@ -3976,6 +4262,10 @@ watch( editForm.require_oauth_only = false; editForm.require_privacy_set = false; } + if (editingGroup.value) { + resetModelsListState(editModelsListState, editForm.platform === editingGroup.value.platform ? editingGroup.value.models_list_config : undefined); + loadModelsListCandidates("edit", editingGroup.value.id, newVal); + } }, ); @@ -4049,6 +4339,7 @@ const saveSortOrder = async () => { onMounted(() => { loadGroups(); + loadModelsListCandidates("create", 0, createForm.platform); document.addEventListener("click", handleClickOutside); }); diff --git a/frontend/src/views/admin/__tests__/groupsModelsList.spec.ts b/frontend/src/views/admin/__tests__/groupsModelsList.spec.ts new file mode 100644 index 00000000..ae50c861 --- /dev/null +++ b/frontend/src/views/admin/__tests__/groupsModelsList.spec.ts @@ -0,0 +1,125 @@ +import { describe, expect, it } from "vitest"; + +import { + buildModelsListConfig, + createModelsListState, + hydrateModelsListState, + invertModelsListSelection, + moveModelsListItem, + selectAllModelsListItems, + setModelsListCandidates, + toggleModelsListItem, +} from "../groupsModelsList"; + +describe("groupsModelsList", () => { + it("selects all default candidates for a new disabled config", () => { + const state = createModelsListState(); + + setModelsListCandidates(state, ["gpt-5.5", "gpt-5.4"]); + + expect(state.enabled).toBe(false); + expect(state.items).toEqual([ + { id: "gpt-5.5", selected: true }, + { id: "gpt-5.4", selected: true }, + ]); + }); + + it("keeps saved selections and marks new candidates as unselected when editing", () => { + const state = createModelsListState({ + enabled: true, + models: ["gpt-5.5", "gpt-5.4"], + }); + + setModelsListCandidates(state, ["gpt-5.4", "legacy-gpt", "gpt-5.5"]); + + expect(state.enabled).toBe(true); + expect(state.items).toEqual([ + { id: "gpt-5.5", selected: true }, + { id: "gpt-5.4", selected: true }, + { id: "legacy-gpt", selected: false }, + ]); + }); + + it("preserves explicitly unselected saved candidates when candidates refresh", () => { + const state = createModelsListState({ + enabled: true, + models: ["gpt-5.5"], + }); + + setModelsListCandidates(state, ["gpt-5.5", "gpt-5.4"]); + + expect(state.items).toEqual([ + { id: "gpt-5.5", selected: true }, + { id: "gpt-5.4", selected: false }, + ]); + }); + + it("builds config with selected models in current display order", () => { + const state = hydrateModelsListState({ + enabled: true, + models: ["gpt-5.5", "gpt-5.4", "legacy-gpt"], + }, ["gpt-5.5", "gpt-5.4", "legacy-gpt"]); + + toggleModelsListItem(state, "legacy-gpt"); + moveModelsListItem(state, 1, 0); + + expect(buildModelsListConfig(state)).toEqual({ + enabled: true, + models: ["gpt-5.4", "gpt-5.5"], + }); + }); + + it("keeps selected models in payload even when disabled so reopening can restore choices", () => { + const state = hydrateModelsListState({ + enabled: false, + models: ["gpt-5.5"], + }, ["gpt-5.5", "gpt-5.4"]); + + expect(buildModelsListConfig(state)).toEqual({ + enabled: false, + models: ["gpt-5.5"], + }); + }); + + it("preserves saved models when candidates have not loaded yet", () => { + const state = createModelsListState({ + enabled: true, + models: ["gpt-5.5", "gpt-5.4"], + }); + + expect(buildModelsListConfig(state)).toEqual({ + enabled: true, + models: ["gpt-5.5", "gpt-5.4"], + }); + }); + + it("selects all candidate models from the toolbar action", () => { + const state = hydrateModelsListState({ + enabled: true, + models: ["gpt-5.5"], + }, ["gpt-5.5", "gpt-5.4", "gpt-5.4-mini"]); + + selectAllModelsListItems(state); + + expect(state.items).toEqual([ + { id: "gpt-5.5", selected: true }, + { id: "gpt-5.4", selected: true }, + { id: "gpt-5.4-mini", selected: true }, + ]); + }); + + it("inverts selected models from the toolbar action", () => { + const state = hydrateModelsListState({ + enabled: true, + models: ["gpt-5.5"], + }, ["gpt-5.5", "gpt-5.4", "gpt-5.4-mini"]); + + invertModelsListSelection(state); + + expect(state.items).toEqual([ + { id: "gpt-5.5", selected: false }, + { id: "gpt-5.4", selected: true }, + { id: "gpt-5.4-mini", selected: true }, + ]); + }); +}); diff --git a/frontend/src/views/admin/__tests__/groupsModelsListCandidates.spec.ts b/frontend/src/views/admin/__tests__/groupsModelsListCandidates.spec.ts new file mode 100644 index 00000000..ec292c63 --- /dev/null +++ b/frontend/src/views/admin/__tests__/groupsModelsListCandidates.spec.ts @@ -0,0 +1,65 @@ +import { describe, expect, it } from "vitest"; + +import { + createModelsListCandidatesTracker, +} from "../groupsModelsListCandidates"; + +describe("groupsModelsListCandidates", () => { + it("rejects stale candidate responses after a newer platform request starts", () => { + const tracker = createModelsListCandidatesTracker(); + const first = { + mode: "create" as const, + groupID: 0, + platform: "openai" as const, + }; + const second = { + mode: "create" as const, + groupID: 0, + platform: "anthropic" as const, + }; + + const firstID = tracker.next(first); + const secondID = tracker.next(second); + + expect(tracker.isCurrent(firstID, first)).toBe(false); + expect(tracker.isCurrent(secondID, second)).toBe(true); + }); + + it("rejects responses for a previous edit group even with the same platform", () => { + const tracker = createModelsListCandidatesTracker(); + const first = { + mode: "edit" as const, + groupID: 10, + platform: "openai" as const, + }; + const second = { + mode: "edit" as const, + groupID: 11, + platform: "openai" as const, + }; + + const firstID = tracker.next(first); + tracker.next(second); + + expect(tracker.isCurrent(firstID, first)).toBe(false); + }); + + it("tracks create and edit requests independently", () => { + const tracker = createModelsListCandidatesTracker(); + const editRequest = { + mode: "edit" as const, + groupID: 10, + platform: "openai" as const, + }; + const createRequest = { + mode: "create" as const, + groupID: 0, + platform: "anthropic" as const, + }; + + const editID = tracker.next(editRequest); + tracker.next(createRequest); + + expect(tracker.isCurrent(editID, editRequest)).toBe(true); + }); +}); diff --git a/frontend/src/views/admin/__tests__/groupsModelsListLayout.spec.ts b/frontend/src/views/admin/__tests__/groupsModelsListLayout.spec.ts new file mode 100644 index 00000000..6ac3d769 --- /dev/null +++ b/frontend/src/views/admin/__tests__/groupsModelsListLayout.spec.ts @@ -0,0 +1,19 @@ +import { readFileSync } from "node:fs"; +import { fileURLToPath } from "node:url"; +import { dirname, resolve } from "node:path"; + +import { describe, expect, it } from "vitest"; + +const currentDir = dirname(fileURLToPath(import.meta.url)); +const groupsViewSource = readFileSync( + resolve(currentDir, "../GroupsView.vue"), + "utf8", +); + +describe("groups models list layout", () => { + it("keeps the toolbar outside of the scrolling list content", () => { + expect(groupsViewSource).toContain("overflow-hidden rounded-lg border"); + expect(groupsViewSource).toContain("max-h-64 space-y-2 overflow-y-auto p-2"); + expect(groupsViewSource).not.toContain("sticky top-0"); + }); +}); diff --git a/frontend/src/views/admin/groupsModelsList.ts b/frontend/src/views/admin/groupsModelsList.ts new file mode 100644 index 00000000..790268fe --- /dev/null +++ b/frontend/src/views/admin/groupsModelsList.ts @@ -0,0 +1,121 @@ +export interface ModelsListConfig { + enabled: boolean + models: string[] +} + +export interface ModelsListItem { + id: string + selected: boolean +} + +export interface ModelsListState { + enabled: boolean + savedModels: string[] + items: ModelsListItem[] +} + +export const createModelsListState = ( + config?: Partial | null, +): ModelsListState => ({ + enabled: config?.enabled ?? false, + savedModels: normalizeModels(config?.models ?? []), + items: [], +}) + +export const hydrateModelsListState = ( + config: Partial | null | undefined, + candidates: string[], +): ModelsListState => { + const state = createModelsListState(config) + setModelsListCandidates(state, candidates) + return state +} + +export const setModelsListCandidates = ( + state: ModelsListState, + candidates: string[], +) => { + const normalizedCandidates = normalizeModels(candidates) + const currentSelected = new Set( + state.items.filter(item => item.selected).map(item => item.id), + ) + const currentKnown = new Set(state.items.map(item => item.id)) + const savedSelected = new Set(state.savedModels) + const hasExistingItems = state.items.length > 0 + const selectionOrder = normalizeModels([ + ...state.items.map(item => item.id), + ...state.savedModels, + ...normalizedCandidates, + ]) + + state.items = selectionOrder.map(id => { + const selected = hasExistingItems + ? currentSelected.has(id) + : state.savedModels.length > 0 + ? savedSelected.has(id) + : normalizedCandidates.includes(id) + + return { + id, + selected: selected && (currentKnown.has(id) || savedSelected.has(id) || state.savedModels.length === 0), + } + }) +} + +export const toggleModelsListItem = (state: ModelsListState, modelID: string) => { + const item = state.items.find(item => item.id === modelID) + if (item) { + item.selected = !item.selected + } +} + +export const selectAllModelsListItems = (state: ModelsListState) => { + state.items.forEach(item => { + item.selected = true + }) +} + +export const invertModelsListSelection = (state: ModelsListState) => { + state.items.forEach(item => { + item.selected = !item.selected + }) +} + +export const moveModelsListItem = ( + state: ModelsListState, + fromIndex: number, + toIndex: number, +) => { + if ( + fromIndex === toIndex || + fromIndex < 0 || + toIndex < 0 || + fromIndex >= state.items.length || + toIndex >= state.items.length + ) { + return + } + const [item] = state.items.splice(fromIndex, 1) + state.items.splice(toIndex, 0, item) +} + +export const buildModelsListConfig = (state: ModelsListState): ModelsListConfig => ({ + enabled: state.enabled, + models: state.items.length > 0 + ? state.items.filter(item => item.selected).map(item => item.id) + : [...state.savedModels], +}) + +const normalizeModels = (models: string[]): string[] => { + const seen = new Set() + const out: string[] = [] + for (const raw of models) { + const model = raw.trim() + if (!model || seen.has(model)) { + continue + } + seen.add(model) + out.push(model) + } + return out +} diff --git a/frontend/src/views/admin/groupsModelsListCandidates.ts b/frontend/src/views/admin/groupsModelsListCandidates.ts new file mode 100644 index 00000000..2c722af8 --- /dev/null +++ b/frontend/src/views/admin/groupsModelsListCandidates.ts @@ -0,0 +1,41 @@ +import type { GroupPlatform } from "@/types"; + +export type ModelsListCandidatesMode = "create" | "edit"; + +export interface ModelsListCandidatesRequest { + mode: ModelsListCandidatesMode; + groupID: number; + platform: GroupPlatform; +} + +export interface ModelsListCandidatesTracker { + next(request: ModelsListCandidatesRequest): number; + isCurrent(requestID: number, request: ModelsListCandidatesRequest): boolean; +} + +export const createModelsListCandidatesTracker = (): ModelsListCandidatesTracker => { + let currentRequestID = 0; + const currentByMode: Partial> = {}; + + return { + next(request) { + currentRequestID += 1; + currentByMode[request.mode] = { + id: currentRequestID, + request: { ...request }, + }; + return currentRequestID; + }, + isCurrent(requestID, request) { + const current = currentByMode[request.mode]; + return ( + current?.id === requestID && + current.request.groupID === request.groupID && + current.request.platform === request.platform + ); + }, + }; +};