mirror of
https://github.com/QuantumNous/new-api.git
synced 2026-03-30 02:05:21 +00:00
Merge pull request #2087 from feitianbubu/pr/doubao-tts-stream
feat: doubao tts support streaming realtime audio
This commit is contained in:
@@ -23,6 +23,11 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
const (
|
||||
contextKeyTTSRequest = "volcengine_tts_request"
|
||||
contextKeyResponseFormat = "response_format"
|
||||
)
|
||||
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
@@ -50,7 +55,7 @@ func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInf
|
||||
speedRatio := request.Speed
|
||||
encoding := mapEncoding(request.ResponseFormat)
|
||||
|
||||
c.Set("response_format", encoding)
|
||||
c.Set(contextKeyResponseFormat, encoding)
|
||||
|
||||
volcRequest := VolcengineTTSRequest{
|
||||
App: VolcengineTTSApp{
|
||||
@@ -70,18 +75,23 @@ func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInf
|
||||
Request: VolcengineTTSReqInfo{
|
||||
ReqID: generateRequestID(),
|
||||
Text: request.Input,
|
||||
Operation: "query",
|
||||
Operation: "submit",
|
||||
Model: info.OriginModelName,
|
||||
},
|
||||
}
|
||||
|
||||
// 同步扩展字段的厂商自定义metadata
|
||||
if len(request.Metadata) > 0 {
|
||||
if err = json.Unmarshal(request.Metadata, &volcRequest); err != nil {
|
||||
return nil, fmt.Errorf("error unmarshalling metadata to volcengine request: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
c.Set(contextKeyTTSRequest, volcRequest)
|
||||
|
||||
if volcRequest.Request.Operation == "submit" {
|
||||
info.IsStream = true
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(volcRequest)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error marshalling volcengine request: %w", err)
|
||||
@@ -100,9 +110,8 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf
|
||||
writer := multipart.NewWriter(&requestBody)
|
||||
|
||||
writer.WriteField("model", request.Model)
|
||||
// 获取所有表单字段
|
||||
|
||||
formData := c.Request.PostForm
|
||||
// 遍历表单字段并打印输出
|
||||
for key, values := range formData {
|
||||
if key == "model" {
|
||||
continue
|
||||
@@ -112,21 +121,16 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf
|
||||
}
|
||||
}
|
||||
|
||||
// Parse the multipart form to handle both single image and multiple images
|
||||
if err := c.Request.ParseMultipartForm(32 << 20); err != nil { // 32MB max memory
|
||||
if err := c.Request.ParseMultipartForm(32 << 20); err != nil {
|
||||
return nil, errors.New("failed to parse multipart form")
|
||||
}
|
||||
|
||||
if c.Request.MultipartForm != nil && c.Request.MultipartForm.File != nil {
|
||||
// Check if "image" field exists in any form, including array notation
|
||||
var imageFiles []*multipart.FileHeader
|
||||
var exists bool
|
||||
|
||||
// First check for standard "image" field
|
||||
if imageFiles, exists = c.Request.MultipartForm.File["image"]; !exists || len(imageFiles) == 0 {
|
||||
// If not found, check for "image[]" field
|
||||
if imageFiles, exists = c.Request.MultipartForm.File["image[]"]; !exists || len(imageFiles) == 0 {
|
||||
// If still not found, iterate through all fields to find any that start with "image["
|
||||
foundArrayImages := false
|
||||
for fieldName, files := range c.Request.MultipartForm.File {
|
||||
if strings.HasPrefix(fieldName, "image[") && len(files) > 0 {
|
||||
@@ -137,14 +141,12 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf
|
||||
}
|
||||
}
|
||||
|
||||
// If no image fields found at all
|
||||
if !foundArrayImages && (len(imageFiles) == 0) {
|
||||
return nil, errors.New("image is required")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Process all image files
|
||||
for i, fileHeader := range imageFiles {
|
||||
file, err := fileHeader.Open()
|
||||
if err != nil {
|
||||
@@ -152,16 +154,13 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
// If multiple images, use image[] as the field name
|
||||
fieldName := "image"
|
||||
if len(imageFiles) > 1 {
|
||||
fieldName = "image[]"
|
||||
}
|
||||
|
||||
// Determine MIME type based on file extension
|
||||
mimeType := detectImageMimeType(fileHeader.Filename)
|
||||
|
||||
// Create a form file with the appropriate content type
|
||||
h := make(textproto.MIMEHeader)
|
||||
h.Set("Content-Disposition", fmt.Sprintf(`form-data; name="%s"; filename="%s"`, fieldName, fileHeader.Filename))
|
||||
h.Set("Content-Type", mimeType)
|
||||
@@ -176,7 +175,6 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf
|
||||
}
|
||||
}
|
||||
|
||||
// Handle mask file if present
|
||||
if maskFiles, exists := c.Request.MultipartForm.File["mask"]; exists && len(maskFiles) > 0 {
|
||||
maskFile, err := maskFiles[0].Open()
|
||||
if err != nil {
|
||||
@@ -184,10 +182,8 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf
|
||||
}
|
||||
defer maskFile.Close()
|
||||
|
||||
// Determine MIME type for mask file
|
||||
mimeType := detectImageMimeType(maskFiles[0].Filename)
|
||||
|
||||
// Create a form file with the appropriate content type
|
||||
h := make(textproto.MIMEHeader)
|
||||
h.Set("Content-Disposition", fmt.Sprintf(`form-data; name="mask"; filename="%s"`, maskFiles[0].Filename))
|
||||
h.Set("Content-Type", mimeType)
|
||||
@@ -205,7 +201,6 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf
|
||||
return nil, errors.New("no multipart form data found")
|
||||
}
|
||||
|
||||
// 关闭 multipart 编写器以设置分界线
|
||||
writer.Close()
|
||||
c.Request.Header.Set("Content-Type", writer.FormDataContentType())
|
||||
return bytes.NewReader(requestBody.Bytes()), nil
|
||||
@@ -215,7 +210,6 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf
|
||||
}
|
||||
}
|
||||
|
||||
// detectImageMimeType determines the MIME type based on the file extension
|
||||
func detectImageMimeType(filename string) string {
|
||||
ext := strings.ToLower(filepath.Ext(filename))
|
||||
switch ext {
|
||||
@@ -226,11 +220,9 @@ func detectImageMimeType(filename string) string {
|
||||
case ".webp":
|
||||
return "image/webp"
|
||||
default:
|
||||
// Try to detect from extension if possible
|
||||
if strings.HasPrefix(ext, ".jp") {
|
||||
return "image/jpeg"
|
||||
}
|
||||
// Default to png as a fallback
|
||||
return "image/png"
|
||||
}
|
||||
}
|
||||
@@ -266,9 +258,8 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
case constant.RelayModeRerank:
|
||||
return fmt.Sprintf("%s/api/v3/rerank", baseUrl), nil
|
||||
case constant.RelayModeAudioSpeech:
|
||||
// 只有当 baseUrl 是火山默认的官方Url时才改为官方的的TTS接口,否则走透传的New接口
|
||||
if baseUrl == channelconstant.ChannelBaseURLs[channelconstant.ChannelTypeVolcEngine] {
|
||||
return "https://openspeech.bytedance.com/api/v1/tts", nil
|
||||
return "wss://openspeech.bytedance.com/api/v1/tts/ws_binary", nil
|
||||
}
|
||||
return fmt.Sprintf("%s/v1/audio/speech", baseUrl), nil
|
||||
default:
|
||||
@@ -297,7 +288,7 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
// 适配 方舟deepseek混合模型 的 thinking 后缀
|
||||
|
||||
if strings.HasSuffix(info.UpstreamModelName, "-thinking") && strings.HasPrefix(info.UpstreamModelName, "deepseek") {
|
||||
info.UpstreamModelName = strings.TrimSuffix(info.UpstreamModelName, "-thinking")
|
||||
request.Model = info.UpstreamModelName
|
||||
@@ -315,17 +306,58 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {
|
||||
// TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
||||
if info.RelayMode == constant.RelayModeAudioSpeech {
|
||||
baseUrl := info.ChannelBaseUrl
|
||||
if baseUrl == "" {
|
||||
baseUrl = channelconstant.ChannelBaseURLs[channelconstant.ChannelTypeVolcEngine]
|
||||
}
|
||||
|
||||
if baseUrl == channelconstant.ChannelBaseURLs[channelconstant.ChannelTypeVolcEngine] {
|
||||
if info.IsStream {
|
||||
return nil, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
return channel.DoApiRequest(a, c, info, requestBody)
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
|
||||
if info.RelayMode == constant.RelayModeAudioSpeech {
|
||||
encoding := mapEncoding(c.GetString("response_format"))
|
||||
encoding := mapEncoding(c.GetString(contextKeyResponseFormat))
|
||||
if info.IsStream {
|
||||
volcRequestInterface, exists := c.Get(contextKeyTTSRequest)
|
||||
if !exists {
|
||||
return nil, types.NewErrorWithStatusCode(
|
||||
errors.New("volcengine TTS request not found in context"),
|
||||
types.ErrorCodeBadRequestBody,
|
||||
http.StatusInternalServerError,
|
||||
)
|
||||
}
|
||||
|
||||
volcRequest, ok := volcRequestInterface.(VolcengineTTSRequest)
|
||||
if !ok {
|
||||
return nil, types.NewErrorWithStatusCode(
|
||||
errors.New("invalid volcengine TTS request type"),
|
||||
types.ErrorCodeBadRequestBody,
|
||||
http.StatusInternalServerError,
|
||||
)
|
||||
}
|
||||
|
||||
// Get the WebSocket URL
|
||||
requestURL, urlErr := a.GetRequestURL(info)
|
||||
if urlErr != nil {
|
||||
return nil, types.NewErrorWithStatusCode(
|
||||
urlErr,
|
||||
types.ErrorCodeBadRequestBody,
|
||||
http.StatusInternalServerError,
|
||||
)
|
||||
}
|
||||
return handleTTSWebSocketResponse(c, requestURL, volcRequest, info, encoding)
|
||||
}
|
||||
return handleTTSResponse(c, resp, info, encoding)
|
||||
}
|
||||
|
||||
|
||||
533
relay/channel/volcengine/protocols.go
Normal file
533
relay/channel/volcengine/protocols.go
Normal file
@@ -0,0 +1,533 @@
|
||||
package volcengine
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
"math"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
type (
|
||||
EventType int32
|
||||
MsgType uint8
|
||||
MsgTypeFlagBits uint8
|
||||
VersionBits uint8
|
||||
HeaderSizeBits uint8
|
||||
SerializationBits uint8
|
||||
CompressionBits uint8
|
||||
)
|
||||
|
||||
const (
|
||||
MsgTypeFlagNoSeq MsgTypeFlagBits = 0
|
||||
MsgTypeFlagPositiveSeq MsgTypeFlagBits = 0b1
|
||||
MsgTypeFlagNegativeSeq MsgTypeFlagBits = 0b11
|
||||
MsgTypeFlagWithEvent MsgTypeFlagBits = 0b100
|
||||
)
|
||||
|
||||
const (
|
||||
Version1 VersionBits = iota + 1
|
||||
)
|
||||
|
||||
const (
|
||||
HeaderSize4 HeaderSizeBits = iota + 1
|
||||
)
|
||||
|
||||
const (
|
||||
SerializationJSON SerializationBits = 0b1
|
||||
)
|
||||
|
||||
const (
|
||||
CompressionNone CompressionBits = 0
|
||||
)
|
||||
|
||||
const (
|
||||
MsgTypeFullClientRequest MsgType = 0b1
|
||||
MsgTypeAudioOnlyClient MsgType = 0b10
|
||||
MsgTypeFullServerResponse MsgType = 0b1001
|
||||
MsgTypeAudioOnlyServer MsgType = 0b1011
|
||||
MsgTypeFrontEndResultServer MsgType = 0b1100
|
||||
MsgTypeError MsgType = 0b1111
|
||||
)
|
||||
|
||||
func (t MsgType) String() string {
|
||||
switch t {
|
||||
case MsgTypeFullClientRequest:
|
||||
return "MsgType_FullClientRequest"
|
||||
case MsgTypeAudioOnlyClient:
|
||||
return "MsgType_AudioOnlyClient"
|
||||
case MsgTypeFullServerResponse:
|
||||
return "MsgType_FullServerResponse"
|
||||
case MsgTypeAudioOnlyServer:
|
||||
return "MsgType_AudioOnlyServer"
|
||||
case MsgTypeError:
|
||||
return "MsgType_Error"
|
||||
case MsgTypeFrontEndResultServer:
|
||||
return "MsgType_FrontEndResultServer"
|
||||
default:
|
||||
return fmt.Sprintf("MsgType_(%d)", t)
|
||||
}
|
||||
}
|
||||
|
||||
const (
|
||||
EventType_None EventType = 0
|
||||
|
||||
EventType_StartConnection EventType = 1
|
||||
EventType_FinishConnection EventType = 2
|
||||
|
||||
EventType_ConnectionStarted EventType = 50
|
||||
EventType_ConnectionFailed EventType = 51
|
||||
EventType_ConnectionFinished EventType = 52
|
||||
|
||||
EventType_StartSession EventType = 100
|
||||
EventType_CancelSession EventType = 101
|
||||
EventType_FinishSession EventType = 102
|
||||
|
||||
EventType_SessionStarted EventType = 150
|
||||
EventType_SessionCanceled EventType = 151
|
||||
EventType_SessionFinished EventType = 152
|
||||
EventType_SessionFailed EventType = 153
|
||||
|
||||
EventType_UsageResponse EventType = 154
|
||||
|
||||
EventType_TaskRequest EventType = 200
|
||||
EventType_UpdateConfig EventType = 201
|
||||
|
||||
EventType_AudioMuted EventType = 250
|
||||
|
||||
EventType_SayHello EventType = 300
|
||||
|
||||
EventType_TTSSentenceStart EventType = 350
|
||||
EventType_TTSSentenceEnd EventType = 351
|
||||
EventType_TTSResponse EventType = 352
|
||||
EventType_TTSEnded EventType = 359
|
||||
EventType_PodcastRoundStart EventType = 360
|
||||
EventType_PodcastRoundResponse EventType = 361
|
||||
EventType_PodcastRoundEnd EventType = 362
|
||||
|
||||
EventType_ASRInfo EventType = 450
|
||||
EventType_ASRResponse EventType = 451
|
||||
EventType_ASREnded EventType = 459
|
||||
|
||||
EventType_ChatTTSText EventType = 500
|
||||
|
||||
EventType_ChatResponse EventType = 550
|
||||
EventType_ChatEnded EventType = 559
|
||||
|
||||
EventType_SourceSubtitleStart EventType = 650
|
||||
EventType_SourceSubtitleResponse EventType = 651
|
||||
EventType_SourceSubtitleEnd EventType = 652
|
||||
|
||||
EventType_TranslationSubtitleStart EventType = 653
|
||||
EventType_TranslationSubtitleResponse EventType = 654
|
||||
EventType_TranslationSubtitleEnd EventType = 655
|
||||
)
|
||||
|
||||
func (t EventType) String() string {
|
||||
switch t {
|
||||
case EventType_None:
|
||||
return "EventType_None"
|
||||
case EventType_StartConnection:
|
||||
return "EventType_StartConnection"
|
||||
case EventType_FinishConnection:
|
||||
return "EventType_FinishConnection"
|
||||
case EventType_ConnectionStarted:
|
||||
return "EventType_ConnectionStarted"
|
||||
case EventType_ConnectionFailed:
|
||||
return "EventType_ConnectionFailed"
|
||||
case EventType_ConnectionFinished:
|
||||
return "EventType_ConnectionFinished"
|
||||
case EventType_StartSession:
|
||||
return "EventType_StartSession"
|
||||
case EventType_CancelSession:
|
||||
return "EventType_CancelSession"
|
||||
case EventType_FinishSession:
|
||||
return "EventType_FinishSession"
|
||||
case EventType_SessionStarted:
|
||||
return "EventType_SessionStarted"
|
||||
case EventType_SessionCanceled:
|
||||
return "EventType_SessionCanceled"
|
||||
case EventType_SessionFinished:
|
||||
return "EventType_SessionFinished"
|
||||
case EventType_SessionFailed:
|
||||
return "EventType_SessionFailed"
|
||||
case EventType_UsageResponse:
|
||||
return "EventType_UsageResponse"
|
||||
case EventType_TaskRequest:
|
||||
return "EventType_TaskRequest"
|
||||
case EventType_UpdateConfig:
|
||||
return "EventType_UpdateConfig"
|
||||
case EventType_AudioMuted:
|
||||
return "EventType_AudioMuted"
|
||||
case EventType_SayHello:
|
||||
return "EventType_SayHello"
|
||||
case EventType_TTSSentenceStart:
|
||||
return "EventType_TTSSentenceStart"
|
||||
case EventType_TTSSentenceEnd:
|
||||
return "EventType_TTSSentenceEnd"
|
||||
case EventType_TTSResponse:
|
||||
return "EventType_TTSResponse"
|
||||
case EventType_TTSEnded:
|
||||
return "EventType_TTSEnded"
|
||||
case EventType_PodcastRoundStart:
|
||||
return "EventType_PodcastRoundStart"
|
||||
case EventType_PodcastRoundResponse:
|
||||
return "EventType_PodcastRoundResponse"
|
||||
case EventType_PodcastRoundEnd:
|
||||
return "EventType_PodcastRoundEnd"
|
||||
case EventType_ASRInfo:
|
||||
return "EventType_ASRInfo"
|
||||
case EventType_ASRResponse:
|
||||
return "EventType_ASRResponse"
|
||||
case EventType_ASREnded:
|
||||
return "EventType_ASREnded"
|
||||
case EventType_ChatTTSText:
|
||||
return "EventType_ChatTTSText"
|
||||
case EventType_ChatResponse:
|
||||
return "EventType_ChatResponse"
|
||||
case EventType_ChatEnded:
|
||||
return "EventType_ChatEnded"
|
||||
case EventType_SourceSubtitleStart:
|
||||
return "EventType_SourceSubtitleStart"
|
||||
case EventType_SourceSubtitleResponse:
|
||||
return "EventType_SourceSubtitleResponse"
|
||||
case EventType_SourceSubtitleEnd:
|
||||
return "EventType_SourceSubtitleEnd"
|
||||
case EventType_TranslationSubtitleStart:
|
||||
return "EventType_TranslationSubtitleStart"
|
||||
case EventType_TranslationSubtitleResponse:
|
||||
return "EventType_TranslationSubtitleResponse"
|
||||
case EventType_TranslationSubtitleEnd:
|
||||
return "EventType_TranslationSubtitleEnd"
|
||||
default:
|
||||
return fmt.Sprintf("EventType_(%d)", t)
|
||||
}
|
||||
}
|
||||
|
||||
type Message struct {
|
||||
Version VersionBits
|
||||
HeaderSize HeaderSizeBits
|
||||
MsgType MsgType
|
||||
MsgTypeFlag MsgTypeFlagBits
|
||||
Serialization SerializationBits
|
||||
Compression CompressionBits
|
||||
|
||||
EventType EventType
|
||||
SessionID string
|
||||
ConnectID string
|
||||
Sequence int32
|
||||
ErrorCode uint32
|
||||
|
||||
Payload []byte
|
||||
}
|
||||
|
||||
func NewMessageFromBytes(data []byte) (*Message, error) {
|
||||
if len(data) < 3 {
|
||||
return nil, fmt.Errorf("data too short: expected at least 3 bytes, got %d", len(data))
|
||||
}
|
||||
|
||||
typeAndFlag := data[1]
|
||||
|
||||
msg, err := NewMessage(MsgType(typeAndFlag>>4), MsgTypeFlagBits(typeAndFlag&0b00001111))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := msg.Unmarshal(data); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
func NewMessage(msgType MsgType, flag MsgTypeFlagBits) (*Message, error) {
|
||||
return &Message{
|
||||
MsgType: msgType,
|
||||
MsgTypeFlag: flag,
|
||||
Version: Version1,
|
||||
HeaderSize: HeaderSize4,
|
||||
Serialization: SerializationJSON,
|
||||
Compression: CompressionNone,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (m *Message) String() string {
|
||||
switch m.MsgType {
|
||||
case MsgTypeAudioOnlyServer, MsgTypeAudioOnlyClient:
|
||||
if m.MsgTypeFlag == MsgTypeFlagPositiveSeq || m.MsgTypeFlag == MsgTypeFlagNegativeSeq {
|
||||
return fmt.Sprintf("%s, %s, Sequence: %d, PayloadSize: %d", m.MsgType, m.EventType, m.Sequence, len(m.Payload))
|
||||
}
|
||||
return fmt.Sprintf("%s, %s, PayloadSize: %d", m.MsgType, m.EventType, len(m.Payload))
|
||||
case MsgTypeError:
|
||||
return fmt.Sprintf("%s, %s, ErrorCode: %d, Payload: %s", m.MsgType, m.EventType, m.ErrorCode, string(m.Payload))
|
||||
default:
|
||||
if m.MsgTypeFlag == MsgTypeFlagPositiveSeq || m.MsgTypeFlag == MsgTypeFlagNegativeSeq {
|
||||
return fmt.Sprintf("%s, %s, Sequence: %d, Payload: %s",
|
||||
m.MsgType, m.EventType, m.Sequence, string(m.Payload))
|
||||
}
|
||||
return fmt.Sprintf("%s, %s, Payload: %s", m.MsgType, m.EventType, string(m.Payload))
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Message) Marshal() ([]byte, error) {
|
||||
buf := new(bytes.Buffer)
|
||||
|
||||
header := []uint8{
|
||||
uint8(m.Version)<<4 | uint8(m.HeaderSize),
|
||||
uint8(m.MsgType)<<4 | uint8(m.MsgTypeFlag),
|
||||
uint8(m.Serialization)<<4 | uint8(m.Compression),
|
||||
}
|
||||
|
||||
headerSize := 4 * int(m.HeaderSize)
|
||||
if padding := headerSize - len(header); padding > 0 {
|
||||
header = append(header, make([]uint8, padding)...)
|
||||
}
|
||||
|
||||
if err := binary.Write(buf, binary.BigEndian, header); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
writers, err := m.writers()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, write := range writers {
|
||||
if err := write(buf); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
|
||||
func (m *Message) Unmarshal(data []byte) error {
|
||||
buf := bytes.NewBuffer(data)
|
||||
|
||||
versionAndHeaderSize, err := buf.ReadByte()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
m.Version = VersionBits(versionAndHeaderSize >> 4)
|
||||
m.HeaderSize = HeaderSizeBits(versionAndHeaderSize & 0b00001111)
|
||||
|
||||
_, err = buf.ReadByte()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
serializationCompression, err := buf.ReadByte()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
m.Serialization = SerializationBits(serializationCompression & 0b11110000)
|
||||
m.Compression = CompressionBits(serializationCompression & 0b00001111)
|
||||
|
||||
headerSize := 4 * int(m.HeaderSize)
|
||||
readSize := 3
|
||||
if paddingSize := headerSize - readSize; paddingSize > 0 {
|
||||
if n, err := buf.Read(make([]byte, paddingSize)); err != nil || n < paddingSize {
|
||||
return fmt.Errorf("insufficient header bytes: expected %d, got %d", paddingSize, n)
|
||||
}
|
||||
}
|
||||
|
||||
readers, err := m.readers()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, read := range readers {
|
||||
if err := read(buf); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if _, err := buf.ReadByte(); err != io.EOF {
|
||||
return fmt.Errorf("unexpected data after message: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Message) writers() (writers []func(*bytes.Buffer) error, _ error) {
|
||||
if m.MsgTypeFlag == MsgTypeFlagWithEvent {
|
||||
writers = append(writers, m.writeEvent, m.writeSessionID)
|
||||
}
|
||||
|
||||
switch m.MsgType {
|
||||
case MsgTypeFullClientRequest, MsgTypeFullServerResponse, MsgTypeFrontEndResultServer, MsgTypeAudioOnlyClient, MsgTypeAudioOnlyServer:
|
||||
if m.MsgTypeFlag == MsgTypeFlagPositiveSeq || m.MsgTypeFlag == MsgTypeFlagNegativeSeq {
|
||||
writers = append(writers, m.writeSequence)
|
||||
}
|
||||
case MsgTypeError:
|
||||
writers = append(writers, m.writeErrorCode)
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported message type: %d", m.MsgType)
|
||||
}
|
||||
|
||||
writers = append(writers, m.writePayload)
|
||||
return writers, nil
|
||||
}
|
||||
|
||||
func (m *Message) writeEvent(buf *bytes.Buffer) error {
|
||||
return binary.Write(buf, binary.BigEndian, m.EventType)
|
||||
}
|
||||
|
||||
func (m *Message) writeSessionID(buf *bytes.Buffer) error {
|
||||
switch m.EventType {
|
||||
case EventType_StartConnection, EventType_FinishConnection,
|
||||
EventType_ConnectionStarted, EventType_ConnectionFailed:
|
||||
return nil
|
||||
}
|
||||
|
||||
size := len(m.SessionID)
|
||||
if size > math.MaxUint32 {
|
||||
return fmt.Errorf("session ID size (%d) exceeds max(uint32)", size)
|
||||
}
|
||||
|
||||
if err := binary.Write(buf, binary.BigEndian, uint32(size)); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
buf.WriteString(m.SessionID)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Message) writeSequence(buf *bytes.Buffer) error {
|
||||
return binary.Write(buf, binary.BigEndian, m.Sequence)
|
||||
}
|
||||
|
||||
func (m *Message) writeErrorCode(buf *bytes.Buffer) error {
|
||||
return binary.Write(buf, binary.BigEndian, m.ErrorCode)
|
||||
}
|
||||
|
||||
func (m *Message) writePayload(buf *bytes.Buffer) error {
|
||||
size := len(m.Payload)
|
||||
if size > math.MaxUint32 {
|
||||
return fmt.Errorf("payload size (%d) exceeds max(uint32)", size)
|
||||
}
|
||||
|
||||
if err := binary.Write(buf, binary.BigEndian, uint32(size)); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
buf.Write(m.Payload)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Message) readers() (readers []func(*bytes.Buffer) error, _ error) {
|
||||
switch m.MsgType {
|
||||
case MsgTypeFullClientRequest, MsgTypeFullServerResponse, MsgTypeFrontEndResultServer, MsgTypeAudioOnlyClient, MsgTypeAudioOnlyServer:
|
||||
if m.MsgTypeFlag == MsgTypeFlagPositiveSeq || m.MsgTypeFlag == MsgTypeFlagNegativeSeq {
|
||||
readers = append(readers, m.readSequence)
|
||||
}
|
||||
case MsgTypeError:
|
||||
readers = append(readers, m.readErrorCode)
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported message type: %d", m.MsgType)
|
||||
}
|
||||
|
||||
if m.MsgTypeFlag == MsgTypeFlagWithEvent {
|
||||
readers = append(readers, m.readEvent, m.readSessionID, m.readConnectID)
|
||||
}
|
||||
|
||||
readers = append(readers, m.readPayload)
|
||||
return readers, nil
|
||||
}
|
||||
|
||||
func (m *Message) readEvent(buf *bytes.Buffer) error {
|
||||
return binary.Read(buf, binary.BigEndian, &m.EventType)
|
||||
}
|
||||
|
||||
func (m *Message) readSessionID(buf *bytes.Buffer) error {
|
||||
switch m.EventType {
|
||||
case EventType_StartConnection, EventType_FinishConnection,
|
||||
EventType_ConnectionStarted, EventType_ConnectionFailed,
|
||||
EventType_ConnectionFinished:
|
||||
return nil
|
||||
}
|
||||
|
||||
var size uint32
|
||||
if err := binary.Read(buf, binary.BigEndian, &size); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if size > 0 {
|
||||
m.SessionID = string(buf.Next(int(size)))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Message) readConnectID(buf *bytes.Buffer) error {
|
||||
switch m.EventType {
|
||||
case EventType_ConnectionStarted, EventType_ConnectionFailed,
|
||||
EventType_ConnectionFinished:
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
|
||||
var size uint32
|
||||
if err := binary.Read(buf, binary.BigEndian, &size); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if size > 0 {
|
||||
m.ConnectID = string(buf.Next(int(size)))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Message) readSequence(buf *bytes.Buffer) error {
|
||||
return binary.Read(buf, binary.BigEndian, &m.Sequence)
|
||||
}
|
||||
|
||||
func (m *Message) readErrorCode(buf *bytes.Buffer) error {
|
||||
return binary.Read(buf, binary.BigEndian, &m.ErrorCode)
|
||||
}
|
||||
|
||||
func (m *Message) readPayload(buf *bytes.Buffer) error {
|
||||
var size uint32
|
||||
if err := binary.Read(buf, binary.BigEndian, &size); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if size > 0 {
|
||||
m.Payload = buf.Next(int(size))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func ReceiveMessage(conn *websocket.Conn) (*Message, error) {
|
||||
mt, frame, err := conn.ReadMessage()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if mt != websocket.BinaryMessage && mt != websocket.TextMessage {
|
||||
return nil, fmt.Errorf("unexpected Websocket message type: %d", mt)
|
||||
}
|
||||
msg, err := NewMessageFromBytes(frame)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
func FullClientRequest(conn *websocket.Conn, payload []byte) error {
|
||||
msg, err := NewMessage(MsgTypeFullClientRequest, MsgTypeFlagNoSeq)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
msg.Payload = payload
|
||||
frame, err := msg.Marshal()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return conn.WriteMessage(websocket.BinaryMessage, frame)
|
||||
}
|
||||
@@ -1,9 +1,11 @@
|
||||
package volcengine
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
@@ -13,6 +15,7 @@ import (
|
||||
"github.com/QuantumNous/new-api/types"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
type VolcengineTTSRequest struct {
|
||||
@@ -192,3 +195,111 @@ func handleTTSResponse(c *gin.Context, resp *http.Response, info *relaycommon.Re
|
||||
func generateRequestID() string {
|
||||
return uuid.New().String()
|
||||
}
|
||||
|
||||
func handleTTSWebSocketResponse(c *gin.Context, requestURL string, volcRequest VolcengineTTSRequest, info *relaycommon.RelayInfo, encoding string) (usage any, err *types.NewAPIError) {
|
||||
_, token, parseErr := parseVolcengineAuth(info.ApiKey)
|
||||
if parseErr != nil {
|
||||
return nil, types.NewErrorWithStatusCode(
|
||||
parseErr,
|
||||
types.ErrorCodeChannelInvalidKey,
|
||||
http.StatusUnauthorized,
|
||||
)
|
||||
}
|
||||
|
||||
header := http.Header{}
|
||||
header.Set("Authorization", fmt.Sprintf("Bearer;%s", token))
|
||||
|
||||
conn, resp, dialErr := websocket.DefaultDialer.DialContext(context.Background(), requestURL, header)
|
||||
if dialErr != nil {
|
||||
if resp != nil {
|
||||
return nil, types.NewErrorWithStatusCode(
|
||||
fmt.Errorf("failed to connect to websocket: %w, status: %d", dialErr, resp.StatusCode),
|
||||
types.ErrorCodeBadResponseStatusCode,
|
||||
http.StatusBadGateway,
|
||||
)
|
||||
}
|
||||
return nil, types.NewErrorWithStatusCode(
|
||||
fmt.Errorf("failed to connect to websocket: %w", dialErr),
|
||||
types.ErrorCodeBadResponseStatusCode,
|
||||
http.StatusBadGateway,
|
||||
)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
payload, marshalErr := json.Marshal(volcRequest)
|
||||
if marshalErr != nil {
|
||||
return nil, types.NewErrorWithStatusCode(
|
||||
fmt.Errorf("failed to marshal request: %w", marshalErr),
|
||||
types.ErrorCodeBadRequestBody,
|
||||
http.StatusInternalServerError,
|
||||
)
|
||||
}
|
||||
|
||||
if sendErr := FullClientRequest(conn, payload); sendErr != nil {
|
||||
return nil, types.NewErrorWithStatusCode(
|
||||
fmt.Errorf("failed to send request: %w", sendErr),
|
||||
types.ErrorCodeBadRequestBody,
|
||||
http.StatusInternalServerError,
|
||||
)
|
||||
}
|
||||
|
||||
contentType := getContentTypeByEncoding(encoding)
|
||||
c.Header("Content-Type", contentType)
|
||||
c.Header("Transfer-Encoding", "chunked")
|
||||
|
||||
for {
|
||||
msg, recvErr := ReceiveMessage(conn)
|
||||
if recvErr != nil {
|
||||
if websocket.IsCloseError(recvErr, websocket.CloseNormalClosure, websocket.CloseGoingAway) {
|
||||
break
|
||||
}
|
||||
return nil, types.NewErrorWithStatusCode(
|
||||
fmt.Errorf("failed to receive message: %w", recvErr),
|
||||
types.ErrorCodeBadResponse,
|
||||
http.StatusInternalServerError,
|
||||
)
|
||||
}
|
||||
|
||||
switch msg.MsgType {
|
||||
case MsgTypeError:
|
||||
return nil, types.NewErrorWithStatusCode(
|
||||
fmt.Errorf("received error from server: code=%d, %s", msg.ErrorCode, string(msg.Payload)),
|
||||
types.ErrorCodeBadResponse,
|
||||
http.StatusBadRequest,
|
||||
)
|
||||
case MsgTypeFrontEndResultServer:
|
||||
continue
|
||||
case MsgTypeAudioOnlyServer:
|
||||
if len(msg.Payload) > 0 {
|
||||
if _, writeErr := c.Writer.Write(msg.Payload); writeErr != nil {
|
||||
return nil, types.NewErrorWithStatusCode(
|
||||
fmt.Errorf("failed to write audio data: %w", writeErr),
|
||||
types.ErrorCodeBadResponse,
|
||||
http.StatusInternalServerError,
|
||||
)
|
||||
}
|
||||
c.Writer.Flush()
|
||||
}
|
||||
|
||||
if msg.Sequence < 0 {
|
||||
c.Status(http.StatusOK)
|
||||
usage = &dto.Usage{
|
||||
PromptTokens: info.PromptTokens,
|
||||
CompletionTokens: 0,
|
||||
TotalTokens: info.PromptTokens,
|
||||
}
|
||||
return usage, nil
|
||||
}
|
||||
default:
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
c.Status(http.StatusOK)
|
||||
usage = &dto.Usage{
|
||||
PromptTokens: info.PromptTokens,
|
||||
CompletionTokens: 0,
|
||||
TotalTokens: info.PromptTokens,
|
||||
}
|
||||
return usage, nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user