mirror of
https://github.com/QuantumNous/new-api.git
synced 2026-03-30 02:25:00 +00:00
390 lines
12 KiB
Go
390 lines
12 KiB
Go
package volcengine
|
|
|
|
import (
|
|
"bytes"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"path/filepath"
|
|
"strings"
|
|
|
|
channelconstant "github.com/QuantumNous/new-api/constant"
|
|
"github.com/QuantumNous/new-api/dto"
|
|
"github.com/QuantumNous/new-api/relay/channel"
|
|
"github.com/QuantumNous/new-api/relay/channel/openai"
|
|
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
|
"github.com/QuantumNous/new-api/relay/constant"
|
|
"github.com/QuantumNous/new-api/setting/model_setting"
|
|
"github.com/QuantumNous/new-api/types"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
)
|
|
|
|
const (
|
|
contextKeyTTSRequest = "volcengine_tts_request"
|
|
contextKeyResponseFormat = "response_format"
|
|
DoubaoCodingPlan = "doubao-coding-plan"
|
|
DoubaoCodingPlanClaudeBaseURL = "https://ark.cn-beijing.volces.com/api/coding"
|
|
DoubaoCodingPlanOpenAIBaseURL = "https://ark.cn-beijing.volces.com/api/coding/v3"
|
|
)
|
|
|
|
type Adaptor struct {
|
|
}
|
|
|
|
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
|
|
//TODO implement me
|
|
return nil, errors.New("not implemented")
|
|
}
|
|
|
|
func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, req *dto.ClaudeRequest) (any, error) {
|
|
adaptor := openai.Adaptor{}
|
|
return adaptor.ConvertClaudeRequest(c, info, req)
|
|
}
|
|
|
|
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
|
if info.RelayMode != constant.RelayModeAudioSpeech {
|
|
return nil, errors.New("unsupported audio relay mode")
|
|
}
|
|
|
|
appID, token, err := parseVolcengineAuth(info.ApiKey)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
voiceType := mapVoiceType(request.Voice)
|
|
speedRatio := request.Speed
|
|
encoding := mapEncoding(request.ResponseFormat)
|
|
|
|
c.Set(contextKeyResponseFormat, encoding)
|
|
|
|
volcRequest := VolcengineTTSRequest{
|
|
App: VolcengineTTSApp{
|
|
AppID: appID,
|
|
Token: token,
|
|
Cluster: "volcano_tts",
|
|
},
|
|
User: VolcengineTTSUser{
|
|
UID: "openai_relay_user",
|
|
},
|
|
Audio: VolcengineTTSAudio{
|
|
VoiceType: voiceType,
|
|
Encoding: encoding,
|
|
SpeedRatio: speedRatio,
|
|
Rate: 24000,
|
|
},
|
|
Request: VolcengineTTSReqInfo{
|
|
ReqID: generateRequestID(),
|
|
Text: request.Input,
|
|
Operation: "submit",
|
|
Model: info.OriginModelName,
|
|
},
|
|
}
|
|
|
|
if len(request.Metadata) > 0 {
|
|
if err = json.Unmarshal(request.Metadata, &volcRequest); err != nil {
|
|
return nil, fmt.Errorf("error unmarshalling metadata to volcengine request: %w", err)
|
|
}
|
|
}
|
|
|
|
c.Set(contextKeyTTSRequest, volcRequest)
|
|
|
|
if volcRequest.Request.Operation == "submit" {
|
|
info.IsStream = true
|
|
}
|
|
|
|
jsonData, err := json.Marshal(volcRequest)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("error marshalling volcengine request: %w", err)
|
|
}
|
|
|
|
return bytes.NewReader(jsonData), nil
|
|
}
|
|
|
|
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
|
|
switch info.RelayMode {
|
|
case constant.RelayModeImagesGenerations:
|
|
return request, nil
|
|
// 根据官方文档,并没有发现豆包生图支持表单请求:https://www.volcengine.com/docs/82379/1824121
|
|
//case constant.RelayModeImagesEdits:
|
|
//
|
|
// var requestBody bytes.Buffer
|
|
// writer := multipart.NewWriter(&requestBody)
|
|
//
|
|
// writer.WriteField("model", request.Model)
|
|
//
|
|
// formData := c.Request.PostForm
|
|
// for key, values := range formData {
|
|
// if key == "model" {
|
|
// continue
|
|
// }
|
|
// for _, value := range values {
|
|
// writer.WriteField(key, value)
|
|
// }
|
|
// }
|
|
//
|
|
// if err := c.Request.ParseMultipartForm(32 << 20); err != nil {
|
|
// return nil, errors.New("failed to parse multipart form")
|
|
// }
|
|
//
|
|
// if c.Request.MultipartForm != nil && c.Request.MultipartForm.File != nil {
|
|
// var imageFiles []*multipart.FileHeader
|
|
// var exists bool
|
|
//
|
|
// if imageFiles, exists = c.Request.MultipartForm.File["image"]; !exists || len(imageFiles) == 0 {
|
|
// if imageFiles, exists = c.Request.MultipartForm.File["image[]"]; !exists || len(imageFiles) == 0 {
|
|
// foundArrayImages := false
|
|
// for fieldName, files := range c.Request.MultipartForm.File {
|
|
// if strings.HasPrefix(fieldName, "image[") && len(files) > 0 {
|
|
// foundArrayImages = true
|
|
// for _, file := range files {
|
|
// imageFiles = append(imageFiles, file)
|
|
// }
|
|
// }
|
|
// }
|
|
//
|
|
// if !foundArrayImages && (len(imageFiles) == 0) {
|
|
// return nil, errors.New("image is required")
|
|
// }
|
|
// }
|
|
// }
|
|
//
|
|
// for i, fileHeader := range imageFiles {
|
|
// file, err := fileHeader.Open()
|
|
// if err != nil {
|
|
// return nil, fmt.Errorf("failed to open image file %d: %w", i, err)
|
|
// }
|
|
// defer file.Close()
|
|
//
|
|
// fieldName := "image"
|
|
// if len(imageFiles) > 1 {
|
|
// fieldName = "image[]"
|
|
// }
|
|
//
|
|
// mimeType := detectImageMimeType(fileHeader.Filename)
|
|
//
|
|
// h := make(textproto.MIMEHeader)
|
|
// h.Set("Content-Disposition", fmt.Sprintf(`form-data; name="%s"; filename="%s"`, fieldName, fileHeader.Filename))
|
|
// h.Set("Content-Type", mimeType)
|
|
//
|
|
// part, err := writer.CreatePart(h)
|
|
// if err != nil {
|
|
// return nil, fmt.Errorf("create form part failed for image %d: %w", i, err)
|
|
// }
|
|
//
|
|
// if _, err := io.Copy(part, file); err != nil {
|
|
// return nil, fmt.Errorf("copy file failed for image %d: %w", i, err)
|
|
// }
|
|
// }
|
|
//
|
|
// if maskFiles, exists := c.Request.MultipartForm.File["mask"]; exists && len(maskFiles) > 0 {
|
|
// maskFile, err := maskFiles[0].Open()
|
|
// if err != nil {
|
|
// return nil, errors.New("failed to open mask file")
|
|
// }
|
|
// defer maskFile.Close()
|
|
//
|
|
// mimeType := detectImageMimeType(maskFiles[0].Filename)
|
|
//
|
|
// h := make(textproto.MIMEHeader)
|
|
// h.Set("Content-Disposition", fmt.Sprintf(`form-data; name="mask"; filename="%s"`, maskFiles[0].Filename))
|
|
// h.Set("Content-Type", mimeType)
|
|
//
|
|
// maskPart, err := writer.CreatePart(h)
|
|
// if err != nil {
|
|
// return nil, errors.New("create form file failed for mask")
|
|
// }
|
|
//
|
|
// if _, err := io.Copy(maskPart, maskFile); err != nil {
|
|
// return nil, errors.New("copy mask file failed")
|
|
// }
|
|
// }
|
|
// } else {
|
|
// return nil, errors.New("no multipart form data found")
|
|
// }
|
|
//
|
|
// writer.Close()
|
|
// c.Request.Header.Set("Content-Type", writer.FormDataContentType())
|
|
// return bytes.NewReader(requestBody.Bytes()), nil
|
|
|
|
default:
|
|
return request, nil
|
|
}
|
|
}
|
|
|
|
func detectImageMimeType(filename string) string {
|
|
ext := strings.ToLower(filepath.Ext(filename))
|
|
switch ext {
|
|
case ".jpg", ".jpeg":
|
|
return "image/jpeg"
|
|
case ".png":
|
|
return "image/png"
|
|
case ".webp":
|
|
return "image/webp"
|
|
default:
|
|
if strings.HasPrefix(ext, ".jp") {
|
|
return "image/jpeg"
|
|
}
|
|
return "image/png"
|
|
}
|
|
}
|
|
|
|
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
|
}
|
|
|
|
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
|
baseUrl := info.ChannelBaseUrl
|
|
if baseUrl == "" {
|
|
baseUrl = channelconstant.ChannelBaseURLs[channelconstant.ChannelTypeVolcEngine]
|
|
}
|
|
|
|
switch info.RelayFormat {
|
|
case types.RelayFormatClaude:
|
|
if baseUrl == DoubaoCodingPlan {
|
|
return fmt.Sprintf("%s/v1/messages", DoubaoCodingPlanClaudeBaseURL), nil
|
|
}
|
|
if strings.HasPrefix(info.UpstreamModelName, "bot") {
|
|
return fmt.Sprintf("%s/api/v3/bots/chat/completions", baseUrl), nil
|
|
}
|
|
return fmt.Sprintf("%s/api/v3/chat/completions", baseUrl), nil
|
|
default:
|
|
switch info.RelayMode {
|
|
case constant.RelayModeChatCompletions:
|
|
if baseUrl == DoubaoCodingPlan {
|
|
return fmt.Sprintf("%s/chat/completions", DoubaoCodingPlanOpenAIBaseURL), nil
|
|
}
|
|
if strings.HasPrefix(info.UpstreamModelName, "bot") {
|
|
return fmt.Sprintf("%s/api/v3/bots/chat/completions", baseUrl), nil
|
|
}
|
|
return fmt.Sprintf("%s/api/v3/chat/completions", baseUrl), nil
|
|
case constant.RelayModeEmbeddings:
|
|
return fmt.Sprintf("%s/api/v3/embeddings", baseUrl), nil
|
|
//豆包的图生图也走generations接口: https://www.volcengine.com/docs/82379/1824121
|
|
case constant.RelayModeImagesGenerations, constant.RelayModeImagesEdits:
|
|
return fmt.Sprintf("%s/api/v3/images/generations", baseUrl), nil
|
|
//case constant.RelayModeImagesEdits:
|
|
// return fmt.Sprintf("%s/api/v3/images/edits", baseUrl), nil
|
|
case constant.RelayModeRerank:
|
|
return fmt.Sprintf("%s/api/v3/rerank", baseUrl), nil
|
|
case constant.RelayModeAudioSpeech:
|
|
if baseUrl == channelconstant.ChannelBaseURLs[channelconstant.ChannelTypeVolcEngine] {
|
|
return "wss://openspeech.bytedance.com/api/v1/tts/ws_binary", nil
|
|
}
|
|
return fmt.Sprintf("%s/v1/audio/speech", baseUrl), nil
|
|
default:
|
|
}
|
|
}
|
|
return "", fmt.Errorf("unsupported relay mode: %d", info.RelayMode)
|
|
}
|
|
|
|
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
|
|
channel.SetupApiRequestHeader(info, c, req)
|
|
|
|
if info.RelayMode == constant.RelayModeAudioSpeech {
|
|
parts := strings.Split(info.ApiKey, "|")
|
|
if len(parts) == 2 {
|
|
req.Set("Authorization", "Bearer;"+parts[1])
|
|
}
|
|
req.Set("Content-Type", "application/json")
|
|
return nil
|
|
} else if info.RelayMode == constant.RelayModeImagesEdits {
|
|
req.Set("Content-Type", gin.MIMEJSON)
|
|
}
|
|
|
|
req.Set("Authorization", "Bearer "+info.ApiKey)
|
|
return nil
|
|
}
|
|
|
|
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
|
if request == nil {
|
|
return nil, errors.New("request is nil")
|
|
}
|
|
|
|
if !model_setting.ShouldPreserveThinkingSuffix(info.OriginModelName) &&
|
|
strings.HasSuffix(info.UpstreamModelName, "-thinking") &&
|
|
strings.HasPrefix(info.UpstreamModelName, "deepseek") {
|
|
info.UpstreamModelName = strings.TrimSuffix(info.UpstreamModelName, "-thinking")
|
|
request.Model = info.UpstreamModelName
|
|
request.THINKING = json.RawMessage(`{"type": "enabled"}`)
|
|
}
|
|
return request, nil
|
|
}
|
|
|
|
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
|
|
return nil, nil
|
|
}
|
|
|
|
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
|
|
return request, nil
|
|
}
|
|
|
|
func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {
|
|
return nil, errors.New("not implemented")
|
|
}
|
|
|
|
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
|
if info.RelayMode == constant.RelayModeAudioSpeech {
|
|
baseUrl := info.ChannelBaseUrl
|
|
if baseUrl == "" {
|
|
baseUrl = channelconstant.ChannelBaseURLs[channelconstant.ChannelTypeVolcEngine]
|
|
}
|
|
|
|
if baseUrl == channelconstant.ChannelBaseURLs[channelconstant.ChannelTypeVolcEngine] {
|
|
if info.IsStream {
|
|
return nil, nil
|
|
}
|
|
}
|
|
}
|
|
return channel.DoApiRequest(a, c, info, requestBody)
|
|
}
|
|
|
|
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
|
|
if info.RelayMode == constant.RelayModeAudioSpeech {
|
|
encoding := mapEncoding(c.GetString(contextKeyResponseFormat))
|
|
if info.IsStream {
|
|
volcRequestInterface, exists := c.Get(contextKeyTTSRequest)
|
|
if !exists {
|
|
return nil, types.NewErrorWithStatusCode(
|
|
errors.New("volcengine TTS request not found in context"),
|
|
types.ErrorCodeBadRequestBody,
|
|
http.StatusInternalServerError,
|
|
)
|
|
}
|
|
|
|
volcRequest, ok := volcRequestInterface.(VolcengineTTSRequest)
|
|
if !ok {
|
|
return nil, types.NewErrorWithStatusCode(
|
|
errors.New("invalid volcengine TTS request type"),
|
|
types.ErrorCodeBadRequestBody,
|
|
http.StatusInternalServerError,
|
|
)
|
|
}
|
|
|
|
// Get the WebSocket URL
|
|
requestURL, urlErr := a.GetRequestURL(info)
|
|
if urlErr != nil {
|
|
return nil, types.NewErrorWithStatusCode(
|
|
urlErr,
|
|
types.ErrorCodeBadRequestBody,
|
|
http.StatusInternalServerError,
|
|
)
|
|
}
|
|
return handleTTSWebSocketResponse(c, requestURL, volcRequest, info, encoding)
|
|
}
|
|
return handleTTSResponse(c, resp, info, encoding)
|
|
}
|
|
|
|
adaptor := openai.Adaptor{}
|
|
usage, err = adaptor.DoResponse(c, resp, info)
|
|
return
|
|
}
|
|
|
|
func (a *Adaptor) GetModelList() []string {
|
|
return ModelList
|
|
}
|
|
|
|
func (a *Adaptor) GetChannelName() string {
|
|
return ChannelName
|
|
}
|