package controller import ( "context" "encoding/json" "errors" "fmt" "io" "net/http" "strconv" "strings" "time" "github.com/QuantumNous/new-api/common" "github.com/QuantumNous/new-api/constant" "github.com/QuantumNous/new-api/model" "github.com/QuantumNous/new-api/service" "github.com/QuantumNous/new-api/setting/operation_setting" "github.com/QuantumNous/new-api/types" "github.com/shopspring/decimal" "github.com/gin-gonic/gin" ) // https://github.com/songquanpeng/one-api/issues/79 type OpenAISubscriptionResponse struct { Object string `json:"object"` HasPaymentMethod bool `json:"has_payment_method"` SoftLimitUSD float64 `json:"soft_limit_usd"` HardLimitUSD float64 `json:"hard_limit_usd"` SystemHardLimitUSD float64 `json:"system_hard_limit_usd"` AccessUntil int64 `json:"access_until"` } type OpenAIUsageDailyCost struct { Timestamp float64 `json:"timestamp"` LineItems []struct { Name string `json:"name"` Cost float64 `json:"cost"` } } type OpenAICreditGrants struct { Object string `json:"object"` TotalGranted float64 `json:"total_granted"` TotalUsed float64 `json:"total_used"` TotalAvailable float64 `json:"total_available"` } type OpenAIUsageResponse struct { Object string `json:"object"` //DailyCosts []OpenAIUsageDailyCost `json:"daily_costs"` TotalUsage float64 `json:"total_usage"` // unit: 0.01 dollar } type OpenAISBUsageResponse struct { Msg string `json:"msg"` Data *struct { Credit string `json:"credit"` } `json:"data"` } type AIProxyUserOverviewResponse struct { Success bool `json:"success"` Message string `json:"message"` ErrorCode int `json:"error_code"` Data struct { TotalPoints float64 `json:"totalPoints"` } `json:"data"` } type API2GPTUsageResponse struct { Object string `json:"object"` TotalGranted float64 `json:"total_granted"` TotalUsed float64 `json:"total_used"` TotalRemaining float64 `json:"total_remaining"` } type APGC2DGPTUsageResponse struct { //Grants interface{} `json:"grants"` Object string `json:"object"` TotalAvailable float64 `json:"total_available"` TotalGranted float64 `json:"total_granted"` TotalUsed float64 `json:"total_used"` } type SiliconFlowUsageResponse struct { Code int `json:"code"` Message string `json:"message"` Status bool `json:"status"` Data struct { ID string `json:"id"` Name string `json:"name"` Image string `json:"image"` Email string `json:"email"` IsAdmin bool `json:"isAdmin"` Balance string `json:"balance"` Status string `json:"status"` Introduction string `json:"introduction"` Role string `json:"role"` ChargeBalance string `json:"chargeBalance"` TotalBalance string `json:"totalBalance"` Category string `json:"category"` } `json:"data"` } type DeepSeekUsageResponse struct { IsAvailable bool `json:"is_available"` BalanceInfos []struct { Currency string `json:"currency"` TotalBalance string `json:"total_balance"` GrantedBalance string `json:"granted_balance"` ToppedUpBalance string `json:"topped_up_balance"` } `json:"balance_infos"` } type OpenRouterCreditResponse struct { Data struct { TotalCredits float64 `json:"total_credits"` TotalUsage float64 `json:"total_usage"` } `json:"data"` } // GetAuthHeader get auth header func GetAuthHeader(token string) http.Header { h := http.Header{} h.Add("Authorization", fmt.Sprintf("Bearer %s", token)) return h } // GetClaudeAuthHeader get claude auth header func GetClaudeAuthHeader(token string) http.Header { h := http.Header{} h.Add("x-api-key", token) h.Add("anthropic-version", "2023-06-01") return h } func GetResponseBody(method, url string, channel *model.Channel, headers http.Header) ([]byte, error) { return GetResponseBodyWithContext(context.Background(), method, url, channel, headers) } // GetResponseBodyWithContext 与 GetResponseBody 相同,但将请求绑定到 ctx(用于取消与超时)。 func GetResponseBodyWithContext(ctx context.Context, method, url string, channel *model.Channel, headers http.Header) ([]byte, error) { req, err := http.NewRequestWithContext(ctx, method, url, nil) if err != nil { return nil, err } for k := range headers { req.Header.Add(k, headers.Get(k)) } client, err := service.NewProxyHttpClient(channel.GetSetting().Proxy) if err != nil { return nil, err } res, err := client.Do(req) if err != nil { return nil, err } if res.StatusCode != http.StatusOK { return nil, fmt.Errorf("status code: %d", res.StatusCode) } body, err := io.ReadAll(res.Body) if err != nil { return nil, err } err = res.Body.Close() if err != nil { return nil, err } return body, nil } func updateChannelCloseAIBalance(channel *model.Channel) (float64, error) { url := fmt.Sprintf("%s/dashboard/billing/credit_grants", channel.GetBaseURL()) body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key)) if err != nil { return 0, err } response := OpenAICreditGrants{} err = json.Unmarshal(body, &response) if err != nil { return 0, err } channel.UpdateBalance(response.TotalAvailable) return response.TotalAvailable, nil } func updateChannelOpenAISBBalance(channel *model.Channel) (float64, error) { url := fmt.Sprintf("https://api.openai-sb.com/sb-api/user/status?api_key=%s", channel.Key) body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key)) if err != nil { return 0, err } response := OpenAISBUsageResponse{} err = json.Unmarshal(body, &response) if err != nil { return 0, err } if response.Data == nil { return 0, errors.New(response.Msg) } balance, err := strconv.ParseFloat(response.Data.Credit, 64) if err != nil { return 0, err } channel.UpdateBalance(balance) return balance, nil } func updateChannelAIProxyBalance(channel *model.Channel) (float64, error) { url := "https://aiproxy.io/api/report/getUserOverview" headers := http.Header{} headers.Add("Api-Key", channel.Key) body, err := GetResponseBody("GET", url, channel, headers) if err != nil { return 0, err } response := AIProxyUserOverviewResponse{} err = json.Unmarshal(body, &response) if err != nil { return 0, err } if !response.Success { return 0, fmt.Errorf("code: %d, message: %s", response.ErrorCode, response.Message) } channel.UpdateBalance(response.Data.TotalPoints) return response.Data.TotalPoints, nil } func updateChannelAPI2GPTBalance(channel *model.Channel) (float64, error) { url := "https://api.api2gpt.com/dashboard/billing/credit_grants" body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key)) if err != nil { return 0, err } response := API2GPTUsageResponse{} err = json.Unmarshal(body, &response) if err != nil { return 0, err } channel.UpdateBalance(response.TotalRemaining) return response.TotalRemaining, nil } func updateChannelSiliconFlowBalance(channel *model.Channel) (float64, error) { url := "https://api.siliconflow.cn/v1/user/info" body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key)) if err != nil { return 0, err } response := SiliconFlowUsageResponse{} err = json.Unmarshal(body, &response) if err != nil { return 0, err } if response.Code != 20000 { return 0, fmt.Errorf("code: %d, message: %s", response.Code, response.Message) } balance, err := strconv.ParseFloat(response.Data.TotalBalance, 64) if err != nil { return 0, err } channel.UpdateBalance(balance) return balance, nil } func updateChannelDeepSeekBalance(channel *model.Channel) (float64, error) { url := "https://api.deepseek.com/user/balance" body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key)) if err != nil { return 0, err } response := DeepSeekUsageResponse{} err = json.Unmarshal(body, &response) if err != nil { return 0, err } index := -1 for i, balanceInfo := range response.BalanceInfos { if balanceInfo.Currency == "CNY" { index = i break } } if index == -1 { return 0, errors.New("currency CNY not found") } balance, err := strconv.ParseFloat(response.BalanceInfos[index].TotalBalance, 64) if err != nil { return 0, err } channel.UpdateBalance(balance) return balance, nil } func updateChannelAIGC2DBalance(channel *model.Channel) (float64, error) { url := "https://api.aigc2d.com/dashboard/billing/credit_grants" body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key)) if err != nil { return 0, err } response := APGC2DGPTUsageResponse{} err = json.Unmarshal(body, &response) if err != nil { return 0, err } channel.UpdateBalance(response.TotalAvailable) return response.TotalAvailable, nil } func updateChannelOpenRouterBalance(channel *model.Channel) (float64, error) { url := "https://openrouter.ai/api/v1/credits" body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key)) if err != nil { return 0, err } response := OpenRouterCreditResponse{} err = json.Unmarshal(body, &response) if err != nil { return 0, err } balance := response.Data.TotalCredits - response.Data.TotalUsage channel.UpdateBalance(balance) return balance, nil } func updateChannelMoonshotBalance(channel *model.Channel) (float64, error) { url := "https://api.moonshot.cn/v1/users/me/balance" body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key)) if err != nil { return 0, err } type MoonshotBalanceData struct { AvailableBalance float64 `json:"available_balance"` VoucherBalance float64 `json:"voucher_balance"` CashBalance float64 `json:"cash_balance"` } type MoonshotBalanceResponse struct { Code int `json:"code"` Data MoonshotBalanceData `json:"data"` Scode string `json:"scode"` Status bool `json:"status"` } response := MoonshotBalanceResponse{} err = json.Unmarshal(body, &response) if err != nil { return 0, err } if !response.Status || response.Code != 0 { return 0, fmt.Errorf("failed to update moonshot balance, status: %v, code: %d, scode: %s", response.Status, response.Code, response.Scode) } availableBalanceCny := response.Data.AvailableBalance availableBalanceUsd := decimal.NewFromFloat(availableBalanceCny).Div(decimal.NewFromFloat(operation_setting.Price)).InexactFloat64() channel.UpdateBalance(availableBalanceUsd) return availableBalanceUsd, nil } type upstreamChannelBalanceResponse struct { Success bool `json:"success"` Message string `json:"message"` Balance float64 `json:"balance"` } const ( channelBalanceAlertLevelNone = "none" channelBalanceAlertLevelSoft = "soft" channelBalanceAlertLevelRisk = "risk" ) func getChannelBalanceAlertConfig() (enabled bool, softThreshold float64, riskThreshold float64) { softThreshold = 50 riskThreshold = 20 common.OptionMapRWMutex.RLock() enabled = common.OptionMap["ChannelBalanceAlertEnabled"] == "true" if raw, ok := common.OptionMap["ChannelBalanceSoftAlertThreshold"]; ok { if val, err := strconv.ParseFloat(strings.TrimSpace(raw), 64); err == nil && val >= 0 { softThreshold = val } } if raw, ok := common.OptionMap["ChannelBalanceRiskAlertThreshold"]; ok { if val, err := strconv.ParseFloat(strings.TrimSpace(raw), 64); err == nil && val >= 0 { riskThreshold = val } } common.OptionMapRWMutex.RUnlock() if riskThreshold > softThreshold { riskThreshold = softThreshold } return enabled, softThreshold, riskThreshold } // getChannelBalanceAlertLevel 按「剩余额度」比较阈值;渠道 balance 字段即剩余(计费会同步扣减)。 func getChannelBalanceAlertLevel(remaining float64, softThreshold float64, riskThreshold float64) string { if remaining <= riskThreshold { return channelBalanceAlertLevelRisk } if remaining <= softThreshold { return channelBalanceAlertLevelSoft } return channelBalanceAlertLevelNone } func persistChannelBalanceAlertLevel(channel *model.Channel, level string) { if channel == nil || channel.Id <= 0 { return } otherInfo := channel.GetOtherInfo() otherInfo["balance_alert_level"] = level otherInfo["balance_alert_at"] = common.GetTimestamp() channel.SetOtherInfo(otherInfo) if err := model.DB.Model(&model.Channel{}). Where("id = ?", channel.Id). Update("other_info", channel.OtherInfo).Error; err != nil { common.SysLog(fmt.Sprintf("failed to persist balance alert level: channel_id=%d, err=%v", channel.Id, err)) } } func notifyChannelBalanceAlertIfNeeded(channel *model.Channel, oldBalance float64, newBalance float64) { if channel == nil || channel.Id <= 0 { return } enabled, softThreshold, riskThreshold := getChannelBalanceAlertConfig() if !enabled { return } newLevel := getChannelBalanceAlertLevel(newBalance, softThreshold, riskThreshold) otherInfo := channel.GetOtherInfo() oldLevel := strings.TrimSpace(common.Interface2String(otherInfo["balance_alert_level"])) if oldLevel == "" { oldLevel = getChannelBalanceAlertLevel(oldBalance, softThreshold, riskThreshold) } persistChannelBalanceAlertLevel(channel, newLevel) if newLevel == channelBalanceAlertLevelNone || newLevel == oldLevel { return } levelText := "柔和提示" threshold := softThreshold if newLevel == channelBalanceAlertLevelRisk { levelText = "风险警告" threshold = riskThreshold } title := fmt.Sprintf("渠道余额%s(%s)", levelText, channel.Name) content := fmt.Sprintf( "渠道“%s”(ID:%d)剩余额度 %.2f,已低于阈值 %.2f,请及时处理。", channel.Name, channel.Id, newBalance, threshold, ) err := service.PublishUserMessage(&model.UserMessage{ ReceiverMinRole: common.RoleAdminUser, Type: "channel_balance_alert", Title: title, Content: content, BizType: "channel_balance_alert", BizID: channel.Id, }) if err != nil { common.SysLog(fmt.Sprintf("failed to publish channel balance alert message: channel_id=%d, err=%v", channel.Id, err)) } } func tryUpdateTFOpenMirroredChannelBalance(channel *model.Channel) (float64, bool, error) { otherInfo := channel.GetOtherInfo() if strings.TrimSpace(common.Interface2String(otherInfo["source"])) != "tokenfactory_open" { return 0, false, nil } upstreamID := common.String2Int(common.Interface2String(otherInfo["upstream_channel_id"])) if upstreamID <= 0 { return 0, true, errors.New("同步渠道缺少 upstream_channel_id") } baseURL := strings.TrimRight(strings.TrimSpace(channel.GetBaseURL()), "/") if baseURL == "" { return 0, true, errors.New("同步渠道缺少上游平台地址") } url := fmt.Sprintf("%s/api/channel/update_balance/%d", baseURL, upstreamID) headers := GetAuthHeader(channel.Key) headers.Set("X-TokenFactory-Open-Sync-Secret", strings.TrimSpace(channel.Key)) body, err := GetResponseBody("GET", url, channel, headers) if err != nil { return 0, true, err } resp := upstreamChannelBalanceResponse{} if err := json.Unmarshal(body, &resp); err != nil { return 0, true, err } if !resp.Success { msg := strings.TrimSpace(resp.Message) if msg == "" { msg = "上游余额接口返回失败" } return 0, true, errors.New(msg) } channel.UpdateBalance(resp.Balance) return resp.Balance, true, nil } func updateChannelBalance(channel *model.Channel) (float64, error) { if balance, handled, err := tryUpdateTFOpenMirroredChannelBalance(channel); handled { return balance, err } baseURL := constant.ChannelBaseURLs[channel.Type] if channel.GetBaseURL() == "" { channel.BaseURL = &baseURL } switch channel.Type { case constant.ChannelTypeOpenAI: if channel.GetBaseURL() != "" { baseURL = channel.GetBaseURL() } case constant.ChannelTypeAzure: return 0, errors.New("尚未实现") case constant.ChannelTypeCustom: baseURL = channel.GetBaseURL() //case common.ChannelTypeOpenAISB: // return updateChannelOpenAISBBalance(channel) case constant.ChannelTypeAIProxy: return updateChannelAIProxyBalance(channel) case constant.ChannelTypeAPI2GPT: return updateChannelAPI2GPTBalance(channel) case constant.ChannelTypeAIGC2D: return updateChannelAIGC2DBalance(channel) case constant.ChannelTypeSiliconFlow: return updateChannelSiliconFlowBalance(channel) case constant.ChannelTypeDeepSeek: return updateChannelDeepSeekBalance(channel) case constant.ChannelTypeOpenRouter: return updateChannelOpenRouterBalance(channel) case constant.ChannelTypeMoonshot: return updateChannelMoonshotBalance(channel) default: return 0, errors.New("尚未实现") } url := fmt.Sprintf("%s/v1/dashboard/billing/subscription", baseURL) body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key)) if err != nil { return 0, err } subscription := OpenAISubscriptionResponse{} err = json.Unmarshal(body, &subscription) if err != nil { return 0, err } now := time.Now() startDate := fmt.Sprintf("%s-01", now.Format("2006-01")) endDate := now.Format("2006-01-02") if !subscription.HasPaymentMethod { startDate = now.AddDate(0, 0, -100).Format("2006-01-02") } url = fmt.Sprintf("%s/v1/dashboard/billing/usage?start_date=%s&end_date=%s", baseURL, startDate, endDate) body, err = GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key)) if err != nil { return 0, err } usage := OpenAIUsageResponse{} err = json.Unmarshal(body, &usage) if err != nil { return 0, err } balance := subscription.HardLimitUSD - usage.TotalUsage/100 channel.UpdateBalance(balance) return balance, nil } func UpdateChannelBalance(c *gin.Context) { id, err := strconv.Atoi(c.Param("id")) if err != nil { common.ApiError(c, err) return } channel, err := model.CacheGetChannel(id) if err != nil { common.ApiError(c, err) return } if channel.ChannelInfo.IsMultiKey { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "多密钥渠道不支持余额查询", }) return } oldBalance := channel.Balance balance, err := updateChannelBalance(channel) if err != nil { common.ApiError(c, err) return } notifyChannelBalanceAlertIfNeeded(channel, oldBalance, balance) c.JSON(http.StatusOK, gin.H{ "success": true, "message": "", "balance": balance, }) } func updateAllChannelsBalance() error { channels, err := model.GetAllChannels(0, 0, true, false) if err != nil { return err } for _, channel := range channels { if channel.Status != common.ChannelStatusEnabled { continue } if channel.ChannelInfo.IsMultiKey { continue // skip multi-key channels } // TODO: support Azure //if channel.Type != common.ChannelTypeOpenAI && channel.Type != common.ChannelTypeCustom { // continue //} oldBalance := channel.Balance balance, err := updateChannelBalance(channel) if err != nil { continue } else { notifyChannelBalanceAlertIfNeeded(channel, oldBalance, balance) // err is nil & balance <= 0 means quota is used up if balance <= 0 { service.DisableChannel(*types.NewChannelError(channel.Id, channel.Type, channel.Name, channel.ChannelInfo.IsMultiKey, "", channel.GetAutoBan()), "余额不足") } } time.Sleep(common.RequestInterval) } return nil } func UpdateAllChannelsBalance(c *gin.Context) { // TODO: make it async err := updateAllChannelsBalance() if err != nil { common.ApiError(c, err) return } c.JSON(http.StatusOK, gin.H{ "success": true, "message": "", }) return } func AutomaticallyUpdateChannels(frequency int) { for { time.Sleep(time.Duration(frequency) * time.Minute) common.SysLog("updating all channels") _ = updateAllChannelsBalance() common.SysLog("channels update done") } }