feat: doubao tts add is stream check

This commit is contained in:
feitianbubu
2025-10-22 13:29:01 +08:00
parent 098e6e7f2b
commit 431b3a84f6
2 changed files with 36 additions and 33 deletions

View File

@@ -23,6 +23,12 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
const (
// Context keys for passing data between methods
contextKeyTTSRequest = "volcengine_tts_request"
contextKeyResponseFormat = "response_format"
)
type Adaptor struct { type Adaptor struct {
} }
@@ -50,7 +56,7 @@ func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInf
speedRatio := request.Speed speedRatio := request.Speed
encoding := mapEncoding(request.ResponseFormat) encoding := mapEncoding(request.ResponseFormat)
c.Set("response_format", encoding) c.Set(contextKeyResponseFormat, encoding)
volcRequest := VolcengineTTSRequest{ volcRequest := VolcengineTTSRequest{
App: VolcengineTTSApp{ App: VolcengineTTSApp{
@@ -70,7 +76,7 @@ func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInf
Request: VolcengineTTSReqInfo{ Request: VolcengineTTSReqInfo{
ReqID: generateRequestID(), ReqID: generateRequestID(),
Text: request.Input, Text: request.Input,
Operation: "submit", // WebSocket uses "submit" Operation: "submit", // default WebSocket uses "submit"
Model: info.OriginModelName, Model: info.OriginModelName,
}, },
} }
@@ -83,10 +89,20 @@ func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInf
} }
// Store the request in context for WebSocket handler // Store the request in context for WebSocket handler
c.Set("volcengine_tts_request", volcRequest) c.Set(contextKeyTTSRequest, volcRequest)
// https://www.volcengine.com/docs/6561/1257584
// operation需要设置为submit才是流式返回
if volcRequest.Request.Operation == "submit" {
info.IsStream = true
}
// Return nil as WebSocket doesn't use traditional request body // Return nil as WebSocket doesn't use traditional request body
return nil, nil jsonData, err := json.Marshal(volcRequest)
if err != nil {
return nil, fmt.Errorf("error marshalling volcengine request: %w", err)
}
return bytes.NewReader(jsonData), nil
} }
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) { func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
@@ -327,7 +343,9 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
} }
// Only use WebSocket for official Volcengine endpoint // Only use WebSocket for official Volcengine endpoint
if baseUrl == channelconstant.ChannelBaseURLs[channelconstant.ChannelTypeVolcEngine] { if baseUrl == channelconstant.ChannelBaseURLs[channelconstant.ChannelTypeVolcEngine] {
return nil, nil // WebSocket handling will be done in DoResponse if info.IsStream {
return nil, nil
}
} }
} }
return channel.DoApiRequest(a, c, info, requestBody) return channel.DoApiRequest(a, c, info, requestBody)
@@ -335,22 +353,9 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) { func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
if info.RelayMode == constant.RelayModeAudioSpeech { if info.RelayMode == constant.RelayModeAudioSpeech {
encoding := mapEncoding(c.GetString("response_format")) encoding := mapEncoding(c.GetString(contextKeyResponseFormat))
if info.IsStream {
// Check if this is WebSocket mode (resp will be nil for WebSocket) volcRequestInterface, exists := c.Get(contextKeyTTSRequest)
if resp == nil {
// Get the WebSocket URL
requestURL, urlErr := a.GetRequestURL(info)
if urlErr != nil {
return nil, types.NewErrorWithStatusCode(
urlErr,
types.ErrorCodeBadRequestBody,
http.StatusInternalServerError,
)
}
// Retrieve the volcengine request from context
volcRequestInterface, exists := c.Get("volcengine_tts_request")
if !exists { if !exists {
return nil, types.NewErrorWithStatusCode( return nil, types.NewErrorWithStatusCode(
errors.New("volcengine TTS request not found in context"), errors.New("volcengine TTS request not found in context"),
@@ -368,11 +373,17 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
) )
} }
// Handle WebSocket streaming // Get the WebSocket URL
requestURL, urlErr := a.GetRequestURL(info)
if urlErr != nil {
return nil, types.NewErrorWithStatusCode(
urlErr,
types.ErrorCodeBadRequestBody,
http.StatusInternalServerError,
)
}
return handleTTSWebSocketResponse(c, requestURL, volcRequest, info, encoding) return handleTTSWebSocketResponse(c, requestURL, volcRequest, info, encoding)
} }
// Handle traditional HTTP response
return handleTTSResponse(c, resp, info, encoding) return handleTTSResponse(c, resp, info, encoding)
} }

View File

@@ -230,10 +230,6 @@ func handleTTSWebSocketResponse(c *gin.Context, requestURL string, volcRequest V
} }
defer conn.Close() defer conn.Close()
// Update request operation to "submit" for WebSocket
volcRequest.Request.Operation = "submit"
// Marshal request payload
payload, marshalErr := json.Marshal(volcRequest) payload, marshalErr := json.Marshal(volcRequest)
if marshalErr != nil { if marshalErr != nil {
return nil, types.NewErrorWithStatusCode( return nil, types.NewErrorWithStatusCode(
@@ -280,10 +276,8 @@ func handleTTSWebSocketResponse(c *gin.Context, requestURL string, volcRequest V
http.StatusBadRequest, http.StatusBadRequest,
) )
case MsgTypeFrontEndResultServer: case MsgTypeFrontEndResultServer:
// Metadata response, can be logged or processed
continue continue
case MsgTypeAudioOnlyServer: case MsgTypeAudioOnlyServer:
// Stream audio chunk to client
if len(msg.Payload) > 0 { if len(msg.Payload) > 0 {
audioBuffer = append(audioBuffer, msg.Payload...) audioBuffer = append(audioBuffer, msg.Payload...)
if _, writeErr := c.Writer.Write(msg.Payload); writeErr != nil { if _, writeErr := c.Writer.Write(msg.Payload); writeErr != nil {
@@ -293,10 +287,10 @@ func handleTTSWebSocketResponse(c *gin.Context, requestURL string, volcRequest V
http.StatusInternalServerError, http.StatusInternalServerError,
) )
} }
//logger.Infof("write audio chunk size: %d", len(msg.Payload))
c.Writer.Flush() c.Writer.Flush()
} }
// Check if this is the last packet (negative sequence)
if msg.Sequence < 0 { if msg.Sequence < 0 {
c.Status(http.StatusOK) c.Status(http.StatusOK)
usage = &dto.Usage{ usage = &dto.Usage{
@@ -307,12 +301,10 @@ func handleTTSWebSocketResponse(c *gin.Context, requestURL string, volcRequest V
return usage, nil return usage, nil
} }
default: default:
// Unknown message type, log and continue
continue continue
} }
} }
// If we reach here, connection closed without final packet
c.Status(http.StatusOK) c.Status(http.StatusOK)
usage = &dto.Usage{ usage = &dto.Usage{
PromptTokens: info.PromptTokens, PromptTokens: info.PromptTokens,