diff --git a/common/body_storage.go b/common/body_storage.go index ea37cda96..094dbda36 100644 --- a/common/body_storage.go +++ b/common/body_storage.go @@ -302,6 +302,12 @@ func CreateBodyStorageFromReader(reader io.Reader, contentLength int64, maxBytes return storage, nil } +// ReaderOnly wraps an io.Reader to hide io.Closer, preventing http.NewRequest +// from type-asserting io.ReadCloser and closing the underlying BodyStorage. +func ReaderOnly(r io.Reader) io.Reader { + return struct{ io.Reader }{r} +} + // CleanupOldCacheFiles 清理旧的缓存文件(用于启动时清理残留) func CleanupOldCacheFiles() { // 使用统一的缓存管理 diff --git a/common/gin.go b/common/gin.go index e40279723..48971c130 100644 --- a/common/gin.go +++ b/common/gin.go @@ -33,14 +33,14 @@ func IsRequestBodyTooLargeError(err error) bool { return errors.As(err, &mbe) } -func GetRequestBody(c *gin.Context) ([]byte, error) { +func GetRequestBody(c *gin.Context) (io.Seeker, error) { // 首先检查是否有 BodyStorage 缓存 if storage, exists := c.Get(KeyBodyStorage); exists && storage != nil { if bs, ok := storage.(BodyStorage); ok { if _, err := bs.Seek(0, io.SeekStart); err != nil { return nil, fmt.Errorf("failed to seek body storage: %w", err) } - return bs.Bytes() + return bs, nil } } @@ -48,7 +48,12 @@ func GetRequestBody(c *gin.Context) ([]byte, error) { cached, exists := c.Get(KeyRequestBody) if exists && cached != nil { if b, ok := cached.([]byte); ok { - return b, nil + bs, err := CreateBodyStorage(b) + if err != nil { + return nil, err + } + c.Set(KeyBodyStorage, bs) + return bs, nil } } @@ -74,47 +79,20 @@ func GetRequestBody(c *gin.Context) ([]byte, error) { // 缓存存储对象 c.Set(KeyBodyStorage, storage) - // 获取字节数据 - body, err := storage.Bytes() - if err != nil { - return nil, err - } - - // 同时设置旧的缓存键以保持兼容性 - c.Set(KeyRequestBody, body) - - return body, nil + return storage, nil } // GetBodyStorage 获取请求体存储对象(用于需要多次读取的场景) func GetBodyStorage(c *gin.Context) (BodyStorage, error) { - // 检查是否已有存储 - if storage, exists := c.Get(KeyBodyStorage); exists && storage != nil { - if bs, ok := storage.(BodyStorage); ok { - if _, err := bs.Seek(0, io.SeekStart); err != nil { - return nil, fmt.Errorf("failed to seek body storage: %w", err) - } - return bs, nil - } - } - - // 如果没有,调用 GetRequestBody 创建存储 - _, err := GetRequestBody(c) + seeker, err := GetRequestBody(c) if err != nil { return nil, err } - - // 再次获取存储 - if storage, exists := c.Get(KeyBodyStorage); exists && storage != nil { - if bs, ok := storage.(BodyStorage); ok { - if _, err := bs.Seek(0, io.SeekStart); err != nil { - return nil, fmt.Errorf("failed to seek body storage: %w", err) - } - return bs, nil - } + bs, ok := seeker.(BodyStorage) + if !ok { + return nil, errors.New("unexpected body storage type") } - - return nil, errors.New("failed to get body storage") + return bs, nil } // CleanupBodyStorage 清理请求体存储(应在请求结束时调用) @@ -128,13 +106,14 @@ func CleanupBodyStorage(c *gin.Context) { } func UnmarshalBodyReusable(c *gin.Context, v any) error { - requestBody, err := GetRequestBody(c) + storage, err := GetBodyStorage(c) + if err != nil { + return err + } + requestBody, err := storage.Bytes() if err != nil { return err } - //if DebugEnabled { - // println("UnmarshalBodyReusable request body:", string(requestBody)) - //} contentType := c.Request.Header.Get("Content-Type") if strings.HasPrefix(contentType, "application/json") { err = Unmarshal(requestBody, v) @@ -150,7 +129,10 @@ func UnmarshalBodyReusable(c *gin.Context, v any) error { return err } // Reset request body - c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody)) + if _, seekErr := storage.Seek(0, io.SeekStart); seekErr != nil { + return seekErr + } + c.Request.Body = io.NopCloser(storage) return nil } @@ -252,7 +234,11 @@ func init() { } func ParseMultipartFormReusable(c *gin.Context) (*multipart.Form, error) { - requestBody, err := GetRequestBody(c) + storage, err := GetBodyStorage(c) + if err != nil { + return nil, err + } + requestBody, err := storage.Bytes() if err != nil { return nil, err } @@ -270,7 +256,10 @@ func ParseMultipartFormReusable(c *gin.Context) (*multipart.Form, error) { } // Reset request body - c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody)) + if _, seekErr := storage.Seek(0, io.SeekStart); seekErr != nil { + return nil, seekErr + } + c.Request.Body = io.NopCloser(storage) return form, nil } diff --git a/controller/relay.go b/controller/relay.go index 2d5ae7df6..0b30e6e9e 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -1,7 +1,6 @@ package controller import ( - "bytes" "errors" "fmt" "io" @@ -193,7 +192,7 @@ func Relay(c *gin.Context, relayFormat types.RelayFormat) { } addUsedChannel(c, channel.Id) - requestBody, bodyErr := common.GetRequestBody(c) + bodyStorage, bodyErr := common.GetBodyStorage(c) if bodyErr != nil { // Ensure consistent 413 for oversized bodies even when error occurs later (e.g., retry path) if common.IsRequestBodyTooLargeError(bodyErr) || errors.Is(bodyErr, common.ErrRequestBodyTooLarge) { @@ -203,7 +202,7 @@ func Relay(c *gin.Context, relayFormat types.RelayFormat) { } break } - c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody)) + c.Request.Body = io.NopCloser(bodyStorage) switch relayFormat { case types.RelayFormatOpenAIRealtime: @@ -483,7 +482,7 @@ func RelayTask(c *gin.Context) { logger.LogInfo(c, fmt.Sprintf("using channel #%d to retry (remain times %d)", channel.Id, retryParam.GetRetry())) //middleware.SetupContextForSelectedChannel(c, channel, originalModel) - requestBody, err := common.GetRequestBody(c) + bodyStorage, err := common.GetBodyStorage(c) if err != nil { if common.IsRequestBodyTooLargeError(err) || errors.Is(err, common.ErrRequestBodyTooLarge) { taskErr = service.TaskErrorWrapperLocal(err, "read_request_body_failed", http.StatusRequestEntityTooLarge) @@ -492,7 +491,7 @@ func RelayTask(c *gin.Context) { } break } - c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody)) + c.Request.Body = io.NopCloser(bodyStorage) taskErr = taskRelayHandler(c, relayInfo) } useChannel := c.GetStringSlice("use_channel") diff --git a/relay/channel/aws/relay-aws.go b/relay/channel/aws/relay-aws.go index aab4baea0..c2a676738 100644 --- a/relay/channel/aws/relay-aws.go +++ b/relay/channel/aws/relay-aws.go @@ -165,10 +165,14 @@ func doAwsClientRequest(c *gin.Context, info *relaycommon.RelayInfo, a *Adaptor, // buildAwsRequestBody prepares the payload for AWS requests, applying passthrough rules when enabled. func buildAwsRequestBody(c *gin.Context, info *relaycommon.RelayInfo, awsClaudeReq any) ([]byte, error) { if model_setting.GetGlobalSettings().PassThroughRequestEnabled || info.ChannelSetting.PassThroughBodyEnabled { - body, err := common.GetRequestBody(c) + storage, err := common.GetBodyStorage(c) if err != nil { return nil, errors.Wrap(err, "get request body for pass-through fail") } + body, err := storage.Bytes() + if err != nil { + return nil, errors.Wrap(err, "get request body bytes fail") + } var data map[string]interface{} if err := common.Unmarshal(body, &data); err != nil { return nil, errors.Wrap(err, "pass-through unmarshal request body fail") diff --git a/relay/channel/task/sora/adaptor.go b/relay/channel/task/sora/adaptor.go index 9dc03796c..c149f9663 100644 --- a/relay/channel/task/sora/adaptor.go +++ b/relay/channel/task/sora/adaptor.go @@ -1,7 +1,6 @@ package sora import ( - "bytes" "fmt" "io" "net/http" @@ -104,11 +103,11 @@ func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info } func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayInfo) (io.Reader, error) { - cachedBody, err := common.GetRequestBody(c) + storage, err := common.GetBodyStorage(c) if err != nil { return nil, errors.Wrap(err, "get_request_body_failed") } - return bytes.NewReader(cachedBody), nil + return common.ReaderOnly(storage), nil } // DoRequest delegates to common helper. diff --git a/relay/claude_handler.go b/relay/claude_handler.go index 518ff3f87..81adb276a 100644 --- a/relay/claude_handler.go +++ b/relay/claude_handler.go @@ -129,11 +129,11 @@ func ClaudeHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ var requestBody io.Reader if model_setting.GetGlobalSettings().PassThroughRequestEnabled || info.ChannelSetting.PassThroughBodyEnabled { - body, err := common.GetRequestBody(c) + storage, err := common.GetBodyStorage(c) if err != nil { return types.NewErrorWithStatusCode(err, types.ErrorCodeReadRequestBodyFailed, http.StatusBadRequest, types.ErrOptionWithSkipRetry()) } - requestBody = bytes.NewBuffer(body) + requestBody = common.ReaderOnly(storage) } else { convertedRequest, err := adaptor.ConvertClaudeRequest(c, info, request) if err != nil { diff --git a/relay/compatible_handler.go b/relay/compatible_handler.go index eeb7b7aab..e7adddbbf 100644 --- a/relay/compatible_handler.go +++ b/relay/compatible_handler.go @@ -100,14 +100,16 @@ func TextHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types var requestBody io.Reader if passThroughGlobal || info.ChannelSetting.PassThroughBodyEnabled { - body, err := common.GetRequestBody(c) + storage, err := common.GetBodyStorage(c) if err != nil { return types.NewErrorWithStatusCode(err, types.ErrorCodeReadRequestBodyFailed, http.StatusBadRequest, types.ErrOptionWithSkipRetry()) } if common.DebugEnabled { - println("requestBody: ", string(body)) + if debugBytes, bErr := storage.Bytes(); bErr == nil { + println("requestBody: ", string(debugBytes)) + } } - requestBody = bytes.NewBuffer(body) + requestBody = common.ReaderOnly(storage) } else { convertedRequest, err := adaptor.ConvertOpenAIRequest(c, info, request) if err != nil { diff --git a/relay/gemini_handler.go b/relay/gemini_handler.go index 779670b9e..a1b8e592e 100644 --- a/relay/gemini_handler.go +++ b/relay/gemini_handler.go @@ -138,11 +138,11 @@ func GeminiHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ var requestBody io.Reader if model_setting.GetGlobalSettings().PassThroughRequestEnabled || info.ChannelSetting.PassThroughBodyEnabled { - body, err := common.GetRequestBody(c) + storage, err := common.GetBodyStorage(c) if err != nil { return types.NewErrorWithStatusCode(err, types.ErrorCodeReadRequestBodyFailed, http.StatusBadRequest, types.ErrOptionWithSkipRetry()) } - requestBody = bytes.NewReader(body) + requestBody = common.ReaderOnly(storage) } else { // 使用 ConvertGeminiRequest 转换请求格式 convertedRequest, err := adaptor.ConvertGeminiRequest(c, info, request) diff --git a/relay/image_handler.go b/relay/image_handler.go index 1ee790b74..e83294268 100644 --- a/relay/image_handler.go +++ b/relay/image_handler.go @@ -47,11 +47,11 @@ func ImageHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *type var requestBody io.Reader if model_setting.GetGlobalSettings().PassThroughRequestEnabled || info.ChannelSetting.PassThroughBodyEnabled { - body, err := common.GetRequestBody(c) + storage, err := common.GetBodyStorage(c) if err != nil { return types.NewErrorWithStatusCode(err, types.ErrorCodeReadRequestBodyFailed, http.StatusBadRequest, types.ErrOptionWithSkipRetry()) } - requestBody = bytes.NewBuffer(body) + requestBody = common.ReaderOnly(storage) } else { convertedRequest, err := adaptor.ConvertImageRequest(c, info, *request) if err != nil { diff --git a/relay/rerank_handler.go b/relay/rerank_handler.go index 35c66a291..8fe2930e9 100644 --- a/relay/rerank_handler.go +++ b/relay/rerank_handler.go @@ -43,11 +43,11 @@ func RerankHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ var requestBody io.Reader if model_setting.GetGlobalSettings().PassThroughRequestEnabled || info.ChannelSetting.PassThroughBodyEnabled { - body, err := common.GetRequestBody(c) + storage, err := common.GetBodyStorage(c) if err != nil { return types.NewErrorWithStatusCode(err, types.ErrorCodeReadRequestBodyFailed, http.StatusBadRequest, types.ErrOptionWithSkipRetry()) } - requestBody = bytes.NewBuffer(body) + requestBody = common.ReaderOnly(storage) } else { convertedRequest, err := adaptor.ConvertRerankRequest(c, info.RelayMode, *request) if err != nil { diff --git a/relay/responses_handler.go b/relay/responses_handler.go index 8954bd5cc..04fc3470e 100644 --- a/relay/responses_handler.go +++ b/relay/responses_handler.go @@ -72,11 +72,11 @@ func ResponsesHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError * adaptor.Init(info) var requestBody io.Reader if model_setting.GetGlobalSettings().PassThroughRequestEnabled || info.ChannelSetting.PassThroughBodyEnabled { - body, err := common.GetRequestBody(c) + storage, err := common.GetBodyStorage(c) if err != nil { return types.NewError(err, types.ErrorCodeReadRequestBodyFailed, types.ErrOptionWithSkipRetry()) } - requestBody = bytes.NewBuffer(body) + requestBody = common.ReaderOnly(storage) } else { convertedRequest, err := adaptor.ConvertOpenAIResponsesRequest(c, info, *request) if err != nil { diff --git a/service/channel_affinity.go b/service/channel_affinity.go index a94eb29e8..fe1524c59 100644 --- a/service/channel_affinity.go +++ b/service/channel_affinity.go @@ -288,7 +288,11 @@ func extractChannelAffinityValue(c *gin.Context, src operation_setting.ChannelAf if src.Path == "" { return "" } - body, err := common.GetRequestBody(c) + storage, err := common.GetBodyStorage(c) + if err != nil { + return "" + } + body, err := storage.Bytes() if err != nil || len(body) == 0 { return "" }