fix(service): normalize user agent for sticky session hashes
Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-openagent) Co-authored-by: Sisyphus <clio-agent@sisyphuslabs.ai>
This commit is contained in:
parent
13b72f6bc2
commit
bcf84cc153
@ -5,6 +5,8 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"math"
|
"math"
|
||||||
|
"regexp"
|
||||||
|
"sort"
|
||||||
"strings"
|
"strings"
|
||||||
"unsafe"
|
"unsafe"
|
||||||
|
|
||||||
@ -34,6 +36,9 @@ var (
|
|||||||
patternEmptyTextSpaced = []byte(`"text": ""`)
|
patternEmptyTextSpaced = []byte(`"text": ""`)
|
||||||
patternEmptyTextSp1 = []byte(`"text" : ""`)
|
patternEmptyTextSp1 = []byte(`"text" : ""`)
|
||||||
patternEmptyTextSp2 = []byte(`"text" :""`)
|
patternEmptyTextSp2 = []byte(`"text" :""`)
|
||||||
|
|
||||||
|
sessionUserAgentProductPattern = regexp.MustCompile(`([A-Za-z0-9._-]+)/[A-Za-z0-9._-]+`)
|
||||||
|
sessionUserAgentVersionPattern = regexp.MustCompile(`\bv?\d+(?:\.\d+){1,3}\b`)
|
||||||
)
|
)
|
||||||
|
|
||||||
// SessionContext 粘性会话上下文,用于区分不同来源的请求。
|
// SessionContext 粘性会话上下文,用于区分不同来源的请求。
|
||||||
@ -75,6 +80,49 @@ type ParsedRequest struct {
|
|||||||
OnUpstreamAccepted func()
|
OnUpstreamAccepted func()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NormalizeSessionUserAgent reduces UA noise for sticky-session and digest hashing.
|
||||||
|
// It preserves the set of product names from Product/Version tokens while
|
||||||
|
// discarding version-only changes and incidental comments.
|
||||||
|
func NormalizeSessionUserAgent(raw string) string {
|
||||||
|
raw = strings.TrimSpace(raw)
|
||||||
|
if raw == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
matches := sessionUserAgentProductPattern.FindAllStringSubmatch(raw, -1)
|
||||||
|
if len(matches) == 0 {
|
||||||
|
return normalizeSessionUserAgentFallback(raw)
|
||||||
|
}
|
||||||
|
|
||||||
|
products := make([]string, 0, len(matches))
|
||||||
|
seen := make(map[string]struct{}, len(matches))
|
||||||
|
for _, match := range matches {
|
||||||
|
if len(match) < 2 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
product := strings.ToLower(strings.TrimSpace(match[1]))
|
||||||
|
if product == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if _, exists := seen[product]; exists {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
seen[product] = struct{}{}
|
||||||
|
products = append(products, product)
|
||||||
|
}
|
||||||
|
if len(products) == 0 {
|
||||||
|
return normalizeSessionUserAgentFallback(raw)
|
||||||
|
}
|
||||||
|
sort.Strings(products)
|
||||||
|
return strings.Join(products, "+")
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizeSessionUserAgentFallback(raw string) string {
|
||||||
|
normalized := strings.ToLower(strings.Join(strings.Fields(raw), " "))
|
||||||
|
normalized = sessionUserAgentVersionPattern.ReplaceAllString(normalized, "")
|
||||||
|
return strings.Join(strings.Fields(normalized), " ")
|
||||||
|
}
|
||||||
|
|
||||||
// ParseGatewayRequest 解析网关请求体并返回结构化结果。
|
// ParseGatewayRequest 解析网关请求体并返回结构化结果。
|
||||||
// protocol 指定请求协议格式(domain.PlatformAnthropic / domain.PlatformGemini),
|
// protocol 指定请求协议格式(domain.PlatformAnthropic / domain.PlatformGemini),
|
||||||
// 不同协议使用不同的 system/messages 字段名。
|
// 不同协议使用不同的 system/messages 字段名。
|
||||||
|
|||||||
@ -658,7 +658,7 @@ func (s *GatewayService) GenerateSessionHash(parsed *ParsedRequest) string {
|
|||||||
if parsed.SessionContext != nil {
|
if parsed.SessionContext != nil {
|
||||||
_, _ = combined.WriteString(parsed.SessionContext.ClientIP)
|
_, _ = combined.WriteString(parsed.SessionContext.ClientIP)
|
||||||
_, _ = combined.WriteString(":")
|
_, _ = combined.WriteString(":")
|
||||||
_, _ = combined.WriteString(parsed.SessionContext.UserAgent)
|
_, _ = combined.WriteString(NormalizeSessionUserAgent(parsed.SessionContext.UserAgent))
|
||||||
_, _ = combined.WriteString(":")
|
_, _ = combined.WriteString(":")
|
||||||
_, _ = combined.WriteString(strconv.FormatInt(parsed.SessionContext.APIKeyID, 10))
|
_, _ = combined.WriteString(strconv.FormatInt(parsed.SessionContext.APIKeyID, 10))
|
||||||
_, _ = combined.WriteString("|")
|
_, _ = combined.WriteString("|")
|
||||||
|
|||||||
@ -504,6 +504,48 @@ func TestGenerateSessionHash_SessionContext_UADifference(t *testing.T) {
|
|||||||
require.NotEqual(t, h1, h2, "different User-Agent should produce different hash")
|
require.NotEqual(t, h1, h2, "different User-Agent should produce different hash")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestGenerateSessionHash_SessionContext_UAVersionNoiseIgnored(t *testing.T) {
|
||||||
|
svc := &GatewayService{}
|
||||||
|
|
||||||
|
base := func(ua string) *ParsedRequest {
|
||||||
|
return &ParsedRequest{
|
||||||
|
Messages: []any{
|
||||||
|
map[string]any{"role": "user", "content": "test"},
|
||||||
|
},
|
||||||
|
SessionContext: &SessionContext{
|
||||||
|
ClientIP: "1.1.1.1",
|
||||||
|
UserAgent: ua,
|
||||||
|
APIKeyID: 1,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
h1 := svc.GenerateSessionHash(base("Mozilla/5.0 codex_cli_rs/0.1.0"))
|
||||||
|
h2 := svc.GenerateSessionHash(base("Mozilla/5.0 codex_cli_rs/0.1.1"))
|
||||||
|
require.Equal(t, h1, h2, "version-only User-Agent changes should not perturb the sticky session hash")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGenerateSessionHash_SessionContext_FreeformUAVersionNoiseIgnored(t *testing.T) {
|
||||||
|
svc := &GatewayService{}
|
||||||
|
|
||||||
|
base := func(ua string) *ParsedRequest {
|
||||||
|
return &ParsedRequest{
|
||||||
|
Messages: []any{
|
||||||
|
map[string]any{"role": "user", "content": "test"},
|
||||||
|
},
|
||||||
|
SessionContext: &SessionContext{
|
||||||
|
ClientIP: "1.1.1.1",
|
||||||
|
UserAgent: ua,
|
||||||
|
APIKeyID: 1,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
h1 := svc.GenerateSessionHash(base("Codex CLI 0.1.0"))
|
||||||
|
h2 := svc.GenerateSessionHash(base("Codex CLI 0.1.1"))
|
||||||
|
require.Equal(t, h1, h2, "free-form version-only User-Agent changes should not perturb the sticky session hash")
|
||||||
|
}
|
||||||
|
|
||||||
func TestGenerateSessionHash_SessionContext_APIKeyIDDifference(t *testing.T) {
|
func TestGenerateSessionHash_SessionContext_APIKeyIDDifference(t *testing.T) {
|
||||||
svc := &GatewayService{}
|
svc := &GatewayService{}
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user