mirror of
https://github.com/QuantumNous/new-api.git
synced 2026-04-03 19:14:52 +00:00
Compare commits
23 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
99a2fc5852 | ||
|
|
9d9070c899 | ||
|
|
9a48ed47f4 | ||
|
|
155f67e960 | ||
|
|
71778f4174 | ||
|
|
7bb66b8bec | ||
|
|
7bdec28e5f | ||
|
|
5ffdd9f542 | ||
|
|
4c72f2abed | ||
|
|
fd51f71e0f | ||
|
|
59f12d2582 | ||
|
|
f17a419520 | ||
|
|
ee114e14c3 | ||
|
|
78fb457765 | ||
|
|
8759ef012f | ||
|
|
f8d67a62a2 | ||
|
|
efb98854b2 | ||
|
|
7b29f429ee | ||
|
|
265c7d93a2 | ||
|
|
ce57ad3570 | ||
|
|
9282f1d893 | ||
|
|
9546a47f2b | ||
|
|
8073cbd96a |
11
.env.example
11
.env.example
@@ -73,3 +73,14 @@
|
||||
# 节点类型
|
||||
# 如果是主节点则为master
|
||||
# NODE_TYPE=master
|
||||
|
||||
|
||||
# JavaScript 运行时配置
|
||||
# 是否启用(默认:false)
|
||||
# JS_RUNTIME_ENABLED=true
|
||||
# 最大虚拟机数量(默认:8)
|
||||
# JS_MAX_VM_COUNT=
|
||||
# 运行超时时间(单位:秒,默认:5)
|
||||
# JS_SCRIPT_TIMEOUT=
|
||||
# 脚本文件夹(默认:scripts/)
|
||||
# JS_SCRIPT_PATH=
|
||||
|
||||
@@ -76,3 +76,13 @@ func GetContextKeyStringMap(c *gin.Context, key constant.ContextKey) map[string]
|
||||
func GetContextKeyTime(c *gin.Context, key constant.ContextKey) time.Time {
|
||||
return c.GetTime(string(key))
|
||||
}
|
||||
|
||||
func GetContextKeyType[T any](c *gin.Context, key constant.ContextKey) (T, bool) {
|
||||
if value, ok := c.Get(string(key)); ok {
|
||||
if v, ok := value.(T); ok {
|
||||
return v, true
|
||||
}
|
||||
}
|
||||
var t T
|
||||
return t, false
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"math/rand"
|
||||
"strconv"
|
||||
@@ -68,3 +69,15 @@ func StringToByteSlice(s string) []byte {
|
||||
tmp2 := [3]uintptr{tmp1[0], tmp1[1], tmp1[1]}
|
||||
return *(*[]byte)(unsafe.Pointer(&tmp2))
|
||||
}
|
||||
|
||||
func EncodeBase64(str string) string {
|
||||
return base64.StdEncoding.EncodeToString([]byte(str))
|
||||
}
|
||||
|
||||
func GetJsonString(data any) string {
|
||||
if data == nil {
|
||||
return ""
|
||||
}
|
||||
b, _ := json.Marshal(data)
|
||||
return string(b)
|
||||
}
|
||||
|
||||
149
common/struct_reflect.go
Normal file
149
common/struct_reflect.go
Normal file
@@ -0,0 +1,149 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
)
|
||||
|
||||
// StructToMap 递归地把任意结构体 v 转成 map[string]any。
|
||||
// - 只处理导出字段;未导出字段会被跳过。
|
||||
// - 优先使用 `json:"name"` 里逗号前的部分作为键;如果是 "-" 则忽略该字段;若无 tag,则使用字段名。
|
||||
// - 对指针、切片、数组、嵌套结构体、map 做深度遍历,保持原始结构。
|
||||
func StructToMap(v any) (map[string]any, error) {
|
||||
val := reflect.ValueOf(v)
|
||||
if !val.IsValid() {
|
||||
return nil, fmt.Errorf("nil value")
|
||||
}
|
||||
for val.Kind() == reflect.Pointer {
|
||||
if val.IsNil() {
|
||||
return nil, fmt.Errorf("nil pointer")
|
||||
}
|
||||
val = val.Elem()
|
||||
}
|
||||
if val.Kind() != reflect.Struct {
|
||||
return nil, fmt.Errorf("expect struct, got %s", val.Kind())
|
||||
}
|
||||
|
||||
return structValueToMap(val), nil
|
||||
}
|
||||
|
||||
func structValueToMap(val reflect.Value) map[string]any {
|
||||
out := make(map[string]any, val.NumField())
|
||||
|
||||
typ := val.Type()
|
||||
for i := 0; i < val.NumField(); i++ {
|
||||
f := typ.Field(i)
|
||||
if f.PkgPath != "" { // 未导出字段
|
||||
continue
|
||||
}
|
||||
|
||||
// 解析 json tag
|
||||
tag := f.Tag.Get("json")
|
||||
name, opts := parseTag(tag)
|
||||
if name == "-" {
|
||||
continue
|
||||
}
|
||||
if name == "" {
|
||||
name = f.Name
|
||||
}
|
||||
|
||||
fv := val.Field(i)
|
||||
out[name] = valueToAny(fv, opts.Contains("omitempty"))
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// valueToAny 递归处理各种值类型。
|
||||
func valueToAny(v reflect.Value, omitEmpty bool) any {
|
||||
if !v.IsValid() {
|
||||
return nil
|
||||
}
|
||||
|
||||
for v.Kind() == reflect.Pointer {
|
||||
if v.IsNil() {
|
||||
if omitEmpty {
|
||||
return nil
|
||||
}
|
||||
// 保持与 encoding/json 行为一致,nil 指针写成 null
|
||||
return nil
|
||||
}
|
||||
v = v.Elem()
|
||||
}
|
||||
|
||||
switch v.Kind() {
|
||||
|
||||
case reflect.Struct:
|
||||
return structValueToMap(v)
|
||||
|
||||
case reflect.Slice, reflect.Array:
|
||||
l := v.Len()
|
||||
arr := make([]any, l)
|
||||
for i := 0; i < l; i++ {
|
||||
arr[i] = valueToAny(v.Index(i), false)
|
||||
}
|
||||
return arr
|
||||
|
||||
case reflect.Map:
|
||||
m := make(map[string]any, v.Len())
|
||||
iter := v.MapRange()
|
||||
for iter.Next() {
|
||||
k := iter.Key()
|
||||
// 只支持 string key,与 encoding/json 一致
|
||||
if k.Kind() == reflect.String {
|
||||
m[k.String()] = valueToAny(iter.Value(), false)
|
||||
}
|
||||
}
|
||||
return m
|
||||
|
||||
default:
|
||||
// 基本类型直接返回其接口值
|
||||
return v.Interface()
|
||||
}
|
||||
}
|
||||
|
||||
// tagOptions 用于判断是否包含 "omitempty"
|
||||
type tagOptions string
|
||||
|
||||
func (o tagOptions) Contains(opt string) bool {
|
||||
if len(o) == 0 {
|
||||
return false
|
||||
}
|
||||
for _, s := range splitComma(string(o)) {
|
||||
if s == opt {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func parseTag(tag string) (string, tagOptions) {
|
||||
if idx := indexComma(tag); idx != -1 {
|
||||
return tag[:idx], tagOptions(tag[idx+1:])
|
||||
}
|
||||
return tag, tagOptions("")
|
||||
}
|
||||
|
||||
// 避免 strings.Split 额外分配
|
||||
func indexComma(s string) int {
|
||||
for i, r := range s {
|
||||
if r == ',' {
|
||||
return i
|
||||
}
|
||||
}
|
||||
return -1
|
||||
}
|
||||
|
||||
func splitComma(s string) []string {
|
||||
var parts []string
|
||||
start := 0
|
||||
for i, r := range s {
|
||||
if r == ',' {
|
||||
parts = append(parts, s[start:i])
|
||||
start = i + 1
|
||||
}
|
||||
}
|
||||
if start <= len(s) {
|
||||
parts = append(parts, s[start:])
|
||||
}
|
||||
return parts
|
||||
}
|
||||
@@ -1,7 +0,0 @@
|
||||
package constant
|
||||
|
||||
var (
|
||||
ForceFormat = "force_format" // ForceFormat 强制格式化为OpenAI格式
|
||||
ChanelSettingProxy = "proxy" // Proxy 代理
|
||||
ChannelSettingThinkingToContent = "thinking_to_content" // ThinkingToContent
|
||||
)
|
||||
@@ -1,16 +0,0 @@
|
||||
package constant
|
||||
|
||||
var (
|
||||
UserSettingNotifyType = "notify_type" // QuotaWarningType 额度预警类型
|
||||
UserSettingQuotaWarningThreshold = "quota_warning_threshold" // QuotaWarningThreshold 额度预警阈值
|
||||
UserSettingWebhookUrl = "webhook_url" // WebhookUrl webhook地址
|
||||
UserSettingWebhookSecret = "webhook_secret" // WebhookSecret webhook密钥
|
||||
UserSettingNotificationEmail = "notification_email" // NotificationEmail 通知邮箱地址
|
||||
UserAcceptUnsetRatioModel = "accept_unset_model_ratio_model" // AcceptUnsetRatioModel 是否接受未设置价格的模型
|
||||
UserSettingRecordIpLog = "record_ip_log" // 是否记录请求和错误日志IP
|
||||
)
|
||||
|
||||
var (
|
||||
NotifyTypeEmail = "email" // Email 邮件
|
||||
NotifyTypeWebhook = "webhook" // Webhook
|
||||
)
|
||||
@@ -173,8 +173,19 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
|
||||
consumedTime := float64(milliseconds) / 1000.0
|
||||
other := service.GenerateTextOtherInfo(c, info, priceData.ModelRatio, priceData.GroupRatioInfo.GroupRatio, priceData.CompletionRatio,
|
||||
usage.PromptTokensDetails.CachedTokens, priceData.CacheRatio, priceData.ModelPrice, priceData.GroupRatioInfo.GroupSpecialRatio)
|
||||
model.RecordConsumeLog(c, 1, channel.Id, usage.PromptTokens, usage.CompletionTokens, info.OriginModelName, "模型测试",
|
||||
quota, "模型测试", 0, quota, int(consumedTime), false, info.UsingGroup, other)
|
||||
model.RecordConsumeLog(c, 1, model.RecordConsumeLogParams{
|
||||
ChannelId: channel.Id,
|
||||
PromptTokens: usage.PromptTokens,
|
||||
CompletionTokens: usage.CompletionTokens,
|
||||
ModelName: info.OriginModelName,
|
||||
TokenName: "模型测试",
|
||||
Quota: quota,
|
||||
Content: "模型测试",
|
||||
UseTimeSeconds: int(consumedTime),
|
||||
IsStream: false,
|
||||
Group: info.UsingGroup,
|
||||
Other: other,
|
||||
})
|
||||
common.SysLog(fmt.Sprintf("testing channel #%d, response: \n%s", channel.Id, string(respBody)))
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
@@ -228,7 +228,7 @@ func FetchUpstreamModels(c *gin.Context) {
|
||||
}
|
||||
|
||||
func FixChannelsAbilities(c *gin.Context) {
|
||||
count, err := model.FixAbility()
|
||||
success, fails, err := model.FixAbility()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
@@ -239,7 +239,10 @@ func FixChannelsAbilities(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": count,
|
||||
"data": gin.H{
|
||||
"success": success,
|
||||
"fails": fails,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
@@ -387,6 +390,14 @@ func AddChannel(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
err = channel.ValidateSettings()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "channel setting 格式错误:" + err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
channel.CreatedTime = common.GetTimestamp()
|
||||
keys := strings.Split(channel.Key, "\n")
|
||||
if channel.Type == constant.ChannelTypeVertexAi {
|
||||
@@ -614,6 +625,14 @@ func UpdateChannel(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
err = channel.ValidateSettings()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "channel setting 格式错误:" + err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
if channel.Type == constant.ChannelTypeVertexAi {
|
||||
if channel.Other == "" {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
|
||||
@@ -1,12 +1,14 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"slices"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"one-api/middleware"
|
||||
"one-api/middleware/jsrt"
|
||||
"one-api/model"
|
||||
"one-api/setting"
|
||||
"one-api/setting/console_setting"
|
||||
@@ -33,7 +35,6 @@ func TestStatus(c *gin.Context) {
|
||||
"message": "Server is running",
|
||||
"http_stats": httpStats,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func GetStatus(c *gin.Context) {
|
||||
@@ -106,7 +107,6 @@ func GetStatus(c *gin.Context) {
|
||||
"message": "",
|
||||
"data": data,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func GetNotice(c *gin.Context) {
|
||||
@@ -117,7 +117,6 @@ func GetNotice(c *gin.Context) {
|
||||
"message": "",
|
||||
"data": common.OptionMap["Notice"],
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func GetAbout(c *gin.Context) {
|
||||
@@ -128,7 +127,6 @@ func GetAbout(c *gin.Context) {
|
||||
"message": "",
|
||||
"data": common.OptionMap["About"],
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func GetMidjourney(c *gin.Context) {
|
||||
@@ -139,7 +137,6 @@ func GetMidjourney(c *gin.Context) {
|
||||
"message": "",
|
||||
"data": common.OptionMap["Midjourney"],
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func GetHomePageContent(c *gin.Context) {
|
||||
@@ -150,7 +147,6 @@ func GetHomePageContent(c *gin.Context) {
|
||||
"message": "",
|
||||
"data": common.OptionMap["HomePageContent"],
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func SendEmailVerification(c *gin.Context) {
|
||||
@@ -173,13 +169,7 @@ func SendEmailVerification(c *gin.Context) {
|
||||
localPart := parts[0]
|
||||
domainPart := parts[1]
|
||||
if common.EmailDomainRestrictionEnabled {
|
||||
allowed := false
|
||||
for _, domain := range common.EmailDomainWhitelist {
|
||||
if domainPart == domain {
|
||||
allowed = true
|
||||
break
|
||||
}
|
||||
}
|
||||
allowed := slices.Contains(common.EmailDomainWhitelist, domainPart)
|
||||
if !allowed {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
@@ -224,7 +214,6 @@ func SendEmailVerification(c *gin.Context) {
|
||||
"success": true,
|
||||
"message": "",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func SendPasswordResetEmail(c *gin.Context) {
|
||||
@@ -263,7 +252,6 @@ func SendPasswordResetEmail(c *gin.Context) {
|
||||
"success": true,
|
||||
"message": "",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
type PasswordResetRequest struct {
|
||||
@@ -303,5 +291,13 @@ func ResetPassword(c *gin.Context) {
|
||||
"message": "",
|
||||
"data": password,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func ReloadJSScripts(c *gin.Context) {
|
||||
jsrt.ReloadJSScripts()
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "JavaScript 脚本已重新加载",
|
||||
})
|
||||
}
|
||||
|
||||
@@ -130,7 +130,7 @@ func ListModels(c *gin.Context) {
|
||||
}
|
||||
} else {
|
||||
userId := c.GetInt("id")
|
||||
userGroup, err := model.GetUserGroup(userId, true)
|
||||
userGroup, err := model.GetUserGroup(userId, false)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
|
||||
@@ -122,7 +122,7 @@ func updateSunoTaskAll(ctx context.Context, channelId int, taskIds []string, tas
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
common.LogError(ctx, fmt.Sprintf("Get Task status code: %d", resp.StatusCode))
|
||||
return errors.New(fmt.Sprintf("Get Task status code: %d", resp.StatusCode))
|
||||
return fmt.Errorf("Get Task status code: %d", resp.StatusCode)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
"one-api/model"
|
||||
"one-api/setting"
|
||||
"strconv"
|
||||
@@ -961,7 +962,7 @@ func UpdateUserSetting(c *gin.Context) {
|
||||
}
|
||||
|
||||
// 验证预警类型
|
||||
if req.QuotaWarningType != constant.NotifyTypeEmail && req.QuotaWarningType != constant.NotifyTypeWebhook {
|
||||
if req.QuotaWarningType != dto.NotifyTypeEmail && req.QuotaWarningType != dto.NotifyTypeWebhook {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "无效的预警类型",
|
||||
@@ -979,7 +980,7 @@ func UpdateUserSetting(c *gin.Context) {
|
||||
}
|
||||
|
||||
// 如果是webhook类型,验证webhook地址
|
||||
if req.QuotaWarningType == constant.NotifyTypeWebhook {
|
||||
if req.QuotaWarningType == dto.NotifyTypeWebhook {
|
||||
if req.WebhookUrl == "" {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
@@ -998,7 +999,7 @@ func UpdateUserSetting(c *gin.Context) {
|
||||
}
|
||||
|
||||
// 如果是邮件类型,验证邮箱地址
|
||||
if req.QuotaWarningType == constant.NotifyTypeEmail && req.NotificationEmail != "" {
|
||||
if req.QuotaWarningType == dto.NotifyTypeEmail && req.NotificationEmail != "" {
|
||||
// 验证邮箱格式
|
||||
if !strings.Contains(req.NotificationEmail, "@") {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
@@ -1020,24 +1021,24 @@ func UpdateUserSetting(c *gin.Context) {
|
||||
}
|
||||
|
||||
// 构建设置
|
||||
settings := map[string]interface{}{
|
||||
constant.UserSettingNotifyType: req.QuotaWarningType,
|
||||
constant.UserSettingQuotaWarningThreshold: req.QuotaWarningThreshold,
|
||||
"accept_unset_model_ratio_model": req.AcceptUnsetModelRatioModel,
|
||||
constant.UserSettingRecordIpLog: req.RecordIpLog,
|
||||
settings := dto.UserSetting{
|
||||
NotifyType: req.QuotaWarningType,
|
||||
QuotaWarningThreshold: req.QuotaWarningThreshold,
|
||||
AcceptUnsetRatioModel: req.AcceptUnsetModelRatioModel,
|
||||
RecordIpLog: req.RecordIpLog,
|
||||
}
|
||||
|
||||
// 如果是webhook类型,添加webhook相关设置
|
||||
if req.QuotaWarningType == constant.NotifyTypeWebhook {
|
||||
settings[constant.UserSettingWebhookUrl] = req.WebhookUrl
|
||||
if req.QuotaWarningType == dto.NotifyTypeWebhook {
|
||||
settings.WebhookUrl = req.WebhookUrl
|
||||
if req.WebhookSecret != "" {
|
||||
settings[constant.UserSettingWebhookSecret] = req.WebhookSecret
|
||||
settings.WebhookSecret = req.WebhookSecret
|
||||
}
|
||||
}
|
||||
|
||||
// 如果提供了通知邮箱,添加到设置中
|
||||
if req.QuotaWarningType == constant.NotifyTypeEmail && req.NotificationEmail != "" {
|
||||
settings[constant.UserSettingNotificationEmail] = req.NotificationEmail
|
||||
if req.QuotaWarningType == dto.NotifyTypeEmail && req.NotificationEmail != "" {
|
||||
settings.NotificationEmail = req.NotificationEmail
|
||||
}
|
||||
|
||||
// 更新用户设置
|
||||
|
||||
@@ -11,6 +11,7 @@ services:
|
||||
volumes:
|
||||
- ./data:/data
|
||||
- ./logs:/app/logs
|
||||
- ${JS_SCRIPT_DIR:-./scripts}:/app/scripts
|
||||
environment:
|
||||
- SQL_DSN=root:123456@tcp(mysql:3306)/new-api # Point to the mysql service
|
||||
- REDIS_CONN_STRING=redis://redis
|
||||
@@ -21,7 +22,6 @@ services:
|
||||
# - NODE_TYPE=slave # Uncomment for slave node in multi-node deployment
|
||||
# - SYNC_FREQUENCY=60 # Uncomment if regular database syncing is needed
|
||||
# - FRONTEND_BASE_URL=https://openai.justsong.cn # Uncomment for multi-node deployment with front-end URL
|
||||
|
||||
depends_on:
|
||||
- redis
|
||||
- mysql
|
||||
|
||||
7
dto/channel_settings.go
Normal file
7
dto/channel_settings.go
Normal file
@@ -0,0 +1,7 @@
|
||||
package dto
|
||||
|
||||
type ChannelSettings struct {
|
||||
ForceFormat bool `json:"force_format,omitempty"`
|
||||
ThinkingToContent bool `json:"thinking_to_content,omitempty"`
|
||||
Proxy string `json:"proxy"`
|
||||
}
|
||||
16
dto/user_settings.go
Normal file
16
dto/user_settings.go
Normal file
@@ -0,0 +1,16 @@
|
||||
package dto
|
||||
|
||||
type UserSetting struct {
|
||||
NotifyType string `json:"notify_type,omitempty"` // QuotaWarningType 额度预警类型
|
||||
QuotaWarningThreshold float64 `json:"quota_warning_threshold,omitempty"` // QuotaWarningThreshold 额度预警阈值
|
||||
WebhookUrl string `json:"webhook_url,omitempty"` // WebhookUrl webhook地址
|
||||
WebhookSecret string `json:"webhook_secret,omitempty"` // WebhookSecret webhook密钥
|
||||
NotificationEmail string `json:"notification_email,omitempty"` // NotificationEmail 通知邮箱地址
|
||||
AcceptUnsetRatioModel bool `json:"accept_unset_model_ratio_model,omitempty"` // AcceptUnsetRatioModel 是否接受未设置价格的模型
|
||||
RecordIpLog bool `json:"record_ip_log,omitempty"` // 是否记录请求和错误日志IP
|
||||
}
|
||||
|
||||
var (
|
||||
NotifyTypeEmail = "email" // Email 邮件
|
||||
NotifyTypeWebhook = "webhook" // Webhook
|
||||
)
|
||||
5
go.mod
5
go.mod
@@ -11,6 +11,7 @@ require (
|
||||
github.com/aws/aws-sdk-go-v2/credentials v1.17.11
|
||||
github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.7.4
|
||||
github.com/bytedance/gopkg v0.0.0-20220118071334-3db87571198b
|
||||
github.com/dop251/goja v0.0.0-20250630131328-58d95d85e994
|
||||
github.com/gin-contrib/cors v1.7.2
|
||||
github.com/gin-contrib/gzip v0.0.6
|
||||
github.com/gin-contrib/sessions v0.0.5
|
||||
@@ -31,6 +32,7 @@ require (
|
||||
golang.org/x/crypto v0.35.0
|
||||
golang.org/x/image v0.23.0
|
||||
golang.org/x/net v0.35.0
|
||||
golang.org/x/sync v0.11.0
|
||||
gorm.io/driver/mysql v1.4.3
|
||||
gorm.io/driver/postgres v1.5.2
|
||||
gorm.io/gorm v1.25.2
|
||||
@@ -56,9 +58,11 @@ require (
|
||||
github.com/go-ole/go-ole v1.2.6 // indirect
|
||||
github.com/go-playground/locales v0.14.1 // indirect
|
||||
github.com/go-playground/universal-translator v0.18.1 // indirect
|
||||
github.com/go-sourcemap/sourcemap v2.1.3+incompatible // indirect
|
||||
github.com/go-sql-driver/mysql v1.7.0 // indirect
|
||||
github.com/goccy/go-json v0.10.2 // indirect
|
||||
github.com/google/go-cmp v0.6.0 // indirect
|
||||
github.com/google/pprof v0.0.0-20230207041349-798e818bf904 // indirect
|
||||
github.com/gorilla/context v1.1.1 // indirect
|
||||
github.com/gorilla/securecookie v1.1.1 // indirect
|
||||
github.com/gorilla/sessions v1.2.1 // indirect
|
||||
@@ -84,7 +88,6 @@ require (
|
||||
github.com/yusufpapurcu/wmi v1.2.3 // indirect
|
||||
golang.org/x/arch v0.12.0 // indirect
|
||||
golang.org/x/exp v0.0.0-20240404231335-c0f41cb1a7a0 // indirect
|
||||
golang.org/x/sync v0.11.0 // indirect
|
||||
golang.org/x/sys v0.30.0 // indirect
|
||||
golang.org/x/text v0.22.0 // indirect
|
||||
google.golang.org/protobuf v1.34.2 // indirect
|
||||
|
||||
10
go.sum
10
go.sum
@@ -1,5 +1,7 @@
|
||||
github.com/Calcium-Ion/go-epay v0.0.4 h1:C96M7WfRLadcIVscWzwLiYs8etI1wrDmtFMuK2zP22A=
|
||||
github.com/Calcium-Ion/go-epay v0.0.4/go.mod h1:cxo/ZOg8ClvE3VAnCmEzbuyAZINSq7kFEN9oHj5WQ2U=
|
||||
github.com/Masterminds/semver/v3 v3.2.1 h1:RN9w6+7QoMeJVGyfmbcgs28Br8cvmnucEXnY0rYXWg0=
|
||||
github.com/Masterminds/semver/v3 v3.2.1/go.mod h1:qvl/7zhW3nngYb5+80sSMF+FG2BjYrf8m9wsX0PNOMQ=
|
||||
github.com/andybalholm/brotli v1.1.1 h1:PR2pgnyFznKEugtsUo0xLdDop5SKXd5Qf5ysW+7XdTA=
|
||||
github.com/andybalholm/brotli v1.1.1/go.mod h1:05ib4cKhjx3OQYUY22hTVd34Bc8upXjOLL2rKwwZBoA=
|
||||
github.com/anknown/ahocorasick v0.0.0-20190904063843-d75dbd5169c0 h1:onfun1RA+KcxaMk1lfrRnwCd1UUuOjJM/lri5eM1qMs=
|
||||
@@ -40,6 +42,8 @@ github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/r
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
|
||||
github.com/dlclark/regexp2 v1.11.5 h1:Q/sSnsKerHeCkc/jSTNq1oCm7KiVgUMZRDUoRu0JQZQ=
|
||||
github.com/dlclark/regexp2 v1.11.5/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
|
||||
github.com/dop251/goja v0.0.0-20250630131328-58d95d85e994 h1:aQYWswi+hRL2zJqGacdCZx32XjKYV8ApXFGntw79XAM=
|
||||
github.com/dop251/goja v0.0.0-20250630131328-58d95d85e994/go.mod h1:MxLav0peU43GgvwVgNbLAj1s/bSGboKkhuULvq/7hx4=
|
||||
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
|
||||
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
|
||||
github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4=
|
||||
@@ -83,6 +87,8 @@ github.com/go-playground/validator/v10 v10.20.0 h1:K9ISHbSaI0lyB2eWMPJo+kOS/FBEx
|
||||
github.com/go-playground/validator/v10 v10.20.0/go.mod h1:dbuPbCMFw/DrkbEynArYaCwl3amGuJotoKCe95atGMM=
|
||||
github.com/go-redis/redis/v8 v8.11.5 h1:AcZZR7igkdvfVmQTPnu9WE37LRrO/YrBH5zWyjDC0oI=
|
||||
github.com/go-redis/redis/v8 v8.11.5/go.mod h1:gREzHqY1hg6oD9ngVRbLStwAWKhA0FEgq8Jd4h5lpwo=
|
||||
github.com/go-sourcemap/sourcemap v2.1.3+incompatible h1:W1iEw64niKVGogNgBN3ePyLFfuisuzeidWPMPWmECqU=
|
||||
github.com/go-sourcemap/sourcemap v2.1.3+incompatible/go.mod h1:F8jJfvm2KbVjc5NqelyYJmf/v5J0dwNLS2mL4sNA1Jg=
|
||||
github.com/go-sql-driver/mysql v1.6.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg=
|
||||
github.com/go-sql-driver/mysql v1.7.0 h1:ueSltNNllEqE3qcWBTD0iQd3IpL/6U+mJxLkazJ7YPc=
|
||||
github.com/go-sql-driver/mysql v1.7.0/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI=
|
||||
@@ -97,8 +103,8 @@ github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/
|
||||
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
|
||||
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
||||
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
||||
github.com/google/pprof v0.0.0-20221118152302-e6195bd50e26 h1:Xim43kblpZXfIBQsbuBVKCudVG457BR2GZFIz3uw3hQ=
|
||||
github.com/google/pprof v0.0.0-20221118152302-e6195bd50e26/go.mod h1:dDKJzRmX4S37WGHujM7tX//fmj1uioxKzKxz3lo4HJo=
|
||||
github.com/google/pprof v0.0.0-20230207041349-798e818bf904 h1:4/hN5RUoecvl+RmJRE2YxKWtnnQls6rQjjW5oV7qg2U=
|
||||
github.com/google/pprof v0.0.0-20230207041349-798e818bf904/go.mod h1:uglQLonpP8qtYCYyzA+8c/9qtqgA3qsXGYqCPKARAFg=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/gorilla/context v1.1.1 h1:AWwleXJkX/nhcU9bZSnZoi3h/qGYqQAGhq6zZe/aQW8=
|
||||
|
||||
7
main.go
7
main.go
@@ -39,7 +39,6 @@ func main() {
|
||||
return
|
||||
}
|
||||
|
||||
common.SetupLogger()
|
||||
common.SysLog("New API " + common.Version + " started")
|
||||
if os.Getenv("GIN_MODE") != "debug" {
|
||||
gin.SetMode(gin.ReleaseMode)
|
||||
@@ -69,9 +68,9 @@ func main() {
|
||||
if r := recover(); r != nil {
|
||||
common.SysError(fmt.Sprintf("InitChannelCache panic: %v, retrying once", r))
|
||||
// Retry once
|
||||
_, fixErr := model.FixAbility()
|
||||
_, _, fixErr := model.FixAbility()
|
||||
if fixErr != nil {
|
||||
common.SysError(fmt.Sprintf("InitChannelCache failed: %s", fixErr.Error()))
|
||||
common.FatalLog(fmt.Sprintf("InitChannelCache failed: %s", fixErr.Error()))
|
||||
}
|
||||
}
|
||||
}()
|
||||
@@ -172,6 +171,8 @@ func InitResources() error {
|
||||
// 加载环境变量
|
||||
common.InitEnv()
|
||||
|
||||
common.SetupLogger()
|
||||
|
||||
// Initialize model settings
|
||||
ratio_setting.InitRatioSettings()
|
||||
|
||||
|
||||
@@ -247,9 +247,9 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode
|
||||
}
|
||||
c.Set("channel_id", channel.Id)
|
||||
c.Set("channel_name", channel.Name)
|
||||
c.Set("channel_type", channel.Type)
|
||||
common.SetContextKey(c, constant.ContextKeyChannelType, channel.Type)
|
||||
c.Set("channel_create_time", channel.CreatedTime)
|
||||
c.Set("channel_setting", channel.GetSetting())
|
||||
common.SetContextKey(c, constant.ContextKeyChannelSetting, channel.GetSetting())
|
||||
c.Set("param_override", channel.GetParamOverride())
|
||||
if nil != channel.OpenAIOrganization && "" != *channel.OpenAIOrganization {
|
||||
c.Set("channel_organization", *channel.OpenAIOrganization)
|
||||
@@ -258,7 +258,7 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode
|
||||
c.Set("model_mapping", channel.GetModelMapping())
|
||||
c.Set("status_code_mapping", channel.GetStatusCodeMapping())
|
||||
c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
|
||||
c.Set("base_url", channel.GetBaseURL())
|
||||
common.SetContextKey(c, constant.ContextKeyBaseUrl, channel.GetBaseURL())
|
||||
// TODO: api_version统一
|
||||
switch channel.Type {
|
||||
case constant.ChannelTypeAzure:
|
||||
|
||||
62
middleware/jsrt/cfg.go
Normal file
62
middleware/jsrt/cfg.go
Normal file
@@ -0,0 +1,62 @@
|
||||
package jsrt
|
||||
|
||||
import (
|
||||
"os"
|
||||
"strconv"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Runtime 配置
|
||||
type JSRuntimeConfig struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
MaxVMCount int `json:"max_vm_count"`
|
||||
ScriptTimeout time.Duration `json:"script_timeout"`
|
||||
ScriptDir string `json:"script_dir"`
|
||||
FetchTimeout time.Duration `json:"fetch_timeout"`
|
||||
}
|
||||
|
||||
var (
|
||||
jsConfig = JSRuntimeConfig{}
|
||||
)
|
||||
|
||||
const (
|
||||
defaultScriptDir = "scripts/"
|
||||
defaultScriptTimeout = 5 * time.Second
|
||||
defaultFetchTimeout = 10 * time.Second
|
||||
defaultMaxVMCount = 8
|
||||
)
|
||||
|
||||
func loadCfg() {
|
||||
if enabled := os.Getenv("JS_RUNTIME_ENABLED"); enabled != "" {
|
||||
jsConfig.Enabled = enabled == "true"
|
||||
}
|
||||
|
||||
if maxCount := os.Getenv("JS_MAX_VM_COUNT"); maxCount != "" {
|
||||
if count, err := strconv.Atoi(maxCount); err == nil && count > 0 {
|
||||
jsConfig.MaxVMCount = count
|
||||
}
|
||||
} else {
|
||||
jsConfig.MaxVMCount = defaultMaxVMCount
|
||||
}
|
||||
|
||||
if timeout := os.Getenv("JS_SCRIPT_TIMEOUT"); timeout != "" {
|
||||
if t, err := time.ParseDuration(timeout + "s"); err == nil && t > 0 {
|
||||
jsConfig.ScriptTimeout = t
|
||||
}
|
||||
} else {
|
||||
jsConfig.ScriptTimeout = defaultScriptTimeout
|
||||
}
|
||||
|
||||
if fetchTimeout := os.Getenv("JS_FETCH_TIMEOUT"); fetchTimeout != "" {
|
||||
if t, err := time.ParseDuration(fetchTimeout + "s"); err == nil && t > 0 {
|
||||
jsConfig.FetchTimeout = t
|
||||
}
|
||||
} else {
|
||||
jsConfig.FetchTimeout = defaultFetchTimeout
|
||||
}
|
||||
|
||||
jsConfig.ScriptDir = os.Getenv("JS_SCRIPT_DIR")
|
||||
if jsConfig.ScriptDir == "" {
|
||||
jsConfig.ScriptDir = defaultScriptDir
|
||||
}
|
||||
}
|
||||
69
middleware/jsrt/db.go
Normal file
69
middleware/jsrt/db.go
Normal file
@@ -0,0 +1,69 @@
|
||||
package jsrt
|
||||
|
||||
import (
|
||||
"one-api/common"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
func dbQuery(db *gorm.DB, sql string, args ...any) []map[string]any {
|
||||
if db == nil {
|
||||
common.SysError("JS DB is nil")
|
||||
return nil
|
||||
}
|
||||
|
||||
rows, err := db.Raw(sql, args...).Rows()
|
||||
if err != nil {
|
||||
common.SysError("JS DB Query Error: " + err.Error())
|
||||
return nil
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
columns, err := rows.Columns()
|
||||
if err != nil {
|
||||
common.SysError("JS DB Columns Error: " + err.Error())
|
||||
return nil
|
||||
}
|
||||
|
||||
results := make([]map[string]any, 0, 100)
|
||||
for rows.Next() {
|
||||
values := make([]any, len(columns))
|
||||
valuePtrs := make([]any, len(columns))
|
||||
for i := range values {
|
||||
valuePtrs[i] = &values[i]
|
||||
}
|
||||
|
||||
if err := rows.Scan(valuePtrs...); err != nil {
|
||||
common.SysError("JS DB Scan Error: " + err.Error())
|
||||
continue
|
||||
}
|
||||
|
||||
row := make(map[string]any, len(columns))
|
||||
for i, col := range columns {
|
||||
val := values[i]
|
||||
if b, ok := val.([]byte); ok {
|
||||
row[col] = string(b)
|
||||
} else {
|
||||
row[col] = val
|
||||
}
|
||||
}
|
||||
results = append(results, row)
|
||||
}
|
||||
|
||||
return results
|
||||
}
|
||||
|
||||
func dbExec(db *gorm.DB, sql string, args ...any) map[string]any {
|
||||
if db == nil {
|
||||
return map[string]any{
|
||||
"rowsAffected": int64(0),
|
||||
"error": "database is nil",
|
||||
}
|
||||
}
|
||||
|
||||
result := db.Exec(sql, args...)
|
||||
return map[string]any{
|
||||
"rowsAffected": result.RowsAffected,
|
||||
"error": result.Error,
|
||||
}
|
||||
}
|
||||
137
middleware/jsrt/fetch.go
Normal file
137
middleware/jsrt/fetch.go
Normal file
@@ -0,0 +1,137 @@
|
||||
package jsrt
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
type JSFetchRequest struct {
|
||||
Method string `json:"method"`
|
||||
URL string `json:"url"`
|
||||
Headers map[string]string `json:"headers"`
|
||||
Body any `json:"body"`
|
||||
Timeout int `json:"timeout"`
|
||||
}
|
||||
|
||||
type JSFetchResponse struct {
|
||||
Status int `json:"status"`
|
||||
Headers map[string]string `json:"headers"`
|
||||
Body string `json:"body"`
|
||||
Error string `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
func (p *JSRuntimePool) fetch(url string, options ...any) *JSFetchResponse {
|
||||
req := &JSFetchRequest{
|
||||
Method: "GET",
|
||||
URL: url,
|
||||
Headers: make(map[string]string),
|
||||
Timeout: int(jsConfig.FetchTimeout.Seconds()),
|
||||
}
|
||||
|
||||
// 解析选项
|
||||
if len(options) > 0 && options[0] != nil {
|
||||
if optMap, ok := options[0].(map[string]any); ok {
|
||||
if method, exists := optMap["method"]; exists {
|
||||
if methodStr, ok := method.(string); ok {
|
||||
req.Method = strings.ToUpper(methodStr)
|
||||
}
|
||||
}
|
||||
|
||||
if headers, exists := optMap["headers"]; exists {
|
||||
if headersMap, ok := headers.(map[string]any); ok {
|
||||
for k, v := range headersMap {
|
||||
if vStr, ok := v.(string); ok {
|
||||
req.Headers[k] = vStr
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if body, exists := optMap["body"]; exists {
|
||||
req.Body = body
|
||||
}
|
||||
|
||||
if timeout, exists := optMap["timeout"]; exists {
|
||||
if timeoutNum, ok := timeout.(float64); ok {
|
||||
req.Timeout = int(timeoutNum)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 创建HTTP请求
|
||||
var bodyReader io.Reader
|
||||
switch body := req.Body.(type) {
|
||||
case string:
|
||||
bodyReader = strings.NewReader(body)
|
||||
case []byte:
|
||||
bodyReader = bytes.NewReader(body)
|
||||
case nil:
|
||||
bodyReader = nil
|
||||
default:
|
||||
bodyBytes, err := json.Marshal(body)
|
||||
if err != nil {
|
||||
return &JSFetchResponse{
|
||||
Error: fmt.Sprintf("Failed to marshal body: %v", err),
|
||||
}
|
||||
}
|
||||
bodyReader = bytes.NewReader(bodyBytes)
|
||||
}
|
||||
|
||||
httpReq, err := http.NewRequest(req.Method, req.URL, bodyReader)
|
||||
if err != nil {
|
||||
return &JSFetchResponse{
|
||||
Error: err.Error(),
|
||||
}
|
||||
}
|
||||
|
||||
// 设置请求头
|
||||
for k, v := range req.Headers {
|
||||
httpReq.Header.Set(k, v)
|
||||
}
|
||||
|
||||
// 设置默认User-Agent
|
||||
if httpReq.Header.Get("User-Agent") == "" {
|
||||
httpReq.Header.Set("User-Agent", "JS-Runtime-Fetch/1.0")
|
||||
}
|
||||
|
||||
// 创建带超时的上下文
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(req.Timeout)*time.Second)
|
||||
defer cancel()
|
||||
httpReq = httpReq.WithContext(ctx)
|
||||
|
||||
// 执行请求
|
||||
resp, err := p.httpClient.Do(httpReq)
|
||||
if err != nil {
|
||||
return &JSFetchResponse{}
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// 读取响应体
|
||||
bodyBytes, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return &JSFetchResponse{
|
||||
Status: resp.StatusCode,
|
||||
}
|
||||
}
|
||||
|
||||
// 构建响应头
|
||||
headers := make(map[string]string)
|
||||
for k, v := range resp.Header {
|
||||
if len(v) > 0 {
|
||||
headers[k] = v[0]
|
||||
}
|
||||
}
|
||||
|
||||
return &JSFetchResponse{
|
||||
Status: resp.StatusCode,
|
||||
Headers: headers,
|
||||
Body: string(bodyBytes),
|
||||
}
|
||||
}
|
||||
570
middleware/jsrt/jsrt.go
Normal file
570
middleware/jsrt/jsrt.go
Normal file
@@ -0,0 +1,570 @@
|
||||
package jsrt
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/tls"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/dop251/goja"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// 池化
|
||||
type JSRuntimePool struct {
|
||||
pool chan *goja.Runtime
|
||||
maxSize int
|
||||
createFunc func() *goja.Runtime
|
||||
scripts string
|
||||
mu sync.RWMutex
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
var (
|
||||
jsRuntimePool *JSRuntimePool
|
||||
jsPoolOnce sync.Once
|
||||
)
|
||||
|
||||
func NewJSRuntimePool(maxSize int) *JSRuntimePool {
|
||||
// 创建HTTP客户端
|
||||
httpClient := &http.Client{
|
||||
Timeout: jsConfig.FetchTimeout,
|
||||
Transport: &http.Transport{
|
||||
TLSClientConfig: &tls.Config{
|
||||
InsecureSkipVerify: false,
|
||||
},
|
||||
MaxIdleConns: 100,
|
||||
MaxIdleConnsPerHost: 10,
|
||||
IdleConnTimeout: 90 * time.Second,
|
||||
},
|
||||
}
|
||||
|
||||
pool := &JSRuntimePool{
|
||||
pool: make(chan *goja.Runtime, maxSize),
|
||||
maxSize: maxSize,
|
||||
scripts: "",
|
||||
httpClient: httpClient,
|
||||
}
|
||||
|
||||
pool.createFunc = func() *goja.Runtime {
|
||||
vm := goja.New()
|
||||
pool.setupGlobals(vm)
|
||||
pool.loadScripts(vm)
|
||||
return vm
|
||||
}
|
||||
|
||||
// 预创建VM
|
||||
preCreate := min(maxSize/2, 4)
|
||||
for range preCreate {
|
||||
select {
|
||||
case pool.pool <- pool.createFunc():
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
return pool
|
||||
}
|
||||
|
||||
func (p *JSRuntimePool) Get() *goja.Runtime {
|
||||
select {
|
||||
case vm := <-p.pool:
|
||||
return vm
|
||||
default:
|
||||
return p.createFunc()
|
||||
}
|
||||
}
|
||||
|
||||
func (p *JSRuntimePool) Put(vm *goja.Runtime) {
|
||||
if vm == nil {
|
||||
return
|
||||
}
|
||||
|
||||
select {
|
||||
case p.pool <- vm:
|
||||
default:
|
||||
// 池满,丢弃VM让GC回收
|
||||
}
|
||||
}
|
||||
|
||||
func (p *JSRuntimePool) setupGlobals(vm *goja.Runtime) {
|
||||
// console
|
||||
console := vm.NewObject()
|
||||
console.Set("log", func(args ...any) {
|
||||
var strs []string
|
||||
for _, arg := range args {
|
||||
strs = append(strs, fmt.Sprintf("%v", arg))
|
||||
}
|
||||
common.SysLog("JS: " + strings.Join(strs, " "))
|
||||
})
|
||||
console.Set("error", func(args ...any) {
|
||||
var strs []string
|
||||
for _, arg := range args {
|
||||
strs = append(strs, fmt.Sprintf("%v", arg))
|
||||
}
|
||||
common.SysError("JS: " + strings.Join(strs, " "))
|
||||
})
|
||||
console.Set("warn", func(args ...any) {
|
||||
var strs []string
|
||||
for _, arg := range args {
|
||||
strs = append(strs, fmt.Sprintf("%v", arg))
|
||||
}
|
||||
common.SysError("JS WARN: " + strings.Join(strs, " "))
|
||||
})
|
||||
vm.Set("console", console)
|
||||
|
||||
// JSON
|
||||
jsonObj := vm.NewObject()
|
||||
jsonObj.Set("parse", func(str string) any {
|
||||
var result any
|
||||
err := json.Unmarshal([]byte(str), &result)
|
||||
if err != nil {
|
||||
panic(vm.ToValue(err.Error()))
|
||||
}
|
||||
return result
|
||||
})
|
||||
jsonObj.Set("stringify", func(obj any) string {
|
||||
data, err := json.Marshal(obj)
|
||||
if err != nil {
|
||||
panic(vm.ToValue(err.Error()))
|
||||
}
|
||||
return string(data)
|
||||
})
|
||||
vm.Set("JSON", jsonObj)
|
||||
|
||||
// fetch 实现
|
||||
vm.Set("fetch", func(url string, options ...any) *JSFetchResponse {
|
||||
return p.fetch(url, options...)
|
||||
})
|
||||
|
||||
// 数据库
|
||||
setDB(vm, model.DB, "db")
|
||||
setDB(vm, model.LOG_DB, "logdb")
|
||||
|
||||
// 定时器
|
||||
vm.Set("setTimeout", func(fn func(), delay int) {
|
||||
go func() {
|
||||
time.Sleep(time.Duration(delay) * time.Millisecond)
|
||||
fn()
|
||||
}()
|
||||
})
|
||||
}
|
||||
|
||||
func (p *JSRuntimePool) loadScripts(vm *goja.Runtime) {
|
||||
p.mu.RLock()
|
||||
defer p.mu.RUnlock()
|
||||
|
||||
// 如果已经缓存了合并的脚本,直接使用
|
||||
if p.scripts != "" {
|
||||
if _, err := vm.RunString(p.scripts); err != nil {
|
||||
common.SysError("Failed to load cached scripts: " + err.Error())
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// 首次加载时,读取 scripts/ 文件夹中的所有脚本
|
||||
p.mu.RUnlock()
|
||||
p.mu.Lock()
|
||||
defer func() {
|
||||
p.mu.Unlock()
|
||||
p.mu.RLock()
|
||||
}()
|
||||
|
||||
if p.scripts != "" {
|
||||
if _, err := vm.RunString(p.scripts); err != nil {
|
||||
common.SysError("Failed to load cached scripts: " + err.Error())
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// 读取所有脚本文件
|
||||
var combinedScript strings.Builder
|
||||
scriptDir := jsConfig.ScriptDir
|
||||
|
||||
// 检查目录是否存在
|
||||
if _, err := os.Stat(scriptDir); os.IsNotExist(err) {
|
||||
common.SysLog("Scripts directory does not exist: " + scriptDir)
|
||||
return
|
||||
}
|
||||
|
||||
// 读取目录中的所有 .js 文件
|
||||
files, err := filepath.Glob(filepath.Join(scriptDir, "*.js"))
|
||||
if err != nil {
|
||||
common.SysError("Failed to read scripts directory: " + err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if len(files) == 0 {
|
||||
common.SysLog("No JavaScript files found in: " + scriptDir)
|
||||
return
|
||||
}
|
||||
|
||||
// 按文件名排序以确保加载顺序一致
|
||||
for _, file := range files {
|
||||
content, err := os.ReadFile(file)
|
||||
if err != nil {
|
||||
common.SysError("Failed to read script file " + file + ": " + err.Error())
|
||||
continue
|
||||
}
|
||||
|
||||
// 添加文件注释和内容
|
||||
combinedScript.WriteString("// File: " + filepath.Base(file) + "\n")
|
||||
combinedScript.WriteString(string(content))
|
||||
combinedScript.WriteString("\n\n")
|
||||
|
||||
common.SysLog("Loaded script: " + filepath.Base(file))
|
||||
}
|
||||
|
||||
// 缓存合并后的脚本
|
||||
p.scripts = combinedScript.String()
|
||||
|
||||
// 执行脚本
|
||||
if p.scripts != "" {
|
||||
if _, err := vm.RunString(p.scripts); err != nil {
|
||||
common.SysError("Failed to load combined scripts: " + err.Error())
|
||||
} else {
|
||||
common.SysLog("Successfully loaded and combined all JavaScript files from: " + scriptDir)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (p *JSRuntimePool) ReloadScripts() {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
// 清空缓存的脚本
|
||||
p.scripts = ""
|
||||
|
||||
// 清空VM池,强制重新创建
|
||||
for {
|
||||
select {
|
||||
case <-p.pool:
|
||||
default:
|
||||
goto done
|
||||
}
|
||||
}
|
||||
done:
|
||||
common.SysLog("JavaScript scripts reloaded")
|
||||
}
|
||||
|
||||
func initJSRuntimePool() *JSRuntimePool {
|
||||
jsPoolOnce.Do(func() {
|
||||
jsRuntimePool = NewJSRuntimePool(jsConfig.MaxVMCount)
|
||||
common.SysLog("JavaScript runtime pool initialized successfully")
|
||||
})
|
||||
return jsRuntimePool
|
||||
}
|
||||
|
||||
func validateGinContext(c *gin.Context) error {
|
||||
if c == nil {
|
||||
return fmt.Errorf("gin context is nil")
|
||||
}
|
||||
if c.Request == nil {
|
||||
return fmt.Errorf("gin context request is nil")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *JSRuntimePool) executeWithTimeout(_ *goja.Runtime, fn func() (goja.Value, error)) (goja.Value, error) {
|
||||
type result struct {
|
||||
value goja.Value
|
||||
err error
|
||||
}
|
||||
|
||||
resultChan := make(chan result, 1)
|
||||
go func() {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
resultChan <- result{err: fmt.Errorf("JS panic: %v", r)}
|
||||
}
|
||||
}()
|
||||
|
||||
value, err := fn()
|
||||
resultChan <- result{value: value, err: err}
|
||||
}()
|
||||
|
||||
select {
|
||||
case res := <-resultChan:
|
||||
return res.value, res.err
|
||||
case <-time.After(jsConfig.ScriptTimeout):
|
||||
return nil, fmt.Errorf("script execution timeout after %v", jsConfig.ScriptTimeout)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *JSRuntimePool) PreProcessRequest(c *gin.Context) error {
|
||||
if err := validateGinContext(c); err != nil {
|
||||
common.SysError("JS PreProcess Validation Error: " + err.Error())
|
||||
return err
|
||||
}
|
||||
|
||||
vm := p.Get()
|
||||
defer p.Put(vm)
|
||||
|
||||
preProcessFunc := vm.Get("preProcessRequest")
|
||||
if preProcessFunc == nil || goja.IsUndefined(preProcessFunc) {
|
||||
return nil
|
||||
}
|
||||
|
||||
jsReq, err := common.StructToMap(createJSReq(c))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create JS context: %v", err)
|
||||
}
|
||||
|
||||
result, err := p.executeWithTimeout(vm, func() (goja.Value, error) {
|
||||
fn, ok := goja.AssertFunction(preProcessFunc)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("preProcessRequest is not a function")
|
||||
}
|
||||
return fn(goja.Undefined(), vm.ToValue(jsReq))
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
common.SysError("JS PreProcess Error: " + err.Error())
|
||||
return err
|
||||
}
|
||||
|
||||
// 处理返回结果
|
||||
if result != nil && !goja.IsUndefined(result) {
|
||||
resultObj := result.Export()
|
||||
if resultMap, ok := resultObj.(map[string]any); ok {
|
||||
// 是否修改请求
|
||||
if newBody, exists := resultMap["body"]; exists {
|
||||
switch v := newBody.(type) {
|
||||
case string:
|
||||
c.Request.Body = io.NopCloser(strings.NewReader(v))
|
||||
c.Request.ContentLength = int64(len(v))
|
||||
case []byte:
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(v))
|
||||
c.Request.ContentLength = int64(len(v))
|
||||
case map[string]any:
|
||||
bodyBytes, err := json.Marshal(v)
|
||||
if err == nil {
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
|
||||
c.Request.ContentLength = int64(len(bodyBytes))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
} else {
|
||||
common.SysError("JS PreProcess JSON Marshal Error: " + err.Error())
|
||||
}
|
||||
default:
|
||||
common.SysError("JS PreProcess Unsupported Body Type: " + fmt.Sprintf("%T", newBody))
|
||||
}
|
||||
}
|
||||
|
||||
// 是否修改 headers
|
||||
if newHeaders, exists := resultMap["headers"]; exists {
|
||||
if headersMap, ok := newHeaders.(map[string]any); ok {
|
||||
for key, value := range headersMap {
|
||||
if valueStr, ok := value.(string); ok {
|
||||
c.Request.Header.Set(key, valueStr)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 是否阻止请求
|
||||
if block, exists := resultMap["block"]; exists {
|
||||
if blockBool, ok := block.(bool); ok && blockBool {
|
||||
status := http.StatusForbidden
|
||||
if statusCode, exists := resultMap["statusCode"]; exists {
|
||||
if statusInt, ok := statusCode.(float64); ok {
|
||||
status = int(statusInt)
|
||||
}
|
||||
}
|
||||
|
||||
message := "Request blocked by pre-process script"
|
||||
if msg, exists := resultMap["message"]; exists {
|
||||
if msgStr, ok := msg.(string); ok {
|
||||
message = msgStr
|
||||
}
|
||||
}
|
||||
|
||||
c.JSON(status, gin.H{"error": message})
|
||||
c.Abort()
|
||||
return fmt.Errorf("request blocked")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *JSRuntimePool) PostProcessResponse(c *gin.Context, statusCode int, body []byte) (int, []byte, error) {
|
||||
if err := validateGinContext(c); err != nil {
|
||||
common.SysError("JS PostProcess Validation Error: " + err.Error())
|
||||
return statusCode, body, err
|
||||
}
|
||||
|
||||
vm := p.Get()
|
||||
defer p.Put(vm)
|
||||
|
||||
postProcessFunc := vm.Get("postProcessResponse")
|
||||
if postProcessFunc == nil || goja.IsUndefined(postProcessFunc) {
|
||||
return statusCode, body, nil
|
||||
}
|
||||
|
||||
jsReq, err := common.StructToMap(createJSReq(c))
|
||||
if err != nil {
|
||||
return statusCode, body, fmt.Errorf("failed to create JS context: %v", err)
|
||||
}
|
||||
|
||||
jsResp := &JSResponse{
|
||||
StatusCode: statusCode,
|
||||
Headers: make(map[string]string),
|
||||
Body: string(body),
|
||||
}
|
||||
|
||||
// 获取响应头
|
||||
if c.Writer != nil {
|
||||
for key, values := range c.Writer.Header() {
|
||||
if len(values) > 0 {
|
||||
jsResp.Headers[key] = values[0]
|
||||
}
|
||||
}
|
||||
}
|
||||
jsResponse, err := common.StructToMap(jsResp)
|
||||
if err != nil {
|
||||
return statusCode, body, fmt.Errorf("failed to create JS response context: %v", err)
|
||||
}
|
||||
|
||||
result, err := p.executeWithTimeout(vm, func() (goja.Value, error) {
|
||||
fn, ok := goja.AssertFunction(postProcessFunc)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("postProcessResponse is not a function")
|
||||
}
|
||||
return fn(goja.Undefined(), vm.ToValue(jsReq), vm.ToValue(jsResponse))
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
common.SysError("JS PostProcess Error: " + err.Error())
|
||||
return statusCode, body, err
|
||||
}
|
||||
|
||||
// 处理返回
|
||||
if result != nil && !goja.IsUndefined(result) {
|
||||
resultObj := result.Export()
|
||||
if resultMap, ok := resultObj.(map[string]any); ok {
|
||||
if newStatusCode, exists := resultMap["statusCode"]; exists {
|
||||
if statusInt, ok := newStatusCode.(float64); ok {
|
||||
statusCode = int(statusInt)
|
||||
}
|
||||
}
|
||||
|
||||
if newBody, exists := resultMap["body"]; exists {
|
||||
if bodyStr, ok := newBody.(string); ok {
|
||||
body = []byte(bodyStr)
|
||||
}
|
||||
}
|
||||
|
||||
if newHeaders, exists := resultMap["headers"]; exists {
|
||||
if headersMap, ok := newHeaders.(map[string]any); ok {
|
||||
for key, value := range headersMap {
|
||||
if valueStr, ok := value.(string); ok {
|
||||
c.Header(key, valueStr)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return statusCode, body, nil
|
||||
}
|
||||
|
||||
func (p *JSRuntimePool) hasPostProcessFunction() bool {
|
||||
vm := p.Get()
|
||||
defer p.Put(vm)
|
||||
postProcessFunc := vm.Get("postProcessResponse")
|
||||
return postProcessFunc != nil && !goja.IsUndefined(postProcessFunc)
|
||||
}
|
||||
|
||||
func JSRuntimeMiddleware() *gin.HandlerFunc {
|
||||
loadCfg()
|
||||
if !jsConfig.Enabled {
|
||||
common.SysLog("JavaScript Runtime is disabled")
|
||||
return nil
|
||||
}
|
||||
|
||||
pool := initJSRuntimePool()
|
||||
var fn gin.HandlerFunc
|
||||
fn = func(c *gin.Context) {
|
||||
start := time.Now()
|
||||
|
||||
// 预处理
|
||||
if err := pool.PreProcessRequest(c); err != nil {
|
||||
common.SysError("JS Runtime PreProcess Error: " + err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
duration := time.Since(start)
|
||||
if duration > time.Millisecond*100 {
|
||||
common.SysLog(fmt.Sprintf("JS Runtime PreProcess took %v", duration))
|
||||
}
|
||||
|
||||
// 后处理
|
||||
if pool.hasPostProcessFunction() {
|
||||
writer := newResponseWriter(c.Writer)
|
||||
c.Writer = writer
|
||||
|
||||
c.Next()
|
||||
|
||||
// 后处理响应
|
||||
if writer.body.Len() > 0 {
|
||||
start := time.Now()
|
||||
|
||||
statusCode, body, err := pool.PostProcessResponse(c, writer.statusCode, writer.body.Bytes())
|
||||
if err == nil {
|
||||
c.Writer = writer.ResponseWriter
|
||||
|
||||
for k, v := range writer.headerMap {
|
||||
for _, value := range v {
|
||||
c.Writer.Header().Add(k, value)
|
||||
}
|
||||
}
|
||||
|
||||
c.Status(statusCode)
|
||||
|
||||
if len(body) >= 0 {
|
||||
c.Writer.Header().Set("Content-Length", fmt.Sprintf("%d", len(body)))
|
||||
c.Writer.Write(body)
|
||||
} else {
|
||||
c.Writer.Header().Del("Content-Length")
|
||||
c.Writer.Write(body)
|
||||
}
|
||||
} else {
|
||||
// 出错时回复原响应
|
||||
c.Writer = writer.ResponseWriter
|
||||
c.Status(writer.statusCode)
|
||||
|
||||
common.SysError(fmt.Sprintf("JS Runtime PostProcess Error: %v", err))
|
||||
}
|
||||
|
||||
duration := time.Since(start)
|
||||
if duration > time.Millisecond*100 {
|
||||
common.SysLog(fmt.Sprintf("JS Runtime PostProcess took %v", duration))
|
||||
}
|
||||
} else {
|
||||
// 没有响应体时,恢复原始writer
|
||||
c.Writer = writer.ResponseWriter
|
||||
}
|
||||
} else {
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
return &fn
|
||||
}
|
||||
|
||||
func ReloadJSScripts() {
|
||||
if jsRuntimePool != nil {
|
||||
jsRuntimePool.ReloadScripts()
|
||||
common.SysLog("JavaScript scripts reloaded")
|
||||
}
|
||||
}
|
||||
139
middleware/jsrt/req.go
Normal file
139
middleware/jsrt/req.go
Normal file
@@ -0,0 +1,139 @@
|
||||
package jsrt
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"maps"
|
||||
"net/http"
|
||||
"sync"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// 请求
|
||||
type JSReq struct {
|
||||
Method string `json:"method"`
|
||||
URL string `json:"url"`
|
||||
Headers map[string]string `json:"headers"`
|
||||
Body any `json:"body"`
|
||||
UserAgent string `json:"userAgent"`
|
||||
RemoteIP string `json:"remoteIP"`
|
||||
Extra map[string]any `json:"extra"`
|
||||
}
|
||||
|
||||
type JSResponse struct {
|
||||
StatusCode int `json:"statusCode"`
|
||||
Headers map[string]string `json:"headers"`
|
||||
Body string `json:"body"`
|
||||
}
|
||||
|
||||
type responseWriter struct {
|
||||
gin.ResponseWriter
|
||||
body *bytes.Buffer
|
||||
statusCode int
|
||||
headerMap http.Header
|
||||
written bool
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
func createJSReq(c *gin.Context) *JSReq {
|
||||
var bodyBytes []byte
|
||||
if c.Request != nil && c.Request.Body != nil {
|
||||
bodyBytes, _ = io.ReadAll(c.Request.Body)
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
|
||||
}
|
||||
|
||||
// headers map
|
||||
headers := make(map[string]string)
|
||||
if c.Request != nil && c.Request.Header != nil {
|
||||
for key, values := range c.Request.Header {
|
||||
if len(values) > 0 {
|
||||
headers[key] = values[0]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
method := ""
|
||||
url := ""
|
||||
userAgent := ""
|
||||
remoteIP := ""
|
||||
contentType := ""
|
||||
|
||||
if c.Request != nil {
|
||||
method = c.Request.Method
|
||||
if c.Request.URL != nil {
|
||||
url = c.Request.URL.String()
|
||||
}
|
||||
userAgent = c.Request.UserAgent()
|
||||
contentType = c.ContentType()
|
||||
}
|
||||
|
||||
if c != nil {
|
||||
remoteIP = c.ClientIP()
|
||||
}
|
||||
|
||||
parsedBody := parseBodyByType(bodyBytes, contentType)
|
||||
|
||||
return &JSReq{
|
||||
Method: method,
|
||||
URL: url,
|
||||
Headers: headers,
|
||||
Body: parsedBody,
|
||||
UserAgent: userAgent,
|
||||
RemoteIP: remoteIP,
|
||||
Extra: make(map[string]any),
|
||||
}
|
||||
}
|
||||
|
||||
func newResponseWriter(w gin.ResponseWriter) *responseWriter {
|
||||
return &responseWriter{
|
||||
ResponseWriter: w,
|
||||
body: &bytes.Buffer{},
|
||||
statusCode: 200,
|
||||
headerMap: make(http.Header),
|
||||
written: false,
|
||||
}
|
||||
}
|
||||
|
||||
func (w *responseWriter) Write(data []byte) (int, error) {
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
|
||||
if !w.written {
|
||||
w.WriteHeader(200)
|
||||
}
|
||||
return w.body.Write(data)
|
||||
}
|
||||
|
||||
func (w *responseWriter) WriteString(s string) (int, error) {
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
|
||||
if !w.written {
|
||||
w.WriteHeader(200)
|
||||
}
|
||||
return w.body.WriteString(s)
|
||||
}
|
||||
|
||||
func (w *responseWriter) WriteHeader(statusCode int) {
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
|
||||
if w.written {
|
||||
return
|
||||
}
|
||||
w.statusCode = statusCode
|
||||
w.written = true
|
||||
|
||||
maps.Copy(w.headerMap, w.ResponseWriter.Header())
|
||||
}
|
||||
|
||||
func (w *responseWriter) Header() http.Header {
|
||||
w.mu.RLock()
|
||||
defer w.mu.RUnlock()
|
||||
|
||||
if w.headerMap == nil {
|
||||
w.headerMap = make(http.Header)
|
||||
}
|
||||
return w.headerMap
|
||||
}
|
||||
86
middleware/jsrt/utils.go
Normal file
86
middleware/jsrt/utils.go
Normal file
@@ -0,0 +1,86 @@
|
||||
package jsrt
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/url"
|
||||
"one-api/common"
|
||||
"strings"
|
||||
|
||||
"github.com/dop251/goja"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
func setDB(vm *goja.Runtime, db *gorm.DB, name string) {
|
||||
if db == nil {
|
||||
common.SysError("JS DB is nil")
|
||||
return
|
||||
}
|
||||
|
||||
obj := vm.NewObject()
|
||||
obj.Set("query", func(sql string, params ...any) []map[string]any {
|
||||
return dbQuery(db, sql, params...)
|
||||
})
|
||||
obj.Set("exec", func(sql string, params ...any) map[string]any {
|
||||
return dbExec(db, sql, params...)
|
||||
})
|
||||
if err := vm.Set(name, obj); err != nil {
|
||||
common.SysError("Failed to set JS DB: " + err.Error())
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func parseBodyByType(bodyBytes []byte, contentType string) any {
|
||||
if len(bodyBytes) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
bodyStr := string(bodyBytes)
|
||||
contentLower := strings.ToLower(contentType)
|
||||
|
||||
switch {
|
||||
case strings.Contains(contentLower, "application/json"):
|
||||
var jsonObj any
|
||||
if err := json.Unmarshal(bodyBytes, &jsonObj); err == nil {
|
||||
return jsonObj
|
||||
}
|
||||
return bodyStr
|
||||
|
||||
case strings.Contains(contentLower, "application/x-www-form-urlencoded"):
|
||||
if values, err := url.ParseQuery(bodyStr); err == nil {
|
||||
result := make(map[string]string, len(values))
|
||||
for k, v := range values {
|
||||
if len(v) > 0 {
|
||||
result[k] = v[0]
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
return bodyStr
|
||||
|
||||
case strings.Contains(contentLower, "multipart/form-data"):
|
||||
return bodyBytes
|
||||
|
||||
case strings.Contains(contentLower, "text/"):
|
||||
return bodyStr
|
||||
|
||||
default:
|
||||
// 尝试JSON解析
|
||||
var jsonObj any
|
||||
if json.Unmarshal(bodyBytes, &jsonObj) == nil {
|
||||
return jsonObj
|
||||
}
|
||||
|
||||
// 尝试form解析
|
||||
if values, err := url.ParseQuery(bodyStr); err == nil && len(values) > 0 {
|
||||
result := make(map[string]string, len(values))
|
||||
for k, v := range values {
|
||||
if len(v) > 0 {
|
||||
result[k] = v[0]
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
return bodyStr
|
||||
}
|
||||
}
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"fmt"
|
||||
"one-api/common"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/samber/lo"
|
||||
"gorm.io/gorm"
|
||||
@@ -272,74 +273,45 @@ func UpdateAbilityByTag(tag string, newTag *string, priority *int64, weight *uin
|
||||
return DB.Model(&Ability{}).Where("tag = ?", tag).Updates(ability).Error
|
||||
}
|
||||
|
||||
func FixAbility() (int, error) {
|
||||
var channelIds []int
|
||||
count := 0
|
||||
// Find all channel ids from channel table
|
||||
err := DB.Model(&Channel{}).Pluck("id", &channelIds).Error
|
||||
var fixLock = sync.Mutex{}
|
||||
|
||||
func FixAbility() (int, int, error) {
|
||||
lock := fixLock.TryLock()
|
||||
if !lock {
|
||||
return 0, 0, errors.New("已经有一个修复任务在运行中,请稍后再试")
|
||||
}
|
||||
defer fixLock.Unlock()
|
||||
var channels []*Channel
|
||||
// Find all channels
|
||||
err := DB.Model(&Channel{}).Find(&channels).Error
|
||||
if err != nil {
|
||||
common.SysError(fmt.Sprintf("Get channel ids from channel table failed: %s", err.Error()))
|
||||
return 0, err
|
||||
return 0, 0, err
|
||||
}
|
||||
|
||||
// Delete abilities of channels that are not in channel table - in batches to avoid too many placeholders
|
||||
if len(channelIds) > 0 {
|
||||
// Process deletion in chunks to avoid "too many placeholders" error
|
||||
for _, chunk := range lo.Chunk(channelIds, 100) {
|
||||
err = DB.Where("channel_id NOT IN (?)", chunk).Delete(&Ability{}).Error
|
||||
if err != nil {
|
||||
common.SysError(fmt.Sprintf("Delete abilities of channels (batch) that are not in channel table failed: %s", err.Error()))
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// If no channels exist, delete all abilities
|
||||
err = DB.Delete(&Ability{}).Error
|
||||
if len(channels) == 0 {
|
||||
return 0, 0, nil
|
||||
}
|
||||
successCount := 0
|
||||
failCount := 0
|
||||
for _, chunk := range lo.Chunk(channels, 50) {
|
||||
ids := lo.Map(chunk, func(c *Channel, _ int) int { return c.Id })
|
||||
// Delete all abilities of this channel
|
||||
err = DB.Where("channel_id IN ?", ids).Delete(&Ability{}).Error
|
||||
if err != nil {
|
||||
common.SysError(fmt.Sprintf("Delete all abilities failed: %s", err.Error()))
|
||||
return 0, err
|
||||
common.SysError(fmt.Sprintf("Delete abilities failed: %s", err.Error()))
|
||||
failCount += len(chunk)
|
||||
continue
|
||||
}
|
||||
common.SysLog("Delete all abilities successfully")
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
common.SysLog(fmt.Sprintf("Delete abilities of channels that are not in channel table successfully, ids: %v", channelIds))
|
||||
count += len(channelIds)
|
||||
|
||||
// Use channelIds to find channel not in abilities table
|
||||
var abilityChannelIds []int
|
||||
err = DB.Table("abilities").Distinct("channel_id").Pluck("channel_id", &abilityChannelIds).Error
|
||||
if err != nil {
|
||||
common.SysError(fmt.Sprintf("Get channel ids from abilities table failed: %s", err.Error()))
|
||||
return count, err
|
||||
}
|
||||
|
||||
var channels []Channel
|
||||
if len(abilityChannelIds) == 0 {
|
||||
err = DB.Find(&channels).Error
|
||||
} else {
|
||||
// Process query in chunks to avoid "too many placeholders" error
|
||||
err = nil
|
||||
for _, chunk := range lo.Chunk(abilityChannelIds, 100) {
|
||||
var channelsChunk []Channel
|
||||
err = DB.Where("id NOT IN (?)", chunk).Find(&channelsChunk).Error
|
||||
// Then add new abilities
|
||||
for _, channel := range chunk {
|
||||
err = channel.AddAbilities()
|
||||
if err != nil {
|
||||
common.SysError(fmt.Sprintf("Find channels not in abilities table failed: %s", err.Error()))
|
||||
return count, err
|
||||
common.SysError(fmt.Sprintf("Add abilities for channel %d failed: %s", channel.Id, err.Error()))
|
||||
failCount++
|
||||
} else {
|
||||
successCount++
|
||||
}
|
||||
channels = append(channels, channelsChunk...)
|
||||
}
|
||||
}
|
||||
|
||||
for _, channel := range channels {
|
||||
err := channel.UpdateAbilities(nil)
|
||||
if err != nil {
|
||||
common.SysError(fmt.Sprintf("Update abilities of channel %d failed: %s", channel.Id, err.Error()))
|
||||
} else {
|
||||
common.SysLog(fmt.Sprintf("Update abilities of channel %d successfully", channel.Id))
|
||||
count++
|
||||
}
|
||||
}
|
||||
InitChannelCache()
|
||||
return count, nil
|
||||
return successCount, failCount, nil
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package model
|
||||
import (
|
||||
"encoding/json"
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
@@ -514,8 +515,19 @@ func SearchTags(keyword string, group string, model string, idSort bool) ([]*str
|
||||
return tags, nil
|
||||
}
|
||||
|
||||
func (channel *Channel) GetSetting() map[string]interface{} {
|
||||
setting := make(map[string]interface{})
|
||||
func (channel *Channel) ValidateSettings() error {
|
||||
channelParams := &dto.ChannelSettings{}
|
||||
if channel.Setting != nil && *channel.Setting != "" {
|
||||
err := json.Unmarshal([]byte(*channel.Setting), channelParams)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (channel *Channel) GetSetting() dto.ChannelSettings {
|
||||
setting := dto.ChannelSettings{}
|
||||
if channel.Setting != nil && *channel.Setting != "" {
|
||||
err := json.Unmarshal([]byte(*channel.Setting), &setting)
|
||||
if err != nil {
|
||||
@@ -525,7 +537,7 @@ func (channel *Channel) GetSetting() map[string]interface{} {
|
||||
return setting
|
||||
}
|
||||
|
||||
func (channel *Channel) SetSetting(setting map[string]interface{}) {
|
||||
func (channel *Channel) SetSetting(setting dto.ChannelSettings) {
|
||||
settingBytes, err := json.Marshal(setting)
|
||||
if err != nil {
|
||||
common.SysError("failed to marshal setting: " + err.Error())
|
||||
|
||||
61
model/log.go
61
model/log.go
@@ -4,7 +4,6 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -100,10 +99,8 @@ func RecordErrorLog(c *gin.Context, userId int, channelId int, modelName string,
|
||||
// 判断是否需要记录 IP
|
||||
needRecordIp := false
|
||||
if settingMap, err := GetUserSetting(userId, false); err == nil {
|
||||
if v, ok := settingMap[constant.UserSettingRecordIpLog]; ok {
|
||||
if vb, ok := v.(bool); ok && vb {
|
||||
needRecordIp = true
|
||||
}
|
||||
if settingMap.RecordIpLog {
|
||||
needRecordIp = true
|
||||
}
|
||||
}
|
||||
log := &Log{
|
||||
@@ -136,22 +133,34 @@ func RecordErrorLog(c *gin.Context, userId int, channelId int, modelName string,
|
||||
}
|
||||
}
|
||||
|
||||
func RecordConsumeLog(c *gin.Context, userId int, channelId int, promptTokens int, completionTokens int,
|
||||
modelName string, tokenName string, quota int, content string, tokenId int, userQuota int, useTimeSeconds int,
|
||||
isStream bool, group string, other map[string]interface{}) {
|
||||
common.LogInfo(c, fmt.Sprintf("record consume log: userId=%d, 用户调用前余额=%d, channelId=%d, promptTokens=%d, completionTokens=%d, modelName=%s, tokenName=%s, quota=%d, content=%s", userId, userQuota, channelId, promptTokens, completionTokens, modelName, tokenName, quota, content))
|
||||
type RecordConsumeLogParams struct {
|
||||
ChannelId int `json:"channel_id"`
|
||||
PromptTokens int `json:"prompt_tokens"`
|
||||
CompletionTokens int `json:"completion_tokens"`
|
||||
ModelName string `json:"model_name"`
|
||||
TokenName string `json:"token_name"`
|
||||
Quota int `json:"quota"`
|
||||
Content string `json:"content"`
|
||||
TokenId int `json:"token_id"`
|
||||
UserQuota int `json:"user_quota"`
|
||||
UseTimeSeconds int `json:"use_time_seconds"`
|
||||
IsStream bool `json:"is_stream"`
|
||||
Group string `json:"group"`
|
||||
Other map[string]interface{} `json:"other"`
|
||||
}
|
||||
|
||||
func RecordConsumeLog(c *gin.Context, userId int, params RecordConsumeLogParams) {
|
||||
common.LogInfo(c, fmt.Sprintf("record consume log: userId=%d, params=%s", userId, common.GetJsonString(params)))
|
||||
if !common.LogConsumeEnabled {
|
||||
return
|
||||
}
|
||||
username := c.GetString("username")
|
||||
otherStr := common.MapToJsonStr(other)
|
||||
otherStr := common.MapToJsonStr(params.Other)
|
||||
// 判断是否需要记录 IP
|
||||
needRecordIp := false
|
||||
if settingMap, err := GetUserSetting(userId, false); err == nil {
|
||||
if v, ok := settingMap[constant.UserSettingRecordIpLog]; ok {
|
||||
if vb, ok := v.(bool); ok && vb {
|
||||
needRecordIp = true
|
||||
}
|
||||
if settingMap.RecordIpLog {
|
||||
needRecordIp = true
|
||||
}
|
||||
}
|
||||
log := &Log{
|
||||
@@ -159,17 +168,17 @@ func RecordConsumeLog(c *gin.Context, userId int, channelId int, promptTokens in
|
||||
Username: username,
|
||||
CreatedAt: common.GetTimestamp(),
|
||||
Type: LogTypeConsume,
|
||||
Content: content,
|
||||
PromptTokens: promptTokens,
|
||||
CompletionTokens: completionTokens,
|
||||
TokenName: tokenName,
|
||||
ModelName: modelName,
|
||||
Quota: quota,
|
||||
ChannelId: channelId,
|
||||
TokenId: tokenId,
|
||||
UseTime: useTimeSeconds,
|
||||
IsStream: isStream,
|
||||
Group: group,
|
||||
Content: params.Content,
|
||||
PromptTokens: params.PromptTokens,
|
||||
CompletionTokens: params.CompletionTokens,
|
||||
TokenName: params.TokenName,
|
||||
ModelName: params.ModelName,
|
||||
Quota: params.Quota,
|
||||
ChannelId: params.ChannelId,
|
||||
TokenId: params.TokenId,
|
||||
UseTime: params.UseTimeSeconds,
|
||||
IsStream: params.IsStream,
|
||||
Group: params.Group,
|
||||
Ip: func() string {
|
||||
if needRecordIp {
|
||||
return c.ClientIP()
|
||||
@@ -184,7 +193,7 @@ func RecordConsumeLog(c *gin.Context, userId int, channelId int, promptTokens in
|
||||
}
|
||||
if common.DataExportEnabled {
|
||||
gopool.Go(func() {
|
||||
LogQuotaData(userId, username, modelName, quota, common.GetTimestamp(), promptTokens+completionTokens)
|
||||
LogQuotaData(userId, username, params.ModelName, params.Quota, common.GetTimestamp(), params.PromptTokens+params.CompletionTokens)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
@@ -68,14 +69,18 @@ func (user *User) SetAccessToken(token string) {
|
||||
user.AccessToken = &token
|
||||
}
|
||||
|
||||
func (user *User) GetSetting() map[string]interface{} {
|
||||
if user.Setting == "" {
|
||||
return nil
|
||||
func (user *User) GetSetting() dto.UserSetting {
|
||||
setting := dto.UserSetting{}
|
||||
if user.Setting != "" {
|
||||
err := json.Unmarshal([]byte(user.Setting), &setting)
|
||||
if err != nil {
|
||||
common.SysError("failed to unmarshal setting: " + err.Error())
|
||||
}
|
||||
}
|
||||
return common.StrToMap(user.Setting)
|
||||
return setting
|
||||
}
|
||||
|
||||
func (user *User) SetSetting(setting map[string]interface{}) {
|
||||
func (user *User) SetSetting(setting dto.UserSetting) {
|
||||
settingBytes, err := json.Marshal(setting)
|
||||
if err != nil {
|
||||
common.SysError("failed to marshal setting: " + err.Error())
|
||||
@@ -626,7 +631,7 @@ func GetUserGroup(id int, fromDB bool) (group string, err error) {
|
||||
}
|
||||
|
||||
// GetUserSetting gets setting from Redis first, falls back to DB if needed
|
||||
func GetUserSetting(id int, fromDB bool) (settingMap map[string]interface{}, err error) {
|
||||
func GetUserSetting(id int, fromDB bool) (settingMap dto.UserSetting, err error) {
|
||||
var setting string
|
||||
defer func() {
|
||||
// Update Redis cache asynchronously on successful DB read
|
||||
@@ -648,10 +653,12 @@ func GetUserSetting(id int, fromDB bool) (settingMap map[string]interface{}, err
|
||||
fromDB = true
|
||||
err = DB.Model(&User{}).Where("id = ?", id).Select("setting").Find(&setting).Error
|
||||
if err != nil {
|
||||
return map[string]interface{}{}, err
|
||||
return settingMap, err
|
||||
}
|
||||
|
||||
return common.StrToMap(setting), nil
|
||||
userBase := &UserBase{
|
||||
Setting: setting,
|
||||
}
|
||||
return userBase.GetSetting(), nil
|
||||
}
|
||||
|
||||
func IncreaseUserQuota(id int, quota int, db bool) (err error) {
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"fmt"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"one-api/dto"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -32,20 +33,15 @@ func (user *UserBase) WriteContext(c *gin.Context) {
|
||||
common.SetContextKey(c, constant.ContextKeyUserSetting, user.GetSetting())
|
||||
}
|
||||
|
||||
func (user *UserBase) GetSetting() map[string]interface{} {
|
||||
if user.Setting == "" {
|
||||
return nil
|
||||
func (user *UserBase) GetSetting() dto.UserSetting {
|
||||
setting := dto.UserSetting{}
|
||||
if user.Setting != "" {
|
||||
err := json.Unmarshal([]byte(user.Setting), &setting)
|
||||
if err != nil {
|
||||
common.SysError("failed to unmarshal setting: " + err.Error())
|
||||
}
|
||||
}
|
||||
return common.StrToMap(user.Setting)
|
||||
}
|
||||
|
||||
func (user *UserBase) SetSetting(setting map[string]interface{}) {
|
||||
settingBytes, err := json.Marshal(setting)
|
||||
if err != nil {
|
||||
common.SysError("failed to marshal setting: " + err.Error())
|
||||
return
|
||||
}
|
||||
user.Setting = string(settingBytes)
|
||||
return setting
|
||||
}
|
||||
|
||||
// getUserCacheKey returns the key for user cache
|
||||
@@ -174,11 +170,10 @@ func getUserNameCache(userId int) (string, error) {
|
||||
return cache.Username, nil
|
||||
}
|
||||
|
||||
func getUserSettingCache(userId int) (map[string]interface{}, error) {
|
||||
setting := make(map[string]interface{})
|
||||
func getUserSettingCache(userId int) (dto.UserSetting, error) {
|
||||
cache, err := GetUserCache(userId)
|
||||
if err != nil {
|
||||
return setting, err
|
||||
return dto.UserSetting{}, err
|
||||
}
|
||||
return cache.GetSetting(), nil
|
||||
}
|
||||
|
||||
@@ -206,8 +206,8 @@ func sendPingData(c *gin.Context, mutex *sync.Mutex) error {
|
||||
func doRequest(c *gin.Context, req *http.Request, info *common.RelayInfo) (*http.Response, error) {
|
||||
var client *http.Client
|
||||
var err error
|
||||
if proxyURL, ok := info.ChannelSetting["proxy"]; ok {
|
||||
client, err = service.NewProxyHttpClient(proxyURL.(string))
|
||||
if info.ChannelSetting.Proxy != "" {
|
||||
client, err = service.NewProxyHttpClient(info.ChannelSetting.Proxy)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("new proxy http client failed: %w", err)
|
||||
}
|
||||
|
||||
@@ -42,7 +42,16 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
|
||||
channel.SetupApiRequestHeader(info, c, req)
|
||||
req.Set("Authorization", "Bearer "+info.ApiKey)
|
||||
keyParts := strings.Split(info.ApiKey, "|")
|
||||
if len(keyParts) == 0 || keyParts[0] == "" {
|
||||
return errors.New("invalid API key: authorization token is required")
|
||||
}
|
||||
if len(keyParts) > 1 {
|
||||
if keyParts[1] != "" {
|
||||
req.Set("appid", keyParts[1])
|
||||
}
|
||||
}
|
||||
req.Set("Authorization", "Bearer "+keyParts[0])
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -278,8 +278,8 @@ func getChatDetail(a *Adaptor, c *gin.Context, info *relaycommon.RelayInfo) (*ht
|
||||
func doRequest(req *http.Request, info *relaycommon.RelayInfo) (*http.Response, error) {
|
||||
var client *http.Client
|
||||
var err error // 声明 err 变量
|
||||
if proxyURL, ok := info.ChannelSetting["proxy"]; ok {
|
||||
client, err = service.NewProxyHttpClient(proxyURL.(string))
|
||||
if info.ChannelSetting.Proxy != "" {
|
||||
client, err = service.NewProxyHttpClient(info.ChannelSetting.Proxy)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("new proxy http client failed: %w", err)
|
||||
}
|
||||
|
||||
@@ -53,7 +53,7 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
||||
a.ChannelType = info.ChannelType
|
||||
|
||||
// initialize ThinkingContentInfo when thinking_to_content is enabled
|
||||
if think2Content, ok := info.ChannelSetting[constant.ChannelSettingThinkingToContent].(bool); ok && think2Content {
|
||||
if info.ChannelSetting.ThinkingToContent {
|
||||
info.ThinkingContentInfo = relaycommon.ThinkingContentInfo{
|
||||
IsFirstThinkingContent: true,
|
||||
SendLastThinkingContent: false,
|
||||
@@ -145,7 +145,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, header *http.Header, info *
|
||||
header.Set("Authorization", "Bearer "+info.ApiKey)
|
||||
}
|
||||
if info.ChannelType == constant.ChannelTypeOpenRouter {
|
||||
header.Set("HTTP-Referer", "https://github.com/Calcium-Ion/new-api")
|
||||
header.Set("HTTP-Referer", "https://www.newapi.ai")
|
||||
header.Set("X-Title", "New API")
|
||||
}
|
||||
return nil
|
||||
|
||||
@@ -124,12 +124,12 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
|
||||
var forceFormat bool
|
||||
var thinkToContent bool
|
||||
|
||||
if forceFmt, ok := info.ChannelSetting[constant.ForceFormat].(bool); ok {
|
||||
forceFormat = forceFmt
|
||||
if info.ChannelSetting.ForceFormat {
|
||||
forceFormat = true
|
||||
}
|
||||
|
||||
if think2Content, ok := info.ChannelSetting[constant.ChannelSettingThinkingToContent].(bool); ok {
|
||||
thinkToContent = think2Content
|
||||
if info.ChannelSetting.ThinkingToContent {
|
||||
thinkToContent = true
|
||||
}
|
||||
|
||||
var (
|
||||
@@ -200,8 +200,8 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayI
|
||||
}
|
||||
|
||||
forceFormat := false
|
||||
if forceFmt, ok := info.ChannelSetting[constant.ForceFormat].(bool); ok {
|
||||
forceFormat = forceFmt
|
||||
if info.ChannelSetting.ForceFormat {
|
||||
forceFormat = true
|
||||
}
|
||||
|
||||
if simpleResponse.Usage.TotalTokens == 0 || (simpleResponse.Usage.PromptTokens == 0 && simpleResponse.Usage.CompletionTokens == 0) {
|
||||
|
||||
@@ -106,8 +106,8 @@ func exchangeJwtForAccessToken(signedJWT string, info *relaycommon.RelayInfo) (s
|
||||
|
||||
var client *http.Client
|
||||
var err error
|
||||
if proxyURL, ok := info.ChannelSetting["proxy"]; ok {
|
||||
client, err = service.NewProxyHttpClient(proxyURL.(string))
|
||||
if info.ChannelSetting.Proxy != "" {
|
||||
client, err = service.NewProxyHttpClient(info.ChannelSetting.Proxy)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("new proxy http client failed: %w", err)
|
||||
}
|
||||
|
||||
@@ -97,9 +97,9 @@ type RelayInfo struct {
|
||||
IsFirstRequest bool
|
||||
AudioUsage bool
|
||||
ReasoningEffort string
|
||||
ChannelSetting map[string]interface{}
|
||||
ChannelSetting dto.ChannelSettings
|
||||
ParamOverride map[string]interface{}
|
||||
UserSetting map[string]interface{}
|
||||
UserSetting dto.UserSetting
|
||||
UserEmail string
|
||||
UserQuota int
|
||||
RelayFormat string
|
||||
@@ -213,7 +213,6 @@ func GenRelayInfoImage(c *gin.Context) *RelayInfo {
|
||||
func GenRelayInfo(c *gin.Context) *RelayInfo {
|
||||
channelType := common.GetContextKeyInt(c, constant.ContextKeyChannelType)
|
||||
channelId := common.GetContextKeyInt(c, constant.ContextKeyChannelId)
|
||||
channelSetting := common.GetContextKeyStringMap(c, constant.ContextKeyChannelSetting)
|
||||
paramOverride := common.GetContextKeyStringMap(c, constant.ContextKeyParamOverride)
|
||||
|
||||
tokenId := common.GetContextKeyInt(c, constant.ContextKeyTokenId)
|
||||
@@ -227,7 +226,6 @@ func GenRelayInfo(c *gin.Context) *RelayInfo {
|
||||
|
||||
info := &RelayInfo{
|
||||
UserQuota: common.GetContextKeyInt(c, constant.ContextKeyUserQuota),
|
||||
UserSetting: common.GetContextKeyStringMap(c, constant.ContextKeyUserSetting),
|
||||
UserEmail: common.GetContextKeyString(c, constant.ContextKeyUserEmail),
|
||||
isFirstResponse: true,
|
||||
RelayMode: relayconstant.Path2RelayMode(c.Request.URL.Path),
|
||||
@@ -246,12 +244,12 @@ func GenRelayInfo(c *gin.Context) *RelayInfo {
|
||||
OriginModelName: common.GetContextKeyString(c, constant.ContextKeyOriginalModel),
|
||||
UpstreamModelName: common.GetContextKeyString(c, constant.ContextKeyOriginalModel),
|
||||
//RecodeModelName: c.GetString("original_model"),
|
||||
IsModelMapped: false,
|
||||
ApiType: apiType,
|
||||
ApiVersion: c.GetString("api_version"),
|
||||
ApiKey: strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "),
|
||||
Organization: c.GetString("channel_organization"),
|
||||
ChannelSetting: channelSetting,
|
||||
IsModelMapped: false,
|
||||
ApiType: apiType,
|
||||
ApiVersion: c.GetString("api_version"),
|
||||
ApiKey: strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "),
|
||||
Organization: c.GetString("channel_organization"),
|
||||
|
||||
ChannelCreateTime: c.GetInt64("channel_create_time"),
|
||||
ParamOverride: paramOverride,
|
||||
RelayFormat: RelayFormatOpenAI,
|
||||
@@ -277,6 +275,16 @@ func GenRelayInfo(c *gin.Context) *RelayInfo {
|
||||
if streamSupportedChannels[info.ChannelType] {
|
||||
info.SupportStreamOptions = true
|
||||
}
|
||||
|
||||
channelSetting, ok := common.GetContextKeyType[dto.ChannelSettings](c, constant.ContextKeyChannelSetting)
|
||||
if ok {
|
||||
info.ChannelSetting = channelSetting
|
||||
}
|
||||
userSetting, ok := common.GetContextKeyType[dto.UserSetting](c, constant.ContextKeyUserSetting)
|
||||
if ok {
|
||||
info.UserSetting = userSetting
|
||||
}
|
||||
|
||||
return info
|
||||
}
|
||||
|
||||
|
||||
@@ -3,7 +3,6 @@ package helper
|
||||
import (
|
||||
"fmt"
|
||||
"one-api/common"
|
||||
constant2 "one-api/constant"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/setting/ratio_setting"
|
||||
|
||||
@@ -83,11 +82,8 @@ func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens
|
||||
modelRatio, success = ratio_setting.GetModelRatio(info.OriginModelName)
|
||||
if !success {
|
||||
acceptUnsetRatio := false
|
||||
if accept, ok := info.UserSetting[constant2.UserAcceptUnsetRatioModel]; ok {
|
||||
b, ok := accept.(bool)
|
||||
if ok {
|
||||
acceptUnsetRatio = b
|
||||
}
|
||||
if info.UserSetting.AcceptUnsetRatioModel {
|
||||
acceptUnsetRatio = true
|
||||
}
|
||||
if !acceptUnsetRatio {
|
||||
return PriceData{}, fmt.Errorf("模型 %s 倍率或价格未配置,请联系管理员设置或开始自用模式;Model %s ratio or price not set, please set or start self-use mode", info.OriginModelName, info.OriginModelName)
|
||||
|
||||
@@ -34,14 +34,13 @@ func RelayMidjourneyImage(c *gin.Context) {
|
||||
}
|
||||
var httpClient *http.Client
|
||||
if channel, err := model.CacheGetChannel(midjourneyTask.ChannelId); err == nil {
|
||||
if proxy, ok := channel.GetSetting()["proxy"]; ok {
|
||||
if proxyURL, ok := proxy.(string); ok && proxyURL != "" {
|
||||
if httpClient, err = service.NewProxyHttpClient(proxyURL); err != nil {
|
||||
c.JSON(400, gin.H{
|
||||
"error": "proxy_url_invalid",
|
||||
})
|
||||
return
|
||||
}
|
||||
proxy := channel.GetSetting().Proxy
|
||||
if proxy != "" {
|
||||
if httpClient, err = service.NewProxyHttpClient(proxy); err != nil {
|
||||
c.JSON(400, gin.H{
|
||||
"error": "proxy_url_invalid",
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -175,7 +174,7 @@ func RelaySwapFace(c *gin.Context) *dto.MidjourneyResponse {
|
||||
startTime := time.Now().UnixNano() / int64(time.Millisecond)
|
||||
tokenId := c.GetInt("token_id")
|
||||
userId := c.GetInt("id")
|
||||
group := c.GetString("group")
|
||||
//group := c.GetString("group")
|
||||
channelId := c.GetInt("channel_id")
|
||||
relayInfo := relaycommon.GenRelayInfo(c)
|
||||
var swapFaceRequest dto.SwapFaceRequest
|
||||
@@ -221,8 +220,17 @@ func RelaySwapFace(c *gin.Context) *dto.MidjourneyResponse {
|
||||
tokenName := c.GetString("token_name")
|
||||
logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s", priceData.ModelPrice, priceData.GroupRatioInfo.GroupRatio, constant.MjActionSwapFace)
|
||||
other := service.GenerateMjOtherInfo(priceData)
|
||||
model.RecordConsumeLog(c, userId, channelId, 0, 0, modelName, tokenName,
|
||||
priceData.Quota, logContent, tokenId, userQuota, 0, false, group, other)
|
||||
model.RecordConsumeLog(c, relayInfo.UserId, model.RecordConsumeLogParams{
|
||||
ChannelId: channelId,
|
||||
ModelName: modelName,
|
||||
TokenName: tokenName,
|
||||
Quota: priceData.Quota,
|
||||
Content: logContent,
|
||||
TokenId: tokenId,
|
||||
UserQuota: userQuota,
|
||||
Group: relayInfo.UsingGroup,
|
||||
Other: other,
|
||||
})
|
||||
model.UpdateUserUsedQuotaAndRequestCount(userId, priceData.Quota)
|
||||
model.UpdateChannelUsedQuota(channelId, priceData.Quota)
|
||||
}
|
||||
@@ -363,7 +371,7 @@ func RelayMidjourneyTask(c *gin.Context, relayMode int) *dto.MidjourneyResponse
|
||||
|
||||
func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyResponse {
|
||||
|
||||
tokenId := c.GetInt("token_id")
|
||||
//tokenId := c.GetInt("token_id")
|
||||
//channelType := c.GetInt("channel")
|
||||
userId := c.GetInt("id")
|
||||
group := c.GetString("group")
|
||||
@@ -518,8 +526,17 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
|
||||
tokenName := c.GetString("token_name")
|
||||
logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s,ID %s", priceData.ModelPrice, priceData.GroupRatioInfo.GroupRatio, midjRequest.Action, midjResponse.Result)
|
||||
other := service.GenerateMjOtherInfo(priceData)
|
||||
model.RecordConsumeLog(c, userId, channelId, 0, 0, modelName, tokenName,
|
||||
priceData.Quota, logContent, tokenId, userQuota, 0, false, group, other)
|
||||
model.RecordConsumeLog(c, relayInfo.UserId, model.RecordConsumeLogParams{
|
||||
ChannelId: channelId,
|
||||
ModelName: modelName,
|
||||
TokenName: tokenName,
|
||||
Quota: priceData.Quota,
|
||||
Content: logContent,
|
||||
TokenId: relayInfo.TokenId,
|
||||
UserQuota: userQuota,
|
||||
Group: group,
|
||||
Other: other,
|
||||
})
|
||||
model.UpdateUserUsedQuotaAndRequestCount(userId, priceData.Quota)
|
||||
model.UpdateChannelUsedQuota(channelId, priceData.Quota)
|
||||
}
|
||||
|
||||
@@ -540,6 +540,19 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
|
||||
other["audio_input_token_count"] = audioTokens
|
||||
other["audio_input_price"] = audioInputPrice
|
||||
}
|
||||
model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, promptTokens, completionTokens, logModel,
|
||||
tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, relayInfo.UsingGroup, other)
|
||||
model.RecordConsumeLog(ctx, relayInfo.UserId, model.RecordConsumeLogParams{
|
||||
ChannelId: relayInfo.ChannelId,
|
||||
PromptTokens: promptTokens,
|
||||
CompletionTokens: completionTokens,
|
||||
ModelName: logModel,
|
||||
TokenName: tokenName,
|
||||
Quota: quota,
|
||||
Content: logContent,
|
||||
TokenId: relayInfo.TokenId,
|
||||
UserQuota: userQuota,
|
||||
UseTimeSeconds: int(useTimeSeconds),
|
||||
IsStream: relayInfo.IsStream,
|
||||
Group: relayInfo.UsingGroup,
|
||||
Other: other,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -139,8 +139,17 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) {
|
||||
if hasUserGroupRatio {
|
||||
other["user_group_ratio"] = userGroupRatio
|
||||
}
|
||||
model.RecordConsumeLog(c, relayInfo.UserId, relayInfo.ChannelId, 0, 0,
|
||||
modelName, tokenName, quota, logContent, relayInfo.TokenId, userQuota, 0, false, relayInfo.UsingGroup, other)
|
||||
model.RecordConsumeLog(c, relayInfo.UserId, model.RecordConsumeLogParams{
|
||||
ChannelId: relayInfo.ChannelId,
|
||||
ModelName: modelName,
|
||||
TokenName: tokenName,
|
||||
Quota: quota,
|
||||
Content: logContent,
|
||||
TokenId: relayInfo.TokenId,
|
||||
UserQuota: userQuota,
|
||||
Group: relayInfo.UsingGroup,
|
||||
Other: other,
|
||||
})
|
||||
model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, quota)
|
||||
model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota)
|
||||
}
|
||||
|
||||
@@ -19,6 +19,7 @@ func SetApiRouter(router *gin.Engine) {
|
||||
apiRouter.GET("/uptime/status", controller.GetUptimeKumaStatus)
|
||||
apiRouter.GET("/models", middleware.UserAuth(), controller.DashboardListModels)
|
||||
apiRouter.GET("/status/test", middleware.AdminAuth(), controller.TestStatus)
|
||||
apiRouter.GET("/jsrt/reload", middleware.AdminAuth(), controller.ReloadJSScripts)
|
||||
apiRouter.GET("/notice", controller.GetNotice)
|
||||
apiRouter.GET("/about", controller.GetAbout)
|
||||
//apiRouter.GET("/midjourney", controller.GetMidjourney)
|
||||
|
||||
@@ -3,14 +3,21 @@ package router
|
||||
import (
|
||||
"embed"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/middleware/jsrt"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func SetRouter(router *gin.Engine, buildFS embed.FS, indexPage []byte) {
|
||||
jsrtMid := jsrt.JSRuntimeMiddleware()
|
||||
if jsrtMid != nil {
|
||||
router.Use(*jsrtMid)
|
||||
}
|
||||
|
||||
SetApiRouter(router)
|
||||
SetDashboardRouter(router)
|
||||
SetRelayRouter(router)
|
||||
|
||||
@@ -12,6 +12,7 @@ func SetRelayRouter(router *gin.Engine) {
|
||||
router.Use(middleware.CORS())
|
||||
router.Use(middleware.DecompressRequestMiddleware())
|
||||
router.Use(middleware.StatsMiddleware())
|
||||
|
||||
// https://platform.openai.com/docs/api-reference/introduction
|
||||
modelsRouter := router.Group("/v1/models")
|
||||
modelsRouter.Use(middleware.TokenAuth())
|
||||
|
||||
15
scripts/01_utils.js
Normal file
15
scripts/01_utils.js
Normal file
@@ -0,0 +1,15 @@
|
||||
// Utility functions for JavaScript runtime
|
||||
|
||||
function logWithReq(req, message) {
|
||||
let reqPath = req.url || 'unknown path';
|
||||
console.log(`[${req.method} ${reqPath}] ${message}`);
|
||||
}
|
||||
|
||||
function safeJsonParse(str, defaultValue = null) {
|
||||
try {
|
||||
return JSON.parse(str);
|
||||
} catch (e) {
|
||||
console.error('JSON parse error:', e.message);
|
||||
return defaultValue;
|
||||
}
|
||||
}
|
||||
5
scripts/02_pre_process.js
Normal file
5
scripts/02_pre_process.js
Normal file
@@ -0,0 +1,5 @@
|
||||
// Pre-processing function for incoming requests
|
||||
|
||||
function preProcessRequest(req) {
|
||||
logWithReq(req, 'Pre-processing request');
|
||||
}
|
||||
5
scripts/03_post_process.js
Normal file
5
scripts/03_post_process.js
Normal file
@@ -0,0 +1,5 @@
|
||||
// Post-processing function for outgoing responses
|
||||
|
||||
function postProcessResponse(req, resp) {
|
||||
logWithReq(req, 'Post-processing response with: ' + resp.statusCode);
|
||||
}
|
||||
238
scripts/README.md
Normal file
238
scripts/README.md
Normal file
@@ -0,0 +1,238 @@
|
||||
# JavaScript Runtime Scripts
|
||||
|
||||
本目录包含 JavaScript Runtime 中间件使用的脚本文件。
|
||||
|
||||
## 脚本加载
|
||||
|
||||
- 系统会自动读取 `scripts/` 目录下的所有 `.js` 文件
|
||||
- 脚本按文件名字母顺序加载
|
||||
- 建议使用数字前缀来控制加载顺序(如:`01_utils.js`, `02_pre_process.js`)
|
||||
- 所有脚本会被合并到一个 JavaScript 运行时环境中
|
||||
|
||||
## 配置
|
||||
|
||||
通过环境变量配置:
|
||||
|
||||
- `JS_RUNTIME_ENABLED=true` - 启用 JavaScript Runtime
|
||||
- `JS_SCRIPT_DIR=scripts/` - 脚本目录路径
|
||||
- `JS_MAX_VM_COUNT=8` - 最大虚拟机数量
|
||||
- `JS_SCRIPT_TIMEOUT=5s` - 脚本执行超时时间
|
||||
- `JS_FETCH_TIMEOUT=10s` - HTTP 请求超时时间
|
||||
|
||||
更多的详细配置可以在 `.env.example` 文件中找到,并在实际使用时重命名为 `.env`。
|
||||
|
||||
## 必需的函数
|
||||
|
||||
脚本中必须定义以下两个函数:
|
||||
|
||||
### 1. preProcessRequest(req)
|
||||
|
||||
在请求被转发到后端 API 之前调用。
|
||||
|
||||
**参数:**
|
||||
|
||||
- `req`: 请求对象,包含 `method`, `url`, `headers`, `body` 等属性
|
||||
|
||||
**返回值:**
|
||||
返回一个对象,可包含以下属性:
|
||||
|
||||
- `block`: boolean - 是否阻止请求继续执行
|
||||
- `statusCode`: number - 阻止请求时返回的状态码
|
||||
- `message`: string - 阻止请求时返回的错误消息
|
||||
- `headers`: object - 要修改或添加的请求头
|
||||
- `body`: any - 修改后的请求体
|
||||
|
||||
### 2. postProcessResponse(req, resp)
|
||||
|
||||
在响应返回给客户端之前调用。
|
||||
|
||||
**参数:**
|
||||
|
||||
- `req`: 原始请求对象
|
||||
- `resp`: 响应对象,包含 `statusCode`, `headers`, `body` 等属性
|
||||
|
||||
**返回值:**
|
||||
返回一个对象,可包含以下属性:
|
||||
|
||||
- `statusCode`: number - 修改后的状态码
|
||||
- `headers`: object - 要修改或添加的响应头
|
||||
- `body`: string - 修改后的响应体
|
||||
|
||||
## 可用的全局对象和函数
|
||||
|
||||
- `console.log()`, `console.error()`, `console.warn()` - 日志输出
|
||||
- `JSON.parse()`, `JSON.stringify()` - JSON 处理
|
||||
- `fetch(url, options)` - HTTP 请求
|
||||
- `db` - 主数据库连接
|
||||
- `logdb` - 日志数据库连接
|
||||
- `setTimeout(fn, delay)` - 定时器
|
||||
|
||||
## 示例脚本
|
||||
|
||||
参考现有的示例脚本:
|
||||
|
||||
- `01_utils.js` - 工具函数
|
||||
- `02_pre_process.js` - 请求预处理
|
||||
- `03_post_process.js` - 响应后处理
|
||||
|
||||
## 使用示例
|
||||
|
||||
```js
|
||||
// 例子:基于数据库的速率限制
|
||||
if (req.url.includes("/v1/chat/completions")) {
|
||||
try {
|
||||
// Check recent requests from this IP
|
||||
var recentRequests = db.query(
|
||||
"SELECT COUNT(*) as count FROM logs WHERE created_at > ? AND ip = ?",
|
||||
Math.floor(Date.now() / 1000) - 60, // last minute
|
||||
req.remoteIP
|
||||
);
|
||||
if (recentRequests && recentRequests.length > 0 && recentRequests[0].count > 10) {
|
||||
console.log("速率限制 IP:", req.remoteIP);
|
||||
return {
|
||||
block: true,
|
||||
statusCode: 429,
|
||||
message: "超过速率限制"
|
||||
};
|
||||
}
|
||||
} catch (e) {
|
||||
console.error("Ratelimit 数据库错误:", e);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// 例子:修改请求
|
||||
if (req.url.includes("/chat/completions")) {
|
||||
try {
|
||||
var bodyObj = req.body;
|
||||
let firstMsg = { // 需要新建一个对象,不能修改原有对象
|
||||
role: "user",
|
||||
content: "喵呜🐱~嘻嘻"
|
||||
};
|
||||
bodyObj.messages[0] = firstMsg;
|
||||
console.log("Modified first message:", JSON.stringify(firstMsg));
|
||||
console.log("Modified body:", JSON.stringify(bodyObj));
|
||||
return {
|
||||
body: bodyObj,
|
||||
headers: {
|
||||
...req.headers,
|
||||
"X-Modified-Body": "true"
|
||||
}
|
||||
};
|
||||
} catch (e) {
|
||||
console.error("Failed to modify request body:", {
|
||||
message: e.message,
|
||||
stack: e.stack,
|
||||
bodyType: typeof req.body,
|
||||
url: req.url
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// 例子:读取最近一条日志,新增 jsrt 日志,并输出日志总数
|
||||
try {
|
||||
// 1. 读取最近一条日志
|
||||
var recentLogs = logdb.query(
|
||||
"SELECT id, user_id, username, content, created_at FROM logs ORDER BY id DESC LIMIT 1"
|
||||
);
|
||||
var recentLog = null;
|
||||
if (recentLogs && recentLogs.length > 0) {
|
||||
recentLog = recentLogs[0];
|
||||
console.log("最近一条日志:", JSON.stringify(recentLog));
|
||||
}
|
||||
// 2. 新增一条 jsrt 日志
|
||||
var currentTimestamp = Math.floor(Date.now() / 1000);
|
||||
var jsrtLogContent = "JSRT 预处理中间件执行 - " + req.URL + " - " + new Date().toISOString();
|
||||
var insertResult = logdb.exec(
|
||||
"INSERT INTO logs (user_id, username, created_at, type, content) VALUES (?, ?, ?, ?, ?)",
|
||||
req.UserID || 0,
|
||||
req.Username || "jsrt-system",
|
||||
currentTimestamp,
|
||||
4, // LogTypeSystem
|
||||
jsrtLogContent
|
||||
);
|
||||
if (insertResult.error) {
|
||||
console.error("插入 JSRT 日志失败:", insertResult.error);
|
||||
} else {
|
||||
console.log("成功插入 JSRT 日志,影响行数:", insertResult.rowsAffected);
|
||||
}
|
||||
// 3. 输出日志总数
|
||||
var totalLogsResult = logdb.query("SELECT COUNT(*) as total FROM logs");
|
||||
var totalLogs = 0;
|
||||
if (totalLogsResult && totalLogsResult.length > 0) {
|
||||
totalLogs = totalLogsResult[0].total;
|
||||
}
|
||||
console.log("当前日志总数:", totalLogs);
|
||||
console.log("JSRT 日志管理示例执行完成");
|
||||
} catch (e) {
|
||||
console.error("JSRT 日志管理示例执行失败:", {
|
||||
message: e.message,
|
||||
stack: e.stack,
|
||||
url: req.URL
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
// 例子:使用 fetch 调用外部 API
|
||||
if (req.url.includes("/api/uptime/status")) {
|
||||
try {
|
||||
// 使用 httpbin.org/ip 测试 fetch 功能
|
||||
var response = fetch("https://httpbin.org/ip", {
|
||||
method: "GET",
|
||||
timeout: 5, // 5秒超时
|
||||
headers: {
|
||||
"User-Agent": "JSRT/1.0"
|
||||
}
|
||||
});
|
||||
if (response.Error.length === 0) {
|
||||
// 解析响应体
|
||||
var ipData = JSON.parse(response.Body);
|
||||
// 可以根据获取到的 IP 信息进行后续处理
|
||||
if (ipData.origin) {
|
||||
console.log("外部 IP 地址:", ipData.origin);
|
||||
// 示例:记录 IP 信息到数据库
|
||||
var currentTimestamp = Math.floor(Date.now() / 1000);
|
||||
var logContent = "Fetch 示例 - 外部 IP: " + ipData.origin + " - " + new Date().toISOString();
|
||||
var insertResult = logdb.exec(
|
||||
"INSERT INTO logs (user_id, username, created_at, type, content) VALUES (?, ?, ?, ?, ?)",
|
||||
0,
|
||||
"jsrt-fetch",
|
||||
currentTimestamp,
|
||||
4, // LogTypeSystem
|
||||
logContent
|
||||
);
|
||||
if (insertResult.error) {
|
||||
console.error("记录 IP 信息失败:", insertResult.error);
|
||||
} else {
|
||||
console.log("成功记录 IP 信息到数据库");
|
||||
}
|
||||
}
|
||||
} else {
|
||||
console.error("Fetch 失败 ", response.Status, " ", response.Error);
|
||||
}
|
||||
} catch (e) {
|
||||
console.error("Fetch 失败:", {
|
||||
message: e.message,
|
||||
stack: e.stack,
|
||||
url: req.url
|
||||
});
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## 管理接口
|
||||
|
||||
### 重新加载脚本
|
||||
|
||||
```bash
|
||||
curl -X POST http://host:port/api/jsrt/reload \
|
||||
-H 'Content-Type: application/json' \
|
||||
-H 'Authorization Bearer <admin_token>'
|
||||
```
|
||||
|
||||
## 故障排除
|
||||
|
||||
- 查看服务日志中的 JavaScript 相关错误信息
|
||||
- 使用 `console.log()` 调试脚本逻辑
|
||||
- 确保 JavaScript 语法正确(不支持所有 ES6+ 特性)
|
||||
@@ -209,8 +209,21 @@ func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, mod
|
||||
}
|
||||
other := GenerateWssOtherInfo(ctx, relayInfo, usage, modelRatio, groupRatio,
|
||||
completionRatio.InexactFloat64(), audioRatio.InexactFloat64(), audioCompletionRatio.InexactFloat64(), modelPrice, priceData.GroupRatioInfo.GroupSpecialRatio)
|
||||
model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, usage.InputTokens, usage.OutputTokens, logModel,
|
||||
tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, relayInfo.UsingGroup, other)
|
||||
model.RecordConsumeLog(ctx, relayInfo.UserId, model.RecordConsumeLogParams{
|
||||
ChannelId: relayInfo.ChannelId,
|
||||
PromptTokens: usage.InputTokens,
|
||||
CompletionTokens: usage.OutputTokens,
|
||||
ModelName: logModel,
|
||||
TokenName: tokenName,
|
||||
Quota: quota,
|
||||
Content: logContent,
|
||||
TokenId: relayInfo.TokenId,
|
||||
UserQuota: userQuota,
|
||||
UseTimeSeconds: int(useTimeSeconds),
|
||||
IsStream: relayInfo.IsStream,
|
||||
Group: relayInfo.UsingGroup,
|
||||
Other: other,
|
||||
})
|
||||
}
|
||||
|
||||
func PostClaudeConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
|
||||
@@ -286,8 +299,22 @@ func PostClaudeConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
|
||||
|
||||
other := GenerateClaudeOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio,
|
||||
cacheTokens, cacheRatio, cacheCreationTokens, cacheCreationRatio, modelPrice, priceData.GroupRatioInfo.GroupSpecialRatio)
|
||||
model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, promptTokens, completionTokens, modelName,
|
||||
tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, relayInfo.UsingGroup, other)
|
||||
model.RecordConsumeLog(ctx, relayInfo.UserId, model.RecordConsumeLogParams{
|
||||
ChannelId: relayInfo.ChannelId,
|
||||
PromptTokens: promptTokens,
|
||||
CompletionTokens: completionTokens,
|
||||
ModelName: modelName,
|
||||
TokenName: tokenName,
|
||||
Quota: quota,
|
||||
Content: logContent,
|
||||
TokenId: relayInfo.TokenId,
|
||||
UserQuota: userQuota,
|
||||
UseTimeSeconds: int(useTimeSeconds),
|
||||
IsStream: relayInfo.IsStream,
|
||||
Group: relayInfo.UsingGroup,
|
||||
Other: other,
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
func CalcOpenRouterCacheCreateTokens(usage dto.Usage, priceData helper.PriceData) int {
|
||||
@@ -384,8 +411,21 @@ func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
|
||||
}
|
||||
other := GenerateAudioOtherInfo(ctx, relayInfo, usage, modelRatio, groupRatio,
|
||||
completionRatio.InexactFloat64(), audioRatio.InexactFloat64(), audioCompletionRatio.InexactFloat64(), modelPrice, priceData.GroupRatioInfo.GroupSpecialRatio)
|
||||
model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, usage.PromptTokens, usage.CompletionTokens, logModel,
|
||||
tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, relayInfo.UsingGroup, other)
|
||||
model.RecordConsumeLog(ctx, relayInfo.UserId, model.RecordConsumeLogParams{
|
||||
ChannelId: relayInfo.ChannelId,
|
||||
PromptTokens: usage.PromptTokens,
|
||||
CompletionTokens: usage.CompletionTokens,
|
||||
ModelName: logModel,
|
||||
TokenName: tokenName,
|
||||
Quota: quota,
|
||||
Content: logContent,
|
||||
TokenId: relayInfo.TokenId,
|
||||
UserQuota: userQuota,
|
||||
UseTimeSeconds: int(useTimeSeconds),
|
||||
IsStream: relayInfo.IsStream,
|
||||
Group: relayInfo.UsingGroup,
|
||||
Other: other,
|
||||
})
|
||||
}
|
||||
|
||||
func PreConsumeTokenQuota(relayInfo *relaycommon.RelayInfo, quota int) error {
|
||||
@@ -447,8 +487,8 @@ func checkAndSendQuotaNotify(relayInfo *relaycommon.RelayInfo, quota int, preCon
|
||||
gopool.Go(func() {
|
||||
userSetting := relayInfo.UserSetting
|
||||
threshold := common.QuotaRemindThreshold
|
||||
if userCustomThreshold, ok := userSetting[constant.UserSettingQuotaWarningThreshold]; ok {
|
||||
threshold = int(userCustomThreshold.(float64))
|
||||
if userSetting.QuotaWarningThreshold != 0 {
|
||||
threshold = int(userSetting.QuotaWarningThreshold)
|
||||
}
|
||||
|
||||
//noMoreQuota := userCache.Quota-(quota+preConsumedQuota) <= 0
|
||||
|
||||
@@ -3,7 +3,6 @@ package service
|
||||
import (
|
||||
"fmt"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"one-api/dto"
|
||||
"one-api/model"
|
||||
"strings"
|
||||
@@ -17,10 +16,10 @@ func NotifyRootUser(t string, subject string, content string) {
|
||||
}
|
||||
}
|
||||
|
||||
func NotifyUser(userId int, userEmail string, userSetting map[string]interface{}, data dto.Notify) error {
|
||||
notifyType, ok := userSetting[constant.UserSettingNotifyType]
|
||||
if !ok {
|
||||
notifyType = constant.NotifyTypeEmail
|
||||
func NotifyUser(userId int, userEmail string, userSetting dto.UserSetting, data dto.Notify) error {
|
||||
notifyType := userSetting.NotifyType
|
||||
if notifyType == "" {
|
||||
notifyType = dto.NotifyTypeEmail
|
||||
}
|
||||
|
||||
// Check notification limit
|
||||
@@ -34,34 +33,23 @@ func NotifyUser(userId int, userEmail string, userSetting map[string]interface{}
|
||||
}
|
||||
|
||||
switch notifyType {
|
||||
case constant.NotifyTypeEmail:
|
||||
case dto.NotifyTypeEmail:
|
||||
// check setting email
|
||||
if settingEmail, ok := userSetting[constant.UserSettingNotificationEmail]; ok {
|
||||
userEmail = settingEmail.(string)
|
||||
}
|
||||
userEmail = userSetting.NotificationEmail
|
||||
if userEmail == "" {
|
||||
common.SysLog(fmt.Sprintf("user %d has no email, skip sending email", userId))
|
||||
return nil
|
||||
}
|
||||
return sendEmailNotify(userEmail, data)
|
||||
case constant.NotifyTypeWebhook:
|
||||
webhookURL, ok := userSetting[constant.UserSettingWebhookUrl]
|
||||
if !ok {
|
||||
case dto.NotifyTypeWebhook:
|
||||
webhookURLStr := userSetting.WebhookUrl
|
||||
if webhookURLStr == "" {
|
||||
common.SysError(fmt.Sprintf("user %d has no webhook url, skip sending webhook", userId))
|
||||
return nil
|
||||
}
|
||||
webhookURLStr, ok := webhookURL.(string)
|
||||
if !ok {
|
||||
common.SysError(fmt.Sprintf("user %d webhook url is not string type", userId))
|
||||
return nil
|
||||
}
|
||||
|
||||
// 获取 webhook secret
|
||||
var webhookSecret string
|
||||
if secret, ok := userSetting[constant.UserSettingWebhookSecret]; ok {
|
||||
webhookSecret, _ = secret.(string)
|
||||
}
|
||||
|
||||
webhookSecret := userSetting.WebhookSecret
|
||||
return SendWebhookNotify(webhookURLStr, webhookSecret, data)
|
||||
}
|
||||
return nil
|
||||
|
||||
@@ -6,8 +6,11 @@ import (
|
||||
)
|
||||
|
||||
var Chats = []map[string]string{
|
||||
//{
|
||||
// "ChatGPT Next Web 官方示例": "https://app.nextchat.dev/#/?settings={\"key\":\"{key}\",\"url\":\"{address}\"}",
|
||||
//},
|
||||
{
|
||||
"ChatGPT Next Web 官方示例": "https://app.nextchat.dev/#/?settings={\"key\":\"{key}\",\"url\":\"{address}\"}",
|
||||
"Cherry Studio": "cherrystudio://providers/api-keys?v=1&data={cherryConfig}",
|
||||
},
|
||||
{
|
||||
"Lobe Chat 官方示例": "https://chat-preview.lobehub.com/?settings={\"keyVaults\":{\"openai\":{\"apiKey\":\"{key}\",\"baseURL\":\"{address}/v1\"}}}",
|
||||
|
||||
@@ -1461,9 +1461,9 @@ const ChannelsTable = () => {
|
||||
|
||||
const fixChannelsAbilities = async () => {
|
||||
const res = await API.post(`/api/channel/fix`);
|
||||
const { success, message, data } = res.data;
|
||||
const { success, message, data } = res.data;
|
||||
if (success) {
|
||||
showSuccess(t('已修复 ${data} 个通道!').replace('${data}', data));
|
||||
showSuccess(t('已修复 ${success} 个通道,失败 ${fails} 个通道。').replace('${success}', data.success).replace('${fails}', data.fails));
|
||||
await refresh();
|
||||
} else {
|
||||
showError(message);
|
||||
|
||||
@@ -432,9 +432,22 @@ const TokensTable = () => {
|
||||
if (serverAddress === '') {
|
||||
serverAddress = window.location.origin;
|
||||
}
|
||||
let encodedServerAddress = encodeURIComponent(serverAddress);
|
||||
url = url.replaceAll('{address}', encodedServerAddress);
|
||||
url = url.replaceAll('{key}', 'sk-' + record.key);
|
||||
if (url.includes('{cherryConfig}') === true) {
|
||||
let cherryConfig = {
|
||||
id: 'new-api',
|
||||
baseUrl: serverAddress,
|
||||
apiKey: 'sk-' + record.key,
|
||||
}
|
||||
// 替换 {cherryConfig} 为base64编码的JSON字符串
|
||||
let encodedConfig = encodeURIComponent(
|
||||
btoa(JSON.stringify(cherryConfig))
|
||||
);
|
||||
url = url.replaceAll('{cherryConfig}', encodedConfig);
|
||||
} else {
|
||||
let encodedServerAddress = encodeURIComponent(serverAddress);
|
||||
url = url.replaceAll('{address}', encodedServerAddress);
|
||||
url = url.replaceAll('{key}', 'sk-' + record.key);
|
||||
}
|
||||
|
||||
window.open(url, '_blank');
|
||||
};
|
||||
|
||||
@@ -240,7 +240,7 @@ const EditChannel = (props) => {
|
||||
if (isEdit) {
|
||||
// 如果是编辑模式,使用已有的channel id获取模型列表
|
||||
const res = await API.get('/api/channel/fetch_models/' + channelId);
|
||||
if (res.data && res.data?.success) {
|
||||
if (res.data && res.data.success) {
|
||||
models.push(...res.data.data);
|
||||
} else {
|
||||
err = true;
|
||||
|
||||
Reference in New Issue
Block a user