tokenFactory/controller/model_meta.go

546 lines
14 KiB
Go

package controller
import (
"encoding/json"
"fmt"
"sort"
"strconv"
"strings"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/constant"
"github.com/QuantumNous/new-api/model"
"github.com/QuantumNous/new-api/service"
"github.com/gin-gonic/gin"
)
var defaultModelTags = []string{"文本", "视频", "图片"}
// GetAllModelsMeta 获取模型列表(分页)
func GetAllModelsMeta(c *gin.Context) {
pageInfo := common.GetPageQuery(c)
var (
modelsMeta []*model.Model
total int64
err error
)
if c.GetInt("role") >= common.RoleAdminUser {
modelsMeta, err = model.GetAllModels(pageInfo.GetStartIdx(), pageInfo.GetPageSize())
if err != nil {
common.ApiError(c, err)
return
}
model.DB.Model(&model.Model{}).Count(&total)
} else {
ownerUserID := c.GetInt("id")
modelsMeta, total, err = model.SearchSupplierModels(&ownerUserID, "", "", pageInfo.GetStartIdx(), pageInfo.GetPageSize())
if err != nil {
common.ApiError(c, err)
return
}
}
// 批量填充附加字段,提升列表接口性能
enrichModels(modelsMeta)
// 统计供应商计数(全部数据,不受分页影响)
vendorCounts, _ := model.GetVendorModelCounts()
if c.GetInt("role") < common.RoleAdminUser {
vendorCounts = make(map[int64]int64)
for _, item := range modelsMeta {
vendorCounts[int64(item.VendorID)]++
}
}
pageInfo.SetTotal(int(total))
pageInfo.SetItems(modelsMeta)
common.ApiSuccess(c, gin.H{
"items": modelsMeta,
"total": total,
"page": pageInfo.GetPage(),
"page_size": pageInfo.GetPageSize(),
"vendor_counts": vendorCounts,
})
}
func normalizeModelTags(tags []string) []string {
result := make([]string, 0, len(tags))
seen := make(map[string]struct{}, len(tags))
for _, tag := range tags {
name := strings.TrimSpace(tag)
if name == "" {
continue
}
if _, ok := seen[name]; ok {
continue
}
seen[name] = struct{}{}
result = append(result, name)
}
return result
}
func splitModelTagsCSV(csv string) []string {
if strings.TrimSpace(csv) == "" {
return nil
}
return normalizeModelTags(strings.Split(csv, ","))
}
func GetModelTags(c *gin.Context) {
merged := make([]string, 0, 32)
seen := make(map[string]struct{}, 32)
appendTag := func(name string) {
tag := strings.TrimSpace(name)
if tag == "" {
return
}
if _, ok := seen[tag]; ok {
return
}
seen[tag] = struct{}{}
merged = append(merged, tag)
}
for _, tag := range defaultModelTags {
appendTag(tag)
}
dbTags, err := model.GetAllModelTagNames()
if err != nil {
common.ApiError(c, err)
return
}
for _, tag := range dbTags {
appendTag(tag)
}
var allTagCSVs []string
if err := model.DB.Model(&model.Model{}).Where("tags <> ?", "").Pluck("tags", &allTagCSVs).Error; err != nil {
common.ApiError(c, err)
return
}
for _, csv := range allTagCSVs {
for _, tag := range splitModelTagsCSV(csv) {
appendTag(tag)
}
}
if err := model.UpsertModelTags(merged); err != nil {
common.ApiError(c, err)
return
}
common.ApiSuccess(c, merged)
}
type BatchSetModelTagsRequest struct {
IDs []int `json:"ids"`
Tags []string `json:"tags"`
Mode string `json:"mode"` // add | replace
}
func BatchSetModelTags(c *gin.Context) {
var req BatchSetModelTagsRequest
if err := c.ShouldBindJSON(&req); err != nil {
common.ApiError(c, err)
return
}
if len(req.IDs) == 0 {
common.ApiErrorMsg(c, "请选择至少一个模型")
return
}
normalizedTags := normalizeModelTags(req.Tags)
if len(normalizedTags) == 0 {
common.ApiErrorMsg(c, "请至少填写一个标签")
return
}
if req.Mode != "add" && req.Mode != "replace" {
common.ApiErrorMsg(c, "标签设置模式无效")
return
}
var modelsMeta []*model.Model
if err := model.DB.Where("id IN ?", req.IDs).Find(&modelsMeta).Error; err != nil {
common.ApiError(c, err)
return
}
if len(modelsMeta) == 0 {
common.ApiErrorMsg(c, "未找到可更新的模型")
return
}
updated := 0
for _, item := range modelsMeta {
newTags := normalizedTags
if req.Mode == "add" {
existing := splitModelTagsCSV(item.Tags)
newTags = normalizeModelTags(append(existing, normalizedTags...))
}
csv := strings.Join(newTags, ",")
if err := model.DB.Model(&model.Model{}).Where("id = ?", item.Id).Update("tags", csv).Error; err != nil {
common.ApiError(c, err)
return
}
updated++
}
if err := model.UpsertModelTags(normalizedTags); err != nil {
common.ApiError(c, err)
return
}
model.RefreshPricing()
common.ApiSuccess(c, gin.H{
"updated": updated,
})
}
// SearchModelsMeta 搜索模型列表
func SearchModelsMeta(c *gin.Context) {
keyword := c.Query("keyword")
vendor := c.Query("vendor")
pageInfo := common.GetPageQuery(c)
var (
modelsMeta []*model.Model
total int64
err error
)
if c.GetInt("role") >= common.RoleAdminUser {
modelsMeta, total, err = model.SearchModels(keyword, vendor, pageInfo.GetStartIdx(), pageInfo.GetPageSize())
} else {
ownerUserID := c.GetInt("id")
modelsMeta, total, err = model.SearchSupplierModels(&ownerUserID, keyword, vendor, pageInfo.GetStartIdx(), pageInfo.GetPageSize())
}
if err != nil {
common.ApiError(c, err)
return
}
// 批量填充附加字段,提升列表接口性能
enrichModels(modelsMeta)
pageInfo.SetTotal(int(total))
pageInfo.SetItems(modelsMeta)
common.ApiSuccess(c, pageInfo)
}
// GetModelMeta 根据 ID 获取单条模型信息
func GetModelMeta(c *gin.Context) {
idStr := c.Param("id")
if idStr == "tags" {
GetModelTags(c)
return
}
id, err := strconv.Atoi(idStr)
if err != nil {
common.ApiError(c, err)
return
}
var m model.Model
if err := model.DB.First(&m, id).Error; err != nil {
common.ApiError(c, err)
return
}
enrichModels([]*model.Model{&m})
common.ApiSuccess(c, &m)
}
// CreateModelMeta 新建模型
func CreateModelMeta(c *gin.Context) {
var m model.Model
if err := c.ShouldBindJSON(&m); err != nil {
common.ApiError(c, err)
return
}
if m.ModelName == "" {
common.ApiErrorMsg(c, "模型名称不能为空")
return
}
// 名称冲突检查
if dup, err := model.IsModelNameDuplicated(0, m.ModelName); err != nil {
common.ApiError(c, err)
return
} else if dup {
common.ApiErrorMsg(c, "模型名称已存在")
return
}
if err := m.Insert(); err != nil {
common.ApiError(c, err)
return
}
model.RefreshPricing()
service.RecordCreateOperation(c, "model", m.Id, m.ModelName, "创建模型: "+m.ModelName, "")
common.ApiSuccess(c, &m)
}
// UpdateModelMeta 更新模型
func UpdateModelMeta(c *gin.Context) {
statusOnly := c.Query("status_only") == "true"
docsOnly := c.Query("docs_only") == "true"
var m model.Model
if err := c.ShouldBindJSON(&m); err != nil {
common.ApiError(c, err)
return
}
if m.Id == 0 {
common.ApiErrorMsg(c, "缺少模型 ID")
return
}
if docsOnly {
if err := model.DB.Model(&model.Model{}).Where("id = ?", m.Id).Updates(map[string]interface{}{
"doc_introduction": m.DocIntroduction,
"api_docs": m.ApiDocs,
}).Error; err != nil {
common.ApiError(c, err)
return
}
} else if statusOnly {
// 只更新状态,防止误清空其他字段
if err := model.DB.Model(&model.Model{}).Where("id = ?", m.Id).Update("status", m.Status).Error; err != nil {
common.ApiError(c, err)
return
}
} else {
// 名称冲突检查
if dup, err := model.IsModelNameDuplicated(m.Id, m.ModelName); err != nil {
common.ApiError(c, err)
return
} else if dup {
common.ApiErrorMsg(c, "模型名称已存在")
return
}
if err := m.Update(); err != nil {
common.ApiError(c, err)
return
}
}
model.RefreshPricing()
service.RecordUpdateOperation(c, "model", m.Id, m.ModelName, "更新模型: "+m.ModelName, "")
common.ApiSuccess(c, &m)
}
// DeleteModelMeta 删除模型
func DeleteModelMeta(c *gin.Context) {
idStr := c.Param("id")
id, err := strconv.Atoi(idStr)
if err != nil {
common.ApiError(c, err)
return
}
if err := model.DB.Delete(&model.Model{}, id).Error; err != nil {
common.ApiError(c, err)
return
}
model.RefreshPricing()
service.RecordDeleteOperation(c, "model", id, "", fmt.Sprintf("删除模型 (ID: %d)", id), "")
common.ApiSuccess(c, nil)
}
// BatchUpdateModelWeightRequest 批量更新模型权重请求
type BatchUpdateModelWeightRequest struct {
IDs []int `json:"ids"`
SortWeight float64 `json:"sort_weight"`
ManualBaseReqCount int64 `json:"manual_base_req_count"`
}
// BatchUpdateModelWeight 批量更新模型权重和手动调用次数
func BatchUpdateModelWeight(c *gin.Context) {
var req BatchUpdateModelWeightRequest
if err := c.ShouldBindJSON(&req); err != nil {
common.ApiError(c, err)
return
}
if len(req.IDs) == 0 {
common.ApiErrorMsg(c, "请选择至少一个模型")
return
}
updates := map[string]interface{}{
"sort_weight": req.SortWeight,
"manual_base_req_count": req.ManualBaseReqCount,
}
if err := model.DB.Model(&model.Model{}).Where("id IN ?", req.IDs).Updates(updates).Error; err != nil {
common.ApiError(c, err)
return
}
model.RefreshPricing()
common.ApiSuccess(c, gin.H{
"updated": len(req.IDs),
})
}
// enrichModels 批量填充附加信息:端点、渠道、分组、计费类型,避免 N+1 查询
func enrichModels(models []*model.Model) {
if len(models) == 0 {
return
}
// 1) 拆分精确与规则匹配
exactNames := make([]string, 0)
exactIdx := make(map[string][]int) // modelName -> indices in models
ruleIndices := make([]int, 0)
for i, m := range models {
if m == nil {
continue
}
if m.NameRule == model.NameRuleExact {
exactNames = append(exactNames, m.ModelName)
exactIdx[m.ModelName] = append(exactIdx[m.ModelName], i)
} else {
ruleIndices = append(ruleIndices, i)
}
}
// 2) 批量查询精确模型的绑定渠道
channelsByModel, _ := model.GetBoundChannelsByModelsMap(exactNames)
// 3) 精确模型:端点从缓存、渠道批量映射、分组/计费类型从缓存
for name, indices := range exactIdx {
chs := channelsByModel[name]
for _, idx := range indices {
mm := models[idx]
if mm.Endpoints == "" {
eps := model.GetModelSupportEndpointTypes(mm.ModelName)
if b, err := json.Marshal(eps); err == nil {
mm.Endpoints = string(b)
}
}
mm.BoundChannels = chs
mm.EnableGroups = model.GetModelEnableGroups(mm.ModelName)
mm.QuotaTypes = model.GetModelQuotaTypes(mm.ModelName)
}
}
if len(ruleIndices) == 0 {
return
}
// 4) 一次性读取定价缓存,内存匹配所有规则模型
pricings := model.GetPricing()
// 为全部规则模型收集匹配名集合、端点并集、分组并集、配额集合
matchedNamesByIdx := make(map[int][]string)
endpointSetByIdx := make(map[int]map[constant.EndpointType]struct{})
groupSetByIdx := make(map[int]map[string]struct{})
quotaSetByIdx := make(map[int]map[int]struct{})
for _, p := range pricings {
for _, idx := range ruleIndices {
mm := models[idx]
var matched bool
switch mm.NameRule {
case model.NameRulePrefix:
matched = strings.HasPrefix(p.ModelName, mm.ModelName)
case model.NameRuleSuffix:
matched = strings.HasSuffix(p.ModelName, mm.ModelName)
case model.NameRuleContains:
matched = strings.Contains(p.ModelName, mm.ModelName)
}
if !matched {
continue
}
matchedNamesByIdx[idx] = append(matchedNamesByIdx[idx], p.ModelName)
es := endpointSetByIdx[idx]
if es == nil {
es = make(map[constant.EndpointType]struct{})
endpointSetByIdx[idx] = es
}
for _, et := range p.SupportedEndpointTypes {
es[et] = struct{}{}
}
gs := groupSetByIdx[idx]
if gs == nil {
gs = make(map[string]struct{})
groupSetByIdx[idx] = gs
}
for _, g := range p.EnableGroup {
gs[g] = struct{}{}
}
qs := quotaSetByIdx[idx]
if qs == nil {
qs = make(map[int]struct{})
quotaSetByIdx[idx] = qs
}
qs[p.QuotaType] = struct{}{}
}
}
// 5) 汇总所有匹配到的模型名称,批量查询一次渠道
allMatchedSet := make(map[string]struct{})
for _, names := range matchedNamesByIdx {
for _, n := range names {
allMatchedSet[n] = struct{}{}
}
}
allMatched := make([]string, 0, len(allMatchedSet))
for n := range allMatchedSet {
allMatched = append(allMatched, n)
}
matchedChannelsByModel, _ := model.GetBoundChannelsByModelsMap(allMatched)
// 6) 回填每个规则模型的并集信息
for _, idx := range ruleIndices {
mm := models[idx]
// 端点并集 -> 序列化
if es, ok := endpointSetByIdx[idx]; ok && mm.Endpoints == "" {
eps := make([]constant.EndpointType, 0, len(es))
for et := range es {
eps = append(eps, et)
}
if b, err := json.Marshal(eps); err == nil {
mm.Endpoints = string(b)
}
}
// 分组并集
if gs, ok := groupSetByIdx[idx]; ok {
groups := make([]string, 0, len(gs))
for g := range gs {
groups = append(groups, g)
}
mm.EnableGroups = groups
}
// 配额类型集合(保持去重并排序)
if qs, ok := quotaSetByIdx[idx]; ok {
arr := make([]int, 0, len(qs))
for k := range qs {
arr = append(arr, k)
}
sort.Ints(arr)
mm.QuotaTypes = arr
}
// 渠道并集
names := matchedNamesByIdx[idx]
channelSet := make(map[string]model.BoundChannel)
for _, n := range names {
for _, ch := range matchedChannelsByModel[n] {
key := ch.Name + "_" + strconv.Itoa(ch.Type)
channelSet[key] = ch
}
}
if len(channelSet) > 0 {
chs := make([]model.BoundChannel, 0, len(channelSet))
for _, ch := range channelSet {
chs = append(chs, ch)
}
mm.BoundChannels = chs
}
// 匹配信息
mm.MatchedModels = names
mm.MatchedCount = len(names)
}
}