1702 lines
55 KiB
Go
1702 lines
55 KiB
Go
package model
|
||
|
||
import (
|
||
"database/sql/driver"
|
||
"encoding/json"
|
||
"errors"
|
||
"fmt"
|
||
"math/rand"
|
||
"strconv"
|
||
"strings"
|
||
"sync"
|
||
|
||
"github.com/QuantumNous/new-api/common"
|
||
"github.com/QuantumNous/new-api/constant"
|
||
"github.com/QuantumNous/new-api/dto"
|
||
"github.com/QuantumNous/new-api/setting/ratio_setting"
|
||
"github.com/QuantumNous/new-api/types"
|
||
|
||
"github.com/samber/lo"
|
||
"gorm.io/gorm"
|
||
)
|
||
|
||
type Channel struct {
|
||
Id int `json:"id"`
|
||
Type int `json:"type" gorm:"default:0"`
|
||
CompanyLogoURL string `json:"company_logo_url" gorm:"type:varchar(1024);not null;default:'';comment:企业Logo图片URL"`
|
||
SupplierType string `json:"supplier_type" gorm:"type:varchar(64);not null;default:'';comment:供应商类型"`
|
||
Key string `json:"key" gorm:"not null"`
|
||
OpenAIOrganization *string `json:"openai_organization"`
|
||
TestModel *string `json:"test_model"`
|
||
Status int `json:"status" gorm:"default:1"`
|
||
Name string `json:"name" gorm:"index"`
|
||
Weight *uint `json:"weight" gorm:"default:0"`
|
||
CreatedTime int64 `json:"created_time" gorm:"bigint"`
|
||
TestTime int64 `json:"test_time" gorm:"bigint"` // 最近一次渠道测试时间(Unix 秒级时间戳)
|
||
ResponseTime int `json:"response_time"` // 最近一次渠道测试响应耗时(毫秒)
|
||
BaseURL *string `json:"base_url" gorm:"column:base_url;default:''"`
|
||
Other string `json:"other"`
|
||
Balance float64 `json:"balance"` // 剩余额度(美元计价展示);同步/手动写入;计费扣减与 used_quota 同步累加
|
||
BalanceUpdatedTime int64 `json:"balance_updated_time" gorm:"bigint"`
|
||
Models string `json:"models"`
|
||
Group string `json:"group" gorm:"type:varchar(64);default:'default'"`
|
||
UsedQuota int64 `json:"used_quota" gorm:"bigint;default:0"`
|
||
ModelMapping *string `json:"model_mapping" gorm:"type:text"`
|
||
//MaxInputTokens *int `json:"max_input_tokens" gorm:"default:0"`
|
||
StatusCodeMapping *string `json:"status_code_mapping" gorm:"type:varchar(1024);default:''"`
|
||
Priority *int64 `json:"priority" gorm:"bigint;default:0"`
|
||
AutoBan *int `json:"auto_ban" gorm:"default:1"`
|
||
OtherInfo string `json:"other_info"` // 渠道扩展信息(JSON),测试相关键:last_test_success/last_test_message/last_test_model/last_test_time
|
||
Tag *string `json:"tag" gorm:"index"`
|
||
Setting *string `json:"setting" gorm:"type:text"` // 渠道额外设置
|
||
ParamOverride *string `json:"param_override" gorm:"type:text"`
|
||
HeaderOverride *string `json:"header_override" gorm:"type:text"`
|
||
Remark *string `json:"remark" gorm:"type:varchar(255)" validate:"max=255"`
|
||
// add after v0.8.5
|
||
ChannelInfo ChannelInfo `json:"channel_info" gorm:"type:json"`
|
||
|
||
OtherSettings string `json:"settings" gorm:"column:settings"` // 其他设置,存储azure版本等不需要检索的信息,详见dto.ChannelOtherSettings
|
||
OwnerUserID int `json:"owner_user_id" gorm:"type:int;index;default:0"` // 渠道归属用户ID(供应商场景)
|
||
SupplierApplicationID int `json:"supplier_application_id" gorm:"type:int;index;default:0"` // 关联 supplier_applications.id
|
||
ChannelNo string `json:"channel_no" gorm:"type:varchar(32);default:'';index;comment:供应商渠道编号 c1,c2 递增"`
|
||
// RouteSlug 全局唯一渠道路由后缀;调用格式 {model}/{route_slug} 强制该渠道(该渠道下所有模型共用此后缀)。
|
||
RouteSlug string `json:"route_slug" gorm:"type:varchar(32);not null;default:'';index"`
|
||
SupplierName string `json:"supplier_name,omitempty" gorm:"-"` // 供应商用户名(由控制器回填,不落库)
|
||
|
||
// 成本折扣率(百分数,100=原价无折扣,60=六折/按原价×0.6 计费)。nil=数据库默认/未设,按 100 处理。使用指针以便 GORM Updates 时可将 0% 写回。
|
||
PriceDiscountPercent *float64 `json:"price_discount_percent" gorm:"type:double precision;default:100"`
|
||
// 加价折扣率(百分数,0=不加价;如 5 表示在全局价格基础上加 5% 作为附加收益)。nil=数据库默认/未设,按 0 处理。
|
||
MarkupDiscountRate *float64 `json:"markup_discount_rate" gorm:"type:double precision;default:0"`
|
||
|
||
// cache info
|
||
Keys []string `json:"-" gorm:"-"`
|
||
}
|
||
|
||
type ChannelInfo struct {
|
||
IsMultiKey bool `json:"is_multi_key"` // 是否多Key模式
|
||
MultiKeySize int `json:"multi_key_size"` // 多Key模式下的Key数量
|
||
MultiKeyStatusList map[int]int `json:"multi_key_status_list"` // key状态列表,key index -> status
|
||
MultiKeyDisabledReason map[int]string `json:"multi_key_disabled_reason,omitempty"` // key禁用原因列表,key index -> reason
|
||
MultiKeyDisabledTime map[int]int64 `json:"multi_key_disabled_time,omitempty"` // key禁用时间列表,key index -> time
|
||
MultiKeyPollingIndex int `json:"multi_key_polling_index"` // 多Key模式下轮询的key索引
|
||
MultiKeyMode constant.MultiKeyMode `json:"multi_key_mode"`
|
||
}
|
||
|
||
// Value implements driver.Valuer interface
|
||
func (c ChannelInfo) Value() (driver.Value, error) {
|
||
return common.Marshal(&c)
|
||
}
|
||
|
||
// Scan implements sql.Scanner interface
|
||
func (c *ChannelInfo) Scan(value interface{}) error {
|
||
bytesValue, _ := value.([]byte)
|
||
return common.Unmarshal(bytesValue, c)
|
||
}
|
||
|
||
func (channel *Channel) GetKeys() []string {
|
||
if channel.Key == "" {
|
||
return []string{}
|
||
}
|
||
if len(channel.Keys) > 0 {
|
||
return channel.Keys
|
||
}
|
||
trimmed := strings.TrimSpace(channel.Key)
|
||
// If the key starts with '[', try to parse it as a JSON array (e.g., for Vertex AI scenarios)
|
||
if strings.HasPrefix(trimmed, "[") {
|
||
var arr []json.RawMessage
|
||
if err := common.Unmarshal([]byte(trimmed), &arr); err == nil {
|
||
res := make([]string, len(arr))
|
||
for i, v := range arr {
|
||
res[i] = string(v)
|
||
}
|
||
return res
|
||
}
|
||
}
|
||
// Otherwise, fall back to splitting by newline
|
||
keys := strings.Split(strings.Trim(channel.Key, "\n"), "\n")
|
||
return keys
|
||
}
|
||
|
||
func (channel *Channel) GetNextEnabledKey() (string, int, *types.TokenFactoryError) {
|
||
// If not in multi-key mode, return the original key string directly.
|
||
if !channel.ChannelInfo.IsMultiKey {
|
||
return channel.Key, 0, nil
|
||
}
|
||
|
||
// Obtain all keys (split by \n)
|
||
keys := channel.GetKeys()
|
||
if len(keys) == 0 {
|
||
// No keys available, return error, should disable the channel
|
||
return "", 0, types.NewError(errors.New("no keys available"), types.ErrorCodeChannelNoAvailableKey)
|
||
}
|
||
|
||
lock := GetChannelPollingLock(channel.Id)
|
||
lock.Lock()
|
||
defer lock.Unlock()
|
||
|
||
statusList := channel.ChannelInfo.MultiKeyStatusList
|
||
// helper to get key status, default to enabled when missing
|
||
getStatus := func(idx int) int {
|
||
if statusList == nil {
|
||
return common.ChannelStatusEnabled
|
||
}
|
||
if status, ok := statusList[idx]; ok {
|
||
return status
|
||
}
|
||
return common.ChannelStatusEnabled
|
||
}
|
||
|
||
// Collect indexes of enabled keys
|
||
enabledIdx := make([]int, 0, len(keys))
|
||
for i := range keys {
|
||
if getStatus(i) == common.ChannelStatusEnabled {
|
||
enabledIdx = append(enabledIdx, i)
|
||
}
|
||
}
|
||
// If no specific status list or none enabled, return an explicit error so caller can
|
||
// properly handle a channel with no available keys (e.g. mark channel disabled).
|
||
// Returning the first key here caused requests to keep using an already-disabled key.
|
||
if len(enabledIdx) == 0 {
|
||
return "", 0, types.NewError(errors.New("no enabled keys"), types.ErrorCodeChannelNoAvailableKey)
|
||
}
|
||
|
||
switch channel.ChannelInfo.MultiKeyMode {
|
||
case constant.MultiKeyModeRandom:
|
||
// Randomly pick one enabled key
|
||
selectedIdx := enabledIdx[rand.Intn(len(enabledIdx))]
|
||
return keys[selectedIdx], selectedIdx, nil
|
||
case constant.MultiKeyModePolling:
|
||
// Use channel-specific lock to ensure thread-safe polling
|
||
|
||
channelInfo, err := CacheGetChannelInfo(channel.Id)
|
||
if err != nil {
|
||
return "", 0, types.NewError(err, types.ErrorCodeGetChannelFailed, types.ErrOptionWithSkipRetry())
|
||
}
|
||
//println("before polling index:", channel.ChannelInfo.MultiKeyPollingIndex)
|
||
defer func() {
|
||
if common.DebugEnabled {
|
||
println(fmt.Sprintf("channel %d polling index: %d", channel.Id, channel.ChannelInfo.MultiKeyPollingIndex))
|
||
}
|
||
if !common.MemoryCacheEnabled {
|
||
_ = channel.SaveChannelInfo()
|
||
} else {
|
||
// CacheUpdateChannel(channel)
|
||
}
|
||
}()
|
||
// Start from the saved polling index and look for the next enabled key
|
||
start := channelInfo.MultiKeyPollingIndex
|
||
if start < 0 || start >= len(keys) {
|
||
start = 0
|
||
}
|
||
for i := 0; i < len(keys); i++ {
|
||
idx := (start + i) % len(keys)
|
||
if getStatus(idx) == common.ChannelStatusEnabled {
|
||
// update polling index for next call (point to the next position)
|
||
channel.ChannelInfo.MultiKeyPollingIndex = (idx + 1) % len(keys)
|
||
return keys[idx], idx, nil
|
||
}
|
||
}
|
||
// Fallback – should not happen, but return first enabled key
|
||
return keys[enabledIdx[0]], enabledIdx[0], nil
|
||
default:
|
||
// Unknown mode, default to first enabled key (or original key string)
|
||
return keys[enabledIdx[0]], enabledIdx[0], nil
|
||
}
|
||
}
|
||
|
||
func (channel *Channel) SaveChannelInfo() error {
|
||
return DB.Model(channel).Update("channel_info", channel.ChannelInfo).Error
|
||
}
|
||
|
||
func (channel *Channel) GetModels() []string {
|
||
if channel.Models == "" {
|
||
return []string{}
|
||
}
|
||
return strings.Split(strings.Trim(channel.Models, ","), ",")
|
||
}
|
||
|
||
func (channel *Channel) GetGroups() []string {
|
||
if channel.Group == "" {
|
||
return []string{}
|
||
}
|
||
groups := strings.Split(strings.Trim(channel.Group, ","), ",")
|
||
for i, group := range groups {
|
||
groups[i] = strings.TrimSpace(group)
|
||
}
|
||
return groups
|
||
}
|
||
|
||
func (channel *Channel) GetOtherInfo() map[string]interface{} {
|
||
otherInfo := make(map[string]interface{})
|
||
if channel.OtherInfo != "" {
|
||
err := common.Unmarshal([]byte(channel.OtherInfo), &otherInfo)
|
||
if err != nil {
|
||
common.SysLog(fmt.Sprintf("failed to unmarshal other info: channel_id=%d, tag=%s, name=%s, error=%v", channel.Id, channel.GetTag(), channel.Name, err))
|
||
}
|
||
}
|
||
return otherInfo
|
||
}
|
||
|
||
func (channel *Channel) SetOtherInfo(otherInfo map[string]interface{}) {
|
||
otherInfoBytes, err := json.Marshal(otherInfo)
|
||
if err != nil {
|
||
common.SysLog(fmt.Sprintf("failed to marshal other info: channel_id=%d, tag=%s, name=%s, error=%v", channel.Id, channel.GetTag(), channel.Name, err))
|
||
return
|
||
}
|
||
channel.OtherInfo = string(otherInfoBytes)
|
||
}
|
||
|
||
func (channel *Channel) GetTag() string {
|
||
if channel.Tag == nil {
|
||
return ""
|
||
}
|
||
return *channel.Tag
|
||
}
|
||
|
||
func (channel *Channel) SetTag(tag string) {
|
||
channel.Tag = &tag
|
||
}
|
||
|
||
func (channel *Channel) GetAutoBan() bool {
|
||
if channel.AutoBan == nil {
|
||
return false
|
||
}
|
||
return *channel.AutoBan == 1
|
||
}
|
||
|
||
func (channel *Channel) Save() error {
|
||
return DB.Save(channel).Error
|
||
}
|
||
|
||
func (channel *Channel) SaveWithoutKey() error {
|
||
if channel.Id == 0 {
|
||
return errors.New("channel ID is 0")
|
||
}
|
||
return DB.Omit("key").Save(channel).Error
|
||
}
|
||
|
||
func GetAllChannels(startIdx int, num int, selectAll bool, idSort bool) ([]*Channel, error) {
|
||
var channels []*Channel
|
||
var err error
|
||
order := "priority desc"
|
||
if idSort {
|
||
order = "id desc"
|
||
}
|
||
if selectAll {
|
||
err = DB.Order(order).Find(&channels).Error
|
||
} else {
|
||
err = DB.Order(order).Limit(num).Offset(startIdx).Omit("key").Find(&channels).Error
|
||
}
|
||
return channels, err
|
||
}
|
||
|
||
// ListChannelsByOwnerUser 分页查询指定归属用户创建的渠道。
|
||
func ListChannelsByOwnerUser(ownerUserID int, startIdx int, num int) ([]*Channel, int64, error) {
|
||
var (
|
||
channels []*Channel
|
||
total int64
|
||
)
|
||
query := DB.Model(&Channel{}).Where("owner_user_id = ?", ownerUserID)
|
||
if err := query.Count(&total).Error; err != nil {
|
||
return nil, 0, err
|
||
}
|
||
if err := query.Order("id desc").Limit(num).Offset(startIdx).Omit("key").Find(&channels).Error; err != nil {
|
||
return nil, 0, err
|
||
}
|
||
return channels, total, nil
|
||
}
|
||
|
||
// ListAllSupplierChannels 分页查询所有供应商归属渠道(管理员视角)。
|
||
func ListAllSupplierChannels(startIdx int, num int) ([]*Channel, int64, error) {
|
||
var (
|
||
channels []*Channel
|
||
total int64
|
||
)
|
||
query := DB.Model(&Channel{}).Where("owner_user_id > ? AND supplier_application_id > ?", 0, 0)
|
||
if err := query.Count(&total).Error; err != nil {
|
||
return nil, 0, err
|
||
}
|
||
if err := query.Order("id desc").Limit(num).Offset(startIdx).Omit("key").Find(&channels).Error; err != nil {
|
||
return nil, 0, err
|
||
}
|
||
return channels, total, nil
|
||
}
|
||
|
||
// SupplierChannelSearchFilter 供应商渠道搜索过滤参数。
|
||
type SupplierChannelSearchFilter struct {
|
||
ChannelID int
|
||
Keyword string
|
||
Supplier string
|
||
Name string
|
||
Key string
|
||
BaseURL string
|
||
ModelKeyword string
|
||
Group string
|
||
}
|
||
|
||
// ChannelSimplePricingItem pricing 页面使用的渠道精简信息。
|
||
type ChannelSimplePricingItem struct {
|
||
ChannelID int `json:"channel_id"`
|
||
ChannelName string `json:"channel_name"`
|
||
ChannelNo string `json:"channel_no"`
|
||
SupplierAlias string `json:"supplier_alias"`
|
||
CompanyLogoURL string `json:"company_logo_url"`
|
||
SupplierType string `json:"supplier_type"`
|
||
}
|
||
|
||
// ChannelPricingMeta 定价接口计算渠道维度价格所需的渠道行(含供应商别名)。
|
||
type ChannelPricingMeta struct {
|
||
ChannelID int `gorm:"column:channel_id"`
|
||
SupplierApplicationID int `gorm:"column:supplier_application_id"`
|
||
ChannelNo string `gorm:"column:channel_no"`
|
||
Models string `gorm:"column:models"`
|
||
SupplierAlias *string `gorm:"column:supplier_alias"`
|
||
CompanyLogoURL string `gorm:"column:company_logo_url"`
|
||
SupplierType string `gorm:"column:supplier_type"`
|
||
PriceDiscountPercent *float64 `gorm:"column:price_discount_percent"`
|
||
MarkupDiscountRate *float64 `gorm:"column:markup_discount_rate"`
|
||
}
|
||
|
||
// ListChannelsForPricing 查询定价页使用的渠道列表。
|
||
func ListChannelsForPricing() ([]ChannelSimplePricingItem, error) {
|
||
items := make([]ChannelSimplePricingItem, 0)
|
||
err := DB.Model(&Channel{}).
|
||
Select("channels.id AS channel_id, channels.name AS channel_name, channels.channel_no, COALESCE(supplier_applications.supplier_alias, '') AS supplier_alias, COALESCE(supplier_applications.company_logo_url, '') AS company_logo_url, COALESCE(NULLIF(supplier_applications.supplier_type, ''), channels.supplier_type, '') AS supplier_type").
|
||
Joins("LEFT JOIN supplier_applications ON supplier_applications.id = channels.supplier_application_id").
|
||
Where("channels.status = ?", common.ChannelStatusEnabled).
|
||
Order("channels.id ASC").
|
||
Scan(&items).Error
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
return items, nil
|
||
}
|
||
|
||
// ListChannelPricingMeta 查询全部渠道的定价元数据(用于按模型汇总渠道价)。
|
||
func ListChannelPricingMeta() ([]ChannelPricingMeta, error) {
|
||
items := make([]ChannelPricingMeta, 0)
|
||
err := DB.Model(&Channel{}).
|
||
Select("channels.id AS channel_id, channels.supplier_application_id, channels.channel_no, channels.models, channels.price_discount_percent, channels.markup_discount_rate, supplier_applications.supplier_alias, COALESCE(NULLIF(supplier_applications.company_logo_url, ''), channels.company_logo_url, '') AS company_logo_url, COALESCE(NULLIF(supplier_applications.supplier_type, ''), channels.supplier_type, '') AS supplier_type").
|
||
Joins("LEFT JOIN supplier_applications ON supplier_applications.id = channels.supplier_application_id").
|
||
Where("channels.status = ?", common.ChannelStatusEnabled).
|
||
Order("channels.id ASC").
|
||
Scan(&items).Error
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
return items, nil
|
||
}
|
||
|
||
// ChannelModelsRawContains 判断 channels.models 逗号列表是否包含指定模型名(去空格精确匹配)。
|
||
func ChannelModelsRawContains(modelsRaw string, modelName string) bool {
|
||
if strings.TrimSpace(modelsRaw) == "" || strings.TrimSpace(modelName) == "" {
|
||
return false
|
||
}
|
||
for _, m := range strings.Split(modelsRaw, ",") {
|
||
if strings.TrimSpace(m) == modelName {
|
||
return true
|
||
}
|
||
}
|
||
return false
|
||
}
|
||
|
||
// SearchSupplierChannels 搜索供应商渠道(供应商只查自己,管理员可查全部供应商渠道)。
|
||
func SearchSupplierChannels(ownerUserID *int, startIdx int, num int, filter SupplierChannelSearchFilter) ([]*Channel, int64, error) {
|
||
var (
|
||
channels []*Channel
|
||
total int64
|
||
)
|
||
query := DB.Model(&Channel{})
|
||
if ownerUserID != nil {
|
||
query = query.Where("owner_user_id = ?", *ownerUserID)
|
||
} else {
|
||
query = query.Where("owner_user_id > ? AND supplier_application_id > ?", 0, 0)
|
||
}
|
||
if filter.ChannelID > 0 {
|
||
query = query.Where("id = ?", filter.ChannelID)
|
||
}
|
||
if filter.Keyword != "" {
|
||
keywordLike := "%" + filter.Keyword + "%"
|
||
query = query.Where("(name LIKE ? OR "+commonKeyCol+" LIKE ? OR base_url LIKE ?)", keywordLike, keywordLike, keywordLike)
|
||
}
|
||
if filter.Supplier != "" {
|
||
query = query.Joins("LEFT JOIN users ON users.id = channels.owner_user_id").Where("users.username LIKE ?", "%"+filter.Supplier+"%")
|
||
}
|
||
if filter.Name != "" {
|
||
query = query.Where("name LIKE ?", "%"+filter.Name+"%")
|
||
}
|
||
if filter.Key != "" {
|
||
// commonKeyCol 兼容不同数据库对保留字 key 的转义差异
|
||
query = query.Where(commonKeyCol+" = ? OR "+commonKeyCol+" LIKE ?", filter.Key, "%"+filter.Key+"%")
|
||
}
|
||
if filter.BaseURL != "" {
|
||
query = query.Where("base_url LIKE ?", "%"+filter.BaseURL+"%")
|
||
}
|
||
if filter.ModelKeyword != "" {
|
||
query = query.Where("models LIKE ?", "%"+filter.ModelKeyword+"%")
|
||
}
|
||
if filter.Group != "" && filter.Group != "null" {
|
||
var groupCondition string
|
||
if common.UsingMySQL {
|
||
groupCondition = "CONCAT(',', " + commonGroupCol + ", ',') LIKE ?"
|
||
} else {
|
||
groupCondition = "(',' || " + commonGroupCol + " || ',') LIKE ?"
|
||
}
|
||
query = query.Where(groupCondition, "%,"+filter.Group+",%")
|
||
}
|
||
if err := query.Count(&total).Error; err != nil {
|
||
return nil, 0, err
|
||
}
|
||
if err := query.Order("id desc").Limit(num).Offset(startIdx).Omit("key").Find(&channels).Error; err != nil {
|
||
return nil, 0, err
|
||
}
|
||
return channels, total, nil
|
||
}
|
||
|
||
// ParseSupplierChannelIDFilter 解析渠道ID筛选参数(支持空值)。
|
||
func ParseSupplierChannelIDFilter(raw string) (int, error) {
|
||
raw = strings.TrimSpace(raw)
|
||
if raw == "" {
|
||
return 0, nil
|
||
}
|
||
id, err := strconv.Atoi(raw)
|
||
if err != nil {
|
||
return 0, err
|
||
}
|
||
return id, nil
|
||
}
|
||
|
||
func GetChannelsByTag(tag string, idSort bool, selectAll bool) ([]*Channel, error) {
|
||
var channels []*Channel
|
||
order := "priority desc"
|
||
if idSort {
|
||
order = "id desc"
|
||
}
|
||
query := DB.Where("tag = ?", tag).Order(order)
|
||
if !selectAll {
|
||
query = query.Omit("key")
|
||
}
|
||
err := query.Find(&channels).Error
|
||
return channels, err
|
||
}
|
||
|
||
func SearchChannels(keyword string, group string, model string, idSort bool) ([]*Channel, error) {
|
||
var channels []*Channel
|
||
modelsCol := "`models`"
|
||
|
||
// 如果是 PostgreSQL,使用双引号
|
||
if common.UsingPostgreSQL {
|
||
modelsCol = `"models"`
|
||
}
|
||
|
||
baseURLCol := "`base_url`"
|
||
// 如果是 PostgreSQL,使用双引号
|
||
if common.UsingPostgreSQL {
|
||
baseURLCol = `"base_url"`
|
||
}
|
||
|
||
order := "priority desc"
|
||
if idSort {
|
||
order = "id desc"
|
||
}
|
||
|
||
// 构造基础查询
|
||
baseQuery := DB.Model(&Channel{}).Omit("key")
|
||
|
||
// 构造WHERE子句
|
||
var whereClause string
|
||
var args []interface{}
|
||
if group != "" && group != "null" {
|
||
var groupCondition string
|
||
if common.UsingMySQL {
|
||
groupCondition = `CONCAT(',', ` + commonGroupCol + `, ',') LIKE ?`
|
||
} else {
|
||
// sqlite, PostgreSQL
|
||
groupCondition = `(',' || ` + commonGroupCol + ` || ',') LIKE ?`
|
||
}
|
||
whereClause = "(id = ? OR name LIKE ? OR " + commonKeyCol + " = ? OR " + baseURLCol + " LIKE ?) AND " + modelsCol + ` LIKE ? AND ` + groupCondition
|
||
args = append(args, common.String2Int(keyword), "%"+keyword+"%", keyword, "%"+keyword+"%", "%"+model+"%", "%,"+group+",%")
|
||
} else {
|
||
whereClause = "(id = ? OR name LIKE ? OR " + commonKeyCol + " = ? OR " + baseURLCol + " LIKE ?) AND " + modelsCol + " LIKE ?"
|
||
args = append(args, common.String2Int(keyword), "%"+keyword+"%", keyword, "%"+keyword+"%", "%"+model+"%")
|
||
}
|
||
|
||
// 执行查询
|
||
err := baseQuery.Where(whereClause, args...).Order(order).Find(&channels).Error
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
return channels, nil
|
||
}
|
||
|
||
func GetChannelById(id int, selectAll bool) (*Channel, error) {
|
||
channel := &Channel{Id: id}
|
||
var err error = nil
|
||
if selectAll {
|
||
err = DB.First(channel, "id = ?", id).Error
|
||
} else {
|
||
err = DB.Omit("key").First(channel, "id = ?", id).Error
|
||
}
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
return channel, nil
|
||
}
|
||
|
||
// ResolveSupplierApplicationIDByAlias 根据供应商别名返回 supplier_application_id。
|
||
//
|
||
// - "P0"(不区分大小写)返回 0,代表未归属任何 supplier_applications 的渠道;
|
||
// - 其他别名到 supplier_applications.supplier_alias 精确匹配后返回其 id;
|
||
//
|
||
// 未找到别名时返回 (0, false, err);找到 P0 或匹配记录时 found=true。
|
||
func ResolveSupplierApplicationIDByAlias(alias string) (supplierApplicationID int, found bool, err error) {
|
||
aliasTrim := strings.TrimSpace(alias)
|
||
if aliasTrim == "" {
|
||
return 0, false, fmt.Errorf("alias 不能为空")
|
||
}
|
||
if strings.EqualFold(aliasTrim, "P0") {
|
||
return 0, true, nil
|
||
}
|
||
var app SupplierApplication
|
||
if err := DB.Select("id").Where("supplier_alias = ?", aliasTrim).First(&app).Error; err != nil {
|
||
return 0, false, fmt.Errorf("供应商别名未找到: %s", aliasTrim)
|
||
}
|
||
return app.ID, true, nil
|
||
}
|
||
|
||
// FindChannelIDBySupplierAliasAndNo 根据「供应商别名」+「渠道编号」定位具体渠道 ID。
|
||
//
|
||
// 支持两种别名形式:
|
||
// 1. "P0":特指未归属供应商申请(supplier_application_id = 0)的渠道;
|
||
// 2. 其他:先到 supplier_applications.supplier_alias 精确匹配,取得 id 后再按
|
||
// (supplier_application_id, channel_no) 查找渠道。
|
||
//
|
||
// 该方法仅返回启用状态的渠道。未找到时返回 0 与具体错误信息。
|
||
func FindChannelIDBySupplierAliasAndNo(alias string, channelNo string) (int, error) {
|
||
noTrim := strings.TrimSpace(channelNo)
|
||
if noTrim == "" {
|
||
return 0, fmt.Errorf("channel_no 不能为空")
|
||
}
|
||
supplierApplicationID, _, err := ResolveSupplierApplicationIDByAlias(alias)
|
||
if err != nil {
|
||
return 0, err
|
||
}
|
||
|
||
var channel Channel
|
||
if err := DB.Select("id, status").
|
||
Where("supplier_application_id = ? AND channel_no = ?", supplierApplicationID, noTrim).
|
||
First(&channel).Error; err != nil {
|
||
return 0, fmt.Errorf("未找到渠道: %s/%s", strings.TrimSpace(alias), noTrim)
|
||
}
|
||
if channel.Status != common.ChannelStatusEnabled {
|
||
return 0, fmt.Errorf("渠道已禁用: %s/%s", strings.TrimSpace(alias), noTrim)
|
||
}
|
||
return channel.Id, nil
|
||
}
|
||
|
||
// ValidateSupplierChannelNoUnique 校验同一 supplier_application_id 下 channel_no 不重复(空编号不校验)。
|
||
// supplier_application_id 为 0(P0 未归属)时同样校验,以支持 alias/cN 唯一路由。
|
||
// excludeChannelID 大于 0 时排除自身,用于更新;新建时传 0。
|
||
func ValidateSupplierChannelNoUnique(excludeChannelID int, supplierApplicationID int, channelNo string) error {
|
||
no := strings.TrimSpace(channelNo)
|
||
if no == "" {
|
||
return nil
|
||
}
|
||
q := DB.Model(&Channel{}).Where("supplier_application_id = ? AND channel_no = ?", supplierApplicationID, no)
|
||
if excludeChannelID > 0 {
|
||
q = q.Where("id <> ?", excludeChannelID)
|
||
}
|
||
var cnt int64
|
||
if err := q.Count(&cnt).Error; err != nil {
|
||
return err
|
||
}
|
||
if cnt > 0 {
|
||
return fmt.Errorf("该供应商下已存在相同渠道编号")
|
||
}
|
||
return nil
|
||
}
|
||
|
||
func maxChannelNoNumericSuffixForSupplier(tx *gorm.DB, supplierApplicationID int) (int, error) {
|
||
var existing []string
|
||
if err := tx.Model(&Channel{}).Where("supplier_application_id = ?", supplierApplicationID).Pluck("channel_no", &existing).Error; err != nil {
|
||
return 0, err
|
||
}
|
||
maxN := 0
|
||
for _, no := range existing {
|
||
no = strings.TrimSpace(no)
|
||
if len(no) >= 2 && no[0] == 'c' {
|
||
if n, err := strconv.Atoi(no[1:]); err == nil && n > maxN {
|
||
maxN = n
|
||
}
|
||
}
|
||
}
|
||
return maxN, nil
|
||
}
|
||
|
||
// allocateSupplierChannelNosInBatch 保留为空实现:
|
||
// 兼容历史调用链,但不再为新渠道自动生成 channel_no。
|
||
func allocateSupplierChannelNosInBatch(tx *gorm.DB, batch []Channel) error {
|
||
_ = tx
|
||
_ = batch
|
||
return nil
|
||
}
|
||
|
||
// BackfillSupplierChannelNo 保留为空实现:
|
||
// 兼容启动流程,不再为历史数据补全 channel_no。
|
||
func BackfillSupplierChannelNo() error {
|
||
return nil
|
||
}
|
||
|
||
// TFOpenUpstreamPricing 与 BatchInsertChannelsWithTfOpenUpstreamPricing 中 channels 顺序一一对应。
|
||
type TFOpenUpstreamPricing struct {
|
||
ModelPrice map[string]float64
|
||
ModelRatio map[string]float64
|
||
}
|
||
|
||
func BatchInsertChannels(channels []Channel) error {
|
||
return batchInsertChannelsWithOptionalTfOpenPricing(channels, nil)
|
||
}
|
||
|
||
// BatchInsertChannelsWithTfOpenUpstreamPricing 批量插入渠道并在落库后合并上游渠道级定价/倍率(用于 TokenFactoryOpen 同步)。
|
||
func BatchInsertChannelsWithTfOpenUpstreamPricing(channels []Channel, pricing []TFOpenUpstreamPricing) error {
|
||
return batchInsertChannelsWithOptionalTfOpenPricing(channels, pricing)
|
||
}
|
||
|
||
func batchInsertChannelsWithOptionalTfOpenPricing(channels []Channel, tfOpenPricing []TFOpenUpstreamPricing) error {
|
||
if len(channels) == 0 {
|
||
return nil
|
||
}
|
||
tx := DB.Begin()
|
||
if tx.Error != nil {
|
||
return tx.Error
|
||
}
|
||
defer func() {
|
||
if r := recover(); r != nil {
|
||
tx.Rollback()
|
||
}
|
||
}()
|
||
|
||
createdChannels := make([]Channel, 0, len(channels))
|
||
for _, chunk := range lo.Chunk(channels, 50) {
|
||
if err := allocateSupplierChannelNosInBatch(tx, chunk); err != nil {
|
||
tx.Rollback()
|
||
return err
|
||
}
|
||
if err := tx.Create(&chunk).Error; err != nil {
|
||
tx.Rollback()
|
||
return err
|
||
}
|
||
for i := range chunk {
|
||
if chunk[i].Id <= 0 {
|
||
continue
|
||
}
|
||
assigned, err := assignRouteSlugInTx(tx, chunk[i].Id, chunk[i].RouteSlug)
|
||
if err != nil {
|
||
tx.Rollback()
|
||
return err
|
||
}
|
||
chunk[i].RouteSlug = assigned
|
||
}
|
||
createdChannels = append(createdChannels, chunk...)
|
||
for _, channel_ := range chunk {
|
||
if err := channel_.AddAbilities(tx); err != nil {
|
||
tx.Rollback()
|
||
return err
|
||
}
|
||
}
|
||
}
|
||
if err := tx.Commit().Error; err != nil {
|
||
return err
|
||
}
|
||
// Best effort: initialize channel-level model pricing entries so newly imported
|
||
// channels are visible in channel pricing editor without blank mappings.
|
||
ensureChannelModelPricingDefaults(createdChannels)
|
||
if len(tfOpenPricing) > 0 {
|
||
mergeTFOpenUpstreamPricingAfterInsert(createdChannels, tfOpenPricing)
|
||
}
|
||
return nil
|
||
}
|
||
|
||
func mergeTFOpenUpstreamPricingAfterInsert(created []Channel, pricing []TFOpenUpstreamPricing) {
|
||
if len(created) == 0 || len(pricing) == 0 {
|
||
return
|
||
}
|
||
n := len(created)
|
||
if len(pricing) < n {
|
||
n = len(pricing)
|
||
}
|
||
priceCopy := ratio_setting.GetChannelModelPriceCopy()
|
||
ratioCopy := ratio_setting.GetChannelModelRatioCopy()
|
||
changed := false
|
||
for i := 0; i < n; i++ {
|
||
ch := created[i]
|
||
if ch.Id <= 0 {
|
||
continue
|
||
}
|
||
p := pricing[i]
|
||
if len(p.ModelPrice) == 0 && len(p.ModelRatio) == 0 {
|
||
continue
|
||
}
|
||
cid := strconv.Itoa(ch.Id)
|
||
if _, ok := priceCopy[cid]; !ok {
|
||
priceCopy[cid] = make(map[string]float64)
|
||
}
|
||
if _, ok := ratioCopy[cid]; !ok {
|
||
ratioCopy[cid] = make(map[string]float64)
|
||
}
|
||
for k, v := range p.ModelPrice {
|
||
mk := ratio_setting.FormatMatchingModelName(k)
|
||
if mk == "" {
|
||
continue
|
||
}
|
||
priceCopy[cid][mk] = v
|
||
changed = true
|
||
}
|
||
for k, v := range p.ModelRatio {
|
||
mk := ratio_setting.FormatMatchingModelName(k)
|
||
if mk == "" {
|
||
continue
|
||
}
|
||
ratioCopy[cid][mk] = v
|
||
changed = true
|
||
}
|
||
}
|
||
if !changed {
|
||
return
|
||
}
|
||
priceJSONBytes, err := common.Marshal(priceCopy)
|
||
if err != nil {
|
||
common.SysLog(fmt.Sprintf("mergeTFOpen upstream price marshal: %v", err))
|
||
return
|
||
}
|
||
ratioJSONBytes, err := common.Marshal(ratioCopy)
|
||
if err != nil {
|
||
common.SysLog(fmt.Sprintf("mergeTFOpen upstream ratio marshal: %v", err))
|
||
return
|
||
}
|
||
if err := UpdateOption("ChannelModelPrice", string(priceJSONBytes)); err != nil {
|
||
common.SysLog(fmt.Sprintf("mergeTFOpen update ChannelModelPrice: %v", err))
|
||
}
|
||
if err := UpdateOption("ChannelModelRatio", string(ratioJSONBytes)); err != nil {
|
||
common.SysLog(fmt.Sprintf("mergeTFOpen update ChannelModelRatio: %v", err))
|
||
}
|
||
}
|
||
|
||
func ensureChannelModelPricingDefaults(channels []Channel) {
|
||
if len(channels) == 0 {
|
||
return
|
||
}
|
||
channelModelPrice := ratio_setting.GetChannelModelPriceCopy()
|
||
channelModelRatio := ratio_setting.GetChannelModelRatioCopy()
|
||
changed := false
|
||
|
||
for _, ch := range channels {
|
||
if ch.Id <= 0 {
|
||
continue
|
||
}
|
||
channelID := strconv.Itoa(ch.Id)
|
||
if _, ok := channelModelPrice[channelID]; !ok {
|
||
channelModelPrice[channelID] = make(map[string]float64)
|
||
}
|
||
if _, ok := channelModelRatio[channelID]; !ok {
|
||
channelModelRatio[channelID] = make(map[string]float64)
|
||
}
|
||
seen := make(map[string]struct{})
|
||
for _, rawModel := range ch.GetModels() {
|
||
modelName := strings.TrimSpace(rawModel)
|
||
if modelName == "" {
|
||
continue
|
||
}
|
||
modelKey := ratio_setting.FormatMatchingModelName(modelName)
|
||
if _, ok := seen[modelKey]; ok {
|
||
continue
|
||
}
|
||
seen[modelKey] = struct{}{}
|
||
|
||
needPrice := false
|
||
if _, exists := channelModelPrice[channelID][modelKey]; !exists {
|
||
needPrice = true
|
||
}
|
||
needRatio := false
|
||
if _, exists := channelModelRatio[channelID][modelKey]; !exists {
|
||
needRatio = true
|
||
}
|
||
if !needPrice && !needRatio {
|
||
continue
|
||
}
|
||
// 无渠道专属值时用全局同模型名(已 FormatMatching)兜底;与定价解析顺序一致。
|
||
if needPrice {
|
||
if modelPrice, ok := ratio_setting.GetModelPrice(modelKey, false); ok {
|
||
channelModelPrice[channelID][modelKey] = modelPrice
|
||
changed = true
|
||
}
|
||
}
|
||
if needRatio {
|
||
if modelRatio, ok, _ := ratio_setting.GetModelRatio(modelKey); ok {
|
||
channelModelRatio[channelID][modelKey] = modelRatio
|
||
changed = true
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
if !changed {
|
||
return
|
||
}
|
||
priceJSONBytes, err := common.Marshal(channelModelPrice)
|
||
if err != nil {
|
||
common.SysLog(fmt.Sprintf("failed to marshal ChannelModelPrice: %v", err))
|
||
return
|
||
}
|
||
ratioJSONBytes, err := common.Marshal(channelModelRatio)
|
||
if err != nil {
|
||
common.SysLog(fmt.Sprintf("failed to marshal ChannelModelRatio: %v", err))
|
||
return
|
||
}
|
||
if err := UpdateOption("ChannelModelPrice", string(priceJSONBytes)); err != nil {
|
||
common.SysLog(fmt.Sprintf("failed to update ChannelModelPrice option: %v", err))
|
||
}
|
||
if err := UpdateOption("ChannelModelRatio", string(ratioJSONBytes)); err != nil {
|
||
common.SysLog(fmt.Sprintf("failed to update ChannelModelRatio option: %v", err))
|
||
}
|
||
}
|
||
|
||
func BatchDeleteChannels(ids []int) error {
|
||
if len(ids) == 0 {
|
||
return nil
|
||
}
|
||
// 使用事务 分批删除channel表和abilities表
|
||
tx := DB.Begin()
|
||
if tx.Error != nil {
|
||
return tx.Error
|
||
}
|
||
for _, chunk := range lo.Chunk(ids, 200) {
|
||
if err := tx.Where("id in (?)", chunk).Delete(&Channel{}).Error; err != nil {
|
||
tx.Rollback()
|
||
return err
|
||
}
|
||
if err := tx.Where("channel_id in (?)", chunk).Delete(&Ability{}).Error; err != nil {
|
||
tx.Rollback()
|
||
return err
|
||
}
|
||
}
|
||
return tx.Commit().Error
|
||
}
|
||
|
||
func (channel *Channel) GetPriority() int64 {
|
||
if channel.Priority == nil {
|
||
return 0
|
||
}
|
||
return *channel.Priority
|
||
}
|
||
|
||
func (channel *Channel) GetWeight() int {
|
||
if channel.Weight == nil {
|
||
return 0
|
||
}
|
||
return int(*channel.Weight)
|
||
}
|
||
|
||
func (channel *Channel) GetBaseURL() string {
|
||
if channel.BaseURL == nil {
|
||
return ""
|
||
}
|
||
url := *channel.BaseURL
|
||
if url == "" {
|
||
url = constant.ChannelBaseURLs[channel.Type]
|
||
}
|
||
return url
|
||
}
|
||
|
||
func (channel *Channel) GetModelMapping() string {
|
||
if channel.ModelMapping == nil {
|
||
return ""
|
||
}
|
||
return *channel.ModelMapping
|
||
}
|
||
|
||
func (channel *Channel) GetStatusCodeMapping() string {
|
||
if channel.StatusCodeMapping == nil {
|
||
return ""
|
||
}
|
||
return *channel.StatusCodeMapping
|
||
}
|
||
|
||
func (channel *Channel) Insert() error {
|
||
var assigned string
|
||
err := DB.Transaction(func(tx *gorm.DB) error {
|
||
batch := []Channel{*channel}
|
||
if err := allocateSupplierChannelNosInBatch(tx, batch); err != nil {
|
||
return err
|
||
}
|
||
*channel = batch[0]
|
||
if err := tx.Create(channel).Error; err != nil {
|
||
return err
|
||
}
|
||
var err2 error
|
||
assigned, err2 = assignRouteSlugInTx(tx, channel.Id, channel.RouteSlug)
|
||
if err2 != nil {
|
||
return err2
|
||
}
|
||
return channel.AddAbilities(tx)
|
||
})
|
||
if err != nil {
|
||
return err
|
||
}
|
||
channel.RouteSlug = assigned
|
||
return nil
|
||
}
|
||
|
||
func (channel *Channel) Update() error {
|
||
// If this is a multi-key channel, recalculate MultiKeySize based on the current key list to avoid inconsistency after editing keys
|
||
if channel.ChannelInfo.IsMultiKey {
|
||
var keyStr string
|
||
if channel.Key != "" {
|
||
keyStr = channel.Key
|
||
} else {
|
||
// If key is not provided, read the existing key from the database
|
||
if existing, err := GetChannelById(channel.Id, true); err == nil {
|
||
keyStr = existing.Key
|
||
}
|
||
}
|
||
// Parse the key list (supports newline separation or JSON array)
|
||
keys := []string{}
|
||
if keyStr != "" {
|
||
trimmed := strings.TrimSpace(keyStr)
|
||
if strings.HasPrefix(trimmed, "[") {
|
||
var arr []json.RawMessage
|
||
if err := common.Unmarshal([]byte(trimmed), &arr); err == nil {
|
||
keys = make([]string, len(arr))
|
||
for i, v := range arr {
|
||
keys[i] = string(v)
|
||
}
|
||
}
|
||
}
|
||
if len(keys) == 0 { // fallback to newline split
|
||
keys = strings.Split(strings.Trim(keyStr, "\n"), "\n")
|
||
}
|
||
}
|
||
channel.ChannelInfo.MultiKeySize = len(keys)
|
||
// Clean up status data that exceeds the new key count to prevent index out of range
|
||
if channel.ChannelInfo.MultiKeyStatusList != nil {
|
||
for idx := range channel.ChannelInfo.MultiKeyStatusList {
|
||
if idx >= channel.ChannelInfo.MultiKeySize {
|
||
delete(channel.ChannelInfo.MultiKeyStatusList, idx)
|
||
}
|
||
}
|
||
}
|
||
}
|
||
var err error
|
||
err = DB.Model(channel).Updates(channel).Error
|
||
if err != nil {
|
||
return err
|
||
}
|
||
DB.Model(channel).First(channel, "id = ?", channel.Id)
|
||
err = channel.UpdateAbilities(nil)
|
||
|
||
return err
|
||
}
|
||
|
||
func (channel *Channel) UpdateResponseTime(responseTime int64) {
|
||
err := DB.Model(channel).Select("response_time", "test_time").Updates(Channel{
|
||
TestTime: common.GetTimestamp(),
|
||
ResponseTime: int(responseTime),
|
||
}).Error
|
||
if err != nil {
|
||
common.SysLog(fmt.Sprintf("failed to update response time: channel_id=%d, error=%v", channel.Id, err))
|
||
}
|
||
}
|
||
|
||
// UpdateTestResult 持久化渠道测试结果(成功/失败)、响应时间与测试时间。
|
||
// 同时会把最近一次测试的状态信息写入 other_info,供前端与运维排查使用。
|
||
func (channel *Channel) UpdateTestResult(success bool, responseTime int64, message string, modelName string) {
|
||
err := DB.Transaction(func(tx *gorm.DB) error {
|
||
var dbChannel Channel
|
||
if err := tx.Select("id", "other_info").First(&dbChannel, "id = ?", channel.Id).Error; err != nil {
|
||
return err
|
||
}
|
||
|
||
otherInfo := dbChannel.GetOtherInfo()
|
||
otherInfo["last_test_success"] = success
|
||
otherInfo["last_test_message"] = message
|
||
otherInfo["last_test_model"] = modelName
|
||
otherInfo["last_test_time"] = common.GetTimestamp()
|
||
dbChannel.SetOtherInfo(otherInfo)
|
||
|
||
return tx.Model(&Channel{}).Where("id = ?", channel.Id).Select("response_time", "test_time", "other_info").Updates(Channel{
|
||
TestTime: common.GetTimestamp(),
|
||
ResponseTime: int(responseTime),
|
||
OtherInfo: dbChannel.OtherInfo,
|
||
}).Error
|
||
})
|
||
if err != nil {
|
||
common.SysLog(fmt.Sprintf("failed to update test result: channel_id=%d, error=%v", channel.Id, err))
|
||
}
|
||
}
|
||
|
||
func (channel *Channel) UpdateBalance(balance float64) {
|
||
err := DB.Model(channel).Select("balance_updated_time", "balance").Updates(Channel{
|
||
BalanceUpdatedTime: common.GetTimestamp(),
|
||
Balance: balance,
|
||
}).Error
|
||
if err != nil {
|
||
common.SysLog(fmt.Sprintf("failed to update balance: channel_id=%d, error=%v", channel.Id, err))
|
||
}
|
||
}
|
||
|
||
func (channel *Channel) Delete() error {
|
||
var err error
|
||
err = DB.Delete(channel).Error
|
||
if err != nil {
|
||
return err
|
||
}
|
||
err = channel.DeleteAbilities()
|
||
return err
|
||
}
|
||
|
||
var channelStatusLock sync.Mutex
|
||
|
||
// channelPollingLocks stores locks for each channel.id to ensure thread-safe polling
|
||
var channelPollingLocks sync.Map
|
||
|
||
// GetChannelPollingLock returns or creates a mutex for the given channel ID
|
||
func GetChannelPollingLock(channelId int) *sync.Mutex {
|
||
if lock, exists := channelPollingLocks.Load(channelId); exists {
|
||
return lock.(*sync.Mutex)
|
||
}
|
||
// Create new lock for this channel
|
||
newLock := &sync.Mutex{}
|
||
actual, _ := channelPollingLocks.LoadOrStore(channelId, newLock)
|
||
return actual.(*sync.Mutex)
|
||
}
|
||
|
||
// CleanupChannelPollingLocks removes locks for channels that no longer exist
|
||
// This is optional and can be called periodically to prevent memory leaks
|
||
func CleanupChannelPollingLocks() {
|
||
var activeChannelIds []int
|
||
DB.Model(&Channel{}).Pluck("id", &activeChannelIds)
|
||
|
||
activeChannelSet := make(map[int]bool)
|
||
for _, id := range activeChannelIds {
|
||
activeChannelSet[id] = true
|
||
}
|
||
|
||
channelPollingLocks.Range(func(key, value interface{}) bool {
|
||
channelId := key.(int)
|
||
if !activeChannelSet[channelId] {
|
||
channelPollingLocks.Delete(channelId)
|
||
}
|
||
return true
|
||
})
|
||
}
|
||
|
||
func handlerMultiKeyUpdate(channel *Channel, usingKey string, status int, reason string) {
|
||
keys := channel.GetKeys()
|
||
if len(keys) == 0 {
|
||
channel.Status = status
|
||
} else {
|
||
var keyIndex int
|
||
for i, key := range keys {
|
||
if key == usingKey {
|
||
keyIndex = i
|
||
break
|
||
}
|
||
}
|
||
if channel.ChannelInfo.MultiKeyStatusList == nil {
|
||
channel.ChannelInfo.MultiKeyStatusList = make(map[int]int)
|
||
}
|
||
if status == common.ChannelStatusEnabled {
|
||
delete(channel.ChannelInfo.MultiKeyStatusList, keyIndex)
|
||
} else {
|
||
channel.ChannelInfo.MultiKeyStatusList[keyIndex] = status
|
||
if channel.ChannelInfo.MultiKeyDisabledReason == nil {
|
||
channel.ChannelInfo.MultiKeyDisabledReason = make(map[int]string)
|
||
}
|
||
if channel.ChannelInfo.MultiKeyDisabledTime == nil {
|
||
channel.ChannelInfo.MultiKeyDisabledTime = make(map[int]int64)
|
||
}
|
||
channel.ChannelInfo.MultiKeyDisabledReason[keyIndex] = reason
|
||
channel.ChannelInfo.MultiKeyDisabledTime[keyIndex] = common.GetTimestamp()
|
||
}
|
||
if len(channel.ChannelInfo.MultiKeyStatusList) >= channel.ChannelInfo.MultiKeySize {
|
||
channel.Status = common.ChannelStatusAutoDisabled
|
||
info := channel.GetOtherInfo()
|
||
info["status_reason"] = "All keys are disabled"
|
||
info["status_time"] = common.GetTimestamp()
|
||
channel.SetOtherInfo(info)
|
||
}
|
||
}
|
||
}
|
||
|
||
func UpdateChannelStatus(channelId int, usingKey string, status int, reason string) bool {
|
||
if common.MemoryCacheEnabled {
|
||
channelStatusLock.Lock()
|
||
defer channelStatusLock.Unlock()
|
||
|
||
channelCache, _ := CacheGetChannel(channelId)
|
||
if channelCache == nil {
|
||
return false
|
||
}
|
||
if channelCache.ChannelInfo.IsMultiKey {
|
||
// Use per-channel lock to prevent concurrent map read/write with GetNextEnabledKey
|
||
pollingLock := GetChannelPollingLock(channelId)
|
||
pollingLock.Lock()
|
||
// 如果是多Key模式,更新缓存中的状态
|
||
handlerMultiKeyUpdate(channelCache, usingKey, status, reason)
|
||
pollingLock.Unlock()
|
||
//CacheUpdateChannel(channelCache)
|
||
//return true
|
||
} else {
|
||
// 如果缓存渠道存在,且状态已是目标状态,直接返回
|
||
if channelCache.Status == status {
|
||
return false
|
||
}
|
||
CacheUpdateChannelStatus(channelId, status)
|
||
}
|
||
}
|
||
|
||
shouldUpdateAbilities := false
|
||
defer func() {
|
||
if shouldUpdateAbilities {
|
||
err := UpdateAbilityStatus(channelId, status == common.ChannelStatusEnabled)
|
||
if err != nil {
|
||
common.SysLog(fmt.Sprintf("failed to update ability status: channel_id=%d, error=%v", channelId, err))
|
||
}
|
||
}
|
||
}()
|
||
channel, err := GetChannelById(channelId, true)
|
||
if err != nil {
|
||
return false
|
||
} else {
|
||
if channel.Status == status {
|
||
return false
|
||
}
|
||
|
||
if channel.ChannelInfo.IsMultiKey {
|
||
beforeStatus := channel.Status
|
||
// Protect map writes with the same per-channel lock used by readers
|
||
pollingLock := GetChannelPollingLock(channelId)
|
||
pollingLock.Lock()
|
||
handlerMultiKeyUpdate(channel, usingKey, status, reason)
|
||
pollingLock.Unlock()
|
||
if beforeStatus != channel.Status {
|
||
shouldUpdateAbilities = true
|
||
}
|
||
} else {
|
||
info := channel.GetOtherInfo()
|
||
info["status_reason"] = reason
|
||
info["status_time"] = common.GetTimestamp()
|
||
channel.SetOtherInfo(info)
|
||
channel.Status = status
|
||
shouldUpdateAbilities = true
|
||
}
|
||
err = channel.SaveWithoutKey()
|
||
if err != nil {
|
||
common.SysLog(fmt.Sprintf("failed to update channel status: channel_id=%d, status=%d, error=%v", channel.Id, status, err))
|
||
return false
|
||
}
|
||
}
|
||
return true
|
||
}
|
||
|
||
func EnableChannelByTag(tag string) error {
|
||
err := DB.Model(&Channel{}).Where("tag = ?", tag).Update("status", common.ChannelStatusEnabled).Error
|
||
if err != nil {
|
||
return err
|
||
}
|
||
err = UpdateAbilityStatusByTag(tag, true)
|
||
return err
|
||
}
|
||
|
||
func DisableChannelByTag(tag string) error {
|
||
err := DB.Model(&Channel{}).Where("tag = ?", tag).Update("status", common.ChannelStatusManuallyDisabled).Error
|
||
if err != nil {
|
||
return err
|
||
}
|
||
err = UpdateAbilityStatusByTag(tag, false)
|
||
return err
|
||
}
|
||
|
||
func EditChannelByTag(tag string, newTag *string, modelMapping *string, models *string, group *string, priority *int64, weight *uint, paramOverride *string, headerOverride *string) error {
|
||
updateData := Channel{}
|
||
shouldReCreateAbilities := false
|
||
updatedTag := tag
|
||
// 如果 newTag 不为空且不等于 tag,则更新 tag
|
||
if newTag != nil && *newTag != tag {
|
||
updateData.Tag = newTag
|
||
updatedTag = *newTag
|
||
}
|
||
if modelMapping != nil && *modelMapping != "" {
|
||
updateData.ModelMapping = modelMapping
|
||
}
|
||
if models != nil && *models != "" {
|
||
shouldReCreateAbilities = true
|
||
updateData.Models = *models
|
||
}
|
||
if group != nil && *group != "" {
|
||
shouldReCreateAbilities = true
|
||
updateData.Group = *group
|
||
}
|
||
if priority != nil {
|
||
updateData.Priority = priority
|
||
}
|
||
if weight != nil {
|
||
updateData.Weight = weight
|
||
}
|
||
if paramOverride != nil {
|
||
updateData.ParamOverride = paramOverride
|
||
}
|
||
if headerOverride != nil {
|
||
updateData.HeaderOverride = headerOverride
|
||
}
|
||
|
||
err := DB.Model(&Channel{}).Where("tag = ?", tag).Updates(updateData).Error
|
||
if err != nil {
|
||
return err
|
||
}
|
||
if shouldReCreateAbilities {
|
||
channels, err := GetChannelsByTag(updatedTag, false, false)
|
||
if err == nil {
|
||
for _, channel := range channels {
|
||
err = channel.UpdateAbilities(nil)
|
||
if err != nil {
|
||
common.SysLog(fmt.Sprintf("failed to update abilities: channel_id=%d, tag=%s, error=%v", channel.Id, channel.GetTag(), err))
|
||
}
|
||
}
|
||
}
|
||
} else {
|
||
err := UpdateAbilityByTag(tag, newTag, priority, weight)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
}
|
||
return nil
|
||
}
|
||
|
||
func UpdateChannelUsedQuota(id int, quota int) {
|
||
if common.BatchUpdateEnabled {
|
||
addNewRecord(BatchUpdateTypeChannelUsedQuota, id, quota)
|
||
return
|
||
}
|
||
updateChannelUsedQuota(id, quota)
|
||
}
|
||
|
||
func updateChannelUsedQuota(id int, quota int) {
|
||
if quota == 0 {
|
||
return
|
||
}
|
||
var before Channel
|
||
if err := DB.Select("id", "balance", "used_quota", "name", "other_info").Where("id = ?", id).First(&before).Error; err != nil {
|
||
common.SysLog(fmt.Sprintf("failed to load channel before used quota update: channel_id=%d, err=%v", id, err))
|
||
return
|
||
}
|
||
oldRemaining := before.Balance
|
||
|
||
deltaUSD := 0.0
|
||
if common.QuotaPerUnit > 0 {
|
||
deltaUSD = float64(quota) / common.QuotaPerUnit
|
||
}
|
||
updates := map[string]interface{}{
|
||
"used_quota": gorm.Expr("used_quota + ?", quota),
|
||
}
|
||
if deltaUSD != 0 {
|
||
updates["balance"] = gorm.Expr("CASE WHEN (balance - ?) < 0 THEN 0 ELSE (balance - ?) END", deltaUSD, deltaUSD)
|
||
}
|
||
if err := DB.Model(&Channel{}).Where("id = ?", id).Updates(updates).Error; err != nil {
|
||
common.SysLog(fmt.Sprintf("failed to update channel used quota: channel_id=%d, delta_quota=%d, error=%v", id, quota, err))
|
||
return
|
||
}
|
||
|
||
var after Channel
|
||
if err := DB.Select("id", "balance", "used_quota", "name", "other_info").Where("id = ?", id).First(&after).Error; err != nil {
|
||
common.SysLog(fmt.Sprintf("failed to load channel after used quota update: channel_id=%d, err=%v", id, err))
|
||
return
|
||
}
|
||
notifyChannelBalanceAlertOnUsageDelta(&after, oldRemaining, quota)
|
||
}
|
||
|
||
const (
|
||
channelBalanceAlertLevelNone = "none"
|
||
channelBalanceAlertLevelSoft = "soft"
|
||
channelBalanceAlertLevelRisk = "risk"
|
||
)
|
||
|
||
func getChannelBalanceAlertConfigForUsedQuota() (bool, float64, float64) {
|
||
enabled := false
|
||
softThreshold := 50.0
|
||
riskThreshold := 20.0
|
||
|
||
common.OptionMapRWMutex.RLock()
|
||
enabled = common.OptionMap["ChannelBalanceAlertEnabled"] == "true"
|
||
if raw, ok := common.OptionMap["ChannelBalanceSoftAlertThreshold"]; ok {
|
||
if val, err := strconv.ParseFloat(strings.TrimSpace(raw), 64); err == nil && val >= 0 {
|
||
softThreshold = val
|
||
}
|
||
}
|
||
if raw, ok := common.OptionMap["ChannelBalanceRiskAlertThreshold"]; ok {
|
||
if val, err := strconv.ParseFloat(strings.TrimSpace(raw), 64); err == nil && val >= 0 {
|
||
riskThreshold = val
|
||
}
|
||
}
|
||
common.OptionMapRWMutex.RUnlock()
|
||
|
||
if riskThreshold > softThreshold {
|
||
riskThreshold = softThreshold
|
||
}
|
||
return enabled, softThreshold, riskThreshold
|
||
}
|
||
|
||
func getChannelBalanceAlertLevelByRemaining(remaining float64, softThreshold float64, riskThreshold float64) string {
|
||
if remaining <= riskThreshold {
|
||
return channelBalanceAlertLevelRisk
|
||
}
|
||
if remaining <= softThreshold {
|
||
return channelBalanceAlertLevelSoft
|
||
}
|
||
return channelBalanceAlertLevelNone
|
||
}
|
||
|
||
func notifyChannelBalanceAlertOnUsageDelta(channel *Channel, oldRemaining float64, usedQuotaDelta int) {
|
||
enabled, softThreshold, riskThreshold := getChannelBalanceAlertConfigForUsedQuota()
|
||
if !enabled || channel == nil || channel.Id <= 0 || usedQuotaDelta == 0 {
|
||
return
|
||
}
|
||
|
||
newLevel := getChannelBalanceAlertLevelByRemaining(channel.Balance, softThreshold, riskThreshold)
|
||
oldLevel := getChannelBalanceAlertLevelByRemaining(oldRemaining, softThreshold, riskThreshold)
|
||
|
||
otherInfo := channel.GetOtherInfo()
|
||
if persistedLevel := strings.TrimSpace(common.Interface2String(otherInfo["balance_alert_level"])); persistedLevel != "" {
|
||
oldLevel = persistedLevel
|
||
}
|
||
otherInfo["balance_alert_level"] = newLevel
|
||
otherInfo["balance_alert_at"] = common.GetTimestamp()
|
||
channel.SetOtherInfo(otherInfo)
|
||
if err := DB.Model(&Channel{}).Where("id = ?", channel.Id).Update("other_info", channel.OtherInfo).Error; err != nil {
|
||
common.SysLog(fmt.Sprintf("failed to persist used_quota alert level: channel_id=%d, err=%v", channel.Id, err))
|
||
}
|
||
if newLevel == channelBalanceAlertLevelNone || newLevel == oldLevel {
|
||
return
|
||
}
|
||
|
||
levelText := "柔和提示"
|
||
threshold := softThreshold
|
||
if newLevel == channelBalanceAlertLevelRisk {
|
||
levelText = "风险警告"
|
||
threshold = riskThreshold
|
||
}
|
||
err := CreateUserMessage(&UserMessage{
|
||
ReceiverMinRole: common.RoleAdminUser,
|
||
Type: "channel_balance_alert",
|
||
Title: fmt.Sprintf("渠道余额%s(%s)", levelText, channel.Name),
|
||
Content: fmt.Sprintf(
|
||
"渠道“%s”(ID:%d)剩余额度 %.2f,已低于阈值 %.2f,请及时处理。",
|
||
channel.Name, channel.Id, channel.Balance, threshold,
|
||
),
|
||
BizType: "channel_balance_alert",
|
||
BizID: channel.Id,
|
||
})
|
||
if err != nil {
|
||
common.SysLog(fmt.Sprintf("failed to publish used_quota alert message: channel_id=%d, err=%v", channel.Id, err))
|
||
}
|
||
}
|
||
|
||
func DeleteChannelByStatus(status int64) (int64, error) {
|
||
result := DB.Where("status = ?", status).Delete(&Channel{})
|
||
return result.RowsAffected, result.Error
|
||
}
|
||
|
||
func DeleteDisabledChannel() (int64, error) {
|
||
result := DB.Where("status = ? or status = ?", common.ChannelStatusAutoDisabled, common.ChannelStatusManuallyDisabled).Delete(&Channel{})
|
||
return result.RowsAffected, result.Error
|
||
}
|
||
|
||
func GetPaginatedTags(offset int, limit int) ([]*string, error) {
|
||
var tags []*string
|
||
err := DB.Model(&Channel{}).Select("DISTINCT tag").Where("tag != ''").Offset(offset).Limit(limit).Find(&tags).Error
|
||
return tags, err
|
||
}
|
||
|
||
func SearchTags(keyword string, group string, model string, idSort bool) ([]*string, error) {
|
||
var tags []*string
|
||
modelsCol := "`models`"
|
||
|
||
// 如果是 PostgreSQL,使用双引号
|
||
if common.UsingPostgreSQL {
|
||
modelsCol = `"models"`
|
||
}
|
||
|
||
baseURLCol := "`base_url`"
|
||
// 如果是 PostgreSQL,使用双引号
|
||
if common.UsingPostgreSQL {
|
||
baseURLCol = `"base_url"`
|
||
}
|
||
|
||
order := "priority desc"
|
||
if idSort {
|
||
order = "id desc"
|
||
}
|
||
|
||
// 构造基础查询
|
||
baseQuery := DB.Model(&Channel{}).Omit("key")
|
||
|
||
// 构造WHERE子句
|
||
var whereClause string
|
||
var args []interface{}
|
||
if group != "" && group != "null" {
|
||
var groupCondition string
|
||
if common.UsingMySQL {
|
||
groupCondition = `CONCAT(',', ` + commonGroupCol + `, ',') LIKE ?`
|
||
} else {
|
||
// sqlite, PostgreSQL
|
||
groupCondition = `(',' || ` + commonGroupCol + ` || ',') LIKE ?`
|
||
}
|
||
whereClause = "(id = ? OR name LIKE ? OR " + commonKeyCol + " = ? OR " + baseURLCol + " LIKE ?) AND " + modelsCol + ` LIKE ? AND ` + groupCondition
|
||
args = append(args, common.String2Int(keyword), "%"+keyword+"%", keyword, "%"+keyword+"%", "%"+model+"%", "%,"+group+",%")
|
||
} else {
|
||
whereClause = "(id = ? OR name LIKE ? OR " + commonKeyCol + " = ? OR " + baseURLCol + " LIKE ?) AND " + modelsCol + " LIKE ?"
|
||
args = append(args, common.String2Int(keyword), "%"+keyword+"%", keyword, "%"+keyword+"%", "%"+model+"%")
|
||
}
|
||
|
||
subQuery := baseQuery.Where(whereClause, args...).
|
||
Select("tag").
|
||
Where("tag != ''").
|
||
Order(order)
|
||
|
||
err := DB.Table("(?) as sub", subQuery).
|
||
Select("DISTINCT tag").
|
||
Find(&tags).Error
|
||
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
return tags, nil
|
||
}
|
||
|
||
func (channel *Channel) ValidateSettings() error {
|
||
channelParams := &dto.ChannelSettings{}
|
||
if channel.Setting != nil && *channel.Setting != "" {
|
||
err := common.Unmarshal([]byte(*channel.Setting), channelParams)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
}
|
||
return nil
|
||
}
|
||
|
||
func (channel *Channel) GetSetting() dto.ChannelSettings {
|
||
setting := dto.ChannelSettings{}
|
||
if channel.Setting != nil && *channel.Setting != "" {
|
||
err := common.Unmarshal([]byte(*channel.Setting), &setting)
|
||
if err != nil {
|
||
common.SysLog(fmt.Sprintf("failed to unmarshal setting: channel_id=%d, error=%v", channel.Id, err))
|
||
channel.Setting = nil // 清空设置以避免后续错误
|
||
_ = channel.Save() // 保存修改
|
||
}
|
||
}
|
||
return setting
|
||
}
|
||
|
||
func (channel *Channel) SetSetting(setting dto.ChannelSettings) {
|
||
settingBytes, err := common.Marshal(setting)
|
||
if err != nil {
|
||
common.SysLog(fmt.Sprintf("failed to marshal setting: channel_id=%d, error=%v", channel.Id, err))
|
||
return
|
||
}
|
||
channel.Setting = common.GetPointer[string](string(settingBytes))
|
||
}
|
||
|
||
func (channel *Channel) GetOtherSettings() dto.ChannelOtherSettings {
|
||
setting := dto.ChannelOtherSettings{}
|
||
if channel.OtherSettings != "" {
|
||
err := common.UnmarshalJsonStr(channel.OtherSettings, &setting)
|
||
if err != nil {
|
||
common.SysLog(fmt.Sprintf("failed to unmarshal setting: channel_id=%d, error=%v", channel.Id, err))
|
||
channel.OtherSettings = "{}" // 清空设置以避免后续错误
|
||
_ = channel.Save() // 保存修改
|
||
}
|
||
}
|
||
return setting
|
||
}
|
||
|
||
func (channel *Channel) SetOtherSettings(setting dto.ChannelOtherSettings) {
|
||
settingBytes, err := common.Marshal(setting)
|
||
if err != nil {
|
||
common.SysLog(fmt.Sprintf("failed to marshal setting: channel_id=%d, error=%v", channel.Id, err))
|
||
return
|
||
}
|
||
channel.OtherSettings = string(settingBytes)
|
||
}
|
||
|
||
func (channel *Channel) GetParamOverride() map[string]interface{} {
|
||
paramOverride := make(map[string]interface{})
|
||
if channel.ParamOverride != nil && *channel.ParamOverride != "" {
|
||
err := common.Unmarshal([]byte(*channel.ParamOverride), ¶mOverride)
|
||
if err != nil {
|
||
common.SysLog(fmt.Sprintf("failed to unmarshal param override: channel_id=%d, error=%v", channel.Id, err))
|
||
}
|
||
}
|
||
return paramOverride
|
||
}
|
||
|
||
func (channel *Channel) GetHeaderOverride() map[string]interface{} {
|
||
headerOverride := make(map[string]interface{})
|
||
if channel.HeaderOverride != nil && *channel.HeaderOverride != "" {
|
||
err := common.Unmarshal([]byte(*channel.HeaderOverride), &headerOverride)
|
||
if err != nil {
|
||
common.SysLog(fmt.Sprintf("failed to unmarshal header override: channel_id=%d, error=%v", channel.Id, err))
|
||
}
|
||
}
|
||
return headerOverride
|
||
}
|
||
|
||
func GetChannelsByIds(ids []int) ([]*Channel, error) {
|
||
var channels []*Channel
|
||
err := DB.Where("id in (?)", ids).Find(&channels).Error
|
||
return channels, err
|
||
}
|
||
|
||
func BatchSetChannelTag(ids []int, tag *string) error {
|
||
// 开启事务
|
||
tx := DB.Begin()
|
||
if tx.Error != nil {
|
||
return tx.Error
|
||
}
|
||
|
||
// 更新标签
|
||
err := tx.Model(&Channel{}).Where("id in (?)", ids).Update("tag", tag).Error
|
||
if err != nil {
|
||
tx.Rollback()
|
||
return err
|
||
}
|
||
|
||
// update ability status
|
||
channels, err := GetChannelsByIds(ids)
|
||
if err != nil {
|
||
tx.Rollback()
|
||
return err
|
||
}
|
||
|
||
for _, channel := range channels {
|
||
err = channel.UpdateAbilities(tx)
|
||
if err != nil {
|
||
tx.Rollback()
|
||
return err
|
||
}
|
||
}
|
||
|
||
// 提交事务
|
||
return tx.Commit().Error
|
||
}
|
||
|
||
// CountAllChannels returns total channels in DB
|
||
func CountAllChannels() (int64, error) {
|
||
var total int64
|
||
err := DB.Model(&Channel{}).Count(&total).Error
|
||
return total, err
|
||
}
|
||
|
||
// CountAllTags returns number of non-empty distinct tags
|
||
func CountAllTags() (int64, error) {
|
||
var total int64
|
||
err := DB.Model(&Channel{}).Where("tag is not null AND tag != ''").Distinct("tag").Count(&total).Error
|
||
return total, err
|
||
}
|
||
|
||
// Get channels of specified type with pagination
|
||
func GetChannelsByType(startIdx int, num int, idSort bool, channelType int) ([]*Channel, error) {
|
||
var channels []*Channel
|
||
order := "priority desc"
|
||
if idSort {
|
||
order = "id desc"
|
||
}
|
||
err := DB.Where("type = ?", channelType).Order(order).Limit(num).Offset(startIdx).Omit("key").Find(&channels).Error
|
||
return channels, err
|
||
}
|
||
|
||
// Count channels of specific type
|
||
func CountChannelsByType(channelType int) (int64, error) {
|
||
var count int64
|
||
err := DB.Model(&Channel{}).Where("type = ?", channelType).Count(&count).Error
|
||
return count, err
|
||
}
|
||
|
||
// Return map[type]count for all channels
|
||
func CountChannelsGroupByType() (map[int64]int64, error) {
|
||
type result struct {
|
||
Type int64 `gorm:"column:type"`
|
||
Count int64 `gorm:"column:count"`
|
||
}
|
||
var results []result
|
||
err := DB.Model(&Channel{}).Select("type, count(*) as count").Group("type").Find(&results).Error
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
counts := make(map[int64]int64)
|
||
for _, r := range results {
|
||
counts[r.Type] = r.Count
|
||
}
|
||
return counts, nil
|
||
}
|
||
|
||
// GetChannelIdNameMap 返回全量 channel_id(string) → channel_name 的映射,用于价格导出时将 ID 转换为名称。
|
||
func GetChannelIdNameMap() (map[string]string, error) {
|
||
type row struct {
|
||
Id int `gorm:"column:id"`
|
||
Name string `gorm:"column:name"`
|
||
}
|
||
var rows []row
|
||
if err := DB.Model(&Channel{}).Select("id, name").Find(&rows).Error; err != nil {
|
||
return nil, err
|
||
}
|
||
out := make(map[string]string, len(rows))
|
||
for _, r := range rows {
|
||
out[fmt.Sprintf("%d", r.Id)] = r.Name
|
||
}
|
||
return out, nil
|
||
}
|
||
|
||
// GetChannelIDsByName 根据渠道名称精确匹配,返回所有同名渠道的 ID 列表(用于价格导入时按名称定位渠道)。
|
||
func GetChannelIDsByName(name string) ([]int, error) {
|
||
var ids []int
|
||
if err := DB.Model(&Channel{}).Where("name = ?", name).Pluck("id", &ids).Error; err != nil {
|
||
return nil, err
|
||
}
|
||
return ids, nil
|
||
}
|
||
|
||
// GetChannelsByIDs 根据 ID 列表批量获取渠道(含密钥),用于导出场景。
|
||
func GetChannelsByIDs(ids []int) ([]*Channel, error) {
|
||
if len(ids) == 0 {
|
||
return []*Channel{}, nil
|
||
}
|
||
var channels []*Channel
|
||
if err := DB.Where("id IN ?", ids).Find(&channels).Error; err != nil {
|
||
return nil, err
|
||
}
|
||
return channels, nil
|
||
}
|
||
|
||
// GetChannelByName 根据渠道名称精确匹配,返回第一个同名渠道(含密钥)。找不到返回 (nil, nil)。
|
||
func GetChannelByName(name string) (*Channel, error) {
|
||
var channel Channel
|
||
err := DB.Where("name = ?", name).First(&channel).Error
|
||
if err != nil {
|
||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||
return nil, nil
|
||
}
|
||
return nil, err
|
||
}
|
||
return &channel, nil
|
||
}
|
||
|
||
// PartialUpdateChannelFields 按字段名列表精确更新渠道指定列,不影响其他列。
|
||
// 使用 GORM Select + struct Updates,GORM 会按模型定义正确处理保留字(group/key)的方言转义。
|
||
func PartialUpdateChannelFields(id int, cols []string, updates *Channel) error {
|
||
if len(cols) == 0 {
|
||
return nil
|
||
}
|
||
return DB.Model(&Channel{}).Where("id = ?", id).Select(cols).Updates(updates).Error
|
||
}
|