303 lines
9.6 KiB
Go
303 lines
9.6 KiB
Go
package service
|
||
|
||
import (
|
||
"encoding/json"
|
||
"os"
|
||
"strings"
|
||
|
||
"github.com/QuantumNous/new-api/common"
|
||
"github.com/QuantumNous/new-api/constant"
|
||
"github.com/QuantumNous/new-api/model"
|
||
"github.com/QuantumNous/new-api/setting/ratio_setting"
|
||
|
||
"github.com/fyinfor/router-engine/pkg/router"
|
||
"github.com/gin-gonic/gin"
|
||
)
|
||
|
||
// SmartRouterEnabled 默认开启。仅当 SMART_ROUTER_ENABLED 为 0 / false / no / off(不区分大小写)时关闭。
|
||
func SmartRouterEnabled() bool {
|
||
v := strings.TrimSpace(os.Getenv("SMART_ROUTER_ENABLED"))
|
||
if v == "" {
|
||
return true
|
||
}
|
||
if v == "0" || strings.EqualFold(v, "false") || strings.EqualFold(v, "no") || strings.EqualFold(v, "off") {
|
||
return false
|
||
}
|
||
return true
|
||
}
|
||
|
||
func channelProviderSlug(ch *model.Channel) string {
|
||
switch ch.Type {
|
||
case constant.ChannelTypeOpenAI:
|
||
return "openai"
|
||
case constant.ChannelTypeAzure:
|
||
return "azure"
|
||
case constant.ChannelTypeAnthropic:
|
||
return "anthropic"
|
||
case constant.ChannelTypeOpenRouter:
|
||
return "openrouter"
|
||
case constant.ChannelTypeGemini:
|
||
return "google"
|
||
case constant.ChannelTypeVertexAi:
|
||
return "google-vertex"
|
||
case constant.ChannelTypeDeepSeek:
|
||
return "deepseek"
|
||
case constant.ChannelTypeSiliconFlow:
|
||
return "siliconflow"
|
||
case constant.ChannelTypeVolcEngine:
|
||
return "volcengine"
|
||
case constant.ChannelTypeMoonshot:
|
||
return "moonshot"
|
||
case constant.ChannelTypeXai:
|
||
return "xai"
|
||
case constant.ChannelTypeMistral:
|
||
return "mistral"
|
||
case constant.ChannelTypePerplexity:
|
||
return "perplexity"
|
||
case constant.ChannelTypeTencent:
|
||
return "tencent"
|
||
case constant.ChannelTypeZhipu, constant.ChannelTypeZhipu_v4:
|
||
return "zhipu"
|
||
case constant.ChannelTypeBaidu, constant.ChannelTypeBaiduV2:
|
||
return "baidu"
|
||
case constant.ChannelTypeAli:
|
||
return "dashscope"
|
||
case constant.ChannelTypeAws:
|
||
return "aws"
|
||
case constant.ChannelTypeCohere:
|
||
return "cohere"
|
||
default:
|
||
if n, ok := constant.ChannelTypeNames[ch.Type]; ok {
|
||
return strings.ToLower(strings.ReplaceAll(n, " ", ""))
|
||
}
|
||
return "unknown"
|
||
}
|
||
}
|
||
|
||
func buildRouterCandidates(group, modelName string) ([]*router.EndpointCandidate, error) {
|
||
return buildRouterCandidatesFiltered(group, modelName, nil)
|
||
}
|
||
|
||
// buildRouterCandidatesFiltered 在 buildRouterCandidates 基础上额外支持按渠道过滤。
|
||
// filter 为 nil 时行为与 buildRouterCandidates 相同;filter 返回 false 的渠道将被剔除。
|
||
func buildRouterCandidatesFiltered(group, modelName string, filter func(*model.Channel) bool) ([]*router.EndpointCandidate, error) {
|
||
ids := model.ListChannelIDsForGroupModel(group, modelName)
|
||
if len(ids) == 0 {
|
||
return nil, nil
|
||
}
|
||
var out []*router.EndpointCandidate
|
||
for _, id := range ids {
|
||
ch, err := model.CacheGetChannel(id)
|
||
if err != nil || ch == nil || ch.Status != common.ChannelStatusEnabled {
|
||
continue
|
||
}
|
||
if !model.IsChannelEnabledForGroupModel(group, modelName, ch.Id) {
|
||
continue
|
||
}
|
||
if filter != nil && !filter(ch) {
|
||
continue
|
||
}
|
||
// UnitPrice is the primary sorting signal for smart routing(与 relay 定价优先级对齐)。
|
||
unitPrice := 1.0
|
||
sid := ch.SupplierApplicationID
|
||
if p, ok := model.ResolveSupplierScopedFixedModelPrice(ch.Id, sid, modelName); ok {
|
||
unitPrice = p
|
||
} else if r, ok, _ := model.ResolveSupplierScopedModelRatio(ch.Id, sid, modelName); ok {
|
||
unitPrice = r
|
||
}
|
||
if unitPrice <= 0 {
|
||
ratio, _, _ := ratio_setting.GetModelRatio(modelName)
|
||
if ratio > 0 {
|
||
unitPrice = ratio
|
||
}
|
||
}
|
||
if unitPrice <= 0 {
|
||
unitPrice = 1
|
||
}
|
||
latSec := float64(ch.ResponseTime) / 1000.0
|
||
if latSec <= 0 {
|
||
latSec = 0.001
|
||
}
|
||
tps := 1.0 / latSec
|
||
w := 0
|
||
if ch.Weight != nil {
|
||
w = int(*ch.Weight)
|
||
}
|
||
prio := int64(0)
|
||
if ch.Priority != nil {
|
||
prio = *ch.Priority
|
||
}
|
||
out = append(out, &router.EndpointCandidate{
|
||
ChannelID: ch.Id,
|
||
Model: modelName,
|
||
ProviderSlug: channelProviderSlug(ch),
|
||
UnitPrice: unitPrice,
|
||
Healthy: true,
|
||
LatencyP50Seconds: latSec,
|
||
ThroughputTps: tps,
|
||
Priority: prio,
|
||
Weight: w,
|
||
})
|
||
}
|
||
return out, nil
|
||
}
|
||
|
||
func resolveSmartRouteGroup(usingGroup, userGroup, modelName string) string {
|
||
if usingGroup != "auto" {
|
||
return usingGroup
|
||
}
|
||
for _, g := range GetUserAutoGroup(userGroup) {
|
||
if len(model.ListChannelIDsForGroupModel(g, modelName)) > 0 {
|
||
return g
|
||
}
|
||
}
|
||
return ""
|
||
}
|
||
|
||
// TrySmartRouteChannel runs in-process router-engine when SmartRouterEnabled(). On success it stores
|
||
// ContextKeySmartRouteChannelOrder for relay retries and returns the first channel.
|
||
func TrySmartRouteChannel(c *gin.Context, usingGroup, userGroup, modelName, providerJSON string) (*model.Channel, string, bool) {
|
||
if !SmartRouterEnabled() {
|
||
return nil, "", false
|
||
}
|
||
selectGroup := resolveSmartRouteGroup(usingGroup, userGroup, modelName)
|
||
if selectGroup == "" {
|
||
return nil, "", false
|
||
}
|
||
cands, err := buildRouterCandidates(selectGroup, modelName)
|
||
if err != nil || len(cands) == 0 {
|
||
return nil, "", false
|
||
}
|
||
models := []string{modelName}
|
||
if raw, ok := common.GetContextKey(c, constant.ContextKeyRequestModelsList); ok {
|
||
if sl, ok := raw.([]string); ok && len(sl) > 0 {
|
||
models = sl
|
||
}
|
||
}
|
||
req := router.SelectRequest{
|
||
Models: models,
|
||
ProviderPreferencesJSON: providerJSON,
|
||
Candidates: cands,
|
||
}
|
||
if v, ok := common.GetContextKey(c, constant.ContextKeyRequestHasTools); ok {
|
||
if b, ok := v.(bool); ok {
|
||
req.RequestHasTools = b
|
||
}
|
||
}
|
||
res, err := router.SelectProviders(req)
|
||
if err != nil || len(res.OrderedChannelIDs) == 0 {
|
||
return nil, "", false
|
||
}
|
||
common.SetContextKey(c, constant.ContextKeySmartRouteChannelOrder, res.OrderedChannelIDs)
|
||
common.SetContextKey(c, constant.ContextKeySmartRouteSelectGroup, selectGroup)
|
||
firstID := res.OrderedChannelIDs[0]
|
||
ch, err := model.CacheGetChannel(firstID)
|
||
if err != nil || ch == nil || ch.Status != common.ChannelStatusEnabled {
|
||
return nil, "", false
|
||
}
|
||
return ch, selectGroup, true
|
||
}
|
||
|
||
// TrySupplierRouteChannel 在「强制供应商」语义下选择渠道:候选池限制为该供应商下满足
|
||
// (group, model) 条件的启用渠道。SmartRouter 开启时走 router-engine 排序;关闭或 router-engine
|
||
// 无可用候选时,回退到按优先级 + 权重的随机选择(与 GetRandomSatisfiedChannel 一致),并把最终
|
||
// 候选顺序写入 ContextKeySmartRouteChannelOrder,保证控制器侧重试也严格落在同一供应商内。
|
||
//
|
||
// 返回 (channel, selectGroup, true) 表示已完成选择;返回 false 时表示候选为空,调用方应按
|
||
// 正常"无可用渠道"错误处理,而不是再去兜底 SmartRouter / 随机,因为那会绕过供应商约束。
|
||
func TrySupplierRouteChannel(c *gin.Context, usingGroup, userGroup, modelName, providerJSON string, supplierApplicationID int) (*model.Channel, string, bool) {
|
||
filter := func(ch *model.Channel) bool { return ch.SupplierApplicationID == supplierApplicationID }
|
||
|
||
// 自动分组下挑选一个"对该供应商下的该模型有候选"的子分组。
|
||
selectGroup := usingGroup
|
||
if usingGroup == "auto" {
|
||
selectGroup = ""
|
||
for _, g := range GetUserAutoGroup(userGroup) {
|
||
cands, _ := buildRouterCandidatesFiltered(g, modelName, filter)
|
||
if len(cands) > 0 {
|
||
selectGroup = g
|
||
break
|
||
}
|
||
}
|
||
if selectGroup == "" {
|
||
return nil, "", false
|
||
}
|
||
}
|
||
|
||
cands, err := buildRouterCandidatesFiltered(selectGroup, modelName, filter)
|
||
if err != nil || len(cands) == 0 {
|
||
return nil, "", false
|
||
}
|
||
|
||
candidateIDs := make([]int, 0, len(cands))
|
||
for _, c := range cands {
|
||
candidateIDs = append(candidateIDs, c.ChannelID)
|
||
}
|
||
|
||
if SmartRouterEnabled() {
|
||
models := []string{modelName}
|
||
if raw, ok := common.GetContextKey(c, constant.ContextKeyRequestModelsList); ok {
|
||
if sl, ok := raw.([]string); ok && len(sl) > 0 {
|
||
models = sl
|
||
}
|
||
}
|
||
req := router.SelectRequest{
|
||
Models: models,
|
||
ProviderPreferencesJSON: providerJSON,
|
||
Candidates: cands,
|
||
}
|
||
if v, ok := common.GetContextKey(c, constant.ContextKeyRequestHasTools); ok {
|
||
if b, ok := v.(bool); ok {
|
||
req.RequestHasTools = b
|
||
}
|
||
}
|
||
if res, err := router.SelectProviders(req); err == nil && len(res.OrderedChannelIDs) > 0 {
|
||
candidateIDs = res.OrderedChannelIDs
|
||
}
|
||
}
|
||
|
||
// 按 candidateIDs 顺序取第一个启用渠道作为本次命中;其余供重试回退。
|
||
var chosen *model.Channel
|
||
for _, id := range candidateIDs {
|
||
ch, err := model.CacheGetChannel(id)
|
||
if err != nil || ch == nil || ch.Status != common.ChannelStatusEnabled {
|
||
continue
|
||
}
|
||
chosen = ch
|
||
break
|
||
}
|
||
if chosen == nil {
|
||
return nil, "", false
|
||
}
|
||
common.SetContextKey(c, constant.ContextKeySmartRouteChannelOrder, candidateIDs)
|
||
common.SetContextKey(c, constant.ContextKeySmartRouteSelectGroup, selectGroup)
|
||
if usingGroup == "auto" {
|
||
common.SetContextKey(c, constant.ContextKeyAutoGroup, selectGroup)
|
||
}
|
||
return chosen, selectGroup, true
|
||
}
|
||
|
||
// IngestChatCompletionRoutingHints parses provider / models / tools from JSON body (OpenRouter-compatible).
|
||
func IngestChatCompletionRoutingHints(c *gin.Context, modelName string) {
|
||
if c == nil || !strings.Contains(c.Request.URL.Path, "chat/completions") {
|
||
return
|
||
}
|
||
var pick struct {
|
||
Provider json.RawMessage `json:"provider"`
|
||
Models []string `json:"models"`
|
||
Tools []json.RawMessage `json:"tools"`
|
||
}
|
||
if err := common.UnmarshalBodyReusable(c, &pick); err != nil {
|
||
return
|
||
}
|
||
if len(pick.Provider) > 0 {
|
||
common.SetContextKey(c, constant.ContextKeyOpenRouterProviderJSON, string(pick.Provider))
|
||
}
|
||
if len(pick.Models) > 0 {
|
||
common.SetContextKey(c, constant.ContextKeyRequestModelsList, pick.Models)
|
||
} else if modelName != "" {
|
||
common.SetContextKey(c, constant.ContextKeyRequestModelsList, []string{modelName})
|
||
}
|
||
common.SetContextKey(c, constant.ContextKeyRequestHasTools, len(pick.Tools) > 0)
|
||
}
|