tokenFactory/model/supplier_pricing.go

486 lines
17 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package model
import (
"errors"
"fmt"
"strings"
"github.com/QuantumNous/new-api/setting/ratio_setting"
"gorm.io/gorm"
)
// supplierPricingModelsJoined 收集一组映射中出现的全部模型名(用于批量入库)。
// NormalizeOwnedModelsForPricing 将自有模型名与 FormatMatching 规范化名一并纳入权限校验集合。
func NormalizeOwnedModelsForPricing(owned map[string]struct{}) map[string]struct{} {
out := make(map[string]struct{}, len(owned)*2)
for k := range owned {
k = strings.TrimSpace(k)
if k == "" {
continue
}
out[k] = struct{}{}
out[ratio_setting.FormatMatchingModelName(k)] = struct{}{}
}
return out
}
func supplierPricingModelsJoined(maps map[string]map[string]float64) map[string]struct{} {
out := make(map[string]struct{})
for _, m := range maps {
for name := range m {
name = strings.TrimSpace(name)
if name != "" {
out[name] = struct{}{}
}
}
}
return out
}
// UpsertSupplierModelPricingMaps 将前端提交的模型→数值映射合并写入 supplier_model_pricings覆盖同一供应商下同模型名
func UpsertSupplierModelPricingMaps(supplierApplicationID int, userID int, maps map[string]map[string]float64, ownedModels map[string]struct{}) error {
modelNames := supplierPricingModelsJoined(maps)
for name := range modelNames {
if _, ok := ownedModels[name]; !ok {
return fmt.Errorf("无权配置模型: %s", name)
}
}
return DB.Transaction(func(tx *gorm.DB) error {
for rawName := range modelNames {
modelName := ratio_setting.FormatMatchingModelName(rawName)
row := SupplierModelPricing{
SupplierApplicationID: supplierApplicationID,
ModelName: modelName,
QuotaType: 0,
UpdatedByUserID: userID,
}
if v, ok := pickFloat(maps["ModelPrice"], rawName); ok {
row.QuotaType = 1
row.ModelPrice = floatPtr(v)
}
if v, ok := pickFloat(maps["ModelRatio"], rawName); ok {
row.ModelRatio = floatPtr(v)
if row.ModelPrice == nil {
row.QuotaType = 0
}
}
if v, ok := pickFloat(maps["CompletionRatio"], rawName); ok {
row.CompletionRatio = floatPtr(v)
}
if v, ok := pickFloat(maps["CacheRatio"], rawName); ok {
row.CacheRatio = floatPtr(v)
}
if v, ok := pickFloat(maps["CreateCacheRatio"], rawName); ok {
row.CreateCacheRatio = floatPtr(v)
}
if v, ok := pickFloat(maps["ImageRatio"], rawName); ok {
row.ImageRatio = floatPtr(v)
}
if v, ok := pickFloat(maps["AudioRatio"], rawName); ok {
row.AudioRatio = floatPtr(v)
}
if v, ok := pickFloat(maps["AudioCompletionRatio"], rawName); ok {
row.AudioCompletionRatio = floatPtr(v)
}
if row.ModelPrice == nil && row.ModelRatio == nil &&
row.CompletionRatio == nil && row.CacheRatio == nil && row.CreateCacheRatio == nil &&
row.ImageRatio == nil && row.AudioRatio == nil && row.AudioCompletionRatio == nil {
if err := tx.Where("supplier_application_id = ? AND model_name = ?", supplierApplicationID, modelName).
Delete(&SupplierModelPricing{}).Error; err != nil {
return err
}
continue
}
var existing SupplierModelPricing
err := tx.Where("supplier_application_id = ? AND model_name = ?", supplierApplicationID, modelName).First(&existing).Error
if errors.Is(err, gorm.ErrRecordNotFound) {
if err := tx.Create(&row).Error; err != nil {
return err
}
continue
}
if err != nil {
return err
}
row.ID = existing.ID
if err := tx.Save(&row).Error; err != nil {
return err
}
}
return nil
})
}
// UpsertSupplierChannelModelPricingMaps 写入供应商渠道维度定价映射。
func UpsertSupplierChannelModelPricingMaps(supplierApplicationID int, channelID int, userID int, maps map[string]map[string]float64, ownedModels map[string]struct{}) error {
modelNames := supplierPricingModelsJoined(maps)
for name := range modelNames {
if _, ok := ownedModels[name]; !ok {
return fmt.Errorf("无权配置模型: %s", name)
}
}
return DB.Transaction(func(tx *gorm.DB) error {
for rawName := range modelNames {
modelName := ratio_setting.FormatMatchingModelName(rawName)
row := SupplierChannelModelPricing{
SupplierApplicationID: supplierApplicationID,
ChannelID: channelID,
ModelName: modelName,
QuotaType: 0,
UpdatedByUserID: userID,
}
if v, ok := pickFloat(maps["ModelPrice"], rawName); ok {
row.QuotaType = 1
row.ModelPrice = floatPtr(v)
}
if v, ok := pickFloat(maps["ModelRatio"], rawName); ok {
row.ModelRatio = floatPtr(v)
if row.ModelPrice == nil {
row.QuotaType = 0
}
}
if v, ok := pickFloat(maps["CompletionRatio"], rawName); ok {
row.CompletionRatio = floatPtr(v)
}
if v, ok := pickFloat(maps["CacheRatio"], rawName); ok {
row.CacheRatio = floatPtr(v)
}
if v, ok := pickFloat(maps["CreateCacheRatio"], rawName); ok {
row.CreateCacheRatio = floatPtr(v)
}
if v, ok := pickFloat(maps["ImageRatio"], rawName); ok {
row.ImageRatio = floatPtr(v)
}
if v, ok := pickFloat(maps["AudioRatio"], rawName); ok {
row.AudioRatio = floatPtr(v)
}
if v, ok := pickFloat(maps["AudioCompletionRatio"], rawName); ok {
row.AudioCompletionRatio = floatPtr(v)
}
if row.ModelPrice == nil && row.ModelRatio == nil &&
row.CompletionRatio == nil && row.CacheRatio == nil && row.CreateCacheRatio == nil &&
row.ImageRatio == nil && row.AudioRatio == nil && row.AudioCompletionRatio == nil {
if err := tx.Where("supplier_application_id = ? AND channel_id = ? AND model_name = ?", supplierApplicationID, channelID, modelName).
Delete(&SupplierChannelModelPricing{}).Error; err != nil {
return err
}
continue
}
var existing SupplierChannelModelPricing
err := tx.Where("supplier_application_id = ? AND channel_id = ? AND model_name = ?", supplierApplicationID, channelID, modelName).First(&existing).Error
if errors.Is(err, gorm.ErrRecordNotFound) {
if err := tx.Create(&row).Error; err != nil {
return err
}
continue
}
if err != nil {
return err
}
row.ID = existing.ID
if err := tx.Save(&row).Error; err != nil {
return err
}
}
return nil
})
}
func pickFloat(m map[string]float64, key string) (float64, bool) {
if m == nil {
return 0, false
}
v, ok := m[key]
return v, ok
}
func floatPtr(v float64) *float64 {
return &v
}
// GetSupplierModelPricingRow 读取供应商全局定价单行model_name 存库为 FormatMatching 规范化值)。
func GetSupplierModelPricingRow(supplierApplicationID int, rawModelName string) (*SupplierModelPricing, error) {
if supplierApplicationID <= 0 {
return nil, nil
}
key := ratio_setting.FormatMatchingModelName(rawModelName)
var row SupplierModelPricing
err := DB.Where("supplier_application_id = ? AND model_name = ?", supplierApplicationID, key).First(&row).Error
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, nil
}
if err != nil {
return nil, err
}
return &row, nil
}
// GetSupplierChannelModelPricingRow 读取供应商渠道定价单行。
func GetSupplierChannelModelPricingRow(supplierApplicationID int, channelID int, rawModelName string) (*SupplierChannelModelPricing, error) {
if supplierApplicationID <= 0 || channelID <= 0 {
return nil, nil
}
key := ratio_setting.FormatMatchingModelName(rawModelName)
var row SupplierChannelModelPricing
err := DB.Where("supplier_application_id = ? AND channel_id = ? AND model_name = ?", supplierApplicationID, channelID, key).First(&row).Error
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, nil
}
if err != nil {
return nil, err
}
return &row, nil
}
// ListSupplierModelPricingsForSupplier 列出某供应商全部全局定价行(用于 GET API
func ListSupplierModelPricingsForSupplier(supplierApplicationID int) ([]SupplierModelPricing, error) {
var rows []SupplierModelPricing
err := DB.Where("supplier_application_id = ?", supplierApplicationID).Order("model_name asc").Find(&rows).Error
return rows, err
}
// ListSupplierChannelModelPricings 列出某渠道下全部供应商渠道定价行。
func ListSupplierChannelModelPricings(supplierApplicationID int, channelID int) ([]SupplierChannelModelPricing, error) {
var rows []SupplierChannelModelPricing
err := DB.Where("supplier_application_id = ? AND channel_id = ?", supplierApplicationID, channelID).Order("model_name asc").Find(&rows).Error
return rows, err
}
// BuildOptionLikeMapsFromSupplierGlobalRows 将表行转为与 Option JSON 相同结构的 map便于前端复用编辑器
func BuildOptionLikeMapsFromSupplierGlobalRows(rows []SupplierModelPricing) map[string]map[string]float64 {
out := map[string]map[string]float64{
"ModelPrice": {},
"ModelRatio": {},
"CompletionRatio": {},
"CacheRatio": {},
"CreateCacheRatio": {},
"ImageRatio": {},
"AudioRatio": {},
"AudioCompletionRatio": {},
}
for _, r := range rows {
name := r.ModelName
if r.ModelPrice != nil {
out["ModelPrice"][name] = *r.ModelPrice
}
if r.ModelRatio != nil {
out["ModelRatio"][name] = *r.ModelRatio
}
if r.CompletionRatio != nil {
out["CompletionRatio"][name] = *r.CompletionRatio
}
if r.CacheRatio != nil {
out["CacheRatio"][name] = *r.CacheRatio
}
if r.CreateCacheRatio != nil {
out["CreateCacheRatio"][name] = *r.CreateCacheRatio
}
if r.ImageRatio != nil {
out["ImageRatio"][name] = *r.ImageRatio
}
if r.AudioRatio != nil {
out["AudioRatio"][name] = *r.AudioRatio
}
if r.AudioCompletionRatio != nil {
out["AudioCompletionRatio"][name] = *r.AudioCompletionRatio
}
}
return out
}
// BuildOptionLikeMapsFromSupplierChannelRows 将渠道定价行转为 map。
func BuildOptionLikeMapsFromSupplierChannelRows(rows []SupplierChannelModelPricing) map[string]map[string]float64 {
out := map[string]map[string]float64{
"ModelPrice": {},
"ModelRatio": {},
"CompletionRatio": {},
"CacheRatio": {},
"CreateCacheRatio": {},
"ImageRatio": {},
"AudioRatio": {},
"AudioCompletionRatio": {},
}
for _, r := range rows {
name := r.ModelName
if r.ModelPrice != nil {
out["ModelPrice"][name] = *r.ModelPrice
}
if r.ModelRatio != nil {
out["ModelRatio"][name] = *r.ModelRatio
}
if r.CompletionRatio != nil {
out["CompletionRatio"][name] = *r.CompletionRatio
}
if r.CacheRatio != nil {
out["CacheRatio"][name] = *r.CacheRatio
}
if r.CreateCacheRatio != nil {
out["CreateCacheRatio"][name] = *r.CreateCacheRatio
}
if r.ImageRatio != nil {
out["ImageRatio"][name] = *r.ImageRatio
}
if r.AudioRatio != nil {
out["AudioRatio"][name] = *r.AudioRatio
}
if r.AudioCompletionRatio != nil {
out["AudioCompletionRatio"][name] = *r.AudioCompletionRatio
}
}
return out
}
// ResolveSupplierScopedFixedModelPrice 解析按次固定价($/次):供应商渠道表 > 供应商全局表 > Option 渠道价 > 平台全局 > 旧 SupplierModelPrice。
func ResolveSupplierScopedFixedModelPrice(channelID int, supplierApplicationID int, modelName string) (float64, bool) {
modelName = ratio_setting.FormatMatchingModelName(modelName)
if supplierApplicationID > 0 && channelID > 0 {
if row, _ := GetSupplierChannelModelPricingRow(supplierApplicationID, channelID, modelName); row != nil && row.ModelPrice != nil {
return *row.ModelPrice, true
}
if row, _ := GetSupplierModelPricingRow(supplierApplicationID, modelName); row != nil && row.ModelPrice != nil {
return *row.ModelPrice, true
}
}
if v, ok := ratio_setting.GetChannelModelPrice(channelID, modelName); ok {
return v, true
}
if v, ok := ratio_setting.GetModelPrice(modelName, false); ok {
return v, true
}
if supplierApplicationID > 0 {
if v, ok := ratio_setting.GetSupplierModelPrice(supplierApplicationID, modelName); ok {
return v, true
}
}
return 0, false
}
// ResolveSupplierScopedModelRatio 解析输入倍率:供应商渠道表 > 供应商全局表 > Option 渠道倍率 > 旧 SupplierModelRatio > 平台全局。
func ResolveSupplierScopedModelRatio(channelID int, supplierApplicationID int, modelName string) (float64, bool, string) {
modelName = ratio_setting.FormatMatchingModelName(modelName)
if supplierApplicationID > 0 && channelID > 0 {
if row, _ := GetSupplierChannelModelPricingRow(supplierApplicationID, channelID, modelName); row != nil && row.ModelRatio != nil {
return *row.ModelRatio, true, ""
}
if row, _ := GetSupplierModelPricingRow(supplierApplicationID, modelName); row != nil && row.ModelRatio != nil {
return *row.ModelRatio, true, ""
}
}
if v, ok := ratio_setting.GetChannelModelRatio(channelID, modelName); ok {
return v, true, ""
}
if supplierApplicationID > 0 {
if v, ok := ratio_setting.GetSupplierModelRatio(supplierApplicationID, modelName); ok {
return v, true, ""
}
}
return ratio_setting.GetModelRatio(modelName)
}
// ResolveSupplierScopedCompletionRatio 解析输出倍率:供应商渠道行 > 供应商全局行 > Option 渠道 > 平台全局。
func ResolveSupplierScopedCompletionRatio(channelID int, supplierApplicationID int, modelName string) float64 {
modelName = ratio_setting.FormatMatchingModelName(modelName)
if supplierApplicationID > 0 && channelID > 0 {
if row, _ := GetSupplierChannelModelPricingRow(supplierApplicationID, channelID, modelName); row != nil && row.CompletionRatio != nil {
return *row.CompletionRatio
}
if row, _ := GetSupplierModelPricingRow(supplierApplicationID, modelName); row != nil && row.CompletionRatio != nil {
return *row.CompletionRatio
}
}
if v, ok := ratio_setting.GetChannelCompletionRatio(channelID, modelName); ok {
return v
}
return ratio_setting.GetCompletionRatio(modelName)
}
// ResolveSupplierScopedImageRatio 供应商渠道/全局行优先,其次 Option 渠道与平台全局。
func ResolveSupplierScopedImageRatio(channelID int, supplierApplicationID int, modelName string) (float64, bool) {
modelName = ratio_setting.FormatMatchingModelName(modelName)
if supplierApplicationID > 0 && channelID > 0 {
if row, _ := GetSupplierChannelModelPricingRow(supplierApplicationID, channelID, modelName); row != nil && row.ImageRatio != nil {
return *row.ImageRatio, true
}
}
if supplierApplicationID > 0 {
if row, _ := GetSupplierModelPricingRow(supplierApplicationID, modelName); row != nil && row.ImageRatio != nil {
return *row.ImageRatio, true
}
}
if v, ok := ratio_setting.GetChannelImageRatio(channelID, modelName); ok {
return v, true
}
return ratio_setting.GetImageRatio(modelName)
}
// ResolveSupplierScopedAudioRatio 音频输入倍率。
func ResolveSupplierScopedAudioRatio(channelID int, supplierApplicationID int, modelName string) float64 {
modelName = ratio_setting.FormatMatchingModelName(modelName)
if supplierApplicationID > 0 && channelID > 0 {
if row, _ := GetSupplierChannelModelPricingRow(supplierApplicationID, channelID, modelName); row != nil && row.AudioRatio != nil {
return *row.AudioRatio
}
}
if supplierApplicationID > 0 {
if row, _ := GetSupplierModelPricingRow(supplierApplicationID, modelName); row != nil && row.AudioRatio != nil {
return *row.AudioRatio
}
}
if v, ok := ratio_setting.GetChannelAudioRatio(channelID, modelName); ok {
return v
}
return ratio_setting.GetAudioRatio(modelName)
}
// ResolveSupplierScopedAudioCompletionRatio 音频输出相对输入倍率。
func ResolveSupplierScopedAudioCompletionRatio(channelID int, supplierApplicationID int, modelName string) float64 {
modelName = ratio_setting.FormatMatchingModelName(modelName)
if supplierApplicationID > 0 && channelID > 0 {
if row, _ := GetSupplierChannelModelPricingRow(supplierApplicationID, channelID, modelName); row != nil && row.AudioCompletionRatio != nil {
return *row.AudioCompletionRatio
}
}
if supplierApplicationID > 0 {
if row, _ := GetSupplierModelPricingRow(supplierApplicationID, modelName); row != nil && row.AudioCompletionRatio != nil {
return *row.AudioCompletionRatio
}
}
if v, ok := ratio_setting.GetChannelAudioCompletionRatio(channelID, modelName); ok {
return v
}
return ratio_setting.GetAudioCompletionRatio(modelName)
}
// ResolveSupplierScopedCacheRatios 解析缓存读写倍率(供应商渠道/全局行优先)。
func ResolveSupplierScopedCacheRatios(channelID int, supplierApplicationID int, modelName string) (cacheRatio float64, createCacheRatio float64) {
modelName = ratio_setting.FormatMatchingModelName(modelName)
var chRow *SupplierChannelModelPricing
var glRow *SupplierModelPricing
if supplierApplicationID > 0 && channelID > 0 {
chRow, _ = GetSupplierChannelModelPricingRow(supplierApplicationID, channelID, modelName)
}
if supplierApplicationID > 0 {
glRow, _ = GetSupplierModelPricingRow(supplierApplicationID, modelName)
}
if chRow != nil && chRow.CacheRatio != nil {
cacheRatio = *chRow.CacheRatio
} else if glRow != nil && glRow.CacheRatio != nil {
cacheRatio = *glRow.CacheRatio
} else if v, ok := ratio_setting.GetChannelCacheRatio(channelID, modelName); ok {
cacheRatio = v
} else {
cacheRatio, _ = ratio_setting.GetCacheRatio(modelName)
}
if chRow != nil && chRow.CreateCacheRatio != nil {
createCacheRatio = *chRow.CreateCacheRatio
} else if glRow != nil && glRow.CreateCacheRatio != nil {
createCacheRatio = *glRow.CreateCacheRatio
} else if v, ok := ratio_setting.GetChannelCreateCacheRatio(channelID, modelName); ok {
createCacheRatio = v
} else {
createCacheRatio, _ = ratio_setting.GetCreateCacheRatio(modelName)
}
return cacheRatio, createCacheRatio
}