Merge pull request #2067 from feitianbubu/pr/add-doubao-audio

新增支持豆包语音合成2.0功能
This commit is contained in:
IcedTangerine
2025-10-18 00:14:11 +08:00
committed by GitHub
3 changed files with 275 additions and 3 deletions

View File

@@ -37,8 +37,50 @@ func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayIn
}
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
//TODO implement me
return nil, errors.New("not implemented")
if info.RelayMode != constant.RelayModeAudioSpeech {
return nil, errors.New("unsupported audio relay mode")
}
appID, token, err := parseVolcengineAuth(info.ApiKey)
if err != nil {
return nil, err
}
voiceType := mapVoiceType(request.Voice)
speedRatio := mapSpeedRatio(request.Speed)
encoding := mapEncoding(request.ResponseFormat)
c.Set("response_format", encoding)
volcRequest := VolcengineTTSRequest{
App: VolcengineTTSApp{
AppID: appID,
Token: token,
Cluster: "volcano_tts",
},
User: VolcengineTTSUser{
UID: "openai_relay_user",
},
Audio: VolcengineTTSAudio{
VoiceType: voiceType,
Encoding: encoding,
SpeedRatio: speedRatio,
Rate: 24000,
},
Request: VolcengineTTSReqInfo{
ReqID: generateRequestID(),
Text: request.Input,
Operation: "query",
Model: info.OriginModelName,
},
}
jsonData, err := json.Marshal(volcRequest)
if err != nil {
return nil, fmt.Errorf("error marshalling volcengine request: %w", err)
}
return bytes.NewReader(jsonData), nil
}
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
@@ -190,7 +232,6 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
// 支持自定义域名,如果未设置则使用默认域名
baseUrl := info.ChannelBaseUrl
if baseUrl == "" {
baseUrl = channelconstant.ChannelBaseURLs[channelconstant.ChannelTypeVolcEngine]
@@ -217,6 +258,12 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
return fmt.Sprintf("%s/api/v3/images/edits", baseUrl), nil
case constant.RelayModeRerank:
return fmt.Sprintf("%s/api/v3/rerank", baseUrl), nil
case constant.RelayModeAudioSpeech:
// 只有当 baseUrl 是火山默认的官方Url时才改为官方的的TTS接口否则走透传的New接口
if baseUrl == channelconstant.ChannelBaseURLs[channelconstant.ChannelTypeVolcEngine] {
return "https://openspeech.bytedance.com/api/v1/tts", nil
}
return fmt.Sprintf("%s/v1/audio/speech", baseUrl), nil
default:
}
}
@@ -225,6 +272,16 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
channel.SetupApiRequestHeader(info, c, req)
if info.RelayMode == constant.RelayModeAudioSpeech {
parts := strings.Split(info.ApiKey, "|")
if len(parts) == 2 {
req.Set("Authorization", "Bearer;"+parts[1])
}
req.Set("Content-Type", "application/json")
return nil
}
req.Set("Authorization", "Bearer "+info.ApiKey)
return nil
}
@@ -260,6 +317,11 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
if info.RelayMode == constant.RelayModeAudioSpeech {
encoding := mapEncoding(c.GetString("response_format"))
return handleTTSResponse(c, resp, info, encoding)
}
adaptor := openai.Adaptor{}
usage, err = adaptor.DoResponse(c, resp, info)
return

View File

@@ -0,0 +1,208 @@
package volcengine
import (
"encoding/base64"
"encoding/json"
"errors"
"io"
"net/http"
"strings"
"github.com/QuantumNous/new-api/dto"
relaycommon "github.com/QuantumNous/new-api/relay/common"
"github.com/QuantumNous/new-api/types"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
)
type VolcengineTTSRequest struct {
App VolcengineTTSApp `json:"app"`
User VolcengineTTSUser `json:"user"`
Audio VolcengineTTSAudio `json:"audio"`
Request VolcengineTTSReqInfo `json:"request"`
}
type VolcengineTTSApp struct {
AppID string `json:"appid"`
Token string `json:"token"`
Cluster string `json:"cluster"`
}
type VolcengineTTSUser struct {
UID string `json:"uid"`
}
type VolcengineTTSAudio struct {
VoiceType string `json:"voice_type"`
Encoding string `json:"encoding"`
SpeedRatio float64 `json:"speed_ratio"`
Rate int `json:"rate"`
Bitrate int `json:"bitrate,omitempty"`
LoudnessRatio float64 `json:"loudness_ratio,omitempty"`
EnableEmotion bool `json:"enable_emotion,omitempty"`
Emotion string `json:"emotion,omitempty"`
EmotionScale float64 `json:"emotion_scale,omitempty"`
ExplicitLanguage string `json:"explicit_language,omitempty"`
ContextLanguage string `json:"context_language,omitempty"`
}
type VolcengineTTSReqInfo struct {
ReqID string `json:"reqid"`
Text string `json:"text"`
Operation string `json:"operation"`
Model string `json:"model,omitempty"`
TextType string `json:"text_type,omitempty"`
SilenceDuration float64 `json:"silence_duration,omitempty"`
WithTimestamp interface{} `json:"with_timestamp,omitempty"`
ExtraParam *VolcengineTTSExtraParam `json:"extra_param,omitempty"`
}
type VolcengineTTSExtraParam struct {
DisableMarkdownFilter bool `json:"disable_markdown_filter,omitempty"`
EnableLatexTn bool `json:"enable_latex_tn,omitempty"`
MuteCutThreshold string `json:"mute_cut_threshold,omitempty"`
MuteCutRemainMs string `json:"mute_cut_remain_ms,omitempty"`
DisableEmojiFilter bool `json:"disable_emoji_filter,omitempty"`
UnsupportedCharRatioThresh float64 `json:"unsupported_char_ratio_thresh,omitempty"`
AigcWatermark bool `json:"aigc_watermark,omitempty"`
CacheConfig *VolcengineTTSCacheConfig `json:"cache_config,omitempty"`
}
type VolcengineTTSCacheConfig struct {
TextType int `json:"text_type,omitempty"`
UseCache bool `json:"use_cache,omitempty"`
}
type VolcengineTTSResponse struct {
ReqID string `json:"reqid"`
Code int `json:"code"`
Message string `json:"message"`
Sequence int `json:"sequence"`
Data string `json:"data"`
Addition *VolcengineTTSAdditionInfo `json:"addition,omitempty"`
}
type VolcengineTTSAdditionInfo struct {
Duration string `json:"duration"`
}
var openAIToVolcengineVoiceMap = map[string]string{
"alloy": "zh_male_M392_conversation_wvae_bigtts",
"echo": "zh_male_wenhao_mars_bigtts",
"fable": "zh_female_tianmei_mars_bigtts",
"onyx": "zh_male_zhibei_mars_bigtts",
"nova": "zh_female_shuangkuaisisi_mars_bigtts",
"shimmer": "zh_female_cancan_mars_bigtts",
}
var responseFormatToEncodingMap = map[string]string{
"mp3": "mp3",
"opus": "ogg_opus",
"aac": "mp3",
"flac": "mp3",
"wav": "wav",
"pcm": "pcm",
}
func parseVolcengineAuth(apiKey string) (appID, token string, err error) {
parts := strings.Split(apiKey, "|")
if len(parts) != 2 {
return "", "", errors.New("invalid api key format, expected: appid|access_token")
}
return parts[0], parts[1], nil
}
func mapVoiceType(openAIVoice string) string {
if voice, ok := openAIToVolcengineVoiceMap[openAIVoice]; ok {
return voice
}
return openAIVoice
}
// [0.1,2],默认为 1通常保留一位小数即可
func mapSpeedRatio(speed float64) float64 {
if speed == 0 {
return 1.0
}
if speed < 0.1 {
return 0.1
}
if speed > 2.0 {
return 2.0
}
return speed
}
func mapEncoding(responseFormat string) string {
if encoding, ok := responseFormatToEncodingMap[responseFormat]; ok {
return encoding
}
return "mp3"
}
func getContentTypeByEncoding(encoding string) string {
contentTypeMap := map[string]string{
"mp3": "audio/mpeg",
"ogg_opus": "audio/ogg",
"wav": "audio/wav",
"pcm": "audio/pcm",
}
if ct, ok := contentTypeMap[encoding]; ok {
return ct
}
return "application/octet-stream"
}
func handleTTSResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, encoding string) (usage any, err *types.NewAPIError) {
body, readErr := io.ReadAll(resp.Body)
if readErr != nil {
return nil, types.NewErrorWithStatusCode(
errors.New("failed to read volcengine response"),
types.ErrorCodeReadResponseBodyFailed,
http.StatusInternalServerError,
)
}
defer resp.Body.Close()
var volcResp VolcengineTTSResponse
if unmarshalErr := json.Unmarshal(body, &volcResp); unmarshalErr != nil {
return nil, types.NewErrorWithStatusCode(
errors.New("failed to parse volcengine response"),
types.ErrorCodeBadResponseBody,
http.StatusInternalServerError,
)
}
if volcResp.Code != 3000 {
return nil, types.NewErrorWithStatusCode(
errors.New(volcResp.Message),
types.ErrorCodeBadResponse,
http.StatusBadRequest,
)
}
audioData, decodeErr := base64.StdEncoding.DecodeString(volcResp.Data)
if decodeErr != nil {
return nil, types.NewErrorWithStatusCode(
errors.New("failed to decode audio data"),
types.ErrorCodeBadResponseBody,
http.StatusInternalServerError,
)
}
contentType := getContentTypeByEncoding(encoding)
c.Header("Content-Type", contentType)
c.Data(http.StatusOK, contentType, audioData)
usage = &dto.Usage{
PromptTokens: info.PromptTokens,
CompletionTokens: 0,
TotalTokens: info.PromptTokens,
}
return usage, nil
}
func generateRequestID() string {
return uuid.New().String()
}

View File

@@ -107,6 +107,8 @@ function type2secretPrompt(type) {
return '按照如下格式输入AppId|SecretId|SecretKey';
case 33:
return '按照如下格式输入Ak|Sk|Region';
case 45:
return '请输入渠道对应的鉴权密钥, 豆包语音输入AppId|AccessToken';
case 50:
return '按照如下格式输入: AccessKey|SecretKey, 如果上游是New API则直接输ApiKey';
case 51: