feat(channel-monitor): 拆分 OpenAI 检测协议

Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-openagent)

Co-authored-by: Sisyphus <clio-agent@sisyphuslabs.ai>
This commit is contained in:
benjamin 2026-05-19 22:04:55 +08:00
parent b685fe69a4
commit 1184ef265f
2 changed files with 276 additions and 22 deletions

View File

@ -40,6 +40,8 @@ func newSSRFSafeHTTPClient(timeout time.Duration) *http.Client {
// CheckOptions 承载一次检测的自定义入参。
// 所有字段都是可选(零值即等价于"用默认行为")。
type CheckOptions struct {
// APIMode 仅对 OpenAI provider 生效;空串等同 chat_completions。
APIMode string
// ExtraHeaders 用户自定义 HTTP 头merge 到 adapter 默认 headers用户优先
ExtraHeaders map[string]string
// BodyOverrideMode: off | merge | replace
@ -164,21 +166,7 @@ type providerAdapter struct {
//
//nolint:gochecknoglobals // 适配器表是只读静态数据,初始化后不变更。
var providerAdapters = map[string]providerAdapter{
MonitorProviderOpenAI: {
buildPath: func(string) string { return providerOpenAIPath },
buildBody: func(model, prompt string) ([]byte, error) {
return json.Marshal(map[string]any{
"model": model,
"messages": []map[string]string{{"role": "user", "content": prompt}},
"max_tokens": monitorChallengeMaxTokens,
"stream": false,
})
},
buildHeaders: func(apiKey string) map[string]string {
return map[string]string{"Authorization": "Bearer " + apiKey}
},
textPath: "choices.0.message.content",
},
MonitorProviderOpenAI: providerOpenAIChatAdapter,
MonitorProviderAnthropic: {
buildPath: func(string) string { return providerAnthropicPath },
buildBody: func(model, prompt string) ([]byte, error) {
@ -215,6 +203,50 @@ var providerAdapters = map[string]providerAdapter{
},
}
//nolint:gochecknoglobals // 适配器表是只读静态数据,初始化后不变更。
var providerOpenAIChatAdapter = providerAdapter{
buildPath: func(string) string { return providerOpenAIPath },
buildBody: func(model, prompt string) ([]byte, error) {
return json.Marshal(map[string]any{
"model": model,
"messages": []map[string]string{{"role": "user", "content": prompt}},
"max_tokens": monitorChallengeMaxTokens,
"stream": false,
})
},
buildHeaders: func(apiKey string) map[string]string {
return map[string]string{"Authorization": "Bearer " + apiKey}
},
textPath: "choices.0.message.content",
}
//nolint:gochecknoglobals // 适配器表是只读静态数据,初始化后不变更。
var providerOpenAIResponsesAdapter = providerAdapter{
buildPath: func(string) string { return providerOpenAIResponsesPath },
buildBody: func(model, prompt string) ([]byte, error) {
return json.Marshal(map[string]any{
"model": model,
"instructions": "You are a channel health-check endpoint. Answer the arithmetic challenge exactly and briefly.",
"input": prompt,
"max_output_tokens": monitorChallengeMaxTokens,
"stream": false,
})
},
buildHeaders: func(apiKey string) map[string]string {
return map[string]string{"Authorization": "Bearer " + apiKey}
},
textPath: "output.0.content.0.text",
}
// providerAdapterFor 按 provider + api_mode 选择具体 adapter。
func providerAdapterFor(provider, apiMode string) (providerAdapter, string, bool) {
if provider == MonitorProviderOpenAI && defaultAPIMode(apiMode) == MonitorAPIModeResponses {
return providerOpenAIResponsesAdapter, MonitorAPIModeResponses, true
}
adapter, ok := providerAdapters[provider]
return adapter, MonitorAPIModeChatCompletions, ok
}
// isSupportedProvider 校验 provider 字符串是否在 adapter 表中。
// 供 validate.go 的 validateProvider 复用,避免两份 switch 漂移。
func isSupportedProvider(p string) bool {
@ -231,11 +263,15 @@ func isSupportedProvider(p string) bool {
// - status: HTTP 状态码
// - err: 网络 / 序列化错误
func callProvider(ctx context.Context, provider, endpoint, apiKey, model, prompt string, opts *CheckOptions) (extractedText, rawBody string, status int, err error) {
adapter, ok := providerAdapters[provider]
requestedAPIMode := checkAPIMode(opts)
if err := validateAPIMode(provider, requestedAPIMode); err != nil {
return "", "", 0, err
}
adapter, apiMode, ok := providerAdapterFor(provider, requestedAPIMode)
if !ok {
return "", "", 0, fmt.Errorf("unsupported provider %q", provider)
}
body, err := buildRequestBody(adapter, provider, model, prompt, opts)
body, err := buildRequestBody(adapter, provider, apiMode, model, prompt, opts)
if err != nil {
return "", "", 0, err
}
@ -275,13 +311,16 @@ func mergeHeaders(base map[string]string, opts *CheckOptions) map[string]string
// - replace: 直接 marshal BodyOverride 作为完整 body
//
// 任何 mode 返回的 []byte 都已经是合法 JSON可直接送入 postRawJSON。
func buildRequestBody(adapter providerAdapter, provider, model, prompt string, opts *CheckOptions) ([]byte, error) {
func buildRequestBody(adapter providerAdapter, provider, apiMode, model, prompt string, opts *CheckOptions) ([]byte, error) {
mode := bodyOverrideMode(opts)
if mode == MonitorBodyOverrideModeReplace {
if opts == nil || len(opts.BodyOverride) == 0 {
return nil, fmt.Errorf("replace mode: body_override is empty")
}
if err := validateReplaceRequestBody(provider, apiMode, opts.BodyOverride); err != nil {
return nil, err
}
body, err := json.Marshal(opts.BodyOverride)
if err != nil {
return nil, fmt.Errorf("marshal body_override (replace): %w", err)
@ -301,7 +340,7 @@ func buildRequestBody(adapter providerAdapter, provider, model, prompt string, o
if err := json.Unmarshal(defaultBody, &defaultMap); err != nil {
return nil, fmt.Errorf("unmarshal default body for merge: %w", err)
}
deny := bodyMergeKeyDenyList[provider]
deny := bodyMergeKeyDenyList[bodyMergeDenyKey(provider, apiMode)]
for k, v := range opts.BodyOverride {
if deny[k] {
continue
@ -321,9 +360,63 @@ func buildRequestBody(adapter providerAdapter, provider, model, prompt string, o
//
//nolint:gochecknoglobals // 静态查表,初始化后不变。
var bodyMergeKeyDenyList = map[string]map[string]bool{
MonitorProviderOpenAI: {"model": true, "messages": true, "stream": true},
MonitorProviderAnthropic: {"model": true, "messages": true},
MonitorProviderGemini: {"contents": true},
MonitorProviderOpenAI + ":" + MonitorAPIModeChatCompletions: {"model": true, "messages": true, "stream": true},
MonitorProviderOpenAI + ":" + MonitorAPIModeResponses: {"model": true, "instructions": true, "input": true, "stream": true},
MonitorProviderAnthropic: {"model": true, "messages": true},
MonitorProviderGemini: {"contents": true},
}
func checkAPIMode(opts *CheckOptions) string {
if opts == nil {
return MonitorAPIModeChatCompletions
}
return defaultAPIMode(opts.APIMode)
}
func bodyMergeDenyKey(provider, apiMode string) string {
if provider == MonitorProviderOpenAI {
return provider + ":" + defaultAPIMode(apiMode)
}
return provider
}
func validateReplaceRequestBody(provider, apiMode string, body map[string]any) error {
if provider != MonitorProviderOpenAI {
return nil
}
switch defaultAPIMode(apiMode) {
case MonitorAPIModeResponses:
if strings.TrimSpace(stringFromAny(body["instructions"])) == "" || !hasNonEmptyBodyValue(body["input"]) {
return fmt.Errorf("replace mode responses body: instructions and input are required")
}
case MonitorAPIModeChatCompletions:
if !hasNonEmptyBodyValue(body["messages"]) {
return fmt.Errorf("replace mode chat_completions body: messages are required")
}
}
return nil
}
func stringFromAny(v any) string {
s, _ := v.(string)
return s
}
func hasNonEmptyBodyValue(v any) bool {
switch val := v.(type) {
case nil:
return false
case string:
return strings.TrimSpace(val) != ""
case []any:
return len(val) > 0
case []map[string]any:
return len(val) > 0
case []map[string]string:
return len(val) > 0
default:
return true
}
}
// postRawJSON 发送 POST + 已序列化好的 JSON 字节限制响应体大小返回响应字节、HTTP status、错误。

View File

@ -7,6 +7,8 @@ import (
"encoding/json"
"net/http"
"net/http/httptest"
"regexp"
"strconv"
"strings"
"testing"
"time"
@ -57,6 +59,76 @@ func setupFakeAnthropic(t *testing.T, handler *captureHandler) string {
return srv.URL
}
type openAICaptureHandler struct {
lastBody map[string]any
lastHeaders http.Header
lastPath string
status int
}
func (h *openAICaptureHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
h.lastHeaders = r.Header.Clone()
h.lastPath = r.URL.Path
defer func() { _ = r.Body.Close() }()
var parsed map[string]any
_ = json.NewDecoder(r.Body).Decode(&parsed)
h.lastBody = parsed
if h.status == 0 {
h.status = http.StatusOK
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(h.status)
answer := answerFromOpenAIRequest(parsed)
if h.lastPath == providerOpenAIResponsesPath {
_ = json.NewEncoder(w).Encode(map[string]any{
"output": []map[string]any{{
"content": []map[string]any{{"type": "output_text", "text": answer}},
}},
})
return
}
_ = json.NewEncoder(w).Encode(map[string]any{
"choices": []map[string]any{{"message": map[string]any{"content": answer}}},
})
}
func setupFakeOpenAI(t *testing.T, handler *openAICaptureHandler) string {
t.Helper()
swapMonitorHTTPClient(t)
srv := httptest.NewServer(handler)
t.Cleanup(srv.Close)
return srv.URL
}
func answerFromOpenAIRequest(body map[string]any) string {
prompt, _ := body["input"].(string)
if prompt == "" {
if messages, ok := body["messages"].([]any); ok && len(messages) > 0 {
if msg, ok := messages[0].(map[string]any); ok {
prompt, _ = msg["content"].(string)
}
}
}
return answerFromChallengePrompt(prompt)
}
var challengeQuestionRegex = regexp.MustCompile(`Q: (\d+) ([+-]) (\d+) = \?\nA:$`)
func answerFromChallengePrompt(prompt string) string {
m := challengeQuestionRegex.FindStringSubmatch(prompt)
if len(m) != 4 {
return "0"
}
left, _ := strconv.Atoi(m[1])
right, _ := strconv.Atoi(m[3])
if m[2] == "+" {
return strconv.Itoa(left + right)
}
return strconv.Itoa(left - right)
}
func TestRunCheckForModel_OffMode_PreservesDefaultBody(t *testing.T) {
h := &captureHandler{respondText: "the answer is 42"}
endpoint := setupFakeAnthropic(t, h)
@ -75,6 +147,95 @@ func TestRunCheckForModel_OffMode_PreservesDefaultBody(t *testing.T) {
}
}
func TestRunCheckForModel_OpenAI_DefaultChatRequest(t *testing.T) {
h := &openAICaptureHandler{}
endpoint := setupFakeOpenAI(t, h)
res := runCheckForModel(context.Background(), MonitorProviderOpenAI, endpoint, "sk-openai", "gpt-test", nil)
if res.Status != MonitorStatusOperational {
t.Fatalf("default chat request should pass challenge, got status=%s message=%q", res.Status, res.Message)
}
if h.lastPath != providerOpenAIPath {
t.Fatalf("expected chat completions path %q, got %q", providerOpenAIPath, h.lastPath)
}
if h.lastBody["model"] != "gpt-test" {
t.Errorf("chat body should contain model=gpt-test, got %v", h.lastBody["model"])
}
if _, ok := h.lastBody["messages"]; !ok {
t.Error("chat body should contain messages")
}
if _, ok := h.lastBody["instructions"]; ok {
t.Error("chat body must not contain top-level instructions")
}
if h.lastBody["stream"] != false {
t.Errorf("chat body should set stream=false, got %v", h.lastBody["stream"])
}
if h.lastHeaders.Get("Authorization") != "Bearer sk-openai" {
t.Errorf("expected bearer auth header, got %q", h.lastHeaders.Get("Authorization"))
}
}
func TestRunCheckForModel_OpenAIResponses_DefaultRequest(t *testing.T) {
h := &openAICaptureHandler{}
endpoint := setupFakeOpenAI(t, h)
res := runCheckForModel(context.Background(), MonitorProviderOpenAI, endpoint, "sk-openai", "gpt-test", &CheckOptions{
APIMode: MonitorAPIModeResponses,
})
if res.Status != MonitorStatusOperational {
t.Fatalf("default responses request should pass challenge, got status=%s message=%q", res.Status, res.Message)
}
if h.lastPath != providerOpenAIResponsesPath {
t.Fatalf("expected responses path %q, got %q", providerOpenAIResponsesPath, h.lastPath)
}
if h.lastBody["model"] != "gpt-test" {
t.Errorf("responses body should contain model=gpt-test, got %v", h.lastBody["model"])
}
instructions, _ := h.lastBody["instructions"].(string)
if strings.TrimSpace(instructions) == "" {
t.Error("responses body should contain non-empty instructions")
}
input, _ := h.lastBody["input"].(string)
if strings.TrimSpace(input) == "" {
t.Error("responses body should contain non-empty input")
}
if _, ok := h.lastBody["messages"]; ok {
t.Error("responses body must not contain chat messages")
}
if h.lastBody["stream"] != false {
t.Errorf("responses body should set stream=false, got %v", h.lastBody["stream"])
}
if h.lastHeaders.Get("Authorization") != "Bearer sk-openai" {
t.Errorf("expected bearer auth header, got %q", h.lastHeaders.Get("Authorization"))
}
}
func TestRunCheckForModel_OpenAIResponsesReplaceMissingInstructionsFailsLocally(t *testing.T) {
h := &openAICaptureHandler{}
endpoint := setupFakeOpenAI(t, h)
res := runCheckForModel(context.Background(), MonitorProviderOpenAI, endpoint, "sk-openai", "gpt-test", &CheckOptions{
APIMode: MonitorAPIModeResponses,
BodyOverrideMode: MonitorBodyOverrideModeReplace,
BodyOverride: map[string]any{
"model": "gpt-test",
"input": "hello",
},
})
if res.Status != MonitorStatusError {
t.Fatalf("invalid responses replace body should fail locally as error, got status=%s", res.Status)
}
if !strings.Contains(res.Message, "instructions and input are required") {
t.Errorf("expected local validation message about instructions/input, got %q", res.Message)
}
if h.lastPath != "" {
t.Errorf("invalid replace body should fail before HTTP request, got path %q", h.lastPath)
}
}
func TestRunCheckForModel_MergeMode_UserFieldsWinButDenyListProtects(t *testing.T) {
h := &captureHandler{respondText: "the answer is 42"}
endpoint := setupFakeAnthropic(t, h)