mirror of
https://github.com/QuantumNous/new-api.git
synced 2026-04-10 19:57:27 +00:00
Compare commits
46 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3003d12a20 | ||
|
|
8129aa76f9 | ||
|
|
fb8595da18 | ||
|
|
93cda60d44 | ||
|
|
2ec5eafbce | ||
|
|
be0c240e97 | ||
|
|
7180e6f114 | ||
|
|
61495a460a | ||
|
|
cf3287a10a | ||
|
|
f3f1817aea | ||
|
|
a4795737fe | ||
|
|
eec8f523ce | ||
|
|
58fac129d6 | ||
|
|
241c9389ef | ||
|
|
1d0ef89ce9 | ||
|
|
cce2990db6 | ||
|
|
a7e1d17c3e | ||
|
|
c4e256e69b | ||
|
|
87a5e40daf | ||
|
|
0c326556aa | ||
|
|
794f6a6e34 | ||
|
|
656e809202 | ||
|
|
53ab2aaee4 | ||
|
|
a02bc3342f | ||
|
|
f54d0cb3b0 | ||
|
|
a5c48c2772 | ||
|
|
cffaf0d636 | ||
|
|
865b98a454 | ||
|
|
5bdbf3a673 | ||
|
|
43a7b59b68 | ||
|
|
eac3463401 | ||
|
|
1fa478af20 | ||
|
|
3b58b4989d | ||
|
|
0b1ba2eeb9 | ||
|
|
35277f2b4a | ||
|
|
03256dbdad | ||
|
|
f9a7f6085e | ||
|
|
c289694fe8 | ||
|
|
c37319cb4c | ||
|
|
57f036747b | ||
|
|
158edab974 | ||
|
|
f3a7119745 | ||
|
|
3fa1d93b9c | ||
|
|
5e3ba45748 | ||
|
|
cab30014ad | ||
|
|
e3a0162206 |
1
.gitignore
vendored
1
.gitignore
vendored
@@ -8,3 +8,4 @@ build
|
||||
logs
|
||||
web/dist
|
||||
.env
|
||||
one-api
|
||||
@@ -81,6 +81,7 @@ You can add custom models gpt-4-gizmo-* in channels. These are third-party model
|
||||
- `UPDATE_TASK`: Update async tasks (Midjourney, Suno), default `true`
|
||||
- `GEMINI_MODEL_MAP`: Specify Gemini model versions (v1/v1beta), format: "model:version", comma-separated
|
||||
- `COHERE_SAFETY_SETTING`: Cohere model [safety settings](https://docs.cohere.com/docs/safety-modes#overview), options: `NONE`, `CONTEXTUAL`, `STRICT`, default `NONE`
|
||||
- `GEMINI_VISION_MAX_IMAGE_NUM`: Gemini model maximum image number, default `16`, set to `-1` to disable
|
||||
|
||||
## Deployment
|
||||
> [!TIP]
|
||||
|
||||
@@ -87,6 +87,7 @@
|
||||
- `UPDATE_TASK`:是否更新异步任务(Midjourney、Suno),默认为 `true`,关闭后将不会更新任务进度。
|
||||
- `GEMINI_MODEL_MAP`:Gemini模型指定版本(v1/v1beta),使用“模型:版本”指定,","分隔,例如:-e GEMINI_MODEL_MAP="gemini-1.5-pro-latest:v1beta,gemini-1.5-pro-001:v1beta",为空则使用默认配置(v1beta)
|
||||
- `COHERE_SAFETY_SETTING`:Cohere模型[安全设置](https://docs.cohere.com/docs/safety-modes#overview),可选值为 `NONE`, `CONTEXTUAL`,`STRICT`,默认为 `NONE`。
|
||||
- `GEMINI_VISION_MAX_IMAGE_NUM`:Gemini模型最大图片数量,默认为 `16`,设置为 `-1` 则不限制。
|
||||
## 部署
|
||||
> [!TIP]
|
||||
> 最新版Docker镜像:`calciumion/new-api:latest`
|
||||
|
||||
@@ -46,6 +46,8 @@ var defaultModelRatio = map[string]float64{
|
||||
"gpt-4o-2024-08-06": 1.25, // $2.5 / 1M tokens
|
||||
"gpt-4o-2024-11-20": 1.25, // $2.5 / 1M tokens
|
||||
"gpt-4o-realtime-preview": 2.5,
|
||||
"o1": 7.5,
|
||||
"o1-2024-12-17": 7.5,
|
||||
"o1-preview": 7.5,
|
||||
"o1-preview-2024-09-12": 7.5,
|
||||
"o1-mini": 1.5,
|
||||
|
||||
@@ -35,7 +35,9 @@ func StrToMap(str string) map[string]interface{} {
|
||||
m := make(map[string]interface{})
|
||||
err := json.Unmarshal([]byte(str), &m)
|
||||
if err != nil {
|
||||
return nil
|
||||
return map[string]interface{}{
|
||||
"result": str,
|
||||
}
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
@@ -4,7 +4,6 @@ import (
|
||||
crand "crypto/rand"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"github.com/google/uuid"
|
||||
"html/template"
|
||||
"log"
|
||||
"math/big"
|
||||
@@ -15,6 +14,8 @@ import (
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
func OpenBrowser(url string) {
|
||||
|
||||
5
constant/context_key.go
Normal file
5
constant/context_key.go
Normal file
@@ -0,0 +1,5 @@
|
||||
package constant
|
||||
|
||||
const (
|
||||
ContextKeyRequestStartTime = "request_start_time"
|
||||
)
|
||||
@@ -23,6 +23,8 @@ var GeminiModelMap = map[string]string{
|
||||
"gemini-1.0-pro": "v1",
|
||||
}
|
||||
|
||||
var GeminiVisionMaxImageNum = common.GetEnvOrDefault("GEMINI_VISION_MAX_IMAGE_NUM", 16)
|
||||
|
||||
func InitEnv() {
|
||||
modelVersionMapStr := strings.TrimSpace(os.Getenv("GEMINI_MODEL_MAP"))
|
||||
if modelVersionMapStr == "" {
|
||||
|
||||
@@ -1,11 +1,5 @@
|
||||
package constant
|
||||
|
||||
var MjNotifyEnabled = false
|
||||
var MjAccountFilterEnabled = false
|
||||
var MjModeClearEnabled = false
|
||||
var MjForwardUrlEnabled = true
|
||||
var MjActionCheckSuccessEnabled = true
|
||||
|
||||
const (
|
||||
MjErrorUnknown = 5
|
||||
MjRequestError = 4
|
||||
|
||||
@@ -141,7 +141,8 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
|
||||
milliseconds := tok.Sub(tik).Milliseconds()
|
||||
consumedTime := float64(milliseconds) / 1000.0
|
||||
other := service.GenerateTextOtherInfo(c, meta, modelRatio, 1, completionRatio, modelPrice)
|
||||
model.RecordConsumeLog(c, 1, channel.Id, usage.PromptTokens, usage.CompletionTokens, testModel, "模型测试", quota, "模型测试", 0, quota, int(consumedTime), false, other)
|
||||
model.RecordConsumeLog(c, 1, channel.Id, usage.PromptTokens, usage.CompletionTokens, testModel, "模型测试",
|
||||
quota, "模型测试", 0, quota, int(consumedTime), false, "default", other)
|
||||
common.SysLog(fmt.Sprintf("testing channel #%d, response: \n%s", channel.Id, string(respBody)))
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
@@ -97,6 +97,7 @@ func FetchUpstreamModels(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
channel, err := model.GetChannelById(id, true)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
@@ -105,34 +106,35 @@ func FetchUpstreamModels(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
if channel.Type != common.ChannelTypeOpenAI {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "仅支持 OpenAI 类型渠道",
|
||||
})
|
||||
return
|
||||
|
||||
//if channel.Type != common.ChannelTypeOpenAI {
|
||||
// c.JSON(http.StatusOK, gin.H{
|
||||
// "success": false,
|
||||
// "message": "仅支持 OpenAI 类型渠道",
|
||||
// })
|
||||
// return
|
||||
//}
|
||||
baseURL := common.ChannelBaseURLs[channel.Type]
|
||||
if channel.GetBaseURL() != "" {
|
||||
baseURL = channel.GetBaseURL()
|
||||
}
|
||||
url := fmt.Sprintf("%s/v1/models", *channel.BaseURL)
|
||||
url := fmt.Sprintf("%s/v1/models", baseURL)
|
||||
body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
result := OpenAIModelsResponse{}
|
||||
err = json.Unmarshal(body, &result)
|
||||
if err != nil {
|
||||
|
||||
var result OpenAIModelsResponse
|
||||
if err = json.Unmarshal(body, &result); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
}
|
||||
if !result.Success {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "上游返回错误",
|
||||
"message": fmt.Sprintf("解析响应失败: %s", err.Error()),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
var ids []string
|
||||
@@ -492,3 +494,79 @@ func UpdateChannel(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func FetchModels(c *gin.Context) {
|
||||
var req struct {
|
||||
BaseURL string `json:"base_url"`
|
||||
Key string `json:"key"`
|
||||
}
|
||||
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"success": false,
|
||||
"message": "Invalid request",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
baseURL := req.BaseURL
|
||||
if baseURL == "" {
|
||||
baseURL = "https://api.openai.com"
|
||||
}
|
||||
|
||||
client := &http.Client{}
|
||||
url := fmt.Sprintf("%s/v1/models", baseURL)
|
||||
|
||||
request, err := http.NewRequest("GET", url, nil)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
request.Header.Set("Authorization", "Bearer "+req.Key)
|
||||
|
||||
response, err := client.Do(request)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
//check status code
|
||||
if response.StatusCode != http.StatusOK {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"success": false,
|
||||
"message": "Failed to fetch models",
|
||||
})
|
||||
return
|
||||
}
|
||||
defer response.Body.Close()
|
||||
|
||||
var result struct {
|
||||
Data []struct {
|
||||
ID string `json:"id"`
|
||||
} `json:"data"`
|
||||
}
|
||||
|
||||
if err := json.NewDecoder(response.Body).Decode(&result); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
var models []string
|
||||
for _, model := range result.Data {
|
||||
models = append(models, model.ID)
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"data": models,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -25,7 +25,8 @@ func GetAllLogs(c *gin.Context) {
|
||||
tokenName := c.Query("token_name")
|
||||
modelName := c.Query("model_name")
|
||||
channel, _ := strconv.Atoi(c.Query("channel"))
|
||||
logs, total, err := model.GetAllLogs(logType, startTimestamp, endTimestamp, modelName, username, tokenName, (p-1)*pageSize, pageSize, channel)
|
||||
group := c.Query("group")
|
||||
logs, total, err := model.GetAllLogs(logType, startTimestamp, endTimestamp, modelName, username, tokenName, (p-1)*pageSize, pageSize, channel, group)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
@@ -63,7 +64,8 @@ func GetUserLogs(c *gin.Context) {
|
||||
endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64)
|
||||
tokenName := c.Query("token_name")
|
||||
modelName := c.Query("model_name")
|
||||
logs, total, err := model.GetUserLogs(userId, logType, startTimestamp, endTimestamp, modelName, tokenName, (p-1)*pageSize, pageSize)
|
||||
group := c.Query("group")
|
||||
logs, total, err := model.GetUserLogs(userId, logType, startTimestamp, endTimestamp, modelName, tokenName, (p-1)*pageSize, pageSize, group)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
@@ -146,7 +148,8 @@ func GetLogsStat(c *gin.Context) {
|
||||
username := c.Query("username")
|
||||
modelName := c.Query("model_name")
|
||||
channel, _ := strconv.Atoi(c.Query("channel"))
|
||||
stat := model.SumUsedQuota(logType, startTimestamp, endTimestamp, modelName, username, tokenName, channel)
|
||||
group := c.Query("group")
|
||||
stat := model.SumUsedQuota(logType, startTimestamp, endTimestamp, modelName, username, tokenName, channel, group)
|
||||
//tokenNum := model.SumUsedToken(logType, startTimestamp, endTimestamp, modelName, username, "")
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
@@ -168,7 +171,8 @@ func GetLogsSelfStat(c *gin.Context) {
|
||||
tokenName := c.Query("token_name")
|
||||
modelName := c.Query("model_name")
|
||||
channel, _ := strconv.Atoi(c.Query("channel"))
|
||||
quotaNum := model.SumUsedQuota(logType, startTimestamp, endTimestamp, modelName, username, tokenName, channel)
|
||||
group := c.Query("group")
|
||||
quotaNum := model.SumUsedQuota(logType, startTimestamp, endTimestamp, modelName, username, tokenName, channel, group)
|
||||
//tokenNum := model.SumUsedToken(logType, startTimestamp, endTimestamp, modelName, username, tokenName)
|
||||
c.JSON(200, gin.H{
|
||||
"success": true,
|
||||
|
||||
@@ -10,10 +10,10 @@ import (
|
||||
"log"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"one-api/dto"
|
||||
"one-api/model"
|
||||
"one-api/service"
|
||||
"one-api/setting"
|
||||
"strconv"
|
||||
"time"
|
||||
)
|
||||
@@ -231,9 +231,9 @@ func GetAllMidjourney(c *gin.Context) {
|
||||
if logs == nil {
|
||||
logs = make([]*model.Midjourney, 0)
|
||||
}
|
||||
if constant.MjForwardUrlEnabled {
|
||||
if setting.MjForwardUrlEnabled {
|
||||
for i, midjourney := range logs {
|
||||
midjourney.ImageUrl = constant.ServerAddress + "/mj/image/" + midjourney.MjId
|
||||
midjourney.ImageUrl = setting.ServerAddress + "/mj/image/" + midjourney.MjId
|
||||
logs[i] = midjourney
|
||||
}
|
||||
}
|
||||
@@ -263,9 +263,9 @@ func GetUserMidjourney(c *gin.Context) {
|
||||
if logs == nil {
|
||||
logs = make([]*model.Midjourney, 0)
|
||||
}
|
||||
if constant.MjForwardUrlEnabled {
|
||||
if setting.MjForwardUrlEnabled {
|
||||
for i, midjourney := range logs {
|
||||
midjourney.ImageUrl = constant.ServerAddress + "/mj/image/" + midjourney.MjId
|
||||
midjourney.ImageUrl = setting.ServerAddress + "/mj/image/" + midjourney.MjId
|
||||
logs[i] = midjourney
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,8 +5,8 @@ import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"one-api/model"
|
||||
"one-api/setting"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -47,9 +47,9 @@ func GetStatus(c *gin.Context) {
|
||||
"footer_html": common.Footer,
|
||||
"wechat_qrcode": common.WeChatAccountQRCodeImageURL,
|
||||
"wechat_login": common.WeChatAuthEnabled,
|
||||
"server_address": constant.ServerAddress,
|
||||
"price": constant.Price,
|
||||
"min_topup": constant.MinTopUp,
|
||||
"server_address": setting.ServerAddress,
|
||||
"price": setting.Price,
|
||||
"min_topup": setting.MinTopUp,
|
||||
"turnstile_check": common.TurnstileCheckEnabled,
|
||||
"turnstile_site_key": common.TurnstileSiteKey,
|
||||
"top_up_link": common.TopUpLink,
|
||||
@@ -63,9 +63,9 @@ func GetStatus(c *gin.Context) {
|
||||
"enable_data_export": common.DataExportEnabled,
|
||||
"data_export_default_time": common.DataExportDefaultTime,
|
||||
"default_collapse_sidebar": common.DefaultCollapseSidebar,
|
||||
"enable_online_topup": constant.PayAddress != "" && constant.EpayId != "" && constant.EpayKey != "",
|
||||
"mj_notify_enabled": constant.MjNotifyEnabled,
|
||||
"chats": constant.Chats,
|
||||
"enable_online_topup": setting.PayAddress != "" && setting.EpayId != "" && setting.EpayKey != "",
|
||||
"mj_notify_enabled": setting.MjNotifyEnabled,
|
||||
"chats": setting.Chats,
|
||||
},
|
||||
})
|
||||
return
|
||||
@@ -207,7 +207,7 @@ func SendPasswordResetEmail(c *gin.Context) {
|
||||
}
|
||||
code := common.GenerateVerificationCode(0)
|
||||
common.RegisterVerificationCodeWithKey(email, code, common.PasswordResetPurpose)
|
||||
link := fmt.Sprintf("%s/user/reset?email=%s&token=%s", constant.ServerAddress, email, code)
|
||||
link := fmt.Sprintf("%s/user/reset?email=%s&token=%s", setting.ServerAddress, email, code)
|
||||
subject := fmt.Sprintf("%s密码重置", common.SystemName)
|
||||
content := fmt.Sprintf("<p>您好,你正在进行%s密码重置。</p>"+
|
||||
"<p>点击 <a href='%s'>此处</a> 进行密码重置。</p>"+
|
||||
|
||||
@@ -8,10 +8,35 @@ import (
|
||||
|
||||
func GetPricing(c *gin.Context) {
|
||||
pricing := model.GetPricing()
|
||||
userId, exists := c.Get("id")
|
||||
usableGroup := map[string]string{}
|
||||
groupRatio := common.GroupRatio
|
||||
var group string
|
||||
if exists {
|
||||
user, err := model.GetChannelById(userId.(int), false)
|
||||
if err != nil {
|
||||
c.JSON(200, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
group = user.Group
|
||||
}
|
||||
|
||||
usableGroup = common.GetUserUsableGroups(group)
|
||||
// check groupRatio contains usableGroup
|
||||
for group := range common.GroupRatio {
|
||||
if _, ok := usableGroup[group]; !ok {
|
||||
delete(groupRatio, group)
|
||||
}
|
||||
}
|
||||
|
||||
c.JSON(200, gin.H{
|
||||
"success": true,
|
||||
"data": pricing,
|
||||
"group_ratio": common.GroupRatio,
|
||||
"success": true,
|
||||
"data": pricing,
|
||||
"group_ratio": groupRatio,
|
||||
"usable_group": usableGroup,
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -8,9 +8,9 @@ import (
|
||||
"log"
|
||||
"net/url"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"one-api/model"
|
||||
"one-api/service"
|
||||
"one-api/setting"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -28,13 +28,13 @@ type AmountRequest struct {
|
||||
}
|
||||
|
||||
func GetEpayClient() *epay.Client {
|
||||
if constant.PayAddress == "" || constant.EpayId == "" || constant.EpayKey == "" {
|
||||
if setting.PayAddress == "" || setting.EpayId == "" || setting.EpayKey == "" {
|
||||
return nil
|
||||
}
|
||||
withUrl, err := epay.NewClient(&epay.Config{
|
||||
PartnerID: constant.EpayId,
|
||||
Key: constant.EpayKey,
|
||||
}, constant.PayAddress)
|
||||
PartnerID: setting.EpayId,
|
||||
Key: setting.EpayKey,
|
||||
}, setting.PayAddress)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
@@ -50,12 +50,12 @@ func getPayMoney(amount float64, group string) float64 {
|
||||
if topupGroupRatio == 0 {
|
||||
topupGroupRatio = 1
|
||||
}
|
||||
payMoney := amount * constant.Price * topupGroupRatio
|
||||
payMoney := amount * setting.Price * topupGroupRatio
|
||||
return payMoney
|
||||
}
|
||||
|
||||
func getMinTopup() int {
|
||||
minTopup := constant.MinTopUp
|
||||
minTopup := setting.MinTopUp
|
||||
if !common.DisplayInCurrencyEnabled {
|
||||
minTopup = minTopup * int(common.QuotaPerUnit)
|
||||
}
|
||||
@@ -94,7 +94,7 @@ func RequestEpay(c *gin.Context) {
|
||||
payType = "wxpay"
|
||||
}
|
||||
callBackAddress := service.GetCallbackAddress()
|
||||
returnUrl, _ := url.Parse(constant.ServerAddress + "/log")
|
||||
returnUrl, _ := url.Parse(setting.ServerAddress + "/log")
|
||||
notifyUrl, _ := url.Parse(callBackAddress + "/api/user/epay/notify")
|
||||
tradeNo := fmt.Sprintf("%s%d", common.GetRandomString(6), time.Now().Unix())
|
||||
tradeNo = fmt.Sprintf("USR%dNO%s", id, tradeNo)
|
||||
|
||||
@@ -3,39 +3,48 @@ package dto
|
||||
import "encoding/json"
|
||||
|
||||
type ResponseFormat struct {
|
||||
Type string `json:"type,omitempty"`
|
||||
Type string `json:"type,omitempty"`
|
||||
JsonSchema *FormatJsonSchema `json:"json_schema,omitempty"`
|
||||
}
|
||||
|
||||
type FormatJsonSchema struct {
|
||||
Description string `json:"description,omitempty"`
|
||||
Name string `json:"name"`
|
||||
Schema any `json:"schema,omitempty"`
|
||||
Strict any `json:"strict,omitempty"`
|
||||
}
|
||||
|
||||
type GeneralOpenAIRequest struct {
|
||||
Model string `json:"model,omitempty"`
|
||||
Messages []Message `json:"messages,omitempty"`
|
||||
Prompt any `json:"prompt,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
StreamOptions *StreamOptions `json:"stream_options,omitempty"`
|
||||
MaxTokens uint `json:"max_tokens,omitempty"`
|
||||
MaxCompletionTokens uint `json:"max_completion_tokens,omitempty"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
TopP float64 `json:"top_p,omitempty"`
|
||||
TopK int `json:"top_k,omitempty"`
|
||||
Stop any `json:"stop,omitempty"`
|
||||
N int `json:"n,omitempty"`
|
||||
Input any `json:"input,omitempty"`
|
||||
Instruction string `json:"instruction,omitempty"`
|
||||
Size string `json:"size,omitempty"`
|
||||
Functions any `json:"functions,omitempty"`
|
||||
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
|
||||
PresencePenalty float64 `json:"presence_penalty,omitempty"`
|
||||
ResponseFormat any `json:"response_format,omitempty"`
|
||||
EncodingFormat any `json:"encoding_format,omitempty"`
|
||||
Seed float64 `json:"seed,omitempty"`
|
||||
Tools []ToolCall `json:"tools,omitempty"`
|
||||
ToolChoice any `json:"tool_choice,omitempty"`
|
||||
User string `json:"user,omitempty"`
|
||||
LogProbs bool `json:"logprobs,omitempty"`
|
||||
TopLogProbs int `json:"top_logprobs,omitempty"`
|
||||
Dimensions int `json:"dimensions,omitempty"`
|
||||
Modalities any `json:"modalities,omitempty"`
|
||||
Audio any `json:"audio,omitempty"`
|
||||
Model string `json:"model,omitempty"`
|
||||
Messages []Message `json:"messages,omitempty"`
|
||||
Prompt any `json:"prompt,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
StreamOptions *StreamOptions `json:"stream_options,omitempty"`
|
||||
MaxTokens uint `json:"max_tokens,omitempty"`
|
||||
MaxCompletionTokens uint `json:"max_completion_tokens,omitempty"`
|
||||
ReasoningEffort string `json:"reasoning_effort,omitempty"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
TopP float64 `json:"top_p,omitempty"`
|
||||
TopK int `json:"top_k,omitempty"`
|
||||
Stop any `json:"stop,omitempty"`
|
||||
N int `json:"n,omitempty"`
|
||||
Input any `json:"input,omitempty"`
|
||||
Instruction string `json:"instruction,omitempty"`
|
||||
Size string `json:"size,omitempty"`
|
||||
Functions any `json:"functions,omitempty"`
|
||||
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
|
||||
PresencePenalty float64 `json:"presence_penalty,omitempty"`
|
||||
ResponseFormat *ResponseFormat `json:"response_format,omitempty"`
|
||||
EncodingFormat any `json:"encoding_format,omitempty"`
|
||||
Seed float64 `json:"seed,omitempty"`
|
||||
Tools []ToolCall `json:"tools,omitempty"`
|
||||
ToolChoice any `json:"tool_choice,omitempty"`
|
||||
User string `json:"user,omitempty"`
|
||||
LogProbs bool `json:"logprobs,omitempty"`
|
||||
TopLogProbs int `json:"top_logprobs,omitempty"`
|
||||
Dimensions int `json:"dimensions,omitempty"`
|
||||
Modalities any `json:"modalities,omitempty"`
|
||||
Audio any `json:"audio,omitempty"`
|
||||
}
|
||||
|
||||
type OpenAITools struct {
|
||||
@@ -80,11 +89,11 @@ type Message struct {
|
||||
Role string `json:"role"`
|
||||
Content json.RawMessage `json:"content"`
|
||||
Name *string `json:"name,omitempty"`
|
||||
ToolCalls any `json:"tool_calls,omitempty"`
|
||||
ToolCalls json.RawMessage `json:"tool_calls,omitempty"`
|
||||
ToolCallId string `json:"tool_call_id,omitempty"`
|
||||
}
|
||||
|
||||
type MediaMessage struct {
|
||||
type MediaContent struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text"`
|
||||
ImageUrl any `json:"image_url,omitempty"`
|
||||
@@ -107,7 +116,23 @@ const (
|
||||
ContentTypeInputAudio = "input_audio"
|
||||
)
|
||||
|
||||
func (m Message) StringContent() string {
|
||||
func (m *Message) ParseToolCalls() []ToolCall {
|
||||
if m.ToolCalls == nil {
|
||||
return nil
|
||||
}
|
||||
var toolCalls []ToolCall
|
||||
if err := json.Unmarshal(m.ToolCalls, &toolCalls); err == nil {
|
||||
return toolCalls
|
||||
}
|
||||
return toolCalls
|
||||
}
|
||||
|
||||
func (m *Message) SetToolCalls(toolCalls any) {
|
||||
toolCallsJson, _ := json.Marshal(toolCalls)
|
||||
m.ToolCalls = toolCallsJson
|
||||
}
|
||||
|
||||
func (m *Message) StringContent() string {
|
||||
var stringContent string
|
||||
if err := json.Unmarshal(m.Content, &stringContent); err == nil {
|
||||
return stringContent
|
||||
@@ -120,7 +145,7 @@ func (m *Message) SetStringContent(content string) {
|
||||
m.Content = jsonContent
|
||||
}
|
||||
|
||||
func (m Message) IsStringContent() bool {
|
||||
func (m *Message) IsStringContent() bool {
|
||||
var stringContent string
|
||||
if err := json.Unmarshal(m.Content, &stringContent); err == nil {
|
||||
return true
|
||||
@@ -128,11 +153,11 @@ func (m Message) IsStringContent() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (m Message) ParseContent() []MediaMessage {
|
||||
var contentList []MediaMessage
|
||||
func (m *Message) ParseContent() []MediaContent {
|
||||
var contentList []MediaContent
|
||||
var stringContent string
|
||||
if err := json.Unmarshal(m.Content, &stringContent); err == nil {
|
||||
contentList = append(contentList, MediaMessage{
|
||||
contentList = append(contentList, MediaContent{
|
||||
Type: ContentTypeText,
|
||||
Text: stringContent,
|
||||
})
|
||||
@@ -148,7 +173,7 @@ func (m Message) ParseContent() []MediaMessage {
|
||||
switch contentMap["type"] {
|
||||
case ContentTypeText:
|
||||
if subStr, ok := contentMap["text"].(string); ok {
|
||||
contentList = append(contentList, MediaMessage{
|
||||
contentList = append(contentList, MediaContent{
|
||||
Type: ContentTypeText,
|
||||
Text: subStr,
|
||||
})
|
||||
@@ -161,7 +186,7 @@ func (m Message) ParseContent() []MediaMessage {
|
||||
} else {
|
||||
subObj["detail"] = "high"
|
||||
}
|
||||
contentList = append(contentList, MediaMessage{
|
||||
contentList = append(contentList, MediaContent{
|
||||
Type: ContentTypeImageURL,
|
||||
ImageUrl: MessageImageUrl{
|
||||
Url: subObj["url"].(string),
|
||||
@@ -169,7 +194,7 @@ func (m Message) ParseContent() []MediaMessage {
|
||||
},
|
||||
})
|
||||
} else if url, ok := contentMap["image_url"].(string); ok {
|
||||
contentList = append(contentList, MediaMessage{
|
||||
contentList = append(contentList, MediaContent{
|
||||
Type: ContentTypeImageURL,
|
||||
ImageUrl: MessageImageUrl{
|
||||
Url: url,
|
||||
@@ -179,7 +204,7 @@ func (m Message) ParseContent() []MediaMessage {
|
||||
}
|
||||
case ContentTypeInputAudio:
|
||||
if subObj, ok := contentMap["input_audio"].(map[string]any); ok {
|
||||
contentList = append(contentList, MediaMessage{
|
||||
contentList = append(contentList, MediaContent{
|
||||
Type: ContentTypeInputAudio,
|
||||
InputAudio: MessageInputAudio{
|
||||
Data: subObj["data"].(string),
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"one-api/service"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
@@ -112,6 +113,7 @@ func Distribute() func(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
}
|
||||
c.Set(constant.ContextKeyRequestStartTime, time.Now())
|
||||
SetupContextForSelectedChannel(c, channel, modelRequest.Model)
|
||||
c.Next()
|
||||
}
|
||||
|
||||
32
model/log.go
32
model/log.go
@@ -12,6 +12,16 @@ import (
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
var groupCol string
|
||||
|
||||
func init() {
|
||||
if common.UsingPostgreSQL {
|
||||
groupCol = `"group"`
|
||||
} else {
|
||||
groupCol = "`group`"
|
||||
}
|
||||
}
|
||||
|
||||
type Log struct {
|
||||
Id int `json:"id" gorm:"index:idx_created_at_id,priority:1"`
|
||||
UserId int `json:"user_id" gorm:"index"`
|
||||
@@ -28,6 +38,7 @@ type Log struct {
|
||||
IsStream bool `json:"is_stream" gorm:"default:false"`
|
||||
ChannelId int `json:"channel" gorm:"index"`
|
||||
TokenId int `json:"token_id" gorm:"default:0;index"`
|
||||
Group string `json:"group" gorm:"index"`
|
||||
Other string `json:"other"`
|
||||
}
|
||||
|
||||
@@ -70,7 +81,9 @@ func RecordLog(userId int, logType int, content string) {
|
||||
}
|
||||
}
|
||||
|
||||
func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptTokens int, completionTokens int, modelName string, tokenName string, quota int, content string, tokenId int, userQuota int, useTimeSeconds int, isStream bool, other map[string]interface{}) {
|
||||
func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptTokens int, completionTokens int,
|
||||
modelName string, tokenName string, quota int, content string, tokenId int, userQuota int, useTimeSeconds int,
|
||||
isStream bool, group string, other map[string]interface{}) {
|
||||
common.LogInfo(ctx, fmt.Sprintf("record consume log: userId=%d, 用户调用前余额=%d, channelId=%d, promptTokens=%d, completionTokens=%d, modelName=%s, tokenName=%s, quota=%d, content=%s", userId, userQuota, channelId, promptTokens, completionTokens, modelName, tokenName, quota, content))
|
||||
if !common.LogConsumeEnabled {
|
||||
return
|
||||
@@ -92,6 +105,7 @@ func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptToke
|
||||
TokenId: tokenId,
|
||||
UseTime: useTimeSeconds,
|
||||
IsStream: isStream,
|
||||
Group: group,
|
||||
Other: otherStr,
|
||||
}
|
||||
err := LOG_DB.Create(log).Error
|
||||
@@ -105,7 +119,7 @@ func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptToke
|
||||
}
|
||||
}
|
||||
|
||||
func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, startIdx int, num int, channel int) (logs []*Log, total int64, err error) {
|
||||
func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, startIdx int, num int, channel int, group string) (logs []*Log, total int64, err error) {
|
||||
var tx *gorm.DB
|
||||
if logType == LogTypeUnknown {
|
||||
tx = LOG_DB
|
||||
@@ -130,6 +144,9 @@ func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName
|
||||
if channel != 0 {
|
||||
tx = tx.Where("channel_id = ?", channel)
|
||||
}
|
||||
if group != "" {
|
||||
tx = tx.Where(groupCol+" = ?", group)
|
||||
}
|
||||
err = tx.Model(&Log{}).Count(&total).Error
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
@@ -141,7 +158,7 @@ func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName
|
||||
return logs, total, err
|
||||
}
|
||||
|
||||
func GetUserLogs(userId int, logType int, startTimestamp int64, endTimestamp int64, modelName string, tokenName string, startIdx int, num int) (logs []*Log, total int64, err error) {
|
||||
func GetUserLogs(userId int, logType int, startTimestamp int64, endTimestamp int64, modelName string, tokenName string, startIdx int, num int, group string) (logs []*Log, total int64, err error) {
|
||||
var tx *gorm.DB
|
||||
if logType == LogTypeUnknown {
|
||||
tx = LOG_DB.Where("user_id = ?", userId)
|
||||
@@ -160,6 +177,9 @@ func GetUserLogs(userId int, logType int, startTimestamp int64, endTimestamp int
|
||||
if endTimestamp != 0 {
|
||||
tx = tx.Where("created_at <= ?", endTimestamp)
|
||||
}
|
||||
if group != "" {
|
||||
tx = tx.Where(groupCol+" = ?", group)
|
||||
}
|
||||
err = tx.Model(&Log{}).Count(&total).Error
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
@@ -193,7 +213,7 @@ type Stat struct {
|
||||
Tpm int `json:"tpm"`
|
||||
}
|
||||
|
||||
func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, channel int) (stat Stat) {
|
||||
func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, channel int, group string) (stat Stat) {
|
||||
tx := LOG_DB.Table("logs").Select("sum(quota) quota")
|
||||
|
||||
// 为rpm和tpm创建单独的查询
|
||||
@@ -221,6 +241,10 @@ func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelNa
|
||||
tx = tx.Where("channel_id = ?", channel)
|
||||
rpmTpmQuery = rpmTpmQuery.Where("channel_id = ?", channel)
|
||||
}
|
||||
if group != "" {
|
||||
tx = tx.Where(groupCol+" = ?", group)
|
||||
rpmTpmQuery = rpmTpmQuery.Where(groupCol+" = ?", group)
|
||||
}
|
||||
|
||||
tx = tx.Where("type = ?", LogTypeConsume)
|
||||
rpmTpmQuery = rpmTpmQuery.Where("type = ?", LogTypeConsume)
|
||||
|
||||
@@ -2,7 +2,7 @@ package model
|
||||
|
||||
import (
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"one-api/setting"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -61,16 +61,16 @@ func InitOptionMap() {
|
||||
common.OptionMap["SystemName"] = common.SystemName
|
||||
common.OptionMap["Logo"] = common.Logo
|
||||
common.OptionMap["ServerAddress"] = ""
|
||||
common.OptionMap["WorkerUrl"] = constant.WorkerUrl
|
||||
common.OptionMap["WorkerValidKey"] = constant.WorkerValidKey
|
||||
common.OptionMap["WorkerUrl"] = setting.WorkerUrl
|
||||
common.OptionMap["WorkerValidKey"] = setting.WorkerValidKey
|
||||
common.OptionMap["PayAddress"] = ""
|
||||
common.OptionMap["CustomCallbackAddress"] = ""
|
||||
common.OptionMap["EpayId"] = ""
|
||||
common.OptionMap["EpayKey"] = ""
|
||||
common.OptionMap["Price"] = strconv.FormatFloat(constant.Price, 'f', -1, 64)
|
||||
common.OptionMap["MinTopUp"] = strconv.Itoa(constant.MinTopUp)
|
||||
common.OptionMap["Price"] = strconv.FormatFloat(setting.Price, 'f', -1, 64)
|
||||
common.OptionMap["MinTopUp"] = strconv.Itoa(setting.MinTopUp)
|
||||
common.OptionMap["TopupGroupRatio"] = common.TopupGroupRatio2JSONString()
|
||||
common.OptionMap["Chats"] = constant.Chats2JsonString()
|
||||
common.OptionMap["Chats"] = setting.Chats2JsonString()
|
||||
common.OptionMap["GitHubClientId"] = ""
|
||||
common.OptionMap["GitHubClientSecret"] = ""
|
||||
common.OptionMap["TelegramBotToken"] = ""
|
||||
@@ -98,17 +98,17 @@ func InitOptionMap() {
|
||||
common.OptionMap["DataExportInterval"] = strconv.Itoa(common.DataExportInterval)
|
||||
common.OptionMap["DataExportDefaultTime"] = common.DataExportDefaultTime
|
||||
common.OptionMap["DefaultCollapseSidebar"] = strconv.FormatBool(common.DefaultCollapseSidebar)
|
||||
common.OptionMap["MjNotifyEnabled"] = strconv.FormatBool(constant.MjNotifyEnabled)
|
||||
common.OptionMap["MjAccountFilterEnabled"] = strconv.FormatBool(constant.MjAccountFilterEnabled)
|
||||
common.OptionMap["MjModeClearEnabled"] = strconv.FormatBool(constant.MjModeClearEnabled)
|
||||
common.OptionMap["MjForwardUrlEnabled"] = strconv.FormatBool(constant.MjForwardUrlEnabled)
|
||||
common.OptionMap["MjActionCheckSuccessEnabled"] = strconv.FormatBool(constant.MjActionCheckSuccessEnabled)
|
||||
common.OptionMap["CheckSensitiveEnabled"] = strconv.FormatBool(constant.CheckSensitiveEnabled)
|
||||
common.OptionMap["CheckSensitiveOnPromptEnabled"] = strconv.FormatBool(constant.CheckSensitiveOnPromptEnabled)
|
||||
common.OptionMap["MjNotifyEnabled"] = strconv.FormatBool(setting.MjNotifyEnabled)
|
||||
common.OptionMap["MjAccountFilterEnabled"] = strconv.FormatBool(setting.MjAccountFilterEnabled)
|
||||
common.OptionMap["MjModeClearEnabled"] = strconv.FormatBool(setting.MjModeClearEnabled)
|
||||
common.OptionMap["MjForwardUrlEnabled"] = strconv.FormatBool(setting.MjForwardUrlEnabled)
|
||||
common.OptionMap["MjActionCheckSuccessEnabled"] = strconv.FormatBool(setting.MjActionCheckSuccessEnabled)
|
||||
common.OptionMap["CheckSensitiveEnabled"] = strconv.FormatBool(setting.CheckSensitiveEnabled)
|
||||
common.OptionMap["CheckSensitiveOnPromptEnabled"] = strconv.FormatBool(setting.CheckSensitiveOnPromptEnabled)
|
||||
//common.OptionMap["CheckSensitiveOnCompletionEnabled"] = strconv.FormatBool(constant.CheckSensitiveOnCompletionEnabled)
|
||||
common.OptionMap["StopOnSensitiveEnabled"] = strconv.FormatBool(constant.StopOnSensitiveEnabled)
|
||||
common.OptionMap["SensitiveWords"] = constant.SensitiveWordsToString()
|
||||
common.OptionMap["StreamCacheQueueLength"] = strconv.Itoa(constant.StreamCacheQueueLength)
|
||||
common.OptionMap["StopOnSensitiveEnabled"] = strconv.FormatBool(setting.StopOnSensitiveEnabled)
|
||||
common.OptionMap["SensitiveWords"] = setting.SensitiveWordsToString()
|
||||
common.OptionMap["StreamCacheQueueLength"] = strconv.Itoa(setting.StreamCacheQueueLength)
|
||||
|
||||
common.OptionMapRWMutex.Unlock()
|
||||
loadOptionsFromDatabase()
|
||||
@@ -209,23 +209,23 @@ func updateOptionMap(key string, value string) (err error) {
|
||||
case "DefaultCollapseSidebar":
|
||||
common.DefaultCollapseSidebar = boolValue
|
||||
case "MjNotifyEnabled":
|
||||
constant.MjNotifyEnabled = boolValue
|
||||
setting.MjNotifyEnabled = boolValue
|
||||
case "MjAccountFilterEnabled":
|
||||
constant.MjAccountFilterEnabled = boolValue
|
||||
setting.MjAccountFilterEnabled = boolValue
|
||||
case "MjModeClearEnabled":
|
||||
constant.MjModeClearEnabled = boolValue
|
||||
setting.MjModeClearEnabled = boolValue
|
||||
case "MjForwardUrlEnabled":
|
||||
constant.MjForwardUrlEnabled = boolValue
|
||||
setting.MjForwardUrlEnabled = boolValue
|
||||
case "MjActionCheckSuccessEnabled":
|
||||
constant.MjActionCheckSuccessEnabled = boolValue
|
||||
setting.MjActionCheckSuccessEnabled = boolValue
|
||||
case "CheckSensitiveEnabled":
|
||||
constant.CheckSensitiveEnabled = boolValue
|
||||
setting.CheckSensitiveEnabled = boolValue
|
||||
case "CheckSensitiveOnPromptEnabled":
|
||||
constant.CheckSensitiveOnPromptEnabled = boolValue
|
||||
setting.CheckSensitiveOnPromptEnabled = boolValue
|
||||
//case "CheckSensitiveOnCompletionEnabled":
|
||||
// constant.CheckSensitiveOnCompletionEnabled = boolValue
|
||||
case "StopOnSensitiveEnabled":
|
||||
constant.StopOnSensitiveEnabled = boolValue
|
||||
setting.StopOnSensitiveEnabled = boolValue
|
||||
case "SMTPSSLEnabled":
|
||||
common.SMTPSSLEnabled = boolValue
|
||||
}
|
||||
@@ -245,25 +245,25 @@ func updateOptionMap(key string, value string) (err error) {
|
||||
case "SMTPToken":
|
||||
common.SMTPToken = value
|
||||
case "ServerAddress":
|
||||
constant.ServerAddress = value
|
||||
setting.ServerAddress = value
|
||||
case "WorkerUrl":
|
||||
constant.WorkerUrl = value
|
||||
setting.WorkerUrl = value
|
||||
case "WorkerValidKey":
|
||||
constant.WorkerValidKey = value
|
||||
setting.WorkerValidKey = value
|
||||
case "PayAddress":
|
||||
constant.PayAddress = value
|
||||
setting.PayAddress = value
|
||||
case "Chats":
|
||||
err = constant.UpdateChatsByJsonString(value)
|
||||
err = setting.UpdateChatsByJsonString(value)
|
||||
case "CustomCallbackAddress":
|
||||
constant.CustomCallbackAddress = value
|
||||
setting.CustomCallbackAddress = value
|
||||
case "EpayId":
|
||||
constant.EpayId = value
|
||||
setting.EpayId = value
|
||||
case "EpayKey":
|
||||
constant.EpayKey = value
|
||||
setting.EpayKey = value
|
||||
case "Price":
|
||||
constant.Price, _ = strconv.ParseFloat(value, 64)
|
||||
setting.Price, _ = strconv.ParseFloat(value, 64)
|
||||
case "MinTopUp":
|
||||
constant.MinTopUp, _ = strconv.Atoi(value)
|
||||
setting.MinTopUp, _ = strconv.Atoi(value)
|
||||
case "TopupGroupRatio":
|
||||
err = common.UpdateTopupGroupRatioByJSONString(value)
|
||||
case "GitHubClientId":
|
||||
@@ -331,9 +331,9 @@ func updateOptionMap(key string, value string) (err error) {
|
||||
case "QuotaPerUnit":
|
||||
common.QuotaPerUnit, _ = strconv.ParseFloat(value, 64)
|
||||
case "SensitiveWords":
|
||||
constant.SensitiveWordsFromString(value)
|
||||
setting.SensitiveWordsFromString(value)
|
||||
case "StreamCacheQueueLength":
|
||||
constant.StreamCacheQueueLength, _ = strconv.Atoi(value)
|
||||
setting.StreamCacheQueueLength, _ = strconv.Atoi(value)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -5,8 +5,8 @@ import (
|
||||
"fmt"
|
||||
"gorm.io/gorm"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/setting"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
@@ -325,7 +325,7 @@ func PostConsumeTokenQuota(relayInfo *relaycommon.RelayInfo, userQuota int, quot
|
||||
prompt = "您的额度已用尽"
|
||||
}
|
||||
if email != "" {
|
||||
topUpLink := fmt.Sprintf("%s/topup", constant.ServerAddress)
|
||||
topUpLink := fmt.Sprintf("%s/topup", setting.ServerAddress)
|
||||
err = common.SendEmail(prompt, email,
|
||||
fmt.Sprintf("%s,当前剩余额度为 %d,为了不影响您的使用,请及时充值。<br/>充值链接:<a href='%s'>%s</a>", prompt, userQuota, topUpLink, topUpLink))
|
||||
if err != nil {
|
||||
|
||||
@@ -89,26 +89,31 @@ func SearchUsers(keyword string, group string) ([]*User, error) {
|
||||
var users []*User
|
||||
var err error
|
||||
|
||||
groupCol := "`group`"
|
||||
if common.UsingPostgreSQL {
|
||||
groupCol = `"group"`
|
||||
}
|
||||
|
||||
// 尝试将关键字转换为整数ID
|
||||
keywordInt, err := strconv.Atoi(keyword)
|
||||
if err == nil {
|
||||
// 如果转换成功,按照ID和可选的组别搜索用户
|
||||
query := DB.Unscoped().Omit("password").Where("`id` = ?", keywordInt)
|
||||
query := DB.Unscoped().Omit("password").Where("id = ?", keywordInt)
|
||||
if group != "" {
|
||||
query = query.Where("`group` = ?", group) // 使用反引号包围group
|
||||
query = query.Where(groupCol+" = ?", group) // 使用反引号包围group
|
||||
}
|
||||
err = query.Find(&users).Error
|
||||
if err != nil || len(users) > 0 {
|
||||
return users, err
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
err = nil
|
||||
|
||||
query := DB.Unscoped().Omit("password")
|
||||
likeCondition := "`username` LIKE ? OR `email` LIKE ? OR `display_name` LIKE ?"
|
||||
likeCondition := "username LIKE ? OR email LIKE ? OR display_name LIKE ?"
|
||||
if group != "" {
|
||||
query = query.Where("("+likeCondition+") AND `group` = ?", "%"+keyword+"%", "%"+keyword+"%", "%"+keyword+"%", group)
|
||||
query = query.Where("("+likeCondition+") AND "+groupCol+" = ?", "%"+keyword+"%", "%"+keyword+"%", "%"+keyword+"%", group)
|
||||
} else {
|
||||
query = query.Where(likeCondition, "%"+keyword+"%", "%"+keyword+"%", "%"+keyword+"%")
|
||||
}
|
||||
|
||||
@@ -240,14 +240,7 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR
|
||||
claudeMediaMessages = append(claudeMediaMessages, claudeMediaMessage)
|
||||
}
|
||||
if message.ToolCalls != nil {
|
||||
for _, tc := range message.ToolCalls.([]interface{}) {
|
||||
toolCallJSON, _ := json.Marshal(tc)
|
||||
var toolCall dto.ToolCall
|
||||
err := json.Unmarshal(toolCallJSON, &toolCall)
|
||||
if err != nil {
|
||||
common.SysError("tool call is not a dto.ToolCall: " + fmt.Sprintf("%v", tc))
|
||||
continue
|
||||
}
|
||||
for _, toolCall := range message.ParseToolCalls() {
|
||||
inputObj := make(map[string]any)
|
||||
if err := json.Unmarshal([]byte(toolCall.Function.Arguments), &inputObj); err != nil {
|
||||
common.SysError("tool call function arguments is not a map[string]any: " + fmt.Sprintf("%v", toolCall.Function.Arguments))
|
||||
@@ -393,7 +386,7 @@ func ResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) *dto.Ope
|
||||
}
|
||||
choice.SetStringContent(responseText)
|
||||
if len(tools) > 0 {
|
||||
choice.Message.ToolCalls = tools
|
||||
choice.Message.SetToolCalls(tools)
|
||||
}
|
||||
fullTextResponse.Model = claudeResponse.Model
|
||||
choices = append(choices, choice)
|
||||
|
||||
@@ -57,7 +57,11 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, re
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
return CovertGemini2OpenAI(*request), nil
|
||||
ai, err := CovertGemini2OpenAI(*request)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return ai, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
|
||||
@@ -72,7 +76,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
|
||||
if info.IsStream {
|
||||
err, usage = GeminiChatStreamHandler(c, resp, info)
|
||||
} else {
|
||||
err, usage = GeminiChatHandler(c, resp)
|
||||
err, usage = GeminiChatHandler(c, resp, info)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
@@ -1,14 +1,18 @@
|
||||
package gemini
|
||||
|
||||
const (
|
||||
GeminiVisionMaxImageNum = 16
|
||||
)
|
||||
|
||||
var ModelList = []string{
|
||||
// stable version
|
||||
"gemini-1.5-pro", "gemini-1.5-flash", "gemini-1.5-flash-8b",
|
||||
// latest version
|
||||
"gemini-1.5-pro-latest", "gemini-1.5-flash-latest",
|
||||
// legacy version
|
||||
"gemini-1.5-pro-exp-0827", "gemini-1.5-flash-exp-0827",
|
||||
// exp
|
||||
"gemini-exp-1114", "gemini-exp-1121", "gemini-exp-1206",
|
||||
// flash exp
|
||||
"gemini-2.0-flash-exp",
|
||||
// thinking exp
|
||||
"gemini-2.0-flash-thinking-exp",
|
||||
"gemini-2.0-flash-thinking-exp-1219",
|
||||
}
|
||||
|
||||
|
||||
@@ -18,10 +18,16 @@ type FunctionCall struct {
|
||||
Arguments any `json:"args"`
|
||||
}
|
||||
|
||||
type FunctionResponse struct {
|
||||
Name string `json:"name"`
|
||||
Response any `json:"response"`
|
||||
}
|
||||
|
||||
type GeminiPart struct {
|
||||
Text string `json:"text,omitempty"`
|
||||
InlineData *GeminiInlineData `json:"inlineData,omitempty"`
|
||||
FunctionCall *FunctionCall `json:"functionCall,omitempty"`
|
||||
Text string `json:"text,omitempty"`
|
||||
InlineData *GeminiInlineData `json:"inlineData,omitempty"`
|
||||
FunctionCall *FunctionCall `json:"functionCall,omitempty"`
|
||||
FunctionResponse *FunctionResponse `json:"functionResponse,omitempty"`
|
||||
}
|
||||
|
||||
type GeminiChatContent struct {
|
||||
@@ -40,12 +46,14 @@ type GeminiChatTools struct {
|
||||
}
|
||||
|
||||
type GeminiChatGenerationConfig struct {
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
TopP float64 `json:"topP,omitempty"`
|
||||
TopK float64 `json:"topK,omitempty"`
|
||||
MaxOutputTokens uint `json:"maxOutputTokens,omitempty"`
|
||||
CandidateCount int `json:"candidateCount,omitempty"`
|
||||
StopSequences []string `json:"stopSequences,omitempty"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
TopP float64 `json:"topP,omitempty"`
|
||||
TopK float64 `json:"topK,omitempty"`
|
||||
MaxOutputTokens uint `json:"maxOutputTokens,omitempty"`
|
||||
CandidateCount int `json:"candidateCount,omitempty"`
|
||||
StopSequences []string `json:"stopSequences,omitempty"`
|
||||
ResponseMimeType string `json:"responseMimeType,omitempty"`
|
||||
ResponseSchema any `json:"responseSchema,omitempty"`
|
||||
}
|
||||
|
||||
type GeminiChatCandidate struct {
|
||||
|
||||
@@ -17,7 +17,7 @@ import (
|
||||
)
|
||||
|
||||
// Setting safety to the lowest possible values since Gemini is already powerless enough
|
||||
func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) *GeminiChatRequest {
|
||||
func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) (*GeminiChatRequest, error) {
|
||||
geminiRequest := GeminiChatRequest{
|
||||
Contents: make([]GeminiChatContent, 0, len(textRequest.Messages)),
|
||||
SafetySettings: []GeminiChatSafetySettings{
|
||||
@@ -56,6 +56,16 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) *GeminiChatReques
|
||||
googleSearch = true
|
||||
continue
|
||||
}
|
||||
if tool.Function.Parameters != nil {
|
||||
params, ok := tool.Function.Parameters.(map[string]interface{})
|
||||
if ok {
|
||||
if props, hasProps := params["properties"].(map[string]interface{}); hasProps {
|
||||
if len(props) == 0 {
|
||||
tool.Function.Parameters = nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
functions = append(functions, tool.Function)
|
||||
}
|
||||
if len(functions) > 0 {
|
||||
@@ -77,6 +87,15 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) *GeminiChatReques
|
||||
},
|
||||
}
|
||||
}
|
||||
if textRequest.ResponseFormat != nil && (textRequest.ResponseFormat.Type == "json_schema" || textRequest.ResponseFormat.Type == "json_object") {
|
||||
geminiRequest.GenerationConfig.ResponseMimeType = "application/json"
|
||||
|
||||
if textRequest.ResponseFormat.JsonSchema != nil && textRequest.ResponseFormat.JsonSchema.Schema != nil {
|
||||
cleanedSchema := removeAdditionalPropertiesWithDepth(textRequest.ResponseFormat.JsonSchema.Schema, 0)
|
||||
geminiRequest.GenerationConfig.ResponseSchema = cleanedSchema
|
||||
}
|
||||
}
|
||||
tool_call_ids := make(map[string]string)
|
||||
//shouldAddDummyModelMessage := false
|
||||
for _, message := range textRequest.Messages {
|
||||
|
||||
@@ -89,79 +108,136 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) *GeminiChatReques
|
||||
},
|
||||
}
|
||||
continue
|
||||
} else if message.Role == "tool" {
|
||||
if len(geminiRequest.Contents) == 0 || geminiRequest.Contents[len(geminiRequest.Contents)-1].Role != "user" {
|
||||
geminiRequest.Contents = append(geminiRequest.Contents, GeminiChatContent{
|
||||
Role: "user",
|
||||
})
|
||||
}
|
||||
var parts = &geminiRequest.Contents[len(geminiRequest.Contents)-1].Parts
|
||||
name := ""
|
||||
if message.Name != nil {
|
||||
name = *message.Name
|
||||
} else if val, exists := tool_call_ids[message.ToolCallId]; exists {
|
||||
name = val
|
||||
}
|
||||
functionResp := &FunctionResponse{
|
||||
Name: name,
|
||||
Response: common.StrToMap(message.StringContent()),
|
||||
}
|
||||
*parts = append(*parts, GeminiPart{
|
||||
FunctionResponse: functionResp,
|
||||
})
|
||||
continue
|
||||
}
|
||||
var parts []GeminiPart
|
||||
content := GeminiChatContent{
|
||||
Role: message.Role,
|
||||
//Parts: []GeminiPart{
|
||||
// {
|
||||
// Text: message.StringContent(),
|
||||
// },
|
||||
//},
|
||||
}
|
||||
openaiContent := message.ParseContent()
|
||||
var parts []GeminiPart
|
||||
imageNum := 0
|
||||
for _, part := range openaiContent {
|
||||
if part.Type == dto.ContentTypeText {
|
||||
parts = append(parts, GeminiPart{
|
||||
Text: part.Text,
|
||||
})
|
||||
} else if part.Type == dto.ContentTypeImageURL {
|
||||
imageNum += 1
|
||||
//if imageNum > GeminiVisionMaxImageNum {
|
||||
// continue
|
||||
//}
|
||||
// 判断是否是url
|
||||
if strings.HasPrefix(part.ImageUrl.(dto.MessageImageUrl).Url, "http") {
|
||||
// 是url,获取图片的类型和base64编码的数据
|
||||
mimeType, data, _ := service.GetImageFromUrl(part.ImageUrl.(dto.MessageImageUrl).Url)
|
||||
isToolCall := false
|
||||
if message.ToolCalls != nil {
|
||||
message.Role = "model"
|
||||
isToolCall = true
|
||||
for _, call := range message.ParseToolCalls() {
|
||||
toolCall := GeminiPart{
|
||||
FunctionCall: &FunctionCall{
|
||||
FunctionName: call.Function.Name,
|
||||
Arguments: call.Function.Parameters,
|
||||
},
|
||||
}
|
||||
parts = append(parts, toolCall)
|
||||
tool_call_ids[call.ID] = call.Function.Name
|
||||
}
|
||||
}
|
||||
if !isToolCall {
|
||||
openaiContent := message.ParseContent()
|
||||
imageNum := 0
|
||||
for _, part := range openaiContent {
|
||||
if part.Type == dto.ContentTypeText {
|
||||
parts = append(parts, GeminiPart{
|
||||
InlineData: &GeminiInlineData{
|
||||
MimeType: mimeType,
|
||||
Data: data,
|
||||
},
|
||||
Text: part.Text,
|
||||
})
|
||||
} else {
|
||||
_, format, base64String, err := service.DecodeBase64ImageData(part.ImageUrl.(dto.MessageImageUrl).Url)
|
||||
if err != nil {
|
||||
continue
|
||||
} else if part.Type == dto.ContentTypeImageURL {
|
||||
imageNum += 1
|
||||
|
||||
if constant.GeminiVisionMaxImageNum != -1 && imageNum > constant.GeminiVisionMaxImageNum {
|
||||
return nil, fmt.Errorf("too many images in the message, max allowed is %d", constant.GeminiVisionMaxImageNum)
|
||||
}
|
||||
// 判断是否是url
|
||||
if strings.HasPrefix(part.ImageUrl.(dto.MessageImageUrl).Url, "http") {
|
||||
// 是url,获取图片的类型和base64编码的数据
|
||||
mimeType, data, _ := service.GetImageFromUrl(part.ImageUrl.(dto.MessageImageUrl).Url)
|
||||
parts = append(parts, GeminiPart{
|
||||
InlineData: &GeminiInlineData{
|
||||
MimeType: mimeType,
|
||||
Data: data,
|
||||
},
|
||||
})
|
||||
} else {
|
||||
_, format, base64String, err := service.DecodeBase64ImageData(part.ImageUrl.(dto.MessageImageUrl).Url)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decode base64 image data failed: %s", err.Error())
|
||||
}
|
||||
parts = append(parts, GeminiPart{
|
||||
InlineData: &GeminiInlineData{
|
||||
MimeType: "image/" + format,
|
||||
Data: base64String,
|
||||
},
|
||||
})
|
||||
}
|
||||
parts = append(parts, GeminiPart{
|
||||
InlineData: &GeminiInlineData{
|
||||
MimeType: "image/" + format,
|
||||
Data: base64String,
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
content.Parts = parts
|
||||
|
||||
// there's no assistant role in gemini and API shall vomit if Role is not user or model
|
||||
if content.Role == "assistant" {
|
||||
content.Role = "model"
|
||||
}
|
||||
// Converting system prompt to prompt from user for the same reason
|
||||
//if content.Role == "system" {
|
||||
// content.Role = "user"
|
||||
// shouldAddDummyModelMessage = true
|
||||
//}
|
||||
geminiRequest.Contents = append(geminiRequest.Contents, content)
|
||||
//
|
||||
//// If a system message is the last message, we need to add a dummy model message to make gemini happy
|
||||
//if shouldAddDummyModelMessage {
|
||||
// geminiRequest.Contents = append(geminiRequest.Contents, GeminiChatContent{
|
||||
// Role: "model",
|
||||
// Parts: []GeminiPart{
|
||||
// {
|
||||
// Text: "Okay",
|
||||
// },
|
||||
// },
|
||||
// })
|
||||
// shouldAddDummyModelMessage = false
|
||||
//}
|
||||
}
|
||||
return &geminiRequest
|
||||
return &geminiRequest, nil
|
||||
}
|
||||
|
||||
func removeAdditionalPropertiesWithDepth(schema interface{}, depth int) interface{} {
|
||||
if depth >= 5 {
|
||||
return schema
|
||||
}
|
||||
|
||||
v, ok := schema.(map[string]interface{})
|
||||
if !ok || len(v) == 0 {
|
||||
return schema
|
||||
}
|
||||
|
||||
// 如果type不为object和array,则直接返回
|
||||
if typeVal, exists := v["type"]; !exists || (typeVal != "object" && typeVal != "array") {
|
||||
return schema
|
||||
}
|
||||
delete(v, "title")
|
||||
switch v["type"] {
|
||||
case "object":
|
||||
delete(v, "additionalProperties")
|
||||
// 处理 properties
|
||||
if properties, ok := v["properties"].(map[string]interface{}); ok {
|
||||
for key, value := range properties {
|
||||
properties[key] = removeAdditionalPropertiesWithDepth(value, depth+1)
|
||||
}
|
||||
}
|
||||
for _, field := range []string{"allOf", "anyOf", "oneOf"} {
|
||||
if nested, ok := v[field].([]interface{}); ok {
|
||||
for i, item := range nested {
|
||||
nested[i] = removeAdditionalPropertiesWithDepth(item, depth+1)
|
||||
}
|
||||
}
|
||||
}
|
||||
case "array":
|
||||
if items, ok := v["items"].(map[string]interface{}); ok {
|
||||
v["items"] = removeAdditionalPropertiesWithDepth(items, depth+1)
|
||||
}
|
||||
}
|
||||
|
||||
return v
|
||||
}
|
||||
|
||||
func (g *GeminiChatResponse) GetResponseText() string {
|
||||
@@ -174,19 +250,13 @@ func (g *GeminiChatResponse) GetResponseText() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func getToolCalls(candidate *GeminiChatCandidate) []dto.ToolCall {
|
||||
var toolCalls []dto.ToolCall
|
||||
|
||||
item := candidate.Content.Parts[0]
|
||||
if item.FunctionCall == nil {
|
||||
return toolCalls
|
||||
}
|
||||
func getToolCall(item *GeminiPart) *dto.ToolCall {
|
||||
argsBytes, err := json.Marshal(item.FunctionCall.Arguments)
|
||||
if err != nil {
|
||||
//common.SysError("getToolCalls failed: " + err.Error())
|
||||
return toolCalls
|
||||
//common.SysError("getToolCall failed: " + err.Error())
|
||||
return nil
|
||||
}
|
||||
toolCall := dto.ToolCall{
|
||||
return &dto.ToolCall{
|
||||
ID: fmt.Sprintf("call_%s", common.GetUUID()),
|
||||
Type: "function",
|
||||
Function: dto.FunctionCall{
|
||||
@@ -194,10 +264,32 @@ func getToolCalls(candidate *GeminiChatCandidate) []dto.ToolCall {
|
||||
Name: item.FunctionCall.FunctionName,
|
||||
},
|
||||
}
|
||||
toolCalls = append(toolCalls, toolCall)
|
||||
return toolCalls
|
||||
}
|
||||
|
||||
// func getToolCalls(candidate *GeminiChatCandidate, index int) []dto.ToolCall {
|
||||
// var toolCalls []dto.ToolCall
|
||||
|
||||
// item := candidate.Content.Parts[index]
|
||||
// if item.FunctionCall == nil {
|
||||
// return toolCalls
|
||||
// }
|
||||
// argsBytes, err := json.Marshal(item.FunctionCall.Arguments)
|
||||
// if err != nil {
|
||||
// //common.SysError("getToolCalls failed: " + err.Error())
|
||||
// return toolCalls
|
||||
// }
|
||||
// toolCall := dto.ToolCall{
|
||||
// ID: fmt.Sprintf("call_%s", common.GetUUID()),
|
||||
// Type: "function",
|
||||
// Function: dto.FunctionCall{
|
||||
// Arguments: string(argsBytes),
|
||||
// Name: item.FunctionCall.FunctionName,
|
||||
// },
|
||||
// }
|
||||
// toolCalls = append(toolCalls, toolCall)
|
||||
// return toolCalls
|
||||
// }
|
||||
|
||||
func responseGeminiChat2OpenAI(response *GeminiChatResponse) *dto.OpenAITextResponse {
|
||||
fullTextResponse := dto.OpenAITextResponse{
|
||||
Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
|
||||
@@ -207,6 +299,8 @@ func responseGeminiChat2OpenAI(response *GeminiChatResponse) *dto.OpenAITextResp
|
||||
}
|
||||
content, _ := json.Marshal("")
|
||||
for i, candidate := range response.Candidates {
|
||||
// jsonData, _ := json.MarshalIndent(candidate, "", " ")
|
||||
// common.SysLog(fmt.Sprintf("candidate: %v", string(jsonData)))
|
||||
choice := dto.OpenAITextResponseChoice{
|
||||
Index: i,
|
||||
Message: dto.Message{
|
||||
@@ -216,16 +310,20 @@ func responseGeminiChat2OpenAI(response *GeminiChatResponse) *dto.OpenAITextResp
|
||||
FinishReason: constant.FinishReasonStop,
|
||||
}
|
||||
if len(candidate.Content.Parts) > 0 {
|
||||
if candidate.Content.Parts[0].FunctionCall != nil {
|
||||
choice.FinishReason = constant.FinishReasonToolCalls
|
||||
choice.Message.ToolCalls = getToolCalls(&candidate)
|
||||
} else {
|
||||
var texts []string
|
||||
for _, part := range candidate.Content.Parts {
|
||||
var texts []string
|
||||
var tool_calls []dto.ToolCall
|
||||
for _, part := range candidate.Content.Parts {
|
||||
if part.FunctionCall != nil {
|
||||
choice.FinishReason = constant.FinishReasonToolCalls
|
||||
if call := getToolCall(&part); call != nil {
|
||||
tool_calls = append(tool_calls, *call)
|
||||
}
|
||||
} else {
|
||||
texts = append(texts, part.Text)
|
||||
}
|
||||
choice.Message.SetStringContent(strings.Join(texts, "\n"))
|
||||
}
|
||||
choice.Message.SetStringContent(strings.Join(texts, "\n"))
|
||||
choice.Message.SetToolCalls(tool_calls)
|
||||
}
|
||||
fullTextResponse.Choices = append(fullTextResponse.Choices, choice)
|
||||
}
|
||||
@@ -236,13 +334,22 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) *dto.Ch
|
||||
var choice dto.ChatCompletionsStreamResponseChoice
|
||||
//choice.Delta.SetContentString(geminiResponse.GetResponseText())
|
||||
if len(geminiResponse.Candidates) > 0 && len(geminiResponse.Candidates[0].Content.Parts) > 0 {
|
||||
respFirst := geminiResponse.Candidates[0].Content.Parts[0]
|
||||
if respFirst.FunctionCall != nil {
|
||||
// function response
|
||||
choice.Delta.ToolCalls = getToolCalls(&geminiResponse.Candidates[0])
|
||||
} else {
|
||||
// text response
|
||||
choice.Delta.SetContentString(respFirst.Text)
|
||||
var texts []string
|
||||
var tool_calls []dto.ToolCall
|
||||
for _, part := range geminiResponse.Candidates[0].Content.Parts {
|
||||
if part.FunctionCall != nil {
|
||||
if call := getToolCall(&part); call != nil {
|
||||
tool_calls = append(tool_calls, *call)
|
||||
}
|
||||
} else {
|
||||
texts = append(texts, part.Text)
|
||||
}
|
||||
}
|
||||
if len(texts) > 0 {
|
||||
choice.Delta.SetContentString(strings.Join(texts, "\n"))
|
||||
}
|
||||
if len(tool_calls) > 0 {
|
||||
choice.Delta.ToolCalls = tool_calls
|
||||
}
|
||||
}
|
||||
var response dto.ChatCompletionsStreamResponse
|
||||
@@ -283,6 +390,7 @@ func GeminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycom
|
||||
}
|
||||
response.Id = id
|
||||
response.Created = createAt
|
||||
response.Model = info.UpstreamModelName
|
||||
responseText += response.Choices[0].Delta.GetContentString()
|
||||
if geminiResponse.UsageMetadata.TotalTokenCount != 0 {
|
||||
usage.PromptTokens = geminiResponse.UsageMetadata.PromptTokenCount
|
||||
@@ -311,7 +419,7 @@ func GeminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycom
|
||||
return nil, usage
|
||||
}
|
||||
|
||||
func GeminiChatHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
func GeminiChatHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
@@ -337,6 +445,7 @@ func GeminiChatHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWit
|
||||
}, nil
|
||||
}
|
||||
fullTextResponse := responseGeminiChat2OpenAI(&geminiResponse)
|
||||
fullTextResponse.Model = info.UpstreamModelName
|
||||
usage := dto.Usage{
|
||||
PromptTokens: geminiResponse.UsageMetadata.PromptTokenCount,
|
||||
CompletionTokens: geminiResponse.UsageMetadata.CandidatesTokenCount,
|
||||
|
||||
@@ -109,12 +109,19 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, re
|
||||
if info.ChannelType != common.ChannelTypeOpenAI {
|
||||
request.StreamOptions = nil
|
||||
}
|
||||
if strings.HasPrefix(request.Model, "o1-") {
|
||||
if strings.HasPrefix(request.Model, "o1") {
|
||||
if request.MaxCompletionTokens == 0 && request.MaxTokens != 0 {
|
||||
request.MaxCompletionTokens = request.MaxTokens
|
||||
request.MaxTokens = 0
|
||||
}
|
||||
}
|
||||
if request.Model == "o1" || request.Model == "o1-2024-12-17" {
|
||||
//修改第一个Message的内容,将system改为developer
|
||||
if len(request.Messages) > 0 && request.Messages[0].Role == "system" {
|
||||
request.Messages[0].Role = "developer"
|
||||
}
|
||||
}
|
||||
|
||||
return request, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -13,6 +13,7 @@ var ModelList = []string{
|
||||
"gpt-4o-mini", "gpt-4o-mini-2024-07-18",
|
||||
"o1-preview", "o1-preview-2024-09-12",
|
||||
"o1-mini", "o1-mini-2024-09-12",
|
||||
"o1", "o1-2024-12-17",
|
||||
"gpt-4o-audio-preview", "gpt-4o-audio-preview-2024-10-01",
|
||||
"gpt-4o-realtime-preview", "gpt-4o-realtime-preview-2024-10-01",
|
||||
"text-embedding-ada-002", "text-embedding-3-small", "text-embedding-3-large",
|
||||
|
||||
@@ -135,7 +135,10 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, re
|
||||
c.Set("request_model", request.Model)
|
||||
return vertexClaudeReq, nil
|
||||
} else if a.RequestMode == RequestModeGemini {
|
||||
geminiRequest := gemini.CovertGemini2OpenAI(*request)
|
||||
geminiRequest, err := gemini.CovertGemini2OpenAI(*request)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
c.Set("request_model", request.Model)
|
||||
return geminiRequest, nil
|
||||
} else if a.RequestMode == RequestModeLlama {
|
||||
@@ -167,7 +170,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
|
||||
case RequestModeClaude:
|
||||
err, usage = claude.ClaudeHandler(c, resp, claude.RequestModeMessage, info)
|
||||
case RequestModeGemini:
|
||||
err, usage = gemini.GeminiChatHandler(c, resp)
|
||||
err, usage = gemini.GeminiChatHandler(c, resp, info)
|
||||
case RequestModeLlama:
|
||||
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.OriginModelName)
|
||||
}
|
||||
|
||||
@@ -2,8 +2,9 @@ package common
|
||||
|
||||
import (
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"one-api/dto"
|
||||
"one-api/relay/constant"
|
||||
relayconstant "one-api/relay/constant"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -66,13 +67,13 @@ func GenRelayInfo(c *gin.Context) *RelayInfo {
|
||||
userId := c.GetInt("id")
|
||||
group := c.GetString("group")
|
||||
tokenUnlimited := c.GetBool("token_unlimited_quota")
|
||||
startTime := time.Now()
|
||||
startTime := c.GetTime(constant.ContextKeyRequestStartTime)
|
||||
// firstResponseTime = time.Now() - 1 second
|
||||
|
||||
apiType, _ := constant.ChannelType2APIType(channelType)
|
||||
apiType, _ := relayconstant.ChannelType2APIType(channelType)
|
||||
|
||||
info := &RelayInfo{
|
||||
RelayMode: constant.Path2RelayMode(c.Request.URL.Path),
|
||||
RelayMode: relayconstant.Path2RelayMode(c.Request.URL.Path),
|
||||
BaseUrl: c.GetString("base_url"),
|
||||
RequestURLPath: c.Request.URL.String(),
|
||||
ChannelType: channelType,
|
||||
@@ -158,10 +159,10 @@ func GenTaskRelayInfo(c *gin.Context) *TaskRelayInfo {
|
||||
group := c.GetString("group")
|
||||
startTime := time.Now()
|
||||
|
||||
apiType, _ := constant.ChannelType2APIType(channelType)
|
||||
apiType, _ := relayconstant.ChannelType2APIType(channelType)
|
||||
|
||||
info := &TaskRelayInfo{
|
||||
RelayMode: constant.Path2RelayMode(c.Request.URL.Path),
|
||||
RelayMode: relayconstant.Path2RelayMode(c.Request.URL.Path),
|
||||
BaseUrl: c.GetString("base_url"),
|
||||
RequestURLPath: c.Request.URL.String(),
|
||||
ChannelType: channelType,
|
||||
|
||||
@@ -7,12 +7,12 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"one-api/dto"
|
||||
"one-api/model"
|
||||
relaycommon "one-api/relay/common"
|
||||
relayconstant "one-api/relay/constant"
|
||||
"one-api/service"
|
||||
"one-api/setting"
|
||||
)
|
||||
|
||||
func getAndValidAudioRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto.AudioRequest, error) {
|
||||
@@ -26,7 +26,7 @@ func getAndValidAudioRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto.
|
||||
if audioRequest.Model == "" {
|
||||
return nil, errors.New("model is required")
|
||||
}
|
||||
if constant.ShouldCheckPromptSensitive() {
|
||||
if setting.ShouldCheckPromptSensitive() {
|
||||
err := service.CheckSensitiveInput(audioRequest.Input)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
||||
@@ -9,11 +9,11 @@ import (
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"one-api/dto"
|
||||
"one-api/model"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/service"
|
||||
"one-api/setting"
|
||||
"strings"
|
||||
)
|
||||
|
||||
@@ -59,7 +59,7 @@ func getAndValidImageRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto.
|
||||
//if imageRequest.N != 0 && (imageRequest.N < 1 || imageRequest.N > 10) {
|
||||
// return service.OpenAIErrorWrapper(errors.New("n must be between 1 and 10"), "invalid_field_value", http.StatusBadRequest)
|
||||
//}
|
||||
if constant.ShouldCheckPromptSensitive() {
|
||||
if setting.ShouldCheckPromptSensitive() {
|
||||
err := service.CheckSensitiveInput(imageRequest.Prompt)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
||||
@@ -15,6 +15,7 @@ import (
|
||||
relaycommon "one-api/relay/common"
|
||||
relayconstant "one-api/relay/constant"
|
||||
"one-api/service"
|
||||
"one-api/setting"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -111,8 +112,8 @@ func coverMidjourneyTaskDto(c *gin.Context, originTask *model.Midjourney) (midjo
|
||||
midjourneyTask.StartTime = originTask.StartTime
|
||||
midjourneyTask.FinishTime = originTask.FinishTime
|
||||
midjourneyTask.ImageUrl = ""
|
||||
if originTask.ImageUrl != "" && constant.MjForwardUrlEnabled {
|
||||
midjourneyTask.ImageUrl = constant.ServerAddress + "/mj/image/" + originTask.MjId
|
||||
if originTask.ImageUrl != "" && setting.MjForwardUrlEnabled {
|
||||
midjourneyTask.ImageUrl = setting.ServerAddress + "/mj/image/" + originTask.MjId
|
||||
if originTask.Status != "SUCCESS" {
|
||||
midjourneyTask.ImageUrl += "?rand=" + strconv.FormatInt(time.Now().UnixNano(), 10)
|
||||
}
|
||||
@@ -207,7 +208,8 @@ func RelaySwapFace(c *gin.Context) *dto.MidjourneyResponse {
|
||||
other := make(map[string]interface{})
|
||||
other["model_price"] = modelPrice
|
||||
other["group_ratio"] = groupRatio
|
||||
model.RecordConsumeLog(ctx, userId, channelId, 0, 0, modelName, tokenName, quota, logContent, tokenId, userQuota, 0, false, other)
|
||||
model.RecordConsumeLog(ctx, userId, channelId, 0, 0, modelName, tokenName,
|
||||
quota, logContent, tokenId, userQuota, 0, false, group, other)
|
||||
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
|
||||
channelId := c.GetInt("channel_id")
|
||||
model.UpdateChannelUsedQuota(channelId, quota)
|
||||
@@ -421,7 +423,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
|
||||
if originTask == nil {
|
||||
return service.MidjourneyErrorWrapper(constant.MjRequestError, "task_not_found")
|
||||
} else { //原任务的Status=SUCCESS,则可以做放大UPSCALE、变换VARIATION等动作,此时必须使用原来的请求地址才能正确处理
|
||||
if constant.MjActionCheckSuccessEnabled {
|
||||
if setting.MjActionCheckSuccessEnabled {
|
||||
if originTask.Status != "SUCCESS" && relayMode != relayconstant.RelayModeMidjourneyModal {
|
||||
return service.MidjourneyErrorWrapper(constant.MjRequestError, "task_status_not_success")
|
||||
}
|
||||
@@ -512,7 +514,8 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
|
||||
other := make(map[string]interface{})
|
||||
other["model_price"] = modelPrice
|
||||
other["group_ratio"] = groupRatio
|
||||
model.RecordConsumeLog(ctx, userId, channelId, 0, 0, modelName, tokenName, quota, logContent, tokenId, userQuota, 0, false, other)
|
||||
model.RecordConsumeLog(ctx, userId, channelId, 0, 0, modelName, tokenName,
|
||||
quota, logContent, tokenId, userQuota, 0, false, group, other)
|
||||
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
|
||||
channelId := c.GetInt("channel_id")
|
||||
model.UpdateChannelUsedQuota(channelId, quota)
|
||||
|
||||
@@ -15,6 +15,7 @@ import (
|
||||
relaycommon "one-api/relay/common"
|
||||
relayconstant "one-api/relay/constant"
|
||||
"one-api/service"
|
||||
"one-api/setting"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -100,7 +101,7 @@ func TextHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
|
||||
var modelRatio float64
|
||||
//err := service.SensitiveWordsCheck(textRequest)
|
||||
|
||||
if constant.ShouldCheckPromptSensitive() {
|
||||
if setting.ShouldCheckPromptSensitive() {
|
||||
err = checkRequestSensitive(textRequest, relayInfo)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapperLocal(err, "sensitive_words_detected", http.StatusBadRequest)
|
||||
@@ -384,7 +385,7 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelN
|
||||
}
|
||||
other := service.GenerateTextOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio, modelPrice)
|
||||
model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, promptTokens, completionTokens, logModel,
|
||||
tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, other)
|
||||
tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, relayInfo.Group, other)
|
||||
|
||||
//if quota != 0 {
|
||||
//
|
||||
|
||||
@@ -126,7 +126,8 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) {
|
||||
other := make(map[string]interface{})
|
||||
other["model_price"] = modelPrice
|
||||
other["group_ratio"] = groupRatio
|
||||
model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, 0, 0, modelName, tokenName, quota, logContent, relayInfo.TokenId, userQuota, 0, false, other)
|
||||
model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, 0, 0,
|
||||
modelName, tokenName, quota, logContent, relayInfo.TokenId, userQuota, 0, false, relayInfo.Group, other)
|
||||
model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, quota)
|
||||
model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota)
|
||||
}
|
||||
|
||||
@@ -98,6 +98,7 @@ func SetApiRouter(router *gin.Engine) {
|
||||
channelRoute.POST("/batch", controller.DeleteChannelBatch)
|
||||
channelRoute.POST("/fix", controller.FixChannelsAbilities)
|
||||
channelRoute.GET("/fetch_models/:id", controller.FetchUpstreamModels)
|
||||
channelRoute.POST("/fetch_models", controller.FetchModels)
|
||||
|
||||
}
|
||||
tokenRoute := apiRouter.Group("/token")
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"one-api/constant"
|
||||
"one-api/setting"
|
||||
)
|
||||
|
||||
func GetCallbackAddress() string {
|
||||
if constant.CustomCallbackAddress == "" {
|
||||
return constant.ServerAddress
|
||||
if setting.CustomCallbackAddress == "" {
|
||||
return setting.ServerAddress
|
||||
}
|
||||
return constant.CustomCallbackAddress
|
||||
return setting.CustomCallbackAddress
|
||||
}
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"one-api/constant"
|
||||
"one-api/dto"
|
||||
relayconstant "one-api/relay/constant"
|
||||
"one-api/setting"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -167,16 +168,16 @@ func DoMidjourneyHttpRequest(c *gin.Context, timeout time.Duration, fullRequestU
|
||||
if err != nil {
|
||||
return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "read_request_body_failed", http.StatusInternalServerError), nullBytes, err
|
||||
}
|
||||
if !constant.MjAccountFilterEnabled {
|
||||
if !setting.MjAccountFilterEnabled {
|
||||
delete(mapResult, "accountFilter")
|
||||
}
|
||||
if !constant.MjNotifyEnabled {
|
||||
if !setting.MjNotifyEnabled {
|
||||
delete(mapResult, "notifyHook")
|
||||
}
|
||||
//req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
|
||||
// make new request with mapResult
|
||||
}
|
||||
if constant.MjModeClearEnabled {
|
||||
if setting.MjModeClearEnabled {
|
||||
if prompt, ok := mapResult["prompt"].(string); ok {
|
||||
prompt = strings.Replace(prompt, "--fast", "", -1)
|
||||
prompt = strings.Replace(prompt, "--relax", "", -1)
|
||||
|
||||
@@ -139,7 +139,7 @@ func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, mod
|
||||
}
|
||||
other := GenerateWssOtherInfo(ctx, relayInfo, usage, modelRatio, groupRatio, completionRatio, audioRatio, audioCompletionRatio, modelPrice)
|
||||
model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, usage.InputTokens, usage.OutputTokens, logModel,
|
||||
tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, other)
|
||||
tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, relayInfo.Group, other)
|
||||
}
|
||||
|
||||
func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
|
||||
@@ -208,5 +208,5 @@ func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
|
||||
}
|
||||
other := GenerateAudioOtherInfo(ctx, relayInfo, usage, modelRatio, groupRatio, completionRatio, audioRatio, audioCompletionRatio, modelPrice)
|
||||
model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, usage.PromptTokens, usage.CompletionTokens, logModel,
|
||||
tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, other)
|
||||
tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, relayInfo.Group, other)
|
||||
}
|
||||
|
||||
@@ -3,8 +3,8 @@ package service
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"one-api/constant"
|
||||
"one-api/dto"
|
||||
"one-api/setting"
|
||||
"strings"
|
||||
)
|
||||
|
||||
@@ -56,7 +56,7 @@ func CheckSensitiveInput(input any) error {
|
||||
|
||||
// SensitiveWordContains 是否包含敏感词,返回是否包含敏感词和敏感词列表
|
||||
func SensitiveWordContains(text string) (bool, []string) {
|
||||
if len(constant.SensitiveWords) == 0 {
|
||||
if len(setting.SensitiveWords) == 0 {
|
||||
return false, nil
|
||||
}
|
||||
checkText := strings.ToLower(text)
|
||||
@@ -75,7 +75,7 @@ func SensitiveWordContains(text string) (bool, []string) {
|
||||
|
||||
// SensitiveWordReplace 敏感词替换,返回是否包含敏感词和替换后的文本
|
||||
func SensitiveWordReplace(text string, returnImmediately bool) (bool, []string, string) {
|
||||
if len(constant.SensitiveWords) == 0 {
|
||||
if len(setting.SensitiveWords) == 0 {
|
||||
return false, nil, text
|
||||
}
|
||||
checkText := strings.ToLower(text)
|
||||
|
||||
@@ -4,7 +4,7 @@ import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
goahocorasick "github.com/anknown/ahocorasick"
|
||||
"one-api/constant"
|
||||
"one-api/setting"
|
||||
"strings"
|
||||
)
|
||||
|
||||
@@ -70,7 +70,7 @@ func InitAc() *goahocorasick.Machine {
|
||||
func readRunes() [][]rune {
|
||||
var dict [][]rune
|
||||
|
||||
for _, word := range constant.SensitiveWords {
|
||||
for _, word := range setting.SensitiveWords {
|
||||
word = strings.ToLower(word)
|
||||
l := bytes.TrimSpace([]byte(word))
|
||||
dict = append(dict, bytes.Runes(l))
|
||||
|
||||
@@ -5,20 +5,20 @@ import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"one-api/setting"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func DoImageRequest(originUrl string) (resp *http.Response, err error) {
|
||||
if constant.EnableWorker() {
|
||||
if setting.EnableWorker() {
|
||||
common.SysLog(fmt.Sprintf("downloading image from worker: %s", originUrl))
|
||||
workerUrl := constant.WorkerUrl
|
||||
workerUrl := setting.WorkerUrl
|
||||
if !strings.HasSuffix(workerUrl, "/") {
|
||||
workerUrl += "/"
|
||||
}
|
||||
// post request to worker
|
||||
data := []byte(`{"url":"` + originUrl + `","key":"` + constant.WorkerValidKey + `"}`)
|
||||
return http.Post(constant.WorkerUrl, "application/json", bytes.NewBuffer(data))
|
||||
data := []byte(`{"url":"` + originUrl + `","key":"` + setting.WorkerValidKey + `"}`)
|
||||
return http.Post(setting.WorkerUrl, "application/json", bytes.NewBuffer(data))
|
||||
} else {
|
||||
common.SysLog(fmt.Sprintf("downloading image from origin: %s", originUrl))
|
||||
return http.Get(originUrl)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package constant
|
||||
package setting
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
7
setting/midjourney.go
Normal file
7
setting/midjourney.go
Normal file
@@ -0,0 +1,7 @@
|
||||
package setting
|
||||
|
||||
var MjNotifyEnabled = false
|
||||
var MjAccountFilterEnabled = false
|
||||
var MjModeClearEnabled = false
|
||||
var MjForwardUrlEnabled = true
|
||||
var MjActionCheckSuccessEnabled = true
|
||||
@@ -1,4 +1,4 @@
|
||||
package constant
|
||||
package setting
|
||||
|
||||
var PayAddress = ""
|
||||
var CustomCallbackAddress = ""
|
||||
@@ -1,4 +1,4 @@
|
||||
package constant
|
||||
package setting
|
||||
|
||||
import "strings"
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package constant
|
||||
package setting
|
||||
|
||||
var ServerAddress = "http://localhost:3000"
|
||||
var WorkerUrl = ""
|
||||
@@ -1,4 +1,4 @@
|
||||
import React, { useEffect, useState } from 'react';
|
||||
import React, { useContext, useEffect, useState } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import {
|
||||
API,
|
||||
@@ -25,7 +25,7 @@ import {
|
||||
} from '@douyinfe/semi-ui';
|
||||
import { ITEMS_PER_PAGE } from '../constants';
|
||||
import {
|
||||
renderAudioModelPrice,
|
||||
renderAudioModelPrice, renderGroup,
|
||||
renderModelPrice, renderModelPriceSimple,
|
||||
renderNumber,
|
||||
renderQuota,
|
||||
@@ -33,6 +33,7 @@ import {
|
||||
} from '../helpers/render';
|
||||
import Paragraph from '@douyinfe/semi-ui/lib/es/typography/paragraph';
|
||||
import { getLogOther } from '../helpers/other.js';
|
||||
import { StyleContext } from '../context/Style/index.js';
|
||||
|
||||
const { Header } = Layout;
|
||||
|
||||
@@ -217,6 +218,37 @@ const LogsTable = () => {
|
||||
);
|
||||
},
|
||||
},
|
||||
{
|
||||
title: t('分组'),
|
||||
dataIndex: 'group',
|
||||
render: (text, record, index) => {
|
||||
if (record.type === 0 || record.type === 2) {
|
||||
if (record.group) {
|
||||
return (
|
||||
<>
|
||||
{renderGroup(record.group)}
|
||||
</>
|
||||
);
|
||||
} else {
|
||||
let other = JSON.parse(record.other);
|
||||
if (other === null) {
|
||||
return <></>;
|
||||
}
|
||||
if (other.group !== undefined) {
|
||||
return (
|
||||
<>
|
||||
{renderGroup(other.group)}
|
||||
</>
|
||||
);
|
||||
} else {
|
||||
return <></>;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
return <></>;
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
title: t('类型'),
|
||||
dataIndex: 'type',
|
||||
@@ -375,6 +407,7 @@ const LogsTable = () => {
|
||||
},
|
||||
];
|
||||
|
||||
const [styleState, styleDispatch] = useContext(StyleContext);
|
||||
const [logs, setLogs] = useState([]);
|
||||
const [expandData, setExpandData] = useState({});
|
||||
const [showStat, setShowStat] = useState(false);
|
||||
@@ -394,6 +427,7 @@ const LogsTable = () => {
|
||||
start_timestamp: timestamp2string(getTodayStartTimestamp()),
|
||||
end_timestamp: timestamp2string(now.getTime() / 1000 + 3600),
|
||||
channel: '',
|
||||
group: '',
|
||||
});
|
||||
const {
|
||||
username,
|
||||
@@ -402,6 +436,7 @@ const LogsTable = () => {
|
||||
start_timestamp,
|
||||
end_timestamp,
|
||||
channel,
|
||||
group,
|
||||
} = inputs;
|
||||
|
||||
const [stat, setStat] = useState({
|
||||
@@ -410,13 +445,19 @@ const LogsTable = () => {
|
||||
});
|
||||
|
||||
const handleInputChange = (value, name) => {
|
||||
setInputs((inputs) => ({ ...inputs, [name]: value }));
|
||||
if (value && (name === 'start_timestamp' || name === 'end_timestamp')) {
|
||||
// 确保日期值是有效的
|
||||
const dateValue = typeof value === 'string' ? value : timestamp2string(value);
|
||||
setInputs(inputs => ({ ...inputs, [name]: dateValue }));
|
||||
} else {
|
||||
setInputs(inputs => ({ ...inputs, [name]: value }));
|
||||
}
|
||||
};
|
||||
|
||||
const getLogSelfStat = async () => {
|
||||
let localStartTimestamp = Date.parse(start_timestamp) / 1000;
|
||||
let localStartTimestamp = Date.parse(3) / 1000;
|
||||
let localEndTimestamp = Date.parse(end_timestamp) / 1000;
|
||||
let url = `/api/log/self/stat?type=${logType}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}`;
|
||||
let url = `/api/log/self/stat?type=${logType}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}&group=${group}`;
|
||||
url = encodeURI(url);
|
||||
let res = await API.get(url);
|
||||
const { success, message, data } = res.data;
|
||||
@@ -430,7 +471,7 @@ const LogsTable = () => {
|
||||
const getLogStat = async () => {
|
||||
let localStartTimestamp = Date.parse(start_timestamp) / 1000;
|
||||
let localEndTimestamp = Date.parse(end_timestamp) / 1000;
|
||||
let url = `/api/log/stat?type=${logType}&username=${username}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}&channel=${channel}`;
|
||||
let url = `/api/log/stat?type=${logType}&username=${username}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}&channel=${channel}&group=${group}`;
|
||||
url = encodeURI(url);
|
||||
let res = await API.get(url);
|
||||
const { success, message, data } = res.data;
|
||||
@@ -573,9 +614,9 @@ const LogsTable = () => {
|
||||
let localStartTimestamp = Date.parse(start_timestamp) / 1000;
|
||||
let localEndTimestamp = Date.parse(end_timestamp) / 1000;
|
||||
if (isAdminUser) {
|
||||
url = `/api/log/?p=${startIdx}&page_size=${pageSize}&type=${logType}&username=${username}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}&channel=${channel}`;
|
||||
url = `/api/log/?p=${startIdx}&page_size=${pageSize}&type=${logType}&username=${username}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}&channel=${channel}&group=${group}`;
|
||||
} else {
|
||||
url = `/api/log/self/?p=${startIdx}&page_size=${pageSize}&type=${logType}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}`;
|
||||
url = `/api/log/self/?p=${startIdx}&page_size=${pageSize}&type=${logType}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}&group=${group}`;
|
||||
}
|
||||
url = encodeURI(url);
|
||||
const res = await API.get(url);
|
||||
@@ -659,10 +700,53 @@ const LogsTable = () => {
|
||||
</Header>
|
||||
<Form layout='horizontal' style={{ marginTop: 10 }}>
|
||||
<>
|
||||
<Form.Section>
|
||||
<div style={{ marginBottom: 10 }}>
|
||||
{
|
||||
styleState.isMobile ? (
|
||||
<div>
|
||||
<Form.DatePicker
|
||||
field='start_timestamp'
|
||||
label={t('起始时间')}
|
||||
style={{ width: 272 }}
|
||||
initValue={start_timestamp}
|
||||
type='dateTime'
|
||||
onChange={(value) => {
|
||||
console.log(value);
|
||||
handleInputChange(value, 'start_timestamp')
|
||||
}}
|
||||
/>
|
||||
<Form.DatePicker
|
||||
field='end_timestamp'
|
||||
fluid
|
||||
label={t('结束时间')}
|
||||
style={{ width: 272 }}
|
||||
initValue={end_timestamp}
|
||||
type='dateTime'
|
||||
onChange={(value) => handleInputChange(value, 'end_timestamp')}
|
||||
/>
|
||||
</div>
|
||||
) : (
|
||||
<Form.DatePicker
|
||||
field="range_timestamp"
|
||||
label={t('时间范围')}
|
||||
initValue={[start_timestamp, end_timestamp]}
|
||||
type="dateTimeRange"
|
||||
name="range_timestamp"
|
||||
onChange={(value) => {
|
||||
if (Array.isArray(value) && value.length === 2) {
|
||||
handleInputChange(value[0], 'start_timestamp');
|
||||
handleInputChange(value[1], 'end_timestamp');
|
||||
}
|
||||
}}
|
||||
/>
|
||||
)
|
||||
}
|
||||
</div>
|
||||
</Form.Section>
|
||||
<Form.Input
|
||||
field='token_name'
|
||||
label={t('令牌名称')}
|
||||
style={{ width: 176 }}
|
||||
value={token_name}
|
||||
placeholder={t('可选值')}
|
||||
name='token_name'
|
||||
@@ -671,39 +755,24 @@ const LogsTable = () => {
|
||||
<Form.Input
|
||||
field='model_name'
|
||||
label={t('模型名称')}
|
||||
style={{ width: 176 }}
|
||||
value={model_name}
|
||||
placeholder={t('可选值')}
|
||||
name='model_name'
|
||||
onChange={(value) => handleInputChange(value, 'model_name')}
|
||||
/>
|
||||
<Form.DatePicker
|
||||
field='start_timestamp'
|
||||
label={t('起始时间')}
|
||||
style={{ width: 272 }}
|
||||
initValue={start_timestamp}
|
||||
value={start_timestamp}
|
||||
type='dateTime'
|
||||
name='start_timestamp'
|
||||
onChange={(value) => handleInputChange(value, 'start_timestamp')}
|
||||
/>
|
||||
<Form.DatePicker
|
||||
field='end_timestamp'
|
||||
fluid
|
||||
label={t('结束时间')}
|
||||
style={{ width: 272 }}
|
||||
initValue={end_timestamp}
|
||||
value={end_timestamp}
|
||||
type='dateTime'
|
||||
name='end_timestamp'
|
||||
onChange={(value) => handleInputChange(value, 'end_timestamp')}
|
||||
<Form.Input
|
||||
field='group'
|
||||
label={t('分组')}
|
||||
value={group}
|
||||
placeholder={t('可选值')}
|
||||
name='group'
|
||||
onChange={(value) => handleInputChange(value, 'group')}
|
||||
/>
|
||||
{isAdminUser && (
|
||||
<>
|
||||
<Form.Input
|
||||
field='channel'
|
||||
label={t('渠道 ID')}
|
||||
style={{ width: 176 }}
|
||||
value={channel}
|
||||
placeholder={t('可选值')}
|
||||
name='channel'
|
||||
@@ -712,7 +781,6 @@ const LogsTable = () => {
|
||||
<Form.Input
|
||||
field='username'
|
||||
label={t('用户名称')}
|
||||
style={{ width: 176 }}
|
||||
value={username}
|
||||
placeholder={t('可选值')}
|
||||
name='username'
|
||||
|
||||
@@ -162,36 +162,39 @@ const ModelPricing = () => {
|
||||
title: t('可用分组'),
|
||||
dataIndex: 'enable_groups',
|
||||
render: (text, record, index) => {
|
||||
|
||||
// enable_groups is a string array
|
||||
return (
|
||||
<Space>
|
||||
{text.map((group) => {
|
||||
if (group === selectedGroup) {
|
||||
return (
|
||||
<Tag
|
||||
color='blue'
|
||||
size='large'
|
||||
prefixIcon={<IconVerify />}
|
||||
>
|
||||
{group}
|
||||
</Tag>
|
||||
);
|
||||
} else {
|
||||
return (
|
||||
<Tag
|
||||
color='blue'
|
||||
size='large'
|
||||
onClick={() => {
|
||||
setSelectedGroup(group);
|
||||
showInfo(t('当前查看的分组为:{{group}},倍率为:{{ratio}}', {
|
||||
group: group,
|
||||
ratio: groupRatio[group]
|
||||
}));
|
||||
}}
|
||||
>
|
||||
{group}
|
||||
</Tag>
|
||||
);
|
||||
if (usableGroup[group]) {
|
||||
if (group === selectedGroup) {
|
||||
return (
|
||||
<Tag
|
||||
color='blue'
|
||||
size='large'
|
||||
prefixIcon={<IconVerify />}
|
||||
>
|
||||
{group}
|
||||
</Tag>
|
||||
);
|
||||
} else {
|
||||
return (
|
||||
<Tag
|
||||
color='blue'
|
||||
size='large'
|
||||
onClick={() => {
|
||||
setSelectedGroup(group);
|
||||
showInfo(t('当前查看的分组为:{{group}},倍率为:{{ratio}}', {
|
||||
group: group,
|
||||
ratio: groupRatio[group]
|
||||
}));
|
||||
}}
|
||||
>
|
||||
{group}
|
||||
</Tag>
|
||||
);
|
||||
}
|
||||
}
|
||||
})}
|
||||
</Space>
|
||||
@@ -275,6 +278,7 @@ const ModelPricing = () => {
|
||||
const [loading, setLoading] = useState(true);
|
||||
const [userState, userDispatch] = useContext(UserContext);
|
||||
const [groupRatio, setGroupRatio] = useState({});
|
||||
const [usableGroup, setUsableGroup] = useState({});
|
||||
|
||||
const setModelsFormat = (models, groupRatio) => {
|
||||
for (let i = 0; i < models.length; i++) {
|
||||
@@ -309,9 +313,10 @@ const ModelPricing = () => {
|
||||
let url = '';
|
||||
url = `/api/pricing`;
|
||||
const res = await API.get(url);
|
||||
const { success, message, data, group_ratio } = res.data;
|
||||
const { success, message, data, group_ratio, usable_group } = res.data;
|
||||
if (success) {
|
||||
setGroupRatio(group_ratio);
|
||||
setUsableGroup(usable_group);
|
||||
setSelectedGroup(userState.user ? userState.user.group : 'default')
|
||||
setModelsFormat(data, group_ratio);
|
||||
} else {
|
||||
|
||||
@@ -430,6 +430,8 @@
|
||||
"一小时后过期": "Expires after one hour",
|
||||
"一分钟后过期": "Expires after one minute",
|
||||
"创建新的令牌": "Create New Token",
|
||||
"令牌分组,默认为用户的分组": "Token group, default is the your's group",
|
||||
"IP白名单(请勿过度信任此功能)": "IP whitelist (do not overly trust this function)",
|
||||
"注意,令牌的额度仅用于限制令牌本身的最大额度使用量,实际的使用受到账户的剩余额度限制。": "Note that the quota of the token is only used to limit the maximum quota usage of the token itself, and the actual usage is limited by the remaining quota of the account.",
|
||||
"设为无限额度": "Set to unlimited quota",
|
||||
"更新令牌信息": "Update Token Information",
|
||||
@@ -866,8 +868,8 @@
|
||||
"请选择模式": "Please select mode",
|
||||
"图片代理方式": "Picture agency method",
|
||||
"用于替换 https://cdn.discordapp.com 的域名": "The domain name used to replace https://cdn.discordapp.com",
|
||||
"一个月": "a month",
|
||||
"一天": "one day",
|
||||
"一个月": "A month",
|
||||
"一天": "One day",
|
||||
"令牌渠道分组选择": "Token channel grouping selection",
|
||||
"只可使用对应分组包含的模型。": "Only models contained in the corresponding group can be used.",
|
||||
"渠道分组": "Channel grouping",
|
||||
@@ -876,7 +878,7 @@
|
||||
"启用模型限制(非必要,不建议启用)": "Enable model restrictions (not necessary, not recommended)",
|
||||
"秒": "Second",
|
||||
"更新令牌后需等待几分钟生效": "It will take a few minutes to take effect after updating the token.",
|
||||
"一小时": "one hour",
|
||||
"一小时": "One hour",
|
||||
"新建数量": "New quantity",
|
||||
"加载失败,请稍后重试": "Loading failed, please try again later",
|
||||
"未设置": "Not set",
|
||||
@@ -1234,5 +1236,6 @@
|
||||
"应用更改": "Apply changes",
|
||||
"更多": "Expand more",
|
||||
"个模型": "models",
|
||||
"可用模型": "Available models"
|
||||
"可用模型": "Available models",
|
||||
"时间范围": "Time range"
|
||||
}
|
||||
@@ -193,14 +193,16 @@ const EditChannel = (props) => {
|
||||
|
||||
|
||||
const fetchUpstreamModelList = async (name) => {
|
||||
if (inputs['type'] !== 1) {
|
||||
showError(t('仅支持 OpenAI 接口格式'));
|
||||
return;
|
||||
}
|
||||
// if (inputs['type'] !== 1) {
|
||||
// showError(t('仅支持 OpenAI 接口格式'));
|
||||
// return;
|
||||
// }
|
||||
setLoading(true);
|
||||
const models = inputs['models'] || [];
|
||||
let err = false;
|
||||
|
||||
if (isEdit) {
|
||||
// 如果是编辑模式,使用已有的channel id获取模型列表
|
||||
const res = await API.get('/api/channel/fetch_models/' + channelId);
|
||||
if (res.data && res.data?.success) {
|
||||
models.push(...res.data.data);
|
||||
@@ -208,30 +210,29 @@ const EditChannel = (props) => {
|
||||
err = true;
|
||||
}
|
||||
} else {
|
||||
// 如果是新建模式,通过后端代理获取模型列表
|
||||
if (!inputs?.['key']) {
|
||||
showError(t('请填写密钥'));
|
||||
err = true;
|
||||
} else {
|
||||
try {
|
||||
const host = new URL((inputs['base_url'] || 'https://api.openai.com'));
|
||||
|
||||
const url = `https://${host.hostname}/v1/models`;
|
||||
const key = inputs['key'];
|
||||
const res = await axios.get(url, {
|
||||
headers: {
|
||||
'Authorization': `Bearer ${key}`
|
||||
}
|
||||
const res = await API.post('/api/channel/fetch_models', {
|
||||
base_url: inputs['base_url'],
|
||||
key: inputs['key']
|
||||
});
|
||||
if (res.data) {
|
||||
models.push(...res.data.data.map((model) => model.id));
|
||||
|
||||
if (res.data && res.data.success) {
|
||||
models.push(...res.data.data);
|
||||
} else {
|
||||
err = true;
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Error fetching models:', error);
|
||||
err = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (!err) {
|
||||
handleInputChange(name, Array.from(new Set(models)));
|
||||
showSuccess(t('获取模型列表成功'));
|
||||
@@ -638,7 +639,7 @@ const EditChannel = (props) => {
|
||||
{inputs.type === 21 && (
|
||||
<>
|
||||
<div style={{ marginTop: 10 }}>
|
||||
<Typography.Text strong>知识库 ID:</Typography.Text>
|
||||
<Typography.Text strong><EFBFBD><EFBFBD>识库 ID:</Typography.Text>
|
||||
</div>
|
||||
<Input
|
||||
label="知识库 ID"
|
||||
|
||||
@@ -75,7 +75,7 @@ export default function ModelSettingsVisualEditor(props) {
|
||||
output.ModelPrice[model.name] = parseFloat(model.price)
|
||||
} else {
|
||||
if (model.ratio !== '') output.ModelRatio[model.name] = parseFloat(model.ratio);
|
||||
if (model.completionRatio != '') output.CompletionRatio[model.name] = parseFloat(model.completionRatio);
|
||||
if (model.completionRatio !== '') output.CompletionRatio[model.name] = parseFloat(model.completionRatio);
|
||||
}
|
||||
});
|
||||
|
||||
@@ -203,11 +203,6 @@ export default function ModelSettingsVisualEditor(props) {
|
||||
showError('模型名称已存在');
|
||||
return;
|
||||
}
|
||||
// 不允许同时添加固定价格和倍率
|
||||
if (values.price !== '' && (values.ratio !== '' || values.completionRatio !== '')) {
|
||||
showError('固定价格和倍率不能同时存在');
|
||||
return;
|
||||
}
|
||||
setModels(prev => [{
|
||||
name: values.name,
|
||||
price: values.price || '',
|
||||
|
||||
@@ -23,6 +23,7 @@ import {
|
||||
} from '@douyinfe/semi-ui';
|
||||
import Title from '@douyinfe/semi-ui/lib/es/typography/title';
|
||||
import { Divider } from 'semantic-ui-react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
const EditToken = (props) => {
|
||||
const [isEdit, setIsEdit] = useState(false);
|
||||
@@ -52,6 +53,7 @@ const EditToken = (props) => {
|
||||
const [models, setModels] = useState([]);
|
||||
const [groups, setGroups] = useState([]);
|
||||
const navigate = useNavigate();
|
||||
const { t } = useTranslation();
|
||||
const handleInputChange = (name, value) => {
|
||||
setInputs((inputs) => ({ ...inputs, [name]: value }));
|
||||
};
|
||||
@@ -87,7 +89,7 @@ const EditToken = (props) => {
|
||||
}));
|
||||
setModels(localModelOptions);
|
||||
} else {
|
||||
showError(message);
|
||||
showError(t(message));
|
||||
}
|
||||
};
|
||||
|
||||
@@ -95,15 +97,13 @@ const EditToken = (props) => {
|
||||
let res = await API.get(`/api/user/self/groups`);
|
||||
const { success, message, data } = res.data;
|
||||
if (success) {
|
||||
// return data is a map, key is group name, value is group description
|
||||
// label is group description, value is group name
|
||||
let localGroupOptions = Object.keys(data).map((group) => ({
|
||||
label: data[group],
|
||||
value: group,
|
||||
}));
|
||||
setGroups(localGroupOptions);
|
||||
let localGroupOptions = Object.keys(data).map((group) => ({
|
||||
label: data[group],
|
||||
value: group,
|
||||
}));
|
||||
setGroups(localGroupOptions);
|
||||
} else {
|
||||
showError(message);
|
||||
showError(t(message));
|
||||
}
|
||||
};
|
||||
|
||||
@@ -176,7 +176,7 @@ const EditToken = (props) => {
|
||||
if (localInputs.expired_time !== -1) {
|
||||
let time = Date.parse(localInputs.expired_time);
|
||||
if (isNaN(time)) {
|
||||
showError('过期时间格式错误!');
|
||||
showError(t('过期时间格式错误!'));
|
||||
setLoading(false);
|
||||
return;
|
||||
}
|
||||
@@ -189,11 +189,11 @@ const EditToken = (props) => {
|
||||
});
|
||||
const { success, message } = res.data;
|
||||
if (success) {
|
||||
showSuccess('令牌更新成功!');
|
||||
showSuccess(t('令牌更新成功!'));
|
||||
props.refresh();
|
||||
props.handleClose();
|
||||
} else {
|
||||
showError(message);
|
||||
showError(t(message));
|
||||
}
|
||||
} else {
|
||||
// 处理新增多个令牌的情况
|
||||
@@ -209,7 +209,7 @@ const EditToken = (props) => {
|
||||
if (localInputs.expired_time !== -1) {
|
||||
let time = Date.parse(localInputs.expired_time);
|
||||
if (isNaN(time)) {
|
||||
showError('过期时间格式错误!');
|
||||
showError(t('过期时间格式错误!'));
|
||||
setLoading(false);
|
||||
break;
|
||||
}
|
||||
@@ -222,14 +222,14 @@ const EditToken = (props) => {
|
||||
if (success) {
|
||||
successCount++;
|
||||
} else {
|
||||
showError(message);
|
||||
showError(t(message));
|
||||
break; // 如果创建失败,终止循环
|
||||
}
|
||||
}
|
||||
|
||||
if (successCount > 0) {
|
||||
showSuccess(
|
||||
`${successCount}个令牌创建成功,请在列表页面点击复制获取令牌!`,
|
||||
t('令牌创建成功,请在列表页面点击复制获取令牌!')
|
||||
);
|
||||
props.refresh();
|
||||
props.handleClose();
|
||||
@@ -245,7 +245,7 @@ const EditToken = (props) => {
|
||||
<SideSheet
|
||||
placement={isEdit ? 'right' : 'left'}
|
||||
title={
|
||||
<Title level={3}>{isEdit ? '更新令牌信息' : '创建新的令牌'}</Title>
|
||||
<Title level={3}>{isEdit ? t('更新令牌信息') : t('创建新的令牌')}</Title>
|
||||
}
|
||||
headerStyle={{ borderBottom: '1px solid var(--semi-color-border)' }}
|
||||
bodyStyle={{ borderBottom: '1px solid var(--semi-color-border)' }}
|
||||
@@ -254,7 +254,7 @@ const EditToken = (props) => {
|
||||
<div style={{ display: 'flex', justifyContent: 'flex-end' }}>
|
||||
<Space>
|
||||
<Button theme='solid' size={'large'} onClick={submit}>
|
||||
提交
|
||||
{t('提交')}
|
||||
</Button>
|
||||
<Button
|
||||
theme='solid'
|
||||
@@ -262,7 +262,7 @@ const EditToken = (props) => {
|
||||
type={'tertiary'}
|
||||
onClick={handleCancel}
|
||||
>
|
||||
取消
|
||||
{t('取消')}
|
||||
</Button>
|
||||
</Space>
|
||||
</div>
|
||||
@@ -274,9 +274,9 @@ const EditToken = (props) => {
|
||||
<Spin spinning={loading}>
|
||||
<Input
|
||||
style={{ marginTop: 20 }}
|
||||
label='名称'
|
||||
label={t('名称')}
|
||||
name='name'
|
||||
placeholder={'请输入名称'}
|
||||
placeholder={t('请输入名称')}
|
||||
onChange={(value) => handleInputChange('name', value)}
|
||||
value={name}
|
||||
autoComplete='new-password'
|
||||
@@ -284,9 +284,9 @@ const EditToken = (props) => {
|
||||
/>
|
||||
<Divider />
|
||||
<DatePicker
|
||||
label='过期时间'
|
||||
label={t('过期时间')}
|
||||
name='expired_time'
|
||||
placeholder={'请选择过期时间'}
|
||||
placeholder={t('请选择过期时间')}
|
||||
onChange={(value) => handleInputChange('expired_time', value)}
|
||||
value={expired_time}
|
||||
autoComplete='new-password'
|
||||
@@ -300,7 +300,7 @@ const EditToken = (props) => {
|
||||
setExpiredTime(0, 0, 0, 0);
|
||||
}}
|
||||
>
|
||||
永不过期
|
||||
{t('永不过期')}
|
||||
</Button>
|
||||
<Button
|
||||
type={'tertiary'}
|
||||
@@ -308,7 +308,7 @@ const EditToken = (props) => {
|
||||
setExpiredTime(0, 0, 1, 0);
|
||||
}}
|
||||
>
|
||||
一小时
|
||||
{t('一小时')}
|
||||
</Button>
|
||||
<Button
|
||||
type={'tertiary'}
|
||||
@@ -316,7 +316,7 @@ const EditToken = (props) => {
|
||||
setExpiredTime(1, 0, 0, 0);
|
||||
}}
|
||||
>
|
||||
一个月
|
||||
{t('一个月')}
|
||||
</Button>
|
||||
<Button
|
||||
type={'tertiary'}
|
||||
@@ -324,7 +324,7 @@ const EditToken = (props) => {
|
||||
setExpiredTime(0, 1, 0, 0);
|
||||
}}
|
||||
>
|
||||
一天
|
||||
{t('一天')}
|
||||
</Button>
|
||||
</Space>
|
||||
</div>
|
||||
@@ -332,17 +332,15 @@ const EditToken = (props) => {
|
||||
<Divider />
|
||||
<Banner
|
||||
type={'warning'}
|
||||
description={
|
||||
'注意,令牌的额度仅用于限制令牌本身的最大额度使用量,实际的使用受到账户的剩余额度限制。'
|
||||
}
|
||||
description={t('注意,令牌的额度仅用于限制令牌本身的最大额度使用量,实际的使用受到账户的剩余额度限制。')}
|
||||
></Banner>
|
||||
<div style={{ marginTop: 20 }}>
|
||||
<Typography.Text>{`额度${renderQuotaWithPrompt(remain_quota)}`}</Typography.Text>
|
||||
<Typography.Text>{`${t('额度')}${renderQuotaWithPrompt(remain_quota)}`}</Typography.Text>
|
||||
</div>
|
||||
<AutoComplete
|
||||
style={{ marginTop: 8 }}
|
||||
name='remain_quota'
|
||||
placeholder={'请输入额度'}
|
||||
placeholder={t('请输入额度')}
|
||||
onChange={(value) => handleInputChange('remain_quota', value)}
|
||||
value={remain_quota}
|
||||
autoComplete='new-password'
|
||||
@@ -362,22 +360,22 @@ const EditToken = (props) => {
|
||||
{!isEdit && (
|
||||
<>
|
||||
<div style={{ marginTop: 20 }}>
|
||||
<Typography.Text>新建数量</Typography.Text>
|
||||
<Typography.Text>{t('新建数量')}</Typography.Text>
|
||||
</div>
|
||||
<AutoComplete
|
||||
style={{ marginTop: 8 }}
|
||||
label='数量'
|
||||
placeholder={'请选择或输入创建令牌的数量'}
|
||||
label={t('数量')}
|
||||
placeholder={t('请选择或输入创建令牌的数量')}
|
||||
onChange={(value) => handleTokenCountChange(value)}
|
||||
onSelect={(value) => handleTokenCountChange(value)}
|
||||
value={tokenCount.toString()}
|
||||
autoComplete='off'
|
||||
type='number'
|
||||
data={[
|
||||
{ value: 10, label: '10个' },
|
||||
{ value: 20, label: '20个' },
|
||||
{ value: 30, label: '30个' },
|
||||
{ value: 100, label: '100个' },
|
||||
{ value: 10, label: t('10个') },
|
||||
{ value: 20, label: t('20个') },
|
||||
{ value: 30, label: t('30个') },
|
||||
{ value: 100, label: t('100个') },
|
||||
]}
|
||||
disabled={unlimited_quota}
|
||||
/>
|
||||
@@ -392,17 +390,17 @@ const EditToken = (props) => {
|
||||
setUnlimitedQuota();
|
||||
}}
|
||||
>
|
||||
{unlimited_quota ? '取消无限额度' : '设为无限额度'}
|
||||
{unlimited_quota ? t('取消<EFBFBD><EFBFBD><EFBFBD>限额度') : t('设为无限额度')}
|
||||
</Button>
|
||||
</div>
|
||||
<Divider />
|
||||
<div style={{ marginTop: 10 }}>
|
||||
<Typography.Text>IP白名单(请勿过度信任此功能)</Typography.Text>
|
||||
<Typography.Text>{t('IP白名单(请勿过度信任此功能)')}</Typography.Text>
|
||||
</div>
|
||||
<TextArea
|
||||
label='IP白名单'
|
||||
label={t('IP白名单')}
|
||||
name='allow_ips'
|
||||
placeholder={'允许的IP,一行一个'}
|
||||
placeholder={t('允许的IP,一行一个')}
|
||||
onChange={(value) => {
|
||||
handleInputChange('allow_ips', value);
|
||||
}}
|
||||
@@ -417,16 +415,15 @@ const EditToken = (props) => {
|
||||
onChange={(e) =>
|
||||
handleInputChange('model_limits_enabled', e.target.checked)
|
||||
}
|
||||
></Checkbox>
|
||||
<Typography.Text>
|
||||
启用模型限制(非必要,不建议启用)
|
||||
</Typography.Text>
|
||||
>
|
||||
{t('启用模型限制(非必要,不建议启用)')}
|
||||
</Checkbox>
|
||||
</Space>
|
||||
</div>
|
||||
|
||||
<Select
|
||||
style={{ marginTop: 8 }}
|
||||
placeholder={'请选择该渠道所支持的模型'}
|
||||
placeholder={t('请选择该渠道所支持的模型')}
|
||||
name='models'
|
||||
required
|
||||
multiple
|
||||
@@ -440,12 +437,12 @@ const EditToken = (props) => {
|
||||
disabled={!model_limits_enabled}
|
||||
/>
|
||||
<div style={{ marginTop: 10 }}>
|
||||
<Typography.Text>令牌分组,默认为用户的分组</Typography.Text>
|
||||
<Typography.Text>{t('令牌分组,默认为用户的分组')}</Typography.Text>
|
||||
</div>
|
||||
{groups.length > 0 ?
|
||||
<Select
|
||||
style={{ marginTop: 8 }}
|
||||
placeholder={'令牌分组,默认为用户的分组'}
|
||||
placeholder={t('令牌分组,默认为用户的分组')}
|
||||
name='gruop'
|
||||
required
|
||||
selection
|
||||
@@ -458,7 +455,7 @@ const EditToken = (props) => {
|
||||
/>:
|
||||
<Select
|
||||
style={{ marginTop: 8 }}
|
||||
placeholder={'管理员未设置用户可选分组'}
|
||||
placeholder={t('管理员未设置用户可选分组')}
|
||||
name='gruop'
|
||||
disabled={true}
|
||||
/>
|
||||
|
||||
Reference in New Issue
Block a user