Merge pull request #2398 from seefs001/fix/video-proxy

fix: Use channel proxy settings for task query scenarios
This commit is contained in:
Calcium-Ion
2025-12-09 14:05:30 +08:00
committed by GitHub
17 changed files with 107 additions and 42 deletions

View File

@@ -116,9 +116,10 @@ func updateSunoTaskAll(ctx context.Context, channelId int, taskIds []string, tas
if adaptor == nil { if adaptor == nil {
return errors.New("adaptor not found") return errors.New("adaptor not found")
} }
proxy := channel.GetSetting().Proxy
resp, err := adaptor.FetchTask(*channel.BaseURL, channel.Key, map[string]any{ resp, err := adaptor.FetchTask(*channel.BaseURL, channel.Key, map[string]any{
"ids": taskIds, "ids": taskIds,
}) }, proxy)
if err != nil { if err != nil {
common.SysLog(fmt.Sprintf("Get Task Do req error: %v", err)) common.SysLog(fmt.Sprintf("Get Task Do req error: %v", err))
return err return err

View File

@@ -67,6 +67,7 @@ func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, cha
if channel.GetBaseURL() != "" { if channel.GetBaseURL() != "" {
baseURL = channel.GetBaseURL() baseURL = channel.GetBaseURL()
} }
proxy := channel.GetSetting().Proxy
task := taskM[taskId] task := taskM[taskId]
if task == nil { if task == nil {
@@ -76,7 +77,7 @@ func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, cha
resp, err := adaptor.FetchTask(baseURL, channel.Key, map[string]any{ resp, err := adaptor.FetchTask(baseURL, channel.Key, map[string]any{
"task_id": taskId, "task_id": taskId,
"action": task.Action, "action": task.Action,
}) }, proxy)
if err != nil { if err != nil {
return fmt.Errorf("fetchTask failed for task %s: %w", taskId, err) return fmt.Errorf("fetchTask failed for task %s: %w", taskId, err)
} }

View File

@@ -1,6 +1,7 @@
package controller package controller
import ( import (
"context"
"fmt" "fmt"
"io" "io"
"net/http" "net/http"
@@ -10,6 +11,7 @@ import (
"github.com/QuantumNous/new-api/constant" "github.com/QuantumNous/new-api/constant"
"github.com/QuantumNous/new-api/logger" "github.com/QuantumNous/new-api/logger"
"github.com/QuantumNous/new-api/model" "github.com/QuantumNous/new-api/model"
"github.com/QuantumNous/new-api/service"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
@@ -75,11 +77,22 @@ func VideoProxy(c *gin.Context) {
} }
var videoURL string var videoURL string
client := &http.Client{ proxy := channel.GetSetting().Proxy
Timeout: 60 * time.Second, client, err := service.GetHttpClientWithProxy(proxy)
if err != nil {
logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to create proxy client for task %s: %s", taskID, err.Error()))
c.JSON(http.StatusInternalServerError, gin.H{
"error": gin.H{
"message": "Failed to create proxy client",
"type": "server_error",
},
})
return
} }
req, err := http.NewRequestWithContext(c.Request.Context(), http.MethodGet, "", nil) ctx, cancel := context.WithTimeout(c.Request.Context(), 60*time.Second)
defer cancel()
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "", nil)
if err != nil { if err != nil {
logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to create request: %s", err.Error())) logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to create request: %s", err.Error()))
c.JSON(http.StatusInternalServerError, gin.H{ c.JSON(http.StatusInternalServerError, gin.H{

View File

@@ -35,10 +35,11 @@ func getGeminiVideoURL(channel *model.Channel, task *model.Task, apiKey string)
return "", fmt.Errorf("api key not available for task") return "", fmt.Errorf("api key not available for task")
} }
proxy := channel.GetSetting().Proxy
resp, err := adaptor.FetchTask(baseURL, apiKey, map[string]any{ resp, err := adaptor.FetchTask(baseURL, apiKey, map[string]any{
"task_id": task.TaskID, "task_id": task.TaskID,
"action": task.Action, "action": task.Action,
}) }, proxy)
if err != nil { if err != nil {
return "", fmt.Errorf("fetch task failed: %w", err) return "", fmt.Errorf("fetch task failed: %w", err)
} }

View File

@@ -47,7 +47,7 @@ type TaskAdaptor interface {
GetChannelName() string GetChannelName() string
// FetchTask // FetchTask
FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error) FetchTask(baseUrl, key string, body map[string]any, proxy string) (*http.Response, error)
ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error)
} }

View File

@@ -393,7 +393,7 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela
} }
// FetchTask 查询任务状态 // FetchTask 查询任务状态
func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error) { func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any, proxy string) (*http.Response, error) {
taskID, ok := body["task_id"].(string) taskID, ok := body["task_id"].(string)
if !ok { if !ok {
return nil, fmt.Errorf("invalid task_id") return nil, fmt.Errorf("invalid task_id")
@@ -408,7 +408,11 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http
req.Header.Set("Authorization", "Bearer "+key) req.Header.Set("Authorization", "Bearer "+key)
return service.GetHttpClient().Do(req) client, err := service.GetHttpClientWithProxy(proxy)
if err != nil {
return nil, fmt.Errorf("new proxy http client failed: %w", err)
}
return client.Do(req)
} }
func (a *TaskAdaptor) GetModelList() []string { func (a *TaskAdaptor) GetModelList() []string {

View File

@@ -146,7 +146,7 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela
} }
// FetchTask fetch task status // FetchTask fetch task status
func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error) { func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any, proxy string) (*http.Response, error) {
taskID, ok := body["task_id"].(string) taskID, ok := body["task_id"].(string)
if !ok { if !ok {
return nil, fmt.Errorf("invalid task_id") return nil, fmt.Errorf("invalid task_id")
@@ -163,7 +163,11 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http
req.Header.Set("Content-Type", "application/json") req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+key) req.Header.Set("Authorization", "Bearer "+key)
return service.GetHttpClient().Do(req) client, err := service.GetHttpClientWithProxy(proxy)
if err != nil {
return nil, fmt.Errorf("new proxy http client failed: %w", err)
}
return client.Do(req)
} }
func (a *TaskAdaptor) GetModelList() []string { func (a *TaskAdaptor) GetModelList() []string {

View File

@@ -200,7 +200,7 @@ func (a *TaskAdaptor) GetChannelName() string {
} }
// FetchTask fetch task status // FetchTask fetch task status
func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error) { func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any, proxy string) (*http.Response, error) {
taskID, ok := body["task_id"].(string) taskID, ok := body["task_id"].(string)
if !ok { if !ok {
return nil, fmt.Errorf("invalid task_id") return nil, fmt.Errorf("invalid task_id")
@@ -223,7 +223,11 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http
req.Header.Set("Accept", "application/json") req.Header.Set("Accept", "application/json")
req.Header.Set("x-goog-api-key", key) req.Header.Set("x-goog-api-key", key)
return service.GetHttpClient().Do(req) client, err := service.GetHttpClientWithProxy(proxy)
if err != nil {
return nil, fmt.Errorf("new proxy http client failed: %w", err)
}
return client.Do(req)
} }
func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) { func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) {

View File

@@ -110,7 +110,7 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela
return hResp.TaskID, responseBody, nil return hResp.TaskID, responseBody, nil
} }
func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error) { func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any, proxy string) (*http.Response, error) {
taskID, ok := body["task_id"].(string) taskID, ok := body["task_id"].(string)
if !ok { if !ok {
return nil, fmt.Errorf("invalid task_id") return nil, fmt.Errorf("invalid task_id")
@@ -126,7 +126,11 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http
req.Header.Set("Accept", "application/json") req.Header.Set("Accept", "application/json")
req.Header.Set("Authorization", "Bearer "+key) req.Header.Set("Authorization", "Bearer "+key)
return service.GetHttpClient().Do(req) client, err := service.GetHttpClientWithProxy(proxy)
if err != nil {
return nil, fmt.Errorf("new proxy http client failed: %w", err)
}
return client.Do(req)
} }
func (a *TaskAdaptor) GetModelList() []string { func (a *TaskAdaptor) GetModelList() []string {

View File

@@ -210,7 +210,7 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela
} }
// FetchTask fetch task status // FetchTask fetch task status
func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error) { func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any, proxy string) (*http.Response, error) {
taskID, ok := body["task_id"].(string) taskID, ok := body["task_id"].(string)
if !ok { if !ok {
return nil, fmt.Errorf("invalid task_id") return nil, fmt.Errorf("invalid task_id")
@@ -251,7 +251,11 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http
return nil, errors.Wrap(err, "sign request failed") return nil, errors.Wrap(err, "sign request failed")
} }
} }
return service.GetHttpClient().Do(req) client, err := service.GetHttpClientWithProxy(proxy)
if err != nil {
return nil, fmt.Errorf("new proxy http client failed: %w", err)
}
return client.Do(req)
} }
func (a *TaskAdaptor) GetModelList() []string { func (a *TaskAdaptor) GetModelList() []string {

View File

@@ -199,7 +199,7 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela
} }
// FetchTask fetch task status // FetchTask fetch task status
func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error) { func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any, proxy string) (*http.Response, error) {
taskID, ok := body["task_id"].(string) taskID, ok := body["task_id"].(string)
if !ok { if !ok {
return nil, fmt.Errorf("invalid task_id") return nil, fmt.Errorf("invalid task_id")
@@ -228,7 +228,11 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http
req.Header.Set("Authorization", "Bearer "+token) req.Header.Set("Authorization", "Bearer "+token)
req.Header.Set("User-Agent", "kling-sdk/1.0") req.Header.Set("User-Agent", "kling-sdk/1.0")
return service.GetHttpClient().Do(req) client, err := service.GetHttpClientWithProxy(proxy)
if err != nil {
return nil, fmt.Errorf("new proxy http client failed: %w", err)
}
return client.Do(req)
} }
func (a *TaskAdaptor) GetModelList() []string { func (a *TaskAdaptor) GetModelList() []string {

View File

@@ -125,7 +125,7 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, _ *relayco
} }
// FetchTask fetch task status // FetchTask fetch task status
func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error) { func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any, proxy string) (*http.Response, error) {
taskID, ok := body["task_id"].(string) taskID, ok := body["task_id"].(string)
if !ok { if !ok {
return nil, fmt.Errorf("invalid task_id") return nil, fmt.Errorf("invalid task_id")
@@ -140,7 +140,11 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http
req.Header.Set("Authorization", "Bearer "+key) req.Header.Set("Authorization", "Bearer "+key)
return service.GetHttpClient().Do(req) client, err := service.GetHttpClientWithProxy(proxy)
if err != nil {
return nil, fmt.Errorf("new proxy http client failed: %w", err)
}
return client.Do(req)
} }
func (a *TaskAdaptor) GetModelList() []string { func (a *TaskAdaptor) GetModelList() []string {

View File

@@ -132,7 +132,7 @@ func (a *TaskAdaptor) GetChannelName() string {
return ChannelName return ChannelName
} }
func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error) { func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any, proxy string) (*http.Response, error) {
requestUrl := fmt.Sprintf("%s/suno/fetch", baseUrl) requestUrl := fmt.Sprintf("%s/suno/fetch", baseUrl)
byteBody, err := json.Marshal(body) byteBody, err := json.Marshal(body)
if err != nil { if err != nil {
@@ -153,11 +153,11 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http
req = req.WithContext(ctx) req = req.WithContext(ctx)
req.Header.Set("Content-Type", "application/json") req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+key) req.Header.Set("Authorization", "Bearer "+key)
resp, err := service.GetHttpClient().Do(req) client, err := service.GetHttpClientWithProxy(proxy)
if err != nil { if err != nil {
return nil, err return nil, fmt.Errorf("new proxy http client failed: %w", err)
} }
return resp, nil return client.Do(req)
} }
func actionValidate(c *gin.Context, sunoRequest *dto.SunoSubmitReq, action string) (err error) { func actionValidate(c *gin.Context, sunoRequest *dto.SunoSubmitReq, action string) (err error) {

View File

@@ -120,7 +120,11 @@ func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info
return fmt.Errorf("failed to decode credentials: %w", err) return fmt.Errorf("failed to decode credentials: %w", err)
} }
token, err := vertexcore.AcquireAccessToken(*adc, "") proxy := ""
if info != nil {
proxy = info.ChannelSetting.Proxy
}
token, err := vertexcore.AcquireAccessToken(*adc, proxy)
if err != nil { if err != nil {
return fmt.Errorf("failed to acquire access token: %w", err) return fmt.Errorf("failed to acquire access token: %w", err)
} }
@@ -216,7 +220,7 @@ func (a *TaskAdaptor) GetModelList() []string { return []string{"veo-3.0-generat
func (a *TaskAdaptor) GetChannelName() string { return "vertex" } func (a *TaskAdaptor) GetChannelName() string { return "vertex" }
// FetchTask fetch task status // FetchTask fetch task status
func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error) { func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any, proxy string) (*http.Response, error) {
taskID, ok := body["task_id"].(string) taskID, ok := body["task_id"].(string)
if !ok { if !ok {
return nil, fmt.Errorf("invalid task_id") return nil, fmt.Errorf("invalid task_id")
@@ -249,7 +253,7 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http
if err := json.Unmarshal([]byte(key), adc); err != nil { if err := json.Unmarshal([]byte(key), adc); err != nil {
return nil, fmt.Errorf("failed to decode credentials: %w", err) return nil, fmt.Errorf("failed to decode credentials: %w", err)
} }
token, err := vertexcore.AcquireAccessToken(*adc, "") token, err := vertexcore.AcquireAccessToken(*adc, proxy)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to acquire access token: %w", err) return nil, fmt.Errorf("failed to acquire access token: %w", err)
} }
@@ -261,7 +265,11 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http
req.Header.Set("Accept", "application/json") req.Header.Set("Accept", "application/json")
req.Header.Set("Authorization", "Bearer "+token) req.Header.Set("Authorization", "Bearer "+token)
req.Header.Set("x-goog-user-project", adc.ProjectID) req.Header.Set("x-goog-user-project", adc.ProjectID)
return service.GetHttpClient().Do(req) client, err := service.GetHttpClientWithProxy(proxy)
if err != nil {
return nil, fmt.Errorf("new proxy http client failed: %w", err)
}
return client.Do(req)
} }
func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) { func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) {

View File

@@ -188,7 +188,7 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela
return vResp.TaskId, responseBody, nil return vResp.TaskId, responseBody, nil
} }
func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error) { func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any, proxy string) (*http.Response, error) {
taskID, ok := body["task_id"].(string) taskID, ok := body["task_id"].(string)
if !ok { if !ok {
return nil, fmt.Errorf("invalid task_id") return nil, fmt.Errorf("invalid task_id")
@@ -204,7 +204,11 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http
req.Header.Set("Accept", "application/json") req.Header.Set("Accept", "application/json")
req.Header.Set("Authorization", "Token "+key) req.Header.Set("Authorization", "Token "+key)
return service.GetHttpClient().Do(req) client, err := service.GetHttpClientWithProxy(proxy)
if err != nil {
return nil, fmt.Errorf("new proxy http client failed: %w", err)
}
return client.Do(req)
} }
func (a *TaskAdaptor) GetModelList() []string { func (a *TaskAdaptor) GetModelList() []string {

View File

@@ -326,6 +326,7 @@ func videoFetchByIDRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *d
if channelModel.GetBaseURL() != "" { if channelModel.GetBaseURL() != "" {
baseURL = channelModel.GetBaseURL() baseURL = channelModel.GetBaseURL()
} }
proxy := channelModel.GetSetting().Proxy
adaptor := GetTaskAdaptor(constant.TaskPlatform(strconv.Itoa(channelModel.Type))) adaptor := GetTaskAdaptor(constant.TaskPlatform(strconv.Itoa(channelModel.Type)))
if adaptor == nil { if adaptor == nil {
return return
@@ -333,7 +334,7 @@ func videoFetchByIDRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *d
resp, err2 := adaptor.FetchTask(baseURL, channelModel.Key, map[string]any{ resp, err2 := adaptor.FetchTask(baseURL, channelModel.Key, map[string]any{
"task_id": originTask.TaskID, "task_id": originTask.TaskID,
"action": originTask.Action, "action": originTask.Action,
}) }, proxy)
if err2 != nil || resp == nil { if err2 != nil || resp == nil {
return return
} }

View File

@@ -35,9 +35,9 @@ func checkRedirect(req *http.Request, via []*http.Request) error {
func InitHttpClient() { func InitHttpClient() {
transport := &http.Transport{ transport := &http.Transport{
MaxIdleConns: common.RelayMaxIdleConns, MaxIdleConns: common.RelayMaxIdleConns,
MaxIdleConnsPerHost: common.RelayMaxIdleConnsPerHost, MaxIdleConnsPerHost: common.RelayMaxIdleConnsPerHost,
ForceAttemptHTTP2: true, ForceAttemptHTTP2: true,
} }
if common.RelayTimeout == 0 { if common.RelayTimeout == 0 {
@@ -58,6 +58,14 @@ func GetHttpClient() *http.Client {
return httpClient return httpClient
} }
// GetHttpClientWithProxy returns the default client or a proxy-enabled one when proxyURL is provided.
func GetHttpClientWithProxy(proxyURL string) (*http.Client, error) {
if proxyURL == "" {
return GetHttpClient(), nil
}
return NewProxyHttpClient(proxyURL)
}
// ResetProxyClientCache 清空代理客户端缓存,确保下次使用时重新初始化 // ResetProxyClientCache 清空代理客户端缓存,确保下次使用时重新初始化
func ResetProxyClientCache() { func ResetProxyClientCache() {
proxyClientLock.Lock() proxyClientLock.Lock()
@@ -92,10 +100,10 @@ func NewProxyHttpClient(proxyURL string) (*http.Client, error) {
case "http", "https": case "http", "https":
client := &http.Client{ client := &http.Client{
Transport: &http.Transport{ Transport: &http.Transport{
MaxIdleConns: common.RelayMaxIdleConns, MaxIdleConns: common.RelayMaxIdleConns,
MaxIdleConnsPerHost: common.RelayMaxIdleConnsPerHost, MaxIdleConnsPerHost: common.RelayMaxIdleConnsPerHost,
ForceAttemptHTTP2: true, ForceAttemptHTTP2: true,
Proxy: http.ProxyURL(parsedURL), Proxy: http.ProxyURL(parsedURL),
}, },
CheckRedirect: checkRedirect, CheckRedirect: checkRedirect,
} }
@@ -127,9 +135,9 @@ func NewProxyHttpClient(proxyURL string) (*http.Client, error) {
client := &http.Client{ client := &http.Client{
Transport: &http.Transport{ Transport: &http.Transport{
MaxIdleConns: common.RelayMaxIdleConns, MaxIdleConns: common.RelayMaxIdleConns,
MaxIdleConnsPerHost: common.RelayMaxIdleConnsPerHost, MaxIdleConnsPerHost: common.RelayMaxIdleConnsPerHost,
ForceAttemptHTTP2: true, ForceAttemptHTTP2: true,
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
return dialer.Dial(network, addr) return dialer.Dial(network, addr)
}, },