mirror of
https://github.com/QuantumNous/new-api.git
synced 2026-03-30 00:46:42 +00:00
Merge branch 'upstream-main' into feature/pyro
This commit is contained in:
@@ -71,15 +71,66 @@ func getMP3Duration(r io.Reader) (float64, error) {
|
||||
|
||||
// getWAVDuration 解析 WAV 文件头以获取时长。
|
||||
func getWAVDuration(r io.ReadSeeker) (float64, error) {
|
||||
// 1. 强制复位指针
|
||||
r.Seek(0, io.SeekStart)
|
||||
|
||||
dec := wav.NewDecoder(r)
|
||||
|
||||
// IsValidFile 会读取 fmt 块
|
||||
if !dec.IsValidFile() {
|
||||
return 0, errors.New("invalid wav file")
|
||||
}
|
||||
d, err := dec.Duration()
|
||||
if err != nil {
|
||||
return 0, errors.Wrap(err, "failed to get wav duration")
|
||||
|
||||
// 尝试寻找 data 块
|
||||
if err := dec.FwdToPCM(); err != nil {
|
||||
return 0, errors.Wrap(err, "failed to find PCM data chunk")
|
||||
}
|
||||
return d.Seconds(), nil
|
||||
|
||||
pcmSize := int64(dec.PCMSize)
|
||||
|
||||
// 如果读出来的 Size 是 0,尝试用文件大小反推
|
||||
if pcmSize == 0 {
|
||||
// 获取文件总大小
|
||||
currentPos, _ := r.Seek(0, io.SeekCurrent) // 当前通常在 data chunk header 之后
|
||||
endPos, _ := r.Seek(0, io.SeekEnd)
|
||||
fileSize := endPos
|
||||
|
||||
// 恢复位置(虽然如果不继续读也没关系)
|
||||
r.Seek(currentPos, io.SeekStart)
|
||||
|
||||
// 数据区大小 ≈ 文件总大小 - 当前指针位置(即Header大小)
|
||||
// 注意:FwdToPCM 成功后,CurrentPos 应该刚好指向 Data 区数据的开始
|
||||
// 或者是 Data Chunk ID + Size 之后。
|
||||
// WAV Header 一般 44 字节。
|
||||
if fileSize > 44 {
|
||||
// 如果 FwdToPCM 成功,Reader 应该位于 data 块的数据起始处
|
||||
// 所以剩余的所有字节理论上都是音频数据
|
||||
pcmSize = fileSize - currentPos
|
||||
|
||||
// 简单的兜底:如果算出来还是负数或0,强制按文件大小-44计算
|
||||
if pcmSize <= 0 {
|
||||
pcmSize = fileSize - 44
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
numChans := int64(dec.NumChans)
|
||||
bitDepth := int64(dec.BitDepth)
|
||||
sampleRate := float64(dec.SampleRate)
|
||||
|
||||
if sampleRate == 0 || numChans == 0 || bitDepth == 0 {
|
||||
return 0, errors.New("invalid wav header metadata")
|
||||
}
|
||||
|
||||
bytesPerFrame := numChans * (bitDepth / 8)
|
||||
if bytesPerFrame == 0 {
|
||||
return 0, errors.New("invalid byte depth calculation")
|
||||
}
|
||||
|
||||
totalFrames := pcmSize / bytesPerFrame
|
||||
|
||||
durationSeconds := float64(totalFrames) / sampleRate
|
||||
return durationSeconds, nil
|
||||
}
|
||||
|
||||
// getFLACDuration 解析 FLAC 文件的 STREAMINFO 块。
|
||||
|
||||
@@ -2,7 +2,7 @@ package common
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"mime"
|
||||
"mime/multipart"
|
||||
@@ -12,24 +12,61 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/QuantumNous/new-api/constant"
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
const KeyRequestBody = "key_request_body"
|
||||
|
||||
func GetRequestBody(c *gin.Context) ([]byte, error) {
|
||||
requestBody, _ := c.Get(KeyRequestBody)
|
||||
if requestBody != nil {
|
||||
return requestBody.([]byte), nil
|
||||
var ErrRequestBodyTooLarge = errors.New("request body too large")
|
||||
|
||||
func IsRequestBodyTooLargeError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
requestBody, err := io.ReadAll(c.Request.Body)
|
||||
if errors.Is(err, ErrRequestBodyTooLarge) {
|
||||
return true
|
||||
}
|
||||
var mbe *http.MaxBytesError
|
||||
return errors.As(err, &mbe)
|
||||
}
|
||||
|
||||
func GetRequestBody(c *gin.Context) ([]byte, error) {
|
||||
cached, exists := c.Get(KeyRequestBody)
|
||||
if exists && cached != nil {
|
||||
if b, ok := cached.([]byte); ok {
|
||||
return b, nil
|
||||
}
|
||||
}
|
||||
maxMB := constant.MaxRequestBodyMB
|
||||
if maxMB < 0 {
|
||||
// no limit
|
||||
body, err := io.ReadAll(c.Request.Body)
|
||||
_ = c.Request.Body.Close()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
c.Set(KeyRequestBody, body)
|
||||
return body, nil
|
||||
}
|
||||
maxBytes := int64(maxMB) << 20
|
||||
|
||||
limited := io.LimitReader(c.Request.Body, maxBytes+1)
|
||||
body, err := io.ReadAll(limited)
|
||||
if err != nil {
|
||||
_ = c.Request.Body.Close()
|
||||
if IsRequestBodyTooLargeError(err) {
|
||||
return nil, errors.Wrap(ErrRequestBodyTooLarge, fmt.Sprintf("request body exceeds %d MB", maxMB))
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
_ = c.Request.Body.Close()
|
||||
c.Set(KeyRequestBody, requestBody)
|
||||
return requestBody.([]byte), nil
|
||||
if int64(len(body)) > maxBytes {
|
||||
return nil, errors.Wrap(ErrRequestBodyTooLarge, fmt.Sprintf("request body exceeds %d MB", maxMB))
|
||||
}
|
||||
c.Set(KeyRequestBody, body)
|
||||
return body, nil
|
||||
}
|
||||
|
||||
func UnmarshalBodyReusable(c *gin.Context, v any) error {
|
||||
|
||||
@@ -117,6 +117,8 @@ func initConstantEnv() {
|
||||
constant.DifyDebug = GetEnvOrDefaultBool("DIFY_DEBUG", true)
|
||||
constant.MaxFileDownloadMB = GetEnvOrDefault("MAX_FILE_DOWNLOAD_MB", 20)
|
||||
constant.StreamScannerMaxBufferMB = GetEnvOrDefault("STREAM_SCANNER_MAX_BUFFER_MB", 64)
|
||||
// MaxRequestBodyMB 请求体最大大小(解压后),用于防止超大请求/zip bomb导致内存暴涨
|
||||
constant.MaxRequestBodyMB = GetEnvOrDefault("MAX_REQUEST_BODY_MB", 64)
|
||||
// ForceStreamOption 覆盖请求参数,强制返回usage信息
|
||||
constant.ForceStreamOption = GetEnvOrDefaultBool("FORCE_STREAM_OPTION", true)
|
||||
constant.CountToken = GetEnvOrDefaultBool("CountToken", true)
|
||||
|
||||
29
common/ip.go
29
common/ip.go
@@ -2,6 +2,15 @@ package common
|
||||
|
||||
import "net"
|
||||
|
||||
func IsIP(s string) bool {
|
||||
ip := net.ParseIP(s)
|
||||
return ip != nil
|
||||
}
|
||||
|
||||
func ParseIP(s string) net.IP {
|
||||
return net.ParseIP(s)
|
||||
}
|
||||
|
||||
func IsPrivateIP(ip net.IP) bool {
|
||||
if ip.IsLoopback() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() {
|
||||
return true
|
||||
@@ -20,3 +29,23 @@ func IsPrivateIP(ip net.IP) bool {
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func IsIpInCIDRList(ip net.IP, cidrList []string) bool {
|
||||
for _, cidr := range cidrList {
|
||||
_, network, err := net.ParseCIDR(cidr)
|
||||
if err != nil {
|
||||
// 尝试作为单个IP处理
|
||||
if whitelistIP := net.ParseIP(cidr); whitelistIP != nil {
|
||||
if ip.Equal(whitelistIP) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if network.Contains(ip) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -186,23 +186,7 @@ func isIPListed(ip net.IP, list []string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
for _, whitelistCIDR := range list {
|
||||
_, network, err := net.ParseCIDR(whitelistCIDR)
|
||||
if err != nil {
|
||||
// 尝试作为单个IP处理
|
||||
if whitelistIP := net.ParseIP(whitelistCIDR); whitelistIP != nil {
|
||||
if ip.Equal(whitelistIP) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if network.Contains(ip) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
return IsIpInCIDRList(ip, list)
|
||||
}
|
||||
|
||||
// IsIPAccessAllowed 检查IP是否允许访问
|
||||
|
||||
@@ -217,11 +217,6 @@ func IntMax(a int, b int) int {
|
||||
}
|
||||
}
|
||||
|
||||
func IsIP(s string) bool {
|
||||
ip := net.ParseIP(s)
|
||||
return ip != nil
|
||||
}
|
||||
|
||||
func GetUUID() string {
|
||||
code := uuid.New().String()
|
||||
code = strings.Replace(code, "-", "", -1)
|
||||
|
||||
Reference in New Issue
Block a user