diff --git a/backend/internal/handler/gateway_handler.go b/backend/internal/handler/gateway_handler.go index 65836a7e..4fe89615 100644 --- a/backend/internal/handler/gateway_handler.go +++ b/backend/internal/handler/gateway_handler.go @@ -18,6 +18,7 @@ import ( "github.com/Wei-Shaw/sub2api/internal/pkg/claude" "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" pkgerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" + "github.com/Wei-Shaw/sub2api/internal/pkg/geminicli" pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil" "github.com/Wei-Shaw/sub2api/internal/pkg/ip" "github.com/Wei-Shaw/sub2api/internal/pkg/logger" @@ -946,8 +947,8 @@ func (h *GatewayHandler) Models(c *gin.Context) { platform = forcedPlatform } - // Get available models from account configurations (without platform filter) - availableModels := h.gatewayService.GetAvailableModels(c.Request.Context(), groupID, "") + // Get available models from account configurations for the selected group platform. + availableModels := h.gatewayService.GetAvailableModels(c.Request.Context(), groupID, platform) if len(availableModels) > 0 { // Build model list from whitelist @@ -968,7 +969,7 @@ func (h *GatewayHandler) Models(c *gin.Context) { } // Fallback to default models - if platform == "openai" { + if platform == service.PlatformOpenAI { c.JSON(http.StatusOK, gin.H{ "object": "list", "data": openai.DefaultModels, @@ -976,6 +977,14 @@ func (h *GatewayHandler) Models(c *gin.Context) { return } + if platform == service.PlatformGemini { + c.JSON(http.StatusOK, gin.H{ + "object": "list", + "data": geminicli.DefaultModels, + }) + return + } + c.JSON(http.StatusOK, gin.H{ "object": "list", "data": claude.DefaultModels, diff --git a/backend/internal/handler/gateway_models_test.go b/backend/internal/handler/gateway_models_test.go new file mode 100644 index 00000000..af52ae23 --- /dev/null +++ b/backend/internal/handler/gateway_models_test.go @@ -0,0 +1,136 @@ +package handler + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +type gatewayModelsAccountRepoStub struct { + service.AccountRepository + + byGroup map[int64][]service.Account +} + +type gatewayModelsResponseForTest struct { + Object string `json:"object"` + Data []gatewayModelItemForTest `json:"data"` +} + +type gatewayModelItemForTest struct { + ID string `json:"id"` +} + +func (s *gatewayModelsAccountRepoStub) ListSchedulableByGroupID(ctx context.Context, groupID int64) ([]service.Account, error) { + accounts, ok := s.byGroup[groupID] + if !ok { + return nil, nil + } + out := make([]service.Account, len(accounts)) + copy(out, accounts) + return out, nil +} + +func newGatewayModelsHandlerForTest(repo service.AccountRepository) *GatewayHandler { + return &GatewayHandler{ + gatewayService: service.NewGatewayService( + repo, + nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, + nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, + ), + } +} + +func TestGatewayModels_GeminiGroupFallsBackToGeminiModels(t *testing.T) { + gin.SetMode(gin.TestMode) + + groupID := int64(20) + h := newGatewayModelsHandlerForTest( + &gatewayModelsAccountRepoStub{ + byGroup: map[int64][]service.Account{ + groupID: { + {ID: 1, Platform: service.PlatformGemini}, + }, + }, + }, + ) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodGet, "/v1/models", nil) + c.Set(string(middleware2.ContextKeyAPIKey), &service.APIKey{ + Group: &service.Group{ID: groupID, Platform: service.PlatformGemini}, + }) + + h.Models(c) + + require.Equal(t, http.StatusOK, rec.Code) + + var got gatewayModelsResponseForTest + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &got)) + require.Equal(t, "list", got.Object) + require.Contains(t, modelIDsForTest(got.Data), "gemini-2.5-flash") + require.NotContains(t, modelIDsForTest(got.Data), "claude-sonnet-4-6") +} + +func TestGatewayModels_GeminiGroupFiltersMappedModelsByPlatform(t *testing.T) { + gin.SetMode(gin.TestMode) + + groupID := int64(21) + h := newGatewayModelsHandlerForTest( + &gatewayModelsAccountRepoStub{ + byGroup: map[int64][]service.Account{ + groupID: { + { + ID: 1, + Platform: service.PlatformAnthropic, + Credentials: map[string]any{ + "model_mapping": map[string]any{ + "claude-sonnet-4-6": "claude-sonnet-4-6", + }, + }, + }, + { + ID: 2, + Platform: service.PlatformGemini, + Credentials: map[string]any{ + "model_mapping": map[string]any{ + "gemini-2.5-flash": "gemini-2.5-flash", + }, + }, + }, + }, + }, + }, + ) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodGet, "/v1/models", nil) + c.Set(string(middleware2.ContextKeyAPIKey), &service.APIKey{ + Group: &service.Group{ID: groupID, Platform: service.PlatformGemini}, + }) + + h.Models(c) + + require.Equal(t, http.StatusOK, rec.Code) + + var got gatewayModelsResponseForTest + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &got)) + require.Equal(t, []string{"gemini-2.5-flash"}, modelIDsForTest(got.Data)) +} + +func modelIDsForTest(models []gatewayModelItemForTest) []string { + ids := make([]string, 0, len(models)) + for _, model := range models { + ids = append(ids, model.ID) + } + return ids +}