diff --git a/.env.example b/.env.example index 3d2d3b7ad..99eda4025 100644 --- a/.env.example +++ b/.env.example @@ -74,11 +74,16 @@ # 如果是主节点则为master # NODE_TYPE=master -# JavaScript 运行时配置 -# 是否启用(默认:true) -# JS_RUNTIME_ENABLED=true -# 脚本文件夹(可选,默认 scripts/) -# JS_SCRIPTS_DIR=./custom_scripts -# 启用调试日志(可选) -# JS_RUNTIME_DEBUG=true + +# JavaScript 运行时配置 +# 是否启用(默认:false) +# JS_RUNTIME_ENABLED=true +# 最大虚拟机数量(默认:8) +# JS_MAX_VM_COUNT= +# 运行超时时间(单位:秒,默认:5) +# JS_SCRIPT_TIMEOUT= +# 预处理脚本路径(默认:scripts/pre_process.js) +# JS_PREPROCESS_SCRIPT_PATH= +# 后处理脚本路径(默认:scripts/post_process.js) +# JS_POSTPROCESS_SCRIPT_PATH= diff --git a/middleware/js_rt.go b/middleware/js_rt.go index b8c3216d0..c9b12c91a 100644 --- a/middleware/js_rt.go +++ b/middleware/js_rt.go @@ -5,11 +5,13 @@ import ( "encoding/json" "fmt" "io" + "maps" "net/http" "net/url" "one-api/common" "one-api/model" "os" + "strconv" "strings" "sync" "time" @@ -19,45 +21,118 @@ import ( "gorm.io/gorm" ) -type JSRuntime struct { - vm *goja.Runtime - mu sync.RWMutex +/// Runtime 配置 +type JSRuntimeConfig struct { + Enabled bool `json:"enabled"` + MaxVMCount int `json:"max_vm_count"` + ScriptTimeout time.Duration `json:"script_timeout"` + PreScriptPath string `json:"pre_script_path"` + PostScriptPath string `json:"post_script_path"` } +/// 池化 +type JSRuntimePool struct { + pool chan *goja.Runtime + maxSize int + createFunc func() *goja.Runtime + scripts map[string]string + mu sync.RWMutex +} + +/// 上下文 type JSContext struct { Method string `json:"method"` URL string `json:"url"` Headers map[string]string `json:"headers"` + // 可能是string、[]byte、map等 Body any `json:"body"` - Query map[string]string `json:"query"` - Params map[string]string `json:"params"` UserAgent string `json:"userAgent"` RemoteIP string `json:"remoteIP"` - ContentType string `json:"contentType"` Extra map[string]any `json:"extra"` } +type JSResponse struct { + StatusCode int `json:"statusCode"` + Headers map[string]string `json:"headers"` + Body string `json:"body"` +} + +type JSDatabase struct { + db *gorm.DB +} + +type responseWriter struct { + gin.ResponseWriter + body *bytes.Buffer + statusCode int + headerMap http.Header + written bool +} + +var ( + jsRuntimePool *JSRuntimePool + jsPoolOnce sync.Once + jsConfig = JSRuntimeConfig{} +) + +const ( + defaultPreScriptPath = "scripts/pre_process.js" + defaultPostScriptPath = "scripts/post_process.js" + defaultScriptTimeout = 5 * time.Second + defaultMaxVMCount = 8 +) + +func init() { + if enabled := os.Getenv("JS_RUNTIME_ENABLED"); enabled != "" { + jsConfig.Enabled = enabled == "true" + } + + if maxCount := os.Getenv("JS_MAX_VM_COUNT"); maxCount != "" { + if count, err := strconv.Atoi(maxCount); err == nil && count > 0 { + jsConfig.MaxVMCount = count + } + } else { + jsConfig.MaxVMCount = defaultMaxVMCount + } + + if timeout := os.Getenv("JS_SCRIPT_TIMEOUT"); timeout != "" { + if t, err := time.ParseDuration(timeout + "s"); err == nil && t > 0 { + jsConfig.ScriptTimeout = t + } + } else { + jsConfig.ScriptTimeout = defaultScriptTimeout + } + + jsConfig.PreScriptPath = os.Getenv("JS_PREPROCESS_SCRIPT_PATH") + if jsConfig.PreScriptPath == "" { + jsConfig.PreScriptPath = defaultPreScriptPath + } + + jsConfig.PostScriptPath = os.Getenv("JS_POSTPROCESS_SCRIPT_PATH") + if jsConfig.PostScriptPath == "" { + jsConfig.PostScriptPath = defaultPostScriptPath + } +} + func parseBodyByType(bodyBytes []byte, contentType string) any { if len(bodyBytes) == 0 { return "" } bodyStr := string(bodyBytes) + contentLower := strings.ToLower(contentType) - // 根据 Content-Type 判断 switch { - case strings.Contains(contentType, "application/json"): + case strings.Contains(contentLower, "application/json"): var jsonObj any if err := json.Unmarshal(bodyBytes, &jsonObj); err == nil { return jsonObj } - return bodyStr // JSON 解析失败时返回字符串 + return bodyStr - case strings.Contains(contentType, "application/x-www-form-urlencoded"): - // 解析为 map[string]string - values, err := url.ParseQuery(bodyStr) - if err == nil { - result := make(map[string]string) + case strings.Contains(contentLower, "application/x-www-form-urlencoded"): + if values, err := url.ParseQuery(bodyStr); err == nil { + result := make(map[string]string, len(values)) for k, v := range values { if len(v) > 0 { result[k] = v[0] @@ -67,24 +142,22 @@ func parseBodyByType(bodyBytes []byte, contentType string) any { } return bodyStr - case strings.Contains(contentType, "multipart/form-data"): - // multipart 数据保持为字节数组,JS 中需要特殊处理 + case strings.Contains(contentLower, "multipart/form-data"): return bodyBytes - case strings.Contains(contentType, "text/"): - // 文本类型返回字符串 + case strings.Contains(contentLower, "text/"): return bodyStr default: - // 尝试 JSON 解析 + // 尝试JSON解析 var jsonObj any - if err := json.Unmarshal(bodyBytes, &jsonObj); err == nil { + if json.Unmarshal(bodyBytes, &jsonObj) == nil { return jsonObj } - // 检查是否是 URL encoded + // 尝试form解析 if values, err := url.ParseQuery(bodyStr); err == nil && len(values) > 0 { - result := make(map[string]string) + result := make(map[string]string, len(values)) for k, v := range values { if len(v) > 0 { result[k] = v[0] @@ -93,38 +166,10 @@ func parseBodyByType(bodyBytes []byte, contentType string) any { return result } - // 二进制数据或未知格式 - if isBinary(bodyBytes) { - return bodyBytes - } - return bodyStr } } -// 检查是否为二进制数据 -func isBinary(data []byte) bool { - if len(data) == 0 { - return false - } - - // 检查前 512 字节中是否包含控制字符(除了常见的换行符等) - checkLen := min(len(data), 512) - - for i := range checkLen { - b := data[i] - // 控制字符但不是常见的文本字符 - if b < 32 && b != 9 && b != 10 && b != 13 { - return true - } - // 非 UTF-8 字符 - if b > 127 { - return true - } - } - return false -} - func createJSContext(c *gin.Context) *JSContext { var bodyBytes []byte if c.Request != nil && c.Request.Body != nil { @@ -142,22 +187,6 @@ func createJSContext(c *gin.Context) *JSContext { } } - // query parameters map - query := make(map[string]string) - if c.Request != nil && c.Request.URL != nil { - for key, values := range c.Request.URL.Query() { - if len(values) > 0 { - query[key] = values[0] - } - } - } - - // path parameters map - params := make(map[string]string) - for _, param := range c.Params { - params[param.Key] = param.Value - } - method := "" url := "" userAgent := "" @@ -177,7 +206,6 @@ func createJSContext(c *gin.Context) *JSContext { remoteIP = c.ClientIP() } - // 智能解析 body parsedBody := parseBodyByType(bodyBytes, contentType) return &JSContext{ @@ -185,48 +213,58 @@ func createJSContext(c *gin.Context) *JSContext { URL: url, Headers: headers, Body: parsedBody, - Query: query, - Params: params, UserAgent: userAgent, RemoteIP: remoteIP, - ContentType: contentType, Extra: make(map[string]any), } } -type JSResponse struct { - StatusCode int `json:"statusCode"` - Headers map[string]string `json:"headers"` - Body string `json:"body"` -} +func NewJSRuntimePool(maxSize int) *JSRuntimePool { + pool := &JSRuntimePool{ + pool: make(chan *goja.Runtime, maxSize), + maxSize: maxSize, + scripts: make(map[string]string), + } -type JSDatabase struct { - db *gorm.DB -} + pool.createFunc = func() *goja.Runtime { + vm := goja.New() + pool.setupGlobals(vm) + pool.loadScripts(vm) + return vm + } -var ( - jsRuntime *JSRuntime - jsRuntimeOnce sync.Once -) - -func initJSRuntime() *JSRuntime { - jsRuntimeOnce.Do(func() { - jsRuntime = &JSRuntime{ - vm: goja.New(), + // 预创建 + preCreate := min(maxSize/2, 4) + for range preCreate { + select { + case pool.pool <- pool.createFunc(): + default: } - jsRuntime.setupGlobals() - jsRuntime.loadScripts() - common.SysLog("JavaScript runtime initialized successfully") - }) - return jsRuntime + } + + return pool } -func (js *JSRuntime) setupGlobals() { - js.mu.Lock() - defer js.mu.Unlock() +func (p *JSRuntimePool) Get() *goja.Runtime { + select { + case vm := <-p.pool: + return vm + default: + return p.createFunc() + } +} +func (p *JSRuntimePool) Put(vm *goja.Runtime) { + select { + case p.pool <- vm: + default: + // 池满,丢弃VM让GC回收 + } +} + +func (p *JSRuntimePool) setupGlobals(vm *goja.Runtime) { // console - console := js.vm.NewObject() + console := vm.NewObject() console.Set("log", func(args ...any) { var strs []string for _, arg := range args { @@ -241,31 +279,88 @@ func (js *JSRuntime) setupGlobals() { } common.SysError("JS: " + strings.Join(strs, " ")) }) - js.vm.Set("console", console) + vm.Set("console", console) // JSON - jsonObj := js.vm.NewObject() + jsonObj := vm.NewObject() jsonObj.Set("parse", func(str string) any { var result any err := json.Unmarshal([]byte(str), &result) if err != nil { - panic(js.vm.ToValue(err.Error())) + panic(vm.ToValue(err.Error())) } return result }) jsonObj.Set("stringify", func(obj any) string { data, err := json.Marshal(obj) if err != nil { - panic(js.vm.ToValue(err.Error())) + panic(vm.ToValue(err.Error())) } return string(data) }) - js.vm.Set("JSON", jsonObj) + vm.Set("JSON", jsonObj) - js.vm.Set("db", &JSDatabase{db: model.DB}) + vm.Set("db", &JSDatabase{db: model.DB}) +} + +func (p *JSRuntimePool) loadScripts(vm *goja.Runtime) { + p.mu.RLock() + defer p.mu.RUnlock() + + // 加载预处理脚本 + if script, exists := p.scripts["pre"]; exists { + if _, err := vm.RunString(script); err != nil { + common.SysError("Failed to load pre_process.js: " + err.Error()) + } + } else if preScript, err := os.ReadFile(jsConfig.PreScriptPath); err == nil { + p.scripts["pre"] = string(preScript) + if _, err = vm.RunString(string(preScript)); err != nil { + common.SysError("Failed to load pre_process.js: " + err.Error()) + } else { + common.SysLog("Loaded pre_process.js") + } + } + + // 加载后处理脚本 + if script, exists := p.scripts["post"]; exists { + if _, err := vm.RunString(script); err != nil { + common.SysError("Failed to load post_process.js: " + err.Error()) + } + } else if postScript, err := os.ReadFile(jsConfig.PostScriptPath); err == nil { + p.scripts["post"] = string(postScript) + if _, err = vm.RunString(string(postScript)); err != nil { + common.SysError("Failed to load post_process.js: " + err.Error()) + } else { + common.SysLog("Loaded post_process.js") + } + } +} + +func (p *JSRuntimePool) ReloadScripts() { + p.mu.Lock() + defer p.mu.Unlock() + + // 清空缓存的脚本 + p.scripts = make(map[string]string) + + // 清空VM池,强制重新创建 + for { + select { + case <-p.pool: + default: + goto done + } + } +done: + common.SysLog("JavaScript scripts reloaded") } func (jsdb *JSDatabase) Query(sql string, args ...any) []map[string]any { + if jsdb.db == nil { + common.SysError("JS DB is nil") + return nil + } + rows, err := jsdb.db.Raw(sql, args...).Rows() if err != nil { common.SysError("JS DB Query Error: " + err.Error()) @@ -279,7 +374,7 @@ func (jsdb *JSDatabase) Query(sql string, args ...any) []map[string]any { return nil } - var results []map[string]any + results := make([]map[string]any, 0, 100) for rows.Next() { values := make([]any, len(columns)) valuePtrs := make([]any, len(columns)) @@ -292,7 +387,7 @@ func (jsdb *JSDatabase) Query(sql string, args ...any) []map[string]any { continue } - row := make(map[string]any) + row := make(map[string]any, len(columns)) for i, col := range columns { val := values[i] if b, ok := val.([]byte); ok { @@ -308,6 +403,13 @@ func (jsdb *JSDatabase) Query(sql string, args ...any) []map[string]any { } func (jsdb *JSDatabase) Exec(sql string, args ...any) map[string]any { + if jsdb.db == nil { + return map[string]any{ + "rowsAffected": int64(0), + "error": "database is nil", + } + } + result := jsdb.db.Exec(sql, args...) return map[string]any{ "rowsAffected": result.RowsAffected, @@ -315,37 +417,14 @@ func (jsdb *JSDatabase) Exec(sql string, args ...any) map[string]any { } } -func (js *JSRuntime) loadScripts() { - // 加载预处理脚本 - if preScript, err := os.ReadFile("scripts/pre_process.js"); err == nil { - js.mu.Lock() - _, err = js.vm.RunString(string(preScript)) - js.mu.Unlock() - if err != nil { - common.SysError("Failed to load pre_process.js: " + err.Error()) - } else { - common.SysLog("Loaded pre_process.js") - } - } - - // 加载后处理脚本 - if postScript, err := os.ReadFile("scripts/post_process.js"); err == nil { - js.mu.Lock() - _, err = js.vm.RunString(string(postScript)) - js.mu.Unlock() - if err != nil { - common.SysError("Failed to load post_process.js: " + err.Error()) - } else { - common.SysLog("Loaded post_process.js") - } - } +func initJSRuntimePool() *JSRuntimePool { + jsPoolOnce.Do(func() { + jsRuntimePool = NewJSRuntimePool(jsConfig.MaxVMCount) + common.SysLog("JavaScript runtime pool initialized successfully") + }) + return jsRuntimePool } -func (js *JSRuntime) ReloadScripts() { - js.loadScripts() -} - -// validateGinContext checks if the gin context is properly initialized func validateGinContext(c *gin.Context) error { if c == nil { return fmt.Errorf("gin context is nil") @@ -356,18 +435,18 @@ func validateGinContext(c *gin.Context) error { return nil } -func (js *JSRuntime) PreProcessRequest(c *gin.Context) error { +func (p *JSRuntimePool) PreProcessRequest(c *gin.Context) error { if err := validateGinContext(c); err != nil { common.SysError("JS PreProcess Validation Error: " + err.Error()) return err } - js.mu.RLock() - preProcessFunc := js.vm.Get("preProcessRequest") - js.mu.RUnlock() + vm := p.Get() + defer p.Put(vm) + preProcessFunc := vm.Get("preProcessRequest") if preProcessFunc == nil || goja.IsUndefined(preProcessFunc) { - return nil // 没有预处理函数 + return nil } jsCtx := createJSContext(c) @@ -375,16 +454,39 @@ func (js *JSRuntime) PreProcessRequest(c *gin.Context) error { return fmt.Errorf("failed to create JS context") } - js.mu.Lock() - defer js.mu.Unlock() - - js.vm.Set("ctx", jsCtx) - fn, ok := goja.AssertFunction(preProcessFunc) - if !ok { - return fmt.Errorf("preProcessRequest is not a function") + type jsResult struct { + result goja.Value + err error } - result, err := fn(goja.Undefined(), js.vm.ToValue(jsCtx)) + resultChan := make(chan jsResult, 1) + go func() { + defer func() { + if r := recover(); r != nil { + resultChan <- jsResult{err: fmt.Errorf("JS panic: %v", r)} + } + }() + + vm.Set("ctx", jsCtx) + fn, ok := goja.AssertFunction(preProcessFunc) + if !ok { + resultChan <- jsResult{err: fmt.Errorf("preProcessRequest is not a function")} + return + } + + result, err := fn(goja.Undefined(), vm.ToValue(jsCtx)) + resultChan <- jsResult{result: result, err: err} + }() + + var err error + var result goja.Value + select { + case res := <-resultChan: + result, err = res.result, res.err + // 超时控制 + case <-time.After(jsConfig.ScriptTimeout): + return fmt.Errorf("JS preProcess timeout after %v", jsConfig.ScriptTimeout) + } if err != nil { common.SysError("JS PreProcess Error: " + err.Error()) @@ -413,7 +515,6 @@ func (js *JSRuntime) PreProcessRequest(c *gin.Context) error { } else { common.SysError("JS PreProcess JSON Marshal Error: " + err.Error()) } - default: common.SysError("JS PreProcess Unsupported Body Type: " + fmt.Sprintf("%T", newBody)) } @@ -458,18 +559,18 @@ func (js *JSRuntime) PreProcessRequest(c *gin.Context) error { return nil } -func (js *JSRuntime) PostProcessResponse(c *gin.Context, statusCode int, body []byte) (int, []byte, error) { +func (p *JSRuntimePool) PostProcessResponse(c *gin.Context, statusCode int, body []byte) (int, []byte, error) { if err := validateGinContext(c); err != nil { common.SysError("JS PostProcess Validation Error: " + err.Error()) return statusCode, body, err } - js.mu.RLock() - postProcessFunc := js.vm.Get("postProcessResponse") - js.mu.RUnlock() + vm := p.Get() + defer p.Put(vm) + postProcessFunc := vm.Get("postProcessResponse") if postProcessFunc == nil || goja.IsUndefined(postProcessFunc) { - return statusCode, body, nil // 没有后处理 + return statusCode, body, nil } jsCtx := createJSContext(c) @@ -492,15 +593,40 @@ func (js *JSRuntime) PostProcessResponse(c *gin.Context, statusCode int, body [] } } - js.mu.Lock() - defer js.mu.Unlock() - - js.vm.Set("ctx", jsCtx) - fn, ok := goja.AssertFunction(postProcessFunc) - if !ok { - return statusCode, body, fmt.Errorf("postProcessResponse is not a function") + type jsResult struct { + result goja.Value + err error + } + + resultChan := make(chan jsResult, 1) + go func() { + defer func() { + if r := recover(); r != nil { + resultChan <- jsResult{err: fmt.Errorf("JS panic: %v", r)} + } + }() + + vm.Set("ctx", jsCtx) + fn, ok := goja.AssertFunction(postProcessFunc) + if !ok { + resultChan <- jsResult{err: fmt.Errorf("postProcessResponse is not a function")} + return + } + + result, err := fn(goja.Undefined(), vm.ToValue(jsCtx), vm.ToValue(jsResponse)) + resultChan <- jsResult{result: result, err: err} + }() + + var result goja.Value + var err error + + select { + case res := <-resultChan: + result, err = res.result, res.err + // 超时控制 + case <-time.After(jsConfig.ScriptTimeout): + return statusCode, body, fmt.Errorf("JS postProcess timeout after %v", jsConfig.ScriptTimeout) } - result, err := fn(goja.Undefined(), js.vm.ToValue(jsCtx), js.vm.ToValue(jsResponse)) if err != nil { common.SysError("JS PostProcess Error: " + err.Error()) @@ -538,54 +664,102 @@ func (js *JSRuntime) PostProcessResponse(c *gin.Context, statusCode int, body [] return statusCode, body, nil } +func (p *JSRuntimePool) hasPostProcessFunction() bool { + vm := p.Get() + defer p.Put(vm) + postProcessFunc := vm.Get("postProcessResponse") + return postProcessFunc != nil && !goja.IsUndefined(postProcessFunc) +} + +func newResponseWriter(w gin.ResponseWriter) *responseWriter { + return &responseWriter{ + ResponseWriter: w, + body: &bytes.Buffer{}, + statusCode: 200, + headerMap: make(http.Header), + written: false, + } +} + +func (w *responseWriter) Write(data []byte) (int, error) { + if !w.written { + w.WriteHeader(200) + } + return w.body.Write(data) +} + +func (w *responseWriter) WriteString(s string) (int, error) { + if !w.written { + w.WriteHeader(200) + } + return w.body.WriteString(s) +} + +func (w *responseWriter) WriteHeader(statusCode int) { + if w.written { + return + } + w.statusCode = statusCode + w.written = true + + maps.Copy(w.headerMap, w.ResponseWriter.Header()) +} + +func (w *responseWriter) Header() http.Header { + if w.headerMap == nil { + w.headerMap = make(http.Header) + } + return w.headerMap +} + func JSRuntimeMiddleware() gin.HandlerFunc { - if os.Getenv("JS_RUNTIME_ENABLED") != "true" { + if !jsConfig.Enabled { return func(c *gin.Context) { c.Next() } } - runtime := initJSRuntime() + pool := initJSRuntimePool() return func(c *gin.Context) { start := time.Now() // 预处理 common.SysLog("JS Runtime PreProcessing Request: " + c.Request.Method + " " + c.Request.URL.String()) - if err := runtime.PreProcessRequest(c); err != nil { - // 如果预处理返回错误,说明请求被阻止 + if err := pool.PreProcessRequest(c); err != nil { common.SysError("JS Runtime PreProcess Error: " + err.Error()) return } common.SysLog("JS Runtime PreProcessing Completed") // 后处理 - if runtime.hasPostProcessFunction() { + if pool.hasPostProcessFunction() { common.SysLog("JS Runtime PostProcessing Response") - writer := &responseWriter{ - ResponseWriter: c.Writer, - body: &bytes.Buffer{}, - statusCode: 200, // 默认状态码 - headers: make(map[string]string), - } + writer := newResponseWriter(c.Writer) c.Writer = writer c.Next() // 后处理响应 if writer.body.Len() > 0 { - statusCode, body, err := runtime.PostProcessResponse(c, writer.statusCode, writer.body.Bytes()) + statusCode, body, err := pool.PostProcessResponse(c, writer.statusCode, writer.body.Bytes()) if err == nil { // 更新响应 c.Writer = writer.ResponseWriter - // Clear any existing content-length header to let Gin handle it c.Writer.Header().Del("Content-Length") + + // 设置修改后的headers + for k, v := range writer.headerMap { + for _, value := range v { + c.Writer.Header().Add(k, value) + } + } + c.Status(statusCode) c.Writer.Write(body) common.SysLog(fmt.Sprintf("JS Runtime PostProcessing Completed with status %d", statusCode)) } else { // 出错时返回原始响应 c.Writer = writer.ResponseWriter - // Clear any existing content-length header to let Gin handle it c.Writer.Header().Del("Content-Length") c.Status(writer.statusCode) c.Writer.Write(writer.body.Bytes()) @@ -609,52 +783,9 @@ func JSRuntimeMiddleware() gin.HandlerFunc { } } -func (js *JSRuntime) hasPostProcessFunction() bool { - js.mu.RLock() - defer js.mu.RUnlock() - postProcessFunc := js.vm.Get("postProcessResponse") - return postProcessFunc != nil && !goja.IsUndefined(postProcessFunc) -} - -type responseWriter struct { - gin.ResponseWriter - body *bytes.Buffer - statusCode int - written bool - headers map[string]string -} - -func (w *responseWriter) Write(data []byte) (int, error) { - if !w.written { - w.statusCode = 200 - w.written = true - } - w.body.Write(data) - return len(data), nil -} - -func (w *responseWriter) WriteString(s string) (int, error) { - if !w.written { - w.statusCode = 200 - w.written = true - } - w.body.WriteString(s) - return len(s), nil -} - -func (w *responseWriter) WriteHeader(statusCode int) { - w.statusCode = statusCode - w.written = true - // 不立即调用原始的 WriteHeader,等后处理完成后再调用 -} - -func (w *responseWriter) Header() http.Header { - return w.ResponseWriter.Header() -} - func ReloadJSScripts() { - if jsRuntime != nil { - jsRuntime.ReloadScripts() + if jsRuntimePool != nil { + jsRuntimePool.ReloadScripts() common.SysLog("JavaScript scripts reloaded") } } diff --git a/router/main.go b/router/main.go index 235764270..2dca96c38 100644 --- a/router/main.go +++ b/router/main.go @@ -5,6 +5,7 @@ import ( "fmt" "net/http" "one-api/common" + "one-api/middleware" "os" "strings" @@ -12,6 +13,7 @@ import ( ) func SetRouter(router *gin.Engine, buildFS embed.FS, indexPage []byte) { + router.Use(middleware.JSRuntimeMiddleware()) SetApiRouter(router) SetDashboardRouter(router) SetRelayRouter(router) diff --git a/router/relay-router.go b/router/relay-router.go index 2b73eefbc..8d15d975a 100644 --- a/router/relay-router.go +++ b/router/relay-router.go @@ -12,7 +12,6 @@ func SetRelayRouter(router *gin.Engine) { router.Use(middleware.CORS()) router.Use(middleware.DecompressRequestMiddleware()) router.Use(middleware.StatsMiddleware()) - router.Use(middleware.JSRuntimeMiddleware()) // 放在最后 // https://platform.openai.com/docs/api-reference/introduction modelsRouter := router.Group("/v1/models")