tokenFactory/controller/channel-billing.go

666 lines
19 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package controller
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"strconv"
"strings"
"time"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/constant"
"github.com/QuantumNous/new-api/model"
"github.com/QuantumNous/new-api/service"
"github.com/QuantumNous/new-api/setting/operation_setting"
"github.com/QuantumNous/new-api/types"
"github.com/shopspring/decimal"
"github.com/gin-gonic/gin"
)
// https://github.com/songquanpeng/one-api/issues/79
type OpenAISubscriptionResponse struct {
Object string `json:"object"`
HasPaymentMethod bool `json:"has_payment_method"`
SoftLimitUSD float64 `json:"soft_limit_usd"`
HardLimitUSD float64 `json:"hard_limit_usd"`
SystemHardLimitUSD float64 `json:"system_hard_limit_usd"`
AccessUntil int64 `json:"access_until"`
}
type OpenAIUsageDailyCost struct {
Timestamp float64 `json:"timestamp"`
LineItems []struct {
Name string `json:"name"`
Cost float64 `json:"cost"`
}
}
type OpenAICreditGrants struct {
Object string `json:"object"`
TotalGranted float64 `json:"total_granted"`
TotalUsed float64 `json:"total_used"`
TotalAvailable float64 `json:"total_available"`
}
type OpenAIUsageResponse struct {
Object string `json:"object"`
//DailyCosts []OpenAIUsageDailyCost `json:"daily_costs"`
TotalUsage float64 `json:"total_usage"` // unit: 0.01 dollar
}
type OpenAISBUsageResponse struct {
Msg string `json:"msg"`
Data *struct {
Credit string `json:"credit"`
} `json:"data"`
}
type AIProxyUserOverviewResponse struct {
Success bool `json:"success"`
Message string `json:"message"`
ErrorCode int `json:"error_code"`
Data struct {
TotalPoints float64 `json:"totalPoints"`
} `json:"data"`
}
type API2GPTUsageResponse struct {
Object string `json:"object"`
TotalGranted float64 `json:"total_granted"`
TotalUsed float64 `json:"total_used"`
TotalRemaining float64 `json:"total_remaining"`
}
type APGC2DGPTUsageResponse struct {
//Grants interface{} `json:"grants"`
Object string `json:"object"`
TotalAvailable float64 `json:"total_available"`
TotalGranted float64 `json:"total_granted"`
TotalUsed float64 `json:"total_used"`
}
type SiliconFlowUsageResponse struct {
Code int `json:"code"`
Message string `json:"message"`
Status bool `json:"status"`
Data struct {
ID string `json:"id"`
Name string `json:"name"`
Image string `json:"image"`
Email string `json:"email"`
IsAdmin bool `json:"isAdmin"`
Balance string `json:"balance"`
Status string `json:"status"`
Introduction string `json:"introduction"`
Role string `json:"role"`
ChargeBalance string `json:"chargeBalance"`
TotalBalance string `json:"totalBalance"`
Category string `json:"category"`
} `json:"data"`
}
type DeepSeekUsageResponse struct {
IsAvailable bool `json:"is_available"`
BalanceInfos []struct {
Currency string `json:"currency"`
TotalBalance string `json:"total_balance"`
GrantedBalance string `json:"granted_balance"`
ToppedUpBalance string `json:"topped_up_balance"`
} `json:"balance_infos"`
}
type OpenRouterCreditResponse struct {
Data struct {
TotalCredits float64 `json:"total_credits"`
TotalUsage float64 `json:"total_usage"`
} `json:"data"`
}
// GetAuthHeader get auth header
func GetAuthHeader(token string) http.Header {
h := http.Header{}
h.Add("Authorization", fmt.Sprintf("Bearer %s", token))
return h
}
// GetClaudeAuthHeader get claude auth header
func GetClaudeAuthHeader(token string) http.Header {
h := http.Header{}
h.Add("x-api-key", token)
h.Add("anthropic-version", "2023-06-01")
return h
}
func GetResponseBody(method, url string, channel *model.Channel, headers http.Header) ([]byte, error) {
return GetResponseBodyWithContext(context.Background(), method, url, channel, headers)
}
// GetResponseBodyWithContext 与 GetResponseBody 相同,但将请求绑定到 ctx用于取消与超时
func GetResponseBodyWithContext(ctx context.Context, method, url string, channel *model.Channel, headers http.Header) ([]byte, error) {
req, err := http.NewRequestWithContext(ctx, method, url, nil)
if err != nil {
return nil, err
}
for k := range headers {
req.Header.Add(k, headers.Get(k))
}
client, err := service.NewProxyHttpClient(channel.GetSetting().Proxy)
if err != nil {
return nil, err
}
res, err := client.Do(req)
if err != nil {
return nil, err
}
if res.StatusCode != http.StatusOK {
return nil, fmt.Errorf("status code: %d", res.StatusCode)
}
body, err := io.ReadAll(res.Body)
if err != nil {
return nil, err
}
err = res.Body.Close()
if err != nil {
return nil, err
}
return body, nil
}
func updateChannelCloseAIBalance(channel *model.Channel) (float64, error) {
url := fmt.Sprintf("%s/dashboard/billing/credit_grants", channel.GetBaseURL())
body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
if err != nil {
return 0, err
}
response := OpenAICreditGrants{}
err = json.Unmarshal(body, &response)
if err != nil {
return 0, err
}
channel.UpdateBalance(response.TotalAvailable)
return response.TotalAvailable, nil
}
func updateChannelOpenAISBBalance(channel *model.Channel) (float64, error) {
url := fmt.Sprintf("https://api.openai-sb.com/sb-api/user/status?api_key=%s", channel.Key)
body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
if err != nil {
return 0, err
}
response := OpenAISBUsageResponse{}
err = json.Unmarshal(body, &response)
if err != nil {
return 0, err
}
if response.Data == nil {
return 0, errors.New(response.Msg)
}
balance, err := strconv.ParseFloat(response.Data.Credit, 64)
if err != nil {
return 0, err
}
channel.UpdateBalance(balance)
return balance, nil
}
func updateChannelAIProxyBalance(channel *model.Channel) (float64, error) {
url := "https://aiproxy.io/api/report/getUserOverview"
headers := http.Header{}
headers.Add("Api-Key", channel.Key)
body, err := GetResponseBody("GET", url, channel, headers)
if err != nil {
return 0, err
}
response := AIProxyUserOverviewResponse{}
err = json.Unmarshal(body, &response)
if err != nil {
return 0, err
}
if !response.Success {
return 0, fmt.Errorf("code: %d, message: %s", response.ErrorCode, response.Message)
}
channel.UpdateBalance(response.Data.TotalPoints)
return response.Data.TotalPoints, nil
}
func updateChannelAPI2GPTBalance(channel *model.Channel) (float64, error) {
url := "https://api.api2gpt.com/dashboard/billing/credit_grants"
body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
if err != nil {
return 0, err
}
response := API2GPTUsageResponse{}
err = json.Unmarshal(body, &response)
if err != nil {
return 0, err
}
channel.UpdateBalance(response.TotalRemaining)
return response.TotalRemaining, nil
}
func updateChannelSiliconFlowBalance(channel *model.Channel) (float64, error) {
url := "https://api.siliconflow.cn/v1/user/info"
body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
if err != nil {
return 0, err
}
response := SiliconFlowUsageResponse{}
err = json.Unmarshal(body, &response)
if err != nil {
return 0, err
}
if response.Code != 20000 {
return 0, fmt.Errorf("code: %d, message: %s", response.Code, response.Message)
}
balance, err := strconv.ParseFloat(response.Data.TotalBalance, 64)
if err != nil {
return 0, err
}
channel.UpdateBalance(balance)
return balance, nil
}
func updateChannelDeepSeekBalance(channel *model.Channel) (float64, error) {
url := "https://api.deepseek.com/user/balance"
body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
if err != nil {
return 0, err
}
response := DeepSeekUsageResponse{}
err = json.Unmarshal(body, &response)
if err != nil {
return 0, err
}
index := -1
for i, balanceInfo := range response.BalanceInfos {
if balanceInfo.Currency == "CNY" {
index = i
break
}
}
if index == -1 {
return 0, errors.New("currency CNY not found")
}
balance, err := strconv.ParseFloat(response.BalanceInfos[index].TotalBalance, 64)
if err != nil {
return 0, err
}
channel.UpdateBalance(balance)
return balance, nil
}
func updateChannelAIGC2DBalance(channel *model.Channel) (float64, error) {
url := "https://api.aigc2d.com/dashboard/billing/credit_grants"
body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
if err != nil {
return 0, err
}
response := APGC2DGPTUsageResponse{}
err = json.Unmarshal(body, &response)
if err != nil {
return 0, err
}
channel.UpdateBalance(response.TotalAvailable)
return response.TotalAvailable, nil
}
func updateChannelOpenRouterBalance(channel *model.Channel) (float64, error) {
url := "https://openrouter.ai/api/v1/credits"
body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
if err != nil {
return 0, err
}
response := OpenRouterCreditResponse{}
err = json.Unmarshal(body, &response)
if err != nil {
return 0, err
}
balance := response.Data.TotalCredits - response.Data.TotalUsage
channel.UpdateBalance(balance)
return balance, nil
}
func updateChannelMoonshotBalance(channel *model.Channel) (float64, error) {
url := "https://api.moonshot.cn/v1/users/me/balance"
body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
if err != nil {
return 0, err
}
type MoonshotBalanceData struct {
AvailableBalance float64 `json:"available_balance"`
VoucherBalance float64 `json:"voucher_balance"`
CashBalance float64 `json:"cash_balance"`
}
type MoonshotBalanceResponse struct {
Code int `json:"code"`
Data MoonshotBalanceData `json:"data"`
Scode string `json:"scode"`
Status bool `json:"status"`
}
response := MoonshotBalanceResponse{}
err = json.Unmarshal(body, &response)
if err != nil {
return 0, err
}
if !response.Status || response.Code != 0 {
return 0, fmt.Errorf("failed to update moonshot balance, status: %v, code: %d, scode: %s", response.Status, response.Code, response.Scode)
}
availableBalanceCny := response.Data.AvailableBalance
availableBalanceUsd := decimal.NewFromFloat(availableBalanceCny).Div(decimal.NewFromFloat(operation_setting.Price)).InexactFloat64()
channel.UpdateBalance(availableBalanceUsd)
return availableBalanceUsd, nil
}
type upstreamChannelBalanceResponse struct {
Success bool `json:"success"`
Message string `json:"message"`
Balance float64 `json:"balance"`
}
const (
channelBalanceAlertLevelNone = "none"
channelBalanceAlertLevelSoft = "soft"
channelBalanceAlertLevelRisk = "risk"
)
func getChannelBalanceAlertConfig() (enabled bool, softThreshold float64, riskThreshold float64) {
softThreshold = 50
riskThreshold = 20
common.OptionMapRWMutex.RLock()
enabled = common.OptionMap["ChannelBalanceAlertEnabled"] == "true"
if raw, ok := common.OptionMap["ChannelBalanceSoftAlertThreshold"]; ok {
if val, err := strconv.ParseFloat(strings.TrimSpace(raw), 64); err == nil && val >= 0 {
softThreshold = val
}
}
if raw, ok := common.OptionMap["ChannelBalanceRiskAlertThreshold"]; ok {
if val, err := strconv.ParseFloat(strings.TrimSpace(raw), 64); err == nil && val >= 0 {
riskThreshold = val
}
}
common.OptionMapRWMutex.RUnlock()
if riskThreshold > softThreshold {
riskThreshold = softThreshold
}
return enabled, softThreshold, riskThreshold
}
// getChannelBalanceAlertLevel 按「剩余额度」比较阈值;渠道 balance 字段即剩余(计费会同步扣减)。
func getChannelBalanceAlertLevel(remaining float64, softThreshold float64, riskThreshold float64) string {
if remaining <= riskThreshold {
return channelBalanceAlertLevelRisk
}
if remaining <= softThreshold {
return channelBalanceAlertLevelSoft
}
return channelBalanceAlertLevelNone
}
func persistChannelBalanceAlertLevel(channel *model.Channel, level string) {
if channel == nil || channel.Id <= 0 {
return
}
otherInfo := channel.GetOtherInfo()
otherInfo["balance_alert_level"] = level
otherInfo["balance_alert_at"] = common.GetTimestamp()
channel.SetOtherInfo(otherInfo)
if err := model.DB.Model(&model.Channel{}).
Where("id = ?", channel.Id).
Update("other_info", channel.OtherInfo).Error; err != nil {
common.SysLog(fmt.Sprintf("failed to persist balance alert level: channel_id=%d, err=%v", channel.Id, err))
}
}
func notifyChannelBalanceAlertIfNeeded(channel *model.Channel, oldBalance float64, newBalance float64) {
if channel == nil || channel.Id <= 0 {
return
}
enabled, softThreshold, riskThreshold := getChannelBalanceAlertConfig()
if !enabled {
return
}
newLevel := getChannelBalanceAlertLevel(newBalance, softThreshold, riskThreshold)
otherInfo := channel.GetOtherInfo()
oldLevel := strings.TrimSpace(common.Interface2String(otherInfo["balance_alert_level"]))
if oldLevel == "" {
oldLevel = getChannelBalanceAlertLevel(oldBalance, softThreshold, riskThreshold)
}
persistChannelBalanceAlertLevel(channel, newLevel)
if newLevel == channelBalanceAlertLevelNone || newLevel == oldLevel {
return
}
levelText := "柔和提示"
threshold := softThreshold
if newLevel == channelBalanceAlertLevelRisk {
levelText = "风险警告"
threshold = riskThreshold
}
title := fmt.Sprintf("渠道余额%s%s", levelText, channel.Name)
content := fmt.Sprintf(
"渠道“%s”ID:%d剩余额度 %.2f,已低于阈值 %.2f,请及时处理。",
channel.Name,
channel.Id,
newBalance,
threshold,
)
err := service.PublishUserMessage(&model.UserMessage{
ReceiverMinRole: common.RoleAdminUser,
Type: "channel_balance_alert",
Title: title,
Content: content,
BizType: "channel_balance_alert",
BizID: channel.Id,
})
if err != nil {
common.SysLog(fmt.Sprintf("failed to publish channel balance alert message: channel_id=%d, err=%v", channel.Id, err))
}
}
func tryUpdateTFOpenMirroredChannelBalance(channel *model.Channel) (float64, bool, error) {
otherInfo := channel.GetOtherInfo()
if strings.TrimSpace(common.Interface2String(otherInfo["source"])) != "tokenfactory_open" {
return 0, false, nil
}
upstreamID := common.String2Int(common.Interface2String(otherInfo["upstream_channel_id"]))
if upstreamID <= 0 {
return 0, true, errors.New("同步渠道缺少 upstream_channel_id")
}
baseURL := strings.TrimRight(strings.TrimSpace(channel.GetBaseURL()), "/")
if baseURL == "" {
return 0, true, errors.New("同步渠道缺少上游平台地址")
}
url := fmt.Sprintf("%s/api/channel/update_balance/%d", baseURL, upstreamID)
headers := GetAuthHeader(channel.Key)
headers.Set("X-TokenFactory-Open-Sync-Secret", strings.TrimSpace(channel.Key))
body, err := GetResponseBody("GET", url, channel, headers)
if err != nil {
return 0, true, err
}
resp := upstreamChannelBalanceResponse{}
if err := json.Unmarshal(body, &resp); err != nil {
return 0, true, err
}
if !resp.Success {
msg := strings.TrimSpace(resp.Message)
if msg == "" {
msg = "上游余额接口返回失败"
}
return 0, true, errors.New(msg)
}
channel.UpdateBalance(resp.Balance)
return resp.Balance, true, nil
}
func updateChannelBalance(channel *model.Channel) (float64, error) {
if balance, handled, err := tryUpdateTFOpenMirroredChannelBalance(channel); handled {
return balance, err
}
baseURL := constant.ChannelBaseURLs[channel.Type]
if channel.GetBaseURL() == "" {
channel.BaseURL = &baseURL
}
switch channel.Type {
case constant.ChannelTypeOpenAI:
if channel.GetBaseURL() != "" {
baseURL = channel.GetBaseURL()
}
case constant.ChannelTypeAzure:
return 0, errors.New("尚未实现")
case constant.ChannelTypeCustom:
baseURL = channel.GetBaseURL()
//case common.ChannelTypeOpenAISB:
// return updateChannelOpenAISBBalance(channel)
case constant.ChannelTypeAIProxy:
return updateChannelAIProxyBalance(channel)
case constant.ChannelTypeAPI2GPT:
return updateChannelAPI2GPTBalance(channel)
case constant.ChannelTypeAIGC2D:
return updateChannelAIGC2DBalance(channel)
case constant.ChannelTypeSiliconFlow:
return updateChannelSiliconFlowBalance(channel)
case constant.ChannelTypeDeepSeek:
return updateChannelDeepSeekBalance(channel)
case constant.ChannelTypeOpenRouter:
return updateChannelOpenRouterBalance(channel)
case constant.ChannelTypeMoonshot:
return updateChannelMoonshotBalance(channel)
default:
return 0, errors.New("尚未实现")
}
url := fmt.Sprintf("%s/v1/dashboard/billing/subscription", baseURL)
body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
if err != nil {
return 0, err
}
subscription := OpenAISubscriptionResponse{}
err = json.Unmarshal(body, &subscription)
if err != nil {
return 0, err
}
now := time.Now()
startDate := fmt.Sprintf("%s-01", now.Format("2006-01"))
endDate := now.Format("2006-01-02")
if !subscription.HasPaymentMethod {
startDate = now.AddDate(0, 0, -100).Format("2006-01-02")
}
url = fmt.Sprintf("%s/v1/dashboard/billing/usage?start_date=%s&end_date=%s", baseURL, startDate, endDate)
body, err = GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
if err != nil {
return 0, err
}
usage := OpenAIUsageResponse{}
err = json.Unmarshal(body, &usage)
if err != nil {
return 0, err
}
balance := subscription.HardLimitUSD - usage.TotalUsage/100
channel.UpdateBalance(balance)
return balance, nil
}
func UpdateChannelBalance(c *gin.Context) {
id, err := strconv.Atoi(c.Param("id"))
if err != nil {
common.ApiError(c, err)
return
}
channel, err := model.CacheGetChannel(id)
if err != nil {
common.ApiError(c, err)
return
}
if channel.ChannelInfo.IsMultiKey {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "多密钥渠道不支持余额查询",
})
return
}
oldBalance := channel.Balance
balance, err := updateChannelBalance(channel)
if err != nil {
common.ApiError(c, err)
return
}
notifyChannelBalanceAlertIfNeeded(channel, oldBalance, balance)
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"balance": balance,
})
}
func updateAllChannelsBalance() error {
channels, err := model.GetAllChannels(0, 0, true, false)
if err != nil {
return err
}
for _, channel := range channels {
if channel.Status != common.ChannelStatusEnabled {
continue
}
if channel.ChannelInfo.IsMultiKey {
continue // skip multi-key channels
}
// TODO: support Azure
//if channel.Type != common.ChannelTypeOpenAI && channel.Type != common.ChannelTypeCustom {
// continue
//}
oldBalance := channel.Balance
balance, err := updateChannelBalance(channel)
if err != nil {
continue
} else {
notifyChannelBalanceAlertIfNeeded(channel, oldBalance, balance)
// err is nil & balance <= 0 means quota is used up
if balance <= 0 {
service.DisableChannel(*types.NewChannelError(channel.Id, channel.Type, channel.Name, channel.ChannelInfo.IsMultiKey, "", channel.GetAutoBan()), "余额不足")
}
}
time.Sleep(common.RequestInterval)
}
return nil
}
func UpdateAllChannelsBalance(c *gin.Context) {
// TODO: make it async
err := updateAllChannelsBalance()
if err != nil {
common.ApiError(c, err)
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
})
return
}
func AutomaticallyUpdateChannels(frequency int) {
for {
time.Sleep(time.Duration(frequency) * time.Minute)
common.SysLog("updating all channels")
_ = updateAllChannelsBalance()
common.SysLog("channels update done")
}
}