tokenFactory/model/pricing.go

804 lines
29 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package model
import (
"encoding/json"
"fmt"
"sort"
"strings"
"sync"
"time"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/constant"
"github.com/QuantumNous/new-api/setting/ratio_setting"
"github.com/QuantumNous/new-api/types"
)
type Pricing struct {
ModelName string `json:"model_name"`
Description string `json:"description,omitempty"`
DocIntroduction string `json:"doc_introduction,omitempty"`
ApiDocs string `json:"api_docs,omitempty"`
Icon string `json:"icon,omitempty"`
Tags string `json:"tags,omitempty"`
VendorID int `json:"vendor_id,omitempty"`
QuotaType int `json:"quota_type"`
ModelRatio float64 `json:"model_ratio"`
ModelPrice float64 `json:"model_price"`
OwnerBy string `json:"owner_by"`
CompletionRatio *float64 `json:"completion_ratio,omitempty"`
CacheRatio *float64 `json:"cache_ratio,omitempty"`
CreateCacheRatio *float64 `json:"create_cache_ratio,omitempty"`
ImageRatio *float64 `json:"image_ratio,omitempty"`
AudioRatio *float64 `json:"audio_ratio,omitempty"`
AudioCompletionRatio *float64 `json:"audio_completion_ratio,omitempty"`
VideoRatio *float64 `json:"video_ratio,omitempty"`
VideoCompletionRatio *float64 `json:"video_completion_ratio,omitempty"`
VideoPrice *float64 `json:"video_price,omitempty"`
EnableGroup []string `json:"enable_groups"`
SupportedEndpointTypes []constant.EndpointType `json:"supported_endpoint_types"`
PricingVersion string `json:"pricing_version,omitempty"`
}
type PricingVendor struct {
ID int `json:"id"`
Name string `json:"name"`
Description string `json:"description,omitempty"`
Icon string `json:"icon,omitempty"`
}
// PricingSupplierItem 定价 data 中的供应商摘要。
type PricingSupplierItem struct {
SupplierID int `json:"supplier_id"`
SupplierAlias string `json:"supplier_alias"`
CompanyLogoURL string `json:"company_logo_url"`
SupplierType string `json:"supplier_type"`
}
// PricingChannelItem 某模型在各渠道上的定价摘要。
type PricingChannelItem struct {
ChannelID int `json:"channel_id"`
SupplierApplicationID int `json:"supplier_application_id"`
ChannelNo string `json:"channel_no"`
SupplierAlias string `json:"supplier_alias"`
CompanyLogoURL string `json:"company_logo_url"`
SupplierType string `json:"supplier_type"`
// RouteSlug 渠道全局路由后缀,与模型名组合为 {model}/{route_slug} 强制路由至该渠道(整渠道下各模型共用)。
RouteSlug string `json:"route_slug,omitempty"`
// TestResponseTimeMs 渠道最近可展示的单测耗时毫秒0 代表未测试或测试失败,接口将省略该字段。
TestResponseTimeMs int `json:"test_response_time_ms,omitempty"`
ModelPrice float64 `json:"model_price"`
ModelRatio float64 `json:"model_ratio"`
CompletionRatio float64 `json:"completion_ratio"`
CacheRatio float64 `json:"cache_ratio"`
CreateCacheRatio float64 `json:"create_cache_ratio"`
ModelTierRatio any `json:"model_tier_ratio,omitempty"`
CompletionTierRatio any `json:"completion_tier_ratio,omitempty"`
CacheTierRatio any `json:"cache_tier_ratio,omitempty"`
CreateCacheTierRatio any `json:"create_cache_tier_ratio,omitempty"`
PriceDiscountPercent float64 `json:"price_discount_percent"` // 成本折扣率百分数100=不打折)
MarkupDiscountRate float64 `json:"markup_discount_rate"` // 加价折扣率百分数0=不加价)
QuotaType int `json:"quota_type"`
// OptionModelRatio 等:仅 Option「渠道模型定价」显式配置不做供应商/全局回退),供首页成本价展示。
OptionModelRatio *float64 `json:"option_model_ratio,omitempty"`
OptionCompletionRatio *float64 `json:"option_completion_ratio,omitempty"`
OptionCacheRatio *float64 `json:"option_cache_ratio,omitempty"`
OptionCreateCacheRatio *float64 `json:"option_create_cache_ratio,omitempty"`
OptionModelPrice *float64 `json:"option_model_price,omitempty"`
OptionImageRatio *float64 `json:"option_image_ratio,omitempty"`
OptionImagePrice *float64 `json:"option_image_price,omitempty"`
OptionAudioRatio *float64 `json:"option_audio_ratio,omitempty"`
OptionAudioCompletionRatio *float64 `json:"option_audio_completion_ratio,omitempty"`
OptionVideoRatio *float64 `json:"option_video_ratio,omitempty"`
OptionVideoCompletionRatio *float64 `json:"option_video_completion_ratio,omitempty"`
OptionVideoPrice *float64 `json:"option_video_price,omitempty"`
// 热门排序相关字段
SortWeight float64 `json:"sort_weight"` // 渠道权重
ManualBaseReqCount int64 `json:"manual_base_req_count"` // 手动设置调用基数
AutoReqCount int64 `json:"auto_req_count"` // 自动统计调用次数
FinalReqCount int64 `json:"final_req_count"` // 最终调用次数 (= manual + auto)
ChannelHeatScore float64 `json:"channel_heat_score"` // 渠道热度得分 (= final * weight)
}
// PricingAPIItem 在 Pricing 基础上扩展渠道维度统计字段(定价接口 data 元素类型)。
type PricingAPIItem struct {
Pricing
SupplierList []PricingSupplierItem `json:"supplier_list"`
ChannelList []PricingChannelItem `json:"channel_list"`
VideoFlatClipHint *VideoFlatClipPricingHint `json:"video_flat_clip_hint,omitempty"`
ImagePerImageHint *ImagePerImagePricingHint `json:"image_per_image_hint,omitempty"`
}
func resolveChannelPricingTriple(channelID int, supplierApplicationID int, modelName string) (mp, mr, cr float64) {
cr = ResolveSupplierScopedCompletionRatio(channelID, supplierApplicationID, modelName)
// 优先级:供应商渠道表 > 供应商全局表 > Option 渠道 > 平台全局 > 旧 SupplierOption
if v, ok := ResolveSupplierScopedFixedModelPrice(channelID, supplierApplicationID, modelName); ok {
return v, 0, cr
}
mr, _, _ = ResolveSupplierScopedModelRatio(channelID, supplierApplicationID, modelName)
return 0, mr, cr
}
func resolveChannelCachePair(channelID int, supplierApplicationID int, modelName string) (cacheRatio, createCacheRatio float64) {
return ResolveSupplierScopedCacheRatios(channelID, supplierApplicationID, modelName)
}
// fillOptionChannelPricingFields 填充仅来自 Option 渠道模型定价的字段(与运营设置-渠道模型定价一致)。
func fillOptionChannelPricingFields(item *PricingChannelItem, channelID int, modelName string) {
if v, ok := ratio_setting.GetChannelModelRatio(channelID, modelName); ok {
vv := v
item.OptionModelRatio = &vv
}
if v, ok := ratio_setting.GetChannelCompletionRatio(channelID, modelName); ok {
vv := v
item.OptionCompletionRatio = &vv
}
if v, ok := ratio_setting.GetChannelCacheRatio(channelID, modelName); ok {
vv := v
item.OptionCacheRatio = &vv
}
if v, ok := ratio_setting.GetChannelCreateCacheRatio(channelID, modelName); ok {
vv := v
item.OptionCreateCacheRatio = &vv
}
if v, ok := ratio_setting.GetChannelModelPrice(channelID, modelName); ok {
vv := v
item.OptionModelPrice = &vv
}
if v, ok := ratio_setting.GetChannelImageRatio(channelID, modelName); ok {
vv := v
item.OptionImageRatio = &vv
}
if v, ok := ratio_setting.GetChannelImagePrice(channelID, modelName); ok {
vv := v
item.OptionImagePrice = &vv
}
if v, ok := ratio_setting.GetChannelAudioRatio(channelID, modelName); ok {
vv := v
item.OptionAudioRatio = &vv
}
if v, ok := ratio_setting.GetChannelAudioCompletionRatio(channelID, modelName); ok {
vv := v
item.OptionAudioCompletionRatio = &vv
}
if v, ok := ratio_setting.GetChannelVideoRatio(channelID, modelName); ok {
vv := v
item.OptionVideoRatio = &vv
}
if v, ok := ratio_setting.GetChannelVideoCompletionRatio(channelID, modelName); ok {
vv := v
item.OptionVideoCompletionRatio = &vv
}
if v, ok := ratio_setting.GetChannelVideoPrice(channelID, modelName); ok {
vv := v
item.OptionVideoPrice = &vv
}
}
func pricingSupplierAliasFromMeta(supplierApplicationID int, alias *string) string {
if supplierApplicationID == 0 {
return "P0"
}
if alias != nil {
s := strings.TrimSpace(*alias)
if s == "0" {
return "P0"
}
if s != "" {
return s
}
}
return SupplierApplicationAutoAlias(supplierApplicationID)
}
// BuildPricingAPIItems 为定价接口组装带渠道统计的 data 列表。
// 渠道项价格为基础定价resolveChannelPricingTriple× 渠道专属折扣;用户/分组倍率由前端用 group_ratio 再乘(与 calculateModelPrice 一致)。
//
// includeUntestedChannelPricingRows 为 false 时保持原行为:要求有有效单测耗时,且在渠道已有成功单测时要求本模型单测可匹配。
// 为 true 时不过滤上述单测门禁,供 /api/price_sync 等需完整渠道定价(含未单测模型×渠道)的场景。
func BuildPricingAPIItems(filtered []Pricing, visibleChannelIDs map[int]struct{}, metas []ChannelPricingMeta, includeUntestedChannelPricingRows bool) []PricingAPIItem {
testSuccessByChannel, err := LoadChannelPricingTestSuccessIndex()
if err != nil {
common.SysLog(fmt.Sprintf("LoadChannelPricingTestSuccessIndex error: %v", err))
testSuccessByChannel = nil
}
visibleIDs := make([]int, 0, len(visibleChannelIDs))
for id := range visibleChannelIDs {
visibleIDs = append(visibleIDs, id)
}
// 一次性批量加载可见渠道的 route_slug避免 N+1 查询
channelSlugMap := GetRouteSlugsByChannelIDs(visibleIDs)
// 按“模型 × 渠道”打平返回:每条 data 仅包含 1 个 channel_list 与 1 个 supplier_list。
// 这样在前端可直接按渠道维度渲染,不再需要先展开聚合模型行。
out := make([]PricingAPIItem, 0, len(filtered))
for _, p := range filtered {
var chItems []PricingChannelItem
modelName := p.ModelName
// 为当前模型预加载各可见渠道的测试耗时:手动覆盖耗时优先,否则使用最近一次成功测试耗时。
testResponseTimeByChannel := make(map[int]int)
if len(visibleIDs) > 0 {
rows, err := GetModelTestResultsByModelNameAndChannelIDs(modelName, visibleIDs)
if err != nil {
common.SysLog(fmt.Sprintf("GetModelTestResultsByModelNameAndChannelIDs error: model=%s err=%v", modelName, err))
} else {
for i := range rows {
r := rows[i]
if r.ChannelId <= 0 {
continue
}
if r.ManualDisplayResponseTime > 0 {
testResponseTimeByChannel[r.ChannelId] = r.ManualDisplayResponseTime
continue
}
if r.LastTestSuccess && r.LastResponseTime > 0 {
testResponseTimeByChannel[r.ChannelId] = r.LastResponseTime
}
}
}
}
for _, row := range metas {
if row.ChannelID <= 0 {
continue
}
if _, ok := visibleChannelIDs[row.ChannelID]; !ok {
continue
}
if !ChannelModelsRawContains(row.Models, modelName) {
continue
}
// 单测门禁:仅当该渠道在库中已有「至少一条」成功单测记录时,才要求本模型也有成功记录。
// 否则新渠道/供应商从未跑过单测时 names 为空,旧逻辑会对所有模型 continue导致供应商只见自有渠道时 data 全空。
if !includeUntestedChannelPricingRows && testSuccessByChannel != nil {
namesOK := testSuccessByChannel[row.ChannelID]
if len(namesOK) > 0 && !ChannelPricingRowMatchesLastTestSuccess(testSuccessByChannel, row.ChannelID, modelName) {
continue
}
}
testMs := testResponseTimeByChannel[row.ChannelID]
// 打平后按渠道逐条返回若该渠道无有效单测耗时0=未测/失败),整条模型-渠道数据不展示定价页price_sync 等场景传入 includeUntestedChannelPricingRows 以保留。
if !includeUntestedChannelPricingRows && testMs <= 0 {
continue
}
baseMp, baseMr, cr := resolveChannelPricingTriple(row.ChannelID, row.SupplierApplicationID, modelName)
chCache, chCreate := resolveChannelCachePair(row.ChannelID, row.SupplierApplicationID, modelName)
modelTierRatio, hasModelTierRatio := ratio_setting.ResolveModelTierRatio(row.ChannelID, modelName)
completionTierRatio, hasCompletionTierRatio := ratio_setting.ResolveCompletionTierRatio(row.ChannelID, modelName)
cacheTierRatio, hasCacheTierRatio := ratio_setting.ResolveCacheTierRatio(row.ChannelID, modelName)
createCacheTierRatio, hasCreateCacheTierRatio := ratio_setting.ResolveCreateCacheTierRatio(row.ChannelID, modelName)
alias := pricingSupplierAliasFromMeta(row.SupplierApplicationID, row.SupplierAlias)
d := 100.0
if row.PriceDiscountPercent != nil {
d = *row.PriceDiscountPercent
}
markupRate := 0.0
if row.MarkupDiscountRate != nil {
markupRate = *row.MarkupDiscountRate
}
// 新公式前端接收原始倍率baseMp/baseMr由前端按公式显式乘以成本折扣率
// price_discount_percent 和 markup_discount_rate 一并下发供前端计算。
routeSlug := ""
if channelSlugMap != nil {
routeSlug = channelSlugMap[row.ChannelID]
}
chItem := PricingChannelItem{
ChannelID: row.ChannelID,
SupplierApplicationID: row.SupplierApplicationID,
ChannelNo: row.ChannelNo,
SupplierAlias: alias,
CompanyLogoURL: strings.TrimSpace(row.CompanyLogoURL),
SupplierType: strings.TrimSpace(row.SupplierType),
RouteSlug: routeSlug,
TestResponseTimeMs: testMs,
ModelPrice: baseMp,
ModelRatio: baseMr,
CompletionRatio: cr,
CacheRatio: chCache,
CreateCacheRatio: chCreate,
PriceDiscountPercent: d,
MarkupDiscountRate: markupRate,
QuotaType: func() int {
if baseMp > 0 {
return 1
} else {
return 0
}
}(),
}
if hasModelTierRatio {
chItem.ModelTierRatio = modelTierRatio
}
if hasCompletionTierRatio {
chItem.CompletionTierRatio = completionTierRatio
}
if hasCacheTierRatio {
chItem.CacheTierRatio = cacheTierRatio
}
if hasCreateCacheTierRatio {
chItem.CreateCacheTierRatio = createCacheTierRatio
}
fillOptionChannelPricingFields(&chItem, row.ChannelID, modelName)
chItems = append(chItems, chItem)
}
if len(chItems) == 0 {
continue
}
sort.Slice(chItems, func(i, j int) bool {
var ai, aj float64
if p.QuotaType == 1 {
ai, aj = chItems[i].ModelPrice, chItems[j].ModelPrice
} else {
ai, aj = chItems[i].ModelRatio, chItems[j].ModelRatio
}
if ai != aj {
return ai < aj
}
return chItems[i].ChannelID < chItems[j].ChannelID
})
for _, ch := range chItems {
item := PricingAPIItem{Pricing: p}
item.ChannelList = []PricingChannelItem{ch}
item.SupplierList = []PricingSupplierItem{
{
SupplierID: ch.SupplierApplicationID,
SupplierAlias: ch.SupplierAlias,
CompanyLogoURL: ch.CompanyLogoURL,
SupplierType: ch.SupplierType,
},
}
item.VideoFlatClipHint = BuildVideoFlatClipHint(ch.ChannelID, modelName, ch.PriceDiscountPercent, ch.MarkupDiscountRate)
item.ImagePerImageHint = BuildImagePerImageHint(ch.ChannelID, modelName, ch.PriceDiscountPercent, ch.MarkupDiscountRate)
out = append(out, item)
}
}
return out
}
var (
pricingMap []Pricing
vendorsList []PricingVendor
supportedEndpointMap map[string]common.EndpointInfo
lastGetPricingTime time.Time
updatePricingLock sync.Mutex
// 缓存映射:模型名 -> 启用分组 / 计费类型
modelEnableGroups = make(map[string][]string)
modelQuotaTypeMap = make(map[string]int)
modelEnableGroupsLock = sync.RWMutex{}
)
var (
modelSupportEndpointTypes = make(map[string][]constant.EndpointType)
modelSupportEndpointsLock = sync.RWMutex{}
)
func GetPricing() []Pricing {
if time.Since(lastGetPricingTime) > time.Minute*1 || len(pricingMap) == 0 {
updatePricingLock.Lock()
defer updatePricingLock.Unlock()
// Double check after acquiring the lock
if time.Since(lastGetPricingTime) > time.Minute*1 || len(pricingMap) == 0 {
modelSupportEndpointsLock.Lock()
defer modelSupportEndpointsLock.Unlock()
updatePricing()
}
}
return pricingMap
}
// GetVendors 返回当前定价接口使用到的供应商信息
func GetVendors() []PricingVendor {
if time.Since(lastGetPricingTime) > time.Minute*1 || len(pricingMap) == 0 {
// 保证先刷新一次
GetPricing()
}
return vendorsList
}
func GetModelSupportEndpointTypes(model string) []constant.EndpointType {
if model == "" {
return make([]constant.EndpointType, 0)
}
modelSupportEndpointsLock.RLock()
defer modelSupportEndpointsLock.RUnlock()
if endpoints, ok := modelSupportEndpointTypes[model]; ok {
return endpoints
}
return make([]constant.EndpointType, 0)
}
func updatePricing() {
//modelRatios := common.GetModelRatios()
enableAbilities, err := GetAllEnableAbilityWithChannels()
if err != nil {
common.SysLog(fmt.Sprintf("GetAllEnableAbilityWithChannels error: %v", err))
return
}
// 预加载模型元数据与供应商一次,避免循环查询
var allMeta []Model
_ = DB.Find(&allMeta).Error
metaMap := make(map[string]*Model)
prefixList := make([]*Model, 0)
suffixList := make([]*Model, 0)
containsList := make([]*Model, 0)
for i := range allMeta {
m := &allMeta[i]
if m.NameRule == NameRuleExact {
metaMap[m.ModelName] = m
} else {
switch m.NameRule {
case NameRulePrefix:
prefixList = append(prefixList, m)
case NameRuleSuffix:
suffixList = append(suffixList, m)
case NameRuleContains:
containsList = append(containsList, m)
}
}
}
// 将非精确规则模型匹配到 metaMap
for _, m := range prefixList {
for _, pricingModel := range enableAbilities {
if strings.HasPrefix(pricingModel.Model, m.ModelName) {
if _, exists := metaMap[pricingModel.Model]; !exists {
metaMap[pricingModel.Model] = m
}
}
}
}
for _, m := range suffixList {
for _, pricingModel := range enableAbilities {
if strings.HasSuffix(pricingModel.Model, m.ModelName) {
if _, exists := metaMap[pricingModel.Model]; !exists {
metaMap[pricingModel.Model] = m
}
}
}
}
for _, m := range containsList {
for _, pricingModel := range enableAbilities {
if strings.Contains(pricingModel.Model, m.ModelName) {
if _, exists := metaMap[pricingModel.Model]; !exists {
metaMap[pricingModel.Model] = m
}
}
}
}
// 预加载供应商
var vendors []Vendor
_ = DB.Find(&vendors).Error
vendorMap := make(map[int]*Vendor)
for i := range vendors {
vendorMap[vendors[i].Id] = &vendors[i]
}
// 初始化默认供应商映射
initDefaultVendorMapping(metaMap, vendorMap, enableAbilities)
// 构建对前端友好的供应商列表
vendorsList = make([]PricingVendor, 0, len(vendorMap))
for _, v := range vendorMap {
vendorsList = append(vendorsList, PricingVendor{
ID: v.Id,
Name: v.Name,
Description: v.Description,
Icon: v.Icon,
})
}
modelGroupsMap := make(map[string]*types.Set[string])
for _, ability := range enableAbilities {
groups, ok := modelGroupsMap[ability.Model]
if !ok {
groups = types.NewSet[string]()
modelGroupsMap[ability.Model] = groups
}
groups.Add(ability.Group)
}
//这里使用切片而不是Set因为一个模型可能支持多个端点类型并且第一个端点是优先使用端点
modelSupportEndpointsStr := make(map[string][]string)
// 先根据已有能力填充原生端点
for _, ability := range enableAbilities {
endpoints := modelSupportEndpointsStr[ability.Model]
channelTypes := common.GetEndpointTypesByChannelType(ability.ChannelType, ability.Model)
for _, channelType := range channelTypes {
if !common.StringsContains(endpoints, string(channelType)) {
endpoints = append(endpoints, string(channelType))
}
}
modelSupportEndpointsStr[ability.Model] = endpoints
}
// 再补充模型自定义端点:若配置有效则替换默认端点,不做合并
for modelName, meta := range metaMap {
if strings.TrimSpace(meta.Endpoints) == "" {
continue
}
var raw map[string]interface{}
if err := json.Unmarshal([]byte(meta.Endpoints), &raw); err == nil {
endpoints := make([]string, 0, len(raw))
for k, v := range raw {
switch v.(type) {
case string, map[string]interface{}:
if !common.StringsContains(endpoints, k) {
endpoints = append(endpoints, k)
}
}
}
if len(endpoints) > 0 {
modelSupportEndpointsStr[modelName] = endpoints
}
}
}
modelSupportEndpointTypes = make(map[string][]constant.EndpointType)
for model, endpoints := range modelSupportEndpointsStr {
supportedEndpoints := make([]constant.EndpointType, 0)
for _, endpointStr := range endpoints {
endpointType := constant.EndpointType(endpointStr)
supportedEndpoints = append(supportedEndpoints, endpointType)
}
modelSupportEndpointTypes[model] = supportedEndpoints
}
// 构建全局 supportedEndpointMap默认 + 自定义覆盖)
supportedEndpointMap = make(map[string]common.EndpointInfo)
// 1. 默认端点
for _, endpoints := range modelSupportEndpointTypes {
for _, et := range endpoints {
if info, ok := common.GetDefaultEndpointInfo(et); ok {
if _, exists := supportedEndpointMap[string(et)]; !exists {
supportedEndpointMap[string(et)] = info
}
}
}
}
// 2. 自定义端点models 表)覆盖默认
for _, meta := range metaMap {
if strings.TrimSpace(meta.Endpoints) == "" {
continue
}
var raw map[string]interface{}
if err := json.Unmarshal([]byte(meta.Endpoints), &raw); err == nil {
for k, v := range raw {
switch val := v.(type) {
case string:
supportedEndpointMap[k] = common.EndpointInfo{Path: val, Method: "POST"}
case map[string]interface{}:
ep := common.EndpointInfo{Method: "POST"}
if p, ok := val["path"].(string); ok {
ep.Path = p
}
if m, ok := val["method"].(string); ok {
ep.Method = strings.ToUpper(m)
}
supportedEndpointMap[k] = ep
default:
// ignore unsupported types
}
}
}
}
pricingMap = make([]Pricing, 0)
for model, groups := range modelGroupsMap {
pricing := Pricing{
ModelName: model,
EnableGroup: groups.Items(),
SupportedEndpointTypes: modelSupportEndpointTypes[model],
}
// 补充模型元数据(描述、标签、供应商、状态)
if meta, ok := metaMap[model]; ok {
// 若模型被禁用(status!=1),则直接跳过,不返回给前端
if meta.Status != 1 {
continue
}
pricing.Description = meta.Description
pricing.DocIntroduction = meta.DocIntroduction
pricing.ApiDocs = meta.ApiDocs
pricing.Icon = meta.Icon
pricing.Tags = meta.Tags
pricing.VendorID = meta.VendorID
}
modelPrice, findPrice := ratio_setting.GetModelPrice(model, false)
if findPrice {
pricing.ModelPrice = modelPrice
pricing.QuotaType = 1
} else {
modelRatio, _, _ := ratio_setting.GetModelRatio(model)
pricing.ModelRatio = modelRatio
// 仅当模型有显式配置的输出倍率时才返回,否则前端不展示输出价格
if ratio_setting.ContainsCompletionRatio(model) {
cr := ratio_setting.GetCompletionRatio(model)
pricing.CompletionRatio = &cr
}
pricing.QuotaType = 0
}
if cacheRatio, ok := ratio_setting.GetCacheRatio(model); ok {
pricing.CacheRatio = &cacheRatio
}
if createCacheRatio, ok := ratio_setting.GetCreateCacheRatio(model); ok {
pricing.CreateCacheRatio = &createCacheRatio
}
if imageRatio, ok := ratio_setting.GetImageRatio(model); ok {
pricing.ImageRatio = &imageRatio
}
if ratio_setting.ContainsAudioRatio(model) {
audioRatio := ratio_setting.GetAudioRatio(model)
pricing.AudioRatio = &audioRatio
}
if ratio_setting.ContainsAudioCompletionRatio(model) {
audioCompletionRatio := ratio_setting.GetAudioCompletionRatio(model)
pricing.AudioCompletionRatio = &audioCompletionRatio
}
if ratio_setting.ContainsVideoRatio(model) {
videoRatio := ratio_setting.GetVideoRatio(model)
pricing.VideoRatio = &videoRatio
}
if ratio_setting.ContainsVideoCompletionRatio(model) {
videoCompletionRatio := ratio_setting.GetVideoCompletionRatio(model)
pricing.VideoCompletionRatio = &videoCompletionRatio
}
if ratio_setting.ContainsVideoPrice(model) {
videoPrice, _ := ratio_setting.GetVideoPrice(model)
pricing.VideoPrice = &videoPrice
}
pricingMap = append(pricingMap, pricing)
}
// 防止大更新后数据不通用
if len(pricingMap) > 0 {
pricingMap[0].PricingVersion = "5a90f2b86c08bd983a9a2e6d66c255f4eaef9c4bc934386d2b6ae84ef0ff1f1f"
}
// 刷新缓存映射,供高并发快速查询
modelEnableGroupsLock.Lock()
modelEnableGroups = make(map[string][]string)
modelQuotaTypeMap = make(map[string]int)
for _, p := range pricingMap {
modelEnableGroups[p.ModelName] = p.EnableGroup
modelQuotaTypeMap[p.ModelName] = p.QuotaType
}
modelEnableGroupsLock.Unlock()
lastGetPricingTime = time.Now()
}
// GetSupportedEndpointMap 返回全局端点到路径的映射
func GetSupportedEndpointMap() map[string]common.EndpointInfo {
return supportedEndpointMap
}
// ModelRequestStats 模型请求统计数据
type ModelRequestStats struct {
ModelName string `gorm:"column:model_name"`
RequestCount7d int64 `gorm:"column:req_count_7d"`
RequestCount30d int64 `gorm:"column:req_count_30d"`
}
// ChannelModelRequestStats 渠道-模型组合的请求统计数据
type ChannelModelRequestStats struct {
ChannelID int `gorm:"column:channel_id"`
ModelName string `gorm:"column:model_name"`
RequestCount7d int64 `gorm:"column:req_count_7d"`
RequestCount30d int64 `gorm:"column:req_count_30d"`
}
// HeatStatPeriod 热度统计周期,可选值: "7d" | "30d" | "all"
const (
HeatStatPeriod7d = "7d"
HeatStatPeriod30d = "30d"
HeatStatPeriodAll = "all"
)
// GetModelRequestStatsByPeriod 查询各模型的请求统计数据period 为 "7d"/"30d"/"all"
func GetModelRequestStatsByPeriod(period string) ([]ModelRequestStats, error) {
now := time.Now()
var startTime int64
switch period {
case HeatStatPeriod30d:
startTime = now.AddDate(0, 0, -30).Unix()
case HeatStatPeriodAll:
startTime = 0
default: // "7d"
startTime = now.AddDate(0, 0, -7).Unix()
}
var stats []ModelRequestStats
var err error
if startTime == 0 {
err = DB.Raw(`
SELECT model_name,
COUNT(*) as req_count_7d,
COUNT(*) as req_count_30d
FROM logs
WHERE type = ?
AND model_name != ''
GROUP BY model_name
`, LogTypeConsume).Scan(&stats).Error
} else {
err = DB.Raw(`
SELECT model_name,
COUNT(*) as req_count_7d,
COUNT(*) as req_count_30d
FROM logs
WHERE type = ?
AND model_name != ''
AND created_at >= ?
GROUP BY model_name
`, LogTypeConsume, startTime).Scan(&stats).Error
}
return stats, err
}
// GetModelRequestStats 查询各模型的请求统计数据7天和30天
func GetModelRequestStats() ([]ModelRequestStats, error) {
return GetModelRequestStatsByPeriod(HeatStatPeriod7d)
}
// GetChannelModelRequestStatsByPeriod 查询各渠道-模型组合的请求统计数据period 为 "7d"/"30d"/"all"
func GetChannelModelRequestStatsByPeriod(channelIDs []int, period string) ([]ChannelModelRequestStats, error) {
if len(channelIDs) == 0 {
return []ChannelModelRequestStats{}, nil
}
now := time.Now()
var startTime int64
switch period {
case HeatStatPeriod30d:
startTime = now.AddDate(0, 0, -30).Unix()
case HeatStatPeriodAll:
startTime = 0
default: // "7d"
startTime = now.AddDate(0, 0, -7).Unix()
}
var stats []ChannelModelRequestStats
var err error
if startTime == 0 {
err = DB.Raw(`
SELECT channel_id,
model_name,
COUNT(*) as req_count_7d,
COUNT(*) as req_count_30d
FROM logs
WHERE type = ?
AND model_name != ''
AND channel_id IN ?
GROUP BY channel_id, model_name
`, LogTypeConsume, channelIDs).Scan(&stats).Error
} else {
err = DB.Raw(`
SELECT channel_id,
model_name,
COUNT(*) as req_count_7d,
COUNT(*) as req_count_30d
FROM logs
WHERE type = ?
AND model_name != ''
AND channel_id IN ?
AND created_at >= ?
GROUP BY channel_id, model_name
`, LogTypeConsume, channelIDs, startTime).Scan(&stats).Error
}
return stats, err
}
// GetChannelModelRequestStats 查询各渠道-模型组合的请求统计数据
func GetChannelModelRequestStats(channelIDs []int) ([]ChannelModelRequestStats, error) {
return GetChannelModelRequestStatsByPeriod(channelIDs, HeatStatPeriod7d)
}