(jsrt) opt.: struct

This commit is contained in:
lollipopkit🏳️‍⚧️
2025-07-15 22:19:02 +08:00
parent 59f12d2582
commit fd51f71e0f
9 changed files with 504 additions and 473 deletions

517
middleware/jsrt/jsrt.go Normal file
View File

@@ -0,0 +1,517 @@
package jsrt
import (
"bytes"
"crypto/tls"
"encoding/json"
"fmt"
"io"
"net/http"
"one-api/common"
"one-api/model"
"os"
"strings"
"sync"
"time"
"github.com/dop251/goja"
"github.com/gin-gonic/gin"
)
/// 池化
type JSRuntimePool struct {
pool chan *goja.Runtime
maxSize int
createFunc func() *goja.Runtime
scripts map[string]string
mu sync.RWMutex
httpClient *http.Client
}
var (
jsRuntimePool *JSRuntimePool
jsPoolOnce sync.Once
)
func NewJSRuntimePool(maxSize int) *JSRuntimePool {
// 创建HTTP客户端
httpClient := &http.Client{
Timeout: jsConfig.FetchTimeout,
Transport: &http.Transport{
TLSClientConfig: &tls.Config{
InsecureSkipVerify: false,
},
MaxIdleConns: 100,
MaxIdleConnsPerHost: 10,
IdleConnTimeout: 90 * time.Second,
},
}
pool := &JSRuntimePool{
pool: make(chan *goja.Runtime, maxSize),
maxSize: maxSize,
scripts: make(map[string]string),
httpClient: httpClient,
}
pool.createFunc = func() *goja.Runtime {
vm := goja.New()
pool.setupGlobals(vm)
pool.loadScripts(vm)
return vm
}
// 预创建VM
preCreate := min(maxSize/2, 4)
for range preCreate {
select {
case pool.pool <- pool.createFunc():
default:
}
}
return pool
}
func (p *JSRuntimePool) Get() *goja.Runtime {
select {
case vm := <-p.pool:
return vm
default:
return p.createFunc()
}
}
func (p *JSRuntimePool) Put(vm *goja.Runtime) {
if vm == nil {
return
}
select {
case p.pool <- vm:
default:
// 池满丢弃VM让GC回收
}
}
func (p *JSRuntimePool) setupGlobals(vm *goja.Runtime) {
// console
console := vm.NewObject()
console.Set("log", func(args ...any) {
var strs []string
for _, arg := range args {
strs = append(strs, fmt.Sprintf("%v", arg))
}
common.SysLog("JS: " + strings.Join(strs, " "))
})
console.Set("error", func(args ...any) {
var strs []string
for _, arg := range args {
strs = append(strs, fmt.Sprintf("%v", arg))
}
common.SysError("JS: " + strings.Join(strs, " "))
})
console.Set("warn", func(args ...any) {
var strs []string
for _, arg := range args {
strs = append(strs, fmt.Sprintf("%v", arg))
}
common.SysError("JS WARN: " + strings.Join(strs, " "))
})
vm.Set("console", console)
// JSON
jsonObj := vm.NewObject()
jsonObj.Set("parse", func(str string) any {
var result any
err := json.Unmarshal([]byte(str), &result)
if err != nil {
panic(vm.ToValue(err.Error()))
}
return result
})
jsonObj.Set("stringify", func(obj any) string {
data, err := json.Marshal(obj)
if err != nil {
panic(vm.ToValue(err.Error()))
}
return string(data)
})
vm.Set("JSON", jsonObj)
// fetch 实现
vm.Set("fetch", func(url string, options ...any) *JSFetchResponse {
return p.fetch(url, options...)
})
// 数据库
vm.Set("db", &JSDatabase{db: model.DB})
// 定时器
vm.Set("setTimeout", func(fn func(), delay int) {
go func() {
time.Sleep(time.Duration(delay) * time.Millisecond)
fn()
}()
})
}
func (p *JSRuntimePool) loadScripts(vm *goja.Runtime) {
p.mu.RLock()
defer p.mu.RUnlock()
// 加载预处理脚本
if script, exists := p.scripts["pre"]; exists {
if _, err := vm.RunString(script); err != nil {
common.SysError("Failed to load pre_process.js: " + err.Error())
}
} else if preScript, err := os.ReadFile(jsConfig.PreScriptPath); err == nil {
p.scripts["pre"] = string(preScript)
if _, err = vm.RunString(string(preScript)); err != nil {
common.SysError("Failed to load pre_process.js: " + err.Error())
} else {
common.SysLog("Loaded pre_process.js")
}
}
// 加载后处理脚本
if script, exists := p.scripts["post"]; exists {
if _, err := vm.RunString(script); err != nil {
common.SysError("Failed to load post_process.js: " + err.Error())
}
} else if postScript, err := os.ReadFile(jsConfig.PostScriptPath); err == nil {
p.scripts["post"] = string(postScript)
if _, err = vm.RunString(string(postScript)); err != nil {
common.SysError("Failed to load post_process.js: " + err.Error())
} else {
common.SysLog("Loaded post_process.js")
}
}
}
func (p *JSRuntimePool) ReloadScripts() {
p.mu.Lock()
defer p.mu.Unlock()
// 清空缓存的脚本
p.scripts = make(map[string]string)
// 清空VM池强制重新创建
for {
select {
case <-p.pool:
default:
goto done
}
}
done:
common.SysLog("JavaScript scripts reloaded")
}
func initJSRuntimePool() *JSRuntimePool {
jsPoolOnce.Do(func() {
jsRuntimePool = NewJSRuntimePool(jsConfig.MaxVMCount)
common.SysLog("JavaScript runtime pool initialized successfully")
})
return jsRuntimePool
}
func validateGinContext(c *gin.Context) error {
if c == nil {
return fmt.Errorf("gin context is nil")
}
if c.Request == nil {
return fmt.Errorf("gin context request is nil")
}
return nil
}
func (p *JSRuntimePool) executeWithTimeout(vm *goja.Runtime, fn func() (goja.Value, error)) (goja.Value, error) {
type result struct {
value goja.Value
err error
}
resultChan := make(chan result, 1)
go func() {
defer func() {
if r := recover(); r != nil {
resultChan <- result{err: fmt.Errorf("JS panic: %v", r)}
}
}()
value, err := fn()
resultChan <- result{value: value, err: err}
}()
select {
case res := <-resultChan:
return res.value, res.err
case <-time.After(jsConfig.ScriptTimeout):
return nil, fmt.Errorf("script execution timeout after %v", jsConfig.ScriptTimeout)
}
}
func (p *JSRuntimePool) PreProcessRequest(c *gin.Context) error {
if err := validateGinContext(c); err != nil {
common.SysError("JS PreProcess Validation Error: " + err.Error())
return err
}
vm := p.Get()
defer p.Put(vm)
preProcessFunc := vm.Get("preProcessRequest")
if preProcessFunc == nil || goja.IsUndefined(preProcessFunc) {
return nil
}
jsCtx := createJSContext(c)
if jsCtx == nil {
return fmt.Errorf("failed to create JS context")
}
result, err := p.executeWithTimeout(vm, func() (goja.Value, error) {
vm.Set("ctx", jsCtx)
fn, ok := goja.AssertFunction(preProcessFunc)
if !ok {
return nil, fmt.Errorf("preProcessRequest is not a function")
}
return fn(goja.Undefined(), vm.ToValue(jsCtx))
})
if err != nil {
common.SysError("JS PreProcess Error: " + err.Error())
return err
}
// 处理返回结果
if result != nil && !goja.IsUndefined(result) {
resultObj := result.Export()
if resultMap, ok := resultObj.(map[string]any); ok {
// 是否修改请求
if newBody, exists := resultMap["body"]; exists {
switch v := newBody.(type) {
case string:
c.Request.Body = io.NopCloser(strings.NewReader(v))
c.Request.ContentLength = int64(len(v))
case []byte:
c.Request.Body = io.NopCloser(bytes.NewBuffer(v))
c.Request.ContentLength = int64(len(v))
case map[string]any:
bodyBytes, err := json.Marshal(v)
if err == nil {
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
c.Request.ContentLength = int64(len(bodyBytes))
c.Request.Header.Set("Content-Type", "application/json")
} else {
common.SysError("JS PreProcess JSON Marshal Error: " + err.Error())
}
default:
common.SysError("JS PreProcess Unsupported Body Type: " + fmt.Sprintf("%T", newBody))
}
}
// 是否修改 headers
if newHeaders, exists := resultMap["headers"]; exists {
if headersMap, ok := newHeaders.(map[string]any); ok {
for key, value := range headersMap {
if valueStr, ok := value.(string); ok {
c.Request.Header.Set(key, valueStr)
}
}
}
}
// 是否阻止请求
if block, exists := resultMap["block"]; exists {
if blockBool, ok := block.(bool); ok && blockBool {
status := http.StatusForbidden
if statusCode, exists := resultMap["statusCode"]; exists {
if statusInt, ok := statusCode.(float64); ok {
status = int(statusInt)
}
}
message := "Request blocked by pre-process script"
if msg, exists := resultMap["message"]; exists {
if msgStr, ok := msg.(string); ok {
message = msgStr
}
}
c.JSON(status, gin.H{"error": message})
c.Abort()
return fmt.Errorf("request blocked")
}
}
}
}
return nil
}
func (p *JSRuntimePool) PostProcessResponse(c *gin.Context, statusCode int, body []byte) (int, []byte, error) {
if err := validateGinContext(c); err != nil {
common.SysError("JS PostProcess Validation Error: " + err.Error())
return statusCode, body, err
}
vm := p.Get()
defer p.Put(vm)
postProcessFunc := vm.Get("postProcessResponse")
if postProcessFunc == nil || goja.IsUndefined(postProcessFunc) {
return statusCode, body, nil
}
jsCtx := createJSContext(c)
if jsCtx == nil {
return statusCode, body, fmt.Errorf("failed to create JS context")
}
jsResponse := &JSResponse{
StatusCode: statusCode,
Headers: make(map[string]string),
Body: string(body),
}
// 获取响应头
if c.Writer != nil {
for key, values := range c.Writer.Header() {
if len(values) > 0 {
jsResponse.Headers[key] = values[0]
}
}
}
result, err := p.executeWithTimeout(vm, func() (goja.Value, error) {
vm.Set("ctx", jsCtx)
fn, ok := goja.AssertFunction(postProcessFunc)
if !ok {
return nil, fmt.Errorf("postProcessResponse is not a function")
}
return fn(goja.Undefined(), vm.ToValue(jsCtx), vm.ToValue(jsResponse))
})
if err != nil {
common.SysError("JS PostProcess Error: " + err.Error())
return statusCode, body, err
}
// 处理返回
if result != nil && !goja.IsUndefined(result) {
resultObj := result.Export()
if resultMap, ok := resultObj.(map[string]any); ok {
if newStatusCode, exists := resultMap["statusCode"]; exists {
if statusInt, ok := newStatusCode.(float64); ok {
statusCode = int(statusInt)
}
}
if newBody, exists := resultMap["body"]; exists {
if bodyStr, ok := newBody.(string); ok {
body = []byte(bodyStr)
}
}
if newHeaders, exists := resultMap["headers"]; exists {
if headersMap, ok := newHeaders.(map[string]any); ok {
for key, value := range headersMap {
if valueStr, ok := value.(string); ok {
c.Header(key, valueStr)
}
}
}
}
}
}
return statusCode, body, nil
}
func (p *JSRuntimePool) hasPostProcessFunction() bool {
vm := p.Get()
defer p.Put(vm)
postProcessFunc := vm.Get("postProcessResponse")
return postProcessFunc != nil && !goja.IsUndefined(postProcessFunc)
}
func JSRuntimeMiddleware() gin.HandlerFunc {
if !jsConfig.Enabled {
return func(c *gin.Context) {
c.Next()
}
}
pool := initJSRuntimePool()
return func(c *gin.Context) {
start := time.Now()
// 预处理
common.SysLog("JS Runtime PreProcessing Request: " + c.Request.Method + " " + c.Request.URL.String())
if err := pool.PreProcessRequest(c); err != nil {
common.SysError("JS Runtime PreProcess Error: " + err.Error())
return
}
common.SysLog("JS Runtime PreProcessing Completed")
// 后处理
if pool.hasPostProcessFunction() {
common.SysLog("JS Runtime PostProcessing Response")
writer := newResponseWriter(c.Writer)
c.Writer = writer
c.Next()
// 后处理响应
if writer.body.Len() > 0 {
statusCode, body, err := pool.PostProcessResponse(c, writer.statusCode, writer.body.Bytes())
if err == nil {
// 更新响应
c.Writer = writer.ResponseWriter
c.Writer.Header().Del("Content-Length")
// 设置修改后的headers
for k, v := range writer.headerMap {
for _, value := range v {
c.Writer.Header().Add(k, value)
}
}
c.Status(statusCode)
c.Writer.Write(body)
common.SysLog(fmt.Sprintf("JS Runtime PostProcessing Completed with status %d", statusCode))
} else {
// 出错时返回原始响应
c.Writer = writer.ResponseWriter
c.Writer.Header().Del("Content-Length")
c.Status(writer.statusCode)
c.Writer.Write(writer.body.Bytes())
common.SysError(fmt.Sprintf("JS Runtime PostProcess Error: %v", err))
}
} else {
// 没有响应体时恢复原始writer
c.Writer = writer.ResponseWriter
common.SysLog("JS Runtime PostProcessing Completed with no body")
}
} else {
c.Next()
common.SysLog("JS Runtime PostProcessing Skipped: No postProcessResponse function defined")
}
// 记录处理时间
duration := time.Since(start)
if duration > time.Millisecond*100 {
common.SysLog(fmt.Sprintf("JS Runtime processing took %v", duration))
}
}
}
func ReloadJSScripts() {
if jsRuntimePool != nil {
jsRuntimePool.ReloadScripts()
common.SysLog("JavaScript scripts reloaded")
}
}