mirror of
https://github.com/QuantumNous/new-api.git
synced 2026-03-30 04:03:18 +00:00
🛡️ fix: prevent OOM on large/decompressed requests; skip heavy prompt meta when token count is disabled
Clamp request body size (including post-decompression) to avoid memory exhaustion caused by huge payloads/zip bombs, especially with large-context Claude requests. Add a configurable `MAX_REQUEST_BODY_MB` (default `32`) and document it. - Enforce max request body size after gzip/br decompression via `http.MaxBytesReader` - Add a secondary size guard in `common.GetRequestBody` and cache-safe handling - Return **413 Request Entity Too Large** on oversized bodies in relay entry - Avoid building large `TokenCountMeta.CombineText` when both token counting and sensitive check are disabled (use lightweight meta for pricing) - Update READMEs (CN/EN/FR/JA) with `MAX_REQUEST_BODY_MB` - Fix a handful of vet/formatting issues encountered during the change - `go test ./...` passes
This commit is contained in:
@@ -305,6 +305,7 @@ docker run --name new-api -d --restart always \
|
||||
| `REDIS_CONN_STRING` | Redis connection string | - |
|
||||
| `STREAMING_TIMEOUT` | Streaming timeout (seconds) | `300` |
|
||||
| `STREAM_SCANNER_MAX_BUFFER_MB` | Max per-line buffer (MB) for the stream scanner; increase when upstream sends huge image/base64 payloads | `64` |
|
||||
| `MAX_REQUEST_BODY_MB` | Max request body size (MB, counted **after decompression**; prevents huge requests/zip bombs from exhausting memory). Exceeding it returns `413` | `32` |
|
||||
| `AZURE_DEFAULT_API_VERSION` | Azure API version | `2025-04-01-preview` |
|
||||
| `ERROR_LOG_ENABLED` | Error log switch | `false` |
|
||||
|
||||
|
||||
@@ -301,6 +301,7 @@ docker run --name new-api -d --restart always \
|
||||
| `REDIS_CONN_STRING` | Chaine de connexion Redis | - |
|
||||
| `STREAMING_TIMEOUT` | Délai d'expiration du streaming (secondes) | `300` |
|
||||
| `STREAM_SCANNER_MAX_BUFFER_MB` | Taille max du buffer par ligne (Mo) pour le scanner SSE ; à augmenter quand les sorties image/base64 sont très volumineuses (ex. images 4K) | `64` |
|
||||
| `MAX_REQUEST_BODY_MB` | Taille maximale du corps de requête (Mo, comptée **après décompression** ; évite les requêtes énormes/zip bombs qui saturent la mémoire). Dépassement ⇒ `413` | `32` |
|
||||
| `AZURE_DEFAULT_API_VERSION` | Version de l'API Azure | `2025-04-01-preview` |
|
||||
| `ERROR_LOG_ENABLED` | Interrupteur du journal d'erreurs | `false` |
|
||||
|
||||
|
||||
@@ -310,6 +310,7 @@ docker run --name new-api -d --restart always \
|
||||
| `REDIS_CONN_STRING` | Redis接続文字列 | - |
|
||||
| `STREAMING_TIMEOUT` | ストリーミング応答のタイムアウト時間(秒) | `300` |
|
||||
| `STREAM_SCANNER_MAX_BUFFER_MB` | ストリームスキャナの1行あたりバッファ上限(MB)。4K画像など巨大なbase64 `data:` ペイロードを扱う場合は値を増加させてください | `64` |
|
||||
| `MAX_REQUEST_BODY_MB` | リクエストボディ最大サイズ(MB、**解凍後**に計測。巨大リクエスト/zip bomb によるメモリ枯渇を防止)。超過時は `413` | `32` |
|
||||
| `AZURE_DEFAULT_API_VERSION` | Azure APIバージョン | `2025-04-01-preview` |
|
||||
| `ERROR_LOG_ENABLED` | エラーログスイッチ | `false` |
|
||||
|
||||
|
||||
@@ -306,6 +306,7 @@ docker run --name new-api -d --restart always \
|
||||
| `REDIS_CONN_STRING` | Redis 连接字符串 | - |
|
||||
| `STREAMING_TIMEOUT` | 流式超时时间(秒) | `300` |
|
||||
| `STREAM_SCANNER_MAX_BUFFER_MB` | 流式扫描器单行最大缓冲(MB),图像生成等超大 `data:` 片段(如 4K 图片 base64)需适当调大 | `64` |
|
||||
| `MAX_REQUEST_BODY_MB` | 请求体最大大小(MB,**解压后**计;防止超大请求/zip bomb 导致内存暴涨),超过将返回 `413` | `32` |
|
||||
| `AZURE_DEFAULT_API_VERSION` | Azure API 版本 | `2025-04-01-preview` |
|
||||
| `ERROR_LOG_ENABLED` | 错误日志开关 | `false` |
|
||||
|
||||
|
||||
@@ -18,18 +18,47 @@ import (
|
||||
|
||||
const KeyRequestBody = "key_request_body"
|
||||
|
||||
func GetRequestBody(c *gin.Context) ([]byte, error) {
|
||||
requestBody, _ := c.Get(KeyRequestBody)
|
||||
if requestBody != nil {
|
||||
return requestBody.([]byte), nil
|
||||
var ErrRequestBodyTooLarge = errors.New("request body too large")
|
||||
|
||||
func IsRequestBodyTooLargeError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
requestBody, err := io.ReadAll(c.Request.Body)
|
||||
if errors.Is(err, ErrRequestBodyTooLarge) {
|
||||
return true
|
||||
}
|
||||
var mbe *http.MaxBytesError
|
||||
return errors.As(err, &mbe)
|
||||
}
|
||||
|
||||
func GetRequestBody(c *gin.Context) ([]byte, error) {
|
||||
cached, exists := c.Get(KeyRequestBody)
|
||||
if exists && cached != nil {
|
||||
if b, ok := cached.([]byte); ok {
|
||||
return b, nil
|
||||
}
|
||||
}
|
||||
maxMB := constant.MaxRequestBodyMB
|
||||
if maxMB <= 0 {
|
||||
maxMB = 64
|
||||
}
|
||||
maxBytes := int64(maxMB) << 20
|
||||
|
||||
limited := io.LimitReader(c.Request.Body, maxBytes+1)
|
||||
body, err := io.ReadAll(limited)
|
||||
if err != nil {
|
||||
_ = c.Request.Body.Close()
|
||||
if IsRequestBodyTooLargeError(err) {
|
||||
return nil, ErrRequestBodyTooLarge
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
_ = c.Request.Body.Close()
|
||||
c.Set(KeyRequestBody, requestBody)
|
||||
return requestBody.([]byte), nil
|
||||
if int64(len(body)) > maxBytes {
|
||||
return nil, ErrRequestBodyTooLarge
|
||||
}
|
||||
c.Set(KeyRequestBody, body)
|
||||
return body, nil
|
||||
}
|
||||
|
||||
func UnmarshalBodyReusable(c *gin.Context, v any) error {
|
||||
|
||||
@@ -117,6 +117,8 @@ func initConstantEnv() {
|
||||
constant.DifyDebug = GetEnvOrDefaultBool("DIFY_DEBUG", true)
|
||||
constant.MaxFileDownloadMB = GetEnvOrDefault("MAX_FILE_DOWNLOAD_MB", 20)
|
||||
constant.StreamScannerMaxBufferMB = GetEnvOrDefault("STREAM_SCANNER_MAX_BUFFER_MB", 64)
|
||||
// MaxRequestBodyMB 请求体最大大小(解压后),用于防止超大请求/zip bomb导致内存暴涨
|
||||
constant.MaxRequestBodyMB = GetEnvOrDefault("MAX_REQUEST_BODY_MB", 32)
|
||||
// ForceStreamOption 覆盖请求参数,强制返回usage信息
|
||||
constant.ForceStreamOption = GetEnvOrDefaultBool("FORCE_STREAM_OPTION", true)
|
||||
constant.CountToken = GetEnvOrDefaultBool("CountToken", true)
|
||||
|
||||
@@ -9,6 +9,7 @@ var CountToken bool
|
||||
var GetMediaToken bool
|
||||
var GetMediaTokenNotStream bool
|
||||
var UpdateTask bool
|
||||
var MaxRequestBodyMB int
|
||||
var AzureDefaultAPIVersion string
|
||||
var GeminiVisionMaxImageNum int
|
||||
var NotifyLimitCount int
|
||||
|
||||
@@ -114,7 +114,7 @@ func DiscordOAuth(c *gin.Context) {
|
||||
DiscordBind(c)
|
||||
return
|
||||
}
|
||||
if !system_setting.GetDiscordSettings().Enabled {
|
||||
if !system_setting.GetDiscordSettings().Enabled {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "管理员未开启通过 Discord 登录以及注册",
|
||||
|
||||
@@ -2,6 +2,7 @@ package controller
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
@@ -104,7 +105,12 @@ func Relay(c *gin.Context, relayFormat types.RelayFormat) {
|
||||
|
||||
request, err := helper.GetAndValidateRequest(c, relayFormat)
|
||||
if err != nil {
|
||||
newAPIError = types.NewError(err, types.ErrorCodeInvalidRequest)
|
||||
// Map "request body too large" to 413 so clients can handle it correctly
|
||||
if common.IsRequestBodyTooLargeError(err) || errors.Is(err, common.ErrRequestBodyTooLarge) {
|
||||
newAPIError = types.NewErrorWithStatusCode(err, types.ErrorCodeReadRequestBodyFailed, http.StatusRequestEntityTooLarge, types.ErrOptionWithSkipRetry())
|
||||
} else {
|
||||
newAPIError = types.NewError(err, types.ErrorCodeInvalidRequest)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
@@ -114,9 +120,17 @@ func Relay(c *gin.Context, relayFormat types.RelayFormat) {
|
||||
return
|
||||
}
|
||||
|
||||
meta := request.GetTokenCountMeta()
|
||||
needSensitiveCheck := setting.ShouldCheckPromptSensitive()
|
||||
needCountToken := constant.CountToken
|
||||
// Avoid building huge CombineText (strings.Join) when token counting and sensitive check are both disabled.
|
||||
var meta *types.TokenCountMeta
|
||||
if needSensitiveCheck || needCountToken {
|
||||
meta = request.GetTokenCountMeta()
|
||||
} else {
|
||||
meta = fastTokenCountMetaForPricing(request)
|
||||
}
|
||||
|
||||
if setting.ShouldCheckPromptSensitive() {
|
||||
if needSensitiveCheck && meta != nil {
|
||||
contains, words := service.CheckSensitiveText(meta.CombineText)
|
||||
if contains {
|
||||
logger.LogWarn(c, fmt.Sprintf("user sensitive words detected: %s", strings.Join(words, ", ")))
|
||||
@@ -218,6 +232,33 @@ func addUsedChannel(c *gin.Context, channelId int) {
|
||||
c.Set("use_channel", useChannel)
|
||||
}
|
||||
|
||||
func fastTokenCountMetaForPricing(request dto.Request) *types.TokenCountMeta {
|
||||
if request == nil {
|
||||
return &types.TokenCountMeta{}
|
||||
}
|
||||
meta := &types.TokenCountMeta{
|
||||
TokenType: types.TokenTypeTokenizer,
|
||||
}
|
||||
switch r := request.(type) {
|
||||
case *dto.GeneralOpenAIRequest:
|
||||
if r.MaxCompletionTokens > r.MaxTokens {
|
||||
meta.MaxTokens = int(r.MaxCompletionTokens)
|
||||
} else {
|
||||
meta.MaxTokens = int(r.MaxTokens)
|
||||
}
|
||||
case *dto.OpenAIResponsesRequest:
|
||||
meta.MaxTokens = int(r.MaxOutputTokens)
|
||||
case *dto.ClaudeRequest:
|
||||
meta.MaxTokens = int(r.MaxTokens)
|
||||
case *dto.ImageRequest:
|
||||
// Pricing for image requests depends on ImagePriceRatio; safe to compute even when CountToken is disabled.
|
||||
return r.GetTokenCountMeta()
|
||||
default:
|
||||
// Best-effort: leave CombineText empty to avoid large allocations.
|
||||
}
|
||||
return meta
|
||||
}
|
||||
|
||||
func getChannel(c *gin.Context, info *relaycommon.RelayInfo, retryParam *service.RetryParam) (*model.Channel, *types.NewAPIError) {
|
||||
if info.ChannelMeta == nil {
|
||||
autoBan := c.GetBool("auto_ban")
|
||||
|
||||
@@ -88,7 +88,7 @@ func UpdateSunoTaskAll(ctx context.Context, taskChannelM map[int][]string, taskM
|
||||
for channelId, taskIds := range taskChannelM {
|
||||
err := updateSunoTaskAll(ctx, channelId, taskIds, taskM)
|
||||
if err != nil {
|
||||
logger.LogError(ctx, fmt.Sprintf("渠道 #%d 更新异步任务失败: %d", channelId, err.Error()))
|
||||
logger.LogError(ctx, fmt.Sprintf("渠道 #%d 更新异步任务失败: %s", channelId, err.Error()))
|
||||
}
|
||||
}
|
||||
return nil
|
||||
@@ -141,7 +141,7 @@ func updateSunoTaskAll(ctx context.Context, channelId int, taskIds []string, tas
|
||||
return err
|
||||
}
|
||||
if !responseItems.IsSuccess() {
|
||||
common.SysLog(fmt.Sprintf("渠道 #%d 未完成的任务有: %d, 成功获取到任务数: %d", channelId, len(taskIds), string(responseBody)))
|
||||
common.SysLog(fmt.Sprintf("渠道 #%d 未完成的任务有: %d, 成功获取到任务数: %s", channelId, len(taskIds), string(responseBody)))
|
||||
return err
|
||||
}
|
||||
|
||||
|
||||
@@ -7,12 +7,12 @@ import (
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/model"
|
||||
"github.com/QuantumNous/new-api/setting"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
@@ -162,7 +162,7 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
|
||||
}
|
||||
midjourneyModel, mjErr, success := service.GetMjRequestModel(relayMode, &midjourneyRequest)
|
||||
if mjErr != nil {
|
||||
return nil, false, fmt.Errorf(mjErr.Description)
|
||||
return nil, false, fmt.Errorf("%s", mjErr.Description)
|
||||
}
|
||||
if midjourneyModel == "" {
|
||||
if !success {
|
||||
|
||||
@@ -5,32 +5,69 @@ import (
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
"github.com/QuantumNous/new-api/constant"
|
||||
"github.com/andybalholm/brotli"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type readCloser struct {
|
||||
io.Reader
|
||||
closeFn func() error
|
||||
}
|
||||
|
||||
func (rc *readCloser) Close() error {
|
||||
if rc.closeFn != nil {
|
||||
return rc.closeFn()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func DecompressRequestMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
if c.Request.Body == nil || c.Request.Method == http.MethodGet {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
maxMB := constant.MaxRequestBodyMB
|
||||
if maxMB <= 0 {
|
||||
maxMB = 64
|
||||
}
|
||||
maxBytes := int64(maxMB) << 20
|
||||
|
||||
origBody := c.Request.Body
|
||||
wrapMaxBytes := func(body io.ReadCloser) io.ReadCloser {
|
||||
return http.MaxBytesReader(c.Writer, body, maxBytes)
|
||||
}
|
||||
|
||||
switch c.GetHeader("Content-Encoding") {
|
||||
case "gzip":
|
||||
gzipReader, err := gzip.NewReader(c.Request.Body)
|
||||
gzipReader, err := gzip.NewReader(origBody)
|
||||
if err != nil {
|
||||
_ = origBody.Close()
|
||||
c.AbortWithStatus(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
defer gzipReader.Close()
|
||||
|
||||
// Replace the request body with the decompressed data
|
||||
c.Request.Body = io.NopCloser(gzipReader)
|
||||
// Replace the request body with the decompressed data, and enforce a max size (post-decompression).
|
||||
c.Request.Body = wrapMaxBytes(&readCloser{
|
||||
Reader: gzipReader,
|
||||
closeFn: func() error {
|
||||
_ = gzipReader.Close()
|
||||
return origBody.Close()
|
||||
},
|
||||
})
|
||||
c.Request.Header.Del("Content-Encoding")
|
||||
case "br":
|
||||
reader := brotli.NewReader(c.Request.Body)
|
||||
c.Request.Body = io.NopCloser(reader)
|
||||
reader := brotli.NewReader(origBody)
|
||||
c.Request.Body = wrapMaxBytes(&readCloser{
|
||||
Reader: reader,
|
||||
closeFn: func() error {
|
||||
return origBody.Close()
|
||||
},
|
||||
})
|
||||
c.Request.Header.Del("Content-Encoding")
|
||||
default:
|
||||
// Even for uncompressed bodies, enforce a max size to avoid huge request allocations.
|
||||
c.Request.Body = wrapMaxBytes(origBody)
|
||||
}
|
||||
|
||||
// Continue processing the request
|
||||
|
||||
@@ -18,7 +18,7 @@ var awsModelIDMap = map[string]string{
|
||||
"claude-opus-4-1-20250805": "anthropic.claude-opus-4-1-20250805-v1:0",
|
||||
"claude-sonnet-4-5-20250929": "anthropic.claude-sonnet-4-5-20250929-v1:0",
|
||||
"claude-haiku-4-5-20251001": "anthropic.claude-haiku-4-5-20251001-v1:0",
|
||||
"claude-opus-4-5-20251101": "anthropic.claude-opus-4-5-20251101-v1:0",
|
||||
"claude-opus-4-5-20251101": "anthropic.claude-opus-4-5-20251101-v1:0",
|
||||
// Nova models
|
||||
"nova-micro-v1:0": "amazon.nova-micro-v1:0",
|
||||
"nova-lite-v1:0": "amazon.nova-lite-v1:0",
|
||||
|
||||
@@ -150,7 +150,7 @@ func baiduHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respon
|
||||
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
|
||||
}
|
||||
if baiduResponse.ErrorMsg != "" {
|
||||
return types.NewError(fmt.Errorf(baiduResponse.ErrorMsg), types.ErrorCodeBadResponseBody), nil
|
||||
return types.NewError(fmt.Errorf("%s", baiduResponse.ErrorMsg), types.ErrorCodeBadResponseBody), nil
|
||||
}
|
||||
fullTextResponse := responseBaidu2OpenAI(&baiduResponse)
|
||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||
@@ -175,7 +175,7 @@ func baiduEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *ht
|
||||
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
|
||||
}
|
||||
if baiduResponse.ErrorMsg != "" {
|
||||
return types.NewError(fmt.Errorf(baiduResponse.ErrorMsg), types.ErrorCodeBadResponseBody), nil
|
||||
return types.NewError(fmt.Errorf("%s", baiduResponse.ErrorMsg), types.ErrorCodeBadResponseBody), nil
|
||||
}
|
||||
fullTextResponse := embeddingResponseBaidu2OpenAI(&baiduResponse)
|
||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||
|
||||
@@ -208,7 +208,7 @@ func handleCozeEvent(c *gin.Context, event string, data string, responseText *st
|
||||
return
|
||||
}
|
||||
|
||||
common.SysLog(fmt.Sprintf("stream event error: ", errorData.Code, errorData.Message))
|
||||
common.SysLog(fmt.Sprintf("stream event error: %v %v", errorData.Code, errorData.Message))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -196,7 +196,7 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela
|
||||
}
|
||||
|
||||
if jResp.Code != 10000 {
|
||||
taskErr = service.TaskErrorWrapper(fmt.Errorf(jResp.Message), fmt.Sprintf("%d", jResp.Code), http.StatusInternalServerError)
|
||||
taskErr = service.TaskErrorWrapper(fmt.Errorf("%s", jResp.Message), fmt.Sprintf("%d", jResp.Code), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -186,7 +186,7 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela
|
||||
return
|
||||
}
|
||||
if kResp.Code != 0 {
|
||||
taskErr = service.TaskErrorWrapperLocal(fmt.Errorf(kResp.Message), "task_failed", http.StatusBadRequest)
|
||||
taskErr = service.TaskErrorWrapperLocal(fmt.Errorf("%s", kResp.Message), "task_failed", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
ov := dto.NewOpenAIVideo()
|
||||
|
||||
@@ -105,7 +105,7 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela
|
||||
return
|
||||
}
|
||||
if !sunoResponse.IsSuccess() {
|
||||
taskErr = service.TaskErrorWrapper(fmt.Errorf(sunoResponse.Message), sunoResponse.Code, http.StatusInternalServerError)
|
||||
taskErr = service.TaskErrorWrapper(fmt.Errorf("%s", sunoResponse.Message), sunoResponse.Code, http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -196,7 +196,7 @@ func RelayTaskSubmit(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.
|
||||
// handle response
|
||||
if resp != nil && resp.StatusCode != http.StatusOK {
|
||||
responseBody, _ := io.ReadAll(resp.Body)
|
||||
taskErr = service.TaskErrorWrapper(fmt.Errorf(string(responseBody)), "fail_to_fetch_task", resp.StatusCode)
|
||||
taskErr = service.TaskErrorWrapper(fmt.Errorf("%s", string(responseBody)), "fail_to_fetch_task", resp.StatusCode)
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -3,9 +3,9 @@ package system_setting
|
||||
import "github.com/QuantumNous/new-api/setting/config"
|
||||
|
||||
type DiscordSettings struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
ClientId string `json:"client_id"`
|
||||
ClientSecret string `json:"client_secret"`
|
||||
Enabled bool `json:"enabled"`
|
||||
ClientId string `json:"client_id"`
|
||||
ClientSecret string `json:"client_secret"`
|
||||
}
|
||||
|
||||
// 默认配置
|
||||
|
||||
Reference in New Issue
Block a user