tokenFactory/relay/helper/model_mapped.go

175 lines
6.9 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package helper
import (
"encoding/json"
"errors"
"fmt"
"strings"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/constant"
"github.com/QuantumNous/new-api/dto"
"github.com/QuantumNous/new-api/model"
relaycommon "github.com/QuantumNous/new-api/relay/common"
relayconstant "github.com/QuantumNous/new-api/relay/constant"
"github.com/QuantumNous/new-api/service"
"github.com/QuantumNous/new-api/setting/ratio_setting"
"github.com/gin-gonic/gin"
)
func ModelMappedHelper(c *gin.Context, info *relaycommon.RelayInfo, request dto.Request) error {
if info != nil {
info.TFOpenUpstreamRouteApplied = false
}
if info.ChannelMeta == nil {
info.ChannelMeta = &relaycommon.ChannelMeta{}
}
isResponsesCompact := info.RelayMode == relayconstant.RelayModeResponsesCompact
originModelName := info.OriginModelName
mappingModelName := originModelName
if isResponsesCompact && strings.HasSuffix(originModelName, ratio_setting.CompactModelSuffix) {
mappingModelName = strings.TrimSuffix(originModelName, ratio_setting.CompactModelSuffix)
}
// TokenFactoryOpen 渠道指向上游 TokenFactory 平台,上游 distributor 会将含 "/" 的模型名
// 误解析为路由格式({model}/{route_slug} 或 {alias}/{model}/{channel_no})。
// 因此当上游是 TF 平台时,跳过 model_mapping保留本地原始模型名。
// TFOpen 同步渠道source=tokenfactory_open会在下方 tfRoute 逻辑中拼接三段式路由,
// 同样使用原始模型名。
channelType := common.GetContextKeyInt(c, constant.ContextKeyChannelType)
isTFOpenUpstream := channelType == constant.ChannelTypeTokenFactoryOpen
// map model name
modelMapping := c.GetString("model_mapping")
if modelMapping != "" && modelMapping != "{}" && !isTFOpenUpstream {
modelMap := make(map[string]string)
err := json.Unmarshal([]byte(modelMapping), &modelMap)
if err != nil {
return fmt.Errorf("unmarshal_model_mapping_failed")
}
// 若模型名形如「Seedance2.0/route_slug」优先用已解析的路由得到基础名
// 若路由未命中子站与上游库不一致、slug 在上游不存在等),仍用「最后一段为合法 route_slug」时的基础名走 model_mapping
// 避免把整串当作上游真实 model_id 送给外部网关(会导致 Invalid input params
currentModel := mappingModelName
if idx, matched, _ := service.ParseModelRouteIndex(mappingModelName); matched && idx != nil {
currentModel = idx.ModelName
} else if strings.Contains(mappingModelName, "/") {
lastSlash := strings.LastIndex(mappingModelName, "/")
if lastSlash > 0 && lastSlash < len(mappingModelName)-1 {
potentialSlug := strings.TrimSpace(mappingModelName[lastSlash+1:])
potentialBase := strings.TrimSpace(mappingModelName[:lastSlash])
if potentialBase != "" && model.IsValidRouteSlug(potentialSlug) {
currentModel = potentialBase
}
}
}
// 支持链式模型重定向,最终使用链尾的模型
visitedModels := map[string]bool{
currentModel: true,
}
for {
if mappedModel, exists := modelMap[currentModel]; exists && mappedModel != "" {
// 模型重定向循环检测,避免无限循环
if visitedModels[mappedModel] {
if mappedModel == currentModel {
if currentModel == info.OriginModelName {
info.IsModelMapped = false
return nil
} else {
info.IsModelMapped = true
break
}
}
return errors.New("model_mapping_contains_cycle")
}
visitedModels[mappedModel] = true
currentModel = mappedModel
info.IsModelMapped = true
} else {
break
}
}
if info.IsModelMapped {
info.UpstreamModelName = currentModel
}
}
if isResponsesCompact {
finalUpstreamModelName := mappingModelName
if info.IsModelMapped && info.UpstreamModelName != "" {
finalUpstreamModelName = info.UpstreamModelName
}
info.UpstreamModelName = finalUpstreamModelName
info.OriginModelName = ratio_setting.WithCompactModelSuffix(finalUpstreamModelName)
}
// TFOpen 上游渠道精准路由:
// 新版route_slug 格式(优先),将 UpstreamModelName 改写为 "{model}/{route_slug}"
// 上游的 ParseModelRouteIndex 解析此格式精准路由到对应渠道。
// 旧版兼容alias|channelNo 三段式路由,格式为 "legacy|{alias}|{channelNo}"
// 将 UpstreamModelName 改写为 "{alias}/{model}/{channelNo}"。
// 当上游也是 TokenFactory 平台时,使用原始模型名(上游可识别的本地模型名)而非
// model_mapping 映射后的名称(如 HuggingFace 格式),避免上游 distributor 误解析。
if tfRoute := c.GetString(string(constant.ContextKeyTFOpenUpstreamChannelRoute)); tfRoute != "" {
// 使用原始模型名(而非映射后的名称),因为上游 TF 平台理解本地原始模型名
modelForUpstream := info.OriginModelName
if isResponsesCompact && strings.HasSuffix(modelForUpstream, ratio_setting.CompactModelSuffix) {
modelForUpstream = strings.TrimSuffix(modelForUpstream, ratio_setting.CompactModelSuffix)
}
if strings.HasPrefix(tfRoute, "legacy|") {
// 旧版三段式路由兼容legacy|alias|channelNo → alias/model/channelNo
parts := strings.SplitN(tfRoute, "|", 3)
if len(parts) == 3 {
alias := parts[1]
channelNo := parts[2]
if alias != "" && channelNo != "" {
info.UpstreamModelName = alias + "/" + modelForUpstream + "/" + channelNo
info.IsModelMapped = false
info.TFOpenUpstreamRouteApplied = true
}
}
} else {
// 新版二段式路由route_slug → model/route_slug
routeSlug := strings.TrimSpace(tfRoute)
if routeSlug != "" {
info.UpstreamModelName = modelForUpstream + "/" + routeSlug
info.IsModelMapped = false
info.TFOpenUpstreamRouteApplied = true
}
}
}
// 未命中 model_mapping、且未走 TFOpen 精准路由时请求里仍可能是「Seedance2.0/route_slug」
//(例如子站 other_info 里的 slug 在上游库不存在Distribute 未能改写 body
// 此时至少剥掉「最后一段为合法 route_slug」的后缀避免把整串当作外部视频网关的 model_id
//Hidream/MaaS 会返回 Invalid input params
if info != nil && !isTFOpenUpstream && !info.TFOpenUpstreamRouteApplied && !info.IsModelMapped {
um := strings.TrimSpace(info.UpstreamModelName)
if um == "" {
um = strings.TrimSpace(mappingModelName)
}
if um != "" && strings.Contains(um, "/") {
if idx, matched, _ := service.ParseModelRouteIndex(um); matched && idx != nil {
info.UpstreamModelName = idx.ModelName
} else {
lastSlash := strings.LastIndex(um, "/")
if lastSlash > 0 && lastSlash < len(um)-1 {
potentialSlug := strings.TrimSpace(um[lastSlash+1:])
potentialBase := strings.TrimSpace(um[:lastSlash])
if potentialBase != "" && model.IsValidRouteSlug(potentialSlug) {
info.UpstreamModelName = potentialBase
}
}
}
}
}
if request != nil {
request.SetModelName(info.UpstreamModelName)
}
return nil
}