152 lines
4.1 KiB
Go
152 lines
4.1 KiB
Go
package controller
|
||
|
||
import (
|
||
"strings"
|
||
|
||
"github.com/QuantumNous/new-api/common"
|
||
"github.com/QuantumNous/new-api/model"
|
||
)
|
||
|
||
// supplierEditableModelOptionKeys 定义供应商可操作的模型倍率相关配置键。
|
||
var supplierEditableModelOptionKeys = map[string]struct{}{
|
||
"ModelPrice": {},
|
||
"ModelRatio": {},
|
||
"CompletionRatio": {},
|
||
"CacheRatio": {},
|
||
"CreateCacheRatio": {},
|
||
"ImageRatio": {},
|
||
"AudioRatio": {},
|
||
"AudioCompletionRatio": {},
|
||
"VideoRatio": {},
|
||
"VideoCompletionRatio": {},
|
||
"VideoPrice": {},
|
||
"VideoPricingRules": {},
|
||
"ImagePrice": {},
|
||
"ImagePricingRules": {},
|
||
}
|
||
|
||
// collectSupplierOwnedModelNames 收集供应商名下渠道与模型中的模型名集合。
|
||
func collectSupplierOwnedModelNames(userID int) (map[string]struct{}, error) {
|
||
ownedModels := make(map[string]struct{})
|
||
|
||
channels, _, err := model.SearchSupplierChannels(&userID, 0, 100000, model.SupplierChannelSearchFilter{})
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
for _, channel := range channels {
|
||
for _, modelName := range channel.GetModels() {
|
||
modelName = strings.TrimSpace(modelName)
|
||
if modelName == "" {
|
||
continue
|
||
}
|
||
ownedModels[modelName] = struct{}{}
|
||
}
|
||
}
|
||
|
||
models, _, err := model.SearchSupplierModels(&userID, "", "", 0, 100000)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
for _, item := range models {
|
||
modelName := strings.TrimSpace(item.ModelName)
|
||
if modelName == "" {
|
||
continue
|
||
}
|
||
ownedModels[modelName] = struct{}{}
|
||
}
|
||
|
||
return ownedModels, nil
|
||
}
|
||
|
||
// collectAllSupplierOwnedModelNames 收集全部供应商名下的模型名集合(管理员统计用)。
|
||
func collectAllSupplierOwnedModelNames() (map[string]struct{}, error) {
|
||
ownedModels := make(map[string]struct{})
|
||
|
||
channels, _, err := model.SearchSupplierChannels(nil, 0, 100000, model.SupplierChannelSearchFilter{})
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
for _, channel := range channels {
|
||
for _, modelName := range channel.GetModels() {
|
||
modelName = strings.TrimSpace(modelName)
|
||
if modelName == "" {
|
||
continue
|
||
}
|
||
ownedModels[modelName] = struct{}{}
|
||
}
|
||
}
|
||
|
||
models, _, err := model.SearchSupplierModels(nil, "", "", 0, 100000)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
for _, item := range models {
|
||
modelName := strings.TrimSpace(item.ModelName)
|
||
if modelName == "" {
|
||
continue
|
||
}
|
||
ownedModels[modelName] = struct{}{}
|
||
}
|
||
return ownedModels, nil
|
||
}
|
||
|
||
// collectSupplierOwnedModelNamesBySupplierID 收集指定供应商申请(supplier_application_id)名下模型集合。
|
||
func collectSupplierOwnedModelNamesBySupplierID(supplierID int) (map[string]struct{}, error) {
|
||
app, err := model.GetSupplierByID(supplierID)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
return collectSupplierOwnedModelNames(app.ApplicantUserID)
|
||
}
|
||
|
||
// filterModelJSONByOwnedModels 仅保留属于供应商自有模型的 JSON 键值。
|
||
func filterModelJSONByOwnedModels(raw string, ownedModels map[string]struct{}) (string, error) {
|
||
raw = strings.TrimSpace(raw)
|
||
if raw == "" {
|
||
return "{}", nil
|
||
}
|
||
var origin map[string]any
|
||
if err := common.UnmarshalJsonStr(raw, &origin); err != nil {
|
||
return "", err
|
||
}
|
||
filtered := make(map[string]any)
|
||
for modelName, value := range origin {
|
||
if _, ok := ownedModels[modelName]; !ok {
|
||
continue
|
||
}
|
||
filtered[modelName] = value
|
||
}
|
||
bytes, err := common.Marshal(filtered)
|
||
if err != nil {
|
||
return "", err
|
||
}
|
||
return string(bytes), nil
|
||
}
|
||
|
||
// mergeModelJSONByOwnedModels 仅允许供应商更新自有模型键,其余键保持原值。
|
||
func mergeModelJSONByOwnedModels(currentRaw string, incomingRaw string, ownedModels map[string]struct{}) (string, error) {
|
||
base := make(map[string]any)
|
||
currentRaw = strings.TrimSpace(currentRaw)
|
||
if currentRaw != "" {
|
||
if err := common.UnmarshalJsonStr(currentRaw, &base); err != nil {
|
||
return "", err
|
||
}
|
||
}
|
||
|
||
patch := make(map[string]any)
|
||
if err := common.UnmarshalJsonStr(strings.TrimSpace(incomingRaw), &patch); err != nil {
|
||
return "", err
|
||
}
|
||
for modelName, value := range patch {
|
||
if _, ok := ownedModels[modelName]; !ok {
|
||
continue
|
||
}
|
||
base[modelName] = value
|
||
}
|
||
bytes, err := common.Marshal(base)
|
||
if err != nil {
|
||
return "", err
|
||
}
|
||
return string(bytes), nil
|
||
}
|