tokenFactory/controller/channel.go

2720 lines
76 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package controller
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"strconv"
"strings"
"time"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/constant"
"github.com/QuantumNous/new-api/dto"
"github.com/QuantumNous/new-api/model"
relaychannel "github.com/QuantumNous/new-api/relay/channel"
"github.com/QuantumNous/new-api/relay/channel/gemini"
"github.com/QuantumNous/new-api/relay/channel/ollama"
"github.com/QuantumNous/new-api/service"
"github.com/gin-gonic/gin"
)
type OpenAIModel struct {
ID string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
OwnedBy string `json:"owned_by"`
Metadata map[string]any `json:"metadata,omitempty"`
Permission []struct {
ID string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
AllowCreateEngine bool `json:"allow_create_engine"`
AllowSampling bool `json:"allow_sampling"`
AllowLogprobs bool `json:"allow_logprobs"`
AllowSearchIndices bool `json:"allow_search_indices"`
AllowView bool `json:"allow_view"`
AllowFineTuning bool `json:"allow_fine_tuning"`
Organization string `json:"organization"`
Group string `json:"group"`
IsBlocking bool `json:"is_blocking"`
} `json:"permission"`
Root string `json:"root"`
Parent string `json:"parent"`
}
type OpenAIModelsResponse struct {
Data []OpenAIModel `json:"data"`
Success bool `json:"success"`
}
var channelAllowedSupplierTypes = map[string]struct{}{
"公有云": {},
"AIDC": {},
"企业中转站": {},
"个人中转站": {},
}
// defaultChannelSupplierType 当渠道行与关联供应商申请均未提供 supplier_type 时的兜底值(须为 channelAllowedSupplierTypes 之一)。
const defaultChannelSupplierType = "公有云"
// isValidChannelSupplierType 校验供应商类型是否属于预定义枚举值。
func isValidChannelSupplierType(supplierType string) bool {
_, ok := channelAllowedSupplierTypes[supplierType]
return ok
}
func parseStatusFilter(statusParam string) int {
switch strings.ToLower(statusParam) {
case "enabled", "1":
return common.ChannelStatusEnabled
case "disabled", "0":
return 0
default:
return -1
}
}
func clearChannelInfo(channel *model.Channel) {
if channel.ChannelInfo.IsMultiKey {
channel.ChannelInfo.MultiKeyDisabledReason = nil
channel.ChannelInfo.MultiKeyDisabledTime = nil
}
}
// attachSupplierNames 为渠道列表补齐供应商用户名owner_user_id 对应 users.username
func attachSupplierNames(channels []*model.Channel) {
ownerIDs := make([]int, 0)
ownerSet := make(map[int]struct{})
for _, channel := range channels {
if channel == nil || channel.OwnerUserID <= 0 {
continue
}
if _, ok := ownerSet[channel.OwnerUserID]; ok {
continue
}
ownerSet[channel.OwnerUserID] = struct{}{}
ownerIDs = append(ownerIDs, channel.OwnerUserID)
}
if len(ownerIDs) == 0 {
return
}
var users []model.User
if err := model.DB.Select("id, username").Where("id IN ?", ownerIDs).Find(&users).Error; err != nil {
return
}
userMap := make(map[int]string, len(users))
for _, user := range users {
userMap[user.Id] = user.Username
}
for _, channel := range channels {
if channel == nil || channel.OwnerUserID <= 0 {
continue
}
channel.SupplierName = userMap[channel.OwnerUserID]
}
}
func GetAllChannels(c *gin.Context) {
pageInfo := common.GetPageQuery(c)
channelData := make([]*model.Channel, 0)
idSort, _ := strconv.ParseBool(c.Query("id_sort"))
enableTagMode, _ := strconv.ParseBool(c.Query("tag_mode"))
supplierKeyword := strings.TrimSpace(c.Query("supplier"))
statusParam := c.Query("status")
// statusFilter: -1 all, 1 enabled, 0 disabled (include auto & manual)
statusFilter := parseStatusFilter(statusParam)
// type filter
typeStr := c.Query("type")
typeFilter := -1
if typeStr != "" {
if t, err := strconv.Atoi(typeStr); err == nil {
typeFilter = t
}
}
// 供应商复用原渠道列表接口:仅查看自己渠道,管理员保持原有全量逻辑。
if c.GetInt("role") < common.RoleAdminUser {
baseQuery := model.DB.Model(&model.Channel{}).Where("owner_user_id = ?", c.GetInt("id"))
if typeFilter >= 0 {
baseQuery = baseQuery.Where("type = ?", typeFilter)
}
if statusFilter == common.ChannelStatusEnabled {
baseQuery = baseQuery.Where("status = ?", common.ChannelStatusEnabled)
} else if statusFilter == 0 {
baseQuery = baseQuery.Where("status != ?", common.ChannelStatusEnabled)
}
var total int64
if err := baseQuery.Count(&total).Error; err != nil {
common.ApiError(c, err)
return
}
order := "priority desc"
if idSort {
order = "id desc"
}
if err := baseQuery.Order(order).Limit(pageInfo.GetPageSize()).Offset(pageInfo.GetStartIdx()).Omit("key").Find(&channelData).Error; err != nil {
common.ApiError(c, err)
return
}
for _, datum := range channelData {
clearChannelInfo(datum)
}
attachSupplierNames(channelData)
typeCounts := make(map[int64]int64)
for _, channel := range channelData {
typeCounts[int64(channel.Type)]++
}
common.ApiSuccess(c, gin.H{
"items": channelData,
"total": total,
"page": pageInfo.GetPage(),
"page_size": pageInfo.GetPageSize(),
"type_counts": typeCounts,
})
return
}
var total int64
if enableTagMode {
tags, err := model.GetPaginatedTags(pageInfo.GetStartIdx(), pageInfo.GetPageSize())
if err != nil {
common.SysError("failed to get paginated tags: " + err.Error())
c.JSON(http.StatusOK, gin.H{"success": false, "message": "获取标签失败,请稍后重试"})
return
}
for _, tag := range tags {
if tag == nil || *tag == "" {
continue
}
tagChannels, err := model.GetChannelsByTag(*tag, idSort, false)
if err != nil {
continue
}
filtered := make([]*model.Channel, 0)
for _, ch := range tagChannels {
if statusFilter == common.ChannelStatusEnabled && ch.Status != common.ChannelStatusEnabled {
continue
}
if statusFilter == 0 && ch.Status == common.ChannelStatusEnabled {
continue
}
if typeFilter >= 0 && ch.Type != typeFilter {
continue
}
filtered = append(filtered, ch)
}
channelData = append(channelData, filtered...)
}
total, _ = model.CountAllTags()
} else {
baseQuery := model.DB.Model(&model.Channel{})
if supplierKeyword != "" {
baseQuery = baseQuery.Joins("LEFT JOIN users ON users.id = channels.owner_user_id").Where("users.username LIKE ?", "%"+supplierKeyword+"%")
}
if typeFilter >= 0 {
baseQuery = baseQuery.Where("type = ?", typeFilter)
}
if statusFilter == common.ChannelStatusEnabled {
baseQuery = baseQuery.Where("status = ?", common.ChannelStatusEnabled)
} else if statusFilter == 0 {
baseQuery = baseQuery.Where("status != ?", common.ChannelStatusEnabled)
}
baseQuery.Count(&total)
order := "priority desc"
if idSort {
order = "id desc"
}
err := baseQuery.Order(order).Limit(pageInfo.GetPageSize()).Offset(pageInfo.GetStartIdx()).Omit("key").Find(&channelData).Error
if err != nil {
common.SysError("failed to get channels: " + err.Error())
c.JSON(http.StatusOK, gin.H{"success": false, "message": "获取渠道列表失败,请稍后重试"})
return
}
}
for _, datum := range channelData {
clearChannelInfo(datum)
}
attachSupplierNames(channelData)
countQuery := model.DB.Model(&model.Channel{})
if statusFilter == common.ChannelStatusEnabled {
countQuery = countQuery.Where("status = ?", common.ChannelStatusEnabled)
} else if statusFilter == 0 {
countQuery = countQuery.Where("status != ?", common.ChannelStatusEnabled)
}
var results []struct {
Type int64
Count int64
}
_ = countQuery.Select("type, count(*) as count").Group("type").Find(&results).Error
typeCounts := make(map[int64]int64)
for _, r := range results {
typeCounts[r.Type] = r.Count
}
common.ApiSuccess(c, gin.H{
"items": channelData,
"total": total,
"page": pageInfo.GetPage(),
"page_size": pageInfo.GetPageSize(),
"type_counts": typeCounts,
})
return
}
func buildFetchModelsHeaders(channel *model.Channel, key string) (http.Header, error) {
var headers http.Header
switch channel.Type {
case constant.ChannelTypeAnthropic:
headers = GetClaudeAuthHeader(key)
default:
headers = GetAuthHeader(key)
}
headerOverride := channel.GetHeaderOverride()
for k, v := range headerOverride {
if relaychannel.IsHeaderPassthroughRuleKey(k) {
continue
}
str, ok := v.(string)
if !ok {
return nil, fmt.Errorf("invalid header override for key %s", k)
}
if strings.Contains(str, "{api_key}") {
str = strings.ReplaceAll(str, "{api_key}", key)
}
headers.Set(k, str)
}
return headers, nil
}
func FetchUpstreamModels(c *gin.Context) {
id, err := strconv.Atoi(c.Param("id"))
if err != nil {
common.ApiError(c, err)
return
}
channel, err := model.GetChannelById(id, true)
if err != nil {
common.ApiError(c, err)
return
}
// 供应商只允许拉取自己渠道的上游模型,防止跨供应商越权读取。
if c.GetInt("role") < common.RoleAdminUser && channel.OwnerUserID != c.GetInt("id") {
c.JSON(http.StatusForbidden, gin.H{
"success": false,
"message": "无权访问其他供应商渠道",
})
return
}
ids, err := fetchChannelUpstreamModelIDs(channel)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": fmt.Sprintf("获取模型列表失败: %s", err.Error()),
})
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": ids,
})
}
func FixChannelsAbilities(c *gin.Context) {
success, fails, err := model.FixAbility()
if err != nil {
common.ApiError(c, err)
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": gin.H{
"success": success,
"fails": fails,
},
})
}
func SearchChannels(c *gin.Context) {
keyword := c.Query("keyword")
supplierKeyword := strings.TrimSpace(c.Query("supplier"))
group := c.Query("group")
modelKeyword := c.Query("model")
statusParam := c.Query("status")
statusFilter := parseStatusFilter(statusParam)
idSort, _ := strconv.ParseBool(c.Query("id_sort"))
enableTagMode, _ := strconv.ParseBool(c.Query("tag_mode"))
// 供应商复用原渠道搜索接口:仅查询自己渠道。
if c.GetInt("role") < common.RoleAdminUser {
channelID, _ := model.ParseSupplierChannelIDFilter(keyword)
filter := model.SupplierChannelSearchFilter{
ChannelID: channelID,
Keyword: keyword,
Supplier: supplierKeyword,
ModelKeyword: modelKeyword,
Group: group,
}
ownerUserID := c.GetInt("id")
channelData, total, err := model.SearchSupplierChannels(&ownerUserID, 0, 100000, filter)
if err != nil {
common.ApiError(c, err)
return
}
if statusFilter == common.ChannelStatusEnabled || statusFilter == 0 {
filtered := make([]*model.Channel, 0, len(channelData))
for _, ch := range channelData {
if statusFilter == common.ChannelStatusEnabled && ch.Status != common.ChannelStatusEnabled {
continue
}
if statusFilter == 0 && ch.Status == common.ChannelStatusEnabled {
continue
}
filtered = append(filtered, ch)
}
channelData = filtered
total = int64(len(filtered))
}
typeParam := c.Query("type")
typeFilter := -1
if typeParam != "" {
if tp, err := strconv.Atoi(typeParam); err == nil {
typeFilter = tp
}
}
if typeFilter >= 0 {
filtered := make([]*model.Channel, 0, len(channelData))
for _, ch := range channelData {
if ch.Type == typeFilter {
filtered = append(filtered, ch)
}
}
channelData = filtered
total = int64(len(filtered))
}
typeCounts := make(map[int64]int64)
for _, channel := range channelData {
typeCounts[int64(channel.Type)]++
}
page, _ := strconv.Atoi(c.DefaultQuery("p", "1"))
pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20"))
if page < 1 {
page = 1
}
if pageSize <= 0 {
pageSize = 20
}
startIdx := (page - 1) * pageSize
if startIdx > len(channelData) {
startIdx = len(channelData)
}
endIdx := startIdx + pageSize
if endIdx > len(channelData) {
endIdx = len(channelData)
}
pagedData := channelData[startIdx:endIdx]
for _, datum := range pagedData {
clearChannelInfo(datum)
}
attachSupplierNames(pagedData)
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": gin.H{
"items": pagedData,
"total": total,
"type_counts": typeCounts,
},
})
return
}
channelData := make([]*model.Channel, 0)
if enableTagMode {
tags, err := model.SearchTags(keyword, group, modelKeyword, idSort)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
for _, tag := range tags {
if tag != nil && *tag != "" {
tagChannel, err := model.GetChannelsByTag(*tag, idSort, false)
if err == nil {
channelData = append(channelData, tagChannel...)
}
}
}
} else {
channels, err := model.SearchChannels(keyword, group, modelKeyword, idSort)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
channelData = channels
}
attachSupplierNames(channelData)
if supplierKeyword != "" {
filteredBySupplier := make([]*model.Channel, 0, len(channelData))
for _, ch := range channelData {
if strings.Contains(strings.ToLower(ch.SupplierName), strings.ToLower(supplierKeyword)) {
filteredBySupplier = append(filteredBySupplier, ch)
}
}
channelData = filteredBySupplier
}
if statusFilter == common.ChannelStatusEnabled || statusFilter == 0 {
filtered := make([]*model.Channel, 0, len(channelData))
for _, ch := range channelData {
if statusFilter == common.ChannelStatusEnabled && ch.Status != common.ChannelStatusEnabled {
continue
}
if statusFilter == 0 && ch.Status == common.ChannelStatusEnabled {
continue
}
filtered = append(filtered, ch)
}
channelData = filtered
}
// calculate type counts for search results
typeCounts := make(map[int64]int64)
for _, channel := range channelData {
typeCounts[int64(channel.Type)]++
}
typeParam := c.Query("type")
typeFilter := -1
if typeParam != "" {
if tp, err := strconv.Atoi(typeParam); err == nil {
typeFilter = tp
}
}
if typeFilter >= 0 {
filtered := make([]*model.Channel, 0, len(channelData))
for _, ch := range channelData {
if ch.Type == typeFilter {
filtered = append(filtered, ch)
}
}
channelData = filtered
}
page, _ := strconv.Atoi(c.DefaultQuery("p", "1"))
pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20"))
if page < 1 {
page = 1
}
if pageSize <= 0 {
pageSize = 20
}
total := len(channelData)
startIdx := (page - 1) * pageSize
if startIdx > total {
startIdx = total
}
endIdx := startIdx + pageSize
if endIdx > total {
endIdx = total
}
pagedData := channelData[startIdx:endIdx]
for _, datum := range pagedData {
clearChannelInfo(datum)
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": gin.H{
"items": pagedData,
"total": total,
"type_counts": typeCounts,
},
})
return
}
func GetChannel(c *gin.Context) {
id, err := strconv.Atoi(c.Param("id"))
if err != nil {
common.ApiError(c, err)
return
}
channel, err := model.GetChannelById(id, false)
if err != nil {
common.ApiError(c, err)
return
}
// 供应商仅允许查看自己归属的渠道。
if c.GetInt("role") < common.RoleAdminUser && channel.OwnerUserID != c.GetInt("id") {
c.JSON(http.StatusForbidden, gin.H{
"success": false,
"message": "无权访问其他供应商渠道",
})
return
}
if channel != nil {
clearChannelInfo(channel)
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": channel,
})
return
}
// GetChannelKey 获取渠道密钥(需要通过安全验证中间件)
// 此函数依赖 SecureVerificationRequired 中间件,确保用户已通过安全验证
func GetChannelKey(c *gin.Context) {
userId := c.GetInt("id")
channelId, err := strconv.Atoi(c.Param("id"))
if err != nil {
common.ApiError(c, fmt.Errorf("渠道ID格式错误: %v", err))
return
}
// 获取渠道信息(包含密钥)
channel, err := model.GetChannelById(channelId, true)
if err != nil {
common.ApiError(c, fmt.Errorf("获取渠道信息失败: %v", err))
return
}
if channel == nil {
common.ApiError(c, fmt.Errorf("渠道不存在"))
return
}
// 记录操作日志
model.RecordLog(userId, model.LogTypeSystem, fmt.Sprintf("查看渠道密钥信息 (渠道ID: %d)", channelId))
// 返回渠道密钥
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "获取成功",
"data": map[string]interface{}{
"key": channel.Key,
},
})
}
// validateTwoFactorAuth 统一的2FA验证函数
func validateTwoFactorAuth(twoFA *model.TwoFA, code string) bool {
// 尝试验证TOTP
if cleanCode, err := common.ValidateNumericCode(code); err == nil {
if isValid, _ := twoFA.ValidateTOTPAndUpdateUsage(cleanCode); isValid {
return true
}
}
// 尝试验证备用码
if isValid, err := twoFA.ValidateBackupCodeAndUpdateUsage(code); err == nil && isValid {
return true
}
return false
}
// validateChannel 通用的渠道校验函数
func validateChannel(channel *model.Channel, isAdd bool) error {
if channel == nil {
return fmt.Errorf("channel cannot be empty")
}
channel.CompanyLogoURL = strings.TrimSpace(channel.CompanyLogoURL)
channel.SupplierType = strings.TrimSpace(channel.SupplierType)
// TokenFactoryOpen (type=60) 渠道的 supplier_type 由上游同步继承,创建时允许为空
if channel.SupplierType == "" && channel.Type != constant.ChannelTypeTokenFactoryOpen {
return fmt.Errorf("供应商类型不能为空")
}
if channel.SupplierType != "" && !isValidChannelSupplierType(channel.SupplierType) {
return fmt.Errorf("供应商类型无效")
}
// 校验 channel settings
if err := channel.ValidateSettings(); err != nil {
return fmt.Errorf("渠道额外设置[channel setting] 格式错误:%s", err.Error())
}
// 如果是添加操作,检查 channel 和 key 是否为空
if isAdd {
if channel.Key == "" {
return fmt.Errorf("channel cannot be empty")
}
// 检查模型名称长度是否超过 255
for _, m := range channel.GetModels() {
if len(m) > 255 {
return fmt.Errorf("模型名称过长: %s", m)
}
}
}
// VertexAI 特殊校验
if channel.Type == constant.ChannelTypeVertexAi {
if channel.Other == "" {
return fmt.Errorf("部署地区不能为空")
}
regionMap, err := common.StrToMap(channel.Other)
if err != nil {
return fmt.Errorf("部署地区必须是标准的Json格式例如{\"default\": \"us-central1\", \"region2\": \"us-east1\"}")
}
if regionMap["default"] == nil {
return fmt.Errorf("部署地区必须包含default字段")
}
}
// Codex OAuth key validation (optional, only when JSON object is provided)
if channel.Type == constant.ChannelTypeCodex {
trimmedKey := strings.TrimSpace(channel.Key)
if isAdd || trimmedKey != "" {
if !strings.HasPrefix(trimmedKey, "{") {
return fmt.Errorf("Codex key must be a valid JSON object")
}
var keyMap map[string]any
if err := common.Unmarshal([]byte(trimmedKey), &keyMap); err != nil {
return fmt.Errorf("Codex key must be a valid JSON object")
}
if v, ok := keyMap["access_token"]; !ok || v == nil || strings.TrimSpace(fmt.Sprintf("%v", v)) == "" {
return fmt.Errorf("Codex key JSON must include access_token")
}
if v, ok := keyMap["account_id"]; !ok || v == nil || strings.TrimSpace(fmt.Sprintf("%v", v)) == "" {
return fmt.Errorf("Codex key JSON must include account_id")
}
}
}
if channel != nil && channel.PriceDiscountPercent != nil {
v := *channel.PriceDiscountPercent
if v < 0 || v > 1000 {
return fmt.Errorf("价格折扣(百分比)须介于 0 与 1000 之间100 表示无折扣60 表示按原价 60%% 计费")
}
}
if rs := strings.TrimSpace(channel.RouteSlug); rs != "" && !model.IsValidRouteSlug(rs) {
return fmt.Errorf("route_slug 格式无效232 位字母数字,且不能为 c 加纯数字)")
}
return nil
}
func RefreshCodexChannelCredential(c *gin.Context) {
channelId, err := strconv.Atoi(c.Param("id"))
if err != nil {
common.ApiError(c, fmt.Errorf("invalid channel id: %w", err))
return
}
ctx, cancel := context.WithTimeout(c.Request.Context(), 10*time.Second)
defer cancel()
oauthKey, ch, err := service.RefreshCodexChannelCredential(ctx, channelId, service.CodexCredentialRefreshOptions{ResetCaches: true})
if err != nil {
common.SysError("failed to refresh codex channel credential: " + err.Error())
c.JSON(http.StatusOK, gin.H{"success": false, "message": "刷新凭证失败,请稍后重试"})
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "refreshed",
"data": gin.H{
"expires_at": oauthKey.Expired,
"last_refresh": oauthKey.LastRefresh,
"account_id": oauthKey.AccountID,
"email": oauthKey.Email,
"channel_id": ch.Id,
"channel_type": ch.Type,
"channel_name": ch.Name,
},
})
}
type AddChannelRequest struct {
Mode string `json:"mode"`
MultiKeyMode constant.MultiKeyMode `json:"multi_key_mode"`
BatchAddSetKeyPrefix2Name bool `json:"batch_add_set_key_prefix_2_name"`
Channel *model.Channel `json:"channel"`
}
// applySupplierChannelOwnershipForCreate 在供应商创建渠道时强制写入归属信息,防止越权伪造 owner 字段。
func applySupplierChannelOwnershipForCreate(c *gin.Context, channel *model.Channel) error {
if c.GetInt("role") >= common.RoleAdminUser {
return nil
}
app, err := model.GetApprovedSupplierApplicationByApplicant(c.GetInt("id"))
if err != nil {
return err
}
channel.OwnerUserID = c.GetInt("id")
channel.SupplierApplicationID = app.ID
return nil
}
// validateSupplierChannelOwnershipForUpdate 校验供应商仅可更新自己的渠道,管理员不受限制。
func validateSupplierChannelOwnershipForUpdate(c *gin.Context, originChannel *model.Channel) bool {
if c.GetInt("role") >= common.RoleAdminUser {
return true
}
return originChannel.OwnerUserID == c.GetInt("id")
}
func getVertexArrayKeys(keys string) ([]string, error) {
if keys == "" {
return nil, nil
}
var keyArray []interface{}
err := common.Unmarshal([]byte(keys), &keyArray)
if err != nil {
return nil, fmt.Errorf("批量添加 Vertex AI 必须使用标准的JsonArray格式例如[{key1}, {key2}...],请检查输入: %w", err)
}
cleanKeys := make([]string, 0, len(keyArray))
for _, key := range keyArray {
var keyStr string
switch v := key.(type) {
case string:
keyStr = strings.TrimSpace(v)
default:
bytes, err := json.Marshal(v)
if err != nil {
return nil, fmt.Errorf("Vertex AI key JSON 编码失败: %w", err)
}
keyStr = string(bytes)
}
if keyStr != "" {
cleanKeys = append(cleanKeys, keyStr)
}
}
if len(cleanKeys) == 0 {
return nil, fmt.Errorf("批量添加 Vertex AI 的 keys 不能为空")
}
return cleanKeys, nil
}
type upstreamChannelSyncItem struct {
ID int `json:"id"`
Name string `json:"name"`
Models string `json:"models"`
Group string `json:"group"`
Status int `json:"status"`
Type int `json:"type"`
ChannelNo string `json:"channel_no"`
RouteSlug string `json:"route_slug"`
SupplierApplication int `json:"supplier_application_id"`
SupplierAlias string `json:"supplier_alias"`
SupplierType string `json:"supplier_type"`
CompanyLogoURL string `json:"company_logo_url"`
PriceDiscountPercent float64 `json:"price_discount_percent"`
MarkupDiscountRate float64 `json:"markup_discount_rate"`
ModelMapping string `json:"model_mapping"`
ModelPrice map[string]float64 `json:"model_price"`
ModelRatio map[string]float64 `json:"model_ratio"`
}
func decodeUpstreamModelMapping(m map[string]any) string {
raw, ok := m["model_mapping"]
if !ok || raw == nil {
return ""
}
switch x := raw.(type) {
case string:
return strings.TrimSpace(x)
case map[string]any:
b, err := json.Marshal(x)
if err != nil {
return ""
}
return strings.TrimSpace(string(b))
default:
b, err := json.Marshal(raw)
if err != nil {
return strings.TrimSpace(common.Interface2String(raw))
}
return strings.TrimSpace(string(b))
}
}
func isTokenFactoryOpenBaseURL(raw string) bool {
parsed, err := url.Parse(strings.TrimSpace(raw))
if err != nil {
return false
}
scheme := strings.ToLower(strings.TrimSpace(parsed.Scheme))
if scheme != "http" && scheme != "https" {
return false
}
return strings.TrimSpace(parsed.Hostname()) != ""
}
func isLikelyTokenFactoryStatusData(data map[string]any, systemName string) bool {
name := strings.ToLower(strings.TrimSpace(systemName))
if strings.Contains(name, "tokenfactory") ||
strings.Contains(name, "词元工厂") ||
strings.Contains(name, "开放词元工厂") {
return true
}
score := 0
if strings.TrimSpace(common.Interface2String(data["version"])) != "" {
score++
}
if startTimeRaw := strings.TrimSpace(common.Interface2String(data["start_time"])); startTimeRaw != "" {
if startTime, err := strconv.ParseInt(startTimeRaw, 10, 64); err == nil && startTime > 0 {
score++
}
}
if strings.TrimSpace(common.Interface2String(data["quota_display_type"])) != "" ||
strings.TrimSpace(common.Interface2String(data["quota_per_unit"])) != "" {
score++
}
if _, ok := data["enable_drawing"]; ok {
score++
}
if _, ok := data["enable_task"]; ok {
score++
}
if _, ok := data["system_name"]; ok {
score++
}
// 命中特征达到阈值即视为 TokenFactory 平台实例,避免只依赖 system_name 英文名。
return score >= 4
}
func fetchTokenFactoryStatus(baseURL string, key string) error {
client := &http.Client{Timeout: 10 * time.Second}
u := strings.TrimRight(strings.TrimSpace(baseURL), "/") + "/api/status"
req, err := http.NewRequest("GET", u, nil)
if err != nil {
return err
}
req.Header.Set("Authorization", "Bearer "+strings.TrimSpace(key))
resp, err := client.Do(req)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("status code %d", resp.StatusCode)
}
var payload map[string]any
if err := json.NewDecoder(resp.Body).Decode(&payload); err != nil {
return err
}
if success, ok := payload["success"].(bool); ok && !success {
return fmt.Errorf("status 接口返回失败")
}
var systemName string
var statusData map[string]any
if parsedData, ok := payload["data"].(map[string]any); ok {
statusData = parsedData
systemName = strings.TrimSpace(common.Interface2String(statusData["system_name"]))
}
if statusData == nil {
return fmt.Errorf("status 返回结构缺少 data")
}
if !isLikelyTokenFactoryStatusData(statusData, systemName) {
return fmt.Errorf("status 特征不匹配 TokenFactory 平台system_name=%s", systemName)
}
return nil
}
func decodeUpstreamChannelPayload(payload map[string]any, itemsKey string) ([]upstreamChannelSyncItem, error) {
successRaw, exists := payload["success"]
if !exists {
return nil, fmt.Errorf("上游响应缺少 success 字段")
}
success, ok := successRaw.(bool)
if !ok {
return nil, fmt.Errorf("上游 success 字段类型异常: %T", successRaw)
}
if !success {
upstreamMessage := strings.TrimSpace(common.Interface2String(payload["message"]))
if upstreamMessage == "" {
upstreamMessage = "上游返回失败message 为空)"
}
return nil, fmt.Errorf("%s", upstreamMessage)
}
data, _ := payload["data"].(map[string]any)
if data == nil {
return nil, fmt.Errorf("上游响应缺少 data")
}
rawItems, _ := data[itemsKey].([]any)
items := make([]upstreamChannelSyncItem, 0, len(rawItems))
for _, raw := range rawItems {
m, ok := raw.(map[string]any)
if !ok {
continue
}
item := upstreamChannelSyncItem{
ID: common.String2Int(common.Interface2String(m["id"])),
Name: strings.TrimSpace(common.Interface2String(m["name"])),
Models: strings.TrimSpace(common.Interface2String(m["models"])),
Group: strings.TrimSpace(common.Interface2String(m["group"])),
Status: common.String2Int(common.Interface2String(m["status"])),
Type: common.String2Int(common.Interface2String(m["type"])),
ChannelNo: strings.TrimSpace(common.Interface2String(m["channel_no"])),
RouteSlug: strings.TrimSpace(common.Interface2String(m["route_slug"])),
SupplierApplication: common.String2Int(common.Interface2String(m["supplier_application_id"])),
SupplierAlias: strings.TrimSpace(common.Interface2String(m["supplier_alias"])),
SupplierType: strings.TrimSpace(common.Interface2String(m["supplier_type"])),
CompanyLogoURL: strings.TrimSpace(common.Interface2String(m["company_logo_url"])),
}
if v, ok := m["price_discount_percent"]; ok && v != nil {
switch x := v.(type) {
case float64:
item.PriceDiscountPercent = x
case json.Number:
if f, err := x.Float64(); err == nil {
item.PriceDiscountPercent = f
}
default:
if f, err := strconv.ParseFloat(strings.TrimSpace(common.Interface2String(v)), 64); err == nil {
item.PriceDiscountPercent = f
}
}
}
if v, ok := m["markup_discount_rate"]; ok && v != nil {
switch x := v.(type) {
case float64:
item.MarkupDiscountRate = x
case json.Number:
if f, err := x.Float64(); err == nil {
item.MarkupDiscountRate = f
}
default:
if f, err := strconv.ParseFloat(strings.TrimSpace(common.Interface2String(v)), 64); err == nil {
item.MarkupDiscountRate = f
}
}
}
if mp, ok := m["model_price"].(map[string]any); ok && len(mp) > 0 {
item.ModelPrice = jsonAnyMapToFloatMap(mp)
}
if mr, ok := m["model_ratio"].(map[string]any); ok && len(mr) > 0 {
item.ModelRatio = jsonAnyMapToFloatMap(mr)
}
item.ModelMapping = decodeUpstreamModelMapping(m)
items = append(items, item)
}
return items, nil
}
func fetchTokenFactoryUpstreamChannelsExport(baseURL string, key string) ([]upstreamChannelSyncItem, error) {
client := &http.Client{Timeout: 45 * time.Second}
u := strings.TrimRight(strings.TrimSpace(baseURL), "/") + "/api/tf_open_sync/channels"
req, err := http.NewRequest("GET", u, nil)
if err != nil {
return nil, err
}
k := strings.TrimSpace(key)
req.Header.Set("Authorization", "Bearer "+k)
req.Header.Set("X-TokenFactory-Open-Sync-Secret", k)
resp, err := client.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode == http.StatusNotFound {
return nil, errTfOpenExportNotFound
}
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(io.LimitReader(resp.Body, 1024))
return nil, fmt.Errorf("export 接口 status code %d: %s", resp.StatusCode, strings.TrimSpace(string(body)))
}
var payload map[string]any
if err := json.NewDecoder(resp.Body).Decode(&payload); err != nil {
return nil, err
}
return decodeUpstreamChannelPayload(payload, "channels")
}
var errTfOpenExportNotFound = errors.New("tf_open_sync channels export not found")
func jsonAnyMapToFloatMap(raw map[string]any) map[string]float64 {
out := make(map[string]float64)
for k, v := range raw {
switch x := v.(type) {
case float64:
out[k] = x
case json.Number:
if f, err := x.Float64(); err == nil {
out[k] = f
}
default:
if f, err := strconv.ParseFloat(strings.TrimSpace(common.Interface2String(v)), 64); err == nil {
out[k] = f
}
}
}
if len(out) == 0 {
return nil
}
return out
}
func fetchTokenFactoryUpstreamChannelsLegacy(baseURL string, key string) ([]upstreamChannelSyncItem, error) {
client := &http.Client{Timeout: 15 * time.Second}
u := strings.TrimRight(strings.TrimSpace(baseURL), "/") + "/api/channel/?p=1&page_size=100000&id_sort=true"
req, err := http.NewRequest("GET", u, nil)
if err != nil {
return nil, err
}
req.Header.Set("Authorization", "Bearer "+strings.TrimSpace(key))
resp, err := client.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(io.LimitReader(resp.Body, 1024))
return nil, fmt.Errorf("upstream status code %d: %s", resp.StatusCode, strings.TrimSpace(string(body)))
}
var payload map[string]any
if err := json.NewDecoder(resp.Body).Decode(&payload); err != nil {
return nil, err
}
return decodeUpstreamChannelPayload(payload, "items")
}
func fetchTokenFactoryUpstreamChannels(baseURL string, key string) ([]upstreamChannelSyncItem, error) {
items, err := fetchTokenFactoryUpstreamChannelsExport(baseURL, key)
if err == nil && len(items) > 0 {
return items, nil
}
if err != nil && !errors.Is(err, errTfOpenExportNotFound) {
return nil, fmt.Errorf("拉取上游渠道export: %w", err)
}
legacy, err2 := fetchTokenFactoryUpstreamChannelsLegacy(baseURL, key)
if err2 != nil {
return nil, fmt.Errorf("拉取上游渠道失败: %w", err2)
}
return legacy, nil
}
func tfOpenLocalChannelNo(up upstreamChannelSyncItem) string {
// 留空让本地按既有逻辑分配 cN按 supplier_application_id 递增)。
return ""
}
func buildTokenFactorySyncedChannels(base *model.Channel) ([]model.Channel, []model.TFOpenUpstreamPricing, error) {
baseURL := base.GetBaseURL()
if !isTokenFactoryOpenBaseURL(baseURL) {
return nil, nil, fmt.Errorf("TokenFactoryOpen 渠道的 API 地址必须指向 TokenFactory 平台")
}
key := strings.TrimSpace(base.Key)
if key == "" {
return nil, nil, fmt.Errorf("TokenFactoryOpen 渠道密钥不能为空")
}
if err := fetchTokenFactoryStatus(baseURL, key); err != nil {
return nil, nil, fmt.Errorf("TokenFactoryOpen 平台识别失败: %w", err)
}
upstreamChannels, err := fetchTokenFactoryUpstreamChannels(baseURL, key)
if err != nil {
return nil, nil, fmt.Errorf("拉取上游渠道失败: %w", err)
}
if len(upstreamChannels) == 0 {
return nil, nil, fmt.Errorf("上游未返回可同步渠道")
}
now := common.GetTimestamp()
result := make([]model.Channel, 0, len(upstreamChannels))
pricing := make([]model.TFOpenUpstreamPricing, 0, len(upstreamChannels))
for i, upstream := range upstreamChannels {
clone := *base
clone.Id = 0
clone.CreatedTime = now
if upstream.Type > 0 {
clone.Type = constant.ChannelTypeTokenFactoryOpen
} else {
clone.Type = constant.ChannelTypeTokenFactoryOpen
}
// 直接使用上游渠道名称,不再拼接序号后缀
upstreamName := strings.TrimSpace(upstream.Name)
seqIdx := model.EncodeBase62(int64(i))
if upstreamName != "" {
clone.Name = upstreamName
} else {
clone.Name = fmt.Sprintf("upstream-%s", seqIdx)
}
clone.Models = strings.TrimSpace(upstream.Models)
if strings.TrimSpace(upstream.Group) != "" {
clone.Group = strings.TrimSpace(upstream.Group)
}
if upstream.Status > 0 {
clone.Status = upstream.Status
}
upstreamSupplierType := strings.TrimSpace(upstream.SupplierType)
if upstreamSupplierType != "" && isValidChannelSupplierType(upstreamSupplierType) {
clone.SupplierType = upstreamSupplierType
} else if strings.TrimSpace(clone.SupplierType) == "" || !isValidChannelSupplierType(strings.TrimSpace(clone.SupplierType)) {
clone.SupplierType = defaultChannelSupplierType
}
upstreamLogoURL := strings.TrimSpace(upstream.CompanyLogoURL)
if upstreamLogoURL != "" {
clone.CompanyLogoURL = upstreamLogoURL
}
if upstream.PriceDiscountPercent > 0 {
clone.PriceDiscountPercent = &upstream.PriceDiscountPercent
}
mm := strings.TrimSpace(upstream.ModelMapping)
if mm != "" {
clone.ModelMapping = &mm
} else {
clone.ModelMapping = nil
}
clone.ChannelNo = tfOpenLocalChannelNo(upstream)
clone.RouteSlug = ""
syncMeta := map[string]any{
"source": "tokenfactory_open",
"upstream_channel_id": upstream.ID,
"upstream_channel_no": strings.TrimSpace(upstream.ChannelNo),
"upstream_route_slug": strings.TrimSpace(upstream.RouteSlug),
"upstream_supplier_app_id": upstream.SupplierApplication,
"upstream_supplier_alias": strings.TrimSpace(upstream.SupplierAlias),
"upstream_channel_type": upstream.Type,
"local_channel_no": clone.ChannelNo,
"sync_seq_index": seqIdx, // 本次同步批次内的 base-62 顺序编号
"synced_at": now,
}
metaJSON, _ := common.Marshal(syncMeta)
clone.OtherInfo = string(metaJSON)
result = append(result, clone)
pricing = append(pricing, model.TFOpenUpstreamPricing{
ModelPrice: upstream.ModelPrice,
ModelRatio: upstream.ModelRatio,
})
}
return result, pricing, nil
}
func AddChannel(c *gin.Context) {
addChannelRequest := AddChannelRequest{}
err := c.ShouldBindJSON(&addChannelRequest)
if err != nil {
common.ApiError(c, err)
return
}
// 使用统一的校验函数
if err := validateChannel(addChannelRequest.Channel, true); err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
if addChannelRequest.Channel != nil && addChannelRequest.Channel.PriceDiscountPercent == nil {
v := 100.0
addChannelRequest.Channel.PriceDiscountPercent = &v
}
if err := applySupplierChannelOwnershipForCreate(c, addChannelRequest.Channel); err != nil {
c.JSON(http.StatusForbidden, gin.H{
"success": false,
"message": "当前用户未通过供应商审核,无权创建渠道",
})
return
}
addChannelRequest.Channel.CreatedTime = common.GetTimestamp()
keys := make([]string, 0)
switch addChannelRequest.Mode {
case "multi_to_single":
addChannelRequest.Channel.ChannelInfo.IsMultiKey = true
addChannelRequest.Channel.ChannelInfo.MultiKeyMode = addChannelRequest.MultiKeyMode
if addChannelRequest.Channel.Type == constant.ChannelTypeVertexAi && addChannelRequest.Channel.GetOtherSettings().VertexKeyType != dto.VertexKeyTypeAPIKey {
array, err := getVertexArrayKeys(addChannelRequest.Channel.Key)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
addChannelRequest.Channel.ChannelInfo.MultiKeySize = len(array)
addChannelRequest.Channel.Key = strings.Join(array, "\n")
} else {
cleanKeys := make([]string, 0)
for _, key := range strings.Split(addChannelRequest.Channel.Key, "\n") {
if key == "" {
continue
}
key = strings.TrimSpace(key)
cleanKeys = append(cleanKeys, key)
}
addChannelRequest.Channel.ChannelInfo.MultiKeySize = len(cleanKeys)
addChannelRequest.Channel.Key = strings.Join(cleanKeys, "\n")
}
keys = []string{addChannelRequest.Channel.Key}
case "batch":
if addChannelRequest.Channel.Type == constant.ChannelTypeVertexAi && addChannelRequest.Channel.GetOtherSettings().VertexKeyType != dto.VertexKeyTypeAPIKey {
// multi json
keys, err = getVertexArrayKeys(addChannelRequest.Channel.Key)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
} else {
keys = strings.Split(addChannelRequest.Channel.Key, "\n")
}
case "single":
keys = []string{addChannelRequest.Channel.Key}
default:
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "不支持的添加模式",
})
return
}
channels := make([]model.Channel, 0, len(keys))
for _, key := range keys {
if key == "" {
continue
}
localChannel := addChannelRequest.Channel
localChannel.Key = key
if addChannelRequest.BatchAddSetKeyPrefix2Name && len(keys) > 1 {
keyPrefix := localChannel.Key
if len(localChannel.Key) > 8 {
keyPrefix = localChannel.Key[:8]
}
localChannel.Name = fmt.Sprintf("%s %s", localChannel.Name, keyPrefix)
}
channels = append(channels, *localChannel)
}
var tfOpenPricing []model.TFOpenUpstreamPricing
if addChannelRequest.Channel.Type == constant.ChannelTypeTokenFactoryOpen {
syncBase := *addChannelRequest.Channel
if len(channels) > 0 {
syncBase.Key = strings.TrimSpace(channels[0].Key)
}
syncedChannels, pricing, syncErr := buildTokenFactorySyncedChannels(&syncBase)
if syncErr != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": syncErr.Error(),
})
return
}
channels = syncedChannels
tfOpenPricing = pricing
}
if len(channels) > 1 {
for i := range channels {
channels[i].RouteSlug = ""
}
}
if addChannelRequest.Channel.Type == constant.ChannelTypeTokenFactoryOpen {
err = model.BatchInsertChannelsWithTfOpenUpstreamPricing(channels, tfOpenPricing)
} else {
err = model.BatchInsertChannels(channels)
}
if err != nil {
common.ApiError(c, err)
return
}
service.ResetProxyClientCache()
// 记录操作日志
channelName := ""
if addChannelRequest.Channel != nil {
channelName = addChannelRequest.Channel.Name
}
service.RecordCreateOperation(c, "channel", 0, channelName, "创建渠道: "+channelName, "")
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
})
return
}
func DeleteChannel(c *gin.Context) {
id, _ := strconv.Atoi(c.Param("id"))
channel := model.Channel{Id: id}
err := channel.Delete()
if err != nil {
common.ApiError(c, err)
return
}
model.InitChannelCache()
// 记录操作日志
service.RecordDeleteOperation(c, "channel", id, "", fmt.Sprintf("删除渠道 (ID: %d)", id), "")
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
})
return
}
func DeleteDisabledChannel(c *gin.Context) {
rows, err := model.DeleteDisabledChannel()
if err != nil {
common.ApiError(c, err)
return
}
model.InitChannelCache()
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": rows,
})
return
}
type ChannelTag struct {
Tag string `json:"tag"`
NewTag *string `json:"new_tag"`
Priority *int64 `json:"priority"`
Weight *uint `json:"weight"`
ModelMapping *string `json:"model_mapping"`
Models *string `json:"models"`
Groups *string `json:"groups"`
ParamOverride *string `json:"param_override"`
HeaderOverride *string `json:"header_override"`
}
func DisableTagChannels(c *gin.Context) {
channelTag := ChannelTag{}
err := c.ShouldBindJSON(&channelTag)
if err != nil || channelTag.Tag == "" {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "参数错误",
})
return
}
err = model.DisableChannelByTag(channelTag.Tag)
if err != nil {
common.ApiError(c, err)
return
}
model.InitChannelCache()
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
})
return
}
func EnableTagChannels(c *gin.Context) {
channelTag := ChannelTag{}
err := c.ShouldBindJSON(&channelTag)
if err != nil || channelTag.Tag == "" {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "参数错误",
})
return
}
err = model.EnableChannelByTag(channelTag.Tag)
if err != nil {
common.ApiError(c, err)
return
}
model.InitChannelCache()
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
})
return
}
func EditTagChannels(c *gin.Context) {
channelTag := ChannelTag{}
err := c.ShouldBindJSON(&channelTag)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "参数错误",
})
return
}
if channelTag.Tag == "" {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "tag不能为空",
})
return
}
if channelTag.ParamOverride != nil {
trimmed := strings.TrimSpace(*channelTag.ParamOverride)
if trimmed != "" && !json.Valid([]byte(trimmed)) {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "参数覆盖必须是合法的 JSON 格式",
})
return
}
channelTag.ParamOverride = common.GetPointer[string](trimmed)
}
if channelTag.HeaderOverride != nil {
trimmed := strings.TrimSpace(*channelTag.HeaderOverride)
if trimmed != "" && !json.Valid([]byte(trimmed)) {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "请求头覆盖必须是合法的 JSON 格式",
})
return
}
channelTag.HeaderOverride = common.GetPointer[string](trimmed)
}
err = model.EditChannelByTag(channelTag.Tag, channelTag.NewTag, channelTag.ModelMapping, channelTag.Models, channelTag.Groups, channelTag.Priority, channelTag.Weight, channelTag.ParamOverride, channelTag.HeaderOverride)
if err != nil {
common.ApiError(c, err)
return
}
model.InitChannelCache()
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
})
return
}
type ChannelBatch struct {
Ids []int `json:"ids"`
Tag *string `json:"tag"`
}
func DeleteChannelBatch(c *gin.Context) {
channelBatch := ChannelBatch{}
err := c.ShouldBindJSON(&channelBatch)
if err != nil || len(channelBatch.Ids) == 0 {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "参数错误",
})
return
}
err = model.BatchDeleteChannels(channelBatch.Ids)
if err != nil {
common.ApiError(c, err)
return
}
model.InitChannelCache()
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": len(channelBatch.Ids),
})
return
}
type PatchChannel struct {
model.Channel
MultiKeyMode *string `json:"multi_key_mode"`
KeyMode *string `json:"key_mode"` // 多key模式下密钥覆盖或者追加
}
func UpdateChannel(c *gin.Context) {
channel := PatchChannel{}
err := c.ShouldBindJSON(&channel)
if err != nil {
common.ApiError(c, err)
return
}
if channel.Id <= 0 {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "渠道ID无效",
})
return
}
// Preserve existing ChannelInfo to ensure multi-key channels keep correct state even if the client does not send ChannelInfo in the request.
originChannel, err := model.GetChannelById(channel.Id, true)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
if !validateSupplierChannelOwnershipForUpdate(c, originChannel) {
c.JSON(http.StatusForbidden, gin.H{
"success": false,
"message": "无权修改其他供应商渠道",
})
return
}
oldBalance := originChannel.Balance
// 部分更新(如仅改状态/优先级/权重):请求未带供应商类型时沿用库中值,否则 validateChannel 会因零值失败。
if strings.TrimSpace(channel.SupplierType) == "" {
channel.SupplierType = strings.TrimSpace(originChannel.SupplierType)
}
if strings.TrimSpace(channel.SupplierType) == "" && originChannel.SupplierApplicationID > 0 {
var app model.SupplierApplication
if err := model.DB.Select("supplier_type").Where("id = ?", originChannel.SupplierApplicationID).First(&app).Error; err == nil {
channel.SupplierType = strings.TrimSpace(app.SupplierType)
}
}
if strings.TrimSpace(channel.SupplierType) == "" {
channel.SupplierType = defaultChannelSupplierType
}
// route_slug空则沿用库中值非空变更时校验格式与全局唯一。
if strings.TrimSpace(channel.RouteSlug) == "" {
channel.RouteSlug = originChannel.RouteSlug
} else {
channel.RouteSlug = strings.TrimSpace(channel.RouteSlug)
if channel.RouteSlug != originChannel.RouteSlug {
if !model.IsValidRouteSlug(channel.RouteSlug) {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "route_slug 格式无效232 位字母数字,且不能为 c 加纯数字)",
})
return
}
var cnt int64
if err := model.DB.Model(&model.Channel{}).Where("route_slug = ? AND id <> ?", channel.RouteSlug, channel.Id).Count(&cnt).Error; err != nil {
common.ApiError(c, err)
return
}
if cnt > 0 {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "route_slug 已被其他渠道占用",
})
return
}
}
}
// 使用统一的校验函数
if err := validateChannel(&channel.Channel, false); err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
// 供应商更新时强制保持归属信息不变,防止通过请求体篡改 owner/supplier 关联。
if c.GetInt("role") < common.RoleAdminUser {
channel.OwnerUserID = originChannel.OwnerUserID
channel.SupplierApplicationID = originChannel.SupplierApplicationID
}
// Always copy the original ChannelInfo so that fields like IsMultiKey and MultiKeySize are retained.
channel.ChannelInfo = originChannel.ChannelInfo
// If the request explicitly specifies a new MultiKeyMode, apply it on top of the original info.
if channel.MultiKeyMode != nil && *channel.MultiKeyMode != "" {
channel.ChannelInfo.MultiKeyMode = constant.MultiKeyMode(*channel.MultiKeyMode)
}
// 处理多key模式下的密钥追加/覆盖逻辑
if channel.KeyMode != nil && channel.ChannelInfo.IsMultiKey {
switch *channel.KeyMode {
case "append":
// 追加模式:将新密钥添加到现有密钥列表
if originChannel.Key != "" {
var newKeys []string
var existingKeys []string
// 解析现有密钥
if strings.HasPrefix(strings.TrimSpace(originChannel.Key), "[") {
// JSON数组格式
var arr []json.RawMessage
if err := json.Unmarshal([]byte(strings.TrimSpace(originChannel.Key)), &arr); err == nil {
existingKeys = make([]string, len(arr))
for i, v := range arr {
existingKeys[i] = string(v)
}
}
} else {
// 换行分隔格式
existingKeys = strings.Split(strings.Trim(originChannel.Key, "\n"), "\n")
}
// 处理 Vertex AI 的特殊情况
if channel.Type == constant.ChannelTypeVertexAi && channel.GetOtherSettings().VertexKeyType != dto.VertexKeyTypeAPIKey {
// 尝试解析新密钥为JSON数组
if strings.HasPrefix(strings.TrimSpace(channel.Key), "[") {
array, err := getVertexArrayKeys(channel.Key)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "追加密钥解析失败: " + err.Error(),
})
return
}
newKeys = array
} else {
// 单个JSON密钥
newKeys = []string{channel.Key}
}
} else {
// 普通渠道的处理
inputKeys := strings.Split(channel.Key, "\n")
for _, key := range inputKeys {
key = strings.TrimSpace(key)
if key != "" {
newKeys = append(newKeys, key)
}
}
}
seen := make(map[string]struct{}, len(existingKeys)+len(newKeys))
for _, key := range existingKeys {
normalized := strings.TrimSpace(key)
if normalized == "" {
continue
}
seen[normalized] = struct{}{}
}
dedupedNewKeys := make([]string, 0, len(newKeys))
for _, key := range newKeys {
normalized := strings.TrimSpace(key)
if normalized == "" {
continue
}
if _, ok := seen[normalized]; ok {
continue
}
seen[normalized] = struct{}{}
dedupedNewKeys = append(dedupedNewKeys, normalized)
}
allKeys := append(existingKeys, dedupedNewKeys...)
channel.Key = strings.Join(allKeys, "\n")
}
case "replace":
// 覆盖模式:直接使用新密钥(默认行为,不需要特殊处理)
}
}
err = channel.Update()
if err != nil {
common.ApiError(c, err)
return
}
notifyChannelBalanceAlertIfNeeded(originChannel, oldBalance, channel.Balance)
model.InitChannelCache()
service.ResetProxyClientCache()
channel.Key = ""
clearChannelInfo(&channel.Channel)
// 记录操作日志
service.RecordUpdateOperation(c, "channel", channel.Id, channel.Name, "更新渠道: "+channel.Name, "")
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": channel,
})
return
}
func FetchModels(c *gin.Context) {
var req struct {
BaseURL string `json:"base_url"`
Type int `json:"type"`
Key string `json:"key"`
}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{
"success": false,
"message": "Invalid request",
})
return
}
baseURL := req.BaseURL
if baseURL == "" {
baseURL = constant.ChannelBaseURLs[req.Type]
}
// remove line breaks and extra spaces.
key := strings.TrimSpace(req.Key)
key = strings.Split(key, "\n")[0]
if req.Type == constant.ChannelTypeOllama {
models, err := ollama.FetchOllamaModels(c.Request.Context(), baseURL, key)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": fmt.Sprintf("获取Ollama模型失败: %s", err.Error()),
})
return
}
names := make([]string, 0, len(models))
for _, modelInfo := range models {
names = append(names, modelInfo.Name)
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"data": names,
})
return
}
if req.Type == constant.ChannelTypeGemini {
models, err := gemini.FetchGeminiModels(c.Request.Context(), baseURL, key, "")
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": fmt.Sprintf("获取Gemini模型失败: %s", err.Error()),
})
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"data": models,
})
return
}
client := &http.Client{}
url := fmt.Sprintf("%s/v1/models", baseURL)
request, err := http.NewRequest("GET", url, nil)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"success": false,
"message": err.Error(),
})
return
}
request.Header.Set("Authorization", "Bearer "+key)
response, err := client.Do(request)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"success": false,
"message": err.Error(),
})
return
}
//check status code
if response.StatusCode != http.StatusOK {
c.JSON(http.StatusInternalServerError, gin.H{
"success": false,
"message": "Failed to fetch models",
})
return
}
defer response.Body.Close()
var result struct {
Data []struct {
ID string `json:"id"`
} `json:"data"`
}
if err := json.NewDecoder(response.Body).Decode(&result); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"success": false,
"message": err.Error(),
})
return
}
var models []string
for _, model := range result.Data {
models = append(models, model.ID)
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"data": models,
})
}
func BatchSetChannelTag(c *gin.Context) {
channelBatch := ChannelBatch{}
err := c.ShouldBindJSON(&channelBatch)
if err != nil || len(channelBatch.Ids) == 0 {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "参数错误",
})
return
}
err = model.BatchSetChannelTag(channelBatch.Ids, channelBatch.Tag)
if err != nil {
common.ApiError(c, err)
return
}
model.InitChannelCache()
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": len(channelBatch.Ids),
})
return
}
func GetTagModels(c *gin.Context) {
tag := c.Query("tag")
if tag == "" {
c.JSON(http.StatusBadRequest, gin.H{
"success": false,
"message": "tag不能为空",
})
return
}
channels, err := model.GetChannelsByTag(tag, false, false) // idSort=false, selectAll=false
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"success": false,
"message": err.Error(),
})
return
}
var longestModels string
maxLength := 0
// Find the longest models string among all channels with the given tag
for _, channel := range channels {
if channel.Models != "" {
currentModels := strings.Split(channel.Models, ",")
if len(currentModels) > maxLength {
maxLength = len(currentModels)
longestModels = channel.Models
}
}
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": longestModels,
})
return
}
// CopyChannel handles cloning an existing channel with its key.
// POST /api/channel/copy/:id
// Optional query params:
//
// suffix - string appended to the original name (default "_复制")
// reset_balance - bool, when true will reset balance & used_quota to 0 (default true)
func CopyChannel(c *gin.Context) {
id, err := strconv.Atoi(c.Param("id"))
if err != nil {
c.JSON(http.StatusOK, gin.H{"success": false, "message": "invalid id"})
return
}
suffix := c.DefaultQuery("suffix", "_复制")
resetBalance := true
if rbStr := c.DefaultQuery("reset_balance", "true"); rbStr != "" {
if v, err := strconv.ParseBool(rbStr); err == nil {
resetBalance = v
}
}
// fetch original channel with key
origin, err := model.GetChannelById(id, true)
if err != nil {
common.SysError("failed to get channel by id: " + err.Error())
c.JSON(http.StatusOK, gin.H{"success": false, "message": "获取渠道信息失败,请稍后重试"})
return
}
// clone channel
clone := *origin // shallow copy is sufficient as we will overwrite primitives
clone.Id = 0 // let DB auto-generate
clone.CreatedTime = common.GetTimestamp()
clone.Name = origin.Name + suffix
clone.TestTime = 0
clone.ResponseTime = 0
if resetBalance {
clone.Balance = 0
clone.UsedQuota = 0
}
clone.ChannelNo = ""
clone.RouteSlug = ""
// insert
if err := model.BatchInsertChannels([]model.Channel{clone}); err != nil {
common.SysError("failed to clone channel: " + err.Error())
c.JSON(http.StatusOK, gin.H{"success": false, "message": "复制渠道失败,请稍后重试"})
return
}
model.InitChannelCache()
// success
c.JSON(http.StatusOK, gin.H{"success": true, "message": "", "data": gin.H{"id": clone.Id}})
}
// MultiKeyManageRequest represents the request for multi-key management operations
type MultiKeyManageRequest struct {
ChannelId int `json:"channel_id"`
Action string `json:"action"` // "disable_key", "enable_key", "delete_key", "delete_disabled_keys", "get_key_status"
KeyIndex *int `json:"key_index,omitempty"` // for disable_key, enable_key, and delete_key actions
Page int `json:"page,omitempty"` // for get_key_status pagination
PageSize int `json:"page_size,omitempty"` // for get_key_status pagination
Status *int `json:"status,omitempty"` // for get_key_status filtering: 1=enabled, 2=manual_disabled, 3=auto_disabled, nil=all
}
// MultiKeyStatusResponse represents the response for key status query
type MultiKeyStatusResponse struct {
Keys []KeyStatus `json:"keys"`
Total int `json:"total"`
Page int `json:"page"`
PageSize int `json:"page_size"`
TotalPages int `json:"total_pages"`
// Statistics
EnabledCount int `json:"enabled_count"`
ManualDisabledCount int `json:"manual_disabled_count"`
AutoDisabledCount int `json:"auto_disabled_count"`
}
type KeyStatus struct {
Index int `json:"index"`
Status int `json:"status"` // 1: enabled, 2: disabled
DisabledTime int64 `json:"disabled_time,omitempty"`
Reason string `json:"reason,omitempty"`
KeyPreview string `json:"key_preview"` // first 10 chars of key for identification
}
// ManageMultiKeys handles multi-key management operations
func ManageMultiKeys(c *gin.Context) {
request := MultiKeyManageRequest{}
err := c.ShouldBindJSON(&request)
if err != nil {
common.ApiError(c, err)
return
}
channel, err := model.GetChannelById(request.ChannelId, true)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "渠道不存在",
})
return
}
if !channel.ChannelInfo.IsMultiKey {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "该渠道不是多密钥模式",
})
return
}
lock := model.GetChannelPollingLock(channel.Id)
lock.Lock()
defer lock.Unlock()
switch request.Action {
case "get_key_status":
keys := channel.GetKeys()
// Default pagination parameters
page := request.Page
pageSize := request.PageSize
if page <= 0 {
page = 1
}
if pageSize <= 0 {
pageSize = 50 // Default page size
}
// Statistics for all keys (unchanged by filtering)
var enabledCount, manualDisabledCount, autoDisabledCount int
// Build all key status data first
var allKeyStatusList []KeyStatus
for i, key := range keys {
status := 1 // default enabled
var disabledTime int64
var reason string
if channel.ChannelInfo.MultiKeyStatusList != nil {
if s, exists := channel.ChannelInfo.MultiKeyStatusList[i]; exists {
status = s
}
}
// Count for statistics (all keys)
switch status {
case 1:
enabledCount++
case 2:
manualDisabledCount++
case 3:
autoDisabledCount++
}
if status != 1 {
if channel.ChannelInfo.MultiKeyDisabledTime != nil {
disabledTime = channel.ChannelInfo.MultiKeyDisabledTime[i]
}
if channel.ChannelInfo.MultiKeyDisabledReason != nil {
reason = channel.ChannelInfo.MultiKeyDisabledReason[i]
}
}
// Create key preview (first 10 chars)
keyPreview := key
if len(key) > 10 {
keyPreview = key[:10] + "..."
}
allKeyStatusList = append(allKeyStatusList, KeyStatus{
Index: i,
Status: status,
DisabledTime: disabledTime,
Reason: reason,
KeyPreview: keyPreview,
})
}
// Apply status filter if specified
var filteredKeyStatusList []KeyStatus
if request.Status != nil {
for _, keyStatus := range allKeyStatusList {
if keyStatus.Status == *request.Status {
filteredKeyStatusList = append(filteredKeyStatusList, keyStatus)
}
}
} else {
filteredKeyStatusList = allKeyStatusList
}
// Calculate pagination based on filtered results
filteredTotal := len(filteredKeyStatusList)
totalPages := (filteredTotal + pageSize - 1) / pageSize
if totalPages == 0 {
totalPages = 1
}
if page > totalPages {
page = totalPages
}
// Calculate range for current page
start := (page - 1) * pageSize
end := start + pageSize
if end > filteredTotal {
end = filteredTotal
}
// Get the page data
var pageKeyStatusList []KeyStatus
if start < filteredTotal {
pageKeyStatusList = filteredKeyStatusList[start:end]
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": MultiKeyStatusResponse{
Keys: pageKeyStatusList,
Total: filteredTotal, // Total of filtered results
Page: page,
PageSize: pageSize,
TotalPages: totalPages,
EnabledCount: enabledCount, // Overall statistics
ManualDisabledCount: manualDisabledCount, // Overall statistics
AutoDisabledCount: autoDisabledCount, // Overall statistics
},
})
return
case "disable_key":
if request.KeyIndex == nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "未指定要禁用的密钥索引",
})
return
}
keyIndex := *request.KeyIndex
if keyIndex < 0 || keyIndex >= channel.ChannelInfo.MultiKeySize {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "密钥索引超出范围",
})
return
}
if channel.ChannelInfo.MultiKeyStatusList == nil {
channel.ChannelInfo.MultiKeyStatusList = make(map[int]int)
}
if channel.ChannelInfo.MultiKeyDisabledTime == nil {
channel.ChannelInfo.MultiKeyDisabledTime = make(map[int]int64)
}
if channel.ChannelInfo.MultiKeyDisabledReason == nil {
channel.ChannelInfo.MultiKeyDisabledReason = make(map[int]string)
}
channel.ChannelInfo.MultiKeyStatusList[keyIndex] = 2 // disabled
err = channel.Update()
if err != nil {
common.ApiError(c, err)
return
}
model.InitChannelCache()
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "密钥已禁用",
})
return
case "enable_key":
if request.KeyIndex == nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "未指定要启用的密钥索引",
})
return
}
keyIndex := *request.KeyIndex
if keyIndex < 0 || keyIndex >= channel.ChannelInfo.MultiKeySize {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "密钥索引超出范围",
})
return
}
// 从状态列表中删除该密钥的记录,使其回到默认启用状态
if channel.ChannelInfo.MultiKeyStatusList != nil {
delete(channel.ChannelInfo.MultiKeyStatusList, keyIndex)
}
if channel.ChannelInfo.MultiKeyDisabledTime != nil {
delete(channel.ChannelInfo.MultiKeyDisabledTime, keyIndex)
}
if channel.ChannelInfo.MultiKeyDisabledReason != nil {
delete(channel.ChannelInfo.MultiKeyDisabledReason, keyIndex)
}
err = channel.Update()
if err != nil {
common.ApiError(c, err)
return
}
model.InitChannelCache()
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "密钥已启用",
})
return
case "enable_all_keys":
// 清空所有禁用状态,使所有密钥回到默认启用状态
var enabledCount int
if channel.ChannelInfo.MultiKeyStatusList != nil {
enabledCount = len(channel.ChannelInfo.MultiKeyStatusList)
}
channel.ChannelInfo.MultiKeyStatusList = make(map[int]int)
channel.ChannelInfo.MultiKeyDisabledTime = make(map[int]int64)
channel.ChannelInfo.MultiKeyDisabledReason = make(map[int]string)
err = channel.Update()
if err != nil {
common.ApiError(c, err)
return
}
model.InitChannelCache()
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": fmt.Sprintf("已启用 %d 个密钥", enabledCount),
})
return
case "disable_all_keys":
// 禁用所有启用的密钥
if channel.ChannelInfo.MultiKeyStatusList == nil {
channel.ChannelInfo.MultiKeyStatusList = make(map[int]int)
}
if channel.ChannelInfo.MultiKeyDisabledTime == nil {
channel.ChannelInfo.MultiKeyDisabledTime = make(map[int]int64)
}
if channel.ChannelInfo.MultiKeyDisabledReason == nil {
channel.ChannelInfo.MultiKeyDisabledReason = make(map[int]string)
}
var disabledCount int
for i := 0; i < channel.ChannelInfo.MultiKeySize; i++ {
status := 1 // default enabled
if s, exists := channel.ChannelInfo.MultiKeyStatusList[i]; exists {
status = s
}
// 只禁用当前启用的密钥
if status == 1 {
channel.ChannelInfo.MultiKeyStatusList[i] = 2 // disabled
disabledCount++
}
}
if disabledCount == 0 {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "没有可禁用的密钥",
})
return
}
err = channel.Update()
if err != nil {
common.ApiError(c, err)
return
}
model.InitChannelCache()
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": fmt.Sprintf("已禁用 %d 个密钥", disabledCount),
})
return
case "delete_key":
if request.KeyIndex == nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "未指定要删除的密钥索引",
})
return
}
keyIndex := *request.KeyIndex
if keyIndex < 0 || keyIndex >= channel.ChannelInfo.MultiKeySize {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "密钥索引超出范围",
})
return
}
keys := channel.GetKeys()
var remainingKeys []string
var newStatusList = make(map[int]int)
var newDisabledTime = make(map[int]int64)
var newDisabledReason = make(map[int]string)
newIndex := 0
for i, key := range keys {
// 跳过要删除的密钥
if i == keyIndex {
continue
}
remainingKeys = append(remainingKeys, key)
// 保留其他密钥的状态信息,重新索引
if channel.ChannelInfo.MultiKeyStatusList != nil {
if status, exists := channel.ChannelInfo.MultiKeyStatusList[i]; exists && status != 1 {
newStatusList[newIndex] = status
}
}
if channel.ChannelInfo.MultiKeyDisabledTime != nil {
if t, exists := channel.ChannelInfo.MultiKeyDisabledTime[i]; exists {
newDisabledTime[newIndex] = t
}
}
if channel.ChannelInfo.MultiKeyDisabledReason != nil {
if r, exists := channel.ChannelInfo.MultiKeyDisabledReason[i]; exists {
newDisabledReason[newIndex] = r
}
}
newIndex++
}
if len(remainingKeys) == 0 {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "不能删除最后一个密钥",
})
return
}
// Update channel with remaining keys
channel.Key = strings.Join(remainingKeys, "\n")
channel.ChannelInfo.MultiKeySize = len(remainingKeys)
channel.ChannelInfo.MultiKeyStatusList = newStatusList
channel.ChannelInfo.MultiKeyDisabledTime = newDisabledTime
channel.ChannelInfo.MultiKeyDisabledReason = newDisabledReason
err = channel.Update()
if err != nil {
common.ApiError(c, err)
return
}
model.InitChannelCache()
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "密钥已删除",
})
return
case "delete_disabled_keys":
keys := channel.GetKeys()
var remainingKeys []string
var deletedCount int
var newStatusList = make(map[int]int)
var newDisabledTime = make(map[int]int64)
var newDisabledReason = make(map[int]string)
newIndex := 0
for i, key := range keys {
status := 1 // default enabled
if channel.ChannelInfo.MultiKeyStatusList != nil {
if s, exists := channel.ChannelInfo.MultiKeyStatusList[i]; exists {
status = s
}
}
// 只删除自动禁用status == 3的密钥保留启用status == 1和手动禁用status == 2的密钥
if status == 3 {
deletedCount++
} else {
remainingKeys = append(remainingKeys, key)
// 保留非自动禁用密钥的状态信息,重新索引
if status != 1 {
newStatusList[newIndex] = status
if channel.ChannelInfo.MultiKeyDisabledTime != nil {
if t, exists := channel.ChannelInfo.MultiKeyDisabledTime[i]; exists {
newDisabledTime[newIndex] = t
}
}
if channel.ChannelInfo.MultiKeyDisabledReason != nil {
if r, exists := channel.ChannelInfo.MultiKeyDisabledReason[i]; exists {
newDisabledReason[newIndex] = r
}
}
}
newIndex++
}
}
if deletedCount == 0 {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "没有需要删除的自动禁用密钥",
})
return
}
// Update channel with remaining keys
channel.Key = strings.Join(remainingKeys, "\n")
channel.ChannelInfo.MultiKeySize = len(remainingKeys)
channel.ChannelInfo.MultiKeyStatusList = newStatusList
channel.ChannelInfo.MultiKeyDisabledTime = newDisabledTime
channel.ChannelInfo.MultiKeyDisabledReason = newDisabledReason
err = channel.Update()
if err != nil {
common.ApiError(c, err)
return
}
model.InitChannelCache()
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": fmt.Sprintf("已删除 %d 个自动禁用的密钥", deletedCount),
"data": deletedCount,
})
return
default:
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "不支持的操作",
})
return
}
}
// OllamaPullModel 拉取 Ollama 模型
func OllamaPullModel(c *gin.Context) {
var req struct {
ChannelID int `json:"channel_id"`
ModelName string `json:"model_name"`
}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{
"success": false,
"message": "Invalid request parameters",
})
return
}
if req.ChannelID == 0 || req.ModelName == "" {
c.JSON(http.StatusBadRequest, gin.H{
"success": false,
"message": "Channel ID and model name are required",
})
return
}
// 获取渠道信息
channel, err := model.GetChannelById(req.ChannelID, true)
if err != nil {
c.JSON(http.StatusNotFound, gin.H{
"success": false,
"message": "Channel not found",
})
return
}
// 检查是否是 Ollama 渠道
if channel.Type != constant.ChannelTypeOllama {
c.JSON(http.StatusBadRequest, gin.H{
"success": false,
"message": "This operation is only supported for Ollama channels",
})
return
}
baseURL := constant.ChannelBaseURLs[channel.Type]
if channel.GetBaseURL() != "" {
baseURL = channel.GetBaseURL()
}
key := strings.Split(channel.Key, "\n")[0]
err = ollama.PullOllamaModel(baseURL, key, req.ModelName)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"success": false,
"message": fmt.Sprintf("Failed to pull model: %s", err.Error()),
})
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": fmt.Sprintf("Model %s pulled successfully", req.ModelName),
})
}
// OllamaPullModelStream 流式拉取 Ollama 模型
func OllamaPullModelStream(c *gin.Context) {
var req struct {
ChannelID int `json:"channel_id"`
ModelName string `json:"model_name"`
}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{
"success": false,
"message": "Invalid request parameters",
})
return
}
if req.ChannelID == 0 || req.ModelName == "" {
c.JSON(http.StatusBadRequest, gin.H{
"success": false,
"message": "Channel ID and model name are required",
})
return
}
// 获取渠道信息
channel, err := model.GetChannelById(req.ChannelID, true)
if err != nil {
c.JSON(http.StatusNotFound, gin.H{
"success": false,
"message": "Channel not found",
})
return
}
// 检查是否是 Ollama 渠道
if channel.Type != constant.ChannelTypeOllama {
c.JSON(http.StatusBadRequest, gin.H{
"success": false,
"message": "This operation is only supported for Ollama channels",
})
return
}
baseURL := constant.ChannelBaseURLs[channel.Type]
if channel.GetBaseURL() != "" {
baseURL = channel.GetBaseURL()
}
// 设置 SSE 头部
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
c.Header("Access-Control-Allow-Origin", "*")
key := strings.Split(channel.Key, "\n")[0]
// 创建进度回调函数
progressCallback := func(progress ollama.OllamaPullResponse) {
data, _ := json.Marshal(progress)
fmt.Fprintf(c.Writer, "data: %s\n\n", string(data))
c.Writer.Flush()
}
// 执行拉取
err = ollama.PullOllamaModelStream(baseURL, key, req.ModelName, progressCallback)
if err != nil {
errorData, _ := json.Marshal(gin.H{
"error": err.Error(),
})
fmt.Fprintf(c.Writer, "data: %s\n\n", string(errorData))
} else {
successData, _ := json.Marshal(gin.H{
"message": fmt.Sprintf("Model %s pulled successfully", req.ModelName),
})
fmt.Fprintf(c.Writer, "data: %s\n\n", string(successData))
}
// 发送结束标志
fmt.Fprintf(c.Writer, "data: [DONE]\n\n")
c.Writer.Flush()
}
// OllamaDeleteModel 删除 Ollama 模型
func OllamaDeleteModel(c *gin.Context) {
var req struct {
ChannelID int `json:"channel_id"`
ModelName string `json:"model_name"`
}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{
"success": false,
"message": "Invalid request parameters",
})
return
}
if req.ChannelID == 0 || req.ModelName == "" {
c.JSON(http.StatusBadRequest, gin.H{
"success": false,
"message": "Channel ID and model name are required",
})
return
}
// 获取渠道信息
channel, err := model.GetChannelById(req.ChannelID, true)
if err != nil {
c.JSON(http.StatusNotFound, gin.H{
"success": false,
"message": "Channel not found",
})
return
}
// 检查是否是 Ollama 渠道
if channel.Type != constant.ChannelTypeOllama {
c.JSON(http.StatusBadRequest, gin.H{
"success": false,
"message": "This operation is only supported for Ollama channels",
})
return
}
baseURL := constant.ChannelBaseURLs[channel.Type]
if channel.GetBaseURL() != "" {
baseURL = channel.GetBaseURL()
}
key := strings.Split(channel.Key, "\n")[0]
err = ollama.DeleteOllamaModel(baseURL, key, req.ModelName)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"success": false,
"message": fmt.Sprintf("Failed to delete model: %s", err.Error()),
})
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": fmt.Sprintf("Model %s deleted successfully", req.ModelName),
})
}
// OllamaVersion 获取 Ollama 服务版本信息
func OllamaVersion(c *gin.Context) {
id, err := strconv.Atoi(c.Param("id"))
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{
"success": false,
"message": "Invalid channel id",
})
return
}
channel, err := model.GetChannelById(id, true)
if err != nil {
c.JSON(http.StatusNotFound, gin.H{
"success": false,
"message": "Channel not found",
})
return
}
if channel.Type != constant.ChannelTypeOllama {
c.JSON(http.StatusBadRequest, gin.H{
"success": false,
"message": "This operation is only supported for Ollama channels",
})
return
}
baseURL := constant.ChannelBaseURLs[channel.Type]
if channel.GetBaseURL() != "" {
baseURL = channel.GetBaseURL()
}
key := strings.Split(channel.Key, "\n")[0]
version, err := ollama.FetchOllamaVersion(baseURL, key)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": fmt.Sprintf("获取Ollama版本失败: %s", err.Error()),
})
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"data": gin.H{
"version": version,
},
})
}