mirror of
https://github.com/QuantumNous/new-api.git
synced 2026-03-31 16:51:52 +00:00
Compare commits
66 Commits
refactor/c
...
update-git
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ce1fde8500 | ||
|
|
4661399639 | ||
|
|
78d8d458ca | ||
|
|
e20a287c4b | ||
|
|
c7ab0f4f3d | ||
|
|
0d1057830b | ||
|
|
dd1cac3f2e | ||
|
|
cdbc7a9510 | ||
|
|
c693bfee5e | ||
|
|
7156bf2382 | ||
|
|
c216527f23 | ||
|
|
b1de0f49df | ||
|
|
525ca09f2c | ||
|
|
92fc973bc3 | ||
|
|
22ff8e2cbe | ||
|
|
1ec664a348 | ||
|
|
6a24c37c0e | ||
|
|
8965fc49c9 | ||
|
|
735386c0b9 | ||
|
|
58c4da0ddf | ||
|
|
fe68488b1c | ||
|
|
25af6e6f77 | ||
|
|
e2d3b46a3a | ||
|
|
dd775167ab | ||
|
|
43f2a8ac06 | ||
|
|
bcf93a2c05 | ||
|
|
09ff878d88 | ||
|
|
d4749ba388 | ||
|
|
1f2bdb1402 | ||
|
|
64a97092c9 | ||
|
|
69b87b5d8e | ||
|
|
bd4160793e | ||
|
|
82e21972ec | ||
|
|
dce00141ce | ||
|
|
b2a057723a | ||
|
|
f023efdbfc | ||
|
|
8b65623726 | ||
|
|
aa35d8db69 | ||
|
|
64ed7dce4d | ||
|
|
67c321c4fb | ||
|
|
b3f50e9dd0 | ||
|
|
ea870a7846 | ||
|
|
fa21599fc8 | ||
|
|
e6c42bfbda | ||
|
|
7d480d5ff3 | ||
|
|
86c63ea4a7 | ||
|
|
2624c48113 | ||
|
|
384cba92cf | ||
|
|
7222265fee | ||
|
|
fdbc31eb9a | ||
|
|
3172c956f7 | ||
|
|
8b9188c584 | ||
|
|
5fc9152499 | ||
|
|
18b945b9c5 | ||
|
|
826ef2e5a6 | ||
|
|
7311c18d52 | ||
|
|
4a4238d830 | ||
|
|
9805b0f3b0 | ||
|
|
dfca9681c8 | ||
|
|
a6e6897f63 | ||
|
|
ec0633bdfb | ||
|
|
2d1534dc77 | ||
|
|
eebd7ca0f3 | ||
|
|
98e3e5ca2c | ||
|
|
e5dde67272 | ||
|
|
d2546cf9ec |
3
.github/workflows/docker-image-arm64.yml
vendored
3
.github/workflows/docker-image-arm64.yml
vendored
@@ -33,7 +33,8 @@ jobs:
|
||||
- name: Resolve tag & write VERSION
|
||||
run: |
|
||||
git fetch --tags --force --depth=1
|
||||
echo "TAG=${GITHUB_REF#refs/tags/}" >> $GITHUB_ENV
|
||||
TAG=${GITHUB_REF#refs/tags/}
|
||||
echo "TAG=$TAG" >> $GITHUB_ENV
|
||||
echo "$TAG" > VERSION
|
||||
echo "Building tag: $TAG for ${{ matrix.arch }}"
|
||||
|
||||
|
||||
7
.github/workflows/electron-build.yml
vendored
7
.github/workflows/electron-build.yml
vendored
@@ -4,6 +4,8 @@ on:
|
||||
push:
|
||||
tags:
|
||||
- '*' # Triggers on version tags like v1.0.0
|
||||
- '!*-*' # Ignore pre-release tags like v1.0.0-beta
|
||||
- '!*-alpha*' # Ignore alpha tags like v1.0.0-alpha
|
||||
workflow_dispatch: # Allows manual triggering
|
||||
|
||||
jobs:
|
||||
@@ -130,13 +132,10 @@ jobs:
|
||||
- name: Download all artifacts
|
||||
uses: actions/download-artifact@v4
|
||||
|
||||
- name: Create Release
|
||||
- name: Upload to Release
|
||||
uses: softprops/action-gh-release@v2
|
||||
with:
|
||||
files: |
|
||||
windows-build/*
|
||||
draft: false
|
||||
prerelease: false
|
||||
overwrite_files: true
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
6
.github/workflows/release.yml
vendored
6
.github/workflows/release.yml
vendored
@@ -54,8 +54,6 @@ jobs:
|
||||
with:
|
||||
files: |
|
||||
new-api-*
|
||||
draft: true
|
||||
generate_release_notes: true
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
@@ -93,8 +91,6 @@ jobs:
|
||||
if: startsWith(github.ref, 'refs/tags/')
|
||||
with:
|
||||
files: new-api-macos-*
|
||||
draft: true
|
||||
generate_release_notes: true
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
@@ -134,8 +130,6 @@ jobs:
|
||||
if: startsWith(github.ref, 'refs/tags/')
|
||||
with:
|
||||
files: new-api-*.exe
|
||||
draft: true
|
||||
generate_release_notes: true
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
|
||||
4
.gitignore
vendored
4
.gitignore
vendored
@@ -1,5 +1,6 @@
|
||||
.idea
|
||||
.vscode
|
||||
.zed
|
||||
upload
|
||||
*.exe
|
||||
*.db
|
||||
@@ -10,10 +11,11 @@ web/dist
|
||||
.env
|
||||
one-api
|
||||
new-api
|
||||
/__debug_bin*
|
||||
.DS_Store
|
||||
tiktoken_cache
|
||||
.eslintcache
|
||||
.gocache
|
||||
|
||||
electron/node_modules
|
||||
electron/dist
|
||||
electron/dist
|
||||
|
||||
16
README.md
16
README.md
@@ -165,12 +165,18 @@ New API提供了丰富的功能,详细特性请参考[特性说明](https://do
|
||||
|
||||
#### 使用Docker Compose部署(推荐)
|
||||
```shell
|
||||
# 下载项目
|
||||
git clone https://github.com/Calcium-Ion/new-api.git
|
||||
# 下载项目源码
|
||||
git clone https://github.com/QuantumNous/new-api.git
|
||||
|
||||
# 进入项目目录
|
||||
cd new-api
|
||||
# 按需编辑docker-compose.yml
|
||||
# 启动
|
||||
docker-compose up -d
|
||||
|
||||
# 根据需要编辑 docker-compose.yml 文件
|
||||
# 使用nano编辑器
|
||||
nano docker-compose.yml
|
||||
# 或使用vim编辑器
|
||||
# vim docker-compose.yml
|
||||
|
||||
```
|
||||
|
||||
#### 直接使用Docker镜像
|
||||
|
||||
@@ -69,6 +69,8 @@ func ChannelType2APIType(channelType int) (int, bool) {
|
||||
apiType = constant.APITypeMoonshot
|
||||
case constant.ChannelTypeSubmodel:
|
||||
apiType = constant.APITypeSubmodel
|
||||
case constant.ChannelTypeMiniMax:
|
||||
apiType = constant.APITypeMiniMax
|
||||
}
|
||||
if apiType == -1 {
|
||||
return constant.APITypeOpenAI, false
|
||||
|
||||
@@ -86,5 +86,8 @@ func SendEmail(subject string, receiver string, content string) error {
|
||||
} else {
|
||||
err = smtp.SendMail(addr, auth, SMTPFrom, to, mail)
|
||||
}
|
||||
if err != nil {
|
||||
SysError(fmt.Sprintf("failed to send email to %s: %v", receiver, err))
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -26,6 +26,8 @@ func GetEndpointTypesByChannelType(channelType int, modelName string) []constant
|
||||
endpointTypes = []constant.EndpointType{constant.EndpointTypeGemini, constant.EndpointTypeOpenAI}
|
||||
case constant.ChannelTypeOpenRouter: // OpenRouter 只支持 OpenAI 端点
|
||||
endpointTypes = []constant.EndpointType{constant.EndpointTypeOpenAI}
|
||||
case constant.ChannelTypeSora:
|
||||
endpointTypes = []constant.EndpointType{constant.EndpointTypeOpenAIVideo}
|
||||
default:
|
||||
if IsOpenAIResponseOnlyModel(modelName) {
|
||||
endpointTypes = []constant.EndpointType{constant.EndpointTypeOpenAIResponse}
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/QuantumNous/new-api/constant"
|
||||
@@ -118,4 +119,17 @@ func initConstantEnv() {
|
||||
constant.GenerateDefaultToken = GetEnvOrDefaultBool("GENERATE_DEFAULT_TOKEN", false)
|
||||
// 是否启用错误日志
|
||||
constant.ErrorLogEnabled = GetEnvOrDefaultBool("ERROR_LOG_ENABLED", false)
|
||||
|
||||
soraPatchStr := GetEnvOrDefaultString("TASK_PRICE_PATCH", "")
|
||||
if soraPatchStr != "" {
|
||||
var taskPricePatches []string
|
||||
soraPatches := strings.Split(soraPatchStr, ",")
|
||||
for _, patch := range soraPatches {
|
||||
trimmedPatch := strings.TrimSpace(patch)
|
||||
if trimmedPatch != "" {
|
||||
taskPricePatches = append(taskPricePatches, trimmedPatch)
|
||||
}
|
||||
}
|
||||
constant.TaskPricePatches = taskPricePatches
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package common
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"io"
|
||||
)
|
||||
|
||||
func Unmarshal(data []byte, v any) error {
|
||||
@@ -13,7 +14,7 @@ func UnmarshalJsonStr(data string, v any) error {
|
||||
return json.Unmarshal(StringToByteSlice(data), v)
|
||||
}
|
||||
|
||||
func DecodeJson(reader *bytes.Reader, v any) error {
|
||||
func DecodeJson(reader io.Reader, v any) error {
|
||||
return json.NewDecoder(reader).Decode(v)
|
||||
}
|
||||
|
||||
|
||||
@@ -1,161 +0,0 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/core/interfaces"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
// PluginConfig 插件配置结构
|
||||
type PluginConfig struct {
|
||||
Channels map[string]interfaces.ChannelConfig `yaml:"channels"`
|
||||
Middlewares []interfaces.MiddlewareConfig `yaml:"middlewares"`
|
||||
Hooks HooksConfig `yaml:"hooks"`
|
||||
}
|
||||
|
||||
// HooksConfig Hook配置
|
||||
type HooksConfig struct {
|
||||
Relay []interfaces.HookConfig `yaml:"relay"`
|
||||
}
|
||||
|
||||
var (
|
||||
// 全局配置实例
|
||||
globalPluginConfig *PluginConfig
|
||||
)
|
||||
|
||||
// LoadPluginConfig 加载插件配置
|
||||
func LoadPluginConfig(configPath string) (*PluginConfig, error) {
|
||||
// 如果没有指定配置文件路径,使用默认路径
|
||||
if configPath == "" {
|
||||
configPath = "config/plugins.yaml"
|
||||
}
|
||||
|
||||
// 检查文件是否存在
|
||||
if _, err := os.Stat(configPath); os.IsNotExist(err) {
|
||||
common.SysLog(fmt.Sprintf("Plugin config file not found: %s, using default configuration", configPath))
|
||||
return getDefaultConfig(), nil
|
||||
}
|
||||
|
||||
// 读取配置文件
|
||||
data, err := ioutil.ReadFile(configPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read plugin config: %w", err)
|
||||
}
|
||||
|
||||
// 解析YAML
|
||||
var config PluginConfig
|
||||
if err := yaml.Unmarshal(data, &config); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse plugin config: %w", err)
|
||||
}
|
||||
|
||||
// 环境变量替换
|
||||
expandEnvVars(&config)
|
||||
|
||||
common.SysLog(fmt.Sprintf("Loaded plugin config from: %s", configPath))
|
||||
|
||||
return &config, nil
|
||||
}
|
||||
|
||||
// getDefaultConfig 返回默认配置
|
||||
func getDefaultConfig() *PluginConfig {
|
||||
return &PluginConfig{
|
||||
Channels: make(map[string]interfaces.ChannelConfig),
|
||||
Middlewares: make([]interfaces.MiddlewareConfig, 0),
|
||||
Hooks: HooksConfig{
|
||||
Relay: make([]interfaces.HookConfig, 0),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// expandEnvVars 展开环境变量
|
||||
func expandEnvVars(config *PluginConfig) {
|
||||
// 展开Hook配置中的环境变量
|
||||
for i := range config.Hooks.Relay {
|
||||
for key, value := range config.Hooks.Relay[i].Config {
|
||||
if strValue, ok := value.(string); ok {
|
||||
config.Hooks.Relay[i].Config[key] = os.ExpandEnv(strValue)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 展开Middleware配置中的环境变量
|
||||
for i := range config.Middlewares {
|
||||
for key, value := range config.Middlewares[i].Config {
|
||||
if strValue, ok := value.(string); ok {
|
||||
config.Middlewares[i].Config[key] = os.ExpandEnv(strValue)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// GetGlobalPluginConfig 获取全局配置
|
||||
func GetGlobalPluginConfig() *PluginConfig {
|
||||
if globalPluginConfig == nil {
|
||||
configPath := os.Getenv("PLUGIN_CONFIG_PATH")
|
||||
if configPath == "" {
|
||||
configPath = "config/plugins.yaml"
|
||||
}
|
||||
|
||||
config, err := LoadPluginConfig(configPath)
|
||||
if err != nil {
|
||||
common.SysError(fmt.Sprintf("Failed to load plugin config: %v", err))
|
||||
config = getDefaultConfig()
|
||||
}
|
||||
|
||||
globalPluginConfig = config
|
||||
}
|
||||
|
||||
return globalPluginConfig
|
||||
}
|
||||
|
||||
// SavePluginConfig 保存插件配置
|
||||
func SavePluginConfig(config *PluginConfig, configPath string) error {
|
||||
if configPath == "" {
|
||||
configPath = "config/plugins.yaml"
|
||||
}
|
||||
|
||||
// 确保目录存在
|
||||
dir := filepath.Dir(configPath)
|
||||
if err := os.MkdirAll(dir, 0755); err != nil {
|
||||
return fmt.Errorf("failed to create config directory: %w", err)
|
||||
}
|
||||
|
||||
// 序列化为YAML
|
||||
data, err := yaml.Marshal(config)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal config: %w", err)
|
||||
}
|
||||
|
||||
// 写入文件
|
||||
if err := ioutil.WriteFile(configPath, data, 0644); err != nil {
|
||||
return fmt.Errorf("failed to write config file: %w", err)
|
||||
}
|
||||
|
||||
common.SysLog(fmt.Sprintf("Saved plugin config to: %s", configPath))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ReloadPluginConfig 重新加载配置
|
||||
func ReloadPluginConfig() error {
|
||||
configPath := os.Getenv("PLUGIN_CONFIG_PATH")
|
||||
if configPath == "" {
|
||||
configPath = "config/plugins.yaml"
|
||||
}
|
||||
|
||||
config, err := LoadPluginConfig(configPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
globalPluginConfig = config
|
||||
common.SysLog("Plugin config reloaded")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1,52 +0,0 @@
|
||||
# New-API 插件配置
|
||||
# 此文件用于配置所有插件的启用状态和参数
|
||||
|
||||
# Channel插件配置
|
||||
channels:
|
||||
openai:
|
||||
enabled: true
|
||||
priority: 100
|
||||
|
||||
claude:
|
||||
enabled: true
|
||||
priority: 90
|
||||
|
||||
gemini:
|
||||
enabled: true
|
||||
priority: 85
|
||||
|
||||
# Middleware插件配置
|
||||
middlewares:
|
||||
- name: auth
|
||||
enabled: true
|
||||
priority: 100
|
||||
|
||||
- name: ratelimit
|
||||
enabled: true
|
||||
priority: 90
|
||||
config:
|
||||
default_rate: 60
|
||||
|
||||
# Hook插件配置
|
||||
hooks:
|
||||
# Relay层Hook
|
||||
relay:
|
||||
# 联网搜索插件
|
||||
- name: web_search
|
||||
enabled: false # 默认禁用,需要配置API key后启用
|
||||
priority: 50
|
||||
config:
|
||||
provider: google
|
||||
api_key: ${WEB_SEARCH_API_KEY} # 从环境变量读取
|
||||
|
||||
# 内容过滤插件
|
||||
- name: content_filter
|
||||
enabled: false # 默认禁用,需要配置后启用
|
||||
priority: 100 # 高优先级,最后执行
|
||||
config:
|
||||
filter_nsfw: true
|
||||
filter_political: false
|
||||
sensitive_words:
|
||||
- "敏感词1"
|
||||
- "敏感词2"
|
||||
|
||||
@@ -33,5 +33,6 @@ const (
|
||||
APITypeJimeng
|
||||
APITypeMoonshot
|
||||
APITypeSubmodel
|
||||
APITypeMiniMax
|
||||
APITypeDummy // this one is only for count, do not add any channel after this
|
||||
)
|
||||
|
||||
@@ -10,6 +10,7 @@ const (
|
||||
EndpointTypeJinaRerank EndpointType = "jina-rerank"
|
||||
EndpointTypeImageGeneration EndpointType = "image-generation"
|
||||
EndpointTypeEmbeddings EndpointType = "embeddings"
|
||||
EndpointTypeOpenAIVideo EndpointType = "openai-video"
|
||||
//EndpointTypeMidjourney EndpointType = "midjourney-proxy"
|
||||
//EndpointTypeSuno EndpointType = "suno-proxy"
|
||||
//EndpointTypeKling EndpointType = "kling"
|
||||
|
||||
@@ -13,3 +13,6 @@ var NotifyLimitCount int
|
||||
var NotificationLimitDurationMinute int
|
||||
var GenerateDefaultToken bool
|
||||
var ErrorLogEnabled bool
|
||||
|
||||
// temporary variable for sora patch, will be removed in future
|
||||
var TaskPricePatches []string
|
||||
|
||||
@@ -229,7 +229,7 @@ func getChannel(c *gin.Context, group, originalModel string, retryCount int) (*m
|
||||
return nil, types.NewError(fmt.Errorf("获取分组 %s 下模型 %s 的可用渠道失败(retry): %s", selectGroup, originalModel, err.Error()), types.ErrorCodeGetChannelFailed, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
if channel == nil {
|
||||
return nil, types.NewError(fmt.Errorf("分组 %s 下模型 %s 的可用渠道不存在(数据库一致性已被破坏,retry)", selectGroup, originalModel), types.ErrorCodeGetChannelFailed, types.ErrOptionWithSkipRetry())
|
||||
return nil, types.NewError(fmt.Errorf("分组 %s 下模型 %s 的可用渠道不存在(retry)", selectGroup, originalModel), types.ErrorCodeGetChannelFailed, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
newAPIError := middleware.SetupContextForSelectedChannel(c, channel, originalModel)
|
||||
if newAPIError != nil {
|
||||
@@ -299,6 +299,9 @@ func processChannelError(c *gin.Context, channelError types.ChannelError, err *t
|
||||
userGroup := c.GetString("group")
|
||||
channelId := c.GetInt("channel_id")
|
||||
other := make(map[string]interface{})
|
||||
if c.Request != nil && c.Request.URL != nil {
|
||||
other["request_path"] = c.Request.URL.Path
|
||||
}
|
||||
other["error_type"] = err.GetErrorType()
|
||||
other["error_code"] = err.GetErrorCode()
|
||||
other["status_code"] = err.StatusCode
|
||||
|
||||
@@ -88,10 +88,13 @@ func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, cha
|
||||
return fmt.Errorf("readAll failed for task %s: %w", taskId, err)
|
||||
}
|
||||
|
||||
logger.LogDebug(ctx, fmt.Sprintf("UpdateVideoSingleTask response: %s", string(responseBody)))
|
||||
|
||||
taskResult := &relaycommon.TaskInfo{}
|
||||
// try parse as New API response format
|
||||
var responseItems dto.TaskResponse[model.Task]
|
||||
if err = json.Unmarshal(responseBody, &responseItems); err == nil && responseItems.IsSuccess() {
|
||||
if err = common.Unmarshal(responseBody, &responseItems); err == nil && responseItems.IsSuccess() {
|
||||
logger.LogDebug(ctx, fmt.Sprintf("UpdateVideoSingleTask parsed as new api response format: %+v", responseItems))
|
||||
t := responseItems.Data
|
||||
taskResult.TaskID = t.TaskID
|
||||
taskResult.Status = string(t.Status)
|
||||
@@ -105,10 +108,19 @@ func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, cha
|
||||
task.Data = redactVideoResponseBody(responseBody)
|
||||
}
|
||||
|
||||
logger.LogDebug(ctx, fmt.Sprintf("UpdateVideoSingleTask taskResult: %+v", taskResult))
|
||||
|
||||
now := time.Now().Unix()
|
||||
if taskResult.Status == "" {
|
||||
return fmt.Errorf("task %s status is empty", taskId)
|
||||
//return fmt.Errorf("task %s status is empty", taskId)
|
||||
taskResult = relaycommon.FailTaskInfo("upstream returned empty status")
|
||||
}
|
||||
|
||||
// 记录原本的状态,防止重复退款
|
||||
shouldRefund := false
|
||||
quota := task.Quota
|
||||
preStatus := task.Status
|
||||
|
||||
task.Status = model.TaskStatus(taskResult.Status)
|
||||
switch taskResult.Status {
|
||||
case model.TaskStatusSubmitted:
|
||||
@@ -137,14 +149,19 @@ func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, cha
|
||||
if modelName, ok := taskData["model"].(string); ok && modelName != "" {
|
||||
// 获取模型价格和倍率
|
||||
modelRatio, hasRatioSetting, _ := ratio_setting.GetModelRatio(modelName)
|
||||
|
||||
// 只有配置了倍率(非固定价格)时才按 token 重新计费
|
||||
if hasRatioSetting && modelRatio > 0 {
|
||||
// 获取用户和组的倍率信息
|
||||
user, err := model.GetUserById(task.UserId, false)
|
||||
if err == nil {
|
||||
groupRatio := ratio_setting.GetGroupRatio(user.Group)
|
||||
userGroupRatio, hasUserGroupRatio := ratio_setting.GetGroupGroupRatio(user.Group, user.Group)
|
||||
group := task.Group
|
||||
if group == "" {
|
||||
user, err := model.GetUserById(task.UserId, false)
|
||||
if err == nil {
|
||||
group = user.Group
|
||||
}
|
||||
}
|
||||
if group != "" {
|
||||
groupRatio := ratio_setting.GetGroupRatio(group)
|
||||
userGroupRatio, hasUserGroupRatio := ratio_setting.GetGroupGroupRatio(group, group)
|
||||
|
||||
var finalGroupRatio float64
|
||||
if hasUserGroupRatio {
|
||||
@@ -214,6 +231,7 @@ func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, cha
|
||||
}
|
||||
}
|
||||
case model.TaskStatusFailure:
|
||||
logger.LogJson(ctx, fmt.Sprintf("Task %s failed", taskId), task)
|
||||
task.Status = model.TaskStatusFailure
|
||||
task.Progress = "100%"
|
||||
if task.FinishTime == 0 {
|
||||
@@ -221,13 +239,13 @@ func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, cha
|
||||
}
|
||||
task.FailReason = taskResult.Reason
|
||||
logger.LogInfo(ctx, fmt.Sprintf("Task %s failed: %s", task.TaskID, task.FailReason))
|
||||
quota := task.Quota
|
||||
taskResult.Progress = "100%"
|
||||
if quota != 0 {
|
||||
if err := model.IncreaseUserQuota(task.UserId, quota, false); err != nil {
|
||||
logger.LogError(ctx, "Failed to increase user quota: "+err.Error())
|
||||
if preStatus != model.TaskStatusFailure {
|
||||
shouldRefund = true
|
||||
} else {
|
||||
logger.LogWarn(ctx, fmt.Sprintf("Task %s already in failure status, skip refund", task.TaskID))
|
||||
}
|
||||
logContent := fmt.Sprintf("Video async task failed %s, refund %s", task.TaskID, logger.LogQuota(quota))
|
||||
model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("unknown task status %s for task %s", taskResult.Status, taskId)
|
||||
@@ -237,6 +255,16 @@ func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, cha
|
||||
}
|
||||
if err := task.Update(); err != nil {
|
||||
common.SysLog("UpdateVideoTask task error: " + err.Error())
|
||||
shouldRefund = false
|
||||
}
|
||||
|
||||
if shouldRefund {
|
||||
// 任务失败且之前状态不是失败才退还额度,防止重复退还
|
||||
if err := model.IncreaseUserQuota(task.UserId, quota, false); err != nil {
|
||||
logger.LogWarn(ctx, "Failed to increase user quota: "+err.Error())
|
||||
}
|
||||
logContent := fmt.Sprintf("Video async task failed %s, refund %s", task.TaskID, logger.LogQuota(quota))
|
||||
model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
@@ -1,66 +0,0 @@
|
||||
package interfaces
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
"github.com/QuantumNous/new-api/dto"
|
||||
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||||
"github.com/QuantumNous/new-api/types"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// ChannelPlugin 定义Channel插件接口
|
||||
// 继承原有的Adaptor接口,增加插件元数据
|
||||
type ChannelPlugin interface {
|
||||
// 插件元数据
|
||||
Name() string
|
||||
Version() string
|
||||
Priority() int
|
||||
|
||||
// 原有Adaptor接口方法
|
||||
Init(info *relaycommon.RelayInfo)
|
||||
GetRequestURL(info *relaycommon.RelayInfo) (string, error)
|
||||
SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error
|
||||
ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error)
|
||||
ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error)
|
||||
ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error)
|
||||
ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error)
|
||||
ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error)
|
||||
ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error)
|
||||
DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error)
|
||||
DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError)
|
||||
GetModelList() []string
|
||||
GetChannelName() string
|
||||
ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error)
|
||||
ConvertGeminiRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeminiChatRequest) (any, error)
|
||||
}
|
||||
|
||||
// TaskChannelPlugin 定义Task类型的Channel插件接口
|
||||
type TaskChannelPlugin interface {
|
||||
// 插件元数据
|
||||
Name() string
|
||||
Version() string
|
||||
Priority() int
|
||||
|
||||
// 原有TaskAdaptor接口方法
|
||||
Init(info *relaycommon.RelayInfo)
|
||||
ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) *dto.TaskError
|
||||
BuildRequestURL(info *relaycommon.RelayInfo) (string, error)
|
||||
BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error
|
||||
BuildRequestBody(c *gin.Context, info *relaycommon.RelayInfo) (io.Reader, error)
|
||||
DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error)
|
||||
DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (taskID string, taskData []byte, err *dto.TaskError)
|
||||
GetModelList() []string
|
||||
GetChannelName() string
|
||||
FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error)
|
||||
ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error)
|
||||
}
|
||||
|
||||
// ChannelConfig 插件配置
|
||||
type ChannelConfig struct {
|
||||
Enabled bool `yaml:"enabled"`
|
||||
Priority int `yaml:"priority"`
|
||||
Config map[string]interface{} `yaml:"config"`
|
||||
}
|
||||
|
||||
@@ -1,93 +0,0 @@
|
||||
package interfaces
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// HookContext Relay Hook执行上下文
|
||||
type HookContext struct {
|
||||
// Gin Context
|
||||
GinContext *gin.Context
|
||||
|
||||
// Request相关
|
||||
Request *http.Request
|
||||
RequestBody []byte
|
||||
|
||||
// Response相关
|
||||
Response *http.Response
|
||||
ResponseBody []byte
|
||||
|
||||
// Channel信息
|
||||
ChannelID int
|
||||
ChannelType int
|
||||
ChannelName string
|
||||
|
||||
// Model信息
|
||||
Model string
|
||||
OriginalModel string
|
||||
|
||||
// User信息
|
||||
UserID int
|
||||
TokenID int
|
||||
Group string
|
||||
|
||||
// 扩展数据(插件间共享)
|
||||
Data map[string]interface{}
|
||||
|
||||
// 错误信息
|
||||
Error error
|
||||
|
||||
// 是否跳过后续处理
|
||||
ShouldSkip bool
|
||||
}
|
||||
|
||||
// RelayHook Relay Hook接口
|
||||
type RelayHook interface {
|
||||
// 插件元数据
|
||||
Name() string
|
||||
Priority() int
|
||||
Enabled() bool
|
||||
|
||||
// 生命周期钩子
|
||||
// OnBeforeRequest 在请求发送到上游之前执行
|
||||
OnBeforeRequest(ctx *HookContext) error
|
||||
|
||||
// OnAfterResponse 在收到上游响应之后执行
|
||||
OnAfterResponse(ctx *HookContext) error
|
||||
|
||||
// OnError 在发生错误时执行
|
||||
OnError(ctx *HookContext, err error) error
|
||||
}
|
||||
|
||||
// RequestModifier 请求修改器接口
|
||||
// 实现此接口的Hook可以修改请求内容
|
||||
type RequestModifier interface {
|
||||
RelayHook
|
||||
ModifyRequest(ctx *HookContext, body io.Reader) (io.Reader, error)
|
||||
}
|
||||
|
||||
// ResponseProcessor 响应处理器接口
|
||||
// 实现此接口的Hook可以处理响应内容
|
||||
type ResponseProcessor interface {
|
||||
RelayHook
|
||||
ProcessResponse(ctx *HookContext, body []byte) ([]byte, error)
|
||||
}
|
||||
|
||||
// StreamProcessor 流式响应处理器接口
|
||||
// 实现此接口的Hook可以处理流式响应
|
||||
type StreamProcessor interface {
|
||||
RelayHook
|
||||
ProcessStreamChunk(ctx *HookContext, chunk []byte) ([]byte, error)
|
||||
}
|
||||
|
||||
// HookConfig Hook配置
|
||||
type HookConfig struct {
|
||||
Name string `yaml:"name"`
|
||||
Enabled bool `yaml:"enabled"`
|
||||
Priority int `yaml:"priority"`
|
||||
Config map[string]interface{} `yaml:"config"`
|
||||
}
|
||||
|
||||
@@ -1,31 +0,0 @@
|
||||
package interfaces
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// MiddlewarePlugin 中间件插件接口
|
||||
type MiddlewarePlugin interface {
|
||||
// 插件元数据
|
||||
Name() string
|
||||
Priority() int
|
||||
Enabled() bool
|
||||
|
||||
// 返回Gin中间件处理函数
|
||||
Handler() gin.HandlerFunc
|
||||
|
||||
// 初始化(可选)
|
||||
Initialize(config MiddlewareConfig) error
|
||||
}
|
||||
|
||||
// MiddlewareConfig 中间件配置
|
||||
type MiddlewareConfig struct {
|
||||
Name string `yaml:"name"`
|
||||
Enabled bool `yaml:"enabled"`
|
||||
Priority int `yaml:"priority"`
|
||||
Config map[string]interface{} `yaml:"config"`
|
||||
}
|
||||
|
||||
// MiddlewareFactory 中间件工厂函数类型
|
||||
type MiddlewareFactory func(config MiddlewareConfig) (MiddlewarePlugin, error)
|
||||
|
||||
@@ -1,171 +0,0 @@
|
||||
package registry
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"github.com/QuantumNous/new-api/core/interfaces"
|
||||
)
|
||||
|
||||
var (
|
||||
// 全局Channel注册表
|
||||
channelRegistry = &ChannelRegistry{plugins: make(map[int]interfaces.ChannelPlugin)}
|
||||
channelRegistryLock sync.RWMutex
|
||||
|
||||
// 全局TaskChannel注册表
|
||||
taskChannelRegistry = &TaskChannelRegistry{plugins: make(map[string]interfaces.TaskChannelPlugin)}
|
||||
taskChannelRegistryLock sync.RWMutex
|
||||
)
|
||||
|
||||
// ChannelRegistry Channel插件注册中心
|
||||
type ChannelRegistry struct {
|
||||
plugins map[int]interfaces.ChannelPlugin // channelType -> plugin
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// Register 注册Channel插件
|
||||
func (r *ChannelRegistry) Register(channelType int, plugin interfaces.ChannelPlugin) error {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
if _, exists := r.plugins[channelType]; exists {
|
||||
return fmt.Errorf("channel plugin for type %d already registered", channelType)
|
||||
}
|
||||
|
||||
r.plugins[channelType] = plugin
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get 获取Channel插件
|
||||
func (r *ChannelRegistry) Get(channelType int) (interfaces.ChannelPlugin, error) {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
plugin, exists := r.plugins[channelType]
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("channel plugin for type %d not found", channelType)
|
||||
}
|
||||
|
||||
return plugin, nil
|
||||
}
|
||||
|
||||
// List 列出所有已注册的Channel插件
|
||||
func (r *ChannelRegistry) List() []interfaces.ChannelPlugin {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
plugins := make([]interfaces.ChannelPlugin, 0, len(r.plugins))
|
||||
for _, plugin := range r.plugins {
|
||||
plugins = append(plugins, plugin)
|
||||
}
|
||||
|
||||
return plugins
|
||||
}
|
||||
|
||||
// Has 检查是否存在指定的Channel插件
|
||||
func (r *ChannelRegistry) Has(channelType int) bool {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
_, exists := r.plugins[channelType]
|
||||
return exists
|
||||
}
|
||||
|
||||
// TaskChannelRegistry TaskChannel插件注册中心
|
||||
type TaskChannelRegistry struct {
|
||||
plugins map[string]interfaces.TaskChannelPlugin // platform -> plugin
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// Register 注册TaskChannel插件
|
||||
func (r *TaskChannelRegistry) Register(platform string, plugin interfaces.TaskChannelPlugin) error {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
if _, exists := r.plugins[platform]; exists {
|
||||
return fmt.Errorf("task channel plugin for platform %s already registered", platform)
|
||||
}
|
||||
|
||||
r.plugins[platform] = plugin
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get 获取TaskChannel插件
|
||||
func (r *TaskChannelRegistry) Get(platform string) (interfaces.TaskChannelPlugin, error) {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
plugin, exists := r.plugins[platform]
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("task channel plugin for platform %s not found", platform)
|
||||
}
|
||||
|
||||
return plugin, nil
|
||||
}
|
||||
|
||||
// List 列出所有已注册的TaskChannel插件
|
||||
func (r *TaskChannelRegistry) List() []interfaces.TaskChannelPlugin {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
plugins := make([]interfaces.TaskChannelPlugin, 0, len(r.plugins))
|
||||
for _, plugin := range r.plugins {
|
||||
plugins = append(plugins, plugin)
|
||||
}
|
||||
|
||||
return plugins
|
||||
}
|
||||
|
||||
// 全局函数 - Channel Registry
|
||||
|
||||
// RegisterChannel 注册Channel插件
|
||||
func RegisterChannel(channelType int, plugin interfaces.ChannelPlugin) error {
|
||||
channelRegistryLock.Lock()
|
||||
defer channelRegistryLock.Unlock()
|
||||
return channelRegistry.Register(channelType, plugin)
|
||||
}
|
||||
|
||||
// GetChannel 获取Channel插件
|
||||
func GetChannel(channelType int) (interfaces.ChannelPlugin, error) {
|
||||
channelRegistryLock.RLock()
|
||||
defer channelRegistryLock.RUnlock()
|
||||
return channelRegistry.Get(channelType)
|
||||
}
|
||||
|
||||
// ListChannels 列出所有Channel插件
|
||||
func ListChannels() []interfaces.ChannelPlugin {
|
||||
channelRegistryLock.RLock()
|
||||
defer channelRegistryLock.RUnlock()
|
||||
return channelRegistry.List()
|
||||
}
|
||||
|
||||
// HasChannel 检查是否存在指定的Channel插件
|
||||
func HasChannel(channelType int) bool {
|
||||
channelRegistryLock.RLock()
|
||||
defer channelRegistryLock.RUnlock()
|
||||
return channelRegistry.Has(channelType)
|
||||
}
|
||||
|
||||
// 全局函数 - TaskChannel Registry
|
||||
|
||||
// RegisterTaskChannel 注册TaskChannel插件
|
||||
func RegisterTaskChannel(platform string, plugin interfaces.TaskChannelPlugin) error {
|
||||
taskChannelRegistryLock.Lock()
|
||||
defer taskChannelRegistryLock.Unlock()
|
||||
return taskChannelRegistry.Register(platform, plugin)
|
||||
}
|
||||
|
||||
// GetTaskChannel 获取TaskChannel插件
|
||||
func GetTaskChannel(platform string) (interfaces.TaskChannelPlugin, error) {
|
||||
taskChannelRegistryLock.RLock()
|
||||
defer taskChannelRegistryLock.RUnlock()
|
||||
return taskChannelRegistry.Get(platform)
|
||||
}
|
||||
|
||||
// ListTaskChannels 列出所有TaskChannel插件
|
||||
func ListTaskChannels() []interfaces.TaskChannelPlugin {
|
||||
taskChannelRegistryLock.RLock()
|
||||
defer taskChannelRegistryLock.RUnlock()
|
||||
return taskChannelRegistry.List()
|
||||
}
|
||||
|
||||
@@ -1,183 +0,0 @@
|
||||
package registry
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sort"
|
||||
"sync"
|
||||
|
||||
"github.com/QuantumNous/new-api/core/interfaces"
|
||||
)
|
||||
|
||||
var (
|
||||
// 全局Hook注册表
|
||||
hookRegistry = &HookRegistry{hooks: make([]interfaces.RelayHook, 0)}
|
||||
hookRegistryLock sync.RWMutex
|
||||
)
|
||||
|
||||
// HookRegistry Hook插件注册中心
|
||||
type HookRegistry struct {
|
||||
hooks []interfaces.RelayHook
|
||||
sorted bool
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// Register 注册Hook插件
|
||||
func (r *HookRegistry) Register(hook interfaces.RelayHook) error {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
// 检查是否已存在同名Hook
|
||||
for _, h := range r.hooks {
|
||||
if h.Name() == hook.Name() {
|
||||
return fmt.Errorf("hook %s already registered", hook.Name())
|
||||
}
|
||||
}
|
||||
|
||||
r.hooks = append(r.hooks, hook)
|
||||
r.sorted = false // 标记需要重新排序
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get 获取指定名称的Hook插件
|
||||
func (r *HookRegistry) Get(name string) (interfaces.RelayHook, error) {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
for _, hook := range r.hooks {
|
||||
if hook.Name() == name {
|
||||
return hook, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("hook %s not found", name)
|
||||
}
|
||||
|
||||
// List 列出所有已注册且启用的Hook插件(按优先级排序)
|
||||
func (r *HookRegistry) List() []interfaces.RelayHook {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
// 如果未排序,先排序
|
||||
if !r.sorted {
|
||||
r.sortHooks()
|
||||
}
|
||||
|
||||
// 只返回启用的hooks
|
||||
enabledHooks := make([]interfaces.RelayHook, 0)
|
||||
for _, hook := range r.hooks {
|
||||
if hook.Enabled() {
|
||||
enabledHooks = append(enabledHooks, hook)
|
||||
}
|
||||
}
|
||||
|
||||
return enabledHooks
|
||||
}
|
||||
|
||||
// ListAll 列出所有已注册的Hook插件(包括未启用的)
|
||||
func (r *HookRegistry) ListAll() []interfaces.RelayHook {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
hooks := make([]interfaces.RelayHook, len(r.hooks))
|
||||
copy(hooks, r.hooks)
|
||||
|
||||
return hooks
|
||||
}
|
||||
|
||||
// sortHooks 按优先级排序hooks(优先级数字越大越先执行)
|
||||
func (r *HookRegistry) sortHooks() {
|
||||
sort.SliceStable(r.hooks, func(i, j int) bool {
|
||||
return r.hooks[i].Priority() > r.hooks[j].Priority()
|
||||
})
|
||||
r.sorted = true
|
||||
}
|
||||
|
||||
// Has 检查是否存在指定的Hook插件
|
||||
func (r *HookRegistry) Has(name string) bool {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
for _, hook := range r.hooks {
|
||||
if hook.Name() == name {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// Count 返回已注册的Hook数量
|
||||
func (r *HookRegistry) Count() int {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
return len(r.hooks)
|
||||
}
|
||||
|
||||
// EnabledCount 返回已启用的Hook数量
|
||||
func (r *HookRegistry) EnabledCount() int {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
count := 0
|
||||
for _, hook := range r.hooks {
|
||||
if hook.Enabled() {
|
||||
count++
|
||||
}
|
||||
}
|
||||
|
||||
return count
|
||||
}
|
||||
|
||||
// 全局函数
|
||||
|
||||
// RegisterHook 注册Hook插件
|
||||
func RegisterHook(hook interfaces.RelayHook) error {
|
||||
hookRegistryLock.Lock()
|
||||
defer hookRegistryLock.Unlock()
|
||||
return hookRegistry.Register(hook)
|
||||
}
|
||||
|
||||
// GetHook 获取Hook插件
|
||||
func GetHook(name string) (interfaces.RelayHook, error) {
|
||||
hookRegistryLock.RLock()
|
||||
defer hookRegistryLock.RUnlock()
|
||||
return hookRegistry.Get(name)
|
||||
}
|
||||
|
||||
// ListHooks 列出所有已启用的Hook插件
|
||||
func ListHooks() []interfaces.RelayHook {
|
||||
hookRegistryLock.RLock()
|
||||
defer hookRegistryLock.RUnlock()
|
||||
return hookRegistry.List()
|
||||
}
|
||||
|
||||
// ListAllHooks 列出所有Hook插件
|
||||
func ListAllHooks() []interfaces.RelayHook {
|
||||
hookRegistryLock.RLock()
|
||||
defer hookRegistryLock.RUnlock()
|
||||
return hookRegistry.ListAll()
|
||||
}
|
||||
|
||||
// HasHook 检查是否存在指定的Hook插件
|
||||
func HasHook(name string) bool {
|
||||
hookRegistryLock.RLock()
|
||||
defer hookRegistryLock.RUnlock()
|
||||
return hookRegistry.Has(name)
|
||||
}
|
||||
|
||||
// HookCount 返回已注册的Hook数量
|
||||
func HookCount() int {
|
||||
hookRegistryLock.RLock()
|
||||
defer hookRegistryLock.RUnlock()
|
||||
return hookRegistry.Count()
|
||||
}
|
||||
|
||||
// EnabledHookCount 返回已启用的Hook数量
|
||||
func EnabledHookCount() int {
|
||||
hookRegistryLock.RLock()
|
||||
defer hookRegistryLock.RUnlock()
|
||||
return hookRegistry.EnabledCount()
|
||||
}
|
||||
|
||||
@@ -1,133 +0,0 @@
|
||||
package registry
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sort"
|
||||
"sync"
|
||||
|
||||
"github.com/QuantumNous/new-api/core/interfaces"
|
||||
)
|
||||
|
||||
var (
|
||||
// 全局Middleware注册表
|
||||
middlewareRegistry = &MiddlewareRegistry{plugins: make(map[string]interfaces.MiddlewarePlugin)}
|
||||
middlewareRegistryLock sync.RWMutex
|
||||
)
|
||||
|
||||
// MiddlewareRegistry 中间件插件注册中心
|
||||
type MiddlewareRegistry struct {
|
||||
plugins map[string]interfaces.MiddlewarePlugin
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// Register 注册Middleware插件
|
||||
func (r *MiddlewareRegistry) Register(plugin interfaces.MiddlewarePlugin) error {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
name := plugin.Name()
|
||||
if _, exists := r.plugins[name]; exists {
|
||||
return fmt.Errorf("middleware plugin %s already registered", name)
|
||||
}
|
||||
|
||||
r.plugins[name] = plugin
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get 获取Middleware插件
|
||||
func (r *MiddlewareRegistry) Get(name string) (interfaces.MiddlewarePlugin, error) {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
plugin, exists := r.plugins[name]
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("middleware plugin %s not found", name)
|
||||
}
|
||||
|
||||
return plugin, nil
|
||||
}
|
||||
|
||||
// List 列出所有已注册的Middleware插件(按优先级排序)
|
||||
func (r *MiddlewareRegistry) List() []interfaces.MiddlewarePlugin {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
plugins := make([]interfaces.MiddlewarePlugin, 0, len(r.plugins))
|
||||
for _, plugin := range r.plugins {
|
||||
plugins = append(plugins, plugin)
|
||||
}
|
||||
|
||||
// 按优先级排序(优先级数字越大越先执行)
|
||||
sort.SliceStable(plugins, func(i, j int) bool {
|
||||
return plugins[i].Priority() > plugins[j].Priority()
|
||||
})
|
||||
|
||||
return plugins
|
||||
}
|
||||
|
||||
// ListEnabled 列出所有已启用的Middleware插件(按优先级排序)
|
||||
func (r *MiddlewareRegistry) ListEnabled() []interfaces.MiddlewarePlugin {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
plugins := make([]interfaces.MiddlewarePlugin, 0, len(r.plugins))
|
||||
for _, plugin := range r.plugins {
|
||||
if plugin.Enabled() {
|
||||
plugins = append(plugins, plugin)
|
||||
}
|
||||
}
|
||||
|
||||
// 按优先级排序
|
||||
sort.SliceStable(plugins, func(i, j int) bool {
|
||||
return plugins[i].Priority() > plugins[j].Priority()
|
||||
})
|
||||
|
||||
return plugins
|
||||
}
|
||||
|
||||
// Has 检查是否存在指定的Middleware插件
|
||||
func (r *MiddlewareRegistry) Has(name string) bool {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
_, exists := r.plugins[name]
|
||||
return exists
|
||||
}
|
||||
|
||||
// 全局函数
|
||||
|
||||
// RegisterMiddleware 注册Middleware插件
|
||||
func RegisterMiddleware(plugin interfaces.MiddlewarePlugin) error {
|
||||
middlewareRegistryLock.Lock()
|
||||
defer middlewareRegistryLock.Unlock()
|
||||
return middlewareRegistry.Register(plugin)
|
||||
}
|
||||
|
||||
// GetMiddleware 获取Middleware插件
|
||||
func GetMiddleware(name string) (interfaces.MiddlewarePlugin, error) {
|
||||
middlewareRegistryLock.RLock()
|
||||
defer middlewareRegistryLock.RUnlock()
|
||||
return middlewareRegistry.Get(name)
|
||||
}
|
||||
|
||||
// ListMiddlewares 列出所有Middleware插件
|
||||
func ListMiddlewares() []interfaces.MiddlewarePlugin {
|
||||
middlewareRegistryLock.RLock()
|
||||
defer middlewareRegistryLock.RUnlock()
|
||||
return middlewareRegistry.List()
|
||||
}
|
||||
|
||||
// ListEnabledMiddlewares 列出所有已启用的Middleware插件
|
||||
func ListEnabledMiddlewares() []interfaces.MiddlewarePlugin {
|
||||
middlewareRegistryLock.RLock()
|
||||
defer middlewareRegistryLock.RUnlock()
|
||||
return middlewareRegistry.ListEnabled()
|
||||
}
|
||||
|
||||
// HasMiddleware 检查是否存在指定的Middleware插件
|
||||
func HasMiddleware(name string) bool {
|
||||
middlewareRegistryLock.RLock()
|
||||
defer middlewareRegistryLock.RUnlock()
|
||||
return middlewareRegistry.Has(name)
|
||||
}
|
||||
|
||||
@@ -1,116 +0,0 @@
|
||||
package registry
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/QuantumNous/new-api/core/interfaces"
|
||||
)
|
||||
|
||||
// Mock Hook实现
|
||||
type mockHook struct {
|
||||
name string
|
||||
priority int
|
||||
enabled bool
|
||||
}
|
||||
|
||||
func (m *mockHook) Name() string { return m.name }
|
||||
func (m *mockHook) Priority() int { return m.priority }
|
||||
func (m *mockHook) Enabled() bool { return m.enabled }
|
||||
func (m *mockHook) OnBeforeRequest(ctx *interfaces.HookContext) error { return nil }
|
||||
func (m *mockHook) OnAfterResponse(ctx *interfaces.HookContext) error { return nil }
|
||||
func (m *mockHook) OnError(ctx *interfaces.HookContext, err error) error { return nil }
|
||||
|
||||
func TestHookRegistry(t *testing.T) {
|
||||
// 创建新的注册表(用于测试)
|
||||
registry := &HookRegistry{hooks: make([]interfaces.RelayHook, 0)}
|
||||
|
||||
// 测试注册Hook
|
||||
hook1 := &mockHook{name: "test_hook_1", priority: 100, enabled: true}
|
||||
hook2 := &mockHook{name: "test_hook_2", priority: 50, enabled: true}
|
||||
hook3 := &mockHook{name: "test_hook_3", priority: 75, enabled: false}
|
||||
|
||||
if err := registry.Register(hook1); err != nil {
|
||||
t.Errorf("Failed to register hook1: %v", err)
|
||||
}
|
||||
|
||||
if err := registry.Register(hook2); err != nil {
|
||||
t.Errorf("Failed to register hook2: %v", err)
|
||||
}
|
||||
|
||||
if err := registry.Register(hook3); err != nil {
|
||||
t.Errorf("Failed to register hook3: %v", err)
|
||||
}
|
||||
|
||||
// 测试重复注册
|
||||
if err := registry.Register(hook1); err == nil {
|
||||
t.Error("Expected error when registering duplicate hook")
|
||||
}
|
||||
|
||||
// 测试获取Hook
|
||||
if hook, err := registry.Get("test_hook_1"); err != nil {
|
||||
t.Errorf("Failed to get hook: %v", err)
|
||||
} else if hook.Name() != "test_hook_1" {
|
||||
t.Errorf("Got wrong hook: %s", hook.Name())
|
||||
}
|
||||
|
||||
// 测试不存在的Hook
|
||||
if _, err := registry.Get("nonexistent"); err == nil {
|
||||
t.Error("Expected error when getting nonexistent hook")
|
||||
}
|
||||
|
||||
// 测试List(应该只返回enabled的hooks)
|
||||
hooks := registry.List()
|
||||
if len(hooks) != 2 {
|
||||
t.Errorf("Expected 2 enabled hooks, got %d", len(hooks))
|
||||
}
|
||||
|
||||
// 测试优先级排序(100应该在50之前)
|
||||
if hooks[0].Priority() != 100 {
|
||||
t.Errorf("Expected first hook to have priority 100, got %d", hooks[0].Priority())
|
||||
}
|
||||
|
||||
// 测试Count
|
||||
if count := registry.Count(); count != 3 {
|
||||
t.Errorf("Expected count 3, got %d", count)
|
||||
}
|
||||
|
||||
// 测试EnabledCount
|
||||
if count := registry.EnabledCount(); count != 2 {
|
||||
t.Errorf("Expected enabled count 2, got %d", count)
|
||||
}
|
||||
|
||||
// 测试Has
|
||||
if !registry.Has("test_hook_1") {
|
||||
t.Error("Expected to find test_hook_1")
|
||||
}
|
||||
|
||||
if registry.Has("nonexistent") {
|
||||
t.Error("Should not find nonexistent hook")
|
||||
}
|
||||
}
|
||||
|
||||
func TestChannelRegistry(t *testing.T) {
|
||||
// 这里可以添加Channel Registry的测试
|
||||
// 但需要mock ChannelPlugin接口,比较复杂
|
||||
// 作为示例,我们只测试基本功能
|
||||
|
||||
registry := &ChannelRegistry{plugins: make(map[int]interfaces.ChannelPlugin)}
|
||||
|
||||
// 测试Has方法
|
||||
if registry.Has(1) {
|
||||
t.Error("Should not find channel type 1")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMiddlewareRegistry(t *testing.T) {
|
||||
// Middleware Registry测试
|
||||
// 需要mock MiddlewarePlugin接口
|
||||
|
||||
registry := &MiddlewareRegistry{plugins: make(map[string]interfaces.MiddlewarePlugin)}
|
||||
|
||||
// 测试Has方法
|
||||
if registry.Has("test_middleware") {
|
||||
t.Error("Should not find test_middleware")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -30,11 +30,14 @@ services:
|
||||
# - SQL_DSN=root:123456@tcp(mysql:3306)/new-api # Point to the mysql service, uncomment if using MySQL
|
||||
- REDIS_CONN_STRING=redis://redis
|
||||
- TZ=Asia/Shanghai
|
||||
- ERROR_LOG_ENABLED=true # 是否启用错误日志记录
|
||||
- BATCH_UPDATE_ENABLED=true # 是否启用批量更新 batch update enabled
|
||||
# - STREAMING_TIMEOUT=300 # 流模式无响应超时时间,单位秒,默认120秒,如果出现空补全可以尝试改为更大值 Streaming timeout in seconds, default is 120s. Increase if experiencing empty completions
|
||||
# - SESSION_SECRET=random_string # 多机部署时设置,必须修改这个随机字符串!! multi-node deployment, set this to a random string!!!!!!!
|
||||
- ERROR_LOG_ENABLED=true # 是否启用错误日志记录 (Whether to enable error log recording)
|
||||
- BATCH_UPDATE_ENABLED=true # 是否启用批量更新 (Whether to enable batch update)
|
||||
# - STREAMING_TIMEOUT=300 # 流模式无响应超时时间,单位秒,默认120秒,如果出现空补全可以尝试改为更大值 (Streaming timeout in seconds, default is 120s. Increase if experiencing empty completions)
|
||||
# - SESSION_SECRET=random_string # 多机部署时设置,必须修改这个随机字符串!! (multi-node deployment, set this to a random string!!!!!!!)
|
||||
# - SYNC_FREQUENCY=60 # Uncomment if regular database syncing is needed
|
||||
# - GOOGLE_ANALYTICS_ID=G-XXXXXXXXXX # Google Analytics 的测量 ID (Google Analytics Measurement ID)
|
||||
# - UMAMI_WEBSITE_ID=xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx # Umami 网站 ID (Umami Website ID)
|
||||
# - UMAMI_SCRIPT_URL=https://analytics.umami.is/script.js # Umami 脚本 URL,默认为官方地址 (Umami Script URL, defaults to official URL)
|
||||
|
||||
depends_on:
|
||||
- redis
|
||||
|
||||
@@ -1,359 +0,0 @@
|
||||
# New-API 插件化架构说明
|
||||
|
||||
## 完整目录结构
|
||||
|
||||
```
|
||||
new-api-2/
|
||||
├── core/ # 核心层(高性能,不可插件化)
|
||||
│ ├── interfaces/ # 插件接口定义
|
||||
│ │ ├── channel.go # Channel插件接口
|
||||
│ │ ├── hook.go # Hook插件接口
|
||||
│ │ └── middleware.go # Middleware插件接口
|
||||
│ └── registry/ # 插件注册中心
|
||||
│ ├── channel_registry.go # Channel注册器(线程安全)
|
||||
│ ├── hook_registry.go # Hook注册器(优先级排序)
|
||||
│ └── middleware_registry.go # Middleware注册器
|
||||
│
|
||||
├── plugins/ # 🔵 Tier 1: 编译时插件(已实施)
|
||||
│ ├── channels/ # Channel插件
|
||||
│ │ ├── base_plugin.go # 基础插件包装器
|
||||
│ │ └── registry.go # 自动注册31个AI Provider
|
||||
│ └── hooks/ # Hook插件
|
||||
│ ├── web_search/ # 联网搜索Hook
|
||||
│ │ ├── web_search_hook.go
|
||||
│ │ └── init.go
|
||||
│ └── content_filter/ # 内容过滤Hook
|
||||
│ ├── content_filter_hook.go
|
||||
│ └── init.go
|
||||
│
|
||||
├── marketplace/ # 🟣 Tier 2: 运行时插件(待实施,Phase 2)
|
||||
│ ├── loader/ # go-plugin加载器
|
||||
│ │ ├── plugin_client.go # 插件客户端
|
||||
│ │ ├── plugin_server.go # 插件服务器
|
||||
│ │ └── lifecycle.go # 生命周期管理
|
||||
│ ├── manager/ # 插件管理器
|
||||
│ │ ├── installer.go # 安装/卸载
|
||||
│ │ ├── updater.go # 版本更新
|
||||
│ │ └── registry.go # 插件注册表
|
||||
│ ├── security/ # 安全模块
|
||||
│ │ ├── signature.go # Ed25519签名验证
|
||||
│ │ ├── checksum.go # SHA256校验
|
||||
│ │ └── sandbox.go # 沙箱配置
|
||||
│ ├── store/ # 插件商店客户端
|
||||
│ │ ├── client.go # 商店API客户端
|
||||
│ │ ├── search.go # 搜索功能
|
||||
│ │ └── download.go # 下载管理
|
||||
│ └── proto/ # gRPC协议定义
|
||||
│ ├── hook.proto # Hook插件协议
|
||||
│ ├── channel.proto # Channel插件协议
|
||||
│ └── common.proto # 通用消息
|
||||
│
|
||||
├── plugins_external/ # 第三方插件安装目录
|
||||
│ ├── installed/ # 已安装插件
|
||||
│ │ ├── awesome-hook-v1.0.0/
|
||||
│ │ ├── custom-llm-v2.1.0/
|
||||
│ │ └── slack-notify-v1.5.0/
|
||||
│ ├── cache/ # 下载缓存
|
||||
│ └── temp/ # 临时文件
|
||||
│
|
||||
├── relay/ # Relay层
|
||||
│ ├── hooks/ # Hook执行链
|
||||
│ │ ├── chain.go # Hook链管理器
|
||||
│ │ ├── context.go # Hook上下文
|
||||
│ │ └── context_builder.go # 上下文构建器
|
||||
│ └── relay_adaptor.go # Channel适配器(优先从Registry获取)
|
||||
│
|
||||
├── config/ # 配置系统
|
||||
│ ├── plugins.yaml # 插件配置(Tier 1 + Tier 2)
|
||||
│ └── plugin_config.go # 配置加载器(支持环境变量)
|
||||
│
|
||||
└── (其他现有目录保持不变)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 完整架构图
|
||||
|
||||
### 系统架构总览
|
||||
|
||||
```mermaid
|
||||
graph TB
|
||||
subgraph "🌐 API层"
|
||||
Client[客户端请求]
|
||||
end
|
||||
|
||||
subgraph "🔐 中间件层"
|
||||
Auth[认证中间件]
|
||||
RateLimit[限流中间件]
|
||||
Cache[缓存中间件]
|
||||
end
|
||||
|
||||
subgraph "🎯 核心层 Core"
|
||||
Registry[插件注册中心]
|
||||
ChannelReg[Channel Registry]
|
||||
HookReg[Hook Registry]
|
||||
MidReg[Middleware Registry]
|
||||
|
||||
Registry --> ChannelReg
|
||||
Registry --> HookReg
|
||||
Registry --> MidReg
|
||||
end
|
||||
|
||||
subgraph "🔵 Tier 1: 编译时插件(已实施)"
|
||||
direction TB
|
||||
|
||||
Channels[31个 Channel Plugins]
|
||||
OpenAI[OpenAI]
|
||||
Claude[Claude]
|
||||
Gemini[Gemini]
|
||||
Others[其他28个...]
|
||||
|
||||
Channels --> OpenAI
|
||||
Channels --> Claude
|
||||
Channels --> Gemini
|
||||
Channels --> Others
|
||||
|
||||
Hooks[Hook Plugins]
|
||||
WebSearch[Web Search Hook]
|
||||
ContentFilter[Content Filter Hook]
|
||||
|
||||
Hooks --> WebSearch
|
||||
Hooks --> ContentFilter
|
||||
end
|
||||
|
||||
subgraph "🟣 Tier 2: 运行时插件(待实施)"
|
||||
direction TB
|
||||
|
||||
Marketplace[🏪 Plugin Marketplace]
|
||||
ExtHook[External Hooks<br/>Python/Go/Node.js]
|
||||
ExtChannel[External Channels<br/>小众AI提供商]
|
||||
ExtMid[External Middleware<br/>企业集成]
|
||||
ExtUI[UI Extensions<br/>自定义仪表板]
|
||||
|
||||
Marketplace --> ExtHook
|
||||
Marketplace --> ExtChannel
|
||||
Marketplace --> ExtMid
|
||||
Marketplace --> ExtUI
|
||||
end
|
||||
|
||||
subgraph "⚡ Relay执行流程"
|
||||
direction LR
|
||||
HookChain[Hook Chain]
|
||||
BeforeHook[OnBeforeRequest]
|
||||
ChannelAdaptor[Channel Adaptor]
|
||||
AfterHook[OnAfterResponse]
|
||||
|
||||
HookChain --> BeforeHook
|
||||
BeforeHook --> ChannelAdaptor
|
||||
ChannelAdaptor --> AfterHook
|
||||
end
|
||||
|
||||
subgraph "🌍 上游服务"
|
||||
Upstream[AI Provider APIs]
|
||||
end
|
||||
|
||||
Client --> Auth
|
||||
Auth --> RateLimit
|
||||
RateLimit --> Cache
|
||||
Cache --> Registry
|
||||
|
||||
Channels --> ChannelReg
|
||||
Hooks --> HookReg
|
||||
|
||||
Registry --> HookChain
|
||||
HookChain --> Upstream
|
||||
Upstream --> HookChain
|
||||
|
||||
Registry -.gRPC/RPC.-> ExtHook
|
||||
Registry -.gRPC/RPC.-> ExtChannel
|
||||
Registry -.gRPC/RPC.-> ExtMid
|
||||
|
||||
style Marketplace fill:#f9f,stroke:#333,stroke-width:4px
|
||||
style Registry fill:#bbf,stroke:#333,stroke-width:4px
|
||||
style Channels fill:#bfb,stroke:#333,stroke-width:2px
|
||||
style Hooks fill:#bfb,stroke:#333,stroke-width:2px
|
||||
```
|
||||
|
||||
### 双层插件系统架构
|
||||
|
||||
```mermaid
|
||||
graph LR
|
||||
subgraph "🔵 Tier 1: 编译时插件"
|
||||
T1[性能: 100%<br/>语言: Go only<br/>部署: 编译到二进制]
|
||||
T1Chan[31 Channels]
|
||||
T1Hook[2 Hooks]
|
||||
|
||||
T1 --> T1Chan
|
||||
T1 --> T1Hook
|
||||
end
|
||||
|
||||
subgraph "🟣 Tier 2: 运行时插件"
|
||||
T2[性能: 90-95%<br/>语言: Go/Python/Node.js<br/>部署: 独立进程]
|
||||
T2Hook[External Hooks]
|
||||
T2Chan[External Channels]
|
||||
T2Mid[External Middleware]
|
||||
T2UI[UI Extensions]
|
||||
|
||||
T2 --> T2Hook
|
||||
T2 --> T2Chan
|
||||
T2 --> T2Mid
|
||||
T2 --> T2UI
|
||||
end
|
||||
|
||||
T1 -.进程内调用.-> Core[Core System]
|
||||
T2 -.gRPC/RPC.-> Core
|
||||
|
||||
style T1 fill:#bfb,stroke:#333,stroke-width:3px
|
||||
style T2 fill:#f9f,stroke:#333,stroke-width:3px
|
||||
style Core fill:#bbf,stroke:#333,stroke-width:3px
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 核心要点说明
|
||||
|
||||
### 1. 双层插件架构
|
||||
|
||||
| 层级 | 技术方案 | 性能 | 适用场景 | 开发语言 |
|
||||
|------|---------|------|---------|---------|
|
||||
| **Tier 1<br/>编译时插件** | 编译时链接 | 100%<br/>零损失 | • 核心Channel(OpenAI等)<br/>• 内置Hook<br/>• 高频调用路径 | Go only |
|
||||
| **Tier 2<br/>运行时插件** | go-plugin<br/>gRPC | 90-95%<br/>5-10%开销 | • 第三方扩展<br/>• 企业定制<br/>• 多语言集成 | Go/Python/<br/>Node.js/Rust |
|
||||
|
||||
### 2. 核心组件
|
||||
|
||||
#### Core层(核心引擎)
|
||||
- **interfaces/**: 定义ChannelPlugin、RelayHook、MiddlewarePlugin接口
|
||||
- **registry/**: 线程安全的插件注册中心,支持O(1)查找、优先级排序
|
||||
|
||||
#### Relay Hook链
|
||||
- **执行流程**: OnBeforeRequest → Channel.DoRequest → OnAfterResponse
|
||||
- **特性**: 优先级排序、短路机制、数据共享(HookContext.Data)
|
||||
- **应用场景**: 联网搜索、内容过滤、日志增强、缓存策略
|
||||
|
||||
### 3. Tier 1: 编译时插件(已实施 ✅)
|
||||
|
||||
**特点**:
|
||||
- 零性能损失,编译后与硬编码无差异
|
||||
- init()函数自动注册到Registry
|
||||
- YAML配置启用/禁用
|
||||
|
||||
**已实现**:
|
||||
- ✅ 31个Channel插件(OpenAI、Claude、Gemini等)
|
||||
- ✅ 2个Hook插件(web_search、content_filter)
|
||||
- ✅ Hook执行链
|
||||
- ✅ 配置系统(支持环境变量展开)
|
||||
|
||||
### 4. Tier 2: 运行时插件(待实施 🚧)
|
||||
|
||||
**基于**: [hashicorp/go-plugin](https://github.com/hashicorp/go-plugin)(Vault/Terraform使用)
|
||||
|
||||
**优势**:
|
||||
- ✅ 进程隔离(第三方代码崩溃不影响主程序)
|
||||
- ✅ 多语言支持(gRPC协议)
|
||||
- ✅ 热插拔(无需重启)
|
||||
- ✅ 安全验证(Ed25519签名 + SHA256校验 + TLS加密)
|
||||
- ✅ 独立分发(插件商店)
|
||||
|
||||
**适用场景**:
|
||||
- 第三方开发者扩展
|
||||
- 企业定制业务逻辑
|
||||
- Python ML模型集成
|
||||
- 第三方服务集成(Slack/钉钉/企业微信)
|
||||
- UI扩展
|
||||
|
||||
### 5. 安全机制
|
||||
|
||||
**Tier 1(编译时)**:
|
||||
- 内部代码审查
|
||||
- 编译期类型安全
|
||||
|
||||
**Tier 2(运行时)**:
|
||||
- Ed25519签名验证
|
||||
- SHA256校验和
|
||||
- gRPC TLS加密
|
||||
- 进程资源限制(内存/CPU)
|
||||
- 插件商店审核机制
|
||||
- 可信发布者白名单
|
||||
|
||||
### 6. 配置系统
|
||||
|
||||
**单一配置文件**: `config/plugins.yaml`
|
||||
|
||||
```yaml
|
||||
# Tier 1: 编译时插件
|
||||
plugins:
|
||||
hooks:
|
||||
- name: web_search
|
||||
enabled: false
|
||||
priority: 50
|
||||
config:
|
||||
api_key: ${WEB_SEARCH_API_KEY}
|
||||
|
||||
# Tier 2: 运行时插件(待实施)
|
||||
external_plugins:
|
||||
enabled: true
|
||||
hooks:
|
||||
- name: awesome_hook
|
||||
binary: awesome-hook-v1.0.0/awesome-hook
|
||||
checksum: sha256:abc123...
|
||||
|
||||
# 插件商店
|
||||
marketplace:
|
||||
enabled: true
|
||||
api_url: https://plugins.new-api.com
|
||||
```
|
||||
|
||||
### 7. 性能对比
|
||||
|
||||
| 场景 | Tier 1 | Tier 2 | RPC开销 |
|
||||
|------|--------|--------|--------|
|
||||
| 核心Channel | 100% | N/A | 0% |
|
||||
| 内置Hook | 100% | N/A | 0% |
|
||||
| 第三方Hook | N/A | 92-95% | 5-8% |
|
||||
| Python插件 | N/A | 88-92% | 8-12% |
|
||||
|
||||
### 8. 实施路线图
|
||||
|
||||
#### Phase 1: 编译时插件系统 ✅ 已完成
|
||||
- Core Registry + Hook Chain
|
||||
- 31个Channel插件 + 2个Hook示例
|
||||
- YAML配置系统
|
||||
|
||||
#### Phase 2: go-plugin基础
|
||||
- protobuf协议定义
|
||||
- PluginLoader实现
|
||||
- 签名验证系统
|
||||
- Python/Go SDK
|
||||
|
||||
#### Phase 3: 插件商店
|
||||
- 商店后端API
|
||||
- Web UI(搜索、安装、管理)
|
||||
- CLI工具
|
||||
- 多语言SDK
|
||||
|
||||
### 9. 扩展示例
|
||||
|
||||
**新增Tier 1插件(编译时)**:
|
||||
```go
|
||||
// 1. 实现接口
|
||||
type MyHook struct{}
|
||||
func (h *MyHook) OnBeforeRequest(ctx *HookContext) error { /*...*/ }
|
||||
|
||||
// 2. 注册
|
||||
func init() { registry.RegisterHook(&MyHook{}) }
|
||||
|
||||
// 3. 导入到main.go
|
||||
import _ "github.com/xxx/plugins/hooks/my_hook"
|
||||
```
|
||||
|
||||
**新增Tier 2插件(运行时)**:
|
||||
```python
|
||||
# external-plugin/my_hook.py
|
||||
from new_api_plugin_sdk import HookPlugin, serve
|
||||
|
||||
class MyHook(HookPlugin):
|
||||
def on_before_request(self, ctx):
|
||||
return {"modified_body": ctx.request_body}
|
||||
|
||||
serve(MyHook())
|
||||
```
|
||||
15
dto/audio.go
15
dto/audio.go
@@ -1,17 +1,22 @@
|
||||
package dto
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
|
||||
"github.com/QuantumNous/new-api/types"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type AudioRequest struct {
|
||||
Model string `json:"model"`
|
||||
Input string `json:"input"`
|
||||
Voice string `json:"voice"`
|
||||
Speed float64 `json:"speed,omitempty"`
|
||||
ResponseFormat string `json:"response_format,omitempty"`
|
||||
Model string `json:"model"`
|
||||
Input string `json:"input"`
|
||||
Voice string `json:"voice"`
|
||||
Instructions string `json:"instructions,omitempty"`
|
||||
ResponseFormat string `json:"response_format,omitempty"`
|
||||
Speed float64 `json:"speed,omitempty"`
|
||||
StreamFormat string `json:"stream_format,omitempty"`
|
||||
Metadata json.RawMessage `json:"metadata,omitempty"`
|
||||
}
|
||||
|
||||
func (r *AudioRequest) GetTokenCountMeta() *types.TokenCountMeta {
|
||||
|
||||
@@ -16,6 +16,13 @@ const (
|
||||
VertexKeyTypeAPIKey VertexKeyType = "api_key"
|
||||
)
|
||||
|
||||
type AwsKeyType string
|
||||
|
||||
const (
|
||||
AwsKeyTypeAKSK AwsKeyType = "ak_sk" // 默认
|
||||
AwsKeyTypeApiKey AwsKeyType = "api_key"
|
||||
)
|
||||
|
||||
type ChannelOtherSettings struct {
|
||||
AzureResponsesVersion string `json:"azure_responses_version,omitempty"`
|
||||
VertexKeyType VertexKeyType `json:"vertex_key_type,omitempty"` // "json" or "api_key"
|
||||
@@ -23,6 +30,7 @@ type ChannelOtherSettings struct {
|
||||
AllowServiceTier bool `json:"allow_service_tier,omitempty"` // 是否允许 service_tier 透传(默认过滤以避免额外计费)
|
||||
DisableStore bool `json:"disable_store,omitempty"` // 是否禁用 store 透传(默认允许透传,禁用后可能导致 Codex 无法使用)
|
||||
AllowSafetyIdentifier bool `json:"allow_safety_identifier,omitempty"` // 是否允许 safety_identifier 透传(默认过滤以保护用户隐私)
|
||||
AwsKeyType AwsKeyType `json:"aws_key_type,omitempty"`
|
||||
}
|
||||
|
||||
func (s *ChannelOtherSettings) IsOpenRouterEnterprise() bool {
|
||||
|
||||
@@ -24,7 +24,7 @@ type ClaudeMediaMessage struct {
|
||||
StopReason *string `json:"stop_reason,omitempty"`
|
||||
PartialJson *string `json:"partial_json,omitempty"`
|
||||
Role string `json:"role,omitempty"`
|
||||
Thinking string `json:"thinking,omitempty"`
|
||||
Thinking *string `json:"thinking,omitempty"`
|
||||
Signature string `json:"signature,omitempty"`
|
||||
Delta string `json:"delta,omitempty"`
|
||||
CacheControl json.RawMessage `json:"cache_control,omitempty"`
|
||||
@@ -148,6 +148,10 @@ func (c *ClaudeMessage) SetStringContent(content string) {
|
||||
c.Content = content
|
||||
}
|
||||
|
||||
func (c *ClaudeMessage) SetContent(content any) {
|
||||
c.Content = content
|
||||
}
|
||||
|
||||
func (c *ClaudeMessage) ParseContent() ([]ClaudeMediaMessage, error) {
|
||||
return common.Any2Type[[]ClaudeMediaMessage](c.Content)
|
||||
}
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
)
|
||||
|
||||
type GeminiChatRequest struct {
|
||||
Requests []GeminiChatRequest `json:"requests,omitempty"` // For batch requests
|
||||
Contents []GeminiChatContent `json:"contents"`
|
||||
SafetySettings []GeminiChatSafetySettings `json:"safetySettings,omitempty"`
|
||||
GenerationConfig GeminiChatGenerationConfig `json:"generationConfig,omitempty"`
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package common
|
||||
package dto
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
@@ -27,7 +27,7 @@ type OpenAIVideo struct {
|
||||
Size string `json:"size,omitempty"`
|
||||
RemixedFromVideoID string `json:"remixed_from_video_id,omitempty"`
|
||||
Error *OpenAIVideoError `json:"error,omitempty"`
|
||||
Metadata map[string]any `json:"metadata,omitempty"`
|
||||
Metadata map[string]any `json:"meta_data,omitempty"`
|
||||
}
|
||||
|
||||
func (m *OpenAIVideo) SetProgressStr(progress string) {
|
||||
1
go.mod
1
go.mod
@@ -40,7 +40,6 @@ require (
|
||||
golang.org/x/image v0.23.0
|
||||
golang.org/x/net v0.43.0
|
||||
golang.org/x/sync v0.17.0
|
||||
gopkg.in/yaml.v3 v3.0.1
|
||||
gorm.io/driver/mysql v1.4.3
|
||||
gorm.io/driver/postgres v1.5.2
|
||||
gorm.io/gorm v1.25.2
|
||||
|
||||
@@ -153,5 +153,5 @@ func LogJson(ctx context.Context, msg string, obj any) {
|
||||
LogError(ctx, fmt.Sprintf("json marshal failed: %s", err.Error()))
|
||||
return
|
||||
}
|
||||
LogInfo(ctx, fmt.Sprintf("%s | %s", msg, string(jsonStr)))
|
||||
LogDebug(ctx, fmt.Sprintf("%s | %s", msg, string(jsonStr)))
|
||||
}
|
||||
|
||||
89
main.go
89
main.go
@@ -21,13 +21,6 @@ import (
|
||||
"github.com/QuantumNous/new-api/service"
|
||||
"github.com/QuantumNous/new-api/setting/ratio_setting"
|
||||
|
||||
// Plugin System
|
||||
coreregistry "github.com/QuantumNous/new-api/core/registry"
|
||||
_ "github.com/QuantumNous/new-api/plugins/channels" // 自动注册channel插件
|
||||
_ "github.com/QuantumNous/new-api/plugins/hooks/web_search" // 自动注册web_search hook
|
||||
_ "github.com/QuantumNous/new-api/plugins/hooks/content_filter" // 自动注册content_filter hook
|
||||
relayhooks "github.com/QuantumNous/new-api/relay/hooks"
|
||||
|
||||
"github.com/bytedance/gopkg/util/gopool"
|
||||
"github.com/gin-contrib/sessions"
|
||||
"github.com/gin-contrib/sessions/cookie"
|
||||
@@ -157,6 +150,26 @@ func main() {
|
||||
})
|
||||
server.Use(sessions.Sessions("session", store))
|
||||
|
||||
InjectUmamiAnalytics()
|
||||
InjectGoogleAnalytics()
|
||||
|
||||
// 设置路由
|
||||
router.SetRouter(server, buildFS, indexPage)
|
||||
var port = os.Getenv("PORT")
|
||||
if port == "" {
|
||||
port = strconv.Itoa(*common.Port)
|
||||
}
|
||||
|
||||
// Log startup success message
|
||||
common.LogStartupSuccess(startTime, port)
|
||||
|
||||
err = server.Run(":" + port)
|
||||
if err != nil {
|
||||
common.FatalLog("failed to start HTTP server: " + err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func InjectUmamiAnalytics() {
|
||||
analyticsInjectBuilder := &strings.Builder{}
|
||||
if os.Getenv("UMAMI_WEBSITE_ID") != "" {
|
||||
umamiSiteID := os.Getenv("UMAMI_WEBSITE_ID")
|
||||
@@ -171,21 +184,28 @@ func main() {
|
||||
analyticsInjectBuilder.WriteString("\"></script>")
|
||||
}
|
||||
analyticsInject := analyticsInjectBuilder.String()
|
||||
indexPage = bytes.ReplaceAll(indexPage, []byte("<analytics></analytics>\n"), []byte(analyticsInject))
|
||||
indexPage = bytes.ReplaceAll(indexPage, []byte("<!--umami-->\n"), []byte(analyticsInject))
|
||||
}
|
||||
|
||||
router.SetRouter(server, buildFS, indexPage)
|
||||
var port = os.Getenv("PORT")
|
||||
if port == "" {
|
||||
port = strconv.Itoa(*common.Port)
|
||||
}
|
||||
|
||||
// Log startup success message
|
||||
common.LogStartupSuccess(startTime, port)
|
||||
|
||||
err = server.Run(":" + port)
|
||||
if err != nil {
|
||||
common.FatalLog("failed to start HTTP server: " + err.Error())
|
||||
func InjectGoogleAnalytics() {
|
||||
analyticsInjectBuilder := &strings.Builder{}
|
||||
if os.Getenv("GOOGLE_ANALYTICS_ID") != "" {
|
||||
gaID := os.Getenv("GOOGLE_ANALYTICS_ID")
|
||||
// Google Analytics 4 (gtag.js)
|
||||
analyticsInjectBuilder.WriteString("<script async src=\"https://www.googletagmanager.com/gtag/js?id=")
|
||||
analyticsInjectBuilder.WriteString(gaID)
|
||||
analyticsInjectBuilder.WriteString("\"></script>")
|
||||
analyticsInjectBuilder.WriteString("<script>")
|
||||
analyticsInjectBuilder.WriteString("window.dataLayer = window.dataLayer || [];")
|
||||
analyticsInjectBuilder.WriteString("function gtag(){dataLayer.push(arguments);}")
|
||||
analyticsInjectBuilder.WriteString("gtag('js', new Date());")
|
||||
analyticsInjectBuilder.WriteString("gtag('config', '")
|
||||
analyticsInjectBuilder.WriteString(gaID)
|
||||
analyticsInjectBuilder.WriteString("');")
|
||||
analyticsInjectBuilder.WriteString("</script>")
|
||||
}
|
||||
analyticsInject := analyticsInjectBuilder.String()
|
||||
indexPage = bytes.ReplaceAll(indexPage, []byte("<!--Google Analytics-->\n"), []byte(analyticsInject))
|
||||
}
|
||||
|
||||
func InitResources() error {
|
||||
@@ -236,34 +256,5 @@ func InitResources() error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Initialize Plugin System
|
||||
InitPluginSystem()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// InitPluginSystem 初始化插件系统
|
||||
func InitPluginSystem() {
|
||||
common.SysLog("Initializing plugin system...")
|
||||
|
||||
// 1. 加载插件配置
|
||||
// config.LoadPluginConfig() 会在各个插件的init()中自动调用
|
||||
|
||||
// 2. 注册Channel插件
|
||||
// 注意:这会在 plugins/channels/registry.go 的 init() 中自动完成
|
||||
// 但为了确保加载,我们显式导入
|
||||
common.SysLog("Registering channel plugins...")
|
||||
|
||||
// 3. 初始化Hook链
|
||||
common.SysLog("Initializing hook chain...")
|
||||
_ = relayhooks.GetGlobalChain()
|
||||
|
||||
hookCount := coreregistry.HookCount()
|
||||
enabledHookCount := coreregistry.EnabledHookCount()
|
||||
common.SysLog(fmt.Sprintf("Plugin system initialized: %d hooks registered (%d enabled)",
|
||||
hookCount, enabledHookCount))
|
||||
|
||||
channelCount := len(coreregistry.ListChannels())
|
||||
common.SysLog(fmt.Sprintf("Registered %d channel plugins", channelCount))
|
||||
}
|
||||
|
||||
@@ -102,7 +102,7 @@ func Distribute() func(c *gin.Context) {
|
||||
if userGroup == "auto" {
|
||||
showGroup = fmt.Sprintf("auto(%s)", selectGroup)
|
||||
}
|
||||
message := fmt.Sprintf("获取分组 %s 下模型 %s 的可用渠道失败(数据库一致性已被破坏,distributor): %s", showGroup, modelRequest.Model, err.Error())
|
||||
message := fmt.Sprintf("获取分组 %s 下模型 %s 的可用渠道失败(distributor): %s", showGroup, modelRequest.Model, err.Error())
|
||||
// 如果错误,但是渠道不为空,说明是数据库一致性问题
|
||||
//if channel != nil {
|
||||
// common.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id))
|
||||
|
||||
14
model/log.go
14
model/log.go
@@ -39,13 +39,15 @@ type Log struct {
|
||||
Other string `json:"other"`
|
||||
}
|
||||
|
||||
// don't use iota, avoid change log type value
|
||||
const (
|
||||
LogTypeUnknown = iota
|
||||
LogTypeTopup
|
||||
LogTypeConsume
|
||||
LogTypeManage
|
||||
LogTypeSystem
|
||||
LogTypeError
|
||||
LogTypeUnknown = 0
|
||||
LogTypeTopup = 1
|
||||
LogTypeConsume = 2
|
||||
LogTypeManage = 3
|
||||
LogTypeSystem = 4
|
||||
LogTypeError = 5
|
||||
LogTypeRefund = 6
|
||||
)
|
||||
|
||||
func formatUserLogs(logs []*Log) {
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/QuantumNous/new-api/constant"
|
||||
"github.com/QuantumNous/new-api/dto"
|
||||
commonRelay "github.com/QuantumNous/new-api/relay/common"
|
||||
)
|
||||
|
||||
@@ -15,15 +16,15 @@ func (t TaskStatus) ToVideoStatus() string {
|
||||
var status string
|
||||
switch t {
|
||||
case TaskStatusQueued, TaskStatusSubmitted:
|
||||
status = commonRelay.VideoStatusQueued
|
||||
status = dto.VideoStatusQueued
|
||||
case TaskStatusInProgress:
|
||||
status = commonRelay.VideoStatusInProgress
|
||||
status = dto.VideoStatusInProgress
|
||||
case TaskStatusSuccess:
|
||||
status = commonRelay.VideoStatusCompleted
|
||||
status = dto.VideoStatusCompleted
|
||||
case TaskStatusFailure:
|
||||
status = commonRelay.VideoStatusFailed
|
||||
status = dto.VideoStatusFailed
|
||||
default:
|
||||
status = commonRelay.VideoStatusUnknown // Default fallback
|
||||
status = dto.VideoStatusUnknown // Default fallback
|
||||
}
|
||||
return status
|
||||
}
|
||||
@@ -45,6 +46,7 @@ type Task struct {
|
||||
TaskID string `json:"task_id" gorm:"type:varchar(191);index"` // 第三方id,不一定有/ song id\ Task id
|
||||
Platform constant.TaskPlatform `json:"platform" gorm:"type:varchar(30);index"` // 平台
|
||||
UserId int `json:"user_id" gorm:"index"`
|
||||
Group string `json:"group" gorm:"type:varchar(50)"` // 修正计费用
|
||||
ChannelId int `json:"channel_id" gorm:"index"`
|
||||
Quota int `json:"quota"`
|
||||
Action string `json:"action" gorm:"type:varchar(40);index"` // 任务类型, song, lyrics, description-mode
|
||||
@@ -98,6 +100,7 @@ type SyncTaskQueryParams struct {
|
||||
func InitTask(platform constant.TaskPlatform, relayInfo *commonRelay.RelayInfo) *Task {
|
||||
t := &Task{
|
||||
UserId: relayInfo.UserId,
|
||||
Group: relayInfo.UsingGroup,
|
||||
SubmitTime: time.Now().Unix(),
|
||||
Status: TaskStatusNotStart,
|
||||
Progress: "0%",
|
||||
|
||||
@@ -1,114 +0,0 @@
|
||||
package channels
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
"github.com/QuantumNous/new-api/dto"
|
||||
"github.com/QuantumNous/new-api/relay/channel"
|
||||
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||||
"github.com/QuantumNous/new-api/types"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// BaseChannelPlugin 基础Channel插件
|
||||
// 包装现有的Adaptor实现,使其符合ChannelPlugin接口
|
||||
type BaseChannelPlugin struct {
|
||||
adaptor channel.Adaptor
|
||||
name string
|
||||
version string
|
||||
priority int
|
||||
}
|
||||
|
||||
// NewBaseChannelPlugin 创建基础Channel插件
|
||||
func NewBaseChannelPlugin(adaptor channel.Adaptor, name, version string, priority int) *BaseChannelPlugin {
|
||||
return &BaseChannelPlugin{
|
||||
adaptor: adaptor,
|
||||
name: name,
|
||||
version: version,
|
||||
priority: priority,
|
||||
}
|
||||
}
|
||||
|
||||
// Name 返回插件名称
|
||||
func (p *BaseChannelPlugin) Name() string {
|
||||
return p.name
|
||||
}
|
||||
|
||||
// Version 返回插件版本
|
||||
func (p *BaseChannelPlugin) Version() string {
|
||||
return p.version
|
||||
}
|
||||
|
||||
// Priority 返回优先级
|
||||
func (p *BaseChannelPlugin) Priority() int {
|
||||
return p.priority
|
||||
}
|
||||
|
||||
// 以下方法直接委托给内部的Adaptor
|
||||
|
||||
func (p *BaseChannelPlugin) Init(info *relaycommon.RelayInfo) {
|
||||
p.adaptor.Init(info)
|
||||
}
|
||||
|
||||
func (p *BaseChannelPlugin) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
return p.adaptor.GetRequestURL(info)
|
||||
}
|
||||
|
||||
func (p *BaseChannelPlugin) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
|
||||
return p.adaptor.SetupRequestHeader(c, req, info)
|
||||
}
|
||||
|
||||
func (p *BaseChannelPlugin) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
return p.adaptor.ConvertOpenAIRequest(c, info, request)
|
||||
}
|
||||
|
||||
func (p *BaseChannelPlugin) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
|
||||
return p.adaptor.ConvertRerankRequest(c, relayMode, request)
|
||||
}
|
||||
|
||||
func (p *BaseChannelPlugin) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
|
||||
return p.adaptor.ConvertEmbeddingRequest(c, info, request)
|
||||
}
|
||||
|
||||
func (p *BaseChannelPlugin) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
||||
return p.adaptor.ConvertAudioRequest(c, info, request)
|
||||
}
|
||||
|
||||
func (p *BaseChannelPlugin) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
|
||||
return p.adaptor.ConvertImageRequest(c, info, request)
|
||||
}
|
||||
|
||||
func (p *BaseChannelPlugin) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {
|
||||
return p.adaptor.ConvertOpenAIResponsesRequest(c, info, request)
|
||||
}
|
||||
|
||||
func (p *BaseChannelPlugin) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
||||
return p.adaptor.DoRequest(c, info, requestBody)
|
||||
}
|
||||
|
||||
func (p *BaseChannelPlugin) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
|
||||
return p.adaptor.DoResponse(c, resp, info)
|
||||
}
|
||||
|
||||
func (p *BaseChannelPlugin) GetModelList() []string {
|
||||
return p.adaptor.GetModelList()
|
||||
}
|
||||
|
||||
func (p *BaseChannelPlugin) GetChannelName() string {
|
||||
return p.adaptor.GetChannelName()
|
||||
}
|
||||
|
||||
func (p *BaseChannelPlugin) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error) {
|
||||
return p.adaptor.ConvertClaudeRequest(c, info, request)
|
||||
}
|
||||
|
||||
func (p *BaseChannelPlugin) ConvertGeminiRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeminiChatRequest) (any, error) {
|
||||
return p.adaptor.ConvertGeminiRequest(c, info, request)
|
||||
}
|
||||
|
||||
// GetAdaptor 获取内部的Adaptor(用于向后兼容)
|
||||
func (p *BaseChannelPlugin) GetAdaptor() channel.Adaptor {
|
||||
return p.adaptor
|
||||
}
|
||||
|
||||
@@ -1,106 +0,0 @@
|
||||
package channels
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/constant"
|
||||
"github.com/QuantumNous/new-api/core/registry"
|
||||
"github.com/QuantumNous/new-api/relay/channel"
|
||||
"github.com/QuantumNous/new-api/relay/channel/ali"
|
||||
"github.com/QuantumNous/new-api/relay/channel/aws"
|
||||
"github.com/QuantumNous/new-api/relay/channel/baidu"
|
||||
"github.com/QuantumNous/new-api/relay/channel/baidu_v2"
|
||||
"github.com/QuantumNous/new-api/relay/channel/claude"
|
||||
"github.com/QuantumNous/new-api/relay/channel/cloudflare"
|
||||
"github.com/QuantumNous/new-api/relay/channel/cohere"
|
||||
"github.com/QuantumNous/new-api/relay/channel/coze"
|
||||
"github.com/QuantumNous/new-api/relay/channel/deepseek"
|
||||
"github.com/QuantumNous/new-api/relay/channel/dify"
|
||||
"github.com/QuantumNous/new-api/relay/channel/gemini"
|
||||
"github.com/QuantumNous/new-api/relay/channel/jimeng"
|
||||
"github.com/QuantumNous/new-api/relay/channel/jina"
|
||||
"github.com/QuantumNous/new-api/relay/channel/mistral"
|
||||
"github.com/QuantumNous/new-api/relay/channel/mokaai"
|
||||
"github.com/QuantumNous/new-api/relay/channel/moonshot"
|
||||
"github.com/QuantumNous/new-api/relay/channel/ollama"
|
||||
"github.com/QuantumNous/new-api/relay/channel/openai"
|
||||
"github.com/QuantumNous/new-api/relay/channel/palm"
|
||||
"github.com/QuantumNous/new-api/relay/channel/perplexity"
|
||||
"github.com/QuantumNous/new-api/relay/channel/siliconflow"
|
||||
"github.com/QuantumNous/new-api/relay/channel/submodel"
|
||||
"github.com/QuantumNous/new-api/relay/channel/tencent"
|
||||
"github.com/QuantumNous/new-api/relay/channel/vertex"
|
||||
"github.com/QuantumNous/new-api/relay/channel/volcengine"
|
||||
"github.com/QuantumNous/new-api/relay/channel/xai"
|
||||
"github.com/QuantumNous/new-api/relay/channel/xunfei"
|
||||
"github.com/QuantumNous/new-api/relay/channel/zhipu"
|
||||
"github.com/QuantumNous/new-api/relay/channel/zhipu_4v"
|
||||
)
|
||||
|
||||
// init 包初始化时自动注册所有Channel插件
|
||||
func init() {
|
||||
RegisterAllChannels()
|
||||
}
|
||||
|
||||
// RegisterAllChannels 注册所有Channel插件
|
||||
func RegisterAllChannels() {
|
||||
// 包装现有的Adaptor并注册为插件
|
||||
channels := []struct {
|
||||
channelType int
|
||||
adaptor channel.Adaptor
|
||||
name string
|
||||
}{
|
||||
{constant.APITypeOpenAI, &openai.Adaptor{}, "openai"},
|
||||
{constant.APITypeAnthropic, &claude.Adaptor{}, "claude"},
|
||||
{constant.APITypeGemini, &gemini.Adaptor{}, "gemini"},
|
||||
{constant.APITypeAli, &ali.Adaptor{}, "ali"},
|
||||
{constant.APITypeBaidu, &baidu.Adaptor{}, "baidu"},
|
||||
{constant.APITypeBaiduV2, &baidu_v2.Adaptor{}, "baidu_v2"},
|
||||
{constant.APITypeTencent, &tencent.Adaptor{}, "tencent"},
|
||||
{constant.APITypeXunfei, &xunfei.Adaptor{}, "xunfei"},
|
||||
{constant.APITypeZhipu, &zhipu.Adaptor{}, "zhipu"},
|
||||
{constant.APITypeZhipuV4, &zhipu_4v.Adaptor{}, "zhipu_v4"},
|
||||
{constant.APITypeOllama, &ollama.Adaptor{}, "ollama"},
|
||||
{constant.APITypePerplexity, &perplexity.Adaptor{}, "perplexity"},
|
||||
{constant.APITypeAws, &aws.Adaptor{}, "aws"},
|
||||
{constant.APITypeCohere, &cohere.Adaptor{}, "cohere"},
|
||||
{constant.APITypeDify, &dify.Adaptor{}, "dify"},
|
||||
{constant.APITypeJina, &jina.Adaptor{}, "jina"},
|
||||
{constant.APITypeCloudflare, &cloudflare.Adaptor{}, "cloudflare"},
|
||||
{constant.APITypeSiliconFlow, &siliconflow.Adaptor{}, "siliconflow"},
|
||||
{constant.APITypeVertexAi, &vertex.Adaptor{}, "vertex"},
|
||||
{constant.APITypeMistral, &mistral.Adaptor{}, "mistral"},
|
||||
{constant.APITypeDeepSeek, &deepseek.Adaptor{}, "deepseek"},
|
||||
{constant.APITypeMokaAI, &mokaai.Adaptor{}, "mokaai"},
|
||||
{constant.APITypeVolcEngine, &volcengine.Adaptor{}, "volcengine"},
|
||||
{constant.APITypeXai, &xai.Adaptor{}, "xai"},
|
||||
{constant.APITypeCoze, &coze.Adaptor{}, "coze"},
|
||||
{constant.APITypeJimeng, &jimeng.Adaptor{}, "jimeng"},
|
||||
{constant.APITypeMoonshot, &moonshot.Adaptor{}, "moonshot"},
|
||||
{constant.APITypeSubmodel, &submodel.Adaptor{}, "submodel"},
|
||||
{constant.APITypePaLM, &palm.Adaptor{}, "palm"},
|
||||
// OpenRouter 和 Xinference 使用 OpenAI adaptor
|
||||
{constant.APITypeOpenRouter, &openai.Adaptor{}, "openrouter"},
|
||||
{constant.APITypeXinference, &openai.Adaptor{}, "xinference"},
|
||||
}
|
||||
|
||||
registeredCount := 0
|
||||
for _, ch := range channels {
|
||||
plugin := NewBaseChannelPlugin(
|
||||
ch.adaptor,
|
||||
ch.name,
|
||||
"1.0.0",
|
||||
100, // 默认优先级
|
||||
)
|
||||
|
||||
if err := registry.RegisterChannel(ch.channelType, plugin); err != nil {
|
||||
common.SysError("Failed to register channel plugin: " + ch.name + ", error: " + err.Error())
|
||||
} else {
|
||||
registeredCount++
|
||||
}
|
||||
}
|
||||
|
||||
common.SysLog(fmt.Sprintf("Registered %d channel plugins", registeredCount))
|
||||
}
|
||||
|
||||
@@ -1,186 +0,0 @@
|
||||
package content_filter
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/core/interfaces"
|
||||
)
|
||||
|
||||
// ContentFilterHook 内容过滤Hook
|
||||
// 在响应返回前过滤敏感内容
|
||||
type ContentFilterHook struct {
|
||||
enabled bool
|
||||
priority int
|
||||
sensitiveWords []string
|
||||
filterNSFW bool
|
||||
filterPolitical bool
|
||||
replacementText string
|
||||
}
|
||||
|
||||
// NewContentFilterHook 创建ContentFilterHook实例
|
||||
func NewContentFilterHook(config map[string]interface{}) *ContentFilterHook {
|
||||
hook := &ContentFilterHook{
|
||||
enabled: true,
|
||||
priority: 100, // 高优先级,最后执行
|
||||
sensitiveWords: []string{},
|
||||
filterNSFW: true,
|
||||
filterPolitical: false,
|
||||
replacementText: "[已过滤]",
|
||||
}
|
||||
|
||||
if enabled, ok := config["enabled"].(bool); ok {
|
||||
hook.enabled = enabled
|
||||
}
|
||||
|
||||
if priority, ok := config["priority"].(int); ok {
|
||||
hook.priority = priority
|
||||
}
|
||||
|
||||
if filterNSFW, ok := config["filter_nsfw"].(bool); ok {
|
||||
hook.filterNSFW = filterNSFW
|
||||
}
|
||||
|
||||
if filterPolitical, ok := config["filter_political"].(bool); ok {
|
||||
hook.filterPolitical = filterPolitical
|
||||
}
|
||||
|
||||
if words, ok := config["sensitive_words"].([]interface{}); ok {
|
||||
for _, word := range words {
|
||||
if w, ok := word.(string); ok {
|
||||
hook.sensitiveWords = append(hook.sensitiveWords, w)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return hook
|
||||
}
|
||||
|
||||
// Name 返回Hook名称
|
||||
func (h *ContentFilterHook) Name() string {
|
||||
return "content_filter"
|
||||
}
|
||||
|
||||
// Priority 返回优先级
|
||||
func (h *ContentFilterHook) Priority() int {
|
||||
return h.priority
|
||||
}
|
||||
|
||||
// Enabled 返回是否启用
|
||||
func (h *ContentFilterHook) Enabled() bool {
|
||||
return h.enabled
|
||||
}
|
||||
|
||||
// OnBeforeRequest 请求前处理(不需要处理)
|
||||
func (h *ContentFilterHook) OnBeforeRequest(ctx *interfaces.HookContext) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// OnAfterResponse 响应后处理 - 过滤内容
|
||||
func (h *ContentFilterHook) OnAfterResponse(ctx *interfaces.HookContext) error {
|
||||
if !h.Enabled() {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 只处理chat completion响应
|
||||
if !strings.Contains(ctx.Request.URL.Path, "chat/completions") {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 如果没有响应体,跳过
|
||||
if len(ctx.ResponseBody) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 解析响应
|
||||
var response map[string]interface{}
|
||||
if err := json.Unmarshal(ctx.ResponseBody, &response); err != nil {
|
||||
return nil // 忽略解析错误
|
||||
}
|
||||
|
||||
// 过滤内容
|
||||
filtered := h.filterResponse(response)
|
||||
|
||||
// 如果内容被修改,更新响应体
|
||||
if filtered {
|
||||
modifiedBody, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ctx.ResponseBody = modifiedBody
|
||||
|
||||
// 记录过滤事件
|
||||
ctx.Data["content_filtered"] = true
|
||||
common.SysLog("Content filter applied to response")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// OnError 错误处理
|
||||
func (h *ContentFilterHook) OnError(ctx *interfaces.HookContext, err error) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// filterResponse 过滤响应内容
|
||||
func (h *ContentFilterHook) filterResponse(response map[string]interface{}) bool {
|
||||
modified := false
|
||||
|
||||
// 获取choices数组
|
||||
choices, ok := response["choices"].([]interface{})
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
// 遍历每个choice
|
||||
for _, choice := range choices {
|
||||
choiceMap, ok := choice.(map[string]interface{})
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
// 获取message
|
||||
message, ok := choiceMap["message"].(map[string]interface{})
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
// 获取content
|
||||
content, ok := message["content"].(string)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
// 过滤内容
|
||||
filteredContent := h.filterText(content)
|
||||
|
||||
// 如果内容被修改
|
||||
if filteredContent != content {
|
||||
message["content"] = filteredContent
|
||||
modified = true
|
||||
}
|
||||
}
|
||||
|
||||
return modified
|
||||
}
|
||||
|
||||
// filterText 过滤文本内容
|
||||
func (h *ContentFilterHook) filterText(text string) string {
|
||||
filtered := text
|
||||
|
||||
// 过滤敏感词
|
||||
for _, word := range h.sensitiveWords {
|
||||
if strings.Contains(filtered, word) {
|
||||
filtered = strings.ReplaceAll(filtered, word, h.replacementText)
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: 实现更复杂的过滤逻辑
|
||||
// - NSFW内容检测
|
||||
// - 政治敏感内容检测
|
||||
// - 使用AI模型进行内容分类
|
||||
|
||||
return filtered
|
||||
}
|
||||
|
||||
@@ -1,39 +0,0 @@
|
||||
package content_filter
|
||||
|
||||
import (
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/core/registry"
|
||||
)
|
||||
|
||||
func init() {
|
||||
// 从环境变量读取配置
|
||||
config := map[string]interface{}{
|
||||
"enabled": os.Getenv("CONTENT_FILTER_ENABLED") == "true",
|
||||
"priority": 100,
|
||||
"filter_nsfw": os.Getenv("CONTENT_FILTER_NSFW") != "false",
|
||||
"filter_political": os.Getenv("CONTENT_FILTER_POLITICAL") == "true",
|
||||
}
|
||||
|
||||
// 读取敏感词列表
|
||||
if wordsEnv := os.Getenv("CONTENT_FILTER_WORDS"); wordsEnv != "" {
|
||||
words := strings.Split(wordsEnv, ",")
|
||||
config["sensitive_words"] = words
|
||||
}
|
||||
|
||||
// 创建并注册Hook
|
||||
hook := NewContentFilterHook(config)
|
||||
|
||||
if err := registry.RegisterHook(hook); err != nil {
|
||||
common.SysError("Failed to register content_filter hook: " + err.Error())
|
||||
} else {
|
||||
if hook.Enabled() {
|
||||
common.SysLog("Content filter hook registered and enabled")
|
||||
} else {
|
||||
common.SysLog("Content filter hook registered but disabled")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,39 +0,0 @@
|
||||
package web_search
|
||||
|
||||
import (
|
||||
"os"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/core/registry"
|
||||
)
|
||||
|
||||
func init() {
|
||||
// 从环境变量读取配置
|
||||
config := map[string]interface{}{
|
||||
"enabled": os.Getenv("WEB_SEARCH_ENABLED") == "true",
|
||||
"api_key": os.Getenv("WEB_SEARCH_API_KEY"),
|
||||
"provider": getEnvOrDefault("WEB_SEARCH_PROVIDER", "google"),
|
||||
"priority": 50,
|
||||
}
|
||||
|
||||
// 创建并注册Hook
|
||||
hook := NewWebSearchHook(config)
|
||||
|
||||
if err := registry.RegisterHook(hook); err != nil {
|
||||
common.SysError("Failed to register web_search hook: " + err.Error())
|
||||
} else {
|
||||
if hook.Enabled() {
|
||||
common.SysLog("Web search hook registered and enabled")
|
||||
} else {
|
||||
common.SysLog("Web search hook registered but disabled (missing API key or not enabled)")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func getEnvOrDefault(key, defaultValue string) string {
|
||||
if value := os.Getenv(key); value != "" {
|
||||
return value
|
||||
}
|
||||
return defaultValue
|
||||
}
|
||||
|
||||
@@ -1,281 +0,0 @@
|
||||
package web_search
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/core/interfaces"
|
||||
)
|
||||
|
||||
// WebSearchHook 联网搜索Hook插件
|
||||
// 在请求发送前检测是否需要联网搜索,如果需要则调用搜索API并将结果注入到请求中
|
||||
type WebSearchHook struct {
|
||||
enabled bool
|
||||
priority int
|
||||
apiKey string
|
||||
provider string // google, bing, etc
|
||||
}
|
||||
|
||||
// NewWebSearchHook 创建WebSearchHook实例
|
||||
func NewWebSearchHook(config map[string]interface{}) *WebSearchHook {
|
||||
hook := &WebSearchHook{
|
||||
enabled: true,
|
||||
priority: 50, // 中等优先级
|
||||
provider: "google",
|
||||
}
|
||||
|
||||
if apiKey, ok := config["api_key"].(string); ok {
|
||||
hook.apiKey = apiKey
|
||||
}
|
||||
|
||||
if provider, ok := config["provider"].(string); ok {
|
||||
hook.provider = provider
|
||||
}
|
||||
|
||||
if priority, ok := config["priority"].(int); ok {
|
||||
hook.priority = priority
|
||||
}
|
||||
|
||||
if enabled, ok := config["enabled"].(bool); ok {
|
||||
hook.enabled = enabled
|
||||
}
|
||||
|
||||
return hook
|
||||
}
|
||||
|
||||
// Name 返回Hook名称
|
||||
func (h *WebSearchHook) Name() string {
|
||||
return "web_search"
|
||||
}
|
||||
|
||||
// Priority 返回优先级
|
||||
func (h *WebSearchHook) Priority() int {
|
||||
return h.priority
|
||||
}
|
||||
|
||||
// Enabled 返回是否启用
|
||||
func (h *WebSearchHook) Enabled() bool {
|
||||
return h.enabled && h.apiKey != ""
|
||||
}
|
||||
|
||||
// OnBeforeRequest 请求前处理
|
||||
func (h *WebSearchHook) OnBeforeRequest(ctx *interfaces.HookContext) error {
|
||||
if !h.Enabled() {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 只处理chat completion请求
|
||||
if !strings.Contains(ctx.Request.URL.Path, "chat/completions") {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 检查请求体中是否包含搜索关键词
|
||||
if len(ctx.RequestBody) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 解析请求体
|
||||
var requestData map[string]interface{}
|
||||
if err := json.Unmarshal(ctx.RequestBody, &requestData); err != nil {
|
||||
return nil // 忽略解析错误
|
||||
}
|
||||
|
||||
// 检查是否需要搜索(简单示例:检查最后一条消息是否包含 [search] 标记)
|
||||
if !h.shouldSearch(requestData) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 执行搜索
|
||||
searchQuery := h.extractSearchQuery(requestData)
|
||||
if searchQuery == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
common.SysLog(fmt.Sprintf("Web search triggered for query: %s", searchQuery))
|
||||
|
||||
// 调用搜索API
|
||||
searchResults, err := h.performSearch(searchQuery)
|
||||
if err != nil {
|
||||
common.SysError(fmt.Sprintf("Web search failed: %v", err))
|
||||
return nil // 不中断请求,只记录错误
|
||||
}
|
||||
|
||||
// 将搜索结果注入到请求中
|
||||
h.injectSearchResults(requestData, searchResults)
|
||||
|
||||
// 更新请求体
|
||||
modifiedBody, err := json.Marshal(requestData)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ctx.RequestBody = modifiedBody
|
||||
|
||||
// 存储到Data中供后续使用
|
||||
ctx.Data["web_search_performed"] = true
|
||||
ctx.Data["web_search_query"] = searchQuery
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// OnAfterResponse 响应后处理
|
||||
func (h *WebSearchHook) OnAfterResponse(ctx *interfaces.HookContext) error {
|
||||
// 可以在这里记录搜索使用情况等
|
||||
if performed, ok := ctx.Data["web_search_performed"].(bool); ok && performed {
|
||||
query := ctx.Data["web_search_query"].(string)
|
||||
common.SysLog(fmt.Sprintf("Web search completed for query: %s", query))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// OnError 错误处理
|
||||
func (h *WebSearchHook) OnError(ctx *interfaces.HookContext, err error) error {
|
||||
// 记录错误但不影响主流程
|
||||
if performed, ok := ctx.Data["web_search_performed"].(bool); ok && performed {
|
||||
common.SysError(fmt.Sprintf("Request failed after web search: %v", err))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// shouldSearch 判断是否需要搜索
|
||||
func (h *WebSearchHook) shouldSearch(requestData map[string]interface{}) bool {
|
||||
messages, ok := requestData["messages"].([]interface{})
|
||||
if !ok || len(messages) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
// 检查最后一条消息
|
||||
lastMessage, ok := messages[len(messages)-1].(map[string]interface{})
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
content, ok := lastMessage["content"].(string)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
// 简单示例:检查是否包含 [search] 或 [联网] 标记
|
||||
return strings.Contains(content, "[search]") ||
|
||||
strings.Contains(content, "[联网]") ||
|
||||
strings.Contains(content, "[web]")
|
||||
}
|
||||
|
||||
// extractSearchQuery 提取搜索查询
|
||||
func (h *WebSearchHook) extractSearchQuery(requestData map[string]interface{}) string {
|
||||
messages, ok := requestData["messages"].([]interface{})
|
||||
if !ok || len(messages) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
lastMessage, ok := messages[len(messages)-1].(map[string]interface{})
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
|
||||
content, ok := lastMessage["content"].(string)
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
|
||||
// 移除标记,保留实际查询内容
|
||||
query := strings.ReplaceAll(content, "[search]", "")
|
||||
query = strings.ReplaceAll(query, "[联网]", "")
|
||||
query = strings.ReplaceAll(query, "[web]", "")
|
||||
query = strings.TrimSpace(query)
|
||||
|
||||
return query
|
||||
}
|
||||
|
||||
// performSearch 执行搜索
|
||||
func (h *WebSearchHook) performSearch(query string) (string, error) {
|
||||
// 这里是示例实现,实际应该调用真实的搜索API
|
||||
// 例如:Google Custom Search API, Bing Search API等
|
||||
|
||||
if h.apiKey == "" {
|
||||
return "", fmt.Errorf("search API key not configured")
|
||||
}
|
||||
|
||||
// 示例:返回模拟结果
|
||||
// 实际实现需要调用真实API
|
||||
return h.mockSearch(query)
|
||||
}
|
||||
|
||||
// mockSearch 模拟搜索(示例)
|
||||
func (h *WebSearchHook) mockSearch(query string) (string, error) {
|
||||
// 这只是一个示例实现
|
||||
// 实际应该调用真实的搜索API
|
||||
|
||||
common.SysLog(fmt.Sprintf("[Mock] Searching for: %s", query))
|
||||
|
||||
// 返回模拟的搜索结果
|
||||
return fmt.Sprintf("搜索结果 (模拟):关于 '%s' 的最新信息...", query), nil
|
||||
}
|
||||
|
||||
// realSearch 真实搜索实现示例(需要配置API)
|
||||
func (h *WebSearchHook) realSearch(query string) (string, error) {
|
||||
// 示例:使用Google Custom Search API
|
||||
url := fmt.Sprintf("https://www.googleapis.com/customsearch/v1?key=%s&cx=YOUR_CX&q=%s",
|
||||
h.apiKey, query)
|
||||
|
||||
resp, err := http.Get(url)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// 解析搜索结果
|
||||
var result map[string]interface{}
|
||||
if err := json.Unmarshal(body, &result); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// 提取搜索结果摘要
|
||||
// 这里需要根据实际API响应格式处理
|
||||
return string(body), nil
|
||||
}
|
||||
|
||||
// injectSearchResults 将搜索结果注入到请求中
|
||||
func (h *WebSearchHook) injectSearchResults(requestData map[string]interface{}, results string) {
|
||||
messages, ok := requestData["messages"].([]interface{})
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
// 在用户消息前插入系统消息,包含搜索结果
|
||||
systemMessage := map[string]interface{}{
|
||||
"role": "system",
|
||||
"content": fmt.Sprintf("以下是针对用户查询的最新搜索结果:\n\n%s\n\n请基于这些信息回答用户的问题。", results),
|
||||
}
|
||||
|
||||
// 插入到消息列表的适当位置
|
||||
updatedMessages := make([]interface{}, 0, len(messages)+1)
|
||||
|
||||
// 如果第一条是系统消息,在其后插入
|
||||
if len(messages) > 0 {
|
||||
if firstMsg, ok := messages[0].(map[string]interface{}); ok {
|
||||
if role, ok := firstMsg["role"].(string); ok && role == "system" {
|
||||
updatedMessages = append(updatedMessages, messages[0])
|
||||
updatedMessages = append(updatedMessages, systemMessage)
|
||||
updatedMessages = append(updatedMessages, messages[1:]...)
|
||||
requestData["messages"] = updatedMessages
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 否则插入到开头
|
||||
updatedMessages = append(updatedMessages, systemMessage)
|
||||
updatedMessages = append(updatedMessages, messages...)
|
||||
requestData["messages"] = updatedMessages
|
||||
}
|
||||
|
||||
@@ -53,5 +53,5 @@ type TaskAdaptor interface {
|
||||
}
|
||||
|
||||
type OpenAIVideoConverter interface {
|
||||
ConvertToOpenAIVideo(originTask *model.Task) (*relaycommon.OpenAIVideo, error)
|
||||
ConvertToOpenAIVideo(originTask *model.Task) ([]byte, error)
|
||||
}
|
||||
|
||||
@@ -1,20 +1,7 @@
|
||||
package ali
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/dto"
|
||||
"github.com/QuantumNous/new-api/relay/helper"
|
||||
"github.com/QuantumNous/new-api/service"
|
||||
|
||||
"github.com/QuantumNous/new-api/types"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// https://help.aliyun.com/document_detail/613695.html?spm=a2c4g.2399480.0.0.1adb778fAdzP9w#341800c0f8w0r
|
||||
@@ -29,180 +16,3 @@ func requestOpenAI2Ali(request dto.GeneralOpenAIRequest) *dto.GeneralOpenAIReque
|
||||
}
|
||||
return &request
|
||||
}
|
||||
|
||||
func embeddingRequestOpenAI2Ali(request dto.EmbeddingRequest) *AliEmbeddingRequest {
|
||||
return &AliEmbeddingRequest{
|
||||
Model: request.Model,
|
||||
Input: struct {
|
||||
Texts []string `json:"texts"`
|
||||
}{
|
||||
Texts: request.ParseInput(),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func aliEmbeddingHandler(c *gin.Context, resp *http.Response) (*types.NewAPIError, *dto.Usage) {
|
||||
var fullTextResponse dto.FlexibleEmbeddingResponse
|
||||
err := json.NewDecoder(resp.Body).Decode(&fullTextResponse)
|
||||
if err != nil {
|
||||
return types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError), nil
|
||||
}
|
||||
|
||||
service.CloseResponseBodyGracefully(resp)
|
||||
|
||||
model := c.GetString("model")
|
||||
if model == "" {
|
||||
model = "text-embedding-v4"
|
||||
}
|
||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
|
||||
}
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
c.Writer.Write(jsonResponse)
|
||||
return nil, &fullTextResponse.Usage
|
||||
}
|
||||
|
||||
func embeddingResponseAli2OpenAI(response *AliEmbeddingResponse, model string) *dto.OpenAIEmbeddingResponse {
|
||||
openAIEmbeddingResponse := dto.OpenAIEmbeddingResponse{
|
||||
Object: "list",
|
||||
Data: make([]dto.OpenAIEmbeddingResponseItem, 0, len(response.Output.Embeddings)),
|
||||
Model: model,
|
||||
Usage: dto.Usage{TotalTokens: response.Usage.TotalTokens},
|
||||
}
|
||||
|
||||
for _, item := range response.Output.Embeddings {
|
||||
openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, dto.OpenAIEmbeddingResponseItem{
|
||||
Object: `embedding`,
|
||||
Index: item.TextIndex,
|
||||
Embedding: item.Embedding,
|
||||
})
|
||||
}
|
||||
return &openAIEmbeddingResponse
|
||||
}
|
||||
|
||||
func responseAli2OpenAI(response *AliResponse) *dto.OpenAITextResponse {
|
||||
choice := dto.OpenAITextResponseChoice{
|
||||
Index: 0,
|
||||
Message: dto.Message{
|
||||
Role: "assistant",
|
||||
Content: response.Output.Text,
|
||||
},
|
||||
FinishReason: response.Output.FinishReason,
|
||||
}
|
||||
fullTextResponse := dto.OpenAITextResponse{
|
||||
Id: response.RequestId,
|
||||
Object: "chat.completion",
|
||||
Created: common.GetTimestamp(),
|
||||
Choices: []dto.OpenAITextResponseChoice{choice},
|
||||
Usage: dto.Usage{
|
||||
PromptTokens: response.Usage.InputTokens,
|
||||
CompletionTokens: response.Usage.OutputTokens,
|
||||
TotalTokens: response.Usage.InputTokens + response.Usage.OutputTokens,
|
||||
},
|
||||
}
|
||||
return &fullTextResponse
|
||||
}
|
||||
|
||||
func streamResponseAli2OpenAI(aliResponse *AliResponse) *dto.ChatCompletionsStreamResponse {
|
||||
var choice dto.ChatCompletionsStreamResponseChoice
|
||||
choice.Delta.SetContentString(aliResponse.Output.Text)
|
||||
if aliResponse.Output.FinishReason != "null" {
|
||||
finishReason := aliResponse.Output.FinishReason
|
||||
choice.FinishReason = &finishReason
|
||||
}
|
||||
response := dto.ChatCompletionsStreamResponse{
|
||||
Id: aliResponse.RequestId,
|
||||
Object: "chat.completion.chunk",
|
||||
Created: common.GetTimestamp(),
|
||||
Model: "ernie-bot",
|
||||
Choices: []dto.ChatCompletionsStreamResponseChoice{choice},
|
||||
}
|
||||
return &response
|
||||
}
|
||||
|
||||
func aliStreamHandler(c *gin.Context, resp *http.Response) (*types.NewAPIError, *dto.Usage) {
|
||||
var usage dto.Usage
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner.Split(bufio.ScanLines)
|
||||
dataChan := make(chan string)
|
||||
stopChan := make(chan bool)
|
||||
go func() {
|
||||
for scanner.Scan() {
|
||||
data := scanner.Text()
|
||||
if len(data) < 5 { // ignore blank line or wrong format
|
||||
continue
|
||||
}
|
||||
if data[:5] != "data:" {
|
||||
continue
|
||||
}
|
||||
data = data[5:]
|
||||
dataChan <- data
|
||||
}
|
||||
stopChan <- true
|
||||
}()
|
||||
helper.SetEventStreamHeaders(c)
|
||||
lastResponseText := ""
|
||||
c.Stream(func(w io.Writer) bool {
|
||||
select {
|
||||
case data := <-dataChan:
|
||||
var aliResponse AliResponse
|
||||
err := json.Unmarshal([]byte(data), &aliResponse)
|
||||
if err != nil {
|
||||
common.SysLog("error unmarshalling stream response: " + err.Error())
|
||||
return true
|
||||
}
|
||||
if aliResponse.Usage.OutputTokens != 0 {
|
||||
usage.PromptTokens = aliResponse.Usage.InputTokens
|
||||
usage.CompletionTokens = aliResponse.Usage.OutputTokens
|
||||
usage.TotalTokens = aliResponse.Usage.InputTokens + aliResponse.Usage.OutputTokens
|
||||
}
|
||||
response := streamResponseAli2OpenAI(&aliResponse)
|
||||
response.Choices[0].Delta.SetContentString(strings.TrimPrefix(response.Choices[0].Delta.GetContentString(), lastResponseText))
|
||||
lastResponseText = aliResponse.Output.Text
|
||||
jsonResponse, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
common.SysLog("error marshalling stream response: " + err.Error())
|
||||
return true
|
||||
}
|
||||
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
|
||||
return true
|
||||
case <-stopChan:
|
||||
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
|
||||
return false
|
||||
}
|
||||
})
|
||||
service.CloseResponseBodyGracefully(resp)
|
||||
return nil, &usage
|
||||
}
|
||||
|
||||
func aliHandler(c *gin.Context, resp *http.Response) (*types.NewAPIError, *dto.Usage) {
|
||||
var aliResponse AliResponse
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError), nil
|
||||
}
|
||||
service.CloseResponseBodyGracefully(resp)
|
||||
err = json.Unmarshal(responseBody, &aliResponse)
|
||||
if err != nil {
|
||||
return types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError), nil
|
||||
}
|
||||
if aliResponse.Code != "" {
|
||||
return types.WithOpenAIError(types.OpenAIError{
|
||||
Message: aliResponse.Message,
|
||||
Type: "ali_error",
|
||||
Param: aliResponse.RequestId,
|
||||
Code: aliResponse.Code,
|
||||
}, resp.StatusCode), nil
|
||||
}
|
||||
fullTextResponse := responseAli2OpenAI(&aliResponse)
|
||||
jsonResponse, err := common.Marshal(fullTextResponse)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
|
||||
}
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
_, err = c.Writer.Write(jsonResponse)
|
||||
return nil, &fullTextResponse.Usage
|
||||
}
|
||||
|
||||
@@ -1,25 +1,36 @@
|
||||
package aws
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/QuantumNous/new-api/dto"
|
||||
"github.com/QuantumNous/new-api/relay/channel"
|
||||
"github.com/QuantumNous/new-api/relay/channel/claude"
|
||||
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||||
"github.com/QuantumNous/new-api/service"
|
||||
"github.com/QuantumNous/new-api/types"
|
||||
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type ClientMode int
|
||||
|
||||
const (
|
||||
RequestModeCompletion = 1
|
||||
RequestModeMessage = 2
|
||||
ClientModeApiKey ClientMode = iota + 1
|
||||
ClientModeAKSK
|
||||
)
|
||||
|
||||
type Adaptor struct {
|
||||
RequestMode int
|
||||
ClientMode ClientMode
|
||||
AwsClient *bedrockruntime.Client
|
||||
AwsModelId string
|
||||
AwsReq any
|
||||
IsNova bool
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
|
||||
@@ -28,8 +39,37 @@ func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dt
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error) {
|
||||
c.Set("request_model", request.Model)
|
||||
c.Set("converted_request", request)
|
||||
for i, message := range request.Messages {
|
||||
updated := false
|
||||
if !message.IsStringContent() {
|
||||
content, err := message.ParseContent()
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to parse message content")
|
||||
}
|
||||
for i2, mediaMessage := range content {
|
||||
if mediaMessage.Source != nil {
|
||||
if mediaMessage.Source.Type == "url" {
|
||||
fileData, err := service.GetFileBase64FromUrl(c, mediaMessage.Source.Url, "formatting image for Claude")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get file base64 from url failed: %s", err.Error())
|
||||
}
|
||||
mediaMessage.Source.MediaType = fileData.MimeType
|
||||
mediaMessage.Source.Data = fileData.Base64Data
|
||||
mediaMessage.Source.Url = ""
|
||||
mediaMessage.Source.Type = "base64"
|
||||
content[i2] = mediaMessage
|
||||
updated = true
|
||||
}
|
||||
}
|
||||
}
|
||||
if updated {
|
||||
message.SetContent(content)
|
||||
}
|
||||
}
|
||||
if updated {
|
||||
request.Messages[i] = message
|
||||
}
|
||||
}
|
||||
return request, nil
|
||||
}
|
||||
|
||||
@@ -44,15 +84,28 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf
|
||||
}
|
||||
|
||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
||||
a.RequestMode = RequestModeMessage
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
return "", nil
|
||||
if info.ChannelOtherSettings.AwsKeyType == dto.AwsKeyTypeApiKey {
|
||||
awsModelId := getAwsModelID(info.UpstreamModelName)
|
||||
a.ClientMode = ClientModeApiKey
|
||||
awsSecret := strings.Split(info.ApiKey, "|")
|
||||
if len(awsSecret) != 2 {
|
||||
return "", errors.New("invalid aws api key, should be in format of <api-key>|<region>")
|
||||
}
|
||||
return fmt.Sprintf("https://bedrock-runtime.%s.amazonaws.com/model/%s/converse", awsModelId, awsSecret[1]), nil
|
||||
} else {
|
||||
a.ClientMode = ClientModeAKSK
|
||||
return "", nil
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
|
||||
claude.CommonClaudeHeadersOperation(c, req, info)
|
||||
if a.ClientMode == ClientModeApiKey {
|
||||
req.Set("Authorization", "Bearer "+info.ApiKey)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -63,22 +116,16 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
|
||||
// 检查是否为Nova模型
|
||||
if isNovaModel(request.Model) {
|
||||
novaReq := convertToNovaRequest(request)
|
||||
c.Set("request_model", request.Model)
|
||||
c.Set("converted_request", novaReq)
|
||||
c.Set("is_nova_model", true)
|
||||
a.IsNova = true
|
||||
return novaReq, nil
|
||||
}
|
||||
|
||||
// 原有的Claude模型处理逻辑
|
||||
var claudeReq *dto.ClaudeRequest
|
||||
var err error
|
||||
claudeReq, err = claude.RequestOpenAI2ClaudeMessage(c, *request)
|
||||
claudeReq, err := claude.RequestOpenAI2ClaudeMessage(c, *request)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, errors.Wrap(err, "failed to convert openai request to claude request")
|
||||
}
|
||||
c.Set("request_model", claudeReq.Model)
|
||||
c.Set("converted_request", claudeReq)
|
||||
c.Set("is_nova_model", false)
|
||||
info.UpstreamModelName = claudeReq.Model
|
||||
return claudeReq, err
|
||||
}
|
||||
|
||||
@@ -97,14 +144,27 @@ func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommo
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
||||
return nil, nil
|
||||
if a.ClientMode == ClientModeApiKey {
|
||||
return channel.DoApiRequest(a, c, info, requestBody)
|
||||
} else {
|
||||
return doAwsClientRequest(c, info, a, requestBody)
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
|
||||
if info.IsStream {
|
||||
err, usage = awsStreamHandler(c, resp, info, a.RequestMode)
|
||||
if a.ClientMode == ClientModeApiKey {
|
||||
claudeAdaptor := claude.Adaptor{}
|
||||
usage, err = claudeAdaptor.DoResponse(c, resp, info)
|
||||
} else {
|
||||
err, usage = awsHandler(c, info, a.RequestMode)
|
||||
if a.IsNova {
|
||||
err, usage = handleNovaRequest(c, info, a)
|
||||
} else {
|
||||
if info.IsStream {
|
||||
err, usage = awsStreamHandler(c, info, a)
|
||||
} else {
|
||||
err, usage = awsHandler(c, info, a)
|
||||
}
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
@@ -124,5 +124,5 @@ var ChannelName = "aws"
|
||||
|
||||
// 判断是否为Nova模型
|
||||
func isNovaModel(modelId string) bool {
|
||||
return strings.HasPrefix(modelId, "nova-")
|
||||
return strings.Contains(modelId, "nova-")
|
||||
}
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
package aws
|
||||
|
||||
import (
|
||||
"io"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/dto"
|
||||
)
|
||||
|
||||
@@ -35,6 +38,16 @@ func copyRequest(req *dto.ClaudeRequest) *AwsClaudeRequest {
|
||||
}
|
||||
}
|
||||
|
||||
func formatRequest(requestBody io.Reader) (*AwsClaudeRequest, error) {
|
||||
var awsClaudeRequest AwsClaudeRequest
|
||||
err := common.DecodeJson(requestBody, &awsClaudeRequest)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
awsClaudeRequest.AnthropicVersion = "bedrock-2023-05-31"
|
||||
return &awsClaudeRequest, nil
|
||||
}
|
||||
|
||||
// NovaMessage Nova模型使用messages-v1格式
|
||||
type NovaMessage struct {
|
||||
Role string `json:"role"`
|
||||
|
||||
@@ -3,6 +3,7 @@ package aws
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
@@ -49,16 +50,78 @@ func newAwsClient(c *gin.Context, info *relaycommon.RelayInfo) (*bedrockruntime.
|
||||
return client, nil
|
||||
}
|
||||
|
||||
func wrapErr(err error) *dto.OpenAIErrorWithStatusCode {
|
||||
return &dto.OpenAIErrorWithStatusCode{
|
||||
StatusCode: http.StatusInternalServerError,
|
||||
Error: dto.OpenAIError{
|
||||
Message: fmt.Sprintf("%s", err.Error()),
|
||||
},
|
||||
func doAwsClientRequest(c *gin.Context, info *relaycommon.RelayInfo, a *Adaptor, requestBody io.Reader) (any, error) {
|
||||
awsCli, err := newAwsClient(c, info)
|
||||
if err != nil {
|
||||
return nil, types.NewError(err, types.ErrorCodeChannelAwsClientError)
|
||||
}
|
||||
a.AwsClient = awsCli
|
||||
|
||||
println(info.UpstreamModelName)
|
||||
// 获取对应的AWS模型ID
|
||||
awsModelId := getAwsModelID(info.UpstreamModelName)
|
||||
|
||||
awsRegionPrefix := getAwsRegionPrefix(awsCli.Options().Region)
|
||||
canCrossRegion := awsModelCanCrossRegion(awsModelId, awsRegionPrefix)
|
||||
if canCrossRegion {
|
||||
awsModelId = awsModelCrossRegion(awsModelId, awsRegionPrefix)
|
||||
}
|
||||
|
||||
if isNovaModel(awsModelId) {
|
||||
var novaReq *NovaRequest
|
||||
err = common.DecodeJson(requestBody, &novaReq)
|
||||
if err != nil {
|
||||
return nil, types.NewError(errors.Wrap(err, "decode nova request fail"), types.ErrorCodeBadRequestBody)
|
||||
}
|
||||
|
||||
// 使用InvokeModel API,但使用Nova格式的请求体
|
||||
awsReq := &bedrockruntime.InvokeModelInput{
|
||||
ModelId: aws.String(awsModelId),
|
||||
Accept: aws.String("application/json"),
|
||||
ContentType: aws.String("application/json"),
|
||||
}
|
||||
|
||||
reqBody, err := common.Marshal(novaReq)
|
||||
if err != nil {
|
||||
return nil, types.NewError(errors.Wrap(err, "marshal nova request"), types.ErrorCodeBadResponseBody)
|
||||
}
|
||||
awsReq.Body = reqBody
|
||||
return nil, nil
|
||||
} else {
|
||||
awsClaudeReq, err := formatRequest(requestBody)
|
||||
if err != nil {
|
||||
return nil, types.NewError(errors.Wrap(err, "format aws request fail"), types.ErrorCodeBadRequestBody)
|
||||
}
|
||||
|
||||
if info.IsStream {
|
||||
awsReq := &bedrockruntime.InvokeModelWithResponseStreamInput{
|
||||
ModelId: aws.String(awsModelId),
|
||||
Accept: aws.String("application/json"),
|
||||
ContentType: aws.String("application/json"),
|
||||
}
|
||||
awsReq.Body, err = common.Marshal(awsClaudeReq)
|
||||
if err != nil {
|
||||
return nil, types.NewError(errors.Wrap(err, "marshal aws request fail"), types.ErrorCodeBadRequestBody)
|
||||
}
|
||||
a.AwsReq = awsReq
|
||||
return nil, nil
|
||||
} else {
|
||||
awsReq := &bedrockruntime.InvokeModelInput{
|
||||
ModelId: aws.String(awsModelId),
|
||||
Accept: aws.String("application/json"),
|
||||
ContentType: aws.String("application/json"),
|
||||
}
|
||||
awsReq.Body, err = common.Marshal(awsClaudeReq)
|
||||
if err != nil {
|
||||
return nil, types.NewError(errors.Wrap(err, "marshal aws request fail"), types.ErrorCodeBadRequestBody)
|
||||
}
|
||||
a.AwsReq = awsReq
|
||||
return nil, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func awsRegionPrefix(awsRegionId string) string {
|
||||
func getAwsRegionPrefix(awsRegionId string) string {
|
||||
parts := strings.Split(awsRegionId, "-")
|
||||
regionPrefix := ""
|
||||
if len(parts) > 0 {
|
||||
@@ -80,58 +143,16 @@ func awsModelCrossRegion(awsModelId, awsRegionPrefix string) string {
|
||||
return modelPrefix + "." + awsModelId
|
||||
}
|
||||
|
||||
func awsModelID(requestModel string) string {
|
||||
if awsModelID, ok := awsModelIDMap[requestModel]; ok {
|
||||
return awsModelID
|
||||
func getAwsModelID(requestModel string) string {
|
||||
if awsModelIDName, ok := awsModelIDMap[requestModel]; ok {
|
||||
return awsModelIDName
|
||||
}
|
||||
|
||||
return requestModel
|
||||
}
|
||||
|
||||
func awsHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode int) (*types.NewAPIError, *dto.Usage) {
|
||||
awsCli, err := newAwsClient(c, info)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeChannelAwsClientError), nil
|
||||
}
|
||||
func awsHandler(c *gin.Context, info *relaycommon.RelayInfo, a *Adaptor) (*types.NewAPIError, *dto.Usage) {
|
||||
|
||||
awsModelId := awsModelID(c.GetString("request_model"))
|
||||
// 检查是否为Nova模型
|
||||
isNova, _ := c.Get("is_nova_model")
|
||||
if isNova == true {
|
||||
// Nova模型也支持跨区域
|
||||
awsRegionPrefix := awsRegionPrefix(awsCli.Options().Region)
|
||||
canCrossRegion := awsModelCanCrossRegion(awsModelId, awsRegionPrefix)
|
||||
if canCrossRegion {
|
||||
awsModelId = awsModelCrossRegion(awsModelId, awsRegionPrefix)
|
||||
}
|
||||
return handleNovaRequest(c, awsCli, info, awsModelId)
|
||||
}
|
||||
|
||||
// 原有的Claude处理逻辑
|
||||
awsRegionPrefix := awsRegionPrefix(awsCli.Options().Region)
|
||||
canCrossRegion := awsModelCanCrossRegion(awsModelId, awsRegionPrefix)
|
||||
if canCrossRegion {
|
||||
awsModelId = awsModelCrossRegion(awsModelId, awsRegionPrefix)
|
||||
}
|
||||
|
||||
awsReq := &bedrockruntime.InvokeModelInput{
|
||||
ModelId: aws.String(awsModelId),
|
||||
Accept: aws.String("application/json"),
|
||||
ContentType: aws.String("application/json"),
|
||||
}
|
||||
|
||||
claudeReq_, ok := c.Get("converted_request")
|
||||
if !ok {
|
||||
return types.NewError(errors.New("aws claude request not found"), types.ErrorCodeInvalidRequest), nil
|
||||
}
|
||||
claudeReq := claudeReq_.(*dto.ClaudeRequest)
|
||||
awsClaudeReq := copyRequest(claudeReq)
|
||||
awsReq.Body, err = common.Marshal(awsClaudeReq)
|
||||
if err != nil {
|
||||
return types.NewError(errors.Wrap(err, "marshal request"), types.ErrorCodeBadResponseBody), nil
|
||||
}
|
||||
|
||||
awsResp, err := awsCli.InvokeModel(c.Request.Context(), awsReq)
|
||||
awsResp, err := a.AwsClient.InvokeModel(c.Request.Context(), a.AwsReq.(*bedrockruntime.InvokeModelInput))
|
||||
if err != nil {
|
||||
return types.NewOpenAIError(errors.Wrap(err, "InvokeModel"), types.ErrorCodeAwsInvokeError, http.StatusInternalServerError), nil
|
||||
}
|
||||
@@ -149,46 +170,15 @@ func awsHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode int) (*
|
||||
c.Writer.Header().Set("Content-Type", *awsResp.ContentType)
|
||||
}
|
||||
|
||||
handlerErr := claude.HandleClaudeResponseData(c, info, claudeInfo, nil, awsResp.Body, RequestModeMessage)
|
||||
handlerErr := claude.HandleClaudeResponseData(c, info, claudeInfo, nil, awsResp.Body, claude.RequestModeMessage)
|
||||
if handlerErr != nil {
|
||||
return handlerErr, nil
|
||||
}
|
||||
return nil, claudeInfo.Usage
|
||||
}
|
||||
|
||||
func awsStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, requestMode int) (*types.NewAPIError, *dto.Usage) {
|
||||
awsCli, err := newAwsClient(c, info)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeChannelAwsClientError), nil
|
||||
}
|
||||
|
||||
awsModelId := awsModelID(c.GetString("request_model"))
|
||||
|
||||
awsRegionPrefix := awsRegionPrefix(awsCli.Options().Region)
|
||||
canCrossRegion := awsModelCanCrossRegion(awsModelId, awsRegionPrefix)
|
||||
if canCrossRegion {
|
||||
awsModelId = awsModelCrossRegion(awsModelId, awsRegionPrefix)
|
||||
}
|
||||
|
||||
awsReq := &bedrockruntime.InvokeModelWithResponseStreamInput{
|
||||
ModelId: aws.String(awsModelId),
|
||||
Accept: aws.String("application/json"),
|
||||
ContentType: aws.String("application/json"),
|
||||
}
|
||||
|
||||
claudeReq_, ok := c.Get("converted_request")
|
||||
if !ok {
|
||||
return types.NewError(errors.New("aws claude request not found"), types.ErrorCodeInvalidRequest), nil
|
||||
}
|
||||
claudeReq := claudeReq_.(*dto.ClaudeRequest)
|
||||
|
||||
awsClaudeReq := copyRequest(claudeReq)
|
||||
awsReq.Body, err = common.Marshal(awsClaudeReq)
|
||||
if err != nil {
|
||||
return types.NewError(errors.Wrap(err, "marshal request"), types.ErrorCodeBadResponseBody), nil
|
||||
}
|
||||
|
||||
awsResp, err := awsCli.InvokeModelWithResponseStream(c.Request.Context(), awsReq)
|
||||
func awsStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, a *Adaptor) (*types.NewAPIError, *dto.Usage) {
|
||||
awsResp, err := a.AwsClient.InvokeModelWithResponseStream(c.Request.Context(), a.AwsReq.(*bedrockruntime.InvokeModelWithResponseStreamInput))
|
||||
if err != nil {
|
||||
return types.NewOpenAIError(errors.Wrap(err, "InvokeModelWithResponseStream"), types.ErrorCodeAwsInvokeError, http.StatusInternalServerError), nil
|
||||
}
|
||||
@@ -207,7 +197,7 @@ func awsStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
|
||||
switch v := event.(type) {
|
||||
case *bedrockruntimeTypes.ResponseStreamMemberChunk:
|
||||
info.SetFirstResponseTime()
|
||||
respErr := claude.HandleStreamResponseData(c, info, claudeInfo, string(v.Value.Bytes), RequestModeMessage)
|
||||
respErr := claude.HandleStreamResponseData(c, info, claudeInfo, string(v.Value.Bytes), claude.RequestModeMessage)
|
||||
if respErr != nil {
|
||||
return respErr, nil
|
||||
}
|
||||
@@ -220,32 +210,14 @@ func awsStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
|
||||
}
|
||||
}
|
||||
|
||||
claude.HandleStreamFinalResponse(c, info, claudeInfo, RequestModeMessage)
|
||||
claude.HandleStreamFinalResponse(c, info, claudeInfo, claude.RequestModeMessage)
|
||||
return nil, claudeInfo.Usage
|
||||
}
|
||||
|
||||
// Nova模型处理函数
|
||||
func handleNovaRequest(c *gin.Context, awsCli *bedrockruntime.Client, info *relaycommon.RelayInfo, awsModelId string) (*types.NewAPIError, *dto.Usage) {
|
||||
novaReq_, ok := c.Get("converted_request")
|
||||
if !ok {
|
||||
return types.NewError(errors.New("nova request not found"), types.ErrorCodeInvalidRequest), nil
|
||||
}
|
||||
novaReq := novaReq_.(*NovaRequest)
|
||||
func handleNovaRequest(c *gin.Context, info *relaycommon.RelayInfo, a *Adaptor) (*types.NewAPIError, *dto.Usage) {
|
||||
|
||||
// 使用InvokeModel API,但使用Nova格式的请求体
|
||||
awsReq := &bedrockruntime.InvokeModelInput{
|
||||
ModelId: aws.String(awsModelId),
|
||||
Accept: aws.String("application/json"),
|
||||
ContentType: aws.String("application/json"),
|
||||
}
|
||||
|
||||
reqBody, err := json.Marshal(novaReq)
|
||||
if err != nil {
|
||||
return types.NewError(errors.Wrap(err, "marshal nova request"), types.ErrorCodeBadResponseBody), nil
|
||||
}
|
||||
awsReq.Body = reqBody
|
||||
|
||||
awsResp, err := awsCli.InvokeModel(c.Request.Context(), awsReq)
|
||||
awsResp, err := a.AwsClient.InvokeModel(c.Request.Context(), a.AwsReq.(*bedrockruntime.InvokeModelInput))
|
||||
if err != nil {
|
||||
return types.NewError(errors.Wrap(err, "InvokeModel"), types.ErrorCodeChannelAwsClientError), nil
|
||||
}
|
||||
|
||||
@@ -477,8 +477,7 @@ func StreamResponseClaude2OpenAI(reqMode int, claudeResponse *dto.ClaudeResponse
|
||||
signatureContent := "\n"
|
||||
choice.Delta.ReasoningContent = &signatureContent
|
||||
case "thinking_delta":
|
||||
thinkingContent := claudeResponse.Delta.Thinking
|
||||
choice.Delta.ReasoningContent = &thinkingContent
|
||||
choice.Delta.ReasoningContent = claudeResponse.Delta.Thinking
|
||||
}
|
||||
}
|
||||
} else if claudeResponse.Type == "message_delta" {
|
||||
@@ -513,7 +512,9 @@ func ResponseClaude2OpenAI(reqMode int, claudeResponse *dto.ClaudeResponse) *dto
|
||||
var responseThinking string
|
||||
if len(claudeResponse.Content) > 0 {
|
||||
responseText = claudeResponse.Content[0].GetText()
|
||||
responseThinking = claudeResponse.Content[0].Thinking
|
||||
if claudeResponse.Content[0].Thinking != nil {
|
||||
responseThinking = *claudeResponse.Content[0].Thinking
|
||||
}
|
||||
}
|
||||
tools := make([]dto.ToolCallResponse, 0)
|
||||
thinkingContent := ""
|
||||
@@ -545,7 +546,9 @@ func ResponseClaude2OpenAI(reqMode int, claudeResponse *dto.ClaudeResponse) *dto
|
||||
})
|
||||
case "thinking":
|
||||
// 加密的不管, 只输出明文的推理过程
|
||||
thinkingContent = message.Thinking
|
||||
if message.Thinking != nil {
|
||||
thinkingContent = *message.Thinking
|
||||
}
|
||||
case "text":
|
||||
responseText = message.GetText()
|
||||
}
|
||||
@@ -598,8 +601,8 @@ func FormatClaudeResponseInfo(requestMode int, claudeResponse *dto.ClaudeRespons
|
||||
if claudeResponse.Delta.Text != nil {
|
||||
claudeInfo.ResponseText.WriteString(*claudeResponse.Delta.Text)
|
||||
}
|
||||
if claudeResponse.Delta.Thinking != "" {
|
||||
claudeInfo.ResponseText.WriteString(claudeResponse.Delta.Thinking)
|
||||
if claudeResponse.Delta.Thinking != nil {
|
||||
claudeInfo.ResponseText.WriteString(*claudeResponse.Delta.Thinking)
|
||||
}
|
||||
} else if claudeResponse.Type == "message_delta" {
|
||||
// 最终的usage获取
|
||||
|
||||
@@ -211,7 +211,16 @@ func CovertGemini2OpenAI(c *gin.Context, textRequest dto.GeneralOpenAIRequest, i
|
||||
// eg. {"google":{"thinking_config":{"thinking_budget":5324,"include_thoughts":true}}}
|
||||
if googleBody, ok := extraBody["google"].(map[string]interface{}); ok {
|
||||
adaptorWithExtraBody = true
|
||||
// check error param name like thinkingConfig, should be thinking_config
|
||||
if _, hasErrorParam := googleBody["thinkingConfig"]; hasErrorParam {
|
||||
return nil, errors.New("extra_body.google.thinkingConfig is not supported, use extra_body.google.thinking_config instead")
|
||||
}
|
||||
|
||||
if thinkingConfig, ok := googleBody["thinking_config"].(map[string]interface{}); ok {
|
||||
// check error param name like thinkingBudget, should be thinking_budget
|
||||
if _, hasErrorParam := thinkingConfig["thinkingBudget"]; hasErrorParam {
|
||||
return nil, errors.New("extra_body.google.thinking_config.thinkingBudget is not supported, use extra_body.google.thinking_config.thinking_budget instead")
|
||||
}
|
||||
if budget, ok := thinkingConfig["thinking_budget"].(float64); ok {
|
||||
budgetInt := int(budget)
|
||||
geminiRequest.GenerationConfig.ThinkingConfig = &dto.GeminiThinkingConfig{
|
||||
@@ -1052,11 +1061,11 @@ func GeminiChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.R
|
||||
}
|
||||
if len(geminiResponse.Candidates) == 0 {
|
||||
//return nil, types.NewOpenAIError(errors.New("no candidates returned"), types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
||||
if geminiResponse.PromptFeedback != nil && geminiResponse.PromptFeedback.BlockReason != nil {
|
||||
return nil, types.NewOpenAIError(errors.New("request blocked by Gemini API: "+*geminiResponse.PromptFeedback.BlockReason), types.ErrorCodePromptBlocked, http.StatusBadRequest)
|
||||
} else {
|
||||
return nil, types.NewOpenAIError(errors.New("empty response from Gemini API"), types.ErrorCodeEmptyResponse, http.StatusInternalServerError)
|
||||
}
|
||||
//if geminiResponse.PromptFeedback != nil && geminiResponse.PromptFeedback.BlockReason != nil {
|
||||
// return nil, types.NewOpenAIError(errors.New("request blocked by Gemini API: "+*geminiResponse.PromptFeedback.BlockReason), types.ErrorCodePromptBlocked, http.StatusBadRequest)
|
||||
//} else {
|
||||
// return nil, types.NewOpenAIError(errors.New("empty response from Gemini API"), types.ErrorCodeEmptyResponse, http.StatusInternalServerError)
|
||||
//}
|
||||
}
|
||||
fullTextResponse := responseGeminiChat2OpenAI(c, &geminiResponse)
|
||||
fullTextResponse.Model = info.UpstreamModelName
|
||||
|
||||
132
relay/channel/minimax/adaptor.go
Normal file
132
relay/channel/minimax/adaptor.go
Normal file
@@ -0,0 +1,132 @@
|
||||
package minimax
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
"github.com/QuantumNous/new-api/dto"
|
||||
"github.com/QuantumNous/new-api/relay/channel"
|
||||
"github.com/QuantumNous/new-api/relay/channel/openai"
|
||||
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||||
"github.com/QuantumNous/new-api/relay/constant"
|
||||
"github.com/QuantumNous/new-api/types"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, req *dto.ClaudeRequest) (any, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
||||
if info.RelayMode != constant.RelayModeAudioSpeech {
|
||||
return nil, errors.New("unsupported audio relay mode")
|
||||
}
|
||||
|
||||
voiceID := request.Voice
|
||||
speed := request.Speed
|
||||
outputFormat := request.ResponseFormat
|
||||
|
||||
minimaxRequest := MiniMaxTTSRequest{
|
||||
Model: info.OriginModelName,
|
||||
Text: request.Input,
|
||||
VoiceSetting: VoiceSetting{
|
||||
VoiceID: voiceID,
|
||||
Speed: speed,
|
||||
},
|
||||
AudioSetting: &AudioSetting{
|
||||
Format: outputFormat,
|
||||
},
|
||||
OutputFormat: outputFormat,
|
||||
}
|
||||
|
||||
// 同步扩展字段的厂商自定义metadata
|
||||
if len(request.Metadata) > 0 {
|
||||
if err := json.Unmarshal(request.Metadata, &minimaxRequest); err != nil {
|
||||
return nil, fmt.Errorf("error unmarshalling metadata to minimax request: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(minimaxRequest)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error marshalling minimax request: %w", err)
|
||||
}
|
||||
if outputFormat != "hex" {
|
||||
outputFormat = "url"
|
||||
}
|
||||
|
||||
c.Set("response_format", outputFormat)
|
||||
|
||||
// Debug: log the request structure
|
||||
// fmt.Printf("MiniMax TTS Request: %s\n", string(jsonData))
|
||||
|
||||
return bytes.NewReader(jsonData), nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
|
||||
return request, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
return GetRequestURL(info)
|
||||
}
|
||||
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
|
||||
channel.SetupApiRequestHeader(info, c, req)
|
||||
req.Set("Authorization", "Bearer "+info.ApiKey)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
return request, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
|
||||
return request, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
||||
return channel.DoApiRequest(a, c, info, requestBody)
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
|
||||
if info.RelayMode == constant.RelayModeAudioSpeech {
|
||||
return handleTTSResponse(c, resp, info)
|
||||
}
|
||||
|
||||
adaptor := openai.Adaptor{}
|
||||
return adaptor.DoResponse(c, resp, info)
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetModelList() []string {
|
||||
return ModelList
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetChannelName() string {
|
||||
return ChannelName
|
||||
}
|
||||
@@ -8,6 +8,12 @@ var ModelList = []string{
|
||||
"abab6-chat",
|
||||
"abab5.5-chat",
|
||||
"abab5.5s-chat",
|
||||
"speech-2.5-hd-preview",
|
||||
"speech-2.5-turbo-preview",
|
||||
"speech-02-hd",
|
||||
"speech-02-turbo",
|
||||
"speech-01-hd",
|
||||
"speech-01-turbo",
|
||||
}
|
||||
|
||||
var ChannelName = "minimax"
|
||||
|
||||
@@ -3,9 +3,23 @@ package minimax
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
channelconstant "github.com/QuantumNous/new-api/constant"
|
||||
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||||
"github.com/QuantumNous/new-api/relay/constant"
|
||||
)
|
||||
|
||||
func GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
return fmt.Sprintf("%s/v1/text/chatcompletion_v2", info.ChannelBaseUrl), nil
|
||||
baseUrl := info.ChannelBaseUrl
|
||||
if baseUrl == "" {
|
||||
baseUrl = channelconstant.ChannelBaseURLs[channelconstant.ChannelTypeMiniMax]
|
||||
}
|
||||
|
||||
switch info.RelayMode {
|
||||
case constant.RelayModeChatCompletions:
|
||||
return fmt.Sprintf("%s/v1/text/chatcompletion_v2", baseUrl), nil
|
||||
case constant.RelayModeAudioSpeech:
|
||||
return fmt.Sprintf("%s/v1/t2a_v2", baseUrl), nil
|
||||
default:
|
||||
return "", fmt.Errorf("unsupported relay mode: %d", info.RelayMode)
|
||||
}
|
||||
}
|
||||
|
||||
194
relay/channel/minimax/tts.go
Normal file
194
relay/channel/minimax/tts.go
Normal file
@@ -0,0 +1,194 @@
|
||||
package minimax
|
||||
|
||||
import (
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/QuantumNous/new-api/dto"
|
||||
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||||
"github.com/QuantumNous/new-api/types"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type MiniMaxTTSRequest struct {
|
||||
Model string `json:"model"`
|
||||
Text string `json:"text"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
StreamOptions *StreamOptions `json:"stream_options,omitempty"`
|
||||
VoiceSetting VoiceSetting `json:"voice_setting"`
|
||||
PronunciationDict *PronunciationDict `json:"pronunciation_dict,omitempty"`
|
||||
AudioSetting *AudioSetting `json:"audio_setting,omitempty"`
|
||||
TimbreWeights []TimbreWeight `json:"timbre_weights,omitempty"`
|
||||
LanguageBoost string `json:"language_boost,omitempty"`
|
||||
VoiceModify *VoiceModify `json:"voice_modify,omitempty"`
|
||||
SubtitleEnable bool `json:"subtitle_enable,omitempty"`
|
||||
OutputFormat string `json:"output_format,omitempty"`
|
||||
AigcWatermark bool `json:"aigc_watermark,omitempty"`
|
||||
}
|
||||
|
||||
type StreamOptions struct {
|
||||
ExcludeAggregatedAudio bool `json:"exclude_aggregated_audio,omitempty"`
|
||||
}
|
||||
|
||||
type VoiceSetting struct {
|
||||
VoiceID string `json:"voice_id"`
|
||||
Speed float64 `json:"speed,omitempty"`
|
||||
Vol float64 `json:"vol,omitempty"`
|
||||
Pitch int `json:"pitch,omitempty"`
|
||||
Emotion string `json:"emotion,omitempty"`
|
||||
TextNormalization bool `json:"text_normalization,omitempty"`
|
||||
LatexRead bool `json:"latex_read,omitempty"`
|
||||
}
|
||||
|
||||
type PronunciationDict struct {
|
||||
Tone []string `json:"tone,omitempty"`
|
||||
}
|
||||
|
||||
type AudioSetting struct {
|
||||
SampleRate int `json:"sample_rate,omitempty"`
|
||||
Bitrate int `json:"bitrate,omitempty"`
|
||||
Format string `json:"format,omitempty"`
|
||||
Channel int `json:"channel,omitempty"`
|
||||
ForceCbr bool `json:"force_cbr,omitempty"`
|
||||
}
|
||||
|
||||
type TimbreWeight struct {
|
||||
VoiceID string `json:"voice_id"`
|
||||
Weight int `json:"weight"`
|
||||
}
|
||||
|
||||
type VoiceModify struct {
|
||||
Pitch int `json:"pitch,omitempty"`
|
||||
Intensity int `json:"intensity,omitempty"`
|
||||
Timbre int `json:"timbre,omitempty"`
|
||||
SoundEffects string `json:"sound_effects,omitempty"`
|
||||
}
|
||||
|
||||
type MiniMaxTTSResponse struct {
|
||||
Data MiniMaxTTSData `json:"data"`
|
||||
ExtraInfo MiniMaxExtraInfo `json:"extra_info"`
|
||||
TraceID string `json:"trace_id"`
|
||||
BaseResp MiniMaxBaseResp `json:"base_resp"`
|
||||
}
|
||||
|
||||
type MiniMaxTTSData struct {
|
||||
Audio string `json:"audio"`
|
||||
Status int `json:"status"`
|
||||
}
|
||||
|
||||
type MiniMaxExtraInfo struct {
|
||||
UsageCharacters int64 `json:"usage_characters"`
|
||||
}
|
||||
|
||||
type MiniMaxBaseResp struct {
|
||||
StatusCode int64 `json:"status_code"`
|
||||
StatusMsg string `json:"status_msg"`
|
||||
}
|
||||
|
||||
func getContentTypeByFormat(format string) string {
|
||||
contentTypeMap := map[string]string{
|
||||
"mp3": "audio/mpeg",
|
||||
"wav": "audio/wav",
|
||||
"flac": "audio/flac",
|
||||
"aac": "audio/aac",
|
||||
"pcm": "audio/pcm",
|
||||
}
|
||||
if ct, ok := contentTypeMap[format]; ok {
|
||||
return ct
|
||||
}
|
||||
return "audio/mpeg" // default to mp3
|
||||
}
|
||||
|
||||
func handleTTSResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
|
||||
body, readErr := io.ReadAll(resp.Body)
|
||||
if readErr != nil {
|
||||
return nil, types.NewErrorWithStatusCode(
|
||||
fmt.Errorf("failed to read minimax response: %w", readErr),
|
||||
types.ErrorCodeReadResponseBodyFailed,
|
||||
http.StatusInternalServerError,
|
||||
)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// Parse response
|
||||
var minimaxResp MiniMaxTTSResponse
|
||||
if unmarshalErr := json.Unmarshal(body, &minimaxResp); unmarshalErr != nil {
|
||||
return nil, types.NewErrorWithStatusCode(
|
||||
fmt.Errorf("failed to unmarshal minimax TTS response: %w", unmarshalErr),
|
||||
types.ErrorCodeBadResponseBody,
|
||||
http.StatusInternalServerError,
|
||||
)
|
||||
}
|
||||
|
||||
// Check base_resp status code
|
||||
if minimaxResp.BaseResp.StatusCode != 0 {
|
||||
return nil, types.NewErrorWithStatusCode(
|
||||
fmt.Errorf("minimax TTS error: %d - %s", minimaxResp.BaseResp.StatusCode, minimaxResp.BaseResp.StatusMsg),
|
||||
types.ErrorCodeBadResponse,
|
||||
http.StatusBadRequest,
|
||||
)
|
||||
}
|
||||
|
||||
// Check if we have audio data
|
||||
if minimaxResp.Data.Audio == "" {
|
||||
return nil, types.NewErrorWithStatusCode(
|
||||
fmt.Errorf("no audio data in minimax TTS response"),
|
||||
types.ErrorCodeBadResponse,
|
||||
http.StatusBadRequest,
|
||||
)
|
||||
}
|
||||
|
||||
if strings.HasPrefix(minimaxResp.Data.Audio, "http") {
|
||||
c.Redirect(http.StatusFound, minimaxResp.Data.Audio)
|
||||
} else {
|
||||
// Handle hex-encoded audio data
|
||||
audioData, decodeErr := hex.DecodeString(minimaxResp.Data.Audio)
|
||||
if decodeErr != nil {
|
||||
return nil, types.NewErrorWithStatusCode(
|
||||
fmt.Errorf("failed to decode hex audio data: %w", decodeErr),
|
||||
types.ErrorCodeBadResponse,
|
||||
http.StatusInternalServerError,
|
||||
)
|
||||
}
|
||||
|
||||
// Determine content type - default to mp3
|
||||
contentType := "audio/mpeg"
|
||||
|
||||
c.Data(http.StatusOK, contentType, audioData)
|
||||
}
|
||||
|
||||
usage = &dto.Usage{
|
||||
PromptTokens: info.PromptTokens,
|
||||
CompletionTokens: 0,
|
||||
TotalTokens: int(minimaxResp.ExtraInfo.UsageCharacters),
|
||||
}
|
||||
|
||||
return usage, nil
|
||||
}
|
||||
|
||||
func handleChatCompletionResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
|
||||
body, readErr := io.ReadAll(resp.Body)
|
||||
if readErr != nil {
|
||||
return nil, types.NewErrorWithStatusCode(
|
||||
errors.New("failed to read minimax response"),
|
||||
types.ErrorCodeReadResponseBodyFailed,
|
||||
http.StatusInternalServerError,
|
||||
)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// Set response headers
|
||||
for key, values := range resp.Header {
|
||||
for _, value := range values {
|
||||
c.Header(key, value)
|
||||
}
|
||||
}
|
||||
|
||||
c.Data(resp.StatusCode, "application/json", body)
|
||||
return nil, nil
|
||||
}
|
||||
@@ -121,7 +121,14 @@ func ollamaStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http
|
||||
if chunk.Message != nil && len(chunk.Message.Thinking) > 0 {
|
||||
raw := strings.TrimSpace(string(chunk.Message.Thinking))
|
||||
if raw != "" && raw != "null" {
|
||||
delta.Choices[0].Delta.SetReasoningContent(raw)
|
||||
// Unmarshal the JSON string to get the actual content without quotes
|
||||
var thinkingContent string
|
||||
if err := json.Unmarshal(chunk.Message.Thinking, &thinkingContent); err == nil {
|
||||
delta.Choices[0].Delta.SetReasoningContent(thinkingContent)
|
||||
} else {
|
||||
// Fallback to raw string if it's not a JSON string
|
||||
delta.Choices[0].Delta.SetReasoningContent(raw)
|
||||
}
|
||||
}
|
||||
}
|
||||
// tool calls
|
||||
@@ -209,7 +216,14 @@ func ollamaChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.R
|
||||
if ck.Message != nil && len(ck.Message.Thinking) > 0 {
|
||||
raw := strings.TrimSpace(string(ck.Message.Thinking))
|
||||
if raw != "" && raw != "null" {
|
||||
reasoningBuilder.WriteString(raw)
|
||||
// Unmarshal the JSON string to get the actual content without quotes
|
||||
var thinkingContent string
|
||||
if err := json.Unmarshal(ck.Message.Thinking, &thinkingContent); err == nil {
|
||||
reasoningBuilder.WriteString(thinkingContent)
|
||||
} else {
|
||||
// Fallback to raw string if it's not a JSON string
|
||||
reasoningBuilder.WriteString(raw)
|
||||
}
|
||||
}
|
||||
}
|
||||
if ck.Message != nil && ck.Message.Content != "" {
|
||||
@@ -229,7 +243,14 @@ func ollamaChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.R
|
||||
if len(single.Message.Thinking) > 0 {
|
||||
raw := strings.TrimSpace(string(single.Message.Thinking))
|
||||
if raw != "" && raw != "null" {
|
||||
reasoningBuilder.WriteString(raw)
|
||||
// Unmarshal the JSON string to get the actual content without quotes
|
||||
var thinkingContent string
|
||||
if err := json.Unmarshal(single.Message.Thinking, &thinkingContent); err == nil {
|
||||
reasoningBuilder.WriteString(thinkingContent)
|
||||
} else {
|
||||
// Fallback to raw string if it's not a JSON string
|
||||
reasoningBuilder.WriteString(raw)
|
||||
}
|
||||
}
|
||||
}
|
||||
aggContent.WriteString(single.Message.Content)
|
||||
|
||||
@@ -18,7 +18,7 @@ import (
|
||||
"github.com/QuantumNous/new-api/relay/channel"
|
||||
"github.com/QuantumNous/new-api/relay/channel/ai360"
|
||||
"github.com/QuantumNous/new-api/relay/channel/lingyiwanwu"
|
||||
"github.com/QuantumNous/new-api/relay/channel/minimax"
|
||||
//"github.com/QuantumNous/new-api/relay/channel/minimax"
|
||||
"github.com/QuantumNous/new-api/relay/channel/openrouter"
|
||||
"github.com/QuantumNous/new-api/relay/channel/xinference"
|
||||
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||||
@@ -161,8 +161,8 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
requestURL = fmt.Sprintf("/openai/realtime?deployment=%s&api-version=%s", model_, apiVersion)
|
||||
}
|
||||
return relaycommon.GetFullRequestURL(info.ChannelBaseUrl, requestURL, info.ChannelType), nil
|
||||
case constant.ChannelTypeMiniMax:
|
||||
return minimax.GetRequestURL(info)
|
||||
//case constant.ChannelTypeMiniMax:
|
||||
// return minimax.GetRequestURL(info)
|
||||
case constant.ChannelTypeCustom:
|
||||
url := info.ChannelBaseUrl
|
||||
url = strings.Replace(url, "{model}", info.UpstreamModelName, -1)
|
||||
@@ -599,8 +599,8 @@ func (a *Adaptor) GetModelList() []string {
|
||||
return ai360.ModelList
|
||||
case constant.ChannelTypeLingYiWanWu:
|
||||
return lingyiwanwu.ModelList
|
||||
case constant.ChannelTypeMiniMax:
|
||||
return minimax.ModelList
|
||||
//case constant.ChannelTypeMiniMax:
|
||||
// return minimax.ModelList
|
||||
case constant.ChannelTypeXinference:
|
||||
return xinference.ModelList
|
||||
case constant.ChannelTypeOpenRouter:
|
||||
@@ -616,8 +616,8 @@ func (a *Adaptor) GetChannelName() string {
|
||||
return ai360.ChannelName
|
||||
case constant.ChannelTypeLingYiWanWu:
|
||||
return lingyiwanwu.ChannelName
|
||||
case constant.ChannelTypeMiniMax:
|
||||
return minimax.ChannelName
|
||||
//case constant.ChannelTypeMiniMax:
|
||||
// return minimax.ChannelName
|
||||
case constant.ChannelTypeXinference:
|
||||
return xinference.ChannelName
|
||||
case constant.ChannelTypeOpenRouter:
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/dto"
|
||||
"github.com/QuantumNous/new-api/relay/channel"
|
||||
"github.com/QuantumNous/new-api/relay/channel/openai"
|
||||
@@ -35,8 +36,27 @@ func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInf
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
|
||||
adaptor := openai.Adaptor{}
|
||||
return adaptor.ConvertImageRequest(c, info, request)
|
||||
// 解析extra到SFImageRequest里,以填入SiliconFlow特殊字段。若失败重建一个空的。
|
||||
sfRequest := &SFImageRequest{}
|
||||
extra, err := common.Marshal(request.Extra)
|
||||
if err == nil {
|
||||
err = common.Unmarshal(extra, sfRequest)
|
||||
if err != nil {
|
||||
sfRequest = &SFImageRequest{}
|
||||
}
|
||||
}
|
||||
|
||||
sfRequest.Model = request.Model
|
||||
sfRequest.Prompt = request.Prompt
|
||||
// 优先使用image_size/batch_size,否则使用OpenAI标准的size/n
|
||||
if sfRequest.ImageSize == "" {
|
||||
sfRequest.ImageSize = request.Size
|
||||
}
|
||||
if sfRequest.BatchSize == 0 {
|
||||
sfRequest.BatchSize = request.N
|
||||
}
|
||||
|
||||
return sfRequest, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
||||
@@ -51,6 +71,8 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
return fmt.Sprintf("%s/v1/chat/completions", info.ChannelBaseUrl), nil
|
||||
} else if info.RelayMode == constant.RelayModeCompletions {
|
||||
return fmt.Sprintf("%s/v1/completions", info.ChannelBaseUrl), nil
|
||||
} else if info.RelayMode == constant.RelayModeImagesGenerations {
|
||||
return fmt.Sprintf("%s/v1/images/generations", info.ChannelBaseUrl), nil
|
||||
}
|
||||
return fmt.Sprintf("%s/v1/chat/completions", info.ChannelBaseUrl), nil
|
||||
}
|
||||
@@ -102,6 +124,8 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
|
||||
fallthrough
|
||||
case constant.RelayModeChatCompletions:
|
||||
fallthrough
|
||||
case constant.RelayModeImagesGenerations:
|
||||
fallthrough
|
||||
default:
|
||||
if info.IsStream {
|
||||
usage, err = openai.OaiStreamHandler(c, info, resp)
|
||||
|
||||
@@ -15,3 +15,18 @@ type SFRerankResponse struct {
|
||||
Results []dto.RerankResponseResult `json:"results"`
|
||||
Meta SFMeta `json:"meta"`
|
||||
}
|
||||
|
||||
type SFImageRequest struct {
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt"`
|
||||
NegativePrompt string `json:"negative_prompt,omitempty"`
|
||||
ImageSize string `json:"image_size,omitempty"`
|
||||
BatchSize uint `json:"batch_size,omitempty"`
|
||||
Seed uint64 `json:"seed,omitempty"`
|
||||
NumInferenceSteps uint `json:"num_inference_steps,omitempty"`
|
||||
GuidanceScale float64 `json:"guidance_scale,omitempty"`
|
||||
Cfg float64 `json:"cfg,omitempty"`
|
||||
Image string `json:"image,omitempty"`
|
||||
Image2 string `json:"image2,omitempty"`
|
||||
Image3 string `json:"image3,omitempty"`
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"bytes"
|
||||
"crypto/hmac"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
@@ -14,6 +15,7 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/model"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -64,6 +66,11 @@ type responseTask struct {
|
||||
TimeElapsed string `json:"time_elapsed"`
|
||||
}
|
||||
|
||||
const (
|
||||
// 即梦限制单个文件最大4.7MB https://www.volcengine.com/docs/85621/1747301
|
||||
MaxFileSize int64 = 4*1024*1024 + 700*1024 // 4.7MB (4MB + 724KB)
|
||||
)
|
||||
|
||||
// ============================
|
||||
// Adaptor implementation
|
||||
// ============================
|
||||
@@ -89,7 +96,6 @@ func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) {
|
||||
|
||||
// ValidateRequestAndSetAction parses body, validates fields and sets default action.
|
||||
func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.TaskError) {
|
||||
// Accept only POST /v1/video/generations as "generate" action.
|
||||
return relaycommon.ValidateBasicTaskRequest(c, info, constant.TaskActionGenerate)
|
||||
}
|
||||
|
||||
@@ -113,13 +119,49 @@ func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info
|
||||
return nil
|
||||
}
|
||||
|
||||
// BuildRequestBody converts request into Jimeng specific format.
|
||||
func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayInfo) (io.Reader, error) {
|
||||
v, exists := c.Get("task_request")
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("request not found in context")
|
||||
}
|
||||
req := v.(relaycommon.TaskSubmitReq)
|
||||
req, ok := v.(relaycommon.TaskSubmitReq)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid request type in context")
|
||||
}
|
||||
// 支持openai sdk的图片上传方式
|
||||
if mf, err := c.MultipartForm(); err == nil {
|
||||
if files, exists := mf.File["input_reference"]; exists && len(files) > 0 {
|
||||
if len(files) == 1 {
|
||||
info.Action = constant.TaskActionGenerate
|
||||
} else if len(files) > 1 {
|
||||
info.Action = constant.TaskActionFirstTailGenerate
|
||||
}
|
||||
|
||||
// 将上传的文件转换为base64格式
|
||||
var images []string
|
||||
|
||||
for _, fileHeader := range files {
|
||||
// 检查文件大小
|
||||
if fileHeader.Size > MaxFileSize {
|
||||
return nil, fmt.Errorf("文件 %s 大小超过限制,最大允许 %d MB", fileHeader.Filename, MaxFileSize/(1024*1024))
|
||||
}
|
||||
|
||||
file, err := fileHeader.Open()
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
fileBytes, err := io.ReadAll(file)
|
||||
file.Close()
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
// 将文件内容转换为base64
|
||||
base64Str := base64.StdEncoding.EncodeToString(fileBytes)
|
||||
images = append(images, base64Str)
|
||||
}
|
||||
req.Images = images
|
||||
}
|
||||
}
|
||||
|
||||
body, err := a.convertToRequestPayload(&req)
|
||||
if err != nil {
|
||||
@@ -158,7 +200,7 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela
|
||||
return
|
||||
}
|
||||
|
||||
ov := relaycommon.NewOpenAIVideo()
|
||||
ov := dto.NewOpenAIVideo()
|
||||
ov.ID = jResp.Data.TaskID
|
||||
ov.TaskID = jResp.Data.TaskID
|
||||
ov.CreatedAt = time.Now().Unix()
|
||||
@@ -364,10 +406,10 @@ func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*
|
||||
// 即梦视频3.0 ReqKey转换
|
||||
// https://www.volcengine.com/docs/85621/1792707
|
||||
if strings.Contains(r.ReqKey, "jimeng_v30") {
|
||||
if len(r.ImageUrls) > 1 {
|
||||
if len(req.Images) > 1 {
|
||||
// 多张图片:首尾帧生成
|
||||
r.ReqKey = strings.Replace(r.ReqKey, "jimeng_v30", "jimeng_i2v_first_tail_v30", 1)
|
||||
} else if len(r.ImageUrls) == 1 {
|
||||
} else if len(req.Images) == 1 {
|
||||
// 单张图片:图生视频
|
||||
r.ReqKey = strings.Replace(r.ReqKey, "jimeng_v30", "jimeng_i2v_first_v30", 1)
|
||||
} else {
|
||||
@@ -405,13 +447,13 @@ func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, e
|
||||
return &taskResult, nil
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) ConvertToOpenAIVideo(originTask *model.Task) (*relaycommon.OpenAIVideo, error) {
|
||||
func (a *TaskAdaptor) ConvertToOpenAIVideo(originTask *model.Task) ([]byte, error) {
|
||||
var jimengResp responseTask
|
||||
if err := json.Unmarshal(originTask.Data, &jimengResp); err != nil {
|
||||
return nil, errors.Wrap(err, "unmarshal jimeng task data failed")
|
||||
}
|
||||
|
||||
openAIVideo := relaycommon.NewOpenAIVideo()
|
||||
openAIVideo := dto.NewOpenAIVideo()
|
||||
openAIVideo.ID = originTask.TaskID
|
||||
openAIVideo.Status = originTask.Status.ToVideoStatus()
|
||||
openAIVideo.SetProgressStr(originTask.Progress)
|
||||
@@ -420,13 +462,14 @@ func (a *TaskAdaptor) ConvertToOpenAIVideo(originTask *model.Task) (*relaycommon
|
||||
openAIVideo.CompletedAt = originTask.UpdatedAt
|
||||
|
||||
if jimengResp.Code != 10000 {
|
||||
openAIVideo.Error = &relaycommon.OpenAIVideoError{
|
||||
openAIVideo.Error = &dto.OpenAIVideoError{
|
||||
Message: jimengResp.Message,
|
||||
Code: fmt.Sprintf("%d", jimengResp.Code),
|
||||
}
|
||||
}
|
||||
|
||||
return openAIVideo, nil
|
||||
jsonData, _ := common.Marshal(openAIVideo)
|
||||
return jsonData, nil
|
||||
}
|
||||
|
||||
func isNewAPIRelay(apiKey string) bool {
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/model"
|
||||
|
||||
"github.com/samber/lo"
|
||||
@@ -188,7 +189,7 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela
|
||||
taskErr = service.TaskErrorWrapperLocal(fmt.Errorf(kResp.Message), "task_failed", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
ov := relaycommon.NewOpenAIVideo()
|
||||
ov := dto.NewOpenAIVideo()
|
||||
ov.ID = kResp.Data.TaskId
|
||||
ov.TaskID = kResp.Data.TaskId
|
||||
ov.CreatedAt = time.Now().Unix()
|
||||
@@ -367,13 +368,13 @@ func isNewAPIRelay(apiKey string) bool {
|
||||
return strings.HasPrefix(apiKey, "sk-")
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) ConvertToOpenAIVideo(originTask *model.Task) (*relaycommon.OpenAIVideo, error) {
|
||||
func (a *TaskAdaptor) ConvertToOpenAIVideo(originTask *model.Task) ([]byte, error) {
|
||||
var klingResp responsePayload
|
||||
if err := json.Unmarshal(originTask.Data, &klingResp); err != nil {
|
||||
return nil, errors.Wrap(err, "unmarshal kling task data failed")
|
||||
}
|
||||
|
||||
openAIVideo := relaycommon.NewOpenAIVideo()
|
||||
openAIVideo := dto.NewOpenAIVideo()
|
||||
openAIVideo.ID = originTask.TaskID
|
||||
openAIVideo.Status = originTask.Status.ToVideoStatus()
|
||||
openAIVideo.SetProgressStr(originTask.Progress)
|
||||
@@ -391,11 +392,11 @@ func (a *TaskAdaptor) ConvertToOpenAIVideo(originTask *model.Task) (*relaycommon
|
||||
}
|
||||
|
||||
if klingResp.Code != 0 && klingResp.Message != "" {
|
||||
openAIVideo.Error = &relaycommon.OpenAIVideoError{
|
||||
openAIVideo.Error = &dto.OpenAIVideoError{
|
||||
Message: klingResp.Message,
|
||||
Code: fmt.Sprintf("%d", klingResp.Code),
|
||||
}
|
||||
}
|
||||
|
||||
return openAIVideo, nil
|
||||
jsonData, _ := common.Marshal(openAIVideo)
|
||||
return jsonData, nil
|
||||
}
|
||||
|
||||
@@ -2,7 +2,6 @@ package sora
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
@@ -107,7 +106,7 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, _ *relayco
|
||||
|
||||
// Parse Sora response
|
||||
var dResp responseTask
|
||||
if err := json.Unmarshal(responseBody, &dResp); err != nil {
|
||||
if err := common.Unmarshal(responseBody, &dResp); err != nil {
|
||||
taskErr = service.TaskErrorWrapper(errors.Wrapf(err, "body: %s", responseBody), "unmarshal_response_body_failed", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
@@ -154,7 +153,7 @@ func (a *TaskAdaptor) GetChannelName() string {
|
||||
|
||||
func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) {
|
||||
resTask := responseTask{}
|
||||
if err := json.Unmarshal(respBody, &resTask); err != nil {
|
||||
if err := common.Unmarshal(respBody, &resTask); err != nil {
|
||||
return nil, errors.Wrap(err, "unmarshal task result failed")
|
||||
}
|
||||
|
||||
@@ -186,11 +185,6 @@ func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, e
|
||||
return &taskResult, nil
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) ConvertToOpenAIVideo(task *model.Task) (*relaycommon.OpenAIVideo, error) {
|
||||
openAIVideo := &relaycommon.OpenAIVideo{}
|
||||
err := json.Unmarshal(task.Data, openAIVideo)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "unmarshal to OpenAIVideo failed")
|
||||
}
|
||||
return openAIVideo, nil
|
||||
func (a *TaskAdaptor) ConvertToOpenAIVideo(task *model.Task) ([]byte, error) {
|
||||
return task.Data, nil
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"github.com/QuantumNous/new-api/constant"
|
||||
@@ -155,7 +156,7 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela
|
||||
return
|
||||
}
|
||||
|
||||
ov := relaycommon.NewOpenAIVideo()
|
||||
ov := dto.NewOpenAIVideo()
|
||||
ov.ID = vResp.TaskId
|
||||
ov.TaskID = vResp.TaskId
|
||||
ov.CreatedAt = time.Now().Unix()
|
||||
@@ -263,13 +264,13 @@ func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, e
|
||||
return taskInfo, nil
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) ConvertToOpenAIVideo(originTask *model.Task) (*relaycommon.OpenAIVideo, error) {
|
||||
func (a *TaskAdaptor) ConvertToOpenAIVideo(originTask *model.Task) ([]byte, error) {
|
||||
var viduResp taskResultResponse
|
||||
if err := json.Unmarshal(originTask.Data, &viduResp); err != nil {
|
||||
return nil, errors.Wrap(err, "unmarshal vidu task data failed")
|
||||
}
|
||||
|
||||
openAIVideo := relaycommon.NewOpenAIVideo()
|
||||
openAIVideo := dto.NewOpenAIVideo()
|
||||
openAIVideo.ID = originTask.TaskID
|
||||
openAIVideo.Status = originTask.Status.ToVideoStatus()
|
||||
openAIVideo.SetProgressStr(originTask.Progress)
|
||||
@@ -281,11 +282,12 @@ func (a *TaskAdaptor) ConvertToOpenAIVideo(originTask *model.Task) (*relaycommon
|
||||
}
|
||||
|
||||
if viduResp.State == "failed" && viduResp.ErrCode != "" {
|
||||
openAIVideo.Error = &relaycommon.OpenAIVideoError{
|
||||
openAIVideo.Error = &dto.OpenAIVideoError{
|
||||
Message: viduResp.ErrCode,
|
||||
Code: viduResp.ErrCode,
|
||||
}
|
||||
}
|
||||
|
||||
return openAIVideo, nil
|
||||
jsonData, _ := common.Marshal(openAIVideo)
|
||||
return jsonData, nil
|
||||
}
|
||||
|
||||
@@ -37,8 +37,57 @@ func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayIn
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
if info.RelayMode != constant.RelayModeAudioSpeech {
|
||||
return nil, errors.New("unsupported audio relay mode")
|
||||
}
|
||||
|
||||
appID, token, err := parseVolcengineAuth(info.ApiKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
voiceType := mapVoiceType(request.Voice)
|
||||
speedRatio := request.Speed
|
||||
encoding := mapEncoding(request.ResponseFormat)
|
||||
|
||||
c.Set("response_format", encoding)
|
||||
|
||||
volcRequest := VolcengineTTSRequest{
|
||||
App: VolcengineTTSApp{
|
||||
AppID: appID,
|
||||
Token: token,
|
||||
Cluster: "volcano_tts",
|
||||
},
|
||||
User: VolcengineTTSUser{
|
||||
UID: "openai_relay_user",
|
||||
},
|
||||
Audio: VolcengineTTSAudio{
|
||||
VoiceType: voiceType,
|
||||
Encoding: encoding,
|
||||
SpeedRatio: speedRatio,
|
||||
Rate: 24000,
|
||||
},
|
||||
Request: VolcengineTTSReqInfo{
|
||||
ReqID: generateRequestID(),
|
||||
Text: request.Input,
|
||||
Operation: "query",
|
||||
Model: info.OriginModelName,
|
||||
},
|
||||
}
|
||||
|
||||
// 同步扩展字段的厂商自定义metadata
|
||||
if len(request.Metadata) > 0 {
|
||||
if err = json.Unmarshal(request.Metadata, &volcRequest); err != nil {
|
||||
return nil, fmt.Errorf("error unmarshalling metadata to volcengine request: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(volcRequest)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error marshalling volcengine request: %w", err)
|
||||
}
|
||||
|
||||
return bytes.NewReader(jsonData), nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
|
||||
@@ -190,7 +239,6 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
// 支持自定义域名,如果未设置则使用默认域名
|
||||
baseUrl := info.ChannelBaseUrl
|
||||
if baseUrl == "" {
|
||||
baseUrl = channelconstant.ChannelBaseURLs[channelconstant.ChannelTypeVolcEngine]
|
||||
@@ -217,6 +265,12 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
return fmt.Sprintf("%s/api/v3/images/edits", baseUrl), nil
|
||||
case constant.RelayModeRerank:
|
||||
return fmt.Sprintf("%s/api/v3/rerank", baseUrl), nil
|
||||
case constant.RelayModeAudioSpeech:
|
||||
// 只有当 baseUrl 是火山默认的官方Url时才改为官方的的TTS接口,否则走透传的New接口
|
||||
if baseUrl == channelconstant.ChannelBaseURLs[channelconstant.ChannelTypeVolcEngine] {
|
||||
return "https://openspeech.bytedance.com/api/v1/tts", nil
|
||||
}
|
||||
return fmt.Sprintf("%s/v1/audio/speech", baseUrl), nil
|
||||
default:
|
||||
}
|
||||
}
|
||||
@@ -225,6 +279,16 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
|
||||
channel.SetupApiRequestHeader(info, c, req)
|
||||
|
||||
if info.RelayMode == constant.RelayModeAudioSpeech {
|
||||
parts := strings.Split(info.ApiKey, "|")
|
||||
if len(parts) == 2 {
|
||||
req.Set("Authorization", "Bearer;"+parts[1])
|
||||
}
|
||||
req.Set("Content-Type", "application/json")
|
||||
return nil
|
||||
}
|
||||
|
||||
req.Set("Authorization", "Bearer "+info.ApiKey)
|
||||
return nil
|
||||
}
|
||||
@@ -260,6 +324,11 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
|
||||
if info.RelayMode == constant.RelayModeAudioSpeech {
|
||||
encoding := mapEncoding(c.GetString("response_format"))
|
||||
return handleTTSResponse(c, resp, info, encoding)
|
||||
}
|
||||
|
||||
adaptor := openai.Adaptor{}
|
||||
usage, err = adaptor.DoResponse(c, resp, info)
|
||||
return
|
||||
|
||||
194
relay/channel/volcengine/tts.go
Normal file
194
relay/channel/volcengine/tts.go
Normal file
@@ -0,0 +1,194 @@
|
||||
package volcengine
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/QuantumNous/new-api/dto"
|
||||
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||||
"github.com/QuantumNous/new-api/types"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type VolcengineTTSRequest struct {
|
||||
App VolcengineTTSApp `json:"app"`
|
||||
User VolcengineTTSUser `json:"user"`
|
||||
Audio VolcengineTTSAudio `json:"audio"`
|
||||
Request VolcengineTTSReqInfo `json:"request"`
|
||||
}
|
||||
|
||||
type VolcengineTTSApp struct {
|
||||
AppID string `json:"appid"`
|
||||
Token string `json:"token"`
|
||||
Cluster string `json:"cluster"`
|
||||
}
|
||||
|
||||
type VolcengineTTSUser struct {
|
||||
UID string `json:"uid"`
|
||||
}
|
||||
|
||||
type VolcengineTTSAudio struct {
|
||||
VoiceType string `json:"voice_type"`
|
||||
Encoding string `json:"encoding"`
|
||||
SpeedRatio float64 `json:"speed_ratio"`
|
||||
Rate int `json:"rate"`
|
||||
Bitrate int `json:"bitrate,omitempty"`
|
||||
LoudnessRatio float64 `json:"loudness_ratio,omitempty"`
|
||||
EnableEmotion bool `json:"enable_emotion,omitempty"`
|
||||
Emotion string `json:"emotion,omitempty"`
|
||||
EmotionScale float64 `json:"emotion_scale,omitempty"`
|
||||
ExplicitLanguage string `json:"explicit_language,omitempty"`
|
||||
ContextLanguage string `json:"context_language,omitempty"`
|
||||
}
|
||||
|
||||
type VolcengineTTSReqInfo struct {
|
||||
ReqID string `json:"reqid"`
|
||||
Text string `json:"text"`
|
||||
Operation string `json:"operation"`
|
||||
Model string `json:"model,omitempty"`
|
||||
TextType string `json:"text_type,omitempty"`
|
||||
SilenceDuration float64 `json:"silence_duration,omitempty"`
|
||||
WithTimestamp interface{} `json:"with_timestamp,omitempty"`
|
||||
ExtraParam *VolcengineTTSExtraParam `json:"extra_param,omitempty"`
|
||||
}
|
||||
|
||||
type VolcengineTTSExtraParam struct {
|
||||
DisableMarkdownFilter bool `json:"disable_markdown_filter,omitempty"`
|
||||
EnableLatexTn bool `json:"enable_latex_tn,omitempty"`
|
||||
MuteCutThreshold string `json:"mute_cut_threshold,omitempty"`
|
||||
MuteCutRemainMs string `json:"mute_cut_remain_ms,omitempty"`
|
||||
DisableEmojiFilter bool `json:"disable_emoji_filter,omitempty"`
|
||||
UnsupportedCharRatioThresh float64 `json:"unsupported_char_ratio_thresh,omitempty"`
|
||||
AigcWatermark bool `json:"aigc_watermark,omitempty"`
|
||||
CacheConfig *VolcengineTTSCacheConfig `json:"cache_config,omitempty"`
|
||||
}
|
||||
|
||||
type VolcengineTTSCacheConfig struct {
|
||||
TextType int `json:"text_type,omitempty"`
|
||||
UseCache bool `json:"use_cache,omitempty"`
|
||||
}
|
||||
|
||||
type VolcengineTTSResponse struct {
|
||||
ReqID string `json:"reqid"`
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Sequence int `json:"sequence"`
|
||||
Data string `json:"data"`
|
||||
Addition *VolcengineTTSAdditionInfo `json:"addition,omitempty"`
|
||||
}
|
||||
|
||||
type VolcengineTTSAdditionInfo struct {
|
||||
Duration string `json:"duration"`
|
||||
}
|
||||
|
||||
var openAIToVolcengineVoiceMap = map[string]string{
|
||||
"alloy": "zh_male_M392_conversation_wvae_bigtts",
|
||||
"echo": "zh_male_wenhao_mars_bigtts",
|
||||
"fable": "zh_female_tianmei_mars_bigtts",
|
||||
"onyx": "zh_male_zhibei_mars_bigtts",
|
||||
"nova": "zh_female_shuangkuaisisi_mars_bigtts",
|
||||
"shimmer": "zh_female_cancan_mars_bigtts",
|
||||
}
|
||||
|
||||
var responseFormatToEncodingMap = map[string]string{
|
||||
"mp3": "mp3",
|
||||
"opus": "ogg_opus",
|
||||
"aac": "mp3",
|
||||
"flac": "mp3",
|
||||
"wav": "wav",
|
||||
"pcm": "pcm",
|
||||
}
|
||||
|
||||
func parseVolcengineAuth(apiKey string) (appID, token string, err error) {
|
||||
parts := strings.Split(apiKey, "|")
|
||||
if len(parts) != 2 {
|
||||
return "", "", errors.New("invalid api key format, expected: appid|access_token")
|
||||
}
|
||||
return parts[0], parts[1], nil
|
||||
}
|
||||
|
||||
func mapVoiceType(openAIVoice string) string {
|
||||
if voice, ok := openAIToVolcengineVoiceMap[openAIVoice]; ok {
|
||||
return voice
|
||||
}
|
||||
return openAIVoice
|
||||
}
|
||||
|
||||
func mapEncoding(responseFormat string) string {
|
||||
if encoding, ok := responseFormatToEncodingMap[responseFormat]; ok {
|
||||
return encoding
|
||||
}
|
||||
return "mp3"
|
||||
}
|
||||
|
||||
func getContentTypeByEncoding(encoding string) string {
|
||||
contentTypeMap := map[string]string{
|
||||
"mp3": "audio/mpeg",
|
||||
"ogg_opus": "audio/ogg",
|
||||
"wav": "audio/wav",
|
||||
"pcm": "audio/pcm",
|
||||
}
|
||||
if ct, ok := contentTypeMap[encoding]; ok {
|
||||
return ct
|
||||
}
|
||||
return "application/octet-stream"
|
||||
}
|
||||
|
||||
func handleTTSResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, encoding string) (usage any, err *types.NewAPIError) {
|
||||
body, readErr := io.ReadAll(resp.Body)
|
||||
if readErr != nil {
|
||||
return nil, types.NewErrorWithStatusCode(
|
||||
errors.New("failed to read volcengine response"),
|
||||
types.ErrorCodeReadResponseBodyFailed,
|
||||
http.StatusInternalServerError,
|
||||
)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
var volcResp VolcengineTTSResponse
|
||||
if unmarshalErr := json.Unmarshal(body, &volcResp); unmarshalErr != nil {
|
||||
return nil, types.NewErrorWithStatusCode(
|
||||
errors.New("failed to parse volcengine response"),
|
||||
types.ErrorCodeBadResponseBody,
|
||||
http.StatusInternalServerError,
|
||||
)
|
||||
}
|
||||
|
||||
if volcResp.Code != 3000 {
|
||||
return nil, types.NewErrorWithStatusCode(
|
||||
errors.New(volcResp.Message),
|
||||
types.ErrorCodeBadResponse,
|
||||
http.StatusBadRequest,
|
||||
)
|
||||
}
|
||||
|
||||
audioData, decodeErr := base64.StdEncoding.DecodeString(volcResp.Data)
|
||||
if decodeErr != nil {
|
||||
return nil, types.NewErrorWithStatusCode(
|
||||
errors.New("failed to decode audio data"),
|
||||
types.ErrorCodeBadResponseBody,
|
||||
http.StatusInternalServerError,
|
||||
)
|
||||
}
|
||||
|
||||
contentType := getContentTypeByEncoding(encoding)
|
||||
c.Header("Content-Type", contentType)
|
||||
c.Data(http.StatusOK, contentType, audioData)
|
||||
|
||||
usage = &dto.Usage{
|
||||
PromptTokens: info.PromptTokens,
|
||||
CompletionTokens: 0,
|
||||
TotalTokens: info.PromptTokens,
|
||||
}
|
||||
|
||||
return usage, nil
|
||||
}
|
||||
|
||||
func generateRequestID() string {
|
||||
return uuid.New().String()
|
||||
}
|
||||
@@ -263,6 +263,7 @@ var streamSupportedChannels = map[int]bool{
|
||||
constant.ChannelTypeDeepSeek: true,
|
||||
constant.ChannelTypeBaiduV2: true,
|
||||
constant.ChannelTypeZhipu_v4: true,
|
||||
constant.ChannelTypeAli: true,
|
||||
}
|
||||
|
||||
func GenRelayInfoWs(c *gin.Context, ws *websocket.Conn) *RelayInfo {
|
||||
@@ -512,6 +513,13 @@ type TaskInfo struct {
|
||||
TotalTokens int `json:"total_tokens,omitempty"` // 用于按倍率计费
|
||||
}
|
||||
|
||||
func FailTaskInfo(reason string) *TaskInfo {
|
||||
return &TaskInfo{
|
||||
Status: "FAILURE",
|
||||
Reason: reason,
|
||||
}
|
||||
}
|
||||
|
||||
// RemoveDisabledFields 从请求 JSON 数据中移除渠道设置中禁用的字段
|
||||
// service_tier: 服务层级字段,可能导致额外计费(OpenAI、Claude、Responses API 支持)
|
||||
// store: 数据存储授权字段,涉及用户隐私(仅 OpenAI、Responses API 支持,默认允许透传,禁用后可能导致 Codex 无法使用)
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/dto"
|
||||
"github.com/QuantumNous/new-api/logger"
|
||||
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||||
"github.com/QuantumNous/new-api/relay/helper"
|
||||
"github.com/QuantumNous/new-api/service"
|
||||
@@ -48,6 +49,7 @@ func EmbeddingHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
logger.LogDebug(c, fmt.Sprintf("converted embedding request body: %s", string(jsonData)))
|
||||
requestBody := bytes.NewBuffer(jsonData)
|
||||
statusCodeMappingStr := c.GetString("status_code_mapping")
|
||||
resp, err := adaptor.DoRequest(c, info, requestBody)
|
||||
|
||||
@@ -240,6 +240,8 @@ func GeminiEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo) (newAPI
|
||||
return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
req.SetModelName("models/" + info.UpstreamModelName)
|
||||
|
||||
adaptor := GetAdaptor(info.ApiType)
|
||||
if adaptor == nil {
|
||||
return types.NewError(fmt.Errorf("invalid api type: %d", info.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry())
|
||||
@@ -264,6 +266,7 @@ func GeminiEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo) (newAPI
|
||||
return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
}
|
||||
logger.LogDebug(c, "Gemini embedding request body: "+string(jsonData))
|
||||
requestBody = bytes.NewReader(jsonData)
|
||||
|
||||
resp, err := adaptor.DoRequest(c, info, requestBody)
|
||||
|
||||
@@ -22,8 +22,10 @@ func GetAndValidateRequest(c *gin.Context, format types.RelayFormat) (request dt
|
||||
case types.RelayFormatOpenAI:
|
||||
request, err = GetAndValidateTextRequest(c, relayMode)
|
||||
case types.RelayFormatGemini:
|
||||
if strings.Contains(c.Request.URL.Path, ":embedContent") || strings.Contains(c.Request.URL.Path, ":batchEmbedContents") {
|
||||
if strings.Contains(c.Request.URL.Path, ":embedContent") {
|
||||
request, err = GetAndValidateGeminiEmbeddingRequest(c)
|
||||
} else if strings.Contains(c.Request.URL.Path, ":batchEmbedContents") {
|
||||
request, err = GetAndValidateGeminiBatchEmbeddingRequest(c)
|
||||
} else {
|
||||
request, err = GetAndValidateGeminiRequest(c)
|
||||
}
|
||||
@@ -300,7 +302,7 @@ func GetAndValidateGeminiRequest(c *gin.Context) (*dto.GeminiChatRequest, error)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(request.Contents) == 0 {
|
||||
if len(request.Contents) == 0 && len(request.Requests) == 0 {
|
||||
return nil, errors.New("contents is required")
|
||||
}
|
||||
|
||||
@@ -319,3 +321,12 @@ func GetAndValidateGeminiEmbeddingRequest(c *gin.Context) (*dto.GeminiEmbeddingR
|
||||
}
|
||||
return request, nil
|
||||
}
|
||||
|
||||
func GetAndValidateGeminiBatchEmbeddingRequest(c *gin.Context) (*dto.GeminiBatchEmbeddingRequest, error) {
|
||||
request := &dto.GeminiBatchEmbeddingRequest{}
|
||||
err := common.UnmarshalBodyReusable(c, request)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return request, nil
|
||||
}
|
||||
|
||||
@@ -1,136 +0,0 @@
|
||||
package hooks
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/core/interfaces"
|
||||
"github.com/QuantumNous/new-api/core/registry"
|
||||
)
|
||||
|
||||
var (
|
||||
// 全局Hook链实例(单例)
|
||||
globalChain *HookChain
|
||||
globalChainOnce sync.Once
|
||||
)
|
||||
|
||||
// HookChain Hook执行链
|
||||
type HookChain struct {
|
||||
hooks []interfaces.RelayHook
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// GetGlobalChain 获取全局Hook链实例
|
||||
func GetGlobalChain() *HookChain {
|
||||
globalChainOnce.Do(func() {
|
||||
globalChain = &HookChain{
|
||||
hooks: make([]interfaces.RelayHook, 0),
|
||||
}
|
||||
// 从注册中心加载hooks
|
||||
globalChain.LoadHooks()
|
||||
})
|
||||
return globalChain
|
||||
}
|
||||
|
||||
// LoadHooks 从注册中心加载hooks
|
||||
func (c *HookChain) LoadHooks() {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
c.hooks = registry.ListHooks()
|
||||
common.SysLog(fmt.Sprintf("Loaded %d enabled hooks", len(c.hooks)))
|
||||
}
|
||||
|
||||
// ReloadHooks 重新加载hooks
|
||||
func (c *HookChain) ReloadHooks() {
|
||||
c.LoadHooks()
|
||||
common.SysLog("Hooks reloaded")
|
||||
}
|
||||
|
||||
// ExecuteBeforeRequest 执行所有BeforeRequest钩子
|
||||
func (c *HookChain) ExecuteBeforeRequest(ctx *interfaces.HookContext) error {
|
||||
c.mu.RLock()
|
||||
hooks := c.hooks
|
||||
c.mu.RUnlock()
|
||||
|
||||
for _, hook := range hooks {
|
||||
if !hook.Enabled() {
|
||||
continue
|
||||
}
|
||||
|
||||
if ctx.ShouldSkip {
|
||||
break
|
||||
}
|
||||
|
||||
if err := hook.OnBeforeRequest(ctx); err != nil {
|
||||
common.SysError(fmt.Sprintf("Hook %s OnBeforeRequest error: %v", hook.Name(), err))
|
||||
return fmt.Errorf("hook %s failed: %w", hook.Name(), err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ExecuteAfterResponse 执行所有AfterResponse钩子
|
||||
func (c *HookChain) ExecuteAfterResponse(ctx *interfaces.HookContext) error {
|
||||
c.mu.RLock()
|
||||
hooks := c.hooks
|
||||
c.mu.RUnlock()
|
||||
|
||||
for _, hook := range hooks {
|
||||
if !hook.Enabled() {
|
||||
continue
|
||||
}
|
||||
|
||||
if ctx.ShouldSkip {
|
||||
break
|
||||
}
|
||||
|
||||
if err := hook.OnAfterResponse(ctx); err != nil {
|
||||
common.SysError(fmt.Sprintf("Hook %s OnAfterResponse error: %v", hook.Name(), err))
|
||||
return fmt.Errorf("hook %s failed: %w", hook.Name(), err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ExecuteOnError 执行所有OnError钩子
|
||||
func (c *HookChain) ExecuteOnError(ctx *interfaces.HookContext, err error) error {
|
||||
c.mu.RLock()
|
||||
hooks := c.hooks
|
||||
c.mu.RUnlock()
|
||||
|
||||
for _, hook := range hooks {
|
||||
if !hook.Enabled() {
|
||||
continue
|
||||
}
|
||||
|
||||
if hookErr := hook.OnError(ctx, err); hookErr != nil {
|
||||
common.SysError(fmt.Sprintf("Hook %s OnError failed: %v", hook.Name(), hookErr))
|
||||
// OnError钩子的错误不会中断执行
|
||||
}
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// GetHooks 获取当前hook列表
|
||||
func (c *HookChain) GetHooks() []interfaces.RelayHook {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
|
||||
hooks := make([]interfaces.RelayHook, len(c.hooks))
|
||||
copy(hooks, c.hooks)
|
||||
return hooks
|
||||
}
|
||||
|
||||
// Count 返回hook数量
|
||||
func (c *HookChain) Count() int {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
|
||||
return len(c.hooks)
|
||||
}
|
||||
|
||||
@@ -1,212 +0,0 @@
|
||||
package hooks
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/QuantumNous/new-api/core/interfaces"
|
||||
"github.com/QuantumNous/new-api/core/registry"
|
||||
)
|
||||
|
||||
// Mock Hook实现
|
||||
type testHook struct {
|
||||
name string
|
||||
priority int
|
||||
enabled bool
|
||||
beforeCalled bool
|
||||
afterCalled bool
|
||||
errorCalled bool
|
||||
shouldReturnError bool
|
||||
}
|
||||
|
||||
func (h *testHook) Name() string { return h.name }
|
||||
func (h *testHook) Priority() int { return h.priority }
|
||||
func (h *testHook) Enabled() bool { return h.enabled }
|
||||
|
||||
func (h *testHook) OnBeforeRequest(ctx *interfaces.HookContext) error {
|
||||
h.beforeCalled = true
|
||||
if h.shouldReturnError {
|
||||
return errors.New("test error")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *testHook) OnAfterResponse(ctx *interfaces.HookContext) error {
|
||||
h.afterCalled = true
|
||||
if h.shouldReturnError {
|
||||
return errors.New("test error")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *testHook) OnError(ctx *interfaces.HookContext, err error) error {
|
||||
h.errorCalled = true
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestHookChainExecution(t *testing.T) {
|
||||
// 创建测试hooks
|
||||
hook1 := &testHook{name: "hook1", priority: 100, enabled: true}
|
||||
hook2 := &testHook{name: "hook2", priority: 50, enabled: true}
|
||||
hook3 := &testHook{name: "hook3", priority: 75, enabled: false} // disabled
|
||||
|
||||
// 创建Hook链
|
||||
chain := &HookChain{
|
||||
hooks: []interfaces.RelayHook{hook1, hook2, hook3},
|
||||
}
|
||||
|
||||
// 创建测试上下文
|
||||
ctx := &interfaces.HookContext{
|
||||
Data: make(map[string]interface{}),
|
||||
}
|
||||
|
||||
// 测试ExecuteBeforeRequest
|
||||
if err := chain.ExecuteBeforeRequest(ctx); err != nil {
|
||||
t.Errorf("ExecuteBeforeRequest failed: %v", err)
|
||||
}
|
||||
|
||||
// 检查enabled的hooks是否被调用
|
||||
if !hook1.beforeCalled {
|
||||
t.Error("hook1 OnBeforeRequest should be called")
|
||||
}
|
||||
|
||||
if !hook2.beforeCalled {
|
||||
t.Error("hook2 OnBeforeRequest should be called")
|
||||
}
|
||||
|
||||
// disabled的hook不应该被调用
|
||||
if hook3.beforeCalled {
|
||||
t.Error("hook3 OnBeforeRequest should not be called (disabled)")
|
||||
}
|
||||
|
||||
// 测试ExecuteAfterResponse
|
||||
if err := chain.ExecuteAfterResponse(ctx); err != nil {
|
||||
t.Errorf("ExecuteAfterResponse failed: %v", err)
|
||||
}
|
||||
|
||||
if !hook1.afterCalled {
|
||||
t.Error("hook1 OnAfterResponse should be called")
|
||||
}
|
||||
|
||||
if !hook2.afterCalled {
|
||||
t.Error("hook2 OnAfterResponse should be called")
|
||||
}
|
||||
|
||||
// 测试ExecuteOnError
|
||||
testErr := errors.New("test error")
|
||||
if err := chain.ExecuteOnError(ctx, testErr); err != testErr {
|
||||
t.Error("ExecuteOnError should return original error")
|
||||
}
|
||||
|
||||
if !hook1.errorCalled {
|
||||
t.Error("hook1 OnError should be called")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHookChainErrorHandling(t *testing.T) {
|
||||
// 创建会返回错误的hook
|
||||
errorHook := &testHook{
|
||||
name: "error_hook",
|
||||
priority: 100,
|
||||
enabled: true,
|
||||
shouldReturnError: true,
|
||||
}
|
||||
|
||||
chain := &HookChain{
|
||||
hooks: []interfaces.RelayHook{errorHook},
|
||||
}
|
||||
|
||||
ctx := &interfaces.HookContext{
|
||||
Data: make(map[string]interface{}),
|
||||
}
|
||||
|
||||
// 测试错误处理
|
||||
if err := chain.ExecuteBeforeRequest(ctx); err == nil {
|
||||
t.Error("Expected error from ExecuteBeforeRequest")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHookChainShouldSkip(t *testing.T) {
|
||||
hook1 := &testHook{name: "hook1", priority: 100, enabled: true}
|
||||
hook2 := &testHook{name: "hook2", priority: 50, enabled: true}
|
||||
|
||||
chain := &HookChain{
|
||||
hooks: []interfaces.RelayHook{hook1, hook2},
|
||||
}
|
||||
|
||||
ctx := &interfaces.HookContext{
|
||||
Data: make(map[string]interface{}),
|
||||
ShouldSkip: true, // 设置跳过标记
|
||||
}
|
||||
|
||||
// 执行
|
||||
if err := chain.ExecuteBeforeRequest(ctx); err != nil {
|
||||
t.Errorf("ExecuteBeforeRequest failed: %v", err)
|
||||
}
|
||||
|
||||
// 由于ShouldSkip为true,hooks不应该被调用
|
||||
// 注意:当前实现在第一个hook执行后才会检查ShouldSkip
|
||||
// 所以hook1会被调用,但hook2不会
|
||||
}
|
||||
|
||||
func TestHookChainCount(t *testing.T) {
|
||||
hook1 := &testHook{name: "hook1", priority: 100, enabled: true}
|
||||
hook2 := &testHook{name: "hook2", priority: 50, enabled: true}
|
||||
|
||||
chain := &HookChain{
|
||||
hooks: []interfaces.RelayHook{hook1, hook2},
|
||||
}
|
||||
|
||||
if count := chain.Count(); count != 2 {
|
||||
t.Errorf("Expected count 2, got %d", count)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHookChainGetHooks(t *testing.T) {
|
||||
hook1 := &testHook{name: "hook1", priority: 100, enabled: true}
|
||||
hook2 := &testHook{name: "hook2", priority: 50, enabled: true}
|
||||
|
||||
chain := &HookChain{
|
||||
hooks: []interfaces.RelayHook{hook1, hook2},
|
||||
}
|
||||
|
||||
hooks := chain.GetHooks()
|
||||
if len(hooks) != 2 {
|
||||
t.Errorf("Expected 2 hooks, got %d", len(hooks))
|
||||
}
|
||||
}
|
||||
|
||||
func TestGlobalChain(t *testing.T) {
|
||||
// 测试全局链的单例模式
|
||||
chain1 := GetGlobalChain()
|
||||
chain2 := GetGlobalChain()
|
||||
|
||||
if chain1 != chain2 {
|
||||
t.Error("GetGlobalChain should return the same instance")
|
||||
}
|
||||
}
|
||||
|
||||
// 集成测试:测试完整的注册和执行流程
|
||||
func TestIntegration(t *testing.T) {
|
||||
// 注册测试hook
|
||||
testHook := &testHook{
|
||||
name: "integration_test_hook",
|
||||
priority: 100,
|
||||
enabled: true,
|
||||
}
|
||||
|
||||
if err := registry.RegisterHook(testHook); err != nil {
|
||||
// 如果已注册,跳过错误
|
||||
t.Logf("Hook already registered (expected in some cases): %v", err)
|
||||
}
|
||||
|
||||
// 创建新的hook链并加载
|
||||
chain := &HookChain{hooks: make([]interfaces.RelayHook, 0)}
|
||||
chain.LoadHooks()
|
||||
|
||||
// 检查是否加载了hooks
|
||||
if chain.Count() == 0 {
|
||||
t.Log("No hooks loaded (expected if registry is clean)")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,79 +0,0 @@
|
||||
package hooks
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/constant"
|
||||
"github.com/QuantumNous/new-api/core/interfaces"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// BuildHookContext 从Gin Context构建HookContext
|
||||
func BuildHookContext(c *gin.Context) *interfaces.HookContext {
|
||||
ctx := &interfaces.HookContext{
|
||||
GinContext: c,
|
||||
Request: c.Request,
|
||||
Data: make(map[string]interface{}),
|
||||
}
|
||||
|
||||
// 提取Channel信息
|
||||
if channelID, ok := common.GetContextKey(c, constant.ContextKeyChannelId); ok {
|
||||
if id, ok := channelID.(int); ok {
|
||||
ctx.ChannelID = id
|
||||
}
|
||||
}
|
||||
|
||||
if channelType, ok := common.GetContextKey(c, constant.ContextKeyChannelType); ok {
|
||||
if t, ok := channelType.(int); ok {
|
||||
ctx.ChannelType = t
|
||||
}
|
||||
}
|
||||
|
||||
if channelName, ok := common.GetContextKey(c, constant.ContextKeyChannelName); ok {
|
||||
if name, ok := channelName.(string); ok {
|
||||
ctx.ChannelName = name
|
||||
}
|
||||
}
|
||||
|
||||
// 提取Model信息
|
||||
if originalModel, ok := common.GetContextKey(c, constant.ContextKeyOriginalModel); ok {
|
||||
if m, ok := originalModel.(string); ok {
|
||||
ctx.OriginalModel = m
|
||||
ctx.Model = m // 使用OriginalModel作为Model
|
||||
}
|
||||
}
|
||||
|
||||
// 提取User信息
|
||||
if userID, ok := common.GetContextKey(c, constant.ContextKeyUserId); ok {
|
||||
if id, ok := userID.(int); ok {
|
||||
ctx.UserID = id
|
||||
}
|
||||
}
|
||||
|
||||
if tokenID, ok := common.GetContextKey(c, constant.ContextKeyTokenId); ok {
|
||||
if id, ok := tokenID.(int); ok {
|
||||
ctx.TokenID = id
|
||||
}
|
||||
}
|
||||
|
||||
if group, ok := common.GetContextKey(c, constant.ContextKeyUsingGroup); ok {
|
||||
if g, ok := group.(string); ok {
|
||||
ctx.Group = g
|
||||
}
|
||||
}
|
||||
|
||||
return ctx
|
||||
}
|
||||
|
||||
// UpdateHookContextWithResponse 更新HookContext的Response信息
|
||||
func UpdateHookContextWithResponse(ctx *interfaces.HookContext, resp *http.Response, body []byte) {
|
||||
ctx.Response = resp
|
||||
ctx.ResponseBody = body
|
||||
}
|
||||
|
||||
// UpdateHookContextWithRequest 更新HookContext的Request信息
|
||||
func UpdateHookContextWithRequest(ctx *interfaces.HookContext, body []byte) {
|
||||
ctx.RequestBody = body
|
||||
}
|
||||
|
||||
@@ -218,7 +218,7 @@ func RelaySwapFace(c *gin.Context, info *relaycommon.RelayInfo) *dto.MidjourneyR
|
||||
|
||||
tokenName := c.GetString("token_name")
|
||||
logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s", priceData.ModelPrice, priceData.GroupRatioInfo.GroupRatio, constant.MjActionSwapFace)
|
||||
other := service.GenerateMjOtherInfo(priceData)
|
||||
other := service.GenerateMjOtherInfo(info, priceData)
|
||||
model.RecordConsumeLog(c, info.UserId, model.RecordConsumeLogParams{
|
||||
ChannelId: info.ChannelId,
|
||||
ModelName: modelName,
|
||||
@@ -518,7 +518,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayInfo *relaycommon.RelayInfo) *dt
|
||||
}
|
||||
tokenName := c.GetString("token_name")
|
||||
logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s,ID %s", priceData.ModelPrice, priceData.GroupRatioInfo.GroupRatio, midjRequest.Action, midjResponse.Result)
|
||||
other := service.GenerateMjOtherInfo(priceData)
|
||||
other := service.GenerateMjOtherInfo(relayInfo, priceData)
|
||||
model.RecordConsumeLog(c, relayInfo.UserId, model.RecordConsumeLogParams{
|
||||
ChannelId: relayInfo.ChannelId,
|
||||
ModelName: modelName,
|
||||
|
||||
@@ -4,8 +4,6 @@ import (
|
||||
"strconv"
|
||||
|
||||
"github.com/QuantumNous/new-api/constant"
|
||||
"github.com/QuantumNous/new-api/core/registry"
|
||||
pluginchannels "github.com/QuantumNous/new-api/plugins/channels"
|
||||
"github.com/QuantumNous/new-api/relay/channel"
|
||||
"github.com/QuantumNous/new-api/relay/channel/ali"
|
||||
"github.com/QuantumNous/new-api/relay/channel/aws"
|
||||
@@ -20,6 +18,7 @@ import (
|
||||
"github.com/QuantumNous/new-api/relay/channel/gemini"
|
||||
"github.com/QuantumNous/new-api/relay/channel/jimeng"
|
||||
"github.com/QuantumNous/new-api/relay/channel/jina"
|
||||
"github.com/QuantumNous/new-api/relay/channel/minimax"
|
||||
"github.com/QuantumNous/new-api/relay/channel/mistral"
|
||||
"github.com/QuantumNous/new-api/relay/channel/mokaai"
|
||||
"github.com/QuantumNous/new-api/relay/channel/moonshot"
|
||||
@@ -46,19 +45,7 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// GetAdaptor 获取Channel适配器(优先从插件注册中心获取,保持向后兼容)
|
||||
func GetAdaptor(apiType int) channel.Adaptor {
|
||||
// 优先从插件注册中心获取
|
||||
if plugin, err := registry.GetChannel(apiType); err == nil {
|
||||
// 如果是BaseChannelPlugin,提取内部的Adaptor
|
||||
if basePlugin, ok := plugin.(*pluginchannels.BaseChannelPlugin); ok {
|
||||
return basePlugin.GetAdaptor()
|
||||
}
|
||||
// 否则直接返回plugin(它也实现了Adaptor接口)
|
||||
return plugin
|
||||
}
|
||||
|
||||
// 向后兼容:如果注册中心没有,使用原有的硬编码方式
|
||||
switch apiType {
|
||||
case constant.APITypeAli:
|
||||
return &ali.Adaptor{}
|
||||
@@ -122,6 +109,8 @@ func GetAdaptor(apiType int) channel.Adaptor {
|
||||
return &moonshot.Adaptor{} // Moonshot uses Claude API
|
||||
case constant.APITypeSubmodel:
|
||||
return &submodel.Adaptor{}
|
||||
case constant.APITypeMiniMax:
|
||||
return &minimax.Adaptor{}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -153,7 +142,7 @@ func GetTaskAdaptor(platform constant.TaskPlatform) channel.TaskAdaptor {
|
||||
return &taskVidu.TaskAdaptor{}
|
||||
case constant.ChannelTypeDoubaoVideo:
|
||||
return &taskdoubao.TaskAdaptor{}
|
||||
case constant.ChannelTypeSora:
|
||||
case constant.ChannelTypeSora, constant.ChannelTypeOpenAI:
|
||||
return &tasksora.TaskAdaptor{}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -72,10 +72,13 @@ func RelayTaskSubmit(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.
|
||||
} else {
|
||||
ratio = modelPrice * groupRatio
|
||||
}
|
||||
if len(info.PriceData.OtherRatios) > 0 {
|
||||
for _, ra := range info.PriceData.OtherRatios {
|
||||
if 1.0 != ra {
|
||||
ratio *= ra
|
||||
// FIXME: 临时修补,支持任务仅按次计费
|
||||
if !common.StringsContains(constant.TaskPricePatches, modelName) {
|
||||
if len(info.PriceData.OtherRatios) > 0 {
|
||||
for _, ra := range info.PriceData.OtherRatios {
|
||||
if 1.0 != ra {
|
||||
ratio *= ra
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -153,18 +156,26 @@ func RelayTaskSubmit(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.
|
||||
// gRatio = userGroupRatio
|
||||
//}
|
||||
logContent := fmt.Sprintf("操作 %s", info.Action)
|
||||
if len(info.PriceData.OtherRatios) > 0 {
|
||||
var contents []string
|
||||
for key, ra := range info.PriceData.OtherRatios {
|
||||
if 1.0 != ra {
|
||||
contents = append(contents, fmt.Sprintf("%s: %.2f", key, ra))
|
||||
// FIXME: 临时修补,支持任务仅按次计费
|
||||
if common.StringsContains(constant.TaskPricePatches, modelName) {
|
||||
logContent = fmt.Sprintf("%s,按次计费", logContent)
|
||||
} else {
|
||||
if len(info.PriceData.OtherRatios) > 0 {
|
||||
var contents []string
|
||||
for key, ra := range info.PriceData.OtherRatios {
|
||||
if 1.0 != ra {
|
||||
contents = append(contents, fmt.Sprintf("%s: %.2f", key, ra))
|
||||
}
|
||||
}
|
||||
if len(contents) > 0 {
|
||||
logContent = fmt.Sprintf("%s, 计算参数:%s", logContent, strings.Join(contents, ", "))
|
||||
}
|
||||
}
|
||||
if len(contents) > 0 {
|
||||
logContent = fmt.Sprintf("%s, 计算参数:%s", logContent, strings.Join(contents, ", "))
|
||||
}
|
||||
}
|
||||
other := make(map[string]interface{})
|
||||
if c != nil && c.Request != nil && c.Request.URL != nil {
|
||||
other["request_path"] = c.Request.URL.Path
|
||||
}
|
||||
other["model_price"] = modelPrice
|
||||
other["group_ratio"] = groupRatio
|
||||
if hasUserGroupRatio {
|
||||
@@ -394,12 +405,12 @@ func videoFetchByIDRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *d
|
||||
return
|
||||
}
|
||||
if converter, ok := adaptor.(channel.OpenAIVideoConverter); ok {
|
||||
openAIVideo, err := converter.ConvertToOpenAIVideo(originTask)
|
||||
openAIVideoData, err := converter.ConvertToOpenAIVideo(originTask)
|
||||
if err != nil {
|
||||
taskResp = service.TaskErrorWrapper(err, "convert_to_openai_video_failed", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
respBody, _ = json.Marshal(openAIVideo)
|
||||
respBody = openAIVideoData
|
||||
return
|
||||
}
|
||||
taskResp = service.TaskErrorWrapperLocal(errors.New(fmt.Sprintf("not_implemented:%s", originTask.Platform)), "not_implemented", http.StatusNotImplemented)
|
||||
|
||||
@@ -352,7 +352,7 @@ func StreamResponseOpenAI2Claude(openAIResponse *dto.ChatCompletionsStreamRespon
|
||||
Type: "content_block_start",
|
||||
ContentBlock: &dto.ClaudeMediaMessage{
|
||||
Type: "thinking",
|
||||
Thinking: "",
|
||||
Thinking: common.GetPointer[string](""),
|
||||
},
|
||||
})
|
||||
}
|
||||
@@ -360,7 +360,7 @@ func StreamResponseOpenAI2Claude(openAIResponse *dto.ChatCompletionsStreamRespon
|
||||
// text delta
|
||||
claudeResponse.Delta = &dto.ClaudeMediaMessage{
|
||||
Type: "thinking_delta",
|
||||
Thinking: reasoning,
|
||||
Thinking: &reasoning,
|
||||
}
|
||||
} else {
|
||||
if info.ClaudeConvertInfo.LastMessagesType != relaycommon.LastMessageTypeText {
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/constant"
|
||||
"github.com/QuantumNous/new-api/dto"
|
||||
@@ -10,6 +12,25 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func appendRequestPath(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, other map[string]interface{}) {
|
||||
if other == nil {
|
||||
return
|
||||
}
|
||||
if ctx != nil && ctx.Request != nil && ctx.Request.URL != nil {
|
||||
if path := ctx.Request.URL.Path; path != "" {
|
||||
other["request_path"] = path
|
||||
return
|
||||
}
|
||||
}
|
||||
if relayInfo != nil && relayInfo.RequestURLPath != "" {
|
||||
path := relayInfo.RequestURLPath
|
||||
if idx := strings.Index(path, "?"); idx != -1 {
|
||||
path = path[:idx]
|
||||
}
|
||||
other["request_path"] = path
|
||||
}
|
||||
}
|
||||
|
||||
func GenerateTextOtherInfo(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelRatio, groupRatio, completionRatio float64,
|
||||
cacheTokens int, cacheRatio float64, modelPrice float64, userGroupRatio float64) map[string]interface{} {
|
||||
other := make(map[string]interface{})
|
||||
@@ -42,6 +63,7 @@ func GenerateTextOtherInfo(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, m
|
||||
adminInfo["multi_key_index"] = common.GetContextKeyInt(ctx, constant.ContextKeyChannelMultiKeyIndex)
|
||||
}
|
||||
other["admin_info"] = adminInfo
|
||||
appendRequestPath(ctx, relayInfo, other)
|
||||
return other
|
||||
}
|
||||
|
||||
@@ -78,12 +100,13 @@ func GenerateClaudeOtherInfo(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
|
||||
return info
|
||||
}
|
||||
|
||||
func GenerateMjOtherInfo(priceData types.PerCallPriceData) map[string]interface{} {
|
||||
func GenerateMjOtherInfo(relayInfo *relaycommon.RelayInfo, priceData types.PerCallPriceData) map[string]interface{} {
|
||||
other := make(map[string]interface{})
|
||||
other["model_price"] = priceData.ModelPrice
|
||||
other["group_ratio"] = priceData.GroupRatioInfo.GroupRatio
|
||||
if priceData.GroupRatioInfo.HasSpecialRatio {
|
||||
other["user_group_ratio"] = priceData.GroupRatioInfo.GroupSpecialRatio
|
||||
}
|
||||
appendRequestPath(nil, relayInfo, other)
|
||||
return other
|
||||
}
|
||||
|
||||
@@ -62,6 +62,9 @@ const (
|
||||
ErrorCodeConvertRequestFailed ErrorCode = "convert_request_failed"
|
||||
ErrorCodeAccessDenied ErrorCode = "access_denied"
|
||||
|
||||
// request error
|
||||
ErrorCodeBadRequestBody ErrorCode = "bad_request_body"
|
||||
|
||||
// response error
|
||||
ErrorCodeReadResponseBodyFailed ErrorCode = "read_response_body_failed"
|
||||
ErrorCodeBadResponseStatusCode ErrorCode = "bad_response_status_code"
|
||||
|
||||
@@ -10,7 +10,8 @@
|
||||
content="OpenAI 接口聚合管理,支持多种渠道包括 Azure,可用于二次分发管理 key,仅单可执行文件,已打包好 Docker 镜像,一键部署,开箱即用"
|
||||
/>
|
||||
<title>New API</title>
|
||||
<analytics></analytics>
|
||||
<!--umami-->
|
||||
<!--Google Analytics-->
|
||||
</head>
|
||||
|
||||
<body>
|
||||
|
||||
@@ -107,10 +107,12 @@ function type2secretPrompt(type) {
|
||||
return '按照如下格式输入:AppId|SecretId|SecretKey';
|
||||
case 33:
|
||||
return '按照如下格式输入:Ak|Sk|Region';
|
||||
case 45:
|
||||
return '请输入渠道对应的鉴权密钥, 豆包语音输入:AppId|AccessToken';
|
||||
case 50:
|
||||
return '按照如下格式输入: AccessKey|SecretKey, 如果上游是New API,则直接输ApiKey';
|
||||
case 51:
|
||||
return '按照如下格式输入: Access Key ID|Secret Access Key';
|
||||
return '按照如下格式输入: AccessKey|SecretAccessKey';
|
||||
default:
|
||||
return '请输入渠道对应的鉴权密钥';
|
||||
}
|
||||
@@ -153,6 +155,8 @@ const EditChannelModal = (props) => {
|
||||
settings: '',
|
||||
// 仅 Vertex: 密钥格式(存入 settings.vertex_key_type)
|
||||
vertex_key_type: 'json',
|
||||
// 仅 AWS: 密钥格式和区域(存入 settings.aws_key_type 和 settings.aws_region)
|
||||
aws_key_type: 'ak_sk',
|
||||
// 企业账户设置
|
||||
is_enterprise_account: false,
|
||||
// 字段透传控制默认值
|
||||
@@ -515,6 +519,8 @@ const EditChannelModal = (props) => {
|
||||
parsedSettings.azure_responses_version || '';
|
||||
// 读取 Vertex 密钥格式
|
||||
data.vertex_key_type = parsedSettings.vertex_key_type || 'json';
|
||||
// 读取 AWS 密钥格式和区域
|
||||
data.aws_key_type = parsedSettings.aws_key_type || 'ak_sk';
|
||||
// 读取企业账户设置
|
||||
data.is_enterprise_account =
|
||||
parsedSettings.openrouter_enterprise === true;
|
||||
@@ -528,6 +534,7 @@ const EditChannelModal = (props) => {
|
||||
data.azure_responses_version = '';
|
||||
data.region = '';
|
||||
data.vertex_key_type = 'json';
|
||||
data.aws_key_type = 'ak_sk';
|
||||
data.is_enterprise_account = false;
|
||||
data.allow_service_tier = false;
|
||||
data.disable_store = false;
|
||||
@@ -536,6 +543,7 @@ const EditChannelModal = (props) => {
|
||||
} else {
|
||||
// 兼容历史数据:老渠道没有 settings 时,默认按 json 展示
|
||||
data.vertex_key_type = 'json';
|
||||
data.aws_key_type = 'ak_sk';
|
||||
data.is_enterprise_account = false;
|
||||
data.allow_service_tier = false;
|
||||
data.disable_store = false;
|
||||
@@ -997,6 +1005,11 @@ const EditChannelModal = (props) => {
|
||||
localInputs.is_enterprise_account === true;
|
||||
}
|
||||
|
||||
// type === 33 (AWS): 保存 aws_key_type 到 settings
|
||||
if (localInputs.type === 33) {
|
||||
settings.aws_key_type = localInputs.aws_key_type || 'ak_sk';
|
||||
}
|
||||
|
||||
// type === 1 (OpenAI) 或 type === 14 (Claude): 设置字段透传控制(显式保存布尔值)
|
||||
if (localInputs.type === 1 || localInputs.type === 14) {
|
||||
settings.allow_service_tier = localInputs.allow_service_tier === true;
|
||||
@@ -1020,6 +1033,8 @@ const EditChannelModal = (props) => {
|
||||
delete localInputs.is_enterprise_account;
|
||||
// 顶层的 vertex_key_type 不应发送给后端
|
||||
delete localInputs.vertex_key_type;
|
||||
// 顶层的 aws_key_type 不应发送给后端
|
||||
delete localInputs.aws_key_type;
|
||||
// 清理字段透传控制的临时字段
|
||||
delete localInputs.allow_service_tier;
|
||||
delete localInputs.disable_store;
|
||||
@@ -1468,6 +1483,31 @@ const EditChannelModal = (props) => {
|
||||
autoComplete='new-password'
|
||||
/>
|
||||
|
||||
{inputs.type === 33 && (
|
||||
<>
|
||||
<Form.Select
|
||||
field='aws_key_type'
|
||||
label={t('密钥格式')}
|
||||
placeholder={t('请选择密钥格式')}
|
||||
optionList={[
|
||||
{
|
||||
label: 'AccessKey / SecretAccessKey',
|
||||
value: 'ak_sk',
|
||||
},
|
||||
{ label: 'API Key', value: 'api_key' },
|
||||
]}
|
||||
style={{ width: '100%' }}
|
||||
value={inputs.aws_key_type || 'ak_sk'}
|
||||
onChange={(value) => {
|
||||
handleChannelOtherSettingsChange('aws_key_type', value);
|
||||
}}
|
||||
extraText={t(
|
||||
'AK/SK 模式:使用 AccessKey 和 SecretAccessKey;API Key 模式:使用 API Key',
|
||||
)}
|
||||
/>
|
||||
</>
|
||||
)}
|
||||
|
||||
{inputs.type === 41 && (
|
||||
<Form.Select
|
||||
field='vertex_key_type'
|
||||
@@ -1536,7 +1576,15 @@ const EditChannelModal = (props) => {
|
||||
<Form.TextArea
|
||||
field='key'
|
||||
label={t('密钥')}
|
||||
placeholder={t('请输入密钥,一行一个')}
|
||||
placeholder={
|
||||
inputs.type === 33
|
||||
? inputs.aws_key_type === 'api_key'
|
||||
? t('请输入 API Key,一行一个,格式:APIKey|Region')
|
||||
: t(
|
||||
'请输入密钥,一行一个,格式:AccessKey|SecretAccessKey|Region',
|
||||
)
|
||||
: t('请输入密钥,一行一个')
|
||||
}
|
||||
rules={
|
||||
isEdit
|
||||
? []
|
||||
@@ -1730,7 +1778,13 @@ const EditChannelModal = (props) => {
|
||||
? t('密钥(编辑模式下,保存的密钥不会显示)')
|
||||
: t('密钥')
|
||||
}
|
||||
placeholder={t(type2secretPrompt(inputs.type))}
|
||||
placeholder={
|
||||
inputs.type === 33
|
||||
? inputs.aws_key_type === 'api_key'
|
||||
? t('请输入 API Key,格式:APIKey|Region')
|
||||
: t('按照如下格式输入:AccessKey|SecretAccessKey|Region')
|
||||
: t(type2secretPrompt(inputs.type))
|
||||
}
|
||||
rules={
|
||||
isEdit
|
||||
? []
|
||||
|
||||
@@ -468,6 +468,12 @@ export const useLogsData = () => {
|
||||
});
|
||||
}
|
||||
}
|
||||
if (other?.request_path) {
|
||||
expandDataLocal.push({
|
||||
key: t('请求路径'),
|
||||
value: other.request_path,
|
||||
});
|
||||
}
|
||||
expandDatesLocal[logs[i].key] = expandDataLocal;
|
||||
}
|
||||
|
||||
|
||||
@@ -1675,6 +1675,7 @@
|
||||
"请求失败": "Request failed",
|
||||
"请求头覆盖": "Request header override",
|
||||
"请求并计费模型": "Request and charge model",
|
||||
"请求路径": "Request path",
|
||||
"请求时长: ${time}s": "Request time: ${time}s",
|
||||
"请求次数": "Number of Requests",
|
||||
"请求结束后多退少补": "Adjust after request completion",
|
||||
|
||||
@@ -1684,6 +1684,7 @@
|
||||
"请求失败": "Échec de la demande",
|
||||
"请求头覆盖": "Remplacement des en-têtes de demande",
|
||||
"请求并计费模型": "Modèle de demande et de facturation",
|
||||
"请求路径": "Chemin de requête",
|
||||
"请求时长: ${time}s": "Durée de la requête : ${time}s",
|
||||
"请求次数": "Nombre de demandes",
|
||||
"请求结束后多退少补": "Ajuster après la fin de la demande",
|
||||
@@ -2081,4 +2082,4 @@
|
||||
"默认测试模型": "Modèle de test par défaut",
|
||||
"默认补全倍率": "Taux de complétion par défaut"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1693,6 +1693,7 @@
|
||||
"请求失败": "Запрос не удался",
|
||||
"请求头覆盖": "Переопределение заголовков запроса",
|
||||
"请求并计费模型": "Запрос и выставление счёта модели",
|
||||
"请求路径": "Путь запроса",
|
||||
"请求时长: ${time}s": "Время запроса: ${time}s",
|
||||
"请求次数": "Количество запросов",
|
||||
"请求结束后多退少补": "После вывода запроса возврат излишков и доплата недостатка",
|
||||
|
||||
@@ -1666,6 +1666,7 @@
|
||||
"请求失败": "请求失败",
|
||||
"请求头覆盖": "请求头覆盖",
|
||||
"请求并计费模型": "请求并计费模型",
|
||||
"请求路径": "请求路径",
|
||||
"请求时长: ${time}s": "请求时长: ${time}s",
|
||||
"请求次数": "请求次数",
|
||||
"请求结束后多退少补": "请求结束后多退少补",
|
||||
|
||||
Reference in New Issue
Block a user