sub2api/backend/internal/service/image_output_accounting.go
2026-05-12 15:21:31 +08:00

205 lines
4.9 KiB
Go

package service
import (
"crypto/sha256"
"encoding/hex"
"strings"
"github.com/tidwall/gjson"
)
type openAIImageOutputCounter struct {
seen map[string]struct{}
seenSizes map[string]string
seenOrder []string
dataSizes []string
count int
maxDataCount int
}
func newOpenAIImageOutputCounter() *openAIImageOutputCounter {
return &openAIImageOutputCounter{
seen: make(map[string]struct{}),
seenSizes: make(map[string]string),
}
}
func (c *openAIImageOutputCounter) Count() int {
if c == nil {
return 0
}
if c.maxDataCount > c.count {
return c.maxDataCount
}
return c.count
}
func (c *openAIImageOutputCounter) Sizes() []string {
if c == nil {
return nil
}
sizes := make([]string, 0, len(c.seenOrder)+len(c.dataSizes))
for _, key := range c.seenOrder {
if size := strings.TrimSpace(c.seenSizes[key]); size != "" {
sizes = append(sizes, size)
}
}
if len(sizes) == 0 && len(c.dataSizes) > 0 {
sizes = append(sizes, c.dataSizes...)
}
if len(sizes) == 0 {
return nil
}
return sizes
}
func (c *openAIImageOutputCounter) AddJSONResponse(body []byte) {
if c == nil || len(body) == 0 || !gjson.ValidBytes(body) {
return
}
c.addDataArray(gjson.GetBytes(body, "data"))
c.addOutputArray(gjson.GetBytes(body, "output"))
c.addOutputArray(gjson.GetBytes(body, "response.output"))
}
func (c *openAIImageOutputCounter) AddSSEData(data []byte) {
if c == nil || len(data) == 0 || strings.TrimSpace(string(data)) == "[DONE]" || !gjson.ValidBytes(data) {
return
}
root := gjson.ParseBytes(data)
c.addDataArray(root.Get("data"))
eventType := strings.TrimSpace(root.Get("type").String())
switch eventType {
case "response.output_item.done":
c.addImageOutputItem(root.Get("item"))
case "response.completed", "response.done":
c.addOutputArray(root.Get("response.output"))
case "image_generation.completed":
if item := root.Get("item"); item.Exists() {
c.addImageOutputItem(item)
return
}
if output := root.Get("output"); output.Exists() {
c.addImageOutputItem(output)
return
}
c.addImageOutputItem(root)
}
}
func (c *openAIImageOutputCounter) AddSSEBody(body string) {
if c == nil || strings.TrimSpace(body) == "" {
return
}
forEachOpenAISSEDataPayload(body, c.AddSSEData)
}
func (c *openAIImageOutputCounter) addDataArray(data gjson.Result) {
if !data.IsArray() {
return
}
items := data.Array()
count := len(items)
if count > c.maxDataCount {
c.maxDataCount = count
}
sizes := make([]string, 0, len(items))
for _, item := range items {
if size := strings.TrimSpace(item.Get("size").String()); size != "" {
sizes = append(sizes, size)
}
}
if len(sizes) > 0 {
c.dataSizes = sizes
}
}
func (c *openAIImageOutputCounter) addOutputArray(output gjson.Result) {
if !output.IsArray() {
return
}
output.ForEach(func(_, item gjson.Result) bool {
c.addImageOutputItem(item)
return true
})
}
func (c *openAIImageOutputCounter) addImageOutputItem(item gjson.Result) {
if !item.Exists() || !item.IsObject() {
return
}
itemType := strings.TrimSpace(item.Get("type").String())
if itemType != "" && itemType != "image_generation_call" && itemType != "image_generation.completed" {
return
}
if strings.Contains(strings.ToLower(item.Raw), "partial_image") {
return
}
result := strings.TrimSpace(item.Get("result").String())
if result == "" {
result = strings.TrimSpace(item.Get("b64_json").String())
}
if result == "" {
result = strings.TrimSpace(item.Get("url").String())
}
if result == "" && itemType != "image_generation.completed" {
return
}
key := strings.TrimSpace(item.Get("id").String())
if key == "" {
key = strings.TrimSpace(item.Get("call_id").String())
}
if key == "" {
key = hashOpenAIImageOutputResult(result)
}
if key == "" {
return
}
size := strings.TrimSpace(item.Get("size").String())
if _, exists := c.seen[key]; exists {
if size != "" && strings.TrimSpace(c.seenSizes[key]) == "" {
c.seenSizes[key] = size
}
return
}
c.seen[key] = struct{}{}
c.seenOrder = append(c.seenOrder, key)
if size != "" {
c.seenSizes[key] = size
}
c.count++
}
func hashOpenAIImageOutputResult(result string) string {
result = strings.TrimSpace(result)
if result == "" {
return ""
}
sum := sha256.Sum256([]byte(result))
return hex.EncodeToString(sum[:])
}
func countOpenAIResponseImageOutputsFromJSONBytes(body []byte) int {
counter := newOpenAIImageOutputCounter()
counter.AddJSONResponse(body)
return counter.Count()
}
func collectOpenAIResponseImageOutputSizesFromJSONBytes(body []byte) []string {
counter := newOpenAIImageOutputCounter()
counter.AddJSONResponse(body)
return counter.Sizes()
}
func countOpenAIImageOutputsFromSSEBody(body string) int {
counter := newOpenAIImageOutputCounter()
counter.AddSSEBody(body)
return counter.Count()
}
func collectOpenAIImageOutputSizesFromSSEBody(body string) []string {
counter := newOpenAIImageOutputCounter()
counter.AddSSEBody(body)
return counter.Sizes()
}