diff --git a/controller/misc.go b/controller/misc.go index 0474b0405..44e1f336b 100644 --- a/controller/misc.go +++ b/controller/misc.go @@ -1,12 +1,14 @@ package controller import ( + "slices" "encoding/json" "fmt" "net/http" "one-api/common" "one-api/constant" "one-api/middleware" + "one-api/middleware/jsrt" "one-api/model" "one-api/setting" "one-api/setting/console_setting" @@ -33,7 +35,6 @@ func TestStatus(c *gin.Context) { "message": "Server is running", "http_stats": httpStats, }) - return } func GetStatus(c *gin.Context) { @@ -106,7 +107,6 @@ func GetStatus(c *gin.Context) { "message": "", "data": data, }) - return } func GetNotice(c *gin.Context) { @@ -117,7 +117,6 @@ func GetNotice(c *gin.Context) { "message": "", "data": common.OptionMap["Notice"], }) - return } func GetAbout(c *gin.Context) { @@ -128,7 +127,6 @@ func GetAbout(c *gin.Context) { "message": "", "data": common.OptionMap["About"], }) - return } func GetMidjourney(c *gin.Context) { @@ -139,7 +137,6 @@ func GetMidjourney(c *gin.Context) { "message": "", "data": common.OptionMap["Midjourney"], }) - return } func GetHomePageContent(c *gin.Context) { @@ -150,7 +147,6 @@ func GetHomePageContent(c *gin.Context) { "message": "", "data": common.OptionMap["HomePageContent"], }) - return } func SendEmailVerification(c *gin.Context) { @@ -173,13 +169,7 @@ func SendEmailVerification(c *gin.Context) { localPart := parts[0] domainPart := parts[1] if common.EmailDomainRestrictionEnabled { - allowed := false - for _, domain := range common.EmailDomainWhitelist { - if domainPart == domain { - allowed = true - break - } - } + allowed := slices.Contains(common.EmailDomainWhitelist, domainPart) if !allowed { c.JSON(http.StatusOK, gin.H{ "success": false, @@ -224,7 +214,6 @@ func SendEmailVerification(c *gin.Context) { "success": true, "message": "", }) - return } func SendPasswordResetEmail(c *gin.Context) { @@ -263,7 +252,6 @@ func SendPasswordResetEmail(c *gin.Context) { "success": true, "message": "", }) - return } type PasswordResetRequest struct { @@ -306,7 +294,7 @@ func ResetPassword(c *gin.Context) { } func ReloadJSScripts(c *gin.Context) { - middleware.ReloadJSScripts() + jsrt.ReloadJSScripts() c.JSON(http.StatusOK, gin.H{ "success": true, diff --git a/middleware/jsrt/cfg.go b/middleware/jsrt/cfg.go new file mode 100644 index 000000000..2ad5aff0b --- /dev/null +++ b/middleware/jsrt/cfg.go @@ -0,0 +1,69 @@ +package jsrt + +import ( + "os" + "strconv" + "time" +) + +// / Runtime 配置 +type JSRuntimeConfig struct { + Enabled bool `json:"enabled"` + MaxVMCount int `json:"max_vm_count"` + ScriptTimeout time.Duration `json:"script_timeout"` + PreScriptPath string `json:"pre_script_path"` + PostScriptPath string `json:"post_script_path"` + FetchTimeout time.Duration `json:"fetch_timeout"` +} + +var ( + jsConfig = JSRuntimeConfig{} +) + +const ( + defaultPreScriptPath = "scripts/pre_process.js" + defaultPostScriptPath = "scripts/post_process.js" + defaultScriptTimeout = 5 * time.Second + defaultFetchTimeout = 10 * time.Second + defaultMaxVMCount = 8 +) + +func init() { + if enabled := os.Getenv("JS_RUNTIME_ENABLED"); enabled != "" { + jsConfig.Enabled = enabled == "true" + } + + if maxCount := os.Getenv("JS_MAX_VM_COUNT"); maxCount != "" { + if count, err := strconv.Atoi(maxCount); err == nil && count > 0 { + jsConfig.MaxVMCount = count + } + } else { + jsConfig.MaxVMCount = defaultMaxVMCount + } + + if timeout := os.Getenv("JS_SCRIPT_TIMEOUT"); timeout != "" { + if t, err := time.ParseDuration(timeout + "s"); err == nil && t > 0 { + jsConfig.ScriptTimeout = t + } + } else { + jsConfig.ScriptTimeout = defaultScriptTimeout + } + + if fetchTimeout := os.Getenv("JS_FETCH_TIMEOUT"); fetchTimeout != "" { + if t, err := time.ParseDuration(fetchTimeout + "s"); err == nil && t > 0 { + jsConfig.FetchTimeout = t + } + } else { + jsConfig.FetchTimeout = defaultFetchTimeout + } + + jsConfig.PreScriptPath = os.Getenv("JS_PREPROCESS_SCRIPT_PATH") + if jsConfig.PreScriptPath == "" { + jsConfig.PreScriptPath = defaultPreScriptPath + } + + jsConfig.PostScriptPath = os.Getenv("JS_POSTPROCESS_SCRIPT_PATH") + if jsConfig.PostScriptPath == "" { + jsConfig.PostScriptPath = defaultPostScriptPath + } +} diff --git a/middleware/jsrt/ctx.go b/middleware/jsrt/ctx.go new file mode 100644 index 000000000..c5cee0389 --- /dev/null +++ b/middleware/jsrt/ctx.go @@ -0,0 +1,139 @@ +package jsrt + +import ( + "bytes" + "io" + "maps" + "net/http" + "sync" + + "github.com/gin-gonic/gin" +) + +// / 上下文 +type JSContext struct { + Method string `json:"method"` + URL string `json:"url"` + Headers map[string]string `json:"headers"` + Body any `json:"body"` + UserAgent string `json:"userAgent"` + RemoteIP string `json:"remoteIP"` + Extra map[string]any `json:"extra"` +} + +type JSResponse struct { + StatusCode int `json:"statusCode"` + Headers map[string]string `json:"headers"` + Body string `json:"body"` +} + +type responseWriter struct { + gin.ResponseWriter + body *bytes.Buffer + statusCode int + headerMap http.Header + written bool + mu sync.RWMutex +} + +func createJSContext(c *gin.Context) *JSContext { + var bodyBytes []byte + if c.Request != nil && c.Request.Body != nil { + bodyBytes, _ = io.ReadAll(c.Request.Body) + c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) + } + + // headers map + headers := make(map[string]string) + if c.Request != nil && c.Request.Header != nil { + for key, values := range c.Request.Header { + if len(values) > 0 { + headers[key] = values[0] + } + } + } + + method := "" + url := "" + userAgent := "" + remoteIP := "" + contentType := "" + + if c.Request != nil { + method = c.Request.Method + if c.Request.URL != nil { + url = c.Request.URL.String() + } + userAgent = c.Request.UserAgent() + contentType = c.ContentType() + } + + if c != nil { + remoteIP = c.ClientIP() + } + + parsedBody := parseBodyByType(bodyBytes, contentType) + + return &JSContext{ + Method: method, + URL: url, + Headers: headers, + Body: parsedBody, + UserAgent: userAgent, + RemoteIP: remoteIP, + Extra: make(map[string]any), + } +} + +func newResponseWriter(w gin.ResponseWriter) *responseWriter { + return &responseWriter{ + ResponseWriter: w, + body: &bytes.Buffer{}, + statusCode: 200, + headerMap: make(http.Header), + written: false, + } +} + +func (w *responseWriter) Write(data []byte) (int, error) { + w.mu.Lock() + defer w.mu.Unlock() + + if !w.written { + w.WriteHeader(200) + } + return w.body.Write(data) +} + +func (w *responseWriter) WriteString(s string) (int, error) { + w.mu.Lock() + defer w.mu.Unlock() + + if !w.written { + w.WriteHeader(200) + } + return w.body.WriteString(s) +} + +func (w *responseWriter) WriteHeader(statusCode int) { + w.mu.Lock() + defer w.mu.Unlock() + + if w.written { + return + } + w.statusCode = statusCode + w.written = true + + maps.Copy(w.headerMap, w.ResponseWriter.Header()) +} + +func (w *responseWriter) Header() http.Header { + w.mu.RLock() + defer w.mu.RUnlock() + + if w.headerMap == nil { + w.headerMap = make(http.Header) + } + return w.headerMap +} diff --git a/middleware/jsrt/db.go b/middleware/jsrt/db.go new file mode 100644 index 000000000..546c09f38 --- /dev/null +++ b/middleware/jsrt/db.go @@ -0,0 +1,73 @@ +package jsrt + +import ( + "one-api/common" + + "gorm.io/gorm" +) + +type JSDatabase struct { + db *gorm.DB +} + +func (jsdb *JSDatabase) Query(sql string, args ...any) []map[string]any { + if jsdb.db == nil { + common.SysError("JS DB is nil") + return nil + } + + rows, err := jsdb.db.Raw(sql, args...).Rows() + if err != nil { + common.SysError("JS DB Query Error: " + err.Error()) + return nil + } + defer rows.Close() + + columns, err := rows.Columns() + if err != nil { + common.SysError("JS DB Columns Error: " + err.Error()) + return nil + } + + results := make([]map[string]any, 0, 100) + for rows.Next() { + values := make([]any, len(columns)) + valuePtrs := make([]any, len(columns)) + for i := range values { + valuePtrs[i] = &values[i] + } + + if err := rows.Scan(valuePtrs...); err != nil { + common.SysError("JS DB Scan Error: " + err.Error()) + continue + } + + row := make(map[string]any, len(columns)) + for i, col := range columns { + val := values[i] + if b, ok := val.([]byte); ok { + row[col] = string(b) + } else { + row[col] = val + } + } + results = append(results, row) + } + + return results +} + +func (jsdb *JSDatabase) Exec(sql string, args ...any) map[string]any { + if jsdb.db == nil { + return map[string]any{ + "rowsAffected": int64(0), + "error": "database is nil", + } + } + + result := jsdb.db.Exec(sql, args...) + return map[string]any{ + "rowsAffected": result.RowsAffected, + "error": result.Error, + } +} diff --git a/middleware/jsrt/fetch.go b/middleware/jsrt/fetch.go new file mode 100644 index 000000000..bffe71345 --- /dev/null +++ b/middleware/jsrt/fetch.go @@ -0,0 +1,150 @@ +package jsrt + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + "time" +) + +type JSFetchRequest struct { + Method string `json:"method"` + URL string `json:"url"` + Headers map[string]string `json:"headers"` + Body string `json:"body"` + Timeout int `json:"timeout"` +} + +type JSFetchResponse struct { + Status int `json:"status"` + StatusText string `json:"statusText"` + Headers map[string]string `json:"headers"` + Body string `json:"body"` + OK bool `json:"ok"` +} + +func (p *JSRuntimePool) fetch(url string, options ...any) *JSFetchResponse { + req := &JSFetchRequest{ + Method: "GET", + URL: url, + Headers: make(map[string]string), + Timeout: int(jsConfig.FetchTimeout.Seconds()), + } + + // 解析选项 + if len(options) > 0 && options[0] != nil { + if optMap, ok := options[0].(map[string]any); ok { + if method, exists := optMap["method"]; exists { + if methodStr, ok := method.(string); ok { + req.Method = strings.ToUpper(methodStr) + } + } + + if headers, exists := optMap["headers"]; exists { + if headersMap, ok := headers.(map[string]any); ok { + for k, v := range headersMap { + if vStr, ok := v.(string); ok { + req.Headers[k] = vStr + } + } + } + } + + if body, exists := optMap["body"]; exists { + switch v := body.(type) { + case string: + req.Body = v + case map[string]any: + if bodyBytes, err := json.Marshal(v); err == nil { + req.Body = string(bodyBytes) + req.Headers["Content-Type"] = "application/json" + } + default: + req.Body = fmt.Sprintf("%v", body) + } + } + + if timeout, exists := optMap["timeout"]; exists { + if timeoutNum, ok := timeout.(float64); ok { + req.Timeout = int(timeoutNum) + } + } + } + } + + // 创建HTTP请求 + var bodyReader io.Reader + if req.Body != "" { + bodyReader = strings.NewReader(req.Body) + } + + httpReq, err := http.NewRequest(req.Method, req.URL, bodyReader) + if err != nil { + return &JSFetchResponse{ + Status: 0, + StatusText: err.Error(), + Headers: make(map[string]string), + Body: "", + OK: false, + } + } + + // 设置请求头 + for k, v := range req.Headers { + httpReq.Header.Set(k, v) + } + + // 设置默认User-Agent + if httpReq.Header.Get("User-Agent") == "" { + httpReq.Header.Set("User-Agent", "JS-Runtime-Fetch/1.0") + } + + // 创建带超时的上下文 + ctx, cancel := context.WithTimeout(context.Background(), time.Duration(req.Timeout)*time.Second) + defer cancel() + httpReq = httpReq.WithContext(ctx) + + // 执行请求 + resp, err := p.httpClient.Do(httpReq) + if err != nil { + return &JSFetchResponse{ + Status: 0, + StatusText: err.Error(), + Headers: make(map[string]string), + Body: "", + OK: false, + } + } + defer resp.Body.Close() + + // 读取响应体 + bodyBytes, err := io.ReadAll(resp.Body) + if err != nil { + return &JSFetchResponse{ + Status: resp.StatusCode, + StatusText: resp.Status, + Headers: make(map[string]string), + Body: "", + OK: resp.StatusCode >= 200 && resp.StatusCode < 300, + } + } + + // 构建响应头 + headers := make(map[string]string) + for k, v := range resp.Header { + if len(v) > 0 { + headers[k] = v[0] + } + } + + return &JSFetchResponse{ + Status: resp.StatusCode, + StatusText: resp.Status, + Headers: headers, + Body: string(bodyBytes), + OK: resp.StatusCode >= 200 && resp.StatusCode < 300, + } +} diff --git a/middleware/js_rt.go b/middleware/jsrt/jsrt.go similarity index 54% rename from middleware/js_rt.go rename to middleware/jsrt/jsrt.go index 2118f646a..2de1a3811 100644 --- a/middleware/js_rt.go +++ b/middleware/jsrt/jsrt.go @@ -1,37 +1,23 @@ -package middleware +package jsrt import ( "bytes" - "context" "crypto/tls" "encoding/json" "fmt" "io" - "maps" "net/http" - "net/url" "one-api/common" "one-api/model" "os" - "strconv" "strings" "sync" "time" "github.com/dop251/goja" "github.com/gin-gonic/gin" - "gorm.io/gorm" ) -/// Runtime 配置 -type JSRuntimeConfig struct { - Enabled bool `json:"enabled"` - MaxVMCount int `json:"max_vm_count"` - ScriptTimeout time.Duration `json:"script_timeout"` - PreScriptPath string `json:"pre_script_path"` - PostScriptPath string `json:"post_script_path"` - FetchTimeout time.Duration `json:"fetch_timeout"` -} /// 池化 type JSRuntimePool struct { @@ -43,211 +29,11 @@ type JSRuntimePool struct { httpClient *http.Client } -/// 上下文 -type JSContext struct { - Method string `json:"method"` - URL string `json:"url"` - Headers map[string]string `json:"headers"` - Body any `json:"body"` - UserAgent string `json:"userAgent"` - RemoteIP string `json:"remoteIP"` - Extra map[string]any `json:"extra"` -} - -type JSResponse struct { - StatusCode int `json:"statusCode"` - Headers map[string]string `json:"headers"` - Body string `json:"body"` -} - -type JSDatabase struct { - db *gorm.DB -} - -type JSFetchRequest struct { - Method string `json:"method"` - URL string `json:"url"` - Headers map[string]string `json:"headers"` - Body string `json:"body"` - Timeout int `json:"timeout"` -} - -type JSFetchResponse struct { - Status int `json:"status"` - StatusText string `json:"statusText"` - Headers map[string]string `json:"headers"` - Body string `json:"body"` - OK bool `json:"ok"` -} - -type responseWriter struct { - gin.ResponseWriter - body *bytes.Buffer - statusCode int - headerMap http.Header - written bool - mu sync.RWMutex -} - var ( jsRuntimePool *JSRuntimePool jsPoolOnce sync.Once - jsConfig = JSRuntimeConfig{} ) -const ( - defaultPreScriptPath = "scripts/pre_process.js" - defaultPostScriptPath = "scripts/post_process.js" - defaultScriptTimeout = 5 * time.Second - defaultFetchTimeout = 10 * time.Second - defaultMaxVMCount = 8 -) - -func init() { - if enabled := os.Getenv("JS_RUNTIME_ENABLED"); enabled != "" { - jsConfig.Enabled = enabled == "true" - } - - if maxCount := os.Getenv("JS_MAX_VM_COUNT"); maxCount != "" { - if count, err := strconv.Atoi(maxCount); err == nil && count > 0 { - jsConfig.MaxVMCount = count - } - } else { - jsConfig.MaxVMCount = defaultMaxVMCount - } - - if timeout := os.Getenv("JS_SCRIPT_TIMEOUT"); timeout != "" { - if t, err := time.ParseDuration(timeout + "s"); err == nil && t > 0 { - jsConfig.ScriptTimeout = t - } - } else { - jsConfig.ScriptTimeout = defaultScriptTimeout - } - - if fetchTimeout := os.Getenv("JS_FETCH_TIMEOUT"); fetchTimeout != "" { - if t, err := time.ParseDuration(fetchTimeout + "s"); err == nil && t > 0 { - jsConfig.FetchTimeout = t - } - } else { - jsConfig.FetchTimeout = defaultFetchTimeout - } - - jsConfig.PreScriptPath = os.Getenv("JS_PREPROCESS_SCRIPT_PATH") - if jsConfig.PreScriptPath == "" { - jsConfig.PreScriptPath = defaultPreScriptPath - } - - jsConfig.PostScriptPath = os.Getenv("JS_POSTPROCESS_SCRIPT_PATH") - if jsConfig.PostScriptPath == "" { - jsConfig.PostScriptPath = defaultPostScriptPath - } -} - -func parseBodyByType(bodyBytes []byte, contentType string) any { - if len(bodyBytes) == 0 { - return "" - } - - bodyStr := string(bodyBytes) - contentLower := strings.ToLower(contentType) - - switch { - case strings.Contains(contentLower, "application/json"): - var jsonObj any - if err := json.Unmarshal(bodyBytes, &jsonObj); err == nil { - return jsonObj - } - return bodyStr - - case strings.Contains(contentLower, "application/x-www-form-urlencoded"): - if values, err := url.ParseQuery(bodyStr); err == nil { - result := make(map[string]string, len(values)) - for k, v := range values { - if len(v) > 0 { - result[k] = v[0] - } - } - return result - } - return bodyStr - - case strings.Contains(contentLower, "multipart/form-data"): - return bodyBytes - - case strings.Contains(contentLower, "text/"): - return bodyStr - - default: - // 尝试JSON解析 - var jsonObj any - if json.Unmarshal(bodyBytes, &jsonObj) == nil { - return jsonObj - } - - // 尝试form解析 - if values, err := url.ParseQuery(bodyStr); err == nil && len(values) > 0 { - result := make(map[string]string, len(values)) - for k, v := range values { - if len(v) > 0 { - result[k] = v[0] - } - } - return result - } - - return bodyStr - } -} - -func createJSContext(c *gin.Context) *JSContext { - var bodyBytes []byte - if c.Request != nil && c.Request.Body != nil { - bodyBytes, _ = io.ReadAll(c.Request.Body) - c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) - } - - // headers map - headers := make(map[string]string) - if c.Request != nil && c.Request.Header != nil { - for key, values := range c.Request.Header { - if len(values) > 0 { - headers[key] = values[0] - } - } - } - - method := "" - url := "" - userAgent := "" - remoteIP := "" - contentType := "" - - if c.Request != nil { - method = c.Request.Method - if c.Request.URL != nil { - url = c.Request.URL.String() - } - userAgent = c.Request.UserAgent() - contentType = c.ContentType() - } - - if c != nil { - remoteIP = c.ClientIP() - } - - parsedBody := parseBodyByType(bodyBytes, contentType) - - return &JSContext{ - Method: method, - URL: url, - Headers: headers, - Body: parsedBody, - UserAgent: userAgent, - RemoteIP: remoteIP, - Extra: make(map[string]any), - } -} - func NewJSRuntimePool(maxSize int) *JSRuntimePool { // 创建HTTP客户端 httpClient := &http.Client{ @@ -362,7 +148,7 @@ func (p *JSRuntimePool) setupGlobals(vm *goja.Runtime) { // 数据库 vm.Set("db", &JSDatabase{db: model.DB}) - // 定时器 (简化版) + // 定时器 vm.Set("setTimeout", func(fn func(), delay int) { go func() { time.Sleep(time.Duration(delay) * time.Millisecond) @@ -371,129 +157,6 @@ func (p *JSRuntimePool) setupGlobals(vm *goja.Runtime) { }) } -func (p *JSRuntimePool) fetch(url string, options ...any) *JSFetchResponse { - req := &JSFetchRequest{ - Method: "GET", - URL: url, - Headers: make(map[string]string), - Timeout: int(jsConfig.FetchTimeout.Seconds()), - } - - // 解析选项 - if len(options) > 0 && options[0] != nil { - if optMap, ok := options[0].(map[string]any); ok { - if method, exists := optMap["method"]; exists { - if methodStr, ok := method.(string); ok { - req.Method = strings.ToUpper(methodStr) - } - } - - if headers, exists := optMap["headers"]; exists { - if headersMap, ok := headers.(map[string]any); ok { - for k, v := range headersMap { - if vStr, ok := v.(string); ok { - req.Headers[k] = vStr - } - } - } - } - - if body, exists := optMap["body"]; exists { - switch v := body.(type) { - case string: - req.Body = v - case map[string]any: - if bodyBytes, err := json.Marshal(v); err == nil { - req.Body = string(bodyBytes) - req.Headers["Content-Type"] = "application/json" - } - default: - req.Body = fmt.Sprintf("%v", body) - } - } - - if timeout, exists := optMap["timeout"]; exists { - if timeoutNum, ok := timeout.(float64); ok { - req.Timeout = int(timeoutNum) - } - } - } - } - - // 创建HTTP请求 - var bodyReader io.Reader - if req.Body != "" { - bodyReader = strings.NewReader(req.Body) - } - - httpReq, err := http.NewRequest(req.Method, req.URL, bodyReader) - if err != nil { - return &JSFetchResponse{ - Status: 0, - StatusText: err.Error(), - Headers: make(map[string]string), - Body: "", - OK: false, - } - } - - // 设置请求头 - for k, v := range req.Headers { - httpReq.Header.Set(k, v) - } - - // 设置默认User-Agent - if httpReq.Header.Get("User-Agent") == "" { - httpReq.Header.Set("User-Agent", "JS-Runtime-Fetch/1.0") - } - - // 创建带超时的上下文 - ctx, cancel := context.WithTimeout(context.Background(), time.Duration(req.Timeout)*time.Second) - defer cancel() - httpReq = httpReq.WithContext(ctx) - - // 执行请求 - resp, err := p.httpClient.Do(httpReq) - if err != nil { - return &JSFetchResponse{ - Status: 0, - StatusText: err.Error(), - Headers: make(map[string]string), - Body: "", - OK: false, - } - } - defer resp.Body.Close() - - // 读取响应体 - bodyBytes, err := io.ReadAll(resp.Body) - if err != nil { - return &JSFetchResponse{ - Status: resp.StatusCode, - StatusText: resp.Status, - Headers: make(map[string]string), - Body: "", - OK: resp.StatusCode >= 200 && resp.StatusCode < 300, - } - } - - // 构建响应头 - headers := make(map[string]string) - for k, v := range resp.Header { - if len(v) > 0 { - headers[k] = v[0] - } - } - - return &JSFetchResponse{ - Status: resp.StatusCode, - StatusText: resp.Status, - Headers: headers, - Body: string(bodyBytes), - OK: resp.StatusCode >= 200 && resp.StatusCode < 300, - } -} - func (p *JSRuntimePool) loadScripts(vm *goja.Runtime) { p.mu.RLock() defer p.mu.RUnlock() @@ -546,68 +209,6 @@ done: common.SysLog("JavaScript scripts reloaded") } -func (jsdb *JSDatabase) Query(sql string, args ...any) []map[string]any { - if jsdb.db == nil { - common.SysError("JS DB is nil") - return nil - } - - rows, err := jsdb.db.Raw(sql, args...).Rows() - if err != nil { - common.SysError("JS DB Query Error: " + err.Error()) - return nil - } - defer rows.Close() - - columns, err := rows.Columns() - if err != nil { - common.SysError("JS DB Columns Error: " + err.Error()) - return nil - } - - results := make([]map[string]any, 0, 100) - for rows.Next() { - values := make([]any, len(columns)) - valuePtrs := make([]any, len(columns)) - for i := range values { - valuePtrs[i] = &values[i] - } - - if err := rows.Scan(valuePtrs...); err != nil { - common.SysError("JS DB Scan Error: " + err.Error()) - continue - } - - row := make(map[string]any, len(columns)) - for i, col := range columns { - val := values[i] - if b, ok := val.([]byte); ok { - row[col] = string(b) - } else { - row[col] = val - } - } - results = append(results, row) - } - - return results -} - -func (jsdb *JSDatabase) Exec(sql string, args ...any) map[string]any { - if jsdb.db == nil { - return map[string]any{ - "rowsAffected": int64(0), - "error": "database is nil", - } - } - - result := jsdb.db.Exec(sql, args...) - return map[string]any{ - "rowsAffected": result.RowsAffected, - "error": result.Error, - } -} - func initJSRuntimePool() *JSRuntimePool { jsPoolOnce.Do(func() { jsRuntimePool = NewJSRuntimePool(jsConfig.MaxVMCount) @@ -837,59 +438,6 @@ func (p *JSRuntimePool) hasPostProcessFunction() bool { return postProcessFunc != nil && !goja.IsUndefined(postProcessFunc) } -func newResponseWriter(w gin.ResponseWriter) *responseWriter { - return &responseWriter{ - ResponseWriter: w, - body: &bytes.Buffer{}, - statusCode: 200, - headerMap: make(http.Header), - written: false, - } -} - -func (w *responseWriter) Write(data []byte) (int, error) { - w.mu.Lock() - defer w.mu.Unlock() - - if !w.written { - w.WriteHeader(200) - } - return w.body.Write(data) -} - -func (w *responseWriter) WriteString(s string) (int, error) { - w.mu.Lock() - defer w.mu.Unlock() - - if !w.written { - w.WriteHeader(200) - } - return w.body.WriteString(s) -} - -func (w *responseWriter) WriteHeader(statusCode int) { - w.mu.Lock() - defer w.mu.Unlock() - - if w.written { - return - } - w.statusCode = statusCode - w.written = true - - maps.Copy(w.headerMap, w.ResponseWriter.Header()) -} - -func (w *responseWriter) Header() http.Header { - w.mu.RLock() - defer w.mu.RUnlock() - - if w.headerMap == nil { - w.headerMap = make(http.Header) - } - return w.headerMap -} - func JSRuntimeMiddleware() gin.HandlerFunc { if !jsConfig.Enabled { return func(c *gin.Context) { diff --git a/middleware/jsrt/utils.go b/middleware/jsrt/utils.go new file mode 100644 index 000000000..446e6bf56 --- /dev/null +++ b/middleware/jsrt/utils.go @@ -0,0 +1,64 @@ +package jsrt + +import ( + "encoding/json" + "net/url" + "strings" +) + + +func parseBodyByType(bodyBytes []byte, contentType string) any { + if len(bodyBytes) == 0 { + return "" + } + + bodyStr := string(bodyBytes) + contentLower := strings.ToLower(contentType) + + switch { + case strings.Contains(contentLower, "application/json"): + var jsonObj any + if err := json.Unmarshal(bodyBytes, &jsonObj); err == nil { + return jsonObj + } + return bodyStr + + case strings.Contains(contentLower, "application/x-www-form-urlencoded"): + if values, err := url.ParseQuery(bodyStr); err == nil { + result := make(map[string]string, len(values)) + for k, v := range values { + if len(v) > 0 { + result[k] = v[0] + } + } + return result + } + return bodyStr + + case strings.Contains(contentLower, "multipart/form-data"): + return bodyBytes + + case strings.Contains(contentLower, "text/"): + return bodyStr + + default: + // 尝试JSON解析 + var jsonObj any + if json.Unmarshal(bodyBytes, &jsonObj) == nil { + return jsonObj + } + + // 尝试form解析 + if values, err := url.ParseQuery(bodyStr); err == nil && len(values) > 0 { + result := make(map[string]string, len(values)) + for k, v := range values { + if len(v) > 0 { + result[k] = v[0] + } + } + return result + } + + return bodyStr + } +} \ No newline at end of file diff --git a/router/main.go b/router/main.go index 2dca96c38..ff4b97055 100644 --- a/router/main.go +++ b/router/main.go @@ -5,7 +5,7 @@ import ( "fmt" "net/http" "one-api/common" - "one-api/middleware" + "one-api/middleware/jsrt" "os" "strings" @@ -13,7 +13,7 @@ import ( ) func SetRouter(router *gin.Engine, buildFS embed.FS, indexPage []byte) { - router.Use(middleware.JSRuntimeMiddleware()) + router.Use(jsrt.JSRuntimeMiddleware()) SetApiRouter(router) SetDashboardRouter(router) SetRelayRouter(router) diff --git a/scripts/pre_process.js b/scripts/pre_process.js index 4cdfd3df2..5d7408316 100644 --- a/scripts/pre_process.js +++ b/scripts/pre_process.js @@ -50,7 +50,7 @@ function preProcessRequest(ctx) { } }; } catch (e) { - console.error("Failed to parse/modify request body:", { + console.error("Failed to modify request body:", { message: e.message, stack: e.stack, bodyType: typeof ctx.Body,