tokenFactory/model/model_meta.go

231 lines
7.7 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package model
import (
"strconv"
"github.com/QuantumNous/new-api/common"
"gorm.io/gorm"
)
const (
NameRuleExact = iota
NameRulePrefix
NameRuleContains
NameRuleSuffix
)
type BoundChannel struct {
Name string `json:"name"`
Type int `json:"type"`
}
type Model struct {
Id int `json:"id"`
ModelName string `json:"model_name" gorm:"size:128;not null;uniqueIndex:uk_model_name_delete_at,priority:1"`
Description string `json:"description,omitempty" gorm:"type:text"`
DocIntroduction string `json:"doc_introduction,omitempty" gorm:"type:text"`
ApiDocs string `json:"api_docs,omitempty" gorm:"type:text"`
Icon string `json:"icon,omitempty" gorm:"type:varchar(128)"`
Tags string `json:"tags,omitempty" gorm:"type:varchar(255)"`
VendorID int `json:"vendor_id,omitempty" gorm:"index"`
Endpoints string `json:"endpoints,omitempty" gorm:"type:text"`
Status int `json:"status" gorm:"default:1"`
SyncOfficial int `json:"sync_official" gorm:"default:1"`
CreatedTime int64 `json:"created_time" gorm:"bigint"`
UpdatedTime int64 `json:"updated_time" gorm:"bigint"`
DeletedAt gorm.DeletedAt `json:"-" gorm:"index;uniqueIndex:uk_model_name_delete_at,priority:2"`
BoundChannels []BoundChannel `json:"bound_channels,omitempty" gorm:"-"`
EnableGroups []string `json:"enable_groups,omitempty" gorm:"-"`
QuotaTypes []int `json:"quota_types,omitempty" gorm:"-"`
NameRule int `json:"name_rule" gorm:"default:0"`
OwnerUserID int `json:"owner_user_id" gorm:"type:int;index;default:0"` // 模型归属用户ID供应商场景
SupplierApplicationID int `json:"supplier_application_id" gorm:"type:int;index;default:0"` // 关联 supplier_applications.id
MatchedModels []string `json:"matched_models,omitempty" gorm:"-"`
MatchedCount int `json:"matched_count,omitempty" gorm:"-"`
// 排序权重和手动调用次数(用于热门排序干预)
SortWeight float64 `json:"sort_weight" gorm:"default:1"`
ManualBaseReqCount int64 `json:"manual_base_req_count" gorm:"default:0"` // 手动设置调用基数
}
func (mi *Model) Insert() error {
now := common.GetTimestamp()
mi.CreatedTime = now
mi.UpdatedTime = now
// 保存原始值(因为 Create 后可能被 GORM 的 default 标签覆盖为 1
originalStatus := mi.Status
originalSyncOfficial := mi.SyncOfficial
// 先创建记录GORM 会对零值字段应用默认值)
if err := DB.Create(mi).Error; err != nil {
return err
}
// 使用保存的原始值进行更新,确保零值能正确保存
return DB.Model(&Model{}).Where("id = ?", mi.Id).Updates(map[string]interface{}{
"status": originalStatus,
"sync_official": originalSyncOfficial,
}).Error
}
func IsModelNameDuplicated(id int, name string) (bool, error) {
if name == "" {
return false, nil
}
var cnt int64
err := DB.Model(&Model{}).Where("model_name = ? AND id <> ?", name, id).Count(&cnt).Error
return cnt > 0, err
}
func (mi *Model) Update() error {
mi.UpdatedTime = common.GetTimestamp()
// 使用 Select 强制更新所有字段,包括零值
return DB.Model(&Model{}).Where("id = ?", mi.Id).
Select("model_name", "description", "doc_introduction", "api_docs", "icon", "tags", "vendor_id", "endpoints", "status", "sync_official", "name_rule", "owner_user_id", "supplier_application_id", "updated_time").
Updates(mi).Error
}
func (mi *Model) Delete() error {
return DB.Delete(mi).Error
}
func GetVendorModelCounts() (map[int64]int64, error) {
var stats []struct {
VendorID int64
Count int64
}
if err := DB.Model(&Model{}).
Select("vendor_id as vendor_id, count(*) as count").
Group("vendor_id").
Scan(&stats).Error; err != nil {
return nil, err
}
m := make(map[int64]int64, len(stats))
for _, s := range stats {
m[s.VendorID] = s.Count
}
return m, nil
}
func GetAllModels(offset int, limit int) ([]*Model, error) {
var models []*Model
err := DB.Order("id DESC").Offset(offset).Limit(limit).Find(&models).Error
return models, err
}
// ListModelsByOwnerUser 分页查询指定归属用户创建的模型。
func ListModelsByOwnerUser(ownerUserID int, offset int, limit int) ([]*Model, int64, error) {
var (
models []*Model
total int64
)
query := DB.Model(&Model{}).Where("owner_user_id = ?", ownerUserID)
if err := query.Count(&total).Error; err != nil {
return nil, 0, err
}
if err := query.Order("id DESC").Offset(offset).Limit(limit).Find(&models).Error; err != nil {
return nil, 0, err
}
return models, total, nil
}
// SearchSupplierModels 搜索供应商模型(供应商查自己,管理员查全部供应商)。
func SearchSupplierModels(ownerUserID *int, keyword string, vendor string, offset int, limit int) ([]*Model, int64, error) {
var (
models []*Model
total int64
)
db := DB.Model(&Model{})
if ownerUserID != nil {
db = db.Where("owner_user_id = ?", *ownerUserID)
} else {
db = db.Where("owner_user_id > ? AND supplier_application_id > ?", 0, 0)
}
if keyword != "" {
like := "%" + keyword + "%"
db = db.Where("model_name LIKE ? OR description LIKE ? OR tags LIKE ?", like, like, like)
}
if vendor != "" {
if vid, err := strconv.Atoi(vendor); err == nil {
db = db.Where("models.vendor_id = ?", vid)
} else {
db = db.Joins("JOIN vendors ON vendors.id = models.vendor_id").Where("vendors.name LIKE ?", "%"+vendor+"%")
}
}
if err := db.Count(&total).Error; err != nil {
return nil, 0, err
}
if err := db.Order("models.id DESC").Offset(offset).Limit(limit).Find(&models).Error; err != nil {
return nil, 0, err
}
return models, total, nil
}
func GetBoundChannelsByModelsMap(modelNames []string) (map[string][]BoundChannel, error) {
result := make(map[string][]BoundChannel)
if len(modelNames) == 0 {
return result, nil
}
type row struct {
Model string
Name string
Type int
}
var rows []row
err := DB.Table("channels").
Select("abilities.model as model, channels.name as name, channels.type as type").
Joins("JOIN abilities ON abilities.channel_id = channels.id").
Where("abilities.model IN ? AND abilities.enabled = ?", modelNames, true).
Distinct().
Scan(&rows).Error
if err != nil {
return nil, err
}
for _, r := range rows {
result[r.Model] = append(result[r.Model], BoundChannel{Name: r.Name, Type: r.Type})
}
return result, nil
}
func SearchModels(keyword string, vendor string, offset int, limit int) ([]*Model, int64, error) {
var models []*Model
db := DB.Model(&Model{})
if keyword != "" {
like := "%" + keyword + "%"
db = db.Where("model_name LIKE ? OR description LIKE ? OR tags LIKE ?", like, like, like)
}
if vendor != "" {
if vid, err := strconv.Atoi(vendor); err == nil {
db = db.Where("models.vendor_id = ?", vid)
} else {
db = db.Joins("JOIN vendors ON vendors.id = models.vendor_id").Where("vendors.name LIKE ?", "%"+vendor+"%")
}
}
var total int64
if err := db.Count(&total).Error; err != nil {
return nil, 0, err
}
if err := db.Order("models.id DESC").Offset(offset).Limit(limit).Find(&models).Error; err != nil {
return nil, 0, err
}
return models, total, nil
}
// GetExistingModelNames 从给定名称列表中返回已在 model_meta 表中存在记录的模型名。
// 用于上架向导诊断:快速判断哪些模型需要手动去 /console/models 配置元数据。
func GetExistingModelNames(names []string) ([]string, error) {
if len(names) == 0 {
return nil, nil
}
var result []string
err := DB.Model(&Model{}).
Select("model_name").
Where("model_name IN ?", names).
Pluck("model_name", &result).Error
return result, err
}