Compare commits

...

67 Commits

Author SHA1 Message Date
1808837298@qq.com
60aac77c08 fix: Correct Ollama channel authentication header setting 2025-02-20 01:28:15 +08:00
Calcium-Ion
6e0046f73c Merge pull request #773 from wellcoming/patch-1
fix: Fix Ollama channel authentication
2025-02-20 01:26:12 +08:00
Coming
a13f4d6c56 fix: Fix Ollama channel authentication 2025-02-20 00:52:30 +08:00
CalciumIon
4ce12ea6e3 feat: Improve mobile text truncation and sidebar visibility 2025-02-19 23:25:42 +08:00
1808837298@qq.com
971aea09ee feat: Improve image handling for Ollama channels 2025-02-19 20:45:42 +08:00
1808837298@qq.com
a4b2b9c935 feat: Enhance Ollama channel support with additional request parameters #771 2025-02-19 19:58:34 +08:00
1808837298@qq.com
ae5875d4c7 fix: Remove redundant error handling in distributor and relay modules 2025-02-19 18:47:28 +08:00
1808837298@qq.com
5937d850d9 refactor: Replace manual goroutine creation with gopool.Go 2025-02-19 18:38:29 +08:00
Calcium-Ion
2b7435500c Merge pull request #770 from Calcium-Ion/refactor_notify
feat: Add user notification settings and multiple notification methods
2025-02-19 14:54:54 +07:00
1808837298@qq.com
90191b8d5b chore: update env name and README 2025-02-19 15:54:33 +08:00
1808837298@qq.com
585c19fc70 docs: Add proxy usage information note in SystemSetting component 2025-02-19 15:45:09 +08:00
1808837298@qq.com
4e871507cf feat: Implement comprehensive webhook notification system 2025-02-19 15:40:54 +08:00
1808837298@qq.com
b1847509a4 refactor: Optimize user caching and token retrieval methods 2025-02-19 15:12:26 +08:00
Calcium-Ion
63f3412394 Merge pull request #768 from lgphone/main
bugfix: 配置文件 .env.example 示例配置错误
2025-02-18 19:35:08 +07:00
lgphone
a13bea5ffa Update .env.example
修复示例配置中MySQL的DSN错误问题
2025-02-18 19:18:54 +08:00
Calcium-Ion
2e3b920a2c Merge pull request #763 from Sh1n3zZ/support-imagen-3.0-generate-002
feat: add Gemini Imagen image generation support
2025-02-18 15:32:32 +07:00
1808837298@qq.com
812c188ab1 fix: Extend temperature handling for OpenAI-like models
- Add support for suppressing temperature for o1 models
- Expand model prefix check to include 'o1' alongside 'o3' models
2025-02-18 16:00:56 +08:00
1808837298@qq.com
0907a078b4 refactor: Simplify root user notification and remove global email variable
- Remove global `RootUserEmail` variable
- Modify channel testing and user notification methods to use `GetRootUser()`
- Update user cache and notification service to use more consistent user base type
- Add new channel test notification type
2025-02-18 15:59:17 +08:00
1808837298@qq.com
56f6b2ab56 feat: Implement notification rate limiting mechanism
- Add in-memory and Redis-based notification rate limiting
- Create configurable hourly notification limits
- Implement notification limit checking for user notifications
- Add environment variables for customizing notification limits
2025-02-18 15:30:43 +08:00
1808837298@qq.com
9d9c461c48 refactor: Improve CompletionRatio handling with thread-safe access and initialization 2025-02-18 15:01:43 +08:00
1808837298@qq.com
3da1344897 feat: Add user notification settings with quota warning and multiple notification methods
- Implement user notification settings with email and webhook options
- Add new user settings for quota warning threshold and notification preferences
- Create backend API and database support for user notification configuration
- Enhance frontend personal settings with notification configuration UI
- Support custom notification email and webhook URL
- Add service layer for sending user notifications
2025-02-18 14:54:21 +08:00
Sh1n3zZ
61d2a2f92d feat: add Gemini Imagen image generation support 2025-02-18 01:41:58 +08:00
1808837298@qq.com
995b3a2403 Merge remote-tracking branch 'origin/main' 2025-02-17 18:15:13 +08:00
1808837298@qq.com
7b384cb933 feat: Add support for DeepSeek completions endpoint 2025-02-17 18:15:01 +08:00
Calcium-Ion
78f19d4690 Merge pull request #735 from jyc001/main
feat:Add Supoorts to FIM
2025-02-17 14:37:06 +07:00
1808837298@qq.com
3239c60535 refactor: Optimize channel testing and model menu generation (fix #761) 2025-02-15 19:12:28 +08:00
1808837298@qq.com
e6f4587f6f refactor: Improve channel property update mechanism (fix #761) 2025-02-15 15:30:55 +08:00
Calcium-Ion
814be84500 Merge pull request #759 from nightcoffee/patch-1
feat: add 火山引擎 support stream options
2025-02-15 14:22:04 +07:00
nightcoffee
e7e5a16767 feat: add 火山引擎 support stream options 2025-02-15 04:55:57 +08:00
1808837298@qq.com
6bf99f218c feat: Enhance VolcEngine channel support with bot model routing (fix #757) 2025-02-15 00:10:58 +08:00
1808837298@qq.com
bd4ce9cd91 fix: Improve OpenAI stream data parsing and handling 2025-02-14 23:52:25 +08:00
1808837298@qq.com
9edb9f7a71 feat: Add automatic channel disabling based on configurable keywords
- Introduce AutomaticDisableKeywords setting to dynamically control channel disabling
- Implement AC search for matching error messages against disable keywords
- Add frontend UI for configuring automatic disable keywords
- Update localization with new keyword-based channel disabling feature
- Refactor sensitive word and AC search logic to support multiple keyword lists
2025-02-13 16:39:17 +08:00
1808837298@qq.com
bc62d1bb81 refactor: Optimize log retrieval with separate channel name fetching (fix #751)
- Remove inline channel join in log queries
- Implement separate channel name lookup for logs
- Improve performance by fetching channel names in a single query
- Ensure channel names are correctly associated with logs
2025-02-12 19:19:13 +08:00
1808837298@qq.com
6b923ef728 feat: Add invite link banner for specific channel type 2025-02-12 17:48:48 +08:00
1808837298@qq.com
81591f20e0 refactor: Optimize Dockerfile for Go build process
- Use alpine-based Golang image for smaller build size
- Simplify Go build command by removing static linking flag
- Improve Docker multi-stage build configuration
2025-02-12 17:18:23 +08:00
1808837298@qq.com
2072376694 docs: Update README with detailed Docker deployment and update instructions 2025-02-12 16:54:53 +08:00
1808837298@qq.com
871d73ecc9 fix: Update BaseURL placeholder text and label in channel edit page 2025-02-12 15:39:18 +08:00
1808837298@qq.com
f5e3063f33 feat: Improve embedding request handling and support across channels
- Update EmbeddingRequest DTO to support more flexible input types
- Add input parsing method to handle various input formats
- Implement ConvertEmbeddingRequest for multiple channel adaptors
- Remove relayMode parameter from EmbeddingHelper
- Add input validation for embedding requests
- Simplify embedding request conversion for different channels
2025-02-12 14:39:36 +08:00
1808837298@qq.com
eceb6afcdd feat: Add Baidu Qianfan V2 channel support #725
- Update channel constants to include Baidu V2 channel
- Create new Baidu V2 adaptor for relay
- Add Baidu V2 models and channel configuration
- Update relay adaptor to support Baidu V2 channel
- Modify web channel constants to include Baidu V2 option
2025-02-12 00:07:02 +08:00
1808837298@qq.com
28c13e5a0f feat: Add support for VolcEngine (Doubao) channel #313 #734 2025-02-11 23:47:15 +08:00
Calcium-Ion
81d11e5d31 Merge pull request #714 from NitroRCr/main
feat:  添加 AIaW 的聊天链接
2025-02-11 22:17:49 +07:00
Calcium-Ion
88bdedd2c9 Merge pull request #723 from kuwork/main
Support for MokaAI M3E
2025-02-11 22:16:18 +07:00
1808837298@qq.com
cf0ff0371b fix: adjust max tokens configuration in test request builder
- Update max tokens default value to 10
2025-02-11 20:00:05 +08:00
1808837298@qq.com
1f527ffc50 feat: enhance OpenAI request and response DTOs
- Add `Prefix` and `ReasoningContent` fields to Message struct
- Add getter and setter methods for `Prefix`
- Make `ToolCall.ID` field optional (fix #749)
2025-02-11 19:54:54 +08:00
1808837298@qq.com
cad8a83260 chore: disable cgo 2025-02-11 18:51:27 +08:00
1808837298@qq.com
40d878e8a9 chore: disable cgo 2025-02-11 18:51:09 +08:00
1808837298@qq.com
3a2e22443f chore: replace sqlite lib with prue go lib 2025-02-11 18:34:34 +08:00
1808837298@qq.com
13d1b8203c chore: update CI 2025-02-11 18:23:20 +08:00
1808837298@qq.com
7fce084aa5 update CI 2025-02-11 17:44:54 +08:00
1808837298@qq.com
cb4d40c3c8 feat: enhance session store security and configuration
- Add 30-day max age for session cookies
- Enable HttpOnly flag
- Set SameSite to strict mode
2025-02-11 17:06:51 +08:00
1808837298@qq.com
bbc1550a9e fix: update session store configuration
- Change session cookie path from "/api" to "/"
- Remove HttpOnly flag
2025-02-11 15:53:15 +08:00
1808837298@qq.com
6acc37cf27 feat: configure session store options for API routes
- Set session cookie path to "/api"
- Disable secure flag for local development
- Enable HttpOnly flag for improved security
2025-02-11 15:45:24 +08:00
Calcium-Ion
0e89939a12 Merge pull request #746 from zjjxwhh/main
fix: always use modelMapping in channel test
2025-02-11 12:21:06 +07:00
1808837298@qq.com
1b4fe8600e chore: update CI 2025-02-11 13:14:38 +08:00
zjjxwhh
882c5970d9 fix: always use modelMapping in channel test 2025-02-10 22:39:56 +08:00
1808837298@qq.com
d10b47005c chore: update CI 2025-02-10 21:59:41 +08:00
1808837298@qq.com
8418dbe7c4 fix: replace context-based user ID with session-based retrieval #741
- Update user and wechat controllers to use sessions for user ID
- Modify ID retrieval to use `session.Get("id")` instead of `c.GetInt("id")`
- Cast session ID to int when creating user object
2025-02-10 20:52:33 +08:00
e.
206dbfa45e Merge pull request #2 from jyc001/dev
fix: correct JSON tags for `Prompt` and `Suffix` in `GeneralOpenAIReq…
2025-02-08 00:37:37 +08:00
e.
1eb72f2f22 fix: correct JSON tags for Prompt and Suffix in GeneralOpenAIRequest 2025-02-08 00:36:42 +08:00
e.
68bd7f70a4 Merge pull request #1 from jyc001/dev
Dev
2025-02-08 00:25:49 +08:00
e.
8082905184 feat: add Suffix to GeneralOpenAIRequest in order to support FIM 2025-02-08 00:25:08 +08:00
e.
ce4269955e feat add FIM support for siliconflow 2025-02-08 00:23:35 +08:00
kuwork
89d48a6618 Merge branch 'main' into main 2025-02-04 22:52:37 +08:00
NitroRCr
324d127a88 feat: add chat link for AIaW 2025-01-25 11:57:54 +08:00
Jerry
7588c42b42 Fix M3E not working 2025-01-23 05:54:39 +08:00
Jerry
8a2d220cf4 fix : chanel test did not refresh 2025-01-22 13:16:06 +08:00
Jerry
126f04e08f Support for MokaAI M3E 2025-01-22 04:21:08 +08:00
103 changed files with 2711 additions and 2132 deletions

View File

@@ -10,9 +10,9 @@
# 数据库相关配置
# 数据库连接字符串
# SQL_DSN=mysql://user:password@tcp(127.0.0.1:3306)/dbname?parseTime=true
# SQL_DSN=user:password@tcp(127.0.0.1:3306)/dbname?parseTime=true
# 日志数据库连接字符串
# LOG_SQL_DSN=mysql://user:password@tcp(127.0.0.1:3306)/logdb?parseTime=true
# LOG_SQL_DSN=user:password@tcp(127.0.0.1:3306)/logdb?parseTime=true
# SQLite数据库路径
# SQLITE_PATH=/path/to/sqlite.db
# 数据库最大空闲连接数

View File

@@ -13,7 +13,7 @@ on:
jobs:
push_to_registries:
name: Push Docker image to multiple registries
runs-on: self-hosted
runs-on: ubuntu-latest
permissions:
packages: write
contents: read

View File

@@ -1,4 +1,4 @@
FROM oven/bun:latest as builder
FROM oven/bun:latest AS builder
WORKDIR /build
COPY web/package.json .
@@ -7,18 +7,20 @@ COPY ./web .
COPY ./VERSION .
RUN DISABLE_ESLINT_PLUGIN='true' VITE_REACT_APP_VERSION=$(cat VERSION) bun run build
FROM golang AS builder2
FROM golang:alpine AS builder2
ENV GO111MODULE=on \
CGO_ENABLED=1 \
CGO_ENABLED=0 \
GOOS=linux
WORKDIR /build
ADD go.mod go.sum ./
RUN go mod download
COPY . .
COPY --from=builder /build/dist ./web/dist
RUN go build -ldflags "-s -w -X 'one-api/common.Version=$(cat VERSION)' -extldflags '-static'" -o one-api
RUN go build -ldflags "-s -w -X 'one-api/common.Version=$(cat VERSION)'" -o one-api
FROM alpine

View File

@@ -89,15 +89,14 @@ You can add custom models gpt-4-gizmo-* in channels. These are third-party model
- `MAX_FILE_DOWNLOAD_MB`: Maximum file download size in MB, default `20`
- `CRYPTO_SECRET`: Encryption key for encrypting database content
- `AZURE_DEFAULT_API_VERSION`: Azure channel default API version, if not specified in channel settings, use this version, default `2024-12-01-preview`
- `NOTIFICATION_LIMIT_DURATION_MINUTE`: Duration of notification limit in minutes, default `10`
- `NOTIFY_LIMIT_COUNT`: Maximum number of user notifications in the specified duration, default `2`
## Deployment
> [!TIP]
> Latest Docker image: `calciumion/new-api:latest`
> Default account: root, password: 123456
> Update command:
> ```
> docker run --rm -v /var/run/docker.sock:/var/run/docker.sock containrrr/watchtower -cR
> ```
> Default account: root, password: 123456
### Multi-Server Deployment
- Must set `SESSION_SECRET` environment variable, otherwise login state will not be consistent across multiple servers.
@@ -107,26 +106,58 @@ You can add custom models gpt-4-gizmo-* in channels. These are third-party model
- Local database (default): SQLite (Docker deployment must mount `/data` directory)
- Remote database: MySQL >= 5.7.8, PgSQL >= 9.6
### Deployment with BT Panel
Install BT Panel (**version 9.2.0** or above) from [BT Panel Official Website](https://www.bt.cn/new/download.html), choose the stable version script to download and install.
After installation, log in to BT Panel and click Docker in the menu bar. First-time access will prompt to install Docker service. Click Install Now and follow the prompts to complete installation.
After installation, find **New-API** in the app store, click install, configure basic options to complete installation.
[Pictorial Guide](BT.md)
### Docker Deployment
### Using Docker Compose (Recommended)
```shell
# Clone project
git clone https://github.com/Calcium-Ion/new-api.git
cd new-api
# Edit docker-compose.yml as needed
# nano docker-compose.yml
# vim docker-compose.yml
# Start
docker-compose up -d
```
#### Update Version
```shell
docker-compose pull
docker-compose up -d
```
### Direct Docker Image Usage
```shell
# SQLite deployment:
docker run --name new-api -d --restart always -p 3000:3000 -e TZ=Asia/Shanghai -v /home/ubuntu/data/new-api:/data calciumion/new-api:latest
# MySQL deployment (add -e SQL_DSN="root:123456@tcp(localhost:3306)/oneapi"), modify database connection parameters as needed
# Example:
docker run --name new-api -d --restart always -p 3000:3000 -e SQL_DSN="root:123456@tcp(localhost:3306)/oneapi" -e TZ=Asia/Shanghai -v /home/ubuntu/data/new-api:/data calciumion/new-api:latest
```
#### Update Version
```shell
# Pull the latest image
docker pull calciumion/new-api:latest
# Stop and remove the old container
docker stop new-api
docker rm new-api
# Run the new container with the same parameters as before
docker run --name new-api -d --restart always -p 3000:3000 -e TZ=Asia/Shanghai -v /home/ubuntu/data/new-api:/data calciumion/new-api:latest
```
Alternatively, you can use Watchtower for automatic updates (not recommended, may cause database incompatibility):
```shell
docker run --rm -v /var/run/docker.sock:/var/run/docker.sock containrrr/watchtower -cR
```
## Channel Retry
Channel retry is implemented, configurable in `Settings->Operation Settings->General Settings`. **Cache recommended**.
First retry uses same priority, second retry uses next priority, and so on.

View File

@@ -95,14 +95,14 @@
- `MAX_FILE_DOWNLOAD_MB`: 最大文件下载大小,单位 MB默认为 `20`
- `CRYPTO_SECRET`:加密密钥,用于加密数据库内容。
- `AZURE_DEFAULT_API_VERSION`Azure渠道默认API版本如果渠道设置中未指定API版本则使用此版本默认为 `2024-12-01-preview`
- `NOTIFICATION_LIMIT_DURATION_MINUTE`:通知限制的持续时间(分钟),默认为 `10`
- `NOTIFY_LIMIT_COUNT`:用户通知在指定持续时间内的最大数量,默认为 `2`
## 部署
> [!TIP]
> 最新版Docker镜像`calciumion/new-api:latest`
> 默认账号root 密码123456
> 更新指令:
> ```
> docker run --rm -v /var/run/docker.sock:/var/run/docker.sock containrrr/watchtower -cR
> ```
> 默认账号root 密码123456
### 多机部署
- 必须设置环境变量 `SESSION_SECRET`,否则会导致多机部署时登录状态不一致。
@@ -119,25 +119,54 @@
[图文教程](BT.md)
### 基于 Docker 进行部署
> [!TIP]
> 默认管理员账号root 密码123456
### 使用 Docker Compose 部署(推荐)
```shell
# 下载项目
git clone https://github.com/Calcium-Ion/new-api.git
cd new-api
# 按需编辑 docker-compose.yml
# nano docker-compose.yml
# vim docker-compose.yml
# 启动
docker-compose up -d
```
#### 更新版本
```shell
docker-compose pull
docker-compose up -d
```
### 直接使用 Docker 镜像
```shell
# 使用 SQLite 的部署命令:
docker run --name new-api -d --restart always -p 3000:3000 -e TZ=Asia/Shanghai -v /home/ubuntu/data/new-api:/data calciumion/new-api:latest
# 使用 MySQL 的部署命令,在上面的基础上添加 `-e SQL_DSN="root:123456@tcp(localhost:3306)/oneapi"`,请自行修改数据库连接参数。
# 例如:
docker run --name new-api -d --restart always -p 3000:3000 -e SQL_DSN="root:123456@tcp(localhost:3306)/oneapi" -e TZ=Asia/Shanghai -v /home/ubuntu/data/new-api:/data calciumion/new-api:latest
```
#### 更新版本
```shell
# 拉取最新镜像
docker pull calciumion/new-api:latest
# 停止并删除旧容器
docker stop new-api
docker rm new-api
# 使用相同参数运行新容器
docker run --name new-api -d --restart always -p 3000:3000 -e TZ=Asia/Shanghai -v /home/ubuntu/data/new-api:/data calciumion/new-api:latest
```
或者使用 Watchtower 自动更新(不推荐,可能会导致数据库不兼容):
```shell
docker run --rm -v /var/run/docker.sock:/var/run/docker.sock containrrr/watchtower -cR
```
## 渠道重试
渠道重试功能已经实现,可以在`设置->运营设置->通用设置`设置重试次数,**建议开启缓存**功能。
如果开启了重试功能,第一次重试使用同优先级,第二次重试使用下一个优先级,以此类推。

View File

@@ -101,7 +101,7 @@ var PreConsumedQuota = 500
var RetryTimes = 0
var RootUserEmail = ""
//var RootUserEmail = ""
var IsMasterNode = os.Getenv("NODE_TYPE") != "slave"
@@ -231,8 +231,10 @@ const (
ChannelTypeVertexAi = 41
ChannelTypeMistral = 42
ChannelTypeDeepSeek = 43
ChannelTypeDummy // this one is only for count, do not add any channel after this
ChannelTypeMokaAI = 44
ChannelTypeVolcEngine = 45
ChannelTypeBaiduV2 = 46
ChannelTypeDummy // this one is only for count, do not add any channel after this
)
@@ -281,4 +283,7 @@ var ChannelBaseURLs = []string{
"", //41
"https://api.mistral.ai", //42
"https://api.deepseek.com", //43
"https://api.moka.ai", //44
"https://ark.cn-beijing.volces.com", //45
"https://qianfan.baidubce.com", //46
}

View File

@@ -3,5 +3,6 @@ package common
var UsingSQLite = false
var UsingPostgreSQL = false
var UsingMySQL = false
var UsingClickHouse = false
var SQLitePath = "one-api.db?_busy_timeout=5000"

View File

@@ -1,22 +1,9 @@
package common
import (
"fmt"
"runtime/debug"
"time"
)
func SafeGoroutine(f func()) {
go func() {
defer func() {
if r := recover(); r != nil {
SysError(fmt.Sprintf("child goroutine panic occured: error: %v, stack: %s", r, string(debug.Stack())))
}
}()
f()
}()
}
func SafeSendBool(ch chan bool, value bool) (closed bool) {
defer func() {
// Recover from panic if one occured. A panic would mean the channel was closed.

View File

@@ -4,6 +4,7 @@ import (
"context"
"encoding/json"
"fmt"
"github.com/bytedance/gopkg/util/gopool"
"github.com/gin-gonic/gin"
"io"
"log"
@@ -80,9 +81,9 @@ func logHelper(ctx context.Context, level string, msg string) {
if logCount > maxLogCount && !setupLogWorking {
logCount = 0
setupLogWorking = true
go func() {
gopool.Go(func() {
SetupLogger()
}()
})
}
}
@@ -100,6 +101,14 @@ func LogQuota(quota int) string {
}
}
func FormatQuota(quota int) string {
if DisplayInCurrencyEnabled {
return fmt.Sprintf("%.6f", float64(quota)/QuotaPerUnit)
} else {
return fmt.Sprintf("%d", quota)
}
}
// LogJson 仅供测试使用 only for test
func LogJson(ctx context.Context, msg string, obj any) {
jsonStr, err := json.Marshal(obj)

View File

@@ -233,7 +233,11 @@ var (
modelRatioMapMutex = sync.RWMutex{}
)
var CompletionRatio map[string]float64 = nil
var (
CompletionRatio map[string]float64 = nil
CompletionRatioMutex = sync.RWMutex{}
)
var defaultCompletionRatio = map[string]float64{
"gpt-4-gizmo-*": 2,
"gpt-4o-gizmo-*": 3,
@@ -334,10 +338,17 @@ func GetDefaultModelRatioMap() map[string]float64 {
return defaultModelRatio
}
func CompletionRatio2JSONString() string {
func GetCompletionRatioMap() map[string]float64 {
CompletionRatioMutex.Lock()
defer CompletionRatioMutex.Unlock()
if CompletionRatio == nil {
CompletionRatio = defaultCompletionRatio
}
return CompletionRatio
}
func CompletionRatio2JSONString() string {
GetCompletionRatioMap()
jsonBytes, err := json.Marshal(CompletionRatio)
if err != nil {
SysError("error marshalling completion ratio: " + err.Error())
@@ -346,11 +357,15 @@ func CompletionRatio2JSONString() string {
}
func UpdateCompletionRatioByJSONString(jsonStr string) error {
CompletionRatioMutex.Lock()
defer CompletionRatioMutex.Unlock()
CompletionRatio = make(map[string]float64)
return json.Unmarshal([]byte(jsonStr), &CompletionRatio)
}
func GetCompletionRatio(name string) float64 {
GetCompletionRatioMap()
if strings.Contains(name, "/") {
if ratio, ok := CompletionRatio[name]; ok {
return ratio
@@ -476,24 +491,3 @@ func GetAudioCompletionRatio(name string) float64 {
}
return 2
}
//func GetAudioPricePerMinute(name string) float64 {
// if strings.HasPrefix(name, "gpt-4o-realtime") {
// return 0.06
// }
// return 0.06
//}
//
//func GetAudioCompletionPricePerMinute(name string) float64 {
// if strings.HasPrefix(name, "gpt-4o-realtime") {
// return 0.24
// }
// return 0.24
//}
func GetCompletionRatioMap() map[string]float64 {
if CompletionRatio == nil {
CompletionRatio = defaultCompletionRatio
}
return CompletionRatio
}

View File

@@ -29,6 +29,9 @@ var GeminiModelMap = map[string]string{
var GeminiVisionMaxImageNum = common.GetEnvOrDefault("GEMINI_VISION_MAX_IMAGE_NUM", 16)
var NotifyLimitCount = common.GetEnvOrDefault("NOTIFY_LIMIT_COUNT", 2)
var NotificationLimitDurationMinute = common.GetEnvOrDefault("NOTIFICATION_LIMIT_DURATION_MINUTE", 10)
func InitEnv() {
modelVersionMapStr := strings.TrimSpace(os.Getenv("GEMINI_MODEL_MAP"))
if modelVersionMapStr == "" {
@@ -44,5 +47,5 @@ func InitEnv() {
}
}
// 是否生成初始令牌,默认关闭。
// GenerateDefaultToken 是否生成初始令牌,默认关闭。
var GenerateDefaultToken = common.GetEnvOrDefaultBool("GENERATE_DEFAULT_TOKEN", false)

14
constant/user_setting.go Normal file
View File

@@ -0,0 +1,14 @@
package constant
var (
UserSettingNotifyType = "notify_type" // QuotaWarningType 额度预警类型
UserSettingQuotaWarningThreshold = "quota_warning_threshold" // QuotaWarningThreshold 额度预警阈值
UserSettingWebhookUrl = "webhook_url" // WebhookUrl webhook地址
UserSettingWebhookSecret = "webhook_secret" // WebhookSecret webhook密钥
UserSettingNotificationEmail = "notification_email" // NotificationEmail 通知邮箱地址
)
var (
NotifyTypeEmail = "email" // Email 邮件
NotifyTypeWebhook = "webhook" // Webhook
)

View File

@@ -41,9 +41,21 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
}
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
requestPath := "/v1/chat/completions"
// 先判断是否为 Embedding 模型
if strings.Contains(strings.ToLower(testModel), "embedding") ||
strings.HasPrefix(testModel, "m3e") || // m3e 系列模型
strings.Contains(testModel, "bge-") || // bge 系列模型
testModel == "text-embedding-v1" ||
channel.Type == common.ChannelTypeMokaAI { // 其他 embedding 模型
requestPath = "/v1/embeddings" // 修改请求路径
}
c.Request = &http.Request{
Method: "POST",
URL: &url.URL{Path: "/v1/chat/completions"},
URL: &url.URL{Path: requestPath}, // 使用动态路径
Body: nil,
Header: make(http.Header),
}
@@ -55,20 +67,20 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
if len(channel.GetModels()) > 0 {
testModel = channel.GetModels()[0]
} else {
testModel = "gpt-3.5-turbo"
testModel = "gpt-4o-mini"
}
}
} else {
modelMapping := *channel.ModelMapping
if modelMapping != "" && modelMapping != "{}" {
modelMap := make(map[string]string)
err := json.Unmarshal([]byte(modelMapping), &modelMap)
if err != nil {
return err, service.OpenAIErrorWrapperLocal(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
}
if modelMap[testModel] != "" {
testModel = modelMap[testModel]
}
}
modelMapping := *channel.ModelMapping
if modelMapping != "" && modelMapping != "{}" {
modelMap := make(map[string]string)
err := json.Unmarshal([]byte(modelMapping), &modelMap)
if err != nil {
return err, service.OpenAIErrorWrapperLocal(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
}
if modelMap[testModel] != "" {
testModel = modelMap[testModel]
}
}
@@ -88,7 +100,7 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
request := buildTestRequest(testModel)
meta.UpstreamModelName = testModel
common.SysLog(fmt.Sprintf("testing channel %d with model %s", channel.Id, testModel))
common.SysLog(fmt.Sprintf("testing channel %d with model %s , meta %v ", channel.Id, testModel, meta))
adaptor.Init(meta)
@@ -156,12 +168,21 @@ func buildTestRequest(model string) *dto.GeneralOpenAIRequest {
Model: "", // this will be set later
Stream: false,
}
// 先判断是否为 Embedding 模型
if strings.Contains(strings.ToLower(model), "embedding") ||
strings.HasPrefix(model, "m3e") || // m3e 系列模型
strings.Contains(model, "bge-") || // bge 系列模型
model == "text-embedding-v1" { // 其他 embedding 模型
// Embedding 请求
testRequest.Input = []string{"hello world"}
return testRequest
}
// 并非Embedding 模型
if strings.HasPrefix(model, "o1") || strings.HasPrefix(model, "o3") {
testRequest.MaxCompletionTokens = 10
} else if strings.HasPrefix(model, "gemini-2.0-flash-thinking") {
testRequest.MaxTokens = 10
} else {
testRequest.MaxTokens = 1
testRequest.MaxTokens = 10
}
content, _ := json.Marshal("hi")
testMessage := dto.Message{
@@ -217,9 +238,7 @@ var testAllChannelsLock sync.Mutex
var testAllChannelsRunning bool = false
func testAllChannels(notify bool) error {
if common.RootUserEmail == "" {
common.RootUserEmail = model.GetRootUserEmail()
}
testAllChannelsLock.Lock()
if testAllChannelsRunning {
testAllChannelsLock.Unlock()
@@ -274,10 +293,7 @@ func testAllChannels(notify bool) error {
testAllChannelsRunning = false
testAllChannelsLock.Unlock()
if notify {
err := common.SendEmail("通道测试完成", common.RootUserEmail, "通道测试完成,如果没有收到禁用通知,说明所有通道都正常")
if err != nil {
common.SysError(fmt.Sprintf("failed to send email: %s", err.Error()))
}
service.NotifyRootUser(dto.NotifyTypeChannelTest, "通道测试完成", "所有通道测试已完成")
}
})
return nil

View File

@@ -17,7 +17,7 @@ func GetPricing(c *gin.Context) {
}
var group string
if exists {
user, err := model.GetUserById(userId.(int), false)
user, err := model.GetUserCache(userId.(int))
if err == nil {
group = user.Group
}

View File

@@ -33,6 +33,8 @@ func relayHandler(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode
err = relay.AudioHelper(c)
case relayconstant.RelayModeRerank:
err = relay.RerankHelper(c, relayMode)
case relayconstant.RelayModeEmbeddings:
err = relay.EmbeddingHelper(c)
default:
err = relay.TextHelper(c)
}

View File

@@ -4,6 +4,7 @@ import (
"encoding/json"
"fmt"
"net/http"
"net/url"
"one-api/common"
"one-api/model"
"one-api/setting"
@@ -471,7 +472,7 @@ func GetUserModels(c *gin.Context) {
if err != nil {
id = c.GetInt("id")
}
user, err := model.GetUserById(id, true)
user, err := model.GetUserCache(id)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
@@ -846,9 +847,10 @@ func EmailBind(c *gin.Context) {
})
return
}
id := c.GetInt("id")
session := sessions.Default(c)
id := session.Get("id")
user := model.User{
Id: id,
Id: id.(int),
}
err := user.FillUserById()
if err != nil {
@@ -868,9 +870,6 @@ func EmailBind(c *gin.Context) {
})
return
}
if user.Role == common.RoleRootUser {
common.RootUserEmail = email
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
@@ -912,3 +911,115 @@ func TopUp(c *gin.Context) {
})
return
}
type UpdateUserSettingRequest struct {
QuotaWarningType string `json:"notify_type"`
QuotaWarningThreshold int `json:"quota_warning_threshold"`
WebhookUrl string `json:"webhook_url,omitempty"`
WebhookSecret string `json:"webhook_secret,omitempty"`
NotificationEmail string `json:"notification_email,omitempty"`
}
func UpdateUserSetting(c *gin.Context) {
var req UpdateUserSettingRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "无效的参数",
})
return
}
// 验证预警类型
if req.QuotaWarningType != constant.NotifyTypeEmail && req.QuotaWarningType != constant.NotifyTypeWebhook {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "无效的预警类型",
})
return
}
// 验证预警阈值
if req.QuotaWarningThreshold <= 0 {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "预警阈值必须大于0",
})
return
}
// 如果是webhook类型,验证webhook地址
if req.QuotaWarningType == constant.NotifyTypeWebhook {
if req.WebhookUrl == "" {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "Webhook地址不能为空",
})
return
}
// 验证URL格式
if _, err := url.ParseRequestURI(req.WebhookUrl); err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "无效的Webhook地址",
})
return
}
}
// 如果是邮件类型,验证邮箱地址
if req.QuotaWarningType == constant.NotifyTypeEmail && req.NotificationEmail != "" {
// 验证邮箱格式
if !strings.Contains(req.NotificationEmail, "@") {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "无效的邮箱地址",
})
return
}
}
userId := c.GetInt("id")
user, err := model.GetUserById(userId, true)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
// 构建设置
settings := map[string]interface{}{
constant.UserSettingNotifyType: req.QuotaWarningType,
constant.UserSettingQuotaWarningThreshold: req.QuotaWarningThreshold,
}
// 如果是webhook类型,添加webhook相关设置
if req.QuotaWarningType == constant.NotifyTypeWebhook {
settings[constant.UserSettingWebhookUrl] = req.WebhookUrl
if req.WebhookSecret != "" {
settings[constant.UserSettingWebhookSecret] = req.WebhookSecret
}
}
// 如果提供了通知邮箱,添加到设置中
if req.QuotaWarningType == constant.NotifyTypeEmail && req.NotificationEmail != "" {
settings[constant.UserSettingNotificationEmail] = req.NotificationEmail
}
// 更新用户设置
user.SetSetting(settings)
if err := user.Update(false); err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "更新设置失败: " + err.Error(),
})
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "设置已更新",
})
}

View File

@@ -4,6 +4,7 @@ import (
"encoding/json"
"errors"
"fmt"
"github.com/gin-contrib/sessions"
"github.com/gin-gonic/gin"
"net/http"
"one-api/common"
@@ -142,9 +143,10 @@ func WeChatBind(c *gin.Context) {
})
return
}
id := c.GetInt("id")
session := sessions.Default(c)
id := session.Get("id")
user := model.User{
Id: id,
Id: id.(int),
}
err = user.FillUserById()
if err != nil {

View File

@@ -24,7 +24,7 @@ services:
- redis
- mysql
healthcheck:
test: [ "CMD-SHELL", "wget -q -O - http://localhost:3000/api/status | grep -o '\"success\":\\s*true' | awk -F: '{print $2}'" ]
test: ["CMD-SHELL", "wget -q -O - http://localhost:3000/api/status | grep -o '\"success\":\\s*true' | awk -F: '{print $$2}'"]
interval: 30s
timeout: 10s
retries: 3

57
dto/embedding.go Normal file
View File

@@ -0,0 +1,57 @@
package dto
type EmbeddingOptions struct {
Seed int `json:"seed,omitempty"`
Temperature *float64 `json:"temperature,omitempty"`
TopK int `json:"top_k,omitempty"`
TopP *float64 `json:"top_p,omitempty"`
FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"`
PresencePenalty *float64 `json:"presence_penalty,omitempty"`
NumPredict int `json:"num_predict,omitempty"`
NumCtx int `json:"num_ctx,omitempty"`
}
type EmbeddingRequest struct {
Model string `json:"model"`
Input any `json:"input"`
EncodingFormat string `json:"encoding_format,omitempty"`
Dimensions int `json:"dimensions,omitempty"`
User string `json:"user,omitempty"`
Seed float64 `json:"seed,omitempty"`
Temperature *float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
PresencePenalty float64 `json:"presence_penalty,omitempty"`
}
func (r EmbeddingRequest) ParseInput() []string {
if r.Input == nil {
return nil
}
var input []string
switch r.Input.(type) {
case string:
input = []string{r.Input.(string)}
case []any:
input = make([]string, 0, len(r.Input.([]any)))
for _, item := range r.Input.([]any) {
if str, ok := item.(string); ok {
input = append(input, str)
}
}
}
return input
}
type EmbeddingResponseItem struct {
Object string `json:"object"`
Index int `json:"index"`
Embedding []float64 `json:"embedding"`
}
type EmbeddingResponse struct {
Object string `json:"object"`
Data []EmbeddingResponseItem `json:"data"`
Model string `json:"model"`
Usage `json:"usage"`
}

25
dto/notify.go Normal file
View File

@@ -0,0 +1,25 @@
package dto
type Notify struct {
Type string `json:"type"`
Title string `json:"title"`
Content string `json:"content"`
Values []interface{} `json:"values"`
}
const ContentValueParam = "{{value}}"
const (
NotifyTypeQuotaExceed = "quota_exceed"
NotifyTypeChannelUpdate = "channel_update"
NotifyTypeChannelTest = "channel_test"
)
func NewNotify(t string, title string, content string, values []interface{}) Notify {
return Notify{
Type: t,
Title: title,
Content: content,
Values: values,
}
}

View File

@@ -18,6 +18,8 @@ type GeneralOpenAIRequest struct {
Model string `json:"model,omitempty"`
Messages []Message `json:"messages,omitempty"`
Prompt any `json:"prompt,omitempty"`
Prefix any `json:"prefix,omitempty"`
Suffix any `json:"suffix,omitempty"`
Stream bool `json:"stream,omitempty"`
StreamOptions *StreamOptions `json:"stream_options,omitempty"`
MaxTokens uint `json:"max_tokens,omitempty"`
@@ -86,11 +88,15 @@ func (r GeneralOpenAIRequest) ParseInput() []string {
}
type Message struct {
Role string `json:"role"`
Content json.RawMessage `json:"content"`
Name *string `json:"name,omitempty"`
ToolCalls json.RawMessage `json:"tool_calls,omitempty"`
ToolCallId string `json:"tool_call_id,omitempty"`
Role string `json:"role"`
Content json.RawMessage `json:"content"`
// parsedContent not json field
parsedContent []MediaContent
Name *string `json:"name,omitempty"`
Prefix *bool `json:"prefix,omitempty"`
ReasoningContent string `json:"reasoning_content,omitempty"`
ToolCalls json.RawMessage `json:"tool_calls,omitempty"`
ToolCallId string `json:"tool_call_id,omitempty"`
}
type MediaContent struct {
@@ -116,6 +122,17 @@ const (
ContentTypeInputAudio = "input_audio"
)
func (m *Message) GetPrefix() bool {
if m.Prefix == nil {
return false
}
return *m.Prefix
}
func (m *Message) SetPrefix(prefix bool) {
m.Prefix = &prefix
}
func (m *Message) ParseToolCalls() []ToolCall {
if m.ToolCalls == nil {
return nil
@@ -145,6 +162,11 @@ func (m *Message) SetStringContent(content string) {
m.Content = jsonContent
}
func (m *Message) SetMediaContent(content []MediaContent) {
jsonContent, _ := json.Marshal(content)
m.Content = jsonContent
}
func (m *Message) IsStringContent() bool {
var stringContent string
if err := json.Unmarshal(m.Content, &stringContent); err == nil {
@@ -154,7 +176,15 @@ func (m *Message) IsStringContent() bool {
}
func (m *Message) ParseContent() []MediaContent {
if m.parsedContent != nil {
return m.parsedContent
}
var contentList []MediaContent
defer func() {
if len(contentList) > 0 {
m.parsedContent = contentList
}
}()
var stringContent string
if err := json.Unmarshal(m.Content, &stringContent); err == nil {
contentList = append(contentList, MediaContent{

View File

@@ -81,7 +81,7 @@ func (c *ChatCompletionsStreamResponseChoiceDelta) GetContentString() string {
type ToolCall struct {
// Index is not nil only in chat completion chunk object
Index *int `json:"index,omitempty"`
ID string `json:"id"`
ID string `json:"id,omitempty"`
Type any `json:"type"`
Function FunctionCall `json:"function"`
}

16
go.mod
View File

@@ -16,6 +16,7 @@ require (
github.com/gin-contrib/sessions v0.0.5
github.com/gin-contrib/static v0.0.1
github.com/gin-gonic/gin v1.9.1
github.com/glebarez/sqlite v1.9.0
github.com/go-playground/validator/v10 v10.20.0
github.com/go-redis/redis/v8 v8.11.5
github.com/golang-jwt/jwt v3.2.2+incompatible
@@ -29,10 +30,10 @@ require (
github.com/shirou/gopsutil v3.21.11+incompatible
golang.org/x/crypto v0.27.0
golang.org/x/image v0.23.0
golang.org/x/net v0.28.0
gorm.io/driver/mysql v1.4.3
gorm.io/driver/postgres v1.5.2
gorm.io/driver/sqlite v1.4.3
gorm.io/gorm v1.25.0
gorm.io/gorm v1.25.2
)
require (
@@ -48,12 +49,14 @@ require (
github.com/cloudwego/iasm v0.2.0 // indirect
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
github.com/dlclark/regexp2 v1.11.0 // indirect
github.com/dustin/go-humanize v1.0.1 // indirect
github.com/gabriel-vasile/mimetype v1.4.3 // indirect
github.com/gin-contrib/sse v0.1.0 // indirect
github.com/glebarez/go-sqlite v1.21.2 // indirect
github.com/go-ole/go-ole v1.2.6 // indirect
github.com/go-playground/locales v0.14.1 // indirect
github.com/go-playground/universal-translator v0.18.1 // indirect
github.com/go-sql-driver/mysql v1.6.0 // indirect
github.com/go-sql-driver/mysql v1.7.0 // indirect
github.com/goccy/go-json v0.10.2 // indirect
github.com/google/go-cmp v0.6.0 // indirect
github.com/gorilla/context v1.1.1 // indirect
@@ -69,11 +72,11 @@ require (
github.com/klauspost/cpuid/v2 v2.2.9 // indirect
github.com/leodido/go-urn v1.4.0 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect
github.com/mattn/go-sqlite3 v2.0.3+incompatible // indirect
github.com/mitchellh/mapstructure v1.5.0 // indirect
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
github.com/modern-go/reflect2 v1.0.2 // indirect
github.com/pelletier/go-toml/v2 v2.2.1 // indirect
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
github.com/tklauser/go-sysconf v0.3.12 // indirect
github.com/tklauser/numcpus v0.6.1 // indirect
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
@@ -81,10 +84,13 @@ require (
github.com/yusufpapurcu/wmi v1.2.3 // indirect
golang.org/x/arch v0.12.0 // indirect
golang.org/x/exp v0.0.0-20240404231335-c0f41cb1a7a0 // indirect
golang.org/x/net v0.28.0 // indirect
golang.org/x/sync v0.10.0 // indirect
golang.org/x/sys v0.27.0 // indirect
golang.org/x/text v0.21.0 // indirect
google.golang.org/protobuf v1.34.2 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
modernc.org/libc v1.22.5 // indirect
modernc.org/mathutil v1.5.0 // indirect
modernc.org/memory v1.5.0 // indirect
modernc.org/sqlite v1.23.1 // indirect
)

32
go.sum
View File

@@ -40,6 +40,8 @@ github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/r
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
github.com/dlclark/regexp2 v1.11.0 h1:G/nrcoOa7ZXlpoa/91N3X7mM3r8eIlMBBJZvsz/mxKI=
github.com/dlclark/regexp2 v1.11.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4=
github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ=
github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0=
@@ -58,6 +60,10 @@ github.com/gin-gonic/gin v1.6.3/go.mod h1:75u5sXoLsGZoRN5Sgbi1eraJ4GU3++wFwWzhwv
github.com/gin-gonic/gin v1.8.1/go.mod h1:ji8BvRH1azfM+SYow9zQ6SZMvR8qOMZHmsCuWR9tTTk=
github.com/gin-gonic/gin v1.9.1 h1:4idEAncQnU5cB7BeOkPtxjfCSye0AAm1R0RVIqJ+Jmg=
github.com/gin-gonic/gin v1.9.1/go.mod h1:hPrL7YrpYKXt5YId3A/Tnip5kqbEAP+KLuI3SUcPTeU=
github.com/glebarez/go-sqlite v1.21.2 h1:3a6LFC4sKahUunAmynQKLZceZCOzUthkRkEAl9gAXWo=
github.com/glebarez/go-sqlite v1.21.2/go.mod h1:sfxdZyhQjTM2Wry3gVYWaW072Ri1WMdWJi0k6+3382k=
github.com/glebarez/sqlite v1.9.0 h1:Aj6bPA12ZEx5GbSF6XADmCkYXlljPNUY+Zf1EQxynXs=
github.com/glebarez/sqlite v1.9.0/go.mod h1:YBYCoyupOao60lzp1MVBLEjZfgkq0tdB1voAQ09K9zw=
github.com/go-ole/go-ole v1.2.6 h1:/Fpf6oFPoeFik9ty7siob0G6Ke8QvQEuVcuChpwXzpY=
github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0=
github.com/go-playground/assert/v2 v2.0.1/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
@@ -77,8 +83,9 @@ github.com/go-playground/validator/v10 v10.20.0 h1:K9ISHbSaI0lyB2eWMPJo+kOS/FBEx
github.com/go-playground/validator/v10 v10.20.0/go.mod h1:dbuPbCMFw/DrkbEynArYaCwl3amGuJotoKCe95atGMM=
github.com/go-redis/redis/v8 v8.11.5 h1:AcZZR7igkdvfVmQTPnu9WE37LRrO/YrBH5zWyjDC0oI=
github.com/go-redis/redis/v8 v8.11.5/go.mod h1:gREzHqY1hg6oD9ngVRbLStwAWKhA0FEgq8Jd4h5lpwo=
github.com/go-sql-driver/mysql v1.6.0 h1:BCTh4TKNUYmOmMUcQ3IipzF5prigylS7XXjEkfCHuOE=
github.com/go-sql-driver/mysql v1.6.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg=
github.com/go-sql-driver/mysql v1.7.0 h1:ueSltNNllEqE3qcWBTD0iQd3IpL/6U+mJxLkazJ7YPc=
github.com/go-sql-driver/mysql v1.7.0/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI=
github.com/goccy/go-json v0.9.7/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU=
github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
@@ -90,6 +97,8 @@ github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/google/pprof v0.0.0-20221118152302-e6195bd50e26 h1:Xim43kblpZXfIBQsbuBVKCudVG457BR2GZFIz3uw3hQ=
github.com/google/pprof v0.0.0-20221118152302-e6195bd50e26/go.mod h1:dDKJzRmX4S37WGHujM7tX//fmj1uioxKzKxz3lo4HJo=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/gorilla/context v1.1.1 h1:AWwleXJkX/nhcU9bZSnZoi3h/qGYqQAGhq6zZe/aQW8=
@@ -140,9 +149,6 @@ github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Ky
github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-sqlite3 v1.14.15/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg=
github.com/mattn/go-sqlite3 v2.0.3+incompatible h1:gXHsfypPkaMZrKbD5209QV9jbUTJKjyR5WD3HYQSd+U=
github.com/mattn/go-sqlite3 v2.0.3+incompatible/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc=
github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY=
github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo=
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
@@ -167,6 +173,9 @@ github.com/pkoukk/tiktoken-go v0.1.7 h1:qOBHXX4PHtvIvmOtyg1EeKlwFRiMKAcoMp4Q+bLQ
github.com/pkoukk/tiktoken-go v0.1.7/go.mod h1:9NiV+i9mJKGj1rYOT+njbv+ZwA/zJxYdewGl6qVatpg=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/remyoudompheng/bigfft v0.0.0-20200410134404-eec4a21b6bb0/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc=
github.com/rogpeppe/go-internal v1.8.0 h1:FCbCCtXNOY3UtUuHUYaghJg4y7Fd14rXifAYUAtL9R8=
github.com/rogpeppe/go-internal v1.8.0/go.mod h1:WmiCO8CzOY8rg0OYDC4/i/2WRWAB6poM+XZ2dLUbcbE=
@@ -263,11 +272,16 @@ gorm.io/driver/mysql v1.4.3 h1:/JhWJhO2v17d8hjApTltKNADm7K7YI2ogkR7avJUL3k=
gorm.io/driver/mysql v1.4.3/go.mod h1:sSIebwZAVPiT+27jK9HIwvsqOGKx3YMPmrA3mBJR10c=
gorm.io/driver/postgres v1.5.2 h1:ytTDxxEv+MplXOfFe3Lzm7SjG09fcdb3Z/c056DTBx0=
gorm.io/driver/postgres v1.5.2/go.mod h1:fmpX0m2I1PKuR7mKZiEluwrP3hbs+ps7JIGMUBpCgl8=
gorm.io/driver/sqlite v1.4.3 h1:HBBcZSDnWi5BW3B3rwvVTc510KGkBkexlOg0QrmLUuU=
gorm.io/driver/sqlite v1.4.3/go.mod h1:0Aq3iPO+v9ZKbcdiz8gLWRw5VOPcBOPUQJFLq5e2ecI=
gorm.io/gorm v1.23.8/go.mod h1:l2lP/RyAtc1ynaTjFksBde/O8v9oOGIApu2/xRitmZk=
gorm.io/gorm v1.24.0/go.mod h1:DVrVomtaYTbqs7gB/x2uVvqnXzv0nqjB396B8cG4dBA=
gorm.io/gorm v1.25.0 h1:+KtYtb2roDz14EQe4bla8CbQlmb9dN3VejSai3lprfU=
gorm.io/gorm v1.25.0/go.mod h1:L4uxeKpfBml98NYqVqwAdmV1a2nBtAec/cf3fpucW/k=
gorm.io/gorm v1.25.2 h1:gs1o6Vsa+oVKG/a9ElL3XgyGfghFfkKA2SInQaCyMho=
gorm.io/gorm v1.25.2/go.mod h1:L4uxeKpfBml98NYqVqwAdmV1a2nBtAec/cf3fpucW/k=
modernc.org/libc v1.22.5 h1:91BNch/e5B0uPbJFgqbxXuOnxBQjlS//icfQEGmvyjE=
modernc.org/libc v1.22.5/go.mod h1:jj+Z7dTNX8fBScMVNRAYZ/jF91K8fdT2hYMThc3YjBY=
modernc.org/mathutil v1.5.0 h1:rV0Ko/6SfM+8G+yKiyI830l3Wuz1zRutdslNoQ0kfiQ=
modernc.org/mathutil v1.5.0/go.mod h1:mZW8CKdRPY1v87qxC/wUdX5O1qDzXMP5TH3wjfpga6E=
modernc.org/memory v1.5.0 h1:N+/8c5rE6EqugZwHii4IFsaJ7MUhoWX07J5tC/iI5Ds=
modernc.org/memory v1.5.0/go.mod h1:PkUhL0Mugw21sHPeskwZW4D6VscE/GQJOnIpCnW6pSU=
modernc.org/sqlite v1.23.1 h1:nrSBg4aRQQwq59JpvGEQ15tNxoO5pX/kUjcRNwSAGQM=
modernc.org/sqlite v1.23.1/go.mod h1:OrDj17Mggn6MhE+iPbBNf7RGKODDE9NFT0f3EwDzJqk=
nullprogram.com/x/optparse v1.0.0/go.mod h1:KdyPE+Igbe0jQUrVfMqDMeJQIJZEuyV7pjYmp6pbG50=
rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4=

11
main.go
View File

@@ -119,9 +119,9 @@ func main() {
}
if os.Getenv("ENABLE_PPROF") == "true" {
go func() {
gopool.Go(func() {
log.Println(http.ListenAndServe("0.0.0.0:8005", nil))
}()
})
go common.Monitor()
common.SysLog("pprof enabled")
}
@@ -145,6 +145,13 @@ func main() {
middleware.SetUpLogger(server)
// Initialize session store
store := cookie.NewStore([]byte(common.SessionSecret))
store.Options(sessions.Options{
Path: "/",
MaxAge: 2592000, // 30 days
HttpOnly: true,
Secure: false,
SameSite: http.SameSiteStrictMode,
})
server.Use(sessions.Sessions("session", store))
router.SetRouter(server, buildFS, indexPage)

View File

@@ -135,17 +135,14 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
midjourneyRequest := dto.MidjourneyRequest{}
err = common.UnmarshalBodyReusable(c, &midjourneyRequest)
if err != nil {
abortWithMidjourneyMessage(c, http.StatusBadRequest, constant.MjErrorUnknown, "无效的请求, "+err.Error())
return nil, false, err
}
midjourneyModel, mjErr, success := service.GetMjRequestModel(relayMode, &midjourneyRequest)
if mjErr != nil {
abortWithMidjourneyMessage(c, http.StatusBadRequest, mjErr.Code, mjErr.Description)
return nil, false, fmt.Errorf(mjErr.Description)
}
if midjourneyModel == "" {
if !success {
abortWithMidjourneyMessage(c, http.StatusBadRequest, constant.MjErrorUnknown, "无效的请求, 无法解析模型")
return nil, false, fmt.Errorf("无效的请求, 无法解析模型")
} else {
// task fetch, task fetch by condition, notify
@@ -170,7 +167,6 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
err = common.UnmarshalBodyReusable(c, &modelRequest)
}
if err != nil {
abortWithOpenAiMessage(c, http.StatusBadRequest, "无效的请求, "+err.Error())
return nil, false, errors.New("无效的请求, " + err.Error())
}
if strings.HasPrefix(c.Request.URL.Path, "/v1/realtime") {
@@ -239,5 +235,7 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode
c.Set("plugin", channel.Other)
case common.ChannelCloudflare:
c.Set("api_version", channel.Other)
case common.ChannelTypeMokaAI:
c.Set("api_version", channel.Other)
}
}

View File

@@ -11,106 +11,6 @@ import (
"time"
)
//func CacheGetUserGroup(id int) (group string, err error) {
// if !common.RedisEnabled {
// return GetUserGroup(id)
// }
// group, err = common.RedisGet(fmt.Sprintf("user_group:%d", id))
// if err != nil {
// group, err = GetUserGroup(id)
// if err != nil {
// return "", err
// }
// err = common.RedisSet(fmt.Sprintf("user_group:%d", id), group, time.Duration(constant.UserId2GroupCacheSeconds)*time.Second)
// if err != nil {
// common.SysError("Redis set user group error: " + err.Error())
// }
// }
// return group, err
//}
//
//func CacheGetUsername(id int) (username string, err error) {
// if !common.RedisEnabled {
// return GetUsernameById(id)
// }
// username, err = common.RedisGet(fmt.Sprintf("user_name:%d", id))
// if err != nil {
// username, err = GetUsernameById(id)
// if err != nil {
// return "", err
// }
// err = common.RedisSet(fmt.Sprintf("user_name:%d", id), username, time.Duration(constant.UserId2GroupCacheSeconds)*time.Second)
// if err != nil {
// common.SysError("Redis set user group error: " + err.Error())
// }
// }
// return username, err
//}
//
//func CacheGetUserQuota(id int) (quota int, err error) {
// if !common.RedisEnabled {
// return GetUserQuota(id)
// }
// quotaString, err := common.RedisGet(fmt.Sprintf("user_quota:%d", id))
// if err != nil {
// quota, err = GetUserQuota(id)
// if err != nil {
// return 0, err
// }
// return quota, nil
// }
// quota, err = strconv.Atoi(quotaString)
// return quota, nil
//}
//
//func CacheUpdateUserQuota(id int) error {
// if !common.RedisEnabled {
// return nil
// }
// quota, err := GetUserQuota(id)
// if err != nil {
// return err
// }
// return cacheSetUserQuota(id, quota)
//}
//
//func cacheSetUserQuota(id int, quota int) error {
// err := common.RedisSet(fmt.Sprintf("user_quota:%d", id), fmt.Sprintf("%d", quota), time.Duration(constant.UserId2QuotaCacheSeconds)*time.Second)
// return err
//}
//
//func CacheDecreaseUserQuota(id int, quota int) error {
// if !common.RedisEnabled {
// return nil
// }
// err := common.RedisDecrease(fmt.Sprintf("user_quota:%d", id), int64(quota))
// return err
//}
//
//func CacheIsUserEnabled(userId int) (bool, error) {
// if !common.RedisEnabled {
// return IsUserEnabled(userId)
// }
// enabled, err := common.RedisGet(fmt.Sprintf("user_enabled:%d", userId))
// if err == nil {
// return enabled == "1", nil
// }
//
// userEnabled, err := IsUserEnabled(userId)
// if err != nil {
// return false, err
// }
// enabled = "0"
// if userEnabled {
// enabled = "1"
// }
// err = common.RedisSet(fmt.Sprintf("user_enabled:%d", userId), enabled, time.Duration(constant.UserId2StatusCacheSeconds)*time.Second)
// if err != nil {
// common.SysError("Redis set user enabled error: " + err.Error())
// }
// return userEnabled, err
//}
var group2model2channels map[string]map[string][]*Channel
var channelsIDM map[int]*Channel
var channelSyncLock sync.RWMutex

View File

@@ -133,9 +133,6 @@ func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName
tx = LOG_DB.Where("logs.type = ?", logType)
}
tx = tx.Joins("LEFT JOIN channels ON logs.channel_id = channels.id")
tx = tx.Select("logs.*, channels.name as channel_name")
if modelName != "" {
tx = tx.Where("logs.model_name like ?", modelName)
}
@@ -165,6 +162,30 @@ func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName
if err != nil {
return nil, 0, err
}
channelIds := make([]int, 0)
channelMap := make(map[int]string)
for _, log := range logs {
if log.ChannelId != 0 {
channelIds = append(channelIds, log.ChannelId)
}
}
if len(channelIds) > 0 {
var channels []struct {
Id int `gorm:"column:id"`
Name string `gorm:"column:name"`
}
if err = DB.Table("channels").Select("id, name").Where("id IN ?", channelIds).Find(&channels).Error; err != nil {
return logs, total, err
}
for _, channel := range channels {
channelMap[channel.Id] = channel.Name
}
for i := range logs {
logs[i].ChannelName = channelMap[logs[i].ChannelId]
}
}
return logs, total, err
}
@@ -176,9 +197,6 @@ func GetUserLogs(userId int, logType int, startTimestamp int64, endTimestamp int
tx = LOG_DB.Where("logs.user_id = ? and logs.type = ?", userId, logType)
}
tx = tx.Joins("LEFT JOIN channels ON logs.channel_id = channels.id")
tx = tx.Select("logs.*, channels.name as channel_name")
if modelName != "" {
tx = tx.Where("logs.model_name like ?", modelName)
}
@@ -199,6 +217,10 @@ func GetUserLogs(userId int, logType int, startTimestamp int64, endTimestamp int
return nil, 0, err
}
err = tx.Order("logs.id desc").Limit(num).Offset(startIdx).Find(&logs).Error
if err != nil {
return nil, 0, err
}
formatUserLogs(logs)
return logs, total, err
}

View File

@@ -1,9 +1,9 @@
package model
import (
"github.com/glebarez/sqlite"
"gorm.io/driver/mysql"
"gorm.io/driver/postgres"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"log"
"one-api/common"

View File

@@ -110,6 +110,7 @@ func InitOptionMap() {
common.OptionMap["StopOnSensitiveEnabled"] = strconv.FormatBool(setting.StopOnSensitiveEnabled)
common.OptionMap["SensitiveWords"] = setting.SensitiveWordsToString()
common.OptionMap["StreamCacheQueueLength"] = strconv.Itoa(setting.StreamCacheQueueLength)
common.OptionMap["AutomaticDisableKeywords"] = setting.AutomaticDisableKeywordsToString()
common.OptionMapRWMutex.Unlock()
loadOptionsFromDatabase()
@@ -335,6 +336,8 @@ func updateOptionMap(key string, value string) (err error) {
common.QuotaPerUnit, _ = strconv.ParseFloat(value, 64)
case "SensitiveWords":
setting.SensitiveWordsFromString(value)
case "AutomaticDisableKeywords":
setting.AutomaticDisableKeywordsFromString(value)
case "StreamCacheQueueLength":
setting.StreamCacheQueueLength, _ = strconv.Atoi(value)
}

View File

@@ -3,13 +3,11 @@ package model
import (
"errors"
"fmt"
"one-api/common"
"strings"
"github.com/bytedance/gopkg/util/gopool"
"gorm.io/gorm"
"one-api/common"
relaycommon "one-api/relay/common"
"one-api/setting"
"strconv"
"strings"
)
type Token struct {
@@ -322,80 +320,3 @@ func decreaseTokenQuota(id int, quota int) (err error) {
).Error
return err
}
func PreConsumeTokenQuota(relayInfo *relaycommon.RelayInfo, quota int) error {
if quota < 0 {
return errors.New("quota 不能为负数!")
}
if relayInfo.IsPlayground {
return nil
}
//if relayInfo.TokenUnlimited {
// return nil
//}
token, err := GetTokenById(relayInfo.TokenId)
if err != nil {
return err
}
if !relayInfo.TokenUnlimited && token.RemainQuota < quota {
return errors.New("令牌额度不足")
}
err = DecreaseTokenQuota(relayInfo.TokenId, relayInfo.TokenKey, quota)
if err != nil {
return err
}
return nil
}
func PostConsumeQuota(relayInfo *relaycommon.RelayInfo, userQuota int, quota int, preConsumedQuota int, sendEmail bool) (err error) {
if quota > 0 {
err = DecreaseUserQuota(relayInfo.UserId, quota)
} else {
err = IncreaseUserQuota(relayInfo.UserId, -quota)
}
if err != nil {
return err
}
if !relayInfo.IsPlayground {
if quota > 0 {
err = DecreaseTokenQuota(relayInfo.TokenId, relayInfo.TokenKey, quota)
} else {
err = IncreaseTokenQuota(relayInfo.TokenId, relayInfo.TokenKey, -quota)
}
if err != nil {
return err
}
}
if sendEmail {
if (quota + preConsumedQuota) != 0 {
quotaTooLow := userQuota >= common.QuotaRemindThreshold && userQuota-(quota+preConsumedQuota) < common.QuotaRemindThreshold
noMoreQuota := userQuota-(quota+preConsumedQuota) <= 0
if quotaTooLow || noMoreQuota {
go func() {
email, err := GetUserEmail(relayInfo.UserId)
if err != nil {
common.SysError("failed to fetch user email: " + err.Error())
}
prompt := "您的额度即将用尽"
if noMoreQuota {
prompt = "您的额度已用尽"
}
if email != "" {
topUpLink := fmt.Sprintf("%s/topup", setting.ServerAddress)
err = common.SendEmail(prompt, email,
fmt.Sprintf("%s当前剩余额度为 %d为了不影响您的使用请及时充值。<br/>充值链接:<a href='%s'>%s</a>", prompt, userQuota, topUpLink, topUpLink))
if err != nil {
common.SysError("failed to send email" + err.Error())
}
common.SysLog("user quota is low, consumed quota: " + strconv.Itoa(quota) + ", user quota: " + strconv.Itoa(userQuota))
}
}()
}
}
}
return nil
}

View File

@@ -52,7 +52,7 @@ func cacheSetTokenField(key string, field string, value string) error {
func cacheGetTokenByKey(key string) (*Token, error) {
hmacKey := common.GenerateHMAC(key)
if !common.RedisEnabled {
return nil, nil
return nil, fmt.Errorf("redis is not enabled")
}
var token Token
err := common.RedisHGetObj(fmt.Sprintf("token:%s", hmacKey), &token)

View File

@@ -1,6 +1,7 @@
package model
import (
"encoding/json"
"errors"
"fmt"
"one-api/common"
@@ -38,6 +39,20 @@ type User struct {
InviterId int `json:"inviter_id" gorm:"type:int;column:inviter_id;index"`
DeletedAt gorm.DeletedAt `gorm:"index"`
LinuxDOId string `json:"linux_do_id" gorm:"column:linux_do_id;index"`
Setting string `json:"setting" gorm:"type:text;column:setting"`
}
func (user *User) ToBaseUser() *UserBase {
cache := &UserBase{
Id: user.Id,
Group: user.Group,
Quota: user.Quota,
Status: user.Status,
Username: user.Username,
Setting: user.Setting,
Email: user.Email,
}
return cache
}
func (user *User) GetAccessToken() string {
@@ -51,6 +66,22 @@ func (user *User) SetAccessToken(token string) {
user.AccessToken = &token
}
func (user *User) GetSetting() map[string]interface{} {
if user.Setting == "" {
return nil
}
return common.StrToMap(user.Setting)
}
func (user *User) SetSetting(setting map[string]interface{}) {
settingBytes, err := json.Marshal(setting)
if err != nil {
common.SysError("failed to marshal setting: " + err.Error())
return
}
user.Setting = string(settingBytes)
}
// CheckUserExistOrDeleted check if user exist or deleted, if not exist, return false, nil, if deleted or exist, return true, nil
func CheckUserExistOrDeleted(username string, email string) (bool, error) {
var user User
@@ -315,8 +346,8 @@ func (user *User) Update(updatePassword bool) error {
return err
}
// 更新缓存
return updateUserCache(user.Id, user.Username, user.Group, user.Quota, user.Status)
// Update cache
return updateUserCache(*user)
}
func (user *User) Edit(updatePassword bool) error {
@@ -344,8 +375,8 @@ func (user *User) Edit(updatePassword bool) error {
return err
}
// 更新缓存
return updateUserCache(user.Id, user.Username, user.Group, user.Quota, user.Status)
// Update cache
return updateUserCache(*user)
}
func (user *User) Delete() error {
@@ -371,8 +402,8 @@ func (user *User) HardDelete() error {
// ValidateAndFill check password & user status
func (user *User) ValidateAndFill() (err error) {
// When querying with struct, GORM will only query with non-zero fields,
// that means if your fields value is 0, '', false or other zero values,
// it wont be used to build query conditions
// that means if your field's value is 0, '', false or other zero values,
// it won't be used to build query conditions
password := user.Password
username := strings.TrimSpace(user.Username)
if username == "" || password == "" {
@@ -531,7 +562,6 @@ func GetUserQuota(id int, fromDB bool) (quota int, err error) {
return quota, nil
}
// Don't return error - fall through to DB
//common.SysError("failed to get user quota from cache: " + err.Error())
}
fromDB = true
err = DB.Model(&User{}).Where("id = ?", id).Select("quota").Find(&quota).Error
@@ -580,6 +610,35 @@ func GetUserGroup(id int, fromDB bool) (group string, err error) {
return group, nil
}
// GetUserSetting gets setting from Redis first, falls back to DB if needed
func GetUserSetting(id int, fromDB bool) (settingMap map[string]interface{}, err error) {
var setting string
defer func() {
// Update Redis cache asynchronously on successful DB read
if shouldUpdateRedis(fromDB, err) {
gopool.Go(func() {
if err := updateUserSettingCache(id, setting); err != nil {
common.SysError("failed to update user setting cache: " + err.Error())
}
})
}
}()
if !fromDB && common.RedisEnabled {
setting, err := getUserSettingCache(id)
if err == nil {
return setting, nil
}
// Don't return error - fall through to DB
}
fromDB = true
err = DB.Model(&User{}).Where("id = ?", id).Select("setting").Find(&setting).Error
if err != nil {
return map[string]interface{}{}, err
}
return common.StrToMap(setting), nil
}
func IncreaseUserQuota(id int, quota int) (err error) {
if quota < 0 {
return errors.New("quota 不能为负数!")
@@ -641,9 +700,14 @@ func DeltaUpdateUserQuota(id int, delta int) (err error) {
}
}
func GetRootUserEmail() (email string) {
DB.Model(&User{}).Where("role = ?", common.RoleRootUser).Select("email").Find(&email)
return email
//func GetRootUserEmail() (email string) {
// DB.Model(&User{}).Where("role = ?", common.RoleRootUser).Select("email").Find(&email)
// return email
//}
func GetRootUser() (user *User) {
DB.Where("role = ?", common.RoleRootUser).First(&user)
return user
}
func UpdateUserUsedQuotaAndRequestCount(id int, quota int) {
@@ -725,10 +789,10 @@ func IsLinuxDOIdAlreadyTaken(linuxDOId string) bool {
return !errors.Is(err, gorm.ErrRecordNotFound)
}
func (u *User) FillUserByLinuxDOId() error {
if u.LinuxDOId == "" {
func (user *User) FillUserByLinuxDOId() error {
if user.LinuxDOId == "" {
return errors.New("linux do id is empty")
}
err := DB.Where("linux_do_id = ?", u.LinuxDOId).First(u).Error
err := DB.Where("linux_do_id = ?", user.LinuxDOId).First(user).Error
return err
}

View File

@@ -1,206 +1,213 @@
package model
import (
"encoding/json"
"fmt"
"one-api/common"
"one-api/constant"
"strconv"
"time"
"github.com/bytedance/gopkg/util/gopool"
)
// Change UserCache struct to userCache
type userCache struct {
// UserBase struct remains the same as it represents the cached data structure
type UserBase struct {
Id int `json:"id"`
Group string `json:"group"`
Email string `json:"email"`
Quota int `json:"quota"`
Status int `json:"status"`
Role int `json:"role"`
Username string `json:"username"`
Setting string `json:"setting"`
}
// Rename all exported functions to private ones
// invalidateUserCache clears all user related cache
func (user *UserBase) GetSetting() map[string]interface{} {
if user.Setting == "" {
return nil
}
return common.StrToMap(user.Setting)
}
func (user *UserBase) SetSetting(setting map[string]interface{}) {
settingBytes, err := json.Marshal(setting)
if err != nil {
common.SysError("failed to marshal setting: " + err.Error())
return
}
user.Setting = string(settingBytes)
}
// getUserCacheKey returns the key for user cache
func getUserCacheKey(userId int) string {
return fmt.Sprintf("user:%d", userId)
}
// invalidateUserCache clears user cache
func invalidateUserCache(userId int) error {
if !common.RedisEnabled {
return nil
}
return common.RedisHDelObj(getUserCacheKey(userId))
}
keys := []string{
fmt.Sprintf(constant.UserGroupKeyFmt, userId),
fmt.Sprintf(constant.UserQuotaKeyFmt, userId),
fmt.Sprintf(constant.UserEnabledKeyFmt, userId),
fmt.Sprintf(constant.UserUsernameKeyFmt, userId),
// updateUserCache updates all user cache fields using hash
func updateUserCache(user User) error {
if !common.RedisEnabled {
return nil
}
for _, key := range keys {
if err := common.RedisDel(key); err != nil {
return fmt.Errorf("failed to delete cache key %s: %w", key, err)
return common.RedisHSetObj(
getUserCacheKey(user.Id),
user.ToBaseUser(),
time.Duration(constant.UserId2QuotaCacheSeconds)*time.Second,
)
}
// GetUserCache gets complete user cache from hash
func GetUserCache(userId int) (userCache *UserBase, err error) {
var user *User
var fromDB bool
defer func() {
// Update Redis cache asynchronously on successful DB read
if shouldUpdateRedis(fromDB, err) && user != nil {
gopool.Go(func() {
if err := updateUserCache(*user); err != nil {
common.SysError("failed to update user status cache: " + err.Error())
}
})
}
}
return nil
}
}()
// updateUserGroupCache updates user group cache
func updateUserGroupCache(userId int, group string) error {
if !common.RedisEnabled {
return nil
}
return common.RedisSet(
fmt.Sprintf(constant.UserGroupKeyFmt, userId),
group,
time.Duration(constant.UserId2QuotaCacheSeconds)*time.Second,
)
}
// updateUserQuotaCache updates user quota cache
func updateUserQuotaCache(userId int, quota int) error {
if !common.RedisEnabled {
return nil
}
return common.RedisSet(
fmt.Sprintf(constant.UserQuotaKeyFmt, userId),
fmt.Sprintf("%d", quota),
time.Duration(constant.UserId2QuotaCacheSeconds)*time.Second,
)
}
// updateUserStatusCache updates user status cache
func updateUserStatusCache(userId int, userEnabled bool) error {
if !common.RedisEnabled {
return nil
}
enabled := "0"
if userEnabled {
enabled = "1"
}
return common.RedisSet(
fmt.Sprintf(constant.UserEnabledKeyFmt, userId),
enabled,
time.Duration(constant.UserId2StatusCacheSeconds)*time.Second,
)
}
// updateUserNameCache updates username cache
func updateUserNameCache(userId int, username string) error {
if !common.RedisEnabled {
return nil
}
return common.RedisSet(
fmt.Sprintf(constant.UserUsernameKeyFmt, userId),
username,
time.Duration(constant.UserId2QuotaCacheSeconds)*time.Second,
)
}
// updateUserCache updates all user cache fields
func updateUserCache(userId int, username string, userGroup string, quota int, status int) error {
if !common.RedisEnabled {
return nil
// Try getting from Redis first
userCache, err = cacheGetUserBase(userId)
if err == nil {
return userCache, nil
}
if err := updateUserGroupCache(userId, userGroup); err != nil {
return fmt.Errorf("update group cache: %w", err)
}
if err := updateUserQuotaCache(userId, quota); err != nil {
return fmt.Errorf("update quota cache: %w", err)
}
if err := updateUserStatusCache(userId, status == common.UserStatusEnabled); err != nil {
return fmt.Errorf("update status cache: %w", err)
}
if err := updateUserNameCache(userId, username); err != nil {
return fmt.Errorf("update username cache: %w", err)
}
return nil
}
// getUserGroupCache gets user group from cache
func getUserGroupCache(userId int) (string, error) {
if !common.RedisEnabled {
return "", nil
}
return common.RedisGet(fmt.Sprintf(constant.UserGroupKeyFmt, userId))
}
// getUserQuotaCache gets user quota from cache
func getUserQuotaCache(userId int) (int, error) {
if !common.RedisEnabled {
return 0, nil
}
quotaStr, err := common.RedisGet(fmt.Sprintf(constant.UserQuotaKeyFmt, userId))
// If Redis fails, get from DB
fromDB = true
user, err = GetUserById(userId, false)
if err != nil {
return 0, err
return nil, err // Return nil and error if DB lookup fails
}
return strconv.Atoi(quotaStr)
// Create cache object from user data
userCache = &UserBase{
Id: user.Id,
Group: user.Group,
Quota: user.Quota,
Status: user.Status,
Username: user.Username,
Setting: user.Setting,
Email: user.Email,
}
return userCache, nil
}
// getUserStatusCache gets user status from cache
func getUserStatusCache(userId int) (int, error) {
func cacheGetUserBase(userId int) (*UserBase, error) {
if !common.RedisEnabled {
return 0, nil
return nil, fmt.Errorf("redis is not enabled")
}
statusStr, err := common.RedisGet(fmt.Sprintf(constant.UserEnabledKeyFmt, userId))
var userCache UserBase
// Try getting from Redis first
err := common.RedisHGetObj(getUserCacheKey(userId), &userCache)
if err != nil {
return 0, err
return nil, err
}
return strconv.Atoi(statusStr)
return &userCache, nil
}
// getUserNameCache gets username from cache
func getUserNameCache(userId int) (string, error) {
if !common.RedisEnabled {
return "", nil
}
return common.RedisGet(fmt.Sprintf(constant.UserUsernameKeyFmt, userId))
}
// getUserCache gets complete user cache
func getUserCache(userId int) (*userCache, error) {
if !common.RedisEnabled {
return nil, nil
}
group, err := getUserGroupCache(userId)
if err != nil {
return nil, fmt.Errorf("get group cache: %w", err)
}
quota, err := getUserQuotaCache(userId)
if err != nil {
return nil, fmt.Errorf("get quota cache: %w", err)
}
status, err := getUserStatusCache(userId)
if err != nil {
return nil, fmt.Errorf("get status cache: %w", err)
}
username, err := getUserNameCache(userId)
if err != nil {
return nil, fmt.Errorf("get username cache: %w", err)
}
return &userCache{
Id: userId,
Group: group,
Quota: quota,
Status: status,
Username: username,
}, nil
}
// Add atomic quota operations
// Add atomic quota operations using hash fields
func cacheIncrUserQuota(userId int, delta int64) error {
if !common.RedisEnabled {
return nil
}
key := fmt.Sprintf(constant.UserQuotaKeyFmt, userId)
return common.RedisIncr(key, delta)
return common.RedisHIncrBy(getUserCacheKey(userId), "Quota", delta)
}
func cacheDecrUserQuota(userId int, delta int64) error {
return cacheIncrUserQuota(userId, -delta)
}
// Helper functions to get individual fields if needed
func getUserGroupCache(userId int) (string, error) {
cache, err := GetUserCache(userId)
if err != nil {
return "", err
}
return cache.Group, nil
}
func getUserQuotaCache(userId int) (int, error) {
cache, err := GetUserCache(userId)
if err != nil {
return 0, err
}
return cache.Quota, nil
}
func getUserStatusCache(userId int) (int, error) {
cache, err := GetUserCache(userId)
if err != nil {
return 0, err
}
return cache.Status, nil
}
func getUserNameCache(userId int) (string, error) {
cache, err := GetUserCache(userId)
if err != nil {
return "", err
}
return cache.Username, nil
}
func getUserSettingCache(userId int) (map[string]interface{}, error) {
setting := make(map[string]interface{})
cache, err := GetUserCache(userId)
if err != nil {
return setting, err
}
return cache.GetSetting(), nil
}
// New functions for individual field updates
func updateUserStatusCache(userId int, status bool) error {
if !common.RedisEnabled {
return nil
}
statusInt := common.UserStatusEnabled
if !status {
statusInt = common.UserStatusDisabled
}
return common.RedisHSetField(getUserCacheKey(userId), "Status", fmt.Sprintf("%d", statusInt))
}
func updateUserQuotaCache(userId int, quota int) error {
if !common.RedisEnabled {
return nil
}
return common.RedisHSetField(getUserCacheKey(userId), "Quota", fmt.Sprintf("%d", quota))
}
func updateUserGroupCache(userId int, group string) error {
if !common.RedisEnabled {
return nil
}
return common.RedisHSetField(getUserCacheKey(userId), "Group", group)
}
func updateUserNameCache(userId int, username string) error {
if !common.RedisEnabled {
return nil
}
return common.RedisHSetField(getUserCacheKey(userId), "Username", username)
}
func updateUserSettingCache(userId int, setting string) error {
if !common.RedisEnabled {
return nil
}
return common.RedisHSetField(getUserCacheKey(userId), "Setting", setting)
}

View File

@@ -15,6 +15,7 @@ type Adaptor interface {
SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error
ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error)
ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error)
ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error)
ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error)
ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error)
DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error)

View File

@@ -49,9 +49,6 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, re
return nil, errors.New("request is nil")
}
switch info.RelayMode {
case constant.RelayModeEmbeddings:
baiduEmbeddingRequest := embeddingRequestOpenAI2Ali(*request)
return baiduEmbeddingRequest, nil
default:
aliReq := requestOpenAI2Ali(*request)
return aliReq, nil
@@ -67,6 +64,10 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
return nil, errors.New("not implemented")
}
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
return embeddingRequestOpenAI2Ali(request), nil
}
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
//TODO implement me
return nil, errors.New("not implemented")

View File

@@ -25,9 +25,12 @@ func requestOpenAI2Ali(request dto.GeneralOpenAIRequest) *dto.GeneralOpenAIReque
return &request
}
func embeddingRequestOpenAI2Ali(request dto.GeneralOpenAIRequest) *AliEmbeddingRequest {
func embeddingRequestOpenAI2Ali(request dto.EmbeddingRequest) *AliEmbeddingRequest {
if request.Model == "" {
request.Model = "text-embedding-v1"
}
return &AliEmbeddingRequest{
Model: "text-embedding-v1",
Model: request.Model,
Input: struct {
Texts []string `json:"texts"`
}{

View File

@@ -59,6 +59,12 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
return nil, nil
}
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return nil, nil
}

View File

@@ -109,9 +109,6 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, re
return nil, errors.New("request is nil")
}
switch info.RelayMode {
case constant.RelayModeEmbeddings:
baiduEmbeddingRequest := embeddingRequestOpenAI2Baidu(*request)
return baiduEmbeddingRequest, nil
default:
baiduRequest := requestOpenAI2Baidu(*request)
return baiduRequest, nil
@@ -122,6 +119,11 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
return nil, nil
}
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
baiduEmbeddingRequest := embeddingRequestOpenAI2Baidu(request)
return baiduEmbeddingRequest, nil
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}

View File

@@ -87,7 +87,7 @@ func streamResponseBaidu2OpenAI(baiduResponse *BaiduChatStreamResponse) *dto.Cha
return &response
}
func embeddingRequestOpenAI2Baidu(request dto.GeneralOpenAIRequest) *BaiduEmbeddingRequest {
func embeddingRequestOpenAI2Baidu(request dto.EmbeddingRequest) *BaiduEmbeddingRequest {
return &BaiduEmbeddingRequest{
Input: request.ParseInput(),
}

View File

@@ -0,0 +1,76 @@
package baidu_v2
import (
"errors"
"fmt"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/dto"
"one-api/relay/channel"
"one-api/relay/channel/openai"
relaycommon "one-api/relay/common"
)
type Adaptor struct {
}
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
return fmt.Sprintf("%s/v2/chat/completions", info.BaseUrl), nil
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
channel.SetupApiRequestHeader(info, c, req)
req.Set("Authorization", "Bearer "+info.ApiKey)
return nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
return request, nil
}
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
return nil, nil
}
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream {
err, usage = openai.OaiStreamHandler(c, resp, info)
} else {
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
}
return
}
func (a *Adaptor) GetModelList() []string {
return ModelList
}
func (a *Adaptor) GetChannelName() string {
return ChannelName
}

View File

@@ -0,0 +1,29 @@
package baidu_v2
var ModelList = []string{
"ernie-4.0-8k-latest",
"ernie-4.0-8k-preview",
"ernie-4.0-8k",
"ernie-4.0-turbo-8k-latest",
"ernie-4.0-turbo-8k-preview",
"ernie-4.0-turbo-8k",
"ernie-4.0-turbo-128k",
"ernie-3.5-8k-preview",
"ernie-3.5-8k",
"ernie-3.5-128k",
"ernie-speed-8k",
"ernie-speed-128k",
"ernie-speed-pro-128k",
"ernie-lite-8k",
"ernie-lite-pro-128k",
"ernie-tiny-8k",
"ernie-char-8k",
"ernie-char-fiction-8k",
"ernie-novel-8k",
"deepseek-v3",
"deepseek-r1",
"deepseek-r1-distill-qwen-32b",
"deepseek-r1-distill-qwen-14b",
}
var ChannelName = "volcengine"

View File

@@ -73,6 +73,11 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
return nil, nil
}
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}

View File

@@ -4,13 +4,14 @@ import (
"bytes"
"errors"
"fmt"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/dto"
"one-api/relay/channel"
relaycommon "one-api/relay/common"
"one-api/relay/constant"
"github.com/gin-gonic/gin"
)
type Adaptor struct {
@@ -56,6 +57,10 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
return request, nil
}
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
return request, nil
}
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
// 添加文件字段
file, _, err := c.Request.FormFile("file")

View File

@@ -54,6 +54,12 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
return requestConvertRerank2Cohere(request), nil
}
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
if info.RelayMode == constant.RelayModeRerank {
err, usage = cohereRerankHandler(c, resp, info)

View File

@@ -10,6 +10,7 @@ import (
"one-api/relay/channel"
"one-api/relay/channel/openai"
relaycommon "one-api/relay/common"
"one-api/relay/constant"
)
type Adaptor struct {
@@ -29,7 +30,12 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
return fmt.Sprintf("%s/chat/completions", info.BaseUrl), nil
switch info.RelayMode {
case constant.RelayModeCompletions:
return fmt.Sprintf("%s/beta/completions", info.BaseUrl), nil
default:
return fmt.Sprintf("%s/v1/chat/completions", info.BaseUrl), nil
}
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
@@ -49,6 +55,11 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
return nil, nil
}
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}

View File

@@ -48,6 +48,12 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
return nil, nil
}
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}

View File

@@ -1,15 +1,21 @@
package gemini
import (
"encoding/json"
"errors"
"fmt"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/common"
"one-api/constant"
"one-api/dto"
"one-api/relay/channel"
relaycommon "one-api/relay/common"
"one-api/service"
"strings"
"github.com/gin-gonic/gin"
)
type Adaptor struct {
@@ -21,8 +27,36 @@ func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInf
}
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
if !strings.HasPrefix(info.UpstreamModelName, "imagen") {
return nil, errors.New("not supported model for image generation")
}
// convert size to aspect ratio
aspectRatio := "1:1" // default aspect ratio
switch request.Size {
case "1024x1024":
aspectRatio = "1:1"
case "1024x1792":
aspectRatio = "9:16"
case "1792x1024":
aspectRatio = "16:9"
}
// build gemini imagen request
geminiRequest := GeminiImageRequest{
Instances: []GeminiImageInstance{
{
Prompt: request.Prompt,
},
},
Parameters: GeminiImageParameters{
SampleCount: request.N,
AspectRatio: aspectRatio,
PersonGeneration: "allow_adult", // default allow adult
},
}
return geminiRequest, nil
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
@@ -40,6 +74,10 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
}
}
if strings.HasPrefix(info.UpstreamModelName, "imagen") {
return fmt.Sprintf("%s/%s/models/%s:predict", info.BaseUrl, version, info.UpstreamModelName), nil
}
action := "generateContent"
if info.IsStream {
action = "streamGenerateContent?alt=sse"
@@ -68,11 +106,20 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
return nil, nil
}
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
if strings.HasPrefix(info.UpstreamModelName, "imagen") {
return GeminiImageHandler(c, resp, info)
}
if info.IsStream {
err, usage = GeminiChatStreamHandler(c, resp, info)
} else {
@@ -81,6 +128,60 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
return
}
func GeminiImageHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
responseBody, readErr := io.ReadAll(resp.Body)
if readErr != nil {
return nil, service.OpenAIErrorWrapper(readErr, "read_response_body_failed", http.StatusInternalServerError)
}
_ = resp.Body.Close()
var geminiResponse GeminiImageResponse
if jsonErr := json.Unmarshal(responseBody, &geminiResponse); jsonErr != nil {
return nil, service.OpenAIErrorWrapper(jsonErr, "unmarshal_response_body_failed", http.StatusInternalServerError)
}
if len(geminiResponse.Predictions) == 0 {
return nil, service.OpenAIErrorWrapper(errors.New("no images generated"), "no_images", http.StatusBadRequest)
}
// convert to openai format response
openAIResponse := dto.ImageResponse{
Created: common.GetTimestamp(),
Data: make([]dto.ImageData, 0, len(geminiResponse.Predictions)),
}
for _, prediction := range geminiResponse.Predictions {
if prediction.RaiFilteredReason != "" {
continue // skip filtered image
}
openAIResponse.Data = append(openAIResponse.Data, dto.ImageData{
B64Json: prediction.BytesBase64Encoded,
})
}
jsonResponse, jsonErr := json.Marshal(openAIResponse)
if jsonErr != nil {
return nil, service.OpenAIErrorWrapper(jsonErr, "marshal_response_failed", http.StatusInternalServerError)
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, _ = c.Writer.Write(jsonResponse)
// https://github.com/google-gemini/cookbook/blob/719a27d752aac33f39de18a8d3cb42a70874917e/quickstarts/Counting_Tokens.ipynb
// each image has fixed 258 tokens
const imageTokens = 258
generatedImages := len(openAIResponse.Data)
usage = &dto.Usage{
PromptTokens: imageTokens * generatedImages, // each generated image has fixed 258 tokens
CompletionTokens: 0, // image generation does not calculate completion tokens
TotalTokens: imageTokens * generatedImages,
}
return usage, nil
}
func (a *Adaptor) GetModelList() []string {
return ModelList
}

View File

@@ -16,6 +16,8 @@ var ModelList = []string{
"gemini-2.0-pro-exp",
// thinking exp
"gemini-2.0-flash-thinking-exp",
// imagen models
"imagen-3.0-generate-002",
}
var ChannelName = "google gemini"

View File

@@ -109,3 +109,30 @@ type GeminiUsageMetadata struct {
CandidatesTokenCount int `json:"candidatesTokenCount"`
TotalTokenCount int `json:"totalTokenCount"`
}
// Imagen related structs
type GeminiImageRequest struct {
Instances []GeminiImageInstance `json:"instances"`
Parameters GeminiImageParameters `json:"parameters"`
}
type GeminiImageInstance struct {
Prompt string `json:"prompt"`
}
type GeminiImageParameters struct {
SampleCount int `json:"sampleCount,omitempty"`
AspectRatio string `json:"aspectRatio,omitempty"`
PersonGeneration string `json:"personGeneration,omitempty"`
}
type GeminiImageResponse struct {
Predictions []GeminiImagePrediction `json:"predictions"`
}
type GeminiImagePrediction struct {
MimeType string `json:"mimeType"`
BytesBase64Encoded string `json:"bytesBase64Encoded"`
RaiFilteredReason string `json:"raiFilteredReason,omitempty"`
SafetyAttributes any `json:"safetyAttributes,omitempty"`
}

View File

@@ -55,6 +55,10 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
return request, nil
}
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
return request, nil
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
if info.RelayMode == constant.RelayModeRerank {
err, usage = jinaRerankHandler(c, resp)

View File

@@ -50,6 +50,12 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
return nil, nil
}
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}

View File

@@ -0,0 +1,93 @@
package mokaai
import (
"errors"
"fmt"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/dto"
"one-api/relay/channel"
relaycommon "one-api/relay/common"
"one-api/relay/constant"
"strings"
)
type Adaptor struct {
}
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
//TODO implement me
return request, nil
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/clntwmv7t
suffix := "chat/"
if strings.HasPrefix(info.UpstreamModelName, "m3e") {
suffix = "embeddings"
}
fullRequestURL := fmt.Sprintf("%s/%s", info.BaseUrl, suffix)
return fullRequestURL, nil
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
channel.SetupApiRequestHeader(info, c, req)
req.Set("Authorization", fmt.Sprintf("Bearer %s", info.ApiKey))
return nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
switch info.RelayMode {
case constant.RelayModeEmbeddings:
baiduEmbeddingRequest := embeddingRequestOpenAI2Moka(*request)
return baiduEmbeddingRequest, nil
default:
return nil, errors.New("not implemented")
}
}
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
return nil, nil
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
switch info.RelayMode {
case constant.RelayModeEmbeddings:
err, usage = mokaEmbeddingHandler(c, resp)
default:
// err, usage = mokaHandler(c, resp)
}
return
}
func (a *Adaptor) GetModelList() []string {
return ModelList
}
func (a *Adaptor) GetChannelName() string {
return ChannelName
}

View File

@@ -0,0 +1,9 @@
package mokaai
var ModelList = []string{
"m3e-large",
"m3e-base",
"m3e-small",
}
var ChannelName = "mokaai"

View File

@@ -0,0 +1,83 @@
package mokaai
import (
"encoding/json"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/dto"
"one-api/service"
)
func embeddingRequestOpenAI2Moka(request dto.GeneralOpenAIRequest) *dto.EmbeddingRequest {
var input []string // Change input to []string
switch v := request.Input.(type) {
case string:
input = []string{v} // Convert string to []string
case []string:
input = v // Already a []string, no conversion needed
case []interface{}:
for _, part := range v {
if str, ok := part.(string); ok {
input = append(input, str) // Append each string to the slice
}
}
}
return &dto.EmbeddingRequest{
Input: input,
Model: request.Model,
}
}
func embeddingResponseMoka2OpenAI(response *dto.EmbeddingResponse) *dto.OpenAIEmbeddingResponse {
openAIEmbeddingResponse := dto.OpenAIEmbeddingResponse{
Object: "list",
Data: make([]dto.OpenAIEmbeddingResponseItem, 0, len(response.Data)),
Model: "baidu-embedding",
Usage: response.Usage,
}
for _, item := range response.Data {
openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, dto.OpenAIEmbeddingResponseItem{
Object: item.Object,
Index: item.Index,
Embedding: item.Embedding,
})
}
return &openAIEmbeddingResponse
}
func mokaEmbeddingHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
var baiduResponse dto.EmbeddingResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
err = json.Unmarshal(responseBody, &baiduResponse)
if err != nil {
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
// if baiduResponse.ErrorMsg != "" {
// return &dto.OpenAIErrorWithStatusCode{
// Error: dto.OpenAIError{
// Type: "baidu_error",
// Param: "",
// },
// StatusCode: resp.StatusCode,
// }, nil
// }
fullTextResponse := embeddingResponseMoka2OpenAI(&baiduResponse)
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse)
return nil, &fullTextResponse.Usage
}

View File

@@ -39,6 +39,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
channel.SetupApiRequestHeader(info, c, req)
req.Set("Authorization", "Bearer "+info.ApiKey)
return nil
}
@@ -46,18 +47,17 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, re
if request == nil {
return nil, errors.New("request is nil")
}
switch info.RelayMode {
case relayconstant.RelayModeEmbeddings:
return requestOpenAI2Embeddings(*request), nil
default:
return requestOpenAI2Ollama(*request), nil
}
return requestOpenAI2Ollama(*request)
}
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
return nil, nil
}
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
return requestOpenAI2Embeddings(request), nil
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}

View File

@@ -3,18 +3,21 @@ package ollama
import "one-api/dto"
type OllamaRequest struct {
Model string `json:"model,omitempty"`
Messages []dto.Message `json:"messages,omitempty"`
Stream bool `json:"stream,omitempty"`
Temperature *float64 `json:"temperature,omitempty"`
Seed float64 `json:"seed,omitempty"`
Topp float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
Stop any `json:"stop,omitempty"`
Tools []dto.ToolCall `json:"tools,omitempty"`
ResponseFormat any `json:"response_format,omitempty"`
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
PresencePenalty float64 `json:"presence_penalty,omitempty"`
Model string `json:"model,omitempty"`
Messages []dto.Message `json:"messages,omitempty"`
Stream bool `json:"stream,omitempty"`
Temperature *float64 `json:"temperature,omitempty"`
Seed float64 `json:"seed,omitempty"`
Topp float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
Stop any `json:"stop,omitempty"`
Tools []dto.ToolCall `json:"tools,omitempty"`
ResponseFormat any `json:"response_format,omitempty"`
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
PresencePenalty float64 `json:"presence_penalty,omitempty"`
Suffix any `json:"suffix,omitempty"`
StreamOptions *dto.StreamOptions `json:"stream_options,omitempty"`
Prompt any `json:"prompt,omitempty"`
}
type Options struct {
@@ -35,7 +38,7 @@ type OllamaEmbeddingRequest struct {
}
type OllamaEmbeddingResponse struct {
Error string `json:"error,omitempty"`
Model string `json:"model"`
Error string `json:"error,omitempty"`
Model string `json:"model"`
Embedding [][]float64 `json:"embeddings,omitempty"`
}

View File

@@ -9,14 +9,36 @@ import (
"net/http"
"one-api/dto"
"one-api/service"
"strings"
)
func requestOpenAI2Ollama(request dto.GeneralOpenAIRequest) *OllamaRequest {
func requestOpenAI2Ollama(request dto.GeneralOpenAIRequest) (*OllamaRequest, error) {
messages := make([]dto.Message, 0, len(request.Messages))
for _, message := range request.Messages {
if !message.IsStringContent() {
mediaMessages := message.ParseContent()
for j, mediaMessage := range mediaMessages {
if mediaMessage.Type == dto.ContentTypeImageURL {
imageUrl := mediaMessage.ImageUrl.(dto.MessageImageUrl)
// check if not base64
if strings.HasPrefix(imageUrl.Url, "http") {
fileData, err := service.GetFileBase64FromUrl(imageUrl.Url)
if err != nil {
return nil, err
}
imageUrl.Url = fmt.Sprintf("data:%s;base64,%s", fileData.MimeType, fileData.Base64Data)
}
mediaMessage.ImageUrl = imageUrl
mediaMessages[j] = mediaMessage
}
}
message.SetMediaContent(mediaMessages)
}
messages = append(messages, dto.Message{
Role: message.Role,
Content: message.Content,
Role: message.Role,
Content: message.Content,
ToolCalls: message.ToolCalls,
ToolCallId: message.ToolCallId,
})
}
str, ok := request.Stop.(string)
@@ -39,10 +61,13 @@ func requestOpenAI2Ollama(request dto.GeneralOpenAIRequest) *OllamaRequest {
ResponseFormat: request.ResponseFormat,
FrequencyPenalty: request.FrequencyPenalty,
PresencePenalty: request.PresencePenalty,
}
Prompt: request.Prompt,
StreamOptions: request.StreamOptions,
Suffix: request.Suffix,
}, nil
}
func requestOpenAI2Embeddings(request dto.GeneralOpenAIRequest) *OllamaEmbeddingRequest {
func requestOpenAI2Embeddings(request dto.EmbeddingRequest) *OllamaEmbeddingRequest {
return &OllamaEmbeddingRequest{
Model: request.Model,
Input: request.ParseInput(),
@@ -123,9 +148,9 @@ func ollamaEmbeddingHandler(c *gin.Context, resp *http.Response, promptTokens in
}
func flattenEmbeddings(embeddings [][]float64) []float64 {
flattened := []float64{}
for _, row := range embeddings {
flattened = append(flattened, row...)
flattened := []float64{}
for _, row := range embeddings {
flattened = append(flattened, row...)
}
return flattened
}
return flattened
}

View File

@@ -119,7 +119,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, re
request.MaxCompletionTokens = request.MaxTokens
request.MaxTokens = 0
}
if strings.HasPrefix(request.Model, "o3") {
if strings.HasPrefix(request.Model, "o3") || strings.HasPrefix(request.Model, "o1") {
request.Temperature = nil
}
if strings.HasSuffix(request.Model, "-high") {
@@ -149,6 +149,10 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
return nil, errors.New("not implemented")
}
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
return request, nil
}
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
a.ResponseFormat = request.ResponseFormat
if info.RelayMode == constant.RelayModeAudioSpeech {

View File

@@ -5,6 +5,9 @@ import (
"bytes"
"encoding/json"
"fmt"
"github.com/bytedance/gopkg/util/gopool"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
"github.com/pkg/errors"
"io"
"math"
@@ -20,10 +23,6 @@ import (
"strings"
"sync"
"time"
"github.com/bytedance/gopkg/util/gopool"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
)
func sendStreamData(c *gin.Context, data string, forceFormat bool) error {
@@ -91,11 +90,12 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
if len(data) < 6 { // ignore blank line or wrong format
continue
}
if data[:6] != "data: " && data[:6] != "[DONE]" {
if data[:5] != "data:" && data[:6] != "[DONE]" {
continue
}
mu.Lock()
data = data[6:]
data = data[5:]
data = strings.TrimSpace(data)
if !strings.HasPrefix(data, "[DONE]") {
if lastStreamData != "" {
err := sendStreamData(c, lastStreamData, forceFormat)

View File

@@ -49,6 +49,12 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
return nil, nil
}
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}

View File

@@ -52,6 +52,12 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
return nil, nil
}
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}

View File

@@ -36,6 +36,8 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
return fmt.Sprintf("%s/v1/embeddings", info.BaseUrl), nil
} else if info.RelayMode == constant.RelayModeChatCompletions {
return fmt.Sprintf("%s/v1/chat/completions", info.BaseUrl), nil
} else if info.RelayMode == constant.RelayModeCompletions {
return fmt.Sprintf("%s/v1/completions", info.BaseUrl), nil
}
return "", errors.New("invalid relay mode")
}
@@ -58,6 +60,10 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
return request, nil
}
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
return request, nil
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
switch info.RelayMode {
case constant.RelayModeRerank:
@@ -68,6 +74,12 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
} else {
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
}
case constant.RelayModeCompletions:
if info.IsStream {
err, usage = openai.OaiStreamHandler(c, resp, info)
} else {
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
}
case constant.RelayModeEmbeddings:
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
}

View File

@@ -73,6 +73,12 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
return nil, nil
}
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}

View File

@@ -151,6 +151,12 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
return nil, nil
}
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}

View File

@@ -0,0 +1,92 @@
package volcengine
import (
"errors"
"fmt"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/dto"
"one-api/relay/channel"
"one-api/relay/channel/openai"
relaycommon "one-api/relay/common"
"one-api/relay/constant"
"strings"
)
type Adaptor struct {
}
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
switch info.RelayMode {
case constant.RelayModeChatCompletions:
if strings.HasPrefix(info.UpstreamModelName, "bot") {
return fmt.Sprintf("%s/api/v3/bots/chat/completions", info.BaseUrl), nil
}
return fmt.Sprintf("%s/api/v3/chat/completions", info.BaseUrl), nil
case constant.RelayModeEmbeddings:
return fmt.Sprintf("%s/api/v3/embeddings", info.BaseUrl), nil
default:
}
return "", fmt.Errorf("unsupported relay mode: %d", info.RelayMode)
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
channel.SetupApiRequestHeader(info, c, req)
req.Set("Authorization", "Bearer "+info.ApiKey)
return nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
return request, nil
}
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
return nil, nil
}
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
return request, nil
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
switch info.RelayMode {
case constant.RelayModeChatCompletions:
if info.IsStream {
err, usage = openai.OaiStreamHandler(c, resp, info)
} else {
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
}
case constant.RelayModeEmbeddings:
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
}
return
}
func (a *Adaptor) GetModelList() []string {
return ModelList
}
func (a *Adaptor) GetChannelName() string {
return ChannelName
}

View File

@@ -0,0 +1,13 @@
package volcengine
var ModelList = []string{
"Doubao-pro-128k",
"Doubao-pro-32k",
"Doubao-pro-4k",
"Doubao-lite-128k",
"Doubao-lite-32k",
"Doubao-lite-4k",
"Doubao-embedding",
}
var ChannelName = "volcengine"

View File

@@ -50,6 +50,12 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
return nil, nil
}
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
// xunfei's request is not http request, so we don't need to do anything here
dummyResp := &http.Response{}

View File

@@ -56,6 +56,12 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
return nil, nil
}
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}

View File

@@ -53,6 +53,12 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
return nil, nil
}
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}

View File

@@ -90,8 +90,7 @@ func requestOpenAI2Zhipu(request dto.GeneralOpenAIRequest) *dto.GeneralOpenAIReq
mediaMessages[j] = mediaMessage
}
}
messageRaw, _ := json.Marshal(mediaMessages)
message.Content = messageRaw
message.SetMediaContent(mediaMessages)
}
messages = append(messages, dto.Message{
Role: message.Role,

View File

@@ -112,7 +112,8 @@ func GenRelayInfo(c *gin.Context) *RelayInfo {
}
if info.ChannelType == common.ChannelTypeOpenAI || info.ChannelType == common.ChannelTypeAnthropic ||
info.ChannelType == common.ChannelTypeAws || info.ChannelType == common.ChannelTypeGemini ||
info.ChannelType == common.ChannelCloudflare || info.ChannelType == common.ChannelTypeAzure {
info.ChannelType == common.ChannelCloudflare || info.ChannelType == common.ChannelTypeAzure ||
info.ChannelType == common.ChannelTypeVolcEngine || info.ChannelType == common.ChannelTypeOllama {
info.SupportStreamOptions = true
}
return info

View File

@@ -27,7 +27,9 @@ const (
APITypeVertexAi
APITypeMistral
APITypeDeepSeek
APITypeMokaAI
APITypeVolcEngine
APITypeBaiduV2
APITypeDummy // this one is only for count, do not add any channel after this
)
@@ -78,6 +80,12 @@ func ChannelType2APIType(channelType int) (int, bool) {
apiType = APITypeMistral
case common.ChannelTypeDeepSeek:
apiType = APITypeDeepSeek
case common.ChannelTypeMokaAI:
apiType = APITypeMokaAI
case common.ChannelTypeVolcEngine:
apiType = APITypeVolcEngine
case common.ChannelTypeBaiduV2:
apiType = APITypeBaiduV2
}
if apiType == -1 {
return APITypeOpenAI, false

View File

@@ -194,7 +194,7 @@ func RelaySwapFace(c *gin.Context) *dto.MidjourneyResponse {
}
defer func(ctx context.Context) {
if mjResp.StatusCode == 200 && mjResp.Response.Code == 1 {
err := model.PostConsumeQuota(relayInfo, userQuota, quota, 0, true)
err := service.PostConsumeQuota(relayInfo, quota, 0, true)
if err != nil {
common.SysError("error consuming token remain quota: " + err.Error())
}
@@ -500,7 +500,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
defer func(ctx context.Context) {
if consumeQuota && midjResponseWithStatus.StatusCode == 200 {
err := model.PostConsumeQuota(relayInfo, userQuota, quota, 0, true)
err := service.PostConsumeQuota(relayInfo, quota, 0, true)
if err != nil {
common.SysError("error consuming token remain quota: " + err.Error())
}

View File

@@ -5,6 +5,7 @@ import (
"encoding/json"
"errors"
"fmt"
"github.com/bytedance/gopkg/util/gopool"
"io"
"math"
"net/http"
@@ -272,7 +273,7 @@ func preConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommo
return 0, 0, service.OpenAIErrorWrapperLocal(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
}
if userQuota-preConsumedQuota < 0 {
return 0, 0, service.OpenAIErrorWrapperLocal(fmt.Errorf("chat pre-consumed quota failed, user quota: %d, need quota: %d", userQuota, preConsumedQuota), "insufficient_user_quota", http.StatusBadRequest)
return 0, 0, service.OpenAIErrorWrapperLocal(fmt.Errorf("chat pre-consumed quota failed, user quota: %s, need quota: %s", common.FormatQuota(userQuota), common.FormatQuota(preConsumedQuota)), "insufficient_user_quota", http.StatusBadRequest)
}
if userQuota > 100*preConsumedQuota {
// 用户额度充足,判断令牌额度是否充足
@@ -282,18 +283,18 @@ func preConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommo
if tokenQuota > 100*preConsumedQuota {
// 令牌额度充足,信任令牌
preConsumedQuota = 0
common.LogInfo(c, fmt.Sprintf("user %d quota %d and token %d quota %d are enough, trusted and no need to pre-consume", relayInfo.UserId, userQuota, relayInfo.TokenId, tokenQuota))
common.LogInfo(c, fmt.Sprintf("user %d quota %s and token %d quota %d are enough, trusted and no need to pre-consume", relayInfo.UserId, common.FormatQuota(userQuota), relayInfo.TokenId, tokenQuota))
}
} else {
// in this case, we do not pre-consume quota
// because the user has enough quota
preConsumedQuota = 0
common.LogInfo(c, fmt.Sprintf("user %d with unlimited token has enough quota %d, trusted and no need to pre-consume", relayInfo.UserId, userQuota))
common.LogInfo(c, fmt.Sprintf("user %d with unlimited token has enough quota %s, trusted and no need to pre-consume", relayInfo.UserId, common.FormatQuota(userQuota)))
}
}
if preConsumedQuota > 0 {
err = model.PreConsumeTokenQuota(relayInfo, preConsumedQuota)
err = service.PreConsumeTokenQuota(relayInfo, preConsumedQuota)
if err != nil {
return 0, 0, service.OpenAIErrorWrapperLocal(err, "pre_consume_token_quota_failed", http.StatusForbidden)
}
@@ -307,14 +308,14 @@ func preConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommo
func returnPreConsumedQuota(c *gin.Context, relayInfo *relaycommon.RelayInfo, userQuota int, preConsumedQuota int) {
if preConsumedQuota != 0 {
go func() {
gopool.Go(func() {
relayInfoCopy := *relayInfo
err := model.PostConsumeQuota(&relayInfoCopy, userQuota, -preConsumedQuota, 0, false)
err := service.PostConsumeQuota(&relayInfoCopy, -preConsumedQuota, 0, false)
if err != nil {
common.SysError("error return pre-consumed quota: " + err.Error())
}
}()
})
}
}
@@ -368,7 +369,7 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelN
//}
quotaDelta := quota - preConsumedQuota
if quotaDelta != 0 {
err := model.PostConsumeQuota(relayInfo, userQuota, quotaDelta, preConsumedQuota, true)
err := service.PostConsumeQuota(relayInfo, quotaDelta, preConsumedQuota, true)
if err != nil {
common.LogError(ctx, "error consuming token remain quota: "+err.Error())
}

View File

@@ -6,6 +6,7 @@ import (
"one-api/relay/channel/ali"
"one-api/relay/channel/aws"
"one-api/relay/channel/baidu"
"one-api/relay/channel/baidu_v2"
"one-api/relay/channel/claude"
"one-api/relay/channel/cloudflare"
"one-api/relay/channel/cohere"
@@ -14,6 +15,7 @@ import (
"one-api/relay/channel/gemini"
"one-api/relay/channel/jina"
"one-api/relay/channel/mistral"
"one-api/relay/channel/mokaai"
"one-api/relay/channel/ollama"
"one-api/relay/channel/openai"
"one-api/relay/channel/palm"
@@ -22,6 +24,7 @@ import (
"one-api/relay/channel/task/suno"
"one-api/relay/channel/tencent"
"one-api/relay/channel/vertex"
"one-api/relay/channel/volcengine"
"one-api/relay/channel/xunfei"
"one-api/relay/channel/zhipu"
"one-api/relay/channel/zhipu_4v"
@@ -74,6 +77,12 @@ func GetAdaptor(apiType int) channel.Adaptor {
return &mistral.Adaptor{}
case constant.APITypeDeepSeek:
return &deepseek.Adaptor{}
case constant.APITypeMokaAI:
return &mokaai.Adaptor{}
case constant.APITypeVolcEngine:
return &volcengine.Adaptor{}
case constant.APITypeBaiduV2:
return &baidu_v2.Adaptor{}
}
return nil
}

137
relay/relay_embedding.go Normal file
View File

@@ -0,0 +1,137 @@
package relay
import (
"bytes"
"encoding/json"
"fmt"
"github.com/gin-gonic/gin"
"net/http"
"one-api/common"
"one-api/dto"
relaycommon "one-api/relay/common"
relayconstant "one-api/relay/constant"
"one-api/service"
"one-api/setting"
)
func getEmbeddingPromptToken(embeddingRequest dto.EmbeddingRequest) int {
token, _ := service.CountTokenInput(embeddingRequest.Input, embeddingRequest.Model)
return token
}
func validateEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, embeddingRequest dto.EmbeddingRequest) error {
if embeddingRequest.Input == nil {
return fmt.Errorf("input is empty")
}
if info.RelayMode == relayconstant.RelayModeModerations && embeddingRequest.Model == "" {
embeddingRequest.Model = "omni-moderation-latest"
}
if info.RelayMode == relayconstant.RelayModeEmbeddings && embeddingRequest.Model == "" {
embeddingRequest.Model = c.Param("model")
}
return nil
}
func EmbeddingHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
relayInfo := relaycommon.GenRelayInfo(c)
var embeddingRequest *dto.EmbeddingRequest
err := common.UnmarshalBodyReusable(c, &embeddingRequest)
if err != nil {
common.LogError(c, fmt.Sprintf("getAndValidateTextRequest failed: %s", err.Error()))
return service.OpenAIErrorWrapperLocal(err, "invalid_text_request", http.StatusBadRequest)
}
err = validateEmbeddingRequest(c, relayInfo, *embeddingRequest)
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "invalid_embedding_request", http.StatusBadRequest)
}
// map model name
modelMapping := c.GetString("model_mapping")
//isModelMapped := false
if modelMapping != "" && modelMapping != "{}" {
modelMap := make(map[string]string)
err := json.Unmarshal([]byte(modelMapping), &modelMap)
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
}
if modelMap[embeddingRequest.Model] != "" {
embeddingRequest.Model = modelMap[embeddingRequest.Model]
// set upstream model name
//isModelMapped = true
}
}
relayInfo.UpstreamModelName = embeddingRequest.Model
modelPrice, success := common.GetModelPrice(embeddingRequest.Model, false)
groupRatio := setting.GetGroupRatio(relayInfo.Group)
var preConsumedQuota int
var ratio float64
var modelRatio float64
promptToken := getEmbeddingPromptToken(*embeddingRequest)
if !success {
preConsumedTokens := promptToken
modelRatio = common.GetModelRatio(embeddingRequest.Model)
ratio = modelRatio * groupRatio
preConsumedQuota = int(float64(preConsumedTokens) * ratio)
} else {
preConsumedQuota = int(modelPrice * common.QuotaPerUnit * groupRatio)
}
relayInfo.PromptTokens = promptToken
// pre-consume quota 预消耗配额
preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, preConsumedQuota, relayInfo)
if openaiErr != nil {
return openaiErr
}
defer func() {
if openaiErr != nil {
returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
}
}()
adaptor := GetAdaptor(relayInfo.ApiType)
if adaptor == nil {
return service.OpenAIErrorWrapperLocal(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest)
}
adaptor.Init(relayInfo)
convertedRequest, err := adaptor.ConvertEmbeddingRequest(c, relayInfo, *embeddingRequest)
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "convert_request_failed", http.StatusInternalServerError)
}
jsonData, err := json.Marshal(convertedRequest)
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "json_marshal_failed", http.StatusInternalServerError)
}
requestBody := bytes.NewBuffer(jsonData)
statusCodeMappingStr := c.GetString("status_code_mapping")
resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
if err != nil {
return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
}
var httpResp *http.Response
if resp != nil {
httpResp = resp.(*http.Response)
if httpResp.StatusCode != http.StatusOK {
openaiErr = service.RelayErrorHandler(httpResp)
// reset status code 重置状态码
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
return openaiErr
}
}
usage, openaiErr := adaptor.DoResponse(c, httpResp, relayInfo)
if openaiErr != nil {
// reset status code 重置状态码
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
return openaiErr
}
postConsumeQuota(c, relayInfo, embeddingRequest.Model, usage.(*dto.Usage), ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, success, "")
return nil
}

View File

@@ -113,7 +113,7 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) {
// release quota
if relayInfo.ConsumeQuota && taskErr == nil {
err := model.PostConsumeQuota(relayInfo.ToRelayInfo(), userQuota, quota, 0, true)
err := service.PostConsumeQuota(relayInfo.ToRelayInfo(), quota, 0, true)
if err != nil {
common.SysError("error consuming token remain quota: " + err.Error())
}

View File

@@ -56,6 +56,7 @@ func SetApiRouter(router *gin.Engine) {
selfRoute.POST("/pay", controller.RequestEpay)
selfRoute.POST("/amount", controller.RequestAmount)
selfRoute.POST("/aff_transfer", controller.TransferAffQuota)
selfRoute.PUT("/setting", controller.UpdateUserSetting)
}
adminRoute := userRoute.Group("/")

View File

@@ -2,6 +2,7 @@ package service
import (
"bytes"
"encoding/json"
"fmt"
"net/http"
"one-api/common"
@@ -9,19 +10,46 @@ import (
"strings"
)
// WorkerRequest Worker请求的数据结构
type WorkerRequest struct {
URL string `json:"url"`
Key string `json:"key"`
Method string `json:"method,omitempty"`
Headers map[string]string `json:"headers,omitempty"`
Body json.RawMessage `json:"body,omitempty"`
}
// DoWorkerRequest 通过Worker发送请求
func DoWorkerRequest(req *WorkerRequest) (*http.Response, error) {
if !setting.EnableWorker() {
return nil, fmt.Errorf("worker not enabled")
}
if !strings.HasPrefix(req.URL, "https") {
return nil, fmt.Errorf("only support https url")
}
workerUrl := setting.WorkerUrl
if !strings.HasSuffix(workerUrl, "/") {
workerUrl += "/"
}
// 序列化worker请求数据
workerPayload, err := json.Marshal(req)
if err != nil {
return nil, fmt.Errorf("failed to marshal worker payload: %v", err)
}
return http.Post(workerUrl, "application/json", bytes.NewBuffer(workerPayload))
}
func DoDownloadRequest(originUrl string) (resp *http.Response, err error) {
if setting.EnableWorker() {
common.SysLog(fmt.Sprintf("downloading file from worker: %s", originUrl))
if !strings.HasPrefix(originUrl, "https") {
return nil, fmt.Errorf("only support https url")
req := &WorkerRequest{
URL: originUrl,
Key: setting.WorkerValidKey,
}
workerUrl := setting.WorkerUrl
if !strings.HasSuffix(workerUrl, "/") {
workerUrl += "/"
}
// post request to worker
data := []byte(`{"url":"` + originUrl + `","key":"` + setting.WorkerValidKey + `"}`)
return http.Post(setting.WorkerUrl, "application/json", bytes.NewBuffer(data))
return DoWorkerRequest(req)
} else {
common.SysLog(fmt.Sprintf("downloading from origin: %s", originUrl))
return http.Get(originUrl)

View File

@@ -4,8 +4,9 @@ import (
"fmt"
"net/http"
"one-api/common"
relaymodel "one-api/dto"
"one-api/dto"
"one-api/model"
"one-api/setting"
"strings"
)
@@ -14,17 +15,17 @@ func DisableChannel(channelId int, channelName string, reason string) {
model.UpdateChannelStatusById(channelId, common.ChannelStatusAutoDisabled, reason)
subject := fmt.Sprintf("通道「%s」#%d已被禁用", channelName, channelId)
content := fmt.Sprintf("通道「%s」#%d已被禁用原因%s", channelName, channelId, reason)
notifyRootUser(subject, content)
NotifyRootUser(subject, content, dto.NotifyTypeChannelUpdate)
}
func EnableChannel(channelId int, channelName string) {
model.UpdateChannelStatusById(channelId, common.ChannelStatusEnabled, "")
subject := fmt.Sprintf("通道「%s」#%d已被启用", channelName, channelId)
content := fmt.Sprintf("通道「%s」#%d已被启用", channelName, channelId)
notifyRootUser(subject, content)
NotifyRootUser(subject, content, dto.NotifyTypeChannelUpdate)
}
func ShouldDisableChannel(channelType int, err *relaymodel.OpenAIErrorWithStatusCode) bool {
func ShouldDisableChannel(channelType int, err *dto.OpenAIErrorWithStatusCode) bool {
if !common.AutomaticDisableChannelEnabled {
return false
}
@@ -64,28 +65,17 @@ func ShouldDisableChannel(channelType int, err *relaymodel.OpenAIErrorWithStatus
case "forbidden":
return true
}
if strings.HasPrefix(err.Error.Message, "Your credit balance is too low") { // anthropic
return true
} else if strings.HasPrefix(err.Error.Message, "This organization has been disabled.") {
return true
} else if strings.HasPrefix(err.Error.Message, "You exceeded your current quota") {
return true
} else if strings.HasPrefix(err.Error.Message, "Permission denied") {
return true
}
if strings.Contains(err.Error.Message, "The security token included in the request is invalid") { // anthropic
return true
} else if strings.Contains(err.Error.Message, "Operation not allowed") {
return true
} else if strings.Contains(err.Error.Message, "Your account is not authorized") {
lowerMessage := strings.ToLower(err.Error.Message)
search, _ := AcSearch(lowerMessage, setting.AutomaticDisableKeywords, true)
if search {
return true
}
return false
}
func ShouldEnableChannel(err error, openaiWithStatusErr *relaymodel.OpenAIErrorWithStatusCode, status int) bool {
func ShouldEnableChannel(err error, openaiWithStatusErr *dto.OpenAIErrorWithStatusCode, status int) bool {
if !common.AutomaticEnableChannelEnabled {
return false
}

117
service/notify-limit.go Normal file
View File

@@ -0,0 +1,117 @@
package service
import (
"fmt"
"github.com/bytedance/gopkg/util/gopool"
"one-api/common"
"one-api/constant"
"strconv"
"sync"
"time"
)
// notifyLimitStore is used for in-memory rate limiting when Redis is disabled
var (
notifyLimitStore sync.Map
cleanupOnce sync.Once
)
type limitCount struct {
Count int
Timestamp time.Time
}
func getDuration() time.Duration {
minute := constant.NotificationLimitDurationMinute
return time.Duration(minute) * time.Minute
}
// startCleanupTask starts a background task to clean up expired entries
func startCleanupTask() {
gopool.Go(func() {
for {
time.Sleep(time.Hour)
now := time.Now()
notifyLimitStore.Range(func(key, value interface{}) bool {
if limit, ok := value.(limitCount); ok {
if now.Sub(limit.Timestamp) >= getDuration() {
notifyLimitStore.Delete(key)
}
}
return true
})
}
})
}
// CheckNotificationLimit checks if the user has exceeded their notification limit
// Returns true if the user can send notification, false if limit exceeded
func CheckNotificationLimit(userId int, notifyType string) (bool, error) {
if common.RedisEnabled {
return checkRedisLimit(userId, notifyType)
}
return checkMemoryLimit(userId, notifyType)
}
func checkRedisLimit(userId int, notifyType string) (bool, error) {
key := fmt.Sprintf("notify_limit:%d:%s:%s", userId, notifyType, time.Now().Format("2006010215"))
// Get current count
count, err := common.RedisGet(key)
if err != nil && err.Error() != "redis: nil" {
return false, fmt.Errorf("failed to get notification count: %w", err)
}
// If key doesn't exist, initialize it
if count == "" {
err = common.RedisSet(key, "1", getDuration())
return true, err
}
currentCount, _ := strconv.Atoi(count)
limit := constant.NotifyLimitCount
// Check if limit is already reached
if currentCount >= limit {
return false, nil
}
// Only increment if under limit
err = common.RedisIncr(key, 1)
if err != nil {
return false, fmt.Errorf("failed to increment notification count: %w", err)
}
return true, nil
}
func checkMemoryLimit(userId int, notifyType string) (bool, error) {
// Ensure cleanup task is started
cleanupOnce.Do(startCleanupTask)
key := fmt.Sprintf("%d:%s:%s", userId, notifyType, time.Now().Format("2006010215"))
now := time.Now()
// Get current limit count or initialize new one
var currentLimit limitCount
if value, ok := notifyLimitStore.Load(key); ok {
currentLimit = value.(limitCount)
// Check if the entry has expired
if now.Sub(currentLimit.Timestamp) >= getDuration() {
currentLimit = limitCount{Count: 0, Timestamp: now}
}
} else {
currentLimit = limitCount{Count: 0, Timestamp: now}
}
// Increment count
currentLimit.Count++
// Check against limits
limit := constant.NotifyLimitCount
// Store updated count
notifyLimitStore.Store(key, currentLimit)
return currentLimit.Count <= limit, nil
}

View File

@@ -3,8 +3,10 @@ package service
import (
"errors"
"fmt"
"github.com/bytedance/gopkg/util/gopool"
"math"
"one-api/common"
constant2 "one-api/constant"
"one-api/dto"
"one-api/model"
relaycommon "one-api/relay/common"
@@ -99,7 +101,7 @@ func PreWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usag
return errors.New(fmt.Sprintf("令牌额度不足,剩余额度为 %d", token.RemainQuota))
}
err = model.PostConsumeQuota(relayInfo, 0, quota, 0, false)
err = PostConsumeQuota(relayInfo, quota, 0, false)
if err != nil {
return err
}
@@ -222,7 +224,7 @@ func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
} else {
quotaDelta := quota - preConsumedQuota
if quotaDelta != 0 {
err := model.PostConsumeQuota(relayInfo, userQuota, quotaDelta, preConsumedQuota, true)
err := PostConsumeQuota(relayInfo, quotaDelta, preConsumedQuota, true)
if err != nil {
common.LogError(ctx, "error consuming token remain quota: "+err.Error())
}
@@ -239,3 +241,88 @@ func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, usage.PromptTokens, usage.CompletionTokens, logModel,
tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, relayInfo.Group, other)
}
func PreConsumeTokenQuota(relayInfo *relaycommon.RelayInfo, quota int) error {
if quota < 0 {
return errors.New("quota 不能为负数!")
}
if relayInfo.IsPlayground {
return nil
}
//if relayInfo.TokenUnlimited {
// return nil
//}
token, err := model.GetTokenByKey(relayInfo.TokenKey, false)
if err != nil {
return err
}
if !relayInfo.TokenUnlimited && token.RemainQuota < quota {
return errors.New("令牌额度不足")
}
err = model.DecreaseTokenQuota(relayInfo.TokenId, relayInfo.TokenKey, quota)
if err != nil {
return err
}
return nil
}
func PostConsumeQuota(relayInfo *relaycommon.RelayInfo, quota int, preConsumedQuota int, sendEmail bool) (err error) {
if quota > 0 {
err = model.DecreaseUserQuota(relayInfo.UserId, quota)
} else {
err = model.IncreaseUserQuota(relayInfo.UserId, -quota)
}
if err != nil {
return err
}
if !relayInfo.IsPlayground {
if quota > 0 {
err = model.DecreaseTokenQuota(relayInfo.TokenId, relayInfo.TokenKey, quota)
} else {
err = model.IncreaseTokenQuota(relayInfo.TokenId, relayInfo.TokenKey, -quota)
}
if err != nil {
return err
}
}
if sendEmail {
if (quota + preConsumedQuota) != 0 {
checkAndSendQuotaNotify(relayInfo.UserId, quota, preConsumedQuota)
}
}
return nil
}
func checkAndSendQuotaNotify(userId int, quota int, preConsumedQuota int) {
gopool.Go(func() {
userCache, err := model.GetUserCache(userId)
if err != nil {
common.SysError("failed to get user cache: " + err.Error())
}
userSetting := userCache.GetSetting()
threshold := common.QuotaRemindThreshold
if userCustomThreshold, ok := userSetting[constant2.UserSettingQuotaWarningThreshold]; ok {
threshold = int(userCustomThreshold.(float64))
}
//noMoreQuota := userCache.Quota-(quota+preConsumedQuota) <= 0
quotaTooLow := false
consumeQuota := quota + preConsumedQuota
if userCache.Quota-consumeQuota < threshold {
quotaTooLow = true
}
if quotaTooLow {
prompt := "您的额度即将用尽"
topUpLink := fmt.Sprintf("%s/topup", setting.ServerAddress)
content := "{{value}},当前剩余额度为 {{value}},为了不影响您的使用,请及时充值。<br/>充值链接:<a href='{{value}}'>{{value}}</a>"
err = NotifyUser(userCache, dto.NewNotify(dto.NotifyTypeQuotaExceed, prompt, content, []interface{}{prompt, common.FormatQuota(userCache.Quota), topUpLink, topUpLink}))
if err != nil {
common.SysError(fmt.Sprintf("failed to send quota notify to user %d: %s", userId, err.Error()))
}
}
})
}

View File

@@ -60,17 +60,7 @@ func SensitiveWordContains(text string) (bool, []string) {
return false, nil
}
checkText := strings.ToLower(text)
// 构建一个AC自动机
m := InitAc()
hits := m.MultiPatternSearch([]rune(checkText), false)
if len(hits) > 0 {
words := make([]string, 0)
for _, hit := range hits {
words = append(words, string(hit.Word))
}
return true, words
}
return false, nil
return AcSearch(checkText, setting.SensitiveWords, false)
}
// SensitiveWordReplace 敏感词替换,返回是否包含敏感词和替换后的文本
@@ -79,7 +69,7 @@ func SensitiveWordReplace(text string, returnImmediately bool) (bool, []string,
return false, nil, text
}
checkText := strings.ToLower(text)
m := InitAc()
m := InitAc(setting.SensitiveWords)
hits := m.MultiPatternSearch([]rune(checkText), returnImmediately)
if len(hits) > 0 {
words := make([]string, 0)

View File

@@ -4,7 +4,6 @@ import (
"bytes"
"fmt"
goahocorasick "github.com/anknown/ahocorasick"
"one-api/setting"
"strings"
)
@@ -57,9 +56,9 @@ func RemoveDuplicate(s []string) []string {
return result
}
func InitAc() *goahocorasick.Machine {
func InitAc(words []string) *goahocorasick.Machine {
m := new(goahocorasick.Machine)
dict := readRunes()
dict := readRunes(words)
if err := m.Build(dict); err != nil {
fmt.Println(err)
return nil
@@ -67,10 +66,10 @@ func InitAc() *goahocorasick.Machine {
return m
}
func readRunes() [][]rune {
func readRunes(words []string) [][]rune {
var dict [][]rune
for _, word := range setting.SensitiveWords {
for _, word := range words {
word = strings.ToLower(word)
l := bytes.TrimSpace([]byte(word))
dict = append(dict, bytes.Runes(l))
@@ -78,3 +77,25 @@ func readRunes() [][]rune {
return dict
}
func AcSearch(findText string, dict []string, stopImmediately bool) (bool, []string) {
if len(dict) == 0 {
return false, nil
}
if len(findText) == 0 {
return false, nil
}
m := InitAc(dict)
if m == nil {
return false, nil
}
hits := m.MultiPatternSearch([]rune(findText), stopImmediately)
if len(hits) > 0 {
words := make([]string, 0)
for _, hit := range hits {
words = append(words, string(hit.Word))
}
return true, words
}
return false, nil
}

View File

@@ -4,7 +4,6 @@ import (
"encoding/json"
"errors"
"fmt"
"github.com/pkoukk/tiktoken-go"
"image"
"log"
"math"
@@ -14,6 +13,8 @@ import (
relaycommon "one-api/relay/common"
"strings"
"unicode/utf8"
"github.com/pkoukk/tiktoken-go"
)
// tokenEncoderMap won't grow after initialization
@@ -323,6 +324,12 @@ func CountTokenInput(input any, model string) (int, error) {
text += s
}
return CountTextToken(text, model)
case []interface{}:
text := ""
for _, item := range v {
text += fmt.Sprintf("%v", item)
}
return CountTextToken(text, model)
}
return CountTokenInput(fmt.Sprintf("%v", input), model)
}

View File

@@ -3,15 +3,75 @@ package service
import (
"fmt"
"one-api/common"
"one-api/constant"
"one-api/dto"
"one-api/model"
"strings"
)
func notifyRootUser(subject string, content string) {
if common.RootUserEmail == "" {
common.RootUserEmail = model.GetRootUserEmail()
}
err := common.SendEmail(subject, common.RootUserEmail, content)
if err != nil {
common.SysError(fmt.Sprintf("failed to send email: %s", err.Error()))
}
func NotifyRootUser(t string, subject string, content string) {
user := model.GetRootUser().ToBaseUser()
_ = NotifyUser(user, dto.NewNotify(t, subject, content, nil))
}
func NotifyUser(user *model.UserBase, data dto.Notify) error {
userSetting := user.GetSetting()
notifyType, ok := userSetting[constant.UserSettingNotifyType]
if !ok {
notifyType = constant.NotifyTypeEmail
}
// Check notification limit
canSend, err := CheckNotificationLimit(user.Id, data.Type)
if err != nil {
common.SysError(fmt.Sprintf("failed to check notification limit: %s", err.Error()))
return err
}
if !canSend {
return fmt.Errorf("notification limit exceeded for user %d with type %s", user.Id, notifyType)
}
switch notifyType {
case constant.NotifyTypeEmail:
userEmail := user.Email
// check setting email
if settingEmail, ok := userSetting[constant.UserSettingNotificationEmail]; ok {
userEmail = settingEmail.(string)
}
if userEmail == "" {
common.SysLog(fmt.Sprintf("user %d has no email, skip sending email", user.Id))
return nil
}
return sendEmailNotify(userEmail, data)
case constant.NotifyTypeWebhook:
webhookURL, ok := userSetting[constant.UserSettingWebhookUrl]
if !ok {
common.SysError(fmt.Sprintf("user %d has no webhook url, skip sending webhook", user.Id))
return nil
}
webhookURLStr, ok := webhookURL.(string)
if !ok {
common.SysError(fmt.Sprintf("user %d webhook url is not string type", user.Id))
return nil
}
// 获取 webhook secret
var webhookSecret string
if secret, ok := userSetting[constant.UserSettingWebhookSecret]; ok {
webhookSecret, _ = secret.(string)
}
return SendWebhookNotify(webhookURLStr, webhookSecret, data)
}
return nil
}
func sendEmailNotify(userEmail string, data dto.Notify) error {
// make email content
content := data.Content
// 处理占位符
for _, value := range data.Values {
content = strings.Replace(content, dto.ContentValueParam, fmt.Sprintf("%v", value), 1)
}
return common.SendEmail(data.Title, userEmail, content)
}

118
service/webhook.go Normal file
View File

@@ -0,0 +1,118 @@
package service
import (
"bytes"
"crypto/hmac"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"fmt"
"net/http"
"one-api/dto"
"one-api/setting"
"time"
)
// WebhookPayload webhook 通知的负载数据
type WebhookPayload struct {
Type string `json:"type"`
Title string `json:"title"`
Content string `json:"content"`
Values []interface{} `json:"values,omitempty"`
Timestamp int64 `json:"timestamp"`
}
// generateSignature 生成 webhook 签名
func generateSignature(secret string, payload []byte) string {
h := hmac.New(sha256.New, []byte(secret))
h.Write(payload)
return hex.EncodeToString(h.Sum(nil))
}
// SendWebhookNotify 发送 webhook 通知
func SendWebhookNotify(webhookURL string, secret string, data dto.Notify) error {
// 处理占位符
content := data.Content
for _, value := range data.Values {
content = fmt.Sprintf(content, value)
}
// 构建 webhook 负载
payload := WebhookPayload{
Type: data.Type,
Title: data.Title,
Content: content,
Values: data.Values,
Timestamp: time.Now().Unix(),
}
// 序列化负载
payloadBytes, err := json.Marshal(payload)
if err != nil {
return fmt.Errorf("failed to marshal webhook payload: %v", err)
}
// 创建 HTTP 请求
var req *http.Request
var resp *http.Response
if setting.EnableWorker() {
// 构建worker请求数据
workerReq := &WorkerRequest{
URL: webhookURL,
Key: setting.WorkerValidKey,
Method: http.MethodPost,
Headers: map[string]string{
"Content-Type": "application/json",
},
Body: payloadBytes,
}
// 如果有secret添加签名到headers
if secret != "" {
signature := generateSignature(secret, payloadBytes)
workerReq.Headers["X-Webhook-Signature"] = signature
workerReq.Headers["Authorization"] = "Bearer " + secret
}
resp, err = DoWorkerRequest(workerReq)
if err != nil {
return fmt.Errorf("failed to send webhook request through worker: %v", err)
}
defer resp.Body.Close()
// 检查响应状态
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return fmt.Errorf("webhook request failed with status code: %d", resp.StatusCode)
}
} else {
req, err = http.NewRequest(http.MethodPost, webhookURL, bytes.NewBuffer(payloadBytes))
if err != nil {
return fmt.Errorf("failed to create webhook request: %v", err)
}
// 设置请求头
req.Header.Set("Content-Type", "application/json")
// 如果有 secret生成签名
if secret != "" {
signature := generateSignature(secret, payloadBytes)
req.Header.Set("X-Webhook-Signature", signature)
}
// 发送请求
client := GetImpatientHttpClient()
resp, err = client.Do(req)
if err != nil {
return fmt.Errorf("failed to send webhook request: %v", err)
}
defer resp.Body.Close()
// 检查响应状态
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return fmt.Errorf("webhook request failed with status code: %d", resp.StatusCode)
}
}
return nil
}

View File

@@ -12,6 +12,9 @@ var Chats = []map[string]string{
{
"Lobe Chat 官方示例": "https://chat-preview.lobehub.com/?settings={\"keyVaults\":{\"openai\":{\"apiKey\":\"{key}\",\"baseURL\":\"{address}/v1\"}}}",
},
{
"AI as Workspace": "https://aiaw.app/set-provider?provider={\"type\":\"openai\",\"settings\":{\"apiKey\":\"{key}\",\"baseURL\":\"{address}/v1\",\"compatibility\":\"strict\"}}",
},
{
"AMA 问天": "ama://set-api-key?server={address}&key={key}",
},

View File

@@ -1,3 +1,30 @@
package setting
import "strings"
var DemoSiteEnabled = false
var AutomaticDisableKeywords = []string{
"Your credit balance is too low",
"This organization has been disabled.",
"You exceeded your current quota",
"Permission denied",
"The security token included in the request is invalid",
"Operation not allowed",
"Your account is not authorized",
}
func AutomaticDisableKeywordsToString() string {
return strings.Join(AutomaticDisableKeywords, "\n")
}
func AutomaticDisableKeywordsFromString(s string) {
AutomaticDisableKeywords = []string{}
ak := strings.Split(s, "\n")
for _, k := range ak {
k = strings.TrimSpace(k)
if k != "" {
AutomaticDisableKeywords = append(AutomaticDisableKeywords, k)
}
}
}

View File

@@ -44,7 +44,7 @@ function renderTimestamp(timestamp) {
const ChannelsTable = () => {
const { t } = useTranslation();
let type2label = undefined;
const renderType = (type) => {
@@ -53,11 +53,11 @@ const ChannelsTable = () => {
for (let i = 0; i < CHANNEL_OPTIONS.length; i++) {
type2label[CHANNEL_OPTIONS[i].value] = CHANNEL_OPTIONS[i];
}
type2label[0] = { value: 0, text: t('未知类型'), color: 'grey' };
type2label[0] = { value: 0, label: t('未知类型'), color: 'grey' };
}
return (
<Tag size="large" color={type2label[type]?.color}>
{type2label[type]?.text}
{type2label[type]?.label}
</Tag>
);
};
@@ -357,6 +357,13 @@ const ChannelsTable = () => {
dataIndex: 'operate',
render: (text, record, index) => {
if (record.children === undefined) {
// 构建模型测试菜单
const modelMenuItems = record.models.split(',').map(model => ({
node: 'item',
name: model,
onClick: () => testChannel(record, model)
}));
return (
<div>
<SplitButtonGroup
@@ -374,7 +381,7 @@ const ChannelsTable = () => {
<Dropdown
trigger="click"
position="bottomRight"
menu={record.test_models}
menu={modelMenuItems} // 使用即时生成的菜单项
>
<Button
style={{ padding: '8px 4px' }}
@@ -545,21 +552,10 @@ const ChannelsTable = () => {
let channelTags = {};
for (let i = 0; i < channels.length; i++) {
channels[i].key = '' + channels[i].id;
let test_models = [];
channels[i].models.split(',').forEach((item, index) => {
test_models.push({
node: 'item',
name: item,
onClick: () => {
testChannel(channels[i], item);
}
});
});
channels[i].test_models = test_models;
if (!enableTagMode) {
channelDates.push(channels[i]);
} else {
let tag = channels[i].tag?channels[i].tag:"";
let tag = channels[i].tag ? channels[i].tag : "";
// find from channelTags
let tagIndex = channelTags[tag];
let tagChannelDates = undefined;
@@ -798,18 +794,64 @@ const ChannelsTable = () => {
setSearching(false);
};
const updateChannelProperty = (channelId, updateFn) => {
// Create a new copy of channels array
const newChannels = [...channels];
let updated = false;
// Find and update the correct channel
newChannels.forEach(channel => {
if (channel.children !== undefined) {
// If this is a tag group, search in its children
channel.children.forEach(child => {
if (child.id === channelId) {
updateFn(child);
updated = true;
}
});
} else if (channel.id === channelId) {
// Direct channel match
updateFn(channel);
updated = true;
}
});
// Only update state if we actually modified a channel
if (updated) {
setChannels(newChannels);
}
};
const testChannel = async (record, model) => {
const res = await API.get(`/api/channel/test/${record.id}?model=${model}`);
const { success, message, time } = res.data;
if (success) {
record.response_time = time * 1000;
record.test_time = Date.now() / 1000;
// Also update the channels state to persist the change
updateChannelProperty(record.id, (channel) => {
channel.response_time = time * 1000;
channel.test_time = Date.now() / 1000;
});
showInfo(t('通道 ${name} 测试成功,耗时 ${time.toFixed(2)} 秒。').replace('${name}', record.name).replace('${time.toFixed(2)}', time.toFixed(2)));
} else {
showError(message);
}
};
const updateChannelBalance = async (record) => {
const res = await API.get(`/api/channel/update_balance/${record.id}/`);
const { success, message, balance } = res.data;
if (success) {
updateChannelProperty(record.id, (channel) => {
channel.balance = balance;
channel.balance_updated_time = Date.now() / 1000;
});
showInfo(t('通道 ${name} 余额更新成功!').replace('${name}', record.name));
} else {
showError(message);
}
};
const testAllChannels = async () => {
const res = await API.get(`/api/channel/test`);
const { success, message } = res.data;
@@ -831,18 +873,6 @@ const ChannelsTable = () => {
}
};
const updateChannelBalance = async (record) => {
const res = await API.get(`/api/channel/update_balance/${record.id}/`);
const { success, message, balance } = res.data;
if (success) {
record.balance = balance;
record.balance_updated_time = Date.now() / 1000;
showInfo(t('通道 ${name} 余额更新成功!').replace('${name}', record.name));
} else {
showError(message);
}
};
const updateAllChannelsBalance = async () => {
setUpdatingBalance(true);
const res = await API.get(`/api/channel/update_balance`);
@@ -1186,7 +1216,7 @@ const ChannelsTable = () => {
</Space>
</div>
<div style={{ marginTop: 20 }}>
<Space>
<Space>
<Typography.Text strong>{t('标签聚合模式')}</Typography.Text>
<Switch
checked={enableTagMode}
@@ -1199,14 +1229,14 @@ const ChannelsTable = () => {
}}
/>
<Button
disabled={!enableBatchDelete}
theme="light"
type="primary"
style={{ marginRight: 8 }}
onClick={() => setShowBatchSetTag(true)}
>
{t('批量设置标签')}
</Button>
disabled={!enableBatchDelete}
theme="light"
type="primary"
style={{ marginRight: 8 }}
onClick={() => setShowBatchSetTag(true)}
>
{t('批量设置标签')}
</Button>
</Space>
</div>

View File

@@ -59,6 +59,7 @@ const OperationSetting = () => {
RetryTimes: 0,
Chats: "[]",
DemoSiteEnabled: false,
AutomaticDisableKeywords: '',
});
let [loading, setLoading] = useState(false);

View File

@@ -26,6 +26,10 @@ import {
Tag,
Typography,
Collapsible,
Select,
Radio,
RadioGroup,
AutoComplete,
} from '@douyinfe/semi-ui';
import {
getQuotaPerUnit,
@@ -67,14 +71,16 @@ const PersonalSetting = () => {
const [transferAmount, setTransferAmount] = useState(0);
const [isModelsExpanded, setIsModelsExpanded] = useState(false);
const MODELS_DISPLAY_COUNT = 10; // 默认显示的模型数量
const [notificationSettings, setNotificationSettings] = useState({
warningType: 'email',
warningThreshold: 100000,
webhookUrl: '',
webhookSecret: '',
notificationEmail: ''
});
const [showWebhookDocs, setShowWebhookDocs] = useState(false);
useEffect(() => {
// let user = localStorage.getItem('user');
// if (user) {
// userDispatch({ type: 'login', payload: user });
// }
// console.log(localStorage.getItem('user'))
let status = localStorage.getItem('status');
if (status) {
status = JSON.parse(status);
@@ -105,6 +111,19 @@ const PersonalSetting = () => {
return () => clearInterval(countdownInterval); // Clean up on unmount
}, [disableButton, countdown]);
useEffect(() => {
if (userState?.user?.setting) {
const settings = JSON.parse(userState.user.setting);
setNotificationSettings({
warningType: settings.notify_type || 'email',
warningThreshold: settings.quota_warning_threshold || 500000,
webhookUrl: settings.webhook_url || '',
webhookSecret: settings.webhook_secret || '',
notificationEmail: settings.notification_email || ''
});
}
}, [userState?.user?.setting]);
const handleInputChange = (name, value) => {
setInputs((inputs) => ({...inputs, [name]: value}));
};
@@ -300,7 +319,36 @@ const PersonalSetting = () => {
}
};
const handleNotificationSettingChange = (type, value) => {
setNotificationSettings(prev => ({
...prev,
[type]: value.target ? value.target.value : value // 处理 Radio 事件对象
}));
};
const saveNotificationSettings = async () => {
try {
const res = await API.put('/api/user/setting', {
notify_type: notificationSettings.warningType,
quota_warning_threshold: notificationSettings.warningThreshold,
webhook_url: notificationSettings.webhookUrl,
webhook_secret: notificationSettings.webhookSecret,
notification_email: notificationSettings.notificationEmail
});
if (res.data.success) {
showSuccess(t('通知设置已更新'));
await getUserData();
} else {
showError(res.data.message);
}
} catch (error) {
showError(t('更新通知设置失败'));
}
};
return (
<div>
<Layout>
<Layout.Content>
@@ -526,9 +574,7 @@ const PersonalSetting = () => {
</div>
<div style={{marginTop: 10}}>
<Typography.Text strong>{t('微信')}</Typography.Text>
<div
style={{display: 'flex', justifyContent: 'space-between'}}
>
<div style={{display: 'flex', justifyContent: 'space-between'}}>
<div>
<Input
value={
@@ -541,12 +587,16 @@ const PersonalSetting = () => {
</div>
<div>
<Button
disabled={
(userState.user && userState.user.wechat_id !== '') ||
!status.wechat_login
}
disabled={!status.wechat_login}
onClick={() => {
setShowWeChatBindModal(true);
}}
>
{status.wechat_login ? t('绑定') : t('未启用')}
{userState.user && userState.user.wechat_id !== ''
? t('修改绑定')
: status.wechat_login
? t('绑定')
: t('未启用')}
</Button>
</div>
</div>
@@ -672,18 +722,8 @@ const PersonalSetting = () => {
style={{marginTop: '10px'}}
/>
)}
{status.wechat_login && (
<Button
onClick={() => {
setShowWeChatBindModal(true);
}}
>
{t('绑定微信账号')}
</Button>
)}
<Modal
onCancel={() => setShowWeChatBindModal(false)}
// onOpen={() => setShowWeChatBindModal(true)}
visible={showWeChatBindModal}
size={'small'}
>
@@ -707,9 +747,121 @@ const PersonalSetting = () => {
</Modal>
</div>
</Card>
<Card style={{marginTop: 10}}>
<Typography.Title heading={6}>{t('通知设置')}</Typography.Title>
<div style={{marginTop: 20}}>
<Typography.Text strong>{t('通知方式')}</Typography.Text>
<div style={{marginTop: 10}}>
<RadioGroup
value={notificationSettings.warningType}
onChange={value => handleNotificationSettingChange('warningType', value)}
>
<Radio value="email">{t('邮件通知')}</Radio>
<Radio value="webhook">{t('Webhook通知')}</Radio>
</RadioGroup>
</div>
</div>
{notificationSettings.warningType === 'webhook' && (
<>
<div style={{marginTop: 20}}>
<Typography.Text strong>{t('Webhook地址')}</Typography.Text>
<div style={{marginTop: 10}}>
<Input
value={notificationSettings.webhookUrl}
onChange={val => handleNotificationSettingChange('webhookUrl', val)}
placeholder={t('请输入Webhook地址例如: https://example.com/webhook')}
/>
<Typography.Text type="secondary" style={{marginTop: 8, display: 'block'}}>
{t('只支持https系统将以 POST 方式发送通知,请确保地址可以接收 POST 请求')}
</Typography.Text>
<Typography.Text type="secondary" style={{marginTop: 8, display: 'block'}}>
<div style={{cursor: 'pointer'}} onClick={() => setShowWebhookDocs(!showWebhookDocs)}>
{t('Webhook请求结构')} {showWebhookDocs ? '▼' : '▶'}
</div>
<Collapsible isOpen={showWebhookDocs}>
<pre style={{marginTop: 4, background: 'var(--semi-color-fill-0)', padding: 8, borderRadius: 4}}>
{`{
"type": "quota_exceed", // 通知类型
"title": "标题", // 通知标题
"content": "通知内容", // 通知内容,支持 {{value}} 变量占位符
"values": ["值1", "值2"], // 按顺序替换content中的 {{value}} 占位符
"timestamp": 1739950503 // 时间戳
}
示例:
{
"type": "quota_exceed",
"title": "额度预警通知",
"content": "您的额度即将用尽,当前剩余额度为 {{value}}",
"values": ["$0.99"],
"timestamp": 1739950503
}`}
</pre>
</Collapsible>
</Typography.Text>
</div>
</div>
<div style={{marginTop: 20}}>
<Typography.Text strong>{t('接口凭证(可选)')}</Typography.Text>
<div style={{marginTop: 10}}>
<Input
value={notificationSettings.webhookSecret}
onChange={val => handleNotificationSettingChange('webhookSecret', val)}
placeholder={t('请输入密钥')}
/>
<Typography.Text type="secondary" style={{marginTop: 8, display: 'block'}}>
{t('密钥将以 Bearer 方式添加到请求头中用于验证webhook请求的合法性')}
</Typography.Text>
<Typography.Text type="secondary" style={{marginTop: 4, display: 'block'}}>
{t('Authorization: Bearer your-secret-key')}
</Typography.Text>
</div>
</div>
</>
)}
{notificationSettings.warningType === 'email' && (
<div style={{marginTop: 20}}>
<Typography.Text strong>{t('通知邮箱')}</Typography.Text>
<div style={{marginTop: 10}}>
<Input
value={notificationSettings.notificationEmail}
onChange={val => handleNotificationSettingChange('notificationEmail', val)}
placeholder={t('留空则使用账号绑定的邮箱')}
/>
<Typography.Text type="secondary" style={{marginTop: 8, display: 'block'}}>
{t('设置用于接收额度预警的邮箱地址,不填则使用账号绑定的邮箱')}
</Typography.Text>
</div>
</div>
)}
<div style={{marginTop: 20}}>
<Typography.Text strong>{t('额度预警阈值')} {renderQuotaWithPrompt(notificationSettings.warningThreshold)}</Typography.Text>
<div style={{marginTop: 10}}>
<AutoComplete
value={notificationSettings.warningThreshold}
onChange={val => handleNotificationSettingChange('warningThreshold', val)}
style={{width: 200}}
placeholder={t('请输入预警额度')}
data={[
{ value: 100000, label: '0.2$' },
{ value: 500000, label: '1$' },
{ value: 1000000, label: '5$' },
{ value: 5000000, label: '10$' }
]}
/>
</div>
<Typography.Text type="secondary" style={{marginTop: 10, display: 'block'}}>
{t('当剩余额度低于此数值时,系统将通过选择的方式发送通知')}
</Typography.Text>
</div>
<div style={{marginTop: 20}}>
<Button type="primary" onClick={saveNotificationSettings}>
{t('保存设置')}
</Button>
</div>
</Card>
<Modal
onCancel={() => setShowEmailBindModal(false)}
// onOpen={() => setShowEmailBindModal(true)}
onOk={bindEmail}
visible={showEmailBindModal}
size={'small'}

View File

@@ -80,7 +80,7 @@ const SiderBar = () => {
itemKey: 'channel',
to: '/channel',
icon: <IconLayers />,
className: isAdmin() ? 'semi-navigation-item-normal' : 'tableHiddle',
className: isAdmin() ? '' : 'tableHiddle',
},
{
text: t('聊天'),
@@ -101,7 +101,7 @@ const SiderBar = () => {
icon: <IconCalendarClock />,
className:
localStorage.getItem('enable_data_export') === 'true'
? 'semi-navigation-item-normal'
? ''
: 'tableHiddle',
},
{
@@ -109,7 +109,7 @@ const SiderBar = () => {
itemKey: 'redemption',
to: '/redemption',
icon: <IconGift />,
className: isAdmin() ? 'semi-navigation-item-normal' : 'tableHiddle',
className: isAdmin() ? '' : 'tableHiddle',
},
{
text: t('钱包'),
@@ -122,7 +122,7 @@ const SiderBar = () => {
itemKey: 'user',
to: '/user',
icon: <IconUser />,
className: isAdmin() ? 'semi-navigation-item-normal' : 'tableHiddle',
className: isAdmin() ? '' : 'tableHiddle',
},
{
text: t('日志'),
@@ -137,7 +137,7 @@ const SiderBar = () => {
icon: <IconImage />,
className:
localStorage.getItem('enable_drawing') === 'true'
? 'semi-navigation-item-normal'
? ''
: 'tableHiddle',
},
{
@@ -147,7 +147,7 @@ const SiderBar = () => {
icon: <IconChecklistStroked />,
className:
localStorage.getItem('enable_task') === 'true'
? 'semi-navigation-item-normal'
? ''
: 'tableHiddle',
},
{

View File

@@ -368,6 +368,17 @@ const SystemSetting = () => {
</a>
</Header>
<Message info>
注意代理功能仅对图片请求和 Webhook 请求生效不会影响其他 API 请求如需配置 API 请求代理请参考
<a
href='https://github.com/Calcium-Ion/new-api/blob/main/docs/channel/other_setting.md'
target='_blank'
rel='noreferrer'
>
{' '}API 代理设置文档
</a>
</Message>
<Form.Group widths='equal'>
<Form.Input
label='Worker地址不填写则不启用代理'

View File

@@ -1,129 +1,112 @@
export const CHANNEL_OPTIONS = [
{ key: 1, text: 'OpenAI', value: 1, color: 'green', label: 'OpenAI' },
{ value: 1, color: 'green', label: 'OpenAI' },
{
key: 2,
text: 'Midjourney Proxy',
value: 2,
color: 'light-blue',
label: 'Midjourney Proxy'
},
{
key: 5,
text: 'Midjourney Proxy Plus',
value: 5,
color: 'blue',
label: 'Midjourney Proxy Plus'
},
{
key: 36,
text: 'Suno API',
value: 36,
color: 'purple',
label: 'Suno API'
},
{ key: 4, text: 'Ollama', value: 4, color: 'grey', label: 'Ollama' },
{ value: 4, color: 'grey', label: 'Ollama' },
{
key: 14,
text: 'Anthropic Claude',
value: 14,
color: 'indigo',
label: 'Anthropic Claude'
},
{
key: 33,
text: 'AWS Claude',
value: 33,
color: 'indigo',
label: 'AWS Claude'
},
{ key: 41, text: 'Vertex AI', value: 41, color: 'blue', label: 'Vertex AI' },
{ value: 41, color: 'blue', label: 'Vertex AI' },
{
key: 3,
text: 'Azure OpenAI',
value: 3,
color: 'teal',
label: 'Azure OpenAI'
},
{
key: 34,
text: 'Cohere',
value: 34,
color: 'purple',
label: 'Cohere'
},
{ key: 39, text: 'Cloudflare', value: 39, color: 'grey', label: 'Cloudflare' },
{ key: 43, text: 'DeepSeek', value: 43, color: 'blue', label: 'DeepSeek' },
{ value: 39, color: 'grey', label: 'Cloudflare' },
{ value: 43, color: 'blue', label: 'DeepSeek' },
{
key: 15,
text: '百度文心千帆',
value: 15,
color: 'blue',
label: '百度文心千帆'
},
{
key: 17,
text: '阿里通义千问',
value: 46,
color: 'blue',
label: '百度文心千帆V2'
},
{
value: 17,
color: 'orange',
label: '阿里通义千问'
},
{
key: 18,
text: '讯飞星火认知',
value: 18,
color: 'blue',
label: '讯飞星火认知'
},
{
key: 16,
text: '智谱 ChatGLM',
value: 16,
color: 'violet',
label: '智谱 ChatGLM'
},
{
key: 26,
text: '智谱 GLM-4V',
value: 26,
color: 'purple',
label: '智谱 GLM-4V'
},
{
key: 24,
text: 'Google Gemini',
value: 24,
color: 'orange',
label: 'Google Gemini'
},
{
key: 11,
text: 'Google PaLM2',
value: 11,
color: 'orange',
label: 'Google PaLM2'
},
{ key: 25, text: 'Moonshot', value: 25, color: 'green', label: 'Moonshot' },
{ key: 19, text: '360 智脑', value: 19, color: 'blue', label: '360 智脑' },
{ key: 23, text: '腾讯混元', value: 23, color: 'teal', label: '腾讯混元' },
{ key: 31, text: '零一万物', value: 31, color: 'green', label: '零一万物' },
{ key: 35, text: 'MiniMax', value: 35, color: 'green', label: 'MiniMax' },
{ key: 37, text: 'Dify', value: 37, color: 'teal', label: 'Dify' },
{ key: 38, text: 'Jina', value: 38, color: 'blue', label: 'Jina' },
{ key: 40, text: 'SiliconCloud', value: 40, color: 'purple', label: 'SiliconCloud' },
{ key: 42, text: 'Mistral AI', value: 42, color: 'blue', label: 'Mistral AI' },
{ key: 8, text: '自定义渠道', value: 8, color: 'pink', label: '自定义渠道' },
{
key: 22,
text: '知识库FastGPT',
value: 45,
color: 'blue',
label: '火山方舟(豆包)'
},
{ value: 25, color: 'green', label: 'Moonshot' },
{ value: 19, color: 'blue', label: '360 智脑' },
{ value: 23, color: 'teal', label: '腾讯混元' },
{ value: 31, color: 'green', label: '零一万物' },
{ value: 35, color: 'green', label: 'MiniMax' },
{ value: 37, color: 'teal', label: 'Dify' },
{ value: 38, color: 'blue', label: 'Jina' },
{ value: 40, color: 'purple', label: 'SiliconCloud' },
{ value: 42, color: 'blue', label: 'Mistral AI' },
{ value: 8, color: 'pink', label: '自定义渠道' },
{
value: 22,
color: 'blue',
label: '知识库FastGPT'
},
{
key: 21,
text: '知识库AI Proxy',
value: 21,
color: 'purple',
label: '知识库AI Proxy'
},
{
value: 44,
color: 'purple',
label: '嵌入模型MokaAI M3E'
}
];

View File

@@ -1,6 +1,6 @@
import i18next from 'i18next';
import { Modal, Tag, Typography } from '@douyinfe/semi-ui';
import { copy, showSuccess } from './utils.js';
import { copy, isMobile, showSuccess } from './utils.js';
export function renderText(text, limit) {
if (text.length > limit) {
@@ -67,6 +67,73 @@ export function renderRatio(ratio) {
return <Tag color={color}>{ratio}x {i18next.t('倍率')}</Tag>;
}
const measureTextWidth = (text, style = {
fontSize: '14px',
fontFamily: '-apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, sans-serif'
}, containerWidth) => {
const span = document.createElement('span');
span.style.visibility = 'hidden';
span.style.position = 'absolute';
span.style.whiteSpace = 'nowrap';
span.style.fontSize = style.fontSize;
span.style.fontFamily = style.fontFamily;
span.textContent = text;
document.body.appendChild(span);
const width = span.offsetWidth;
document.body.removeChild(span);
return width;
};
export function truncateText(text, maxWidth = 200) {
if (!isMobile()) {
return text;
}
if (!text) return text;
try {
// Handle percentage-based maxWidth
let actualMaxWidth = maxWidth;
if (typeof maxWidth === 'string' && maxWidth.endsWith('%')) {
const percentage = parseFloat(maxWidth) / 100;
// Use window width as fallback container width
actualMaxWidth = window.innerWidth * percentage;
}
const width = measureTextWidth(text);
if (width <= actualMaxWidth) return text;
let left = 0;
let right = text.length;
let result = text;
while (left <= right) {
const mid = Math.floor((left + right) / 2);
const truncated = text.slice(0, mid) + '...';
const currentWidth = measureTextWidth(truncated);
if (currentWidth <= actualMaxWidth) {
result = truncated;
left = mid + 1;
} else {
right = mid - 1;
}
}
return result;
} catch (error) {
console.warn('Text measurement failed, falling back to character count', error);
if (text.length > 20) {
return text.slice(0, 17) + '...';
}
return text;
}
}
export const renderGroupOption = (item) => {
const {
disabled,
@@ -386,7 +453,7 @@ export function renderQuotaWithPrompt(quota, digits) {
let displayInCurrency = localStorage.getItem('display_in_currency');
displayInCurrency = displayInCurrency === 'true';
if (displayInCurrency) {
return '|' + i18next.t('等价金额') + ': ' + renderQuota(quota, digits) + '';
return ' | ' + i18next.t('等价金额') + ': ' + renderQuota(quota, digits) + '';
}
return '';
}

File diff suppressed because it is too large Load Diff

View File

@@ -201,7 +201,7 @@
"相关 API 显示令牌额度而非用户额度": "Related APIs show token quota instead of user quota",
"保存通用设置": "Save General Settings",
"监控设置": "Monitoring Settings",
"最长响应时间": "Maximum Response Time",
"测试所有渠道的最长响应时间": "Maximum response time for testing all channels",
"单位秒": "Unit: seconds",
"当运行通道全部测试时": "When running all channel tests",
"超过此时间将自动禁用通道": "Channels exceeding this time will be automatically disabled",
@@ -498,8 +498,7 @@
"请输入用户名": "Please enter username",
"请输入显示名称": "Please enter display name",
"请输入密码": "Please enter password",
"模型部署名称必须和模型名称保持一致": "The model deployment name must be consistent with the model name",
",因为 One API 会把请求体中的 model": ", because One API will take the model in the request body",
"注意,模型部署名称必须和模型名称保持一致": "Note that the model deployment name must be consistent with the model name",
"请输入 AZURE_OPENAI_ENDPOINT": "Please enter AZURE_OPENAI_ENDPOINT",
"请输入自定义渠道的 Base URL": "Please enter the Base URL of the custom channel",
"Homepage URL 填": "Fill in the Homepage URL",
@@ -1109,7 +1108,7 @@
"如果你对接的是上游One API或者New API等转发项目请使用OpenAI类型不要使用此类型除非你知道你在做什么。": "If you are connecting to upstream One API or New API forwarding projects, please use OpenAI type. Do not use this type unless you know what you are doing.",
"完整的 Base URL支持变量{model}": "Complete Base URL, supports variable {model}",
"请输入完整的URL例如https://api.openai.com/v1/chat/completions": "Please enter complete URL, e.g.: https://api.openai.com/v1/chat/completions",
"此项可选,用于通过代理站来进行 API 调用": "Optional, used for API calls through proxy sites",
"此项可选,用于通过代理站来进行 API 调用,末尾不要带/v1和/": "Optional for API calls through proxy sites, do not end with /v1 and /",
"私有部署地址": "Private Deployment Address",
"请输入私有部署地址格式为https://fastgpt.run/api/openapi": "Please enter private deployment address, format: https://fastgpt.run/api/openapi",
"注意非Chat API请务必填写正确的API地址否则可能导致无法使用": "Note: For non-Chat API, please make sure to enter the correct API address, otherwise it may not work",
@@ -1247,5 +1246,8 @@
"请输入要设置的标签名称": "Please enter the tag name to be set",
"请输入标签名称": "Please enter the tag name",
"支持搜索用户的 ID、用户名、显示名称和邮箱地址": "Support searching for user ID, username, display name, and email address",
"已注销": "Logged out"
"已注销": "Logged out",
"自动禁用关键词": "Automatic disable keywords",
"一行一个,不区分大小写": "One line per keyword, not case-sensitive",
"当上游通道返回错误中包含这些关键词时(不区分大小写),自动禁用通道": "When the upstream channel returns an error containing these keywords (not case-sensitive), automatically disable the channel"
}

Some files were not shown because too many files have changed in this diff Show More