mirror of
https://github.com/QuantumNous/new-api.git
synced 2026-03-30 02:44:40 +00:00
feat(file): unify file handling with a new FileSource abstraction for URL and base64 data
This commit is contained in:
@@ -5,12 +5,9 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/google/uuid"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// BodyStorage 请求体存储接口
|
// BodyStorage 请求体存储接口
|
||||||
@@ -101,25 +98,10 @@ type diskStorage struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func newDiskStorage(data []byte, cachePath string) (*diskStorage, error) {
|
func newDiskStorage(data []byte, cachePath string) (*diskStorage, error) {
|
||||||
// 确定缓存目录
|
// 使用统一的缓存目录管理
|
||||||
dir := cachePath
|
filePath, file, err := CreateDiskCacheFile(DiskCacheTypeBody)
|
||||||
if dir == "" {
|
|
||||||
dir = os.TempDir()
|
|
||||||
}
|
|
||||||
dir = filepath.Join(dir, "new-api-body-cache")
|
|
||||||
|
|
||||||
// 确保目录存在
|
|
||||||
if err := os.MkdirAll(dir, 0755); err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to create cache directory: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 创建临时文件
|
|
||||||
filename := fmt.Sprintf("body-%s-%d.tmp", uuid.New().String()[:8], time.Now().UnixNano())
|
|
||||||
filePath := filepath.Join(dir, filename)
|
|
||||||
|
|
||||||
file, err := os.OpenFile(filePath, os.O_CREATE|os.O_RDWR|os.O_EXCL, 0600)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to create temp file: %w", err)
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// 写入数据
|
// 写入数据
|
||||||
@@ -148,25 +130,10 @@ func newDiskStorage(data []byte, cachePath string) (*diskStorage, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func newDiskStorageFromReader(reader io.Reader, maxBytes int64, cachePath string) (*diskStorage, error) {
|
func newDiskStorageFromReader(reader io.Reader, maxBytes int64, cachePath string) (*diskStorage, error) {
|
||||||
// 确定缓存目录
|
// 使用统一的缓存目录管理
|
||||||
dir := cachePath
|
filePath, file, err := CreateDiskCacheFile(DiskCacheTypeBody)
|
||||||
if dir == "" {
|
|
||||||
dir = os.TempDir()
|
|
||||||
}
|
|
||||||
dir = filepath.Join(dir, "new-api-body-cache")
|
|
||||||
|
|
||||||
// 确保目录存在
|
|
||||||
if err := os.MkdirAll(dir, 0755); err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to create cache directory: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 创建临时文件
|
|
||||||
filename := fmt.Sprintf("body-%s-%d.tmp", uuid.New().String()[:8], time.Now().UnixNano())
|
|
||||||
filePath := filepath.Join(dir, filename)
|
|
||||||
|
|
||||||
file, err := os.OpenFile(filePath, os.O_CREATE|os.O_RDWR|os.O_EXCL, 0600)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to create temp file: %w", err)
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// 从 reader 读取并写入文件
|
// 从 reader 读取并写入文件
|
||||||
@@ -337,29 +304,6 @@ func CreateBodyStorageFromReader(reader io.Reader, contentLength int64, maxBytes
|
|||||||
|
|
||||||
// CleanupOldCacheFiles 清理旧的缓存文件(用于启动时清理残留)
|
// CleanupOldCacheFiles 清理旧的缓存文件(用于启动时清理残留)
|
||||||
func CleanupOldCacheFiles() {
|
func CleanupOldCacheFiles() {
|
||||||
cachePath := GetDiskCachePath()
|
// 使用统一的缓存管理
|
||||||
if cachePath == "" {
|
CleanupOldDiskCacheFiles(5 * time.Minute)
|
||||||
cachePath = os.TempDir()
|
|
||||||
}
|
|
||||||
dir := filepath.Join(cachePath, "new-api-body-cache")
|
|
||||||
|
|
||||||
entries, err := os.ReadDir(dir)
|
|
||||||
if err != nil {
|
|
||||||
return // 目录不存在或无法读取
|
|
||||||
}
|
|
||||||
|
|
||||||
now := time.Now()
|
|
||||||
for _, entry := range entries {
|
|
||||||
if entry.IsDir() {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
info, err := entry.Info()
|
|
||||||
if err != nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
// 删除超过 5 分钟的旧文件
|
|
||||||
if now.Sub(info.ModTime()) > 5*time.Minute {
|
|
||||||
os.Remove(filepath.Join(dir, entry.Name()))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|||||||
172
common/disk_cache.go
Normal file
172
common/disk_cache.go
Normal file
@@ -0,0 +1,172 @@
|
|||||||
|
package common
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
)
|
||||||
|
|
||||||
|
// DiskCacheType 磁盘缓存类型
|
||||||
|
type DiskCacheType string
|
||||||
|
|
||||||
|
const (
|
||||||
|
DiskCacheTypeBody DiskCacheType = "body" // 请求体缓存
|
||||||
|
DiskCacheTypeFile DiskCacheType = "file" // 文件数据缓存
|
||||||
|
)
|
||||||
|
|
||||||
|
// 统一的缓存目录名
|
||||||
|
const diskCacheDir = "new-api-body-cache"
|
||||||
|
|
||||||
|
// GetDiskCacheDir 获取统一的磁盘缓存目录
|
||||||
|
// 注意:每次调用都会重新计算,以响应配置变化
|
||||||
|
func GetDiskCacheDir() string {
|
||||||
|
cachePath := GetDiskCachePath()
|
||||||
|
if cachePath == "" {
|
||||||
|
cachePath = os.TempDir()
|
||||||
|
}
|
||||||
|
return filepath.Join(cachePath, diskCacheDir)
|
||||||
|
}
|
||||||
|
|
||||||
|
// EnsureDiskCacheDir 确保缓存目录存在
|
||||||
|
func EnsureDiskCacheDir() error {
|
||||||
|
dir := GetDiskCacheDir()
|
||||||
|
return os.MkdirAll(dir, 0755)
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateDiskCacheFile 创建磁盘缓存文件
|
||||||
|
// cacheType: 缓存类型(body/file)
|
||||||
|
// 返回文件路径和文件句柄
|
||||||
|
func CreateDiskCacheFile(cacheType DiskCacheType) (string, *os.File, error) {
|
||||||
|
if err := EnsureDiskCacheDir(); err != nil {
|
||||||
|
return "", nil, fmt.Errorf("failed to create cache directory: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
dir := GetDiskCacheDir()
|
||||||
|
filename := fmt.Sprintf("%s-%s-%d.tmp", cacheType, uuid.New().String()[:8], time.Now().UnixNano())
|
||||||
|
filePath := filepath.Join(dir, filename)
|
||||||
|
|
||||||
|
file, err := os.OpenFile(filePath, os.O_CREATE|os.O_RDWR|os.O_EXCL, 0600)
|
||||||
|
if err != nil {
|
||||||
|
return "", nil, fmt.Errorf("failed to create cache file: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return filePath, file, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// WriteDiskCacheFile 写入数据到磁盘缓存文件
|
||||||
|
// 返回文件路径
|
||||||
|
func WriteDiskCacheFile(cacheType DiskCacheType, data []byte) (string, error) {
|
||||||
|
filePath, file, err := CreateDiskCacheFile(cacheType)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = file.Write(data)
|
||||||
|
if err != nil {
|
||||||
|
file.Close()
|
||||||
|
os.Remove(filePath)
|
||||||
|
return "", fmt.Errorf("failed to write cache file: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := file.Close(); err != nil {
|
||||||
|
os.Remove(filePath)
|
||||||
|
return "", fmt.Errorf("failed to close cache file: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return filePath, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// WriteDiskCacheFileString 写入字符串到磁盘缓存文件
|
||||||
|
func WriteDiskCacheFileString(cacheType DiskCacheType, data string) (string, error) {
|
||||||
|
return WriteDiskCacheFile(cacheType, []byte(data))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReadDiskCacheFile 读取磁盘缓存文件
|
||||||
|
func ReadDiskCacheFile(filePath string) ([]byte, error) {
|
||||||
|
return os.ReadFile(filePath)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReadDiskCacheFileString 读取磁盘缓存文件为字符串
|
||||||
|
func ReadDiskCacheFileString(filePath string) (string, error) {
|
||||||
|
data, err := os.ReadFile(filePath)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return string(data), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemoveDiskCacheFile 删除磁盘缓存文件
|
||||||
|
func RemoveDiskCacheFile(filePath string) error {
|
||||||
|
return os.Remove(filePath)
|
||||||
|
}
|
||||||
|
|
||||||
|
// CleanupOldDiskCacheFiles 清理旧的缓存文件
|
||||||
|
// maxAge: 文件最大存活时间
|
||||||
|
// 注意:此函数只删除文件,不更新统计(因为无法知道每个文件的原始大小)
|
||||||
|
func CleanupOldDiskCacheFiles(maxAge time.Duration) error {
|
||||||
|
dir := GetDiskCacheDir()
|
||||||
|
|
||||||
|
entries, err := os.ReadDir(dir)
|
||||||
|
if err != nil {
|
||||||
|
if os.IsNotExist(err) {
|
||||||
|
return nil // 目录不存在,无需清理
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
for _, entry := range entries {
|
||||||
|
if entry.IsDir() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
info, err := entry.Info()
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if now.Sub(info.ModTime()) > maxAge {
|
||||||
|
os.Remove(filepath.Join(dir, entry.Name()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetDiskCacheInfo 获取磁盘缓存目录信息
|
||||||
|
func GetDiskCacheInfo() (fileCount int, totalSize int64, err error) {
|
||||||
|
dir := GetDiskCacheDir()
|
||||||
|
|
||||||
|
entries, err := os.ReadDir(dir)
|
||||||
|
if err != nil {
|
||||||
|
if os.IsNotExist(err) {
|
||||||
|
return 0, 0, nil
|
||||||
|
}
|
||||||
|
return 0, 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, entry := range entries {
|
||||||
|
if entry.IsDir() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
info, err := entry.Info()
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
fileCount++
|
||||||
|
totalSize += info.Size()
|
||||||
|
}
|
||||||
|
return fileCount, totalSize, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ShouldUseDiskCache 判断是否应该使用磁盘缓存
|
||||||
|
func ShouldUseDiskCache(dataSize int64) bool {
|
||||||
|
if !IsDiskCacheEnabled() {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
threshold := GetDiskCacheThresholdBytes()
|
||||||
|
if dataSize < threshold {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return IsDiskCacheAvailable(dataSize)
|
||||||
|
}
|
||||||
@@ -139,12 +139,29 @@ func IncrementMemoryCacheHits() {
|
|||||||
atomic.AddInt64(&diskCacheStats.MemoryCacheHits, 1)
|
atomic.AddInt64(&diskCacheStats.MemoryCacheHits, 1)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ResetDiskCacheStats 重置统计信息(不重置当前使用量)
|
// ResetDiskCacheStats 重置命中统计信息(不重置当前使用量)
|
||||||
func ResetDiskCacheStats() {
|
func ResetDiskCacheStats() {
|
||||||
atomic.StoreInt64(&diskCacheStats.DiskCacheHits, 0)
|
atomic.StoreInt64(&diskCacheStats.DiskCacheHits, 0)
|
||||||
atomic.StoreInt64(&diskCacheStats.MemoryCacheHits, 0)
|
atomic.StoreInt64(&diskCacheStats.MemoryCacheHits, 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ResetDiskCacheUsage 重置磁盘缓存使用量统计(用于清理缓存后)
|
||||||
|
func ResetDiskCacheUsage() {
|
||||||
|
atomic.StoreInt64(&diskCacheStats.ActiveDiskFiles, 0)
|
||||||
|
atomic.StoreInt64(&diskCacheStats.CurrentDiskUsageBytes, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SyncDiskCacheStats 从实际磁盘状态同步统计信息
|
||||||
|
// 用于修正统计与实际不符的情况
|
||||||
|
func SyncDiskCacheStats() {
|
||||||
|
fileCount, totalSize, err := GetDiskCacheInfo()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
atomic.StoreInt64(&diskCacheStats.ActiveDiskFiles, int64(fileCount))
|
||||||
|
atomic.StoreInt64(&diskCacheStats.CurrentDiskUsageBytes, totalSize)
|
||||||
|
}
|
||||||
|
|
||||||
// IsDiskCacheAvailable 检查是否可以创建新的磁盘缓存
|
// IsDiskCacheAvailable 检查是否可以创建新的磁盘缓存
|
||||||
func IsDiskCacheAvailable(requestSize int64) bool {
|
func IsDiskCacheAvailable(requestSize int64) bool {
|
||||||
if !IsDiskCacheEnabled() {
|
if !IsDiskCacheEnabled() {
|
||||||
|
|||||||
@@ -56,6 +56,9 @@ const (
|
|||||||
|
|
||||||
ContextKeySystemPromptOverride ContextKey = "system_prompt_override"
|
ContextKeySystemPromptOverride ContextKey = "system_prompt_override"
|
||||||
|
|
||||||
|
// ContextKeyFileSourcesToCleanup stores file sources that need cleanup when request ends
|
||||||
|
ContextKeyFileSourcesToCleanup ContextKey = "file_sources_to_cleanup"
|
||||||
|
|
||||||
// ContextKeyAdminRejectReason stores an admin-only reject/block reason extracted from upstream responses.
|
// ContextKeyAdminRejectReason stores an admin-only reject/block reason extracted from upstream responses.
|
||||||
// It is not returned to end users, but can be persisted into consume/error logs for debugging.
|
// It is not returned to end users, but can be persisted into consume/error logs for debugging.
|
||||||
ContextKeyAdminRejectReason ContextKey = "admin_reject_reason"
|
ContextKeyAdminRejectReason ContextKey = "admin_reject_reason"
|
||||||
|
|||||||
@@ -89,7 +89,8 @@ func GetAllChannels(c *gin.Context) {
|
|||||||
if enableTagMode {
|
if enableTagMode {
|
||||||
tags, err := model.GetPaginatedTags(pageInfo.GetStartIdx(), pageInfo.GetPageSize())
|
tags, err := model.GetPaginatedTags(pageInfo.GetStartIdx(), pageInfo.GetPageSize())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
|
common.SysError("failed to get paginated tags: " + err.Error())
|
||||||
|
c.JSON(http.StatusOK, gin.H{"success": false, "message": "获取标签失败,请稍后重试"})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
for _, tag := range tags {
|
for _, tag := range tags {
|
||||||
@@ -136,7 +137,8 @@ func GetAllChannels(c *gin.Context) {
|
|||||||
|
|
||||||
err := baseQuery.Order(order).Limit(pageInfo.GetPageSize()).Offset(pageInfo.GetStartIdx()).Omit("key").Find(&channelData).Error
|
err := baseQuery.Order(order).Limit(pageInfo.GetPageSize()).Offset(pageInfo.GetStartIdx()).Omit("key").Find(&channelData).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
|
common.SysError("failed to get channels: " + err.Error())
|
||||||
|
c.JSON(http.StatusOK, gin.H{"success": false, "message": "获取渠道列表失败,请稍后重试"})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -641,7 +643,8 @@ func RefreshCodexChannelCredential(c *gin.Context) {
|
|||||||
|
|
||||||
oauthKey, ch, err := service.RefreshCodexChannelCredential(ctx, channelId, service.CodexCredentialRefreshOptions{ResetCaches: true})
|
oauthKey, ch, err := service.RefreshCodexChannelCredential(ctx, channelId, service.CodexCredentialRefreshOptions{ResetCaches: true})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
|
common.SysError("failed to refresh codex channel credential: " + err.Error())
|
||||||
|
c.JSON(http.StatusOK, gin.H{"success": false, "message": "刷新凭证失败,请稍后重试"})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1315,7 +1318,8 @@ func CopyChannel(c *gin.Context) {
|
|||||||
// fetch original channel with key
|
// fetch original channel with key
|
||||||
origin, err := model.GetChannelById(id, true)
|
origin, err := model.GetChannelById(id, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
|
common.SysError("failed to get channel by id: " + err.Error())
|
||||||
|
c.JSON(http.StatusOK, gin.H{"success": false, "message": "获取渠道信息失败,请稍后重试"})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1333,7 +1337,8 @@ func CopyChannel(c *gin.Context) {
|
|||||||
|
|
||||||
// insert
|
// insert
|
||||||
if err := model.BatchInsertChannels([]model.Channel{clone}); err != nil {
|
if err := model.BatchInsertChannels([]model.Channel{clone}); err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
|
common.SysError("failed to clone channel: " + err.Error())
|
||||||
|
c.JSON(http.StatusOK, gin.H{"success": false, "message": "复制渠道失败,请稍后重试"})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
model.InitChannelCache()
|
model.InitChannelCache()
|
||||||
|
|||||||
@@ -132,7 +132,8 @@ func completeCodexOAuthWithChannelID(c *gin.Context, channelID int) {
|
|||||||
|
|
||||||
code, state, err := parseCodexAuthorizationInput(req.Input)
|
code, state, err := parseCodexAuthorizationInput(req.Input)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
|
common.SysError("failed to parse codex authorization input: " + err.Error())
|
||||||
|
c.JSON(http.StatusOK, gin.H{"success": false, "message": "解析授权信息失败,请检查输入格式"})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if strings.TrimSpace(code) == "" {
|
if strings.TrimSpace(code) == "" {
|
||||||
@@ -177,7 +178,8 @@ func completeCodexOAuthWithChannelID(c *gin.Context, channelID int) {
|
|||||||
|
|
||||||
tokenRes, err := service.ExchangeCodexAuthorizationCode(ctx, code, verifier)
|
tokenRes, err := service.ExchangeCodexAuthorizationCode(ctx, code, verifier)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
|
common.SysError("failed to exchange codex authorization code: " + err.Error())
|
||||||
|
c.JSON(http.StatusOK, gin.H{"success": false, "message": "授权码交换失败,请重试"})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -45,7 +45,8 @@ func GetCodexChannelUsage(c *gin.Context) {
|
|||||||
|
|
||||||
oauthKey, err := codex.ParseOAuthKey(strings.TrimSpace(ch.Key))
|
oauthKey, err := codex.ParseOAuthKey(strings.TrimSpace(ch.Key))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
|
common.SysError("failed to parse oauth key: " + err.Error())
|
||||||
|
c.JSON(http.StatusOK, gin.H{"success": false, "message": "解析凭证失败,请检查渠道配置"})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
accessToken := strings.TrimSpace(oauthKey.AccessToken)
|
accessToken := strings.TrimSpace(oauthKey.AccessToken)
|
||||||
@@ -70,7 +71,8 @@ func GetCodexChannelUsage(c *gin.Context) {
|
|||||||
|
|
||||||
statusCode, body, err := service.FetchCodexWhamUsage(ctx, client, ch.GetBaseURL(), accessToken, accountID)
|
statusCode, body, err := service.FetchCodexWhamUsage(ctx, client, ch.GetBaseURL(), accessToken, accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
|
common.SysError("failed to fetch codex usage: " + err.Error())
|
||||||
|
c.JSON(http.StatusOK, gin.H{"success": false, "message": "获取用量信息失败,请稍后重试"})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -99,7 +101,8 @@ func GetCodexChannelUsage(c *gin.Context) {
|
|||||||
defer cancel2()
|
defer cancel2()
|
||||||
statusCode, body, err = service.FetchCodexWhamUsage(ctx2, client, ch.GetBaseURL(), oauthKey.AccessToken, accountID)
|
statusCode, body, err = service.FetchCodexWhamUsage(ctx2, client, ch.GetBaseURL(), oauthKey.AccessToken, accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
|
common.SysError("failed to fetch codex usage after refresh: " + err.Error())
|
||||||
|
c.JSON(http.StatusOK, gin.H{"success": false, "message": "获取用量信息失败,请稍后重试"})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -17,7 +17,8 @@ func MigrateConsoleSetting(c *gin.Context) {
|
|||||||
// 读取全部 option
|
// 读取全部 option
|
||||||
opts, err := model.AllOption()
|
opts, err := model.AllOption()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"success": false, "message": err.Error()})
|
common.SysError("failed to get all options: " + err.Error())
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"success": false, "message": "获取配置失败,请稍后重试"})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// 建立 map
|
// 建立 map
|
||||||
|
|||||||
@@ -272,7 +272,8 @@ func SyncUpstreamModels(c *gin.Context) {
|
|||||||
// 1) 获取未配置模型列表
|
// 1) 获取未配置模型列表
|
||||||
missing, err := model.GetMissingModels()
|
missing, err := model.GetMissingModels()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
|
common.SysError("failed to get missing models: " + err.Error())
|
||||||
|
c.JSON(http.StatusOK, gin.H{"success": false, "message": "获取模型列表失败,请稍后重试"})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -3,7 +3,6 @@ package controller
|
|||||||
import (
|
import (
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
|
||||||
"runtime"
|
"runtime"
|
||||||
|
|
||||||
"github.com/QuantumNous/new-api/common"
|
"github.com/QuantumNous/new-api/common"
|
||||||
@@ -78,6 +77,9 @@ type PerformanceConfig struct {
|
|||||||
|
|
||||||
// GetPerformanceStats 获取性能统计信息
|
// GetPerformanceStats 获取性能统计信息
|
||||||
func GetPerformanceStats(c *gin.Context) {
|
func GetPerformanceStats(c *gin.Context) {
|
||||||
|
// 先同步磁盘缓存统计,确保显示准确
|
||||||
|
common.SyncDiskCacheStats()
|
||||||
|
|
||||||
// 获取缓存统计
|
// 获取缓存统计
|
||||||
cacheStats := common.GetDiskCacheStats()
|
cacheStats := common.GetDiskCacheStats()
|
||||||
|
|
||||||
@@ -123,11 +125,8 @@ func GetPerformanceStats(c *gin.Context) {
|
|||||||
|
|
||||||
// ClearDiskCache 清理磁盘缓存
|
// ClearDiskCache 清理磁盘缓存
|
||||||
func ClearDiskCache(c *gin.Context) {
|
func ClearDiskCache(c *gin.Context) {
|
||||||
cachePath := common.GetDiskCachePath()
|
// 使用统一的缓存目录
|
||||||
if cachePath == "" {
|
dir := common.GetDiskCacheDir()
|
||||||
cachePath = os.TempDir()
|
|
||||||
}
|
|
||||||
dir := filepath.Join(cachePath, "new-api-body-cache")
|
|
||||||
|
|
||||||
// 删除缓存目录
|
// 删除缓存目录
|
||||||
err := os.RemoveAll(dir)
|
err := os.RemoveAll(dir)
|
||||||
@@ -136,8 +135,9 @@ func ClearDiskCache(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// 重置统计
|
// 重置统计(包括命中次数和使用量)
|
||||||
common.ResetDiskCacheStats()
|
common.ResetDiskCacheStats()
|
||||||
|
common.ResetDiskCacheUsage()
|
||||||
|
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": true,
|
"success": true,
|
||||||
@@ -167,11 +167,8 @@ func ForceGC(c *gin.Context) {
|
|||||||
|
|
||||||
// getDiskCacheInfo 获取磁盘缓存目录信息
|
// getDiskCacheInfo 获取磁盘缓存目录信息
|
||||||
func getDiskCacheInfo() DiskCacheInfo {
|
func getDiskCacheInfo() DiskCacheInfo {
|
||||||
cachePath := common.GetDiskCachePath()
|
// 使用统一的缓存目录
|
||||||
if cachePath == "" {
|
dir := common.GetDiskCacheDir()
|
||||||
cachePath = os.TempDir()
|
|
||||||
}
|
|
||||||
dir := filepath.Join(cachePath, "new-api-body-cache")
|
|
||||||
|
|
||||||
info := DiskCacheInfo{
|
info := DiskCacheInfo{
|
||||||
Path: dir,
|
Path: dir,
|
||||||
|
|||||||
@@ -56,7 +56,8 @@ type upstreamResult struct {
|
|||||||
func FetchUpstreamRatios(c *gin.Context) {
|
func FetchUpstreamRatios(c *gin.Context) {
|
||||||
var req dto.UpstreamRequest
|
var req dto.UpstreamRequest
|
||||||
if err := c.ShouldBindJSON(&req); err != nil {
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"success": false, "message": err.Error()})
|
common.SysError("failed to bind upstream request: " + err.Error())
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"success": false, "message": "请求参数格式错误"})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -103,9 +103,10 @@ func AddRedemption(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
err = cleanRedemption.Insert()
|
err = cleanRedemption.Insert()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
common.SysError("failed to insert redemption: " + err.Error())
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": false,
|
"success": false,
|
||||||
"message": err.Error(),
|
"message": "创建兑换码失败,请稍后重试",
|
||||||
"data": keys,
|
"data": keys,
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -107,9 +107,10 @@ func GetTokenUsage(c *gin.Context) {
|
|||||||
|
|
||||||
token, err := model.GetTokenByKey(strings.TrimPrefix(tokenKey, "sk-"), false)
|
token, err := model.GetTokenByKey(strings.TrimPrefix(tokenKey, "sk-"), false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
common.SysError("failed to get token by key: " + err.Error())
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": false,
|
"success": false,
|
||||||
"message": err.Error(),
|
"message": "获取令牌信息失败,请稍后重试",
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -214,6 +214,14 @@ type ClaudeRequest struct {
|
|||||||
ServiceTier string `json:"service_tier,omitempty"`
|
ServiceTier string `json:"service_tier,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// createClaudeFileSource 根据数据内容创建正确类型的 FileSource
|
||||||
|
func createClaudeFileSource(data string) *types.FileSource {
|
||||||
|
if strings.HasPrefix(data, "http://") || strings.HasPrefix(data, "https://") {
|
||||||
|
return types.NewURLFileSource(data)
|
||||||
|
}
|
||||||
|
return types.NewBase64FileSource(data, "")
|
||||||
|
}
|
||||||
|
|
||||||
func (c *ClaudeRequest) GetTokenCountMeta() *types.TokenCountMeta {
|
func (c *ClaudeRequest) GetTokenCountMeta() *types.TokenCountMeta {
|
||||||
var tokenCountMeta = types.TokenCountMeta{
|
var tokenCountMeta = types.TokenCountMeta{
|
||||||
TokenType: types.TokenTypeTokenizer,
|
TokenType: types.TokenTypeTokenizer,
|
||||||
@@ -243,7 +251,10 @@ func (c *ClaudeRequest) GetTokenCountMeta() *types.TokenCountMeta {
|
|||||||
data = common.Interface2String(media.Source.Data)
|
data = common.Interface2String(media.Source.Data)
|
||||||
}
|
}
|
||||||
if data != "" {
|
if data != "" {
|
||||||
fileMeta = append(fileMeta, &types.FileMeta{FileType: types.FileTypeImage, OriginData: data})
|
fileMeta = append(fileMeta, &types.FileMeta{
|
||||||
|
FileType: types.FileTypeImage,
|
||||||
|
Source: createClaudeFileSource(data),
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -275,7 +286,10 @@ func (c *ClaudeRequest) GetTokenCountMeta() *types.TokenCountMeta {
|
|||||||
data = common.Interface2String(media.Source.Data)
|
data = common.Interface2String(media.Source.Data)
|
||||||
}
|
}
|
||||||
if data != "" {
|
if data != "" {
|
||||||
fileMeta = append(fileMeta, &types.FileMeta{FileType: types.FileTypeImage, OriginData: data})
|
fileMeta = append(fileMeta, &types.FileMeta{
|
||||||
|
FileType: types.FileTypeImage,
|
||||||
|
Source: createClaudeFileSource(data),
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
case "tool_use":
|
case "tool_use":
|
||||||
|
|||||||
@@ -64,6 +64,14 @@ type LatLng struct {
|
|||||||
Longitude *float64 `json:"longitude,omitempty"`
|
Longitude *float64 `json:"longitude,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// createGeminiFileSource 根据数据内容创建正确类型的 FileSource
|
||||||
|
func createGeminiFileSource(data string, mimeType string) *types.FileSource {
|
||||||
|
if strings.HasPrefix(data, "http://") || strings.HasPrefix(data, "https://") {
|
||||||
|
return types.NewURLFileSource(data)
|
||||||
|
}
|
||||||
|
return types.NewBase64FileSource(data, mimeType)
|
||||||
|
}
|
||||||
|
|
||||||
func (r *GeminiChatRequest) GetTokenCountMeta() *types.TokenCountMeta {
|
func (r *GeminiChatRequest) GetTokenCountMeta() *types.TokenCountMeta {
|
||||||
var files []*types.FileMeta = make([]*types.FileMeta, 0)
|
var files []*types.FileMeta = make([]*types.FileMeta, 0)
|
||||||
|
|
||||||
@@ -80,27 +88,23 @@ func (r *GeminiChatRequest) GetTokenCountMeta() *types.TokenCountMeta {
|
|||||||
inputTexts = append(inputTexts, part.Text)
|
inputTexts = append(inputTexts, part.Text)
|
||||||
}
|
}
|
||||||
if part.InlineData != nil && part.InlineData.Data != "" {
|
if part.InlineData != nil && part.InlineData.Data != "" {
|
||||||
if strings.HasPrefix(part.InlineData.MimeType, "image/") {
|
mimeType := part.InlineData.MimeType
|
||||||
files = append(files, &types.FileMeta{
|
source := createGeminiFileSource(part.InlineData.Data, mimeType)
|
||||||
FileType: types.FileTypeImage,
|
var fileType types.FileType
|
||||||
OriginData: part.InlineData.Data,
|
if strings.HasPrefix(mimeType, "image/") {
|
||||||
})
|
fileType = types.FileTypeImage
|
||||||
} else if strings.HasPrefix(part.InlineData.MimeType, "audio/") {
|
} else if strings.HasPrefix(mimeType, "audio/") {
|
||||||
files = append(files, &types.FileMeta{
|
fileType = types.FileTypeAudio
|
||||||
FileType: types.FileTypeAudio,
|
} else if strings.HasPrefix(mimeType, "video/") {
|
||||||
OriginData: part.InlineData.Data,
|
fileType = types.FileTypeVideo
|
||||||
})
|
|
||||||
} else if strings.HasPrefix(part.InlineData.MimeType, "video/") {
|
|
||||||
files = append(files, &types.FileMeta{
|
|
||||||
FileType: types.FileTypeVideo,
|
|
||||||
OriginData: part.InlineData.Data,
|
|
||||||
})
|
|
||||||
} else {
|
} else {
|
||||||
files = append(files, &types.FileMeta{
|
fileType = types.FileTypeFile
|
||||||
FileType: types.FileTypeFile,
|
|
||||||
OriginData: part.InlineData.Data,
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
files = append(files, &types.FileMeta{
|
||||||
|
FileType: fileType,
|
||||||
|
Source: source,
|
||||||
|
MimeType: mimeType,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -101,6 +101,14 @@ type GeneralOpenAIRequest struct {
|
|||||||
SearchMode string `json:"search_mode,omitempty"`
|
SearchMode string `json:"search_mode,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// createFileSource 根据数据内容创建正确类型的 FileSource
|
||||||
|
func createFileSource(data string) *types.FileSource {
|
||||||
|
if strings.HasPrefix(data, "http://") || strings.HasPrefix(data, "https://") {
|
||||||
|
return types.NewURLFileSource(data)
|
||||||
|
}
|
||||||
|
return types.NewBase64FileSource(data, "")
|
||||||
|
}
|
||||||
|
|
||||||
func (r *GeneralOpenAIRequest) GetTokenCountMeta() *types.TokenCountMeta {
|
func (r *GeneralOpenAIRequest) GetTokenCountMeta() *types.TokenCountMeta {
|
||||||
var tokenCountMeta types.TokenCountMeta
|
var tokenCountMeta types.TokenCountMeta
|
||||||
var texts = make([]string, 0)
|
var texts = make([]string, 0)
|
||||||
@@ -144,42 +152,40 @@ func (r *GeneralOpenAIRequest) GetTokenCountMeta() *types.TokenCountMeta {
|
|||||||
for _, m := range arrayContent {
|
for _, m := range arrayContent {
|
||||||
if m.Type == ContentTypeImageURL {
|
if m.Type == ContentTypeImageURL {
|
||||||
imageUrl := m.GetImageMedia()
|
imageUrl := m.GetImageMedia()
|
||||||
if imageUrl != nil {
|
if imageUrl != nil && imageUrl.Url != "" {
|
||||||
if imageUrl.Url != "" {
|
source := createFileSource(imageUrl.Url)
|
||||||
meta := &types.FileMeta{
|
fileMeta = append(fileMeta, &types.FileMeta{
|
||||||
FileType: types.FileTypeImage,
|
FileType: types.FileTypeImage,
|
||||||
}
|
Source: source,
|
||||||
meta.OriginData = imageUrl.Url
|
Detail: imageUrl.Detail,
|
||||||
meta.Detail = imageUrl.Detail
|
})
|
||||||
fileMeta = append(fileMeta, meta)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
} else if m.Type == ContentTypeInputAudio {
|
} else if m.Type == ContentTypeInputAudio {
|
||||||
inputAudio := m.GetInputAudio()
|
inputAudio := m.GetInputAudio()
|
||||||
if inputAudio != nil {
|
if inputAudio != nil && inputAudio.Data != "" {
|
||||||
meta := &types.FileMeta{
|
source := createFileSource(inputAudio.Data)
|
||||||
|
fileMeta = append(fileMeta, &types.FileMeta{
|
||||||
FileType: types.FileTypeAudio,
|
FileType: types.FileTypeAudio,
|
||||||
}
|
Source: source,
|
||||||
meta.OriginData = inputAudio.Data
|
})
|
||||||
fileMeta = append(fileMeta, meta)
|
|
||||||
}
|
}
|
||||||
} else if m.Type == ContentTypeFile {
|
} else if m.Type == ContentTypeFile {
|
||||||
file := m.GetFile()
|
file := m.GetFile()
|
||||||
if file != nil {
|
if file != nil && file.FileData != "" {
|
||||||
meta := &types.FileMeta{
|
source := createFileSource(file.FileData)
|
||||||
|
fileMeta = append(fileMeta, &types.FileMeta{
|
||||||
FileType: types.FileTypeFile,
|
FileType: types.FileTypeFile,
|
||||||
}
|
Source: source,
|
||||||
meta.OriginData = file.FileData
|
})
|
||||||
fileMeta = append(fileMeta, meta)
|
|
||||||
}
|
}
|
||||||
} else if m.Type == ContentTypeVideoUrl {
|
} else if m.Type == ContentTypeVideoUrl {
|
||||||
videoUrl := m.GetVideoUrl()
|
videoUrl := m.GetVideoUrl()
|
||||||
if videoUrl != nil && videoUrl.Url != "" {
|
if videoUrl != nil && videoUrl.Url != "" {
|
||||||
meta := &types.FileMeta{
|
source := createFileSource(videoUrl.Url)
|
||||||
|
fileMeta = append(fileMeta, &types.FileMeta{
|
||||||
FileType: types.FileTypeVideo,
|
FileType: types.FileTypeVideo,
|
||||||
}
|
Source: source,
|
||||||
meta.OriginData = videoUrl.Url
|
})
|
||||||
fileMeta = append(fileMeta, meta)
|
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
texts = append(texts, m.Text)
|
texts = append(texts, m.Text)
|
||||||
@@ -833,16 +839,16 @@ func (r *OpenAIResponsesRequest) GetTokenCountMeta() *types.TokenCountMeta {
|
|||||||
if input.Type == "input_image" {
|
if input.Type == "input_image" {
|
||||||
if input.ImageUrl != "" {
|
if input.ImageUrl != "" {
|
||||||
fileMeta = append(fileMeta, &types.FileMeta{
|
fileMeta = append(fileMeta, &types.FileMeta{
|
||||||
FileType: types.FileTypeImage,
|
FileType: types.FileTypeImage,
|
||||||
OriginData: input.ImageUrl,
|
Source: createFileSource(input.ImageUrl),
|
||||||
Detail: input.Detail,
|
Detail: input.Detail,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
} else if input.Type == "input_file" {
|
} else if input.Type == "input_file" {
|
||||||
if input.FileUrl != "" {
|
if input.FileUrl != "" {
|
||||||
fileMeta = append(fileMeta, &types.FileMeta{
|
fileMeta = append(fileMeta, &types.FileMeta{
|
||||||
FileType: types.FileTypeFile,
|
FileType: types.FileTypeFile,
|
||||||
OriginData: input.FileUrl,
|
Source: createFileSource(input.FileUrl),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package middleware
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/QuantumNous/new-api/common"
|
"github.com/QuantumNous/new-api/common"
|
||||||
|
"github.com/QuantumNous/new-api/service"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -14,5 +15,8 @@ func BodyStorageCleanup() gin.HandlerFunc {
|
|||||||
|
|
||||||
// 请求结束后清理存储
|
// 请求结束后清理存储
|
||||||
common.CleanupBodyStorage(c)
|
common.CleanupBodyStorage(c)
|
||||||
|
|
||||||
|
// 清理文件缓存(URL 下载的文件等)
|
||||||
|
service.CleanupFileSources(c)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -148,7 +148,8 @@ func Redeem(key string, userId int) (quota int, err error) {
|
|||||||
return err
|
return err
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, errors.New("兑换失败," + err.Error())
|
common.SysError("redemption failed: " + err.Error())
|
||||||
|
return 0, errors.New("兑换失败,请稍后重试")
|
||||||
}
|
}
|
||||||
RecordLog(userId, LogTypeTopup, fmt.Sprintf("通过兑换码充值 %s,兑换码ID %d", logger.LogQuota(redemption.Quota), redemption.Id))
|
RecordLog(userId, LogTypeTopup, fmt.Sprintf("通过兑换码充值 %s,兑换码ID %d", logger.LogQuota(redemption.Quota), redemption.Id))
|
||||||
return redemption.Quota, nil
|
return redemption.Quota, nil
|
||||||
|
|||||||
@@ -95,7 +95,8 @@ func Recharge(referenceId string, customerId string) (err error) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.New("充值失败," + err.Error())
|
common.SysError("topup failed: " + err.Error())
|
||||||
|
return errors.New("充值失败,请稍后重试")
|
||||||
}
|
}
|
||||||
|
|
||||||
RecordLog(topUp.UserId, LogTypeTopup, fmt.Sprintf("使用在线充值成功,充值金额: %v,支付金额:%d", logger.FormatQuota(int(quota)), topUp.Amount))
|
RecordLog(topUp.UserId, LogTypeTopup, fmt.Sprintf("使用在线充值成功,充值金额: %v,支付金额:%d", logger.FormatQuota(int(quota)), topUp.Amount))
|
||||||
@@ -367,7 +368,8 @@ func RechargeCreem(referenceId string, customerEmail string, customerName string
|
|||||||
})
|
})
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.New("充值失败," + err.Error())
|
common.SysError("creem topup failed: " + err.Error())
|
||||||
|
return errors.New("充值失败,请稍后重试")
|
||||||
}
|
}
|
||||||
|
|
||||||
RecordLog(topUp.UserId, LogTypeTopup, fmt.Sprintf("使用Creem充值成功,充值额度: %v,支付金额:%.2f", quota, topUp.Money))
|
RecordLog(topUp.UserId, LogTypeTopup, fmt.Sprintf("使用Creem充值成功,充值额度: %v,支付金额:%.2f", quota, topUp.Money))
|
||||||
|
|||||||
@@ -49,12 +49,14 @@ func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayIn
|
|||||||
for i2, mediaMessage := range content {
|
for i2, mediaMessage := range content {
|
||||||
if mediaMessage.Source != nil {
|
if mediaMessage.Source != nil {
|
||||||
if mediaMessage.Source.Type == "url" {
|
if mediaMessage.Source.Type == "url" {
|
||||||
fileData, err := service.GetFileBase64FromUrl(c, mediaMessage.Source.Url, "formatting image for Claude")
|
// 使用统一的文件服务获取图片数据
|
||||||
|
source := types.NewURLFileSource(mediaMessage.Source.Url)
|
||||||
|
base64Data, mimeType, err := service.GetBase64Data(c, source, "formatting image for Claude")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("get file base64 from url failed: %s", err.Error())
|
return nil, fmt.Errorf("get file base64 from url failed: %s", err.Error())
|
||||||
}
|
}
|
||||||
mediaMessage.Source.MediaType = fileData.MimeType
|
mediaMessage.Source.MediaType = mimeType
|
||||||
mediaMessage.Source.Data = fileData.Base64Data
|
mediaMessage.Source.Data = base64Data
|
||||||
mediaMessage.Source.Url = ""
|
mediaMessage.Source.Url = ""
|
||||||
mediaMessage.Source.Type = "base64"
|
mediaMessage.Source.Type = "base64"
|
||||||
content[i2] = mediaMessage
|
content[i2] = mediaMessage
|
||||||
|
|||||||
@@ -364,23 +364,19 @@ func RequestOpenAI2ClaudeMessage(c *gin.Context, textRequest dto.GeneralOpenAIRe
|
|||||||
claudeMediaMessage.Source = &dto.ClaudeMessageSource{
|
claudeMediaMessage.Source = &dto.ClaudeMessageSource{
|
||||||
Type: "base64",
|
Type: "base64",
|
||||||
}
|
}
|
||||||
// 判断是否是url
|
// 使用统一的文件服务获取图片数据
|
||||||
|
var source *types.FileSource
|
||||||
if strings.HasPrefix(imageUrl.Url, "http") {
|
if strings.HasPrefix(imageUrl.Url, "http") {
|
||||||
// 是url,获取图片的类型和base64编码的数据
|
source = types.NewURLFileSource(imageUrl.Url)
|
||||||
fileData, err := service.GetFileBase64FromUrl(c, imageUrl.Url, "formatting image for Claude")
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("get file base64 from url failed: %s", err.Error())
|
|
||||||
}
|
|
||||||
claudeMediaMessage.Source.MediaType = fileData.MimeType
|
|
||||||
claudeMediaMessage.Source.Data = fileData.Base64Data
|
|
||||||
} else {
|
} else {
|
||||||
_, format, base64String, err := service.DecodeBase64ImageData(imageUrl.Url)
|
source = types.NewBase64FileSource(imageUrl.Url, "")
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
claudeMediaMessage.Source.MediaType = "image/" + format
|
|
||||||
claudeMediaMessage.Source.Data = base64String
|
|
||||||
}
|
}
|
||||||
|
base64Data, mimeType, err := service.GetBase64Data(c, source, "formatting image for Claude")
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("get file data failed: %s", err.Error())
|
||||||
|
}
|
||||||
|
claudeMediaMessage.Source.MediaType = mimeType
|
||||||
|
claudeMediaMessage.Source.Data = base64Data
|
||||||
}
|
}
|
||||||
claudeMediaMessages = append(claudeMediaMessages, claudeMediaMessage)
|
claudeMediaMessages = append(claudeMediaMessages, claudeMediaMessage)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -540,64 +540,58 @@ func CovertOpenAI2Gemini(c *gin.Context, textRequest dto.GeneralOpenAIRequest, i
|
|||||||
if constant.GeminiVisionMaxImageNum != -1 && imageNum > constant.GeminiVisionMaxImageNum {
|
if constant.GeminiVisionMaxImageNum != -1 && imageNum > constant.GeminiVisionMaxImageNum {
|
||||||
return nil, fmt.Errorf("too many images in the message, max allowed is %d", constant.GeminiVisionMaxImageNum)
|
return nil, fmt.Errorf("too many images in the message, max allowed is %d", constant.GeminiVisionMaxImageNum)
|
||||||
}
|
}
|
||||||
// 判断是否是url
|
// 使用统一的文件服务获取图片数据
|
||||||
if strings.HasPrefix(part.GetImageMedia().Url, "http") {
|
var source *types.FileSource
|
||||||
// 是url,获取文件的类型和base64编码的数据
|
imageUrl := part.GetImageMedia().Url
|
||||||
fileData, err := service.GetFileBase64FromUrl(c, part.GetImageMedia().Url, "formatting image for Gemini")
|
if strings.HasPrefix(imageUrl, "http") {
|
||||||
if err != nil {
|
source = types.NewURLFileSource(imageUrl)
|
||||||
return nil, fmt.Errorf("get file base64 from url '%s' failed: %w", part.GetImageMedia().Url, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 校验 MimeType 是否在 Gemini 支持的白名单中
|
|
||||||
if _, ok := geminiSupportedMimeTypes[strings.ToLower(fileData.MimeType)]; !ok {
|
|
||||||
url := part.GetImageMedia().Url
|
|
||||||
return nil, fmt.Errorf("mime type is not supported by Gemini: '%s', url: '%s', supported types are: %v", fileData.MimeType, url, getSupportedMimeTypesList())
|
|
||||||
}
|
|
||||||
|
|
||||||
parts = append(parts, dto.GeminiPart{
|
|
||||||
InlineData: &dto.GeminiInlineData{
|
|
||||||
MimeType: fileData.MimeType, // 使用原始的 MimeType,因为大小写可能对API有意义
|
|
||||||
Data: fileData.Base64Data,
|
|
||||||
},
|
|
||||||
})
|
|
||||||
} else {
|
} else {
|
||||||
format, base64String, err := service.DecodeBase64FileData(part.GetImageMedia().Url)
|
source = types.NewBase64FileSource(imageUrl, "")
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("decode base64 image data failed: %s", err.Error())
|
|
||||||
}
|
|
||||||
parts = append(parts, dto.GeminiPart{
|
|
||||||
InlineData: &dto.GeminiInlineData{
|
|
||||||
MimeType: format,
|
|
||||||
Data: base64String,
|
|
||||||
},
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
base64Data, mimeType, err := service.GetBase64Data(c, source, "formatting image for Gemini")
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("get file data from '%s' failed: %w", source.GetIdentifier(), err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 校验 MimeType 是否在 Gemini 支持的白名单中
|
||||||
|
if _, ok := geminiSupportedMimeTypes[strings.ToLower(mimeType)]; !ok {
|
||||||
|
return nil, fmt.Errorf("mime type is not supported by Gemini: '%s', url: '%s', supported types are: %v", mimeType, source.GetIdentifier(), getSupportedMimeTypesList())
|
||||||
|
}
|
||||||
|
|
||||||
|
parts = append(parts, dto.GeminiPart{
|
||||||
|
InlineData: &dto.GeminiInlineData{
|
||||||
|
MimeType: mimeType,
|
||||||
|
Data: base64Data,
|
||||||
|
},
|
||||||
|
})
|
||||||
} else if part.Type == dto.ContentTypeFile {
|
} else if part.Type == dto.ContentTypeFile {
|
||||||
if part.GetFile().FileId != "" {
|
if part.GetFile().FileId != "" {
|
||||||
return nil, fmt.Errorf("only base64 file is supported in gemini")
|
return nil, fmt.Errorf("only base64 file is supported in gemini")
|
||||||
}
|
}
|
||||||
format, base64String, err := service.DecodeBase64FileData(part.GetFile().FileData)
|
fileSource := types.NewBase64FileSource(part.GetFile().FileData, "")
|
||||||
|
base64Data, mimeType, err := service.GetBase64Data(c, fileSource, "formatting file for Gemini")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("decode base64 file data failed: %s", err.Error())
|
return nil, fmt.Errorf("decode base64 file data failed: %s", err.Error())
|
||||||
}
|
}
|
||||||
parts = append(parts, dto.GeminiPart{
|
parts = append(parts, dto.GeminiPart{
|
||||||
InlineData: &dto.GeminiInlineData{
|
InlineData: &dto.GeminiInlineData{
|
||||||
MimeType: format,
|
MimeType: mimeType,
|
||||||
Data: base64String,
|
Data: base64Data,
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
} else if part.Type == dto.ContentTypeInputAudio {
|
} else if part.Type == dto.ContentTypeInputAudio {
|
||||||
if part.GetInputAudio().Data == "" {
|
if part.GetInputAudio().Data == "" {
|
||||||
return nil, fmt.Errorf("only base64 audio is supported in gemini")
|
return nil, fmt.Errorf("only base64 audio is supported in gemini")
|
||||||
}
|
}
|
||||||
base64String, err := service.DecodeBase64AudioData(part.GetInputAudio().Data)
|
audioSource := types.NewBase64FileSource(part.GetInputAudio().Data, "audio/"+part.GetInputAudio().Format)
|
||||||
|
base64Data, mimeType, err := service.GetBase64Data(c, audioSource, "formatting audio for Gemini")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("decode base64 audio data failed: %s", err.Error())
|
return nil, fmt.Errorf("decode base64 audio data failed: %s", err.Error())
|
||||||
}
|
}
|
||||||
parts = append(parts, dto.GeminiPart{
|
parts = append(parts, dto.GeminiPart{
|
||||||
InlineData: &dto.GeminiInlineData{
|
InlineData: &dto.GeminiInlineData{
|
||||||
MimeType: "audio/" + part.GetInputAudio().Format,
|
MimeType: mimeType,
|
||||||
Data: base64String,
|
Data: base64Data,
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -99,19 +99,16 @@ func openAIChatToOllamaChat(c *gin.Context, r *dto.GeneralOpenAIRequest) (*Ollam
|
|||||||
if part.Type == dto.ContentTypeImageURL {
|
if part.Type == dto.ContentTypeImageURL {
|
||||||
img := part.GetImageMedia()
|
img := part.GetImageMedia()
|
||||||
if img != nil && img.Url != "" {
|
if img != nil && img.Url != "" {
|
||||||
var base64Data string
|
// 使用统一的文件服务获取图片数据
|
||||||
|
var source *types.FileSource
|
||||||
if strings.HasPrefix(img.Url, "http") {
|
if strings.HasPrefix(img.Url, "http") {
|
||||||
fileData, err := service.GetFileBase64FromUrl(c, img.Url, "fetch image for ollama chat")
|
source = types.NewURLFileSource(img.Url)
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
base64Data = fileData.Base64Data
|
|
||||||
} else if strings.HasPrefix(img.Url, "data:") {
|
|
||||||
if idx := strings.Index(img.Url, ","); idx != -1 && idx+1 < len(img.Url) {
|
|
||||||
base64Data = img.Url[idx+1:]
|
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
base64Data = img.Url
|
source = types.NewBase64FileSource(img.Url, "")
|
||||||
|
}
|
||||||
|
base64Data, _, err := service.GetBase64Data(c, source, "fetch image for ollama chat")
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
}
|
}
|
||||||
if base64Data != "" {
|
if base64Data != "" {
|
||||||
images = append(images, base64Data)
|
images = append(images, base64Data)
|
||||||
|
|||||||
@@ -2,7 +2,6 @@ package service
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"encoding/base64"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"image"
|
"image"
|
||||||
_ "image/gif"
|
_ "image/gif"
|
||||||
@@ -13,7 +12,6 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/QuantumNous/new-api/common"
|
"github.com/QuantumNous/new-api/common"
|
||||||
"github.com/QuantumNous/new-api/constant"
|
|
||||||
"github.com/QuantumNous/new-api/logger"
|
"github.com/QuantumNous/new-api/logger"
|
||||||
"github.com/QuantumNous/new-api/types"
|
"github.com/QuantumNous/new-api/types"
|
||||||
|
|
||||||
@@ -130,90 +128,27 @@ func GetFileTypeFromUrl(c *gin.Context, url string, reason ...string) (string, e
|
|||||||
return "application/octet-stream", nil
|
return "application/octet-stream", nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetFileBase64FromUrl 从 URL 获取文件的 base64 编码数据
|
||||||
|
// Deprecated: 请使用 GetBase64Data 配合 types.NewURLFileSource 替代
|
||||||
|
// 此函数保留用于向后兼容,内部已重构为调用统一的文件服务
|
||||||
func GetFileBase64FromUrl(c *gin.Context, url string, reason ...string) (*types.LocalFileData, error) {
|
func GetFileBase64FromUrl(c *gin.Context, url string, reason ...string) (*types.LocalFileData, error) {
|
||||||
contextKey := fmt.Sprintf("file_download_%s", common.GenerateHMAC(url))
|
source := types.NewURLFileSource(url)
|
||||||
|
cachedData, err := LoadFileSource(c, source, reason...)
|
||||||
// Check if the file has already been downloaded in this request
|
|
||||||
if cachedData, exists := c.Get(contextKey); exists {
|
|
||||||
if common.DebugEnabled {
|
|
||||||
logger.LogDebug(c, fmt.Sprintf("Using cached file data for URL: %s", url))
|
|
||||||
}
|
|
||||||
return cachedData.(*types.LocalFileData), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
var maxFileSize = constant.MaxFileDownloadMB * 1024 * 1024
|
|
||||||
|
|
||||||
resp, err := DoDownloadRequest(url, reason...)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
|
||||||
|
|
||||||
// Always use LimitReader to prevent oversized downloads
|
// 转换为旧的 LocalFileData 格式以保持兼容
|
||||||
fileBytes, err := io.ReadAll(io.LimitReader(resp.Body, int64(maxFileSize+1)))
|
base64Data, err := cachedData.GetBase64Data()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
// Check actual size after reading
|
return &types.LocalFileData{
|
||||||
if len(fileBytes) > maxFileSize {
|
|
||||||
return nil, fmt.Errorf("file size exceeds maximum allowed size: %dMB", constant.MaxFileDownloadMB)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Convert to base64
|
|
||||||
base64Data := base64.StdEncoding.EncodeToString(fileBytes)
|
|
||||||
|
|
||||||
mimeType := resp.Header.Get("Content-Type")
|
|
||||||
if len(strings.Split(mimeType, ";")) > 1 {
|
|
||||||
// If Content-Type has parameters, take the first part
|
|
||||||
mimeType = strings.Split(mimeType, ";")[0]
|
|
||||||
}
|
|
||||||
if mimeType == "application/octet-stream" {
|
|
||||||
logger.LogDebug(c, fmt.Sprintf("MIME type is application/octet-stream for URL: %s", url))
|
|
||||||
// try to guess the MIME type from the url last segment
|
|
||||||
urlParts := strings.Split(url, "/")
|
|
||||||
if len(urlParts) > 0 {
|
|
||||||
lastSegment := urlParts[len(urlParts)-1]
|
|
||||||
if strings.Contains(lastSegment, ".") {
|
|
||||||
// Extract the file extension
|
|
||||||
filename := strings.Split(lastSegment, ".")
|
|
||||||
if len(filename) > 1 {
|
|
||||||
ext := strings.ToLower(filename[len(filename)-1])
|
|
||||||
// Guess MIME type based on file extension
|
|
||||||
mimeType = GetMimeTypeByExtension(ext)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// try to guess the MIME type from the file extension
|
|
||||||
fileName := resp.Header.Get("Content-Disposition")
|
|
||||||
if fileName != "" {
|
|
||||||
// Extract the filename from the Content-Disposition header
|
|
||||||
parts := strings.Split(fileName, ";")
|
|
||||||
for _, part := range parts {
|
|
||||||
if strings.HasPrefix(strings.TrimSpace(part), "filename=") {
|
|
||||||
fileName = strings.TrimSpace(strings.TrimPrefix(part, "filename="))
|
|
||||||
// Remove quotes if present
|
|
||||||
if len(fileName) > 2 && fileName[0] == '"' && fileName[len(fileName)-1] == '"' {
|
|
||||||
fileName = fileName[1 : len(fileName)-1]
|
|
||||||
}
|
|
||||||
// Guess MIME type based on file extension
|
|
||||||
if ext := strings.ToLower(strings.TrimPrefix(fileName, ".")); ext != "" {
|
|
||||||
mimeType = GetMimeTypeByExtension(ext)
|
|
||||||
}
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
data := &types.LocalFileData{
|
|
||||||
Base64Data: base64Data,
|
Base64Data: base64Data,
|
||||||
MimeType: mimeType,
|
MimeType: cachedData.MimeType,
|
||||||
Size: int64(len(fileBytes)),
|
Size: cachedData.Size,
|
||||||
}
|
Url: url,
|
||||||
// Store the file data in the context to avoid re-downloading
|
}, nil
|
||||||
c.Set(contextKey, data)
|
|
||||||
|
|
||||||
return data, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetMimeTypeByExtension(ext string) string {
|
func GetMimeTypeByExtension(ext string) string {
|
||||||
|
|||||||
451
service/file_service.go
Normal file
451
service/file_service.go
Normal file
@@ -0,0 +1,451 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/base64"
|
||||||
|
"fmt"
|
||||||
|
"image"
|
||||||
|
_ "image/gif"
|
||||||
|
_ "image/jpeg"
|
||||||
|
_ "image/png"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/QuantumNous/new-api/common"
|
||||||
|
"github.com/QuantumNous/new-api/constant"
|
||||||
|
"github.com/QuantumNous/new-api/logger"
|
||||||
|
"github.com/QuantumNous/new-api/types"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"golang.org/x/image/webp"
|
||||||
|
)
|
||||||
|
|
||||||
|
// FileService 统一的文件处理服务
|
||||||
|
// 提供文件下载、解码、缓存等功能的统一入口
|
||||||
|
|
||||||
|
// getContextCacheKey 生成 context 缓存的 key
|
||||||
|
func getContextCacheKey(url string) string {
|
||||||
|
return fmt.Sprintf("file_cache_%s", common.GenerateHMAC(url))
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoadFileSource 加载文件源数据
|
||||||
|
// 这是统一的入口,会自动处理缓存和不同的来源类型
|
||||||
|
func LoadFileSource(c *gin.Context, source *types.FileSource, reason ...string) (*types.CachedFileData, error) {
|
||||||
|
if source == nil {
|
||||||
|
return nil, fmt.Errorf("file source is nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 如果已有缓存,直接返回
|
||||||
|
if source.HasCache() {
|
||||||
|
return source.GetCache(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var cachedData *types.CachedFileData
|
||||||
|
var err error
|
||||||
|
|
||||||
|
if source.IsURL() {
|
||||||
|
cachedData, err = loadFromURL(c, source.URL, reason...)
|
||||||
|
} else {
|
||||||
|
cachedData, err = loadFromBase64(source.Base64Data, source.MimeType)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 设置缓存
|
||||||
|
source.SetCache(cachedData)
|
||||||
|
|
||||||
|
// 注册到 context 以便请求结束时自动清理
|
||||||
|
if c != nil {
|
||||||
|
registerSourceForCleanup(c, source)
|
||||||
|
}
|
||||||
|
|
||||||
|
return cachedData, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// registerSourceForCleanup 注册 FileSource 到 context 以便请求结束时清理
|
||||||
|
func registerSourceForCleanup(c *gin.Context, source *types.FileSource) {
|
||||||
|
key := string(constant.ContextKeyFileSourcesToCleanup)
|
||||||
|
var sources []*types.FileSource
|
||||||
|
if existing, exists := c.Get(key); exists {
|
||||||
|
sources = existing.([]*types.FileSource)
|
||||||
|
}
|
||||||
|
sources = append(sources, source)
|
||||||
|
c.Set(key, sources)
|
||||||
|
}
|
||||||
|
|
||||||
|
// CleanupFileSources 清理请求中所有注册的 FileSource
|
||||||
|
// 应在请求结束时调用(通常由中间件自动调用)
|
||||||
|
func CleanupFileSources(c *gin.Context) {
|
||||||
|
key := string(constant.ContextKeyFileSourcesToCleanup)
|
||||||
|
if sources, exists := c.Get(key); exists {
|
||||||
|
for _, source := range sources.([]*types.FileSource) {
|
||||||
|
if cache := source.GetCache(); cache != nil {
|
||||||
|
if cache.IsDisk() {
|
||||||
|
common.DecrementDiskFiles(cache.Size)
|
||||||
|
}
|
||||||
|
cache.Close()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
c.Set(key, nil) // 清除引用
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// loadFromURL 从 URL 加载文件
|
||||||
|
// 支持磁盘缓存:当文件大小超过阈值且磁盘缓存可用时,将数据存储到磁盘
|
||||||
|
func loadFromURL(c *gin.Context, url string, reason ...string) (*types.CachedFileData, error) {
|
||||||
|
contextKey := getContextCacheKey(url)
|
||||||
|
|
||||||
|
// 检查 context 缓存
|
||||||
|
if cachedData, exists := c.Get(contextKey); exists {
|
||||||
|
if common.DebugEnabled {
|
||||||
|
logger.LogDebug(c, fmt.Sprintf("Using cached file data for URL: %s", url))
|
||||||
|
}
|
||||||
|
return cachedData.(*types.CachedFileData), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// 下载文件
|
||||||
|
var maxFileSize = constant.MaxFileDownloadMB * 1024 * 1024
|
||||||
|
|
||||||
|
resp, err := DoDownloadRequest(url, reason...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to download file from %s: %w", url, err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != 200 {
|
||||||
|
return nil, fmt.Errorf("failed to download file, status code: %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 读取文件内容(限制大小)
|
||||||
|
fileBytes, err := io.ReadAll(io.LimitReader(resp.Body, int64(maxFileSize+1)))
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to read file content: %w", err)
|
||||||
|
}
|
||||||
|
if len(fileBytes) > maxFileSize {
|
||||||
|
return nil, fmt.Errorf("file size exceeds maximum allowed size: %dMB", constant.MaxFileDownloadMB)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 转换为 base64
|
||||||
|
base64Data := base64.StdEncoding.EncodeToString(fileBytes)
|
||||||
|
|
||||||
|
// 智能获取 MIME 类型
|
||||||
|
mimeType := smartDetectMimeType(resp, url, fileBytes)
|
||||||
|
|
||||||
|
// 判断是否使用磁盘缓存
|
||||||
|
base64Size := int64(len(base64Data))
|
||||||
|
var cachedData *types.CachedFileData
|
||||||
|
|
||||||
|
if shouldUseDiskCache(base64Size) {
|
||||||
|
// 使用磁盘缓存
|
||||||
|
diskPath, err := writeToDiskCache(base64Data)
|
||||||
|
if err != nil {
|
||||||
|
// 磁盘缓存失败,回退到内存
|
||||||
|
logger.LogWarn(c, fmt.Sprintf("Failed to write to disk cache, falling back to memory: %v", err))
|
||||||
|
cachedData = types.NewMemoryCachedData(base64Data, mimeType, int64(len(fileBytes)))
|
||||||
|
} else {
|
||||||
|
cachedData = types.NewDiskCachedData(diskPath, mimeType, int64(len(fileBytes)))
|
||||||
|
common.IncrementDiskFiles(base64Size)
|
||||||
|
if common.DebugEnabled {
|
||||||
|
logger.LogDebug(c, fmt.Sprintf("File cached to disk: %s, size: %d bytes", diskPath, base64Size))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// 使用内存缓存
|
||||||
|
cachedData = types.NewMemoryCachedData(base64Data, mimeType, int64(len(fileBytes)))
|
||||||
|
}
|
||||||
|
|
||||||
|
// 如果是图片,尝试获取图片配置
|
||||||
|
if strings.HasPrefix(mimeType, "image/") {
|
||||||
|
config, format, err := decodeImageConfig(fileBytes)
|
||||||
|
if err == nil {
|
||||||
|
cachedData.ImageConfig = &config
|
||||||
|
cachedData.ImageFormat = format
|
||||||
|
// 如果通过图片解码获取了更准确的格式,更新 MIME 类型
|
||||||
|
if mimeType == "application/octet-stream" || mimeType == "" {
|
||||||
|
cachedData.MimeType = "image/" + format
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 存入 context 缓存
|
||||||
|
c.Set(contextKey, cachedData)
|
||||||
|
|
||||||
|
return cachedData, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// shouldUseDiskCache 判断是否应该使用磁盘缓存
|
||||||
|
func shouldUseDiskCache(dataSize int64) bool {
|
||||||
|
return common.ShouldUseDiskCache(dataSize)
|
||||||
|
}
|
||||||
|
|
||||||
|
// writeToDiskCache 将数据写入磁盘缓存
|
||||||
|
func writeToDiskCache(base64Data string) (string, error) {
|
||||||
|
return common.WriteDiskCacheFileString(common.DiskCacheTypeFile, base64Data)
|
||||||
|
}
|
||||||
|
|
||||||
|
// smartDetectMimeType 智能检测 MIME 类型
|
||||||
|
// 优先级:Content-Type header > Content-Disposition filename > URL 路径 > 内容嗅探 > 图片解码
|
||||||
|
func smartDetectMimeType(resp *http.Response, url string, fileBytes []byte) string {
|
||||||
|
// 1. 尝试从 Content-Type header 获取
|
||||||
|
mimeType := resp.Header.Get("Content-Type")
|
||||||
|
if idx := strings.Index(mimeType, ";"); idx != -1 {
|
||||||
|
mimeType = strings.TrimSpace(mimeType[:idx])
|
||||||
|
}
|
||||||
|
if mimeType != "" && mimeType != "application/octet-stream" {
|
||||||
|
return mimeType
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2. 尝试从 Content-Disposition header 的 filename 获取
|
||||||
|
if cd := resp.Header.Get("Content-Disposition"); cd != "" {
|
||||||
|
parts := strings.Split(cd, ";")
|
||||||
|
for _, part := range parts {
|
||||||
|
part = strings.TrimSpace(part)
|
||||||
|
if strings.HasPrefix(strings.ToLower(part), "filename=") {
|
||||||
|
name := strings.TrimSpace(strings.TrimPrefix(part, "filename="))
|
||||||
|
// 移除引号
|
||||||
|
if len(name) > 2 && name[0] == '"' && name[len(name)-1] == '"' {
|
||||||
|
name = name[1 : len(name)-1]
|
||||||
|
}
|
||||||
|
if dot := strings.LastIndex(name, "."); dot != -1 && dot+1 < len(name) {
|
||||||
|
ext := strings.ToLower(name[dot+1:])
|
||||||
|
if ext != "" {
|
||||||
|
mt := GetMimeTypeByExtension(ext)
|
||||||
|
if mt != "application/octet-stream" {
|
||||||
|
return mt
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3. 尝试从 URL 路径获取扩展名
|
||||||
|
mt := guessMimeTypeFromURL(url)
|
||||||
|
if mt != "application/octet-stream" {
|
||||||
|
return mt
|
||||||
|
}
|
||||||
|
|
||||||
|
// 4. 使用 http.DetectContentType 内容嗅探
|
||||||
|
if len(fileBytes) > 0 {
|
||||||
|
sniffed := http.DetectContentType(fileBytes)
|
||||||
|
if sniffed != "" && sniffed != "application/octet-stream" {
|
||||||
|
// 去除可能的 charset 参数
|
||||||
|
if idx := strings.Index(sniffed, ";"); idx != -1 {
|
||||||
|
sniffed = strings.TrimSpace(sniffed[:idx])
|
||||||
|
}
|
||||||
|
return sniffed
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 5. 尝试作为图片解码获取格式
|
||||||
|
if len(fileBytes) > 0 {
|
||||||
|
if _, format, err := decodeImageConfig(fileBytes); err == nil && format != "" {
|
||||||
|
return "image/" + strings.ToLower(format)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 最终回退
|
||||||
|
return "application/octet-stream"
|
||||||
|
}
|
||||||
|
|
||||||
|
// loadFromBase64 从 base64 字符串加载文件
|
||||||
|
func loadFromBase64(base64String string, providedMimeType string) (*types.CachedFileData, error) {
|
||||||
|
var mimeType string
|
||||||
|
var cleanBase64 string
|
||||||
|
|
||||||
|
// 处理 data: 前缀
|
||||||
|
if strings.HasPrefix(base64String, "data:") {
|
||||||
|
// 格式: data:mime/type;base64,xxxxx
|
||||||
|
idx := strings.Index(base64String, ",")
|
||||||
|
if idx != -1 {
|
||||||
|
header := base64String[:idx]
|
||||||
|
cleanBase64 = base64String[idx+1:]
|
||||||
|
|
||||||
|
// 从 header 提取 MIME 类型
|
||||||
|
if strings.Contains(header, ":") && strings.Contains(header, ";") {
|
||||||
|
mimeStart := strings.Index(header, ":") + 1
|
||||||
|
mimeEnd := strings.Index(header, ";")
|
||||||
|
if mimeStart < mimeEnd {
|
||||||
|
mimeType = header[mimeStart:mimeEnd]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
cleanBase64 = base64String
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
cleanBase64 = base64String
|
||||||
|
}
|
||||||
|
|
||||||
|
// 使用提供的 MIME 类型(如果有)
|
||||||
|
if providedMimeType != "" {
|
||||||
|
mimeType = providedMimeType
|
||||||
|
}
|
||||||
|
|
||||||
|
// 解码 base64
|
||||||
|
decodedData, err := base64.StdEncoding.DecodeString(cleanBase64)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to decode base64 data: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 判断是否使用磁盘缓存(对于 base64 内联数据也支持磁盘缓存)
|
||||||
|
base64Size := int64(len(cleanBase64))
|
||||||
|
var cachedData *types.CachedFileData
|
||||||
|
|
||||||
|
if shouldUseDiskCache(base64Size) {
|
||||||
|
// 使用磁盘缓存
|
||||||
|
diskPath, err := writeToDiskCache(cleanBase64)
|
||||||
|
if err != nil {
|
||||||
|
// 磁盘缓存失败,回退到内存
|
||||||
|
cachedData = types.NewMemoryCachedData(cleanBase64, mimeType, int64(len(decodedData)))
|
||||||
|
} else {
|
||||||
|
cachedData = types.NewDiskCachedData(diskPath, mimeType, int64(len(decodedData)))
|
||||||
|
common.IncrementDiskFiles(base64Size)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
cachedData = types.NewMemoryCachedData(cleanBase64, mimeType, int64(len(decodedData)))
|
||||||
|
}
|
||||||
|
|
||||||
|
// 如果是图片或 MIME 类型未知,尝试解码图片获取更多信息
|
||||||
|
if mimeType == "" || strings.HasPrefix(mimeType, "image/") {
|
||||||
|
config, format, err := decodeImageConfig(decodedData)
|
||||||
|
if err == nil {
|
||||||
|
cachedData.ImageConfig = &config
|
||||||
|
cachedData.ImageFormat = format
|
||||||
|
if mimeType == "" {
|
||||||
|
cachedData.MimeType = "image/" + format
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return cachedData, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetImageConfig 获取图片配置(宽高等信息)
|
||||||
|
// 会自动处理缓存,避免重复下载/解码
|
||||||
|
func GetImageConfig(c *gin.Context, source *types.FileSource) (image.Config, string, error) {
|
||||||
|
cachedData, err := LoadFileSource(c, source, "get_image_config")
|
||||||
|
if err != nil {
|
||||||
|
return image.Config{}, "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
if cachedData.ImageConfig != nil {
|
||||||
|
return *cachedData.ImageConfig, cachedData.ImageFormat, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// 如果缓存中没有图片配置,尝试解码
|
||||||
|
base64Str, err := cachedData.GetBase64Data()
|
||||||
|
if err != nil {
|
||||||
|
return image.Config{}, "", fmt.Errorf("failed to get base64 data: %w", err)
|
||||||
|
}
|
||||||
|
decodedData, err := base64.StdEncoding.DecodeString(base64Str)
|
||||||
|
if err != nil {
|
||||||
|
return image.Config{}, "", fmt.Errorf("failed to decode base64 for image config: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
config, format, err := decodeImageConfig(decodedData)
|
||||||
|
if err != nil {
|
||||||
|
return image.Config{}, "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 更新缓存
|
||||||
|
cachedData.ImageConfig = &config
|
||||||
|
cachedData.ImageFormat = format
|
||||||
|
|
||||||
|
return config, format, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetBase64Data 获取 base64 编码的数据
|
||||||
|
// 会自动处理缓存,避免重复下载
|
||||||
|
// 支持内存缓存和磁盘缓存
|
||||||
|
func GetBase64Data(c *gin.Context, source *types.FileSource, reason ...string) (string, string, error) {
|
||||||
|
cachedData, err := LoadFileSource(c, source, reason...)
|
||||||
|
if err != nil {
|
||||||
|
return "", "", err
|
||||||
|
}
|
||||||
|
base64Str, err := cachedData.GetBase64Data()
|
||||||
|
if err != nil {
|
||||||
|
return "", "", fmt.Errorf("failed to get base64 data: %w", err)
|
||||||
|
}
|
||||||
|
return base64Str, cachedData.MimeType, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetMimeType 获取文件的 MIME 类型
|
||||||
|
func GetMimeType(c *gin.Context, source *types.FileSource) (string, error) {
|
||||||
|
// 如果已经有缓存,直接返回
|
||||||
|
if source.HasCache() {
|
||||||
|
return source.GetCache().MimeType, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// 如果是 URL,尝试只获取 header 而不下载完整文件
|
||||||
|
if source.IsURL() {
|
||||||
|
mimeType, err := GetFileTypeFromUrl(c, source.URL, "get_mime_type")
|
||||||
|
if err == nil && mimeType != "" && mimeType != "application/octet-stream" {
|
||||||
|
return mimeType, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 否则加载完整数据
|
||||||
|
cachedData, err := LoadFileSource(c, source, "get_mime_type")
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return cachedData.MimeType, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// DetectFileType 检测文件类型(image/audio/video/file)
|
||||||
|
func DetectFileType(mimeType string) types.FileType {
|
||||||
|
if strings.HasPrefix(mimeType, "image/") {
|
||||||
|
return types.FileTypeImage
|
||||||
|
}
|
||||||
|
if strings.HasPrefix(mimeType, "audio/") {
|
||||||
|
return types.FileTypeAudio
|
||||||
|
}
|
||||||
|
if strings.HasPrefix(mimeType, "video/") {
|
||||||
|
return types.FileTypeVideo
|
||||||
|
}
|
||||||
|
return types.FileTypeFile
|
||||||
|
}
|
||||||
|
|
||||||
|
// decodeImageConfig 从字节数据解码图片配置
|
||||||
|
func decodeImageConfig(data []byte) (image.Config, string, error) {
|
||||||
|
reader := bytes.NewReader(data)
|
||||||
|
|
||||||
|
// 尝试标准格式
|
||||||
|
config, format, err := image.DecodeConfig(reader)
|
||||||
|
if err == nil {
|
||||||
|
return config, format, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// 尝试 webp
|
||||||
|
reader.Seek(0, io.SeekStart)
|
||||||
|
config, err = webp.DecodeConfig(reader)
|
||||||
|
if err == nil {
|
||||||
|
return config, "webp", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return image.Config{}, "", fmt.Errorf("failed to decode image config: unsupported format")
|
||||||
|
}
|
||||||
|
|
||||||
|
// guessMimeTypeFromURL 从 URL 猜测 MIME 类型
|
||||||
|
func guessMimeTypeFromURL(url string) string {
|
||||||
|
// 移除查询参数
|
||||||
|
cleanedURL := url
|
||||||
|
if q := strings.Index(cleanedURL, "?"); q != -1 {
|
||||||
|
cleanedURL = cleanedURL[:q]
|
||||||
|
}
|
||||||
|
|
||||||
|
// 获取最后一段
|
||||||
|
if slash := strings.LastIndex(cleanedURL, "/"); slash != -1 && slash+1 < len(cleanedURL) {
|
||||||
|
last := cleanedURL[slash+1:]
|
||||||
|
if dot := strings.LastIndex(last, "."); dot != -1 && dot+1 < len(last) {
|
||||||
|
ext := strings.ToLower(last[dot+1:])
|
||||||
|
return GetMimeTypeByExtension(ext)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return "application/octet-stream"
|
||||||
|
}
|
||||||
@@ -3,10 +3,6 @@ package service
|
|||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"image"
|
|
||||||
_ "image/gif"
|
|
||||||
_ "image/jpeg"
|
|
||||||
_ "image/png"
|
|
||||||
"log"
|
"log"
|
||||||
"math"
|
"math"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
@@ -23,8 +19,8 @@ import (
|
|||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
func getImageToken(fileMeta *types.FileMeta, model string, stream bool) (int, error) {
|
func getImageToken(c *gin.Context, fileMeta *types.FileMeta, model string, stream bool) (int, error) {
|
||||||
if fileMeta == nil {
|
if fileMeta == nil || fileMeta.Source == nil {
|
||||||
return 0, fmt.Errorf("image_url_is_nil")
|
return 0, fmt.Errorf("image_url_is_nil")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -99,35 +95,20 @@ func getImageToken(fileMeta *types.FileMeta, model string, stream bool) (int, er
|
|||||||
fileMeta.Detail = "high"
|
fileMeta.Detail = "high"
|
||||||
}
|
}
|
||||||
|
|
||||||
// Decode image to get dimensions
|
// 使用统一的文件服务获取图片配置
|
||||||
var config image.Config
|
config, format, err := GetImageConfig(c, fileMeta.Source)
|
||||||
var err error
|
|
||||||
var format string
|
|
||||||
var b64str string
|
|
||||||
|
|
||||||
if fileMeta.ParsedData != nil {
|
|
||||||
config, format, b64str, err = DecodeBase64ImageData(fileMeta.ParsedData.Base64Data)
|
|
||||||
} else {
|
|
||||||
if strings.HasPrefix(fileMeta.OriginData, "http") {
|
|
||||||
config, format, err = DecodeUrlImageData(fileMeta.OriginData)
|
|
||||||
} else {
|
|
||||||
common.SysLog(fmt.Sprintf("decoding image"))
|
|
||||||
config, format, b64str, err = DecodeBase64ImageData(fileMeta.OriginData)
|
|
||||||
}
|
|
||||||
fileMeta.MimeType = format
|
|
||||||
}
|
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
|
fileMeta.MimeType = format
|
||||||
|
|
||||||
if config.Width == 0 || config.Height == 0 {
|
if config.Width == 0 || config.Height == 0 {
|
||||||
// not an image
|
// not an image, but might be a valid file
|
||||||
if format != "" && b64str != "" {
|
if format != "" {
|
||||||
// file type
|
// file type
|
||||||
return 3 * baseTokens, nil
|
return 3 * baseTokens, nil
|
||||||
}
|
}
|
||||||
return 0, errors.New(fmt.Sprintf("fail to decode base64 config: %s", fileMeta.OriginData))
|
return 0, errors.New(fmt.Sprintf("fail to decode image config: %s", fileMeta.GetIdentifier()))
|
||||||
}
|
}
|
||||||
|
|
||||||
width := config.Width
|
width := config.Width
|
||||||
@@ -269,48 +250,24 @@ func EstimateRequestToken(c *gin.Context, meta *types.TokenCountMeta, info *rela
|
|||||||
shouldFetchFiles = false
|
shouldFetchFiles = false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 使用统一的文件服务获取文件类型
|
||||||
for _, file := range meta.Files {
|
for _, file := range meta.Files {
|
||||||
if strings.HasPrefix(file.OriginData, "http") {
|
if file.Source == nil {
|
||||||
if shouldFetchFiles {
|
continue
|
||||||
mineType, err := GetFileTypeFromUrl(c, file.OriginData, "token_counter")
|
}
|
||||||
if err != nil {
|
|
||||||
return 0, fmt.Errorf("error getting file base64 from url: %v", err)
|
// 如果文件类型未知且需要获取,通过 MIME 类型检测
|
||||||
}
|
if file.FileType == "" || (file.Source.IsURL() && shouldFetchFiles) {
|
||||||
if strings.HasPrefix(mineType, "image/") {
|
mimeType, err := GetMimeType(c, file.Source)
|
||||||
file.FileType = types.FileTypeImage
|
if err != nil {
|
||||||
} else if strings.HasPrefix(mineType, "video/") {
|
if shouldFetchFiles {
|
||||||
file.FileType = types.FileTypeVideo
|
return 0, fmt.Errorf("error getting file type: %v", err)
|
||||||
} else if strings.HasPrefix(mineType, "audio/") {
|
|
||||||
file.FileType = types.FileTypeAudio
|
|
||||||
} else {
|
|
||||||
file.FileType = types.FileTypeFile
|
|
||||||
}
|
|
||||||
file.MimeType = mineType
|
|
||||||
}
|
|
||||||
} else if strings.HasPrefix(file.OriginData, "data:") {
|
|
||||||
// get mime type from base64 header
|
|
||||||
parts := strings.SplitN(file.OriginData, ",", 2)
|
|
||||||
if len(parts) >= 1 {
|
|
||||||
header := parts[0]
|
|
||||||
// Extract mime type from "data:mime/type;base64" format
|
|
||||||
if strings.Contains(header, ":") && strings.Contains(header, ";") {
|
|
||||||
mimeStart := strings.Index(header, ":") + 1
|
|
||||||
mimeEnd := strings.Index(header, ";")
|
|
||||||
if mimeStart < mimeEnd {
|
|
||||||
mineType := header[mimeStart:mimeEnd]
|
|
||||||
if strings.HasPrefix(mineType, "image/") {
|
|
||||||
file.FileType = types.FileTypeImage
|
|
||||||
} else if strings.HasPrefix(mineType, "video/") {
|
|
||||||
file.FileType = types.FileTypeVideo
|
|
||||||
} else if strings.HasPrefix(mineType, "audio/") {
|
|
||||||
file.FileType = types.FileTypeAudio
|
|
||||||
} else {
|
|
||||||
file.FileType = types.FileTypeFile
|
|
||||||
}
|
|
||||||
file.MimeType = mineType
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
// 如果不需要获取,使用默认类型
|
||||||
|
continue
|
||||||
}
|
}
|
||||||
|
file.MimeType = mimeType
|
||||||
|
file.FileType = DetectFileType(mimeType)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -318,9 +275,9 @@ func EstimateRequestToken(c *gin.Context, meta *types.TokenCountMeta, info *rela
|
|||||||
switch file.FileType {
|
switch file.FileType {
|
||||||
case types.FileTypeImage:
|
case types.FileTypeImage:
|
||||||
if common.IsOpenAITextModel(model) {
|
if common.IsOpenAITextModel(model) {
|
||||||
token, err := getImageToken(file, model, info.IsStream)
|
token, err := getImageToken(c, file, model, info.IsStream)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, fmt.Errorf("error counting image token, media index[%d], original data[%s], err: %v", i, file.OriginData, err)
|
return 0, fmt.Errorf("error counting image token, media index[%d], identifier[%s], err: %v", i, file.GetIdentifier(), err)
|
||||||
}
|
}
|
||||||
tkm += token
|
tkm += token
|
||||||
} else {
|
} else {
|
||||||
|
|||||||
213
types/file_source.go
Normal file
213
types/file_source.go
Normal file
@@ -0,0 +1,213 @@
|
|||||||
|
package types
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"image"
|
||||||
|
"os"
|
||||||
|
"sync"
|
||||||
|
)
|
||||||
|
|
||||||
|
// FileSourceType 文件来源类型
|
||||||
|
type FileSourceType string
|
||||||
|
|
||||||
|
const (
|
||||||
|
FileSourceTypeURL FileSourceType = "url" // URL 来源
|
||||||
|
FileSourceTypeBase64 FileSourceType = "base64" // Base64 内联数据
|
||||||
|
)
|
||||||
|
|
||||||
|
// FileSource 统一的文件来源抽象
|
||||||
|
// 支持 URL 和 base64 两种来源,提供懒加载和缓存机制
|
||||||
|
type FileSource struct {
|
||||||
|
Type FileSourceType `json:"type"` // 来源类型
|
||||||
|
URL string `json:"url,omitempty"` // URL(当 Type 为 url 时)
|
||||||
|
Base64Data string `json:"base64_data,omitempty"` // Base64 数据(当 Type 为 base64 时)
|
||||||
|
MimeType string `json:"mime_type,omitempty"` // MIME 类型(可选,会自动检测)
|
||||||
|
|
||||||
|
// 内部缓存(不导出,不序列化)
|
||||||
|
cachedData *CachedFileData
|
||||||
|
cacheMu sync.RWMutex
|
||||||
|
cacheLoaded bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// CachedFileData 缓存的文件数据
|
||||||
|
// 支持内存缓存和磁盘缓存两种模式
|
||||||
|
type CachedFileData struct {
|
||||||
|
base64Data string // 内存中的 base64 数据(小文件)
|
||||||
|
MimeType string // MIME 类型
|
||||||
|
Size int64 // 文件大小(字节)
|
||||||
|
ImageConfig *image.Config // 图片配置(如果是图片)
|
||||||
|
ImageFormat string // 图片格式(如果是图片)
|
||||||
|
|
||||||
|
// 磁盘缓存相关
|
||||||
|
diskPath string // 磁盘缓存文件路径(大文件)
|
||||||
|
isDisk bool // 是否使用磁盘缓存
|
||||||
|
diskMu sync.Mutex // 磁盘操作锁
|
||||||
|
diskClosed bool // 是否已关闭/清理
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewMemoryCachedData 创建内存缓存的数据
|
||||||
|
func NewMemoryCachedData(base64Data string, mimeType string, size int64) *CachedFileData {
|
||||||
|
return &CachedFileData{
|
||||||
|
base64Data: base64Data,
|
||||||
|
MimeType: mimeType,
|
||||||
|
Size: size,
|
||||||
|
isDisk: false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewDiskCachedData 创建磁盘缓存的数据
|
||||||
|
func NewDiskCachedData(diskPath string, mimeType string, size int64) *CachedFileData {
|
||||||
|
return &CachedFileData{
|
||||||
|
diskPath: diskPath,
|
||||||
|
MimeType: mimeType,
|
||||||
|
Size: size,
|
||||||
|
isDisk: true,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetBase64Data 获取 base64 数据(自动处理内存/磁盘)
|
||||||
|
func (c *CachedFileData) GetBase64Data() (string, error) {
|
||||||
|
if !c.isDisk {
|
||||||
|
return c.base64Data, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
c.diskMu.Lock()
|
||||||
|
defer c.diskMu.Unlock()
|
||||||
|
|
||||||
|
if c.diskClosed {
|
||||||
|
return "", fmt.Errorf("disk cache already closed")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 从磁盘读取
|
||||||
|
data, err := os.ReadFile(c.diskPath)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("failed to read from disk cache: %w", err)
|
||||||
|
}
|
||||||
|
return string(data), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetBase64Data 设置 base64 数据(仅用于内存模式)
|
||||||
|
func (c *CachedFileData) SetBase64Data(data string) {
|
||||||
|
if !c.isDisk {
|
||||||
|
c.base64Data = data
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsDisk 是否使用磁盘缓存
|
||||||
|
func (c *CachedFileData) IsDisk() bool {
|
||||||
|
return c.isDisk
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close 关闭并清理资源
|
||||||
|
func (c *CachedFileData) Close() error {
|
||||||
|
if !c.isDisk {
|
||||||
|
c.base64Data = "" // 释放内存
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
c.diskMu.Lock()
|
||||||
|
defer c.diskMu.Unlock()
|
||||||
|
|
||||||
|
if c.diskClosed {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
c.diskClosed = true
|
||||||
|
if c.diskPath != "" {
|
||||||
|
return os.Remove(c.diskPath)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewURLFileSource 创建 URL 来源的 FileSource
|
||||||
|
func NewURLFileSource(url string) *FileSource {
|
||||||
|
return &FileSource{
|
||||||
|
Type: FileSourceTypeURL,
|
||||||
|
URL: url,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewBase64FileSource 创建 base64 来源的 FileSource
|
||||||
|
func NewBase64FileSource(base64Data string, mimeType string) *FileSource {
|
||||||
|
return &FileSource{
|
||||||
|
Type: FileSourceTypeBase64,
|
||||||
|
Base64Data: base64Data,
|
||||||
|
MimeType: mimeType,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsURL 判断是否是 URL 来源
|
||||||
|
func (f *FileSource) IsURL() bool {
|
||||||
|
return f.Type == FileSourceTypeURL
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsBase64 判断是否是 base64 来源
|
||||||
|
func (f *FileSource) IsBase64() bool {
|
||||||
|
return f.Type == FileSourceTypeBase64
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetIdentifier 获取文件标识符(用于日志和错误追踪)
|
||||||
|
func (f *FileSource) GetIdentifier() string {
|
||||||
|
if f.IsURL() {
|
||||||
|
if len(f.URL) > 100 {
|
||||||
|
return f.URL[:100] + "..."
|
||||||
|
}
|
||||||
|
return f.URL
|
||||||
|
}
|
||||||
|
if len(f.Base64Data) > 50 {
|
||||||
|
return "base64:" + f.Base64Data[:50] + "..."
|
||||||
|
}
|
||||||
|
return "base64:" + f.Base64Data
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetRawData 获取原始数据(URL 或完整的 base64 字符串)
|
||||||
|
func (f *FileSource) GetRawData() string {
|
||||||
|
if f.IsURL() {
|
||||||
|
return f.URL
|
||||||
|
}
|
||||||
|
return f.Base64Data
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetCache 设置缓存数据
|
||||||
|
func (f *FileSource) SetCache(data *CachedFileData) {
|
||||||
|
f.cacheMu.Lock()
|
||||||
|
defer f.cacheMu.Unlock()
|
||||||
|
f.cachedData = data
|
||||||
|
f.cacheLoaded = true
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetCache 获取缓存数据
|
||||||
|
func (f *FileSource) GetCache() *CachedFileData {
|
||||||
|
f.cacheMu.RLock()
|
||||||
|
defer f.cacheMu.RUnlock()
|
||||||
|
return f.cachedData
|
||||||
|
}
|
||||||
|
|
||||||
|
// HasCache 是否有缓存
|
||||||
|
func (f *FileSource) HasCache() bool {
|
||||||
|
f.cacheMu.RLock()
|
||||||
|
defer f.cacheMu.RUnlock()
|
||||||
|
return f.cacheLoaded && f.cachedData != nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClearCache 清除缓存,释放内存和磁盘文件
|
||||||
|
func (f *FileSource) ClearCache() {
|
||||||
|
f.cacheMu.Lock()
|
||||||
|
defer f.cacheMu.Unlock()
|
||||||
|
|
||||||
|
// 如果有缓存数据,先关闭它(会清理磁盘文件)
|
||||||
|
if f.cachedData != nil {
|
||||||
|
f.cachedData.Close()
|
||||||
|
}
|
||||||
|
f.cachedData = nil
|
||||||
|
f.cacheLoaded = false
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClearRawData 清除原始数据,只保留必要的元信息
|
||||||
|
// 用于在处理完成后释放大文件的内存
|
||||||
|
func (f *FileSource) ClearRawData() {
|
||||||
|
// 保留 URL(通常很短),只清除大的 base64 数据
|
||||||
|
if f.IsBase64() && len(f.Base64Data) > 1024 {
|
||||||
|
f.Base64Data = ""
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -32,10 +32,48 @@ type TokenCountMeta struct {
|
|||||||
|
|
||||||
type FileMeta struct {
|
type FileMeta struct {
|
||||||
FileType
|
FileType
|
||||||
MimeType string
|
MimeType string
|
||||||
OriginData string // url or base64 data
|
Source *FileSource // 统一的文件来源(URL 或 base64)
|
||||||
Detail string
|
Detail string // 图片细节级别(low/high/auto)
|
||||||
ParsedData *LocalFileData
|
}
|
||||||
|
|
||||||
|
// NewFileMeta 创建新的 FileMeta
|
||||||
|
func NewFileMeta(fileType FileType, source *FileSource) *FileMeta {
|
||||||
|
return &FileMeta{
|
||||||
|
FileType: fileType,
|
||||||
|
Source: source,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewImageFileMeta 创建图片类型的 FileMeta
|
||||||
|
func NewImageFileMeta(source *FileSource, detail string) *FileMeta {
|
||||||
|
return &FileMeta{
|
||||||
|
FileType: FileTypeImage,
|
||||||
|
Source: source,
|
||||||
|
Detail: detail,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetIdentifier 获取文件标识符(用于日志)
|
||||||
|
func (f *FileMeta) GetIdentifier() string {
|
||||||
|
if f.Source != nil {
|
||||||
|
return f.Source.GetIdentifier()
|
||||||
|
}
|
||||||
|
return "unknown"
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsURL 判断是否是 URL 来源
|
||||||
|
func (f *FileMeta) IsURL() bool {
|
||||||
|
return f.Source != nil && f.Source.IsURL()
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetRawData 获取原始数据(兼容旧代码)
|
||||||
|
// Deprecated: 请使用 Source.GetRawData()
|
||||||
|
func (f *FileMeta) GetRawData() string {
|
||||||
|
if f.Source != nil {
|
||||||
|
return f.Source.GetRawData()
|
||||||
|
}
|
||||||
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
type RequestMeta struct {
|
type RequestMeta struct {
|
||||||
|
|||||||
Reference in New Issue
Block a user