package repository import ( "io" "net/http" "net/http/httptrace" "os" "strings" "testing" ) func TestWrapTransportAuditIfEnabledDisabled(t *testing.T) { t.Setenv(transportAuditEnv, "") base := roundTripFunc(func(r *http.Request) (*http.Response, error) { if httptrace.ContextClientTrace(r.Context()) != nil { t.Fatalf("unexpected client trace when audit is disabled") } return &http.Response{ StatusCode: 200, Proto: "HTTP/1.1", Body: io.NopCloser(strings.NewReader("ok")), Request: r, }, nil }) wrapped := wrapTransportAuditIfEnabled(base, "plain") if _, ok := wrapped.(*transportAuditRoundTripper); ok { t.Fatalf("expected base transport when audit disabled") } req, err := http.NewRequest(http.MethodGet, "https://api.anthropic.com/v1/messages?beta=true", nil) if err != nil { t.Fatalf("new request: %v", err) } if _, err := wrapped.RoundTrip(req); err != nil { t.Fatalf("round trip: %v", err) } } func TestWrapTransportAuditIfEnabledEnabled(t *testing.T) { t.Setenv(transportAuditEnv, "1") base := roundTripFunc(func(r *http.Request) (*http.Response, error) { if httptrace.ContextClientTrace(r.Context()) == nil { t.Fatalf("expected client trace when audit is enabled") } return &http.Response{ StatusCode: 200, Proto: "HTTP/1.1", Body: io.NopCloser(strings.NewReader("ok")), Request: r, }, nil }) wrapped := wrapTransportAuditIfEnabled(base, "tlsfp") if _, ok := wrapped.(*transportAuditRoundTripper); !ok { t.Fatalf("expected wrapped transport when audit enabled") } req, err := http.NewRequest(http.MethodGet, "https://api.anthropic.com/v1/messages?beta=true", nil) if err != nil { t.Fatalf("new request: %v", err) } if _, err := wrapped.RoundTrip(req); err != nil { t.Fatalf("round trip: %v", err) } } func TestTransportAuditEnabled(t *testing.T) { cases := []struct { name string raw string want bool }{ {name: "empty", raw: "", want: false}, {name: "one", raw: "1", want: true}, {name: "true", raw: "true", want: true}, {name: "on", raw: "on", want: true}, {name: "no", raw: "no", want: false}, } for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { if tc.raw == "" { _ = os.Unsetenv(transportAuditEnv) } else { t.Setenv(transportAuditEnv, tc.raw) } if got := transportAuditEnabled(); got != tc.want { t.Fatalf("transportAuditEnabled() = %v, want %v", got, tc.want) } }) } }