feat: doubao tts support streaming realtime audio

This commit is contained in:
feitianbubu
2025-10-22 12:31:54 +08:00
parent 4661399639
commit 098e6e7f2b
3 changed files with 898 additions and 7 deletions

View File

@@ -70,7 +70,7 @@ func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInf
Request: VolcengineTTSReqInfo{
ReqID: generateRequestID(),
Text: request.Input,
Operation: "query",
Operation: "submit", // WebSocket uses "submit"
Model: info.OriginModelName,
},
}
@@ -82,12 +82,11 @@ func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInf
}
}
jsonData, err := json.Marshal(volcRequest)
if err != nil {
return nil, fmt.Errorf("error marshalling volcengine request: %w", err)
}
// Store the request in context for WebSocket handler
c.Set("volcengine_tts_request", volcRequest)
return bytes.NewReader(jsonData), nil
// Return nil as WebSocket doesn't use traditional request body
return nil, nil
}
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
@@ -268,7 +267,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
case constant.RelayModeAudioSpeech:
// 只有当 baseUrl 是火山默认的官方Url时才改为官方的的TTS接口否则走透传的New接口
if baseUrl == channelconstant.ChannelBaseURLs[channelconstant.ChannelTypeVolcEngine] {
return "https://openspeech.bytedance.com/api/v1/tts", nil
return "wss://openspeech.bytedance.com/api/v1/tts/ws_binary", nil
}
return fmt.Sprintf("%s/v1/audio/speech", baseUrl), nil
default:
@@ -320,12 +319,60 @@ func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommo
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
// For TTS with WebSocket, skip traditional HTTP request
if info.RelayMode == constant.RelayModeAudioSpeech {
baseUrl := info.ChannelBaseUrl
if baseUrl == "" {
baseUrl = channelconstant.ChannelBaseURLs[channelconstant.ChannelTypeVolcEngine]
}
// Only use WebSocket for official Volcengine endpoint
if baseUrl == channelconstant.ChannelBaseURLs[channelconstant.ChannelTypeVolcEngine] {
return nil, nil // WebSocket handling will be done in DoResponse
}
}
return channel.DoApiRequest(a, c, info, requestBody)
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
if info.RelayMode == constant.RelayModeAudioSpeech {
encoding := mapEncoding(c.GetString("response_format"))
// Check if this is WebSocket mode (resp will be nil for WebSocket)
if resp == nil {
// Get the WebSocket URL
requestURL, urlErr := a.GetRequestURL(info)
if urlErr != nil {
return nil, types.NewErrorWithStatusCode(
urlErr,
types.ErrorCodeBadRequestBody,
http.StatusInternalServerError,
)
}
// Retrieve the volcengine request from context
volcRequestInterface, exists := c.Get("volcengine_tts_request")
if !exists {
return nil, types.NewErrorWithStatusCode(
errors.New("volcengine TTS request not found in context"),
types.ErrorCodeBadRequestBody,
http.StatusInternalServerError,
)
}
volcRequest, ok := volcRequestInterface.(VolcengineTTSRequest)
if !ok {
return nil, types.NewErrorWithStatusCode(
errors.New("invalid volcengine TTS request type"),
types.ErrorCodeBadRequestBody,
http.StatusInternalServerError,
)
}
// Handle WebSocket streaming
return handleTTSWebSocketResponse(c, requestURL, volcRequest, info, encoding)
}
// Handle traditional HTTP response
return handleTTSResponse(c, resp, info, encoding)
}