From 8e3f9b1faa0716b93254ab378c4dc3967259c254 Mon Sep 17 00:00:00 2001 From: t0ng7u Date: Tue, 16 Dec 2025 17:00:19 +0800 Subject: [PATCH] =?UTF-8?q?=F0=9F=9B=A1=EF=B8=8F=20fix:=20prevent=20OOM=20?= =?UTF-8?q?on=20large/decompressed=20requests;=20skip=20heavy=20prompt=20m?= =?UTF-8?q?eta=20when=20token=20count=20is=20disabled?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Clamp request body size (including post-decompression) to avoid memory exhaustion caused by huge payloads/zip bombs, especially with large-context Claude requests. Add a configurable `MAX_REQUEST_BODY_MB` (default `32`) and document it. - Enforce max request body size after gzip/br decompression via `http.MaxBytesReader` - Add a secondary size guard in `common.GetRequestBody` and cache-safe handling - Return **413 Request Entity Too Large** on oversized bodies in relay entry - Avoid building large `TokenCountMeta.CombineText` when both token counting and sensitive check are disabled (use lightweight meta for pricing) - Update READMEs (CN/EN/FR/JA) with `MAX_REQUEST_BODY_MB` - Fix a handful of vet/formatting issues encountered during the change - `go test ./...` passes --- README.en.md | 1 + README.fr.md | 1 + README.ja.md | 1 + README.md | 1 + common/gin.go | 43 +++++++++++++++++++---- common/init.go | 2 ++ constant/env.go | 1 + controller/discord.go | 2 +- controller/relay.go | 47 +++++++++++++++++++++++-- controller/task.go | 4 +-- controller/topup_creem.go | 6 ++-- middleware/distributor.go | 2 +- middleware/gzip.go | 51 ++++++++++++++++++++++++---- relay/channel/aws/constants.go | 2 +- relay/channel/baidu/relay-baidu.go | 4 +-- relay/channel/coze/relay-coze.go | 2 +- relay/channel/task/jimeng/adaptor.go | 2 +- relay/channel/task/kling/adaptor.go | 2 +- relay/channel/task/suno/adaptor.go | 2 +- relay/relay_task.go | 2 +- setting/system_setting/discord.go | 6 ++-- 21 files changed, 149 insertions(+), 35 deletions(-) diff --git a/README.en.md b/README.en.md index e71f5e623..063d360b2 100644 --- a/README.en.md +++ b/README.en.md @@ -305,6 +305,7 @@ docker run --name new-api -d --restart always \ | `REDIS_CONN_STRING` | Redis connection string | - | | `STREAMING_TIMEOUT` | Streaming timeout (seconds) | `300` | | `STREAM_SCANNER_MAX_BUFFER_MB` | Max per-line buffer (MB) for the stream scanner; increase when upstream sends huge image/base64 payloads | `64` | +| `MAX_REQUEST_BODY_MB` | Max request body size (MB, counted **after decompression**; prevents huge requests/zip bombs from exhausting memory). Exceeding it returns `413` | `32` | | `AZURE_DEFAULT_API_VERSION` | Azure API version | `2025-04-01-preview` | | `ERROR_LOG_ENABLED` | Error log switch | `false` | diff --git a/README.fr.md b/README.fr.md index 35051223e..0aa212d1f 100644 --- a/README.fr.md +++ b/README.fr.md @@ -301,6 +301,7 @@ docker run --name new-api -d --restart always \ | `REDIS_CONN_STRING` | Chaine de connexion Redis | - | | `STREAMING_TIMEOUT` | Délai d'expiration du streaming (secondes) | `300` | | `STREAM_SCANNER_MAX_BUFFER_MB` | Taille max du buffer par ligne (Mo) pour le scanner SSE ; à augmenter quand les sorties image/base64 sont très volumineuses (ex. images 4K) | `64` | +| `MAX_REQUEST_BODY_MB` | Taille maximale du corps de requête (Mo, comptée **après décompression** ; évite les requêtes énormes/zip bombs qui saturent la mémoire). Dépassement ⇒ `413` | `32` | | `AZURE_DEFAULT_API_VERSION` | Version de l'API Azure | `2025-04-01-preview` | | `ERROR_LOG_ENABLED` | Interrupteur du journal d'erreurs | `false` | diff --git a/README.ja.md b/README.ja.md index 0c4b91f66..e76cd0ed4 100644 --- a/README.ja.md +++ b/README.ja.md @@ -310,6 +310,7 @@ docker run --name new-api -d --restart always \ | `REDIS_CONN_STRING` | Redis接続文字列 | - | | `STREAMING_TIMEOUT` | ストリーミング応答のタイムアウト時間(秒) | `300` | | `STREAM_SCANNER_MAX_BUFFER_MB` | ストリームスキャナの1行あたりバッファ上限(MB)。4K画像など巨大なbase64 `data:` ペイロードを扱う場合は値を増加させてください | `64` | +| `MAX_REQUEST_BODY_MB` | リクエストボディ最大サイズ(MB、**解凍後**に計測。巨大リクエスト/zip bomb によるメモリ枯渇を防止)。超過時は `413` | `32` | | `AZURE_DEFAULT_API_VERSION` | Azure APIバージョン | `2025-04-01-preview` | | `ERROR_LOG_ENABLED` | エラーログスイッチ | `false` | diff --git a/README.md b/README.md index 3d5b6923c..f1cb37480 100644 --- a/README.md +++ b/README.md @@ -306,6 +306,7 @@ docker run --name new-api -d --restart always \ | `REDIS_CONN_STRING` | Redis 连接字符串 | - | | `STREAMING_TIMEOUT` | 流式超时时间(秒) | `300` | | `STREAM_SCANNER_MAX_BUFFER_MB` | 流式扫描器单行最大缓冲(MB),图像生成等超大 `data:` 片段(如 4K 图片 base64)需适当调大 | `64` | +| `MAX_REQUEST_BODY_MB` | 请求体最大大小(MB,**解压后**计;防止超大请求/zip bomb 导致内存暴涨),超过将返回 `413` | `32` | | `AZURE_DEFAULT_API_VERSION` | Azure API 版本 | `2025-04-01-preview` | | `ERROR_LOG_ENABLED` | 错误日志开关 | `false` | diff --git a/common/gin.go b/common/gin.go index db299f293..e927962cf 100644 --- a/common/gin.go +++ b/common/gin.go @@ -18,18 +18,47 @@ import ( const KeyRequestBody = "key_request_body" -func GetRequestBody(c *gin.Context) ([]byte, error) { - requestBody, _ := c.Get(KeyRequestBody) - if requestBody != nil { - return requestBody.([]byte), nil +var ErrRequestBodyTooLarge = errors.New("request body too large") + +func IsRequestBodyTooLargeError(err error) bool { + if err == nil { + return false } - requestBody, err := io.ReadAll(c.Request.Body) + if errors.Is(err, ErrRequestBodyTooLarge) { + return true + } + var mbe *http.MaxBytesError + return errors.As(err, &mbe) +} + +func GetRequestBody(c *gin.Context) ([]byte, error) { + cached, exists := c.Get(KeyRequestBody) + if exists && cached != nil { + if b, ok := cached.([]byte); ok { + return b, nil + } + } + maxMB := constant.MaxRequestBodyMB + if maxMB <= 0 { + maxMB = 64 + } + maxBytes := int64(maxMB) << 20 + + limited := io.LimitReader(c.Request.Body, maxBytes+1) + body, err := io.ReadAll(limited) if err != nil { + _ = c.Request.Body.Close() + if IsRequestBodyTooLargeError(err) { + return nil, ErrRequestBodyTooLarge + } return nil, err } _ = c.Request.Body.Close() - c.Set(KeyRequestBody, requestBody) - return requestBody.([]byte), nil + if int64(len(body)) > maxBytes { + return nil, ErrRequestBodyTooLarge + } + c.Set(KeyRequestBody, body) + return body, nil } func UnmarshalBodyReusable(c *gin.Context, v any) error { diff --git a/common/init.go b/common/init.go index 3f3bd1df4..ac27fd2c2 100644 --- a/common/init.go +++ b/common/init.go @@ -117,6 +117,8 @@ func initConstantEnv() { constant.DifyDebug = GetEnvOrDefaultBool("DIFY_DEBUG", true) constant.MaxFileDownloadMB = GetEnvOrDefault("MAX_FILE_DOWNLOAD_MB", 20) constant.StreamScannerMaxBufferMB = GetEnvOrDefault("STREAM_SCANNER_MAX_BUFFER_MB", 64) + // MaxRequestBodyMB 请求体最大大小(解压后),用于防止超大请求/zip bomb导致内存暴涨 + constant.MaxRequestBodyMB = GetEnvOrDefault("MAX_REQUEST_BODY_MB", 32) // ForceStreamOption 覆盖请求参数,强制返回usage信息 constant.ForceStreamOption = GetEnvOrDefaultBool("FORCE_STREAM_OPTION", true) constant.CountToken = GetEnvOrDefaultBool("CountToken", true) diff --git a/constant/env.go b/constant/env.go index 975bced7c..c561c207d 100644 --- a/constant/env.go +++ b/constant/env.go @@ -9,6 +9,7 @@ var CountToken bool var GetMediaToken bool var GetMediaTokenNotStream bool var UpdateTask bool +var MaxRequestBodyMB int var AzureDefaultAPIVersion string var GeminiVisionMaxImageNum int var NotifyLimitCount int diff --git a/controller/discord.go b/controller/discord.go index 41dd59808..a0865de51 100644 --- a/controller/discord.go +++ b/controller/discord.go @@ -114,7 +114,7 @@ func DiscordOAuth(c *gin.Context) { DiscordBind(c) return } - if !system_setting.GetDiscordSettings().Enabled { + if !system_setting.GetDiscordSettings().Enabled { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "管理员未开启通过 Discord 登录以及注册", diff --git a/controller/relay.go b/controller/relay.go index a0618452c..29fd209d2 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -2,6 +2,7 @@ package controller import ( "bytes" + "errors" "fmt" "io" "log" @@ -104,7 +105,12 @@ func Relay(c *gin.Context, relayFormat types.RelayFormat) { request, err := helper.GetAndValidateRequest(c, relayFormat) if err != nil { - newAPIError = types.NewError(err, types.ErrorCodeInvalidRequest) + // Map "request body too large" to 413 so clients can handle it correctly + if common.IsRequestBodyTooLargeError(err) || errors.Is(err, common.ErrRequestBodyTooLarge) { + newAPIError = types.NewErrorWithStatusCode(err, types.ErrorCodeReadRequestBodyFailed, http.StatusRequestEntityTooLarge, types.ErrOptionWithSkipRetry()) + } else { + newAPIError = types.NewError(err, types.ErrorCodeInvalidRequest) + } return } @@ -114,9 +120,17 @@ func Relay(c *gin.Context, relayFormat types.RelayFormat) { return } - meta := request.GetTokenCountMeta() + needSensitiveCheck := setting.ShouldCheckPromptSensitive() + needCountToken := constant.CountToken + // Avoid building huge CombineText (strings.Join) when token counting and sensitive check are both disabled. + var meta *types.TokenCountMeta + if needSensitiveCheck || needCountToken { + meta = request.GetTokenCountMeta() + } else { + meta = fastTokenCountMetaForPricing(request) + } - if setting.ShouldCheckPromptSensitive() { + if needSensitiveCheck && meta != nil { contains, words := service.CheckSensitiveText(meta.CombineText) if contains { logger.LogWarn(c, fmt.Sprintf("user sensitive words detected: %s", strings.Join(words, ", "))) @@ -218,6 +232,33 @@ func addUsedChannel(c *gin.Context, channelId int) { c.Set("use_channel", useChannel) } +func fastTokenCountMetaForPricing(request dto.Request) *types.TokenCountMeta { + if request == nil { + return &types.TokenCountMeta{} + } + meta := &types.TokenCountMeta{ + TokenType: types.TokenTypeTokenizer, + } + switch r := request.(type) { + case *dto.GeneralOpenAIRequest: + if r.MaxCompletionTokens > r.MaxTokens { + meta.MaxTokens = int(r.MaxCompletionTokens) + } else { + meta.MaxTokens = int(r.MaxTokens) + } + case *dto.OpenAIResponsesRequest: + meta.MaxTokens = int(r.MaxOutputTokens) + case *dto.ClaudeRequest: + meta.MaxTokens = int(r.MaxTokens) + case *dto.ImageRequest: + // Pricing for image requests depends on ImagePriceRatio; safe to compute even when CountToken is disabled. + return r.GetTokenCountMeta() + default: + // Best-effort: leave CombineText empty to avoid large allocations. + } + return meta +} + func getChannel(c *gin.Context, info *relaycommon.RelayInfo, retryParam *service.RetryParam) (*model.Channel, *types.NewAPIError) { if info.ChannelMeta == nil { autoBan := c.GetBool("auto_ban") diff --git a/controller/task.go b/controller/task.go index 16acc2269..244f9161c 100644 --- a/controller/task.go +++ b/controller/task.go @@ -88,7 +88,7 @@ func UpdateSunoTaskAll(ctx context.Context, taskChannelM map[int][]string, taskM for channelId, taskIds := range taskChannelM { err := updateSunoTaskAll(ctx, channelId, taskIds, taskM) if err != nil { - logger.LogError(ctx, fmt.Sprintf("渠道 #%d 更新异步任务失败: %d", channelId, err.Error())) + logger.LogError(ctx, fmt.Sprintf("渠道 #%d 更新异步任务失败: %s", channelId, err.Error())) } } return nil @@ -141,7 +141,7 @@ func updateSunoTaskAll(ctx context.Context, channelId int, taskIds []string, tas return err } if !responseItems.IsSuccess() { - common.SysLog(fmt.Sprintf("渠道 #%d 未完成的任务有: %d, 成功获取到任务数: %d", channelId, len(taskIds), string(responseBody))) + common.SysLog(fmt.Sprintf("渠道 #%d 未完成的任务有: %d, 成功获取到任务数: %s", channelId, len(taskIds), string(responseBody))) return err } diff --git a/controller/topup_creem.go b/controller/topup_creem.go index aab951c54..80a869673 100644 --- a/controller/topup_creem.go +++ b/controller/topup_creem.go @@ -7,12 +7,12 @@ import ( "encoding/hex" "encoding/json" "fmt" - "io" - "log" - "net/http" "github.com/QuantumNous/new-api/common" "github.com/QuantumNous/new-api/model" "github.com/QuantumNous/new-api/setting" + "io" + "log" + "net/http" "time" "github.com/gin-gonic/gin" diff --git a/middleware/distributor.go b/middleware/distributor.go index 390dc059f..a33404726 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -162,7 +162,7 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) { } midjourneyModel, mjErr, success := service.GetMjRequestModel(relayMode, &midjourneyRequest) if mjErr != nil { - return nil, false, fmt.Errorf(mjErr.Description) + return nil, false, fmt.Errorf("%s", mjErr.Description) } if midjourneyModel == "" { if !success { diff --git a/middleware/gzip.go b/middleware/gzip.go index 7fe2f3be3..e86d2fffc 100644 --- a/middleware/gzip.go +++ b/middleware/gzip.go @@ -5,32 +5,69 @@ import ( "io" "net/http" + "github.com/QuantumNous/new-api/constant" "github.com/andybalholm/brotli" "github.com/gin-gonic/gin" ) +type readCloser struct { + io.Reader + closeFn func() error +} + +func (rc *readCloser) Close() error { + if rc.closeFn != nil { + return rc.closeFn() + } + return nil +} + func DecompressRequestMiddleware() gin.HandlerFunc { return func(c *gin.Context) { if c.Request.Body == nil || c.Request.Method == http.MethodGet { c.Next() return } + maxMB := constant.MaxRequestBodyMB + if maxMB <= 0 { + maxMB = 64 + } + maxBytes := int64(maxMB) << 20 + + origBody := c.Request.Body + wrapMaxBytes := func(body io.ReadCloser) io.ReadCloser { + return http.MaxBytesReader(c.Writer, body, maxBytes) + } + switch c.GetHeader("Content-Encoding") { case "gzip": - gzipReader, err := gzip.NewReader(c.Request.Body) + gzipReader, err := gzip.NewReader(origBody) if err != nil { + _ = origBody.Close() c.AbortWithStatus(http.StatusBadRequest) return } - defer gzipReader.Close() - - // Replace the request body with the decompressed data - c.Request.Body = io.NopCloser(gzipReader) + // Replace the request body with the decompressed data, and enforce a max size (post-decompression). + c.Request.Body = wrapMaxBytes(&readCloser{ + Reader: gzipReader, + closeFn: func() error { + _ = gzipReader.Close() + return origBody.Close() + }, + }) c.Request.Header.Del("Content-Encoding") case "br": - reader := brotli.NewReader(c.Request.Body) - c.Request.Body = io.NopCloser(reader) + reader := brotli.NewReader(origBody) + c.Request.Body = wrapMaxBytes(&readCloser{ + Reader: reader, + closeFn: func() error { + return origBody.Close() + }, + }) c.Request.Header.Del("Content-Encoding") + default: + // Even for uncompressed bodies, enforce a max size to avoid huge request allocations. + c.Request.Body = wrapMaxBytes(origBody) } // Continue processing the request diff --git a/relay/channel/aws/constants.go b/relay/channel/aws/constants.go index 6323bb3b1..888d96eef 100644 --- a/relay/channel/aws/constants.go +++ b/relay/channel/aws/constants.go @@ -18,7 +18,7 @@ var awsModelIDMap = map[string]string{ "claude-opus-4-1-20250805": "anthropic.claude-opus-4-1-20250805-v1:0", "claude-sonnet-4-5-20250929": "anthropic.claude-sonnet-4-5-20250929-v1:0", "claude-haiku-4-5-20251001": "anthropic.claude-haiku-4-5-20251001-v1:0", - "claude-opus-4-5-20251101": "anthropic.claude-opus-4-5-20251101-v1:0", + "claude-opus-4-5-20251101": "anthropic.claude-opus-4-5-20251101-v1:0", // Nova models "nova-micro-v1:0": "amazon.nova-micro-v1:0", "nova-lite-v1:0": "amazon.nova-lite-v1:0", diff --git a/relay/channel/baidu/relay-baidu.go b/relay/channel/baidu/relay-baidu.go index 8597e50ef..691d41888 100644 --- a/relay/channel/baidu/relay-baidu.go +++ b/relay/channel/baidu/relay-baidu.go @@ -150,7 +150,7 @@ func baiduHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respon return types.NewError(err, types.ErrorCodeBadResponseBody), nil } if baiduResponse.ErrorMsg != "" { - return types.NewError(fmt.Errorf(baiduResponse.ErrorMsg), types.ErrorCodeBadResponseBody), nil + return types.NewError(fmt.Errorf("%s", baiduResponse.ErrorMsg), types.ErrorCodeBadResponseBody), nil } fullTextResponse := responseBaidu2OpenAI(&baiduResponse) jsonResponse, err := json.Marshal(fullTextResponse) @@ -175,7 +175,7 @@ func baiduEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *ht return types.NewError(err, types.ErrorCodeBadResponseBody), nil } if baiduResponse.ErrorMsg != "" { - return types.NewError(fmt.Errorf(baiduResponse.ErrorMsg), types.ErrorCodeBadResponseBody), nil + return types.NewError(fmt.Errorf("%s", baiduResponse.ErrorMsg), types.ErrorCodeBadResponseBody), nil } fullTextResponse := embeddingResponseBaidu2OpenAI(&baiduResponse) jsonResponse, err := json.Marshal(fullTextResponse) diff --git a/relay/channel/coze/relay-coze.go b/relay/channel/coze/relay-coze.go index 7095a8b6d..2edeeee0d 100644 --- a/relay/channel/coze/relay-coze.go +++ b/relay/channel/coze/relay-coze.go @@ -208,7 +208,7 @@ func handleCozeEvent(c *gin.Context, event string, data string, responseText *st return } - common.SysLog(fmt.Sprintf("stream event error: ", errorData.Code, errorData.Message)) + common.SysLog(fmt.Sprintf("stream event error: %v %v", errorData.Code, errorData.Message)) } } diff --git a/relay/channel/task/jimeng/adaptor.go b/relay/channel/task/jimeng/adaptor.go index d6973531f..91d3f2361 100644 --- a/relay/channel/task/jimeng/adaptor.go +++ b/relay/channel/task/jimeng/adaptor.go @@ -196,7 +196,7 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela } if jResp.Code != 10000 { - taskErr = service.TaskErrorWrapper(fmt.Errorf(jResp.Message), fmt.Sprintf("%d", jResp.Code), http.StatusInternalServerError) + taskErr = service.TaskErrorWrapper(fmt.Errorf("%s", jResp.Message), fmt.Sprintf("%d", jResp.Code), http.StatusInternalServerError) return } diff --git a/relay/channel/task/kling/adaptor.go b/relay/channel/task/kling/adaptor.go index d00350652..4c3c9d61b 100644 --- a/relay/channel/task/kling/adaptor.go +++ b/relay/channel/task/kling/adaptor.go @@ -186,7 +186,7 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela return } if kResp.Code != 0 { - taskErr = service.TaskErrorWrapperLocal(fmt.Errorf(kResp.Message), "task_failed", http.StatusBadRequest) + taskErr = service.TaskErrorWrapperLocal(fmt.Errorf("%s", kResp.Message), "task_failed", http.StatusBadRequest) return } ov := dto.NewOpenAIVideo() diff --git a/relay/channel/task/suno/adaptor.go b/relay/channel/task/suno/adaptor.go index f7c891723..8ea9a1c7f 100644 --- a/relay/channel/task/suno/adaptor.go +++ b/relay/channel/task/suno/adaptor.go @@ -105,7 +105,7 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela return } if !sunoResponse.IsSuccess() { - taskErr = service.TaskErrorWrapper(fmt.Errorf(sunoResponse.Message), sunoResponse.Code, http.StatusInternalServerError) + taskErr = service.TaskErrorWrapper(fmt.Errorf("%s", sunoResponse.Message), sunoResponse.Code, http.StatusInternalServerError) return } diff --git a/relay/relay_task.go b/relay/relay_task.go index bac05e0ee..04a905c7f 100644 --- a/relay/relay_task.go +++ b/relay/relay_task.go @@ -196,7 +196,7 @@ func RelayTaskSubmit(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto. // handle response if resp != nil && resp.StatusCode != http.StatusOK { responseBody, _ := io.ReadAll(resp.Body) - taskErr = service.TaskErrorWrapper(fmt.Errorf(string(responseBody)), "fail_to_fetch_task", resp.StatusCode) + taskErr = service.TaskErrorWrapper(fmt.Errorf("%s", string(responseBody)), "fail_to_fetch_task", resp.StatusCode) return } diff --git a/setting/system_setting/discord.go b/setting/system_setting/discord.go index f4e763ffa..f4789060b 100644 --- a/setting/system_setting/discord.go +++ b/setting/system_setting/discord.go @@ -3,9 +3,9 @@ package system_setting import "github.com/QuantumNous/new-api/setting/config" type DiscordSettings struct { - Enabled bool `json:"enabled"` - ClientId string `json:"client_id"` - ClientSecret string `json:"client_secret"` + Enabled bool `json:"enabled"` + ClientId string `json:"client_id"` + ClientSecret string `json:"client_secret"` } // 默认配置