diff --git a/relay/channel/volcengine/adaptor.go b/relay/channel/volcengine/adaptor.go index 8049a6c1c..c7c2a92bc 100644 --- a/relay/channel/volcengine/adaptor.go +++ b/relay/channel/volcengine/adaptor.go @@ -23,6 +23,12 @@ import ( "github.com/gin-gonic/gin" ) +const ( + // Context keys for passing data between methods + contextKeyTTSRequest = "volcengine_tts_request" + contextKeyResponseFormat = "response_format" +) + type Adaptor struct { } @@ -50,7 +56,7 @@ func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInf speedRatio := request.Speed encoding := mapEncoding(request.ResponseFormat) - c.Set("response_format", encoding) + c.Set(contextKeyResponseFormat, encoding) volcRequest := VolcengineTTSRequest{ App: VolcengineTTSApp{ @@ -70,7 +76,7 @@ func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInf Request: VolcengineTTSReqInfo{ ReqID: generateRequestID(), Text: request.Input, - Operation: "submit", // WebSocket uses "submit" + Operation: "submit", // default WebSocket uses "submit" Model: info.OriginModelName, }, } @@ -83,10 +89,20 @@ func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInf } // Store the request in context for WebSocket handler - c.Set("volcengine_tts_request", volcRequest) + c.Set(contextKeyTTSRequest, volcRequest) + // https://www.volcengine.com/docs/6561/1257584 + // operation需要设置为submit才是流式返回 + if volcRequest.Request.Operation == "submit" { + info.IsStream = true + } // Return nil as WebSocket doesn't use traditional request body - return nil, nil + jsonData, err := json.Marshal(volcRequest) + if err != nil { + return nil, fmt.Errorf("error marshalling volcengine request: %w", err) + } + + return bytes.NewReader(jsonData), nil } func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) { @@ -327,7 +343,9 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request } // Only use WebSocket for official Volcengine endpoint if baseUrl == channelconstant.ChannelBaseURLs[channelconstant.ChannelTypeVolcEngine] { - return nil, nil // WebSocket handling will be done in DoResponse + if info.IsStream { + return nil, nil + } } } return channel.DoApiRequest(a, c, info, requestBody) @@ -335,22 +353,9 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) { if info.RelayMode == constant.RelayModeAudioSpeech { - encoding := mapEncoding(c.GetString("response_format")) - - // Check if this is WebSocket mode (resp will be nil for WebSocket) - if resp == nil { - // Get the WebSocket URL - requestURL, urlErr := a.GetRequestURL(info) - if urlErr != nil { - return nil, types.NewErrorWithStatusCode( - urlErr, - types.ErrorCodeBadRequestBody, - http.StatusInternalServerError, - ) - } - - // Retrieve the volcengine request from context - volcRequestInterface, exists := c.Get("volcengine_tts_request") + encoding := mapEncoding(c.GetString(contextKeyResponseFormat)) + if info.IsStream { + volcRequestInterface, exists := c.Get(contextKeyTTSRequest) if !exists { return nil, types.NewErrorWithStatusCode( errors.New("volcengine TTS request not found in context"), @@ -368,11 +373,17 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom ) } - // Handle WebSocket streaming + // Get the WebSocket URL + requestURL, urlErr := a.GetRequestURL(info) + if urlErr != nil { + return nil, types.NewErrorWithStatusCode( + urlErr, + types.ErrorCodeBadRequestBody, + http.StatusInternalServerError, + ) + } return handleTTSWebSocketResponse(c, requestURL, volcRequest, info, encoding) } - - // Handle traditional HTTP response return handleTTSResponse(c, resp, info, encoding) } diff --git a/relay/channel/volcengine/tts.go b/relay/channel/volcengine/tts.go index 6b64c551e..033737a58 100644 --- a/relay/channel/volcengine/tts.go +++ b/relay/channel/volcengine/tts.go @@ -230,10 +230,6 @@ func handleTTSWebSocketResponse(c *gin.Context, requestURL string, volcRequest V } defer conn.Close() - // Update request operation to "submit" for WebSocket - volcRequest.Request.Operation = "submit" - - // Marshal request payload payload, marshalErr := json.Marshal(volcRequest) if marshalErr != nil { return nil, types.NewErrorWithStatusCode( @@ -280,10 +276,8 @@ func handleTTSWebSocketResponse(c *gin.Context, requestURL string, volcRequest V http.StatusBadRequest, ) case MsgTypeFrontEndResultServer: - // Metadata response, can be logged or processed continue case MsgTypeAudioOnlyServer: - // Stream audio chunk to client if len(msg.Payload) > 0 { audioBuffer = append(audioBuffer, msg.Payload...) if _, writeErr := c.Writer.Write(msg.Payload); writeErr != nil { @@ -293,10 +287,10 @@ func handleTTSWebSocketResponse(c *gin.Context, requestURL string, volcRequest V http.StatusInternalServerError, ) } + //logger.Infof("write audio chunk size: %d", len(msg.Payload)) c.Writer.Flush() } - // Check if this is the last packet (negative sequence) if msg.Sequence < 0 { c.Status(http.StatusOK) usage = &dto.Usage{ @@ -307,12 +301,10 @@ func handleTTSWebSocketResponse(c *gin.Context, requestURL string, volcRequest V return usage, nil } default: - // Unknown message type, log and continue continue } } - // If we reach here, connection closed without final packet c.Status(http.StatusOK) usage = &dto.Usage{ PromptTokens: info.PromptTokens,