package controller import ( "net/http" "strconv" "strings" "time" "github.com/QuantumNous/new-api/common" "github.com/QuantumNous/new-api/model" "github.com/QuantumNous/new-api/service" "github.com/QuantumNous/new-api/setting/ratio_setting" "github.com/gin-gonic/gin" ) // OnboardResult 渠道上架诊断结果,前端根据此结构引导用户完成各上架步骤。 type OnboardResult struct { // 上游可拉取的模型列表(拉取失败时为空列表) ModelsAvailable []string `json:"models_available"` // 当前渠道已启用的模型列表 ModelsImported []string `json:"models_imported"` // 已有 model_meta 记录的模型(类型/描述已配置) MetaLinked []string `json:"meta_linked"` // 缺少 model_meta 记录的模型(需去 /console/models 配置) MetaMissing []string `json:"meta_missing"` // 已有定价配置的模型 RatioConfigured []string `json:"ratio_configured"` // 缺少定价配置的模型(需同步或手动配置) RatioMissing []string `json:"ratio_missing"` // 该渠道是否支持上游倍率同步(有 http base_url) CanSyncRatio bool `json:"can_sync_ratio"` // 满足测试条件:已导入模型 + 所有模型均有定价 ReadyToTest bool `json:"ready_to_test"` // 为加速响应未请求上游模型列表;前端可带 fetch_upstream=1 再拉取 UpstreamSkipped bool `json:"upstream_skipped,omitempty"` // 非阻断性警告信息 Warnings []string `json:"warnings,omitempty"` } // OnboardChannel 渠道上架状态诊断(只读)。 // 拉取上游模型列表、检查 model_meta 配置状态、检查定价配置状态, // 返回 OnboardResult 供前端引导用户完成各步骤。 func OnboardChannel(c *gin.Context) { id, err := strconv.Atoi(c.Param("id")) if err != nil { common.ApiError(c, err) return } channel, err := model.GetChannelById(id, true) if err != nil { common.ApiError(c, err) return } if c.GetInt("role") < common.RoleAdminUser && channel.OwnerUserID != c.GetInt("id") { c.JSON(http.StatusForbidden, gin.H{ "success": false, "message": "无权访问其他供应商渠道", }) return } result := OnboardResult{ ModelsAvailable: []string{}, ModelsImported: []string{}, MetaLinked: []string{}, MetaMissing: []string{}, RatioConfigured: []string{}, RatioMissing: []string{}, Warnings: []string{}, } // 1. 拉取上游模型列表(已有导入时默认跳过以将响应控制在毫秒~百毫秒级;需列表时加 ?fetch_upstream=1) importedForSkip := channel.GetModels() fetchUpstream := strings.EqualFold(strings.TrimSpace(c.Query("fetch_upstream")), "1") || strings.EqualFold(strings.TrimSpace(c.Query("fetch_upstream")), "true") skipUpstream := len(importedForSkip) > 0 && !fetchUpstream if skipUpstream { result.UpstreamSkipped = true } else { upstreamModelIDs, fetchErr := fetchChannelUpstreamModelIDs(channel) if fetchErr != nil { result.Warnings = append(result.Warnings, "拉取上游模型列表失败: "+fetchErr.Error()) } else { result.ModelsAvailable = upstreamModelIDs } } // 2. 当前渠道已启用的模型 channelModels := channel.GetModels() if channelModels != nil { result.ModelsImported = channelModels } // 3. 诊断目标:优先用已导入的模型,否则用上游可用模型 diagModels := result.ModelsImported if len(diagModels) == 0 { diagModels = result.ModelsAvailable } // 4. 检查 model_meta 记录 if len(diagModels) > 0 { existingNames, _ := model.GetExistingModelNames(diagModels) existingSet := make(map[string]bool, len(existingNames)) for _, name := range existingNames { existingSet[name] = true } for _, m := range diagModels { if existingSet[m] { result.MetaLinked = append(result.MetaLinked, m) } else { result.MetaMissing = append(result.MetaMissing, m) } } } // 5. 检查定价配置(渠道级优先,再查全局) for _, m := range diagModels { // 渠道级 price 优先(通过 ratio_sync 配置的渠道专属定价) if _, ok := ratio_setting.GetChannelModelPrice(channel.Id, m); ok { result.RatioConfigured = append(result.RatioConfigured, m) continue } // 渠道级 ratio if _, ok := ratio_setting.GetChannelModelRatio(channel.Id, m); ok { result.RatioConfigured = append(result.RatioConfigured, m) continue } // 全局 model_price / model_ratio 兜底 if _, _, exist := ratio_setting.GetModelRatioOrPrice(m); exist { result.RatioConfigured = append(result.RatioConfigured, m) continue } result.RatioMissing = append(result.RatioMissing, m) } // 6. 是否支持上游 ratio_sync if base := channel.GetBaseURL(); strings.HasPrefix(base, "http") { result.CanSyncRatio = true } // 7. 就绪状态 result.ReadyToTest = len(result.ModelsImported) > 0 && len(result.RatioMissing) == 0 c.JSON(http.StatusOK, gin.H{ "success": true, "message": "", "data": result, }) } // UpdateChannelModelsRequest 更新渠道模型列表请求。 type UpdateChannelModelsRequest struct { Models []string `json:"models" binding:"required"` } // UpdateChannelModels 仅更新渠道的模型列表,同步更新 abilities 表。 // 不需要传输完整渠道信息(包括密钥),适合用于上架向导模型导入场景。 func UpdateChannelModels(c *gin.Context) { id, err := strconv.Atoi(c.Param("id")) if err != nil { common.ApiError(c, err) return } var req UpdateChannelModelsRequest if err := c.ShouldBindJSON(&req); err != nil { c.JSON(http.StatusBadRequest, gin.H{ "success": false, "message": "请求参数格式错误: " + err.Error(), }) return } channel, err := model.GetChannelById(id, true) if err != nil { common.ApiError(c, err) return } if c.GetInt("role") < common.RoleAdminUser && channel.OwnerUserID != c.GetInt("id") { c.JSON(http.StatusForbidden, gin.H{ "success": false, "message": "无权修改其他供应商渠道", }) return } // 去重并过滤空值 seen := make(map[string]bool) clean := make([]string, 0, len(req.Models)) for _, m := range req.Models { m = strings.TrimSpace(m) if m != "" && !seen[m] { seen[m] = true clean = append(clean, m) } } channel.Models = strings.Join(clean, ",") if err := model.DB.Model(channel).Update("models", channel.Models).Error; err != nil { common.ApiError(c, err) return } if err := channel.UpdateAbilities(nil); err != nil { common.SysError("onboard: failed to update abilities after model patch: " + err.Error()) } c.JSON(http.StatusOK, gin.H{ "success": true, "message": "", }) } // AutoMetaRequest 自动推断元数据请求:对指定模型名列表执行自动创建。 // 若 Models 为空,则使用当前渠道的已导入模型列表。 type AutoMetaRequest struct { Models []string `json:"models"` } // AutoMetaChannelModels 为渠道中缺少 model_meta 的模型自动推断并创建元数据。 // 推断优先级:① 官方预设精确匹配 → ② 模型名称规则推断。 // 已有记录的模型直接跳过,幂等安全。 func AutoMetaChannelModels(c *gin.Context) { id, err := strconv.Atoi(c.Param("id")) if err != nil { common.ApiError(c, err) return } channel, err := model.GetChannelById(id, false) if err != nil { common.ApiError(c, err) return } if c.GetInt("role") < common.RoleAdminUser && channel.OwnerUserID != c.GetInt("id") { c.JSON(http.StatusForbidden, gin.H{ "success": false, "message": "无权访问其他供应商渠道", }) return } var req AutoMetaRequest _ = c.ShouldBindJSON(&req) // 允许空体 // 目标模型列表:优先用请求体,否则取渠道已导入模型 targets := req.Models if len(targets) == 0 { targets = channel.GetModels() } if len(targets) == 0 { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "渠道尚未导入任何模型,请先导入模型后再执行自动推断", }) return } results := service.AutoCreateMissingModelMeta(c.Request.Context(), targets) // 统计摘要 var created, skipped, failed int for _, r := range results { switch r.Source { case "exists": skipped++ default: if r.Err != "" { failed++ } else { created++ } } } c.JSON(http.StatusOK, gin.H{ "success": true, "message": "", "data": gin.H{ "created": created, "skipped": skipped, "failed": failed, "items": results, }, }) } // BulkTestModelItem 批量测试单个模型的结果。 type BulkTestModelItem struct { ModelName string `json:"model_name"` Success bool `json:"success"` Time float64 `json:"time"` // 秒 Message string `json:"message"` } // BulkTestChannelModels 批量测试渠道的指定模型列表,每个模型串行执行, // 避免前端发出大量并发请求触发全局限流。 // POST /api/channel/:id/onboard/test func BulkTestChannelModels(c *gin.Context) { channelId, err := strconv.Atoi(c.Param("id")) if err != nil { common.ApiError(c, err) return } channel, err := model.GetChannelById(channelId, true) if err != nil { common.ApiError(c, err) return } // 解析请求体 var req struct { Models []string `json:"models"` } if err := c.ShouldBindJSON(&req); err != nil { common.ApiError(c, err) return } targets := req.Models if len(targets) == 0 { targets = channel.GetModels() } if len(targets) == 0 { c.JSON(http.StatusOK, gin.H{"success": false, "message": "no models to test"}) return } results := make([]BulkTestModelItem, 0, len(targets)) for _, modelName := range targets { modelName = strings.TrimSpace(modelName) if modelName == "" { continue } tik := time.Now() res := testChannel(channel, modelName, "", false) elapsed := float64(time.Since(tik).Milliseconds()) / 1000.0 // 判断成功与否 success := res.localErr == nil && res.tokenFactoryError == nil msg := "" if res.localErr != nil { msg = res.localErr.Error() } else if res.tokenFactoryError != nil { msg = res.tokenFactoryError.Error() } // 持久化(与单测保持一致) ms := int64(elapsed * 1000) go func(ch *model.Channel, mn string, ok bool, ms int64, m string) { ch.UpdateTestResult(ok, ms, m, mn) _ = model.UpsertModelTestResult(ch.Id, mn, ok, ms, m) }(channel, modelName, success, ms, msg) results = append(results, BulkTestModelItem{ ModelName: modelName, Success: success, Time: elapsed, Message: msg, }) } c.JSON(http.StatusOK, gin.H{ "success": true, "message": "", "data": results, }) } // GetChannelTestResults 返回某渠道在 model_test_results 表中的全部历史测试记录。 // GET /api/channel/:id/test_results func GetChannelTestResults(c *gin.Context) { channelId, err := strconv.Atoi(c.Param("id")) if err != nil { common.ApiError(c, err) return } rows, err := model.GetAllModelTestResultsByChannelID(channelId) if err != nil { common.ApiError(c, err) return } c.JSON(http.StatusOK, gin.H{ "success": true, "message": "", "data": rows, }) }