tokenFactory/controller/supplier_scope.go

152 lines
4.1 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package controller
import (
"strings"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/model"
)
// supplierEditableModelOptionKeys 定义供应商可操作的模型倍率相关配置键。
var supplierEditableModelOptionKeys = map[string]struct{}{
"ModelPrice": {},
"ModelRatio": {},
"CompletionRatio": {},
"CacheRatio": {},
"CreateCacheRatio": {},
"ImageRatio": {},
"AudioRatio": {},
"AudioCompletionRatio": {},
"VideoRatio": {},
"VideoCompletionRatio": {},
"VideoPrice": {},
"VideoPricingRules": {},
"ImagePrice": {},
"ImagePricingRules": {},
}
// collectSupplierOwnedModelNames 收集供应商名下渠道与模型中的模型名集合。
func collectSupplierOwnedModelNames(userID int) (map[string]struct{}, error) {
ownedModels := make(map[string]struct{})
channels, _, err := model.SearchSupplierChannels(&userID, 0, 100000, model.SupplierChannelSearchFilter{})
if err != nil {
return nil, err
}
for _, channel := range channels {
for _, modelName := range channel.GetModels() {
modelName = strings.TrimSpace(modelName)
if modelName == "" {
continue
}
ownedModels[modelName] = struct{}{}
}
}
models, _, err := model.SearchSupplierModels(&userID, "", "", 0, 100000)
if err != nil {
return nil, err
}
for _, item := range models {
modelName := strings.TrimSpace(item.ModelName)
if modelName == "" {
continue
}
ownedModels[modelName] = struct{}{}
}
return ownedModels, nil
}
// collectAllSupplierOwnedModelNames 收集全部供应商名下的模型名集合(管理员统计用)。
func collectAllSupplierOwnedModelNames() (map[string]struct{}, error) {
ownedModels := make(map[string]struct{})
channels, _, err := model.SearchSupplierChannels(nil, 0, 100000, model.SupplierChannelSearchFilter{})
if err != nil {
return nil, err
}
for _, channel := range channels {
for _, modelName := range channel.GetModels() {
modelName = strings.TrimSpace(modelName)
if modelName == "" {
continue
}
ownedModels[modelName] = struct{}{}
}
}
models, _, err := model.SearchSupplierModels(nil, "", "", 0, 100000)
if err != nil {
return nil, err
}
for _, item := range models {
modelName := strings.TrimSpace(item.ModelName)
if modelName == "" {
continue
}
ownedModels[modelName] = struct{}{}
}
return ownedModels, nil
}
// collectSupplierOwnedModelNamesBySupplierID 收集指定供应商申请supplier_application_id名下模型集合。
func collectSupplierOwnedModelNamesBySupplierID(supplierID int) (map[string]struct{}, error) {
app, err := model.GetSupplierByID(supplierID)
if err != nil {
return nil, err
}
return collectSupplierOwnedModelNames(app.ApplicantUserID)
}
// filterModelJSONByOwnedModels 仅保留属于供应商自有模型的 JSON 键值。
func filterModelJSONByOwnedModels(raw string, ownedModels map[string]struct{}) (string, error) {
raw = strings.TrimSpace(raw)
if raw == "" {
return "{}", nil
}
var origin map[string]any
if err := common.UnmarshalJsonStr(raw, &origin); err != nil {
return "", err
}
filtered := make(map[string]any)
for modelName, value := range origin {
if _, ok := ownedModels[modelName]; !ok {
continue
}
filtered[modelName] = value
}
bytes, err := common.Marshal(filtered)
if err != nil {
return "", err
}
return string(bytes), nil
}
// mergeModelJSONByOwnedModels 仅允许供应商更新自有模型键,其余键保持原值。
func mergeModelJSONByOwnedModels(currentRaw string, incomingRaw string, ownedModels map[string]struct{}) (string, error) {
base := make(map[string]any)
currentRaw = strings.TrimSpace(currentRaw)
if currentRaw != "" {
if err := common.UnmarshalJsonStr(currentRaw, &base); err != nil {
return "", err
}
}
patch := make(map[string]any)
if err := common.UnmarshalJsonStr(strings.TrimSpace(incomingRaw), &patch); err != nil {
return "", err
}
for modelName, value := range patch {
if _, ok := ownedModels[modelName]; !ok {
continue
}
base[modelName] = value
}
bytes, err := common.Marshal(base)
if err != nil {
return "", err
}
return string(bytes), nil
}