2710 lines
75 KiB
Go
2710 lines
75 KiB
Go
package controller
|
||
|
||
import (
|
||
"context"
|
||
"encoding/json"
|
||
"errors"
|
||
"fmt"
|
||
"io"
|
||
"net/http"
|
||
"net/url"
|
||
"strconv"
|
||
"strings"
|
||
"time"
|
||
|
||
"github.com/QuantumNous/new-api/common"
|
||
"github.com/QuantumNous/new-api/constant"
|
||
"github.com/QuantumNous/new-api/dto"
|
||
"github.com/QuantumNous/new-api/model"
|
||
relaychannel "github.com/QuantumNous/new-api/relay/channel"
|
||
"github.com/QuantumNous/new-api/relay/channel/gemini"
|
||
"github.com/QuantumNous/new-api/relay/channel/ollama"
|
||
"github.com/QuantumNous/new-api/service"
|
||
|
||
"github.com/gin-gonic/gin"
|
||
)
|
||
|
||
type OpenAIModel struct {
|
||
ID string `json:"id"`
|
||
Object string `json:"object"`
|
||
Created int64 `json:"created"`
|
||
OwnedBy string `json:"owned_by"`
|
||
Metadata map[string]any `json:"metadata,omitempty"`
|
||
Permission []struct {
|
||
ID string `json:"id"`
|
||
Object string `json:"object"`
|
||
Created int64 `json:"created"`
|
||
AllowCreateEngine bool `json:"allow_create_engine"`
|
||
AllowSampling bool `json:"allow_sampling"`
|
||
AllowLogprobs bool `json:"allow_logprobs"`
|
||
AllowSearchIndices bool `json:"allow_search_indices"`
|
||
AllowView bool `json:"allow_view"`
|
||
AllowFineTuning bool `json:"allow_fine_tuning"`
|
||
Organization string `json:"organization"`
|
||
Group string `json:"group"`
|
||
IsBlocking bool `json:"is_blocking"`
|
||
} `json:"permission"`
|
||
Root string `json:"root"`
|
||
Parent string `json:"parent"`
|
||
}
|
||
|
||
type OpenAIModelsResponse struct {
|
||
Data []OpenAIModel `json:"data"`
|
||
Success bool `json:"success"`
|
||
}
|
||
|
||
var channelAllowedSupplierTypes = map[string]struct{}{
|
||
"公有云": {},
|
||
"AIDC": {},
|
||
"企业中转站": {},
|
||
"个人中转站": {},
|
||
}
|
||
|
||
// defaultChannelSupplierType 当渠道行与关联供应商申请均未提供 supplier_type 时的兜底值(须为 channelAllowedSupplierTypes 之一)。
|
||
const defaultChannelSupplierType = "公有云"
|
||
|
||
// isValidChannelSupplierType 校验供应商类型是否属于预定义枚举值。
|
||
func isValidChannelSupplierType(supplierType string) bool {
|
||
_, ok := channelAllowedSupplierTypes[supplierType]
|
||
return ok
|
||
}
|
||
|
||
func parseStatusFilter(statusParam string) int {
|
||
switch strings.ToLower(statusParam) {
|
||
case "enabled", "1":
|
||
return common.ChannelStatusEnabled
|
||
case "disabled", "0":
|
||
return 0
|
||
default:
|
||
return -1
|
||
}
|
||
}
|
||
|
||
func clearChannelInfo(channel *model.Channel) {
|
||
if channel.ChannelInfo.IsMultiKey {
|
||
channel.ChannelInfo.MultiKeyDisabledReason = nil
|
||
channel.ChannelInfo.MultiKeyDisabledTime = nil
|
||
}
|
||
}
|
||
|
||
// attachSupplierNames 为渠道列表补齐供应商用户名(owner_user_id 对应 users.username)。
|
||
func attachSupplierNames(channels []*model.Channel) {
|
||
ownerIDs := make([]int, 0)
|
||
ownerSet := make(map[int]struct{})
|
||
for _, channel := range channels {
|
||
if channel == nil || channel.OwnerUserID <= 0 {
|
||
continue
|
||
}
|
||
if _, ok := ownerSet[channel.OwnerUserID]; ok {
|
||
continue
|
||
}
|
||
ownerSet[channel.OwnerUserID] = struct{}{}
|
||
ownerIDs = append(ownerIDs, channel.OwnerUserID)
|
||
}
|
||
if len(ownerIDs) == 0 {
|
||
return
|
||
}
|
||
var users []model.User
|
||
if err := model.DB.Select("id, username").Where("id IN ?", ownerIDs).Find(&users).Error; err != nil {
|
||
return
|
||
}
|
||
userMap := make(map[int]string, len(users))
|
||
for _, user := range users {
|
||
userMap[user.Id] = user.Username
|
||
}
|
||
for _, channel := range channels {
|
||
if channel == nil || channel.OwnerUserID <= 0 {
|
||
continue
|
||
}
|
||
channel.SupplierName = userMap[channel.OwnerUserID]
|
||
}
|
||
}
|
||
|
||
func GetAllChannels(c *gin.Context) {
|
||
pageInfo := common.GetPageQuery(c)
|
||
channelData := make([]*model.Channel, 0)
|
||
idSort, _ := strconv.ParseBool(c.Query("id_sort"))
|
||
enableTagMode, _ := strconv.ParseBool(c.Query("tag_mode"))
|
||
supplierKeyword := strings.TrimSpace(c.Query("supplier"))
|
||
statusParam := c.Query("status")
|
||
// statusFilter: -1 all, 1 enabled, 0 disabled (include auto & manual)
|
||
statusFilter := parseStatusFilter(statusParam)
|
||
// type filter
|
||
typeStr := c.Query("type")
|
||
typeFilter := -1
|
||
if typeStr != "" {
|
||
if t, err := strconv.Atoi(typeStr); err == nil {
|
||
typeFilter = t
|
||
}
|
||
}
|
||
// 供应商复用原渠道列表接口:仅查看自己渠道,管理员保持原有全量逻辑。
|
||
if c.GetInt("role") < common.RoleAdminUser {
|
||
baseQuery := model.DB.Model(&model.Channel{}).Where("owner_user_id = ?", c.GetInt("id"))
|
||
if typeFilter >= 0 {
|
||
baseQuery = baseQuery.Where("type = ?", typeFilter)
|
||
}
|
||
if statusFilter == common.ChannelStatusEnabled {
|
||
baseQuery = baseQuery.Where("status = ?", common.ChannelStatusEnabled)
|
||
} else if statusFilter == 0 {
|
||
baseQuery = baseQuery.Where("status != ?", common.ChannelStatusEnabled)
|
||
}
|
||
var total int64
|
||
if err := baseQuery.Count(&total).Error; err != nil {
|
||
common.ApiError(c, err)
|
||
return
|
||
}
|
||
order := "priority desc"
|
||
if idSort {
|
||
order = "id desc"
|
||
}
|
||
if err := baseQuery.Order(order).Limit(pageInfo.GetPageSize()).Offset(pageInfo.GetStartIdx()).Omit("key").Find(&channelData).Error; err != nil {
|
||
common.ApiError(c, err)
|
||
return
|
||
}
|
||
for _, datum := range channelData {
|
||
clearChannelInfo(datum)
|
||
}
|
||
attachSupplierNames(channelData)
|
||
typeCounts := make(map[int64]int64)
|
||
for _, channel := range channelData {
|
||
typeCounts[int64(channel.Type)]++
|
||
}
|
||
common.ApiSuccess(c, gin.H{
|
||
"items": channelData,
|
||
"total": total,
|
||
"page": pageInfo.GetPage(),
|
||
"page_size": pageInfo.GetPageSize(),
|
||
"type_counts": typeCounts,
|
||
})
|
||
return
|
||
}
|
||
|
||
var total int64
|
||
|
||
if enableTagMode {
|
||
tags, err := model.GetPaginatedTags(pageInfo.GetStartIdx(), pageInfo.GetPageSize())
|
||
if err != nil {
|
||
common.SysError("failed to get paginated tags: " + err.Error())
|
||
c.JSON(http.StatusOK, gin.H{"success": false, "message": "获取标签失败,请稍后重试"})
|
||
return
|
||
}
|
||
for _, tag := range tags {
|
||
if tag == nil || *tag == "" {
|
||
continue
|
||
}
|
||
tagChannels, err := model.GetChannelsByTag(*tag, idSort, false)
|
||
if err != nil {
|
||
continue
|
||
}
|
||
filtered := make([]*model.Channel, 0)
|
||
for _, ch := range tagChannels {
|
||
if statusFilter == common.ChannelStatusEnabled && ch.Status != common.ChannelStatusEnabled {
|
||
continue
|
||
}
|
||
if statusFilter == 0 && ch.Status == common.ChannelStatusEnabled {
|
||
continue
|
||
}
|
||
if typeFilter >= 0 && ch.Type != typeFilter {
|
||
continue
|
||
}
|
||
filtered = append(filtered, ch)
|
||
}
|
||
channelData = append(channelData, filtered...)
|
||
}
|
||
total, _ = model.CountAllTags()
|
||
} else {
|
||
baseQuery := model.DB.Model(&model.Channel{})
|
||
if supplierKeyword != "" {
|
||
baseQuery = baseQuery.Joins("LEFT JOIN users ON users.id = channels.owner_user_id").Where("users.username LIKE ?", "%"+supplierKeyword+"%")
|
||
}
|
||
if typeFilter >= 0 {
|
||
baseQuery = baseQuery.Where("type = ?", typeFilter)
|
||
}
|
||
if statusFilter == common.ChannelStatusEnabled {
|
||
baseQuery = baseQuery.Where("status = ?", common.ChannelStatusEnabled)
|
||
} else if statusFilter == 0 {
|
||
baseQuery = baseQuery.Where("status != ?", common.ChannelStatusEnabled)
|
||
}
|
||
|
||
baseQuery.Count(&total)
|
||
|
||
order := "priority desc"
|
||
if idSort {
|
||
order = "id desc"
|
||
}
|
||
|
||
err := baseQuery.Order(order).Limit(pageInfo.GetPageSize()).Offset(pageInfo.GetStartIdx()).Omit("key").Find(&channelData).Error
|
||
if err != nil {
|
||
common.SysError("failed to get channels: " + err.Error())
|
||
c.JSON(http.StatusOK, gin.H{"success": false, "message": "获取渠道列表失败,请稍后重试"})
|
||
return
|
||
}
|
||
}
|
||
|
||
for _, datum := range channelData {
|
||
clearChannelInfo(datum)
|
||
}
|
||
attachSupplierNames(channelData)
|
||
|
||
countQuery := model.DB.Model(&model.Channel{})
|
||
if statusFilter == common.ChannelStatusEnabled {
|
||
countQuery = countQuery.Where("status = ?", common.ChannelStatusEnabled)
|
||
} else if statusFilter == 0 {
|
||
countQuery = countQuery.Where("status != ?", common.ChannelStatusEnabled)
|
||
}
|
||
var results []struct {
|
||
Type int64
|
||
Count int64
|
||
}
|
||
_ = countQuery.Select("type, count(*) as count").Group("type").Find(&results).Error
|
||
typeCounts := make(map[int64]int64)
|
||
for _, r := range results {
|
||
typeCounts[r.Type] = r.Count
|
||
}
|
||
common.ApiSuccess(c, gin.H{
|
||
"items": channelData,
|
||
"total": total,
|
||
"page": pageInfo.GetPage(),
|
||
"page_size": pageInfo.GetPageSize(),
|
||
"type_counts": typeCounts,
|
||
})
|
||
return
|
||
}
|
||
|
||
func buildFetchModelsHeaders(channel *model.Channel, key string) (http.Header, error) {
|
||
var headers http.Header
|
||
switch channel.Type {
|
||
case constant.ChannelTypeAnthropic:
|
||
headers = GetClaudeAuthHeader(key)
|
||
default:
|
||
headers = GetAuthHeader(key)
|
||
}
|
||
|
||
headerOverride := channel.GetHeaderOverride()
|
||
for k, v := range headerOverride {
|
||
if relaychannel.IsHeaderPassthroughRuleKey(k) {
|
||
continue
|
||
}
|
||
str, ok := v.(string)
|
||
if !ok {
|
||
return nil, fmt.Errorf("invalid header override for key %s", k)
|
||
}
|
||
if strings.Contains(str, "{api_key}") {
|
||
str = strings.ReplaceAll(str, "{api_key}", key)
|
||
}
|
||
headers.Set(k, str)
|
||
}
|
||
|
||
return headers, nil
|
||
}
|
||
|
||
func FetchUpstreamModels(c *gin.Context) {
|
||
id, err := strconv.Atoi(c.Param("id"))
|
||
if err != nil {
|
||
common.ApiError(c, err)
|
||
return
|
||
}
|
||
|
||
channel, err := model.GetChannelById(id, true)
|
||
if err != nil {
|
||
common.ApiError(c, err)
|
||
return
|
||
}
|
||
// 供应商只允许拉取自己渠道的上游模型,防止跨供应商越权读取。
|
||
if c.GetInt("role") < common.RoleAdminUser && channel.OwnerUserID != c.GetInt("id") {
|
||
c.JSON(http.StatusForbidden, gin.H{
|
||
"success": false,
|
||
"message": "无权访问其他供应商渠道",
|
||
})
|
||
return
|
||
}
|
||
|
||
ids, err := fetchChannelUpstreamModelIDs(channel)
|
||
if err != nil {
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"success": false,
|
||
"message": fmt.Sprintf("获取模型列表失败: %s", err.Error()),
|
||
})
|
||
return
|
||
}
|
||
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"success": true,
|
||
"message": "",
|
||
"data": ids,
|
||
})
|
||
}
|
||
|
||
func FixChannelsAbilities(c *gin.Context) {
|
||
success, fails, err := model.FixAbility()
|
||
if err != nil {
|
||
common.ApiError(c, err)
|
||
return
|
||
}
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"success": true,
|
||
"message": "",
|
||
"data": gin.H{
|
||
"success": success,
|
||
"fails": fails,
|
||
},
|
||
})
|
||
}
|
||
|
||
func SearchChannels(c *gin.Context) {
|
||
keyword := c.Query("keyword")
|
||
supplierKeyword := strings.TrimSpace(c.Query("supplier"))
|
||
group := c.Query("group")
|
||
modelKeyword := c.Query("model")
|
||
statusParam := c.Query("status")
|
||
statusFilter := parseStatusFilter(statusParam)
|
||
idSort, _ := strconv.ParseBool(c.Query("id_sort"))
|
||
enableTagMode, _ := strconv.ParseBool(c.Query("tag_mode"))
|
||
// 供应商复用原渠道搜索接口:仅查询自己渠道。
|
||
if c.GetInt("role") < common.RoleAdminUser {
|
||
channelID, _ := model.ParseSupplierChannelIDFilter(keyword)
|
||
filter := model.SupplierChannelSearchFilter{
|
||
ChannelID: channelID,
|
||
Keyword: keyword,
|
||
Supplier: supplierKeyword,
|
||
ModelKeyword: modelKeyword,
|
||
Group: group,
|
||
}
|
||
ownerUserID := c.GetInt("id")
|
||
channelData, total, err := model.SearchSupplierChannels(&ownerUserID, 0, 100000, filter)
|
||
if err != nil {
|
||
common.ApiError(c, err)
|
||
return
|
||
}
|
||
if statusFilter == common.ChannelStatusEnabled || statusFilter == 0 {
|
||
filtered := make([]*model.Channel, 0, len(channelData))
|
||
for _, ch := range channelData {
|
||
if statusFilter == common.ChannelStatusEnabled && ch.Status != common.ChannelStatusEnabled {
|
||
continue
|
||
}
|
||
if statusFilter == 0 && ch.Status == common.ChannelStatusEnabled {
|
||
continue
|
||
}
|
||
filtered = append(filtered, ch)
|
||
}
|
||
channelData = filtered
|
||
total = int64(len(filtered))
|
||
}
|
||
typeParam := c.Query("type")
|
||
typeFilter := -1
|
||
if typeParam != "" {
|
||
if tp, err := strconv.Atoi(typeParam); err == nil {
|
||
typeFilter = tp
|
||
}
|
||
}
|
||
if typeFilter >= 0 {
|
||
filtered := make([]*model.Channel, 0, len(channelData))
|
||
for _, ch := range channelData {
|
||
if ch.Type == typeFilter {
|
||
filtered = append(filtered, ch)
|
||
}
|
||
}
|
||
channelData = filtered
|
||
total = int64(len(filtered))
|
||
}
|
||
typeCounts := make(map[int64]int64)
|
||
for _, channel := range channelData {
|
||
typeCounts[int64(channel.Type)]++
|
||
}
|
||
page, _ := strconv.Atoi(c.DefaultQuery("p", "1"))
|
||
pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20"))
|
||
if page < 1 {
|
||
page = 1
|
||
}
|
||
if pageSize <= 0 {
|
||
pageSize = 20
|
||
}
|
||
startIdx := (page - 1) * pageSize
|
||
if startIdx > len(channelData) {
|
||
startIdx = len(channelData)
|
||
}
|
||
endIdx := startIdx + pageSize
|
||
if endIdx > len(channelData) {
|
||
endIdx = len(channelData)
|
||
}
|
||
pagedData := channelData[startIdx:endIdx]
|
||
for _, datum := range pagedData {
|
||
clearChannelInfo(datum)
|
||
}
|
||
attachSupplierNames(pagedData)
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"success": true,
|
||
"message": "",
|
||
"data": gin.H{
|
||
"items": pagedData,
|
||
"total": total,
|
||
"type_counts": typeCounts,
|
||
},
|
||
})
|
||
return
|
||
}
|
||
channelData := make([]*model.Channel, 0)
|
||
if enableTagMode {
|
||
tags, err := model.SearchTags(keyword, group, modelKeyword, idSort)
|
||
if err != nil {
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"success": false,
|
||
"message": err.Error(),
|
||
})
|
||
return
|
||
}
|
||
for _, tag := range tags {
|
||
if tag != nil && *tag != "" {
|
||
tagChannel, err := model.GetChannelsByTag(*tag, idSort, false)
|
||
if err == nil {
|
||
channelData = append(channelData, tagChannel...)
|
||
}
|
||
}
|
||
}
|
||
} else {
|
||
channels, err := model.SearchChannels(keyword, group, modelKeyword, idSort)
|
||
if err != nil {
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"success": false,
|
||
"message": err.Error(),
|
||
})
|
||
return
|
||
}
|
||
channelData = channels
|
||
}
|
||
attachSupplierNames(channelData)
|
||
if supplierKeyword != "" {
|
||
filteredBySupplier := make([]*model.Channel, 0, len(channelData))
|
||
for _, ch := range channelData {
|
||
if strings.Contains(strings.ToLower(ch.SupplierName), strings.ToLower(supplierKeyword)) {
|
||
filteredBySupplier = append(filteredBySupplier, ch)
|
||
}
|
||
}
|
||
channelData = filteredBySupplier
|
||
}
|
||
|
||
if statusFilter == common.ChannelStatusEnabled || statusFilter == 0 {
|
||
filtered := make([]*model.Channel, 0, len(channelData))
|
||
for _, ch := range channelData {
|
||
if statusFilter == common.ChannelStatusEnabled && ch.Status != common.ChannelStatusEnabled {
|
||
continue
|
||
}
|
||
if statusFilter == 0 && ch.Status == common.ChannelStatusEnabled {
|
||
continue
|
||
}
|
||
filtered = append(filtered, ch)
|
||
}
|
||
channelData = filtered
|
||
}
|
||
|
||
// calculate type counts for search results
|
||
typeCounts := make(map[int64]int64)
|
||
for _, channel := range channelData {
|
||
typeCounts[int64(channel.Type)]++
|
||
}
|
||
|
||
typeParam := c.Query("type")
|
||
typeFilter := -1
|
||
if typeParam != "" {
|
||
if tp, err := strconv.Atoi(typeParam); err == nil {
|
||
typeFilter = tp
|
||
}
|
||
}
|
||
|
||
if typeFilter >= 0 {
|
||
filtered := make([]*model.Channel, 0, len(channelData))
|
||
for _, ch := range channelData {
|
||
if ch.Type == typeFilter {
|
||
filtered = append(filtered, ch)
|
||
}
|
||
}
|
||
channelData = filtered
|
||
}
|
||
|
||
page, _ := strconv.Atoi(c.DefaultQuery("p", "1"))
|
||
pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20"))
|
||
if page < 1 {
|
||
page = 1
|
||
}
|
||
if pageSize <= 0 {
|
||
pageSize = 20
|
||
}
|
||
|
||
total := len(channelData)
|
||
startIdx := (page - 1) * pageSize
|
||
if startIdx > total {
|
||
startIdx = total
|
||
}
|
||
endIdx := startIdx + pageSize
|
||
if endIdx > total {
|
||
endIdx = total
|
||
}
|
||
|
||
pagedData := channelData[startIdx:endIdx]
|
||
|
||
for _, datum := range pagedData {
|
||
clearChannelInfo(datum)
|
||
}
|
||
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"success": true,
|
||
"message": "",
|
||
"data": gin.H{
|
||
"items": pagedData,
|
||
"total": total,
|
||
"type_counts": typeCounts,
|
||
},
|
||
})
|
||
return
|
||
}
|
||
|
||
func GetChannel(c *gin.Context) {
|
||
id, err := strconv.Atoi(c.Param("id"))
|
||
if err != nil {
|
||
common.ApiError(c, err)
|
||
return
|
||
}
|
||
channel, err := model.GetChannelById(id, false)
|
||
if err != nil {
|
||
common.ApiError(c, err)
|
||
return
|
||
}
|
||
// 供应商仅允许查看自己归属的渠道。
|
||
if c.GetInt("role") < common.RoleAdminUser && channel.OwnerUserID != c.GetInt("id") {
|
||
c.JSON(http.StatusForbidden, gin.H{
|
||
"success": false,
|
||
"message": "无权访问其他供应商渠道",
|
||
})
|
||
return
|
||
}
|
||
if channel != nil {
|
||
clearChannelInfo(channel)
|
||
}
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"success": true,
|
||
"message": "",
|
||
"data": channel,
|
||
})
|
||
return
|
||
}
|
||
|
||
// GetChannelKey 获取渠道密钥(需要通过安全验证中间件)
|
||
// 此函数依赖 SecureVerificationRequired 中间件,确保用户已通过安全验证
|
||
func GetChannelKey(c *gin.Context) {
|
||
userId := c.GetInt("id")
|
||
channelId, err := strconv.Atoi(c.Param("id"))
|
||
if err != nil {
|
||
common.ApiError(c, fmt.Errorf("渠道ID格式错误: %v", err))
|
||
return
|
||
}
|
||
|
||
// 获取渠道信息(包含密钥)
|
||
channel, err := model.GetChannelById(channelId, true)
|
||
if err != nil {
|
||
common.ApiError(c, fmt.Errorf("获取渠道信息失败: %v", err))
|
||
return
|
||
}
|
||
|
||
if channel == nil {
|
||
common.ApiError(c, fmt.Errorf("渠道不存在"))
|
||
return
|
||
}
|
||
|
||
// 记录操作日志
|
||
model.RecordLog(userId, model.LogTypeSystem, fmt.Sprintf("查看渠道密钥信息 (渠道ID: %d)", channelId))
|
||
|
||
// 返回渠道密钥
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"success": true,
|
||
"message": "获取成功",
|
||
"data": map[string]interface{}{
|
||
"key": channel.Key,
|
||
},
|
||
})
|
||
}
|
||
|
||
// validateTwoFactorAuth 统一的2FA验证函数
|
||
func validateTwoFactorAuth(twoFA *model.TwoFA, code string) bool {
|
||
// 尝试验证TOTP
|
||
if cleanCode, err := common.ValidateNumericCode(code); err == nil {
|
||
if isValid, _ := twoFA.ValidateTOTPAndUpdateUsage(cleanCode); isValid {
|
||
return true
|
||
}
|
||
}
|
||
|
||
// 尝试验证备用码
|
||
if isValid, err := twoFA.ValidateBackupCodeAndUpdateUsage(code); err == nil && isValid {
|
||
return true
|
||
}
|
||
|
||
return false
|
||
}
|
||
|
||
// validateChannel 通用的渠道校验函数
|
||
func validateChannel(channel *model.Channel, isAdd bool) error {
|
||
if channel == nil {
|
||
return fmt.Errorf("channel cannot be empty")
|
||
}
|
||
|
||
channel.CompanyLogoURL = strings.TrimSpace(channel.CompanyLogoURL)
|
||
channel.SupplierType = strings.TrimSpace(channel.SupplierType)
|
||
// TokenFactoryOpen (type=60) 渠道的 supplier_type 由上游同步继承,创建时允许为空
|
||
if channel.SupplierType == "" && channel.Type != constant.ChannelTypeTokenFactoryOpen {
|
||
return fmt.Errorf("供应商类型不能为空")
|
||
}
|
||
if channel.SupplierType != "" && !isValidChannelSupplierType(channel.SupplierType) {
|
||
return fmt.Errorf("供应商类型无效")
|
||
}
|
||
|
||
// 校验 channel settings
|
||
if err := channel.ValidateSettings(); err != nil {
|
||
return fmt.Errorf("渠道额外设置[channel setting] 格式错误:%s", err.Error())
|
||
}
|
||
|
||
// 如果是添加操作,检查 channel 和 key 是否为空
|
||
if isAdd {
|
||
if channel.Key == "" {
|
||
return fmt.Errorf("channel cannot be empty")
|
||
}
|
||
|
||
// 检查模型名称长度是否超过 255
|
||
for _, m := range channel.GetModels() {
|
||
if len(m) > 255 {
|
||
return fmt.Errorf("模型名称过长: %s", m)
|
||
}
|
||
}
|
||
}
|
||
|
||
// VertexAI 特殊校验
|
||
if channel.Type == constant.ChannelTypeVertexAi {
|
||
if channel.Other == "" {
|
||
return fmt.Errorf("部署地区不能为空")
|
||
}
|
||
|
||
regionMap, err := common.StrToMap(channel.Other)
|
||
if err != nil {
|
||
return fmt.Errorf("部署地区必须是标准的Json格式,例如{\"default\": \"us-central1\", \"region2\": \"us-east1\"}")
|
||
}
|
||
|
||
if regionMap["default"] == nil {
|
||
return fmt.Errorf("部署地区必须包含default字段")
|
||
}
|
||
}
|
||
|
||
// Codex OAuth key validation (optional, only when JSON object is provided)
|
||
if channel.Type == constant.ChannelTypeCodex {
|
||
trimmedKey := strings.TrimSpace(channel.Key)
|
||
if isAdd || trimmedKey != "" {
|
||
if !strings.HasPrefix(trimmedKey, "{") {
|
||
return fmt.Errorf("Codex key must be a valid JSON object")
|
||
}
|
||
var keyMap map[string]any
|
||
if err := common.Unmarshal([]byte(trimmedKey), &keyMap); err != nil {
|
||
return fmt.Errorf("Codex key must be a valid JSON object")
|
||
}
|
||
if v, ok := keyMap["access_token"]; !ok || v == nil || strings.TrimSpace(fmt.Sprintf("%v", v)) == "" {
|
||
return fmt.Errorf("Codex key JSON must include access_token")
|
||
}
|
||
if v, ok := keyMap["account_id"]; !ok || v == nil || strings.TrimSpace(fmt.Sprintf("%v", v)) == "" {
|
||
return fmt.Errorf("Codex key JSON must include account_id")
|
||
}
|
||
}
|
||
}
|
||
|
||
if channel != nil && channel.PriceDiscountPercent != nil {
|
||
v := *channel.PriceDiscountPercent
|
||
if v < 0 || v > 1000 {
|
||
return fmt.Errorf("价格折扣(百分比)须介于 0 与 1000 之间,100 表示无折扣,60 表示按原价 60%% 计费")
|
||
}
|
||
}
|
||
|
||
if rs := strings.TrimSpace(channel.RouteSlug); rs != "" && !model.IsValidRouteSlug(rs) {
|
||
return fmt.Errorf("route_slug 格式无效(2~32 位字母数字,且不能为 c 加纯数字)")
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
func RefreshCodexChannelCredential(c *gin.Context) {
|
||
channelId, err := strconv.Atoi(c.Param("id"))
|
||
if err != nil {
|
||
common.ApiError(c, fmt.Errorf("invalid channel id: %w", err))
|
||
return
|
||
}
|
||
|
||
ctx, cancel := context.WithTimeout(c.Request.Context(), 10*time.Second)
|
||
defer cancel()
|
||
|
||
oauthKey, ch, err := service.RefreshCodexChannelCredential(ctx, channelId, service.CodexCredentialRefreshOptions{ResetCaches: true})
|
||
if err != nil {
|
||
common.SysError("failed to refresh codex channel credential: " + err.Error())
|
||
c.JSON(http.StatusOK, gin.H{"success": false, "message": "刷新凭证失败,请稍后重试"})
|
||
return
|
||
}
|
||
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"success": true,
|
||
"message": "refreshed",
|
||
"data": gin.H{
|
||
"expires_at": oauthKey.Expired,
|
||
"last_refresh": oauthKey.LastRefresh,
|
||
"account_id": oauthKey.AccountID,
|
||
"email": oauthKey.Email,
|
||
"channel_id": ch.Id,
|
||
"channel_type": ch.Type,
|
||
"channel_name": ch.Name,
|
||
},
|
||
})
|
||
}
|
||
|
||
type AddChannelRequest struct {
|
||
Mode string `json:"mode"`
|
||
MultiKeyMode constant.MultiKeyMode `json:"multi_key_mode"`
|
||
BatchAddSetKeyPrefix2Name bool `json:"batch_add_set_key_prefix_2_name"`
|
||
Channel *model.Channel `json:"channel"`
|
||
}
|
||
|
||
// applySupplierChannelOwnershipForCreate 在供应商创建渠道时强制写入归属信息,防止越权伪造 owner 字段。
|
||
func applySupplierChannelOwnershipForCreate(c *gin.Context, channel *model.Channel) error {
|
||
if c.GetInt("role") >= common.RoleAdminUser {
|
||
return nil
|
||
}
|
||
app, err := model.GetApprovedSupplierApplicationByApplicant(c.GetInt("id"))
|
||
if err != nil {
|
||
return err
|
||
}
|
||
channel.OwnerUserID = c.GetInt("id")
|
||
channel.SupplierApplicationID = app.ID
|
||
return nil
|
||
}
|
||
|
||
// validateSupplierChannelOwnershipForUpdate 校验供应商仅可更新自己的渠道,管理员不受限制。
|
||
func validateSupplierChannelOwnershipForUpdate(c *gin.Context, originChannel *model.Channel) bool {
|
||
if c.GetInt("role") >= common.RoleAdminUser {
|
||
return true
|
||
}
|
||
return originChannel.OwnerUserID == c.GetInt("id")
|
||
}
|
||
|
||
func getVertexArrayKeys(keys string) ([]string, error) {
|
||
if keys == "" {
|
||
return nil, nil
|
||
}
|
||
var keyArray []interface{}
|
||
err := common.Unmarshal([]byte(keys), &keyArray)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("批量添加 Vertex AI 必须使用标准的JsonArray格式,例如[{key1}, {key2}...],请检查输入: %w", err)
|
||
}
|
||
cleanKeys := make([]string, 0, len(keyArray))
|
||
for _, key := range keyArray {
|
||
var keyStr string
|
||
switch v := key.(type) {
|
||
case string:
|
||
keyStr = strings.TrimSpace(v)
|
||
default:
|
||
bytes, err := json.Marshal(v)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("Vertex AI key JSON 编码失败: %w", err)
|
||
}
|
||
keyStr = string(bytes)
|
||
}
|
||
if keyStr != "" {
|
||
cleanKeys = append(cleanKeys, keyStr)
|
||
}
|
||
}
|
||
if len(cleanKeys) == 0 {
|
||
return nil, fmt.Errorf("批量添加 Vertex AI 的 keys 不能为空")
|
||
}
|
||
return cleanKeys, nil
|
||
}
|
||
|
||
type upstreamChannelSyncItem struct {
|
||
ID int `json:"id"`
|
||
Name string `json:"name"`
|
||
Models string `json:"models"`
|
||
Group string `json:"group"`
|
||
Status int `json:"status"`
|
||
Type int `json:"type"`
|
||
ChannelNo string `json:"channel_no"`
|
||
RouteSlug string `json:"route_slug"`
|
||
SupplierApplication int `json:"supplier_application_id"`
|
||
SupplierAlias string `json:"supplier_alias"`
|
||
SupplierType string `json:"supplier_type"`
|
||
CompanyLogoURL string `json:"company_logo_url"`
|
||
PriceDiscountPercent float64 `json:"price_discount_percent"`
|
||
MarkupDiscountRate float64 `json:"markup_discount_rate"`
|
||
ModelMapping string `json:"model_mapping"`
|
||
ModelPrice map[string]float64 `json:"model_price"`
|
||
ModelRatio map[string]float64 `json:"model_ratio"`
|
||
}
|
||
|
||
func decodeUpstreamModelMapping(m map[string]any) string {
|
||
raw, ok := m["model_mapping"]
|
||
if !ok || raw == nil {
|
||
return ""
|
||
}
|
||
switch x := raw.(type) {
|
||
case string:
|
||
return strings.TrimSpace(x)
|
||
case map[string]any:
|
||
b, err := json.Marshal(x)
|
||
if err != nil {
|
||
return ""
|
||
}
|
||
return strings.TrimSpace(string(b))
|
||
default:
|
||
b, err := json.Marshal(raw)
|
||
if err != nil {
|
||
return strings.TrimSpace(common.Interface2String(raw))
|
||
}
|
||
return strings.TrimSpace(string(b))
|
||
}
|
||
}
|
||
|
||
func isTokenFactoryOpenBaseURL(raw string) bool {
|
||
parsed, err := url.Parse(strings.TrimSpace(raw))
|
||
if err != nil {
|
||
return false
|
||
}
|
||
scheme := strings.ToLower(strings.TrimSpace(parsed.Scheme))
|
||
if scheme != "http" && scheme != "https" {
|
||
return false
|
||
}
|
||
return strings.TrimSpace(parsed.Hostname()) != ""
|
||
}
|
||
|
||
func isLikelyTokenFactoryStatusData(data map[string]any, systemName string) bool {
|
||
name := strings.ToLower(strings.TrimSpace(systemName))
|
||
if strings.Contains(name, "tokenfactory") ||
|
||
strings.Contains(name, "词元工厂") ||
|
||
strings.Contains(name, "开放词元工厂") {
|
||
return true
|
||
}
|
||
|
||
score := 0
|
||
|
||
if strings.TrimSpace(common.Interface2String(data["version"])) != "" {
|
||
score++
|
||
}
|
||
if startTimeRaw := strings.TrimSpace(common.Interface2String(data["start_time"])); startTimeRaw != "" {
|
||
if startTime, err := strconv.ParseInt(startTimeRaw, 10, 64); err == nil && startTime > 0 {
|
||
score++
|
||
}
|
||
}
|
||
if strings.TrimSpace(common.Interface2String(data["quota_display_type"])) != "" ||
|
||
strings.TrimSpace(common.Interface2String(data["quota_per_unit"])) != "" {
|
||
score++
|
||
}
|
||
if _, ok := data["enable_drawing"]; ok {
|
||
score++
|
||
}
|
||
if _, ok := data["enable_task"]; ok {
|
||
score++
|
||
}
|
||
if _, ok := data["system_name"]; ok {
|
||
score++
|
||
}
|
||
|
||
// 命中特征达到阈值即视为 TokenFactory 平台实例,避免只依赖 system_name 英文名。
|
||
return score >= 4
|
||
}
|
||
|
||
func fetchTokenFactoryStatus(baseURL string, key string) error {
|
||
client := &http.Client{Timeout: 10 * time.Second}
|
||
u := strings.TrimRight(strings.TrimSpace(baseURL), "/") + "/api/status"
|
||
req, err := http.NewRequest("GET", u, nil)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
req.Header.Set("Authorization", "Bearer "+strings.TrimSpace(key))
|
||
resp, err := client.Do(req)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
defer resp.Body.Close()
|
||
if resp.StatusCode != http.StatusOK {
|
||
return fmt.Errorf("status code %d", resp.StatusCode)
|
||
}
|
||
var payload map[string]any
|
||
if err := json.NewDecoder(resp.Body).Decode(&payload); err != nil {
|
||
return err
|
||
}
|
||
if success, ok := payload["success"].(bool); ok && !success {
|
||
return fmt.Errorf("status 接口返回失败")
|
||
}
|
||
var systemName string
|
||
var statusData map[string]any
|
||
if parsedData, ok := payload["data"].(map[string]any); ok {
|
||
statusData = parsedData
|
||
systemName = strings.TrimSpace(common.Interface2String(statusData["system_name"]))
|
||
}
|
||
if statusData == nil {
|
||
return fmt.Errorf("status 返回结构缺少 data")
|
||
}
|
||
if !isLikelyTokenFactoryStatusData(statusData, systemName) {
|
||
return fmt.Errorf("status 特征不匹配 TokenFactory 平台(system_name=%s)", systemName)
|
||
}
|
||
return nil
|
||
}
|
||
|
||
func decodeUpstreamChannelPayload(payload map[string]any, itemsKey string) ([]upstreamChannelSyncItem, error) {
|
||
successRaw, exists := payload["success"]
|
||
if !exists {
|
||
return nil, fmt.Errorf("上游响应缺少 success 字段")
|
||
}
|
||
success, ok := successRaw.(bool)
|
||
if !ok {
|
||
return nil, fmt.Errorf("上游 success 字段类型异常: %T", successRaw)
|
||
}
|
||
if !success {
|
||
upstreamMessage := strings.TrimSpace(common.Interface2String(payload["message"]))
|
||
if upstreamMessage == "" {
|
||
upstreamMessage = "上游返回失败(message 为空)"
|
||
}
|
||
return nil, fmt.Errorf("%s", upstreamMessage)
|
||
}
|
||
data, _ := payload["data"].(map[string]any)
|
||
if data == nil {
|
||
return nil, fmt.Errorf("上游响应缺少 data")
|
||
}
|
||
rawItems, _ := data[itemsKey].([]any)
|
||
items := make([]upstreamChannelSyncItem, 0, len(rawItems))
|
||
for _, raw := range rawItems {
|
||
m, ok := raw.(map[string]any)
|
||
if !ok {
|
||
continue
|
||
}
|
||
item := upstreamChannelSyncItem{
|
||
ID: common.String2Int(common.Interface2String(m["id"])),
|
||
Name: strings.TrimSpace(common.Interface2String(m["name"])),
|
||
Models: strings.TrimSpace(common.Interface2String(m["models"])),
|
||
Group: strings.TrimSpace(common.Interface2String(m["group"])),
|
||
Status: common.String2Int(common.Interface2String(m["status"])),
|
||
Type: common.String2Int(common.Interface2String(m["type"])),
|
||
ChannelNo: strings.TrimSpace(common.Interface2String(m["channel_no"])),
|
||
RouteSlug: strings.TrimSpace(common.Interface2String(m["route_slug"])),
|
||
SupplierApplication: common.String2Int(common.Interface2String(m["supplier_application_id"])),
|
||
SupplierAlias: strings.TrimSpace(common.Interface2String(m["supplier_alias"])),
|
||
SupplierType: strings.TrimSpace(common.Interface2String(m["supplier_type"])),
|
||
CompanyLogoURL: strings.TrimSpace(common.Interface2String(m["company_logo_url"])),
|
||
}
|
||
if v, ok := m["price_discount_percent"]; ok && v != nil {
|
||
switch x := v.(type) {
|
||
case float64:
|
||
item.PriceDiscountPercent = x
|
||
case json.Number:
|
||
if f, err := x.Float64(); err == nil {
|
||
item.PriceDiscountPercent = f
|
||
}
|
||
default:
|
||
if f, err := strconv.ParseFloat(strings.TrimSpace(common.Interface2String(v)), 64); err == nil {
|
||
item.PriceDiscountPercent = f
|
||
}
|
||
}
|
||
}
|
||
if v, ok := m["markup_discount_rate"]; ok && v != nil {
|
||
switch x := v.(type) {
|
||
case float64:
|
||
item.MarkupDiscountRate = x
|
||
case json.Number:
|
||
if f, err := x.Float64(); err == nil {
|
||
item.MarkupDiscountRate = f
|
||
}
|
||
default:
|
||
if f, err := strconv.ParseFloat(strings.TrimSpace(common.Interface2String(v)), 64); err == nil {
|
||
item.MarkupDiscountRate = f
|
||
}
|
||
}
|
||
}
|
||
if mp, ok := m["model_price"].(map[string]any); ok && len(mp) > 0 {
|
||
item.ModelPrice = jsonAnyMapToFloatMap(mp)
|
||
}
|
||
if mr, ok := m["model_ratio"].(map[string]any); ok && len(mr) > 0 {
|
||
item.ModelRatio = jsonAnyMapToFloatMap(mr)
|
||
}
|
||
item.ModelMapping = decodeUpstreamModelMapping(m)
|
||
items = append(items, item)
|
||
}
|
||
return items, nil
|
||
}
|
||
|
||
func fetchTokenFactoryUpstreamChannelsExport(baseURL string, key string) ([]upstreamChannelSyncItem, error) {
|
||
client := &http.Client{Timeout: 45 * time.Second}
|
||
u := strings.TrimRight(strings.TrimSpace(baseURL), "/") + "/api/tf_open_sync/channels"
|
||
req, err := http.NewRequest("GET", u, nil)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
k := strings.TrimSpace(key)
|
||
req.Header.Set("Authorization", "Bearer "+k)
|
||
req.Header.Set("X-TokenFactory-Open-Sync-Secret", k)
|
||
resp, err := client.Do(req)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
defer resp.Body.Close()
|
||
if resp.StatusCode == http.StatusNotFound {
|
||
return nil, errTfOpenExportNotFound
|
||
}
|
||
if resp.StatusCode != http.StatusOK {
|
||
body, _ := io.ReadAll(io.LimitReader(resp.Body, 1024))
|
||
return nil, fmt.Errorf("export 接口 status code %d: %s", resp.StatusCode, strings.TrimSpace(string(body)))
|
||
}
|
||
var payload map[string]any
|
||
if err := json.NewDecoder(resp.Body).Decode(&payload); err != nil {
|
||
return nil, err
|
||
}
|
||
return decodeUpstreamChannelPayload(payload, "channels")
|
||
}
|
||
|
||
var errTfOpenExportNotFound = errors.New("tf_open_sync channels export not found")
|
||
|
||
func jsonAnyMapToFloatMap(raw map[string]any) map[string]float64 {
|
||
out := make(map[string]float64)
|
||
for k, v := range raw {
|
||
switch x := v.(type) {
|
||
case float64:
|
||
out[k] = x
|
||
case json.Number:
|
||
if f, err := x.Float64(); err == nil {
|
||
out[k] = f
|
||
}
|
||
default:
|
||
if f, err := strconv.ParseFloat(strings.TrimSpace(common.Interface2String(v)), 64); err == nil {
|
||
out[k] = f
|
||
}
|
||
}
|
||
}
|
||
if len(out) == 0 {
|
||
return nil
|
||
}
|
||
return out
|
||
}
|
||
|
||
func fetchTokenFactoryUpstreamChannelsLegacy(baseURL string, key string) ([]upstreamChannelSyncItem, error) {
|
||
client := &http.Client{Timeout: 15 * time.Second}
|
||
u := strings.TrimRight(strings.TrimSpace(baseURL), "/") + "/api/channel/?p=1&page_size=100000&id_sort=true"
|
||
req, err := http.NewRequest("GET", u, nil)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
req.Header.Set("Authorization", "Bearer "+strings.TrimSpace(key))
|
||
resp, err := client.Do(req)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
defer resp.Body.Close()
|
||
if resp.StatusCode != http.StatusOK {
|
||
body, _ := io.ReadAll(io.LimitReader(resp.Body, 1024))
|
||
return nil, fmt.Errorf("upstream status code %d: %s", resp.StatusCode, strings.TrimSpace(string(body)))
|
||
}
|
||
var payload map[string]any
|
||
if err := json.NewDecoder(resp.Body).Decode(&payload); err != nil {
|
||
return nil, err
|
||
}
|
||
return decodeUpstreamChannelPayload(payload, "items")
|
||
}
|
||
|
||
func fetchTokenFactoryUpstreamChannels(baseURL string, key string) ([]upstreamChannelSyncItem, error) {
|
||
items, err := fetchTokenFactoryUpstreamChannelsExport(baseURL, key)
|
||
if err == nil && len(items) > 0 {
|
||
return items, nil
|
||
}
|
||
if err != nil && !errors.Is(err, errTfOpenExportNotFound) {
|
||
return nil, fmt.Errorf("拉取上游渠道(export): %w", err)
|
||
}
|
||
legacy, err2 := fetchTokenFactoryUpstreamChannelsLegacy(baseURL, key)
|
||
if err2 != nil {
|
||
return nil, fmt.Errorf("拉取上游渠道失败: %w", err2)
|
||
}
|
||
return legacy, nil
|
||
}
|
||
|
||
func tfOpenLocalChannelNo(up upstreamChannelSyncItem) string {
|
||
// 留空让本地按既有逻辑分配 cN(按 supplier_application_id 递增)。
|
||
return ""
|
||
}
|
||
|
||
func buildTokenFactorySyncedChannels(base *model.Channel) ([]model.Channel, []model.TFOpenUpstreamPricing, error) {
|
||
baseURL := base.GetBaseURL()
|
||
if !isTokenFactoryOpenBaseURL(baseURL) {
|
||
return nil, nil, fmt.Errorf("TokenFactoryOpen 渠道的 API 地址必须指向 TokenFactory 平台")
|
||
}
|
||
key := strings.TrimSpace(base.Key)
|
||
if key == "" {
|
||
return nil, nil, fmt.Errorf("TokenFactoryOpen 渠道密钥不能为空")
|
||
}
|
||
if err := fetchTokenFactoryStatus(baseURL, key); err != nil {
|
||
return nil, nil, fmt.Errorf("TokenFactoryOpen 平台识别失败: %w", err)
|
||
}
|
||
upstreamChannels, err := fetchTokenFactoryUpstreamChannels(baseURL, key)
|
||
if err != nil {
|
||
return nil, nil, fmt.Errorf("拉取上游渠道失败: %w", err)
|
||
}
|
||
if len(upstreamChannels) == 0 {
|
||
return nil, nil, fmt.Errorf("上游未返回可同步渠道")
|
||
}
|
||
now := common.GetTimestamp()
|
||
result := make([]model.Channel, 0, len(upstreamChannels))
|
||
pricing := make([]model.TFOpenUpstreamPricing, 0, len(upstreamChannels))
|
||
for i, upstream := range upstreamChannels {
|
||
clone := *base
|
||
clone.Id = 0
|
||
clone.CreatedTime = now
|
||
if upstream.Type > 0 {
|
||
clone.Type = constant.ChannelTypeTokenFactoryOpen
|
||
} else {
|
||
clone.Type = constant.ChannelTypeTokenFactoryOpen
|
||
}
|
||
// 直接使用上游渠道名称,不再拼接序号后缀
|
||
upstreamName := strings.TrimSpace(upstream.Name)
|
||
seqIdx := model.EncodeBase62(int64(i))
|
||
if upstreamName != "" {
|
||
clone.Name = upstreamName
|
||
} else {
|
||
clone.Name = fmt.Sprintf("upstream-%s", seqIdx)
|
||
}
|
||
clone.Models = strings.TrimSpace(upstream.Models)
|
||
if strings.TrimSpace(upstream.Group) != "" {
|
||
clone.Group = strings.TrimSpace(upstream.Group)
|
||
}
|
||
if upstream.Status > 0 {
|
||
clone.Status = upstream.Status
|
||
}
|
||
upstreamSupplierType := strings.TrimSpace(upstream.SupplierType)
|
||
if upstreamSupplierType != "" && isValidChannelSupplierType(upstreamSupplierType) {
|
||
clone.SupplierType = upstreamSupplierType
|
||
} else if strings.TrimSpace(clone.SupplierType) == "" || !isValidChannelSupplierType(strings.TrimSpace(clone.SupplierType)) {
|
||
clone.SupplierType = defaultChannelSupplierType
|
||
}
|
||
upstreamLogoURL := strings.TrimSpace(upstream.CompanyLogoURL)
|
||
if upstreamLogoURL != "" {
|
||
clone.CompanyLogoURL = upstreamLogoURL
|
||
}
|
||
if upstream.PriceDiscountPercent > 0 {
|
||
clone.PriceDiscountPercent = &upstream.PriceDiscountPercent
|
||
}
|
||
mm := strings.TrimSpace(upstream.ModelMapping)
|
||
if mm != "" {
|
||
clone.ModelMapping = &mm
|
||
} else {
|
||
clone.ModelMapping = nil
|
||
}
|
||
clone.ChannelNo = tfOpenLocalChannelNo(upstream)
|
||
clone.RouteSlug = ""
|
||
syncMeta := map[string]any{
|
||
"source": "tokenfactory_open",
|
||
"upstream_channel_id": upstream.ID,
|
||
"upstream_channel_no": strings.TrimSpace(upstream.ChannelNo),
|
||
"upstream_route_slug": strings.TrimSpace(upstream.RouteSlug),
|
||
"upstream_supplier_app_id": upstream.SupplierApplication,
|
||
"upstream_supplier_alias": strings.TrimSpace(upstream.SupplierAlias),
|
||
"upstream_channel_type": upstream.Type,
|
||
"local_channel_no": clone.ChannelNo,
|
||
"sync_seq_index": seqIdx, // 本次同步批次内的 base-62 顺序编号
|
||
"synced_at": now,
|
||
}
|
||
metaJSON, _ := common.Marshal(syncMeta)
|
||
clone.OtherInfo = string(metaJSON)
|
||
result = append(result, clone)
|
||
pricing = append(pricing, model.TFOpenUpstreamPricing{
|
||
ModelPrice: upstream.ModelPrice,
|
||
ModelRatio: upstream.ModelRatio,
|
||
})
|
||
}
|
||
return result, pricing, nil
|
||
}
|
||
|
||
func AddChannel(c *gin.Context) {
|
||
addChannelRequest := AddChannelRequest{}
|
||
err := c.ShouldBindJSON(&addChannelRequest)
|
||
if err != nil {
|
||
common.ApiError(c, err)
|
||
return
|
||
}
|
||
// 使用统一的校验函数
|
||
if err := validateChannel(addChannelRequest.Channel, true); err != nil {
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"success": false,
|
||
"message": err.Error(),
|
||
})
|
||
return
|
||
}
|
||
if addChannelRequest.Channel != nil && addChannelRequest.Channel.PriceDiscountPercent == nil {
|
||
v := 100.0
|
||
addChannelRequest.Channel.PriceDiscountPercent = &v
|
||
}
|
||
if err := applySupplierChannelOwnershipForCreate(c, addChannelRequest.Channel); err != nil {
|
||
c.JSON(http.StatusForbidden, gin.H{
|
||
"success": false,
|
||
"message": "当前用户未通过供应商审核,无权创建渠道",
|
||
})
|
||
return
|
||
}
|
||
addChannelRequest.Channel.CreatedTime = common.GetTimestamp()
|
||
keys := make([]string, 0)
|
||
switch addChannelRequest.Mode {
|
||
case "multi_to_single":
|
||
addChannelRequest.Channel.ChannelInfo.IsMultiKey = true
|
||
addChannelRequest.Channel.ChannelInfo.MultiKeyMode = addChannelRequest.MultiKeyMode
|
||
if addChannelRequest.Channel.Type == constant.ChannelTypeVertexAi && addChannelRequest.Channel.GetOtherSettings().VertexKeyType != dto.VertexKeyTypeAPIKey {
|
||
array, err := getVertexArrayKeys(addChannelRequest.Channel.Key)
|
||
if err != nil {
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"success": false,
|
||
"message": err.Error(),
|
||
})
|
||
return
|
||
}
|
||
addChannelRequest.Channel.ChannelInfo.MultiKeySize = len(array)
|
||
addChannelRequest.Channel.Key = strings.Join(array, "\n")
|
||
} else {
|
||
cleanKeys := make([]string, 0)
|
||
for _, key := range strings.Split(addChannelRequest.Channel.Key, "\n") {
|
||
if key == "" {
|
||
continue
|
||
}
|
||
key = strings.TrimSpace(key)
|
||
cleanKeys = append(cleanKeys, key)
|
||
}
|
||
addChannelRequest.Channel.ChannelInfo.MultiKeySize = len(cleanKeys)
|
||
addChannelRequest.Channel.Key = strings.Join(cleanKeys, "\n")
|
||
}
|
||
keys = []string{addChannelRequest.Channel.Key}
|
||
case "batch":
|
||
if addChannelRequest.Channel.Type == constant.ChannelTypeVertexAi && addChannelRequest.Channel.GetOtherSettings().VertexKeyType != dto.VertexKeyTypeAPIKey {
|
||
// multi json
|
||
keys, err = getVertexArrayKeys(addChannelRequest.Channel.Key)
|
||
if err != nil {
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"success": false,
|
||
"message": err.Error(),
|
||
})
|
||
return
|
||
}
|
||
} else {
|
||
keys = strings.Split(addChannelRequest.Channel.Key, "\n")
|
||
}
|
||
case "single":
|
||
keys = []string{addChannelRequest.Channel.Key}
|
||
default:
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"success": false,
|
||
"message": "不支持的添加模式",
|
||
})
|
||
return
|
||
}
|
||
|
||
channels := make([]model.Channel, 0, len(keys))
|
||
for _, key := range keys {
|
||
if key == "" {
|
||
continue
|
||
}
|
||
localChannel := addChannelRequest.Channel
|
||
localChannel.Key = key
|
||
if addChannelRequest.BatchAddSetKeyPrefix2Name && len(keys) > 1 {
|
||
keyPrefix := localChannel.Key
|
||
if len(localChannel.Key) > 8 {
|
||
keyPrefix = localChannel.Key[:8]
|
||
}
|
||
localChannel.Name = fmt.Sprintf("%s %s", localChannel.Name, keyPrefix)
|
||
}
|
||
channels = append(channels, *localChannel)
|
||
}
|
||
var tfOpenPricing []model.TFOpenUpstreamPricing
|
||
if addChannelRequest.Channel.Type == constant.ChannelTypeTokenFactoryOpen {
|
||
syncBase := *addChannelRequest.Channel
|
||
if len(channels) > 0 {
|
||
syncBase.Key = strings.TrimSpace(channels[0].Key)
|
||
}
|
||
syncedChannels, pricing, syncErr := buildTokenFactorySyncedChannels(&syncBase)
|
||
if syncErr != nil {
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"success": false,
|
||
"message": syncErr.Error(),
|
||
})
|
||
return
|
||
}
|
||
channels = syncedChannels
|
||
tfOpenPricing = pricing
|
||
}
|
||
if len(channels) > 1 {
|
||
for i := range channels {
|
||
channels[i].RouteSlug = ""
|
||
}
|
||
}
|
||
if addChannelRequest.Channel.Type == constant.ChannelTypeTokenFactoryOpen {
|
||
err = model.BatchInsertChannelsWithTfOpenUpstreamPricing(channels, tfOpenPricing)
|
||
} else {
|
||
err = model.BatchInsertChannels(channels)
|
||
}
|
||
if err != nil {
|
||
common.ApiError(c, err)
|
||
return
|
||
}
|
||
service.ResetProxyClientCache()
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"success": true,
|
||
"message": "",
|
||
})
|
||
return
|
||
}
|
||
|
||
func DeleteChannel(c *gin.Context) {
|
||
id, _ := strconv.Atoi(c.Param("id"))
|
||
channel := model.Channel{Id: id}
|
||
err := channel.Delete()
|
||
if err != nil {
|
||
common.ApiError(c, err)
|
||
return
|
||
}
|
||
model.InitChannelCache()
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"success": true,
|
||
"message": "",
|
||
})
|
||
return
|
||
}
|
||
|
||
func DeleteDisabledChannel(c *gin.Context) {
|
||
rows, err := model.DeleteDisabledChannel()
|
||
if err != nil {
|
||
common.ApiError(c, err)
|
||
return
|
||
}
|
||
model.InitChannelCache()
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"success": true,
|
||
"message": "",
|
||
"data": rows,
|
||
})
|
||
return
|
||
}
|
||
|
||
type ChannelTag struct {
|
||
Tag string `json:"tag"`
|
||
NewTag *string `json:"new_tag"`
|
||
Priority *int64 `json:"priority"`
|
||
Weight *uint `json:"weight"`
|
||
ModelMapping *string `json:"model_mapping"`
|
||
Models *string `json:"models"`
|
||
Groups *string `json:"groups"`
|
||
ParamOverride *string `json:"param_override"`
|
||
HeaderOverride *string `json:"header_override"`
|
||
}
|
||
|
||
func DisableTagChannels(c *gin.Context) {
|
||
channelTag := ChannelTag{}
|
||
err := c.ShouldBindJSON(&channelTag)
|
||
if err != nil || channelTag.Tag == "" {
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"success": false,
|
||
"message": "参数错误",
|
||
})
|
||
return
|
||
}
|
||
err = model.DisableChannelByTag(channelTag.Tag)
|
||
if err != nil {
|
||
common.ApiError(c, err)
|
||
return
|
||
}
|
||
model.InitChannelCache()
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"success": true,
|
||
"message": "",
|
||
})
|
||
return
|
||
}
|
||
|
||
func EnableTagChannels(c *gin.Context) {
|
||
channelTag := ChannelTag{}
|
||
err := c.ShouldBindJSON(&channelTag)
|
||
if err != nil || channelTag.Tag == "" {
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"success": false,
|
||
"message": "参数错误",
|
||
})
|
||
return
|
||
}
|
||
err = model.EnableChannelByTag(channelTag.Tag)
|
||
if err != nil {
|
||
common.ApiError(c, err)
|
||
return
|
||
}
|
||
model.InitChannelCache()
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"success": true,
|
||
"message": "",
|
||
})
|
||
return
|
||
}
|
||
|
||
func EditTagChannels(c *gin.Context) {
|
||
channelTag := ChannelTag{}
|
||
err := c.ShouldBindJSON(&channelTag)
|
||
if err != nil {
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"success": false,
|
||
"message": "参数错误",
|
||
})
|
||
return
|
||
}
|
||
if channelTag.Tag == "" {
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"success": false,
|
||
"message": "tag不能为空",
|
||
})
|
||
return
|
||
}
|
||
if channelTag.ParamOverride != nil {
|
||
trimmed := strings.TrimSpace(*channelTag.ParamOverride)
|
||
if trimmed != "" && !json.Valid([]byte(trimmed)) {
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"success": false,
|
||
"message": "参数覆盖必须是合法的 JSON 格式",
|
||
})
|
||
return
|
||
}
|
||
channelTag.ParamOverride = common.GetPointer[string](trimmed)
|
||
}
|
||
if channelTag.HeaderOverride != nil {
|
||
trimmed := strings.TrimSpace(*channelTag.HeaderOverride)
|
||
if trimmed != "" && !json.Valid([]byte(trimmed)) {
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"success": false,
|
||
"message": "请求头覆盖必须是合法的 JSON 格式",
|
||
})
|
||
return
|
||
}
|
||
channelTag.HeaderOverride = common.GetPointer[string](trimmed)
|
||
}
|
||
err = model.EditChannelByTag(channelTag.Tag, channelTag.NewTag, channelTag.ModelMapping, channelTag.Models, channelTag.Groups, channelTag.Priority, channelTag.Weight, channelTag.ParamOverride, channelTag.HeaderOverride)
|
||
if err != nil {
|
||
common.ApiError(c, err)
|
||
return
|
||
}
|
||
model.InitChannelCache()
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"success": true,
|
||
"message": "",
|
||
})
|
||
return
|
||
}
|
||
|
||
type ChannelBatch struct {
|
||
Ids []int `json:"ids"`
|
||
Tag *string `json:"tag"`
|
||
}
|
||
|
||
func DeleteChannelBatch(c *gin.Context) {
|
||
channelBatch := ChannelBatch{}
|
||
err := c.ShouldBindJSON(&channelBatch)
|
||
if err != nil || len(channelBatch.Ids) == 0 {
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"success": false,
|
||
"message": "参数错误",
|
||
})
|
||
return
|
||
}
|
||
err = model.BatchDeleteChannels(channelBatch.Ids)
|
||
if err != nil {
|
||
common.ApiError(c, err)
|
||
return
|
||
}
|
||
model.InitChannelCache()
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"success": true,
|
||
"message": "",
|
||
"data": len(channelBatch.Ids),
|
||
})
|
||
return
|
||
}
|
||
|
||
type PatchChannel struct {
|
||
model.Channel
|
||
MultiKeyMode *string `json:"multi_key_mode"`
|
||
KeyMode *string `json:"key_mode"` // 多key模式下密钥覆盖或者追加
|
||
}
|
||
|
||
func UpdateChannel(c *gin.Context) {
|
||
channel := PatchChannel{}
|
||
err := c.ShouldBindJSON(&channel)
|
||
if err != nil {
|
||
common.ApiError(c, err)
|
||
return
|
||
}
|
||
if channel.Id <= 0 {
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"success": false,
|
||
"message": "渠道ID无效",
|
||
})
|
||
return
|
||
}
|
||
|
||
// Preserve existing ChannelInfo to ensure multi-key channels keep correct state even if the client does not send ChannelInfo in the request.
|
||
originChannel, err := model.GetChannelById(channel.Id, true)
|
||
if err != nil {
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"success": false,
|
||
"message": err.Error(),
|
||
})
|
||
return
|
||
}
|
||
if !validateSupplierChannelOwnershipForUpdate(c, originChannel) {
|
||
c.JSON(http.StatusForbidden, gin.H{
|
||
"success": false,
|
||
"message": "无权修改其他供应商渠道",
|
||
})
|
||
return
|
||
}
|
||
oldBalance := originChannel.Balance
|
||
// 部分更新(如仅改状态/优先级/权重):请求未带供应商类型时沿用库中值,否则 validateChannel 会因零值失败。
|
||
if strings.TrimSpace(channel.SupplierType) == "" {
|
||
channel.SupplierType = strings.TrimSpace(originChannel.SupplierType)
|
||
}
|
||
if strings.TrimSpace(channel.SupplierType) == "" && originChannel.SupplierApplicationID > 0 {
|
||
var app model.SupplierApplication
|
||
if err := model.DB.Select("supplier_type").Where("id = ?", originChannel.SupplierApplicationID).First(&app).Error; err == nil {
|
||
channel.SupplierType = strings.TrimSpace(app.SupplierType)
|
||
}
|
||
}
|
||
if strings.TrimSpace(channel.SupplierType) == "" {
|
||
channel.SupplierType = defaultChannelSupplierType
|
||
}
|
||
|
||
// route_slug:空则沿用库中值;非空变更时校验格式与全局唯一。
|
||
if strings.TrimSpace(channel.RouteSlug) == "" {
|
||
channel.RouteSlug = originChannel.RouteSlug
|
||
} else {
|
||
channel.RouteSlug = strings.TrimSpace(channel.RouteSlug)
|
||
if channel.RouteSlug != originChannel.RouteSlug {
|
||
if !model.IsValidRouteSlug(channel.RouteSlug) {
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"success": false,
|
||
"message": "route_slug 格式无效(2~32 位字母数字,且不能为 c 加纯数字)",
|
||
})
|
||
return
|
||
}
|
||
var cnt int64
|
||
if err := model.DB.Model(&model.Channel{}).Where("route_slug = ? AND id <> ?", channel.RouteSlug, channel.Id).Count(&cnt).Error; err != nil {
|
||
common.ApiError(c, err)
|
||
return
|
||
}
|
||
if cnt > 0 {
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"success": false,
|
||
"message": "route_slug 已被其他渠道占用",
|
||
})
|
||
return
|
||
}
|
||
}
|
||
}
|
||
|
||
// 使用统一的校验函数
|
||
if err := validateChannel(&channel.Channel, false); err != nil {
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"success": false,
|
||
"message": err.Error(),
|
||
})
|
||
return
|
||
}
|
||
// 供应商更新时强制保持归属信息不变,防止通过请求体篡改 owner/supplier 关联。
|
||
if c.GetInt("role") < common.RoleAdminUser {
|
||
channel.OwnerUserID = originChannel.OwnerUserID
|
||
channel.SupplierApplicationID = originChannel.SupplierApplicationID
|
||
}
|
||
|
||
// Always copy the original ChannelInfo so that fields like IsMultiKey and MultiKeySize are retained.
|
||
channel.ChannelInfo = originChannel.ChannelInfo
|
||
|
||
// If the request explicitly specifies a new MultiKeyMode, apply it on top of the original info.
|
||
if channel.MultiKeyMode != nil && *channel.MultiKeyMode != "" {
|
||
channel.ChannelInfo.MultiKeyMode = constant.MultiKeyMode(*channel.MultiKeyMode)
|
||
}
|
||
|
||
// 处理多key模式下的密钥追加/覆盖逻辑
|
||
if channel.KeyMode != nil && channel.ChannelInfo.IsMultiKey {
|
||
switch *channel.KeyMode {
|
||
case "append":
|
||
// 追加模式:将新密钥添加到现有密钥列表
|
||
if originChannel.Key != "" {
|
||
var newKeys []string
|
||
var existingKeys []string
|
||
|
||
// 解析现有密钥
|
||
if strings.HasPrefix(strings.TrimSpace(originChannel.Key), "[") {
|
||
// JSON数组格式
|
||
var arr []json.RawMessage
|
||
if err := json.Unmarshal([]byte(strings.TrimSpace(originChannel.Key)), &arr); err == nil {
|
||
existingKeys = make([]string, len(arr))
|
||
for i, v := range arr {
|
||
existingKeys[i] = string(v)
|
||
}
|
||
}
|
||
} else {
|
||
// 换行分隔格式
|
||
existingKeys = strings.Split(strings.Trim(originChannel.Key, "\n"), "\n")
|
||
}
|
||
|
||
// 处理 Vertex AI 的特殊情况
|
||
if channel.Type == constant.ChannelTypeVertexAi && channel.GetOtherSettings().VertexKeyType != dto.VertexKeyTypeAPIKey {
|
||
// 尝试解析新密钥为JSON数组
|
||
if strings.HasPrefix(strings.TrimSpace(channel.Key), "[") {
|
||
array, err := getVertexArrayKeys(channel.Key)
|
||
if err != nil {
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"success": false,
|
||
"message": "追加密钥解析失败: " + err.Error(),
|
||
})
|
||
return
|
||
}
|
||
newKeys = array
|
||
} else {
|
||
// 单个JSON密钥
|
||
newKeys = []string{channel.Key}
|
||
}
|
||
} else {
|
||
// 普通渠道的处理
|
||
inputKeys := strings.Split(channel.Key, "\n")
|
||
for _, key := range inputKeys {
|
||
key = strings.TrimSpace(key)
|
||
if key != "" {
|
||
newKeys = append(newKeys, key)
|
||
}
|
||
}
|
||
}
|
||
|
||
seen := make(map[string]struct{}, len(existingKeys)+len(newKeys))
|
||
for _, key := range existingKeys {
|
||
normalized := strings.TrimSpace(key)
|
||
if normalized == "" {
|
||
continue
|
||
}
|
||
seen[normalized] = struct{}{}
|
||
}
|
||
dedupedNewKeys := make([]string, 0, len(newKeys))
|
||
for _, key := range newKeys {
|
||
normalized := strings.TrimSpace(key)
|
||
if normalized == "" {
|
||
continue
|
||
}
|
||
if _, ok := seen[normalized]; ok {
|
||
continue
|
||
}
|
||
seen[normalized] = struct{}{}
|
||
dedupedNewKeys = append(dedupedNewKeys, normalized)
|
||
}
|
||
|
||
allKeys := append(existingKeys, dedupedNewKeys...)
|
||
channel.Key = strings.Join(allKeys, "\n")
|
||
}
|
||
case "replace":
|
||
// 覆盖模式:直接使用新密钥(默认行为,不需要特殊处理)
|
||
}
|
||
}
|
||
err = channel.Update()
|
||
if err != nil {
|
||
common.ApiError(c, err)
|
||
return
|
||
}
|
||
notifyChannelBalanceAlertIfNeeded(originChannel, oldBalance, channel.Balance)
|
||
model.InitChannelCache()
|
||
service.ResetProxyClientCache()
|
||
channel.Key = ""
|
||
clearChannelInfo(&channel.Channel)
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"success": true,
|
||
"message": "",
|
||
"data": channel,
|
||
})
|
||
return
|
||
}
|
||
|
||
func FetchModels(c *gin.Context) {
|
||
var req struct {
|
||
BaseURL string `json:"base_url"`
|
||
Type int `json:"type"`
|
||
Key string `json:"key"`
|
||
}
|
||
|
||
if err := c.ShouldBindJSON(&req); err != nil {
|
||
c.JSON(http.StatusBadRequest, gin.H{
|
||
"success": false,
|
||
"message": "Invalid request",
|
||
})
|
||
return
|
||
}
|
||
|
||
baseURL := req.BaseURL
|
||
if baseURL == "" {
|
||
baseURL = constant.ChannelBaseURLs[req.Type]
|
||
}
|
||
|
||
// remove line breaks and extra spaces.
|
||
key := strings.TrimSpace(req.Key)
|
||
key = strings.Split(key, "\n")[0]
|
||
|
||
if req.Type == constant.ChannelTypeOllama {
|
||
models, err := ollama.FetchOllamaModels(c.Request.Context(), baseURL, key)
|
||
if err != nil {
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"success": false,
|
||
"message": fmt.Sprintf("获取Ollama模型失败: %s", err.Error()),
|
||
})
|
||
return
|
||
}
|
||
|
||
names := make([]string, 0, len(models))
|
||
for _, modelInfo := range models {
|
||
names = append(names, modelInfo.Name)
|
||
}
|
||
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"success": true,
|
||
"data": names,
|
||
})
|
||
return
|
||
}
|
||
|
||
if req.Type == constant.ChannelTypeGemini {
|
||
models, err := gemini.FetchGeminiModels(c.Request.Context(), baseURL, key, "")
|
||
if err != nil {
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"success": false,
|
||
"message": fmt.Sprintf("获取Gemini模型失败: %s", err.Error()),
|
||
})
|
||
return
|
||
}
|
||
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"success": true,
|
||
"data": models,
|
||
})
|
||
return
|
||
}
|
||
|
||
client := &http.Client{}
|
||
url := fmt.Sprintf("%s/v1/models", baseURL)
|
||
|
||
request, err := http.NewRequest("GET", url, nil)
|
||
if err != nil {
|
||
c.JSON(http.StatusInternalServerError, gin.H{
|
||
"success": false,
|
||
"message": err.Error(),
|
||
})
|
||
return
|
||
}
|
||
|
||
request.Header.Set("Authorization", "Bearer "+key)
|
||
|
||
response, err := client.Do(request)
|
||
if err != nil {
|
||
c.JSON(http.StatusInternalServerError, gin.H{
|
||
"success": false,
|
||
"message": err.Error(),
|
||
})
|
||
return
|
||
}
|
||
//check status code
|
||
if response.StatusCode != http.StatusOK {
|
||
c.JSON(http.StatusInternalServerError, gin.H{
|
||
"success": false,
|
||
"message": "Failed to fetch models",
|
||
})
|
||
return
|
||
}
|
||
defer response.Body.Close()
|
||
|
||
var result struct {
|
||
Data []struct {
|
||
ID string `json:"id"`
|
||
} `json:"data"`
|
||
}
|
||
|
||
if err := json.NewDecoder(response.Body).Decode(&result); err != nil {
|
||
c.JSON(http.StatusInternalServerError, gin.H{
|
||
"success": false,
|
||
"message": err.Error(),
|
||
})
|
||
return
|
||
}
|
||
|
||
var models []string
|
||
for _, model := range result.Data {
|
||
models = append(models, model.ID)
|
||
}
|
||
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"success": true,
|
||
"data": models,
|
||
})
|
||
}
|
||
|
||
func BatchSetChannelTag(c *gin.Context) {
|
||
channelBatch := ChannelBatch{}
|
||
err := c.ShouldBindJSON(&channelBatch)
|
||
if err != nil || len(channelBatch.Ids) == 0 {
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"success": false,
|
||
"message": "参数错误",
|
||
})
|
||
return
|
||
}
|
||
err = model.BatchSetChannelTag(channelBatch.Ids, channelBatch.Tag)
|
||
if err != nil {
|
||
common.ApiError(c, err)
|
||
return
|
||
}
|
||
model.InitChannelCache()
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"success": true,
|
||
"message": "",
|
||
"data": len(channelBatch.Ids),
|
||
})
|
||
return
|
||
}
|
||
|
||
func GetTagModels(c *gin.Context) {
|
||
tag := c.Query("tag")
|
||
if tag == "" {
|
||
c.JSON(http.StatusBadRequest, gin.H{
|
||
"success": false,
|
||
"message": "tag不能为空",
|
||
})
|
||
return
|
||
}
|
||
|
||
channels, err := model.GetChannelsByTag(tag, false, false) // idSort=false, selectAll=false
|
||
if err != nil {
|
||
c.JSON(http.StatusInternalServerError, gin.H{
|
||
"success": false,
|
||
"message": err.Error(),
|
||
})
|
||
return
|
||
}
|
||
|
||
var longestModels string
|
||
maxLength := 0
|
||
|
||
// Find the longest models string among all channels with the given tag
|
||
for _, channel := range channels {
|
||
if channel.Models != "" {
|
||
currentModels := strings.Split(channel.Models, ",")
|
||
if len(currentModels) > maxLength {
|
||
maxLength = len(currentModels)
|
||
longestModels = channel.Models
|
||
}
|
||
}
|
||
}
|
||
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"success": true,
|
||
"message": "",
|
||
"data": longestModels,
|
||
})
|
||
return
|
||
}
|
||
|
||
// CopyChannel handles cloning an existing channel with its key.
|
||
// POST /api/channel/copy/:id
|
||
// Optional query params:
|
||
//
|
||
// suffix - string appended to the original name (default "_复制")
|
||
// reset_balance - bool, when true will reset balance & used_quota to 0 (default true)
|
||
func CopyChannel(c *gin.Context) {
|
||
id, err := strconv.Atoi(c.Param("id"))
|
||
if err != nil {
|
||
c.JSON(http.StatusOK, gin.H{"success": false, "message": "invalid id"})
|
||
return
|
||
}
|
||
|
||
suffix := c.DefaultQuery("suffix", "_复制")
|
||
resetBalance := true
|
||
if rbStr := c.DefaultQuery("reset_balance", "true"); rbStr != "" {
|
||
if v, err := strconv.ParseBool(rbStr); err == nil {
|
||
resetBalance = v
|
||
}
|
||
}
|
||
|
||
// fetch original channel with key
|
||
origin, err := model.GetChannelById(id, true)
|
||
if err != nil {
|
||
common.SysError("failed to get channel by id: " + err.Error())
|
||
c.JSON(http.StatusOK, gin.H{"success": false, "message": "获取渠道信息失败,请稍后重试"})
|
||
return
|
||
}
|
||
|
||
// clone channel
|
||
clone := *origin // shallow copy is sufficient as we will overwrite primitives
|
||
clone.Id = 0 // let DB auto-generate
|
||
clone.CreatedTime = common.GetTimestamp()
|
||
clone.Name = origin.Name + suffix
|
||
clone.TestTime = 0
|
||
clone.ResponseTime = 0
|
||
if resetBalance {
|
||
clone.Balance = 0
|
||
clone.UsedQuota = 0
|
||
}
|
||
clone.ChannelNo = ""
|
||
clone.RouteSlug = ""
|
||
|
||
// insert
|
||
if err := model.BatchInsertChannels([]model.Channel{clone}); err != nil {
|
||
common.SysError("failed to clone channel: " + err.Error())
|
||
c.JSON(http.StatusOK, gin.H{"success": false, "message": "复制渠道失败,请稍后重试"})
|
||
return
|
||
}
|
||
model.InitChannelCache()
|
||
// success
|
||
c.JSON(http.StatusOK, gin.H{"success": true, "message": "", "data": gin.H{"id": clone.Id}})
|
||
}
|
||
|
||
// MultiKeyManageRequest represents the request for multi-key management operations
|
||
type MultiKeyManageRequest struct {
|
||
ChannelId int `json:"channel_id"`
|
||
Action string `json:"action"` // "disable_key", "enable_key", "delete_key", "delete_disabled_keys", "get_key_status"
|
||
KeyIndex *int `json:"key_index,omitempty"` // for disable_key, enable_key, and delete_key actions
|
||
Page int `json:"page,omitempty"` // for get_key_status pagination
|
||
PageSize int `json:"page_size,omitempty"` // for get_key_status pagination
|
||
Status *int `json:"status,omitempty"` // for get_key_status filtering: 1=enabled, 2=manual_disabled, 3=auto_disabled, nil=all
|
||
}
|
||
|
||
// MultiKeyStatusResponse represents the response for key status query
|
||
type MultiKeyStatusResponse struct {
|
||
Keys []KeyStatus `json:"keys"`
|
||
Total int `json:"total"`
|
||
Page int `json:"page"`
|
||
PageSize int `json:"page_size"`
|
||
TotalPages int `json:"total_pages"`
|
||
// Statistics
|
||
EnabledCount int `json:"enabled_count"`
|
||
ManualDisabledCount int `json:"manual_disabled_count"`
|
||
AutoDisabledCount int `json:"auto_disabled_count"`
|
||
}
|
||
|
||
type KeyStatus struct {
|
||
Index int `json:"index"`
|
||
Status int `json:"status"` // 1: enabled, 2: disabled
|
||
DisabledTime int64 `json:"disabled_time,omitempty"`
|
||
Reason string `json:"reason,omitempty"`
|
||
KeyPreview string `json:"key_preview"` // first 10 chars of key for identification
|
||
}
|
||
|
||
// ManageMultiKeys handles multi-key management operations
|
||
func ManageMultiKeys(c *gin.Context) {
|
||
request := MultiKeyManageRequest{}
|
||
err := c.ShouldBindJSON(&request)
|
||
if err != nil {
|
||
common.ApiError(c, err)
|
||
return
|
||
}
|
||
|
||
channel, err := model.GetChannelById(request.ChannelId, true)
|
||
if err != nil {
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"success": false,
|
||
"message": "渠道不存在",
|
||
})
|
||
return
|
||
}
|
||
|
||
if !channel.ChannelInfo.IsMultiKey {
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"success": false,
|
||
"message": "该渠道不是多密钥模式",
|
||
})
|
||
return
|
||
}
|
||
|
||
lock := model.GetChannelPollingLock(channel.Id)
|
||
lock.Lock()
|
||
defer lock.Unlock()
|
||
|
||
switch request.Action {
|
||
case "get_key_status":
|
||
keys := channel.GetKeys()
|
||
|
||
// Default pagination parameters
|
||
page := request.Page
|
||
pageSize := request.PageSize
|
||
if page <= 0 {
|
||
page = 1
|
||
}
|
||
if pageSize <= 0 {
|
||
pageSize = 50 // Default page size
|
||
}
|
||
|
||
// Statistics for all keys (unchanged by filtering)
|
||
var enabledCount, manualDisabledCount, autoDisabledCount int
|
||
|
||
// Build all key status data first
|
||
var allKeyStatusList []KeyStatus
|
||
for i, key := range keys {
|
||
status := 1 // default enabled
|
||
var disabledTime int64
|
||
var reason string
|
||
|
||
if channel.ChannelInfo.MultiKeyStatusList != nil {
|
||
if s, exists := channel.ChannelInfo.MultiKeyStatusList[i]; exists {
|
||
status = s
|
||
}
|
||
}
|
||
|
||
// Count for statistics (all keys)
|
||
switch status {
|
||
case 1:
|
||
enabledCount++
|
||
case 2:
|
||
manualDisabledCount++
|
||
case 3:
|
||
autoDisabledCount++
|
||
}
|
||
|
||
if status != 1 {
|
||
if channel.ChannelInfo.MultiKeyDisabledTime != nil {
|
||
disabledTime = channel.ChannelInfo.MultiKeyDisabledTime[i]
|
||
}
|
||
if channel.ChannelInfo.MultiKeyDisabledReason != nil {
|
||
reason = channel.ChannelInfo.MultiKeyDisabledReason[i]
|
||
}
|
||
}
|
||
|
||
// Create key preview (first 10 chars)
|
||
keyPreview := key
|
||
if len(key) > 10 {
|
||
keyPreview = key[:10] + "..."
|
||
}
|
||
|
||
allKeyStatusList = append(allKeyStatusList, KeyStatus{
|
||
Index: i,
|
||
Status: status,
|
||
DisabledTime: disabledTime,
|
||
Reason: reason,
|
||
KeyPreview: keyPreview,
|
||
})
|
||
}
|
||
|
||
// Apply status filter if specified
|
||
var filteredKeyStatusList []KeyStatus
|
||
if request.Status != nil {
|
||
for _, keyStatus := range allKeyStatusList {
|
||
if keyStatus.Status == *request.Status {
|
||
filteredKeyStatusList = append(filteredKeyStatusList, keyStatus)
|
||
}
|
||
}
|
||
} else {
|
||
filteredKeyStatusList = allKeyStatusList
|
||
}
|
||
|
||
// Calculate pagination based on filtered results
|
||
filteredTotal := len(filteredKeyStatusList)
|
||
totalPages := (filteredTotal + pageSize - 1) / pageSize
|
||
if totalPages == 0 {
|
||
totalPages = 1
|
||
}
|
||
if page > totalPages {
|
||
page = totalPages
|
||
}
|
||
|
||
// Calculate range for current page
|
||
start := (page - 1) * pageSize
|
||
end := start + pageSize
|
||
if end > filteredTotal {
|
||
end = filteredTotal
|
||
}
|
||
|
||
// Get the page data
|
||
var pageKeyStatusList []KeyStatus
|
||
if start < filteredTotal {
|
||
pageKeyStatusList = filteredKeyStatusList[start:end]
|
||
}
|
||
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"success": true,
|
||
"message": "",
|
||
"data": MultiKeyStatusResponse{
|
||
Keys: pageKeyStatusList,
|
||
Total: filteredTotal, // Total of filtered results
|
||
Page: page,
|
||
PageSize: pageSize,
|
||
TotalPages: totalPages,
|
||
EnabledCount: enabledCount, // Overall statistics
|
||
ManualDisabledCount: manualDisabledCount, // Overall statistics
|
||
AutoDisabledCount: autoDisabledCount, // Overall statistics
|
||
},
|
||
})
|
||
return
|
||
|
||
case "disable_key":
|
||
if request.KeyIndex == nil {
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"success": false,
|
||
"message": "未指定要禁用的密钥索引",
|
||
})
|
||
return
|
||
}
|
||
|
||
keyIndex := *request.KeyIndex
|
||
if keyIndex < 0 || keyIndex >= channel.ChannelInfo.MultiKeySize {
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"success": false,
|
||
"message": "密钥索引超出范围",
|
||
})
|
||
return
|
||
}
|
||
|
||
if channel.ChannelInfo.MultiKeyStatusList == nil {
|
||
channel.ChannelInfo.MultiKeyStatusList = make(map[int]int)
|
||
}
|
||
if channel.ChannelInfo.MultiKeyDisabledTime == nil {
|
||
channel.ChannelInfo.MultiKeyDisabledTime = make(map[int]int64)
|
||
}
|
||
if channel.ChannelInfo.MultiKeyDisabledReason == nil {
|
||
channel.ChannelInfo.MultiKeyDisabledReason = make(map[int]string)
|
||
}
|
||
|
||
channel.ChannelInfo.MultiKeyStatusList[keyIndex] = 2 // disabled
|
||
|
||
err = channel.Update()
|
||
if err != nil {
|
||
common.ApiError(c, err)
|
||
return
|
||
}
|
||
|
||
model.InitChannelCache()
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"success": true,
|
||
"message": "密钥已禁用",
|
||
})
|
||
return
|
||
|
||
case "enable_key":
|
||
if request.KeyIndex == nil {
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"success": false,
|
||
"message": "未指定要启用的密钥索引",
|
||
})
|
||
return
|
||
}
|
||
|
||
keyIndex := *request.KeyIndex
|
||
if keyIndex < 0 || keyIndex >= channel.ChannelInfo.MultiKeySize {
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"success": false,
|
||
"message": "密钥索引超出范围",
|
||
})
|
||
return
|
||
}
|
||
|
||
// 从状态列表中删除该密钥的记录,使其回到默认启用状态
|
||
if channel.ChannelInfo.MultiKeyStatusList != nil {
|
||
delete(channel.ChannelInfo.MultiKeyStatusList, keyIndex)
|
||
}
|
||
if channel.ChannelInfo.MultiKeyDisabledTime != nil {
|
||
delete(channel.ChannelInfo.MultiKeyDisabledTime, keyIndex)
|
||
}
|
||
if channel.ChannelInfo.MultiKeyDisabledReason != nil {
|
||
delete(channel.ChannelInfo.MultiKeyDisabledReason, keyIndex)
|
||
}
|
||
|
||
err = channel.Update()
|
||
if err != nil {
|
||
common.ApiError(c, err)
|
||
return
|
||
}
|
||
|
||
model.InitChannelCache()
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"success": true,
|
||
"message": "密钥已启用",
|
||
})
|
||
return
|
||
|
||
case "enable_all_keys":
|
||
// 清空所有禁用状态,使所有密钥回到默认启用状态
|
||
var enabledCount int
|
||
if channel.ChannelInfo.MultiKeyStatusList != nil {
|
||
enabledCount = len(channel.ChannelInfo.MultiKeyStatusList)
|
||
}
|
||
|
||
channel.ChannelInfo.MultiKeyStatusList = make(map[int]int)
|
||
channel.ChannelInfo.MultiKeyDisabledTime = make(map[int]int64)
|
||
channel.ChannelInfo.MultiKeyDisabledReason = make(map[int]string)
|
||
|
||
err = channel.Update()
|
||
if err != nil {
|
||
common.ApiError(c, err)
|
||
return
|
||
}
|
||
|
||
model.InitChannelCache()
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"success": true,
|
||
"message": fmt.Sprintf("已启用 %d 个密钥", enabledCount),
|
||
})
|
||
return
|
||
|
||
case "disable_all_keys":
|
||
// 禁用所有启用的密钥
|
||
if channel.ChannelInfo.MultiKeyStatusList == nil {
|
||
channel.ChannelInfo.MultiKeyStatusList = make(map[int]int)
|
||
}
|
||
if channel.ChannelInfo.MultiKeyDisabledTime == nil {
|
||
channel.ChannelInfo.MultiKeyDisabledTime = make(map[int]int64)
|
||
}
|
||
if channel.ChannelInfo.MultiKeyDisabledReason == nil {
|
||
channel.ChannelInfo.MultiKeyDisabledReason = make(map[int]string)
|
||
}
|
||
|
||
var disabledCount int
|
||
for i := 0; i < channel.ChannelInfo.MultiKeySize; i++ {
|
||
status := 1 // default enabled
|
||
if s, exists := channel.ChannelInfo.MultiKeyStatusList[i]; exists {
|
||
status = s
|
||
}
|
||
|
||
// 只禁用当前启用的密钥
|
||
if status == 1 {
|
||
channel.ChannelInfo.MultiKeyStatusList[i] = 2 // disabled
|
||
disabledCount++
|
||
}
|
||
}
|
||
|
||
if disabledCount == 0 {
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"success": false,
|
||
"message": "没有可禁用的密钥",
|
||
})
|
||
return
|
||
}
|
||
|
||
err = channel.Update()
|
||
if err != nil {
|
||
common.ApiError(c, err)
|
||
return
|
||
}
|
||
|
||
model.InitChannelCache()
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"success": true,
|
||
"message": fmt.Sprintf("已禁用 %d 个密钥", disabledCount),
|
||
})
|
||
return
|
||
|
||
case "delete_key":
|
||
if request.KeyIndex == nil {
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"success": false,
|
||
"message": "未指定要删除的密钥索引",
|
||
})
|
||
return
|
||
}
|
||
|
||
keyIndex := *request.KeyIndex
|
||
if keyIndex < 0 || keyIndex >= channel.ChannelInfo.MultiKeySize {
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"success": false,
|
||
"message": "密钥索引超出范围",
|
||
})
|
||
return
|
||
}
|
||
|
||
keys := channel.GetKeys()
|
||
var remainingKeys []string
|
||
var newStatusList = make(map[int]int)
|
||
var newDisabledTime = make(map[int]int64)
|
||
var newDisabledReason = make(map[int]string)
|
||
|
||
newIndex := 0
|
||
for i, key := range keys {
|
||
// 跳过要删除的密钥
|
||
if i == keyIndex {
|
||
continue
|
||
}
|
||
|
||
remainingKeys = append(remainingKeys, key)
|
||
|
||
// 保留其他密钥的状态信息,重新索引
|
||
if channel.ChannelInfo.MultiKeyStatusList != nil {
|
||
if status, exists := channel.ChannelInfo.MultiKeyStatusList[i]; exists && status != 1 {
|
||
newStatusList[newIndex] = status
|
||
}
|
||
}
|
||
if channel.ChannelInfo.MultiKeyDisabledTime != nil {
|
||
if t, exists := channel.ChannelInfo.MultiKeyDisabledTime[i]; exists {
|
||
newDisabledTime[newIndex] = t
|
||
}
|
||
}
|
||
if channel.ChannelInfo.MultiKeyDisabledReason != nil {
|
||
if r, exists := channel.ChannelInfo.MultiKeyDisabledReason[i]; exists {
|
||
newDisabledReason[newIndex] = r
|
||
}
|
||
}
|
||
newIndex++
|
||
}
|
||
|
||
if len(remainingKeys) == 0 {
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"success": false,
|
||
"message": "不能删除最后一个密钥",
|
||
})
|
||
return
|
||
}
|
||
|
||
// Update channel with remaining keys
|
||
channel.Key = strings.Join(remainingKeys, "\n")
|
||
channel.ChannelInfo.MultiKeySize = len(remainingKeys)
|
||
channel.ChannelInfo.MultiKeyStatusList = newStatusList
|
||
channel.ChannelInfo.MultiKeyDisabledTime = newDisabledTime
|
||
channel.ChannelInfo.MultiKeyDisabledReason = newDisabledReason
|
||
|
||
err = channel.Update()
|
||
if err != nil {
|
||
common.ApiError(c, err)
|
||
return
|
||
}
|
||
|
||
model.InitChannelCache()
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"success": true,
|
||
"message": "密钥已删除",
|
||
})
|
||
return
|
||
|
||
case "delete_disabled_keys":
|
||
keys := channel.GetKeys()
|
||
var remainingKeys []string
|
||
var deletedCount int
|
||
var newStatusList = make(map[int]int)
|
||
var newDisabledTime = make(map[int]int64)
|
||
var newDisabledReason = make(map[int]string)
|
||
|
||
newIndex := 0
|
||
for i, key := range keys {
|
||
status := 1 // default enabled
|
||
if channel.ChannelInfo.MultiKeyStatusList != nil {
|
||
if s, exists := channel.ChannelInfo.MultiKeyStatusList[i]; exists {
|
||
status = s
|
||
}
|
||
}
|
||
|
||
// 只删除自动禁用(status == 3)的密钥,保留启用(status == 1)和手动禁用(status == 2)的密钥
|
||
if status == 3 {
|
||
deletedCount++
|
||
} else {
|
||
remainingKeys = append(remainingKeys, key)
|
||
// 保留非自动禁用密钥的状态信息,重新索引
|
||
if status != 1 {
|
||
newStatusList[newIndex] = status
|
||
if channel.ChannelInfo.MultiKeyDisabledTime != nil {
|
||
if t, exists := channel.ChannelInfo.MultiKeyDisabledTime[i]; exists {
|
||
newDisabledTime[newIndex] = t
|
||
}
|
||
}
|
||
if channel.ChannelInfo.MultiKeyDisabledReason != nil {
|
||
if r, exists := channel.ChannelInfo.MultiKeyDisabledReason[i]; exists {
|
||
newDisabledReason[newIndex] = r
|
||
}
|
||
}
|
||
}
|
||
newIndex++
|
||
}
|
||
}
|
||
|
||
if deletedCount == 0 {
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"success": false,
|
||
"message": "没有需要删除的自动禁用密钥",
|
||
})
|
||
return
|
||
}
|
||
|
||
// Update channel with remaining keys
|
||
channel.Key = strings.Join(remainingKeys, "\n")
|
||
channel.ChannelInfo.MultiKeySize = len(remainingKeys)
|
||
channel.ChannelInfo.MultiKeyStatusList = newStatusList
|
||
channel.ChannelInfo.MultiKeyDisabledTime = newDisabledTime
|
||
channel.ChannelInfo.MultiKeyDisabledReason = newDisabledReason
|
||
|
||
err = channel.Update()
|
||
if err != nil {
|
||
common.ApiError(c, err)
|
||
return
|
||
}
|
||
|
||
model.InitChannelCache()
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"success": true,
|
||
"message": fmt.Sprintf("已删除 %d 个自动禁用的密钥", deletedCount),
|
||
"data": deletedCount,
|
||
})
|
||
return
|
||
|
||
default:
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"success": false,
|
||
"message": "不支持的操作",
|
||
})
|
||
return
|
||
}
|
||
}
|
||
|
||
// OllamaPullModel 拉取 Ollama 模型
|
||
func OllamaPullModel(c *gin.Context) {
|
||
var req struct {
|
||
ChannelID int `json:"channel_id"`
|
||
ModelName string `json:"model_name"`
|
||
}
|
||
|
||
if err := c.ShouldBindJSON(&req); err != nil {
|
||
c.JSON(http.StatusBadRequest, gin.H{
|
||
"success": false,
|
||
"message": "Invalid request parameters",
|
||
})
|
||
return
|
||
}
|
||
|
||
if req.ChannelID == 0 || req.ModelName == "" {
|
||
c.JSON(http.StatusBadRequest, gin.H{
|
||
"success": false,
|
||
"message": "Channel ID and model name are required",
|
||
})
|
||
return
|
||
}
|
||
|
||
// 获取渠道信息
|
||
channel, err := model.GetChannelById(req.ChannelID, true)
|
||
if err != nil {
|
||
c.JSON(http.StatusNotFound, gin.H{
|
||
"success": false,
|
||
"message": "Channel not found",
|
||
})
|
||
return
|
||
}
|
||
|
||
// 检查是否是 Ollama 渠道
|
||
if channel.Type != constant.ChannelTypeOllama {
|
||
c.JSON(http.StatusBadRequest, gin.H{
|
||
"success": false,
|
||
"message": "This operation is only supported for Ollama channels",
|
||
})
|
||
return
|
||
}
|
||
|
||
baseURL := constant.ChannelBaseURLs[channel.Type]
|
||
if channel.GetBaseURL() != "" {
|
||
baseURL = channel.GetBaseURL()
|
||
}
|
||
|
||
key := strings.Split(channel.Key, "\n")[0]
|
||
err = ollama.PullOllamaModel(baseURL, key, req.ModelName)
|
||
if err != nil {
|
||
c.JSON(http.StatusInternalServerError, gin.H{
|
||
"success": false,
|
||
"message": fmt.Sprintf("Failed to pull model: %s", err.Error()),
|
||
})
|
||
return
|
||
}
|
||
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"success": true,
|
||
"message": fmt.Sprintf("Model %s pulled successfully", req.ModelName),
|
||
})
|
||
}
|
||
|
||
// OllamaPullModelStream 流式拉取 Ollama 模型
|
||
func OllamaPullModelStream(c *gin.Context) {
|
||
var req struct {
|
||
ChannelID int `json:"channel_id"`
|
||
ModelName string `json:"model_name"`
|
||
}
|
||
|
||
if err := c.ShouldBindJSON(&req); err != nil {
|
||
c.JSON(http.StatusBadRequest, gin.H{
|
||
"success": false,
|
||
"message": "Invalid request parameters",
|
||
})
|
||
return
|
||
}
|
||
|
||
if req.ChannelID == 0 || req.ModelName == "" {
|
||
c.JSON(http.StatusBadRequest, gin.H{
|
||
"success": false,
|
||
"message": "Channel ID and model name are required",
|
||
})
|
||
return
|
||
}
|
||
|
||
// 获取渠道信息
|
||
channel, err := model.GetChannelById(req.ChannelID, true)
|
||
if err != nil {
|
||
c.JSON(http.StatusNotFound, gin.H{
|
||
"success": false,
|
||
"message": "Channel not found",
|
||
})
|
||
return
|
||
}
|
||
|
||
// 检查是否是 Ollama 渠道
|
||
if channel.Type != constant.ChannelTypeOllama {
|
||
c.JSON(http.StatusBadRequest, gin.H{
|
||
"success": false,
|
||
"message": "This operation is only supported for Ollama channels",
|
||
})
|
||
return
|
||
}
|
||
|
||
baseURL := constant.ChannelBaseURLs[channel.Type]
|
||
if channel.GetBaseURL() != "" {
|
||
baseURL = channel.GetBaseURL()
|
||
}
|
||
|
||
// 设置 SSE 头部
|
||
c.Header("Content-Type", "text/event-stream")
|
||
c.Header("Cache-Control", "no-cache")
|
||
c.Header("Connection", "keep-alive")
|
||
c.Header("Access-Control-Allow-Origin", "*")
|
||
|
||
key := strings.Split(channel.Key, "\n")[0]
|
||
|
||
// 创建进度回调函数
|
||
progressCallback := func(progress ollama.OllamaPullResponse) {
|
||
data, _ := json.Marshal(progress)
|
||
fmt.Fprintf(c.Writer, "data: %s\n\n", string(data))
|
||
c.Writer.Flush()
|
||
}
|
||
|
||
// 执行拉取
|
||
err = ollama.PullOllamaModelStream(baseURL, key, req.ModelName, progressCallback)
|
||
|
||
if err != nil {
|
||
errorData, _ := json.Marshal(gin.H{
|
||
"error": err.Error(),
|
||
})
|
||
fmt.Fprintf(c.Writer, "data: %s\n\n", string(errorData))
|
||
} else {
|
||
successData, _ := json.Marshal(gin.H{
|
||
"message": fmt.Sprintf("Model %s pulled successfully", req.ModelName),
|
||
})
|
||
fmt.Fprintf(c.Writer, "data: %s\n\n", string(successData))
|
||
}
|
||
|
||
// 发送结束标志
|
||
fmt.Fprintf(c.Writer, "data: [DONE]\n\n")
|
||
c.Writer.Flush()
|
||
}
|
||
|
||
// OllamaDeleteModel 删除 Ollama 模型
|
||
func OllamaDeleteModel(c *gin.Context) {
|
||
var req struct {
|
||
ChannelID int `json:"channel_id"`
|
||
ModelName string `json:"model_name"`
|
||
}
|
||
|
||
if err := c.ShouldBindJSON(&req); err != nil {
|
||
c.JSON(http.StatusBadRequest, gin.H{
|
||
"success": false,
|
||
"message": "Invalid request parameters",
|
||
})
|
||
return
|
||
}
|
||
|
||
if req.ChannelID == 0 || req.ModelName == "" {
|
||
c.JSON(http.StatusBadRequest, gin.H{
|
||
"success": false,
|
||
"message": "Channel ID and model name are required",
|
||
})
|
||
return
|
||
}
|
||
|
||
// 获取渠道信息
|
||
channel, err := model.GetChannelById(req.ChannelID, true)
|
||
if err != nil {
|
||
c.JSON(http.StatusNotFound, gin.H{
|
||
"success": false,
|
||
"message": "Channel not found",
|
||
})
|
||
return
|
||
}
|
||
|
||
// 检查是否是 Ollama 渠道
|
||
if channel.Type != constant.ChannelTypeOllama {
|
||
c.JSON(http.StatusBadRequest, gin.H{
|
||
"success": false,
|
||
"message": "This operation is only supported for Ollama channels",
|
||
})
|
||
return
|
||
}
|
||
|
||
baseURL := constant.ChannelBaseURLs[channel.Type]
|
||
if channel.GetBaseURL() != "" {
|
||
baseURL = channel.GetBaseURL()
|
||
}
|
||
|
||
key := strings.Split(channel.Key, "\n")[0]
|
||
err = ollama.DeleteOllamaModel(baseURL, key, req.ModelName)
|
||
if err != nil {
|
||
c.JSON(http.StatusInternalServerError, gin.H{
|
||
"success": false,
|
||
"message": fmt.Sprintf("Failed to delete model: %s", err.Error()),
|
||
})
|
||
return
|
||
}
|
||
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"success": true,
|
||
"message": fmt.Sprintf("Model %s deleted successfully", req.ModelName),
|
||
})
|
||
}
|
||
|
||
// OllamaVersion 获取 Ollama 服务版本信息
|
||
func OllamaVersion(c *gin.Context) {
|
||
id, err := strconv.Atoi(c.Param("id"))
|
||
if err != nil {
|
||
c.JSON(http.StatusBadRequest, gin.H{
|
||
"success": false,
|
||
"message": "Invalid channel id",
|
||
})
|
||
return
|
||
}
|
||
|
||
channel, err := model.GetChannelById(id, true)
|
||
if err != nil {
|
||
c.JSON(http.StatusNotFound, gin.H{
|
||
"success": false,
|
||
"message": "Channel not found",
|
||
})
|
||
return
|
||
}
|
||
|
||
if channel.Type != constant.ChannelTypeOllama {
|
||
c.JSON(http.StatusBadRequest, gin.H{
|
||
"success": false,
|
||
"message": "This operation is only supported for Ollama channels",
|
||
})
|
||
return
|
||
}
|
||
|
||
baseURL := constant.ChannelBaseURLs[channel.Type]
|
||
if channel.GetBaseURL() != "" {
|
||
baseURL = channel.GetBaseURL()
|
||
}
|
||
|
||
key := strings.Split(channel.Key, "\n")[0]
|
||
version, err := ollama.FetchOllamaVersion(baseURL, key)
|
||
if err != nil {
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"success": false,
|
||
"message": fmt.Sprintf("获取Ollama版本失败: %s", err.Error()),
|
||
})
|
||
return
|
||
}
|
||
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"success": true,
|
||
"data": gin.H{
|
||
"version": version,
|
||
},
|
||
})
|
||
}
|