Compare commits

...

72 Commits

Author SHA1 Message Date
CaIon
2f3acd9d22 feat: 添加流模式下的SSE保活机制 #945 2025-04-14 19:40:23 +08:00
CaIon
dcf7878772 fix: update model name handling in UI and localization 2025-04-12 17:44:29 +08:00
CaIon
ef8ae4db80 fix: xAI usage 2025-04-11 23:31:32 +08:00
CaIon
90576d0261 feat: enhance Claude to OpenAI request conversion with additional relay info support 2025-04-11 19:13:38 +08:00
CaIon
4b3e30e669 feat: 完善openai转claude支持 2025-04-11 18:28:50 +08:00
CaIon
75570af967 chore: update .gitignore and docker-compose.yml to include tiktoken_cache directory 2025-04-11 16:24:27 +08:00
CaIon
cca9c0479f feat: enhance file handling and logging in the application 2025-04-11 16:23:54 +08:00
CaIon
8a2332074f refactor: move maxFileSize variable inside GetFileBase64FromUrl function 2025-04-11 15:53:23 +08:00
CaIon
2ec4565601 feat: implement parameter cleaning for Gemini functions 2025-04-10 22:35:03 +08:00
CaIon
a4fb33957f feat: support zhipu_4v embeddings path 2025-04-10 20:53:51 +08:00
Calcium-Ion
909c5eb276 Merge pull request #959 from Praying/main
fix(relay): 优化数据流处理
2025-04-10 17:21:55 +08:00
CaIon
8723e3f239 feat: add xAI handling and response processing 2025-04-10 17:20:59 +08:00
quran
9328b907f2 fix(relay): 优化数据流处理
- 移除了 bufio 的无效使用
- 在 StreamScannerHandler 中增加了初始和最大缓冲区大小的常量设置
- 调整 StreamScannerHandler 中缓冲区大小,避免出现token too long报错
2025-04-10 16:56:16 +08:00
Calcium-Ion
8efa12b941 Merge pull request #953 from wkxu/main
fix: .env文件配置DEBUG=true等参数不起作用的fix
2025-04-10 16:14:11 +08:00
Calcium-Ion
7b997b3a2c Merge pull request #956 from HynoR/feat/xai
feat: add xAI channel
2025-04-10 16:13:48 +08:00
HynoR
700c05b826 feat: update adaptor methods and add new image model 2025-04-10 15:08:12 +08:00
HynoR
c5103237b0 feat: add xai grok-3-mini reasoning effort 2025-04-10 13:31:43 +08:00
HynoR
f500eb17a8 feat: add xai channel
feat: add xai channel

feat: add xai channel
2025-04-10 13:04:43 +08:00
wkxu
86f6bb7abe refactor: 把common/instants.go里的从Getenv获取的参数,放到init.go的LoadEnv函数里获取
把constant/env.go里的从Getenv获取的参数,放到env.go的InitEnv函数里获取。以避免.env文件配置参数不起作用的情况
2025-04-10 09:02:19 +08:00
Calcium-Ion
c4c1099ae5 Merge pull request #944 from lamcodes/main
Update: Gemini channel fetch_models
2025-04-10 00:09:54 +08:00
CaIon
c869455456 fix: Update model ratios for gemini-2.5-pro 2025-04-10 00:09:11 +08:00
CaIon
f89d8a0fe5 refactor: Remove duplicate model settings initialization in main function 2025-04-10 00:07:34 +08:00
CaIon
3d6d19903b refactor: Update localization keys for API address in English translations and adjust related UI labels 2025-04-09 22:22:19 +08:00
zkp
524d4a65bf Update: Gemini channel fetch_models 2025-04-08 22:43:13 +08:00
CaIon
082218173a feat: Add CheckSetup function call in main to ensure proper initialization #942 2025-04-08 18:14:36 +08:00
Calcium-Ion
67cbbc2266 Merge pull request #930 from Yiffyi/main
fix: save OIDC settings
2025-04-08 17:39:42 +08:00
CaIon
79b35e385f Update MaxTokens for gemini model to 300 in test request 2025-04-08 17:37:25 +08:00
Calcium-Ion
03e8ab4126 Merge pull request #936 from lamcodes/main
fix: gemini test MaxTokens
2025-04-08 17:33:31 +08:00
Calcium-Ion
30f32c6a6d Set MaxTokens to 50 for gemini 2025-04-08 17:33:10 +08:00
CaIon
5813ca780f feat: Integrate SetupCheck component for improved setup validation in routing 2025-04-08 17:31:46 +08:00
CaIon
aa34c3035a feat: Initialize model settings and improve concurrency control in operation settings 2025-04-07 22:20:47 +08:00
CaIon
fb9f595044 feat: Add concurrency control to group ratio management with mutexes 2025-04-07 21:55:54 +08:00
zkp
f24de65626 fix: gemini test MaxTokens 2025-04-06 23:24:47 +08:00
Yiffyi Jia
e34dccbc65 fix: cannot save OIDC settings 2025-04-05 04:24:38 +00:00
CaIon
f6e8887482 Update model-ratio.go 2025-04-04 23:43:14 +08:00
CaIon
a29f4d88c5 Update model-ratio.go 2025-04-04 23:41:41 +08:00
CaIon
a6bb30af41 fix: Improve setup check logic and logging for system initialization 2025-04-04 21:27:24 +08:00
CaIon
424424c160 Update model-ratio.go 2025-04-04 00:31:24 +08:00
CaIon
e5baa6ee1c feat: Enhance ModelSettingsVisualEditor with pricing modes and improved model management features 2025-04-03 20:42:08 +08:00
CaIon
9207d729ca feat: Add new localization strings for system initialization 2025-04-03 19:27:25 +08:00
CaIon
27933da884 fix: Update option key from SelfUseModeEnabled to DemoSiteEnabled in PostSetup function 2025-04-03 19:21:53 +08:00
CaIon
454dac17ea feat: Add timestamp and version to setup initialization in PostSetup function 2025-04-03 19:16:17 +08:00
CaIon
1921ac3692 fix: Correct option key for SelfUseModeEnabled in setup controller 2025-04-03 19:15:04 +08:00
CaIon
42a2418d9a Merge remote-tracking branch 'origin/main' 2025-04-03 19:09:26 +08:00
CaIon
5cb317bdbd Update README.md 2025-04-03 19:09:13 +08:00
Calcium-Ion
37dd1ef099 Merge pull request #925 from Calcium-Ion/setup
 feat: Implement system setup functionality
2025-04-03 19:01:45 +08:00
CaIon
5fa6462412 feat: Refine personal mode description in setup page for clarity 2025-04-03 19:01:16 +08:00
CaIon
a882e680ae feat: Implement system setup functionality 2025-04-03 18:57:15 +08:00
CaIon
552e2850c5 Merge remote-tracking branch 'origin/main' 2025-04-03 17:33:03 +08:00
CaIon
c418d9ed9a feat: Enhance user settings and notification options 2025-04-03 17:32:48 +08:00
Calcium-Ion
1dc2284d57 Merge pull request #909 from jasinliu/feature/fix-dify-thinking
feat: fix dify thinking
2025-04-03 16:23:12 +08:00
Calcium-Ion
f4cc90c8d6 Merge pull request #893 from wizcas/replace-linux-do-icon
替换登录界面的 Linux.do OAuth 图标
2025-03-31 22:38:41 +08:00
Calcium-Ion
140d3a974b Merge pull request #895 from Feiyuyu0503/main
docs: fix a typo
2025-03-31 22:38:25 +08:00
Calcium-Ion
2ecb742e47 Merge pull request #912 from OrdinarySF/main
fix: fixed bug where target.id was null when clicking 'x' icon
2025-03-31 22:38:08 +08:00
Calcium-Ion
9066cfa8a0 Merge pull request #914 from JoeyLearnsToCode/main
feat: Add Parameters Override
2025-03-31 22:37:26 +08:00
Calcium-Ion
4f437f30e0 Merge pull request #916 from xifan2333/fix/systemSettingsUI
 feat: Update option handling in SystemSetting
2025-03-31 22:36:14 +08:00
xifan
3c2a86f94d feat: Update option handling in SystemSetting
-  Add backend validation for OIDC & Telegram OAuth config
- ♻️ Refactor frontend option updates with batch processing
2025-03-31 00:46:13 +08:00
JoeyLearnsToCode
1b07282153 feat: Add Parameters Override 2025-03-29 14:39:39 +08:00
Ordinary
af7f886c39 refactor: use handleFieldChange function on change event 2025-03-28 12:44:40 +00:00
Ordinary
9cfa138796 fix: fixed bug where target.id was null when clicking 'x' icon 2025-03-28 12:43:26 +00:00
jasinliu
dc132655a6 fix dify thinking 2025-03-28 00:21:27 +08:00
1808837298@qq.com
a378665b8c feat: Add new cache ratios for o3-mini and gpt-4.5-preview models 2025-03-27 18:47:50 +08:00
1808837298@qq.com
3516aad349 update model ratio 2025-03-27 17:02:09 +08:00
1808837298@qq.com
58525c574b feat: Enhance GetCompletionRatio function 2025-03-27 16:38:29 +08:00
1808837298@qq.com
1df39e5a7f update model ratio 2025-03-27 16:24:30 +08:00
feiyuyu
be6ffd3c60 docs: fix a typo 2025-03-22 21:28:25 +08:00
Wizcas Chen
a9522075c6 replace the linuxdo icon in the login form 2025-03-22 17:16:07 +08:00
Calcium-Ion
983d31bfd3 Merge pull request #886 from seefs001/main
fix: claude function calling type
2025-03-20 23:22:20 +08:00
Seefs
20c043f584 fix: claude function calling type 2025-03-19 22:48:49 +08:00
1808837298@qq.com
73263e02d6 fix: Adjust MaxTokens logic for non-Claude models in test request 2025-03-17 23:44:32 +08:00
1808837298@qq.com
7143b0f160 feat: Add support for cross-region AWS model handling in awsStreamHandler 2025-03-17 23:41:00 +08:00
1808837298@qq.com
dd82618c05 refactor: Improve token quota consumption logic 2025-03-17 17:52:54 +08:00
72 changed files with 6760 additions and 4749 deletions

3
.gitignore vendored
View File

@@ -9,4 +9,5 @@ logs
web/dist
.env
one-api
.DS_Store
.DS_Store
tiktoken_cache

View File

@@ -117,7 +117,6 @@ New API提供了丰富的功能详细特性请参考[特性说明](https://do
> [!TIP]
> 最新版Docker镜像`calciumion/new-api:latest`
> 默认账号root 密码123456
### 多机部署注意事项
- 必须设置环境变量 `SESSION_SECRET`,否则会导致多机部署时登录状态不一致

View File

@@ -1,8 +1,8 @@
package common
import (
"os"
"strconv"
//"os"
//"strconv"
"sync"
"time"
@@ -63,8 +63,8 @@ var EmailDomainWhitelist = []string{
"foxmail.com",
}
var DebugEnabled = os.Getenv("DEBUG") == "true"
var MemoryCacheEnabled = os.Getenv("MEMORY_CACHE_ENABLED") == "true"
var DebugEnabled bool
var MemoryCacheEnabled bool
var LogConsumeEnabled = true
@@ -103,22 +103,22 @@ var RetryTimes = 0
//var RootUserEmail = ""
var IsMasterNode = os.Getenv("NODE_TYPE") != "slave"
var IsMasterNode bool
var requestInterval, _ = strconv.Atoi(os.Getenv("POLLING_INTERVAL"))
var RequestInterval = time.Duration(requestInterval) * time.Second
var requestInterval int
var RequestInterval time.Duration
var SyncFrequency = GetEnvOrDefault("SYNC_FREQUENCY", 60) // unit is second
var SyncFrequency int // unit is second
var BatchUpdateEnabled = false
var BatchUpdateInterval = GetEnvOrDefault("BATCH_UPDATE_INTERVAL", 5)
var BatchUpdateInterval int
var RelayTimeout = GetEnvOrDefault("RELAY_TIMEOUT", 0) // unit is second
var RelayTimeout int // unit is second
var GeminiSafetySetting = GetEnvOrDefaultString("GEMINI_SAFETY_SETTING", "BLOCK_NONE")
var GeminiSafetySetting string
// https://docs.cohere.com/docs/safety-modes Type; NONE/CONTEXTUAL/STRICT
var CohereSafetySetting = GetEnvOrDefaultString("COHERE_SAFETY_SETTING", "NONE")
var CohereSafetySetting string
const (
RequestIdKey = "X-Oneapi-Request-Id"
@@ -145,13 +145,13 @@ var (
// All duration's unit is seconds
// Shouldn't larger then RateLimitKeyExpirationDuration
var (
GlobalApiRateLimitEnable = GetEnvOrDefaultBool("GLOBAL_API_RATE_LIMIT_ENABLE", true)
GlobalApiRateLimitNum = GetEnvOrDefault("GLOBAL_API_RATE_LIMIT", 180)
GlobalApiRateLimitDuration = int64(GetEnvOrDefault("GLOBAL_API_RATE_LIMIT_DURATION", 180))
GlobalApiRateLimitEnable bool
GlobalApiRateLimitNum int
GlobalApiRateLimitDuration int64
GlobalWebRateLimitEnable = GetEnvOrDefaultBool("GLOBAL_WEB_RATE_LIMIT_ENABLE", true)
GlobalWebRateLimitNum = GetEnvOrDefault("GLOBAL_WEB_RATE_LIMIT", 60)
GlobalWebRateLimitDuration = int64(GetEnvOrDefault("GLOBAL_WEB_RATE_LIMIT_DURATION", 180))
GlobalWebRateLimitEnable bool
GlobalWebRateLimitNum int
GlobalWebRateLimitDuration int64
UploadRateLimitNum = 10
UploadRateLimitDuration int64 = 60
@@ -235,6 +235,7 @@ const (
ChannelTypeVolcEngine = 45
ChannelTypeBaiduV2 = 46
ChannelTypeXinference = 47
ChannelTypeXai = 48
ChannelTypeDummy // this one is only for count, do not add any channel after this
)
@@ -288,4 +289,5 @@ var ChannelBaseURLs = []string{
"https://ark.cn-beijing.volces.com", //45
"https://qianfan.baidubce.com", //46
"", //47
"https://api.x.ai", //48
}

View File

@@ -6,6 +6,8 @@ import (
"log"
"os"
"path/filepath"
"strconv"
"time"
)
var (
@@ -66,4 +68,31 @@ func LoadEnv() {
}
}
}
// Initialize variables from constants.go that were using environment variables
DebugEnabled = os.Getenv("DEBUG") == "true"
MemoryCacheEnabled = os.Getenv("MEMORY_CACHE_ENABLED") == "true"
IsMasterNode = os.Getenv("NODE_TYPE") != "slave"
// Parse requestInterval and set RequestInterval
requestInterval, _ = strconv.Atoi(os.Getenv("POLLING_INTERVAL"))
RequestInterval = time.Duration(requestInterval) * time.Second
// Initialize variables with GetEnvOrDefault
SyncFrequency = GetEnvOrDefault("SYNC_FREQUENCY", 60)
BatchUpdateInterval = GetEnvOrDefault("BATCH_UPDATE_INTERVAL", 5)
RelayTimeout = GetEnvOrDefault("RELAY_TIMEOUT", 0)
// Initialize string variables with GetEnvOrDefaultString
GeminiSafetySetting = GetEnvOrDefaultString("GEMINI_SAFETY_SETTING", "BLOCK_NONE")
CohereSafetySetting = GetEnvOrDefaultString("COHERE_SAFETY_SETTING", "NONE")
// Initialize rate limit variables
GlobalApiRateLimitEnable = GetEnvOrDefaultBool("GLOBAL_API_RATE_LIMIT_ENABLE", true)
GlobalApiRateLimitNum = GetEnvOrDefault("GLOBAL_API_RATE_LIMIT", 180)
GlobalApiRateLimitDuration = int64(GetEnvOrDefault("GLOBAL_API_RATE_LIMIT_DURATION", 180))
GlobalWebRateLimitEnable = GetEnvOrDefaultBool("GLOBAL_WEB_RATE_LIMIT_ENABLE", true)
GlobalWebRateLimitNum = GetEnvOrDefault("GLOBAL_WEB_RATE_LIMIT", 60)
GlobalWebRateLimitDuration = int64(GetEnvOrDefault("GLOBAL_WEB_RATE_LIMIT_DURATION", 180))
}

View File

@@ -12,3 +12,7 @@ func DecodeJson(data []byte, v any) error {
func DecodeJsonStr(data string, v any) error {
return DecodeJson(StringToByteSlice(data), v)
}
func EncodeJson(v any) ([]byte, error) {
return json.Marshal(v)
}

View File

@@ -4,32 +4,39 @@ import (
"one-api/common"
)
var StreamingTimeout = common.GetEnvOrDefault("STREAMING_TIMEOUT", 60)
var DifyDebug = common.GetEnvOrDefaultBool("DIFY_DEBUG", true)
var MaxFileDownloadMB = common.GetEnvOrDefault("MAX_FILE_DOWNLOAD_MB", 20)
// ForceStreamOption 覆盖请求参数强制返回usage信息
var ForceStreamOption = common.GetEnvOrDefaultBool("FORCE_STREAM_OPTION", true)
var GetMediaToken = common.GetEnvOrDefaultBool("GET_MEDIA_TOKEN", true)
var GetMediaTokenNotStream = common.GetEnvOrDefaultBool("GET_MEDIA_TOKEN_NOT_STREAM", true)
var UpdateTask = common.GetEnvOrDefaultBool("UPDATE_TASK", true)
var AzureDefaultAPIVersion = common.GetEnvOrDefaultString("AZURE_DEFAULT_API_VERSION", "2024-12-01-preview")
var StreamingTimeout int
var DifyDebug bool
var MaxFileDownloadMB int
var ForceStreamOption bool
var GetMediaToken bool
var GetMediaTokenNotStream bool
var UpdateTask bool
var AzureDefaultAPIVersion string
var GeminiVisionMaxImageNum int
var NotifyLimitCount int
var NotificationLimitDurationMinute int
var GenerateDefaultToken bool
//var GeminiModelMap = map[string]string{
// "gemini-1.0-pro": "v1",
//}
var GeminiVisionMaxImageNum = common.GetEnvOrDefault("GEMINI_VISION_MAX_IMAGE_NUM", 16)
var NotifyLimitCount = common.GetEnvOrDefault("NOTIFY_LIMIT_COUNT", 2)
var NotificationLimitDurationMinute = common.GetEnvOrDefault("NOTIFICATION_LIMIT_DURATION_MINUTE", 10)
func InitEnv() {
StreamingTimeout = common.GetEnvOrDefault("STREAMING_TIMEOUT", 60)
DifyDebug = common.GetEnvOrDefaultBool("DIFY_DEBUG", true)
MaxFileDownloadMB = common.GetEnvOrDefault("MAX_FILE_DOWNLOAD_MB", 20)
// ForceStreamOption 覆盖请求参数强制返回usage信息
ForceStreamOption = common.GetEnvOrDefaultBool("FORCE_STREAM_OPTION", true)
GetMediaToken = common.GetEnvOrDefaultBool("GET_MEDIA_TOKEN", true)
GetMediaTokenNotStream = common.GetEnvOrDefaultBool("GET_MEDIA_TOKEN_NOT_STREAM", true)
UpdateTask = common.GetEnvOrDefaultBool("UPDATE_TASK", true)
AzureDefaultAPIVersion = common.GetEnvOrDefaultString("AZURE_DEFAULT_API_VERSION", "2024-12-01-preview")
GeminiVisionMaxImageNum = common.GetEnvOrDefault("GEMINI_VISION_MAX_IMAGE_NUM", 16)
NotifyLimitCount = common.GetEnvOrDefault("NOTIFY_LIMIT_COUNT", 2)
NotificationLimitDurationMinute = common.GetEnvOrDefault("NOTIFICATION_LIMIT_DURATION_MINUTE", 10)
// GenerateDefaultToken 是否生成初始令牌,默认关闭。
GenerateDefaultToken = common.GetEnvOrDefaultBool("GENERATE_DEFAULT_TOKEN", false)
//modelVersionMapStr := strings.TrimSpace(os.Getenv("GEMINI_MODEL_MAP"))
//if modelVersionMapStr == "" {
// return
@@ -43,6 +50,3 @@ func InitEnv() {
// }
//}
}
// GenerateDefaultToken 是否生成初始令牌,默认关闭。
var GenerateDefaultToken = common.GetEnvOrDefaultBool("GENERATE_DEFAULT_TOKEN", false)

3
constant/setup.go Normal file
View File

@@ -0,0 +1,3 @@
package constant
var Setup = false

View File

@@ -1,11 +1,12 @@
package constant
var (
UserSettingNotifyType = "notify_type" // QuotaWarningType 额度预警类型
UserSettingQuotaWarningThreshold = "quota_warning_threshold" // QuotaWarningThreshold 额度预警阈值
UserSettingWebhookUrl = "webhook_url" // WebhookUrl webhook地址
UserSettingWebhookSecret = "webhook_secret" // WebhookSecret webhook密钥
UserSettingNotificationEmail = "notification_email" // NotificationEmail 通知邮箱地址
UserSettingNotifyType = "notify_type" // QuotaWarningType 额度预警类型
UserSettingQuotaWarningThreshold = "quota_warning_threshold" // QuotaWarningThreshold 额度预警阈值
UserSettingWebhookUrl = "webhook_url" // WebhookUrl webhook地址
UserSettingWebhookSecret = "webhook_secret" // WebhookSecret webhook密钥
UserSettingNotificationEmail = "notification_email" // NotificationEmail 通知邮箱地址
UserAcceptUnsetRatioModel = "accept_unset_model_ratio_model" // AcceptUnsetRatioModel 是否接受未设置价格的模型
)
var (

View File

@@ -105,6 +105,11 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
request := buildTestRequest(testModel)
common.SysLog(fmt.Sprintf("testing channel %d with model %s , info %v ", channel.Id, testModel, info))
priceData, err := helper.ModelPriceHelper(c, info, 0, int(request.MaxTokens))
if err != nil {
return err, nil
}
adaptor.Init(info)
convertedRequest, err := adaptor.ConvertOpenAIRequest(c, info, request)
@@ -143,10 +148,7 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
return err, nil
}
info.PromptTokens = usage.PromptTokens
priceData, err := helper.ModelPriceHelper(c, info, usage.PromptTokens, int(request.MaxTokens))
if err != nil {
return err, nil
}
quota := 0
if !priceData.UsePrice {
quota = usage.PromptTokens + int(math.Round(float64(usage.CompletionTokens)*priceData.CompletionRatio))
@@ -187,7 +189,11 @@ func buildTestRequest(model string) *dto.GeneralOpenAIRequest {
if strings.HasPrefix(model, "o1") || strings.HasPrefix(model, "o3") {
testRequest.MaxCompletionTokens = 10
} else if strings.Contains(model, "thinking") {
testRequest.MaxTokens = 50
if !strings.Contains(model, "claude") {
testRequest.MaxTokens = 50
}
} else if strings.Contains(model, "gemini") {
testRequest.MaxTokens = 300
} else {
testRequest.MaxTokens = 10
}

View File

@@ -119,6 +119,9 @@ func FetchUpstreamModels(c *gin.Context) {
baseURL = channel.GetBaseURL()
}
url := fmt.Sprintf("%s/v1/models", baseURL)
if channel.Type == common.ChannelTypeGemini {
url = fmt.Sprintf("%s/v1beta/openai/models", baseURL)
}
body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
if err != nil {
c.JSON(http.StatusOK, gin.H{
@@ -139,7 +142,11 @@ func FetchUpstreamModels(c *gin.Context) {
var ids []string
for _, model := range result.Data {
ids = append(ids, model.ID)
id := model.ID
if channel.Type == common.ChannelTypeGemini {
id = strings.TrimPrefix(id, "models/")
}
ids = append(ids, id)
}
c.JSON(http.StatusOK, gin.H{

View File

@@ -5,6 +5,7 @@ import (
"fmt"
"net/http"
"one-api/common"
"one-api/constant"
"one-api/model"
"one-api/setting"
"one-api/setting/operation_setting"
@@ -72,6 +73,7 @@ func GetStatus(c *gin.Context) {
"oidc_enabled": system_setting.GetOIDCSettings().Enabled,
"oidc_client_id": system_setting.GetOIDCSettings().ClientId,
"oidc_authorization_endpoint": system_setting.GetOIDCSettings().AuthorizationEndpoint,
"setup": constant.Setup,
},
})
return

View File

@@ -53,11 +53,12 @@ func UpdateOption(c *gin.Context) {
return
}
case "oidc.enabled":
if option.Value == "true" && system_setting.GetOIDCSettings().Enabled {
if option.Value == "true" && system_setting.GetOIDCSettings().ClientId == "" {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "无法启用 OIDC 登录,请先填入 OIDC Client Id 以及 OIDC Client Secret",
})
return
}
case "LinuxDOOAuthEnabled":
if option.Value == "true" && common.LinuxDOClientId == "" {
@@ -89,6 +90,15 @@ func UpdateOption(c *gin.Context) {
"success": false,
"message": "无法启用 Turnstile 校验,请先填入 Turnstile 校验相关配置信息!",
})
return
}
case "TelegramOAuthEnabled":
if option.Value == "true" && common.TelegramBotToken == "" {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "无法启用 Telegram OAuth请先填入 Telegram Bot Token",
})
return
}
case "GroupRatio":
@@ -100,6 +110,7 @@ func UpdateOption(c *gin.Context) {
})
return
}
}
err = model.UpdateOption(option.Key, option.Value)
if err != nil {

173
controller/setup.go Normal file
View File

@@ -0,0 +1,173 @@
package controller
import (
"github.com/gin-gonic/gin"
"one-api/common"
"one-api/constant"
"one-api/model"
"one-api/setting/operation_setting"
"time"
)
type Setup struct {
Status bool `json:"status"`
RootInit bool `json:"root_init"`
DatabaseType string `json:"database_type"`
}
type SetupRequest struct {
Username string `json:"username"`
Password string `json:"password"`
ConfirmPassword string `json:"confirmPassword"`
SelfUseModeEnabled bool `json:"SelfUseModeEnabled"`
DemoSiteEnabled bool `json:"DemoSiteEnabled"`
}
func GetSetup(c *gin.Context) {
setup := Setup{
Status: constant.Setup,
}
if constant.Setup {
c.JSON(200, gin.H{
"success": true,
"data": setup,
})
return
}
setup.RootInit = model.RootUserExists()
if common.UsingMySQL {
setup.DatabaseType = "mysql"
}
if common.UsingPostgreSQL {
setup.DatabaseType = "postgres"
}
if common.UsingSQLite {
setup.DatabaseType = "sqlite"
}
c.JSON(200, gin.H{
"success": true,
"data": setup,
})
}
func PostSetup(c *gin.Context) {
// Check if setup is already completed
if constant.Setup {
c.JSON(400, gin.H{
"success": false,
"message": "系统已经初始化完成",
})
return
}
// Check if root user already exists
rootExists := model.RootUserExists()
var req SetupRequest
err := c.ShouldBindJSON(&req)
if err != nil {
c.JSON(400, gin.H{
"success": false,
"message": "请求参数有误",
})
return
}
// If root doesn't exist, validate and create admin account
if !rootExists {
// Validate password
if req.Password != req.ConfirmPassword {
c.JSON(400, gin.H{
"success": false,
"message": "两次输入的密码不一致",
})
return
}
if len(req.Password) < 8 {
c.JSON(400, gin.H{
"success": false,
"message": "密码长度至少为8个字符",
})
return
}
// Create root user
hashedPassword, err := common.Password2Hash(req.Password)
if err != nil {
c.JSON(500, gin.H{
"success": false,
"message": "系统错误: " + err.Error(),
})
return
}
rootUser := model.User{
Username: req.Username,
Password: hashedPassword,
Role: common.RoleRootUser,
Status: common.UserStatusEnabled,
DisplayName: "Root User",
AccessToken: nil,
Quota: 100000000,
}
err = model.DB.Create(&rootUser).Error
if err != nil {
c.JSON(500, gin.H{
"success": false,
"message": "创建管理员账号失败: " + err.Error(),
})
return
}
}
// Set operation modes
operation_setting.SelfUseModeEnabled = req.SelfUseModeEnabled
operation_setting.DemoSiteEnabled = req.DemoSiteEnabled
// Save operation modes to database for persistence
err = model.UpdateOption("SelfUseModeEnabled", boolToString(req.SelfUseModeEnabled))
if err != nil {
c.JSON(500, gin.H{
"success": false,
"message": "保存自用模式设置失败: " + err.Error(),
})
return
}
err = model.UpdateOption("DemoSiteEnabled", boolToString(req.DemoSiteEnabled))
if err != nil {
c.JSON(500, gin.H{
"success": false,
"message": "保存演示站点模式设置失败: " + err.Error(),
})
return
}
// Update setup status
constant.Setup = true
setup := model.Setup{
Version: common.Version,
InitializedAt: time.Now().Unix(),
}
err = model.DB.Create(&setup).Error
if err != nil {
c.JSON(500, gin.H{
"success": false,
"message": "系统初始化失败: " + err.Error(),
})
return
}
c.JSON(200, gin.H{
"success": true,
"message": "系统初始化成功",
})
}
func boolToString(b bool) string {
if b {
return "true"
}
return "false"
}

View File

@@ -913,11 +913,12 @@ func TopUp(c *gin.Context) {
}
type UpdateUserSettingRequest struct {
QuotaWarningType string `json:"notify_type"`
QuotaWarningThreshold float64 `json:"quota_warning_threshold"`
WebhookUrl string `json:"webhook_url,omitempty"`
WebhookSecret string `json:"webhook_secret,omitempty"`
NotificationEmail string `json:"notification_email,omitempty"`
QuotaWarningType string `json:"notify_type"`
QuotaWarningThreshold float64 `json:"quota_warning_threshold"`
WebhookUrl string `json:"webhook_url,omitempty"`
WebhookSecret string `json:"webhook_secret,omitempty"`
NotificationEmail string `json:"notification_email,omitempty"`
AcceptUnsetModelRatioModel bool `json:"accept_unset_model_ratio_model"`
}
func UpdateUserSetting(c *gin.Context) {
@@ -993,6 +994,7 @@ func UpdateUserSetting(c *gin.Context) {
settings := map[string]interface{}{
constant.UserSettingNotifyType: req.QuotaWarningType,
constant.UserSettingQuotaWarningThreshold: req.QuotaWarningThreshold,
"accept_unset_model_ratio_model": req.AcceptUnsetModelRatioModel,
}
// 如果是webhook类型,添加webhook相关设置

View File

@@ -15,6 +15,7 @@ services:
- SQL_DSN=root:123456@tcp(mysql:3306)/new-api # Point to the mysql service
- REDIS_CONN_STRING=redis://redis
- TZ=Asia/Shanghai
# - TIKTOKEN_CACHE_DIR=./tiktoken_cache # 如果需要使用tiktoken_cache请取消注释
# - SESSION_SECRET=random_string # 多机部署时设置,必须修改这个随机字符串!!!!!!!
# - NODE_TYPE=slave # Uncomment for slave node in multi-node deployment
# - SYNC_FREQUENCY=60 # Uncomment if regular database syncing is needed

View File

@@ -11,7 +11,7 @@
- 类型为字符串,填写代理地址(例如 socks5 协议的代理地址)
3. thinking_to_content
- 用于标识是否将思考内容`reasoning_conetnt`转换为`<think>`标签拼接到内容中返回
- 用于标识是否将思考内容`reasoning_content`转换为`<think>`标签拼接到内容中返回
- 类型为布尔值,设置为 true 时启用思考内容转换
--------------------------------------------------------------
@@ -30,4 +30,4 @@
--------------------------------------------------------------
通过调整上述 JSON 配置中的值,可以灵活控制渠道的额外行为,比如是否进行格式化以及使用特定的网络代理。
通过调整上述 JSON 配置中的值,可以灵活控制渠道的额外行为,比如是否进行格式化以及使用特定的网络代理。

View File

@@ -7,7 +7,7 @@ type ClaudeMetadata struct {
}
type ClaudeMediaMessage struct {
Type string `json:"type"`
Type string `json:"type,omitempty"`
Text *string `json:"text,omitempty"`
Model string `json:"model,omitempty"`
Source *ClaudeMessageSource `json:"source,omitempty"`
@@ -50,6 +50,11 @@ func (c *ClaudeMediaMessage) GetStringContent() string {
return ""
}
func (c *ClaudeMediaMessage) GetJsonRowString() string {
jsonContent, _ := json.Marshal(c)
return string(jsonContent)
}
func (c *ClaudeMediaMessage) SetContent(content any) {
jsonContent, _ := json.Marshal(content)
c.Content = jsonContent

View File

@@ -111,6 +111,7 @@ type MediaContent struct {
Text string `json:"text,omitempty"`
ImageUrl any `json:"image_url,omitempty"`
InputAudio any `json:"input_audio,omitempty"`
File any `json:"file,omitempty"`
}
func (m *MediaContent) GetImageMedia() *MessageImageUrl {
@@ -120,6 +121,20 @@ func (m *MediaContent) GetImageMedia() *MessageImageUrl {
return nil
}
func (m *MediaContent) GetInputAudio() *MessageInputAudio {
if m.InputAudio != nil {
return m.InputAudio.(*MessageInputAudio)
}
return nil
}
func (m *MediaContent) GetFile() *MessageFile {
if m.File != nil {
return m.File.(*MessageFile)
}
return nil
}
type MessageImageUrl struct {
Url string `json:"url"`
Detail string `json:"detail"`
@@ -135,10 +150,17 @@ type MessageInputAudio struct {
Format string `json:"format"`
}
type MessageFile struct {
FileName string `json:"filename,omitempty"`
FileData string `json:"file_data,omitempty"`
FileId string `json:"file_id,omitempty"`
}
const (
ContentTypeText = "text"
ContentTypeImageURL = "image_url"
ContentTypeInputAudio = "input_audio"
ContentTypeFile = "file"
)
func (m *Message) GetPrefix() bool {
@@ -192,6 +214,12 @@ func (m *Message) StringContent() string {
return stringContent
}
func (m *Message) SetNullContent() {
m.Content = nil
m.parsedStringContent = nil
m.parsedContent = nil
}
func (m *Message) SetStringContent(content string) {
jsonContent, _ := json.Marshal(content)
m.Content = jsonContent
@@ -292,6 +320,30 @@ func (m *Message) ParseContent() []MediaContent {
})
}
}
case ContentTypeFile:
if fileData, ok := contentItem["file"].(map[string]interface{}); ok {
fileId, ok3 := fileData["file_id"].(string)
if ok3 {
contentList = append(contentList, MediaContent{
Type: ContentTypeFile,
File: &MessageFile{
FileId: fileId,
},
})
} else {
fileName, ok1 := fileData["filename"].(string)
fileDataStr, ok2 := fileData["file_data"].(string)
if ok1 && ok2 {
contentList = append(contentList, MediaContent{
Type: ContentTypeFile,
File: &MessageFile{
FileName: fileName,
FileData: fileDataStr,
},
})
}
}
}
}
}
}

View File

@@ -45,15 +45,16 @@ type RealtimeUsage struct {
type InputTokenDetails struct {
CachedTokens int `json:"cached_tokens"`
CachedCreationTokens int
CachedCreationTokens int `json:"-"`
TextTokens int `json:"text_tokens"`
AudioTokens int `json:"audio_tokens"`
ImageTokens int `json:"image_tokens"`
}
type OutputTokenDetails struct {
TextTokens int `json:"text_tokens"`
AudioTokens int `json:"audio_tokens"`
TextTokens int `json:"text_tokens"`
AudioTokens int `json:"audio_tokens"`
ReasoningTokens int `json:"reasoning_tokens"`
}
type RealtimeSession struct {

View File

@@ -12,6 +12,7 @@ import (
"one-api/model"
"one-api/router"
"one-api/service"
"one-api/setting/operation_setting"
"os"
"strconv"
@@ -33,7 +34,7 @@ var indexPage []byte
func main() {
err := godotenv.Load(".env")
if err != nil {
common.SysLog("Support for .env file is disabled")
common.SysLog("Support for .env file is disabled: " + err.Error())
}
common.LoadEnv()
@@ -51,6 +52,9 @@ func main() {
if err != nil {
common.FatalLog("failed to initialize database: " + err.Error())
}
model.CheckSetup()
// Initialize SQL Database
err = model.InitLogDB()
if err != nil {
@@ -69,10 +73,13 @@ func main() {
common.FatalLog("failed to initialize Redis: " + err.Error())
}
// Initialize model settings
operation_setting.InitModelSettings()
// Initialize constants
constant.InitEnv()
// Initialize options
model.InitOptionMap()
if common.RedisEnabled {
// for compatibility with old versions
common.MemoryCacheEnabled = true

View File

@@ -212,6 +212,7 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode
c.Set("channel_name", channel.Name)
c.Set("channel_type", channel.Type)
c.Set("channel_setting", channel.GetSetting())
c.Set("param_override", channel.GetParamOverride())
if nil != channel.OpenAIOrganization && "" != *channel.OpenAIOrganization {
c.Set("channel_organization", *channel.OpenAIOrganization)
}

View File

@@ -36,6 +36,7 @@ type Channel struct {
OtherInfo string `json:"other_info"`
Tag *string `json:"tag" gorm:"index"`
Setting *string `json:"setting" gorm:"type:text"`
ParamOverride *string `json:"param_override" gorm:"type:text"`
}
func (channel *Channel) GetModels() []string {
@@ -511,6 +512,17 @@ func (channel *Channel) SetSetting(setting map[string]interface{}) {
channel.Setting = common.GetPointer[string](string(settingBytes))
}
func (channel *Channel) GetParamOverride() map[string]interface{} {
paramOverride := make(map[string]interface{})
if channel.ParamOverride != nil && *channel.ParamOverride != "" {
err := json.Unmarshal([]byte(*channel.ParamOverride), &paramOverride)
if err != nil {
common.SysError("failed to unmarshal param override: " + err.Error())
}
}
return paramOverride
}
func GetChannelsByIds(ids []int) ([]*Channel, error) {
var channels []*Channel
err := DB.Where("id in (?)", ids).Find(&channels).Error

View File

@@ -3,6 +3,7 @@ package model
import (
"log"
"one-api/common"
"one-api/constant"
"os"
"strings"
"sync"
@@ -55,6 +56,33 @@ func createRootAccountIfNeed() error {
return nil
}
func CheckSetup() {
setup := GetSetup()
if setup == nil {
// No setup record exists, check if we have a root user
if RootUserExists() {
common.SysLog("system is not initialized, but root user exists")
// Create setup record
newSetup := Setup{
Version: common.Version,
InitializedAt: time.Now().Unix(),
}
err := DB.Create(&newSetup).Error
if err != nil {
common.SysLog("failed to create setup record: " + err.Error())
}
constant.Setup = true
} else {
common.SysLog("system is not initialized and no root user exists")
constant.Setup = false
}
} else {
// Setup record exists, system is initialized
common.SysLog("system is already initialized at: " + time.Unix(setup.InitializedAt, 0).String())
constant.Setup = true
}
}
func chooseDB(envName string) (*gorm.DB, error) {
defer func() {
initCol()
@@ -214,8 +242,9 @@ func migrateDB() error {
if err != nil {
return err
}
err = DB.AutoMigrate(&Setup{})
common.SysLog("database migrated")
err = createRootAccountIfNeed()
//err = createRootAccountIfNeed()
return err
}

16
model/setup.go Normal file
View File

@@ -0,0 +1,16 @@
package model
type Setup struct {
ID uint `json:"id" gorm:"primaryKey"`
Version string `json:"version" gorm:"type:varchar(50);not null"`
InitializedAt int64 `json:"initialized_at" gorm:"type:bigint;not null"`
}
func GetSetup() *Setup {
var setup Setup
err := DB.First(&setup).Error
if err != nil {
return nil
}
return &setup
}

View File

@@ -808,3 +808,12 @@ func (user *User) FillUserByLinuxDOId() error {
err := DB.Where("linux_do_id = ?", user.LinuxDOId).First(user).Error
return err
}
func RootUserExists() bool {
var user User
err := DB.Where("role = ?", common.RoleRootUser).First(&user).Error
if err != nil {
return false
}
return true
}

View File

@@ -135,6 +135,12 @@ func awsStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
return wrapErr(errors.Wrap(err, "awsModelID")), nil
}
awsRegionPrefix := awsRegionPrefix(awsCli.Options().Region)
canCrossRegion := awsModelCanCrossRegion(awsModelId, awsRegionPrefix)
if canCrossRegion {
awsModelId = awsModelCrossRegion(awsModelId, awsRegionPrefix)
}
awsReq := &bedrockruntime.InvokeModelWithResponseStreamInput{
ModelId: aws.String(awsModelId),
Accept: aws.String("application/json"),

View File

@@ -70,7 +70,9 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*dto.Cla
Description: tool.Function.Description,
}
claudeTool.InputSchema = make(map[string]interface{})
claudeTool.InputSchema["type"] = params["type"].(string)
if params["type"] != nil {
claudeTool.InputSchema["type"] = params["type"].(string)
}
claudeTool.InputSchema["properties"] = params["properties"]
claudeTool.InputSchema["required"] = params["required"]
for s, a := range params {

View File

@@ -1,7 +1,6 @@
package dify
import (
"bufio"
"bytes"
"encoding/base64"
"encoding/json"
@@ -198,6 +197,12 @@ func streamResponseDify2OpenAI(difyResponse DifyChunkChatCompletionResponse) *dt
choice.Delta.SetReasoningContent(text + "\n")
}
} else if difyResponse.Event == "message" || difyResponse.Event == "agent_message" {
if difyResponse.Answer == "<details style=\"color:gray;background-color: #f8f8f8;padding: 8px;border-radius: 4px;\" open> <summary> Thinking... </summary>\n" {
difyResponse.Answer = "<think>"
} else if difyResponse.Answer == "</details>" {
difyResponse.Answer = "</think>"
}
choice.Delta.SetContentString(difyResponse.Answer)
}
response.Choices = append(response.Choices, choice)
@@ -207,12 +212,8 @@ func streamResponseDify2OpenAI(difyResponse DifyChunkChatCompletionResponse) *dt
func difyStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
var responseText string
usage := &dto.Usage{}
scanner := bufio.NewScanner(resp.Body)
scanner.Split(bufio.ScanLines)
var nodeToken int
helper.SetEventStreamHeaders(c)
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
var difyResponse DifyChunkChatCompletionResponse
err := json.Unmarshal([]byte(data), &difyResponse)
@@ -241,13 +242,10 @@ func difyStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Re
}
return true
})
if err := scanner.Err(); err != nil {
common.SysError("error reading stream: " + err.Error())
}
helper.Done(c)
err := resp.Body.Close()
if err != nil {
//return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
// return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
common.SysError("close_response_body_failed: " + err.Error())
}
if usage.TotalTokens == 0 {

View File

@@ -16,6 +16,8 @@ var ModelList = []string{
"gemini-2.0-pro-exp",
// thinking exp
"gemini-2.0-flash-thinking-exp",
"gemini-2.5-pro-exp-03-25",
"gemini-2.5-pro-preview-03-25",
// imagen models
"imagen-3.0-generate-002",
// embedding models

View File

@@ -56,6 +56,7 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) (*GeminiChatReque
continue
}
if tool.Function.Parameters != nil {
params, ok := tool.Function.Parameters.(map[string]interface{})
if ok {
if props, hasProps := params["properties"].(map[string]interface{}); hasProps {
@@ -65,6 +66,9 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) (*GeminiChatReque
}
}
}
// Clean the parameters before appending
cleanedParams := cleanFunctionParameters(tool.Function.Parameters)
tool.Function.Parameters = cleanedParams
functions = append(functions, tool.Function)
}
if codeExecution {
@@ -86,11 +90,11 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) (*GeminiChatReque
// json_data, _ := json.Marshal(geminiRequest.Tools)
// common.SysLog("tools_json: " + string(json_data))
} else if textRequest.Functions != nil {
geminiRequest.Tools = []GeminiChatTool{
{
FunctionDeclarations: textRequest.Functions,
},
}
//geminiRequest.Tools = []GeminiChatTool{
// {
// FunctionDeclarations: textRequest.Functions,
// },
//}
}
if textRequest.ResponseFormat != nil && (textRequest.ResponseFormat.Type == "json_schema" || textRequest.ResponseFormat.Type == "json_object") {
@@ -204,6 +208,34 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) (*GeminiChatReque
},
})
}
} else if part.Type == dto.ContentTypeFile {
if part.GetFile().FileId != "" {
return nil, fmt.Errorf("only base64 file is supported in gemini")
}
format, base64String, err := service.DecodeBase64FileData(part.GetFile().FileData)
if err != nil {
return nil, fmt.Errorf("decode base64 file data failed: %s", err.Error())
}
parts = append(parts, GeminiPart{
InlineData: &GeminiInlineData{
MimeType: format,
Data: base64String,
},
})
} else if part.Type == dto.ContentTypeInputAudio {
if part.GetInputAudio().Data == "" {
return nil, fmt.Errorf("only base64 audio is supported in gemini")
}
format, base64String, err := service.DecodeBase64FileData(part.GetInputAudio().Data)
if err != nil {
return nil, fmt.Errorf("decode base64 audio data failed: %s", err.Error())
}
parts = append(parts, GeminiPart{
InlineData: &GeminiInlineData{
MimeType: format,
Data: base64String,
},
})
}
}
@@ -229,6 +261,93 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) (*GeminiChatReque
return &geminiRequest, nil
}
// cleanFunctionParameters recursively removes unsupported fields from Gemini function parameters.
func cleanFunctionParameters(params interface{}) interface{} {
if params == nil {
return nil
}
paramMap, ok := params.(map[string]interface{})
if !ok {
// Not a map, return as is (e.g., could be an array or primitive)
return params
}
// Create a copy to avoid modifying the original
cleanedMap := make(map[string]interface{})
for k, v := range paramMap {
cleanedMap[k] = v
}
// Clean properties
if props, ok := cleanedMap["properties"].(map[string]interface{}); ok && props != nil {
cleanedProps := make(map[string]interface{})
for propName, propValue := range props {
propMap, ok := propValue.(map[string]interface{})
if !ok {
cleanedProps[propName] = propValue // Keep non-map properties
continue
}
// Create a copy of the property map
cleanedPropMap := make(map[string]interface{})
for k, v := range propMap {
cleanedPropMap[k] = v
}
// Remove unsupported fields
delete(cleanedPropMap, "default")
delete(cleanedPropMap, "exclusiveMaximum")
delete(cleanedPropMap, "exclusiveMinimum")
// Check and clean 'format' for string types
if propType, typeExists := cleanedPropMap["type"].(string); typeExists && propType == "string" {
if formatValue, formatExists := cleanedPropMap["format"].(string); formatExists {
if formatValue != "enum" && formatValue != "date-time" {
delete(cleanedPropMap, "format")
}
}
}
// Recursively clean nested properties within this property if it's an object/array
// Check the type before recursing
if propType, typeExists := cleanedPropMap["type"].(string); typeExists && (propType == "object" || propType == "array") {
cleanedProps[propName] = cleanFunctionParameters(cleanedPropMap)
} else {
cleanedProps[propName] = cleanedPropMap // Assign the cleaned map back if not recursing
}
}
cleanedMap["properties"] = cleanedProps
}
// Recursively clean items in arrays if needed (e.g., type: array, items: { ... })
if items, ok := cleanedMap["items"].(map[string]interface{}); ok && items != nil {
cleanedMap["items"] = cleanFunctionParameters(items)
}
// Also handle items if it's an array of schemas
if itemsArray, ok := cleanedMap["items"].([]interface{}); ok {
cleanedItemsArray := make([]interface{}, len(itemsArray))
for i, item := range itemsArray {
cleanedItemsArray[i] = cleanFunctionParameters(item)
}
cleanedMap["items"] = cleanedItemsArray
}
// Recursively clean other schema composition keywords if necessary
for _, field := range []string{"allOf", "anyOf", "oneOf"} {
if nested, ok := cleanedMap[field].([]interface{}); ok {
cleanedNested := make([]interface{}, len(nested))
for i, item := range nested {
cleanedNested[i] = cleanFunctionParameters(item)
}
cleanedMap[field] = cleanedNested
}
}
return cleanedMap
}
func removeAdditionalPropertiesWithDepth(schema interface{}, depth int) interface{} {
if depth >= 5 {
return schema

View File

@@ -36,7 +36,7 @@ func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayIn
if !strings.Contains(request.Model, "claude") {
return nil, fmt.Errorf("you are using openai channel type with path /v1/messages, only claude model supported convert, but got %s", request.Model)
}
aiRequest, err := service.ClaudeToOpenAIRequest(*request)
aiRequest, err := service.ClaudeToOpenAIRequest(*request, info)
if err != nil {
return nil, err
}

View File

@@ -31,6 +31,9 @@ func handleClaudeFormat(c *gin.Context, data string, info *relaycommon.RelayInfo
return err
}
if streamResponse.Usage != nil {
info.ClaudeConvertInfo.Usage = streamResponse.Usage
}
claudeResponses := service.StreamResponseOpenAI2Claude(&streamResponse, info)
for _, resp := range claudeResponses {
helper.ClaudeData(c, *resp)
@@ -38,12 +41,7 @@ func handleClaudeFormat(c *gin.Context, data string, info *relaycommon.RelayInfo
return nil
}
func processStreamResponse(item string, responseTextBuilder *strings.Builder, toolCount *int) error {
var streamResponse dto.ChatCompletionsStreamResponse
if err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse); err != nil {
return err
}
func ProcessStreamResponse(streamResponse dto.ChatCompletionsStreamResponse, responseTextBuilder *strings.Builder, toolCount *int) error {
for _, choice := range streamResponse.Choices {
responseTextBuilder.WriteString(choice.Delta.GetContentString())
responseTextBuilder.WriteString(choice.Delta.GetReasoningContent())
@@ -78,7 +76,11 @@ func processChatCompletions(streamResp string, streamItems []string, responseTex
// 一次性解析失败,逐个解析
common.SysError("error unmarshalling stream response: " + err.Error())
for _, item := range streamItems {
if err := processStreamResponse(item, responseTextBuilder, toolCount); err != nil {
var streamResponse dto.ChatCompletionsStreamResponse
if err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse); err != nil {
return err
}
if err := ProcessStreamResponse(streamResponse, responseTextBuilder, toolCount); err != nil {
common.SysError("error processing stream response: " + err.Error())
}
}
@@ -170,15 +172,14 @@ func handleFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, lastStream
helper.Done(c)
case relaycommon.RelayFormatClaude:
info.ClaudeConvertInfo.Done = true
var streamResponse dto.ChatCompletionsStreamResponse
if err := json.Unmarshal(common.StringToByteSlice(lastStreamData), &streamResponse); err != nil {
common.SysError("error unmarshalling stream response: " + err.Error())
return
}
if !containStreamUsage {
streamResponse.Usage = usage
}
info.ClaudeConvertInfo.Usage = usage
claudeResponses := service.StreamResponseOpenAI2Claude(&streamResponse, info)
for _, resp := range claudeResponses {

View File

@@ -117,6 +117,7 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
model := info.UpstreamModelName
var responseTextBuilder strings.Builder
var toolCount int
var usage = &dto.Usage{}
var streamItems []string // store stream items
var forceFormat bool
@@ -130,8 +131,6 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
thinkToContent = think2Content
}
toolCount := 0
var (
lastStreamData string
)
@@ -142,7 +141,6 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
if err != nil {
common.SysError("error handling stream format: " + err.Error())
}
info.SetFirstResponseTime()
}
lastStreamData = data
streamItems = append(streamItems, data)
@@ -170,8 +168,10 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
}
}
}
if shouldSendLastResp {
sendStreamData(c, info, lastStreamData, forceFormat, thinkToContent)
//err = handleStreamFormat(c, info, lastStreamData, forceFormat, thinkToContent)
}
// 处理token计算

View File

@@ -0,0 +1,104 @@
package xai
import (
"errors"
"fmt"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/dto"
"one-api/relay/channel"
relaycommon "one-api/relay/common"
"strings"
)
type Adaptor struct {
}
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
//TODO implement me
//panic("implement me")
return nil, errors.New("not available")
}
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
//not available
return nil, errors.New("not available")
}
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
request.Size = ""
return request, nil
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
return fmt.Sprintf("%s/v1/chat/completions", info.BaseUrl), nil
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
channel.SetupApiRequestHeader(info, c, req)
req.Set("Authorization", "Bearer "+info.ApiKey)
return nil
}
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
if strings.HasPrefix(request.Model, "grok-3-mini") {
if request.MaxCompletionTokens == 0 && request.MaxTokens != 0 {
request.MaxCompletionTokens = request.MaxTokens
request.MaxTokens = 0
}
if strings.HasSuffix(request.Model, "-high") {
request.ReasoningEffort = "high"
request.Model = strings.TrimSuffix(request.Model, "-high")
} else if strings.HasSuffix(request.Model, "-low") {
request.ReasoningEffort = "low"
request.Model = strings.TrimSuffix(request.Model, "-low")
} else if strings.HasSuffix(request.Model, "-medium") {
request.ReasoningEffort = "medium"
request.Model = strings.TrimSuffix(request.Model, "-medium")
}
info.ReasoningEffort = request.ReasoningEffort
info.UpstreamModelName = request.Model
}
return request, nil
}
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
return nil, nil
}
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
//not available
return nil, errors.New("not available")
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream {
err, usage = xAIStreamHandler(c, resp, info)
} else {
err, usage = xAIHandler(c, resp, info)
}
//if _, ok := usage.(*dto.Usage); ok && usage != nil {
// usage.(*dto.Usage).CompletionTokens = usage.(*dto.Usage).TotalTokens - usage.(*dto.Usage).PromptTokens
//}
return
}
func (a *Adaptor) GetModelList() []string {
return ModelList
}
func (a *Adaptor) GetChannelName() string {
return ChannelName
}

View File

@@ -0,0 +1,18 @@
package xai
var ModelList = []string{
// grok-3
"grok-3-beta", "grok-3-mini-beta",
// grok-3 mini
"grok-3-fast-beta", "grok-3-mini-fast-beta",
// extend grok-3-mini reasoning
"grok-3-mini-beta-high", "grok-3-mini-beta-low", "grok-3-mini-beta-medium",
"grok-3-mini-fast-beta-high", "grok-3-mini-fast-beta-low", "grok-3-mini-fast-beta-medium",
// image model
"grok-2-image",
// legacy models
"grok-2", "grok-2-vision",
"grok-beta", "grok-vision-beta",
}
var ChannelName = "xai"

14
relay/channel/xai/dto.go Normal file
View File

@@ -0,0 +1,14 @@
package xai
import "one-api/dto"
// ChatCompletionResponse represents the response from XAI chat completion API
type ChatCompletionResponse struct {
Id string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Model string `json:"model"`
Choices []dto.ChatCompletionsStreamResponseChoice
Usage *dto.Usage `json:"usage"`
SystemFingerprint string `json:"system_fingerprint"`
}

119
relay/channel/xai/text.go Normal file
View File

@@ -0,0 +1,119 @@
package xai
import (
"bytes"
"encoding/json"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/common"
"one-api/dto"
"one-api/relay/channel/openai"
relaycommon "one-api/relay/common"
"one-api/relay/helper"
"one-api/service"
"strings"
)
func streamResponseXAI2OpenAI(xAIResp *dto.ChatCompletionsStreamResponse, usage *dto.Usage) *dto.ChatCompletionsStreamResponse {
if xAIResp == nil {
return nil
}
if xAIResp.Usage != nil {
xAIResp.Usage.CompletionTokens = usage.CompletionTokens
}
openAIResp := &dto.ChatCompletionsStreamResponse{
Id: xAIResp.Id,
Object: xAIResp.Object,
Created: xAIResp.Created,
Model: xAIResp.Model,
Choices: xAIResp.Choices,
Usage: xAIResp.Usage,
}
return openAIResp
}
func xAIStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
usage := &dto.Usage{}
var responseTextBuilder strings.Builder
var toolCount int
var containStreamUsage bool
helper.SetEventStreamHeaders(c)
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
var xAIResp *dto.ChatCompletionsStreamResponse
err := json.Unmarshal([]byte(data), &xAIResp)
if err != nil {
common.SysError("error unmarshalling stream response: " + err.Error())
return true
}
// 把 xAI 的usage转换为 OpenAI 的usage
if xAIResp.Usage != nil {
containStreamUsage = true
usage.PromptTokens = xAIResp.Usage.PromptTokens
usage.TotalTokens = xAIResp.Usage.TotalTokens
usage.CompletionTokens = usage.TotalTokens - usage.PromptTokens
}
openaiResponse := streamResponseXAI2OpenAI(xAIResp, usage)
_ = openai.ProcessStreamResponse(*openaiResponse, &responseTextBuilder, &toolCount)
err = helper.ObjectData(c, openaiResponse)
if err != nil {
common.SysError(err.Error())
}
return true
})
if !containStreamUsage {
usage, _ = service.ResponseText2Usage(responseTextBuilder.String(), info.UpstreamModelName, info.PromptTokens)
usage.CompletionTokens += toolCount * 7
}
helper.Done(c)
err := resp.Body.Close()
if err != nil {
//return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
common.SysError("close_response_body_failed: " + err.Error())
}
return nil, usage
}
func xAIHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
responseBody, err := io.ReadAll(resp.Body)
var response *dto.TextResponse
err = common.DecodeJson(responseBody, &response)
if err != nil {
common.SysError("error unmarshalling stream response: " + err.Error())
return nil, nil
}
response.Usage.CompletionTokens = response.Usage.TotalTokens - response.Usage.PromptTokens
response.Usage.CompletionTokenDetails.TextTokens = response.Usage.CompletionTokens - response.Usage.CompletionTokenDetails.ReasoningTokens
// new body
encodeJson, err := common.EncodeJson(response)
if err != nil {
common.SysError("error marshalling stream response: " + err.Error())
return nil, nil
}
// set new body
resp.Body = io.NopCloser(bytes.NewBuffer(encodeJson))
for k, v := range resp.Header {
c.Writer.Header().Set(k, v[0])
}
c.Writer.WriteHeader(resp.StatusCode)
_, err = io.Copy(c.Writer, resp.Body)
if err != nil {
return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
return nil, &response.Usage
}

View File

@@ -10,6 +10,7 @@ import (
"one-api/relay/channel"
"one-api/relay/channel/openai"
relaycommon "one-api/relay/common"
relayconstant "one-api/relay/constant"
)
type Adaptor struct {
@@ -35,7 +36,13 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
return fmt.Sprintf("%s/api/paas/v4/chat/completions", info.BaseUrl), nil
baseUrl := fmt.Sprintf("%s/api/paas/v4", info.BaseUrl)
switch info.RelayMode {
case relayconstant.RelayModeEmbeddings:
return fmt.Sprintf("%s/embeddings", baseUrl), nil
default:
return fmt.Sprintf("%s/chat/completions", baseUrl), nil
}
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
@@ -60,8 +67,7 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
}
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
return request, nil
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {

View File

@@ -1,17 +1,9 @@
package zhipu_4v
import (
"bufio"
"bytes"
"encoding/json"
"github.com/gin-gonic/gin"
"github.com/golang-jwt/jwt"
"io"
"net/http"
"one-api/common"
"one-api/dto"
"one-api/relay/helper"
"one-api/service"
"strings"
"sync"
"time"
@@ -119,163 +111,3 @@ func requestOpenAI2Zhipu(request dto.GeneralOpenAIRequest) *dto.GeneralOpenAIReq
ToolChoice: request.ToolChoice,
}
}
//func responseZhipu2OpenAI(response *dto.OpenAITextResponse) *dto.OpenAITextResponse {
// fullTextResponse := dto.OpenAITextResponse{
// Id: response.Id,
// Object: "chat.completion",
// Created: common.GetTimestamp(),
// Choices: make([]dto.OpenAITextResponseChoice, 0, len(response.TextResponseChoices)),
// Usage: response.Usage,
// }
// for i, choice := range response.TextResponseChoices {
// content, _ := json.Marshal(strings.Trim(choice.Content, "\""))
// openaiChoice := dto.OpenAITextResponseChoice{
// Index: i,
// Message: dto.Message{
// Role: choice.Role,
// Content: content,
// },
// FinishReason: "",
// }
// if i == len(response.TextResponseChoices)-1 {
// openaiChoice.FinishReason = "stop"
// }
// fullTextResponse.Choices = append(fullTextResponse.Choices, openaiChoice)
// }
// return &fullTextResponse
//}
func streamResponseZhipu2OpenAI(zhipuResponse *ZhipuV4StreamResponse) *dto.ChatCompletionsStreamResponse {
var choice dto.ChatCompletionsStreamResponseChoice
choice.Delta.Content = zhipuResponse.Choices[0].Delta.Content
choice.Delta.Role = zhipuResponse.Choices[0].Delta.Role
choice.Delta.ToolCalls = zhipuResponse.Choices[0].Delta.ToolCalls
choice.Index = zhipuResponse.Choices[0].Index
choice.FinishReason = zhipuResponse.Choices[0].FinishReason
response := dto.ChatCompletionsStreamResponse{
Id: zhipuResponse.Id,
Object: "chat.completion.chunk",
Created: zhipuResponse.Created,
Model: "glm-4v",
Choices: []dto.ChatCompletionsStreamResponseChoice{choice},
}
return &response
}
func lastStreamResponseZhipuV42OpenAI(zhipuResponse *ZhipuV4StreamResponse) (*dto.ChatCompletionsStreamResponse, *dto.Usage) {
response := streamResponseZhipu2OpenAI(zhipuResponse)
return response, &zhipuResponse.Usage
}
func zhipuStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
var usage *dto.Usage
scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
return 0, nil, nil
}
if i := strings.Index(string(data), "\n"); i >= 0 {
return i + 1, data[0:i], nil
}
if atEOF {
return len(data), data, nil
}
return 0, nil, nil
})
dataChan := make(chan string)
stopChan := make(chan bool)
go func() {
for scanner.Scan() {
data := scanner.Text()
if len(data) < 6 { // ignore blank line or wrong format
continue
}
if data[:6] != "data: " && data[:6] != "[DONE]" {
continue
}
dataChan <- data
}
stopChan <- true
}()
helper.SetEventStreamHeaders(c)
c.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:
if strings.HasPrefix(data, "data: [DONE]") {
data = data[:12]
}
// some implementations may add \r at the end of data
data = strings.TrimSuffix(data, "\r")
var streamResponse ZhipuV4StreamResponse
err := json.Unmarshal([]byte(data), &streamResponse)
if err != nil {
common.SysError("error unmarshalling stream response: " + err.Error())
}
var response *dto.ChatCompletionsStreamResponse
if strings.Contains(data, "prompt_tokens") {
response, usage = lastStreamResponseZhipuV42OpenAI(&streamResponse)
} else {
response = streamResponseZhipu2OpenAI(&streamResponse)
}
jsonResponse, err := json.Marshal(response)
if err != nil {
common.SysError("error marshalling stream response: " + err.Error())
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
return true
case <-stopChan:
return false
}
})
err := resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
return nil, usage
}
func zhipuHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
var textResponse ZhipuV4Response
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
err = json.Unmarshal(responseBody, &textResponse)
if err != nil {
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
if textResponse.Error.Type != "" {
return &dto.OpenAIErrorWithStatusCode{
Error: textResponse.Error,
StatusCode: resp.StatusCode,
}, nil
}
// Reset response body
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
// We shouldn't set the header before we parse the response body, because the parse part may fail.
// And then we will have to send an error response, but in this case, the header has already been set.
// So the HTTPClient will be confused by the response.
// For example, Postman will report error, and we cannot check the response at all.
for k, v := range resp.Header {
c.Writer.Header().Set(k, v[0])
}
c.Writer.WriteHeader(resp.StatusCode)
_, err = io.Copy(c.Writer, resp.Body)
if err != nil {
return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
return nil, &textResponse.Usage
}

View File

@@ -6,6 +6,7 @@ import (
"one-api/dto"
relayconstant "one-api/relay/constant"
"strings"
"sync"
"time"
"github.com/gin-gonic/gin"
@@ -19,13 +20,18 @@ type ThinkingContentInfo struct {
}
const (
LastMessageTypeText = "text"
LastMessageTypeTools = "tools"
LastMessageTypeNone = "none"
LastMessageTypeText = "text"
LastMessageTypeTools = "tools"
LastMessageTypeThinking = "thinking"
)
type ClaudeConvertInfo struct {
LastMessagesType string
Index int
Usage *dto.Usage
FinishReason string
Done bool
}
const (
@@ -49,6 +55,7 @@ type RelayInfo struct {
StartTime time.Time
FirstResponseTime time.Time
isFirstResponse bool
responseMutex sync.Mutex // Add mutex for protecting concurrent access
//SendLastReasoningResponse bool
ApiType int
IsStream bool
@@ -76,6 +83,7 @@ type RelayInfo struct {
AudioUsage bool
ReasoningEffort string
ChannelSetting map[string]interface{}
ParamOverride map[string]interface{}
UserSetting map[string]interface{}
UserEmail string
UserQuota int
@@ -96,6 +104,7 @@ var streamSupportedChannels = map[int]bool{
common.ChannelTypeAzure: true,
common.ChannelTypeVolcEngine: true,
common.ChannelTypeOllama: true,
common.ChannelTypeXai: true,
}
func GenRelayInfoWs(c *gin.Context, ws *websocket.Conn) *RelayInfo {
@@ -112,7 +121,7 @@ func GenRelayInfoClaude(c *gin.Context) *RelayInfo {
info.RelayFormat = RelayFormatClaude
info.ShouldIncludeUsage = false
info.ClaudeConvertInfo = ClaudeConvertInfo{
LastMessagesType: LastMessageTypeText,
LastMessagesType: LastMessageTypeNone,
}
return info
}
@@ -131,6 +140,7 @@ func GenRelayInfo(c *gin.Context) *RelayInfo {
channelType := c.GetInt("channel_type")
channelId := c.GetInt("channel_id")
channelSetting := c.GetStringMap("channel_setting")
paramOverride := c.GetStringMap("param_override")
tokenId := c.GetInt("token_id")
tokenKey := c.GetString("token_key")
@@ -168,6 +178,7 @@ func GenRelayInfo(c *gin.Context) *RelayInfo {
ApiKey: strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "),
Organization: c.GetString("channel_organization"),
ChannelSetting: channelSetting,
ParamOverride: paramOverride,
RelayFormat: RelayFormatOpenAI,
ThinkingContentInfo: ThinkingContentInfo{
IsFirstThinkingContent: true,
@@ -203,12 +214,19 @@ func (info *RelayInfo) SetIsStream(isStream bool) {
}
func (info *RelayInfo) SetFirstResponseTime() {
info.responseMutex.Lock()
defer info.responseMutex.Unlock()
if info.isFirstResponse {
info.FirstResponseTime = time.Now()
info.isFirstResponse = false
}
}
func (info *RelayInfo) HasSendResponse() bool {
return info.FirstResponseTime.After(info.StartTime)
}
type TaskRelayInfo struct {
*RelayInfo
Action string

View File

@@ -32,6 +32,7 @@ const (
APITypeBaiduV2
APITypeOpenRouter
APITypeXinference
APITypeXai
APITypeDummy // this one is only for count, do not add any channel after this
)
@@ -92,6 +93,8 @@ func ChannelType2APIType(channelType int) (int, bool) {
apiType = APITypeOpenRouter
case common.ChannelTypeXinference:
apiType = APITypeXinference
case common.ChannelTypeXai:
apiType = APITypeXai
}
if apiType == -1 {
return APITypeOpenAI, false

View File

@@ -55,7 +55,20 @@ func StringData(c *gin.Context, str string) error {
return nil
}
func PingData(c *gin.Context) error {
c.Writer.Write([]byte(": PING\n\n"))
if flusher, ok := c.Writer.(http.Flusher); ok {
flusher.Flush()
} else {
return errors.New("streaming error: flusher not found")
}
return nil
}
func ObjectData(c *gin.Context, object interface{}) error {
if object == nil {
return errors.New("object is nil")
}
jsonData, err := json.Marshal(object)
if err != nil {
return fmt.Errorf("error marshalling object: %w", err)

View File

@@ -4,6 +4,7 @@ import (
"fmt"
"github.com/gin-gonic/gin"
"one-api/common"
constant2 "one-api/constant"
relaycommon "one-api/relay/common"
"one-api/setting"
"one-api/setting/operation_setting"
@@ -20,6 +21,10 @@ type PriceData struct {
ShouldPreConsumedQuota int
}
func (p PriceData) ToSetting() string {
return fmt.Sprintf("ModelPrice: %f, ModelRatio: %f, CompletionRatio: %f, CacheRatio: %f, GroupRatio: %f, UsePrice: %t, CacheCreationRatio: %f, ShouldPreConsumedQuota: %d", p.ModelPrice, p.ModelRatio, p.CompletionRatio, p.CacheRatio, p.GroupRatio, p.UsePrice, p.CacheCreationRatio, p.ShouldPreConsumedQuota)
}
func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens int, maxTokens int) (PriceData, error) {
modelPrice, usePrice := operation_setting.GetModelPrice(info.OriginModelName, false)
groupRatio := setting.GetGroupRatio(info.Group)
@@ -36,10 +41,19 @@ func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens
var success bool
modelRatio, success = operation_setting.GetModelRatio(info.OriginModelName)
if !success {
if info.UserId == 1 {
return PriceData{}, fmt.Errorf("模型 %s 倍率或价格未配置请设置或开始自用模式Model %s ratio or price not set, please set or start self-use mode", info.OriginModelName, info.OriginModelName)
} else {
return PriceData{}, fmt.Errorf("模型 %s 倍率或价格未配置, 请联系管理员设置Model %s ratio or price not set, please contact administrator to set", info.OriginModelName, info.OriginModelName)
acceptUnsetRatio := false
if accept, ok := info.UserSetting[constant2.UserAcceptUnsetRatioModel]; ok {
b, ok := accept.(bool)
if ok {
acceptUnsetRatio = b
}
}
if !acceptUnsetRatio {
if info.UserId == 1 {
return PriceData{}, fmt.Errorf("模型 %s 倍率或价格未配置请设置或开始自用模式Model %s ratio or price not set, please set or start self-use mode", info.OriginModelName, info.OriginModelName)
} else {
return PriceData{}, fmt.Errorf("模型 %s 倍率或价格未配置, 请联系管理员设置Model %s ratio or price not set, please contact administrator to set", info.OriginModelName, info.OriginModelName)
}
}
}
completionRatio = operation_setting.GetCompletionRatio(info.OriginModelName)
@@ -50,7 +64,8 @@ func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens
} else {
preConsumedQuota = int(modelPrice * common.QuotaPerUnit * groupRatio)
}
return PriceData{
priceData := PriceData{
ModelPrice: modelPrice,
ModelRatio: modelRatio,
CompletionRatio: completionRatio,
@@ -59,5 +74,11 @@ func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens
CacheRatio: cacheRatio,
CacheCreationRatio: cacheCreationRatio,
ShouldPreConsumedQuota: preConsumedQuota,
}, nil
}
if common.DebugEnabled {
println(fmt.Sprintf("model_price_helper result: %s", priceData.ToSetting()))
}
return priceData, nil
}

View File

@@ -3,20 +3,29 @@ package helper
import (
"bufio"
"context"
"github.com/bytedance/gopkg/util/gopool"
"io"
"net/http"
"one-api/common"
"one-api/constant"
relaycommon "one-api/relay/common"
"one-api/setting/operation_setting"
"strings"
"sync"
"time"
"github.com/gin-gonic/gin"
)
const (
InitialScannerBufferSize = 1 << 20 // 1MB (1*1024*1024)
MaxScannerBufferSize = 10 << 20 // 10MB (10*1024*1024)
DefaultPingInterval = 10 * time.Second
)
func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, dataHandler func(data string) bool) {
if resp == nil {
if resp == nil || dataHandler == nil {
return
}
@@ -29,16 +38,32 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon
}
var (
stopChan = make(chan bool, 2)
scanner = bufio.NewScanner(resp.Body)
ticker = time.NewTicker(streamingTimeout)
stopChan = make(chan bool, 2)
scanner = bufio.NewScanner(resp.Body)
ticker = time.NewTicker(streamingTimeout)
pingTicker *time.Ticker
writeMutex sync.Mutex // Mutex to protect concurrent writes
)
generalSettings := operation_setting.GetGeneralSetting()
pingEnabled := generalSettings.PingIntervalEnabled
pingInterval := time.Duration(generalSettings.PingIntervalSeconds) * time.Second
if pingInterval <= 0 {
pingInterval = DefaultPingInterval
}
if pingEnabled {
pingTicker = time.NewTicker(pingInterval)
}
defer func() {
ticker.Stop()
if pingTicker != nil {
pingTicker.Stop()
}
close(stopChan)
}()
scanner.Buffer(make([]byte, InitialScannerBufferSize), MaxScannerBufferSize)
scanner.Split(bufio.ScanLines)
SetEventStreamHeaders(c)
@@ -46,6 +71,34 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon
defer cancel()
ctx = context.WithValue(ctx, "stop_chan", stopChan)
// Handle ping data sending
if pingEnabled && pingTicker != nil {
gopool.Go(func() {
for {
select {
case <-pingTicker.C:
writeMutex.Lock() // Lock before writing
err := PingData(c)
writeMutex.Unlock() // Unlock after writing
if err != nil {
common.LogError(c, "ping data error: "+err.Error())
common.SafeSendBool(stopChan, true)
return
}
if common.DebugEnabled {
println("ping data sent")
}
case <-ctx.Done():
if common.DebugEnabled {
println("ping data goroutine stopped")
}
return
}
}
})
}
common.RelayCtxGo(ctx, func() {
for scanner.Scan() {
ticker.Reset(streamingTimeout)
@@ -65,7 +118,9 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon
data = strings.TrimSuffix(data, "\"")
if !strings.HasPrefix(data, "[DONE]") {
info.SetFirstResponseTime()
writeMutex.Lock() // Lock before writing
success := dataHandler(data)
writeMutex.Unlock() // Unlock after writing
if !success {
break
}
@@ -85,7 +140,9 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon
case <-ticker.C:
// 超时处理逻辑
common.LogError(c, "streaming timeout")
common.SafeSendBool(stopChan, true)
case <-stopChan:
// 正常结束
common.LogInfo(c, "streaming finished")
}
}

View File

@@ -109,7 +109,7 @@ func TextHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
c.Set("prompt_tokens", promptTokens)
}
priceData, err := helper.ModelPriceHelper(c, relayInfo, promptTokens, int(textRequest.MaxTokens))
priceData, err := helper.ModelPriceHelper(c, relayInfo, promptTokens, int(math.Max(float64(textRequest.MaxTokens), float64(textRequest.MaxCompletionTokens))))
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "model_price_error", http.StatusInternalServerError)
}
@@ -168,6 +168,23 @@ func TextHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "json_marshal_failed", http.StatusInternalServerError)
}
// apply param override
if len(relayInfo.ParamOverride) > 0 {
reqMap := make(map[string]interface{})
err = json.Unmarshal(jsonData, &reqMap)
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "param_override_unmarshal_failed", http.StatusInternalServerError)
}
for key, value := range relayInfo.ParamOverride {
reqMap[key] = value
}
jsonData, err = json.Marshal(reqMap)
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "param_override_marshal_failed", http.StatusInternalServerError)
}
}
if common.DebugEnabled {
println("requestBody: ", string(jsonData))
}
@@ -372,17 +389,18 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
common.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, "+
"tokenId %d, model %s pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, modelName, preConsumedQuota))
} else {
quotaDelta := quota - preConsumedQuota
if quotaDelta != 0 {
err := service.PostConsumeQuota(relayInfo, quotaDelta, preConsumedQuota, true)
if err != nil {
common.LogError(ctx, "error consuming token remain quota: "+err.Error())
}
}
model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, quota)
model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota)
}
quotaDelta := quota - preConsumedQuota
if quotaDelta != 0 {
err := service.PostConsumeQuota(relayInfo, quotaDelta, preConsumedQuota, true)
if err != nil {
common.LogError(ctx, "error consuming token remain quota: "+err.Error())
}
}
logModel := modelName
if strings.HasPrefix(logModel, "gpt-4-gizmo") {
logModel = "gpt-4-gizmo-*"

View File

@@ -25,6 +25,7 @@ import (
"one-api/relay/channel/tencent"
"one-api/relay/channel/vertex"
"one-api/relay/channel/volcengine"
"one-api/relay/channel/xai"
"one-api/relay/channel/xunfei"
"one-api/relay/channel/zhipu"
"one-api/relay/channel/zhipu_4v"
@@ -85,6 +86,8 @@ func GetAdaptor(apiType int) channel.Adaptor {
return &openai.Adaptor{}
case constant.APITypeXinference:
return &openai.Adaptor{}
case constant.APITypeXai:
return &xai.Adaptor{}
}
return nil
}

View File

@@ -13,6 +13,8 @@ func SetApiRouter(router *gin.Engine) {
apiRouter.Use(gzip.Gzip(gzip.DefaultCompression))
apiRouter.Use(middleware.GlobalAPIRateLimit())
{
apiRouter.GET("/setup", controller.GetSetup)
apiRouter.POST("/setup", controller.PostSetup)
apiRouter.GET("/status", controller.GetStatus)
apiRouter.GET("/models", middleware.UserAuth(), controller.DashboardListModels)
apiRouter.GET("/status/test", middleware.AdminAuth(), controller.TestStatus)

View File

@@ -6,9 +6,10 @@ import (
"one-api/common"
"one-api/dto"
relaycommon "one-api/relay/common"
"strings"
)
func ClaudeToOpenAIRequest(claudeRequest dto.ClaudeRequest) (*dto.GeneralOpenAIRequest, error) {
func ClaudeToOpenAIRequest(claudeRequest dto.ClaudeRequest, info *relaycommon.RelayInfo) (*dto.GeneralOpenAIRequest, error) {
openAIRequest := dto.GeneralOpenAIRequest{
Model: claudeRequest.Model,
MaxTokens: claudeRequest.MaxTokens,
@@ -17,6 +18,13 @@ func ClaudeToOpenAIRequest(claudeRequest dto.ClaudeRequest) (*dto.GeneralOpenAIR
Stream: claudeRequest.Stream,
}
if claudeRequest.Thinking != nil {
if strings.HasSuffix(info.OriginModelName, "-thinking") &&
!strings.HasSuffix(claudeRequest.Model, "-thinking") {
openAIRequest.Model = openAIRequest.Model + "-thinking"
}
}
// Convert stop sequences
if len(claudeRequest.StopSequences) == 1 {
openAIRequest.Stop = claudeRequest.StopSequences[0]
@@ -45,7 +53,7 @@ func ClaudeToOpenAIRequest(claudeRequest dto.ClaudeRequest) (*dto.GeneralOpenAIR
// Add system message if present
if claudeRequest.System != nil {
if claudeRequest.IsStringSystem() {
if claudeRequest.IsStringSystem() && claudeRequest.GetStringSystem() != "" {
openAIMessage := dto.Message{
Role: "system",
}
@@ -122,23 +130,22 @@ func ClaudeToOpenAIRequest(claudeRequest dto.ClaudeRequest) (*dto.GeneralOpenAIR
oaiToolMessage.SetStringContent(mediaMsg.GetStringContent())
} else {
mediaContents := mediaMsg.ParseMediaContent()
if len(mediaContents) > 0 && mediaContents[0].Text != nil {
oaiToolMessage.SetStringContent(*mediaContents[0].Text)
}
encodeJson, _ := common.EncodeJson(mediaContents)
oaiToolMessage.SetStringContent(string(encodeJson))
}
openAIMessages = append(openAIMessages, oaiToolMessage)
}
}
if len(mediaMessages) > 0 {
openAIMessage.SetMediaContent(mediaMessages)
}
if len(toolCalls) > 0 {
openAIMessage.SetToolCalls(toolCalls)
}
if len(mediaMessages) > 0 && len(toolCalls) == 0 {
openAIMessage.SetMediaContent(mediaMessages)
}
}
if len(openAIMessage.ParseContent()) > 0 {
if len(openAIMessage.ParseContent()) > 0 || len(openAIMessage.ToolCalls) > 0 {
openAIMessages = append(openAIMessages, openAIMessage)
}
}
@@ -211,15 +218,15 @@ func StreamResponseOpenAI2Claude(openAIResponse *dto.ChatCompletionsStreamRespon
resp.SetIndex(0)
claudeResponses = append(claudeResponses, resp)
} else {
resp := &dto.ClaudeResponse{
Type: "content_block_start",
ContentBlock: &dto.ClaudeMediaMessage{
Type: "text",
Text: common.GetPointer[string](""),
},
}
resp.SetIndex(0)
claudeResponses = append(claudeResponses, resp)
//resp := &dto.ClaudeResponse{
// Type: "content_block_start",
// ContentBlock: &dto.ClaudeMediaMessage{
// Type: "text",
// Text: common.GetPointer[string](""),
// },
//}
//resp.SetIndex(0)
//claudeResponses = append(claudeResponses, resp)
}
return claudeResponses
}
@@ -232,16 +239,20 @@ func StreamResponseOpenAI2Claude(openAIResponse *dto.ChatCompletionsStreamRespon
chosenChoice := openAIResponse.Choices[0]
if chosenChoice.FinishReason != nil && *chosenChoice.FinishReason != "" {
// should be done
info.FinishReason = *chosenChoice.FinishReason
return claudeResponses
}
if info.Done {
claudeResponses = append(claudeResponses, generateStopBlock(info.ClaudeConvertInfo.Index))
if openAIResponse.Usage != nil {
if info.ClaudeConvertInfo.Usage != nil {
claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
Type: "message_delta",
Usage: &dto.ClaudeUsage{
InputTokens: openAIResponse.Usage.PromptTokens,
OutputTokens: openAIResponse.Usage.CompletionTokens,
InputTokens: info.ClaudeConvertInfo.Usage.PromptTokens,
OutputTokens: info.ClaudeConvertInfo.Usage.CompletionTokens,
},
Delta: &dto.ClaudeMediaMessage{
StopReason: common.GetPointer[string](stopReasonOpenAI2Claude(*chosenChoice.FinishReason)),
StopReason: common.GetPointer[string](stopReasonOpenAI2Claude(info.FinishReason)),
},
})
}
@@ -250,10 +261,10 @@ func StreamResponseOpenAI2Claude(openAIResponse *dto.ChatCompletionsStreamRespon
})
} else {
var claudeResponse dto.ClaudeResponse
claudeResponse.SetIndex(0)
var isEmpty bool
claudeResponse.Type = "content_block_delta"
if len(chosenChoice.Delta.ToolCalls) > 0 {
if info.ClaudeConvertInfo.LastMessagesType == relaycommon.LastMessageTypeText {
if info.ClaudeConvertInfo.LastMessagesType != relaycommon.LastMessageTypeTools {
claudeResponses = append(claudeResponses, generateStopBlock(info.ClaudeConvertInfo.Index))
info.ClaudeConvertInfo.Index++
claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
@@ -274,15 +285,57 @@ func StreamResponseOpenAI2Claude(openAIResponse *dto.ChatCompletionsStreamRespon
PartialJson: &chosenChoice.Delta.ToolCalls[0].Function.Arguments,
}
} else {
info.ClaudeConvertInfo.LastMessagesType = relaycommon.LastMessageTypeText
// text delta
claudeResponse.Delta = &dto.ClaudeMediaMessage{
Type: "text_delta",
Text: common.GetPointer[string](chosenChoice.Delta.GetContentString()),
reasoning := chosenChoice.Delta.GetReasoningContent()
textContent := chosenChoice.Delta.GetContentString()
if reasoning != "" || textContent != "" {
if reasoning != "" {
if info.ClaudeConvertInfo.LastMessagesType != relaycommon.LastMessageTypeThinking {
//info.ClaudeConvertInfo.Index++
claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
Index: &info.ClaudeConvertInfo.Index,
Type: "content_block_start",
ContentBlock: &dto.ClaudeMediaMessage{
Type: "thinking",
Thinking: "",
},
})
}
info.ClaudeConvertInfo.LastMessagesType = relaycommon.LastMessageTypeThinking
// text delta
claudeResponse.Delta = &dto.ClaudeMediaMessage{
Type: "thinking_delta",
Thinking: reasoning,
}
} else {
if info.ClaudeConvertInfo.LastMessagesType != relaycommon.LastMessageTypeText {
if info.LastMessagesType == relaycommon.LastMessageTypeThinking || info.LastMessagesType == relaycommon.LastMessageTypeTools {
claudeResponses = append(claudeResponses, generateStopBlock(info.ClaudeConvertInfo.Index))
info.ClaudeConvertInfo.Index++
}
claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
Index: &info.ClaudeConvertInfo.Index,
Type: "content_block_start",
ContentBlock: &dto.ClaudeMediaMessage{
Type: "text",
Text: common.GetPointer[string](""),
},
})
}
info.ClaudeConvertInfo.LastMessagesType = relaycommon.LastMessageTypeText
// text delta
claudeResponse.Delta = &dto.ClaudeMediaMessage{
Type: "text_delta",
Text: common.GetPointer[string](textContent),
}
}
} else {
isEmpty = true
}
}
claudeResponse.Index = &info.ClaudeConvertInfo.Index
claudeResponses = append(claudeResponses, &claudeResponse)
if !isEmpty {
claudeResponses = append(claudeResponses, &claudeResponse)
}
}
}

View File

@@ -8,9 +8,9 @@ import (
"one-api/dto"
)
var maxFileSize = constant.MaxFileDownloadMB * 1024 * 1024
func GetFileBase64FromUrl(url string) (*dto.LocalFileData, error) {
var maxFileSize = constant.MaxFileDownloadMB * 1024 * 1024
resp, err := DoDownloadRequest(url)
if err != nil {
return nil, err
@@ -22,7 +22,6 @@ func GetFileBase64FromUrl(url string) (*dto.LocalFileData, error) {
if err != nil {
return nil, err
}
// Check actual size after reading
if len(fileBytes) > maxFileSize {
return nil, fmt.Errorf("file size exceeds maximum allowed size: %dMB", constant.MaxFileDownloadMB)

View File

@@ -243,20 +243,18 @@ func PostClaudeConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
common.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, "+
"tokenId %d, model %s pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, modelName, preConsumedQuota))
} else {
//if sensitiveResp != nil {
// logContent += fmt.Sprintf(",敏感词:%s", strings.Join(sensitiveResp.SensitiveWords, ", "))
//}
quotaDelta := quota - preConsumedQuota
if quotaDelta != 0 {
err := PostConsumeQuota(relayInfo, quotaDelta, preConsumedQuota, true)
if err != nil {
common.LogError(ctx, "error consuming token remain quota: "+err.Error())
}
}
model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, quota)
model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota)
}
quotaDelta := quota - preConsumedQuota
if quotaDelta != 0 {
err := PostConsumeQuota(relayInfo, quotaDelta, preConsumedQuota, true)
if err != nil {
common.LogError(ctx, "error consuming token remain quota: "+err.Error())
}
}
other := GenerateClaudeOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio,
cacheTokens, cacheRatio, cacheCreationTokens, cacheCreationRatio, modelPrice)
model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, promptTokens, completionTokens, modelName,
@@ -318,17 +316,18 @@ func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
common.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, "+
"tokenId %d, model %s pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, relayInfo.OriginModelName, preConsumedQuota))
} else {
quotaDelta := quota - preConsumedQuota
if quotaDelta != 0 {
err := PostConsumeQuota(relayInfo, quotaDelta, preConsumedQuota, true)
if err != nil {
common.LogError(ctx, "error consuming token remain quota: "+err.Error())
}
}
model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, quota)
model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota)
}
quotaDelta := quota - preConsumedQuota
if quotaDelta != 0 {
err := PostConsumeQuota(relayInfo, quotaDelta, preConsumedQuota, true)
if err != nil {
common.LogError(ctx, "error consuming token remain quota: "+err.Error())
}
}
logModel := relayInfo.OriginModelName
if extraContent != "" {
logContent += ", " + extraContent

View File

@@ -398,6 +398,8 @@ func CountTokenMessages(info *relaycommon.RelayInfo, messages []dto.Message, mod
} else if m.Type == dto.ContentTypeInputAudio {
// TODO: 音频token数量计算
tokenNum += 100
} else if m.Type == dto.ContentTypeFile {
tokenNum += 5000
} else {
tokenNum += getTokenNum(tokenEncoder, m.Text)
}

View File

@@ -4,6 +4,7 @@ import (
"encoding/json"
"errors"
"one-api/common"
"sync"
)
var groupRatio = map[string]float64{
@@ -11,8 +12,12 @@ var groupRatio = map[string]float64{
"vip": 1,
"svip": 1,
}
var groupRatioMutex sync.RWMutex
func GetGroupRatioCopy() map[string]float64 {
groupRatioMutex.RLock()
defer groupRatioMutex.RUnlock()
groupRatioCopy := make(map[string]float64)
for k, v := range groupRatio {
groupRatioCopy[k] = v
@@ -21,11 +26,17 @@ func GetGroupRatioCopy() map[string]float64 {
}
func ContainsGroupRatio(name string) bool {
groupRatioMutex.RLock()
defer groupRatioMutex.RUnlock()
_, ok := groupRatio[name]
return ok
}
func GroupRatio2JSONString() string {
groupRatioMutex.RLock()
defer groupRatioMutex.RUnlock()
jsonBytes, err := json.Marshal(groupRatio)
if err != nil {
common.SysError("error marshalling model ratio: " + err.Error())
@@ -34,11 +45,17 @@ func GroupRatio2JSONString() string {
}
func UpdateGroupRatioByJSONString(jsonStr string) error {
groupRatioMutex.Lock()
defer groupRatioMutex.Unlock()
groupRatio = make(map[string]float64)
return json.Unmarshal([]byte(jsonStr), &groupRatio)
}
func GetGroupRatio(name string) float64 {
groupRatioMutex.RLock()
defer groupRatioMutex.RUnlock()
ratio, ok := groupRatio[name]
if !ok {
common.SysError("group ratio not found: " + name)

View File

@@ -14,6 +14,8 @@ var defaultCacheRatio = map[string]float64{
"o1-preview": 0.5,
"o1-mini-2024-09-12": 0.5,
"o1-mini": 0.5,
"o3-mini": 0.5,
"o3-mini-2025-01-31": 0.5,
"gpt-4o-2024-11-20": 0.5,
"gpt-4o-2024-08-06": 0.5,
"gpt-4o": 0.5,
@@ -21,6 +23,8 @@ var defaultCacheRatio = map[string]float64{
"gpt-4o-mini": 0.5,
"gpt-4o-realtime-preview": 0.5,
"gpt-4o-mini-realtime-preview": 0.5,
"gpt-4.5-preview": 0.5,
"gpt-4.5-preview-2025-02-27": 0.5,
"deepseek-chat": 0.25,
"deepseek-reasoner": 0.25,
"deepseek-coder": 0.25,
@@ -52,17 +56,15 @@ var cacheRatioMapMutex sync.RWMutex
// GetCacheRatioMap returns the cache ratio map
func GetCacheRatioMap() map[string]float64 {
cacheRatioMapMutex.Lock()
defer cacheRatioMapMutex.Unlock()
if cacheRatioMap == nil {
cacheRatioMap = defaultCacheRatio
}
cacheRatioMapMutex.RLock()
defer cacheRatioMapMutex.RUnlock()
return cacheRatioMap
}
// CacheRatio2JSONString converts the cache ratio map to a JSON string
func CacheRatio2JSONString() string {
GetCacheRatioMap()
cacheRatioMapMutex.RLock()
defer cacheRatioMapMutex.RUnlock()
jsonBytes, err := json.Marshal(cacheRatioMap)
if err != nil {
common.SysError("error marshalling cache ratio: " + err.Error())
@@ -80,10 +82,11 @@ func UpdateCacheRatioByJSONString(jsonStr string) error {
// GetCacheRatio returns the cache ratio for a model
func GetCacheRatio(name string) (float64, bool) {
GetCacheRatioMap()
cacheRatioMapMutex.RLock()
defer cacheRatioMapMutex.RUnlock()
ratio, ok := cacheRatioMap[name]
if !ok {
return 1, false // Default to 0.5 if not found
return 1, false // Default to 1 if not found
}
return ratio, true
}

View File

@@ -3,12 +3,16 @@ package operation_setting
import "one-api/setting/config"
type GeneralSetting struct {
DocsLink string `json:"docs_link"`
DocsLink string `json:"docs_link"`
PingIntervalEnabled bool `json:"ping_interval_enabled"`
PingIntervalSeconds int `json:"ping_interval_seconds"`
}
// 默认配置
var generalSetting = GeneralSetting{
DocsLink: "https://docs.newapi.pro",
DocsLink: "https://docs.newapi.pro",
PingIntervalEnabled: false,
PingIntervalSeconds: 60,
}
func init() {

View File

@@ -131,17 +131,12 @@ var defaultModelRatio = map[string]float64{
"bge-large-en": 0.002 * RMB,
"tao-8k": 0.002 * RMB,
"PaLM-2": 1,
"gemini-pro": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens
"gemini-pro-vision": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens
"gemini-1.0-pro-vision-001": 1,
"gemini-1.0-pro-001": 1,
"gemini-1.5-pro-latest": 1.75, // $3.5 / 1M tokens
"gemini-1.5-pro-exp-0827": 1.75, // $3.5 / 1M tokens
"gemini-1.5-flash-latest": 1,
"gemini-1.5-flash-exp-0827": 1,
"gemini-1.0-pro-latest": 1,
"gemini-1.0-pro-vision-latest": 1,
"gemini-ultra": 1,
"gemini-1.5-pro-latest": 1.25, // $3.5 / 1M tokens
"gemini-1.5-flash-latest": 0.075,
"gemini-2.0-flash": 0.05,
"gemini-2.5-pro-exp-03-25": 0.625,
"gemini-2.5-pro-preview-03-25": 0.625,
"text-embedding-004": 0.001,
"chatglm_turbo": 0.3572, // ¥0.005 / 1k tokens
"chatglm_pro": 0.7143, // ¥0.01 / 1k tokens
"chatglm_std": 0.3572, // ¥0.005 / 1k tokens
@@ -204,29 +199,39 @@ var defaultModelRatio = map[string]float64{
"llama-3-sonar-small-32k-online": 0.2 / 1000 * USD,
"llama-3-sonar-large-32k-chat": 1 / 1000 * USD,
"llama-3-sonar-large-32k-online": 1 / 1000 * USD,
// grok
"grok-3-beta": 1.5,
"grok-3-mini-beta": 0.15,
"grok-2": 1,
"grok-2-vision": 1,
"grok-beta": 2.5,
"grok-vision-beta": 2.5,
"grok-3-fast-beta": 2.5,
"grok-3-mini-fast-beta": 0.3,
}
var defaultModelPrice = map[string]float64{
"suno_music": 0.1,
"suno_lyrics": 0.01,
"dall-e-3": 0.04,
"gpt-4-gizmo-*": 0.1,
"mj_imagine": 0.1,
"mj_variation": 0.1,
"mj_reroll": 0.1,
"mj_blend": 0.1,
"mj_modal": 0.1,
"mj_zoom": 0.1,
"mj_shorten": 0.1,
"mj_high_variation": 0.1,
"mj_low_variation": 0.1,
"mj_pan": 0.1,
"mj_inpaint": 0,
"mj_custom_zoom": 0,
"mj_describe": 0.05,
"mj_upscale": 0.05,
"swap_face": 0.05,
"mj_upload": 0.05,
"suno_music": 0.1,
"suno_lyrics": 0.01,
"dall-e-3": 0.04,
"imagen-3.0-generate-002": 0.03,
"gpt-4-gizmo-*": 0.1,
"mj_imagine": 0.1,
"mj_variation": 0.1,
"mj_reroll": 0.1,
"mj_blend": 0.1,
"mj_modal": 0.1,
"mj_zoom": 0.1,
"mj_shorten": 0.1,
"mj_high_variation": 0.1,
"mj_low_variation": 0.1,
"mj_pan": 0.1,
"mj_inpaint": 0,
"mj_custom_zoom": 0,
"mj_describe": 0.05,
"mj_upscale": 0.05,
"swap_face": 0.05,
"mj_upload": 0.05,
}
var (
@@ -249,17 +254,41 @@ var defaultCompletionRatio = map[string]float64{
"gpt-4-all": 2,
}
func GetModelPriceMap() map[string]float64 {
// InitModelSettings initializes all model related settings maps
func InitModelSettings() {
// Initialize modelPriceMap
modelPriceMapMutex.Lock()
defer modelPriceMapMutex.Unlock()
if modelPriceMap == nil {
modelPriceMap = defaultModelPrice
}
modelPriceMap = defaultModelPrice
modelPriceMapMutex.Unlock()
// Initialize modelRatioMap
modelRatioMapMutex.Lock()
modelRatioMap = defaultModelRatio
modelRatioMapMutex.Unlock()
// Initialize CompletionRatio
CompletionRatioMutex.Lock()
CompletionRatio = defaultCompletionRatio
CompletionRatioMutex.Unlock()
// Initialize cacheRatioMap
cacheRatioMapMutex.Lock()
cacheRatioMap = defaultCacheRatio
cacheRatioMapMutex.Unlock()
common.SysLog("model settings initialized")
}
func GetModelPriceMap() map[string]float64 {
modelPriceMapMutex.RLock()
defer modelPriceMapMutex.RUnlock()
return modelPriceMap
}
func ModelPrice2JSONString() string {
GetModelPriceMap()
modelPriceMapMutex.RLock()
defer modelPriceMapMutex.RUnlock()
jsonBytes, err := json.Marshal(modelPriceMap)
if err != nil {
common.SysError("error marshalling model price: " + err.Error())
@@ -276,7 +305,9 @@ func UpdateModelPriceByJSONString(jsonStr string) error {
// GetModelPrice 返回模型的价格,如果模型不存在则返回-1false
func GetModelPrice(name string, printErr bool) (float64, bool) {
GetModelPriceMap()
modelPriceMapMutex.RLock()
defer modelPriceMapMutex.RUnlock()
if strings.HasPrefix(name, "gpt-4-gizmo") {
name = "gpt-4-gizmo-*"
}
@@ -293,24 +324,6 @@ func GetModelPrice(name string, printErr bool) (float64, bool) {
return price, true
}
func GetModelRatioMap() map[string]float64 {
modelRatioMapMutex.Lock()
defer modelRatioMapMutex.Unlock()
if modelRatioMap == nil {
modelRatioMap = defaultModelRatio
}
return modelRatioMap
}
func ModelRatio2JSONString() string {
GetModelRatioMap()
jsonBytes, err := json.Marshal(modelRatioMap)
if err != nil {
common.SysError("error marshalling model ratio: " + err.Error())
}
return string(jsonBytes)
}
func UpdateModelRatioByJSONString(jsonStr string) error {
modelRatioMapMutex.Lock()
defer modelRatioMapMutex.Unlock()
@@ -319,7 +332,9 @@ func UpdateModelRatioByJSONString(jsonStr string) error {
}
func GetModelRatio(name string) (float64, bool) {
GetModelRatioMap()
modelRatioMapMutex.RLock()
defer modelRatioMapMutex.RUnlock()
if strings.HasPrefix(name, "gpt-4-gizmo") {
name = "gpt-4-gizmo-*"
}
@@ -343,16 +358,15 @@ func GetDefaultModelRatioMap() map[string]float64 {
}
func GetCompletionRatioMap() map[string]float64 {
CompletionRatioMutex.Lock()
defer CompletionRatioMutex.Unlock()
if CompletionRatio == nil {
CompletionRatio = defaultCompletionRatio
}
CompletionRatioMutex.RLock()
defer CompletionRatioMutex.RUnlock()
return CompletionRatio
}
func CompletionRatio2JSONString() string {
GetCompletionRatioMap()
CompletionRatioMutex.RLock()
defer CompletionRatioMutex.RUnlock()
jsonBytes, err := json.Marshal(CompletionRatio)
if err != nil {
common.SysError("error marshalling completion ratio: " + err.Error())
@@ -368,13 +382,25 @@ func UpdateCompletionRatioByJSONString(jsonStr string) error {
}
func GetCompletionRatio(name string) float64 {
GetCompletionRatioMap()
CompletionRatioMutex.RLock()
defer CompletionRatioMutex.RUnlock()
if strings.Contains(name, "/") {
if ratio, ok := CompletionRatio[name]; ok {
return ratio
}
}
hardCodedRatio, contain := getHardcodedCompletionModelRatio(name)
if contain {
return hardCodedRatio
}
if ratio, ok := CompletionRatio[name]; ok {
return ratio
}
return hardCodedRatio
}
func getHardcodedCompletionModelRatio(name string) (float64, bool) {
lowercaseName := strings.ToLower(name)
if strings.HasPrefix(name, "gpt-4-gizmo") {
name = "gpt-4-gizmo-*"
@@ -385,87 +411,93 @@ func GetCompletionRatio(name string) float64 {
if strings.HasPrefix(name, "gpt-4") && !strings.HasSuffix(name, "-all") && !strings.HasSuffix(name, "-gizmo-*") {
if strings.HasPrefix(name, "gpt-4o") {
if name == "gpt-4o-2024-05-13" {
return 3
return 3, true
}
return 4
return 4, true
}
if strings.HasPrefix(name, "gpt-4.5") {
return 2
// gpt-4.5-preview匹配
if strings.HasPrefix(name, "gpt-4.5-preview") {
return 2, true
}
if strings.HasPrefix(name, "gpt-4-turbo") || strings.HasSuffix(name, "preview") {
return 3
if strings.HasPrefix(name, "gpt-4-turbo") || strings.HasSuffix(name, "gpt-4-1106") || strings.HasSuffix(name, "gpt-4-1105") {
return 3, true
}
return 2
// 没有特殊标记的 gpt-4 模型默认倍率为 2
return 2, false
}
if strings.HasPrefix(name, "o1") || strings.HasPrefix(name, "o3") {
return 4
return 4, true
}
if name == "chatgpt-4o-latest" {
return 3
return 3, true
}
if strings.Contains(name, "claude-instant-1") {
return 3
return 3, true
} else if strings.Contains(name, "claude-2") {
return 3
return 3, true
} else if strings.Contains(name, "claude-3") {
return 5
return 5, true
}
if strings.HasPrefix(name, "gpt-3.5") {
if name == "gpt-3.5-turbo" || strings.HasSuffix(name, "0125") {
// https://openai.com/blog/new-embedding-models-and-api-updates
// Updated GPT-3.5 Turbo model and lower pricing
return 3
return 3, true
}
if strings.HasSuffix(name, "1106") {
return 2
return 2, true
}
return 4.0 / 3.0
return 4.0 / 3.0, true
}
if strings.HasPrefix(name, "mistral-") {
return 3
return 3, true
}
if strings.HasPrefix(name, "gemini-") {
return 4
if strings.HasPrefix(name, "gemini-1.5") {
return 4, true
} else if strings.HasPrefix(name, "gemini-2.0") {
return 4, true
} else if strings.HasPrefix(name, "gemini-2.5-pro-preview") {
return 8, true
}
return 4, false
}
if strings.HasPrefix(name, "command") {
switch name {
case "command-r":
return 3
return 3, true
case "command-r-plus":
return 5
return 5, true
case "command-r-08-2024":
return 4
return 4, true
case "command-r-plus-08-2024":
return 4
return 4, true
default:
return 4
return 4, false
}
}
// hint 只给官方上4倍率由于开源模型供应商自行定价不对其进行补全倍率进行强制对齐
if lowercaseName == "deepseek-chat" || lowercaseName == "deepseek-reasoner" {
return 4
return 4, true
}
if strings.HasPrefix(name, "ERNIE-Speed-") {
return 2
return 2, true
} else if strings.HasPrefix(name, "ERNIE-Lite-") {
return 2
return 2, true
} else if strings.HasPrefix(name, "ERNIE-Character") {
return 2
return 2, true
} else if strings.HasPrefix(name, "ERNIE-Functions") {
return 2
return 2, true
}
switch name {
case "llama2-70b-4096":
return 0.8 / 0.64
return 0.8 / 0.64, true
case "llama3-8b-8192":
return 2
return 2, true
case "llama3-70b-8192":
return 0.79 / 0.59
return 0.79 / 0.59, true
}
if ratio, ok := CompletionRatio[name]; ok {
return ratio
}
return 1
return 1, false
}
func GetAudioRatio(name string) float64 {
@@ -498,3 +530,14 @@ func GetAudioCompletionRatio(name string) float64 {
}
return 2
}
func ModelRatio2JSONString() string {
modelRatioMapMutex.RLock()
defer modelRatioMapMutex.RUnlock()
jsonBytes, err := json.Marshal(modelRatioMap)
if err != nil {
common.SysError("error marshalling model ratio: " + err.Error())
}
return string(jsonBytes)
}

View File

@@ -23,7 +23,7 @@
"react-turnstile": "^1.0.5",
"semantic-ui-offline": "^2.5.0",
"semantic-ui-react": "^2.1.3",
"sse": "github:mpetazzoni/sse.js",
"sse": "https://github.com/mpetazzoni/sse.js",
"i18next": "^23.16.8",
"react-i18next": "^13.0.0",
"i18next-browser-languagedetector": "^7.2.0"

5319
web/pnpm-lock.yaml generated

File diff suppressed because it is too large Load Diff

View File

@@ -25,6 +25,8 @@ import Task from "./pages/Task/index.js";
import Playground from './pages/Playground/Playground.js';
import OAuth2Callback from "./components/OAuth2Callback.js";
import PersonalSetting from './components/PersonalSetting.js';
import Setup from './pages/Setup/index.js';
import SetupCheck from './components/SetupCheck';
const Home = lazy(() => import('./pages/Home'));
const Detail = lazy(() => import('./pages/Detail'));
@@ -34,7 +36,7 @@ function App() {
const location = useLocation();
return (
<>
<SetupCheck>
<Routes>
<Route
path='/'
@@ -44,6 +46,14 @@ function App() {
</Suspense>
}
/>
<Route
path='/setup'
element={
<Suspense fallback={<Loading></Loading>} key={location.pathname}>
<Setup />
</Suspense>
}
/>
<Route
path='/channel'
element={
@@ -277,7 +287,7 @@ function App() {
/>
<Route path='*' element={<NotFound />} />
</Routes>
</>
</SetupCheck>
);
}

View File

@@ -6,17 +6,27 @@ const LinuxDoIcon = (props) => {
return (
<svg
className='icon'
viewBox='0 0 24 24'
viewBox='0 0 16 16'
version='1.1'
xmlns='http://www.w3.org/2000/svg'
width='1em'
height='1em'
{...props}
>
<path
d='M19.7,17.6c-0.1-0.2-0.2-0.4-0.2-0.6c0-0.4-0.2-0.7-0.5-1c-0.1-0.1-0.3-0.2-0.4-0.2c0.6-1.8-0.3-3.6-1.3-4.9c0,0,0,0,0,0c-0.8-1.2-2-2.1-1.9-3.7c0-1.9,0.2-5.4-3.3-5.1C8.5,2.3,9.5,6,9.4,7.3c0,1.1-0.5,2.2-1.3,3.1c-0.2,0.2-0.4,0.5-0.5,0.7c-1,1.2-1.5,2.8-1.5,4.3c-0.2,0.2-0.4,0.4-0.5,0.6c-0.1,0.1-0.2,0.2-0.2,0.3c-0.1,0.1-0.3,0.2-0.5,0.3c-0.4,0.1-0.7,0.3-0.9,0.7c-0.1,0.3-0.2,0.7-0.1,1.1c0.1,0.2,0.1,0.4,0,0.7c-0.2,0.4-0.2,0.9,0,1.4c0.3,0.4,0.8,0.5,1.5,0.6c0.5,0,1.1,0.2,1.6,0.4l0,0c0.5,0.3,1.1,0.5,1.7,0.5c0.3,0,0.7-0.1,1-0.2c0.3-0.2,0.5-0.4,0.6-0.7c0.4,0,1-0.2,1.7-0.2c0.6,0,1.2,0.2,2,0.1c0,0.1,0,0.2,0.1,0.3c0.2,0.5,0.7,0.9,1.3,1c0.1,0,0.1,0,0.2,0c0.8-0.1,1.6-0.5,2.1-1.1l0,0c0.4-0.4,0.9-0.7,1.4-0.9c0.6-0.3,1-0.5,1.1-1C20.3,18.6,20.1,18.2,19.7,17.6z M12.8,4.8c0.6,0.1,1.1,0.6,1,1.2c0,0.3-0.1,0.6-0.3,0.9c0,0,0,0-0.1,0c-0.2-0.1-0.3-0.1-0.4-0.2c0.1-0.1,0.1-0.3,0.2-0.5c0-0.4-0.2-0.7-0.4-0.7c-0.3,0-0.5,0.3-0.5,0.7c0,0,0,0.1,0,0.1c-0.1-0.1-0.3-0.1-0.4-0.2c0,0,0-0.1,0-0.1C11.8,5.5,12.2,4.9,12.8,4.8z M12.5,6.8c0.1,0.1,0.3,0.2,0.4,0.2c0.1,0,0.3,0.1,0.4,0.2c0.2,0.1,0.4,0.2,0.4,0.5c0,0.3-0.3,0.6-0.9,0.8c-0.2,0.1-0.3,0.1-0.4,0.2c-0.3,0.2-0.6,0.3-1,0.3c-0.3,0-0.6-0.2-0.8-0.4c-0.1-0.1-0.2-0.2-0.4-0.3C10.1,8.2,9.9,8,9.8,7.7c0-0.1,0.1-0.2,0.2-0.3c0.3-0.2,0.4-0.3,0.5-0.4l0.1-0.1c0.2-0.3,0.6-0.5,1-0.5C11.9,6.5,12.2,6.6,12.5,6.8z M10.4,5c0.4,0,0.7,0.4,0.8,1.1c0,0.1,0,0.1,0,0.2c-0.1,0-0.3,0.1-0.4,0.2c0,0,0-0.1,0-0.2c0-0.3-0.2-0.6-0.4-0.5c-0.2,0-0.3,0.3-0.3,0.6c0,0.2,0.1,0.3,0.2,0.4l0,0c0,0-0.1,0.1-0.2,0.1C9.9,6.7,9.7,6.4,9.7,6.1C9.7,5.5,10,5,10.4,5z M9.4,21.1c-0.7,0.3-1.6,0.2-2.2-0.2c-0.6-0.3-1.1-0.4-1.8-0.4c-0.5-0.1-1-0.1-1.1-0.3c-0.1-0.2-0.1-0.5,0.1-1c0.1-0.3,0.1-0.6,0-0.9c-0.1-0.3-0.1-0.5,0-0.8C4.5,17.2,4.7,17.1,5,17c0.3-0.1,0.5-0.2,0.7-0.4c0.1-0.1,0.2-0.2,0.3-0.4c0.3-0.4,0.5-0.6,0.8-0.6c0.6,0.1,1.1,1,1.5,1.9c0.2,0.3,0.4,0.7,0.7,1c0.4,0.5,0.9,1.2,0.9,1.6C9.9,20.6,9.7,20.9,9.4,21.1z M14.3,18.9c0,0.1,0,0.1-0.1,0.2c-1.2,0.9-2.8,1-4.1,0.3c-0.2-0.3-0.4-0.6-0.6-0.9c0.9-0.1,0.7-1.3-1.2-2.5c-2-1.3-0.6-3.7,0.1-4.8c0.1-0.1,0.1,0-0.3,0.8c-0.3,0.6-0.9,2.1-0.1,3.2c0-0.8,0.2-1.6,0.5-2.4c0.7-1.3,1.2-2.8,1.5-4.3c0.1,0.1,0.1,0.1,0.2,0.1c0.1,0.1,0.2,0.2,0.3,0.2c0.2,0.3,0.6,0.4,0.9,0.4c0,0,0.1,0,0.1,0c0.4,0,0.8-0.1,1.1-0.4c0.1-0.1,0.2-0.2,0.4-0.2c0.3-0.1,0.6-0.3,0.9-0.6c0.4,1.3,0.8,2.5,1.4,3.6c0.4,0.8,0.7,1.6,0.9,2.5c0.3,0,0.7,0.1,1,0.3c0.8,0.4,1.1,0.7,1,1.2c-0.1,0-0.1,0-0.2,0c0-0.3-0.2-0.6-0.9-0.9c-0.7-0.3-1.3-0.3-1.5,0.4c-0.1,0-0.2,0.1-0.3,0.1c-0.8,0.4-0.8,1.5-0.9,2.6C14.5,18.2,14.4,18.5,14.3,18.9z M18.9,19.5c-0.6,0.2-1.1,0.6-1.5,1.1c-0.4,0.6-1.1,1-1.9,0.9c-0.4,0-0.8-0.3-0.9-0.7c-0.1-0.6-0.1-1.2,0.2-1.8c0.1-0.4,0.2-0.7,0.3-1.1c0.1-1.2,0.1-1.9,0.6-2.2h0c0,0.5,0.3,0.8,0.7,1c0.5,0,1-0.1,1.4-0.5c0.1,0,0.1,0,0.2,0c0.3,0,0.5,0,0.7,0.2c0.2,0.2,0.3,0.5,0.3,0.7c0,0.3,0.2,0.6,0.3,0.9c0.5,0.5,0.5,0.8,0.5,0.9C19.7,19.1,19.3,19.3,18.9,19.5z M9.9,7.5c-0.1,0-0.1,0-0.1,0.1c0,0,0,0.1,0.1,0.1c0,0,0,0,0,0c0.1,0,0.1,0.1,0.1,0.1c0.3,0.4,0.8,0.6,1.4,0.7c0.5-0.1,1-0.2,1.5-0.6c0.2-0.1,0.4-0.2,0.6-0.3c0.1,0,0.1-0.1,0.1-0.1c0-0.1,0-0.1-0.1-0.1l0,0c-0.2,0.1-0.5,0.2-0.7,0.3c-0.4,0.3-0.9,0.5-1.4,0.5c-0.5,0-0.9-0.3-1.2-0.6C10.1,7.6,10,7.5,9.9,7.5z'
fill='currentColor'
/>
<g id='linuxdo_icon' data-name='linuxdo_icon'>
<path
d='m7.44,0s.09,0,.13,0c.09,0,.19,0,.28,0,.14,0,.29,0,.43,0,.09,0,.18,0,.27,0q.12,0,.25,0t.26.08c.15.03.29.06.44.08,1.97.38,3.78,1.47,4.95,3.11.04.06.09.12.13.18.67.96,1.15,2.11,1.3,3.28q0,.19.09.26c0,.15,0,.29,0,.44,0,.04,0,.09,0,.13,0,.09,0,.19,0,.28,0,.14,0,.29,0,.43,0,.09,0,.18,0,.27,0,.08,0,.17,0,.25q0,.19-.08.26c-.03.15-.06.29-.08.44-.38,1.97-1.47,3.78-3.11,4.95-.06.04-.12.09-.18.13-.96.67-2.11,1.15-3.28,1.3q-.19,0-.26.09c-.15,0-.29,0-.44,0-.04,0-.09,0-.13,0-.09,0-.19,0-.28,0-.14,0-.29,0-.43,0-.09,0-.18,0-.27,0-.08,0-.17,0-.25,0q-.19,0-.26-.08c-.15-.03-.29-.06-.44-.08-1.97-.38-3.78-1.47-4.95-3.11q-.07-.09-.13-.18c-.67-.96-1.15-2.11-1.3-3.28q0-.19-.09-.26c0-.15,0-.29,0-.44,0-.04,0-.09,0-.13,0-.09,0-.19,0-.28,0-.14,0-.29,0-.43,0-.09,0-.18,0-.27,0-.08,0-.17,0-.25q0-.19.08-.26c.03-.15.06-.29.08-.44.38-1.97,1.47-3.78,3.11-4.95.06-.04.12-.09.18-.13C4.42.73,5.57.26,6.74.1,7,.07,7.15,0,7.44,0Z'
fill='#EFEFEF'
/>
<path
d='m1.27,11.33h13.45c-.94,1.89-2.51,3.21-4.51,3.88-1.99.59-3.96.37-5.8-.57-1.25-.7-2.67-1.9-3.14-3.3Z'
fill='#FEB005'
/>
<path
d='m12.54,1.99c.87.7,1.82,1.59,2.18,2.68H1.27c.87-1.74,2.33-3.13,4.2-3.78,2.44-.79,5-.47,7.07,1.1Z'
fill='#1D1D1F'
/>
</g>
</svg>
);
}
@@ -24,4 +34,4 @@ const LinuxDoIcon = (props) => {
return <Icon svg={<CustomIcon />} />;
};
export default LinuxDoIcon;
export default LinuxDoIcon;

View File

@@ -18,6 +18,8 @@ const ModelSetting = () => {
'claude.default_max_tokens': '',
'claude.thinking_adapter_budget_tokens_percentage': 0.8,
'global.pass_through_request_enabled': false,
'general_setting.ping_interval_enabled': false,
'general_setting.ping_interval_seconds': 60,
});
let [loading, setLoading] = useState(false);

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,18 @@
import React, { useContext, useEffect } from 'react';
import { Navigate, useLocation } from 'react-router-dom';
import { StatusContext } from '../context/Status';
const SetupCheck = ({ children }) => {
const [statusState] = useContext(StatusContext);
const location = useLocation();
useEffect(() => {
if (statusState?.status?.setup === false && location.pathname !== '/setup') {
window.location.href = '/setup';
}
}, [statusState?.status?.setup, location.pathname]);
return children;
};
export default SetupCheck;

File diff suppressed because it is too large Load Diff

View File

@@ -115,4 +115,9 @@ export const CHANNEL_OPTIONS = [
color: 'blue',
label: '字节火山方舟、豆包、DeepSeek通用'
},
{
value: 48,
color: 'blue',
label: 'xAI'
}
];

View File

@@ -60,6 +60,9 @@ export const StyleProvider = ({ children }) => {
if (pathname === '' || pathname === '/' || pathname.includes('/home') || pathname.includes('/chat')) {
dispatch({ type: 'SET_SIDER', payload: false });
dispatch({ type: 'SET_INNER_PADDING', payload: false });
} else if (pathname === '/setup') {
dispatch({ type: 'SET_SIDER', payload: false });
dispatch({ type: 'SET_INNER_PADDING', payload: false });
} else {
// Only show sidebar on non-mobile devices by default
dispatch({ type: 'SET_SIDER', payload: !isMobile() });

View File

@@ -492,7 +492,7 @@
"请输入默认 API 版本例如2023-03-15-preview该配置可以被实际的请求查询参数所覆盖": "Please enter the default API version, for example: 2023-03-15-preview, this configuration can be overridden by the actual request query parameters",
"默认": "default",
"图片演示": "Image demo",
"参数替换为你的部署名称(模型名称中的点会被剔除)": "Replace the parameter with your deployment name (dots in the model name will be removed)",
"注意系统请求的时模型名称中的点会被剔除例如gpt-4.5-preview会请求为gpt-45-preview所以部署的模型名称需要去掉点": "Note that the dot in the model name requested by the system will be removed, for example: gpt-4.5-preview will be requested as gpt-45-preview, so the deployed model name needs to remove the dot",
"模型映射必须是合法的 JSON 格式!": "Model mapping must be in valid JSON format!",
"取消无限额度": "Cancel unlimited quota",
"取消": "Cancel",
@@ -514,7 +514,7 @@
",图片演示。": "related image demo.",
"令牌创建成功,请在列表页面点击复制获取令牌!": "Token created successfully, please click copy on the list page to get the token!",
"代理": "Proxy",
"此项可选,用于通过代理站来进行 API 调用,请输入代理站地址格式为https://domain.com": "This is optional, used to make API calls through the proxy site, please enter the proxy site address, the format is: https://domain.com",
"此项可选,用于通过自定义API地址来进行 API 调用,请输入API地址格式为https://domain.com": "This is optional, used to make API calls through the proxy site, please enter the proxy site address, the format is: https://domain.com",
"取消密码登录将导致所有未绑定其他登录方式的用户(包括管理员)无法通过密码登录,确认取消?": "Canceling password login will cause all users (including administrators) who have not bound other login methods to be unable to log in via password, confirm cancel?",
"按照如下格式输入:": "Enter in the following format:",
"模型版本": "Model version",
@@ -1111,7 +1111,7 @@
"如果你对接的是上游One API或者New API等转发项目请使用OpenAI类型不要使用此类型除非你知道你在做什么。": "If you are connecting to upstream One API or New API forwarding projects, please use OpenAI type. Do not use this type unless you know what you are doing.",
"完整的 Base URL支持变量{model}": "Complete Base URL, supports variable {model}",
"请输入完整的URL例如https://api.openai.com/v1/chat/completions": "Please enter complete URL, e.g.: https://api.openai.com/v1/chat/completions",
"此项可选,用于通过代理站来进行 API 调用,末尾不要带/v1和/": "Optional for API calls through proxy sites, do not end with /v1 and /",
"此项可选,用于通过自定义API地址来进行 API 调用,末尾不要带/v1和/": "Optional for API calls through custom API address, do not add /v1 and / at the end",
"私有部署地址": "Private Deployment Address",
"请输入私有部署地址格式为https://fastgpt.run/api/openapi": "Please enter private deployment address, format: https://fastgpt.run/api/openapi",
"注意非Chat API请务必填写正确的API地址否则可能导致无法使用": "Note: For non-Chat API, please make sure to enter the correct API address, otherwise it may not work",
@@ -1272,9 +1272,10 @@
"通知邮箱": "Notification email",
"设置用于接收额度预警的邮箱地址,不填则使用账号绑定的邮箱": "Set the email address for receiving quota warning notifications, if not set, the email address bound to the account will be used",
"留空则使用账号绑定的邮箱": "If left blank, the email address bound to the account will be used",
"代理站地址": "Base URL",
"API地址": "Base URL",
"对于官方渠道new-api已经内置地址除非是第三方代理站点或者Azure的特殊接入地址否则不需要填写": "For official channels, the new-api has a built-in address. Unless it is a third-party proxy site or a special Azure access address, there is no need to fill it in",
"渠道额外设置": "Channel extra settings",
"参数覆盖": "Parameters override",
"模型请求速率限制": "Model request rate limit",
"启用用户模型请求速率限制(可能会影响高并发性能)": "Enable user model request rate limit (may affect high concurrency performance)",
"限制周期": "Limit period",
@@ -1345,5 +1346,26 @@
"提示缓存倍率": "Prompt cache ratio",
"缓存:${{price}} * {{ratio}} = ${{total}} / 1M tokens (缓存倍率: {{cacheRatio}})": "Cache: ${{price}} * {{ratio}} = ${{total}} / 1M tokens (cache ratio: {{cacheRatio}})",
"提示 {{nonCacheInput}} tokens + 缓存 {{cacheInput}} tokens * {{cacheRatio}} / 1M tokens * ${{price}} + 补全 {{completion}} tokens / 1M tokens * ${{compPrice}} * 分组 {{ratio}} = ${{total}}": "Prompt {{nonCacheInput}} tokens + cache {{cacheInput}} tokens * {{cacheRatio}} / 1M tokens * ${{price}} + completion {{completion}} tokens / 1M tokens * ${{compPrice}} * group {{ratio}} = ${{total}}",
"缓存 Tokens": "Cache Tokens"
"缓存 Tokens": "Cache Tokens",
"系统初始化": "System initialization",
"管理员账号已经初始化过,请继续设置系统参数": "The admin account has already been initialized, please continue to set the system parameters",
"管理员账号": "Admin account",
"请输入管理员用户名": "Please enter the admin username",
"请输入管理员密码": "Please enter the admin password",
"请确认管理员密码": "Please confirm the admin password",
"请选择使用模式": "Please select the usage mode",
"数据库警告": "Database warning",
"您正在使用 SQLite 数据库。如果您在容器环境中运行,请确保已正确设置数据库文件的持久化映射,否则容器重启后所有数据将丢失!": "You are using the SQLite database. If you are running in a container environment, please ensure that the database file persistence mapping is correctly set, otherwise all data will be lost after container restart!",
"建议在生产环境中使用 MySQL 或 PostgreSQL 数据库,或确保 SQLite 数据库文件已映射到宿主机的持久化存储。": "It is recommended to use MySQL or PostgreSQL databases in production environments, or ensure that the SQLite database file is mapped to the persistent storage of the host machine.",
"使用模式": "Usage mode",
"对外运营模式": "Default mode",
"密码长度至少为8个字符": "Password must be at least 8 characters long",
"表单引用错误,请刷新页面重试": "Form reference error, please refresh the page and try again",
"默认模式,适用于为多个用户提供服务的场景。": "Default mode, suitable for scenarios where multiple users are provided.",
"此模式下,系统将计算每次调用的用量,您需要对每个模型都设置价格,如果没有设置价格,用户将无法使用该模型。": "In this mode, the system will calculate the usage of each call, you need to set the price for each model, if the price is not set, the user will not be able to use the model.",
"适用于个人使用的场景。": "Suitable for personal use.",
"不需要设置模型价格,系统将弱化用量计算,您可专注于使用模型。": "No need to set the model price, the system will weaken the usage calculation, you can focus on using the model.",
"适用于展示系统功能的场景。": "Suitable for scenarios where the system functions are displayed.",
"可在初始化后修改": "Can be modified after initialization",
"初始化系统": "Initialize system"
}

View File

@@ -473,7 +473,7 @@ const EditChannel = (props) => {
<div style={{ marginTop: 10 }}>
<Banner
type={'warning'}
description={t('注意,模型部署名称必须和模型名称保持一致')}
description={t('注意,系统请求的时模型名称中的点会被剔除例如gpt-4.5-preview会请求为gpt-45-preview所以部署的模型名称需要去掉点')}
></Banner>
</div>
<div style={{ marginTop: 10 }}>
@@ -556,13 +556,13 @@ const EditChannel = (props) => {
{inputs.type !== 3 && inputs.type !== 8 && inputs.type !== 22 && inputs.type !== 36 && inputs.type !== 45 && (
<>
<div style={{ marginTop: 10 }}>
<Typography.Text strong>{t('代理站地址')}</Typography.Text>
<Typography.Text strong>{t('API地址')}</Typography.Text>
</div>
<Tooltip content={t('对于官方渠道new-api已经内置地址除非是第三方代理站点或者Azure的特殊接入地址否则不需要填写')}>
<Input
label={t('代理站地址')}
label={t('API地址')}
name="base_url"
placeholder={t('此项可选,用于通过代理站来进行 API 调用,末尾不要带/v1和/')}
placeholder={t('此项可选,用于通过自定义API地址来进行 API 调用,末尾不要带/v1和/')}
onChange={(value) => {
handleInputChange('base_url', value);
}}
@@ -983,6 +983,23 @@ const EditChannel = (props) => {
</Typography.Text>
</Space>
</>
<>
<div style={{ marginTop: 10 }}>
<Typography.Text strong>
{t('参数覆盖')}
</Typography.Text>
</div>
<TextArea
placeholder={t('此项可选,用于覆盖请求参数。不支持覆盖 stream 参数。为一个 JSON 字符串,例如:') + '\n{\n "temperature": 0\n}'}
name="setting"
onChange={(value) => {
handleInputChange('param_override', value);
}}
autosize
value={inputs.param_override}
autoComplete="new-password"
/>
</>
{inputs.type === 1 && (
<>
<div style={{ marginTop: 10 }}>

View File

@@ -68,7 +68,7 @@ const Home = () => {
useEffect(() => {
displayNotice().then();
displayHomePageContent().then();
});
}, []);
return (
<>
@@ -112,6 +112,7 @@ const Home = () => {
https://github.com/Calcium-Ion/new-api
</a>
</p>
<p>
{t('协议')}
<a

View File

@@ -1,5 +1,5 @@
import React, { useEffect, useState, useRef } from 'react';
import { Button, Col, Form, Row, Spin } from '@douyinfe/semi-ui';
import { Button, Col, Form, Row, Spin, Banner } from '@douyinfe/semi-ui';
import {
compareObjects,
API,
@@ -15,6 +15,8 @@ export default function SettingGlobalModel(props) {
const [loading, setLoading] = useState(false);
const [inputs, setInputs] = useState({
'global.pass_through_request_enabled': false,
'general_setting.ping_interval_enabled': false,
'general_setting.ping_interval_seconds': 60,
});
const refForm = useRef();
const [inputsRow, setInputsRow] = useState(inputs);
@@ -23,12 +25,8 @@ export default function SettingGlobalModel(props) {
const updateArray = compareObjects(inputs, inputsRow);
if (!updateArray.length) return showWarning(t('你似乎并没有修改什么'));
const requestQueue = updateArray.map((item) => {
let value = '';
if (typeof inputs[item.key] === 'boolean') {
value = String(inputs[item.key]);
} else {
value = inputs[item.key];
}
let value = String(inputs[item.key]);
return API.put('/api/option/', {
key: item.key,
value,
@@ -84,6 +82,36 @@ export default function SettingGlobalModel(props) {
/>
</Col>
</Row>
<Form.Section text={t('连接保活设置')}>
<Row style={{ marginTop: 10 }}>
<Col span={24}>
<Banner
type="warning"
description="警告启用保活后如果已经写入保活数据后渠道出错系统无法重试如果必须开启推荐设置尽可能大的Ping间隔"
/>
</Col>
</Row>
<Row>
<Col xs={24} sm={12} md={8} lg={8} xl={8}>
<Form.Switch
label={t('启用Ping间隔')}
field={'general_setting.ping_interval_enabled'}
onChange={(value) => setInputs({ ...inputs, 'general_setting.ping_interval_enabled': value })}
extraText={'开启后将定期发送ping数据保持连接活跃'}
/>
</Col>
<Col xs={24} sm={12} md={8} lg={8} xl={8}>
<Form.InputNumber
label={t('Ping间隔')}
field={'general_setting.ping_interval_seconds'}
onChange={(value) => setInputs({ ...inputs, 'general_setting.ping_interval_seconds': value })}
min={1}
disabled={!inputs['general_setting.ping_interval_enabled']}
/>
</Col>
</Row>
</Form.Section>
<Row>
<Button size='default' onClick={onSubmit}>

View File

@@ -1,10 +1,12 @@
// ModelSettingsVisualEditor.js
import React, { useEffect, useState } from 'react';
import { Table, Button, Input, Modal, Form, Space } from '@douyinfe/semi-ui';
import { IconDelete, IconPlus, IconSearch, IconSave } from '@douyinfe/semi-icons';
import React, { useContext, useEffect, useState, useRef } from 'react';
import { Table, Button, Input, Modal, Form, Space, RadioGroup, Radio, Tabs, TabPane } from '@douyinfe/semi-ui';
import { IconDelete, IconPlus, IconSearch, IconSave, IconEdit } from '@douyinfe/semi-icons';
import { showError, showSuccess } from '../../../helpers';
import { API } from '../../../helpers';
import { useTranslation } from 'react-i18next';
import { StatusContext } from '../../../context/Status/index.js';
import { getQuotaPerUnit } from '../../../helpers/render.js';
export default function ModelSettingsVisualEditor(props) {
const { t } = useTranslation();
@@ -14,7 +16,11 @@ export default function ModelSettingsVisualEditor(props) {
const [searchText, setSearchText] = useState('');
const [currentPage, setCurrentPage] = useState(1);
const [loading, setLoading] = useState(false);
const [pricingMode, setPricingMode] = useState('per-token'); // 'per-token' or 'per-request'
const [pricingSubMode, setPricingSubMode] = useState('ratio'); // 'ratio' or 'token-price'
const formRef = useRef(null);
const pageSize = 10;
const quotaPerUnit = getQuotaPerUnit()
useEffect(() => {
try {
@@ -171,11 +177,19 @@ export default function ModelSettingsVisualEditor(props) {
title: t('操作'),
key: 'action',
render: (_, record) => (
<Button
icon={<IconDelete />}
type="danger"
onClick={() => deleteModel(record.name)}
/>
<Space>
<Button
type="primary"
icon={<IconEdit />}
onClick={() => editModel(record)}
>
</Button>
<Button
icon={<IconDelete />}
type="danger"
onClick={() => deleteModel(record.name)}
/>
</Space>
)
}
];
@@ -197,28 +211,171 @@ export default function ModelSettingsVisualEditor(props) {
const deleteModel = (name) => {
setModels(prev => prev.filter(model => model.name !== name));
};
const addModel = (values) => {
// 检查模型名称是否存在, 如果存在则拒绝添加
if (models.some(model => model.name === values.name)) {
showError('模型名称已存在');
return;
const calculateRatioFromTokenPrice = (tokenPrice) => {
return tokenPrice / 2;
};
const calculateCompletionRatioFromPrices = (modelTokenPrice, completionTokenPrice) => {
if (!modelTokenPrice || modelTokenPrice === '0') {
showError('模型价格不能为0');
return '';
}
setModels(prev => [{
name: values.name,
price: values.price || '',
ratio: values.ratio || '',
completionRatio: values.completionRatio || ''
}, ...prev]);
setVisible(false);
showSuccess('添加成功');
return completionTokenPrice / modelTokenPrice;
};
const handleTokenPriceChange = (value) => {
// Use a temporary variable to hold the new state
let newState = {
...(currentModel || {}),
tokenPrice: value,
ratio: 0
};
if (!isNaN(value) && value !== '') {
const tokenPrice = parseFloat(value);
const ratio = calculateRatioFromTokenPrice(tokenPrice);
newState.ratio = ratio;
}
// Set the state with the complete updated object
setCurrentModel(newState);
};
const handleCompletionTokenPriceChange = (value) => {
// Use a temporary variable to hold the new state
let newState = {
...(currentModel || {}),
completionTokenPrice: value,
completionRatio: 0
};
if (!isNaN(value) && value !== '' && currentModel?.tokenPrice) {
const completionTokenPrice = parseFloat(value);
const modelTokenPrice = parseFloat(currentModel.tokenPrice);
if (modelTokenPrice > 0) {
const completionRatio = calculateCompletionRatioFromPrices(modelTokenPrice, completionTokenPrice);
newState.completionRatio = completionRatio;
}
}
// Set the state with the complete updated object
setCurrentModel(newState);
};
const addOrUpdateModel = (values) => {
// Check if we're editing an existing model or adding a new one
const existingModelIndex = models.findIndex(model => model.name === values.name);
if (existingModelIndex >= 0) {
// Update existing model
setModels(prev => prev.map((model, index) =>
index === existingModelIndex ? {
name: values.name,
price: values.price || '',
ratio: values.ratio || '',
completionRatio: values.completionRatio || ''
} : model
));
setVisible(false);
showSuccess(t('更新成功'));
} else {
// Add new model
// Check if model name already exists
if (models.some(model => model.name === values.name)) {
showError(t('模型名称已存在'));
return;
}
setModels(prev => [{
name: values.name,
price: values.price || '',
ratio: values.ratio || '',
completionRatio: values.completionRatio || ''
}, ...prev]);
setVisible(false);
showSuccess(t('添加成功'));
}
};
const calculateTokenPriceFromRatio = (ratio) => {
return ratio * 2;
};
const resetModalState = () => {
setCurrentModel(null);
setPricingMode('per-token');
setPricingSubMode('ratio');
};
const editModel = (record) => {
// Determine which pricing mode to use based on the model's current configuration
let initialPricingMode = 'per-token';
let initialPricingSubMode = 'ratio';
if (record.price !== '') {
initialPricingMode = 'per-request';
} else {
initialPricingMode = 'per-token';
// We default to ratio mode, but could set to token-price if needed
}
// Set the pricing modes for the form
setPricingMode(initialPricingMode);
setPricingSubMode(initialPricingSubMode);
// Create a copy of the model data to avoid modifying the original
const modelCopy = { ...record };
// If the model has ratio data and we want to populate token price fields
if (record.ratio) {
modelCopy.tokenPrice = calculateTokenPriceFromRatio(parseFloat(record.ratio)).toString();
if (record.completionRatio) {
modelCopy.completionTokenPrice = (parseFloat(modelCopy.tokenPrice) * parseFloat(record.completionRatio)).toString();
}
}
// Set the current model
setCurrentModel(modelCopy);
// Open the modal
setVisible(true);
// Use setTimeout to ensure the form is rendered before setting values
setTimeout(() => {
if (formRef.current) {
// Update the form fields based on pricing mode
const formValues = {
name: modelCopy.name,
};
if (initialPricingMode === 'per-request') {
formValues.priceInput = modelCopy.price;
} else if (initialPricingMode === 'per-token') {
formValues.ratioInput = modelCopy.ratio;
formValues.completionRatioInput = modelCopy.completionRatio;
formValues.modelTokenPrice = modelCopy.tokenPrice;
formValues.completionTokenPrice = modelCopy.completionTokenPrice;
}
formRef.current.setValues(formValues);
}
}, 0);
};
return (
<>
<Space vertical align="start" style={{ width: '100%' }}>
<Space>
<Button icon={<IconPlus />} onClick={() => setVisible(true)}>
<Button icon={<IconPlus />} onClick={() => {
resetModalState();
setVisible(true);
}}>
{t('添加模型')}
</Button>
<Button type="primary" icon={<IconSave />} onClick={SubmitData}>
@@ -256,56 +413,205 @@ export default function ModelSettingsVisualEditor(props) {
</Space>
<Modal
title={t('添加模型')}
title={currentModel && currentModel.name && models.some(model => model.name === currentModel.name) ? t('编辑模型') : t('添加模型')}
visible={visible}
onCancel={() => setVisible(false)}
onCancel={() => {
resetModalState();
setVisible(false);
}}
onOk={() => {
currentModel && addModel(currentModel);
if (currentModel) {
// If we're in token price mode, make sure ratio values are properly set
const valuesToSave = { ...currentModel };
if (pricingMode === 'per-token' && pricingSubMode === 'token-price' && currentModel.tokenPrice) {
// Calculate and set ratio from token price
const tokenPrice = parseFloat(currentModel.tokenPrice);
valuesToSave.ratio = (tokenPrice / 2).toString();
// Calculate and set completion ratio if both token prices are available
if (currentModel.completionTokenPrice && currentModel.tokenPrice) {
const completionPrice = parseFloat(currentModel.completionTokenPrice);
const modelPrice = parseFloat(currentModel.tokenPrice);
if (modelPrice > 0) {
valuesToSave.completionRatio = (completionPrice / modelPrice).toString();
}
}
}
// Clear price if we're in per-token mode
if (pricingMode === 'per-token') {
valuesToSave.price = '';
} else {
// Clear ratios if we're in per-request mode
valuesToSave.ratio = '';
valuesToSave.completionRatio = '';
}
addOrUpdateModel(valuesToSave);
}
}}
>
<Form>
<Form getFormApi={api => formRef.current = api}>
<Form.Input
field="name"
label={t('模型名称')}
placeholder="strawberry"
required
disabled={currentModel && currentModel.name && models.some(model => model.name === currentModel.name)}
onChange={value => setCurrentModel(prev => ({ ...prev, name: value }))}
/>
<Form.Switch
field="priceMode"
label={<>{t('定价模式')}{currentModel?.priceMode ? t("固定价格") : t("倍率模式")}</>}
onChange={checked => {
setCurrentModel(prev => ({
...prev,
price: '',
ratio: '',
completionRatio: '',
priceMode: checked
}));
}}
/>
{currentModel?.priceMode ? (
<Form.Section text={t('定价模式')}>
<div style={{ marginBottom: '16px' }}>
<RadioGroup type="button" value={pricingMode} onChange={(e) => {
const newMode = e.target.value;
const oldMode = pricingMode;
setPricingMode(newMode);
// Instead of resetting all values, convert between modes
if (currentModel) {
const updatedModel = { ...currentModel };
// Update formRef with converted values
if (formRef.current) {
const formValues = {
name: updatedModel.name
};
if (newMode === 'per-request') {
formValues.priceInput = updatedModel.price || '';
} else if (newMode === 'per-token') {
formValues.ratioInput = updatedModel.ratio || '';
formValues.completionRatioInput = updatedModel.completionRatio || '';
formValues.modelTokenPrice = updatedModel.tokenPrice || '';
formValues.completionTokenPrice = updatedModel.completionTokenPrice || '';
}
formRef.current.setValues(formValues);
}
// Update the model state
setCurrentModel(updatedModel);
}
}}>
<Radio value="per-token">{t('按量计费')}</Radio>
<Radio value="per-request">{t('按次计费')}</Radio>
</RadioGroup>
</div>
</Form.Section>
{pricingMode === 'per-token' && (
<>
<Form.Section text={t('价格设置方式')}>
<div style={{ marginBottom: '16px' }}>
<RadioGroup type="button" value={pricingSubMode} onChange={(e) => {
const newSubMode = e.target.value;
const oldSubMode = pricingSubMode;
setPricingSubMode(newSubMode);
// Handle conversion between submodes
if (currentModel) {
const updatedModel = { ...currentModel };
// Convert between ratio and token price
if (oldSubMode === 'ratio' && newSubMode === 'token-price') {
if (updatedModel.ratio) {
updatedModel.tokenPrice = calculateTokenPriceFromRatio(parseFloat(updatedModel.ratio)).toString();
if (updatedModel.completionRatio) {
updatedModel.completionTokenPrice = (parseFloat(updatedModel.tokenPrice) * parseFloat(updatedModel.completionRatio)).toString();
}
}
} else if (oldSubMode === 'token-price' && newSubMode === 'ratio') {
// Ratio values should already be calculated by the handlers
}
// Update the form values
if (formRef.current) {
const formValues = {};
if (newSubMode === 'ratio') {
formValues.ratioInput = updatedModel.ratio || '';
formValues.completionRatioInput = updatedModel.completionRatio || '';
} else if (newSubMode === 'token-price') {
formValues.modelTokenPrice = updatedModel.tokenPrice || '';
formValues.completionTokenPrice = updatedModel.completionTokenPrice || '';
}
formRef.current.setValues(formValues);
}
setCurrentModel(updatedModel);
}
}}>
<Radio value="ratio">{t('按倍率设置')}</Radio>
<Radio value="token-price">{t('按价格设置')}</Radio>
</RadioGroup>
</div>
</Form.Section>
{pricingSubMode === 'ratio' && (
<>
<Form.Input
field="ratioInput"
label={t('模型倍率')}
placeholder={t('输入模型倍率')}
onChange={value => setCurrentModel(prev => ({
...prev || {},
ratio: value
}))}
initValue={currentModel?.ratio || ''}
/>
<Form.Input
field="completionRatioInput"
label={t('补全倍率')}
placeholder={t('输入补全倍率')}
onChange={value => setCurrentModel(prev => ({
...prev || {},
completionRatio: value
}))}
initValue={currentModel?.completionRatio || ''}
/>
</>
)}
{pricingSubMode === 'token-price' && (
<>
<Form.Input
field="modelTokenPrice"
label={t('输入价格')}
onChange={(value) => {
handleTokenPriceChange(value);
}}
initValue={currentModel?.tokenPrice || ''}
suffix={t('$/1M tokens')}
/>
<Form.Input
field="completionTokenPrice"
label={t('输出价格')}
onChange={(value) => {
handleCompletionTokenPriceChange(value);
}}
initValue={currentModel?.completionTokenPrice || ''}
suffix={t('$/1M tokens')}
/>
</>
)}
</>
)}
{pricingMode === 'per-request' && (
<Form.Input
field="price"
field="priceInput"
label={t('固定价格(每次)')}
placeholder={t('输入每次价格')}
onChange={value => setCurrentModel(prev => ({ ...prev, price: value }))}
onChange={value => setCurrentModel(prev => ({
...prev || {},
price: value
}))}
initValue={currentModel?.price || ''}
/>
) : (
<>
<Form.Input
field="ratio"
label={t('模型倍率')}
placeholder={t('输入模型倍率')}
onChange={value => setCurrentModel(prev => ({ ...prev, ratio: value }))}
/>
<Form.Input
field="completionRatio"
label={t('补全倍率')}
placeholder={t('输入补全价格')}
onChange={value => setCurrentModel(prev => ({ ...prev, completionRatio: value }))}
/>
</>
)}
</Form>
</Modal>

View File

@@ -27,9 +27,10 @@ export default function GeneralSettings(props) {
const refForm = useRef();
const [inputsRow, setInputsRow] = useState(inputs);
function onChange(value, e) {
const name = e.target.id;
setInputs((inputs) => ({ ...inputs, [name]: value }));
function handleFieldChange(fieldName) {
return (value) => {
setInputs((inputs) => ({ ...inputs, [fieldName]: value }));
};
}
function onSubmit() {
@@ -98,7 +99,7 @@ export default function GeneralSettings(props) {
label={t('充值链接')}
initValue={''}
placeholder={t('例如发卡网站的购买链接')}
onChange={onChange}
onChange={handleFieldChange('TopUpLink')}
showClear
/>
</Col>
@@ -108,7 +109,7 @@ export default function GeneralSettings(props) {
label={t('文档地址')}
initValue={''}
placeholder={t('例如 https://docs.newapi.pro')}
onChange={onChange}
onChange={handleFieldChange('general_setting.docs_link')}
showClear
/>
</Col>
@@ -118,7 +119,7 @@ export default function GeneralSettings(props) {
label={t('单位美元额度')}
initValue={''}
placeholder={t('一单位货币能兑换的额度')}
onChange={onChange}
onChange={handleFieldChange('QuotaPerUnit')}
showClear
onClick={() => setShowQuotaWarning(true)}
/>
@@ -129,7 +130,7 @@ export default function GeneralSettings(props) {
label={t('失败重试次数')}
initValue={''}
placeholder={t('失败重试次数')}
onChange={onChange}
onChange={handleFieldChange('RetryTimes')}
showClear
/>
</Col>
@@ -142,12 +143,7 @@ export default function GeneralSettings(props) {
size='default'
checkedText=''
uncheckedText=''
onChange={(value) => {
setInputs({
...inputs,
DisplayInCurrencyEnabled: value,
});
}}
onChange={handleFieldChange('DisplayInCurrencyEnabled')}
/>
</Col>
<Col xs={24} sm={12} md={8} lg={8} xl={8}>
@@ -157,12 +153,7 @@ export default function GeneralSettings(props) {
size='default'
checkedText=''
uncheckedText=''
onChange={(value) =>
setInputs({
...inputs,
DisplayTokenStatEnabled: value,
})
}
onChange={handleFieldChange('DisplayTokenStatEnabled')}
/>
</Col>
<Col xs={24} sm={12} md={8} lg={8} xl={8}>
@@ -172,12 +163,7 @@ export default function GeneralSettings(props) {
size='default'
checkedText=''
uncheckedText=''
onChange={(value) =>
setInputs({
...inputs,
DefaultCollapseSidebar: value,
})
}
onChange={handleFieldChange('DefaultCollapseSidebar')}
/>
</Col>
</Row>
@@ -189,12 +175,7 @@ export default function GeneralSettings(props) {
size='default'
checkedText=''
uncheckedText=''
onChange={(value) =>
setInputs({
...inputs,
DemoSiteEnabled: value
})
}
onChange={handleFieldChange('DemoSiteEnabled')}
/>
</Col>
<Col xs={24} sm={12} md={8} lg={8} xl={8}>
@@ -205,12 +186,7 @@ export default function GeneralSettings(props) {
size='default'
checkedText=''
uncheckedText=''
onChange={(value) =>
setInputs({
...inputs,
SelfUseModeEnabled: value
})
}
onChange={handleFieldChange('SelfUseModeEnabled')}
/>
</Col>
</Row>

View File

@@ -0,0 +1,252 @@
import React, { useContext, useEffect, useState, useRef } from 'react';
import { Card, Col, Row, Form, Button, Typography, Space, RadioGroup, Radio, Modal, Banner } from '@douyinfe/semi-ui';
import { API, showError, showNotice, timestamp2string } from '../../helpers';
import { StatusContext } from '../../context/Status';
import { marked } from 'marked';
import { StyleContext } from '../../context/Style/index.js';
import { useTranslation } from 'react-i18next';
import { IconHelpCircle, IconInfoCircle, IconAlertTriangle } from '@douyinfe/semi-icons';
const Setup = () => {
const { t, i18n } = useTranslation();
const [statusState] = useContext(StatusContext);
const [styleState, styleDispatch] = useContext(StyleContext);
const [loading, setLoading] = useState(false);
const [selfUseModeInfoVisible, setUsageModeInfoVisible] = useState(false);
const [setupStatus, setSetupStatus] = useState({
status: false,
root_init: false,
database_type: ''
});
const { Text, Title } = Typography;
const formRef = useRef(null);
const [formData, setFormData] = useState({
username: '',
password: '',
confirmPassword: '',
usageMode: 'external'
});
useEffect(() => {
fetchSetupStatus();
}, []);
const fetchSetupStatus = async () => {
try {
const res = await API.get('/api/setup');
const { success, data } = res.data;
if (success) {
setSetupStatus(data);
// If setup is already completed, redirect to home
if (data.status) {
window.location.href = '/';
}
} else {
showError(t('获取初始化状态失败'));
}
} catch (error) {
console.error('Failed to fetch setup status:', error);
showError(t('获取初始化状态失败'));
}
};
const handleUsageModeChange = (val) => {
setFormData({...formData, usageMode: val});
};
const onSubmit = () => {
if (!formRef.current) {
console.error("Form reference is null");
showError(t('表单引用错误,请刷新页面重试'));
return;
}
const values = formRef.current.getValues();
console.log("Form values:", values);
// For root_init=false, validate admin username and password
if (!setupStatus.root_init) {
if (!values.username || !values.username.trim()) {
showError(t('请输入管理员用户名'));
return;
}
if (!values.password || values.password.length < 8) {
showError(t('密码长度至少为8个字符'));
return;
}
if (values.password !== values.confirmPassword) {
showError(t('两次输入的密码不一致'));
return;
}
}
// Prepare submission data
const formValues = {...values};
formValues.SelfUseModeEnabled = values.usageMode === 'self';
formValues.DemoSiteEnabled = values.usageMode === 'demo';
// Remove usageMode as it's not needed by the backend
delete formValues.usageMode;
console.log("Submitting data to backend:", formValues);
setLoading(true);
// Submit to backend
API.post('/api/setup', formValues)
.then(res => {
const { success, message } = res.data;
console.log("API response:", res.data);
if (success) {
showNotice(t('系统初始化成功,正在跳转...'));
setTimeout(() => {
window.location.reload();
}, 1500);
} else {
showError(message || t('初始化失败,请重试'));
}
})
.catch(error => {
console.error('API error:', error);
showError(t('系统初始化失败,请重试'));
setLoading(false);
})
.finally(() => {
// setLoading(false);
});
};
return (
<>
<div style={{ maxWidth: '800px', margin: '0 auto', padding: '20px' }}>
<Card>
<Title heading={2} style={{ marginBottom: '24px' }}>{t('系统初始化')}</Title>
{setupStatus.database_type === 'sqlite' && (
<Banner
type="warning"
icon={<IconAlertTriangle size="large" />}
closeIcon={null}
title={t('数据库警告')}
description={
<div>
<p>{t('您正在使用 SQLite 数据库。如果您在容器环境中运行,请确保已正确设置数据库文件的持久化映射,否则容器重启后所有数据将丢失!')}</p>
<p>{t('建议在生产环境中使用 MySQL 或 PostgreSQL 数据库,或确保 SQLite 数据库文件已映射到宿主机的持久化存储。')}</p>
</div>
}
style={{ marginBottom: '24px' }}
/>
)}
<Form
getFormApi={(formApi) => { formRef.current = formApi; console.log("Form API set:", formApi); }}
initValues={formData}
>
{setupStatus.root_init ? (
<Banner
type="info"
icon={<IconInfoCircle />}
closeIcon={null}
description={t('管理员账号已经初始化过,请继续设置系统参数')}
style={{ marginBottom: '24px' }}
/>
) : (
<Form.Section text={t('管理员账号')}>
<Form.Input
field="username"
label={t('用户名')}
placeholder={t('请输入管理员用户名')}
showClear
onChange={(value) => setFormData({...formData, username: value})}
/>
<Form.Input
field="password"
label={t('密码')}
placeholder={t('请输入管理员密码')}
type="password"
showClear
onChange={(value) => setFormData({...formData, password: value})}
/>
<Form.Input
field="confirmPassword"
label={t('确认密码')}
placeholder={t('请确认管理员密码')}
type="password"
showClear
onChange={(value) => setFormData({...formData, confirmPassword: value})}
/>
</Form.Section>
)}
<Form.Section text={
<div style={{ display: 'flex', alignItems: 'center' }}>
{t('系统设置')}
</div>
}>
<Form.RadioGroup
field="usageMode"
label={
<div style={{ display: 'flex', alignItems: 'center' }}>
{t('使用模式')}
<IconHelpCircle
style={{ marginLeft: '4px', color: 'var(--semi-color-primary)', verticalAlign: 'middle', cursor: 'pointer' }}
onClick={(e) => {
// e.preventDefault();
// e.stopPropagation();
setUsageModeInfoVisible(true);
}}
/>
</div>
}
extraText={t('可在初始化后修改')}
initValue="external"
onChange={handleUsageModeChange}
>
<Form.Radio value="external">{t('对外运营模式')}</Form.Radio>
<Form.Radio value="self">{t('自用模式')}</Form.Radio>
<Form.Radio value="demo">{t('演示站点模式')}</Form.Radio>
</Form.RadioGroup>
</Form.Section>
</Form>
<div style={{ marginTop: '24px', textAlign: 'right' }}>
<Button type="primary" onClick={onSubmit} loading={loading}>
{t('初始化系统')}
</Button>
</div>
</Card>
</div>
<Modal
title={t('使用模式说明')}
visible={selfUseModeInfoVisible}
onOk={() => setUsageModeInfoVisible(false)}
onCancel={() => setUsageModeInfoVisible(false)}
closeOnEsc={true}
okText={t('确定')}
cancelText={null}
>
<div style={{ padding: '8px 0' }}>
<Title heading={6}>{t('对外运营模式')}</Title>
<p>{t('默认模式,适用于为多个用户提供服务的场景。')}</p>
<p>{t('此模式下,系统将计算每次调用的用量,您需要对每个模型都设置价格,如果没有设置价格,用户将无法使用该模型。')}</p>
</div>
<div style={{ padding: '8px 0' }}>
<Title heading={6}>{t('自用模式')}</Title>
<p>{t('适用于个人使用的场景。')}</p>
<p>{t('不需要设置模型价格,系统将弱化用量计算,您可专注于使用模型。')}</p>
</div>
<div style={{ padding: '8px 0' }}>
<Title heading={6}>{t('演示站点模式')}</Title>
<p>{t('适用于展示系统功能的场景。')}</p>
</div>
</Modal>
</>
);
};
export default Setup;