diff --git a/common/body_storage.go b/common/body_storage.go index 13062bd06..ea37cda96 100644 --- a/common/body_storage.go +++ b/common/body_storage.go @@ -5,12 +5,9 @@ import ( "fmt" "io" "os" - "path/filepath" "sync" "sync/atomic" "time" - - "github.com/google/uuid" ) // BodyStorage 请求体存储接口 @@ -101,25 +98,10 @@ type diskStorage struct { } func newDiskStorage(data []byte, cachePath string) (*diskStorage, error) { - // 确定缓存目录 - dir := cachePath - if dir == "" { - dir = os.TempDir() - } - dir = filepath.Join(dir, "new-api-body-cache") - - // 确保目录存在 - if err := os.MkdirAll(dir, 0755); err != nil { - return nil, fmt.Errorf("failed to create cache directory: %w", err) - } - - // 创建临时文件 - filename := fmt.Sprintf("body-%s-%d.tmp", uuid.New().String()[:8], time.Now().UnixNano()) - filePath := filepath.Join(dir, filename) - - file, err := os.OpenFile(filePath, os.O_CREATE|os.O_RDWR|os.O_EXCL, 0600) + // 使用统一的缓存目录管理 + filePath, file, err := CreateDiskCacheFile(DiskCacheTypeBody) if err != nil { - return nil, fmt.Errorf("failed to create temp file: %w", err) + return nil, err } // 写入数据 @@ -148,25 +130,10 @@ func newDiskStorage(data []byte, cachePath string) (*diskStorage, error) { } func newDiskStorageFromReader(reader io.Reader, maxBytes int64, cachePath string) (*diskStorage, error) { - // 确定缓存目录 - dir := cachePath - if dir == "" { - dir = os.TempDir() - } - dir = filepath.Join(dir, "new-api-body-cache") - - // 确保目录存在 - if err := os.MkdirAll(dir, 0755); err != nil { - return nil, fmt.Errorf("failed to create cache directory: %w", err) - } - - // 创建临时文件 - filename := fmt.Sprintf("body-%s-%d.tmp", uuid.New().String()[:8], time.Now().UnixNano()) - filePath := filepath.Join(dir, filename) - - file, err := os.OpenFile(filePath, os.O_CREATE|os.O_RDWR|os.O_EXCL, 0600) + // 使用统一的缓存目录管理 + filePath, file, err := CreateDiskCacheFile(DiskCacheTypeBody) if err != nil { - return nil, fmt.Errorf("failed to create temp file: %w", err) + return nil, err } // 从 reader 读取并写入文件 @@ -337,29 +304,6 @@ func CreateBodyStorageFromReader(reader io.Reader, contentLength int64, maxBytes // CleanupOldCacheFiles 清理旧的缓存文件(用于启动时清理残留) func CleanupOldCacheFiles() { - cachePath := GetDiskCachePath() - if cachePath == "" { - cachePath = os.TempDir() - } - dir := filepath.Join(cachePath, "new-api-body-cache") - - entries, err := os.ReadDir(dir) - if err != nil { - return // 目录不存在或无法读取 - } - - now := time.Now() - for _, entry := range entries { - if entry.IsDir() { - continue - } - info, err := entry.Info() - if err != nil { - continue - } - // 删除超过 5 分钟的旧文件 - if now.Sub(info.ModTime()) > 5*time.Minute { - os.Remove(filepath.Join(dir, entry.Name())) - } - } + // 使用统一的缓存管理 + CleanupOldDiskCacheFiles(5 * time.Minute) } diff --git a/common/disk_cache.go b/common/disk_cache.go new file mode 100644 index 000000000..b41fdcb6a --- /dev/null +++ b/common/disk_cache.go @@ -0,0 +1,172 @@ +package common + +import ( + "fmt" + "os" + "path/filepath" + "time" + + "github.com/google/uuid" +) + +// DiskCacheType 磁盘缓存类型 +type DiskCacheType string + +const ( + DiskCacheTypeBody DiskCacheType = "body" // 请求体缓存 + DiskCacheTypeFile DiskCacheType = "file" // 文件数据缓存 +) + +// 统一的缓存目录名 +const diskCacheDir = "new-api-body-cache" + +// GetDiskCacheDir 获取统一的磁盘缓存目录 +// 注意:每次调用都会重新计算,以响应配置变化 +func GetDiskCacheDir() string { + cachePath := GetDiskCachePath() + if cachePath == "" { + cachePath = os.TempDir() + } + return filepath.Join(cachePath, diskCacheDir) +} + +// EnsureDiskCacheDir 确保缓存目录存在 +func EnsureDiskCacheDir() error { + dir := GetDiskCacheDir() + return os.MkdirAll(dir, 0755) +} + +// CreateDiskCacheFile 创建磁盘缓存文件 +// cacheType: 缓存类型(body/file) +// 返回文件路径和文件句柄 +func CreateDiskCacheFile(cacheType DiskCacheType) (string, *os.File, error) { + if err := EnsureDiskCacheDir(); err != nil { + return "", nil, fmt.Errorf("failed to create cache directory: %w", err) + } + + dir := GetDiskCacheDir() + filename := fmt.Sprintf("%s-%s-%d.tmp", cacheType, uuid.New().String()[:8], time.Now().UnixNano()) + filePath := filepath.Join(dir, filename) + + file, err := os.OpenFile(filePath, os.O_CREATE|os.O_RDWR|os.O_EXCL, 0600) + if err != nil { + return "", nil, fmt.Errorf("failed to create cache file: %w", err) + } + + return filePath, file, nil +} + +// WriteDiskCacheFile 写入数据到磁盘缓存文件 +// 返回文件路径 +func WriteDiskCacheFile(cacheType DiskCacheType, data []byte) (string, error) { + filePath, file, err := CreateDiskCacheFile(cacheType) + if err != nil { + return "", err + } + + _, err = file.Write(data) + if err != nil { + file.Close() + os.Remove(filePath) + return "", fmt.Errorf("failed to write cache file: %w", err) + } + + if err := file.Close(); err != nil { + os.Remove(filePath) + return "", fmt.Errorf("failed to close cache file: %w", err) + } + + return filePath, nil +} + +// WriteDiskCacheFileString 写入字符串到磁盘缓存文件 +func WriteDiskCacheFileString(cacheType DiskCacheType, data string) (string, error) { + return WriteDiskCacheFile(cacheType, []byte(data)) +} + +// ReadDiskCacheFile 读取磁盘缓存文件 +func ReadDiskCacheFile(filePath string) ([]byte, error) { + return os.ReadFile(filePath) +} + +// ReadDiskCacheFileString 读取磁盘缓存文件为字符串 +func ReadDiskCacheFileString(filePath string) (string, error) { + data, err := os.ReadFile(filePath) + if err != nil { + return "", err + } + return string(data), nil +} + +// RemoveDiskCacheFile 删除磁盘缓存文件 +func RemoveDiskCacheFile(filePath string) error { + return os.Remove(filePath) +} + +// CleanupOldDiskCacheFiles 清理旧的缓存文件 +// maxAge: 文件最大存活时间 +// 注意:此函数只删除文件,不更新统计(因为无法知道每个文件的原始大小) +func CleanupOldDiskCacheFiles(maxAge time.Duration) error { + dir := GetDiskCacheDir() + + entries, err := os.ReadDir(dir) + if err != nil { + if os.IsNotExist(err) { + return nil // 目录不存在,无需清理 + } + return err + } + + now := time.Now() + for _, entry := range entries { + if entry.IsDir() { + continue + } + info, err := entry.Info() + if err != nil { + continue + } + if now.Sub(info.ModTime()) > maxAge { + os.Remove(filepath.Join(dir, entry.Name())) + } + } + return nil +} + +// GetDiskCacheInfo 获取磁盘缓存目录信息 +func GetDiskCacheInfo() (fileCount int, totalSize int64, err error) { + dir := GetDiskCacheDir() + + entries, err := os.ReadDir(dir) + if err != nil { + if os.IsNotExist(err) { + return 0, 0, nil + } + return 0, 0, err + } + + for _, entry := range entries { + if entry.IsDir() { + continue + } + info, err := entry.Info() + if err != nil { + continue + } + fileCount++ + totalSize += info.Size() + } + return fileCount, totalSize, nil +} + +// ShouldUseDiskCache 判断是否应该使用磁盘缓存 +func ShouldUseDiskCache(dataSize int64) bool { + if !IsDiskCacheEnabled() { + return false + } + threshold := GetDiskCacheThresholdBytes() + if dataSize < threshold { + return false + } + return IsDiskCacheAvailable(dataSize) +} diff --git a/common/disk_cache_config.go b/common/disk_cache_config.go index 416ec94f4..ea0b1e1ad 100644 --- a/common/disk_cache_config.go +++ b/common/disk_cache_config.go @@ -139,12 +139,29 @@ func IncrementMemoryCacheHits() { atomic.AddInt64(&diskCacheStats.MemoryCacheHits, 1) } -// ResetDiskCacheStats 重置统计信息(不重置当前使用量) +// ResetDiskCacheStats 重置命中统计信息(不重置当前使用量) func ResetDiskCacheStats() { atomic.StoreInt64(&diskCacheStats.DiskCacheHits, 0) atomic.StoreInt64(&diskCacheStats.MemoryCacheHits, 0) } +// ResetDiskCacheUsage 重置磁盘缓存使用量统计(用于清理缓存后) +func ResetDiskCacheUsage() { + atomic.StoreInt64(&diskCacheStats.ActiveDiskFiles, 0) + atomic.StoreInt64(&diskCacheStats.CurrentDiskUsageBytes, 0) +} + +// SyncDiskCacheStats 从实际磁盘状态同步统计信息 +// 用于修正统计与实际不符的情况 +func SyncDiskCacheStats() { + fileCount, totalSize, err := GetDiskCacheInfo() + if err != nil { + return + } + atomic.StoreInt64(&diskCacheStats.ActiveDiskFiles, int64(fileCount)) + atomic.StoreInt64(&diskCacheStats.CurrentDiskUsageBytes, totalSize) +} + // IsDiskCacheAvailable 检查是否可以创建新的磁盘缓存 func IsDiskCacheAvailable(requestSize int64) bool { if !IsDiskCacheEnabled() { diff --git a/constant/context_key.go b/constant/context_key.go index b494f3685..93a553c7a 100644 --- a/constant/context_key.go +++ b/constant/context_key.go @@ -56,6 +56,9 @@ const ( ContextKeySystemPromptOverride ContextKey = "system_prompt_override" + // ContextKeyFileSourcesToCleanup stores file sources that need cleanup when request ends + ContextKeyFileSourcesToCleanup ContextKey = "file_sources_to_cleanup" + // ContextKeyAdminRejectReason stores an admin-only reject/block reason extracted from upstream responses. // It is not returned to end users, but can be persisted into consume/error logs for debugging. ContextKeyAdminRejectReason ContextKey = "admin_reject_reason" diff --git a/controller/channel.go b/controller/channel.go index 3ac29d7c6..9fcc95e0b 100644 --- a/controller/channel.go +++ b/controller/channel.go @@ -89,7 +89,8 @@ func GetAllChannels(c *gin.Context) { if enableTagMode { tags, err := model.GetPaginatedTags(pageInfo.GetStartIdx(), pageInfo.GetPageSize()) if err != nil { - c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()}) + common.SysError("failed to get paginated tags: " + err.Error()) + c.JSON(http.StatusOK, gin.H{"success": false, "message": "获取标签失败,请稍后重试"}) return } for _, tag := range tags { @@ -136,7 +137,8 @@ func GetAllChannels(c *gin.Context) { err := baseQuery.Order(order).Limit(pageInfo.GetPageSize()).Offset(pageInfo.GetStartIdx()).Omit("key").Find(&channelData).Error if err != nil { - c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()}) + common.SysError("failed to get channels: " + err.Error()) + c.JSON(http.StatusOK, gin.H{"success": false, "message": "获取渠道列表失败,请稍后重试"}) return } } @@ -641,7 +643,8 @@ func RefreshCodexChannelCredential(c *gin.Context) { oauthKey, ch, err := service.RefreshCodexChannelCredential(ctx, channelId, service.CodexCredentialRefreshOptions{ResetCaches: true}) if err != nil { - c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()}) + common.SysError("failed to refresh codex channel credential: " + err.Error()) + c.JSON(http.StatusOK, gin.H{"success": false, "message": "刷新凭证失败,请稍后重试"}) return } @@ -1315,7 +1318,8 @@ func CopyChannel(c *gin.Context) { // fetch original channel with key origin, err := model.GetChannelById(id, true) if err != nil { - c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()}) + common.SysError("failed to get channel by id: " + err.Error()) + c.JSON(http.StatusOK, gin.H{"success": false, "message": "获取渠道信息失败,请稍后重试"}) return } @@ -1333,7 +1337,8 @@ func CopyChannel(c *gin.Context) { // insert if err := model.BatchInsertChannels([]model.Channel{clone}); err != nil { - c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()}) + common.SysError("failed to clone channel: " + err.Error()) + c.JSON(http.StatusOK, gin.H{"success": false, "message": "复制渠道失败,请稍后重试"}) return } model.InitChannelCache() diff --git a/controller/codex_oauth.go b/controller/codex_oauth.go index 3c881ebb5..3071413c6 100644 --- a/controller/codex_oauth.go +++ b/controller/codex_oauth.go @@ -132,7 +132,8 @@ func completeCodexOAuthWithChannelID(c *gin.Context, channelID int) { code, state, err := parseCodexAuthorizationInput(req.Input) if err != nil { - c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()}) + common.SysError("failed to parse codex authorization input: " + err.Error()) + c.JSON(http.StatusOK, gin.H{"success": false, "message": "解析授权信息失败,请检查输入格式"}) return } if strings.TrimSpace(code) == "" { @@ -177,7 +178,8 @@ func completeCodexOAuthWithChannelID(c *gin.Context, channelID int) { tokenRes, err := service.ExchangeCodexAuthorizationCode(ctx, code, verifier) if err != nil { - c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()}) + common.SysError("failed to exchange codex authorization code: " + err.Error()) + c.JSON(http.StatusOK, gin.H{"success": false, "message": "授权码交换失败,请重试"}) return } diff --git a/controller/codex_usage.go b/controller/codex_usage.go index 61614b460..62b7a754f 100644 --- a/controller/codex_usage.go +++ b/controller/codex_usage.go @@ -45,7 +45,8 @@ func GetCodexChannelUsage(c *gin.Context) { oauthKey, err := codex.ParseOAuthKey(strings.TrimSpace(ch.Key)) if err != nil { - c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()}) + common.SysError("failed to parse oauth key: " + err.Error()) + c.JSON(http.StatusOK, gin.H{"success": false, "message": "解析凭证失败,请检查渠道配置"}) return } accessToken := strings.TrimSpace(oauthKey.AccessToken) @@ -70,7 +71,8 @@ func GetCodexChannelUsage(c *gin.Context) { statusCode, body, err := service.FetchCodexWhamUsage(ctx, client, ch.GetBaseURL(), accessToken, accountID) if err != nil { - c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()}) + common.SysError("failed to fetch codex usage: " + err.Error()) + c.JSON(http.StatusOK, gin.H{"success": false, "message": "获取用量信息失败,请稍后重试"}) return } @@ -99,7 +101,8 @@ func GetCodexChannelUsage(c *gin.Context) { defer cancel2() statusCode, body, err = service.FetchCodexWhamUsage(ctx2, client, ch.GetBaseURL(), oauthKey.AccessToken, accountID) if err != nil { - c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()}) + common.SysError("failed to fetch codex usage after refresh: " + err.Error()) + c.JSON(http.StatusOK, gin.H{"success": false, "message": "获取用量信息失败,请稍后重试"}) return } } diff --git a/controller/console_migrate.go b/controller/console_migrate.go index 011ab09d4..458496104 100644 --- a/controller/console_migrate.go +++ b/controller/console_migrate.go @@ -17,7 +17,8 @@ func MigrateConsoleSetting(c *gin.Context) { // 读取全部 option opts, err := model.AllOption() if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"success": false, "message": err.Error()}) + common.SysError("failed to get all options: " + err.Error()) + c.JSON(http.StatusInternalServerError, gin.H{"success": false, "message": "获取配置失败,请稍后重试"}) return } // 建立 map diff --git a/controller/model_sync.go b/controller/model_sync.go index 737f92d40..160d8f780 100644 --- a/controller/model_sync.go +++ b/controller/model_sync.go @@ -272,7 +272,8 @@ func SyncUpstreamModels(c *gin.Context) { // 1) 获取未配置模型列表 missing, err := model.GetMissingModels() if err != nil { - c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()}) + common.SysError("failed to get missing models: " + err.Error()) + c.JSON(http.StatusOK, gin.H{"success": false, "message": "获取模型列表失败,请稍后重试"}) return } diff --git a/controller/performance.go b/controller/performance.go index a6fedc46c..8a261ad82 100644 --- a/controller/performance.go +++ b/controller/performance.go @@ -3,7 +3,6 @@ package controller import ( "net/http" "os" - "path/filepath" "runtime" "github.com/QuantumNous/new-api/common" @@ -78,6 +77,9 @@ type PerformanceConfig struct { // GetPerformanceStats 获取性能统计信息 func GetPerformanceStats(c *gin.Context) { + // 先同步磁盘缓存统计,确保显示准确 + common.SyncDiskCacheStats() + // 获取缓存统计 cacheStats := common.GetDiskCacheStats() @@ -123,11 +125,8 @@ func GetPerformanceStats(c *gin.Context) { // ClearDiskCache 清理磁盘缓存 func ClearDiskCache(c *gin.Context) { - cachePath := common.GetDiskCachePath() - if cachePath == "" { - cachePath = os.TempDir() - } - dir := filepath.Join(cachePath, "new-api-body-cache") + // 使用统一的缓存目录 + dir := common.GetDiskCacheDir() // 删除缓存目录 err := os.RemoveAll(dir) @@ -136,8 +135,9 @@ func ClearDiskCache(c *gin.Context) { return } - // 重置统计 + // 重置统计(包括命中次数和使用量) common.ResetDiskCacheStats() + common.ResetDiskCacheUsage() c.JSON(http.StatusOK, gin.H{ "success": true, @@ -167,11 +167,8 @@ func ForceGC(c *gin.Context) { // getDiskCacheInfo 获取磁盘缓存目录信息 func getDiskCacheInfo() DiskCacheInfo { - cachePath := common.GetDiskCachePath() - if cachePath == "" { - cachePath = os.TempDir() - } - dir := filepath.Join(cachePath, "new-api-body-cache") + // 使用统一的缓存目录 + dir := common.GetDiskCacheDir() info := DiskCacheInfo{ Path: dir, diff --git a/controller/ratio_sync.go b/controller/ratio_sync.go index 0b6a6dff0..68b776e9a 100644 --- a/controller/ratio_sync.go +++ b/controller/ratio_sync.go @@ -56,7 +56,8 @@ type upstreamResult struct { func FetchUpstreamRatios(c *gin.Context) { var req dto.UpstreamRequest if err := c.ShouldBindJSON(&req); err != nil { - c.JSON(http.StatusBadRequest, gin.H{"success": false, "message": err.Error()}) + common.SysError("failed to bind upstream request: " + err.Error()) + c.JSON(http.StatusBadRequest, gin.H{"success": false, "message": "请求参数格式错误"}) return } diff --git a/controller/redemption.go b/controller/redemption.go index 945cefa35..33c17346c 100644 --- a/controller/redemption.go +++ b/controller/redemption.go @@ -103,9 +103,10 @@ func AddRedemption(c *gin.Context) { } err = cleanRedemption.Insert() if err != nil { + common.SysError("failed to insert redemption: " + err.Error()) c.JSON(http.StatusOK, gin.H{ "success": false, - "message": err.Error(), + "message": "创建兑换码失败,请稍后重试", "data": keys, }) return diff --git a/controller/token.go b/controller/token.go index c5dc5ec42..b683b730f 100644 --- a/controller/token.go +++ b/controller/token.go @@ -107,9 +107,10 @@ func GetTokenUsage(c *gin.Context) { token, err := model.GetTokenByKey(strings.TrimPrefix(tokenKey, "sk-"), false) if err != nil { + common.SysError("failed to get token by key: " + err.Error()) c.JSON(http.StatusOK, gin.H{ "success": false, - "message": err.Error(), + "message": "获取令牌信息失败,请稍后重试", }) return } diff --git a/dto/claude.go b/dto/claude.go index 14efd7315..8b6b495f6 100644 --- a/dto/claude.go +++ b/dto/claude.go @@ -214,6 +214,14 @@ type ClaudeRequest struct { ServiceTier string `json:"service_tier,omitempty"` } +// createClaudeFileSource 根据数据内容创建正确类型的 FileSource +func createClaudeFileSource(data string) *types.FileSource { + if strings.HasPrefix(data, "http://") || strings.HasPrefix(data, "https://") { + return types.NewURLFileSource(data) + } + return types.NewBase64FileSource(data, "") +} + func (c *ClaudeRequest) GetTokenCountMeta() *types.TokenCountMeta { var tokenCountMeta = types.TokenCountMeta{ TokenType: types.TokenTypeTokenizer, @@ -243,7 +251,10 @@ func (c *ClaudeRequest) GetTokenCountMeta() *types.TokenCountMeta { data = common.Interface2String(media.Source.Data) } if data != "" { - fileMeta = append(fileMeta, &types.FileMeta{FileType: types.FileTypeImage, OriginData: data}) + fileMeta = append(fileMeta, &types.FileMeta{ + FileType: types.FileTypeImage, + Source: createClaudeFileSource(data), + }) } } } @@ -275,7 +286,10 @@ func (c *ClaudeRequest) GetTokenCountMeta() *types.TokenCountMeta { data = common.Interface2String(media.Source.Data) } if data != "" { - fileMeta = append(fileMeta, &types.FileMeta{FileType: types.FileTypeImage, OriginData: data}) + fileMeta = append(fileMeta, &types.FileMeta{ + FileType: types.FileTypeImage, + Source: createClaudeFileSource(data), + }) } } case "tool_use": diff --git a/dto/gemini.go b/dto/gemini.go index b330f8b1b..0fd74c639 100644 --- a/dto/gemini.go +++ b/dto/gemini.go @@ -64,6 +64,14 @@ type LatLng struct { Longitude *float64 `json:"longitude,omitempty"` } +// createGeminiFileSource 根据数据内容创建正确类型的 FileSource +func createGeminiFileSource(data string, mimeType string) *types.FileSource { + if strings.HasPrefix(data, "http://") || strings.HasPrefix(data, "https://") { + return types.NewURLFileSource(data) + } + return types.NewBase64FileSource(data, mimeType) +} + func (r *GeminiChatRequest) GetTokenCountMeta() *types.TokenCountMeta { var files []*types.FileMeta = make([]*types.FileMeta, 0) @@ -80,27 +88,23 @@ func (r *GeminiChatRequest) GetTokenCountMeta() *types.TokenCountMeta { inputTexts = append(inputTexts, part.Text) } if part.InlineData != nil && part.InlineData.Data != "" { - if strings.HasPrefix(part.InlineData.MimeType, "image/") { - files = append(files, &types.FileMeta{ - FileType: types.FileTypeImage, - OriginData: part.InlineData.Data, - }) - } else if strings.HasPrefix(part.InlineData.MimeType, "audio/") { - files = append(files, &types.FileMeta{ - FileType: types.FileTypeAudio, - OriginData: part.InlineData.Data, - }) - } else if strings.HasPrefix(part.InlineData.MimeType, "video/") { - files = append(files, &types.FileMeta{ - FileType: types.FileTypeVideo, - OriginData: part.InlineData.Data, - }) + mimeType := part.InlineData.MimeType + source := createGeminiFileSource(part.InlineData.Data, mimeType) + var fileType types.FileType + if strings.HasPrefix(mimeType, "image/") { + fileType = types.FileTypeImage + } else if strings.HasPrefix(mimeType, "audio/") { + fileType = types.FileTypeAudio + } else if strings.HasPrefix(mimeType, "video/") { + fileType = types.FileTypeVideo } else { - files = append(files, &types.FileMeta{ - FileType: types.FileTypeFile, - OriginData: part.InlineData.Data, - }) + fileType = types.FileTypeFile } + files = append(files, &types.FileMeta{ + FileType: fileType, + Source: source, + MimeType: mimeType, + }) } } } diff --git a/dto/openai_request.go b/dto/openai_request.go index 5e40d5472..9113a086e 100644 --- a/dto/openai_request.go +++ b/dto/openai_request.go @@ -101,6 +101,14 @@ type GeneralOpenAIRequest struct { SearchMode string `json:"search_mode,omitempty"` } +// createFileSource 根据数据内容创建正确类型的 FileSource +func createFileSource(data string) *types.FileSource { + if strings.HasPrefix(data, "http://") || strings.HasPrefix(data, "https://") { + return types.NewURLFileSource(data) + } + return types.NewBase64FileSource(data, "") +} + func (r *GeneralOpenAIRequest) GetTokenCountMeta() *types.TokenCountMeta { var tokenCountMeta types.TokenCountMeta var texts = make([]string, 0) @@ -144,42 +152,40 @@ func (r *GeneralOpenAIRequest) GetTokenCountMeta() *types.TokenCountMeta { for _, m := range arrayContent { if m.Type == ContentTypeImageURL { imageUrl := m.GetImageMedia() - if imageUrl != nil { - if imageUrl.Url != "" { - meta := &types.FileMeta{ - FileType: types.FileTypeImage, - } - meta.OriginData = imageUrl.Url - meta.Detail = imageUrl.Detail - fileMeta = append(fileMeta, meta) - } + if imageUrl != nil && imageUrl.Url != "" { + source := createFileSource(imageUrl.Url) + fileMeta = append(fileMeta, &types.FileMeta{ + FileType: types.FileTypeImage, + Source: source, + Detail: imageUrl.Detail, + }) } } else if m.Type == ContentTypeInputAudio { inputAudio := m.GetInputAudio() - if inputAudio != nil { - meta := &types.FileMeta{ + if inputAudio != nil && inputAudio.Data != "" { + source := createFileSource(inputAudio.Data) + fileMeta = append(fileMeta, &types.FileMeta{ FileType: types.FileTypeAudio, - } - meta.OriginData = inputAudio.Data - fileMeta = append(fileMeta, meta) + Source: source, + }) } } else if m.Type == ContentTypeFile { file := m.GetFile() - if file != nil { - meta := &types.FileMeta{ + if file != nil && file.FileData != "" { + source := createFileSource(file.FileData) + fileMeta = append(fileMeta, &types.FileMeta{ FileType: types.FileTypeFile, - } - meta.OriginData = file.FileData - fileMeta = append(fileMeta, meta) + Source: source, + }) } } else if m.Type == ContentTypeVideoUrl { videoUrl := m.GetVideoUrl() if videoUrl != nil && videoUrl.Url != "" { - meta := &types.FileMeta{ + source := createFileSource(videoUrl.Url) + fileMeta = append(fileMeta, &types.FileMeta{ FileType: types.FileTypeVideo, - } - meta.OriginData = videoUrl.Url - fileMeta = append(fileMeta, meta) + Source: source, + }) } } else { texts = append(texts, m.Text) @@ -833,16 +839,16 @@ func (r *OpenAIResponsesRequest) GetTokenCountMeta() *types.TokenCountMeta { if input.Type == "input_image" { if input.ImageUrl != "" { fileMeta = append(fileMeta, &types.FileMeta{ - FileType: types.FileTypeImage, - OriginData: input.ImageUrl, - Detail: input.Detail, + FileType: types.FileTypeImage, + Source: createFileSource(input.ImageUrl), + Detail: input.Detail, }) } } else if input.Type == "input_file" { if input.FileUrl != "" { fileMeta = append(fileMeta, &types.FileMeta{ - FileType: types.FileTypeFile, - OriginData: input.FileUrl, + FileType: types.FileTypeFile, + Source: createFileSource(input.FileUrl), }) } } else { diff --git a/middleware/body_cleanup.go b/middleware/body_cleanup.go index 5d06726f7..f7b7ab51a 100644 --- a/middleware/body_cleanup.go +++ b/middleware/body_cleanup.go @@ -2,6 +2,7 @@ package middleware import ( "github.com/QuantumNous/new-api/common" + "github.com/QuantumNous/new-api/service" "github.com/gin-gonic/gin" ) @@ -14,5 +15,8 @@ func BodyStorageCleanup() gin.HandlerFunc { // 请求结束后清理存储 common.CleanupBodyStorage(c) + + // 清理文件缓存(URL 下载的文件等) + service.CleanupFileSources(c) } } diff --git a/model/redemption.go b/model/redemption.go index 7dd2d9527..237561bec 100644 --- a/model/redemption.go +++ b/model/redemption.go @@ -148,7 +148,8 @@ func Redeem(key string, userId int) (quota int, err error) { return err }) if err != nil { - return 0, errors.New("兑换失败," + err.Error()) + common.SysError("redemption failed: " + err.Error()) + return 0, errors.New("兑换失败,请稍后重试") } RecordLog(userId, LogTypeTopup, fmt.Sprintf("通过兑换码充值 %s,兑换码ID %d", logger.LogQuota(redemption.Quota), redemption.Id)) return redemption.Quota, nil diff --git a/model/topup.go b/model/topup.go index 994094d9d..655d9b77a 100644 --- a/model/topup.go +++ b/model/topup.go @@ -95,7 +95,8 @@ func Recharge(referenceId string, customerId string) (err error) { }) if err != nil { - return errors.New("充值失败," + err.Error()) + common.SysError("topup failed: " + err.Error()) + return errors.New("充值失败,请稍后重试") } RecordLog(topUp.UserId, LogTypeTopup, fmt.Sprintf("使用在线充值成功,充值金额: %v,支付金额:%d", logger.FormatQuota(int(quota)), topUp.Amount)) @@ -367,7 +368,8 @@ func RechargeCreem(referenceId string, customerEmail string, customerName string }) if err != nil { - return errors.New("充值失败," + err.Error()) + common.SysError("creem topup failed: " + err.Error()) + return errors.New("充值失败,请稍后重试") } RecordLog(topUp.UserId, LogTypeTopup, fmt.Sprintf("使用Creem充值成功,充值额度: %v,支付金额:%.2f", quota, topUp.Money)) diff --git a/relay/channel/aws/adaptor.go b/relay/channel/aws/adaptor.go index 423cf842d..e9e5fd913 100644 --- a/relay/channel/aws/adaptor.go +++ b/relay/channel/aws/adaptor.go @@ -49,12 +49,14 @@ func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayIn for i2, mediaMessage := range content { if mediaMessage.Source != nil { if mediaMessage.Source.Type == "url" { - fileData, err := service.GetFileBase64FromUrl(c, mediaMessage.Source.Url, "formatting image for Claude") + // 使用统一的文件服务获取图片数据 + source := types.NewURLFileSource(mediaMessage.Source.Url) + base64Data, mimeType, err := service.GetBase64Data(c, source, "formatting image for Claude") if err != nil { return nil, fmt.Errorf("get file base64 from url failed: %s", err.Error()) } - mediaMessage.Source.MediaType = fileData.MimeType - mediaMessage.Source.Data = fileData.Base64Data + mediaMessage.Source.MediaType = mimeType + mediaMessage.Source.Data = base64Data mediaMessage.Source.Url = "" mediaMessage.Source.Type = "base64" content[i2] = mediaMessage diff --git a/relay/channel/claude/relay-claude.go b/relay/channel/claude/relay-claude.go index 4fdf7c275..f7e8abd15 100644 --- a/relay/channel/claude/relay-claude.go +++ b/relay/channel/claude/relay-claude.go @@ -364,23 +364,19 @@ func RequestOpenAI2ClaudeMessage(c *gin.Context, textRequest dto.GeneralOpenAIRe claudeMediaMessage.Source = &dto.ClaudeMessageSource{ Type: "base64", } - // 判断是否是url + // 使用统一的文件服务获取图片数据 + var source *types.FileSource if strings.HasPrefix(imageUrl.Url, "http") { - // 是url,获取图片的类型和base64编码的数据 - fileData, err := service.GetFileBase64FromUrl(c, imageUrl.Url, "formatting image for Claude") - if err != nil { - return nil, fmt.Errorf("get file base64 from url failed: %s", err.Error()) - } - claudeMediaMessage.Source.MediaType = fileData.MimeType - claudeMediaMessage.Source.Data = fileData.Base64Data + source = types.NewURLFileSource(imageUrl.Url) } else { - _, format, base64String, err := service.DecodeBase64ImageData(imageUrl.Url) - if err != nil { - return nil, err - } - claudeMediaMessage.Source.MediaType = "image/" + format - claudeMediaMessage.Source.Data = base64String + source = types.NewBase64FileSource(imageUrl.Url, "") } + base64Data, mimeType, err := service.GetBase64Data(c, source, "formatting image for Claude") + if err != nil { + return nil, fmt.Errorf("get file data failed: %s", err.Error()) + } + claudeMediaMessage.Source.MediaType = mimeType + claudeMediaMessage.Source.Data = base64Data } claudeMediaMessages = append(claudeMediaMessages, claudeMediaMessage) } diff --git a/relay/channel/gemini/relay-gemini.go b/relay/channel/gemini/relay-gemini.go index e39cf99e8..f4a5be1af 100644 --- a/relay/channel/gemini/relay-gemini.go +++ b/relay/channel/gemini/relay-gemini.go @@ -540,64 +540,58 @@ func CovertOpenAI2Gemini(c *gin.Context, textRequest dto.GeneralOpenAIRequest, i if constant.GeminiVisionMaxImageNum != -1 && imageNum > constant.GeminiVisionMaxImageNum { return nil, fmt.Errorf("too many images in the message, max allowed is %d", constant.GeminiVisionMaxImageNum) } - // 判断是否是url - if strings.HasPrefix(part.GetImageMedia().Url, "http") { - // 是url,获取文件的类型和base64编码的数据 - fileData, err := service.GetFileBase64FromUrl(c, part.GetImageMedia().Url, "formatting image for Gemini") - if err != nil { - return nil, fmt.Errorf("get file base64 from url '%s' failed: %w", part.GetImageMedia().Url, err) - } - - // 校验 MimeType 是否在 Gemini 支持的白名单中 - if _, ok := geminiSupportedMimeTypes[strings.ToLower(fileData.MimeType)]; !ok { - url := part.GetImageMedia().Url - return nil, fmt.Errorf("mime type is not supported by Gemini: '%s', url: '%s', supported types are: %v", fileData.MimeType, url, getSupportedMimeTypesList()) - } - - parts = append(parts, dto.GeminiPart{ - InlineData: &dto.GeminiInlineData{ - MimeType: fileData.MimeType, // 使用原始的 MimeType,因为大小写可能对API有意义 - Data: fileData.Base64Data, - }, - }) + // 使用统一的文件服务获取图片数据 + var source *types.FileSource + imageUrl := part.GetImageMedia().Url + if strings.HasPrefix(imageUrl, "http") { + source = types.NewURLFileSource(imageUrl) } else { - format, base64String, err := service.DecodeBase64FileData(part.GetImageMedia().Url) - if err != nil { - return nil, fmt.Errorf("decode base64 image data failed: %s", err.Error()) - } - parts = append(parts, dto.GeminiPart{ - InlineData: &dto.GeminiInlineData{ - MimeType: format, - Data: base64String, - }, - }) + source = types.NewBase64FileSource(imageUrl, "") } + base64Data, mimeType, err := service.GetBase64Data(c, source, "formatting image for Gemini") + if err != nil { + return nil, fmt.Errorf("get file data from '%s' failed: %w", source.GetIdentifier(), err) + } + + // 校验 MimeType 是否在 Gemini 支持的白名单中 + if _, ok := geminiSupportedMimeTypes[strings.ToLower(mimeType)]; !ok { + return nil, fmt.Errorf("mime type is not supported by Gemini: '%s', url: '%s', supported types are: %v", mimeType, source.GetIdentifier(), getSupportedMimeTypesList()) + } + + parts = append(parts, dto.GeminiPart{ + InlineData: &dto.GeminiInlineData{ + MimeType: mimeType, + Data: base64Data, + }, + }) } else if part.Type == dto.ContentTypeFile { if part.GetFile().FileId != "" { return nil, fmt.Errorf("only base64 file is supported in gemini") } - format, base64String, err := service.DecodeBase64FileData(part.GetFile().FileData) + fileSource := types.NewBase64FileSource(part.GetFile().FileData, "") + base64Data, mimeType, err := service.GetBase64Data(c, fileSource, "formatting file for Gemini") if err != nil { return nil, fmt.Errorf("decode base64 file data failed: %s", err.Error()) } parts = append(parts, dto.GeminiPart{ InlineData: &dto.GeminiInlineData{ - MimeType: format, - Data: base64String, + MimeType: mimeType, + Data: base64Data, }, }) } else if part.Type == dto.ContentTypeInputAudio { if part.GetInputAudio().Data == "" { return nil, fmt.Errorf("only base64 audio is supported in gemini") } - base64String, err := service.DecodeBase64AudioData(part.GetInputAudio().Data) + audioSource := types.NewBase64FileSource(part.GetInputAudio().Data, "audio/"+part.GetInputAudio().Format) + base64Data, mimeType, err := service.GetBase64Data(c, audioSource, "formatting audio for Gemini") if err != nil { return nil, fmt.Errorf("decode base64 audio data failed: %s", err.Error()) } parts = append(parts, dto.GeminiPart{ InlineData: &dto.GeminiInlineData{ - MimeType: "audio/" + part.GetInputAudio().Format, - Data: base64String, + MimeType: mimeType, + Data: base64Data, }, }) } diff --git a/relay/channel/ollama/relay-ollama.go b/relay/channel/ollama/relay-ollama.go index 795e9c975..ccc19c67b 100644 --- a/relay/channel/ollama/relay-ollama.go +++ b/relay/channel/ollama/relay-ollama.go @@ -99,19 +99,16 @@ func openAIChatToOllamaChat(c *gin.Context, r *dto.GeneralOpenAIRequest) (*Ollam if part.Type == dto.ContentTypeImageURL { img := part.GetImageMedia() if img != nil && img.Url != "" { - var base64Data string + // 使用统一的文件服务获取图片数据 + var source *types.FileSource if strings.HasPrefix(img.Url, "http") { - fileData, err := service.GetFileBase64FromUrl(c, img.Url, "fetch image for ollama chat") - if err != nil { - return nil, err - } - base64Data = fileData.Base64Data - } else if strings.HasPrefix(img.Url, "data:") { - if idx := strings.Index(img.Url, ","); idx != -1 && idx+1 < len(img.Url) { - base64Data = img.Url[idx+1:] - } + source = types.NewURLFileSource(img.Url) } else { - base64Data = img.Url + source = types.NewBase64FileSource(img.Url, "") + } + base64Data, _, err := service.GetBase64Data(c, source, "fetch image for ollama chat") + if err != nil { + return nil, err } if base64Data != "" { images = append(images, base64Data) diff --git a/service/file_decoder.go b/service/file_decoder.go index ae3bc581f..d5831d8c1 100644 --- a/service/file_decoder.go +++ b/service/file_decoder.go @@ -2,7 +2,6 @@ package service import ( "bytes" - "encoding/base64" "fmt" "image" _ "image/gif" @@ -13,7 +12,6 @@ import ( "strings" "github.com/QuantumNous/new-api/common" - "github.com/QuantumNous/new-api/constant" "github.com/QuantumNous/new-api/logger" "github.com/QuantumNous/new-api/types" @@ -130,90 +128,27 @@ func GetFileTypeFromUrl(c *gin.Context, url string, reason ...string) (string, e return "application/octet-stream", nil } +// GetFileBase64FromUrl 从 URL 获取文件的 base64 编码数据 +// Deprecated: 请使用 GetBase64Data 配合 types.NewURLFileSource 替代 +// 此函数保留用于向后兼容,内部已重构为调用统一的文件服务 func GetFileBase64FromUrl(c *gin.Context, url string, reason ...string) (*types.LocalFileData, error) { - contextKey := fmt.Sprintf("file_download_%s", common.GenerateHMAC(url)) - - // Check if the file has already been downloaded in this request - if cachedData, exists := c.Get(contextKey); exists { - if common.DebugEnabled { - logger.LogDebug(c, fmt.Sprintf("Using cached file data for URL: %s", url)) - } - return cachedData.(*types.LocalFileData), nil - } - - var maxFileSize = constant.MaxFileDownloadMB * 1024 * 1024 - - resp, err := DoDownloadRequest(url, reason...) + source := types.NewURLFileSource(url) + cachedData, err := LoadFileSource(c, source, reason...) if err != nil { return nil, err } - defer resp.Body.Close() - // Always use LimitReader to prevent oversized downloads - fileBytes, err := io.ReadAll(io.LimitReader(resp.Body, int64(maxFileSize+1))) + // 转换为旧的 LocalFileData 格式以保持兼容 + base64Data, err := cachedData.GetBase64Data() if err != nil { return nil, err } - // Check actual size after reading - if len(fileBytes) > maxFileSize { - return nil, fmt.Errorf("file size exceeds maximum allowed size: %dMB", constant.MaxFileDownloadMB) - } - - // Convert to base64 - base64Data := base64.StdEncoding.EncodeToString(fileBytes) - - mimeType := resp.Header.Get("Content-Type") - if len(strings.Split(mimeType, ";")) > 1 { - // If Content-Type has parameters, take the first part - mimeType = strings.Split(mimeType, ";")[0] - } - if mimeType == "application/octet-stream" { - logger.LogDebug(c, fmt.Sprintf("MIME type is application/octet-stream for URL: %s", url)) - // try to guess the MIME type from the url last segment - urlParts := strings.Split(url, "/") - if len(urlParts) > 0 { - lastSegment := urlParts[len(urlParts)-1] - if strings.Contains(lastSegment, ".") { - // Extract the file extension - filename := strings.Split(lastSegment, ".") - if len(filename) > 1 { - ext := strings.ToLower(filename[len(filename)-1]) - // Guess MIME type based on file extension - mimeType = GetMimeTypeByExtension(ext) - } - } - } else { - // try to guess the MIME type from the file extension - fileName := resp.Header.Get("Content-Disposition") - if fileName != "" { - // Extract the filename from the Content-Disposition header - parts := strings.Split(fileName, ";") - for _, part := range parts { - if strings.HasPrefix(strings.TrimSpace(part), "filename=") { - fileName = strings.TrimSpace(strings.TrimPrefix(part, "filename=")) - // Remove quotes if present - if len(fileName) > 2 && fileName[0] == '"' && fileName[len(fileName)-1] == '"' { - fileName = fileName[1 : len(fileName)-1] - } - // Guess MIME type based on file extension - if ext := strings.ToLower(strings.TrimPrefix(fileName, ".")); ext != "" { - mimeType = GetMimeTypeByExtension(ext) - } - break - } - } - } - } - } - data := &types.LocalFileData{ + return &types.LocalFileData{ Base64Data: base64Data, - MimeType: mimeType, - Size: int64(len(fileBytes)), - } - // Store the file data in the context to avoid re-downloading - c.Set(contextKey, data) - - return data, nil + MimeType: cachedData.MimeType, + Size: cachedData.Size, + Url: url, + }, nil } func GetMimeTypeByExtension(ext string) string { diff --git a/service/file_service.go b/service/file_service.go new file mode 100644 index 000000000..a42a42bf2 --- /dev/null +++ b/service/file_service.go @@ -0,0 +1,451 @@ +package service + +import ( + "bytes" + "encoding/base64" + "fmt" + "image" + _ "image/gif" + _ "image/jpeg" + _ "image/png" + "io" + "net/http" + "strings" + + "github.com/QuantumNous/new-api/common" + "github.com/QuantumNous/new-api/constant" + "github.com/QuantumNous/new-api/logger" + "github.com/QuantumNous/new-api/types" + + "github.com/gin-gonic/gin" + "golang.org/x/image/webp" +) + +// FileService 统一的文件处理服务 +// 提供文件下载、解码、缓存等功能的统一入口 + +// getContextCacheKey 生成 context 缓存的 key +func getContextCacheKey(url string) string { + return fmt.Sprintf("file_cache_%s", common.GenerateHMAC(url)) +} + +// LoadFileSource 加载文件源数据 +// 这是统一的入口,会自动处理缓存和不同的来源类型 +func LoadFileSource(c *gin.Context, source *types.FileSource, reason ...string) (*types.CachedFileData, error) { + if source == nil { + return nil, fmt.Errorf("file source is nil") + } + + // 如果已有缓存,直接返回 + if source.HasCache() { + return source.GetCache(), nil + } + + var cachedData *types.CachedFileData + var err error + + if source.IsURL() { + cachedData, err = loadFromURL(c, source.URL, reason...) + } else { + cachedData, err = loadFromBase64(source.Base64Data, source.MimeType) + } + + if err != nil { + return nil, err + } + + // 设置缓存 + source.SetCache(cachedData) + + // 注册到 context 以便请求结束时自动清理 + if c != nil { + registerSourceForCleanup(c, source) + } + + return cachedData, nil +} + +// registerSourceForCleanup 注册 FileSource 到 context 以便请求结束时清理 +func registerSourceForCleanup(c *gin.Context, source *types.FileSource) { + key := string(constant.ContextKeyFileSourcesToCleanup) + var sources []*types.FileSource + if existing, exists := c.Get(key); exists { + sources = existing.([]*types.FileSource) + } + sources = append(sources, source) + c.Set(key, sources) +} + +// CleanupFileSources 清理请求中所有注册的 FileSource +// 应在请求结束时调用(通常由中间件自动调用) +func CleanupFileSources(c *gin.Context) { + key := string(constant.ContextKeyFileSourcesToCleanup) + if sources, exists := c.Get(key); exists { + for _, source := range sources.([]*types.FileSource) { + if cache := source.GetCache(); cache != nil { + if cache.IsDisk() { + common.DecrementDiskFiles(cache.Size) + } + cache.Close() + } + } + c.Set(key, nil) // 清除引用 + } +} + +// loadFromURL 从 URL 加载文件 +// 支持磁盘缓存:当文件大小超过阈值且磁盘缓存可用时,将数据存储到磁盘 +func loadFromURL(c *gin.Context, url string, reason ...string) (*types.CachedFileData, error) { + contextKey := getContextCacheKey(url) + + // 检查 context 缓存 + if cachedData, exists := c.Get(contextKey); exists { + if common.DebugEnabled { + logger.LogDebug(c, fmt.Sprintf("Using cached file data for URL: %s", url)) + } + return cachedData.(*types.CachedFileData), nil + } + + // 下载文件 + var maxFileSize = constant.MaxFileDownloadMB * 1024 * 1024 + + resp, err := DoDownloadRequest(url, reason...) + if err != nil { + return nil, fmt.Errorf("failed to download file from %s: %w", url, err) + } + defer resp.Body.Close() + + if resp.StatusCode != 200 { + return nil, fmt.Errorf("failed to download file, status code: %d", resp.StatusCode) + } + + // 读取文件内容(限制大小) + fileBytes, err := io.ReadAll(io.LimitReader(resp.Body, int64(maxFileSize+1))) + if err != nil { + return nil, fmt.Errorf("failed to read file content: %w", err) + } + if len(fileBytes) > maxFileSize { + return nil, fmt.Errorf("file size exceeds maximum allowed size: %dMB", constant.MaxFileDownloadMB) + } + + // 转换为 base64 + base64Data := base64.StdEncoding.EncodeToString(fileBytes) + + // 智能获取 MIME 类型 + mimeType := smartDetectMimeType(resp, url, fileBytes) + + // 判断是否使用磁盘缓存 + base64Size := int64(len(base64Data)) + var cachedData *types.CachedFileData + + if shouldUseDiskCache(base64Size) { + // 使用磁盘缓存 + diskPath, err := writeToDiskCache(base64Data) + if err != nil { + // 磁盘缓存失败,回退到内存 + logger.LogWarn(c, fmt.Sprintf("Failed to write to disk cache, falling back to memory: %v", err)) + cachedData = types.NewMemoryCachedData(base64Data, mimeType, int64(len(fileBytes))) + } else { + cachedData = types.NewDiskCachedData(diskPath, mimeType, int64(len(fileBytes))) + common.IncrementDiskFiles(base64Size) + if common.DebugEnabled { + logger.LogDebug(c, fmt.Sprintf("File cached to disk: %s, size: %d bytes", diskPath, base64Size)) + } + } + } else { + // 使用内存缓存 + cachedData = types.NewMemoryCachedData(base64Data, mimeType, int64(len(fileBytes))) + } + + // 如果是图片,尝试获取图片配置 + if strings.HasPrefix(mimeType, "image/") { + config, format, err := decodeImageConfig(fileBytes) + if err == nil { + cachedData.ImageConfig = &config + cachedData.ImageFormat = format + // 如果通过图片解码获取了更准确的格式,更新 MIME 类型 + if mimeType == "application/octet-stream" || mimeType == "" { + cachedData.MimeType = "image/" + format + } + } + } + + // 存入 context 缓存 + c.Set(contextKey, cachedData) + + return cachedData, nil +} + +// shouldUseDiskCache 判断是否应该使用磁盘缓存 +func shouldUseDiskCache(dataSize int64) bool { + return common.ShouldUseDiskCache(dataSize) +} + +// writeToDiskCache 将数据写入磁盘缓存 +func writeToDiskCache(base64Data string) (string, error) { + return common.WriteDiskCacheFileString(common.DiskCacheTypeFile, base64Data) +} + +// smartDetectMimeType 智能检测 MIME 类型 +// 优先级:Content-Type header > Content-Disposition filename > URL 路径 > 内容嗅探 > 图片解码 +func smartDetectMimeType(resp *http.Response, url string, fileBytes []byte) string { + // 1. 尝试从 Content-Type header 获取 + mimeType := resp.Header.Get("Content-Type") + if idx := strings.Index(mimeType, ";"); idx != -1 { + mimeType = strings.TrimSpace(mimeType[:idx]) + } + if mimeType != "" && mimeType != "application/octet-stream" { + return mimeType + } + + // 2. 尝试从 Content-Disposition header 的 filename 获取 + if cd := resp.Header.Get("Content-Disposition"); cd != "" { + parts := strings.Split(cd, ";") + for _, part := range parts { + part = strings.TrimSpace(part) + if strings.HasPrefix(strings.ToLower(part), "filename=") { + name := strings.TrimSpace(strings.TrimPrefix(part, "filename=")) + // 移除引号 + if len(name) > 2 && name[0] == '"' && name[len(name)-1] == '"' { + name = name[1 : len(name)-1] + } + if dot := strings.LastIndex(name, "."); dot != -1 && dot+1 < len(name) { + ext := strings.ToLower(name[dot+1:]) + if ext != "" { + mt := GetMimeTypeByExtension(ext) + if mt != "application/octet-stream" { + return mt + } + } + } + break + } + } + } + + // 3. 尝试从 URL 路径获取扩展名 + mt := guessMimeTypeFromURL(url) + if mt != "application/octet-stream" { + return mt + } + + // 4. 使用 http.DetectContentType 内容嗅探 + if len(fileBytes) > 0 { + sniffed := http.DetectContentType(fileBytes) + if sniffed != "" && sniffed != "application/octet-stream" { + // 去除可能的 charset 参数 + if idx := strings.Index(sniffed, ";"); idx != -1 { + sniffed = strings.TrimSpace(sniffed[:idx]) + } + return sniffed + } + } + + // 5. 尝试作为图片解码获取格式 + if len(fileBytes) > 0 { + if _, format, err := decodeImageConfig(fileBytes); err == nil && format != "" { + return "image/" + strings.ToLower(format) + } + } + + // 最终回退 + return "application/octet-stream" +} + +// loadFromBase64 从 base64 字符串加载文件 +func loadFromBase64(base64String string, providedMimeType string) (*types.CachedFileData, error) { + var mimeType string + var cleanBase64 string + + // 处理 data: 前缀 + if strings.HasPrefix(base64String, "data:") { + // 格式: data:mime/type;base64,xxxxx + idx := strings.Index(base64String, ",") + if idx != -1 { + header := base64String[:idx] + cleanBase64 = base64String[idx+1:] + + // 从 header 提取 MIME 类型 + if strings.Contains(header, ":") && strings.Contains(header, ";") { + mimeStart := strings.Index(header, ":") + 1 + mimeEnd := strings.Index(header, ";") + if mimeStart < mimeEnd { + mimeType = header[mimeStart:mimeEnd] + } + } + } else { + cleanBase64 = base64String + } + } else { + cleanBase64 = base64String + } + + // 使用提供的 MIME 类型(如果有) + if providedMimeType != "" { + mimeType = providedMimeType + } + + // 解码 base64 + decodedData, err := base64.StdEncoding.DecodeString(cleanBase64) + if err != nil { + return nil, fmt.Errorf("failed to decode base64 data: %w", err) + } + + // 判断是否使用磁盘缓存(对于 base64 内联数据也支持磁盘缓存) + base64Size := int64(len(cleanBase64)) + var cachedData *types.CachedFileData + + if shouldUseDiskCache(base64Size) { + // 使用磁盘缓存 + diskPath, err := writeToDiskCache(cleanBase64) + if err != nil { + // 磁盘缓存失败,回退到内存 + cachedData = types.NewMemoryCachedData(cleanBase64, mimeType, int64(len(decodedData))) + } else { + cachedData = types.NewDiskCachedData(diskPath, mimeType, int64(len(decodedData))) + common.IncrementDiskFiles(base64Size) + } + } else { + cachedData = types.NewMemoryCachedData(cleanBase64, mimeType, int64(len(decodedData))) + } + + // 如果是图片或 MIME 类型未知,尝试解码图片获取更多信息 + if mimeType == "" || strings.HasPrefix(mimeType, "image/") { + config, format, err := decodeImageConfig(decodedData) + if err == nil { + cachedData.ImageConfig = &config + cachedData.ImageFormat = format + if mimeType == "" { + cachedData.MimeType = "image/" + format + } + } + } + + return cachedData, nil +} + +// GetImageConfig 获取图片配置(宽高等信息) +// 会自动处理缓存,避免重复下载/解码 +func GetImageConfig(c *gin.Context, source *types.FileSource) (image.Config, string, error) { + cachedData, err := LoadFileSource(c, source, "get_image_config") + if err != nil { + return image.Config{}, "", err + } + + if cachedData.ImageConfig != nil { + return *cachedData.ImageConfig, cachedData.ImageFormat, nil + } + + // 如果缓存中没有图片配置,尝试解码 + base64Str, err := cachedData.GetBase64Data() + if err != nil { + return image.Config{}, "", fmt.Errorf("failed to get base64 data: %w", err) + } + decodedData, err := base64.StdEncoding.DecodeString(base64Str) + if err != nil { + return image.Config{}, "", fmt.Errorf("failed to decode base64 for image config: %w", err) + } + + config, format, err := decodeImageConfig(decodedData) + if err != nil { + return image.Config{}, "", err + } + + // 更新缓存 + cachedData.ImageConfig = &config + cachedData.ImageFormat = format + + return config, format, nil +} + +// GetBase64Data 获取 base64 编码的数据 +// 会自动处理缓存,避免重复下载 +// 支持内存缓存和磁盘缓存 +func GetBase64Data(c *gin.Context, source *types.FileSource, reason ...string) (string, string, error) { + cachedData, err := LoadFileSource(c, source, reason...) + if err != nil { + return "", "", err + } + base64Str, err := cachedData.GetBase64Data() + if err != nil { + return "", "", fmt.Errorf("failed to get base64 data: %w", err) + } + return base64Str, cachedData.MimeType, nil +} + +// GetMimeType 获取文件的 MIME 类型 +func GetMimeType(c *gin.Context, source *types.FileSource) (string, error) { + // 如果已经有缓存,直接返回 + if source.HasCache() { + return source.GetCache().MimeType, nil + } + + // 如果是 URL,尝试只获取 header 而不下载完整文件 + if source.IsURL() { + mimeType, err := GetFileTypeFromUrl(c, source.URL, "get_mime_type") + if err == nil && mimeType != "" && mimeType != "application/octet-stream" { + return mimeType, nil + } + } + + // 否则加载完整数据 + cachedData, err := LoadFileSource(c, source, "get_mime_type") + if err != nil { + return "", err + } + return cachedData.MimeType, nil +} + +// DetectFileType 检测文件类型(image/audio/video/file) +func DetectFileType(mimeType string) types.FileType { + if strings.HasPrefix(mimeType, "image/") { + return types.FileTypeImage + } + if strings.HasPrefix(mimeType, "audio/") { + return types.FileTypeAudio + } + if strings.HasPrefix(mimeType, "video/") { + return types.FileTypeVideo + } + return types.FileTypeFile +} + +// decodeImageConfig 从字节数据解码图片配置 +func decodeImageConfig(data []byte) (image.Config, string, error) { + reader := bytes.NewReader(data) + + // 尝试标准格式 + config, format, err := image.DecodeConfig(reader) + if err == nil { + return config, format, nil + } + + // 尝试 webp + reader.Seek(0, io.SeekStart) + config, err = webp.DecodeConfig(reader) + if err == nil { + return config, "webp", nil + } + + return image.Config{}, "", fmt.Errorf("failed to decode image config: unsupported format") +} + +// guessMimeTypeFromURL 从 URL 猜测 MIME 类型 +func guessMimeTypeFromURL(url string) string { + // 移除查询参数 + cleanedURL := url + if q := strings.Index(cleanedURL, "?"); q != -1 { + cleanedURL = cleanedURL[:q] + } + + // 获取最后一段 + if slash := strings.LastIndex(cleanedURL, "/"); slash != -1 && slash+1 < len(cleanedURL) { + last := cleanedURL[slash+1:] + if dot := strings.LastIndex(last, "."); dot != -1 && dot+1 < len(last) { + ext := strings.ToLower(last[dot+1:]) + return GetMimeTypeByExtension(ext) + } + } + + return "application/octet-stream" +} diff --git a/service/token_counter.go b/service/token_counter.go index c70c54a88..2020845e3 100644 --- a/service/token_counter.go +++ b/service/token_counter.go @@ -3,10 +3,6 @@ package service import ( "errors" "fmt" - "image" - _ "image/gif" - _ "image/jpeg" - _ "image/png" "log" "math" "path/filepath" @@ -23,8 +19,8 @@ import ( "github.com/gin-gonic/gin" ) -func getImageToken(fileMeta *types.FileMeta, model string, stream bool) (int, error) { - if fileMeta == nil { +func getImageToken(c *gin.Context, fileMeta *types.FileMeta, model string, stream bool) (int, error) { + if fileMeta == nil || fileMeta.Source == nil { return 0, fmt.Errorf("image_url_is_nil") } @@ -99,35 +95,20 @@ func getImageToken(fileMeta *types.FileMeta, model string, stream bool) (int, er fileMeta.Detail = "high" } - // Decode image to get dimensions - var config image.Config - var err error - var format string - var b64str string - - if fileMeta.ParsedData != nil { - config, format, b64str, err = DecodeBase64ImageData(fileMeta.ParsedData.Base64Data) - } else { - if strings.HasPrefix(fileMeta.OriginData, "http") { - config, format, err = DecodeUrlImageData(fileMeta.OriginData) - } else { - common.SysLog(fmt.Sprintf("decoding image")) - config, format, b64str, err = DecodeBase64ImageData(fileMeta.OriginData) - } - fileMeta.MimeType = format - } - + // 使用统一的文件服务获取图片配置 + config, format, err := GetImageConfig(c, fileMeta.Source) if err != nil { return 0, err } + fileMeta.MimeType = format if config.Width == 0 || config.Height == 0 { - // not an image - if format != "" && b64str != "" { + // not an image, but might be a valid file + if format != "" { // file type return 3 * baseTokens, nil } - return 0, errors.New(fmt.Sprintf("fail to decode base64 config: %s", fileMeta.OriginData)) + return 0, errors.New(fmt.Sprintf("fail to decode image config: %s", fileMeta.GetIdentifier())) } width := config.Width @@ -269,48 +250,24 @@ func EstimateRequestToken(c *gin.Context, meta *types.TokenCountMeta, info *rela shouldFetchFiles = false } + // 使用统一的文件服务获取文件类型 for _, file := range meta.Files { - if strings.HasPrefix(file.OriginData, "http") { - if shouldFetchFiles { - mineType, err := GetFileTypeFromUrl(c, file.OriginData, "token_counter") - if err != nil { - return 0, fmt.Errorf("error getting file base64 from url: %v", err) - } - if strings.HasPrefix(mineType, "image/") { - file.FileType = types.FileTypeImage - } else if strings.HasPrefix(mineType, "video/") { - file.FileType = types.FileTypeVideo - } else if strings.HasPrefix(mineType, "audio/") { - file.FileType = types.FileTypeAudio - } else { - file.FileType = types.FileTypeFile - } - file.MimeType = mineType - } - } else if strings.HasPrefix(file.OriginData, "data:") { - // get mime type from base64 header - parts := strings.SplitN(file.OriginData, ",", 2) - if len(parts) >= 1 { - header := parts[0] - // Extract mime type from "data:mime/type;base64" format - if strings.Contains(header, ":") && strings.Contains(header, ";") { - mimeStart := strings.Index(header, ":") + 1 - mimeEnd := strings.Index(header, ";") - if mimeStart < mimeEnd { - mineType := header[mimeStart:mimeEnd] - if strings.HasPrefix(mineType, "image/") { - file.FileType = types.FileTypeImage - } else if strings.HasPrefix(mineType, "video/") { - file.FileType = types.FileTypeVideo - } else if strings.HasPrefix(mineType, "audio/") { - file.FileType = types.FileTypeAudio - } else { - file.FileType = types.FileTypeFile - } - file.MimeType = mineType - } + if file.Source == nil { + continue + } + + // 如果文件类型未知且需要获取,通过 MIME 类型检测 + if file.FileType == "" || (file.Source.IsURL() && shouldFetchFiles) { + mimeType, err := GetMimeType(c, file.Source) + if err != nil { + if shouldFetchFiles { + return 0, fmt.Errorf("error getting file type: %v", err) } + // 如果不需要获取,使用默认类型 + continue } + file.MimeType = mimeType + file.FileType = DetectFileType(mimeType) } } @@ -318,9 +275,9 @@ func EstimateRequestToken(c *gin.Context, meta *types.TokenCountMeta, info *rela switch file.FileType { case types.FileTypeImage: if common.IsOpenAITextModel(model) { - token, err := getImageToken(file, model, info.IsStream) + token, err := getImageToken(c, file, model, info.IsStream) if err != nil { - return 0, fmt.Errorf("error counting image token, media index[%d], original data[%s], err: %v", i, file.OriginData, err) + return 0, fmt.Errorf("error counting image token, media index[%d], identifier[%s], err: %v", i, file.GetIdentifier(), err) } tkm += token } else { diff --git a/types/file_source.go b/types/file_source.go new file mode 100644 index 000000000..d2d217fd4 --- /dev/null +++ b/types/file_source.go @@ -0,0 +1,213 @@ +package types + +import ( + "fmt" + "image" + "os" + "sync" +) + +// FileSourceType 文件来源类型 +type FileSourceType string + +const ( + FileSourceTypeURL FileSourceType = "url" // URL 来源 + FileSourceTypeBase64 FileSourceType = "base64" // Base64 内联数据 +) + +// FileSource 统一的文件来源抽象 +// 支持 URL 和 base64 两种来源,提供懒加载和缓存机制 +type FileSource struct { + Type FileSourceType `json:"type"` // 来源类型 + URL string `json:"url,omitempty"` // URL(当 Type 为 url 时) + Base64Data string `json:"base64_data,omitempty"` // Base64 数据(当 Type 为 base64 时) + MimeType string `json:"mime_type,omitempty"` // MIME 类型(可选,会自动检测) + + // 内部缓存(不导出,不序列化) + cachedData *CachedFileData + cacheMu sync.RWMutex + cacheLoaded bool +} + +// CachedFileData 缓存的文件数据 +// 支持内存缓存和磁盘缓存两种模式 +type CachedFileData struct { + base64Data string // 内存中的 base64 数据(小文件) + MimeType string // MIME 类型 + Size int64 // 文件大小(字节) + ImageConfig *image.Config // 图片配置(如果是图片) + ImageFormat string // 图片格式(如果是图片) + + // 磁盘缓存相关 + diskPath string // 磁盘缓存文件路径(大文件) + isDisk bool // 是否使用磁盘缓存 + diskMu sync.Mutex // 磁盘操作锁 + diskClosed bool // 是否已关闭/清理 +} + +// NewMemoryCachedData 创建内存缓存的数据 +func NewMemoryCachedData(base64Data string, mimeType string, size int64) *CachedFileData { + return &CachedFileData{ + base64Data: base64Data, + MimeType: mimeType, + Size: size, + isDisk: false, + } +} + +// NewDiskCachedData 创建磁盘缓存的数据 +func NewDiskCachedData(diskPath string, mimeType string, size int64) *CachedFileData { + return &CachedFileData{ + diskPath: diskPath, + MimeType: mimeType, + Size: size, + isDisk: true, + } +} + +// GetBase64Data 获取 base64 数据(自动处理内存/磁盘) +func (c *CachedFileData) GetBase64Data() (string, error) { + if !c.isDisk { + return c.base64Data, nil + } + + c.diskMu.Lock() + defer c.diskMu.Unlock() + + if c.diskClosed { + return "", fmt.Errorf("disk cache already closed") + } + + // 从磁盘读取 + data, err := os.ReadFile(c.diskPath) + if err != nil { + return "", fmt.Errorf("failed to read from disk cache: %w", err) + } + return string(data), nil +} + +// SetBase64Data 设置 base64 数据(仅用于内存模式) +func (c *CachedFileData) SetBase64Data(data string) { + if !c.isDisk { + c.base64Data = data + } +} + +// IsDisk 是否使用磁盘缓存 +func (c *CachedFileData) IsDisk() bool { + return c.isDisk +} + +// Close 关闭并清理资源 +func (c *CachedFileData) Close() error { + if !c.isDisk { + c.base64Data = "" // 释放内存 + return nil + } + + c.diskMu.Lock() + defer c.diskMu.Unlock() + + if c.diskClosed { + return nil + } + + c.diskClosed = true + if c.diskPath != "" { + return os.Remove(c.diskPath) + } + return nil +} + +// NewURLFileSource 创建 URL 来源的 FileSource +func NewURLFileSource(url string) *FileSource { + return &FileSource{ + Type: FileSourceTypeURL, + URL: url, + } +} + +// NewBase64FileSource 创建 base64 来源的 FileSource +func NewBase64FileSource(base64Data string, mimeType string) *FileSource { + return &FileSource{ + Type: FileSourceTypeBase64, + Base64Data: base64Data, + MimeType: mimeType, + } +} + +// IsURL 判断是否是 URL 来源 +func (f *FileSource) IsURL() bool { + return f.Type == FileSourceTypeURL +} + +// IsBase64 判断是否是 base64 来源 +func (f *FileSource) IsBase64() bool { + return f.Type == FileSourceTypeBase64 +} + +// GetIdentifier 获取文件标识符(用于日志和错误追踪) +func (f *FileSource) GetIdentifier() string { + if f.IsURL() { + if len(f.URL) > 100 { + return f.URL[:100] + "..." + } + return f.URL + } + if len(f.Base64Data) > 50 { + return "base64:" + f.Base64Data[:50] + "..." + } + return "base64:" + f.Base64Data +} + +// GetRawData 获取原始数据(URL 或完整的 base64 字符串) +func (f *FileSource) GetRawData() string { + if f.IsURL() { + return f.URL + } + return f.Base64Data +} + +// SetCache 设置缓存数据 +func (f *FileSource) SetCache(data *CachedFileData) { + f.cacheMu.Lock() + defer f.cacheMu.Unlock() + f.cachedData = data + f.cacheLoaded = true +} + +// GetCache 获取缓存数据 +func (f *FileSource) GetCache() *CachedFileData { + f.cacheMu.RLock() + defer f.cacheMu.RUnlock() + return f.cachedData +} + +// HasCache 是否有缓存 +func (f *FileSource) HasCache() bool { + f.cacheMu.RLock() + defer f.cacheMu.RUnlock() + return f.cacheLoaded && f.cachedData != nil +} + +// ClearCache 清除缓存,释放内存和磁盘文件 +func (f *FileSource) ClearCache() { + f.cacheMu.Lock() + defer f.cacheMu.Unlock() + + // 如果有缓存数据,先关闭它(会清理磁盘文件) + if f.cachedData != nil { + f.cachedData.Close() + } + f.cachedData = nil + f.cacheLoaded = false +} + +// ClearRawData 清除原始数据,只保留必要的元信息 +// 用于在处理完成后释放大文件的内存 +func (f *FileSource) ClearRawData() { + // 保留 URL(通常很短),只清除大的 base64 数据 + if f.IsBase64() && len(f.Base64Data) > 1024 { + f.Base64Data = "" + } +} diff --git a/types/request_meta.go b/types/request_meta.go index 18f80832b..2d909d0b8 100644 --- a/types/request_meta.go +++ b/types/request_meta.go @@ -32,10 +32,48 @@ type TokenCountMeta struct { type FileMeta struct { FileType - MimeType string - OriginData string // url or base64 data - Detail string - ParsedData *LocalFileData + MimeType string + Source *FileSource // 统一的文件来源(URL 或 base64) + Detail string // 图片细节级别(low/high/auto) +} + +// NewFileMeta 创建新的 FileMeta +func NewFileMeta(fileType FileType, source *FileSource) *FileMeta { + return &FileMeta{ + FileType: fileType, + Source: source, + } +} + +// NewImageFileMeta 创建图片类型的 FileMeta +func NewImageFileMeta(source *FileSource, detail string) *FileMeta { + return &FileMeta{ + FileType: FileTypeImage, + Source: source, + Detail: detail, + } +} + +// GetIdentifier 获取文件标识符(用于日志) +func (f *FileMeta) GetIdentifier() string { + if f.Source != nil { + return f.Source.GetIdentifier() + } + return "unknown" +} + +// IsURL 判断是否是 URL 来源 +func (f *FileMeta) IsURL() bool { + return f.Source != nil && f.Source.IsURL() +} + +// GetRawData 获取原始数据(兼容旧代码) +// Deprecated: 请使用 Source.GetRawData() +func (f *FileMeta) GetRawData() string { + if f.Source != nil { + return f.Source.GetRawData() + } + return "" } type RequestMeta struct {