feat(gateway): add web search emulation for Anthropic API Key accounts
Inject web search capability for Claude Console (API Key) accounts that don't natively support Anthropic's web_search tool. When a pure web_search request is detected, the gateway calls Brave Search or Tavily API directly and constructs an Anthropic-protocol-compliant SSE/JSON response without forwarding to upstream. Backend: - New `pkg/websearch/` SDK: Brave and Tavily provider implementations with io.LimitReader, proxy support, and Redis-based quota tracking (Lua atomic INCR + TTL, DECR rollback on failure) - Global config via `settings.web_search_emulation_config` (JSON) with in-process cache + singleflight, input validation, API key merge on save, and sanitized API responses - Channel-level toggle via `channels.features_config` JSONB column (DB migration 101) - Account-level toggle via `accounts.extra.web_search_emulation` - Request interception in `Forward()` with SSE streaming response construction using json.Marshal (no manual string concatenation) - Manager hot-reload: `RebuildWebSearchManager()` called on config save and startup via `SetWebSearchRedisClient()` - 70 unit tests covering providers, manager, config validation, sanitization, tool detection, query extraction, and response building Frontend: - Settings → Gateway tab: Web Search Emulation config card with global toggle, provider list (add/remove, API key, priority, quota, proxy) - Channels → Anthropic tab: web search emulation toggle with global state linkage (disabled when global off) - Account Create/Edit modals: web search emulation toggle for API Key type with Toggle component - Full i18n coverage (zh + en)
This commit is contained in:
parent
c738cfec93
commit
1b53ffcac7
@ -34,6 +34,7 @@ type createChannelRequest struct {
|
|||||||
BillingModelSource string `json:"billing_model_source" binding:"omitempty,oneof=requested upstream channel_mapped"`
|
BillingModelSource string `json:"billing_model_source" binding:"omitempty,oneof=requested upstream channel_mapped"`
|
||||||
RestrictModels bool `json:"restrict_models"`
|
RestrictModels bool `json:"restrict_models"`
|
||||||
Features string `json:"features"`
|
Features string `json:"features"`
|
||||||
|
FeaturesConfig map[string]any `json:"features_config"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type updateChannelRequest struct {
|
type updateChannelRequest struct {
|
||||||
@ -46,6 +47,7 @@ type updateChannelRequest struct {
|
|||||||
BillingModelSource string `json:"billing_model_source" binding:"omitempty,oneof=requested upstream channel_mapped"`
|
BillingModelSource string `json:"billing_model_source" binding:"omitempty,oneof=requested upstream channel_mapped"`
|
||||||
RestrictModels *bool `json:"restrict_models"`
|
RestrictModels *bool `json:"restrict_models"`
|
||||||
Features *string `json:"features"`
|
Features *string `json:"features"`
|
||||||
|
FeaturesConfig map[string]any `json:"features_config"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type channelModelPricingRequest struct {
|
type channelModelPricingRequest struct {
|
||||||
@ -81,6 +83,7 @@ type channelResponse struct {
|
|||||||
BillingModelSource string `json:"billing_model_source"`
|
BillingModelSource string `json:"billing_model_source"`
|
||||||
RestrictModels bool `json:"restrict_models"`
|
RestrictModels bool `json:"restrict_models"`
|
||||||
Features string `json:"features"`
|
Features string `json:"features"`
|
||||||
|
FeaturesConfig map[string]any `json:"features_config"`
|
||||||
GroupIDs []int64 `json:"group_ids"`
|
GroupIDs []int64 `json:"group_ids"`
|
||||||
ModelPricing []channelModelPricingResponse `json:"model_pricing"`
|
ModelPricing []channelModelPricingResponse `json:"model_pricing"`
|
||||||
ModelMapping map[string]map[string]string `json:"model_mapping"`
|
ModelMapping map[string]map[string]string `json:"model_mapping"`
|
||||||
@ -126,6 +129,7 @@ func channelToResponse(ch *service.Channel) *channelResponse {
|
|||||||
Status: ch.Status,
|
Status: ch.Status,
|
||||||
RestrictModels: ch.RestrictModels,
|
RestrictModels: ch.RestrictModels,
|
||||||
Features: ch.Features,
|
Features: ch.Features,
|
||||||
|
FeaturesConfig: ch.FeaturesConfig,
|
||||||
GroupIDs: ch.GroupIDs,
|
GroupIDs: ch.GroupIDs,
|
||||||
ModelMapping: ch.ModelMapping,
|
ModelMapping: ch.ModelMapping,
|
||||||
CreatedAt: ch.CreatedAt.Format("2006-01-02T15:04:05Z"),
|
CreatedAt: ch.CreatedAt.Format("2006-01-02T15:04:05Z"),
|
||||||
@ -305,6 +309,7 @@ func (h *ChannelHandler) Create(c *gin.Context) {
|
|||||||
BillingModelSource: req.BillingModelSource,
|
BillingModelSource: req.BillingModelSource,
|
||||||
RestrictModels: req.RestrictModels,
|
RestrictModels: req.RestrictModels,
|
||||||
Features: req.Features,
|
Features: req.Features,
|
||||||
|
FeaturesConfig: req.FeaturesConfig,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
response.ErrorFrom(c, err)
|
response.ErrorFrom(c, err)
|
||||||
@ -338,6 +343,7 @@ func (h *ChannelHandler) Update(c *gin.Context) {
|
|||||||
BillingModelSource: req.BillingModelSource,
|
BillingModelSource: req.BillingModelSource,
|
||||||
RestrictModels: req.RestrictModels,
|
RestrictModels: req.RestrictModels,
|
||||||
Features: req.Features,
|
Features: req.Features,
|
||||||
|
FeaturesConfig: req.FeaturesConfig,
|
||||||
}
|
}
|
||||||
if req.ModelPricing != nil {
|
if req.ModelPricing != nil {
|
||||||
pricing := pricingRequestToService(*req.ModelPricing)
|
pricing := pricingRequestToService(*req.ModelPricing)
|
||||||
|
|||||||
@ -175,6 +175,7 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
|
|||||||
EnableFingerprintUnification: settings.EnableFingerprintUnification,
|
EnableFingerprintUnification: settings.EnableFingerprintUnification,
|
||||||
EnableMetadataPassthrough: settings.EnableMetadataPassthrough,
|
EnableMetadataPassthrough: settings.EnableMetadataPassthrough,
|
||||||
EnableCCHSigning: settings.EnableCCHSigning,
|
EnableCCHSigning: settings.EnableCCHSigning,
|
||||||
|
WebSearchEmulationEnabled: settings.WebSearchEmulationEnabled,
|
||||||
PaymentEnabled: paymentCfg.Enabled,
|
PaymentEnabled: paymentCfg.Enabled,
|
||||||
PaymentMinAmount: paymentCfg.MinAmount,
|
PaymentMinAmount: paymentCfg.MinAmount,
|
||||||
PaymentMaxAmount: paymentCfg.MaxAmount,
|
PaymentMaxAmount: paymentCfg.MaxAmount,
|
||||||
@ -1847,3 +1848,37 @@ func (h *SettingHandler) UpdateStreamTimeoutSettings(c *gin.Context) {
|
|||||||
ThresholdWindowMinutes: updatedSettings.ThresholdWindowMinutes,
|
ThresholdWindowMinutes: updatedSettings.ThresholdWindowMinutes,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetWebSearchEmulationConfig 获取 Web Search 模拟配置
|
||||||
|
// GET /api/v1/admin/settings/web-search-emulation
|
||||||
|
func (h *SettingHandler) GetWebSearchEmulationConfig(c *gin.Context) {
|
||||||
|
cfg, err := h.settingService.GetWebSearchEmulationConfig(c.Request.Context())
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
response.Success(c, service.SanitizeWebSearchConfig(cfg))
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateWebSearchEmulationConfig 更新 Web Search 模拟配置
|
||||||
|
// PUT /api/v1/admin/settings/web-search-emulation
|
||||||
|
func (h *SettingHandler) UpdateWebSearchEmulationConfig(c *gin.Context) {
|
||||||
|
var cfg service.WebSearchEmulationConfig
|
||||||
|
if err := c.ShouldBindJSON(&cfg); err != nil {
|
||||||
|
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := h.settingService.SaveWebSearchEmulationConfig(c.Request.Context(), &cfg); err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Re-read (with sanitized api keys) to return current state
|
||||||
|
updated, err := h.settingService.GetWebSearchEmulationConfig(c.Request.Context())
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
response.Success(c, service.SanitizeWebSearchConfig(updated))
|
||||||
|
}
|
||||||
|
|||||||
@ -124,6 +124,9 @@ type SystemSettings struct {
|
|||||||
EnableMetadataPassthrough bool `json:"enable_metadata_passthrough"`
|
EnableMetadataPassthrough bool `json:"enable_metadata_passthrough"`
|
||||||
EnableCCHSigning bool `json:"enable_cch_signing"`
|
EnableCCHSigning bool `json:"enable_cch_signing"`
|
||||||
|
|
||||||
|
// Web Search Emulation
|
||||||
|
WebSearchEmulationEnabled bool `json:"web_search_emulation_enabled"`
|
||||||
|
|
||||||
// Payment configuration
|
// Payment configuration
|
||||||
PaymentEnabled bool `json:"payment_enabled"`
|
PaymentEnabled bool `json:"payment_enabled"`
|
||||||
PaymentMinAmount float64 `json:"payment_min_amount"`
|
PaymentMinAmount float64 `json:"payment_min_amount"`
|
||||||
|
|||||||
106
backend/internal/pkg/websearch/brave.go
Normal file
106
backend/internal/pkg/websearch/brave.go
Normal file
@ -0,0 +1,106 @@
|
|||||||
|
package websearch
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"strconv"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
braveSearchEndpoint = "https://api.search.brave.com/res/v1/web/search"
|
||||||
|
braveMaxCount = 20
|
||||||
|
braveProviderName = "brave"
|
||||||
|
)
|
||||||
|
|
||||||
|
// braveSearchURL is pre-parsed at init time; url.Parse cannot fail on a constant literal.
|
||||||
|
var braveSearchURL, _ = url.Parse(braveSearchEndpoint) //nolint:errcheck
|
||||||
|
|
||||||
|
// BraveProvider implements web search via the Brave Search API.
|
||||||
|
type BraveProvider struct {
|
||||||
|
apiKey string
|
||||||
|
httpClient *http.Client
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewBraveProvider creates a Brave Search provider.
|
||||||
|
// The caller is responsible for configuring the http.Client with proxy/timeouts.
|
||||||
|
func NewBraveProvider(apiKey string, httpClient *http.Client) *BraveProvider {
|
||||||
|
if httpClient == nil {
|
||||||
|
httpClient = http.DefaultClient
|
||||||
|
}
|
||||||
|
return &BraveProvider{apiKey: apiKey, httpClient: httpClient}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *BraveProvider) Name() string { return braveProviderName }
|
||||||
|
|
||||||
|
func (b *BraveProvider) Search(ctx context.Context, req SearchRequest) (*SearchResponse, error) {
|
||||||
|
count := req.MaxResults
|
||||||
|
if count <= 0 {
|
||||||
|
count = defaultMaxResults
|
||||||
|
}
|
||||||
|
if count > braveMaxCount {
|
||||||
|
count = braveMaxCount
|
||||||
|
}
|
||||||
|
|
||||||
|
u := *braveSearchURL // copy the pre-parsed URL
|
||||||
|
q := u.Query()
|
||||||
|
q.Set("q", req.Query)
|
||||||
|
q.Set("count", strconv.Itoa(count))
|
||||||
|
u.RawQuery = q.Encode()
|
||||||
|
|
||||||
|
httpReq, err := http.NewRequestWithContext(ctx, http.MethodGet, u.String(), nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("brave: build request: %w", err)
|
||||||
|
}
|
||||||
|
httpReq.Header.Set("X-Subscription-Token", b.apiKey)
|
||||||
|
httpReq.Header.Set("Accept", "application/json")
|
||||||
|
|
||||||
|
resp, err := b.httpClient.Do(httpReq)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("brave: request failed: %w", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
body, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseSize))
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("brave: read body: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
return nil, fmt.Errorf("brave: status %d: %s", resp.StatusCode, truncateBody(body))
|
||||||
|
}
|
||||||
|
|
||||||
|
var raw braveResponse
|
||||||
|
if err := json.Unmarshal(body, &raw); err != nil {
|
||||||
|
return nil, fmt.Errorf("brave: decode response: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
results := make([]SearchResult, 0, len(raw.Web.Results))
|
||||||
|
for _, r := range raw.Web.Results {
|
||||||
|
results = append(results, SearchResult{
|
||||||
|
URL: r.URL,
|
||||||
|
Title: r.Title,
|
||||||
|
Snippet: r.Description,
|
||||||
|
PageAge: r.Age,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return &SearchResponse{Results: results, Query: req.Query}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// braveResponse is the minimal structure of the Brave Search API response.
|
||||||
|
type braveResponse struct {
|
||||||
|
Web struct {
|
||||||
|
Results []braveResult `json:"results"`
|
||||||
|
} `json:"web"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type braveResult struct {
|
||||||
|
URL string `json:"url"`
|
||||||
|
Title string `json:"title"`
|
||||||
|
Description string `json:"description"`
|
||||||
|
Age string `json:"age"`
|
||||||
|
}
|
||||||
119
backend/internal/pkg/websearch/brave_test.go
Normal file
119
backend/internal/pkg/websearch/brave_test.go
Normal file
@ -0,0 +1,119 @@
|
|||||||
|
package websearch
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestBraveProvider_Name(t *testing.T) {
|
||||||
|
p := NewBraveProvider("key", nil)
|
||||||
|
require.Equal(t, "brave", p.Name())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBraveProvider_Search_Success(t *testing.T) {
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
require.Equal(t, "test-key", r.Header.Get("X-Subscription-Token"))
|
||||||
|
require.Equal(t, "application/json", r.Header.Get("Accept"))
|
||||||
|
require.Equal(t, "golang", r.URL.Query().Get("q"))
|
||||||
|
require.Equal(t, "3", r.URL.Query().Get("count"))
|
||||||
|
|
||||||
|
resp := braveResponse{}
|
||||||
|
resp.Web.Results = []braveResult{
|
||||||
|
{URL: "https://go.dev", Title: "Go", Description: "Go lang", Age: "1 day"},
|
||||||
|
{URL: "https://pkg.go.dev", Title: "Pkg", Description: "Packages"},
|
||||||
|
{URL: "https://tour.go.dev", Title: "Tour", Description: "A Tour of Go", Age: "3 days"},
|
||||||
|
}
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
json.NewEncoder(w).Encode(resp)
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
p := NewBraveProvider("test-key", srv.Client())
|
||||||
|
// Override the endpoint for testing
|
||||||
|
origURL := *braveSearchURL
|
||||||
|
u, _ := http.NewRequest("GET", srv.URL, nil)
|
||||||
|
*braveSearchURL = *u.URL
|
||||||
|
defer func() { *braveSearchURL = origURL }()
|
||||||
|
|
||||||
|
resp, err := p.Search(context.Background(), SearchRequest{Query: "golang", MaxResults: 3})
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Len(t, resp.Results, 3)
|
||||||
|
require.Equal(t, "https://go.dev", resp.Results[0].URL)
|
||||||
|
require.Equal(t, "Go lang", resp.Results[0].Snippet)
|
||||||
|
require.Equal(t, "1 day", resp.Results[0].PageAge)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBraveProvider_Search_DefaultMaxResults(t *testing.T) {
|
||||||
|
var receivedCount string
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
receivedCount = r.URL.Query().Get("count")
|
||||||
|
resp := braveResponse{}
|
||||||
|
json.NewEncoder(w).Encode(resp)
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
p := NewBraveProvider("key", srv.Client())
|
||||||
|
origURL := *braveSearchURL
|
||||||
|
u, _ := http.NewRequest("GET", srv.URL, nil)
|
||||||
|
*braveSearchURL = *u.URL
|
||||||
|
defer func() { *braveSearchURL = origURL }()
|
||||||
|
|
||||||
|
_, _ = p.Search(context.Background(), SearchRequest{Query: "test", MaxResults: 0})
|
||||||
|
require.Equal(t, "5", receivedCount)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBraveProvider_Search_HTTPError(t *testing.T) {
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||||
|
w.WriteHeader(429)
|
||||||
|
w.Write([]byte("rate limited"))
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
p := NewBraveProvider("key", srv.Client())
|
||||||
|
origURL := *braveSearchURL
|
||||||
|
u, _ := http.NewRequest("GET", srv.URL, nil)
|
||||||
|
*braveSearchURL = *u.URL
|
||||||
|
defer func() { *braveSearchURL = origURL }()
|
||||||
|
|
||||||
|
_, err := p.Search(context.Background(), SearchRequest{Query: "test"})
|
||||||
|
require.ErrorContains(t, err, "brave: status 429")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBraveProvider_Search_InvalidJSON(t *testing.T) {
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||||
|
w.Write([]byte("not json"))
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
p := NewBraveProvider("key", srv.Client())
|
||||||
|
origURL := *braveSearchURL
|
||||||
|
u, _ := http.NewRequest("GET", srv.URL, nil)
|
||||||
|
*braveSearchURL = *u.URL
|
||||||
|
defer func() { *braveSearchURL = origURL }()
|
||||||
|
|
||||||
|
_, err := p.Search(context.Background(), SearchRequest{Query: "test"})
|
||||||
|
require.ErrorContains(t, err, "brave: decode response")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBraveProvider_Search_EmptyResults(t *testing.T) {
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||||
|
resp := braveResponse{}
|
||||||
|
json.NewEncoder(w).Encode(resp)
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
p := NewBraveProvider("key", srv.Client())
|
||||||
|
origURL := *braveSearchURL
|
||||||
|
u, _ := http.NewRequest("GET", srv.URL, nil)
|
||||||
|
*braveSearchURL = *u.URL
|
||||||
|
defer func() { *braveSearchURL = origURL }()
|
||||||
|
|
||||||
|
resp, err := p.Search(context.Background(), SearchRequest{Query: "test"})
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Empty(t, resp.Results)
|
||||||
|
}
|
||||||
14
backend/internal/pkg/websearch/helpers.go
Normal file
14
backend/internal/pkg/websearch/helpers.go
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
package websearch
|
||||||
|
|
||||||
|
const (
|
||||||
|
maxResponseSize = 1 << 20 // 1 MB
|
||||||
|
errorBodyTruncLen = 200
|
||||||
|
)
|
||||||
|
|
||||||
|
// truncateBody returns a truncated string of body for error messages.
|
||||||
|
func truncateBody(body []byte) string {
|
||||||
|
if len(body) <= errorBodyTruncLen {
|
||||||
|
return string(body)
|
||||||
|
}
|
||||||
|
return string(body[:errorBodyTruncLen]) + "...(truncated)"
|
||||||
|
}
|
||||||
25
backend/internal/pkg/websearch/helpers_test.go
Normal file
25
backend/internal/pkg/websearch/helpers_test.go
Normal file
@ -0,0 +1,25 @@
|
|||||||
|
package websearch
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestTruncateBody_Short(t *testing.T) {
|
||||||
|
body := []byte("short body")
|
||||||
|
require.Equal(t, "short body", truncateBody(body))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTruncateBody_Long(t *testing.T) {
|
||||||
|
body := []byte(strings.Repeat("x", 500))
|
||||||
|
result := truncateBody(body)
|
||||||
|
require.Len(t, result, errorBodyTruncLen+len("...(truncated)"))
|
||||||
|
require.True(t, strings.HasSuffix(result, "...(truncated)"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTruncateBody_ExactBoundary(t *testing.T) {
|
||||||
|
body := []byte(strings.Repeat("x", errorBodyTruncLen))
|
||||||
|
require.Equal(t, string(body), truncateBody(body))
|
||||||
|
}
|
||||||
273
backend/internal/pkg/websearch/manager.go
Normal file
273
backend/internal/pkg/websearch/manager.go
Normal file
@ -0,0 +1,273 @@
|
|||||||
|
package websearch
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/tls"
|
||||||
|
"fmt"
|
||||||
|
"log/slog"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"sort"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/redis/go-redis/v9"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Quota refresh interval constants.
|
||||||
|
const (
|
||||||
|
QuotaRefreshDaily = "daily"
|
||||||
|
QuotaRefreshWeekly = "weekly"
|
||||||
|
QuotaRefreshMonthly = "monthly"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ProviderConfig holds the configuration for a single search provider.
|
||||||
|
type ProviderConfig struct {
|
||||||
|
Type string `json:"type"` // ProviderTypeBrave | ProviderTypeTavily
|
||||||
|
APIKey string `json:"api_key"` // secret
|
||||||
|
Priority int `json:"priority"` // lower = higher priority
|
||||||
|
QuotaLimit int64 `json:"quota_limit"` // 0 = unlimited
|
||||||
|
QuotaRefreshInterval string `json:"quota_refresh_interval"` // QuotaRefreshDaily / Weekly / Monthly
|
||||||
|
ProxyURL string `json:"-"` // resolved proxy URL (not persisted)
|
||||||
|
ExpiresAt *int64 `json:"expires_at,omitempty"` // optional expiration (unix seconds)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Manager selects providers by priority and tracks quota via Redis.
|
||||||
|
type Manager struct {
|
||||||
|
configs []ProviderConfig
|
||||||
|
redis *redis.Client
|
||||||
|
|
||||||
|
clientMu sync.Mutex
|
||||||
|
clientCache map[string]*http.Client
|
||||||
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
quotaKeyPrefix = "websearch:quota:"
|
||||||
|
searchRequestTimeout = 30 * time.Second
|
||||||
|
quotaTTLBuffer = 24 * time.Hour
|
||||||
|
maxCachedClients = 100
|
||||||
|
)
|
||||||
|
|
||||||
|
// quotaIncrScript atomically increments the counter and sets TTL on first creation.
|
||||||
|
// KEYS[1] = quota key, ARGV[1] = TTL in seconds.
|
||||||
|
// Returns the new counter value.
|
||||||
|
var quotaIncrScript = redis.NewScript(`
|
||||||
|
local val = redis.call('INCR', KEYS[1])
|
||||||
|
if val == 1 then
|
||||||
|
redis.call('EXPIRE', KEYS[1], ARGV[1])
|
||||||
|
else
|
||||||
|
-- Defensive: ensure TTL exists even if a prior EXPIRE failed
|
||||||
|
local ttl = redis.call('TTL', KEYS[1])
|
||||||
|
if ttl == -1 then
|
||||||
|
redis.call('EXPIRE', KEYS[1], ARGV[1])
|
||||||
|
end
|
||||||
|
end
|
||||||
|
return val
|
||||||
|
`)
|
||||||
|
|
||||||
|
// NewManager creates a Manager with the given provider configs and Redis client.
|
||||||
|
func NewManager(configs []ProviderConfig, redisClient *redis.Client) *Manager {
|
||||||
|
sorted := make([]ProviderConfig, len(configs))
|
||||||
|
copy(sorted, configs)
|
||||||
|
sort.Slice(sorted, func(i, j int) bool {
|
||||||
|
return sorted[i].Priority < sorted[j].Priority
|
||||||
|
})
|
||||||
|
return &Manager{
|
||||||
|
configs: sorted,
|
||||||
|
redis: redisClient,
|
||||||
|
clientCache: make(map[string]*http.Client),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SearchWithBestProvider selects the highest-priority available provider,
|
||||||
|
// reserves quota, executes the search, and rolls back quota on failure.
|
||||||
|
func (m *Manager) SearchWithBestProvider(ctx context.Context, req SearchRequest) (*SearchResponse, string, error) {
|
||||||
|
if strings.TrimSpace(req.Query) == "" {
|
||||||
|
return nil, "", fmt.Errorf("websearch: empty search query")
|
||||||
|
}
|
||||||
|
for _, cfg := range m.configs {
|
||||||
|
if !m.isProviderAvailable(cfg) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
allowed, incremented := m.tryReserveQuota(ctx, cfg)
|
||||||
|
if !allowed {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
resp, err := m.executeSearch(ctx, cfg, req)
|
||||||
|
if err != nil {
|
||||||
|
if incremented {
|
||||||
|
m.rollbackQuota(ctx, cfg)
|
||||||
|
}
|
||||||
|
slog.Warn("websearch: provider search failed",
|
||||||
|
"provider", cfg.Type, "error", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
return resp, cfg.Type, nil
|
||||||
|
}
|
||||||
|
return nil, "", fmt.Errorf("websearch: no available provider (all exhausted or failed)")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) isProviderAvailable(cfg ProviderConfig) bool {
|
||||||
|
if cfg.APIKey == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if cfg.ExpiresAt != nil && time.Now().Unix() > *cfg.ExpiresAt {
|
||||||
|
slog.Info("websearch: provider expired, skipping",
|
||||||
|
"provider", cfg.Type, "expires_at", *cfg.ExpiresAt)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// tryReserveQuota atomically increments the counter via Lua script and checks limit.
|
||||||
|
// Returns (allowed, incremented): allowed=true means the request may proceed;
|
||||||
|
// incremented=true means the Redis counter was actually incremented (so rollback is needed on failure).
|
||||||
|
func (m *Manager) tryReserveQuota(ctx context.Context, cfg ProviderConfig) (bool, bool) {
|
||||||
|
if cfg.QuotaLimit <= 0 {
|
||||||
|
return true, false // unlimited, no INCR
|
||||||
|
}
|
||||||
|
if m.redis == nil {
|
||||||
|
slog.Warn("websearch: Redis unavailable, quota check skipped",
|
||||||
|
"provider", cfg.Type)
|
||||||
|
return true, false // allowed but not incremented
|
||||||
|
}
|
||||||
|
key := quotaRedisKey(cfg.Type, cfg.QuotaRefreshInterval)
|
||||||
|
ttlSec := int(quotaTTL(cfg.QuotaRefreshInterval).Seconds())
|
||||||
|
|
||||||
|
newVal, err := quotaIncrScript.Run(ctx, m.redis, []string{key}, ttlSec).Int64()
|
||||||
|
if err != nil {
|
||||||
|
slog.Warn("websearch: quota Lua INCR failed, allowing request",
|
||||||
|
"provider", cfg.Type, "error", err)
|
||||||
|
return true, false // allowed but not incremented
|
||||||
|
}
|
||||||
|
if newVal > cfg.QuotaLimit {
|
||||||
|
if decrErr := m.redis.Decr(ctx, key).Err(); decrErr != nil {
|
||||||
|
slog.Warn("websearch: quota over-limit DECR failed",
|
||||||
|
"provider", cfg.Type, "error", decrErr)
|
||||||
|
}
|
||||||
|
slog.Info("websearch: provider quota exhausted",
|
||||||
|
"provider", cfg.Type, "used", newVal, "limit", cfg.QuotaLimit)
|
||||||
|
return false, false // rejected, already rolled back
|
||||||
|
}
|
||||||
|
return true, true // allowed and incremented
|
||||||
|
}
|
||||||
|
|
||||||
|
// rollbackQuota decrements the counter after a search failure.
|
||||||
|
func (m *Manager) rollbackQuota(ctx context.Context, cfg ProviderConfig) {
|
||||||
|
if cfg.QuotaLimit <= 0 || m.redis == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
key := quotaRedisKey(cfg.Type, cfg.QuotaRefreshInterval)
|
||||||
|
if err := m.redis.Decr(ctx, key).Err(); err != nil {
|
||||||
|
slog.Warn("websearch: quota rollback DECR failed",
|
||||||
|
"provider", cfg.Type, "error", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) executeSearch(ctx context.Context, cfg ProviderConfig, req SearchRequest) (*SearchResponse, error) {
|
||||||
|
proxyURL := cfg.ProxyURL
|
||||||
|
if req.ProxyURL != "" {
|
||||||
|
proxyURL = req.ProxyURL
|
||||||
|
}
|
||||||
|
client := m.getOrCreateHTTPClient(proxyURL)
|
||||||
|
provider := m.buildProvider(cfg, client)
|
||||||
|
return provider.Search(ctx, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetUsage returns the current usage count for the given provider.
|
||||||
|
func (m *Manager) GetUsage(ctx context.Context, providerType, refreshInterval string) (int64, error) {
|
||||||
|
if m.redis == nil {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
key := quotaRedisKey(providerType, refreshInterval)
|
||||||
|
val, err := m.redis.Get(ctx, key).Int64()
|
||||||
|
if err == redis.Nil {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
return val, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAllUsage returns usage for every configured provider.
|
||||||
|
func (m *Manager) GetAllUsage(ctx context.Context) map[string]int64 {
|
||||||
|
result := make(map[string]int64, len(m.configs))
|
||||||
|
for _, cfg := range m.configs {
|
||||||
|
used, _ := m.GetUsage(ctx, cfg.Type, cfg.QuotaRefreshInterval)
|
||||||
|
result[cfg.Type] = used
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- HTTP client cache (bounded) ---
|
||||||
|
|
||||||
|
func (m *Manager) getOrCreateHTTPClient(proxyURL string) *http.Client {
|
||||||
|
m.clientMu.Lock()
|
||||||
|
defer m.clientMu.Unlock()
|
||||||
|
|
||||||
|
if c, ok := m.clientCache[proxyURL]; ok {
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
if len(m.clientCache) >= maxCachedClients {
|
||||||
|
m.clientCache = make(map[string]*http.Client) // evict all
|
||||||
|
}
|
||||||
|
c := newHTTPClient(proxyURL)
|
||||||
|
m.clientCache[proxyURL] = c
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
|
||||||
|
func newHTTPClient(proxyURL string) *http.Client {
|
||||||
|
transport := &http.Transport{
|
||||||
|
TLSClientConfig: &tls.Config{MinVersion: tls.VersionTLS12},
|
||||||
|
}
|
||||||
|
if proxyURL != "" {
|
||||||
|
if u, err := url.Parse(proxyURL); err == nil {
|
||||||
|
transport.Proxy = http.ProxyURL(u)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return &http.Client{Transport: transport, Timeout: searchRequestTimeout}
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Provider factory ---
|
||||||
|
|
||||||
|
func (m *Manager) buildProvider(cfg ProviderConfig, client *http.Client) Provider {
|
||||||
|
switch cfg.Type {
|
||||||
|
case braveProviderName:
|
||||||
|
return NewBraveProvider(cfg.APIKey, client)
|
||||||
|
case tavilyProviderName:
|
||||||
|
return NewTavilyProvider(cfg.APIKey, client)
|
||||||
|
default:
|
||||||
|
slog.Warn("websearch: unknown provider type, falling back to brave",
|
||||||
|
"type", cfg.Type)
|
||||||
|
return NewBraveProvider(cfg.APIKey, client)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Redis key helpers ---
|
||||||
|
|
||||||
|
func quotaRedisKey(providerType, refreshInterval string) string {
|
||||||
|
return quotaKeyPrefix + providerType + ":" + periodKey(refreshInterval)
|
||||||
|
}
|
||||||
|
|
||||||
|
func periodKey(refreshInterval string) string {
|
||||||
|
now := time.Now().UTC()
|
||||||
|
switch refreshInterval {
|
||||||
|
case QuotaRefreshDaily:
|
||||||
|
return now.Format("2006-01-02")
|
||||||
|
case QuotaRefreshWeekly:
|
||||||
|
year, week := now.ISOWeek()
|
||||||
|
return fmt.Sprintf("%d-W%02d", year, week)
|
||||||
|
default: // QuotaRefreshMonthly
|
||||||
|
return now.Format("2006-01")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func quotaTTL(refreshInterval string) time.Duration {
|
||||||
|
switch refreshInterval {
|
||||||
|
case QuotaRefreshDaily:
|
||||||
|
return 24*time.Hour + quotaTTLBuffer
|
||||||
|
case QuotaRefreshWeekly:
|
||||||
|
return 7*24*time.Hour + quotaTTLBuffer
|
||||||
|
default:
|
||||||
|
return 31*24*time.Hour + quotaTTLBuffer
|
||||||
|
}
|
||||||
|
}
|
||||||
149
backend/internal/pkg/websearch/manager_test.go
Normal file
149
backend/internal/pkg/websearch/manager_test.go
Normal file
@ -0,0 +1,149 @@
|
|||||||
|
package websearch
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNewManager_SortsByPriority(t *testing.T) {
|
||||||
|
configs := []ProviderConfig{
|
||||||
|
{Type: "brave", APIKey: "k3", Priority: 30},
|
||||||
|
{Type: "tavily", APIKey: "k1", Priority: 10},
|
||||||
|
}
|
||||||
|
m := NewManager(configs, nil)
|
||||||
|
require.Equal(t, 10, m.configs[0].Priority)
|
||||||
|
require.Equal(t, 30, m.configs[1].Priority)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestManager_SearchWithBestProvider_EmptyQuery(t *testing.T) {
|
||||||
|
m := NewManager([]ProviderConfig{{Type: "brave", APIKey: "k"}}, nil)
|
||||||
|
_, _, err := m.SearchWithBestProvider(context.Background(), SearchRequest{Query: ""})
|
||||||
|
require.ErrorContains(t, err, "empty search query")
|
||||||
|
|
||||||
|
_, _, err = m.SearchWithBestProvider(context.Background(), SearchRequest{Query: " "})
|
||||||
|
require.ErrorContains(t, err, "empty search query")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestManager_SearchWithBestProvider_SkipEmptyAPIKey(t *testing.T) {
|
||||||
|
m := NewManager([]ProviderConfig{{Type: "brave", APIKey: ""}}, nil)
|
||||||
|
_, _, err := m.SearchWithBestProvider(context.Background(), SearchRequest{Query: "test"})
|
||||||
|
require.ErrorContains(t, err, "no available provider")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestManager_SearchWithBestProvider_SkipExpired(t *testing.T) {
|
||||||
|
past := time.Now().Add(-1 * time.Hour).Unix()
|
||||||
|
m := NewManager([]ProviderConfig{
|
||||||
|
{Type: "brave", APIKey: "k", ExpiresAt: &past},
|
||||||
|
}, nil)
|
||||||
|
_, _, err := m.SearchWithBestProvider(context.Background(), SearchRequest{Query: "test"})
|
||||||
|
require.ErrorContains(t, err, "no available provider")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestManager_SearchWithBestProvider_PriorityOrder(t *testing.T) {
|
||||||
|
// Create two mock servers that return different results
|
||||||
|
srvBrave := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||||
|
resp := braveResponse{}
|
||||||
|
resp.Web.Results = []braveResult{{URL: "https://brave.com", Title: "Brave", Description: "from brave"}}
|
||||||
|
json.NewEncoder(w).Encode(resp)
|
||||||
|
}))
|
||||||
|
defer srvBrave.Close()
|
||||||
|
|
||||||
|
// Override brave endpoint for test
|
||||||
|
origURL := *braveSearchURL
|
||||||
|
u, _ := http.NewRequest("GET", srvBrave.URL, nil)
|
||||||
|
*braveSearchURL = *u.URL
|
||||||
|
defer func() { *braveSearchURL = origURL }()
|
||||||
|
|
||||||
|
m := NewManager([]ProviderConfig{
|
||||||
|
{Type: "brave", APIKey: "k1", Priority: 1},
|
||||||
|
{Type: "tavily", APIKey: "k2", Priority: 2},
|
||||||
|
}, nil)
|
||||||
|
// Inject the test server's client
|
||||||
|
m.clientCache[srvBrave.URL] = srvBrave.Client()
|
||||||
|
m.clientCache[""] = srvBrave.Client()
|
||||||
|
|
||||||
|
resp, providerName, err := m.SearchWithBestProvider(context.Background(), SearchRequest{Query: "test"})
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, "brave", providerName)
|
||||||
|
require.Len(t, resp.Results, 1)
|
||||||
|
require.Equal(t, "from brave", resp.Results[0].Snippet)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestManager_SearchWithBestProvider_NilRedis(t *testing.T) {
|
||||||
|
// With nil Redis, quota check is skipped (always allowed)
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||||
|
resp := braveResponse{}
|
||||||
|
resp.Web.Results = []braveResult{{URL: "https://test.com", Title: "Test", Description: "result"}}
|
||||||
|
json.NewEncoder(w).Encode(resp)
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
origURL := *braveSearchURL
|
||||||
|
u, _ := http.NewRequest("GET", srv.URL, nil)
|
||||||
|
*braveSearchURL = *u.URL
|
||||||
|
defer func() { *braveSearchURL = origURL }()
|
||||||
|
|
||||||
|
m := NewManager([]ProviderConfig{
|
||||||
|
{Type: "brave", APIKey: "k", Priority: 1, QuotaLimit: 100},
|
||||||
|
}, nil) // nil Redis
|
||||||
|
m.clientCache[""] = srv.Client()
|
||||||
|
|
||||||
|
resp, _, err := m.SearchWithBestProvider(context.Background(), SearchRequest{Query: "test"})
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Len(t, resp.Results, 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestManager_GetUsage_NilRedis(t *testing.T) {
|
||||||
|
m := NewManager(nil, nil)
|
||||||
|
used, err := m.GetUsage(context.Background(), "brave", "monthly")
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, int64(0), used)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestManager_GetAllUsage_NilRedis(t *testing.T) {
|
||||||
|
m := NewManager([]ProviderConfig{
|
||||||
|
{Type: "brave", QuotaRefreshInterval: "monthly"},
|
||||||
|
}, nil)
|
||||||
|
usage := m.GetAllUsage(context.Background())
|
||||||
|
require.Equal(t, int64(0), usage["brave"])
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Key/TTL helpers ---
|
||||||
|
|
||||||
|
func TestQuotaTTL_Daily(t *testing.T) {
|
||||||
|
require.Equal(t, 24*time.Hour+quotaTTLBuffer, quotaTTL(QuotaRefreshDaily))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQuotaTTL_Weekly(t *testing.T) {
|
||||||
|
require.Equal(t, 7*24*time.Hour+quotaTTLBuffer, quotaTTL(QuotaRefreshWeekly))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQuotaTTL_Monthly(t *testing.T) {
|
||||||
|
require.Equal(t, 31*24*time.Hour+quotaTTLBuffer, quotaTTL(QuotaRefreshMonthly))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPeriodKey_Daily(t *testing.T) {
|
||||||
|
key := periodKey(QuotaRefreshDaily)
|
||||||
|
require.Regexp(t, `^\d{4}-\d{2}-\d{2}$`, key)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPeriodKey_Weekly(t *testing.T) {
|
||||||
|
key := periodKey(QuotaRefreshWeekly)
|
||||||
|
require.Regexp(t, `^\d{4}-W\d{2}$`, key)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPeriodKey_Monthly(t *testing.T) {
|
||||||
|
key := periodKey(QuotaRefreshMonthly)
|
||||||
|
require.Regexp(t, `^\d{4}-\d{2}$`, key)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQuotaRedisKey_Format(t *testing.T) {
|
||||||
|
key := quotaRedisKey("brave", QuotaRefreshDaily)
|
||||||
|
require.Contains(t, key, "websearch:quota:brave:")
|
||||||
|
}
|
||||||
11
backend/internal/pkg/websearch/provider.go
Normal file
11
backend/internal/pkg/websearch/provider.go
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
package websearch
|
||||||
|
|
||||||
|
import "context"
|
||||||
|
|
||||||
|
// Provider is the interface every search backend must implement.
|
||||||
|
type Provider interface {
|
||||||
|
// Name returns the provider identifier ("brave" or "tavily").
|
||||||
|
Name() string
|
||||||
|
// Search executes a web search and returns results.
|
||||||
|
Search(ctx context.Context, req SearchRequest) (*SearchResponse, error)
|
||||||
|
}
|
||||||
107
backend/internal/pkg/websearch/tavily.go
Normal file
107
backend/internal/pkg/websearch/tavily.go
Normal file
@ -0,0 +1,107 @@
|
|||||||
|
package websearch
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
tavilySearchEndpoint = "https://api.tavily.com/search"
|
||||||
|
tavilyProviderName = "tavily"
|
||||||
|
tavilySearchDepthBasic = "basic"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TavilyProvider implements web search via the Tavily Search API.
|
||||||
|
type TavilyProvider struct {
|
||||||
|
apiKey string
|
||||||
|
httpClient *http.Client
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewTavilyProvider creates a Tavily Search provider.
|
||||||
|
// The caller is responsible for configuring the http.Client with proxy/timeouts.
|
||||||
|
func NewTavilyProvider(apiKey string, httpClient *http.Client) *TavilyProvider {
|
||||||
|
if httpClient == nil {
|
||||||
|
httpClient = http.DefaultClient
|
||||||
|
}
|
||||||
|
return &TavilyProvider{apiKey: apiKey, httpClient: httpClient}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *TavilyProvider) Name() string { return tavilyProviderName }
|
||||||
|
|
||||||
|
func (t *TavilyProvider) Search(ctx context.Context, req SearchRequest) (*SearchResponse, error) {
|
||||||
|
maxResults := req.MaxResults
|
||||||
|
if maxResults <= 0 {
|
||||||
|
maxResults = defaultMaxResults
|
||||||
|
}
|
||||||
|
|
||||||
|
payload := tavilyRequest{
|
||||||
|
APIKey: t.apiKey,
|
||||||
|
Query: req.Query,
|
||||||
|
MaxResults: maxResults,
|
||||||
|
SearchDepth: tavilySearchDepthBasic,
|
||||||
|
}
|
||||||
|
|
||||||
|
bodyBytes, err := json.Marshal(payload)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("tavily: encode request: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, tavilySearchEndpoint, bytes.NewReader(bodyBytes))
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("tavily: build request: %w", err)
|
||||||
|
}
|
||||||
|
httpReq.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
resp, err := t.httpClient.Do(httpReq)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("tavily: request failed: %w", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
body, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseSize))
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("tavily: read body: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
return nil, fmt.Errorf("tavily: status %d: %s", resp.StatusCode, truncateBody(body))
|
||||||
|
}
|
||||||
|
|
||||||
|
var raw tavilyResponse
|
||||||
|
if err := json.Unmarshal(body, &raw); err != nil {
|
||||||
|
return nil, fmt.Errorf("tavily: decode response: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
results := make([]SearchResult, 0, len(raw.Results))
|
||||||
|
for _, r := range raw.Results {
|
||||||
|
results = append(results, SearchResult{
|
||||||
|
URL: r.URL,
|
||||||
|
Title: r.Title,
|
||||||
|
Snippet: r.Content,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return &SearchResponse{Results: results, Query: req.Query}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type tavilyRequest struct {
|
||||||
|
APIKey string `json:"api_key"`
|
||||||
|
Query string `json:"query"`
|
||||||
|
MaxResults int `json:"max_results"`
|
||||||
|
SearchDepth string `json:"search_depth"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type tavilyResponse struct {
|
||||||
|
Results []tavilyResult `json:"results"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type tavilyResult struct {
|
||||||
|
URL string `json:"url"`
|
||||||
|
Title string `json:"title"`
|
||||||
|
Content string `json:"content"`
|
||||||
|
Score float64 `json:"score"`
|
||||||
|
}
|
||||||
63
backend/internal/pkg/websearch/tavily_test.go
Normal file
63
backend/internal/pkg/websearch/tavily_test.go
Normal file
@ -0,0 +1,63 @@
|
|||||||
|
package websearch
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestTavilyProvider_Name(t *testing.T) {
|
||||||
|
p := NewTavilyProvider("key", nil)
|
||||||
|
require.Equal(t, "tavily", p.Name())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTavilyProvider_Search_RequestConstruction(t *testing.T) {
|
||||||
|
// Verify tavilyRequest struct fields map correctly
|
||||||
|
req := tavilyRequest{
|
||||||
|
APIKey: "test-key",
|
||||||
|
Query: "golang",
|
||||||
|
MaxResults: 3,
|
||||||
|
SearchDepth: tavilySearchDepthBasic,
|
||||||
|
}
|
||||||
|
data, err := json.Marshal(req)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
var parsed map[string]any
|
||||||
|
require.NoError(t, json.Unmarshal(data, &parsed))
|
||||||
|
require.Equal(t, "test-key", parsed["api_key"])
|
||||||
|
require.Equal(t, "golang", parsed["query"])
|
||||||
|
require.Equal(t, float64(3), parsed["max_results"])
|
||||||
|
require.Equal(t, "basic", parsed["search_depth"])
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTavilyProvider_Search_ResponseParsing(t *testing.T) {
|
||||||
|
rawResp := `{"results":[{"url":"https://go.dev","title":"Go","content":"Go programming language","score":0.95}]}`
|
||||||
|
var resp tavilyResponse
|
||||||
|
require.NoError(t, json.Unmarshal([]byte(rawResp), &resp))
|
||||||
|
require.Len(t, resp.Results, 1)
|
||||||
|
require.Equal(t, "https://go.dev", resp.Results[0].URL)
|
||||||
|
require.Equal(t, "Go programming language", resp.Results[0].Content)
|
||||||
|
require.InDelta(t, 0.95, resp.Results[0].Score, 0.001)
|
||||||
|
|
||||||
|
// Verify mapping to SearchResult
|
||||||
|
results := make([]SearchResult, 0, len(resp.Results))
|
||||||
|
for _, r := range resp.Results {
|
||||||
|
results = append(results, SearchResult{
|
||||||
|
URL: r.URL, Title: r.Title, Snippet: r.Content,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
require.Equal(t, "Go programming language", results[0].Snippet)
|
||||||
|
require.Equal(t, "", results[0].PageAge)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTavilyProvider_Search_EmptyResults(t *testing.T) {
|
||||||
|
var resp tavilyResponse
|
||||||
|
require.NoError(t, json.Unmarshal([]byte(`{"results":[]}`), &resp))
|
||||||
|
require.Empty(t, resp.Results)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTavilyProvider_Search_InvalidJSON(t *testing.T) {
|
||||||
|
var resp tavilyResponse
|
||||||
|
require.Error(t, json.Unmarshal([]byte("not json"), &resp))
|
||||||
|
}
|
||||||
30
backend/internal/pkg/websearch/types.go
Normal file
30
backend/internal/pkg/websearch/types.go
Normal file
@ -0,0 +1,30 @@
|
|||||||
|
package websearch
|
||||||
|
|
||||||
|
// SearchResult represents a single web search result.
|
||||||
|
type SearchResult struct {
|
||||||
|
URL string `json:"url"`
|
||||||
|
Title string `json:"title"`
|
||||||
|
Snippet string `json:"snippet"`
|
||||||
|
PageAge string `json:"page_age,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// SearchRequest describes a web search to perform.
|
||||||
|
type SearchRequest struct {
|
||||||
|
Query string
|
||||||
|
MaxResults int // defaults to defaultMaxResults if <= 0
|
||||||
|
ProxyURL string // optional HTTP proxy URL
|
||||||
|
}
|
||||||
|
|
||||||
|
// SearchResponse holds the results of a web search.
|
||||||
|
type SearchResponse struct {
|
||||||
|
Results []SearchResult
|
||||||
|
Query string // the query that was actually executed
|
||||||
|
}
|
||||||
|
|
||||||
|
const defaultMaxResults = 5
|
||||||
|
|
||||||
|
// Provider type identifiers.
|
||||||
|
const (
|
||||||
|
ProviderTypeBrave = "brave"
|
||||||
|
ProviderTypeTavily = "tavily"
|
||||||
|
)
|
||||||
@ -41,10 +41,14 @@ func (r *channelRepository) Create(ctx context.Context, channel *service.Channel
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
featuresConfigJSON, err := marshalFeaturesConfig(channel.FeaturesConfig)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
err = tx.QueryRowContext(ctx,
|
err = tx.QueryRowContext(ctx,
|
||||||
`INSERT INTO channels (name, description, status, model_mapping, billing_model_source, restrict_models, features) VALUES ($1, $2, $3, $4, $5, $6, $7)
|
`INSERT INTO channels (name, description, status, model_mapping, billing_model_source, restrict_models, features, features_config) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
|
||||||
RETURNING id, created_at, updated_at`,
|
RETURNING id, created_at, updated_at`,
|
||||||
channel.Name, channel.Description, channel.Status, modelMappingJSON, channel.BillingModelSource, channel.RestrictModels, channel.Features,
|
channel.Name, channel.Description, channel.Status, modelMappingJSON, channel.BillingModelSource, channel.RestrictModels, channel.Features, featuresConfigJSON,
|
||||||
).Scan(&channel.ID, &channel.CreatedAt, &channel.UpdatedAt)
|
).Scan(&channel.ID, &channel.CreatedAt, &channel.UpdatedAt)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if isUniqueViolation(err) {
|
if isUniqueViolation(err) {
|
||||||
@ -73,11 +77,11 @@ func (r *channelRepository) Create(ctx context.Context, channel *service.Channel
|
|||||||
|
|
||||||
func (r *channelRepository) GetByID(ctx context.Context, id int64) (*service.Channel, error) {
|
func (r *channelRepository) GetByID(ctx context.Context, id int64) (*service.Channel, error) {
|
||||||
ch := &service.Channel{}
|
ch := &service.Channel{}
|
||||||
var modelMappingJSON []byte
|
var modelMappingJSON, featuresConfigJSON []byte
|
||||||
err := r.db.QueryRowContext(ctx,
|
err := r.db.QueryRowContext(ctx,
|
||||||
`SELECT id, name, description, status, model_mapping, billing_model_source, restrict_models, features, created_at, updated_at
|
`SELECT id, name, description, status, model_mapping, billing_model_source, restrict_models, features, features_config, created_at, updated_at
|
||||||
FROM channels WHERE id = $1`, id,
|
FROM channels WHERE id = $1`, id,
|
||||||
).Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &modelMappingJSON, &ch.BillingModelSource, &ch.RestrictModels, &ch.Features, &ch.CreatedAt, &ch.UpdatedAt)
|
).Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &modelMappingJSON, &ch.BillingModelSource, &ch.RestrictModels, &ch.Features, &featuresConfigJSON, &ch.CreatedAt, &ch.UpdatedAt)
|
||||||
if err == sql.ErrNoRows {
|
if err == sql.ErrNoRows {
|
||||||
return nil, service.ErrChannelNotFound
|
return nil, service.ErrChannelNotFound
|
||||||
}
|
}
|
||||||
@ -85,6 +89,7 @@ func (r *channelRepository) GetByID(ctx context.Context, id int64) (*service.Cha
|
|||||||
return nil, fmt.Errorf("get channel: %w", err)
|
return nil, fmt.Errorf("get channel: %w", err)
|
||||||
}
|
}
|
||||||
ch.ModelMapping = unmarshalModelMapping(modelMappingJSON)
|
ch.ModelMapping = unmarshalModelMapping(modelMappingJSON)
|
||||||
|
ch.FeaturesConfig = unmarshalFeaturesConfig(featuresConfigJSON)
|
||||||
|
|
||||||
groupIDs, err := r.GetGroupIDs(ctx, id)
|
groupIDs, err := r.GetGroupIDs(ctx, id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -107,10 +112,14 @@ func (r *channelRepository) Update(ctx context.Context, channel *service.Channel
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
featuresConfigJSON, err := marshalFeaturesConfig(channel.FeaturesConfig)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
result, err := tx.ExecContext(ctx,
|
result, err := tx.ExecContext(ctx,
|
||||||
`UPDATE channels SET name = $1, description = $2, status = $3, model_mapping = $4, billing_model_source = $5, restrict_models = $6, features = $7, updated_at = NOW()
|
`UPDATE channels SET name = $1, description = $2, status = $3, model_mapping = $4, billing_model_source = $5, restrict_models = $6, features = $7, features_config = $8, updated_at = NOW()
|
||||||
WHERE id = $8`,
|
WHERE id = $9`,
|
||||||
channel.Name, channel.Description, channel.Status, modelMappingJSON, channel.BillingModelSource, channel.RestrictModels, channel.Features, channel.ID,
|
channel.Name, channel.Description, channel.Status, modelMappingJSON, channel.BillingModelSource, channel.RestrictModels, channel.Features, featuresConfigJSON, channel.ID,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if isUniqueViolation(err) {
|
if isUniqueViolation(err) {
|
||||||
@ -187,9 +196,9 @@ func (r *channelRepository) List(ctx context.Context, params pagination.Paginati
|
|||||||
|
|
||||||
// 查询 channel 列表
|
// 查询 channel 列表
|
||||||
dataQuery := fmt.Sprintf(
|
dataQuery := fmt.Sprintf(
|
||||||
`SELECT c.id, c.name, c.description, c.status, c.model_mapping, c.billing_model_source, c.restrict_models, c.features, c.created_at, c.updated_at
|
`SELECT c.id, c.name, c.description, c.status, c.model_mapping, c.billing_model_source, c.restrict_models, c.features, c.features_config, c.created_at, c.updated_at
|
||||||
FROM channels c WHERE %s ORDER BY c.id ASC LIMIT $%d OFFSET $%d`,
|
FROM channels c WHERE %s ORDER BY %s LIMIT $%d OFFSET $%d`,
|
||||||
whereClause, argIdx, argIdx+1,
|
whereClause, channelListOrderBy(params), argIdx, argIdx+1,
|
||||||
)
|
)
|
||||||
args = append(args, pageSize, offset)
|
args = append(args, pageSize, offset)
|
||||||
|
|
||||||
@ -203,11 +212,12 @@ func (r *channelRepository) List(ctx context.Context, params pagination.Paginati
|
|||||||
var channelIDs []int64
|
var channelIDs []int64
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
var ch service.Channel
|
var ch service.Channel
|
||||||
var modelMappingJSON []byte
|
var modelMappingJSON, featuresConfigJSON []byte
|
||||||
if err := rows.Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &modelMappingJSON, &ch.BillingModelSource, &ch.RestrictModels, &ch.Features, &ch.CreatedAt, &ch.UpdatedAt); err != nil {
|
if err := rows.Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &modelMappingJSON, &ch.BillingModelSource, &ch.RestrictModels, &ch.Features, &featuresConfigJSON, &ch.CreatedAt, &ch.UpdatedAt); err != nil {
|
||||||
return nil, nil, fmt.Errorf("scan channel: %w", err)
|
return nil, nil, fmt.Errorf("scan channel: %w", err)
|
||||||
}
|
}
|
||||||
ch.ModelMapping = unmarshalModelMapping(modelMappingJSON)
|
ch.ModelMapping = unmarshalModelMapping(modelMappingJSON)
|
||||||
|
ch.FeaturesConfig = unmarshalFeaturesConfig(featuresConfigJSON)
|
||||||
channels = append(channels, ch)
|
channels = append(channels, ch)
|
||||||
channelIDs = append(channelIDs, ch.ID)
|
channelIDs = append(channelIDs, ch.ID)
|
||||||
}
|
}
|
||||||
@ -246,9 +256,34 @@ func (r *channelRepository) List(ctx context.Context, params pagination.Paginati
|
|||||||
return channels, paginationResult, nil
|
return channels, paginationResult, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func channelListOrderBy(params pagination.PaginationParams) string {
|
||||||
|
sortBy := strings.ToLower(strings.TrimSpace(params.SortBy))
|
||||||
|
sortOrder := strings.ToUpper(params.NormalizedSortOrder(pagination.SortOrderAsc))
|
||||||
|
|
||||||
|
var column string
|
||||||
|
switch sortBy {
|
||||||
|
case "":
|
||||||
|
column = "c.id"
|
||||||
|
sortOrder = "ASC"
|
||||||
|
case "id":
|
||||||
|
column = "c.id"
|
||||||
|
case "name":
|
||||||
|
column = "c.name"
|
||||||
|
case "status":
|
||||||
|
column = "c.status"
|
||||||
|
case "created_at":
|
||||||
|
column = "c.created_at"
|
||||||
|
default:
|
||||||
|
column = "c.id"
|
||||||
|
sortOrder = "ASC"
|
||||||
|
}
|
||||||
|
|
||||||
|
return fmt.Sprintf("%s %s, c.id %s", column, sortOrder, sortOrder)
|
||||||
|
}
|
||||||
|
|
||||||
func (r *channelRepository) ListAll(ctx context.Context) ([]service.Channel, error) {
|
func (r *channelRepository) ListAll(ctx context.Context) ([]service.Channel, error) {
|
||||||
rows, err := r.db.QueryContext(ctx,
|
rows, err := r.db.QueryContext(ctx,
|
||||||
`SELECT id, name, description, status, model_mapping, billing_model_source, restrict_models, features, created_at, updated_at FROM channels ORDER BY id`,
|
`SELECT id, name, description, status, model_mapping, billing_model_source, restrict_models, features, features_config, created_at, updated_at FROM channels ORDER BY id`,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("query all channels: %w", err)
|
return nil, fmt.Errorf("query all channels: %w", err)
|
||||||
@ -259,11 +294,12 @@ func (r *channelRepository) ListAll(ctx context.Context) ([]service.Channel, err
|
|||||||
var channelIDs []int64
|
var channelIDs []int64
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
var ch service.Channel
|
var ch service.Channel
|
||||||
var modelMappingJSON []byte
|
var modelMappingJSON, featuresConfigJSON []byte
|
||||||
if err := rows.Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &modelMappingJSON, &ch.BillingModelSource, &ch.RestrictModels, &ch.Features, &ch.CreatedAt, &ch.UpdatedAt); err != nil {
|
if err := rows.Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &modelMappingJSON, &ch.BillingModelSource, &ch.RestrictModels, &ch.Features, &featuresConfigJSON, &ch.CreatedAt, &ch.UpdatedAt); err != nil {
|
||||||
return nil, fmt.Errorf("scan channel: %w", err)
|
return nil, fmt.Errorf("scan channel: %w", err)
|
||||||
}
|
}
|
||||||
ch.ModelMapping = unmarshalModelMapping(modelMappingJSON)
|
ch.ModelMapping = unmarshalModelMapping(modelMappingJSON)
|
||||||
|
ch.FeaturesConfig = unmarshalFeaturesConfig(featuresConfigJSON)
|
||||||
channels = append(channels, ch)
|
channels = append(channels, ch)
|
||||||
channelIDs = append(channelIDs, ch.ID)
|
channelIDs = append(channelIDs, ch.ID)
|
||||||
}
|
}
|
||||||
@ -431,6 +467,28 @@ func unmarshalModelMapping(data []byte) map[string]map[string]string {
|
|||||||
return m
|
return m
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func marshalFeaturesConfig(m map[string]any) ([]byte, error) {
|
||||||
|
if len(m) == 0 {
|
||||||
|
return []byte("{}"), nil
|
||||||
|
}
|
||||||
|
data, err := json.Marshal(m)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("marshal features_config: %w", err)
|
||||||
|
}
|
||||||
|
return data, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func unmarshalFeaturesConfig(data []byte) map[string]any {
|
||||||
|
if len(data) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
var m map[string]any
|
||||||
|
if err := json.Unmarshal(data, &m); err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return m
|
||||||
|
}
|
||||||
|
|
||||||
// GetGroupPlatforms 批量查询分组 ID 对应的平台
|
// GetGroupPlatforms 批量查询分组 ID 对应的平台
|
||||||
func (r *channelRepository) GetGroupPlatforms(ctx context.Context, groupIDs []int64) (map[int64]string, error) {
|
func (r *channelRepository) GetGroupPlatforms(ctx context.Context, groupIDs []int64) (map[int64]string, error) {
|
||||||
if len(groupIDs) == 0 {
|
if len(groupIDs) == 0 {
|
||||||
|
|||||||
@ -407,6 +407,9 @@ func registerSettingsRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
|||||||
// Beta 策略配置
|
// Beta 策略配置
|
||||||
adminSettings.GET("/beta-policy", h.Admin.Setting.GetBetaPolicySettings)
|
adminSettings.GET("/beta-policy", h.Admin.Setting.GetBetaPolicySettings)
|
||||||
adminSettings.PUT("/beta-policy", h.Admin.Setting.UpdateBetaPolicySettings)
|
adminSettings.PUT("/beta-policy", h.Admin.Setting.UpdateBetaPolicySettings)
|
||||||
|
// Web Search 模拟配置
|
||||||
|
adminSettings.GET("/web-search-emulation", h.Admin.Setting.GetWebSearchEmulationConfig)
|
||||||
|
adminSettings.PUT("/web-search-emulation", h.Admin.Setting.UpdateWebSearchEmulationConfig)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -969,7 +969,7 @@ func (a *Account) IsOveragesEnabled() bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsOpenAIPassthroughEnabled 返回 OpenAI 账号是否启用“自动透传(仅替换认证)”。
|
// IsOpenAIPassthroughEnabled 返回 OpenAI 账号是否启用"自动透传(仅替换认证)"。
|
||||||
//
|
//
|
||||||
// 新字段:accounts.extra.openai_passthrough。
|
// 新字段:accounts.extra.openai_passthrough。
|
||||||
// 兼容字段:accounts.extra.openai_oauth_passthrough(历史 OAuth 开关)。
|
// 兼容字段:accounts.extra.openai_oauth_passthrough(历史 OAuth 开关)。
|
||||||
@ -1133,7 +1133,7 @@ func (a *Account) ResolveOpenAIResponsesWebSocketV2Mode(defaultMode string) stri
|
|||||||
return resolvedDefault
|
return resolvedDefault
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsOpenAIWSForceHTTPEnabled 返回账号级“强制 HTTP”开关。
|
// IsOpenAIWSForceHTTPEnabled 返回账号级"强制 HTTP"开关。
|
||||||
// 字段:accounts.extra.openai_ws_force_http。
|
// 字段:accounts.extra.openai_ws_force_http。
|
||||||
func (a *Account) IsOpenAIWSForceHTTPEnabled() bool {
|
func (a *Account) IsOpenAIWSForceHTTPEnabled() bool {
|
||||||
if a == nil || !a.IsOpenAI() || a.Extra == nil {
|
if a == nil || !a.IsOpenAI() || a.Extra == nil {
|
||||||
@ -1158,7 +1158,7 @@ func (a *Account) IsOpenAIOAuthPassthroughEnabled() bool {
|
|||||||
return a != nil && a.IsOpenAIOAuth() && a.IsOpenAIPassthroughEnabled()
|
return a != nil && a.IsOpenAIOAuth() && a.IsOpenAIPassthroughEnabled()
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsAnthropicAPIKeyPassthroughEnabled 返回 Anthropic API Key 账号是否启用“自动透传(仅替换认证)”。
|
// IsAnthropicAPIKeyPassthroughEnabled 返回 Anthropic API Key 账号是否启用"自动透传(仅替换认证)"。
|
||||||
// 字段:accounts.extra.anthropic_passthrough。
|
// 字段:accounts.extra.anthropic_passthrough。
|
||||||
// 字段缺失或类型不正确时,按 false(关闭)处理。
|
// 字段缺失或类型不正确时,按 false(关闭)处理。
|
||||||
func (a *Account) IsAnthropicAPIKeyPassthroughEnabled() bool {
|
func (a *Account) IsAnthropicAPIKeyPassthroughEnabled() bool {
|
||||||
@ -1169,7 +1169,18 @@ func (a *Account) IsAnthropicAPIKeyPassthroughEnabled() bool {
|
|||||||
return ok && enabled
|
return ok && enabled
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsCodexCLIOnlyEnabled 返回 OpenAI OAuth 账号是否启用“仅允许 Codex 官方客户端”。
|
// IsWebSearchEmulationEnabled 返回 Anthropic API Key 账号是否启用 web search 模拟。
|
||||||
|
// 字段:accounts.extra.web_search_emulation。
|
||||||
|
// 字段缺失或类型不正确时,按 false(关闭)处理。
|
||||||
|
func (a *Account) IsWebSearchEmulationEnabled() bool {
|
||||||
|
if a == nil || a.Platform != PlatformAnthropic || a.Type != AccountTypeAPIKey || a.Extra == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
enabled, ok := a.Extra[featureKeyWebSearchEmulation].(bool)
|
||||||
|
return ok && enabled
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsCodexCLIOnlyEnabled 返回 OpenAI OAuth 账号是否启用"仅允许 Codex 官方客户端"。
|
||||||
// 字段:accounts.extra.codex_cli_only。
|
// 字段:accounts.extra.codex_cli_only。
|
||||||
// 字段缺失或类型不正确时,按 false(关闭)处理。
|
// 字段缺失或类型不正确时,按 false(关闭)处理。
|
||||||
func (a *Account) IsCodexCLIOnlyEnabled() bool {
|
func (a *Account) IsCodexCLIOnlyEnabled() bool {
|
||||||
|
|||||||
71
backend/internal/service/account_websearch_test.go
Normal file
71
backend/internal/service/account_websearch_test.go
Normal file
@ -0,0 +1,71 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestAccount_IsWebSearchEmulationEnabled_Enabled(t *testing.T) {
|
||||||
|
a := &Account{
|
||||||
|
Platform: PlatformAnthropic,
|
||||||
|
Type: AccountTypeAPIKey,
|
||||||
|
Extra: map[string]any{featureKeyWebSearchEmulation: true},
|
||||||
|
}
|
||||||
|
require.True(t, a.IsWebSearchEmulationEnabled())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAccount_IsWebSearchEmulationEnabled_Disabled(t *testing.T) {
|
||||||
|
a := &Account{
|
||||||
|
Platform: PlatformAnthropic,
|
||||||
|
Type: AccountTypeAPIKey,
|
||||||
|
Extra: map[string]any{featureKeyWebSearchEmulation: false},
|
||||||
|
}
|
||||||
|
require.False(t, a.IsWebSearchEmulationEnabled())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAccount_IsWebSearchEmulationEnabled_MissingField(t *testing.T) {
|
||||||
|
a := &Account{
|
||||||
|
Platform: PlatformAnthropic,
|
||||||
|
Type: AccountTypeAPIKey,
|
||||||
|
Extra: map[string]any{},
|
||||||
|
}
|
||||||
|
require.False(t, a.IsWebSearchEmulationEnabled())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAccount_IsWebSearchEmulationEnabled_WrongType(t *testing.T) {
|
||||||
|
a := &Account{
|
||||||
|
Platform: PlatformAnthropic,
|
||||||
|
Type: AccountTypeAPIKey,
|
||||||
|
Extra: map[string]any{featureKeyWebSearchEmulation: "true"},
|
||||||
|
}
|
||||||
|
require.False(t, a.IsWebSearchEmulationEnabled())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAccount_IsWebSearchEmulationEnabled_NilExtra(t *testing.T) {
|
||||||
|
a := &Account{Platform: PlatformAnthropic, Type: AccountTypeAPIKey, Extra: nil}
|
||||||
|
require.False(t, a.IsWebSearchEmulationEnabled())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAccount_IsWebSearchEmulationEnabled_NilAccount(t *testing.T) {
|
||||||
|
var a *Account
|
||||||
|
require.False(t, a.IsWebSearchEmulationEnabled())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAccount_IsWebSearchEmulationEnabled_NonAnthropicPlatform(t *testing.T) {
|
||||||
|
a := &Account{
|
||||||
|
Platform: PlatformOpenAI,
|
||||||
|
Type: AccountTypeAPIKey,
|
||||||
|
Extra: map[string]any{featureKeyWebSearchEmulation: true},
|
||||||
|
}
|
||||||
|
require.False(t, a.IsWebSearchEmulationEnabled())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAccount_IsWebSearchEmulationEnabled_NonAPIKeyType(t *testing.T) {
|
||||||
|
a := &Account{
|
||||||
|
Platform: PlatformAnthropic,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Extra: map[string]any{featureKeyWebSearchEmulation: true},
|
||||||
|
}
|
||||||
|
require.False(t, a.IsWebSearchEmulationEnabled())
|
||||||
|
}
|
||||||
@ -49,6 +49,21 @@ type Channel struct {
|
|||||||
ModelPricing []ChannelModelPricing
|
ModelPricing []ChannelModelPricing
|
||||||
// 渠道级模型映射(按平台分组:platform → {src→dst})
|
// 渠道级模型映射(按平台分组:platform → {src→dst})
|
||||||
ModelMapping map[string]map[string]string
|
ModelMapping map[string]map[string]string
|
||||||
|
// 渠道特性配置(如 {"web_search_emulation": {"anthropic": true}})
|
||||||
|
FeaturesConfig map[string]any
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsWebSearchEmulationEnabled 返回该渠道是否为指定平台启用了 web search 模拟。
|
||||||
|
func (c *Channel) IsWebSearchEmulationEnabled(platform string) bool {
|
||||||
|
if c == nil || c.FeaturesConfig == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
wse, ok := c.FeaturesConfig[featureKeyWebSearchEmulation].(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
enabled, ok := wse[platform].(bool)
|
||||||
|
return ok && enabled
|
||||||
}
|
}
|
||||||
|
|
||||||
// ChannelModelPricing 渠道模型定价条目
|
// ChannelModelPricing 渠道模型定价条目
|
||||||
|
|||||||
@ -197,10 +197,8 @@ func newEmptyChannelCache() *channelCache {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// expandPricingToCache 将渠道的模型定价展开到缓存(按分组+平台维度)。
|
// expandPricingToCache 将渠道的模型定价展开到缓存(按分组+平台维度)。
|
||||||
// antigravity 平台同时服务 Claude 和 Gemini 模型,需匹配 anthropic/gemini 的定价条目。
|
// 各平台严格独立:antigravity 分组只匹配 antigravity 定价,不会匹配 anthropic/gemini 的定价。
|
||||||
// 缓存 key 使用定价条目的原始平台(pricing.Platform),而非分组平台,
|
// 查找时通过 lookupPricingAcrossPlatforms() 在本平台内查找。
|
||||||
// 避免跨平台同名模型(如 anthropic 和 gemini 都有 "model-x")互相覆盖。
|
|
||||||
// 查找时通过 lookupPricingAcrossPlatforms() 依次尝试所有匹配平台。
|
|
||||||
func expandPricingToCache(cache *channelCache, ch *Channel, gid int64, platform string) {
|
func expandPricingToCache(cache *channelCache, ch *Channel, gid int64, platform string) {
|
||||||
for j := range ch.ModelPricing {
|
for j := range ch.ModelPricing {
|
||||||
pricing := &ch.ModelPricing[j]
|
pricing := &ch.ModelPricing[j]
|
||||||
@ -226,8 +224,7 @@ func expandPricingToCache(cache *channelCache, ch *Channel, gid int64, platform
|
|||||||
}
|
}
|
||||||
|
|
||||||
// expandMappingToCache 将渠道的模型映射展开到缓存(按分组+平台维度)。
|
// expandMappingToCache 将渠道的模型映射展开到缓存(按分组+平台维度)。
|
||||||
// antigravity 平台同时服务 Claude 和 Gemini 模型。
|
// 各平台严格独立:antigravity 分组只匹配 antigravity 映射。
|
||||||
// 缓存 key 使用映射条目的原始平台(mappingPlatform),避免跨平台同名映射覆盖。
|
|
||||||
func expandMappingToCache(cache *channelCache, ch *Channel, gid int64, platform string) {
|
func expandMappingToCache(cache *channelCache, ch *Channel, gid int64, platform string) {
|
||||||
for _, mappingPlatform := range matchingPlatforms(platform) {
|
for _, mappingPlatform := range matchingPlatforms(platform) {
|
||||||
platformMapping, ok := ch.ModelMapping[mappingPlatform]
|
platformMapping, ok := ch.ModelMapping[mappingPlatform]
|
||||||
@ -251,40 +248,58 @@ func expandMappingToCache(cache *channelCache, ch *Channel, gid int64, platform
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// storeErrorCache 存入短 TTL 空缓存,防止 DB 错误后紧密重试。
|
||||||
|
// 通过回退 loadedAt 使剩余 TTL = channelErrorTTL。
|
||||||
|
func (s *ChannelService) storeErrorCache() {
|
||||||
|
errorCache := newEmptyChannelCache()
|
||||||
|
errorCache.loadedAt = time.Now().Add(-(channelCacheTTL - channelErrorTTL))
|
||||||
|
s.cache.Store(errorCache)
|
||||||
|
}
|
||||||
|
|
||||||
// buildCache 从数据库构建渠道缓存。
|
// buildCache 从数据库构建渠道缓存。
|
||||||
// 使用独立 context 避免请求取消导致空值被长期缓存。
|
// 使用独立 context 避免请求取消导致空值被长期缓存。
|
||||||
func (s *ChannelService) buildCache(ctx context.Context) (*channelCache, error) {
|
func (s *ChannelService) buildCache(ctx context.Context) (*channelCache, error) {
|
||||||
// 断开请求取消链,避免客户端断连导致空值被长期缓存
|
|
||||||
dbCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), channelCacheDBTimeout)
|
dbCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), channelCacheDBTimeout)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
channels, err := s.repo.ListAll(dbCtx)
|
channels, groupPlatforms, err := s.fetchChannelData(dbCtx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// error-TTL:失败时存入短 TTL 空缓存,防止紧密重试
|
return nil, err
|
||||||
slog.Warn("failed to build channel cache", "error", err)
|
}
|
||||||
errorCache := newEmptyChannelCache()
|
|
||||||
errorCache.loadedAt = time.Now().Add(-(channelCacheTTL - channelErrorTTL)) // 使剩余 TTL = errorTTL
|
cache := populateChannelCache(channels, groupPlatforms)
|
||||||
s.cache.Store(errorCache)
|
s.cache.Store(cache)
|
||||||
return nil, fmt.Errorf("list all channels: %w", err)
|
return cache, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// fetchChannelData 从数据库加载渠道列表和分组平台映射。
|
||||||
|
func (s *ChannelService) fetchChannelData(ctx context.Context) ([]Channel, map[int64]string, error) {
|
||||||
|
channels, err := s.repo.ListAll(ctx)
|
||||||
|
if err != nil {
|
||||||
|
slog.Warn("failed to build channel cache", "error", err)
|
||||||
|
s.storeErrorCache()
|
||||||
|
return nil, nil, fmt.Errorf("list all channels: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 收集所有 groupID,批量查询 platform
|
|
||||||
var allGroupIDs []int64
|
var allGroupIDs []int64
|
||||||
for i := range channels {
|
for i := range channels {
|
||||||
allGroupIDs = append(allGroupIDs, channels[i].GroupIDs...)
|
allGroupIDs = append(allGroupIDs, channels[i].GroupIDs...)
|
||||||
}
|
}
|
||||||
|
|
||||||
groupPlatforms := make(map[int64]string)
|
groupPlatforms := make(map[int64]string)
|
||||||
if len(allGroupIDs) > 0 {
|
if len(allGroupIDs) > 0 {
|
||||||
groupPlatforms, err = s.repo.GetGroupPlatforms(dbCtx, allGroupIDs)
|
groupPlatforms, err = s.repo.GetGroupPlatforms(ctx, allGroupIDs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Warn("failed to load group platforms for channel cache", "error", err)
|
slog.Warn("failed to load group platforms for channel cache", "error", err)
|
||||||
errorCache := newEmptyChannelCache()
|
s.storeErrorCache()
|
||||||
errorCache.loadedAt = time.Now().Add(-(channelCacheTTL - channelErrorTTL))
|
return nil, nil, fmt.Errorf("get group platforms: %w", err)
|
||||||
s.cache.Store(errorCache)
|
|
||||||
return nil, fmt.Errorf("get group platforms: %w", err)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
return channels, groupPlatforms, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// populateChannelCache 将渠道列表和分组平台映射填充到缓存快照中。
|
||||||
|
func populateChannelCache(channels []Channel, groupPlatforms map[int64]string) *channelCache {
|
||||||
cache := newEmptyChannelCache()
|
cache := newEmptyChannelCache()
|
||||||
cache.groupPlatform = groupPlatforms
|
cache.groupPlatform = groupPlatforms
|
||||||
cache.byID = make(map[int64]*Channel, len(channels))
|
cache.byID = make(map[int64]*Channel, len(channels))
|
||||||
@ -293,7 +308,6 @@ func (s *ChannelService) buildCache(ctx context.Context) (*channelCache, error)
|
|||||||
for i := range channels {
|
for i := range channels {
|
||||||
ch := &channels[i]
|
ch := &channels[i]
|
||||||
cache.byID[ch.ID] = ch
|
cache.byID[ch.ID] = ch
|
||||||
|
|
||||||
for _, gid := range ch.GroupIDs {
|
for _, gid := range ch.GroupIDs {
|
||||||
cache.channelByGroupID[gid] = ch
|
cache.channelByGroupID[gid] = ch
|
||||||
platform := groupPlatforms[gid]
|
platform := groupPlatforms[gid]
|
||||||
@ -302,32 +316,20 @@ func (s *ChannelService) buildCache(ctx context.Context) (*channelCache, error)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 通配符条目保持配置顺序(最先匹配到优先)
|
return cache
|
||||||
|
|
||||||
s.cache.Store(cache)
|
|
||||||
return cache, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// invalidateCache 使缓存失效,让下次读取时自然重建
|
// invalidateCache 使缓存失效,让下次读取时自然重建
|
||||||
|
|
||||||
// isPlatformPricingMatch 判断定价条目的平台是否匹配分组平台。
|
// isPlatformPricingMatch 判断定价条目的平台是否匹配分组平台。
|
||||||
// antigravity 平台同时服务 Claude(anthropic)和 Gemini(gemini)模型,
|
// 各平台(antigravity / anthropic / gemini / openai)严格独立,不跨平台匹配。
|
||||||
// 因此 antigravity 分组应匹配 anthropic 和 gemini 的定价条目。
|
|
||||||
func isPlatformPricingMatch(groupPlatform, pricingPlatform string) bool {
|
func isPlatformPricingMatch(groupPlatform, pricingPlatform string) bool {
|
||||||
if groupPlatform == pricingPlatform {
|
return groupPlatform == pricingPlatform
|
||||||
return true
|
|
||||||
}
|
|
||||||
if groupPlatform == PlatformAntigravity {
|
|
||||||
return pricingPlatform == PlatformAnthropic || pricingPlatform == PlatformGemini
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// matchingPlatforms 返回分组平台对应的所有可匹配平台列表。
|
// matchingPlatforms 返回分组平台对应的可匹配平台列表。
|
||||||
|
// 各平台严格独立,只返回自身。
|
||||||
func matchingPlatforms(groupPlatform string) []string {
|
func matchingPlatforms(groupPlatform string) []string {
|
||||||
if groupPlatform == PlatformAntigravity {
|
|
||||||
return []string{PlatformAntigravity, PlatformAnthropic, PlatformGemini}
|
|
||||||
}
|
|
||||||
return []string{groupPlatform}
|
return []string{groupPlatform}
|
||||||
}
|
}
|
||||||
func (s *ChannelService) invalidateCache() {
|
func (s *ChannelService) invalidateCache() {
|
||||||
@ -364,10 +366,8 @@ func (c *channelCache) matchWildcardMapping(groupID int64, platform, modelLower
|
|||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
// lookupPricingAcrossPlatforms 在所有匹配平台中查找模型定价。
|
// lookupPricingAcrossPlatforms 在分组平台内查找模型定价。
|
||||||
// antigravity 分组的缓存 key 使用定价条目的原始平台,因此查找时需依次尝试
|
// 各平台严格独立,只在本平台内查找(先精确匹配,再通配符)。
|
||||||
// matchingPlatforms() 返回的所有平台(antigravity → anthropic → gemini),
|
|
||||||
// 返回第一个命中的结果。非 antigravity 平台只尝试自身。
|
|
||||||
func lookupPricingAcrossPlatforms(cache *channelCache, groupID int64, groupPlatform, modelLower string) *ChannelModelPricing {
|
func lookupPricingAcrossPlatforms(cache *channelCache, groupID int64, groupPlatform, modelLower string) *ChannelModelPricing {
|
||||||
for _, p := range matchingPlatforms(groupPlatform) {
|
for _, p := range matchingPlatforms(groupPlatform) {
|
||||||
key := channelModelKey{groupID: groupID, platform: p, model: modelLower}
|
key := channelModelKey{groupID: groupID, platform: p, model: modelLower}
|
||||||
@ -384,7 +384,7 @@ func lookupPricingAcrossPlatforms(cache *channelCache, groupID int64, groupPlatf
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// lookupMappingAcrossPlatforms 在所有匹配平台中查找模型映射。
|
// lookupMappingAcrossPlatforms 在分组平台内查找模型映射。
|
||||||
// 逻辑与 lookupPricingAcrossPlatforms 相同:先精确查找,再通配符。
|
// 逻辑与 lookupPricingAcrossPlatforms 相同:先精确查找,再通配符。
|
||||||
func lookupMappingAcrossPlatforms(cache *channelCache, groupID int64, groupPlatform, modelLower string) string {
|
func lookupMappingAcrossPlatforms(cache *channelCache, groupID int64, groupPlatform, modelLower string) string {
|
||||||
for _, p := range matchingPlatforms(groupPlatform) {
|
for _, p := range matchingPlatforms(groupPlatform) {
|
||||||
@ -442,8 +442,7 @@ func (s *ChannelService) lookupGroupChannel(ctx context.Context, groupID int64)
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetChannelModelPricing 获取指定分组+模型的渠道定价(热路径 O(1))。
|
// GetChannelModelPricing 获取指定分组+模型的渠道定价(热路径 O(1))。
|
||||||
// antigravity 分组依次尝试所有匹配平台(antigravity → anthropic → gemini),
|
// 各平台严格独立,只在本平台内查找定价。
|
||||||
// 确保跨平台同名模型各自独立匹配。
|
|
||||||
func (s *ChannelService) GetChannelModelPricing(ctx context.Context, groupID int64, model string) *ChannelModelPricing {
|
func (s *ChannelService) GetChannelModelPricing(ctx context.Context, groupID int64, model string) *ChannelModelPricing {
|
||||||
lk, err := s.lookupGroupChannel(ctx, groupID)
|
lk, err := s.lookupGroupChannel(ctx, groupID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -481,7 +480,10 @@ func (s *ChannelService) ResolveChannelMapping(ctx context.Context, groupID int6
|
|||||||
// 返回 true 表示模型被限制(不在允许列表中)。
|
// 返回 true 表示模型被限制(不在允许列表中)。
|
||||||
// 如果渠道未启用模型限制或分组无渠道关联,返回 false。
|
// 如果渠道未启用模型限制或分组无渠道关联,返回 false。
|
||||||
func (s *ChannelService) IsModelRestricted(ctx context.Context, groupID int64, model string) bool {
|
func (s *ChannelService) IsModelRestricted(ctx context.Context, groupID int64, model string) bool {
|
||||||
lk, _ := s.lookupGroupChannel(ctx, groupID)
|
lk, err := s.lookupGroupChannel(ctx, groupID)
|
||||||
|
if err != nil {
|
||||||
|
slog.Warn("failed to load channel cache for model restriction check", "group_id", groupID, "error", err)
|
||||||
|
}
|
||||||
if lk == nil {
|
if lk == nil {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
@ -524,7 +526,7 @@ func resolveMapping(lk *channelLookup, groupID int64, model string) ChannelMappi
|
|||||||
}
|
}
|
||||||
|
|
||||||
// checkRestricted 基于已查找的渠道信息检查模型是否被限制。
|
// checkRestricted 基于已查找的渠道信息检查模型是否被限制。
|
||||||
// antigravity 分组依次尝试所有匹配平台的定价列表。
|
// 只在本平台的定价列表中查找。
|
||||||
func checkRestricted(lk *channelLookup, groupID int64, model string) bool {
|
func checkRestricted(lk *channelLookup, groupID int64, model string) bool {
|
||||||
if !lk.channel.RestrictModels {
|
if !lk.channel.RestrictModels {
|
||||||
return false
|
return false
|
||||||
@ -552,6 +554,91 @@ func ReplaceModelInBody(body []byte, newModel string) []byte {
|
|||||||
return newBody
|
return newBody
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// validateChannelConfig 校验渠道的定价和映射配置(冲突检测 + 区间校验 + 计费模式校验)。
|
||||||
|
// Create 和 Update 共用此函数,避免重复。
|
||||||
|
func validateChannelConfig(pricing []ChannelModelPricing, mapping map[string]map[string]string) error {
|
||||||
|
if err := validateNoConflictingModels(pricing); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := validatePricingIntervals(pricing); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := validateNoConflictingMappings(mapping); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return validatePricingBillingMode(pricing)
|
||||||
|
}
|
||||||
|
|
||||||
|
// validatePricingBillingMode 校验计费模式配置:按次/图片模式必须配价格或区间,所有价格字段不能为负,区间至少有一个价格字段。
|
||||||
|
func validatePricingBillingMode(pricing []ChannelModelPricing) error {
|
||||||
|
for _, p := range pricing {
|
||||||
|
if err := checkBillingModeRequirements(p); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := checkPricesNotNegative(p); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := checkIntervalsHavePrices(p); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func checkBillingModeRequirements(p ChannelModelPricing) error {
|
||||||
|
if p.BillingMode == BillingModePerRequest || p.BillingMode == BillingModeImage {
|
||||||
|
if p.PerRequestPrice == nil && len(p.Intervals) == 0 {
|
||||||
|
return infraerrors.BadRequest(
|
||||||
|
"BILLING_MODE_MISSING_PRICE",
|
||||||
|
"per-request price or intervals required for per_request/image billing mode",
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func checkPricesNotNegative(p ChannelModelPricing) error {
|
||||||
|
checks := []struct {
|
||||||
|
field string
|
||||||
|
val *float64
|
||||||
|
}{
|
||||||
|
{"input_price", p.InputPrice},
|
||||||
|
{"output_price", p.OutputPrice},
|
||||||
|
{"cache_write_price", p.CacheWritePrice},
|
||||||
|
{"cache_read_price", p.CacheReadPrice},
|
||||||
|
{"image_output_price", p.ImageOutputPrice},
|
||||||
|
{"per_request_price", p.PerRequestPrice},
|
||||||
|
}
|
||||||
|
for _, c := range checks {
|
||||||
|
if c.val != nil && *c.val < 0 {
|
||||||
|
return infraerrors.BadRequest("NEGATIVE_PRICE", fmt.Sprintf("%s must be >= 0", c.field))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func checkIntervalsHavePrices(p ChannelModelPricing) error {
|
||||||
|
for _, iv := range p.Intervals {
|
||||||
|
if iv.InputPrice == nil && iv.OutputPrice == nil &&
|
||||||
|
iv.CacheWritePrice == nil && iv.CacheReadPrice == nil &&
|
||||||
|
iv.PerRequestPrice == nil {
|
||||||
|
return infraerrors.BadRequest(
|
||||||
|
"INTERVAL_MISSING_PRICE",
|
||||||
|
fmt.Sprintf("interval [%d, %s] has no price fields set for model %v",
|
||||||
|
iv.MinTokens, formatMaxTokens(iv.MaxTokens), p.Models),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func formatMaxTokens(max *int) string {
|
||||||
|
if max == nil {
|
||||||
|
return "∞"
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("%d", *max)
|
||||||
|
}
|
||||||
|
|
||||||
// --- CRUD ---
|
// --- CRUD ---
|
||||||
|
|
||||||
// Create 创建渠道
|
// Create 创建渠道
|
||||||
@ -564,15 +651,8 @@ func (s *ChannelService) Create(ctx context.Context, input *CreateChannelInput)
|
|||||||
return nil, ErrChannelExists
|
return nil, ErrChannelExists
|
||||||
}
|
}
|
||||||
|
|
||||||
// 检查分组冲突
|
if err := s.checkGroupConflicts(ctx, 0, input.GroupIDs); err != nil {
|
||||||
if len(input.GroupIDs) > 0 {
|
return nil, err
|
||||||
conflicting, err := s.repo.GetGroupsInOtherChannels(ctx, 0, input.GroupIDs)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("check group conflicts: %w", err)
|
|
||||||
}
|
|
||||||
if len(conflicting) > 0 {
|
|
||||||
return nil, ErrGroupAlreadyInChannel
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
channel := &Channel{
|
channel := &Channel{
|
||||||
@ -585,18 +665,13 @@ func (s *ChannelService) Create(ctx context.Context, input *CreateChannelInput)
|
|||||||
ModelPricing: input.ModelPricing,
|
ModelPricing: input.ModelPricing,
|
||||||
ModelMapping: input.ModelMapping,
|
ModelMapping: input.ModelMapping,
|
||||||
Features: input.Features,
|
Features: input.Features,
|
||||||
|
FeaturesConfig: input.FeaturesConfig,
|
||||||
}
|
}
|
||||||
if channel.BillingModelSource == "" {
|
if channel.BillingModelSource == "" {
|
||||||
channel.BillingModelSource = BillingModelSourceChannelMapped
|
channel.BillingModelSource = BillingModelSourceChannelMapped
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := validateNoConflictingModels(channel.ModelPricing); err != nil {
|
if err := validateChannelConfig(channel.ModelPricing, channel.ModelMapping); err != nil {
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if err := validatePricingIntervals(channel.ModelPricing); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if err := validateNoConflictingMappings(channel.ModelMapping); err != nil {
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -620,105 +695,118 @@ func (s *ChannelService) Update(ctx context.Context, id int64, input *UpdateChan
|
|||||||
return nil, fmt.Errorf("get channel: %w", err)
|
return nil, fmt.Errorf("get channel: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if input.Name != "" && input.Name != channel.Name {
|
if err := s.applyUpdateInput(ctx, channel, input); err != nil {
|
||||||
exists, err := s.repo.ExistsByNameExcluding(ctx, input.Name, id)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("check channel exists: %w", err)
|
|
||||||
}
|
|
||||||
if exists {
|
|
||||||
return nil, ErrChannelExists
|
|
||||||
}
|
|
||||||
channel.Name = input.Name
|
|
||||||
}
|
|
||||||
|
|
||||||
if input.Description != nil {
|
|
||||||
channel.Description = *input.Description
|
|
||||||
}
|
|
||||||
|
|
||||||
if input.Status != "" {
|
|
||||||
channel.Status = input.Status
|
|
||||||
}
|
|
||||||
|
|
||||||
if input.RestrictModels != nil {
|
|
||||||
channel.RestrictModels = *input.RestrictModels
|
|
||||||
}
|
|
||||||
if input.Features != nil {
|
|
||||||
channel.Features = *input.Features
|
|
||||||
}
|
|
||||||
|
|
||||||
// 检查分组冲突
|
|
||||||
if input.GroupIDs != nil {
|
|
||||||
conflicting, err := s.repo.GetGroupsInOtherChannels(ctx, id, *input.GroupIDs)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("check group conflicts: %w", err)
|
|
||||||
}
|
|
||||||
if len(conflicting) > 0 {
|
|
||||||
return nil, ErrGroupAlreadyInChannel
|
|
||||||
}
|
|
||||||
channel.GroupIDs = *input.GroupIDs
|
|
||||||
}
|
|
||||||
|
|
||||||
if input.ModelPricing != nil {
|
|
||||||
channel.ModelPricing = *input.ModelPricing
|
|
||||||
}
|
|
||||||
|
|
||||||
if input.ModelMapping != nil {
|
|
||||||
channel.ModelMapping = input.ModelMapping
|
|
||||||
}
|
|
||||||
|
|
||||||
if input.BillingModelSource != "" {
|
|
||||||
channel.BillingModelSource = input.BillingModelSource
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := validateNoConflictingModels(channel.ModelPricing); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if err := validatePricingIntervals(channel.ModelPricing); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if err := validateNoConflictingMappings(channel.ModelMapping); err != nil {
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// 先获取旧分组,Update 后旧分组关联已删除,无法再查到
|
if err := validateChannelConfig(channel.ModelPricing, channel.ModelMapping); err != nil {
|
||||||
var oldGroupIDs []int64
|
return nil, err
|
||||||
if s.authCacheInvalidator != nil {
|
|
||||||
var err2 error
|
|
||||||
oldGroupIDs, err2 = s.repo.GetGroupIDs(ctx, id)
|
|
||||||
if err2 != nil {
|
|
||||||
slog.Warn("failed to get old group IDs for cache invalidation", "channel_id", id, "error", err2)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
oldGroupIDs := s.getOldGroupIDs(ctx, id)
|
||||||
|
|
||||||
if err := s.repo.Update(ctx, channel); err != nil {
|
if err := s.repo.Update(ctx, channel); err != nil {
|
||||||
return nil, fmt.Errorf("update channel: %w", err)
|
return nil, fmt.Errorf("update channel: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
s.invalidateCache()
|
s.invalidateCache()
|
||||||
|
s.invalidateAuthCacheForGroups(ctx, oldGroupIDs, channel.GroupIDs)
|
||||||
// 失效新旧分组的 auth 缓存
|
|
||||||
if s.authCacheInvalidator != nil {
|
|
||||||
seen := make(map[int64]struct{}, len(oldGroupIDs)+len(channel.GroupIDs))
|
|
||||||
for _, gid := range oldGroupIDs {
|
|
||||||
if _, ok := seen[gid]; !ok {
|
|
||||||
seen[gid] = struct{}{}
|
|
||||||
s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, gid)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for _, gid := range channel.GroupIDs {
|
|
||||||
if _, ok := seen[gid]; !ok {
|
|
||||||
seen[gid] = struct{}{}
|
|
||||||
s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, gid)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return s.repo.GetByID(ctx, id)
|
return s.repo.GetByID(ctx, id)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// applyUpdateInput 将更新请求的字段应用到渠道实体上。
|
||||||
|
func (s *ChannelService) applyUpdateInput(ctx context.Context, channel *Channel, input *UpdateChannelInput) error {
|
||||||
|
if input.Name != "" && input.Name != channel.Name {
|
||||||
|
exists, err := s.repo.ExistsByNameExcluding(ctx, input.Name, channel.ID)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("check channel exists: %w", err)
|
||||||
|
}
|
||||||
|
if exists {
|
||||||
|
return ErrChannelExists
|
||||||
|
}
|
||||||
|
channel.Name = input.Name
|
||||||
|
}
|
||||||
|
if input.Description != nil {
|
||||||
|
channel.Description = *input.Description
|
||||||
|
}
|
||||||
|
if input.Status != "" {
|
||||||
|
channel.Status = input.Status
|
||||||
|
}
|
||||||
|
if input.RestrictModels != nil {
|
||||||
|
channel.RestrictModels = *input.RestrictModels
|
||||||
|
}
|
||||||
|
if input.Features != nil {
|
||||||
|
channel.Features = *input.Features
|
||||||
|
}
|
||||||
|
if input.GroupIDs != nil {
|
||||||
|
if err := s.checkGroupConflicts(ctx, channel.ID, *input.GroupIDs); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
channel.GroupIDs = *input.GroupIDs
|
||||||
|
}
|
||||||
|
if input.ModelPricing != nil {
|
||||||
|
channel.ModelPricing = *input.ModelPricing
|
||||||
|
}
|
||||||
|
if input.ModelMapping != nil {
|
||||||
|
channel.ModelMapping = input.ModelMapping
|
||||||
|
}
|
||||||
|
if input.BillingModelSource != "" {
|
||||||
|
channel.BillingModelSource = input.BillingModelSource
|
||||||
|
}
|
||||||
|
if input.FeaturesConfig != nil {
|
||||||
|
channel.FeaturesConfig = input.FeaturesConfig
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// checkGroupConflicts 检查待关联的分组是否已属于其他渠道。
|
||||||
|
// channelID 为当前渠道 ID(Create 时传 0)。
|
||||||
|
func (s *ChannelService) checkGroupConflicts(ctx context.Context, channelID int64, groupIDs []int64) error {
|
||||||
|
if len(groupIDs) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
conflicting, err := s.repo.GetGroupsInOtherChannels(ctx, channelID, groupIDs)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("check group conflicts: %w", err)
|
||||||
|
}
|
||||||
|
if len(conflicting) > 0 {
|
||||||
|
return ErrGroupAlreadyInChannel
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// getOldGroupIDs 获取渠道更新前的关联分组 ID(用于失效 auth 缓存)。
|
||||||
|
func (s *ChannelService) getOldGroupIDs(ctx context.Context, channelID int64) []int64 {
|
||||||
|
if s.authCacheInvalidator == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
oldGroupIDs, err := s.repo.GetGroupIDs(ctx, channelID)
|
||||||
|
if err != nil {
|
||||||
|
slog.Warn("failed to get old group IDs for cache invalidation", "channel_id", channelID, "error", err)
|
||||||
|
}
|
||||||
|
return oldGroupIDs
|
||||||
|
}
|
||||||
|
|
||||||
|
// invalidateAuthCacheForGroups 对新旧分组去重后逐个失效 auth 缓存。
|
||||||
|
func (s *ChannelService) invalidateAuthCacheForGroups(ctx context.Context, groupIDSets ...[]int64) {
|
||||||
|
if s.authCacheInvalidator == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
seen := make(map[int64]struct{})
|
||||||
|
for _, ids := range groupIDSets {
|
||||||
|
for _, gid := range ids {
|
||||||
|
if _, ok := seen[gid]; ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
seen[gid] = struct{}{}
|
||||||
|
s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, gid)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Delete 删除渠道
|
// Delete 删除渠道
|
||||||
func (s *ChannelService) Delete(ctx context.Context, id int64) error {
|
func (s *ChannelService) Delete(ctx context.Context, id int64) error {
|
||||||
// 先获取关联分组用于失效缓存
|
|
||||||
groupIDs, err := s.repo.GetGroupIDs(ctx, id)
|
groupIDs, err := s.repo.GetGroupIDs(ctx, id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Warn("failed to get group IDs before delete", "channel_id", id, "error", err)
|
slog.Warn("failed to get group IDs before delete", "channel_id", id, "error", err)
|
||||||
@ -729,12 +817,7 @@ func (s *ChannelService) Delete(ctx context.Context, id int64) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
s.invalidateCache()
|
s.invalidateCache()
|
||||||
|
s.invalidateAuthCacheForGroups(ctx, groupIDs)
|
||||||
if s.authCacheInvalidator != nil {
|
|
||||||
for _, gid := range groupIDs {
|
|
||||||
s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, gid)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@ -847,6 +930,7 @@ type CreateChannelInput struct {
|
|||||||
BillingModelSource string
|
BillingModelSource string
|
||||||
RestrictModels bool
|
RestrictModels bool
|
||||||
Features string
|
Features string
|
||||||
|
FeaturesConfig map[string]any
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateChannelInput 更新渠道输入
|
// UpdateChannelInput 更新渠道输入
|
||||||
@ -860,4 +944,5 @@ type UpdateChannelInput struct {
|
|||||||
BillingModelSource string
|
BillingModelSource string
|
||||||
RestrictModels *bool
|
RestrictModels *bool
|
||||||
Features *string
|
Features *string
|
||||||
|
FeaturesConfig map[string]any
|
||||||
}
|
}
|
||||||
|
|||||||
62
backend/internal/service/channel_websearch_test.go
Normal file
62
backend/internal/service/channel_websearch_test.go
Normal file
@ -0,0 +1,62 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestChannel_IsWebSearchEmulationEnabled_Enabled(t *testing.T) {
|
||||||
|
c := &Channel{
|
||||||
|
FeaturesConfig: map[string]any{
|
||||||
|
featureKeyWebSearchEmulation: map[string]any{"anthropic": true},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
require.True(t, c.IsWebSearchEmulationEnabled("anthropic"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestChannel_IsWebSearchEmulationEnabled_DifferentPlatform(t *testing.T) {
|
||||||
|
c := &Channel{
|
||||||
|
FeaturesConfig: map[string]any{
|
||||||
|
featureKeyWebSearchEmulation: map[string]any{"anthropic": true},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
require.False(t, c.IsWebSearchEmulationEnabled("openai"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestChannel_IsWebSearchEmulationEnabled_Disabled(t *testing.T) {
|
||||||
|
c := &Channel{
|
||||||
|
FeaturesConfig: map[string]any{
|
||||||
|
featureKeyWebSearchEmulation: map[string]any{"anthropic": false},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
require.False(t, c.IsWebSearchEmulationEnabled("anthropic"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestChannel_IsWebSearchEmulationEnabled_NilFeaturesConfig(t *testing.T) {
|
||||||
|
c := &Channel{FeaturesConfig: nil}
|
||||||
|
require.False(t, c.IsWebSearchEmulationEnabled("anthropic"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestChannel_IsWebSearchEmulationEnabled_NilChannel(t *testing.T) {
|
||||||
|
var c *Channel
|
||||||
|
require.False(t, c.IsWebSearchEmulationEnabled("anthropic"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestChannel_IsWebSearchEmulationEnabled_WrongStructure(t *testing.T) {
|
||||||
|
c := &Channel{
|
||||||
|
FeaturesConfig: map[string]any{
|
||||||
|
featureKeyWebSearchEmulation: true, // not a map
|
||||||
|
},
|
||||||
|
}
|
||||||
|
require.False(t, c.IsWebSearchEmulationEnabled("anthropic"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestChannel_IsWebSearchEmulationEnabled_PlatformValueNotBool(t *testing.T) {
|
||||||
|
c := &Channel{
|
||||||
|
FeaturesConfig: map[string]any{
|
||||||
|
featureKeyWebSearchEmulation: map[string]any{"anthropic": "yes"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
require.False(t, c.IsWebSearchEmulationEnabled("anthropic"))
|
||||||
|
}
|
||||||
@ -249,6 +249,10 @@ const (
|
|||||||
SettingKeyEnableMetadataPassthrough = "enable_metadata_passthrough"
|
SettingKeyEnableMetadataPassthrough = "enable_metadata_passthrough"
|
||||||
// SettingKeyEnableCCHSigning 是否对 billing header 中的 cch 进行 xxHash64 签名(默认 false)
|
// SettingKeyEnableCCHSigning 是否对 billing header 中的 cch 进行 xxHash64 签名(默认 false)
|
||||||
SettingKeyEnableCCHSigning = "enable_cch_signing"
|
SettingKeyEnableCCHSigning = "enable_cch_signing"
|
||||||
|
|
||||||
|
// Web Search Emulation
|
||||||
|
// SettingKeyWebSearchEmulationConfig 全局 web search 模拟配置(JSON)
|
||||||
|
SettingKeyWebSearchEmulationConfig = "web_search_emulation_config"
|
||||||
)
|
)
|
||||||
|
|
||||||
// AdminAPIKeyPrefix is the prefix for admin API keys (distinct from user "sk-" keys).
|
// AdminAPIKeyPrefix is the prefix for admin API keys (distinct from user "sk-" keys).
|
||||||
|
|||||||
@ -3785,6 +3785,11 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
|||||||
return nil, fmt.Errorf("parse request: empty request")
|
return nil, fmt.Errorf("parse request: empty request")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Web Search 模拟:纯 web_search 请求时,直接调用搜索 API 构造响应
|
||||||
|
if account != nil && s.shouldEmulateWebSearch(ctx, account, parsed.Body) {
|
||||||
|
return s.handleWebSearchEmulation(ctx, c, account, parsed)
|
||||||
|
}
|
||||||
|
|
||||||
if account != nil && account.IsAnthropicAPIKeyPassthroughEnabled() {
|
if account != nil && account.IsAnthropicAPIKeyPassthroughEnabled() {
|
||||||
passthroughBody := parsed.Body
|
passthroughBody := parsed.Body
|
||||||
passthroughModel := parsed.Model
|
passthroughModel := parsed.Model
|
||||||
|
|||||||
358
backend/internal/service/gateway_websearch_emulation.go
Normal file
358
backend/internal/service/gateway_websearch_emulation.go
Normal file
@ -0,0 +1,358 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"log/slog"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/websearch"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Web search emulation constants
|
||||||
|
const (
|
||||||
|
toolTypeWebSearchPrefix = "web_search"
|
||||||
|
toolTypeGoogleSearch = "google_search"
|
||||||
|
toolNameWebSearch = "web_search"
|
||||||
|
toolNameGoogleSearch = "google_search"
|
||||||
|
toolNameWebSearch2025 = "web_search_20250305"
|
||||||
|
|
||||||
|
webSearchDefaultMaxResults = 5
|
||||||
|
defaultWebSearchModel = "claude-sonnet-4-6"
|
||||||
|
webSearchMsgIDPrefix = "msg_ws_"
|
||||||
|
webSearchToolUseIDPrefix = "srvtoolu_ws_"
|
||||||
|
tokenEstimateDivisor = 4
|
||||||
|
|
||||||
|
// featureKeyWebSearchEmulation is the key used in Account.Extra and Channel.FeaturesConfig.
|
||||||
|
featureKeyWebSearchEmulation = "web_search_emulation"
|
||||||
|
)
|
||||||
|
|
||||||
|
// webSearchManagerPtr stores *websearch.Manager atomically for concurrent safety.
|
||||||
|
var webSearchManagerPtr atomic.Pointer[websearch.Manager]
|
||||||
|
|
||||||
|
// SetWebSearchManager wires the websearch.Manager into the gateway (goroutine-safe).
|
||||||
|
func SetWebSearchManager(m *websearch.Manager) {
|
||||||
|
webSearchManagerPtr.Store(m)
|
||||||
|
}
|
||||||
|
|
||||||
|
func getWebSearchManager() *websearch.Manager {
|
||||||
|
return webSearchManagerPtr.Load()
|
||||||
|
}
|
||||||
|
|
||||||
|
// shouldEmulateWebSearch checks whether a request should be intercepted.
|
||||||
|
//
|
||||||
|
// Judgment chain: manager exists → only web_search tool → global enabled → account enabled.
|
||||||
|
// Note: channel-level control is enforced via the account's extra field; the channel toggle
|
||||||
|
// in the admin UI sets the account's flag for all accounts in that channel's groups.
|
||||||
|
func (s *GatewayService) shouldEmulateWebSearch(ctx context.Context, account *Account, body []byte) bool {
|
||||||
|
if getWebSearchManager() == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if !isOnlyWebSearchToolInBody(body) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if !s.settingService.IsWebSearchEmulationEnabled(ctx) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if !account.IsWebSearchEmulationEnabled() {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// isOnlyWebSearchToolInBody checks if the body contains exactly one web_search tool.
|
||||||
|
func isOnlyWebSearchToolInBody(body []byte) bool {
|
||||||
|
tools := gjson.GetBytes(body, "tools")
|
||||||
|
if !tools.IsArray() {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
arr := tools.Array()
|
||||||
|
if len(arr) != 1 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return isWebSearchToolJSON(arr[0])
|
||||||
|
}
|
||||||
|
|
||||||
|
func isWebSearchToolJSON(tool gjson.Result) bool {
|
||||||
|
toolType := tool.Get("type").String()
|
||||||
|
if strings.HasPrefix(toolType, toolTypeWebSearchPrefix) || toolType == toolTypeGoogleSearch {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
switch tool.Get("name").String() {
|
||||||
|
case toolNameWebSearch, toolNameGoogleSearch, toolNameWebSearch2025:
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// extractSearchQueryFromBody extracts the last user message text as the search query.
|
||||||
|
func extractSearchQueryFromBody(body []byte) string {
|
||||||
|
messages := gjson.GetBytes(body, "messages")
|
||||||
|
if !messages.IsArray() {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
arr := messages.Array()
|
||||||
|
if len(arr) == 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
lastMsg := arr[len(arr)-1]
|
||||||
|
if lastMsg.Get("role").String() != "user" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return extractWebSearchTextFromContent(lastMsg.Get("content"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func extractWebSearchTextFromContent(content gjson.Result) string {
|
||||||
|
if content.Type == gjson.String {
|
||||||
|
return content.String()
|
||||||
|
}
|
||||||
|
if content.IsArray() {
|
||||||
|
for _, block := range content.Array() {
|
||||||
|
if block.Get("type").String() == "text" {
|
||||||
|
if text := block.Get("text").String(); text != "" {
|
||||||
|
return text
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleWebSearchEmulation intercepts a web-search-only request,
|
||||||
|
// calls a third-party search API, and constructs an Anthropic-format response.
|
||||||
|
func (s *GatewayService) handleWebSearchEmulation(
|
||||||
|
ctx context.Context, c *gin.Context, account *Account, parsed *ParsedRequest,
|
||||||
|
) (*ForwardResult, error) {
|
||||||
|
startTime := time.Now()
|
||||||
|
|
||||||
|
// Release the serial queue lock immediately — we don't need upstream.
|
||||||
|
if parsed.OnUpstreamAccepted != nil {
|
||||||
|
parsed.OnUpstreamAccepted()
|
||||||
|
}
|
||||||
|
|
||||||
|
query := extractSearchQueryFromBody(parsed.Body)
|
||||||
|
if query == "" {
|
||||||
|
return nil, fmt.Errorf("web search emulation: no query found in messages")
|
||||||
|
}
|
||||||
|
|
||||||
|
slog.Info("web search emulation: executing search",
|
||||||
|
"account_id", account.ID, "account_name", account.Name, "query", query)
|
||||||
|
|
||||||
|
resp, providerName, err := doWebSearch(ctx, account, query)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
slog.Info("web search emulation: search completed",
|
||||||
|
"provider", providerName, "results_count", len(resp.Results))
|
||||||
|
|
||||||
|
model := parsed.Model
|
||||||
|
if model == "" {
|
||||||
|
model = defaultWebSearchModel
|
||||||
|
}
|
||||||
|
|
||||||
|
if parsed.Stream {
|
||||||
|
return writeWebSearchStreamResponse(c, query, resp, model, startTime)
|
||||||
|
}
|
||||||
|
return writeWebSearchNonStreamResponse(c, query, resp, model, startTime)
|
||||||
|
}
|
||||||
|
|
||||||
|
func doWebSearch(ctx context.Context, account *Account, query string) (*websearch.SearchResponse, string, error) {
|
||||||
|
proxyURL := resolveAccountProxyURL(account)
|
||||||
|
mgr := getWebSearchManager()
|
||||||
|
if mgr == nil {
|
||||||
|
return nil, "", fmt.Errorf("web search emulation: manager not initialized")
|
||||||
|
}
|
||||||
|
resp, providerName, err := mgr.SearchWithBestProvider(ctx, websearch.SearchRequest{
|
||||||
|
Query: query, MaxResults: webSearchDefaultMaxResults, ProxyURL: proxyURL,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
slog.Error("web search emulation: search failed", "error", err)
|
||||||
|
return nil, "", fmt.Errorf("web search emulation: %w", err)
|
||||||
|
}
|
||||||
|
return resp, providerName, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func resolveAccountProxyURL(account *Account) string {
|
||||||
|
if account.ProxyID != nil && account.Proxy != nil {
|
||||||
|
return account.Proxy.URL()
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- SSE streaming response ---
|
||||||
|
|
||||||
|
func writeWebSearchStreamResponse(
|
||||||
|
c *gin.Context, query string, resp *websearch.SearchResponse, model string, startTime time.Time,
|
||||||
|
) (*ForwardResult, error) {
|
||||||
|
msgID := webSearchMsgIDPrefix + uuid.New().String()
|
||||||
|
toolUseID := webSearchToolUseIDPrefix + uuid.New().String()[:16]
|
||||||
|
|
||||||
|
setSSEHeaders(c)
|
||||||
|
if err := writeSSEMessageStart(c.Writer, msgID, model); err != nil {
|
||||||
|
return nil, fmt.Errorf("web search emulation: SSE write: %w", err)
|
||||||
|
}
|
||||||
|
writeSSEServerToolUse(c.Writer, toolUseID, query, 0)
|
||||||
|
writeSSEToolResult(c.Writer, toolUseID, resp.Results, 1)
|
||||||
|
textSummary := buildTextSummary(query, resp.Results)
|
||||||
|
writeSSETextBlock(c.Writer, textSummary, 2)
|
||||||
|
writeSSEMessageEnd(c.Writer, len(textSummary)/tokenEstimateDivisor)
|
||||||
|
c.Writer.Flush()
|
||||||
|
|
||||||
|
return &ForwardResult{Model: model, Duration: time.Since(startTime), Usage: ClaudeUsage{}}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func setSSEHeaders(c *gin.Context) {
|
||||||
|
c.Writer.Header().Set("Content-Type", "text/event-stream")
|
||||||
|
c.Writer.Header().Set("Cache-Control", "no-cache")
|
||||||
|
c.Writer.Header().Set("Connection", "keep-alive")
|
||||||
|
c.Writer.Header().Set("X-Accel-Buffering", "no")
|
||||||
|
c.Writer.WriteHeader(http.StatusOK)
|
||||||
|
}
|
||||||
|
|
||||||
|
func writeSSEMessageStart(w http.ResponseWriter, msgID, model string) error {
|
||||||
|
evt := map[string]any{
|
||||||
|
"type": "message_start",
|
||||||
|
"message": map[string]any{
|
||||||
|
"id": msgID, "type": "message", "role": "assistant", "model": model,
|
||||||
|
"content": []any{}, "stop_reason": nil, "stop_sequence": nil,
|
||||||
|
"usage": map[string]int{"input_tokens": 0, "output_tokens": 0},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
return flushSSEJSON(w, "message_start", evt)
|
||||||
|
}
|
||||||
|
|
||||||
|
func writeSSEServerToolUse(w http.ResponseWriter, toolUseID, query string, index int) {
|
||||||
|
start := map[string]any{
|
||||||
|
"type": "content_block_start", "index": index,
|
||||||
|
"content_block": map[string]any{
|
||||||
|
"type": "server_tool_use", "id": toolUseID,
|
||||||
|
"name": toolNameWebSearch, "input": map[string]string{"query": query},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
_ = flushSSEJSON(w, "content_block_start", start)
|
||||||
|
_ = flushSSEJSON(w, "content_block_stop", map[string]any{"type": "content_block_stop", "index": index})
|
||||||
|
}
|
||||||
|
|
||||||
|
func writeSSEToolResult(w http.ResponseWriter, toolUseID string, results []websearch.SearchResult, index int) {
|
||||||
|
start := map[string]any{
|
||||||
|
"type": "content_block_start", "index": index,
|
||||||
|
"content_block": map[string]any{
|
||||||
|
"type": "web_search_tool_result", "tool_use_id": toolUseID,
|
||||||
|
"content": buildSearchResultBlocks(results),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
_ = flushSSEJSON(w, "content_block_start", start)
|
||||||
|
_ = flushSSEJSON(w, "content_block_stop", map[string]any{"type": "content_block_stop", "index": index})
|
||||||
|
}
|
||||||
|
|
||||||
|
func writeSSETextBlock(w http.ResponseWriter, text string, index int) {
|
||||||
|
_ = flushSSEJSON(w, "content_block_start", map[string]any{
|
||||||
|
"type": "content_block_start", "index": index,
|
||||||
|
"content_block": map[string]any{"type": "text", "text": ""},
|
||||||
|
})
|
||||||
|
_ = flushSSEJSON(w, "content_block_delta", map[string]any{
|
||||||
|
"type": "content_block_delta", "index": index,
|
||||||
|
"delta": map[string]string{"type": "text_delta", "text": text},
|
||||||
|
})
|
||||||
|
_ = flushSSEJSON(w, "content_block_stop", map[string]any{"type": "content_block_stop", "index": index})
|
||||||
|
}
|
||||||
|
|
||||||
|
func writeSSEMessageEnd(w http.ResponseWriter, outputTokens int) {
|
||||||
|
_ = flushSSEJSON(w, "message_delta", map[string]any{
|
||||||
|
"type": "message_delta",
|
||||||
|
"delta": map[string]any{"stop_reason": "end_turn", "stop_sequence": nil},
|
||||||
|
"usage": map[string]int{"output_tokens": outputTokens},
|
||||||
|
})
|
||||||
|
_ = flushSSEJSON(w, "message_stop", map[string]string{"type": "message_stop"})
|
||||||
|
}
|
||||||
|
|
||||||
|
// flushSSEJSON marshals data to JSON and writes an SSE event. Returns error on marshal failure.
|
||||||
|
func flushSSEJSON(w http.ResponseWriter, event string, data any) error {
|
||||||
|
b, err := json.Marshal(data)
|
||||||
|
if err != nil {
|
||||||
|
slog.Error("web search emulation: failed to marshal SSE event",
|
||||||
|
"event", event, "error", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
fmt.Fprintf(w, "event: %s\ndata: %s\n\n", event, b)
|
||||||
|
if f, ok := w.(http.Flusher); ok {
|
||||||
|
f.Flush()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Non-streaming JSON response ---
|
||||||
|
|
||||||
|
func writeWebSearchNonStreamResponse(
|
||||||
|
c *gin.Context, query string, resp *websearch.SearchResponse, model string, startTime time.Time,
|
||||||
|
) (*ForwardResult, error) {
|
||||||
|
msgID := webSearchMsgIDPrefix + uuid.New().String()
|
||||||
|
toolUseID := webSearchToolUseIDPrefix + uuid.New().String()[:16]
|
||||||
|
textSummary := buildTextSummary(query, resp.Results)
|
||||||
|
|
||||||
|
msg := map[string]any{
|
||||||
|
"id": msgID, "type": "message", "role": "assistant", "model": model,
|
||||||
|
"content": []any{
|
||||||
|
map[string]any{
|
||||||
|
"type": "server_tool_use", "id": toolUseID,
|
||||||
|
"name": toolNameWebSearch, "input": map[string]string{"query": query},
|
||||||
|
},
|
||||||
|
map[string]any{
|
||||||
|
"type": "web_search_tool_result", "tool_use_id": toolUseID,
|
||||||
|
"content": buildSearchResultBlocks(resp.Results),
|
||||||
|
},
|
||||||
|
map[string]any{"type": "text", "text": textSummary},
|
||||||
|
},
|
||||||
|
"stop_reason": "end_turn", "stop_sequence": nil,
|
||||||
|
"usage": map[string]int{"input_tokens": 0, "output_tokens": len(textSummary) / tokenEstimateDivisor},
|
||||||
|
}
|
||||||
|
|
||||||
|
body, err := json.Marshal(msg)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("web search emulation: marshal response: %w", err)
|
||||||
|
}
|
||||||
|
c.Data(http.StatusOK, "application/json", body)
|
||||||
|
|
||||||
|
return &ForwardResult{Model: model, Duration: time.Since(startTime), Usage: ClaudeUsage{}}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Helpers ---
|
||||||
|
|
||||||
|
func buildSearchResultBlocks(results []websearch.SearchResult) []map[string]string {
|
||||||
|
blocks := make([]map[string]string, 0, len(results))
|
||||||
|
for _, r := range results {
|
||||||
|
block := map[string]string{
|
||||||
|
"type": "web_search_result",
|
||||||
|
"url": r.URL,
|
||||||
|
"title": r.Title,
|
||||||
|
}
|
||||||
|
if r.Snippet != "" {
|
||||||
|
block["page_content"] = r.Snippet
|
||||||
|
}
|
||||||
|
if r.PageAge != "" {
|
||||||
|
block["page_age"] = r.PageAge
|
||||||
|
}
|
||||||
|
blocks = append(blocks, block)
|
||||||
|
}
|
||||||
|
return blocks
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildTextSummary(query string, results []websearch.SearchResult) string {
|
||||||
|
if len(results) == 0 {
|
||||||
|
return "No search results found for: " + query
|
||||||
|
}
|
||||||
|
var sb strings.Builder
|
||||||
|
fmt.Fprintf(&sb, "Here are the search results for \"%s\":\n\n", query)
|
||||||
|
for i, r := range results {
|
||||||
|
fmt.Fprintf(&sb, "%d. **%s**\n %s\n %s\n\n", i+1, r.Title, r.URL, r.Snippet)
|
||||||
|
}
|
||||||
|
return sb.String()
|
||||||
|
}
|
||||||
142
backend/internal/service/gateway_websearch_emulation_test.go
Normal file
142
backend/internal/service/gateway_websearch_emulation_test.go
Normal file
@ -0,0 +1,142 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/websearch"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
// --- isOnlyWebSearchToolInBody ---
|
||||||
|
|
||||||
|
func TestIsOnlyWebSearchToolInBody_WebSearchType(t *testing.T) {
|
||||||
|
require.True(t, isOnlyWebSearchToolInBody([]byte(`{"tools":[{"type":"web_search"}]}`)))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsOnlyWebSearchToolInBody_WebSearch2025Type(t *testing.T) {
|
||||||
|
require.True(t, isOnlyWebSearchToolInBody([]byte(`{"tools":[{"type":"web_search_20250305"}]}`)))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsOnlyWebSearchToolInBody_GoogleSearchType(t *testing.T) {
|
||||||
|
require.True(t, isOnlyWebSearchToolInBody([]byte(`{"tools":[{"type":"google_search"}]}`)))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsOnlyWebSearchToolInBody_NameWebSearch(t *testing.T) {
|
||||||
|
require.True(t, isOnlyWebSearchToolInBody([]byte(`{"tools":[{"name":"web_search"}]}`)))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsOnlyWebSearchToolInBody_NameWebSearch2025(t *testing.T) {
|
||||||
|
require.True(t, isOnlyWebSearchToolInBody([]byte(`{"tools":[{"name":"web_search_20250305"}]}`)))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsOnlyWebSearchToolInBody_NameGoogleSearch(t *testing.T) {
|
||||||
|
require.True(t, isOnlyWebSearchToolInBody([]byte(`{"tools":[{"name":"google_search"}]}`)))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsOnlyWebSearchToolInBody_MultipleTools(t *testing.T) {
|
||||||
|
require.False(t, isOnlyWebSearchToolInBody(
|
||||||
|
[]byte(`{"tools":[{"type":"web_search"},{"type":"text_editor"}]}`)))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsOnlyWebSearchToolInBody_NoTools(t *testing.T) {
|
||||||
|
require.False(t, isOnlyWebSearchToolInBody([]byte(`{"model":"claude-3"}`)))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsOnlyWebSearchToolInBody_EmptyToolsArray(t *testing.T) {
|
||||||
|
require.False(t, isOnlyWebSearchToolInBody([]byte(`{"tools":[]}`)))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsOnlyWebSearchToolInBody_NonWebSearchTool(t *testing.T) {
|
||||||
|
require.False(t, isOnlyWebSearchToolInBody([]byte(`{"tools":[{"type":"text_editor"}]}`)))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsOnlyWebSearchToolInBody_ToolsNotArray(t *testing.T) {
|
||||||
|
require.False(t, isOnlyWebSearchToolInBody([]byte(`{"tools":"web_search"}`)))
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- extractSearchQueryFromBody ---
|
||||||
|
|
||||||
|
func TestExtractSearchQueryFromBody_StringContent(t *testing.T) {
|
||||||
|
body := `{"messages":[{"role":"user","content":"what is golang"}]}`
|
||||||
|
require.Equal(t, "what is golang", extractSearchQueryFromBody([]byte(body)))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExtractSearchQueryFromBody_ArrayContent(t *testing.T) {
|
||||||
|
body := `{"messages":[{"role":"user","content":[{"type":"text","text":"search this"}]}]}`
|
||||||
|
require.Equal(t, "search this", extractSearchQueryFromBody([]byte(body)))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExtractSearchQueryFromBody_MultipleMessages(t *testing.T) {
|
||||||
|
body := `{"messages":[{"role":"user","content":"first"},{"role":"assistant","content":"ok"},{"role":"user","content":"second"}]}`
|
||||||
|
require.Equal(t, "second", extractSearchQueryFromBody([]byte(body)))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExtractSearchQueryFromBody_LastMessageNotUser(t *testing.T) {
|
||||||
|
body := `{"messages":[{"role":"user","content":"q"},{"role":"assistant","content":"a"}]}`
|
||||||
|
require.Equal(t, "", extractSearchQueryFromBody([]byte(body)))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExtractSearchQueryFromBody_EmptyMessages(t *testing.T) {
|
||||||
|
require.Equal(t, "", extractSearchQueryFromBody([]byte(`{"messages":[]}`)))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExtractSearchQueryFromBody_NoMessages(t *testing.T) {
|
||||||
|
require.Equal(t, "", extractSearchQueryFromBody([]byte(`{"model":"claude-3"}`)))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExtractSearchQueryFromBody_ArrayContentSkipsEmptyText(t *testing.T) {
|
||||||
|
body := `{"messages":[{"role":"user","content":[{"type":"image"},{"type":"text","text":""},{"type":"text","text":"real query"}]}]}`
|
||||||
|
require.Equal(t, "real query", extractSearchQueryFromBody([]byte(body)))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExtractSearchQueryFromBody_ArrayContentNoTextBlock(t *testing.T) {
|
||||||
|
body := `{"messages":[{"role":"user","content":[{"type":"image","source":{}}]}]}`
|
||||||
|
require.Equal(t, "", extractSearchQueryFromBody([]byte(body)))
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- buildSearchResultBlocks ---
|
||||||
|
|
||||||
|
func TestBuildSearchResultBlocks_WithResults(t *testing.T) {
|
||||||
|
results := []websearch.SearchResult{
|
||||||
|
{URL: "https://a.com", Title: "A", Snippet: "snippet a", PageAge: "2 days"},
|
||||||
|
{URL: "https://b.com", Title: "B", Snippet: "snippet b"},
|
||||||
|
}
|
||||||
|
blocks := buildSearchResultBlocks(results)
|
||||||
|
require.Len(t, blocks, 2)
|
||||||
|
require.Equal(t, "web_search_result", blocks[0]["type"])
|
||||||
|
require.Equal(t, "https://a.com", blocks[0]["url"])
|
||||||
|
require.Equal(t, "snippet a", blocks[0]["page_content"])
|
||||||
|
require.Equal(t, "2 days", blocks[0]["page_age"])
|
||||||
|
// Second result has no PageAge
|
||||||
|
require.Equal(t, "https://b.com", blocks[1]["url"])
|
||||||
|
_, hasPageAge := blocks[1]["page_age"]
|
||||||
|
require.False(t, hasPageAge)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildSearchResultBlocks_Empty(t *testing.T) {
|
||||||
|
blocks := buildSearchResultBlocks(nil)
|
||||||
|
require.Empty(t, blocks)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildSearchResultBlocks_SnippetEmpty(t *testing.T) {
|
||||||
|
blocks := buildSearchResultBlocks([]websearch.SearchResult{{URL: "https://x.com", Title: "X", Snippet: ""}})
|
||||||
|
_, hasContent := blocks[0]["page_content"]
|
||||||
|
require.False(t, hasContent)
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- buildTextSummary ---
|
||||||
|
|
||||||
|
func TestBuildTextSummary_WithResults(t *testing.T) {
|
||||||
|
results := []websearch.SearchResult{
|
||||||
|
{URL: "https://a.com", Title: "A", Snippet: "desc a"},
|
||||||
|
}
|
||||||
|
summary := buildTextSummary("test query", results)
|
||||||
|
require.Contains(t, summary, "test query")
|
||||||
|
require.Contains(t, summary, "1. **A**")
|
||||||
|
require.Contains(t, summary, "https://a.com")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildTextSummary_NoResults(t *testing.T) {
|
||||||
|
summary := buildTextSummary("test", nil)
|
||||||
|
require.Contains(t, summary, "No search results found for: test")
|
||||||
|
}
|
||||||
@ -18,6 +18,7 @@ import (
|
|||||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||||
"github.com/imroc/req/v3"
|
"github.com/imroc/req/v3"
|
||||||
|
"github.com/redis/go-redis/v9"
|
||||||
"golang.org/x/sync/singleflight"
|
"golang.org/x/sync/singleflight"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -106,6 +107,7 @@ type SettingService struct {
|
|||||||
cfg *config.Config
|
cfg *config.Config
|
||||||
onUpdate func() // Callback when settings are updated (for cache invalidation)
|
onUpdate func() // Callback when settings are updated (for cache invalidation)
|
||||||
version string // Application version
|
version string // Application version
|
||||||
|
webSearchRedis *redis.Client // optional: Redis client for web search quota tracking
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewSettingService 创建系统设置服务实例
|
// NewSettingService 创建系统设置服务实例
|
||||||
@ -1217,6 +1219,14 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
|
|||||||
result.EnableMetadataPassthrough = settings[SettingKeyEnableMetadataPassthrough] == "true"
|
result.EnableMetadataPassthrough = settings[SettingKeyEnableMetadataPassthrough] == "true"
|
||||||
result.EnableCCHSigning = settings[SettingKeyEnableCCHSigning] == "true"
|
result.EnableCCHSigning = settings[SettingKeyEnableCCHSigning] == "true"
|
||||||
|
|
||||||
|
// Web search emulation: quick enabled check from the JSON config
|
||||||
|
if raw := settings[SettingKeyWebSearchEmulationConfig]; raw != "" {
|
||||||
|
var wsCfg WebSearchEmulationConfig
|
||||||
|
if err := json.Unmarshal([]byte(raw), &wsCfg); err == nil {
|
||||||
|
result.WebSearchEmulationEnabled = wsCfg.Enabled && len(wsCfg.Providers) > 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -106,6 +106,9 @@ type SystemSettings struct {
|
|||||||
EnableFingerprintUnification bool // 是否统一 OAuth 账号的指纹头(默认 true)
|
EnableFingerprintUnification bool // 是否统一 OAuth 账号的指纹头(默认 true)
|
||||||
EnableMetadataPassthrough bool // 是否透传客户端原始 metadata(默认 false)
|
EnableMetadataPassthrough bool // 是否透传客户端原始 metadata(默认 false)
|
||||||
EnableCCHSigning bool // 是否对 billing header cch 进行签名(默认 false)
|
EnableCCHSigning bool // 是否对 billing header cch 进行签名(默认 false)
|
||||||
|
|
||||||
|
// Web Search Emulation (read-only quick check; full config via dedicated API)
|
||||||
|
WebSearchEmulationEnabled bool
|
||||||
}
|
}
|
||||||
|
|
||||||
type DefaultSubscriptionSetting struct {
|
type DefaultSubscriptionSetting struct {
|
||||||
|
|||||||
253
backend/internal/service/websearch_config.go
Normal file
253
backend/internal/service/websearch_config.go
Normal file
@ -0,0 +1,253 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"log/slog"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/websearch"
|
||||||
|
"github.com/redis/go-redis/v9"
|
||||||
|
"golang.org/x/sync/singleflight"
|
||||||
|
)
|
||||||
|
|
||||||
|
// WebSearchEmulationConfig holds the global web search emulation configuration.
|
||||||
|
type WebSearchEmulationConfig struct {
|
||||||
|
Enabled bool `json:"enabled"`
|
||||||
|
Providers []WebSearchProviderConfig `json:"providers"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// WebSearchProviderConfig describes a single search provider (Brave or Tavily).
|
||||||
|
type WebSearchProviderConfig struct {
|
||||||
|
Type string `json:"type"` // websearch.ProviderTypeBrave | Tavily
|
||||||
|
APIKey string `json:"api_key,omitempty"` // secret — omitted in API responses
|
||||||
|
APIKeyConfigured bool `json:"api_key_configured"` // read-only mask
|
||||||
|
Priority int `json:"priority"` // lower = higher priority
|
||||||
|
QuotaLimit int64 `json:"quota_limit"` // 0 = unlimited
|
||||||
|
QuotaRefreshInterval string `json:"quota_refresh_interval"` // websearch.QuotaRefresh*
|
||||||
|
QuotaUsed int64 `json:"quota_used,omitempty"` // read-only: current period usage
|
||||||
|
ProxyID *int64 `json:"proxy_id"` // optional proxy association
|
||||||
|
ExpiresAt *int64 `json:"expires_at,omitempty"` // optional expiration timestamp
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Validation ---
|
||||||
|
|
||||||
|
const maxWebSearchProviders = 10
|
||||||
|
|
||||||
|
var validProviderTypes = map[string]bool{
|
||||||
|
websearch.ProviderTypeBrave: true,
|
||||||
|
websearch.ProviderTypeTavily: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
var validQuotaIntervals = map[string]bool{
|
||||||
|
websearch.QuotaRefreshDaily: true,
|
||||||
|
websearch.QuotaRefreshWeekly: true,
|
||||||
|
websearch.QuotaRefreshMonthly: true,
|
||||||
|
"": true, // defaults to monthly
|
||||||
|
}
|
||||||
|
|
||||||
|
func validateWebSearchConfig(cfg *WebSearchEmulationConfig) error {
|
||||||
|
if cfg == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if len(cfg.Providers) > maxWebSearchProviders {
|
||||||
|
return fmt.Errorf("too many providers (max %d)", maxWebSearchProviders)
|
||||||
|
}
|
||||||
|
seen := make(map[string]bool, len(cfg.Providers))
|
||||||
|
for i, p := range cfg.Providers {
|
||||||
|
if !validProviderTypes[p.Type] {
|
||||||
|
return fmt.Errorf("provider[%d]: invalid type %q", i, p.Type)
|
||||||
|
}
|
||||||
|
if !validQuotaIntervals[p.QuotaRefreshInterval] {
|
||||||
|
return fmt.Errorf("provider[%d]: invalid quota_refresh_interval %q", i, p.QuotaRefreshInterval)
|
||||||
|
}
|
||||||
|
if p.QuotaLimit < 0 {
|
||||||
|
return fmt.Errorf("provider[%d]: quota_limit must be >= 0", i)
|
||||||
|
}
|
||||||
|
if seen[p.Type] {
|
||||||
|
return fmt.Errorf("provider[%d]: duplicate type %q", i, p.Type)
|
||||||
|
}
|
||||||
|
seen[p.Type] = true
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- In-process cache (same pattern as gateway forwarding settings) ---
|
||||||
|
|
||||||
|
const sfKeyWebSearchConfig = "web_search_emulation_config"
|
||||||
|
|
||||||
|
type cachedWebSearchEmulationConfig struct {
|
||||||
|
config *WebSearchEmulationConfig
|
||||||
|
expiresAt int64 // unix nano
|
||||||
|
}
|
||||||
|
|
||||||
|
var webSearchEmulationCache atomic.Value // *cachedWebSearchEmulationConfig
|
||||||
|
var webSearchEmulationSF singleflight.Group
|
||||||
|
|
||||||
|
const (
|
||||||
|
webSearchEmulationCacheTTL = 60 * time.Second
|
||||||
|
webSearchEmulationErrorTTL = 5 * time.Second
|
||||||
|
webSearchEmulationDBTimeout = 5 * time.Second
|
||||||
|
)
|
||||||
|
|
||||||
|
// GetWebSearchEmulationConfig returns the configuration with in-process cache + singleflight.
|
||||||
|
func (s *SettingService) GetWebSearchEmulationConfig(ctx context.Context) (*WebSearchEmulationConfig, error) {
|
||||||
|
if cached := webSearchEmulationCache.Load(); cached != nil {
|
||||||
|
c := cached.(*cachedWebSearchEmulationConfig)
|
||||||
|
if time.Now().UnixNano() < c.expiresAt {
|
||||||
|
return c.config, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
result, err, _ := webSearchEmulationSF.Do(sfKeyWebSearchConfig, func() (any, error) {
|
||||||
|
return s.loadWebSearchConfigFromDB()
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return &WebSearchEmulationConfig{}, err
|
||||||
|
}
|
||||||
|
return result.(*WebSearchEmulationConfig), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SettingService) loadWebSearchConfigFromDB() (*WebSearchEmulationConfig, error) {
|
||||||
|
dbCtx, cancel := context.WithTimeout(context.Background(), webSearchEmulationDBTimeout)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
raw, err := s.settingRepo.GetValue(dbCtx, SettingKeyWebSearchEmulationConfig)
|
||||||
|
if err != nil {
|
||||||
|
webSearchEmulationCache.Store(&cachedWebSearchEmulationConfig{
|
||||||
|
config: &WebSearchEmulationConfig{},
|
||||||
|
expiresAt: time.Now().Add(webSearchEmulationErrorTTL).UnixNano(),
|
||||||
|
})
|
||||||
|
return &WebSearchEmulationConfig{}, err
|
||||||
|
}
|
||||||
|
cfg := parseWebSearchConfigJSON(raw)
|
||||||
|
webSearchEmulationCache.Store(&cachedWebSearchEmulationConfig{
|
||||||
|
config: cfg,
|
||||||
|
expiresAt: time.Now().Add(webSearchEmulationCacheTTL).UnixNano(),
|
||||||
|
})
|
||||||
|
return cfg, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseWebSearchConfigJSON(raw string) *WebSearchEmulationConfig {
|
||||||
|
cfg := &WebSearchEmulationConfig{}
|
||||||
|
if raw == "" {
|
||||||
|
return cfg
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal([]byte(raw), cfg); err != nil {
|
||||||
|
slog.Warn("websearch: failed to parse config JSON", "error", err)
|
||||||
|
return &WebSearchEmulationConfig{}
|
||||||
|
}
|
||||||
|
return cfg
|
||||||
|
}
|
||||||
|
|
||||||
|
// SaveWebSearchEmulationConfig validates and persists the configuration.
|
||||||
|
// Empty API keys in the input are preserved from the existing config.
|
||||||
|
func (s *SettingService) SaveWebSearchEmulationConfig(ctx context.Context, cfg *WebSearchEmulationConfig) error {
|
||||||
|
if err := validateWebSearchConfig(cfg); err != nil {
|
||||||
|
return infraerrors.BadRequest("INVALID_WEB_SEARCH_CONFIG", err.Error())
|
||||||
|
}
|
||||||
|
s.mergeExistingAPIKeys(ctx, cfg)
|
||||||
|
|
||||||
|
data, err := json.Marshal(cfg)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("websearch: marshal config: %w", err)
|
||||||
|
}
|
||||||
|
if err := s.settingRepo.Set(ctx, SettingKeyWebSearchEmulationConfig, string(data)); err != nil {
|
||||||
|
return fmt.Errorf("websearch: save config: %w", err)
|
||||||
|
}
|
||||||
|
// Invalidate: forget singleflight first, then store new value
|
||||||
|
webSearchEmulationSF.Forget(sfKeyWebSearchConfig)
|
||||||
|
webSearchEmulationCache.Store(&cachedWebSearchEmulationConfig{
|
||||||
|
config: cfg,
|
||||||
|
expiresAt: time.Now().Add(webSearchEmulationCacheTTL).UnixNano(),
|
||||||
|
})
|
||||||
|
|
||||||
|
// Hot-reload: rebuild the global Manager with new config
|
||||||
|
s.RebuildWebSearchManager(ctx)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// mergeExistingAPIKeys preserves API keys from the current config when incoming value is empty.
|
||||||
|
func (s *SettingService) mergeExistingAPIKeys(ctx context.Context, cfg *WebSearchEmulationConfig) {
|
||||||
|
existing, _ := s.getWebSearchEmulationConfigRaw(ctx)
|
||||||
|
if existing == nil || cfg == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
existingByType := make(map[string]string, len(existing.Providers))
|
||||||
|
for _, p := range existing.Providers {
|
||||||
|
if p.APIKey != "" {
|
||||||
|
existingByType[p.Type] = p.APIKey
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for i := range cfg.Providers {
|
||||||
|
if cfg.Providers[i].APIKey == "" {
|
||||||
|
if key, ok := existingByType[cfg.Providers[i].Type]; ok {
|
||||||
|
cfg.Providers[i].APIKey = key
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SettingService) getWebSearchEmulationConfigRaw(ctx context.Context) (*WebSearchEmulationConfig, error) {
|
||||||
|
raw, err := s.settingRepo.GetValue(ctx, SettingKeyWebSearchEmulationConfig)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return parseWebSearchConfigJSON(raw), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsWebSearchEmulationEnabled is a quick check for whether the global switch is on.
|
||||||
|
func (s *SettingService) IsWebSearchEmulationEnabled(ctx context.Context) bool {
|
||||||
|
cfg, err := s.GetWebSearchEmulationConfig(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return cfg.Enabled && len(cfg.Providers) > 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetWebSearchRedisClient injects the Redis client used for quota tracking.
|
||||||
|
// Call after construction, before first use. Triggers initial Manager build.
|
||||||
|
func (s *SettingService) SetWebSearchRedisClient(ctx context.Context, redisClient *redis.Client) {
|
||||||
|
s.webSearchRedis = redisClient
|
||||||
|
s.RebuildWebSearchManager(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
// RebuildWebSearchManager reads the current config and (re)creates the global websearch.Manager.
|
||||||
|
// Called on startup and after SaveWebSearchEmulationConfig.
|
||||||
|
func (s *SettingService) RebuildWebSearchManager(ctx context.Context) {
|
||||||
|
cfg, err := s.GetWebSearchEmulationConfig(ctx)
|
||||||
|
if err != nil || !cfg.Enabled || len(cfg.Providers) == 0 {
|
||||||
|
SetWebSearchManager(nil)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
providerConfigs := make([]websearch.ProviderConfig, 0, len(cfg.Providers))
|
||||||
|
for _, p := range cfg.Providers {
|
||||||
|
providerConfigs = append(providerConfigs, websearch.ProviderConfig{
|
||||||
|
Type: p.Type,
|
||||||
|
APIKey: p.APIKey,
|
||||||
|
Priority: p.Priority,
|
||||||
|
QuotaLimit: p.QuotaLimit,
|
||||||
|
QuotaRefreshInterval: p.QuotaRefreshInterval,
|
||||||
|
ExpiresAt: p.ExpiresAt,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
SetWebSearchManager(websearch.NewManager(providerConfigs, s.webSearchRedis))
|
||||||
|
slog.Info("websearch: manager rebuilt", "provider_count", len(providerConfigs))
|
||||||
|
}
|
||||||
|
|
||||||
|
// SanitizeWebSearchConfig returns a copy with api_key fields masked for API responses.
|
||||||
|
func SanitizeWebSearchConfig(cfg *WebSearchEmulationConfig) *WebSearchEmulationConfig {
|
||||||
|
if cfg == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
out := *cfg
|
||||||
|
out.Providers = make([]WebSearchProviderConfig, len(cfg.Providers))
|
||||||
|
for i, p := range cfg.Providers {
|
||||||
|
out.Providers[i] = p
|
||||||
|
out.Providers[i].APIKeyConfigured = p.APIKey != ""
|
||||||
|
out.Providers[i].APIKey = "" // never return the secret
|
||||||
|
}
|
||||||
|
return &out
|
||||||
|
}
|
||||||
148
backend/internal/service/websearch_config_test.go
Normal file
148
backend/internal/service/websearch_config_test.go
Normal file
@ -0,0 +1,148 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
// --- validateWebSearchConfig ---
|
||||||
|
|
||||||
|
func TestValidateWebSearchConfig_Nil(t *testing.T) {
|
||||||
|
require.NoError(t, validateWebSearchConfig(nil))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateWebSearchConfig_Valid(t *testing.T) {
|
||||||
|
cfg := &WebSearchEmulationConfig{
|
||||||
|
Enabled: true,
|
||||||
|
Providers: []WebSearchProviderConfig{
|
||||||
|
{Type: "brave", Priority: 1, QuotaLimit: 1000, QuotaRefreshInterval: "monthly"},
|
||||||
|
{Type: "tavily", Priority: 2, QuotaLimit: 500, QuotaRefreshInterval: "daily"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
require.NoError(t, validateWebSearchConfig(cfg))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateWebSearchConfig_TooManyProviders(t *testing.T) {
|
||||||
|
cfg := &WebSearchEmulationConfig{Providers: make([]WebSearchProviderConfig, 11)}
|
||||||
|
for i := range cfg.Providers {
|
||||||
|
cfg.Providers[i] = WebSearchProviderConfig{Type: "brave"}
|
||||||
|
}
|
||||||
|
err := validateWebSearchConfig(cfg)
|
||||||
|
require.ErrorContains(t, err, "too many providers")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateWebSearchConfig_InvalidType(t *testing.T) {
|
||||||
|
cfg := &WebSearchEmulationConfig{
|
||||||
|
Providers: []WebSearchProviderConfig{{Type: "bing"}},
|
||||||
|
}
|
||||||
|
require.ErrorContains(t, validateWebSearchConfig(cfg), "invalid type")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateWebSearchConfig_InvalidQuotaInterval(t *testing.T) {
|
||||||
|
cfg := &WebSearchEmulationConfig{
|
||||||
|
Providers: []WebSearchProviderConfig{{Type: "brave", QuotaRefreshInterval: "hourly"}},
|
||||||
|
}
|
||||||
|
require.ErrorContains(t, validateWebSearchConfig(cfg), "invalid quota_refresh_interval")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateWebSearchConfig_NegativeQuotaLimit(t *testing.T) {
|
||||||
|
cfg := &WebSearchEmulationConfig{
|
||||||
|
Providers: []WebSearchProviderConfig{{Type: "brave", QuotaLimit: -1}},
|
||||||
|
}
|
||||||
|
require.ErrorContains(t, validateWebSearchConfig(cfg), "quota_limit must be >= 0")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateWebSearchConfig_DuplicateType(t *testing.T) {
|
||||||
|
cfg := &WebSearchEmulationConfig{
|
||||||
|
Providers: []WebSearchProviderConfig{
|
||||||
|
{Type: "brave", Priority: 1},
|
||||||
|
{Type: "brave", Priority: 2},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
require.ErrorContains(t, validateWebSearchConfig(cfg), "duplicate type")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateWebSearchConfig_EmptyQuotaInterval(t *testing.T) {
|
||||||
|
cfg := &WebSearchEmulationConfig{
|
||||||
|
Providers: []WebSearchProviderConfig{{Type: "brave", QuotaRefreshInterval: ""}},
|
||||||
|
}
|
||||||
|
require.NoError(t, validateWebSearchConfig(cfg))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateWebSearchConfig_ZeroQuotaLimit(t *testing.T) {
|
||||||
|
cfg := &WebSearchEmulationConfig{
|
||||||
|
Providers: []WebSearchProviderConfig{{Type: "brave", QuotaLimit: 0}},
|
||||||
|
}
|
||||||
|
require.NoError(t, validateWebSearchConfig(cfg))
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- parseWebSearchConfigJSON ---
|
||||||
|
|
||||||
|
func TestParseWebSearchConfigJSON_ValidJSON(t *testing.T) {
|
||||||
|
raw := `{"enabled":true,"providers":[{"type":"brave","api_key":"sk-xxx"}]}`
|
||||||
|
cfg := parseWebSearchConfigJSON(raw)
|
||||||
|
require.True(t, cfg.Enabled)
|
||||||
|
require.Len(t, cfg.Providers, 1)
|
||||||
|
require.Equal(t, "brave", cfg.Providers[0].Type)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseWebSearchConfigJSON_EmptyString(t *testing.T) {
|
||||||
|
cfg := parseWebSearchConfigJSON("")
|
||||||
|
require.False(t, cfg.Enabled)
|
||||||
|
require.Empty(t, cfg.Providers)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseWebSearchConfigJSON_InvalidJSON(t *testing.T) {
|
||||||
|
cfg := parseWebSearchConfigJSON("not{json")
|
||||||
|
require.False(t, cfg.Enabled)
|
||||||
|
require.Empty(t, cfg.Providers)
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- SanitizeWebSearchConfig ---
|
||||||
|
|
||||||
|
func TestSanitizeWebSearchConfig_MaskAPIKey(t *testing.T) {
|
||||||
|
cfg := &WebSearchEmulationConfig{
|
||||||
|
Enabled: true,
|
||||||
|
Providers: []WebSearchProviderConfig{
|
||||||
|
{Type: "brave", APIKey: "sk-secret-xxx"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
out := SanitizeWebSearchConfig(cfg)
|
||||||
|
require.Equal(t, "", out.Providers[0].APIKey)
|
||||||
|
require.True(t, out.Providers[0].APIKeyConfigured)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSanitizeWebSearchConfig_NoAPIKey(t *testing.T) {
|
||||||
|
cfg := &WebSearchEmulationConfig{
|
||||||
|
Providers: []WebSearchProviderConfig{{Type: "brave", APIKey: ""}},
|
||||||
|
}
|
||||||
|
out := SanitizeWebSearchConfig(cfg)
|
||||||
|
require.Equal(t, "", out.Providers[0].APIKey)
|
||||||
|
require.False(t, out.Providers[0].APIKeyConfigured)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSanitizeWebSearchConfig_Nil(t *testing.T) {
|
||||||
|
require.Nil(t, SanitizeWebSearchConfig(nil))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSanitizeWebSearchConfig_PreservesOtherFields(t *testing.T) {
|
||||||
|
cfg := &WebSearchEmulationConfig{
|
||||||
|
Enabled: true,
|
||||||
|
Providers: []WebSearchProviderConfig{
|
||||||
|
{Type: "brave", APIKey: "secret", Priority: 10, QuotaLimit: 1000},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
out := SanitizeWebSearchConfig(cfg)
|
||||||
|
require.True(t, out.Enabled)
|
||||||
|
require.Equal(t, 10, out.Providers[0].Priority)
|
||||||
|
require.Equal(t, int64(1000), out.Providers[0].QuotaLimit)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSanitizeWebSearchConfig_DoesNotMutateOriginal(t *testing.T) {
|
||||||
|
cfg := &WebSearchEmulationConfig{
|
||||||
|
Providers: []WebSearchProviderConfig{{Type: "brave", APIKey: "secret"}},
|
||||||
|
}
|
||||||
|
_ = SanitizeWebSearchConfig(cfg)
|
||||||
|
require.Equal(t, "secret", cfg.Providers[0].APIKey)
|
||||||
|
}
|
||||||
2
backend/migrations/101_add_channel_features_config.sql
Normal file
2
backend/migrations/101_add_channel_features_config.sql
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
ALTER TABLE channels ADD COLUMN IF NOT EXISTS features_config JSONB NOT NULL DEFAULT '{}';
|
||||||
|
COMMENT ON COLUMN channels.features_config IS '渠道特性配置(如 web_search_emulation),JSON 对象格式';
|
||||||
@ -41,6 +41,7 @@ export interface Channel {
|
|||||||
status: string
|
status: string
|
||||||
billing_model_source: string // "requested" | "upstream"
|
billing_model_source: string // "requested" | "upstream"
|
||||||
restrict_models: boolean
|
restrict_models: boolean
|
||||||
|
features_config?: Record<string, unknown>
|
||||||
group_ids: number[]
|
group_ids: number[]
|
||||||
model_pricing: ChannelModelPricing[]
|
model_pricing: ChannelModelPricing[]
|
||||||
model_mapping: Record<string, Record<string, string>> // platform → {src→dst}
|
model_mapping: Record<string, Record<string, string>> // platform → {src→dst}
|
||||||
@ -56,6 +57,7 @@ export interface CreateChannelRequest {
|
|||||||
model_mapping?: Record<string, Record<string, string>>
|
model_mapping?: Record<string, Record<string, string>>
|
||||||
billing_model_source?: string
|
billing_model_source?: string
|
||||||
restrict_models?: boolean
|
restrict_models?: boolean
|
||||||
|
features_config?: Record<string, unknown>
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface UpdateChannelRequest {
|
export interface UpdateChannelRequest {
|
||||||
@ -67,6 +69,7 @@ export interface UpdateChannelRequest {
|
|||||||
model_mapping?: Record<string, Record<string, string>>
|
model_mapping?: Record<string, Record<string, string>>
|
||||||
billing_model_source?: string
|
billing_model_source?: string
|
||||||
restrict_models?: boolean
|
restrict_models?: boolean
|
||||||
|
features_config?: Record<string, unknown>
|
||||||
}
|
}
|
||||||
|
|
||||||
interface PaginatedResponse<T> {
|
interface PaginatedResponse<T> {
|
||||||
|
|||||||
@ -482,6 +482,42 @@ export async function updateBetaPolicySettings(
|
|||||||
return data
|
return data
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// --- Web Search Emulation Config ---
|
||||||
|
|
||||||
|
export interface WebSearchProviderConfig {
|
||||||
|
type: 'brave' | 'tavily'
|
||||||
|
api_key: string
|
||||||
|
api_key_configured: boolean
|
||||||
|
priority: number
|
||||||
|
quota_limit: number
|
||||||
|
quota_refresh_interval: 'daily' | 'weekly' | 'monthly'
|
||||||
|
quota_used?: number
|
||||||
|
proxy_id: number | null
|
||||||
|
expires_at: number | null
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface WebSearchEmulationConfig {
|
||||||
|
enabled: boolean
|
||||||
|
providers: WebSearchProviderConfig[]
|
||||||
|
}
|
||||||
|
|
||||||
|
export async function getWebSearchEmulationConfig(): Promise<WebSearchEmulationConfig> {
|
||||||
|
const { data } = await apiClient.get<WebSearchEmulationConfig>(
|
||||||
|
'/admin/settings/web-search-emulation'
|
||||||
|
)
|
||||||
|
return data
|
||||||
|
}
|
||||||
|
|
||||||
|
export async function updateWebSearchEmulationConfig(
|
||||||
|
config: WebSearchEmulationConfig
|
||||||
|
): Promise<WebSearchEmulationConfig> {
|
||||||
|
const { data } = await apiClient.put<WebSearchEmulationConfig>(
|
||||||
|
'/admin/settings/web-search-emulation',
|
||||||
|
config
|
||||||
|
)
|
||||||
|
return data
|
||||||
|
}
|
||||||
|
|
||||||
export const settingsAPI = {
|
export const settingsAPI = {
|
||||||
getSettings,
|
getSettings,
|
||||||
updateSettings,
|
updateSettings,
|
||||||
@ -497,7 +533,9 @@ export const settingsAPI = {
|
|||||||
getRectifierSettings,
|
getRectifierSettings,
|
||||||
updateRectifierSettings,
|
updateRectifierSettings,
|
||||||
getBetaPolicySettings,
|
getBetaPolicySettings,
|
||||||
updateBetaPolicySettings
|
updateBetaPolicySettings,
|
||||||
|
getWebSearchEmulationConfig,
|
||||||
|
updateWebSearchEmulationConfig
|
||||||
}
|
}
|
||||||
|
|
||||||
export default settingsAPI
|
export default settingsAPI
|
||||||
|
|||||||
@ -2325,6 +2325,22 @@
|
|||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
<!-- Anthropic API Key: Web Search Emulation -->
|
||||||
|
<div
|
||||||
|
v-if="form.platform === 'anthropic' && accountCategory === 'apikey'"
|
||||||
|
class="border-t border-gray-200 pt-4 dark:border-dark-600"
|
||||||
|
>
|
||||||
|
<div class="flex items-center justify-between">
|
||||||
|
<div>
|
||||||
|
<label class="input-label mb-0">{{ t('admin.accounts.anthropic.webSearchEmulation') }}</label>
|
||||||
|
<p class="mt-1 text-xs text-gray-500 dark:text-gray-400">
|
||||||
|
{{ t('admin.accounts.anthropic.webSearchEmulationDesc') }}
|
||||||
|
</p>
|
||||||
|
</div>
|
||||||
|
<Toggle v-model="webSearchEmulationEnabled" />
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
<!-- OpenAI OAuth Codex 官方客户端限制开关 -->
|
<!-- OpenAI OAuth Codex 官方客户端限制开关 -->
|
||||||
<div
|
<div
|
||||||
v-if="form.platform === 'openai' && accountCategory === 'oauth-based'"
|
v-if="form.platform === 'openai' && accountCategory === 'oauth-based'"
|
||||||
@ -2830,6 +2846,7 @@ import ConfirmDialog from '@/components/common/ConfirmDialog.vue'
|
|||||||
import Select from '@/components/common/Select.vue'
|
import Select from '@/components/common/Select.vue'
|
||||||
import Icon from '@/components/icons/Icon.vue'
|
import Icon from '@/components/icons/Icon.vue'
|
||||||
import ProxySelector from '@/components/common/ProxySelector.vue'
|
import ProxySelector from '@/components/common/ProxySelector.vue'
|
||||||
|
import Toggle from '@/components/common/Toggle.vue'
|
||||||
import GroupSelector from '@/components/common/GroupSelector.vue'
|
import GroupSelector from '@/components/common/GroupSelector.vue'
|
||||||
import ModelWhitelistSelector from '@/components/account/ModelWhitelistSelector.vue'
|
import ModelWhitelistSelector from '@/components/account/ModelWhitelistSelector.vue'
|
||||||
import QuotaLimitCard from '@/components/account/QuotaLimitCard.vue'
|
import QuotaLimitCard from '@/components/account/QuotaLimitCard.vue'
|
||||||
@ -2980,6 +2997,7 @@ const openaiOAuthResponsesWebSocketV2Mode = ref<OpenAIWSMode>(OPENAI_WS_MODE_OFF
|
|||||||
const openaiAPIKeyResponsesWebSocketV2Mode = ref<OpenAIWSMode>(OPENAI_WS_MODE_OFF)
|
const openaiAPIKeyResponsesWebSocketV2Mode = ref<OpenAIWSMode>(OPENAI_WS_MODE_OFF)
|
||||||
const codexCLIOnlyEnabled = ref(false)
|
const codexCLIOnlyEnabled = ref(false)
|
||||||
const anthropicPassthroughEnabled = ref(false)
|
const anthropicPassthroughEnabled = ref(false)
|
||||||
|
const webSearchEmulationEnabled = ref(false)
|
||||||
const mixedScheduling = ref(false) // For antigravity accounts: enable mixed scheduling
|
const mixedScheduling = ref(false) // For antigravity accounts: enable mixed scheduling
|
||||||
const allowOverages = ref(false) // For antigravity accounts: enable AI Credits overages
|
const allowOverages = ref(false) // For antigravity accounts: enable AI Credits overages
|
||||||
const antigravityAccountType = ref<'oauth' | 'upstream'>('oauth') // For antigravity: oauth or upstream
|
const antigravityAccountType = ref<'oauth' | 'upstream'>('oauth') // For antigravity: oauth or upstream
|
||||||
@ -3307,6 +3325,7 @@ watch(
|
|||||||
}
|
}
|
||||||
if (newPlatform !== 'anthropic') {
|
if (newPlatform !== 'anthropic') {
|
||||||
anthropicPassthroughEnabled.value = false
|
anthropicPassthroughEnabled.value = false
|
||||||
|
webSearchEmulationEnabled.value = false
|
||||||
}
|
}
|
||||||
// Reset OAuth states
|
// Reset OAuth states
|
||||||
oauth.resetState()
|
oauth.resetState()
|
||||||
@ -3326,6 +3345,7 @@ watch(
|
|||||||
}
|
}
|
||||||
if (platform !== 'anthropic' || category !== 'apikey') {
|
if (platform !== 'anthropic' || category !== 'apikey') {
|
||||||
anthropicPassthroughEnabled.value = false
|
anthropicPassthroughEnabled.value = false
|
||||||
|
webSearchEmulationEnabled.value = false
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@ -3690,6 +3710,7 @@ const resetForm = () => {
|
|||||||
openaiAPIKeyResponsesWebSocketV2Mode.value = OPENAI_WS_MODE_OFF
|
openaiAPIKeyResponsesWebSocketV2Mode.value = OPENAI_WS_MODE_OFF
|
||||||
codexCLIOnlyEnabled.value = false
|
codexCLIOnlyEnabled.value = false
|
||||||
anthropicPassthroughEnabled.value = false
|
anthropicPassthroughEnabled.value = false
|
||||||
|
webSearchEmulationEnabled.value = false
|
||||||
// Reset quota control state
|
// Reset quota control state
|
||||||
windowCostEnabled.value = false
|
windowCostEnabled.value = false
|
||||||
windowCostLimit.value = null
|
windowCostLimit.value = null
|
||||||
@ -3777,6 +3798,11 @@ const buildAnthropicExtra = (base?: Record<string, unknown>): Record<string, unk
|
|||||||
} else {
|
} else {
|
||||||
delete extra.anthropic_passthrough
|
delete extra.anthropic_passthrough
|
||||||
}
|
}
|
||||||
|
if (webSearchEmulationEnabled.value) {
|
||||||
|
extra.web_search_emulation = true
|
||||||
|
} else {
|
||||||
|
delete extra.web_search_emulation
|
||||||
|
}
|
||||||
|
|
||||||
return Object.keys(extra).length > 0 ? extra : undefined
|
return Object.keys(extra).length > 0 ? extra : undefined
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1149,10 +1149,61 @@
|
|||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<!-- API Key / Bedrock 账号配额限制 -->
|
<!-- Anthropic API Key: Web Search Emulation -->
|
||||||
<div v-if="account?.type === 'apikey' || account?.type === 'bedrock'" class="border-t border-gray-200 pt-4 dark:border-dark-600 space-y-4">
|
<div
|
||||||
|
v-if="account?.platform === 'anthropic' && account?.type === 'apikey'"
|
||||||
|
class="border-t border-gray-200 pt-4 dark:border-dark-600"
|
||||||
|
>
|
||||||
|
<div class="flex items-center justify-between">
|
||||||
|
<div>
|
||||||
|
<label class="input-label mb-0">{{ t('admin.accounts.anthropic.webSearchEmulation') }}</label>
|
||||||
|
<p class="mt-1 text-xs text-gray-500 dark:text-gray-400">
|
||||||
|
{{ t('admin.accounts.anthropic.webSearchEmulationDesc') }}
|
||||||
|
</p>
|
||||||
|
</div>
|
||||||
|
<Toggle v-model="webSearchEmulationEnabled" />
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<!-- 配额控制 (Anthropic apikey/bedrock: 配额限制 + 亲和) -->
|
||||||
|
<div
|
||||||
|
v-if="account?.platform === 'anthropic' && (account?.type === 'apikey' || account?.type === 'bedrock')"
|
||||||
|
class="border-t border-gray-200 pt-4 dark:border-dark-600 space-y-4"
|
||||||
|
>
|
||||||
<div class="mb-3">
|
<div class="mb-3">
|
||||||
<h3 class="input-label mb-0 text-base font-semibold">{{ t('admin.accounts.quotaLimit') }}</h3>
|
<h3 class="input-label mb-0 text-base font-semibold">{{ t('admin.accounts.quotaControl.title') }}</h3>
|
||||||
|
<p class="mt-1 text-xs text-gray-500 dark:text-gray-400">
|
||||||
|
{{ t('admin.accounts.quotaControl.hint') }}
|
||||||
|
</p>
|
||||||
|
</div>
|
||||||
|
<QuotaLimitCard
|
||||||
|
:totalLimit="editQuotaLimit"
|
||||||
|
:dailyLimit="editQuotaDailyLimit"
|
||||||
|
:weeklyLimit="editQuotaWeeklyLimit"
|
||||||
|
:dailyResetMode="editDailyResetMode"
|
||||||
|
:dailyResetHour="editDailyResetHour"
|
||||||
|
:weeklyResetMode="editWeeklyResetMode"
|
||||||
|
:weeklyResetDay="editWeeklyResetDay"
|
||||||
|
:weeklyResetHour="editWeeklyResetHour"
|
||||||
|
:resetTimezone="editResetTimezone"
|
||||||
|
@update:totalLimit="editQuotaLimit = $event"
|
||||||
|
@update:dailyLimit="editQuotaDailyLimit = $event"
|
||||||
|
@update:weeklyLimit="editQuotaWeeklyLimit = $event"
|
||||||
|
@update:dailyResetMode="editDailyResetMode = $event"
|
||||||
|
@update:dailyResetHour="editDailyResetHour = $event"
|
||||||
|
@update:weeklyResetMode="editWeeklyResetMode = $event"
|
||||||
|
@update:weeklyResetDay="editWeeklyResetDay = $event"
|
||||||
|
@update:weeklyResetHour="editWeeklyResetHour = $event"
|
||||||
|
@update:resetTimezone="editResetTimezone = $event"
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
<!-- 配额控制 (非 Anthropic apikey/bedrock) -->
|
||||||
|
<div
|
||||||
|
v-else-if="account?.type === 'apikey' || account?.type === 'bedrock'"
|
||||||
|
class="border-t border-gray-200 pt-4 dark:border-dark-600 space-y-4"
|
||||||
|
>
|
||||||
|
<div class="mb-3">
|
||||||
|
<h3 class="input-label mb-0 text-base font-semibold">{{ t('admin.accounts.quotaControl.title') }}</h3>
|
||||||
<p class="mt-1 text-xs text-gray-500 dark:text-gray-400">
|
<p class="mt-1 text-xs text-gray-500 dark:text-gray-400">
|
||||||
{{ t('admin.accounts.quotaLimitHint') }}
|
{{ t('admin.accounts.quotaLimitHint') }}
|
||||||
</p>
|
</p>
|
||||||
@ -1237,7 +1288,7 @@
|
|||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<!-- Quota Control Section (Anthropic OAuth/SetupToken only) -->
|
<!-- 配额控制 (Anthropic OAuth/SetupToken: 亲和 + 窗口费用 + 会话 + RPM 等) -->
|
||||||
<div
|
<div
|
||||||
v-if="account?.platform === 'anthropic' && (account?.type === 'oauth' || account?.type === 'setup-token')"
|
v-if="account?.platform === 'anthropic' && (account?.type === 'oauth' || account?.type === 'setup-token')"
|
||||||
class="border-t border-gray-200 pt-4 dark:border-dark-600 space-y-4"
|
class="border-t border-gray-200 pt-4 dark:border-dark-600 space-y-4"
|
||||||
@ -1757,6 +1808,7 @@ import ConfirmDialog from '@/components/common/ConfirmDialog.vue'
|
|||||||
import Select from '@/components/common/Select.vue'
|
import Select from '@/components/common/Select.vue'
|
||||||
import Icon from '@/components/icons/Icon.vue'
|
import Icon from '@/components/icons/Icon.vue'
|
||||||
import ProxySelector from '@/components/common/ProxySelector.vue'
|
import ProxySelector from '@/components/common/ProxySelector.vue'
|
||||||
|
import Toggle from '@/components/common/Toggle.vue'
|
||||||
import GroupSelector from '@/components/common/GroupSelector.vue'
|
import GroupSelector from '@/components/common/GroupSelector.vue'
|
||||||
import ModelWhitelistSelector from '@/components/account/ModelWhitelistSelector.vue'
|
import ModelWhitelistSelector from '@/components/account/ModelWhitelistSelector.vue'
|
||||||
import QuotaLimitCard from '@/components/account/QuotaLimitCard.vue'
|
import QuotaLimitCard from '@/components/account/QuotaLimitCard.vue'
|
||||||
@ -1898,6 +1950,7 @@ const openaiOAuthResponsesWebSocketV2Mode = ref<OpenAIWSMode>(OPENAI_WS_MODE_OFF
|
|||||||
const openaiAPIKeyResponsesWebSocketV2Mode = ref<OpenAIWSMode>(OPENAI_WS_MODE_OFF)
|
const openaiAPIKeyResponsesWebSocketV2Mode = ref<OpenAIWSMode>(OPENAI_WS_MODE_OFF)
|
||||||
const codexCLIOnlyEnabled = ref(false)
|
const codexCLIOnlyEnabled = ref(false)
|
||||||
const anthropicPassthroughEnabled = ref(false)
|
const anthropicPassthroughEnabled = ref(false)
|
||||||
|
const webSearchEmulationEnabled = ref(false)
|
||||||
const editQuotaLimit = ref<number | null>(null)
|
const editQuotaLimit = ref<number | null>(null)
|
||||||
const editQuotaDailyLimit = ref<number | null>(null)
|
const editQuotaDailyLimit = ref<number | null>(null)
|
||||||
const editQuotaWeeklyLimit = ref<number | null>(null)
|
const editQuotaWeeklyLimit = ref<number | null>(null)
|
||||||
@ -2067,6 +2120,7 @@ const syncFormFromAccount = (newAccount: Account | null) => {
|
|||||||
openaiAPIKeyResponsesWebSocketV2Mode.value = OPENAI_WS_MODE_OFF
|
openaiAPIKeyResponsesWebSocketV2Mode.value = OPENAI_WS_MODE_OFF
|
||||||
codexCLIOnlyEnabled.value = false
|
codexCLIOnlyEnabled.value = false
|
||||||
anthropicPassthroughEnabled.value = false
|
anthropicPassthroughEnabled.value = false
|
||||||
|
webSearchEmulationEnabled.value = false
|
||||||
if (newAccount.platform === 'openai' && (newAccount.type === 'oauth' || newAccount.type === 'apikey')) {
|
if (newAccount.platform === 'openai' && (newAccount.type === 'oauth' || newAccount.type === 'apikey')) {
|
||||||
openaiPassthroughEnabled.value = extra?.openai_passthrough === true || extra?.openai_oauth_passthrough === true
|
openaiPassthroughEnabled.value = extra?.openai_passthrough === true || extra?.openai_oauth_passthrough === true
|
||||||
openaiOAuthResponsesWebSocketV2Mode.value = resolveOpenAIWSModeFromExtra(extra, {
|
openaiOAuthResponsesWebSocketV2Mode.value = resolveOpenAIWSModeFromExtra(extra, {
|
||||||
@ -2087,6 +2141,7 @@ const syncFormFromAccount = (newAccount: Account | null) => {
|
|||||||
}
|
}
|
||||||
if (newAccount.platform === 'anthropic' && newAccount.type === 'apikey') {
|
if (newAccount.platform === 'anthropic' && newAccount.type === 'apikey') {
|
||||||
anthropicPassthroughEnabled.value = extra?.anthropic_passthrough === true
|
anthropicPassthroughEnabled.value = extra?.anthropic_passthrough === true
|
||||||
|
webSearchEmulationEnabled.value = extra?.web_search_emulation === true
|
||||||
}
|
}
|
||||||
|
|
||||||
// Load quota limit for apikey/bedrock accounts (bedrock quota is also loaded in its own branch above)
|
// Load quota limit for apikey/bedrock accounts (bedrock quota is also loaded in its own branch above)
|
||||||
@ -2522,8 +2577,13 @@ function loadQuotaControlSettings(account: Account) {
|
|||||||
customBaseUrlEnabled.value = false
|
customBaseUrlEnabled.value = false
|
||||||
customBaseUrl.value = ''
|
customBaseUrl.value = ''
|
||||||
|
|
||||||
// Only applies to Anthropic OAuth/SetupToken accounts
|
// Remaining quota control settings only apply to Anthropic accounts
|
||||||
if (account.platform !== 'anthropic' || (account.type !== 'oauth' && account.type !== 'setup-token')) {
|
if (account.platform !== 'anthropic') {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Window cost / session limit only apply to Anthropic OAuth/SetupToken accounts
|
||||||
|
if (account.type !== 'oauth' && account.type !== 'setup-token') {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2949,7 +3009,7 @@ const handleSubmit = async () => {
|
|||||||
|
|
||||||
// For Anthropic OAuth/SetupToken accounts, handle quota control settings in extra
|
// For Anthropic OAuth/SetupToken accounts, handle quota control settings in extra
|
||||||
if (props.account.platform === 'anthropic' && (props.account.type === 'oauth' || props.account.type === 'setup-token')) {
|
if (props.account.platform === 'anthropic' && (props.account.type === 'oauth' || props.account.type === 'setup-token')) {
|
||||||
const currentExtra = (props.account.extra as Record<string, unknown>) || {}
|
const currentExtra = (updatePayload.extra as Record<string, unknown>) || (props.account.extra as Record<string, unknown>) || {}
|
||||||
const newExtra: Record<string, unknown> = { ...currentExtra }
|
const newExtra: Record<string, unknown> = { ...currentExtra }
|
||||||
|
|
||||||
// Window cost limit settings
|
// Window cost limit settings
|
||||||
@ -3037,15 +3097,20 @@ const handleSubmit = async () => {
|
|||||||
updatePayload.extra = newExtra
|
updatePayload.extra = newExtra
|
||||||
}
|
}
|
||||||
|
|
||||||
// For Anthropic API Key accounts, handle passthrough mode in extra
|
// For Anthropic API Key accounts, handle passthrough mode + web search emulation in extra
|
||||||
if (props.account.platform === 'anthropic' && props.account.type === 'apikey') {
|
if (props.account.platform === 'anthropic' && props.account.type === 'apikey') {
|
||||||
const currentExtra = (props.account.extra as Record<string, unknown>) || {}
|
const currentExtra = (updatePayload.extra as Record<string, unknown>) || (props.account.extra as Record<string, unknown>) || {}
|
||||||
const newExtra: Record<string, unknown> = { ...currentExtra }
|
const newExtra: Record<string, unknown> = { ...currentExtra }
|
||||||
if (anthropicPassthroughEnabled.value) {
|
if (anthropicPassthroughEnabled.value) {
|
||||||
newExtra.anthropic_passthrough = true
|
newExtra.anthropic_passthrough = true
|
||||||
} else {
|
} else {
|
||||||
delete newExtra.anthropic_passthrough
|
delete newExtra.anthropic_passthrough
|
||||||
}
|
}
|
||||||
|
if (webSearchEmulationEnabled.value) {
|
||||||
|
newExtra.web_search_emulation = true
|
||||||
|
} else {
|
||||||
|
delete newExtra.web_search_emulation
|
||||||
|
}
|
||||||
updatePayload.extra = newExtra
|
updatePayload.extra = newExtra
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -3089,20 +3154,27 @@ const handleSubmit = async () => {
|
|||||||
const currentExtra = (updatePayload.extra as Record<string, unknown>) ||
|
const currentExtra = (updatePayload.extra as Record<string, unknown>) ||
|
||||||
(props.account.extra as Record<string, unknown>) || {}
|
(props.account.extra as Record<string, unknown>) || {}
|
||||||
const newExtra: Record<string, unknown> = { ...currentExtra }
|
const newExtra: Record<string, unknown> = { ...currentExtra }
|
||||||
|
// Total quota
|
||||||
if (editQuotaLimit.value != null && editQuotaLimit.value > 0) {
|
if (editQuotaLimit.value != null && editQuotaLimit.value > 0) {
|
||||||
newExtra.quota_limit = editQuotaLimit.value
|
newExtra.quota_limit = editQuotaLimit.value
|
||||||
} else {
|
} else {
|
||||||
delete newExtra.quota_limit
|
delete newExtra.quota_limit
|
||||||
}
|
}
|
||||||
|
// Daily quota
|
||||||
if (editQuotaDailyLimit.value != null && editQuotaDailyLimit.value > 0) {
|
if (editQuotaDailyLimit.value != null && editQuotaDailyLimit.value > 0) {
|
||||||
newExtra.quota_daily_limit = editQuotaDailyLimit.value
|
newExtra.quota_daily_limit = editQuotaDailyLimit.value
|
||||||
} else {
|
} else {
|
||||||
delete newExtra.quota_daily_limit
|
delete newExtra.quota_daily_limit
|
||||||
|
delete newExtra.quota_daily_used
|
||||||
|
delete newExtra.quota_daily_start
|
||||||
}
|
}
|
||||||
|
// Weekly quota
|
||||||
if (editQuotaWeeklyLimit.value != null && editQuotaWeeklyLimit.value > 0) {
|
if (editQuotaWeeklyLimit.value != null && editQuotaWeeklyLimit.value > 0) {
|
||||||
newExtra.quota_weekly_limit = editQuotaWeeklyLimit.value
|
newExtra.quota_weekly_limit = editQuotaWeeklyLimit.value
|
||||||
} else {
|
} else {
|
||||||
delete newExtra.quota_weekly_limit
|
delete newExtra.quota_weekly_limit
|
||||||
|
delete newExtra.quota_weekly_used
|
||||||
|
delete newExtra.quota_weekly_start
|
||||||
}
|
}
|
||||||
// Quota reset mode config
|
// Quota reset mode config
|
||||||
if (editDailyResetMode.value === 'fixed') {
|
if (editDailyResetMode.value === 'fixed') {
|
||||||
|
|||||||
@ -1836,6 +1836,9 @@ export default {
|
|||||||
defaultPerRequestPrice: 'Default per-request price (fallback when no tier matches)',
|
defaultPerRequestPrice: 'Default per-request price (fallback when no tier matches)',
|
||||||
defaultImagePrice: 'Default image price (fallback when no tier matches)',
|
defaultImagePrice: 'Default image price (fallback when no tier matches)',
|
||||||
platformConfig: 'Platform Configuration',
|
platformConfig: 'Platform Configuration',
|
||||||
|
webSearchEmulation: 'Web Search Emulation',
|
||||||
|
webSearchEmulationHint: '⚠️ When enabled, all accounts in this channel\'s Anthropic groups will intercept web_search requests. Use with caution.',
|
||||||
|
webSearchEmulationGlobalDisabled: 'Please enable the global switch first in Settings → Gateway → Web Search Emulation',
|
||||||
basicSettings: 'Basic Settings',
|
basicSettings: 'Basic Settings',
|
||||||
addPlatform: 'Add Platform',
|
addPlatform: 'Add Platform',
|
||||||
noPlatforms: 'Click "Add Platform" to start configuring the channel',
|
noPlatforms: 'Click "Add Platform" to start configuring the channel',
|
||||||
@ -2325,7 +2328,10 @@ export default {
|
|||||||
anthropic: {
|
anthropic: {
|
||||||
apiKeyPassthrough: 'Auto passthrough (auth only)',
|
apiKeyPassthrough: 'Auto passthrough (auth only)',
|
||||||
apiKeyPassthroughDesc:
|
apiKeyPassthroughDesc:
|
||||||
'Only applies to Anthropic API Key accounts. When enabled, messages/count_tokens are forwarded in passthrough mode with auth replacement only, while billing/concurrency/audit and safety filtering are preserved. Disable to roll back immediately.'
|
'Only applies to Anthropic API Key accounts. When enabled, messages/count_tokens are forwarded in passthrough mode with auth replacement only, while billing/concurrency/audit and safety filtering are preserved. Disable to roll back immediately.',
|
||||||
|
webSearchEmulation: 'Web Search Emulation',
|
||||||
|
webSearchEmulationDesc:
|
||||||
|
'Enable web search emulation for this API Key account. When a pure web_search request is detected, the gateway calls a third-party search API and constructs the response locally.',
|
||||||
},
|
},
|
||||||
modelRestriction: 'Model Restriction (Optional)',
|
modelRestriction: 'Model Restriction (Optional)',
|
||||||
modelWhitelist: 'Model Whitelist',
|
modelWhitelist: 'Model Whitelist',
|
||||||
@ -4358,6 +4364,31 @@ export default {
|
|||||||
cchSigning: 'CCH Signing',
|
cchSigning: 'CCH Signing',
|
||||||
cchSigningHint: 'Sign the billing header in forwarded requests with CCH hash. When disabled, the placeholder is preserved.',
|
cchSigningHint: 'Sign the billing header in forwarded requests with CCH hash. When disabled, the placeholder is preserved.',
|
||||||
},
|
},
|
||||||
|
webSearchEmulation: {
|
||||||
|
title: 'Web Search Emulation',
|
||||||
|
description: 'Inject web search capability for Anthropic API Key accounts that don\'t natively support it',
|
||||||
|
enabled: 'Enable Web Search Emulation',
|
||||||
|
enabledHint: 'Global switch. When disabled, web search emulation is inactive for all channels and accounts.',
|
||||||
|
providers: 'Search Providers',
|
||||||
|
addProvider: 'Add Provider',
|
||||||
|
providerType: 'Provider Type',
|
||||||
|
apiKey: 'API Key',
|
||||||
|
apiKeyPlaceholder: 'Enter API Key',
|
||||||
|
apiKeyConfigured: 'Configured',
|
||||||
|
priority: 'Priority',
|
||||||
|
priorityHint: 'Lower number = higher priority',
|
||||||
|
quotaLimit: 'Quota Limit',
|
||||||
|
quotaLimitHint: '0 = unlimited',
|
||||||
|
quotaRefreshInterval: 'Refresh Interval',
|
||||||
|
quotaUsed: 'Used',
|
||||||
|
proxy: 'Proxy',
|
||||||
|
expiresAt: 'Expires At',
|
||||||
|
removeProvider: 'Remove',
|
||||||
|
daily: 'Daily',
|
||||||
|
weekly: 'Weekly',
|
||||||
|
monthly: 'Monthly',
|
||||||
|
noProviders: 'No search providers configured',
|
||||||
|
},
|
||||||
site: {
|
site: {
|
||||||
title: 'Site Settings',
|
title: 'Site Settings',
|
||||||
description: 'Customize site branding',
|
description: 'Customize site branding',
|
||||||
|
|||||||
@ -1915,6 +1915,9 @@ export default {
|
|||||||
defaultPerRequestPrice: '默认单次价格(未命中层级时使用)',
|
defaultPerRequestPrice: '默认单次价格(未命中层级时使用)',
|
||||||
defaultImagePrice: '默认图片价格(未命中层级时使用)',
|
defaultImagePrice: '默认图片价格(未命中层级时使用)',
|
||||||
platformConfig: '平台配置',
|
platformConfig: '平台配置',
|
||||||
|
webSearchEmulation: 'Web Search 模拟',
|
||||||
|
webSearchEmulationHint: '⚠️ 开启后该渠道下所有 Anthropic 分组的账号将自动拦截 web_search 请求,请谨慎操作',
|
||||||
|
webSearchEmulationGlobalDisabled: '请先在系统设置 → 网关 → Web Search 模拟中启用全局开关',
|
||||||
basicSettings: '基础设置',
|
basicSettings: '基础设置',
|
||||||
addPlatform: '添加平台',
|
addPlatform: '添加平台',
|
||||||
noPlatforms: '点击"添加平台"开始配置渠道',
|
noPlatforms: '点击"添加平台"开始配置渠道',
|
||||||
@ -2472,7 +2475,10 @@ export default {
|
|||||||
anthropic: {
|
anthropic: {
|
||||||
apiKeyPassthrough: '自动透传(仅替换认证)',
|
apiKeyPassthrough: '自动透传(仅替换认证)',
|
||||||
apiKeyPassthroughDesc:
|
apiKeyPassthroughDesc:
|
||||||
'仅对 Anthropic API Key 生效。开启后,messages/count_tokens 请求将透传上游并仅替换认证,保留计费/并发/审计及必要安全过滤;关闭即可回滚到现有兼容链路。'
|
'仅对 Anthropic API Key 生效。开启后,messages/count_tokens 请求将透传上游并仅替换认证,保留计费/并发/审计及必要安全过滤;关闭即可回滚到现有兼容链路。',
|
||||||
|
webSearchEmulation: 'Web Search 模拟',
|
||||||
|
webSearchEmulationDesc:
|
||||||
|
'为该 API Key 账号启用 web search 模拟。客户端发送纯 web_search 请求时,由网关调用第三方搜索 API 并构造响应返回。',
|
||||||
},
|
},
|
||||||
modelRestriction: '模型限制(可选)',
|
modelRestriction: '模型限制(可选)',
|
||||||
modelWhitelist: '模型白名单',
|
modelWhitelist: '模型白名单',
|
||||||
@ -4520,6 +4526,31 @@ export default {
|
|||||||
cchSigning: 'CCH 签名',
|
cchSigning: 'CCH 签名',
|
||||||
cchSigningHint: '对转发请求的 billing header 进行 CCH 哈希签名。关闭时保留原始占位符。',
|
cchSigningHint: '对转发请求的 billing header 进行 CCH 哈希签名。关闭时保留原始占位符。',
|
||||||
},
|
},
|
||||||
|
webSearchEmulation: {
|
||||||
|
title: 'Web Search 模拟',
|
||||||
|
description: '为不原生支持搜索的 Anthropic API Key 账号注入 web search 能力',
|
||||||
|
enabled: '启用 Web Search 模拟',
|
||||||
|
enabledHint: '全局开关。关闭后所有渠道和账号的 web search 模拟均不生效。',
|
||||||
|
providers: '搜索服务商',
|
||||||
|
addProvider: '添加服务商',
|
||||||
|
providerType: '服务商类型',
|
||||||
|
apiKey: 'API Key',
|
||||||
|
apiKeyPlaceholder: '输入 API Key',
|
||||||
|
apiKeyConfigured: '已配置',
|
||||||
|
priority: '优先级',
|
||||||
|
priorityHint: '数值越小优先级越高',
|
||||||
|
quotaLimit: '配额上限',
|
||||||
|
quotaLimitHint: '0 表示无限制',
|
||||||
|
quotaRefreshInterval: '刷新周期',
|
||||||
|
quotaUsed: '已使用',
|
||||||
|
proxy: '代理',
|
||||||
|
expiresAt: '过期时间',
|
||||||
|
removeProvider: '删除',
|
||||||
|
daily: '每日',
|
||||||
|
weekly: '每周',
|
||||||
|
monthly: '每月',
|
||||||
|
noProviders: '未配置搜索服务商',
|
||||||
|
},
|
||||||
site: {
|
site: {
|
||||||
title: '站点设置',
|
title: '站点设置',
|
||||||
description: '自定义站点品牌',
|
description: '自定义站点品牌',
|
||||||
|
|||||||
@ -306,6 +306,24 @@
|
|||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
<!-- Web Search Emulation (Anthropic only) -->
|
||||||
|
<div v-if="section.platform === 'anthropic'" class="border-t border-gray-200 pt-3 dark:border-dark-600">
|
||||||
|
<div class="flex items-center justify-between">
|
||||||
|
<div>
|
||||||
|
<label class="text-xs font-medium text-orange-600 dark:text-orange-400">
|
||||||
|
{{ t('admin.channels.form.webSearchEmulation') }}
|
||||||
|
</label>
|
||||||
|
<p v-if="webSearchGlobalEnabled" class="mt-0.5 text-[11px] text-amber-500 dark:text-amber-400">
|
||||||
|
{{ t('admin.channels.form.webSearchEmulationHint') }}
|
||||||
|
</p>
|
||||||
|
<p v-else class="mt-0.5 text-[11px] text-gray-400">
|
||||||
|
{{ t('admin.channels.form.webSearchEmulationGlobalDisabled') }}
|
||||||
|
</p>
|
||||||
|
</div>
|
||||||
|
<Toggle v-model="section.web_search_emulation" :disabled="!webSearchGlobalEnabled" />
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
<!-- Model Mapping -->
|
<!-- Model Mapping -->
|
||||||
<div>
|
<div>
|
||||||
<div class="mb-1 flex items-center justify-between">
|
<div class="mb-1 flex items-center justify-between">
|
||||||
@ -423,6 +441,7 @@
|
|||||||
import { ref, reactive, computed, onMounted, onUnmounted } from 'vue'
|
import { ref, reactive, computed, onMounted, onUnmounted } from 'vue'
|
||||||
import { useI18n } from 'vue-i18n'
|
import { useI18n } from 'vue-i18n'
|
||||||
import { useAppStore } from '@/stores/app'
|
import { useAppStore } from '@/stores/app'
|
||||||
|
import { extractApiErrorMessage } from '@/utils/apiError'
|
||||||
import { adminAPI } from '@/api/admin'
|
import { adminAPI } from '@/api/admin'
|
||||||
import type { Channel, ChannelModelPricing, CreateChannelRequest, UpdateChannelRequest } from '@/api/admin/channels'
|
import type { Channel, ChannelModelPricing, CreateChannelRequest, UpdateChannelRequest } from '@/api/admin/channels'
|
||||||
import type { PricingFormEntry } from '@/components/admin/channel/types'
|
import type { PricingFormEntry } from '@/components/admin/channel/types'
|
||||||
@ -446,6 +465,18 @@ import { getPersistedPageSize } from '@/composables/usePersistedPageSize'
|
|||||||
const { t } = useI18n()
|
const { t } = useI18n()
|
||||||
const appStore = useAppStore()
|
const appStore = useAppStore()
|
||||||
|
|
||||||
|
// Web Search global enabled state (loaded once on mount)
|
||||||
|
const webSearchGlobalEnabled = ref(false)
|
||||||
|
async function loadWebSearchGlobalState() {
|
||||||
|
try {
|
||||||
|
const cfg = await adminAPI.settings.getWebSearchEmulationConfig()
|
||||||
|
webSearchGlobalEnabled.value = cfg?.enabled === true && (cfg?.providers?.length ?? 0) > 0
|
||||||
|
} catch (err: unknown) {
|
||||||
|
console.warn('Failed to load web search global state:', err)
|
||||||
|
webSearchGlobalEnabled.value = false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// ── Platform Section type ──
|
// ── Platform Section type ──
|
||||||
interface PlatformSection {
|
interface PlatformSection {
|
||||||
platform: GroupPlatform
|
platform: GroupPlatform
|
||||||
@ -454,6 +485,7 @@ interface PlatformSection {
|
|||||||
group_ids: number[]
|
group_ids: number[]
|
||||||
model_mapping: Record<string, string>
|
model_mapping: Record<string, string>
|
||||||
model_pricing: PricingFormEntry[]
|
model_pricing: PricingFormEntry[]
|
||||||
|
web_search_emulation: boolean
|
||||||
}
|
}
|
||||||
|
|
||||||
// ── Table columns ──
|
// ── Table columns ──
|
||||||
@ -565,7 +597,8 @@ function addPlatformSection(platform: GroupPlatform) {
|
|||||||
collapsed: false,
|
collapsed: false,
|
||||||
group_ids: [],
|
group_ids: [],
|
||||||
model_mapping: {},
|
model_mapping: {},
|
||||||
model_pricing: []
|
model_pricing: [],
|
||||||
|
web_search_emulation: false,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -679,10 +712,14 @@ function renameMappingKey(sectionIdx: number, oldKey: string, newKey: string) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// ── Form ↔ API conversion ──
|
// ── Form ↔ API conversion ──
|
||||||
function formToAPI(): { group_ids: number[], model_pricing: ChannelModelPricing[], model_mapping: Record<string, Record<string, string>> } {
|
function formToAPI(): { group_ids: number[], model_pricing: ChannelModelPricing[], model_mapping: Record<string, Record<string, string>>, features_config: Record<string, unknown> } {
|
||||||
const group_ids: number[] = []
|
const group_ids: number[] = []
|
||||||
const model_pricing: ChannelModelPricing[] = []
|
const model_pricing: ChannelModelPricing[] = []
|
||||||
const model_mapping: Record<string, Record<string, string>> = {}
|
const model_mapping: Record<string, Record<string, string>> = {}
|
||||||
|
// Preserve existing features_config fields not managed by the form
|
||||||
|
const featuresConfig: Record<string, unknown> = editingChannel.value?.features_config
|
||||||
|
? { ...editingChannel.value.features_config }
|
||||||
|
: {}
|
||||||
|
|
||||||
for (const section of form.platforms) {
|
for (const section of form.platforms) {
|
||||||
if (!section.enabled) continue
|
if (!section.enabled) continue
|
||||||
@ -711,7 +748,19 @@ function formToAPI(): { group_ids: number[], model_pricing: ChannelModelPricing[
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return { group_ids, model_pricing, model_mapping }
|
// Collect web_search_emulation (only anthropic platform supports it)
|
||||||
|
const wsEmulation: Record<string, boolean> = {}
|
||||||
|
for (const section of form.platforms) {
|
||||||
|
if (!section.enabled) continue
|
||||||
|
if (section.web_search_emulation && section.platform === 'anthropic') {
|
||||||
|
wsEmulation[section.platform] = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (Object.keys(wsEmulation).length > 0) {
|
||||||
|
featuresConfig.web_search_emulation = wsEmulation
|
||||||
|
}
|
||||||
|
|
||||||
|
return { group_ids, model_pricing, model_mapping, features_config: featuresConfig }
|
||||||
}
|
}
|
||||||
|
|
||||||
function apiToForm(channel: Channel): PlatformSection[] {
|
function apiToForm(channel: Channel): PlatformSection[] {
|
||||||
@ -755,13 +804,19 @@ function apiToForm(channel: Channel): PlatformSection[] {
|
|||||||
intervals: apiIntervalsToForm(p.intervals || [])
|
intervals: apiIntervalsToForm(p.intervals || [])
|
||||||
} as PricingFormEntry))
|
} as PricingFormEntry))
|
||||||
|
|
||||||
|
// Read web_search_emulation from features_config
|
||||||
|
const fc = channel.features_config
|
||||||
|
const wsEmulation = fc?.web_search_emulation as Record<string, boolean> | undefined
|
||||||
|
const webSearchEnabled = wsEmulation?.[platform] === true
|
||||||
|
|
||||||
sections.push({
|
sections.push({
|
||||||
platform,
|
platform,
|
||||||
enabled: true,
|
enabled: true,
|
||||||
collapsed: false,
|
collapsed: false,
|
||||||
group_ids: groupIds,
|
group_ids: groupIds,
|
||||||
model_mapping: { ...mapping },
|
model_mapping: { ...mapping },
|
||||||
model_pricing: pricing
|
model_pricing: pricing,
|
||||||
|
web_search_emulation: webSearchEnabled,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -786,10 +841,10 @@ async function loadChannels() {
|
|||||||
if (ctrl.signal.aborted || abortController !== ctrl) return
|
if (ctrl.signal.aborted || abortController !== ctrl) return
|
||||||
channels.value = response.items || []
|
channels.value = response.items || []
|
||||||
pagination.total = response.total
|
pagination.total = response.total
|
||||||
} catch (error: any) {
|
} catch (error: unknown) {
|
||||||
if (error?.name === 'AbortError' || error?.code === 'ERR_CANCELED') return
|
const e = error as { name?: string; code?: string }
|
||||||
appStore.showError(t('admin.channels.loadError', 'Failed to load channels'))
|
if (e?.name === 'AbortError' || e?.code === 'ERR_CANCELED') return
|
||||||
console.error('Error loading channels:', error)
|
appStore.showError(extractApiErrorMessage(error, t('admin.channels.loadError', 'Failed to load channels')))
|
||||||
} finally {
|
} finally {
|
||||||
if (abortController === ctrl) {
|
if (abortController === ctrl) {
|
||||||
loading.value = false
|
loading.value = false
|
||||||
@ -969,8 +1024,7 @@ async function handleSubmit() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const { group_ids, model_pricing, model_mapping } = formToAPI()
|
const { group_ids, model_pricing, model_mapping, features_config } = formToAPI()
|
||||||
console.log('[handleSubmit] model_pricing to send:', JSON.stringify(model_pricing))
|
|
||||||
|
|
||||||
submitting.value = true
|
submitting.value = true
|
||||||
try {
|
try {
|
||||||
@ -983,7 +1037,8 @@ async function handleSubmit() {
|
|||||||
model_pricing,
|
model_pricing,
|
||||||
model_mapping: Object.keys(model_mapping).length > 0 ? model_mapping : {},
|
model_mapping: Object.keys(model_mapping).length > 0 ? model_mapping : {},
|
||||||
billing_model_source: form.billing_model_source,
|
billing_model_source: form.billing_model_source,
|
||||||
restrict_models: form.restrict_models
|
restrict_models: form.restrict_models,
|
||||||
|
features_config,
|
||||||
}
|
}
|
||||||
await adminAPI.channels.update(editingChannel.value.id, req)
|
await adminAPI.channels.update(editingChannel.value.id, req)
|
||||||
appStore.showSuccess(t('admin.channels.updateSuccess', 'Channel updated'))
|
appStore.showSuccess(t('admin.channels.updateSuccess', 'Channel updated'))
|
||||||
@ -995,19 +1050,18 @@ async function handleSubmit() {
|
|||||||
model_pricing,
|
model_pricing,
|
||||||
model_mapping: Object.keys(model_mapping).length > 0 ? model_mapping : {},
|
model_mapping: Object.keys(model_mapping).length > 0 ? model_mapping : {},
|
||||||
billing_model_source: form.billing_model_source,
|
billing_model_source: form.billing_model_source,
|
||||||
restrict_models: form.restrict_models
|
restrict_models: form.restrict_models,
|
||||||
|
features_config,
|
||||||
}
|
}
|
||||||
await adminAPI.channels.create(req)
|
await adminAPI.channels.create(req)
|
||||||
appStore.showSuccess(t('admin.channels.createSuccess', 'Channel created'))
|
appStore.showSuccess(t('admin.channels.createSuccess', 'Channel created'))
|
||||||
}
|
}
|
||||||
closeDialog()
|
closeDialog()
|
||||||
loadChannels()
|
loadChannels()
|
||||||
} catch (error: any) {
|
} catch (error: unknown) {
|
||||||
const msg = error.response?.data?.detail || (editingChannel.value
|
appStore.showError(extractApiErrorMessage(error, editingChannel.value
|
||||||
? t('admin.channels.updateError', 'Failed to update channel')
|
? t('admin.channels.updateError', 'Failed to update channel')
|
||||||
: t('admin.channels.createError', 'Failed to create channel'))
|
: t('admin.channels.createError', 'Failed to create channel')))
|
||||||
appStore.showError(msg)
|
|
||||||
console.error('Error saving channel:', error)
|
|
||||||
} finally {
|
} finally {
|
||||||
submitting.value = false
|
submitting.value = false
|
||||||
}
|
}
|
||||||
@ -1045,9 +1099,8 @@ async function confirmDelete() {
|
|||||||
showDeleteDialog.value = false
|
showDeleteDialog.value = false
|
||||||
deletingChannel.value = null
|
deletingChannel.value = null
|
||||||
loadChannels()
|
loadChannels()
|
||||||
} catch (error: any) {
|
} catch (error: unknown) {
|
||||||
appStore.showError(error.response?.data?.detail || t('admin.channels.deleteError', 'Failed to delete channel'))
|
appStore.showError(extractApiErrorMessage(error, t('admin.channels.deleteError', 'Failed to delete channel')))
|
||||||
console.error('Error deleting channel:', error)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1055,6 +1108,7 @@ async function confirmDelete() {
|
|||||||
onMounted(() => {
|
onMounted(() => {
|
||||||
loadChannels()
|
loadChannels()
|
||||||
loadGroups()
|
loadGroups()
|
||||||
|
loadWebSearchGlobalState()
|
||||||
})
|
})
|
||||||
|
|
||||||
onUnmounted(() => {
|
onUnmounted(() => {
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
Loading…
x
Reference in New Issue
Block a user