tokenFactory/model/log.go

579 lines
18 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package model
import (
"context"
"errors"
"fmt"
"strings"
"time"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/logger"
"github.com/gin-gonic/gin"
"github.com/bytedance/gopkg/util/gopool"
"gorm.io/gorm"
)
type Log struct {
Id int `json:"id" gorm:"index:idx_created_at_id,priority:1;index:idx_user_id_id,priority:2"`
UserId int `json:"user_id" gorm:"index;index:idx_user_id_id,priority:1"`
CreatedAt int64 `json:"created_at" gorm:"bigint;index:idx_created_at_id,priority:2;index:idx_created_at_type"`
Type int `json:"type" gorm:"index:idx_created_at_type"`
Content string `json:"content"`
Username string `json:"username" gorm:"index;index:index_username_model_name,priority:2;default:''"`
TokenName string `json:"token_name" gorm:"index;default:''"`
ModelName string `json:"model_name" gorm:"index;index:index_username_model_name,priority:1;default:''"`
Quota int `json:"quota" gorm:"default:0"`
PromptTokens int `json:"prompt_tokens" gorm:"default:0"`
CompletionTokens int `json:"completion_tokens" gorm:"default:0"`
UseTime int `json:"use_time" gorm:"default:0"`
IsStream bool `json:"is_stream"`
ChannelId int `json:"channel" gorm:"index"`
// 不在 API 中暴露渠道展示名称(控制台日志仅展示渠道编号)。
ChannelName string `json:"-" gorm:"->"`
// ChannelDisplay is a read-only API display value, e.g. route_slug_supplier_type.
ChannelDisplay string `json:"channel_display,omitempty" gorm:"-"`
TokenId int `json:"token_id" gorm:"default:0;index"`
Group string `json:"group" gorm:"index"`
Ip string `json:"ip" gorm:"index;default:''"`
RequestId string `json:"request_id,omitempty" gorm:"type:varchar(64);index:idx_logs_request_id;default:''"`
Other string `json:"other"`
}
// don't use iota, avoid change log type value
const (
LogTypeUnknown = 0
LogTypeTopup = 1
LogTypeConsume = 2
LogTypeManage = 3
LogTypeSystem = 4
LogTypeError = 5
LogTypeRefund = 6
)
func formatUserLogs(logs []*Log, startIdx int) {
for i := range logs {
logs[i].ChannelName = ""
logs[i].ChannelDisplay = ""
var otherMap map[string]interface{}
otherMap, _ = common.StrToMap(logs[i].Other)
if otherMap != nil {
// Remove admin-only debug fields.
delete(otherMap, "admin_info")
// delete(otherMap, "reject_reason")
delete(otherMap, "stream_status")
}
logs[i].Other = common.MapToJsonStr(otherMap)
logs[i].Id = startIdx + i + 1
}
}
func formatChannelDisplay(routeSlug string, supplierType string, channelID int) string {
routeSlug = strings.TrimSpace(routeSlug)
supplierType = strings.TrimSpace(supplierType)
if routeSlug == "" && channelID > 0 {
routeSlug = DefaultRouteSlugFromChannelID(int64(channelID))
}
if routeSlug == "" {
return ""
}
if supplierType == "" {
return routeSlug
}
return routeSlug + "_" + supplierType
}
func attachLogChannelDisplays(logs []*Log) {
if len(logs) == 0 {
return
}
channelIDs := make([]int, 0, len(logs))
seen := make(map[int]struct{}, len(logs))
for _, log := range logs {
if log == nil || log.ChannelId <= 0 {
continue
}
if _, ok := seen[log.ChannelId]; ok {
continue
}
seen[log.ChannelId] = struct{}{}
channelIDs = append(channelIDs, log.ChannelId)
}
if len(channelIDs) == 0 {
return
}
type channelDisplayRow struct {
Id int
RouteSlug string
SupplierType string
}
var rows []channelDisplayRow
if err := DB.Model(&Channel{}).
Select("id", "route_slug", "supplier_type").
Where("id IN ?", channelIDs).
Find(&rows).Error; err != nil {
common.SysError("failed to attach log channel display: " + err.Error())
return
}
displayMap := make(map[int]string, len(rows))
for _, row := range rows {
display := formatChannelDisplay(row.RouteSlug, row.SupplierType, row.Id)
if display != "" {
displayMap[row.Id] = display
}
}
for i := range logs {
if display, ok := displayMap[logs[i].ChannelId]; ok {
logs[i].ChannelDisplay = display
}
}
}
func GetLogByTokenId(tokenId int) (logs []*Log, err error) {
err = LOG_DB.Model(&Log{}).Where("token_id = ?", tokenId).Order("id desc").Limit(common.MaxRecentItems).Find(&logs).Error
formatUserLogs(logs, 0)
return logs, err
}
func RecordLog(userId int, logType int, content string) {
if logType == LogTypeConsume && !common.LogConsumeEnabled {
return
}
username, _ := GetUsernameById(userId, false)
content = prependUsernameBeforeRole(content, username)
log := &Log{
UserId: userId,
Username: username,
CreatedAt: common.GetTimestamp(),
Type: logType,
Content: content,
}
err := LOG_DB.Create(log).Error
if err != nil {
common.SysLog("failed to record log: " + err.Error())
}
}
// prependUsernameBeforeRole 在日志详情以角色词开头时,自动在角色前补充用户名,便于审计操作者身份。
func prependUsernameBeforeRole(content string, username string) string {
if content == "" || username == "" {
return content
}
rolePrefixes := []string{"管理员", "用户", "分销商", "供应商"}
for _, rolePrefix := range rolePrefixes {
withUsername := username + rolePrefix
if strings.HasPrefix(content, withUsername) {
return content
}
if strings.HasPrefix(content, rolePrefix) {
return username + content
}
}
return content
}
func RecordErrorLog(c *gin.Context, userId int, channelId int, modelName string, tokenName string, content string, tokenId int, useTimeSeconds int,
isStream bool, group string, other map[string]interface{}) {
logger.LogInfo(c, fmt.Sprintf("record error log: userId=%d, channelId=%d, modelName=%s, tokenName=%s, content=%s", userId, channelId, modelName, tokenName, content))
username := c.GetString("username")
requestId := c.GetString(common.RequestIdKey)
otherStr := common.MapToJsonStr(other)
// 判断是否需要记录 IP
needRecordIp := false
if settingMap, err := GetUserSetting(userId, false); err == nil {
if settingMap.RecordIpLog {
needRecordIp = true
}
}
log := &Log{
UserId: userId,
Username: username,
CreatedAt: common.GetTimestamp(),
Type: LogTypeError,
Content: content,
PromptTokens: 0,
CompletionTokens: 0,
TokenName: tokenName,
ModelName: modelName,
Quota: 0,
ChannelId: channelId,
TokenId: tokenId,
UseTime: useTimeSeconds,
IsStream: isStream,
Group: group,
Ip: func() string {
if needRecordIp {
return c.ClientIP()
}
return ""
}(),
RequestId: requestId,
Other: otherStr,
}
err := LOG_DB.Create(log).Error
if err != nil {
logger.LogError(c, "failed to record log: "+err.Error())
}
}
type RecordConsumeLogParams struct {
ChannelId int `json:"channel_id"`
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
ModelName string `json:"model_name"`
TokenName string `json:"token_name"`
Quota int `json:"quota"`
Content string `json:"content"`
TokenId int `json:"token_id"`
UseTimeSeconds int `json:"use_time_seconds"`
IsStream bool `json:"is_stream"`
Group string `json:"group"`
Other map[string]interface{} `json:"other"`
}
func RecordConsumeLog(c *gin.Context, userId int, params RecordConsumeLogParams) {
if !common.LogConsumeEnabled {
return
}
logger.LogInfo(c, fmt.Sprintf("record consume log: userId=%d, params=%s", userId, common.GetJsonString(params)))
username := c.GetString("username")
requestId := c.GetString(common.RequestIdKey)
otherStr := common.MapToJsonStr(params.Other)
// 判断是否需要记录 IP
needRecordIp := false
if settingMap, err := GetUserSetting(userId, false); err == nil {
if settingMap.RecordIpLog {
needRecordIp = true
}
}
log := &Log{
UserId: userId,
Username: username,
CreatedAt: common.GetTimestamp(),
Type: LogTypeConsume,
Content: params.Content,
PromptTokens: params.PromptTokens,
CompletionTokens: params.CompletionTokens,
TokenName: params.TokenName,
ModelName: params.ModelName,
Quota: params.Quota,
ChannelId: params.ChannelId,
TokenId: params.TokenId,
UseTime: params.UseTimeSeconds,
IsStream: params.IsStream,
Group: params.Group,
Ip: func() string {
if needRecordIp {
return c.ClientIP()
}
return ""
}(),
RequestId: requestId,
Other: otherStr,
}
err := LOG_DB.Create(log).Error
if err != nil {
logger.LogError(c, "failed to record log: "+err.Error())
}
if common.DataExportEnabled {
gopool.Go(func() {
LogQuotaData(userId, username, params.ModelName, params.Quota, common.GetTimestamp(), params.PromptTokens+params.CompletionTokens)
})
}
}
type RecordTaskBillingLogParams struct {
UserId int
LogType int
Content string
ChannelId int
ModelName string
TokenName string
Quota int
TokenId int
Group string
Other map[string]interface{}
}
func RecordTaskBillingLog(params RecordTaskBillingLogParams) {
if params.LogType == LogTypeConsume && !common.LogConsumeEnabled {
return
}
username, _ := GetUsernameById(params.UserId, false)
tokenName := params.TokenName
if params.TokenId > 0 {
if token, err := GetTokenById(params.TokenId); err == nil {
tokenName = token.Name
}
} else if tokenName == "" {
// playground/default token避免任务结算日志中令牌名为空导致前端不展示。
tokenName = "playground-default"
}
log := &Log{
UserId: params.UserId,
Username: username,
CreatedAt: common.GetTimestamp(),
Type: params.LogType,
Content: params.Content,
TokenName: tokenName,
ModelName: params.ModelName,
Quota: params.Quota,
ChannelId: params.ChannelId,
TokenId: params.TokenId,
Group: params.Group,
Other: common.MapToJsonStr(params.Other),
}
err := LOG_DB.Create(log).Error
if err != nil {
common.SysLog("failed to record task billing log: " + err.Error())
}
}
func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, startIdx int, num int, channel int, group string, requestId string) (logs []*Log, total int64, err error) {
var tx *gorm.DB
if logType == LogTypeUnknown {
tx = LOG_DB
} else {
tx = LOG_DB.Where("logs.type = ?", logType)
}
if modelName != "" {
tx = tx.Where("logs.model_name like ?", modelName)
}
if username != "" {
tx = tx.Where("logs.username = ?", username)
}
if tokenName != "" {
tx = tx.Where("logs.token_name = ?", tokenName)
}
if requestId != "" {
tx = tx.Where("logs.request_id = ?", requestId)
}
if startTimestamp != 0 {
tx = tx.Where("logs.created_at >= ?", startTimestamp)
}
if endTimestamp != 0 {
tx = tx.Where("logs.created_at <= ?", endTimestamp)
}
if channel != 0 {
tx = tx.Where("logs.channel_id = ?", channel)
}
if group != "" {
tx = tx.Where("logs."+logGroupCol+" = ?", group)
}
err = tx.Model(&Log{}).Count(&total).Error
if err != nil {
return nil, 0, err
}
err = tx.Order("logs.id desc").Limit(num).Offset(startIdx).Find(&logs).Error
if err != nil {
return nil, 0, err
}
attachLogChannelDisplays(logs)
for i := range logs {
if logs[i].Other == "" {
continue
}
otherMap, errParse := common.StrToMap(logs[i].Other)
if errParse != nil || otherMap == nil {
continue
}
// 历史错误日志 other 中可能含 channel_name渠道展示名控制台不返回。
delete(otherMap, "channel_name")
logs[i].Other = common.MapToJsonStr(otherMap)
}
return logs, total, err
}
const logSearchCountLimit = 10000
func GetUserLogs(userId int, logType int, startTimestamp int64, endTimestamp int64, modelName string, tokenName string, startIdx int, num int, group string, requestId string) (logs []*Log, total int64, err error) {
var tx *gorm.DB
if logType == LogTypeUnknown {
tx = LOG_DB.Where("logs.user_id = ?", userId)
} else {
tx = LOG_DB.Where("logs.user_id = ? and logs.type = ?", userId, logType)
}
if modelName != "" {
modelNamePattern, err := sanitizeLikePattern(modelName)
if err != nil {
return nil, 0, err
}
tx = tx.Where("logs.model_name LIKE ? ESCAPE '!'", modelNamePattern)
}
if tokenName != "" {
tx = tx.Where("logs.token_name = ?", tokenName)
}
if requestId != "" {
tx = tx.Where("logs.request_id = ?", requestId)
}
if startTimestamp != 0 {
tx = tx.Where("logs.created_at >= ?", startTimestamp)
}
if endTimestamp != 0 {
tx = tx.Where("logs.created_at <= ?", endTimestamp)
}
if group != "" {
tx = tx.Where("logs."+logGroupCol+" = ?", group)
}
err = tx.Model(&Log{}).Limit(logSearchCountLimit).Count(&total).Error
if err != nil {
common.SysError("failed to count user logs: " + err.Error())
return nil, 0, errors.New("查询日志失败")
}
err = tx.Order("logs.id desc").Limit(num).Offset(startIdx).Find(&logs).Error
if err != nil {
common.SysError("failed to search user logs: " + err.Error())
return nil, 0, errors.New("查询日志失败")
}
formatUserLogs(logs, startIdx)
return logs, total, err
}
type Stat struct {
Quota int `json:"quota"`
Rpm int `json:"rpm"`
Tpm int `json:"tpm"`
}
func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, channel int, group string) (stat Stat, err error) {
tx := LOG_DB.Table("logs").Select("sum(quota) quota")
// 为rpm和tpm创建单独的查询
rpmTpmQuery := LOG_DB.Table("logs").Select("count(*) rpm, sum(prompt_tokens) + sum(completion_tokens) tpm")
if username != "" {
tx = tx.Where("username = ?", username)
rpmTpmQuery = rpmTpmQuery.Where("username = ?", username)
}
if tokenName != "" {
tx = tx.Where("token_name = ?", tokenName)
rpmTpmQuery = rpmTpmQuery.Where("token_name = ?", tokenName)
}
if startTimestamp != 0 {
tx = tx.Where("created_at >= ?", startTimestamp)
}
if endTimestamp != 0 {
tx = tx.Where("created_at <= ?", endTimestamp)
}
if modelName != "" {
modelNamePattern, err := sanitizeLikePattern(modelName)
if err != nil {
return stat, err
}
tx = tx.Where("model_name LIKE ? ESCAPE '!'", modelNamePattern)
rpmTpmQuery = rpmTpmQuery.Where("model_name LIKE ? ESCAPE '!'", modelNamePattern)
}
if channel != 0 {
tx = tx.Where("channel_id = ?", channel)
rpmTpmQuery = rpmTpmQuery.Where("channel_id = ?", channel)
}
if group != "" {
tx = tx.Where(logGroupCol+" = ?", group)
rpmTpmQuery = rpmTpmQuery.Where(logGroupCol+" = ?", group)
}
tx = tx.Where("type = ?", LogTypeConsume)
rpmTpmQuery = rpmTpmQuery.Where("type = ?", LogTypeConsume)
// 只统计最近60秒的rpm和tpm
rpmTpmQuery = rpmTpmQuery.Where("created_at >= ?", time.Now().Add(-60*time.Second).Unix())
// 执行查询
if err := tx.Scan(&stat).Error; err != nil {
common.SysError("failed to query log stat: " + err.Error())
return stat, errors.New("查询统计数据失败")
}
if err := rpmTpmQuery.Scan(&stat).Error; err != nil {
common.SysError("failed to query rpm/tpm stat: " + err.Error())
return stat, errors.New("查询统计数据失败")
}
return stat, nil
}
// SumUsedQuotaByModelNames 按模型集合统计消耗额度与最近一分钟 rpm/tpm。
func SumUsedQuotaByModelNames(startTimestamp int64, endTimestamp int64, modelNames []string) (stat Stat, err error) {
if len(modelNames) == 0 {
return stat, nil
}
tx := LOG_DB.Table("logs").Select("sum(quota) quota")
rpmTpmQuery := LOG_DB.Table("logs").Select("count(*) rpm, sum(prompt_tokens) + sum(completion_tokens) tpm")
if startTimestamp != 0 {
tx = tx.Where("created_at >= ?", startTimestamp)
}
if endTimestamp != 0 {
tx = tx.Where("created_at <= ?", endTimestamp)
}
tx = tx.Where("model_name IN ?", modelNames)
rpmTpmQuery = rpmTpmQuery.Where("model_name IN ?", modelNames)
rpmTpmQuery = rpmTpmQuery.Where("created_at >= ?", time.Now().Add(-60*time.Second).Unix())
tx = tx.Where("type = ?", LogTypeConsume)
rpmTpmQuery = rpmTpmQuery.Where("type = ?", LogTypeConsume)
if err := tx.Scan(&stat).Error; err != nil {
common.SysError("failed to query supplier log stat: " + err.Error())
return stat, errors.New("查询统计数据失败")
}
if err := rpmTpmQuery.Scan(&stat).Error; err != nil {
common.SysError("failed to query supplier rpm/tpm stat: " + err.Error())
return stat, errors.New("查询统计数据失败")
}
return stat, nil
}
func SumUsedToken(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string) (token int) {
tx := LOG_DB.Table("logs").Select("ifnull(sum(prompt_tokens),0) + ifnull(sum(completion_tokens),0)")
if username != "" {
tx = tx.Where("username = ?", username)
}
if tokenName != "" {
tx = tx.Where("token_name = ?", tokenName)
}
if startTimestamp != 0 {
tx = tx.Where("created_at >= ?", startTimestamp)
}
if endTimestamp != 0 {
tx = tx.Where("created_at <= ?", endTimestamp)
}
if modelName != "" {
tx = tx.Where("model_name = ?", modelName)
}
tx.Where("type = ?", LogTypeConsume).Scan(&token)
return token
}
func DeleteOldLog(ctx context.Context, targetTimestamp int64, limit int) (int64, error) {
var total int64 = 0
for {
if nil != ctx.Err() {
return total, ctx.Err()
}
result := LOG_DB.Where("created_at < ?", targetTimestamp).Limit(limit).Delete(&Log{})
if nil != result.Error {
return total, result.Error
}
total += result.RowsAffected
if result.RowsAffected < int64(limit) {
break
}
}
return total, nil
}