mirror of
https://github.com/QuantumNous/new-api.git
synced 2026-04-19 05:37:27 +00:00
feat: doubao tts add is stream check
This commit is contained in:
@@ -23,6 +23,12 @@ import (
|
|||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// Context keys for passing data between methods
|
||||||
|
contextKeyTTSRequest = "volcengine_tts_request"
|
||||||
|
contextKeyResponseFormat = "response_format"
|
||||||
|
)
|
||||||
|
|
||||||
type Adaptor struct {
|
type Adaptor struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -50,7 +56,7 @@ func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInf
|
|||||||
speedRatio := request.Speed
|
speedRatio := request.Speed
|
||||||
encoding := mapEncoding(request.ResponseFormat)
|
encoding := mapEncoding(request.ResponseFormat)
|
||||||
|
|
||||||
c.Set("response_format", encoding)
|
c.Set(contextKeyResponseFormat, encoding)
|
||||||
|
|
||||||
volcRequest := VolcengineTTSRequest{
|
volcRequest := VolcengineTTSRequest{
|
||||||
App: VolcengineTTSApp{
|
App: VolcengineTTSApp{
|
||||||
@@ -70,7 +76,7 @@ func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInf
|
|||||||
Request: VolcengineTTSReqInfo{
|
Request: VolcengineTTSReqInfo{
|
||||||
ReqID: generateRequestID(),
|
ReqID: generateRequestID(),
|
||||||
Text: request.Input,
|
Text: request.Input,
|
||||||
Operation: "submit", // WebSocket uses "submit"
|
Operation: "submit", // default WebSocket uses "submit"
|
||||||
Model: info.OriginModelName,
|
Model: info.OriginModelName,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -83,10 +89,20 @@ func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInf
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Store the request in context for WebSocket handler
|
// Store the request in context for WebSocket handler
|
||||||
c.Set("volcengine_tts_request", volcRequest)
|
c.Set(contextKeyTTSRequest, volcRequest)
|
||||||
|
// https://www.volcengine.com/docs/6561/1257584
|
||||||
|
// operation需要设置为submit才是流式返回
|
||||||
|
if volcRequest.Request.Operation == "submit" {
|
||||||
|
info.IsStream = true
|
||||||
|
}
|
||||||
|
|
||||||
// Return nil as WebSocket doesn't use traditional request body
|
// Return nil as WebSocket doesn't use traditional request body
|
||||||
return nil, nil
|
jsonData, err := json.Marshal(volcRequest)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("error marshalling volcengine request: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return bytes.NewReader(jsonData), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
|
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
|
||||||
@@ -327,7 +343,9 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
|
|||||||
}
|
}
|
||||||
// Only use WebSocket for official Volcengine endpoint
|
// Only use WebSocket for official Volcengine endpoint
|
||||||
if baseUrl == channelconstant.ChannelBaseURLs[channelconstant.ChannelTypeVolcEngine] {
|
if baseUrl == channelconstant.ChannelBaseURLs[channelconstant.ChannelTypeVolcEngine] {
|
||||||
return nil, nil // WebSocket handling will be done in DoResponse
|
if info.IsStream {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return channel.DoApiRequest(a, c, info, requestBody)
|
return channel.DoApiRequest(a, c, info, requestBody)
|
||||||
@@ -335,22 +353,9 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
|
|||||||
|
|
||||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
|
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
|
||||||
if info.RelayMode == constant.RelayModeAudioSpeech {
|
if info.RelayMode == constant.RelayModeAudioSpeech {
|
||||||
encoding := mapEncoding(c.GetString("response_format"))
|
encoding := mapEncoding(c.GetString(contextKeyResponseFormat))
|
||||||
|
if info.IsStream {
|
||||||
// Check if this is WebSocket mode (resp will be nil for WebSocket)
|
volcRequestInterface, exists := c.Get(contextKeyTTSRequest)
|
||||||
if resp == nil {
|
|
||||||
// Get the WebSocket URL
|
|
||||||
requestURL, urlErr := a.GetRequestURL(info)
|
|
||||||
if urlErr != nil {
|
|
||||||
return nil, types.NewErrorWithStatusCode(
|
|
||||||
urlErr,
|
|
||||||
types.ErrorCodeBadRequestBody,
|
|
||||||
http.StatusInternalServerError,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Retrieve the volcengine request from context
|
|
||||||
volcRequestInterface, exists := c.Get("volcengine_tts_request")
|
|
||||||
if !exists {
|
if !exists {
|
||||||
return nil, types.NewErrorWithStatusCode(
|
return nil, types.NewErrorWithStatusCode(
|
||||||
errors.New("volcengine TTS request not found in context"),
|
errors.New("volcengine TTS request not found in context"),
|
||||||
@@ -368,11 +373,17 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handle WebSocket streaming
|
// Get the WebSocket URL
|
||||||
|
requestURL, urlErr := a.GetRequestURL(info)
|
||||||
|
if urlErr != nil {
|
||||||
|
return nil, types.NewErrorWithStatusCode(
|
||||||
|
urlErr,
|
||||||
|
types.ErrorCodeBadRequestBody,
|
||||||
|
http.StatusInternalServerError,
|
||||||
|
)
|
||||||
|
}
|
||||||
return handleTTSWebSocketResponse(c, requestURL, volcRequest, info, encoding)
|
return handleTTSWebSocketResponse(c, requestURL, volcRequest, info, encoding)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handle traditional HTTP response
|
|
||||||
return handleTTSResponse(c, resp, info, encoding)
|
return handleTTSResponse(c, resp, info, encoding)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -230,10 +230,6 @@ func handleTTSWebSocketResponse(c *gin.Context, requestURL string, volcRequest V
|
|||||||
}
|
}
|
||||||
defer conn.Close()
|
defer conn.Close()
|
||||||
|
|
||||||
// Update request operation to "submit" for WebSocket
|
|
||||||
volcRequest.Request.Operation = "submit"
|
|
||||||
|
|
||||||
// Marshal request payload
|
|
||||||
payload, marshalErr := json.Marshal(volcRequest)
|
payload, marshalErr := json.Marshal(volcRequest)
|
||||||
if marshalErr != nil {
|
if marshalErr != nil {
|
||||||
return nil, types.NewErrorWithStatusCode(
|
return nil, types.NewErrorWithStatusCode(
|
||||||
@@ -280,10 +276,8 @@ func handleTTSWebSocketResponse(c *gin.Context, requestURL string, volcRequest V
|
|||||||
http.StatusBadRequest,
|
http.StatusBadRequest,
|
||||||
)
|
)
|
||||||
case MsgTypeFrontEndResultServer:
|
case MsgTypeFrontEndResultServer:
|
||||||
// Metadata response, can be logged or processed
|
|
||||||
continue
|
continue
|
||||||
case MsgTypeAudioOnlyServer:
|
case MsgTypeAudioOnlyServer:
|
||||||
// Stream audio chunk to client
|
|
||||||
if len(msg.Payload) > 0 {
|
if len(msg.Payload) > 0 {
|
||||||
audioBuffer = append(audioBuffer, msg.Payload...)
|
audioBuffer = append(audioBuffer, msg.Payload...)
|
||||||
if _, writeErr := c.Writer.Write(msg.Payload); writeErr != nil {
|
if _, writeErr := c.Writer.Write(msg.Payload); writeErr != nil {
|
||||||
@@ -293,10 +287,10 @@ func handleTTSWebSocketResponse(c *gin.Context, requestURL string, volcRequest V
|
|||||||
http.StatusInternalServerError,
|
http.StatusInternalServerError,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
//logger.Infof("write audio chunk size: %d", len(msg.Payload))
|
||||||
c.Writer.Flush()
|
c.Writer.Flush()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if this is the last packet (negative sequence)
|
|
||||||
if msg.Sequence < 0 {
|
if msg.Sequence < 0 {
|
||||||
c.Status(http.StatusOK)
|
c.Status(http.StatusOK)
|
||||||
usage = &dto.Usage{
|
usage = &dto.Usage{
|
||||||
@@ -307,12 +301,10 @@ func handleTTSWebSocketResponse(c *gin.Context, requestURL string, volcRequest V
|
|||||||
return usage, nil
|
return usage, nil
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
// Unknown message type, log and continue
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// If we reach here, connection closed without final packet
|
|
||||||
c.Status(http.StatusOK)
|
c.Status(http.StatusOK)
|
||||||
usage = &dto.Usage{
|
usage = &dto.Usage{
|
||||||
PromptTokens: info.PromptTokens,
|
PromptTokens: info.PromptTokens,
|
||||||
|
|||||||
Reference in New Issue
Block a user