package handler import ( "context" "encoding/json" "net/http" "net/http/httptest" "testing" middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" "github.com/Wei-Shaw/sub2api/internal/service" "github.com/gin-gonic/gin" "github.com/stretchr/testify/require" ) type gatewayModelsAccountRepoStub struct { service.AccountRepository byGroup map[int64][]service.Account } type gatewayModelsResponseForTest struct { Object string `json:"object"` Data []gatewayModelItemForTest `json:"data"` } type gatewayModelItemForTest struct { ID string `json:"id"` Object string `json:"object"` Created int64 `json:"created"` OwnedBy string `json:"owned_by"` CreatedAt string `json:"created_at"` } func (s *gatewayModelsAccountRepoStub) ListSchedulableByGroupID(ctx context.Context, groupID int64) ([]service.Account, error) { accounts, ok := s.byGroup[groupID] if !ok { return nil, nil } out := make([]service.Account, len(accounts)) copy(out, accounts) return out, nil } func newGatewayModelsHandlerForTest(repo service.AccountRepository) *GatewayHandler { return &GatewayHandler{ gatewayService: service.NewGatewayService( repo, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, ), } } func TestGatewayModels_GeminiGroupFallsBackToGeminiModels(t *testing.T) { gin.SetMode(gin.TestMode) groupID := int64(20) h := newGatewayModelsHandlerForTest( &gatewayModelsAccountRepoStub{ byGroup: map[int64][]service.Account{ groupID: { {ID: 1, Platform: service.PlatformGemini}, }, }, }, ) rec := httptest.NewRecorder() c, _ := gin.CreateTestContext(rec) c.Request = httptest.NewRequest(http.MethodGet, "/v1/models", nil) c.Set(string(middleware2.ContextKeyAPIKey), &service.APIKey{ Group: &service.Group{ID: groupID, Platform: service.PlatformGemini}, }) h.Models(c) require.Equal(t, http.StatusOK, rec.Code) var got gatewayModelsResponseForTest require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &got)) require.Equal(t, "list", got.Object) require.Contains(t, modelIDsForTest(got.Data), "gemini-2.5-flash") require.NotContains(t, modelIDsForTest(got.Data), "claude-sonnet-4-6") } func TestGatewayModels_GeminiGroupFiltersMappedModelsByPlatform(t *testing.T) { gin.SetMode(gin.TestMode) groupID := int64(21) h := newGatewayModelsHandlerForTest( &gatewayModelsAccountRepoStub{ byGroup: map[int64][]service.Account{ groupID: { { ID: 1, Platform: service.PlatformAnthropic, Credentials: map[string]any{ "model_mapping": map[string]any{ "claude-sonnet-4-6": "claude-sonnet-4-6", }, }, }, { ID: 2, Platform: service.PlatformGemini, Credentials: map[string]any{ "model_mapping": map[string]any{ "gemini-2.5-flash": "gemini-2.5-flash", }, }, }, }, }, }, ) rec := httptest.NewRecorder() c, _ := gin.CreateTestContext(rec) c.Request = httptest.NewRequest(http.MethodGet, "/v1/models", nil) c.Set(string(middleware2.ContextKeyAPIKey), &service.APIKey{ Group: &service.Group{ID: groupID, Platform: service.PlatformGemini}, }) h.Models(c) require.Equal(t, http.StatusOK, rec.Code) var got gatewayModelsResponseForTest require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &got)) require.Equal(t, []string{"gemini-2.5-flash"}, modelIDsForTest(got.Data)) } func TestGatewayModels_CustomModelsListDisabledKeepsOriginalModels(t *testing.T) { gin.SetMode(gin.TestMode) groupID := int64(22) h := newGatewayModelsHandlerForTest( &gatewayModelsAccountRepoStub{ byGroup: map[int64][]service.Account{ groupID: { { ID: 1, Platform: service.PlatformOpenAI, Credentials: map[string]any{ "model_mapping": map[string]any{ "gpt-5.5": "gpt-5.5", "gpt-5.4": "gpt-5.4", }, }, }, }, }, }, ) rec := httptest.NewRecorder() c, _ := gin.CreateTestContext(rec) c.Request = httptest.NewRequest(http.MethodGet, "/v1/models", nil) c.Set(string(middleware2.ContextKeyAPIKey), &service.APIKey{ Group: &service.Group{ ID: groupID, Platform: service.PlatformOpenAI, ModelsListConfig: service.GroupModelsListConfig{ Enabled: false, Models: []string{"gpt-5.5"}, }, }, }) h.Models(c) require.Equal(t, http.StatusOK, rec.Code) var got gatewayModelsResponseForTest require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &got)) require.Equal(t, []string{"gpt-5.4", "gpt-5.5"}, modelIDsForTest(got.Data)) } func TestGatewayModels_CustomModelsListFiltersAndOrdersMappedModels(t *testing.T) { gin.SetMode(gin.TestMode) groupID := int64(23) h := newGatewayModelsHandlerForTest( &gatewayModelsAccountRepoStub{ byGroup: map[int64][]service.Account{ groupID: { { ID: 1, Platform: service.PlatformOpenAI, Credentials: map[string]any{ "model_mapping": map[string]any{ "gpt-5.4": "gpt-5.4", "gpt-5.5": "gpt-5.5", "legacy-gpt-2024": "legacy-gpt-2024", }, }, }, }, }, }, ) rec := httptest.NewRecorder() c, _ := gin.CreateTestContext(rec) c.Request = httptest.NewRequest(http.MethodGet, "/v1/models", nil) c.Set(string(middleware2.ContextKeyAPIKey), &service.APIKey{ Group: &service.Group{ ID: groupID, Platform: service.PlatformOpenAI, ModelsListConfig: service.GroupModelsListConfig{ Enabled: true, Models: []string{"gpt-5.5", "missing-model", "gpt-5.4"}, }, }, }) h.Models(c) require.Equal(t, http.StatusOK, rec.Code) var got gatewayModelsResponseForTest require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &got)) require.Equal(t, []string{"gpt-5.5", "gpt-5.4"}, modelIDsForTest(got.Data)) } func TestGatewayModels_CustomModelsListKeepsConcreteModelAllowedByWildcardMapping(t *testing.T) { gin.SetMode(gin.TestMode) groupID := int64(26) h := newGatewayModelsHandlerForTest( &gatewayModelsAccountRepoStub{ byGroup: map[int64][]service.Account{ groupID: { { ID: 1, Platform: service.PlatformAnthropic, Credentials: map[string]any{ "model_mapping": map[string]any{ "claude-*": "claude-sonnet-4-6", }, }, }, }, }, }, ) rec := httptest.NewRecorder() c, _ := gin.CreateTestContext(rec) c.Request = httptest.NewRequest(http.MethodGet, "/v1/models", nil) c.Set(string(middleware2.ContextKeyAPIKey), &service.APIKey{ Group: &service.Group{ ID: groupID, Platform: service.PlatformAnthropic, ModelsListConfig: service.GroupModelsListConfig{ Enabled: true, Models: []string{"claude-sonnet-4-6"}, }, }, }) h.Models(c) require.Equal(t, http.StatusOK, rec.Code) var got gatewayModelsResponseForTest require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &got)) require.Equal(t, []string{"claude-sonnet-4-6"}, modelIDsForTest(got.Data)) } func TestGatewayModels_CustomModelsListCanReturnEmptyWhenSelectionsUnavailable(t *testing.T) { gin.SetMode(gin.TestMode) groupID := int64(24) h := newGatewayModelsHandlerForTest( &gatewayModelsAccountRepoStub{ byGroup: map[int64][]service.Account{ groupID: { { ID: 1, Platform: service.PlatformOpenAI, Credentials: map[string]any{ "model_mapping": map[string]any{ "gpt-5.4": "gpt-5.4", }, }, }, }, }, }, ) rec := httptest.NewRecorder() c, _ := gin.CreateTestContext(rec) c.Request = httptest.NewRequest(http.MethodGet, "/v1/models", nil) c.Set(string(middleware2.ContextKeyAPIKey), &service.APIKey{ Group: &service.Group{ ID: groupID, Platform: service.PlatformOpenAI, ModelsListConfig: service.GroupModelsListConfig{ Enabled: true, Models: []string{"gpt-5.5"}, }, }, }) h.Models(c) require.Equal(t, http.StatusOK, rec.Code) var got gatewayModelsResponseForTest require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &got)) require.Empty(t, modelIDsForTest(got.Data)) } func TestGatewayModels_CustomModelsListFiltersDefaultFallbackModels(t *testing.T) { gin.SetMode(gin.TestMode) groupID := int64(25) h := newGatewayModelsHandlerForTest( &gatewayModelsAccountRepoStub{ byGroup: map[int64][]service.Account{ groupID: { {ID: 1, Platform: service.PlatformOpenAI}, }, }, }, ) rec := httptest.NewRecorder() c, _ := gin.CreateTestContext(rec) c.Request = httptest.NewRequest(http.MethodGet, "/v1/models", nil) c.Set(string(middleware2.ContextKeyAPIKey), &service.APIKey{ Group: &service.Group{ ID: groupID, Platform: service.PlatformOpenAI, ModelsListConfig: service.GroupModelsListConfig{ Enabled: true, Models: []string{"gpt-5.5", "legacy-gpt-2024", "gpt-5.4"}, }, }, }) h.Models(c) require.Equal(t, http.StatusOK, rec.Code) var got gatewayModelsResponseForTest require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &got)) require.Equal(t, []string{"gpt-5.5", "gpt-5.4"}, modelIDsForTest(got.Data)) } func TestGatewayModels_OpenAICustomModelsListKeepsOpenAIResponseShapeForDefaultFallback(t *testing.T) { gin.SetMode(gin.TestMode) groupID := int64(27) h := newGatewayModelsHandlerForTest( &gatewayModelsAccountRepoStub{ byGroup: map[int64][]service.Account{ groupID: { {ID: 1, Platform: service.PlatformOpenAI}, }, }, }, ) rec := httptest.NewRecorder() c, _ := gin.CreateTestContext(rec) c.Request = httptest.NewRequest(http.MethodGet, "/v1/models", nil) c.Set(string(middleware2.ContextKeyAPIKey), &service.APIKey{ Group: &service.Group{ ID: groupID, Platform: service.PlatformOpenAI, ModelsListConfig: service.GroupModelsListConfig{ Enabled: true, Models: []string{"gpt-5.5", "gpt-5.4"}, }, }, }) h.Models(c) require.Equal(t, http.StatusOK, rec.Code) var got gatewayModelsResponseForTest require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &got)) require.Equal(t, []string{"gpt-5.5", "gpt-5.4"}, modelIDsForTest(got.Data)) require.Equal(t, "model", got.Data[0].Object) require.NotZero(t, got.Data[0].Created) require.Equal(t, "openai", got.Data[0].OwnedBy) require.Empty(t, got.Data[0].CreatedAt) } func modelIDsForTest(models []gatewayModelItemForTest) []string { ids := make([]string, 0, len(models)) for _, model := range models { ids = append(ids, model.ID) } return ids }