package helper import ( "encoding/json" "fmt" "math" "strings" "github.com/QuantumNous/new-api/common" "github.com/QuantumNous/new-api/dto" "github.com/QuantumNous/new-api/logger" relaycommon "github.com/QuantumNous/new-api/relay/common" "github.com/QuantumNous/new-api/service" "github.com/gin-gonic/gin" ) // FinalizeImagePerImageBilling adjusts billing from the upstream image response: // actual image count and optional resolution inferred from output/input images. func FinalizeImagePerImageBilling(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ImageRequest, responseBody []byte) { if info == nil || info.ImageBilling == nil || !info.PriceData.UsePrice { return } channelID := 0 if info.ChannelMeta != nil { channelID = info.ChannelId } modelName := info.OriginModelName if !HasImageGenerationPricing(channelID, modelName) { return } actualCount := countImagesInResponseBody(responseBody) if actualCount <= 0 { if request != nil && request.N != nil && *request.N > 0 { actualCount = int(*request.N) } else if n, ok := info.PriceData.OtherRatios["n"]; ok && n > 0 { actualCount = int(math.Round(n)) } else { actualCount = 1 } } estimateCtx := estimateImageRequestContext(c, info) estimateCtx.Count = actualCount if w, h, ok := resolveImageDimensions(c, request, responseBody); ok { estimateCtx.Width = w estimateCtx.Height = h } if !SyncImagePerImagePriceData(c, info, estimateCtx) { info.PriceData.AddOtherRatio("n", float64(actualCount)) if info.ImageBilling != nil { info.ImageBilling.Width = estimateCtx.Width info.ImageBilling.Height = estimateCtx.Height info.ImageBilling.Count = actualCount info.ImageBilling.Mode = string(estimateCtx.Mode) } return } if common.DebugEnabled && info.ImageBilling != nil { logger.LogDebug(c, fmt.Sprintf( "[image][finalize] model=%s mode=%s w=%d h=%d actualCount=%d channelUSD=%.6f globalUSD=%.6f effUSD=%.6f quota=%d", modelName, estimateCtx.Mode, estimateCtx.Width, estimateCtx.Height, actualCount, info.PriceData.ModelPrice, info.PriceData.GlobalModelPrice, info.ImageBilling.UsdPerImage, info.PriceData.Quota, )) } } func countImagesInResponseBody(body []byte) int { body = bytesTrimSpace(body) if len(body) == 0 { return 0 } var imageResp dto.ImageResponse if err := common.Unmarshal(body, &imageResp); err != nil { return 0 } count := 0 for _, item := range imageResp.Data { if strings.TrimSpace(item.Url) != "" || strings.TrimSpace(item.B64Json) != "" { count++ } } return count } func resolveImageDimensions(c *gin.Context, request *dto.ImageRequest, responseBody []byte) (int, int, bool) { if request != nil { if w, h, ok := parseResolutionFlexible(request.Size); ok { return w, h, true } } if w, h, ok := dimensionsFromImageResponseBody(c, responseBody); ok { return w, h, true } if request != nil { for _, url := range extractImageInputURLs(request.Image) { if w, h, ok := decodeImageURLDimensions(url); ok { return w, h, true } } } return 0, 0, false } func dimensionsFromImageResponseBody(c *gin.Context, body []byte) (int, int, bool) { body = bytesTrimSpace(body) if len(body) == 0 { return 0, 0, false } var imageResp dto.ImageResponse if err := common.Unmarshal(body, &imageResp); err != nil { return 0, 0, false } for _, item := range imageResp.Data { if w, h, ok := dimensionsFromImageData(c, item); ok { return w, h, true } } return 0, 0, false } func dimensionsFromImageData(c *gin.Context, item dto.ImageData) (int, int, bool) { if url := strings.TrimSpace(item.Url); url != "" { return decodeImageURLDimensions(url) } if b64 := strings.TrimSpace(item.B64Json); b64 != "" { if cfg, _, _, err := service.DecodeBase64ImageData(b64); err == nil && cfg.Width > 0 && cfg.Height > 0 { return cfg.Width, cfg.Height, true } } _ = c return 0, 0, false } func decodeImageURLDimensions(url string) (int, int, bool) { url = strings.TrimSpace(url) if url == "" { return 0, 0, false } cfg, _, err := service.DecodeUrlImageData(url) if err != nil || cfg.Width <= 0 || cfg.Height <= 0 { return 0, 0, false } return cfg.Width, cfg.Height, true } func extractImageInputURLs(raw json.RawMessage) []string { if len(raw) == 0 { return nil } s := strings.TrimSpace(string(raw)) if s == "" || s == "null" { return nil } var single string if err := common.Unmarshal(raw, &single); err == nil { single = strings.TrimSpace(single) if single != "" { return []string{single} } } var list []string if err := common.Unmarshal(raw, &list); err == nil { out := make([]string, 0, len(list)) for _, item := range list { item = strings.TrimSpace(item) if item != "" { out = append(out, item) } } if len(out) > 0 { return out } } var objects []struct { URL string `json:"url"` } if err := common.Unmarshal(raw, &objects); err == nil { out := make([]string, 0, len(objects)) for _, obj := range objects { u := strings.TrimSpace(obj.URL) if u != "" { out = append(out, u) } } if len(out) > 0 { return out } } return nil } func bytesTrimSpace(b []byte) []byte { return []byte(strings.TrimSpace(string(b))) }