mirror of
https://github.com/QuantumNous/new-api.git
synced 2026-03-30 09:33:10 +00:00
* feat: 引入通用 HTTP BodyStorage/DiskCache 缓存配置与管理 - 新增 common/body_storage.go 提供 HTTP 请求体存储抽象和文件缓存能力 - 增加 common/disk_cache_config.go 支持全局磁盘缓存配置 - main.go 挂载缓存初始化流程 - 新增和补充 controller/performance.go (及 unix/windows) 用于缓存性能监控接口 - middleware/body_cleanup.go 自动清理缓存文件 - router 挂载相关接口 - 前端 settings 页面新增性能监控设置 PerformanceSetting - 优化缓存开关状态和模块热插拔能力 - 其他相关文件同步适配缓存扩展 * fix: 修复 BodyStorage 并发安全和错误处理问题 - 修复 diskStorage.Close() 竞态条件,先获取锁再执行 CAS - 为 memoryStorage 添加互斥锁和 closed 状态检查 - 修复 CreateBodyStorageFromReader 在磁盘存储失败时的回退逻辑 - 添加缓存命中统计调用 (IncrementDiskCacheHits/IncrementMemoryCacheHits) - 修复 gin.go 中 Seek 错误被忽略的问题 - 在 api-router 添加 BodyStorageCleanup 中间件 - 修复前端 formatBytes 对异常值的处理 Co-authored-by: Cursor <cursoragent@cursor.com> --------- Co-authored-by: Cursor <cursoragent@cursor.com>
330 lines
7.7 KiB
Go
330 lines
7.7 KiB
Go
package common
|
|
|
|
import (
|
|
"bytes"
|
|
"fmt"
|
|
"io"
|
|
"mime"
|
|
"mime/multipart"
|
|
"net/http"
|
|
"net/url"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/QuantumNous/new-api/constant"
|
|
"github.com/pkg/errors"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
)
|
|
|
|
const KeyRequestBody = "key_request_body"
|
|
const KeyBodyStorage = "key_body_storage"
|
|
|
|
var ErrRequestBodyTooLarge = errors.New("request body too large")
|
|
|
|
func IsRequestBodyTooLargeError(err error) bool {
|
|
if err == nil {
|
|
return false
|
|
}
|
|
if errors.Is(err, ErrRequestBodyTooLarge) {
|
|
return true
|
|
}
|
|
var mbe *http.MaxBytesError
|
|
return errors.As(err, &mbe)
|
|
}
|
|
|
|
func GetRequestBody(c *gin.Context) ([]byte, error) {
|
|
// 首先检查是否有 BodyStorage 缓存
|
|
if storage, exists := c.Get(KeyBodyStorage); exists && storage != nil {
|
|
if bs, ok := storage.(BodyStorage); ok {
|
|
if _, err := bs.Seek(0, io.SeekStart); err != nil {
|
|
return nil, fmt.Errorf("failed to seek body storage: %w", err)
|
|
}
|
|
return bs.Bytes()
|
|
}
|
|
}
|
|
|
|
// 检查旧的缓存方式
|
|
cached, exists := c.Get(KeyRequestBody)
|
|
if exists && cached != nil {
|
|
if b, ok := cached.([]byte); ok {
|
|
return b, nil
|
|
}
|
|
}
|
|
|
|
maxMB := constant.MaxRequestBodyMB
|
|
if maxMB <= 0 {
|
|
maxMB = 128 // 默认 128MB
|
|
}
|
|
maxBytes := int64(maxMB) << 20
|
|
|
|
contentLength := c.Request.ContentLength
|
|
|
|
// 使用新的存储系统
|
|
storage, err := CreateBodyStorageFromReader(c.Request.Body, contentLength, maxBytes)
|
|
_ = c.Request.Body.Close()
|
|
|
|
if err != nil {
|
|
if IsRequestBodyTooLargeError(err) {
|
|
return nil, errors.Wrap(ErrRequestBodyTooLarge, fmt.Sprintf("request body exceeds %d MB", maxMB))
|
|
}
|
|
return nil, err
|
|
}
|
|
|
|
// 缓存存储对象
|
|
c.Set(KeyBodyStorage, storage)
|
|
|
|
// 获取字节数据
|
|
body, err := storage.Bytes()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// 同时设置旧的缓存键以保持兼容性
|
|
c.Set(KeyRequestBody, body)
|
|
|
|
return body, nil
|
|
}
|
|
|
|
// GetBodyStorage 获取请求体存储对象(用于需要多次读取的场景)
|
|
func GetBodyStorage(c *gin.Context) (BodyStorage, error) {
|
|
// 检查是否已有存储
|
|
if storage, exists := c.Get(KeyBodyStorage); exists && storage != nil {
|
|
if bs, ok := storage.(BodyStorage); ok {
|
|
if _, err := bs.Seek(0, io.SeekStart); err != nil {
|
|
return nil, fmt.Errorf("failed to seek body storage: %w", err)
|
|
}
|
|
return bs, nil
|
|
}
|
|
}
|
|
|
|
// 如果没有,调用 GetRequestBody 创建存储
|
|
_, err := GetRequestBody(c)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// 再次获取存储
|
|
if storage, exists := c.Get(KeyBodyStorage); exists && storage != nil {
|
|
if bs, ok := storage.(BodyStorage); ok {
|
|
if _, err := bs.Seek(0, io.SeekStart); err != nil {
|
|
return nil, fmt.Errorf("failed to seek body storage: %w", err)
|
|
}
|
|
return bs, nil
|
|
}
|
|
}
|
|
|
|
return nil, errors.New("failed to get body storage")
|
|
}
|
|
|
|
// CleanupBodyStorage 清理请求体存储(应在请求结束时调用)
|
|
func CleanupBodyStorage(c *gin.Context) {
|
|
if storage, exists := c.Get(KeyBodyStorage); exists && storage != nil {
|
|
if bs, ok := storage.(BodyStorage); ok {
|
|
bs.Close()
|
|
}
|
|
c.Set(KeyBodyStorage, nil)
|
|
}
|
|
}
|
|
|
|
func UnmarshalBodyReusable(c *gin.Context, v any) error {
|
|
requestBody, err := GetRequestBody(c)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
//if DebugEnabled {
|
|
// println("UnmarshalBodyReusable request body:", string(requestBody))
|
|
//}
|
|
contentType := c.Request.Header.Get("Content-Type")
|
|
if strings.HasPrefix(contentType, "application/json") {
|
|
err = Unmarshal(requestBody, v)
|
|
} else if strings.Contains(contentType, gin.MIMEPOSTForm) {
|
|
err = parseFormData(requestBody, v)
|
|
} else if strings.Contains(contentType, gin.MIMEMultipartPOSTForm) {
|
|
err = parseMultipartFormData(c, requestBody, v)
|
|
} else {
|
|
// skip for now
|
|
// TODO: someday non json request have variant model, we will need to implementation this
|
|
}
|
|
if err != nil {
|
|
return err
|
|
}
|
|
// Reset request body
|
|
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
|
|
return nil
|
|
}
|
|
|
|
func SetContextKey(c *gin.Context, key constant.ContextKey, value any) {
|
|
c.Set(string(key), value)
|
|
}
|
|
|
|
func GetContextKey(c *gin.Context, key constant.ContextKey) (any, bool) {
|
|
return c.Get(string(key))
|
|
}
|
|
|
|
func GetContextKeyString(c *gin.Context, key constant.ContextKey) string {
|
|
return c.GetString(string(key))
|
|
}
|
|
|
|
func GetContextKeyInt(c *gin.Context, key constant.ContextKey) int {
|
|
return c.GetInt(string(key))
|
|
}
|
|
|
|
func GetContextKeyBool(c *gin.Context, key constant.ContextKey) bool {
|
|
return c.GetBool(string(key))
|
|
}
|
|
|
|
func GetContextKeyStringSlice(c *gin.Context, key constant.ContextKey) []string {
|
|
return c.GetStringSlice(string(key))
|
|
}
|
|
|
|
func GetContextKeyStringMap(c *gin.Context, key constant.ContextKey) map[string]any {
|
|
return c.GetStringMap(string(key))
|
|
}
|
|
|
|
func GetContextKeyTime(c *gin.Context, key constant.ContextKey) time.Time {
|
|
return c.GetTime(string(key))
|
|
}
|
|
|
|
func GetContextKeyType[T any](c *gin.Context, key constant.ContextKey) (T, bool) {
|
|
if value, ok := c.Get(string(key)); ok {
|
|
if v, ok := value.(T); ok {
|
|
return v, true
|
|
}
|
|
}
|
|
var t T
|
|
return t, false
|
|
}
|
|
|
|
func ApiError(c *gin.Context, err error) {
|
|
c.JSON(http.StatusOK, gin.H{
|
|
"success": false,
|
|
"message": err.Error(),
|
|
})
|
|
}
|
|
|
|
func ApiErrorMsg(c *gin.Context, msg string) {
|
|
c.JSON(http.StatusOK, gin.H{
|
|
"success": false,
|
|
"message": msg,
|
|
})
|
|
}
|
|
|
|
func ApiSuccess(c *gin.Context, data any) {
|
|
c.JSON(http.StatusOK, gin.H{
|
|
"success": true,
|
|
"message": "",
|
|
"data": data,
|
|
})
|
|
}
|
|
|
|
func ParseMultipartFormReusable(c *gin.Context) (*multipart.Form, error) {
|
|
requestBody, err := GetRequestBody(c)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
contentType := c.Request.Header.Get("Content-Type")
|
|
boundary, err := parseBoundary(contentType)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
reader := multipart.NewReader(bytes.NewReader(requestBody), boundary)
|
|
form, err := reader.ReadForm(multipartMemoryLimit())
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Reset request body
|
|
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
|
|
return form, nil
|
|
}
|
|
|
|
func processFormMap(formMap map[string]any, v any) error {
|
|
jsonData, err := Marshal(formMap)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
err = Unmarshal(jsonData, v)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func parseFormData(data []byte, v any) error {
|
|
values, err := url.ParseQuery(string(data))
|
|
if err != nil {
|
|
return err
|
|
}
|
|
formMap := make(map[string]any)
|
|
for key, vals := range values {
|
|
if len(vals) == 1 {
|
|
formMap[key] = vals[0]
|
|
} else {
|
|
formMap[key] = vals
|
|
}
|
|
}
|
|
|
|
return processFormMap(formMap, v)
|
|
}
|
|
|
|
func parseMultipartFormData(c *gin.Context, data []byte, v any) error {
|
|
contentType := c.Request.Header.Get("Content-Type")
|
|
boundary, err := parseBoundary(contentType)
|
|
if err != nil {
|
|
if errors.Is(err, errBoundaryNotFound) {
|
|
return Unmarshal(data, v) // Fallback to JSON
|
|
}
|
|
return err
|
|
}
|
|
|
|
reader := multipart.NewReader(bytes.NewReader(data), boundary)
|
|
form, err := reader.ReadForm(multipartMemoryLimit())
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer form.RemoveAll()
|
|
formMap := make(map[string]any)
|
|
for key, vals := range form.Value {
|
|
if len(vals) == 1 {
|
|
formMap[key] = vals[0]
|
|
} else {
|
|
formMap[key] = vals
|
|
}
|
|
}
|
|
|
|
return processFormMap(formMap, v)
|
|
}
|
|
|
|
var errBoundaryNotFound = errors.New("multipart boundary not found")
|
|
|
|
// parseBoundary extracts the multipart boundary from the Content-Type header using mime.ParseMediaType
|
|
func parseBoundary(contentType string) (string, error) {
|
|
if contentType == "" {
|
|
return "", errBoundaryNotFound
|
|
}
|
|
// Boundary-UUID / boundary-------xxxxxx
|
|
_, params, err := mime.ParseMediaType(contentType)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
boundary, ok := params["boundary"]
|
|
if !ok || boundary == "" {
|
|
return "", errBoundaryNotFound
|
|
}
|
|
return boundary, nil
|
|
}
|
|
|
|
// multipartMemoryLimit returns the configured multipart memory limit in bytes
|
|
func multipartMemoryLimit() int64 {
|
|
limitMB := constant.MaxFileDownloadMB
|
|
if limitMB <= 0 {
|
|
limitMB = 32
|
|
}
|
|
return int64(limitMB) << 20
|
|
}
|