314 lines
8.8 KiB
Go
314 lines
8.8 KiB
Go
package middleware
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"net/http"
|
|
"strconv"
|
|
"time"
|
|
|
|
"github.com/QuantumNous/new-api/common"
|
|
"github.com/QuantumNous/new-api/model"
|
|
"github.com/QuantumNous/new-api/service"
|
|
"github.com/QuantumNous/new-api/setting"
|
|
"github.com/gin-contrib/sessions"
|
|
"github.com/gin-gonic/gin"
|
|
)
|
|
|
|
var timeFormat = "2006-01-02T15:04:05.000Z"
|
|
|
|
var inMemoryRateLimiter common.InMemoryRateLimiter
|
|
|
|
var defNext = func(c *gin.Context) {
|
|
c.Next()
|
|
}
|
|
|
|
func redisRateLimiter(c *gin.Context, maxRequestNum int, duration int64, mark string) {
|
|
if !shouldApplyGeneralRateLimit(c) {
|
|
return
|
|
}
|
|
if shouldBypassRateLimit(c) {
|
|
return
|
|
}
|
|
ctx := context.Background()
|
|
rdb := common.RDB
|
|
if rdb == nil {
|
|
return
|
|
}
|
|
userID := getRateLimitUserID(c)
|
|
key := "rateLimit:" + mark + ":" + buildRateLimitSubject(c)
|
|
listLength, err := rdb.LLen(ctx, key).Result()
|
|
if err != nil {
|
|
fmt.Println(err.Error())
|
|
c.JSON(http.StatusOK, gin.H{
|
|
"success": false,
|
|
"message": "服务繁忙,请稍后重试",
|
|
})
|
|
c.Abort()
|
|
return
|
|
}
|
|
if listLength < int64(maxRequestNum) {
|
|
rdb.LPush(ctx, key, time.Now().Format(timeFormat))
|
|
rdb.Expire(ctx, key, common.RateLimitKeyExpirationDuration)
|
|
} else {
|
|
oldTimeStr, _ := rdb.LIndex(ctx, key, -1).Result()
|
|
oldTime, err := time.Parse(timeFormat, oldTimeStr)
|
|
if err != nil {
|
|
fmt.Println(err)
|
|
c.JSON(http.StatusOK, gin.H{
|
|
"success": false,
|
|
"message": "服务繁忙,请稍后重试",
|
|
})
|
|
c.Abort()
|
|
return
|
|
}
|
|
nowTimeStr := time.Now().Format(timeFormat)
|
|
nowTime, err := time.Parse(timeFormat, nowTimeStr)
|
|
if err != nil {
|
|
fmt.Println(err)
|
|
c.JSON(http.StatusOK, gin.H{
|
|
"success": false,
|
|
"message": "服务繁忙,请稍后重试",
|
|
})
|
|
c.Abort()
|
|
return
|
|
}
|
|
// time.Since will return negative number!
|
|
// See: https://stackoverflow.com/questions/50970900/why-is-time-since-returning-negative-durations-on-windows
|
|
if int64(nowTime.Sub(oldTime).Seconds()) < duration {
|
|
rdb.Expire(ctx, key, common.RateLimitKeyExpirationDuration)
|
|
if userID > 0 {
|
|
_ = service.AddUserRateLimitBlacklist(userID, duration, "api-rate-limit:"+mark)
|
|
}
|
|
c.Status(http.StatusTooManyRequests)
|
|
c.Abort()
|
|
return
|
|
} else {
|
|
rdb.LPush(ctx, key, time.Now().Format(timeFormat))
|
|
rdb.LTrim(ctx, key, 0, int64(maxRequestNum-1))
|
|
rdb.Expire(ctx, key, common.RateLimitKeyExpirationDuration)
|
|
}
|
|
}
|
|
}
|
|
|
|
func memoryRateLimiter(c *gin.Context, maxRequestNum int, duration int64, mark string) {
|
|
if !shouldApplyGeneralRateLimit(c) {
|
|
return
|
|
}
|
|
if shouldBypassRateLimit(c) {
|
|
return
|
|
}
|
|
userID := getRateLimitUserID(c)
|
|
key := mark + ":" + buildRateLimitSubject(c)
|
|
if !inMemoryRateLimiter.Request(key, maxRequestNum, duration) {
|
|
if userID > 0 {
|
|
_ = service.AddUserRateLimitBlacklist(userID, duration, "api-rate-limit:"+mark)
|
|
}
|
|
c.Status(http.StatusTooManyRequests)
|
|
c.Abort()
|
|
return
|
|
}
|
|
}
|
|
|
|
// buildRateLimitSubject returns a stable per-user identifier when available.
|
|
// Priority: authenticated context id -> session id -> New-Api-User header -> client IP.
|
|
func buildRateLimitSubject(c *gin.Context) string {
|
|
if userID := getRateLimitUserID(c); userID > 0 {
|
|
return fmt.Sprintf("user:%d", userID)
|
|
}
|
|
return "ip:" + c.ClientIP()
|
|
}
|
|
|
|
func getRateLimitUserID(c *gin.Context) int {
|
|
if userID := c.GetInt("id"); userID > 0 {
|
|
return userID
|
|
}
|
|
session := sessions.Default(c)
|
|
if sessionID := session.Get("id"); sessionID != nil {
|
|
if id, ok := sessionID.(int); ok && id > 0 {
|
|
return id
|
|
}
|
|
if id, ok := sessionID.(int64); ok && id > 0 {
|
|
return int(id)
|
|
}
|
|
if id, ok := sessionID.(float64); ok && id > 0 {
|
|
return int(id)
|
|
}
|
|
if idStr, ok := sessionID.(string); ok {
|
|
if parsed, err := strconv.Atoi(idStr); err == nil && parsed > 0 {
|
|
return parsed
|
|
}
|
|
}
|
|
}
|
|
if apiUserIDStr := c.GetHeader("New-Api-User"); apiUserIDStr != "" {
|
|
if id, err := strconv.Atoi(apiUserIDStr); err == nil && id > 0 {
|
|
return id
|
|
}
|
|
}
|
|
return 0
|
|
}
|
|
|
|
func shouldBypassRateLimit(c *gin.Context) bool {
|
|
userID := getRateLimitUserID(c)
|
|
if userID <= 0 {
|
|
return false
|
|
}
|
|
return model.IsAdmin(userID) || setting.IsUserInRateLimitWhitelist(userID)
|
|
}
|
|
|
|
// shouldApplyGeneralRateLimit ensures GA/GW/CT only apply to identified users.
|
|
// Public unauthenticated routes (e.g. homepage) are excluded.
|
|
func shouldApplyGeneralRateLimit(c *gin.Context) bool {
|
|
return getRateLimitUserID(c) > 0
|
|
}
|
|
|
|
func rateLimitFactory(maxRequestNum int, duration int64, mark string) func(c *gin.Context) {
|
|
if common.RedisEnabled {
|
|
return func(c *gin.Context) {
|
|
redisRateLimiter(c, maxRequestNum, duration, mark)
|
|
}
|
|
} else {
|
|
// It's safe to call multi times.
|
|
inMemoryRateLimiter.Init(common.RateLimitKeyExpirationDuration)
|
|
return func(c *gin.Context) {
|
|
memoryRateLimiter(c, maxRequestNum, duration, mark)
|
|
}
|
|
}
|
|
}
|
|
|
|
func GlobalWebRateLimit() func(c *gin.Context) {
|
|
// Web-side global limiter is intentionally disabled.
|
|
// Use GlobalAPIRateLimit + CriticalRateLimit for authenticated API traffic.
|
|
return defNext
|
|
}
|
|
|
|
func GlobalAPIRateLimit() func(c *gin.Context) {
|
|
return func(c *gin.Context) {
|
|
if !common.GlobalApiRateLimitEnable {
|
|
c.Next()
|
|
return
|
|
}
|
|
rateLimitFactory(common.GlobalApiRateLimitNum, common.GlobalApiRateLimitDuration, "GA")(c)
|
|
}
|
|
}
|
|
|
|
func CriticalRateLimit() func(c *gin.Context) {
|
|
return func(c *gin.Context) {
|
|
if !common.CriticalRateLimitEnable {
|
|
c.Next()
|
|
return
|
|
}
|
|
rateLimitFactory(common.CriticalRateLimitNum, common.CriticalRateLimitDuration, "CT")(c)
|
|
}
|
|
}
|
|
|
|
func DownloadRateLimit() func(c *gin.Context) {
|
|
return rateLimitFactory(common.DownloadRateLimitNum, common.DownloadRateLimitDuration, "DW")
|
|
}
|
|
|
|
func UploadRateLimit() func(c *gin.Context) {
|
|
return rateLimitFactory(common.UploadRateLimitNum, common.UploadRateLimitDuration, "UP")
|
|
}
|
|
|
|
// userRateLimitFactory creates a rate limiter keyed by authenticated user ID
|
|
// instead of client IP, making it resistant to proxy rotation attacks.
|
|
// Must be used AFTER authentication middleware (UserAuth).
|
|
func userRateLimitFactory(maxRequestNum int, duration int64, mark string) func(c *gin.Context) {
|
|
if common.RedisEnabled {
|
|
return func(c *gin.Context) {
|
|
userId := c.GetInt("id")
|
|
if userId == 0 {
|
|
c.Status(http.StatusUnauthorized)
|
|
c.Abort()
|
|
return
|
|
}
|
|
// Admin users are always whitelisted by default.
|
|
if model.IsAdmin(userId) || setting.IsUserInRateLimitWhitelist(userId) {
|
|
return
|
|
}
|
|
blacklisted, err := service.IsUserRateLimitBlacklisted(userId)
|
|
if err == nil && blacklisted {
|
|
c.Status(http.StatusTooManyRequests)
|
|
c.Abort()
|
|
return
|
|
}
|
|
key := fmt.Sprintf("rateLimit:%s:user:%d", mark, userId)
|
|
userRedisRateLimiter(c, maxRequestNum, duration, key)
|
|
}
|
|
}
|
|
// It's safe to call multi times.
|
|
inMemoryRateLimiter.Init(common.RateLimitKeyExpirationDuration)
|
|
return func(c *gin.Context) {
|
|
userId := c.GetInt("id")
|
|
if userId == 0 {
|
|
c.Status(http.StatusUnauthorized)
|
|
c.Abort()
|
|
return
|
|
}
|
|
if model.IsAdmin(userId) || setting.IsUserInRateLimitWhitelist(userId) {
|
|
return
|
|
}
|
|
key := fmt.Sprintf("%s:user:%d", mark, userId)
|
|
if !inMemoryRateLimiter.Request(key, maxRequestNum, duration) {
|
|
c.Status(http.StatusTooManyRequests)
|
|
c.Abort()
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
// userRedisRateLimiter is like redisRateLimiter but accepts a pre-built key
|
|
// (to support user-ID-based keys).
|
|
func userRedisRateLimiter(c *gin.Context, maxRequestNum int, duration int64, key string) {
|
|
ctx := context.Background()
|
|
rdb := common.RDB
|
|
listLength, err := rdb.LLen(ctx, key).Result()
|
|
if err != nil {
|
|
fmt.Println(err.Error())
|
|
c.Status(http.StatusInternalServerError)
|
|
c.Abort()
|
|
return
|
|
}
|
|
if listLength < int64(maxRequestNum) {
|
|
rdb.LPush(ctx, key, time.Now().Format(timeFormat))
|
|
rdb.Expire(ctx, key, common.RateLimitKeyExpirationDuration)
|
|
} else {
|
|
oldTimeStr, _ := rdb.LIndex(ctx, key, -1).Result()
|
|
oldTime, err := time.Parse(timeFormat, oldTimeStr)
|
|
if err != nil {
|
|
fmt.Println(err)
|
|
c.Status(http.StatusInternalServerError)
|
|
c.Abort()
|
|
return
|
|
}
|
|
nowTimeStr := time.Now().Format(timeFormat)
|
|
nowTime, err := time.Parse(timeFormat, nowTimeStr)
|
|
if err != nil {
|
|
fmt.Println(err)
|
|
c.Status(http.StatusInternalServerError)
|
|
c.Abort()
|
|
return
|
|
}
|
|
if int64(nowTime.Sub(oldTime).Seconds()) < duration {
|
|
rdb.Expire(ctx, key, common.RateLimitKeyExpirationDuration)
|
|
_ = service.AddUserRateLimitBlacklist(c.GetInt("id"), duration, "user-rate-limit")
|
|
c.Status(http.StatusTooManyRequests)
|
|
c.Abort()
|
|
return
|
|
} else {
|
|
rdb.LPush(ctx, key, time.Now().Format(timeFormat))
|
|
rdb.LTrim(ctx, key, 0, int64(maxRequestNum-1))
|
|
rdb.Expire(ctx, key, common.RateLimitKeyExpirationDuration)
|
|
}
|
|
}
|
|
}
|
|
|
|
// SearchRateLimit returns a per-user rate limiter for search endpoints.
|
|
// Configurable via SEARCH_RATE_LIMIT_ENABLE / SEARCH_RATE_LIMIT / SEARCH_RATE_LIMIT_DURATION.
|
|
func SearchRateLimit() func(c *gin.Context) {
|
|
if !common.SearchRateLimitEnable {
|
|
return defNext
|
|
}
|
|
return userRateLimitFactory(common.SearchRateLimitNum, common.SearchRateLimitDuration, "SR")
|
|
}
|