Files
new-api/relay/channel/volcengine/tts.go
CaIon f5b409d74f feat: refactor token estimation logic
- Introduced new OpenAI text models in `common/model.go`.
- Added `IsOpenAITextModel` function to check for OpenAI text models.
- Refactored token estimation methods across various channels to use estimated prompt tokens instead of direct prompt token counts.
- Updated related functions and structures to accommodate the new token estimation approach, enhancing overall token management.
2025-12-02 21:34:39 +08:00

306 lines
9.1 KiB
Go

package volcengine
import (
"context"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"strings"
"github.com/QuantumNous/new-api/dto"
relaycommon "github.com/QuantumNous/new-api/relay/common"
"github.com/QuantumNous/new-api/types"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
"github.com/gorilla/websocket"
)
type VolcengineTTSRequest struct {
App VolcengineTTSApp `json:"app"`
User VolcengineTTSUser `json:"user"`
Audio VolcengineTTSAudio `json:"audio"`
Request VolcengineTTSReqInfo `json:"request"`
}
type VolcengineTTSApp struct {
AppID string `json:"appid"`
Token string `json:"token"`
Cluster string `json:"cluster"`
}
type VolcengineTTSUser struct {
UID string `json:"uid"`
}
type VolcengineTTSAudio struct {
VoiceType string `json:"voice_type"`
Encoding string `json:"encoding"`
SpeedRatio float64 `json:"speed_ratio"`
Rate int `json:"rate"`
Bitrate int `json:"bitrate,omitempty"`
LoudnessRatio float64 `json:"loudness_ratio,omitempty"`
EnableEmotion bool `json:"enable_emotion,omitempty"`
Emotion string `json:"emotion,omitempty"`
EmotionScale float64 `json:"emotion_scale,omitempty"`
ExplicitLanguage string `json:"explicit_language,omitempty"`
ContextLanguage string `json:"context_language,omitempty"`
}
type VolcengineTTSReqInfo struct {
ReqID string `json:"reqid"`
Text string `json:"text"`
Operation string `json:"operation"`
Model string `json:"model,omitempty"`
TextType string `json:"text_type,omitempty"`
SilenceDuration float64 `json:"silence_duration,omitempty"`
WithTimestamp interface{} `json:"with_timestamp,omitempty"`
ExtraParam *VolcengineTTSExtraParam `json:"extra_param,omitempty"`
}
type VolcengineTTSExtraParam struct {
DisableMarkdownFilter bool `json:"disable_markdown_filter,omitempty"`
EnableLatexTn bool `json:"enable_latex_tn,omitempty"`
MuteCutThreshold string `json:"mute_cut_threshold,omitempty"`
MuteCutRemainMs string `json:"mute_cut_remain_ms,omitempty"`
DisableEmojiFilter bool `json:"disable_emoji_filter,omitempty"`
UnsupportedCharRatioThresh float64 `json:"unsupported_char_ratio_thresh,omitempty"`
AigcWatermark bool `json:"aigc_watermark,omitempty"`
CacheConfig *VolcengineTTSCacheConfig `json:"cache_config,omitempty"`
}
type VolcengineTTSCacheConfig struct {
TextType int `json:"text_type,omitempty"`
UseCache bool `json:"use_cache,omitempty"`
}
type VolcengineTTSResponse struct {
ReqID string `json:"reqid"`
Code int `json:"code"`
Message string `json:"message"`
Sequence int `json:"sequence"`
Data string `json:"data"`
Addition *VolcengineTTSAdditionInfo `json:"addition,omitempty"`
}
type VolcengineTTSAdditionInfo struct {
Duration string `json:"duration"`
}
var openAIToVolcengineVoiceMap = map[string]string{
"alloy": "zh_male_M392_conversation_wvae_bigtts",
"echo": "zh_male_wenhao_mars_bigtts",
"fable": "zh_female_tianmei_mars_bigtts",
"onyx": "zh_male_zhibei_mars_bigtts",
"nova": "zh_female_shuangkuaisisi_mars_bigtts",
"shimmer": "zh_female_cancan_mars_bigtts",
}
var responseFormatToEncodingMap = map[string]string{
"mp3": "mp3",
"opus": "ogg_opus",
"aac": "mp3",
"flac": "mp3",
"wav": "wav",
"pcm": "pcm",
}
func parseVolcengineAuth(apiKey string) (appID, token string, err error) {
parts := strings.Split(apiKey, "|")
if len(parts) != 2 {
return "", "", errors.New("invalid api key format, expected: appid|access_token")
}
return parts[0], parts[1], nil
}
func mapVoiceType(openAIVoice string) string {
if voice, ok := openAIToVolcengineVoiceMap[openAIVoice]; ok {
return voice
}
return openAIVoice
}
func mapEncoding(responseFormat string) string {
if encoding, ok := responseFormatToEncodingMap[responseFormat]; ok {
return encoding
}
return "mp3"
}
func getContentTypeByEncoding(encoding string) string {
contentTypeMap := map[string]string{
"mp3": "audio/mpeg",
"ogg_opus": "audio/ogg",
"wav": "audio/wav",
"pcm": "audio/pcm",
}
if ct, ok := contentTypeMap[encoding]; ok {
return ct
}
return "application/octet-stream"
}
func handleTTSResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, encoding string) (usage any, err *types.NewAPIError) {
body, readErr := io.ReadAll(resp.Body)
if readErr != nil {
return nil, types.NewErrorWithStatusCode(
errors.New("failed to read volcengine response"),
types.ErrorCodeReadResponseBodyFailed,
http.StatusInternalServerError,
)
}
defer resp.Body.Close()
var volcResp VolcengineTTSResponse
if unmarshalErr := json.Unmarshal(body, &volcResp); unmarshalErr != nil {
return nil, types.NewErrorWithStatusCode(
errors.New("failed to parse volcengine response"),
types.ErrorCodeBadResponseBody,
http.StatusInternalServerError,
)
}
if volcResp.Code != 3000 {
return nil, types.NewErrorWithStatusCode(
errors.New(volcResp.Message),
types.ErrorCodeBadResponse,
http.StatusBadRequest,
)
}
audioData, decodeErr := base64.StdEncoding.DecodeString(volcResp.Data)
if decodeErr != nil {
return nil, types.NewErrorWithStatusCode(
errors.New("failed to decode audio data"),
types.ErrorCodeBadResponseBody,
http.StatusInternalServerError,
)
}
contentType := getContentTypeByEncoding(encoding)
c.Header("Content-Type", contentType)
c.Data(http.StatusOK, contentType, audioData)
usage = &dto.Usage{
PromptTokens: info.GetEstimatePromptTokens(),
CompletionTokens: 0,
TotalTokens: info.GetEstimatePromptTokens(),
}
return usage, nil
}
func generateRequestID() string {
return uuid.New().String()
}
func handleTTSWebSocketResponse(c *gin.Context, requestURL string, volcRequest VolcengineTTSRequest, info *relaycommon.RelayInfo, encoding string) (usage any, err *types.NewAPIError) {
_, token, parseErr := parseVolcengineAuth(info.ApiKey)
if parseErr != nil {
return nil, types.NewErrorWithStatusCode(
parseErr,
types.ErrorCodeChannelInvalidKey,
http.StatusUnauthorized,
)
}
header := http.Header{}
header.Set("Authorization", fmt.Sprintf("Bearer;%s", token))
conn, resp, dialErr := websocket.DefaultDialer.DialContext(context.Background(), requestURL, header)
if dialErr != nil {
if resp != nil {
return nil, types.NewErrorWithStatusCode(
fmt.Errorf("failed to connect to websocket: %w, status: %d", dialErr, resp.StatusCode),
types.ErrorCodeBadResponseStatusCode,
http.StatusBadGateway,
)
}
return nil, types.NewErrorWithStatusCode(
fmt.Errorf("failed to connect to websocket: %w", dialErr),
types.ErrorCodeBadResponseStatusCode,
http.StatusBadGateway,
)
}
defer conn.Close()
payload, marshalErr := json.Marshal(volcRequest)
if marshalErr != nil {
return nil, types.NewErrorWithStatusCode(
fmt.Errorf("failed to marshal request: %w", marshalErr),
types.ErrorCodeBadRequestBody,
http.StatusInternalServerError,
)
}
if sendErr := FullClientRequest(conn, payload); sendErr != nil {
return nil, types.NewErrorWithStatusCode(
fmt.Errorf("failed to send request: %w", sendErr),
types.ErrorCodeBadRequestBody,
http.StatusInternalServerError,
)
}
contentType := getContentTypeByEncoding(encoding)
c.Header("Content-Type", contentType)
c.Header("Transfer-Encoding", "chunked")
for {
msg, recvErr := ReceiveMessage(conn)
if recvErr != nil {
if websocket.IsCloseError(recvErr, websocket.CloseNormalClosure, websocket.CloseGoingAway) {
break
}
return nil, types.NewErrorWithStatusCode(
fmt.Errorf("failed to receive message: %w", recvErr),
types.ErrorCodeBadResponse,
http.StatusInternalServerError,
)
}
switch msg.MsgType {
case MsgTypeError:
return nil, types.NewErrorWithStatusCode(
fmt.Errorf("received error from server: code=%d, %s", msg.ErrorCode, string(msg.Payload)),
types.ErrorCodeBadResponse,
http.StatusBadRequest,
)
case MsgTypeFrontEndResultServer:
continue
case MsgTypeAudioOnlyServer:
if len(msg.Payload) > 0 {
if _, writeErr := c.Writer.Write(msg.Payload); writeErr != nil {
return nil, types.NewErrorWithStatusCode(
fmt.Errorf("failed to write audio data: %w", writeErr),
types.ErrorCodeBadResponse,
http.StatusInternalServerError,
)
}
c.Writer.Flush()
}
if msg.Sequence < 0 {
c.Status(http.StatusOK)
usage = &dto.Usage{
PromptTokens: info.GetEstimatePromptTokens(),
CompletionTokens: 0,
TotalTokens: info.GetEstimatePromptTokens(),
}
return usage, nil
}
default:
continue
}
}
c.Status(http.StatusOK)
usage = &dto.Usage{
PromptTokens: info.GetEstimatePromptTokens(),
CompletionTokens: 0,
TotalTokens: info.GetEstimatePromptTokens(),
}
return usage, nil
}