508 lines
16 KiB
Go
508 lines
16 KiB
Go
package service
|
||
|
||
import (
|
||
"context"
|
||
"encoding/json"
|
||
"fmt"
|
||
"io"
|
||
"net/http"
|
||
"strings"
|
||
"sync"
|
||
"time"
|
||
|
||
"github.com/QuantumNous/new-api/model"
|
||
)
|
||
|
||
// ────────────────────────────────────────────────────────────────────────────
|
||
// 官方预设缓存(TTL 15 分钟,避免每次 auto_meta 都重复抓取)
|
||
// ────────────────────────────────────────────────────────────────────────────
|
||
|
||
const (
|
||
officialModelsPresetURL = "https://basellm.github.io/llm-metadata/api/newapi/models.json"
|
||
presetCacheTTL = 15 * time.Minute
|
||
)
|
||
|
||
type officialModelEntry struct {
|
||
ModelName string `json:"model_name"`
|
||
Endpoints json.RawMessage `json:"endpoints"`
|
||
Tags string `json:"tags"`
|
||
VendorName string `json:"vendor_name"`
|
||
Description string `json:"description"`
|
||
Icon string `json:"icon"`
|
||
NameRule int `json:"name_rule"`
|
||
Status int `json:"status"`
|
||
}
|
||
|
||
type officialPresetEnvelope struct {
|
||
Success bool `json:"success"`
|
||
Data []officialModelEntry `json:"data"`
|
||
}
|
||
|
||
var (
|
||
presetMu sync.RWMutex
|
||
presetByName map[string]officialModelEntry
|
||
presetFetchAt time.Time
|
||
)
|
||
|
||
// fetchOfficialPreset 获取官方模型预设(带本地缓存)。
|
||
// 缓存未过期时直接返回内存副本;过期或首次调用时请求远端。
|
||
func fetchOfficialPreset(ctx context.Context) map[string]officialModelEntry {
|
||
presetMu.RLock()
|
||
if presetByName != nil && time.Since(presetFetchAt) < presetCacheTTL {
|
||
m := presetByName
|
||
presetMu.RUnlock()
|
||
return m
|
||
}
|
||
presetMu.RUnlock()
|
||
|
||
// 升级为写锁后二次检查,防止并发重复抓取
|
||
presetMu.Lock()
|
||
defer presetMu.Unlock()
|
||
if presetByName != nil && time.Since(presetFetchAt) < presetCacheTTL {
|
||
return presetByName
|
||
}
|
||
|
||
client := &http.Client{Timeout: 10 * time.Second}
|
||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, officialModelsPresetURL, nil)
|
||
if err != nil {
|
||
return presetByName // 失败时沿用旧缓存
|
||
}
|
||
resp, err := client.Do(req)
|
||
if err != nil || resp.StatusCode != http.StatusOK {
|
||
return presetByName
|
||
}
|
||
defer resp.Body.Close()
|
||
|
||
body, err := io.ReadAll(io.LimitReader(resp.Body, 8<<20))
|
||
if err != nil {
|
||
return presetByName
|
||
}
|
||
|
||
// 兼容两种格式:envelope{ success, data:[] } 或 直接 []
|
||
var env officialPresetEnvelope
|
||
if err := json.Unmarshal(body, &env); err == nil && len(env.Data) > 0 {
|
||
m := make(map[string]officialModelEntry, len(env.Data))
|
||
for _, e := range env.Data {
|
||
if e.ModelName != "" {
|
||
m[e.ModelName] = e
|
||
}
|
||
}
|
||
presetByName = m
|
||
presetFetchAt = time.Now()
|
||
return presetByName
|
||
}
|
||
// 尝试直接解析为数组
|
||
var arr []officialModelEntry
|
||
if err := json.Unmarshal(body, &arr); err == nil {
|
||
m := make(map[string]officialModelEntry, len(arr))
|
||
for _, e := range arr {
|
||
if e.ModelName != "" {
|
||
m[e.ModelName] = e
|
||
}
|
||
}
|
||
presetByName = m
|
||
presetFetchAt = time.Now()
|
||
}
|
||
return presetByName
|
||
}
|
||
|
||
// ────────────────────────────────────────────────────────────────────────────
|
||
// 模型名称规则推断
|
||
// ────────────────────────────────────────────────────────────────────────────
|
||
|
||
// inferEndpoints 根据模型名推断 Endpoints JSON 字符串(如 `["openai"]`)。
|
||
// 推断顺序:Embedding → Rerank → Image → Video → Chat(默认)
|
||
func inferEndpoints(name string) string {
|
||
lower := strings.ToLower(name)
|
||
|
||
switch {
|
||
// Embedding
|
||
case strings.Contains(lower, "embed"),
|
||
strings.HasPrefix(lower, "bge-"),
|
||
strings.HasPrefix(lower, "m3e-"),
|
||
strings.Contains(lower, "jina-embed"):
|
||
return `["embeddings"]`
|
||
|
||
// Rerank
|
||
case strings.Contains(lower, "rerank"),
|
||
strings.Contains(lower, "jina-rerank"):
|
||
return `["jina-rerank"]`
|
||
|
||
// Image generation
|
||
case strings.Contains(lower, "dall-e"),
|
||
strings.Contains(lower, "sdxl"),
|
||
strings.Contains(lower, "stable-diffusion"),
|
||
strings.Contains(lower, "wanx"),
|
||
strings.Contains(lower, "kolors"),
|
||
strings.Contains(lower, "cogview"),
|
||
strings.Contains(lower, "hunyuan-dit"),
|
||
strings.Contains(lower, "flux"),
|
||
matchesPattern(lower, []string{"image-alpha", "imagen-", "text-to-image"}):
|
||
return `["image-generation"]`
|
||
|
||
// Video generation
|
||
case strings.Contains(lower, "video-generation"),
|
||
strings.Contains(lower, "kling"),
|
||
strings.Contains(lower, "vidu"),
|
||
matchesPattern(lower, []string{"video-01", "video-02"}):
|
||
return `["openai-video"]`
|
||
|
||
// 默认:Chat (openai-compatible)
|
||
default:
|
||
return `["openai"]`
|
||
}
|
||
}
|
||
|
||
// inferTags 根据模型名推断标签(逗号分隔字符串)。
|
||
func inferTags(name string) string {
|
||
lower := strings.ToLower(name)
|
||
var tags []string
|
||
seen := make(map[string]bool)
|
||
|
||
add := func(t string) {
|
||
if !seen[t] {
|
||
seen[t] = true
|
||
tags = append(tags, t)
|
||
}
|
||
}
|
||
|
||
// 视觉/多模态
|
||
if strings.Contains(lower, "vision") ||
|
||
strings.Contains(lower, "-vl") ||
|
||
strings.Contains(lower, "omni") ||
|
||
strings.Contains(lower, "visual") {
|
||
add("vision")
|
||
}
|
||
// 推理增强
|
||
if strings.Contains(lower, "thinking") ||
|
||
strings.Contains(lower, "reasoner") ||
|
||
strings.Contains(lower, "-r1") ||
|
||
strings.Contains(lower, "-r2") ||
|
||
strings.HasPrefix(lower, "o1") ||
|
||
strings.HasPrefix(lower, "o3") ||
|
||
strings.Contains(lower, "-think") ||
|
||
strings.Contains(lower, "qwq") {
|
||
add("reasoning")
|
||
}
|
||
// 代码
|
||
if strings.Contains(lower, "code") ||
|
||
strings.Contains(lower, "coder") ||
|
||
strings.Contains(lower, "codex") ||
|
||
strings.Contains(lower, "codestral") ||
|
||
strings.Contains(lower, "deepseek-coder") {
|
||
add("coding")
|
||
}
|
||
// Embedding
|
||
if strings.Contains(lower, "embed") ||
|
||
strings.HasPrefix(lower, "bge-") ||
|
||
strings.HasPrefix(lower, "m3e-") {
|
||
add("embedding")
|
||
}
|
||
// Rerank
|
||
if strings.Contains(lower, "rerank") {
|
||
add("rerank")
|
||
}
|
||
// Image
|
||
if strings.Contains(lower, "dall-e") ||
|
||
strings.Contains(lower, "sdxl") ||
|
||
strings.Contains(lower, "flux") ||
|
||
strings.Contains(lower, "image-generation") {
|
||
add("image")
|
||
}
|
||
// 音频
|
||
if strings.Contains(lower, "whisper") ||
|
||
strings.Contains(lower, "-asr") ||
|
||
strings.Contains(lower, "tts") {
|
||
add("audio")
|
||
}
|
||
// 轻量/经济型
|
||
if strings.Contains(lower, "mini") ||
|
||
strings.Contains(lower, "lite") ||
|
||
strings.Contains(lower, "tiny") ||
|
||
strings.Contains(lower, "nano") ||
|
||
strings.Contains(lower, "small") ||
|
||
strings.Contains(lower, "flash") ||
|
||
strings.Contains(lower, "haiku") {
|
||
add("budget")
|
||
}
|
||
|
||
return strings.Join(tags, ",")
|
||
}
|
||
|
||
// ────────────────────────────────────────────────────────────────────────────
|
||
// 标签过滤:移除不适合用户分类使用的标签
|
||
// ────────────────────────────────────────────────────────────────────────────
|
||
|
||
// validTagSet 定义允许作为模型分类标签的合法标签集合(小写)。
|
||
// 不在此集合中的标签将被过滤掉(如上下文窗口大小 "262.1K"、"128K" 等数值型标签)。
|
||
var validTagSet = map[string]bool{
|
||
// 能力分类
|
||
"reasoning": true,
|
||
"tools": true,
|
||
"files": true,
|
||
"vision": true,
|
||
"coding": true,
|
||
"code": true,
|
||
"embedding": true,
|
||
"rerank": true,
|
||
"image": true,
|
||
"audio": true,
|
||
"video": true,
|
||
"budget": true,
|
||
// 模型属性
|
||
"open weights": true,
|
||
"open source": true,
|
||
"proprietary": true,
|
||
"local": true,
|
||
"cloud": true,
|
||
"multilingual": true,
|
||
// 通用分类
|
||
"chat": true,
|
||
"completion": true,
|
||
"instruct": true,
|
||
"base": true,
|
||
"fine-tuned": true,
|
||
"lora": true,
|
||
}
|
||
|
||
// filterTags 过滤逗号分隔的标签字符串,只保留合法的分类标签。
|
||
// 用于清理官方预设中可能包含的上下文窗口大小(如 "262.1K"、"128K")等
|
||
// 不适合作为用户筛选分类的数值型标签。
|
||
func filterTags(tagsStr string) string {
|
||
if tagsStr == "" {
|
||
return ""
|
||
}
|
||
parts := strings.Split(tagsStr, ",")
|
||
var filtered []string
|
||
for _, p := range parts {
|
||
tag := strings.TrimSpace(p)
|
||
if tag == "" {
|
||
continue
|
||
}
|
||
// 精确匹配合法标签(不区分大小写)
|
||
if validTagSet[strings.ToLower(tag)] {
|
||
filtered = append(filtered, tag)
|
||
}
|
||
}
|
||
return strings.Join(filtered, ",")
|
||
}
|
||
|
||
// matchesPattern 检查 lower 是否包含 patterns 中的任意一个。
|
||
func matchesPattern(lower string, patterns []string) bool {
|
||
for _, p := range patterns {
|
||
if strings.Contains(lower, p) {
|
||
return true
|
||
}
|
||
}
|
||
return false
|
||
}
|
||
|
||
// ────────────────────────────────────────────────────────────────────────────
|
||
// 供应商推断(VendorID)
|
||
// ────────────────────────────────────────────────────────────────────────────
|
||
|
||
// vendorKeywordAliases 将"模型名关键词"映射到"供应商名关键词"(小写)。
|
||
// 匹配策略:先在模型名中搜索 key,命中后再用 values 匹配数据库中 Vendor.Name(小写子串)。
|
||
var vendorKeywordAliases = []struct {
|
||
modelKW string // 在模型名中搜索(小写)
|
||
vendorKWs []string // 在 Vendor.Name 中任意一个命中即可(小写)
|
||
}{
|
||
{"claude", []string{"anthropic"}},
|
||
{"gemini", []string{"google"}},
|
||
{"gpt", []string{"openai"}},
|
||
{"dall-e", []string{"openai"}},
|
||
{"whisper", []string{"openai"}},
|
||
{"o1-", []string{"openai"}},
|
||
{"o3-", []string{"openai"}},
|
||
{"o4-", []string{"openai"}},
|
||
{"llama", []string{"meta"}},
|
||
{"mistral", []string{"mistral"}},
|
||
{"mixtral", []string{"mistral"}},
|
||
{"codestral", []string{"mistral"}},
|
||
{"deepseek", []string{"deepseek"}},
|
||
{"qwen", []string{"alibaba", "qwen", "tongyi", "aliyun"}},
|
||
{"moonshot", []string{"moonshot"}},
|
||
{"kimi", []string{"moonshot"}},
|
||
{"doubao", []string{"bytedance", "volcengine", "volcano"}},
|
||
{"ernie", []string{"baidu"}},
|
||
{"wenxin", []string{"baidu"}},
|
||
{"hunyuan", []string{"tencent"}},
|
||
{"spark", []string{"xunfei", "iflytek"}},
|
||
{"glm", []string{"zhipu", "chatglm"}},
|
||
{"chatglm", []string{"zhipu"}},
|
||
{"yi-", []string{"lingyiwanwu", "01ai", "zero-one"}},
|
||
{"minimax", []string{"minimax"}},
|
||
{"abab", []string{"minimax"}},
|
||
{"flux", []string{"black forest", "blackforest"}},
|
||
{"stable-diffusion", []string{"stability"}},
|
||
{"sdxl", []string{"stability"}},
|
||
{"cohere", []string{"cohere"}},
|
||
{"command-r", []string{"cohere"}},
|
||
{"perplexity", []string{"perplexity"}},
|
||
{"jina", []string{"jina"}},
|
||
{"suno", []string{"suno"}},
|
||
{"kling", []string{"kling", "kuaishou"}},
|
||
{"vidu", []string{"vidu", "shengshu"}},
|
||
{"cogview", []string{"zhipu"}},
|
||
{"internlm", []string{"shanghaiai", "intern"}},
|
||
{"baichuan", []string{"baichuan"}},
|
||
{"xai", []string{"xai"}},
|
||
{"grok", []string{"xai"}},
|
||
}
|
||
|
||
// buildVendorIndex 一次性从 DB 中加载所有 Vendor,构建 name.lower → id 的映射。
|
||
func buildVendorIndex() map[string]int {
|
||
vendors, err := model.GetAllVendors(0, 2000)
|
||
if err != nil || len(vendors) == 0 {
|
||
return nil
|
||
}
|
||
idx := make(map[string]int, len(vendors))
|
||
for _, v := range vendors {
|
||
idx[strings.ToLower(v.Name)] = v.Id
|
||
}
|
||
return idx
|
||
}
|
||
|
||
// inferVendorID 根据模型名在 vendorIdx 中查找最可能的供应商 ID,找不到返回 0。
|
||
func inferVendorID(modelName string, vendorIdx map[string]int) int {
|
||
if len(vendorIdx) == 0 {
|
||
return 0
|
||
}
|
||
lower := strings.ToLower(modelName)
|
||
|
||
for _, rule := range vendorKeywordAliases {
|
||
if !strings.Contains(lower, rule.modelKW) {
|
||
continue
|
||
}
|
||
// 模型名匹配到关键词 → 在 vendorIdx 中搜索供应商名关键词
|
||
for vendorNameLower, id := range vendorIdx {
|
||
for _, vkw := range rule.vendorKWs {
|
||
if strings.Contains(vendorNameLower, vkw) {
|
||
return id
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
// 兜底:尝试用 vendorIdx 中的供应商名直接匹配模型名(如模型名直接含供应商名)
|
||
for vendorNameLower, id := range vendorIdx {
|
||
if len(vendorNameLower) >= 4 && strings.Contains(lower, vendorNameLower) {
|
||
return id
|
||
}
|
||
}
|
||
return 0
|
||
}
|
||
|
||
// ────────────────────────────────────────────────────────────────────────────
|
||
// 对外接口:AutoCreateMissingModelMeta
|
||
// ────────────────────────────────────────────────────────────────────────────
|
||
|
||
// AutoMetaItem 单个模型的自动推断结果。
|
||
type AutoMetaItem struct {
|
||
ModelName string `json:"model_name"`
|
||
// "official":来自官方预设;"inferred":名称规则推断;"exists":已有记录跳过
|
||
Source string `json:"source"`
|
||
Endpoints string `json:"endpoints"`
|
||
Tags string `json:"tags"`
|
||
VendorID int `json:"vendor_id,omitempty"`
|
||
Err string `json:"err,omitempty"`
|
||
}
|
||
|
||
// AutoCreateMissingModelMeta 对给定模型名列表,为缺少 model_meta 记录的模型
|
||
// 自动推断并创建元数据(先查官方预设,再用名称规则兜底)。
|
||
// 返回每个模型的处理结果。
|
||
func AutoCreateMissingModelMeta(ctx context.Context, modelNames []string) []AutoMetaItem {
|
||
if len(modelNames) == 0 {
|
||
return nil
|
||
}
|
||
|
||
// 1. 找出已存在的模型名(跳过)
|
||
existingNames, _ := model.GetExistingModelNames(modelNames)
|
||
existingSet := make(map[string]bool, len(existingNames))
|
||
for _, n := range existingNames {
|
||
existingSet[n] = true
|
||
}
|
||
|
||
// 2. 拉取官方预设(带缓存)
|
||
preset := fetchOfficialPreset(ctx)
|
||
|
||
// 3. 构建供应商索引(vendor name lower → id),用于 VendorID 推断
|
||
vendorIdx := buildVendorIndex()
|
||
|
||
results := make([]AutoMetaItem, 0, len(modelNames))
|
||
|
||
for _, name := range modelNames {
|
||
// 已存在:跳过
|
||
if existingSet[name] {
|
||
results = append(results, AutoMetaItem{
|
||
ModelName: name,
|
||
Source: "exists",
|
||
})
|
||
continue
|
||
}
|
||
|
||
item := AutoMetaItem{ModelName: name}
|
||
|
||
// 3a. 优先:官方预设精确匹配
|
||
if entry, ok := preset[name]; ok {
|
||
item.Source = "official"
|
||
if len(entry.Endpoints) > 0 && string(entry.Endpoints) != "null" {
|
||
item.Endpoints = string(entry.Endpoints)
|
||
} else {
|
||
item.Endpoints = inferEndpoints(name)
|
||
}
|
||
item.Tags = filterTags(entry.Tags)
|
||
if item.Tags == "" {
|
||
item.Tags = inferTags(name)
|
||
}
|
||
|
||
vendorID := inferVendorID(name, vendorIdx)
|
||
item.VendorID = vendorID
|
||
mi := &model.Model{
|
||
ModelName: name,
|
||
Description: entry.Description,
|
||
Icon: entry.Icon,
|
||
Tags: item.Tags,
|
||
Endpoints: item.Endpoints,
|
||
VendorID: vendorID,
|
||
Status: chooseModelStatus(entry.Status),
|
||
NameRule: entry.NameRule,
|
||
SyncOfficial: 1,
|
||
}
|
||
if err := mi.Insert(); err != nil {
|
||
item.Err = fmt.Sprintf("DB error: %v", err)
|
||
}
|
||
} else {
|
||
// 3b. 兜底:名称规则推断
|
||
item.Source = "inferred"
|
||
item.Endpoints = inferEndpoints(name)
|
||
item.Tags = inferTags(name)
|
||
|
||
vendorID := inferVendorID(name, vendorIdx)
|
||
item.VendorID = vendorID
|
||
mi := &model.Model{
|
||
ModelName: name,
|
||
Tags: item.Tags,
|
||
Endpoints: item.Endpoints,
|
||
VendorID: vendorID,
|
||
Status: 1,
|
||
SyncOfficial: 1,
|
||
}
|
||
if err := mi.Insert(); err != nil {
|
||
item.Err = fmt.Sprintf("DB error: %v", err)
|
||
}
|
||
}
|
||
|
||
results = append(results, item)
|
||
}
|
||
|
||
return results
|
||
}
|
||
|
||
func chooseModelStatus(upstreamStatus int) int {
|
||
if upstreamStatus == 0 {
|
||
return 1
|
||
}
|
||
return upstreamStatus
|
||
}
|