diff --git a/relay/channel/volcengine/adaptor.go b/relay/channel/volcengine/adaptor.go index c5d9e5dd6..8049a6c1c 100644 --- a/relay/channel/volcengine/adaptor.go +++ b/relay/channel/volcengine/adaptor.go @@ -70,7 +70,7 @@ func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInf Request: VolcengineTTSReqInfo{ ReqID: generateRequestID(), Text: request.Input, - Operation: "query", + Operation: "submit", // WebSocket uses "submit" Model: info.OriginModelName, }, } @@ -82,12 +82,11 @@ func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInf } } - jsonData, err := json.Marshal(volcRequest) - if err != nil { - return nil, fmt.Errorf("error marshalling volcengine request: %w", err) - } + // Store the request in context for WebSocket handler + c.Set("volcengine_tts_request", volcRequest) - return bytes.NewReader(jsonData), nil + // Return nil as WebSocket doesn't use traditional request body + return nil, nil } func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) { @@ -268,7 +267,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { case constant.RelayModeAudioSpeech: // 只有当 baseUrl 是火山默认的官方Url时才改为官方的的TTS接口,否则走透传的New接口 if baseUrl == channelconstant.ChannelBaseURLs[channelconstant.ChannelTypeVolcEngine] { - return "https://openspeech.bytedance.com/api/v1/tts", nil + return "wss://openspeech.bytedance.com/api/v1/tts/ws_binary", nil } return fmt.Sprintf("%s/v1/audio/speech", baseUrl), nil default: @@ -320,12 +319,60 @@ func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommo } func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { + // For TTS with WebSocket, skip traditional HTTP request + if info.RelayMode == constant.RelayModeAudioSpeech { + baseUrl := info.ChannelBaseUrl + if baseUrl == "" { + baseUrl = channelconstant.ChannelBaseURLs[channelconstant.ChannelTypeVolcEngine] + } + // Only use WebSocket for official Volcengine endpoint + if baseUrl == channelconstant.ChannelBaseURLs[channelconstant.ChannelTypeVolcEngine] { + return nil, nil // WebSocket handling will be done in DoResponse + } + } return channel.DoApiRequest(a, c, info, requestBody) } func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) { if info.RelayMode == constant.RelayModeAudioSpeech { encoding := mapEncoding(c.GetString("response_format")) + + // Check if this is WebSocket mode (resp will be nil for WebSocket) + if resp == nil { + // Get the WebSocket URL + requestURL, urlErr := a.GetRequestURL(info) + if urlErr != nil { + return nil, types.NewErrorWithStatusCode( + urlErr, + types.ErrorCodeBadRequestBody, + http.StatusInternalServerError, + ) + } + + // Retrieve the volcengine request from context + volcRequestInterface, exists := c.Get("volcengine_tts_request") + if !exists { + return nil, types.NewErrorWithStatusCode( + errors.New("volcengine TTS request not found in context"), + types.ErrorCodeBadRequestBody, + http.StatusInternalServerError, + ) + } + + volcRequest, ok := volcRequestInterface.(VolcengineTTSRequest) + if !ok { + return nil, types.NewErrorWithStatusCode( + errors.New("invalid volcengine TTS request type"), + types.ErrorCodeBadRequestBody, + http.StatusInternalServerError, + ) + } + + // Handle WebSocket streaming + return handleTTSWebSocketResponse(c, requestURL, volcRequest, info, encoding) + } + + // Handle traditional HTTP response return handleTTSResponse(c, resp, info, encoding) } diff --git a/relay/channel/volcengine/protocols.go b/relay/channel/volcengine/protocols.go new file mode 100644 index 000000000..a41d87566 --- /dev/null +++ b/relay/channel/volcengine/protocols.go @@ -0,0 +1,715 @@ +package volcengine + +import ( + "bytes" + "encoding/binary" + "fmt" + "io" + "math" + + "github.com/gorilla/websocket" +) + +type ( + // EventType defines the event type which determines the event of the message. + EventType int32 + // MsgType defines message type which determines how the message will be + // serialized with the protocol. + MsgType uint8 + // MsgTypeFlagBits defines the 4-bit message-type specific flags. The specific + // values should be defined in each specific usage scenario. + MsgTypeFlagBits uint8 + // VersionBits defines the 4-bit version type. + VersionBits uint8 + // HeaderSizeBits defines the 4-bit header-size type. + HeaderSizeBits uint8 + // SerializationBits defines the 4-bit serialization method type. + SerializationBits uint8 + // CompressionBits defines the 4-bit compression method type. + CompressionBits uint8 +) + +const ( + MsgTypeFlagNoSeq MsgTypeFlagBits = 0 // Non-terminal packet with no sequence + MsgTypeFlagPositiveSeq MsgTypeFlagBits = 0b1 // Non-terminal packet with sequence > 0 + MsgTypeFlagLastNoSeq MsgTypeFlagBits = 0b10 // last packet with no sequence + MsgTypeFlagNegativeSeq MsgTypeFlagBits = 0b11 // last packet with sequence < 0 + MsgTypeFlagWithEvent MsgTypeFlagBits = 0b100 // Payload contains event number (int32) +) + +const ( + Version1 VersionBits = iota + 1 + Version2 + Version3 + Version4 +) + +const ( + HeaderSize4 HeaderSizeBits = iota + 1 + HeaderSize8 + HeaderSize12 + HeaderSize16 +) + +const ( + SerializationRaw SerializationBits = 0 + SerializationJSON SerializationBits = 0b1 + SerializationThrift SerializationBits = 0b11 + SerializationCustom SerializationBits = 0b1111 +) + +const ( + CompressionNone CompressionBits = 0 + CompressionGzip CompressionBits = 0b1 + CompressionCustom CompressionBits = 0b1111 +) + +const ( + MsgTypeInvalid MsgType = 0 + MsgTypeFullClientRequest MsgType = 0b1 + MsgTypeAudioOnlyClient MsgType = 0b10 + MsgTypeFullServerResponse MsgType = 0b1001 + MsgTypeAudioOnlyServer MsgType = 0b1011 + MsgTypeFrontEndResultServer MsgType = 0b1100 + MsgTypeError MsgType = 0b1111 + + MsgTypeServerACK = MsgTypeAudioOnlyServer +) + +func (t MsgType) String() string { + switch t { + case MsgTypeFullClientRequest: + return "MsgType_FullClientRequest" + case MsgTypeAudioOnlyClient: + return "MsgType_AudioOnlyClient" + case MsgTypeFullServerResponse: + return "MsgType_FullServerResponse" + case MsgTypeAudioOnlyServer: + return "MsgType_AudioOnlyServer" // MsgTypeServerACK + case MsgTypeError: + return "MsgType_Error" + case MsgTypeFrontEndResultServer: + return "MsgType_FrontEndResultServer" + default: + return fmt.Sprintf("MsgType_(%d)", t) + } +} + +const ( + // Default event, applicable for scenarios not using events or not requiring event transmission, + // or for scenarios using events, non-zero values can be used to validate event legitimacy + EventType_None EventType = 0 + // 1 ~ 49 for upstream Connection events + EventType_StartConnection EventType = 1 + EventType_StartTask EventType = 1 // Alias of "StartConnection" + EventType_FinishConnection EventType = 2 + EventType_FinishTask EventType = 2 // Alias of "FinishConnection" + // 50 ~ 99 for downstream Connection events + // Connection established successfully + EventType_ConnectionStarted EventType = 50 + EventType_TaskStarted EventType = 50 // Alias of "ConnectionStarted" + // Connection failed (possibly due to authentication failure) + EventType_ConnectionFailed EventType = 51 + EventType_TaskFailed EventType = 51 // Alias of "ConnectionFailed" + // Connection ended + EventType_ConnectionFinished EventType = 52 + EventType_TaskFinished EventType = 52 // Alias of "ConnectionFinished" + // 100 ~ 149 for upstream Session events + EventType_StartSession EventType = 100 + EventType_CancelSession EventType = 101 + EventType_FinishSession EventType = 102 + // 150 ~ 199 for downstream Session events + EventType_SessionStarted EventType = 150 + EventType_SessionCanceled EventType = 151 + EventType_SessionFinished EventType = 152 + EventType_SessionFailed EventType = 153 + // Usage events + EventType_UsageResponse EventType = 154 + EventType_ChargeData EventType = 154 // Alias of "UsageResponse" + // 200 ~ 249 for upstream general events + EventType_TaskRequest EventType = 200 + EventType_UpdateConfig EventType = 201 + // 250 ~ 299 for downstream general events + EventType_AudioMuted EventType = 250 + // 300 ~ 349 for upstream TTS events + EventType_SayHello EventType = 300 + // 350 ~ 399 for downstream TTS events + EventType_TTSSentenceStart EventType = 350 + EventType_TTSSentenceEnd EventType = 351 + EventType_TTSResponse EventType = 352 + EventType_TTSEnded EventType = 359 + EventType_PodcastRoundStart EventType = 360 + EventType_PodcastRoundResponse EventType = 361 + EventType_PodcastRoundEnd EventType = 362 + // 450 ~ 499 for downstream ASR events + EventType_ASRInfo EventType = 450 + EventType_ASRResponse EventType = 451 + EventType_ASREnded EventType = 459 + // 500 ~ 549 for upstream dialogue events + // (Ground-Truth-Alignment) text for speech synthesis + EventType_ChatTTSText EventType = 500 + // 550 ~ 599 for downstream dialogue events + EventType_ChatResponse EventType = 550 + EventType_ChatEnded EventType = 559 + // 650 ~ 699 for downstream dialogue events + // Events for source (original) language subtitle. + EventType_SourceSubtitleStart EventType = 650 + EventType_SourceSubtitleResponse EventType = 651 + EventType_SourceSubtitleEnd EventType = 652 + // Events for target (translation) language subtitle. + EventType_TranslationSubtitleStart EventType = 653 + EventType_TranslationSubtitleResponse EventType = 654 + EventType_TranslationSubtitleEnd EventType = 655 +) + +func (t EventType) String() string { + switch t { + case EventType_None: + return "EventType_None" + case EventType_StartConnection: + return "EventType_StartConnection" + case EventType_FinishConnection: + return "EventType_FinishConnection" + case EventType_ConnectionStarted: + return "EventType_ConnectionStarted" + case EventType_ConnectionFailed: + return "EventType_ConnectionFailed" + case EventType_ConnectionFinished: + return "EventType_ConnectionFinished" + case EventType_StartSession: + return "EventType_StartSession" + case EventType_CancelSession: + return "EventType_CancelSession" + case EventType_FinishSession: + return "EventType_FinishSession" + case EventType_SessionStarted: + return "EventType_SessionStarted" + case EventType_SessionCanceled: + return "EventType_SessionCanceled" + case EventType_SessionFinished: + return "EventType_SessionFinished" + case EventType_SessionFailed: + return "EventType_SessionFailed" + case EventType_UsageResponse: + return "EventType_UsageResponse" + case EventType_TaskRequest: + return "EventType_TaskRequest" + case EventType_UpdateConfig: + return "EventType_UpdateConfig" + case EventType_AudioMuted: + return "EventType_AudioMuted" + case EventType_SayHello: + return "EventType_SayHello" + case EventType_TTSSentenceStart: + return "EventType_TTSSentenceStart" + case EventType_TTSSentenceEnd: + return "EventType_TTSSentenceEnd" + case EventType_TTSResponse: + return "EventType_TTSResponse" + case EventType_TTSEnded: + return "EventType_TTSEnded" + case EventType_PodcastRoundStart: + return "EventType_PodcastRoundStart" + case EventType_PodcastRoundResponse: + return "EventType_PodcastRoundResponse" + case EventType_PodcastRoundEnd: + return "EventType_PodcastRoundEnd" + case EventType_ASRInfo: + return "EventType_ASRInfo" + case EventType_ASRResponse: + return "EventType_ASRResponse" + case EventType_ASREnded: + return "EventType_ASREnded" + case EventType_ChatTTSText: + return "EventType_ChatTTSText" + case EventType_ChatResponse: + return "EventType_ChatResponse" + case EventType_ChatEnded: + return "EventType_ChatEnded" + case EventType_SourceSubtitleStart: + return "EventType_SourceSubtitleStart" + case EventType_SourceSubtitleResponse: + return "EventType_SourceSubtitleResponse" + case EventType_SourceSubtitleEnd: + return "EventType_SourceSubtitleEnd" + case EventType_TranslationSubtitleStart: + return "EventType_TranslationSubtitleStart" + case EventType_TranslationSubtitleResponse: + return "EventType_TranslationSubtitleResponse" + case EventType_TranslationSubtitleEnd: + return "EventType_TranslationSubtitleEnd" + default: + return fmt.Sprintf("EventType_(%d)", t) + } +} + +// 0 1 2 3 +// | 0 1 2 3 4 5 6 7 | 0 1 2 3 4 5 6 7 | 0 1 2 3 4 5 6 7 | 0 1 2 3 4 5 6 7 | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Version | Header Size | Msg Type | Flags | +// | (4 bits) | (4 bits) | (4 bits) | (4 bits) | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Serialization | Compression | Reserved | +// | (4 bits) | (4 bits) | (8 bits) | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | | +// | Optional Header Extensions | +// | (if Header Size > 1) | +// | | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | | +// | Payload | +// | (variable length) | +// | | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + +type Message struct { + Version VersionBits + HeaderSize HeaderSizeBits + MsgType MsgType + MsgTypeFlag MsgTypeFlagBits + Serialization SerializationBits + Compression CompressionBits + + EventType EventType + SessionID string + ConnectID string + Sequence int32 + ErrorCode uint32 + + Payload []byte +} + +func NewMessageFromBytes(data []byte) (*Message, error) { + if len(data) < 3 { + return nil, fmt.Errorf("data too short: expected at least 3 bytes, got %d", len(data)) + } + + typeAndFlag := data[1] + + msg, err := NewMessage(MsgType(typeAndFlag>>4), MsgTypeFlagBits(typeAndFlag&0b00001111)) + if err != nil { + return nil, err + } + + if err := msg.Unmarshal(data); err != nil { + return nil, err + } + + return msg, nil +} + +func NewMessage(msgType MsgType, flag MsgTypeFlagBits) (*Message, error) { + return &Message{ + MsgType: msgType, + MsgTypeFlag: flag, + Version: Version1, + HeaderSize: HeaderSize4, + Serialization: SerializationJSON, + Compression: CompressionNone, + }, nil +} + +func (m *Message) String() string { + switch m.MsgType { + case MsgTypeAudioOnlyServer, MsgTypeAudioOnlyClient: + if m.MsgTypeFlag == MsgTypeFlagPositiveSeq || m.MsgTypeFlag == MsgTypeFlagNegativeSeq { + return fmt.Sprintf("%s, %s, Sequence: %d, PayloadSize: %d", m.MsgType, m.EventType, m.Sequence, len(m.Payload)) + } + return fmt.Sprintf("%s, %s, PayloadSize: %d", m.MsgType, m.EventType, len(m.Payload)) + case MsgTypeError: + return fmt.Sprintf("%s, %s, ErrorCode: %d, Payload: %s", m.MsgType, m.EventType, m.ErrorCode, string(m.Payload)) + default: + if m.MsgTypeFlag == MsgTypeFlagPositiveSeq || m.MsgTypeFlag == MsgTypeFlagNegativeSeq { + return fmt.Sprintf("%s, %s, Sequence: %d, Payload: %s", + m.MsgType, m.EventType, m.Sequence, string(m.Payload)) + } + return fmt.Sprintf("%s, %s, Payload: %s", m.MsgType, m.EventType, string(m.Payload)) + } +} + +func (m *Message) Marshal() ([]byte, error) { + buf := new(bytes.Buffer) + + header := []uint8{ + uint8(m.Version)<<4 | uint8(m.HeaderSize), + uint8(m.MsgType)<<4 | uint8(m.MsgTypeFlag), + uint8(m.Serialization)<<4 | uint8(m.Compression), + } + + headerSize := 4 * int(m.HeaderSize) + if padding := headerSize - len(header); padding > 0 { + header = append(header, make([]uint8, padding)...) + } + + if err := binary.Write(buf, binary.BigEndian, header); err != nil { + return nil, err + } + + writers, err := m.writers() + if err != nil { + return nil, err + } + + for _, write := range writers { + if err := write(buf); err != nil { + return nil, err + } + } + + return buf.Bytes(), nil +} + +func (m *Message) Unmarshal(data []byte) error { + buf := bytes.NewBuffer(data) + + versionAndHeaderSize, err := buf.ReadByte() + if err != nil { + return err + } + + m.Version = VersionBits(versionAndHeaderSize >> 4) + m.HeaderSize = HeaderSizeBits(versionAndHeaderSize & 0b00001111) + + _, err = buf.ReadByte() + if err != nil { + return err + } + + serializationCompression, err := buf.ReadByte() + if err != nil { + return err + } + + m.Serialization = SerializationBits(serializationCompression & 0b11110000) + m.Compression = CompressionBits(serializationCompression & 0b00001111) + + headerSize := 4 * int(m.HeaderSize) + readSize := 3 + if paddingSize := headerSize - readSize; paddingSize > 0 { + if n, err := buf.Read(make([]byte, paddingSize)); err != nil || n < paddingSize { + return fmt.Errorf("insufficient header bytes: expected %d, got %d", paddingSize, n) + } + } + + readers, err := m.readers() + if err != nil { + return err + } + + for _, read := range readers { + if err := read(buf); err != nil { + return err + } + } + + if _, err := buf.ReadByte(); err != io.EOF { + return fmt.Errorf("unexpected data after message: %v", err) + } + + return nil +} + +func (m *Message) writers() (writers []func(*bytes.Buffer) error, _ error) { + if m.MsgTypeFlag == MsgTypeFlagWithEvent { + writers = append(writers, m.writeEvent, m.writeSessionID) + } + + switch m.MsgType { + case MsgTypeFullClientRequest, MsgTypeFullServerResponse, MsgTypeFrontEndResultServer, MsgTypeAudioOnlyClient, MsgTypeAudioOnlyServer: + if m.MsgTypeFlag == MsgTypeFlagPositiveSeq || m.MsgTypeFlag == MsgTypeFlagNegativeSeq { + writers = append(writers, m.writeSequence) + } + case MsgTypeError: + writers = append(writers, m.writeErrorCode) + default: + return nil, fmt.Errorf("unsupported message type: %d", m.MsgType) + } + + writers = append(writers, m.writePayload) + return writers, nil +} + +func (m *Message) writeEvent(buf *bytes.Buffer) error { + return binary.Write(buf, binary.BigEndian, m.EventType) +} + +func (m *Message) writeSessionID(buf *bytes.Buffer) error { + switch m.EventType { + case EventType_StartConnection, EventType_FinishConnection, + EventType_ConnectionStarted, EventType_ConnectionFailed: + return nil + } + + size := len(m.SessionID) + if size > math.MaxUint32 { + return fmt.Errorf("session ID size (%d) exceeds max(uint32)", size) + } + + if err := binary.Write(buf, binary.BigEndian, uint32(size)); err != nil { + return err + } + + buf.WriteString(m.SessionID) + return nil +} + +func (m *Message) writeSequence(buf *bytes.Buffer) error { + return binary.Write(buf, binary.BigEndian, m.Sequence) +} + +func (m *Message) writeErrorCode(buf *bytes.Buffer) error { + return binary.Write(buf, binary.BigEndian, m.ErrorCode) +} + +func (m *Message) writePayload(buf *bytes.Buffer) error { + size := len(m.Payload) + if size > math.MaxUint32 { + return fmt.Errorf("payload size (%d) exceeds max(uint32)", size) + } + + if err := binary.Write(buf, binary.BigEndian, uint32(size)); err != nil { + return err + } + + buf.Write(m.Payload) + return nil +} + +func (m *Message) readers() (readers []func(*bytes.Buffer) error, _ error) { + switch m.MsgType { + case MsgTypeFullClientRequest, MsgTypeFullServerResponse, MsgTypeFrontEndResultServer, MsgTypeAudioOnlyClient, MsgTypeAudioOnlyServer: + if m.MsgTypeFlag == MsgTypeFlagPositiveSeq || m.MsgTypeFlag == MsgTypeFlagNegativeSeq { + readers = append(readers, m.readSequence) + } + case MsgTypeError: + readers = append(readers, m.readErrorCode) + default: + return nil, fmt.Errorf("unsupported message type: %d", m.MsgType) + } + + if m.MsgTypeFlag == MsgTypeFlagWithEvent { + readers = append(readers, m.readEvent, m.readSessionID, m.readConnectID) + } + + readers = append(readers, m.readPayload) + return readers, nil +} + +func (m *Message) readEvent(buf *bytes.Buffer) error { + return binary.Read(buf, binary.BigEndian, &m.EventType) +} + +func (m *Message) readSessionID(buf *bytes.Buffer) error { + switch m.EventType { + case EventType_StartConnection, EventType_FinishConnection, + EventType_ConnectionStarted, EventType_ConnectionFailed, + EventType_ConnectionFinished: + return nil + } + + var size uint32 + if err := binary.Read(buf, binary.BigEndian, &size); err != nil { + return err + } + + if size > 0 { + m.SessionID = string(buf.Next(int(size))) + } + + return nil +} + +func (m *Message) readConnectID(buf *bytes.Buffer) error { + switch m.EventType { + case EventType_ConnectionStarted, EventType_ConnectionFailed, + EventType_ConnectionFinished: + default: + return nil + } + + var size uint32 + if err := binary.Read(buf, binary.BigEndian, &size); err != nil { + return err + } + + if size > 0 { + m.ConnectID = string(buf.Next(int(size))) + } + + return nil +} + +func (m *Message) readSequence(buf *bytes.Buffer) error { + return binary.Read(buf, binary.BigEndian, &m.Sequence) +} + +func (m *Message) readErrorCode(buf *bytes.Buffer) error { + return binary.Read(buf, binary.BigEndian, &m.ErrorCode) +} + +func (m *Message) readPayload(buf *bytes.Buffer) error { + var size uint32 + if err := binary.Read(buf, binary.BigEndian, &size); err != nil { + return err + } + + if size > 0 { + m.Payload = buf.Next(int(size)) + } + + return nil +} + +func ReceiveMessage(conn *websocket.Conn) (*Message, error) { + mt, frame, err := conn.ReadMessage() + if err != nil { + return nil, err + } + if mt != websocket.BinaryMessage && mt != websocket.TextMessage { + return nil, fmt.Errorf("unexpected Websocket message type: %d", mt) + } + msg, err := NewMessageFromBytes(frame) + if err != nil { + return nil, err + } + // Log: receive msg + return msg, nil +} + +func WaitForEvent(conn *websocket.Conn, msgType MsgType, eventType EventType) (*Message, error) { + for { + msg, err := ReceiveMessage(conn) + if err != nil { + return nil, err + } + if msg.MsgType != msgType || msg.EventType != eventType { + return nil, fmt.Errorf("unexpected message: %s", msg) + } + if msg.MsgType == msgType && msg.EventType == eventType { + return msg, nil + } + } +} + +func FullClientRequest(conn *websocket.Conn, payload []byte) error { + msg, err := NewMessage(MsgTypeFullClientRequest, MsgTypeFlagNoSeq) + if err != nil { + return err + } + msg.Payload = payload + // Log: send msg + frame, err := msg.Marshal() + if err != nil { + return err + } + return conn.WriteMessage(websocket.BinaryMessage, frame) +} + +func AudioOnlyClient(conn *websocket.Conn, payload []byte, flag MsgTypeFlagBits) error { + msg, err := NewMessage(MsgTypeAudioOnlyClient, flag) + if err != nil { + return err + } + msg.Payload = payload + // Log: send msg + frame, err := msg.Marshal() + if err != nil { + return err + } + return conn.WriteMessage(websocket.BinaryMessage, frame) +} + +func StartConnection(conn *websocket.Conn) error { + msg, err := NewMessage(MsgTypeFullClientRequest, MsgTypeFlagWithEvent) + if err != nil { + return err + } + msg.EventType = EventType_StartConnection + msg.Payload = []byte("{}") + // Log: send msg + frame, err := msg.Marshal() + if err != nil { + return err + } + return conn.WriteMessage(websocket.BinaryMessage, frame) +} + +func FinishConnection(conn *websocket.Conn) error { + msg, err := NewMessage(MsgTypeFullClientRequest, MsgTypeFlagWithEvent) + if err != nil { + return err + } + msg.EventType = EventType_FinishConnection + msg.Payload = []byte("{}") + // Log: send msg + frame, err := msg.Marshal() + if err != nil { + return err + } + return conn.WriteMessage(websocket.BinaryMessage, frame) +} + +func StartSession(conn *websocket.Conn, payload []byte, sessionID string) error { + msg, err := NewMessage(MsgTypeFullClientRequest, MsgTypeFlagWithEvent) + if err != nil { + return err + } + msg.EventType = EventType_StartSession + msg.SessionID = sessionID + msg.Payload = payload + // Log: send msg + frame, err := msg.Marshal() + if err != nil { + return err + } + return conn.WriteMessage(websocket.BinaryMessage, frame) +} + +func FinishSession(conn *websocket.Conn, sessionID string) error { + msg, err := NewMessage(MsgTypeFullClientRequest, MsgTypeFlagWithEvent) + if err != nil { + return err + } + msg.EventType = EventType_FinishSession + msg.SessionID = sessionID + msg.Payload = []byte("{}") + // Log: send msg + frame, err := msg.Marshal() + if err != nil { + return err + } + return conn.WriteMessage(websocket.BinaryMessage, frame) +} + +func CancelSession(conn *websocket.Conn, sessionID string) error { + msg, err := NewMessage(MsgTypeFullClientRequest, MsgTypeFlagWithEvent) + if err != nil { + return err + } + msg.EventType = EventType_CancelSession + msg.SessionID = sessionID + msg.Payload = []byte("{}") + // Log: send msg + frame, err := msg.Marshal() + if err != nil { + return err + } + return conn.WriteMessage(websocket.BinaryMessage, frame) +} + +func TaskRequest(conn *websocket.Conn, payload []byte, sessionID string) error { + msg, err := NewMessage(MsgTypeFullClientRequest, MsgTypeFlagWithEvent) + if err != nil { + return err + } + msg.EventType = EventType_TaskRequest + msg.SessionID = sessionID + msg.Payload = payload + // Log: send msg + frame, err := msg.Marshal() + if err != nil { + return err + } + return conn.WriteMessage(websocket.BinaryMessage, frame) +} diff --git a/relay/channel/volcengine/tts.go b/relay/channel/volcengine/tts.go index 328512845..6b64c551e 100644 --- a/relay/channel/volcengine/tts.go +++ b/relay/channel/volcengine/tts.go @@ -1,9 +1,11 @@ package volcengine import ( + "context" "encoding/base64" "encoding/json" "errors" + "fmt" "io" "net/http" "strings" @@ -13,6 +15,7 @@ import ( "github.com/QuantumNous/new-api/types" "github.com/gin-gonic/gin" "github.com/google/uuid" + "github.com/gorilla/websocket" ) type VolcengineTTSRequest struct { @@ -192,3 +195,129 @@ func handleTTSResponse(c *gin.Context, resp *http.Response, info *relaycommon.Re func generateRequestID() string { return uuid.New().String() } + +// handleTTSWebSocketResponse handles streaming TTS response via WebSocket +func handleTTSWebSocketResponse(c *gin.Context, requestURL string, volcRequest VolcengineTTSRequest, info *relaycommon.RelayInfo, encoding string) (usage any, err *types.NewAPIError) { + // Parse API key for auth + _, token, parseErr := parseVolcengineAuth(info.ApiKey) + if parseErr != nil { + return nil, types.NewErrorWithStatusCode( + parseErr, + types.ErrorCodeChannelInvalidKey, + http.StatusUnauthorized, + ) + } + + // Setup WebSocket headers + header := http.Header{} + header.Set("Authorization", fmt.Sprintf("Bearer;%s", token)) + + // Dial WebSocket connection + conn, resp, dialErr := websocket.DefaultDialer.DialContext(context.Background(), requestURL, header) + if dialErr != nil { + if resp != nil { + return nil, types.NewErrorWithStatusCode( + fmt.Errorf("failed to connect to websocket: %w, status: %d", dialErr, resp.StatusCode), + types.ErrorCodeBadResponseStatusCode, + http.StatusBadGateway, + ) + } + return nil, types.NewErrorWithStatusCode( + fmt.Errorf("failed to connect to websocket: %w", dialErr), + types.ErrorCodeBadResponseStatusCode, + http.StatusBadGateway, + ) + } + defer conn.Close() + + // Update request operation to "submit" for WebSocket + volcRequest.Request.Operation = "submit" + + // Marshal request payload + payload, marshalErr := json.Marshal(volcRequest) + if marshalErr != nil { + return nil, types.NewErrorWithStatusCode( + fmt.Errorf("failed to marshal request: %w", marshalErr), + types.ErrorCodeBadRequestBody, + http.StatusInternalServerError, + ) + } + + // Send full client request + if sendErr := FullClientRequest(conn, payload); sendErr != nil { + return nil, types.NewErrorWithStatusCode( + fmt.Errorf("failed to send request: %w", sendErr), + types.ErrorCodeBadRequestBody, + http.StatusInternalServerError, + ) + } + + // Set response headers + contentType := getContentTypeByEncoding(encoding) + c.Header("Content-Type", contentType) + c.Header("Transfer-Encoding", "chunked") + + // Stream audio data + var audioBuffer []byte + for { + msg, recvErr := ReceiveMessage(conn) + if recvErr != nil { + if websocket.IsCloseError(recvErr, websocket.CloseNormalClosure, websocket.CloseGoingAway) { + break + } + return nil, types.NewErrorWithStatusCode( + fmt.Errorf("failed to receive message: %w", recvErr), + types.ErrorCodeBadResponse, + http.StatusInternalServerError, + ) + } + + switch msg.MsgType { + case MsgTypeError: + return nil, types.NewErrorWithStatusCode( + fmt.Errorf("received error from server: code=%d, %s", msg.ErrorCode, string(msg.Payload)), + types.ErrorCodeBadResponse, + http.StatusBadRequest, + ) + case MsgTypeFrontEndResultServer: + // Metadata response, can be logged or processed + continue + case MsgTypeAudioOnlyServer: + // Stream audio chunk to client + if len(msg.Payload) > 0 { + audioBuffer = append(audioBuffer, msg.Payload...) + if _, writeErr := c.Writer.Write(msg.Payload); writeErr != nil { + return nil, types.NewErrorWithStatusCode( + fmt.Errorf("failed to write audio data: %w", writeErr), + types.ErrorCodeBadResponse, + http.StatusInternalServerError, + ) + } + c.Writer.Flush() + } + + // Check if this is the last packet (negative sequence) + if msg.Sequence < 0 { + c.Status(http.StatusOK) + usage = &dto.Usage{ + PromptTokens: info.PromptTokens, + CompletionTokens: 0, + TotalTokens: info.PromptTokens, + } + return usage, nil + } + default: + // Unknown message type, log and continue + continue + } + } + + // If we reach here, connection closed without final packet + c.Status(http.StatusOK) + usage = &dto.Usage{ + PromptTokens: info.PromptTokens, + CompletionTokens: 0, + TotalTokens: info.PromptTokens, + } + return usage, nil +}