mirror of
https://github.com/QuantumNous/new-api.git
synced 2026-03-30 04:40:59 +00:00
70 lines
2.4 KiB
Go
70 lines
2.4 KiB
Go
package service
|
||
|
||
import (
|
||
"bytes"
|
||
"encoding/json"
|
||
"fmt"
|
||
"net/http"
|
||
"one-api/common"
|
||
"one-api/setting/system_setting"
|
||
"strings"
|
||
)
|
||
|
||
// WorkerRequest Worker请求的数据结构
|
||
type WorkerRequest struct {
|
||
URL string `json:"url"`
|
||
Key string `json:"key"`
|
||
Method string `json:"method,omitempty"`
|
||
Headers map[string]string `json:"headers,omitempty"`
|
||
Body json.RawMessage `json:"body,omitempty"`
|
||
}
|
||
|
||
// DoWorkerRequest 通过Worker发送请求
|
||
func DoWorkerRequest(req *WorkerRequest) (*http.Response, error) {
|
||
if !system_setting.EnableWorker() {
|
||
return nil, fmt.Errorf("worker not enabled")
|
||
}
|
||
if !system_setting.WorkerAllowHttpImageRequestEnabled && !strings.HasPrefix(req.URL, "https") {
|
||
return nil, fmt.Errorf("only support https url")
|
||
}
|
||
|
||
// SSRF防护:验证请求URL
|
||
fetchSetting := system_setting.GetFetchSetting()
|
||
if err := common.ValidateURLWithFetchSetting(req.URL, fetchSetting.EnableSSRFProtection, fetchSetting.AllowPrivateIp, fetchSetting.DomainList, fetchSetting.IpList, fetchSetting.AllowedPorts); err != nil {
|
||
return nil, fmt.Errorf("request reject: %v", err)
|
||
}
|
||
|
||
workerUrl := system_setting.WorkerUrl
|
||
if !strings.HasSuffix(workerUrl, "/") {
|
||
workerUrl += "/"
|
||
}
|
||
|
||
// 序列化worker请求数据
|
||
workerPayload, err := json.Marshal(req)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("failed to marshal worker payload: %v", err)
|
||
}
|
||
|
||
return http.Post(workerUrl, "application/json", bytes.NewBuffer(workerPayload))
|
||
}
|
||
|
||
func DoDownloadRequest(originUrl string, reason ...string) (resp *http.Response, err error) {
|
||
if system_setting.EnableWorker() {
|
||
common.SysLog(fmt.Sprintf("downloading file from worker: %s, reason: %s", originUrl, strings.Join(reason, ", ")))
|
||
req := &WorkerRequest{
|
||
URL: originUrl,
|
||
Key: system_setting.WorkerValidKey,
|
||
}
|
||
return DoWorkerRequest(req)
|
||
} else {
|
||
// SSRF防护:验证请求URL(非Worker模式)
|
||
fetchSetting := system_setting.GetFetchSetting()
|
||
if err := common.ValidateURLWithFetchSetting(originUrl, fetchSetting.EnableSSRFProtection, fetchSetting.AllowPrivateIp, fetchSetting.DomainList, fetchSetting.IpList, fetchSetting.AllowedPorts); err != nil {
|
||
return nil, fmt.Errorf("request reject: %v", err)
|
||
}
|
||
|
||
common.SysLog(fmt.Sprintf("downloading from origin: %s, reason: %s", common.MaskSensitiveInfo(originUrl), strings.Join(reason, ", ")))
|
||
return http.Get(originUrl)
|
||
}
|
||
}
|