tokenFactory/controller/tf_open_sync.go

197 lines
6.5 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package controller
import (
"net/http"
"os"
"strconv"
"strings"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/constant"
"github.com/QuantumNous/new-api/model"
"github.com/QuantumNous/new-api/setting/ratio_setting"
"github.com/gin-gonic/gin"
)
// coalesceStr 返回第一个非空字符串,若均为空则返回空串。
func coalesceStr(vals ...string) string {
for _, v := range vals {
if strings.TrimSpace(v) != "" {
return strings.TrimSpace(v)
}
}
return ""
}
// tfOpenSyncExportRow 仅用于跨站同步导出,不包含渠道密钥。
type tfOpenSyncExportRow struct {
ID int `json:"id"`
Name string `json:"name"`
Models string `json:"models"`
Group string `json:"group"`
Status int `json:"status"`
Type int `json:"type"`
ChannelNo string `json:"channel_no"`
SupplierApplicationID int `json:"supplier_application_id"`
SupplierAlias string `json:"supplier_alias,omitempty"`
SupplierType string `json:"supplier_type,omitempty"`
CompanyLogoURL string `json:"company_logo_url,omitempty"`
PriceDiscountPercent *float64 `json:"price_discount_percent,omitempty"`
MarkupDiscountRate *float64 `json:"markup_discount_rate,omitempty"`
ModelMapping string `json:"model_mapping,omitempty"`
ModelPrice map[string]float64 `json:"model_price,omitempty"`
ModelRatio map[string]float64 `json:"model_ratio,omitempty"`
}
func authorizeTFOpenSyncExport(c *gin.Context) bool {
secretEnv := strings.TrimSpace(os.Getenv("TOKENFACTORY_OPEN_SYNC_SECRET"))
hdr := strings.TrimSpace(c.GetHeader("X-TokenFactory-Open-Sync-Secret"))
if secretEnv != "" && hdr != "" && hdr == secretEnv {
return true
}
auth := strings.TrimSpace(c.GetHeader("Authorization"))
if auth == "" {
return false
}
if strings.HasPrefix(strings.ToLower(auth), "bearer ") {
auth = strings.TrimSpace(auth[7:])
}
// 优先支持普通 API 令牌sk- 前缀),方便上游发放非管理员同步 key。
tokenKey := strings.TrimPrefix(auth, "sk-")
if tokenKey != "" {
if _, err := model.ValidateUserToken(tokenKey); err == nil {
return true
}
}
// 兼容 access token不再强制管理员角色
return model.ValidateAccessToken(auth) != nil
}
// TFOpenSyncExportChannels 供子站 TokenFactoryOpen 同步:返回全站渠道(脱敏)及渠道级定价/倍率。
// 鉴权:环境变量 TOKENFACTORY_OPEN_SYNC_SECRET + 请求头;或 Bearer 携带可用普通 API 令牌sk-);或有效 access token。
func TFOpenSyncExportChannels(c *gin.Context) {
if !authorizeTFOpenSyncExport(c) {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "无权导出请使用同步密钥X-TokenFactory-Open-Sync-Secret或 Bearer 携带可用令牌sk- 或 access token",
})
return
}
var channels []*model.Channel
q := model.DB.Model(&model.Channel{}).
Omit("key").
Where("type <> ?", constant.ChannelTypeTokenFactoryOpen).
Order("supplier_application_id asc, channel_no asc, id asc")
if err := q.Find(&channels).Error; err != nil {
common.SysError("tf_open_sync export: " + err.Error())
c.JSON(http.StatusOK, gin.H{"success": false, "message": "查询渠道失败"})
return
}
appIDs := make([]int, 0)
seen := make(map[int]struct{})
for _, ch := range channels {
if ch != nil && ch.SupplierApplicationID > 0 {
if _, ok := seen[ch.SupplierApplicationID]; !ok {
seen[ch.SupplierApplicationID] = struct{}{}
appIDs = append(appIDs, ch.SupplierApplicationID)
}
}
}
aliasByAppID := make(map[int]string, len(appIDs))
logoByAppID := make(map[int]string, len(appIDs))
supplierTypeByAppID := make(map[int]string, len(appIDs))
if len(appIDs) > 0 {
type appRow struct {
ID int `gorm:"column:id"`
Alias string `gorm:"column:supplier_alias"`
LogoURL string `gorm:"column:company_logo_url"`
SupplierType string `gorm:"column:supplier_type"`
}
var apps []appRow
if err := model.DB.Table("supplier_applications").
Select("id, supplier_alias, company_logo_url, supplier_type").
Where("id IN ?", appIDs).
Scan(&apps).Error; err == nil {
for _, a := range apps {
aliasByAppID[a.ID] = strings.TrimSpace(a.Alias)
logoByAppID[a.ID] = strings.TrimSpace(a.LogoURL)
supplierTypeByAppID[a.ID] = strings.TrimSpace(a.SupplierType)
}
}
}
priceAll := ratio_setting.GetChannelModelPriceCopy()
ratioAll := ratio_setting.GetChannelModelRatioCopy()
out := make([]tfOpenSyncExportRow, 0, len(channels))
for _, ch := range channels {
if ch == nil {
continue
}
idStr := strconv.Itoa(ch.Id)
mp := priceAll[idStr]
mr := ratioAll[idStr]
if len(mp) == 0 {
mp = nil
}
if len(mr) == 0 {
mr = nil
}
// 仅导出该渠道 models 列表中出现的模型,控制体积
modelSet := make(map[string]struct{})
for _, m := range ch.GetModels() {
mk := ratio_setting.FormatMatchingModelName(m)
if mk != "" {
modelSet[mk] = struct{}{}
}
}
if len(modelSet) > 0 {
filteredP := make(map[string]float64)
filteredR := make(map[string]float64)
for mk := range modelSet {
if mp != nil {
if v, ok := mp[mk]; ok {
filteredP[mk] = v
}
}
if mr != nil {
if v, ok := mr[mk]; ok {
filteredR[mk] = v
}
}
}
if len(filteredP) == 0 {
filteredP = nil
}
if len(filteredR) == 0 {
filteredR = nil
}
mp, mr = filteredP, filteredR
}
out = append(out, tfOpenSyncExportRow{
ID: ch.Id,
Name: ch.Name,
Models: ch.Models,
Group: ch.Group,
Status: ch.Status,
Type: ch.Type,
ChannelNo: strings.TrimSpace(ch.ChannelNo),
SupplierApplicationID: ch.SupplierApplicationID,
SupplierAlias: aliasByAppID[ch.SupplierApplicationID],
SupplierType: coalesceStr(supplierTypeByAppID[ch.SupplierApplicationID], strings.TrimSpace(ch.SupplierType)),
CompanyLogoURL: coalesceStr(logoByAppID[ch.SupplierApplicationID], strings.TrimSpace(ch.CompanyLogoURL)),
PriceDiscountPercent: ch.PriceDiscountPercent,
MarkupDiscountRate: ch.MarkupDiscountRate,
ModelMapping: strings.TrimSpace(ch.GetModelMapping()),
ModelPrice: mp,
ModelRatio: mr,
})
}
common.ApiSuccess(c, gin.H{"channels": out})
}