package handler import ( "context" "encoding/json" "net/http" "net/http/httptest" "testing" middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" "github.com/Wei-Shaw/sub2api/internal/service" "github.com/gin-gonic/gin" "github.com/stretchr/testify/require" ) type gatewayModelsAccountRepoStub struct { service.AccountRepository byGroup map[int64][]service.Account } type gatewayModelsResponseForTest struct { Object string `json:"object"` Data []gatewayModelItemForTest `json:"data"` } type gatewayModelItemForTest struct { ID string `json:"id"` } func (s *gatewayModelsAccountRepoStub) ListSchedulableByGroupID(ctx context.Context, groupID int64) ([]service.Account, error) { accounts, ok := s.byGroup[groupID] if !ok { return nil, nil } out := make([]service.Account, len(accounts)) copy(out, accounts) return out, nil } func newGatewayModelsHandlerForTest(repo service.AccountRepository) *GatewayHandler { return &GatewayHandler{ gatewayService: service.NewGatewayService( repo, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, ), } } func TestGatewayModels_GeminiGroupFallsBackToGeminiModels(t *testing.T) { gin.SetMode(gin.TestMode) groupID := int64(20) h := newGatewayModelsHandlerForTest( &gatewayModelsAccountRepoStub{ byGroup: map[int64][]service.Account{ groupID: { {ID: 1, Platform: service.PlatformGemini}, }, }, }, ) rec := httptest.NewRecorder() c, _ := gin.CreateTestContext(rec) c.Request = httptest.NewRequest(http.MethodGet, "/v1/models", nil) c.Set(string(middleware2.ContextKeyAPIKey), &service.APIKey{ Group: &service.Group{ID: groupID, Platform: service.PlatformGemini}, }) h.Models(c) require.Equal(t, http.StatusOK, rec.Code) var got gatewayModelsResponseForTest require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &got)) require.Equal(t, "list", got.Object) require.Contains(t, modelIDsForTest(got.Data), "gemini-2.5-flash") require.NotContains(t, modelIDsForTest(got.Data), "claude-sonnet-4-6") } func TestGatewayModels_GeminiGroupFiltersMappedModelsByPlatform(t *testing.T) { gin.SetMode(gin.TestMode) groupID := int64(21) h := newGatewayModelsHandlerForTest( &gatewayModelsAccountRepoStub{ byGroup: map[int64][]service.Account{ groupID: { { ID: 1, Platform: service.PlatformAnthropic, Credentials: map[string]any{ "model_mapping": map[string]any{ "claude-sonnet-4-6": "claude-sonnet-4-6", }, }, }, { ID: 2, Platform: service.PlatformGemini, Credentials: map[string]any{ "model_mapping": map[string]any{ "gemini-2.5-flash": "gemini-2.5-flash", }, }, }, }, }, }, ) rec := httptest.NewRecorder() c, _ := gin.CreateTestContext(rec) c.Request = httptest.NewRequest(http.MethodGet, "/v1/models", nil) c.Set(string(middleware2.ContextKeyAPIKey), &service.APIKey{ Group: &service.Group{ID: groupID, Platform: service.PlatformGemini}, }) h.Models(c) require.Equal(t, http.StatusOK, rec.Code) var got gatewayModelsResponseForTest require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &got)) require.Equal(t, []string{"gemini-2.5-flash"}, modelIDsForTest(got.Data)) } func modelIDsForTest(models []gatewayModelItemForTest) []string { ids := make([]string, 0, len(models)) for _, model := range models { ids = append(ids, model.ID) } return ids }