From 8cb56fc319aa79ee848648331dcae5de1c5ecb45 Mon Sep 17 00:00:00 2001 From: t0ng7u Date: Tue, 16 Dec 2025 18:10:00 +0800 Subject: [PATCH] =?UTF-8?q?=F0=9F=A7=B9=20fix:=20harden=20request-body=20s?= =?UTF-8?q?ize=20handling=20and=20error=20unwrapping?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Tighten oversized request handling across relay paths and make error matching reliable. - Align `MAX_REQUEST_BODY_MB` fallback to `32` in request body reader and decompression middleware - Stop ignoring `GetRequestBody` errors in relay retry paths; return consistent **413** on oversized bodies (400 for other read errors) - Add `Unwrap()` to `types.NewAPIError` so `errors.Is/As` can match wrapped underlying errors - `go test ./...` passes --- common/gin.go | 2 +- controller/relay.go | 29 +++++++++++++++++++++++------ middleware/gzip.go | 2 +- types/error.go | 8 ++++++++ 4 files changed, 33 insertions(+), 8 deletions(-) diff --git a/common/gin.go b/common/gin.go index e927962cf..95996b619 100644 --- a/common/gin.go +++ b/common/gin.go @@ -40,7 +40,7 @@ func GetRequestBody(c *gin.Context) ([]byte, error) { } maxMB := constant.MaxRequestBodyMB if maxMB <= 0 { - maxMB = 64 + maxMB = 32 } maxBytes := int64(maxMB) << 20 diff --git a/controller/relay.go b/controller/relay.go index 29fd209d2..9759fa30c 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -179,15 +179,24 @@ func Relay(c *gin.Context, relayFormat types.RelayFormat) { } for ; retryParam.GetRetry() <= common.RetryTimes; retryParam.IncreaseRetry() { - channel, err := getChannel(c, relayInfo, retryParam) - if err != nil { - logger.LogError(c, err.Error()) - newAPIError = err + channel, channelErr := getChannel(c, relayInfo, retryParam) + if channelErr != nil { + logger.LogError(c, channelErr.Error()) + newAPIError = channelErr break } addUsedChannel(c, channel.Id) - requestBody, _ := common.GetRequestBody(c) + requestBody, bodyErr := common.GetRequestBody(c) + if bodyErr != nil { + // Ensure consistent 413 for oversized bodies even when error occurs later (e.g., retry path) + if common.IsRequestBodyTooLargeError(bodyErr) || errors.Is(bodyErr, common.ErrRequestBodyTooLarge) { + newAPIError = types.NewErrorWithStatusCode(bodyErr, types.ErrorCodeReadRequestBodyFailed, http.StatusRequestEntityTooLarge, types.ErrOptionWithSkipRetry()) + } else { + newAPIError = types.NewErrorWithStatusCode(bodyErr, types.ErrorCodeReadRequestBodyFailed, http.StatusBadRequest, types.ErrOptionWithSkipRetry()) + } + break + } c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody)) switch relayFormat { @@ -473,7 +482,15 @@ func RelayTask(c *gin.Context) { logger.LogInfo(c, fmt.Sprintf("using channel #%d to retry (remain times %d)", channel.Id, retryParam.GetRetry())) //middleware.SetupContextForSelectedChannel(c, channel, originalModel) - requestBody, _ := common.GetRequestBody(c) + requestBody, err := common.GetRequestBody(c) + if err != nil { + if common.IsRequestBodyTooLargeError(err) || errors.Is(err, common.ErrRequestBodyTooLarge) { + taskErr = service.TaskErrorWrapperLocal(err, "read_request_body_failed", http.StatusRequestEntityTooLarge) + } else { + taskErr = service.TaskErrorWrapperLocal(err, "read_request_body_failed", http.StatusBadRequest) + } + break + } c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody)) taskErr = taskRelayHandler(c, relayInfo) } diff --git a/middleware/gzip.go b/middleware/gzip.go index e86d2fffc..5e5682532 100644 --- a/middleware/gzip.go +++ b/middleware/gzip.go @@ -30,7 +30,7 @@ func DecompressRequestMiddleware() gin.HandlerFunc { } maxMB := constant.MaxRequestBodyMB if maxMB <= 0 { - maxMB = 64 + maxMB = 32 } maxBytes := int64(maxMB) << 20 diff --git a/types/error.go b/types/error.go index 9c12034e1..3bfd0399a 100644 --- a/types/error.go +++ b/types/error.go @@ -94,6 +94,14 @@ type NewAPIError struct { StatusCode int } +// Unwrap enables errors.Is / errors.As to work with NewAPIError by exposing the underlying error. +func (e *NewAPIError) Unwrap() error { + if e == nil { + return nil + } + return e.Err +} + func (e *NewAPIError) GetErrorCode() ErrorCode { if e == nil { return ""