sub2api/backend/internal/handler/gateway_models_test.go
win f519a02ec9 chore: merge upstream Wei-Shaw/sub2api v0.1.132
Conflicts resolved (preserving fork customizations):
- config.go: keep NodeTLSProxy + add upstream OpenAIHTTP2
- gateway_service.go: NewGatewayService now takes both rpmTokenBucketSvc
  (local) and userPlatformQuotaRepo (upstream)
- wire_gen.go: wire both new args into the call site
- http_upstream.go: drop redundant settings re-assignment; keep proxy
  URL log redaction
- http_upstream_test.go: adopt upstream's explicit-0-disables semantics;
  keep 600s default constant in nil-cfg fallback test
- user_handler_test.go / gateway_record_usage_test.go: pick up new
  userPlatformQuotaRepo nil parameter

Also updated test stubs (windsurf_google_login_test.go,
windsurf_tier_access_service_test.go, gateway_models_test.go) for new
SetModelRateLimit variadic signature and the extra NewGatewayService arg.

Upstream highlights: OpenAI embeddings gateway, user x platform USD
quota, content-moderation risk thresholds, OAuth 401 credentials
no-overwrite fix, HTTP/2 OpenAI upstream config, pool retry status code
configurability, long-context cache pricing multipliers.
2026-05-29 07:21:32 +08:00

402 lines
10 KiB
Go

package handler
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
type gatewayModelsAccountRepoStub struct {
service.AccountRepository
byGroup map[int64][]service.Account
}
type gatewayModelsResponseForTest struct {
Object string `json:"object"`
Data []gatewayModelItemForTest `json:"data"`
}
type gatewayModelItemForTest struct {
ID string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
OwnedBy string `json:"owned_by"`
CreatedAt string `json:"created_at"`
}
func (s *gatewayModelsAccountRepoStub) ListSchedulableByGroupID(ctx context.Context, groupID int64) ([]service.Account, error) {
accounts, ok := s.byGroup[groupID]
if !ok {
return nil, nil
}
out := make([]service.Account, len(accounts))
copy(out, accounts)
return out, nil
}
func newGatewayModelsHandlerForTest(repo service.AccountRepository) *GatewayHandler {
return &GatewayHandler{
gatewayService: service.NewGatewayService(
repo,
nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil,
nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil,
),
}
}
func TestGatewayModels_GeminiGroupFallsBackToGeminiModels(t *testing.T) {
gin.SetMode(gin.TestMode)
groupID := int64(20)
h := newGatewayModelsHandlerForTest(
&gatewayModelsAccountRepoStub{
byGroup: map[int64][]service.Account{
groupID: {
{ID: 1, Platform: service.PlatformGemini},
},
},
},
)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodGet, "/v1/models", nil)
c.Set(string(middleware2.ContextKeyAPIKey), &service.APIKey{
Group: &service.Group{ID: groupID, Platform: service.PlatformGemini},
})
h.Models(c)
require.Equal(t, http.StatusOK, rec.Code)
var got gatewayModelsResponseForTest
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &got))
require.Equal(t, "list", got.Object)
require.Contains(t, modelIDsForTest(got.Data), "gemini-2.5-flash")
require.NotContains(t, modelIDsForTest(got.Data), "claude-sonnet-4-6")
}
func TestGatewayModels_GeminiGroupFiltersMappedModelsByPlatform(t *testing.T) {
gin.SetMode(gin.TestMode)
groupID := int64(21)
h := newGatewayModelsHandlerForTest(
&gatewayModelsAccountRepoStub{
byGroup: map[int64][]service.Account{
groupID: {
{
ID: 1,
Platform: service.PlatformAnthropic,
Credentials: map[string]any{
"model_mapping": map[string]any{
"claude-sonnet-4-6": "claude-sonnet-4-6",
},
},
},
{
ID: 2,
Platform: service.PlatformGemini,
Credentials: map[string]any{
"model_mapping": map[string]any{
"gemini-2.5-flash": "gemini-2.5-flash",
},
},
},
},
},
},
)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodGet, "/v1/models", nil)
c.Set(string(middleware2.ContextKeyAPIKey), &service.APIKey{
Group: &service.Group{ID: groupID, Platform: service.PlatformGemini},
})
h.Models(c)
require.Equal(t, http.StatusOK, rec.Code)
var got gatewayModelsResponseForTest
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &got))
require.Equal(t, []string{"gemini-2.5-flash"}, modelIDsForTest(got.Data))
}
func TestGatewayModels_CustomModelsListDisabledKeepsOriginalModels(t *testing.T) {
gin.SetMode(gin.TestMode)
groupID := int64(22)
h := newGatewayModelsHandlerForTest(
&gatewayModelsAccountRepoStub{
byGroup: map[int64][]service.Account{
groupID: {
{
ID: 1,
Platform: service.PlatformOpenAI,
Credentials: map[string]any{
"model_mapping": map[string]any{
"gpt-5.5": "gpt-5.5",
"gpt-5.4": "gpt-5.4",
},
},
},
},
},
},
)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodGet, "/v1/models", nil)
c.Set(string(middleware2.ContextKeyAPIKey), &service.APIKey{
Group: &service.Group{
ID: groupID,
Platform: service.PlatformOpenAI,
ModelsListConfig: service.GroupModelsListConfig{
Enabled: false,
Models: []string{"gpt-5.5"},
},
},
})
h.Models(c)
require.Equal(t, http.StatusOK, rec.Code)
var got gatewayModelsResponseForTest
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &got))
require.Equal(t, []string{"gpt-5.4", "gpt-5.5"}, modelIDsForTest(got.Data))
}
func TestGatewayModels_CustomModelsListFiltersAndOrdersMappedModels(t *testing.T) {
gin.SetMode(gin.TestMode)
groupID := int64(23)
h := newGatewayModelsHandlerForTest(
&gatewayModelsAccountRepoStub{
byGroup: map[int64][]service.Account{
groupID: {
{
ID: 1,
Platform: service.PlatformOpenAI,
Credentials: map[string]any{
"model_mapping": map[string]any{
"gpt-5.4": "gpt-5.4",
"gpt-5.5": "gpt-5.5",
"legacy-gpt-2024": "legacy-gpt-2024",
},
},
},
},
},
},
)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodGet, "/v1/models", nil)
c.Set(string(middleware2.ContextKeyAPIKey), &service.APIKey{
Group: &service.Group{
ID: groupID,
Platform: service.PlatformOpenAI,
ModelsListConfig: service.GroupModelsListConfig{
Enabled: true,
Models: []string{"gpt-5.5", "missing-model", "gpt-5.4"},
},
},
})
h.Models(c)
require.Equal(t, http.StatusOK, rec.Code)
var got gatewayModelsResponseForTest
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &got))
require.Equal(t, []string{"gpt-5.5", "gpt-5.4"}, modelIDsForTest(got.Data))
}
func TestGatewayModels_CustomModelsListKeepsConcreteModelAllowedByWildcardMapping(t *testing.T) {
gin.SetMode(gin.TestMode)
groupID := int64(26)
h := newGatewayModelsHandlerForTest(
&gatewayModelsAccountRepoStub{
byGroup: map[int64][]service.Account{
groupID: {
{
ID: 1,
Platform: service.PlatformAnthropic,
Credentials: map[string]any{
"model_mapping": map[string]any{
"claude-*": "claude-sonnet-4-6",
},
},
},
},
},
},
)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodGet, "/v1/models", nil)
c.Set(string(middleware2.ContextKeyAPIKey), &service.APIKey{
Group: &service.Group{
ID: groupID,
Platform: service.PlatformAnthropic,
ModelsListConfig: service.GroupModelsListConfig{
Enabled: true,
Models: []string{"claude-sonnet-4-6"},
},
},
})
h.Models(c)
require.Equal(t, http.StatusOK, rec.Code)
var got gatewayModelsResponseForTest
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &got))
require.Equal(t, []string{"claude-sonnet-4-6"}, modelIDsForTest(got.Data))
}
func TestGatewayModels_CustomModelsListCanReturnEmptyWhenSelectionsUnavailable(t *testing.T) {
gin.SetMode(gin.TestMode)
groupID := int64(24)
h := newGatewayModelsHandlerForTest(
&gatewayModelsAccountRepoStub{
byGroup: map[int64][]service.Account{
groupID: {
{
ID: 1,
Platform: service.PlatformOpenAI,
Credentials: map[string]any{
"model_mapping": map[string]any{
"gpt-5.4": "gpt-5.4",
},
},
},
},
},
},
)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodGet, "/v1/models", nil)
c.Set(string(middleware2.ContextKeyAPIKey), &service.APIKey{
Group: &service.Group{
ID: groupID,
Platform: service.PlatformOpenAI,
ModelsListConfig: service.GroupModelsListConfig{
Enabled: true,
Models: []string{"gpt-5.5"},
},
},
})
h.Models(c)
require.Equal(t, http.StatusOK, rec.Code)
var got gatewayModelsResponseForTest
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &got))
require.Empty(t, modelIDsForTest(got.Data))
}
func TestGatewayModels_CustomModelsListFiltersDefaultFallbackModels(t *testing.T) {
gin.SetMode(gin.TestMode)
groupID := int64(25)
h := newGatewayModelsHandlerForTest(
&gatewayModelsAccountRepoStub{
byGroup: map[int64][]service.Account{
groupID: {
{ID: 1, Platform: service.PlatformOpenAI},
},
},
},
)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodGet, "/v1/models", nil)
c.Set(string(middleware2.ContextKeyAPIKey), &service.APIKey{
Group: &service.Group{
ID: groupID,
Platform: service.PlatformOpenAI,
ModelsListConfig: service.GroupModelsListConfig{
Enabled: true,
Models: []string{"gpt-5.5", "legacy-gpt-2024", "gpt-5.4"},
},
},
})
h.Models(c)
require.Equal(t, http.StatusOK, rec.Code)
var got gatewayModelsResponseForTest
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &got))
require.Equal(t, []string{"gpt-5.5", "gpt-5.4"}, modelIDsForTest(got.Data))
}
func TestGatewayModels_OpenAICustomModelsListKeepsOpenAIResponseShapeForDefaultFallback(t *testing.T) {
gin.SetMode(gin.TestMode)
groupID := int64(27)
h := newGatewayModelsHandlerForTest(
&gatewayModelsAccountRepoStub{
byGroup: map[int64][]service.Account{
groupID: {
{ID: 1, Platform: service.PlatformOpenAI},
},
},
},
)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodGet, "/v1/models", nil)
c.Set(string(middleware2.ContextKeyAPIKey), &service.APIKey{
Group: &service.Group{
ID: groupID,
Platform: service.PlatformOpenAI,
ModelsListConfig: service.GroupModelsListConfig{
Enabled: true,
Models: []string{"gpt-5.5", "gpt-5.4"},
},
},
})
h.Models(c)
require.Equal(t, http.StatusOK, rec.Code)
var got gatewayModelsResponseForTest
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &got))
require.Equal(t, []string{"gpt-5.5", "gpt-5.4"}, modelIDsForTest(got.Data))
require.Equal(t, "model", got.Data[0].Object)
require.NotZero(t, got.Data[0].Created)
require.Equal(t, "openai", got.Data[0].OwnedBy)
require.Empty(t, got.Data[0].CreatedAt)
}
func modelIDsForTest(models []gatewayModelItemForTest) []string {
ids := make([]string, 0, len(models))
for _, model := range models {
ids = append(ids, model.ID)
}
return ids
}