Compare commits

..

39 Commits

Author SHA1 Message Date
creamlike1024
ce1fde8500 chore: Ignore .zed and debug binaries in .gitignore 2025-10-21 16:40:22 +08:00
Seefs
4661399639 Merge pull request #2070 from QuantumNous/ali-channel-support-stream-options
Ali channel support stream options
2025-10-20 23:24:33 +08:00
IcedTangerine
78d8d458ca Merge pull request #2081 from feitianbubu/pr/add-miniMax-tts
增加MiniMax语音合成TTS支持
2025-10-20 17:48:35 +08:00
IcedTangerine
e20a287c4b chore: Comment out debug log in adaptor.go
Comment out the debug log for MiniMax TTS Request.
2025-10-20 17:48:08 +08:00
feitianbubu
c7ab0f4f3d feat: opt minimax tts req struct 2025-10-20 16:26:50 +08:00
feitianbubu
0d1057830b feat: add minimax api adaptor 2025-10-20 16:26:50 +08:00
feitianbubu
dd1cac3f2e feat: add minimax tts 2025-10-20 16:26:50 +08:00
creamlike1024
cdbc7a9510 refactor: remove unused functions and imports from ali text handler 2025-10-18 17:00:28 +08:00
creamlike1024
c693bfee5e feat: add support for Ali channel in streamSupportedChannels 2025-10-18 17:00:08 +08:00
IcedTangerine
7156bf2382 Merge pull request #2068 from feitianbubu/pr/doubao-speech-emotion
豆包语音2.0音色支持情感,情绪,音量
2025-10-18 14:30:17 +08:00
Seefs
c216527f23 Merge pull request #2065 from somnifex/main
fix: handle JSON parsing for thinking content in ollama stream
2025-10-18 13:02:56 +08:00
Seefs
b1de0f49df Merge pull request #2061 from QuantumNous/fix-gemini-batch-embedding-token-count
fix: gemini batch embedding token not counted
2025-10-18 12:54:44 +08:00
feitianbubu
525ca09f2c fix: doubao audio speedRadio to speed 2025-10-18 01:48:36 +08:00
feitianbubu
92fc973bc3 feat: AudioRequest add metadata support custom params 2025-10-18 01:48:36 +08:00
feitianbubu
22ff8e2cbe feat: sync latest openai speech struct
https://platform.openai.com/docs/api-reference/audio/createSpeech
2025-10-18 01:48:36 +08:00
IcedTangerine
1ec664a348 Merge pull request #2067 from feitianbubu/pr/add-doubao-audio
新增支持豆包语音合成2.0功能
2025-10-18 00:14:11 +08:00
IcedTangerine
6a24c37c0e Fix error message for invalid API key format 2025-10-18 00:13:28 +08:00
feitianbubu
8965fc49c9 feat: add doubao audio token input prompt 2025-10-17 22:06:46 +08:00
feitianbubu
735386c0b9 feat: add doubao tts usage token 2025-10-17 22:06:45 +08:00
feitianbubu
58c4da0ddf feat: switch to official TTS only when baseUrl is Volcano's official URL 2025-10-17 22:06:45 +08:00
feitianbubu
fe68488b1c feat: add doubao audio tts 2025-10-17 22:06:45 +08:00
somnifex
25af6e6f77 fix: handle JSON parsing for thinking content in ollama stream and chat handlers 2025-10-17 18:35:08 +08:00
creamlike1024
e2d3b46a3a fix: gemini batch embedding token not counted 2025-10-17 15:51:04 +08:00
CaIon
dd775167ab feat: support OpenAI channel type in sora relay adaptor 2025-10-17 13:53:15 +08:00
CaIon
43f2a8ac06 feat: add temporary TASK_PRICE_PATCH configuration to environment variables 2025-10-16 21:59:21 +08:00
CaIon
bcf93a2c05 fix: prevent refund on video task update error 2025-10-16 12:46:07 +08:00
CaIon
09ff878d88 fix: prevent duplicate refunds on task failure #2050 2025-10-16 12:38:21 +08:00
CaIon
d4749ba388 refactor: rename AWS model ID and region prefix functions for clarity 2025-10-16 12:10:55 +08:00
CaIon
1f2bdb1402 fix: gemini embedding 2025-10-15 21:48:36 +08:00
CaIon
64a97092c9 CI: ignore pre-release and alpha tags in electron build workflow 2025-10-15 19:56:07 +08:00
CaIon
69b87b5d8e refactor: replace iota with explicit values for log type constants 2025-10-15 19:54:13 +08:00
CaIon
bd4160793e fix 2025-10-15 19:46:06 +08:00
CaIon
82e21972ec feat: 修复aws渠道-thinking后缀不生效的问题 2025-10-15 18:49:27 +08:00
CaIon
dce00141ce feat: 临时兼容aws使用链接媒体 2025-10-15 18:21:19 +08:00
CaIon
b2a057723a refactor: update AWS key format in EditChannelModal for consistency 2025-10-15 17:38:21 +08:00
CaIon
f023efdbfc feat: support aws bedrock api-keys-use 2025-10-15 17:29:10 +08:00
CaIon
8b65623726 refactor: aws 2025-10-15 16:44:33 +08:00
CaIon
aa35d8db69 refactor: update ConvertToOpenAIVideo method to return byte array and improve error handling 2025-10-14 23:03:17 +08:00
CaIon
64ed7dce4d docs: update README for project cloning and Docker Compose instructions 2025-10-14 17:51:33 +08:00
42 changed files with 1048 additions and 405 deletions

View File

@@ -4,6 +4,8 @@ on:
push:
tags:
- '*' # Triggers on version tags like v1.0.0
- '!*-*' # Ignore pre-release tags like v1.0.0-beta
- '!*-alpha*' # Ignore alpha tags like v1.0.0-alpha
workflow_dispatch: # Allows manual triggering
jobs:

4
.gitignore vendored
View File

@@ -1,5 +1,6 @@
.idea
.vscode
.zed
upload
*.exe
*.db
@@ -10,10 +11,11 @@ web/dist
.env
one-api
new-api
/__debug_bin*
.DS_Store
tiktoken_cache
.eslintcache
.gocache
electron/node_modules
electron/dist
electron/dist

View File

@@ -165,12 +165,18 @@ New API提供了丰富的功能详细特性请参考[特性说明](https://do
#### 使用Docker Compose部署推荐
```shell
# 下载项目
git clone https://github.com/Calcium-Ion/new-api.git
# 下载项目源码
git clone https://github.com/QuantumNous/new-api.git
# 进入项目目录
cd new-api
# 按需编辑docker-compose.yml
# 启动
docker-compose up -d
# 根据需要编辑 docker-compose.yml 文件
# 使用nano编辑器
nano docker-compose.yml
# 或使用vim编辑器
# vim docker-compose.yml
```
#### 直接使用Docker镜像

View File

@@ -69,6 +69,8 @@ func ChannelType2APIType(channelType int) (int, bool) {
apiType = constant.APITypeMoonshot
case constant.ChannelTypeSubmodel:
apiType = constant.APITypeSubmodel
case constant.ChannelTypeMiniMax:
apiType = constant.APITypeMiniMax
}
if apiType == -1 {
return constant.APITypeOpenAI, false

View File

@@ -7,6 +7,7 @@ import (
"os"
"path/filepath"
"strconv"
"strings"
"time"
"github.com/QuantumNous/new-api/constant"
@@ -118,4 +119,17 @@ func initConstantEnv() {
constant.GenerateDefaultToken = GetEnvOrDefaultBool("GENERATE_DEFAULT_TOKEN", false)
// 是否启用错误日志
constant.ErrorLogEnabled = GetEnvOrDefaultBool("ERROR_LOG_ENABLED", false)
soraPatchStr := GetEnvOrDefaultString("TASK_PRICE_PATCH", "")
if soraPatchStr != "" {
var taskPricePatches []string
soraPatches := strings.Split(soraPatchStr, ",")
for _, patch := range soraPatches {
trimmedPatch := strings.TrimSpace(patch)
if trimmedPatch != "" {
taskPricePatches = append(taskPricePatches, trimmedPatch)
}
}
constant.TaskPricePatches = taskPricePatches
}
}

View File

@@ -3,6 +3,7 @@ package common
import (
"bytes"
"encoding/json"
"io"
)
func Unmarshal(data []byte, v any) error {
@@ -13,7 +14,7 @@ func UnmarshalJsonStr(data string, v any) error {
return json.Unmarshal(StringToByteSlice(data), v)
}
func DecodeJson(reader *bytes.Reader, v any) error {
func DecodeJson(reader io.Reader, v any) error {
return json.NewDecoder(reader).Decode(v)
}

View File

@@ -33,5 +33,6 @@ const (
APITypeJimeng
APITypeMoonshot
APITypeSubmodel
APITypeMiniMax
APITypeDummy // this one is only for count, do not add any channel after this
)

View File

@@ -13,3 +13,6 @@ var NotifyLimitCount int
var NotificationLimitDurationMinute int
var GenerateDefaultToken bool
var ErrorLogEnabled bool
// temporary variable for sora patch, will be removed in future
var TaskPricePatches []string

View File

@@ -88,10 +88,13 @@ func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, cha
return fmt.Errorf("readAll failed for task %s: %w", taskId, err)
}
logger.LogDebug(ctx, fmt.Sprintf("UpdateVideoSingleTask response: %s", string(responseBody)))
taskResult := &relaycommon.TaskInfo{}
// try parse as New API response format
var responseItems dto.TaskResponse[model.Task]
if err = json.Unmarshal(responseBody, &responseItems); err == nil && responseItems.IsSuccess() {
if err = common.Unmarshal(responseBody, &responseItems); err == nil && responseItems.IsSuccess() {
logger.LogDebug(ctx, fmt.Sprintf("UpdateVideoSingleTask parsed as new api response format: %+v", responseItems))
t := responseItems.Data
taskResult.TaskID = t.TaskID
taskResult.Status = string(t.Status)
@@ -105,10 +108,19 @@ func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, cha
task.Data = redactVideoResponseBody(responseBody)
}
logger.LogDebug(ctx, fmt.Sprintf("UpdateVideoSingleTask taskResult: %+v", taskResult))
now := time.Now().Unix()
if taskResult.Status == "" {
return fmt.Errorf("task %s status is empty", taskId)
//return fmt.Errorf("task %s status is empty", taskId)
taskResult = relaycommon.FailTaskInfo("upstream returned empty status")
}
// 记录原本的状态,防止重复退款
shouldRefund := false
quota := task.Quota
preStatus := task.Status
task.Status = model.TaskStatus(taskResult.Status)
switch taskResult.Status {
case model.TaskStatusSubmitted:
@@ -219,7 +231,7 @@ func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, cha
}
}
case model.TaskStatusFailure:
preStatus := task.Status
logger.LogJson(ctx, fmt.Sprintf("Task %s failed", taskId), task)
task.Status = model.TaskStatusFailure
task.Progress = "100%"
if task.FinishTime == 0 {
@@ -227,16 +239,10 @@ func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, cha
}
task.FailReason = taskResult.Reason
logger.LogInfo(ctx, fmt.Sprintf("Task %s failed: %s", task.TaskID, task.FailReason))
quota := task.Quota
taskResult.Progress = "100%"
if quota != 0 {
if preStatus != model.TaskStatusFailure {
// 任务失败且之前状态不是失败才退还额度,防止重复退还
if err := model.IncreaseUserQuota(task.UserId, quota, false); err != nil {
logger.LogWarn(ctx, "Failed to increase user quota: "+err.Error())
}
logContent := fmt.Sprintf("Video async task failed %s, refund %s", task.TaskID, logger.LogQuota(quota))
model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
shouldRefund = true
} else {
logger.LogWarn(ctx, fmt.Sprintf("Task %s already in failure status, skip refund", task.TaskID))
}
@@ -249,6 +255,16 @@ func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, cha
}
if err := task.Update(); err != nil {
common.SysLog("UpdateVideoTask task error: " + err.Error())
shouldRefund = false
}
if shouldRefund {
// 任务失败且之前状态不是失败才退还额度,防止重复退还
if err := model.IncreaseUserQuota(task.UserId, quota, false); err != nil {
logger.LogWarn(ctx, "Failed to increase user quota: "+err.Error())
}
logContent := fmt.Sprintf("Video async task failed %s, refund %s", task.TaskID, logger.LogQuota(quota))
model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
}
return nil

View File

@@ -1,17 +1,22 @@
package dto
import (
"encoding/json"
"github.com/QuantumNous/new-api/types"
"github.com/gin-gonic/gin"
)
type AudioRequest struct {
Model string `json:"model"`
Input string `json:"input"`
Voice string `json:"voice"`
Speed float64 `json:"speed,omitempty"`
ResponseFormat string `json:"response_format,omitempty"`
Model string `json:"model"`
Input string `json:"input"`
Voice string `json:"voice"`
Instructions string `json:"instructions,omitempty"`
ResponseFormat string `json:"response_format,omitempty"`
Speed float64 `json:"speed,omitempty"`
StreamFormat string `json:"stream_format,omitempty"`
Metadata json.RawMessage `json:"metadata,omitempty"`
}
func (r *AudioRequest) GetTokenCountMeta() *types.TokenCountMeta {

View File

@@ -16,6 +16,13 @@ const (
VertexKeyTypeAPIKey VertexKeyType = "api_key"
)
type AwsKeyType string
const (
AwsKeyTypeAKSK AwsKeyType = "ak_sk" // 默认
AwsKeyTypeApiKey AwsKeyType = "api_key"
)
type ChannelOtherSettings struct {
AzureResponsesVersion string `json:"azure_responses_version,omitempty"`
VertexKeyType VertexKeyType `json:"vertex_key_type,omitempty"` // "json" or "api_key"
@@ -23,6 +30,7 @@ type ChannelOtherSettings struct {
AllowServiceTier bool `json:"allow_service_tier,omitempty"` // 是否允许 service_tier 透传(默认过滤以避免额外计费)
DisableStore bool `json:"disable_store,omitempty"` // 是否禁用 store 透传(默认允许透传,禁用后可能导致 Codex 无法使用)
AllowSafetyIdentifier bool `json:"allow_safety_identifier,omitempty"` // 是否允许 safety_identifier 透传(默认过滤以保护用户隐私)
AwsKeyType AwsKeyType `json:"aws_key_type,omitempty"`
}
func (s *ChannelOtherSettings) IsOpenRouterEnterprise() bool {

View File

@@ -148,6 +148,10 @@ func (c *ClaudeMessage) SetStringContent(content string) {
c.Content = content
}
func (c *ClaudeMessage) SetContent(content any) {
c.Content = content
}
func (c *ClaudeMessage) ParseContent() ([]ClaudeMediaMessage, error) {
return common.Any2Type[[]ClaudeMediaMessage](c.Content)
}

View File

@@ -12,6 +12,7 @@ import (
)
type GeminiChatRequest struct {
Requests []GeminiChatRequest `json:"requests,omitempty"` // For batch requests
Contents []GeminiChatContent `json:"contents"`
SafetySettings []GeminiChatSafetySettings `json:"safetySettings,omitempty"`
GenerationConfig GeminiChatGenerationConfig `json:"generationConfig,omitempty"`

View File

@@ -27,7 +27,7 @@ type OpenAIVideo struct {
Size string `json:"size,omitempty"`
RemixedFromVideoID string `json:"remixed_from_video_id,omitempty"`
Error *OpenAIVideoError `json:"error,omitempty"`
Metadata map[string]any `json:"metadata,omitempty"`
Metadata map[string]any `json:"meta_data,omitempty"`
}
func (m *OpenAIVideo) SetProgressStr(progress string) {

View File

@@ -153,5 +153,5 @@ func LogJson(ctx context.Context, msg string, obj any) {
LogError(ctx, fmt.Sprintf("json marshal failed: %s", err.Error()))
return
}
LogInfo(ctx, fmt.Sprintf("%s | %s", msg, string(jsonStr)))
LogDebug(ctx, fmt.Sprintf("%s | %s", msg, string(jsonStr)))
}

View File

@@ -39,13 +39,15 @@ type Log struct {
Other string `json:"other"`
}
// don't use iota, avoid change log type value
const (
LogTypeUnknown = iota
LogTypeTopup
LogTypeConsume
LogTypeManage
LogTypeSystem
LogTypeError
LogTypeUnknown = 0
LogTypeTopup = 1
LogTypeConsume = 2
LogTypeManage = 3
LogTypeSystem = 4
LogTypeError = 5
LogTypeRefund = 6
)
func formatUserLogs(logs []*Log) {

View File

@@ -53,5 +53,5 @@ type TaskAdaptor interface {
}
type OpenAIVideoConverter interface {
ConvertToOpenAIVideo(originTask *model.Task) (*dto.OpenAIVideo, error)
ConvertToOpenAIVideo(originTask *model.Task) ([]byte, error)
}

View File

@@ -1,20 +1,7 @@
package ali
import (
"bufio"
"encoding/json"
"io"
"net/http"
"strings"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/dto"
"github.com/QuantumNous/new-api/relay/helper"
"github.com/QuantumNous/new-api/service"
"github.com/QuantumNous/new-api/types"
"github.com/gin-gonic/gin"
)
// https://help.aliyun.com/document_detail/613695.html?spm=a2c4g.2399480.0.0.1adb778fAdzP9w#341800c0f8w0r
@@ -29,180 +16,3 @@ func requestOpenAI2Ali(request dto.GeneralOpenAIRequest) *dto.GeneralOpenAIReque
}
return &request
}
func embeddingRequestOpenAI2Ali(request dto.EmbeddingRequest) *AliEmbeddingRequest {
return &AliEmbeddingRequest{
Model: request.Model,
Input: struct {
Texts []string `json:"texts"`
}{
Texts: request.ParseInput(),
},
}
}
func aliEmbeddingHandler(c *gin.Context, resp *http.Response) (*types.NewAPIError, *dto.Usage) {
var fullTextResponse dto.FlexibleEmbeddingResponse
err := json.NewDecoder(resp.Body).Decode(&fullTextResponse)
if err != nil {
return types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError), nil
}
service.CloseResponseBodyGracefully(resp)
model := c.GetString("model")
if model == "" {
model = "text-embedding-v4"
}
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
c.Writer.Write(jsonResponse)
return nil, &fullTextResponse.Usage
}
func embeddingResponseAli2OpenAI(response *AliEmbeddingResponse, model string) *dto.OpenAIEmbeddingResponse {
openAIEmbeddingResponse := dto.OpenAIEmbeddingResponse{
Object: "list",
Data: make([]dto.OpenAIEmbeddingResponseItem, 0, len(response.Output.Embeddings)),
Model: model,
Usage: dto.Usage{TotalTokens: response.Usage.TotalTokens},
}
for _, item := range response.Output.Embeddings {
openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, dto.OpenAIEmbeddingResponseItem{
Object: `embedding`,
Index: item.TextIndex,
Embedding: item.Embedding,
})
}
return &openAIEmbeddingResponse
}
func responseAli2OpenAI(response *AliResponse) *dto.OpenAITextResponse {
choice := dto.OpenAITextResponseChoice{
Index: 0,
Message: dto.Message{
Role: "assistant",
Content: response.Output.Text,
},
FinishReason: response.Output.FinishReason,
}
fullTextResponse := dto.OpenAITextResponse{
Id: response.RequestId,
Object: "chat.completion",
Created: common.GetTimestamp(),
Choices: []dto.OpenAITextResponseChoice{choice},
Usage: dto.Usage{
PromptTokens: response.Usage.InputTokens,
CompletionTokens: response.Usage.OutputTokens,
TotalTokens: response.Usage.InputTokens + response.Usage.OutputTokens,
},
}
return &fullTextResponse
}
func streamResponseAli2OpenAI(aliResponse *AliResponse) *dto.ChatCompletionsStreamResponse {
var choice dto.ChatCompletionsStreamResponseChoice
choice.Delta.SetContentString(aliResponse.Output.Text)
if aliResponse.Output.FinishReason != "null" {
finishReason := aliResponse.Output.FinishReason
choice.FinishReason = &finishReason
}
response := dto.ChatCompletionsStreamResponse{
Id: aliResponse.RequestId,
Object: "chat.completion.chunk",
Created: common.GetTimestamp(),
Model: "ernie-bot",
Choices: []dto.ChatCompletionsStreamResponseChoice{choice},
}
return &response
}
func aliStreamHandler(c *gin.Context, resp *http.Response) (*types.NewAPIError, *dto.Usage) {
var usage dto.Usage
scanner := bufio.NewScanner(resp.Body)
scanner.Split(bufio.ScanLines)
dataChan := make(chan string)
stopChan := make(chan bool)
go func() {
for scanner.Scan() {
data := scanner.Text()
if len(data) < 5 { // ignore blank line or wrong format
continue
}
if data[:5] != "data:" {
continue
}
data = data[5:]
dataChan <- data
}
stopChan <- true
}()
helper.SetEventStreamHeaders(c)
lastResponseText := ""
c.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:
var aliResponse AliResponse
err := json.Unmarshal([]byte(data), &aliResponse)
if err != nil {
common.SysLog("error unmarshalling stream response: " + err.Error())
return true
}
if aliResponse.Usage.OutputTokens != 0 {
usage.PromptTokens = aliResponse.Usage.InputTokens
usage.CompletionTokens = aliResponse.Usage.OutputTokens
usage.TotalTokens = aliResponse.Usage.InputTokens + aliResponse.Usage.OutputTokens
}
response := streamResponseAli2OpenAI(&aliResponse)
response.Choices[0].Delta.SetContentString(strings.TrimPrefix(response.Choices[0].Delta.GetContentString(), lastResponseText))
lastResponseText = aliResponse.Output.Text
jsonResponse, err := json.Marshal(response)
if err != nil {
common.SysLog("error marshalling stream response: " + err.Error())
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
return true
case <-stopChan:
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
return false
}
})
service.CloseResponseBodyGracefully(resp)
return nil, &usage
}
func aliHandler(c *gin.Context, resp *http.Response) (*types.NewAPIError, *dto.Usage) {
var aliResponse AliResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError), nil
}
service.CloseResponseBodyGracefully(resp)
err = json.Unmarshal(responseBody, &aliResponse)
if err != nil {
return types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError), nil
}
if aliResponse.Code != "" {
return types.WithOpenAIError(types.OpenAIError{
Message: aliResponse.Message,
Type: "ali_error",
Param: aliResponse.RequestId,
Code: aliResponse.Code,
}, resp.StatusCode), nil
}
fullTextResponse := responseAli2OpenAI(&aliResponse)
jsonResponse, err := common.Marshal(fullTextResponse)
if err != nil {
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse)
return nil, &fullTextResponse.Usage
}

View File

@@ -1,25 +1,36 @@
package aws
import (
"errors"
"fmt"
"io"
"net/http"
"strings"
"github.com/QuantumNous/new-api/dto"
"github.com/QuantumNous/new-api/relay/channel"
"github.com/QuantumNous/new-api/relay/channel/claude"
relaycommon "github.com/QuantumNous/new-api/relay/common"
"github.com/QuantumNous/new-api/service"
"github.com/QuantumNous/new-api/types"
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
"github.com/pkg/errors"
"github.com/gin-gonic/gin"
)
type ClientMode int
const (
RequestModeCompletion = 1
RequestModeMessage = 2
ClientModeApiKey ClientMode = iota + 1
ClientModeAKSK
)
type Adaptor struct {
RequestMode int
ClientMode ClientMode
AwsClient *bedrockruntime.Client
AwsModelId string
AwsReq any
IsNova bool
}
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
@@ -28,8 +39,37 @@ func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dt
}
func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error) {
c.Set("request_model", request.Model)
c.Set("converted_request", request)
for i, message := range request.Messages {
updated := false
if !message.IsStringContent() {
content, err := message.ParseContent()
if err != nil {
return nil, errors.Wrap(err, "failed to parse message content")
}
for i2, mediaMessage := range content {
if mediaMessage.Source != nil {
if mediaMessage.Source.Type == "url" {
fileData, err := service.GetFileBase64FromUrl(c, mediaMessage.Source.Url, "formatting image for Claude")
if err != nil {
return nil, fmt.Errorf("get file base64 from url failed: %s", err.Error())
}
mediaMessage.Source.MediaType = fileData.MimeType
mediaMessage.Source.Data = fileData.Base64Data
mediaMessage.Source.Url = ""
mediaMessage.Source.Type = "base64"
content[i2] = mediaMessage
updated = true
}
}
}
if updated {
message.SetContent(content)
}
}
if updated {
request.Messages[i] = message
}
}
return request, nil
}
@@ -44,15 +84,28 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
a.RequestMode = RequestModeMessage
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
return "", nil
if info.ChannelOtherSettings.AwsKeyType == dto.AwsKeyTypeApiKey {
awsModelId := getAwsModelID(info.UpstreamModelName)
a.ClientMode = ClientModeApiKey
awsSecret := strings.Split(info.ApiKey, "|")
if len(awsSecret) != 2 {
return "", errors.New("invalid aws api key, should be in format of <api-key>|<region>")
}
return fmt.Sprintf("https://bedrock-runtime.%s.amazonaws.com/model/%s/converse", awsModelId, awsSecret[1]), nil
} else {
a.ClientMode = ClientModeAKSK
return "", nil
}
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
claude.CommonClaudeHeadersOperation(c, req, info)
if a.ClientMode == ClientModeApiKey {
req.Set("Authorization", "Bearer "+info.ApiKey)
}
return nil
}
@@ -63,22 +116,16 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
// 检查是否为Nova模型
if isNovaModel(request.Model) {
novaReq := convertToNovaRequest(request)
c.Set("request_model", request.Model)
c.Set("converted_request", novaReq)
c.Set("is_nova_model", true)
a.IsNova = true
return novaReq, nil
}
// 原有的Claude模型处理逻辑
var claudeReq *dto.ClaudeRequest
var err error
claudeReq, err = claude.RequestOpenAI2ClaudeMessage(c, *request)
claudeReq, err := claude.RequestOpenAI2ClaudeMessage(c, *request)
if err != nil {
return nil, err
return nil, errors.Wrap(err, "failed to convert openai request to claude request")
}
c.Set("request_model", claudeReq.Model)
c.Set("converted_request", claudeReq)
c.Set("is_nova_model", false)
info.UpstreamModelName = claudeReq.Model
return claudeReq, err
}
@@ -97,14 +144,27 @@ func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommo
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return nil, nil
if a.ClientMode == ClientModeApiKey {
return channel.DoApiRequest(a, c, info, requestBody)
} else {
return doAwsClientRequest(c, info, a, requestBody)
}
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
if info.IsStream {
err, usage = awsStreamHandler(c, resp, info, a.RequestMode)
if a.ClientMode == ClientModeApiKey {
claudeAdaptor := claude.Adaptor{}
usage, err = claudeAdaptor.DoResponse(c, resp, info)
} else {
err, usage = awsHandler(c, info, a.RequestMode)
if a.IsNova {
err, usage = handleNovaRequest(c, info, a)
} else {
if info.IsStream {
err, usage = awsStreamHandler(c, info, a)
} else {
err, usage = awsHandler(c, info, a)
}
}
}
return
}

View File

@@ -124,5 +124,5 @@ var ChannelName = "aws"
// 判断是否为Nova模型
func isNovaModel(modelId string) bool {
return strings.HasPrefix(modelId, "nova-")
return strings.Contains(modelId, "nova-")
}

View File

@@ -1,6 +1,9 @@
package aws
import (
"io"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/dto"
)
@@ -35,6 +38,16 @@ func copyRequest(req *dto.ClaudeRequest) *AwsClaudeRequest {
}
}
func formatRequest(requestBody io.Reader) (*AwsClaudeRequest, error) {
var awsClaudeRequest AwsClaudeRequest
err := common.DecodeJson(requestBody, &awsClaudeRequest)
if err != nil {
return nil, err
}
awsClaudeRequest.AnthropicVersion = "bedrock-2023-05-31"
return &awsClaudeRequest, nil
}
// NovaMessage Nova模型使用messages-v1格式
type NovaMessage struct {
Role string `json:"role"`

View File

@@ -3,6 +3,7 @@ package aws
import (
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
@@ -49,16 +50,78 @@ func newAwsClient(c *gin.Context, info *relaycommon.RelayInfo) (*bedrockruntime.
return client, nil
}
func wrapErr(err error) *dto.OpenAIErrorWithStatusCode {
return &dto.OpenAIErrorWithStatusCode{
StatusCode: http.StatusInternalServerError,
Error: dto.OpenAIError{
Message: fmt.Sprintf("%s", err.Error()),
},
func doAwsClientRequest(c *gin.Context, info *relaycommon.RelayInfo, a *Adaptor, requestBody io.Reader) (any, error) {
awsCli, err := newAwsClient(c, info)
if err != nil {
return nil, types.NewError(err, types.ErrorCodeChannelAwsClientError)
}
a.AwsClient = awsCli
println(info.UpstreamModelName)
// 获取对应的AWS模型ID
awsModelId := getAwsModelID(info.UpstreamModelName)
awsRegionPrefix := getAwsRegionPrefix(awsCli.Options().Region)
canCrossRegion := awsModelCanCrossRegion(awsModelId, awsRegionPrefix)
if canCrossRegion {
awsModelId = awsModelCrossRegion(awsModelId, awsRegionPrefix)
}
if isNovaModel(awsModelId) {
var novaReq *NovaRequest
err = common.DecodeJson(requestBody, &novaReq)
if err != nil {
return nil, types.NewError(errors.Wrap(err, "decode nova request fail"), types.ErrorCodeBadRequestBody)
}
// 使用InvokeModel API但使用Nova格式的请求体
awsReq := &bedrockruntime.InvokeModelInput{
ModelId: aws.String(awsModelId),
Accept: aws.String("application/json"),
ContentType: aws.String("application/json"),
}
reqBody, err := common.Marshal(novaReq)
if err != nil {
return nil, types.NewError(errors.Wrap(err, "marshal nova request"), types.ErrorCodeBadResponseBody)
}
awsReq.Body = reqBody
return nil, nil
} else {
awsClaudeReq, err := formatRequest(requestBody)
if err != nil {
return nil, types.NewError(errors.Wrap(err, "format aws request fail"), types.ErrorCodeBadRequestBody)
}
if info.IsStream {
awsReq := &bedrockruntime.InvokeModelWithResponseStreamInput{
ModelId: aws.String(awsModelId),
Accept: aws.String("application/json"),
ContentType: aws.String("application/json"),
}
awsReq.Body, err = common.Marshal(awsClaudeReq)
if err != nil {
return nil, types.NewError(errors.Wrap(err, "marshal aws request fail"), types.ErrorCodeBadRequestBody)
}
a.AwsReq = awsReq
return nil, nil
} else {
awsReq := &bedrockruntime.InvokeModelInput{
ModelId: aws.String(awsModelId),
Accept: aws.String("application/json"),
ContentType: aws.String("application/json"),
}
awsReq.Body, err = common.Marshal(awsClaudeReq)
if err != nil {
return nil, types.NewError(errors.Wrap(err, "marshal aws request fail"), types.ErrorCodeBadRequestBody)
}
a.AwsReq = awsReq
return nil, nil
}
}
}
func awsRegionPrefix(awsRegionId string) string {
func getAwsRegionPrefix(awsRegionId string) string {
parts := strings.Split(awsRegionId, "-")
regionPrefix := ""
if len(parts) > 0 {
@@ -80,58 +143,16 @@ func awsModelCrossRegion(awsModelId, awsRegionPrefix string) string {
return modelPrefix + "." + awsModelId
}
func awsModelID(requestModel string) string {
if awsModelID, ok := awsModelIDMap[requestModel]; ok {
return awsModelID
func getAwsModelID(requestModel string) string {
if awsModelIDName, ok := awsModelIDMap[requestModel]; ok {
return awsModelIDName
}
return requestModel
}
func awsHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode int) (*types.NewAPIError, *dto.Usage) {
awsCli, err := newAwsClient(c, info)
if err != nil {
return types.NewError(err, types.ErrorCodeChannelAwsClientError), nil
}
func awsHandler(c *gin.Context, info *relaycommon.RelayInfo, a *Adaptor) (*types.NewAPIError, *dto.Usage) {
awsModelId := awsModelID(c.GetString("request_model"))
// 检查是否为Nova模型
isNova, _ := c.Get("is_nova_model")
if isNova == true {
// Nova模型也支持跨区域
awsRegionPrefix := awsRegionPrefix(awsCli.Options().Region)
canCrossRegion := awsModelCanCrossRegion(awsModelId, awsRegionPrefix)
if canCrossRegion {
awsModelId = awsModelCrossRegion(awsModelId, awsRegionPrefix)
}
return handleNovaRequest(c, awsCli, info, awsModelId)
}
// 原有的Claude处理逻辑
awsRegionPrefix := awsRegionPrefix(awsCli.Options().Region)
canCrossRegion := awsModelCanCrossRegion(awsModelId, awsRegionPrefix)
if canCrossRegion {
awsModelId = awsModelCrossRegion(awsModelId, awsRegionPrefix)
}
awsReq := &bedrockruntime.InvokeModelInput{
ModelId: aws.String(awsModelId),
Accept: aws.String("application/json"),
ContentType: aws.String("application/json"),
}
claudeReq_, ok := c.Get("converted_request")
if !ok {
return types.NewError(errors.New("aws claude request not found"), types.ErrorCodeInvalidRequest), nil
}
claudeReq := claudeReq_.(*dto.ClaudeRequest)
awsClaudeReq := copyRequest(claudeReq)
awsReq.Body, err = common.Marshal(awsClaudeReq)
if err != nil {
return types.NewError(errors.Wrap(err, "marshal request"), types.ErrorCodeBadResponseBody), nil
}
awsResp, err := awsCli.InvokeModel(c.Request.Context(), awsReq)
awsResp, err := a.AwsClient.InvokeModel(c.Request.Context(), a.AwsReq.(*bedrockruntime.InvokeModelInput))
if err != nil {
return types.NewOpenAIError(errors.Wrap(err, "InvokeModel"), types.ErrorCodeAwsInvokeError, http.StatusInternalServerError), nil
}
@@ -149,46 +170,15 @@ func awsHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode int) (*
c.Writer.Header().Set("Content-Type", *awsResp.ContentType)
}
handlerErr := claude.HandleClaudeResponseData(c, info, claudeInfo, nil, awsResp.Body, RequestModeMessage)
handlerErr := claude.HandleClaudeResponseData(c, info, claudeInfo, nil, awsResp.Body, claude.RequestModeMessage)
if handlerErr != nil {
return handlerErr, nil
}
return nil, claudeInfo.Usage
}
func awsStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, requestMode int) (*types.NewAPIError, *dto.Usage) {
awsCli, err := newAwsClient(c, info)
if err != nil {
return types.NewError(err, types.ErrorCodeChannelAwsClientError), nil
}
awsModelId := awsModelID(c.GetString("request_model"))
awsRegionPrefix := awsRegionPrefix(awsCli.Options().Region)
canCrossRegion := awsModelCanCrossRegion(awsModelId, awsRegionPrefix)
if canCrossRegion {
awsModelId = awsModelCrossRegion(awsModelId, awsRegionPrefix)
}
awsReq := &bedrockruntime.InvokeModelWithResponseStreamInput{
ModelId: aws.String(awsModelId),
Accept: aws.String("application/json"),
ContentType: aws.String("application/json"),
}
claudeReq_, ok := c.Get("converted_request")
if !ok {
return types.NewError(errors.New("aws claude request not found"), types.ErrorCodeInvalidRequest), nil
}
claudeReq := claudeReq_.(*dto.ClaudeRequest)
awsClaudeReq := copyRequest(claudeReq)
awsReq.Body, err = common.Marshal(awsClaudeReq)
if err != nil {
return types.NewError(errors.Wrap(err, "marshal request"), types.ErrorCodeBadResponseBody), nil
}
awsResp, err := awsCli.InvokeModelWithResponseStream(c.Request.Context(), awsReq)
func awsStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, a *Adaptor) (*types.NewAPIError, *dto.Usage) {
awsResp, err := a.AwsClient.InvokeModelWithResponseStream(c.Request.Context(), a.AwsReq.(*bedrockruntime.InvokeModelWithResponseStreamInput))
if err != nil {
return types.NewOpenAIError(errors.Wrap(err, "InvokeModelWithResponseStream"), types.ErrorCodeAwsInvokeError, http.StatusInternalServerError), nil
}
@@ -207,7 +197,7 @@ func awsStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
switch v := event.(type) {
case *bedrockruntimeTypes.ResponseStreamMemberChunk:
info.SetFirstResponseTime()
respErr := claude.HandleStreamResponseData(c, info, claudeInfo, string(v.Value.Bytes), RequestModeMessage)
respErr := claude.HandleStreamResponseData(c, info, claudeInfo, string(v.Value.Bytes), claude.RequestModeMessage)
if respErr != nil {
return respErr, nil
}
@@ -220,32 +210,14 @@ func awsStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
}
}
claude.HandleStreamFinalResponse(c, info, claudeInfo, RequestModeMessage)
claude.HandleStreamFinalResponse(c, info, claudeInfo, claude.RequestModeMessage)
return nil, claudeInfo.Usage
}
// Nova模型处理函数
func handleNovaRequest(c *gin.Context, awsCli *bedrockruntime.Client, info *relaycommon.RelayInfo, awsModelId string) (*types.NewAPIError, *dto.Usage) {
novaReq_, ok := c.Get("converted_request")
if !ok {
return types.NewError(errors.New("nova request not found"), types.ErrorCodeInvalidRequest), nil
}
novaReq := novaReq_.(*NovaRequest)
func handleNovaRequest(c *gin.Context, info *relaycommon.RelayInfo, a *Adaptor) (*types.NewAPIError, *dto.Usage) {
// 使用InvokeModel API但使用Nova格式的请求体
awsReq := &bedrockruntime.InvokeModelInput{
ModelId: aws.String(awsModelId),
Accept: aws.String("application/json"),
ContentType: aws.String("application/json"),
}
reqBody, err := json.Marshal(novaReq)
if err != nil {
return types.NewError(errors.Wrap(err, "marshal nova request"), types.ErrorCodeBadResponseBody), nil
}
awsReq.Body = reqBody
awsResp, err := awsCli.InvokeModel(c.Request.Context(), awsReq)
awsResp, err := a.AwsClient.InvokeModel(c.Request.Context(), a.AwsReq.(*bedrockruntime.InvokeModelInput))
if err != nil {
return types.NewError(errors.Wrap(err, "InvokeModel"), types.ErrorCodeChannelAwsClientError), nil
}

View File

@@ -0,0 +1,132 @@
package minimax
import (
"bytes"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"github.com/QuantumNous/new-api/dto"
"github.com/QuantumNous/new-api/relay/channel"
"github.com/QuantumNous/new-api/relay/channel/openai"
relaycommon "github.com/QuantumNous/new-api/relay/common"
"github.com/QuantumNous/new-api/relay/constant"
"github.com/QuantumNous/new-api/types"
"github.com/gin-gonic/gin"
)
type Adaptor struct {
}
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
return nil, errors.New("not implemented")
}
func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, req *dto.ClaudeRequest) (any, error) {
return nil, errors.New("not implemented")
}
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
if info.RelayMode != constant.RelayModeAudioSpeech {
return nil, errors.New("unsupported audio relay mode")
}
voiceID := request.Voice
speed := request.Speed
outputFormat := request.ResponseFormat
minimaxRequest := MiniMaxTTSRequest{
Model: info.OriginModelName,
Text: request.Input,
VoiceSetting: VoiceSetting{
VoiceID: voiceID,
Speed: speed,
},
AudioSetting: &AudioSetting{
Format: outputFormat,
},
OutputFormat: outputFormat,
}
// 同步扩展字段的厂商自定义metadata
if len(request.Metadata) > 0 {
if err := json.Unmarshal(request.Metadata, &minimaxRequest); err != nil {
return nil, fmt.Errorf("error unmarshalling metadata to minimax request: %w", err)
}
}
jsonData, err := json.Marshal(minimaxRequest)
if err != nil {
return nil, fmt.Errorf("error marshalling minimax request: %w", err)
}
if outputFormat != "hex" {
outputFormat = "url"
}
c.Set("response_format", outputFormat)
// Debug: log the request structure
// fmt.Printf("MiniMax TTS Request: %s\n", string(jsonData))
return bytes.NewReader(jsonData), nil
}
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
return request, nil
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
return GetRequestURL(info)
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
channel.SetupApiRequestHeader(info, c, req)
req.Set("Authorization", "Bearer "+info.ApiKey)
return nil
}
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
return request, nil
}
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
return nil, nil
}
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
return request, nil
}
func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {
return nil, errors.New("not implemented")
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
if info.RelayMode == constant.RelayModeAudioSpeech {
return handleTTSResponse(c, resp, info)
}
adaptor := openai.Adaptor{}
return adaptor.DoResponse(c, resp, info)
}
func (a *Adaptor) GetModelList() []string {
return ModelList
}
func (a *Adaptor) GetChannelName() string {
return ChannelName
}

View File

@@ -8,6 +8,12 @@ var ModelList = []string{
"abab6-chat",
"abab5.5-chat",
"abab5.5s-chat",
"speech-2.5-hd-preview",
"speech-2.5-turbo-preview",
"speech-02-hd",
"speech-02-turbo",
"speech-01-hd",
"speech-01-turbo",
}
var ChannelName = "minimax"

View File

@@ -3,9 +3,23 @@ package minimax
import (
"fmt"
channelconstant "github.com/QuantumNous/new-api/constant"
relaycommon "github.com/QuantumNous/new-api/relay/common"
"github.com/QuantumNous/new-api/relay/constant"
)
func GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
return fmt.Sprintf("%s/v1/text/chatcompletion_v2", info.ChannelBaseUrl), nil
baseUrl := info.ChannelBaseUrl
if baseUrl == "" {
baseUrl = channelconstant.ChannelBaseURLs[channelconstant.ChannelTypeMiniMax]
}
switch info.RelayMode {
case constant.RelayModeChatCompletions:
return fmt.Sprintf("%s/v1/text/chatcompletion_v2", baseUrl), nil
case constant.RelayModeAudioSpeech:
return fmt.Sprintf("%s/v1/t2a_v2", baseUrl), nil
default:
return "", fmt.Errorf("unsupported relay mode: %d", info.RelayMode)
}
}

View File

@@ -0,0 +1,194 @@
package minimax
import (
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"strings"
"github.com/QuantumNous/new-api/dto"
relaycommon "github.com/QuantumNous/new-api/relay/common"
"github.com/QuantumNous/new-api/types"
"github.com/gin-gonic/gin"
)
type MiniMaxTTSRequest struct {
Model string `json:"model"`
Text string `json:"text"`
Stream bool `json:"stream,omitempty"`
StreamOptions *StreamOptions `json:"stream_options,omitempty"`
VoiceSetting VoiceSetting `json:"voice_setting"`
PronunciationDict *PronunciationDict `json:"pronunciation_dict,omitempty"`
AudioSetting *AudioSetting `json:"audio_setting,omitempty"`
TimbreWeights []TimbreWeight `json:"timbre_weights,omitempty"`
LanguageBoost string `json:"language_boost,omitempty"`
VoiceModify *VoiceModify `json:"voice_modify,omitempty"`
SubtitleEnable bool `json:"subtitle_enable,omitempty"`
OutputFormat string `json:"output_format,omitempty"`
AigcWatermark bool `json:"aigc_watermark,omitempty"`
}
type StreamOptions struct {
ExcludeAggregatedAudio bool `json:"exclude_aggregated_audio,omitempty"`
}
type VoiceSetting struct {
VoiceID string `json:"voice_id"`
Speed float64 `json:"speed,omitempty"`
Vol float64 `json:"vol,omitempty"`
Pitch int `json:"pitch,omitempty"`
Emotion string `json:"emotion,omitempty"`
TextNormalization bool `json:"text_normalization,omitempty"`
LatexRead bool `json:"latex_read,omitempty"`
}
type PronunciationDict struct {
Tone []string `json:"tone,omitempty"`
}
type AudioSetting struct {
SampleRate int `json:"sample_rate,omitempty"`
Bitrate int `json:"bitrate,omitempty"`
Format string `json:"format,omitempty"`
Channel int `json:"channel,omitempty"`
ForceCbr bool `json:"force_cbr,omitempty"`
}
type TimbreWeight struct {
VoiceID string `json:"voice_id"`
Weight int `json:"weight"`
}
type VoiceModify struct {
Pitch int `json:"pitch,omitempty"`
Intensity int `json:"intensity,omitempty"`
Timbre int `json:"timbre,omitempty"`
SoundEffects string `json:"sound_effects,omitempty"`
}
type MiniMaxTTSResponse struct {
Data MiniMaxTTSData `json:"data"`
ExtraInfo MiniMaxExtraInfo `json:"extra_info"`
TraceID string `json:"trace_id"`
BaseResp MiniMaxBaseResp `json:"base_resp"`
}
type MiniMaxTTSData struct {
Audio string `json:"audio"`
Status int `json:"status"`
}
type MiniMaxExtraInfo struct {
UsageCharacters int64 `json:"usage_characters"`
}
type MiniMaxBaseResp struct {
StatusCode int64 `json:"status_code"`
StatusMsg string `json:"status_msg"`
}
func getContentTypeByFormat(format string) string {
contentTypeMap := map[string]string{
"mp3": "audio/mpeg",
"wav": "audio/wav",
"flac": "audio/flac",
"aac": "audio/aac",
"pcm": "audio/pcm",
}
if ct, ok := contentTypeMap[format]; ok {
return ct
}
return "audio/mpeg" // default to mp3
}
func handleTTSResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
body, readErr := io.ReadAll(resp.Body)
if readErr != nil {
return nil, types.NewErrorWithStatusCode(
fmt.Errorf("failed to read minimax response: %w", readErr),
types.ErrorCodeReadResponseBodyFailed,
http.StatusInternalServerError,
)
}
defer resp.Body.Close()
// Parse response
var minimaxResp MiniMaxTTSResponse
if unmarshalErr := json.Unmarshal(body, &minimaxResp); unmarshalErr != nil {
return nil, types.NewErrorWithStatusCode(
fmt.Errorf("failed to unmarshal minimax TTS response: %w", unmarshalErr),
types.ErrorCodeBadResponseBody,
http.StatusInternalServerError,
)
}
// Check base_resp status code
if minimaxResp.BaseResp.StatusCode != 0 {
return nil, types.NewErrorWithStatusCode(
fmt.Errorf("minimax TTS error: %d - %s", minimaxResp.BaseResp.StatusCode, minimaxResp.BaseResp.StatusMsg),
types.ErrorCodeBadResponse,
http.StatusBadRequest,
)
}
// Check if we have audio data
if minimaxResp.Data.Audio == "" {
return nil, types.NewErrorWithStatusCode(
fmt.Errorf("no audio data in minimax TTS response"),
types.ErrorCodeBadResponse,
http.StatusBadRequest,
)
}
if strings.HasPrefix(minimaxResp.Data.Audio, "http") {
c.Redirect(http.StatusFound, minimaxResp.Data.Audio)
} else {
// Handle hex-encoded audio data
audioData, decodeErr := hex.DecodeString(minimaxResp.Data.Audio)
if decodeErr != nil {
return nil, types.NewErrorWithStatusCode(
fmt.Errorf("failed to decode hex audio data: %w", decodeErr),
types.ErrorCodeBadResponse,
http.StatusInternalServerError,
)
}
// Determine content type - default to mp3
contentType := "audio/mpeg"
c.Data(http.StatusOK, contentType, audioData)
}
usage = &dto.Usage{
PromptTokens: info.PromptTokens,
CompletionTokens: 0,
TotalTokens: int(minimaxResp.ExtraInfo.UsageCharacters),
}
return usage, nil
}
func handleChatCompletionResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
body, readErr := io.ReadAll(resp.Body)
if readErr != nil {
return nil, types.NewErrorWithStatusCode(
errors.New("failed to read minimax response"),
types.ErrorCodeReadResponseBodyFailed,
http.StatusInternalServerError,
)
}
defer resp.Body.Close()
// Set response headers
for key, values := range resp.Header {
for _, value := range values {
c.Header(key, value)
}
}
c.Data(resp.StatusCode, "application/json", body)
return nil, nil
}

View File

@@ -121,7 +121,14 @@ func ollamaStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http
if chunk.Message != nil && len(chunk.Message.Thinking) > 0 {
raw := strings.TrimSpace(string(chunk.Message.Thinking))
if raw != "" && raw != "null" {
delta.Choices[0].Delta.SetReasoningContent(raw)
// Unmarshal the JSON string to get the actual content without quotes
var thinkingContent string
if err := json.Unmarshal(chunk.Message.Thinking, &thinkingContent); err == nil {
delta.Choices[0].Delta.SetReasoningContent(thinkingContent)
} else {
// Fallback to raw string if it's not a JSON string
delta.Choices[0].Delta.SetReasoningContent(raw)
}
}
}
// tool calls
@@ -209,7 +216,14 @@ func ollamaChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.R
if ck.Message != nil && len(ck.Message.Thinking) > 0 {
raw := strings.TrimSpace(string(ck.Message.Thinking))
if raw != "" && raw != "null" {
reasoningBuilder.WriteString(raw)
// Unmarshal the JSON string to get the actual content without quotes
var thinkingContent string
if err := json.Unmarshal(ck.Message.Thinking, &thinkingContent); err == nil {
reasoningBuilder.WriteString(thinkingContent)
} else {
// Fallback to raw string if it's not a JSON string
reasoningBuilder.WriteString(raw)
}
}
}
if ck.Message != nil && ck.Message.Content != "" {
@@ -229,7 +243,14 @@ func ollamaChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.R
if len(single.Message.Thinking) > 0 {
raw := strings.TrimSpace(string(single.Message.Thinking))
if raw != "" && raw != "null" {
reasoningBuilder.WriteString(raw)
// Unmarshal the JSON string to get the actual content without quotes
var thinkingContent string
if err := json.Unmarshal(single.Message.Thinking, &thinkingContent); err == nil {
reasoningBuilder.WriteString(thinkingContent)
} else {
// Fallback to raw string if it's not a JSON string
reasoningBuilder.WriteString(raw)
}
}
}
aggContent.WriteString(single.Message.Content)

View File

@@ -18,7 +18,7 @@ import (
"github.com/QuantumNous/new-api/relay/channel"
"github.com/QuantumNous/new-api/relay/channel/ai360"
"github.com/QuantumNous/new-api/relay/channel/lingyiwanwu"
"github.com/QuantumNous/new-api/relay/channel/minimax"
//"github.com/QuantumNous/new-api/relay/channel/minimax"
"github.com/QuantumNous/new-api/relay/channel/openrouter"
"github.com/QuantumNous/new-api/relay/channel/xinference"
relaycommon "github.com/QuantumNous/new-api/relay/common"
@@ -161,8 +161,8 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
requestURL = fmt.Sprintf("/openai/realtime?deployment=%s&api-version=%s", model_, apiVersion)
}
return relaycommon.GetFullRequestURL(info.ChannelBaseUrl, requestURL, info.ChannelType), nil
case constant.ChannelTypeMiniMax:
return minimax.GetRequestURL(info)
//case constant.ChannelTypeMiniMax:
// return minimax.GetRequestURL(info)
case constant.ChannelTypeCustom:
url := info.ChannelBaseUrl
url = strings.Replace(url, "{model}", info.UpstreamModelName, -1)
@@ -599,8 +599,8 @@ func (a *Adaptor) GetModelList() []string {
return ai360.ModelList
case constant.ChannelTypeLingYiWanWu:
return lingyiwanwu.ModelList
case constant.ChannelTypeMiniMax:
return minimax.ModelList
//case constant.ChannelTypeMiniMax:
// return minimax.ModelList
case constant.ChannelTypeXinference:
return xinference.ModelList
case constant.ChannelTypeOpenRouter:
@@ -616,8 +616,8 @@ func (a *Adaptor) GetChannelName() string {
return ai360.ChannelName
case constant.ChannelTypeLingYiWanWu:
return lingyiwanwu.ChannelName
case constant.ChannelTypeMiniMax:
return minimax.ChannelName
//case constant.ChannelTypeMiniMax:
// return minimax.ChannelName
case constant.ChannelTypeXinference:
return xinference.ChannelName
case constant.ChannelTypeOpenRouter:

View File

@@ -15,6 +15,7 @@ import (
"strings"
"time"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/model"
"github.com/gin-gonic/gin"
@@ -446,7 +447,7 @@ func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, e
return &taskResult, nil
}
func (a *TaskAdaptor) ConvertToOpenAIVideo(originTask *model.Task) (*dto.OpenAIVideo, error) {
func (a *TaskAdaptor) ConvertToOpenAIVideo(originTask *model.Task) ([]byte, error) {
var jimengResp responseTask
if err := json.Unmarshal(originTask.Data, &jimengResp); err != nil {
return nil, errors.Wrap(err, "unmarshal jimeng task data failed")
@@ -467,7 +468,8 @@ func (a *TaskAdaptor) ConvertToOpenAIVideo(originTask *model.Task) (*dto.OpenAIV
}
}
return openAIVideo, nil
jsonData, _ := common.Marshal(openAIVideo)
return jsonData, nil
}
func isNewAPIRelay(apiKey string) bool {

View File

@@ -9,6 +9,7 @@ import (
"strings"
"time"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/model"
"github.com/samber/lo"
@@ -367,7 +368,7 @@ func isNewAPIRelay(apiKey string) bool {
return strings.HasPrefix(apiKey, "sk-")
}
func (a *TaskAdaptor) ConvertToOpenAIVideo(originTask *model.Task) (*dto.OpenAIVideo, error) {
func (a *TaskAdaptor) ConvertToOpenAIVideo(originTask *model.Task) ([]byte, error) {
var klingResp responsePayload
if err := json.Unmarshal(originTask.Data, &klingResp); err != nil {
return nil, errors.Wrap(err, "unmarshal kling task data failed")
@@ -396,6 +397,6 @@ func (a *TaskAdaptor) ConvertToOpenAIVideo(originTask *model.Task) (*dto.OpenAIV
Code: fmt.Sprintf("%d", klingResp.Code),
}
}
return openAIVideo, nil
jsonData, _ := common.Marshal(openAIVideo)
return jsonData, nil
}

View File

@@ -2,7 +2,6 @@ package sora
import (
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
@@ -107,7 +106,7 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, _ *relayco
// Parse Sora response
var dResp responseTask
if err := json.Unmarshal(responseBody, &dResp); err != nil {
if err := common.Unmarshal(responseBody, &dResp); err != nil {
taskErr = service.TaskErrorWrapper(errors.Wrapf(err, "body: %s", responseBody), "unmarshal_response_body_failed", http.StatusInternalServerError)
return
}
@@ -154,7 +153,7 @@ func (a *TaskAdaptor) GetChannelName() string {
func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) {
resTask := responseTask{}
if err := json.Unmarshal(respBody, &resTask); err != nil {
if err := common.Unmarshal(respBody, &resTask); err != nil {
return nil, errors.Wrap(err, "unmarshal task result failed")
}
@@ -186,11 +185,6 @@ func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, e
return &taskResult, nil
}
func (a *TaskAdaptor) ConvertToOpenAIVideo(task *model.Task) (*dto.OpenAIVideo, error) {
openAIVideo := &dto.OpenAIVideo{}
err := json.Unmarshal(task.Data, openAIVideo)
if err != nil {
return nil, errors.Wrap(err, "unmarshal to OpenAIVideo failed")
}
return openAIVideo, nil
func (a *TaskAdaptor) ConvertToOpenAIVideo(task *model.Task) ([]byte, error) {
return task.Data, nil
}

View File

@@ -8,6 +8,7 @@ import (
"net/http"
"time"
"github.com/QuantumNous/new-api/common"
"github.com/gin-gonic/gin"
"github.com/QuantumNous/new-api/constant"
@@ -263,7 +264,7 @@ func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, e
return taskInfo, nil
}
func (a *TaskAdaptor) ConvertToOpenAIVideo(originTask *model.Task) (*dto.OpenAIVideo, error) {
func (a *TaskAdaptor) ConvertToOpenAIVideo(originTask *model.Task) ([]byte, error) {
var viduResp taskResultResponse
if err := json.Unmarshal(originTask.Data, &viduResp); err != nil {
return nil, errors.Wrap(err, "unmarshal vidu task data failed")
@@ -287,5 +288,6 @@ func (a *TaskAdaptor) ConvertToOpenAIVideo(originTask *model.Task) (*dto.OpenAIV
}
}
return openAIVideo, nil
jsonData, _ := common.Marshal(openAIVideo)
return jsonData, nil
}

View File

@@ -37,8 +37,57 @@ func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayIn
}
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
//TODO implement me
return nil, errors.New("not implemented")
if info.RelayMode != constant.RelayModeAudioSpeech {
return nil, errors.New("unsupported audio relay mode")
}
appID, token, err := parseVolcengineAuth(info.ApiKey)
if err != nil {
return nil, err
}
voiceType := mapVoiceType(request.Voice)
speedRatio := request.Speed
encoding := mapEncoding(request.ResponseFormat)
c.Set("response_format", encoding)
volcRequest := VolcengineTTSRequest{
App: VolcengineTTSApp{
AppID: appID,
Token: token,
Cluster: "volcano_tts",
},
User: VolcengineTTSUser{
UID: "openai_relay_user",
},
Audio: VolcengineTTSAudio{
VoiceType: voiceType,
Encoding: encoding,
SpeedRatio: speedRatio,
Rate: 24000,
},
Request: VolcengineTTSReqInfo{
ReqID: generateRequestID(),
Text: request.Input,
Operation: "query",
Model: info.OriginModelName,
},
}
// 同步扩展字段的厂商自定义metadata
if len(request.Metadata) > 0 {
if err = json.Unmarshal(request.Metadata, &volcRequest); err != nil {
return nil, fmt.Errorf("error unmarshalling metadata to volcengine request: %w", err)
}
}
jsonData, err := json.Marshal(volcRequest)
if err != nil {
return nil, fmt.Errorf("error marshalling volcengine request: %w", err)
}
return bytes.NewReader(jsonData), nil
}
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
@@ -190,7 +239,6 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
// 支持自定义域名,如果未设置则使用默认域名
baseUrl := info.ChannelBaseUrl
if baseUrl == "" {
baseUrl = channelconstant.ChannelBaseURLs[channelconstant.ChannelTypeVolcEngine]
@@ -217,6 +265,12 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
return fmt.Sprintf("%s/api/v3/images/edits", baseUrl), nil
case constant.RelayModeRerank:
return fmt.Sprintf("%s/api/v3/rerank", baseUrl), nil
case constant.RelayModeAudioSpeech:
// 只有当 baseUrl 是火山默认的官方Url时才改为官方的的TTS接口否则走透传的New接口
if baseUrl == channelconstant.ChannelBaseURLs[channelconstant.ChannelTypeVolcEngine] {
return "https://openspeech.bytedance.com/api/v1/tts", nil
}
return fmt.Sprintf("%s/v1/audio/speech", baseUrl), nil
default:
}
}
@@ -225,6 +279,16 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
channel.SetupApiRequestHeader(info, c, req)
if info.RelayMode == constant.RelayModeAudioSpeech {
parts := strings.Split(info.ApiKey, "|")
if len(parts) == 2 {
req.Set("Authorization", "Bearer;"+parts[1])
}
req.Set("Content-Type", "application/json")
return nil
}
req.Set("Authorization", "Bearer "+info.ApiKey)
return nil
}
@@ -260,6 +324,11 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
if info.RelayMode == constant.RelayModeAudioSpeech {
encoding := mapEncoding(c.GetString("response_format"))
return handleTTSResponse(c, resp, info, encoding)
}
adaptor := openai.Adaptor{}
usage, err = adaptor.DoResponse(c, resp, info)
return

View File

@@ -0,0 +1,194 @@
package volcengine
import (
"encoding/base64"
"encoding/json"
"errors"
"io"
"net/http"
"strings"
"github.com/QuantumNous/new-api/dto"
relaycommon "github.com/QuantumNous/new-api/relay/common"
"github.com/QuantumNous/new-api/types"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
)
type VolcengineTTSRequest struct {
App VolcengineTTSApp `json:"app"`
User VolcengineTTSUser `json:"user"`
Audio VolcengineTTSAudio `json:"audio"`
Request VolcengineTTSReqInfo `json:"request"`
}
type VolcengineTTSApp struct {
AppID string `json:"appid"`
Token string `json:"token"`
Cluster string `json:"cluster"`
}
type VolcengineTTSUser struct {
UID string `json:"uid"`
}
type VolcengineTTSAudio struct {
VoiceType string `json:"voice_type"`
Encoding string `json:"encoding"`
SpeedRatio float64 `json:"speed_ratio"`
Rate int `json:"rate"`
Bitrate int `json:"bitrate,omitempty"`
LoudnessRatio float64 `json:"loudness_ratio,omitempty"`
EnableEmotion bool `json:"enable_emotion,omitempty"`
Emotion string `json:"emotion,omitempty"`
EmotionScale float64 `json:"emotion_scale,omitempty"`
ExplicitLanguage string `json:"explicit_language,omitempty"`
ContextLanguage string `json:"context_language,omitempty"`
}
type VolcengineTTSReqInfo struct {
ReqID string `json:"reqid"`
Text string `json:"text"`
Operation string `json:"operation"`
Model string `json:"model,omitempty"`
TextType string `json:"text_type,omitempty"`
SilenceDuration float64 `json:"silence_duration,omitempty"`
WithTimestamp interface{} `json:"with_timestamp,omitempty"`
ExtraParam *VolcengineTTSExtraParam `json:"extra_param,omitempty"`
}
type VolcengineTTSExtraParam struct {
DisableMarkdownFilter bool `json:"disable_markdown_filter,omitempty"`
EnableLatexTn bool `json:"enable_latex_tn,omitempty"`
MuteCutThreshold string `json:"mute_cut_threshold,omitempty"`
MuteCutRemainMs string `json:"mute_cut_remain_ms,omitempty"`
DisableEmojiFilter bool `json:"disable_emoji_filter,omitempty"`
UnsupportedCharRatioThresh float64 `json:"unsupported_char_ratio_thresh,omitempty"`
AigcWatermark bool `json:"aigc_watermark,omitempty"`
CacheConfig *VolcengineTTSCacheConfig `json:"cache_config,omitempty"`
}
type VolcengineTTSCacheConfig struct {
TextType int `json:"text_type,omitempty"`
UseCache bool `json:"use_cache,omitempty"`
}
type VolcengineTTSResponse struct {
ReqID string `json:"reqid"`
Code int `json:"code"`
Message string `json:"message"`
Sequence int `json:"sequence"`
Data string `json:"data"`
Addition *VolcengineTTSAdditionInfo `json:"addition,omitempty"`
}
type VolcengineTTSAdditionInfo struct {
Duration string `json:"duration"`
}
var openAIToVolcengineVoiceMap = map[string]string{
"alloy": "zh_male_M392_conversation_wvae_bigtts",
"echo": "zh_male_wenhao_mars_bigtts",
"fable": "zh_female_tianmei_mars_bigtts",
"onyx": "zh_male_zhibei_mars_bigtts",
"nova": "zh_female_shuangkuaisisi_mars_bigtts",
"shimmer": "zh_female_cancan_mars_bigtts",
}
var responseFormatToEncodingMap = map[string]string{
"mp3": "mp3",
"opus": "ogg_opus",
"aac": "mp3",
"flac": "mp3",
"wav": "wav",
"pcm": "pcm",
}
func parseVolcengineAuth(apiKey string) (appID, token string, err error) {
parts := strings.Split(apiKey, "|")
if len(parts) != 2 {
return "", "", errors.New("invalid api key format, expected: appid|access_token")
}
return parts[0], parts[1], nil
}
func mapVoiceType(openAIVoice string) string {
if voice, ok := openAIToVolcengineVoiceMap[openAIVoice]; ok {
return voice
}
return openAIVoice
}
func mapEncoding(responseFormat string) string {
if encoding, ok := responseFormatToEncodingMap[responseFormat]; ok {
return encoding
}
return "mp3"
}
func getContentTypeByEncoding(encoding string) string {
contentTypeMap := map[string]string{
"mp3": "audio/mpeg",
"ogg_opus": "audio/ogg",
"wav": "audio/wav",
"pcm": "audio/pcm",
}
if ct, ok := contentTypeMap[encoding]; ok {
return ct
}
return "application/octet-stream"
}
func handleTTSResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, encoding string) (usage any, err *types.NewAPIError) {
body, readErr := io.ReadAll(resp.Body)
if readErr != nil {
return nil, types.NewErrorWithStatusCode(
errors.New("failed to read volcengine response"),
types.ErrorCodeReadResponseBodyFailed,
http.StatusInternalServerError,
)
}
defer resp.Body.Close()
var volcResp VolcengineTTSResponse
if unmarshalErr := json.Unmarshal(body, &volcResp); unmarshalErr != nil {
return nil, types.NewErrorWithStatusCode(
errors.New("failed to parse volcengine response"),
types.ErrorCodeBadResponseBody,
http.StatusInternalServerError,
)
}
if volcResp.Code != 3000 {
return nil, types.NewErrorWithStatusCode(
errors.New(volcResp.Message),
types.ErrorCodeBadResponse,
http.StatusBadRequest,
)
}
audioData, decodeErr := base64.StdEncoding.DecodeString(volcResp.Data)
if decodeErr != nil {
return nil, types.NewErrorWithStatusCode(
errors.New("failed to decode audio data"),
types.ErrorCodeBadResponseBody,
http.StatusInternalServerError,
)
}
contentType := getContentTypeByEncoding(encoding)
c.Header("Content-Type", contentType)
c.Data(http.StatusOK, contentType, audioData)
usage = &dto.Usage{
PromptTokens: info.PromptTokens,
CompletionTokens: 0,
TotalTokens: info.PromptTokens,
}
return usage, nil
}
func generateRequestID() string {
return uuid.New().String()
}

View File

@@ -263,6 +263,7 @@ var streamSupportedChannels = map[int]bool{
constant.ChannelTypeDeepSeek: true,
constant.ChannelTypeBaiduV2: true,
constant.ChannelTypeZhipu_v4: true,
constant.ChannelTypeAli: true,
}
func GenRelayInfoWs(c *gin.Context, ws *websocket.Conn) *RelayInfo {
@@ -512,6 +513,13 @@ type TaskInfo struct {
TotalTokens int `json:"total_tokens,omitempty"` // 用于按倍率计费
}
func FailTaskInfo(reason string) *TaskInfo {
return &TaskInfo{
Status: "FAILURE",
Reason: reason,
}
}
// RemoveDisabledFields 从请求 JSON 数据中移除渠道设置中禁用的字段
// service_tier: 服务层级字段可能导致额外计费OpenAI、Claude、Responses API 支持)
// store: 数据存储授权字段,涉及用户隐私(仅 OpenAI、Responses API 支持,默认允许透传,禁用后可能导致 Codex 无法使用)

View File

@@ -8,6 +8,7 @@ import (
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/dto"
"github.com/QuantumNous/new-api/logger"
relaycommon "github.com/QuantumNous/new-api/relay/common"
"github.com/QuantumNous/new-api/relay/helper"
"github.com/QuantumNous/new-api/service"
@@ -48,6 +49,7 @@ func EmbeddingHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *
if err != nil {
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
}
logger.LogDebug(c, fmt.Sprintf("converted embedding request body: %s", string(jsonData)))
requestBody := bytes.NewBuffer(jsonData)
statusCodeMappingStr := c.GetString("status_code_mapping")
resp, err := adaptor.DoRequest(c, info, requestBody)

View File

@@ -240,6 +240,8 @@ func GeminiEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo) (newAPI
return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry())
}
req.SetModelName("models/" + info.UpstreamModelName)
adaptor := GetAdaptor(info.ApiType)
if adaptor == nil {
return types.NewError(fmt.Errorf("invalid api type: %d", info.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry())
@@ -264,6 +266,7 @@ func GeminiEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo) (newAPI
return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
}
}
logger.LogDebug(c, "Gemini embedding request body: "+string(jsonData))
requestBody = bytes.NewReader(jsonData)
resp, err := adaptor.DoRequest(c, info, requestBody)

View File

@@ -22,8 +22,10 @@ func GetAndValidateRequest(c *gin.Context, format types.RelayFormat) (request dt
case types.RelayFormatOpenAI:
request, err = GetAndValidateTextRequest(c, relayMode)
case types.RelayFormatGemini:
if strings.Contains(c.Request.URL.Path, ":embedContent") || strings.Contains(c.Request.URL.Path, ":batchEmbedContents") {
if strings.Contains(c.Request.URL.Path, ":embedContent") {
request, err = GetAndValidateGeminiEmbeddingRequest(c)
} else if strings.Contains(c.Request.URL.Path, ":batchEmbedContents") {
request, err = GetAndValidateGeminiBatchEmbeddingRequest(c)
} else {
request, err = GetAndValidateGeminiRequest(c)
}
@@ -300,7 +302,7 @@ func GetAndValidateGeminiRequest(c *gin.Context) (*dto.GeminiChatRequest, error)
if err != nil {
return nil, err
}
if len(request.Contents) == 0 {
if len(request.Contents) == 0 && len(request.Requests) == 0 {
return nil, errors.New("contents is required")
}
@@ -319,3 +321,12 @@ func GetAndValidateGeminiEmbeddingRequest(c *gin.Context) (*dto.GeminiEmbeddingR
}
return request, nil
}
func GetAndValidateGeminiBatchEmbeddingRequest(c *gin.Context) (*dto.GeminiBatchEmbeddingRequest, error) {
request := &dto.GeminiBatchEmbeddingRequest{}
err := common.UnmarshalBodyReusable(c, request)
if err != nil {
return nil, err
}
return request, nil
}

View File

@@ -18,6 +18,7 @@ import (
"github.com/QuantumNous/new-api/relay/channel/gemini"
"github.com/QuantumNous/new-api/relay/channel/jimeng"
"github.com/QuantumNous/new-api/relay/channel/jina"
"github.com/QuantumNous/new-api/relay/channel/minimax"
"github.com/QuantumNous/new-api/relay/channel/mistral"
"github.com/QuantumNous/new-api/relay/channel/mokaai"
"github.com/QuantumNous/new-api/relay/channel/moonshot"
@@ -108,6 +109,8 @@ func GetAdaptor(apiType int) channel.Adaptor {
return &moonshot.Adaptor{} // Moonshot uses Claude API
case constant.APITypeSubmodel:
return &submodel.Adaptor{}
case constant.APITypeMiniMax:
return &minimax.Adaptor{}
}
return nil
}
@@ -139,7 +142,7 @@ func GetTaskAdaptor(platform constant.TaskPlatform) channel.TaskAdaptor {
return &taskVidu.TaskAdaptor{}
case constant.ChannelTypeDoubaoVideo:
return &taskdoubao.TaskAdaptor{}
case constant.ChannelTypeSora:
case constant.ChannelTypeSora, constant.ChannelTypeOpenAI:
return &tasksora.TaskAdaptor{}
}
}

View File

@@ -72,10 +72,13 @@ func RelayTaskSubmit(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.
} else {
ratio = modelPrice * groupRatio
}
if len(info.PriceData.OtherRatios) > 0 {
for _, ra := range info.PriceData.OtherRatios {
if 1.0 != ra {
ratio *= ra
// FIXME: 临时修补,支持任务仅按次计费
if !common.StringsContains(constant.TaskPricePatches, modelName) {
if len(info.PriceData.OtherRatios) > 0 {
for _, ra := range info.PriceData.OtherRatios {
if 1.0 != ra {
ratio *= ra
}
}
}
}
@@ -153,15 +156,20 @@ func RelayTaskSubmit(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.
// gRatio = userGroupRatio
//}
logContent := fmt.Sprintf("操作 %s", info.Action)
if len(info.PriceData.OtherRatios) > 0 {
var contents []string
for key, ra := range info.PriceData.OtherRatios {
if 1.0 != ra {
contents = append(contents, fmt.Sprintf("%s: %.2f", key, ra))
// FIXME: 临时修补,支持任务仅按次计费
if common.StringsContains(constant.TaskPricePatches, modelName) {
logContent = fmt.Sprintf("%s按次计费", logContent)
} else {
if len(info.PriceData.OtherRatios) > 0 {
var contents []string
for key, ra := range info.PriceData.OtherRatios {
if 1.0 != ra {
contents = append(contents, fmt.Sprintf("%s: %.2f", key, ra))
}
}
if len(contents) > 0 {
logContent = fmt.Sprintf("%s, 计算参数:%s", logContent, strings.Join(contents, ", "))
}
}
if len(contents) > 0 {
logContent = fmt.Sprintf("%s, 计算参数:%s", logContent, strings.Join(contents, ", "))
}
}
other := make(map[string]interface{})
@@ -397,12 +405,12 @@ func videoFetchByIDRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *d
return
}
if converter, ok := adaptor.(channel.OpenAIVideoConverter); ok {
openAIVideo, err := converter.ConvertToOpenAIVideo(originTask)
openAIVideoData, err := converter.ConvertToOpenAIVideo(originTask)
if err != nil {
taskResp = service.TaskErrorWrapper(err, "convert_to_openai_video_failed", http.StatusInternalServerError)
return
}
respBody, _ = json.Marshal(openAIVideo)
respBody = openAIVideoData
return
}
taskResp = service.TaskErrorWrapperLocal(errors.New(fmt.Sprintf("not_implemented:%s", originTask.Platform)), "not_implemented", http.StatusNotImplemented)

View File

@@ -62,6 +62,9 @@ const (
ErrorCodeConvertRequestFailed ErrorCode = "convert_request_failed"
ErrorCodeAccessDenied ErrorCode = "access_denied"
// request error
ErrorCodeBadRequestBody ErrorCode = "bad_request_body"
// response error
ErrorCodeReadResponseBodyFailed ErrorCode = "read_response_body_failed"
ErrorCodeBadResponseStatusCode ErrorCode = "bad_response_status_code"

View File

@@ -107,10 +107,12 @@ function type2secretPrompt(type) {
return '按照如下格式输入AppId|SecretId|SecretKey';
case 33:
return '按照如下格式输入Ak|Sk|Region';
case 45:
return '请输入渠道对应的鉴权密钥, 豆包语音输入AppId|AccessToken';
case 50:
return '按照如下格式输入: AccessKey|SecretKey, 如果上游是New API则直接输ApiKey';
case 51:
return '按照如下格式输入: Access Key ID|Secret Access Key';
return '按照如下格式输入: AccessKey|SecretAccessKey';
default:
return '请输入渠道对应的鉴权密钥';
}
@@ -153,6 +155,8 @@ const EditChannelModal = (props) => {
settings: '',
// 仅 Vertex: 密钥格式(存入 settings.vertex_key_type
vertex_key_type: 'json',
// 仅 AWS: 密钥格式和区域(存入 settings.aws_key_type 和 settings.aws_region
aws_key_type: 'ak_sk',
// 企业账户设置
is_enterprise_account: false,
// 字段透传控制默认值
@@ -515,6 +519,8 @@ const EditChannelModal = (props) => {
parsedSettings.azure_responses_version || '';
// 读取 Vertex 密钥格式
data.vertex_key_type = parsedSettings.vertex_key_type || 'json';
// 读取 AWS 密钥格式和区域
data.aws_key_type = parsedSettings.aws_key_type || 'ak_sk';
// 读取企业账户设置
data.is_enterprise_account =
parsedSettings.openrouter_enterprise === true;
@@ -528,6 +534,7 @@ const EditChannelModal = (props) => {
data.azure_responses_version = '';
data.region = '';
data.vertex_key_type = 'json';
data.aws_key_type = 'ak_sk';
data.is_enterprise_account = false;
data.allow_service_tier = false;
data.disable_store = false;
@@ -536,6 +543,7 @@ const EditChannelModal = (props) => {
} else {
// 兼容历史数据:老渠道没有 settings 时,默认按 json 展示
data.vertex_key_type = 'json';
data.aws_key_type = 'ak_sk';
data.is_enterprise_account = false;
data.allow_service_tier = false;
data.disable_store = false;
@@ -997,6 +1005,11 @@ const EditChannelModal = (props) => {
localInputs.is_enterprise_account === true;
}
// type === 33 (AWS): 保存 aws_key_type 到 settings
if (localInputs.type === 33) {
settings.aws_key_type = localInputs.aws_key_type || 'ak_sk';
}
// type === 1 (OpenAI) 或 type === 14 (Claude): 设置字段透传控制(显式保存布尔值)
if (localInputs.type === 1 || localInputs.type === 14) {
settings.allow_service_tier = localInputs.allow_service_tier === true;
@@ -1020,6 +1033,8 @@ const EditChannelModal = (props) => {
delete localInputs.is_enterprise_account;
// 顶层的 vertex_key_type 不应发送给后端
delete localInputs.vertex_key_type;
// 顶层的 aws_key_type 不应发送给后端
delete localInputs.aws_key_type;
// 清理字段透传控制的临时字段
delete localInputs.allow_service_tier;
delete localInputs.disable_store;
@@ -1468,6 +1483,31 @@ const EditChannelModal = (props) => {
autoComplete='new-password'
/>
{inputs.type === 33 && (
<>
<Form.Select
field='aws_key_type'
label={t('密钥格式')}
placeholder={t('请选择密钥格式')}
optionList={[
{
label: 'AccessKey / SecretAccessKey',
value: 'ak_sk',
},
{ label: 'API Key', value: 'api_key' },
]}
style={{ width: '100%' }}
value={inputs.aws_key_type || 'ak_sk'}
onChange={(value) => {
handleChannelOtherSettingsChange('aws_key_type', value);
}}
extraText={t(
'AK/SK 模式:使用 AccessKey 和 SecretAccessKeyAPI Key 模式:使用 API Key',
)}
/>
</>
)}
{inputs.type === 41 && (
<Form.Select
field='vertex_key_type'
@@ -1536,7 +1576,15 @@ const EditChannelModal = (props) => {
<Form.TextArea
field='key'
label={t('密钥')}
placeholder={t('请输入密钥,一行一个')}
placeholder={
inputs.type === 33
? inputs.aws_key_type === 'api_key'
? t('请输入 API Key一行一个格式APIKey|Region')
: t(
'请输入密钥一行一个格式AccessKey|SecretAccessKey|Region',
)
: t('请输入密钥,一行一个')
}
rules={
isEdit
? []
@@ -1730,7 +1778,13 @@ const EditChannelModal = (props) => {
? t('密钥(编辑模式下,保存的密钥不会显示)')
: t('密钥')
}
placeholder={t(type2secretPrompt(inputs.type))}
placeholder={
inputs.type === 33
? inputs.aws_key_type === 'api_key'
? t('请输入 API Key格式APIKey|Region')
: t('按照如下格式输入AccessKey|SecretAccessKey|Region')
: t(type2secretPrompt(inputs.type))
}
rules={
isEdit
? []