package controller import ( "fmt" "net/http" "strconv" "strings" "time" "github.com/QuantumNous/new-api/common" "github.com/QuantumNous/new-api/constant" "github.com/QuantumNous/new-api/model" "github.com/QuantumNous/new-api/setting/system_setting" "github.com/gin-gonic/gin" ) // ─── 导出字段键常量 ──────────────────────────────────────────────────────────── const ( chFieldName = "name" chFieldDiscountRate = "discountRate" chFieldMarkupDiscount = "markupDiscountRate" chFieldRouteSlug = "routeSlug" chFieldQuota = "quota" chFieldDisabled = "disabled" chFieldSupplierName = "supplierName" chFieldType = "type" chFieldLogo = "logo" chFieldProviderType = "providerType" chFieldApiKey = "apiKey" chFieldApiBaseUrl = "apiBaseUrl" chFieldModels = "models" chFieldGroups = "groups" chFieldModelRedirect = "modelRedirect" chFieldOtherInfo = "otherInfo" ) // chAllowedExportFields 允许导出的合法字段集合,防止非法字段注入。 var chAllowedExportFields = map[string]bool{ chFieldName: true, chFieldDiscountRate: true, chFieldRouteSlug: true, chFieldQuota: true, chFieldDisabled: true, chFieldSupplierName: true, chFieldType: true, chFieldLogo: true, chFieldProviderType: true, chFieldApiKey: true, chFieldApiBaseUrl: true, chFieldModels: true, chFieldGroups: true, chFieldModelRedirect: true, chFieldOtherInfo: true, } // ─── DTO 定义 ────────────────────────────────────────────────────────────────── // ChannelExportRequest 渠道导出请求体。 type ChannelExportRequest struct { ChannelIDs []int `json:"channel_ids"` // 需要导出的渠道 ID 列表 Fields []string `json:"fields"` // 用户选择的字段列表 Mode string `json:"mode"` // 导出模式: "standard"(默认) | "site_builder"(建站用户导出) } // ChannelExportPayload 导出响应的数据结构(可直接用于后续导入)。 type ChannelExportPayload struct { Version string `json:"version"` ExportTime string `json:"exportTime"` Channels []map[string]interface{} `json:"channels"` } // ChannelImportRequest 导入请求结构(与导出结构兼容)。 type ChannelImportRequest struct { Version string `json:"version"` ExportTime string `json:"exportTime"` Channels []map[string]interface{} `json:"channels"` SiteBuilderApiKey string `json:"site_builder_api_key,omitempty"` // 建站模式统一密钥,导入 type=60 渠道时若 apiKey 为空则原样写入渠道 Key } // ChannelImportResult 导入操作的结果统计。 type ChannelImportResult struct { Added int `json:"added"` Updated int `json:"updated"` Failed int `json:"failed"` Failures []ChannelImportFailure `json:"failures"` } // ChannelImportFailure 单条导入失败的详情。 type ChannelImportFailure struct { Name string `json:"name"` Reason string `json:"reason"` } // ─── 导出接口 ───────────────────────────────────────────────────────────────── // ExportChannels 按渠道 ID 列表导出指定字段。 // POST /api/channel/export // mode=standard (默认): 原样导出渠道数据 // mode=site_builder: 建站用户导出,type 强制为 60,apiKey 置空(由导入方指定建站密钥), // apiBaseUrl 为本平台 ServerAddress,otherInfo 中标记来源与路由信息。 func ExportChannels(c *gin.Context) { var req ChannelExportRequest if err := common.DecodeJson(c.Request.Body, &req); err != nil { c.JSON(http.StatusBadRequest, gin.H{"success": false, "message": "请求格式错误"}) return } if len(req.ChannelIDs) == 0 { c.JSON(http.StatusBadRequest, gin.H{"success": false, "message": "请先选择需要导出的渠道"}) return } // 过滤非法字段,只保留允许导出的合法字段 fieldSet := make(map[string]bool) for _, f := range req.Fields { if chAllowedExportFields[f] { fieldSet[f] = true } } if len(fieldSet) == 0 { c.JSON(http.StatusBadRequest, gin.H{"success": false, "message": "请至少选择一个导出字段"}) return } // 始终包含 name,以便导入时做名称匹配 fieldSet[chFieldName] = true channels, err := model.GetChannelsByIDs(req.ChannelIDs) if err != nil { common.ApiError(c, err) return } isSiteBuilder := req.Mode == "site_builder" items := make([]map[string]interface{}, 0, len(channels)) for _, ch := range channels { if isSiteBuilder { items = append(items, buildSiteBuilderExportItem(c, ch, fieldSet)) } else { items = append(items, buildChannelExportItem(ch, fieldSet)) } } c.JSON(http.StatusOK, gin.H{ "success": true, "data": ChannelExportPayload{ Version: "1.0", ExportTime: time.Now().UTC().Format(time.RFC3339), Channels: items, }, }) } // buildChannelExportItem 根据字段集合构建单个渠道的导出 map(未选字段不出现在结果中)。 func buildChannelExportItem(ch *model.Channel, fields map[string]bool) map[string]interface{} { item := make(map[string]interface{}) if fields[chFieldName] { item[chFieldName] = ch.Name } if fields[chFieldDiscountRate] { item[chFieldDiscountRate] = ch.PriceDiscountPercent } if fields[chFieldMarkupDiscount] { item[chFieldMarkupDiscount] = ch.MarkupDiscountRate } if fields[chFieldRouteSlug] { item[chFieldRouteSlug] = ch.RouteSlug } if fields[chFieldQuota] { item[chFieldQuota] = ch.Balance } if fields[chFieldDisabled] { // Status=2 表示禁用,其他值表示启用 item[chFieldDisabled] = ch.Status == 2 } if fields[chFieldSupplierName] { item[chFieldSupplierName] = ch.SupplierName } if fields[chFieldType] { item[chFieldType] = ch.Type } if fields[chFieldLogo] { item[chFieldLogo] = ch.CompanyLogoURL } if fields[chFieldProviderType] { item[chFieldProviderType] = ch.SupplierType } if fields[chFieldApiKey] { item[chFieldApiKey] = ch.Key } if fields[chFieldApiBaseUrl] { baseURL := "" if ch.BaseURL != nil { baseURL = *ch.BaseURL } item[chFieldApiBaseUrl] = baseURL } if fields[chFieldModels] { // Models 字段存储为逗号分隔字符串,导出时转换为数组 item[chFieldModels] = ch.GetModels() } if fields[chFieldGroups] { // Group 字段存储为逗号分隔字符串,导出时转换为数组 item[chFieldGroups] = ch.GetGroups() } if fields[chFieldModelRedirect] { redirect := map[string]string{} if ch.ModelMapping != nil && *ch.ModelMapping != "" { _ = common.UnmarshalJsonStr(*ch.ModelMapping, &redirect) } item[chFieldModelRedirect] = redirect } if fields[chFieldOtherInfo] { otherInfo := ch.GetOtherInfo() if len(otherInfo) > 0 { item[chFieldOtherInfo] = otherInfo } } return item } // buildSiteBuilderExportItem 构建建站用户导出项。 // 核心差异:type 固定为 60 (TokenFactoryOpen),apiKey 置空(由导入方在导入时指定建站密钥), // apiBaseUrl 为本平台 ServerAddress,otherInfo 中标记来源和路由信息。 func buildSiteBuilderExportItem(c *gin.Context, ch *model.Channel, fields map[string]bool) map[string]interface{} { item := buildChannelExportItem(ch, fields) // 强制覆盖 type = 60 (TokenFactoryOpen) if fields[chFieldType] { item[chFieldType] = constant.ChannelTypeTokenFactoryOpen } // 强制覆盖 apiBaseUrl 为本平台 ServerAddress serverAddr := strings.TrimRight(system_setting.ServerAddress, "/") if serverAddr == "" { // fallback: 从请求中推导 scheme := "http" if c.Request.TLS != nil || c.GetHeader("X-Forwarded-Proto") == "https" { scheme = "https" } serverAddr = fmt.Sprintf("%s://%s", scheme, c.Request.Host) } item[chFieldApiBaseUrl] = serverAddr // 确保字段集合中包含 apiBaseUrl,即使原先未勾选 fields[chFieldApiBaseUrl] = true // 建站模式下 apiKey 置空,由导入方通过 site_builder_api_key 参数统一指定密钥。 // 导入时:若渠道 type=60 且 apiKey 为空,则使用导入请求中的 site_builder_api_key。 item[chFieldApiKey] = "" // 确保字段集合中包含 apiKey fields[chFieldApiKey] = true // 对于建站导出,在 otherInfo 中标记来源为 tokenfactory_open, // 并保留上游路由信息(route_slug 等),以便导入方可正确路由请求到上游渠道。 otherInfo := ch.GetOtherInfo() if otherInfo == nil { otherInfo = make(map[string]interface{}) } otherInfo["source"] = "tokenfactory_open" // 保留原渠道的 route_slug 作为 upstream_route_slug if ch.RouteSlug != "" { otherInfo["upstream_route_slug"] = ch.RouteSlug } // 保留原渠道的 supplier 信息 if ch.SupplierName != "" { otherInfo["upstream_supplier_alias"] = ch.SupplierName } item[chFieldOtherInfo] = otherInfo fields[chFieldOtherInfo] = true return item } // ─── 导入接口 ────────────────────────────────────────────────────────────── // ImportChannels 按名称匹配导入渠道配置。 // 核心规则:仅通过 name 匹配;同名则更新(仅更新 JSON 中存在的字段);不存在则新增; // 绝对禁止清空/覆盖未传字段;绝对禁止删除已有渠道。 // POST /api/channel/import func ImportChannels(c *gin.Context) { var req ChannelImportRequest if err := common.DecodeJson(c.Request.Body, &req); err != nil { c.JSON(http.StatusBadRequest, gin.H{"success": false, "message": "JSON 格式错误,请上传合法的导出文件"}) return } if req.Channels == nil { c.JSON(http.StatusBadRequest, gin.H{"success": false, "message": "channels 字段不能为空"}) return } result := &ChannelImportResult{Failures: []ChannelImportFailure{}} for _, item := range req.Channels { // 校验 name 字段 name, ok := chGetStr(item, "name") if !ok || strings.TrimSpace(name) == "" { result.Failed++ result.Failures = append(result.Failures, ChannelImportFailure{Name: "(未知)", Reason: "缺少或无效的 name 字段"}) continue } name = strings.TrimSpace(name) // 校验字段类型合法性(models/groups 必须为数组,modelRedirect 必须为对象) if err := chValidateItem(item); err != nil { result.Failed++ result.Failures = append(result.Failures, ChannelImportFailure{Name: name, Reason: err.Error()}) continue } // 按名称查询是否存在同名渠道 existing, err := model.GetChannelByName(name) if err != nil { result.Failed++ result.Failures = append(result.Failures, ChannelImportFailure{Name: name, Reason: "查询渠道失败: " + err.Error()}) continue } if existing != nil { // 同名渠道已存在:仅更新 JSON 中存在的字段,不清空其他字段 if err := chApplyToExisting(existing, item, req.SiteBuilderApiKey); err != nil { result.Failed++ result.Failures = append(result.Failures, ChannelImportFailure{Name: name, Reason: "更新失败: " + err.Error()}) continue } result.Updated++ } else { // 不存在同名渠道:新增 newCh := &model.Channel{} if err := chApplyToNew(newCh, item, req.SiteBuilderApiKey); err != nil { result.Failed++ result.Failures = append(result.Failures, ChannelImportFailure{Name: name, Reason: "构建新增数据失败: " + err.Error()}) continue } if err := newCh.Insert(); err != nil { result.Failed++ result.Failures = append(result.Failures, ChannelImportFailure{Name: name, Reason: "创建渠道失败: " + err.Error()}) continue } result.Added++ } } common.SysLog(fmt.Sprintf("渠道导入完成:新增 %d,更新 %d,失败 %d", result.Added, result.Updated, result.Failed)) c.JSON(http.StatusOK, gin.H{ "success": true, "message": "渠道导入完成", "data": result, }) } // ─── 内部工具函数 ────────────────────────────────────────────────────────────── // chGetStr 从 map 中安全读取字符串值。 func chGetStr(m map[string]interface{}, key string) (string, bool) { v, ok := m[key] if !ok { return "", false } s, ok := v.(string) return s, ok } // chToFloat64 将 JSON 数字(float64)或其他数值类型统一转为 float64。 func chToFloat64(v interface{}) float64 { switch val := v.(type) { case float64: return val case float32: return float64(val) case int: return float64(val) case int64: return float64(val) case string: f, _ := strconv.ParseFloat(val, 64) return f } return 0 } // chValidateItem 校验导入条目中各字段的类型合法性。 // 非法字段跳过,不影响其他条目继续处理。 func chValidateItem(item map[string]interface{}) error { if v, ok := item["models"]; ok && v != nil { if _, ok := v.([]interface{}); !ok { return fmt.Errorf("models 字段必须为数组") } } if v, ok := item["groups"]; ok && v != nil { if _, ok := v.([]interface{}); !ok { return fmt.Errorf("groups 字段必须为数组") } } if v, ok := item["modelRedirect"]; ok && v != nil { if _, ok := v.(map[string]interface{}); !ok { return fmt.Errorf("modelRedirect 字段必须为对象") } } if v, ok := item["otherInfo"]; ok && v != nil { if _, ok := v.(map[string]interface{}); !ok { return fmt.Errorf("otherInfo 字段必须为对象") } } return nil } // chApplyToExisting 将导入数据应用到已存在的渠道(精确更新,仅更新 JSON 中存在的字段)。 // 通过 GORM Select+Updates 确保只写入指定列,不影响其他列。 // siteBuilderApiKey: 建站模式统一密钥,当渠道 type=60 且 apiKey 为空时使用此值。 func chApplyToExisting(ch *model.Channel, item map[string]interface{}, siteBuilderApiKey string) error { cols := make([]string, 0, len(item)) updates := &model.Channel{} if v, ok := item["discountRate"]; ok { f := chToFloat64(v) updates.PriceDiscountPercent = &f cols = append(cols, "price_discount_percent") } if v, ok := item["markupDiscountRate"]; ok { f := chToFloat64(v) updates.MarkupDiscountRate = &f cols = append(cols, "markup_discount_rate") } if v, ok := item["disabled"]; ok { if b, isBool := v.(bool); isBool { if b { updates.Status = 2 // 禁用 } else { updates.Status = 1 // 启用 } cols = append(cols, "status") } } if v, ok := item["type"]; ok { updates.Type = int(chToFloat64(v)) cols = append(cols, "type") } if v, ok := item["logo"]; ok { if s, ok := v.(string); ok { updates.CompanyLogoURL = s cols = append(cols, "company_logo_url") } } if v, ok := item["providerType"]; ok { if s, ok := v.(string); ok { updates.SupplierType = s cols = append(cols, "supplier_type") } } if v, ok := item["apiKey"]; ok { if s, ok := v.(string); ok { // 建站模式:当 apiKey 为空且渠道 type=60 时,使用导入请求中的统一密钥 if strings.TrimSpace(s) == "" && siteBuilderApiKey != "" { channelType := 0 if t, ok := item["type"]; ok { channelType = int(chToFloat64(t)) } else { channelType = ch.Type } if channelType == constant.ChannelTypeTokenFactoryOpen { s = strings.TrimSpace(siteBuilderApiKey) } } updates.Key = s cols = append(cols, "key") } } if v, ok := item["apiBaseUrl"]; ok { if s, ok := v.(string); ok { updates.BaseURL = &s cols = append(cols, "base_url") } } if v, ok := item["models"]; ok { if arr, ok := v.([]interface{}); ok { parts := make([]string, 0, len(arr)) for _, m := range arr { if s, ok := m.(string); ok && strings.TrimSpace(s) != "" { parts = append(parts, strings.TrimSpace(s)) } } updates.Models = strings.Join(parts, ",") cols = append(cols, "models") } } if v, ok := item["groups"]; ok { if arr, ok := v.([]interface{}); ok { parts := make([]string, 0, len(arr)) for _, g := range arr { if s, ok := g.(string); ok && strings.TrimSpace(s) != "" { parts = append(parts, strings.TrimSpace(s)) } } updates.Group = strings.Join(parts, ",") cols = append(cols, "group") } } if v, ok := item["modelRedirect"]; ok { if m, ok := v.(map[string]interface{}); ok { redirect := make(map[string]string, len(m)) for k, val := range m { if s, ok := val.(string); ok { redirect[k] = s } } b, err := common.Marshal(redirect) if err != nil { return fmt.Errorf("序列化 modelRedirect 失败: %w", err) } s := string(b) updates.ModelMapping = &s cols = append(cols, "model_mapping") } } if v, ok := item["quota"]; ok { updates.Balance = chToFloat64(v) cols = append(cols, "balance") } if v, ok := item["routeSlug"]; ok { if s, ok := v.(string); ok { updates.RouteSlug = s cols = append(cols, "route_slug") } } if v, ok := item["otherInfo"]; ok { if m, ok := v.(map[string]interface{}); ok { b, err := common.Marshal(m) if err != nil { return fmt.Errorf("序列化 otherInfo 失败: %w", err) } s := string(b) updates.OtherInfo = s cols = append(cols, "other_info") } } if len(cols) == 0 { // 没有可更新的字段,直接跳过(不报错) return nil } // 使用精确列选择更新,确保只写入指定列 return model.PartialUpdateChannelFields(ch.Id, cols, updates) } // chApplyToNew 将导入数据写入新渠道对象,用于新增场景。 // siteBuilderApiKey: 建站模式统一密钥,当渠道 type=60 且 apiKey 为空时使用此值。 func chApplyToNew(ch *model.Channel, item map[string]interface{}, siteBuilderApiKey string) error { name, ok := chGetStr(item, "name") if !ok || strings.TrimSpace(name) == "" { return fmt.Errorf("name 字段缺失") } ch.Name = strings.TrimSpace(name) // 默认启用;若 JSON 中 disabled=true 则禁用 ch.Status = 1 if v, ok := item["disabled"]; ok { if b, isBool := v.(bool); isBool && b { ch.Status = 2 } } if v, ok := item["discountRate"]; ok { f := chToFloat64(v) ch.PriceDiscountPercent = &f } if v, ok := item["markupDiscountRate"]; ok { f := chToFloat64(v) ch.MarkupDiscountRate = &f } if v, ok := item["routeSlug"]; ok { if s, ok := v.(string); ok { ch.RouteSlug = s } } if v, ok := item["quota"]; ok { ch.Balance = chToFloat64(v) } if v, ok := item["type"]; ok { ch.Type = int(chToFloat64(v)) } if v, ok := item["logo"]; ok { if s, ok := v.(string); ok { ch.CompanyLogoURL = s } } if v, ok := item["providerType"]; ok { if s, ok := v.(string); ok { ch.SupplierType = s } } if v, ok := item["apiKey"]; ok { if s, ok := v.(string); ok { // 建站模式:当 apiKey 为空且渠道 type=60 时,使用导入请求中的统一密钥 if strings.TrimSpace(s) == "" && siteBuilderApiKey != "" { if ch.Type == constant.ChannelTypeTokenFactoryOpen { s = strings.TrimSpace(siteBuilderApiKey) } } ch.Key = s } } if v, ok := item["apiBaseUrl"]; ok { if s, ok := v.(string); ok { ch.BaseURL = &s } } if v, ok := item["models"]; ok { if arr, ok := v.([]interface{}); ok { parts := make([]string, 0, len(arr)) for _, m := range arr { if s, ok := m.(string); ok && strings.TrimSpace(s) != "" { parts = append(parts, strings.TrimSpace(s)) } } ch.Models = strings.Join(parts, ",") } } if v, ok := item["groups"]; ok { if arr, ok := v.([]interface{}); ok { parts := make([]string, 0, len(arr)) for _, g := range arr { if s, ok := g.(string); ok && strings.TrimSpace(s) != "" { parts = append(parts, strings.TrimSpace(s)) } } ch.Group = strings.Join(parts, ",") } } else { ch.Group = "default" // 新增渠道默认分组 } if v, ok := item["modelRedirect"]; ok { if m, ok := v.(map[string]interface{}); ok { redirect := make(map[string]string, len(m)) for k, val := range m { if s, ok := val.(string); ok { redirect[k] = s } } b, err := common.Marshal(redirect) if err != nil { return fmt.Errorf("序列化 modelRedirect 失败: %w", err) } s := string(b) ch.ModelMapping = &s } } if v, ok := item["otherInfo"]; ok { if m, ok := v.(map[string]interface{}); ok { b, err := common.Marshal(m) if err != nil { return fmt.Errorf("序列化 otherInfo 失败: %w", err) } ch.OtherInfo = string(b) } } return nil }