diff --git a/relay/channel/volcengine/adaptor.go b/relay/channel/volcengine/adaptor.go index c5d9e5dd6..f46328e37 100644 --- a/relay/channel/volcengine/adaptor.go +++ b/relay/channel/volcengine/adaptor.go @@ -23,6 +23,11 @@ import ( "github.com/gin-gonic/gin" ) +const ( + contextKeyTTSRequest = "volcengine_tts_request" + contextKeyResponseFormat = "response_format" +) + type Adaptor struct { } @@ -50,7 +55,7 @@ func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInf speedRatio := request.Speed encoding := mapEncoding(request.ResponseFormat) - c.Set("response_format", encoding) + c.Set(contextKeyResponseFormat, encoding) volcRequest := VolcengineTTSRequest{ App: VolcengineTTSApp{ @@ -70,18 +75,23 @@ func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInf Request: VolcengineTTSReqInfo{ ReqID: generateRequestID(), Text: request.Input, - Operation: "query", + Operation: "submit", Model: info.OriginModelName, }, } - // 同步扩展字段的厂商自定义metadata if len(request.Metadata) > 0 { if err = json.Unmarshal(request.Metadata, &volcRequest); err != nil { return nil, fmt.Errorf("error unmarshalling metadata to volcengine request: %w", err) } } + c.Set(contextKeyTTSRequest, volcRequest) + + if volcRequest.Request.Operation == "submit" { + info.IsStream = true + } + jsonData, err := json.Marshal(volcRequest) if err != nil { return nil, fmt.Errorf("error marshalling volcengine request: %w", err) @@ -100,9 +110,8 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf writer := multipart.NewWriter(&requestBody) writer.WriteField("model", request.Model) - // 获取所有表单字段 + formData := c.Request.PostForm - // 遍历表单字段并打印输出 for key, values := range formData { if key == "model" { continue @@ -112,21 +121,16 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf } } - // Parse the multipart form to handle both single image and multiple images - if err := c.Request.ParseMultipartForm(32 << 20); err != nil { // 32MB max memory + if err := c.Request.ParseMultipartForm(32 << 20); err != nil { return nil, errors.New("failed to parse multipart form") } if c.Request.MultipartForm != nil && c.Request.MultipartForm.File != nil { - // Check if "image" field exists in any form, including array notation var imageFiles []*multipart.FileHeader var exists bool - // First check for standard "image" field if imageFiles, exists = c.Request.MultipartForm.File["image"]; !exists || len(imageFiles) == 0 { - // If not found, check for "image[]" field if imageFiles, exists = c.Request.MultipartForm.File["image[]"]; !exists || len(imageFiles) == 0 { - // If still not found, iterate through all fields to find any that start with "image[" foundArrayImages := false for fieldName, files := range c.Request.MultipartForm.File { if strings.HasPrefix(fieldName, "image[") && len(files) > 0 { @@ -137,14 +141,12 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf } } - // If no image fields found at all if !foundArrayImages && (len(imageFiles) == 0) { return nil, errors.New("image is required") } } } - // Process all image files for i, fileHeader := range imageFiles { file, err := fileHeader.Open() if err != nil { @@ -152,16 +154,13 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf } defer file.Close() - // If multiple images, use image[] as the field name fieldName := "image" if len(imageFiles) > 1 { fieldName = "image[]" } - // Determine MIME type based on file extension mimeType := detectImageMimeType(fileHeader.Filename) - // Create a form file with the appropriate content type h := make(textproto.MIMEHeader) h.Set("Content-Disposition", fmt.Sprintf(`form-data; name="%s"; filename="%s"`, fieldName, fileHeader.Filename)) h.Set("Content-Type", mimeType) @@ -176,7 +175,6 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf } } - // Handle mask file if present if maskFiles, exists := c.Request.MultipartForm.File["mask"]; exists && len(maskFiles) > 0 { maskFile, err := maskFiles[0].Open() if err != nil { @@ -184,10 +182,8 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf } defer maskFile.Close() - // Determine MIME type for mask file mimeType := detectImageMimeType(maskFiles[0].Filename) - // Create a form file with the appropriate content type h := make(textproto.MIMEHeader) h.Set("Content-Disposition", fmt.Sprintf(`form-data; name="mask"; filename="%s"`, maskFiles[0].Filename)) h.Set("Content-Type", mimeType) @@ -205,7 +201,6 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf return nil, errors.New("no multipart form data found") } - // 关闭 multipart 编写器以设置分界线 writer.Close() c.Request.Header.Set("Content-Type", writer.FormDataContentType()) return bytes.NewReader(requestBody.Bytes()), nil @@ -215,7 +210,6 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf } } -// detectImageMimeType determines the MIME type based on the file extension func detectImageMimeType(filename string) string { ext := strings.ToLower(filepath.Ext(filename)) switch ext { @@ -226,11 +220,9 @@ func detectImageMimeType(filename string) string { case ".webp": return "image/webp" default: - // Try to detect from extension if possible if strings.HasPrefix(ext, ".jp") { return "image/jpeg" } - // Default to png as a fallback return "image/png" } } @@ -266,9 +258,8 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { case constant.RelayModeRerank: return fmt.Sprintf("%s/api/v3/rerank", baseUrl), nil case constant.RelayModeAudioSpeech: - // 只有当 baseUrl 是火山默认的官方Url时才改为官方的的TTS接口,否则走透传的New接口 if baseUrl == channelconstant.ChannelBaseURLs[channelconstant.ChannelTypeVolcEngine] { - return "https://openspeech.bytedance.com/api/v1/tts", nil + return "wss://openspeech.bytedance.com/api/v1/tts/ws_binary", nil } return fmt.Sprintf("%s/v1/audio/speech", baseUrl), nil default: @@ -297,7 +288,7 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn if request == nil { return nil, errors.New("request is nil") } - // 适配 方舟deepseek混合模型 的 thinking 后缀 + if strings.HasSuffix(info.UpstreamModelName, "-thinking") && strings.HasPrefix(info.UpstreamModelName, "deepseek") { info.UpstreamModelName = strings.TrimSuffix(info.UpstreamModelName, "-thinking") request.Model = info.UpstreamModelName @@ -315,17 +306,58 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela } func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) { - // TODO implement me return nil, errors.New("not implemented") } func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { + if info.RelayMode == constant.RelayModeAudioSpeech { + baseUrl := info.ChannelBaseUrl + if baseUrl == "" { + baseUrl = channelconstant.ChannelBaseURLs[channelconstant.ChannelTypeVolcEngine] + } + + if baseUrl == channelconstant.ChannelBaseURLs[channelconstant.ChannelTypeVolcEngine] { + if info.IsStream { + return nil, nil + } + } + } return channel.DoApiRequest(a, c, info, requestBody) } func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) { if info.RelayMode == constant.RelayModeAudioSpeech { - encoding := mapEncoding(c.GetString("response_format")) + encoding := mapEncoding(c.GetString(contextKeyResponseFormat)) + if info.IsStream { + volcRequestInterface, exists := c.Get(contextKeyTTSRequest) + if !exists { + return nil, types.NewErrorWithStatusCode( + errors.New("volcengine TTS request not found in context"), + types.ErrorCodeBadRequestBody, + http.StatusInternalServerError, + ) + } + + volcRequest, ok := volcRequestInterface.(VolcengineTTSRequest) + if !ok { + return nil, types.NewErrorWithStatusCode( + errors.New("invalid volcengine TTS request type"), + types.ErrorCodeBadRequestBody, + http.StatusInternalServerError, + ) + } + + // Get the WebSocket URL + requestURL, urlErr := a.GetRequestURL(info) + if urlErr != nil { + return nil, types.NewErrorWithStatusCode( + urlErr, + types.ErrorCodeBadRequestBody, + http.StatusInternalServerError, + ) + } + return handleTTSWebSocketResponse(c, requestURL, volcRequest, info, encoding) + } return handleTTSResponse(c, resp, info, encoding) } diff --git a/relay/channel/volcengine/protocols.go b/relay/channel/volcengine/protocols.go new file mode 100644 index 000000000..c978e1c76 --- /dev/null +++ b/relay/channel/volcengine/protocols.go @@ -0,0 +1,533 @@ +package volcengine + +import ( + "bytes" + "encoding/binary" + "fmt" + "io" + "math" + + "github.com/gorilla/websocket" +) + +type ( + EventType int32 + MsgType uint8 + MsgTypeFlagBits uint8 + VersionBits uint8 + HeaderSizeBits uint8 + SerializationBits uint8 + CompressionBits uint8 +) + +const ( + MsgTypeFlagNoSeq MsgTypeFlagBits = 0 + MsgTypeFlagPositiveSeq MsgTypeFlagBits = 0b1 + MsgTypeFlagNegativeSeq MsgTypeFlagBits = 0b11 + MsgTypeFlagWithEvent MsgTypeFlagBits = 0b100 +) + +const ( + Version1 VersionBits = iota + 1 +) + +const ( + HeaderSize4 HeaderSizeBits = iota + 1 +) + +const ( + SerializationJSON SerializationBits = 0b1 +) + +const ( + CompressionNone CompressionBits = 0 +) + +const ( + MsgTypeFullClientRequest MsgType = 0b1 + MsgTypeAudioOnlyClient MsgType = 0b10 + MsgTypeFullServerResponse MsgType = 0b1001 + MsgTypeAudioOnlyServer MsgType = 0b1011 + MsgTypeFrontEndResultServer MsgType = 0b1100 + MsgTypeError MsgType = 0b1111 +) + +func (t MsgType) String() string { + switch t { + case MsgTypeFullClientRequest: + return "MsgType_FullClientRequest" + case MsgTypeAudioOnlyClient: + return "MsgType_AudioOnlyClient" + case MsgTypeFullServerResponse: + return "MsgType_FullServerResponse" + case MsgTypeAudioOnlyServer: + return "MsgType_AudioOnlyServer" + case MsgTypeError: + return "MsgType_Error" + case MsgTypeFrontEndResultServer: + return "MsgType_FrontEndResultServer" + default: + return fmt.Sprintf("MsgType_(%d)", t) + } +} + +const ( + EventType_None EventType = 0 + + EventType_StartConnection EventType = 1 + EventType_FinishConnection EventType = 2 + + EventType_ConnectionStarted EventType = 50 + EventType_ConnectionFailed EventType = 51 + EventType_ConnectionFinished EventType = 52 + + EventType_StartSession EventType = 100 + EventType_CancelSession EventType = 101 + EventType_FinishSession EventType = 102 + + EventType_SessionStarted EventType = 150 + EventType_SessionCanceled EventType = 151 + EventType_SessionFinished EventType = 152 + EventType_SessionFailed EventType = 153 + + EventType_UsageResponse EventType = 154 + + EventType_TaskRequest EventType = 200 + EventType_UpdateConfig EventType = 201 + + EventType_AudioMuted EventType = 250 + + EventType_SayHello EventType = 300 + + EventType_TTSSentenceStart EventType = 350 + EventType_TTSSentenceEnd EventType = 351 + EventType_TTSResponse EventType = 352 + EventType_TTSEnded EventType = 359 + EventType_PodcastRoundStart EventType = 360 + EventType_PodcastRoundResponse EventType = 361 + EventType_PodcastRoundEnd EventType = 362 + + EventType_ASRInfo EventType = 450 + EventType_ASRResponse EventType = 451 + EventType_ASREnded EventType = 459 + + EventType_ChatTTSText EventType = 500 + + EventType_ChatResponse EventType = 550 + EventType_ChatEnded EventType = 559 + + EventType_SourceSubtitleStart EventType = 650 + EventType_SourceSubtitleResponse EventType = 651 + EventType_SourceSubtitleEnd EventType = 652 + + EventType_TranslationSubtitleStart EventType = 653 + EventType_TranslationSubtitleResponse EventType = 654 + EventType_TranslationSubtitleEnd EventType = 655 +) + +func (t EventType) String() string { + switch t { + case EventType_None: + return "EventType_None" + case EventType_StartConnection: + return "EventType_StartConnection" + case EventType_FinishConnection: + return "EventType_FinishConnection" + case EventType_ConnectionStarted: + return "EventType_ConnectionStarted" + case EventType_ConnectionFailed: + return "EventType_ConnectionFailed" + case EventType_ConnectionFinished: + return "EventType_ConnectionFinished" + case EventType_StartSession: + return "EventType_StartSession" + case EventType_CancelSession: + return "EventType_CancelSession" + case EventType_FinishSession: + return "EventType_FinishSession" + case EventType_SessionStarted: + return "EventType_SessionStarted" + case EventType_SessionCanceled: + return "EventType_SessionCanceled" + case EventType_SessionFinished: + return "EventType_SessionFinished" + case EventType_SessionFailed: + return "EventType_SessionFailed" + case EventType_UsageResponse: + return "EventType_UsageResponse" + case EventType_TaskRequest: + return "EventType_TaskRequest" + case EventType_UpdateConfig: + return "EventType_UpdateConfig" + case EventType_AudioMuted: + return "EventType_AudioMuted" + case EventType_SayHello: + return "EventType_SayHello" + case EventType_TTSSentenceStart: + return "EventType_TTSSentenceStart" + case EventType_TTSSentenceEnd: + return "EventType_TTSSentenceEnd" + case EventType_TTSResponse: + return "EventType_TTSResponse" + case EventType_TTSEnded: + return "EventType_TTSEnded" + case EventType_PodcastRoundStart: + return "EventType_PodcastRoundStart" + case EventType_PodcastRoundResponse: + return "EventType_PodcastRoundResponse" + case EventType_PodcastRoundEnd: + return "EventType_PodcastRoundEnd" + case EventType_ASRInfo: + return "EventType_ASRInfo" + case EventType_ASRResponse: + return "EventType_ASRResponse" + case EventType_ASREnded: + return "EventType_ASREnded" + case EventType_ChatTTSText: + return "EventType_ChatTTSText" + case EventType_ChatResponse: + return "EventType_ChatResponse" + case EventType_ChatEnded: + return "EventType_ChatEnded" + case EventType_SourceSubtitleStart: + return "EventType_SourceSubtitleStart" + case EventType_SourceSubtitleResponse: + return "EventType_SourceSubtitleResponse" + case EventType_SourceSubtitleEnd: + return "EventType_SourceSubtitleEnd" + case EventType_TranslationSubtitleStart: + return "EventType_TranslationSubtitleStart" + case EventType_TranslationSubtitleResponse: + return "EventType_TranslationSubtitleResponse" + case EventType_TranslationSubtitleEnd: + return "EventType_TranslationSubtitleEnd" + default: + return fmt.Sprintf("EventType_(%d)", t) + } +} + +type Message struct { + Version VersionBits + HeaderSize HeaderSizeBits + MsgType MsgType + MsgTypeFlag MsgTypeFlagBits + Serialization SerializationBits + Compression CompressionBits + + EventType EventType + SessionID string + ConnectID string + Sequence int32 + ErrorCode uint32 + + Payload []byte +} + +func NewMessageFromBytes(data []byte) (*Message, error) { + if len(data) < 3 { + return nil, fmt.Errorf("data too short: expected at least 3 bytes, got %d", len(data)) + } + + typeAndFlag := data[1] + + msg, err := NewMessage(MsgType(typeAndFlag>>4), MsgTypeFlagBits(typeAndFlag&0b00001111)) + if err != nil { + return nil, err + } + + if err := msg.Unmarshal(data); err != nil { + return nil, err + } + + return msg, nil +} + +func NewMessage(msgType MsgType, flag MsgTypeFlagBits) (*Message, error) { + return &Message{ + MsgType: msgType, + MsgTypeFlag: flag, + Version: Version1, + HeaderSize: HeaderSize4, + Serialization: SerializationJSON, + Compression: CompressionNone, + }, nil +} + +func (m *Message) String() string { + switch m.MsgType { + case MsgTypeAudioOnlyServer, MsgTypeAudioOnlyClient: + if m.MsgTypeFlag == MsgTypeFlagPositiveSeq || m.MsgTypeFlag == MsgTypeFlagNegativeSeq { + return fmt.Sprintf("%s, %s, Sequence: %d, PayloadSize: %d", m.MsgType, m.EventType, m.Sequence, len(m.Payload)) + } + return fmt.Sprintf("%s, %s, PayloadSize: %d", m.MsgType, m.EventType, len(m.Payload)) + case MsgTypeError: + return fmt.Sprintf("%s, %s, ErrorCode: %d, Payload: %s", m.MsgType, m.EventType, m.ErrorCode, string(m.Payload)) + default: + if m.MsgTypeFlag == MsgTypeFlagPositiveSeq || m.MsgTypeFlag == MsgTypeFlagNegativeSeq { + return fmt.Sprintf("%s, %s, Sequence: %d, Payload: %s", + m.MsgType, m.EventType, m.Sequence, string(m.Payload)) + } + return fmt.Sprintf("%s, %s, Payload: %s", m.MsgType, m.EventType, string(m.Payload)) + } +} + +func (m *Message) Marshal() ([]byte, error) { + buf := new(bytes.Buffer) + + header := []uint8{ + uint8(m.Version)<<4 | uint8(m.HeaderSize), + uint8(m.MsgType)<<4 | uint8(m.MsgTypeFlag), + uint8(m.Serialization)<<4 | uint8(m.Compression), + } + + headerSize := 4 * int(m.HeaderSize) + if padding := headerSize - len(header); padding > 0 { + header = append(header, make([]uint8, padding)...) + } + + if err := binary.Write(buf, binary.BigEndian, header); err != nil { + return nil, err + } + + writers, err := m.writers() + if err != nil { + return nil, err + } + + for _, write := range writers { + if err := write(buf); err != nil { + return nil, err + } + } + + return buf.Bytes(), nil +} + +func (m *Message) Unmarshal(data []byte) error { + buf := bytes.NewBuffer(data) + + versionAndHeaderSize, err := buf.ReadByte() + if err != nil { + return err + } + + m.Version = VersionBits(versionAndHeaderSize >> 4) + m.HeaderSize = HeaderSizeBits(versionAndHeaderSize & 0b00001111) + + _, err = buf.ReadByte() + if err != nil { + return err + } + + serializationCompression, err := buf.ReadByte() + if err != nil { + return err + } + + m.Serialization = SerializationBits(serializationCompression & 0b11110000) + m.Compression = CompressionBits(serializationCompression & 0b00001111) + + headerSize := 4 * int(m.HeaderSize) + readSize := 3 + if paddingSize := headerSize - readSize; paddingSize > 0 { + if n, err := buf.Read(make([]byte, paddingSize)); err != nil || n < paddingSize { + return fmt.Errorf("insufficient header bytes: expected %d, got %d", paddingSize, n) + } + } + + readers, err := m.readers() + if err != nil { + return err + } + + for _, read := range readers { + if err := read(buf); err != nil { + return err + } + } + + if _, err := buf.ReadByte(); err != io.EOF { + return fmt.Errorf("unexpected data after message: %v", err) + } + + return nil +} + +func (m *Message) writers() (writers []func(*bytes.Buffer) error, _ error) { + if m.MsgTypeFlag == MsgTypeFlagWithEvent { + writers = append(writers, m.writeEvent, m.writeSessionID) + } + + switch m.MsgType { + case MsgTypeFullClientRequest, MsgTypeFullServerResponse, MsgTypeFrontEndResultServer, MsgTypeAudioOnlyClient, MsgTypeAudioOnlyServer: + if m.MsgTypeFlag == MsgTypeFlagPositiveSeq || m.MsgTypeFlag == MsgTypeFlagNegativeSeq { + writers = append(writers, m.writeSequence) + } + case MsgTypeError: + writers = append(writers, m.writeErrorCode) + default: + return nil, fmt.Errorf("unsupported message type: %d", m.MsgType) + } + + writers = append(writers, m.writePayload) + return writers, nil +} + +func (m *Message) writeEvent(buf *bytes.Buffer) error { + return binary.Write(buf, binary.BigEndian, m.EventType) +} + +func (m *Message) writeSessionID(buf *bytes.Buffer) error { + switch m.EventType { + case EventType_StartConnection, EventType_FinishConnection, + EventType_ConnectionStarted, EventType_ConnectionFailed: + return nil + } + + size := len(m.SessionID) + if size > math.MaxUint32 { + return fmt.Errorf("session ID size (%d) exceeds max(uint32)", size) + } + + if err := binary.Write(buf, binary.BigEndian, uint32(size)); err != nil { + return err + } + + buf.WriteString(m.SessionID) + return nil +} + +func (m *Message) writeSequence(buf *bytes.Buffer) error { + return binary.Write(buf, binary.BigEndian, m.Sequence) +} + +func (m *Message) writeErrorCode(buf *bytes.Buffer) error { + return binary.Write(buf, binary.BigEndian, m.ErrorCode) +} + +func (m *Message) writePayload(buf *bytes.Buffer) error { + size := len(m.Payload) + if size > math.MaxUint32 { + return fmt.Errorf("payload size (%d) exceeds max(uint32)", size) + } + + if err := binary.Write(buf, binary.BigEndian, uint32(size)); err != nil { + return err + } + + buf.Write(m.Payload) + return nil +} + +func (m *Message) readers() (readers []func(*bytes.Buffer) error, _ error) { + switch m.MsgType { + case MsgTypeFullClientRequest, MsgTypeFullServerResponse, MsgTypeFrontEndResultServer, MsgTypeAudioOnlyClient, MsgTypeAudioOnlyServer: + if m.MsgTypeFlag == MsgTypeFlagPositiveSeq || m.MsgTypeFlag == MsgTypeFlagNegativeSeq { + readers = append(readers, m.readSequence) + } + case MsgTypeError: + readers = append(readers, m.readErrorCode) + default: + return nil, fmt.Errorf("unsupported message type: %d", m.MsgType) + } + + if m.MsgTypeFlag == MsgTypeFlagWithEvent { + readers = append(readers, m.readEvent, m.readSessionID, m.readConnectID) + } + + readers = append(readers, m.readPayload) + return readers, nil +} + +func (m *Message) readEvent(buf *bytes.Buffer) error { + return binary.Read(buf, binary.BigEndian, &m.EventType) +} + +func (m *Message) readSessionID(buf *bytes.Buffer) error { + switch m.EventType { + case EventType_StartConnection, EventType_FinishConnection, + EventType_ConnectionStarted, EventType_ConnectionFailed, + EventType_ConnectionFinished: + return nil + } + + var size uint32 + if err := binary.Read(buf, binary.BigEndian, &size); err != nil { + return err + } + + if size > 0 { + m.SessionID = string(buf.Next(int(size))) + } + + return nil +} + +func (m *Message) readConnectID(buf *bytes.Buffer) error { + switch m.EventType { + case EventType_ConnectionStarted, EventType_ConnectionFailed, + EventType_ConnectionFinished: + default: + return nil + } + + var size uint32 + if err := binary.Read(buf, binary.BigEndian, &size); err != nil { + return err + } + + if size > 0 { + m.ConnectID = string(buf.Next(int(size))) + } + + return nil +} + +func (m *Message) readSequence(buf *bytes.Buffer) error { + return binary.Read(buf, binary.BigEndian, &m.Sequence) +} + +func (m *Message) readErrorCode(buf *bytes.Buffer) error { + return binary.Read(buf, binary.BigEndian, &m.ErrorCode) +} + +func (m *Message) readPayload(buf *bytes.Buffer) error { + var size uint32 + if err := binary.Read(buf, binary.BigEndian, &size); err != nil { + return err + } + + if size > 0 { + m.Payload = buf.Next(int(size)) + } + + return nil +} + +func ReceiveMessage(conn *websocket.Conn) (*Message, error) { + mt, frame, err := conn.ReadMessage() + if err != nil { + return nil, err + } + if mt != websocket.BinaryMessage && mt != websocket.TextMessage { + return nil, fmt.Errorf("unexpected Websocket message type: %d", mt) + } + msg, err := NewMessageFromBytes(frame) + if err != nil { + return nil, err + } + return msg, nil +} + +func FullClientRequest(conn *websocket.Conn, payload []byte) error { + msg, err := NewMessage(MsgTypeFullClientRequest, MsgTypeFlagNoSeq) + if err != nil { + return err + } + msg.Payload = payload + frame, err := msg.Marshal() + if err != nil { + return err + } + return conn.WriteMessage(websocket.BinaryMessage, frame) +} diff --git a/relay/channel/volcengine/tts.go b/relay/channel/volcengine/tts.go index 328512845..166fab8ef 100644 --- a/relay/channel/volcengine/tts.go +++ b/relay/channel/volcengine/tts.go @@ -1,9 +1,11 @@ package volcengine import ( + "context" "encoding/base64" "encoding/json" "errors" + "fmt" "io" "net/http" "strings" @@ -13,6 +15,7 @@ import ( "github.com/QuantumNous/new-api/types" "github.com/gin-gonic/gin" "github.com/google/uuid" + "github.com/gorilla/websocket" ) type VolcengineTTSRequest struct { @@ -192,3 +195,111 @@ func handleTTSResponse(c *gin.Context, resp *http.Response, info *relaycommon.Re func generateRequestID() string { return uuid.New().String() } + +func handleTTSWebSocketResponse(c *gin.Context, requestURL string, volcRequest VolcengineTTSRequest, info *relaycommon.RelayInfo, encoding string) (usage any, err *types.NewAPIError) { + _, token, parseErr := parseVolcengineAuth(info.ApiKey) + if parseErr != nil { + return nil, types.NewErrorWithStatusCode( + parseErr, + types.ErrorCodeChannelInvalidKey, + http.StatusUnauthorized, + ) + } + + header := http.Header{} + header.Set("Authorization", fmt.Sprintf("Bearer;%s", token)) + + conn, resp, dialErr := websocket.DefaultDialer.DialContext(context.Background(), requestURL, header) + if dialErr != nil { + if resp != nil { + return nil, types.NewErrorWithStatusCode( + fmt.Errorf("failed to connect to websocket: %w, status: %d", dialErr, resp.StatusCode), + types.ErrorCodeBadResponseStatusCode, + http.StatusBadGateway, + ) + } + return nil, types.NewErrorWithStatusCode( + fmt.Errorf("failed to connect to websocket: %w", dialErr), + types.ErrorCodeBadResponseStatusCode, + http.StatusBadGateway, + ) + } + defer conn.Close() + + payload, marshalErr := json.Marshal(volcRequest) + if marshalErr != nil { + return nil, types.NewErrorWithStatusCode( + fmt.Errorf("failed to marshal request: %w", marshalErr), + types.ErrorCodeBadRequestBody, + http.StatusInternalServerError, + ) + } + + if sendErr := FullClientRequest(conn, payload); sendErr != nil { + return nil, types.NewErrorWithStatusCode( + fmt.Errorf("failed to send request: %w", sendErr), + types.ErrorCodeBadRequestBody, + http.StatusInternalServerError, + ) + } + + contentType := getContentTypeByEncoding(encoding) + c.Header("Content-Type", contentType) + c.Header("Transfer-Encoding", "chunked") + + for { + msg, recvErr := ReceiveMessage(conn) + if recvErr != nil { + if websocket.IsCloseError(recvErr, websocket.CloseNormalClosure, websocket.CloseGoingAway) { + break + } + return nil, types.NewErrorWithStatusCode( + fmt.Errorf("failed to receive message: %w", recvErr), + types.ErrorCodeBadResponse, + http.StatusInternalServerError, + ) + } + + switch msg.MsgType { + case MsgTypeError: + return nil, types.NewErrorWithStatusCode( + fmt.Errorf("received error from server: code=%d, %s", msg.ErrorCode, string(msg.Payload)), + types.ErrorCodeBadResponse, + http.StatusBadRequest, + ) + case MsgTypeFrontEndResultServer: + continue + case MsgTypeAudioOnlyServer: + if len(msg.Payload) > 0 { + if _, writeErr := c.Writer.Write(msg.Payload); writeErr != nil { + return nil, types.NewErrorWithStatusCode( + fmt.Errorf("failed to write audio data: %w", writeErr), + types.ErrorCodeBadResponse, + http.StatusInternalServerError, + ) + } + c.Writer.Flush() + } + + if msg.Sequence < 0 { + c.Status(http.StatusOK) + usage = &dto.Usage{ + PromptTokens: info.PromptTokens, + CompletionTokens: 0, + TotalTokens: info.PromptTokens, + } + return usage, nil + } + default: + continue + } + } + + c.Status(http.StatusOK) + usage = &dto.Usage{ + PromptTokens: info.PromptTokens, + CompletionTokens: 0, + TotalTokens: info.PromptTokens, + } + return usage, nil +}