Files
new-api/middleware/jsrt/jsrt.go
lollipopkit🏳️‍⚧️ fd51f71e0f (jsrt) opt.: struct
2025-07-15 22:19:02 +08:00

517 lines
12 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package jsrt
import (
"bytes"
"crypto/tls"
"encoding/json"
"fmt"
"io"
"net/http"
"one-api/common"
"one-api/model"
"os"
"strings"
"sync"
"time"
"github.com/dop251/goja"
"github.com/gin-gonic/gin"
)
/// 池化
type JSRuntimePool struct {
pool chan *goja.Runtime
maxSize int
createFunc func() *goja.Runtime
scripts map[string]string
mu sync.RWMutex
httpClient *http.Client
}
var (
jsRuntimePool *JSRuntimePool
jsPoolOnce sync.Once
)
func NewJSRuntimePool(maxSize int) *JSRuntimePool {
// 创建HTTP客户端
httpClient := &http.Client{
Timeout: jsConfig.FetchTimeout,
Transport: &http.Transport{
TLSClientConfig: &tls.Config{
InsecureSkipVerify: false,
},
MaxIdleConns: 100,
MaxIdleConnsPerHost: 10,
IdleConnTimeout: 90 * time.Second,
},
}
pool := &JSRuntimePool{
pool: make(chan *goja.Runtime, maxSize),
maxSize: maxSize,
scripts: make(map[string]string),
httpClient: httpClient,
}
pool.createFunc = func() *goja.Runtime {
vm := goja.New()
pool.setupGlobals(vm)
pool.loadScripts(vm)
return vm
}
// 预创建VM
preCreate := min(maxSize/2, 4)
for range preCreate {
select {
case pool.pool <- pool.createFunc():
default:
}
}
return pool
}
func (p *JSRuntimePool) Get() *goja.Runtime {
select {
case vm := <-p.pool:
return vm
default:
return p.createFunc()
}
}
func (p *JSRuntimePool) Put(vm *goja.Runtime) {
if vm == nil {
return
}
select {
case p.pool <- vm:
default:
// 池满丢弃VM让GC回收
}
}
func (p *JSRuntimePool) setupGlobals(vm *goja.Runtime) {
// console
console := vm.NewObject()
console.Set("log", func(args ...any) {
var strs []string
for _, arg := range args {
strs = append(strs, fmt.Sprintf("%v", arg))
}
common.SysLog("JS: " + strings.Join(strs, " "))
})
console.Set("error", func(args ...any) {
var strs []string
for _, arg := range args {
strs = append(strs, fmt.Sprintf("%v", arg))
}
common.SysError("JS: " + strings.Join(strs, " "))
})
console.Set("warn", func(args ...any) {
var strs []string
for _, arg := range args {
strs = append(strs, fmt.Sprintf("%v", arg))
}
common.SysError("JS WARN: " + strings.Join(strs, " "))
})
vm.Set("console", console)
// JSON
jsonObj := vm.NewObject()
jsonObj.Set("parse", func(str string) any {
var result any
err := json.Unmarshal([]byte(str), &result)
if err != nil {
panic(vm.ToValue(err.Error()))
}
return result
})
jsonObj.Set("stringify", func(obj any) string {
data, err := json.Marshal(obj)
if err != nil {
panic(vm.ToValue(err.Error()))
}
return string(data)
})
vm.Set("JSON", jsonObj)
// fetch 实现
vm.Set("fetch", func(url string, options ...any) *JSFetchResponse {
return p.fetch(url, options...)
})
// 数据库
vm.Set("db", &JSDatabase{db: model.DB})
// 定时器
vm.Set("setTimeout", func(fn func(), delay int) {
go func() {
time.Sleep(time.Duration(delay) * time.Millisecond)
fn()
}()
})
}
func (p *JSRuntimePool) loadScripts(vm *goja.Runtime) {
p.mu.RLock()
defer p.mu.RUnlock()
// 加载预处理脚本
if script, exists := p.scripts["pre"]; exists {
if _, err := vm.RunString(script); err != nil {
common.SysError("Failed to load pre_process.js: " + err.Error())
}
} else if preScript, err := os.ReadFile(jsConfig.PreScriptPath); err == nil {
p.scripts["pre"] = string(preScript)
if _, err = vm.RunString(string(preScript)); err != nil {
common.SysError("Failed to load pre_process.js: " + err.Error())
} else {
common.SysLog("Loaded pre_process.js")
}
}
// 加载后处理脚本
if script, exists := p.scripts["post"]; exists {
if _, err := vm.RunString(script); err != nil {
common.SysError("Failed to load post_process.js: " + err.Error())
}
} else if postScript, err := os.ReadFile(jsConfig.PostScriptPath); err == nil {
p.scripts["post"] = string(postScript)
if _, err = vm.RunString(string(postScript)); err != nil {
common.SysError("Failed to load post_process.js: " + err.Error())
} else {
common.SysLog("Loaded post_process.js")
}
}
}
func (p *JSRuntimePool) ReloadScripts() {
p.mu.Lock()
defer p.mu.Unlock()
// 清空缓存的脚本
p.scripts = make(map[string]string)
// 清空VM池强制重新创建
for {
select {
case <-p.pool:
default:
goto done
}
}
done:
common.SysLog("JavaScript scripts reloaded")
}
func initJSRuntimePool() *JSRuntimePool {
jsPoolOnce.Do(func() {
jsRuntimePool = NewJSRuntimePool(jsConfig.MaxVMCount)
common.SysLog("JavaScript runtime pool initialized successfully")
})
return jsRuntimePool
}
func validateGinContext(c *gin.Context) error {
if c == nil {
return fmt.Errorf("gin context is nil")
}
if c.Request == nil {
return fmt.Errorf("gin context request is nil")
}
return nil
}
func (p *JSRuntimePool) executeWithTimeout(vm *goja.Runtime, fn func() (goja.Value, error)) (goja.Value, error) {
type result struct {
value goja.Value
err error
}
resultChan := make(chan result, 1)
go func() {
defer func() {
if r := recover(); r != nil {
resultChan <- result{err: fmt.Errorf("JS panic: %v", r)}
}
}()
value, err := fn()
resultChan <- result{value: value, err: err}
}()
select {
case res := <-resultChan:
return res.value, res.err
case <-time.After(jsConfig.ScriptTimeout):
return nil, fmt.Errorf("script execution timeout after %v", jsConfig.ScriptTimeout)
}
}
func (p *JSRuntimePool) PreProcessRequest(c *gin.Context) error {
if err := validateGinContext(c); err != nil {
common.SysError("JS PreProcess Validation Error: " + err.Error())
return err
}
vm := p.Get()
defer p.Put(vm)
preProcessFunc := vm.Get("preProcessRequest")
if preProcessFunc == nil || goja.IsUndefined(preProcessFunc) {
return nil
}
jsCtx := createJSContext(c)
if jsCtx == nil {
return fmt.Errorf("failed to create JS context")
}
result, err := p.executeWithTimeout(vm, func() (goja.Value, error) {
vm.Set("ctx", jsCtx)
fn, ok := goja.AssertFunction(preProcessFunc)
if !ok {
return nil, fmt.Errorf("preProcessRequest is not a function")
}
return fn(goja.Undefined(), vm.ToValue(jsCtx))
})
if err != nil {
common.SysError("JS PreProcess Error: " + err.Error())
return err
}
// 处理返回结果
if result != nil && !goja.IsUndefined(result) {
resultObj := result.Export()
if resultMap, ok := resultObj.(map[string]any); ok {
// 是否修改请求
if newBody, exists := resultMap["body"]; exists {
switch v := newBody.(type) {
case string:
c.Request.Body = io.NopCloser(strings.NewReader(v))
c.Request.ContentLength = int64(len(v))
case []byte:
c.Request.Body = io.NopCloser(bytes.NewBuffer(v))
c.Request.ContentLength = int64(len(v))
case map[string]any:
bodyBytes, err := json.Marshal(v)
if err == nil {
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
c.Request.ContentLength = int64(len(bodyBytes))
c.Request.Header.Set("Content-Type", "application/json")
} else {
common.SysError("JS PreProcess JSON Marshal Error: " + err.Error())
}
default:
common.SysError("JS PreProcess Unsupported Body Type: " + fmt.Sprintf("%T", newBody))
}
}
// 是否修改 headers
if newHeaders, exists := resultMap["headers"]; exists {
if headersMap, ok := newHeaders.(map[string]any); ok {
for key, value := range headersMap {
if valueStr, ok := value.(string); ok {
c.Request.Header.Set(key, valueStr)
}
}
}
}
// 是否阻止请求
if block, exists := resultMap["block"]; exists {
if blockBool, ok := block.(bool); ok && blockBool {
status := http.StatusForbidden
if statusCode, exists := resultMap["statusCode"]; exists {
if statusInt, ok := statusCode.(float64); ok {
status = int(statusInt)
}
}
message := "Request blocked by pre-process script"
if msg, exists := resultMap["message"]; exists {
if msgStr, ok := msg.(string); ok {
message = msgStr
}
}
c.JSON(status, gin.H{"error": message})
c.Abort()
return fmt.Errorf("request blocked")
}
}
}
}
return nil
}
func (p *JSRuntimePool) PostProcessResponse(c *gin.Context, statusCode int, body []byte) (int, []byte, error) {
if err := validateGinContext(c); err != nil {
common.SysError("JS PostProcess Validation Error: " + err.Error())
return statusCode, body, err
}
vm := p.Get()
defer p.Put(vm)
postProcessFunc := vm.Get("postProcessResponse")
if postProcessFunc == nil || goja.IsUndefined(postProcessFunc) {
return statusCode, body, nil
}
jsCtx := createJSContext(c)
if jsCtx == nil {
return statusCode, body, fmt.Errorf("failed to create JS context")
}
jsResponse := &JSResponse{
StatusCode: statusCode,
Headers: make(map[string]string),
Body: string(body),
}
// 获取响应头
if c.Writer != nil {
for key, values := range c.Writer.Header() {
if len(values) > 0 {
jsResponse.Headers[key] = values[0]
}
}
}
result, err := p.executeWithTimeout(vm, func() (goja.Value, error) {
vm.Set("ctx", jsCtx)
fn, ok := goja.AssertFunction(postProcessFunc)
if !ok {
return nil, fmt.Errorf("postProcessResponse is not a function")
}
return fn(goja.Undefined(), vm.ToValue(jsCtx), vm.ToValue(jsResponse))
})
if err != nil {
common.SysError("JS PostProcess Error: " + err.Error())
return statusCode, body, err
}
// 处理返回
if result != nil && !goja.IsUndefined(result) {
resultObj := result.Export()
if resultMap, ok := resultObj.(map[string]any); ok {
if newStatusCode, exists := resultMap["statusCode"]; exists {
if statusInt, ok := newStatusCode.(float64); ok {
statusCode = int(statusInt)
}
}
if newBody, exists := resultMap["body"]; exists {
if bodyStr, ok := newBody.(string); ok {
body = []byte(bodyStr)
}
}
if newHeaders, exists := resultMap["headers"]; exists {
if headersMap, ok := newHeaders.(map[string]any); ok {
for key, value := range headersMap {
if valueStr, ok := value.(string); ok {
c.Header(key, valueStr)
}
}
}
}
}
}
return statusCode, body, nil
}
func (p *JSRuntimePool) hasPostProcessFunction() bool {
vm := p.Get()
defer p.Put(vm)
postProcessFunc := vm.Get("postProcessResponse")
return postProcessFunc != nil && !goja.IsUndefined(postProcessFunc)
}
func JSRuntimeMiddleware() gin.HandlerFunc {
if !jsConfig.Enabled {
return func(c *gin.Context) {
c.Next()
}
}
pool := initJSRuntimePool()
return func(c *gin.Context) {
start := time.Now()
// 预处理
common.SysLog("JS Runtime PreProcessing Request: " + c.Request.Method + " " + c.Request.URL.String())
if err := pool.PreProcessRequest(c); err != nil {
common.SysError("JS Runtime PreProcess Error: " + err.Error())
return
}
common.SysLog("JS Runtime PreProcessing Completed")
// 后处理
if pool.hasPostProcessFunction() {
common.SysLog("JS Runtime PostProcessing Response")
writer := newResponseWriter(c.Writer)
c.Writer = writer
c.Next()
// 后处理响应
if writer.body.Len() > 0 {
statusCode, body, err := pool.PostProcessResponse(c, writer.statusCode, writer.body.Bytes())
if err == nil {
// 更新响应
c.Writer = writer.ResponseWriter
c.Writer.Header().Del("Content-Length")
// 设置修改后的headers
for k, v := range writer.headerMap {
for _, value := range v {
c.Writer.Header().Add(k, value)
}
}
c.Status(statusCode)
c.Writer.Write(body)
common.SysLog(fmt.Sprintf("JS Runtime PostProcessing Completed with status %d", statusCode))
} else {
// 出错时返回原始响应
c.Writer = writer.ResponseWriter
c.Writer.Header().Del("Content-Length")
c.Status(writer.statusCode)
c.Writer.Write(writer.body.Bytes())
common.SysError(fmt.Sprintf("JS Runtime PostProcess Error: %v", err))
}
} else {
// 没有响应体时恢复原始writer
c.Writer = writer.ResponseWriter
common.SysLog("JS Runtime PostProcessing Completed with no body")
}
} else {
c.Next()
common.SysLog("JS Runtime PostProcessing Skipped: No postProcessResponse function defined")
}
// 记录处理时间
duration := time.Since(start)
if duration > time.Millisecond*100 {
common.SysLog(fmt.Sprintf("JS Runtime processing took %v", duration))
}
}
}
func ReloadJSScripts() {
if jsRuntimePool != nil {
jsRuntimePool.ReloadScripts()
common.SysLog("JavaScript scripts reloaded")
}
}