tokenFactory/service/smart_router.go

303 lines
9.6 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package service
import (
"encoding/json"
"os"
"strings"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/constant"
"github.com/QuantumNous/new-api/model"
"github.com/QuantumNous/new-api/setting/ratio_setting"
"github.com/fyinfor/router-engine/pkg/router"
"github.com/gin-gonic/gin"
)
// SmartRouterEnabled 默认开启。仅当 SMART_ROUTER_ENABLED 为 0 / false / no / off不区分大小写时关闭。
func SmartRouterEnabled() bool {
v := strings.TrimSpace(os.Getenv("SMART_ROUTER_ENABLED"))
if v == "" {
return true
}
if v == "0" || strings.EqualFold(v, "false") || strings.EqualFold(v, "no") || strings.EqualFold(v, "off") {
return false
}
return true
}
func channelProviderSlug(ch *model.Channel) string {
switch ch.Type {
case constant.ChannelTypeOpenAI:
return "openai"
case constant.ChannelTypeAzure:
return "azure"
case constant.ChannelTypeAnthropic:
return "anthropic"
case constant.ChannelTypeOpenRouter:
return "openrouter"
case constant.ChannelTypeGemini:
return "google"
case constant.ChannelTypeVertexAi:
return "google-vertex"
case constant.ChannelTypeDeepSeek:
return "deepseek"
case constant.ChannelTypeSiliconFlow:
return "siliconflow"
case constant.ChannelTypeVolcEngine:
return "volcengine"
case constant.ChannelTypeMoonshot:
return "moonshot"
case constant.ChannelTypeXai:
return "xai"
case constant.ChannelTypeMistral:
return "mistral"
case constant.ChannelTypePerplexity:
return "perplexity"
case constant.ChannelTypeTencent:
return "tencent"
case constant.ChannelTypeZhipu, constant.ChannelTypeZhipu_v4:
return "zhipu"
case constant.ChannelTypeBaidu, constant.ChannelTypeBaiduV2:
return "baidu"
case constant.ChannelTypeAli:
return "dashscope"
case constant.ChannelTypeAws:
return "aws"
case constant.ChannelTypeCohere:
return "cohere"
default:
if n, ok := constant.ChannelTypeNames[ch.Type]; ok {
return strings.ToLower(strings.ReplaceAll(n, " ", ""))
}
return "unknown"
}
}
func buildRouterCandidates(group, modelName string) ([]*router.EndpointCandidate, error) {
return buildRouterCandidatesFiltered(group, modelName, nil)
}
// buildRouterCandidatesFiltered 在 buildRouterCandidates 基础上额外支持按渠道过滤。
// filter 为 nil 时行为与 buildRouterCandidates 相同filter 返回 false 的渠道将被剔除。
func buildRouterCandidatesFiltered(group, modelName string, filter func(*model.Channel) bool) ([]*router.EndpointCandidate, error) {
ids := model.ListChannelIDsForGroupModel(group, modelName)
if len(ids) == 0 {
return nil, nil
}
var out []*router.EndpointCandidate
for _, id := range ids {
ch, err := model.CacheGetChannel(id)
if err != nil || ch == nil || ch.Status != common.ChannelStatusEnabled {
continue
}
if !model.IsChannelEnabledForGroupModel(group, modelName, ch.Id) {
continue
}
if filter != nil && !filter(ch) {
continue
}
// UnitPrice is the primary sorting signal for smart routing与 relay 定价优先级对齐)。
unitPrice := 1.0
sid := ch.SupplierApplicationID
if p, ok := model.ResolveSupplierScopedFixedModelPrice(ch.Id, sid, modelName); ok {
unitPrice = p
} else if r, ok, _ := model.ResolveSupplierScopedModelRatio(ch.Id, sid, modelName); ok {
unitPrice = r
}
if unitPrice <= 0 {
ratio, _, _ := ratio_setting.GetModelRatio(modelName)
if ratio > 0 {
unitPrice = ratio
}
}
if unitPrice <= 0 {
unitPrice = 1
}
latSec := float64(ch.ResponseTime) / 1000.0
if latSec <= 0 {
latSec = 0.001
}
tps := 1.0 / latSec
w := 0
if ch.Weight != nil {
w = int(*ch.Weight)
}
prio := int64(0)
if ch.Priority != nil {
prio = *ch.Priority
}
out = append(out, &router.EndpointCandidate{
ChannelID: ch.Id,
Model: modelName,
ProviderSlug: channelProviderSlug(ch),
UnitPrice: unitPrice,
Healthy: true,
LatencyP50Seconds: latSec,
ThroughputTps: tps,
Priority: prio,
Weight: w,
})
}
return out, nil
}
func resolveSmartRouteGroup(usingGroup, userGroup, modelName string) string {
if usingGroup != "auto" {
return usingGroup
}
for _, g := range GetUserAutoGroup(userGroup) {
if len(model.ListChannelIDsForGroupModel(g, modelName)) > 0 {
return g
}
}
return ""
}
// TrySmartRouteChannel runs in-process router-engine when SmartRouterEnabled(). On success it stores
// ContextKeySmartRouteChannelOrder for relay retries and returns the first channel.
func TrySmartRouteChannel(c *gin.Context, usingGroup, userGroup, modelName, providerJSON string) (*model.Channel, string, bool) {
if !SmartRouterEnabled() {
return nil, "", false
}
selectGroup := resolveSmartRouteGroup(usingGroup, userGroup, modelName)
if selectGroup == "" {
return nil, "", false
}
cands, err := buildRouterCandidates(selectGroup, modelName)
if err != nil || len(cands) == 0 {
return nil, "", false
}
models := []string{modelName}
if raw, ok := common.GetContextKey(c, constant.ContextKeyRequestModelsList); ok {
if sl, ok := raw.([]string); ok && len(sl) > 0 {
models = sl
}
}
req := router.SelectRequest{
Models: models,
ProviderPreferencesJSON: providerJSON,
Candidates: cands,
}
if v, ok := common.GetContextKey(c, constant.ContextKeyRequestHasTools); ok {
if b, ok := v.(bool); ok {
req.RequestHasTools = b
}
}
res, err := router.SelectProviders(req)
if err != nil || len(res.OrderedChannelIDs) == 0 {
return nil, "", false
}
common.SetContextKey(c, constant.ContextKeySmartRouteChannelOrder, res.OrderedChannelIDs)
common.SetContextKey(c, constant.ContextKeySmartRouteSelectGroup, selectGroup)
firstID := res.OrderedChannelIDs[0]
ch, err := model.CacheGetChannel(firstID)
if err != nil || ch == nil || ch.Status != common.ChannelStatusEnabled {
return nil, "", false
}
return ch, selectGroup, true
}
// TrySupplierRouteChannel 在「强制供应商」语义下选择渠道:候选池限制为该供应商下满足
// (group, model) 条件的启用渠道。SmartRouter 开启时走 router-engine 排序;关闭或 router-engine
// 无可用候选时,回退到按优先级 + 权重的随机选择(与 GetRandomSatisfiedChannel 一致),并把最终
// 候选顺序写入 ContextKeySmartRouteChannelOrder保证控制器侧重试也严格落在同一供应商内。
//
// 返回 (channel, selectGroup, true) 表示已完成选择;返回 false 时表示候选为空,调用方应按
// 正常"无可用渠道"错误处理,而不是再去兜底 SmartRouter / 随机,因为那会绕过供应商约束。
func TrySupplierRouteChannel(c *gin.Context, usingGroup, userGroup, modelName, providerJSON string, supplierApplicationID int) (*model.Channel, string, bool) {
filter := func(ch *model.Channel) bool { return ch.SupplierApplicationID == supplierApplicationID }
// 自动分组下挑选一个"对该供应商下的该模型有候选"的子分组。
selectGroup := usingGroup
if usingGroup == "auto" {
selectGroup = ""
for _, g := range GetUserAutoGroup(userGroup) {
cands, _ := buildRouterCandidatesFiltered(g, modelName, filter)
if len(cands) > 0 {
selectGroup = g
break
}
}
if selectGroup == "" {
return nil, "", false
}
}
cands, err := buildRouterCandidatesFiltered(selectGroup, modelName, filter)
if err != nil || len(cands) == 0 {
return nil, "", false
}
candidateIDs := make([]int, 0, len(cands))
for _, c := range cands {
candidateIDs = append(candidateIDs, c.ChannelID)
}
if SmartRouterEnabled() {
models := []string{modelName}
if raw, ok := common.GetContextKey(c, constant.ContextKeyRequestModelsList); ok {
if sl, ok := raw.([]string); ok && len(sl) > 0 {
models = sl
}
}
req := router.SelectRequest{
Models: models,
ProviderPreferencesJSON: providerJSON,
Candidates: cands,
}
if v, ok := common.GetContextKey(c, constant.ContextKeyRequestHasTools); ok {
if b, ok := v.(bool); ok {
req.RequestHasTools = b
}
}
if res, err := router.SelectProviders(req); err == nil && len(res.OrderedChannelIDs) > 0 {
candidateIDs = res.OrderedChannelIDs
}
}
// 按 candidateIDs 顺序取第一个启用渠道作为本次命中;其余供重试回退。
var chosen *model.Channel
for _, id := range candidateIDs {
ch, err := model.CacheGetChannel(id)
if err != nil || ch == nil || ch.Status != common.ChannelStatusEnabled {
continue
}
chosen = ch
break
}
if chosen == nil {
return nil, "", false
}
common.SetContextKey(c, constant.ContextKeySmartRouteChannelOrder, candidateIDs)
common.SetContextKey(c, constant.ContextKeySmartRouteSelectGroup, selectGroup)
if usingGroup == "auto" {
common.SetContextKey(c, constant.ContextKeyAutoGroup, selectGroup)
}
return chosen, selectGroup, true
}
// IngestChatCompletionRoutingHints parses provider / models / tools from JSON body (OpenRouter-compatible).
func IngestChatCompletionRoutingHints(c *gin.Context, modelName string) {
if c == nil || !strings.Contains(c.Request.URL.Path, "chat/completions") {
return
}
var pick struct {
Provider json.RawMessage `json:"provider"`
Models []string `json:"models"`
Tools []json.RawMessage `json:"tools"`
}
if err := common.UnmarshalBodyReusable(c, &pick); err != nil {
return
}
if len(pick.Provider) > 0 {
common.SetContextKey(c, constant.ContextKeyOpenRouterProviderJSON, string(pick.Provider))
}
if len(pick.Models) > 0 {
common.SetContextKey(c, constant.ContextKeyRequestModelsList, pick.Models)
} else if modelName != "" {
common.SetContextKey(c, constant.ContextKeyRequestModelsList, []string{modelName})
}
common.SetContextKey(c, constant.ContextKeyRequestHasTools, len(pick.Tools) > 0)
}