package controller import ( "encoding/json" "fmt" "sort" "strconv" "strings" "github.com/QuantumNous/new-api/common" "github.com/QuantumNous/new-api/constant" "github.com/QuantumNous/new-api/model" "github.com/QuantumNous/new-api/service" "github.com/gin-gonic/gin" ) var defaultModelTags = []string{"文本", "视频", "图片"} // GetAllModelsMeta 获取模型列表(分页) func GetAllModelsMeta(c *gin.Context) { pageInfo := common.GetPageQuery(c) var ( modelsMeta []*model.Model total int64 err error ) if c.GetInt("role") >= common.RoleAdminUser { modelsMeta, err = model.GetAllModels(pageInfo.GetStartIdx(), pageInfo.GetPageSize()) if err != nil { common.ApiError(c, err) return } model.DB.Model(&model.Model{}).Count(&total) } else { ownerUserID := c.GetInt("id") modelsMeta, total, err = model.SearchSupplierModels(&ownerUserID, "", "", pageInfo.GetStartIdx(), pageInfo.GetPageSize()) if err != nil { common.ApiError(c, err) return } } // 批量填充附加字段,提升列表接口性能 enrichModels(modelsMeta) // 统计供应商计数(全部数据,不受分页影响) vendorCounts, _ := model.GetVendorModelCounts() if c.GetInt("role") < common.RoleAdminUser { vendorCounts = make(map[int64]int64) for _, item := range modelsMeta { vendorCounts[int64(item.VendorID)]++ } } pageInfo.SetTotal(int(total)) pageInfo.SetItems(modelsMeta) common.ApiSuccess(c, gin.H{ "items": modelsMeta, "total": total, "page": pageInfo.GetPage(), "page_size": pageInfo.GetPageSize(), "vendor_counts": vendorCounts, }) } func normalizeModelTags(tags []string) []string { result := make([]string, 0, len(tags)) seen := make(map[string]struct{}, len(tags)) for _, tag := range tags { name := strings.TrimSpace(tag) if name == "" { continue } if _, ok := seen[name]; ok { continue } seen[name] = struct{}{} result = append(result, name) } return result } func splitModelTagsCSV(csv string) []string { if strings.TrimSpace(csv) == "" { return nil } return normalizeModelTags(strings.Split(csv, ",")) } func GetModelTags(c *gin.Context) { merged := make([]string, 0, 32) seen := make(map[string]struct{}, 32) appendTag := func(name string) { tag := strings.TrimSpace(name) if tag == "" { return } if _, ok := seen[tag]; ok { return } seen[tag] = struct{}{} merged = append(merged, tag) } for _, tag := range defaultModelTags { appendTag(tag) } dbTags, err := model.GetAllModelTagNames() if err != nil { common.ApiError(c, err) return } for _, tag := range dbTags { appendTag(tag) } var allTagCSVs []string if err := model.DB.Model(&model.Model{}).Where("tags <> ?", "").Pluck("tags", &allTagCSVs).Error; err != nil { common.ApiError(c, err) return } for _, csv := range allTagCSVs { for _, tag := range splitModelTagsCSV(csv) { appendTag(tag) } } if err := model.UpsertModelTags(merged); err != nil { common.ApiError(c, err) return } common.ApiSuccess(c, merged) } type BatchSetModelTagsRequest struct { IDs []int `json:"ids"` Tags []string `json:"tags"` Mode string `json:"mode"` // add | replace } func BatchSetModelTags(c *gin.Context) { var req BatchSetModelTagsRequest if err := c.ShouldBindJSON(&req); err != nil { common.ApiError(c, err) return } if len(req.IDs) == 0 { common.ApiErrorMsg(c, "请选择至少一个模型") return } normalizedTags := normalizeModelTags(req.Tags) if len(normalizedTags) == 0 { common.ApiErrorMsg(c, "请至少填写一个标签") return } if req.Mode != "add" && req.Mode != "replace" { common.ApiErrorMsg(c, "标签设置模式无效") return } var modelsMeta []*model.Model if err := model.DB.Where("id IN ?", req.IDs).Find(&modelsMeta).Error; err != nil { common.ApiError(c, err) return } if len(modelsMeta) == 0 { common.ApiErrorMsg(c, "未找到可更新的模型") return } updated := 0 for _, item := range modelsMeta { newTags := normalizedTags if req.Mode == "add" { existing := splitModelTagsCSV(item.Tags) newTags = normalizeModelTags(append(existing, normalizedTags...)) } csv := strings.Join(newTags, ",") if err := model.DB.Model(&model.Model{}).Where("id = ?", item.Id).Update("tags", csv).Error; err != nil { common.ApiError(c, err) return } updated++ } if err := model.UpsertModelTags(normalizedTags); err != nil { common.ApiError(c, err) return } model.RefreshPricing() common.ApiSuccess(c, gin.H{ "updated": updated, }) } // SearchModelsMeta 搜索模型列表 func SearchModelsMeta(c *gin.Context) { keyword := c.Query("keyword") vendor := c.Query("vendor") pageInfo := common.GetPageQuery(c) var ( modelsMeta []*model.Model total int64 err error ) if c.GetInt("role") >= common.RoleAdminUser { modelsMeta, total, err = model.SearchModels(keyword, vendor, pageInfo.GetStartIdx(), pageInfo.GetPageSize()) } else { ownerUserID := c.GetInt("id") modelsMeta, total, err = model.SearchSupplierModels(&ownerUserID, keyword, vendor, pageInfo.GetStartIdx(), pageInfo.GetPageSize()) } if err != nil { common.ApiError(c, err) return } // 批量填充附加字段,提升列表接口性能 enrichModels(modelsMeta) pageInfo.SetTotal(int(total)) pageInfo.SetItems(modelsMeta) common.ApiSuccess(c, pageInfo) } // GetModelMeta 根据 ID 获取单条模型信息 func GetModelMeta(c *gin.Context) { idStr := c.Param("id") if idStr == "tags" { GetModelTags(c) return } id, err := strconv.Atoi(idStr) if err != nil { common.ApiError(c, err) return } var m model.Model if err := model.DB.First(&m, id).Error; err != nil { common.ApiError(c, err) return } enrichModels([]*model.Model{&m}) common.ApiSuccess(c, &m) } // CreateModelMeta 新建模型 func CreateModelMeta(c *gin.Context) { var m model.Model if err := c.ShouldBindJSON(&m); err != nil { common.ApiError(c, err) return } if m.ModelName == "" { common.ApiErrorMsg(c, "模型名称不能为空") return } // 名称冲突检查 if dup, err := model.IsModelNameDuplicated(0, m.ModelName); err != nil { common.ApiError(c, err) return } else if dup { common.ApiErrorMsg(c, "模型名称已存在") return } if err := m.Insert(); err != nil { common.ApiError(c, err) return } model.RefreshPricing() service.RecordCreateOperation(c, "model", m.Id, m.ModelName, "创建模型: "+m.ModelName, "") common.ApiSuccess(c, &m) } // UpdateModelMeta 更新模型 func UpdateModelMeta(c *gin.Context) { statusOnly := c.Query("status_only") == "true" docsOnly := c.Query("docs_only") == "true" var m model.Model if err := c.ShouldBindJSON(&m); err != nil { common.ApiError(c, err) return } if m.Id == 0 { common.ApiErrorMsg(c, "缺少模型 ID") return } if docsOnly { if err := model.DB.Model(&model.Model{}).Where("id = ?", m.Id).Updates(map[string]interface{}{ "doc_introduction": m.DocIntroduction, "api_docs": m.ApiDocs, }).Error; err != nil { common.ApiError(c, err) return } } else if statusOnly { // 只更新状态,防止误清空其他字段 if err := model.DB.Model(&model.Model{}).Where("id = ?", m.Id).Update("status", m.Status).Error; err != nil { common.ApiError(c, err) return } } else { // 名称冲突检查 if dup, err := model.IsModelNameDuplicated(m.Id, m.ModelName); err != nil { common.ApiError(c, err) return } else if dup { common.ApiErrorMsg(c, "模型名称已存在") return } if err := m.Update(); err != nil { common.ApiError(c, err) return } } model.RefreshPricing() service.RecordUpdateOperation(c, "model", m.Id, m.ModelName, "更新模型: "+m.ModelName, "") common.ApiSuccess(c, &m) } // DeleteModelMeta 删除模型 func DeleteModelMeta(c *gin.Context) { idStr := c.Param("id") id, err := strconv.Atoi(idStr) if err != nil { common.ApiError(c, err) return } if err := model.DB.Delete(&model.Model{}, id).Error; err != nil { common.ApiError(c, err) return } model.RefreshPricing() service.RecordDeleteOperation(c, "model", id, "", fmt.Sprintf("删除模型 (ID: %d)", id), "") common.ApiSuccess(c, nil) } // BatchUpdateModelWeightRequest 批量更新模型权重请求 type BatchUpdateModelWeightRequest struct { IDs []int `json:"ids"` SortWeight float64 `json:"sort_weight"` ManualBaseReqCount int64 `json:"manual_base_req_count"` } // BatchUpdateModelWeight 批量更新模型权重和手动调用次数 func BatchUpdateModelWeight(c *gin.Context) { var req BatchUpdateModelWeightRequest if err := c.ShouldBindJSON(&req); err != nil { common.ApiError(c, err) return } if len(req.IDs) == 0 { common.ApiErrorMsg(c, "请选择至少一个模型") return } updates := map[string]interface{}{ "sort_weight": req.SortWeight, "manual_base_req_count": req.ManualBaseReqCount, } if err := model.DB.Model(&model.Model{}).Where("id IN ?", req.IDs).Updates(updates).Error; err != nil { common.ApiError(c, err) return } model.RefreshPricing() common.ApiSuccess(c, gin.H{ "updated": len(req.IDs), }) } // enrichModels 批量填充附加信息:端点、渠道、分组、计费类型,避免 N+1 查询 func enrichModels(models []*model.Model) { if len(models) == 0 { return } // 1) 拆分精确与规则匹配 exactNames := make([]string, 0) exactIdx := make(map[string][]int) // modelName -> indices in models ruleIndices := make([]int, 0) for i, m := range models { if m == nil { continue } if m.NameRule == model.NameRuleExact { exactNames = append(exactNames, m.ModelName) exactIdx[m.ModelName] = append(exactIdx[m.ModelName], i) } else { ruleIndices = append(ruleIndices, i) } } // 2) 批量查询精确模型的绑定渠道 channelsByModel, _ := model.GetBoundChannelsByModelsMap(exactNames) // 3) 精确模型:端点从缓存、渠道批量映射、分组/计费类型从缓存 for name, indices := range exactIdx { chs := channelsByModel[name] for _, idx := range indices { mm := models[idx] if mm.Endpoints == "" { eps := model.GetModelSupportEndpointTypes(mm.ModelName) if b, err := json.Marshal(eps); err == nil { mm.Endpoints = string(b) } } mm.BoundChannels = chs mm.EnableGroups = model.GetModelEnableGroups(mm.ModelName) mm.QuotaTypes = model.GetModelQuotaTypes(mm.ModelName) } } if len(ruleIndices) == 0 { return } // 4) 一次性读取定价缓存,内存匹配所有规则模型 pricings := model.GetPricing() // 为全部规则模型收集匹配名集合、端点并集、分组并集、配额集合 matchedNamesByIdx := make(map[int][]string) endpointSetByIdx := make(map[int]map[constant.EndpointType]struct{}) groupSetByIdx := make(map[int]map[string]struct{}) quotaSetByIdx := make(map[int]map[int]struct{}) for _, p := range pricings { for _, idx := range ruleIndices { mm := models[idx] var matched bool switch mm.NameRule { case model.NameRulePrefix: matched = strings.HasPrefix(p.ModelName, mm.ModelName) case model.NameRuleSuffix: matched = strings.HasSuffix(p.ModelName, mm.ModelName) case model.NameRuleContains: matched = strings.Contains(p.ModelName, mm.ModelName) } if !matched { continue } matchedNamesByIdx[idx] = append(matchedNamesByIdx[idx], p.ModelName) es := endpointSetByIdx[idx] if es == nil { es = make(map[constant.EndpointType]struct{}) endpointSetByIdx[idx] = es } for _, et := range p.SupportedEndpointTypes { es[et] = struct{}{} } gs := groupSetByIdx[idx] if gs == nil { gs = make(map[string]struct{}) groupSetByIdx[idx] = gs } for _, g := range p.EnableGroup { gs[g] = struct{}{} } qs := quotaSetByIdx[idx] if qs == nil { qs = make(map[int]struct{}) quotaSetByIdx[idx] = qs } qs[p.QuotaType] = struct{}{} } } // 5) 汇总所有匹配到的模型名称,批量查询一次渠道 allMatchedSet := make(map[string]struct{}) for _, names := range matchedNamesByIdx { for _, n := range names { allMatchedSet[n] = struct{}{} } } allMatched := make([]string, 0, len(allMatchedSet)) for n := range allMatchedSet { allMatched = append(allMatched, n) } matchedChannelsByModel, _ := model.GetBoundChannelsByModelsMap(allMatched) // 6) 回填每个规则模型的并集信息 for _, idx := range ruleIndices { mm := models[idx] // 端点并集 -> 序列化 if es, ok := endpointSetByIdx[idx]; ok && mm.Endpoints == "" { eps := make([]constant.EndpointType, 0, len(es)) for et := range es { eps = append(eps, et) } if b, err := json.Marshal(eps); err == nil { mm.Endpoints = string(b) } } // 分组并集 if gs, ok := groupSetByIdx[idx]; ok { groups := make([]string, 0, len(gs)) for g := range gs { groups = append(groups, g) } mm.EnableGroups = groups } // 配额类型集合(保持去重并排序) if qs, ok := quotaSetByIdx[idx]; ok { arr := make([]int, 0, len(qs)) for k := range qs { arr = append(arr, k) } sort.Ints(arr) mm.QuotaTypes = arr } // 渠道并集 names := matchedNamesByIdx[idx] channelSet := make(map[string]model.BoundChannel) for _, n := range names { for _, ch := range matchedChannelsByModel[n] { key := ch.Name + "_" + strconv.Itoa(ch.Type) channelSet[key] = ch } } if len(channelSet) > 0 { chs := make([]model.BoundChannel, 0, len(channelSet)) for _, ch := range channelSet { chs = append(chs, ch) } mm.BoundChannels = chs } // 匹配信息 mm.MatchedModels = names mm.MatchedCount = len(names) } }