Compare commits

...

39 Commits

Author SHA1 Message Date
lollipopkit🏳️‍⚧️
99a2fc5852 (jsrt) opt.: move req from global to local 2025-07-16 13:24:59 +08:00
lollipopkit🏳️‍⚧️
9d9070c899 (jsrt) fix: elapsed time calc 2025-07-16 12:53:35 +08:00
lollipopkit🏳️‍⚧️
9a48ed47f4 (jsrt) fix: http: invalid Content-Length of "-1" 2025-07-16 12:48:00 +08:00
lollipopkit🏳️‍⚧️
155f67e960 (jsrt) opt.: script load strategy 2025-07-16 12:46:08 +08:00
lollipopkit🏳️‍⚧️
71778f4174 (jsrt) chore: docs 2025-07-16 11:11:34 +08:00
lollipopkit🏳️‍⚧️
7bb66b8bec (jsrt) opt.: lower case fns 2025-07-16 11:09:32 +08:00
lollipopkit🏳️‍⚧️
7bdec28e5f (jsrt) rename: JSContext -> JSReq 2025-07-16 10:27:22 +08:00
lollipopkit🏳️‍⚧️
5ffdd9f542 (jsrt) fix: fetch 2025-07-16 01:43:01 +08:00
lollipopkit🏳️‍⚧️
4c72f2abed (jsrt) opt. & fix: cfg loading 2025-07-15 22:50:38 +08:00
lollipopkit🏳️‍⚧️
fd51f71e0f (jsrt) opt.: struct 2025-07-15 22:19:02 +08:00
lollipopkit🏳️‍⚧️
59f12d2582 (jsrt) feat: builtin fetch 2025-07-15 21:29:50 +08:00
lollipopkit🏳️‍⚧️
f17a419520 opt.: js rt pool & perf 2025-07-15 21:11:12 +08:00
lollipopkit🏳️‍⚧️
ee114e14c3 feat: dyn middlewares based on js rt 2025-07-15 20:26:33 +08:00
Calcium-Ion
78fb457765 Merge pull request #1346 from QuantumNous/fix-ability
 feat(ability): enhance FixAbility function
2025-07-08 18:38:35 +08:00
CaIon
8759ef012f feat(ability): enhance FixAbility function 2025-07-08 18:33:32 +08:00
Calcium-Ion
f8d67a62a2 Merge pull request #1334 from duyazhe/fix-baidu-bug
修复了百度请求时候需要传appid的bug
2025-07-07 14:51:23 +08:00
Xyfacai
efb98854b2 Merge pull request #1341 from QuantumNous/refactor/log-params
refactor: log params and channel params
2025-07-07 14:29:16 +08:00
Xiangyuan-liu
7b29f429ee refactor: log params and channel params
refactor: log params and channel params
2025-07-07 14:26:37 +08:00
CaIon
265c7d93a2 🔧 refactor(adaptor): update HTTP referer to new API domain 2025-07-07 12:36:04 +08:00
duyazhe
ce57ad3570 Update adaptor.go 2025-07-07 09:57:20 +08:00
duyazhe
9282f1d893 修复了百度请求时候需要传appid的bug 2025-07-06 23:09:49 +08:00
CaIon
9546a47f2b feat(tokens): add cherryConfig support for URL generation and base64 encoding 2025-07-06 20:56:09 +08:00
CaIon
8073cbd96a 🔧 refactor(model): change user group retrieval to non-strict mode 2025-07-06 10:23:38 +08:00
CaIon
5eba2f1d61 🔧 refactor(model): update context key retrieval to use token group instead of user group 2025-07-05 16:40:49 +08:00
Calcium-Ion
5ec421d8e6 Merge pull request #1321 from iszcz/main
支持Midjourney视频任务和图片编辑
2025-07-05 15:28:33 +08:00
CaIon
1e25bf700d Merge remote-tracking branch 'origin/alpha' into alpha 2025-07-05 14:14:48 +08:00
CaIon
30fb349d91 feat(endpoint types): add support for image generation models in endpoint type handling 2025-07-05 14:14:40 +08:00
t0ng7u
d40fb68500 📊 feat(detail): add model consumption trend & call ranking charts
Introduce two new visualizations to the “Model Data Analysis” panel:

1. Model Consumption Trend (line chart)
   • Added `spec_model_line` state and legend support.
   • Calculates per-model counts over time and updates via `updateChartData`.
2. Model Call Ranking (bar chart)
   • Added `spec_rank_bar` state with `seriesField` and legend enabled.
   • Ranks models by total call count.

Additional changes:
• Extended tab navigation with two new `TabPane`s and adjusted chart rendering logic.
• Swapped icons/texts to match new chart purposes.
• Reused existing color mapping to ensure consistent palette.

No breaking changes; UI now offers richer insights into model usage patterns.
2025-07-05 00:37:05 +08:00
t0ng7u
3049ad47e5 🔢 feat(user-edit): replace add-quota input with Semi-UI InputNumber
Summary:
• Imported InputNumber from @douyinfe/semi-ui.
• Swapped plain Input for InputNumber in “Add Quota” modal.
• Added UX tweaks: full-width styling, showClear, step = 500 000.
• Initialized addQuotaLocal to an empty string so the field starts blank.
• Adjusted state handling and kept quota calculation logic unchanged.

This improves numeric input accuracy and overall user experience without breaking existing functionality.
2025-07-05 00:03:12 +08:00
t0ng7u
8945a3a2dd 🖼️ style(RatioSync): remove the useless rounded-full style 2025-07-04 23:49:34 +08:00
t0ng7u
d191eef657 🐛 fix: fix the header height calculation issue in the custom HTML styles on the homepage 2025-07-04 23:42:46 +08:00
CaIon
6ac7878863 🔧 refactor(endpoint types): comment out unused endpoint types in constants 2025-07-04 15:53:46 +08:00
t0ng7u
c0a23ffa62 🎨 refactor(EditTagModal): tidy imports & enhance state-sync on open
Motivation
• Remove unused UI components to keep the bundle lean and silence linter warnings.
• Ensure every time the side-sheet opens it reflects the latest tag data, avoiding stale form values (e.g., model / group mismatches).

Key Changes
1. UI Imports
   – Dropped `Input`, `Select`, `TextArea` from `@douyinfe/semi-ui` (unused in Form-based version).
2. State Reset & Form Sync
   – On `visible` or `tag` change:
     • Refresh model & group options.
     • Reset `inputs` to clean defaults (`originInputs`) carrying the current `tag`.
     • Pre-fill Form through `formApiRef` to keep controlled fields aligned.
3. Minor Cleanup
   – Added inline comment clarifying local state reset purpose.

Result
Opening the “Edit Tag” side-sheet now always displays accurate data without residual selections, and build output is cleaner due to removed dead imports.
2025-07-04 06:14:15 +08:00
t0ng7u
7d691f362d refactor(EditChannel&EditToken): refactor Channel & Token edit pages with Semi Form and UX enhancements
Overview
• Migrated both `EditChannel.js` and `EditToken.js` to fully leverage Semi UI `Form.*` components, removing legacy `Input/Select/TextArea` + manual labels.
• Unified data-loading strategy: when the drawer becomes visible we load (or reset) data via `props.visible + id` effect and `formApi.setValues()`, guaranteeing fields are always populated; form resets on close.
• Fixed blank-form bug when opening the same record twice.

Key improvements
1. Validation
   • `type`, `models` always required.
   • `key` required only while creating (not on edit).
2. Batch key creation
   • Checkbox moved into `extraText`; hidden when editing or when channel type = 41.
3. Layout & UI
   • `Row / Col` (12 + 12) for “Priority” and “Weight”.
   • Placeholders revised; model selector now shows creation hint; removed obsolete banner.
   • Help / extraText used for long hints, template buttons (`model_mapping`, `status_code_mapping`, `param_override`, etc.), and API address notice.
   • Added `showClear`, `min`, rounded card class names for consistency.
4. Reusable helpers
   • `batchAllowed`, `batchExtra` utilities.
   • `getInitValues()` + centralized `inputs`→form synchronization.
5. Token editor aligned to the same pattern (`props.visiable` watcher).

Result
Cleaner code, consistent UX, instant field population on every open, and clearer validation/error feedback across both editors.
2025-07-04 05:36:10 +08:00
t0ng7u
bf577b8937 🔌 feat(api): extend endpoint type support & expose in pricing UI
* backend
  - constant/endpoint_type.go
    • Add EndpointTypeMidjourney, EndpointTypeSuno, EndpointTypeKling, EndpointTypeJimeng.
  - common/endpoint_type.go
    • Map Midjourney / MidjourneyPlus, SunoAPI, Kling, Jimeng channel types to the new endpoint types.

* frontend
  - ModelPricing.js
    • Add “Supported Endpoint Type” column.
    • Implement renderSupportedEndpoints with `stringToColor` for consistent tag colors.

These changes allow `/api/pricing` and model lists to return accurate
`supported_endpoint_types` covering all non-OpenAI providers and display
them clearly in the UI.

No breaking changes.
2025-07-04 03:15:34 +08:00
Calcium-Ion
819290c9b8 Merge pull request #1314 from vickyyd/main
修复使用gemini-balance作为上游时,测试gemini2.5pro模型时出现的错误问题
2025-07-03 15:53:32 +08:00
CaIon
22e8b46159 feat: make TopN field in RerankRequest optional in JSON serialization 2025-07-03 15:45:32 +08:00
iszcz
660180ea1b 支持Midjourney视频任务和图片编辑 2025-06-30 22:31:12 +08:00
kikii16
efc8457770 修复gemini-balance测试gemini2.5pro的错误问题 2025-06-29 13:36:19 +08:00
80 changed files with 3893 additions and 1010 deletions

View File

@@ -73,3 +73,14 @@
# 节点类型
# 如果是主节点则为master
# NODE_TYPE=master
# JavaScript 运行时配置
# 是否启用默认false
# JS_RUNTIME_ENABLED=true
# 最大虚拟机数量默认8
# JS_MAX_VM_COUNT=
# 运行超时时间单位默认5
# JS_SCRIPT_TIMEOUT=
# 脚本文件夹默认scripts/
# JS_SCRIPT_PATH=

View File

@@ -27,9 +27,6 @@
<a href="https://goreportcard.com/report/github.com/Calcium-Ion/new-api">
<img src="https://goreportcard.com/badge/github.com/Calcium-Ion/new-api" alt="GoReportCard">
</a>
<a href="https://coderabbit.ai">
<img src="https://img.shields.io/coderabbit/prs/github/QuantumNous/new-api?utm_source=oss&utm_medium=github&utm_campaign=QuantumNous%2Fnew-api&labelColor=171717&color=FF570A&link=https%3A%2F%2Fcoderabbit.ai&label=CodeRabbit+Reviews" alt="CodeRabbit Pull Request Reviews">
</a>
</p>
</div>

View File

@@ -8,6 +8,14 @@ func GetEndpointTypesByChannelType(channelType int, modelName string) []constant
switch channelType {
case constant.ChannelTypeJina:
endpointTypes = []constant.EndpointType{constant.EndpointTypeJinaRerank}
//case constant.ChannelTypeMidjourney, constant.ChannelTypeMidjourneyPlus:
// endpointTypes = []constant.EndpointType{constant.EndpointTypeMidjourney}
//case constant.ChannelTypeSunoAPI:
// endpointTypes = []constant.EndpointType{constant.EndpointTypeSuno}
//case constant.ChannelTypeKling:
// endpointTypes = []constant.EndpointType{constant.EndpointTypeKling}
//case constant.ChannelTypeJimeng:
// endpointTypes = []constant.EndpointType{constant.EndpointTypeJimeng}
case constant.ChannelTypeAws:
fallthrough
case constant.ChannelTypeAnthropic:
@@ -25,5 +33,9 @@ func GetEndpointTypesByChannelType(channelType int, modelName string) []constant
endpointTypes = []constant.EndpointType{constant.EndpointTypeOpenAI}
}
}
if IsImageGenerationModel(modelName) {
// add to first
endpointTypes = append([]constant.EndpointType{constant.EndpointTypeImageGeneration}, endpointTypes...)
}
return endpointTypes
}

View File

@@ -76,3 +76,13 @@ func GetContextKeyStringMap(c *gin.Context, key constant.ContextKey) map[string]
func GetContextKeyTime(c *gin.Context, key constant.ContextKey) time.Time {
return c.GetTime(string(key))
}
func GetContextKeyType[T any](c *gin.Context, key constant.ContextKey) (T, bool) {
if value, ok := c.Get(string(key)); ok {
if v, ok := value.(T); ok {
return v, true
}
}
var t T
return t, false
}

View File

@@ -9,11 +9,32 @@ var (
"o3-deep-research",
"o4-mini-deep-research",
}
ImageGenerationModels = []string{
"dall-e-3",
"dall-e-2",
"gpt-image-1",
"prefix:imagen-",
"flux-",
"flux.1-",
}
)
func IsOpenAIResponseOnlyModel(modelName string) bool {
for _, m := range OpenAIResponseOnlyModels {
if strings.Contains(m, modelName) {
if strings.Contains(modelName, m) {
return true
}
}
return false
}
func IsImageGenerationModel(modelName string) bool {
modelName = strings.ToLower(modelName)
for _, m := range ImageGenerationModels {
if strings.Contains(modelName, m) {
return true
}
if strings.HasPrefix(m, "prefix:") && strings.HasPrefix(modelName, strings.TrimPrefix(m, "prefix:")) {
return true
}
}

View File

@@ -1,6 +1,7 @@
package common
import (
"encoding/base64"
"encoding/json"
"math/rand"
"strconv"
@@ -68,3 +69,15 @@ func StringToByteSlice(s string) []byte {
tmp2 := [3]uintptr{tmp1[0], tmp1[1], tmp1[1]}
return *(*[]byte)(unsafe.Pointer(&tmp2))
}
func EncodeBase64(str string) string {
return base64.StdEncoding.EncodeToString([]byte(str))
}
func GetJsonString(data any) string {
if data == nil {
return ""
}
b, _ := json.Marshal(data)
return string(b)
}

149
common/struct_reflect.go Normal file
View File

@@ -0,0 +1,149 @@
package common
import (
"fmt"
"reflect"
)
// StructToMap 递归地把任意结构体 v 转成 map[string]any。
// - 只处理导出字段;未导出字段会被跳过。
// - 优先使用 `json:"name"` 里逗号前的部分作为键;如果是 "-" 则忽略该字段;若无 tag则使用字段名。
// - 对指针、切片、数组、嵌套结构体、map 做深度遍历,保持原始结构。
func StructToMap(v any) (map[string]any, error) {
val := reflect.ValueOf(v)
if !val.IsValid() {
return nil, fmt.Errorf("nil value")
}
for val.Kind() == reflect.Pointer {
if val.IsNil() {
return nil, fmt.Errorf("nil pointer")
}
val = val.Elem()
}
if val.Kind() != reflect.Struct {
return nil, fmt.Errorf("expect struct, got %s", val.Kind())
}
return structValueToMap(val), nil
}
func structValueToMap(val reflect.Value) map[string]any {
out := make(map[string]any, val.NumField())
typ := val.Type()
for i := 0; i < val.NumField(); i++ {
f := typ.Field(i)
if f.PkgPath != "" { // 未导出字段
continue
}
// 解析 json tag
tag := f.Tag.Get("json")
name, opts := parseTag(tag)
if name == "-" {
continue
}
if name == "" {
name = f.Name
}
fv := val.Field(i)
out[name] = valueToAny(fv, opts.Contains("omitempty"))
}
return out
}
// valueToAny 递归处理各种值类型。
func valueToAny(v reflect.Value, omitEmpty bool) any {
if !v.IsValid() {
return nil
}
for v.Kind() == reflect.Pointer {
if v.IsNil() {
if omitEmpty {
return nil
}
// 保持与 encoding/json 行为一致nil 指针写成 null
return nil
}
v = v.Elem()
}
switch v.Kind() {
case reflect.Struct:
return structValueToMap(v)
case reflect.Slice, reflect.Array:
l := v.Len()
arr := make([]any, l)
for i := 0; i < l; i++ {
arr[i] = valueToAny(v.Index(i), false)
}
return arr
case reflect.Map:
m := make(map[string]any, v.Len())
iter := v.MapRange()
for iter.Next() {
k := iter.Key()
// 只支持 string key与 encoding/json 一致
if k.Kind() == reflect.String {
m[k.String()] = valueToAny(iter.Value(), false)
}
}
return m
default:
// 基本类型直接返回其接口值
return v.Interface()
}
}
// tagOptions 用于判断是否包含 "omitempty"
type tagOptions string
func (o tagOptions) Contains(opt string) bool {
if len(o) == 0 {
return false
}
for _, s := range splitComma(string(o)) {
if s == opt {
return true
}
}
return false
}
func parseTag(tag string) (string, tagOptions) {
if idx := indexComma(tag); idx != -1 {
return tag[:idx], tagOptions(tag[idx+1:])
}
return tag, tagOptions("")
}
// 避免 strings.Split 额外分配
func indexComma(s string) int {
for i, r := range s {
if r == ',' {
return i
}
}
return -1
}
func splitComma(s string) []string {
var parts []string
start := 0
for i, r := range s {
if r == ',' {
parts = append(parts, s[start:i])
start = i + 1
}
}
if start <= len(s) {
parts = append(parts, s[start:])
}
return parts
}

View File

@@ -1,7 +0,0 @@
package constant
var (
ForceFormat = "force_format" // ForceFormat 强制格式化为OpenAI格式
ChanelSettingProxy = "proxy" // Proxy 代理
ChannelSettingThinkingToContent = "thinking_to_content" // ThinkingToContent
)

View File

@@ -3,9 +3,14 @@ package constant
type EndpointType string
const (
EndpointTypeOpenAI EndpointType = "openai"
EndpointTypeOpenAIResponse EndpointType = "openai-response"
EndpointTypeAnthropic EndpointType = "anthropic"
EndpointTypeGemini EndpointType = "gemini"
EndpointTypeJinaRerank EndpointType = "jina-rerank"
EndpointTypeOpenAI EndpointType = "openai"
EndpointTypeOpenAIResponse EndpointType = "openai-response"
EndpointTypeAnthropic EndpointType = "anthropic"
EndpointTypeGemini EndpointType = "gemini"
EndpointTypeJinaRerank EndpointType = "jina-rerank"
EndpointTypeImageGeneration EndpointType = "image-generation"
//EndpointTypeMidjourney EndpointType = "midjourney-proxy"
//EndpointTypeSuno EndpointType = "suno-proxy"
//EndpointTypeKling EndpointType = "kling"
//EndpointTypeJimeng EndpointType = "jimeng"
)

View File

@@ -22,6 +22,8 @@ const (
MjActionPan = "PAN"
MjActionSwapFace = "SWAP_FACE"
MjActionUpload = "UPLOAD"
MjActionVideo = "VIDEO"
MjActionEdits = "EDITS"
)
var MidjourneyModel2Action = map[string]string{
@@ -41,4 +43,6 @@ var MidjourneyModel2Action = map[string]string{
"mj_pan": MjActionPan,
"swap_face": MjActionSwapFace,
"mj_upload": MjActionUpload,
"mj_video": MjActionVideo,
"mj_edits": MjActionEdits,
}

View File

@@ -1,16 +0,0 @@
package constant
var (
UserSettingNotifyType = "notify_type" // QuotaWarningType 额度预警类型
UserSettingQuotaWarningThreshold = "quota_warning_threshold" // QuotaWarningThreshold 额度预警阈值
UserSettingWebhookUrl = "webhook_url" // WebhookUrl webhook地址
UserSettingWebhookSecret = "webhook_secret" // WebhookSecret webhook密钥
UserSettingNotificationEmail = "notification_email" // NotificationEmail 通知邮箱地址
UserAcceptUnsetRatioModel = "accept_unset_model_ratio_model" // AcceptUnsetRatioModel 是否接受未设置价格的模型
UserSettingRecordIpLog = "record_ip_log" // 是否记录请求和错误日志IP
)
var (
NotifyTypeEmail = "email" // Email 邮件
NotifyTypeWebhook = "webhook" // Webhook
)

View File

@@ -173,8 +173,19 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
consumedTime := float64(milliseconds) / 1000.0
other := service.GenerateTextOtherInfo(c, info, priceData.ModelRatio, priceData.GroupRatioInfo.GroupRatio, priceData.CompletionRatio,
usage.PromptTokensDetails.CachedTokens, priceData.CacheRatio, priceData.ModelPrice, priceData.GroupRatioInfo.GroupSpecialRatio)
model.RecordConsumeLog(c, 1, channel.Id, usage.PromptTokens, usage.CompletionTokens, info.OriginModelName, "模型测试",
quota, "模型测试", 0, quota, int(consumedTime), false, info.UsingGroup, other)
model.RecordConsumeLog(c, 1, model.RecordConsumeLogParams{
ChannelId: channel.Id,
PromptTokens: usage.PromptTokens,
CompletionTokens: usage.CompletionTokens,
ModelName: info.OriginModelName,
TokenName: "模型测试",
Quota: quota,
Content: "模型测试",
UseTimeSeconds: int(consumedTime),
IsStream: false,
Group: info.UsingGroup,
Other: other,
})
common.SysLog(fmt.Sprintf("testing channel #%d, response: \n%s", channel.Id, string(respBody)))
return nil, nil
}
@@ -202,7 +213,7 @@ func buildTestRequest(model string) *dto.GeneralOpenAIRequest {
testRequest.MaxTokens = 50
}
} else if strings.Contains(model, "gemini") {
testRequest.MaxTokens = 300
testRequest.MaxTokens = 3000
} else {
testRequest.MaxTokens = 10
}

View File

@@ -228,7 +228,7 @@ func FetchUpstreamModels(c *gin.Context) {
}
func FixChannelsAbilities(c *gin.Context) {
count, err := model.FixAbility()
success, fails, err := model.FixAbility()
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
@@ -239,7 +239,10 @@ func FixChannelsAbilities(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": count,
"data": gin.H{
"success": success,
"fails": fails,
},
})
}
@@ -387,6 +390,14 @@ func AddChannel(c *gin.Context) {
})
return
}
err = channel.ValidateSettings()
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "channel setting 格式错误:" + err.Error(),
})
return
}
channel.CreatedTime = common.GetTimestamp()
keys := strings.Split(channel.Key, "\n")
if channel.Type == constant.ChannelTypeVertexAi {
@@ -614,6 +625,14 @@ func UpdateChannel(c *gin.Context) {
})
return
}
err = channel.ValidateSettings()
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "channel setting 格式错误:" + err.Error(),
})
return
}
if channel.Type == constant.ChannelTypeVertexAi {
if channel.Other == "" {
c.JSON(http.StatusOK, gin.H{

View File

@@ -1,12 +1,14 @@
package controller
import (
"slices"
"encoding/json"
"fmt"
"net/http"
"one-api/common"
"one-api/constant"
"one-api/middleware"
"one-api/middleware/jsrt"
"one-api/model"
"one-api/setting"
"one-api/setting/console_setting"
@@ -33,7 +35,6 @@ func TestStatus(c *gin.Context) {
"message": "Server is running",
"http_stats": httpStats,
})
return
}
func GetStatus(c *gin.Context) {
@@ -106,7 +107,6 @@ func GetStatus(c *gin.Context) {
"message": "",
"data": data,
})
return
}
func GetNotice(c *gin.Context) {
@@ -117,7 +117,6 @@ func GetNotice(c *gin.Context) {
"message": "",
"data": common.OptionMap["Notice"],
})
return
}
func GetAbout(c *gin.Context) {
@@ -128,7 +127,6 @@ func GetAbout(c *gin.Context) {
"message": "",
"data": common.OptionMap["About"],
})
return
}
func GetMidjourney(c *gin.Context) {
@@ -139,7 +137,6 @@ func GetMidjourney(c *gin.Context) {
"message": "",
"data": common.OptionMap["Midjourney"],
})
return
}
func GetHomePageContent(c *gin.Context) {
@@ -150,7 +147,6 @@ func GetHomePageContent(c *gin.Context) {
"message": "",
"data": common.OptionMap["HomePageContent"],
})
return
}
func SendEmailVerification(c *gin.Context) {
@@ -173,13 +169,7 @@ func SendEmailVerification(c *gin.Context) {
localPart := parts[0]
domainPart := parts[1]
if common.EmailDomainRestrictionEnabled {
allowed := false
for _, domain := range common.EmailDomainWhitelist {
if domainPart == domain {
allowed = true
break
}
}
allowed := slices.Contains(common.EmailDomainWhitelist, domainPart)
if !allowed {
c.JSON(http.StatusOK, gin.H{
"success": false,
@@ -224,7 +214,6 @@ func SendEmailVerification(c *gin.Context) {
"success": true,
"message": "",
})
return
}
func SendPasswordResetEmail(c *gin.Context) {
@@ -263,7 +252,6 @@ func SendPasswordResetEmail(c *gin.Context) {
"success": true,
"message": "",
})
return
}
type PasswordResetRequest struct {
@@ -303,5 +291,13 @@ func ResetPassword(c *gin.Context) {
"message": "",
"data": password,
})
return
}
func ReloadJSScripts(c *gin.Context) {
jsrt.ReloadJSScripts()
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "JavaScript 脚本已重新加载",
})
}

View File

@@ -130,7 +130,7 @@ func ListModels(c *gin.Context) {
}
} else {
userId := c.GetInt("id")
userGroup, err := model.GetUserGroup(userId, true)
userGroup, err := model.GetUserGroup(userId, false)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
@@ -139,7 +139,7 @@ func ListModels(c *gin.Context) {
return
}
group := userGroup
tokenGroup := common.GetContextKeyString(c, constant.ContextKeyUserGroup)
tokenGroup := common.GetContextKeyString(c, constant.ContextKeyTokenGroup)
if tokenGroup != "" {
group = tokenGroup
}

View File

@@ -122,7 +122,7 @@ func updateSunoTaskAll(ctx context.Context, channelId int, taskIds []string, tas
}
if resp.StatusCode != http.StatusOK {
common.LogError(ctx, fmt.Sprintf("Get Task status code: %d", resp.StatusCode))
return errors.New(fmt.Sprintf("Get Task status code: %d", resp.StatusCode))
return fmt.Errorf("Get Task status code: %d", resp.StatusCode)
}
defer resp.Body.Close()
responseBody, err := io.ReadAll(resp.Body)

View File

@@ -6,6 +6,7 @@ import (
"net/http"
"net/url"
"one-api/common"
"one-api/dto"
"one-api/model"
"one-api/setting"
"strconv"
@@ -961,7 +962,7 @@ func UpdateUserSetting(c *gin.Context) {
}
// 验证预警类型
if req.QuotaWarningType != constant.NotifyTypeEmail && req.QuotaWarningType != constant.NotifyTypeWebhook {
if req.QuotaWarningType != dto.NotifyTypeEmail && req.QuotaWarningType != dto.NotifyTypeWebhook {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "无效的预警类型",
@@ -979,7 +980,7 @@ func UpdateUserSetting(c *gin.Context) {
}
// 如果是webhook类型,验证webhook地址
if req.QuotaWarningType == constant.NotifyTypeWebhook {
if req.QuotaWarningType == dto.NotifyTypeWebhook {
if req.WebhookUrl == "" {
c.JSON(http.StatusOK, gin.H{
"success": false,
@@ -998,7 +999,7 @@ func UpdateUserSetting(c *gin.Context) {
}
// 如果是邮件类型,验证邮箱地址
if req.QuotaWarningType == constant.NotifyTypeEmail && req.NotificationEmail != "" {
if req.QuotaWarningType == dto.NotifyTypeEmail && req.NotificationEmail != "" {
// 验证邮箱格式
if !strings.Contains(req.NotificationEmail, "@") {
c.JSON(http.StatusOK, gin.H{
@@ -1020,24 +1021,24 @@ func UpdateUserSetting(c *gin.Context) {
}
// 构建设置
settings := map[string]interface{}{
constant.UserSettingNotifyType: req.QuotaWarningType,
constant.UserSettingQuotaWarningThreshold: req.QuotaWarningThreshold,
"accept_unset_model_ratio_model": req.AcceptUnsetModelRatioModel,
constant.UserSettingRecordIpLog: req.RecordIpLog,
settings := dto.UserSetting{
NotifyType: req.QuotaWarningType,
QuotaWarningThreshold: req.QuotaWarningThreshold,
AcceptUnsetRatioModel: req.AcceptUnsetModelRatioModel,
RecordIpLog: req.RecordIpLog,
}
// 如果是webhook类型,添加webhook相关设置
if req.QuotaWarningType == constant.NotifyTypeWebhook {
settings[constant.UserSettingWebhookUrl] = req.WebhookUrl
if req.QuotaWarningType == dto.NotifyTypeWebhook {
settings.WebhookUrl = req.WebhookUrl
if req.WebhookSecret != "" {
settings[constant.UserSettingWebhookSecret] = req.WebhookSecret
settings.WebhookSecret = req.WebhookSecret
}
}
// 如果提供了通知邮箱,添加到设置中
if req.QuotaWarningType == constant.NotifyTypeEmail && req.NotificationEmail != "" {
settings[constant.UserSettingNotificationEmail] = req.NotificationEmail
if req.QuotaWarningType == dto.NotifyTypeEmail && req.NotificationEmail != "" {
settings.NotificationEmail = req.NotificationEmail
}
// 更新用户设置

View File

@@ -11,6 +11,7 @@ services:
volumes:
- ./data:/data
- ./logs:/app/logs
- ${JS_SCRIPT_DIR:-./scripts}:/app/scripts
environment:
- SQL_DSN=root:123456@tcp(mysql:3306)/new-api # Point to the mysql service
- REDIS_CONN_STRING=redis://redis
@@ -21,7 +22,6 @@ services:
# - NODE_TYPE=slave # Uncomment for slave node in multi-node deployment
# - SYNC_FREQUENCY=60 # Uncomment if regular database syncing is needed
# - FRONTEND_BASE_URL=https://openai.justsong.cn # Uncomment for multi-node deployment with front-end URL
depends_on:
- redis
- mysql

7
dto/channel_settings.go Normal file
View File

@@ -0,0 +1,7 @@
package dto
type ChannelSettings struct {
ForceFormat bool `json:"force_format,omitempty"`
ThinkingToContent bool `json:"thinking_to_content,omitempty"`
Proxy string `json:"proxy"`
}

View File

@@ -57,6 +57,8 @@ type MidjourneyDto struct {
StartTime int64 `json:"startTime"`
FinishTime int64 `json:"finishTime"`
ImageUrl string `json:"imageUrl"`
VideoUrl string `json:"videoUrl"`
VideoUrls []ImgUrls `json:"videoUrls"`
Status string `json:"status"`
Progress string `json:"progress"`
FailReason string `json:"failReason"`
@@ -65,6 +67,10 @@ type MidjourneyDto struct {
Properties *Properties `json:"properties"`
}
type ImgUrls struct {
Url string `json:"url"`
}
type MidjourneyStatus struct {
Status int `json:"status"`
}

View File

@@ -4,7 +4,7 @@ type RerankRequest struct {
Documents []any `json:"documents"`
Query string `json:"query"`
Model string `json:"model"`
TopN int `json:"top_n"`
TopN int `json:"top_n,omitempty"`
ReturnDocuments *bool `json:"return_documents,omitempty"`
MaxChunkPerDoc int `json:"max_chunk_per_doc,omitempty"`
OverLapTokens int `json:"overlap_tokens,omitempty"`

16
dto/user_settings.go Normal file
View File

@@ -0,0 +1,16 @@
package dto
type UserSetting struct {
NotifyType string `json:"notify_type,omitempty"` // QuotaWarningType 额度预警类型
QuotaWarningThreshold float64 `json:"quota_warning_threshold,omitempty"` // QuotaWarningThreshold 额度预警阈值
WebhookUrl string `json:"webhook_url,omitempty"` // WebhookUrl webhook地址
WebhookSecret string `json:"webhook_secret,omitempty"` // WebhookSecret webhook密钥
NotificationEmail string `json:"notification_email,omitempty"` // NotificationEmail 通知邮箱地址
AcceptUnsetRatioModel bool `json:"accept_unset_model_ratio_model,omitempty"` // AcceptUnsetRatioModel 是否接受未设置价格的模型
RecordIpLog bool `json:"record_ip_log,omitempty"` // 是否记录请求和错误日志IP
}
var (
NotifyTypeEmail = "email" // Email 邮件
NotifyTypeWebhook = "webhook" // Webhook
)

5
go.mod
View File

@@ -11,6 +11,7 @@ require (
github.com/aws/aws-sdk-go-v2/credentials v1.17.11
github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.7.4
github.com/bytedance/gopkg v0.0.0-20220118071334-3db87571198b
github.com/dop251/goja v0.0.0-20250630131328-58d95d85e994
github.com/gin-contrib/cors v1.7.2
github.com/gin-contrib/gzip v0.0.6
github.com/gin-contrib/sessions v0.0.5
@@ -31,6 +32,7 @@ require (
golang.org/x/crypto v0.35.0
golang.org/x/image v0.23.0
golang.org/x/net v0.35.0
golang.org/x/sync v0.11.0
gorm.io/driver/mysql v1.4.3
gorm.io/driver/postgres v1.5.2
gorm.io/gorm v1.25.2
@@ -56,9 +58,11 @@ require (
github.com/go-ole/go-ole v1.2.6 // indirect
github.com/go-playground/locales v0.14.1 // indirect
github.com/go-playground/universal-translator v0.18.1 // indirect
github.com/go-sourcemap/sourcemap v2.1.3+incompatible // indirect
github.com/go-sql-driver/mysql v1.7.0 // indirect
github.com/goccy/go-json v0.10.2 // indirect
github.com/google/go-cmp v0.6.0 // indirect
github.com/google/pprof v0.0.0-20230207041349-798e818bf904 // indirect
github.com/gorilla/context v1.1.1 // indirect
github.com/gorilla/securecookie v1.1.1 // indirect
github.com/gorilla/sessions v1.2.1 // indirect
@@ -84,7 +88,6 @@ require (
github.com/yusufpapurcu/wmi v1.2.3 // indirect
golang.org/x/arch v0.12.0 // indirect
golang.org/x/exp v0.0.0-20240404231335-c0f41cb1a7a0 // indirect
golang.org/x/sync v0.11.0 // indirect
golang.org/x/sys v0.30.0 // indirect
golang.org/x/text v0.22.0 // indirect
google.golang.org/protobuf v1.34.2 // indirect

10
go.sum
View File

@@ -1,5 +1,7 @@
github.com/Calcium-Ion/go-epay v0.0.4 h1:C96M7WfRLadcIVscWzwLiYs8etI1wrDmtFMuK2zP22A=
github.com/Calcium-Ion/go-epay v0.0.4/go.mod h1:cxo/ZOg8ClvE3VAnCmEzbuyAZINSq7kFEN9oHj5WQ2U=
github.com/Masterminds/semver/v3 v3.2.1 h1:RN9w6+7QoMeJVGyfmbcgs28Br8cvmnucEXnY0rYXWg0=
github.com/Masterminds/semver/v3 v3.2.1/go.mod h1:qvl/7zhW3nngYb5+80sSMF+FG2BjYrf8m9wsX0PNOMQ=
github.com/andybalholm/brotli v1.1.1 h1:PR2pgnyFznKEugtsUo0xLdDop5SKXd5Qf5ysW+7XdTA=
github.com/andybalholm/brotli v1.1.1/go.mod h1:05ib4cKhjx3OQYUY22hTVd34Bc8upXjOLL2rKwwZBoA=
github.com/anknown/ahocorasick v0.0.0-20190904063843-d75dbd5169c0 h1:onfun1RA+KcxaMk1lfrRnwCd1UUuOjJM/lri5eM1qMs=
@@ -40,6 +42,8 @@ github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/r
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
github.com/dlclark/regexp2 v1.11.5 h1:Q/sSnsKerHeCkc/jSTNq1oCm7KiVgUMZRDUoRu0JQZQ=
github.com/dlclark/regexp2 v1.11.5/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
github.com/dop251/goja v0.0.0-20250630131328-58d95d85e994 h1:aQYWswi+hRL2zJqGacdCZx32XjKYV8ApXFGntw79XAM=
github.com/dop251/goja v0.0.0-20250630131328-58d95d85e994/go.mod h1:MxLav0peU43GgvwVgNbLAj1s/bSGboKkhuULvq/7hx4=
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4=
@@ -83,6 +87,8 @@ github.com/go-playground/validator/v10 v10.20.0 h1:K9ISHbSaI0lyB2eWMPJo+kOS/FBEx
github.com/go-playground/validator/v10 v10.20.0/go.mod h1:dbuPbCMFw/DrkbEynArYaCwl3amGuJotoKCe95atGMM=
github.com/go-redis/redis/v8 v8.11.5 h1:AcZZR7igkdvfVmQTPnu9WE37LRrO/YrBH5zWyjDC0oI=
github.com/go-redis/redis/v8 v8.11.5/go.mod h1:gREzHqY1hg6oD9ngVRbLStwAWKhA0FEgq8Jd4h5lpwo=
github.com/go-sourcemap/sourcemap v2.1.3+incompatible h1:W1iEw64niKVGogNgBN3ePyLFfuisuzeidWPMPWmECqU=
github.com/go-sourcemap/sourcemap v2.1.3+incompatible/go.mod h1:F8jJfvm2KbVjc5NqelyYJmf/v5J0dwNLS2mL4sNA1Jg=
github.com/go-sql-driver/mysql v1.6.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg=
github.com/go-sql-driver/mysql v1.7.0 h1:ueSltNNllEqE3qcWBTD0iQd3IpL/6U+mJxLkazJ7YPc=
github.com/go-sql-driver/mysql v1.7.0/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI=
@@ -97,8 +103,8 @@ github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/google/pprof v0.0.0-20221118152302-e6195bd50e26 h1:Xim43kblpZXfIBQsbuBVKCudVG457BR2GZFIz3uw3hQ=
github.com/google/pprof v0.0.0-20221118152302-e6195bd50e26/go.mod h1:dDKJzRmX4S37WGHujM7tX//fmj1uioxKzKxz3lo4HJo=
github.com/google/pprof v0.0.0-20230207041349-798e818bf904 h1:4/hN5RUoecvl+RmJRE2YxKWtnnQls6rQjjW5oV7qg2U=
github.com/google/pprof v0.0.0-20230207041349-798e818bf904/go.mod h1:uglQLonpP8qtYCYyzA+8c/9qtqgA3qsXGYqCPKARAFg=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/gorilla/context v1.1.1 h1:AWwleXJkX/nhcU9bZSnZoi3h/qGYqQAGhq6zZe/aQW8=

1041
i18n/zh-cn.json Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -39,7 +39,6 @@ func main() {
return
}
common.SetupLogger()
common.SysLog("New API " + common.Version + " started")
if os.Getenv("GIN_MODE") != "debug" {
gin.SetMode(gin.ReleaseMode)
@@ -69,9 +68,9 @@ func main() {
if r := recover(); r != nil {
common.SysError(fmt.Sprintf("InitChannelCache panic: %v, retrying once", r))
// Retry once
_, fixErr := model.FixAbility()
_, _, fixErr := model.FixAbility()
if fixErr != nil {
common.SysError(fmt.Sprintf("InitChannelCache failed: %s", fixErr.Error()))
common.FatalLog(fmt.Sprintf("InitChannelCache failed: %s", fixErr.Error()))
}
}
}()
@@ -172,6 +171,8 @@ func InitResources() error {
// 加载环境变量
common.InitEnv()
common.SetupLogger()
// Initialize model settings
ratio_setting.InitRatioSettings()

View File

@@ -247,9 +247,9 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode
}
c.Set("channel_id", channel.Id)
c.Set("channel_name", channel.Name)
c.Set("channel_type", channel.Type)
common.SetContextKey(c, constant.ContextKeyChannelType, channel.Type)
c.Set("channel_create_time", channel.CreatedTime)
c.Set("channel_setting", channel.GetSetting())
common.SetContextKey(c, constant.ContextKeyChannelSetting, channel.GetSetting())
c.Set("param_override", channel.GetParamOverride())
if nil != channel.OpenAIOrganization && "" != *channel.OpenAIOrganization {
c.Set("channel_organization", *channel.OpenAIOrganization)
@@ -258,7 +258,7 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode
c.Set("model_mapping", channel.GetModelMapping())
c.Set("status_code_mapping", channel.GetStatusCodeMapping())
c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
c.Set("base_url", channel.GetBaseURL())
common.SetContextKey(c, constant.ContextKeyBaseUrl, channel.GetBaseURL())
// TODO: api_version统一
switch channel.Type {
case constant.ChannelTypeAzure:

62
middleware/jsrt/cfg.go Normal file
View File

@@ -0,0 +1,62 @@
package jsrt
import (
"os"
"strconv"
"time"
)
// Runtime 配置
type JSRuntimeConfig struct {
Enabled bool `json:"enabled"`
MaxVMCount int `json:"max_vm_count"`
ScriptTimeout time.Duration `json:"script_timeout"`
ScriptDir string `json:"script_dir"`
FetchTimeout time.Duration `json:"fetch_timeout"`
}
var (
jsConfig = JSRuntimeConfig{}
)
const (
defaultScriptDir = "scripts/"
defaultScriptTimeout = 5 * time.Second
defaultFetchTimeout = 10 * time.Second
defaultMaxVMCount = 8
)
func loadCfg() {
if enabled := os.Getenv("JS_RUNTIME_ENABLED"); enabled != "" {
jsConfig.Enabled = enabled == "true"
}
if maxCount := os.Getenv("JS_MAX_VM_COUNT"); maxCount != "" {
if count, err := strconv.Atoi(maxCount); err == nil && count > 0 {
jsConfig.MaxVMCount = count
}
} else {
jsConfig.MaxVMCount = defaultMaxVMCount
}
if timeout := os.Getenv("JS_SCRIPT_TIMEOUT"); timeout != "" {
if t, err := time.ParseDuration(timeout + "s"); err == nil && t > 0 {
jsConfig.ScriptTimeout = t
}
} else {
jsConfig.ScriptTimeout = defaultScriptTimeout
}
if fetchTimeout := os.Getenv("JS_FETCH_TIMEOUT"); fetchTimeout != "" {
if t, err := time.ParseDuration(fetchTimeout + "s"); err == nil && t > 0 {
jsConfig.FetchTimeout = t
}
} else {
jsConfig.FetchTimeout = defaultFetchTimeout
}
jsConfig.ScriptDir = os.Getenv("JS_SCRIPT_DIR")
if jsConfig.ScriptDir == "" {
jsConfig.ScriptDir = defaultScriptDir
}
}

69
middleware/jsrt/db.go Normal file
View File

@@ -0,0 +1,69 @@
package jsrt
import (
"one-api/common"
"gorm.io/gorm"
)
func dbQuery(db *gorm.DB, sql string, args ...any) []map[string]any {
if db == nil {
common.SysError("JS DB is nil")
return nil
}
rows, err := db.Raw(sql, args...).Rows()
if err != nil {
common.SysError("JS DB Query Error: " + err.Error())
return nil
}
defer rows.Close()
columns, err := rows.Columns()
if err != nil {
common.SysError("JS DB Columns Error: " + err.Error())
return nil
}
results := make([]map[string]any, 0, 100)
for rows.Next() {
values := make([]any, len(columns))
valuePtrs := make([]any, len(columns))
for i := range values {
valuePtrs[i] = &values[i]
}
if err := rows.Scan(valuePtrs...); err != nil {
common.SysError("JS DB Scan Error: " + err.Error())
continue
}
row := make(map[string]any, len(columns))
for i, col := range columns {
val := values[i]
if b, ok := val.([]byte); ok {
row[col] = string(b)
} else {
row[col] = val
}
}
results = append(results, row)
}
return results
}
func dbExec(db *gorm.DB, sql string, args ...any) map[string]any {
if db == nil {
return map[string]any{
"rowsAffected": int64(0),
"error": "database is nil",
}
}
result := db.Exec(sql, args...)
return map[string]any{
"rowsAffected": result.RowsAffected,
"error": result.Error,
}
}

137
middleware/jsrt/fetch.go Normal file
View File

@@ -0,0 +1,137 @@
package jsrt
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"time"
)
type JSFetchRequest struct {
Method string `json:"method"`
URL string `json:"url"`
Headers map[string]string `json:"headers"`
Body any `json:"body"`
Timeout int `json:"timeout"`
}
type JSFetchResponse struct {
Status int `json:"status"`
Headers map[string]string `json:"headers"`
Body string `json:"body"`
Error string `json:"error,omitempty"`
}
func (p *JSRuntimePool) fetch(url string, options ...any) *JSFetchResponse {
req := &JSFetchRequest{
Method: "GET",
URL: url,
Headers: make(map[string]string),
Timeout: int(jsConfig.FetchTimeout.Seconds()),
}
// 解析选项
if len(options) > 0 && options[0] != nil {
if optMap, ok := options[0].(map[string]any); ok {
if method, exists := optMap["method"]; exists {
if methodStr, ok := method.(string); ok {
req.Method = strings.ToUpper(methodStr)
}
}
if headers, exists := optMap["headers"]; exists {
if headersMap, ok := headers.(map[string]any); ok {
for k, v := range headersMap {
if vStr, ok := v.(string); ok {
req.Headers[k] = vStr
}
}
}
}
if body, exists := optMap["body"]; exists {
req.Body = body
}
if timeout, exists := optMap["timeout"]; exists {
if timeoutNum, ok := timeout.(float64); ok {
req.Timeout = int(timeoutNum)
}
}
}
}
// 创建HTTP请求
var bodyReader io.Reader
switch body := req.Body.(type) {
case string:
bodyReader = strings.NewReader(body)
case []byte:
bodyReader = bytes.NewReader(body)
case nil:
bodyReader = nil
default:
bodyBytes, err := json.Marshal(body)
if err != nil {
return &JSFetchResponse{
Error: fmt.Sprintf("Failed to marshal body: %v", err),
}
}
bodyReader = bytes.NewReader(bodyBytes)
}
httpReq, err := http.NewRequest(req.Method, req.URL, bodyReader)
if err != nil {
return &JSFetchResponse{
Error: err.Error(),
}
}
// 设置请求头
for k, v := range req.Headers {
httpReq.Header.Set(k, v)
}
// 设置默认User-Agent
if httpReq.Header.Get("User-Agent") == "" {
httpReq.Header.Set("User-Agent", "JS-Runtime-Fetch/1.0")
}
// 创建带超时的上下文
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(req.Timeout)*time.Second)
defer cancel()
httpReq = httpReq.WithContext(ctx)
// 执行请求
resp, err := p.httpClient.Do(httpReq)
if err != nil {
return &JSFetchResponse{}
}
defer resp.Body.Close()
// 读取响应体
bodyBytes, err := io.ReadAll(resp.Body)
if err != nil {
return &JSFetchResponse{
Status: resp.StatusCode,
}
}
// 构建响应头
headers := make(map[string]string)
for k, v := range resp.Header {
if len(v) > 0 {
headers[k] = v[0]
}
}
return &JSFetchResponse{
Status: resp.StatusCode,
Headers: headers,
Body: string(bodyBytes),
}
}

570
middleware/jsrt/jsrt.go Normal file
View File

@@ -0,0 +1,570 @@
package jsrt
import (
"bytes"
"crypto/tls"
"encoding/json"
"fmt"
"io"
"net/http"
"one-api/common"
"one-api/model"
"os"
"path/filepath"
"strings"
"sync"
"time"
"github.com/dop251/goja"
"github.com/gin-gonic/gin"
)
// 池化
type JSRuntimePool struct {
pool chan *goja.Runtime
maxSize int
createFunc func() *goja.Runtime
scripts string
mu sync.RWMutex
httpClient *http.Client
}
var (
jsRuntimePool *JSRuntimePool
jsPoolOnce sync.Once
)
func NewJSRuntimePool(maxSize int) *JSRuntimePool {
// 创建HTTP客户端
httpClient := &http.Client{
Timeout: jsConfig.FetchTimeout,
Transport: &http.Transport{
TLSClientConfig: &tls.Config{
InsecureSkipVerify: false,
},
MaxIdleConns: 100,
MaxIdleConnsPerHost: 10,
IdleConnTimeout: 90 * time.Second,
},
}
pool := &JSRuntimePool{
pool: make(chan *goja.Runtime, maxSize),
maxSize: maxSize,
scripts: "",
httpClient: httpClient,
}
pool.createFunc = func() *goja.Runtime {
vm := goja.New()
pool.setupGlobals(vm)
pool.loadScripts(vm)
return vm
}
// 预创建VM
preCreate := min(maxSize/2, 4)
for range preCreate {
select {
case pool.pool <- pool.createFunc():
default:
}
}
return pool
}
func (p *JSRuntimePool) Get() *goja.Runtime {
select {
case vm := <-p.pool:
return vm
default:
return p.createFunc()
}
}
func (p *JSRuntimePool) Put(vm *goja.Runtime) {
if vm == nil {
return
}
select {
case p.pool <- vm:
default:
// 池满丢弃VM让GC回收
}
}
func (p *JSRuntimePool) setupGlobals(vm *goja.Runtime) {
// console
console := vm.NewObject()
console.Set("log", func(args ...any) {
var strs []string
for _, arg := range args {
strs = append(strs, fmt.Sprintf("%v", arg))
}
common.SysLog("JS: " + strings.Join(strs, " "))
})
console.Set("error", func(args ...any) {
var strs []string
for _, arg := range args {
strs = append(strs, fmt.Sprintf("%v", arg))
}
common.SysError("JS: " + strings.Join(strs, " "))
})
console.Set("warn", func(args ...any) {
var strs []string
for _, arg := range args {
strs = append(strs, fmt.Sprintf("%v", arg))
}
common.SysError("JS WARN: " + strings.Join(strs, " "))
})
vm.Set("console", console)
// JSON
jsonObj := vm.NewObject()
jsonObj.Set("parse", func(str string) any {
var result any
err := json.Unmarshal([]byte(str), &result)
if err != nil {
panic(vm.ToValue(err.Error()))
}
return result
})
jsonObj.Set("stringify", func(obj any) string {
data, err := json.Marshal(obj)
if err != nil {
panic(vm.ToValue(err.Error()))
}
return string(data)
})
vm.Set("JSON", jsonObj)
// fetch 实现
vm.Set("fetch", func(url string, options ...any) *JSFetchResponse {
return p.fetch(url, options...)
})
// 数据库
setDB(vm, model.DB, "db")
setDB(vm, model.LOG_DB, "logdb")
// 定时器
vm.Set("setTimeout", func(fn func(), delay int) {
go func() {
time.Sleep(time.Duration(delay) * time.Millisecond)
fn()
}()
})
}
func (p *JSRuntimePool) loadScripts(vm *goja.Runtime) {
p.mu.RLock()
defer p.mu.RUnlock()
// 如果已经缓存了合并的脚本,直接使用
if p.scripts != "" {
if _, err := vm.RunString(p.scripts); err != nil {
common.SysError("Failed to load cached scripts: " + err.Error())
}
return
}
// 首次加载时,读取 scripts/ 文件夹中的所有脚本
p.mu.RUnlock()
p.mu.Lock()
defer func() {
p.mu.Unlock()
p.mu.RLock()
}()
if p.scripts != "" {
if _, err := vm.RunString(p.scripts); err != nil {
common.SysError("Failed to load cached scripts: " + err.Error())
}
return
}
// 读取所有脚本文件
var combinedScript strings.Builder
scriptDir := jsConfig.ScriptDir
// 检查目录是否存在
if _, err := os.Stat(scriptDir); os.IsNotExist(err) {
common.SysLog("Scripts directory does not exist: " + scriptDir)
return
}
// 读取目录中的所有 .js 文件
files, err := filepath.Glob(filepath.Join(scriptDir, "*.js"))
if err != nil {
common.SysError("Failed to read scripts directory: " + err.Error())
return
}
if len(files) == 0 {
common.SysLog("No JavaScript files found in: " + scriptDir)
return
}
// 按文件名排序以确保加载顺序一致
for _, file := range files {
content, err := os.ReadFile(file)
if err != nil {
common.SysError("Failed to read script file " + file + ": " + err.Error())
continue
}
// 添加文件注释和内容
combinedScript.WriteString("// File: " + filepath.Base(file) + "\n")
combinedScript.WriteString(string(content))
combinedScript.WriteString("\n\n")
common.SysLog("Loaded script: " + filepath.Base(file))
}
// 缓存合并后的脚本
p.scripts = combinedScript.String()
// 执行脚本
if p.scripts != "" {
if _, err := vm.RunString(p.scripts); err != nil {
common.SysError("Failed to load combined scripts: " + err.Error())
} else {
common.SysLog("Successfully loaded and combined all JavaScript files from: " + scriptDir)
}
}
}
func (p *JSRuntimePool) ReloadScripts() {
p.mu.Lock()
defer p.mu.Unlock()
// 清空缓存的脚本
p.scripts = ""
// 清空VM池强制重新创建
for {
select {
case <-p.pool:
default:
goto done
}
}
done:
common.SysLog("JavaScript scripts reloaded")
}
func initJSRuntimePool() *JSRuntimePool {
jsPoolOnce.Do(func() {
jsRuntimePool = NewJSRuntimePool(jsConfig.MaxVMCount)
common.SysLog("JavaScript runtime pool initialized successfully")
})
return jsRuntimePool
}
func validateGinContext(c *gin.Context) error {
if c == nil {
return fmt.Errorf("gin context is nil")
}
if c.Request == nil {
return fmt.Errorf("gin context request is nil")
}
return nil
}
func (p *JSRuntimePool) executeWithTimeout(_ *goja.Runtime, fn func() (goja.Value, error)) (goja.Value, error) {
type result struct {
value goja.Value
err error
}
resultChan := make(chan result, 1)
go func() {
defer func() {
if r := recover(); r != nil {
resultChan <- result{err: fmt.Errorf("JS panic: %v", r)}
}
}()
value, err := fn()
resultChan <- result{value: value, err: err}
}()
select {
case res := <-resultChan:
return res.value, res.err
case <-time.After(jsConfig.ScriptTimeout):
return nil, fmt.Errorf("script execution timeout after %v", jsConfig.ScriptTimeout)
}
}
func (p *JSRuntimePool) PreProcessRequest(c *gin.Context) error {
if err := validateGinContext(c); err != nil {
common.SysError("JS PreProcess Validation Error: " + err.Error())
return err
}
vm := p.Get()
defer p.Put(vm)
preProcessFunc := vm.Get("preProcessRequest")
if preProcessFunc == nil || goja.IsUndefined(preProcessFunc) {
return nil
}
jsReq, err := common.StructToMap(createJSReq(c))
if err != nil {
return fmt.Errorf("failed to create JS context: %v", err)
}
result, err := p.executeWithTimeout(vm, func() (goja.Value, error) {
fn, ok := goja.AssertFunction(preProcessFunc)
if !ok {
return nil, fmt.Errorf("preProcessRequest is not a function")
}
return fn(goja.Undefined(), vm.ToValue(jsReq))
})
if err != nil {
common.SysError("JS PreProcess Error: " + err.Error())
return err
}
// 处理返回结果
if result != nil && !goja.IsUndefined(result) {
resultObj := result.Export()
if resultMap, ok := resultObj.(map[string]any); ok {
// 是否修改请求
if newBody, exists := resultMap["body"]; exists {
switch v := newBody.(type) {
case string:
c.Request.Body = io.NopCloser(strings.NewReader(v))
c.Request.ContentLength = int64(len(v))
case []byte:
c.Request.Body = io.NopCloser(bytes.NewBuffer(v))
c.Request.ContentLength = int64(len(v))
case map[string]any:
bodyBytes, err := json.Marshal(v)
if err == nil {
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
c.Request.ContentLength = int64(len(bodyBytes))
c.Request.Header.Set("Content-Type", "application/json")
} else {
common.SysError("JS PreProcess JSON Marshal Error: " + err.Error())
}
default:
common.SysError("JS PreProcess Unsupported Body Type: " + fmt.Sprintf("%T", newBody))
}
}
// 是否修改 headers
if newHeaders, exists := resultMap["headers"]; exists {
if headersMap, ok := newHeaders.(map[string]any); ok {
for key, value := range headersMap {
if valueStr, ok := value.(string); ok {
c.Request.Header.Set(key, valueStr)
}
}
}
}
// 是否阻止请求
if block, exists := resultMap["block"]; exists {
if blockBool, ok := block.(bool); ok && blockBool {
status := http.StatusForbidden
if statusCode, exists := resultMap["statusCode"]; exists {
if statusInt, ok := statusCode.(float64); ok {
status = int(statusInt)
}
}
message := "Request blocked by pre-process script"
if msg, exists := resultMap["message"]; exists {
if msgStr, ok := msg.(string); ok {
message = msgStr
}
}
c.JSON(status, gin.H{"error": message})
c.Abort()
return fmt.Errorf("request blocked")
}
}
}
}
return nil
}
func (p *JSRuntimePool) PostProcessResponse(c *gin.Context, statusCode int, body []byte) (int, []byte, error) {
if err := validateGinContext(c); err != nil {
common.SysError("JS PostProcess Validation Error: " + err.Error())
return statusCode, body, err
}
vm := p.Get()
defer p.Put(vm)
postProcessFunc := vm.Get("postProcessResponse")
if postProcessFunc == nil || goja.IsUndefined(postProcessFunc) {
return statusCode, body, nil
}
jsReq, err := common.StructToMap(createJSReq(c))
if err != nil {
return statusCode, body, fmt.Errorf("failed to create JS context: %v", err)
}
jsResp := &JSResponse{
StatusCode: statusCode,
Headers: make(map[string]string),
Body: string(body),
}
// 获取响应头
if c.Writer != nil {
for key, values := range c.Writer.Header() {
if len(values) > 0 {
jsResp.Headers[key] = values[0]
}
}
}
jsResponse, err := common.StructToMap(jsResp)
if err != nil {
return statusCode, body, fmt.Errorf("failed to create JS response context: %v", err)
}
result, err := p.executeWithTimeout(vm, func() (goja.Value, error) {
fn, ok := goja.AssertFunction(postProcessFunc)
if !ok {
return nil, fmt.Errorf("postProcessResponse is not a function")
}
return fn(goja.Undefined(), vm.ToValue(jsReq), vm.ToValue(jsResponse))
})
if err != nil {
common.SysError("JS PostProcess Error: " + err.Error())
return statusCode, body, err
}
// 处理返回
if result != nil && !goja.IsUndefined(result) {
resultObj := result.Export()
if resultMap, ok := resultObj.(map[string]any); ok {
if newStatusCode, exists := resultMap["statusCode"]; exists {
if statusInt, ok := newStatusCode.(float64); ok {
statusCode = int(statusInt)
}
}
if newBody, exists := resultMap["body"]; exists {
if bodyStr, ok := newBody.(string); ok {
body = []byte(bodyStr)
}
}
if newHeaders, exists := resultMap["headers"]; exists {
if headersMap, ok := newHeaders.(map[string]any); ok {
for key, value := range headersMap {
if valueStr, ok := value.(string); ok {
c.Header(key, valueStr)
}
}
}
}
}
}
return statusCode, body, nil
}
func (p *JSRuntimePool) hasPostProcessFunction() bool {
vm := p.Get()
defer p.Put(vm)
postProcessFunc := vm.Get("postProcessResponse")
return postProcessFunc != nil && !goja.IsUndefined(postProcessFunc)
}
func JSRuntimeMiddleware() *gin.HandlerFunc {
loadCfg()
if !jsConfig.Enabled {
common.SysLog("JavaScript Runtime is disabled")
return nil
}
pool := initJSRuntimePool()
var fn gin.HandlerFunc
fn = func(c *gin.Context) {
start := time.Now()
// 预处理
if err := pool.PreProcessRequest(c); err != nil {
common.SysError("JS Runtime PreProcess Error: " + err.Error())
return
}
duration := time.Since(start)
if duration > time.Millisecond*100 {
common.SysLog(fmt.Sprintf("JS Runtime PreProcess took %v", duration))
}
// 后处理
if pool.hasPostProcessFunction() {
writer := newResponseWriter(c.Writer)
c.Writer = writer
c.Next()
// 后处理响应
if writer.body.Len() > 0 {
start := time.Now()
statusCode, body, err := pool.PostProcessResponse(c, writer.statusCode, writer.body.Bytes())
if err == nil {
c.Writer = writer.ResponseWriter
for k, v := range writer.headerMap {
for _, value := range v {
c.Writer.Header().Add(k, value)
}
}
c.Status(statusCode)
if len(body) >= 0 {
c.Writer.Header().Set("Content-Length", fmt.Sprintf("%d", len(body)))
c.Writer.Write(body)
} else {
c.Writer.Header().Del("Content-Length")
c.Writer.Write(body)
}
} else {
// 出错时回复原响应
c.Writer = writer.ResponseWriter
c.Status(writer.statusCode)
common.SysError(fmt.Sprintf("JS Runtime PostProcess Error: %v", err))
}
duration := time.Since(start)
if duration > time.Millisecond*100 {
common.SysLog(fmt.Sprintf("JS Runtime PostProcess took %v", duration))
}
} else {
// 没有响应体时恢复原始writer
c.Writer = writer.ResponseWriter
}
} else {
c.Next()
}
}
return &fn
}
func ReloadJSScripts() {
if jsRuntimePool != nil {
jsRuntimePool.ReloadScripts()
common.SysLog("JavaScript scripts reloaded")
}
}

139
middleware/jsrt/req.go Normal file
View File

@@ -0,0 +1,139 @@
package jsrt
import (
"bytes"
"io"
"maps"
"net/http"
"sync"
"github.com/gin-gonic/gin"
)
// 请求
type JSReq struct {
Method string `json:"method"`
URL string `json:"url"`
Headers map[string]string `json:"headers"`
Body any `json:"body"`
UserAgent string `json:"userAgent"`
RemoteIP string `json:"remoteIP"`
Extra map[string]any `json:"extra"`
}
type JSResponse struct {
StatusCode int `json:"statusCode"`
Headers map[string]string `json:"headers"`
Body string `json:"body"`
}
type responseWriter struct {
gin.ResponseWriter
body *bytes.Buffer
statusCode int
headerMap http.Header
written bool
mu sync.RWMutex
}
func createJSReq(c *gin.Context) *JSReq {
var bodyBytes []byte
if c.Request != nil && c.Request.Body != nil {
bodyBytes, _ = io.ReadAll(c.Request.Body)
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
}
// headers map
headers := make(map[string]string)
if c.Request != nil && c.Request.Header != nil {
for key, values := range c.Request.Header {
if len(values) > 0 {
headers[key] = values[0]
}
}
}
method := ""
url := ""
userAgent := ""
remoteIP := ""
contentType := ""
if c.Request != nil {
method = c.Request.Method
if c.Request.URL != nil {
url = c.Request.URL.String()
}
userAgent = c.Request.UserAgent()
contentType = c.ContentType()
}
if c != nil {
remoteIP = c.ClientIP()
}
parsedBody := parseBodyByType(bodyBytes, contentType)
return &JSReq{
Method: method,
URL: url,
Headers: headers,
Body: parsedBody,
UserAgent: userAgent,
RemoteIP: remoteIP,
Extra: make(map[string]any),
}
}
func newResponseWriter(w gin.ResponseWriter) *responseWriter {
return &responseWriter{
ResponseWriter: w,
body: &bytes.Buffer{},
statusCode: 200,
headerMap: make(http.Header),
written: false,
}
}
func (w *responseWriter) Write(data []byte) (int, error) {
w.mu.Lock()
defer w.mu.Unlock()
if !w.written {
w.WriteHeader(200)
}
return w.body.Write(data)
}
func (w *responseWriter) WriteString(s string) (int, error) {
w.mu.Lock()
defer w.mu.Unlock()
if !w.written {
w.WriteHeader(200)
}
return w.body.WriteString(s)
}
func (w *responseWriter) WriteHeader(statusCode int) {
w.mu.Lock()
defer w.mu.Unlock()
if w.written {
return
}
w.statusCode = statusCode
w.written = true
maps.Copy(w.headerMap, w.ResponseWriter.Header())
}
func (w *responseWriter) Header() http.Header {
w.mu.RLock()
defer w.mu.RUnlock()
if w.headerMap == nil {
w.headerMap = make(http.Header)
}
return w.headerMap
}

86
middleware/jsrt/utils.go Normal file
View File

@@ -0,0 +1,86 @@
package jsrt
import (
"encoding/json"
"net/url"
"one-api/common"
"strings"
"github.com/dop251/goja"
"gorm.io/gorm"
)
func setDB(vm *goja.Runtime, db *gorm.DB, name string) {
if db == nil {
common.SysError("JS DB is nil")
return
}
obj := vm.NewObject()
obj.Set("query", func(sql string, params ...any) []map[string]any {
return dbQuery(db, sql, params...)
})
obj.Set("exec", func(sql string, params ...any) map[string]any {
return dbExec(db, sql, params...)
})
if err := vm.Set(name, obj); err != nil {
common.SysError("Failed to set JS DB: " + err.Error())
return
}
}
func parseBodyByType(bodyBytes []byte, contentType string) any {
if len(bodyBytes) == 0 {
return ""
}
bodyStr := string(bodyBytes)
contentLower := strings.ToLower(contentType)
switch {
case strings.Contains(contentLower, "application/json"):
var jsonObj any
if err := json.Unmarshal(bodyBytes, &jsonObj); err == nil {
return jsonObj
}
return bodyStr
case strings.Contains(contentLower, "application/x-www-form-urlencoded"):
if values, err := url.ParseQuery(bodyStr); err == nil {
result := make(map[string]string, len(values))
for k, v := range values {
if len(v) > 0 {
result[k] = v[0]
}
}
return result
}
return bodyStr
case strings.Contains(contentLower, "multipart/form-data"):
return bodyBytes
case strings.Contains(contentLower, "text/"):
return bodyStr
default:
// 尝试JSON解析
var jsonObj any
if json.Unmarshal(bodyBytes, &jsonObj) == nil {
return jsonObj
}
// 尝试form解析
if values, err := url.ParseQuery(bodyStr); err == nil && len(values) > 0 {
result := make(map[string]string, len(values))
for k, v := range values {
if len(v) > 0 {
result[k] = v[0]
}
}
return result
}
return bodyStr
}
}

View File

@@ -5,6 +5,7 @@ import (
"fmt"
"one-api/common"
"strings"
"sync"
"github.com/samber/lo"
"gorm.io/gorm"
@@ -272,74 +273,45 @@ func UpdateAbilityByTag(tag string, newTag *string, priority *int64, weight *uin
return DB.Model(&Ability{}).Where("tag = ?", tag).Updates(ability).Error
}
func FixAbility() (int, error) {
var channelIds []int
count := 0
// Find all channel ids from channel table
err := DB.Model(&Channel{}).Pluck("id", &channelIds).Error
var fixLock = sync.Mutex{}
func FixAbility() (int, int, error) {
lock := fixLock.TryLock()
if !lock {
return 0, 0, errors.New("已经有一个修复任务在运行中,请稍后再试")
}
defer fixLock.Unlock()
var channels []*Channel
// Find all channels
err := DB.Model(&Channel{}).Find(&channels).Error
if err != nil {
common.SysError(fmt.Sprintf("Get channel ids from channel table failed: %s", err.Error()))
return 0, err
return 0, 0, err
}
// Delete abilities of channels that are not in channel table - in batches to avoid too many placeholders
if len(channelIds) > 0 {
// Process deletion in chunks to avoid "too many placeholders" error
for _, chunk := range lo.Chunk(channelIds, 100) {
err = DB.Where("channel_id NOT IN (?)", chunk).Delete(&Ability{}).Error
if err != nil {
common.SysError(fmt.Sprintf("Delete abilities of channels (batch) that are not in channel table failed: %s", err.Error()))
return 0, err
}
}
} else {
// If no channels exist, delete all abilities
err = DB.Delete(&Ability{}).Error
if len(channels) == 0 {
return 0, 0, nil
}
successCount := 0
failCount := 0
for _, chunk := range lo.Chunk(channels, 50) {
ids := lo.Map(chunk, func(c *Channel, _ int) int { return c.Id })
// Delete all abilities of this channel
err = DB.Where("channel_id IN ?", ids).Delete(&Ability{}).Error
if err != nil {
common.SysError(fmt.Sprintf("Delete all abilities failed: %s", err.Error()))
return 0, err
common.SysError(fmt.Sprintf("Delete abilities failed: %s", err.Error()))
failCount += len(chunk)
continue
}
common.SysLog("Delete all abilities successfully")
return 0, nil
}
common.SysLog(fmt.Sprintf("Delete abilities of channels that are not in channel table successfully, ids: %v", channelIds))
count += len(channelIds)
// Use channelIds to find channel not in abilities table
var abilityChannelIds []int
err = DB.Table("abilities").Distinct("channel_id").Pluck("channel_id", &abilityChannelIds).Error
if err != nil {
common.SysError(fmt.Sprintf("Get channel ids from abilities table failed: %s", err.Error()))
return count, err
}
var channels []Channel
if len(abilityChannelIds) == 0 {
err = DB.Find(&channels).Error
} else {
// Process query in chunks to avoid "too many placeholders" error
err = nil
for _, chunk := range lo.Chunk(abilityChannelIds, 100) {
var channelsChunk []Channel
err = DB.Where("id NOT IN (?)", chunk).Find(&channelsChunk).Error
// Then add new abilities
for _, channel := range chunk {
err = channel.AddAbilities()
if err != nil {
common.SysError(fmt.Sprintf("Find channels not in abilities table failed: %s", err.Error()))
return count, err
common.SysError(fmt.Sprintf("Add abilities for channel %d failed: %s", channel.Id, err.Error()))
failCount++
} else {
successCount++
}
channels = append(channels, channelsChunk...)
}
}
for _, channel := range channels {
err := channel.UpdateAbilities(nil)
if err != nil {
common.SysError(fmt.Sprintf("Update abilities of channel %d failed: %s", channel.Id, err.Error()))
} else {
common.SysLog(fmt.Sprintf("Update abilities of channel %d successfully", channel.Id))
count++
}
}
InitChannelCache()
return count, nil
return successCount, failCount, nil
}

View File

@@ -3,6 +3,7 @@ package model
import (
"encoding/json"
"one-api/common"
"one-api/dto"
"strings"
"sync"
@@ -514,8 +515,19 @@ func SearchTags(keyword string, group string, model string, idSort bool) ([]*str
return tags, nil
}
func (channel *Channel) GetSetting() map[string]interface{} {
setting := make(map[string]interface{})
func (channel *Channel) ValidateSettings() error {
channelParams := &dto.ChannelSettings{}
if channel.Setting != nil && *channel.Setting != "" {
err := json.Unmarshal([]byte(*channel.Setting), channelParams)
if err != nil {
return err
}
}
return nil
}
func (channel *Channel) GetSetting() dto.ChannelSettings {
setting := dto.ChannelSettings{}
if channel.Setting != nil && *channel.Setting != "" {
err := json.Unmarshal([]byte(*channel.Setting), &setting)
if err != nil {
@@ -525,7 +537,7 @@ func (channel *Channel) GetSetting() map[string]interface{} {
return setting
}
func (channel *Channel) SetSetting(setting map[string]interface{}) {
func (channel *Channel) SetSetting(setting dto.ChannelSettings) {
settingBytes, err := json.Marshal(setting)
if err != nil {
common.SysError("failed to marshal setting: " + err.Error())

View File

@@ -4,7 +4,6 @@ import (
"context"
"fmt"
"one-api/common"
"one-api/constant"
"os"
"strings"
"time"
@@ -100,10 +99,8 @@ func RecordErrorLog(c *gin.Context, userId int, channelId int, modelName string,
// 判断是否需要记录 IP
needRecordIp := false
if settingMap, err := GetUserSetting(userId, false); err == nil {
if v, ok := settingMap[constant.UserSettingRecordIpLog]; ok {
if vb, ok := v.(bool); ok && vb {
needRecordIp = true
}
if settingMap.RecordIpLog {
needRecordIp = true
}
}
log := &Log{
@@ -136,22 +133,34 @@ func RecordErrorLog(c *gin.Context, userId int, channelId int, modelName string,
}
}
func RecordConsumeLog(c *gin.Context, userId int, channelId int, promptTokens int, completionTokens int,
modelName string, tokenName string, quota int, content string, tokenId int, userQuota int, useTimeSeconds int,
isStream bool, group string, other map[string]interface{}) {
common.LogInfo(c, fmt.Sprintf("record consume log: userId=%d, 用户调用前余额=%d, channelId=%d, promptTokens=%d, completionTokens=%d, modelName=%s, tokenName=%s, quota=%d, content=%s", userId, userQuota, channelId, promptTokens, completionTokens, modelName, tokenName, quota, content))
type RecordConsumeLogParams struct {
ChannelId int `json:"channel_id"`
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
ModelName string `json:"model_name"`
TokenName string `json:"token_name"`
Quota int `json:"quota"`
Content string `json:"content"`
TokenId int `json:"token_id"`
UserQuota int `json:"user_quota"`
UseTimeSeconds int `json:"use_time_seconds"`
IsStream bool `json:"is_stream"`
Group string `json:"group"`
Other map[string]interface{} `json:"other"`
}
func RecordConsumeLog(c *gin.Context, userId int, params RecordConsumeLogParams) {
common.LogInfo(c, fmt.Sprintf("record consume log: userId=%d, params=%s", userId, common.GetJsonString(params)))
if !common.LogConsumeEnabled {
return
}
username := c.GetString("username")
otherStr := common.MapToJsonStr(other)
otherStr := common.MapToJsonStr(params.Other)
// 判断是否需要记录 IP
needRecordIp := false
if settingMap, err := GetUserSetting(userId, false); err == nil {
if v, ok := settingMap[constant.UserSettingRecordIpLog]; ok {
if vb, ok := v.(bool); ok && vb {
needRecordIp = true
}
if settingMap.RecordIpLog {
needRecordIp = true
}
}
log := &Log{
@@ -159,17 +168,17 @@ func RecordConsumeLog(c *gin.Context, userId int, channelId int, promptTokens in
Username: username,
CreatedAt: common.GetTimestamp(),
Type: LogTypeConsume,
Content: content,
PromptTokens: promptTokens,
CompletionTokens: completionTokens,
TokenName: tokenName,
ModelName: modelName,
Quota: quota,
ChannelId: channelId,
TokenId: tokenId,
UseTime: useTimeSeconds,
IsStream: isStream,
Group: group,
Content: params.Content,
PromptTokens: params.PromptTokens,
CompletionTokens: params.CompletionTokens,
TokenName: params.TokenName,
ModelName: params.ModelName,
Quota: params.Quota,
ChannelId: params.ChannelId,
TokenId: params.TokenId,
UseTime: params.UseTimeSeconds,
IsStream: params.IsStream,
Group: params.Group,
Ip: func() string {
if needRecordIp {
return c.ClientIP()
@@ -184,7 +193,7 @@ func RecordConsumeLog(c *gin.Context, userId int, channelId int, promptTokens in
}
if common.DataExportEnabled {
gopool.Go(func() {
LogQuotaData(userId, username, modelName, quota, common.GetTimestamp(), promptTokens+completionTokens)
LogQuotaData(userId, username, params.ModelName, params.Quota, common.GetTimestamp(), params.PromptTokens+params.CompletionTokens)
})
}
}

View File

@@ -14,6 +14,8 @@ type Midjourney struct {
StartTime int64 `json:"start_time" gorm:"index"`
FinishTime int64 `json:"finish_time" gorm:"index"`
ImageUrl string `json:"image_url"`
VideoUrl string `json:"video_url"`
VideoUrls string `json:"video_urls"`
Status string `json:"status" gorm:"type:varchar(20);index"`
Progress string `json:"progress" gorm:"type:varchar(30);index"`
FailReason string `json:"fail_reason"`

View File

@@ -5,6 +5,7 @@ import (
"errors"
"fmt"
"one-api/common"
"one-api/dto"
"strconv"
"strings"
@@ -68,14 +69,18 @@ func (user *User) SetAccessToken(token string) {
user.AccessToken = &token
}
func (user *User) GetSetting() map[string]interface{} {
if user.Setting == "" {
return nil
func (user *User) GetSetting() dto.UserSetting {
setting := dto.UserSetting{}
if user.Setting != "" {
err := json.Unmarshal([]byte(user.Setting), &setting)
if err != nil {
common.SysError("failed to unmarshal setting: " + err.Error())
}
}
return common.StrToMap(user.Setting)
return setting
}
func (user *User) SetSetting(setting map[string]interface{}) {
func (user *User) SetSetting(setting dto.UserSetting) {
settingBytes, err := json.Marshal(setting)
if err != nil {
common.SysError("failed to marshal setting: " + err.Error())
@@ -626,7 +631,7 @@ func GetUserGroup(id int, fromDB bool) (group string, err error) {
}
// GetUserSetting gets setting from Redis first, falls back to DB if needed
func GetUserSetting(id int, fromDB bool) (settingMap map[string]interface{}, err error) {
func GetUserSetting(id int, fromDB bool) (settingMap dto.UserSetting, err error) {
var setting string
defer func() {
// Update Redis cache asynchronously on successful DB read
@@ -648,10 +653,12 @@ func GetUserSetting(id int, fromDB bool) (settingMap map[string]interface{}, err
fromDB = true
err = DB.Model(&User{}).Where("id = ?", id).Select("setting").Find(&setting).Error
if err != nil {
return map[string]interface{}{}, err
return settingMap, err
}
return common.StrToMap(setting), nil
userBase := &UserBase{
Setting: setting,
}
return userBase.GetSetting(), nil
}
func IncreaseUserQuota(id int, quota int, db bool) (err error) {

View File

@@ -5,6 +5,7 @@ import (
"fmt"
"one-api/common"
"one-api/constant"
"one-api/dto"
"time"
"github.com/gin-gonic/gin"
@@ -32,20 +33,15 @@ func (user *UserBase) WriteContext(c *gin.Context) {
common.SetContextKey(c, constant.ContextKeyUserSetting, user.GetSetting())
}
func (user *UserBase) GetSetting() map[string]interface{} {
if user.Setting == "" {
return nil
func (user *UserBase) GetSetting() dto.UserSetting {
setting := dto.UserSetting{}
if user.Setting != "" {
err := json.Unmarshal([]byte(user.Setting), &setting)
if err != nil {
common.SysError("failed to unmarshal setting: " + err.Error())
}
}
return common.StrToMap(user.Setting)
}
func (user *UserBase) SetSetting(setting map[string]interface{}) {
settingBytes, err := json.Marshal(setting)
if err != nil {
common.SysError("failed to marshal setting: " + err.Error())
return
}
user.Setting = string(settingBytes)
return setting
}
// getUserCacheKey returns the key for user cache
@@ -174,11 +170,10 @@ func getUserNameCache(userId int) (string, error) {
return cache.Username, nil
}
func getUserSettingCache(userId int) (map[string]interface{}, error) {
setting := make(map[string]interface{})
func getUserSettingCache(userId int) (dto.UserSetting, error) {
cache, err := GetUserCache(userId)
if err != nil {
return setting, err
return dto.UserSetting{}, err
}
return cache.GetSetting(), nil
}

View File

@@ -206,8 +206,8 @@ func sendPingData(c *gin.Context, mutex *sync.Mutex) error {
func doRequest(c *gin.Context, req *http.Request, info *common.RelayInfo) (*http.Response, error) {
var client *http.Client
var err error
if proxyURL, ok := info.ChannelSetting["proxy"]; ok {
client, err = service.NewProxyHttpClient(proxyURL.(string))
if info.ChannelSetting.Proxy != "" {
client, err = service.NewProxyHttpClient(info.ChannelSetting.Proxy)
if err != nil {
return nil, fmt.Errorf("new proxy http client failed: %w", err)
}

View File

@@ -42,7 +42,16 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
channel.SetupApiRequestHeader(info, c, req)
req.Set("Authorization", "Bearer "+info.ApiKey)
keyParts := strings.Split(info.ApiKey, "|")
if len(keyParts) == 0 || keyParts[0] == "" {
return errors.New("invalid API key: authorization token is required")
}
if len(keyParts) > 1 {
if keyParts[1] != "" {
req.Set("appid", keyParts[1])
}
}
req.Set("Authorization", "Bearer "+keyParts[0])
return nil
}

View File

@@ -278,8 +278,8 @@ func getChatDetail(a *Adaptor, c *gin.Context, info *relaycommon.RelayInfo) (*ht
func doRequest(req *http.Request, info *relaycommon.RelayInfo) (*http.Response, error) {
var client *http.Client
var err error // 声明 err 变量
if proxyURL, ok := info.ChannelSetting["proxy"]; ok {
client, err = service.NewProxyHttpClient(proxyURL.(string))
if info.ChannelSetting.Proxy != "" {
client, err = service.NewProxyHttpClient(info.ChannelSetting.Proxy)
if err != nil {
return nil, fmt.Errorf("new proxy http client failed: %w", err)
}

View File

@@ -3,6 +3,7 @@ package jina
var ModelList = []string{
"jina-clip-v1",
"jina-reranker-v2-base-multilingual",
"jina-reranker-m0",
}
var ChannelName = "jina"

View File

@@ -53,7 +53,7 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
a.ChannelType = info.ChannelType
// initialize ThinkingContentInfo when thinking_to_content is enabled
if think2Content, ok := info.ChannelSetting[constant.ChannelSettingThinkingToContent].(bool); ok && think2Content {
if info.ChannelSetting.ThinkingToContent {
info.ThinkingContentInfo = relaycommon.ThinkingContentInfo{
IsFirstThinkingContent: true,
SendLastThinkingContent: false,
@@ -145,7 +145,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, header *http.Header, info *
header.Set("Authorization", "Bearer "+info.ApiKey)
}
if info.ChannelType == constant.ChannelTypeOpenRouter {
header.Set("HTTP-Referer", "https://github.com/Calcium-Ion/new-api")
header.Set("HTTP-Referer", "https://www.newapi.ai")
header.Set("X-Title", "New API")
}
return nil

View File

@@ -124,12 +124,12 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
var forceFormat bool
var thinkToContent bool
if forceFmt, ok := info.ChannelSetting[constant.ForceFormat].(bool); ok {
forceFormat = forceFmt
if info.ChannelSetting.ForceFormat {
forceFormat = true
}
if think2Content, ok := info.ChannelSetting[constant.ChannelSettingThinkingToContent].(bool); ok {
thinkToContent = think2Content
if info.ChannelSetting.ThinkingToContent {
thinkToContent = true
}
var (
@@ -200,8 +200,8 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayI
}
forceFormat := false
if forceFmt, ok := info.ChannelSetting[constant.ForceFormat].(bool); ok {
forceFormat = forceFmt
if info.ChannelSetting.ForceFormat {
forceFormat = true
}
if simpleResponse.Usage.TotalTokens == 0 || (simpleResponse.Usage.PromptTokens == 0 && simpleResponse.Usage.CompletionTokens == 0) {

View File

@@ -106,8 +106,8 @@ func exchangeJwtForAccessToken(signedJWT string, info *relaycommon.RelayInfo) (s
var client *http.Client
var err error
if proxyURL, ok := info.ChannelSetting["proxy"]; ok {
client, err = service.NewProxyHttpClient(proxyURL.(string))
if info.ChannelSetting.Proxy != "" {
client, err = service.NewProxyHttpClient(info.ChannelSetting.Proxy)
if err != nil {
return "", fmt.Errorf("new proxy http client failed: %w", err)
}

View File

@@ -97,9 +97,9 @@ type RelayInfo struct {
IsFirstRequest bool
AudioUsage bool
ReasoningEffort string
ChannelSetting map[string]interface{}
ChannelSetting dto.ChannelSettings
ParamOverride map[string]interface{}
UserSetting map[string]interface{}
UserSetting dto.UserSetting
UserEmail string
UserQuota int
RelayFormat string
@@ -213,7 +213,6 @@ func GenRelayInfoImage(c *gin.Context) *RelayInfo {
func GenRelayInfo(c *gin.Context) *RelayInfo {
channelType := common.GetContextKeyInt(c, constant.ContextKeyChannelType)
channelId := common.GetContextKeyInt(c, constant.ContextKeyChannelId)
channelSetting := common.GetContextKeyStringMap(c, constant.ContextKeyChannelSetting)
paramOverride := common.GetContextKeyStringMap(c, constant.ContextKeyParamOverride)
tokenId := common.GetContextKeyInt(c, constant.ContextKeyTokenId)
@@ -227,7 +226,6 @@ func GenRelayInfo(c *gin.Context) *RelayInfo {
info := &RelayInfo{
UserQuota: common.GetContextKeyInt(c, constant.ContextKeyUserQuota),
UserSetting: common.GetContextKeyStringMap(c, constant.ContextKeyUserSetting),
UserEmail: common.GetContextKeyString(c, constant.ContextKeyUserEmail),
isFirstResponse: true,
RelayMode: relayconstant.Path2RelayMode(c.Request.URL.Path),
@@ -246,12 +244,12 @@ func GenRelayInfo(c *gin.Context) *RelayInfo {
OriginModelName: common.GetContextKeyString(c, constant.ContextKeyOriginalModel),
UpstreamModelName: common.GetContextKeyString(c, constant.ContextKeyOriginalModel),
//RecodeModelName: c.GetString("original_model"),
IsModelMapped: false,
ApiType: apiType,
ApiVersion: c.GetString("api_version"),
ApiKey: strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "),
Organization: c.GetString("channel_organization"),
ChannelSetting: channelSetting,
IsModelMapped: false,
ApiType: apiType,
ApiVersion: c.GetString("api_version"),
ApiKey: strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "),
Organization: c.GetString("channel_organization"),
ChannelCreateTime: c.GetInt64("channel_create_time"),
ParamOverride: paramOverride,
RelayFormat: RelayFormatOpenAI,
@@ -277,6 +275,16 @@ func GenRelayInfo(c *gin.Context) *RelayInfo {
if streamSupportedChannels[info.ChannelType] {
info.SupportStreamOptions = true
}
channelSetting, ok := common.GetContextKeyType[dto.ChannelSettings](c, constant.ContextKeyChannelSetting)
if ok {
info.ChannelSetting = channelSetting
}
userSetting, ok := common.GetContextKeyType[dto.UserSetting](c, constant.ContextKeyUserSetting)
if ok {
info.UserSetting = userSetting
}
return info
}

View File

@@ -29,6 +29,8 @@ const (
RelayModeMidjourneyShorten
RelayModeSwapFace
RelayModeMidjourneyUpload
RelayModeMidjourneyVideo
RelayModeMidjourneyEdits
RelayModeAudioSpeech // tts
RelayModeAudioTranscription // whisper
@@ -108,6 +110,10 @@ func Path2RelayModeMidjourney(path string) int {
relayMode = RelayModeMidjourneyUpload
} else if strings.HasSuffix(path, "/mj/submit/imagine") {
relayMode = RelayModeMidjourneyImagine
} else if strings.HasSuffix(path, "/mj/submit/video") {
relayMode = RelayModeMidjourneyVideo
} else if strings.HasSuffix(path, "/mj/submit/edits") {
relayMode = RelayModeMidjourneyEdits
} else if strings.HasSuffix(path, "/mj/submit/blend") {
relayMode = RelayModeMidjourneyBlend
} else if strings.HasSuffix(path, "/mj/submit/describe") {

View File

@@ -3,7 +3,6 @@ package helper
import (
"fmt"
"one-api/common"
constant2 "one-api/constant"
relaycommon "one-api/relay/common"
"one-api/setting/ratio_setting"
@@ -83,11 +82,8 @@ func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens
modelRatio, success = ratio_setting.GetModelRatio(info.OriginModelName)
if !success {
acceptUnsetRatio := false
if accept, ok := info.UserSetting[constant2.UserAcceptUnsetRatioModel]; ok {
b, ok := accept.(bool)
if ok {
acceptUnsetRatio = b
}
if info.UserSetting.AcceptUnsetRatioModel {
acceptUnsetRatio = true
}
if !acceptUnsetRatio {
return PriceData{}, fmt.Errorf("模型 %s 倍率或价格未配置请联系管理员设置或开始自用模式Model %s ratio or price not set, please set or start self-use mode", info.OriginModelName, info.OriginModelName)

View File

@@ -34,14 +34,13 @@ func RelayMidjourneyImage(c *gin.Context) {
}
var httpClient *http.Client
if channel, err := model.CacheGetChannel(midjourneyTask.ChannelId); err == nil {
if proxy, ok := channel.GetSetting()["proxy"]; ok {
if proxyURL, ok := proxy.(string); ok && proxyURL != "" {
if httpClient, err = service.NewProxyHttpClient(proxyURL); err != nil {
c.JSON(400, gin.H{
"error": "proxy_url_invalid",
})
return
}
proxy := channel.GetSetting().Proxy
if proxy != "" {
if httpClient, err = service.NewProxyHttpClient(proxy); err != nil {
c.JSON(400, gin.H{
"error": "proxy_url_invalid",
})
return
}
}
}
@@ -106,6 +105,9 @@ func RelayMidjourneyNotify(c *gin.Context) *dto.MidjourneyResponse {
midjourneyTask.StartTime = midjRequest.StartTime
midjourneyTask.FinishTime = midjRequest.FinishTime
midjourneyTask.ImageUrl = midjRequest.ImageUrl
midjourneyTask.VideoUrl = midjRequest.VideoUrl
videoUrlsStr, _ := json.Marshal(midjRequest.VideoUrls)
midjourneyTask.VideoUrls = string(videoUrlsStr)
midjourneyTask.Status = midjRequest.Status
midjourneyTask.FailReason = midjRequest.FailReason
err = midjourneyTask.Update()
@@ -136,6 +138,9 @@ func coverMidjourneyTaskDto(c *gin.Context, originTask *model.Midjourney) (midjo
} else {
midjourneyTask.ImageUrl = originTask.ImageUrl
}
if originTask.VideoUrl != "" {
midjourneyTask.VideoUrl = originTask.VideoUrl
}
midjourneyTask.Status = originTask.Status
midjourneyTask.FailReason = originTask.FailReason
midjourneyTask.Action = originTask.Action
@@ -148,6 +153,13 @@ func coverMidjourneyTaskDto(c *gin.Context, originTask *model.Midjourney) (midjo
midjourneyTask.Buttons = buttons
}
}
if originTask.VideoUrls != "" {
var videoUrls []dto.ImgUrls
err := json.Unmarshal([]byte(originTask.VideoUrls), &videoUrls)
if err == nil {
midjourneyTask.VideoUrls = videoUrls
}
}
if originTask.Properties != "" {
var properties dto.Properties
err := json.Unmarshal([]byte(originTask.Properties), &properties)
@@ -162,7 +174,7 @@ func RelaySwapFace(c *gin.Context) *dto.MidjourneyResponse {
startTime := time.Now().UnixNano() / int64(time.Millisecond)
tokenId := c.GetInt("token_id")
userId := c.GetInt("id")
group := c.GetString("group")
//group := c.GetString("group")
channelId := c.GetInt("channel_id")
relayInfo := relaycommon.GenRelayInfo(c)
var swapFaceRequest dto.SwapFaceRequest
@@ -208,8 +220,17 @@ func RelaySwapFace(c *gin.Context) *dto.MidjourneyResponse {
tokenName := c.GetString("token_name")
logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s", priceData.ModelPrice, priceData.GroupRatioInfo.GroupRatio, constant.MjActionSwapFace)
other := service.GenerateMjOtherInfo(priceData)
model.RecordConsumeLog(c, userId, channelId, 0, 0, modelName, tokenName,
priceData.Quota, logContent, tokenId, userQuota, 0, false, group, other)
model.RecordConsumeLog(c, relayInfo.UserId, model.RecordConsumeLogParams{
ChannelId: channelId,
ModelName: modelName,
TokenName: tokenName,
Quota: priceData.Quota,
Content: logContent,
TokenId: tokenId,
UserQuota: userQuota,
Group: relayInfo.UsingGroup,
Other: other,
})
model.UpdateUserUsedQuotaAndRequestCount(userId, priceData.Quota)
model.UpdateChannelUsedQuota(channelId, priceData.Quota)
}
@@ -350,7 +371,7 @@ func RelayMidjourneyTask(c *gin.Context, relayMode int) *dto.MidjourneyResponse
func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyResponse {
tokenId := c.GetInt("token_id")
//tokenId := c.GetInt("token_id")
//channelType := c.GetInt("channel")
userId := c.GetInt("id")
group := c.GetString("group")
@@ -370,6 +391,9 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
}
relayMode = relayconstant.RelayModeMidjourneyChange
}
if relayMode == relayconstant.RelayModeMidjourneyVideo {
midjRequest.Action = constant.MjActionVideo
}
if relayMode == relayconstant.RelayModeMidjourneyImagine { //绘画任务,此类任务可重复
if midjRequest.Prompt == "" {
@@ -378,6 +402,8 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
midjRequest.Action = constant.MjActionImagine
} else if relayMode == relayconstant.RelayModeMidjourneyDescribe { //按图生文任务,此类任务可重复
midjRequest.Action = constant.MjActionDescribe
} else if relayMode == relayconstant.RelayModeMidjourneyEdits { //编辑任务,此类任务可重复
midjRequest.Action = constant.MjActionEdits
} else if relayMode == relayconstant.RelayModeMidjourneyShorten { //缩短任务此类任务可重复plus only
midjRequest.Action = constant.MjActionShorten
} else if relayMode == relayconstant.RelayModeMidjourneyBlend { //绘画任务,此类任务可重复
@@ -412,6 +438,14 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
//}
mjId = midjRequest.TaskId
midjRequest.Action = constant.MjActionModal
} else if relayMode == relayconstant.RelayModeMidjourneyVideo {
midjRequest.Action = constant.MjActionVideo
if midjRequest.TaskId == "" {
return service.MidjourneyErrorWrapper(constant.MjRequestError, "task_id_is_required")
} else if midjRequest.Action == "" {
return service.MidjourneyErrorWrapper(constant.MjRequestError, "action_is_required")
}
mjId = midjRequest.TaskId
}
originTask := model.GetByMJId(userId, mjId)
@@ -492,8 +526,17 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
tokenName := c.GetString("token_name")
logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %sID %s", priceData.ModelPrice, priceData.GroupRatioInfo.GroupRatio, midjRequest.Action, midjResponse.Result)
other := service.GenerateMjOtherInfo(priceData)
model.RecordConsumeLog(c, userId, channelId, 0, 0, modelName, tokenName,
priceData.Quota, logContent, tokenId, userQuota, 0, false, group, other)
model.RecordConsumeLog(c, relayInfo.UserId, model.RecordConsumeLogParams{
ChannelId: channelId,
ModelName: modelName,
TokenName: tokenName,
Quota: priceData.Quota,
Content: logContent,
TokenId: relayInfo.TokenId,
UserQuota: userQuota,
Group: group,
Other: other,
})
model.UpdateUserUsedQuotaAndRequestCount(userId, priceData.Quota)
model.UpdateChannelUsedQuota(channelId, priceData.Quota)
}

View File

@@ -540,6 +540,19 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
other["audio_input_token_count"] = audioTokens
other["audio_input_price"] = audioInputPrice
}
model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, promptTokens, completionTokens, logModel,
tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, relayInfo.UsingGroup, other)
model.RecordConsumeLog(ctx, relayInfo.UserId, model.RecordConsumeLogParams{
ChannelId: relayInfo.ChannelId,
PromptTokens: promptTokens,
CompletionTokens: completionTokens,
ModelName: logModel,
TokenName: tokenName,
Quota: quota,
Content: logContent,
TokenId: relayInfo.TokenId,
UserQuota: userQuota,
UseTimeSeconds: int(useTimeSeconds),
IsStream: relayInfo.IsStream,
Group: relayInfo.UsingGroup,
Other: other,
})
}

View File

@@ -139,8 +139,17 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) {
if hasUserGroupRatio {
other["user_group_ratio"] = userGroupRatio
}
model.RecordConsumeLog(c, relayInfo.UserId, relayInfo.ChannelId, 0, 0,
modelName, tokenName, quota, logContent, relayInfo.TokenId, userQuota, 0, false, relayInfo.UsingGroup, other)
model.RecordConsumeLog(c, relayInfo.UserId, model.RecordConsumeLogParams{
ChannelId: relayInfo.ChannelId,
ModelName: modelName,
TokenName: tokenName,
Quota: quota,
Content: logContent,
TokenId: relayInfo.TokenId,
UserQuota: userQuota,
Group: relayInfo.UsingGroup,
Other: other,
})
model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, quota)
model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota)
}

View File

@@ -78,12 +78,15 @@ func RerankHelper(c *gin.Context, relayMode int) (openaiErr *dto.OpenAIErrorWith
return service.OpenAIErrorWrapperLocal(err, "json_marshal_failed", http.StatusInternalServerError)
}
requestBody := bytes.NewBuffer(jsonData)
statusCodeMappingStr := c.GetString("status_code_mapping")
if common.DebugEnabled {
println(fmt.Sprintf("Rerank request body: %s", requestBody.String()))
}
resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
if err != nil {
return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
}
statusCodeMappingStr := c.GetString("status_code_mapping")
var httpResp *http.Response
if resp != nil {
httpResp = resp.(*http.Response)

View File

@@ -19,6 +19,7 @@ func SetApiRouter(router *gin.Engine) {
apiRouter.GET("/uptime/status", controller.GetUptimeKumaStatus)
apiRouter.GET("/models", middleware.UserAuth(), controller.DashboardListModels)
apiRouter.GET("/status/test", middleware.AdminAuth(), controller.TestStatus)
apiRouter.GET("/jsrt/reload", middleware.AdminAuth(), controller.ReloadJSScripts)
apiRouter.GET("/notice", controller.GetNotice)
apiRouter.GET("/about", controller.GetAbout)
//apiRouter.GET("/midjourney", controller.GetMidjourney)

View File

@@ -3,14 +3,21 @@ package router
import (
"embed"
"fmt"
"github.com/gin-gonic/gin"
"net/http"
"one-api/common"
"one-api/middleware/jsrt"
"os"
"strings"
"github.com/gin-gonic/gin"
)
func SetRouter(router *gin.Engine, buildFS embed.FS, indexPage []byte) {
jsrtMid := jsrt.JSRuntimeMiddleware()
if jsrtMid != nil {
router.Use(*jsrtMid)
}
SetApiRouter(router)
SetDashboardRouter(router)
SetRelayRouter(router)

View File

@@ -12,6 +12,7 @@ func SetRelayRouter(router *gin.Engine) {
router.Use(middleware.CORS())
router.Use(middleware.DecompressRequestMiddleware())
router.Use(middleware.StatsMiddleware())
// https://platform.openai.com/docs/api-reference/introduction
modelsRouter := router.Group("/v1/models")
modelsRouter.Use(middleware.TokenAuth())
@@ -103,6 +104,8 @@ func registerMjRouterGroup(relayMjRouter *gin.RouterGroup) {
relayMjRouter.POST("/submit/simple-change", controller.RelayMidjourney)
relayMjRouter.POST("/submit/describe", controller.RelayMidjourney)
relayMjRouter.POST("/submit/blend", controller.RelayMidjourney)
relayMjRouter.POST("/submit/edits", controller.RelayMidjourney)
relayMjRouter.POST("/submit/video", controller.RelayMidjourney)
relayMjRouter.POST("/notify", controller.RelayMidjourney)
relayMjRouter.GET("/task/:id/fetch", controller.RelayMidjourney)
relayMjRouter.GET("/task/:id/image-seed", controller.RelayMidjourney)

15
scripts/01_utils.js Normal file
View File

@@ -0,0 +1,15 @@
// Utility functions for JavaScript runtime
function logWithReq(req, message) {
let reqPath = req.url || 'unknown path';
console.log(`[${req.method} ${reqPath}] ${message}`);
}
function safeJsonParse(str, defaultValue = null) {
try {
return JSON.parse(str);
} catch (e) {
console.error('JSON parse error:', e.message);
return defaultValue;
}
}

View File

@@ -0,0 +1,5 @@
// Pre-processing function for incoming requests
function preProcessRequest(req) {
logWithReq(req, 'Pre-processing request');
}

View File

@@ -0,0 +1,5 @@
// Post-processing function for outgoing responses
function postProcessResponse(req, resp) {
logWithReq(req, 'Post-processing response with: ' + resp.statusCode);
}

238
scripts/README.md Normal file
View File

@@ -0,0 +1,238 @@
# JavaScript Runtime Scripts
本目录包含 JavaScript Runtime 中间件使用的脚本文件。
## 脚本加载
- 系统会自动读取 `scripts/` 目录下的所有 `.js` 文件
- 脚本按文件名字母顺序加载
- 建议使用数字前缀来控制加载顺序(如:`01_utils.js`, `02_pre_process.js`
- 所有脚本会被合并到一个 JavaScript 运行时环境中
## 配置
通过环境变量配置:
- `JS_RUNTIME_ENABLED=true` - 启用 JavaScript Runtime
- `JS_SCRIPT_DIR=scripts/` - 脚本目录路径
- `JS_MAX_VM_COUNT=8` - 最大虚拟机数量
- `JS_SCRIPT_TIMEOUT=5s` - 脚本执行超时时间
- `JS_FETCH_TIMEOUT=10s` - HTTP 请求超时时间
更多的详细配置可以在 `.env.example` 文件中找到,并在实际使用时重命名为 `.env`
## 必需的函数
脚本中必须定义以下两个函数:
### 1. preProcessRequest(req)
在请求被转发到后端 API 之前调用。
**参数:**
- `req`: 请求对象,包含 `method`, `url`, `headers`, `body` 等属性
**返回值:**
返回一个对象,可包含以下属性:
- `block`: boolean - 是否阻止请求继续执行
- `statusCode`: number - 阻止请求时返回的状态码
- `message`: string - 阻止请求时返回的错误消息
- `headers`: object - 要修改或添加的请求头
- `body`: any - 修改后的请求体
### 2. postProcessResponse(req, resp)
在响应返回给客户端之前调用。
**参数:**
- `req`: 原始请求对象
- `resp`: 响应对象,包含 `statusCode`, `headers`, `body` 等属性
**返回值:**
返回一个对象,可包含以下属性:
- `statusCode`: number - 修改后的状态码
- `headers`: object - 要修改或添加的响应头
- `body`: string - 修改后的响应体
## 可用的全局对象和函数
- `console.log()`, `console.error()`, `console.warn()` - 日志输出
- `JSON.parse()`, `JSON.stringify()` - JSON 处理
- `fetch(url, options)` - HTTP 请求
- `db` - 主数据库连接
- `logdb` - 日志数据库连接
- `setTimeout(fn, delay)` - 定时器
## 示例脚本
参考现有的示例脚本:
- `01_utils.js` - 工具函数
- `02_pre_process.js` - 请求预处理
- `03_post_process.js` - 响应后处理
## 使用示例
```js
// 例子:基于数据库的速率限制
if (req.url.includes("/v1/chat/completions")) {
try {
// Check recent requests from this IP
var recentRequests = db.query(
"SELECT COUNT(*) as count FROM logs WHERE created_at > ? AND ip = ?",
Math.floor(Date.now() / 1000) - 60, // last minute
req.remoteIP
);
if (recentRequests && recentRequests.length > 0 && recentRequests[0].count > 10) {
console.log("速率限制 IP:", req.remoteIP);
return {
block: true,
statusCode: 429,
message: "超过速率限制"
};
}
} catch (e) {
console.error("Ratelimit 数据库错误:", e);
}
}
// 例子:修改请求
if (req.url.includes("/chat/completions")) {
try {
var bodyObj = req.body;
let firstMsg = { // 需要新建一个对象,不能修改原有对象
role: "user",
content: "喵呜🐱~嘻嘻"
};
bodyObj.messages[0] = firstMsg;
console.log("Modified first message:", JSON.stringify(firstMsg));
console.log("Modified body:", JSON.stringify(bodyObj));
return {
body: bodyObj,
headers: {
...req.headers,
"X-Modified-Body": "true"
}
};
} catch (e) {
console.error("Failed to modify request body:", {
message: e.message,
stack: e.stack,
bodyType: typeof req.body,
url: req.url
});
}
}
// 例子:读取最近一条日志,新增 jsrt 日志,并输出日志总数
try {
// 1. 读取最近一条日志
var recentLogs = logdb.query(
"SELECT id, user_id, username, content, created_at FROM logs ORDER BY id DESC LIMIT 1"
);
var recentLog = null;
if (recentLogs && recentLogs.length > 0) {
recentLog = recentLogs[0];
console.log("最近一条日志:", JSON.stringify(recentLog));
}
// 2. 新增一条 jsrt 日志
var currentTimestamp = Math.floor(Date.now() / 1000);
var jsrtLogContent = "JSRT 预处理中间件执行 - " + req.URL + " - " + new Date().toISOString();
var insertResult = logdb.exec(
"INSERT INTO logs (user_id, username, created_at, type, content) VALUES (?, ?, ?, ?, ?)",
req.UserID || 0,
req.Username || "jsrt-system",
currentTimestamp,
4, // LogTypeSystem
jsrtLogContent
);
if (insertResult.error) {
console.error("插入 JSRT 日志失败:", insertResult.error);
} else {
console.log("成功插入 JSRT 日志,影响行数:", insertResult.rowsAffected);
}
// 3. 输出日志总数
var totalLogsResult = logdb.query("SELECT COUNT(*) as total FROM logs");
var totalLogs = 0;
if (totalLogsResult && totalLogsResult.length > 0) {
totalLogs = totalLogsResult[0].total;
}
console.log("当前日志总数:", totalLogs);
console.log("JSRT 日志管理示例执行完成");
} catch (e) {
console.error("JSRT 日志管理示例执行失败:", {
message: e.message,
stack: e.stack,
url: req.URL
});
}
// 例子:使用 fetch 调用外部 API
if (req.url.includes("/api/uptime/status")) {
try {
// 使用 httpbin.org/ip 测试 fetch 功能
var response = fetch("https://httpbin.org/ip", {
method: "GET",
timeout: 5, // 5秒超时
headers: {
"User-Agent": "JSRT/1.0"
}
});
if (response.Error.length === 0) {
// 解析响应体
var ipData = JSON.parse(response.Body);
// 可以根据获取到的 IP 信息进行后续处理
if (ipData.origin) {
console.log("外部 IP 地址:", ipData.origin);
// 示例:记录 IP 信息到数据库
var currentTimestamp = Math.floor(Date.now() / 1000);
var logContent = "Fetch 示例 - 外部 IP: " + ipData.origin + " - " + new Date().toISOString();
var insertResult = logdb.exec(
"INSERT INTO logs (user_id, username, created_at, type, content) VALUES (?, ?, ?, ?, ?)",
0,
"jsrt-fetch",
currentTimestamp,
4, // LogTypeSystem
logContent
);
if (insertResult.error) {
console.error("记录 IP 信息失败:", insertResult.error);
} else {
console.log("成功记录 IP 信息到数据库");
}
}
} else {
console.error("Fetch 失败 ", response.Status, " ", response.Error);
}
} catch (e) {
console.error("Fetch 失败:", {
message: e.message,
stack: e.stack,
url: req.url
});
}
}
```
## 管理接口
### 重新加载脚本
```bash
curl -X POST http://host:port/api/jsrt/reload \
-H 'Content-Type: application/json' \
-H 'Authorization Bearer <admin_token>'
```
## 故障排除
- 查看服务日志中的 JavaScript 相关错误信息
- 使用 `console.log()` 调试脚本逻辑
- 确保 JavaScript 语法正确(不支持所有 ES6+ 特性)

View File

@@ -3,7 +3,6 @@ package service
import (
"context"
"encoding/json"
"github.com/gin-gonic/gin"
"io"
"log"
"net/http"
@@ -15,6 +14,8 @@ import (
"strconv"
"strings"
"time"
"github.com/gin-gonic/gin"
)
func CoverActionToModelName(mjAction string) string {
@@ -38,6 +39,10 @@ func GetMjRequestModel(relayMode int, midjRequest *dto.MidjourneyRequest) (strin
switch relayMode {
case relayconstant.RelayModeMidjourneyImagine:
action = constant.MjActionImagine
case relayconstant.RelayModeMidjourneyVideo:
action = constant.MjActionVideo
case relayconstant.RelayModeMidjourneyEdits:
action = constant.MjActionEdits
case relayconstant.RelayModeMidjourneyDescribe:
action = constant.MjActionDescribe
case relayconstant.RelayModeMidjourneyBlend:

View File

@@ -209,8 +209,21 @@ func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, mod
}
other := GenerateWssOtherInfo(ctx, relayInfo, usage, modelRatio, groupRatio,
completionRatio.InexactFloat64(), audioRatio.InexactFloat64(), audioCompletionRatio.InexactFloat64(), modelPrice, priceData.GroupRatioInfo.GroupSpecialRatio)
model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, usage.InputTokens, usage.OutputTokens, logModel,
tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, relayInfo.UsingGroup, other)
model.RecordConsumeLog(ctx, relayInfo.UserId, model.RecordConsumeLogParams{
ChannelId: relayInfo.ChannelId,
PromptTokens: usage.InputTokens,
CompletionTokens: usage.OutputTokens,
ModelName: logModel,
TokenName: tokenName,
Quota: quota,
Content: logContent,
TokenId: relayInfo.TokenId,
UserQuota: userQuota,
UseTimeSeconds: int(useTimeSeconds),
IsStream: relayInfo.IsStream,
Group: relayInfo.UsingGroup,
Other: other,
})
}
func PostClaudeConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
@@ -286,8 +299,22 @@ func PostClaudeConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
other := GenerateClaudeOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio,
cacheTokens, cacheRatio, cacheCreationTokens, cacheCreationRatio, modelPrice, priceData.GroupRatioInfo.GroupSpecialRatio)
model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, promptTokens, completionTokens, modelName,
tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, relayInfo.UsingGroup, other)
model.RecordConsumeLog(ctx, relayInfo.UserId, model.RecordConsumeLogParams{
ChannelId: relayInfo.ChannelId,
PromptTokens: promptTokens,
CompletionTokens: completionTokens,
ModelName: modelName,
TokenName: tokenName,
Quota: quota,
Content: logContent,
TokenId: relayInfo.TokenId,
UserQuota: userQuota,
UseTimeSeconds: int(useTimeSeconds),
IsStream: relayInfo.IsStream,
Group: relayInfo.UsingGroup,
Other: other,
})
}
func CalcOpenRouterCacheCreateTokens(usage dto.Usage, priceData helper.PriceData) int {
@@ -384,8 +411,21 @@ func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
}
other := GenerateAudioOtherInfo(ctx, relayInfo, usage, modelRatio, groupRatio,
completionRatio.InexactFloat64(), audioRatio.InexactFloat64(), audioCompletionRatio.InexactFloat64(), modelPrice, priceData.GroupRatioInfo.GroupSpecialRatio)
model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, usage.PromptTokens, usage.CompletionTokens, logModel,
tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, relayInfo.UsingGroup, other)
model.RecordConsumeLog(ctx, relayInfo.UserId, model.RecordConsumeLogParams{
ChannelId: relayInfo.ChannelId,
PromptTokens: usage.PromptTokens,
CompletionTokens: usage.CompletionTokens,
ModelName: logModel,
TokenName: tokenName,
Quota: quota,
Content: logContent,
TokenId: relayInfo.TokenId,
UserQuota: userQuota,
UseTimeSeconds: int(useTimeSeconds),
IsStream: relayInfo.IsStream,
Group: relayInfo.UsingGroup,
Other: other,
})
}
func PreConsumeTokenQuota(relayInfo *relaycommon.RelayInfo, quota int) error {
@@ -447,8 +487,8 @@ func checkAndSendQuotaNotify(relayInfo *relaycommon.RelayInfo, quota int, preCon
gopool.Go(func() {
userSetting := relayInfo.UserSetting
threshold := common.QuotaRemindThreshold
if userCustomThreshold, ok := userSetting[constant.UserSettingQuotaWarningThreshold]; ok {
threshold = int(userCustomThreshold.(float64))
if userSetting.QuotaWarningThreshold != 0 {
threshold = int(userSetting.QuotaWarningThreshold)
}
//noMoreQuota := userCache.Quota-(quota+preConsumedQuota) <= 0

View File

@@ -172,9 +172,6 @@ func CountTokenChatRequest(info *relaycommon.RelayInfo, request dto.GeneralOpenA
}
}
toolTokens := CountTokenInput(countStr, request.Model)
if err != nil {
return 0, err
}
tkm += 8
tkm += toolTokens
}
@@ -195,9 +192,6 @@ func CountTokenClaudeRequest(request dto.ClaudeRequest, model string) (int, erro
// Count tokens in system message
if request.System != "" {
systemTokens := CountTokenInput(request.System, model)
if err != nil {
return 0, err
}
tkm += systemTokens
}

View File

@@ -3,7 +3,6 @@ package service
import (
"fmt"
"one-api/common"
"one-api/constant"
"one-api/dto"
"one-api/model"
"strings"
@@ -17,10 +16,10 @@ func NotifyRootUser(t string, subject string, content string) {
}
}
func NotifyUser(userId int, userEmail string, userSetting map[string]interface{}, data dto.Notify) error {
notifyType, ok := userSetting[constant.UserSettingNotifyType]
if !ok {
notifyType = constant.NotifyTypeEmail
func NotifyUser(userId int, userEmail string, userSetting dto.UserSetting, data dto.Notify) error {
notifyType := userSetting.NotifyType
if notifyType == "" {
notifyType = dto.NotifyTypeEmail
}
// Check notification limit
@@ -34,34 +33,23 @@ func NotifyUser(userId int, userEmail string, userSetting map[string]interface{}
}
switch notifyType {
case constant.NotifyTypeEmail:
case dto.NotifyTypeEmail:
// check setting email
if settingEmail, ok := userSetting[constant.UserSettingNotificationEmail]; ok {
userEmail = settingEmail.(string)
}
userEmail = userSetting.NotificationEmail
if userEmail == "" {
common.SysLog(fmt.Sprintf("user %d has no email, skip sending email", userId))
return nil
}
return sendEmailNotify(userEmail, data)
case constant.NotifyTypeWebhook:
webhookURL, ok := userSetting[constant.UserSettingWebhookUrl]
if !ok {
case dto.NotifyTypeWebhook:
webhookURLStr := userSetting.WebhookUrl
if webhookURLStr == "" {
common.SysError(fmt.Sprintf("user %d has no webhook url, skip sending webhook", userId))
return nil
}
webhookURLStr, ok := webhookURL.(string)
if !ok {
common.SysError(fmt.Sprintf("user %d webhook url is not string type", userId))
return nil
}
// 获取 webhook secret
var webhookSecret string
if secret, ok := userSetting[constant.UserSettingWebhookSecret]; ok {
webhookSecret, _ = secret.(string)
}
webhookSecret := userSetting.WebhookSecret
return SendWebhookNotify(webhookURLStr, webhookSecret, data)
}
return nil

View File

@@ -6,8 +6,11 @@ import (
)
var Chats = []map[string]string{
//{
// "ChatGPT Next Web 官方示例": "https://app.nextchat.dev/#/?settings={\"key\":\"{key}\",\"url\":\"{address}\"}",
//},
{
"ChatGPT Next Web 官方示例": "https://app.nextchat.dev/#/?settings={\"key\":\"{key}\",\"url\":\"{address}\"}",
"Cherry Studio": "cherrystudio://providers/api-keys?v=1&data={cherryConfig}",
},
{
"Lobe Chat 官方示例": "https://chat-preview.lobehub.com/?settings={\"keyVaults\":{\"openai\":{\"apiKey\":\"{key}\",\"baseURL\":\"{address}/v1\"}}}",

View File

@@ -231,7 +231,9 @@ var defaultModelPrice = map[string]float64{
"dall-e-3": 0.04,
"imagen-3.0-generate-002": 0.03,
"gpt-4-gizmo-*": 0.1,
"mj_video": 0.8,
"mj_imagine": 0.1,
"mj_edits": 0.1,
"mj_variation": 0.1,
"mj_reroll": 0.1,
"mj_blend": 0.1,

View File

@@ -197,7 +197,6 @@ const ChannelSelectorModal = forwardRef(({
value={searchText}
onChange={setSearchText}
showClear
className="!rounded-full"
/>
<Table

View File

@@ -1461,9 +1461,9 @@ const ChannelsTable = () => {
const fixChannelsAbilities = async () => {
const res = await API.post(`/api/channel/fix`);
const { success, message, data } = res.data;
const { success, message, data } = res.data;
if (success) {
showSuccess(t('已修复 ${data} 个通道').replace('${data}', data));
showSuccess(t('已修复 ${success} 个通道,失败 ${fails} 个通道').replace('${success}', data.success).replace('${fails}', data.fails));
await refresh();
} else {
showError(message);

View File

@@ -195,6 +195,18 @@ const LogsTable = () => {
{t('放大')}
</Tag>
);
case 'VIDEO':
return (
<Tag color='orange' size='large' shape='circle' prefixIcon={<Video size={14} />}>
{t('视频')}
</Tag>
);
case 'EDITS':
return (
<Tag color='orange' size='large' shape='circle' prefixIcon={<Video size={14} />}>
{t('编辑')}
</Tag>
);
case 'VARIATION':
return (
<Tag color='purple' size='large' shape='circle' prefixIcon={<Shuffle size={14} />}>

View File

@@ -1,5 +1,5 @@
import React, { useContext, useEffect, useRef, useMemo, useState } from 'react';
import { API, copy, showError, showInfo, showSuccess, getModelCategories, renderModelTag } from '../../helpers';
import { API, copy, showError, showInfo, showSuccess, getModelCategories, renderModelTag, stringToColor } from '../../helpers';
import { useTranslation } from 'react-i18next';
import {
@@ -106,6 +106,26 @@ const ModelPricing = () => {
) : null;
}
function renderSupportedEndpoints(endpoints) {
if (!endpoints || endpoints.length === 0) {
return null;
}
return (
<Space wrap>
{endpoints.map((endpoint, idx) => (
<Tag
key={endpoint}
color={stringToColor(endpoint)}
size='large'
shape='circle'
>
{endpoint}
</Tag>
))}
</Space>
);
}
const columns = [
{
title: t('可用性'),
@@ -120,6 +140,13 @@ const ModelPricing = () => {
},
defaultSortOrder: 'descend',
},
{
title: t('可用端点类型'),
dataIndex: 'supported_endpoint_types',
render: (text, record, index) => {
return renderSupportedEndpoints(text);
},
},
{
title: t('模型名称'),
dataIndex: 'model_name',
@@ -499,7 +526,7 @@ const ModelPricing = () => {
<div className="flex items-center">
<AlertCircle size={14} className="mr-1.5 flex-shrink-0" />
<span className="truncate">
{t('未登录,使用默认分组倍率')}: {groupRatio['default']}
{t('未登录,使用默认分组倍率')}{groupRatio['default']}
</span>
</div>
)}

View File

@@ -432,9 +432,22 @@ const TokensTable = () => {
if (serverAddress === '') {
serverAddress = window.location.origin;
}
let encodedServerAddress = encodeURIComponent(serverAddress);
url = url.replaceAll('{address}', encodedServerAddress);
url = url.replaceAll('{key}', 'sk-' + record.key);
if (url.includes('{cherryConfig}') === true) {
let cherryConfig = {
id: 'new-api',
baseUrl: serverAddress,
apiKey: 'sk-' + record.key,
}
// 替换 {cherryConfig} 为base64编码的JSON字符串
let encodedConfig = encodeURIComponent(
btoa(JSON.stringify(cherryConfig))
);
url = url.replaceAll('{cherryConfig}', encodedConfig);
} else {
let encodedServerAddress = encodeURIComponent(serverAddress);
url = url.replaceAll('{address}', encodedServerAddress);
url = url.replaceAll('{key}', 'sk-' + record.key);
}
window.open(url, '_blank');
};

View File

@@ -119,7 +119,7 @@ const UsersTable = () => {
<Tooltip content={remark} position="top" showArrow>
<Tag color='white' size='large' shape='circle' className="!text-xs">
<div className="flex items-center gap-1">
<div className="w-2 h-2 flex-shrink-0" style={{ backgroundColor: '#10b981' }} />
<div className="w-2 h-2 flex-shrink-0 rounded-full" style={{ backgroundColor: '#10b981' }} />
{displayRemark}
</div>
</Tag>

View File

@@ -876,7 +876,7 @@
"加载token失败": "Failed to load token",
"配置聊天": "Configure chat",
"模型消耗分布": "Model consumption distribution",
"模型调用次数占比": "Proportion of model calls",
"模型调用次数占比": "Model call ratio",
"用户消耗分布": "User consumption distribution",
"时间粒度": "Time granularity",
"天": "day",
@@ -1119,6 +1119,10 @@
"平均TPM": "Average TPM",
"消耗分布": "Consumption distribution",
"调用次数分布": "Models call distribution",
"消耗趋势": "Consumption trend",
"模型消耗趋势": "Model consumption trend",
"调用次数排行": "Models call ranking",
"模型调用次数排行": "Model call ranking",
"添加渠道": "Add channel",
"测试所有通道": "Test all channels",
"删除禁用通道": "Delete disabled channels",
@@ -1143,8 +1147,8 @@
"默认测试模型": "Default Test Model",
"不填则为模型列表第一个": "First model in list if empty",
"是否自动禁用(仅当自动禁用开启时有效),关闭后不会自动禁用该渠道": "Auto-disable (only effective when auto-disable is enabled). When turned off, this channel will not be automatically disabled",
"状态码复写(仅影响本地判断,不修改返回到上游的状态码)": "Status Code Override (only affects local judgment, does not modify status code returned upstream)",
"此项可选用于复写返回的状态码比如将claude渠道的400错误复写为500用于重试请勿滥用该功能例如": "Optional, used to override returned status codes, e.g. rewriting Claude channel's 400 error to 500 (for retry). Do not abuse this feature. Example:",
"状态码复写": "Status Code Override",
"此项可选,用于复写返回的状态码,仅影响本地判断,不修改返回到上游的状态码,比如将claude渠道的400错误复写为500用于重试请勿滥用该功能例如": "Optional, used to override returned status codes, only affects local judgment, does not modify status code returned upstream, e.g. rewriting Claude channel's 400 error to 500 (for retry). Do not abuse this feature. Example:",
"渠道标签": "Channel Tag",
"渠道优先级": "Channel Priority",
"渠道权重": "Channel Weight",
@@ -1199,7 +1203,7 @@
"添加用户": "Add user",
"角色": "Role",
"已绑定的 Telegram 账户": "Bound Telegram account",
"新额度": "New quota",
"新额度": "New quota: ",
"需要添加的额度(支持负数)": "Need to add quota (supports negative numbers)",
"此项只读,需要用户通过个人设置页面的相关绑定按钮进行绑定,不可直接修改": "Read-only, user's personal settings, and cannot be modified directly",
"请输入新的密码,最短 8 位": "Please enter a new password, at least 8 characterss",
@@ -1750,5 +1754,7 @@
"批量创建时会在名称后自动添加随机后缀": "When creating in batches, a random suffix will be automatically added to the name",
"额度必须大于0": "Quota must be greater than 0",
"生成数量必须大于0": "Generation quantity must be greater than 0",
"创建后可在编辑渠道时获取上游模型列表": "After creation, you can get the upstream model list when editing the channel"
"创建后可在编辑渠道时获取上游模型列表": "After creation, you can get the upstream model list when editing the channel",
"可用端点类型": "Supported endpoint types",
"未登录,使用默认分组倍率:": "Not logged in, using default group ratio: "
}

File diff suppressed because it is too large Load Diff

View File

@@ -1,4 +1,4 @@
import React, { useState, useEffect } from 'react';
import React, { useState, useEffect, useRef } from 'react';
import {
API,
showError,
@@ -11,15 +11,13 @@ import {
SideSheet,
Space,
Button,
Input,
Typography,
Spin,
Select,
Banner,
TextArea,
Card,
Tag,
Avatar,
Form,
} from '@douyinfe/semi-ui';
import {
IconSave,
@@ -53,9 +51,14 @@ const EditTagModal = (props) => {
models: [],
};
const [inputs, setInputs] = useState(originInputs);
const formApiRef = useRef(null);
const getInitValues = () => ({ ...originInputs });
const handleInputChange = (name, value) => {
setInputs((inputs) => ({ ...inputs, [name]: value }));
if (formApiRef.current) {
formApiRef.current.setValue(name, value);
}
if (name === 'type') {
let localModels = [];
switch (value) {
@@ -74,6 +77,8 @@ const EditTagModal = (props) => {
localModels = [
'swap_face',
'mj_imagine',
'mj_video',
'mj_edits',
'mj_variation',
'mj_reroll',
'mj_blend',
@@ -133,27 +138,25 @@ const EditTagModal = (props) => {
}
};
const handleSave = async () => {
const handleSave = async (values) => {
setLoading(true);
let data = {
tag: tag,
};
if (inputs.model_mapping !== null && inputs.model_mapping !== '') {
if (inputs.model_mapping !== '' && !verifyJSON(inputs.model_mapping)) {
const formVals = values || formApiRef.current?.getValues() || {};
let data = { tag };
if (formVals.model_mapping) {
if (!verifyJSON(formVals.model_mapping)) {
showInfo('模型映射必须是合法的 JSON 格式!');
setLoading(false);
return;
}
data.model_mapping = inputs.model_mapping;
data.model_mapping = formVals.model_mapping;
}
if (inputs.groups.length > 0) {
data.groups = inputs.groups.join(',');
if (formVals.groups && formVals.groups.length > 0) {
data.groups = formVals.groups.join(',');
}
if (inputs.models.length > 0) {
data.models = inputs.models.join(',');
if (formVals.models && formVals.models.length > 0) {
data.models = formVals.models.join(',');
}
data.new_tag = inputs.new_tag;
// check have any change
data.new_tag = formVals.new_tag;
if (
data.model_mapping === undefined &&
data.groups === undefined &&
@@ -202,7 +205,7 @@ const EditTagModal = (props) => {
const res = await API.get(`/api/channel/tag/models?tag=${tag}`);
if (res?.data?.success) {
const models = res.data.data ? res.data.data.split(',') : [];
setInputs((inputs) => ({ ...inputs, models: models }));
handleInputChange('models', models);
} else {
showError(res.data.message);
}
@@ -213,19 +216,32 @@ const EditTagModal = (props) => {
}
};
fetchModels().then();
fetchGroups().then();
fetchTagModels().then();
if (formApiRef.current) {
formApiRef.current.setValues({
...getInitValues(),
tag: tag,
new_tag: tag,
});
}
setInputs({
...originInputs,
tag: tag,
new_tag: tag,
});
fetchModels().then();
fetchGroups().then();
fetchTagModels().then(); // Call the new function
}, [visible, tag]); // Add tag to dependency array
}, [visible, tag]);
useEffect(() => {
if (formApiRef.current) {
formApiRef.current.setValues(inputs);
}
}, [inputs]);
const addCustomModels = () => {
if (customModel.trim() === '') return;
// 使用逗号分隔字符串,然后去除每个模型名称前后的空格
const modelArray = customModel.split(',').map((model) => model.trim());
let localModels = [...inputs.models];
@@ -233,11 +249,9 @@ const EditTagModal = (props) => {
const addedModels = [];
modelArray.forEach((model) => {
// 检查模型是否已存在,且模型名称非空
if (model && !localModels.includes(model)) {
localModels.push(model); // 添加到模型列表
localModels.push(model);
localModelOptions.push({
// 添加到下拉选项
key: model,
text: model,
value: model,
@@ -246,7 +260,6 @@ const EditTagModal = (props) => {
}
});
// 更新状态值
setModelOptions(localModelOptions);
setCustomModel('');
handleInputChange('models', localModels);
@@ -283,7 +296,7 @@ const EditTagModal = (props) => {
<Space>
<Button
theme="solid"
onClick={handleSave}
onClick={() => formApiRef.current?.submitForm()}
loading={loading}
icon={<IconSave />}
>
@@ -302,146 +315,128 @@ const EditTagModal = (props) => {
}
closeIcon={null}
>
<Spin spinning={loading}>
<div className="p-2">
<Card className="!rounded-2xl shadow-sm border-0 mb-6">
{/* Header: Tag Info */}
<div className="flex items-center mb-2">
<Avatar size="small" color="blue" className="mr-2 shadow-md">
<IconBookmark size={16} />
</Avatar>
<div>
<Text className="text-lg font-medium">{t('标签信息')}</Text>
<div className="text-xs text-gray-600">{t('标签的基本配置')}</div>
</div>
</div>
<Form
key={tag || 'edit'}
initValues={getInitValues()}
getFormApi={(api) => (formApiRef.current = api)}
onSubmit={handleSave}
>
{() => (
<Spin spinning={loading}>
<div className="p-2">
<Card className="!rounded-2xl shadow-sm border-0 mb-6">
{/* Header: Tag Info */}
<div className="flex items-center mb-2">
<Avatar size="small" color="blue" className="mr-2 shadow-md">
<IconBookmark size={16} />
</Avatar>
<div>
<Text className="text-lg font-medium">{t('标签信息')}</Text>
<div className="text-xs text-gray-600">{t('标签的基本配置')}</div>
</div>
</div>
<Banner
type="warning"
description={t('所有编辑均为覆盖操作,留空则不更改')}
className="!rounded-lg mb-4"
/>
<div className="space-y-4">
<div>
<Text strong className="block mb-2">{t('标签名称')}</Text>
<Input
value={inputs.new_tag}
onChange={(value) => setInputs({ ...inputs, new_tag: value })}
placeholder={t('请输入新标签,留空则解散标签')}
/>
</div>
</div>
</Card>
<Card className="!rounded-2xl shadow-sm border-0 mb-6">
{/* Header: Model Config */}
<div className="flex items-center mb-2">
<Avatar size="small" color="purple" className="mr-2 shadow-md">
<IconCode size={16} />
</Avatar>
<div>
<Text className="text-lg font-medium">{t('模型配置')}</Text>
<div className="text-xs text-gray-600">{t('模型选择和映射设置')}</div>
</div>
</div>
<div className="space-y-4">
<div>
<Text strong className="block mb-2">{t('模型')}</Text>
<Banner
type="info"
description={t('当前模型列表为该标签下所有渠道模型列表最长的一个,并非所有渠道的并集,请注意可能导致某些渠道模型丢失。')}
type="warning"
description={t('所有编辑均为覆盖操作,留空则不更改')}
className="!rounded-lg mb-4"
/>
<Select
placeholder={t('请选择该渠道所支持的模型,留空则不更改')}
name='models'
multiple
filter
searchPosition='dropdown'
onChange={(value) => handleInputChange('models', value)}
value={inputs.models}
optionList={modelOptions}
/>
</div>
<div>
<Input
addonAfter={
<Button type='primary' onClick={addCustomModels} className="!rounded-r-lg">
{t('填入')}
</Button>
}
placeholder={t('输入自定义模型名称')}
value={customModel}
onChange={(value) => setCustomModel(value.trim())}
/>
</div>
<div className="space-y-4">
<Form.Input
field='new_tag'
label={t('标签名称')}
placeholder={t('请输入新标签,留空则解散标签')}
onChange={(value) => handleInputChange('new_tag', value)}
/>
</div>
</Card>
<div>
<Text strong className="block mb-2">{t('模型重定向')}</Text>
<TextArea
placeholder={t('此项可选,用于修改请求体中的模型名称,为一个 JSON 字符串,键为请求中模型名称,值为要替换的模型名称,留空则不更改')}
name='model_mapping'
onChange={(value) => handleInputChange('model_mapping', value)}
autosize
value={inputs.model_mapping}
/>
<Space className="mt-2">
<Text
className="!text-semi-color-primary cursor-pointer"
onClick={() => handleInputChange('model_mapping', JSON.stringify(MODEL_MAPPING_EXAMPLE, null, 2))}
>
{t('填入模板')}
</Text>
<Text
className="!text-semi-color-primary cursor-pointer"
onClick={() => handleInputChange('model_mapping', JSON.stringify({}, null, 2))}
>
{t('清空重定向')}
</Text>
<Text
className="!text-semi-color-primary cursor-pointer"
onClick={() => handleInputChange('model_mapping', '')}
>
{t('不更改')}
</Text>
</Space>
</div>
<Card className="!rounded-2xl shadow-sm border-0 mb-6">
{/* Header: Model Config */}
<div className="flex items-center mb-2">
<Avatar size="small" color="purple" className="mr-2 shadow-md">
<IconCode size={16} />
</Avatar>
<div>
<Text className="text-lg font-medium">{t('模型配置')}</Text>
<div className="text-xs text-gray-600">{t('模型选择和映射设置')}</div>
</div>
</div>
<div className="space-y-4">
<Banner
type="info"
description={t('当前模型列表为该标签下所有渠道模型列表最长的一个,并非所有渠道的并集,请注意可能导致某些渠道模型丢失。')}
className="!rounded-lg mb-4"
/>
<Form.Select
field='models'
label={t('模型')}
placeholder={t('请选择该渠道所支持的模型,留空则不更改')}
multiple
filter
searchPosition='dropdown'
optionList={modelOptions}
style={{ width: '100%' }}
onChange={(value) => handleInputChange('models', value)}
/>
<Form.Input
field='custom_model'
label={t('自定义模型名称')}
placeholder={t('输入自定义模型名称')}
onChange={(value) => setCustomModel(value.trim())}
suffix={<Button size='small' type='primary' onClick={addCustomModels}>{t('填入')}</Button>}
/>
<Form.TextArea
field='model_mapping'
label={t('模型重定向')}
placeholder={t('此项可选,用于修改请求体中的模型名称,为一个 JSON 字符串,键为请求中模型名称,值为要替换的模型名称,留空则不更改')}
autosize
onChange={(value) => handleInputChange('model_mapping', value)}
extraText={(
<Space>
<Text className="!text-semi-color-primary cursor-pointer" onClick={() => handleInputChange('model_mapping', JSON.stringify(MODEL_MAPPING_EXAMPLE, null, 2))}>{t('填入模板')}</Text>
<Text className="!text-semi-color-primary cursor-pointer" onClick={() => handleInputChange('model_mapping', JSON.stringify({}, null, 2))}>{t('清空重定向')}</Text>
<Text className="!text-semi-color-primary cursor-pointer" onClick={() => handleInputChange('model_mapping', '')}>{t('不更改')}</Text>
</Space>
)}
/>
</div>
</Card>
<Card className="!rounded-2xl shadow-sm border-0">
{/* Header: Group Settings */}
<div className="flex items-center mb-2">
<Avatar size="small" color="green" className="mr-2 shadow-md">
<IconUser size={16} />
</Avatar>
<div>
<Text className="text-lg font-medium">{t('分组设置')}</Text>
<div className="text-xs text-gray-600">{t('用户分组配置')}</div>
</div>
</div>
<div className="space-y-4">
<Form.Select
field='groups'
label={t('分组')}
placeholder={t('请选择可以使用该渠道的分组,留空则不更改')}
multiple
allowAdditions
additionLabel={t('请在系统设置页面编辑分组倍率以添加新的分组:')}
optionList={groupOptions}
style={{ width: '100%' }}
onChange={(value) => handleInputChange('groups', value)}
/>
</div>
</Card>
</div>
</Card>
<Card className="!rounded-2xl shadow-sm border-0">
{/* Header: Group Settings */}
<div className="flex items-center mb-2">
<Avatar size="small" color="green" className="mr-2 shadow-md">
<IconUser size={16} />
</Avatar>
<div>
<Text className="text-lg font-medium">{t('分组设置')}</Text>
<div className="text-xs text-gray-600">{t('用户分组配置')}</div>
</div>
</div>
<div className="space-y-4">
<div>
<Text strong className="block mb-2">{t('分组')}</Text>
<Select
placeholder={t('请选择可以使用该渠道的分组,留空则不更改')}
name='groups'
multiple
allowAdditions
additionLabel={t('请在系统设置页面编辑分组倍率以添加新的分组:')}
onChange={(value) => handleInputChange('groups', value)}
value={inputs.groups}
optionList={groupOptions}
/>
</div>
</div>
</Card>
</div>
</Spin>
</Spin>
)}
</Form>
</SideSheet>
);
};

View File

@@ -366,6 +366,86 @@ const Detail = (props) => {
},
});
// 模型消耗趋势折线图
const [spec_model_line, setSpecModelLine] = useState({
type: 'line',
data: [
{
id: 'lineData',
values: [],
},
],
xField: 'Time',
yField: 'Count',
seriesField: 'Model',
legends: {
visible: true,
selectMode: 'single',
},
title: {
visible: true,
text: t('模型消耗趋势'),
subtext: '',
},
tooltip: {
mark: {
content: [
{
key: (datum) => datum['Model'],
value: (datum) => renderNumber(datum['Count']),
},
],
},
},
color: {
specified: modelColorMap,
},
});
// 模型调用次数排行柱状图
const [spec_rank_bar, setSpecRankBar] = useState({
type: 'bar',
data: [
{
id: 'rankData',
values: [],
},
],
xField: 'Model',
yField: 'Count',
seriesField: 'Model',
legends: {
visible: true,
selectMode: 'single',
},
title: {
visible: true,
text: t('模型调用次数排行'),
subtext: '',
},
bar: {
state: {
hover: {
stroke: '#000',
lineWidth: 1,
},
},
},
tooltip: {
mark: {
content: [
{
key: (datum) => datum['Model'],
value: (datum) => renderNumber(datum['Count']),
},
],
},
},
color: {
specified: modelColorMap,
},
});
// ========== Hooks - Memoized Values ==========
const performanceMetrics = useMemo(() => {
const timeDiff = (Date.parse(end_timestamp) - Date.parse(start_timestamp)) / 60000;
@@ -853,6 +933,46 @@ const Detail = (props) => {
'barData'
);
// ===== 模型调用次数折线图 =====
let modelLineData = [];
chartTimePoints.forEach((time) => {
const timeData = Array.from(uniqueModels).map((model) => {
const key = `${time}-${model}`;
const aggregated = aggregatedData.get(key);
return {
Time: time,
Model: model,
Count: aggregated?.count || 0,
};
});
modelLineData.push(...timeData);
});
modelLineData.sort((a, b) => a.Time.localeCompare(b.Time));
// ===== 模型调用次数排行柱状图 =====
const rankData = Array.from(modelTotals)
.map(([model, count]) => ({
Model: model,
Count: count,
}))
.sort((a, b) => b.Count - a.Count);
updateChartSpec(
setSpecModelLine,
modelLineData,
`${t('总计')}${renderNumber(totalTimes)}`,
newModelColors,
'lineData'
);
updateChartSpec(
setSpecRankBar,
rankData,
`${t('总计')}${renderNumber(totalTimes)}`,
newModelColors,
'rankData'
);
setPieData(newPieData);
setLineData(newLineData);
setConsumeQuota(totalQuota);
@@ -1122,28 +1242,53 @@ const Detail = (props) => {
{t('消耗分布')}
</span>
} itemKey="1" />
<TabPane tab={
<span>
<IconPulse />
{t('消耗趋势')}
</span>
} itemKey="2" />
<TabPane tab={
<span>
<IconPieChart2Stroked />
{t('调用次数分布')}
</span>
} itemKey="2" />
} itemKey="3" />
<TabPane tab={
<span>
<IconHistogram />
{t('调用次数排行')}
</span>
} itemKey="4" />
</Tabs>
</div>
}
>
<div style={{ height: 400 }}>
{activeChartTab === '1' ? (
{activeChartTab === '1' && (
<VChart
spec={spec_line}
option={CHART_CONFIG}
/>
) : (
)}
{activeChartTab === '2' && (
<VChart
spec={spec_model_line}
option={CHART_CONFIG}
/>
)}
{activeChartTab === '3' && (
<VChart
spec={spec_pie}
option={CHART_CONFIG}
/>
)}
{activeChartTab === '4' && (
<VChart
spec={spec_rank_bar}
option={CHART_CONFIG}
/>
)}
</div>
</Card>

View File

@@ -272,10 +272,7 @@ const Home = () => {
className="w-full h-screen border-none"
/>
) : (
<div
className="text-base md:text-lg p-4 md:p-6 lg:p-8 overflow-x-hidden max-w-6xl mx-auto"
dangerouslySetInnerHTML={{ __html: homePageContent }}
></div>
<div className="mt-[64px]" dangerouslySetInnerHTML={{ __html: homePageContent }} />
)}
</div>
)}

View File

@@ -373,7 +373,7 @@ export default function UpstreamRatioSync(props) {
<div className="flex flex-col md:flex-row gap-2 w-full md:w-auto order-2 md:order-1">
<Button
icon={<RefreshCcw size={14} />}
className="!rounded-full w-full md:w-auto mt-2"
className="w-full md:w-auto mt-2"
onClick={() => {
setModalVisible(true);
if (allChannels.length === 0) {
@@ -393,7 +393,7 @@ export default function UpstreamRatioSync(props) {
type='secondary'
onClick={applySync}
disabled={!hasSelections}
className="!rounded-full w-full md:w-auto mt-2"
className="w-full md:w-auto mt-2"
>
{t('应用同步')}
</Button>
@@ -406,7 +406,7 @@ export default function UpstreamRatioSync(props) {
placeholder={t('搜索模型名称')}
value={searchKeyword}
onChange={setSearchKeyword}
className="!rounded-full w-full sm:w-64"
className="w-full sm:w-64"
showClear
/>
@@ -414,7 +414,7 @@ export default function UpstreamRatioSync(props) {
placeholder={t('按倍率类型筛选')}
value={ratioTypeFilter}
onChange={setRatioTypeFilter}
className="!rounded-full w-full sm:w-48"
className="w-full sm:w-48"
showClear
onClear={() => setRatioTypeFilter('')}
>
@@ -704,7 +704,6 @@ export default function UpstreamRatioSync(props) {
scroll={{ x: 'max-content' }}
size='middle'
loading={loading || syncLoading}
className="rounded-xl overflow-hidden"
/>
);
};

View File

@@ -139,14 +139,24 @@ const EditToken = (props) => {
if (formApiRef.current) {
if (!isEdit) {
formApiRef.current.setValues(getInitValues());
} else {
loadToken();
}
}
loadModels();
loadGroups();
}, [props.editingToken.id]);
useEffect(() => {
if (props.visiable) {
if (isEdit) {
loadToken();
} else {
formApiRef.current?.setValues(getInitValues());
}
} else {
formApiRef.current?.reset();
}
}, [props.visiable, props.editingToken.id]);
const generateRandomSuffix = () => {
const characters =
'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789';

View File

@@ -22,6 +22,7 @@ import {
Row,
Col,
Input,
InputNumber,
} from '@douyinfe/semi-ui';
import {
IconUser,
@@ -39,7 +40,7 @@ const EditUser = (props) => {
const userId = props.editingUser.id;
const [loading, setLoading] = useState(true);
const [addQuotaModalOpen, setIsModalOpen] = useState(false);
const [addQuotaLocal, setAddQuotaLocal] = useState('0');
const [addQuotaLocal, setAddQuotaLocal] = useState('');
const [groupOptions, setGroupOptions] = useState([]);
const formApiRef = useRef(null);
@@ -254,7 +255,6 @@ const EditUser = (props) => {
field='quota'
label={t('剩余额度')}
placeholder={t('请输入新的剩余额度')}
min={0}
step={500000}
extraText={renderQuotaWithPrompt(values.quota || 0)}
rules={[{ required: true, message: t('请输入额度') }]}
@@ -328,18 +328,19 @@ const EditUser = (props) => {
const current = formApiRef.current?.getValue('quota') || 0;
return (
<Text type='secondary' className='block mb-2'>
{`${t('新额度')}${renderQuota(current)} + ${renderQuota(addQuotaLocal)} = ${renderQuota(current + parseInt(addQuotaLocal || 0))}`}
{`${t('新额度')}${renderQuota(current)} + ${renderQuota(addQuotaLocal)} = ${renderQuota(current + parseInt(addQuotaLocal || 0))}`}
</Text>
);
})()
}
</div>
<Input
<InputNumber
placeholder={t('需要添加的额度(支持负数)')}
type='number'
value={addQuotaLocal}
onChange={setAddQuotaLocal}
style={{ width: '100%' }}
showClear
step={500000}
/>
</Modal>
</>