Compare commits

..

14 Commits

Author SHA1 Message Date
CaIon
2b70095b47 feat: implement audio duration retrieval without ffmpeg dependencies 2025-10-28 15:50:45 +08:00
CaIon
6791eb72ba feat: add support for Submodel channel type in relay info 2025-10-25 22:10:26 +08:00
IcedTangerine
cb3537f529 Merge pull request #2103 from feitianbubu/pr/doubao-image-watermark
fix: correct bool value for watermark
2025-10-25 11:44:34 +08:00
feitianbubu
471fd3a3b2 fix: correct bool value for watermark 2025-10-25 11:26:03 +08:00
IcedTangerine
032f159509 Merge pull request #2090 from feitianbubu/pr/doubao-image-edit
修复豆包图像编辑(图生图)功能
2025-10-23 22:39:18 +08:00
feitianbubu
95a2d02df9 fix: fail get moel by
multipart/form-data; boundary
2025-10-23 22:15:02 +08:00
feitianbubu
3ac9ff6028 feat: doubao-seedream support image edit 2025-10-23 21:19:33 +08:00
feitianbubu
fcf0f952b1 feat: doubao-seedream-4-0-250828 image to image 2025-10-23 21:19:32 +08:00
IcedTangerine
b99099fcbe Merge pull request #2087 from feitianbubu/pr/doubao-tts-stream
feat: doubao tts support streaming realtime audio
2025-10-22 17:44:01 +08:00
feitianbubu
bf66bbe5fa refactor: clean up doubao tts code 2025-10-22 17:06:13 +08:00
IcedTangerine
e80b442dd6 Merge pull request #2086 from feitianbubu/pr/openai-tts-stream
feat: openai tts support streaming realtime audio
2025-10-22 14:07:25 +08:00
feitianbubu
431b3a84f6 feat: doubao tts add is stream check 2025-10-22 13:39:16 +08:00
feitianbubu
098e6e7f2b feat: doubao tts support streaming realtime audio 2025-10-22 13:39:16 +08:00
feitianbubu
afcbff6644 feat: openai tts support streaming realtime audio 2025-10-22 13:33:01 +08:00
16 changed files with 1348 additions and 277 deletions

4
.gitignore vendored
View File

@@ -1,6 +1,5 @@
.idea
.vscode
.zed
upload
*.exe
*.db
@@ -11,11 +10,10 @@ web/dist
.env
one-api
new-api
/__debug_bin*
.DS_Store
tiktoken_cache
.eslintcache
.gocache
electron/node_modules
electron/dist
electron/dist

View File

@@ -28,7 +28,7 @@ RUN go build -ldflags "-s -w -X 'github.com/QuantumNous/new-api/common.Version=$
FROM alpine
RUN apk upgrade --no-cache \
&& apk add --no-cache ca-certificates tzdata ffmpeg \
&& apk add --no-cache ca-certificates tzdata \
&& update-ca-certificates
COPY --from=builder2 /build/new-api /

295
common/audio.go Normal file
View File

@@ -0,0 +1,295 @@
package common
import (
"context"
"encoding/binary"
"fmt"
"io"
"github.com/abema/go-mp4"
"github.com/go-audio/aiff"
"github.com/go-audio/wav"
"github.com/jfreymuth/oggvorbis"
"github.com/mewkiz/flac"
"github.com/pkg/errors"
"github.com/tcolgate/mp3"
"github.com/yapingcat/gomedia/go-codec"
)
// GetAudioDuration 使用纯 Go 库获取音频文件的时长(秒)。
// 它不再依赖外部的 ffmpeg 或 ffprobe 程序。
func GetAudioDuration(ctx context.Context, f io.ReadSeeker, ext string) (duration float64, err error) {
SysLog(fmt.Sprintf("GetAudioDuration: ext=%s", ext))
// 根据文件扩展名选择解析器
switch ext {
case ".mp3":
duration, err = getMP3Duration(f)
case ".wav":
duration, err = getWAVDuration(f)
case ".flac":
duration, err = getFLACDuration(f)
case ".m4a", ".mp4":
duration, err = getM4ADuration(f)
case ".ogg", ".oga":
duration, err = getOGGDuration(f)
case ".opus":
duration, err = getOpusDuration(f)
case ".aiff", ".aif", ".aifc":
duration, err = getAIFFDuration(f)
case ".webm":
duration, err = getWebMDuration(f)
case ".aac":
duration, err = getAACDuration(f)
default:
return 0, fmt.Errorf("unsupported audio format: %s", ext)
}
SysLog(fmt.Sprintf("GetAudioDuration: duration=%f", duration))
return duration, err
}
// getMP3Duration 解析 MP3 文件以获取时长。
// 注意:对于 VBR (Variable Bitrate) MP3这个估算可能不完全精确但通常足够好。
// FFmpeg 在这种情况下会扫描整个文件来获得精确值,但这里的库提供了快速估算。
func getMP3Duration(r io.Reader) (float64, error) {
d := mp3.NewDecoder(r)
var f mp3.Frame
skipped := 0
duration := 0.0
for {
if err := d.Decode(&f, &skipped); err != nil {
if err == io.EOF {
break
}
return 0, errors.Wrap(err, "failed to decode mp3 frame")
}
duration += f.Duration().Seconds()
}
return duration, nil
}
// getWAVDuration 解析 WAV 文件头以获取时长。
func getWAVDuration(r io.ReadSeeker) (float64, error) {
dec := wav.NewDecoder(r)
if !dec.IsValidFile() {
return 0, errors.New("invalid wav file")
}
d, err := dec.Duration()
if err != nil {
return 0, errors.Wrap(err, "failed to get wav duration")
}
return d.Seconds(), nil
}
// getFLACDuration 解析 FLAC 文件的 STREAMINFO 块。
func getFLACDuration(r io.Reader) (float64, error) {
stream, err := flac.Parse(r)
if err != nil {
return 0, errors.Wrap(err, "failed to parse flac stream")
}
defer stream.Close()
// 时长 = 总采样数 / 采样率
duration := float64(stream.Info.NSamples) / float64(stream.Info.SampleRate)
return duration, nil
}
// getM4ADuration 解析 M4A/MP4 文件的 'mvhd' box。
func getM4ADuration(r io.ReadSeeker) (float64, error) {
// go-mp4 库需要 ReadSeeker 接口
info, err := mp4.Probe(r)
if err != nil {
return 0, errors.Wrap(err, "failed to probe m4a/mp4 file")
}
// 时长 = Duration / Timescale
return float64(info.Duration) / float64(info.Timescale), nil
}
// getOGGDuration 解析 OGG/Vorbis 文件以获取时长。
func getOGGDuration(r io.ReadSeeker) (float64, error) {
// 重置 reader 到开头
if _, err := r.Seek(0, io.SeekStart); err != nil {
return 0, errors.Wrap(err, "failed to seek ogg file")
}
reader, err := oggvorbis.NewReader(r)
if err != nil {
return 0, errors.Wrap(err, "failed to create ogg vorbis reader")
}
// 计算时长 = 总采样数 / 采样率
// 需要读取整个文件来获取总采样数
channels := reader.Channels()
sampleRate := reader.SampleRate()
// 估算方法:读取到文件结尾
var totalSamples int64
buf := make([]float32, 4096*channels)
for {
n, err := reader.Read(buf)
if err == io.EOF {
break
}
if err != nil {
return 0, errors.Wrap(err, "failed to read ogg samples")
}
totalSamples += int64(n / channels)
}
duration := float64(totalSamples) / float64(sampleRate)
return duration, nil
}
// getOpusDuration 解析 Opus 文件(在 OGG 容器中)以获取时长。
func getOpusDuration(r io.ReadSeeker) (float64, error) {
// Opus 通常封装在 OGG 容器中
// 我们需要解析 OGG 页面来获取时长信息
if _, err := r.Seek(0, io.SeekStart); err != nil {
return 0, errors.Wrap(err, "failed to seek opus file")
}
// 读取 OGG 页面头部
var totalGranulePos int64
buf := make([]byte, 27) // OGG 页面头部最小大小
for {
n, err := r.Read(buf)
if err == io.EOF {
break
}
if err != nil {
return 0, errors.Wrap(err, "failed to read opus/ogg page")
}
if n < 27 {
break
}
// 检查 OGG 页面标识 "OggS"
if string(buf[0:4]) != "OggS" {
// 跳过一些字节继续寻找
if _, err := r.Seek(-26, io.SeekCurrent); err != nil {
break
}
continue
}
// 读取 granule position (字节 6-13, 小端序)
granulePos := int64(binary.LittleEndian.Uint64(buf[6:14]))
if granulePos > totalGranulePos {
totalGranulePos = granulePos
}
// 读取段表大小
numSegments := int(buf[26])
segmentTable := make([]byte, numSegments)
if _, err := io.ReadFull(r, segmentTable); err != nil {
break
}
// 计算页面数据大小并跳过
var pageSize int
for _, segSize := range segmentTable {
pageSize += int(segSize)
}
if _, err := r.Seek(int64(pageSize), io.SeekCurrent); err != nil {
break
}
}
// Opus 的采样率固定为 48000 Hz
duration := float64(totalGranulePos) / 48000.0
return duration, nil
}
// getAIFFDuration 解析 AIFF 文件头以获取时长。
func getAIFFDuration(r io.ReadSeeker) (float64, error) {
if _, err := r.Seek(0, io.SeekStart); err != nil {
return 0, errors.Wrap(err, "failed to seek aiff file")
}
dec := aiff.NewDecoder(r)
if !dec.IsValidFile() {
return 0, errors.New("invalid aiff file")
}
d, err := dec.Duration()
if err != nil {
return 0, errors.Wrap(err, "failed to get aiff duration")
}
return d.Seconds(), nil
}
// getWebMDuration 解析 WebM 文件以获取时长。
// WebM 使用 Matroska 容器格式
func getWebMDuration(r io.ReadSeeker) (float64, error) {
if _, err := r.Seek(0, io.SeekStart); err != nil {
return 0, errors.Wrap(err, "failed to seek webm file")
}
// WebM/Matroska 文件的解析比较复杂
// 这里提供一个简化的实现,读取 EBML 头部
// 对于完整的 WebM 解析,可能需要使用专门的库
// 简单实现:查找 Duration 元素
// WebM Duration 的 Element ID 是 0x4489
// 这是一个简化版本,可能不适用于所有 WebM 文件
buf := make([]byte, 8192)
n, err := r.Read(buf)
if err != nil && err != io.EOF {
return 0, errors.Wrap(err, "failed to read webm file")
}
// 尝试查找 Duration 元素(这是一个简化的方法)
// 实际的 WebM 解析需要完整的 EBML 解析器
// 这里返回错误,建议使用专门的库
if n > 0 {
// 检查 EBML 标识
if len(buf) >= 4 && binary.BigEndian.Uint32(buf[0:4]) == 0x1A45DFA3 {
// 这是一个有效的 EBML 文件
// 但完整解析需要更复杂的逻辑
return 0, errors.New("webm duration parsing requires full EBML parser (consider using ffprobe for webm files)")
}
}
return 0, errors.New("failed to parse webm file")
}
// getAACDuration 解析 AAC (ADTS格式) 文件以获取时长。
// 使用 gomedia 库来解析 AAC ADTS 帧
func getAACDuration(r io.ReadSeeker) (float64, error) {
if _, err := r.Seek(0, io.SeekStart); err != nil {
return 0, errors.Wrap(err, "failed to seek aac file")
}
// 读取整个文件内容
data, err := io.ReadAll(r)
if err != nil {
return 0, errors.Wrap(err, "failed to read aac file")
}
var totalFrames int64
var sampleRate int
// 使用 gomedia 的 SplitAACFrame 函数来分割 AAC 帧
codec.SplitAACFrame(data, func(aac []byte) {
// 解析 ADTS 头部以获取采样率信息
if len(aac) >= 7 {
// 使用 ConvertADTSToASC 来获取音频配置信息
asc, err := codec.ConvertADTSToASC(aac)
if err == nil && sampleRate == 0 {
sampleRate = codec.AACSampleIdxToSample(int(asc.Sample_freq_index))
}
totalFrames++
}
})
if sampleRate == 0 || totalFrames == 0 {
return 0, errors.New("no valid aac frames found")
}
// 每个 AAC ADTS 帧包含 1024 个采样
totalSamples := totalFrames * 1024
duration := float64(totalSamples) / float64(sampleRate)
return duration, nil
}

View File

@@ -2,9 +2,11 @@ package common
import (
"bytes"
"encoding/json"
"io"
"mime/multipart"
"net/http"
"net/url"
"strings"
"time"
@@ -40,6 +42,10 @@ func UnmarshalBodyReusable(c *gin.Context, v any) error {
contentType := c.Request.Header.Get("Content-Type")
if strings.HasPrefix(contentType, "application/json") {
err = Unmarshal(requestBody, &v)
} else if strings.Contains(contentType, gin.MIMEPOSTForm) {
err = parseFormData(requestBody, &v)
} else if strings.Contains(contentType, gin.MIMEMultipartPOSTForm) {
err = parseMultipartFormData(c, requestBody, &v)
} else {
// skip for now
// TODO: someday non json request have variant model, we will need to implementation this
@@ -138,3 +144,57 @@ func ParseMultipartFormReusable(c *gin.Context) (*multipart.Form, error) {
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
return form, nil
}
func parseFormData(data []byte, v any) error {
values, err := url.ParseQuery(string(data))
if err != nil {
return err
}
formMap := make(map[string]any)
for key, vals := range values {
if len(vals) == 1 {
formMap[key] = vals[0]
} else {
formMap[key] = vals
}
}
jsonData, err := json.Marshal(formMap)
if err != nil {
return err
}
return Unmarshal(jsonData, v)
}
func parseMultipartFormData(c *gin.Context, data []byte, v any) error {
contentType := c.Request.Header.Get("Content-Type")
boundary := ""
if idx := strings.Index(contentType, "boundary="); idx != -1 {
boundary = contentType[idx+9:]
}
if boundary == "" {
return Unmarshal(data, v) // Fallback to JSON
}
reader := multipart.NewReader(bytes.NewReader(data), boundary)
form, err := reader.ReadForm(32 << 20) // 32 MB max memory
if err != nil {
return err
}
defer form.RemoveAll()
formMap := make(map[string]any)
for key, vals := range form.Value {
if len(vals) == 1 {
formMap[key] = vals[0]
} else {
formMap[key] = vals
}
}
jsonData, err := Marshal(formMap)
if err != nil {
return err
}
return Unmarshal(jsonData, v)
}

View File

@@ -1,8 +1,6 @@
package common
import (
"bytes"
"context"
crand "crypto/rand"
"encoding/base64"
"encoding/json"
@@ -329,43 +327,6 @@ func SaveTmpFile(filename string, data io.Reader) (string, error) {
return f.Name(), nil
}
// GetAudioDuration returns the duration of an audio file in seconds.
func GetAudioDuration(ctx context.Context, filename string, ext string) (float64, error) {
// ffprobe -v error -show_entries format=duration -of default=noprint_wrappers=1:nokey=1 {{input}}
c := exec.CommandContext(ctx, "ffprobe", "-v", "error", "-show_entries", "format=duration", "-of", "default=noprint_wrappers=1:nokey=1", filename)
output, err := c.Output()
if err != nil {
return 0, errors.Wrap(err, "failed to get audio duration")
}
durationStr := string(bytes.TrimSpace(output))
if durationStr == "N/A" {
// Create a temporary output file name
tmpFp, err := os.CreateTemp("", "audio-*"+ext)
if err != nil {
return 0, errors.Wrap(err, "failed to create temporary file")
}
tmpName := tmpFp.Name()
// Close immediately so ffmpeg can open the file on Windows.
_ = tmpFp.Close()
defer os.Remove(tmpName)
// ffmpeg -y -i filename -vcodec copy -acodec copy <tmpName>
ffmpegCmd := exec.CommandContext(ctx, "ffmpeg", "-y", "-i", filename, "-vcodec", "copy", "-acodec", "copy", tmpName)
if err := ffmpegCmd.Run(); err != nil {
return 0, errors.Wrap(err, "failed to run ffmpeg")
}
// Recalculate the duration of the new file
c = exec.CommandContext(ctx, "ffprobe", "-v", "error", "-show_entries", "format=duration", "-of", "default=noprint_wrappers=1:nokey=1", tmpName)
output, err := c.Output()
if err != nil {
return 0, errors.Wrap(err, "failed to get audio duration after ffmpeg")
}
durationStr = string(bytes.TrimSpace(output))
}
return strconv.ParseFloat(durationStr, 64)
}
// BuildURL concatenates base and endpoint, returns the complete url string
func BuildURL(base string, endpoint string) string {
u, err := url.Parse(base)

View File

@@ -27,7 +27,8 @@ type ImageRequest struct {
OutputCompression json.RawMessage `json:"output_compression,omitempty"`
PartialImages json.RawMessage `json:"partial_images,omitempty"`
// Stream bool `json:"stream,omitempty"`
Watermark *bool `json:"watermark,omitempty"`
Watermark *bool `json:"watermark,omitempty"`
Image json.RawMessage `json:"image,omitempty"`
// 用匿名参数接收额外参数
Extra map[string]json.RawMessage `json:"-"`
}

13
go.mod
View File

@@ -5,6 +5,7 @@ go 1.25.1
require (
github.com/Calcium-Ion/go-epay v0.0.4
github.com/abema/go-mp4 v1.4.1
github.com/andybalholm/brotli v1.1.1
github.com/anknown/ahocorasick v0.0.0-20190904063843-d75dbd5169c0
github.com/aws/aws-sdk-go-v2 v1.37.2
@@ -18,24 +19,30 @@ require (
github.com/gin-contrib/static v0.0.1
github.com/gin-gonic/gin v1.9.1
github.com/glebarez/sqlite v1.9.0
github.com/go-audio/aiff v1.1.0
github.com/go-audio/wav v1.1.0
github.com/go-playground/validator/v10 v10.20.0
github.com/go-redis/redis/v8 v8.11.5
github.com/go-webauthn/webauthn v0.14.0
github.com/golang-jwt/jwt/v5 v5.3.0
github.com/google/uuid v1.6.0
github.com/gorilla/websocket v1.5.0
github.com/jfreymuth/oggvorbis v1.0.5
github.com/jinzhu/copier v0.4.0
github.com/joho/godotenv v1.5.1
github.com/mewkiz/flac v1.0.13
github.com/pkg/errors v0.9.1
github.com/pquerna/otp v1.5.0
github.com/samber/lo v1.39.0
github.com/shirou/gopsutil v3.21.11+incompatible
github.com/shopspring/decimal v1.4.0
github.com/stripe/stripe-go/v81 v81.4.0
github.com/tcolgate/mp3 v0.0.0-20170426193717-e79c5a46d300
github.com/thanhpk/randstr v1.0.6
github.com/tidwall/gjson v1.18.0
github.com/tidwall/sjson v1.2.5
github.com/tiktoken-go/tokenizer v0.6.2
github.com/yapingcat/gomedia v0.0.0-20240906162731-17feea57090c
golang.org/x/crypto v0.42.0
golang.org/x/image v0.23.0
golang.org/x/net v0.43.0
@@ -62,6 +69,8 @@ require (
github.com/gabriel-vasile/mimetype v1.4.3 // indirect
github.com/gin-contrib/sse v0.1.0 // indirect
github.com/glebarez/go-sqlite v1.21.2 // indirect
github.com/go-audio/audio v1.0.0 // indirect
github.com/go-audio/riff v1.0.0 // indirect
github.com/go-ole/go-ole v1.2.6 // indirect
github.com/go-playground/locales v0.14.1 // indirect
github.com/go-playground/universal-translator v0.18.1 // indirect
@@ -73,16 +82,20 @@ require (
github.com/gorilla/context v1.1.1 // indirect
github.com/gorilla/securecookie v1.1.1 // indirect
github.com/gorilla/sessions v1.2.1 // indirect
github.com/icza/bitio v1.1.0 // indirect
github.com/jackc/pgpassfile v1.0.0 // indirect
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
github.com/jackc/pgx/v5 v5.7.1 // indirect
github.com/jackc/puddle/v2 v2.2.2 // indirect
github.com/jfreymuth/vorbis v1.0.2 // indirect
github.com/jinzhu/inflection v1.0.0 // indirect
github.com/jinzhu/now v1.1.5 // indirect
github.com/json-iterator/go v1.1.12 // indirect
github.com/klauspost/cpuid/v2 v2.3.0 // indirect
github.com/leodido/go-urn v1.4.0 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect
github.com/mewkiz/pkg v0.0.0-20250417130911-3f050ff8c56d // indirect
github.com/mewpkg/term v0.0.0-20241026122259-37a80af23985 // indirect
github.com/mitchellh/mapstructure v1.5.0 // indirect
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
github.com/modern-go/reflect2 v1.0.2 // indirect

40
go.sum
View File

@@ -1,5 +1,7 @@
github.com/Calcium-Ion/go-epay v0.0.4 h1:C96M7WfRLadcIVscWzwLiYs8etI1wrDmtFMuK2zP22A=
github.com/Calcium-Ion/go-epay v0.0.4/go.mod h1:cxo/ZOg8ClvE3VAnCmEzbuyAZINSq7kFEN9oHj5WQ2U=
github.com/abema/go-mp4 v1.4.1 h1:YoS4VRqd+pAmddRPLFf8vMk74kuGl6ULSjzhsIqwr6M=
github.com/abema/go-mp4 v1.4.1/go.mod h1:vPl9t5ZK7K0x68jh12/+ECWBCXoWuIDtNgPtU2f04ws=
github.com/andybalholm/brotli v1.1.1 h1:PR2pgnyFznKEugtsUo0xLdDop5SKXd5Qf5ysW+7XdTA=
github.com/andybalholm/brotli v1.1.1/go.mod h1:05ib4cKhjx3OQYUY22hTVd34Bc8upXjOLL2rKwwZBoA=
github.com/anknown/ahocorasick v0.0.0-20190904063843-d75dbd5169c0 h1:onfun1RA+KcxaMk1lfrRnwCd1UUuOjJM/lri5eM1qMs=
@@ -33,6 +35,7 @@ github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UF
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/cloudwego/base64x v0.1.6 h1:t11wG9AECkCDk5fMSoxmufanudBtJ+/HemLstXDLI2M=
github.com/cloudwego/base64x v0.1.6/go.mod h1:OFcloc187FXDaYHvrNIjxSe8ncn0OOM8gEHfghB2IPU=
github.com/creack/pty v1.1.7/go.mod h1:lj5s0c3V2DBrqTV7llrYr5NG6My20zk30Fl46Y7DoTY=
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
@@ -67,6 +70,15 @@ github.com/glebarez/go-sqlite v1.21.2 h1:3a6LFC4sKahUunAmynQKLZceZCOzUthkRkEAl9g
github.com/glebarez/go-sqlite v1.21.2/go.mod h1:sfxdZyhQjTM2Wry3gVYWaW072Ri1WMdWJi0k6+3382k=
github.com/glebarez/sqlite v1.9.0 h1:Aj6bPA12ZEx5GbSF6XADmCkYXlljPNUY+Zf1EQxynXs=
github.com/glebarez/sqlite v1.9.0/go.mod h1:YBYCoyupOao60lzp1MVBLEjZfgkq0tdB1voAQ09K9zw=
github.com/go-audio/aiff v1.1.0 h1:m2LYgu/2BarpF2yZnFPWtY3Tp41k0A4y51gDRZZsEuU=
github.com/go-audio/aiff v1.1.0/go.mod h1:sDik1muYvhPiccClfri0fv6U2fyH/dy4VRWmUz0cz9Q=
github.com/go-audio/audio v1.0.0 h1:zS9vebldgbQqktK4H0lUqWrG8P0NxCJVqcj7ZpNnwd4=
github.com/go-audio/audio v1.0.0/go.mod h1:6uAu0+H2lHkwdGsAY+j2wHPNPpPoeg5AaEFh9FlA+Zs=
github.com/go-audio/riff v1.0.0 h1:d8iCGbDvox9BfLagY94fBynxSPHO80LmZCaOsmKxokA=
github.com/go-audio/riff v1.0.0/go.mod h1:l3cQwc85y79NQFCRB7TiPoNiaijp6q8Z0Uv38rVG498=
github.com/go-audio/wav v1.0.0/go.mod h1:3yoReyQOsiARkvPl3ERCi8JFjihzG6WhjYpZCf5zAWE=
github.com/go-audio/wav v1.1.0 h1:jQgLtbqBzY7G+BM8fXF7AHUk1uHUviWS4X39d5rsL2g=
github.com/go-audio/wav v1.1.0/go.mod h1:mpe9qfwbScEbkd8uybLuIpTgHyrISw/OTuvjUW2iGtE=
github.com/go-ole/go-ole v1.2.6 h1:/Fpf6oFPoeFik9ty7siob0G6Ke8QvQEuVcuChpwXzpY=
github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0=
github.com/go-playground/assert/v2 v2.0.1/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
@@ -108,6 +120,7 @@ github.com/google/go-tpm v0.9.5/go.mod h1:h9jEsEECg7gtLis0upRBQU+GhYVH6jMjrFxI8u
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/google/pprof v0.0.0-20221118152302-e6195bd50e26 h1:Xim43kblpZXfIBQsbuBVKCudVG457BR2GZFIz3uw3hQ=
github.com/google/pprof v0.0.0-20221118152302-e6195bd50e26/go.mod h1:dDKJzRmX4S37WGHujM7tX//fmj1uioxKzKxz3lo4HJo=
github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/gorilla/context v1.1.1 h1:AWwleXJkX/nhcU9bZSnZoi3h/qGYqQAGhq6zZe/aQW8=
@@ -118,6 +131,10 @@ github.com/gorilla/sessions v1.2.1 h1:DHd3rPN5lE3Ts3D8rKkQ8x/0kqfeNmBAaiSi+o7Fsg
github.com/gorilla/sessions v1.2.1/go.mod h1:dk2InVEVJ0sfLlnXv9EAgkf6ecYs/i80K/zI+bUmuGM=
github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc=
github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/icza/bitio v1.1.0 h1:ysX4vtldjdi3Ygai5m1cWy4oLkhWTAi+SyO6HC8L9T0=
github.com/icza/bitio v1.1.0/go.mod h1:0jGnlLAx8MKMr9VGnn/4YrvZiprkvBelsVIbA9Jjr9A=
github.com/icza/mighty v0.0.0-20180919140131-cfd07d671de6 h1:8UsGZ2rr2ksmEru6lToqnXgA8Mz1DP11X4zSJ159C3k=
github.com/icza/mighty v0.0.0-20180919140131-cfd07d671de6/go.mod h1:xQig96I1VNBDIWGCdTt54nHt6EeI639SmHycLYL7FkA=
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
@@ -126,6 +143,10 @@ github.com/jackc/pgx/v5 v5.7.1 h1:x7SYsPBYDkHDksogeSmZZ5xzThcTgRz++I5E+ePFUcs=
github.com/jackc/pgx/v5 v5.7.1/go.mod h1:e7O26IywZZ+naJtWWos6i6fvWK+29etgITqrqHLfoZA=
github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo=
github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
github.com/jfreymuth/oggvorbis v1.0.5 h1:u+Ck+R0eLSRhgq8WTmffYnrVtSztJcYrl588DM4e3kQ=
github.com/jfreymuth/oggvorbis v1.0.5/go.mod h1:1U4pqWmghcoVsCJJ4fRBKv9peUJMBHixthRlBeD6uII=
github.com/jfreymuth/vorbis v1.0.2 h1:m1xH6+ZI4thH927pgKD8JOH4eaGRm18rEE9/0WKjvNE=
github.com/jfreymuth/vorbis v1.0.2/go.mod h1:DoftRo4AznKnShRl1GxiTFCseHr4zR9BN3TWXyuzrqQ=
github.com/jinzhu/copier v0.4.0 h1:w3ciUoD19shMCRargcpm0cm91ytaBhDvuRpz1ODO/U8=
github.com/jinzhu/copier v0.4.0/go.mod h1:DfbEm0FYsaqBcKcFuvmOZb218JkPGtvSHsKg8S8hyyg=
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
@@ -145,6 +166,7 @@ github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfn
github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0=
github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk=
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
github.com/kr/pty v1.1.8/go.mod h1:O1sed60cT9XZ5uDucP5qwvh+TE3NnUj51EiZO/lmSfw=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
@@ -152,10 +174,17 @@ github.com/leodido/go-urn v1.2.0/go.mod h1:+8+nEpDfqqsY+g338gtMEUOtuK+4dEMhiQEgx
github.com/leodido/go-urn v1.2.1/go.mod h1:zt4jvISO2HfUBqxjfIshjdMTYS56ZS/qv49ictyFfxY=
github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ=
github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI=
github.com/mattetti/audio v0.0.0-20180912171649-01576cde1f21/go.mod h1:LlQmBGkOuV/SKzEDXBPKauvN2UqCgzXO2XjecTGj40s=
github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU=
github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mewkiz/flac v1.0.13 h1:6wF8rRQKBFW159Daqx6Ro7K5ZnlVhHUKfS5aTsC4oXs=
github.com/mewkiz/flac v1.0.13/go.mod h1:HfPYDA+oxjyuqMu2V+cyKcxF51KM6incpw5eZXmfA6k=
github.com/mewkiz/pkg v0.0.0-20250417130911-3f050ff8c56d h1:IL2tii4jXLdhCeQN69HNzYYW1kl0meSG0wt5+sLwszU=
github.com/mewkiz/pkg v0.0.0-20250417130911-3f050ff8c56d/go.mod h1:SIpumAnUWSy0q9RzKD3pyH3g1t5vdawUAPcW5tQrUtI=
github.com/mewpkg/term v0.0.0-20241026122259-37a80af23985 h1:h8O1byDZ1uk6RUXMhj1QJU3VXFKXHDZxr4TXRPGeBa8=
github.com/mewpkg/term v0.0.0-20241026122259-37a80af23985/go.mod h1:uiPmbdUbdt1NkGApKl7htQjZ8S7XaGUAVulJUJ9v6q4=
github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY=
github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo=
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
@@ -170,6 +199,8 @@ github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE=
github.com/onsi/ginkgo v1.16.5/go.mod h1:+E8gABHa3K6zRBolWtd+ROzc/U5bkGt0FwiG042wbpU=
github.com/onsi/gomega v1.18.1 h1:M1GfJqGRrBrrGGsbxzV5dqM2U2ApXefZCQpkukxYRLE=
github.com/onsi/gomega v1.18.1/go.mod h1:0q+aL8jAiMXy9hbwj2mr5GziHiwhAIQpFmmtT5hitRs=
github.com/orcaman/writerseeker v0.0.0-20200621085525-1d3f536ff85e h1:s2RNOM/IGdY0Y6qfTeUKhDawdHDpK9RGBdx80qN4Ttw=
github.com/orcaman/writerseeker v0.0.0-20200621085525-1d3f536ff85e/go.mod h1:nBdnFKj15wFbf94Rwfq4m30eAcyY9V/IyKAGQFtqkW0=
github.com/pelletier/go-toml/v2 v2.0.1/go.mod h1:r9LEWfGN8R5k0VXJ+0BkIe7MYkRdwZOjgMj2KwnJFUo=
github.com/pelletier/go-toml/v2 v2.2.1 h1:9TA9+T8+8CUCO2+WYnDLCgrYi9+omqKXyjDtosvtEhg=
github.com/pelletier/go-toml/v2 v2.2.1/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs=
@@ -209,6 +240,9 @@ github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
github.com/stripe/stripe-go/v81 v81.4.0 h1:AuD9XzdAvl193qUCSaLocf8H+nRopOouXhxqJUzCLbw=
github.com/stripe/stripe-go/v81 v81.4.0/go.mod h1:C/F4jlmnGNacvYtBp/LUHCvVUJEZffFQCobkzwY1WOo=
github.com/sunfish-shogi/bufseekio v0.0.0-20210207115823-a4185644b365/go.mod h1:dEzdXgvImkQ3WLI+0KQpmEx8T/C/ma9KeS3AfmU899I=
github.com/tcolgate/mp3 v0.0.0-20170426193717-e79c5a46d300 h1:XQdibLKagjdevRB6vAjVY4qbSr8rQ610YzTkWcxzxSI=
github.com/tcolgate/mp3 v0.0.0-20170426193717-e79c5a46d300/go.mod h1:FNa/dfN95vAYCNFrIKRrlRo+MBLbwmR9Asa5f2ljmBI=
github.com/thanhpk/randstr v1.0.6 h1:psAOktJFD4vV9NEVb3qkhRSMvYh4ORRaj1+w/hn4B+o=
github.com/thanhpk/randstr v1.0.6/go.mod h1:M/H2P1eNLZzlDwAzpkkkUvoyNNMbzRGhESZuEQk3r0U=
github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
@@ -238,6 +272,8 @@ github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM=
github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg=
github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZqKjWU=
github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E=
github.com/yapingcat/gomedia v0.0.0-20240906162731-17feea57090c h1:xA2TJS9Hu/ivzaZIrDcwvpJ3Fnpsk5fDOJ4iSnL6J0w=
github.com/yapingcat/gomedia v0.0.0-20240906162731-17feea57090c/go.mod h1:WSZ59bidJOO40JSJmLqlkBJrjZCtjbKKkygEMfzY/kc=
github.com/yusufpapurcu/wmi v1.2.3 h1:E1ctvB7uKFMOJw3fdOW32DwGE9I7t++CRUEMKvFoFiw=
github.com/yusufpapurcu/wmi v1.2.3/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0=
go.uber.org/mock v0.6.0 h1:hyF9dfmbgIX5EfOdasqLsWD6xqpNZlXblLB/Dbnwv3Y=
@@ -257,6 +293,7 @@ golang.org/x/net v0.43.0 h1:lat02VYK2j4aLzMzecihNvTlJNQUq316m2Mr9rnM6YE=
golang.org/x/net v0.43.0/go.mod h1:vhO1fvI4dGsIjh73sWfUVjj3N7CA9WkKJNQm2svM6Jg=
golang.org/x/sync v0.17.0 h1:l60nONMj9l5drqw6jlhIELNv9I0A4OFgRsG9k2oT9Ug=
golang.org/x/sync v0.17.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
golang.org/x/sys v0.0.0-20190726091711-fc99dfbffb4e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
@@ -270,6 +307,7 @@ golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.36.0 h1:KVRy2GtZBrk1cBYA7MKu5bEZFxQk4NIDV6RLVcC8o0k=
golang.org/x/sys v0.36.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
@@ -286,6 +324,8 @@ gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI=
gopkg.in/src-d/go-billy.v4 v4.3.2 h1:0SQA1pRztfTFx2miS8sA97XvooFeNOmvUenF4o0EcVg=
gopkg.in/src-d/go-billy.v4 v4.3.2/go.mod h1:nDjArDMp+XMs1aFAESLRjfGSgfvoYN0hDfzEk0GjC98=
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ=
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw=
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=

View File

@@ -4,6 +4,7 @@ import (
"errors"
"fmt"
"net/http"
"slices"
"strconv"
"strings"
"time"
@@ -85,7 +86,7 @@ func Distribute() func(c *gin.Context) {
playgroundRequest := &dto.PlayGroundRequest{}
err = common.UnmarshalBodyReusable(c, playgroundRequest)
if err != nil {
abortWithOpenAiMessage(c, http.StatusBadRequest, "无效的请求, "+err.Error())
abortWithOpenAiMessage(c, http.StatusBadRequest, "无效的playground请求, "+err.Error())
return
}
if playgroundRequest.Group != "" {
@@ -123,6 +124,20 @@ func Distribute() func(c *gin.Context) {
}
}
// getModelFromRequest 从请求中读取模型信息
// 根据 Content-Type 自动处理:
// - application/json
// - application/x-www-form-urlencoded
// - multipart/form-data
func getModelFromRequest(c *gin.Context) (*ModelRequest, error) {
var modelRequest ModelRequest
err := common.UnmarshalBodyReusable(c, &modelRequest)
if err != nil {
return nil, errors.New("无效的请求, " + err.Error())
}
return &modelRequest, nil
}
func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
var modelRequest ModelRequest
shouldSelectChannel := true
@@ -138,7 +153,7 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
midjourneyRequest := dto.MidjourneyRequest{}
err = common.UnmarshalBodyReusable(c, &midjourneyRequest)
if err != nil {
return nil, false, err
return nil, false, errors.New("无效的midjourney请求, " + err.Error())
}
midjourneyModel, mjErr, success := service.GetMjRequestModel(relayMode, &midjourneyRequest)
if mjErr != nil {
@@ -175,23 +190,12 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
relayMode := relayconstant.RelayModeUnknown
if c.Request.Method == http.MethodPost {
relayMode = relayconstant.RelayModeVideoSubmit
contentType := c.Request.Header.Get("Content-Type")
if strings.HasPrefix(contentType, "multipart/form-data") {
form, err := common.ParseMultipartFormReusable(c)
if err != nil {
return nil, false, errors.New("无效的video请求, " + err.Error())
}
defer form.RemoveAll()
if form != nil {
if values, ok := form.Value["model"]; ok && len(values) > 0 {
modelRequest.Model = values[0]
}
}
} else if strings.HasPrefix(contentType, "application/json") {
err = common.UnmarshalBodyReusable(c, &modelRequest)
if err != nil {
return nil, false, errors.New("无效的video请求, " + err.Error())
}
req, err := getModelFromRequest(c)
if err != nil {
return nil, false, err
}
if req != nil {
modelRequest.Model = req.Model
}
} else if c.Request.Method == http.MethodGet {
relayMode = relayconstant.RelayModeVideoFetchByID
@@ -201,10 +205,11 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
} else if strings.Contains(c.Request.URL.Path, "/v1/video/generations") {
relayMode := relayconstant.RelayModeUnknown
if c.Request.Method == http.MethodPost {
err = common.UnmarshalBodyReusable(c, &modelRequest)
req, err := getModelFromRequest(c)
if err != nil {
return nil, false, errors.New("video无效的请求, " + err.Error())
return nil, false, err
}
modelRequest.Model = req.Model
relayMode = relayconstant.RelayModeVideoSubmit
} else if c.Request.Method == http.MethodGet {
relayMode = relayconstant.RelayModeVideoFetchByID
@@ -222,10 +227,11 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
}
c.Set("relay_mode", relayMode)
} else if !strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") && !strings.Contains(c.Request.Header.Get("Content-Type"), "multipart/form-data") {
err = common.UnmarshalBodyReusable(c, &modelRequest)
}
if err != nil {
return nil, false, errors.New("无效的请求, " + err.Error())
req, err := getModelFromRequest(c)
if err != nil {
return nil, false, err
}
modelRequest.Model = req.Model
}
if strings.HasPrefix(c.Request.URL.Path, "/v1/realtime") {
//wss://api.openai.com/v1/realtime?model=gpt-4o-realtime-preview-2024-10-01
@@ -245,20 +251,31 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
modelRequest.Model = common.GetStringIfEmpty(modelRequest.Model, "dall-e")
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/images/edits") {
//modelRequest.Model = common.GetStringIfEmpty(c.PostForm("model"), "gpt-image-1")
if strings.Contains(c.Request.Header.Get("Content-Type"), "multipart/form-data") {
modelRequest.Model = c.PostForm("model")
contentType := c.ContentType()
if slices.Contains([]string{gin.MIMEPOSTForm, gin.MIMEMultipartPOSTForm}, contentType) {
req, err := getModelFromRequest(c)
if err == nil && req.Model != "" {
modelRequest.Model = req.Model
}
}
}
if strings.HasPrefix(c.Request.URL.Path, "/v1/audio") {
relayMode := relayconstant.RelayModeAudioSpeech
if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/speech") {
modelRequest.Model = common.GetStringIfEmpty(modelRequest.Model, "tts-1")
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translations") {
modelRequest.Model = common.GetStringIfEmpty(modelRequest.Model, c.PostForm("model"))
// 先尝试从请求读取
if req, err := getModelFromRequest(c); err == nil && req.Model != "" {
modelRequest.Model = req.Model
}
modelRequest.Model = common.GetStringIfEmpty(modelRequest.Model, "whisper-1")
relayMode = relayconstant.RelayModeAudioTranslation
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") {
modelRequest.Model = common.GetStringIfEmpty(modelRequest.Model, c.PostForm("model"))
// 先尝试从请求读取
if req, err := getModelFromRequest(c); err == nil && req.Model != "" {
modelRequest.Model = req.Model
}
modelRequest.Model = common.GetStringIfEmpty(modelRequest.Model, "whisper-1")
relayMode = relayconstant.RelayModeAudioTranscription
}
@@ -266,10 +283,12 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
}
if strings.HasPrefix(c.Request.URL.Path, "/pg/chat/completions") {
// playground chat completions
err = common.UnmarshalBodyReusable(c, &modelRequest)
req, err := getModelFromRequest(c)
if err != nil {
return nil, false, errors.New("无效的请求, " + err.Error())
return nil, false, err
}
modelRequest.Model = req.Model
modelRequest.Group = req.Group
common.SetContextKey(c, constant.ContextKeyTokenGroup, modelRequest.Group)
}
return &modelRequest, shouldSelectChannel, nil

View File

@@ -1,15 +1,10 @@
package openai
import (
"bytes"
"encoding/json"
"fmt"
"io"
"math"
"mime/multipart"
"net/http"
"os"
"path/filepath"
"strings"
"github.com/QuantumNous/new-api/common"
@@ -26,7 +21,6 @@ import (
"github.com/bytedance/gopkg/util/gopool"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
"github.com/pkg/errors"
)
func sendStreamData(c *gin.Context, info *relaycommon.RelayInfo, data string, forceFormat bool, thinkToContent bool) error {
@@ -273,6 +267,39 @@ func OpenaiHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respo
return &simpleResponse.Usage, nil
}
func streamTTSResponse(c *gin.Context, resp *http.Response) {
c.Writer.WriteHeaderNow()
flusher, ok := c.Writer.(http.Flusher)
if !ok {
logger.LogWarn(c, "streaming not supported")
_, err := io.Copy(c.Writer, resp.Body)
if err != nil {
logger.LogWarn(c, err.Error())
}
return
}
buffer := make([]byte, 4096)
for {
n, err := resp.Body.Read(buffer)
//logger.LogInfo(c, fmt.Sprintf("streamTTSResponse read %d bytes", n))
if n > 0 {
if _, writeErr := c.Writer.Write(buffer[:n]); writeErr != nil {
logger.LogError(c, writeErr.Error())
break
}
flusher.Flush()
}
if err != nil {
if err != io.EOF {
logger.LogError(c, err.Error())
}
break
}
}
}
func OpenaiTTSHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) *dto.Usage {
// the status code has been judged before, if there is a body reading failure,
// it should be regarded as a non-recoverable error, so it should not return err for external retry.
@@ -288,10 +315,16 @@ func OpenaiTTSHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
c.Writer.Header().Set(k, v[0])
}
c.Writer.WriteHeader(resp.StatusCode)
c.Writer.WriteHeaderNow()
_, err := io.Copy(c.Writer, resp.Body)
if err != nil {
logger.LogError(c, err.Error())
isStreaming := resp.ContentLength == -1 || resp.Header.Get("Content-Length") == ""
if isStreaming {
streamTTSResponse(c, resp)
} else {
c.Writer.WriteHeaderNow()
_, err := io.Copy(c.Writer, resp.Body)
if err != nil {
logger.LogError(c, err.Error())
}
}
return usage
}
@@ -322,59 +355,13 @@ func OpenaiSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
}
}
audioTokens, err := countAudioTokens(c)
if err != nil {
return types.NewError(err, types.ErrorCodeCountTokenFailed), nil
}
usage := &dto.Usage{}
usage.PromptTokens = audioTokens
usage.PromptTokens = info.PromptTokens
usage.CompletionTokens = 0
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
return nil, usage
}
func countAudioTokens(c *gin.Context) (int, error) {
body, err := common.GetRequestBody(c)
if err != nil {
return 0, errors.WithStack(err)
}
var reqBody struct {
File *multipart.FileHeader `form:"file" binding:"required"`
}
c.Request.Body = io.NopCloser(bytes.NewReader(body))
if err = c.ShouldBind(&reqBody); err != nil {
return 0, errors.WithStack(err)
}
ext := filepath.Ext(reqBody.File.Filename) // 获取文件扩展名
reqFp, err := reqBody.File.Open()
if err != nil {
return 0, errors.WithStack(err)
}
defer reqFp.Close()
tmpFp, err := os.CreateTemp("", "audio-*"+ext)
if err != nil {
return 0, errors.WithStack(err)
}
defer os.Remove(tmpFp.Name())
_, err = io.Copy(tmpFp, reqFp)
if err != nil {
return 0, errors.WithStack(err)
}
if err = tmpFp.Close(); err != nil {
return 0, errors.WithStack(err)
}
duration, err := common.GetAudioDuration(c.Request.Context(), tmpFp.Name(), ext)
if err != nil {
return 0, errors.WithStack(err)
}
return int(math.Round(math.Ceil(duration) / 60.0 * 1000)), nil // 1 minute 相当于 1k tokens
}
func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*types.NewAPIError, *dto.RealtimeUsage) {
if info == nil || info.ClientWs == nil || info.TargetWs == nil {
return types.NewError(fmt.Errorf("invalid websocket connection"), types.ErrorCodeBadResponse), nil

View File

@@ -6,9 +6,7 @@ import (
"errors"
"fmt"
"io"
"mime/multipart"
"net/http"
"net/textproto"
"path/filepath"
"strings"
@@ -23,6 +21,11 @@ import (
"github.com/gin-gonic/gin"
)
const (
contextKeyTTSRequest = "volcengine_tts_request"
contextKeyResponseFormat = "response_format"
)
type Adaptor struct {
}
@@ -50,7 +53,7 @@ func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInf
speedRatio := request.Speed
encoding := mapEncoding(request.ResponseFormat)
c.Set("response_format", encoding)
c.Set(contextKeyResponseFormat, encoding)
volcRequest := VolcengineTTSRequest{
App: VolcengineTTSApp{
@@ -70,18 +73,23 @@ func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInf
Request: VolcengineTTSReqInfo{
ReqID: generateRequestID(),
Text: request.Input,
Operation: "query",
Operation: "submit",
Model: info.OriginModelName,
},
}
// 同步扩展字段的厂商自定义metadata
if len(request.Metadata) > 0 {
if err = json.Unmarshal(request.Metadata, &volcRequest); err != nil {
return nil, fmt.Errorf("error unmarshalling metadata to volcengine request: %w", err)
}
}
c.Set(contextKeyTTSRequest, volcRequest)
if volcRequest.Request.Operation == "submit" {
info.IsStream = true
}
jsonData, err := json.Marshal(volcRequest)
if err != nil {
return nil, fmt.Errorf("error marshalling volcengine request: %w", err)
@@ -94,128 +102,113 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf
switch info.RelayMode {
case constant.RelayModeImagesGenerations:
return request, nil
case constant.RelayModeImagesEdits:
var requestBody bytes.Buffer
writer := multipart.NewWriter(&requestBody)
writer.WriteField("model", request.Model)
// 获取所有表单字段
formData := c.Request.PostForm
// 遍历表单字段并打印输出
for key, values := range formData {
if key == "model" {
continue
}
for _, value := range values {
writer.WriteField(key, value)
}
}
// Parse the multipart form to handle both single image and multiple images
if err := c.Request.ParseMultipartForm(32 << 20); err != nil { // 32MB max memory
return nil, errors.New("failed to parse multipart form")
}
if c.Request.MultipartForm != nil && c.Request.MultipartForm.File != nil {
// Check if "image" field exists in any form, including array notation
var imageFiles []*multipart.FileHeader
var exists bool
// First check for standard "image" field
if imageFiles, exists = c.Request.MultipartForm.File["image"]; !exists || len(imageFiles) == 0 {
// If not found, check for "image[]" field
if imageFiles, exists = c.Request.MultipartForm.File["image[]"]; !exists || len(imageFiles) == 0 {
// If still not found, iterate through all fields to find any that start with "image["
foundArrayImages := false
for fieldName, files := range c.Request.MultipartForm.File {
if strings.HasPrefix(fieldName, "image[") && len(files) > 0 {
foundArrayImages = true
for _, file := range files {
imageFiles = append(imageFiles, file)
}
}
}
// If no image fields found at all
if !foundArrayImages && (len(imageFiles) == 0) {
return nil, errors.New("image is required")
}
}
}
// Process all image files
for i, fileHeader := range imageFiles {
file, err := fileHeader.Open()
if err != nil {
return nil, fmt.Errorf("failed to open image file %d: %w", i, err)
}
defer file.Close()
// If multiple images, use image[] as the field name
fieldName := "image"
if len(imageFiles) > 1 {
fieldName = "image[]"
}
// Determine MIME type based on file extension
mimeType := detectImageMimeType(fileHeader.Filename)
// Create a form file with the appropriate content type
h := make(textproto.MIMEHeader)
h.Set("Content-Disposition", fmt.Sprintf(`form-data; name="%s"; filename="%s"`, fieldName, fileHeader.Filename))
h.Set("Content-Type", mimeType)
part, err := writer.CreatePart(h)
if err != nil {
return nil, fmt.Errorf("create form part failed for image %d: %w", i, err)
}
if _, err := io.Copy(part, file); err != nil {
return nil, fmt.Errorf("copy file failed for image %d: %w", i, err)
}
}
// Handle mask file if present
if maskFiles, exists := c.Request.MultipartForm.File["mask"]; exists && len(maskFiles) > 0 {
maskFile, err := maskFiles[0].Open()
if err != nil {
return nil, errors.New("failed to open mask file")
}
defer maskFile.Close()
// Determine MIME type for mask file
mimeType := detectImageMimeType(maskFiles[0].Filename)
// Create a form file with the appropriate content type
h := make(textproto.MIMEHeader)
h.Set("Content-Disposition", fmt.Sprintf(`form-data; name="mask"; filename="%s"`, maskFiles[0].Filename))
h.Set("Content-Type", mimeType)
maskPart, err := writer.CreatePart(h)
if err != nil {
return nil, errors.New("create form file failed for mask")
}
if _, err := io.Copy(maskPart, maskFile); err != nil {
return nil, errors.New("copy mask file failed")
}
}
} else {
return nil, errors.New("no multipart form data found")
}
// 关闭 multipart 编写器以设置分界线
writer.Close()
c.Request.Header.Set("Content-Type", writer.FormDataContentType())
return bytes.NewReader(requestBody.Bytes()), nil
// 根据官方文档,并没有发现豆包生图支持表单请求:https://www.volcengine.com/docs/82379/1824121
//case constant.RelayModeImagesEdits:
//
// var requestBody bytes.Buffer
// writer := multipart.NewWriter(&requestBody)
//
// writer.WriteField("model", request.Model)
//
// formData := c.Request.PostForm
// for key, values := range formData {
// if key == "model" {
// continue
// }
// for _, value := range values {
// writer.WriteField(key, value)
// }
// }
//
// if err := c.Request.ParseMultipartForm(32 << 20); err != nil {
// return nil, errors.New("failed to parse multipart form")
// }
//
// if c.Request.MultipartForm != nil && c.Request.MultipartForm.File != nil {
// var imageFiles []*multipart.FileHeader
// var exists bool
//
// if imageFiles, exists = c.Request.MultipartForm.File["image"]; !exists || len(imageFiles) == 0 {
// if imageFiles, exists = c.Request.MultipartForm.File["image[]"]; !exists || len(imageFiles) == 0 {
// foundArrayImages := false
// for fieldName, files := range c.Request.MultipartForm.File {
// if strings.HasPrefix(fieldName, "image[") && len(files) > 0 {
// foundArrayImages = true
// for _, file := range files {
// imageFiles = append(imageFiles, file)
// }
// }
// }
//
// if !foundArrayImages && (len(imageFiles) == 0) {
// return nil, errors.New("image is required")
// }
// }
// }
//
// for i, fileHeader := range imageFiles {
// file, err := fileHeader.Open()
// if err != nil {
// return nil, fmt.Errorf("failed to open image file %d: %w", i, err)
// }
// defer file.Close()
//
// fieldName := "image"
// if len(imageFiles) > 1 {
// fieldName = "image[]"
// }
//
// mimeType := detectImageMimeType(fileHeader.Filename)
//
// h := make(textproto.MIMEHeader)
// h.Set("Content-Disposition", fmt.Sprintf(`form-data; name="%s"; filename="%s"`, fieldName, fileHeader.Filename))
// h.Set("Content-Type", mimeType)
//
// part, err := writer.CreatePart(h)
// if err != nil {
// return nil, fmt.Errorf("create form part failed for image %d: %w", i, err)
// }
//
// if _, err := io.Copy(part, file); err != nil {
// return nil, fmt.Errorf("copy file failed for image %d: %w", i, err)
// }
// }
//
// if maskFiles, exists := c.Request.MultipartForm.File["mask"]; exists && len(maskFiles) > 0 {
// maskFile, err := maskFiles[0].Open()
// if err != nil {
// return nil, errors.New("failed to open mask file")
// }
// defer maskFile.Close()
//
// mimeType := detectImageMimeType(maskFiles[0].Filename)
//
// h := make(textproto.MIMEHeader)
// h.Set("Content-Disposition", fmt.Sprintf(`form-data; name="mask"; filename="%s"`, maskFiles[0].Filename))
// h.Set("Content-Type", mimeType)
//
// maskPart, err := writer.CreatePart(h)
// if err != nil {
// return nil, errors.New("create form file failed for mask")
// }
//
// if _, err := io.Copy(maskPart, maskFile); err != nil {
// return nil, errors.New("copy mask file failed")
// }
// }
// } else {
// return nil, errors.New("no multipart form data found")
// }
//
// writer.Close()
// c.Request.Header.Set("Content-Type", writer.FormDataContentType())
// return bytes.NewReader(requestBody.Bytes()), nil
default:
return request, nil
}
}
// detectImageMimeType determines the MIME type based on the file extension
func detectImageMimeType(filename string) string {
ext := strings.ToLower(filepath.Ext(filename))
switch ext {
@@ -226,11 +219,9 @@ func detectImageMimeType(filename string) string {
case ".webp":
return "image/webp"
default:
// Try to detect from extension if possible
if strings.HasPrefix(ext, ".jp") {
return "image/jpeg"
}
// Default to png as a fallback
return "image/png"
}
}
@@ -259,16 +250,16 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
return fmt.Sprintf("%s/api/v3/chat/completions", baseUrl), nil
case constant.RelayModeEmbeddings:
return fmt.Sprintf("%s/api/v3/embeddings", baseUrl), nil
case constant.RelayModeImagesGenerations:
//豆包的图生图也走generations接口: https://www.volcengine.com/docs/82379/1824121
case constant.RelayModeImagesGenerations, constant.RelayModeImagesEdits:
return fmt.Sprintf("%s/api/v3/images/generations", baseUrl), nil
case constant.RelayModeImagesEdits:
return fmt.Sprintf("%s/api/v3/images/edits", baseUrl), nil
//case constant.RelayModeImagesEdits:
// return fmt.Sprintf("%s/api/v3/images/edits", baseUrl), nil
case constant.RelayModeRerank:
return fmt.Sprintf("%s/api/v3/rerank", baseUrl), nil
case constant.RelayModeAudioSpeech:
// 只有当 baseUrl 是火山默认的官方Url时才改为官方的的TTS接口否则走透传的New接口
if baseUrl == channelconstant.ChannelBaseURLs[channelconstant.ChannelTypeVolcEngine] {
return "https://openspeech.bytedance.com/api/v1/tts", nil
return "wss://openspeech.bytedance.com/api/v1/tts/ws_binary", nil
}
return fmt.Sprintf("%s/v1/audio/speech", baseUrl), nil
default:
@@ -287,6 +278,8 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
}
req.Set("Content-Type", "application/json")
return nil
} else if info.RelayMode == constant.RelayModeImagesEdits {
req.Set("Content-Type", gin.MIMEJSON)
}
req.Set("Authorization", "Bearer "+info.ApiKey)
@@ -297,7 +290,7 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
if request == nil {
return nil, errors.New("request is nil")
}
// 适配 方舟deepseek混合模型 的 thinking 后缀
if strings.HasSuffix(info.UpstreamModelName, "-thinking") && strings.HasPrefix(info.UpstreamModelName, "deepseek") {
info.UpstreamModelName = strings.TrimSuffix(info.UpstreamModelName, "-thinking")
request.Model = info.UpstreamModelName
@@ -315,17 +308,58 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela
}
func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {
// TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
if info.RelayMode == constant.RelayModeAudioSpeech {
baseUrl := info.ChannelBaseUrl
if baseUrl == "" {
baseUrl = channelconstant.ChannelBaseURLs[channelconstant.ChannelTypeVolcEngine]
}
if baseUrl == channelconstant.ChannelBaseURLs[channelconstant.ChannelTypeVolcEngine] {
if info.IsStream {
return nil, nil
}
}
}
return channel.DoApiRequest(a, c, info, requestBody)
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
if info.RelayMode == constant.RelayModeAudioSpeech {
encoding := mapEncoding(c.GetString("response_format"))
encoding := mapEncoding(c.GetString(contextKeyResponseFormat))
if info.IsStream {
volcRequestInterface, exists := c.Get(contextKeyTTSRequest)
if !exists {
return nil, types.NewErrorWithStatusCode(
errors.New("volcengine TTS request not found in context"),
types.ErrorCodeBadRequestBody,
http.StatusInternalServerError,
)
}
volcRequest, ok := volcRequestInterface.(VolcengineTTSRequest)
if !ok {
return nil, types.NewErrorWithStatusCode(
errors.New("invalid volcengine TTS request type"),
types.ErrorCodeBadRequestBody,
http.StatusInternalServerError,
)
}
// Get the WebSocket URL
requestURL, urlErr := a.GetRequestURL(info)
if urlErr != nil {
return nil, types.NewErrorWithStatusCode(
urlErr,
types.ErrorCodeBadRequestBody,
http.StatusInternalServerError,
)
}
return handleTTSWebSocketResponse(c, requestURL, volcRequest, info, encoding)
}
return handleTTSResponse(c, resp, info, encoding)
}

View File

@@ -0,0 +1,533 @@
package volcengine
import (
"bytes"
"encoding/binary"
"fmt"
"io"
"math"
"github.com/gorilla/websocket"
)
type (
EventType int32
MsgType uint8
MsgTypeFlagBits uint8
VersionBits uint8
HeaderSizeBits uint8
SerializationBits uint8
CompressionBits uint8
)
const (
MsgTypeFlagNoSeq MsgTypeFlagBits = 0
MsgTypeFlagPositiveSeq MsgTypeFlagBits = 0b1
MsgTypeFlagNegativeSeq MsgTypeFlagBits = 0b11
MsgTypeFlagWithEvent MsgTypeFlagBits = 0b100
)
const (
Version1 VersionBits = iota + 1
)
const (
HeaderSize4 HeaderSizeBits = iota + 1
)
const (
SerializationJSON SerializationBits = 0b1
)
const (
CompressionNone CompressionBits = 0
)
const (
MsgTypeFullClientRequest MsgType = 0b1
MsgTypeAudioOnlyClient MsgType = 0b10
MsgTypeFullServerResponse MsgType = 0b1001
MsgTypeAudioOnlyServer MsgType = 0b1011
MsgTypeFrontEndResultServer MsgType = 0b1100
MsgTypeError MsgType = 0b1111
)
func (t MsgType) String() string {
switch t {
case MsgTypeFullClientRequest:
return "MsgType_FullClientRequest"
case MsgTypeAudioOnlyClient:
return "MsgType_AudioOnlyClient"
case MsgTypeFullServerResponse:
return "MsgType_FullServerResponse"
case MsgTypeAudioOnlyServer:
return "MsgType_AudioOnlyServer"
case MsgTypeError:
return "MsgType_Error"
case MsgTypeFrontEndResultServer:
return "MsgType_FrontEndResultServer"
default:
return fmt.Sprintf("MsgType_(%d)", t)
}
}
const (
EventType_None EventType = 0
EventType_StartConnection EventType = 1
EventType_FinishConnection EventType = 2
EventType_ConnectionStarted EventType = 50
EventType_ConnectionFailed EventType = 51
EventType_ConnectionFinished EventType = 52
EventType_StartSession EventType = 100
EventType_CancelSession EventType = 101
EventType_FinishSession EventType = 102
EventType_SessionStarted EventType = 150
EventType_SessionCanceled EventType = 151
EventType_SessionFinished EventType = 152
EventType_SessionFailed EventType = 153
EventType_UsageResponse EventType = 154
EventType_TaskRequest EventType = 200
EventType_UpdateConfig EventType = 201
EventType_AudioMuted EventType = 250
EventType_SayHello EventType = 300
EventType_TTSSentenceStart EventType = 350
EventType_TTSSentenceEnd EventType = 351
EventType_TTSResponse EventType = 352
EventType_TTSEnded EventType = 359
EventType_PodcastRoundStart EventType = 360
EventType_PodcastRoundResponse EventType = 361
EventType_PodcastRoundEnd EventType = 362
EventType_ASRInfo EventType = 450
EventType_ASRResponse EventType = 451
EventType_ASREnded EventType = 459
EventType_ChatTTSText EventType = 500
EventType_ChatResponse EventType = 550
EventType_ChatEnded EventType = 559
EventType_SourceSubtitleStart EventType = 650
EventType_SourceSubtitleResponse EventType = 651
EventType_SourceSubtitleEnd EventType = 652
EventType_TranslationSubtitleStart EventType = 653
EventType_TranslationSubtitleResponse EventType = 654
EventType_TranslationSubtitleEnd EventType = 655
)
func (t EventType) String() string {
switch t {
case EventType_None:
return "EventType_None"
case EventType_StartConnection:
return "EventType_StartConnection"
case EventType_FinishConnection:
return "EventType_FinishConnection"
case EventType_ConnectionStarted:
return "EventType_ConnectionStarted"
case EventType_ConnectionFailed:
return "EventType_ConnectionFailed"
case EventType_ConnectionFinished:
return "EventType_ConnectionFinished"
case EventType_StartSession:
return "EventType_StartSession"
case EventType_CancelSession:
return "EventType_CancelSession"
case EventType_FinishSession:
return "EventType_FinishSession"
case EventType_SessionStarted:
return "EventType_SessionStarted"
case EventType_SessionCanceled:
return "EventType_SessionCanceled"
case EventType_SessionFinished:
return "EventType_SessionFinished"
case EventType_SessionFailed:
return "EventType_SessionFailed"
case EventType_UsageResponse:
return "EventType_UsageResponse"
case EventType_TaskRequest:
return "EventType_TaskRequest"
case EventType_UpdateConfig:
return "EventType_UpdateConfig"
case EventType_AudioMuted:
return "EventType_AudioMuted"
case EventType_SayHello:
return "EventType_SayHello"
case EventType_TTSSentenceStart:
return "EventType_TTSSentenceStart"
case EventType_TTSSentenceEnd:
return "EventType_TTSSentenceEnd"
case EventType_TTSResponse:
return "EventType_TTSResponse"
case EventType_TTSEnded:
return "EventType_TTSEnded"
case EventType_PodcastRoundStart:
return "EventType_PodcastRoundStart"
case EventType_PodcastRoundResponse:
return "EventType_PodcastRoundResponse"
case EventType_PodcastRoundEnd:
return "EventType_PodcastRoundEnd"
case EventType_ASRInfo:
return "EventType_ASRInfo"
case EventType_ASRResponse:
return "EventType_ASRResponse"
case EventType_ASREnded:
return "EventType_ASREnded"
case EventType_ChatTTSText:
return "EventType_ChatTTSText"
case EventType_ChatResponse:
return "EventType_ChatResponse"
case EventType_ChatEnded:
return "EventType_ChatEnded"
case EventType_SourceSubtitleStart:
return "EventType_SourceSubtitleStart"
case EventType_SourceSubtitleResponse:
return "EventType_SourceSubtitleResponse"
case EventType_SourceSubtitleEnd:
return "EventType_SourceSubtitleEnd"
case EventType_TranslationSubtitleStart:
return "EventType_TranslationSubtitleStart"
case EventType_TranslationSubtitleResponse:
return "EventType_TranslationSubtitleResponse"
case EventType_TranslationSubtitleEnd:
return "EventType_TranslationSubtitleEnd"
default:
return fmt.Sprintf("EventType_(%d)", t)
}
}
type Message struct {
Version VersionBits
HeaderSize HeaderSizeBits
MsgType MsgType
MsgTypeFlag MsgTypeFlagBits
Serialization SerializationBits
Compression CompressionBits
EventType EventType
SessionID string
ConnectID string
Sequence int32
ErrorCode uint32
Payload []byte
}
func NewMessageFromBytes(data []byte) (*Message, error) {
if len(data) < 3 {
return nil, fmt.Errorf("data too short: expected at least 3 bytes, got %d", len(data))
}
typeAndFlag := data[1]
msg, err := NewMessage(MsgType(typeAndFlag>>4), MsgTypeFlagBits(typeAndFlag&0b00001111))
if err != nil {
return nil, err
}
if err := msg.Unmarshal(data); err != nil {
return nil, err
}
return msg, nil
}
func NewMessage(msgType MsgType, flag MsgTypeFlagBits) (*Message, error) {
return &Message{
MsgType: msgType,
MsgTypeFlag: flag,
Version: Version1,
HeaderSize: HeaderSize4,
Serialization: SerializationJSON,
Compression: CompressionNone,
}, nil
}
func (m *Message) String() string {
switch m.MsgType {
case MsgTypeAudioOnlyServer, MsgTypeAudioOnlyClient:
if m.MsgTypeFlag == MsgTypeFlagPositiveSeq || m.MsgTypeFlag == MsgTypeFlagNegativeSeq {
return fmt.Sprintf("%s, %s, Sequence: %d, PayloadSize: %d", m.MsgType, m.EventType, m.Sequence, len(m.Payload))
}
return fmt.Sprintf("%s, %s, PayloadSize: %d", m.MsgType, m.EventType, len(m.Payload))
case MsgTypeError:
return fmt.Sprintf("%s, %s, ErrorCode: %d, Payload: %s", m.MsgType, m.EventType, m.ErrorCode, string(m.Payload))
default:
if m.MsgTypeFlag == MsgTypeFlagPositiveSeq || m.MsgTypeFlag == MsgTypeFlagNegativeSeq {
return fmt.Sprintf("%s, %s, Sequence: %d, Payload: %s",
m.MsgType, m.EventType, m.Sequence, string(m.Payload))
}
return fmt.Sprintf("%s, %s, Payload: %s", m.MsgType, m.EventType, string(m.Payload))
}
}
func (m *Message) Marshal() ([]byte, error) {
buf := new(bytes.Buffer)
header := []uint8{
uint8(m.Version)<<4 | uint8(m.HeaderSize),
uint8(m.MsgType)<<4 | uint8(m.MsgTypeFlag),
uint8(m.Serialization)<<4 | uint8(m.Compression),
}
headerSize := 4 * int(m.HeaderSize)
if padding := headerSize - len(header); padding > 0 {
header = append(header, make([]uint8, padding)...)
}
if err := binary.Write(buf, binary.BigEndian, header); err != nil {
return nil, err
}
writers, err := m.writers()
if err != nil {
return nil, err
}
for _, write := range writers {
if err := write(buf); err != nil {
return nil, err
}
}
return buf.Bytes(), nil
}
func (m *Message) Unmarshal(data []byte) error {
buf := bytes.NewBuffer(data)
versionAndHeaderSize, err := buf.ReadByte()
if err != nil {
return err
}
m.Version = VersionBits(versionAndHeaderSize >> 4)
m.HeaderSize = HeaderSizeBits(versionAndHeaderSize & 0b00001111)
_, err = buf.ReadByte()
if err != nil {
return err
}
serializationCompression, err := buf.ReadByte()
if err != nil {
return err
}
m.Serialization = SerializationBits(serializationCompression & 0b11110000)
m.Compression = CompressionBits(serializationCompression & 0b00001111)
headerSize := 4 * int(m.HeaderSize)
readSize := 3
if paddingSize := headerSize - readSize; paddingSize > 0 {
if n, err := buf.Read(make([]byte, paddingSize)); err != nil || n < paddingSize {
return fmt.Errorf("insufficient header bytes: expected %d, got %d", paddingSize, n)
}
}
readers, err := m.readers()
if err != nil {
return err
}
for _, read := range readers {
if err := read(buf); err != nil {
return err
}
}
if _, err := buf.ReadByte(); err != io.EOF {
return fmt.Errorf("unexpected data after message: %v", err)
}
return nil
}
func (m *Message) writers() (writers []func(*bytes.Buffer) error, _ error) {
if m.MsgTypeFlag == MsgTypeFlagWithEvent {
writers = append(writers, m.writeEvent, m.writeSessionID)
}
switch m.MsgType {
case MsgTypeFullClientRequest, MsgTypeFullServerResponse, MsgTypeFrontEndResultServer, MsgTypeAudioOnlyClient, MsgTypeAudioOnlyServer:
if m.MsgTypeFlag == MsgTypeFlagPositiveSeq || m.MsgTypeFlag == MsgTypeFlagNegativeSeq {
writers = append(writers, m.writeSequence)
}
case MsgTypeError:
writers = append(writers, m.writeErrorCode)
default:
return nil, fmt.Errorf("unsupported message type: %d", m.MsgType)
}
writers = append(writers, m.writePayload)
return writers, nil
}
func (m *Message) writeEvent(buf *bytes.Buffer) error {
return binary.Write(buf, binary.BigEndian, m.EventType)
}
func (m *Message) writeSessionID(buf *bytes.Buffer) error {
switch m.EventType {
case EventType_StartConnection, EventType_FinishConnection,
EventType_ConnectionStarted, EventType_ConnectionFailed:
return nil
}
size := len(m.SessionID)
if size > math.MaxUint32 {
return fmt.Errorf("session ID size (%d) exceeds max(uint32)", size)
}
if err := binary.Write(buf, binary.BigEndian, uint32(size)); err != nil {
return err
}
buf.WriteString(m.SessionID)
return nil
}
func (m *Message) writeSequence(buf *bytes.Buffer) error {
return binary.Write(buf, binary.BigEndian, m.Sequence)
}
func (m *Message) writeErrorCode(buf *bytes.Buffer) error {
return binary.Write(buf, binary.BigEndian, m.ErrorCode)
}
func (m *Message) writePayload(buf *bytes.Buffer) error {
size := len(m.Payload)
if size > math.MaxUint32 {
return fmt.Errorf("payload size (%d) exceeds max(uint32)", size)
}
if err := binary.Write(buf, binary.BigEndian, uint32(size)); err != nil {
return err
}
buf.Write(m.Payload)
return nil
}
func (m *Message) readers() (readers []func(*bytes.Buffer) error, _ error) {
switch m.MsgType {
case MsgTypeFullClientRequest, MsgTypeFullServerResponse, MsgTypeFrontEndResultServer, MsgTypeAudioOnlyClient, MsgTypeAudioOnlyServer:
if m.MsgTypeFlag == MsgTypeFlagPositiveSeq || m.MsgTypeFlag == MsgTypeFlagNegativeSeq {
readers = append(readers, m.readSequence)
}
case MsgTypeError:
readers = append(readers, m.readErrorCode)
default:
return nil, fmt.Errorf("unsupported message type: %d", m.MsgType)
}
if m.MsgTypeFlag == MsgTypeFlagWithEvent {
readers = append(readers, m.readEvent, m.readSessionID, m.readConnectID)
}
readers = append(readers, m.readPayload)
return readers, nil
}
func (m *Message) readEvent(buf *bytes.Buffer) error {
return binary.Read(buf, binary.BigEndian, &m.EventType)
}
func (m *Message) readSessionID(buf *bytes.Buffer) error {
switch m.EventType {
case EventType_StartConnection, EventType_FinishConnection,
EventType_ConnectionStarted, EventType_ConnectionFailed,
EventType_ConnectionFinished:
return nil
}
var size uint32
if err := binary.Read(buf, binary.BigEndian, &size); err != nil {
return err
}
if size > 0 {
m.SessionID = string(buf.Next(int(size)))
}
return nil
}
func (m *Message) readConnectID(buf *bytes.Buffer) error {
switch m.EventType {
case EventType_ConnectionStarted, EventType_ConnectionFailed,
EventType_ConnectionFinished:
default:
return nil
}
var size uint32
if err := binary.Read(buf, binary.BigEndian, &size); err != nil {
return err
}
if size > 0 {
m.ConnectID = string(buf.Next(int(size)))
}
return nil
}
func (m *Message) readSequence(buf *bytes.Buffer) error {
return binary.Read(buf, binary.BigEndian, &m.Sequence)
}
func (m *Message) readErrorCode(buf *bytes.Buffer) error {
return binary.Read(buf, binary.BigEndian, &m.ErrorCode)
}
func (m *Message) readPayload(buf *bytes.Buffer) error {
var size uint32
if err := binary.Read(buf, binary.BigEndian, &size); err != nil {
return err
}
if size > 0 {
m.Payload = buf.Next(int(size))
}
return nil
}
func ReceiveMessage(conn *websocket.Conn) (*Message, error) {
mt, frame, err := conn.ReadMessage()
if err != nil {
return nil, err
}
if mt != websocket.BinaryMessage && mt != websocket.TextMessage {
return nil, fmt.Errorf("unexpected Websocket message type: %d", mt)
}
msg, err := NewMessageFromBytes(frame)
if err != nil {
return nil, err
}
return msg, nil
}
func FullClientRequest(conn *websocket.Conn, payload []byte) error {
msg, err := NewMessage(MsgTypeFullClientRequest, MsgTypeFlagNoSeq)
if err != nil {
return err
}
msg.Payload = payload
frame, err := msg.Marshal()
if err != nil {
return err
}
return conn.WriteMessage(websocket.BinaryMessage, frame)
}

View File

@@ -1,9 +1,11 @@
package volcengine
import (
"context"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"strings"
@@ -13,6 +15,7 @@ import (
"github.com/QuantumNous/new-api/types"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
"github.com/gorilla/websocket"
)
type VolcengineTTSRequest struct {
@@ -192,3 +195,111 @@ func handleTTSResponse(c *gin.Context, resp *http.Response, info *relaycommon.Re
func generateRequestID() string {
return uuid.New().String()
}
func handleTTSWebSocketResponse(c *gin.Context, requestURL string, volcRequest VolcengineTTSRequest, info *relaycommon.RelayInfo, encoding string) (usage any, err *types.NewAPIError) {
_, token, parseErr := parseVolcengineAuth(info.ApiKey)
if parseErr != nil {
return nil, types.NewErrorWithStatusCode(
parseErr,
types.ErrorCodeChannelInvalidKey,
http.StatusUnauthorized,
)
}
header := http.Header{}
header.Set("Authorization", fmt.Sprintf("Bearer;%s", token))
conn, resp, dialErr := websocket.DefaultDialer.DialContext(context.Background(), requestURL, header)
if dialErr != nil {
if resp != nil {
return nil, types.NewErrorWithStatusCode(
fmt.Errorf("failed to connect to websocket: %w, status: %d", dialErr, resp.StatusCode),
types.ErrorCodeBadResponseStatusCode,
http.StatusBadGateway,
)
}
return nil, types.NewErrorWithStatusCode(
fmt.Errorf("failed to connect to websocket: %w", dialErr),
types.ErrorCodeBadResponseStatusCode,
http.StatusBadGateway,
)
}
defer conn.Close()
payload, marshalErr := json.Marshal(volcRequest)
if marshalErr != nil {
return nil, types.NewErrorWithStatusCode(
fmt.Errorf("failed to marshal request: %w", marshalErr),
types.ErrorCodeBadRequestBody,
http.StatusInternalServerError,
)
}
if sendErr := FullClientRequest(conn, payload); sendErr != nil {
return nil, types.NewErrorWithStatusCode(
fmt.Errorf("failed to send request: %w", sendErr),
types.ErrorCodeBadRequestBody,
http.StatusInternalServerError,
)
}
contentType := getContentTypeByEncoding(encoding)
c.Header("Content-Type", contentType)
c.Header("Transfer-Encoding", "chunked")
for {
msg, recvErr := ReceiveMessage(conn)
if recvErr != nil {
if websocket.IsCloseError(recvErr, websocket.CloseNormalClosure, websocket.CloseGoingAway) {
break
}
return nil, types.NewErrorWithStatusCode(
fmt.Errorf("failed to receive message: %w", recvErr),
types.ErrorCodeBadResponse,
http.StatusInternalServerError,
)
}
switch msg.MsgType {
case MsgTypeError:
return nil, types.NewErrorWithStatusCode(
fmt.Errorf("received error from server: code=%d, %s", msg.ErrorCode, string(msg.Payload)),
types.ErrorCodeBadResponse,
http.StatusBadRequest,
)
case MsgTypeFrontEndResultServer:
continue
case MsgTypeAudioOnlyServer:
if len(msg.Payload) > 0 {
if _, writeErr := c.Writer.Write(msg.Payload); writeErr != nil {
return nil, types.NewErrorWithStatusCode(
fmt.Errorf("failed to write audio data: %w", writeErr),
types.ErrorCodeBadResponse,
http.StatusInternalServerError,
)
}
c.Writer.Flush()
}
if msg.Sequence < 0 {
c.Status(http.StatusOK)
usage = &dto.Usage{
PromptTokens: info.PromptTokens,
CompletionTokens: 0,
TotalTokens: info.PromptTokens,
}
return usage, nil
}
default:
continue
}
}
c.Status(http.StatusOK)
usage = &dto.Usage{
PromptTokens: info.PromptTokens,
CompletionTokens: 0,
TotalTokens: info.PromptTokens,
}
return usage, nil
}

View File

@@ -264,6 +264,7 @@ var streamSupportedChannels = map[int]bool{
constant.ChannelTypeBaiduV2: true,
constant.ChannelTypeZhipu_v4: true,
constant.ChannelTypeAli: true,
constant.ChannelTypeSubmodel: true,
}
func GenRelayInfoWs(c *gin.Context, ws *websocket.Conn) *RelayInfo {

View File

@@ -62,19 +62,9 @@ func GetAndValidAudioRequest(c *gin.Context, relayMode int) (*dto.AudioRequest,
return nil, errors.New("model is required")
}
default:
err = c.Request.ParseForm()
if err != nil {
return nil, err
}
formData := c.Request.PostForm
if audioRequest.Model == "" {
audioRequest.Model = formData.Get("model")
}
if audioRequest.Model == "" {
return nil, errors.New("model is required")
}
audioRequest.ResponseFormat = formData.Get("response_format")
if audioRequest.ResponseFormat == "" {
audioRequest.ResponseFormat = "json"
}
@@ -160,8 +150,9 @@ func GetAndValidOpenAIImageRequest(c *gin.Context, relayMode int) (*dto.ImageReq
imageRequest.N = 1
}
watermark := formData.Has("watermark")
if watermark {
hasWatermark := formData.Has("watermark")
if hasWatermark {
watermark := formData.Get("watermark") == "true"
imageRequest.Watermark = &watermark
}
break

View File

@@ -10,6 +10,7 @@ import (
_ "image/png"
"log"
"math"
"path/filepath"
"strings"
"sync"
"unicode/utf8"
@@ -18,6 +19,7 @@ import (
"github.com/QuantumNous/new-api/constant"
"github.com/QuantumNous/new-api/dto"
relaycommon "github.com/QuantumNous/new-api/relay/common"
constant2 "github.com/QuantumNous/new-api/relay/constant"
"github.com/QuantumNous/new-api/types"
"github.com/gin-gonic/gin"
@@ -254,6 +256,10 @@ func getImageToken(fileMeta *types.FileMeta, model string, stream bool) (int, er
}
func CountRequestToken(c *gin.Context, meta *types.TokenCountMeta, info *relaycommon.RelayInfo) (int, error) {
if meta == nil {
return 0, errors.New("token count meta is nil")
}
if !constant.GetMediaToken {
return 0, nil
}
@@ -263,8 +269,29 @@ func CountRequestToken(c *gin.Context, meta *types.TokenCountMeta, info *relayco
if info.RelayFormat == types.RelayFormatOpenAIRealtime {
return 0, nil
}
if meta == nil {
return 0, errors.New("token count meta is nil")
if info.RelayMode == constant2.RelayModeAudioTranscription || info.RelayMode == constant2.RelayModeAudioTranslation {
multiForm, err := common.ParseMultipartFormReusable(c)
if err != nil {
return 0, fmt.Errorf("error parsing multipart form: %v", err)
}
fileHeaders := multiForm.File["file"]
totalAudioToken := 0
for _, fileHeader := range fileHeaders {
file, err := fileHeader.Open()
if err != nil {
return 0, fmt.Errorf("error opening audio file: %v", err)
}
defer file.Close()
// get ext and io.seeker
ext := filepath.Ext(fileHeader.Filename)
duration, err := common.GetAudioDuration(c.Request.Context(), file, ext)
if err != nil {
return 0, fmt.Errorf("error getting audio duration: %v", err)
}
// 一分钟 1000 token与 $price / minute 对齐
totalAudioToken += int(math.Round(math.Ceil(duration) / 60.0 * 1000))
}
return totalAudioToken, nil
}
model := common.GetContextKeyString(c, constant.ContextKeyOriginalModel)