tokenFactory/service/model_meta_infer.go

508 lines
16 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package service
import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"sync"
"time"
"github.com/QuantumNous/new-api/model"
)
// ────────────────────────────────────────────────────────────────────────────
// 官方预设缓存TTL 15 分钟,避免每次 auto_meta 都重复抓取)
// ────────────────────────────────────────────────────────────────────────────
const (
officialModelsPresetURL = "https://basellm.github.io/llm-metadata/api/newapi/models.json"
presetCacheTTL = 15 * time.Minute
)
type officialModelEntry struct {
ModelName string `json:"model_name"`
Endpoints json.RawMessage `json:"endpoints"`
Tags string `json:"tags"`
VendorName string `json:"vendor_name"`
Description string `json:"description"`
Icon string `json:"icon"`
NameRule int `json:"name_rule"`
Status int `json:"status"`
}
type officialPresetEnvelope struct {
Success bool `json:"success"`
Data []officialModelEntry `json:"data"`
}
var (
presetMu sync.RWMutex
presetByName map[string]officialModelEntry
presetFetchAt time.Time
)
// fetchOfficialPreset 获取官方模型预设(带本地缓存)。
// 缓存未过期时直接返回内存副本;过期或首次调用时请求远端。
func fetchOfficialPreset(ctx context.Context) map[string]officialModelEntry {
presetMu.RLock()
if presetByName != nil && time.Since(presetFetchAt) < presetCacheTTL {
m := presetByName
presetMu.RUnlock()
return m
}
presetMu.RUnlock()
// 升级为写锁后二次检查,防止并发重复抓取
presetMu.Lock()
defer presetMu.Unlock()
if presetByName != nil && time.Since(presetFetchAt) < presetCacheTTL {
return presetByName
}
client := &http.Client{Timeout: 10 * time.Second}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, officialModelsPresetURL, nil)
if err != nil {
return presetByName // 失败时沿用旧缓存
}
resp, err := client.Do(req)
if err != nil || resp.StatusCode != http.StatusOK {
return presetByName
}
defer resp.Body.Close()
body, err := io.ReadAll(io.LimitReader(resp.Body, 8<<20))
if err != nil {
return presetByName
}
// 兼容两种格式envelope{ success, data:[] } 或 直接 []
var env officialPresetEnvelope
if err := json.Unmarshal(body, &env); err == nil && len(env.Data) > 0 {
m := make(map[string]officialModelEntry, len(env.Data))
for _, e := range env.Data {
if e.ModelName != "" {
m[e.ModelName] = e
}
}
presetByName = m
presetFetchAt = time.Now()
return presetByName
}
// 尝试直接解析为数组
var arr []officialModelEntry
if err := json.Unmarshal(body, &arr); err == nil {
m := make(map[string]officialModelEntry, len(arr))
for _, e := range arr {
if e.ModelName != "" {
m[e.ModelName] = e
}
}
presetByName = m
presetFetchAt = time.Now()
}
return presetByName
}
// ────────────────────────────────────────────────────────────────────────────
// 模型名称规则推断
// ────────────────────────────────────────────────────────────────────────────
// inferEndpoints 根据模型名推断 Endpoints JSON 字符串(如 `["openai"]`)。
// 推断顺序Embedding → Rerank → Image → Video → Chat默认
func inferEndpoints(name string) string {
lower := strings.ToLower(name)
switch {
// Embedding
case strings.Contains(lower, "embed"),
strings.HasPrefix(lower, "bge-"),
strings.HasPrefix(lower, "m3e-"),
strings.Contains(lower, "jina-embed"):
return `["embeddings"]`
// Rerank
case strings.Contains(lower, "rerank"),
strings.Contains(lower, "jina-rerank"):
return `["jina-rerank"]`
// Image generation
case strings.Contains(lower, "dall-e"),
strings.Contains(lower, "sdxl"),
strings.Contains(lower, "stable-diffusion"),
strings.Contains(lower, "wanx"),
strings.Contains(lower, "kolors"),
strings.Contains(lower, "cogview"),
strings.Contains(lower, "hunyuan-dit"),
strings.Contains(lower, "flux"),
matchesPattern(lower, []string{"image-alpha", "imagen-", "text-to-image"}):
return `["image-generation"]`
// Video generation
case strings.Contains(lower, "video-generation"),
strings.Contains(lower, "kling"),
strings.Contains(lower, "vidu"),
matchesPattern(lower, []string{"video-01", "video-02"}):
return `["openai-video"]`
// 默认Chat (openai-compatible)
default:
return `["openai"]`
}
}
// inferTags 根据模型名推断标签(逗号分隔字符串)。
func inferTags(name string) string {
lower := strings.ToLower(name)
var tags []string
seen := make(map[string]bool)
add := func(t string) {
if !seen[t] {
seen[t] = true
tags = append(tags, t)
}
}
// 视觉/多模态
if strings.Contains(lower, "vision") ||
strings.Contains(lower, "-vl") ||
strings.Contains(lower, "omni") ||
strings.Contains(lower, "visual") {
add("vision")
}
// 推理增强
if strings.Contains(lower, "thinking") ||
strings.Contains(lower, "reasoner") ||
strings.Contains(lower, "-r1") ||
strings.Contains(lower, "-r2") ||
strings.HasPrefix(lower, "o1") ||
strings.HasPrefix(lower, "o3") ||
strings.Contains(lower, "-think") ||
strings.Contains(lower, "qwq") {
add("reasoning")
}
// 代码
if strings.Contains(lower, "code") ||
strings.Contains(lower, "coder") ||
strings.Contains(lower, "codex") ||
strings.Contains(lower, "codestral") ||
strings.Contains(lower, "deepseek-coder") {
add("coding")
}
// Embedding
if strings.Contains(lower, "embed") ||
strings.HasPrefix(lower, "bge-") ||
strings.HasPrefix(lower, "m3e-") {
add("embedding")
}
// Rerank
if strings.Contains(lower, "rerank") {
add("rerank")
}
// Image
if strings.Contains(lower, "dall-e") ||
strings.Contains(lower, "sdxl") ||
strings.Contains(lower, "flux") ||
strings.Contains(lower, "image-generation") {
add("image")
}
// 音频
if strings.Contains(lower, "whisper") ||
strings.Contains(lower, "-asr") ||
strings.Contains(lower, "tts") {
add("audio")
}
// 轻量/经济型
if strings.Contains(lower, "mini") ||
strings.Contains(lower, "lite") ||
strings.Contains(lower, "tiny") ||
strings.Contains(lower, "nano") ||
strings.Contains(lower, "small") ||
strings.Contains(lower, "flash") ||
strings.Contains(lower, "haiku") {
add("budget")
}
return strings.Join(tags, ",")
}
// ────────────────────────────────────────────────────────────────────────────
// 标签过滤:移除不适合用户分类使用的标签
// ────────────────────────────────────────────────────────────────────────────
// validTagSet 定义允许作为模型分类标签的合法标签集合(小写)。
// 不在此集合中的标签将被过滤掉(如上下文窗口大小 "262.1K"、"128K" 等数值型标签)。
var validTagSet = map[string]bool{
// 能力分类
"reasoning": true,
"tools": true,
"files": true,
"vision": true,
"coding": true,
"code": true,
"embedding": true,
"rerank": true,
"image": true,
"audio": true,
"video": true,
"budget": true,
// 模型属性
"open weights": true,
"open source": true,
"proprietary": true,
"local": true,
"cloud": true,
"multilingual": true,
// 通用分类
"chat": true,
"completion": true,
"instruct": true,
"base": true,
"fine-tuned": true,
"lora": true,
}
// filterTags 过滤逗号分隔的标签字符串,只保留合法的分类标签。
// 用于清理官方预设中可能包含的上下文窗口大小(如 "262.1K"、"128K")等
// 不适合作为用户筛选分类的数值型标签。
func filterTags(tagsStr string) string {
if tagsStr == "" {
return ""
}
parts := strings.Split(tagsStr, ",")
var filtered []string
for _, p := range parts {
tag := strings.TrimSpace(p)
if tag == "" {
continue
}
// 精确匹配合法标签(不区分大小写)
if validTagSet[strings.ToLower(tag)] {
filtered = append(filtered, tag)
}
}
return strings.Join(filtered, ",")
}
// matchesPattern 检查 lower 是否包含 patterns 中的任意一个。
func matchesPattern(lower string, patterns []string) bool {
for _, p := range patterns {
if strings.Contains(lower, p) {
return true
}
}
return false
}
// ────────────────────────────────────────────────────────────────────────────
// 供应商推断VendorID
// ────────────────────────────────────────────────────────────────────────────
// vendorKeywordAliases 将"模型名关键词"映射到"供应商名关键词"(小写)。
// 匹配策略:先在模型名中搜索 key命中后再用 values 匹配数据库中 Vendor.Name小写子串
var vendorKeywordAliases = []struct {
modelKW string // 在模型名中搜索(小写)
vendorKWs []string // 在 Vendor.Name 中任意一个命中即可(小写)
}{
{"claude", []string{"anthropic"}},
{"gemini", []string{"google"}},
{"gpt", []string{"openai"}},
{"dall-e", []string{"openai"}},
{"whisper", []string{"openai"}},
{"o1-", []string{"openai"}},
{"o3-", []string{"openai"}},
{"o4-", []string{"openai"}},
{"llama", []string{"meta"}},
{"mistral", []string{"mistral"}},
{"mixtral", []string{"mistral"}},
{"codestral", []string{"mistral"}},
{"deepseek", []string{"deepseek"}},
{"qwen", []string{"alibaba", "qwen", "tongyi", "aliyun"}},
{"moonshot", []string{"moonshot"}},
{"kimi", []string{"moonshot"}},
{"doubao", []string{"bytedance", "volcengine", "volcano"}},
{"ernie", []string{"baidu"}},
{"wenxin", []string{"baidu"}},
{"hunyuan", []string{"tencent"}},
{"spark", []string{"xunfei", "iflytek"}},
{"glm", []string{"zhipu", "chatglm"}},
{"chatglm", []string{"zhipu"}},
{"yi-", []string{"lingyiwanwu", "01ai", "zero-one"}},
{"minimax", []string{"minimax"}},
{"abab", []string{"minimax"}},
{"flux", []string{"black forest", "blackforest"}},
{"stable-diffusion", []string{"stability"}},
{"sdxl", []string{"stability"}},
{"cohere", []string{"cohere"}},
{"command-r", []string{"cohere"}},
{"perplexity", []string{"perplexity"}},
{"jina", []string{"jina"}},
{"suno", []string{"suno"}},
{"kling", []string{"kling", "kuaishou"}},
{"vidu", []string{"vidu", "shengshu"}},
{"cogview", []string{"zhipu"}},
{"internlm", []string{"shanghaiai", "intern"}},
{"baichuan", []string{"baichuan"}},
{"xai", []string{"xai"}},
{"grok", []string{"xai"}},
}
// buildVendorIndex 一次性从 DB 中加载所有 Vendor构建 name.lower → id 的映射。
func buildVendorIndex() map[string]int {
vendors, err := model.GetAllVendors(0, 2000)
if err != nil || len(vendors) == 0 {
return nil
}
idx := make(map[string]int, len(vendors))
for _, v := range vendors {
idx[strings.ToLower(v.Name)] = v.Id
}
return idx
}
// inferVendorID 根据模型名在 vendorIdx 中查找最可能的供应商 ID找不到返回 0。
func inferVendorID(modelName string, vendorIdx map[string]int) int {
if len(vendorIdx) == 0 {
return 0
}
lower := strings.ToLower(modelName)
for _, rule := range vendorKeywordAliases {
if !strings.Contains(lower, rule.modelKW) {
continue
}
// 模型名匹配到关键词 → 在 vendorIdx 中搜索供应商名关键词
for vendorNameLower, id := range vendorIdx {
for _, vkw := range rule.vendorKWs {
if strings.Contains(vendorNameLower, vkw) {
return id
}
}
}
}
// 兜底:尝试用 vendorIdx 中的供应商名直接匹配模型名(如模型名直接含供应商名)
for vendorNameLower, id := range vendorIdx {
if len(vendorNameLower) >= 4 && strings.Contains(lower, vendorNameLower) {
return id
}
}
return 0
}
// ────────────────────────────────────────────────────────────────────────────
// 对外接口AutoCreateMissingModelMeta
// ────────────────────────────────────────────────────────────────────────────
// AutoMetaItem 单个模型的自动推断结果。
type AutoMetaItem struct {
ModelName string `json:"model_name"`
// "official":来自官方预设;"inferred":名称规则推断;"exists":已有记录跳过
Source string `json:"source"`
Endpoints string `json:"endpoints"`
Tags string `json:"tags"`
VendorID int `json:"vendor_id,omitempty"`
Err string `json:"err,omitempty"`
}
// AutoCreateMissingModelMeta 对给定模型名列表,为缺少 model_meta 记录的模型
// 自动推断并创建元数据(先查官方预设,再用名称规则兜底)。
// 返回每个模型的处理结果。
func AutoCreateMissingModelMeta(ctx context.Context, modelNames []string) []AutoMetaItem {
if len(modelNames) == 0 {
return nil
}
// 1. 找出已存在的模型名(跳过)
existingNames, _ := model.GetExistingModelNames(modelNames)
existingSet := make(map[string]bool, len(existingNames))
for _, n := range existingNames {
existingSet[n] = true
}
// 2. 拉取官方预设(带缓存)
preset := fetchOfficialPreset(ctx)
// 3. 构建供应商索引vendor name lower → id用于 VendorID 推断
vendorIdx := buildVendorIndex()
results := make([]AutoMetaItem, 0, len(modelNames))
for _, name := range modelNames {
// 已存在:跳过
if existingSet[name] {
results = append(results, AutoMetaItem{
ModelName: name,
Source: "exists",
})
continue
}
item := AutoMetaItem{ModelName: name}
// 3a. 优先:官方预设精确匹配
if entry, ok := preset[name]; ok {
item.Source = "official"
if len(entry.Endpoints) > 0 && string(entry.Endpoints) != "null" {
item.Endpoints = string(entry.Endpoints)
} else {
item.Endpoints = inferEndpoints(name)
}
item.Tags = filterTags(entry.Tags)
if item.Tags == "" {
item.Tags = inferTags(name)
}
vendorID := inferVendorID(name, vendorIdx)
item.VendorID = vendorID
mi := &model.Model{
ModelName: name,
Description: entry.Description,
Icon: entry.Icon,
Tags: item.Tags,
Endpoints: item.Endpoints,
VendorID: vendorID,
Status: chooseModelStatus(entry.Status),
NameRule: entry.NameRule,
SyncOfficial: 1,
}
if err := mi.Insert(); err != nil {
item.Err = fmt.Sprintf("DB error: %v", err)
}
} else {
// 3b. 兜底:名称规则推断
item.Source = "inferred"
item.Endpoints = inferEndpoints(name)
item.Tags = inferTags(name)
vendorID := inferVendorID(name, vendorIdx)
item.VendorID = vendorID
mi := &model.Model{
ModelName: name,
Tags: item.Tags,
Endpoints: item.Endpoints,
VendorID: vendorID,
Status: 1,
SyncOfficial: 1,
}
if err := mi.Insert(); err != nil {
item.Err = fmt.Sprintf("DB error: %v", err)
}
}
results = append(results, item)
}
return results
}
func chooseModelStatus(upstreamStatus int) int {
if upstreamStatus == 0 {
return 1
}
return upstreamStatus
}