Compare commits

..

33 Commits

Author SHA1 Message Date
1808837298@qq.com
971aea09ee feat: Improve image handling for Ollama channels 2025-02-19 20:45:42 +08:00
1808837298@qq.com
a4b2b9c935 feat: Enhance Ollama channel support with additional request parameters #771 2025-02-19 19:58:34 +08:00
1808837298@qq.com
ae5875d4c7 fix: Remove redundant error handling in distributor and relay modules 2025-02-19 18:47:28 +08:00
1808837298@qq.com
5937d850d9 refactor: Replace manual goroutine creation with gopool.Go 2025-02-19 18:38:29 +08:00
Calcium-Ion
2b7435500c Merge pull request #770 from Calcium-Ion/refactor_notify
feat: Add user notification settings and multiple notification methods
2025-02-19 14:54:54 +07:00
1808837298@qq.com
90191b8d5b chore: update env name and README 2025-02-19 15:54:33 +08:00
1808837298@qq.com
585c19fc70 docs: Add proxy usage information note in SystemSetting component 2025-02-19 15:45:09 +08:00
1808837298@qq.com
4e871507cf feat: Implement comprehensive webhook notification system 2025-02-19 15:40:54 +08:00
1808837298@qq.com
b1847509a4 refactor: Optimize user caching and token retrieval methods 2025-02-19 15:12:26 +08:00
Calcium-Ion
63f3412394 Merge pull request #768 from lgphone/main
bugfix: 配置文件 .env.example 示例配置错误
2025-02-18 19:35:08 +07:00
lgphone
a13bea5ffa Update .env.example
修复示例配置中MySQL的DSN错误问题
2025-02-18 19:18:54 +08:00
Calcium-Ion
2e3b920a2c Merge pull request #763 from Sh1n3zZ/support-imagen-3.0-generate-002
feat: add Gemini Imagen image generation support
2025-02-18 15:32:32 +07:00
1808837298@qq.com
812c188ab1 fix: Extend temperature handling for OpenAI-like models
- Add support for suppressing temperature for o1 models
- Expand model prefix check to include 'o1' alongside 'o3' models
2025-02-18 16:00:56 +08:00
1808837298@qq.com
0907a078b4 refactor: Simplify root user notification and remove global email variable
- Remove global `RootUserEmail` variable
- Modify channel testing and user notification methods to use `GetRootUser()`
- Update user cache and notification service to use more consistent user base type
- Add new channel test notification type
2025-02-18 15:59:17 +08:00
1808837298@qq.com
56f6b2ab56 feat: Implement notification rate limiting mechanism
- Add in-memory and Redis-based notification rate limiting
- Create configurable hourly notification limits
- Implement notification limit checking for user notifications
- Add environment variables for customizing notification limits
2025-02-18 15:30:43 +08:00
1808837298@qq.com
9d9c461c48 refactor: Improve CompletionRatio handling with thread-safe access and initialization 2025-02-18 15:01:43 +08:00
1808837298@qq.com
3da1344897 feat: Add user notification settings with quota warning and multiple notification methods
- Implement user notification settings with email and webhook options
- Add new user settings for quota warning threshold and notification preferences
- Create backend API and database support for user notification configuration
- Enhance frontend personal settings with notification configuration UI
- Support custom notification email and webhook URL
- Add service layer for sending user notifications
2025-02-18 14:54:21 +08:00
Sh1n3zZ
61d2a2f92d feat: add Gemini Imagen image generation support 2025-02-18 01:41:58 +08:00
1808837298@qq.com
995b3a2403 Merge remote-tracking branch 'origin/main' 2025-02-17 18:15:13 +08:00
1808837298@qq.com
7b384cb933 feat: Add support for DeepSeek completions endpoint 2025-02-17 18:15:01 +08:00
Calcium-Ion
78f19d4690 Merge pull request #735 from jyc001/main
feat:Add Supoorts to FIM
2025-02-17 14:37:06 +07:00
1808837298@qq.com
3239c60535 refactor: Optimize channel testing and model menu generation (fix #761) 2025-02-15 19:12:28 +08:00
1808837298@qq.com
e6f4587f6f refactor: Improve channel property update mechanism (fix #761) 2025-02-15 15:30:55 +08:00
Calcium-Ion
814be84500 Merge pull request #759 from nightcoffee/patch-1
feat: add 火山引擎 support stream options
2025-02-15 14:22:04 +07:00
nightcoffee
e7e5a16767 feat: add 火山引擎 support stream options 2025-02-15 04:55:57 +08:00
1808837298@qq.com
6bf99f218c feat: Enhance VolcEngine channel support with bot model routing (fix #757) 2025-02-15 00:10:58 +08:00
1808837298@qq.com
bd4ce9cd91 fix: Improve OpenAI stream data parsing and handling 2025-02-14 23:52:25 +08:00
1808837298@qq.com
9edb9f7a71 feat: Add automatic channel disabling based on configurable keywords
- Introduce AutomaticDisableKeywords setting to dynamically control channel disabling
- Implement AC search for matching error messages against disable keywords
- Add frontend UI for configuring automatic disable keywords
- Update localization with new keyword-based channel disabling feature
- Refactor sensitive word and AC search logic to support multiple keyword lists
2025-02-13 16:39:17 +08:00
e.
206dbfa45e Merge pull request #2 from jyc001/dev
fix: correct JSON tags for `Prompt` and `Suffix` in `GeneralOpenAIReq…
2025-02-08 00:37:37 +08:00
e.
1eb72f2f22 fix: correct JSON tags for Prompt and Suffix in GeneralOpenAIRequest 2025-02-08 00:36:42 +08:00
e.
68bd7f70a4 Merge pull request #1 from jyc001/dev
Dev
2025-02-08 00:25:49 +08:00
e.
8082905184 feat: add Suffix to GeneralOpenAIRequest in order to support FIM 2025-02-08 00:25:08 +08:00
e.
ce4269955e feat add FIM support for siliconflow 2025-02-08 00:23:35 +08:00
58 changed files with 1460 additions and 495 deletions

View File

@@ -10,9 +10,9 @@
# 数据库相关配置
# 数据库连接字符串
# SQL_DSN=mysql://user:password@tcp(127.0.0.1:3306)/dbname?parseTime=true
# SQL_DSN=user:password@tcp(127.0.0.1:3306)/dbname?parseTime=true
# 日志数据库连接字符串
# LOG_SQL_DSN=mysql://user:password@tcp(127.0.0.1:3306)/logdb?parseTime=true
# LOG_SQL_DSN=user:password@tcp(127.0.0.1:3306)/logdb?parseTime=true
# SQLite数据库路径
# SQLITE_PATH=/path/to/sqlite.db
# 数据库最大空闲连接数

View File

@@ -89,6 +89,8 @@ You can add custom models gpt-4-gizmo-* in channels. These are third-party model
- `MAX_FILE_DOWNLOAD_MB`: Maximum file download size in MB, default `20`
- `CRYPTO_SECRET`: Encryption key for encrypting database content
- `AZURE_DEFAULT_API_VERSION`: Azure channel default API version, if not specified in channel settings, use this version, default `2024-12-01-preview`
- `NOTIFICATION_LIMIT_DURATION_MINUTE`: Duration of notification limit in minutes, default `10`
- `NOTIFY_LIMIT_COUNT`: Maximum number of user notifications in the specified duration, default `2`
## Deployment

View File

@@ -95,6 +95,9 @@
- `MAX_FILE_DOWNLOAD_MB`: 最大文件下载大小,单位 MB默认为 `20`
- `CRYPTO_SECRET`:加密密钥,用于加密数据库内容。
- `AZURE_DEFAULT_API_VERSION`Azure渠道默认API版本如果渠道设置中未指定API版本则使用此版本默认为 `2024-12-01-preview`
- `NOTIFICATION_LIMIT_DURATION_MINUTE`:通知限制的持续时间(分钟),默认为 `10`
- `NOTIFY_LIMIT_COUNT`:用户通知在指定持续时间内的最大数量,默认为 `2`
## 部署
> [!TIP]

View File

@@ -101,7 +101,7 @@ var PreConsumedQuota = 500
var RetryTimes = 0
var RootUserEmail = ""
//var RootUserEmail = ""
var IsMasterNode = os.Getenv("NODE_TYPE") != "slave"

View File

@@ -3,5 +3,6 @@ package common
var UsingSQLite = false
var UsingPostgreSQL = false
var UsingMySQL = false
var UsingClickHouse = false
var SQLitePath = "one-api.db?_busy_timeout=5000"

View File

@@ -1,22 +1,9 @@
package common
import (
"fmt"
"runtime/debug"
"time"
)
func SafeGoroutine(f func()) {
go func() {
defer func() {
if r := recover(); r != nil {
SysError(fmt.Sprintf("child goroutine panic occured: error: %v, stack: %s", r, string(debug.Stack())))
}
}()
f()
}()
}
func SafeSendBool(ch chan bool, value bool) (closed bool) {
defer func() {
// Recover from panic if one occured. A panic would mean the channel was closed.

View File

@@ -4,6 +4,7 @@ import (
"context"
"encoding/json"
"fmt"
"github.com/bytedance/gopkg/util/gopool"
"github.com/gin-gonic/gin"
"io"
"log"
@@ -80,9 +81,9 @@ func logHelper(ctx context.Context, level string, msg string) {
if logCount > maxLogCount && !setupLogWorking {
logCount = 0
setupLogWorking = true
go func() {
gopool.Go(func() {
SetupLogger()
}()
})
}
}
@@ -100,6 +101,14 @@ func LogQuota(quota int) string {
}
}
func FormatQuota(quota int) string {
if DisplayInCurrencyEnabled {
return fmt.Sprintf("%.6f", float64(quota)/QuotaPerUnit)
} else {
return fmt.Sprintf("%d", quota)
}
}
// LogJson 仅供测试使用 only for test
func LogJson(ctx context.Context, msg string, obj any) {
jsonStr, err := json.Marshal(obj)

View File

@@ -233,7 +233,11 @@ var (
modelRatioMapMutex = sync.RWMutex{}
)
var CompletionRatio map[string]float64 = nil
var (
CompletionRatio map[string]float64 = nil
CompletionRatioMutex = sync.RWMutex{}
)
var defaultCompletionRatio = map[string]float64{
"gpt-4-gizmo-*": 2,
"gpt-4o-gizmo-*": 3,
@@ -334,10 +338,17 @@ func GetDefaultModelRatioMap() map[string]float64 {
return defaultModelRatio
}
func CompletionRatio2JSONString() string {
func GetCompletionRatioMap() map[string]float64 {
CompletionRatioMutex.Lock()
defer CompletionRatioMutex.Unlock()
if CompletionRatio == nil {
CompletionRatio = defaultCompletionRatio
}
return CompletionRatio
}
func CompletionRatio2JSONString() string {
GetCompletionRatioMap()
jsonBytes, err := json.Marshal(CompletionRatio)
if err != nil {
SysError("error marshalling completion ratio: " + err.Error())
@@ -346,11 +357,15 @@ func CompletionRatio2JSONString() string {
}
func UpdateCompletionRatioByJSONString(jsonStr string) error {
CompletionRatioMutex.Lock()
defer CompletionRatioMutex.Unlock()
CompletionRatio = make(map[string]float64)
return json.Unmarshal([]byte(jsonStr), &CompletionRatio)
}
func GetCompletionRatio(name string) float64 {
GetCompletionRatioMap()
if strings.Contains(name, "/") {
if ratio, ok := CompletionRatio[name]; ok {
return ratio
@@ -476,24 +491,3 @@ func GetAudioCompletionRatio(name string) float64 {
}
return 2
}
//func GetAudioPricePerMinute(name string) float64 {
// if strings.HasPrefix(name, "gpt-4o-realtime") {
// return 0.06
// }
// return 0.06
//}
//
//func GetAudioCompletionPricePerMinute(name string) float64 {
// if strings.HasPrefix(name, "gpt-4o-realtime") {
// return 0.24
// }
// return 0.24
//}
func GetCompletionRatioMap() map[string]float64 {
if CompletionRatio == nil {
CompletionRatio = defaultCompletionRatio
}
return CompletionRatio
}

View File

@@ -29,6 +29,9 @@ var GeminiModelMap = map[string]string{
var GeminiVisionMaxImageNum = common.GetEnvOrDefault("GEMINI_VISION_MAX_IMAGE_NUM", 16)
var NotifyLimitCount = common.GetEnvOrDefault("NOTIFY_LIMIT_COUNT", 2)
var NotificationLimitDurationMinute = common.GetEnvOrDefault("NOTIFICATION_LIMIT_DURATION_MINUTE", 10)
func InitEnv() {
modelVersionMapStr := strings.TrimSpace(os.Getenv("GEMINI_MODEL_MAP"))
if modelVersionMapStr == "" {
@@ -44,5 +47,5 @@ func InitEnv() {
}
}
// 是否生成初始令牌,默认关闭。
// GenerateDefaultToken 是否生成初始令牌,默认关闭。
var GenerateDefaultToken = common.GetEnvOrDefaultBool("GENERATE_DEFAULT_TOKEN", false)

14
constant/user_setting.go Normal file
View File

@@ -0,0 +1,14 @@
package constant
var (
UserSettingNotifyType = "notify_type" // QuotaWarningType 额度预警类型
UserSettingQuotaWarningThreshold = "quota_warning_threshold" // QuotaWarningThreshold 额度预警阈值
UserSettingWebhookUrl = "webhook_url" // WebhookUrl webhook地址
UserSettingWebhookSecret = "webhook_secret" // WebhookSecret webhook密钥
UserSettingNotificationEmail = "notification_email" // NotificationEmail 通知邮箱地址
)
var (
NotifyTypeEmail = "email" // Email 邮件
NotifyTypeWebhook = "webhook" // Webhook
)

View File

@@ -41,36 +41,34 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
}
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
requestPath := "/v1/chat/completions"
// 先判断是否为 Embedding 模型
if strings.Contains(strings.ToLower(testModel), "embedding") ||
strings.HasPrefix(testModel, "m3e") || // m3e 系列模型
strings.Contains(testModel, "bge-") || // bge 系列模型
strings.HasPrefix(testModel, "m3e") || // m3e 系列模型
strings.Contains(testModel, "bge-") || // bge 系列模型
testModel == "text-embedding-v1" ||
channel.Type == common.ChannelTypeMokaAI{ // 其他 embedding 模型
requestPath = "/v1/embeddings" // 修改请求路径
channel.Type == common.ChannelTypeMokaAI { // 其他 embedding 模型
requestPath = "/v1/embeddings" // 修改请求路径
}
c.Request = &http.Request{
Method: "POST",
URL: &url.URL{Path: requestPath}, // 使用动态路径
URL: &url.URL{Path: requestPath}, // 使用动态路径
Body: nil,
Header: make(http.Header),
}
if testModel == "" {
common.SysLog(fmt.Sprintf("testModel 为空, channel 的 TestModel 是 %s", string(*channel.TestModel)))
if channel.TestModel != nil && *channel.TestModel != "" {
testModel = *channel.TestModel
} else {
if len(channel.GetModels()) > 0 {
testModel = channel.GetModels()[0]
} else {
testModel = "gpt-3.5-turbo"
testModel = "gpt-4o-mini"
}
common.SysLog(fmt.Sprintf("testModel 为空, channel 的 TestModel 为空:", string(testModel)))
}
}
@@ -102,7 +100,7 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
request := buildTestRequest(testModel)
meta.UpstreamModelName = testModel
common.SysLog(fmt.Sprintf("testing channel %d with model %s , meta %s ", channel.Id, testModel, meta))
common.SysLog(fmt.Sprintf("testing channel %d with model %s , meta %v ", channel.Id, testModel, meta))
adaptor.Init(meta)
@@ -173,9 +171,9 @@ func buildTestRequest(model string) *dto.GeneralOpenAIRequest {
// 先判断是否为 Embedding 模型
if strings.Contains(strings.ToLower(model), "embedding") ||
strings.HasPrefix(model, "m3e") || // m3e 系列模型
strings.Contains(model, "bge-") || // bge 系列模型
model == "text-embedding-v1" { // 其他 embedding 模型
strings.HasPrefix(model, "m3e") || // m3e 系列模型
strings.Contains(model, "bge-") || // bge 系列模型
model == "text-embedding-v1" { // 其他 embedding 模型
// Embedding 请求
testRequest.Input = []string{"hello world"}
return testRequest
@@ -240,9 +238,7 @@ var testAllChannelsLock sync.Mutex
var testAllChannelsRunning bool = false
func testAllChannels(notify bool) error {
if common.RootUserEmail == "" {
common.RootUserEmail = model.GetRootUserEmail()
}
testAllChannelsLock.Lock()
if testAllChannelsRunning {
testAllChannelsLock.Unlock()
@@ -297,10 +293,7 @@ func testAllChannels(notify bool) error {
testAllChannelsRunning = false
testAllChannelsLock.Unlock()
if notify {
err := common.SendEmail("通道测试完成", common.RootUserEmail, "通道测试完成,如果没有收到禁用通知,说明所有通道都正常")
if err != nil {
common.SysError(fmt.Sprintf("failed to send email: %s", err.Error()))
}
service.NotifyRootUser(dto.NotifyTypeChannelTest, "通道测试完成", "所有通道测试已完成")
}
})
return nil

View File

@@ -17,7 +17,7 @@ func GetPricing(c *gin.Context) {
}
var group string
if exists {
user, err := model.GetUserById(userId.(int), false)
user, err := model.GetUserCache(userId.(int))
if err == nil {
group = user.Group
}

View File

@@ -4,6 +4,7 @@ import (
"encoding/json"
"fmt"
"net/http"
"net/url"
"one-api/common"
"one-api/model"
"one-api/setting"
@@ -471,7 +472,7 @@ func GetUserModels(c *gin.Context) {
if err != nil {
id = c.GetInt("id")
}
user, err := model.GetUserById(id, true)
user, err := model.GetUserCache(id)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
@@ -869,9 +870,6 @@ func EmailBind(c *gin.Context) {
})
return
}
if user.Role == common.RoleRootUser {
common.RootUserEmail = email
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
@@ -913,3 +911,115 @@ func TopUp(c *gin.Context) {
})
return
}
type UpdateUserSettingRequest struct {
QuotaWarningType string `json:"notify_type"`
QuotaWarningThreshold int `json:"quota_warning_threshold"`
WebhookUrl string `json:"webhook_url,omitempty"`
WebhookSecret string `json:"webhook_secret,omitempty"`
NotificationEmail string `json:"notification_email,omitempty"`
}
func UpdateUserSetting(c *gin.Context) {
var req UpdateUserSettingRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "无效的参数",
})
return
}
// 验证预警类型
if req.QuotaWarningType != constant.NotifyTypeEmail && req.QuotaWarningType != constant.NotifyTypeWebhook {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "无效的预警类型",
})
return
}
// 验证预警阈值
if req.QuotaWarningThreshold <= 0 {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "预警阈值必须大于0",
})
return
}
// 如果是webhook类型,验证webhook地址
if req.QuotaWarningType == constant.NotifyTypeWebhook {
if req.WebhookUrl == "" {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "Webhook地址不能为空",
})
return
}
// 验证URL格式
if _, err := url.ParseRequestURI(req.WebhookUrl); err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "无效的Webhook地址",
})
return
}
}
// 如果是邮件类型,验证邮箱地址
if req.QuotaWarningType == constant.NotifyTypeEmail && req.NotificationEmail != "" {
// 验证邮箱格式
if !strings.Contains(req.NotificationEmail, "@") {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "无效的邮箱地址",
})
return
}
}
userId := c.GetInt("id")
user, err := model.GetUserById(userId, true)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
// 构建设置
settings := map[string]interface{}{
constant.UserSettingNotifyType: req.QuotaWarningType,
constant.UserSettingQuotaWarningThreshold: req.QuotaWarningThreshold,
}
// 如果是webhook类型,添加webhook相关设置
if req.QuotaWarningType == constant.NotifyTypeWebhook {
settings[constant.UserSettingWebhookUrl] = req.WebhookUrl
if req.WebhookSecret != "" {
settings[constant.UserSettingWebhookSecret] = req.WebhookSecret
}
}
// 如果提供了通知邮箱,添加到设置中
if req.QuotaWarningType == constant.NotifyTypeEmail && req.NotificationEmail != "" {
settings[constant.UserSettingNotificationEmail] = req.NotificationEmail
}
// 更新用户设置
user.SetSetting(settings)
if err := user.Update(false); err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "更新设置失败: " + err.Error(),
})
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "设置已更新",
})
}

View File

@@ -24,7 +24,7 @@ services:
- redis
- mysql
healthcheck:
test: [ "CMD-SHELL", "wget -q -O - http://localhost:3000/api/status | grep -o '\"success\":\\s*true' | awk -F: '{print $2}'" ]
test: ["CMD-SHELL", "wget -q -O - http://localhost:3000/api/status | grep -o '\"success\":\\s*true' | awk -F: '{print $$2}'"]
interval: 30s
timeout: 10s
retries: 3

25
dto/notify.go Normal file
View File

@@ -0,0 +1,25 @@
package dto
type Notify struct {
Type string `json:"type"`
Title string `json:"title"`
Content string `json:"content"`
Values []interface{} `json:"values"`
}
const ContentValueParam = "{{value}}"
const (
NotifyTypeQuotaExceed = "quota_exceed"
NotifyTypeChannelUpdate = "channel_update"
NotifyTypeChannelTest = "channel_test"
)
func NewNotify(t string, title string, content string, values []interface{}) Notify {
return Notify{
Type: t,
Title: title,
Content: content,
Values: values,
}
}

View File

@@ -18,6 +18,8 @@ type GeneralOpenAIRequest struct {
Model string `json:"model,omitempty"`
Messages []Message `json:"messages,omitempty"`
Prompt any `json:"prompt,omitempty"`
Prefix any `json:"prefix,omitempty"`
Suffix any `json:"suffix,omitempty"`
Stream bool `json:"stream,omitempty"`
StreamOptions *StreamOptions `json:"stream_options,omitempty"`
MaxTokens uint `json:"max_tokens,omitempty"`
@@ -86,8 +88,10 @@ func (r GeneralOpenAIRequest) ParseInput() []string {
}
type Message struct {
Role string `json:"role"`
Content json.RawMessage `json:"content"`
Role string `json:"role"`
Content json.RawMessage `json:"content"`
// parsedContent not json field
parsedContent []MediaContent
Name *string `json:"name,omitempty"`
Prefix *bool `json:"prefix,omitempty"`
ReasoningContent string `json:"reasoning_content,omitempty"`
@@ -158,6 +162,11 @@ func (m *Message) SetStringContent(content string) {
m.Content = jsonContent
}
func (m *Message) SetMediaContent(content []MediaContent) {
jsonContent, _ := json.Marshal(content)
m.Content = jsonContent
}
func (m *Message) IsStringContent() bool {
var stringContent string
if err := json.Unmarshal(m.Content, &stringContent); err == nil {
@@ -167,7 +176,15 @@ func (m *Message) IsStringContent() bool {
}
func (m *Message) ParseContent() []MediaContent {
if m.parsedContent != nil {
return m.parsedContent
}
var contentList []MediaContent
defer func() {
if len(contentList) > 0 {
m.parsedContent = contentList
}
}()
var stringContent string
if err := json.Unmarshal(m.Content, &stringContent); err == nil {
contentList = append(contentList, MediaContent{

View File

@@ -119,9 +119,9 @@ func main() {
}
if os.Getenv("ENABLE_PPROF") == "true" {
go func() {
gopool.Go(func() {
log.Println(http.ListenAndServe("0.0.0.0:8005", nil))
}()
})
go common.Monitor()
common.SysLog("pprof enabled")
}

View File

@@ -135,17 +135,14 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
midjourneyRequest := dto.MidjourneyRequest{}
err = common.UnmarshalBodyReusable(c, &midjourneyRequest)
if err != nil {
abortWithMidjourneyMessage(c, http.StatusBadRequest, constant.MjErrorUnknown, "无效的请求, "+err.Error())
return nil, false, err
}
midjourneyModel, mjErr, success := service.GetMjRequestModel(relayMode, &midjourneyRequest)
if mjErr != nil {
abortWithMidjourneyMessage(c, http.StatusBadRequest, mjErr.Code, mjErr.Description)
return nil, false, fmt.Errorf(mjErr.Description)
}
if midjourneyModel == "" {
if !success {
abortWithMidjourneyMessage(c, http.StatusBadRequest, constant.MjErrorUnknown, "无效的请求, 无法解析模型")
return nil, false, fmt.Errorf("无效的请求, 无法解析模型")
} else {
// task fetch, task fetch by condition, notify
@@ -170,7 +167,6 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
err = common.UnmarshalBodyReusable(c, &modelRequest)
}
if err != nil {
abortWithOpenAiMessage(c, http.StatusBadRequest, "无效的请求, "+err.Error())
return nil, false, errors.New("无效的请求, " + err.Error())
}
if strings.HasPrefix(c.Request.URL.Path, "/v1/realtime") {

View File

@@ -110,6 +110,7 @@ func InitOptionMap() {
common.OptionMap["StopOnSensitiveEnabled"] = strconv.FormatBool(setting.StopOnSensitiveEnabled)
common.OptionMap["SensitiveWords"] = setting.SensitiveWordsToString()
common.OptionMap["StreamCacheQueueLength"] = strconv.Itoa(setting.StreamCacheQueueLength)
common.OptionMap["AutomaticDisableKeywords"] = setting.AutomaticDisableKeywordsToString()
common.OptionMapRWMutex.Unlock()
loadOptionsFromDatabase()
@@ -335,6 +336,8 @@ func updateOptionMap(key string, value string) (err error) {
common.QuotaPerUnit, _ = strconv.ParseFloat(value, 64)
case "SensitiveWords":
setting.SensitiveWordsFromString(value)
case "AutomaticDisableKeywords":
setting.AutomaticDisableKeywordsFromString(value)
case "StreamCacheQueueLength":
setting.StreamCacheQueueLength, _ = strconv.Atoi(value)
}

View File

@@ -3,13 +3,11 @@ package model
import (
"errors"
"fmt"
"one-api/common"
"strings"
"github.com/bytedance/gopkg/util/gopool"
"gorm.io/gorm"
"one-api/common"
relaycommon "one-api/relay/common"
"one-api/setting"
"strconv"
"strings"
)
type Token struct {
@@ -322,80 +320,3 @@ func decreaseTokenQuota(id int, quota int) (err error) {
).Error
return err
}
func PreConsumeTokenQuota(relayInfo *relaycommon.RelayInfo, quota int) error {
if quota < 0 {
return errors.New("quota 不能为负数!")
}
if relayInfo.IsPlayground {
return nil
}
//if relayInfo.TokenUnlimited {
// return nil
//}
token, err := GetTokenById(relayInfo.TokenId)
if err != nil {
return err
}
if !relayInfo.TokenUnlimited && token.RemainQuota < quota {
return errors.New("令牌额度不足")
}
err = DecreaseTokenQuota(relayInfo.TokenId, relayInfo.TokenKey, quota)
if err != nil {
return err
}
return nil
}
func PostConsumeQuota(relayInfo *relaycommon.RelayInfo, userQuota int, quota int, preConsumedQuota int, sendEmail bool) (err error) {
if quota > 0 {
err = DecreaseUserQuota(relayInfo.UserId, quota)
} else {
err = IncreaseUserQuota(relayInfo.UserId, -quota)
}
if err != nil {
return err
}
if !relayInfo.IsPlayground {
if quota > 0 {
err = DecreaseTokenQuota(relayInfo.TokenId, relayInfo.TokenKey, quota)
} else {
err = IncreaseTokenQuota(relayInfo.TokenId, relayInfo.TokenKey, -quota)
}
if err != nil {
return err
}
}
if sendEmail {
if (quota + preConsumedQuota) != 0 {
quotaTooLow := userQuota >= common.QuotaRemindThreshold && userQuota-(quota+preConsumedQuota) < common.QuotaRemindThreshold
noMoreQuota := userQuota-(quota+preConsumedQuota) <= 0
if quotaTooLow || noMoreQuota {
go func() {
email, err := GetUserEmail(relayInfo.UserId)
if err != nil {
common.SysError("failed to fetch user email: " + err.Error())
}
prompt := "您的额度即将用尽"
if noMoreQuota {
prompt = "您的额度已用尽"
}
if email != "" {
topUpLink := fmt.Sprintf("%s/topup", setting.ServerAddress)
err = common.SendEmail(prompt, email,
fmt.Sprintf("%s当前剩余额度为 %d为了不影响您的使用请及时充值。<br/>充值链接:<a href='%s'>%s</a>", prompt, userQuota, topUpLink, topUpLink))
if err != nil {
common.SysError("failed to send email" + err.Error())
}
common.SysLog("user quota is low, consumed quota: " + strconv.Itoa(quota) + ", user quota: " + strconv.Itoa(userQuota))
}
}()
}
}
}
return nil
}

View File

@@ -52,7 +52,7 @@ func cacheSetTokenField(key string, field string, value string) error {
func cacheGetTokenByKey(key string) (*Token, error) {
hmacKey := common.GenerateHMAC(key)
if !common.RedisEnabled {
return nil, nil
return nil, fmt.Errorf("redis is not enabled")
}
var token Token
err := common.RedisHGetObj(fmt.Sprintf("token:%s", hmacKey), &token)

View File

@@ -1,6 +1,7 @@
package model
import (
"encoding/json"
"errors"
"fmt"
"one-api/common"
@@ -38,6 +39,20 @@ type User struct {
InviterId int `json:"inviter_id" gorm:"type:int;column:inviter_id;index"`
DeletedAt gorm.DeletedAt `gorm:"index"`
LinuxDOId string `json:"linux_do_id" gorm:"column:linux_do_id;index"`
Setting string `json:"setting" gorm:"type:text;column:setting"`
}
func (user *User) ToBaseUser() *UserBase {
cache := &UserBase{
Id: user.Id,
Group: user.Group,
Quota: user.Quota,
Status: user.Status,
Username: user.Username,
Setting: user.Setting,
Email: user.Email,
}
return cache
}
func (user *User) GetAccessToken() string {
@@ -51,6 +66,22 @@ func (user *User) SetAccessToken(token string) {
user.AccessToken = &token
}
func (user *User) GetSetting() map[string]interface{} {
if user.Setting == "" {
return nil
}
return common.StrToMap(user.Setting)
}
func (user *User) SetSetting(setting map[string]interface{}) {
settingBytes, err := json.Marshal(setting)
if err != nil {
common.SysError("failed to marshal setting: " + err.Error())
return
}
user.Setting = string(settingBytes)
}
// CheckUserExistOrDeleted check if user exist or deleted, if not exist, return false, nil, if deleted or exist, return true, nil
func CheckUserExistOrDeleted(username string, email string) (bool, error) {
var user User
@@ -315,8 +346,8 @@ func (user *User) Update(updatePassword bool) error {
return err
}
// 更新缓存
return updateUserCache(user.Id, user.Username, user.Group, user.Quota, user.Status)
// Update cache
return updateUserCache(*user)
}
func (user *User) Edit(updatePassword bool) error {
@@ -344,8 +375,8 @@ func (user *User) Edit(updatePassword bool) error {
return err
}
// 更新缓存
return updateUserCache(user.Id, user.Username, user.Group, user.Quota, user.Status)
// Update cache
return updateUserCache(*user)
}
func (user *User) Delete() error {
@@ -371,8 +402,8 @@ func (user *User) HardDelete() error {
// ValidateAndFill check password & user status
func (user *User) ValidateAndFill() (err error) {
// When querying with struct, GORM will only query with non-zero fields,
// that means if your fields value is 0, '', false or other zero values,
// it wont be used to build query conditions
// that means if your field's value is 0, '', false or other zero values,
// it won't be used to build query conditions
password := user.Password
username := strings.TrimSpace(user.Username)
if username == "" || password == "" {
@@ -531,7 +562,6 @@ func GetUserQuota(id int, fromDB bool) (quota int, err error) {
return quota, nil
}
// Don't return error - fall through to DB
//common.SysError("failed to get user quota from cache: " + err.Error())
}
fromDB = true
err = DB.Model(&User{}).Where("id = ?", id).Select("quota").Find(&quota).Error
@@ -580,6 +610,35 @@ func GetUserGroup(id int, fromDB bool) (group string, err error) {
return group, nil
}
// GetUserSetting gets setting from Redis first, falls back to DB if needed
func GetUserSetting(id int, fromDB bool) (settingMap map[string]interface{}, err error) {
var setting string
defer func() {
// Update Redis cache asynchronously on successful DB read
if shouldUpdateRedis(fromDB, err) {
gopool.Go(func() {
if err := updateUserSettingCache(id, setting); err != nil {
common.SysError("failed to update user setting cache: " + err.Error())
}
})
}
}()
if !fromDB && common.RedisEnabled {
setting, err := getUserSettingCache(id)
if err == nil {
return setting, nil
}
// Don't return error - fall through to DB
}
fromDB = true
err = DB.Model(&User{}).Where("id = ?", id).Select("setting").Find(&setting).Error
if err != nil {
return map[string]interface{}{}, err
}
return common.StrToMap(setting), nil
}
func IncreaseUserQuota(id int, quota int) (err error) {
if quota < 0 {
return errors.New("quota 不能为负数!")
@@ -641,9 +700,14 @@ func DeltaUpdateUserQuota(id int, delta int) (err error) {
}
}
func GetRootUserEmail() (email string) {
DB.Model(&User{}).Where("role = ?", common.RoleRootUser).Select("email").Find(&email)
return email
//func GetRootUserEmail() (email string) {
// DB.Model(&User{}).Where("role = ?", common.RoleRootUser).Select("email").Find(&email)
// return email
//}
func GetRootUser() (user *User) {
DB.Where("role = ?", common.RoleRootUser).First(&user)
return user
}
func UpdateUserUsedQuotaAndRequestCount(id int, quota int) {
@@ -725,10 +789,10 @@ func IsLinuxDOIdAlreadyTaken(linuxDOId string) bool {
return !errors.Is(err, gorm.ErrRecordNotFound)
}
func (u *User) FillUserByLinuxDOId() error {
if u.LinuxDOId == "" {
func (user *User) FillUserByLinuxDOId() error {
if user.LinuxDOId == "" {
return errors.New("linux do id is empty")
}
err := DB.Where("linux_do_id = ?", u.LinuxDOId).First(u).Error
err := DB.Where("linux_do_id = ?", user.LinuxDOId).First(user).Error
return err
}

View File

@@ -1,206 +1,213 @@
package model
import (
"encoding/json"
"fmt"
"one-api/common"
"one-api/constant"
"strconv"
"time"
"github.com/bytedance/gopkg/util/gopool"
)
// Change UserCache struct to userCache
type userCache struct {
// UserBase struct remains the same as it represents the cached data structure
type UserBase struct {
Id int `json:"id"`
Group string `json:"group"`
Email string `json:"email"`
Quota int `json:"quota"`
Status int `json:"status"`
Role int `json:"role"`
Username string `json:"username"`
Setting string `json:"setting"`
}
// Rename all exported functions to private ones
// invalidateUserCache clears all user related cache
func (user *UserBase) GetSetting() map[string]interface{} {
if user.Setting == "" {
return nil
}
return common.StrToMap(user.Setting)
}
func (user *UserBase) SetSetting(setting map[string]interface{}) {
settingBytes, err := json.Marshal(setting)
if err != nil {
common.SysError("failed to marshal setting: " + err.Error())
return
}
user.Setting = string(settingBytes)
}
// getUserCacheKey returns the key for user cache
func getUserCacheKey(userId int) string {
return fmt.Sprintf("user:%d", userId)
}
// invalidateUserCache clears user cache
func invalidateUserCache(userId int) error {
if !common.RedisEnabled {
return nil
}
return common.RedisHDelObj(getUserCacheKey(userId))
}
keys := []string{
fmt.Sprintf(constant.UserGroupKeyFmt, userId),
fmt.Sprintf(constant.UserQuotaKeyFmt, userId),
fmt.Sprintf(constant.UserEnabledKeyFmt, userId),
fmt.Sprintf(constant.UserUsernameKeyFmt, userId),
// updateUserCache updates all user cache fields using hash
func updateUserCache(user User) error {
if !common.RedisEnabled {
return nil
}
for _, key := range keys {
if err := common.RedisDel(key); err != nil {
return fmt.Errorf("failed to delete cache key %s: %w", key, err)
return common.RedisHSetObj(
getUserCacheKey(user.Id),
user.ToBaseUser(),
time.Duration(constant.UserId2QuotaCacheSeconds)*time.Second,
)
}
// GetUserCache gets complete user cache from hash
func GetUserCache(userId int) (userCache *UserBase, err error) {
var user *User
var fromDB bool
defer func() {
// Update Redis cache asynchronously on successful DB read
if shouldUpdateRedis(fromDB, err) && user != nil {
gopool.Go(func() {
if err := updateUserCache(*user); err != nil {
common.SysError("failed to update user status cache: " + err.Error())
}
})
}
}
return nil
}
}()
// updateUserGroupCache updates user group cache
func updateUserGroupCache(userId int, group string) error {
if !common.RedisEnabled {
return nil
}
return common.RedisSet(
fmt.Sprintf(constant.UserGroupKeyFmt, userId),
group,
time.Duration(constant.UserId2QuotaCacheSeconds)*time.Second,
)
}
// updateUserQuotaCache updates user quota cache
func updateUserQuotaCache(userId int, quota int) error {
if !common.RedisEnabled {
return nil
}
return common.RedisSet(
fmt.Sprintf(constant.UserQuotaKeyFmt, userId),
fmt.Sprintf("%d", quota),
time.Duration(constant.UserId2QuotaCacheSeconds)*time.Second,
)
}
// updateUserStatusCache updates user status cache
func updateUserStatusCache(userId int, userEnabled bool) error {
if !common.RedisEnabled {
return nil
}
enabled := "0"
if userEnabled {
enabled = "1"
}
return common.RedisSet(
fmt.Sprintf(constant.UserEnabledKeyFmt, userId),
enabled,
time.Duration(constant.UserId2StatusCacheSeconds)*time.Second,
)
}
// updateUserNameCache updates username cache
func updateUserNameCache(userId int, username string) error {
if !common.RedisEnabled {
return nil
}
return common.RedisSet(
fmt.Sprintf(constant.UserUsernameKeyFmt, userId),
username,
time.Duration(constant.UserId2QuotaCacheSeconds)*time.Second,
)
}
// updateUserCache updates all user cache fields
func updateUserCache(userId int, username string, userGroup string, quota int, status int) error {
if !common.RedisEnabled {
return nil
// Try getting from Redis first
userCache, err = cacheGetUserBase(userId)
if err == nil {
return userCache, nil
}
if err := updateUserGroupCache(userId, userGroup); err != nil {
return fmt.Errorf("update group cache: %w", err)
}
if err := updateUserQuotaCache(userId, quota); err != nil {
return fmt.Errorf("update quota cache: %w", err)
}
if err := updateUserStatusCache(userId, status == common.UserStatusEnabled); err != nil {
return fmt.Errorf("update status cache: %w", err)
}
if err := updateUserNameCache(userId, username); err != nil {
return fmt.Errorf("update username cache: %w", err)
}
return nil
}
// getUserGroupCache gets user group from cache
func getUserGroupCache(userId int) (string, error) {
if !common.RedisEnabled {
return "", nil
}
return common.RedisGet(fmt.Sprintf(constant.UserGroupKeyFmt, userId))
}
// getUserQuotaCache gets user quota from cache
func getUserQuotaCache(userId int) (int, error) {
if !common.RedisEnabled {
return 0, nil
}
quotaStr, err := common.RedisGet(fmt.Sprintf(constant.UserQuotaKeyFmt, userId))
// If Redis fails, get from DB
fromDB = true
user, err = GetUserById(userId, false)
if err != nil {
return 0, err
return nil, err // Return nil and error if DB lookup fails
}
return strconv.Atoi(quotaStr)
// Create cache object from user data
userCache = &UserBase{
Id: user.Id,
Group: user.Group,
Quota: user.Quota,
Status: user.Status,
Username: user.Username,
Setting: user.Setting,
Email: user.Email,
}
return userCache, nil
}
// getUserStatusCache gets user status from cache
func getUserStatusCache(userId int) (int, error) {
func cacheGetUserBase(userId int) (*UserBase, error) {
if !common.RedisEnabled {
return 0, nil
return nil, fmt.Errorf("redis is not enabled")
}
statusStr, err := common.RedisGet(fmt.Sprintf(constant.UserEnabledKeyFmt, userId))
var userCache UserBase
// Try getting from Redis first
err := common.RedisHGetObj(getUserCacheKey(userId), &userCache)
if err != nil {
return 0, err
return nil, err
}
return strconv.Atoi(statusStr)
return &userCache, nil
}
// getUserNameCache gets username from cache
func getUserNameCache(userId int) (string, error) {
if !common.RedisEnabled {
return "", nil
}
return common.RedisGet(fmt.Sprintf(constant.UserUsernameKeyFmt, userId))
}
// getUserCache gets complete user cache
func getUserCache(userId int) (*userCache, error) {
if !common.RedisEnabled {
return nil, nil
}
group, err := getUserGroupCache(userId)
if err != nil {
return nil, fmt.Errorf("get group cache: %w", err)
}
quota, err := getUserQuotaCache(userId)
if err != nil {
return nil, fmt.Errorf("get quota cache: %w", err)
}
status, err := getUserStatusCache(userId)
if err != nil {
return nil, fmt.Errorf("get status cache: %w", err)
}
username, err := getUserNameCache(userId)
if err != nil {
return nil, fmt.Errorf("get username cache: %w", err)
}
return &userCache{
Id: userId,
Group: group,
Quota: quota,
Status: status,
Username: username,
}, nil
}
// Add atomic quota operations
// Add atomic quota operations using hash fields
func cacheIncrUserQuota(userId int, delta int64) error {
if !common.RedisEnabled {
return nil
}
key := fmt.Sprintf(constant.UserQuotaKeyFmt, userId)
return common.RedisIncr(key, delta)
return common.RedisHIncrBy(getUserCacheKey(userId), "Quota", delta)
}
func cacheDecrUserQuota(userId int, delta int64) error {
return cacheIncrUserQuota(userId, -delta)
}
// Helper functions to get individual fields if needed
func getUserGroupCache(userId int) (string, error) {
cache, err := GetUserCache(userId)
if err != nil {
return "", err
}
return cache.Group, nil
}
func getUserQuotaCache(userId int) (int, error) {
cache, err := GetUserCache(userId)
if err != nil {
return 0, err
}
return cache.Quota, nil
}
func getUserStatusCache(userId int) (int, error) {
cache, err := GetUserCache(userId)
if err != nil {
return 0, err
}
return cache.Status, nil
}
func getUserNameCache(userId int) (string, error) {
cache, err := GetUserCache(userId)
if err != nil {
return "", err
}
return cache.Username, nil
}
func getUserSettingCache(userId int) (map[string]interface{}, error) {
setting := make(map[string]interface{})
cache, err := GetUserCache(userId)
if err != nil {
return setting, err
}
return cache.GetSetting(), nil
}
// New functions for individual field updates
func updateUserStatusCache(userId int, status bool) error {
if !common.RedisEnabled {
return nil
}
statusInt := common.UserStatusEnabled
if !status {
statusInt = common.UserStatusDisabled
}
return common.RedisHSetField(getUserCacheKey(userId), "Status", fmt.Sprintf("%d", statusInt))
}
func updateUserQuotaCache(userId int, quota int) error {
if !common.RedisEnabled {
return nil
}
return common.RedisHSetField(getUserCacheKey(userId), "Quota", fmt.Sprintf("%d", quota))
}
func updateUserGroupCache(userId int, group string) error {
if !common.RedisEnabled {
return nil
}
return common.RedisHSetField(getUserCacheKey(userId), "Group", group)
}
func updateUserNameCache(userId int, username string) error {
if !common.RedisEnabled {
return nil
}
return common.RedisHSetField(getUserCacheKey(userId), "Username", username)
}
func updateUserSettingCache(userId int, setting string) error {
if !common.RedisEnabled {
return nil
}
return common.RedisHSetField(getUserCacheKey(userId), "Setting", setting)
}

View File

@@ -4,13 +4,14 @@ import (
"bytes"
"errors"
"fmt"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/dto"
"one-api/relay/channel"
relaycommon "one-api/relay/common"
"one-api/relay/constant"
"github.com/gin-gonic/gin"
)
type Adaptor struct {

View File

@@ -10,6 +10,7 @@ import (
"one-api/relay/channel"
"one-api/relay/channel/openai"
relaycommon "one-api/relay/common"
"one-api/relay/constant"
)
type Adaptor struct {
@@ -29,7 +30,12 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
return fmt.Sprintf("%s/v1/chat/completions", info.BaseUrl), nil
switch info.RelayMode {
case constant.RelayModeCompletions:
return fmt.Sprintf("%s/beta/completions", info.BaseUrl), nil
default:
return fmt.Sprintf("%s/v1/chat/completions", info.BaseUrl), nil
}
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {

View File

@@ -1,15 +1,21 @@
package gemini
import (
"encoding/json"
"errors"
"fmt"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/common"
"one-api/constant"
"one-api/dto"
"one-api/relay/channel"
relaycommon "one-api/relay/common"
"one-api/service"
"strings"
"github.com/gin-gonic/gin"
)
type Adaptor struct {
@@ -21,8 +27,36 @@ func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInf
}
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
if !strings.HasPrefix(info.UpstreamModelName, "imagen") {
return nil, errors.New("not supported model for image generation")
}
// convert size to aspect ratio
aspectRatio := "1:1" // default aspect ratio
switch request.Size {
case "1024x1024":
aspectRatio = "1:1"
case "1024x1792":
aspectRatio = "9:16"
case "1792x1024":
aspectRatio = "16:9"
}
// build gemini imagen request
geminiRequest := GeminiImageRequest{
Instances: []GeminiImageInstance{
{
Prompt: request.Prompt,
},
},
Parameters: GeminiImageParameters{
SampleCount: request.N,
AspectRatio: aspectRatio,
PersonGeneration: "allow_adult", // default allow adult
},
}
return geminiRequest, nil
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
@@ -40,6 +74,10 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
}
}
if strings.HasPrefix(info.UpstreamModelName, "imagen") {
return fmt.Sprintf("%s/%s/models/%s:predict", info.BaseUrl, version, info.UpstreamModelName), nil
}
action := "generateContent"
if info.IsStream {
action = "streamGenerateContent?alt=sse"
@@ -73,12 +111,15 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela
return nil, errors.New("not implemented")
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
if strings.HasPrefix(info.UpstreamModelName, "imagen") {
return GeminiImageHandler(c, resp, info)
}
if info.IsStream {
err, usage = GeminiChatStreamHandler(c, resp, info)
} else {
@@ -87,6 +128,60 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
return
}
func GeminiImageHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
responseBody, readErr := io.ReadAll(resp.Body)
if readErr != nil {
return nil, service.OpenAIErrorWrapper(readErr, "read_response_body_failed", http.StatusInternalServerError)
}
_ = resp.Body.Close()
var geminiResponse GeminiImageResponse
if jsonErr := json.Unmarshal(responseBody, &geminiResponse); jsonErr != nil {
return nil, service.OpenAIErrorWrapper(jsonErr, "unmarshal_response_body_failed", http.StatusInternalServerError)
}
if len(geminiResponse.Predictions) == 0 {
return nil, service.OpenAIErrorWrapper(errors.New("no images generated"), "no_images", http.StatusBadRequest)
}
// convert to openai format response
openAIResponse := dto.ImageResponse{
Created: common.GetTimestamp(),
Data: make([]dto.ImageData, 0, len(geminiResponse.Predictions)),
}
for _, prediction := range geminiResponse.Predictions {
if prediction.RaiFilteredReason != "" {
continue // skip filtered image
}
openAIResponse.Data = append(openAIResponse.Data, dto.ImageData{
B64Json: prediction.BytesBase64Encoded,
})
}
jsonResponse, jsonErr := json.Marshal(openAIResponse)
if jsonErr != nil {
return nil, service.OpenAIErrorWrapper(jsonErr, "marshal_response_failed", http.StatusInternalServerError)
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, _ = c.Writer.Write(jsonResponse)
// https://github.com/google-gemini/cookbook/blob/719a27d752aac33f39de18a8d3cb42a70874917e/quickstarts/Counting_Tokens.ipynb
// each image has fixed 258 tokens
const imageTokens = 258
generatedImages := len(openAIResponse.Data)
usage = &dto.Usage{
PromptTokens: imageTokens * generatedImages, // each generated image has fixed 258 tokens
CompletionTokens: 0, // image generation does not calculate completion tokens
TotalTokens: imageTokens * generatedImages,
}
return usage, nil
}
func (a *Adaptor) GetModelList() []string {
return ModelList
}

View File

@@ -16,6 +16,8 @@ var ModelList = []string{
"gemini-2.0-pro-exp",
// thinking exp
"gemini-2.0-flash-thinking-exp",
// imagen models
"imagen-3.0-generate-002",
}
var ChannelName = "google gemini"

View File

@@ -109,3 +109,30 @@ type GeminiUsageMetadata struct {
CandidatesTokenCount int `json:"candidatesTokenCount"`
TotalTokenCount int `json:"totalTokenCount"`
}
// Imagen related structs
type GeminiImageRequest struct {
Instances []GeminiImageInstance `json:"instances"`
Parameters GeminiImageParameters `json:"parameters"`
}
type GeminiImageInstance struct {
Prompt string `json:"prompt"`
}
type GeminiImageParameters struct {
SampleCount int `json:"sampleCount,omitempty"`
AspectRatio string `json:"aspectRatio,omitempty"`
PersonGeneration string `json:"personGeneration,omitempty"`
}
type GeminiImageResponse struct {
Predictions []GeminiImagePrediction `json:"predictions"`
}
type GeminiImagePrediction struct {
MimeType string `json:"mimeType"`
BytesBase64Encoded string `json:"bytesBase64Encoded"`
RaiFilteredReason string `json:"raiFilteredReason,omitempty"`
SafetyAttributes any `json:"safetyAttributes,omitempty"`
}

View File

@@ -46,7 +46,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, re
if request == nil {
return nil, errors.New("request is nil")
}
return requestOpenAI2Ollama(*request), nil
return requestOpenAI2Ollama(*request)
}
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {

View File

@@ -3,18 +3,21 @@ package ollama
import "one-api/dto"
type OllamaRequest struct {
Model string `json:"model,omitempty"`
Messages []dto.Message `json:"messages,omitempty"`
Stream bool `json:"stream,omitempty"`
Temperature *float64 `json:"temperature,omitempty"`
Seed float64 `json:"seed,omitempty"`
Topp float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
Stop any `json:"stop,omitempty"`
Tools []dto.ToolCall `json:"tools,omitempty"`
ResponseFormat any `json:"response_format,omitempty"`
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
PresencePenalty float64 `json:"presence_penalty,omitempty"`
Model string `json:"model,omitempty"`
Messages []dto.Message `json:"messages,omitempty"`
Stream bool `json:"stream,omitempty"`
Temperature *float64 `json:"temperature,omitempty"`
Seed float64 `json:"seed,omitempty"`
Topp float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
Stop any `json:"stop,omitempty"`
Tools []dto.ToolCall `json:"tools,omitempty"`
ResponseFormat any `json:"response_format,omitempty"`
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
PresencePenalty float64 `json:"presence_penalty,omitempty"`
Suffix any `json:"suffix,omitempty"`
StreamOptions *dto.StreamOptions `json:"stream_options,omitempty"`
Prompt any `json:"prompt,omitempty"`
}
type Options struct {
@@ -35,7 +38,7 @@ type OllamaEmbeddingRequest struct {
}
type OllamaEmbeddingResponse struct {
Error string `json:"error,omitempty"`
Model string `json:"model"`
Error string `json:"error,omitempty"`
Model string `json:"model"`
Embedding [][]float64 `json:"embeddings,omitempty"`
}

View File

@@ -9,14 +9,36 @@ import (
"net/http"
"one-api/dto"
"one-api/service"
"strings"
)
func requestOpenAI2Ollama(request dto.GeneralOpenAIRequest) *OllamaRequest {
func requestOpenAI2Ollama(request dto.GeneralOpenAIRequest) (*OllamaRequest, error) {
messages := make([]dto.Message, 0, len(request.Messages))
for _, message := range request.Messages {
if !message.IsStringContent() {
mediaMessages := message.ParseContent()
for j, mediaMessage := range mediaMessages {
if mediaMessage.Type == dto.ContentTypeImageURL {
imageUrl := mediaMessage.ImageUrl.(dto.MessageImageUrl)
// check if not base64
if strings.HasPrefix(imageUrl.Url, "http") {
fileData, err := service.GetFileBase64FromUrl(imageUrl.Url)
if err != nil {
return nil, err
}
imageUrl.Url = fmt.Sprintf("data:%s;base64,%s", fileData.MimeType, fileData.Base64Data)
}
mediaMessage.ImageUrl = imageUrl
mediaMessages[j] = mediaMessage
}
}
message.SetMediaContent(mediaMessages)
}
messages = append(messages, dto.Message{
Role: message.Role,
Content: message.Content,
Role: message.Role,
Content: message.Content,
ToolCalls: message.ToolCalls,
ToolCallId: message.ToolCallId,
})
}
str, ok := request.Stop.(string)
@@ -39,7 +61,10 @@ func requestOpenAI2Ollama(request dto.GeneralOpenAIRequest) *OllamaRequest {
ResponseFormat: request.ResponseFormat,
FrequencyPenalty: request.FrequencyPenalty,
PresencePenalty: request.PresencePenalty,
}
Prompt: request.Prompt,
StreamOptions: request.StreamOptions,
Suffix: request.Suffix,
}, nil
}
func requestOpenAI2Embeddings(request dto.EmbeddingRequest) *OllamaEmbeddingRequest {

View File

@@ -119,7 +119,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, re
request.MaxCompletionTokens = request.MaxTokens
request.MaxTokens = 0
}
if strings.HasPrefix(request.Model, "o3") {
if strings.HasPrefix(request.Model, "o3") || strings.HasPrefix(request.Model, "o1") {
request.Temperature = nil
}
if strings.HasSuffix(request.Model, "-high") {

View File

@@ -5,6 +5,9 @@ import (
"bytes"
"encoding/json"
"fmt"
"github.com/bytedance/gopkg/util/gopool"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
"github.com/pkg/errors"
"io"
"math"
@@ -20,10 +23,6 @@ import (
"strings"
"sync"
"time"
"github.com/bytedance/gopkg/util/gopool"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
)
func sendStreamData(c *gin.Context, data string, forceFormat bool) error {
@@ -91,11 +90,12 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
if len(data) < 6 { // ignore blank line or wrong format
continue
}
if data[:6] != "data: " && data[:6] != "[DONE]" {
if data[:5] != "data:" && data[:6] != "[DONE]" {
continue
}
mu.Lock()
data = data[6:]
data = data[5:]
data = strings.TrimSpace(data)
if !strings.HasPrefix(data, "[DONE]") {
if lastStreamData != "" {
err := sendStreamData(c, lastStreamData, forceFormat)

View File

@@ -36,6 +36,8 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
return fmt.Sprintf("%s/v1/embeddings", info.BaseUrl), nil
} else if info.RelayMode == constant.RelayModeChatCompletions {
return fmt.Sprintf("%s/v1/chat/completions", info.BaseUrl), nil
} else if info.RelayMode == constant.RelayModeCompletions {
return fmt.Sprintf("%s/v1/completions", info.BaseUrl), nil
}
return "", errors.New("invalid relay mode")
}
@@ -72,6 +74,12 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
} else {
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
}
case constant.RelayModeCompletions:
if info.IsStream {
err, usage = openai.OaiStreamHandler(c, resp, info)
} else {
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
}
case constant.RelayModeEmbeddings:
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
}

View File

@@ -11,6 +11,7 @@ import (
"one-api/relay/channel/openai"
relaycommon "one-api/relay/common"
"one-api/relay/constant"
"strings"
)
type Adaptor struct {
@@ -32,6 +33,9 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
switch info.RelayMode {
case constant.RelayModeChatCompletions:
if strings.HasPrefix(info.UpstreamModelName, "bot") {
return fmt.Sprintf("%s/api/v3/bots/chat/completions", info.BaseUrl), nil
}
return fmt.Sprintf("%s/api/v3/chat/completions", info.BaseUrl), nil
case constant.RelayModeEmbeddings:
return fmt.Sprintf("%s/api/v3/embeddings", info.BaseUrl), nil

View File

@@ -90,8 +90,7 @@ func requestOpenAI2Zhipu(request dto.GeneralOpenAIRequest) *dto.GeneralOpenAIReq
mediaMessages[j] = mediaMessage
}
}
messageRaw, _ := json.Marshal(mediaMessages)
message.Content = messageRaw
message.SetMediaContent(mediaMessages)
}
messages = append(messages, dto.Message{
Role: message.Role,

View File

@@ -112,7 +112,8 @@ func GenRelayInfo(c *gin.Context) *RelayInfo {
}
if info.ChannelType == common.ChannelTypeOpenAI || info.ChannelType == common.ChannelTypeAnthropic ||
info.ChannelType == common.ChannelTypeAws || info.ChannelType == common.ChannelTypeGemini ||
info.ChannelType == common.ChannelCloudflare || info.ChannelType == common.ChannelTypeAzure {
info.ChannelType == common.ChannelCloudflare || info.ChannelType == common.ChannelTypeAzure ||
info.ChannelType == common.ChannelTypeVolcEngine || info.ChannelType == common.ChannelTypeOllama {
info.SupportStreamOptions = true
}
return info

View File

@@ -194,7 +194,7 @@ func RelaySwapFace(c *gin.Context) *dto.MidjourneyResponse {
}
defer func(ctx context.Context) {
if mjResp.StatusCode == 200 && mjResp.Response.Code == 1 {
err := model.PostConsumeQuota(relayInfo, userQuota, quota, 0, true)
err := service.PostConsumeQuota(relayInfo, quota, 0, true)
if err != nil {
common.SysError("error consuming token remain quota: " + err.Error())
}
@@ -500,7 +500,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
defer func(ctx context.Context) {
if consumeQuota && midjResponseWithStatus.StatusCode == 200 {
err := model.PostConsumeQuota(relayInfo, userQuota, quota, 0, true)
err := service.PostConsumeQuota(relayInfo, quota, 0, true)
if err != nil {
common.SysError("error consuming token remain quota: " + err.Error())
}

View File

@@ -5,6 +5,7 @@ import (
"encoding/json"
"errors"
"fmt"
"github.com/bytedance/gopkg/util/gopool"
"io"
"math"
"net/http"
@@ -272,7 +273,7 @@ func preConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommo
return 0, 0, service.OpenAIErrorWrapperLocal(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
}
if userQuota-preConsumedQuota < 0 {
return 0, 0, service.OpenAIErrorWrapperLocal(fmt.Errorf("chat pre-consumed quota failed, user quota: %d, need quota: %d", userQuota, preConsumedQuota), "insufficient_user_quota", http.StatusBadRequest)
return 0, 0, service.OpenAIErrorWrapperLocal(fmt.Errorf("chat pre-consumed quota failed, user quota: %s, need quota: %s", common.FormatQuota(userQuota), common.FormatQuota(preConsumedQuota)), "insufficient_user_quota", http.StatusBadRequest)
}
if userQuota > 100*preConsumedQuota {
// 用户额度充足,判断令牌额度是否充足
@@ -282,18 +283,18 @@ func preConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommo
if tokenQuota > 100*preConsumedQuota {
// 令牌额度充足,信任令牌
preConsumedQuota = 0
common.LogInfo(c, fmt.Sprintf("user %d quota %d and token %d quota %d are enough, trusted and no need to pre-consume", relayInfo.UserId, userQuota, relayInfo.TokenId, tokenQuota))
common.LogInfo(c, fmt.Sprintf("user %d quota %s and token %d quota %d are enough, trusted and no need to pre-consume", relayInfo.UserId, common.FormatQuota(userQuota), relayInfo.TokenId, tokenQuota))
}
} else {
// in this case, we do not pre-consume quota
// because the user has enough quota
preConsumedQuota = 0
common.LogInfo(c, fmt.Sprintf("user %d with unlimited token has enough quota %d, trusted and no need to pre-consume", relayInfo.UserId, userQuota))
common.LogInfo(c, fmt.Sprintf("user %d with unlimited token has enough quota %s, trusted and no need to pre-consume", relayInfo.UserId, common.FormatQuota(userQuota)))
}
}
if preConsumedQuota > 0 {
err = model.PreConsumeTokenQuota(relayInfo, preConsumedQuota)
err = service.PreConsumeTokenQuota(relayInfo, preConsumedQuota)
if err != nil {
return 0, 0, service.OpenAIErrorWrapperLocal(err, "pre_consume_token_quota_failed", http.StatusForbidden)
}
@@ -307,14 +308,14 @@ func preConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommo
func returnPreConsumedQuota(c *gin.Context, relayInfo *relaycommon.RelayInfo, userQuota int, preConsumedQuota int) {
if preConsumedQuota != 0 {
go func() {
gopool.Go(func() {
relayInfoCopy := *relayInfo
err := model.PostConsumeQuota(&relayInfoCopy, userQuota, -preConsumedQuota, 0, false)
err := service.PostConsumeQuota(&relayInfoCopy, -preConsumedQuota, 0, false)
if err != nil {
common.SysError("error return pre-consumed quota: " + err.Error())
}
}()
})
}
}
@@ -368,7 +369,7 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelN
//}
quotaDelta := quota - preConsumedQuota
if quotaDelta != 0 {
err := model.PostConsumeQuota(relayInfo, userQuota, quotaDelta, preConsumedQuota, true)
err := service.PostConsumeQuota(relayInfo, quotaDelta, preConsumedQuota, true)
if err != nil {
common.LogError(ctx, "error consuming token remain quota: "+err.Error())
}

View File

@@ -113,7 +113,7 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) {
// release quota
if relayInfo.ConsumeQuota && taskErr == nil {
err := model.PostConsumeQuota(relayInfo.ToRelayInfo(), userQuota, quota, 0, true)
err := service.PostConsumeQuota(relayInfo.ToRelayInfo(), quota, 0, true)
if err != nil {
common.SysError("error consuming token remain quota: " + err.Error())
}

View File

@@ -56,6 +56,7 @@ func SetApiRouter(router *gin.Engine) {
selfRoute.POST("/pay", controller.RequestEpay)
selfRoute.POST("/amount", controller.RequestAmount)
selfRoute.POST("/aff_transfer", controller.TransferAffQuota)
selfRoute.PUT("/setting", controller.UpdateUserSetting)
}
adminRoute := userRoute.Group("/")

View File

@@ -2,6 +2,7 @@ package service
import (
"bytes"
"encoding/json"
"fmt"
"net/http"
"one-api/common"
@@ -9,19 +10,46 @@ import (
"strings"
)
// WorkerRequest Worker请求的数据结构
type WorkerRequest struct {
URL string `json:"url"`
Key string `json:"key"`
Method string `json:"method,omitempty"`
Headers map[string]string `json:"headers,omitempty"`
Body json.RawMessage `json:"body,omitempty"`
}
// DoWorkerRequest 通过Worker发送请求
func DoWorkerRequest(req *WorkerRequest) (*http.Response, error) {
if !setting.EnableWorker() {
return nil, fmt.Errorf("worker not enabled")
}
if !strings.HasPrefix(req.URL, "https") {
return nil, fmt.Errorf("only support https url")
}
workerUrl := setting.WorkerUrl
if !strings.HasSuffix(workerUrl, "/") {
workerUrl += "/"
}
// 序列化worker请求数据
workerPayload, err := json.Marshal(req)
if err != nil {
return nil, fmt.Errorf("failed to marshal worker payload: %v", err)
}
return http.Post(workerUrl, "application/json", bytes.NewBuffer(workerPayload))
}
func DoDownloadRequest(originUrl string) (resp *http.Response, err error) {
if setting.EnableWorker() {
common.SysLog(fmt.Sprintf("downloading file from worker: %s", originUrl))
if !strings.HasPrefix(originUrl, "https") {
return nil, fmt.Errorf("only support https url")
req := &WorkerRequest{
URL: originUrl,
Key: setting.WorkerValidKey,
}
workerUrl := setting.WorkerUrl
if !strings.HasSuffix(workerUrl, "/") {
workerUrl += "/"
}
// post request to worker
data := []byte(`{"url":"` + originUrl + `","key":"` + setting.WorkerValidKey + `"}`)
return http.Post(setting.WorkerUrl, "application/json", bytes.NewBuffer(data))
return DoWorkerRequest(req)
} else {
common.SysLog(fmt.Sprintf("downloading from origin: %s", originUrl))
return http.Get(originUrl)

View File

@@ -4,8 +4,9 @@ import (
"fmt"
"net/http"
"one-api/common"
relaymodel "one-api/dto"
"one-api/dto"
"one-api/model"
"one-api/setting"
"strings"
)
@@ -14,17 +15,17 @@ func DisableChannel(channelId int, channelName string, reason string) {
model.UpdateChannelStatusById(channelId, common.ChannelStatusAutoDisabled, reason)
subject := fmt.Sprintf("通道「%s」#%d已被禁用", channelName, channelId)
content := fmt.Sprintf("通道「%s」#%d已被禁用原因%s", channelName, channelId, reason)
notifyRootUser(subject, content)
NotifyRootUser(subject, content, dto.NotifyTypeChannelUpdate)
}
func EnableChannel(channelId int, channelName string) {
model.UpdateChannelStatusById(channelId, common.ChannelStatusEnabled, "")
subject := fmt.Sprintf("通道「%s」#%d已被启用", channelName, channelId)
content := fmt.Sprintf("通道「%s」#%d已被启用", channelName, channelId)
notifyRootUser(subject, content)
NotifyRootUser(subject, content, dto.NotifyTypeChannelUpdate)
}
func ShouldDisableChannel(channelType int, err *relaymodel.OpenAIErrorWithStatusCode) bool {
func ShouldDisableChannel(channelType int, err *dto.OpenAIErrorWithStatusCode) bool {
if !common.AutomaticDisableChannelEnabled {
return false
}
@@ -64,28 +65,17 @@ func ShouldDisableChannel(channelType int, err *relaymodel.OpenAIErrorWithStatus
case "forbidden":
return true
}
if strings.HasPrefix(err.Error.Message, "Your credit balance is too low") { // anthropic
return true
} else if strings.HasPrefix(err.Error.Message, "This organization has been disabled.") {
return true
} else if strings.HasPrefix(err.Error.Message, "You exceeded your current quota") {
return true
} else if strings.HasPrefix(err.Error.Message, "Permission denied") {
return true
}
if strings.Contains(err.Error.Message, "The security token included in the request is invalid") { // anthropic
return true
} else if strings.Contains(err.Error.Message, "Operation not allowed") {
return true
} else if strings.Contains(err.Error.Message, "Your account is not authorized") {
lowerMessage := strings.ToLower(err.Error.Message)
search, _ := AcSearch(lowerMessage, setting.AutomaticDisableKeywords, true)
if search {
return true
}
return false
}
func ShouldEnableChannel(err error, openaiWithStatusErr *relaymodel.OpenAIErrorWithStatusCode, status int) bool {
func ShouldEnableChannel(err error, openaiWithStatusErr *dto.OpenAIErrorWithStatusCode, status int) bool {
if !common.AutomaticEnableChannelEnabled {
return false
}

117
service/notify-limit.go Normal file
View File

@@ -0,0 +1,117 @@
package service
import (
"fmt"
"github.com/bytedance/gopkg/util/gopool"
"one-api/common"
"one-api/constant"
"strconv"
"sync"
"time"
)
// notifyLimitStore is used for in-memory rate limiting when Redis is disabled
var (
notifyLimitStore sync.Map
cleanupOnce sync.Once
)
type limitCount struct {
Count int
Timestamp time.Time
}
func getDuration() time.Duration {
minute := constant.NotificationLimitDurationMinute
return time.Duration(minute) * time.Minute
}
// startCleanupTask starts a background task to clean up expired entries
func startCleanupTask() {
gopool.Go(func() {
for {
time.Sleep(time.Hour)
now := time.Now()
notifyLimitStore.Range(func(key, value interface{}) bool {
if limit, ok := value.(limitCount); ok {
if now.Sub(limit.Timestamp) >= getDuration() {
notifyLimitStore.Delete(key)
}
}
return true
})
}
})
}
// CheckNotificationLimit checks if the user has exceeded their notification limit
// Returns true if the user can send notification, false if limit exceeded
func CheckNotificationLimit(userId int, notifyType string) (bool, error) {
if common.RedisEnabled {
return checkRedisLimit(userId, notifyType)
}
return checkMemoryLimit(userId, notifyType)
}
func checkRedisLimit(userId int, notifyType string) (bool, error) {
key := fmt.Sprintf("notify_limit:%d:%s:%s", userId, notifyType, time.Now().Format("2006010215"))
// Get current count
count, err := common.RedisGet(key)
if err != nil && err.Error() != "redis: nil" {
return false, fmt.Errorf("failed to get notification count: %w", err)
}
// If key doesn't exist, initialize it
if count == "" {
err = common.RedisSet(key, "1", getDuration())
return true, err
}
currentCount, _ := strconv.Atoi(count)
limit := constant.NotifyLimitCount
// Check if limit is already reached
if currentCount >= limit {
return false, nil
}
// Only increment if under limit
err = common.RedisIncr(key, 1)
if err != nil {
return false, fmt.Errorf("failed to increment notification count: %w", err)
}
return true, nil
}
func checkMemoryLimit(userId int, notifyType string) (bool, error) {
// Ensure cleanup task is started
cleanupOnce.Do(startCleanupTask)
key := fmt.Sprintf("%d:%s:%s", userId, notifyType, time.Now().Format("2006010215"))
now := time.Now()
// Get current limit count or initialize new one
var currentLimit limitCount
if value, ok := notifyLimitStore.Load(key); ok {
currentLimit = value.(limitCount)
// Check if the entry has expired
if now.Sub(currentLimit.Timestamp) >= getDuration() {
currentLimit = limitCount{Count: 0, Timestamp: now}
}
} else {
currentLimit = limitCount{Count: 0, Timestamp: now}
}
// Increment count
currentLimit.Count++
// Check against limits
limit := constant.NotifyLimitCount
// Store updated count
notifyLimitStore.Store(key, currentLimit)
return currentLimit.Count <= limit, nil
}

View File

@@ -3,8 +3,10 @@ package service
import (
"errors"
"fmt"
"github.com/bytedance/gopkg/util/gopool"
"math"
"one-api/common"
constant2 "one-api/constant"
"one-api/dto"
"one-api/model"
relaycommon "one-api/relay/common"
@@ -99,7 +101,7 @@ func PreWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usag
return errors.New(fmt.Sprintf("令牌额度不足,剩余额度为 %d", token.RemainQuota))
}
err = model.PostConsumeQuota(relayInfo, 0, quota, 0, false)
err = PostConsumeQuota(relayInfo, quota, 0, false)
if err != nil {
return err
}
@@ -222,7 +224,7 @@ func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
} else {
quotaDelta := quota - preConsumedQuota
if quotaDelta != 0 {
err := model.PostConsumeQuota(relayInfo, userQuota, quotaDelta, preConsumedQuota, true)
err := PostConsumeQuota(relayInfo, quotaDelta, preConsumedQuota, true)
if err != nil {
common.LogError(ctx, "error consuming token remain quota: "+err.Error())
}
@@ -239,3 +241,88 @@ func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, usage.PromptTokens, usage.CompletionTokens, logModel,
tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, relayInfo.Group, other)
}
func PreConsumeTokenQuota(relayInfo *relaycommon.RelayInfo, quota int) error {
if quota < 0 {
return errors.New("quota 不能为负数!")
}
if relayInfo.IsPlayground {
return nil
}
//if relayInfo.TokenUnlimited {
// return nil
//}
token, err := model.GetTokenByKey(relayInfo.TokenKey, false)
if err != nil {
return err
}
if !relayInfo.TokenUnlimited && token.RemainQuota < quota {
return errors.New("令牌额度不足")
}
err = model.DecreaseTokenQuota(relayInfo.TokenId, relayInfo.TokenKey, quota)
if err != nil {
return err
}
return nil
}
func PostConsumeQuota(relayInfo *relaycommon.RelayInfo, quota int, preConsumedQuota int, sendEmail bool) (err error) {
if quota > 0 {
err = model.DecreaseUserQuota(relayInfo.UserId, quota)
} else {
err = model.IncreaseUserQuota(relayInfo.UserId, -quota)
}
if err != nil {
return err
}
if !relayInfo.IsPlayground {
if quota > 0 {
err = model.DecreaseTokenQuota(relayInfo.TokenId, relayInfo.TokenKey, quota)
} else {
err = model.IncreaseTokenQuota(relayInfo.TokenId, relayInfo.TokenKey, -quota)
}
if err != nil {
return err
}
}
if sendEmail {
if (quota + preConsumedQuota) != 0 {
checkAndSendQuotaNotify(relayInfo.UserId, quota, preConsumedQuota)
}
}
return nil
}
func checkAndSendQuotaNotify(userId int, quota int, preConsumedQuota int) {
gopool.Go(func() {
userCache, err := model.GetUserCache(userId)
if err != nil {
common.SysError("failed to get user cache: " + err.Error())
}
userSetting := userCache.GetSetting()
threshold := common.QuotaRemindThreshold
if userCustomThreshold, ok := userSetting[constant2.UserSettingQuotaWarningThreshold]; ok {
threshold = int(userCustomThreshold.(float64))
}
//noMoreQuota := userCache.Quota-(quota+preConsumedQuota) <= 0
quotaTooLow := false
consumeQuota := quota + preConsumedQuota
if userCache.Quota-consumeQuota < threshold {
quotaTooLow = true
}
if quotaTooLow {
prompt := "您的额度即将用尽"
topUpLink := fmt.Sprintf("%s/topup", setting.ServerAddress)
content := "{{value}},当前剩余额度为 {{value}},为了不影响您的使用,请及时充值。<br/>充值链接:<a href='{{value}}'>{{value}}</a>"
err = NotifyUser(userCache, dto.NewNotify(dto.NotifyTypeQuotaExceed, prompt, content, []interface{}{prompt, common.FormatQuota(userCache.Quota), topUpLink, topUpLink}))
if err != nil {
common.SysError(fmt.Sprintf("failed to send quota notify to user %d: %s", userId, err.Error()))
}
}
})
}

View File

@@ -60,17 +60,7 @@ func SensitiveWordContains(text string) (bool, []string) {
return false, nil
}
checkText := strings.ToLower(text)
// 构建一个AC自动机
m := InitAc()
hits := m.MultiPatternSearch([]rune(checkText), false)
if len(hits) > 0 {
words := make([]string, 0)
for _, hit := range hits {
words = append(words, string(hit.Word))
}
return true, words
}
return false, nil
return AcSearch(checkText, setting.SensitiveWords, false)
}
// SensitiveWordReplace 敏感词替换,返回是否包含敏感词和替换后的文本
@@ -79,7 +69,7 @@ func SensitiveWordReplace(text string, returnImmediately bool) (bool, []string,
return false, nil, text
}
checkText := strings.ToLower(text)
m := InitAc()
m := InitAc(setting.SensitiveWords)
hits := m.MultiPatternSearch([]rune(checkText), returnImmediately)
if len(hits) > 0 {
words := make([]string, 0)

View File

@@ -4,7 +4,6 @@ import (
"bytes"
"fmt"
goahocorasick "github.com/anknown/ahocorasick"
"one-api/setting"
"strings"
)
@@ -57,9 +56,9 @@ func RemoveDuplicate(s []string) []string {
return result
}
func InitAc() *goahocorasick.Machine {
func InitAc(words []string) *goahocorasick.Machine {
m := new(goahocorasick.Machine)
dict := readRunes()
dict := readRunes(words)
if err := m.Build(dict); err != nil {
fmt.Println(err)
return nil
@@ -67,10 +66,10 @@ func InitAc() *goahocorasick.Machine {
return m
}
func readRunes() [][]rune {
func readRunes(words []string) [][]rune {
var dict [][]rune
for _, word := range setting.SensitiveWords {
for _, word := range words {
word = strings.ToLower(word)
l := bytes.TrimSpace([]byte(word))
dict = append(dict, bytes.Runes(l))
@@ -78,3 +77,25 @@ func readRunes() [][]rune {
return dict
}
func AcSearch(findText string, dict []string, stopImmediately bool) (bool, []string) {
if len(dict) == 0 {
return false, nil
}
if len(findText) == 0 {
return false, nil
}
m := InitAc(dict)
if m == nil {
return false, nil
}
hits := m.MultiPatternSearch([]rune(findText), stopImmediately)
if len(hits) > 0 {
words := make([]string, 0)
for _, hit := range hits {
words = append(words, string(hit.Word))
}
return true, words
}
return false, nil
}

View File

@@ -3,15 +3,75 @@ package service
import (
"fmt"
"one-api/common"
"one-api/constant"
"one-api/dto"
"one-api/model"
"strings"
)
func notifyRootUser(subject string, content string) {
if common.RootUserEmail == "" {
common.RootUserEmail = model.GetRootUserEmail()
}
err := common.SendEmail(subject, common.RootUserEmail, content)
if err != nil {
common.SysError(fmt.Sprintf("failed to send email: %s", err.Error()))
}
func NotifyRootUser(t string, subject string, content string) {
user := model.GetRootUser().ToBaseUser()
_ = NotifyUser(user, dto.NewNotify(t, subject, content, nil))
}
func NotifyUser(user *model.UserBase, data dto.Notify) error {
userSetting := user.GetSetting()
notifyType, ok := userSetting[constant.UserSettingNotifyType]
if !ok {
notifyType = constant.NotifyTypeEmail
}
// Check notification limit
canSend, err := CheckNotificationLimit(user.Id, data.Type)
if err != nil {
common.SysError(fmt.Sprintf("failed to check notification limit: %s", err.Error()))
return err
}
if !canSend {
return fmt.Errorf("notification limit exceeded for user %d with type %s", user.Id, notifyType)
}
switch notifyType {
case constant.NotifyTypeEmail:
userEmail := user.Email
// check setting email
if settingEmail, ok := userSetting[constant.UserSettingNotificationEmail]; ok {
userEmail = settingEmail.(string)
}
if userEmail == "" {
common.SysLog(fmt.Sprintf("user %d has no email, skip sending email", user.Id))
return nil
}
return sendEmailNotify(userEmail, data)
case constant.NotifyTypeWebhook:
webhookURL, ok := userSetting[constant.UserSettingWebhookUrl]
if !ok {
common.SysError(fmt.Sprintf("user %d has no webhook url, skip sending webhook", user.Id))
return nil
}
webhookURLStr, ok := webhookURL.(string)
if !ok {
common.SysError(fmt.Sprintf("user %d webhook url is not string type", user.Id))
return nil
}
// 获取 webhook secret
var webhookSecret string
if secret, ok := userSetting[constant.UserSettingWebhookSecret]; ok {
webhookSecret, _ = secret.(string)
}
return SendWebhookNotify(webhookURLStr, webhookSecret, data)
}
return nil
}
func sendEmailNotify(userEmail string, data dto.Notify) error {
// make email content
content := data.Content
// 处理占位符
for _, value := range data.Values {
content = strings.Replace(content, dto.ContentValueParam, fmt.Sprintf("%v", value), 1)
}
return common.SendEmail(data.Title, userEmail, content)
}

118
service/webhook.go Normal file
View File

@@ -0,0 +1,118 @@
package service
import (
"bytes"
"crypto/hmac"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"fmt"
"net/http"
"one-api/dto"
"one-api/setting"
"time"
)
// WebhookPayload webhook 通知的负载数据
type WebhookPayload struct {
Type string `json:"type"`
Title string `json:"title"`
Content string `json:"content"`
Values []interface{} `json:"values,omitempty"`
Timestamp int64 `json:"timestamp"`
}
// generateSignature 生成 webhook 签名
func generateSignature(secret string, payload []byte) string {
h := hmac.New(sha256.New, []byte(secret))
h.Write(payload)
return hex.EncodeToString(h.Sum(nil))
}
// SendWebhookNotify 发送 webhook 通知
func SendWebhookNotify(webhookURL string, secret string, data dto.Notify) error {
// 处理占位符
content := data.Content
for _, value := range data.Values {
content = fmt.Sprintf(content, value)
}
// 构建 webhook 负载
payload := WebhookPayload{
Type: data.Type,
Title: data.Title,
Content: content,
Values: data.Values,
Timestamp: time.Now().Unix(),
}
// 序列化负载
payloadBytes, err := json.Marshal(payload)
if err != nil {
return fmt.Errorf("failed to marshal webhook payload: %v", err)
}
// 创建 HTTP 请求
var req *http.Request
var resp *http.Response
if setting.EnableWorker() {
// 构建worker请求数据
workerReq := &WorkerRequest{
URL: webhookURL,
Key: setting.WorkerValidKey,
Method: http.MethodPost,
Headers: map[string]string{
"Content-Type": "application/json",
},
Body: payloadBytes,
}
// 如果有secret添加签名到headers
if secret != "" {
signature := generateSignature(secret, payloadBytes)
workerReq.Headers["X-Webhook-Signature"] = signature
workerReq.Headers["Authorization"] = "Bearer " + secret
}
resp, err = DoWorkerRequest(workerReq)
if err != nil {
return fmt.Errorf("failed to send webhook request through worker: %v", err)
}
defer resp.Body.Close()
// 检查响应状态
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return fmt.Errorf("webhook request failed with status code: %d", resp.StatusCode)
}
} else {
req, err = http.NewRequest(http.MethodPost, webhookURL, bytes.NewBuffer(payloadBytes))
if err != nil {
return fmt.Errorf("failed to create webhook request: %v", err)
}
// 设置请求头
req.Header.Set("Content-Type", "application/json")
// 如果有 secret生成签名
if secret != "" {
signature := generateSignature(secret, payloadBytes)
req.Header.Set("X-Webhook-Signature", signature)
}
// 发送请求
client := GetImpatientHttpClient()
resp, err = client.Do(req)
if err != nil {
return fmt.Errorf("failed to send webhook request: %v", err)
}
defer resp.Body.Close()
// 检查响应状态
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return fmt.Errorf("webhook request failed with status code: %d", resp.StatusCode)
}
}
return nil
}

View File

@@ -1,3 +1,30 @@
package setting
import "strings"
var DemoSiteEnabled = false
var AutomaticDisableKeywords = []string{
"Your credit balance is too low",
"This organization has been disabled.",
"You exceeded your current quota",
"Permission denied",
"The security token included in the request is invalid",
"Operation not allowed",
"Your account is not authorized",
}
func AutomaticDisableKeywordsToString() string {
return strings.Join(AutomaticDisableKeywords, "\n")
}
func AutomaticDisableKeywordsFromString(s string) {
AutomaticDisableKeywords = []string{}
ak := strings.Split(s, "\n")
for _, k := range ak {
k = strings.TrimSpace(k)
if k != "" {
AutomaticDisableKeywords = append(AutomaticDisableKeywords, k)
}
}
}

View File

@@ -357,6 +357,13 @@ const ChannelsTable = () => {
dataIndex: 'operate',
render: (text, record, index) => {
if (record.children === undefined) {
// 构建模型测试菜单
const modelMenuItems = record.models.split(',').map(model => ({
node: 'item',
name: model,
onClick: () => testChannel(record, model)
}));
return (
<div>
<SplitButtonGroup
@@ -374,7 +381,7 @@ const ChannelsTable = () => {
<Dropdown
trigger="click"
position="bottomRight"
menu={record.test_models}
menu={modelMenuItems} // 使用即时生成的菜单项
>
<Button
style={{ padding: '8px 4px' }}
@@ -545,17 +552,6 @@ const ChannelsTable = () => {
let channelTags = {};
for (let i = 0; i < channels.length; i++) {
channels[i].key = '' + channels[i].id;
let test_models = [];
channels[i].models.split(',').forEach((item, index) => {
test_models.push({
node: 'item',
name: item,
onClick: () => {
testChannel(channels[i], item);
}
});
});
channels[i].test_models = test_models;
if (!enableTagMode) {
channelDates.push(channels[i]);
} else {
@@ -798,16 +794,59 @@ const ChannelsTable = () => {
setSearching(false);
};
const updateChannelProperty = (channelId, updateFn) => {
// Create a new copy of channels array
const newChannels = [...channels];
let updated = false;
// Find and update the correct channel
newChannels.forEach(channel => {
if (channel.children !== undefined) {
// If this is a tag group, search in its children
channel.children.forEach(child => {
if (child.id === channelId) {
updateFn(child);
updated = true;
}
});
} else if (channel.id === channelId) {
// Direct channel match
updateFn(channel);
updated = true;
}
});
// Only update state if we actually modified a channel
if (updated) {
setChannels(newChannels);
}
};
const testChannel = async (record, model) => {
const res = await API.get(`/api/channel/test/${record.id}?model=${model}`);
const { success, message, time } = res.data;
if (success) {
record.response_time = time * 1000;
record.test_time = Date.now() / 1000;
// Also update the channels state to persist the change
updateChannelProperty(record.id, (channel) => {
channel.response_time = time * 1000;
channel.test_time = Date.now() / 1000;
});
showInfo(t('通道 ${name} 测试成功,耗时 ${time.toFixed(2)} 秒。').replace('${name}', record.name).replace('${time.toFixed(2)}', time.toFixed(2)));
} else {
showError(message);
}
};
// 刷新列表
await refresh();
const updateChannelBalance = async (record) => {
const res = await API.get(`/api/channel/update_balance/${record.id}/`);
const { success, message, balance } = res.data;
if (success) {
updateChannelProperty(record.id, (channel) => {
channel.balance = balance;
channel.balance_updated_time = Date.now() / 1000;
});
showInfo(t('通道 ${name} 余额更新成功!').replace('${name}', record.name));
} else {
showError(message);
}
@@ -834,20 +873,6 @@ const ChannelsTable = () => {
}
};
const updateChannelBalance = async (record) => {
const res = await API.get(`/api/channel/update_balance/${record.id}/`);
const { success, message, balance } = res.data;
if (success) {
record.balance = balance;
record.balance_updated_time = Date.now() / 1000;
showInfo(t('通道 ${name} 余额更新成功!').replace('${name}', record.name));
// 刷新列表
await refresh();
} else {
showError(message);
}
};
const updateAllChannelsBalance = async () => {
setUpdatingBalance(true);
const res = await API.get(`/api/channel/update_balance`);

View File

@@ -59,6 +59,7 @@ const OperationSetting = () => {
RetryTimes: 0,
Chats: "[]",
DemoSiteEnabled: false,
AutomaticDisableKeywords: '',
});
let [loading, setLoading] = useState(false);

View File

@@ -26,6 +26,10 @@ import {
Tag,
Typography,
Collapsible,
Select,
Radio,
RadioGroup,
AutoComplete,
} from '@douyinfe/semi-ui';
import {
getQuotaPerUnit,
@@ -67,14 +71,16 @@ const PersonalSetting = () => {
const [transferAmount, setTransferAmount] = useState(0);
const [isModelsExpanded, setIsModelsExpanded] = useState(false);
const MODELS_DISPLAY_COUNT = 10; // 默认显示的模型数量
const [notificationSettings, setNotificationSettings] = useState({
warningType: 'email',
warningThreshold: 100000,
webhookUrl: '',
webhookSecret: '',
notificationEmail: ''
});
const [showWebhookDocs, setShowWebhookDocs] = useState(false);
useEffect(() => {
// let user = localStorage.getItem('user');
// if (user) {
// userDispatch({ type: 'login', payload: user });
// }
// console.log(localStorage.getItem('user'))
let status = localStorage.getItem('status');
if (status) {
status = JSON.parse(status);
@@ -105,6 +111,19 @@ const PersonalSetting = () => {
return () => clearInterval(countdownInterval); // Clean up on unmount
}, [disableButton, countdown]);
useEffect(() => {
if (userState?.user?.setting) {
const settings = JSON.parse(userState.user.setting);
setNotificationSettings({
warningType: settings.notify_type || 'email',
warningThreshold: settings.quota_warning_threshold || 500000,
webhookUrl: settings.webhook_url || '',
webhookSecret: settings.webhook_secret || '',
notificationEmail: settings.notification_email || ''
});
}
}, [userState?.user?.setting]);
const handleInputChange = (name, value) => {
setInputs((inputs) => ({...inputs, [name]: value}));
};
@@ -300,7 +319,36 @@ const PersonalSetting = () => {
}
};
const handleNotificationSettingChange = (type, value) => {
setNotificationSettings(prev => ({
...prev,
[type]: value.target ? value.target.value : value // 处理 Radio 事件对象
}));
};
const saveNotificationSettings = async () => {
try {
const res = await API.put('/api/user/setting', {
notify_type: notificationSettings.warningType,
quota_warning_threshold: notificationSettings.warningThreshold,
webhook_url: notificationSettings.webhookUrl,
webhook_secret: notificationSettings.webhookSecret,
notification_email: notificationSettings.notificationEmail
});
if (res.data.success) {
showSuccess(t('通知设置已更新'));
await getUserData();
} else {
showError(res.data.message);
}
} catch (error) {
showError(t('更新通知设置失败'));
}
};
return (
<div>
<Layout>
<Layout.Content>
@@ -526,9 +574,7 @@ const PersonalSetting = () => {
</div>
<div style={{marginTop: 10}}>
<Typography.Text strong>{t('微信')}</Typography.Text>
<div
style={{display: 'flex', justifyContent: 'space-between'}}
>
<div style={{display: 'flex', justifyContent: 'space-between'}}>
<div>
<Input
value={
@@ -541,12 +587,16 @@ const PersonalSetting = () => {
</div>
<div>
<Button
disabled={
(userState.user && userState.user.wechat_id !== '') ||
!status.wechat_login
}
disabled={!status.wechat_login}
onClick={() => {
setShowWeChatBindModal(true);
}}
>
{status.wechat_login ? t('绑定') : t('未启用')}
{userState.user && userState.user.wechat_id !== ''
? t('修改绑定')
: status.wechat_login
? t('绑定')
: t('未启用')}
</Button>
</div>
</div>
@@ -672,18 +722,8 @@ const PersonalSetting = () => {
style={{marginTop: '10px'}}
/>
)}
{status.wechat_login && (
<Button
onClick={() => {
setShowWeChatBindModal(true);
}}
>
{t('绑定微信账号')}
</Button>
)}
<Modal
onCancel={() => setShowWeChatBindModal(false)}
// onOpen={() => setShowWeChatBindModal(true)}
visible={showWeChatBindModal}
size={'small'}
>
@@ -707,9 +747,121 @@ const PersonalSetting = () => {
</Modal>
</div>
</Card>
<Card style={{marginTop: 10}}>
<Typography.Title heading={6}>{t('通知设置')}</Typography.Title>
<div style={{marginTop: 20}}>
<Typography.Text strong>{t('通知方式')}</Typography.Text>
<div style={{marginTop: 10}}>
<RadioGroup
value={notificationSettings.warningType}
onChange={value => handleNotificationSettingChange('warningType', value)}
>
<Radio value="email">{t('邮件通知')}</Radio>
<Radio value="webhook">{t('Webhook通知')}</Radio>
</RadioGroup>
</div>
</div>
{notificationSettings.warningType === 'webhook' && (
<>
<div style={{marginTop: 20}}>
<Typography.Text strong>{t('Webhook地址')}</Typography.Text>
<div style={{marginTop: 10}}>
<Input
value={notificationSettings.webhookUrl}
onChange={val => handleNotificationSettingChange('webhookUrl', val)}
placeholder={t('请输入Webhook地址例如: https://example.com/webhook')}
/>
<Typography.Text type="secondary" style={{marginTop: 8, display: 'block'}}>
{t('只支持https系统将以 POST 方式发送通知,请确保地址可以接收 POST 请求')}
</Typography.Text>
<Typography.Text type="secondary" style={{marginTop: 8, display: 'block'}}>
<div style={{cursor: 'pointer'}} onClick={() => setShowWebhookDocs(!showWebhookDocs)}>
{t('Webhook请求结构')} {showWebhookDocs ? '▼' : '▶'}
</div>
<Collapsible isOpen={showWebhookDocs}>
<pre style={{marginTop: 4, background: 'var(--semi-color-fill-0)', padding: 8, borderRadius: 4}}>
{`{
"type": "quota_exceed", // 通知类型
"title": "标题", // 通知标题
"content": "通知内容", // 通知内容,支持 {{value}} 变量占位符
"values": ["值1", "值2"], // 按顺序替换content中的 {{value}} 占位符
"timestamp": 1739950503 // 时间戳
}
示例:
{
"type": "quota_exceed",
"title": "额度预警通知",
"content": "您的额度即将用尽,当前剩余额度为 {{value}}",
"values": ["$0.99"],
"timestamp": 1739950503
}`}
</pre>
</Collapsible>
</Typography.Text>
</div>
</div>
<div style={{marginTop: 20}}>
<Typography.Text strong>{t('接口凭证(可选)')}</Typography.Text>
<div style={{marginTop: 10}}>
<Input
value={notificationSettings.webhookSecret}
onChange={val => handleNotificationSettingChange('webhookSecret', val)}
placeholder={t('请输入密钥')}
/>
<Typography.Text type="secondary" style={{marginTop: 8, display: 'block'}}>
{t('密钥将以 Bearer 方式添加到请求头中用于验证webhook请求的合法性')}
</Typography.Text>
<Typography.Text type="secondary" style={{marginTop: 4, display: 'block'}}>
{t('Authorization: Bearer your-secret-key')}
</Typography.Text>
</div>
</div>
</>
)}
{notificationSettings.warningType === 'email' && (
<div style={{marginTop: 20}}>
<Typography.Text strong>{t('通知邮箱')}</Typography.Text>
<div style={{marginTop: 10}}>
<Input
value={notificationSettings.notificationEmail}
onChange={val => handleNotificationSettingChange('notificationEmail', val)}
placeholder={t('留空则使用账号绑定的邮箱')}
/>
<Typography.Text type="secondary" style={{marginTop: 8, display: 'block'}}>
{t('设置用于接收额度预警的邮箱地址,不填则使用账号绑定的邮箱')}
</Typography.Text>
</div>
</div>
)}
<div style={{marginTop: 20}}>
<Typography.Text strong>{t('额度预警阈值')} {renderQuotaWithPrompt(notificationSettings.warningThreshold)}</Typography.Text>
<div style={{marginTop: 10}}>
<AutoComplete
value={notificationSettings.warningThreshold}
onChange={val => handleNotificationSettingChange('warningThreshold', val)}
style={{width: 200}}
placeholder={t('请输入预警额度')}
data={[
{ value: 100000, label: '0.2$' },
{ value: 500000, label: '1$' },
{ value: 1000000, label: '5$' },
{ value: 5000000, label: '10$' }
]}
/>
</div>
<Typography.Text type="secondary" style={{marginTop: 10, display: 'block'}}>
{t('当剩余额度低于此数值时,系统将通过选择的方式发送通知')}
</Typography.Text>
</div>
<div style={{marginTop: 20}}>
<Button type="primary" onClick={saveNotificationSettings}>
{t('保存设置')}
</Button>
</div>
</Card>
<Modal
onCancel={() => setShowEmailBindModal(false)}
// onOpen={() => setShowEmailBindModal(true)}
onOk={bindEmail}
visible={showEmailBindModal}
size={'small'}

View File

@@ -368,6 +368,17 @@ const SystemSetting = () => {
</a>
</Header>
<Message info>
注意代理功能仅对图片请求和 Webhook 请求生效不会影响其他 API 请求如需配置 API 请求代理请参考
<a
href='https://github.com/Calcium-Ion/new-api/blob/main/docs/channel/other_setting.md'
target='_blank'
rel='noreferrer'
>
{' '}API 代理设置文档
</a>
</Message>
<Form.Group widths='equal'>
<Form.Input
label='Worker地址不填写则不启用代理'

View File

@@ -386,7 +386,7 @@ export function renderQuotaWithPrompt(quota, digits) {
let displayInCurrency = localStorage.getItem('display_in_currency');
displayInCurrency = displayInCurrency === 'true';
if (displayInCurrency) {
return '|' + i18next.t('等价金额') + ': ' + renderQuota(quota, digits) + '';
return ' | ' + i18next.t('等价金额') + ': ' + renderQuota(quota, digits) + '';
}
return '';
}

View File

@@ -201,7 +201,7 @@
"相关 API 显示令牌额度而非用户额度": "Related APIs show token quota instead of user quota",
"保存通用设置": "Save General Settings",
"监控设置": "Monitoring Settings",
"最长响应时间": "Maximum Response Time",
"测试所有渠道的最长响应时间": "Maximum response time for testing all channels",
"单位秒": "Unit: seconds",
"当运行通道全部测试时": "When running all channel tests",
"超过此时间将自动禁用通道": "Channels exceeding this time will be automatically disabled",
@@ -1246,5 +1246,8 @@
"请输入要设置的标签名称": "Please enter the tag name to be set",
"请输入标签名称": "Please enter the tag name",
"支持搜索用户的 ID、用户名、显示名称和邮箱地址": "Support searching for user ID, username, display name, and email address",
"已注销": "Logged out"
"已注销": "Logged out",
"自动禁用关键词": "Automatic disable keywords",
"一行一个,不区分大小写": "One line per keyword, not case-sensitive",
"当上游通道返回错误中包含这些关键词时(不区分大小写),自动禁用通道": "When the upstream channel returns an error containing these keywords (not case-sensitive), automatically disable the channel"
}

View File

@@ -5,7 +5,7 @@ import {
API,
showError,
showSuccess,
showWarning,
showWarning, verifyJSON
} from '../../../helpers';
import { useTranslation } from 'react-i18next';
@@ -17,6 +17,7 @@ export default function SettingsMonitoring(props) {
QuotaRemindThreshold: '',
AutomaticDisableChannelEnabled: false,
AutomaticEnableChannelEnabled: false,
AutomaticDisableKeywords: '',
});
const refForm = useRef();
const [inputsRow, setInputsRow] = useState(inputs);
@@ -79,7 +80,7 @@ export default function SettingsMonitoring(props) {
<Row gutter={16}>
<Col span={8}>
<Form.InputNumber
label={t('最长响应时间')}
label={t('测试所有渠道的最长响应时间')}
step={1}
min={0}
suffix={t('秒')}
@@ -144,6 +145,18 @@ export default function SettingsMonitoring(props) {
/>
</Col>
</Row>
<Row gutter={16}>
<Col span={16}>
<Form.TextArea
label={t('自动禁用关键词')}
placeholder={t('一行一个,不区分大小写')}
extraText={t('当上游通道返回错误中包含这些关键词时(不区分大小写),自动禁用通道')}
field={'AutomaticDisableKeywords'}
autosize={{ minRows: 6, maxRows: 12 }}
onChange={(value) => setInputs({ ...inputs, AutomaticDisableKeywords: value })}
/>
</Col>
</Row>
<Row>
<Button size='default' onClick={onSubmit}>
{t('保存监控设置')}