376 lines
15 KiB
Go
376 lines
15 KiB
Go
package controller
|
||
|
||
import (
|
||
"fmt"
|
||
"net/http"
|
||
"strings"
|
||
|
||
"github.com/QuantumNous/new-api/common"
|
||
"github.com/QuantumNous/new-api/model"
|
||
"github.com/gin-gonic/gin"
|
||
)
|
||
|
||
// ─── 导出/导入共用的数据结构 ──────────────────────────────────────────────────
|
||
|
||
// PriceExportModelMaps 一组模型定价映射(字段名与全局 Option key 保持一致)。
|
||
type PriceExportModelMaps struct {
|
||
ModelPrice map[string]float64 `json:"ModelPrice"`
|
||
ModelRatio map[string]float64 `json:"ModelRatio"`
|
||
CompletionRatio map[string]float64 `json:"CompletionRatio"`
|
||
CacheRatio map[string]float64 `json:"CacheRatio"`
|
||
CreateCacheRatio map[string]float64 `json:"CreateCacheRatio"`
|
||
ImageRatio map[string]float64 `json:"ImageRatio"`
|
||
AudioRatio map[string]float64 `json:"AudioRatio"`
|
||
AudioCompletionRatio map[string]float64 `json:"AudioCompletionRatio"`
|
||
}
|
||
|
||
// PriceExportChannelEntry 单渠道价格导出/导入条目(用 channel_name 标识,不含 ID)。
|
||
type PriceExportChannelEntry struct {
|
||
ChannelName string `json:"channel_name"`
|
||
Models PriceExportModelMaps `json:"models"`
|
||
}
|
||
|
||
// PriceExportData 完整导出结构(可直接用于后续导入)。
|
||
type PriceExportData struct {
|
||
GlobalPrices PriceExportModelMaps `json:"global_prices"`
|
||
Channels []PriceExportChannelEntry `json:"channels"`
|
||
}
|
||
|
||
// PriceImportChannelStat 单渠道导入统计。
|
||
type PriceImportChannelStat struct {
|
||
ChannelName string `json:"channel_name"`
|
||
Updated int `json:"updated"`
|
||
Added int `json:"added"`
|
||
}
|
||
|
||
// PriceImportResult 导入结果统计(返回给前端展示)。
|
||
type PriceImportResult struct {
|
||
GlobalUpdated int `json:"global_updated"`
|
||
GlobalAdded int `json:"global_added"`
|
||
ChannelStats []PriceImportChannelStat `json:"channel_stats"`
|
||
SkippedChannels []string `json:"skipped_channels"`
|
||
}
|
||
|
||
// ─── 内部工具函数 ──────────────────────────────────────────────────────────────
|
||
|
||
// readOptionStr 从内存 OptionMap 安全读取字符串值(只读锁)。
|
||
func readOptionStr(key string) string {
|
||
common.OptionMapRWMutex.RLock()
|
||
defer common.OptionMapRWMutex.RUnlock()
|
||
return common.Interface2String(common.OptionMap[key])
|
||
}
|
||
|
||
// parseFloatMapSafe 将 JSON 字符串解析为 map[string]float64,失败时返回空 map。
|
||
func parseFloatMapSafe(raw string) map[string]float64 {
|
||
out := map[string]float64{}
|
||
if strings.TrimSpace(raw) == "" {
|
||
return out
|
||
}
|
||
_ = common.UnmarshalJsonStr(raw, &out)
|
||
return out
|
||
}
|
||
|
||
// parseNestedFloatMapSafe 将 JSON 字符串解析为 map[string]map[string]float64,失败时返回空 map。
|
||
func parseNestedFloatMapSafe(raw string) map[string]map[string]float64 {
|
||
out := map[string]map[string]float64{}
|
||
if strings.TrimSpace(raw) == "" {
|
||
return out
|
||
}
|
||
_ = common.UnmarshalJsonStr(raw, &out)
|
||
return out
|
||
}
|
||
|
||
// safeFloatMap 确保返回非 nil 的 map。
|
||
func safeFloatMap(m map[string]float64) map[string]float64 {
|
||
if m == nil {
|
||
return map[string]float64{}
|
||
}
|
||
return m
|
||
}
|
||
|
||
// marshalToJSON 将值序列化为 JSON 字符串,失败返回 "{}"。
|
||
func marshalToJSON(v any) string {
|
||
b, err := common.Marshal(v)
|
||
if err != nil {
|
||
return "{}"
|
||
}
|
||
return string(b)
|
||
}
|
||
|
||
// mergeFloatMapCounting 将 src 中的键值增量合并到 dst(不删除 dst 中已有键),返回 added/updated 数量。
|
||
func mergeFloatMapCounting(dst, src map[string]float64) (added, updated int) {
|
||
for k, v := range src {
|
||
if _, exists := dst[k]; exists {
|
||
dst[k] = v
|
||
updated++
|
||
} else {
|
||
dst[k] = v
|
||
added++
|
||
}
|
||
}
|
||
return
|
||
}
|
||
|
||
// isModelMapsEmpty 判断 PriceExportModelMaps 是否所有子 map 均为空。
|
||
func isModelMapsEmpty(m PriceExportModelMaps) bool {
|
||
return len(m.ModelPrice) == 0 &&
|
||
len(m.ModelRatio) == 0 &&
|
||
len(m.CompletionRatio) == 0 &&
|
||
len(m.CacheRatio) == 0 &&
|
||
len(m.CreateCacheRatio) == 0 &&
|
||
len(m.ImageRatio) == 0 &&
|
||
len(m.AudioRatio) == 0 &&
|
||
len(m.AudioCompletionRatio) == 0
|
||
}
|
||
|
||
// globalPriceFields 全局价格 Option 键与 PriceExportModelMaps 字段的绑定关系。
|
||
var globalPriceFields = []struct {
|
||
optionKey string
|
||
getField func(*PriceExportModelMaps) map[string]float64
|
||
}{
|
||
{"ModelPrice", func(m *PriceExportModelMaps) map[string]float64 { return m.ModelPrice }},
|
||
{"ModelRatio", func(m *PriceExportModelMaps) map[string]float64 { return m.ModelRatio }},
|
||
{"CompletionRatio", func(m *PriceExportModelMaps) map[string]float64 { return m.CompletionRatio }},
|
||
{"CacheRatio", func(m *PriceExportModelMaps) map[string]float64 { return m.CacheRatio }},
|
||
{"CreateCacheRatio", func(m *PriceExportModelMaps) map[string]float64 { return m.CreateCacheRatio }},
|
||
{"ImageRatio", func(m *PriceExportModelMaps) map[string]float64 { return m.ImageRatio }},
|
||
{"AudioRatio", func(m *PriceExportModelMaps) map[string]float64 { return m.AudioRatio }},
|
||
{"AudioCompletionRatio", func(m *PriceExportModelMaps) map[string]float64 { return m.AudioCompletionRatio }},
|
||
}
|
||
|
||
// channelPriceFields 渠道价格 Option 键(Channel 前缀)与 PriceExportModelMaps 字段的绑定关系。
|
||
var channelPriceFields = []struct {
|
||
optionKey string
|
||
getField func(*PriceExportModelMaps) map[string]float64
|
||
}{
|
||
{"ChannelModelPrice", func(m *PriceExportModelMaps) map[string]float64 { return m.ModelPrice }},
|
||
{"ChannelModelRatio", func(m *PriceExportModelMaps) map[string]float64 { return m.ModelRatio }},
|
||
{"ChannelCompletionRatio", func(m *PriceExportModelMaps) map[string]float64 { return m.CompletionRatio }},
|
||
{"ChannelCacheRatio", func(m *PriceExportModelMaps) map[string]float64 { return m.CacheRatio }},
|
||
{"ChannelCreateCacheRatio", func(m *PriceExportModelMaps) map[string]float64 { return m.CreateCacheRatio }},
|
||
{"ChannelImageRatio", func(m *PriceExportModelMaps) map[string]float64 { return m.ImageRatio }},
|
||
{"ChannelAudioRatio", func(m *PriceExportModelMaps) map[string]float64 { return m.AudioRatio }},
|
||
{"ChannelAudioCompletionRatio", func(m *PriceExportModelMaps) map[string]float64 { return m.AudioCompletionRatio }},
|
||
}
|
||
|
||
// ─── 导出 ─────────────────────────────────────────────────────────────────────
|
||
|
||
// ExportPrices 导出全局及各渠道模型价格配置。
|
||
// GET /api/admin/price/export
|
||
func ExportPrices(c *gin.Context) {
|
||
// 读取全局价格
|
||
globalPrices := PriceExportModelMaps{
|
||
ModelPrice: parseFloatMapSafe(readOptionStr("ModelPrice")),
|
||
ModelRatio: parseFloatMapSafe(readOptionStr("ModelRatio")),
|
||
CompletionRatio: parseFloatMapSafe(readOptionStr("CompletionRatio")),
|
||
CacheRatio: parseFloatMapSafe(readOptionStr("CacheRatio")),
|
||
CreateCacheRatio: parseFloatMapSafe(readOptionStr("CreateCacheRatio")),
|
||
ImageRatio: parseFloatMapSafe(readOptionStr("ImageRatio")),
|
||
AudioRatio: parseFloatMapSafe(readOptionStr("AudioRatio")),
|
||
AudioCompletionRatio: parseFloatMapSafe(readOptionStr("AudioCompletionRatio")),
|
||
}
|
||
|
||
// 读取渠道维度价格(结构:channel_id(str) → model_name → value)
|
||
chModelPrice := parseNestedFloatMapSafe(readOptionStr("ChannelModelPrice"))
|
||
chModelRatio := parseNestedFloatMapSafe(readOptionStr("ChannelModelRatio"))
|
||
chCompletionRatio := parseNestedFloatMapSafe(readOptionStr("ChannelCompletionRatio"))
|
||
chCacheRatio := parseNestedFloatMapSafe(readOptionStr("ChannelCacheRatio"))
|
||
chCreateCacheRatio := parseNestedFloatMapSafe(readOptionStr("ChannelCreateCacheRatio"))
|
||
chImageRatio := parseNestedFloatMapSafe(readOptionStr("ChannelImageRatio"))
|
||
chAudioRatio := parseNestedFloatMapSafe(readOptionStr("ChannelAudioRatio"))
|
||
chAudioCompletionRatio := parseNestedFloatMapSafe(readOptionStr("ChannelAudioCompletionRatio"))
|
||
|
||
// 收集所有出现过的 channel_id(字符串形式)
|
||
channelIDSet := map[string]struct{}{}
|
||
for _, nm := range []map[string]map[string]float64{
|
||
chModelPrice, chModelRatio, chCompletionRatio, chCacheRatio,
|
||
chCreateCacheRatio, chImageRatio, chAudioRatio, chAudioCompletionRatio,
|
||
} {
|
||
for id := range nm {
|
||
channelIDSet[id] = struct{}{}
|
||
}
|
||
}
|
||
|
||
// 查询 channel_id → channel_name 映射
|
||
idNameMap, err := model.GetChannelIdNameMap()
|
||
if err != nil {
|
||
common.ApiError(c, err)
|
||
return
|
||
}
|
||
|
||
// 构建渠道导出条目(每个 channel_id 对应一个条目,避免同名渠道数据混淆)
|
||
channelEntries := make([]PriceExportChannelEntry, 0, len(channelIDSet))
|
||
for idStr := range channelIDSet {
|
||
name, ok := idNameMap[idStr]
|
||
if !ok {
|
||
// 渠道已删除:保留占位符,导入时会被自动跳过
|
||
name = fmt.Sprintf("__deleted__channel_id_%s", idStr)
|
||
}
|
||
channelEntries = append(channelEntries, PriceExportChannelEntry{
|
||
ChannelName: name,
|
||
Models: PriceExportModelMaps{
|
||
ModelPrice: safeFloatMap(chModelPrice[idStr]),
|
||
ModelRatio: safeFloatMap(chModelRatio[idStr]),
|
||
CompletionRatio: safeFloatMap(chCompletionRatio[idStr]),
|
||
CacheRatio: safeFloatMap(chCacheRatio[idStr]),
|
||
CreateCacheRatio: safeFloatMap(chCreateCacheRatio[idStr]),
|
||
ImageRatio: safeFloatMap(chImageRatio[idStr]),
|
||
AudioRatio: safeFloatMap(chAudioRatio[idStr]),
|
||
AudioCompletionRatio: safeFloatMap(chAudioCompletionRatio[idStr]),
|
||
},
|
||
})
|
||
}
|
||
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"success": true,
|
||
"data": PriceExportData{
|
||
GlobalPrices: globalPrices,
|
||
Channels: channelEntries,
|
||
},
|
||
})
|
||
}
|
||
|
||
// ─── 导入 ─────────────────────────────────────────────────────────────────────
|
||
|
||
// ImportPrices 导入价格配置(增量同步,仅新增/更新,不删除已有数据)。
|
||
// POST /api/admin/price/import
|
||
func ImportPrices(c *gin.Context) {
|
||
var payload PriceExportData
|
||
if err := common.DecodeJson(c.Request.Body, &payload); err != nil {
|
||
c.JSON(http.StatusBadRequest, gin.H{
|
||
"success": false,
|
||
"message": "JSON 格式错误,请上传合法的导出文件",
|
||
})
|
||
return
|
||
}
|
||
|
||
// 防止空数据写入
|
||
globalEmpty := isModelMapsEmpty(payload.GlobalPrices)
|
||
if globalEmpty && len(payload.Channels) == 0 {
|
||
c.JSON(http.StatusBadRequest, gin.H{
|
||
"success": false,
|
||
"message": "导入文件中未包含任何价格数据,已取消导入",
|
||
})
|
||
return
|
||
}
|
||
|
||
result := &PriceImportResult{
|
||
ChannelStats: []PriceImportChannelStat{},
|
||
SkippedChannels: []string{},
|
||
}
|
||
|
||
// ── 1. 同步全局模型价格 ────────────────────────────────────────────────────
|
||
if !globalEmpty {
|
||
added, updated, err := doSyncGlobalPrices(payload.GlobalPrices)
|
||
if err != nil {
|
||
c.JSON(http.StatusInternalServerError, gin.H{
|
||
"success": false,
|
||
"message": fmt.Sprintf("全局价格同步失败: %v", err),
|
||
})
|
||
return
|
||
}
|
||
result.GlobalAdded = added
|
||
result.GlobalUpdated = updated
|
||
}
|
||
|
||
// ── 2. 同步渠道模型价格 ────────────────────────────────────────────────────
|
||
for _, chEntry := range payload.Channels {
|
||
chName := strings.TrimSpace(chEntry.ChannelName)
|
||
if chName == "" {
|
||
continue
|
||
}
|
||
// 跳过已删除渠道的占位符(导出时自动生成的前缀)
|
||
if strings.HasPrefix(chName, "__deleted__channel_id_") {
|
||
result.SkippedChannels = append(result.SkippedChannels, chName)
|
||
continue
|
||
}
|
||
// 渠道模型数据为空时跳过
|
||
if isModelMapsEmpty(chEntry.Models) {
|
||
continue
|
||
}
|
||
|
||
channelIDs, err := model.GetChannelIDsByName(chName)
|
||
if err != nil {
|
||
c.JSON(http.StatusInternalServerError, gin.H{
|
||
"success": false,
|
||
"message": fmt.Sprintf("查询渠道 '%s' 失败: %v", chName, err),
|
||
})
|
||
return
|
||
}
|
||
if len(channelIDs) == 0 {
|
||
result.SkippedChannels = append(result.SkippedChannels, chName)
|
||
continue
|
||
}
|
||
|
||
// 对所有同名渠道执行增量同步
|
||
stat := PriceImportChannelStat{ChannelName: chName}
|
||
for _, channelID := range channelIDs {
|
||
added, updated, err := doSyncChannelPrices(channelID, chEntry.Models)
|
||
if err != nil {
|
||
c.JSON(http.StatusInternalServerError, gin.H{
|
||
"success": false,
|
||
"message": fmt.Sprintf("渠道 '%s'(id=%d) 价格同步失败: %v", chName, channelID, err),
|
||
})
|
||
return
|
||
}
|
||
stat.Added += added
|
||
stat.Updated += updated
|
||
}
|
||
result.ChannelStats = append(result.ChannelStats, stat)
|
||
}
|
||
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"success": true,
|
||
"message": "价格导入成功",
|
||
"data": result,
|
||
})
|
||
}
|
||
|
||
// ─── 内部同步实现 ──────────────────────────────────────────────────────────────
|
||
|
||
// doSyncGlobalPrices 增量合并全局模型价格到 Options,返回 (added, updated, error)。
|
||
// 逐 Option 键处理:读取当前值 → 合并 → 通过 model.UpdateOption 写回(同时刷新内存缓存与 ratio_setting)。
|
||
func doSyncGlobalPrices(incoming PriceExportModelMaps) (totalAdded, totalUpdated int, err error) {
|
||
for _, field := range globalPriceFields {
|
||
src := field.getField(&incoming)
|
||
if len(src) == 0 {
|
||
continue
|
||
}
|
||
current := parseFloatMapSafe(readOptionStr(field.optionKey))
|
||
added, updated := mergeFloatMapCounting(current, src)
|
||
totalAdded += added
|
||
totalUpdated += updated
|
||
|
||
if err = model.UpdateOption(field.optionKey, marshalToJSON(current)); err != nil {
|
||
return 0, 0, fmt.Errorf("写入 Option[%s] 失败: %w", field.optionKey, err)
|
||
}
|
||
}
|
||
return
|
||
}
|
||
|
||
// doSyncChannelPrices 增量合并单渠道模型价格到对应的渠道 Option,返回 (added, updated, error)。
|
||
// 每个渠道 Option 的 value 为 map[channel_id(str)]map[model_name]float64 的嵌套结构。
|
||
func doSyncChannelPrices(channelID int, incoming PriceExportModelMaps) (totalAdded, totalUpdated int, err error) {
|
||
idStr := fmt.Sprintf("%d", channelID)
|
||
|
||
for _, field := range channelPriceFields {
|
||
src := field.getField(&incoming)
|
||
if len(src) == 0 {
|
||
continue
|
||
}
|
||
// 读取整个渠道 Option 的当前嵌套 map
|
||
fullMap := parseNestedFloatMapSafe(readOptionStr(field.optionKey))
|
||
if fullMap[idStr] == nil {
|
||
fullMap[idStr] = map[string]float64{}
|
||
}
|
||
added, updated := mergeFloatMapCounting(fullMap[idStr], src)
|
||
totalAdded += added
|
||
totalUpdated += updated
|
||
|
||
if err = model.UpdateOption(field.optionKey, marshalToJSON(fullMap)); err != nil {
|
||
return 0, 0, fmt.Errorf("写入 Option[%s] 失败: %w", field.optionKey, err)
|
||
}
|
||
}
|
||
return
|
||
}
|