From 2b70095b475c205c23560aae3f9a53b7151460c2 Mon Sep 17 00:00:00 2001 From: CaIon Date: Tue, 28 Oct 2025 15:50:45 +0800 Subject: [PATCH] feat: implement audio duration retrieval without ffmpeg dependencies --- Dockerfile | 2 +- common/audio.go | 295 +++++++++++++++++++++++++++ common/gin.go | 8 +- common/utils.go | 39 ---- go.mod | 13 ++ go.sum | 40 ++++ middleware/distributor.go | 77 ++++--- relay/channel/openai/relay-openai.go | 54 +---- relay/helper/valid_request.go | 10 - service/token_counter.go | 31 ++- 10 files changed, 430 insertions(+), 139 deletions(-) create mode 100644 common/audio.go diff --git a/Dockerfile b/Dockerfile index 89f1bc75a..c7348add8 100644 --- a/Dockerfile +++ b/Dockerfile @@ -28,7 +28,7 @@ RUN go build -ldflags "-s -w -X 'github.com/QuantumNous/new-api/common.Version=$ FROM alpine RUN apk upgrade --no-cache \ - && apk add --no-cache ca-certificates tzdata ffmpeg \ + && apk add --no-cache ca-certificates tzdata \ && update-ca-certificates COPY --from=builder2 /build/new-api / diff --git a/common/audio.go b/common/audio.go new file mode 100644 index 000000000..e41b61653 --- /dev/null +++ b/common/audio.go @@ -0,0 +1,295 @@ +package common + +import ( + "context" + "encoding/binary" + "fmt" + "io" + + "github.com/abema/go-mp4" + "github.com/go-audio/aiff" + "github.com/go-audio/wav" + "github.com/jfreymuth/oggvorbis" + "github.com/mewkiz/flac" + "github.com/pkg/errors" + "github.com/tcolgate/mp3" + "github.com/yapingcat/gomedia/go-codec" +) + +// GetAudioDuration 使用纯 Go 库获取音频文件的时长(秒)。 +// 它不再依赖外部的 ffmpeg 或 ffprobe 程序。 +func GetAudioDuration(ctx context.Context, f io.ReadSeeker, ext string) (duration float64, err error) { + SysLog(fmt.Sprintf("GetAudioDuration: ext=%s", ext)) + // 根据文件扩展名选择解析器 + switch ext { + case ".mp3": + duration, err = getMP3Duration(f) + case ".wav": + duration, err = getWAVDuration(f) + case ".flac": + duration, err = getFLACDuration(f) + case ".m4a", ".mp4": + duration, err = getM4ADuration(f) + case ".ogg", ".oga": + duration, err = getOGGDuration(f) + case ".opus": + duration, err = getOpusDuration(f) + case ".aiff", ".aif", ".aifc": + duration, err = getAIFFDuration(f) + case ".webm": + duration, err = getWebMDuration(f) + case ".aac": + duration, err = getAACDuration(f) + default: + return 0, fmt.Errorf("unsupported audio format: %s", ext) + } + SysLog(fmt.Sprintf("GetAudioDuration: duration=%f", duration)) + return duration, err +} + +// getMP3Duration 解析 MP3 文件以获取时长。 +// 注意:对于 VBR (Variable Bitrate) MP3,这个估算可能不完全精确,但通常足够好。 +// FFmpeg 在这种情况下会扫描整个文件来获得精确值,但这里的库提供了快速估算。 +func getMP3Duration(r io.Reader) (float64, error) { + d := mp3.NewDecoder(r) + var f mp3.Frame + skipped := 0 + duration := 0.0 + + for { + if err := d.Decode(&f, &skipped); err != nil { + if err == io.EOF { + break + } + return 0, errors.Wrap(err, "failed to decode mp3 frame") + } + duration += f.Duration().Seconds() + } + return duration, nil +} + +// getWAVDuration 解析 WAV 文件头以获取时长。 +func getWAVDuration(r io.ReadSeeker) (float64, error) { + dec := wav.NewDecoder(r) + if !dec.IsValidFile() { + return 0, errors.New("invalid wav file") + } + d, err := dec.Duration() + if err != nil { + return 0, errors.Wrap(err, "failed to get wav duration") + } + return d.Seconds(), nil +} + +// getFLACDuration 解析 FLAC 文件的 STREAMINFO 块。 +func getFLACDuration(r io.Reader) (float64, error) { + stream, err := flac.Parse(r) + if err != nil { + return 0, errors.Wrap(err, "failed to parse flac stream") + } + defer stream.Close() + + // 时长 = 总采样数 / 采样率 + duration := float64(stream.Info.NSamples) / float64(stream.Info.SampleRate) + return duration, nil +} + +// getM4ADuration 解析 M4A/MP4 文件的 'mvhd' box。 +func getM4ADuration(r io.ReadSeeker) (float64, error) { + // go-mp4 库需要 ReadSeeker 接口 + info, err := mp4.Probe(r) + if err != nil { + return 0, errors.Wrap(err, "failed to probe m4a/mp4 file") + } + // 时长 = Duration / Timescale + return float64(info.Duration) / float64(info.Timescale), nil +} + +// getOGGDuration 解析 OGG/Vorbis 文件以获取时长。 +func getOGGDuration(r io.ReadSeeker) (float64, error) { + // 重置 reader 到开头 + if _, err := r.Seek(0, io.SeekStart); err != nil { + return 0, errors.Wrap(err, "failed to seek ogg file") + } + + reader, err := oggvorbis.NewReader(r) + if err != nil { + return 0, errors.Wrap(err, "failed to create ogg vorbis reader") + } + + // 计算时长 = 总采样数 / 采样率 + // 需要读取整个文件来获取总采样数 + channels := reader.Channels() + sampleRate := reader.SampleRate() + + // 估算方法:读取到文件结尾 + var totalSamples int64 + buf := make([]float32, 4096*channels) + for { + n, err := reader.Read(buf) + if err == io.EOF { + break + } + if err != nil { + return 0, errors.Wrap(err, "failed to read ogg samples") + } + totalSamples += int64(n / channels) + } + + duration := float64(totalSamples) / float64(sampleRate) + return duration, nil +} + +// getOpusDuration 解析 Opus 文件(在 OGG 容器中)以获取时长。 +func getOpusDuration(r io.ReadSeeker) (float64, error) { + // Opus 通常封装在 OGG 容器中 + // 我们需要解析 OGG 页面来获取时长信息 + if _, err := r.Seek(0, io.SeekStart); err != nil { + return 0, errors.Wrap(err, "failed to seek opus file") + } + + // 读取 OGG 页面头部 + var totalGranulePos int64 + buf := make([]byte, 27) // OGG 页面头部最小大小 + + for { + n, err := r.Read(buf) + if err == io.EOF { + break + } + if err != nil { + return 0, errors.Wrap(err, "failed to read opus/ogg page") + } + if n < 27 { + break + } + + // 检查 OGG 页面标识 "OggS" + if string(buf[0:4]) != "OggS" { + // 跳过一些字节继续寻找 + if _, err := r.Seek(-26, io.SeekCurrent); err != nil { + break + } + continue + } + + // 读取 granule position (字节 6-13, 小端序) + granulePos := int64(binary.LittleEndian.Uint64(buf[6:14])) + if granulePos > totalGranulePos { + totalGranulePos = granulePos + } + + // 读取段表大小 + numSegments := int(buf[26]) + segmentTable := make([]byte, numSegments) + if _, err := io.ReadFull(r, segmentTable); err != nil { + break + } + + // 计算页面数据大小并跳过 + var pageSize int + for _, segSize := range segmentTable { + pageSize += int(segSize) + } + if _, err := r.Seek(int64(pageSize), io.SeekCurrent); err != nil { + break + } + } + + // Opus 的采样率固定为 48000 Hz + duration := float64(totalGranulePos) / 48000.0 + return duration, nil +} + +// getAIFFDuration 解析 AIFF 文件头以获取时长。 +func getAIFFDuration(r io.ReadSeeker) (float64, error) { + if _, err := r.Seek(0, io.SeekStart); err != nil { + return 0, errors.Wrap(err, "failed to seek aiff file") + } + + dec := aiff.NewDecoder(r) + if !dec.IsValidFile() { + return 0, errors.New("invalid aiff file") + } + + d, err := dec.Duration() + if err != nil { + return 0, errors.Wrap(err, "failed to get aiff duration") + } + + return d.Seconds(), nil +} + +// getWebMDuration 解析 WebM 文件以获取时长。 +// WebM 使用 Matroska 容器格式 +func getWebMDuration(r io.ReadSeeker) (float64, error) { + if _, err := r.Seek(0, io.SeekStart); err != nil { + return 0, errors.Wrap(err, "failed to seek webm file") + } + + // WebM/Matroska 文件的解析比较复杂 + // 这里提供一个简化的实现,读取 EBML 头部 + // 对于完整的 WebM 解析,可能需要使用专门的库 + + // 简单实现:查找 Duration 元素 + // WebM Duration 的 Element ID 是 0x4489 + // 这是一个简化版本,可能不适用于所有 WebM 文件 + buf := make([]byte, 8192) + n, err := r.Read(buf) + if err != nil && err != io.EOF { + return 0, errors.Wrap(err, "failed to read webm file") + } + + // 尝试查找 Duration 元素(这是一个简化的方法) + // 实际的 WebM 解析需要完整的 EBML 解析器 + // 这里返回错误,建议使用专门的库 + if n > 0 { + // 检查 EBML 标识 + if len(buf) >= 4 && binary.BigEndian.Uint32(buf[0:4]) == 0x1A45DFA3 { + // 这是一个有效的 EBML 文件 + // 但完整解析需要更复杂的逻辑 + return 0, errors.New("webm duration parsing requires full EBML parser (consider using ffprobe for webm files)") + } + } + + return 0, errors.New("failed to parse webm file") +} + +// getAACDuration 解析 AAC (ADTS格式) 文件以获取时长。 +// 使用 gomedia 库来解析 AAC ADTS 帧 +func getAACDuration(r io.ReadSeeker) (float64, error) { + if _, err := r.Seek(0, io.SeekStart); err != nil { + return 0, errors.Wrap(err, "failed to seek aac file") + } + + // 读取整个文件内容 + data, err := io.ReadAll(r) + if err != nil { + return 0, errors.Wrap(err, "failed to read aac file") + } + + var totalFrames int64 + var sampleRate int + + // 使用 gomedia 的 SplitAACFrame 函数来分割 AAC 帧 + codec.SplitAACFrame(data, func(aac []byte) { + // 解析 ADTS 头部以获取采样率信息 + if len(aac) >= 7 { + // 使用 ConvertADTSToASC 来获取音频配置信息 + asc, err := codec.ConvertADTSToASC(aac) + if err == nil && sampleRate == 0 { + sampleRate = codec.AACSampleIdxToSample(int(asc.Sample_freq_index)) + } + totalFrames++ + } + }) + + if sampleRate == 0 || totalFrames == 0 { + return 0, errors.New("no valid aac frames found") + } + + // 每个 AAC ADTS 帧包含 1024 个采样 + totalSamples := totalFrames * 1024 + duration := float64(totalSamples) / float64(sampleRate) + return duration, nil +} diff --git a/common/gin.go b/common/gin.go index cc83a5f98..e8d8bda3a 100644 --- a/common/gin.go +++ b/common/gin.go @@ -163,7 +163,7 @@ func parseFormData(data []byte, v any) error { return err } - return json.Unmarshal(jsonData, v) + return Unmarshal(jsonData, v) } func parseMultipartFormData(c *gin.Context, data []byte, v any) error { @@ -174,7 +174,7 @@ func parseMultipartFormData(c *gin.Context, data []byte, v any) error { } if boundary == "" { - return json.Unmarshal(data, v) // Fallback to JSON + return Unmarshal(data, v) // Fallback to JSON } reader := multipart.NewReader(bytes.NewReader(data), boundary) @@ -191,10 +191,10 @@ func parseMultipartFormData(c *gin.Context, data []byte, v any) error { formMap[key] = vals } } - jsonData, err := json.Marshal(formMap) + jsonData, err := Marshal(formMap) if err != nil { return err } - return json.Unmarshal(jsonData, v) + return Unmarshal(jsonData, v) } diff --git a/common/utils.go b/common/utils.go index 21f72ec6a..3492f7e44 100644 --- a/common/utils.go +++ b/common/utils.go @@ -1,8 +1,6 @@ package common import ( - "bytes" - "context" crand "crypto/rand" "encoding/base64" "encoding/json" @@ -329,43 +327,6 @@ func SaveTmpFile(filename string, data io.Reader) (string, error) { return f.Name(), nil } -// GetAudioDuration returns the duration of an audio file in seconds. -func GetAudioDuration(ctx context.Context, filename string, ext string) (float64, error) { - // ffprobe -v error -show_entries format=duration -of default=noprint_wrappers=1:nokey=1 {{input}} - c := exec.CommandContext(ctx, "ffprobe", "-v", "error", "-show_entries", "format=duration", "-of", "default=noprint_wrappers=1:nokey=1", filename) - output, err := c.Output() - if err != nil { - return 0, errors.Wrap(err, "failed to get audio duration") - } - durationStr := string(bytes.TrimSpace(output)) - if durationStr == "N/A" { - // Create a temporary output file name - tmpFp, err := os.CreateTemp("", "audio-*"+ext) - if err != nil { - return 0, errors.Wrap(err, "failed to create temporary file") - } - tmpName := tmpFp.Name() - // Close immediately so ffmpeg can open the file on Windows. - _ = tmpFp.Close() - defer os.Remove(tmpName) - - // ffmpeg -y -i filename -vcodec copy -acodec copy - ffmpegCmd := exec.CommandContext(ctx, "ffmpeg", "-y", "-i", filename, "-vcodec", "copy", "-acodec", "copy", tmpName) - if err := ffmpegCmd.Run(); err != nil { - return 0, errors.Wrap(err, "failed to run ffmpeg") - } - - // Recalculate the duration of the new file - c = exec.CommandContext(ctx, "ffprobe", "-v", "error", "-show_entries", "format=duration", "-of", "default=noprint_wrappers=1:nokey=1", tmpName) - output, err := c.Output() - if err != nil { - return 0, errors.Wrap(err, "failed to get audio duration after ffmpeg") - } - durationStr = string(bytes.TrimSpace(output)) - } - return strconv.ParseFloat(durationStr, 64) -} - // BuildURL concatenates base and endpoint, returns the complete url string func BuildURL(base string, endpoint string) string { u, err := url.Parse(base) diff --git a/go.mod b/go.mod index b15bbadb2..87c494370 100644 --- a/go.mod +++ b/go.mod @@ -5,6 +5,7 @@ go 1.25.1 require ( github.com/Calcium-Ion/go-epay v0.0.4 + github.com/abema/go-mp4 v1.4.1 github.com/andybalholm/brotli v1.1.1 github.com/anknown/ahocorasick v0.0.0-20190904063843-d75dbd5169c0 github.com/aws/aws-sdk-go-v2 v1.37.2 @@ -18,24 +19,30 @@ require ( github.com/gin-contrib/static v0.0.1 github.com/gin-gonic/gin v1.9.1 github.com/glebarez/sqlite v1.9.0 + github.com/go-audio/aiff v1.1.0 + github.com/go-audio/wav v1.1.0 github.com/go-playground/validator/v10 v10.20.0 github.com/go-redis/redis/v8 v8.11.5 github.com/go-webauthn/webauthn v0.14.0 github.com/golang-jwt/jwt/v5 v5.3.0 github.com/google/uuid v1.6.0 github.com/gorilla/websocket v1.5.0 + github.com/jfreymuth/oggvorbis v1.0.5 github.com/jinzhu/copier v0.4.0 github.com/joho/godotenv v1.5.1 + github.com/mewkiz/flac v1.0.13 github.com/pkg/errors v0.9.1 github.com/pquerna/otp v1.5.0 github.com/samber/lo v1.39.0 github.com/shirou/gopsutil v3.21.11+incompatible github.com/shopspring/decimal v1.4.0 github.com/stripe/stripe-go/v81 v81.4.0 + github.com/tcolgate/mp3 v0.0.0-20170426193717-e79c5a46d300 github.com/thanhpk/randstr v1.0.6 github.com/tidwall/gjson v1.18.0 github.com/tidwall/sjson v1.2.5 github.com/tiktoken-go/tokenizer v0.6.2 + github.com/yapingcat/gomedia v0.0.0-20240906162731-17feea57090c golang.org/x/crypto v0.42.0 golang.org/x/image v0.23.0 golang.org/x/net v0.43.0 @@ -62,6 +69,8 @@ require ( github.com/gabriel-vasile/mimetype v1.4.3 // indirect github.com/gin-contrib/sse v0.1.0 // indirect github.com/glebarez/go-sqlite v1.21.2 // indirect + github.com/go-audio/audio v1.0.0 // indirect + github.com/go-audio/riff v1.0.0 // indirect github.com/go-ole/go-ole v1.2.6 // indirect github.com/go-playground/locales v0.14.1 // indirect github.com/go-playground/universal-translator v0.18.1 // indirect @@ -73,16 +82,20 @@ require ( github.com/gorilla/context v1.1.1 // indirect github.com/gorilla/securecookie v1.1.1 // indirect github.com/gorilla/sessions v1.2.1 // indirect + github.com/icza/bitio v1.1.0 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect github.com/jackc/pgx/v5 v5.7.1 // indirect github.com/jackc/puddle/v2 v2.2.2 // indirect + github.com/jfreymuth/vorbis v1.0.2 // indirect github.com/jinzhu/inflection v1.0.0 // indirect github.com/jinzhu/now v1.1.5 // indirect github.com/json-iterator/go v1.1.12 // indirect github.com/klauspost/cpuid/v2 v2.3.0 // indirect github.com/leodido/go-urn v1.4.0 // indirect github.com/mattn/go-isatty v0.0.20 // indirect + github.com/mewkiz/pkg v0.0.0-20250417130911-3f050ff8c56d // indirect + github.com/mewpkg/term v0.0.0-20241026122259-37a80af23985 // indirect github.com/mitchellh/mapstructure v1.5.0 // indirect github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/reflect2 v1.0.2 // indirect diff --git a/go.sum b/go.sum index bd6bae025..011c86994 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,7 @@ github.com/Calcium-Ion/go-epay v0.0.4 h1:C96M7WfRLadcIVscWzwLiYs8etI1wrDmtFMuK2zP22A= github.com/Calcium-Ion/go-epay v0.0.4/go.mod h1:cxo/ZOg8ClvE3VAnCmEzbuyAZINSq7kFEN9oHj5WQ2U= +github.com/abema/go-mp4 v1.4.1 h1:YoS4VRqd+pAmddRPLFf8vMk74kuGl6ULSjzhsIqwr6M= +github.com/abema/go-mp4 v1.4.1/go.mod h1:vPl9t5ZK7K0x68jh12/+ECWBCXoWuIDtNgPtU2f04ws= github.com/andybalholm/brotli v1.1.1 h1:PR2pgnyFznKEugtsUo0xLdDop5SKXd5Qf5ysW+7XdTA= github.com/andybalholm/brotli v1.1.1/go.mod h1:05ib4cKhjx3OQYUY22hTVd34Bc8upXjOLL2rKwwZBoA= github.com/anknown/ahocorasick v0.0.0-20190904063843-d75dbd5169c0 h1:onfun1RA+KcxaMk1lfrRnwCd1UUuOjJM/lri5eM1qMs= @@ -33,6 +35,7 @@ github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UF github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/cloudwego/base64x v0.1.6 h1:t11wG9AECkCDk5fMSoxmufanudBtJ+/HemLstXDLI2M= github.com/cloudwego/base64x v0.1.6/go.mod h1:OFcloc187FXDaYHvrNIjxSe8ncn0OOM8gEHfghB2IPU= +github.com/creack/pty v1.1.7/go.mod h1:lj5s0c3V2DBrqTV7llrYr5NG6My20zk30Fl46Y7DoTY= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= @@ -67,6 +70,15 @@ github.com/glebarez/go-sqlite v1.21.2 h1:3a6LFC4sKahUunAmynQKLZceZCOzUthkRkEAl9g github.com/glebarez/go-sqlite v1.21.2/go.mod h1:sfxdZyhQjTM2Wry3gVYWaW072Ri1WMdWJi0k6+3382k= github.com/glebarez/sqlite v1.9.0 h1:Aj6bPA12ZEx5GbSF6XADmCkYXlljPNUY+Zf1EQxynXs= github.com/glebarez/sqlite v1.9.0/go.mod h1:YBYCoyupOao60lzp1MVBLEjZfgkq0tdB1voAQ09K9zw= +github.com/go-audio/aiff v1.1.0 h1:m2LYgu/2BarpF2yZnFPWtY3Tp41k0A4y51gDRZZsEuU= +github.com/go-audio/aiff v1.1.0/go.mod h1:sDik1muYvhPiccClfri0fv6U2fyH/dy4VRWmUz0cz9Q= +github.com/go-audio/audio v1.0.0 h1:zS9vebldgbQqktK4H0lUqWrG8P0NxCJVqcj7ZpNnwd4= +github.com/go-audio/audio v1.0.0/go.mod h1:6uAu0+H2lHkwdGsAY+j2wHPNPpPoeg5AaEFh9FlA+Zs= +github.com/go-audio/riff v1.0.0 h1:d8iCGbDvox9BfLagY94fBynxSPHO80LmZCaOsmKxokA= +github.com/go-audio/riff v1.0.0/go.mod h1:l3cQwc85y79NQFCRB7TiPoNiaijp6q8Z0Uv38rVG498= +github.com/go-audio/wav v1.0.0/go.mod h1:3yoReyQOsiARkvPl3ERCi8JFjihzG6WhjYpZCf5zAWE= +github.com/go-audio/wav v1.1.0 h1:jQgLtbqBzY7G+BM8fXF7AHUk1uHUviWS4X39d5rsL2g= +github.com/go-audio/wav v1.1.0/go.mod h1:mpe9qfwbScEbkd8uybLuIpTgHyrISw/OTuvjUW2iGtE= github.com/go-ole/go-ole v1.2.6 h1:/Fpf6oFPoeFik9ty7siob0G6Ke8QvQEuVcuChpwXzpY= github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0= github.com/go-playground/assert/v2 v2.0.1/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4= @@ -108,6 +120,7 @@ github.com/google/go-tpm v0.9.5/go.mod h1:h9jEsEECg7gtLis0upRBQU+GhYVH6jMjrFxI8u github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/pprof v0.0.0-20221118152302-e6195bd50e26 h1:Xim43kblpZXfIBQsbuBVKCudVG457BR2GZFIz3uw3hQ= github.com/google/pprof v0.0.0-20221118152302-e6195bd50e26/go.mod h1:dDKJzRmX4S37WGHujM7tX//fmj1uioxKzKxz3lo4HJo= +github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/gorilla/context v1.1.1 h1:AWwleXJkX/nhcU9bZSnZoi3h/qGYqQAGhq6zZe/aQW8= @@ -118,6 +131,10 @@ github.com/gorilla/sessions v1.2.1 h1:DHd3rPN5lE3Ts3D8rKkQ8x/0kqfeNmBAaiSi+o7Fsg github.com/gorilla/sessions v1.2.1/go.mod h1:dk2InVEVJ0sfLlnXv9EAgkf6ecYs/i80K/zI+bUmuGM= github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc= github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= +github.com/icza/bitio v1.1.0 h1:ysX4vtldjdi3Ygai5m1cWy4oLkhWTAi+SyO6HC8L9T0= +github.com/icza/bitio v1.1.0/go.mod h1:0jGnlLAx8MKMr9VGnn/4YrvZiprkvBelsVIbA9Jjr9A= +github.com/icza/mighty v0.0.0-20180919140131-cfd07d671de6 h1:8UsGZ2rr2ksmEru6lToqnXgA8Mz1DP11X4zSJ159C3k= +github.com/icza/mighty v0.0.0-20180919140131-cfd07d671de6/go.mod h1:xQig96I1VNBDIWGCdTt54nHt6EeI639SmHycLYL7FkA= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo= @@ -126,6 +143,10 @@ github.com/jackc/pgx/v5 v5.7.1 h1:x7SYsPBYDkHDksogeSmZZ5xzThcTgRz++I5E+ePFUcs= github.com/jackc/pgx/v5 v5.7.1/go.mod h1:e7O26IywZZ+naJtWWos6i6fvWK+29etgITqrqHLfoZA= github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo= github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= +github.com/jfreymuth/oggvorbis v1.0.5 h1:u+Ck+R0eLSRhgq8WTmffYnrVtSztJcYrl588DM4e3kQ= +github.com/jfreymuth/oggvorbis v1.0.5/go.mod h1:1U4pqWmghcoVsCJJ4fRBKv9peUJMBHixthRlBeD6uII= +github.com/jfreymuth/vorbis v1.0.2 h1:m1xH6+ZI4thH927pgKD8JOH4eaGRm18rEE9/0WKjvNE= +github.com/jfreymuth/vorbis v1.0.2/go.mod h1:DoftRo4AznKnShRl1GxiTFCseHr4zR9BN3TWXyuzrqQ= github.com/jinzhu/copier v0.4.0 h1:w3ciUoD19shMCRargcpm0cm91ytaBhDvuRpz1ODO/U8= github.com/jinzhu/copier v0.4.0/go.mod h1:DfbEm0FYsaqBcKcFuvmOZb218JkPGtvSHsKg8S8hyyg= github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= @@ -145,6 +166,7 @@ github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfn github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0= github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/pty v1.1.8/go.mod h1:O1sed60cT9XZ5uDucP5qwvh+TE3NnUj51EiZO/lmSfw= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= @@ -152,10 +174,17 @@ github.com/leodido/go-urn v1.2.0/go.mod h1:+8+nEpDfqqsY+g338gtMEUOtuK+4dEMhiQEgx github.com/leodido/go-urn v1.2.1/go.mod h1:zt4jvISO2HfUBqxjfIshjdMTYS56ZS/qv49ictyFfxY= github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ= github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI= +github.com/mattetti/audio v0.0.0-20180912171649-01576cde1f21/go.mod h1:LlQmBGkOuV/SKzEDXBPKauvN2UqCgzXO2XjecTGj40s= github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU= github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/mewkiz/flac v1.0.13 h1:6wF8rRQKBFW159Daqx6Ro7K5ZnlVhHUKfS5aTsC4oXs= +github.com/mewkiz/flac v1.0.13/go.mod h1:HfPYDA+oxjyuqMu2V+cyKcxF51KM6incpw5eZXmfA6k= +github.com/mewkiz/pkg v0.0.0-20250417130911-3f050ff8c56d h1:IL2tii4jXLdhCeQN69HNzYYW1kl0meSG0wt5+sLwszU= +github.com/mewkiz/pkg v0.0.0-20250417130911-3f050ff8c56d/go.mod h1:SIpumAnUWSy0q9RzKD3pyH3g1t5vdawUAPcW5tQrUtI= +github.com/mewpkg/term v0.0.0-20241026122259-37a80af23985 h1:h8O1byDZ1uk6RUXMhj1QJU3VXFKXHDZxr4TXRPGeBa8= +github.com/mewpkg/term v0.0.0-20241026122259-37a80af23985/go.mod h1:uiPmbdUbdt1NkGApKl7htQjZ8S7XaGUAVulJUJ9v6q4= github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY= github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= @@ -170,6 +199,8 @@ github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE= github.com/onsi/ginkgo v1.16.5/go.mod h1:+E8gABHa3K6zRBolWtd+ROzc/U5bkGt0FwiG042wbpU= github.com/onsi/gomega v1.18.1 h1:M1GfJqGRrBrrGGsbxzV5dqM2U2ApXefZCQpkukxYRLE= github.com/onsi/gomega v1.18.1/go.mod h1:0q+aL8jAiMXy9hbwj2mr5GziHiwhAIQpFmmtT5hitRs= +github.com/orcaman/writerseeker v0.0.0-20200621085525-1d3f536ff85e h1:s2RNOM/IGdY0Y6qfTeUKhDawdHDpK9RGBdx80qN4Ttw= +github.com/orcaman/writerseeker v0.0.0-20200621085525-1d3f536ff85e/go.mod h1:nBdnFKj15wFbf94Rwfq4m30eAcyY9V/IyKAGQFtqkW0= github.com/pelletier/go-toml/v2 v2.0.1/go.mod h1:r9LEWfGN8R5k0VXJ+0BkIe7MYkRdwZOjgMj2KwnJFUo= github.com/pelletier/go-toml/v2 v2.2.1 h1:9TA9+T8+8CUCO2+WYnDLCgrYi9+omqKXyjDtosvtEhg= github.com/pelletier/go-toml/v2 v2.2.1/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs= @@ -209,6 +240,9 @@ github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= github.com/stripe/stripe-go/v81 v81.4.0 h1:AuD9XzdAvl193qUCSaLocf8H+nRopOouXhxqJUzCLbw= github.com/stripe/stripe-go/v81 v81.4.0/go.mod h1:C/F4jlmnGNacvYtBp/LUHCvVUJEZffFQCobkzwY1WOo= +github.com/sunfish-shogi/bufseekio v0.0.0-20210207115823-a4185644b365/go.mod h1:dEzdXgvImkQ3WLI+0KQpmEx8T/C/ma9KeS3AfmU899I= +github.com/tcolgate/mp3 v0.0.0-20170426193717-e79c5a46d300 h1:XQdibLKagjdevRB6vAjVY4qbSr8rQ610YzTkWcxzxSI= +github.com/tcolgate/mp3 v0.0.0-20170426193717-e79c5a46d300/go.mod h1:FNa/dfN95vAYCNFrIKRrlRo+MBLbwmR9Asa5f2ljmBI= github.com/thanhpk/randstr v1.0.6 h1:psAOktJFD4vV9NEVb3qkhRSMvYh4ORRaj1+w/hn4B+o= github.com/thanhpk/randstr v1.0.6/go.mod h1:M/H2P1eNLZzlDwAzpkkkUvoyNNMbzRGhESZuEQk3r0U= github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= @@ -238,6 +272,8 @@ github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM= github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg= github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZqKjWU= github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E= +github.com/yapingcat/gomedia v0.0.0-20240906162731-17feea57090c h1:xA2TJS9Hu/ivzaZIrDcwvpJ3Fnpsk5fDOJ4iSnL6J0w= +github.com/yapingcat/gomedia v0.0.0-20240906162731-17feea57090c/go.mod h1:WSZ59bidJOO40JSJmLqlkBJrjZCtjbKKkygEMfzY/kc= github.com/yusufpapurcu/wmi v1.2.3 h1:E1ctvB7uKFMOJw3fdOW32DwGE9I7t++CRUEMKvFoFiw= github.com/yusufpapurcu/wmi v1.2.3/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0= go.uber.org/mock v0.6.0 h1:hyF9dfmbgIX5EfOdasqLsWD6xqpNZlXblLB/Dbnwv3Y= @@ -257,6 +293,7 @@ golang.org/x/net v0.43.0 h1:lat02VYK2j4aLzMzecihNvTlJNQUq316m2Mr9rnM6YE= golang.org/x/net v0.43.0/go.mod h1:vhO1fvI4dGsIjh73sWfUVjj3N7CA9WkKJNQm2svM6Jg= golang.org/x/sync v0.17.0 h1:l60nONMj9l5drqw6jlhIELNv9I0A4OFgRsG9k2oT9Ug= golang.org/x/sync v0.17.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= +golang.org/x/sys v0.0.0-20190726091711-fc99dfbffb4e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= @@ -270,6 +307,7 @@ golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.36.0 h1:KVRy2GtZBrk1cBYA7MKu5bEZFxQk4NIDV6RLVcC8o0k= golang.org/x/sys v0.36.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= @@ -286,6 +324,8 @@ gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8 gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI= +gopkg.in/src-d/go-billy.v4 v4.3.2 h1:0SQA1pRztfTFx2miS8sA97XvooFeNOmvUenF4o0EcVg= +gopkg.in/src-d/go-billy.v4 v4.3.2/go.mod h1:nDjArDMp+XMs1aFAESLRjfGSgfvoYN0hDfzEk0GjC98= gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ= gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= diff --git a/middleware/distributor.go b/middleware/distributor.go index 9faee8013..598d086ec 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -86,7 +86,7 @@ func Distribute() func(c *gin.Context) { playgroundRequest := &dto.PlayGroundRequest{} err = common.UnmarshalBodyReusable(c, playgroundRequest) if err != nil { - abortWithOpenAiMessage(c, http.StatusBadRequest, "无效的请求, "+err.Error()) + abortWithOpenAiMessage(c, http.StatusBadRequest, "无效的playground请求, "+err.Error()) return } if playgroundRequest.Group != "" { @@ -124,6 +124,20 @@ func Distribute() func(c *gin.Context) { } } +// getModelFromRequest 从请求中读取模型信息 +// 根据 Content-Type 自动处理: +// - application/json +// - application/x-www-form-urlencoded +// - multipart/form-data +func getModelFromRequest(c *gin.Context) (*ModelRequest, error) { + var modelRequest ModelRequest + err := common.UnmarshalBodyReusable(c, &modelRequest) + if err != nil { + return nil, errors.New("无效的请求, " + err.Error()) + } + return &modelRequest, nil +} + func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) { var modelRequest ModelRequest shouldSelectChannel := true @@ -139,7 +153,7 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) { midjourneyRequest := dto.MidjourneyRequest{} err = common.UnmarshalBodyReusable(c, &midjourneyRequest) if err != nil { - return nil, false, err + return nil, false, errors.New("无效的midjourney请求, " + err.Error()) } midjourneyModel, mjErr, success := service.GetMjRequestModel(relayMode, &midjourneyRequest) if mjErr != nil { @@ -176,23 +190,12 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) { relayMode := relayconstant.RelayModeUnknown if c.Request.Method == http.MethodPost { relayMode = relayconstant.RelayModeVideoSubmit - contentType := c.Request.Header.Get("Content-Type") - if strings.HasPrefix(contentType, "multipart/form-data") { - form, err := common.ParseMultipartFormReusable(c) - if err != nil { - return nil, false, errors.New("无效的video请求, " + err.Error()) - } - defer form.RemoveAll() - if form != nil { - if values, ok := form.Value["model"]; ok && len(values) > 0 { - modelRequest.Model = values[0] - } - } - } else if strings.HasPrefix(contentType, "application/json") { - err = common.UnmarshalBodyReusable(c, &modelRequest) - if err != nil { - return nil, false, errors.New("无效的video请求, " + err.Error()) - } + req, err := getModelFromRequest(c) + if err != nil { + return nil, false, err + } + if req != nil { + modelRequest.Model = req.Model } } else if c.Request.Method == http.MethodGet { relayMode = relayconstant.RelayModeVideoFetchByID @@ -202,10 +205,11 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) { } else if strings.Contains(c.Request.URL.Path, "/v1/video/generations") { relayMode := relayconstant.RelayModeUnknown if c.Request.Method == http.MethodPost { - err = common.UnmarshalBodyReusable(c, &modelRequest) + req, err := getModelFromRequest(c) if err != nil { - return nil, false, errors.New("video无效的请求, " + err.Error()) + return nil, false, err } + modelRequest.Model = req.Model relayMode = relayconstant.RelayModeVideoSubmit } else if c.Request.Method == http.MethodGet { relayMode = relayconstant.RelayModeVideoFetchByID @@ -223,10 +227,11 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) { } c.Set("relay_mode", relayMode) } else if !strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") && !strings.Contains(c.Request.Header.Get("Content-Type"), "multipart/form-data") { - err = common.UnmarshalBodyReusable(c, &modelRequest) - } - if err != nil { - return nil, false, errors.New("无效的请求, " + err.Error()) + req, err := getModelFromRequest(c) + if err != nil { + return nil, false, err + } + modelRequest.Model = req.Model } if strings.HasPrefix(c.Request.URL.Path, "/v1/realtime") { //wss://api.openai.com/v1/realtime?model=gpt-4o-realtime-preview-2024-10-01 @@ -248,19 +253,29 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) { //modelRequest.Model = common.GetStringIfEmpty(c.PostForm("model"), "gpt-image-1") contentType := c.ContentType() if slices.Contains([]string{gin.MIMEPOSTForm, gin.MIMEMultipartPOSTForm}, contentType) { - modelRequest.Model = c.PostForm("model") + req, err := getModelFromRequest(c) + if err == nil && req.Model != "" { + modelRequest.Model = req.Model + } } } if strings.HasPrefix(c.Request.URL.Path, "/v1/audio") { relayMode := relayconstant.RelayModeAudioSpeech if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/speech") { + modelRequest.Model = common.GetStringIfEmpty(modelRequest.Model, "tts-1") } else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translations") { - modelRequest.Model = common.GetStringIfEmpty(modelRequest.Model, c.PostForm("model")) + // 先尝试从请求读取 + if req, err := getModelFromRequest(c); err == nil && req.Model != "" { + modelRequest.Model = req.Model + } modelRequest.Model = common.GetStringIfEmpty(modelRequest.Model, "whisper-1") relayMode = relayconstant.RelayModeAudioTranslation } else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") { - modelRequest.Model = common.GetStringIfEmpty(modelRequest.Model, c.PostForm("model")) + // 先尝试从请求读取 + if req, err := getModelFromRequest(c); err == nil && req.Model != "" { + modelRequest.Model = req.Model + } modelRequest.Model = common.GetStringIfEmpty(modelRequest.Model, "whisper-1") relayMode = relayconstant.RelayModeAudioTranscription } @@ -268,10 +283,12 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) { } if strings.HasPrefix(c.Request.URL.Path, "/pg/chat/completions") { // playground chat completions - err = common.UnmarshalBodyReusable(c, &modelRequest) + req, err := getModelFromRequest(c) if err != nil { - return nil, false, errors.New("无效的请求, " + err.Error()) + return nil, false, err } + modelRequest.Model = req.Model + modelRequest.Group = req.Group common.SetContextKey(c, constant.ContextKeyTokenGroup, modelRequest.Group) } return &modelRequest, shouldSelectChannel, nil diff --git a/relay/channel/openai/relay-openai.go b/relay/channel/openai/relay-openai.go index c2bc2b0d1..c08e396fe 100644 --- a/relay/channel/openai/relay-openai.go +++ b/relay/channel/openai/relay-openai.go @@ -1,15 +1,10 @@ package openai import ( - "bytes" "encoding/json" "fmt" "io" - "math" - "mime/multipart" "net/http" - "os" - "path/filepath" "strings" "github.com/QuantumNous/new-api/common" @@ -26,7 +21,6 @@ import ( "github.com/bytedance/gopkg/util/gopool" "github.com/gin-gonic/gin" "github.com/gorilla/websocket" - "github.com/pkg/errors" ) func sendStreamData(c *gin.Context, info *relaycommon.RelayInfo, data string, forceFormat bool, thinkToContent bool) error { @@ -361,59 +355,13 @@ func OpenaiSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel } } - audioTokens, err := countAudioTokens(c) - if err != nil { - return types.NewError(err, types.ErrorCodeCountTokenFailed), nil - } usage := &dto.Usage{} - usage.PromptTokens = audioTokens + usage.PromptTokens = info.PromptTokens usage.CompletionTokens = 0 usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens return nil, usage } -func countAudioTokens(c *gin.Context) (int, error) { - body, err := common.GetRequestBody(c) - if err != nil { - return 0, errors.WithStack(err) - } - - var reqBody struct { - File *multipart.FileHeader `form:"file" binding:"required"` - } - c.Request.Body = io.NopCloser(bytes.NewReader(body)) - if err = c.ShouldBind(&reqBody); err != nil { - return 0, errors.WithStack(err) - } - ext := filepath.Ext(reqBody.File.Filename) // 获取文件扩展名 - reqFp, err := reqBody.File.Open() - if err != nil { - return 0, errors.WithStack(err) - } - defer reqFp.Close() - - tmpFp, err := os.CreateTemp("", "audio-*"+ext) - if err != nil { - return 0, errors.WithStack(err) - } - defer os.Remove(tmpFp.Name()) - - _, err = io.Copy(tmpFp, reqFp) - if err != nil { - return 0, errors.WithStack(err) - } - if err = tmpFp.Close(); err != nil { - return 0, errors.WithStack(err) - } - - duration, err := common.GetAudioDuration(c.Request.Context(), tmpFp.Name(), ext) - if err != nil { - return 0, errors.WithStack(err) - } - - return int(math.Round(math.Ceil(duration) / 60.0 * 1000)), nil // 1 minute 相当于 1k tokens -} - func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*types.NewAPIError, *dto.RealtimeUsage) { if info == nil || info.ClientWs == nil || info.TargetWs == nil { return types.NewError(fmt.Errorf("invalid websocket connection"), types.ErrorCodeBadResponse), nil diff --git a/relay/helper/valid_request.go b/relay/helper/valid_request.go index 4dcf3070c..f59ad2444 100644 --- a/relay/helper/valid_request.go +++ b/relay/helper/valid_request.go @@ -62,19 +62,9 @@ func GetAndValidAudioRequest(c *gin.Context, relayMode int) (*dto.AudioRequest, return nil, errors.New("model is required") } default: - err = c.Request.ParseForm() - if err != nil { - return nil, err - } - formData := c.Request.PostForm - if audioRequest.Model == "" { - audioRequest.Model = formData.Get("model") - } - if audioRequest.Model == "" { return nil, errors.New("model is required") } - audioRequest.ResponseFormat = formData.Get("response_format") if audioRequest.ResponseFormat == "" { audioRequest.ResponseFormat = "json" } diff --git a/service/token_counter.go b/service/token_counter.go index 87ef3b3ec..325fbd7ab 100644 --- a/service/token_counter.go +++ b/service/token_counter.go @@ -10,6 +10,7 @@ import ( _ "image/png" "log" "math" + "path/filepath" "strings" "sync" "unicode/utf8" @@ -18,6 +19,7 @@ import ( "github.com/QuantumNous/new-api/constant" "github.com/QuantumNous/new-api/dto" relaycommon "github.com/QuantumNous/new-api/relay/common" + constant2 "github.com/QuantumNous/new-api/relay/constant" "github.com/QuantumNous/new-api/types" "github.com/gin-gonic/gin" @@ -254,6 +256,10 @@ func getImageToken(fileMeta *types.FileMeta, model string, stream bool) (int, er } func CountRequestToken(c *gin.Context, meta *types.TokenCountMeta, info *relaycommon.RelayInfo) (int, error) { + if meta == nil { + return 0, errors.New("token count meta is nil") + } + if !constant.GetMediaToken { return 0, nil } @@ -263,8 +269,29 @@ func CountRequestToken(c *gin.Context, meta *types.TokenCountMeta, info *relayco if info.RelayFormat == types.RelayFormatOpenAIRealtime { return 0, nil } - if meta == nil { - return 0, errors.New("token count meta is nil") + if info.RelayMode == constant2.RelayModeAudioTranscription || info.RelayMode == constant2.RelayModeAudioTranslation { + multiForm, err := common.ParseMultipartFormReusable(c) + if err != nil { + return 0, fmt.Errorf("error parsing multipart form: %v", err) + } + fileHeaders := multiForm.File["file"] + totalAudioToken := 0 + for _, fileHeader := range fileHeaders { + file, err := fileHeader.Open() + if err != nil { + return 0, fmt.Errorf("error opening audio file: %v", err) + } + defer file.Close() + // get ext and io.seeker + ext := filepath.Ext(fileHeader.Filename) + duration, err := common.GetAudioDuration(c.Request.Context(), file, ext) + if err != nil { + return 0, fmt.Errorf("error getting audio duration: %v", err) + } + // 一分钟 1000 token,与 $price / minute 对齐 + totalAudioToken += int(math.Round(math.Ceil(duration) / 60.0 * 1000)) + } + return totalAudioToken, nil } model := common.GetContextKeyString(c, constant.ContextKeyOriginalModel)