mirror of
https://github.com/QuantumNous/new-api.git
synced 2026-03-30 04:03:18 +00:00
272 lines
6.2 KiB
Go
272 lines
6.2 KiB
Go
package common
|
|
|
|
import (
|
|
"bytes"
|
|
"fmt"
|
|
"io"
|
|
"mime"
|
|
"mime/multipart"
|
|
"net/http"
|
|
"net/url"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/QuantumNous/new-api/constant"
|
|
"github.com/pkg/errors"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
)
|
|
|
|
const KeyRequestBody = "key_request_body"
|
|
|
|
var ErrRequestBodyTooLarge = errors.New("request body too large")
|
|
|
|
func IsRequestBodyTooLargeError(err error) bool {
|
|
if err == nil {
|
|
return false
|
|
}
|
|
if errors.Is(err, ErrRequestBodyTooLarge) {
|
|
return true
|
|
}
|
|
var mbe *http.MaxBytesError
|
|
return errors.As(err, &mbe)
|
|
}
|
|
|
|
func GetRequestBody(c *gin.Context) ([]byte, error) {
|
|
cached, exists := c.Get(KeyRequestBody)
|
|
if exists && cached != nil {
|
|
if b, ok := cached.([]byte); ok {
|
|
return b, nil
|
|
}
|
|
}
|
|
maxMB := constant.MaxRequestBodyMB
|
|
if maxMB <= 0 {
|
|
// no limit
|
|
body, err := io.ReadAll(c.Request.Body)
|
|
_ = c.Request.Body.Close()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
c.Set(KeyRequestBody, body)
|
|
return body, nil
|
|
}
|
|
maxBytes := int64(maxMB) << 20
|
|
|
|
limited := io.LimitReader(c.Request.Body, maxBytes+1)
|
|
body, err := io.ReadAll(limited)
|
|
if err != nil {
|
|
_ = c.Request.Body.Close()
|
|
if IsRequestBodyTooLargeError(err) {
|
|
return nil, errors.Wrap(ErrRequestBodyTooLarge, fmt.Sprintf("request body exceeds %d MB", maxMB))
|
|
}
|
|
return nil, err
|
|
}
|
|
_ = c.Request.Body.Close()
|
|
if int64(len(body)) > maxBytes {
|
|
return nil, errors.Wrap(ErrRequestBodyTooLarge, fmt.Sprintf("request body exceeds %d MB", maxMB))
|
|
}
|
|
c.Set(KeyRequestBody, body)
|
|
return body, nil
|
|
}
|
|
|
|
func UnmarshalBodyReusable(c *gin.Context, v any) error {
|
|
requestBody, err := GetRequestBody(c)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
//if DebugEnabled {
|
|
// println("UnmarshalBodyReusable request body:", string(requestBody))
|
|
//}
|
|
contentType := c.Request.Header.Get("Content-Type")
|
|
if strings.HasPrefix(contentType, "application/json") {
|
|
err = Unmarshal(requestBody, v)
|
|
} else if strings.Contains(contentType, gin.MIMEPOSTForm) {
|
|
err = parseFormData(requestBody, v)
|
|
} else if strings.Contains(contentType, gin.MIMEMultipartPOSTForm) {
|
|
err = parseMultipartFormData(c, requestBody, v)
|
|
} else {
|
|
// skip for now
|
|
// TODO: someday non json request have variant model, we will need to implementation this
|
|
}
|
|
if err != nil {
|
|
return err
|
|
}
|
|
// Reset request body
|
|
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
|
|
return nil
|
|
}
|
|
|
|
func SetContextKey(c *gin.Context, key constant.ContextKey, value any) {
|
|
c.Set(string(key), value)
|
|
}
|
|
|
|
func GetContextKey(c *gin.Context, key constant.ContextKey) (any, bool) {
|
|
return c.Get(string(key))
|
|
}
|
|
|
|
func GetContextKeyString(c *gin.Context, key constant.ContextKey) string {
|
|
return c.GetString(string(key))
|
|
}
|
|
|
|
func GetContextKeyInt(c *gin.Context, key constant.ContextKey) int {
|
|
return c.GetInt(string(key))
|
|
}
|
|
|
|
func GetContextKeyBool(c *gin.Context, key constant.ContextKey) bool {
|
|
return c.GetBool(string(key))
|
|
}
|
|
|
|
func GetContextKeyStringSlice(c *gin.Context, key constant.ContextKey) []string {
|
|
return c.GetStringSlice(string(key))
|
|
}
|
|
|
|
func GetContextKeyStringMap(c *gin.Context, key constant.ContextKey) map[string]any {
|
|
return c.GetStringMap(string(key))
|
|
}
|
|
|
|
func GetContextKeyTime(c *gin.Context, key constant.ContextKey) time.Time {
|
|
return c.GetTime(string(key))
|
|
}
|
|
|
|
func GetContextKeyType[T any](c *gin.Context, key constant.ContextKey) (T, bool) {
|
|
if value, ok := c.Get(string(key)); ok {
|
|
if v, ok := value.(T); ok {
|
|
return v, true
|
|
}
|
|
}
|
|
var t T
|
|
return t, false
|
|
}
|
|
|
|
func ApiError(c *gin.Context, err error) {
|
|
c.JSON(http.StatusOK, gin.H{
|
|
"success": false,
|
|
"message": err.Error(),
|
|
})
|
|
}
|
|
|
|
func ApiErrorMsg(c *gin.Context, msg string) {
|
|
c.JSON(http.StatusOK, gin.H{
|
|
"success": false,
|
|
"message": msg,
|
|
})
|
|
}
|
|
|
|
func ApiSuccess(c *gin.Context, data any) {
|
|
c.JSON(http.StatusOK, gin.H{
|
|
"success": true,
|
|
"message": "",
|
|
"data": data,
|
|
})
|
|
}
|
|
|
|
func ParseMultipartFormReusable(c *gin.Context) (*multipart.Form, error) {
|
|
requestBody, err := GetRequestBody(c)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
contentType := c.Request.Header.Get("Content-Type")
|
|
boundary, err := parseBoundary(contentType)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
reader := multipart.NewReader(bytes.NewReader(requestBody), boundary)
|
|
form, err := reader.ReadForm(multipartMemoryLimit())
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Reset request body
|
|
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
|
|
return form, nil
|
|
}
|
|
|
|
func processFormMap(formMap map[string]any, v any) error {
|
|
jsonData, err := Marshal(formMap)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
err = Unmarshal(jsonData, v)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func parseFormData(data []byte, v any) error {
|
|
values, err := url.ParseQuery(string(data))
|
|
if err != nil {
|
|
return err
|
|
}
|
|
formMap := make(map[string]any)
|
|
for key, vals := range values {
|
|
if len(vals) == 1 {
|
|
formMap[key] = vals[0]
|
|
} else {
|
|
formMap[key] = vals
|
|
}
|
|
}
|
|
|
|
return processFormMap(formMap, v)
|
|
}
|
|
|
|
func parseMultipartFormData(c *gin.Context, data []byte, v any) error {
|
|
contentType := c.Request.Header.Get("Content-Type")
|
|
boundary, err := parseBoundary(contentType)
|
|
if err != nil {
|
|
if errors.Is(err, errBoundaryNotFound) {
|
|
return Unmarshal(data, v) // Fallback to JSON
|
|
}
|
|
return err
|
|
}
|
|
|
|
reader := multipart.NewReader(bytes.NewReader(data), boundary)
|
|
form, err := reader.ReadForm(multipartMemoryLimit())
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer form.RemoveAll()
|
|
formMap := make(map[string]any)
|
|
for key, vals := range form.Value {
|
|
if len(vals) == 1 {
|
|
formMap[key] = vals[0]
|
|
} else {
|
|
formMap[key] = vals
|
|
}
|
|
}
|
|
|
|
return processFormMap(formMap, v)
|
|
}
|
|
|
|
var errBoundaryNotFound = errors.New("multipart boundary not found")
|
|
|
|
// parseBoundary extracts the multipart boundary from the Content-Type header using mime.ParseMediaType
|
|
func parseBoundary(contentType string) (string, error) {
|
|
if contentType == "" {
|
|
return "", errBoundaryNotFound
|
|
}
|
|
// Boundary-UUID / boundary-------xxxxxx
|
|
_, params, err := mime.ParseMediaType(contentType)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
boundary, ok := params["boundary"]
|
|
if !ok || boundary == "" {
|
|
return "", errBoundaryNotFound
|
|
}
|
|
return boundary, nil
|
|
}
|
|
|
|
// multipartMemoryLimit returns the configured multipart memory limit in bytes
|
|
func multipartMemoryLimit() int64 {
|
|
limitMB := constant.MaxFileDownloadMB
|
|
if limitMB <= 0 {
|
|
limitMB = 32
|
|
}
|
|
return int64(limitMB) << 20
|
|
}
|