tokenFactory/service/file_service.go

587 lines
16 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package service
import (
"bytes"
"encoding/base64"
"encoding/binary"
"fmt"
"image"
_ "image/gif"
_ "image/jpeg"
_ "image/png"
"io"
"net/http"
"strings"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/constant"
"github.com/QuantumNous/new-api/logger"
"github.com/QuantumNous/new-api/types"
"github.com/gin-gonic/gin"
"golang.org/x/image/webp"
)
// FileService 统一的文件处理服务
// 提供文件下载、解码、缓存等功能的统一入口
// getContextCacheKey 生成 context 缓存的 key
func getContextCacheKey(url string) string {
return fmt.Sprintf("file_cache_%s", common.GenerateHMAC(url))
}
// LoadFileSource 加载文件源数据
// 这是统一的入口,会自动处理缓存和不同的来源类型
func LoadFileSource(c *gin.Context, source *types.FileSource, reason ...string) (*types.CachedFileData, error) {
if source == nil {
return nil, fmt.Errorf("file source is nil")
}
if common.DebugEnabled {
logger.LogDebug(c, fmt.Sprintf("LoadFileSource starting for: %s", source.GetIdentifier()))
}
// 1. 快速检查内部缓存
if source.HasCache() {
// 即使命中内部缓存,也要确保注册到清理列表(如果尚未注册)
if c != nil {
registerSourceForCleanup(c, source)
}
return source.GetCache(), nil
}
// 2. 加锁保护加载过程
source.Mu().Lock()
defer source.Mu().Unlock()
// 3. 双重检查
if source.HasCache() {
if c != nil {
registerSourceForCleanup(c, source)
}
return source.GetCache(), nil
}
// 4. 如果是 URL检查 Context 缓存
var contextKey string
if source.IsURL() && c != nil {
contextKey = getContextCacheKey(source.URL)
if cachedData, exists := c.Get(contextKey); exists {
data := cachedData.(*types.CachedFileData)
source.SetCache(data)
registerSourceForCleanup(c, source)
return data, nil
}
}
// 5. 执行加载逻辑
var cachedData *types.CachedFileData
var err error
if source.IsURL() {
cachedData, err = loadFromURL(c, source.URL, reason...)
} else {
cachedData, err = loadFromBase64(source.Base64Data, source.MimeType)
}
if err != nil {
return nil, err
}
// 6. 设置缓存
source.SetCache(cachedData)
if contextKey != "" && c != nil {
c.Set(contextKey, cachedData)
}
// 7. 注册到 context 以便请求结束时自动清理
if c != nil {
registerSourceForCleanup(c, source)
}
return cachedData, nil
}
// registerSourceForCleanup 注册 FileSource 到 context 以便请求结束时清理
func registerSourceForCleanup(c *gin.Context, source *types.FileSource) {
if source.IsRegistered() {
return
}
key := string(constant.ContextKeyFileSourcesToCleanup)
var sources []*types.FileSource
if existing, exists := c.Get(key); exists {
sources = existing.([]*types.FileSource)
}
sources = append(sources, source)
c.Set(key, sources)
source.SetRegistered(true)
}
// CleanupFileSources 清理请求中所有注册的 FileSource
// 应在请求结束时调用(通常由中间件自动调用)
func CleanupFileSources(c *gin.Context) {
key := string(constant.ContextKeyFileSourcesToCleanup)
if sources, exists := c.Get(key); exists {
for _, source := range sources.([]*types.FileSource) {
if cache := source.GetCache(); cache != nil {
cache.Close()
}
}
c.Set(key, nil) // 清除引用
}
}
// loadFromURL 从 URL 加载文件
func loadFromURL(c *gin.Context, url string, reason ...string) (*types.CachedFileData, error) {
// 下载文件
var maxFileSize = constant.MaxFileDownloadMB * 1024 * 1024
if common.DebugEnabled {
logger.LogDebug(c, "loadFromURL: initiating download")
}
resp, err := DoDownloadRequest(url, reason...)
if err != nil {
return nil, fmt.Errorf("failed to download file from %s: %w", url, err)
}
defer resp.Body.Close()
if resp.StatusCode != 200 {
return nil, fmt.Errorf("failed to download file, status code: %d", resp.StatusCode)
}
// 读取文件内容(限制大小)
if common.DebugEnabled {
logger.LogDebug(c, "loadFromURL: reading response body")
}
fileBytes, err := io.ReadAll(io.LimitReader(resp.Body, int64(maxFileSize+1)))
if err != nil {
return nil, fmt.Errorf("failed to read file content: %w", err)
}
if len(fileBytes) > maxFileSize {
return nil, fmt.Errorf("file size exceeds maximum allowed size: %dMB", constant.MaxFileDownloadMB)
}
// 转换为 base64
base64Data := base64.StdEncoding.EncodeToString(fileBytes)
// 智能获取 MIME 类型
mimeType := smartDetectMimeType(resp, url, fileBytes)
// 判断是否使用磁盘缓存
base64Size := int64(len(base64Data))
var cachedData *types.CachedFileData
if shouldUseDiskCache(base64Size) {
// 使用磁盘缓存
diskPath, err := writeToDiskCache(base64Data)
if err != nil {
// 磁盘缓存失败,回退到内存
logger.LogWarn(c, fmt.Sprintf("Failed to write to disk cache, falling back to memory: %v", err))
cachedData = types.NewMemoryCachedData(base64Data, mimeType, int64(len(fileBytes)))
} else {
cachedData = types.NewDiskCachedData(diskPath, mimeType, int64(len(fileBytes)))
cachedData.DiskSize = base64Size
cachedData.OnClose = func(size int64) {
common.DecrementDiskFiles(size)
}
common.IncrementDiskFiles(base64Size)
if common.DebugEnabled {
logger.LogDebug(c, fmt.Sprintf("File cached to disk: %s, size: %d bytes", diskPath, base64Size))
}
}
} else {
// 使用内存缓存
cachedData = types.NewMemoryCachedData(base64Data, mimeType, int64(len(fileBytes)))
}
// 如果是图片,尝试获取图片配置
if strings.HasPrefix(mimeType, "image/") {
if common.DebugEnabled {
logger.LogDebug(c, "loadFromURL: decoding image config")
}
config, format, err := decodeImageConfig(fileBytes)
if err == nil {
cachedData.ImageConfig = &config
cachedData.ImageFormat = format
// 如果通过图片解码获取了更准确的格式,更新 MIME 类型
if mimeType == "application/octet-stream" || mimeType == "" {
cachedData.MimeType = "image/" + format
}
}
}
return cachedData, nil
}
// shouldUseDiskCache 判断是否应该使用磁盘缓存
func shouldUseDiskCache(dataSize int64) bool {
return common.ShouldUseDiskCache(dataSize)
}
// writeToDiskCache 将数据写入磁盘缓存
func writeToDiskCache(base64Data string) (string, error) {
return common.WriteDiskCacheFileString(common.DiskCacheTypeFile, base64Data)
}
// smartDetectMimeType 智能检测 MIME 类型
func smartDetectMimeType(resp *http.Response, url string, fileBytes []byte) string {
// 1. 尝试从 Content-Type header 获取
mimeType := resp.Header.Get("Content-Type")
if idx := strings.Index(mimeType, ";"); idx != -1 {
mimeType = strings.TrimSpace(mimeType[:idx])
}
if mimeType != "" && mimeType != "application/octet-stream" {
return mimeType
}
// 2. 尝试从 Content-Disposition header 的 filename 获取
if cd := resp.Header.Get("Content-Disposition"); cd != "" {
parts := strings.Split(cd, ";")
for _, part := range parts {
part = strings.TrimSpace(part)
if strings.HasPrefix(strings.ToLower(part), "filename=") {
name := strings.TrimSpace(strings.TrimPrefix(part, "filename="))
// 移除引号
if len(name) > 2 && name[0] == '"' && name[len(name)-1] == '"' {
name = name[1 : len(name)-1]
}
if dot := strings.LastIndex(name, "."); dot != -1 && dot+1 < len(name) {
ext := strings.ToLower(name[dot+1:])
if ext != "" {
mt := GetMimeTypeByExtension(ext)
if mt != "application/octet-stream" {
return mt
}
}
}
break
}
}
}
// 3. 尝试从 URL 路径获取扩展名
mt := guessMimeTypeFromURL(url)
if mt != "application/octet-stream" {
return mt
}
// 4. 使用 http.DetectContentType 内容嗅探
if len(fileBytes) > 0 {
sniffed := http.DetectContentType(fileBytes)
if sniffed != "" && sniffed != "application/octet-stream" {
// 去除可能的 charset 参数
if idx := strings.Index(sniffed, ";"); idx != -1 {
sniffed = strings.TrimSpace(sniffed[:idx])
}
return sniffed
}
// 4.5 尝试 HEIF/HEIC 检测Go 标准库不识别)
if heifMime := detectHEIF(fileBytes); heifMime != "" {
return heifMime
}
}
// 5. 尝试作为图片解码获取格式
if len(fileBytes) > 0 {
if _, format, err := decodeImageConfig(fileBytes); err == nil && format != "" {
return "image/" + strings.ToLower(format)
}
}
// 最终回退
return "application/octet-stream"
}
// loadFromBase64 从 base64 字符串加载文件
func loadFromBase64(base64String string, providedMimeType string) (*types.CachedFileData, error) {
var mimeType string
var cleanBase64 string
// 处理 data: 前缀
if strings.HasPrefix(base64String, "data:") {
idx := strings.Index(base64String, ",")
if idx != -1 {
header := base64String[:idx]
cleanBase64 = base64String[idx+1:]
if strings.Contains(header, ":") && strings.Contains(header, ";") {
mimeStart := strings.Index(header, ":") + 1
mimeEnd := strings.Index(header, ";")
if mimeStart < mimeEnd {
mimeType = header[mimeStart:mimeEnd]
}
}
} else {
cleanBase64 = base64String
}
} else {
cleanBase64 = base64String
}
if providedMimeType != "" {
mimeType = providedMimeType
}
decodedData, err := base64.StdEncoding.DecodeString(cleanBase64)
if err != nil {
return nil, fmt.Errorf("failed to decode base64 data: %w", err)
}
base64Size := int64(len(cleanBase64))
var cachedData *types.CachedFileData
if shouldUseDiskCache(base64Size) {
diskPath, err := writeToDiskCache(cleanBase64)
if err != nil {
cachedData = types.NewMemoryCachedData(cleanBase64, mimeType, int64(len(decodedData)))
} else {
cachedData = types.NewDiskCachedData(diskPath, mimeType, int64(len(decodedData)))
cachedData.DiskSize = base64Size
cachedData.OnClose = func(size int64) {
common.DecrementDiskFiles(size)
}
common.IncrementDiskFiles(base64Size)
}
} else {
cachedData = types.NewMemoryCachedData(cleanBase64, mimeType, int64(len(decodedData)))
}
if mimeType == "" || strings.HasPrefix(mimeType, "image/") {
config, format, err := decodeImageConfig(decodedData)
if err == nil {
cachedData.ImageConfig = &config
cachedData.ImageFormat = format
if mimeType == "" {
cachedData.MimeType = "image/" + format
}
}
}
return cachedData, nil
}
// GetImageConfig 获取图片配置
func GetImageConfig(c *gin.Context, source *types.FileSource) (image.Config, string, error) {
cachedData, err := LoadFileSource(c, source, "get_image_config")
if err != nil {
return image.Config{}, "", err
}
if cachedData.ImageConfig != nil {
return *cachedData.ImageConfig, cachedData.ImageFormat, nil
}
base64Str, err := cachedData.GetBase64Data()
if err != nil {
return image.Config{}, "", fmt.Errorf("failed to get base64 data: %w", err)
}
decodedData, err := base64.StdEncoding.DecodeString(base64Str)
if err != nil {
return image.Config{}, "", fmt.Errorf("failed to decode base64 for image config: %w", err)
}
config, format, err := decodeImageConfig(decodedData)
if err != nil {
return image.Config{}, "", err
}
cachedData.ImageConfig = &config
cachedData.ImageFormat = format
return config, format, nil
}
// GetBase64Data 获取 base64 编码的数据
func GetBase64Data(c *gin.Context, source *types.FileSource, reason ...string) (string, string, error) {
cachedData, err := LoadFileSource(c, source, reason...)
if err != nil {
return "", "", err
}
base64Str, err := cachedData.GetBase64Data()
if err != nil {
return "", "", fmt.Errorf("failed to get base64 data: %w", err)
}
return base64Str, cachedData.MimeType, nil
}
// GetMimeType 获取文件的 MIME 类型
func GetMimeType(c *gin.Context, source *types.FileSource) (string, error) {
if source.HasCache() {
return source.GetCache().MimeType, nil
}
if source.IsURL() {
mimeType, err := GetFileTypeFromUrl(c, source.URL, "get_mime_type")
if err == nil && mimeType != "" && mimeType != "application/octet-stream" {
return mimeType, nil
}
}
cachedData, err := LoadFileSource(c, source, "get_mime_type")
if err != nil {
return "", err
}
return cachedData.MimeType, nil
}
// DetectFileType 检测文件类型
func DetectFileType(mimeType string) types.FileType {
if strings.HasPrefix(mimeType, "image/") {
return types.FileTypeImage
}
if strings.HasPrefix(mimeType, "audio/") {
return types.FileTypeAudio
}
if strings.HasPrefix(mimeType, "video/") {
return types.FileTypeVideo
}
return types.FileTypeFile
}
// decodeImageConfig 从字节数据解码图片配置
func decodeImageConfig(data []byte) (image.Config, string, error) {
reader := bytes.NewReader(data)
config, format, err := image.DecodeConfig(reader)
if err == nil {
return config, format, nil
}
reader.Seek(0, io.SeekStart)
config, err = webp.DecodeConfig(reader)
if err == nil {
return config, "webp", nil
}
// Try HEIF/HEIC: parse ISOBMFF ispe box for dimensions
if heifMime := detectHEIF(data); heifMime != "" {
formatName := "heif"
if heifMime == "image/heic" {
formatName = "heic"
}
if w, h, ok := parseHEIFDimensions(data); ok {
return image.Config{Width: w, Height: h}, formatName, nil
}
return image.Config{}, "", fmt.Errorf("failed to decode HEIF/HEIC image dimensions")
}
return image.Config{}, "", fmt.Errorf("failed to decode image config: unsupported format")
}
// detectHEIF checks ISOBMFF magic bytes to detect HEIC/HEIF files.
// Returns "image/heic", "image/heif", or "" if not recognized.
func detectHEIF(data []byte) string {
if len(data) < 12 {
return ""
}
// ISOBMFF: bytes[4:8] must be "ftyp"
if string(data[4:8]) != "ftyp" {
return ""
}
brand := string(data[8:12])
switch brand {
case "heic", "heix", "hevc", "hevx", "heim", "heis":
return "image/heic"
case "mif1", "msf1":
return "image/heif"
default:
return ""
}
}
// parseHEIFDimensions parses ISOBMFF box tree to find the ispe box
// and extract image width/height. Returns (width, height, ok).
func parseHEIFDimensions(data []byte) (int, int, bool) {
size := len(data)
if size < 12 {
return 0, 0, false
}
// Walk top-level boxes to find "meta"
offset := 0
for offset+8 <= size {
boxSize := int(binary.BigEndian.Uint32(data[offset : offset+4]))
boxType := string(data[offset+4 : offset+8])
headerLen := 8
if boxSize == 1 {
// 64-bit extended size
if offset+16 > size {
break
}
boxSize = int(binary.BigEndian.Uint64(data[offset+8 : offset+16]))
headerLen = 16
} else if boxSize == 0 {
// box extends to end of data
boxSize = size - offset
}
if boxSize < headerLen || offset+boxSize > size {
break
}
if boxType == "meta" {
// meta is a full box: 4 bytes version/flags after header
metaData := data[offset+headerLen : offset+boxSize]
if len(metaData) < 4 {
return 0, 0, false
}
return findISPE(metaData[4:])
}
offset += boxSize
}
return 0, 0, false
}
// findISPE recursively searches for the ispe box within container boxes.
// Path: meta -> iprp -> ipco -> ispe
func findISPE(data []byte) (int, int, bool) {
offset := 0
size := len(data)
for offset+8 <= size {
boxSize := int(binary.BigEndian.Uint32(data[offset : offset+4]))
boxType := string(data[offset+4 : offset+8])
if boxSize < 8 || offset+boxSize > size {
break
}
content := data[offset+8 : offset+boxSize]
switch boxType {
case "iprp", "ipco":
if w, h, ok := findISPE(content); ok {
return w, h, true
}
case "ispe":
// ispe is a full box: 4 bytes version/flags, then 4 bytes width, 4 bytes height
if len(content) >= 12 {
w := int(binary.BigEndian.Uint32(content[4:8]))
h := int(binary.BigEndian.Uint32(content[8:12]))
if w > 0 && h > 0 {
return w, h, true
}
}
}
offset += boxSize
}
return 0, 0, false
}
// guessMimeTypeFromURL 从 URL 猜测 MIME 类型
func guessMimeTypeFromURL(url string) string {
cleanedURL := url
if q := strings.Index(cleanedURL, "?"); q != -1 {
cleanedURL = cleanedURL[:q]
}
if slash := strings.LastIndex(cleanedURL, "/"); slash != -1 && slash+1 < len(cleanedURL) {
last := cleanedURL[slash+1:]
if dot := strings.LastIndex(last, "."); dot != -1 && dot+1 < len(last) {
ext := strings.ToLower(last[dot+1:])
return GetMimeTypeByExtension(ext)
}
}
return "application/octet-stream"
}