diff --git a/.env.example b/.env.example index ea2464270..3d2d3b7ad 100644 --- a/.env.example +++ b/.env.example @@ -73,3 +73,12 @@ # 节点类型 # 如果是主节点则为master # NODE_TYPE=master + +# JavaScript 运行时配置 +# 是否启用(默认:true) +# JS_RUNTIME_ENABLED=true +# 脚本文件夹(可选,默认 scripts/) +# JS_SCRIPTS_DIR=./custom_scripts +# 启用调试日志(可选) +# JS_RUNTIME_DEBUG=true + diff --git a/controller/misc.go b/controller/misc.go index 4ffe86f43..0474b0405 100644 --- a/controller/misc.go +++ b/controller/misc.go @@ -303,5 +303,13 @@ func ResetPassword(c *gin.Context) { "message": "", "data": password, }) - return +} + +func ReloadJSScripts(c *gin.Context) { + middleware.ReloadJSScripts() + + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "JavaScript 脚本已重新加载", + }) } diff --git a/controller/task.go b/controller/task.go index 5cfa728aa..fda1e7714 100644 --- a/controller/task.go +++ b/controller/task.go @@ -122,7 +122,7 @@ func updateSunoTaskAll(ctx context.Context, channelId int, taskIds []string, tas } if resp.StatusCode != http.StatusOK { common.LogError(ctx, fmt.Sprintf("Get Task status code: %d", resp.StatusCode)) - return errors.New(fmt.Sprintf("Get Task status code: %d", resp.StatusCode)) + return fmt.Errorf("Get Task status code: %d", resp.StatusCode) } defer resp.Body.Close() responseBody, err := io.ReadAll(resp.Body) diff --git a/doc/JS_RUNTIME.md b/doc/JS_RUNTIME.md new file mode 100644 index 000000000..98daaa183 --- /dev/null +++ b/doc/JS_RUNTIME.md @@ -0,0 +1,179 @@ +# JavaScript Runtime 中间件 + +## 配置 + +将 JavaScript 脚本放置在项目根目录的 `scripts/` 文件夹中: + +- `scripts/pre_process.js` - 请求预处理脚本 +- `scripts/post_process.js` - 响应后处理脚本 + +## API 参考 + +### 预处理函数 + +```javascript +function preProcessRequest(ctx) { + // ctx 包含以下属性: + // - method: 请求方法 (GET, POST, etc.) + // - url: 请求URL + // - headers: 请求头 (object) + // - body: 请求体 (string) + // - query: 查询参数 (object) + // - params: 路径参数 (object) + // - userAgent: User-Agent + // - remoteIP: 客户端IP + // - contentType: Content-Type + // - extra: 额外数据 (object) + + // 返回值: + // - undefined: 继续正常处理 + // - object: 修改请求或阻止请求 + // - block: true/false - 是否阻止请求 + // - statusCode: 状态码 (当 block=true 时) + // - message: 错误消息 (当 block=true 时) + // - headers: 修改的请求头 (object) + // - body: 修改的请求体 (string) +} +``` + +### 后处理函数 + +```javascript +function postProcessResponse(ctx, response) { + // ctx: 请求上下文 (同预处理) + // response 包含以下属性: + // - statusCode: 响应状态码 + // - headers: 响应头 (object) + // - body: 响应体 (string) + + // 返回值: + // - undefined: 保持原始响应 + // - object: 修改响应 + // - statusCode: 新的状态码 + // - headers: 修改的响应头 (object) + // - body: 修改的响应体 (string) +} +``` + +### 数据库对象 + +```javascript +// 查询数据库 +var results = db.Query("SELECT * FROM users WHERE id = ?", 123); + +// 执行 SQL +var result = db.Exec("UPDATE users SET last_login = NOW() WHERE id = ?", 123); +// result 包含: { rowsAffected: number, error: any } +``` + +### 全局对象 + +- `console.log()` - 输出日志 +- `console.error()` - 输出错误日志 +- `JSON.parse()` - 解析 JSON +- `JSON.stringify()` - 序列化为 JSON + +## 使用示例 + +### 请求限流 + +```javascript +function preProcessRequest(ctx) { + // 基于 IP 的简单限流 + var recentRequests = db.Query( + "SELECT COUNT(*) as count FROM request_logs WHERE ip = ? AND timestamp > ?", + ctx.remoteIP, + new Date(Date.now() - 60000).toISOString() // 最近1分钟 + ); + + if (recentRequests[0].count > 100) { + return { + block: true, + statusCode: 429, + message: "Too many requests" + }; + } + + // 记录请求 + db.Exec( + "INSERT INTO request_logs (ip, url, timestamp) VALUES (?, ?, ?)", + ctx.remoteIP, ctx.url, new Date().toISOString() + ); +} +``` + +### 请求修改 + +```javascript +function preProcessRequest(ctx) { + if (ctx.method === "POST" && ctx.body) { + try { + var bodyObj = JSON.parse(ctx.body); + + // 添加默认参数 + if (!bodyObj.temperature) { + bodyObj.temperature = 0.7; + } + + // 添加用户标识 + bodyObj._userId = ctx.extra.userId; + + return { + body: JSON.stringify(bodyObj) + }; + } catch (e) { + console.error("Failed to parse request body:", e); + } + } +} +``` + +### 响应增强 + +```javascript +function postProcessResponse(ctx, response) { + if (response.statusCode === 200 && ctx.url.includes("/v1/chat/completions")) { + try { + var bodyObj = JSON.parse(response.body); + + // 添加自定义元数据 + bodyObj.metadata = { + processedAt: new Date().toISOString(), + version: "1.0.0" + }; + + // 记录成功的对话 + db.Exec( + "INSERT INTO chat_logs (user_ip, model, tokens, timestamp) VALUES (?, ?, ?, ?)", + ctx.remoteIP, bodyObj.model, bodyObj.usage?.total_tokens || 0, new Date().toISOString() + ); + + return { + statusCode: response.statusCode, + headers: response.headers, + body: JSON.stringify(bodyObj) + }; + } catch (e) { + console.error("Failed to process chat completion response:", e); + } + } + + return response; +} +``` + +## 管理接口 + +### 重新加载脚本 + +```bash +curl -X POST http://host:port/api/scripts/reload \ + -H 'Content-Type: application/json' \ + -H 'Authorization Bearer ' +``` + +## 故障排除 + +- 查看服务日志中的 JavaScript 相关错误信息 +- 使用 `console.log()` 调试脚本逻辑 +- 确保 JavaScript 语法正确(不支持所有 ES6+ 特性) diff --git a/go.mod b/go.mod index 9479ba552..1e16c33b9 100644 --- a/go.mod +++ b/go.mod @@ -11,6 +11,7 @@ require ( github.com/aws/aws-sdk-go-v2/credentials v1.17.11 github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.7.4 github.com/bytedance/gopkg v0.0.0-20220118071334-3db87571198b + github.com/dop251/goja v0.0.0-20250630131328-58d95d85e994 github.com/gin-contrib/cors v1.7.2 github.com/gin-contrib/gzip v0.0.6 github.com/gin-contrib/sessions v0.0.5 @@ -31,6 +32,7 @@ require ( golang.org/x/crypto v0.35.0 golang.org/x/image v0.23.0 golang.org/x/net v0.35.0 + golang.org/x/sync v0.11.0 gorm.io/driver/mysql v1.4.3 gorm.io/driver/postgres v1.5.2 gorm.io/gorm v1.25.2 @@ -56,9 +58,11 @@ require ( github.com/go-ole/go-ole v1.2.6 // indirect github.com/go-playground/locales v0.14.1 // indirect github.com/go-playground/universal-translator v0.18.1 // indirect + github.com/go-sourcemap/sourcemap v2.1.3+incompatible // indirect github.com/go-sql-driver/mysql v1.7.0 // indirect github.com/goccy/go-json v0.10.2 // indirect github.com/google/go-cmp v0.6.0 // indirect + github.com/google/pprof v0.0.0-20230207041349-798e818bf904 // indirect github.com/gorilla/context v1.1.1 // indirect github.com/gorilla/securecookie v1.1.1 // indirect github.com/gorilla/sessions v1.2.1 // indirect @@ -84,7 +88,6 @@ require ( github.com/yusufpapurcu/wmi v1.2.3 // indirect golang.org/x/arch v0.12.0 // indirect golang.org/x/exp v0.0.0-20240404231335-c0f41cb1a7a0 // indirect - golang.org/x/sync v0.11.0 // indirect golang.org/x/sys v0.30.0 // indirect golang.org/x/text v0.22.0 // indirect google.golang.org/protobuf v1.34.2 // indirect diff --git a/go.sum b/go.sum index 71dd83c22..9e18b3c2d 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,7 @@ github.com/Calcium-Ion/go-epay v0.0.4 h1:C96M7WfRLadcIVscWzwLiYs8etI1wrDmtFMuK2zP22A= github.com/Calcium-Ion/go-epay v0.0.4/go.mod h1:cxo/ZOg8ClvE3VAnCmEzbuyAZINSq7kFEN9oHj5WQ2U= +github.com/Masterminds/semver/v3 v3.2.1 h1:RN9w6+7QoMeJVGyfmbcgs28Br8cvmnucEXnY0rYXWg0= +github.com/Masterminds/semver/v3 v3.2.1/go.mod h1:qvl/7zhW3nngYb5+80sSMF+FG2BjYrf8m9wsX0PNOMQ= github.com/andybalholm/brotli v1.1.1 h1:PR2pgnyFznKEugtsUo0xLdDop5SKXd5Qf5ysW+7XdTA= github.com/andybalholm/brotli v1.1.1/go.mod h1:05ib4cKhjx3OQYUY22hTVd34Bc8upXjOLL2rKwwZBoA= github.com/anknown/ahocorasick v0.0.0-20190904063843-d75dbd5169c0 h1:onfun1RA+KcxaMk1lfrRnwCd1UUuOjJM/lri5eM1qMs= @@ -40,6 +42,8 @@ github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/r github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= github.com/dlclark/regexp2 v1.11.5 h1:Q/sSnsKerHeCkc/jSTNq1oCm7KiVgUMZRDUoRu0JQZQ= github.com/dlclark/regexp2 v1.11.5/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8= +github.com/dop251/goja v0.0.0-20250630131328-58d95d85e994 h1:aQYWswi+hRL2zJqGacdCZx32XjKYV8ApXFGntw79XAM= +github.com/dop251/goja v0.0.0-20250630131328-58d95d85e994/go.mod h1:MxLav0peU43GgvwVgNbLAj1s/bSGboKkhuULvq/7hx4= github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4= @@ -83,6 +87,8 @@ github.com/go-playground/validator/v10 v10.20.0 h1:K9ISHbSaI0lyB2eWMPJo+kOS/FBEx github.com/go-playground/validator/v10 v10.20.0/go.mod h1:dbuPbCMFw/DrkbEynArYaCwl3amGuJotoKCe95atGMM= github.com/go-redis/redis/v8 v8.11.5 h1:AcZZR7igkdvfVmQTPnu9WE37LRrO/YrBH5zWyjDC0oI= github.com/go-redis/redis/v8 v8.11.5/go.mod h1:gREzHqY1hg6oD9ngVRbLStwAWKhA0FEgq8Jd4h5lpwo= +github.com/go-sourcemap/sourcemap v2.1.3+incompatible h1:W1iEw64niKVGogNgBN3ePyLFfuisuzeidWPMPWmECqU= +github.com/go-sourcemap/sourcemap v2.1.3+incompatible/go.mod h1:F8jJfvm2KbVjc5NqelyYJmf/v5J0dwNLS2mL4sNA1Jg= github.com/go-sql-driver/mysql v1.6.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg= github.com/go-sql-driver/mysql v1.7.0 h1:ueSltNNllEqE3qcWBTD0iQd3IpL/6U+mJxLkazJ7YPc= github.com/go-sql-driver/mysql v1.7.0/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI= @@ -97,8 +103,8 @@ github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/ github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= -github.com/google/pprof v0.0.0-20221118152302-e6195bd50e26 h1:Xim43kblpZXfIBQsbuBVKCudVG457BR2GZFIz3uw3hQ= -github.com/google/pprof v0.0.0-20221118152302-e6195bd50e26/go.mod h1:dDKJzRmX4S37WGHujM7tX//fmj1uioxKzKxz3lo4HJo= +github.com/google/pprof v0.0.0-20230207041349-798e818bf904 h1:4/hN5RUoecvl+RmJRE2YxKWtnnQls6rQjjW5oV7qg2U= +github.com/google/pprof v0.0.0-20230207041349-798e818bf904/go.mod h1:uglQLonpP8qtYCYyzA+8c/9qtqgA3qsXGYqCPKARAFg= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/gorilla/context v1.1.1 h1:AWwleXJkX/nhcU9bZSnZoi3h/qGYqQAGhq6zZe/aQW8= diff --git a/middleware/js_rt.go b/middleware/js_rt.go new file mode 100644 index 000000000..b8c3216d0 --- /dev/null +++ b/middleware/js_rt.go @@ -0,0 +1,660 @@ +package middleware + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "one-api/common" + "one-api/model" + "os" + "strings" + "sync" + "time" + + "github.com/dop251/goja" + "github.com/gin-gonic/gin" + "gorm.io/gorm" +) + +type JSRuntime struct { + vm *goja.Runtime + mu sync.RWMutex +} + +type JSContext struct { + Method string `json:"method"` + URL string `json:"url"` + Headers map[string]string `json:"headers"` + Body any `json:"body"` + Query map[string]string `json:"query"` + Params map[string]string `json:"params"` + UserAgent string `json:"userAgent"` + RemoteIP string `json:"remoteIP"` + ContentType string `json:"contentType"` + Extra map[string]any `json:"extra"` +} + +func parseBodyByType(bodyBytes []byte, contentType string) any { + if len(bodyBytes) == 0 { + return "" + } + + bodyStr := string(bodyBytes) + + // 根据 Content-Type 判断 + switch { + case strings.Contains(contentType, "application/json"): + var jsonObj any + if err := json.Unmarshal(bodyBytes, &jsonObj); err == nil { + return jsonObj + } + return bodyStr // JSON 解析失败时返回字符串 + + case strings.Contains(contentType, "application/x-www-form-urlencoded"): + // 解析为 map[string]string + values, err := url.ParseQuery(bodyStr) + if err == nil { + result := make(map[string]string) + for k, v := range values { + if len(v) > 0 { + result[k] = v[0] + } + } + return result + } + return bodyStr + + case strings.Contains(contentType, "multipart/form-data"): + // multipart 数据保持为字节数组,JS 中需要特殊处理 + return bodyBytes + + case strings.Contains(contentType, "text/"): + // 文本类型返回字符串 + return bodyStr + + default: + // 尝试 JSON 解析 + var jsonObj any + if err := json.Unmarshal(bodyBytes, &jsonObj); err == nil { + return jsonObj + } + + // 检查是否是 URL encoded + if values, err := url.ParseQuery(bodyStr); err == nil && len(values) > 0 { + result := make(map[string]string) + for k, v := range values { + if len(v) > 0 { + result[k] = v[0] + } + } + return result + } + + // 二进制数据或未知格式 + if isBinary(bodyBytes) { + return bodyBytes + } + + return bodyStr + } +} + +// 检查是否为二进制数据 +func isBinary(data []byte) bool { + if len(data) == 0 { + return false + } + + // 检查前 512 字节中是否包含控制字符(除了常见的换行符等) + checkLen := min(len(data), 512) + + for i := range checkLen { + b := data[i] + // 控制字符但不是常见的文本字符 + if b < 32 && b != 9 && b != 10 && b != 13 { + return true + } + // 非 UTF-8 字符 + if b > 127 { + return true + } + } + return false +} + +func createJSContext(c *gin.Context) *JSContext { + var bodyBytes []byte + if c.Request != nil && c.Request.Body != nil { + bodyBytes, _ = io.ReadAll(c.Request.Body) + c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) + } + + // headers map + headers := make(map[string]string) + if c.Request != nil && c.Request.Header != nil { + for key, values := range c.Request.Header { + if len(values) > 0 { + headers[key] = values[0] + } + } + } + + // query parameters map + query := make(map[string]string) + if c.Request != nil && c.Request.URL != nil { + for key, values := range c.Request.URL.Query() { + if len(values) > 0 { + query[key] = values[0] + } + } + } + + // path parameters map + params := make(map[string]string) + for _, param := range c.Params { + params[param.Key] = param.Value + } + + method := "" + url := "" + userAgent := "" + remoteIP := "" + contentType := "" + + if c.Request != nil { + method = c.Request.Method + if c.Request.URL != nil { + url = c.Request.URL.String() + } + userAgent = c.Request.UserAgent() + contentType = c.ContentType() + } + + if c != nil { + remoteIP = c.ClientIP() + } + + // 智能解析 body + parsedBody := parseBodyByType(bodyBytes, contentType) + + return &JSContext{ + Method: method, + URL: url, + Headers: headers, + Body: parsedBody, + Query: query, + Params: params, + UserAgent: userAgent, + RemoteIP: remoteIP, + ContentType: contentType, + Extra: make(map[string]any), + } +} + +type JSResponse struct { + StatusCode int `json:"statusCode"` + Headers map[string]string `json:"headers"` + Body string `json:"body"` +} + +type JSDatabase struct { + db *gorm.DB +} + +var ( + jsRuntime *JSRuntime + jsRuntimeOnce sync.Once +) + +func initJSRuntime() *JSRuntime { + jsRuntimeOnce.Do(func() { + jsRuntime = &JSRuntime{ + vm: goja.New(), + } + jsRuntime.setupGlobals() + jsRuntime.loadScripts() + common.SysLog("JavaScript runtime initialized successfully") + }) + return jsRuntime +} + +func (js *JSRuntime) setupGlobals() { + js.mu.Lock() + defer js.mu.Unlock() + + // console + console := js.vm.NewObject() + console.Set("log", func(args ...any) { + var strs []string + for _, arg := range args { + strs = append(strs, fmt.Sprintf("%v", arg)) + } + common.SysLog("JS: " + strings.Join(strs, " ")) + }) + console.Set("error", func(args ...any) { + var strs []string + for _, arg := range args { + strs = append(strs, fmt.Sprintf("%v", arg)) + } + common.SysError("JS: " + strings.Join(strs, " ")) + }) + js.vm.Set("console", console) + + // JSON + jsonObj := js.vm.NewObject() + jsonObj.Set("parse", func(str string) any { + var result any + err := json.Unmarshal([]byte(str), &result) + if err != nil { + panic(js.vm.ToValue(err.Error())) + } + return result + }) + jsonObj.Set("stringify", func(obj any) string { + data, err := json.Marshal(obj) + if err != nil { + panic(js.vm.ToValue(err.Error())) + } + return string(data) + }) + js.vm.Set("JSON", jsonObj) + + js.vm.Set("db", &JSDatabase{db: model.DB}) +} + +func (jsdb *JSDatabase) Query(sql string, args ...any) []map[string]any { + rows, err := jsdb.db.Raw(sql, args...).Rows() + if err != nil { + common.SysError("JS DB Query Error: " + err.Error()) + return nil + } + defer rows.Close() + + columns, err := rows.Columns() + if err != nil { + common.SysError("JS DB Columns Error: " + err.Error()) + return nil + } + + var results []map[string]any + for rows.Next() { + values := make([]any, len(columns)) + valuePtrs := make([]any, len(columns)) + for i := range values { + valuePtrs[i] = &values[i] + } + + if err := rows.Scan(valuePtrs...); err != nil { + common.SysError("JS DB Scan Error: " + err.Error()) + continue + } + + row := make(map[string]any) + for i, col := range columns { + val := values[i] + if b, ok := val.([]byte); ok { + row[col] = string(b) + } else { + row[col] = val + } + } + results = append(results, row) + } + + return results +} + +func (jsdb *JSDatabase) Exec(sql string, args ...any) map[string]any { + result := jsdb.db.Exec(sql, args...) + return map[string]any{ + "rowsAffected": result.RowsAffected, + "error": result.Error, + } +} + +func (js *JSRuntime) loadScripts() { + // 加载预处理脚本 + if preScript, err := os.ReadFile("scripts/pre_process.js"); err == nil { + js.mu.Lock() + _, err = js.vm.RunString(string(preScript)) + js.mu.Unlock() + if err != nil { + common.SysError("Failed to load pre_process.js: " + err.Error()) + } else { + common.SysLog("Loaded pre_process.js") + } + } + + // 加载后处理脚本 + if postScript, err := os.ReadFile("scripts/post_process.js"); err == nil { + js.mu.Lock() + _, err = js.vm.RunString(string(postScript)) + js.mu.Unlock() + if err != nil { + common.SysError("Failed to load post_process.js: " + err.Error()) + } else { + common.SysLog("Loaded post_process.js") + } + } +} + +func (js *JSRuntime) ReloadScripts() { + js.loadScripts() +} + +// validateGinContext checks if the gin context is properly initialized +func validateGinContext(c *gin.Context) error { + if c == nil { + return fmt.Errorf("gin context is nil") + } + if c.Request == nil { + return fmt.Errorf("gin context request is nil") + } + return nil +} + +func (js *JSRuntime) PreProcessRequest(c *gin.Context) error { + if err := validateGinContext(c); err != nil { + common.SysError("JS PreProcess Validation Error: " + err.Error()) + return err + } + + js.mu.RLock() + preProcessFunc := js.vm.Get("preProcessRequest") + js.mu.RUnlock() + + if preProcessFunc == nil || goja.IsUndefined(preProcessFunc) { + return nil // 没有预处理函数 + } + + jsCtx := createJSContext(c) + if jsCtx == nil { + return fmt.Errorf("failed to create JS context") + } + + js.mu.Lock() + defer js.mu.Unlock() + + js.vm.Set("ctx", jsCtx) + fn, ok := goja.AssertFunction(preProcessFunc) + if !ok { + return fmt.Errorf("preProcessRequest is not a function") + } + + result, err := fn(goja.Undefined(), js.vm.ToValue(jsCtx)) + + if err != nil { + common.SysError("JS PreProcess Error: " + err.Error()) + return err + } + + // 处理返回结果 + if result != nil && !goja.IsUndefined(result) { + resultObj := result.Export() + if resultMap, ok := resultObj.(map[string]any); ok { + // 是否修改请求 + if newBody, exists := resultMap["body"]; exists { + switch v := newBody.(type) { + case string: + c.Request.Body = io.NopCloser(strings.NewReader(v)) + c.Request.ContentLength = int64(len(v)) + case []byte: + c.Request.Body = io.NopCloser(bytes.NewBuffer(v)) + c.Request.ContentLength = int64(len(v)) + case map[string]any: + bodyBytes, err := json.Marshal(v) + if err == nil { + c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) + c.Request.ContentLength = int64(len(bodyBytes)) + c.Request.Header.Set("Content-Type", "application/json") + } else { + common.SysError("JS PreProcess JSON Marshal Error: " + err.Error()) + } + + default: + common.SysError("JS PreProcess Unsupported Body Type: " + fmt.Sprintf("%T", newBody)) + } + } + + // 是否修改 headers + if newHeaders, exists := resultMap["headers"]; exists { + if headersMap, ok := newHeaders.(map[string]any); ok { + for key, value := range headersMap { + if valueStr, ok := value.(string); ok { + c.Request.Header.Set(key, valueStr) + } + } + } + } + + // 是否阻止请求 + if block, exists := resultMap["block"]; exists { + if blockBool, ok := block.(bool); ok && blockBool { + status := http.StatusForbidden + if statusCode, exists := resultMap["statusCode"]; exists { + if statusInt, ok := statusCode.(float64); ok { + status = int(statusInt) + } + } + + message := "Request blocked by pre-process script" + if msg, exists := resultMap["message"]; exists { + if msgStr, ok := msg.(string); ok { + message = msgStr + } + } + + c.JSON(status, gin.H{"error": message}) + c.Abort() + return fmt.Errorf("request blocked") + } + } + } + } + + return nil +} + +func (js *JSRuntime) PostProcessResponse(c *gin.Context, statusCode int, body []byte) (int, []byte, error) { + if err := validateGinContext(c); err != nil { + common.SysError("JS PostProcess Validation Error: " + err.Error()) + return statusCode, body, err + } + + js.mu.RLock() + postProcessFunc := js.vm.Get("postProcessResponse") + js.mu.RUnlock() + + if postProcessFunc == nil || goja.IsUndefined(postProcessFunc) { + return statusCode, body, nil // 没有后处理 + } + + jsCtx := createJSContext(c) + if jsCtx == nil { + return statusCode, body, fmt.Errorf("failed to create JS context") + } + + jsResponse := &JSResponse{ + StatusCode: statusCode, + Headers: make(map[string]string), + Body: string(body), + } + + // 获取响应头 + if c.Writer != nil { + for key, values := range c.Writer.Header() { + if len(values) > 0 { + jsResponse.Headers[key] = values[0] + } + } + } + + js.mu.Lock() + defer js.mu.Unlock() + + js.vm.Set("ctx", jsCtx) + fn, ok := goja.AssertFunction(postProcessFunc) + if !ok { + return statusCode, body, fmt.Errorf("postProcessResponse is not a function") + } + result, err := fn(goja.Undefined(), js.vm.ToValue(jsCtx), js.vm.ToValue(jsResponse)) + + if err != nil { + common.SysError("JS PostProcess Error: " + err.Error()) + return statusCode, body, err + } + + // 处理返回 + if result != nil && !goja.IsUndefined(result) { + resultObj := result.Export() + if resultMap, ok := resultObj.(map[string]any); ok { + if newStatusCode, exists := resultMap["statusCode"]; exists { + if statusInt, ok := newStatusCode.(float64); ok { + statusCode = int(statusInt) + } + } + + if newBody, exists := resultMap["body"]; exists { + if bodyStr, ok := newBody.(string); ok { + body = []byte(bodyStr) + } + } + + if newHeaders, exists := resultMap["headers"]; exists { + if headersMap, ok := newHeaders.(map[string]any); ok { + for key, value := range headersMap { + if valueStr, ok := value.(string); ok { + c.Header(key, valueStr) + } + } + } + } + } + } + + return statusCode, body, nil +} + +func JSRuntimeMiddleware() gin.HandlerFunc { + if os.Getenv("JS_RUNTIME_ENABLED") != "true" { + return func(c *gin.Context) { + c.Next() + } + } + + runtime := initJSRuntime() + return func(c *gin.Context) { + start := time.Now() + + // 预处理 + common.SysLog("JS Runtime PreProcessing Request: " + c.Request.Method + " " + c.Request.URL.String()) + if err := runtime.PreProcessRequest(c); err != nil { + // 如果预处理返回错误,说明请求被阻止 + common.SysError("JS Runtime PreProcess Error: " + err.Error()) + return + } + common.SysLog("JS Runtime PreProcessing Completed") + + // 后处理 + if runtime.hasPostProcessFunction() { + common.SysLog("JS Runtime PostProcessing Response") + writer := &responseWriter{ + ResponseWriter: c.Writer, + body: &bytes.Buffer{}, + statusCode: 200, // 默认状态码 + headers: make(map[string]string), + } + c.Writer = writer + + c.Next() + + // 后处理响应 + if writer.body.Len() > 0 { + statusCode, body, err := runtime.PostProcessResponse(c, writer.statusCode, writer.body.Bytes()) + if err == nil { + // 更新响应 + c.Writer = writer.ResponseWriter + // Clear any existing content-length header to let Gin handle it + c.Writer.Header().Del("Content-Length") + c.Status(statusCode) + c.Writer.Write(body) + common.SysLog(fmt.Sprintf("JS Runtime PostProcessing Completed with status %d", statusCode)) + } else { + // 出错时返回原始响应 + c.Writer = writer.ResponseWriter + // Clear any existing content-length header to let Gin handle it + c.Writer.Header().Del("Content-Length") + c.Status(writer.statusCode) + c.Writer.Write(writer.body.Bytes()) + common.SysError(fmt.Sprintf("JS Runtime PostProcess Error: %v", err)) + } + } else { + // 没有响应体时,恢复原始writer + c.Writer = writer.ResponseWriter + common.SysLog("JS Runtime PostProcessing Completed with no body") + } + } else { + c.Next() + common.SysLog("JS Runtime PostProcessing Skipped: No postProcessResponse function defined") + } + + // 记录处理时间 + duration := time.Since(start) + if duration > time.Millisecond*100 { + common.SysLog(fmt.Sprintf("JS Runtime processing took %v", duration)) + } + } +} + +func (js *JSRuntime) hasPostProcessFunction() bool { + js.mu.RLock() + defer js.mu.RUnlock() + postProcessFunc := js.vm.Get("postProcessResponse") + return postProcessFunc != nil && !goja.IsUndefined(postProcessFunc) +} + +type responseWriter struct { + gin.ResponseWriter + body *bytes.Buffer + statusCode int + written bool + headers map[string]string +} + +func (w *responseWriter) Write(data []byte) (int, error) { + if !w.written { + w.statusCode = 200 + w.written = true + } + w.body.Write(data) + return len(data), nil +} + +func (w *responseWriter) WriteString(s string) (int, error) { + if !w.written { + w.statusCode = 200 + w.written = true + } + w.body.WriteString(s) + return len(s), nil +} + +func (w *responseWriter) WriteHeader(statusCode int) { + w.statusCode = statusCode + w.written = true + // 不立即调用原始的 WriteHeader,等后处理完成后再调用 +} + +func (w *responseWriter) Header() http.Header { + return w.ResponseWriter.Header() +} + +func ReloadJSScripts() { + if jsRuntime != nil { + jsRuntime.ReloadScripts() + common.SysLog("JavaScript scripts reloaded") + } +} diff --git a/router/api-router.go b/router/api-router.go index db4c38985..a8965d6bc 100644 --- a/router/api-router.go +++ b/router/api-router.go @@ -19,6 +19,7 @@ func SetApiRouter(router *gin.Engine) { apiRouter.GET("/uptime/status", controller.GetUptimeKumaStatus) apiRouter.GET("/models", middleware.UserAuth(), controller.DashboardListModels) apiRouter.GET("/status/test", middleware.AdminAuth(), controller.TestStatus) + apiRouter.GET("/js_rt/reload", middleware.AdminAuth(), controller.ReloadJSScripts) apiRouter.GET("/notice", controller.GetNotice) apiRouter.GET("/about", controller.GetAbout) //apiRouter.GET("/midjourney", controller.GetMidjourney) diff --git a/router/main.go b/router/main.go index 0d2bfdcea..235764270 100644 --- a/router/main.go +++ b/router/main.go @@ -3,11 +3,12 @@ package router import ( "embed" "fmt" - "github.com/gin-gonic/gin" "net/http" "one-api/common" "os" "strings" + + "github.com/gin-gonic/gin" ) func SetRouter(router *gin.Engine, buildFS embed.FS, indexPage []byte) { diff --git a/router/relay-router.go b/router/relay-router.go index b48c9dc70..2b73eefbc 100644 --- a/router/relay-router.go +++ b/router/relay-router.go @@ -12,6 +12,8 @@ func SetRelayRouter(router *gin.Engine) { router.Use(middleware.CORS()) router.Use(middleware.DecompressRequestMiddleware()) router.Use(middleware.StatsMiddleware()) + router.Use(middleware.JSRuntimeMiddleware()) // 放在最后 + // https://platform.openai.com/docs/api-reference/introduction modelsRouter := router.Group("/v1/models") modelsRouter.Use(middleware.TokenAuth()) diff --git a/scripts/post_process.js b/scripts/post_process.js new file mode 100644 index 000000000..8cea694ce --- /dev/null +++ b/scripts/post_process.js @@ -0,0 +1,9 @@ +// 后处理 +// 在请求处理完成后执行的函数 +// +// @param {Object} ctx - 请求上下文对象 +// @param {Object} response - 响应对象(包含状态码、头部和正文等) +// @returns {Object|undefined} - 返回修改后的响应对象或 undefined +function postProcessResponse(ctx, response) { + return undefined; +} diff --git a/scripts/pre_process.js b/scripts/pre_process.js new file mode 100644 index 000000000..4cdfd3df2 --- /dev/null +++ b/scripts/pre_process.js @@ -0,0 +1,63 @@ +// 请求预处理 +// 在请求被处理之前执行的函数 +// +// @param {Object} ctx - 请求上下文对象 +// @returns {Object|undefined} - 返回修改后的请求对象或 undefined +// +// 参考: [JS Rt](./middleware/js_rt.go) 里的 `JSContext` +function preProcessRequest(ctx) { + // 例子:基于数据库的速率限制 + // if (ctx.url.includes("/v1/chat/completions")) { + // try { + // // Check recent requests from this IP + // var recentRequests = db.Query( + // "SELECT COUNT(*) as count FROM logs WHERE created_at > ? AND ip = ?", + // Math.floor(Date.now() / 1000) - 60, // last minute + // ctx.remoteIP + // ); + + // if (recentRequests && recentRequests.length > 0 && recentRequests[0].count > 10) { + // console.log("速率限制 IP:", ctx.RemoteIP); + // return { + // block: true, + // statusCode: 429, + // message: "超过速率限制" + // }; + // } + // } catch (e) { + // console.error("Ratelimit 数据库错误:", e); + // } + // } + + // 例子:修改请求 + if (ctx.URL.includes("/v1/chat/completions")) { + try { + var bodyObj = ctx.Body; + + let firstMsg = { + role: "user", + content: "今天天气怎么样" + }; + bodyObj.messages[0] = firstMsg; + console.log("Modified first message:", JSON.stringify(firstMsg)); + console.log("Modified body:", JSON.stringify(bodyObj)); + + return { + body: bodyObj, + headers: { + ...ctx.Headers, + "X-Modified-Body": "true" + } + }; + } catch (e) { + console.error("Failed to parse/modify request body:", { + message: e.message, + stack: e.stack, + bodyType: typeof ctx.Body, + url: ctx.URL + }); + } + } + + return undefined; // 跳过处理,继续执行下一个中间件或路由 +}