tokenFactory/model/user.go

1468 lines
46 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package model
import (
"database/sql"
"encoding/json"
"errors"
"fmt"
"strconv"
"strings"
"time"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/dto"
"github.com/QuantumNous/new-api/logger"
"github.com/bytedance/gopkg/util/gopool"
"gorm.io/gorm"
)
const UserNameMaxLength = 20
// User if you add sensitive fields, don't forget to clean them in setupLogin function.
// Otherwise, the sensitive information will be saved on local storage in plain text!
type User struct {
Id int `json:"id"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
LastLoginAt *time.Time `json:"last_login_at,omitempty" gorm:"column:last_login_at"`
CreatedBy string `json:"created_by,omitempty" gorm:"column:created_by;type:varchar(32)"`
Username string `json:"username" gorm:"unique;index" validate:"max=20"`
Password string `json:"password" gorm:"not null;" validate:"min=8,max=20"`
OriginalPassword string `json:"original_password" gorm:"-:all"` // this field is only for Password change verification, don't save it to database!
DisplayName string `json:"display_name" gorm:"index" validate:"max=20"`
Role int `json:"role" gorm:"type:int;default:1"` // admin, common
Status int `json:"status" gorm:"type:int;default:1"` // enabled, disabled
Email string `json:"email" gorm:"index" validate:"max=50"`
Phone string `json:"phone" gorm:"column:phone;type:varchar(20);index"`
GitHubId string `json:"github_id" gorm:"column:github_id;index"`
DiscordId string `json:"discord_id" gorm:"column:discord_id;index"`
OidcId string `json:"oidc_id" gorm:"column:oidc_id;index"`
WeChatId string `json:"wechat_id" gorm:"column:wechat_id;index"`
TelegramId string `json:"telegram_id" gorm:"column:telegram_id;index"`
VerificationCode string `json:"verification_code" gorm:"-:all"` // this field is only for Email verification, don't save it to database!
AccessToken *string `json:"access_token" gorm:"type:char(32);column:access_token;uniqueIndex"` // this token is for system management
Quota int `json:"quota" gorm:"type:int;default:0"`
UsedQuota int `json:"used_quota" gorm:"type:int;default:0;column:used_quota"` // used quota
RequestCount int `json:"request_count" gorm:"type:int;default:0;"` // request number
Group string `json:"group" gorm:"type:varchar(64);default:'default'"`
AffCode string `json:"aff_code" gorm:"type:varchar(32);column:aff_code;uniqueIndex"`
AffCount int `json:"aff_count" gorm:"type:int;default:0;column:aff_count"`
AffQuota int `json:"aff_quota" gorm:"type:int;default:0;column:aff_quota"` // 邀请剩余额度
AffHistoryQuota int `json:"aff_history_quota" gorm:"type:int;default:0;column:aff_history"` // 邀请历史额度
InviterId int `json:"inviter_id" gorm:"type:int;column:inviter_id;index"`
DistributorCommissionBps int `json:"distributor_commission_bps" gorm:"type:int;default:0;column:distributor_commission_bps"` // 分销商名下新邀请关系的默认分成万分之一0 表示跟随系统 AffiliateDefaultCommissionBps
// IsDistributor 分销商资格 0/1与 role 解耦);普通用户 role=1 时可同时为分销商。旧版 role=5 已迁移为 role=1 + is_distributor=1。
IsDistributor int `json:"is_distributor" gorm:"column:is_distributor;type:integer;default:0;index"`
IsStudent int `json:"is_student" gorm:"column:is_student;type:integer;default:0;index"`
StudentStatus int `json:"student_status" gorm:"column:student_status;type:integer;default:0;index"`
StudentApplied *time.Time `json:"student_applied_at,omitempty" gorm:"column:student_applied_at"`
StudentApprovedAt *time.Time `json:"student_approved_at,omitempty" gorm:"column:student_approved_at"`
StudentApprovedBy int `json:"student_approved_by" gorm:"column:student_approved_by;type:int;default:0;index"`
DeletedAt gorm.DeletedAt `gorm:"index"`
LinuxDOId string `json:"linux_do_id" gorm:"column:linux_do_id;index"`
Setting string `json:"setting" gorm:"type:text;column:setting"`
Remark string `json:"remark,omitempty" gorm:"type:varchar(255)" validate:"max=255"`
Tags string `json:"tags,omitempty" gorm:"type:varchar(255)"`
StripeCustomer string `json:"stripe_customer" gorm:"type:varchar(64);column:stripe_customer;index"`
SupplierID int `json:"supplier_id" gorm:"type:int;column:supplier_id;index;default:0;comment:供应商申请ID 0表示非供应商"`
// AdminInitialSetupCompleted 管理员代建账号首次登录前须为 false自助注册等为 true。注意GORM Create 会省略 bool 的 false代建分支须在 Insert 内显式 UPDATE 落库为 0。
AdminInitialSetupCompleted bool `json:"admin_initial_setup_completed" gorm:"column:admin_initial_setup_completed;type:boolean;not null;default:true"`
}
func (user *User) ToBaseUser() *UserBase {
cache := &UserBase{
Id: user.Id,
Group: user.Group,
Quota: user.Quota,
Status: user.Status,
Username: user.Username,
Setting: user.Setting,
Email: user.Email,
}
return cache
}
func (user *User) GetAccessToken() string {
if user.AccessToken == nil {
return ""
}
return *user.AccessToken
}
func (user *User) SetAccessToken(token string) {
user.AccessToken = &token
}
func (user *User) GetSetting() dto.UserSetting {
setting := dto.UserSetting{}
if user.Setting != "" {
err := json.Unmarshal([]byte(user.Setting), &setting)
if err != nil {
common.SysLog("failed to unmarshal setting: " + err.Error())
}
}
return setting
}
func (user *User) SetSetting(setting dto.UserSetting) {
settingBytes, err := json.Marshal(setting)
if err != nil {
common.SysLog("failed to marshal setting: " + err.Error())
return
}
user.Setting = string(settingBytes)
}
// 根据用户角色生成默认的边栏配置
func generateDefaultSidebarConfigForRole(userRole int) string {
defaultConfig := map[string]interface{}{}
// 聊天区域 - 所有用户都可以访问
defaultConfig["chat"] = map[string]interface{}{
"enabled": true,
"playground": true,
"chat": true,
}
// 控制台区域 - 所有用户都可以访问
defaultConfig["console"] = map[string]interface{}{
"enabled": true,
"detail": true,
"token": true,
"log": true,
"midjourney": true,
"task": true,
}
// 个人中心区域 - 所有用户都可以访问
defaultConfig["personal"] = map[string]interface{}{
"enabled": true,
"topup": true,
"personal": true,
}
// 管理员区域 - 根据角色决定
if userRole == common.RoleAdminUser {
// 管理员可以访问管理员区域,但不能访问系统设置
defaultConfig["admin"] = map[string]interface{}{
"enabled": true,
"channel": true,
"models": true,
"redemption": true,
"user": true,
"setting": false, // 管理员不能访问系统设置
}
} else if userRole == common.RoleRootUser {
// 超级管理员可以访问所有功能
defaultConfig["admin"] = map[string]interface{}{
"enabled": true,
"channel": true,
"models": true,
"redemption": true,
"user": true,
"setting": true,
}
}
// 普通用户不包含admin区域
// 转换为JSON字符串
configBytes, err := json.Marshal(defaultConfig)
if err != nil {
common.SysLog("生成默认边栏配置失败: " + err.Error())
return ""
}
return string(configBytes)
}
// CheckUserExistOrDeleted 判断是否已有用户使用相同用户名,或与传入的非空邮箱冲突(含软删除)。
// 注册接口已改用 IsUsernameTakenUnscoped含已注销用户名与 IsEmailTakenByActiveUser不含已注销邮箱分别提示冲突。
func CheckUserExistOrDeleted(username string, email string) (bool, error) {
var user User
// err := DB.Unscoped().First(&user, "username = ? or email = ?", username, email).Error
// check email if empty
var err error
if email == "" {
err = DB.Unscoped().First(&user, "username = ?", username).Error
} else {
err = DB.Unscoped().First(&user, "username = ? or email = ?", username, email).Error
}
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
// not exist, return false, nil
return false, nil
}
// other error, return false, err
return false, err
}
// exist, return true, nil
return true, nil
}
// IsUsernameTakenUnscoped 判断用户名是否已被占用(含软删除),用于注册等场景的精确提示。
func IsUsernameTakenUnscoped(username string) (bool, error) {
username = strings.TrimSpace(username)
if username == "" {
return false, nil
}
var user User
err := DB.Unscoped().Where("username = ?", username).First(&user).Error
if errors.Is(err, gorm.ErrRecordNotFound) {
return false, nil
}
if err != nil {
return false, err
}
return true, nil
}
// IsEmailTakenUnscoped 判断邮箱是否曾被占用(含已软删用户);邮箱为空时不视为占用。注册/绑定冲突请用 IsEmailTakenByActiveUser。
func IsEmailTakenUnscoped(email string) (bool, error) {
email = strings.TrimSpace(email)
if email == "" {
return false, nil
}
var user User
err := DB.Unscoped().Where("email = ?", email).First(&user).Error
if errors.Is(err, gorm.ErrRecordNotFound) {
return false, nil
}
if err != nil {
return false, err
}
return true, nil
}
// IsEmailTakenByActiveUser 判断邮箱是否已被未注销(未软删)用户占用;已注销账号不占坑,邮箱可再次用于注册。
func IsEmailTakenByActiveUser(email string) (bool, error) {
email = strings.TrimSpace(email)
if email == "" {
return false, nil
}
var user User
err := DB.Where("email = ?", email).First(&user).Error
if errors.Is(err, gorm.ErrRecordNotFound) {
return false, nil
}
if err != nil {
return false, err
}
return true, nil
}
// IsEmailTakenByOtherUser 判断邮箱是否已被除 excludeUserId 以外的未注销用户占用。
func IsEmailTakenByOtherUser(email string, excludeUserId int) bool {
email = strings.TrimSpace(email)
if email == "" {
return false
}
return DB.Where("email = ? AND id <> ?", email, excludeUserId).Find(&User{}).RowsAffected > 0
}
// NormalizeAndValidateAdminUserEmail 管理员创建/编辑用户时的邮箱去首尾空格空表示不绑定非空则校验格式、长度与占用excludeUserId=0 表示新建)。
func NormalizeAndValidateAdminUserEmail(email string, excludeUserId int) (string, error) {
n := strings.TrimSpace(email)
if n == "" {
return "", nil
}
if err := common.Validate.Var(n, "email,max=50"); err != nil {
return "", fmt.Errorf("邮箱格式无效")
}
if excludeUserId == 0 {
taken, err := IsEmailTakenByActiveUser(n)
if err != nil {
return "", err
}
if taken {
return "", fmt.Errorf("邮箱已被占用")
}
} else {
if IsEmailTakenByOtherUser(n, excludeUserId) {
return "", fmt.Errorf("邮箱已被占用")
}
}
return n, nil
}
func GetMaxUserId() int {
var user User
DB.Unscoped().Last(&user)
return user.Id
}
// TouchUserLastLogin 在用户成功建立会话(登录)后更新上次登录时间。
func TouchUserLastLogin(userId int) {
if userId <= 0 {
return
}
now := time.Now()
if err := DB.Model(&User{}).Where("id = ?", userId).Update("last_login_at", now).Error; err != nil {
common.SysLog("TouchUserLastLogin: " + err.Error())
}
}
func applyStudentViewFilter(query *gorm.DB, studentView string) *gorm.DB {
switch strings.TrimSpace(studentView) {
case "pending":
return query.Where("student_status = ?", common.StudentStatusPending)
case "students":
return query.Where("is_student = ? AND student_status = ?", 1, common.StudentStatusApproved)
default:
return query
}
}
func GetAllUsers(pageInfo *common.PageInfo, studentView string, tag string) (users []*User, total int64, err error) {
// Start transaction
tx := DB.Begin()
if tx.Error != nil {
return nil, 0, tx.Error
}
defer func() {
if r := recover(); r != nil {
tx.Rollback()
}
}()
// Get total count within transaction
baseQuery := applyStudentViewFilter(tx.Unscoped().Model(&User{}), studentView)
// Apply tag filter if specified
if tag != "" {
if common.UsingPostgreSQL {
baseQuery = baseQuery.Where("tags ILIKE ?", "%"+tag+"%")
} else {
baseQuery = baseQuery.Where("tags LIKE ?", "%"+tag+"%")
}
}
err = baseQuery.Count(&total).Error
if err != nil {
tx.Rollback()
return nil, 0, err
}
// Get paginated users within same transaction
err = baseQuery.Order("id desc").Limit(pageInfo.GetPageSize()).Offset(pageInfo.GetStartIdx()).Omit("password").Find(&users).Error
if err != nil {
tx.Rollback()
return nil, 0, err
}
// Commit transaction
if err = tx.Commit().Error; err != nil {
return nil, 0, err
}
return users, total, nil
}
func SearchUsers(keyword string, group string, studentView string, tag string, startIdx int, num int) ([]*User, int64, error) {
var users []*User
var total int64
var err error
// 开始事务
tx := DB.Begin()
if tx.Error != nil {
return nil, 0, tx.Error
}
defer func() {
if r := recover(); r != nil {
tx.Rollback()
}
}()
// 构建基础查询
query := tx.Unscoped().Model(&User{})
query = applyStudentViewFilter(query, studentView)
// Apply tag filter if specified
if tag != "" {
if common.UsingPostgreSQL {
query = query.Where("tags ILIKE ?", "%"+tag+"%")
} else {
query = query.Where("tags LIKE ?", "%"+tag+"%")
}
}
// 构建搜索条件
likeCondition := "username LIKE ? OR email LIKE ? OR display_name LIKE ? OR phone LIKE ?"
// 尝试将关键字转换为整数ID
keywordInt, err := strconv.Atoi(keyword)
if err == nil {
// 如果是数字同时搜索ID和其他字段
likeCondition = "id = ? OR " + likeCondition
if group != "" {
query = query.Where("("+likeCondition+") AND "+commonGroupCol+" = ?",
keywordInt, "%"+keyword+"%", "%"+keyword+"%", "%"+keyword+"%", "%"+keyword+"%", group)
} else {
query = query.Where(likeCondition,
keywordInt, "%"+keyword+"%", "%"+keyword+"%", "%"+keyword+"%", "%"+keyword+"%")
}
} else {
// 非数字关键字,只搜索字符串字段
if group != "" {
query = query.Where("("+likeCondition+") AND "+commonGroupCol+" = ?",
"%"+keyword+"%", "%"+keyword+"%", "%"+keyword+"%", "%"+keyword+"%", group)
} else {
query = query.Where(likeCondition,
"%"+keyword+"%", "%"+keyword+"%", "%"+keyword+"%", "%"+keyword+"%")
}
}
// 获取总数
err = query.Count(&total).Error
if err != nil {
tx.Rollback()
return nil, 0, err
}
// 获取分页数据
err = query.Omit("password").Order("id desc").Limit(num).Offset(startIdx).Find(&users).Error
if err != nil {
tx.Rollback()
return nil, 0, err
}
// 提交事务
if err = tx.Commit().Error; err != nil {
return nil, 0, err
}
return users, total, nil
}
func GetUserById(id int, selectAll bool) (*User, error) {
if id == 0 {
return nil, errors.New("id 为空!")
}
user := User{Id: id}
var err error = nil
if selectAll {
err = DB.First(&user, "id = ?", id).Error
} else {
err = DB.Omit("password").First(&user, "id = ?", id).Error
}
return &user, err
}
func GetUserIdByAffCode(affCode string) (int, error) {
if affCode == "" {
return 0, errors.New("affCode 为空!")
}
var user User
err := DB.Select("id").First(&user, "aff_code = ?", affCode).Error
return user.Id, err
}
// EnsureAffCode generates a unique aff_code for the user if it is empty,
// retrying on rare collisions. This prevents duplicate-key errors on
// the idx_users_aff_code unique index when multiple users have aff_code = ”.
func (user *User) EnsureAffCode() {
if user.AffCode != "" {
return
}
const maxRetries = 5
for i := 0; i < maxRetries; i++ {
code := common.GetRandomString(6) // 6 chars ≈ 2.2B combos (alphanumeric), negligible collision
var count int64
DB.Model(&User{}).Where("aff_code = ? AND id != ?", code, user.Id).Count(&count)
if count == 0 {
user.AffCode = code
return
}
}
// Fallback: append user id to guarantee uniqueness
user.AffCode = common.GetRandomString(4) + fmt.Sprintf("%d", user.Id)
}
// BackfillEmptyAffCodes finds all users whose aff_code is empty and assigns
// each a unique aff_code. This is needed because aff_code has a uniqueIndex,
// and multiple rows with aff_code = ” violate that constraint on update.
func BackfillEmptyAffCodes() error {
var users []User
if err := DB.Unscoped().Select("id").Where("aff_code = ''").Find(&users).Error; err != nil {
return err
}
if len(users) == 0 {
return nil
}
common.SysLog(fmt.Sprintf("backfill empty aff_code: %d user(s) need assignment", len(users)))
for i := range users {
users[i].EnsureAffCode()
if users[i].AffCode == "" {
common.SysError(fmt.Sprintf("backfill empty aff_code: failed to generate code for user %d", users[i].Id))
continue
}
if err := DB.Model(&User{}).Where("id = ?", users[i].Id).UpdateColumn("aff_code", users[i].AffCode).Error; err != nil {
common.SysError(fmt.Sprintf("backfill empty aff_code: user %d: %s", users[i].Id, err.Error()))
}
}
return nil
}
func DeleteUserById(id int) (err error) {
if id == 0 {
return errors.New("id 为空!")
}
user := User{Id: id}
return user.Delete()
}
func HardDeleteUserById(id int) error {
if id == 0 {
return errors.New("id 为空!")
}
err := DB.Unscoped().Delete(&User{}, "id = ?", id).Error
return err
}
// inviteUser 在新用户通过邀请注册成功后调用邀请人数aff_count+1
// 若运营配置了邀请人注册奖励QuotaForInviter则直接增加邀请人可用额度quota
//
// 说明:注册类邀请奖励与「分销充值提成」分流——后者仍通过 IncreaseUserAffCommissionQuota
// 写入 aff_quota / aff_history本函数不再触碰 aff_quota、aff_history避免与分销待结算/历史统计混淆。
// 历史已写入 aff_* 的数据不做迁移,仅新产生的注册奖励走 quota。
func inviteUser(inviterId int) (err error) {
if inviterId <= 0 {
return nil
}
if _, err = GetUserById(inviterId, true); err != nil {
return err
}
reward := common.QuotaForInviter
// 与 IncreaseUserQuota 一致Batch 模式下额度写入走批处理队列,不能在事务内直接改 quota。
useBatchQuota := reward > 0 && common.BatchUpdateEnabled
tx := DB.Begin()
if tx.Error != nil {
return tx.Error
}
defer tx.Rollback()
if err = tx.Model(&User{}).Where("id = ?", inviterId).UpdateColumn("aff_count", gorm.Expr("aff_count + ?", 1)).Error; err != nil {
return err
}
if reward > 0 && !useBatchQuota {
if err = tx.Model(&User{}).Where("id = ?", inviterId).UpdateColumn("quota", gorm.Expr("quota + ?", reward)).Error; err != nil {
return err
}
}
if err = tx.Commit().Error; err != nil {
return err
}
if useBatchQuota {
if err = IncreaseUserQuota(inviterId, reward, true); err != nil {
return err
}
} else if reward > 0 {
gopool.Go(func() {
if err := cacheIncrUserQuota(inviterId, int64(reward)); err != nil {
common.SysLog("inviteUser cacheIncrUserQuota: " + err.Error())
}
})
}
inviter, err := GetUserById(inviterId, true)
if err != nil {
return err
}
return updateUserCache(*inviter)
}
func (user *User) TransferAffQuotaToQuota(quota int) error {
// 检查quota是否小于最小额度
if float64(quota) < common.QuotaPerUnit {
return fmt.Errorf("转移额度最小为%s", logger.LogQuota(int(common.QuotaPerUnit)))
}
// 开始数据库事务
tx := DB.Begin()
if tx.Error != nil {
return tx.Error
}
defer tx.Rollback() // 确保在函数退出时事务能回滚
// 加锁查询用户以确保数据一致性
err := tx.Set("gorm:query_option", "FOR UPDATE").First(&user, user.Id).Error
if err != nil {
return err
}
// 再次检查用户的AffQuota是否足够
if user.AffQuota < quota {
return errors.New("邀请额度不足!")
}
// 更新用户额度
user.AffQuota -= quota
user.Quota += quota
// 保存用户状态
if err := tx.Save(user).Error; err != nil {
return err
}
// 提交事务
return tx.Commit().Error
}
func (user *User) Insert(inviterId int) error {
var err error
if user.Password != "" {
user.Password, err = common.Password2Hash(user.Password)
if err != nil {
return err
}
}
user.Quota = common.QuotaForNewUser
//user.SetAccessToken(common.GetUUID())
user.EnsureAffCode()
// 初始化用户设置,包括默认的边栏配置
if user.Setting == "" {
defaultSetting := dto.UserSetting{}
// 这里暂时不设置SidebarModules因为需要在用户创建后根据角色设置
user.SetSetting(defaultSetting)
}
if user.CreatedBy == "" {
user.CreatedBy = common.UserCreatedByRegistration
}
// 非管理员代建账号默认可正常使用;管理员代建由 controller 显式置为 false
if user.CreatedBy != common.UserCreatedByAdmin {
user.AdminInitialSetupCompleted = true
}
result := DB.Create(user)
if result.Error != nil {
return result.Error
}
// 管理员代建Create 不会写入 falseMySQL 会落在列默认值 1必须显式更新为 0首次登录才会要求改密/补手机。
if user.CreatedBy == common.UserCreatedByAdmin && user.Id > 0 {
if err := DB.Model(&User{}).Where("id = ?", user.Id).UpdateColumn("admin_initial_setup_completed", false).Error; err != nil {
return err
}
user.AdminInitialSetupCompleted = false
}
// 用户创建成功后,根据角色初始化边栏配置
// 需要重新获取用户以确保有正确的ID和Role
var createdUser User
if err := DB.Where("username = ?", user.Username).First(&createdUser).Error; err == nil {
// 生成基于角色的默认边栏配置
defaultSidebarConfig := generateDefaultSidebarConfigForRole(createdUser.Role)
if defaultSidebarConfig != "" {
currentSetting := createdUser.GetSetting()
currentSetting.SidebarModules = defaultSidebarConfig
createdUser.SetSetting(currentSetting)
createdUser.Update(false)
common.SysLog(fmt.Sprintf("为新用户 %s (角色: %d) 初始化边栏配置", createdUser.Username, createdUser.Role))
}
}
if common.QuotaForNewUser > 0 {
RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("新用户注册赠送 %s", logger.LogQuota(common.QuotaForNewUser)))
}
if inviterId != 0 {
_ = EnsureAffInviteRelation(inviterId, user.Id)
if common.QuotaForInvitee > 0 {
_ = IncreaseUserQuota(user.Id, common.QuotaForInvitee, true)
RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("使用邀请码赠送 %s", logger.LogQuota(common.QuotaForInvitee)))
}
if common.QuotaForInviter > 0 {
RecordLog(inviterId, LogTypeSystem, fmt.Sprintf("邀请用户赠送 %s", logger.LogQuota(common.QuotaForInviter)))
}
_ = inviteUser(inviterId)
}
return nil
}
// InsertWithTx inserts a new user within an existing transaction.
// This is used for OAuth registration where user creation and binding need to be atomic.
// Post-creation tasks (sidebar config, logs, inviter rewards) are handled after the transaction commits.
func (user *User) InsertWithTx(tx *gorm.DB, inviterId int) error {
var err error
if user.Password != "" {
user.Password, err = common.Password2Hash(user.Password)
if err != nil {
return err
}
}
user.Quota = common.QuotaForNewUser
user.EnsureAffCode()
// 初始化用户设置
if user.Setting == "" {
defaultSetting := dto.UserSetting{}
user.SetSetting(defaultSetting)
}
if user.CreatedBy == "" {
user.CreatedBy = common.UserCreatedByRegistration
}
if user.CreatedBy != common.UserCreatedByAdmin {
user.AdminInitialSetupCompleted = true
}
result := tx.Create(user)
if result.Error != nil {
return result.Error
}
if user.CreatedBy == common.UserCreatedByAdmin && user.Id > 0 {
if err := tx.Model(&User{}).Where("id = ?", user.Id).UpdateColumn("admin_initial_setup_completed", false).Error; err != nil {
return err
}
user.AdminInitialSetupCompleted = false
}
return nil
}
// FinalizeOAuthUserCreation performs post-transaction tasks for OAuth user creation.
// This should be called after the transaction commits successfully.
func (user *User) FinalizeOAuthUserCreation(inviterId int) {
// 用户创建成功后,根据角色初始化边栏配置
var createdUser User
if err := DB.Where("id = ?", user.Id).First(&createdUser).Error; err == nil {
defaultSidebarConfig := generateDefaultSidebarConfigForRole(createdUser.Role)
if defaultSidebarConfig != "" {
currentSetting := createdUser.GetSetting()
currentSetting.SidebarModules = defaultSidebarConfig
createdUser.SetSetting(currentSetting)
createdUser.Update(false)
common.SysLog(fmt.Sprintf("为新用户 %s (角色: %d) 初始化边栏配置", createdUser.Username, createdUser.Role))
}
}
if common.QuotaForNewUser > 0 {
RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("新用户注册赠送 %s", logger.LogQuota(common.QuotaForNewUser)))
}
if inviterId != 0 {
_ = EnsureAffInviteRelation(inviterId, user.Id)
if common.QuotaForInvitee > 0 {
_ = IncreaseUserQuota(user.Id, common.QuotaForInvitee, true)
RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("使用邀请码赠送 %s", logger.LogQuota(common.QuotaForInvitee)))
}
if common.QuotaForInviter > 0 {
RecordLog(inviterId, LogTypeSystem, fmt.Sprintf("邀请用户赠送 %s", logger.LogQuota(common.QuotaForInviter)))
}
_ = inviteUser(inviterId)
}
}
// Update 写入用户行;为防止调用方传入“部分构造的 User”导致 username/password/email 等关键字段被
// 零值覆盖(历史上 Select("*").Updates 引发的批量擦库事故),这里先用 DB 中的完整行作为基底,
// 再将调用方显式赋值的字段合并写入:
// - 字符串/字符标识字段username、email、phone、各 OAuth ID、display_name、group、setting、
// remark、stripe_customer、aff_code 等)一律遵循「调用方传空串则保留旧值」;如需清空请走
// ClearBinding 或对应的列级接口,禁止通过本函数清空。
// - 数值/布尔字段role、status、quota、is_distributor、is_student、student_status 等)尊重
// 调用方入参(允许显式置 0/false保留先前可取消分销商等业务语义
// - 注册时间、上次登录、创建来源、软删除标记等系统字段始终保留 DB 现有值,不可被调用方覆盖。
// - 仅当 updatePassword=true 时才更新密码;否则统一回填旧密码哈希,杜绝走只读 select(omit password)
// 的路径误把密码改为空串。
func (user *User) Update(updatePassword bool) error {
if user.Id == 0 {
return errors.New("user id is empty")
}
var err error
if updatePassword {
user.Password, err = common.Password2Hash(user.Password)
if err != nil {
return err
}
}
var existing User
if err = DB.First(&existing, user.Id).Error; err != nil {
return err
}
restoreIfEmpty := func(field, fallback *string) {
if *field == "" {
*field = *fallback
}
}
restoreIfEmpty(&user.Username, &existing.Username)
restoreIfEmpty(&user.DisplayName, &existing.DisplayName)
restoreIfEmpty(&user.Email, &existing.Email)
restoreIfEmpty(&user.Phone, &existing.Phone)
restoreIfEmpty(&user.GitHubId, &existing.GitHubId)
restoreIfEmpty(&user.DiscordId, &existing.DiscordId)
restoreIfEmpty(&user.OidcId, &existing.OidcId)
restoreIfEmpty(&user.WeChatId, &existing.WeChatId)
restoreIfEmpty(&user.TelegramId, &existing.TelegramId)
restoreIfEmpty(&user.LinuxDOId, &existing.LinuxDOId)
restoreIfEmpty(&user.Group, &existing.Group)
restoreIfEmpty(&user.Setting, &existing.Setting)
restoreIfEmpty(&user.Remark, &existing.Remark)
restoreIfEmpty(&user.StripeCustomer, &existing.StripeCustomer)
restoreIfEmpty(&user.AffCode, &existing.AffCode)
if !updatePassword {
user.Password = existing.Password
} else if user.Password == "" {
// 极端兜底:标记要改密但传了空串,仍保留旧哈希,避免擦光
user.Password = existing.Password
}
if user.AccessToken == nil {
user.AccessToken = existing.AccessToken
}
user.CreatedAt = existing.CreatedAt
user.CreatedBy = existing.CreatedBy
user.LastLoginAt = existing.LastLoginAt
user.DeletedAt = existing.DeletedAt
user.UpdatedAt = time.Now()
// 双保险:避免 aff_code 在 existing/入参均为空时违反唯一索引
user.EnsureAffCode()
if err = DB.Save(user).Error; err != nil {
return err
}
return updateUserCache(*user)
}
func (user *User) Edit(updatePassword bool) error {
var err error
if updatePassword {
user.Password, err = common.Password2Hash(user.Password)
if err != nil {
return err
}
}
newUser := *user
normalizedPhone, err := NormalizeAndValidateAdminUserPhone(newUser.Phone, newUser.Id)
if err != nil {
return err
}
normalizedEmail, err := NormalizeAndValidateAdminUserEmail(newUser.Email, newUser.Id)
if err != nil {
return err
}
updates := map[string]interface{}{
"username": newUser.Username,
"display_name": newUser.DisplayName,
"group": newUser.Group,
"quota": newUser.Quota,
"remark": newUser.Remark,
"phone": normalizedPhone,
"email": normalizedEmail,
"tags": newUser.Tags,
"updated_at": time.Now(),
}
if updatePassword {
updates["password"] = newUser.Password
}
DB.First(&user, user.Id)
if err = DB.Model(user).Updates(updates).Error; err != nil {
return err
}
// Update cache
return updateUserCache(*user)
}
func (user *User) ClearBinding(bindingType string) error {
if user.Id == 0 {
return errors.New("user id is empty")
}
bindingColumnMap := map[string]string{
"email": "email",
"github": "github_id",
"discord": "discord_id",
"oidc": "oidc_id",
"wechat": "wechat_id",
"telegram": "telegram_id",
"linuxdo": "linux_do_id",
}
column, ok := bindingColumnMap[bindingType]
if !ok {
return errors.New("invalid binding type")
}
if err := DB.Model(&User{}).Where("id = ?", user.Id).Update(column, "").Error; err != nil {
return err
}
if err := DB.Where("id = ?", user.Id).First(user).Error; err != nil {
return err
}
return updateUserCache(*user)
}
func (user *User) Delete() error {
if user.Id == 0 {
return errors.New("id 为空!")
}
if err := DB.Delete(user).Error; err != nil {
return err
}
// 清除缓存
return invalidateUserCache(user.Id)
}
func (user *User) HardDelete() error {
if user.Id == 0 {
return errors.New("id 为空!")
}
err := DB.Unscoped().Delete(user).Error
return err
}
// ValidateAndFill check password & user status
func (user *User) ValidateAndFill() (err error) {
// When querying with struct, GORM will only query with non-zero fields,
// that means if your field's value is 0, '', false or other zero values,
// it won't be used to build query conditions
password := user.Password
username := strings.TrimSpace(user.Username)
if username == "" || password == "" {
return errors.New("用户名或密码为空")
}
// find buy username or email
DB.Where("username = ? OR email = ?", username, username).First(user)
okay := common.ValidatePasswordAndHash(password, user.Password)
if !okay || user.Status != common.UserStatusEnabled {
return errors.New("用户名或密码错误,或用户已被封禁")
}
return nil
}
// FillUserById 按主键加载用户;历史实现吞掉 First 的 error导致 ErrRecordNotFound / 瞬时连接错误
// 都会让调用方拿到只剩 {Id:X} 的 User再叠加 Update() 的全行覆盖会把 username 等字段清空。
// 现统一对外返回 error调用方需自行处理 ErrRecordNotFound。
func (user *User) FillUserById() error {
if user.Id == 0 {
return errors.New("id 为空!")
}
return DB.Where(User{Id: user.Id}).First(user).Error
}
func (user *User) FillUserByEmail() error {
if user.Email == "" {
return errors.New("email 为空!")
}
return DB.Where(User{Email: user.Email}).First(user).Error
}
func (user *User) FillUserByGitHubId() error {
if user.GitHubId == "" {
return errors.New("GitHub id 为空!")
}
return DB.Where(User{GitHubId: user.GitHubId}).First(user).Error
}
// UpdateGitHubId updates the user's GitHub ID (used for migration from login to numeric ID)
func (user *User) UpdateGitHubId(newGitHubId string) error {
if user.Id == 0 {
return errors.New("user id is empty")
}
return DB.Model(user).Update("github_id", newGitHubId).Error
}
func (user *User) FillUserByDiscordId() error {
if user.DiscordId == "" {
return errors.New("discord id 为空!")
}
return DB.Where(User{DiscordId: user.DiscordId}).First(user).Error
}
func (user *User) FillUserByOidcId() error {
if user.OidcId == "" {
return errors.New("oidc id 为空!")
}
return DB.Where(User{OidcId: user.OidcId}).First(user).Error
}
func (user *User) FillUserByWeChatId() error {
if user.WeChatId == "" {
return errors.New("WeChat id 为空!")
}
return DB.Where(User{WeChatId: user.WeChatId}).First(user).Error
}
func (user *User) FillUserByTelegramId() error {
if user.TelegramId == "" {
return errors.New("Telegram id 为空!")
}
err := DB.Where(User{TelegramId: user.TelegramId}).First(user).Error
if errors.Is(err, gorm.ErrRecordNotFound) {
return errors.New("该 Telegram 账户未绑定")
}
return err
}
func IsEmailAlreadyTaken(email string) bool {
email = strings.TrimSpace(email)
if email == "" {
return false
}
return DB.Where("email = ?", email).Find(&User{}).RowsAffected > 0
}
// IsPhoneTakenByActiveUser 判断手机号是否已被未注销用户占用;已注销账号不占坑,手机号可再次用于注册。
func IsPhoneTakenByActiveUser(phone string) bool {
phone = common.NormalizePhone(phone)
if phone == "" {
return false
}
return DB.Where("phone = ?", phone).Find(&User{}).RowsAffected > 0
}
// IsPhoneAlreadyTaken 判断手机号是否已被未注销用户占用(与 IsPhoneTakenByActiveUser 等价)。
func IsPhoneAlreadyTaken(phone string) bool {
return IsPhoneTakenByActiveUser(phone)
}
// IsPhoneTakenByOtherUser 判断手机号是否已被除 excludeUserId 以外的未注销用户占用。
func IsPhoneTakenByOtherUser(phone string, excludeUserId int) bool {
phone = common.NormalizePhone(phone)
if phone == "" {
return false
}
return DB.Where("phone = ? AND id <> ?", phone, excludeUserId).Find(&User{}).RowsAffected > 0
}
// NormalizeAndValidateAdminUserPhone 管理员创建/编辑用户时的手机号规范化空字符串表示不绑定非空则校验格式、黑名单与占用excludeUserId=0 表示新建用户)。
func NormalizeAndValidateAdminUserPhone(phone string, excludeUserId int) (string, error) {
n := common.NormalizePhone(phone)
if n == "" {
return "", nil
}
if !common.ValidateMainlandChinaPhone(n) {
return "", fmt.Errorf("手机号格式无效,请输入 11 位中国大陆手机号")
}
if common.IsSMSPhoneBlacklisted(n) {
return "", fmt.Errorf("该手机号已被加入短信黑名单")
}
if excludeUserId == 0 {
if IsPhoneAlreadyTaken(n) {
return "", fmt.Errorf("手机号已被占用")
}
} else {
if IsPhoneTakenByOtherUser(n, excludeUserId) {
return "", fmt.Errorf("手机号已被占用")
}
}
return n, nil
}
func IsWeChatIdAlreadyTaken(wechatId string) bool {
return DB.Unscoped().Where("wechat_id = ?", wechatId).Find(&User{}).RowsAffected == 1
}
func IsGitHubIdAlreadyTaken(githubId string) bool {
return DB.Unscoped().Where("github_id = ?", githubId).Find(&User{}).RowsAffected == 1
}
func IsDiscordIdAlreadyTaken(discordId string) bool {
return DB.Unscoped().Where("discord_id = ?", discordId).Find(&User{}).RowsAffected == 1
}
func IsOidcIdAlreadyTaken(oidcId string) bool {
return DB.Where("oidc_id = ?", oidcId).Find(&User{}).RowsAffected == 1
}
func IsTelegramIdAlreadyTaken(telegramId string) bool {
return DB.Unscoped().Where("telegram_id = ?", telegramId).Find(&User{}).RowsAffected == 1
}
func ResetUserPasswordByEmail(email string, password string) error {
if email == "" || password == "" {
return errors.New("邮箱地址或密码为空!")
}
hashedPassword, err := common.Password2Hash(password)
if err != nil {
return err
}
err = DB.Model(&User{}).Where("email = ?", email).Update("password", hashedPassword).Error
return err
}
// ResetUserPasswordByPhone 按手机号重置用户密码。
func ResetUserPasswordByPhone(phone string, password string) error {
phone = common.NormalizePhone(phone)
if phone == "" || password == "" {
return errors.New("手机号或密码为空!")
}
hashedPassword, err := common.Password2Hash(password)
if err != nil {
return err
}
err = DB.Model(&User{}).Where("phone = ?", phone).Update("password", hashedPassword).Error
return err
}
func IsAdmin(userId int) bool {
if userId == 0 {
return false
}
var user User
err := DB.Where("id = ?", userId).Select("role").Find(&user).Error
if err != nil {
common.SysLog("no such user " + err.Error())
return false
}
return user.Role >= common.RoleAdminUser
}
//// IsUserEnabled checks user status from Redis first, falls back to DB if needed
//func IsUserEnabled(id int, fromDB bool) (status bool, err error) {
// defer func() {
// // Update Redis cache asynchronously on successful DB read
// if shouldUpdateRedis(fromDB, err) {
// gopool.Go(func() {
// if err := updateUserStatusCache(id, status); err != nil {
// common.SysError("failed to update user status cache: " + err.Error())
// }
// })
// }
// }()
// if !fromDB && common.RedisEnabled {
// // Try Redis first
// status, err := getUserStatusCache(id)
// if err == nil {
// return status == common.UserStatusEnabled, nil
// }
// // Don't return error - fall through to DB
// }
// fromDB = true
// var user User
// err = DB.Where("id = ?", id).Select("status").Find(&user).Error
// if err != nil {
// return false, err
// }
//
// return user.Status == common.UserStatusEnabled, nil
//}
func ValidateAccessToken(token string) (user *User) {
if token == "" {
return nil
}
token = strings.Replace(token, "Bearer ", "", 1)
user = &User{}
if DB.Where("access_token = ?", token).First(user).RowsAffected == 1 {
return user
}
return nil
}
// GetUserQuota gets quota from Redis first, falls back to DB if needed
func GetUserQuota(id int, fromDB bool) (quota int, err error) {
defer func() {
// Update Redis cache asynchronously on successful DB read
if shouldUpdateRedis(fromDB, err) {
gopool.Go(func() {
if err := updateUserQuotaCache(id, quota); err != nil {
common.SysLog("failed to update user quota cache: " + err.Error())
}
})
}
}()
if !fromDB && common.RedisEnabled {
quota, err := getUserQuotaCache(id)
if err == nil {
return quota, nil
}
// Don't return error - fall through to DB
}
fromDB = true
err = DB.Model(&User{}).Where("id = ?", id).Select("quota").Find(&quota).Error
if err != nil {
return 0, err
}
return quota, nil
}
func GetUserUsedQuota(id int) (quota int, err error) {
err = DB.Model(&User{}).Where("id = ?", id).Select("used_quota").Find(&quota).Error
return quota, err
}
func GetUserEmail(id int) (email string, err error) {
err = DB.Model(&User{}).Where("id = ?", id).Select("email").Find(&email).Error
return email, err
}
// GetUserGroup gets group from Redis first, falls back to DB if needed
func GetUserGroup(id int, fromDB bool) (group string, err error) {
defer func() {
// Update Redis cache asynchronously on successful DB read
if shouldUpdateRedis(fromDB, err) {
gopool.Go(func() {
if err := updateUserGroupCache(id, group); err != nil {
common.SysLog("failed to update user group cache: " + err.Error())
}
})
}
}()
if !fromDB && common.RedisEnabled {
group, err := getUserGroupCache(id)
if err == nil {
return group, nil
}
// Don't return error - fall through to DB
}
fromDB = true
err = DB.Model(&User{}).Where("id = ?", id).Select(commonGroupCol).Find(&group).Error
if err != nil {
return "", err
}
return group, nil
}
// GetUserSetting gets setting from Redis first, falls back to DB if needed
func GetUserSetting(id int, fromDB bool) (settingMap dto.UserSetting, err error) {
var setting string
defer func() {
// Update Redis cache asynchronously on successful DB read
if shouldUpdateRedis(fromDB, err) {
gopool.Go(func() {
if err := updateUserSettingCache(id, setting); err != nil {
common.SysLog("failed to update user setting cache: " + err.Error())
}
})
}
}()
if !fromDB && common.RedisEnabled {
setting, err := getUserSettingCache(id)
if err == nil {
return setting, nil
}
// Don't return error - fall through to DB
}
fromDB = true
// can be nil setting
var safeSetting sql.NullString
err = DB.Model(&User{}).Where("id = ?", id).Select("setting").Find(&safeSetting).Error
if err != nil {
return settingMap, err
}
if safeSetting.Valid {
setting = safeSetting.String
} else {
setting = ""
}
userBase := &UserBase{
Setting: setting,
}
return userBase.GetSetting(), nil
}
func IncreaseUserQuota(id int, quota int, db bool) (err error) {
if quota < 0 {
return errors.New("quota 不能为负数!")
}
gopool.Go(func() {
err := cacheIncrUserQuota(id, int64(quota))
if err != nil {
common.SysLog("failed to increase user quota: " + err.Error())
}
})
if !db && common.BatchUpdateEnabled {
addNewRecord(BatchUpdateTypeUserQuota, id, quota)
return nil
}
return increaseUserQuota(id, quota)
}
func increaseUserQuota(id int, quota int) (err error) {
err = DB.Model(&User{}).Where("id = ?", id).Update("quota", gorm.Expr("quota + ?", quota)).Error
if err != nil {
return err
}
return err
}
// IncreaseUserAffCommissionQuota 分销提成计入邀请人待使用收益(aff_quota)与累计总收益(aff_history),不增加 quota。
func IncreaseUserAffCommissionQuota(inviterId int, delta int) error {
if inviterId <= 0 || delta <= 0 {
return nil
}
tx := DB.Model(&User{}).Where("id = ?", inviterId).Updates(map[string]interface{}{
"aff_quota": gorm.Expr("aff_quota + ?", delta),
"aff_history": gorm.Expr("aff_history + ?", delta),
})
if tx.Error != nil {
return tx.Error
}
if tx.RowsAffected == 0 {
return fmt.Errorf("inviter user not found: %d", inviterId)
}
return nil
}
func DecreaseUserQuota(id int, quota int) (err error) {
if quota < 0 {
return errors.New("quota 不能为负数!")
}
gopool.Go(func() {
err := cacheDecrUserQuota(id, int64(quota))
if err != nil {
common.SysLog("failed to decrease user quota: " + err.Error())
}
})
if common.BatchUpdateEnabled {
addNewRecord(BatchUpdateTypeUserQuota, id, -quota)
return nil
}
return decreaseUserQuota(id, quota)
}
func decreaseUserQuota(id int, quota int) (err error) {
err = DB.Model(&User{}).Where("id = ?", id).Update("quota", gorm.Expr("quota - ?", quota)).Error
if err != nil {
return err
}
return err
}
func DeltaUpdateUserQuota(id int, delta int) (err error) {
if delta == 0 {
return nil
}
if delta > 0 {
return IncreaseUserQuota(id, delta, false)
} else {
return DecreaseUserQuota(id, -delta)
}
}
//func GetRootUserEmail() (email string) {
// DB.Model(&User{}).Where("role = ?", common.RoleRootUser).Select("email").Find(&email)
// return email
//}
func GetRootUser() (user *User) {
DB.Where("role = ?", common.RoleRootUser).First(&user)
return user
}
func UpdateUserUsedQuotaAndRequestCount(id int, quota int) {
if common.BatchUpdateEnabled {
addNewRecord(BatchUpdateTypeUsedQuota, id, quota)
addNewRecord(BatchUpdateTypeRequestCount, id, 1)
return
}
updateUserUsedQuotaAndRequestCount(id, quota, 1)
}
func updateUserUsedQuotaAndRequestCount(id int, quota int, count int) {
err := DB.Model(&User{}).Where("id = ?", id).Updates(
map[string]interface{}{
"used_quota": gorm.Expr("used_quota + ?", quota),
"request_count": gorm.Expr("request_count + ?", count),
},
).Error
if err != nil {
common.SysLog("failed to update user used quota and request count: " + err.Error())
return
}
//// 更新缓存
//if err := invalidateUserCache(id); err != nil {
// common.SysError("failed to invalidate user cache: " + err.Error())
//}
}
func updateUserUsedQuota(id int, quota int) {
err := DB.Model(&User{}).Where("id = ?", id).Updates(
map[string]interface{}{
"used_quota": gorm.Expr("used_quota + ?", quota),
},
).Error
if err != nil {
common.SysLog("failed to update user used quota: " + err.Error())
}
}
func updateUserRequestCount(id int, count int) {
err := DB.Model(&User{}).Where("id = ?", id).Update("request_count", gorm.Expr("request_count + ?", count)).Error
if err != nil {
common.SysLog("failed to update user request count: " + err.Error())
}
}
// GetUsernameById gets username from Redis first, falls back to DB if needed
func GetUsernameById(id int, fromDB bool) (username string, err error) {
defer func() {
// Update Redis cache asynchronously on successful DB read
if shouldUpdateRedis(fromDB, err) {
gopool.Go(func() {
if err := updateUserNameCache(id, username); err != nil {
common.SysLog("failed to update user name cache: " + err.Error())
}
})
}
}()
if !fromDB && common.RedisEnabled {
username, err := getUserNameCache(id)
if err == nil {
return username, nil
}
// Don't return error - fall through to DB
}
fromDB = true
err = DB.Model(&User{}).Where("id = ?", id).Select("username").Find(&username).Error
if err != nil {
return "", err
}
return username, nil
}
func IsLinuxDOIdAlreadyTaken(linuxDOId string) bool {
var user User
err := DB.Unscoped().Where("linux_do_id = ?", linuxDOId).First(&user).Error
return !errors.Is(err, gorm.ErrRecordNotFound)
}
func (user *User) FillUserByLinuxDOId() error {
if user.LinuxDOId == "" {
return errors.New("linux do id is empty")
}
err := DB.Where("linux_do_id = ?", user.LinuxDOId).First(user).Error
return err
}
func RootUserExists() bool {
var user User
err := DB.Where("role = ?", common.RoleRootUser).First(&user).Error
if err != nil {
return false
}
return true
}
// UserIsDistributor 是否具备分销商能力is_distributor=1 且非管理员/超级管理员。
// 兼容尚未迁移的 role=5启动迁移后会转为 role=1 + is_distributor=1
func UserIsDistributor(u *User) bool {
if u == nil {
return false
}
if u.Role >= common.RoleAdminUser {
return false
}
if u.Role == common.RoleDistributorUser {
return true
}
return u.IsDistributor == common.DistributorFlagYes
}