486 lines
17 KiB
Go
486 lines
17 KiB
Go
package model
|
||
|
||
import (
|
||
"errors"
|
||
"fmt"
|
||
"strings"
|
||
|
||
"github.com/QuantumNous/new-api/setting/ratio_setting"
|
||
"gorm.io/gorm"
|
||
)
|
||
|
||
// supplierPricingModelsJoined 收集一组映射中出现的全部模型名(用于批量入库)。
|
||
// NormalizeOwnedModelsForPricing 将自有模型名与 FormatMatching 规范化名一并纳入权限校验集合。
|
||
func NormalizeOwnedModelsForPricing(owned map[string]struct{}) map[string]struct{} {
|
||
out := make(map[string]struct{}, len(owned)*2)
|
||
for k := range owned {
|
||
k = strings.TrimSpace(k)
|
||
if k == "" {
|
||
continue
|
||
}
|
||
out[k] = struct{}{}
|
||
out[ratio_setting.FormatMatchingModelName(k)] = struct{}{}
|
||
}
|
||
return out
|
||
}
|
||
|
||
func supplierPricingModelsJoined(maps map[string]map[string]float64) map[string]struct{} {
|
||
out := make(map[string]struct{})
|
||
for _, m := range maps {
|
||
for name := range m {
|
||
name = strings.TrimSpace(name)
|
||
if name != "" {
|
||
out[name] = struct{}{}
|
||
}
|
||
}
|
||
}
|
||
return out
|
||
}
|
||
|
||
// UpsertSupplierModelPricingMaps 将前端提交的模型→数值映射合并写入 supplier_model_pricings(覆盖同一供应商下同模型名)。
|
||
func UpsertSupplierModelPricingMaps(supplierApplicationID int, userID int, maps map[string]map[string]float64, ownedModels map[string]struct{}) error {
|
||
modelNames := supplierPricingModelsJoined(maps)
|
||
for name := range modelNames {
|
||
if _, ok := ownedModels[name]; !ok {
|
||
return fmt.Errorf("无权配置模型: %s", name)
|
||
}
|
||
}
|
||
return DB.Transaction(func(tx *gorm.DB) error {
|
||
for rawName := range modelNames {
|
||
modelName := ratio_setting.FormatMatchingModelName(rawName)
|
||
row := SupplierModelPricing{
|
||
SupplierApplicationID: supplierApplicationID,
|
||
ModelName: modelName,
|
||
QuotaType: 0,
|
||
UpdatedByUserID: userID,
|
||
}
|
||
if v, ok := pickFloat(maps["ModelPrice"], rawName); ok {
|
||
row.QuotaType = 1
|
||
row.ModelPrice = floatPtr(v)
|
||
}
|
||
if v, ok := pickFloat(maps["ModelRatio"], rawName); ok {
|
||
row.ModelRatio = floatPtr(v)
|
||
if row.ModelPrice == nil {
|
||
row.QuotaType = 0
|
||
}
|
||
}
|
||
if v, ok := pickFloat(maps["CompletionRatio"], rawName); ok {
|
||
row.CompletionRatio = floatPtr(v)
|
||
}
|
||
if v, ok := pickFloat(maps["CacheRatio"], rawName); ok {
|
||
row.CacheRatio = floatPtr(v)
|
||
}
|
||
if v, ok := pickFloat(maps["CreateCacheRatio"], rawName); ok {
|
||
row.CreateCacheRatio = floatPtr(v)
|
||
}
|
||
if v, ok := pickFloat(maps["ImageRatio"], rawName); ok {
|
||
row.ImageRatio = floatPtr(v)
|
||
}
|
||
if v, ok := pickFloat(maps["AudioRatio"], rawName); ok {
|
||
row.AudioRatio = floatPtr(v)
|
||
}
|
||
if v, ok := pickFloat(maps["AudioCompletionRatio"], rawName); ok {
|
||
row.AudioCompletionRatio = floatPtr(v)
|
||
}
|
||
if row.ModelPrice == nil && row.ModelRatio == nil &&
|
||
row.CompletionRatio == nil && row.CacheRatio == nil && row.CreateCacheRatio == nil &&
|
||
row.ImageRatio == nil && row.AudioRatio == nil && row.AudioCompletionRatio == nil {
|
||
if err := tx.Where("supplier_application_id = ? AND model_name = ?", supplierApplicationID, modelName).
|
||
Delete(&SupplierModelPricing{}).Error; err != nil {
|
||
return err
|
||
}
|
||
continue
|
||
}
|
||
var existing SupplierModelPricing
|
||
err := tx.Where("supplier_application_id = ? AND model_name = ?", supplierApplicationID, modelName).First(&existing).Error
|
||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||
if err := tx.Create(&row).Error; err != nil {
|
||
return err
|
||
}
|
||
continue
|
||
}
|
||
if err != nil {
|
||
return err
|
||
}
|
||
row.ID = existing.ID
|
||
if err := tx.Save(&row).Error; err != nil {
|
||
return err
|
||
}
|
||
}
|
||
return nil
|
||
})
|
||
}
|
||
|
||
// UpsertSupplierChannelModelPricingMaps 写入供应商渠道维度定价映射。
|
||
func UpsertSupplierChannelModelPricingMaps(supplierApplicationID int, channelID int, userID int, maps map[string]map[string]float64, ownedModels map[string]struct{}) error {
|
||
modelNames := supplierPricingModelsJoined(maps)
|
||
for name := range modelNames {
|
||
if _, ok := ownedModels[name]; !ok {
|
||
return fmt.Errorf("无权配置模型: %s", name)
|
||
}
|
||
}
|
||
return DB.Transaction(func(tx *gorm.DB) error {
|
||
for rawName := range modelNames {
|
||
modelName := ratio_setting.FormatMatchingModelName(rawName)
|
||
row := SupplierChannelModelPricing{
|
||
SupplierApplicationID: supplierApplicationID,
|
||
ChannelID: channelID,
|
||
ModelName: modelName,
|
||
QuotaType: 0,
|
||
UpdatedByUserID: userID,
|
||
}
|
||
if v, ok := pickFloat(maps["ModelPrice"], rawName); ok {
|
||
row.QuotaType = 1
|
||
row.ModelPrice = floatPtr(v)
|
||
}
|
||
if v, ok := pickFloat(maps["ModelRatio"], rawName); ok {
|
||
row.ModelRatio = floatPtr(v)
|
||
if row.ModelPrice == nil {
|
||
row.QuotaType = 0
|
||
}
|
||
}
|
||
if v, ok := pickFloat(maps["CompletionRatio"], rawName); ok {
|
||
row.CompletionRatio = floatPtr(v)
|
||
}
|
||
if v, ok := pickFloat(maps["CacheRatio"], rawName); ok {
|
||
row.CacheRatio = floatPtr(v)
|
||
}
|
||
if v, ok := pickFloat(maps["CreateCacheRatio"], rawName); ok {
|
||
row.CreateCacheRatio = floatPtr(v)
|
||
}
|
||
if v, ok := pickFloat(maps["ImageRatio"], rawName); ok {
|
||
row.ImageRatio = floatPtr(v)
|
||
}
|
||
if v, ok := pickFloat(maps["AudioRatio"], rawName); ok {
|
||
row.AudioRatio = floatPtr(v)
|
||
}
|
||
if v, ok := pickFloat(maps["AudioCompletionRatio"], rawName); ok {
|
||
row.AudioCompletionRatio = floatPtr(v)
|
||
}
|
||
if row.ModelPrice == nil && row.ModelRatio == nil &&
|
||
row.CompletionRatio == nil && row.CacheRatio == nil && row.CreateCacheRatio == nil &&
|
||
row.ImageRatio == nil && row.AudioRatio == nil && row.AudioCompletionRatio == nil {
|
||
if err := tx.Where("supplier_application_id = ? AND channel_id = ? AND model_name = ?", supplierApplicationID, channelID, modelName).
|
||
Delete(&SupplierChannelModelPricing{}).Error; err != nil {
|
||
return err
|
||
}
|
||
continue
|
||
}
|
||
var existing SupplierChannelModelPricing
|
||
err := tx.Where("supplier_application_id = ? AND channel_id = ? AND model_name = ?", supplierApplicationID, channelID, modelName).First(&existing).Error
|
||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||
if err := tx.Create(&row).Error; err != nil {
|
||
return err
|
||
}
|
||
continue
|
||
}
|
||
if err != nil {
|
||
return err
|
||
}
|
||
row.ID = existing.ID
|
||
if err := tx.Save(&row).Error; err != nil {
|
||
return err
|
||
}
|
||
}
|
||
return nil
|
||
})
|
||
}
|
||
|
||
func pickFloat(m map[string]float64, key string) (float64, bool) {
|
||
if m == nil {
|
||
return 0, false
|
||
}
|
||
v, ok := m[key]
|
||
return v, ok
|
||
}
|
||
|
||
func floatPtr(v float64) *float64 {
|
||
return &v
|
||
}
|
||
|
||
// GetSupplierModelPricingRow 读取供应商全局定价单行(model_name 存库为 FormatMatching 规范化值)。
|
||
func GetSupplierModelPricingRow(supplierApplicationID int, rawModelName string) (*SupplierModelPricing, error) {
|
||
if supplierApplicationID <= 0 {
|
||
return nil, nil
|
||
}
|
||
key := ratio_setting.FormatMatchingModelName(rawModelName)
|
||
var row SupplierModelPricing
|
||
err := DB.Where("supplier_application_id = ? AND model_name = ?", supplierApplicationID, key).First(&row).Error
|
||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||
return nil, nil
|
||
}
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
return &row, nil
|
||
}
|
||
|
||
// GetSupplierChannelModelPricingRow 读取供应商渠道定价单行。
|
||
func GetSupplierChannelModelPricingRow(supplierApplicationID int, channelID int, rawModelName string) (*SupplierChannelModelPricing, error) {
|
||
if supplierApplicationID <= 0 || channelID <= 0 {
|
||
return nil, nil
|
||
}
|
||
key := ratio_setting.FormatMatchingModelName(rawModelName)
|
||
var row SupplierChannelModelPricing
|
||
err := DB.Where("supplier_application_id = ? AND channel_id = ? AND model_name = ?", supplierApplicationID, channelID, key).First(&row).Error
|
||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||
return nil, nil
|
||
}
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
return &row, nil
|
||
}
|
||
|
||
// ListSupplierModelPricingsForSupplier 列出某供应商全部全局定价行(用于 GET API)。
|
||
func ListSupplierModelPricingsForSupplier(supplierApplicationID int) ([]SupplierModelPricing, error) {
|
||
var rows []SupplierModelPricing
|
||
err := DB.Where("supplier_application_id = ?", supplierApplicationID).Order("model_name asc").Find(&rows).Error
|
||
return rows, err
|
||
}
|
||
|
||
// ListSupplierChannelModelPricings 列出某渠道下全部供应商渠道定价行。
|
||
func ListSupplierChannelModelPricings(supplierApplicationID int, channelID int) ([]SupplierChannelModelPricing, error) {
|
||
var rows []SupplierChannelModelPricing
|
||
err := DB.Where("supplier_application_id = ? AND channel_id = ?", supplierApplicationID, channelID).Order("model_name asc").Find(&rows).Error
|
||
return rows, err
|
||
}
|
||
|
||
// BuildOptionLikeMapsFromSupplierGlobalRows 将表行转为与 Option JSON 相同结构的 map(便于前端复用编辑器)。
|
||
func BuildOptionLikeMapsFromSupplierGlobalRows(rows []SupplierModelPricing) map[string]map[string]float64 {
|
||
out := map[string]map[string]float64{
|
||
"ModelPrice": {},
|
||
"ModelRatio": {},
|
||
"CompletionRatio": {},
|
||
"CacheRatio": {},
|
||
"CreateCacheRatio": {},
|
||
"ImageRatio": {},
|
||
"AudioRatio": {},
|
||
"AudioCompletionRatio": {},
|
||
}
|
||
for _, r := range rows {
|
||
name := r.ModelName
|
||
if r.ModelPrice != nil {
|
||
out["ModelPrice"][name] = *r.ModelPrice
|
||
}
|
||
if r.ModelRatio != nil {
|
||
out["ModelRatio"][name] = *r.ModelRatio
|
||
}
|
||
if r.CompletionRatio != nil {
|
||
out["CompletionRatio"][name] = *r.CompletionRatio
|
||
}
|
||
if r.CacheRatio != nil {
|
||
out["CacheRatio"][name] = *r.CacheRatio
|
||
}
|
||
if r.CreateCacheRatio != nil {
|
||
out["CreateCacheRatio"][name] = *r.CreateCacheRatio
|
||
}
|
||
if r.ImageRatio != nil {
|
||
out["ImageRatio"][name] = *r.ImageRatio
|
||
}
|
||
if r.AudioRatio != nil {
|
||
out["AudioRatio"][name] = *r.AudioRatio
|
||
}
|
||
if r.AudioCompletionRatio != nil {
|
||
out["AudioCompletionRatio"][name] = *r.AudioCompletionRatio
|
||
}
|
||
}
|
||
return out
|
||
}
|
||
|
||
// BuildOptionLikeMapsFromSupplierChannelRows 将渠道定价行转为 map。
|
||
func BuildOptionLikeMapsFromSupplierChannelRows(rows []SupplierChannelModelPricing) map[string]map[string]float64 {
|
||
out := map[string]map[string]float64{
|
||
"ModelPrice": {},
|
||
"ModelRatio": {},
|
||
"CompletionRatio": {},
|
||
"CacheRatio": {},
|
||
"CreateCacheRatio": {},
|
||
"ImageRatio": {},
|
||
"AudioRatio": {},
|
||
"AudioCompletionRatio": {},
|
||
}
|
||
for _, r := range rows {
|
||
name := r.ModelName
|
||
if r.ModelPrice != nil {
|
||
out["ModelPrice"][name] = *r.ModelPrice
|
||
}
|
||
if r.ModelRatio != nil {
|
||
out["ModelRatio"][name] = *r.ModelRatio
|
||
}
|
||
if r.CompletionRatio != nil {
|
||
out["CompletionRatio"][name] = *r.CompletionRatio
|
||
}
|
||
if r.CacheRatio != nil {
|
||
out["CacheRatio"][name] = *r.CacheRatio
|
||
}
|
||
if r.CreateCacheRatio != nil {
|
||
out["CreateCacheRatio"][name] = *r.CreateCacheRatio
|
||
}
|
||
if r.ImageRatio != nil {
|
||
out["ImageRatio"][name] = *r.ImageRatio
|
||
}
|
||
if r.AudioRatio != nil {
|
||
out["AudioRatio"][name] = *r.AudioRatio
|
||
}
|
||
if r.AudioCompletionRatio != nil {
|
||
out["AudioCompletionRatio"][name] = *r.AudioCompletionRatio
|
||
}
|
||
}
|
||
return out
|
||
}
|
||
|
||
// ResolveSupplierScopedFixedModelPrice 解析按次固定价($/次):供应商渠道表 > 供应商全局表 > Option 渠道价 > 平台全局 > 旧 SupplierModelPrice。
|
||
func ResolveSupplierScopedFixedModelPrice(channelID int, supplierApplicationID int, modelName string) (float64, bool) {
|
||
modelName = ratio_setting.FormatMatchingModelName(modelName)
|
||
if supplierApplicationID > 0 && channelID > 0 {
|
||
if row, _ := GetSupplierChannelModelPricingRow(supplierApplicationID, channelID, modelName); row != nil && row.ModelPrice != nil {
|
||
return *row.ModelPrice, true
|
||
}
|
||
if row, _ := GetSupplierModelPricingRow(supplierApplicationID, modelName); row != nil && row.ModelPrice != nil {
|
||
return *row.ModelPrice, true
|
||
}
|
||
}
|
||
if v, ok := ratio_setting.GetChannelModelPrice(channelID, modelName); ok {
|
||
return v, true
|
||
}
|
||
if v, ok := ratio_setting.GetModelPrice(modelName, false); ok {
|
||
return v, true
|
||
}
|
||
if supplierApplicationID > 0 {
|
||
if v, ok := ratio_setting.GetSupplierModelPrice(supplierApplicationID, modelName); ok {
|
||
return v, true
|
||
}
|
||
}
|
||
return 0, false
|
||
}
|
||
|
||
// ResolveSupplierScopedModelRatio 解析输入倍率:供应商渠道表 > 供应商全局表 > Option 渠道倍率 > 旧 SupplierModelRatio > 平台全局。
|
||
func ResolveSupplierScopedModelRatio(channelID int, supplierApplicationID int, modelName string) (float64, bool, string) {
|
||
modelName = ratio_setting.FormatMatchingModelName(modelName)
|
||
if supplierApplicationID > 0 && channelID > 0 {
|
||
if row, _ := GetSupplierChannelModelPricingRow(supplierApplicationID, channelID, modelName); row != nil && row.ModelRatio != nil {
|
||
return *row.ModelRatio, true, ""
|
||
}
|
||
if row, _ := GetSupplierModelPricingRow(supplierApplicationID, modelName); row != nil && row.ModelRatio != nil {
|
||
return *row.ModelRatio, true, ""
|
||
}
|
||
}
|
||
if v, ok := ratio_setting.GetChannelModelRatio(channelID, modelName); ok {
|
||
return v, true, ""
|
||
}
|
||
if supplierApplicationID > 0 {
|
||
if v, ok := ratio_setting.GetSupplierModelRatio(supplierApplicationID, modelName); ok {
|
||
return v, true, ""
|
||
}
|
||
}
|
||
return ratio_setting.GetModelRatio(modelName)
|
||
}
|
||
|
||
// ResolveSupplierScopedCompletionRatio 解析输出倍率:供应商渠道行 > 供应商全局行 > Option 渠道 > 平台全局。
|
||
func ResolveSupplierScopedCompletionRatio(channelID int, supplierApplicationID int, modelName string) float64 {
|
||
modelName = ratio_setting.FormatMatchingModelName(modelName)
|
||
if supplierApplicationID > 0 && channelID > 0 {
|
||
if row, _ := GetSupplierChannelModelPricingRow(supplierApplicationID, channelID, modelName); row != nil && row.CompletionRatio != nil {
|
||
return *row.CompletionRatio
|
||
}
|
||
if row, _ := GetSupplierModelPricingRow(supplierApplicationID, modelName); row != nil && row.CompletionRatio != nil {
|
||
return *row.CompletionRatio
|
||
}
|
||
}
|
||
if v, ok := ratio_setting.GetChannelCompletionRatio(channelID, modelName); ok {
|
||
return v
|
||
}
|
||
return ratio_setting.GetCompletionRatio(modelName)
|
||
}
|
||
|
||
// ResolveSupplierScopedImageRatio 供应商渠道/全局行优先,其次 Option 渠道与平台全局。
|
||
func ResolveSupplierScopedImageRatio(channelID int, supplierApplicationID int, modelName string) (float64, bool) {
|
||
modelName = ratio_setting.FormatMatchingModelName(modelName)
|
||
if supplierApplicationID > 0 && channelID > 0 {
|
||
if row, _ := GetSupplierChannelModelPricingRow(supplierApplicationID, channelID, modelName); row != nil && row.ImageRatio != nil {
|
||
return *row.ImageRatio, true
|
||
}
|
||
}
|
||
if supplierApplicationID > 0 {
|
||
if row, _ := GetSupplierModelPricingRow(supplierApplicationID, modelName); row != nil && row.ImageRatio != nil {
|
||
return *row.ImageRatio, true
|
||
}
|
||
}
|
||
if v, ok := ratio_setting.GetChannelImageRatio(channelID, modelName); ok {
|
||
return v, true
|
||
}
|
||
return ratio_setting.GetImageRatio(modelName)
|
||
}
|
||
|
||
// ResolveSupplierScopedAudioRatio 音频输入倍率。
|
||
func ResolveSupplierScopedAudioRatio(channelID int, supplierApplicationID int, modelName string) float64 {
|
||
modelName = ratio_setting.FormatMatchingModelName(modelName)
|
||
if supplierApplicationID > 0 && channelID > 0 {
|
||
if row, _ := GetSupplierChannelModelPricingRow(supplierApplicationID, channelID, modelName); row != nil && row.AudioRatio != nil {
|
||
return *row.AudioRatio
|
||
}
|
||
}
|
||
if supplierApplicationID > 0 {
|
||
if row, _ := GetSupplierModelPricingRow(supplierApplicationID, modelName); row != nil && row.AudioRatio != nil {
|
||
return *row.AudioRatio
|
||
}
|
||
}
|
||
if v, ok := ratio_setting.GetChannelAudioRatio(channelID, modelName); ok {
|
||
return v
|
||
}
|
||
return ratio_setting.GetAudioRatio(modelName)
|
||
}
|
||
|
||
// ResolveSupplierScopedAudioCompletionRatio 音频输出相对输入倍率。
|
||
func ResolveSupplierScopedAudioCompletionRatio(channelID int, supplierApplicationID int, modelName string) float64 {
|
||
modelName = ratio_setting.FormatMatchingModelName(modelName)
|
||
if supplierApplicationID > 0 && channelID > 0 {
|
||
if row, _ := GetSupplierChannelModelPricingRow(supplierApplicationID, channelID, modelName); row != nil && row.AudioCompletionRatio != nil {
|
||
return *row.AudioCompletionRatio
|
||
}
|
||
}
|
||
if supplierApplicationID > 0 {
|
||
if row, _ := GetSupplierModelPricingRow(supplierApplicationID, modelName); row != nil && row.AudioCompletionRatio != nil {
|
||
return *row.AudioCompletionRatio
|
||
}
|
||
}
|
||
if v, ok := ratio_setting.GetChannelAudioCompletionRatio(channelID, modelName); ok {
|
||
return v
|
||
}
|
||
return ratio_setting.GetAudioCompletionRatio(modelName)
|
||
}
|
||
|
||
// ResolveSupplierScopedCacheRatios 解析缓存读写倍率(供应商渠道/全局行优先)。
|
||
func ResolveSupplierScopedCacheRatios(channelID int, supplierApplicationID int, modelName string) (cacheRatio float64, createCacheRatio float64) {
|
||
modelName = ratio_setting.FormatMatchingModelName(modelName)
|
||
var chRow *SupplierChannelModelPricing
|
||
var glRow *SupplierModelPricing
|
||
if supplierApplicationID > 0 && channelID > 0 {
|
||
chRow, _ = GetSupplierChannelModelPricingRow(supplierApplicationID, channelID, modelName)
|
||
}
|
||
if supplierApplicationID > 0 {
|
||
glRow, _ = GetSupplierModelPricingRow(supplierApplicationID, modelName)
|
||
}
|
||
if chRow != nil && chRow.CacheRatio != nil {
|
||
cacheRatio = *chRow.CacheRatio
|
||
} else if glRow != nil && glRow.CacheRatio != nil {
|
||
cacheRatio = *glRow.CacheRatio
|
||
} else if v, ok := ratio_setting.GetChannelCacheRatio(channelID, modelName); ok {
|
||
cacheRatio = v
|
||
} else {
|
||
cacheRatio, _ = ratio_setting.GetCacheRatio(modelName)
|
||
}
|
||
|
||
if chRow != nil && chRow.CreateCacheRatio != nil {
|
||
createCacheRatio = *chRow.CreateCacheRatio
|
||
} else if glRow != nil && glRow.CreateCacheRatio != nil {
|
||
createCacheRatio = *glRow.CreateCacheRatio
|
||
} else if v, ok := ratio_setting.GetChannelCreateCacheRatio(channelID, modelName); ok {
|
||
createCacheRatio = v
|
||
} else {
|
||
createCacheRatio, _ = ratio_setting.GetCreateCacheRatio(modelName)
|
||
}
|
||
return cacheRatio, createCacheRatio
|
||
}
|