package controller import ( "context" "encoding/json" "errors" "fmt" "io" "net/http" "net/url" "strconv" "strings" "time" "github.com/QuantumNous/new-api/common" "github.com/QuantumNous/new-api/constant" "github.com/QuantumNous/new-api/dto" "github.com/QuantumNous/new-api/model" relaychannel "github.com/QuantumNous/new-api/relay/channel" "github.com/QuantumNous/new-api/relay/channel/gemini" "github.com/QuantumNous/new-api/relay/channel/ollama" "github.com/QuantumNous/new-api/service" "github.com/gin-gonic/gin" ) type OpenAIModel struct { ID string `json:"id"` Object string `json:"object"` Created int64 `json:"created"` OwnedBy string `json:"owned_by"` Metadata map[string]any `json:"metadata,omitempty"` Permission []struct { ID string `json:"id"` Object string `json:"object"` Created int64 `json:"created"` AllowCreateEngine bool `json:"allow_create_engine"` AllowSampling bool `json:"allow_sampling"` AllowLogprobs bool `json:"allow_logprobs"` AllowSearchIndices bool `json:"allow_search_indices"` AllowView bool `json:"allow_view"` AllowFineTuning bool `json:"allow_fine_tuning"` Organization string `json:"organization"` Group string `json:"group"` IsBlocking bool `json:"is_blocking"` } `json:"permission"` Root string `json:"root"` Parent string `json:"parent"` } type OpenAIModelsResponse struct { Data []OpenAIModel `json:"data"` Success bool `json:"success"` } var channelAllowedSupplierTypes = map[string]struct{}{ "公有云": {}, "AIDC": {}, "企业中转站": {}, "个人中转站": {}, } // defaultChannelSupplierType 当渠道行与关联供应商申请均未提供 supplier_type 时的兜底值(须为 channelAllowedSupplierTypes 之一)。 const defaultChannelSupplierType = "公有云" // isValidChannelSupplierType 校验供应商类型是否属于预定义枚举值。 func isValidChannelSupplierType(supplierType string) bool { _, ok := channelAllowedSupplierTypes[supplierType] return ok } func parseStatusFilter(statusParam string) int { switch strings.ToLower(statusParam) { case "enabled", "1": return common.ChannelStatusEnabled case "disabled", "0": return 0 default: return -1 } } func clearChannelInfo(channel *model.Channel) { if channel.ChannelInfo.IsMultiKey { channel.ChannelInfo.MultiKeyDisabledReason = nil channel.ChannelInfo.MultiKeyDisabledTime = nil } } // attachSupplierNames 为渠道列表补齐供应商用户名(owner_user_id 对应 users.username)。 func attachSupplierNames(channels []*model.Channel) { ownerIDs := make([]int, 0) ownerSet := make(map[int]struct{}) for _, channel := range channels { if channel == nil || channel.OwnerUserID <= 0 { continue } if _, ok := ownerSet[channel.OwnerUserID]; ok { continue } ownerSet[channel.OwnerUserID] = struct{}{} ownerIDs = append(ownerIDs, channel.OwnerUserID) } if len(ownerIDs) == 0 { return } var users []model.User if err := model.DB.Select("id, username").Where("id IN ?", ownerIDs).Find(&users).Error; err != nil { return } userMap := make(map[int]string, len(users)) for _, user := range users { userMap[user.Id] = user.Username } for _, channel := range channels { if channel == nil || channel.OwnerUserID <= 0 { continue } channel.SupplierName = userMap[channel.OwnerUserID] } } func GetAllChannels(c *gin.Context) { pageInfo := common.GetPageQuery(c) channelData := make([]*model.Channel, 0) idSort, _ := strconv.ParseBool(c.Query("id_sort")) enableTagMode, _ := strconv.ParseBool(c.Query("tag_mode")) supplierKeyword := strings.TrimSpace(c.Query("supplier")) statusParam := c.Query("status") // statusFilter: -1 all, 1 enabled, 0 disabled (include auto & manual) statusFilter := parseStatusFilter(statusParam) // type filter typeStr := c.Query("type") typeFilter := -1 if typeStr != "" { if t, err := strconv.Atoi(typeStr); err == nil { typeFilter = t } } // 供应商复用原渠道列表接口:仅查看自己渠道,管理员保持原有全量逻辑。 if c.GetInt("role") < common.RoleAdminUser { baseQuery := model.DB.Model(&model.Channel{}).Where("owner_user_id = ?", c.GetInt("id")) if typeFilter >= 0 { baseQuery = baseQuery.Where("type = ?", typeFilter) } if statusFilter == common.ChannelStatusEnabled { baseQuery = baseQuery.Where("status = ?", common.ChannelStatusEnabled) } else if statusFilter == 0 { baseQuery = baseQuery.Where("status != ?", common.ChannelStatusEnabled) } var total int64 if err := baseQuery.Count(&total).Error; err != nil { common.ApiError(c, err) return } order := "priority desc" if idSort { order = "id desc" } if err := baseQuery.Order(order).Limit(pageInfo.GetPageSize()).Offset(pageInfo.GetStartIdx()).Omit("key").Find(&channelData).Error; err != nil { common.ApiError(c, err) return } for _, datum := range channelData { clearChannelInfo(datum) } attachSupplierNames(channelData) typeCounts := make(map[int64]int64) for _, channel := range channelData { typeCounts[int64(channel.Type)]++ } common.ApiSuccess(c, gin.H{ "items": channelData, "total": total, "page": pageInfo.GetPage(), "page_size": pageInfo.GetPageSize(), "type_counts": typeCounts, }) return } var total int64 if enableTagMode { tags, err := model.GetPaginatedTags(pageInfo.GetStartIdx(), pageInfo.GetPageSize()) if err != nil { common.SysError("failed to get paginated tags: " + err.Error()) c.JSON(http.StatusOK, gin.H{"success": false, "message": "获取标签失败,请稍后重试"}) return } for _, tag := range tags { if tag == nil || *tag == "" { continue } tagChannels, err := model.GetChannelsByTag(*tag, idSort, false) if err != nil { continue } filtered := make([]*model.Channel, 0) for _, ch := range tagChannels { if statusFilter == common.ChannelStatusEnabled && ch.Status != common.ChannelStatusEnabled { continue } if statusFilter == 0 && ch.Status == common.ChannelStatusEnabled { continue } if typeFilter >= 0 && ch.Type != typeFilter { continue } filtered = append(filtered, ch) } channelData = append(channelData, filtered...) } total, _ = model.CountAllTags() } else { baseQuery := model.DB.Model(&model.Channel{}) if supplierKeyword != "" { baseQuery = baseQuery.Joins("LEFT JOIN users ON users.id = channels.owner_user_id").Where("users.username LIKE ?", "%"+supplierKeyword+"%") } if typeFilter >= 0 { baseQuery = baseQuery.Where("type = ?", typeFilter) } if statusFilter == common.ChannelStatusEnabled { baseQuery = baseQuery.Where("status = ?", common.ChannelStatusEnabled) } else if statusFilter == 0 { baseQuery = baseQuery.Where("status != ?", common.ChannelStatusEnabled) } baseQuery.Count(&total) order := "priority desc" if idSort { order = "id desc" } err := baseQuery.Order(order).Limit(pageInfo.GetPageSize()).Offset(pageInfo.GetStartIdx()).Omit("key").Find(&channelData).Error if err != nil { common.SysError("failed to get channels: " + err.Error()) c.JSON(http.StatusOK, gin.H{"success": false, "message": "获取渠道列表失败,请稍后重试"}) return } } for _, datum := range channelData { clearChannelInfo(datum) } attachSupplierNames(channelData) countQuery := model.DB.Model(&model.Channel{}) if statusFilter == common.ChannelStatusEnabled { countQuery = countQuery.Where("status = ?", common.ChannelStatusEnabled) } else if statusFilter == 0 { countQuery = countQuery.Where("status != ?", common.ChannelStatusEnabled) } var results []struct { Type int64 Count int64 } _ = countQuery.Select("type, count(*) as count").Group("type").Find(&results).Error typeCounts := make(map[int64]int64) for _, r := range results { typeCounts[r.Type] = r.Count } common.ApiSuccess(c, gin.H{ "items": channelData, "total": total, "page": pageInfo.GetPage(), "page_size": pageInfo.GetPageSize(), "type_counts": typeCounts, }) return } func buildFetchModelsHeaders(channel *model.Channel, key string) (http.Header, error) { var headers http.Header switch channel.Type { case constant.ChannelTypeAnthropic: headers = GetClaudeAuthHeader(key) default: headers = GetAuthHeader(key) } headerOverride := channel.GetHeaderOverride() for k, v := range headerOverride { if relaychannel.IsHeaderPassthroughRuleKey(k) { continue } str, ok := v.(string) if !ok { return nil, fmt.Errorf("invalid header override for key %s", k) } if strings.Contains(str, "{api_key}") { str = strings.ReplaceAll(str, "{api_key}", key) } headers.Set(k, str) } return headers, nil } func FetchUpstreamModels(c *gin.Context) { id, err := strconv.Atoi(c.Param("id")) if err != nil { common.ApiError(c, err) return } channel, err := model.GetChannelById(id, true) if err != nil { common.ApiError(c, err) return } // 供应商只允许拉取自己渠道的上游模型,防止跨供应商越权读取。 if c.GetInt("role") < common.RoleAdminUser && channel.OwnerUserID != c.GetInt("id") { c.JSON(http.StatusForbidden, gin.H{ "success": false, "message": "无权访问其他供应商渠道", }) return } ids, err := fetchChannelUpstreamModelIDs(channel) if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, "message": fmt.Sprintf("获取模型列表失败: %s", err.Error()), }) return } c.JSON(http.StatusOK, gin.H{ "success": true, "message": "", "data": ids, }) } func FixChannelsAbilities(c *gin.Context) { success, fails, err := model.FixAbility() if err != nil { common.ApiError(c, err) return } c.JSON(http.StatusOK, gin.H{ "success": true, "message": "", "data": gin.H{ "success": success, "fails": fails, }, }) } func SearchChannels(c *gin.Context) { keyword := c.Query("keyword") supplierKeyword := strings.TrimSpace(c.Query("supplier")) group := c.Query("group") modelKeyword := c.Query("model") statusParam := c.Query("status") statusFilter := parseStatusFilter(statusParam) idSort, _ := strconv.ParseBool(c.Query("id_sort")) enableTagMode, _ := strconv.ParseBool(c.Query("tag_mode")) // 供应商复用原渠道搜索接口:仅查询自己渠道。 if c.GetInt("role") < common.RoleAdminUser { channelID, _ := model.ParseSupplierChannelIDFilter(keyword) filter := model.SupplierChannelSearchFilter{ ChannelID: channelID, Keyword: keyword, Supplier: supplierKeyword, ModelKeyword: modelKeyword, Group: group, } ownerUserID := c.GetInt("id") channelData, total, err := model.SearchSupplierChannels(&ownerUserID, 0, 100000, filter) if err != nil { common.ApiError(c, err) return } if statusFilter == common.ChannelStatusEnabled || statusFilter == 0 { filtered := make([]*model.Channel, 0, len(channelData)) for _, ch := range channelData { if statusFilter == common.ChannelStatusEnabled && ch.Status != common.ChannelStatusEnabled { continue } if statusFilter == 0 && ch.Status == common.ChannelStatusEnabled { continue } filtered = append(filtered, ch) } channelData = filtered total = int64(len(filtered)) } typeParam := c.Query("type") typeFilter := -1 if typeParam != "" { if tp, err := strconv.Atoi(typeParam); err == nil { typeFilter = tp } } if typeFilter >= 0 { filtered := make([]*model.Channel, 0, len(channelData)) for _, ch := range channelData { if ch.Type == typeFilter { filtered = append(filtered, ch) } } channelData = filtered total = int64(len(filtered)) } typeCounts := make(map[int64]int64) for _, channel := range channelData { typeCounts[int64(channel.Type)]++ } page, _ := strconv.Atoi(c.DefaultQuery("p", "1")) pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20")) if page < 1 { page = 1 } if pageSize <= 0 { pageSize = 20 } startIdx := (page - 1) * pageSize if startIdx > len(channelData) { startIdx = len(channelData) } endIdx := startIdx + pageSize if endIdx > len(channelData) { endIdx = len(channelData) } pagedData := channelData[startIdx:endIdx] for _, datum := range pagedData { clearChannelInfo(datum) } attachSupplierNames(pagedData) c.JSON(http.StatusOK, gin.H{ "success": true, "message": "", "data": gin.H{ "items": pagedData, "total": total, "type_counts": typeCounts, }, }) return } channelData := make([]*model.Channel, 0) if enableTagMode { tags, err := model.SearchTags(keyword, group, modelKeyword, idSort) if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, "message": err.Error(), }) return } for _, tag := range tags { if tag != nil && *tag != "" { tagChannel, err := model.GetChannelsByTag(*tag, idSort, false) if err == nil { channelData = append(channelData, tagChannel...) } } } } else { channels, err := model.SearchChannels(keyword, group, modelKeyword, idSort) if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, "message": err.Error(), }) return } channelData = channels } attachSupplierNames(channelData) if supplierKeyword != "" { filteredBySupplier := make([]*model.Channel, 0, len(channelData)) for _, ch := range channelData { if strings.Contains(strings.ToLower(ch.SupplierName), strings.ToLower(supplierKeyword)) { filteredBySupplier = append(filteredBySupplier, ch) } } channelData = filteredBySupplier } if statusFilter == common.ChannelStatusEnabled || statusFilter == 0 { filtered := make([]*model.Channel, 0, len(channelData)) for _, ch := range channelData { if statusFilter == common.ChannelStatusEnabled && ch.Status != common.ChannelStatusEnabled { continue } if statusFilter == 0 && ch.Status == common.ChannelStatusEnabled { continue } filtered = append(filtered, ch) } channelData = filtered } // calculate type counts for search results typeCounts := make(map[int64]int64) for _, channel := range channelData { typeCounts[int64(channel.Type)]++ } typeParam := c.Query("type") typeFilter := -1 if typeParam != "" { if tp, err := strconv.Atoi(typeParam); err == nil { typeFilter = tp } } if typeFilter >= 0 { filtered := make([]*model.Channel, 0, len(channelData)) for _, ch := range channelData { if ch.Type == typeFilter { filtered = append(filtered, ch) } } channelData = filtered } page, _ := strconv.Atoi(c.DefaultQuery("p", "1")) pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20")) if page < 1 { page = 1 } if pageSize <= 0 { pageSize = 20 } total := len(channelData) startIdx := (page - 1) * pageSize if startIdx > total { startIdx = total } endIdx := startIdx + pageSize if endIdx > total { endIdx = total } pagedData := channelData[startIdx:endIdx] for _, datum := range pagedData { clearChannelInfo(datum) } c.JSON(http.StatusOK, gin.H{ "success": true, "message": "", "data": gin.H{ "items": pagedData, "total": total, "type_counts": typeCounts, }, }) return } func GetChannel(c *gin.Context) { id, err := strconv.Atoi(c.Param("id")) if err != nil { common.ApiError(c, err) return } channel, err := model.GetChannelById(id, false) if err != nil { common.ApiError(c, err) return } // 供应商仅允许查看自己归属的渠道。 if c.GetInt("role") < common.RoleAdminUser && channel.OwnerUserID != c.GetInt("id") { c.JSON(http.StatusForbidden, gin.H{ "success": false, "message": "无权访问其他供应商渠道", }) return } if channel != nil { clearChannelInfo(channel) } c.JSON(http.StatusOK, gin.H{ "success": true, "message": "", "data": channel, }) return } // GetChannelKey 获取渠道密钥(需要通过安全验证中间件) // 此函数依赖 SecureVerificationRequired 中间件,确保用户已通过安全验证 func GetChannelKey(c *gin.Context) { userId := c.GetInt("id") channelId, err := strconv.Atoi(c.Param("id")) if err != nil { common.ApiError(c, fmt.Errorf("渠道ID格式错误: %v", err)) return } // 获取渠道信息(包含密钥) channel, err := model.GetChannelById(channelId, true) if err != nil { common.ApiError(c, fmt.Errorf("获取渠道信息失败: %v", err)) return } if channel == nil { common.ApiError(c, fmt.Errorf("渠道不存在")) return } // 记录操作日志 model.RecordLog(userId, model.LogTypeSystem, fmt.Sprintf("查看渠道密钥信息 (渠道ID: %d)", channelId)) // 返回渠道密钥 c.JSON(http.StatusOK, gin.H{ "success": true, "message": "获取成功", "data": map[string]interface{}{ "key": channel.Key, }, }) } // validateTwoFactorAuth 统一的2FA验证函数 func validateTwoFactorAuth(twoFA *model.TwoFA, code string) bool { // 尝试验证TOTP if cleanCode, err := common.ValidateNumericCode(code); err == nil { if isValid, _ := twoFA.ValidateTOTPAndUpdateUsage(cleanCode); isValid { return true } } // 尝试验证备用码 if isValid, err := twoFA.ValidateBackupCodeAndUpdateUsage(code); err == nil && isValid { return true } return false } // validateChannel 通用的渠道校验函数 func validateChannel(channel *model.Channel, isAdd bool) error { if channel == nil { return fmt.Errorf("channel cannot be empty") } channel.CompanyLogoURL = strings.TrimSpace(channel.CompanyLogoURL) channel.SupplierType = strings.TrimSpace(channel.SupplierType) // TokenFactoryOpen (type=60) 渠道的 supplier_type 由上游同步继承,创建时允许为空 if channel.SupplierType == "" && channel.Type != constant.ChannelTypeTokenFactoryOpen { return fmt.Errorf("供应商类型不能为空") } if channel.SupplierType != "" && !isValidChannelSupplierType(channel.SupplierType) { return fmt.Errorf("供应商类型无效") } // 校验 channel settings if err := channel.ValidateSettings(); err != nil { return fmt.Errorf("渠道额外设置[channel setting] 格式错误:%s", err.Error()) } // 如果是添加操作,检查 channel 和 key 是否为空 if isAdd { if channel.Key == "" { return fmt.Errorf("channel cannot be empty") } // 检查模型名称长度是否超过 255 for _, m := range channel.GetModels() { if len(m) > 255 { return fmt.Errorf("模型名称过长: %s", m) } } } // VertexAI 特殊校验 if channel.Type == constant.ChannelTypeVertexAi { if channel.Other == "" { return fmt.Errorf("部署地区不能为空") } regionMap, err := common.StrToMap(channel.Other) if err != nil { return fmt.Errorf("部署地区必须是标准的Json格式,例如{\"default\": \"us-central1\", \"region2\": \"us-east1\"}") } if regionMap["default"] == nil { return fmt.Errorf("部署地区必须包含default字段") } } // Codex OAuth key validation (optional, only when JSON object is provided) if channel.Type == constant.ChannelTypeCodex { trimmedKey := strings.TrimSpace(channel.Key) if isAdd || trimmedKey != "" { if !strings.HasPrefix(trimmedKey, "{") { return fmt.Errorf("Codex key must be a valid JSON object") } var keyMap map[string]any if err := common.Unmarshal([]byte(trimmedKey), &keyMap); err != nil { return fmt.Errorf("Codex key must be a valid JSON object") } if v, ok := keyMap["access_token"]; !ok || v == nil || strings.TrimSpace(fmt.Sprintf("%v", v)) == "" { return fmt.Errorf("Codex key JSON must include access_token") } if v, ok := keyMap["account_id"]; !ok || v == nil || strings.TrimSpace(fmt.Sprintf("%v", v)) == "" { return fmt.Errorf("Codex key JSON must include account_id") } } } if channel != nil && channel.PriceDiscountPercent != nil { v := *channel.PriceDiscountPercent if v < 0 || v > 1000 { return fmt.Errorf("价格折扣(百分比)须介于 0 与 1000 之间,100 表示无折扣,60 表示按原价 60%% 计费") } } if rs := strings.TrimSpace(channel.RouteSlug); rs != "" && !model.IsValidRouteSlug(rs) { return fmt.Errorf("route_slug 格式无效(2~32 位字母数字,且不能为 c 加纯数字)") } return nil } func RefreshCodexChannelCredential(c *gin.Context) { channelId, err := strconv.Atoi(c.Param("id")) if err != nil { common.ApiError(c, fmt.Errorf("invalid channel id: %w", err)) return } ctx, cancel := context.WithTimeout(c.Request.Context(), 10*time.Second) defer cancel() oauthKey, ch, err := service.RefreshCodexChannelCredential(ctx, channelId, service.CodexCredentialRefreshOptions{ResetCaches: true}) if err != nil { common.SysError("failed to refresh codex channel credential: " + err.Error()) c.JSON(http.StatusOK, gin.H{"success": false, "message": "刷新凭证失败,请稍后重试"}) return } c.JSON(http.StatusOK, gin.H{ "success": true, "message": "refreshed", "data": gin.H{ "expires_at": oauthKey.Expired, "last_refresh": oauthKey.LastRefresh, "account_id": oauthKey.AccountID, "email": oauthKey.Email, "channel_id": ch.Id, "channel_type": ch.Type, "channel_name": ch.Name, }, }) } type AddChannelRequest struct { Mode string `json:"mode"` MultiKeyMode constant.MultiKeyMode `json:"multi_key_mode"` BatchAddSetKeyPrefix2Name bool `json:"batch_add_set_key_prefix_2_name"` Channel *model.Channel `json:"channel"` } // applySupplierChannelOwnershipForCreate 在供应商创建渠道时强制写入归属信息,防止越权伪造 owner 字段。 func applySupplierChannelOwnershipForCreate(c *gin.Context, channel *model.Channel) error { if c.GetInt("role") >= common.RoleAdminUser { return nil } app, err := model.GetApprovedSupplierApplicationByApplicant(c.GetInt("id")) if err != nil { return err } channel.OwnerUserID = c.GetInt("id") channel.SupplierApplicationID = app.ID return nil } // validateSupplierChannelOwnershipForUpdate 校验供应商仅可更新自己的渠道,管理员不受限制。 func validateSupplierChannelOwnershipForUpdate(c *gin.Context, originChannel *model.Channel) bool { if c.GetInt("role") >= common.RoleAdminUser { return true } return originChannel.OwnerUserID == c.GetInt("id") } func getVertexArrayKeys(keys string) ([]string, error) { if keys == "" { return nil, nil } var keyArray []interface{} err := common.Unmarshal([]byte(keys), &keyArray) if err != nil { return nil, fmt.Errorf("批量添加 Vertex AI 必须使用标准的JsonArray格式,例如[{key1}, {key2}...],请检查输入: %w", err) } cleanKeys := make([]string, 0, len(keyArray)) for _, key := range keyArray { var keyStr string switch v := key.(type) { case string: keyStr = strings.TrimSpace(v) default: bytes, err := json.Marshal(v) if err != nil { return nil, fmt.Errorf("Vertex AI key JSON 编码失败: %w", err) } keyStr = string(bytes) } if keyStr != "" { cleanKeys = append(cleanKeys, keyStr) } } if len(cleanKeys) == 0 { return nil, fmt.Errorf("批量添加 Vertex AI 的 keys 不能为空") } return cleanKeys, nil } type upstreamChannelSyncItem struct { ID int `json:"id"` Name string `json:"name"` Models string `json:"models"` Group string `json:"group"` Status int `json:"status"` Type int `json:"type"` ChannelNo string `json:"channel_no"` RouteSlug string `json:"route_slug"` SupplierApplication int `json:"supplier_application_id"` SupplierAlias string `json:"supplier_alias"` SupplierType string `json:"supplier_type"` CompanyLogoURL string `json:"company_logo_url"` PriceDiscountPercent float64 `json:"price_discount_percent"` MarkupDiscountRate float64 `json:"markup_discount_rate"` ModelMapping string `json:"model_mapping"` ModelPrice map[string]float64 `json:"model_price"` ModelRatio map[string]float64 `json:"model_ratio"` } func decodeUpstreamModelMapping(m map[string]any) string { raw, ok := m["model_mapping"] if !ok || raw == nil { return "" } switch x := raw.(type) { case string: return strings.TrimSpace(x) case map[string]any: b, err := json.Marshal(x) if err != nil { return "" } return strings.TrimSpace(string(b)) default: b, err := json.Marshal(raw) if err != nil { return strings.TrimSpace(common.Interface2String(raw)) } return strings.TrimSpace(string(b)) } } func isTokenFactoryOpenBaseURL(raw string) bool { parsed, err := url.Parse(strings.TrimSpace(raw)) if err != nil { return false } scheme := strings.ToLower(strings.TrimSpace(parsed.Scheme)) if scheme != "http" && scheme != "https" { return false } return strings.TrimSpace(parsed.Hostname()) != "" } func isLikelyTokenFactoryStatusData(data map[string]any, systemName string) bool { name := strings.ToLower(strings.TrimSpace(systemName)) if strings.Contains(name, "tokenfactory") || strings.Contains(name, "词元工厂") || strings.Contains(name, "开放词元工厂") { return true } score := 0 if strings.TrimSpace(common.Interface2String(data["version"])) != "" { score++ } if startTimeRaw := strings.TrimSpace(common.Interface2String(data["start_time"])); startTimeRaw != "" { if startTime, err := strconv.ParseInt(startTimeRaw, 10, 64); err == nil && startTime > 0 { score++ } } if strings.TrimSpace(common.Interface2String(data["quota_display_type"])) != "" || strings.TrimSpace(common.Interface2String(data["quota_per_unit"])) != "" { score++ } if _, ok := data["enable_drawing"]; ok { score++ } if _, ok := data["enable_task"]; ok { score++ } if _, ok := data["system_name"]; ok { score++ } // 命中特征达到阈值即视为 TokenFactory 平台实例,避免只依赖 system_name 英文名。 return score >= 4 } func fetchTokenFactoryStatus(baseURL string, key string) error { client := &http.Client{Timeout: 10 * time.Second} u := strings.TrimRight(strings.TrimSpace(baseURL), "/") + "/api/status" req, err := http.NewRequest("GET", u, nil) if err != nil { return err } req.Header.Set("Authorization", "Bearer "+strings.TrimSpace(key)) resp, err := client.Do(req) if err != nil { return err } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { return fmt.Errorf("status code %d", resp.StatusCode) } var payload map[string]any if err := json.NewDecoder(resp.Body).Decode(&payload); err != nil { return err } if success, ok := payload["success"].(bool); ok && !success { return fmt.Errorf("status 接口返回失败") } var systemName string var statusData map[string]any if parsedData, ok := payload["data"].(map[string]any); ok { statusData = parsedData systemName = strings.TrimSpace(common.Interface2String(statusData["system_name"])) } if statusData == nil { return fmt.Errorf("status 返回结构缺少 data") } if !isLikelyTokenFactoryStatusData(statusData, systemName) { return fmt.Errorf("status 特征不匹配 TokenFactory 平台(system_name=%s)", systemName) } return nil } func decodeUpstreamChannelPayload(payload map[string]any, itemsKey string) ([]upstreamChannelSyncItem, error) { successRaw, exists := payload["success"] if !exists { return nil, fmt.Errorf("上游响应缺少 success 字段") } success, ok := successRaw.(bool) if !ok { return nil, fmt.Errorf("上游 success 字段类型异常: %T", successRaw) } if !success { upstreamMessage := strings.TrimSpace(common.Interface2String(payload["message"])) if upstreamMessage == "" { upstreamMessage = "上游返回失败(message 为空)" } return nil, fmt.Errorf("%s", upstreamMessage) } data, _ := payload["data"].(map[string]any) if data == nil { return nil, fmt.Errorf("上游响应缺少 data") } rawItems, _ := data[itemsKey].([]any) items := make([]upstreamChannelSyncItem, 0, len(rawItems)) for _, raw := range rawItems { m, ok := raw.(map[string]any) if !ok { continue } item := upstreamChannelSyncItem{ ID: common.String2Int(common.Interface2String(m["id"])), Name: strings.TrimSpace(common.Interface2String(m["name"])), Models: strings.TrimSpace(common.Interface2String(m["models"])), Group: strings.TrimSpace(common.Interface2String(m["group"])), Status: common.String2Int(common.Interface2String(m["status"])), Type: common.String2Int(common.Interface2String(m["type"])), ChannelNo: strings.TrimSpace(common.Interface2String(m["channel_no"])), RouteSlug: strings.TrimSpace(common.Interface2String(m["route_slug"])), SupplierApplication: common.String2Int(common.Interface2String(m["supplier_application_id"])), SupplierAlias: strings.TrimSpace(common.Interface2String(m["supplier_alias"])), SupplierType: strings.TrimSpace(common.Interface2String(m["supplier_type"])), CompanyLogoURL: strings.TrimSpace(common.Interface2String(m["company_logo_url"])), } if v, ok := m["price_discount_percent"]; ok && v != nil { switch x := v.(type) { case float64: item.PriceDiscountPercent = x case json.Number: if f, err := x.Float64(); err == nil { item.PriceDiscountPercent = f } default: if f, err := strconv.ParseFloat(strings.TrimSpace(common.Interface2String(v)), 64); err == nil { item.PriceDiscountPercent = f } } } if v, ok := m["markup_discount_rate"]; ok && v != nil { switch x := v.(type) { case float64: item.MarkupDiscountRate = x case json.Number: if f, err := x.Float64(); err == nil { item.MarkupDiscountRate = f } default: if f, err := strconv.ParseFloat(strings.TrimSpace(common.Interface2String(v)), 64); err == nil { item.MarkupDiscountRate = f } } } if mp, ok := m["model_price"].(map[string]any); ok && len(mp) > 0 { item.ModelPrice = jsonAnyMapToFloatMap(mp) } if mr, ok := m["model_ratio"].(map[string]any); ok && len(mr) > 0 { item.ModelRatio = jsonAnyMapToFloatMap(mr) } item.ModelMapping = decodeUpstreamModelMapping(m) items = append(items, item) } return items, nil } func fetchTokenFactoryUpstreamChannelsExport(baseURL string, key string) ([]upstreamChannelSyncItem, error) { client := &http.Client{Timeout: 45 * time.Second} u := strings.TrimRight(strings.TrimSpace(baseURL), "/") + "/api/tf_open_sync/channels" req, err := http.NewRequest("GET", u, nil) if err != nil { return nil, err } k := strings.TrimSpace(key) req.Header.Set("Authorization", "Bearer "+k) req.Header.Set("X-TokenFactory-Open-Sync-Secret", k) resp, err := client.Do(req) if err != nil { return nil, err } defer resp.Body.Close() if resp.StatusCode == http.StatusNotFound { return nil, errTfOpenExportNotFound } if resp.StatusCode != http.StatusOK { body, _ := io.ReadAll(io.LimitReader(resp.Body, 1024)) return nil, fmt.Errorf("export 接口 status code %d: %s", resp.StatusCode, strings.TrimSpace(string(body))) } var payload map[string]any if err := json.NewDecoder(resp.Body).Decode(&payload); err != nil { return nil, err } return decodeUpstreamChannelPayload(payload, "channels") } var errTfOpenExportNotFound = errors.New("tf_open_sync channels export not found") func jsonAnyMapToFloatMap(raw map[string]any) map[string]float64 { out := make(map[string]float64) for k, v := range raw { switch x := v.(type) { case float64: out[k] = x case json.Number: if f, err := x.Float64(); err == nil { out[k] = f } default: if f, err := strconv.ParseFloat(strings.TrimSpace(common.Interface2String(v)), 64); err == nil { out[k] = f } } } if len(out) == 0 { return nil } return out } func fetchTokenFactoryUpstreamChannelsLegacy(baseURL string, key string) ([]upstreamChannelSyncItem, error) { client := &http.Client{Timeout: 15 * time.Second} u := strings.TrimRight(strings.TrimSpace(baseURL), "/") + "/api/channel/?p=1&page_size=100000&id_sort=true" req, err := http.NewRequest("GET", u, nil) if err != nil { return nil, err } req.Header.Set("Authorization", "Bearer "+strings.TrimSpace(key)) resp, err := client.Do(req) if err != nil { return nil, err } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { body, _ := io.ReadAll(io.LimitReader(resp.Body, 1024)) return nil, fmt.Errorf("upstream status code %d: %s", resp.StatusCode, strings.TrimSpace(string(body))) } var payload map[string]any if err := json.NewDecoder(resp.Body).Decode(&payload); err != nil { return nil, err } return decodeUpstreamChannelPayload(payload, "items") } func fetchTokenFactoryUpstreamChannels(baseURL string, key string) ([]upstreamChannelSyncItem, error) { items, err := fetchTokenFactoryUpstreamChannelsExport(baseURL, key) if err == nil && len(items) > 0 { return items, nil } if err != nil && !errors.Is(err, errTfOpenExportNotFound) { return nil, fmt.Errorf("拉取上游渠道(export): %w", err) } legacy, err2 := fetchTokenFactoryUpstreamChannelsLegacy(baseURL, key) if err2 != nil { return nil, fmt.Errorf("拉取上游渠道失败: %w", err2) } return legacy, nil } func tfOpenLocalChannelNo(up upstreamChannelSyncItem) string { // 留空让本地按既有逻辑分配 cN(按 supplier_application_id 递增)。 return "" } func buildTokenFactorySyncedChannels(base *model.Channel) ([]model.Channel, []model.TFOpenUpstreamPricing, error) { baseURL := base.GetBaseURL() if !isTokenFactoryOpenBaseURL(baseURL) { return nil, nil, fmt.Errorf("TokenFactoryOpen 渠道的 API 地址必须指向 TokenFactory 平台") } key := strings.TrimSpace(base.Key) if key == "" { return nil, nil, fmt.Errorf("TokenFactoryOpen 渠道密钥不能为空") } if err := fetchTokenFactoryStatus(baseURL, key); err != nil { return nil, nil, fmt.Errorf("TokenFactoryOpen 平台识别失败: %w", err) } upstreamChannels, err := fetchTokenFactoryUpstreamChannels(baseURL, key) if err != nil { return nil, nil, fmt.Errorf("拉取上游渠道失败: %w", err) } if len(upstreamChannels) == 0 { return nil, nil, fmt.Errorf("上游未返回可同步渠道") } now := common.GetTimestamp() result := make([]model.Channel, 0, len(upstreamChannels)) pricing := make([]model.TFOpenUpstreamPricing, 0, len(upstreamChannels)) for i, upstream := range upstreamChannels { clone := *base clone.Id = 0 clone.CreatedTime = now if upstream.Type > 0 { clone.Type = constant.ChannelTypeTokenFactoryOpen } else { clone.Type = constant.ChannelTypeTokenFactoryOpen } // 直接使用上游渠道名称,不再拼接序号后缀 upstreamName := strings.TrimSpace(upstream.Name) seqIdx := model.EncodeBase62(int64(i)) if upstreamName != "" { clone.Name = upstreamName } else { clone.Name = fmt.Sprintf("upstream-%s", seqIdx) } clone.Models = strings.TrimSpace(upstream.Models) if strings.TrimSpace(upstream.Group) != "" { clone.Group = strings.TrimSpace(upstream.Group) } if upstream.Status > 0 { clone.Status = upstream.Status } upstreamSupplierType := strings.TrimSpace(upstream.SupplierType) if upstreamSupplierType != "" && isValidChannelSupplierType(upstreamSupplierType) { clone.SupplierType = upstreamSupplierType } else if strings.TrimSpace(clone.SupplierType) == "" || !isValidChannelSupplierType(strings.TrimSpace(clone.SupplierType)) { clone.SupplierType = defaultChannelSupplierType } upstreamLogoURL := strings.TrimSpace(upstream.CompanyLogoURL) if upstreamLogoURL != "" { clone.CompanyLogoURL = upstreamLogoURL } if upstream.PriceDiscountPercent > 0 { clone.PriceDiscountPercent = &upstream.PriceDiscountPercent } mm := strings.TrimSpace(upstream.ModelMapping) if mm != "" { clone.ModelMapping = &mm } else { clone.ModelMapping = nil } clone.ChannelNo = tfOpenLocalChannelNo(upstream) clone.RouteSlug = "" syncMeta := map[string]any{ "source": "tokenfactory_open", "upstream_channel_id": upstream.ID, "upstream_channel_no": strings.TrimSpace(upstream.ChannelNo), "upstream_route_slug": strings.TrimSpace(upstream.RouteSlug), "upstream_supplier_app_id": upstream.SupplierApplication, "upstream_supplier_alias": strings.TrimSpace(upstream.SupplierAlias), "upstream_channel_type": upstream.Type, "local_channel_no": clone.ChannelNo, "sync_seq_index": seqIdx, // 本次同步批次内的 base-62 顺序编号 "synced_at": now, } metaJSON, _ := common.Marshal(syncMeta) clone.OtherInfo = string(metaJSON) result = append(result, clone) pricing = append(pricing, model.TFOpenUpstreamPricing{ ModelPrice: upstream.ModelPrice, ModelRatio: upstream.ModelRatio, }) } return result, pricing, nil } func AddChannel(c *gin.Context) { addChannelRequest := AddChannelRequest{} err := c.ShouldBindJSON(&addChannelRequest) if err != nil { common.ApiError(c, err) return } // 使用统一的校验函数 if err := validateChannel(addChannelRequest.Channel, true); err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, "message": err.Error(), }) return } if addChannelRequest.Channel != nil && addChannelRequest.Channel.PriceDiscountPercent == nil { v := 100.0 addChannelRequest.Channel.PriceDiscountPercent = &v } if err := applySupplierChannelOwnershipForCreate(c, addChannelRequest.Channel); err != nil { c.JSON(http.StatusForbidden, gin.H{ "success": false, "message": "当前用户未通过供应商审核,无权创建渠道", }) return } addChannelRequest.Channel.CreatedTime = common.GetTimestamp() keys := make([]string, 0) switch addChannelRequest.Mode { case "multi_to_single": addChannelRequest.Channel.ChannelInfo.IsMultiKey = true addChannelRequest.Channel.ChannelInfo.MultiKeyMode = addChannelRequest.MultiKeyMode if addChannelRequest.Channel.Type == constant.ChannelTypeVertexAi && addChannelRequest.Channel.GetOtherSettings().VertexKeyType != dto.VertexKeyTypeAPIKey { array, err := getVertexArrayKeys(addChannelRequest.Channel.Key) if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, "message": err.Error(), }) return } addChannelRequest.Channel.ChannelInfo.MultiKeySize = len(array) addChannelRequest.Channel.Key = strings.Join(array, "\n") } else { cleanKeys := make([]string, 0) for _, key := range strings.Split(addChannelRequest.Channel.Key, "\n") { if key == "" { continue } key = strings.TrimSpace(key) cleanKeys = append(cleanKeys, key) } addChannelRequest.Channel.ChannelInfo.MultiKeySize = len(cleanKeys) addChannelRequest.Channel.Key = strings.Join(cleanKeys, "\n") } keys = []string{addChannelRequest.Channel.Key} case "batch": if addChannelRequest.Channel.Type == constant.ChannelTypeVertexAi && addChannelRequest.Channel.GetOtherSettings().VertexKeyType != dto.VertexKeyTypeAPIKey { // multi json keys, err = getVertexArrayKeys(addChannelRequest.Channel.Key) if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, "message": err.Error(), }) return } } else { keys = strings.Split(addChannelRequest.Channel.Key, "\n") } case "single": keys = []string{addChannelRequest.Channel.Key} default: c.JSON(http.StatusOK, gin.H{ "success": false, "message": "不支持的添加模式", }) return } channels := make([]model.Channel, 0, len(keys)) for _, key := range keys { if key == "" { continue } localChannel := addChannelRequest.Channel localChannel.Key = key if addChannelRequest.BatchAddSetKeyPrefix2Name && len(keys) > 1 { keyPrefix := localChannel.Key if len(localChannel.Key) > 8 { keyPrefix = localChannel.Key[:8] } localChannel.Name = fmt.Sprintf("%s %s", localChannel.Name, keyPrefix) } channels = append(channels, *localChannel) } var tfOpenPricing []model.TFOpenUpstreamPricing if addChannelRequest.Channel.Type == constant.ChannelTypeTokenFactoryOpen { syncBase := *addChannelRequest.Channel if len(channels) > 0 { syncBase.Key = strings.TrimSpace(channels[0].Key) } syncedChannels, pricing, syncErr := buildTokenFactorySyncedChannels(&syncBase) if syncErr != nil { c.JSON(http.StatusOK, gin.H{ "success": false, "message": syncErr.Error(), }) return } channels = syncedChannels tfOpenPricing = pricing } if len(channels) > 1 { for i := range channels { channels[i].RouteSlug = "" } } if addChannelRequest.Channel.Type == constant.ChannelTypeTokenFactoryOpen { err = model.BatchInsertChannelsWithTfOpenUpstreamPricing(channels, tfOpenPricing) } else { err = model.BatchInsertChannels(channels) } if err != nil { common.ApiError(c, err) return } service.ResetProxyClientCache() c.JSON(http.StatusOK, gin.H{ "success": true, "message": "", }) return } func DeleteChannel(c *gin.Context) { id, _ := strconv.Atoi(c.Param("id")) channel := model.Channel{Id: id} err := channel.Delete() if err != nil { common.ApiError(c, err) return } model.InitChannelCache() c.JSON(http.StatusOK, gin.H{ "success": true, "message": "", }) return } func DeleteDisabledChannel(c *gin.Context) { rows, err := model.DeleteDisabledChannel() if err != nil { common.ApiError(c, err) return } model.InitChannelCache() c.JSON(http.StatusOK, gin.H{ "success": true, "message": "", "data": rows, }) return } type ChannelTag struct { Tag string `json:"tag"` NewTag *string `json:"new_tag"` Priority *int64 `json:"priority"` Weight *uint `json:"weight"` ModelMapping *string `json:"model_mapping"` Models *string `json:"models"` Groups *string `json:"groups"` ParamOverride *string `json:"param_override"` HeaderOverride *string `json:"header_override"` } func DisableTagChannels(c *gin.Context) { channelTag := ChannelTag{} err := c.ShouldBindJSON(&channelTag) if err != nil || channelTag.Tag == "" { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "参数错误", }) return } err = model.DisableChannelByTag(channelTag.Tag) if err != nil { common.ApiError(c, err) return } model.InitChannelCache() c.JSON(http.StatusOK, gin.H{ "success": true, "message": "", }) return } func EnableTagChannels(c *gin.Context) { channelTag := ChannelTag{} err := c.ShouldBindJSON(&channelTag) if err != nil || channelTag.Tag == "" { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "参数错误", }) return } err = model.EnableChannelByTag(channelTag.Tag) if err != nil { common.ApiError(c, err) return } model.InitChannelCache() c.JSON(http.StatusOK, gin.H{ "success": true, "message": "", }) return } func EditTagChannels(c *gin.Context) { channelTag := ChannelTag{} err := c.ShouldBindJSON(&channelTag) if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "参数错误", }) return } if channelTag.Tag == "" { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "tag不能为空", }) return } if channelTag.ParamOverride != nil { trimmed := strings.TrimSpace(*channelTag.ParamOverride) if trimmed != "" && !json.Valid([]byte(trimmed)) { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "参数覆盖必须是合法的 JSON 格式", }) return } channelTag.ParamOverride = common.GetPointer[string](trimmed) } if channelTag.HeaderOverride != nil { trimmed := strings.TrimSpace(*channelTag.HeaderOverride) if trimmed != "" && !json.Valid([]byte(trimmed)) { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "请求头覆盖必须是合法的 JSON 格式", }) return } channelTag.HeaderOverride = common.GetPointer[string](trimmed) } err = model.EditChannelByTag(channelTag.Tag, channelTag.NewTag, channelTag.ModelMapping, channelTag.Models, channelTag.Groups, channelTag.Priority, channelTag.Weight, channelTag.ParamOverride, channelTag.HeaderOverride) if err != nil { common.ApiError(c, err) return } model.InitChannelCache() c.JSON(http.StatusOK, gin.H{ "success": true, "message": "", }) return } type ChannelBatch struct { Ids []int `json:"ids"` Tag *string `json:"tag"` } func DeleteChannelBatch(c *gin.Context) { channelBatch := ChannelBatch{} err := c.ShouldBindJSON(&channelBatch) if err != nil || len(channelBatch.Ids) == 0 { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "参数错误", }) return } err = model.BatchDeleteChannels(channelBatch.Ids) if err != nil { common.ApiError(c, err) return } model.InitChannelCache() c.JSON(http.StatusOK, gin.H{ "success": true, "message": "", "data": len(channelBatch.Ids), }) return } type PatchChannel struct { model.Channel MultiKeyMode *string `json:"multi_key_mode"` KeyMode *string `json:"key_mode"` // 多key模式下密钥覆盖或者追加 } func UpdateChannel(c *gin.Context) { channel := PatchChannel{} err := c.ShouldBindJSON(&channel) if err != nil { common.ApiError(c, err) return } if channel.Id <= 0 { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "渠道ID无效", }) return } // Preserve existing ChannelInfo to ensure multi-key channels keep correct state even if the client does not send ChannelInfo in the request. originChannel, err := model.GetChannelById(channel.Id, true) if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, "message": err.Error(), }) return } if !validateSupplierChannelOwnershipForUpdate(c, originChannel) { c.JSON(http.StatusForbidden, gin.H{ "success": false, "message": "无权修改其他供应商渠道", }) return } oldBalance := originChannel.Balance // 部分更新(如仅改状态/优先级/权重):请求未带供应商类型时沿用库中值,否则 validateChannel 会因零值失败。 if strings.TrimSpace(channel.SupplierType) == "" { channel.SupplierType = strings.TrimSpace(originChannel.SupplierType) } if strings.TrimSpace(channel.SupplierType) == "" && originChannel.SupplierApplicationID > 0 { var app model.SupplierApplication if err := model.DB.Select("supplier_type").Where("id = ?", originChannel.SupplierApplicationID).First(&app).Error; err == nil { channel.SupplierType = strings.TrimSpace(app.SupplierType) } } if strings.TrimSpace(channel.SupplierType) == "" { channel.SupplierType = defaultChannelSupplierType } // route_slug:空则沿用库中值;非空变更时校验格式与全局唯一。 if strings.TrimSpace(channel.RouteSlug) == "" { channel.RouteSlug = originChannel.RouteSlug } else { channel.RouteSlug = strings.TrimSpace(channel.RouteSlug) if channel.RouteSlug != originChannel.RouteSlug { if !model.IsValidRouteSlug(channel.RouteSlug) { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "route_slug 格式无效(2~32 位字母数字,且不能为 c 加纯数字)", }) return } var cnt int64 if err := model.DB.Model(&model.Channel{}).Where("route_slug = ? AND id <> ?", channel.RouteSlug, channel.Id).Count(&cnt).Error; err != nil { common.ApiError(c, err) return } if cnt > 0 { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "route_slug 已被其他渠道占用", }) return } } } // 使用统一的校验函数 if err := validateChannel(&channel.Channel, false); err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, "message": err.Error(), }) return } // 供应商更新时强制保持归属信息不变,防止通过请求体篡改 owner/supplier 关联。 if c.GetInt("role") < common.RoleAdminUser { channel.OwnerUserID = originChannel.OwnerUserID channel.SupplierApplicationID = originChannel.SupplierApplicationID } // Always copy the original ChannelInfo so that fields like IsMultiKey and MultiKeySize are retained. channel.ChannelInfo = originChannel.ChannelInfo // If the request explicitly specifies a new MultiKeyMode, apply it on top of the original info. if channel.MultiKeyMode != nil && *channel.MultiKeyMode != "" { channel.ChannelInfo.MultiKeyMode = constant.MultiKeyMode(*channel.MultiKeyMode) } // 处理多key模式下的密钥追加/覆盖逻辑 if channel.KeyMode != nil && channel.ChannelInfo.IsMultiKey { switch *channel.KeyMode { case "append": // 追加模式:将新密钥添加到现有密钥列表 if originChannel.Key != "" { var newKeys []string var existingKeys []string // 解析现有密钥 if strings.HasPrefix(strings.TrimSpace(originChannel.Key), "[") { // JSON数组格式 var arr []json.RawMessage if err := json.Unmarshal([]byte(strings.TrimSpace(originChannel.Key)), &arr); err == nil { existingKeys = make([]string, len(arr)) for i, v := range arr { existingKeys[i] = string(v) } } } else { // 换行分隔格式 existingKeys = strings.Split(strings.Trim(originChannel.Key, "\n"), "\n") } // 处理 Vertex AI 的特殊情况 if channel.Type == constant.ChannelTypeVertexAi && channel.GetOtherSettings().VertexKeyType != dto.VertexKeyTypeAPIKey { // 尝试解析新密钥为JSON数组 if strings.HasPrefix(strings.TrimSpace(channel.Key), "[") { array, err := getVertexArrayKeys(channel.Key) if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "追加密钥解析失败: " + err.Error(), }) return } newKeys = array } else { // 单个JSON密钥 newKeys = []string{channel.Key} } } else { // 普通渠道的处理 inputKeys := strings.Split(channel.Key, "\n") for _, key := range inputKeys { key = strings.TrimSpace(key) if key != "" { newKeys = append(newKeys, key) } } } seen := make(map[string]struct{}, len(existingKeys)+len(newKeys)) for _, key := range existingKeys { normalized := strings.TrimSpace(key) if normalized == "" { continue } seen[normalized] = struct{}{} } dedupedNewKeys := make([]string, 0, len(newKeys)) for _, key := range newKeys { normalized := strings.TrimSpace(key) if normalized == "" { continue } if _, ok := seen[normalized]; ok { continue } seen[normalized] = struct{}{} dedupedNewKeys = append(dedupedNewKeys, normalized) } allKeys := append(existingKeys, dedupedNewKeys...) channel.Key = strings.Join(allKeys, "\n") } case "replace": // 覆盖模式:直接使用新密钥(默认行为,不需要特殊处理) } } err = channel.Update() if err != nil { common.ApiError(c, err) return } notifyChannelBalanceAlertIfNeeded(originChannel, oldBalance, channel.Balance) model.InitChannelCache() service.ResetProxyClientCache() channel.Key = "" clearChannelInfo(&channel.Channel) c.JSON(http.StatusOK, gin.H{ "success": true, "message": "", "data": channel, }) return } func FetchModels(c *gin.Context) { var req struct { BaseURL string `json:"base_url"` Type int `json:"type"` Key string `json:"key"` } if err := c.ShouldBindJSON(&req); err != nil { c.JSON(http.StatusBadRequest, gin.H{ "success": false, "message": "Invalid request", }) return } baseURL := req.BaseURL if baseURL == "" { baseURL = constant.ChannelBaseURLs[req.Type] } // remove line breaks and extra spaces. key := strings.TrimSpace(req.Key) key = strings.Split(key, "\n")[0] if req.Type == constant.ChannelTypeOllama { models, err := ollama.FetchOllamaModels(c.Request.Context(), baseURL, key) if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, "message": fmt.Sprintf("获取Ollama模型失败: %s", err.Error()), }) return } names := make([]string, 0, len(models)) for _, modelInfo := range models { names = append(names, modelInfo.Name) } c.JSON(http.StatusOK, gin.H{ "success": true, "data": names, }) return } if req.Type == constant.ChannelTypeGemini { models, err := gemini.FetchGeminiModels(c.Request.Context(), baseURL, key, "") if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, "message": fmt.Sprintf("获取Gemini模型失败: %s", err.Error()), }) return } c.JSON(http.StatusOK, gin.H{ "success": true, "data": models, }) return } client := &http.Client{} url := fmt.Sprintf("%s/v1/models", baseURL) request, err := http.NewRequest("GET", url, nil) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{ "success": false, "message": err.Error(), }) return } request.Header.Set("Authorization", "Bearer "+key) response, err := client.Do(request) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{ "success": false, "message": err.Error(), }) return } //check status code if response.StatusCode != http.StatusOK { c.JSON(http.StatusInternalServerError, gin.H{ "success": false, "message": "Failed to fetch models", }) return } defer response.Body.Close() var result struct { Data []struct { ID string `json:"id"` } `json:"data"` } if err := json.NewDecoder(response.Body).Decode(&result); err != nil { c.JSON(http.StatusInternalServerError, gin.H{ "success": false, "message": err.Error(), }) return } var models []string for _, model := range result.Data { models = append(models, model.ID) } c.JSON(http.StatusOK, gin.H{ "success": true, "data": models, }) } func BatchSetChannelTag(c *gin.Context) { channelBatch := ChannelBatch{} err := c.ShouldBindJSON(&channelBatch) if err != nil || len(channelBatch.Ids) == 0 { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "参数错误", }) return } err = model.BatchSetChannelTag(channelBatch.Ids, channelBatch.Tag) if err != nil { common.ApiError(c, err) return } model.InitChannelCache() c.JSON(http.StatusOK, gin.H{ "success": true, "message": "", "data": len(channelBatch.Ids), }) return } func GetTagModels(c *gin.Context) { tag := c.Query("tag") if tag == "" { c.JSON(http.StatusBadRequest, gin.H{ "success": false, "message": "tag不能为空", }) return } channels, err := model.GetChannelsByTag(tag, false, false) // idSort=false, selectAll=false if err != nil { c.JSON(http.StatusInternalServerError, gin.H{ "success": false, "message": err.Error(), }) return } var longestModels string maxLength := 0 // Find the longest models string among all channels with the given tag for _, channel := range channels { if channel.Models != "" { currentModels := strings.Split(channel.Models, ",") if len(currentModels) > maxLength { maxLength = len(currentModels) longestModels = channel.Models } } } c.JSON(http.StatusOK, gin.H{ "success": true, "message": "", "data": longestModels, }) return } // CopyChannel handles cloning an existing channel with its key. // POST /api/channel/copy/:id // Optional query params: // // suffix - string appended to the original name (default "_复制") // reset_balance - bool, when true will reset balance & used_quota to 0 (default true) func CopyChannel(c *gin.Context) { id, err := strconv.Atoi(c.Param("id")) if err != nil { c.JSON(http.StatusOK, gin.H{"success": false, "message": "invalid id"}) return } suffix := c.DefaultQuery("suffix", "_复制") resetBalance := true if rbStr := c.DefaultQuery("reset_balance", "true"); rbStr != "" { if v, err := strconv.ParseBool(rbStr); err == nil { resetBalance = v } } // fetch original channel with key origin, err := model.GetChannelById(id, true) if err != nil { common.SysError("failed to get channel by id: " + err.Error()) c.JSON(http.StatusOK, gin.H{"success": false, "message": "获取渠道信息失败,请稍后重试"}) return } // clone channel clone := *origin // shallow copy is sufficient as we will overwrite primitives clone.Id = 0 // let DB auto-generate clone.CreatedTime = common.GetTimestamp() clone.Name = origin.Name + suffix clone.TestTime = 0 clone.ResponseTime = 0 if resetBalance { clone.Balance = 0 clone.UsedQuota = 0 } clone.ChannelNo = "" clone.RouteSlug = "" // insert if err := model.BatchInsertChannels([]model.Channel{clone}); err != nil { common.SysError("failed to clone channel: " + err.Error()) c.JSON(http.StatusOK, gin.H{"success": false, "message": "复制渠道失败,请稍后重试"}) return } model.InitChannelCache() // success c.JSON(http.StatusOK, gin.H{"success": true, "message": "", "data": gin.H{"id": clone.Id}}) } // MultiKeyManageRequest represents the request for multi-key management operations type MultiKeyManageRequest struct { ChannelId int `json:"channel_id"` Action string `json:"action"` // "disable_key", "enable_key", "delete_key", "delete_disabled_keys", "get_key_status" KeyIndex *int `json:"key_index,omitempty"` // for disable_key, enable_key, and delete_key actions Page int `json:"page,omitempty"` // for get_key_status pagination PageSize int `json:"page_size,omitempty"` // for get_key_status pagination Status *int `json:"status,omitempty"` // for get_key_status filtering: 1=enabled, 2=manual_disabled, 3=auto_disabled, nil=all } // MultiKeyStatusResponse represents the response for key status query type MultiKeyStatusResponse struct { Keys []KeyStatus `json:"keys"` Total int `json:"total"` Page int `json:"page"` PageSize int `json:"page_size"` TotalPages int `json:"total_pages"` // Statistics EnabledCount int `json:"enabled_count"` ManualDisabledCount int `json:"manual_disabled_count"` AutoDisabledCount int `json:"auto_disabled_count"` } type KeyStatus struct { Index int `json:"index"` Status int `json:"status"` // 1: enabled, 2: disabled DisabledTime int64 `json:"disabled_time,omitempty"` Reason string `json:"reason,omitempty"` KeyPreview string `json:"key_preview"` // first 10 chars of key for identification } // ManageMultiKeys handles multi-key management operations func ManageMultiKeys(c *gin.Context) { request := MultiKeyManageRequest{} err := c.ShouldBindJSON(&request) if err != nil { common.ApiError(c, err) return } channel, err := model.GetChannelById(request.ChannelId, true) if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "渠道不存在", }) return } if !channel.ChannelInfo.IsMultiKey { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "该渠道不是多密钥模式", }) return } lock := model.GetChannelPollingLock(channel.Id) lock.Lock() defer lock.Unlock() switch request.Action { case "get_key_status": keys := channel.GetKeys() // Default pagination parameters page := request.Page pageSize := request.PageSize if page <= 0 { page = 1 } if pageSize <= 0 { pageSize = 50 // Default page size } // Statistics for all keys (unchanged by filtering) var enabledCount, manualDisabledCount, autoDisabledCount int // Build all key status data first var allKeyStatusList []KeyStatus for i, key := range keys { status := 1 // default enabled var disabledTime int64 var reason string if channel.ChannelInfo.MultiKeyStatusList != nil { if s, exists := channel.ChannelInfo.MultiKeyStatusList[i]; exists { status = s } } // Count for statistics (all keys) switch status { case 1: enabledCount++ case 2: manualDisabledCount++ case 3: autoDisabledCount++ } if status != 1 { if channel.ChannelInfo.MultiKeyDisabledTime != nil { disabledTime = channel.ChannelInfo.MultiKeyDisabledTime[i] } if channel.ChannelInfo.MultiKeyDisabledReason != nil { reason = channel.ChannelInfo.MultiKeyDisabledReason[i] } } // Create key preview (first 10 chars) keyPreview := key if len(key) > 10 { keyPreview = key[:10] + "..." } allKeyStatusList = append(allKeyStatusList, KeyStatus{ Index: i, Status: status, DisabledTime: disabledTime, Reason: reason, KeyPreview: keyPreview, }) } // Apply status filter if specified var filteredKeyStatusList []KeyStatus if request.Status != nil { for _, keyStatus := range allKeyStatusList { if keyStatus.Status == *request.Status { filteredKeyStatusList = append(filteredKeyStatusList, keyStatus) } } } else { filteredKeyStatusList = allKeyStatusList } // Calculate pagination based on filtered results filteredTotal := len(filteredKeyStatusList) totalPages := (filteredTotal + pageSize - 1) / pageSize if totalPages == 0 { totalPages = 1 } if page > totalPages { page = totalPages } // Calculate range for current page start := (page - 1) * pageSize end := start + pageSize if end > filteredTotal { end = filteredTotal } // Get the page data var pageKeyStatusList []KeyStatus if start < filteredTotal { pageKeyStatusList = filteredKeyStatusList[start:end] } c.JSON(http.StatusOK, gin.H{ "success": true, "message": "", "data": MultiKeyStatusResponse{ Keys: pageKeyStatusList, Total: filteredTotal, // Total of filtered results Page: page, PageSize: pageSize, TotalPages: totalPages, EnabledCount: enabledCount, // Overall statistics ManualDisabledCount: manualDisabledCount, // Overall statistics AutoDisabledCount: autoDisabledCount, // Overall statistics }, }) return case "disable_key": if request.KeyIndex == nil { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "未指定要禁用的密钥索引", }) return } keyIndex := *request.KeyIndex if keyIndex < 0 || keyIndex >= channel.ChannelInfo.MultiKeySize { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "密钥索引超出范围", }) return } if channel.ChannelInfo.MultiKeyStatusList == nil { channel.ChannelInfo.MultiKeyStatusList = make(map[int]int) } if channel.ChannelInfo.MultiKeyDisabledTime == nil { channel.ChannelInfo.MultiKeyDisabledTime = make(map[int]int64) } if channel.ChannelInfo.MultiKeyDisabledReason == nil { channel.ChannelInfo.MultiKeyDisabledReason = make(map[int]string) } channel.ChannelInfo.MultiKeyStatusList[keyIndex] = 2 // disabled err = channel.Update() if err != nil { common.ApiError(c, err) return } model.InitChannelCache() c.JSON(http.StatusOK, gin.H{ "success": true, "message": "密钥已禁用", }) return case "enable_key": if request.KeyIndex == nil { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "未指定要启用的密钥索引", }) return } keyIndex := *request.KeyIndex if keyIndex < 0 || keyIndex >= channel.ChannelInfo.MultiKeySize { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "密钥索引超出范围", }) return } // 从状态列表中删除该密钥的记录,使其回到默认启用状态 if channel.ChannelInfo.MultiKeyStatusList != nil { delete(channel.ChannelInfo.MultiKeyStatusList, keyIndex) } if channel.ChannelInfo.MultiKeyDisabledTime != nil { delete(channel.ChannelInfo.MultiKeyDisabledTime, keyIndex) } if channel.ChannelInfo.MultiKeyDisabledReason != nil { delete(channel.ChannelInfo.MultiKeyDisabledReason, keyIndex) } err = channel.Update() if err != nil { common.ApiError(c, err) return } model.InitChannelCache() c.JSON(http.StatusOK, gin.H{ "success": true, "message": "密钥已启用", }) return case "enable_all_keys": // 清空所有禁用状态,使所有密钥回到默认启用状态 var enabledCount int if channel.ChannelInfo.MultiKeyStatusList != nil { enabledCount = len(channel.ChannelInfo.MultiKeyStatusList) } channel.ChannelInfo.MultiKeyStatusList = make(map[int]int) channel.ChannelInfo.MultiKeyDisabledTime = make(map[int]int64) channel.ChannelInfo.MultiKeyDisabledReason = make(map[int]string) err = channel.Update() if err != nil { common.ApiError(c, err) return } model.InitChannelCache() c.JSON(http.StatusOK, gin.H{ "success": true, "message": fmt.Sprintf("已启用 %d 个密钥", enabledCount), }) return case "disable_all_keys": // 禁用所有启用的密钥 if channel.ChannelInfo.MultiKeyStatusList == nil { channel.ChannelInfo.MultiKeyStatusList = make(map[int]int) } if channel.ChannelInfo.MultiKeyDisabledTime == nil { channel.ChannelInfo.MultiKeyDisabledTime = make(map[int]int64) } if channel.ChannelInfo.MultiKeyDisabledReason == nil { channel.ChannelInfo.MultiKeyDisabledReason = make(map[int]string) } var disabledCount int for i := 0; i < channel.ChannelInfo.MultiKeySize; i++ { status := 1 // default enabled if s, exists := channel.ChannelInfo.MultiKeyStatusList[i]; exists { status = s } // 只禁用当前启用的密钥 if status == 1 { channel.ChannelInfo.MultiKeyStatusList[i] = 2 // disabled disabledCount++ } } if disabledCount == 0 { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "没有可禁用的密钥", }) return } err = channel.Update() if err != nil { common.ApiError(c, err) return } model.InitChannelCache() c.JSON(http.StatusOK, gin.H{ "success": true, "message": fmt.Sprintf("已禁用 %d 个密钥", disabledCount), }) return case "delete_key": if request.KeyIndex == nil { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "未指定要删除的密钥索引", }) return } keyIndex := *request.KeyIndex if keyIndex < 0 || keyIndex >= channel.ChannelInfo.MultiKeySize { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "密钥索引超出范围", }) return } keys := channel.GetKeys() var remainingKeys []string var newStatusList = make(map[int]int) var newDisabledTime = make(map[int]int64) var newDisabledReason = make(map[int]string) newIndex := 0 for i, key := range keys { // 跳过要删除的密钥 if i == keyIndex { continue } remainingKeys = append(remainingKeys, key) // 保留其他密钥的状态信息,重新索引 if channel.ChannelInfo.MultiKeyStatusList != nil { if status, exists := channel.ChannelInfo.MultiKeyStatusList[i]; exists && status != 1 { newStatusList[newIndex] = status } } if channel.ChannelInfo.MultiKeyDisabledTime != nil { if t, exists := channel.ChannelInfo.MultiKeyDisabledTime[i]; exists { newDisabledTime[newIndex] = t } } if channel.ChannelInfo.MultiKeyDisabledReason != nil { if r, exists := channel.ChannelInfo.MultiKeyDisabledReason[i]; exists { newDisabledReason[newIndex] = r } } newIndex++ } if len(remainingKeys) == 0 { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "不能删除最后一个密钥", }) return } // Update channel with remaining keys channel.Key = strings.Join(remainingKeys, "\n") channel.ChannelInfo.MultiKeySize = len(remainingKeys) channel.ChannelInfo.MultiKeyStatusList = newStatusList channel.ChannelInfo.MultiKeyDisabledTime = newDisabledTime channel.ChannelInfo.MultiKeyDisabledReason = newDisabledReason err = channel.Update() if err != nil { common.ApiError(c, err) return } model.InitChannelCache() c.JSON(http.StatusOK, gin.H{ "success": true, "message": "密钥已删除", }) return case "delete_disabled_keys": keys := channel.GetKeys() var remainingKeys []string var deletedCount int var newStatusList = make(map[int]int) var newDisabledTime = make(map[int]int64) var newDisabledReason = make(map[int]string) newIndex := 0 for i, key := range keys { status := 1 // default enabled if channel.ChannelInfo.MultiKeyStatusList != nil { if s, exists := channel.ChannelInfo.MultiKeyStatusList[i]; exists { status = s } } // 只删除自动禁用(status == 3)的密钥,保留启用(status == 1)和手动禁用(status == 2)的密钥 if status == 3 { deletedCount++ } else { remainingKeys = append(remainingKeys, key) // 保留非自动禁用密钥的状态信息,重新索引 if status != 1 { newStatusList[newIndex] = status if channel.ChannelInfo.MultiKeyDisabledTime != nil { if t, exists := channel.ChannelInfo.MultiKeyDisabledTime[i]; exists { newDisabledTime[newIndex] = t } } if channel.ChannelInfo.MultiKeyDisabledReason != nil { if r, exists := channel.ChannelInfo.MultiKeyDisabledReason[i]; exists { newDisabledReason[newIndex] = r } } } newIndex++ } } if deletedCount == 0 { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "没有需要删除的自动禁用密钥", }) return } // Update channel with remaining keys channel.Key = strings.Join(remainingKeys, "\n") channel.ChannelInfo.MultiKeySize = len(remainingKeys) channel.ChannelInfo.MultiKeyStatusList = newStatusList channel.ChannelInfo.MultiKeyDisabledTime = newDisabledTime channel.ChannelInfo.MultiKeyDisabledReason = newDisabledReason err = channel.Update() if err != nil { common.ApiError(c, err) return } model.InitChannelCache() c.JSON(http.StatusOK, gin.H{ "success": true, "message": fmt.Sprintf("已删除 %d 个自动禁用的密钥", deletedCount), "data": deletedCount, }) return default: c.JSON(http.StatusOK, gin.H{ "success": false, "message": "不支持的操作", }) return } } // OllamaPullModel 拉取 Ollama 模型 func OllamaPullModel(c *gin.Context) { var req struct { ChannelID int `json:"channel_id"` ModelName string `json:"model_name"` } if err := c.ShouldBindJSON(&req); err != nil { c.JSON(http.StatusBadRequest, gin.H{ "success": false, "message": "Invalid request parameters", }) return } if req.ChannelID == 0 || req.ModelName == "" { c.JSON(http.StatusBadRequest, gin.H{ "success": false, "message": "Channel ID and model name are required", }) return } // 获取渠道信息 channel, err := model.GetChannelById(req.ChannelID, true) if err != nil { c.JSON(http.StatusNotFound, gin.H{ "success": false, "message": "Channel not found", }) return } // 检查是否是 Ollama 渠道 if channel.Type != constant.ChannelTypeOllama { c.JSON(http.StatusBadRequest, gin.H{ "success": false, "message": "This operation is only supported for Ollama channels", }) return } baseURL := constant.ChannelBaseURLs[channel.Type] if channel.GetBaseURL() != "" { baseURL = channel.GetBaseURL() } key := strings.Split(channel.Key, "\n")[0] err = ollama.PullOllamaModel(baseURL, key, req.ModelName) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{ "success": false, "message": fmt.Sprintf("Failed to pull model: %s", err.Error()), }) return } c.JSON(http.StatusOK, gin.H{ "success": true, "message": fmt.Sprintf("Model %s pulled successfully", req.ModelName), }) } // OllamaPullModelStream 流式拉取 Ollama 模型 func OllamaPullModelStream(c *gin.Context) { var req struct { ChannelID int `json:"channel_id"` ModelName string `json:"model_name"` } if err := c.ShouldBindJSON(&req); err != nil { c.JSON(http.StatusBadRequest, gin.H{ "success": false, "message": "Invalid request parameters", }) return } if req.ChannelID == 0 || req.ModelName == "" { c.JSON(http.StatusBadRequest, gin.H{ "success": false, "message": "Channel ID and model name are required", }) return } // 获取渠道信息 channel, err := model.GetChannelById(req.ChannelID, true) if err != nil { c.JSON(http.StatusNotFound, gin.H{ "success": false, "message": "Channel not found", }) return } // 检查是否是 Ollama 渠道 if channel.Type != constant.ChannelTypeOllama { c.JSON(http.StatusBadRequest, gin.H{ "success": false, "message": "This operation is only supported for Ollama channels", }) return } baseURL := constant.ChannelBaseURLs[channel.Type] if channel.GetBaseURL() != "" { baseURL = channel.GetBaseURL() } // 设置 SSE 头部 c.Header("Content-Type", "text/event-stream") c.Header("Cache-Control", "no-cache") c.Header("Connection", "keep-alive") c.Header("Access-Control-Allow-Origin", "*") key := strings.Split(channel.Key, "\n")[0] // 创建进度回调函数 progressCallback := func(progress ollama.OllamaPullResponse) { data, _ := json.Marshal(progress) fmt.Fprintf(c.Writer, "data: %s\n\n", string(data)) c.Writer.Flush() } // 执行拉取 err = ollama.PullOllamaModelStream(baseURL, key, req.ModelName, progressCallback) if err != nil { errorData, _ := json.Marshal(gin.H{ "error": err.Error(), }) fmt.Fprintf(c.Writer, "data: %s\n\n", string(errorData)) } else { successData, _ := json.Marshal(gin.H{ "message": fmt.Sprintf("Model %s pulled successfully", req.ModelName), }) fmt.Fprintf(c.Writer, "data: %s\n\n", string(successData)) } // 发送结束标志 fmt.Fprintf(c.Writer, "data: [DONE]\n\n") c.Writer.Flush() } // OllamaDeleteModel 删除 Ollama 模型 func OllamaDeleteModel(c *gin.Context) { var req struct { ChannelID int `json:"channel_id"` ModelName string `json:"model_name"` } if err := c.ShouldBindJSON(&req); err != nil { c.JSON(http.StatusBadRequest, gin.H{ "success": false, "message": "Invalid request parameters", }) return } if req.ChannelID == 0 || req.ModelName == "" { c.JSON(http.StatusBadRequest, gin.H{ "success": false, "message": "Channel ID and model name are required", }) return } // 获取渠道信息 channel, err := model.GetChannelById(req.ChannelID, true) if err != nil { c.JSON(http.StatusNotFound, gin.H{ "success": false, "message": "Channel not found", }) return } // 检查是否是 Ollama 渠道 if channel.Type != constant.ChannelTypeOllama { c.JSON(http.StatusBadRequest, gin.H{ "success": false, "message": "This operation is only supported for Ollama channels", }) return } baseURL := constant.ChannelBaseURLs[channel.Type] if channel.GetBaseURL() != "" { baseURL = channel.GetBaseURL() } key := strings.Split(channel.Key, "\n")[0] err = ollama.DeleteOllamaModel(baseURL, key, req.ModelName) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{ "success": false, "message": fmt.Sprintf("Failed to delete model: %s", err.Error()), }) return } c.JSON(http.StatusOK, gin.H{ "success": true, "message": fmt.Sprintf("Model %s deleted successfully", req.ModelName), }) } // OllamaVersion 获取 Ollama 服务版本信息 func OllamaVersion(c *gin.Context) { id, err := strconv.Atoi(c.Param("id")) if err != nil { c.JSON(http.StatusBadRequest, gin.H{ "success": false, "message": "Invalid channel id", }) return } channel, err := model.GetChannelById(id, true) if err != nil { c.JSON(http.StatusNotFound, gin.H{ "success": false, "message": "Channel not found", }) return } if channel.Type != constant.ChannelTypeOllama { c.JSON(http.StatusBadRequest, gin.H{ "success": false, "message": "This operation is only supported for Ollama channels", }) return } baseURL := constant.ChannelBaseURLs[channel.Type] if channel.GetBaseURL() != "" { baseURL = channel.GetBaseURL() } key := strings.Split(channel.Key, "\n")[0] version, err := ollama.FetchOllamaVersion(baseURL, key) if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, "message": fmt.Sprintf("获取Ollama版本失败: %s", err.Error()), }) return } c.JSON(http.StatusOK, gin.H{ "success": true, "data": gin.H{ "version": version, }, }) }