Files
new-api/relay/relay_task.go

442 lines
13 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package relay
import (
"bytes"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"strconv"
"strings"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/constant"
"github.com/QuantumNous/new-api/dto"
"github.com/QuantumNous/new-api/model"
"github.com/QuantumNous/new-api/relay/channel"
relaycommon "github.com/QuantumNous/new-api/relay/common"
relayconstant "github.com/QuantumNous/new-api/relay/constant"
"github.com/QuantumNous/new-api/service"
"github.com/QuantumNous/new-api/setting/ratio_setting"
"github.com/gin-gonic/gin"
)
/*
Task 任务通过平台、Action 区分任务
*/
func RelayTaskSubmit(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.TaskError) {
info.InitChannelMeta(c)
// ensure TaskRelayInfo is initialized to avoid nil dereference when accessing embedded fields
if info.TaskRelayInfo == nil {
info.TaskRelayInfo = &relaycommon.TaskRelayInfo{}
}
platform := constant.TaskPlatform(c.GetString("platform"))
if platform == "" {
platform = GetTaskPlatform(c)
}
info.InitChannelMeta(c)
adaptor := GetTaskAdaptor(platform)
if adaptor == nil {
return service.TaskErrorWrapperLocal(fmt.Errorf("invalid api platform: %s", platform), "invalid_api_platform", http.StatusBadRequest)
}
adaptor.Init(info)
// get & validate taskRequest 获取并验证文本请求
taskErr = adaptor.ValidateRequestAndSetAction(c, info)
if taskErr != nil {
return
}
modelName := info.OriginModelName
if modelName == "" {
modelName = service.CoverTaskActionToModelName(platform, info.Action)
}
modelPrice, success := ratio_setting.GetModelPrice(modelName, true)
if !success {
defaultPrice, ok := ratio_setting.GetDefaultModelPriceMap()[modelName]
if !ok {
modelPrice = 0.1
} else {
modelPrice = defaultPrice
}
}
// 预扣
groupRatio := ratio_setting.GetGroupRatio(info.UsingGroup)
var ratio float64
userGroupRatio, hasUserGroupRatio := ratio_setting.GetGroupGroupRatio(info.UserGroup, info.UsingGroup)
if hasUserGroupRatio {
ratio = modelPrice * userGroupRatio
} else {
ratio = modelPrice * groupRatio
}
// FIXME: 临时修补,支持任务仅按次计费
if !common.StringsContains(constant.TaskPricePatches, modelName) {
if len(info.PriceData.OtherRatios) > 0 {
for _, ra := range info.PriceData.OtherRatios {
if 1.0 != ra {
ratio *= ra
}
}
}
}
println(fmt.Sprintf("model: %s, model_price: %.4f, group: %s, group_ratio: %.4f, final_ratio: %.4f", modelName, modelPrice, info.UsingGroup, groupRatio, ratio))
userQuota, err := model.GetUserQuota(info.UserId, false)
if err != nil {
taskErr = service.TaskErrorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError)
return
}
quota := int(ratio * common.QuotaPerUnit)
if userQuota-quota < 0 {
taskErr = service.TaskErrorWrapperLocal(errors.New("user quota is not enough"), "quota_not_enough", http.StatusForbidden)
return
}
if info.OriginTaskID != "" {
originTask, exist, err := model.GetByTaskId(info.UserId, info.OriginTaskID)
if err != nil {
taskErr = service.TaskErrorWrapper(err, "get_origin_task_failed", http.StatusInternalServerError)
return
}
if !exist {
taskErr = service.TaskErrorWrapperLocal(errors.New("task_origin_not_exist"), "task_not_exist", http.StatusBadRequest)
return
}
if originTask.ChannelId != info.ChannelId {
channel, err := model.GetChannelById(originTask.ChannelId, true)
if err != nil {
taskErr = service.TaskErrorWrapperLocal(err, "channel_not_found", http.StatusBadRequest)
return
}
if channel.Status != common.ChannelStatusEnabled {
return service.TaskErrorWrapperLocal(errors.New("该任务所属渠道已被禁用"), "task_channel_disable", http.StatusBadRequest)
}
c.Set("base_url", channel.GetBaseURL())
c.Set("channel_id", originTask.ChannelId)
c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
info.ChannelBaseUrl = channel.GetBaseURL()
info.ChannelId = originTask.ChannelId
}
}
// build body
requestBody, err := adaptor.BuildRequestBody(c, info)
if err != nil {
taskErr = service.TaskErrorWrapper(err, "build_request_failed", http.StatusInternalServerError)
return
}
// do request
resp, err := adaptor.DoRequest(c, info, requestBody)
if err != nil {
taskErr = service.TaskErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
return
}
// handle response
if resp != nil && resp.StatusCode != http.StatusOK {
responseBody, _ := io.ReadAll(resp.Body)
taskErr = service.TaskErrorWrapper(fmt.Errorf(string(responseBody)), "fail_to_fetch_task", resp.StatusCode)
return
}
defer func() {
// release quota
if info.ConsumeQuota && taskErr == nil {
err := service.PostConsumeQuota(info, quota, 0, true)
if err != nil {
common.SysLog("error consuming token remain quota: " + err.Error())
}
if quota != 0 {
tokenName := c.GetString("token_name")
//gRatio := groupRatio
//if hasUserGroupRatio {
// gRatio = userGroupRatio
//}
logContent := fmt.Sprintf("操作 %s", info.Action)
// FIXME: 临时修补,支持任务仅按次计费
if common.StringsContains(constant.TaskPricePatches, modelName) {
logContent = fmt.Sprintf("%s按次计费", logContent)
} else {
if len(info.PriceData.OtherRatios) > 0 {
var contents []string
for key, ra := range info.PriceData.OtherRatios {
if 1.0 != ra {
contents = append(contents, fmt.Sprintf("%s: %.2f", key, ra))
}
}
if len(contents) > 0 {
logContent = fmt.Sprintf("%s, 计算参数:%s", logContent, strings.Join(contents, ", "))
}
}
}
other := make(map[string]interface{})
if c != nil && c.Request != nil && c.Request.URL != nil {
other["request_path"] = c.Request.URL.Path
}
other["model_price"] = modelPrice
other["group_ratio"] = groupRatio
if hasUserGroupRatio {
other["user_group_ratio"] = userGroupRatio
}
model.RecordConsumeLog(c, info.UserId, model.RecordConsumeLogParams{
ChannelId: info.ChannelId,
ModelName: modelName,
TokenName: tokenName,
Quota: quota,
Content: logContent,
TokenId: info.TokenId,
Group: info.UsingGroup,
Other: other,
})
model.UpdateUserUsedQuotaAndRequestCount(info.UserId, quota)
model.UpdateChannelUsedQuota(info.ChannelId, quota)
}
}
}()
taskID, taskData, taskErr := adaptor.DoResponse(c, resp, info)
if taskErr != nil {
return
}
info.ConsumeQuota = true
// insert task
task := model.InitTask(platform, info)
task.TaskID = taskID
task.Quota = quota
task.Data = taskData
task.Action = info.Action
err = task.Insert()
if err != nil {
taskErr = service.TaskErrorWrapper(err, "insert_task_failed", http.StatusInternalServerError)
return
}
return nil
}
var fetchRespBuilders = map[int]func(c *gin.Context) (respBody []byte, taskResp *dto.TaskError){
relayconstant.RelayModeSunoFetchByID: sunoFetchByIDRespBodyBuilder,
relayconstant.RelayModeSunoFetch: sunoFetchRespBodyBuilder,
relayconstant.RelayModeVideoFetchByID: videoFetchByIDRespBodyBuilder,
}
func RelayTaskFetch(c *gin.Context, relayMode int) (taskResp *dto.TaskError) {
respBuilder, ok := fetchRespBuilders[relayMode]
if !ok {
taskResp = service.TaskErrorWrapperLocal(errors.New("invalid_relay_mode"), "invalid_relay_mode", http.StatusBadRequest)
}
respBody, taskErr := respBuilder(c)
if taskErr != nil {
return taskErr
}
if len(respBody) == 0 {
respBody = []byte("{\"code\":\"success\",\"data\":null}")
}
c.Writer.Header().Set("Content-Type", "application/json")
_, err := io.Copy(c.Writer, bytes.NewBuffer(respBody))
if err != nil {
taskResp = service.TaskErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError)
return
}
return
}
func sunoFetchRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *dto.TaskError) {
userId := c.GetInt("id")
var condition = struct {
IDs []any `json:"ids"`
Action string `json:"action"`
}{}
err := c.BindJSON(&condition)
if err != nil {
taskResp = service.TaskErrorWrapper(err, "invalid_request", http.StatusBadRequest)
return
}
var tasks []any
if len(condition.IDs) > 0 {
taskModels, err := model.GetByTaskIds(userId, condition.IDs)
if err != nil {
taskResp = service.TaskErrorWrapper(err, "get_tasks_failed", http.StatusInternalServerError)
return
}
for _, task := range taskModels {
tasks = append(tasks, TaskModel2Dto(task))
}
} else {
tasks = make([]any, 0)
}
respBody, err = json.Marshal(dto.TaskResponse[[]any]{
Code: "success",
Data: tasks,
})
return
}
func sunoFetchByIDRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *dto.TaskError) {
taskId := c.Param("id")
userId := c.GetInt("id")
originTask, exist, err := model.GetByTaskId(userId, taskId)
if err != nil {
taskResp = service.TaskErrorWrapper(err, "get_task_failed", http.StatusInternalServerError)
return
}
if !exist {
taskResp = service.TaskErrorWrapperLocal(errors.New("task_not_exist"), "task_not_exist", http.StatusBadRequest)
return
}
respBody, err = json.Marshal(dto.TaskResponse[any]{
Code: "success",
Data: TaskModel2Dto(originTask),
})
return
}
func videoFetchByIDRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *dto.TaskError) {
taskId := c.Param("task_id")
if taskId == "" {
taskId = c.GetString("task_id")
}
userId := c.GetInt("id")
originTask, exist, err := model.GetByTaskId(userId, taskId)
if err != nil {
taskResp = service.TaskErrorWrapper(err, "get_task_failed", http.StatusInternalServerError)
return
}
if !exist {
taskResp = service.TaskErrorWrapperLocal(errors.New("task_not_exist"), "task_not_exist", http.StatusBadRequest)
return
}
func() {
channelModel, err2 := model.GetChannelById(originTask.ChannelId, true)
if err2 != nil {
return
}
if channelModel.Type != constant.ChannelTypeVertexAi {
return
}
baseURL := constant.ChannelBaseURLs[channelModel.Type]
if channelModel.GetBaseURL() != "" {
baseURL = channelModel.GetBaseURL()
}
adaptor := GetTaskAdaptor(constant.TaskPlatform(strconv.Itoa(channelModel.Type)))
if adaptor == nil {
return
}
resp, err2 := adaptor.FetchTask(baseURL, channelModel.Key, map[string]any{
"task_id": originTask.TaskID,
"action": originTask.Action,
})
if err2 != nil || resp == nil {
return
}
defer resp.Body.Close()
body, err2 := io.ReadAll(resp.Body)
if err2 != nil {
return
}
ti, err2 := adaptor.ParseTaskResult(body)
if err2 == nil && ti != nil {
if ti.Status != "" {
originTask.Status = model.TaskStatus(ti.Status)
}
if ti.Progress != "" {
originTask.Progress = ti.Progress
}
if ti.Url != "" {
originTask.FailReason = ti.Url
}
_ = originTask.Update()
var raw map[string]any
_ = json.Unmarshal(body, &raw)
format := "mp4"
if respObj, ok := raw["response"].(map[string]any); ok {
if vids, ok := respObj["videos"].([]any); ok && len(vids) > 0 {
if v0, ok := vids[0].(map[string]any); ok {
if mt, ok := v0["mimeType"].(string); ok && mt != "" {
if strings.Contains(mt, "mp4") {
format = "mp4"
} else {
format = mt
}
}
}
}
}
status := "processing"
switch originTask.Status {
case model.TaskStatusSuccess:
status = "succeeded"
case model.TaskStatusFailure:
status = "failed"
case model.TaskStatusQueued, model.TaskStatusSubmitted:
status = "queued"
}
out := map[string]any{
"error": nil,
"format": format,
"metadata": nil,
"status": status,
"task_id": originTask.TaskID,
"url": originTask.FailReason,
}
respBody, _ = json.Marshal(dto.TaskResponse[any]{
Code: "success",
Data: out,
})
}
}()
if len(respBody) != 0 {
return
}
if strings.HasPrefix(c.Request.RequestURI, "/v1/videos/") {
adaptor := GetTaskAdaptor(originTask.Platform)
if adaptor == nil {
taskResp = service.TaskErrorWrapperLocal(fmt.Errorf("invalid channel id: %d", originTask.ChannelId), "invalid_channel_id", http.StatusBadRequest)
return
}
if converter, ok := adaptor.(channel.OpenAIVideoConverter); ok {
openAIVideoData, err := converter.ConvertToOpenAIVideo(originTask)
if err != nil {
taskResp = service.TaskErrorWrapper(err, "convert_to_openai_video_failed", http.StatusInternalServerError)
return
}
respBody = openAIVideoData
return
}
taskResp = service.TaskErrorWrapperLocal(errors.New(fmt.Sprintf("not_implemented:%s", originTask.Platform)), "not_implemented", http.StatusNotImplemented)
return
}
respBody, err = json.Marshal(dto.TaskResponse[any]{
Code: "success",
Data: TaskModel2Dto(originTask),
})
if err != nil {
taskResp = service.TaskErrorWrapper(err, "marshal_response_failed", http.StatusInternalServerError)
}
return
}
func TaskModel2Dto(task *model.Task) *dto.TaskDto {
return &dto.TaskDto{
TaskID: task.TaskID,
Action: task.Action,
Status: string(task.Status),
FailReason: task.FailReason,
SubmitTime: task.SubmitTime,
StartTime: task.StartTime,
FinishTime: task.FinishTime,
Progress: task.Progress,
Data: task.Data,
}
}