diff --git a/controller/relay.go b/controller/relay.go index c055ef71e..d3d93192e 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -3,7 +3,6 @@ package controller import ( "bytes" "fmt" - "github.com/bytedance/gopkg/util/gopool" "io" "log" "net/http" @@ -22,6 +21,8 @@ import ( "one-api/types" "strings" + "github.com/bytedance/gopkg/util/gopool" + "github.com/gin-gonic/gin" "github.com/gorilla/websocket" ) @@ -383,11 +384,14 @@ func RelayNotFound(c *gin.Context) { func RelayTask(c *gin.Context) { retryTimes := common.RetryTimes channelId := c.GetInt("channel_id") - relayMode := c.GetInt("relay_mode") group := c.GetString("group") originalModel := c.GetString("original_model") c.Set("use_channel", []string{fmt.Sprintf("%d", channelId)}) - taskErr := taskRelayHandler(c, relayMode) + relayInfo, err := relaycommon.GenRelayInfo(c, types.RelayFormatTask, nil, nil) + if err != nil { + return + } + taskErr := taskRelayHandler(c, relayInfo) if taskErr == nil { retryTimes = 0 } @@ -407,7 +411,7 @@ func RelayTask(c *gin.Context) { requestBody, _ := common.GetRequestBody(c) c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody)) - taskErr = taskRelayHandler(c, relayMode) + taskErr = taskRelayHandler(c, relayInfo) } useChannel := c.GetStringSlice("use_channel") if len(useChannel) > 1 { @@ -422,13 +426,13 @@ func RelayTask(c *gin.Context) { } } -func taskRelayHandler(c *gin.Context, relayMode int) *dto.TaskError { +func taskRelayHandler(c *gin.Context, relayInfo *relaycommon.RelayInfo) *dto.TaskError { var err *dto.TaskError - switch relayMode { + switch relayInfo.RelayMode { case relayconstant.RelayModeSunoFetch, relayconstant.RelayModeSunoFetchByID, relayconstant.RelayModeVideoFetchByID: - err = relay.RelayTaskFetch(c, relayMode) + err = relay.RelayTaskFetch(c, relayInfo.RelayMode) default: - err = relay.RelayTaskSubmit(c, relayMode) + err = relay.RelayTaskSubmit(c, relayInfo) } return err } diff --git a/model/task.go b/model/task.go index 9e4177ba0..4c64a5293 100644 --- a/model/task.go +++ b/model/task.go @@ -77,7 +77,7 @@ type SyncTaskQueryParams struct { UserIDs []int } -func InitTask(platform constant.TaskPlatform, relayInfo *commonRelay.TaskRelayInfo) *Task { +func InitTask(platform constant.TaskPlatform, relayInfo *commonRelay.RelayInfo) *Task { t := &Task{ UserId: relayInfo.UserId, SubmitTime: time.Now().Unix(), diff --git a/relay/channel/adapter.go b/relay/channel/adapter.go index ec7491334..02de99567 100644 --- a/relay/channel/adapter.go +++ b/relay/channel/adapter.go @@ -30,16 +30,16 @@ type Adaptor interface { } type TaskAdaptor interface { - Init(info *relaycommon.TaskRelayInfo) + Init(info *relaycommon.RelayInfo) - ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.TaskRelayInfo) *dto.TaskError + ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) *dto.TaskError - BuildRequestURL(info *relaycommon.TaskRelayInfo) (string, error) - BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.TaskRelayInfo) error - BuildRequestBody(c *gin.Context, info *relaycommon.TaskRelayInfo) (io.Reader, error) + BuildRequestURL(info *relaycommon.RelayInfo) (string, error) + BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error + BuildRequestBody(c *gin.Context, info *relaycommon.RelayInfo) (io.Reader, error) - DoRequest(c *gin.Context, info *relaycommon.TaskRelayInfo, requestBody io.Reader) (*http.Response, error) - DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.TaskRelayInfo) (taskID string, taskData []byte, err *dto.TaskError) + DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) + DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (taskID string, taskData []byte, err *dto.TaskError) GetModelList() []string GetChannelName() string diff --git a/relay/channel/api_request.go b/relay/channel/api_request.go index 518d25cea..a50d5bdb5 100644 --- a/relay/channel/api_request.go +++ b/relay/channel/api_request.go @@ -277,7 +277,7 @@ func doRequest(c *gin.Context, req *http.Request, info *common.RelayInfo) (*http return resp, nil } -func DoTaskApiRequest(a TaskAdaptor, c *gin.Context, info *common.TaskRelayInfo, requestBody io.Reader) (*http.Response, error) { +func DoTaskApiRequest(a TaskAdaptor, c *gin.Context, info *common.RelayInfo, requestBody io.Reader) (*http.Response, error) { fullRequestURL, err := a.BuildRequestURL(info) if err != nil { return nil, err @@ -294,7 +294,7 @@ func DoTaskApiRequest(a TaskAdaptor, c *gin.Context, info *common.TaskRelayInfo, if err != nil { return nil, fmt.Errorf("setup request header failed: %w", err) } - resp, err := doRequest(c, req, info.RelayInfo) + resp, err := doRequest(c, req, info) if err != nil { return nil, fmt.Errorf("do request failed: %w", err) } diff --git a/relay/channel/task/jimeng/adaptor.go b/relay/channel/task/jimeng/adaptor.go index a5ada1370..955e592a2 100644 --- a/relay/channel/task/jimeng/adaptor.go +++ b/relay/channel/task/jimeng/adaptor.go @@ -74,7 +74,7 @@ type TaskAdaptor struct { baseURL string } -func (a *TaskAdaptor) Init(info *relaycommon.TaskRelayInfo) { +func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) { a.ChannelType = info.ChannelType a.baseURL = info.ChannelBaseUrl @@ -87,7 +87,7 @@ func (a *TaskAdaptor) Init(info *relaycommon.TaskRelayInfo) { } // ValidateRequestAndSetAction parses body, validates fields and sets default action. -func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.TaskRelayInfo) (taskErr *dto.TaskError) { +func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.TaskError) { // Accept only POST /v1/video/generations as "generate" action. action := constant.TaskActionGenerate info.Action = action @@ -108,19 +108,19 @@ func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycom } // BuildRequestURL constructs the upstream URL. -func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.TaskRelayInfo) (string, error) { +func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.RelayInfo) (string, error) { return fmt.Sprintf("%s/?Action=CVSync2AsyncSubmitTask&Version=2022-08-31", a.baseURL), nil } // BuildRequestHeader sets required headers. -func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.TaskRelayInfo) error { +func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error { req.Header.Set("Content-Type", "application/json") req.Header.Set("Accept", "application/json") return a.signRequest(req, a.accessKey, a.secretKey) } // BuildRequestBody converts request into Jimeng specific format. -func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.TaskRelayInfo) (io.Reader, error) { +func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayInfo) (io.Reader, error) { v, exists := c.Get("task_request") if !exists { return nil, fmt.Errorf("request not found in context") @@ -139,12 +139,12 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.TaskRel } // DoRequest delegates to common helper. -func (a *TaskAdaptor) DoRequest(c *gin.Context, info *relaycommon.TaskRelayInfo, requestBody io.Reader) (*http.Response, error) { +func (a *TaskAdaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) { return channel.DoTaskApiRequest(a, c, info, requestBody) } // DoResponse handles upstream response, returns taskID etc. -func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.TaskRelayInfo) (taskID string, taskData []byte, taskErr *dto.TaskError) { +func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (taskID string, taskData []byte, taskErr *dto.TaskError) { responseBody, err := io.ReadAll(resp.Body) if err != nil { taskErr = service.TaskErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError) diff --git a/relay/channel/task/kling/adaptor.go b/relay/channel/task/kling/adaptor.go index 1fecda08a..f52f9db47 100644 --- a/relay/channel/task/kling/adaptor.go +++ b/relay/channel/task/kling/adaptor.go @@ -4,13 +4,14 @@ import ( "bytes" "encoding/json" "fmt" - "github.com/samber/lo" "io" "net/http" "one-api/model" "strings" "time" + "github.com/samber/lo" + "github.com/gin-gonic/gin" "github.com/golang-jwt/jwt" "github.com/pkg/errors" @@ -79,7 +80,7 @@ type TaskAdaptor struct { baseURL string } -func (a *TaskAdaptor) Init(info *relaycommon.TaskRelayInfo) { +func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) { a.ChannelType = info.ChannelType a.baseURL = info.ChannelBaseUrl a.apiKey = info.ApiKey @@ -88,7 +89,7 @@ func (a *TaskAdaptor) Init(info *relaycommon.TaskRelayInfo) { } // ValidateRequestAndSetAction parses body, validates fields and sets default action. -func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.TaskRelayInfo) (taskErr *dto.TaskError) { +func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.TaskError) { // Accept only POST /v1/video/generations as "generate" action. action := constant.TaskActionGenerate info.Action = action @@ -109,13 +110,13 @@ func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycom } // BuildRequestURL constructs the upstream URL. -func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.TaskRelayInfo) (string, error) { +func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.RelayInfo) (string, error) { path := lo.Ternary(info.Action == constant.TaskActionGenerate, "/v1/videos/image2video", "/v1/videos/text2video") return fmt.Sprintf("%s%s", a.baseURL, path), nil } // BuildRequestHeader sets required headers. -func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.TaskRelayInfo) error { +func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error { token, err := a.createJWTToken() if err != nil { return fmt.Errorf("failed to create JWT token: %w", err) @@ -129,7 +130,7 @@ func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info } // BuildRequestBody converts request into Kling specific format. -func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.TaskRelayInfo) (io.Reader, error) { +func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayInfo) (io.Reader, error) { v, exists := c.Get("task_request") if !exists { return nil, fmt.Errorf("request not found in context") @@ -148,7 +149,7 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.TaskRel } // DoRequest delegates to common helper. -func (a *TaskAdaptor) DoRequest(c *gin.Context, info *relaycommon.TaskRelayInfo, requestBody io.Reader) (*http.Response, error) { +func (a *TaskAdaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) { if action := c.GetString("action"); action != "" { info.Action = action } @@ -156,7 +157,7 @@ func (a *TaskAdaptor) DoRequest(c *gin.Context, info *relaycommon.TaskRelayInfo, } // DoResponse handles upstream response, returns taskID etc. -func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.TaskRelayInfo) (taskID string, taskData []byte, taskErr *dto.TaskError) { +func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (taskID string, taskData []byte, taskErr *dto.TaskError) { responseBody, err := io.ReadAll(resp.Body) if err != nil { taskErr = service.TaskErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError) diff --git a/relay/channel/task/suno/adaptor.go b/relay/channel/task/suno/adaptor.go index df2bb99ea..237513d75 100644 --- a/relay/channel/task/suno/adaptor.go +++ b/relay/channel/task/suno/adaptor.go @@ -5,7 +5,6 @@ import ( "context" "encoding/json" "fmt" - "github.com/gin-gonic/gin" "io" "net/http" "one-api/common" @@ -16,6 +15,8 @@ import ( "one-api/service" "strings" "time" + + "github.com/gin-gonic/gin" ) type TaskAdaptor struct { @@ -26,11 +27,11 @@ func (a *TaskAdaptor) ParseTaskResult([]byte) (*relaycommon.TaskInfo, error) { return nil, fmt.Errorf("not implement") // todo implement this method if needed } -func (a *TaskAdaptor) Init(info *relaycommon.TaskRelayInfo) { +func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) { a.ChannelType = info.ChannelType } -func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.TaskRelayInfo) (taskErr *dto.TaskError) { +func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.TaskError) { action := strings.ToUpper(c.Param("action")) var sunoRequest *dto.SunoSubmitReq @@ -58,20 +59,20 @@ func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycom return nil } -func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.TaskRelayInfo) (string, error) { +func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.RelayInfo) (string, error) { baseURL := info.ChannelBaseUrl fullRequestURL := fmt.Sprintf("%s%s", baseURL, "/suno/submit/"+info.Action) return fullRequestURL, nil } -func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.TaskRelayInfo) error { +func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error { req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) req.Header.Set("Accept", c.Request.Header.Get("Accept")) req.Header.Set("Authorization", "Bearer "+info.ApiKey) return nil } -func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.TaskRelayInfo) (io.Reader, error) { +func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayInfo) (io.Reader, error) { sunoRequest, ok := c.Get("task_request") if !ok { err := common.UnmarshalBodyReusable(c, &sunoRequest) @@ -86,11 +87,11 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.TaskRel return bytes.NewReader(data), nil } -func (a *TaskAdaptor) DoRequest(c *gin.Context, info *relaycommon.TaskRelayInfo, requestBody io.Reader) (*http.Response, error) { +func (a *TaskAdaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) { return channel.DoTaskApiRequest(a, c, info, requestBody) } -func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.TaskRelayInfo) (taskID string, taskData []byte, taskErr *dto.TaskError) { +func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (taskID string, taskData []byte, taskErr *dto.TaskError) { responseBody, err := io.ReadAll(resp.Body) if err != nil { taskErr = service.TaskErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError) diff --git a/relay/channel/task/vidu/adaptor.go b/relay/channel/task/vidu/adaptor.go index b0cc0bdc8..c82c1c0e8 100644 --- a/relay/channel/task/vidu/adaptor.go +++ b/relay/channel/task/vidu/adaptor.go @@ -84,12 +84,12 @@ type TaskAdaptor struct { baseURL string } -func (a *TaskAdaptor) Init(info *relaycommon.TaskRelayInfo) { +func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) { a.ChannelType = info.ChannelType a.baseURL = info.ChannelBaseUrl } -func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.TaskRelayInfo) *dto.TaskError { +func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) *dto.TaskError { var req SubmitReq if err := c.ShouldBindJSON(&req); err != nil { return service.TaskErrorWrapper(err, "invalid_request_body", http.StatusBadRequest) @@ -109,7 +109,7 @@ func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycom return nil } -func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, _ *relaycommon.TaskRelayInfo) (io.Reader, error) { +func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, _ *relaycommon.RelayInfo) (io.Reader, error) { v, exists := c.Get("task_request") if !exists { return nil, fmt.Errorf("request not found in context") @@ -132,7 +132,7 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, _ *relaycommon.TaskRelayI return bytes.NewReader(data), nil } -func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.TaskRelayInfo) (string, error) { +func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.RelayInfo) (string, error) { var path string switch info.Action { case constant.TaskActionGenerate: @@ -143,21 +143,21 @@ func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.TaskRelayInfo) (string, return fmt.Sprintf("%s/ent/v2%s", a.baseURL, path), nil } -func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.TaskRelayInfo) error { +func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error { req.Header.Set("Content-Type", "application/json") req.Header.Set("Accept", "application/json") req.Header.Set("Authorization", "Token "+info.ApiKey) return nil } -func (a *TaskAdaptor) DoRequest(c *gin.Context, info *relaycommon.TaskRelayInfo, requestBody io.Reader) (*http.Response, error) { +func (a *TaskAdaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) { if action := c.GetString("action"); action != "" { info.Action = action } return channel.DoTaskApiRequest(a, c, info, requestBody) } -func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, _ *relaycommon.TaskRelayInfo) (taskID string, taskData []byte, taskErr *dto.TaskError) { +func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, _ *relaycommon.RelayInfo) (taskID string, taskData []byte, taskErr *dto.TaskError) { responseBody, err := io.ReadAll(resp.Body) if err != nil { taskErr = service.TaskErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError) diff --git a/relay/common/relay_info.go b/relay/common/relay_info.go index caf8b452e..404610627 100644 --- a/relay/common/relay_info.go +++ b/relay/common/relay_info.go @@ -116,6 +116,7 @@ type RelayInfo struct { *RerankerInfo *ResponsesUsageInfo *ChannelMeta + *TaskRelayInfo } func (info *RelayInfo) InitChannelMeta(c *gin.Context) { @@ -400,6 +401,10 @@ func genBaseRelayInfo(c *gin.Context, request dto.Request) *RelayInfo { }, } + if info.RelayMode == relayconstant.RelayModeUnknown { + info.RelayMode = c.GetInt("relay_mode") + } + if strings.HasPrefix(c.Request.URL.Path, "/pg") { info.IsPlayground = true info.RequestURLPath = strings.TrimPrefix(info.RequestURLPath, "/pg") @@ -465,25 +470,12 @@ func (info *RelayInfo) HasSendResponse() bool { } type TaskRelayInfo struct { - *RelayInfo Action string OriginTaskID string ConsumeQuota bool } -func GenTaskRelayInfo(c *gin.Context) (*TaskRelayInfo, error) { - relayInfo, err := GenRelayInfo(c, types.RelayFormatTask, nil, nil) - if err != nil { - return nil, err - } - info := &TaskRelayInfo{ - RelayInfo: relayInfo, - } - info.InitChannelMeta(c) - return info, nil -} - type TaskSubmitReq struct { Prompt string `json:"prompt"` Model string `json:"model,omitempty"` diff --git a/relay/relay_task.go b/relay/relay_task.go index 95b8083b3..595ee7e28 100644 --- a/relay/relay_task.go +++ b/relay/relay_task.go @@ -22,31 +22,27 @@ import ( /* Task 任务通过平台、Action 区分任务 */ -func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) { +func RelayTaskSubmit(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.TaskError) { + info.InitChannelMeta(c) platform := constant.TaskPlatform(c.GetString("platform")) if platform == "" { platform = GetTaskPlatform(c) } - relayInfo, err := relaycommon.GenTaskRelayInfo(c) - if err != nil { - return service.TaskErrorWrapper(err, "gen_relay_info_failed", http.StatusInternalServerError) - } - adaptor := GetTaskAdaptor(platform) if adaptor == nil { return service.TaskErrorWrapperLocal(fmt.Errorf("invalid api platform: %s", platform), "invalid_api_platform", http.StatusBadRequest) } - adaptor.Init(relayInfo) + adaptor.Init(info) // get & validate taskRequest 获取并验证文本请求 - taskErr = adaptor.ValidateRequestAndSetAction(c, relayInfo) + taskErr = adaptor.ValidateRequestAndSetAction(c, info) if taskErr != nil { return } - modelName := relayInfo.OriginModelName + modelName := info.OriginModelName if modelName == "" { - modelName = service.CoverTaskActionToModelName(platform, relayInfo.Action) + modelName = service.CoverTaskActionToModelName(platform, info.Action) } modelPrice, success := ratio_setting.GetModelPrice(modelName, true) if !success { @@ -59,15 +55,15 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) { } // 预扣 - groupRatio := ratio_setting.GetGroupRatio(relayInfo.UsingGroup) + groupRatio := ratio_setting.GetGroupRatio(info.UsingGroup) var ratio float64 - userGroupRatio, hasUserGroupRatio := ratio_setting.GetGroupGroupRatio(relayInfo.UserGroup, relayInfo.UsingGroup) + userGroupRatio, hasUserGroupRatio := ratio_setting.GetGroupGroupRatio(info.UserGroup, info.UsingGroup) if hasUserGroupRatio { ratio = modelPrice * userGroupRatio } else { ratio = modelPrice * groupRatio } - userQuota, err := model.GetUserQuota(relayInfo.UserId, false) + userQuota, err := model.GetUserQuota(info.UserId, false) if err != nil { taskErr = service.TaskErrorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError) return @@ -78,8 +74,8 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) { return } - if relayInfo.OriginTaskID != "" { - originTask, exist, err := model.GetByTaskId(relayInfo.UserId, relayInfo.OriginTaskID) + if info.OriginTaskID != "" { + originTask, exist, err := model.GetByTaskId(info.UserId, info.OriginTaskID) if err != nil { taskErr = service.TaskErrorWrapper(err, "get_origin_task_failed", http.StatusInternalServerError) return @@ -88,7 +84,7 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) { taskErr = service.TaskErrorWrapperLocal(errors.New("task_origin_not_exist"), "task_not_exist", http.StatusBadRequest) return } - if originTask.ChannelId != relayInfo.ChannelId { + if originTask.ChannelId != info.ChannelId { channel, err := model.GetChannelById(originTask.ChannelId, true) if err != nil { taskErr = service.TaskErrorWrapperLocal(err, "channel_not_found", http.StatusBadRequest) @@ -101,19 +97,19 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) { c.Set("channel_id", originTask.ChannelId) c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key)) - relayInfo.ChannelBaseUrl = channel.GetBaseURL() - relayInfo.ChannelId = originTask.ChannelId + info.ChannelBaseUrl = channel.GetBaseURL() + info.ChannelId = originTask.ChannelId } } // build body - requestBody, err := adaptor.BuildRequestBody(c, relayInfo) + requestBody, err := adaptor.BuildRequestBody(c, info) if err != nil { taskErr = service.TaskErrorWrapper(err, "build_request_failed", http.StatusInternalServerError) return } // do request - resp, err := adaptor.DoRequest(c, relayInfo, requestBody) + resp, err := adaptor.DoRequest(c, info, requestBody) if err != nil { taskErr = service.TaskErrorWrapper(err, "do_request_failed", http.StatusInternalServerError) return @@ -127,9 +123,9 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) { defer func() { // release quota - if relayInfo.ConsumeQuota && taskErr == nil { + if info.ConsumeQuota && taskErr == nil { - err := service.PostConsumeQuota(relayInfo.RelayInfo, quota, 0, true) + err := service.PostConsumeQuota(info, quota, 0, true) if err != nil { common.SysLog("error consuming token remain quota: " + err.Error()) } @@ -139,40 +135,40 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) { if hasUserGroupRatio { gRatio = userGroupRatio } - logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s", modelPrice, gRatio, relayInfo.Action) + logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s", modelPrice, gRatio, info.Action) other := make(map[string]interface{}) other["model_price"] = modelPrice other["group_ratio"] = groupRatio if hasUserGroupRatio { other["user_group_ratio"] = userGroupRatio } - model.RecordConsumeLog(c, relayInfo.UserId, model.RecordConsumeLogParams{ - ChannelId: relayInfo.ChannelId, + model.RecordConsumeLog(c, info.UserId, model.RecordConsumeLogParams{ + ChannelId: info.ChannelId, ModelName: modelName, TokenName: tokenName, Quota: quota, Content: logContent, - TokenId: relayInfo.TokenId, - Group: relayInfo.UsingGroup, + TokenId: info.TokenId, + Group: info.UsingGroup, Other: other, }) - model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, quota) - model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota) + model.UpdateUserUsedQuotaAndRequestCount(info.UserId, quota) + model.UpdateChannelUsedQuota(info.ChannelId, quota) } } }() - taskID, taskData, taskErr := adaptor.DoResponse(c, resp, relayInfo) + taskID, taskData, taskErr := adaptor.DoResponse(c, resp, info) if taskErr != nil { return } - relayInfo.ConsumeQuota = true + info.ConsumeQuota = true // insert task - task := model.InitTask(platform, relayInfo) + task := model.InitTask(platform, info) task.TaskID = taskID task.Quota = quota task.Data = taskData - task.Action = relayInfo.Action + task.Action = info.Action err = task.Insert() if err != nil { taskErr = service.TaskErrorWrapper(err, "insert_task_failed", http.StatusInternalServerError)