344 lines
9.9 KiB
Go
344 lines
9.9 KiB
Go
package controller
|
||
|
||
import (
|
||
"fmt"
|
||
"net/http"
|
||
"time"
|
||
|
||
"github.com/QuantumNous/new-api/common"
|
||
"github.com/QuantumNous/new-api/constant"
|
||
"github.com/QuantumNous/new-api/dto"
|
||
"github.com/QuantumNous/new-api/model"
|
||
"github.com/QuantumNous/new-api/relay"
|
||
"github.com/QuantumNous/new-api/relay/channel/ai360"
|
||
"github.com/QuantumNous/new-api/relay/channel/lingyiwanwu"
|
||
"github.com/QuantumNous/new-api/relay/channel/minimax"
|
||
"github.com/QuantumNous/new-api/relay/channel/moonshot"
|
||
taskalivideo "github.com/QuantumNous/new-api/relay/channel/task/alivideo"
|
||
taskopenaivideo "github.com/QuantumNous/new-api/relay/channel/task/openaivideo"
|
||
tasktencentvod "github.com/QuantumNous/new-api/relay/channel/task/tencentvod"
|
||
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||
"github.com/QuantumNous/new-api/service"
|
||
"github.com/QuantumNous/new-api/setting/operation_setting"
|
||
"github.com/QuantumNous/new-api/setting/ratio_setting"
|
||
"github.com/QuantumNous/new-api/types"
|
||
"github.com/gin-gonic/gin"
|
||
"github.com/samber/lo"
|
||
)
|
||
|
||
// https://platform.openai.com/docs/api-reference/models/list
|
||
|
||
var openAIModels []dto.OpenAIModels
|
||
var openAIModelsMap map[string]dto.OpenAIModels
|
||
var channelId2Models map[int][]string
|
||
|
||
func init() {
|
||
// https://platform.openai.com/docs/models/model-endpoint-compatibility
|
||
for i := 0; i < constant.APITypeDummy; i++ {
|
||
if i == constant.APITypeAIProxyLibrary {
|
||
continue
|
||
}
|
||
adaptor := relay.GetAdaptor(i)
|
||
channelName := adaptor.GetChannelName()
|
||
modelNames := adaptor.GetModelList()
|
||
for _, modelName := range modelNames {
|
||
openAIModels = append(openAIModels, dto.OpenAIModels{
|
||
Id: modelName,
|
||
Object: "model",
|
||
Created: 1626777600,
|
||
OwnedBy: channelName,
|
||
})
|
||
}
|
||
}
|
||
for _, modelName := range ai360.ModelList {
|
||
openAIModels = append(openAIModels, dto.OpenAIModels{
|
||
Id: modelName,
|
||
Object: "model",
|
||
Created: 1626777600,
|
||
OwnedBy: ai360.ChannelName,
|
||
})
|
||
}
|
||
for _, modelName := range moonshot.ModelList {
|
||
openAIModels = append(openAIModels, dto.OpenAIModels{
|
||
Id: modelName,
|
||
Object: "model",
|
||
Created: 1626777600,
|
||
OwnedBy: moonshot.ChannelName,
|
||
})
|
||
}
|
||
for _, modelName := range lingyiwanwu.ModelList {
|
||
openAIModels = append(openAIModels, dto.OpenAIModels{
|
||
Id: modelName,
|
||
Object: "model",
|
||
Created: 1626777600,
|
||
OwnedBy: lingyiwanwu.ChannelName,
|
||
})
|
||
}
|
||
for _, modelName := range minimax.ModelList {
|
||
openAIModels = append(openAIModels, dto.OpenAIModels{
|
||
Id: modelName,
|
||
Object: "model",
|
||
Created: 1626777600,
|
||
OwnedBy: minimax.ChannelName,
|
||
})
|
||
}
|
||
for modelName, _ := range constant.MidjourneyModel2Action {
|
||
openAIModels = append(openAIModels, dto.OpenAIModels{
|
||
Id: modelName,
|
||
Object: "model",
|
||
Created: 1626777600,
|
||
OwnedBy: "midjourney",
|
||
})
|
||
}
|
||
for _, modelName := range taskopenaivideo.ModelList {
|
||
openAIModels = append(openAIModels, dto.OpenAIModels{
|
||
Id: modelName,
|
||
Object: "model",
|
||
Created: 1626777600,
|
||
OwnedBy: taskopenaivideo.ChannelName,
|
||
})
|
||
}
|
||
for _, modelName := range tasktencentvod.ModelList {
|
||
openAIModels = append(openAIModels, dto.OpenAIModels{
|
||
Id: modelName,
|
||
Object: "model",
|
||
Created: 1626777600,
|
||
OwnedBy: tasktencentvod.ChannelName,
|
||
})
|
||
}
|
||
for _, modelName := range taskalivideo.ModelList {
|
||
openAIModels = append(openAIModels, dto.OpenAIModels{
|
||
Id: modelName,
|
||
Object: "model",
|
||
Created: 1626777600,
|
||
OwnedBy: taskalivideo.ChannelName,
|
||
})
|
||
}
|
||
openAIModelsMap = make(map[string]dto.OpenAIModels)
|
||
for _, aiModel := range openAIModels {
|
||
openAIModelsMap[aiModel.Id] = aiModel
|
||
}
|
||
channelId2Models = make(map[int][]string)
|
||
for i := 1; i <= constant.ChannelTypeDummy; i++ {
|
||
apiType, success := common.ChannelType2APIType(i)
|
||
if !success || apiType == constant.APITypeAIProxyLibrary {
|
||
continue
|
||
}
|
||
meta := &relaycommon.RelayInfo{ChannelMeta: &relaycommon.ChannelMeta{
|
||
ChannelType: i,
|
||
}}
|
||
adaptor := relay.GetAdaptor(apiType)
|
||
adaptor.Init(meta)
|
||
channelId2Models[i] = adaptor.GetModelList()
|
||
}
|
||
// 任务式渠道(如 OpenAI 视频网关)不走 ChannelType2APIType,需要手动登记默认
|
||
// 模型列表,否则前端「获取模型列表」按钮拿不到内置模型。
|
||
channelId2Models[constant.ChannelTypeOpenAIVideo] = taskopenaivideo.ModelList
|
||
channelId2Models[constant.ChannelTypeVideoGenerator] = taskopenaivideo.ModelList
|
||
channelId2Models[constant.ChannelTypeTencentCloudVideo] = tasktencentvod.ModelList
|
||
channelId2Models[constant.ChannelTypeAliVideo] = taskalivideo.ModelList
|
||
openAIModels = lo.UniqBy(openAIModels, func(m dto.OpenAIModels) string {
|
||
return m.Id
|
||
})
|
||
}
|
||
|
||
func ListModels(c *gin.Context, modelType int) {
|
||
userOpenAiModels := make([]dto.OpenAIModels, 0)
|
||
|
||
acceptUnsetRatioModel := operation_setting.SelfUseModeEnabled
|
||
if !acceptUnsetRatioModel {
|
||
userId := c.GetInt("id")
|
||
if userId > 0 {
|
||
userSettings, _ := model.GetUserSetting(userId, false)
|
||
if userSettings.AcceptUnsetRatioModel {
|
||
acceptUnsetRatioModel = true
|
||
}
|
||
}
|
||
}
|
||
|
||
modelLimitEnable := common.GetContextKeyBool(c, constant.ContextKeyTokenModelLimitEnabled)
|
||
if modelLimitEnable {
|
||
s, ok := common.GetContextKey(c, constant.ContextKeyTokenModelLimit)
|
||
var tokenModelLimit map[string]bool
|
||
if ok {
|
||
tokenModelLimit = s.(map[string]bool)
|
||
} else {
|
||
tokenModelLimit = map[string]bool{}
|
||
}
|
||
for allowModel, _ := range tokenModelLimit {
|
||
if !acceptUnsetRatioModel {
|
||
_, _, exist := ratio_setting.GetModelRatioOrPrice(allowModel)
|
||
if !exist {
|
||
continue
|
||
}
|
||
}
|
||
if oaiModel, ok := openAIModelsMap[allowModel]; ok {
|
||
oaiModel.SupportedEndpointTypes = model.GetModelSupportEndpointTypes(allowModel)
|
||
userOpenAiModels = append(userOpenAiModels, oaiModel)
|
||
} else {
|
||
userOpenAiModels = append(userOpenAiModels, dto.OpenAIModels{
|
||
Id: allowModel,
|
||
Object: "model",
|
||
Created: 1626777600,
|
||
OwnedBy: "custom",
|
||
SupportedEndpointTypes: model.GetModelSupportEndpointTypes(allowModel),
|
||
})
|
||
}
|
||
}
|
||
} else {
|
||
userId := c.GetInt("id")
|
||
userGroup, err := model.GetUserGroup(userId, false)
|
||
if err != nil {
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"success": false,
|
||
"message": "get user group failed",
|
||
})
|
||
return
|
||
}
|
||
group := userGroup
|
||
tokenGroup := common.GetContextKeyString(c, constant.ContextKeyTokenGroup)
|
||
if tokenGroup != "" {
|
||
group = tokenGroup
|
||
}
|
||
var models []string
|
||
if tokenGroup == "auto" {
|
||
for _, autoGroup := range service.GetUserAutoGroup(userGroup) {
|
||
groupModels := model.GetGroupEnabledModels(autoGroup)
|
||
for _, g := range groupModels {
|
||
if !common.StringsContains(models, g) {
|
||
models = append(models, g)
|
||
}
|
||
}
|
||
}
|
||
} else {
|
||
models = model.GetGroupEnabledModels(group)
|
||
}
|
||
for _, modelName := range models {
|
||
if !acceptUnsetRatioModel {
|
||
_, _, exist := ratio_setting.GetModelRatioOrPrice(modelName)
|
||
if !exist {
|
||
continue
|
||
}
|
||
}
|
||
if oaiModel, ok := openAIModelsMap[modelName]; ok {
|
||
oaiModel.SupportedEndpointTypes = model.GetModelSupportEndpointTypes(modelName)
|
||
userOpenAiModels = append(userOpenAiModels, oaiModel)
|
||
} else {
|
||
userOpenAiModels = append(userOpenAiModels, dto.OpenAIModels{
|
||
Id: modelName,
|
||
Object: "model",
|
||
Created: 1626777600,
|
||
OwnedBy: "custom",
|
||
SupportedEndpointTypes: model.GetModelSupportEndpointTypes(modelName),
|
||
})
|
||
}
|
||
}
|
||
}
|
||
|
||
switch modelType {
|
||
case constant.ChannelTypeAnthropic:
|
||
useranthropicModels := make([]dto.AnthropicModel, len(userOpenAiModels))
|
||
for i, model := range userOpenAiModels {
|
||
useranthropicModels[i] = dto.AnthropicModel{
|
||
ID: model.Id,
|
||
CreatedAt: time.Unix(int64(model.Created), 0).UTC().Format(time.RFC3339),
|
||
DisplayName: model.Id,
|
||
Type: "model",
|
||
}
|
||
}
|
||
c.JSON(200, gin.H{
|
||
"data": useranthropicModels,
|
||
"first_id": useranthropicModels[0].ID,
|
||
"has_more": false,
|
||
"last_id": useranthropicModels[len(useranthropicModels)-1].ID,
|
||
})
|
||
case constant.ChannelTypeGemini:
|
||
userGeminiModels := make([]dto.GeminiModel, len(userOpenAiModels))
|
||
for i, model := range userOpenAiModels {
|
||
userGeminiModels[i] = dto.GeminiModel{
|
||
Name: model.Id,
|
||
DisplayName: model.Id,
|
||
}
|
||
}
|
||
c.JSON(200, gin.H{
|
||
"models": userGeminiModels,
|
||
"nextPageToken": nil,
|
||
})
|
||
default:
|
||
c.JSON(200, gin.H{
|
||
"success": true,
|
||
"data": userOpenAiModels,
|
||
"object": "list",
|
||
})
|
||
}
|
||
}
|
||
|
||
func ChannelListModels(c *gin.Context) {
|
||
// 管理员查看全量模型;已审核供应商仅查看自己渠道/模型关联的模型。
|
||
if c.GetInt("role") < common.RoleAdminUser {
|
||
ownedModels, err := collectSupplierOwnedModelNames(c.GetInt("id"))
|
||
if err != nil {
|
||
common.ApiError(c, err)
|
||
return
|
||
}
|
||
models := make([]dto.OpenAIModels, 0, len(openAIModels))
|
||
for _, item := range openAIModels {
|
||
if _, ok := ownedModels[item.Id]; !ok {
|
||
continue
|
||
}
|
||
models = append(models, item)
|
||
}
|
||
c.JSON(200, gin.H{
|
||
"success": true,
|
||
"data": models,
|
||
})
|
||
return
|
||
}
|
||
|
||
c.JSON(200, gin.H{
|
||
"success": true,
|
||
"data": openAIModels,
|
||
})
|
||
}
|
||
|
||
func DashboardListModels(c *gin.Context) {
|
||
c.JSON(200, gin.H{
|
||
"success": true,
|
||
"data": channelId2Models,
|
||
})
|
||
}
|
||
|
||
func EnabledListModels(c *gin.Context) {
|
||
c.JSON(200, gin.H{
|
||
"success": true,
|
||
"data": model.GetEnabledModels(),
|
||
})
|
||
}
|
||
|
||
func RetrieveModel(c *gin.Context, modelType int) {
|
||
modelId := c.Param("model")
|
||
if aiModel, ok := openAIModelsMap[modelId]; ok {
|
||
switch modelType {
|
||
case constant.ChannelTypeAnthropic:
|
||
c.JSON(200, dto.AnthropicModel{
|
||
ID: aiModel.Id,
|
||
CreatedAt: time.Unix(int64(aiModel.Created), 0).UTC().Format(time.RFC3339),
|
||
DisplayName: aiModel.Id,
|
||
Type: "model",
|
||
})
|
||
default:
|
||
c.JSON(200, aiModel)
|
||
}
|
||
} else {
|
||
openAIError := types.OpenAIError{
|
||
Message: fmt.Sprintf("The model '%s' does not exist", modelId),
|
||
Type: "invalid_request_error",
|
||
Param: "model",
|
||
Code: "model_not_found",
|
||
}
|
||
c.JSON(200, gin.H{
|
||
"error": openAIError,
|
||
})
|
||
}
|
||
}
|