diff --git a/backend/internal/service/channel_monitor_checker.go b/backend/internal/service/channel_monitor_checker.go index 33570629..25737e45 100644 --- a/backend/internal/service/channel_monitor_checker.go +++ b/backend/internal/service/channel_monitor_checker.go @@ -40,6 +40,8 @@ func newSSRFSafeHTTPClient(timeout time.Duration) *http.Client { // CheckOptions 承载一次检测的自定义入参。 // 所有字段都是可选(零值即等价于"用默认行为")。 type CheckOptions struct { + // APIMode 仅对 OpenAI provider 生效;空串等同 chat_completions。 + APIMode string // ExtraHeaders 用户自定义 HTTP 头(merge 到 adapter 默认 headers,用户优先)。 ExtraHeaders map[string]string // BodyOverrideMode: off | merge | replace @@ -164,21 +166,7 @@ type providerAdapter struct { // //nolint:gochecknoglobals // 适配器表是只读静态数据,初始化后不变更。 var providerAdapters = map[string]providerAdapter{ - MonitorProviderOpenAI: { - buildPath: func(string) string { return providerOpenAIPath }, - buildBody: func(model, prompt string) ([]byte, error) { - return json.Marshal(map[string]any{ - "model": model, - "messages": []map[string]string{{"role": "user", "content": prompt}}, - "max_tokens": monitorChallengeMaxTokens, - "stream": false, - }) - }, - buildHeaders: func(apiKey string) map[string]string { - return map[string]string{"Authorization": "Bearer " + apiKey} - }, - textPath: "choices.0.message.content", - }, + MonitorProviderOpenAI: providerOpenAIChatAdapter, MonitorProviderAnthropic: { buildPath: func(string) string { return providerAnthropicPath }, buildBody: func(model, prompt string) ([]byte, error) { @@ -215,6 +203,50 @@ var providerAdapters = map[string]providerAdapter{ }, } +//nolint:gochecknoglobals // 适配器表是只读静态数据,初始化后不变更。 +var providerOpenAIChatAdapter = providerAdapter{ + buildPath: func(string) string { return providerOpenAIPath }, + buildBody: func(model, prompt string) ([]byte, error) { + return json.Marshal(map[string]any{ + "model": model, + "messages": []map[string]string{{"role": "user", "content": prompt}}, + "max_tokens": monitorChallengeMaxTokens, + "stream": false, + }) + }, + buildHeaders: func(apiKey string) map[string]string { + return map[string]string{"Authorization": "Bearer " + apiKey} + }, + textPath: "choices.0.message.content", +} + +//nolint:gochecknoglobals // 适配器表是只读静态数据,初始化后不变更。 +var providerOpenAIResponsesAdapter = providerAdapter{ + buildPath: func(string) string { return providerOpenAIResponsesPath }, + buildBody: func(model, prompt string) ([]byte, error) { + return json.Marshal(map[string]any{ + "model": model, + "instructions": "You are a channel health-check endpoint. Answer the arithmetic challenge exactly and briefly.", + "input": prompt, + "max_output_tokens": monitorChallengeMaxTokens, + "stream": false, + }) + }, + buildHeaders: func(apiKey string) map[string]string { + return map[string]string{"Authorization": "Bearer " + apiKey} + }, + textPath: "output.0.content.0.text", +} + +// providerAdapterFor 按 provider + api_mode 选择具体 adapter。 +func providerAdapterFor(provider, apiMode string) (providerAdapter, string, bool) { + if provider == MonitorProviderOpenAI && defaultAPIMode(apiMode) == MonitorAPIModeResponses { + return providerOpenAIResponsesAdapter, MonitorAPIModeResponses, true + } + adapter, ok := providerAdapters[provider] + return adapter, MonitorAPIModeChatCompletions, ok +} + // isSupportedProvider 校验 provider 字符串是否在 adapter 表中。 // 供 validate.go 的 validateProvider 复用,避免两份 switch 漂移。 func isSupportedProvider(p string) bool { @@ -231,11 +263,15 @@ func isSupportedProvider(p string) bool { // - status: HTTP 状态码 // - err: 网络 / 序列化错误 func callProvider(ctx context.Context, provider, endpoint, apiKey, model, prompt string, opts *CheckOptions) (extractedText, rawBody string, status int, err error) { - adapter, ok := providerAdapters[provider] + requestedAPIMode := checkAPIMode(opts) + if err := validateAPIMode(provider, requestedAPIMode); err != nil { + return "", "", 0, err + } + adapter, apiMode, ok := providerAdapterFor(provider, requestedAPIMode) if !ok { return "", "", 0, fmt.Errorf("unsupported provider %q", provider) } - body, err := buildRequestBody(adapter, provider, model, prompt, opts) + body, err := buildRequestBody(adapter, provider, apiMode, model, prompt, opts) if err != nil { return "", "", 0, err } @@ -275,13 +311,16 @@ func mergeHeaders(base map[string]string, opts *CheckOptions) map[string]string // - replace: 直接 marshal BodyOverride 作为完整 body // // 任何 mode 返回的 []byte 都已经是合法 JSON,可直接送入 postRawJSON。 -func buildRequestBody(adapter providerAdapter, provider, model, prompt string, opts *CheckOptions) ([]byte, error) { +func buildRequestBody(adapter providerAdapter, provider, apiMode, model, prompt string, opts *CheckOptions) ([]byte, error) { mode := bodyOverrideMode(opts) if mode == MonitorBodyOverrideModeReplace { if opts == nil || len(opts.BodyOverride) == 0 { return nil, fmt.Errorf("replace mode: body_override is empty") } + if err := validateReplaceRequestBody(provider, apiMode, opts.BodyOverride); err != nil { + return nil, err + } body, err := json.Marshal(opts.BodyOverride) if err != nil { return nil, fmt.Errorf("marshal body_override (replace): %w", err) @@ -301,7 +340,7 @@ func buildRequestBody(adapter providerAdapter, provider, model, prompt string, o if err := json.Unmarshal(defaultBody, &defaultMap); err != nil { return nil, fmt.Errorf("unmarshal default body for merge: %w", err) } - deny := bodyMergeKeyDenyList[provider] + deny := bodyMergeKeyDenyList[bodyMergeDenyKey(provider, apiMode)] for k, v := range opts.BodyOverride { if deny[k] { continue @@ -321,9 +360,63 @@ func buildRequestBody(adapter providerAdapter, provider, model, prompt string, o // //nolint:gochecknoglobals // 静态查表,初始化后不变。 var bodyMergeKeyDenyList = map[string]map[string]bool{ - MonitorProviderOpenAI: {"model": true, "messages": true, "stream": true}, - MonitorProviderAnthropic: {"model": true, "messages": true}, - MonitorProviderGemini: {"contents": true}, + MonitorProviderOpenAI + ":" + MonitorAPIModeChatCompletions: {"model": true, "messages": true, "stream": true}, + MonitorProviderOpenAI + ":" + MonitorAPIModeResponses: {"model": true, "instructions": true, "input": true, "stream": true}, + MonitorProviderAnthropic: {"model": true, "messages": true}, + MonitorProviderGemini: {"contents": true}, +} + +func checkAPIMode(opts *CheckOptions) string { + if opts == nil { + return MonitorAPIModeChatCompletions + } + return defaultAPIMode(opts.APIMode) +} + +func bodyMergeDenyKey(provider, apiMode string) string { + if provider == MonitorProviderOpenAI { + return provider + ":" + defaultAPIMode(apiMode) + } + return provider +} + +func validateReplaceRequestBody(provider, apiMode string, body map[string]any) error { + if provider != MonitorProviderOpenAI { + return nil + } + switch defaultAPIMode(apiMode) { + case MonitorAPIModeResponses: + if strings.TrimSpace(stringFromAny(body["instructions"])) == "" || !hasNonEmptyBodyValue(body["input"]) { + return fmt.Errorf("replace mode responses body: instructions and input are required") + } + case MonitorAPIModeChatCompletions: + if !hasNonEmptyBodyValue(body["messages"]) { + return fmt.Errorf("replace mode chat_completions body: messages are required") + } + } + return nil +} + +func stringFromAny(v any) string { + s, _ := v.(string) + return s +} + +func hasNonEmptyBodyValue(v any) bool { + switch val := v.(type) { + case nil: + return false + case string: + return strings.TrimSpace(val) != "" + case []any: + return len(val) > 0 + case []map[string]any: + return len(val) > 0 + case []map[string]string: + return len(val) > 0 + default: + return true + } } // postRawJSON 发送 POST + 已序列化好的 JSON 字节,限制响应体大小,返回响应字节、HTTP status、错误。 diff --git a/backend/internal/service/channel_monitor_checker_body_test.go b/backend/internal/service/channel_monitor_checker_body_test.go index 323cf8b7..620cf565 100644 --- a/backend/internal/service/channel_monitor_checker_body_test.go +++ b/backend/internal/service/channel_monitor_checker_body_test.go @@ -7,6 +7,8 @@ import ( "encoding/json" "net/http" "net/http/httptest" + "regexp" + "strconv" "strings" "testing" "time" @@ -57,6 +59,76 @@ func setupFakeAnthropic(t *testing.T, handler *captureHandler) string { return srv.URL } +type openAICaptureHandler struct { + lastBody map[string]any + lastHeaders http.Header + lastPath string + status int +} + +func (h *openAICaptureHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + h.lastHeaders = r.Header.Clone() + h.lastPath = r.URL.Path + defer func() { _ = r.Body.Close() }() + var parsed map[string]any + _ = json.NewDecoder(r.Body).Decode(&parsed) + h.lastBody = parsed + + if h.status == 0 { + h.status = http.StatusOK + } + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(h.status) + + answer := answerFromOpenAIRequest(parsed) + if h.lastPath == providerOpenAIResponsesPath { + _ = json.NewEncoder(w).Encode(map[string]any{ + "output": []map[string]any{{ + "content": []map[string]any{{"type": "output_text", "text": answer}}, + }}, + }) + return + } + _ = json.NewEncoder(w).Encode(map[string]any{ + "choices": []map[string]any{{"message": map[string]any{"content": answer}}}, + }) +} + +func setupFakeOpenAI(t *testing.T, handler *openAICaptureHandler) string { + t.Helper() + swapMonitorHTTPClient(t) + srv := httptest.NewServer(handler) + t.Cleanup(srv.Close) + return srv.URL +} + +func answerFromOpenAIRequest(body map[string]any) string { + prompt, _ := body["input"].(string) + if prompt == "" { + if messages, ok := body["messages"].([]any); ok && len(messages) > 0 { + if msg, ok := messages[0].(map[string]any); ok { + prompt, _ = msg["content"].(string) + } + } + } + return answerFromChallengePrompt(prompt) +} + +var challengeQuestionRegex = regexp.MustCompile(`Q: (\d+) ([+-]) (\d+) = \?\nA:$`) + +func answerFromChallengePrompt(prompt string) string { + m := challengeQuestionRegex.FindStringSubmatch(prompt) + if len(m) != 4 { + return "0" + } + left, _ := strconv.Atoi(m[1]) + right, _ := strconv.Atoi(m[3]) + if m[2] == "+" { + return strconv.Itoa(left + right) + } + return strconv.Itoa(left - right) +} + func TestRunCheckForModel_OffMode_PreservesDefaultBody(t *testing.T) { h := &captureHandler{respondText: "the answer is 42"} endpoint := setupFakeAnthropic(t, h) @@ -75,6 +147,95 @@ func TestRunCheckForModel_OffMode_PreservesDefaultBody(t *testing.T) { } } +func TestRunCheckForModel_OpenAI_DefaultChatRequest(t *testing.T) { + h := &openAICaptureHandler{} + endpoint := setupFakeOpenAI(t, h) + + res := runCheckForModel(context.Background(), MonitorProviderOpenAI, endpoint, "sk-openai", "gpt-test", nil) + + if res.Status != MonitorStatusOperational { + t.Fatalf("default chat request should pass challenge, got status=%s message=%q", res.Status, res.Message) + } + if h.lastPath != providerOpenAIPath { + t.Fatalf("expected chat completions path %q, got %q", providerOpenAIPath, h.lastPath) + } + if h.lastBody["model"] != "gpt-test" { + t.Errorf("chat body should contain model=gpt-test, got %v", h.lastBody["model"]) + } + if _, ok := h.lastBody["messages"]; !ok { + t.Error("chat body should contain messages") + } + if _, ok := h.lastBody["instructions"]; ok { + t.Error("chat body must not contain top-level instructions") + } + if h.lastBody["stream"] != false { + t.Errorf("chat body should set stream=false, got %v", h.lastBody["stream"]) + } + if h.lastHeaders.Get("Authorization") != "Bearer sk-openai" { + t.Errorf("expected bearer auth header, got %q", h.lastHeaders.Get("Authorization")) + } +} + +func TestRunCheckForModel_OpenAIResponses_DefaultRequest(t *testing.T) { + h := &openAICaptureHandler{} + endpoint := setupFakeOpenAI(t, h) + + res := runCheckForModel(context.Background(), MonitorProviderOpenAI, endpoint, "sk-openai", "gpt-test", &CheckOptions{ + APIMode: MonitorAPIModeResponses, + }) + + if res.Status != MonitorStatusOperational { + t.Fatalf("default responses request should pass challenge, got status=%s message=%q", res.Status, res.Message) + } + if h.lastPath != providerOpenAIResponsesPath { + t.Fatalf("expected responses path %q, got %q", providerOpenAIResponsesPath, h.lastPath) + } + if h.lastBody["model"] != "gpt-test" { + t.Errorf("responses body should contain model=gpt-test, got %v", h.lastBody["model"]) + } + instructions, _ := h.lastBody["instructions"].(string) + if strings.TrimSpace(instructions) == "" { + t.Error("responses body should contain non-empty instructions") + } + input, _ := h.lastBody["input"].(string) + if strings.TrimSpace(input) == "" { + t.Error("responses body should contain non-empty input") + } + if _, ok := h.lastBody["messages"]; ok { + t.Error("responses body must not contain chat messages") + } + if h.lastBody["stream"] != false { + t.Errorf("responses body should set stream=false, got %v", h.lastBody["stream"]) + } + if h.lastHeaders.Get("Authorization") != "Bearer sk-openai" { + t.Errorf("expected bearer auth header, got %q", h.lastHeaders.Get("Authorization")) + } +} + +func TestRunCheckForModel_OpenAIResponsesReplaceMissingInstructionsFailsLocally(t *testing.T) { + h := &openAICaptureHandler{} + endpoint := setupFakeOpenAI(t, h) + + res := runCheckForModel(context.Background(), MonitorProviderOpenAI, endpoint, "sk-openai", "gpt-test", &CheckOptions{ + APIMode: MonitorAPIModeResponses, + BodyOverrideMode: MonitorBodyOverrideModeReplace, + BodyOverride: map[string]any{ + "model": "gpt-test", + "input": "hello", + }, + }) + + if res.Status != MonitorStatusError { + t.Fatalf("invalid responses replace body should fail locally as error, got status=%s", res.Status) + } + if !strings.Contains(res.Message, "instructions and input are required") { + t.Errorf("expected local validation message about instructions/input, got %q", res.Message) + } + if h.lastPath != "" { + t.Errorf("invalid replace body should fail before HTTP request, got path %q", h.lastPath) + } +} + func TestRunCheckForModel_MergeMode_UserFieldsWinButDenyListProtects(t *testing.T) { h := &captureHandler{respondText: "the answer is 42"} endpoint := setupFakeAnthropic(t, h)