mirror of
https://github.com/QuantumNous/new-api.git
synced 2026-03-29 23:10:35 +00:00
Restructure the task relay system for better separation of concerns: - Extract task billing into service/task_billing.go with unified settlement flow - Move task polling loop from controller to service/task_polling.go (supports Suno + video platforms) - Split RelayTask into fetch/submit paths with dedicated retry logic (taskSubmitWithRetry) - Add TaskDto, TaskResponse generics, and FetchReq to dto/task.go - Add taskcommon/helpers.go for shared task adaptor utilities - Remove controller/task_video.go (logic consolidated into service layer) - Update all task adaptors (ali, doubao, gemini, hailuo, jimeng, kling, sora, suno, vertex, vidu) - Simplify frontend task logs to use new TaskDto response format
222 lines
6.2 KiB
Go
222 lines
6.2 KiB
Go
package service
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"math"
|
|
"net/http"
|
|
"strconv"
|
|
"strings"
|
|
|
|
"github.com/QuantumNous/new-api/common"
|
|
"github.com/QuantumNous/new-api/dto"
|
|
"github.com/QuantumNous/new-api/logger"
|
|
"github.com/QuantumNous/new-api/types"
|
|
)
|
|
|
|
func MidjourneyErrorWrapper(code int, desc string) *dto.MidjourneyResponse {
|
|
return &dto.MidjourneyResponse{
|
|
Code: code,
|
|
Description: desc,
|
|
}
|
|
}
|
|
|
|
func MidjourneyErrorWithStatusCodeWrapper(code int, desc string, statusCode int) *dto.MidjourneyResponseWithStatusCode {
|
|
return &dto.MidjourneyResponseWithStatusCode{
|
|
StatusCode: statusCode,
|
|
Response: *MidjourneyErrorWrapper(code, desc),
|
|
}
|
|
}
|
|
|
|
//// OpenAIErrorWrapper wraps an error into an OpenAIErrorWithStatusCode
|
|
//func OpenAIErrorWrapper(err error, code string, statusCode int) *dto.OpenAIErrorWithStatusCode {
|
|
// text := err.Error()
|
|
// lowerText := strings.ToLower(text)
|
|
// if !strings.HasPrefix(lowerText, "get file base64 from url") && !strings.HasPrefix(lowerText, "mime type is not supported") {
|
|
// if strings.Contains(lowerText, "post") || strings.Contains(lowerText, "dial") || strings.Contains(lowerText, "http") {
|
|
// common.SysLog(fmt.Sprintf("error: %s", text))
|
|
// text = "请求上游地址失败"
|
|
// }
|
|
// }
|
|
// openAIError := dto.OpenAIError{
|
|
// Message: text,
|
|
// Type: "new_api_error",
|
|
// Code: code,
|
|
// }
|
|
// return &dto.OpenAIErrorWithStatusCode{
|
|
// Error: openAIError,
|
|
// StatusCode: statusCode,
|
|
// }
|
|
//}
|
|
//
|
|
//func OpenAIErrorWrapperLocal(err error, code string, statusCode int) *dto.OpenAIErrorWithStatusCode {
|
|
// openaiErr := OpenAIErrorWrapper(err, code, statusCode)
|
|
// openaiErr.LocalError = true
|
|
// return openaiErr
|
|
//}
|
|
|
|
func ClaudeErrorWrapper(err error, code string, statusCode int) *dto.ClaudeErrorWithStatusCode {
|
|
text := err.Error()
|
|
lowerText := strings.ToLower(text)
|
|
if !strings.HasPrefix(lowerText, "get file base64 from url") {
|
|
if strings.Contains(lowerText, "post") || strings.Contains(lowerText, "dial") || strings.Contains(lowerText, "http") {
|
|
common.SysLog(fmt.Sprintf("error: %s", text))
|
|
text = "请求上游地址失败"
|
|
}
|
|
}
|
|
claudeError := types.ClaudeError{
|
|
Message: text,
|
|
Type: "new_api_error",
|
|
}
|
|
return &dto.ClaudeErrorWithStatusCode{
|
|
Error: claudeError,
|
|
StatusCode: statusCode,
|
|
}
|
|
}
|
|
|
|
func ClaudeErrorWrapperLocal(err error, code string, statusCode int) *dto.ClaudeErrorWithStatusCode {
|
|
claudeErr := ClaudeErrorWrapper(err, code, statusCode)
|
|
claudeErr.LocalError = true
|
|
return claudeErr
|
|
}
|
|
|
|
func RelayErrorHandler(ctx context.Context, resp *http.Response, showBodyWhenFail bool) (newApiErr *types.NewAPIError) {
|
|
newApiErr = types.InitOpenAIError(types.ErrorCodeBadResponseStatusCode, resp.StatusCode)
|
|
|
|
responseBody, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
return
|
|
}
|
|
CloseResponseBodyGracefully(resp)
|
|
var errResponse dto.GeneralErrorResponse
|
|
buildErrWithBody := func(message string) error {
|
|
if message == "" {
|
|
return fmt.Errorf("bad response status code %d, body: %s", resp.StatusCode, string(responseBody))
|
|
}
|
|
return fmt.Errorf("bad response status code %d, message: %s, body: %s", resp.StatusCode, message, string(responseBody))
|
|
}
|
|
|
|
err = common.Unmarshal(responseBody, &errResponse)
|
|
if err != nil {
|
|
if showBodyWhenFail {
|
|
newApiErr.Err = buildErrWithBody("")
|
|
} else {
|
|
logger.LogError(ctx, fmt.Sprintf("bad response status code %d, body: %s", resp.StatusCode, string(responseBody)))
|
|
newApiErr.Err = fmt.Errorf("bad response status code %d", resp.StatusCode)
|
|
}
|
|
return
|
|
}
|
|
|
|
if common.GetJsonType(errResponse.Error) == "object" {
|
|
// General format error (OpenAI, Anthropic, Gemini, etc.)
|
|
oaiError := errResponse.TryToOpenAIError()
|
|
if oaiError != nil {
|
|
newApiErr = types.WithOpenAIError(*oaiError, resp.StatusCode)
|
|
if showBodyWhenFail {
|
|
newApiErr.Err = buildErrWithBody(newApiErr.Error())
|
|
}
|
|
return
|
|
}
|
|
}
|
|
newApiErr = types.NewOpenAIError(errors.New(errResponse.ToMessage()), types.ErrorCodeBadResponseStatusCode, resp.StatusCode)
|
|
if showBodyWhenFail {
|
|
newApiErr.Err = buildErrWithBody(newApiErr.Error())
|
|
}
|
|
return
|
|
}
|
|
|
|
func ResetStatusCode(newApiErr *types.NewAPIError, statusCodeMappingStr string) {
|
|
if newApiErr == nil {
|
|
return
|
|
}
|
|
if statusCodeMappingStr == "" || statusCodeMappingStr == "{}" {
|
|
return
|
|
}
|
|
statusCodeMapping := make(map[string]any)
|
|
err := common.Unmarshal([]byte(statusCodeMappingStr), &statusCodeMapping)
|
|
if err != nil {
|
|
return
|
|
}
|
|
if newApiErr.StatusCode == http.StatusOK {
|
|
return
|
|
}
|
|
codeStr := strconv.Itoa(newApiErr.StatusCode)
|
|
if value, ok := statusCodeMapping[codeStr]; ok {
|
|
intCode, ok := parseStatusCodeMappingValue(value)
|
|
if !ok {
|
|
return
|
|
}
|
|
newApiErr.StatusCode = intCode
|
|
}
|
|
}
|
|
|
|
func parseStatusCodeMappingValue(value any) (int, bool) {
|
|
switch v := value.(type) {
|
|
case string:
|
|
if v == "" {
|
|
return 0, false
|
|
}
|
|
statusCode, err := strconv.Atoi(v)
|
|
if err != nil {
|
|
return 0, false
|
|
}
|
|
return statusCode, true
|
|
case float64:
|
|
if v != math.Trunc(v) {
|
|
return 0, false
|
|
}
|
|
return int(v), true
|
|
case int:
|
|
return v, true
|
|
case json.Number:
|
|
statusCode, err := strconv.Atoi(v.String())
|
|
if err != nil {
|
|
return 0, false
|
|
}
|
|
return statusCode, true
|
|
default:
|
|
return 0, false
|
|
}
|
|
}
|
|
|
|
func TaskErrorWrapperLocal(err error, code string, statusCode int) *dto.TaskError {
|
|
openaiErr := TaskErrorWrapper(err, code, statusCode)
|
|
openaiErr.LocalError = true
|
|
return openaiErr
|
|
}
|
|
|
|
func TaskErrorWrapper(err error, code string, statusCode int) *dto.TaskError {
|
|
text := err.Error()
|
|
lowerText := strings.ToLower(text)
|
|
if strings.Contains(lowerText, "post") || strings.Contains(lowerText, "dial") || strings.Contains(lowerText, "http") {
|
|
common.SysLog(fmt.Sprintf("error: %s", text))
|
|
//text = "请求上游地址失败"
|
|
text = common.MaskSensitiveInfo(text)
|
|
}
|
|
//避免暴露内部错误
|
|
taskError := &dto.TaskError{
|
|
Code: code,
|
|
Message: text,
|
|
StatusCode: statusCode,
|
|
Error: err,
|
|
}
|
|
|
|
return taskError
|
|
}
|
|
|
|
// TaskErrorFromAPIError 将 PreConsumeBilling 返回的 NewAPIError 转换为 TaskError。
|
|
func TaskErrorFromAPIError(apiErr *types.NewAPIError) *dto.TaskError {
|
|
if apiErr == nil {
|
|
return nil
|
|
}
|
|
return &dto.TaskError{
|
|
Code: string(apiErr.GetErrorCode()),
|
|
Message: apiErr.Err.Error(),
|
|
StatusCode: apiErr.StatusCode,
|
|
Error: apiErr.Err,
|
|
}
|
|
}
|