From 098e6e7f2bfd0af0218e4d6f453871ffbbf4741b Mon Sep 17 00:00:00 2001 From: feitianbubu Date: Wed, 22 Oct 2025 12:31:54 +0800 Subject: [PATCH 1/3] feat: doubao tts support streaming realtime audio --- relay/channel/volcengine/adaptor.go | 61 ++- relay/channel/volcengine/protocols.go | 715 ++++++++++++++++++++++++++ relay/channel/volcengine/tts.go | 129 +++++ 3 files changed, 898 insertions(+), 7 deletions(-) create mode 100644 relay/channel/volcengine/protocols.go diff --git a/relay/channel/volcengine/adaptor.go b/relay/channel/volcengine/adaptor.go index c5d9e5dd6..8049a6c1c 100644 --- a/relay/channel/volcengine/adaptor.go +++ b/relay/channel/volcengine/adaptor.go @@ -70,7 +70,7 @@ func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInf Request: VolcengineTTSReqInfo{ ReqID: generateRequestID(), Text: request.Input, - Operation: "query", + Operation: "submit", // WebSocket uses "submit" Model: info.OriginModelName, }, } @@ -82,12 +82,11 @@ func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInf } } - jsonData, err := json.Marshal(volcRequest) - if err != nil { - return nil, fmt.Errorf("error marshalling volcengine request: %w", err) - } + // Store the request in context for WebSocket handler + c.Set("volcengine_tts_request", volcRequest) - return bytes.NewReader(jsonData), nil + // Return nil as WebSocket doesn't use traditional request body + return nil, nil } func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) { @@ -268,7 +267,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { case constant.RelayModeAudioSpeech: // 只有当 baseUrl 是火山默认的官方Url时才改为官方的的TTS接口,否则走透传的New接口 if baseUrl == channelconstant.ChannelBaseURLs[channelconstant.ChannelTypeVolcEngine] { - return "https://openspeech.bytedance.com/api/v1/tts", nil + return "wss://openspeech.bytedance.com/api/v1/tts/ws_binary", nil } return fmt.Sprintf("%s/v1/audio/speech", baseUrl), nil default: @@ -320,12 +319,60 @@ func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommo } func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { + // For TTS with WebSocket, skip traditional HTTP request + if info.RelayMode == constant.RelayModeAudioSpeech { + baseUrl := info.ChannelBaseUrl + if baseUrl == "" { + baseUrl = channelconstant.ChannelBaseURLs[channelconstant.ChannelTypeVolcEngine] + } + // Only use WebSocket for official Volcengine endpoint + if baseUrl == channelconstant.ChannelBaseURLs[channelconstant.ChannelTypeVolcEngine] { + return nil, nil // WebSocket handling will be done in DoResponse + } + } return channel.DoApiRequest(a, c, info, requestBody) } func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) { if info.RelayMode == constant.RelayModeAudioSpeech { encoding := mapEncoding(c.GetString("response_format")) + + // Check if this is WebSocket mode (resp will be nil for WebSocket) + if resp == nil { + // Get the WebSocket URL + requestURL, urlErr := a.GetRequestURL(info) + if urlErr != nil { + return nil, types.NewErrorWithStatusCode( + urlErr, + types.ErrorCodeBadRequestBody, + http.StatusInternalServerError, + ) + } + + // Retrieve the volcengine request from context + volcRequestInterface, exists := c.Get("volcengine_tts_request") + if !exists { + return nil, types.NewErrorWithStatusCode( + errors.New("volcengine TTS request not found in context"), + types.ErrorCodeBadRequestBody, + http.StatusInternalServerError, + ) + } + + volcRequest, ok := volcRequestInterface.(VolcengineTTSRequest) + if !ok { + return nil, types.NewErrorWithStatusCode( + errors.New("invalid volcengine TTS request type"), + types.ErrorCodeBadRequestBody, + http.StatusInternalServerError, + ) + } + + // Handle WebSocket streaming + return handleTTSWebSocketResponse(c, requestURL, volcRequest, info, encoding) + } + + // Handle traditional HTTP response return handleTTSResponse(c, resp, info, encoding) } diff --git a/relay/channel/volcengine/protocols.go b/relay/channel/volcengine/protocols.go new file mode 100644 index 000000000..a41d87566 --- /dev/null +++ b/relay/channel/volcengine/protocols.go @@ -0,0 +1,715 @@ +package volcengine + +import ( + "bytes" + "encoding/binary" + "fmt" + "io" + "math" + + "github.com/gorilla/websocket" +) + +type ( + // EventType defines the event type which determines the event of the message. + EventType int32 + // MsgType defines message type which determines how the message will be + // serialized with the protocol. + MsgType uint8 + // MsgTypeFlagBits defines the 4-bit message-type specific flags. The specific + // values should be defined in each specific usage scenario. + MsgTypeFlagBits uint8 + // VersionBits defines the 4-bit version type. + VersionBits uint8 + // HeaderSizeBits defines the 4-bit header-size type. + HeaderSizeBits uint8 + // SerializationBits defines the 4-bit serialization method type. + SerializationBits uint8 + // CompressionBits defines the 4-bit compression method type. + CompressionBits uint8 +) + +const ( + MsgTypeFlagNoSeq MsgTypeFlagBits = 0 // Non-terminal packet with no sequence + MsgTypeFlagPositiveSeq MsgTypeFlagBits = 0b1 // Non-terminal packet with sequence > 0 + MsgTypeFlagLastNoSeq MsgTypeFlagBits = 0b10 // last packet with no sequence + MsgTypeFlagNegativeSeq MsgTypeFlagBits = 0b11 // last packet with sequence < 0 + MsgTypeFlagWithEvent MsgTypeFlagBits = 0b100 // Payload contains event number (int32) +) + +const ( + Version1 VersionBits = iota + 1 + Version2 + Version3 + Version4 +) + +const ( + HeaderSize4 HeaderSizeBits = iota + 1 + HeaderSize8 + HeaderSize12 + HeaderSize16 +) + +const ( + SerializationRaw SerializationBits = 0 + SerializationJSON SerializationBits = 0b1 + SerializationThrift SerializationBits = 0b11 + SerializationCustom SerializationBits = 0b1111 +) + +const ( + CompressionNone CompressionBits = 0 + CompressionGzip CompressionBits = 0b1 + CompressionCustom CompressionBits = 0b1111 +) + +const ( + MsgTypeInvalid MsgType = 0 + MsgTypeFullClientRequest MsgType = 0b1 + MsgTypeAudioOnlyClient MsgType = 0b10 + MsgTypeFullServerResponse MsgType = 0b1001 + MsgTypeAudioOnlyServer MsgType = 0b1011 + MsgTypeFrontEndResultServer MsgType = 0b1100 + MsgTypeError MsgType = 0b1111 + + MsgTypeServerACK = MsgTypeAudioOnlyServer +) + +func (t MsgType) String() string { + switch t { + case MsgTypeFullClientRequest: + return "MsgType_FullClientRequest" + case MsgTypeAudioOnlyClient: + return "MsgType_AudioOnlyClient" + case MsgTypeFullServerResponse: + return "MsgType_FullServerResponse" + case MsgTypeAudioOnlyServer: + return "MsgType_AudioOnlyServer" // MsgTypeServerACK + case MsgTypeError: + return "MsgType_Error" + case MsgTypeFrontEndResultServer: + return "MsgType_FrontEndResultServer" + default: + return fmt.Sprintf("MsgType_(%d)", t) + } +} + +const ( + // Default event, applicable for scenarios not using events or not requiring event transmission, + // or for scenarios using events, non-zero values can be used to validate event legitimacy + EventType_None EventType = 0 + // 1 ~ 49 for upstream Connection events + EventType_StartConnection EventType = 1 + EventType_StartTask EventType = 1 // Alias of "StartConnection" + EventType_FinishConnection EventType = 2 + EventType_FinishTask EventType = 2 // Alias of "FinishConnection" + // 50 ~ 99 for downstream Connection events + // Connection established successfully + EventType_ConnectionStarted EventType = 50 + EventType_TaskStarted EventType = 50 // Alias of "ConnectionStarted" + // Connection failed (possibly due to authentication failure) + EventType_ConnectionFailed EventType = 51 + EventType_TaskFailed EventType = 51 // Alias of "ConnectionFailed" + // Connection ended + EventType_ConnectionFinished EventType = 52 + EventType_TaskFinished EventType = 52 // Alias of "ConnectionFinished" + // 100 ~ 149 for upstream Session events + EventType_StartSession EventType = 100 + EventType_CancelSession EventType = 101 + EventType_FinishSession EventType = 102 + // 150 ~ 199 for downstream Session events + EventType_SessionStarted EventType = 150 + EventType_SessionCanceled EventType = 151 + EventType_SessionFinished EventType = 152 + EventType_SessionFailed EventType = 153 + // Usage events + EventType_UsageResponse EventType = 154 + EventType_ChargeData EventType = 154 // Alias of "UsageResponse" + // 200 ~ 249 for upstream general events + EventType_TaskRequest EventType = 200 + EventType_UpdateConfig EventType = 201 + // 250 ~ 299 for downstream general events + EventType_AudioMuted EventType = 250 + // 300 ~ 349 for upstream TTS events + EventType_SayHello EventType = 300 + // 350 ~ 399 for downstream TTS events + EventType_TTSSentenceStart EventType = 350 + EventType_TTSSentenceEnd EventType = 351 + EventType_TTSResponse EventType = 352 + EventType_TTSEnded EventType = 359 + EventType_PodcastRoundStart EventType = 360 + EventType_PodcastRoundResponse EventType = 361 + EventType_PodcastRoundEnd EventType = 362 + // 450 ~ 499 for downstream ASR events + EventType_ASRInfo EventType = 450 + EventType_ASRResponse EventType = 451 + EventType_ASREnded EventType = 459 + // 500 ~ 549 for upstream dialogue events + // (Ground-Truth-Alignment) text for speech synthesis + EventType_ChatTTSText EventType = 500 + // 550 ~ 599 for downstream dialogue events + EventType_ChatResponse EventType = 550 + EventType_ChatEnded EventType = 559 + // 650 ~ 699 for downstream dialogue events + // Events for source (original) language subtitle. + EventType_SourceSubtitleStart EventType = 650 + EventType_SourceSubtitleResponse EventType = 651 + EventType_SourceSubtitleEnd EventType = 652 + // Events for target (translation) language subtitle. + EventType_TranslationSubtitleStart EventType = 653 + EventType_TranslationSubtitleResponse EventType = 654 + EventType_TranslationSubtitleEnd EventType = 655 +) + +func (t EventType) String() string { + switch t { + case EventType_None: + return "EventType_None" + case EventType_StartConnection: + return "EventType_StartConnection" + case EventType_FinishConnection: + return "EventType_FinishConnection" + case EventType_ConnectionStarted: + return "EventType_ConnectionStarted" + case EventType_ConnectionFailed: + return "EventType_ConnectionFailed" + case EventType_ConnectionFinished: + return "EventType_ConnectionFinished" + case EventType_StartSession: + return "EventType_StartSession" + case EventType_CancelSession: + return "EventType_CancelSession" + case EventType_FinishSession: + return "EventType_FinishSession" + case EventType_SessionStarted: + return "EventType_SessionStarted" + case EventType_SessionCanceled: + return "EventType_SessionCanceled" + case EventType_SessionFinished: + return "EventType_SessionFinished" + case EventType_SessionFailed: + return "EventType_SessionFailed" + case EventType_UsageResponse: + return "EventType_UsageResponse" + case EventType_TaskRequest: + return "EventType_TaskRequest" + case EventType_UpdateConfig: + return "EventType_UpdateConfig" + case EventType_AudioMuted: + return "EventType_AudioMuted" + case EventType_SayHello: + return "EventType_SayHello" + case EventType_TTSSentenceStart: + return "EventType_TTSSentenceStart" + case EventType_TTSSentenceEnd: + return "EventType_TTSSentenceEnd" + case EventType_TTSResponse: + return "EventType_TTSResponse" + case EventType_TTSEnded: + return "EventType_TTSEnded" + case EventType_PodcastRoundStart: + return "EventType_PodcastRoundStart" + case EventType_PodcastRoundResponse: + return "EventType_PodcastRoundResponse" + case EventType_PodcastRoundEnd: + return "EventType_PodcastRoundEnd" + case EventType_ASRInfo: + return "EventType_ASRInfo" + case EventType_ASRResponse: + return "EventType_ASRResponse" + case EventType_ASREnded: + return "EventType_ASREnded" + case EventType_ChatTTSText: + return "EventType_ChatTTSText" + case EventType_ChatResponse: + return "EventType_ChatResponse" + case EventType_ChatEnded: + return "EventType_ChatEnded" + case EventType_SourceSubtitleStart: + return "EventType_SourceSubtitleStart" + case EventType_SourceSubtitleResponse: + return "EventType_SourceSubtitleResponse" + case EventType_SourceSubtitleEnd: + return "EventType_SourceSubtitleEnd" + case EventType_TranslationSubtitleStart: + return "EventType_TranslationSubtitleStart" + case EventType_TranslationSubtitleResponse: + return "EventType_TranslationSubtitleResponse" + case EventType_TranslationSubtitleEnd: + return "EventType_TranslationSubtitleEnd" + default: + return fmt.Sprintf("EventType_(%d)", t) + } +} + +// 0 1 2 3 +// | 0 1 2 3 4 5 6 7 | 0 1 2 3 4 5 6 7 | 0 1 2 3 4 5 6 7 | 0 1 2 3 4 5 6 7 | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Version | Header Size | Msg Type | Flags | +// | (4 bits) | (4 bits) | (4 bits) | (4 bits) | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Serialization | Compression | Reserved | +// | (4 bits) | (4 bits) | (8 bits) | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | | +// | Optional Header Extensions | +// | (if Header Size > 1) | +// | | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | | +// | Payload | +// | (variable length) | +// | | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + +type Message struct { + Version VersionBits + HeaderSize HeaderSizeBits + MsgType MsgType + MsgTypeFlag MsgTypeFlagBits + Serialization SerializationBits + Compression CompressionBits + + EventType EventType + SessionID string + ConnectID string + Sequence int32 + ErrorCode uint32 + + Payload []byte +} + +func NewMessageFromBytes(data []byte) (*Message, error) { + if len(data) < 3 { + return nil, fmt.Errorf("data too short: expected at least 3 bytes, got %d", len(data)) + } + + typeAndFlag := data[1] + + msg, err := NewMessage(MsgType(typeAndFlag>>4), MsgTypeFlagBits(typeAndFlag&0b00001111)) + if err != nil { + return nil, err + } + + if err := msg.Unmarshal(data); err != nil { + return nil, err + } + + return msg, nil +} + +func NewMessage(msgType MsgType, flag MsgTypeFlagBits) (*Message, error) { + return &Message{ + MsgType: msgType, + MsgTypeFlag: flag, + Version: Version1, + HeaderSize: HeaderSize4, + Serialization: SerializationJSON, + Compression: CompressionNone, + }, nil +} + +func (m *Message) String() string { + switch m.MsgType { + case MsgTypeAudioOnlyServer, MsgTypeAudioOnlyClient: + if m.MsgTypeFlag == MsgTypeFlagPositiveSeq || m.MsgTypeFlag == MsgTypeFlagNegativeSeq { + return fmt.Sprintf("%s, %s, Sequence: %d, PayloadSize: %d", m.MsgType, m.EventType, m.Sequence, len(m.Payload)) + } + return fmt.Sprintf("%s, %s, PayloadSize: %d", m.MsgType, m.EventType, len(m.Payload)) + case MsgTypeError: + return fmt.Sprintf("%s, %s, ErrorCode: %d, Payload: %s", m.MsgType, m.EventType, m.ErrorCode, string(m.Payload)) + default: + if m.MsgTypeFlag == MsgTypeFlagPositiveSeq || m.MsgTypeFlag == MsgTypeFlagNegativeSeq { + return fmt.Sprintf("%s, %s, Sequence: %d, Payload: %s", + m.MsgType, m.EventType, m.Sequence, string(m.Payload)) + } + return fmt.Sprintf("%s, %s, Payload: %s", m.MsgType, m.EventType, string(m.Payload)) + } +} + +func (m *Message) Marshal() ([]byte, error) { + buf := new(bytes.Buffer) + + header := []uint8{ + uint8(m.Version)<<4 | uint8(m.HeaderSize), + uint8(m.MsgType)<<4 | uint8(m.MsgTypeFlag), + uint8(m.Serialization)<<4 | uint8(m.Compression), + } + + headerSize := 4 * int(m.HeaderSize) + if padding := headerSize - len(header); padding > 0 { + header = append(header, make([]uint8, padding)...) + } + + if err := binary.Write(buf, binary.BigEndian, header); err != nil { + return nil, err + } + + writers, err := m.writers() + if err != nil { + return nil, err + } + + for _, write := range writers { + if err := write(buf); err != nil { + return nil, err + } + } + + return buf.Bytes(), nil +} + +func (m *Message) Unmarshal(data []byte) error { + buf := bytes.NewBuffer(data) + + versionAndHeaderSize, err := buf.ReadByte() + if err != nil { + return err + } + + m.Version = VersionBits(versionAndHeaderSize >> 4) + m.HeaderSize = HeaderSizeBits(versionAndHeaderSize & 0b00001111) + + _, err = buf.ReadByte() + if err != nil { + return err + } + + serializationCompression, err := buf.ReadByte() + if err != nil { + return err + } + + m.Serialization = SerializationBits(serializationCompression & 0b11110000) + m.Compression = CompressionBits(serializationCompression & 0b00001111) + + headerSize := 4 * int(m.HeaderSize) + readSize := 3 + if paddingSize := headerSize - readSize; paddingSize > 0 { + if n, err := buf.Read(make([]byte, paddingSize)); err != nil || n < paddingSize { + return fmt.Errorf("insufficient header bytes: expected %d, got %d", paddingSize, n) + } + } + + readers, err := m.readers() + if err != nil { + return err + } + + for _, read := range readers { + if err := read(buf); err != nil { + return err + } + } + + if _, err := buf.ReadByte(); err != io.EOF { + return fmt.Errorf("unexpected data after message: %v", err) + } + + return nil +} + +func (m *Message) writers() (writers []func(*bytes.Buffer) error, _ error) { + if m.MsgTypeFlag == MsgTypeFlagWithEvent { + writers = append(writers, m.writeEvent, m.writeSessionID) + } + + switch m.MsgType { + case MsgTypeFullClientRequest, MsgTypeFullServerResponse, MsgTypeFrontEndResultServer, MsgTypeAudioOnlyClient, MsgTypeAudioOnlyServer: + if m.MsgTypeFlag == MsgTypeFlagPositiveSeq || m.MsgTypeFlag == MsgTypeFlagNegativeSeq { + writers = append(writers, m.writeSequence) + } + case MsgTypeError: + writers = append(writers, m.writeErrorCode) + default: + return nil, fmt.Errorf("unsupported message type: %d", m.MsgType) + } + + writers = append(writers, m.writePayload) + return writers, nil +} + +func (m *Message) writeEvent(buf *bytes.Buffer) error { + return binary.Write(buf, binary.BigEndian, m.EventType) +} + +func (m *Message) writeSessionID(buf *bytes.Buffer) error { + switch m.EventType { + case EventType_StartConnection, EventType_FinishConnection, + EventType_ConnectionStarted, EventType_ConnectionFailed: + return nil + } + + size := len(m.SessionID) + if size > math.MaxUint32 { + return fmt.Errorf("session ID size (%d) exceeds max(uint32)", size) + } + + if err := binary.Write(buf, binary.BigEndian, uint32(size)); err != nil { + return err + } + + buf.WriteString(m.SessionID) + return nil +} + +func (m *Message) writeSequence(buf *bytes.Buffer) error { + return binary.Write(buf, binary.BigEndian, m.Sequence) +} + +func (m *Message) writeErrorCode(buf *bytes.Buffer) error { + return binary.Write(buf, binary.BigEndian, m.ErrorCode) +} + +func (m *Message) writePayload(buf *bytes.Buffer) error { + size := len(m.Payload) + if size > math.MaxUint32 { + return fmt.Errorf("payload size (%d) exceeds max(uint32)", size) + } + + if err := binary.Write(buf, binary.BigEndian, uint32(size)); err != nil { + return err + } + + buf.Write(m.Payload) + return nil +} + +func (m *Message) readers() (readers []func(*bytes.Buffer) error, _ error) { + switch m.MsgType { + case MsgTypeFullClientRequest, MsgTypeFullServerResponse, MsgTypeFrontEndResultServer, MsgTypeAudioOnlyClient, MsgTypeAudioOnlyServer: + if m.MsgTypeFlag == MsgTypeFlagPositiveSeq || m.MsgTypeFlag == MsgTypeFlagNegativeSeq { + readers = append(readers, m.readSequence) + } + case MsgTypeError: + readers = append(readers, m.readErrorCode) + default: + return nil, fmt.Errorf("unsupported message type: %d", m.MsgType) + } + + if m.MsgTypeFlag == MsgTypeFlagWithEvent { + readers = append(readers, m.readEvent, m.readSessionID, m.readConnectID) + } + + readers = append(readers, m.readPayload) + return readers, nil +} + +func (m *Message) readEvent(buf *bytes.Buffer) error { + return binary.Read(buf, binary.BigEndian, &m.EventType) +} + +func (m *Message) readSessionID(buf *bytes.Buffer) error { + switch m.EventType { + case EventType_StartConnection, EventType_FinishConnection, + EventType_ConnectionStarted, EventType_ConnectionFailed, + EventType_ConnectionFinished: + return nil + } + + var size uint32 + if err := binary.Read(buf, binary.BigEndian, &size); err != nil { + return err + } + + if size > 0 { + m.SessionID = string(buf.Next(int(size))) + } + + return nil +} + +func (m *Message) readConnectID(buf *bytes.Buffer) error { + switch m.EventType { + case EventType_ConnectionStarted, EventType_ConnectionFailed, + EventType_ConnectionFinished: + default: + return nil + } + + var size uint32 + if err := binary.Read(buf, binary.BigEndian, &size); err != nil { + return err + } + + if size > 0 { + m.ConnectID = string(buf.Next(int(size))) + } + + return nil +} + +func (m *Message) readSequence(buf *bytes.Buffer) error { + return binary.Read(buf, binary.BigEndian, &m.Sequence) +} + +func (m *Message) readErrorCode(buf *bytes.Buffer) error { + return binary.Read(buf, binary.BigEndian, &m.ErrorCode) +} + +func (m *Message) readPayload(buf *bytes.Buffer) error { + var size uint32 + if err := binary.Read(buf, binary.BigEndian, &size); err != nil { + return err + } + + if size > 0 { + m.Payload = buf.Next(int(size)) + } + + return nil +} + +func ReceiveMessage(conn *websocket.Conn) (*Message, error) { + mt, frame, err := conn.ReadMessage() + if err != nil { + return nil, err + } + if mt != websocket.BinaryMessage && mt != websocket.TextMessage { + return nil, fmt.Errorf("unexpected Websocket message type: %d", mt) + } + msg, err := NewMessageFromBytes(frame) + if err != nil { + return nil, err + } + // Log: receive msg + return msg, nil +} + +func WaitForEvent(conn *websocket.Conn, msgType MsgType, eventType EventType) (*Message, error) { + for { + msg, err := ReceiveMessage(conn) + if err != nil { + return nil, err + } + if msg.MsgType != msgType || msg.EventType != eventType { + return nil, fmt.Errorf("unexpected message: %s", msg) + } + if msg.MsgType == msgType && msg.EventType == eventType { + return msg, nil + } + } +} + +func FullClientRequest(conn *websocket.Conn, payload []byte) error { + msg, err := NewMessage(MsgTypeFullClientRequest, MsgTypeFlagNoSeq) + if err != nil { + return err + } + msg.Payload = payload + // Log: send msg + frame, err := msg.Marshal() + if err != nil { + return err + } + return conn.WriteMessage(websocket.BinaryMessage, frame) +} + +func AudioOnlyClient(conn *websocket.Conn, payload []byte, flag MsgTypeFlagBits) error { + msg, err := NewMessage(MsgTypeAudioOnlyClient, flag) + if err != nil { + return err + } + msg.Payload = payload + // Log: send msg + frame, err := msg.Marshal() + if err != nil { + return err + } + return conn.WriteMessage(websocket.BinaryMessage, frame) +} + +func StartConnection(conn *websocket.Conn) error { + msg, err := NewMessage(MsgTypeFullClientRequest, MsgTypeFlagWithEvent) + if err != nil { + return err + } + msg.EventType = EventType_StartConnection + msg.Payload = []byte("{}") + // Log: send msg + frame, err := msg.Marshal() + if err != nil { + return err + } + return conn.WriteMessage(websocket.BinaryMessage, frame) +} + +func FinishConnection(conn *websocket.Conn) error { + msg, err := NewMessage(MsgTypeFullClientRequest, MsgTypeFlagWithEvent) + if err != nil { + return err + } + msg.EventType = EventType_FinishConnection + msg.Payload = []byte("{}") + // Log: send msg + frame, err := msg.Marshal() + if err != nil { + return err + } + return conn.WriteMessage(websocket.BinaryMessage, frame) +} + +func StartSession(conn *websocket.Conn, payload []byte, sessionID string) error { + msg, err := NewMessage(MsgTypeFullClientRequest, MsgTypeFlagWithEvent) + if err != nil { + return err + } + msg.EventType = EventType_StartSession + msg.SessionID = sessionID + msg.Payload = payload + // Log: send msg + frame, err := msg.Marshal() + if err != nil { + return err + } + return conn.WriteMessage(websocket.BinaryMessage, frame) +} + +func FinishSession(conn *websocket.Conn, sessionID string) error { + msg, err := NewMessage(MsgTypeFullClientRequest, MsgTypeFlagWithEvent) + if err != nil { + return err + } + msg.EventType = EventType_FinishSession + msg.SessionID = sessionID + msg.Payload = []byte("{}") + // Log: send msg + frame, err := msg.Marshal() + if err != nil { + return err + } + return conn.WriteMessage(websocket.BinaryMessage, frame) +} + +func CancelSession(conn *websocket.Conn, sessionID string) error { + msg, err := NewMessage(MsgTypeFullClientRequest, MsgTypeFlagWithEvent) + if err != nil { + return err + } + msg.EventType = EventType_CancelSession + msg.SessionID = sessionID + msg.Payload = []byte("{}") + // Log: send msg + frame, err := msg.Marshal() + if err != nil { + return err + } + return conn.WriteMessage(websocket.BinaryMessage, frame) +} + +func TaskRequest(conn *websocket.Conn, payload []byte, sessionID string) error { + msg, err := NewMessage(MsgTypeFullClientRequest, MsgTypeFlagWithEvent) + if err != nil { + return err + } + msg.EventType = EventType_TaskRequest + msg.SessionID = sessionID + msg.Payload = payload + // Log: send msg + frame, err := msg.Marshal() + if err != nil { + return err + } + return conn.WriteMessage(websocket.BinaryMessage, frame) +} diff --git a/relay/channel/volcengine/tts.go b/relay/channel/volcengine/tts.go index 328512845..6b64c551e 100644 --- a/relay/channel/volcengine/tts.go +++ b/relay/channel/volcengine/tts.go @@ -1,9 +1,11 @@ package volcengine import ( + "context" "encoding/base64" "encoding/json" "errors" + "fmt" "io" "net/http" "strings" @@ -13,6 +15,7 @@ import ( "github.com/QuantumNous/new-api/types" "github.com/gin-gonic/gin" "github.com/google/uuid" + "github.com/gorilla/websocket" ) type VolcengineTTSRequest struct { @@ -192,3 +195,129 @@ func handleTTSResponse(c *gin.Context, resp *http.Response, info *relaycommon.Re func generateRequestID() string { return uuid.New().String() } + +// handleTTSWebSocketResponse handles streaming TTS response via WebSocket +func handleTTSWebSocketResponse(c *gin.Context, requestURL string, volcRequest VolcengineTTSRequest, info *relaycommon.RelayInfo, encoding string) (usage any, err *types.NewAPIError) { + // Parse API key for auth + _, token, parseErr := parseVolcengineAuth(info.ApiKey) + if parseErr != nil { + return nil, types.NewErrorWithStatusCode( + parseErr, + types.ErrorCodeChannelInvalidKey, + http.StatusUnauthorized, + ) + } + + // Setup WebSocket headers + header := http.Header{} + header.Set("Authorization", fmt.Sprintf("Bearer;%s", token)) + + // Dial WebSocket connection + conn, resp, dialErr := websocket.DefaultDialer.DialContext(context.Background(), requestURL, header) + if dialErr != nil { + if resp != nil { + return nil, types.NewErrorWithStatusCode( + fmt.Errorf("failed to connect to websocket: %w, status: %d", dialErr, resp.StatusCode), + types.ErrorCodeBadResponseStatusCode, + http.StatusBadGateway, + ) + } + return nil, types.NewErrorWithStatusCode( + fmt.Errorf("failed to connect to websocket: %w", dialErr), + types.ErrorCodeBadResponseStatusCode, + http.StatusBadGateway, + ) + } + defer conn.Close() + + // Update request operation to "submit" for WebSocket + volcRequest.Request.Operation = "submit" + + // Marshal request payload + payload, marshalErr := json.Marshal(volcRequest) + if marshalErr != nil { + return nil, types.NewErrorWithStatusCode( + fmt.Errorf("failed to marshal request: %w", marshalErr), + types.ErrorCodeBadRequestBody, + http.StatusInternalServerError, + ) + } + + // Send full client request + if sendErr := FullClientRequest(conn, payload); sendErr != nil { + return nil, types.NewErrorWithStatusCode( + fmt.Errorf("failed to send request: %w", sendErr), + types.ErrorCodeBadRequestBody, + http.StatusInternalServerError, + ) + } + + // Set response headers + contentType := getContentTypeByEncoding(encoding) + c.Header("Content-Type", contentType) + c.Header("Transfer-Encoding", "chunked") + + // Stream audio data + var audioBuffer []byte + for { + msg, recvErr := ReceiveMessage(conn) + if recvErr != nil { + if websocket.IsCloseError(recvErr, websocket.CloseNormalClosure, websocket.CloseGoingAway) { + break + } + return nil, types.NewErrorWithStatusCode( + fmt.Errorf("failed to receive message: %w", recvErr), + types.ErrorCodeBadResponse, + http.StatusInternalServerError, + ) + } + + switch msg.MsgType { + case MsgTypeError: + return nil, types.NewErrorWithStatusCode( + fmt.Errorf("received error from server: code=%d, %s", msg.ErrorCode, string(msg.Payload)), + types.ErrorCodeBadResponse, + http.StatusBadRequest, + ) + case MsgTypeFrontEndResultServer: + // Metadata response, can be logged or processed + continue + case MsgTypeAudioOnlyServer: + // Stream audio chunk to client + if len(msg.Payload) > 0 { + audioBuffer = append(audioBuffer, msg.Payload...) + if _, writeErr := c.Writer.Write(msg.Payload); writeErr != nil { + return nil, types.NewErrorWithStatusCode( + fmt.Errorf("failed to write audio data: %w", writeErr), + types.ErrorCodeBadResponse, + http.StatusInternalServerError, + ) + } + c.Writer.Flush() + } + + // Check if this is the last packet (negative sequence) + if msg.Sequence < 0 { + c.Status(http.StatusOK) + usage = &dto.Usage{ + PromptTokens: info.PromptTokens, + CompletionTokens: 0, + TotalTokens: info.PromptTokens, + } + return usage, nil + } + default: + // Unknown message type, log and continue + continue + } + } + + // If we reach here, connection closed without final packet + c.Status(http.StatusOK) + usage = &dto.Usage{ + PromptTokens: info.PromptTokens, + CompletionTokens: 0, + TotalTokens: info.PromptTokens, + } + return usage, nil +} From 431b3a84f65bdc7885ebdc3cb6e07e4dcb30ef6a Mon Sep 17 00:00:00 2001 From: feitianbubu Date: Wed, 22 Oct 2025 13:29:01 +0800 Subject: [PATCH 2/3] feat: doubao tts add is stream check --- relay/channel/volcengine/adaptor.go | 59 +++++++++++++++++------------ relay/channel/volcengine/tts.go | 10 +---- 2 files changed, 36 insertions(+), 33 deletions(-) diff --git a/relay/channel/volcengine/adaptor.go b/relay/channel/volcengine/adaptor.go index 8049a6c1c..c7c2a92bc 100644 --- a/relay/channel/volcengine/adaptor.go +++ b/relay/channel/volcengine/adaptor.go @@ -23,6 +23,12 @@ import ( "github.com/gin-gonic/gin" ) +const ( + // Context keys for passing data between methods + contextKeyTTSRequest = "volcengine_tts_request" + contextKeyResponseFormat = "response_format" +) + type Adaptor struct { } @@ -50,7 +56,7 @@ func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInf speedRatio := request.Speed encoding := mapEncoding(request.ResponseFormat) - c.Set("response_format", encoding) + c.Set(contextKeyResponseFormat, encoding) volcRequest := VolcengineTTSRequest{ App: VolcengineTTSApp{ @@ -70,7 +76,7 @@ func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInf Request: VolcengineTTSReqInfo{ ReqID: generateRequestID(), Text: request.Input, - Operation: "submit", // WebSocket uses "submit" + Operation: "submit", // default WebSocket uses "submit" Model: info.OriginModelName, }, } @@ -83,10 +89,20 @@ func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInf } // Store the request in context for WebSocket handler - c.Set("volcengine_tts_request", volcRequest) + c.Set(contextKeyTTSRequest, volcRequest) + // https://www.volcengine.com/docs/6561/1257584 + // operation需要设置为submit才是流式返回 + if volcRequest.Request.Operation == "submit" { + info.IsStream = true + } // Return nil as WebSocket doesn't use traditional request body - return nil, nil + jsonData, err := json.Marshal(volcRequest) + if err != nil { + return nil, fmt.Errorf("error marshalling volcengine request: %w", err) + } + + return bytes.NewReader(jsonData), nil } func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) { @@ -327,7 +343,9 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request } // Only use WebSocket for official Volcengine endpoint if baseUrl == channelconstant.ChannelBaseURLs[channelconstant.ChannelTypeVolcEngine] { - return nil, nil // WebSocket handling will be done in DoResponse + if info.IsStream { + return nil, nil + } } } return channel.DoApiRequest(a, c, info, requestBody) @@ -335,22 +353,9 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) { if info.RelayMode == constant.RelayModeAudioSpeech { - encoding := mapEncoding(c.GetString("response_format")) - - // Check if this is WebSocket mode (resp will be nil for WebSocket) - if resp == nil { - // Get the WebSocket URL - requestURL, urlErr := a.GetRequestURL(info) - if urlErr != nil { - return nil, types.NewErrorWithStatusCode( - urlErr, - types.ErrorCodeBadRequestBody, - http.StatusInternalServerError, - ) - } - - // Retrieve the volcengine request from context - volcRequestInterface, exists := c.Get("volcengine_tts_request") + encoding := mapEncoding(c.GetString(contextKeyResponseFormat)) + if info.IsStream { + volcRequestInterface, exists := c.Get(contextKeyTTSRequest) if !exists { return nil, types.NewErrorWithStatusCode( errors.New("volcengine TTS request not found in context"), @@ -368,11 +373,17 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom ) } - // Handle WebSocket streaming + // Get the WebSocket URL + requestURL, urlErr := a.GetRequestURL(info) + if urlErr != nil { + return nil, types.NewErrorWithStatusCode( + urlErr, + types.ErrorCodeBadRequestBody, + http.StatusInternalServerError, + ) + } return handleTTSWebSocketResponse(c, requestURL, volcRequest, info, encoding) } - - // Handle traditional HTTP response return handleTTSResponse(c, resp, info, encoding) } diff --git a/relay/channel/volcengine/tts.go b/relay/channel/volcengine/tts.go index 6b64c551e..033737a58 100644 --- a/relay/channel/volcengine/tts.go +++ b/relay/channel/volcengine/tts.go @@ -230,10 +230,6 @@ func handleTTSWebSocketResponse(c *gin.Context, requestURL string, volcRequest V } defer conn.Close() - // Update request operation to "submit" for WebSocket - volcRequest.Request.Operation = "submit" - - // Marshal request payload payload, marshalErr := json.Marshal(volcRequest) if marshalErr != nil { return nil, types.NewErrorWithStatusCode( @@ -280,10 +276,8 @@ func handleTTSWebSocketResponse(c *gin.Context, requestURL string, volcRequest V http.StatusBadRequest, ) case MsgTypeFrontEndResultServer: - // Metadata response, can be logged or processed continue case MsgTypeAudioOnlyServer: - // Stream audio chunk to client if len(msg.Payload) > 0 { audioBuffer = append(audioBuffer, msg.Payload...) if _, writeErr := c.Writer.Write(msg.Payload); writeErr != nil { @@ -293,10 +287,10 @@ func handleTTSWebSocketResponse(c *gin.Context, requestURL string, volcRequest V http.StatusInternalServerError, ) } + //logger.Infof("write audio chunk size: %d", len(msg.Payload)) c.Writer.Flush() } - // Check if this is the last packet (negative sequence) if msg.Sequence < 0 { c.Status(http.StatusOK) usage = &dto.Usage{ @@ -307,12 +301,10 @@ func handleTTSWebSocketResponse(c *gin.Context, requestURL string, volcRequest V return usage, nil } default: - // Unknown message type, log and continue continue } } - // If we reach here, connection closed without final packet c.Status(http.StatusOK) usage = &dto.Usage{ PromptTokens: info.PromptTokens, From bf66bbe5fa1798999761f65d1fcbacb1b75313c8 Mon Sep 17 00:00:00 2001 From: feitianbubu Date: Wed, 22 Oct 2025 16:48:00 +0800 Subject: [PATCH 3/3] refactor: clean up doubao tts code --- relay/channel/volcengine/adaptor.go | 38 +--- relay/channel/volcengine/protocols.go | 240 ++++---------------------- relay/channel/volcengine/tts.go | 10 -- 3 files changed, 35 insertions(+), 253 deletions(-) diff --git a/relay/channel/volcengine/adaptor.go b/relay/channel/volcengine/adaptor.go index c7c2a92bc..f46328e37 100644 --- a/relay/channel/volcengine/adaptor.go +++ b/relay/channel/volcengine/adaptor.go @@ -24,7 +24,6 @@ import ( ) const ( - // Context keys for passing data between methods contextKeyTTSRequest = "volcengine_tts_request" contextKeyResponseFormat = "response_format" ) @@ -76,27 +75,23 @@ func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInf Request: VolcengineTTSReqInfo{ ReqID: generateRequestID(), Text: request.Input, - Operation: "submit", // default WebSocket uses "submit" + Operation: "submit", Model: info.OriginModelName, }, } - // 同步扩展字段的厂商自定义metadata if len(request.Metadata) > 0 { if err = json.Unmarshal(request.Metadata, &volcRequest); err != nil { return nil, fmt.Errorf("error unmarshalling metadata to volcengine request: %w", err) } } - // Store the request in context for WebSocket handler c.Set(contextKeyTTSRequest, volcRequest) - // https://www.volcengine.com/docs/6561/1257584 - // operation需要设置为submit才是流式返回 + if volcRequest.Request.Operation == "submit" { info.IsStream = true } - // Return nil as WebSocket doesn't use traditional request body jsonData, err := json.Marshal(volcRequest) if err != nil { return nil, fmt.Errorf("error marshalling volcengine request: %w", err) @@ -115,9 +110,8 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf writer := multipart.NewWriter(&requestBody) writer.WriteField("model", request.Model) - // 获取所有表单字段 + formData := c.Request.PostForm - // 遍历表单字段并打印输出 for key, values := range formData { if key == "model" { continue @@ -127,21 +121,16 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf } } - // Parse the multipart form to handle both single image and multiple images - if err := c.Request.ParseMultipartForm(32 << 20); err != nil { // 32MB max memory + if err := c.Request.ParseMultipartForm(32 << 20); err != nil { return nil, errors.New("failed to parse multipart form") } if c.Request.MultipartForm != nil && c.Request.MultipartForm.File != nil { - // Check if "image" field exists in any form, including array notation var imageFiles []*multipart.FileHeader var exists bool - // First check for standard "image" field if imageFiles, exists = c.Request.MultipartForm.File["image"]; !exists || len(imageFiles) == 0 { - // If not found, check for "image[]" field if imageFiles, exists = c.Request.MultipartForm.File["image[]"]; !exists || len(imageFiles) == 0 { - // If still not found, iterate through all fields to find any that start with "image[" foundArrayImages := false for fieldName, files := range c.Request.MultipartForm.File { if strings.HasPrefix(fieldName, "image[") && len(files) > 0 { @@ -152,14 +141,12 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf } } - // If no image fields found at all if !foundArrayImages && (len(imageFiles) == 0) { return nil, errors.New("image is required") } } } - // Process all image files for i, fileHeader := range imageFiles { file, err := fileHeader.Open() if err != nil { @@ -167,16 +154,13 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf } defer file.Close() - // If multiple images, use image[] as the field name fieldName := "image" if len(imageFiles) > 1 { fieldName = "image[]" } - // Determine MIME type based on file extension mimeType := detectImageMimeType(fileHeader.Filename) - // Create a form file with the appropriate content type h := make(textproto.MIMEHeader) h.Set("Content-Disposition", fmt.Sprintf(`form-data; name="%s"; filename="%s"`, fieldName, fileHeader.Filename)) h.Set("Content-Type", mimeType) @@ -191,7 +175,6 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf } } - // Handle mask file if present if maskFiles, exists := c.Request.MultipartForm.File["mask"]; exists && len(maskFiles) > 0 { maskFile, err := maskFiles[0].Open() if err != nil { @@ -199,10 +182,8 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf } defer maskFile.Close() - // Determine MIME type for mask file mimeType := detectImageMimeType(maskFiles[0].Filename) - // Create a form file with the appropriate content type h := make(textproto.MIMEHeader) h.Set("Content-Disposition", fmt.Sprintf(`form-data; name="mask"; filename="%s"`, maskFiles[0].Filename)) h.Set("Content-Type", mimeType) @@ -220,7 +201,6 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf return nil, errors.New("no multipart form data found") } - // 关闭 multipart 编写器以设置分界线 writer.Close() c.Request.Header.Set("Content-Type", writer.FormDataContentType()) return bytes.NewReader(requestBody.Bytes()), nil @@ -230,7 +210,6 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf } } -// detectImageMimeType determines the MIME type based on the file extension func detectImageMimeType(filename string) string { ext := strings.ToLower(filepath.Ext(filename)) switch ext { @@ -241,11 +220,9 @@ func detectImageMimeType(filename string) string { case ".webp": return "image/webp" default: - // Try to detect from extension if possible if strings.HasPrefix(ext, ".jp") { return "image/jpeg" } - // Default to png as a fallback return "image/png" } } @@ -281,7 +258,6 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { case constant.RelayModeRerank: return fmt.Sprintf("%s/api/v3/rerank", baseUrl), nil case constant.RelayModeAudioSpeech: - // 只有当 baseUrl 是火山默认的官方Url时才改为官方的的TTS接口,否则走透传的New接口 if baseUrl == channelconstant.ChannelBaseURLs[channelconstant.ChannelTypeVolcEngine] { return "wss://openspeech.bytedance.com/api/v1/tts/ws_binary", nil } @@ -312,7 +288,7 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn if request == nil { return nil, errors.New("request is nil") } - // 适配 方舟deepseek混合模型 的 thinking 后缀 + if strings.HasSuffix(info.UpstreamModelName, "-thinking") && strings.HasPrefix(info.UpstreamModelName, "deepseek") { info.UpstreamModelName = strings.TrimSuffix(info.UpstreamModelName, "-thinking") request.Model = info.UpstreamModelName @@ -330,18 +306,16 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela } func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) { - // TODO implement me return nil, errors.New("not implemented") } func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { - // For TTS with WebSocket, skip traditional HTTP request if info.RelayMode == constant.RelayModeAudioSpeech { baseUrl := info.ChannelBaseUrl if baseUrl == "" { baseUrl = channelconstant.ChannelBaseURLs[channelconstant.ChannelTypeVolcEngine] } - // Only use WebSocket for official Volcengine endpoint + if baseUrl == channelconstant.ChannelBaseURLs[channelconstant.ChannelTypeVolcEngine] { if info.IsStream { return nil, nil diff --git a/relay/channel/volcengine/protocols.go b/relay/channel/volcengine/protocols.go index a41d87566..c978e1c76 100644 --- a/relay/channel/volcengine/protocols.go +++ b/relay/channel/volcengine/protocols.go @@ -11,69 +11,45 @@ import ( ) type ( - // EventType defines the event type which determines the event of the message. - EventType int32 - // MsgType defines message type which determines how the message will be - // serialized with the protocol. - MsgType uint8 - // MsgTypeFlagBits defines the 4-bit message-type specific flags. The specific - // values should be defined in each specific usage scenario. - MsgTypeFlagBits uint8 - // VersionBits defines the 4-bit version type. - VersionBits uint8 - // HeaderSizeBits defines the 4-bit header-size type. - HeaderSizeBits uint8 - // SerializationBits defines the 4-bit serialization method type. + EventType int32 + MsgType uint8 + MsgTypeFlagBits uint8 + VersionBits uint8 + HeaderSizeBits uint8 SerializationBits uint8 - // CompressionBits defines the 4-bit compression method type. - CompressionBits uint8 + CompressionBits uint8 ) const ( - MsgTypeFlagNoSeq MsgTypeFlagBits = 0 // Non-terminal packet with no sequence - MsgTypeFlagPositiveSeq MsgTypeFlagBits = 0b1 // Non-terminal packet with sequence > 0 - MsgTypeFlagLastNoSeq MsgTypeFlagBits = 0b10 // last packet with no sequence - MsgTypeFlagNegativeSeq MsgTypeFlagBits = 0b11 // last packet with sequence < 0 - MsgTypeFlagWithEvent MsgTypeFlagBits = 0b100 // Payload contains event number (int32) + MsgTypeFlagNoSeq MsgTypeFlagBits = 0 + MsgTypeFlagPositiveSeq MsgTypeFlagBits = 0b1 + MsgTypeFlagNegativeSeq MsgTypeFlagBits = 0b11 + MsgTypeFlagWithEvent MsgTypeFlagBits = 0b100 ) const ( Version1 VersionBits = iota + 1 - Version2 - Version3 - Version4 ) const ( HeaderSize4 HeaderSizeBits = iota + 1 - HeaderSize8 - HeaderSize12 - HeaderSize16 ) const ( - SerializationRaw SerializationBits = 0 - SerializationJSON SerializationBits = 0b1 - SerializationThrift SerializationBits = 0b11 - SerializationCustom SerializationBits = 0b1111 + SerializationJSON SerializationBits = 0b1 ) const ( - CompressionNone CompressionBits = 0 - CompressionGzip CompressionBits = 0b1 - CompressionCustom CompressionBits = 0b1111 + CompressionNone CompressionBits = 0 ) const ( - MsgTypeInvalid MsgType = 0 MsgTypeFullClientRequest MsgType = 0b1 MsgTypeAudioOnlyClient MsgType = 0b10 MsgTypeFullServerResponse MsgType = 0b1001 MsgTypeAudioOnlyServer MsgType = 0b1011 MsgTypeFrontEndResultServer MsgType = 0b1100 MsgTypeError MsgType = 0b1111 - - MsgTypeServerACK = MsgTypeAudioOnlyServer ) func (t MsgType) String() string { @@ -85,7 +61,7 @@ func (t MsgType) String() string { case MsgTypeFullServerResponse: return "MsgType_FullServerResponse" case MsgTypeAudioOnlyServer: - return "MsgType_AudioOnlyServer" // MsgTypeServerACK + return "MsgType_AudioOnlyServer" case MsgTypeError: return "MsgType_Error" case MsgTypeFrontEndResultServer: @@ -96,44 +72,33 @@ func (t MsgType) String() string { } const ( - // Default event, applicable for scenarios not using events or not requiring event transmission, - // or for scenarios using events, non-zero values can be used to validate event legitimacy EventType_None EventType = 0 - // 1 ~ 49 for upstream Connection events + EventType_StartConnection EventType = 1 - EventType_StartTask EventType = 1 // Alias of "StartConnection" EventType_FinishConnection EventType = 2 - EventType_FinishTask EventType = 2 // Alias of "FinishConnection" - // 50 ~ 99 for downstream Connection events - // Connection established successfully - EventType_ConnectionStarted EventType = 50 - EventType_TaskStarted EventType = 50 // Alias of "ConnectionStarted" - // Connection failed (possibly due to authentication failure) - EventType_ConnectionFailed EventType = 51 - EventType_TaskFailed EventType = 51 // Alias of "ConnectionFailed" - // Connection ended + + EventType_ConnectionStarted EventType = 50 + EventType_ConnectionFailed EventType = 51 EventType_ConnectionFinished EventType = 52 - EventType_TaskFinished EventType = 52 // Alias of "ConnectionFinished" - // 100 ~ 149 for upstream Session events + EventType_StartSession EventType = 100 EventType_CancelSession EventType = 101 EventType_FinishSession EventType = 102 - // 150 ~ 199 for downstream Session events + EventType_SessionStarted EventType = 150 EventType_SessionCanceled EventType = 151 EventType_SessionFinished EventType = 152 EventType_SessionFailed EventType = 153 - // Usage events + EventType_UsageResponse EventType = 154 - EventType_ChargeData EventType = 154 // Alias of "UsageResponse" - // 200 ~ 249 for upstream general events + EventType_TaskRequest EventType = 200 EventType_UpdateConfig EventType = 201 - // 250 ~ 299 for downstream general events + EventType_AudioMuted EventType = 250 - // 300 ~ 349 for upstream TTS events + EventType_SayHello EventType = 300 - // 350 ~ 399 for downstream TTS events + EventType_TTSSentenceStart EventType = 350 EventType_TTSSentenceEnd EventType = 351 EventType_TTSResponse EventType = 352 @@ -141,22 +106,20 @@ const ( EventType_PodcastRoundStart EventType = 360 EventType_PodcastRoundResponse EventType = 361 EventType_PodcastRoundEnd EventType = 362 - // 450 ~ 499 for downstream ASR events + EventType_ASRInfo EventType = 450 EventType_ASRResponse EventType = 451 EventType_ASREnded EventType = 459 - // 500 ~ 549 for upstream dialogue events - // (Ground-Truth-Alignment) text for speech synthesis + EventType_ChatTTSText EventType = 500 - // 550 ~ 599 for downstream dialogue events + EventType_ChatResponse EventType = 550 EventType_ChatEnded EventType = 559 - // 650 ~ 699 for downstream dialogue events - // Events for source (original) language subtitle. + EventType_SourceSubtitleStart EventType = 650 EventType_SourceSubtitleResponse EventType = 651 EventType_SourceSubtitleEnd EventType = 652 - // Events for target (translation) language subtitle. + EventType_TranslationSubtitleStart EventType = 653 EventType_TranslationSubtitleResponse EventType = 654 EventType_TranslationSubtitleEnd EventType = 655 @@ -243,26 +206,6 @@ func (t EventType) String() string { } } -// 0 1 2 3 -// | 0 1 2 3 4 5 6 7 | 0 1 2 3 4 5 6 7 | 0 1 2 3 4 5 6 7 | 0 1 2 3 4 5 6 7 | -// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -// | Version | Header Size | Msg Type | Flags | -// | (4 bits) | (4 bits) | (4 bits) | (4 bits) | -// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -// | Serialization | Compression | Reserved | -// | (4 bits) | (4 bits) | (8 bits) | -// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -// | | -// | Optional Header Extensions | -// | (if Header Size > 1) | -// | | -// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -// | | -// | Payload | -// | (variable length) | -// | | -// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - type Message struct { Version VersionBits HeaderSize HeaderSizeBits @@ -573,140 +516,15 @@ func ReceiveMessage(conn *websocket.Conn) (*Message, error) { if err != nil { return nil, err } - // Log: receive msg return msg, nil } -func WaitForEvent(conn *websocket.Conn, msgType MsgType, eventType EventType) (*Message, error) { - for { - msg, err := ReceiveMessage(conn) - if err != nil { - return nil, err - } - if msg.MsgType != msgType || msg.EventType != eventType { - return nil, fmt.Errorf("unexpected message: %s", msg) - } - if msg.MsgType == msgType && msg.EventType == eventType { - return msg, nil - } - } -} - func FullClientRequest(conn *websocket.Conn, payload []byte) error { msg, err := NewMessage(MsgTypeFullClientRequest, MsgTypeFlagNoSeq) if err != nil { return err } msg.Payload = payload - // Log: send msg - frame, err := msg.Marshal() - if err != nil { - return err - } - return conn.WriteMessage(websocket.BinaryMessage, frame) -} - -func AudioOnlyClient(conn *websocket.Conn, payload []byte, flag MsgTypeFlagBits) error { - msg, err := NewMessage(MsgTypeAudioOnlyClient, flag) - if err != nil { - return err - } - msg.Payload = payload - // Log: send msg - frame, err := msg.Marshal() - if err != nil { - return err - } - return conn.WriteMessage(websocket.BinaryMessage, frame) -} - -func StartConnection(conn *websocket.Conn) error { - msg, err := NewMessage(MsgTypeFullClientRequest, MsgTypeFlagWithEvent) - if err != nil { - return err - } - msg.EventType = EventType_StartConnection - msg.Payload = []byte("{}") - // Log: send msg - frame, err := msg.Marshal() - if err != nil { - return err - } - return conn.WriteMessage(websocket.BinaryMessage, frame) -} - -func FinishConnection(conn *websocket.Conn) error { - msg, err := NewMessage(MsgTypeFullClientRequest, MsgTypeFlagWithEvent) - if err != nil { - return err - } - msg.EventType = EventType_FinishConnection - msg.Payload = []byte("{}") - // Log: send msg - frame, err := msg.Marshal() - if err != nil { - return err - } - return conn.WriteMessage(websocket.BinaryMessage, frame) -} - -func StartSession(conn *websocket.Conn, payload []byte, sessionID string) error { - msg, err := NewMessage(MsgTypeFullClientRequest, MsgTypeFlagWithEvent) - if err != nil { - return err - } - msg.EventType = EventType_StartSession - msg.SessionID = sessionID - msg.Payload = payload - // Log: send msg - frame, err := msg.Marshal() - if err != nil { - return err - } - return conn.WriteMessage(websocket.BinaryMessage, frame) -} - -func FinishSession(conn *websocket.Conn, sessionID string) error { - msg, err := NewMessage(MsgTypeFullClientRequest, MsgTypeFlagWithEvent) - if err != nil { - return err - } - msg.EventType = EventType_FinishSession - msg.SessionID = sessionID - msg.Payload = []byte("{}") - // Log: send msg - frame, err := msg.Marshal() - if err != nil { - return err - } - return conn.WriteMessage(websocket.BinaryMessage, frame) -} - -func CancelSession(conn *websocket.Conn, sessionID string) error { - msg, err := NewMessage(MsgTypeFullClientRequest, MsgTypeFlagWithEvent) - if err != nil { - return err - } - msg.EventType = EventType_CancelSession - msg.SessionID = sessionID - msg.Payload = []byte("{}") - // Log: send msg - frame, err := msg.Marshal() - if err != nil { - return err - } - return conn.WriteMessage(websocket.BinaryMessage, frame) -} - -func TaskRequest(conn *websocket.Conn, payload []byte, sessionID string) error { - msg, err := NewMessage(MsgTypeFullClientRequest, MsgTypeFlagWithEvent) - if err != nil { - return err - } - msg.EventType = EventType_TaskRequest - msg.SessionID = sessionID - msg.Payload = payload - // Log: send msg frame, err := msg.Marshal() if err != nil { return err diff --git a/relay/channel/volcengine/tts.go b/relay/channel/volcengine/tts.go index 033737a58..166fab8ef 100644 --- a/relay/channel/volcengine/tts.go +++ b/relay/channel/volcengine/tts.go @@ -196,9 +196,7 @@ func generateRequestID() string { return uuid.New().String() } -// handleTTSWebSocketResponse handles streaming TTS response via WebSocket func handleTTSWebSocketResponse(c *gin.Context, requestURL string, volcRequest VolcengineTTSRequest, info *relaycommon.RelayInfo, encoding string) (usage any, err *types.NewAPIError) { - // Parse API key for auth _, token, parseErr := parseVolcengineAuth(info.ApiKey) if parseErr != nil { return nil, types.NewErrorWithStatusCode( @@ -208,11 +206,9 @@ func handleTTSWebSocketResponse(c *gin.Context, requestURL string, volcRequest V ) } - // Setup WebSocket headers header := http.Header{} header.Set("Authorization", fmt.Sprintf("Bearer;%s", token)) - // Dial WebSocket connection conn, resp, dialErr := websocket.DefaultDialer.DialContext(context.Background(), requestURL, header) if dialErr != nil { if resp != nil { @@ -239,7 +235,6 @@ func handleTTSWebSocketResponse(c *gin.Context, requestURL string, volcRequest V ) } - // Send full client request if sendErr := FullClientRequest(conn, payload); sendErr != nil { return nil, types.NewErrorWithStatusCode( fmt.Errorf("failed to send request: %w", sendErr), @@ -248,13 +243,10 @@ func handleTTSWebSocketResponse(c *gin.Context, requestURL string, volcRequest V ) } - // Set response headers contentType := getContentTypeByEncoding(encoding) c.Header("Content-Type", contentType) c.Header("Transfer-Encoding", "chunked") - // Stream audio data - var audioBuffer []byte for { msg, recvErr := ReceiveMessage(conn) if recvErr != nil { @@ -279,7 +271,6 @@ func handleTTSWebSocketResponse(c *gin.Context, requestURL string, volcRequest V continue case MsgTypeAudioOnlyServer: if len(msg.Payload) > 0 { - audioBuffer = append(audioBuffer, msg.Payload...) if _, writeErr := c.Writer.Write(msg.Payload); writeErr != nil { return nil, types.NewErrorWithStatusCode( fmt.Errorf("failed to write audio data: %w", writeErr), @@ -287,7 +278,6 @@ func handleTTSWebSocketResponse(c *gin.Context, requestURL string, volcRequest V http.StatusInternalServerError, ) } - //logger.Infof("write audio chunk size: %d", len(msg.Payload)) c.Writer.Flush() }