Compare commits

..

1 Commits

Author SHA1 Message Date
Apple\Apple
39c841e600 feat(architecture): Core+Plugin 2025-10-13 02:02:11 +08:00
157 changed files with 3793 additions and 9230 deletions

View File

@@ -33,8 +33,7 @@ jobs:
- name: Resolve tag & write VERSION
run: |
git fetch --tags --force --depth=1
TAG=${GITHUB_REF#refs/tags/}
echo "TAG=$TAG" >> $GITHUB_ENV
echo "TAG=${GITHUB_REF#refs/tags/}" >> $GITHUB_ENV
echo "$TAG" > VERSION
echo "Building tag: $TAG for ${{ matrix.arch }}"

View File

@@ -4,8 +4,6 @@ on:
push:
tags:
- '*' # Triggers on version tags like v1.0.0
- '!*-*' # Ignore pre-release tags like v1.0.0-beta
- '!*-alpha*' # Ignore alpha tags like v1.0.0-alpha
workflow_dispatch: # Allows manual triggering
jobs:
@@ -132,10 +130,13 @@ jobs:
- name: Download all artifacts
uses: actions/download-artifact@v4
- name: Upload to Release
- name: Create Release
uses: softprops/action-gh-release@v2
with:
files: |
windows-build/*
draft: false
prerelease: false
overwrite_files: true
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}

View File

@@ -54,6 +54,8 @@ jobs:
with:
files: |
new-api-*
draft: true
generate_release_notes: true
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
@@ -91,6 +93,8 @@ jobs:
if: startsWith(github.ref, 'refs/tags/')
with:
files: new-api-macos-*
draft: true
generate_release_notes: true
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
@@ -130,6 +134,8 @@ jobs:
if: startsWith(github.ref, 'refs/tags/')
with:
files: new-api-*.exe
draft: true
generate_release_notes: true
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}

4
.gitignore vendored
View File

@@ -1,6 +1,5 @@
.idea
.vscode
.zed
upload
*.exe
*.db
@@ -11,11 +10,10 @@ web/dist
.env
one-api
new-api
/__debug_bin*
.DS_Store
tiktoken_cache
.eslintcache
.gocache
electron/node_modules
electron/dist
electron/dist

View File

@@ -28,7 +28,7 @@ RUN go build -ldflags "-s -w -X 'github.com/QuantumNous/new-api/common.Version=$
FROM alpine
RUN apk upgrade --no-cache \
&& apk add --no-cache ca-certificates tzdata \
&& apk add --no-cache ca-certificates tzdata ffmpeg \
&& update-ca-certificates
COPY --from=builder2 /build/new-api /

View File

@@ -1,260 +0,0 @@
# 日志系统说明
本项目使用 Go 标准库的 `log/slog` 实现结构化日志记录。
## 📋 功能特性
### 1. 标准的文件存储结构
- **当前日志文件**: `oneapi.log` - 实时写入的日志文件
- **归档日志文件**: `oneapi.2024-01-02-153045.log` - 自动轮转后的历史日志
### 2. 自动日志轮转
日志文件会在以下情况自动轮转:
- **按大小轮转**: 当日志文件超过指定大小时(默认 100MB
- **启动时日期检查**: 程序启动时如果检测到日志文件是旧日期创建的,会自动轮转
- **自动清理**: 只保留最近 N 个日志文件(默认 7 个)
### 3. 结构化日志
所有日志都包含以下结构化字段:
```
time=2024-01-02T15:30:45 level=INFO msg="user logged in" request_id=abc123 user_id=1001
```
### 4. 多种输出格式
- **Text 格式** (默认): 人类可读的文本格式
- **JSON 格式**: 便于日志分析工具解析
### 5. 灵活的日志级别
支持四个日志级别:
- `DEBUG`: 调试信息
- `INFO`: 一般信息
- `WARN`: 警告信息
- `ERROR`: 错误信息
## ⚙️ 配置方式
### 环境变量配置
```bash
# 日志目录(必需,否则只输出到控制台)
--log-dir=./logs
# 日志级别(可选,默认: INFODEBUG 模式除外)
export LOG_LEVEL=DEBUG # 可选值: DEBUG, INFO, WARN, ERROR
# 日志格式(可选,默认: text
export LOG_FORMAT=json # 可选值: text, json
# 单个日志文件最大大小(可选,默认: 100单位: MB
export LOG_MAX_SIZE_MB=200
# 保留的日志文件数量(可选,默认: 7
export LOG_MAX_FILES=14
# 启用调试模式(会自动将日志级别设为 DEBUG
export DEBUG=true
```
### 命令行参数
```bash
# 启动时指定日志目录
./new-api --log-dir=./logs
# 如果不指定日志目录,日志只输出到控制台
./new-api
```
## 📝 使用示例
### 基础使用
```go
import (
"context"
"github.com/QuantumNous/new-api/logger"
)
// 记录信息日志
logger.LogInfo(ctx, "user registered successfully")
// 记录警告日志
logger.LogWarn(ctx, "API rate limit approaching")
// 记录错误日志
logger.LogError(ctx, "failed to connect to database")
// 记录调试日志(只在 DEBUG 模式下输出)
logger.LogDebug(ctx, "processing request with params: %v", params)
// 记录系统日志(无 context
logger.LogSystemInfo("application started")
logger.LogSystemError("critical system error")
```
### 日志输出示例
**Text 格式** (易读格式):
```
[INFO] 2024/01/02 - 15:30:45 | SYSTEM | application started
[INFO] 2024/01/02 - 15:30:46 | abc123 | user registered successfully
[WARN] 2024/01/02 - 15:30:47 | def456 | API rate limit approaching | remaining=10, limit=100
[ERROR] 2024/01/02 - 15:30:48 | ghi789 | failed to connect to database | error="connection timeout"
```
格式说明:`[级别] 时间 | 请求ID/组件 | 消息 | 额外属性(如有)`
**JSON 格式**:
```json
{"time":"2024-01-02 15:30:45","level":"INFO","msg":"application started","request_id":"SYSTEM"}
{"time":"2024-01-02 15:30:46","level":"INFO","msg":"user registered successfully","request_id":"abc123"}
{"time":"2024-01-02 15:30:47","level":"WARN","msg":"API rate limit approaching","request_id":"def456"}
```
## 📂 日志文件结构
```
logs/
├── oneapi.log # 当前活动日志文件
├── oneapi.2024-01-01-090000.log # 昨天的日志
├── oneapi.2024-01-01-150000.log # 昨天下午的日志(如果超过大小限制)
├── oneapi.2023-12-31-090000.log # 更早的日志
└── ... # 最多保留配置数量的历史文件
```
## 🔄 日志轮转机制
### 轮转触发条件
1. **文件大小检查**: 每写入 1000 条日志后检查一次文件大小
2. **启动时日期检查**: 程序启动时检查日志文件的修改日期,如果不是今天则轮转
3. **自动清理**: 轮转时自动删除超过保留数量的旧日志文件
> **注意**: 日志不会在运行时动态检查日期变化。如果需要每天自动轮转日志,建议:
> - 使用定时任务(如 cron每天重启服务
> - 或者配置较小的日志文件大小,让它自动按大小轮转
### 轮转流程
1. 检测到需要轮转时,关闭当前日志文件
2.`oneapi.log` 重命名为 `oneapi.YYYY-MM-DD-HHmmss.log`
3. 创建新的 `oneapi.log` 文件
4. 异步清理超过数量限制的旧日志文件
5. 记录轮转事件到新日志文件
## 🎯 最佳实践
### 1. 生产环境配置
```bash
# 使用 INFO 级别,避免过多调试信息
export LOG_LEVEL=INFO
# 使用 JSON 格式,便于日志分析工具处理
export LOG_FORMAT=json
# 设置合适的文件大小和保留数量
export LOG_MAX_SIZE_MB=500
export LOG_MAX_FILES=30
# 指定日志目录
./new-api --log-dir=/var/log/oneapi
```
### 2. 开发环境配置
```bash
# 使用 DEBUG 级别查看详细信息
export DEBUG=true
# 使用 Text 格式,便于阅读
export LOG_FORMAT=text
# 较小的文件大小和保留数量
export LOG_MAX_SIZE_MB=50
export LOG_MAX_FILES=7
./new-api --log-dir=./logs
```
### 3. 容器环境配置
```bash
# 只输出到标准输出,由容器运行时管理日志
./new-api
# 或者使用 JSON 格式便于日志收集系统处理
export LOG_FORMAT=json
./new-api
```
## 🔍 日志分析
### 使用 grep 分析文本日志
```bash
# 查找错误日志
grep '\[ERROR\]' logs/oneapi.log
# 查找特定请求的所有日志
grep 'abc123' logs/*.log
# 查看最近的警告和错误
tail -f logs/oneapi.log | grep -E '\[(WARN|ERROR)\]'
# 查找包含特定关键词的日志
grep 'database' logs/oneapi.log
# 查看今天的所有错误
grep "\[ERROR\] $(date +%Y/%m/%d)" logs/oneapi.log
```
### 使用 jq 分析 JSON 日志
```bash
# 提取所有错误日志
cat logs/oneapi.log | jq 'select(.level=="ERROR")'
# 统计各级别日志数量
cat logs/oneapi.log | jq -r '.level' | sort | uniq -c
# 查找特定时间范围的日志
cat logs/oneapi.log | jq 'select(.time >= "2024-01-02 15:00:00" and .time <= "2024-01-02 16:00:00")'
```
## 📊 性能优化
1. **异步日志轮转**: 轮转操作在后台 goroutine 中执行,不阻塞主程序
2. **批量写入检查**: 每 1000 次写入才检查一次轮转条件,减少 I/O 开销
3. **读写锁**: 使用 `sync.RWMutex` 保护日志器,提高并发性能
4. **零分配**: `slog` 库在大多数情况下实现零内存分配
## 🚨 故障排查
### 日志文件未创建
- 检查日志目录是否存在且有写入权限
- 确认启动时指定了 `--log-dir` 参数
### 日志文件过多
- 调整 `LOG_MAX_FILES` 环境变量
- 手动清理不需要的旧日志文件
### 日志级别不正确
- 检查 `LOG_LEVEL` 环境变量是否正确设置
- 确认 `DEBUG` 环境变量的值(会覆盖 LOG_LEVEL
## 📖 相关文档
- [Go slog 官方文档](https://pkg.go.dev/log/slog)
- [结构化日志最佳实践](https://go.dev/blog/slog)

View File

@@ -141,7 +141,6 @@ New API提供了丰富的功能详细特性请参考[特性说明](https://do
- `NOTIFICATION_LIMIT_DURATION_MINUTE`:邮件等通知限制持续时间,默认 `10`分钟
- `NOTIFY_LIMIT_COUNT`:用户通知在指定持续时间内的最大数量,默认 `2`
- `ERROR_LOG_ENABLED=true`: 是否记录并显示错误日志,默认`false`
- `TASK_PRICE_PATCH=sora-2-all,sora-2-pro-all`: 异步任务设置某些模型按次计费,多个模型用逗号分隔,例如`sora-2-all,sora-2-pro-all`表示sora-2-all和sora-2-pro-all模型异步任务仅按次计费不按秒等计费。
## 部署
@@ -166,18 +165,12 @@ New API提供了丰富的功能详细特性请参考[特性说明](https://do
#### 使用Docker Compose部署推荐
```shell
# 下载项目源码
git clone https://github.com/QuantumNous/new-api.git
# 进入项目目录
# 下载项目
git clone https://github.com/Calcium-Ion/new-api.git
cd new-api
# 根据需要编辑 docker-compose.yml 文件
# 使用nano编辑器
nano docker-compose.yml
# 或使用vim编辑器
# vim docker-compose.yml
# 按需编辑docker-compose.yml
# 启动
docker-compose up -d
```
#### 直接使用Docker镜像

View File

@@ -69,8 +69,6 @@ func ChannelType2APIType(channelType int) (int, bool) {
apiType = constant.APITypeMoonshot
case constant.ChannelTypeSubmodel:
apiType = constant.APITypeSubmodel
case constant.ChannelTypeMiniMax:
apiType = constant.APITypeMiniMax
}
if apiType == -1 {
return constant.APITypeOpenAI, false

View File

@@ -1,296 +0,0 @@
package common
import (
"context"
"encoding/binary"
"fmt"
"io"
"github.com/abema/go-mp4"
"github.com/go-audio/aiff"
"github.com/go-audio/wav"
"github.com/jfreymuth/oggvorbis"
"github.com/mewkiz/flac"
"github.com/pkg/errors"
"github.com/tcolgate/mp3"
"github.com/yapingcat/gomedia/go-codec"
)
// GetAudioDuration 使用纯 Go 库获取音频文件的时长(秒)。
// 它不再依赖外部的 ffmpeg 或 ffprobe 程序。
func GetAudioDuration(ctx context.Context, f io.ReadSeeker, ext string) (duration float64, err error) {
SysLog(fmt.Sprintf("GetAudioDuration: ext=%s", ext))
// 根据文件扩展名选择解析器
switch ext {
case ".mp3":
duration, err = getMP3Duration(f)
case ".wav":
duration, err = getWAVDuration(f)
case ".flac":
duration, err = getFLACDuration(f)
case ".m4a", ".mp4":
duration, err = getM4ADuration(f)
case ".ogg", ".oga", ".opus":
duration, err = getOGGDuration(f)
if err != nil {
duration, err = getOpusDuration(f)
}
case ".aiff", ".aif", ".aifc":
duration, err = getAIFFDuration(f)
case ".webm":
duration, err = getWebMDuration(f)
case ".aac":
duration, err = getAACDuration(f)
default:
return 0, fmt.Errorf("unsupported audio format: %s", ext)
}
SysLog(fmt.Sprintf("GetAudioDuration: duration=%f", duration))
return duration, err
}
// getMP3Duration 解析 MP3 文件以获取时长。
// 注意:对于 VBR (Variable Bitrate) MP3这个估算可能不完全精确但通常足够好。
// FFmpeg 在这种情况下会扫描整个文件来获得精确值,但这里的库提供了快速估算。
func getMP3Duration(r io.Reader) (float64, error) {
d := mp3.NewDecoder(r)
var f mp3.Frame
skipped := 0
duration := 0.0
for {
if err := d.Decode(&f, &skipped); err != nil {
if err == io.EOF {
break
}
return 0, errors.Wrap(err, "failed to decode mp3 frame")
}
duration += f.Duration().Seconds()
}
return duration, nil
}
// getWAVDuration 解析 WAV 文件头以获取时长。
func getWAVDuration(r io.ReadSeeker) (float64, error) {
dec := wav.NewDecoder(r)
if !dec.IsValidFile() {
return 0, errors.New("invalid wav file")
}
d, err := dec.Duration()
if err != nil {
return 0, errors.Wrap(err, "failed to get wav duration")
}
return d.Seconds(), nil
}
// getFLACDuration 解析 FLAC 文件的 STREAMINFO 块。
func getFLACDuration(r io.Reader) (float64, error) {
stream, err := flac.Parse(r)
if err != nil {
return 0, errors.Wrap(err, "failed to parse flac stream")
}
defer stream.Close()
// 时长 = 总采样数 / 采样率
duration := float64(stream.Info.NSamples) / float64(stream.Info.SampleRate)
return duration, nil
}
// getM4ADuration 解析 M4A/MP4 文件的 'mvhd' box。
func getM4ADuration(r io.ReadSeeker) (float64, error) {
// go-mp4 库需要 ReadSeeker 接口
info, err := mp4.Probe(r)
if err != nil {
return 0, errors.Wrap(err, "failed to probe m4a/mp4 file")
}
// 时长 = Duration / Timescale
return float64(info.Duration) / float64(info.Timescale), nil
}
// getOGGDuration 解析 OGG/Vorbis 文件以获取时长。
func getOGGDuration(r io.ReadSeeker) (float64, error) {
// 重置 reader 到开头
if _, err := r.Seek(0, io.SeekStart); err != nil {
return 0, errors.Wrap(err, "failed to seek ogg file")
}
reader, err := oggvorbis.NewReader(r)
if err != nil {
return 0, errors.Wrap(err, "failed to create ogg vorbis reader")
}
// 计算时长 = 总采样数 / 采样率
// 需要读取整个文件来获取总采样数
channels := reader.Channels()
sampleRate := reader.SampleRate()
// 估算方法:读取到文件结尾
var totalSamples int64
buf := make([]float32, 4096*channels)
for {
n, err := reader.Read(buf)
if err == io.EOF {
break
}
if err != nil {
return 0, errors.Wrap(err, "failed to read ogg samples")
}
totalSamples += int64(n / channels)
}
duration := float64(totalSamples) / float64(sampleRate)
return duration, nil
}
// getOpusDuration 解析 Opus 文件(在 OGG 容器中)以获取时长。
func getOpusDuration(r io.ReadSeeker) (float64, error) {
// Opus 通常封装在 OGG 容器中
// 我们需要解析 OGG 页面来获取时长信息
if _, err := r.Seek(0, io.SeekStart); err != nil {
return 0, errors.Wrap(err, "failed to seek opus file")
}
// 读取 OGG 页面头部
var totalGranulePos int64
buf := make([]byte, 27) // OGG 页面头部最小大小
for {
n, err := r.Read(buf)
if err == io.EOF {
break
}
if err != nil {
return 0, errors.Wrap(err, "failed to read opus/ogg page")
}
if n < 27 {
break
}
// 检查 OGG 页面标识 "OggS"
if string(buf[0:4]) != "OggS" {
// 跳过一些字节继续寻找
if _, err := r.Seek(-26, io.SeekCurrent); err != nil {
break
}
continue
}
// 读取 granule position (字节 6-13, 小端序)
granulePos := int64(binary.LittleEndian.Uint64(buf[6:14]))
if granulePos > totalGranulePos {
totalGranulePos = granulePos
}
// 读取段表大小
numSegments := int(buf[26])
segmentTable := make([]byte, numSegments)
if _, err := io.ReadFull(r, segmentTable); err != nil {
break
}
// 计算页面数据大小并跳过
var pageSize int
for _, segSize := range segmentTable {
pageSize += int(segSize)
}
if _, err := r.Seek(int64(pageSize), io.SeekCurrent); err != nil {
break
}
}
// Opus 的采样率固定为 48000 Hz
duration := float64(totalGranulePos) / 48000.0
return duration, nil
}
// getAIFFDuration 解析 AIFF 文件头以获取时长。
func getAIFFDuration(r io.ReadSeeker) (float64, error) {
if _, err := r.Seek(0, io.SeekStart); err != nil {
return 0, errors.Wrap(err, "failed to seek aiff file")
}
dec := aiff.NewDecoder(r)
if !dec.IsValidFile() {
return 0, errors.New("invalid aiff file")
}
d, err := dec.Duration()
if err != nil {
return 0, errors.Wrap(err, "failed to get aiff duration")
}
return d.Seconds(), nil
}
// getWebMDuration 解析 WebM 文件以获取时长。
// WebM 使用 Matroska 容器格式
func getWebMDuration(r io.ReadSeeker) (float64, error) {
if _, err := r.Seek(0, io.SeekStart); err != nil {
return 0, errors.Wrap(err, "failed to seek webm file")
}
// WebM/Matroska 文件的解析比较复杂
// 这里提供一个简化的实现,读取 EBML 头部
// 对于完整的 WebM 解析,可能需要使用专门的库
// 简单实现:查找 Duration 元素
// WebM Duration 的 Element ID 是 0x4489
// 这是一个简化版本,可能不适用于所有 WebM 文件
buf := make([]byte, 8192)
n, err := r.Read(buf)
if err != nil && err != io.EOF {
return 0, errors.Wrap(err, "failed to read webm file")
}
// 尝试查找 Duration 元素(这是一个简化的方法)
// 实际的 WebM 解析需要完整的 EBML 解析器
// 这里返回错误,建议使用专门的库
if n > 0 {
// 检查 EBML 标识
if len(buf) >= 4 && binary.BigEndian.Uint32(buf[0:4]) == 0x1A45DFA3 {
// 这是一个有效的 EBML 文件
// 但完整解析需要更复杂的逻辑
return 0, errors.New("webm duration parsing requires full EBML parser (consider using ffprobe for webm files)")
}
}
return 0, errors.New("failed to parse webm file")
}
// getAACDuration 解析 AAC (ADTS格式) 文件以获取时长。
// 使用 gomedia 库来解析 AAC ADTS 帧
func getAACDuration(r io.ReadSeeker) (float64, error) {
if _, err := r.Seek(0, io.SeekStart); err != nil {
return 0, errors.Wrap(err, "failed to seek aac file")
}
// 读取整个文件内容
data, err := io.ReadAll(r)
if err != nil {
return 0, errors.Wrap(err, "failed to read aac file")
}
var totalFrames int64
var sampleRate int
// 使用 gomedia 的 SplitAACFrame 函数来分割 AAC 帧
codec.SplitAACFrame(data, func(aac []byte) {
// 解析 ADTS 头部以获取采样率信息
if len(aac) >= 7 {
// 使用 ConvertADTSToASC 来获取音频配置信息
asc, err := codec.ConvertADTSToASC(aac)
if err == nil && sampleRate == 0 {
sampleRate = codec.AACSampleIdxToSample(int(asc.Sample_freq_index))
}
totalFrames++
}
})
if sampleRate == 0 || totalFrames == 0 {
return 0, errors.New("no valid aac frames found")
}
// 每个 AAC ADTS 帧包含 1024 个采样
totalSamples := totalFrames * 1024
duration := float64(totalSamples) / float64(sampleRate)
return duration, nil
}

View File

@@ -159,15 +159,14 @@ var (
GlobalWebRateLimitNum int
GlobalWebRateLimitDuration int64
CriticalRateLimitEnable bool
CriticalRateLimitNum = 20
CriticalRateLimitDuration int64 = 20 * 60
UploadRateLimitNum = 10
UploadRateLimitDuration int64 = 60
DownloadRateLimitNum = 10
DownloadRateLimitDuration int64 = 60
CriticalRateLimitNum = 20
CriticalRateLimitDuration int64 = 20 * 60
)
var RateLimitKeyExpirationDuration = 20 * time.Minute

View File

@@ -86,8 +86,5 @@ func SendEmail(subject string, receiver string, content string) error {
} else {
err = smtp.SendMail(addr, auth, SMTPFrom, to, mail)
}
if err != nil {
SysError(fmt.Sprintf("failed to send email to %s: %v", receiver, err))
}
return err
}

View File

@@ -26,8 +26,6 @@ func GetEndpointTypesByChannelType(channelType int, modelName string) []constant
endpointTypes = []constant.EndpointType{constant.EndpointTypeGemini, constant.EndpointTypeOpenAI}
case constant.ChannelTypeOpenRouter: // OpenRouter 只支持 OpenAI 端点
endpointTypes = []constant.EndpointType{constant.EndpointTypeOpenAI}
case constant.ChannelTypeSora:
endpointTypes = []constant.EndpointType{constant.EndpointTypeOpenAIVideo}
default:
if IsOpenAIResponseOnlyModel(modelName) {
endpointTypes = []constant.EndpointType{constant.EndpointTypeOpenAIResponse}

View File

@@ -5,7 +5,6 @@ import (
"io"
"mime/multipart"
"net/http"
"net/url"
"strings"
"time"
@@ -40,11 +39,7 @@ func UnmarshalBodyReusable(c *gin.Context, v any) error {
//}
contentType := c.Request.Header.Get("Content-Type")
if strings.HasPrefix(contentType, "application/json") {
err = Unmarshal(requestBody, v)
} else if strings.Contains(contentType, gin.MIMEPOSTForm) {
err = parseFormData(requestBody, v)
} else if strings.Contains(contentType, gin.MIMEMultipartPOSTForm) {
err = parseMultipartFormData(c, requestBody, v)
err = Unmarshal(requestBody, &v)
} else {
// skip for now
// TODO: someday non json request have variant model, we will need to implementation this
@@ -143,63 +138,3 @@ func ParseMultipartFormReusable(c *gin.Context) (*multipart.Form, error) {
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
return form, nil
}
func processFormMap(formMap map[string]any, v any) error {
jsonData, err := Marshal(formMap)
if err != nil {
return err
}
err = Unmarshal(jsonData, v)
if err != nil {
return err
}
return nil
}
func parseFormData(data []byte, v any) error {
values, err := url.ParseQuery(string(data))
if err != nil {
return err
}
formMap := make(map[string]any)
for key, vals := range values {
if len(vals) == 1 {
formMap[key] = vals[0]
} else {
formMap[key] = vals
}
}
return processFormMap(formMap, v)
}
func parseMultipartFormData(c *gin.Context, data []byte, v any) error {
contentType := c.Request.Header.Get("Content-Type")
boundary := ""
if idx := strings.Index(contentType, "boundary="); idx != -1 {
boundary = contentType[idx+9:]
}
if boundary == "" {
return Unmarshal(data, v) // Fallback to JSON
}
reader := multipart.NewReader(bytes.NewReader(data), boundary)
form, err := reader.ReadForm(32 << 20) // 32 MB max memory
if err != nil {
return err
}
defer form.RemoveAll()
formMap := make(map[string]any)
for key, vals := range form.Value {
if len(vals) == 1 {
formMap[key] = vals[0]
} else {
formMap[key] = vals
}
}
return processFormMap(formMap, v)
}

View File

@@ -3,11 +3,10 @@ package common
import (
"flag"
"fmt"
"log/slog"
"log"
"os"
"path/filepath"
"strconv"
"strings"
"time"
"github.com/QuantumNous/new-api/constant"
@@ -43,10 +42,9 @@ func InitEnv() {
if os.Getenv("SESSION_SECRET") != "" {
ss := os.Getenv("SESSION_SECRET")
if ss == "random_string" {
slog.Warn("SESSION_SECRET is set to the default value 'random_string', please change it to a random string.")
slog.Warn("警告SESSION_SECRET被设置为默认值'random_string',请修改为随机字符串。")
slog.Error("Please set SESSION_SECRET to a random string.")
os.Exit(1)
log.Println("WARNING: SESSION_SECRET is set to the default value 'random_string', please change it to a random string.")
log.Println("警告SESSION_SECRET被设置为默认值'random_string',请修改为随机字符串。")
log.Fatal("Please set SESSION_SECRET to a random string.")
} else {
SessionSecret = ss
}
@@ -63,14 +61,12 @@ func InitEnv() {
var err error
*LogDir, err = filepath.Abs(*LogDir)
if err != nil {
slog.Error("failed to get absolute path for log directory", "error", err)
os.Exit(1)
log.Fatal(err)
}
if _, err := os.Stat(*LogDir); os.IsNotExist(err) {
err = os.Mkdir(*LogDir, 0777)
if err != nil {
slog.Error("failed to create log directory", "error", err)
os.Exit(1)
log.Fatal(err)
}
}
}
@@ -102,9 +98,6 @@ func InitEnv() {
GlobalWebRateLimitNum = GetEnvOrDefault("GLOBAL_WEB_RATE_LIMIT", 60)
GlobalWebRateLimitDuration = int64(GetEnvOrDefault("GLOBAL_WEB_RATE_LIMIT_DURATION", 180))
CriticalRateLimitEnable = GetEnvOrDefaultBool("CRITICAL_RATE_LIMIT_ENABLE", true)
CriticalRateLimitNum = GetEnvOrDefault("CRITICAL_RATE_LIMIT", 20)
CriticalRateLimitDuration = int64(GetEnvOrDefault("CRITICAL_RATE_LIMIT_DURATION", 20*60))
initConstantEnv()
}
@@ -125,17 +118,4 @@ func initConstantEnv() {
constant.GenerateDefaultToken = GetEnvOrDefaultBool("GENERATE_DEFAULT_TOKEN", false)
// 是否启用错误日志
constant.ErrorLogEnabled = GetEnvOrDefaultBool("ERROR_LOG_ENABLED", false)
soraPatchStr := GetEnvOrDefaultString("TASK_PRICE_PATCH", "")
if soraPatchStr != "" {
var taskPricePatches []string
soraPatches := strings.Split(soraPatchStr, ",")
for _, patch := range soraPatches {
trimmedPatch := strings.TrimSpace(patch)
if trimmedPatch != "" {
taskPricePatches = append(taskPricePatches, trimmedPatch)
}
}
constant.TaskPricePatches = taskPricePatches
}
}

View File

@@ -3,7 +3,6 @@ package common
import (
"bytes"
"encoding/json"
"io"
)
func Unmarshal(data []byte, v any) error {
@@ -14,7 +13,7 @@ func UnmarshalJsonStr(data string, v any) error {
return json.Unmarshal(StringToByteSlice(data), v)
}
func DecodeJson(reader io.Reader, v any) error {
func DecodeJson(reader *bytes.Reader, v any) error {
return json.NewDecoder(reader).Decode(v)
}

View File

@@ -2,7 +2,6 @@ package common
import (
"fmt"
"log/slog"
"os"
"time"
@@ -10,16 +9,18 @@ import (
)
func SysLog(s string) {
slog.Info(s, "component", "system")
t := time.Now()
_, _ = fmt.Fprintf(gin.DefaultWriter, "[SYS] %v | %s \n", t.Format("2006/01/02 - 15:04:05"), s)
}
func SysError(s string) {
slog.Error(s, "component", "system")
t := time.Now()
_, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[SYS] %v | %s \n", t.Format("2006/01/02 - 15:04:05"), s)
}
func FatalLog(v ...any) {
msg := fmt.Sprint(v...)
slog.Error(msg, "component", "system", "level", "fatal")
t := time.Now()
_, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[FATAL] %v | %v \n", t.Format("2006/01/02 - 15:04:05"), v)
os.Exit(1)
}

View File

@@ -1,6 +1,8 @@
package common
import (
"bytes"
"context"
crand "crypto/rand"
"encoding/base64"
"encoding/json"
@@ -230,6 +232,10 @@ func GetUUID() string {
const keyChars = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
func init() {
rand.New(rand.NewSource(time.Now().UnixNano()))
}
func GenerateRandomCharsKey(length int) (string, error) {
b := make([]byte, length)
maxI := big.NewInt(int64(len(keyChars)))
@@ -323,6 +329,43 @@ func SaveTmpFile(filename string, data io.Reader) (string, error) {
return f.Name(), nil
}
// GetAudioDuration returns the duration of an audio file in seconds.
func GetAudioDuration(ctx context.Context, filename string, ext string) (float64, error) {
// ffprobe -v error -show_entries format=duration -of default=noprint_wrappers=1:nokey=1 {{input}}
c := exec.CommandContext(ctx, "ffprobe", "-v", "error", "-show_entries", "format=duration", "-of", "default=noprint_wrappers=1:nokey=1", filename)
output, err := c.Output()
if err != nil {
return 0, errors.Wrap(err, "failed to get audio duration")
}
durationStr := string(bytes.TrimSpace(output))
if durationStr == "N/A" {
// Create a temporary output file name
tmpFp, err := os.CreateTemp("", "audio-*"+ext)
if err != nil {
return 0, errors.Wrap(err, "failed to create temporary file")
}
tmpName := tmpFp.Name()
// Close immediately so ffmpeg can open the file on Windows.
_ = tmpFp.Close()
defer os.Remove(tmpName)
// ffmpeg -y -i filename -vcodec copy -acodec copy <tmpName>
ffmpegCmd := exec.CommandContext(ctx, "ffmpeg", "-y", "-i", filename, "-vcodec", "copy", "-acodec", "copy", tmpName)
if err := ffmpegCmd.Run(); err != nil {
return 0, errors.Wrap(err, "failed to run ffmpeg")
}
// Recalculate the duration of the new file
c = exec.CommandContext(ctx, "ffprobe", "-v", "error", "-show_entries", "format=duration", "-of", "default=noprint_wrappers=1:nokey=1", tmpName)
output, err := c.Output()
if err != nil {
return 0, errors.Wrap(err, "failed to get audio duration after ffmpeg")
}
durationStr = string(bytes.TrimSpace(output))
}
return strconv.ParseFloat(durationStr, 64)
}
// BuildURL concatenates base and endpoint, returns the complete url string
func BuildURL(base string, endpoint string) string {
u, err := url.Parse(base)

161
config/plugin_config.go Normal file
View File

@@ -0,0 +1,161 @@
package config
import (
"fmt"
"io/ioutil"
"os"
"path/filepath"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/core/interfaces"
"gopkg.in/yaml.v3"
)
// PluginConfig 插件配置结构
type PluginConfig struct {
Channels map[string]interfaces.ChannelConfig `yaml:"channels"`
Middlewares []interfaces.MiddlewareConfig `yaml:"middlewares"`
Hooks HooksConfig `yaml:"hooks"`
}
// HooksConfig Hook配置
type HooksConfig struct {
Relay []interfaces.HookConfig `yaml:"relay"`
}
var (
// 全局配置实例
globalPluginConfig *PluginConfig
)
// LoadPluginConfig 加载插件配置
func LoadPluginConfig(configPath string) (*PluginConfig, error) {
// 如果没有指定配置文件路径,使用默认路径
if configPath == "" {
configPath = "config/plugins.yaml"
}
// 检查文件是否存在
if _, err := os.Stat(configPath); os.IsNotExist(err) {
common.SysLog(fmt.Sprintf("Plugin config file not found: %s, using default configuration", configPath))
return getDefaultConfig(), nil
}
// 读取配置文件
data, err := ioutil.ReadFile(configPath)
if err != nil {
return nil, fmt.Errorf("failed to read plugin config: %w", err)
}
// 解析YAML
var config PluginConfig
if err := yaml.Unmarshal(data, &config); err != nil {
return nil, fmt.Errorf("failed to parse plugin config: %w", err)
}
// 环境变量替换
expandEnvVars(&config)
common.SysLog(fmt.Sprintf("Loaded plugin config from: %s", configPath))
return &config, nil
}
// getDefaultConfig 返回默认配置
func getDefaultConfig() *PluginConfig {
return &PluginConfig{
Channels: make(map[string]interfaces.ChannelConfig),
Middlewares: make([]interfaces.MiddlewareConfig, 0),
Hooks: HooksConfig{
Relay: make([]interfaces.HookConfig, 0),
},
}
}
// expandEnvVars 展开环境变量
func expandEnvVars(config *PluginConfig) {
// 展开Hook配置中的环境变量
for i := range config.Hooks.Relay {
for key, value := range config.Hooks.Relay[i].Config {
if strValue, ok := value.(string); ok {
config.Hooks.Relay[i].Config[key] = os.ExpandEnv(strValue)
}
}
}
// 展开Middleware配置中的环境变量
for i := range config.Middlewares {
for key, value := range config.Middlewares[i].Config {
if strValue, ok := value.(string); ok {
config.Middlewares[i].Config[key] = os.ExpandEnv(strValue)
}
}
}
}
// GetGlobalPluginConfig 获取全局配置
func GetGlobalPluginConfig() *PluginConfig {
if globalPluginConfig == nil {
configPath := os.Getenv("PLUGIN_CONFIG_PATH")
if configPath == "" {
configPath = "config/plugins.yaml"
}
config, err := LoadPluginConfig(configPath)
if err != nil {
common.SysError(fmt.Sprintf("Failed to load plugin config: %v", err))
config = getDefaultConfig()
}
globalPluginConfig = config
}
return globalPluginConfig
}
// SavePluginConfig 保存插件配置
func SavePluginConfig(config *PluginConfig, configPath string) error {
if configPath == "" {
configPath = "config/plugins.yaml"
}
// 确保目录存在
dir := filepath.Dir(configPath)
if err := os.MkdirAll(dir, 0755); err != nil {
return fmt.Errorf("failed to create config directory: %w", err)
}
// 序列化为YAML
data, err := yaml.Marshal(config)
if err != nil {
return fmt.Errorf("failed to marshal config: %w", err)
}
// 写入文件
if err := ioutil.WriteFile(configPath, data, 0644); err != nil {
return fmt.Errorf("failed to write config file: %w", err)
}
common.SysLog(fmt.Sprintf("Saved plugin config to: %s", configPath))
return nil
}
// ReloadPluginConfig 重新加载配置
func ReloadPluginConfig() error {
configPath := os.Getenv("PLUGIN_CONFIG_PATH")
if configPath == "" {
configPath = "config/plugins.yaml"
}
config, err := LoadPluginConfig(configPath)
if err != nil {
return err
}
globalPluginConfig = config
common.SysLog("Plugin config reloaded")
return nil
}

52
config/plugins.yaml Normal file
View File

@@ -0,0 +1,52 @@
# New-API 插件配置
# 此文件用于配置所有插件的启用状态和参数
# Channel插件配置
channels:
openai:
enabled: true
priority: 100
claude:
enabled: true
priority: 90
gemini:
enabled: true
priority: 85
# Middleware插件配置
middlewares:
- name: auth
enabled: true
priority: 100
- name: ratelimit
enabled: true
priority: 90
config:
default_rate: 60
# Hook插件配置
hooks:
# Relay层Hook
relay:
# 联网搜索插件
- name: web_search
enabled: false # 默认禁用需要配置API key后启用
priority: 50
config:
provider: google
api_key: ${WEB_SEARCH_API_KEY} # 从环境变量读取
# 内容过滤插件
- name: content_filter
enabled: false # 默认禁用,需要配置后启用
priority: 100 # 高优先级,最后执行
config:
filter_nsfw: true
filter_political: false
sensitive_words:
- "敏感词1"
- "敏感词2"

View File

@@ -33,6 +33,5 @@ const (
APITypeJimeng
APITypeMoonshot
APITypeSubmodel
APITypeMiniMax
APITypeDummy // this one is only for count, do not add any channel after this
)

View File

@@ -10,7 +10,6 @@ const (
EndpointTypeJinaRerank EndpointType = "jina-rerank"
EndpointTypeImageGeneration EndpointType = "image-generation"
EndpointTypeEmbeddings EndpointType = "embeddings"
EndpointTypeOpenAIVideo EndpointType = "openai-video"
//EndpointTypeMidjourney EndpointType = "midjourney-proxy"
//EndpointTypeSuno EndpointType = "suno-proxy"
//EndpointTypeKling EndpointType = "kling"

View File

@@ -13,6 +13,3 @@ var NotifyLimitCount int
var NotificationLimitDurationMinute int
var GenerateDefaultToken bool
var ErrorLogEnabled bool
// temporary variable for sora patch, will be removed in future
var TaskPricePatches []string

View File

@@ -617,20 +617,16 @@ func TestAllChannels(c *gin.Context) {
var autoTestChannelsOnce sync.Once
func AutomaticallyTestChannels() {
// 只在Master节点定时测试渠道
if !common.IsMasterNode {
return
}
autoTestChannelsOnce.Do(func() {
for {
if !operation_setting.GetMonitorSetting().AutoTestChannelEnabled {
time.Sleep(1 * time.Minute)
time.Sleep(10 * time.Minute)
continue
}
for {
frequency := operation_setting.GetMonitorSetting().AutoTestChannelMinutes
time.Sleep(time.Duration(int(math.Round(frequency))) * time.Minute)
common.SysLog(fmt.Sprintf("automatically test channels with interval %f minutes", frequency))
time.Sleep(time.Duration(frequency) * time.Minute)
common.SysLog(fmt.Sprintf("automatically test channels with interval %d minutes", frequency))
common.SysLog("automatically testing all channels")
_ = testAllChannels(false)
common.SysLog("automatically channel test finished")

View File

@@ -649,15 +649,13 @@ func DeleteDisabledChannel(c *gin.Context) {
}
type ChannelTag struct {
Tag string `json:"tag"`
NewTag *string `json:"new_tag"`
Priority *int64 `json:"priority"`
Weight *uint `json:"weight"`
ModelMapping *string `json:"model_mapping"`
Models *string `json:"models"`
Groups *string `json:"groups"`
ParamOverride *string `json:"param_override"`
HeaderOverride *string `json:"header_override"`
Tag string `json:"tag"`
NewTag *string `json:"new_tag"`
Priority *int64 `json:"priority"`
Weight *uint `json:"weight"`
ModelMapping *string `json:"model_mapping"`
Models *string `json:"models"`
Groups *string `json:"groups"`
}
func DisableTagChannels(c *gin.Context) {
@@ -723,29 +721,7 @@ func EditTagChannels(c *gin.Context) {
})
return
}
if channelTag.ParamOverride != nil {
trimmed := strings.TrimSpace(*channelTag.ParamOverride)
if trimmed != "" && !json.Valid([]byte(trimmed)) {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "参数覆盖必须是合法的 JSON 格式",
})
return
}
channelTag.ParamOverride = common.GetPointer[string](trimmed)
}
if channelTag.HeaderOverride != nil {
trimmed := strings.TrimSpace(*channelTag.HeaderOverride)
if trimmed != "" && !json.Valid([]byte(trimmed)) {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "请求头覆盖必须是合法的 JSON 格式",
})
return
}
channelTag.HeaderOverride = common.GetPointer[string](trimmed)
}
err = model.EditChannelByTag(channelTag.Tag, channelTag.NewTag, channelTag.ModelMapping, channelTag.Models, channelTag.Groups, channelTag.Priority, channelTag.Weight, channelTag.ParamOverride, channelTag.HeaderOverride)
err = model.EditChannelByTag(channelTag.Tag, channelTag.NewTag, channelTag.ModelMapping, channelTag.Models, channelTag.Groups, channelTag.Priority, channelTag.Weight)
if err != nil {
common.ApiError(c, err)
return

View File

@@ -4,7 +4,6 @@ import (
"net/http"
"github.com/QuantumNous/new-api/model"
"github.com/QuantumNous/new-api/service"
"github.com/QuantumNous/new-api/setting"
"github.com/QuantumNous/new-api/setting/ratio_setting"
@@ -28,17 +27,17 @@ func GetUserGroups(c *gin.Context) {
userGroup := ""
userId := c.GetInt("id")
userGroup, _ = model.GetUserGroup(userId, false)
userUsableGroups := service.GetUserUsableGroups(userGroup)
for groupName, _ := range ratio_setting.GetGroupRatioCopy() {
for groupName, ratio := range ratio_setting.GetGroupRatioCopy() {
// UserUsableGroups contains the groups that the user can use
userUsableGroups := setting.GetUserUsableGroups(userGroup)
if desc, ok := userUsableGroups[groupName]; ok {
usableGroups[groupName] = map[string]interface{}{
"ratio": service.GetUserGroupRatio(userGroup, groupName),
"ratio": ratio,
"desc": desc,
}
}
}
if _, ok := userUsableGroups["auto"]; ok {
if setting.GroupInUserUsableGroups("auto") {
usableGroups["auto"] = map[string]interface{}{
"ratio": "自动",
"desc": setting.GetUsableGroupDescription("auto"),

View File

@@ -15,7 +15,7 @@ import (
"github.com/QuantumNous/new-api/relay/channel/minimax"
"github.com/QuantumNous/new-api/relay/channel/moonshot"
relaycommon "github.com/QuantumNous/new-api/relay/common"
"github.com/QuantumNous/new-api/service"
"github.com/QuantumNous/new-api/setting"
"github.com/gin-gonic/gin"
"github.com/samber/lo"
)
@@ -149,7 +149,7 @@ func ListModels(c *gin.Context, modelType int) {
}
var models []string
if tokenGroup == "auto" {
for _, autoGroup := range service.GetUserAutoGroup(userGroup) {
for _, autoGroup := range setting.AutoGroups {
groupModels := model.GetGroupEnabledModels(autoGroup)
for _, g := range groupModels {
if !common.StringsContains(models, g) {

View File

@@ -31,7 +31,7 @@ func Playground(c *gin.Context) {
return
}
group := common.GetContextKeyString(c, constant.ContextKeyUsingGroup)
group := c.GetString("group")
modelName := c.GetString("original_model")
userId := c.GetInt("id")

View File

@@ -2,7 +2,7 @@ package controller
import (
"github.com/QuantumNous/new-api/model"
"github.com/QuantumNous/new-api/service"
"github.com/QuantumNous/new-api/setting"
"github.com/QuantumNous/new-api/setting/ratio_setting"
"github.com/gin-gonic/gin"
@@ -30,7 +30,7 @@ func GetPricing(c *gin.Context) {
}
}
usableGroup = service.GetUserUsableGroups(group)
usableGroup = setting.GetUserUsableGroups(group)
// check groupRatio contains usableGroup
for group := range ratio_setting.GetGroupRatioCopy() {
if _, ok := usableGroup[group]; !ok {
@@ -45,7 +45,7 @@ func GetPricing(c *gin.Context) {
"group_ratio": groupRatio,
"usable_group": usableGroup,
"supported_endpoint": model.GetSupportedEndpointMap(),
"auto_groups": service.GetUserAutoGroup(group),
"auto_groups": setting.AutoGroups,
})
}

View File

@@ -84,7 +84,6 @@ func Relay(c *gin.Context, relayFormat types.RelayFormat) {
defer func() {
if newAPIError != nil {
logger.LogError(c, fmt.Sprintf("relay error: %s", newAPIError.Error()))
newAPIError.SetMessage(common.MessageWithRequestId(newAPIError.Error(), requestId))
switch relayFormat {
case types.RelayFormatOpenAIRealtime:
@@ -225,12 +224,12 @@ func getChannel(c *gin.Context, group, originalModel string, retryCount int) (*m
AutoBan: &autoBanInt,
}, nil
}
channel, selectGroup, err := service.CacheGetRandomSatisfiedChannel(c, group, originalModel, retryCount)
channel, selectGroup, err := model.CacheGetRandomSatisfiedChannel(c, group, originalModel, retryCount)
if err != nil {
return nil, types.NewError(fmt.Errorf("获取分组 %s 下模型 %s 的可用渠道失败retry: %s", selectGroup, originalModel, err.Error()), types.ErrorCodeGetChannelFailed, types.ErrOptionWithSkipRetry())
}
if channel == nil {
return nil, types.NewError(fmt.Errorf("分组 %s 下模型 %s 的可用渠道不存在retry", selectGroup, originalModel), types.ErrorCodeGetChannelFailed, types.ErrOptionWithSkipRetry())
return nil, types.NewError(fmt.Errorf("分组 %s 下模型 %s 的可用渠道不存在(数据库一致性已被破坏,retry", selectGroup, originalModel), types.ErrorCodeGetChannelFailed, types.ErrOptionWithSkipRetry())
}
newAPIError := middleware.SetupContextForSelectedChannel(c, channel, originalModel)
if newAPIError != nil {
@@ -282,7 +281,7 @@ func shouldRetry(c *gin.Context, openaiErr *types.NewAPIError, retryTimes int) b
}
func processChannelError(c *gin.Context, channelError types.ChannelError, err *types.NewAPIError) {
logger.LogError(c, fmt.Sprintf("channel error (channel #%d, status code: %d): %s", channelError.ChannelId, err.StatusCode, err.Error()))
logger.LogError(c, fmt.Sprintf("relay error (channel #%d, status code: %d): %s", channelError.ChannelId, err.StatusCode, err.Error()))
// 不要使用context获取渠道信息异步处理时可能会出现渠道信息不一致的情况
// do not use context to get channel info, there may be inconsistent channel info when processing asynchronously
if service.ShouldDisableChannel(channelError.ChannelId, err) && channelError.AutoBan {
@@ -300,9 +299,6 @@ func processChannelError(c *gin.Context, channelError types.ChannelError, err *t
userGroup := c.GetString("group")
channelId := c.GetInt("channel_id")
other := make(map[string]interface{})
if c.Request != nil && c.Request.URL != nil {
other["request_path"] = c.Request.URL.Path
}
other["error_type"] = err.GetErrorType()
other["error_code"] = err.GetErrorCode()
other["status_code"] = err.StatusCode

View File

@@ -88,13 +88,10 @@ func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, cha
return fmt.Errorf("readAll failed for task %s: %w", taskId, err)
}
logger.LogDebug(ctx, fmt.Sprintf("UpdateVideoSingleTask response: %s", string(responseBody)))
taskResult := &relaycommon.TaskInfo{}
// try parse as New API response format
var responseItems dto.TaskResponse[model.Task]
if err = common.Unmarshal(responseBody, &responseItems); err == nil && responseItems.IsSuccess() {
logger.LogDebug(ctx, fmt.Sprintf("UpdateVideoSingleTask parsed as new api response format: %+v", responseItems))
if err = json.Unmarshal(responseBody, &responseItems); err == nil && responseItems.IsSuccess() {
t := responseItems.Data
taskResult.TaskID = t.TaskID
taskResult.Status = string(t.Status)
@@ -108,19 +105,10 @@ func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, cha
task.Data = redactVideoResponseBody(responseBody)
}
logger.LogDebug(ctx, fmt.Sprintf("UpdateVideoSingleTask taskResult: %+v", taskResult))
now := time.Now().Unix()
if taskResult.Status == "" {
//return fmt.Errorf("task %s status is empty", taskId)
taskResult = relaycommon.FailTaskInfo("upstream returned empty status")
return fmt.Errorf("task %s status is empty", taskId)
}
// 记录原本的状态,防止重复退款
shouldRefund := false
quota := task.Quota
preStatus := task.Status
task.Status = model.TaskStatus(taskResult.Status)
switch taskResult.Status {
case model.TaskStatusSubmitted:
@@ -149,19 +137,14 @@ func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, cha
if modelName, ok := taskData["model"].(string); ok && modelName != "" {
// 获取模型价格和倍率
modelRatio, hasRatioSetting, _ := ratio_setting.GetModelRatio(modelName)
// 只有配置了倍率(非固定价格)时才按 token 重新计费
if hasRatioSetting && modelRatio > 0 {
// 获取用户和组的倍率信息
group := task.Group
if group == "" {
user, err := model.GetUserById(task.UserId, false)
if err == nil {
group = user.Group
}
}
if group != "" {
groupRatio := ratio_setting.GetGroupRatio(group)
userGroupRatio, hasUserGroupRatio := ratio_setting.GetGroupGroupRatio(group, group)
user, err := model.GetUserById(task.UserId, false)
if err == nil {
groupRatio := ratio_setting.GetGroupRatio(user.Group)
userGroupRatio, hasUserGroupRatio := ratio_setting.GetGroupGroupRatio(user.Group, user.Group)
var finalGroupRatio float64
if hasUserGroupRatio {
@@ -231,7 +214,6 @@ func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, cha
}
}
case model.TaskStatusFailure:
logger.LogJson(ctx, fmt.Sprintf("Task %s failed", taskId), task)
task.Status = model.TaskStatusFailure
task.Progress = "100%"
if task.FinishTime == 0 {
@@ -239,13 +221,13 @@ func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, cha
}
task.FailReason = taskResult.Reason
logger.LogInfo(ctx, fmt.Sprintf("Task %s failed: %s", task.TaskID, task.FailReason))
taskResult.Progress = "100%"
quota := task.Quota
if quota != 0 {
if preStatus != model.TaskStatusFailure {
shouldRefund = true
} else {
logger.LogWarn(ctx, fmt.Sprintf("Task %s already in failure status, skip refund", task.TaskID))
if err := model.IncreaseUserQuota(task.UserId, quota, false); err != nil {
logger.LogError(ctx, "Failed to increase user quota: "+err.Error())
}
logContent := fmt.Sprintf("Video async task failed %s, refund %s", task.TaskID, logger.LogQuota(quota))
model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
}
default:
return fmt.Errorf("unknown task status %s for task %s", taskResult.Status, taskId)
@@ -255,16 +237,6 @@ func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, cha
}
if err := task.Update(); err != nil {
common.SysLog("UpdateVideoTask task error: " + err.Error())
shouldRefund = false
}
if shouldRefund {
// 任务失败且之前状态不是失败才退还额度,防止重复退还
if err := model.IncreaseUserQuota(task.UserId, quota, false); err != nil {
logger.LogWarn(ctx, "Failed to increase user quota: "+err.Error())
}
logContent := fmt.Sprintf("Video async task failed %s, refund %s", task.TaskID, logger.LogQuota(quota))
model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
}
return nil

View File

@@ -51,8 +51,6 @@ func GetTopUpInfo(c *gin.Context) {
data := gin.H{
"enable_online_topup": operation_setting.PayAddress != "" && operation_setting.EpayId != "" && operation_setting.EpayKey != "",
"enable_stripe_topup": setting.StripeApiSecret != "" && setting.StripeWebhookSecret != "" && setting.StripePriceId != "",
"enable_creem_topup": setting.CreemApiKey != "" && setting.CreemProducts != "[]",
"creem_products": setting.CreemProducts,
"pay_methods": payMethods,
"min_topup": operation_setting.MinTopUp,
"stripe_min_topup": setting.StripeMinTopUp,

View File

@@ -1,461 +0,0 @@
package controller
import (
"bytes"
"crypto/hmac"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"fmt"
"io"
"log"
"net/http"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/model"
"github.com/QuantumNous/new-api/setting"
"time"
"github.com/gin-gonic/gin"
"github.com/thanhpk/randstr"
)
const (
PaymentMethodCreem = "creem"
CreemSignatureHeader = "creem-signature"
)
var creemAdaptor = &CreemAdaptor{}
// 生成HMAC-SHA256签名
func generateCreemSignature(payload string, secret string) string {
h := hmac.New(sha256.New, []byte(secret))
h.Write([]byte(payload))
return hex.EncodeToString(h.Sum(nil))
}
// 验证Creem webhook签名
func verifyCreemSignature(payload string, signature string, secret string) bool {
if secret == "" {
log.Printf("Creem webhook secret not set")
if setting.CreemTestMode {
log.Printf("Skip Creem webhook sign verify in test mode")
return true
}
return false
}
expectedSignature := generateCreemSignature(payload, secret)
return hmac.Equal([]byte(signature), []byte(expectedSignature))
}
type CreemPayRequest struct {
ProductId string `json:"product_id"`
PaymentMethod string `json:"payment_method"`
}
type CreemProduct struct {
ProductId string `json:"productId"`
Name string `json:"name"`
Price float64 `json:"price"`
Currency string `json:"currency"`
Quota int64 `json:"quota"`
}
type CreemAdaptor struct {
}
func (*CreemAdaptor) RequestPay(c *gin.Context, req *CreemPayRequest) {
if req.PaymentMethod != PaymentMethodCreem {
c.JSON(200, gin.H{"message": "error", "data": "不支持的支付渠道"})
return
}
if req.ProductId == "" {
c.JSON(200, gin.H{"message": "error", "data": "请选择产品"})
return
}
// 解析产品列表
var products []CreemProduct
err := json.Unmarshal([]byte(setting.CreemProducts), &products)
if err != nil {
log.Println("解析Creem产品列表失败", err)
c.JSON(200, gin.H{"message": "error", "data": "产品配置错误"})
return
}
// 查找对应的产品
var selectedProduct *CreemProduct
for _, product := range products {
if product.ProductId == req.ProductId {
selectedProduct = &product
break
}
}
if selectedProduct == nil {
c.JSON(200, gin.H{"message": "error", "data": "产品不存在"})
return
}
id := c.GetInt("id")
user, _ := model.GetUserById(id, false)
// 生成唯一的订单引用ID
reference := fmt.Sprintf("creem-api-ref-%d-%d-%s", user.Id, time.Now().UnixMilli(), randstr.String(4))
referenceId := "ref_" + common.Sha1([]byte(reference))
// 先创建订单记录,使用产品配置的金额和充值额度
topUp := &model.TopUp{
UserId: id,
Amount: selectedProduct.Quota, // 充值额度
Money: selectedProduct.Price, // 支付金额
TradeNo: referenceId,
CreateTime: time.Now().Unix(),
Status: common.TopUpStatusPending,
}
err = topUp.Insert()
if err != nil {
log.Printf("创建Creem订单失败: %v", err)
c.JSON(200, gin.H{"message": "error", "data": "创建订单失败"})
return
}
// 创建支付链接,传入用户邮箱
checkoutUrl, err := genCreemLink(referenceId, selectedProduct, user.Email, user.Username)
if err != nil {
log.Printf("获取Creem支付链接失败: %v", err)
c.JSON(200, gin.H{"message": "error", "data": "拉起支付失败"})
return
}
log.Printf("Creem订单创建成功 - 用户ID: %d, 订单号: %s, 产品: %s, 充值额度: %d, 支付金额: %.2f",
id, referenceId, selectedProduct.Name, selectedProduct.Quota, selectedProduct.Price)
c.JSON(200, gin.H{
"message": "success",
"data": gin.H{
"checkout_url": checkoutUrl,
"order_id": referenceId,
},
})
}
func RequestCreemPay(c *gin.Context) {
var req CreemPayRequest
// 读取body内容用于打印同时保留原始数据供后续使用
bodyBytes, err := io.ReadAll(c.Request.Body)
if err != nil {
log.Printf("read creem pay req body err: %v", err)
c.JSON(200, gin.H{"message": "error", "data": "read query error"})
return
}
// 打印body内容
log.Printf("creem pay request body: %s", string(bodyBytes))
// 重新设置body供后续的ShouldBindJSON使用
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
err = c.ShouldBindJSON(&req)
if err != nil {
c.JSON(200, gin.H{"message": "error", "data": "参数错误"})
return
}
creemAdaptor.RequestPay(c, &req)
}
// 新的Creem Webhook结构体匹配实际的webhook数据格式
type CreemWebhookEvent struct {
Id string `json:"id"`
EventType string `json:"eventType"`
CreatedAt int64 `json:"created_at"`
Object struct {
Id string `json:"id"`
Object string `json:"object"`
RequestId string `json:"request_id"`
Order struct {
Object string `json:"object"`
Id string `json:"id"`
Customer string `json:"customer"`
Product string `json:"product"`
Amount int `json:"amount"`
Currency string `json:"currency"`
SubTotal int `json:"sub_total"`
TaxAmount int `json:"tax_amount"`
AmountDue int `json:"amount_due"`
AmountPaid int `json:"amount_paid"`
Status string `json:"status"`
Type string `json:"type"`
Transaction string `json:"transaction"`
CreatedAt string `json:"created_at"`
UpdatedAt string `json:"updated_at"`
Mode string `json:"mode"`
} `json:"order"`
Product struct {
Id string `json:"id"`
Object string `json:"object"`
Name string `json:"name"`
Description string `json:"description"`
Price int `json:"price"`
Currency string `json:"currency"`
BillingType string `json:"billing_type"`
BillingPeriod string `json:"billing_period"`
Status string `json:"status"`
TaxMode string `json:"tax_mode"`
TaxCategory string `json:"tax_category"`
DefaultSuccessUrl *string `json:"default_success_url"`
CreatedAt string `json:"created_at"`
UpdatedAt string `json:"updated_at"`
Mode string `json:"mode"`
} `json:"product"`
Units int `json:"units"`
Customer struct {
Id string `json:"id"`
Object string `json:"object"`
Email string `json:"email"`
Name string `json:"name"`
Country string `json:"country"`
CreatedAt string `json:"created_at"`
UpdatedAt string `json:"updated_at"`
Mode string `json:"mode"`
} `json:"customer"`
Status string `json:"status"`
Metadata map[string]string `json:"metadata"`
Mode string `json:"mode"`
} `json:"object"`
}
// 保留旧的结构体作为兼容
type CreemWebhookData struct {
Type string `json:"type"`
Data struct {
RequestId string `json:"request_id"`
Status string `json:"status"`
Metadata map[string]string `json:"metadata"`
} `json:"data"`
}
func CreemWebhook(c *gin.Context) {
// 读取body内容用于打印同时保留原始数据供后续使用
bodyBytes, err := io.ReadAll(c.Request.Body)
if err != nil {
log.Printf("读取Creem Webhook请求body失败: %v", err)
c.AbortWithStatus(http.StatusBadRequest)
return
}
// 获取签名头
signature := c.GetHeader(CreemSignatureHeader)
// 打印关键信息避免输出完整敏感payload
log.Printf("Creem Webhook - URI: %s", c.Request.RequestURI)
if setting.CreemTestMode {
log.Printf("Creem Webhook - Signature: %s , Body: %s", signature, bodyBytes)
} else if signature == "" {
log.Printf("Creem Webhook缺少签名头")
c.AbortWithStatus(http.StatusUnauthorized)
return
}
// 验证签名
if !verifyCreemSignature(string(bodyBytes), signature, setting.CreemWebhookSecret) {
log.Printf("Creem Webhook签名验证失败")
c.AbortWithStatus(http.StatusUnauthorized)
return
}
log.Printf("Creem Webhook签名验证成功")
// 重新设置body供后续的ShouldBindJSON使用
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
// 解析新格式的webhook数据
var webhookEvent CreemWebhookEvent
if err := c.ShouldBindJSON(&webhookEvent); err != nil {
log.Printf("解析Creem Webhook参数失败: %v", err)
c.AbortWithStatus(http.StatusBadRequest)
return
}
log.Printf("Creem Webhook解析成功 - EventType: %s, EventId: %s", webhookEvent.EventType, webhookEvent.Id)
// 根据事件类型处理不同的webhook
switch webhookEvent.EventType {
case "checkout.completed":
handleCheckoutCompleted(c, &webhookEvent)
default:
log.Printf("忽略Creem Webhook事件类型: %s", webhookEvent.EventType)
c.Status(http.StatusOK)
}
}
// 处理支付完成事件
func handleCheckoutCompleted(c *gin.Context, event *CreemWebhookEvent) {
// 验证订单状态
if event.Object.Order.Status != "paid" {
log.Printf("订单状态不是已支付: %s, 跳过处理", event.Object.Order.Status)
c.Status(http.StatusOK)
return
}
// 获取引用ID这是我们创建订单时传递的request_id
referenceId := event.Object.RequestId
if referenceId == "" {
log.Println("Creem Webhook缺少request_id字段")
c.AbortWithStatus(http.StatusBadRequest)
return
}
// 验证订单类型,目前只处理一次性付款
if event.Object.Order.Type != "onetime" {
log.Printf("暂不支持的订单类型: %s, 跳过处理", event.Object.Order.Type)
c.Status(http.StatusOK)
return
}
// 记录详细的支付信息
log.Printf("处理Creem支付完成 - 订单号: %s, Creem订单ID: %s, 支付金额: %d %s, 客户邮箱: <redacted>, 产品: %s",
referenceId,
event.Object.Order.Id,
event.Object.Order.AmountPaid,
event.Object.Order.Currency,
event.Object.Product.Name)
// 查询本地订单确认存在
topUp := model.GetTopUpByTradeNo(referenceId)
if topUp == nil {
log.Printf("Creem充值订单不存在: %s", referenceId)
c.AbortWithStatus(http.StatusBadRequest)
return
}
if topUp.Status != common.TopUpStatusPending {
log.Printf("Creem充值订单状态错误: %s, 当前状态: %s", referenceId, topUp.Status)
c.Status(http.StatusOK) // 已处理过的订单,返回成功避免重复处理
return
}
// 处理充值,传入客户邮箱和姓名信息
customerEmail := event.Object.Customer.Email
customerName := event.Object.Customer.Name
// 防护性检查,确保邮箱和姓名不为空字符串
if customerEmail == "" {
log.Printf("警告Creem回调中客户邮箱为空 - 订单号: %s", referenceId)
}
if customerName == "" {
log.Printf("警告Creem回调中客户姓名为空 - 订单号: %s", referenceId)
}
err := model.RechargeCreem(referenceId, customerEmail, customerName)
if err != nil {
log.Printf("Creem充值处理失败: %s, 订单号: %s", err.Error(), referenceId)
c.AbortWithStatus(http.StatusInternalServerError)
return
}
log.Printf("Creem充值成功 - 订单号: %s, 充值额度: %d, 支付金额: %.2f",
referenceId, topUp.Amount, topUp.Money)
c.Status(http.StatusOK)
}
type CreemCheckoutRequest struct {
ProductId string `json:"product_id"`
RequestId string `json:"request_id"`
Customer struct {
Email string `json:"email"`
} `json:"customer"`
Metadata map[string]string `json:"metadata,omitempty"`
}
type CreemCheckoutResponse struct {
CheckoutUrl string `json:"checkout_url"`
Id string `json:"id"`
}
func genCreemLink(referenceId string, product *CreemProduct, email string, username string) (string, error) {
if setting.CreemApiKey == "" {
return "", fmt.Errorf("未配置Creem API密钥")
}
// 根据测试模式选择 API 端点
apiUrl := "https://api.creem.io/v1/checkouts"
if setting.CreemTestMode {
apiUrl = "https://test-api.creem.io/v1/checkouts"
log.Printf("使用Creem测试环境: %s", apiUrl)
}
// 构建请求数据,确保包含用户邮箱
requestData := CreemCheckoutRequest{
ProductId: product.ProductId,
RequestId: referenceId, // 这个作为订单ID传递给Creem
Customer: struct {
Email string `json:"email"`
}{
Email: email, // 用户邮箱会在支付页面预填充
},
Metadata: map[string]string{
"username": username,
"reference_id": referenceId,
"product_name": product.Name,
"quota": fmt.Sprintf("%d", product.Quota),
},
}
// 序列化请求数据
jsonData, err := json.Marshal(requestData)
if err != nil {
return "", fmt.Errorf("序列化请求数据失败: %v", err)
}
// 创建 HTTP 请求
req, err := http.NewRequest("POST", apiUrl, bytes.NewBuffer(jsonData))
if err != nil {
return "", fmt.Errorf("创建HTTP请求失败: %v", err)
}
// 设置请求头
req.Header.Set("Content-Type", "application/json")
req.Header.Set("x-api-key", setting.CreemApiKey)
log.Printf("发送Creem支付请求 - URL: %s, 产品ID: %s, 用户邮箱: %s, 订单号: %s",
apiUrl, product.ProductId, email, referenceId)
// 发送请求
client := &http.Client{
Timeout: 30 * time.Second,
}
resp, err := client.Do(req)
if err != nil {
return "", fmt.Errorf("发送HTTP请求失败: %v", err)
}
defer resp.Body.Close()
// 读取响应
body, err := io.ReadAll(resp.Body)
if err != nil {
return "", fmt.Errorf("读取响应失败: %v", err)
}
log.Printf("Creem API resp - status code: %d, resp: %s", resp.StatusCode, string(body))
// 检查响应状态
if resp.StatusCode/100 != 2 {
return "", fmt.Errorf("Creem API http status %d ", resp.StatusCode)
}
// 解析响应
var checkoutResp CreemCheckoutResponse
err = json.Unmarshal(body, &checkoutResp)
if err != nil {
return "", fmt.Errorf("解析响应失败: %v", err)
}
if checkoutResp.CheckoutUrl == "" {
return "", fmt.Errorf("Creem API resp no checkout url ")
}
log.Printf("Creem 支付链接创建成功 - 订单号: %s, 支付链接: %s", referenceId, checkoutResp.CheckoutUrl)
return checkoutResp.CheckoutUrl, nil
}

View File

@@ -220,7 +220,7 @@ func genStripeLink(referenceId string, customerId string, email string, amount i
params := &stripe.CheckoutSessionParams{
ClientReferenceID: stripe.String(referenceId),
SuccessURL: stripe.String(system_setting.ServerAddress + "/console/log"),
CancelURL: stripe.String(system_setting.ServerAddress + "/console/topup"),
CancelURL: stripe.String(system_setting.ServerAddress + "/topup"),
LineItems: []*stripe.CheckoutSessionLineItemParams{
{
Price: stripe.String(setting.StripePriceId),

View File

@@ -13,7 +13,6 @@ import (
"github.com/QuantumNous/new-api/dto"
"github.com/QuantumNous/new-api/logger"
"github.com/QuantumNous/new-api/model"
"github.com/QuantumNous/new-api/service"
"github.com/QuantumNous/new-api/setting"
"github.com/QuantumNous/new-api/constant"
@@ -580,7 +579,7 @@ func GetUserModels(c *gin.Context) {
common.ApiError(c, err)
return
}
groups := service.GetUserUsableGroups(user.Group)
groups := setting.GetUserUsableGroups(user.Group)
var models []string
for group := range groups {
for _, g := range model.GetGroupEnabledModels(group) {

View File

@@ -4,10 +4,8 @@ import (
"fmt"
"io"
"net/http"
"net/url"
"time"
"github.com/QuantumNous/new-api/constant"
"github.com/QuantumNous/new-api/logger"
"github.com/QuantumNous/new-api/model"
@@ -38,7 +36,7 @@ func VideoProxy(c *gin.Context) {
return
}
if !exists || task == nil {
logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to get task %s: %v", taskID, err))
logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to get task %s: %s", taskID, err.Error()))
c.JSON(http.StatusNotFound, gin.H{
"error": gin.H{
"message": "Task not found",
@@ -60,7 +58,7 @@ func VideoProxy(c *gin.Context) {
channel, err := model.CacheGetChannel(task.ChannelId)
if err != nil {
logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to get task %s: not found", taskID))
logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to get channel %d: %s", task.ChannelId, err.Error()))
c.JSON(http.StatusInternalServerError, gin.H{
"error": gin.H{
"message": "Failed to retrieve channel information",
@@ -73,15 +71,15 @@ func VideoProxy(c *gin.Context) {
if baseURL == "" {
baseURL = "https://api.openai.com"
}
videoURL := fmt.Sprintf("%s/v1/videos/%s/content", baseURL, task.TaskID)
var videoURL string
client := &http.Client{
Timeout: 60 * time.Second,
}
req, err := http.NewRequestWithContext(c.Request.Context(), http.MethodGet, "", nil)
req, err := http.NewRequestWithContext(c.Request.Context(), http.MethodGet, videoURL, nil)
if err != nil {
logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to create request: %s", err.Error()))
logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to create request for %s: %s", videoURL, err.Error()))
c.JSON(http.StatusInternalServerError, gin.H{
"error": gin.H{
"message": "Failed to create proxy request",
@@ -91,52 +89,7 @@ func VideoProxy(c *gin.Context) {
return
}
switch channel.Type {
case constant.ChannelTypeGemini:
apiKey := task.PrivateData.Key
if apiKey == "" {
logger.LogError(c.Request.Context(), fmt.Sprintf("Missing stored API key for Gemini task %s", taskID))
c.JSON(http.StatusInternalServerError, gin.H{
"error": gin.H{
"message": "API key not stored for task",
"type": "server_error",
},
})
return
}
videoURL, err = getGeminiVideoURL(channel, task, apiKey)
if err != nil {
logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to resolve Gemini video URL for task %s: %s", taskID, err.Error()))
c.JSON(http.StatusBadGateway, gin.H{
"error": gin.H{
"message": "Failed to resolve Gemini video URL",
"type": "server_error",
},
})
return
}
req.Header.Set("x-goog-api-key", apiKey)
case constant.ChannelTypeAli:
// Video URL is directly in task.FailReason
videoURL = task.FailReason
default:
// Default (Sora, etc.): Use original logic
videoURL = fmt.Sprintf("%s/v1/videos/%s/content", baseURL, task.TaskID)
req.Header.Set("Authorization", "Bearer "+channel.Key)
}
req.URL, err = url.Parse(videoURL)
if err != nil {
logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to parse URL %s: %s", videoURL, err.Error()))
c.JSON(http.StatusInternalServerError, gin.H{
"error": gin.H{
"message": "Failed to create proxy request",
"type": "server_error",
},
})
return
}
req.Header.Set("Authorization", "Bearer "+channel.Key)
resp, err := client.Do(req)
if err != nil {

View File

@@ -1,158 +0,0 @@
package controller
import (
"encoding/json"
"fmt"
"io"
"strconv"
"strings"
"github.com/QuantumNous/new-api/constant"
"github.com/QuantumNous/new-api/model"
"github.com/QuantumNous/new-api/relay"
)
func getGeminiVideoURL(channel *model.Channel, task *model.Task, apiKey string) (string, error) {
if channel == nil || task == nil {
return "", fmt.Errorf("invalid channel or task")
}
if url := extractGeminiVideoURLFromTaskData(task); url != "" {
return ensureAPIKey(url, apiKey), nil
}
baseURL := constant.ChannelBaseURLs[channel.Type]
if channel.GetBaseURL() != "" {
baseURL = channel.GetBaseURL()
}
adaptor := relay.GetTaskAdaptor(constant.TaskPlatform(strconv.Itoa(channel.Type)))
if adaptor == nil {
return "", fmt.Errorf("gemini task adaptor not found")
}
if apiKey == "" {
return "", fmt.Errorf("api key not available for task")
}
resp, err := adaptor.FetchTask(baseURL, apiKey, map[string]any{
"task_id": task.TaskID,
"action": task.Action,
})
if err != nil {
return "", fmt.Errorf("fetch task failed: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return "", fmt.Errorf("read task response failed: %w", err)
}
taskInfo, parseErr := adaptor.ParseTaskResult(body)
if parseErr == nil && taskInfo != nil && taskInfo.RemoteUrl != "" {
return ensureAPIKey(taskInfo.RemoteUrl, apiKey), nil
}
if url := extractGeminiVideoURLFromPayload(body); url != "" {
return ensureAPIKey(url, apiKey), nil
}
if parseErr != nil {
return "", fmt.Errorf("parse task result failed: %w", parseErr)
}
return "", fmt.Errorf("gemini video url not found")
}
func extractGeminiVideoURLFromTaskData(task *model.Task) string {
if task == nil || len(task.Data) == 0 {
return ""
}
var payload map[string]any
if err := json.Unmarshal(task.Data, &payload); err != nil {
return ""
}
return extractGeminiVideoURLFromMap(payload)
}
func extractGeminiVideoURLFromPayload(body []byte) string {
var payload map[string]any
if err := json.Unmarshal(body, &payload); err != nil {
return ""
}
return extractGeminiVideoURLFromMap(payload)
}
func extractGeminiVideoURLFromMap(payload map[string]any) string {
if payload == nil {
return ""
}
if uri, ok := payload["uri"].(string); ok && uri != "" {
return uri
}
if resp, ok := payload["response"].(map[string]any); ok {
if uri := extractGeminiVideoURLFromResponse(resp); uri != "" {
return uri
}
}
return ""
}
func extractGeminiVideoURLFromResponse(resp map[string]any) string {
if resp == nil {
return ""
}
if gvr, ok := resp["generateVideoResponse"].(map[string]any); ok {
if uri := extractGeminiVideoURLFromGeneratedSamples(gvr); uri != "" {
return uri
}
}
if videos, ok := resp["videos"].([]any); ok {
for _, video := range videos {
if vm, ok := video.(map[string]any); ok {
if uri, ok := vm["uri"].(string); ok && uri != "" {
return uri
}
}
}
}
if uri, ok := resp["video"].(string); ok && uri != "" {
return uri
}
if uri, ok := resp["uri"].(string); ok && uri != "" {
return uri
}
return ""
}
func extractGeminiVideoURLFromGeneratedSamples(gvr map[string]any) string {
if gvr == nil {
return ""
}
if samples, ok := gvr["generatedSamples"].([]any); ok {
for _, sample := range samples {
if sm, ok := sample.(map[string]any); ok {
if video, ok := sm["video"].(map[string]any); ok {
if uri, ok := video["uri"].(string); ok && uri != "" {
return uri
}
}
}
}
}
return ""
}
func ensureAPIKey(uri, key string) string {
if key == "" || uri == "" {
return uri
}
if strings.Contains(uri, "key=") {
return uri
}
if strings.Contains(uri, "?") {
return fmt.Sprintf("%s&key=%s", uri, key)
}
return fmt.Sprintf("%s?key=%s", uri, key)
}

View File

@@ -0,0 +1,66 @@
package interfaces
import (
"io"
"net/http"
"github.com/QuantumNous/new-api/dto"
relaycommon "github.com/QuantumNous/new-api/relay/common"
"github.com/QuantumNous/new-api/types"
"github.com/gin-gonic/gin"
)
// ChannelPlugin 定义Channel插件接口
// 继承原有的Adaptor接口增加插件元数据
type ChannelPlugin interface {
// 插件元数据
Name() string
Version() string
Priority() int
// 原有Adaptor接口方法
Init(info *relaycommon.RelayInfo)
GetRequestURL(info *relaycommon.RelayInfo) (string, error)
SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error
ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error)
ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error)
ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error)
ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error)
ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error)
ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error)
DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error)
DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError)
GetModelList() []string
GetChannelName() string
ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error)
ConvertGeminiRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeminiChatRequest) (any, error)
}
// TaskChannelPlugin 定义Task类型的Channel插件接口
type TaskChannelPlugin interface {
// 插件元数据
Name() string
Version() string
Priority() int
// 原有TaskAdaptor接口方法
Init(info *relaycommon.RelayInfo)
ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) *dto.TaskError
BuildRequestURL(info *relaycommon.RelayInfo) (string, error)
BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error
BuildRequestBody(c *gin.Context, info *relaycommon.RelayInfo) (io.Reader, error)
DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error)
DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (taskID string, taskData []byte, err *dto.TaskError)
GetModelList() []string
GetChannelName() string
FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error)
ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error)
}
// ChannelConfig 插件配置
type ChannelConfig struct {
Enabled bool `yaml:"enabled"`
Priority int `yaml:"priority"`
Config map[string]interface{} `yaml:"config"`
}

93
core/interfaces/hook.go Normal file
View File

@@ -0,0 +1,93 @@
package interfaces
import (
"io"
"net/http"
"github.com/gin-gonic/gin"
)
// HookContext Relay Hook执行上下文
type HookContext struct {
// Gin Context
GinContext *gin.Context
// Request相关
Request *http.Request
RequestBody []byte
// Response相关
Response *http.Response
ResponseBody []byte
// Channel信息
ChannelID int
ChannelType int
ChannelName string
// Model信息
Model string
OriginalModel string
// User信息
UserID int
TokenID int
Group string
// 扩展数据(插件间共享)
Data map[string]interface{}
// 错误信息
Error error
// 是否跳过后续处理
ShouldSkip bool
}
// RelayHook Relay Hook接口
type RelayHook interface {
// 插件元数据
Name() string
Priority() int
Enabled() bool
// 生命周期钩子
// OnBeforeRequest 在请求发送到上游之前执行
OnBeforeRequest(ctx *HookContext) error
// OnAfterResponse 在收到上游响应之后执行
OnAfterResponse(ctx *HookContext) error
// OnError 在发生错误时执行
OnError(ctx *HookContext, err error) error
}
// RequestModifier 请求修改器接口
// 实现此接口的Hook可以修改请求内容
type RequestModifier interface {
RelayHook
ModifyRequest(ctx *HookContext, body io.Reader) (io.Reader, error)
}
// ResponseProcessor 响应处理器接口
// 实现此接口的Hook可以处理响应内容
type ResponseProcessor interface {
RelayHook
ProcessResponse(ctx *HookContext, body []byte) ([]byte, error)
}
// StreamProcessor 流式响应处理器接口
// 实现此接口的Hook可以处理流式响应
type StreamProcessor interface {
RelayHook
ProcessStreamChunk(ctx *HookContext, chunk []byte) ([]byte, error)
}
// HookConfig Hook配置
type HookConfig struct {
Name string `yaml:"name"`
Enabled bool `yaml:"enabled"`
Priority int `yaml:"priority"`
Config map[string]interface{} `yaml:"config"`
}

View File

@@ -0,0 +1,31 @@
package interfaces
import (
"github.com/gin-gonic/gin"
)
// MiddlewarePlugin 中间件插件接口
type MiddlewarePlugin interface {
// 插件元数据
Name() string
Priority() int
Enabled() bool
// 返回Gin中间件处理函数
Handler() gin.HandlerFunc
// 初始化(可选)
Initialize(config MiddlewareConfig) error
}
// MiddlewareConfig 中间件配置
type MiddlewareConfig struct {
Name string `yaml:"name"`
Enabled bool `yaml:"enabled"`
Priority int `yaml:"priority"`
Config map[string]interface{} `yaml:"config"`
}
// MiddlewareFactory 中间件工厂函数类型
type MiddlewareFactory func(config MiddlewareConfig) (MiddlewarePlugin, error)

View File

@@ -0,0 +1,171 @@
package registry
import (
"fmt"
"sync"
"github.com/QuantumNous/new-api/core/interfaces"
)
var (
// 全局Channel注册表
channelRegistry = &ChannelRegistry{plugins: make(map[int]interfaces.ChannelPlugin)}
channelRegistryLock sync.RWMutex
// 全局TaskChannel注册表
taskChannelRegistry = &TaskChannelRegistry{plugins: make(map[string]interfaces.TaskChannelPlugin)}
taskChannelRegistryLock sync.RWMutex
)
// ChannelRegistry Channel插件注册中心
type ChannelRegistry struct {
plugins map[int]interfaces.ChannelPlugin // channelType -> plugin
mu sync.RWMutex
}
// Register 注册Channel插件
func (r *ChannelRegistry) Register(channelType int, plugin interfaces.ChannelPlugin) error {
r.mu.Lock()
defer r.mu.Unlock()
if _, exists := r.plugins[channelType]; exists {
return fmt.Errorf("channel plugin for type %d already registered", channelType)
}
r.plugins[channelType] = plugin
return nil
}
// Get 获取Channel插件
func (r *ChannelRegistry) Get(channelType int) (interfaces.ChannelPlugin, error) {
r.mu.RLock()
defer r.mu.RUnlock()
plugin, exists := r.plugins[channelType]
if !exists {
return nil, fmt.Errorf("channel plugin for type %d not found", channelType)
}
return plugin, nil
}
// List 列出所有已注册的Channel插件
func (r *ChannelRegistry) List() []interfaces.ChannelPlugin {
r.mu.RLock()
defer r.mu.RUnlock()
plugins := make([]interfaces.ChannelPlugin, 0, len(r.plugins))
for _, plugin := range r.plugins {
plugins = append(plugins, plugin)
}
return plugins
}
// Has 检查是否存在指定的Channel插件
func (r *ChannelRegistry) Has(channelType int) bool {
r.mu.RLock()
defer r.mu.RUnlock()
_, exists := r.plugins[channelType]
return exists
}
// TaskChannelRegistry TaskChannel插件注册中心
type TaskChannelRegistry struct {
plugins map[string]interfaces.TaskChannelPlugin // platform -> plugin
mu sync.RWMutex
}
// Register 注册TaskChannel插件
func (r *TaskChannelRegistry) Register(platform string, plugin interfaces.TaskChannelPlugin) error {
r.mu.Lock()
defer r.mu.Unlock()
if _, exists := r.plugins[platform]; exists {
return fmt.Errorf("task channel plugin for platform %s already registered", platform)
}
r.plugins[platform] = plugin
return nil
}
// Get 获取TaskChannel插件
func (r *TaskChannelRegistry) Get(platform string) (interfaces.TaskChannelPlugin, error) {
r.mu.RLock()
defer r.mu.RUnlock()
plugin, exists := r.plugins[platform]
if !exists {
return nil, fmt.Errorf("task channel plugin for platform %s not found", platform)
}
return plugin, nil
}
// List 列出所有已注册的TaskChannel插件
func (r *TaskChannelRegistry) List() []interfaces.TaskChannelPlugin {
r.mu.RLock()
defer r.mu.RUnlock()
plugins := make([]interfaces.TaskChannelPlugin, 0, len(r.plugins))
for _, plugin := range r.plugins {
plugins = append(plugins, plugin)
}
return plugins
}
// 全局函数 - Channel Registry
// RegisterChannel 注册Channel插件
func RegisterChannel(channelType int, plugin interfaces.ChannelPlugin) error {
channelRegistryLock.Lock()
defer channelRegistryLock.Unlock()
return channelRegistry.Register(channelType, plugin)
}
// GetChannel 获取Channel插件
func GetChannel(channelType int) (interfaces.ChannelPlugin, error) {
channelRegistryLock.RLock()
defer channelRegistryLock.RUnlock()
return channelRegistry.Get(channelType)
}
// ListChannels 列出所有Channel插件
func ListChannels() []interfaces.ChannelPlugin {
channelRegistryLock.RLock()
defer channelRegistryLock.RUnlock()
return channelRegistry.List()
}
// HasChannel 检查是否存在指定的Channel插件
func HasChannel(channelType int) bool {
channelRegistryLock.RLock()
defer channelRegistryLock.RUnlock()
return channelRegistry.Has(channelType)
}
// 全局函数 - TaskChannel Registry
// RegisterTaskChannel 注册TaskChannel插件
func RegisterTaskChannel(platform string, plugin interfaces.TaskChannelPlugin) error {
taskChannelRegistryLock.Lock()
defer taskChannelRegistryLock.Unlock()
return taskChannelRegistry.Register(platform, plugin)
}
// GetTaskChannel 获取TaskChannel插件
func GetTaskChannel(platform string) (interfaces.TaskChannelPlugin, error) {
taskChannelRegistryLock.RLock()
defer taskChannelRegistryLock.RUnlock()
return taskChannelRegistry.Get(platform)
}
// ListTaskChannels 列出所有TaskChannel插件
func ListTaskChannels() []interfaces.TaskChannelPlugin {
taskChannelRegistryLock.RLock()
defer taskChannelRegistryLock.RUnlock()
return taskChannelRegistry.List()
}

View File

@@ -0,0 +1,183 @@
package registry
import (
"fmt"
"sort"
"sync"
"github.com/QuantumNous/new-api/core/interfaces"
)
var (
// 全局Hook注册表
hookRegistry = &HookRegistry{hooks: make([]interfaces.RelayHook, 0)}
hookRegistryLock sync.RWMutex
)
// HookRegistry Hook插件注册中心
type HookRegistry struct {
hooks []interfaces.RelayHook
sorted bool
mu sync.RWMutex
}
// Register 注册Hook插件
func (r *HookRegistry) Register(hook interfaces.RelayHook) error {
r.mu.Lock()
defer r.mu.Unlock()
// 检查是否已存在同名Hook
for _, h := range r.hooks {
if h.Name() == hook.Name() {
return fmt.Errorf("hook %s already registered", hook.Name())
}
}
r.hooks = append(r.hooks, hook)
r.sorted = false // 标记需要重新排序
return nil
}
// Get 获取指定名称的Hook插件
func (r *HookRegistry) Get(name string) (interfaces.RelayHook, error) {
r.mu.RLock()
defer r.mu.RUnlock()
for _, hook := range r.hooks {
if hook.Name() == name {
return hook, nil
}
}
return nil, fmt.Errorf("hook %s not found", name)
}
// List 列出所有已注册且启用的Hook插件按优先级排序
func (r *HookRegistry) List() []interfaces.RelayHook {
r.mu.Lock()
defer r.mu.Unlock()
// 如果未排序,先排序
if !r.sorted {
r.sortHooks()
}
// 只返回启用的hooks
enabledHooks := make([]interfaces.RelayHook, 0)
for _, hook := range r.hooks {
if hook.Enabled() {
enabledHooks = append(enabledHooks, hook)
}
}
return enabledHooks
}
// ListAll 列出所有已注册的Hook插件包括未启用的
func (r *HookRegistry) ListAll() []interfaces.RelayHook {
r.mu.RLock()
defer r.mu.RUnlock()
hooks := make([]interfaces.RelayHook, len(r.hooks))
copy(hooks, r.hooks)
return hooks
}
// sortHooks 按优先级排序hooks优先级数字越大越先执行
func (r *HookRegistry) sortHooks() {
sort.SliceStable(r.hooks, func(i, j int) bool {
return r.hooks[i].Priority() > r.hooks[j].Priority()
})
r.sorted = true
}
// Has 检查是否存在指定的Hook插件
func (r *HookRegistry) Has(name string) bool {
r.mu.RLock()
defer r.mu.RUnlock()
for _, hook := range r.hooks {
if hook.Name() == name {
return true
}
}
return false
}
// Count 返回已注册的Hook数量
func (r *HookRegistry) Count() int {
r.mu.RLock()
defer r.mu.RUnlock()
return len(r.hooks)
}
// EnabledCount 返回已启用的Hook数量
func (r *HookRegistry) EnabledCount() int {
r.mu.RLock()
defer r.mu.RUnlock()
count := 0
for _, hook := range r.hooks {
if hook.Enabled() {
count++
}
}
return count
}
// 全局函数
// RegisterHook 注册Hook插件
func RegisterHook(hook interfaces.RelayHook) error {
hookRegistryLock.Lock()
defer hookRegistryLock.Unlock()
return hookRegistry.Register(hook)
}
// GetHook 获取Hook插件
func GetHook(name string) (interfaces.RelayHook, error) {
hookRegistryLock.RLock()
defer hookRegistryLock.RUnlock()
return hookRegistry.Get(name)
}
// ListHooks 列出所有已启用的Hook插件
func ListHooks() []interfaces.RelayHook {
hookRegistryLock.RLock()
defer hookRegistryLock.RUnlock()
return hookRegistry.List()
}
// ListAllHooks 列出所有Hook插件
func ListAllHooks() []interfaces.RelayHook {
hookRegistryLock.RLock()
defer hookRegistryLock.RUnlock()
return hookRegistry.ListAll()
}
// HasHook 检查是否存在指定的Hook插件
func HasHook(name string) bool {
hookRegistryLock.RLock()
defer hookRegistryLock.RUnlock()
return hookRegistry.Has(name)
}
// HookCount 返回已注册的Hook数量
func HookCount() int {
hookRegistryLock.RLock()
defer hookRegistryLock.RUnlock()
return hookRegistry.Count()
}
// EnabledHookCount 返回已启用的Hook数量
func EnabledHookCount() int {
hookRegistryLock.RLock()
defer hookRegistryLock.RUnlock()
return hookRegistry.EnabledCount()
}

View File

@@ -0,0 +1,133 @@
package registry
import (
"fmt"
"sort"
"sync"
"github.com/QuantumNous/new-api/core/interfaces"
)
var (
// 全局Middleware注册表
middlewareRegistry = &MiddlewareRegistry{plugins: make(map[string]interfaces.MiddlewarePlugin)}
middlewareRegistryLock sync.RWMutex
)
// MiddlewareRegistry 中间件插件注册中心
type MiddlewareRegistry struct {
plugins map[string]interfaces.MiddlewarePlugin
mu sync.RWMutex
}
// Register 注册Middleware插件
func (r *MiddlewareRegistry) Register(plugin interfaces.MiddlewarePlugin) error {
r.mu.Lock()
defer r.mu.Unlock()
name := plugin.Name()
if _, exists := r.plugins[name]; exists {
return fmt.Errorf("middleware plugin %s already registered", name)
}
r.plugins[name] = plugin
return nil
}
// Get 获取Middleware插件
func (r *MiddlewareRegistry) Get(name string) (interfaces.MiddlewarePlugin, error) {
r.mu.RLock()
defer r.mu.RUnlock()
plugin, exists := r.plugins[name]
if !exists {
return nil, fmt.Errorf("middleware plugin %s not found", name)
}
return plugin, nil
}
// List 列出所有已注册的Middleware插件按优先级排序
func (r *MiddlewareRegistry) List() []interfaces.MiddlewarePlugin {
r.mu.RLock()
defer r.mu.RUnlock()
plugins := make([]interfaces.MiddlewarePlugin, 0, len(r.plugins))
for _, plugin := range r.plugins {
plugins = append(plugins, plugin)
}
// 按优先级排序(优先级数字越大越先执行)
sort.SliceStable(plugins, func(i, j int) bool {
return plugins[i].Priority() > plugins[j].Priority()
})
return plugins
}
// ListEnabled 列出所有已启用的Middleware插件按优先级排序
func (r *MiddlewareRegistry) ListEnabled() []interfaces.MiddlewarePlugin {
r.mu.RLock()
defer r.mu.RUnlock()
plugins := make([]interfaces.MiddlewarePlugin, 0, len(r.plugins))
for _, plugin := range r.plugins {
if plugin.Enabled() {
plugins = append(plugins, plugin)
}
}
// 按优先级排序
sort.SliceStable(plugins, func(i, j int) bool {
return plugins[i].Priority() > plugins[j].Priority()
})
return plugins
}
// Has 检查是否存在指定的Middleware插件
func (r *MiddlewareRegistry) Has(name string) bool {
r.mu.RLock()
defer r.mu.RUnlock()
_, exists := r.plugins[name]
return exists
}
// 全局函数
// RegisterMiddleware 注册Middleware插件
func RegisterMiddleware(plugin interfaces.MiddlewarePlugin) error {
middlewareRegistryLock.Lock()
defer middlewareRegistryLock.Unlock()
return middlewareRegistry.Register(plugin)
}
// GetMiddleware 获取Middleware插件
func GetMiddleware(name string) (interfaces.MiddlewarePlugin, error) {
middlewareRegistryLock.RLock()
defer middlewareRegistryLock.RUnlock()
return middlewareRegistry.Get(name)
}
// ListMiddlewares 列出所有Middleware插件
func ListMiddlewares() []interfaces.MiddlewarePlugin {
middlewareRegistryLock.RLock()
defer middlewareRegistryLock.RUnlock()
return middlewareRegistry.List()
}
// ListEnabledMiddlewares 列出所有已启用的Middleware插件
func ListEnabledMiddlewares() []interfaces.MiddlewarePlugin {
middlewareRegistryLock.RLock()
defer middlewareRegistryLock.RUnlock()
return middlewareRegistry.ListEnabled()
}
// HasMiddleware 检查是否存在指定的Middleware插件
func HasMiddleware(name string) bool {
middlewareRegistryLock.RLock()
defer middlewareRegistryLock.RUnlock()
return middlewareRegistry.Has(name)
}

View File

@@ -0,0 +1,116 @@
package registry
import (
"testing"
"github.com/QuantumNous/new-api/core/interfaces"
)
// Mock Hook实现
type mockHook struct {
name string
priority int
enabled bool
}
func (m *mockHook) Name() string { return m.name }
func (m *mockHook) Priority() int { return m.priority }
func (m *mockHook) Enabled() bool { return m.enabled }
func (m *mockHook) OnBeforeRequest(ctx *interfaces.HookContext) error { return nil }
func (m *mockHook) OnAfterResponse(ctx *interfaces.HookContext) error { return nil }
func (m *mockHook) OnError(ctx *interfaces.HookContext, err error) error { return nil }
func TestHookRegistry(t *testing.T) {
// 创建新的注册表(用于测试)
registry := &HookRegistry{hooks: make([]interfaces.RelayHook, 0)}
// 测试注册Hook
hook1 := &mockHook{name: "test_hook_1", priority: 100, enabled: true}
hook2 := &mockHook{name: "test_hook_2", priority: 50, enabled: true}
hook3 := &mockHook{name: "test_hook_3", priority: 75, enabled: false}
if err := registry.Register(hook1); err != nil {
t.Errorf("Failed to register hook1: %v", err)
}
if err := registry.Register(hook2); err != nil {
t.Errorf("Failed to register hook2: %v", err)
}
if err := registry.Register(hook3); err != nil {
t.Errorf("Failed to register hook3: %v", err)
}
// 测试重复注册
if err := registry.Register(hook1); err == nil {
t.Error("Expected error when registering duplicate hook")
}
// 测试获取Hook
if hook, err := registry.Get("test_hook_1"); err != nil {
t.Errorf("Failed to get hook: %v", err)
} else if hook.Name() != "test_hook_1" {
t.Errorf("Got wrong hook: %s", hook.Name())
}
// 测试不存在的Hook
if _, err := registry.Get("nonexistent"); err == nil {
t.Error("Expected error when getting nonexistent hook")
}
// 测试List应该只返回enabled的hooks
hooks := registry.List()
if len(hooks) != 2 {
t.Errorf("Expected 2 enabled hooks, got %d", len(hooks))
}
// 测试优先级排序100应该在50之前
if hooks[0].Priority() != 100 {
t.Errorf("Expected first hook to have priority 100, got %d", hooks[0].Priority())
}
// 测试Count
if count := registry.Count(); count != 3 {
t.Errorf("Expected count 3, got %d", count)
}
// 测试EnabledCount
if count := registry.EnabledCount(); count != 2 {
t.Errorf("Expected enabled count 2, got %d", count)
}
// 测试Has
if !registry.Has("test_hook_1") {
t.Error("Expected to find test_hook_1")
}
if registry.Has("nonexistent") {
t.Error("Should not find nonexistent hook")
}
}
func TestChannelRegistry(t *testing.T) {
// 这里可以添加Channel Registry的测试
// 但需要mock ChannelPlugin接口比较复杂
// 作为示例,我们只测试基本功能
registry := &ChannelRegistry{plugins: make(map[int]interfaces.ChannelPlugin)}
// 测试Has方法
if registry.Has(1) {
t.Error("Should not find channel type 1")
}
}
func TestMiddlewareRegistry(t *testing.T) {
// Middleware Registry测试
// 需要mock MiddlewarePlugin接口
registry := &MiddlewareRegistry{plugins: make(map[string]interfaces.MiddlewarePlugin)}
// 测试Has方法
if registry.Has("test_middleware") {
t.Error("Should not find test_middleware")
}
}

View File

@@ -30,14 +30,11 @@ services:
# - SQL_DSN=root:123456@tcp(mysql:3306)/new-api # Point to the mysql service, uncomment if using MySQL
- REDIS_CONN_STRING=redis://redis
- TZ=Asia/Shanghai
- ERROR_LOG_ENABLED=true # 是否启用错误日志记录 (Whether to enable error log recording)
- BATCH_UPDATE_ENABLED=true # 是否启用批量更新 (Whether to enable batch update)
# - STREAMING_TIMEOUT=300 # 流模式无响应超时时间单位秒默认120秒如果出现空补全可以尝试改为更大值 Streaming timeout in seconds, default is 120s. Increase if experiencing empty completions
# - SESSION_SECRET=random_string # 多机部署时设置,必须修改这个随机字符串!! multi-node deployment, set this to a random string!!!!!!!
- ERROR_LOG_ENABLED=true # 是否启用错误日志记录
- BATCH_UPDATE_ENABLED=true # 是否启用批量更新 batch update enabled
# - STREAMING_TIMEOUT=300 # 流模式无响应超时时间单位秒默认120秒如果出现空补全可以尝试改为更大值 Streaming timeout in seconds, default is 120s. Increase if experiencing empty completions
# - SESSION_SECRET=random_string # 多机部署时设置,必须修改这个随机字符串!! multi-node deployment, set this to a random string!!!!!!!
# - SYNC_FREQUENCY=60 # Uncomment if regular database syncing is needed
# - GOOGLE_ANALYTICS_ID=G-XXXXXXXXXX # Google Analytics 的测量 ID (Google Analytics Measurement ID)
# - UMAMI_WEBSITE_ID=xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx # Umami 网站 ID (Umami Website ID)
# - UMAMI_SCRIPT_URL=https://analytics.umami.is/script.js # Umami 脚本 URL默认为官方地址 (Umami Script URL, defaults to official URL)
depends_on:
- redis

View File

@@ -0,0 +1,359 @@
# New-API 插件化架构说明
## 完整目录结构
```
new-api-2/
├── core/ # 核心层(高性能,不可插件化)
│ ├── interfaces/ # 插件接口定义
│ │ ├── channel.go # Channel插件接口
│ │ ├── hook.go # Hook插件接口
│ │ └── middleware.go # Middleware插件接口
│ └── registry/ # 插件注册中心
│ ├── channel_registry.go # Channel注册器线程安全
│ ├── hook_registry.go # Hook注册器优先级排序
│ └── middleware_registry.go # Middleware注册器
├── plugins/ # 🔵 Tier 1: 编译时插件(已实施)
│ ├── channels/ # Channel插件
│ │ ├── base_plugin.go # 基础插件包装器
│ │ └── registry.go # 自动注册31个AI Provider
│ └── hooks/ # Hook插件
│ ├── web_search/ # 联网搜索Hook
│ │ ├── web_search_hook.go
│ │ └── init.go
│ └── content_filter/ # 内容过滤Hook
│ ├── content_filter_hook.go
│ └── init.go
├── marketplace/ # 🟣 Tier 2: 运行时插件待实施Phase 2
│ ├── loader/ # go-plugin加载器
│ │ ├── plugin_client.go # 插件客户端
│ │ ├── plugin_server.go # 插件服务器
│ │ └── lifecycle.go # 生命周期管理
│ ├── manager/ # 插件管理器
│ │ ├── installer.go # 安装/卸载
│ │ ├── updater.go # 版本更新
│ │ └── registry.go # 插件注册表
│ ├── security/ # 安全模块
│ │ ├── signature.go # Ed25519签名验证
│ │ ├── checksum.go # SHA256校验
│ │ └── sandbox.go # 沙箱配置
│ ├── store/ # 插件商店客户端
│ │ ├── client.go # 商店API客户端
│ │ ├── search.go # 搜索功能
│ │ └── download.go # 下载管理
│ └── proto/ # gRPC协议定义
│ ├── hook.proto # Hook插件协议
│ ├── channel.proto # Channel插件协议
│ └── common.proto # 通用消息
├── plugins_external/ # 第三方插件安装目录
│ ├── installed/ # 已安装插件
│ │ ├── awesome-hook-v1.0.0/
│ │ ├── custom-llm-v2.1.0/
│ │ └── slack-notify-v1.5.0/
│ ├── cache/ # 下载缓存
│ └── temp/ # 临时文件
├── relay/ # Relay层
│ ├── hooks/ # Hook执行链
│ │ ├── chain.go # Hook链管理器
│ │ ├── context.go # Hook上下文
│ │ └── context_builder.go # 上下文构建器
│ └── relay_adaptor.go # Channel适配器优先从Registry获取
├── config/ # 配置系统
│ ├── plugins.yaml # 插件配置Tier 1 + Tier 2
│ └── plugin_config.go # 配置加载器(支持环境变量)
└── (其他现有目录保持不变)
```
---
## 完整架构图
### 系统架构总览
```mermaid
graph TB
subgraph "🌐 API层"
Client[客户端请求]
end
subgraph "🔐 中间件层"
Auth[认证中间件]
RateLimit[限流中间件]
Cache[缓存中间件]
end
subgraph "🎯 核心层 Core"
Registry[插件注册中心]
ChannelReg[Channel Registry]
HookReg[Hook Registry]
MidReg[Middleware Registry]
Registry --> ChannelReg
Registry --> HookReg
Registry --> MidReg
end
subgraph "🔵 Tier 1: 编译时插件(已实施)"
direction TB
Channels[31个 Channel Plugins]
OpenAI[OpenAI]
Claude[Claude]
Gemini[Gemini]
Others[其他28个...]
Channels --> OpenAI
Channels --> Claude
Channels --> Gemini
Channels --> Others
Hooks[Hook Plugins]
WebSearch[Web Search Hook]
ContentFilter[Content Filter Hook]
Hooks --> WebSearch
Hooks --> ContentFilter
end
subgraph "🟣 Tier 2: 运行时插件(待实施)"
direction TB
Marketplace[🏪 Plugin Marketplace]
ExtHook[External Hooks<br/>Python/Go/Node.js]
ExtChannel[External Channels<br/>小众AI提供商]
ExtMid[External Middleware<br/>企业集成]
ExtUI[UI Extensions<br/>自定义仪表板]
Marketplace --> ExtHook
Marketplace --> ExtChannel
Marketplace --> ExtMid
Marketplace --> ExtUI
end
subgraph "⚡ Relay执行流程"
direction LR
HookChain[Hook Chain]
BeforeHook[OnBeforeRequest]
ChannelAdaptor[Channel Adaptor]
AfterHook[OnAfterResponse]
HookChain --> BeforeHook
BeforeHook --> ChannelAdaptor
ChannelAdaptor --> AfterHook
end
subgraph "🌍 上游服务"
Upstream[AI Provider APIs]
end
Client --> Auth
Auth --> RateLimit
RateLimit --> Cache
Cache --> Registry
Channels --> ChannelReg
Hooks --> HookReg
Registry --> HookChain
HookChain --> Upstream
Upstream --> HookChain
Registry -.gRPC/RPC.-> ExtHook
Registry -.gRPC/RPC.-> ExtChannel
Registry -.gRPC/RPC.-> ExtMid
style Marketplace fill:#f9f,stroke:#333,stroke-width:4px
style Registry fill:#bbf,stroke:#333,stroke-width:4px
style Channels fill:#bfb,stroke:#333,stroke-width:2px
style Hooks fill:#bfb,stroke:#333,stroke-width:2px
```
### 双层插件系统架构
```mermaid
graph LR
subgraph "🔵 Tier 1: 编译时插件"
T1[性能: 100%<br/>语言: Go only<br/>部署: 编译到二进制]
T1Chan[31 Channels]
T1Hook[2 Hooks]
T1 --> T1Chan
T1 --> T1Hook
end
subgraph "🟣 Tier 2: 运行时插件"
T2[性能: 90-95%<br/>语言: Go/Python/Node.js<br/>部署: 独立进程]
T2Hook[External Hooks]
T2Chan[External Channels]
T2Mid[External Middleware]
T2UI[UI Extensions]
T2 --> T2Hook
T2 --> T2Chan
T2 --> T2Mid
T2 --> T2UI
end
T1 -.进程内调用.-> Core[Core System]
T2 -.gRPC/RPC.-> Core
style T1 fill:#bfb,stroke:#333,stroke-width:3px
style T2 fill:#f9f,stroke:#333,stroke-width:3px
style Core fill:#bbf,stroke:#333,stroke-width:3px
```
---
## 核心要点说明
### 1. 双层插件架构
| 层级 | 技术方案 | 性能 | 适用场景 | 开发语言 |
|------|---------|------|---------|---------|
| **Tier 1<br/>编译时插件** | 编译时链接 | 100%<br/>零损失 | • 核心ChannelOpenAI等<br/>• 内置Hook<br/>• 高频调用路径 | Go only |
| **Tier 2<br/>运行时插件** | go-plugin<br/>gRPC | 90-95%<br/>5-10%开销 | • 第三方扩展<br/>• 企业定制<br/>• 多语言集成 | Go/Python/<br/>Node.js/Rust |
### 2. 核心组件
#### Core层核心引擎
- **interfaces/**: 定义ChannelPlugin、RelayHook、MiddlewarePlugin接口
- **registry/**: 线程安全的插件注册中心支持O(1)查找、优先级排序
#### Relay Hook链
- **执行流程**: OnBeforeRequest → Channel.DoRequest → OnAfterResponse
- **特性**: 优先级排序、短路机制、数据共享HookContext.Data
- **应用场景**: 联网搜索、内容过滤、日志增强、缓存策略
### 3. Tier 1: 编译时插件(已实施 ✅)
**特点**:
- 零性能损失,编译后与硬编码无差异
- init()函数自动注册到Registry
- YAML配置启用/禁用
**已实现**:
- ✅ 31个Channel插件OpenAI、Claude、Gemini等
- ✅ 2个Hook插件web_search、content_filter
- ✅ Hook执行链
- ✅ 配置系统(支持环境变量展开)
### 4. Tier 2: 运行时插件(待实施 🚧)
**基于**: [hashicorp/go-plugin](https://github.com/hashicorp/go-plugin)Vault/Terraform使用
**优势**:
- ✅ 进程隔离(第三方代码崩溃不影响主程序)
- ✅ 多语言支持gRPC协议
- ✅ 热插拔(无需重启)
- ✅ 安全验证Ed25519签名 + SHA256校验 + TLS加密
- ✅ 独立分发(插件商店)
**适用场景**:
- 第三方开发者扩展
- 企业定制业务逻辑
- Python ML模型集成
- 第三方服务集成Slack/钉钉/企业微信)
- UI扩展
### 5. 安全机制
**Tier 1编译时**:
- 内部代码审查
- 编译期类型安全
**Tier 2运行时**:
- Ed25519签名验证
- SHA256校验和
- gRPC TLS加密
- 进程资源限制(内存/CPU
- 插件商店审核机制
- 可信发布者白名单
### 6. 配置系统
**单一配置文件**: `config/plugins.yaml`
```yaml
# Tier 1: 编译时插件
plugins:
hooks:
- name: web_search
enabled: false
priority: 50
config:
api_key: ${WEB_SEARCH_API_KEY}
# Tier 2: 运行时插件(待实施)
external_plugins:
enabled: true
hooks:
- name: awesome_hook
binary: awesome-hook-v1.0.0/awesome-hook
checksum: sha256:abc123...
# 插件商店
marketplace:
enabled: true
api_url: https://plugins.new-api.com
```
### 7. 性能对比
| 场景 | Tier 1 | Tier 2 | RPC开销 |
|------|--------|--------|--------|
| 核心Channel | 100% | N/A | 0% |
| 内置Hook | 100% | N/A | 0% |
| 第三方Hook | N/A | 92-95% | 5-8% |
| Python插件 | N/A | 88-92% | 8-12% |
### 8. 实施路线图
#### Phase 1: 编译时插件系统 ✅ 已完成
- Core Registry + Hook Chain
- 31个Channel插件 + 2个Hook示例
- YAML配置系统
#### Phase 2: go-plugin基础
- protobuf协议定义
- PluginLoader实现
- 签名验证系统
- Python/Go SDK
#### Phase 3: 插件商店
- 商店后端API
- Web UI搜索、安装、管理
- CLI工具
- 多语言SDK
### 9. 扩展示例
**新增Tier 1插件编译时**:
```go
// 1. 实现接口
type MyHook struct{}
func (h *MyHook) OnBeforeRequest(ctx *HookContext) error { /*...*/ }
// 2. 注册
func init() { registry.RegisterHook(&MyHook{}) }
// 3. 导入到main.go
import _ "github.com/xxx/plugins/hooks/my_hook"
```
**新增Tier 2插件运行时**:
```python
# external-plugin/my_hook.py
from new_api_plugin_sdk import HookPlugin, serve
class MyHook(HookPlugin):
def on_before_request(self, ctx):
return {"modified_body": ctx.request_body}
serve(MyHook())
```

View File

@@ -1,22 +1,17 @@
package dto
import (
"encoding/json"
"github.com/QuantumNous/new-api/types"
"github.com/gin-gonic/gin"
)
type AudioRequest struct {
Model string `json:"model"`
Input string `json:"input"`
Voice string `json:"voice"`
Instructions string `json:"instructions,omitempty"`
ResponseFormat string `json:"response_format,omitempty"`
Speed float64 `json:"speed,omitempty"`
StreamFormat string `json:"stream_format,omitempty"`
Metadata json.RawMessage `json:"metadata,omitempty"`
Model string `json:"model"`
Input string `json:"input"`
Voice string `json:"voice"`
Speed float64 `json:"speed,omitempty"`
ResponseFormat string `json:"response_format,omitempty"`
}
func (r *AudioRequest) GetTokenCountMeta() *types.TokenCountMeta {

View File

@@ -16,13 +16,6 @@ const (
VertexKeyTypeAPIKey VertexKeyType = "api_key"
)
type AwsKeyType string
const (
AwsKeyTypeAKSK AwsKeyType = "ak_sk" // 默认
AwsKeyTypeApiKey AwsKeyType = "api_key"
)
type ChannelOtherSettings struct {
AzureResponsesVersion string `json:"azure_responses_version,omitempty"`
VertexKeyType VertexKeyType `json:"vertex_key_type,omitempty"` // "json" or "api_key"
@@ -30,7 +23,6 @@ type ChannelOtherSettings struct {
AllowServiceTier bool `json:"allow_service_tier,omitempty"` // 是否允许 service_tier 透传(默认过滤以避免额外计费)
DisableStore bool `json:"disable_store,omitempty"` // 是否禁用 store 透传(默认允许透传,禁用后可能导致 Codex 无法使用)
AllowSafetyIdentifier bool `json:"allow_safety_identifier,omitempty"` // 是否允许 safety_identifier 透传(默认过滤以保护用户隐私)
AwsKeyType AwsKeyType `json:"aws_key_type,omitempty"`
}
func (s *ChannelOtherSettings) IsOpenRouterEnterprise() bool {

View File

@@ -24,7 +24,7 @@ type ClaudeMediaMessage struct {
StopReason *string `json:"stop_reason,omitempty"`
PartialJson *string `json:"partial_json,omitempty"`
Role string `json:"role,omitempty"`
Thinking *string `json:"thinking,omitempty"`
Thinking string `json:"thinking,omitempty"`
Signature string `json:"signature,omitempty"`
Delta string `json:"delta,omitempty"`
CacheControl json.RawMessage `json:"cache_control,omitempty"`
@@ -148,10 +148,6 @@ func (c *ClaudeMessage) SetStringContent(content string) {
c.Content = content
}
func (c *ClaudeMessage) SetContent(content any) {
c.Content = content
}
func (c *ClaudeMessage) ParseContent() ([]ClaudeMediaMessage, error) {
return common.Any2Type[[]ClaudeMediaMessage](c.Content)
}
@@ -510,44 +506,11 @@ func (c *ClaudeResponse) GetClaudeError() *types.ClaudeError {
}
type ClaudeUsage struct {
InputTokens int `json:"input_tokens"`
CacheCreationInputTokens int `json:"cache_creation_input_tokens"`
CacheReadInputTokens int `json:"cache_read_input_tokens"`
OutputTokens int `json:"output_tokens"`
CacheCreation *ClaudeCacheCreationUsage `json:"cache_creation,omitempty"`
// claude cache 1h
ClaudeCacheCreation5mTokens int `json:"claude_cache_creation_5_m_tokens"`
ClaudeCacheCreation1hTokens int `json:"claude_cache_creation_1_h_tokens"`
ServerToolUse *ClaudeServerToolUse `json:"server_tool_use,omitempty"`
}
type ClaudeCacheCreationUsage struct {
Ephemeral5mInputTokens int `json:"ephemeral_5m_input_tokens,omitempty"`
Ephemeral1hInputTokens int `json:"ephemeral_1h_input_tokens,omitempty"`
}
func (u *ClaudeUsage) GetCacheCreation5mTokens() int {
if u == nil || u.CacheCreation == nil {
return 0
}
return u.CacheCreation.Ephemeral5mInputTokens
}
func (u *ClaudeUsage) GetCacheCreation1hTokens() int {
if u == nil || u.CacheCreation == nil {
return 0
}
return u.CacheCreation.Ephemeral1hInputTokens
}
func (u *ClaudeUsage) GetCacheCreationTotalTokens() int {
if u == nil {
return 0
}
if u.CacheCreationInputTokens > 0 {
return u.CacheCreationInputTokens
}
return u.GetCacheCreation5mTokens() + u.GetCacheCreation1hTokens()
InputTokens int `json:"input_tokens"`
CacheCreationInputTokens int `json:"cache_creation_input_tokens"`
CacheReadInputTokens int `json:"cache_read_input_tokens"`
OutputTokens int `json:"output_tokens"`
ServerToolUse *ClaudeServerToolUse `json:"server_tool_use,omitempty"`
}
type ClaudeServerToolUse struct {

View File

@@ -12,7 +12,6 @@ import (
)
type GeminiChatRequest struct {
Requests []GeminiChatRequest `json:"requests,omitempty"` // For batch requests
Contents []GeminiChatContent `json:"contents"`
SafetySettings []GeminiChatSafetySettings `json:"safetySettings,omitempty"`
GenerationConfig GeminiChatGenerationConfig `json:"generationConfig,omitempty"`

View File

@@ -27,8 +27,7 @@ type ImageRequest struct {
OutputCompression json.RawMessage `json:"output_compression,omitempty"`
PartialImages json.RawMessage `json:"partial_images,omitempty"`
// Stream bool `json:"stream,omitempty"`
Watermark *bool `json:"watermark,omitempty"`
Image json.RawMessage `json:"image,omitempty"`
Watermark *bool `json:"watermark,omitempty"`
// 用匿名参数接收额外参数
Extra map[string]json.RawMessage `json:"-"`
}

View File

@@ -232,13 +232,10 @@ func (r *GeneralOpenAIRequest) GetSystemRoleName() string {
return "system"
}
const CustomType = "custom"
type ToolCallRequest struct {
ID string `json:"id,omitempty"`
Type string `json:"type"`
Function FunctionRequest `json:"function,omitempty"`
Custom json.RawMessage `json:"custom,omitempty"`
Function FunctionRequest `json:"function"`
}
type FunctionRequest struct {

View File

@@ -230,11 +230,6 @@ type Usage struct {
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
InputTokensDetails *InputTokenDetails `json:"input_tokens_details"`
// claude cache 1h
ClaudeCacheCreation5mTokens int `json:"claude_cache_creation_5_m_tokens"`
ClaudeCacheCreation1hTokens int `json:"claude_cache_creation_1_h_tokens"`
// OpenRouter Params
Cost any `json:"cost,omitempty"`
}

14
go.mod
View File

@@ -5,7 +5,6 @@ go 1.25.1
require (
github.com/Calcium-Ion/go-epay v0.0.4
github.com/abema/go-mp4 v1.4.1
github.com/andybalholm/brotli v1.1.1
github.com/anknown/ahocorasick v0.0.0-20190904063843-d75dbd5169c0
github.com/aws/aws-sdk-go-v2 v1.37.2
@@ -19,34 +18,29 @@ require (
github.com/gin-contrib/static v0.0.1
github.com/gin-gonic/gin v1.9.1
github.com/glebarez/sqlite v1.9.0
github.com/go-audio/aiff v1.1.0
github.com/go-audio/wav v1.1.0
github.com/go-playground/validator/v10 v10.20.0
github.com/go-redis/redis/v8 v8.11.5
github.com/go-webauthn/webauthn v0.14.0
github.com/golang-jwt/jwt/v5 v5.3.0
github.com/google/uuid v1.6.0
github.com/gorilla/websocket v1.5.0
github.com/jfreymuth/oggvorbis v1.0.5
github.com/jinzhu/copier v0.4.0
github.com/joho/godotenv v1.5.1
github.com/mewkiz/flac v1.0.13
github.com/pkg/errors v0.9.1
github.com/pquerna/otp v1.5.0
github.com/samber/lo v1.39.0
github.com/shirou/gopsutil v3.21.11+incompatible
github.com/shopspring/decimal v1.4.0
github.com/stripe/stripe-go/v81 v81.4.0
github.com/tcolgate/mp3 v0.0.0-20170426193717-e79c5a46d300
github.com/thanhpk/randstr v1.0.6
github.com/tidwall/gjson v1.18.0
github.com/tidwall/sjson v1.2.5
github.com/tiktoken-go/tokenizer v0.6.2
github.com/yapingcat/gomedia v0.0.0-20240906162731-17feea57090c
golang.org/x/crypto v0.42.0
golang.org/x/image v0.23.0
golang.org/x/net v0.43.0
golang.org/x/sync v0.17.0
gopkg.in/yaml.v3 v3.0.1
gorm.io/driver/mysql v1.4.3
gorm.io/driver/postgres v1.5.2
gorm.io/gorm v1.25.2
@@ -69,8 +63,6 @@ require (
github.com/gabriel-vasile/mimetype v1.4.3 // indirect
github.com/gin-contrib/sse v0.1.0 // indirect
github.com/glebarez/go-sqlite v1.21.2 // indirect
github.com/go-audio/audio v1.0.0 // indirect
github.com/go-audio/riff v1.0.0 // indirect
github.com/go-ole/go-ole v1.2.6 // indirect
github.com/go-playground/locales v0.14.1 // indirect
github.com/go-playground/universal-translator v0.18.1 // indirect
@@ -82,20 +74,16 @@ require (
github.com/gorilla/context v1.1.1 // indirect
github.com/gorilla/securecookie v1.1.1 // indirect
github.com/gorilla/sessions v1.2.1 // indirect
github.com/icza/bitio v1.1.0 // indirect
github.com/jackc/pgpassfile v1.0.0 // indirect
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
github.com/jackc/pgx/v5 v5.7.1 // indirect
github.com/jackc/puddle/v2 v2.2.2 // indirect
github.com/jfreymuth/vorbis v1.0.2 // indirect
github.com/jinzhu/inflection v1.0.0 // indirect
github.com/jinzhu/now v1.1.5 // indirect
github.com/json-iterator/go v1.1.12 // indirect
github.com/klauspost/cpuid/v2 v2.3.0 // indirect
github.com/leodido/go-urn v1.4.0 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect
github.com/mewkiz/pkg v0.0.0-20250417130911-3f050ff8c56d // indirect
github.com/mewpkg/term v0.0.0-20241026122259-37a80af23985 // indirect
github.com/mitchellh/mapstructure v1.5.0 // indirect
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
github.com/modern-go/reflect2 v1.0.2 // indirect

40
go.sum
View File

@@ -1,7 +1,5 @@
github.com/Calcium-Ion/go-epay v0.0.4 h1:C96M7WfRLadcIVscWzwLiYs8etI1wrDmtFMuK2zP22A=
github.com/Calcium-Ion/go-epay v0.0.4/go.mod h1:cxo/ZOg8ClvE3VAnCmEzbuyAZINSq7kFEN9oHj5WQ2U=
github.com/abema/go-mp4 v1.4.1 h1:YoS4VRqd+pAmddRPLFf8vMk74kuGl6ULSjzhsIqwr6M=
github.com/abema/go-mp4 v1.4.1/go.mod h1:vPl9t5ZK7K0x68jh12/+ECWBCXoWuIDtNgPtU2f04ws=
github.com/andybalholm/brotli v1.1.1 h1:PR2pgnyFznKEugtsUo0xLdDop5SKXd5Qf5ysW+7XdTA=
github.com/andybalholm/brotli v1.1.1/go.mod h1:05ib4cKhjx3OQYUY22hTVd34Bc8upXjOLL2rKwwZBoA=
github.com/anknown/ahocorasick v0.0.0-20190904063843-d75dbd5169c0 h1:onfun1RA+KcxaMk1lfrRnwCd1UUuOjJM/lri5eM1qMs=
@@ -35,7 +33,6 @@ github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UF
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/cloudwego/base64x v0.1.6 h1:t11wG9AECkCDk5fMSoxmufanudBtJ+/HemLstXDLI2M=
github.com/cloudwego/base64x v0.1.6/go.mod h1:OFcloc187FXDaYHvrNIjxSe8ncn0OOM8gEHfghB2IPU=
github.com/creack/pty v1.1.7/go.mod h1:lj5s0c3V2DBrqTV7llrYr5NG6My20zk30Fl46Y7DoTY=
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
@@ -70,15 +67,6 @@ github.com/glebarez/go-sqlite v1.21.2 h1:3a6LFC4sKahUunAmynQKLZceZCOzUthkRkEAl9g
github.com/glebarez/go-sqlite v1.21.2/go.mod h1:sfxdZyhQjTM2Wry3gVYWaW072Ri1WMdWJi0k6+3382k=
github.com/glebarez/sqlite v1.9.0 h1:Aj6bPA12ZEx5GbSF6XADmCkYXlljPNUY+Zf1EQxynXs=
github.com/glebarez/sqlite v1.9.0/go.mod h1:YBYCoyupOao60lzp1MVBLEjZfgkq0tdB1voAQ09K9zw=
github.com/go-audio/aiff v1.1.0 h1:m2LYgu/2BarpF2yZnFPWtY3Tp41k0A4y51gDRZZsEuU=
github.com/go-audio/aiff v1.1.0/go.mod h1:sDik1muYvhPiccClfri0fv6U2fyH/dy4VRWmUz0cz9Q=
github.com/go-audio/audio v1.0.0 h1:zS9vebldgbQqktK4H0lUqWrG8P0NxCJVqcj7ZpNnwd4=
github.com/go-audio/audio v1.0.0/go.mod h1:6uAu0+H2lHkwdGsAY+j2wHPNPpPoeg5AaEFh9FlA+Zs=
github.com/go-audio/riff v1.0.0 h1:d8iCGbDvox9BfLagY94fBynxSPHO80LmZCaOsmKxokA=
github.com/go-audio/riff v1.0.0/go.mod h1:l3cQwc85y79NQFCRB7TiPoNiaijp6q8Z0Uv38rVG498=
github.com/go-audio/wav v1.0.0/go.mod h1:3yoReyQOsiARkvPl3ERCi8JFjihzG6WhjYpZCf5zAWE=
github.com/go-audio/wav v1.1.0 h1:jQgLtbqBzY7G+BM8fXF7AHUk1uHUviWS4X39d5rsL2g=
github.com/go-audio/wav v1.1.0/go.mod h1:mpe9qfwbScEbkd8uybLuIpTgHyrISw/OTuvjUW2iGtE=
github.com/go-ole/go-ole v1.2.6 h1:/Fpf6oFPoeFik9ty7siob0G6Ke8QvQEuVcuChpwXzpY=
github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0=
github.com/go-playground/assert/v2 v2.0.1/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
@@ -120,7 +108,6 @@ github.com/google/go-tpm v0.9.5/go.mod h1:h9jEsEECg7gtLis0upRBQU+GhYVH6jMjrFxI8u
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/google/pprof v0.0.0-20221118152302-e6195bd50e26 h1:Xim43kblpZXfIBQsbuBVKCudVG457BR2GZFIz3uw3hQ=
github.com/google/pprof v0.0.0-20221118152302-e6195bd50e26/go.mod h1:dDKJzRmX4S37WGHujM7tX//fmj1uioxKzKxz3lo4HJo=
github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/gorilla/context v1.1.1 h1:AWwleXJkX/nhcU9bZSnZoi3h/qGYqQAGhq6zZe/aQW8=
@@ -131,10 +118,6 @@ github.com/gorilla/sessions v1.2.1 h1:DHd3rPN5lE3Ts3D8rKkQ8x/0kqfeNmBAaiSi+o7Fsg
github.com/gorilla/sessions v1.2.1/go.mod h1:dk2InVEVJ0sfLlnXv9EAgkf6ecYs/i80K/zI+bUmuGM=
github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc=
github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/icza/bitio v1.1.0 h1:ysX4vtldjdi3Ygai5m1cWy4oLkhWTAi+SyO6HC8L9T0=
github.com/icza/bitio v1.1.0/go.mod h1:0jGnlLAx8MKMr9VGnn/4YrvZiprkvBelsVIbA9Jjr9A=
github.com/icza/mighty v0.0.0-20180919140131-cfd07d671de6 h1:8UsGZ2rr2ksmEru6lToqnXgA8Mz1DP11X4zSJ159C3k=
github.com/icza/mighty v0.0.0-20180919140131-cfd07d671de6/go.mod h1:xQig96I1VNBDIWGCdTt54nHt6EeI639SmHycLYL7FkA=
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
@@ -143,10 +126,6 @@ github.com/jackc/pgx/v5 v5.7.1 h1:x7SYsPBYDkHDksogeSmZZ5xzThcTgRz++I5E+ePFUcs=
github.com/jackc/pgx/v5 v5.7.1/go.mod h1:e7O26IywZZ+naJtWWos6i6fvWK+29etgITqrqHLfoZA=
github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo=
github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
github.com/jfreymuth/oggvorbis v1.0.5 h1:u+Ck+R0eLSRhgq8WTmffYnrVtSztJcYrl588DM4e3kQ=
github.com/jfreymuth/oggvorbis v1.0.5/go.mod h1:1U4pqWmghcoVsCJJ4fRBKv9peUJMBHixthRlBeD6uII=
github.com/jfreymuth/vorbis v1.0.2 h1:m1xH6+ZI4thH927pgKD8JOH4eaGRm18rEE9/0WKjvNE=
github.com/jfreymuth/vorbis v1.0.2/go.mod h1:DoftRo4AznKnShRl1GxiTFCseHr4zR9BN3TWXyuzrqQ=
github.com/jinzhu/copier v0.4.0 h1:w3ciUoD19shMCRargcpm0cm91ytaBhDvuRpz1ODO/U8=
github.com/jinzhu/copier v0.4.0/go.mod h1:DfbEm0FYsaqBcKcFuvmOZb218JkPGtvSHsKg8S8hyyg=
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
@@ -166,7 +145,6 @@ github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfn
github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0=
github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk=
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
github.com/kr/pty v1.1.8/go.mod h1:O1sed60cT9XZ5uDucP5qwvh+TE3NnUj51EiZO/lmSfw=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
@@ -174,17 +152,10 @@ github.com/leodido/go-urn v1.2.0/go.mod h1:+8+nEpDfqqsY+g338gtMEUOtuK+4dEMhiQEgx
github.com/leodido/go-urn v1.2.1/go.mod h1:zt4jvISO2HfUBqxjfIshjdMTYS56ZS/qv49ictyFfxY=
github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ=
github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI=
github.com/mattetti/audio v0.0.0-20180912171649-01576cde1f21/go.mod h1:LlQmBGkOuV/SKzEDXBPKauvN2UqCgzXO2XjecTGj40s=
github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU=
github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mewkiz/flac v1.0.13 h1:6wF8rRQKBFW159Daqx6Ro7K5ZnlVhHUKfS5aTsC4oXs=
github.com/mewkiz/flac v1.0.13/go.mod h1:HfPYDA+oxjyuqMu2V+cyKcxF51KM6incpw5eZXmfA6k=
github.com/mewkiz/pkg v0.0.0-20250417130911-3f050ff8c56d h1:IL2tii4jXLdhCeQN69HNzYYW1kl0meSG0wt5+sLwszU=
github.com/mewkiz/pkg v0.0.0-20250417130911-3f050ff8c56d/go.mod h1:SIpumAnUWSy0q9RzKD3pyH3g1t5vdawUAPcW5tQrUtI=
github.com/mewpkg/term v0.0.0-20241026122259-37a80af23985 h1:h8O1byDZ1uk6RUXMhj1QJU3VXFKXHDZxr4TXRPGeBa8=
github.com/mewpkg/term v0.0.0-20241026122259-37a80af23985/go.mod h1:uiPmbdUbdt1NkGApKl7htQjZ8S7XaGUAVulJUJ9v6q4=
github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY=
github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo=
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
@@ -199,8 +170,6 @@ github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE=
github.com/onsi/ginkgo v1.16.5/go.mod h1:+E8gABHa3K6zRBolWtd+ROzc/U5bkGt0FwiG042wbpU=
github.com/onsi/gomega v1.18.1 h1:M1GfJqGRrBrrGGsbxzV5dqM2U2ApXefZCQpkukxYRLE=
github.com/onsi/gomega v1.18.1/go.mod h1:0q+aL8jAiMXy9hbwj2mr5GziHiwhAIQpFmmtT5hitRs=
github.com/orcaman/writerseeker v0.0.0-20200621085525-1d3f536ff85e h1:s2RNOM/IGdY0Y6qfTeUKhDawdHDpK9RGBdx80qN4Ttw=
github.com/orcaman/writerseeker v0.0.0-20200621085525-1d3f536ff85e/go.mod h1:nBdnFKj15wFbf94Rwfq4m30eAcyY9V/IyKAGQFtqkW0=
github.com/pelletier/go-toml/v2 v2.0.1/go.mod h1:r9LEWfGN8R5k0VXJ+0BkIe7MYkRdwZOjgMj2KwnJFUo=
github.com/pelletier/go-toml/v2 v2.2.1 h1:9TA9+T8+8CUCO2+WYnDLCgrYi9+omqKXyjDtosvtEhg=
github.com/pelletier/go-toml/v2 v2.2.1/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs=
@@ -240,9 +209,6 @@ github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
github.com/stripe/stripe-go/v81 v81.4.0 h1:AuD9XzdAvl193qUCSaLocf8H+nRopOouXhxqJUzCLbw=
github.com/stripe/stripe-go/v81 v81.4.0/go.mod h1:C/F4jlmnGNacvYtBp/LUHCvVUJEZffFQCobkzwY1WOo=
github.com/sunfish-shogi/bufseekio v0.0.0-20210207115823-a4185644b365/go.mod h1:dEzdXgvImkQ3WLI+0KQpmEx8T/C/ma9KeS3AfmU899I=
github.com/tcolgate/mp3 v0.0.0-20170426193717-e79c5a46d300 h1:XQdibLKagjdevRB6vAjVY4qbSr8rQ610YzTkWcxzxSI=
github.com/tcolgate/mp3 v0.0.0-20170426193717-e79c5a46d300/go.mod h1:FNa/dfN95vAYCNFrIKRrlRo+MBLbwmR9Asa5f2ljmBI=
github.com/thanhpk/randstr v1.0.6 h1:psAOktJFD4vV9NEVb3qkhRSMvYh4ORRaj1+w/hn4B+o=
github.com/thanhpk/randstr v1.0.6/go.mod h1:M/H2P1eNLZzlDwAzpkkkUvoyNNMbzRGhESZuEQk3r0U=
github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
@@ -272,8 +238,6 @@ github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM=
github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg=
github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZqKjWU=
github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E=
github.com/yapingcat/gomedia v0.0.0-20240906162731-17feea57090c h1:xA2TJS9Hu/ivzaZIrDcwvpJ3Fnpsk5fDOJ4iSnL6J0w=
github.com/yapingcat/gomedia v0.0.0-20240906162731-17feea57090c/go.mod h1:WSZ59bidJOO40JSJmLqlkBJrjZCtjbKKkygEMfzY/kc=
github.com/yusufpapurcu/wmi v1.2.3 h1:E1ctvB7uKFMOJw3fdOW32DwGE9I7t++CRUEMKvFoFiw=
github.com/yusufpapurcu/wmi v1.2.3/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0=
go.uber.org/mock v0.6.0 h1:hyF9dfmbgIX5EfOdasqLsWD6xqpNZlXblLB/Dbnwv3Y=
@@ -293,7 +257,6 @@ golang.org/x/net v0.43.0 h1:lat02VYK2j4aLzMzecihNvTlJNQUq316m2Mr9rnM6YE=
golang.org/x/net v0.43.0/go.mod h1:vhO1fvI4dGsIjh73sWfUVjj3N7CA9WkKJNQm2svM6Jg=
golang.org/x/sync v0.17.0 h1:l60nONMj9l5drqw6jlhIELNv9I0A4OFgRsG9k2oT9Ug=
golang.org/x/sync v0.17.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
golang.org/x/sys v0.0.0-20190726091711-fc99dfbffb4e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
@@ -307,7 +270,6 @@ golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.36.0 h1:KVRy2GtZBrk1cBYA7MKu5bEZFxQk4NIDV6RLVcC8o0k=
golang.org/x/sys v0.36.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
@@ -324,8 +286,6 @@ gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI=
gopkg.in/src-d/go-billy.v4 v4.3.2 h1:0SQA1pRztfTFx2miS8sA97XvooFeNOmvUenF4o0EcVg=
gopkg.in/src-d/go-billy.v4 v4.3.2/go.mod h1:nDjArDMp+XMs1aFAESLRjfGSgfvoYN0hDfzEk0GjC98=
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ=
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw=
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=

View File

@@ -5,11 +5,9 @@ import (
"encoding/json"
"fmt"
"io"
"log/slog"
"log"
"os"
"path/filepath"
"sort"
"strings"
"sync"
"time"
@@ -21,561 +19,78 @@ import (
)
const (
// 日志轮转配置
defaultMaxLogSize = 100 * 1024 * 1024 // 100MB
defaultMaxLogFiles = 7 // 保留最近7个日志文件
defaultLogFileName = "newapi.log"
checkRotateInterval = 1000 // 每1000次写入检查一次是否需要轮转
loggerINFO = "INFO"
loggerWarn = "WARN"
loggerError = "ERR"
loggerDebug = "DEBUG"
)
var (
logMutex sync.RWMutex
rotateCheckLock sync.Mutex
defaultLogger *slog.Logger
logFile *os.File
logFilePath string
logDirPath string
writeCount int64
maxLogSize int64 = defaultMaxLogSize
maxLogFiles int = defaultMaxLogFiles
useJSONFormat bool
)
const maxLogCount = 1000000
func init() {
// Initialize with a text handler to stdout
handler := createHandler(os.Stdout)
defaultLogger = slog.New(handler)
slog.SetDefault(defaultLogger)
}
var logCount int
var setupLogLock sync.Mutex
var setupLogWorking bool
// SetupLogger 初始化日志系统
func SetupLogger() {
logMutex.Lock()
defer logMutex.Unlock()
// 读取环境变量配置
if maxSize := os.Getenv("LOG_MAX_SIZE_MB"); maxSize != "" {
if size, err := fmt.Sscanf(maxSize, "%d", &maxLogSize); err == nil && size > 0 {
maxLogSize = maxLogSize * 1024 * 1024 // 转换为字节
defer func() {
setupLogWorking = false
}()
if *common.LogDir != "" {
ok := setupLogLock.TryLock()
if !ok {
log.Println("setup log is already working")
return
}
}
if maxFiles := os.Getenv("LOG_MAX_FILES"); maxFiles != "" {
fmt.Sscanf(maxFiles, "%d", &maxLogFiles)
}
if os.Getenv("LOG_FORMAT") == "json" {
useJSONFormat = true
}
if *common.LogDir == "" {
// 如果没有配置日志目录,只输出到标准输出
handler := createHandler(os.Stdout)
defaultLogger = slog.New(handler)
slog.SetDefault(defaultLogger)
return
}
logDirPath = *common.LogDir
logFilePath = filepath.Join(logDirPath, defaultLogFileName)
// 检查日志文件是否需要按日期轮转(仅在启动时检查)
if err := checkAndRotateOnStartup(); err != nil {
slog.Error("failed to check log file on startup", "error", err)
}
// 打开或创建日志文件
if err := openLogFile(); err != nil {
slog.Error("failed to open log file", "error", err)
return
}
// 创建多路输出(控制台 + 文件)
multiWriter := io.MultiWriter(os.Stdout, logFile)
// 更新 gin 的默认输出
gin.DefaultWriter = multiWriter
gin.DefaultErrorWriter = multiWriter
// 更新 slog handler
handler := createHandler(multiWriter)
defaultLogger = slog.New(handler)
slog.SetDefault(defaultLogger)
slog.Info("logger initialized",
"log_dir", logDirPath,
"max_size_mb", maxLogSize/(1024*1024),
"max_files", maxLogFiles,
"format", getLogFormat())
}
// createHandler 创建日志处理器
func createHandler(w io.Writer) slog.Handler {
if useJSONFormat {
opts := &slog.HandlerOptions{
Level: getLogLevel(),
defer func() {
setupLogLock.Unlock()
}()
logPath := filepath.Join(*common.LogDir, fmt.Sprintf("oneapi-%s.log", time.Now().Format("20060102150405")))
fd, err := os.OpenFile(logPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
if err != nil {
log.Fatal("failed to open log file")
}
return slog.NewJSONHandler(w, opts)
}
return NewReadableTextHandler(w, getLogLevel())
}
// ReadableTextHandler 自定义的易读文本处理器
type ReadableTextHandler struct {
w io.Writer
level slog.Level
mu sync.Mutex
}
// NewReadableTextHandler 创建一个新的易读文本处理器
func NewReadableTextHandler(w io.Writer, level slog.Level) *ReadableTextHandler {
return &ReadableTextHandler{
w: w,
level: level,
gin.DefaultWriter = io.MultiWriter(os.Stdout, fd)
gin.DefaultErrorWriter = io.MultiWriter(os.Stderr, fd)
}
}
// Enabled 检查是否启用该级别
func (h *ReadableTextHandler) Enabled(_ context.Context, level slog.Level) bool {
return level >= h.level
}
// Handle 处理日志记录
func (h *ReadableTextHandler) Handle(_ context.Context, r slog.Record) error {
h.mu.Lock()
defer h.mu.Unlock()
// 格式: [LEVEL] YYYY/MM/DD - HH:mm:ss | request_id | message | key=value ...
buf := make([]byte, 0, 256)
// 日志级别
level := r.Level.String()
switch r.Level {
case slog.LevelDebug:
level = "DEBUG"
case slog.LevelInfo:
level = "INFO"
case slog.LevelWarn:
level = "WARN"
case slog.LevelError:
level = "ERROR"
}
buf = append(buf, '[')
buf = append(buf, level...)
buf = append(buf, "] "...)
// 时间
buf = append(buf, r.Time.Format("2006/01/02 - 15:04:05")...)
buf = append(buf, " | "...)
// 提取 request_id 和 component
var requestID, component string
otherAttrs := make([]slog.Attr, 0)
r.Attrs(func(a slog.Attr) bool {
switch a.Key {
case "request_id":
requestID = a.Value.String()
case "component":
component = a.Value.String()
default:
otherAttrs = append(otherAttrs, a)
}
return true
})
// 输出 request_id 或 component
if requestID != "" {
buf = append(buf, requestID...)
buf = append(buf, " | "...)
} else if component != "" {
buf = append(buf, component...)
buf = append(buf, " | "...)
}
// 消息
buf = append(buf, r.Message...)
// 其他属性
if len(otherAttrs) > 0 {
buf = append(buf, " | "...)
for i, a := range otherAttrs {
if i > 0 {
buf = append(buf, ", "...)
}
buf = append(buf, a.Key...)
buf = append(buf, '=')
buf = appendValue(buf, a.Value)
}
}
buf = append(buf, '\n')
_, err := h.w.Write(buf)
return err
}
// appendValue 追加值到缓冲区
func appendValue(buf []byte, v slog.Value) []byte {
switch v.Kind() {
case slog.KindString:
s := v.String()
// 如果字符串包含空格或特殊字符,加引号
if strings.ContainsAny(s, " \t\n\r,=") {
buf = append(buf, '"')
buf = append(buf, s...)
buf = append(buf, '"')
} else {
buf = append(buf, s...)
}
case slog.KindInt64:
buf = append(buf, fmt.Sprintf("%d", v.Int64())...)
case slog.KindUint64:
buf = append(buf, fmt.Sprintf("%d", v.Uint64())...)
case slog.KindFloat64:
buf = append(buf, fmt.Sprintf("%g", v.Float64())...)
case slog.KindBool:
buf = append(buf, fmt.Sprintf("%t", v.Bool())...)
case slog.KindDuration:
buf = append(buf, v.Duration().String()...)
case slog.KindTime:
buf = append(buf, v.Time().Format("2006-01-02 15:04:05")...)
default:
buf = append(buf, fmt.Sprintf("%v", v.Any())...)
}
return buf
}
// WithAttrs 返回一个新的处理器,包含指定的属性
func (h *ReadableTextHandler) WithAttrs(attrs []slog.Attr) slog.Handler {
// 简化实现:不支持 With
return h
}
// WithGroup 返回一个新的处理器,使用指定的组
func (h *ReadableTextHandler) WithGroup(name string) slog.Handler {
// 简化实现:不支持组
return h
}
// checkAndRotateOnStartup 启动时检查日志文件是否需要按日期轮转
func checkAndRotateOnStartup() error {
// 检查日志文件是否存在
fileInfo, err := os.Stat(logFilePath)
if err != nil {
if os.IsNotExist(err) {
// 文件不存在,不需要轮转
return nil
}
return fmt.Errorf("failed to stat log file: %w", err)
}
// 获取文件的修改时间
modTime := fileInfo.ModTime()
modDate := modTime.Format("2006-01-02")
today := time.Now().Format("2006-01-02")
// 如果文件的日期和今天不同,进行轮转
if modDate != today {
// 生成归档文件名(使用文件的修改日期)
timestamp := modTime.Format("2006-01-02-150405")
archivePath := filepath.Join(logDirPath, fmt.Sprintf("newapi.%s.log", timestamp))
// 重命名日志文件
if err := os.Rename(logFilePath, archivePath); err != nil {
return fmt.Errorf("failed to archive old log file: %w", err)
}
slog.Info("rotated old log file on startup",
"archive", archivePath,
"reason", "date changed")
// 清理旧的日志文件
gopool.Go(func() {
cleanOldLogFiles()
})
}
return nil
}
// openLogFile 打开日志文件
func openLogFile() error {
// 关闭旧的日志文件
if logFile != nil {
logFile.Close()
}
// 打开新的日志文件
fd, err := os.OpenFile(logFilePath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
if err != nil {
return fmt.Errorf("failed to open log file: %w", err)
}
logFile = fd
writeCount = 0
return nil
}
// rotateLogFile 轮转日志文件
func rotateLogFile() error {
if logFile == nil {
return nil
}
rotateCheckLock.Lock()
defer rotateCheckLock.Unlock()
// 获取当前日志文件信息
fileInfo, err := logFile.Stat()
if err != nil {
return fmt.Errorf("failed to stat log file: %w", err)
}
// 检查文件大小是否需要轮转
if fileInfo.Size() < maxLogSize {
return nil
}
// 关闭当前日志文件
logFile.Close()
// 生成归档文件名
timestamp := time.Now().Format("2006-01-02-150405")
archivePath := filepath.Join(logDirPath, fmt.Sprintf("newapi.%s.log", timestamp))
// 重命名当前日志文件为归档文件
if err := os.Rename(logFilePath, archivePath); err != nil {
// 如果重命名失败,尝试复制
if copyErr := copyFile(logFilePath, archivePath); copyErr != nil {
return fmt.Errorf("failed to archive log file: %w", err)
}
os.Truncate(logFilePath, 0)
}
// 清理旧的日志文件
gopool.Go(func() {
cleanOldLogFiles()
})
// 打开新的日志文件
if err := openLogFile(); err != nil {
return err
}
// 重新设置日志输出
multiWriter := io.MultiWriter(os.Stdout, logFile)
gin.DefaultWriter = multiWriter
gin.DefaultErrorWriter = multiWriter
handler := createHandler(multiWriter)
logMutex.Lock()
defaultLogger = slog.New(handler)
slog.SetDefault(defaultLogger)
logMutex.Unlock()
slog.Info("log file rotated",
"reason", "size limit reached",
"archive", archivePath)
return nil
}
// cleanOldLogFiles 清理旧的日志文件
func cleanOldLogFiles() {
if logDirPath == "" {
return
}
files, err := os.ReadDir(logDirPath)
if err != nil {
slog.Error("failed to read log directory", "error", err)
return
}
// 收集所有归档日志文件
var logFiles []os.DirEntry
for _, file := range files {
if !file.IsDir() && strings.HasPrefix(file.Name(), "newapi.") &&
strings.HasSuffix(file.Name(), ".log") &&
file.Name() != defaultLogFileName {
logFiles = append(logFiles, file)
}
}
// 如果归档文件数量超过限制,删除最旧的
if len(logFiles) > maxLogFiles {
// 按名称排序(文件名包含时间戳)
sort.Slice(logFiles, func(i, j int) bool {
return logFiles[i].Name() < logFiles[j].Name()
})
// 删除最旧的文件
deleteCount := len(logFiles) - maxLogFiles
for i := 0; i < deleteCount; i++ {
filePath := filepath.Join(logDirPath, logFiles[i].Name())
if err := os.Remove(filePath); err != nil {
slog.Error("failed to remove old log file",
"file", filePath,
"error", err)
} else {
slog.Info("removed old log file", "file", logFiles[i].Name())
}
}
}
}
// copyFile 复制文件
func copyFile(src, dst string) error {
sourceFile, err := os.Open(src)
if err != nil {
return err
}
defer sourceFile.Close()
destFile, err := os.Create(dst)
if err != nil {
return err
}
defer destFile.Close()
_, err = io.Copy(destFile, sourceFile)
return err
}
// getLogLevel 获取日志级别
func getLogLevel() slog.Level {
// 支持环境变量配置
if level := os.Getenv("LOG_LEVEL"); level != "" {
switch strings.ToUpper(level) {
case "DEBUG":
return slog.LevelDebug
case "INFO":
return slog.LevelInfo
case "WARN", "WARNING":
return slog.LevelWarn
case "ERROR":
return slog.LevelError
}
}
if common.DebugEnabled {
return slog.LevelDebug
}
return slog.LevelInfo
}
// getLogFormat 获取日志格式
func getLogFormat() string {
if useJSONFormat {
return "json"
}
return "text"
}
// checkAndRotateLog 检查并轮转日志
func checkAndRotateLog() {
if logFile == nil {
return
}
writeCount++
if writeCount%checkRotateInterval == 0 {
gopool.Go(func() {
if err := rotateLogFile(); err != nil {
slog.Error("failed to rotate log file", "error", err)
}
})
}
}
// LogInfo 记录信息级别日志
func LogInfo(ctx context.Context, msg string) {
if ctx == nil {
ctx = context.Background()
}
id := getRequestID(ctx)
logMutex.RLock()
logger := defaultLogger
logMutex.RUnlock()
logger.InfoContext(ctx, msg, "request_id", id)
checkAndRotateLog()
logHelper(ctx, loggerINFO, msg)
}
// LogWarn 记录警告级别日志
func LogWarn(ctx context.Context, msg string) {
if ctx == nil {
ctx = context.Background()
}
id := getRequestID(ctx)
logMutex.RLock()
logger := defaultLogger
logMutex.RUnlock()
logger.WarnContext(ctx, msg, "request_id", id)
checkAndRotateLog()
logHelper(ctx, loggerWarn, msg)
}
// LogError 记录错误级别日志
func LogError(ctx context.Context, msg string) {
if ctx == nil {
ctx = context.Background()
}
id := getRequestID(ctx)
logMutex.RLock()
logger := defaultLogger
logMutex.RUnlock()
logger.ErrorContext(ctx, msg, "request_id", id)
checkAndRotateLog()
logHelper(ctx, loggerError, msg)
}
// LogSystemInfo 记录系统信息
func LogSystemInfo(msg string) {
logMutex.RLock()
logger := defaultLogger
logMutex.RUnlock()
logger.Info(msg, "request_id", "SYSTEM")
checkAndRotateLog()
func LogDebug(ctx context.Context, msg string) {
if common.DebugEnabled {
logHelper(ctx, loggerDebug, msg)
}
}
// LogSystemError 记录系统错误
func LogSystemError(msg string) {
logMutex.RLock()
logger := defaultLogger
logMutex.RUnlock()
logger.Error(msg, "request_id", "SYSTEM")
checkAndRotateLog()
}
// LogDebug 记录调试级别日志
func LogDebug(ctx context.Context, msg string, args ...any) {
if !common.DebugEnabled && getLogLevel() > slog.LevelDebug {
return
}
if ctx == nil {
ctx = context.Background()
}
id := getRequestID(ctx)
if len(args) > 0 {
msg = fmt.Sprintf(msg, args...)
}
logMutex.RLock()
logger := defaultLogger
logMutex.RUnlock()
logger.DebugContext(ctx, msg, "request_id", id)
checkAndRotateLog()
}
// getRequestID 从上下文中获取请求ID
func getRequestID(ctx context.Context) string {
if ctx == nil {
return "SYSTEM"
func logHelper(ctx context.Context, level string, msg string) {
writer := gin.DefaultErrorWriter
if level == loggerINFO {
writer = gin.DefaultWriter
}
id := ctx.Value(common.RequestIdKey)
if id == nil {
return "SYSTEM"
id = "SYSTEM"
}
if strID, ok := id.(string); ok {
return strID
now := time.Now()
_, _ = fmt.Fprintf(writer, "[%s] %v | %s | %s \n", level, now.Format("2006/01/02 - 15:04:05"), id, msg)
logCount++ // we don't need accurate count, so no lock here
if logCount > maxLogCount && !setupLogWorking {
logCount = 0
setupLogWorking = true
gopool.Go(func() {
SetupLogger()
})
}
return "SYSTEM"
}
func LogQuota(quota int) string {
@@ -638,5 +153,5 @@ func LogJson(ctx context.Context, msg string, obj any) {
LogError(ctx, fmt.Sprintf("json marshal failed: %s", err.Error()))
return
}
LogDebug(ctx, fmt.Sprintf("%s | %s", msg, string(jsonStr)))
LogInfo(ctx, fmt.Sprintf("%s | %s", msg, string(jsonStr)))
}

97
main.go
View File

@@ -4,7 +4,7 @@ import (
"bytes"
"embed"
"fmt"
"log/slog"
"log"
"net/http"
"os"
"strconv"
@@ -21,6 +21,13 @@ import (
"github.com/QuantumNous/new-api/service"
"github.com/QuantumNous/new-api/setting/ratio_setting"
// Plugin System
coreregistry "github.com/QuantumNous/new-api/core/registry"
_ "github.com/QuantumNous/new-api/plugins/channels" // 自动注册channel插件
_ "github.com/QuantumNous/new-api/plugins/hooks/web_search" // 自动注册web_search hook
_ "github.com/QuantumNous/new-api/plugins/hooks/content_filter" // 自动注册content_filter hook
relayhooks "github.com/QuantumNous/new-api/relay/hooks"
"github.com/bytedance/gopkg/util/gopool"
"github.com/gin-contrib/sessions"
"github.com/gin-contrib/sessions/cookie"
@@ -118,9 +125,7 @@ func main() {
if os.Getenv("ENABLE_PPROF") == "true" {
gopool.Go(func() {
if err := http.ListenAndServe("0.0.0.0:8005", nil); err != nil {
slog.Error("pprof server failed", "error", err)
}
log.Println(http.ListenAndServe("0.0.0.0:8005", nil))
})
go common.Monitor()
common.SysLog("pprof enabled")
@@ -129,7 +134,7 @@ func main() {
// Initialize HTTP server
server := gin.New()
server.Use(gin.CustomRecovery(func(c *gin.Context, err any) {
logger.LogSystemError(fmt.Sprintf("panic detected: %v", err))
common.SysLog(fmt.Sprintf("panic detected: %v", err))
c.JSON(http.StatusInternalServerError, gin.H{
"error": gin.H{
"message": fmt.Sprintf("Panic detected, error: %v. Please submit a issue here: https://github.com/Calcium-Ion/new-api", err),
@@ -152,26 +157,6 @@ func main() {
})
server.Use(sessions.Sessions("session", store))
InjectUmamiAnalytics()
InjectGoogleAnalytics()
// 设置路由
router.SetRouter(server, buildFS, indexPage)
var port = os.Getenv("PORT")
if port == "" {
port = strconv.Itoa(*common.Port)
}
// Log startup success message
common.LogStartupSuccess(startTime, port)
err = server.Run(":" + port)
if err != nil {
common.FatalLog("failed to start HTTP server: " + err.Error())
}
}
func InjectUmamiAnalytics() {
analyticsInjectBuilder := &strings.Builder{}
if os.Getenv("UMAMI_WEBSITE_ID") != "" {
umamiSiteID := os.Getenv("UMAMI_WEBSITE_ID")
@@ -186,28 +171,21 @@ func InjectUmamiAnalytics() {
analyticsInjectBuilder.WriteString("\"></script>")
}
analyticsInject := analyticsInjectBuilder.String()
indexPage = bytes.ReplaceAll(indexPage, []byte("<!--umami-->\n"), []byte(analyticsInject))
}
indexPage = bytes.ReplaceAll(indexPage, []byte("<analytics></analytics>\n"), []byte(analyticsInject))
func InjectGoogleAnalytics() {
analyticsInjectBuilder := &strings.Builder{}
if os.Getenv("GOOGLE_ANALYTICS_ID") != "" {
gaID := os.Getenv("GOOGLE_ANALYTICS_ID")
// Google Analytics 4 (gtag.js)
analyticsInjectBuilder.WriteString("<script async src=\"https://www.googletagmanager.com/gtag/js?id=")
analyticsInjectBuilder.WriteString(gaID)
analyticsInjectBuilder.WriteString("\"></script>")
analyticsInjectBuilder.WriteString("<script>")
analyticsInjectBuilder.WriteString("window.dataLayer = window.dataLayer || [];")
analyticsInjectBuilder.WriteString("function gtag(){dataLayer.push(arguments);}")
analyticsInjectBuilder.WriteString("gtag('js', new Date());")
analyticsInjectBuilder.WriteString("gtag('config', '")
analyticsInjectBuilder.WriteString(gaID)
analyticsInjectBuilder.WriteString("');")
analyticsInjectBuilder.WriteString("</script>")
router.SetRouter(server, buildFS, indexPage)
var port = os.Getenv("PORT")
if port == "" {
port = strconv.Itoa(*common.Port)
}
// Log startup success message
common.LogStartupSuccess(startTime, port)
err = server.Run(":" + port)
if err != nil {
common.FatalLog("failed to start HTTP server: " + err.Error())
}
analyticsInject := analyticsInjectBuilder.String()
indexPage = bytes.ReplaceAll(indexPage, []byte("<!--Google Analytics-->\n"), []byte(analyticsInject))
}
func InitResources() error {
@@ -258,5 +236,34 @@ func InitResources() error {
if err != nil {
return err
}
// Initialize Plugin System
InitPluginSystem()
return nil
}
// InitPluginSystem 初始化插件系统
func InitPluginSystem() {
common.SysLog("Initializing plugin system...")
// 1. 加载插件配置
// config.LoadPluginConfig() 会在各个插件的init()中自动调用
// 2. 注册Channel插件
// 注意:这会在 plugins/channels/registry.go 的 init() 中自动完成
// 但为了确保加载,我们显式导入
common.SysLog("Registering channel plugins...")
// 3. 初始化Hook链
common.SysLog("Initializing hook chain...")
_ = relayhooks.GetGlobalChain()
hookCount := coreregistry.HookCount()
enabledHookCount := coreregistry.EnabledHookCount()
common.SysLog(fmt.Sprintf("Plugin system initialized: %d hooks registered (%d enabled)",
hookCount, enabledHookCount))
channelCount := len(coreregistry.ListChannels())
common.SysLog(fmt.Sprintf("Registered %d channel plugins", channelCount))
}

View File

@@ -9,7 +9,7 @@ import (
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/constant"
"github.com/QuantumNous/new-api/model"
"github.com/QuantumNous/new-api/service"
"github.com/QuantumNous/new-api/setting"
"github.com/QuantumNous/new-api/setting/ratio_setting"
"github.com/gin-contrib/sessions"
@@ -266,8 +266,8 @@ func TokenAuth() func(c *gin.Context) {
tokenGroup := token.Group
if tokenGroup != "" {
// check common.UserUsableGroups[userGroup]
if _, ok := service.GetUserUsableGroups(userGroup)[tokenGroup]; !ok {
abortWithOpenAiMessage(c, http.StatusForbidden, fmt.Sprintf("无权访问 %s 分组", tokenGroup))
if _, ok := setting.GetUserUsableGroups(userGroup)[tokenGroup]; !ok {
abortWithOpenAiMessage(c, http.StatusForbidden, fmt.Sprintf("令牌分组 %s 已被禁用", tokenGroup))
return
}
// check group in common.GroupRatio

View File

@@ -4,7 +4,6 @@ import (
"errors"
"fmt"
"net/http"
"slices"
"strconv"
"strings"
"time"
@@ -15,6 +14,7 @@ import (
"github.com/QuantumNous/new-api/model"
relayconstant "github.com/QuantumNous/new-api/relay/constant"
"github.com/QuantumNous/new-api/service"
"github.com/QuantumNous/new-api/setting"
"github.com/QuantumNous/new-api/setting/ratio_setting"
"github.com/QuantumNous/new-api/types"
@@ -79,31 +79,30 @@ func Distribute() func(c *gin.Context) {
return
}
var selectGroup string
usingGroup := common.GetContextKeyString(c, constant.ContextKeyUsingGroup)
userGroup := common.GetContextKeyString(c, constant.ContextKeyUsingGroup)
// check path is /pg/chat/completions
if strings.HasPrefix(c.Request.URL.Path, "/pg/chat/completions") {
playgroundRequest := &dto.PlayGroundRequest{}
err = common.UnmarshalBodyReusable(c, playgroundRequest)
if err != nil {
abortWithOpenAiMessage(c, http.StatusBadRequest, "无效的playground请求, "+err.Error())
abortWithOpenAiMessage(c, http.StatusBadRequest, "无效的请求, "+err.Error())
return
}
if playgroundRequest.Group != "" {
if !service.GroupInUserUsableGroups(usingGroup, playgroundRequest.Group) && playgroundRequest.Group != usingGroup {
if !setting.GroupInUserUsableGroups(playgroundRequest.Group) && playgroundRequest.Group != userGroup {
abortWithOpenAiMessage(c, http.StatusForbidden, "无权访问该分组")
return
}
usingGroup = playgroundRequest.Group
common.SetContextKey(c, constant.ContextKeyUsingGroup, usingGroup)
userGroup = playgroundRequest.Group
}
}
channel, selectGroup, err = service.CacheGetRandomSatisfiedChannel(c, usingGroup, modelRequest.Model, 0)
channel, selectGroup, err = model.CacheGetRandomSatisfiedChannel(c, userGroup, modelRequest.Model, 0)
if err != nil {
showGroup := usingGroup
if usingGroup == "auto" {
showGroup := userGroup
if userGroup == "auto" {
showGroup = fmt.Sprintf("auto(%s)", selectGroup)
}
message := fmt.Sprintf("获取分组 %s 下模型 %s 的可用渠道失败distributor: %s", showGroup, modelRequest.Model, err.Error())
message := fmt.Sprintf("获取分组 %s 下模型 %s 的可用渠道失败(数据库一致性已被破坏,distributor: %s", showGroup, modelRequest.Model, err.Error())
// 如果错误,但是渠道不为空,说明是数据库一致性问题
//if channel != nil {
// common.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id))
@@ -113,7 +112,7 @@ func Distribute() func(c *gin.Context) {
return
}
if channel == nil {
abortWithOpenAiMessage(c, http.StatusServiceUnavailable, fmt.Sprintf("分组 %s 下模型 %s 无可用渠道distributor", usingGroup, modelRequest.Model), string(types.ErrorCodeModelNotFound))
abortWithOpenAiMessage(c, http.StatusServiceUnavailable, fmt.Sprintf("分组 %s 下模型 %s 无可用渠道distributor", userGroup, modelRequest.Model), string(types.ErrorCodeModelNotFound))
return
}
}
@@ -124,20 +123,6 @@ func Distribute() func(c *gin.Context) {
}
}
// getModelFromRequest 从请求中读取模型信息
// 根据 Content-Type 自动处理:
// - application/json
// - application/x-www-form-urlencoded
// - multipart/form-data
func getModelFromRequest(c *gin.Context) (*ModelRequest, error) {
var modelRequest ModelRequest
err := common.UnmarshalBodyReusable(c, &modelRequest)
if err != nil {
return nil, errors.New("无效的请求, " + err.Error())
}
return &modelRequest, nil
}
func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
var modelRequest ModelRequest
shouldSelectChannel := true
@@ -153,7 +138,7 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
midjourneyRequest := dto.MidjourneyRequest{}
err = common.UnmarshalBodyReusable(c, &midjourneyRequest)
if err != nil {
return nil, false, errors.New("无效的midjourney请求, " + err.Error())
return nil, false, err
}
midjourneyModel, mjErr, success := service.GetMjRequestModel(relayMode, &midjourneyRequest)
if mjErr != nil {
@@ -190,12 +175,23 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
relayMode := relayconstant.RelayModeUnknown
if c.Request.Method == http.MethodPost {
relayMode = relayconstant.RelayModeVideoSubmit
req, err := getModelFromRequest(c)
if err != nil {
return nil, false, err
}
if req != nil {
modelRequest.Model = req.Model
contentType := c.Request.Header.Get("Content-Type")
if strings.HasPrefix(contentType, "multipart/form-data") {
form, err := common.ParseMultipartFormReusable(c)
if err != nil {
return nil, false, errors.New("无效的video请求, " + err.Error())
}
defer form.RemoveAll()
if form != nil {
if values, ok := form.Value["model"]; ok && len(values) > 0 {
modelRequest.Model = values[0]
}
}
} else if strings.HasPrefix(contentType, "application/json") {
err = common.UnmarshalBodyReusable(c, &modelRequest)
if err != nil {
return nil, false, errors.New("无效的video请求, " + err.Error())
}
}
} else if c.Request.Method == http.MethodGet {
relayMode = relayconstant.RelayModeVideoFetchByID
@@ -205,11 +201,10 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
} else if strings.Contains(c.Request.URL.Path, "/v1/video/generations") {
relayMode := relayconstant.RelayModeUnknown
if c.Request.Method == http.MethodPost {
req, err := getModelFromRequest(c)
err = common.UnmarshalBodyReusable(c, &modelRequest)
if err != nil {
return nil, false, err
return nil, false, errors.New("video无效的请求, " + err.Error())
}
modelRequest.Model = req.Model
relayMode = relayconstant.RelayModeVideoSubmit
} else if c.Request.Method == http.MethodGet {
relayMode = relayconstant.RelayModeVideoFetchByID
@@ -227,11 +222,10 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
}
c.Set("relay_mode", relayMode)
} else if !strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") && !strings.Contains(c.Request.Header.Get("Content-Type"), "multipart/form-data") {
req, err := getModelFromRequest(c)
if err != nil {
return nil, false, err
}
modelRequest.Model = req.Model
err = common.UnmarshalBodyReusable(c, &modelRequest)
}
if err != nil {
return nil, false, errors.New("无效的请求, " + err.Error())
}
if strings.HasPrefix(c.Request.URL.Path, "/v1/realtime") {
//wss://api.openai.com/v1/realtime?model=gpt-4o-realtime-preview-2024-10-01
@@ -251,31 +245,20 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
modelRequest.Model = common.GetStringIfEmpty(modelRequest.Model, "dall-e")
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/images/edits") {
//modelRequest.Model = common.GetStringIfEmpty(c.PostForm("model"), "gpt-image-1")
contentType := c.ContentType()
if slices.Contains([]string{gin.MIMEPOSTForm, gin.MIMEMultipartPOSTForm}, contentType) {
req, err := getModelFromRequest(c)
if err == nil && req.Model != "" {
modelRequest.Model = req.Model
}
if strings.Contains(c.Request.Header.Get("Content-Type"), "multipart/form-data") {
modelRequest.Model = c.PostForm("model")
}
}
if strings.HasPrefix(c.Request.URL.Path, "/v1/audio") {
relayMode := relayconstant.RelayModeAudioSpeech
if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/speech") {
modelRequest.Model = common.GetStringIfEmpty(modelRequest.Model, "tts-1")
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translations") {
// 先尝试从请求读取
if req, err := getModelFromRequest(c); err == nil && req.Model != "" {
modelRequest.Model = req.Model
}
modelRequest.Model = common.GetStringIfEmpty(modelRequest.Model, c.PostForm("model"))
modelRequest.Model = common.GetStringIfEmpty(modelRequest.Model, "whisper-1")
relayMode = relayconstant.RelayModeAudioTranslation
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") {
// 先尝试从请求读取
if req, err := getModelFromRequest(c); err == nil && req.Model != "" {
modelRequest.Model = req.Model
}
modelRequest.Model = common.GetStringIfEmpty(modelRequest.Model, c.PostForm("model"))
modelRequest.Model = common.GetStringIfEmpty(modelRequest.Model, "whisper-1")
relayMode = relayconstant.RelayModeAudioTranscription
}
@@ -283,12 +266,10 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
}
if strings.HasPrefix(c.Request.URL.Path, "/pg/chat/completions") {
// playground chat completions
req, err := getModelFromRequest(c)
err = common.UnmarshalBodyReusable(c, &modelRequest)
if err != nil {
return nil, false, err
return nil, false, errors.New("无效的请求, " + err.Error())
}
modelRequest.Model = req.Model
modelRequest.Group = req.Group
common.SetContextKey(c, constant.ContextKeyTokenGroup, modelRequest.Group)
}
return &modelRequest, shouldSelectChannel, nil

View File

@@ -102,10 +102,7 @@ func GlobalAPIRateLimit() func(c *gin.Context) {
}
func CriticalRateLimit() func(c *gin.Context) {
if common.CriticalRateLimitEnable {
return rateLimitFactory(common.CriticalRateLimitNum, common.CriticalRateLimitDuration, "CT")
}
return defNext
return rateLimitFactory(common.CriticalRateLimitNum, common.CriticalRateLimitDuration, "CT")
}
func DownloadRateLimit() func(c *gin.Context) {

View File

@@ -103,7 +103,7 @@ func getChannelQuery(group string, model string, retry int) (*gorm.DB, error) {
return channelQuery, nil
}
func GetChannel(group string, model string, retry int) (*Channel, error) {
func GetRandomSatisfiedChannel(group string, model string, retry int) (*Channel, error) {
var abilities []Ability
var err error = nil

View File

@@ -688,7 +688,7 @@ func DisableChannelByTag(tag string) error {
return err
}
func EditChannelByTag(tag string, newTag *string, modelMapping *string, models *string, group *string, priority *int64, weight *uint, paramOverride *string, headerOverride *string) error {
func EditChannelByTag(tag string, newTag *string, modelMapping *string, models *string, group *string, priority *int64, weight *uint) error {
updateData := Channel{}
shouldReCreateAbilities := false
updatedTag := tag
@@ -714,12 +714,6 @@ func EditChannelByTag(tag string, newTag *string, modelMapping *string, models *
if weight != nil {
updateData.Weight = weight
}
if paramOverride != nil {
updateData.ParamOverride = paramOverride
}
if headerOverride != nil {
updateData.HeaderOverride = headerOverride
}
err := DB.Model(&Channel{}).Where("tag = ?", tag).Updates(updateData).Error
if err != nil {

View File

@@ -11,7 +11,10 @@ import (
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/constant"
"github.com/QuantumNous/new-api/setting"
"github.com/QuantumNous/new-api/setting/ratio_setting"
"github.com/gin-gonic/gin"
)
var group2model2channels map[string]map[string][]int // enabled channel
@@ -93,10 +96,43 @@ func SyncChannelCache(frequency int) {
}
}
func GetRandomSatisfiedChannel(group string, model string, retry int) (*Channel, error) {
func CacheGetRandomSatisfiedChannel(c *gin.Context, group string, model string, retry int) (*Channel, string, error) {
var channel *Channel
var err error
selectGroup := group
if group == "auto" {
if len(setting.AutoGroups) == 0 {
return nil, selectGroup, errors.New("auto groups is not enabled")
}
for _, autoGroup := range setting.AutoGroups {
if common.DebugEnabled {
println("autoGroup:", autoGroup)
}
channel, _ = getRandomSatisfiedChannel(autoGroup, model, retry)
if channel == nil {
continue
} else {
c.Set("auto_group", autoGroup)
selectGroup = autoGroup
if common.DebugEnabled {
println("selectGroup:", selectGroup)
}
break
}
}
} else {
channel, err = getRandomSatisfiedChannel(group, model, retry)
if err != nil {
return nil, group, err
}
}
return channel, selectGroup, nil
}
func getRandomSatisfiedChannel(group string, model string, retry int) (*Channel, error) {
// if memory cache is disabled, get channel directly from database
if !common.MemoryCacheEnabled {
return GetChannel(group, model, retry)
return GetRandomSatisfiedChannel(group, model, retry)
}
channelSyncLock.RLock()
@@ -142,12 +178,10 @@ func GetRandomSatisfiedChannel(group string, model string, retry int) (*Channel,
targetPriority := int64(sortedUniquePriorities[retry])
// get the priority for the given retry number
var sumWeight = 0
var targetChannels []*Channel
for _, channelId := range channels {
if channel, ok := channelsIDM[channelId]; ok {
if channel.GetPriority() == targetPriority {
sumWeight += channel.GetWeight()
targetChannels = append(targetChannels, channel)
}
} else {
@@ -155,33 +189,19 @@ func GetRandomSatisfiedChannel(group string, model string, retry int) (*Channel,
}
}
if len(targetChannels) == 0 {
return nil, errors.New(fmt.Sprintf("no channel found, group: %s, model: %s, priority: %d", group, model, targetPriority))
}
// smoothing factor and adjustment
smoothingFactor := 1
smoothingAdjustment := 0
if sumWeight == 0 {
// when all channels have weight 0, set sumWeight to the number of channels and set smoothing adjustment to 100
// each channel's effective weight = 100
sumWeight = len(targetChannels) * 100
smoothingAdjustment = 100
} else if sumWeight/len(targetChannels) < 10 {
// when the average weight is less than 10, set smoothing factor to 100
smoothingFactor = 100
}
// 平滑系数
smoothingFactor := 10
// Calculate the total weight of all channels up to endIdx
totalWeight := sumWeight * smoothingFactor
totalWeight := 0
for _, channel := range targetChannels {
totalWeight += channel.GetWeight() + smoothingFactor
}
// Generate a random value in the range [0, totalWeight)
randomWeight := rand.Intn(totalWeight)
// Find a channel based on its weight
for _, channel := range targetChannels {
randomWeight -= channel.GetWeight()*smoothingFactor + smoothingAdjustment
randomWeight -= channel.GetWeight() + smoothingFactor
if randomWeight < 0 {
return channel, nil
}

View File

@@ -39,15 +39,13 @@ type Log struct {
Other string `json:"other"`
}
// don't use iota, avoid change log type value
const (
LogTypeUnknown = 0
LogTypeTopup = 1
LogTypeConsume = 2
LogTypeManage = 3
LogTypeSystem = 4
LogTypeError = 5
LogTypeRefund = 6
LogTypeUnknown = iota
LogTypeTopup
LogTypeConsume
LogTypeManage
LogTypeSystem
LogTypeError
)
func formatUserLogs(logs []*Log) {

View File

@@ -84,10 +84,6 @@ func InitOptionMap() {
common.OptionMap["StripePriceId"] = setting.StripePriceId
common.OptionMap["StripeUnitPrice"] = strconv.FormatFloat(setting.StripeUnitPrice, 'f', -1, 64)
common.OptionMap["StripePromotionCodesEnabled"] = strconv.FormatBool(setting.StripePromotionCodesEnabled)
common.OptionMap["CreemApiKey"] = setting.CreemApiKey
common.OptionMap["CreemProducts"] = setting.CreemProducts
common.OptionMap["CreemTestMode"] = strconv.FormatBool(setting.CreemTestMode)
common.OptionMap["CreemWebhookSecret"] = setting.CreemWebhookSecret
common.OptionMap["TopupGroupRatio"] = common.TopupGroupRatio2JSONString()
common.OptionMap["Chats"] = setting.Chats2JsonString()
common.OptionMap["AutoGroups"] = setting.AutoGroups2JsonString()
@@ -346,14 +342,6 @@ func updateOptionMap(key string, value string) (err error) {
setting.StripeMinTopUp, _ = strconv.Atoi(value)
case "StripePromotionCodesEnabled":
setting.StripePromotionCodesEnabled = value == "true"
case "CreemApiKey":
setting.CreemApiKey = value
case "CreemProducts":
setting.CreemProducts = value
case "CreemTestMode":
setting.CreemTestMode = value == "true"
case "CreemWebhookSecret":
setting.CreemWebhookSecret = value
case "TopupGroupRatio":
err = common.UpdateTopupGroupRatioByJSONString(value)
case "GitHubClientId":

View File

@@ -6,7 +6,6 @@ import (
"time"
"github.com/QuantumNous/new-api/constant"
"github.com/QuantumNous/new-api/dto"
commonRelay "github.com/QuantumNous/new-api/relay/common"
)
@@ -16,15 +15,15 @@ func (t TaskStatus) ToVideoStatus() string {
var status string
switch t {
case TaskStatusQueued, TaskStatusSubmitted:
status = dto.VideoStatusQueued
status = commonRelay.VideoStatusQueued
case TaskStatusInProgress:
status = dto.VideoStatusInProgress
status = commonRelay.VideoStatusInProgress
case TaskStatusSuccess:
status = dto.VideoStatusCompleted
status = commonRelay.VideoStatusCompleted
case TaskStatusFailure:
status = dto.VideoStatusFailed
status = commonRelay.VideoStatusFailed
default:
status = dto.VideoStatusUnknown // Default fallback
status = commonRelay.VideoStatusUnknown // Default fallback
}
return status
}
@@ -46,7 +45,6 @@ type Task struct {
TaskID string `json:"task_id" gorm:"type:varchar(191);index"` // 第三方id不一定有/ song id\ Task id
Platform constant.TaskPlatform `json:"platform" gorm:"type:varchar(30);index"` // 平台
UserId int `json:"user_id" gorm:"index"`
Group string `json:"group" gorm:"type:varchar(50)"` // 修正计费用
ChannelId int `json:"channel_id" gorm:"index"`
Quota int `json:"quota"`
Action string `json:"action" gorm:"type:varchar(40);index"` // 任务类型, song, lyrics, description-mode
@@ -57,9 +55,8 @@ type Task struct {
FinishTime int64 `json:"finish_time" gorm:"index"`
Progress string `json:"progress" gorm:"type:varchar(20);index"`
Properties Properties `json:"properties" gorm:"type:json"`
// 禁止返回给用户内部可能包含key等隐私信息
PrivateData TaskPrivateData `json:"-" gorm:"column:private_data;type:json"`
Data json.RawMessage `json:"data" gorm:"type:json"`
Data json.RawMessage `json:"data" gorm:"type:json"`
}
func (t *Task) SetData(data any) {
@@ -73,46 +70,18 @@ func (t *Task) GetData(v any) error {
}
type Properties struct {
Input string `json:"input"`
UpstreamModelName string `json:"upstream_model_name,omitempty"`
OriginModelName string `json:"origin_model_name,omitempty"`
Input string `json:"input"`
}
func (m *Properties) Scan(val interface{}) error {
bytesValue, _ := val.([]byte)
if len(bytesValue) == 0 {
*m = Properties{}
return nil
}
return json.Unmarshal(bytesValue, m)
}
func (m Properties) Value() (driver.Value, error) {
if m == (Properties{}) {
return nil, nil
}
return json.Marshal(m)
}
type TaskPrivateData struct {
Key string `json:"key,omitempty"`
}
func (p *TaskPrivateData) Scan(val interface{}) error {
bytesValue, _ := val.([]byte)
if len(bytesValue) == 0 {
return nil
}
return json.Unmarshal(bytesValue, p)
}
func (p TaskPrivateData) Value() (driver.Value, error) {
if (p == TaskPrivateData{}) {
return nil, nil
}
return json.Marshal(p)
}
// SyncTaskQueryParams 用于包含所有搜索条件的结构体,可以根据需求添加更多字段
type SyncTaskQueryParams struct {
Platform constant.TaskPlatform
@@ -127,30 +96,13 @@ type SyncTaskQueryParams struct {
}
func InitTask(platform constant.TaskPlatform, relayInfo *commonRelay.RelayInfo) *Task {
properties := Properties{}
privateData := TaskPrivateData{}
if relayInfo != nil && relayInfo.ChannelMeta != nil {
if relayInfo.ChannelMeta.ChannelType == constant.ChannelTypeGemini {
privateData.Key = relayInfo.ChannelMeta.ApiKey
}
if relayInfo.UpstreamModelName != "" {
properties.UpstreamModelName = relayInfo.UpstreamModelName
}
if relayInfo.OriginModelName != "" {
properties.OriginModelName = relayInfo.OriginModelName
}
}
t := &Task{
UserId: relayInfo.UserId,
Group: relayInfo.UsingGroup,
SubmitTime: time.Now().Unix(),
Status: TaskStatusNotStart,
Progress: "0%",
ChannelId: relayInfo.ChannelId,
Platform: platform,
Properties: properties,
PrivateData: privateData,
UserId: relayInfo.UserId,
SubmitTime: time.Now().Unix(),
Status: TaskStatusNotStart,
Progress: "0%",
ChannelId: relayInfo.ChannelId,
Platform: platform,
}
return t
}

View File

@@ -305,72 +305,3 @@ func ManualCompleteTopUp(tradeNo string) error {
RecordLog(userId, LogTypeTopup, fmt.Sprintf("管理员补单成功,充值金额: %v支付金额%f", logger.FormatQuota(quotaToAdd), payMoney))
return nil
}
func RechargeCreem(referenceId string, customerEmail string, customerName string) (err error) {
if referenceId == "" {
return errors.New("未提供支付单号")
}
var quota int64
topUp := &TopUp{}
refCol := "`trade_no`"
if common.UsingPostgreSQL {
refCol = `"trade_no"`
}
err = DB.Transaction(func(tx *gorm.DB) error {
err := tx.Set("gorm:query_option", "FOR UPDATE").Where(refCol+" = ?", referenceId).First(topUp).Error
if err != nil {
return errors.New("充值订单不存在")
}
if topUp.Status != common.TopUpStatusPending {
return errors.New("充值订单状态错误")
}
topUp.CompleteTime = common.GetTimestamp()
topUp.Status = common.TopUpStatusSuccess
err = tx.Save(topUp).Error
if err != nil {
return err
}
// Creem 直接使用 Amount 作为充值额度(整数)
quota = topUp.Amount
// 构建更新字段,优先使用邮箱,如果邮箱为空则使用用户名
updateFields := map[string]interface{}{
"quota": gorm.Expr("quota + ?", quota),
}
// 如果有客户邮箱,尝试更新用户邮箱(仅当用户邮箱为空时)
if customerEmail != "" {
// 先检查用户当前邮箱是否为空
var user User
err = tx.Where("id = ?", topUp.UserId).First(&user).Error
if err != nil {
return err
}
// 如果用户邮箱为空,则更新为支付时使用的邮箱
if user.Email == "" {
updateFields["email"] = customerEmail
}
}
err = tx.Model(&User{}).Where("id = ?", topUp.UserId).Updates(updateFields).Error
if err != nil {
return err
}
return nil
})
if err != nil {
return errors.New("充值失败," + err.Error())
}
RecordLog(topUp.UserId, LogTypeTopup, fmt.Sprintf("使用Creem充值成功充值额度: %v支付金额%.2f", quota, topUp.Money))
return nil
}

View File

@@ -0,0 +1,114 @@
package channels
import (
"io"
"net/http"
"github.com/QuantumNous/new-api/dto"
"github.com/QuantumNous/new-api/relay/channel"
relaycommon "github.com/QuantumNous/new-api/relay/common"
"github.com/QuantumNous/new-api/types"
"github.com/gin-gonic/gin"
)
// BaseChannelPlugin 基础Channel插件
// 包装现有的Adaptor实现使其符合ChannelPlugin接口
type BaseChannelPlugin struct {
adaptor channel.Adaptor
name string
version string
priority int
}
// NewBaseChannelPlugin 创建基础Channel插件
func NewBaseChannelPlugin(adaptor channel.Adaptor, name, version string, priority int) *BaseChannelPlugin {
return &BaseChannelPlugin{
adaptor: adaptor,
name: name,
version: version,
priority: priority,
}
}
// Name 返回插件名称
func (p *BaseChannelPlugin) Name() string {
return p.name
}
// Version 返回插件版本
func (p *BaseChannelPlugin) Version() string {
return p.version
}
// Priority 返回优先级
func (p *BaseChannelPlugin) Priority() int {
return p.priority
}
// 以下方法直接委托给内部的Adaptor
func (p *BaseChannelPlugin) Init(info *relaycommon.RelayInfo) {
p.adaptor.Init(info)
}
func (p *BaseChannelPlugin) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
return p.adaptor.GetRequestURL(info)
}
func (p *BaseChannelPlugin) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
return p.adaptor.SetupRequestHeader(c, req, info)
}
func (p *BaseChannelPlugin) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
return p.adaptor.ConvertOpenAIRequest(c, info, request)
}
func (p *BaseChannelPlugin) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
return p.adaptor.ConvertRerankRequest(c, relayMode, request)
}
func (p *BaseChannelPlugin) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
return p.adaptor.ConvertEmbeddingRequest(c, info, request)
}
func (p *BaseChannelPlugin) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
return p.adaptor.ConvertAudioRequest(c, info, request)
}
func (p *BaseChannelPlugin) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
return p.adaptor.ConvertImageRequest(c, info, request)
}
func (p *BaseChannelPlugin) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {
return p.adaptor.ConvertOpenAIResponsesRequest(c, info, request)
}
func (p *BaseChannelPlugin) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return p.adaptor.DoRequest(c, info, requestBody)
}
func (p *BaseChannelPlugin) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
return p.adaptor.DoResponse(c, resp, info)
}
func (p *BaseChannelPlugin) GetModelList() []string {
return p.adaptor.GetModelList()
}
func (p *BaseChannelPlugin) GetChannelName() string {
return p.adaptor.GetChannelName()
}
func (p *BaseChannelPlugin) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error) {
return p.adaptor.ConvertClaudeRequest(c, info, request)
}
func (p *BaseChannelPlugin) ConvertGeminiRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeminiChatRequest) (any, error) {
return p.adaptor.ConvertGeminiRequest(c, info, request)
}
// GetAdaptor 获取内部的Adaptor用于向后兼容
func (p *BaseChannelPlugin) GetAdaptor() channel.Adaptor {
return p.adaptor
}

View File

@@ -0,0 +1,106 @@
package channels
import (
"fmt"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/constant"
"github.com/QuantumNous/new-api/core/registry"
"github.com/QuantumNous/new-api/relay/channel"
"github.com/QuantumNous/new-api/relay/channel/ali"
"github.com/QuantumNous/new-api/relay/channel/aws"
"github.com/QuantumNous/new-api/relay/channel/baidu"
"github.com/QuantumNous/new-api/relay/channel/baidu_v2"
"github.com/QuantumNous/new-api/relay/channel/claude"
"github.com/QuantumNous/new-api/relay/channel/cloudflare"
"github.com/QuantumNous/new-api/relay/channel/cohere"
"github.com/QuantumNous/new-api/relay/channel/coze"
"github.com/QuantumNous/new-api/relay/channel/deepseek"
"github.com/QuantumNous/new-api/relay/channel/dify"
"github.com/QuantumNous/new-api/relay/channel/gemini"
"github.com/QuantumNous/new-api/relay/channel/jimeng"
"github.com/QuantumNous/new-api/relay/channel/jina"
"github.com/QuantumNous/new-api/relay/channel/mistral"
"github.com/QuantumNous/new-api/relay/channel/mokaai"
"github.com/QuantumNous/new-api/relay/channel/moonshot"
"github.com/QuantumNous/new-api/relay/channel/ollama"
"github.com/QuantumNous/new-api/relay/channel/openai"
"github.com/QuantumNous/new-api/relay/channel/palm"
"github.com/QuantumNous/new-api/relay/channel/perplexity"
"github.com/QuantumNous/new-api/relay/channel/siliconflow"
"github.com/QuantumNous/new-api/relay/channel/submodel"
"github.com/QuantumNous/new-api/relay/channel/tencent"
"github.com/QuantumNous/new-api/relay/channel/vertex"
"github.com/QuantumNous/new-api/relay/channel/volcengine"
"github.com/QuantumNous/new-api/relay/channel/xai"
"github.com/QuantumNous/new-api/relay/channel/xunfei"
"github.com/QuantumNous/new-api/relay/channel/zhipu"
"github.com/QuantumNous/new-api/relay/channel/zhipu_4v"
)
// init 包初始化时自动注册所有Channel插件
func init() {
RegisterAllChannels()
}
// RegisterAllChannels 注册所有Channel插件
func RegisterAllChannels() {
// 包装现有的Adaptor并注册为插件
channels := []struct {
channelType int
adaptor channel.Adaptor
name string
}{
{constant.APITypeOpenAI, &openai.Adaptor{}, "openai"},
{constant.APITypeAnthropic, &claude.Adaptor{}, "claude"},
{constant.APITypeGemini, &gemini.Adaptor{}, "gemini"},
{constant.APITypeAli, &ali.Adaptor{}, "ali"},
{constant.APITypeBaidu, &baidu.Adaptor{}, "baidu"},
{constant.APITypeBaiduV2, &baidu_v2.Adaptor{}, "baidu_v2"},
{constant.APITypeTencent, &tencent.Adaptor{}, "tencent"},
{constant.APITypeXunfei, &xunfei.Adaptor{}, "xunfei"},
{constant.APITypeZhipu, &zhipu.Adaptor{}, "zhipu"},
{constant.APITypeZhipuV4, &zhipu_4v.Adaptor{}, "zhipu_v4"},
{constant.APITypeOllama, &ollama.Adaptor{}, "ollama"},
{constant.APITypePerplexity, &perplexity.Adaptor{}, "perplexity"},
{constant.APITypeAws, &aws.Adaptor{}, "aws"},
{constant.APITypeCohere, &cohere.Adaptor{}, "cohere"},
{constant.APITypeDify, &dify.Adaptor{}, "dify"},
{constant.APITypeJina, &jina.Adaptor{}, "jina"},
{constant.APITypeCloudflare, &cloudflare.Adaptor{}, "cloudflare"},
{constant.APITypeSiliconFlow, &siliconflow.Adaptor{}, "siliconflow"},
{constant.APITypeVertexAi, &vertex.Adaptor{}, "vertex"},
{constant.APITypeMistral, &mistral.Adaptor{}, "mistral"},
{constant.APITypeDeepSeek, &deepseek.Adaptor{}, "deepseek"},
{constant.APITypeMokaAI, &mokaai.Adaptor{}, "mokaai"},
{constant.APITypeVolcEngine, &volcengine.Adaptor{}, "volcengine"},
{constant.APITypeXai, &xai.Adaptor{}, "xai"},
{constant.APITypeCoze, &coze.Adaptor{}, "coze"},
{constant.APITypeJimeng, &jimeng.Adaptor{}, "jimeng"},
{constant.APITypeMoonshot, &moonshot.Adaptor{}, "moonshot"},
{constant.APITypeSubmodel, &submodel.Adaptor{}, "submodel"},
{constant.APITypePaLM, &palm.Adaptor{}, "palm"},
// OpenRouter 和 Xinference 使用 OpenAI adaptor
{constant.APITypeOpenRouter, &openai.Adaptor{}, "openrouter"},
{constant.APITypeXinference, &openai.Adaptor{}, "xinference"},
}
registeredCount := 0
for _, ch := range channels {
plugin := NewBaseChannelPlugin(
ch.adaptor,
ch.name,
"1.0.0",
100, // 默认优先级
)
if err := registry.RegisterChannel(ch.channelType, plugin); err != nil {
common.SysError("Failed to register channel plugin: " + ch.name + ", error: " + err.Error())
} else {
registeredCount++
}
}
common.SysLog(fmt.Sprintf("Registered %d channel plugins", registeredCount))
}

View File

@@ -0,0 +1,186 @@
package content_filter
import (
"encoding/json"
"strings"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/core/interfaces"
)
// ContentFilterHook 内容过滤Hook
// 在响应返回前过滤敏感内容
type ContentFilterHook struct {
enabled bool
priority int
sensitiveWords []string
filterNSFW bool
filterPolitical bool
replacementText string
}
// NewContentFilterHook 创建ContentFilterHook实例
func NewContentFilterHook(config map[string]interface{}) *ContentFilterHook {
hook := &ContentFilterHook{
enabled: true,
priority: 100, // 高优先级,最后执行
sensitiveWords: []string{},
filterNSFW: true,
filterPolitical: false,
replacementText: "[已过滤]",
}
if enabled, ok := config["enabled"].(bool); ok {
hook.enabled = enabled
}
if priority, ok := config["priority"].(int); ok {
hook.priority = priority
}
if filterNSFW, ok := config["filter_nsfw"].(bool); ok {
hook.filterNSFW = filterNSFW
}
if filterPolitical, ok := config["filter_political"].(bool); ok {
hook.filterPolitical = filterPolitical
}
if words, ok := config["sensitive_words"].([]interface{}); ok {
for _, word := range words {
if w, ok := word.(string); ok {
hook.sensitiveWords = append(hook.sensitiveWords, w)
}
}
}
return hook
}
// Name 返回Hook名称
func (h *ContentFilterHook) Name() string {
return "content_filter"
}
// Priority 返回优先级
func (h *ContentFilterHook) Priority() int {
return h.priority
}
// Enabled 返回是否启用
func (h *ContentFilterHook) Enabled() bool {
return h.enabled
}
// OnBeforeRequest 请求前处理(不需要处理)
func (h *ContentFilterHook) OnBeforeRequest(ctx *interfaces.HookContext) error {
return nil
}
// OnAfterResponse 响应后处理 - 过滤内容
func (h *ContentFilterHook) OnAfterResponse(ctx *interfaces.HookContext) error {
if !h.Enabled() {
return nil
}
// 只处理chat completion响应
if !strings.Contains(ctx.Request.URL.Path, "chat/completions") {
return nil
}
// 如果没有响应体,跳过
if len(ctx.ResponseBody) == 0 {
return nil
}
// 解析响应
var response map[string]interface{}
if err := json.Unmarshal(ctx.ResponseBody, &response); err != nil {
return nil // 忽略解析错误
}
// 过滤内容
filtered := h.filterResponse(response)
// 如果内容被修改,更新响应体
if filtered {
modifiedBody, err := json.Marshal(response)
if err != nil {
return err
}
ctx.ResponseBody = modifiedBody
// 记录过滤事件
ctx.Data["content_filtered"] = true
common.SysLog("Content filter applied to response")
}
return nil
}
// OnError 错误处理
func (h *ContentFilterHook) OnError(ctx *interfaces.HookContext, err error) error {
return nil
}
// filterResponse 过滤响应内容
func (h *ContentFilterHook) filterResponse(response map[string]interface{}) bool {
modified := false
// 获取choices数组
choices, ok := response["choices"].([]interface{})
if !ok {
return false
}
// 遍历每个choice
for _, choice := range choices {
choiceMap, ok := choice.(map[string]interface{})
if !ok {
continue
}
// 获取message
message, ok := choiceMap["message"].(map[string]interface{})
if !ok {
continue
}
// 获取content
content, ok := message["content"].(string)
if !ok {
continue
}
// 过滤内容
filteredContent := h.filterText(content)
// 如果内容被修改
if filteredContent != content {
message["content"] = filteredContent
modified = true
}
}
return modified
}
// filterText 过滤文本内容
func (h *ContentFilterHook) filterText(text string) string {
filtered := text
// 过滤敏感词
for _, word := range h.sensitiveWords {
if strings.Contains(filtered, word) {
filtered = strings.ReplaceAll(filtered, word, h.replacementText)
}
}
// TODO: 实现更复杂的过滤逻辑
// - NSFW内容检测
// - 政治敏感内容检测
// - 使用AI模型进行内容分类
return filtered
}

View File

@@ -0,0 +1,39 @@
package content_filter
import (
"os"
"strings"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/core/registry"
)
func init() {
// 从环境变量读取配置
config := map[string]interface{}{
"enabled": os.Getenv("CONTENT_FILTER_ENABLED") == "true",
"priority": 100,
"filter_nsfw": os.Getenv("CONTENT_FILTER_NSFW") != "false",
"filter_political": os.Getenv("CONTENT_FILTER_POLITICAL") == "true",
}
// 读取敏感词列表
if wordsEnv := os.Getenv("CONTENT_FILTER_WORDS"); wordsEnv != "" {
words := strings.Split(wordsEnv, ",")
config["sensitive_words"] = words
}
// 创建并注册Hook
hook := NewContentFilterHook(config)
if err := registry.RegisterHook(hook); err != nil {
common.SysError("Failed to register content_filter hook: " + err.Error())
} else {
if hook.Enabled() {
common.SysLog("Content filter hook registered and enabled")
} else {
common.SysLog("Content filter hook registered but disabled")
}
}
}

View File

@@ -0,0 +1,39 @@
package web_search
import (
"os"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/core/registry"
)
func init() {
// 从环境变量读取配置
config := map[string]interface{}{
"enabled": os.Getenv("WEB_SEARCH_ENABLED") == "true",
"api_key": os.Getenv("WEB_SEARCH_API_KEY"),
"provider": getEnvOrDefault("WEB_SEARCH_PROVIDER", "google"),
"priority": 50,
}
// 创建并注册Hook
hook := NewWebSearchHook(config)
if err := registry.RegisterHook(hook); err != nil {
common.SysError("Failed to register web_search hook: " + err.Error())
} else {
if hook.Enabled() {
common.SysLog("Web search hook registered and enabled")
} else {
common.SysLog("Web search hook registered but disabled (missing API key or not enabled)")
}
}
}
func getEnvOrDefault(key, defaultValue string) string {
if value := os.Getenv(key); value != "" {
return value
}
return defaultValue
}

View File

@@ -0,0 +1,281 @@
package web_search
import (
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/core/interfaces"
)
// WebSearchHook 联网搜索Hook插件
// 在请求发送前检测是否需要联网搜索如果需要则调用搜索API并将结果注入到请求中
type WebSearchHook struct {
enabled bool
priority int
apiKey string
provider string // google, bing, etc
}
// NewWebSearchHook 创建WebSearchHook实例
func NewWebSearchHook(config map[string]interface{}) *WebSearchHook {
hook := &WebSearchHook{
enabled: true,
priority: 50, // 中等优先级
provider: "google",
}
if apiKey, ok := config["api_key"].(string); ok {
hook.apiKey = apiKey
}
if provider, ok := config["provider"].(string); ok {
hook.provider = provider
}
if priority, ok := config["priority"].(int); ok {
hook.priority = priority
}
if enabled, ok := config["enabled"].(bool); ok {
hook.enabled = enabled
}
return hook
}
// Name 返回Hook名称
func (h *WebSearchHook) Name() string {
return "web_search"
}
// Priority 返回优先级
func (h *WebSearchHook) Priority() int {
return h.priority
}
// Enabled 返回是否启用
func (h *WebSearchHook) Enabled() bool {
return h.enabled && h.apiKey != ""
}
// OnBeforeRequest 请求前处理
func (h *WebSearchHook) OnBeforeRequest(ctx *interfaces.HookContext) error {
if !h.Enabled() {
return nil
}
// 只处理chat completion请求
if !strings.Contains(ctx.Request.URL.Path, "chat/completions") {
return nil
}
// 检查请求体中是否包含搜索关键词
if len(ctx.RequestBody) == 0 {
return nil
}
// 解析请求体
var requestData map[string]interface{}
if err := json.Unmarshal(ctx.RequestBody, &requestData); err != nil {
return nil // 忽略解析错误
}
// 检查是否需要搜索(简单示例:检查最后一条消息是否包含 [search] 标记)
if !h.shouldSearch(requestData) {
return nil
}
// 执行搜索
searchQuery := h.extractSearchQuery(requestData)
if searchQuery == "" {
return nil
}
common.SysLog(fmt.Sprintf("Web search triggered for query: %s", searchQuery))
// 调用搜索API
searchResults, err := h.performSearch(searchQuery)
if err != nil {
common.SysError(fmt.Sprintf("Web search failed: %v", err))
return nil // 不中断请求,只记录错误
}
// 将搜索结果注入到请求中
h.injectSearchResults(requestData, searchResults)
// 更新请求体
modifiedBody, err := json.Marshal(requestData)
if err != nil {
return err
}
ctx.RequestBody = modifiedBody
// 存储到Data中供后续使用
ctx.Data["web_search_performed"] = true
ctx.Data["web_search_query"] = searchQuery
return nil
}
// OnAfterResponse 响应后处理
func (h *WebSearchHook) OnAfterResponse(ctx *interfaces.HookContext) error {
// 可以在这里记录搜索使用情况等
if performed, ok := ctx.Data["web_search_performed"].(bool); ok && performed {
query := ctx.Data["web_search_query"].(string)
common.SysLog(fmt.Sprintf("Web search completed for query: %s", query))
}
return nil
}
// OnError 错误处理
func (h *WebSearchHook) OnError(ctx *interfaces.HookContext, err error) error {
// 记录错误但不影响主流程
if performed, ok := ctx.Data["web_search_performed"].(bool); ok && performed {
common.SysError(fmt.Sprintf("Request failed after web search: %v", err))
}
return nil
}
// shouldSearch 判断是否需要搜索
func (h *WebSearchHook) shouldSearch(requestData map[string]interface{}) bool {
messages, ok := requestData["messages"].([]interface{})
if !ok || len(messages) == 0 {
return false
}
// 检查最后一条消息
lastMessage, ok := messages[len(messages)-1].(map[string]interface{})
if !ok {
return false
}
content, ok := lastMessage["content"].(string)
if !ok {
return false
}
// 简单示例:检查是否包含 [search] 或 [联网] 标记
return strings.Contains(content, "[search]") ||
strings.Contains(content, "[联网]") ||
strings.Contains(content, "[web]")
}
// extractSearchQuery 提取搜索查询
func (h *WebSearchHook) extractSearchQuery(requestData map[string]interface{}) string {
messages, ok := requestData["messages"].([]interface{})
if !ok || len(messages) == 0 {
return ""
}
lastMessage, ok := messages[len(messages)-1].(map[string]interface{})
if !ok {
return ""
}
content, ok := lastMessage["content"].(string)
if !ok {
return ""
}
// 移除标记,保留实际查询内容
query := strings.ReplaceAll(content, "[search]", "")
query = strings.ReplaceAll(query, "[联网]", "")
query = strings.ReplaceAll(query, "[web]", "")
query = strings.TrimSpace(query)
return query
}
// performSearch 执行搜索
func (h *WebSearchHook) performSearch(query string) (string, error) {
// 这里是示例实现实际应该调用真实的搜索API
// 例如Google Custom Search API, Bing Search API等
if h.apiKey == "" {
return "", fmt.Errorf("search API key not configured")
}
// 示例:返回模拟结果
// 实际实现需要调用真实API
return h.mockSearch(query)
}
// mockSearch 模拟搜索(示例)
func (h *WebSearchHook) mockSearch(query string) (string, error) {
// 这只是一个示例实现
// 实际应该调用真实的搜索API
common.SysLog(fmt.Sprintf("[Mock] Searching for: %s", query))
// 返回模拟的搜索结果
return fmt.Sprintf("搜索结果 (模拟):关于 '%s' 的最新信息...", query), nil
}
// realSearch 真实搜索实现示例需要配置API
func (h *WebSearchHook) realSearch(query string) (string, error) {
// 示例使用Google Custom Search API
url := fmt.Sprintf("https://www.googleapis.com/customsearch/v1?key=%s&cx=YOUR_CX&q=%s",
h.apiKey, query)
resp, err := http.Get(url)
if err != nil {
return "", err
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return "", err
}
// 解析搜索结果
var result map[string]interface{}
if err := json.Unmarshal(body, &result); err != nil {
return "", err
}
// 提取搜索结果摘要
// 这里需要根据实际API响应格式处理
return string(body), nil
}
// injectSearchResults 将搜索结果注入到请求中
func (h *WebSearchHook) injectSearchResults(requestData map[string]interface{}, results string) {
messages, ok := requestData["messages"].([]interface{})
if !ok {
return
}
// 在用户消息前插入系统消息,包含搜索结果
systemMessage := map[string]interface{}{
"role": "system",
"content": fmt.Sprintf("以下是针对用户查询的最新搜索结果:\n\n%s\n\n请基于这些信息回答用户的问题。", results),
}
// 插入到消息列表的适当位置
updatedMessages := make([]interface{}, 0, len(messages)+1)
// 如果第一条是系统消息,在其后插入
if len(messages) > 0 {
if firstMsg, ok := messages[0].(map[string]interface{}); ok {
if role, ok := firstMsg["role"].(string); ok && role == "system" {
updatedMessages = append(updatedMessages, messages[0])
updatedMessages = append(updatedMessages, systemMessage)
updatedMessages = append(updatedMessages, messages[1:]...)
requestData["messages"] = updatedMessages
return
}
}
}
// 否则插入到开头
updatedMessages = append(updatedMessages, systemMessage)
updatedMessages = append(updatedMessages, messages...)
requestData["messages"] = updatedMessages
}

View File

@@ -53,5 +53,5 @@ type TaskAdaptor interface {
}
type OpenAIVideoConverter interface {
ConvertToOpenAIVideo(originTask *model.Task) ([]byte, error)
ConvertToOpenAIVideo(originTask *model.Task) (*relaycommon.OpenAIVideo, error)
}

View File

@@ -98,9 +98,9 @@ func oaiFormEdit2AliImageEdit(c *gin.Context, info *relaycommon.RelayInfo, reque
return nil, errors.New("image is required")
}
//if len(imageFiles) > 1 {
// return nil, errors.New("only one image is supported for qwen edit")
//}
if len(imageFiles) > 1 {
return nil, errors.New("only one image is supported for qwen edit")
}
// 获取base64编码的图片
var imageBase64s []string

View File

@@ -1,7 +1,20 @@
package ali
import (
"bufio"
"encoding/json"
"io"
"net/http"
"strings"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/dto"
"github.com/QuantumNous/new-api/relay/helper"
"github.com/QuantumNous/new-api/service"
"github.com/QuantumNous/new-api/types"
"github.com/gin-gonic/gin"
)
// https://help.aliyun.com/document_detail/613695.html?spm=a2c4g.2399480.0.0.1adb778fAdzP9w#341800c0f8w0r
@@ -16,3 +29,180 @@ func requestOpenAI2Ali(request dto.GeneralOpenAIRequest) *dto.GeneralOpenAIReque
}
return &request
}
func embeddingRequestOpenAI2Ali(request dto.EmbeddingRequest) *AliEmbeddingRequest {
return &AliEmbeddingRequest{
Model: request.Model,
Input: struct {
Texts []string `json:"texts"`
}{
Texts: request.ParseInput(),
},
}
}
func aliEmbeddingHandler(c *gin.Context, resp *http.Response) (*types.NewAPIError, *dto.Usage) {
var fullTextResponse dto.FlexibleEmbeddingResponse
err := json.NewDecoder(resp.Body).Decode(&fullTextResponse)
if err != nil {
return types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError), nil
}
service.CloseResponseBodyGracefully(resp)
model := c.GetString("model")
if model == "" {
model = "text-embedding-v4"
}
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
c.Writer.Write(jsonResponse)
return nil, &fullTextResponse.Usage
}
func embeddingResponseAli2OpenAI(response *AliEmbeddingResponse, model string) *dto.OpenAIEmbeddingResponse {
openAIEmbeddingResponse := dto.OpenAIEmbeddingResponse{
Object: "list",
Data: make([]dto.OpenAIEmbeddingResponseItem, 0, len(response.Output.Embeddings)),
Model: model,
Usage: dto.Usage{TotalTokens: response.Usage.TotalTokens},
}
for _, item := range response.Output.Embeddings {
openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, dto.OpenAIEmbeddingResponseItem{
Object: `embedding`,
Index: item.TextIndex,
Embedding: item.Embedding,
})
}
return &openAIEmbeddingResponse
}
func responseAli2OpenAI(response *AliResponse) *dto.OpenAITextResponse {
choice := dto.OpenAITextResponseChoice{
Index: 0,
Message: dto.Message{
Role: "assistant",
Content: response.Output.Text,
},
FinishReason: response.Output.FinishReason,
}
fullTextResponse := dto.OpenAITextResponse{
Id: response.RequestId,
Object: "chat.completion",
Created: common.GetTimestamp(),
Choices: []dto.OpenAITextResponseChoice{choice},
Usage: dto.Usage{
PromptTokens: response.Usage.InputTokens,
CompletionTokens: response.Usage.OutputTokens,
TotalTokens: response.Usage.InputTokens + response.Usage.OutputTokens,
},
}
return &fullTextResponse
}
func streamResponseAli2OpenAI(aliResponse *AliResponse) *dto.ChatCompletionsStreamResponse {
var choice dto.ChatCompletionsStreamResponseChoice
choice.Delta.SetContentString(aliResponse.Output.Text)
if aliResponse.Output.FinishReason != "null" {
finishReason := aliResponse.Output.FinishReason
choice.FinishReason = &finishReason
}
response := dto.ChatCompletionsStreamResponse{
Id: aliResponse.RequestId,
Object: "chat.completion.chunk",
Created: common.GetTimestamp(),
Model: "ernie-bot",
Choices: []dto.ChatCompletionsStreamResponseChoice{choice},
}
return &response
}
func aliStreamHandler(c *gin.Context, resp *http.Response) (*types.NewAPIError, *dto.Usage) {
var usage dto.Usage
scanner := bufio.NewScanner(resp.Body)
scanner.Split(bufio.ScanLines)
dataChan := make(chan string)
stopChan := make(chan bool)
go func() {
for scanner.Scan() {
data := scanner.Text()
if len(data) < 5 { // ignore blank line or wrong format
continue
}
if data[:5] != "data:" {
continue
}
data = data[5:]
dataChan <- data
}
stopChan <- true
}()
helper.SetEventStreamHeaders(c)
lastResponseText := ""
c.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:
var aliResponse AliResponse
err := json.Unmarshal([]byte(data), &aliResponse)
if err != nil {
common.SysLog("error unmarshalling stream response: " + err.Error())
return true
}
if aliResponse.Usage.OutputTokens != 0 {
usage.PromptTokens = aliResponse.Usage.InputTokens
usage.CompletionTokens = aliResponse.Usage.OutputTokens
usage.TotalTokens = aliResponse.Usage.InputTokens + aliResponse.Usage.OutputTokens
}
response := streamResponseAli2OpenAI(&aliResponse)
response.Choices[0].Delta.SetContentString(strings.TrimPrefix(response.Choices[0].Delta.GetContentString(), lastResponseText))
lastResponseText = aliResponse.Output.Text
jsonResponse, err := json.Marshal(response)
if err != nil {
common.SysLog("error marshalling stream response: " + err.Error())
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
return true
case <-stopChan:
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
return false
}
})
service.CloseResponseBodyGracefully(resp)
return nil, &usage
}
func aliHandler(c *gin.Context, resp *http.Response) (*types.NewAPIError, *dto.Usage) {
var aliResponse AliResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError), nil
}
service.CloseResponseBodyGracefully(resp)
err = json.Unmarshal(responseBody, &aliResponse)
if err != nil {
return types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError), nil
}
if aliResponse.Code != "" {
return types.WithOpenAIError(types.OpenAIError{
Message: aliResponse.Message,
Type: "ali_error",
Param: aliResponse.RequestId,
Code: aliResponse.Code,
}, resp.StatusCode), nil
}
fullTextResponse := responseAli2OpenAI(&aliResponse)
jsonResponse, err := common.Marshal(fullTextResponse)
if err != nil {
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse)
return nil, &fullTextResponse.Usage
}

View File

@@ -1,36 +1,25 @@
package aws
import (
"fmt"
"errors"
"io"
"net/http"
"strings"
"github.com/QuantumNous/new-api/dto"
"github.com/QuantumNous/new-api/relay/channel"
"github.com/QuantumNous/new-api/relay/channel/claude"
relaycommon "github.com/QuantumNous/new-api/relay/common"
"github.com/QuantumNous/new-api/service"
"github.com/QuantumNous/new-api/types"
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
"github.com/pkg/errors"
"github.com/gin-gonic/gin"
)
type ClientMode int
const (
ClientModeApiKey ClientMode = iota + 1
ClientModeAKSK
RequestModeCompletion = 1
RequestModeMessage = 2
)
type Adaptor struct {
ClientMode ClientMode
AwsClient *bedrockruntime.Client
AwsModelId string
AwsReq any
IsNova bool
RequestMode int
}
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
@@ -39,37 +28,8 @@ func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dt
}
func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error) {
for i, message := range request.Messages {
updated := false
if !message.IsStringContent() {
content, err := message.ParseContent()
if err != nil {
return nil, errors.Wrap(err, "failed to parse message content")
}
for i2, mediaMessage := range content {
if mediaMessage.Source != nil {
if mediaMessage.Source.Type == "url" {
fileData, err := service.GetFileBase64FromUrl(c, mediaMessage.Source.Url, "formatting image for Claude")
if err != nil {
return nil, fmt.Errorf("get file base64 from url failed: %s", err.Error())
}
mediaMessage.Source.MediaType = fileData.MimeType
mediaMessage.Source.Data = fileData.Base64Data
mediaMessage.Source.Url = ""
mediaMessage.Source.Type = "base64"
content[i2] = mediaMessage
updated = true
}
}
}
if updated {
message.SetContent(content)
}
}
if updated {
request.Messages[i] = message
}
}
c.Set("request_model", request.Model)
c.Set("converted_request", request)
return request, nil
}
@@ -84,28 +44,15 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
a.RequestMode = RequestModeMessage
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
if info.ChannelOtherSettings.AwsKeyType == dto.AwsKeyTypeApiKey {
awsModelId := getAwsModelID(info.UpstreamModelName)
a.ClientMode = ClientModeApiKey
awsSecret := strings.Split(info.ApiKey, "|")
if len(awsSecret) != 2 {
return "", errors.New("invalid aws api key, should be in format of <api-key>|<region>")
}
return fmt.Sprintf("https://bedrock-runtime.%s.amazonaws.com/model/%s/converse", awsModelId, awsSecret[1]), nil
} else {
a.ClientMode = ClientModeAKSK
return "", nil
}
return "", nil
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
claude.CommonClaudeHeadersOperation(c, req, info)
if a.ClientMode == ClientModeApiKey {
req.Set("Authorization", "Bearer "+info.ApiKey)
}
return nil
}
@@ -116,16 +63,22 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
// 检查是否为Nova模型
if isNovaModel(request.Model) {
novaReq := convertToNovaRequest(request)
a.IsNova = true
c.Set("request_model", request.Model)
c.Set("converted_request", novaReq)
c.Set("is_nova_model", true)
return novaReq, nil
}
// 原有的Claude模型处理逻辑
claudeReq, err := claude.RequestOpenAI2ClaudeMessage(c, *request)
var claudeReq *dto.ClaudeRequest
var err error
claudeReq, err = claude.RequestOpenAI2ClaudeMessage(c, *request)
if err != nil {
return nil, errors.Wrap(err, "failed to convert openai request to claude request")
return nil, err
}
info.UpstreamModelName = claudeReq.Model
c.Set("request_model", claudeReq.Model)
c.Set("converted_request", claudeReq)
c.Set("is_nova_model", false)
return claudeReq, err
}
@@ -144,27 +97,14 @@ func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommo
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
if a.ClientMode == ClientModeApiKey {
return channel.DoApiRequest(a, c, info, requestBody)
} else {
return doAwsClientRequest(c, info, a, requestBody)
}
return nil, nil
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
if a.ClientMode == ClientModeApiKey {
claudeAdaptor := claude.Adaptor{}
usage, err = claudeAdaptor.DoResponse(c, resp, info)
if info.IsStream {
err, usage = awsStreamHandler(c, resp, info, a.RequestMode)
} else {
if a.IsNova {
err, usage = handleNovaRequest(c, info, a)
} else {
if info.IsStream {
err, usage = awsStreamHandler(c, info, a)
} else {
err, usage = awsHandler(c, info, a)
}
}
err, usage = awsHandler(c, info, a.RequestMode)
}
return
}

View File

@@ -17,7 +17,6 @@ var awsModelIDMap = map[string]string{
"claude-opus-4-20250514": "anthropic.claude-opus-4-20250514-v1:0",
"claude-opus-4-1-20250805": "anthropic.claude-opus-4-1-20250805-v1:0",
"claude-sonnet-4-5-20250929": "anthropic.claude-sonnet-4-5-20250929-v1:0",
"claude-haiku-4-5-20251001": "anthropic.claude-haiku-4-5-20251001-v1:0",
// Nova models
"nova-micro-v1:0": "amazon.nova-micro-v1:0",
"nova-lite-v1:0": "amazon.nova-lite-v1:0",
@@ -76,11 +75,6 @@ var awsModelCanCrossRegionMap = map[string]map[string]bool{
"ap": true,
"eu": true,
},
"anthropic.claude-haiku-4-5-20251001-v1:0": {
"us": true,
"ap": true,
"eu": true,
},
// Nova models - all support three major regions
"amazon.nova-micro-v1:0": {
"us": true,
@@ -130,5 +124,5 @@ var ChannelName = "aws"
// 判断是否为Nova模型
func isNovaModel(modelId string) bool {
return strings.Contains(modelId, "nova-")
return strings.HasPrefix(modelId, "nova-")
}

View File

@@ -1,9 +1,6 @@
package aws
import (
"io"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/dto"
)
@@ -38,16 +35,6 @@ func copyRequest(req *dto.ClaudeRequest) *AwsClaudeRequest {
}
}
func formatRequest(requestBody io.Reader) (*AwsClaudeRequest, error) {
var awsClaudeRequest AwsClaudeRequest
err := common.DecodeJson(requestBody, &awsClaudeRequest)
if err != nil {
return nil, err
}
awsClaudeRequest.AnthropicVersion = "bedrock-2023-05-31"
return &awsClaudeRequest, nil
}
// NovaMessage Nova模型使用messages-v1格式
type NovaMessage struct {
Role string `json:"role"`

View File

@@ -3,7 +3,6 @@ package aws
import (
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
@@ -12,7 +11,6 @@ import (
"github.com/QuantumNous/new-api/relay/channel/claude"
relaycommon "github.com/QuantumNous/new-api/relay/common"
"github.com/QuantumNous/new-api/relay/helper"
"github.com/QuantumNous/new-api/service"
"github.com/QuantumNous/new-api/types"
"github.com/gin-gonic/gin"
@@ -26,19 +24,6 @@ import (
)
func newAwsClient(c *gin.Context, info *relaycommon.RelayInfo) (*bedrockruntime.Client, error) {
var (
httpClient *http.Client
err error
)
if info.ChannelSetting.Proxy != "" {
httpClient, err = service.NewProxyHttpClient(info.ChannelSetting.Proxy)
if err != nil {
return nil, fmt.Errorf("new proxy http client failed: %w", err)
}
} else {
httpClient = service.GetHttpClient()
}
awsSecret := strings.Split(info.ApiKey, "|")
var client *bedrockruntime.Client
switch len(awsSecret) {
@@ -48,7 +33,6 @@ func newAwsClient(c *gin.Context, info *relaycommon.RelayInfo) (*bedrockruntime.
client = bedrockruntime.New(bedrockruntime.Options{
Region: region,
BearerAuthTokenProvider: bearer.StaticTokenProvider{Token: bearer.Token{Value: apiKey}},
HTTPClient: httpClient,
})
case 3:
ak := awsSecret[0]
@@ -57,7 +41,6 @@ func newAwsClient(c *gin.Context, info *relaycommon.RelayInfo) (*bedrockruntime.
client = bedrockruntime.New(bedrockruntime.Options{
Region: region,
Credentials: aws.NewCredentialsCache(credentials.NewStaticCredentialsProvider(ak, sk, "")),
HTTPClient: httpClient,
})
default:
return nil, errors.New("invalid aws secret key")
@@ -66,78 +49,16 @@ func newAwsClient(c *gin.Context, info *relaycommon.RelayInfo) (*bedrockruntime.
return client, nil
}
func doAwsClientRequest(c *gin.Context, info *relaycommon.RelayInfo, a *Adaptor, requestBody io.Reader) (any, error) {
awsCli, err := newAwsClient(c, info)
if err != nil {
return nil, types.NewError(err, types.ErrorCodeChannelAwsClientError)
}
a.AwsClient = awsCli
println(info.UpstreamModelName)
// 获取对应的AWS模型ID
awsModelId := getAwsModelID(info.UpstreamModelName)
awsRegionPrefix := getAwsRegionPrefix(awsCli.Options().Region)
canCrossRegion := awsModelCanCrossRegion(awsModelId, awsRegionPrefix)
if canCrossRegion {
awsModelId = awsModelCrossRegion(awsModelId, awsRegionPrefix)
}
if isNovaModel(awsModelId) {
var novaReq *NovaRequest
err = common.DecodeJson(requestBody, &novaReq)
if err != nil {
return nil, types.NewError(errors.Wrap(err, "decode nova request fail"), types.ErrorCodeBadRequestBody)
}
// 使用InvokeModel API但使用Nova格式的请求体
awsReq := &bedrockruntime.InvokeModelInput{
ModelId: aws.String(awsModelId),
Accept: aws.String("application/json"),
ContentType: aws.String("application/json"),
}
reqBody, err := common.Marshal(novaReq)
if err != nil {
return nil, types.NewError(errors.Wrap(err, "marshal nova request"), types.ErrorCodeBadResponseBody)
}
awsReq.Body = reqBody
return nil, nil
} else {
awsClaudeReq, err := formatRequest(requestBody)
if err != nil {
return nil, types.NewError(errors.Wrap(err, "format aws request fail"), types.ErrorCodeBadRequestBody)
}
if info.IsStream {
awsReq := &bedrockruntime.InvokeModelWithResponseStreamInput{
ModelId: aws.String(awsModelId),
Accept: aws.String("application/json"),
ContentType: aws.String("application/json"),
}
awsReq.Body, err = common.Marshal(awsClaudeReq)
if err != nil {
return nil, types.NewError(errors.Wrap(err, "marshal aws request fail"), types.ErrorCodeBadRequestBody)
}
a.AwsReq = awsReq
return nil, nil
} else {
awsReq := &bedrockruntime.InvokeModelInput{
ModelId: aws.String(awsModelId),
Accept: aws.String("application/json"),
ContentType: aws.String("application/json"),
}
awsReq.Body, err = common.Marshal(awsClaudeReq)
if err != nil {
return nil, types.NewError(errors.Wrap(err, "marshal aws request fail"), types.ErrorCodeBadRequestBody)
}
a.AwsReq = awsReq
return nil, nil
}
func wrapErr(err error) *dto.OpenAIErrorWithStatusCode {
return &dto.OpenAIErrorWithStatusCode{
StatusCode: http.StatusInternalServerError,
Error: dto.OpenAIError{
Message: fmt.Sprintf("%s", err.Error()),
},
}
}
func getAwsRegionPrefix(awsRegionId string) string {
func awsRegionPrefix(awsRegionId string) string {
parts := strings.Split(awsRegionId, "-")
regionPrefix := ""
if len(parts) > 0 {
@@ -159,16 +80,58 @@ func awsModelCrossRegion(awsModelId, awsRegionPrefix string) string {
return modelPrefix + "." + awsModelId
}
func getAwsModelID(requestModel string) string {
if awsModelIDName, ok := awsModelIDMap[requestModel]; ok {
return awsModelIDName
func awsModelID(requestModel string) string {
if awsModelID, ok := awsModelIDMap[requestModel]; ok {
return awsModelID
}
return requestModel
}
func awsHandler(c *gin.Context, info *relaycommon.RelayInfo, a *Adaptor) (*types.NewAPIError, *dto.Usage) {
func awsHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode int) (*types.NewAPIError, *dto.Usage) {
awsCli, err := newAwsClient(c, info)
if err != nil {
return types.NewError(err, types.ErrorCodeChannelAwsClientError), nil
}
awsResp, err := a.AwsClient.InvokeModel(c.Request.Context(), a.AwsReq.(*bedrockruntime.InvokeModelInput))
awsModelId := awsModelID(c.GetString("request_model"))
// 检查是否为Nova模型
isNova, _ := c.Get("is_nova_model")
if isNova == true {
// Nova模型也支持跨区域
awsRegionPrefix := awsRegionPrefix(awsCli.Options().Region)
canCrossRegion := awsModelCanCrossRegion(awsModelId, awsRegionPrefix)
if canCrossRegion {
awsModelId = awsModelCrossRegion(awsModelId, awsRegionPrefix)
}
return handleNovaRequest(c, awsCli, info, awsModelId)
}
// 原有的Claude处理逻辑
awsRegionPrefix := awsRegionPrefix(awsCli.Options().Region)
canCrossRegion := awsModelCanCrossRegion(awsModelId, awsRegionPrefix)
if canCrossRegion {
awsModelId = awsModelCrossRegion(awsModelId, awsRegionPrefix)
}
awsReq := &bedrockruntime.InvokeModelInput{
ModelId: aws.String(awsModelId),
Accept: aws.String("application/json"),
ContentType: aws.String("application/json"),
}
claudeReq_, ok := c.Get("converted_request")
if !ok {
return types.NewError(errors.New("aws claude request not found"), types.ErrorCodeInvalidRequest), nil
}
claudeReq := claudeReq_.(*dto.ClaudeRequest)
awsClaudeReq := copyRequest(claudeReq)
awsReq.Body, err = common.Marshal(awsClaudeReq)
if err != nil {
return types.NewError(errors.Wrap(err, "marshal request"), types.ErrorCodeBadResponseBody), nil
}
awsResp, err := awsCli.InvokeModel(c.Request.Context(), awsReq)
if err != nil {
return types.NewOpenAIError(errors.Wrap(err, "InvokeModel"), types.ErrorCodeAwsInvokeError, http.StatusInternalServerError), nil
}
@@ -186,15 +149,46 @@ func awsHandler(c *gin.Context, info *relaycommon.RelayInfo, a *Adaptor) (*types
c.Writer.Header().Set("Content-Type", *awsResp.ContentType)
}
handlerErr := claude.HandleClaudeResponseData(c, info, claudeInfo, nil, awsResp.Body, claude.RequestModeMessage)
handlerErr := claude.HandleClaudeResponseData(c, info, claudeInfo, nil, awsResp.Body, RequestModeMessage)
if handlerErr != nil {
return handlerErr, nil
}
return nil, claudeInfo.Usage
}
func awsStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, a *Adaptor) (*types.NewAPIError, *dto.Usage) {
awsResp, err := a.AwsClient.InvokeModelWithResponseStream(c.Request.Context(), a.AwsReq.(*bedrockruntime.InvokeModelWithResponseStreamInput))
func awsStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, requestMode int) (*types.NewAPIError, *dto.Usage) {
awsCli, err := newAwsClient(c, info)
if err != nil {
return types.NewError(err, types.ErrorCodeChannelAwsClientError), nil
}
awsModelId := awsModelID(c.GetString("request_model"))
awsRegionPrefix := awsRegionPrefix(awsCli.Options().Region)
canCrossRegion := awsModelCanCrossRegion(awsModelId, awsRegionPrefix)
if canCrossRegion {
awsModelId = awsModelCrossRegion(awsModelId, awsRegionPrefix)
}
awsReq := &bedrockruntime.InvokeModelWithResponseStreamInput{
ModelId: aws.String(awsModelId),
Accept: aws.String("application/json"),
ContentType: aws.String("application/json"),
}
claudeReq_, ok := c.Get("converted_request")
if !ok {
return types.NewError(errors.New("aws claude request not found"), types.ErrorCodeInvalidRequest), nil
}
claudeReq := claudeReq_.(*dto.ClaudeRequest)
awsClaudeReq := copyRequest(claudeReq)
awsReq.Body, err = common.Marshal(awsClaudeReq)
if err != nil {
return types.NewError(errors.Wrap(err, "marshal request"), types.ErrorCodeBadResponseBody), nil
}
awsResp, err := awsCli.InvokeModelWithResponseStream(c.Request.Context(), awsReq)
if err != nil {
return types.NewOpenAIError(errors.Wrap(err, "InvokeModelWithResponseStream"), types.ErrorCodeAwsInvokeError, http.StatusInternalServerError), nil
}
@@ -213,7 +207,7 @@ func awsStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, a *Adaptor) (
switch v := event.(type) {
case *bedrockruntimeTypes.ResponseStreamMemberChunk:
info.SetFirstResponseTime()
respErr := claude.HandleStreamResponseData(c, info, claudeInfo, string(v.Value.Bytes), claude.RequestModeMessage)
respErr := claude.HandleStreamResponseData(c, info, claudeInfo, string(v.Value.Bytes), RequestModeMessage)
if respErr != nil {
return respErr, nil
}
@@ -226,14 +220,32 @@ func awsStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, a *Adaptor) (
}
}
claude.HandleStreamFinalResponse(c, info, claudeInfo, claude.RequestModeMessage)
claude.HandleStreamFinalResponse(c, info, claudeInfo, RequestModeMessage)
return nil, claudeInfo.Usage
}
// Nova模型处理函数
func handleNovaRequest(c *gin.Context, info *relaycommon.RelayInfo, a *Adaptor) (*types.NewAPIError, *dto.Usage) {
func handleNovaRequest(c *gin.Context, awsCli *bedrockruntime.Client, info *relaycommon.RelayInfo, awsModelId string) (*types.NewAPIError, *dto.Usage) {
novaReq_, ok := c.Get("converted_request")
if !ok {
return types.NewError(errors.New("nova request not found"), types.ErrorCodeInvalidRequest), nil
}
novaReq := novaReq_.(*NovaRequest)
awsResp, err := a.AwsClient.InvokeModel(c.Request.Context(), a.AwsReq.(*bedrockruntime.InvokeModelInput))
// 使用InvokeModel API但使用Nova格式的请求体
awsReq := &bedrockruntime.InvokeModelInput{
ModelId: aws.String(awsModelId),
Accept: aws.String("application/json"),
ContentType: aws.String("application/json"),
}
reqBody, err := json.Marshal(novaReq)
if err != nil {
return types.NewError(errors.Wrap(err, "marshal nova request"), types.ErrorCodeBadResponseBody), nil
}
awsReq.Body = reqBody
awsResp, err := awsCli.InvokeModel(c.Request.Context(), awsReq)
if err != nil {
return types.NewError(errors.Wrap(err, "InvokeModel"), types.ErrorCodeChannelAwsClientError), nil
}

View File

@@ -477,7 +477,8 @@ func StreamResponseClaude2OpenAI(reqMode int, claudeResponse *dto.ClaudeResponse
signatureContent := "\n"
choice.Delta.ReasoningContent = &signatureContent
case "thinking_delta":
choice.Delta.ReasoningContent = claudeResponse.Delta.Thinking
thinkingContent := claudeResponse.Delta.Thinking
choice.Delta.ReasoningContent = &thinkingContent
}
}
} else if claudeResponse.Type == "message_delta" {
@@ -512,9 +513,7 @@ func ResponseClaude2OpenAI(reqMode int, claudeResponse *dto.ClaudeResponse) *dto
var responseThinking string
if len(claudeResponse.Content) > 0 {
responseText = claudeResponse.Content[0].GetText()
if claudeResponse.Content[0].Thinking != nil {
responseThinking = *claudeResponse.Content[0].Thinking
}
responseThinking = claudeResponse.Content[0].Thinking
}
tools := make([]dto.ToolCallResponse, 0)
thinkingContent := ""
@@ -546,9 +545,7 @@ func ResponseClaude2OpenAI(reqMode int, claudeResponse *dto.ClaudeResponse) *dto
})
case "thinking":
// 加密的不管, 只输出明文的推理过程
if message.Thinking != nil {
thinkingContent = *message.Thinking
}
thinkingContent = message.Thinking
case "text":
responseText = message.GetText()
}
@@ -596,15 +593,13 @@ func FormatClaudeResponseInfo(requestMode int, claudeResponse *dto.ClaudeRespons
claudeInfo.Usage.PromptTokens = claudeResponse.Message.Usage.InputTokens
claudeInfo.Usage.PromptTokensDetails.CachedTokens = claudeResponse.Message.Usage.CacheReadInputTokens
claudeInfo.Usage.PromptTokensDetails.CachedCreationTokens = claudeResponse.Message.Usage.CacheCreationInputTokens
claudeInfo.Usage.ClaudeCacheCreation5mTokens = claudeResponse.Message.Usage.GetCacheCreation5mTokens()
claudeInfo.Usage.ClaudeCacheCreation1hTokens = claudeResponse.Message.Usage.GetCacheCreation1hTokens()
claudeInfo.Usage.CompletionTokens = claudeResponse.Message.Usage.OutputTokens
} else if claudeResponse.Type == "content_block_delta" {
if claudeResponse.Delta.Text != nil {
claudeInfo.ResponseText.WriteString(*claudeResponse.Delta.Text)
}
if claudeResponse.Delta.Thinking != nil {
claudeInfo.ResponseText.WriteString(*claudeResponse.Delta.Thinking)
if claudeResponse.Delta.Thinking != "" {
claudeInfo.ResponseText.WriteString(claudeResponse.Delta.Thinking)
}
} else if claudeResponse.Type == "message_delta" {
// 最终的usage获取
@@ -742,8 +737,6 @@ func HandleClaudeResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud
claudeInfo.Usage.TotalTokens = claudeResponse.Usage.InputTokens + claudeResponse.Usage.OutputTokens
claudeInfo.Usage.PromptTokensDetails.CachedTokens = claudeResponse.Usage.CacheReadInputTokens
claudeInfo.Usage.PromptTokensDetails.CachedCreationTokens = claudeResponse.Usage.CacheCreationInputTokens
claudeInfo.Usage.ClaudeCacheCreation5mTokens = claudeResponse.Usage.GetCacheCreation5mTokens()
claudeInfo.Usage.ClaudeCacheCreation1hTokens = claudeResponse.Usage.GetCacheCreation1hTokens()
}
var responseData []byte
switch info.RelayFormat {

View File

@@ -211,16 +211,7 @@ func CovertGemini2OpenAI(c *gin.Context, textRequest dto.GeneralOpenAIRequest, i
// eg. {"google":{"thinking_config":{"thinking_budget":5324,"include_thoughts":true}}}
if googleBody, ok := extraBody["google"].(map[string]interface{}); ok {
adaptorWithExtraBody = true
// check error param name like thinkingConfig, should be thinking_config
if _, hasErrorParam := googleBody["thinkingConfig"]; hasErrorParam {
return nil, errors.New("extra_body.google.thinkingConfig is not supported, use extra_body.google.thinking_config instead")
}
if thinkingConfig, ok := googleBody["thinking_config"].(map[string]interface{}); ok {
// check error param name like thinkingBudget, should be thinking_budget
if _, hasErrorParam := thinkingConfig["thinkingBudget"]; hasErrorParam {
return nil, errors.New("extra_body.google.thinking_config.thinkingBudget is not supported, use extra_body.google.thinking_config.thinking_budget instead")
}
if budget, ok := thinkingConfig["thinking_budget"].(float64); ok {
budgetInt := int(budget)
geminiRequest.GenerationConfig.ThinkingConfig = &dto.GeminiThinkingConfig{
@@ -1061,11 +1052,11 @@ func GeminiChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.R
}
if len(geminiResponse.Candidates) == 0 {
//return nil, types.NewOpenAIError(errors.New("no candidates returned"), types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
//if geminiResponse.PromptFeedback != nil && geminiResponse.PromptFeedback.BlockReason != nil {
// return nil, types.NewOpenAIError(errors.New("request blocked by Gemini API: "+*geminiResponse.PromptFeedback.BlockReason), types.ErrorCodePromptBlocked, http.StatusBadRequest)
//} else {
// return nil, types.NewOpenAIError(errors.New("empty response from Gemini API"), types.ErrorCodeEmptyResponse, http.StatusInternalServerError)
//}
if geminiResponse.PromptFeedback != nil && geminiResponse.PromptFeedback.BlockReason != nil {
return nil, types.NewOpenAIError(errors.New("request blocked by Gemini API: "+*geminiResponse.PromptFeedback.BlockReason), types.ErrorCodePromptBlocked, http.StatusBadRequest)
} else {
return nil, types.NewOpenAIError(errors.New("empty response from Gemini API"), types.ErrorCodeEmptyResponse, http.StatusInternalServerError)
}
}
fullTextResponse := responseGeminiChat2OpenAI(c, &geminiResponse)
fullTextResponse.Model = info.UpstreamModelName

View File

@@ -1,132 +0,0 @@
package minimax
import (
"bytes"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"github.com/QuantumNous/new-api/dto"
"github.com/QuantumNous/new-api/relay/channel"
"github.com/QuantumNous/new-api/relay/channel/openai"
relaycommon "github.com/QuantumNous/new-api/relay/common"
"github.com/QuantumNous/new-api/relay/constant"
"github.com/QuantumNous/new-api/types"
"github.com/gin-gonic/gin"
)
type Adaptor struct {
}
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
return nil, errors.New("not implemented")
}
func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, req *dto.ClaudeRequest) (any, error) {
return nil, errors.New("not implemented")
}
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
if info.RelayMode != constant.RelayModeAudioSpeech {
return nil, errors.New("unsupported audio relay mode")
}
voiceID := request.Voice
speed := request.Speed
outputFormat := request.ResponseFormat
minimaxRequest := MiniMaxTTSRequest{
Model: info.OriginModelName,
Text: request.Input,
VoiceSetting: VoiceSetting{
VoiceID: voiceID,
Speed: speed,
},
AudioSetting: &AudioSetting{
Format: outputFormat,
},
OutputFormat: outputFormat,
}
// 同步扩展字段的厂商自定义metadata
if len(request.Metadata) > 0 {
if err := json.Unmarshal(request.Metadata, &minimaxRequest); err != nil {
return nil, fmt.Errorf("error unmarshalling metadata to minimax request: %w", err)
}
}
jsonData, err := json.Marshal(minimaxRequest)
if err != nil {
return nil, fmt.Errorf("error marshalling minimax request: %w", err)
}
if outputFormat != "hex" {
outputFormat = "url"
}
c.Set("response_format", outputFormat)
// Debug: log the request structure
// fmt.Printf("MiniMax TTS Request: %s\n", string(jsonData))
return bytes.NewReader(jsonData), nil
}
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
return request, nil
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
return GetRequestURL(info)
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
channel.SetupApiRequestHeader(info, c, req)
req.Set("Authorization", "Bearer "+info.ApiKey)
return nil
}
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
return request, nil
}
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
return nil, nil
}
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
return request, nil
}
func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {
return nil, errors.New("not implemented")
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
if info.RelayMode == constant.RelayModeAudioSpeech {
return handleTTSResponse(c, resp, info)
}
adaptor := openai.Adaptor{}
return adaptor.DoResponse(c, resp, info)
}
func (a *Adaptor) GetModelList() []string {
return ModelList
}
func (a *Adaptor) GetChannelName() string {
return ChannelName
}

View File

@@ -8,12 +8,6 @@ var ModelList = []string{
"abab6-chat",
"abab5.5-chat",
"abab5.5s-chat",
"speech-2.5-hd-preview",
"speech-2.5-turbo-preview",
"speech-02-hd",
"speech-02-turbo",
"speech-01-hd",
"speech-01-turbo",
}
var ChannelName = "minimax"

View File

@@ -3,23 +3,9 @@ package minimax
import (
"fmt"
channelconstant "github.com/QuantumNous/new-api/constant"
relaycommon "github.com/QuantumNous/new-api/relay/common"
"github.com/QuantumNous/new-api/relay/constant"
)
func GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
baseUrl := info.ChannelBaseUrl
if baseUrl == "" {
baseUrl = channelconstant.ChannelBaseURLs[channelconstant.ChannelTypeMiniMax]
}
switch info.RelayMode {
case constant.RelayModeChatCompletions:
return fmt.Sprintf("%s/v1/text/chatcompletion_v2", baseUrl), nil
case constant.RelayModeAudioSpeech:
return fmt.Sprintf("%s/v1/t2a_v2", baseUrl), nil
default:
return "", fmt.Errorf("unsupported relay mode: %d", info.RelayMode)
}
return fmt.Sprintf("%s/v1/text/chatcompletion_v2", info.ChannelBaseUrl), nil
}

View File

@@ -1,194 +0,0 @@
package minimax
import (
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"strings"
"github.com/QuantumNous/new-api/dto"
relaycommon "github.com/QuantumNous/new-api/relay/common"
"github.com/QuantumNous/new-api/types"
"github.com/gin-gonic/gin"
)
type MiniMaxTTSRequest struct {
Model string `json:"model"`
Text string `json:"text"`
Stream bool `json:"stream,omitempty"`
StreamOptions *StreamOptions `json:"stream_options,omitempty"`
VoiceSetting VoiceSetting `json:"voice_setting"`
PronunciationDict *PronunciationDict `json:"pronunciation_dict,omitempty"`
AudioSetting *AudioSetting `json:"audio_setting,omitempty"`
TimbreWeights []TimbreWeight `json:"timbre_weights,omitempty"`
LanguageBoost string `json:"language_boost,omitempty"`
VoiceModify *VoiceModify `json:"voice_modify,omitempty"`
SubtitleEnable bool `json:"subtitle_enable,omitempty"`
OutputFormat string `json:"output_format,omitempty"`
AigcWatermark bool `json:"aigc_watermark,omitempty"`
}
type StreamOptions struct {
ExcludeAggregatedAudio bool `json:"exclude_aggregated_audio,omitempty"`
}
type VoiceSetting struct {
VoiceID string `json:"voice_id"`
Speed float64 `json:"speed,omitempty"`
Vol float64 `json:"vol,omitempty"`
Pitch int `json:"pitch,omitempty"`
Emotion string `json:"emotion,omitempty"`
TextNormalization bool `json:"text_normalization,omitempty"`
LatexRead bool `json:"latex_read,omitempty"`
}
type PronunciationDict struct {
Tone []string `json:"tone,omitempty"`
}
type AudioSetting struct {
SampleRate int `json:"sample_rate,omitempty"`
Bitrate int `json:"bitrate,omitempty"`
Format string `json:"format,omitempty"`
Channel int `json:"channel,omitempty"`
ForceCbr bool `json:"force_cbr,omitempty"`
}
type TimbreWeight struct {
VoiceID string `json:"voice_id"`
Weight int `json:"weight"`
}
type VoiceModify struct {
Pitch int `json:"pitch,omitempty"`
Intensity int `json:"intensity,omitempty"`
Timbre int `json:"timbre,omitempty"`
SoundEffects string `json:"sound_effects,omitempty"`
}
type MiniMaxTTSResponse struct {
Data MiniMaxTTSData `json:"data"`
ExtraInfo MiniMaxExtraInfo `json:"extra_info"`
TraceID string `json:"trace_id"`
BaseResp MiniMaxBaseResp `json:"base_resp"`
}
type MiniMaxTTSData struct {
Audio string `json:"audio"`
Status int `json:"status"`
}
type MiniMaxExtraInfo struct {
UsageCharacters int64 `json:"usage_characters"`
}
type MiniMaxBaseResp struct {
StatusCode int64 `json:"status_code"`
StatusMsg string `json:"status_msg"`
}
func getContentTypeByFormat(format string) string {
contentTypeMap := map[string]string{
"mp3": "audio/mpeg",
"wav": "audio/wav",
"flac": "audio/flac",
"aac": "audio/aac",
"pcm": "audio/pcm",
}
if ct, ok := contentTypeMap[format]; ok {
return ct
}
return "audio/mpeg" // default to mp3
}
func handleTTSResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
body, readErr := io.ReadAll(resp.Body)
if readErr != nil {
return nil, types.NewErrorWithStatusCode(
fmt.Errorf("failed to read minimax response: %w", readErr),
types.ErrorCodeReadResponseBodyFailed,
http.StatusInternalServerError,
)
}
defer resp.Body.Close()
// Parse response
var minimaxResp MiniMaxTTSResponse
if unmarshalErr := json.Unmarshal(body, &minimaxResp); unmarshalErr != nil {
return nil, types.NewErrorWithStatusCode(
fmt.Errorf("failed to unmarshal minimax TTS response: %w", unmarshalErr),
types.ErrorCodeBadResponseBody,
http.StatusInternalServerError,
)
}
// Check base_resp status code
if minimaxResp.BaseResp.StatusCode != 0 {
return nil, types.NewErrorWithStatusCode(
fmt.Errorf("minimax TTS error: %d - %s", minimaxResp.BaseResp.StatusCode, minimaxResp.BaseResp.StatusMsg),
types.ErrorCodeBadResponse,
http.StatusBadRequest,
)
}
// Check if we have audio data
if minimaxResp.Data.Audio == "" {
return nil, types.NewErrorWithStatusCode(
fmt.Errorf("no audio data in minimax TTS response"),
types.ErrorCodeBadResponse,
http.StatusBadRequest,
)
}
if strings.HasPrefix(minimaxResp.Data.Audio, "http") {
c.Redirect(http.StatusFound, minimaxResp.Data.Audio)
} else {
// Handle hex-encoded audio data
audioData, decodeErr := hex.DecodeString(minimaxResp.Data.Audio)
if decodeErr != nil {
return nil, types.NewErrorWithStatusCode(
fmt.Errorf("failed to decode hex audio data: %w", decodeErr),
types.ErrorCodeBadResponse,
http.StatusInternalServerError,
)
}
// Determine content type - default to mp3
contentType := "audio/mpeg"
c.Data(http.StatusOK, contentType, audioData)
}
usage = &dto.Usage{
PromptTokens: info.PromptTokens,
CompletionTokens: 0,
TotalTokens: int(minimaxResp.ExtraInfo.UsageCharacters),
}
return usage, nil
}
func handleChatCompletionResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
body, readErr := io.ReadAll(resp.Body)
if readErr != nil {
return nil, types.NewErrorWithStatusCode(
errors.New("failed to read minimax response"),
types.ErrorCodeReadResponseBodyFailed,
http.StatusInternalServerError,
)
}
defer resp.Body.Close()
// Set response headers
for key, values := range resp.Header {
for _, value := range values {
c.Header(key, value)
}
}
c.Data(resp.StatusCode, "application/json", body)
return nil, nil
}

View File

@@ -121,14 +121,7 @@ func ollamaStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http
if chunk.Message != nil && len(chunk.Message.Thinking) > 0 {
raw := strings.TrimSpace(string(chunk.Message.Thinking))
if raw != "" && raw != "null" {
// Unmarshal the JSON string to get the actual content without quotes
var thinkingContent string
if err := json.Unmarshal(chunk.Message.Thinking, &thinkingContent); err == nil {
delta.Choices[0].Delta.SetReasoningContent(thinkingContent)
} else {
// Fallback to raw string if it's not a JSON string
delta.Choices[0].Delta.SetReasoningContent(raw)
}
delta.Choices[0].Delta.SetReasoningContent(raw)
}
}
// tool calls
@@ -216,14 +209,7 @@ func ollamaChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.R
if ck.Message != nil && len(ck.Message.Thinking) > 0 {
raw := strings.TrimSpace(string(ck.Message.Thinking))
if raw != "" && raw != "null" {
// Unmarshal the JSON string to get the actual content without quotes
var thinkingContent string
if err := json.Unmarshal(ck.Message.Thinking, &thinkingContent); err == nil {
reasoningBuilder.WriteString(thinkingContent)
} else {
// Fallback to raw string if it's not a JSON string
reasoningBuilder.WriteString(raw)
}
reasoningBuilder.WriteString(raw)
}
}
if ck.Message != nil && ck.Message.Content != "" {
@@ -243,14 +229,7 @@ func ollamaChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.R
if len(single.Message.Thinking) > 0 {
raw := strings.TrimSpace(string(single.Message.Thinking))
if raw != "" && raw != "null" {
// Unmarshal the JSON string to get the actual content without quotes
var thinkingContent string
if err := json.Unmarshal(single.Message.Thinking, &thinkingContent); err == nil {
reasoningBuilder.WriteString(thinkingContent)
} else {
// Fallback to raw string if it's not a JSON string
reasoningBuilder.WriteString(raw)
}
reasoningBuilder.WriteString(raw)
}
}
aggContent.WriteString(single.Message.Content)

View File

@@ -15,12 +15,10 @@ import (
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/constant"
"github.com/QuantumNous/new-api/dto"
"github.com/QuantumNous/new-api/logger"
"github.com/QuantumNous/new-api/relay/channel"
"github.com/QuantumNous/new-api/relay/channel/ai360"
"github.com/QuantumNous/new-api/relay/channel/lingyiwanwu"
//"github.com/QuantumNous/new-api/relay/channel/minimax"
"github.com/QuantumNous/new-api/relay/channel/minimax"
"github.com/QuantumNous/new-api/relay/channel/openrouter"
"github.com/QuantumNous/new-api/relay/channel/xinference"
relaycommon "github.com/QuantumNous/new-api/relay/common"
@@ -163,8 +161,8 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
requestURL = fmt.Sprintf("/openai/realtime?deployment=%s&api-version=%s", model_, apiVersion)
}
return relaycommon.GetFullRequestURL(info.ChannelBaseUrl, requestURL, info.ChannelType), nil
//case constant.ChannelTypeMiniMax:
// return minimax.GetRequestURL(info)
case constant.ChannelTypeMiniMax:
return minimax.GetRequestURL(info)
case constant.ChannelTypeCustom:
url := info.ChannelBaseUrl
url = strings.Replace(url, "{model}", info.UpstreamModelName, -1)
@@ -354,43 +352,27 @@ func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInf
writer.WriteField("model", request.Model)
formData, err2 := common.ParseMultipartFormReusable(c)
if err2 != nil {
return nil, fmt.Errorf("error parsing multipart form: %w", err2)
}
// 打印类似 curl 命令格式的信息
logger.LogDebug(c.Request.Context(), fmt.Sprintf("--form 'model=\"%s\"'", request.Model))
// 获取所有表单字段
formData := c.Request.PostForm
// 遍历表单字段并打印输出
for key, values := range formData.Value {
for key, values := range formData {
if key == "model" {
continue
}
for _, value := range values {
writer.WriteField(key, value)
logger.LogDebug(c.Request.Context(), fmt.Sprintf("--form '%s=\"%s\"'", key, value))
}
}
// 从 formData 中获取文件
fileHeaders := formData.File["file"]
if len(fileHeaders) == 0 {
return nil, errors.New("file is required")
}
// 使用 formData 中的第一个文件
fileHeader := fileHeaders[0]
logger.LogDebug(c.Request.Context(), fmt.Sprintf("--form 'file=@\"%s\"' (size: %d bytes, content-type: %s)",
fileHeader.Filename, fileHeader.Size, fileHeader.Header.Get("Content-Type")))
file, err := fileHeader.Open()
// 添加文件字段
file, header, err := c.Request.FormFile("file")
if err != nil {
return nil, fmt.Errorf("error opening audio file: %v", err)
return nil, errors.New("file is required")
}
defer file.Close()
part, err := writer.CreateFormFile("file", fileHeader.Filename)
part, err := writer.CreateFormFile("file", header.Filename)
if err != nil {
return nil, errors.New("create form file failed")
}
@@ -401,7 +383,6 @@ func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInf
// 关闭 multipart 编写器以设置分界线
writer.Close()
c.Request.Header.Set("Content-Type", writer.FormDataContentType())
logger.LogDebug(c.Request.Context(), fmt.Sprintf("--header 'Content-Type: %s'", writer.FormDataContentType()))
return &requestBody, nil
}
}
@@ -618,8 +599,8 @@ func (a *Adaptor) GetModelList() []string {
return ai360.ModelList
case constant.ChannelTypeLingYiWanWu:
return lingyiwanwu.ModelList
//case constant.ChannelTypeMiniMax:
// return minimax.ModelList
case constant.ChannelTypeMiniMax:
return minimax.ModelList
case constant.ChannelTypeXinference:
return xinference.ModelList
case constant.ChannelTypeOpenRouter:
@@ -635,8 +616,8 @@ func (a *Adaptor) GetChannelName() string {
return ai360.ChannelName
case constant.ChannelTypeLingYiWanWu:
return lingyiwanwu.ChannelName
//case constant.ChannelTypeMiniMax:
// return minimax.ChannelName
case constant.ChannelTypeMiniMax:
return minimax.ChannelName
case constant.ChannelTypeXinference:
return xinference.ChannelName
case constant.ChannelTypeOpenRouter:

View File

@@ -1,10 +1,15 @@
package openai
import (
"bytes"
"encoding/json"
"fmt"
"io"
"math"
"mime/multipart"
"net/http"
"os"
"path/filepath"
"strings"
"github.com/QuantumNous/new-api/common"
@@ -21,6 +26,7 @@ import (
"github.com/bytedance/gopkg/util/gopool"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
"github.com/pkg/errors"
)
func sendStreamData(c *gin.Context, info *relaycommon.RelayInfo, data string, forceFormat bool, thinkToContent bool) error {
@@ -122,10 +128,6 @@ func OaiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re
var usage = &dto.Usage{}
var streamItems []string // store stream items
var lastStreamData string
var secondLastStreamData string // 存储倒数第二个stream data用于音频模型
// 检查是否为音频模型
isAudioModel := strings.Contains(strings.ToLower(model), "audio")
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
if lastStreamData != "" {
@@ -135,35 +137,12 @@ func OaiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re
}
}
if len(data) > 0 {
// 对音频模型保存倒数第二个stream data
if isAudioModel && lastStreamData != "" {
secondLastStreamData = lastStreamData
}
lastStreamData = data
streamItems = append(streamItems, data)
}
return true
})
// 对音频模型从倒数第二个stream data中提取usage信息
if isAudioModel && secondLastStreamData != "" {
var streamResp struct {
Usage *dto.Usage `json:"usage"`
}
err := json.Unmarshal([]byte(secondLastStreamData), &streamResp)
if err == nil && streamResp.Usage != nil && service.ValidUsage(streamResp.Usage) {
usage = streamResp.Usage
containStreamUsage = true
if common.DebugEnabled {
logger.LogDebug(c, fmt.Sprintf("Audio model usage extracted from second last SSE: PromptTokens=%d, CompletionTokens=%d, TotalTokens=%d, InputTokens=%d, OutputTokens=%d",
usage.PromptTokens, usage.CompletionTokens, usage.TotalTokens,
usage.InputTokens, usage.OutputTokens))
}
}
}
// 处理最后的响应
shouldSendLastResp := true
if err := handleLastResponse(lastStreamData, &responseId, &createAt, &systemFingerprint, &model, &usage,
@@ -294,39 +273,6 @@ func OpenaiHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respo
return &simpleResponse.Usage, nil
}
func streamTTSResponse(c *gin.Context, resp *http.Response) {
c.Writer.WriteHeaderNow()
flusher, ok := c.Writer.(http.Flusher)
if !ok {
logger.LogWarn(c, "streaming not supported")
_, err := io.Copy(c.Writer, resp.Body)
if err != nil {
logger.LogWarn(c, err.Error())
}
return
}
buffer := make([]byte, 4096)
for {
n, err := resp.Body.Read(buffer)
//logger.LogInfo(c, fmt.Sprintf("streamTTSResponse read %d bytes", n))
if n > 0 {
if _, writeErr := c.Writer.Write(buffer[:n]); writeErr != nil {
logger.LogError(c, writeErr.Error())
break
}
flusher.Flush()
}
if err != nil {
if err != io.EOF {
logger.LogError(c, err.Error())
}
break
}
}
}
func OpenaiTTSHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) *dto.Usage {
// the status code has been judged before, if there is a body reading failure,
// it should be regarded as a non-recoverable error, so it should not return err for external retry.
@@ -342,16 +288,10 @@ func OpenaiTTSHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
c.Writer.Header().Set(k, v[0])
}
c.Writer.WriteHeader(resp.StatusCode)
isStreaming := resp.ContentLength == -1 || resp.Header.Get("Content-Length") == ""
if isStreaming {
streamTTSResponse(c, resp)
} else {
c.Writer.WriteHeaderNow()
_, err := io.Copy(c.Writer, resp.Body)
if err != nil {
logger.LogError(c, err.Error())
}
c.Writer.WriteHeaderNow()
_, err := io.Copy(c.Writer, resp.Body)
if err != nil {
logger.LogError(c, err.Error())
}
return usage
}
@@ -382,13 +322,59 @@ func OpenaiSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
}
}
audioTokens, err := countAudioTokens(c)
if err != nil {
return types.NewError(err, types.ErrorCodeCountTokenFailed), nil
}
usage := &dto.Usage{}
usage.PromptTokens = info.PromptTokens
usage.PromptTokens = audioTokens
usage.CompletionTokens = 0
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
return nil, usage
}
func countAudioTokens(c *gin.Context) (int, error) {
body, err := common.GetRequestBody(c)
if err != nil {
return 0, errors.WithStack(err)
}
var reqBody struct {
File *multipart.FileHeader `form:"file" binding:"required"`
}
c.Request.Body = io.NopCloser(bytes.NewReader(body))
if err = c.ShouldBind(&reqBody); err != nil {
return 0, errors.WithStack(err)
}
ext := filepath.Ext(reqBody.File.Filename) // 获取文件扩展名
reqFp, err := reqBody.File.Open()
if err != nil {
return 0, errors.WithStack(err)
}
defer reqFp.Close()
tmpFp, err := os.CreateTemp("", "audio-*"+ext)
if err != nil {
return 0, errors.WithStack(err)
}
defer os.Remove(tmpFp.Name())
_, err = io.Copy(tmpFp, reqFp)
if err != nil {
return 0, errors.WithStack(err)
}
if err = tmpFp.Close(); err != nil {
return 0, errors.WithStack(err)
}
duration, err := common.GetAudioDuration(c.Request.Context(), tmpFp.Name(), ext)
if err != nil {
return 0, errors.WithStack(err)
}
return int(math.Round(math.Ceil(duration) / 60.0 * 1000)), nil // 1 minute 相当于 1k tokens
}
func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*types.NewAPIError, *dto.RealtimeUsage) {
if info == nil || info.ClientWs == nil || info.TargetWs == nil {
return types.NewError(fmt.Errorf("invalid websocket connection"), types.ErrorCodeBadResponse), nil

View File

@@ -6,7 +6,6 @@ import (
"io"
"net/http"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/dto"
"github.com/QuantumNous/new-api/relay/channel"
"github.com/QuantumNous/new-api/relay/channel/openai"
@@ -31,32 +30,13 @@ func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayIn
}
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
adaptor := openai.Adaptor{}
return adaptor.ConvertAudioRequest(c, info, request)
//TODO implement me
return nil, errors.New("not supported")
}
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
// 解析extra到SFImageRequest里以填入SiliconFlow特殊字段。若失败重建一个空的。
sfRequest := &SFImageRequest{}
extra, err := common.Marshal(request.Extra)
if err == nil {
err = common.Unmarshal(extra, sfRequest)
if err != nil {
sfRequest = &SFImageRequest{}
}
}
sfRequest.Model = request.Model
sfRequest.Prompt = request.Prompt
// 优先使用image_size/batch_size否则使用OpenAI标准的size/n
if sfRequest.ImageSize == "" {
sfRequest.ImageSize = request.Size
}
if sfRequest.BatchSize == 0 {
sfRequest.BatchSize = request.N
}
return sfRequest, nil
adaptor := openai.Adaptor{}
return adaptor.ConvertImageRequest(c, info, request)
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
@@ -65,8 +45,14 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
if info.RelayMode == constant.RelayModeRerank {
return fmt.Sprintf("%s/v1/rerank", info.ChannelBaseUrl), nil
} else if info.RelayMode == constant.RelayModeEmbeddings {
return fmt.Sprintf("%s/v1/embeddings", info.ChannelBaseUrl), nil
} else if info.RelayMode == constant.RelayModeChatCompletions {
return fmt.Sprintf("%s/v1/chat/completions", info.ChannelBaseUrl), nil
} else if info.RelayMode == constant.RelayModeCompletions {
return fmt.Sprintf("%s/v1/completions", info.ChannelBaseUrl), nil
}
return relaycommon.GetFullRequestURL(info.ChannelBaseUrl, info.RequestURLPath, info.ChannelType), nil
return fmt.Sprintf("%s/v1/chat/completions", info.ChannelBaseUrl), nil
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
@@ -95,8 +81,7 @@ func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommo
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
adaptor := openai.Adaptor{}
return adaptor.DoRequest(c, info, requestBody)
return channel.DoApiRequest(a, c, info, requestBody)
}
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
@@ -111,9 +96,19 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
switch info.RelayMode {
case constant.RelayModeRerank:
usage, err = siliconflowRerankHandler(c, info, resp)
case constant.RelayModeEmbeddings:
usage, err = openai.OpenaiHandler(c, info, resp)
case constant.RelayModeCompletions:
fallthrough
case constant.RelayModeChatCompletions:
fallthrough
default:
adaptor := openai.Adaptor{}
usage, err = adaptor.DoResponse(c, resp, info)
if info.IsStream {
usage, err = openai.OaiStreamHandler(c, info, resp)
} else {
usage, err = openai.OpenaiHandler(c, info, resp)
}
}
return
}

View File

@@ -15,18 +15,3 @@ type SFRerankResponse struct {
Results []dto.RerankResponseResult `json:"results"`
Meta SFMeta `json:"meta"`
}
type SFImageRequest struct {
Model string `json:"model"`
Prompt string `json:"prompt"`
NegativePrompt string `json:"negative_prompt,omitempty"`
ImageSize string `json:"image_size,omitempty"`
BatchSize uint `json:"batch_size,omitempty"`
Seed uint64 `json:"seed,omitempty"`
NumInferenceSteps uint `json:"num_inference_steps,omitempty"`
GuidanceScale float64 `json:"guidance_scale,omitempty"`
Cfg float64 `json:"cfg,omitempty"`
Image string `json:"image,omitempty"`
Image2 string `json:"image2,omitempty"`
Image3 string `json:"image3,omitempty"`
}

View File

@@ -1,505 +0,0 @@
package ali
import (
"bytes"
"fmt"
"io"
"net/http"
"strconv"
"strings"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/dto"
"github.com/QuantumNous/new-api/logger"
"github.com/QuantumNous/new-api/model"
"github.com/QuantumNous/new-api/relay/channel"
relaycommon "github.com/QuantumNous/new-api/relay/common"
"github.com/QuantumNous/new-api/service"
"github.com/samber/lo"
"github.com/gin-gonic/gin"
"github.com/pkg/errors"
)
// ============================
// Request / Response structures
// ============================
// AliVideoRequest 阿里通义万相视频生成请求
type AliVideoRequest struct {
Model string `json:"model"`
Input AliVideoInput `json:"input"`
Parameters *AliVideoParameters `json:"parameters,omitempty"`
}
// AliVideoInput 视频输入参数
type AliVideoInput struct {
Prompt string `json:"prompt,omitempty"` // 文本提示词
ImgURL string `json:"img_url,omitempty"` // 首帧图像URL或Base64图生视频
FirstFrameURL string `json:"first_frame_url,omitempty"` // 首帧图片URL首尾帧生视频
LastFrameURL string `json:"last_frame_url,omitempty"` // 尾帧图片URL首尾帧生视频
AudioURL string `json:"audio_url,omitempty"` // 音频URLwan2.5支持)
NegativePrompt string `json:"negative_prompt,omitempty"` // 反向提示词
Template string `json:"template,omitempty"` // 视频特效模板
}
// AliVideoParameters 视频参数
type AliVideoParameters struct {
Resolution string `json:"resolution,omitempty"` // 分辨率: 480P/720P/1080P图生视频、首尾帧生视频
Size string `json:"size,omitempty"` // 尺寸: 如 "832*480"(文生视频)
Duration int `json:"duration,omitempty"` // 时长: 3-10秒
PromptExtend bool `json:"prompt_extend,omitempty"` // 是否开启prompt智能改写
Watermark bool `json:"watermark,omitempty"` // 是否添加水印
Audio *bool `json:"audio,omitempty"` // 是否添加音频wan2.5
Seed int `json:"seed,omitempty"` // 随机数种子
}
// AliVideoResponse 阿里通义万相响应
type AliVideoResponse struct {
Output AliVideoOutput `json:"output"`
RequestID string `json:"request_id"`
Code string `json:"code,omitempty"`
Message string `json:"message,omitempty"`
Usage *AliUsage `json:"usage,omitempty"`
}
// AliVideoOutput 输出信息
type AliVideoOutput struct {
TaskID string `json:"task_id"`
TaskStatus string `json:"task_status"`
SubmitTime string `json:"submit_time,omitempty"`
ScheduledTime string `json:"scheduled_time,omitempty"`
EndTime string `json:"end_time,omitempty"`
OrigPrompt string `json:"orig_prompt,omitempty"`
ActualPrompt string `json:"actual_prompt,omitempty"`
VideoURL string `json:"video_url,omitempty"`
Code string `json:"code,omitempty"`
Message string `json:"message,omitempty"`
}
// AliUsage 使用统计
type AliUsage struct {
Duration int `json:"duration,omitempty"`
VideoCount int `json:"video_count,omitempty"`
SR int `json:"SR,omitempty"`
}
type AliMetadata struct {
// Input 相关
AudioURL string `json:"audio_url,omitempty"` // 音频URL
ImgURL string `json:"img_url,omitempty"` // 图片URL图生视频
FirstFrameURL string `json:"first_frame_url,omitempty"` // 首帧图片URL首尾帧生视频
LastFrameURL string `json:"last_frame_url,omitempty"` // 尾帧图片URL首尾帧生视频
NegativePrompt string `json:"negative_prompt,omitempty"` // 反向提示词
Template string `json:"template,omitempty"` // 视频特效模板
// Parameters 相关
Resolution *string `json:"resolution,omitempty"` // 分辨率: 480P/720P/1080P
Size *string `json:"size,omitempty"` // 尺寸: 如 "832*480"
Duration *int `json:"duration,omitempty"` // 时长
PromptExtend *bool `json:"prompt_extend,omitempty"` // 是否开启prompt智能改写
Watermark *bool `json:"watermark,omitempty"` // 是否添加水印
Audio *bool `json:"audio,omitempty"` // 是否添加音频
Seed *int `json:"seed,omitempty"` // 随机数种子
}
// ============================
// Adaptor implementation
// ============================
type TaskAdaptor struct {
ChannelType int
apiKey string
baseURL string
aliReq *AliVideoRequest
}
func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) {
a.ChannelType = info.ChannelType
a.baseURL = info.ChannelBaseUrl
a.apiKey = info.ApiKey
}
func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.TaskError) {
// 阿里通义万相支持 JSON 格式,不使用 multipart
var taskReq relaycommon.TaskSubmitReq
if err := common.UnmarshalBodyReusable(c, &taskReq); err != nil {
return service.TaskErrorWrapper(err, "unmarshal_task_request_failed", http.StatusBadRequest)
}
aliReq, err := a.convertToAliRequest(info, taskReq)
if err != nil {
return service.TaskErrorWrapper(err, "convert_to_ali_request_failed", http.StatusInternalServerError)
}
a.aliReq = aliReq
logger.LogJson(c, "ali video request body", aliReq)
return relaycommon.ValidateMultipartDirect(c, info)
}
func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.RelayInfo) (string, error) {
return fmt.Sprintf("%s/api/v1/services/aigc/video-generation/video-synthesis", a.baseURL), nil
}
// BuildRequestHeader sets required headers for Ali API
func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
req.Header.Set("Authorization", "Bearer "+a.apiKey)
req.Header.Set("Content-Type", "application/json")
req.Header.Set("X-DashScope-Async", "enable") // 阿里异步任务必须设置
return nil
}
func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayInfo) (io.Reader, error) {
bodyBytes, err := common.Marshal(a.aliReq)
if err != nil {
return nil, errors.Wrap(err, "marshal_ali_request_failed")
}
return bytes.NewReader(bodyBytes), nil
}
var (
size480p = []string{
"832*480",
"480*832",
"624*624",
}
size720p = []string{
"1280*720",
"720*1280",
"960*960",
"1088*832",
"832*1088",
}
size1080p = []string{
"1920*1080",
"1080*1920",
"1440*1440",
"1632*1248",
"1248*1632",
}
)
func sizeToResolution(size string) (string, error) {
if lo.Contains(size480p, size) {
return "480P", nil
} else if lo.Contains(size720p, size) {
return "720P", nil
} else if lo.Contains(size1080p, size) {
return "1080P", nil
}
return "", fmt.Errorf("invalid size: %s", size)
}
func ProcessAliOtherRatios(aliReq *AliVideoRequest) (map[string]float64, error) {
otherRatios := make(map[string]float64)
aliRatios := map[string]map[string]float64{
"wan2.5-t2v-preview": {
"480P": 1,
"720P": 2,
"1080P": 1 / 0.3,
},
"wan2.2-t2v-plus": {
"480P": 1,
"1080P": 0.7 / 0.14,
},
"wan2.5-i2v-preview": {
"480P": 1,
"720P": 2,
"1080P": 1 / 0.3,
},
"wan2.2-i2v-plus": {
"480P": 1,
"1080P": 0.7 / 0.14,
},
"wan2.2-kf2v-flash": {
"480P": 1,
"720P": 2,
"1080P": 4.8,
},
"wan2.2-i2v-flash": {
"480P": 1,
"720P": 2,
},
"wan2.2-s2v": {
"480P": 1,
"720P": 0.9 / 0.5,
},
}
var resolution string
// size match
if aliReq.Parameters.Size != "" {
toResolution, err := sizeToResolution(aliReq.Parameters.Size)
if err != nil {
return nil, err
}
resolution = toResolution
} else {
resolution = strings.ToUpper(aliReq.Parameters.Resolution)
if !strings.HasSuffix(resolution, "P") {
resolution = resolution + "P"
}
}
if otherRatio, ok := aliRatios[aliReq.Model]; ok {
if ratio, ok := otherRatio[resolution]; ok {
otherRatios[fmt.Sprintf("resolution-%s", resolution)] = ratio
}
}
return otherRatios, nil
}
func (a *TaskAdaptor) convertToAliRequest(info *relaycommon.RelayInfo, req relaycommon.TaskSubmitReq) (*AliVideoRequest, error) {
aliReq := &AliVideoRequest{
Model: req.Model,
Input: AliVideoInput{
Prompt: req.Prompt,
ImgURL: req.InputReference,
},
Parameters: &AliVideoParameters{
PromptExtend: true, // 默认开启智能改写
Watermark: false,
},
}
// 处理分辨率映射
if req.Size != "" {
// text to video size must be contained *
if strings.Contains(req.Model, "t2v") && !strings.Contains(req.Size, "*") {
return nil, fmt.Errorf("invalid size: %s, example: %s", req.Size, "1920*1080")
}
if strings.Contains(req.Size, "*") {
aliReq.Parameters.Size = req.Size
} else {
resolution := strings.ToUpper(req.Size)
// 支持 480p, 720p, 1080p 或 480P, 720P, 1080P
if !strings.HasSuffix(resolution, "P") {
resolution = resolution + "P"
}
aliReq.Parameters.Resolution = resolution
}
} else {
// 根据模型设置默认分辨率
if strings.Contains(req.Model, "t2v") { // image to video
if strings.HasPrefix(req.Model, "wan2.5") {
aliReq.Parameters.Size = "1920*1080"
} else if strings.HasPrefix(req.Model, "wan2.2") {
aliReq.Parameters.Size = "1920*1080"
} else {
aliReq.Parameters.Size = "1280*720"
}
} else {
if strings.HasPrefix(req.Model, "wan2.5") {
aliReq.Parameters.Resolution = "1080P"
} else if strings.HasPrefix(req.Model, "wan2.2-i2v-flash") {
aliReq.Parameters.Resolution = "720P"
} else if strings.HasPrefix(req.Model, "wan2.2-i2v-plus") {
aliReq.Parameters.Resolution = "1080P"
} else {
aliReq.Parameters.Resolution = "720P"
}
}
}
// 处理时长
if req.Duration > 0 {
aliReq.Parameters.Duration = req.Duration
} else if req.Seconds != "" {
seconds, err := strconv.Atoi(req.Seconds)
if err != nil {
return nil, errors.Wrap(err, "convert seconds to int failed")
} else {
aliReq.Parameters.Duration = seconds
}
} else {
aliReq.Parameters.Duration = 5 // 默认5秒
}
// 从 metadata 中提取额外参数
if req.Metadata != nil {
if metadataBytes, err := common.Marshal(req.Metadata); err == nil {
err = common.Unmarshal(metadataBytes, aliReq)
if err != nil {
return nil, errors.Wrap(err, "unmarshal metadata failed")
}
} else {
return nil, errors.Wrap(err, "marshal metadata failed")
}
}
if aliReq.Model != req.Model {
return nil, errors.New("can't change model with metadata")
}
info.PriceData.OtherRatios = map[string]float64{
"seconds": float64(aliReq.Parameters.Duration),
}
ratios, err := ProcessAliOtherRatios(aliReq)
if err != nil {
return nil, err
}
for s, f := range ratios {
info.PriceData.OtherRatios[s] = f
}
return aliReq, nil
}
// DoRequest delegates to common helper
func (a *TaskAdaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
return channel.DoTaskApiRequest(a, c, info, requestBody)
}
// DoResponse handles upstream response
func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (taskID string, taskData []byte, taskErr *dto.TaskError) {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
taskErr = service.TaskErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
return
}
_ = resp.Body.Close()
// 解析阿里响应
var aliResp AliVideoResponse
if err := common.Unmarshal(responseBody, &aliResp); err != nil {
taskErr = service.TaskErrorWrapper(errors.Wrapf(err, "body: %s", responseBody), "unmarshal_response_body_failed", http.StatusInternalServerError)
return
}
// 检查错误
if aliResp.Code != "" {
taskErr = service.TaskErrorWrapper(fmt.Errorf("%s: %s", aliResp.Code, aliResp.Message), "ali_api_error", resp.StatusCode)
return
}
if aliResp.Output.TaskID == "" {
taskErr = service.TaskErrorWrapper(fmt.Errorf("task_id is empty"), "invalid_response", http.StatusInternalServerError)
return
}
// 转换为 OpenAI 格式响应
openAIResp := dto.NewOpenAIVideo()
openAIResp.ID = aliResp.Output.TaskID
openAIResp.Model = c.GetString("model")
if openAIResp.Model == "" && info != nil {
openAIResp.Model = info.OriginModelName
}
openAIResp.Status = convertAliStatus(aliResp.Output.TaskStatus)
openAIResp.CreatedAt = common.GetTimestamp()
// 返回 OpenAI 格式
c.JSON(http.StatusOK, openAIResp)
return aliResp.Output.TaskID, responseBody, nil
}
// FetchTask 查询任务状态
func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error) {
taskID, ok := body["task_id"].(string)
if !ok {
return nil, fmt.Errorf("invalid task_id")
}
uri := fmt.Sprintf("%s/api/v1/tasks/%s", baseUrl, taskID)
req, err := http.NewRequest(http.MethodGet, uri, nil)
if err != nil {
return nil, err
}
req.Header.Set("Authorization", "Bearer "+key)
return service.GetHttpClient().Do(req)
}
func (a *TaskAdaptor) GetModelList() []string {
return ModelList
}
func (a *TaskAdaptor) GetChannelName() string {
return ChannelName
}
// ParseTaskResult 解析任务结果
func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) {
var aliResp AliVideoResponse
if err := common.Unmarshal(respBody, &aliResp); err != nil {
return nil, errors.Wrap(err, "unmarshal task result failed")
}
taskResult := relaycommon.TaskInfo{
Code: 0,
}
// 状态映射
switch aliResp.Output.TaskStatus {
case "PENDING":
taskResult.Status = model.TaskStatusQueued
case "RUNNING":
taskResult.Status = model.TaskStatusInProgress
case "SUCCEEDED":
taskResult.Status = model.TaskStatusSuccess
// 阿里直接返回视频URL不需要额外的代理端点
taskResult.Url = aliResp.Output.VideoURL
case "FAILED", "CANCELED", "UNKNOWN":
taskResult.Status = model.TaskStatusFailure
if aliResp.Message != "" {
taskResult.Reason = aliResp.Message
} else if aliResp.Output.Message != "" {
taskResult.Reason = fmt.Sprintf("task failed, code: %s , message: %s", aliResp.Output.Code, aliResp.Output.Message)
} else {
taskResult.Reason = "task failed"
}
default:
taskResult.Status = model.TaskStatusQueued
}
return &taskResult, nil
}
func (a *TaskAdaptor) ConvertToOpenAIVideo(task *model.Task) ([]byte, error) {
var aliResp AliVideoResponse
if err := common.Unmarshal(task.Data, &aliResp); err != nil {
return nil, errors.Wrap(err, "unmarshal ali response failed")
}
openAIResp := dto.NewOpenAIVideo()
openAIResp.ID = task.TaskID
openAIResp.Status = convertAliStatus(aliResp.Output.TaskStatus)
openAIResp.Model = task.Properties.OriginModelName
openAIResp.SetProgressStr(task.Progress)
openAIResp.CreatedAt = task.CreatedAt
openAIResp.CompletedAt = task.UpdatedAt
// 设置视频URL核心字段
openAIResp.SetMetadata("url", aliResp.Output.VideoURL)
// 错误处理
if aliResp.Code != "" {
openAIResp.Error = &dto.OpenAIVideoError{
Code: aliResp.Code,
Message: aliResp.Message,
}
} else if aliResp.Output.Code != "" {
openAIResp.Error = &dto.OpenAIVideoError{
Code: aliResp.Output.Code,
Message: aliResp.Output.Message,
}
}
return common.Marshal(openAIResp)
}
func convertAliStatus(aliStatus string) string {
switch aliStatus {
case "PENDING":
return dto.VideoStatusQueued
case "RUNNING":
return dto.VideoStatusInProgress
case "SUCCEEDED":
return dto.VideoStatusCompleted
case "FAILED", "CANCELED", "UNKNOWN":
return dto.VideoStatusFailed
default:
return dto.VideoStatusUnknown
}
}

View File

@@ -1,11 +0,0 @@
package ali
var ModelList = []string{
"wan2.5-i2v-preview", // 万相2.5 preview有声视频推荐
"wan2.2-i2v-flash", // 万相2.2极速版(无声视频)
"wan2.2-i2v-plus", // 万相2.2专业版(无声视频)
"wanx2.1-i2v-plus", // 万相2.1专业版(无声视频)
"wanx2.1-i2v-turbo", // 万相2.1极速版(无声视频)
}
var ChannelName = "ali"

View File

@@ -1,324 +0,0 @@
package gemini
import (
"bytes"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"net/http"
"regexp"
"strings"
"time"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/constant"
"github.com/QuantumNous/new-api/dto"
"github.com/QuantumNous/new-api/model"
"github.com/QuantumNous/new-api/relay/channel"
relaycommon "github.com/QuantumNous/new-api/relay/common"
"github.com/QuantumNous/new-api/service"
"github.com/QuantumNous/new-api/setting/model_setting"
"github.com/QuantumNous/new-api/setting/system_setting"
"github.com/gin-gonic/gin"
"github.com/pkg/errors"
)
// ============================
// Request / Response structures
// ============================
// GeminiVideoGenerationConfig represents the video generation configuration
// Based on: https://ai.google.dev/gemini-api/docs/video
type GeminiVideoGenerationConfig struct {
AspectRatio string `json:"aspectRatio,omitempty"` // "16:9" or "9:16"
DurationSeconds float64 `json:"durationSeconds,omitempty"` // 4, 6, or 8 (as number)
NegativePrompt string `json:"negativePrompt,omitempty"` // unwanted elements
PersonGeneration string `json:"personGeneration,omitempty"` // "allow_all" for text-to-video, "allow_adult" for image-to-video
Resolution string `json:"resolution,omitempty"` // video resolution
}
// GeminiVideoRequest represents a single video generation instance
type GeminiVideoRequest struct {
Prompt string `json:"prompt"`
}
// GeminiVideoPayload represents the complete video generation request payload
type GeminiVideoPayload struct {
Instances []GeminiVideoRequest `json:"instances"`
Parameters GeminiVideoGenerationConfig `json:"parameters,omitempty"`
}
type submitResponse struct {
Name string `json:"name"`
}
type operationVideo struct {
MimeType string `json:"mimeType"`
BytesBase64Encoded string `json:"bytesBase64Encoded"`
Encoding string `json:"encoding"`
}
type operationResponse struct {
Name string `json:"name"`
Done bool `json:"done"`
Response struct {
Type string `json:"@type"`
RaiMediaFilteredCount int `json:"raiMediaFilteredCount"`
Videos []operationVideo `json:"videos"`
BytesBase64Encoded string `json:"bytesBase64Encoded"`
Encoding string `json:"encoding"`
Video string `json:"video"`
GenerateVideoResponse struct {
GeneratedSamples []struct {
Video struct {
URI string `json:"uri"`
} `json:"video"`
} `json:"generatedSamples"`
} `json:"generateVideoResponse"`
} `json:"response"`
Error struct {
Message string `json:"message"`
} `json:"error"`
}
// ============================
// Adaptor implementation
// ============================
type TaskAdaptor struct {
ChannelType int
apiKey string
baseURL string
}
func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) {
a.ChannelType = info.ChannelType
a.baseURL = info.ChannelBaseUrl
a.apiKey = info.ApiKey
}
// ValidateRequestAndSetAction parses body, validates fields and sets default action.
func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.TaskError) {
// Use the standard validation method for TaskSubmitReq
return relaycommon.ValidateBasicTaskRequest(c, info, constant.TaskActionTextGenerate)
}
// BuildRequestURL constructs the upstream URL.
func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.RelayInfo) (string, error) {
modelName := info.OriginModelName
version := model_setting.GetGeminiVersionSetting(modelName)
return fmt.Sprintf(
"%s/%s/models/%s:predictLongRunning",
a.baseURL,
version,
modelName,
), nil
}
// BuildRequestHeader sets required headers.
func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "application/json")
req.Header.Set("x-goog-api-key", a.apiKey)
return nil
}
// BuildRequestBody converts request into Gemini specific format.
func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayInfo) (io.Reader, error) {
v, ok := c.Get("task_request")
if !ok {
return nil, fmt.Errorf("request not found in context")
}
req, ok := v.(relaycommon.TaskSubmitReq)
if !ok {
return nil, fmt.Errorf("unexpected task_request type")
}
// Create structured video generation request
body := GeminiVideoPayload{
Instances: []GeminiVideoRequest{
{Prompt: req.Prompt},
},
Parameters: GeminiVideoGenerationConfig{},
}
metadata := req.Metadata
medaBytes, err := json.Marshal(metadata)
if err != nil {
return nil, errors.Wrap(err, "metadata marshal metadata failed")
}
err = json.Unmarshal(medaBytes, &body.Parameters)
if err != nil {
return nil, errors.Wrap(err, "unmarshal metadata failed")
}
data, err := json.Marshal(body)
if err != nil {
return nil, err
}
return bytes.NewReader(data), nil
}
// DoRequest delegates to common helper.
func (a *TaskAdaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
return channel.DoTaskApiRequest(a, c, info, requestBody)
}
// DoResponse handles upstream response, returns taskID etc.
func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (taskID string, taskData []byte, taskErr *dto.TaskError) {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return "", nil, service.TaskErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
}
_ = resp.Body.Close()
var s submitResponse
if err := json.Unmarshal(responseBody, &s); err != nil {
return "", nil, service.TaskErrorWrapper(err, "unmarshal_response_failed", http.StatusInternalServerError)
}
if strings.TrimSpace(s.Name) == "" {
return "", nil, service.TaskErrorWrapper(fmt.Errorf("missing operation name"), "invalid_response", http.StatusInternalServerError)
}
taskID = encodeLocalTaskID(s.Name)
ov := dto.NewOpenAIVideo()
ov.ID = taskID
ov.TaskID = taskID
ov.CreatedAt = time.Now().Unix()
ov.Model = info.OriginModelName
c.JSON(http.StatusOK, ov)
return taskID, responseBody, nil
}
func (a *TaskAdaptor) GetModelList() []string {
return []string{"veo-3.0-generate-001", "veo-3.1-generate-preview", "veo-3.1-fast-generate-preview"}
}
func (a *TaskAdaptor) GetChannelName() string {
return "gemini"
}
// FetchTask fetch task status
func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error) {
taskID, ok := body["task_id"].(string)
if !ok {
return nil, fmt.Errorf("invalid task_id")
}
upstreamName, err := decodeLocalTaskID(taskID)
if err != nil {
return nil, fmt.Errorf("decode task_id failed: %w", err)
}
// For Gemini API, we use GET request to the operations endpoint
version := model_setting.GetGeminiVersionSetting("default")
url := fmt.Sprintf("%s/%s/%s", baseUrl, version, upstreamName)
req, err := http.NewRequest(http.MethodGet, url, nil)
if err != nil {
return nil, err
}
req.Header.Set("Accept", "application/json")
req.Header.Set("x-goog-api-key", key)
return service.GetHttpClient().Do(req)
}
func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) {
var op operationResponse
if err := json.Unmarshal(respBody, &op); err != nil {
return nil, fmt.Errorf("unmarshal operation response failed: %w", err)
}
ti := &relaycommon.TaskInfo{}
if op.Error.Message != "" {
ti.Status = model.TaskStatusFailure
ti.Reason = op.Error.Message
ti.Progress = "100%"
return ti, nil
}
if !op.Done {
ti.Status = model.TaskStatusInProgress
ti.Progress = "50%"
return ti, nil
}
ti.Status = model.TaskStatusSuccess
ti.Progress = "100%"
taskID := encodeLocalTaskID(op.Name)
ti.TaskID = taskID
ti.Url = fmt.Sprintf("%s/v1/videos/%s/content", system_setting.ServerAddress, taskID)
// Extract URL from generateVideoResponse if available
if len(op.Response.GenerateVideoResponse.GeneratedSamples) > 0 {
if uri := op.Response.GenerateVideoResponse.GeneratedSamples[0].Video.URI; uri != "" {
ti.RemoteUrl = uri
}
}
return ti, nil
}
func (a *TaskAdaptor) ConvertToOpenAIVideo(task *model.Task) ([]byte, error) {
upstreamName, err := decodeLocalTaskID(task.TaskID)
if err != nil {
upstreamName = ""
}
modelName := extractModelFromOperationName(upstreamName)
if strings.TrimSpace(modelName) == "" {
modelName = "veo-3.0-generate-001"
}
video := dto.NewOpenAIVideo()
video.ID = task.TaskID
video.Model = modelName
video.Status = task.Status.ToVideoStatus()
video.SetProgressStr(task.Progress)
video.CreatedAt = task.CreatedAt
if task.FinishTime > 0 {
video.CompletedAt = task.FinishTime
} else if task.UpdatedAt > 0 {
video.CompletedAt = task.UpdatedAt
}
return common.Marshal(video)
}
// ============================
// helpers
// ============================
func encodeLocalTaskID(name string) string {
return base64.RawURLEncoding.EncodeToString([]byte(name))
}
func decodeLocalTaskID(local string) (string, error) {
b, err := base64.RawURLEncoding.DecodeString(local)
if err != nil {
return "", err
}
return string(b), nil
}
var modelRe = regexp.MustCompile(`models/([^/]+)/operations/`)
func extractModelFromOperationName(name string) string {
if name == "" {
return ""
}
if m := modelRe.FindStringSubmatch(name); len(m) == 2 {
return m[1]
}
if idx := strings.Index(name, "models/"); idx >= 0 {
s := name[idx+len("models/"):]
if p := strings.Index(s, "/operations/"); p > 0 {
return s[:p]
}
}
return ""
}

View File

@@ -4,7 +4,6 @@ import (
"bytes"
"crypto/hmac"
"crypto/sha256"
"encoding/base64"
"encoding/hex"
"encoding/json"
"fmt"
@@ -15,7 +14,6 @@ import (
"strings"
"time"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/model"
"github.com/gin-gonic/gin"
@@ -66,11 +64,6 @@ type responseTask struct {
TimeElapsed string `json:"time_elapsed"`
}
const (
// 即梦限制单个文件最大4.7MB https://www.volcengine.com/docs/85621/1747301
MaxFileSize int64 = 4*1024*1024 + 700*1024 // 4.7MB (4MB + 724KB)
)
// ============================
// Adaptor implementation
// ============================
@@ -96,6 +89,7 @@ func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) {
// ValidateRequestAndSetAction parses body, validates fields and sets default action.
func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.TaskError) {
// Accept only POST /v1/video/generations as "generate" action.
return relaycommon.ValidateBasicTaskRequest(c, info, constant.TaskActionGenerate)
}
@@ -119,49 +113,13 @@ func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info
return nil
}
// BuildRequestBody converts request into Jimeng specific format.
func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayInfo) (io.Reader, error) {
v, exists := c.Get("task_request")
if !exists {
return nil, fmt.Errorf("request not found in context")
}
req, ok := v.(relaycommon.TaskSubmitReq)
if !ok {
return nil, fmt.Errorf("invalid request type in context")
}
// 支持openai sdk的图片上传方式
if mf, err := c.MultipartForm(); err == nil {
if files, exists := mf.File["input_reference"]; exists && len(files) > 0 {
if len(files) == 1 {
info.Action = constant.TaskActionGenerate
} else if len(files) > 1 {
info.Action = constant.TaskActionFirstTailGenerate
}
// 将上传的文件转换为base64格式
var images []string
for _, fileHeader := range files {
// 检查文件大小
if fileHeader.Size > MaxFileSize {
return nil, fmt.Errorf("文件 %s 大小超过限制,最大允许 %d MB", fileHeader.Filename, MaxFileSize/(1024*1024))
}
file, err := fileHeader.Open()
if err != nil {
continue
}
fileBytes, err := io.ReadAll(file)
file.Close()
if err != nil {
continue
}
// 将文件内容转换为base64
base64Str := base64.StdEncoding.EncodeToString(fileBytes)
images = append(images, base64Str)
}
req.Images = images
}
}
req := v.(relaycommon.TaskSubmitReq)
body, err := a.convertToRequestPayload(&req)
if err != nil {
@@ -200,7 +158,7 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela
return
}
ov := dto.NewOpenAIVideo()
ov := relaycommon.NewOpenAIVideo()
ov.ID = jResp.Data.TaskID
ov.TaskID = jResp.Data.TaskID
ov.CreatedAt = time.Now().Unix()
@@ -406,15 +364,12 @@ func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*
// 即梦视频3.0 ReqKey转换
// https://www.volcengine.com/docs/85621/1792707
if strings.Contains(r.ReqKey, "jimeng_v30") {
if r.ReqKey == "jimeng_v30_pro" {
// 3.0 pro只有固定的jimeng_ti2v_v30_pro
r.ReqKey = "jimeng_ti2v_v30_pro"
} else if len(req.Images) > 1 {
if len(r.ImageUrls) > 1 {
// 多张图片:首尾帧生成
r.ReqKey = strings.TrimSuffix(strings.Replace(r.ReqKey, "jimeng_v30", "jimeng_i2v_first_tail_v30", 1), "p")
} else if len(req.Images) == 1 {
r.ReqKey = strings.Replace(r.ReqKey, "jimeng_v30", "jimeng_i2v_first_tail_v30", 1)
} else if len(r.ImageUrls) == 1 {
// 单张图片:图生视频
r.ReqKey = strings.TrimSuffix(strings.Replace(r.ReqKey, "jimeng_v30", "jimeng_i2v_first_v30", 1), "p")
r.ReqKey = strings.Replace(r.ReqKey, "jimeng_v30", "jimeng_i2v_first_v30", 1)
} else {
// 无图片:文生视频
r.ReqKey = strings.Replace(r.ReqKey, "jimeng_v30", "jimeng_t2v_v30", 1)
@@ -450,13 +405,13 @@ func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, e
return &taskResult, nil
}
func (a *TaskAdaptor) ConvertToOpenAIVideo(originTask *model.Task) ([]byte, error) {
func (a *TaskAdaptor) ConvertToOpenAIVideo(originTask *model.Task) (*relaycommon.OpenAIVideo, error) {
var jimengResp responseTask
if err := json.Unmarshal(originTask.Data, &jimengResp); err != nil {
return nil, errors.Wrap(err, "unmarshal jimeng task data failed")
}
openAIVideo := dto.NewOpenAIVideo()
openAIVideo := relaycommon.NewOpenAIVideo()
openAIVideo.ID = originTask.TaskID
openAIVideo.Status = originTask.Status.ToVideoStatus()
openAIVideo.SetProgressStr(originTask.Progress)
@@ -465,14 +420,13 @@ func (a *TaskAdaptor) ConvertToOpenAIVideo(originTask *model.Task) ([]byte, erro
openAIVideo.CompletedAt = originTask.UpdatedAt
if jimengResp.Code != 10000 {
openAIVideo.Error = &dto.OpenAIVideoError{
openAIVideo.Error = &relaycommon.OpenAIVideoError{
Message: jimengResp.Message,
Code: fmt.Sprintf("%d", jimengResp.Code),
}
}
jsonData, _ := common.Marshal(openAIVideo)
return jsonData, nil
return openAIVideo, nil
}
func isNewAPIRelay(apiKey string) bool {

View File

@@ -9,7 +9,6 @@ import (
"strings"
"time"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/model"
"github.com/samber/lo"
@@ -189,7 +188,7 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela
taskErr = service.TaskErrorWrapperLocal(fmt.Errorf(kResp.Message), "task_failed", http.StatusBadRequest)
return
}
ov := dto.NewOpenAIVideo()
ov := relaycommon.NewOpenAIVideo()
ov.ID = kResp.Data.TaskId
ov.TaskID = kResp.Data.TaskId
ov.CreatedAt = time.Now().Unix()
@@ -368,13 +367,13 @@ func isNewAPIRelay(apiKey string) bool {
return strings.HasPrefix(apiKey, "sk-")
}
func (a *TaskAdaptor) ConvertToOpenAIVideo(originTask *model.Task) ([]byte, error) {
func (a *TaskAdaptor) ConvertToOpenAIVideo(originTask *model.Task) (*relaycommon.OpenAIVideo, error) {
var klingResp responsePayload
if err := json.Unmarshal(originTask.Data, &klingResp); err != nil {
return nil, errors.Wrap(err, "unmarshal kling task data failed")
}
openAIVideo := dto.NewOpenAIVideo()
openAIVideo := relaycommon.NewOpenAIVideo()
openAIVideo.ID = originTask.TaskID
openAIVideo.Status = originTask.Status.ToVideoStatus()
openAIVideo.SetProgressStr(originTask.Progress)
@@ -392,11 +391,11 @@ func (a *TaskAdaptor) ConvertToOpenAIVideo(originTask *model.Task) ([]byte, erro
}
if klingResp.Code != 0 && klingResp.Message != "" {
openAIVideo.Error = &dto.OpenAIVideoError{
openAIVideo.Error = &relaycommon.OpenAIVideoError{
Message: klingResp.Message,
Code: fmt.Sprintf("%d", klingResp.Code),
}
}
jsonData, _ := common.Marshal(openAIVideo)
return jsonData, nil
return openAIVideo, nil
}

View File

@@ -2,6 +2,7 @@ package sora
import (
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
@@ -106,7 +107,7 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, _ *relayco
// Parse Sora response
var dResp responseTask
if err := common.Unmarshal(responseBody, &dResp); err != nil {
if err := json.Unmarshal(responseBody, &dResp); err != nil {
taskErr = service.TaskErrorWrapper(errors.Wrapf(err, "body: %s", responseBody), "unmarshal_response_body_failed", http.StatusInternalServerError)
return
}
@@ -153,7 +154,7 @@ func (a *TaskAdaptor) GetChannelName() string {
func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) {
resTask := responseTask{}
if err := common.Unmarshal(respBody, &resTask); err != nil {
if err := json.Unmarshal(respBody, &resTask); err != nil {
return nil, errors.Wrap(err, "unmarshal task result failed")
}
@@ -185,6 +186,11 @@ func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, e
return &taskResult, nil
}
func (a *TaskAdaptor) ConvertToOpenAIVideo(task *model.Task) ([]byte, error) {
return task.Data, nil
func (a *TaskAdaptor) ConvertToOpenAIVideo(task *model.Task) (*relaycommon.OpenAIVideo, error) {
openAIVideo := &relaycommon.OpenAIVideo{}
err := json.Unmarshal(task.Data, openAIVideo)
if err != nil {
return nil, errors.Wrap(err, "unmarshal to OpenAIVideo failed")
}
return openAIVideo, nil
}

View File

@@ -10,7 +10,6 @@ import (
"regexp"
"strings"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/model"
"github.com/gin-gonic/gin"
@@ -303,29 +302,6 @@ func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, e
return ti, nil
}
func (a *TaskAdaptor) ConvertToOpenAIVideo(task *model.Task) ([]byte, error) {
upstreamName, err := decodeLocalTaskID(task.TaskID)
if err != nil {
upstreamName = ""
}
modelName := extractModelFromOperationName(upstreamName)
if strings.TrimSpace(modelName) == "" {
modelName = "veo-3.0-generate-001"
}
v := dto.NewOpenAIVideo()
v.ID = task.TaskID
v.Model = modelName
v.Status = task.Status.ToVideoStatus()
v.SetProgressStr(task.Progress)
v.CreatedAt = task.CreatedAt
v.CompletedAt = task.UpdatedAt
if strings.HasPrefix(task.FailReason, "data:") && len(task.FailReason) > 0 {
v.SetMetadata("url", task.FailReason)
}
return common.Marshal(v)
}
// ============================
// helpers
// ============================

View File

@@ -8,7 +8,6 @@ import (
"net/http"
"time"
"github.com/QuantumNous/new-api/common"
"github.com/gin-gonic/gin"
"github.com/QuantumNous/new-api/constant"
@@ -156,7 +155,7 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela
return
}
ov := dto.NewOpenAIVideo()
ov := relaycommon.NewOpenAIVideo()
ov.ID = vResp.TaskId
ov.TaskID = vResp.TaskId
ov.CreatedAt = time.Now().Unix()
@@ -264,13 +263,13 @@ func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, e
return taskInfo, nil
}
func (a *TaskAdaptor) ConvertToOpenAIVideo(originTask *model.Task) ([]byte, error) {
func (a *TaskAdaptor) ConvertToOpenAIVideo(originTask *model.Task) (*relaycommon.OpenAIVideo, error) {
var viduResp taskResultResponse
if err := json.Unmarshal(originTask.Data, &viduResp); err != nil {
return nil, errors.Wrap(err, "unmarshal vidu task data failed")
}
openAIVideo := dto.NewOpenAIVideo()
openAIVideo := relaycommon.NewOpenAIVideo()
openAIVideo.ID = originTask.TaskID
openAIVideo.Status = originTask.Status.ToVideoStatus()
openAIVideo.SetProgressStr(originTask.Progress)
@@ -282,12 +281,11 @@ func (a *TaskAdaptor) ConvertToOpenAIVideo(originTask *model.Task) ([]byte, erro
}
if viduResp.State == "failed" && viduResp.ErrCode != "" {
openAIVideo.Error = &dto.OpenAIVideoError{
openAIVideo.Error = &relaycommon.OpenAIVideoError{
Message: viduResp.ErrCode,
Code: viduResp.ErrCode,
}
}
jsonData, _ := common.Marshal(openAIVideo)
return jsonData, nil
return openAIVideo, nil
}

View File

@@ -6,7 +6,9 @@ import (
"errors"
"fmt"
"io"
"mime/multipart"
"net/http"
"net/textproto"
"path/filepath"
"strings"
@@ -21,11 +23,6 @@ import (
"github.com/gin-gonic/gin"
)
const (
contextKeyTTSRequest = "volcengine_tts_request"
contextKeyResponseFormat = "response_format"
)
type Adaptor struct {
}
@@ -40,175 +37,136 @@ func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayIn
}
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
if info.RelayMode != constant.RelayModeAudioSpeech {
return nil, errors.New("unsupported audio relay mode")
}
appID, token, err := parseVolcengineAuth(info.ApiKey)
if err != nil {
return nil, err
}
voiceType := mapVoiceType(request.Voice)
speedRatio := request.Speed
encoding := mapEncoding(request.ResponseFormat)
c.Set(contextKeyResponseFormat, encoding)
volcRequest := VolcengineTTSRequest{
App: VolcengineTTSApp{
AppID: appID,
Token: token,
Cluster: "volcano_tts",
},
User: VolcengineTTSUser{
UID: "openai_relay_user",
},
Audio: VolcengineTTSAudio{
VoiceType: voiceType,
Encoding: encoding,
SpeedRatio: speedRatio,
Rate: 24000,
},
Request: VolcengineTTSReqInfo{
ReqID: generateRequestID(),
Text: request.Input,
Operation: "submit",
Model: info.OriginModelName,
},
}
if len(request.Metadata) > 0 {
if err = json.Unmarshal(request.Metadata, &volcRequest); err != nil {
return nil, fmt.Errorf("error unmarshalling metadata to volcengine request: %w", err)
}
}
c.Set(contextKeyTTSRequest, volcRequest)
if volcRequest.Request.Operation == "submit" {
info.IsStream = true
}
jsonData, err := json.Marshal(volcRequest)
if err != nil {
return nil, fmt.Errorf("error marshalling volcengine request: %w", err)
}
return bytes.NewReader(jsonData), nil
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
switch info.RelayMode {
case constant.RelayModeImagesGenerations:
return request, nil
// 根据官方文档,并没有发现豆包生图支持表单请求:https://www.volcengine.com/docs/82379/1824121
//case constant.RelayModeImagesEdits:
//
// var requestBody bytes.Buffer
// writer := multipart.NewWriter(&requestBody)
//
// writer.WriteField("model", request.Model)
//
// formData := c.Request.PostForm
// for key, values := range formData {
// if key == "model" {
// continue
// }
// for _, value := range values {
// writer.WriteField(key, value)
// }
// }
//
// if err := c.Request.ParseMultipartForm(32 << 20); err != nil {
// return nil, errors.New("failed to parse multipart form")
// }
//
// if c.Request.MultipartForm != nil && c.Request.MultipartForm.File != nil {
// var imageFiles []*multipart.FileHeader
// var exists bool
//
// if imageFiles, exists = c.Request.MultipartForm.File["image"]; !exists || len(imageFiles) == 0 {
// if imageFiles, exists = c.Request.MultipartForm.File["image[]"]; !exists || len(imageFiles) == 0 {
// foundArrayImages := false
// for fieldName, files := range c.Request.MultipartForm.File {
// if strings.HasPrefix(fieldName, "image[") && len(files) > 0 {
// foundArrayImages = true
// for _, file := range files {
// imageFiles = append(imageFiles, file)
// }
// }
// }
//
// if !foundArrayImages && (len(imageFiles) == 0) {
// return nil, errors.New("image is required")
// }
// }
// }
//
// for i, fileHeader := range imageFiles {
// file, err := fileHeader.Open()
// if err != nil {
// return nil, fmt.Errorf("failed to open image file %d: %w", i, err)
// }
// defer file.Close()
//
// fieldName := "image"
// if len(imageFiles) > 1 {
// fieldName = "image[]"
// }
//
// mimeType := detectImageMimeType(fileHeader.Filename)
//
// h := make(textproto.MIMEHeader)
// h.Set("Content-Disposition", fmt.Sprintf(`form-data; name="%s"; filename="%s"`, fieldName, fileHeader.Filename))
// h.Set("Content-Type", mimeType)
//
// part, err := writer.CreatePart(h)
// if err != nil {
// return nil, fmt.Errorf("create form part failed for image %d: %w", i, err)
// }
//
// if _, err := io.Copy(part, file); err != nil {
// return nil, fmt.Errorf("copy file failed for image %d: %w", i, err)
// }
// }
//
// if maskFiles, exists := c.Request.MultipartForm.File["mask"]; exists && len(maskFiles) > 0 {
// maskFile, err := maskFiles[0].Open()
// if err != nil {
// return nil, errors.New("failed to open mask file")
// }
// defer maskFile.Close()
//
// mimeType := detectImageMimeType(maskFiles[0].Filename)
//
// h := make(textproto.MIMEHeader)
// h.Set("Content-Disposition", fmt.Sprintf(`form-data; name="mask"; filename="%s"`, maskFiles[0].Filename))
// h.Set("Content-Type", mimeType)
//
// maskPart, err := writer.CreatePart(h)
// if err != nil {
// return nil, errors.New("create form file failed for mask")
// }
//
// if _, err := io.Copy(maskPart, maskFile); err != nil {
// return nil, errors.New("copy mask file failed")
// }
// }
// } else {
// return nil, errors.New("no multipart form data found")
// }
//
// writer.Close()
// c.Request.Header.Set("Content-Type", writer.FormDataContentType())
// return bytes.NewReader(requestBody.Bytes()), nil
case constant.RelayModeImagesEdits:
var requestBody bytes.Buffer
writer := multipart.NewWriter(&requestBody)
writer.WriteField("model", request.Model)
// 获取所有表单字段
formData := c.Request.PostForm
// 遍历表单字段并打印输出
for key, values := range formData {
if key == "model" {
continue
}
for _, value := range values {
writer.WriteField(key, value)
}
}
// Parse the multipart form to handle both single image and multiple images
if err := c.Request.ParseMultipartForm(32 << 20); err != nil { // 32MB max memory
return nil, errors.New("failed to parse multipart form")
}
if c.Request.MultipartForm != nil && c.Request.MultipartForm.File != nil {
// Check if "image" field exists in any form, including array notation
var imageFiles []*multipart.FileHeader
var exists bool
// First check for standard "image" field
if imageFiles, exists = c.Request.MultipartForm.File["image"]; !exists || len(imageFiles) == 0 {
// If not found, check for "image[]" field
if imageFiles, exists = c.Request.MultipartForm.File["image[]"]; !exists || len(imageFiles) == 0 {
// If still not found, iterate through all fields to find any that start with "image["
foundArrayImages := false
for fieldName, files := range c.Request.MultipartForm.File {
if strings.HasPrefix(fieldName, "image[") && len(files) > 0 {
foundArrayImages = true
for _, file := range files {
imageFiles = append(imageFiles, file)
}
}
}
// If no image fields found at all
if !foundArrayImages && (len(imageFiles) == 0) {
return nil, errors.New("image is required")
}
}
}
// Process all image files
for i, fileHeader := range imageFiles {
file, err := fileHeader.Open()
if err != nil {
return nil, fmt.Errorf("failed to open image file %d: %w", i, err)
}
defer file.Close()
// If multiple images, use image[] as the field name
fieldName := "image"
if len(imageFiles) > 1 {
fieldName = "image[]"
}
// Determine MIME type based on file extension
mimeType := detectImageMimeType(fileHeader.Filename)
// Create a form file with the appropriate content type
h := make(textproto.MIMEHeader)
h.Set("Content-Disposition", fmt.Sprintf(`form-data; name="%s"; filename="%s"`, fieldName, fileHeader.Filename))
h.Set("Content-Type", mimeType)
part, err := writer.CreatePart(h)
if err != nil {
return nil, fmt.Errorf("create form part failed for image %d: %w", i, err)
}
if _, err := io.Copy(part, file); err != nil {
return nil, fmt.Errorf("copy file failed for image %d: %w", i, err)
}
}
// Handle mask file if present
if maskFiles, exists := c.Request.MultipartForm.File["mask"]; exists && len(maskFiles) > 0 {
maskFile, err := maskFiles[0].Open()
if err != nil {
return nil, errors.New("failed to open mask file")
}
defer maskFile.Close()
// Determine MIME type for mask file
mimeType := detectImageMimeType(maskFiles[0].Filename)
// Create a form file with the appropriate content type
h := make(textproto.MIMEHeader)
h.Set("Content-Disposition", fmt.Sprintf(`form-data; name="mask"; filename="%s"`, maskFiles[0].Filename))
h.Set("Content-Type", mimeType)
maskPart, err := writer.CreatePart(h)
if err != nil {
return nil, errors.New("create form file failed for mask")
}
if _, err := io.Copy(maskPart, maskFile); err != nil {
return nil, errors.New("copy mask file failed")
}
}
} else {
return nil, errors.New("no multipart form data found")
}
// 关闭 multipart 编写器以设置分界线
writer.Close()
c.Request.Header.Set("Content-Type", writer.FormDataContentType())
return bytes.NewReader(requestBody.Bytes()), nil
default:
return request, nil
}
}
// detectImageMimeType determines the MIME type based on the file extension
func detectImageMimeType(filename string) string {
ext := strings.ToLower(filepath.Ext(filename))
switch ext {
@@ -219,9 +177,11 @@ func detectImageMimeType(filename string) string {
case ".webp":
return "image/webp"
default:
// Try to detect from extension if possible
if strings.HasPrefix(ext, ".jp") {
return "image/jpeg"
}
// Default to png as a fallback
return "image/png"
}
}
@@ -230,6 +190,7 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
// 支持自定义域名,如果未设置则使用默认域名
baseUrl := info.ChannelBaseUrl
if baseUrl == "" {
baseUrl = channelconstant.ChannelBaseURLs[channelconstant.ChannelTypeVolcEngine]
@@ -250,18 +211,12 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
return fmt.Sprintf("%s/api/v3/chat/completions", baseUrl), nil
case constant.RelayModeEmbeddings:
return fmt.Sprintf("%s/api/v3/embeddings", baseUrl), nil
//豆包的图生图也走generations接口: https://www.volcengine.com/docs/82379/1824121
case constant.RelayModeImagesGenerations, constant.RelayModeImagesEdits:
case constant.RelayModeImagesGenerations:
return fmt.Sprintf("%s/api/v3/images/generations", baseUrl), nil
//case constant.RelayModeImagesEdits:
// return fmt.Sprintf("%s/api/v3/images/edits", baseUrl), nil
case constant.RelayModeImagesEdits:
return fmt.Sprintf("%s/api/v3/images/edits", baseUrl), nil
case constant.RelayModeRerank:
return fmt.Sprintf("%s/api/v3/rerank", baseUrl), nil
case constant.RelayModeAudioSpeech:
if baseUrl == channelconstant.ChannelBaseURLs[channelconstant.ChannelTypeVolcEngine] {
return "wss://openspeech.bytedance.com/api/v1/tts/ws_binary", nil
}
return fmt.Sprintf("%s/v1/audio/speech", baseUrl), nil
default:
}
}
@@ -270,18 +225,6 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
channel.SetupApiRequestHeader(info, c, req)
if info.RelayMode == constant.RelayModeAudioSpeech {
parts := strings.Split(info.ApiKey, "|")
if len(parts) == 2 {
req.Set("Authorization", "Bearer;"+parts[1])
}
req.Set("Content-Type", "application/json")
return nil
} else if info.RelayMode == constant.RelayModeImagesEdits {
req.Set("Content-Type", gin.MIMEJSON)
}
req.Set("Authorization", "Bearer "+info.ApiKey)
return nil
}
@@ -290,7 +233,7 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
if request == nil {
return nil, errors.New("request is nil")
}
// 适配 方舟deepseek混合模型 的 thinking 后缀
if strings.HasSuffix(info.UpstreamModelName, "-thinking") && strings.HasPrefix(info.UpstreamModelName, "deepseek") {
info.UpstreamModelName = strings.TrimSuffix(info.UpstreamModelName, "-thinking")
request.Model = info.UpstreamModelName
@@ -308,61 +251,15 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela
}
func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {
// TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
if info.RelayMode == constant.RelayModeAudioSpeech {
baseUrl := info.ChannelBaseUrl
if baseUrl == "" {
baseUrl = channelconstant.ChannelBaseURLs[channelconstant.ChannelTypeVolcEngine]
}
if baseUrl == channelconstant.ChannelBaseURLs[channelconstant.ChannelTypeVolcEngine] {
if info.IsStream {
return nil, nil
}
}
}
return channel.DoApiRequest(a, c, info, requestBody)
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
if info.RelayMode == constant.RelayModeAudioSpeech {
encoding := mapEncoding(c.GetString(contextKeyResponseFormat))
if info.IsStream {
volcRequestInterface, exists := c.Get(contextKeyTTSRequest)
if !exists {
return nil, types.NewErrorWithStatusCode(
errors.New("volcengine TTS request not found in context"),
types.ErrorCodeBadRequestBody,
http.StatusInternalServerError,
)
}
volcRequest, ok := volcRequestInterface.(VolcengineTTSRequest)
if !ok {
return nil, types.NewErrorWithStatusCode(
errors.New("invalid volcengine TTS request type"),
types.ErrorCodeBadRequestBody,
http.StatusInternalServerError,
)
}
// Get the WebSocket URL
requestURL, urlErr := a.GetRequestURL(info)
if urlErr != nil {
return nil, types.NewErrorWithStatusCode(
urlErr,
types.ErrorCodeBadRequestBody,
http.StatusInternalServerError,
)
}
return handleTTSWebSocketResponse(c, requestURL, volcRequest, info, encoding)
}
return handleTTSResponse(c, resp, info, encoding)
}
adaptor := openai.Adaptor{}
usage, err = adaptor.DoResponse(c, resp, info)
return

View File

@@ -1,533 +0,0 @@
package volcengine
import (
"bytes"
"encoding/binary"
"fmt"
"io"
"math"
"github.com/gorilla/websocket"
)
type (
EventType int32
MsgType uint8
MsgTypeFlagBits uint8
VersionBits uint8
HeaderSizeBits uint8
SerializationBits uint8
CompressionBits uint8
)
const (
MsgTypeFlagNoSeq MsgTypeFlagBits = 0
MsgTypeFlagPositiveSeq MsgTypeFlagBits = 0b1
MsgTypeFlagNegativeSeq MsgTypeFlagBits = 0b11
MsgTypeFlagWithEvent MsgTypeFlagBits = 0b100
)
const (
Version1 VersionBits = iota + 1
)
const (
HeaderSize4 HeaderSizeBits = iota + 1
)
const (
SerializationJSON SerializationBits = 0b1
)
const (
CompressionNone CompressionBits = 0
)
const (
MsgTypeFullClientRequest MsgType = 0b1
MsgTypeAudioOnlyClient MsgType = 0b10
MsgTypeFullServerResponse MsgType = 0b1001
MsgTypeAudioOnlyServer MsgType = 0b1011
MsgTypeFrontEndResultServer MsgType = 0b1100
MsgTypeError MsgType = 0b1111
)
func (t MsgType) String() string {
switch t {
case MsgTypeFullClientRequest:
return "MsgType_FullClientRequest"
case MsgTypeAudioOnlyClient:
return "MsgType_AudioOnlyClient"
case MsgTypeFullServerResponse:
return "MsgType_FullServerResponse"
case MsgTypeAudioOnlyServer:
return "MsgType_AudioOnlyServer"
case MsgTypeError:
return "MsgType_Error"
case MsgTypeFrontEndResultServer:
return "MsgType_FrontEndResultServer"
default:
return fmt.Sprintf("MsgType_(%d)", t)
}
}
const (
EventType_None EventType = 0
EventType_StartConnection EventType = 1
EventType_FinishConnection EventType = 2
EventType_ConnectionStarted EventType = 50
EventType_ConnectionFailed EventType = 51
EventType_ConnectionFinished EventType = 52
EventType_StartSession EventType = 100
EventType_CancelSession EventType = 101
EventType_FinishSession EventType = 102
EventType_SessionStarted EventType = 150
EventType_SessionCanceled EventType = 151
EventType_SessionFinished EventType = 152
EventType_SessionFailed EventType = 153
EventType_UsageResponse EventType = 154
EventType_TaskRequest EventType = 200
EventType_UpdateConfig EventType = 201
EventType_AudioMuted EventType = 250
EventType_SayHello EventType = 300
EventType_TTSSentenceStart EventType = 350
EventType_TTSSentenceEnd EventType = 351
EventType_TTSResponse EventType = 352
EventType_TTSEnded EventType = 359
EventType_PodcastRoundStart EventType = 360
EventType_PodcastRoundResponse EventType = 361
EventType_PodcastRoundEnd EventType = 362
EventType_ASRInfo EventType = 450
EventType_ASRResponse EventType = 451
EventType_ASREnded EventType = 459
EventType_ChatTTSText EventType = 500
EventType_ChatResponse EventType = 550
EventType_ChatEnded EventType = 559
EventType_SourceSubtitleStart EventType = 650
EventType_SourceSubtitleResponse EventType = 651
EventType_SourceSubtitleEnd EventType = 652
EventType_TranslationSubtitleStart EventType = 653
EventType_TranslationSubtitleResponse EventType = 654
EventType_TranslationSubtitleEnd EventType = 655
)
func (t EventType) String() string {
switch t {
case EventType_None:
return "EventType_None"
case EventType_StartConnection:
return "EventType_StartConnection"
case EventType_FinishConnection:
return "EventType_FinishConnection"
case EventType_ConnectionStarted:
return "EventType_ConnectionStarted"
case EventType_ConnectionFailed:
return "EventType_ConnectionFailed"
case EventType_ConnectionFinished:
return "EventType_ConnectionFinished"
case EventType_StartSession:
return "EventType_StartSession"
case EventType_CancelSession:
return "EventType_CancelSession"
case EventType_FinishSession:
return "EventType_FinishSession"
case EventType_SessionStarted:
return "EventType_SessionStarted"
case EventType_SessionCanceled:
return "EventType_SessionCanceled"
case EventType_SessionFinished:
return "EventType_SessionFinished"
case EventType_SessionFailed:
return "EventType_SessionFailed"
case EventType_UsageResponse:
return "EventType_UsageResponse"
case EventType_TaskRequest:
return "EventType_TaskRequest"
case EventType_UpdateConfig:
return "EventType_UpdateConfig"
case EventType_AudioMuted:
return "EventType_AudioMuted"
case EventType_SayHello:
return "EventType_SayHello"
case EventType_TTSSentenceStart:
return "EventType_TTSSentenceStart"
case EventType_TTSSentenceEnd:
return "EventType_TTSSentenceEnd"
case EventType_TTSResponse:
return "EventType_TTSResponse"
case EventType_TTSEnded:
return "EventType_TTSEnded"
case EventType_PodcastRoundStart:
return "EventType_PodcastRoundStart"
case EventType_PodcastRoundResponse:
return "EventType_PodcastRoundResponse"
case EventType_PodcastRoundEnd:
return "EventType_PodcastRoundEnd"
case EventType_ASRInfo:
return "EventType_ASRInfo"
case EventType_ASRResponse:
return "EventType_ASRResponse"
case EventType_ASREnded:
return "EventType_ASREnded"
case EventType_ChatTTSText:
return "EventType_ChatTTSText"
case EventType_ChatResponse:
return "EventType_ChatResponse"
case EventType_ChatEnded:
return "EventType_ChatEnded"
case EventType_SourceSubtitleStart:
return "EventType_SourceSubtitleStart"
case EventType_SourceSubtitleResponse:
return "EventType_SourceSubtitleResponse"
case EventType_SourceSubtitleEnd:
return "EventType_SourceSubtitleEnd"
case EventType_TranslationSubtitleStart:
return "EventType_TranslationSubtitleStart"
case EventType_TranslationSubtitleResponse:
return "EventType_TranslationSubtitleResponse"
case EventType_TranslationSubtitleEnd:
return "EventType_TranslationSubtitleEnd"
default:
return fmt.Sprintf("EventType_(%d)", t)
}
}
type Message struct {
Version VersionBits
HeaderSize HeaderSizeBits
MsgType MsgType
MsgTypeFlag MsgTypeFlagBits
Serialization SerializationBits
Compression CompressionBits
EventType EventType
SessionID string
ConnectID string
Sequence int32
ErrorCode uint32
Payload []byte
}
func NewMessageFromBytes(data []byte) (*Message, error) {
if len(data) < 3 {
return nil, fmt.Errorf("data too short: expected at least 3 bytes, got %d", len(data))
}
typeAndFlag := data[1]
msg, err := NewMessage(MsgType(typeAndFlag>>4), MsgTypeFlagBits(typeAndFlag&0b00001111))
if err != nil {
return nil, err
}
if err := msg.Unmarshal(data); err != nil {
return nil, err
}
return msg, nil
}
func NewMessage(msgType MsgType, flag MsgTypeFlagBits) (*Message, error) {
return &Message{
MsgType: msgType,
MsgTypeFlag: flag,
Version: Version1,
HeaderSize: HeaderSize4,
Serialization: SerializationJSON,
Compression: CompressionNone,
}, nil
}
func (m *Message) String() string {
switch m.MsgType {
case MsgTypeAudioOnlyServer, MsgTypeAudioOnlyClient:
if m.MsgTypeFlag == MsgTypeFlagPositiveSeq || m.MsgTypeFlag == MsgTypeFlagNegativeSeq {
return fmt.Sprintf("%s, %s, Sequence: %d, PayloadSize: %d", m.MsgType, m.EventType, m.Sequence, len(m.Payload))
}
return fmt.Sprintf("%s, %s, PayloadSize: %d", m.MsgType, m.EventType, len(m.Payload))
case MsgTypeError:
return fmt.Sprintf("%s, %s, ErrorCode: %d, Payload: %s", m.MsgType, m.EventType, m.ErrorCode, string(m.Payload))
default:
if m.MsgTypeFlag == MsgTypeFlagPositiveSeq || m.MsgTypeFlag == MsgTypeFlagNegativeSeq {
return fmt.Sprintf("%s, %s, Sequence: %d, Payload: %s",
m.MsgType, m.EventType, m.Sequence, string(m.Payload))
}
return fmt.Sprintf("%s, %s, Payload: %s", m.MsgType, m.EventType, string(m.Payload))
}
}
func (m *Message) Marshal() ([]byte, error) {
buf := new(bytes.Buffer)
header := []uint8{
uint8(m.Version)<<4 | uint8(m.HeaderSize),
uint8(m.MsgType)<<4 | uint8(m.MsgTypeFlag),
uint8(m.Serialization)<<4 | uint8(m.Compression),
}
headerSize := 4 * int(m.HeaderSize)
if padding := headerSize - len(header); padding > 0 {
header = append(header, make([]uint8, padding)...)
}
if err := binary.Write(buf, binary.BigEndian, header); err != nil {
return nil, err
}
writers, err := m.writers()
if err != nil {
return nil, err
}
for _, write := range writers {
if err := write(buf); err != nil {
return nil, err
}
}
return buf.Bytes(), nil
}
func (m *Message) Unmarshal(data []byte) error {
buf := bytes.NewBuffer(data)
versionAndHeaderSize, err := buf.ReadByte()
if err != nil {
return err
}
m.Version = VersionBits(versionAndHeaderSize >> 4)
m.HeaderSize = HeaderSizeBits(versionAndHeaderSize & 0b00001111)
_, err = buf.ReadByte()
if err != nil {
return err
}
serializationCompression, err := buf.ReadByte()
if err != nil {
return err
}
m.Serialization = SerializationBits(serializationCompression & 0b11110000)
m.Compression = CompressionBits(serializationCompression & 0b00001111)
headerSize := 4 * int(m.HeaderSize)
readSize := 3
if paddingSize := headerSize - readSize; paddingSize > 0 {
if n, err := buf.Read(make([]byte, paddingSize)); err != nil || n < paddingSize {
return fmt.Errorf("insufficient header bytes: expected %d, got %d", paddingSize, n)
}
}
readers, err := m.readers()
if err != nil {
return err
}
for _, read := range readers {
if err := read(buf); err != nil {
return err
}
}
if _, err := buf.ReadByte(); err != io.EOF {
return fmt.Errorf("unexpected data after message: %v", err)
}
return nil
}
func (m *Message) writers() (writers []func(*bytes.Buffer) error, _ error) {
if m.MsgTypeFlag == MsgTypeFlagWithEvent {
writers = append(writers, m.writeEvent, m.writeSessionID)
}
switch m.MsgType {
case MsgTypeFullClientRequest, MsgTypeFullServerResponse, MsgTypeFrontEndResultServer, MsgTypeAudioOnlyClient, MsgTypeAudioOnlyServer:
if m.MsgTypeFlag == MsgTypeFlagPositiveSeq || m.MsgTypeFlag == MsgTypeFlagNegativeSeq {
writers = append(writers, m.writeSequence)
}
case MsgTypeError:
writers = append(writers, m.writeErrorCode)
default:
return nil, fmt.Errorf("unsupported message type: %d", m.MsgType)
}
writers = append(writers, m.writePayload)
return writers, nil
}
func (m *Message) writeEvent(buf *bytes.Buffer) error {
return binary.Write(buf, binary.BigEndian, m.EventType)
}
func (m *Message) writeSessionID(buf *bytes.Buffer) error {
switch m.EventType {
case EventType_StartConnection, EventType_FinishConnection,
EventType_ConnectionStarted, EventType_ConnectionFailed:
return nil
}
size := len(m.SessionID)
if size > math.MaxUint32 {
return fmt.Errorf("session ID size (%d) exceeds max(uint32)", size)
}
if err := binary.Write(buf, binary.BigEndian, uint32(size)); err != nil {
return err
}
buf.WriteString(m.SessionID)
return nil
}
func (m *Message) writeSequence(buf *bytes.Buffer) error {
return binary.Write(buf, binary.BigEndian, m.Sequence)
}
func (m *Message) writeErrorCode(buf *bytes.Buffer) error {
return binary.Write(buf, binary.BigEndian, m.ErrorCode)
}
func (m *Message) writePayload(buf *bytes.Buffer) error {
size := len(m.Payload)
if size > math.MaxUint32 {
return fmt.Errorf("payload size (%d) exceeds max(uint32)", size)
}
if err := binary.Write(buf, binary.BigEndian, uint32(size)); err != nil {
return err
}
buf.Write(m.Payload)
return nil
}
func (m *Message) readers() (readers []func(*bytes.Buffer) error, _ error) {
switch m.MsgType {
case MsgTypeFullClientRequest, MsgTypeFullServerResponse, MsgTypeFrontEndResultServer, MsgTypeAudioOnlyClient, MsgTypeAudioOnlyServer:
if m.MsgTypeFlag == MsgTypeFlagPositiveSeq || m.MsgTypeFlag == MsgTypeFlagNegativeSeq {
readers = append(readers, m.readSequence)
}
case MsgTypeError:
readers = append(readers, m.readErrorCode)
default:
return nil, fmt.Errorf("unsupported message type: %d", m.MsgType)
}
if m.MsgTypeFlag == MsgTypeFlagWithEvent {
readers = append(readers, m.readEvent, m.readSessionID, m.readConnectID)
}
readers = append(readers, m.readPayload)
return readers, nil
}
func (m *Message) readEvent(buf *bytes.Buffer) error {
return binary.Read(buf, binary.BigEndian, &m.EventType)
}
func (m *Message) readSessionID(buf *bytes.Buffer) error {
switch m.EventType {
case EventType_StartConnection, EventType_FinishConnection,
EventType_ConnectionStarted, EventType_ConnectionFailed,
EventType_ConnectionFinished:
return nil
}
var size uint32
if err := binary.Read(buf, binary.BigEndian, &size); err != nil {
return err
}
if size > 0 {
m.SessionID = string(buf.Next(int(size)))
}
return nil
}
func (m *Message) readConnectID(buf *bytes.Buffer) error {
switch m.EventType {
case EventType_ConnectionStarted, EventType_ConnectionFailed,
EventType_ConnectionFinished:
default:
return nil
}
var size uint32
if err := binary.Read(buf, binary.BigEndian, &size); err != nil {
return err
}
if size > 0 {
m.ConnectID = string(buf.Next(int(size)))
}
return nil
}
func (m *Message) readSequence(buf *bytes.Buffer) error {
return binary.Read(buf, binary.BigEndian, &m.Sequence)
}
func (m *Message) readErrorCode(buf *bytes.Buffer) error {
return binary.Read(buf, binary.BigEndian, &m.ErrorCode)
}
func (m *Message) readPayload(buf *bytes.Buffer) error {
var size uint32
if err := binary.Read(buf, binary.BigEndian, &size); err != nil {
return err
}
if size > 0 {
m.Payload = buf.Next(int(size))
}
return nil
}
func ReceiveMessage(conn *websocket.Conn) (*Message, error) {
mt, frame, err := conn.ReadMessage()
if err != nil {
return nil, err
}
if mt != websocket.BinaryMessage && mt != websocket.TextMessage {
return nil, fmt.Errorf("unexpected Websocket message type: %d", mt)
}
msg, err := NewMessageFromBytes(frame)
if err != nil {
return nil, err
}
return msg, nil
}
func FullClientRequest(conn *websocket.Conn, payload []byte) error {
msg, err := NewMessage(MsgTypeFullClientRequest, MsgTypeFlagNoSeq)
if err != nil {
return err
}
msg.Payload = payload
frame, err := msg.Marshal()
if err != nil {
return err
}
return conn.WriteMessage(websocket.BinaryMessage, frame)
}

Some files were not shown because too many files have changed in this diff Show More