feat(channel-monitor): 拆分 OpenAI 检测协议
Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-openagent) Co-authored-by: Sisyphus <clio-agent@sisyphuslabs.ai>
This commit is contained in:
parent
b685fe69a4
commit
1184ef265f
@ -40,6 +40,8 @@ func newSSRFSafeHTTPClient(timeout time.Duration) *http.Client {
|
||||
// CheckOptions 承载一次检测的自定义入参。
|
||||
// 所有字段都是可选(零值即等价于"用默认行为")。
|
||||
type CheckOptions struct {
|
||||
// APIMode 仅对 OpenAI provider 生效;空串等同 chat_completions。
|
||||
APIMode string
|
||||
// ExtraHeaders 用户自定义 HTTP 头(merge 到 adapter 默认 headers,用户优先)。
|
||||
ExtraHeaders map[string]string
|
||||
// BodyOverrideMode: off | merge | replace
|
||||
@ -164,21 +166,7 @@ type providerAdapter struct {
|
||||
//
|
||||
//nolint:gochecknoglobals // 适配器表是只读静态数据,初始化后不变更。
|
||||
var providerAdapters = map[string]providerAdapter{
|
||||
MonitorProviderOpenAI: {
|
||||
buildPath: func(string) string { return providerOpenAIPath },
|
||||
buildBody: func(model, prompt string) ([]byte, error) {
|
||||
return json.Marshal(map[string]any{
|
||||
"model": model,
|
||||
"messages": []map[string]string{{"role": "user", "content": prompt}},
|
||||
"max_tokens": monitorChallengeMaxTokens,
|
||||
"stream": false,
|
||||
})
|
||||
},
|
||||
buildHeaders: func(apiKey string) map[string]string {
|
||||
return map[string]string{"Authorization": "Bearer " + apiKey}
|
||||
},
|
||||
textPath: "choices.0.message.content",
|
||||
},
|
||||
MonitorProviderOpenAI: providerOpenAIChatAdapter,
|
||||
MonitorProviderAnthropic: {
|
||||
buildPath: func(string) string { return providerAnthropicPath },
|
||||
buildBody: func(model, prompt string) ([]byte, error) {
|
||||
@ -215,6 +203,50 @@ var providerAdapters = map[string]providerAdapter{
|
||||
},
|
||||
}
|
||||
|
||||
//nolint:gochecknoglobals // 适配器表是只读静态数据,初始化后不变更。
|
||||
var providerOpenAIChatAdapter = providerAdapter{
|
||||
buildPath: func(string) string { return providerOpenAIPath },
|
||||
buildBody: func(model, prompt string) ([]byte, error) {
|
||||
return json.Marshal(map[string]any{
|
||||
"model": model,
|
||||
"messages": []map[string]string{{"role": "user", "content": prompt}},
|
||||
"max_tokens": monitorChallengeMaxTokens,
|
||||
"stream": false,
|
||||
})
|
||||
},
|
||||
buildHeaders: func(apiKey string) map[string]string {
|
||||
return map[string]string{"Authorization": "Bearer " + apiKey}
|
||||
},
|
||||
textPath: "choices.0.message.content",
|
||||
}
|
||||
|
||||
//nolint:gochecknoglobals // 适配器表是只读静态数据,初始化后不变更。
|
||||
var providerOpenAIResponsesAdapter = providerAdapter{
|
||||
buildPath: func(string) string { return providerOpenAIResponsesPath },
|
||||
buildBody: func(model, prompt string) ([]byte, error) {
|
||||
return json.Marshal(map[string]any{
|
||||
"model": model,
|
||||
"instructions": "You are a channel health-check endpoint. Answer the arithmetic challenge exactly and briefly.",
|
||||
"input": prompt,
|
||||
"max_output_tokens": monitorChallengeMaxTokens,
|
||||
"stream": false,
|
||||
})
|
||||
},
|
||||
buildHeaders: func(apiKey string) map[string]string {
|
||||
return map[string]string{"Authorization": "Bearer " + apiKey}
|
||||
},
|
||||
textPath: "output.0.content.0.text",
|
||||
}
|
||||
|
||||
// providerAdapterFor 按 provider + api_mode 选择具体 adapter。
|
||||
func providerAdapterFor(provider, apiMode string) (providerAdapter, string, bool) {
|
||||
if provider == MonitorProviderOpenAI && defaultAPIMode(apiMode) == MonitorAPIModeResponses {
|
||||
return providerOpenAIResponsesAdapter, MonitorAPIModeResponses, true
|
||||
}
|
||||
adapter, ok := providerAdapters[provider]
|
||||
return adapter, MonitorAPIModeChatCompletions, ok
|
||||
}
|
||||
|
||||
// isSupportedProvider 校验 provider 字符串是否在 adapter 表中。
|
||||
// 供 validate.go 的 validateProvider 复用,避免两份 switch 漂移。
|
||||
func isSupportedProvider(p string) bool {
|
||||
@ -231,11 +263,15 @@ func isSupportedProvider(p string) bool {
|
||||
// - status: HTTP 状态码
|
||||
// - err: 网络 / 序列化错误
|
||||
func callProvider(ctx context.Context, provider, endpoint, apiKey, model, prompt string, opts *CheckOptions) (extractedText, rawBody string, status int, err error) {
|
||||
adapter, ok := providerAdapters[provider]
|
||||
requestedAPIMode := checkAPIMode(opts)
|
||||
if err := validateAPIMode(provider, requestedAPIMode); err != nil {
|
||||
return "", "", 0, err
|
||||
}
|
||||
adapter, apiMode, ok := providerAdapterFor(provider, requestedAPIMode)
|
||||
if !ok {
|
||||
return "", "", 0, fmt.Errorf("unsupported provider %q", provider)
|
||||
}
|
||||
body, err := buildRequestBody(adapter, provider, model, prompt, opts)
|
||||
body, err := buildRequestBody(adapter, provider, apiMode, model, prompt, opts)
|
||||
if err != nil {
|
||||
return "", "", 0, err
|
||||
}
|
||||
@ -275,13 +311,16 @@ func mergeHeaders(base map[string]string, opts *CheckOptions) map[string]string
|
||||
// - replace: 直接 marshal BodyOverride 作为完整 body
|
||||
//
|
||||
// 任何 mode 返回的 []byte 都已经是合法 JSON,可直接送入 postRawJSON。
|
||||
func buildRequestBody(adapter providerAdapter, provider, model, prompt string, opts *CheckOptions) ([]byte, error) {
|
||||
func buildRequestBody(adapter providerAdapter, provider, apiMode, model, prompt string, opts *CheckOptions) ([]byte, error) {
|
||||
mode := bodyOverrideMode(opts)
|
||||
|
||||
if mode == MonitorBodyOverrideModeReplace {
|
||||
if opts == nil || len(opts.BodyOverride) == 0 {
|
||||
return nil, fmt.Errorf("replace mode: body_override is empty")
|
||||
}
|
||||
if err := validateReplaceRequestBody(provider, apiMode, opts.BodyOverride); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
body, err := json.Marshal(opts.BodyOverride)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshal body_override (replace): %w", err)
|
||||
@ -301,7 +340,7 @@ func buildRequestBody(adapter providerAdapter, provider, model, prompt string, o
|
||||
if err := json.Unmarshal(defaultBody, &defaultMap); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal default body for merge: %w", err)
|
||||
}
|
||||
deny := bodyMergeKeyDenyList[provider]
|
||||
deny := bodyMergeKeyDenyList[bodyMergeDenyKey(provider, apiMode)]
|
||||
for k, v := range opts.BodyOverride {
|
||||
if deny[k] {
|
||||
continue
|
||||
@ -321,9 +360,63 @@ func buildRequestBody(adapter providerAdapter, provider, model, prompt string, o
|
||||
//
|
||||
//nolint:gochecknoglobals // 静态查表,初始化后不变。
|
||||
var bodyMergeKeyDenyList = map[string]map[string]bool{
|
||||
MonitorProviderOpenAI: {"model": true, "messages": true, "stream": true},
|
||||
MonitorProviderAnthropic: {"model": true, "messages": true},
|
||||
MonitorProviderGemini: {"contents": true},
|
||||
MonitorProviderOpenAI + ":" + MonitorAPIModeChatCompletions: {"model": true, "messages": true, "stream": true},
|
||||
MonitorProviderOpenAI + ":" + MonitorAPIModeResponses: {"model": true, "instructions": true, "input": true, "stream": true},
|
||||
MonitorProviderAnthropic: {"model": true, "messages": true},
|
||||
MonitorProviderGemini: {"contents": true},
|
||||
}
|
||||
|
||||
func checkAPIMode(opts *CheckOptions) string {
|
||||
if opts == nil {
|
||||
return MonitorAPIModeChatCompletions
|
||||
}
|
||||
return defaultAPIMode(opts.APIMode)
|
||||
}
|
||||
|
||||
func bodyMergeDenyKey(provider, apiMode string) string {
|
||||
if provider == MonitorProviderOpenAI {
|
||||
return provider + ":" + defaultAPIMode(apiMode)
|
||||
}
|
||||
return provider
|
||||
}
|
||||
|
||||
func validateReplaceRequestBody(provider, apiMode string, body map[string]any) error {
|
||||
if provider != MonitorProviderOpenAI {
|
||||
return nil
|
||||
}
|
||||
switch defaultAPIMode(apiMode) {
|
||||
case MonitorAPIModeResponses:
|
||||
if strings.TrimSpace(stringFromAny(body["instructions"])) == "" || !hasNonEmptyBodyValue(body["input"]) {
|
||||
return fmt.Errorf("replace mode responses body: instructions and input are required")
|
||||
}
|
||||
case MonitorAPIModeChatCompletions:
|
||||
if !hasNonEmptyBodyValue(body["messages"]) {
|
||||
return fmt.Errorf("replace mode chat_completions body: messages are required")
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func stringFromAny(v any) string {
|
||||
s, _ := v.(string)
|
||||
return s
|
||||
}
|
||||
|
||||
func hasNonEmptyBodyValue(v any) bool {
|
||||
switch val := v.(type) {
|
||||
case nil:
|
||||
return false
|
||||
case string:
|
||||
return strings.TrimSpace(val) != ""
|
||||
case []any:
|
||||
return len(val) > 0
|
||||
case []map[string]any:
|
||||
return len(val) > 0
|
||||
case []map[string]string:
|
||||
return len(val) > 0
|
||||
default:
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// postRawJSON 发送 POST + 已序列化好的 JSON 字节,限制响应体大小,返回响应字节、HTTP status、错误。
|
||||
|
||||
@ -7,6 +7,8 @@ import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
@ -57,6 +59,76 @@ func setupFakeAnthropic(t *testing.T, handler *captureHandler) string {
|
||||
return srv.URL
|
||||
}
|
||||
|
||||
type openAICaptureHandler struct {
|
||||
lastBody map[string]any
|
||||
lastHeaders http.Header
|
||||
lastPath string
|
||||
status int
|
||||
}
|
||||
|
||||
func (h *openAICaptureHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
h.lastHeaders = r.Header.Clone()
|
||||
h.lastPath = r.URL.Path
|
||||
defer func() { _ = r.Body.Close() }()
|
||||
var parsed map[string]any
|
||||
_ = json.NewDecoder(r.Body).Decode(&parsed)
|
||||
h.lastBody = parsed
|
||||
|
||||
if h.status == 0 {
|
||||
h.status = http.StatusOK
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(h.status)
|
||||
|
||||
answer := answerFromOpenAIRequest(parsed)
|
||||
if h.lastPath == providerOpenAIResponsesPath {
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||
"output": []map[string]any{{
|
||||
"content": []map[string]any{{"type": "output_text", "text": answer}},
|
||||
}},
|
||||
})
|
||||
return
|
||||
}
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||
"choices": []map[string]any{{"message": map[string]any{"content": answer}}},
|
||||
})
|
||||
}
|
||||
|
||||
func setupFakeOpenAI(t *testing.T, handler *openAICaptureHandler) string {
|
||||
t.Helper()
|
||||
swapMonitorHTTPClient(t)
|
||||
srv := httptest.NewServer(handler)
|
||||
t.Cleanup(srv.Close)
|
||||
return srv.URL
|
||||
}
|
||||
|
||||
func answerFromOpenAIRequest(body map[string]any) string {
|
||||
prompt, _ := body["input"].(string)
|
||||
if prompt == "" {
|
||||
if messages, ok := body["messages"].([]any); ok && len(messages) > 0 {
|
||||
if msg, ok := messages[0].(map[string]any); ok {
|
||||
prompt, _ = msg["content"].(string)
|
||||
}
|
||||
}
|
||||
}
|
||||
return answerFromChallengePrompt(prompt)
|
||||
}
|
||||
|
||||
var challengeQuestionRegex = regexp.MustCompile(`Q: (\d+) ([+-]) (\d+) = \?\nA:$`)
|
||||
|
||||
func answerFromChallengePrompt(prompt string) string {
|
||||
m := challengeQuestionRegex.FindStringSubmatch(prompt)
|
||||
if len(m) != 4 {
|
||||
return "0"
|
||||
}
|
||||
left, _ := strconv.Atoi(m[1])
|
||||
right, _ := strconv.Atoi(m[3])
|
||||
if m[2] == "+" {
|
||||
return strconv.Itoa(left + right)
|
||||
}
|
||||
return strconv.Itoa(left - right)
|
||||
}
|
||||
|
||||
func TestRunCheckForModel_OffMode_PreservesDefaultBody(t *testing.T) {
|
||||
h := &captureHandler{respondText: "the answer is 42"}
|
||||
endpoint := setupFakeAnthropic(t, h)
|
||||
@ -75,6 +147,95 @@ func TestRunCheckForModel_OffMode_PreservesDefaultBody(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunCheckForModel_OpenAI_DefaultChatRequest(t *testing.T) {
|
||||
h := &openAICaptureHandler{}
|
||||
endpoint := setupFakeOpenAI(t, h)
|
||||
|
||||
res := runCheckForModel(context.Background(), MonitorProviderOpenAI, endpoint, "sk-openai", "gpt-test", nil)
|
||||
|
||||
if res.Status != MonitorStatusOperational {
|
||||
t.Fatalf("default chat request should pass challenge, got status=%s message=%q", res.Status, res.Message)
|
||||
}
|
||||
if h.lastPath != providerOpenAIPath {
|
||||
t.Fatalf("expected chat completions path %q, got %q", providerOpenAIPath, h.lastPath)
|
||||
}
|
||||
if h.lastBody["model"] != "gpt-test" {
|
||||
t.Errorf("chat body should contain model=gpt-test, got %v", h.lastBody["model"])
|
||||
}
|
||||
if _, ok := h.lastBody["messages"]; !ok {
|
||||
t.Error("chat body should contain messages")
|
||||
}
|
||||
if _, ok := h.lastBody["instructions"]; ok {
|
||||
t.Error("chat body must not contain top-level instructions")
|
||||
}
|
||||
if h.lastBody["stream"] != false {
|
||||
t.Errorf("chat body should set stream=false, got %v", h.lastBody["stream"])
|
||||
}
|
||||
if h.lastHeaders.Get("Authorization") != "Bearer sk-openai" {
|
||||
t.Errorf("expected bearer auth header, got %q", h.lastHeaders.Get("Authorization"))
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunCheckForModel_OpenAIResponses_DefaultRequest(t *testing.T) {
|
||||
h := &openAICaptureHandler{}
|
||||
endpoint := setupFakeOpenAI(t, h)
|
||||
|
||||
res := runCheckForModel(context.Background(), MonitorProviderOpenAI, endpoint, "sk-openai", "gpt-test", &CheckOptions{
|
||||
APIMode: MonitorAPIModeResponses,
|
||||
})
|
||||
|
||||
if res.Status != MonitorStatusOperational {
|
||||
t.Fatalf("default responses request should pass challenge, got status=%s message=%q", res.Status, res.Message)
|
||||
}
|
||||
if h.lastPath != providerOpenAIResponsesPath {
|
||||
t.Fatalf("expected responses path %q, got %q", providerOpenAIResponsesPath, h.lastPath)
|
||||
}
|
||||
if h.lastBody["model"] != "gpt-test" {
|
||||
t.Errorf("responses body should contain model=gpt-test, got %v", h.lastBody["model"])
|
||||
}
|
||||
instructions, _ := h.lastBody["instructions"].(string)
|
||||
if strings.TrimSpace(instructions) == "" {
|
||||
t.Error("responses body should contain non-empty instructions")
|
||||
}
|
||||
input, _ := h.lastBody["input"].(string)
|
||||
if strings.TrimSpace(input) == "" {
|
||||
t.Error("responses body should contain non-empty input")
|
||||
}
|
||||
if _, ok := h.lastBody["messages"]; ok {
|
||||
t.Error("responses body must not contain chat messages")
|
||||
}
|
||||
if h.lastBody["stream"] != false {
|
||||
t.Errorf("responses body should set stream=false, got %v", h.lastBody["stream"])
|
||||
}
|
||||
if h.lastHeaders.Get("Authorization") != "Bearer sk-openai" {
|
||||
t.Errorf("expected bearer auth header, got %q", h.lastHeaders.Get("Authorization"))
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunCheckForModel_OpenAIResponsesReplaceMissingInstructionsFailsLocally(t *testing.T) {
|
||||
h := &openAICaptureHandler{}
|
||||
endpoint := setupFakeOpenAI(t, h)
|
||||
|
||||
res := runCheckForModel(context.Background(), MonitorProviderOpenAI, endpoint, "sk-openai", "gpt-test", &CheckOptions{
|
||||
APIMode: MonitorAPIModeResponses,
|
||||
BodyOverrideMode: MonitorBodyOverrideModeReplace,
|
||||
BodyOverride: map[string]any{
|
||||
"model": "gpt-test",
|
||||
"input": "hello",
|
||||
},
|
||||
})
|
||||
|
||||
if res.Status != MonitorStatusError {
|
||||
t.Fatalf("invalid responses replace body should fail locally as error, got status=%s", res.Status)
|
||||
}
|
||||
if !strings.Contains(res.Message, "instructions and input are required") {
|
||||
t.Errorf("expected local validation message about instructions/input, got %q", res.Message)
|
||||
}
|
||||
if h.lastPath != "" {
|
||||
t.Errorf("invalid replace body should fail before HTTP request, got path %q", h.lastPath)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunCheckForModel_MergeMode_UserFieldsWinButDenyListProtects(t *testing.T) {
|
||||
h := &captureHandler{respondText: "the answer is 42"}
|
||||
endpoint := setupFakeAnthropic(t, h)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user