From 58895711084c1f454ed5dffe9f5af01b685f03c2 Mon Sep 17 00:00:00 2001 From: Seefs Date: Tue, 9 Dec 2025 11:15:27 +0800 Subject: [PATCH] fix: Use channel proxy settings for task query scenarios --- controller/task.go | 3 ++- controller/task_video.go | 3 ++- controller/video_proxy.go | 19 ++++++++++++++++--- controller/video_proxy_gemini.go | 3 ++- relay/channel/adapter.go | 2 +- relay/channel/task/ali/adaptor.go | 8 ++++++-- relay/channel/task/doubao/adaptor.go | 8 ++++++-- relay/channel/task/gemini/adaptor.go | 8 ++++++-- relay/channel/task/hailuo/adaptor.go | 8 ++++++-- relay/channel/task/jimeng/adaptor.go | 8 ++++++-- relay/channel/task/kling/adaptor.go | 8 ++++++-- relay/channel/task/sora/adaptor.go | 8 ++++++-- relay/channel/task/suno/adaptor.go | 8 ++++---- relay/channel/task/vertex/adaptor.go | 16 ++++++++++++---- relay/channel/task/vidu/adaptor.go | 8 ++++++-- relay/relay_task.go | 3 ++- service/http_client.go | 28 ++++++++++++++++++---------- 17 files changed, 107 insertions(+), 42 deletions(-) diff --git a/controller/task.go b/controller/task.go index ad034d61e..16acc2269 100644 --- a/controller/task.go +++ b/controller/task.go @@ -116,9 +116,10 @@ func updateSunoTaskAll(ctx context.Context, channelId int, taskIds []string, tas if adaptor == nil { return errors.New("adaptor not found") } + proxy := channel.GetSetting().Proxy resp, err := adaptor.FetchTask(*channel.BaseURL, channel.Key, map[string]any{ "ids": taskIds, - }) + }, proxy) if err != nil { common.SysLog(fmt.Sprintf("Get Task Do req error: %v", err)) return err diff --git a/controller/task_video.go b/controller/task_video.go index 8c9f9719e..86095307d 100644 --- a/controller/task_video.go +++ b/controller/task_video.go @@ -67,6 +67,7 @@ func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, cha if channel.GetBaseURL() != "" { baseURL = channel.GetBaseURL() } + proxy := channel.GetSetting().Proxy task := taskM[taskId] if task == nil { @@ -76,7 +77,7 @@ func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, cha resp, err := adaptor.FetchTask(baseURL, channel.Key, map[string]any{ "task_id": taskId, "action": task.Action, - }) + }, proxy) if err != nil { return fmt.Errorf("fetchTask failed for task %s: %w", taskId, err) } diff --git a/controller/video_proxy.go b/controller/video_proxy.go index a577cf819..f102baae4 100644 --- a/controller/video_proxy.go +++ b/controller/video_proxy.go @@ -1,6 +1,7 @@ package controller import ( + "context" "fmt" "io" "net/http" @@ -10,6 +11,7 @@ import ( "github.com/QuantumNous/new-api/constant" "github.com/QuantumNous/new-api/logger" "github.com/QuantumNous/new-api/model" + "github.com/QuantumNous/new-api/service" "github.com/gin-gonic/gin" ) @@ -75,11 +77,22 @@ func VideoProxy(c *gin.Context) { } var videoURL string - client := &http.Client{ - Timeout: 60 * time.Second, + proxy := channel.GetSetting().Proxy + client, err := service.GetHttpClientWithProxy(proxy) + if err != nil { + logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to create proxy client for task %s: %s", taskID, err.Error())) + c.JSON(http.StatusInternalServerError, gin.H{ + "error": gin.H{ + "message": "Failed to create proxy client", + "type": "server_error", + }, + }) + return } - req, err := http.NewRequestWithContext(c.Request.Context(), http.MethodGet, "", nil) + ctx, cancel := context.WithTimeout(c.Request.Context(), 60*time.Second) + defer cancel() + req, err := http.NewRequestWithContext(ctx, http.MethodGet, "", nil) if err != nil { logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to create request: %s", err.Error())) c.JSON(http.StatusInternalServerError, gin.H{ diff --git a/controller/video_proxy_gemini.go b/controller/video_proxy_gemini.go index 4e2e60e62..053ac6515 100644 --- a/controller/video_proxy_gemini.go +++ b/controller/video_proxy_gemini.go @@ -35,10 +35,11 @@ func getGeminiVideoURL(channel *model.Channel, task *model.Task, apiKey string) return "", fmt.Errorf("api key not available for task") } + proxy := channel.GetSetting().Proxy resp, err := adaptor.FetchTask(baseURL, apiKey, map[string]any{ "task_id": task.TaskID, "action": task.Action, - }) + }, proxy) if err != nil { return "", fmt.Errorf("fetch task failed: %w", err) } diff --git a/relay/channel/adapter.go b/relay/channel/adapter.go index 7f8faf22d..ff7606e2e 100644 --- a/relay/channel/adapter.go +++ b/relay/channel/adapter.go @@ -47,7 +47,7 @@ type TaskAdaptor interface { GetChannelName() string // FetchTask - FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error) + FetchTask(baseUrl, key string, body map[string]any, proxy string) (*http.Response, error) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) } diff --git a/relay/channel/task/ali/adaptor.go b/relay/channel/task/ali/adaptor.go index 32d5da398..eef699665 100644 --- a/relay/channel/task/ali/adaptor.go +++ b/relay/channel/task/ali/adaptor.go @@ -393,7 +393,7 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela } // FetchTask 查询任务状态 -func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error) { +func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any, proxy string) (*http.Response, error) { taskID, ok := body["task_id"].(string) if !ok { return nil, fmt.Errorf("invalid task_id") @@ -408,7 +408,11 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http req.Header.Set("Authorization", "Bearer "+key) - return service.GetHttpClient().Do(req) + client, err := service.GetHttpClientWithProxy(proxy) + if err != nil { + return nil, fmt.Errorf("new proxy http client failed: %w", err) + } + return client.Do(req) } func (a *TaskAdaptor) GetModelList() []string { diff --git a/relay/channel/task/doubao/adaptor.go b/relay/channel/task/doubao/adaptor.go index 1bacb2019..dd21fb75a 100644 --- a/relay/channel/task/doubao/adaptor.go +++ b/relay/channel/task/doubao/adaptor.go @@ -146,7 +146,7 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela } // FetchTask fetch task status -func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error) { +func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any, proxy string) (*http.Response, error) { taskID, ok := body["task_id"].(string) if !ok { return nil, fmt.Errorf("invalid task_id") @@ -163,7 +163,11 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http req.Header.Set("Content-Type", "application/json") req.Header.Set("Authorization", "Bearer "+key) - return service.GetHttpClient().Do(req) + client, err := service.GetHttpClientWithProxy(proxy) + if err != nil { + return nil, fmt.Errorf("new proxy http client failed: %w", err) + } + return client.Do(req) } func (a *TaskAdaptor) GetModelList() []string { diff --git a/relay/channel/task/gemini/adaptor.go b/relay/channel/task/gemini/adaptor.go index 0fa9dda4b..16c6919b7 100644 --- a/relay/channel/task/gemini/adaptor.go +++ b/relay/channel/task/gemini/adaptor.go @@ -200,7 +200,7 @@ func (a *TaskAdaptor) GetChannelName() string { } // FetchTask fetch task status -func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error) { +func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any, proxy string) (*http.Response, error) { taskID, ok := body["task_id"].(string) if !ok { return nil, fmt.Errorf("invalid task_id") @@ -223,7 +223,11 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http req.Header.Set("Accept", "application/json") req.Header.Set("x-goog-api-key", key) - return service.GetHttpClient().Do(req) + client, err := service.GetHttpClientWithProxy(proxy) + if err != nil { + return nil, fmt.Errorf("new proxy http client failed: %w", err) + } + return client.Do(req) } func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) { diff --git a/relay/channel/task/hailuo/adaptor.go b/relay/channel/task/hailuo/adaptor.go index cb6f1eebd..c77905bfb 100644 --- a/relay/channel/task/hailuo/adaptor.go +++ b/relay/channel/task/hailuo/adaptor.go @@ -110,7 +110,7 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela return hResp.TaskID, responseBody, nil } -func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error) { +func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any, proxy string) (*http.Response, error) { taskID, ok := body["task_id"].(string) if !ok { return nil, fmt.Errorf("invalid task_id") @@ -126,7 +126,11 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http req.Header.Set("Accept", "application/json") req.Header.Set("Authorization", "Bearer "+key) - return service.GetHttpClient().Do(req) + client, err := service.GetHttpClientWithProxy(proxy) + if err != nil { + return nil, fmt.Errorf("new proxy http client failed: %w", err) + } + return client.Do(req) } func (a *TaskAdaptor) GetModelList() []string { diff --git a/relay/channel/task/jimeng/adaptor.go b/relay/channel/task/jimeng/adaptor.go index da4a1f8fe..d6973531f 100644 --- a/relay/channel/task/jimeng/adaptor.go +++ b/relay/channel/task/jimeng/adaptor.go @@ -210,7 +210,7 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela } // FetchTask fetch task status -func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error) { +func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any, proxy string) (*http.Response, error) { taskID, ok := body["task_id"].(string) if !ok { return nil, fmt.Errorf("invalid task_id") @@ -251,7 +251,11 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http return nil, errors.Wrap(err, "sign request failed") } } - return service.GetHttpClient().Do(req) + client, err := service.GetHttpClientWithProxy(proxy) + if err != nil { + return nil, fmt.Errorf("new proxy http client failed: %w", err) + } + return client.Do(req) } func (a *TaskAdaptor) GetModelList() []string { diff --git a/relay/channel/task/kling/adaptor.go b/relay/channel/task/kling/adaptor.go index c1bbd9d59..d00350652 100644 --- a/relay/channel/task/kling/adaptor.go +++ b/relay/channel/task/kling/adaptor.go @@ -199,7 +199,7 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela } // FetchTask fetch task status -func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error) { +func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any, proxy string) (*http.Response, error) { taskID, ok := body["task_id"].(string) if !ok { return nil, fmt.Errorf("invalid task_id") @@ -228,7 +228,11 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http req.Header.Set("Authorization", "Bearer "+token) req.Header.Set("User-Agent", "kling-sdk/1.0") - return service.GetHttpClient().Do(req) + client, err := service.GetHttpClientWithProxy(proxy) + if err != nil { + return nil, fmt.Errorf("new proxy http client failed: %w", err) + } + return client.Do(req) } func (a *TaskAdaptor) GetModelList() []string { diff --git a/relay/channel/task/sora/adaptor.go b/relay/channel/task/sora/adaptor.go index 17aec18f0..214561b5b 100644 --- a/relay/channel/task/sora/adaptor.go +++ b/relay/channel/task/sora/adaptor.go @@ -125,7 +125,7 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, _ *relayco } // FetchTask fetch task status -func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error) { +func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any, proxy string) (*http.Response, error) { taskID, ok := body["task_id"].(string) if !ok { return nil, fmt.Errorf("invalid task_id") @@ -140,7 +140,11 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http req.Header.Set("Authorization", "Bearer "+key) - return service.GetHttpClient().Do(req) + client, err := service.GetHttpClientWithProxy(proxy) + if err != nil { + return nil, fmt.Errorf("new proxy http client failed: %w", err) + } + return client.Do(req) } func (a *TaskAdaptor) GetModelList() []string { diff --git a/relay/channel/task/suno/adaptor.go b/relay/channel/task/suno/adaptor.go index c4858d0c0..f7c891723 100644 --- a/relay/channel/task/suno/adaptor.go +++ b/relay/channel/task/suno/adaptor.go @@ -132,7 +132,7 @@ func (a *TaskAdaptor) GetChannelName() string { return ChannelName } -func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error) { +func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any, proxy string) (*http.Response, error) { requestUrl := fmt.Sprintf("%s/suno/fetch", baseUrl) byteBody, err := json.Marshal(body) if err != nil { @@ -153,11 +153,11 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http req = req.WithContext(ctx) req.Header.Set("Content-Type", "application/json") req.Header.Set("Authorization", "Bearer "+key) - resp, err := service.GetHttpClient().Do(req) + client, err := service.GetHttpClientWithProxy(proxy) if err != nil { - return nil, err + return nil, fmt.Errorf("new proxy http client failed: %w", err) } - return resp, nil + return client.Do(req) } func actionValidate(c *gin.Context, sunoRequest *dto.SunoSubmitReq, action string) (err error) { diff --git a/relay/channel/task/vertex/adaptor.go b/relay/channel/task/vertex/adaptor.go index d98ac53cf..8ec77266e 100644 --- a/relay/channel/task/vertex/adaptor.go +++ b/relay/channel/task/vertex/adaptor.go @@ -120,7 +120,11 @@ func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info return fmt.Errorf("failed to decode credentials: %w", err) } - token, err := vertexcore.AcquireAccessToken(*adc, "") + proxy := "" + if info != nil { + proxy = info.ChannelSetting.Proxy + } + token, err := vertexcore.AcquireAccessToken(*adc, proxy) if err != nil { return fmt.Errorf("failed to acquire access token: %w", err) } @@ -216,7 +220,7 @@ func (a *TaskAdaptor) GetModelList() []string { return []string{"veo-3.0-generat func (a *TaskAdaptor) GetChannelName() string { return "vertex" } // FetchTask fetch task status -func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error) { +func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any, proxy string) (*http.Response, error) { taskID, ok := body["task_id"].(string) if !ok { return nil, fmt.Errorf("invalid task_id") @@ -249,7 +253,7 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http if err := json.Unmarshal([]byte(key), adc); err != nil { return nil, fmt.Errorf("failed to decode credentials: %w", err) } - token, err := vertexcore.AcquireAccessToken(*adc, "") + token, err := vertexcore.AcquireAccessToken(*adc, proxy) if err != nil { return nil, fmt.Errorf("failed to acquire access token: %w", err) } @@ -261,7 +265,11 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http req.Header.Set("Accept", "application/json") req.Header.Set("Authorization", "Bearer "+token) req.Header.Set("x-goog-user-project", adc.ProjectID) - return service.GetHttpClient().Do(req) + client, err := service.GetHttpClientWithProxy(proxy) + if err != nil { + return nil, fmt.Errorf("new proxy http client failed: %w", err) + } + return client.Do(req) } func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) { diff --git a/relay/channel/task/vidu/adaptor.go b/relay/channel/task/vidu/adaptor.go index 6b62f1f01..3657161c0 100644 --- a/relay/channel/task/vidu/adaptor.go +++ b/relay/channel/task/vidu/adaptor.go @@ -188,7 +188,7 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela return vResp.TaskId, responseBody, nil } -func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error) { +func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any, proxy string) (*http.Response, error) { taskID, ok := body["task_id"].(string) if !ok { return nil, fmt.Errorf("invalid task_id") @@ -204,7 +204,11 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http req.Header.Set("Accept", "application/json") req.Header.Set("Authorization", "Token "+key) - return service.GetHttpClient().Do(req) + client, err := service.GetHttpClientWithProxy(proxy) + if err != nil { + return nil, fmt.Errorf("new proxy http client failed: %w", err) + } + return client.Do(req) } func (a *TaskAdaptor) GetModelList() []string { diff --git a/relay/relay_task.go b/relay/relay_task.go index 61e2af523..ba9fe1e8f 100644 --- a/relay/relay_task.go +++ b/relay/relay_task.go @@ -326,6 +326,7 @@ func videoFetchByIDRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *d if channelModel.GetBaseURL() != "" { baseURL = channelModel.GetBaseURL() } + proxy := channelModel.GetSetting().Proxy adaptor := GetTaskAdaptor(constant.TaskPlatform(strconv.Itoa(channelModel.Type))) if adaptor == nil { return @@ -333,7 +334,7 @@ func videoFetchByIDRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *d resp, err2 := adaptor.FetchTask(baseURL, channelModel.Key, map[string]any{ "task_id": originTask.TaskID, "action": originTask.Action, - }) + }, proxy) if err2 != nil || resp == nil { return } diff --git a/service/http_client.go b/service/http_client.go index 2fa9e51cf..be89c73c0 100644 --- a/service/http_client.go +++ b/service/http_client.go @@ -35,9 +35,9 @@ func checkRedirect(req *http.Request, via []*http.Request) error { func InitHttpClient() { transport := &http.Transport{ - MaxIdleConns: common.RelayMaxIdleConns, - MaxIdleConnsPerHost: common.RelayMaxIdleConnsPerHost, - ForceAttemptHTTP2: true, + MaxIdleConns: common.RelayMaxIdleConns, + MaxIdleConnsPerHost: common.RelayMaxIdleConnsPerHost, + ForceAttemptHTTP2: true, } if common.RelayTimeout == 0 { @@ -58,6 +58,14 @@ func GetHttpClient() *http.Client { return httpClient } +// GetHttpClientWithProxy returns the default client or a proxy-enabled one when proxyURL is provided. +func GetHttpClientWithProxy(proxyURL string) (*http.Client, error) { + if proxyURL == "" { + return GetHttpClient(), nil + } + return NewProxyHttpClient(proxyURL) +} + // ResetProxyClientCache 清空代理客户端缓存,确保下次使用时重新初始化 func ResetProxyClientCache() { proxyClientLock.Lock() @@ -92,10 +100,10 @@ func NewProxyHttpClient(proxyURL string) (*http.Client, error) { case "http", "https": client := &http.Client{ Transport: &http.Transport{ - MaxIdleConns: common.RelayMaxIdleConns, - MaxIdleConnsPerHost: common.RelayMaxIdleConnsPerHost, - ForceAttemptHTTP2: true, - Proxy: http.ProxyURL(parsedURL), + MaxIdleConns: common.RelayMaxIdleConns, + MaxIdleConnsPerHost: common.RelayMaxIdleConnsPerHost, + ForceAttemptHTTP2: true, + Proxy: http.ProxyURL(parsedURL), }, CheckRedirect: checkRedirect, } @@ -127,9 +135,9 @@ func NewProxyHttpClient(proxyURL string) (*http.Client, error) { client := &http.Client{ Transport: &http.Transport{ - MaxIdleConns: common.RelayMaxIdleConns, - MaxIdleConnsPerHost: common.RelayMaxIdleConnsPerHost, - ForceAttemptHTTP2: true, + MaxIdleConns: common.RelayMaxIdleConns, + MaxIdleConnsPerHost: common.RelayMaxIdleConnsPerHost, + ForceAttemptHTTP2: true, DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { return dialer.Dial(network, addr) },