200 lines
5.1 KiB
Go
200 lines
5.1 KiB
Go
package helper
|
|
|
|
import (
|
|
"encoding/json"
|
|
"fmt"
|
|
"math"
|
|
"strings"
|
|
|
|
"github.com/QuantumNous/new-api/common"
|
|
"github.com/QuantumNous/new-api/dto"
|
|
"github.com/QuantumNous/new-api/logger"
|
|
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
|
"github.com/QuantumNous/new-api/service"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
)
|
|
|
|
// FinalizeImagePerImageBilling adjusts billing from the upstream image response:
|
|
// actual image count and optional resolution inferred from output/input images.
|
|
func FinalizeImagePerImageBilling(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ImageRequest, responseBody []byte) {
|
|
if info == nil || info.ImageBilling == nil || !info.PriceData.UsePrice {
|
|
return
|
|
}
|
|
|
|
channelID := 0
|
|
if info.ChannelMeta != nil {
|
|
channelID = info.ChannelId
|
|
}
|
|
modelName := info.OriginModelName
|
|
if !HasImageGenerationPricing(channelID, modelName) {
|
|
return
|
|
}
|
|
|
|
actualCount := countImagesInResponseBody(responseBody)
|
|
if actualCount <= 0 {
|
|
if request != nil && request.N != nil && *request.N > 0 {
|
|
actualCount = int(*request.N)
|
|
} else if n, ok := info.PriceData.OtherRatios["n"]; ok && n > 0 {
|
|
actualCount = int(math.Round(n))
|
|
} else {
|
|
actualCount = 1
|
|
}
|
|
}
|
|
|
|
estimateCtx := estimateImageRequestContext(c, info)
|
|
estimateCtx.Count = actualCount
|
|
if w, h, ok := resolveImageDimensions(c, request, responseBody); ok {
|
|
estimateCtx.Width = w
|
|
estimateCtx.Height = h
|
|
}
|
|
|
|
if !SyncImagePerImagePriceData(c, info, estimateCtx) {
|
|
info.PriceData.AddOtherRatio("n", float64(actualCount))
|
|
if info.ImageBilling != nil {
|
|
info.ImageBilling.Width = estimateCtx.Width
|
|
info.ImageBilling.Height = estimateCtx.Height
|
|
info.ImageBilling.Count = actualCount
|
|
info.ImageBilling.Mode = string(estimateCtx.Mode)
|
|
}
|
|
return
|
|
}
|
|
|
|
if common.DebugEnabled && info.ImageBilling != nil {
|
|
logger.LogDebug(c, fmt.Sprintf(
|
|
"[image][finalize] model=%s mode=%s w=%d h=%d actualCount=%d channelUSD=%.6f globalUSD=%.6f effUSD=%.6f quota=%d",
|
|
modelName, estimateCtx.Mode, estimateCtx.Width, estimateCtx.Height, actualCount,
|
|
info.PriceData.ModelPrice, info.PriceData.GlobalModelPrice, info.ImageBilling.UsdPerImage, info.PriceData.Quota,
|
|
))
|
|
}
|
|
}
|
|
|
|
func countImagesInResponseBody(body []byte) int {
|
|
body = bytesTrimSpace(body)
|
|
if len(body) == 0 {
|
|
return 0
|
|
}
|
|
var imageResp dto.ImageResponse
|
|
if err := common.Unmarshal(body, &imageResp); err != nil {
|
|
return 0
|
|
}
|
|
count := 0
|
|
for _, item := range imageResp.Data {
|
|
if strings.TrimSpace(item.Url) != "" || strings.TrimSpace(item.B64Json) != "" {
|
|
count++
|
|
}
|
|
}
|
|
return count
|
|
}
|
|
|
|
func resolveImageDimensions(c *gin.Context, request *dto.ImageRequest, responseBody []byte) (int, int, bool) {
|
|
if request != nil {
|
|
if w, h, ok := parseResolutionFlexible(request.Size); ok {
|
|
return w, h, true
|
|
}
|
|
}
|
|
if w, h, ok := dimensionsFromImageResponseBody(c, responseBody); ok {
|
|
return w, h, true
|
|
}
|
|
if request != nil {
|
|
for _, url := range extractImageInputURLs(request.Image) {
|
|
if w, h, ok := decodeImageURLDimensions(url); ok {
|
|
return w, h, true
|
|
}
|
|
}
|
|
}
|
|
return 0, 0, false
|
|
}
|
|
|
|
func dimensionsFromImageResponseBody(c *gin.Context, body []byte) (int, int, bool) {
|
|
body = bytesTrimSpace(body)
|
|
if len(body) == 0 {
|
|
return 0, 0, false
|
|
}
|
|
var imageResp dto.ImageResponse
|
|
if err := common.Unmarshal(body, &imageResp); err != nil {
|
|
return 0, 0, false
|
|
}
|
|
for _, item := range imageResp.Data {
|
|
if w, h, ok := dimensionsFromImageData(c, item); ok {
|
|
return w, h, true
|
|
}
|
|
}
|
|
return 0, 0, false
|
|
}
|
|
|
|
func dimensionsFromImageData(c *gin.Context, item dto.ImageData) (int, int, bool) {
|
|
if url := strings.TrimSpace(item.Url); url != "" {
|
|
return decodeImageURLDimensions(url)
|
|
}
|
|
if b64 := strings.TrimSpace(item.B64Json); b64 != "" {
|
|
if cfg, _, _, err := service.DecodeBase64ImageData(b64); err == nil && cfg.Width > 0 && cfg.Height > 0 {
|
|
return cfg.Width, cfg.Height, true
|
|
}
|
|
}
|
|
_ = c
|
|
return 0, 0, false
|
|
}
|
|
|
|
func decodeImageURLDimensions(url string) (int, int, bool) {
|
|
url = strings.TrimSpace(url)
|
|
if url == "" {
|
|
return 0, 0, false
|
|
}
|
|
cfg, _, err := service.DecodeUrlImageData(url)
|
|
if err != nil || cfg.Width <= 0 || cfg.Height <= 0 {
|
|
return 0, 0, false
|
|
}
|
|
return cfg.Width, cfg.Height, true
|
|
}
|
|
|
|
func extractImageInputURLs(raw json.RawMessage) []string {
|
|
if len(raw) == 0 {
|
|
return nil
|
|
}
|
|
s := strings.TrimSpace(string(raw))
|
|
if s == "" || s == "null" {
|
|
return nil
|
|
}
|
|
var single string
|
|
if err := common.Unmarshal(raw, &single); err == nil {
|
|
single = strings.TrimSpace(single)
|
|
if single != "" {
|
|
return []string{single}
|
|
}
|
|
}
|
|
var list []string
|
|
if err := common.Unmarshal(raw, &list); err == nil {
|
|
out := make([]string, 0, len(list))
|
|
for _, item := range list {
|
|
item = strings.TrimSpace(item)
|
|
if item != "" {
|
|
out = append(out, item)
|
|
}
|
|
}
|
|
if len(out) > 0 {
|
|
return out
|
|
}
|
|
}
|
|
var objects []struct {
|
|
URL string `json:"url"`
|
|
}
|
|
if err := common.Unmarshal(raw, &objects); err == nil {
|
|
out := make([]string, 0, len(objects))
|
|
for _, obj := range objects {
|
|
u := strings.TrimSpace(obj.URL)
|
|
if u != "" {
|
|
out = append(out, u)
|
|
}
|
|
}
|
|
if len(out) > 0 {
|
|
return out
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func bytesTrimSpace(b []byte) []byte {
|
|
return []byte(strings.TrimSpace(string(b)))
|
|
}
|