tokenFactory/service/quota.go

566 lines
21 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package service
import (
"errors"
"fmt"
"log"
"math"
"strings"
"time"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/constant"
"github.com/QuantumNous/new-api/dto"
"github.com/QuantumNous/new-api/logger"
"github.com/QuantumNous/new-api/model"
relaycommon "github.com/QuantumNous/new-api/relay/common"
"github.com/QuantumNous/new-api/setting/ratio_setting"
"github.com/QuantumNous/new-api/setting/system_setting"
"github.com/QuantumNous/new-api/types"
"github.com/bytedance/gopkg/util/gopool"
"github.com/gin-gonic/gin"
"github.com/shopspring/decimal"
)
type TokenDetails struct {
TextTokens int
AudioTokens int
}
type QuotaInfo struct {
InputDetails TokenDetails
OutputDetails TokenDetails
ModelName string
UsePrice bool
ModelPrice float64
ModelRatio float64
GroupRatio float64
// 新计费公式字段
CostDiscountPercent float64 // 成本折扣率%,默认 100
MarkupDiscountPercent float64 // 加价折扣率%,默认 0
GlobalModelRatio float64 // 全局模型输入倍率
GlobalModelPrice float64 // 全局模型固定价格
}
func hasCustomModelRatio(modelName string, currentRatio float64) bool {
defaultRatio, exists := ratio_setting.GetDefaultModelRatioMap()[modelName]
if !exists {
return true
}
return currentRatio != defaultRatio
}
func calculateAudioQuota(info QuotaInfo) int {
// 兼容旧调用路径CostDiscountPercent 为 0 时默认 100无折扣
costDisc := info.CostDiscountPercent
if costDisc == 0 {
costDisc = 100
}
markupDisc := info.MarkupDiscountPercent
if info.UsePrice {
quotaPerUnit := decimal.NewFromFloat(common.QuotaPerUnit)
groupRatio := decimal.NewFromFloat(info.GroupRatio)
// 新公式:固定价格 = 渠道固定价 * 成本折扣率% + 全局固定价 * 加价折扣率%
effModelPrice := model.EffectiveModelPrice(info.ModelPrice, info.GlobalModelPrice, costDisc, markupDisc)
quota := decimal.NewFromFloat(effModelPrice).Mul(quotaPerUnit).Mul(groupRatio)
return int(quota.IntPart())
}
completionRatio := decimal.NewFromFloat(ratio_setting.GetCompletionRatio(info.ModelName))
audioRatio := decimal.NewFromFloat(ratio_setting.GetAudioRatio(info.ModelName))
audioCompletionRatio := decimal.NewFromFloat(ratio_setting.GetAudioCompletionRatio(info.ModelName))
groupRatio := decimal.NewFromFloat(info.GroupRatio)
// 新公式:有效输入/输出倍率
// 音频场景全局输出倍率与渠道输出倍率相同(均取 GetCompletionRatio
// 故 globalCompletionRatio = completionRatio语义一致不影响加价侧金额
globalCompletionRatioForAudio := completionRatio.InexactFloat64()
effInputRate := model.EffectiveInputRate(info.ModelRatio, info.GlobalModelRatio, costDisc, markupDisc)
effOutputRate := model.EffectiveOutputRate(info.ModelRatio, completionRatio.InexactFloat64(), info.GlobalModelRatio, globalCompletionRatioForAudio, costDisc, markupDisc)
dEffInputRate := decimal.NewFromFloat(effInputRate)
dEffOutputRate := decimal.NewFromFloat(effOutputRate)
inputTextTokens := decimal.NewFromInt(int64(info.InputDetails.TextTokens))
outputTextTokens := decimal.NewFromInt(int64(info.OutputDetails.TextTokens))
inputAudioTokens := decimal.NewFromInt(int64(info.InputDetails.AudioTokens))
outputAudioTokens := decimal.NewFromInt(int64(info.OutputDetails.AudioTokens))
// 音频倍率沿用原有逻辑,仅文本侧使用新有效倍率
quota := decimal.Zero
quota = quota.Add(inputTextTokens.Mul(dEffInputRate))
quota = quota.Add(outputTextTokens.Mul(dEffOutputRate))
quota = quota.Add(inputAudioTokens.Mul(audioRatio).Mul(dEffInputRate))
quota = quota.Add(outputAudioTokens.Mul(audioRatio).Mul(audioCompletionRatio).Mul(dEffInputRate))
quota = quota.Mul(groupRatio)
if effInputRate > 0 && quota.LessThanOrEqual(decimal.Zero) {
quota = decimal.NewFromInt(1)
}
return int(quota.Round(0).IntPart())
}
func PreWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage *dto.RealtimeUsage) error {
if relayInfo.UsePrice {
return nil
}
userQuota, err := model.GetUserQuota(relayInfo.UserId, false)
if err != nil {
return err
}
token, err := model.GetTokenByKey(strings.TrimPrefix(relayInfo.TokenKey, "sk-"), false)
if err != nil {
return err
}
modelName := relayInfo.OriginModelName
textInputTokens := usage.InputTokenDetails.TextTokens
textOutTokens := usage.OutputTokenDetails.TextTokens
audioInputTokens := usage.InputTokenDetails.AudioTokens
audioOutTokens := usage.OutputTokenDetails.AudioTokens
groupRatio := ratio_setting.GetGroupRatio(relayInfo.UsingGroup)
modelRatio, _, _ := ratio_setting.GetModelRatio(modelName)
autoGroup, exists := common.GetContextKey(ctx, constant.ContextKeyAutoGroup)
if exists {
groupRatio = ratio_setting.GetGroupRatio(autoGroup.(string))
log.Printf("final group ratio: %f", groupRatio)
relayInfo.UsingGroup = autoGroup.(string)
}
actualGroupRatio := groupRatio
userGroupRatio, ok := ratio_setting.GetGroupGroupRatio(relayInfo.UserGroup, relayInfo.UsingGroup)
if ok {
actualGroupRatio = userGroupRatio
}
wssChID := 0
if relayInfo.ChannelMeta != nil {
wssChID = relayInfo.ChannelId
}
wssGlobalRatio, _, _ := ratio_setting.GetModelRatio(modelName)
quotaInfo := QuotaInfo{
InputDetails: TokenDetails{
TextTokens: textInputTokens,
AudioTokens: audioInputTokens,
},
OutputDetails: TokenDetails{
TextTokens: textOutTokens,
AudioTokens: audioOutTokens,
},
ModelName: modelName,
UsePrice: relayInfo.UsePrice,
ModelRatio: modelRatio,
GroupRatio: actualGroupRatio,
CostDiscountPercent: model.ResolveChannelPriceDiscountPercent(wssChID),
MarkupDiscountPercent: model.ResolveChannelMarkupDiscountRate(wssChID),
GlobalModelRatio: wssGlobalRatio,
}
quota := calculateAudioQuota(quotaInfo)
if userQuota < quota {
return fmt.Errorf("user quota is not enough, user quota: %s, need quota: %s", logger.FormatQuota(userQuota), logger.FormatQuota(quota))
}
if !token.UnlimitedQuota && token.RemainQuota < quota {
return fmt.Errorf("token quota is not enough, token remain quota: %s, need quota: %s", logger.FormatQuota(token.RemainQuota), logger.FormatQuota(quota))
}
err = PostConsumeQuota(relayInfo, quota, 0, false)
if err != nil {
return err
}
logger.LogInfo(ctx, "realtime streaming consume quota success, quota: "+fmt.Sprintf("%d", quota))
return nil
}
func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelName string,
usage *dto.RealtimeUsage, extraContent string) {
useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix()
textInputTokens := usage.InputTokenDetails.TextTokens
textOutTokens := usage.OutputTokenDetails.TextTokens
audioInputTokens := usage.InputTokenDetails.AudioTokens
audioOutTokens := usage.OutputTokenDetails.AudioTokens
tokenName := ctx.GetString("token_name")
completionRatio := decimal.NewFromFloat(ratio_setting.GetCompletionRatio(modelName))
audioRatio := decimal.NewFromFloat(ratio_setting.GetAudioRatio(relayInfo.OriginModelName))
audioCompletionRatio := decimal.NewFromFloat(ratio_setting.GetAudioCompletionRatio(modelName))
modelRatio := relayInfo.PriceData.ModelRatio
groupRatio := relayInfo.PriceData.GroupRatioInfo.GroupRatio
modelPrice := relayInfo.PriceData.ModelPrice
usePrice := relayInfo.PriceData.UsePrice
audioWssChID := 0
if relayInfo.ChannelMeta != nil {
audioWssChID = relayInfo.ChannelId
}
wssPostCostDisc := relayInfo.PriceData.CostDiscountPercent
wssPostMarkupDisc := relayInfo.PriceData.MarkupDiscountPercent
if wssPostCostDisc == 0 {
wssPostCostDisc = model.ResolveChannelPriceDiscountPercent(audioWssChID)
}
wssPostGlobalRatio, _, _ := ratio_setting.GetModelRatio(modelName)
wssPostGlobalPrice, _ := ratio_setting.GetModelPrice(modelName, false)
quotaInfo := QuotaInfo{
InputDetails: TokenDetails{
TextTokens: textInputTokens,
AudioTokens: audioInputTokens,
},
OutputDetails: TokenDetails{
TextTokens: textOutTokens,
AudioTokens: audioOutTokens,
},
ModelName: modelName,
UsePrice: usePrice,
ModelRatio: modelRatio,
ModelPrice: modelPrice,
GroupRatio: groupRatio,
CostDiscountPercent: wssPostCostDisc,
MarkupDiscountPercent: wssPostMarkupDisc,
GlobalModelRatio: wssPostGlobalRatio,
GlobalModelPrice: wssPostGlobalPrice,
}
quota := calculateAudioQuota(quotaInfo)
totalTokens := usage.TotalTokens
var logContent string
if !usePrice {
logContent = fmt.Sprintf("模型倍率 %.2f,补全倍率 %.2f,音频倍率 %.2f,音频补全倍率 %.2f,分组倍率 %.2f",
modelRatio, completionRatio.InexactFloat64(), audioRatio.InexactFloat64(), audioCompletionRatio.InexactFloat64(), groupRatio)
} else {
logContent = fmt.Sprintf("模型价格 %.2f,分组倍率 %.2f", modelPrice, groupRatio)
}
// record all the consume log even if quota is 0
if totalTokens == 0 {
// in this case, must be some error happened
// we cannot just return, because we may have to return the pre-consumed quota
quota = 0
logContent += fmt.Sprintf("(可能是上游超时)")
logger.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, "+
"tokenId %d, model %s pre-consumed quota %d", relayInfo.UserId, audioWssChID, relayInfo.TokenId, modelName, relayInfo.FinalPreConsumedQuota))
} else {
model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, quota)
model.UpdateChannelUsedQuota(audioWssChID, quota)
}
logModel := modelName
if extraContent != "" {
logContent += ", " + extraContent
}
other := GenerateWssOtherInfo(ctx, relayInfo, usage, modelRatio, groupRatio,
completionRatio.InexactFloat64(), audioRatio.InexactFloat64(), audioCompletionRatio.InexactFloat64(), modelPrice, relayInfo.PriceData.GroupRatioInfo.GroupSpecialRatio)
model.RecordConsumeLog(ctx, relayInfo.UserId, model.RecordConsumeLogParams{
ChannelId: audioWssChID,
PromptTokens: usage.InputTokens,
CompletionTokens: usage.OutputTokens,
ModelName: logModel,
TokenName: tokenName,
Quota: quota,
Content: logContent,
TokenId: relayInfo.TokenId,
UseTimeSeconds: int(useTimeSeconds),
IsStream: relayInfo.IsStream,
Group: relayInfo.UsingGroup,
Other: other,
})
}
func CalcOpenRouterCacheCreateTokens(usage dto.Usage, priceData types.PriceData) int {
if priceData.CacheCreationRatio == 1 {
return 0
}
quotaPrice := priceData.ModelRatio / common.QuotaPerUnit
promptCacheCreatePrice := quotaPrice * priceData.CacheCreationRatio
promptCacheReadPrice := quotaPrice * priceData.CacheRatio
completionPrice := quotaPrice * priceData.CompletionRatio
cost, _ := usage.Cost.(float64)
totalPromptTokens := float64(usage.PromptTokens)
completionTokens := float64(usage.CompletionTokens)
promptCacheReadTokens := float64(usage.PromptTokensDetails.CachedTokens)
return int(math.Round((cost -
totalPromptTokens*quotaPrice +
promptCacheReadTokens*(quotaPrice-promptCacheReadPrice) -
completionTokens*completionPrice) /
(promptCacheCreatePrice - quotaPrice)))
}
func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage *dto.Usage, extraContent string) {
useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix()
textInputTokens := usage.PromptTokensDetails.TextTokens
textOutTokens := usage.CompletionTokenDetails.TextTokens
audioInputTokens := usage.PromptTokensDetails.AudioTokens
audioOutTokens := usage.CompletionTokenDetails.AudioTokens
tokenName := ctx.GetString("token_name")
completionRatio := decimal.NewFromFloat(ratio_setting.GetCompletionRatio(relayInfo.OriginModelName))
audioRatio := decimal.NewFromFloat(ratio_setting.GetAudioRatio(relayInfo.OriginModelName))
audioCompletionRatio := decimal.NewFromFloat(ratio_setting.GetAudioCompletionRatio(relayInfo.OriginModelName))
modelRatio := relayInfo.PriceData.ModelRatio
groupRatio := relayInfo.PriceData.GroupRatioInfo.GroupRatio
modelPrice := relayInfo.PriceData.ModelPrice
usePrice := relayInfo.PriceData.UsePrice
audioChID := 0
if relayInfo.ChannelMeta != nil {
audioChID = relayInfo.ChannelId
}
audioCostDisc := relayInfo.PriceData.CostDiscountPercent
audioMarkupDisc := relayInfo.PriceData.MarkupDiscountPercent
if audioCostDisc == 0 {
audioCostDisc = model.ResolveChannelPriceDiscountPercent(audioChID)
}
audioGlobalRatio, _, _ := ratio_setting.GetModelRatio(relayInfo.OriginModelName)
audioGlobalPrice, _ := ratio_setting.GetModelPrice(relayInfo.OriginModelName, false)
quotaInfo := QuotaInfo{
InputDetails: TokenDetails{
TextTokens: textInputTokens,
AudioTokens: audioInputTokens,
},
OutputDetails: TokenDetails{
TextTokens: textOutTokens,
AudioTokens: audioOutTokens,
},
ModelName: relayInfo.OriginModelName,
UsePrice: usePrice,
ModelRatio: modelRatio,
ModelPrice: modelPrice,
GroupRatio: groupRatio,
CostDiscountPercent: audioCostDisc,
MarkupDiscountPercent: audioMarkupDisc,
GlobalModelRatio: audioGlobalRatio,
GlobalModelPrice: audioGlobalPrice,
}
quota := calculateAudioQuota(quotaInfo)
totalTokens := usage.TotalTokens
var logContent string
if !usePrice {
logContent = fmt.Sprintf("模型倍率 %.2f,补全倍率 %.2f,音频倍率 %.2f,音频补全倍率 %.2f,分组倍率 %.2f",
modelRatio, completionRatio.InexactFloat64(), audioRatio.InexactFloat64(), audioCompletionRatio.InexactFloat64(), groupRatio)
} else {
logContent = fmt.Sprintf("模型价格 %.2f,分组倍率 %.2f", modelPrice, groupRatio)
}
// record all the consume log even if quota is 0
if totalTokens == 0 {
// in this case, must be some error happened
// we cannot just return, because we may have to return the pre-consumed quota
quota = 0
logContent += fmt.Sprintf("(可能是上游超时)")
logger.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, "+
"tokenId %d, model %s pre-consumed quota %d", relayInfo.UserId, audioChID, relayInfo.TokenId, relayInfo.OriginModelName, relayInfo.FinalPreConsumedQuota))
} else {
model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, quota)
model.UpdateChannelUsedQuota(audioChID, quota)
}
if err := SettleBilling(ctx, relayInfo, quota); err != nil {
logger.LogError(ctx, "error settling billing: "+err.Error())
}
logModel := relayInfo.OriginModelName
if extraContent != "" {
logContent += ", " + extraContent
}
other := GenerateAudioOtherInfo(ctx, relayInfo, usage, modelRatio, groupRatio,
completionRatio.InexactFloat64(), audioRatio.InexactFloat64(), audioCompletionRatio.InexactFloat64(), modelPrice, relayInfo.PriceData.GroupRatioInfo.GroupSpecialRatio)
model.RecordConsumeLog(ctx, relayInfo.UserId, model.RecordConsumeLogParams{
ChannelId: audioChID,
PromptTokens: usage.PromptTokens,
CompletionTokens: usage.CompletionTokens,
ModelName: logModel,
TokenName: tokenName,
Quota: quota,
Content: logContent,
TokenId: relayInfo.TokenId,
UseTimeSeconds: int(useTimeSeconds),
IsStream: relayInfo.IsStream,
Group: relayInfo.UsingGroup,
Other: other,
})
}
func PreConsumeTokenQuota(relayInfo *relaycommon.RelayInfo, quota int) error {
if quota < 0 {
return errors.New("quota 不能为负数!")
}
if relayInfo.IsPlayground {
return nil
}
//if relayInfo.TokenUnlimited {
// return nil
//}
token, err := model.GetTokenByKey(relayInfo.TokenKey, false)
if err != nil {
return err
}
if !relayInfo.TokenUnlimited && token.RemainQuota < quota {
return fmt.Errorf("token quota is not enough, token remain quota: %s, need quota: %s", logger.FormatQuota(token.RemainQuota), logger.FormatQuota(quota))
}
err = model.DecreaseTokenQuota(relayInfo.TokenId, relayInfo.TokenKey, quota)
if err != nil {
return err
}
return nil
}
func PostConsumeQuota(relayInfo *relaycommon.RelayInfo, quota int, preConsumedQuota int, sendEmail bool) (err error) {
// 1) Consume from wallet quota OR subscription item
if relayInfo != nil && relayInfo.BillingSource == BillingSourceSubscription {
if relayInfo.SubscriptionId == 0 {
return errors.New("subscription id is missing")
}
delta := int64(quota)
if delta != 0 {
if err := model.PostConsumeUserSubscriptionDelta(relayInfo.SubscriptionId, delta); err != nil {
return err
}
relayInfo.SubscriptionPostDelta += delta
}
} else {
// Wallet
if quota > 0 {
err = model.DecreaseUserQuota(relayInfo.UserId, quota)
} else {
err = model.IncreaseUserQuota(relayInfo.UserId, -quota, false)
}
if err != nil {
return err
}
}
if !relayInfo.IsPlayground {
if quota > 0 {
err = model.DecreaseTokenQuota(relayInfo.TokenId, relayInfo.TokenKey, quota)
} else {
err = model.IncreaseTokenQuota(relayInfo.TokenId, relayInfo.TokenKey, -quota)
}
if err != nil {
return err
}
}
if sendEmail {
if (quota + preConsumedQuota) != 0 {
checkAndSendQuotaNotify(relayInfo, quota, preConsumedQuota)
}
}
return nil
}
func checkAndSendQuotaNotify(relayInfo *relaycommon.RelayInfo, quota int, preConsumedQuota int) {
gopool.Go(func() {
userSetting := relayInfo.UserSetting
threshold := common.QuotaRemindThreshold
if userSetting.QuotaWarningThreshold != 0 {
threshold = int(userSetting.QuotaWarningThreshold)
}
//noMoreQuota := userCache.Quota-(quota+preConsumedQuota) <= 0
quotaTooLow := false
consumeQuota := quota + preConsumedQuota
if relayInfo.UserQuota-consumeQuota < threshold {
quotaTooLow = true
}
if quotaTooLow {
prompt := "您的额度即将用尽"
topUpLink := fmt.Sprintf("%s/console/topup", system_setting.ServerAddress)
// 根据通知方式生成不同的内容格式
var content string
var values []interface{}
notifyType := userSetting.NotifyType
if notifyType == "" {
notifyType = dto.NotifyTypeEmail
}
if notifyType == dto.NotifyTypeBark {
// Bark推送使用简短文本不支持HTML
content = "{{value}},剩余额度:{{value}},请及时充值"
values = []interface{}{prompt, logger.FormatQuota(relayInfo.UserQuota)}
} else if notifyType == dto.NotifyTypeGotify {
content = "{{value}},当前剩余额度为 {{value}},请及时充值。"
values = []interface{}{prompt, logger.FormatQuota(relayInfo.UserQuota)}
} else {
// 默认内容格式适用于Email和Webhook支持HTML
content = "{{value}},当前剩余额度为 {{value}},为了不影响您的使用,请及时充值。<br/>充值链接:<a href='{{value}}'>{{value}}</a>"
values = []interface{}{prompt, logger.FormatQuota(relayInfo.UserQuota), topUpLink, topUpLink}
}
err := NotifyUser(relayInfo.UserId, relayInfo.UserEmail, relayInfo.UserSetting, dto.NewNotify(dto.NotifyTypeQuotaExceed, prompt, content, values))
if err != nil {
common.SysError(fmt.Sprintf("failed to send quota notify to user %d: %s", relayInfo.UserId, err.Error()))
}
}
})
}
func checkAndSendSubscriptionQuotaNotify(relayInfo *relaycommon.RelayInfo) {
gopool.Go(func() {
if relayInfo == nil {
return
}
if relayInfo.SubscriptionId == 0 || relayInfo.SubscriptionAmountTotal <= 0 {
return
}
userSetting := relayInfo.UserSetting
threshold := common.QuotaRemindThreshold
if userSetting.QuotaWarningThreshold != 0 {
threshold = int(userSetting.QuotaWarningThreshold)
}
usedAfter := relayInfo.SubscriptionAmountUsedAfterPreConsume + relayInfo.SubscriptionPostDelta
remaining := relayInfo.SubscriptionAmountTotal - usedAfter
if remaining >= int64(threshold) {
return
}
prompt := "您的订阅额度即将用尽"
topUpLink := fmt.Sprintf("%s/console/topup", system_setting.ServerAddress)
var content string
var values []interface{}
notifyType := userSetting.NotifyType
if notifyType == "" {
notifyType = dto.NotifyTypeEmail
}
if notifyType == dto.NotifyTypeBark {
content = "{{value}},剩余额度:{{value}},请及时充值"
values = []interface{}{prompt, logger.FormatQuota(int(remaining))}
} else if notifyType == dto.NotifyTypeGotify {
content = "{{value}},当前剩余额度为 {{value}},请及时充值。"
values = []interface{}{prompt, logger.FormatQuota(int(remaining))}
} else {
content = "{{value}},当前剩余额度为 {{value}},为了不影响您的使用,请及时充值。<br/>充值链接:<a href='{{value}}'>{{value}}</a>"
values = []interface{}{prompt, logger.FormatQuota(int(remaining)), topUpLink, topUpLink}
}
if err := NotifyUser(relayInfo.UserId, relayInfo.UserEmail, relayInfo.UserSetting, dto.NewNotify(dto.NotifyTypeQuotaExceed, prompt, content, values)); err != nil {
common.SysError(fmt.Sprintf("failed to send subscription quota notify to user %d: %s", relayInfo.UserId, err.Error()))
}
})
}