mirror of
https://github.com/QuantumNous/new-api.git
synced 2026-03-30 02:25:00 +00:00
Merge branch 'QuantumNous:main' into fix-claude-haiku
This commit is contained in:
@@ -67,8 +67,11 @@ func AudioHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *type
|
||||
service.ResetStatusCode(newAPIError, statusCodeMappingStr)
|
||||
return newAPIError
|
||||
}
|
||||
|
||||
postConsumeQuota(c, info, usage.(*dto.Usage), "")
|
||||
if usage.(*dto.Usage).CompletionTokenDetails.AudioTokens > 0 || usage.(*dto.Usage).PromptTokensDetails.AudioTokens > 0 {
|
||||
service.PostAudioConsumeQuota(c, info, usage.(*dto.Usage), "")
|
||||
} else {
|
||||
postConsumeQuota(c, info, usage.(*dto.Usage), "")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -47,7 +47,7 @@ type TaskAdaptor interface {
|
||||
GetChannelName() string
|
||||
|
||||
// FetchTask
|
||||
FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error)
|
||||
FetchTask(baseUrl, key string, body map[string]any, proxy string) (*http.Response, error)
|
||||
|
||||
ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error)
|
||||
}
|
||||
|
||||
@@ -18,7 +18,7 @@ var awsModelIDMap = map[string]string{
|
||||
"claude-opus-4-1-20250805": "anthropic.claude-opus-4-1-20250805-v1:0",
|
||||
"claude-sonnet-4-5-20250929": "anthropic.claude-sonnet-4-5-20250929-v1:0",
|
||||
"claude-haiku-4-5-20251001": "anthropic.claude-haiku-4-5-20251001-v1:0",
|
||||
"claude-opus-4-5-20251101": "anthropic.claude-opus-4-5-20251101-v1:0",
|
||||
"claude-opus-4-5-20251101": "anthropic.claude-opus-4-5-20251101-v1:0",
|
||||
// Nova models
|
||||
"nova-micro-v1:0": "amazon.nova-micro-v1:0",
|
||||
"nova-lite-v1:0": "amazon.nova-lite-v1:0",
|
||||
|
||||
@@ -18,6 +18,7 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"github.com/QuantumNous/new-api/setting/model_setting"
|
||||
"github.com/aws/aws-sdk-go-v2/aws"
|
||||
"github.com/aws/aws-sdk-go-v2/credentials"
|
||||
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
|
||||
@@ -129,7 +130,7 @@ func doAwsClientRequest(c *gin.Context, info *relaycommon.RelayInfo, a *Adaptor,
|
||||
Accept: aws.String("application/json"),
|
||||
ContentType: aws.String("application/json"),
|
||||
}
|
||||
awsReq.Body, err = common.Marshal(awsClaudeReq)
|
||||
awsReq.Body, err = buildAwsRequestBody(c, info, awsClaudeReq)
|
||||
if err != nil {
|
||||
return nil, types.NewError(errors.Wrap(err, "marshal aws request fail"), types.ErrorCodeBadRequestBody)
|
||||
}
|
||||
@@ -141,7 +142,7 @@ func doAwsClientRequest(c *gin.Context, info *relaycommon.RelayInfo, a *Adaptor,
|
||||
Accept: aws.String("application/json"),
|
||||
ContentType: aws.String("application/json"),
|
||||
}
|
||||
awsReq.Body, err = common.Marshal(awsClaudeReq)
|
||||
awsReq.Body, err = buildAwsRequestBody(c, info, awsClaudeReq)
|
||||
if err != nil {
|
||||
return nil, types.NewError(errors.Wrap(err, "marshal aws request fail"), types.ErrorCodeBadRequestBody)
|
||||
}
|
||||
@@ -151,6 +152,24 @@ func doAwsClientRequest(c *gin.Context, info *relaycommon.RelayInfo, a *Adaptor,
|
||||
}
|
||||
}
|
||||
|
||||
// buildAwsRequestBody prepares the payload for AWS requests, applying passthrough rules when enabled.
|
||||
func buildAwsRequestBody(c *gin.Context, info *relaycommon.RelayInfo, awsClaudeReq any) ([]byte, error) {
|
||||
if model_setting.GetGlobalSettings().PassThroughRequestEnabled || info.ChannelSetting.PassThroughBodyEnabled {
|
||||
body, err := common.GetRequestBody(c)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "get request body for pass-through fail")
|
||||
}
|
||||
var data map[string]interface{}
|
||||
if err := common.Unmarshal(body, &data); err != nil {
|
||||
return nil, errors.Wrap(err, "pass-through unmarshal request body fail")
|
||||
}
|
||||
delete(data, "model")
|
||||
delete(data, "stream")
|
||||
return common.Marshal(data)
|
||||
}
|
||||
return common.Marshal(awsClaudeReq)
|
||||
}
|
||||
|
||||
func getAwsRegionPrefix(awsRegionId string) string {
|
||||
parts := strings.Split(awsRegionId, "-")
|
||||
regionPrefix := ""
|
||||
|
||||
@@ -150,7 +150,7 @@ func baiduHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respon
|
||||
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
|
||||
}
|
||||
if baiduResponse.ErrorMsg != "" {
|
||||
return types.NewError(fmt.Errorf(baiduResponse.ErrorMsg), types.ErrorCodeBadResponseBody), nil
|
||||
return types.NewError(fmt.Errorf("%s", baiduResponse.ErrorMsg), types.ErrorCodeBadResponseBody), nil
|
||||
}
|
||||
fullTextResponse := responseBaidu2OpenAI(&baiduResponse)
|
||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||
@@ -175,7 +175,7 @@ func baiduEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *ht
|
||||
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
|
||||
}
|
||||
if baiduResponse.ErrorMsg != "" {
|
||||
return types.NewError(fmt.Errorf(baiduResponse.ErrorMsg), types.ErrorCodeBadResponseBody), nil
|
||||
return types.NewError(fmt.Errorf("%s", baiduResponse.ErrorMsg), types.ErrorCodeBadResponseBody), nil
|
||||
}
|
||||
fullTextResponse := embeddingResponseBaidu2OpenAI(&baiduResponse)
|
||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||
|
||||
@@ -9,6 +9,7 @@ var ModelList = []string{
|
||||
"claude-3-opus-20240229",
|
||||
"claude-3-haiku-20240307",
|
||||
"claude-3-5-haiku-20241022",
|
||||
"claude-haiku-4-5-20251001",
|
||||
"claude-3-5-sonnet-20240620",
|
||||
"claude-3-5-sonnet-20241022",
|
||||
"claude-3-7-sonnet-20250219",
|
||||
|
||||
@@ -483,9 +483,11 @@ func StreamResponseClaude2OpenAI(reqMode int, claudeResponse *dto.ClaudeResponse
|
||||
}
|
||||
}
|
||||
} else if claudeResponse.Type == "message_delta" {
|
||||
finishReason := stopReasonClaude2OpenAI(*claudeResponse.Delta.StopReason)
|
||||
if finishReason != "null" {
|
||||
choice.FinishReason = &finishReason
|
||||
if claudeResponse.Delta != nil && claudeResponse.Delta.StopReason != nil {
|
||||
finishReason := stopReasonClaude2OpenAI(*claudeResponse.Delta.StopReason)
|
||||
if finishReason != "null" {
|
||||
choice.FinishReason = &finishReason
|
||||
}
|
||||
}
|
||||
//claudeUsage = &claudeResponse.Usage
|
||||
} else if claudeResponse.Type == "message_stop" {
|
||||
|
||||
@@ -208,7 +208,7 @@ func handleCozeEvent(c *gin.Context, event string, data string, responseText *st
|
||||
return
|
||||
}
|
||||
|
||||
common.SysLog(fmt.Sprintf("stream event error: ", errorData.Code, errorData.Message))
|
||||
common.SysLog(fmt.Sprintf("stream event error: %v %v", errorData.Code, errorData.Message))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||||
"github.com/QuantumNous/new-api/relay/constant"
|
||||
"github.com/QuantumNous/new-api/setting/model_setting"
|
||||
"github.com/QuantumNous/new-api/setting/reasoning"
|
||||
"github.com/QuantumNous/new-api/types"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -137,7 +138,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
info.UpstreamModelName = strings.TrimSuffix(info.UpstreamModelName, "-thinking")
|
||||
} else if strings.HasSuffix(info.UpstreamModelName, "-nothinking") {
|
||||
info.UpstreamModelName = strings.TrimSuffix(info.UpstreamModelName, "-nothinking")
|
||||
} else if baseModel, level := parseThinkingLevelSuffix(info.UpstreamModelName); level != "" {
|
||||
} else if baseModel, level, ok := reasoning.TrimEffortSuffix(info.UpstreamModelName); ok && level != "" {
|
||||
info.UpstreamModelName = baseModel
|
||||
}
|
||||
}
|
||||
|
||||
@@ -94,10 +94,10 @@ func GeminiTextGenerationStreamHandler(c *gin.Context, info *relaycommon.RelayIn
|
||||
helper.SetEventStreamHeaders(c)
|
||||
|
||||
return geminiStreamHandler(c, info, resp, func(data string, geminiResponse *dto.GeminiChatResponse) bool {
|
||||
// 直接发送 GeminiChatResponse 响应
|
||||
err := helper.StringData(c, data)
|
||||
if err != nil {
|
||||
logger.LogError(c, err.Error())
|
||||
logger.LogError(c, "failed to write stream data: "+err.Error())
|
||||
return false
|
||||
}
|
||||
info.SendResponseCount++
|
||||
return true
|
||||
|
||||
@@ -98,6 +98,7 @@ func clampThinkingBudget(modelName string, budget int) int {
|
||||
// "effort": "high" - Allocates a large portion of tokens for reasoning (approximately 80% of max_tokens)
|
||||
// "effort": "medium" - Allocates a moderate portion of tokens (approximately 50% of max_tokens)
|
||||
// "effort": "low" - Allocates a smaller portion of tokens (approximately 20% of max_tokens)
|
||||
// "effort": "minimal" - Allocates a minimal portion of tokens (approximately 5% of max_tokens)
|
||||
func clampThinkingBudgetByEffort(modelName string, effort string) int {
|
||||
isNew25Pro := isNew25ProModel(modelName)
|
||||
is25FlashLite := is25FlashLiteModel(modelName)
|
||||
@@ -118,18 +119,12 @@ func clampThinkingBudgetByEffort(modelName string, effort string) int {
|
||||
maxBudget = maxBudget * 50 / 100
|
||||
case "low":
|
||||
maxBudget = maxBudget * 20 / 100
|
||||
case "minimal":
|
||||
maxBudget = maxBudget * 5 / 100
|
||||
}
|
||||
return clampThinkingBudget(modelName, maxBudget)
|
||||
}
|
||||
|
||||
func parseThinkingLevelSuffix(modelName string) (string, string) {
|
||||
base, level, ok := reasoning.TrimEffortSuffix(modelName)
|
||||
if !ok {
|
||||
return modelName, ""
|
||||
}
|
||||
return base, level
|
||||
}
|
||||
|
||||
func ThinkingAdaptor(geminiRequest *dto.GeminiChatRequest, info *relaycommon.RelayInfo, oaiRequest ...dto.GeneralOpenAIRequest) {
|
||||
if model_setting.GetGeminiSettings().ThinkingAdapterEnabled {
|
||||
modelName := info.UpstreamModelName
|
||||
@@ -186,7 +181,7 @@ func ThinkingAdaptor(geminiRequest *dto.GeminiChatRequest, info *relaycommon.Rel
|
||||
ThinkingBudget: common.GetPointer(0),
|
||||
}
|
||||
}
|
||||
} else if _, level := parseThinkingLevelSuffix(modelName); level != "" {
|
||||
} else if _, level, ok := reasoning.TrimEffortSuffix(info.UpstreamModelName); ok && level != "" {
|
||||
geminiRequest.GenerationConfig.ThinkingConfig = &dto.GeminiThinkingConfig{
|
||||
IncludeThoughts: true,
|
||||
ThinkingLevel: level,
|
||||
|
||||
@@ -42,7 +42,7 @@ type Adaptor struct {
|
||||
// support OAI models: o1-mini/o3-mini/o4-mini/o1/o3 etc...
|
||||
// minimal effort only available in gpt-5
|
||||
func parseReasoningEffortFromModelSuffix(model string) (string, string) {
|
||||
effortSuffixes := []string{"-high", "-minimal", "-low", "-medium", "-none"}
|
||||
effortSuffixes := []string{"-high", "-minimal", "-low", "-medium", "-none", "-xhigh"}
|
||||
for _, suffix := range effortSuffixes {
|
||||
if strings.HasSuffix(model, suffix) {
|
||||
effort := strings.TrimPrefix(suffix, "-")
|
||||
@@ -306,10 +306,11 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
|
||||
request.Temperature = nil
|
||||
}
|
||||
|
||||
// gpt-5系列模型适配 归零不再支持的参数
|
||||
if strings.HasPrefix(info.UpstreamModelName, "gpt-5") {
|
||||
if info.UpstreamModelName != "gpt-5-chat-latest" {
|
||||
request.Temperature = nil
|
||||
}
|
||||
request.Temperature = nil
|
||||
request.TopP = 0 // oai 的 top_p 默认值是 1.0,但是为了 omitempty 属性直接不传,这里显式设置为 0
|
||||
request.LogProbs = false
|
||||
}
|
||||
|
||||
// 转换模型推理力度后缀
|
||||
|
||||
145
relay/channel/openai/audio.go
Normal file
145
relay/channel/openai/audio.go
Normal file
@@ -0,0 +1,145 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"math"
|
||||
"net/http"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/constant"
|
||||
"github.com/QuantumNous/new-api/dto"
|
||||
"github.com/QuantumNous/new-api/logger"
|
||||
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||||
"github.com/QuantumNous/new-api/relay/helper"
|
||||
"github.com/QuantumNous/new-api/service"
|
||||
"github.com/QuantumNous/new-api/types"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func OpenaiTTSHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) *dto.Usage {
|
||||
// the status code has been judged before, if there is a body reading failure,
|
||||
// it should be regarded as a non-recoverable error, so it should not return err for external retry.
|
||||
// Analogous to nginx's load balancing, it will only retry if it can't be requested or
|
||||
// if the upstream returns a specific status code, once the upstream has already written the header,
|
||||
// the subsequent failure of the response body should be regarded as a non-recoverable error,
|
||||
// and can be terminated directly.
|
||||
defer service.CloseResponseBodyGracefully(resp)
|
||||
usage := &dto.Usage{}
|
||||
usage.PromptTokens = info.GetEstimatePromptTokens()
|
||||
usage.TotalTokens = info.GetEstimatePromptTokens()
|
||||
for k, v := range resp.Header {
|
||||
c.Writer.Header().Set(k, v[0])
|
||||
}
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
|
||||
if info.IsStream {
|
||||
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
|
||||
if service.SundaySearch(data, "usage") {
|
||||
var simpleResponse dto.SimpleResponse
|
||||
err := common.Unmarshal([]byte(data), &simpleResponse)
|
||||
if err != nil {
|
||||
logger.LogError(c, err.Error())
|
||||
}
|
||||
if simpleResponse.Usage.TotalTokens != 0 {
|
||||
usage.PromptTokens = simpleResponse.Usage.InputTokens
|
||||
usage.CompletionTokens = simpleResponse.OutputTokens
|
||||
usage.TotalTokens = simpleResponse.TotalTokens
|
||||
}
|
||||
}
|
||||
_ = helper.StringData(c, data)
|
||||
return true
|
||||
})
|
||||
} else {
|
||||
common.SetContextKey(c, constant.ContextKeyLocalCountTokens, true)
|
||||
// 读取响应体到缓冲区
|
||||
bodyBytes, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
logger.LogError(c, fmt.Sprintf("failed to read TTS response body: %v", err))
|
||||
c.Writer.WriteHeaderNow()
|
||||
return usage
|
||||
}
|
||||
|
||||
// 写入响应到客户端
|
||||
c.Writer.WriteHeaderNow()
|
||||
_, err = c.Writer.Write(bodyBytes)
|
||||
if err != nil {
|
||||
logger.LogError(c, fmt.Sprintf("failed to write TTS response: %v", err))
|
||||
}
|
||||
|
||||
// 计算音频时长并更新 usage
|
||||
audioFormat := "mp3" // 默认格式
|
||||
if audioReq, ok := info.Request.(*dto.AudioRequest); ok && audioReq.ResponseFormat != "" {
|
||||
audioFormat = audioReq.ResponseFormat
|
||||
}
|
||||
|
||||
var duration float64
|
||||
var durationErr error
|
||||
|
||||
if audioFormat == "pcm" {
|
||||
// PCM 格式没有文件头,根据 OpenAI TTS 的 PCM 参数计算时长
|
||||
// 采样率: 24000 Hz, 位深度: 16-bit (2 bytes), 声道数: 1
|
||||
const sampleRate = 24000
|
||||
const bytesPerSample = 2
|
||||
const channels = 1
|
||||
duration = float64(len(bodyBytes)) / float64(sampleRate*bytesPerSample*channels)
|
||||
} else {
|
||||
ext := "." + audioFormat
|
||||
reader := bytes.NewReader(bodyBytes)
|
||||
duration, durationErr = common.GetAudioDuration(c.Request.Context(), reader, ext)
|
||||
}
|
||||
|
||||
usage.PromptTokensDetails.TextTokens = usage.PromptTokens
|
||||
|
||||
if durationErr != nil {
|
||||
logger.LogWarn(c, fmt.Sprintf("failed to get audio duration: %v", durationErr))
|
||||
// 如果无法获取时长,则设置保底的 CompletionTokens,根据body大小计算
|
||||
sizeInKB := float64(len(bodyBytes)) / 1000.0
|
||||
estimatedTokens := int(math.Ceil(sizeInKB)) // 粗略估算每KB约等于1 token
|
||||
usage.CompletionTokens = estimatedTokens
|
||||
usage.CompletionTokenDetails.AudioTokens = estimatedTokens
|
||||
} else if duration > 0 {
|
||||
// 计算 token: ceil(duration) / 60.0 * 1000,即每分钟 1000 tokens
|
||||
completionTokens := int(math.Round(math.Ceil(duration) / 60.0 * 1000))
|
||||
usage.CompletionTokens = completionTokens
|
||||
usage.CompletionTokenDetails.AudioTokens = completionTokens
|
||||
}
|
||||
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
|
||||
}
|
||||
|
||||
return usage
|
||||
}
|
||||
|
||||
func OpenaiSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, responseFormat string) (*types.NewAPIError, *dto.Usage) {
|
||||
defer service.CloseResponseBodyGracefully(resp)
|
||||
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError), nil
|
||||
}
|
||||
// 写入新的 response body
|
||||
service.IOCopyBytesGracefully(c, resp, responseBody)
|
||||
|
||||
var responseData struct {
|
||||
Usage *dto.Usage `json:"usage"`
|
||||
}
|
||||
if err := common.Unmarshal(responseBody, &responseData); err == nil && responseData.Usage != nil {
|
||||
if responseData.Usage.TotalTokens > 0 {
|
||||
usage := responseData.Usage
|
||||
if usage.PromptTokens == 0 {
|
||||
usage.PromptTokens = usage.InputTokens
|
||||
}
|
||||
if usage.CompletionTokens == 0 {
|
||||
usage.CompletionTokens = usage.OutputTokens
|
||||
}
|
||||
return nil, usage
|
||||
}
|
||||
}
|
||||
|
||||
usage := &dto.Usage{}
|
||||
usage.PromptTokens = info.GetEstimatePromptTokens()
|
||||
usage.CompletionTokens = 0
|
||||
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
|
||||
return nil, usage
|
||||
}
|
||||
@@ -172,7 +172,7 @@ func handleLastResponse(lastStreamData string, responseId *string, createAt *int
|
||||
shouldSendLastResp *bool) error {
|
||||
|
||||
var lastStreamResponse dto.ChatCompletionsStreamResponse
|
||||
if err := json.Unmarshal(common.StringToByteSlice(lastStreamData), &lastStreamResponse); err != nil {
|
||||
if err := common.Unmarshal(common.StringToByteSlice(lastStreamData), &lastStreamResponse); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
@@ -151,7 +150,7 @@ func OaiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re
|
||||
var streamResp struct {
|
||||
Usage *dto.Usage `json:"usage"`
|
||||
}
|
||||
err := json.Unmarshal([]byte(secondLastStreamData), &streamResp)
|
||||
err := common.Unmarshal([]byte(secondLastStreamData), &streamResp)
|
||||
if err == nil && streamResp.Usage != nil && service.ValidUsage(streamResp.Usage) {
|
||||
usage = streamResp.Usage
|
||||
containStreamUsage = true
|
||||
@@ -327,68 +326,6 @@ func streamTTSResponse(c *gin.Context, resp *http.Response) {
|
||||
}
|
||||
}
|
||||
|
||||
func OpenaiTTSHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) *dto.Usage {
|
||||
// the status code has been judged before, if there is a body reading failure,
|
||||
// it should be regarded as a non-recoverable error, so it should not return err for external retry.
|
||||
// Analogous to nginx's load balancing, it will only retry if it can't be requested or
|
||||
// if the upstream returns a specific status code, once the upstream has already written the header,
|
||||
// the subsequent failure of the response body should be regarded as a non-recoverable error,
|
||||
// and can be terminated directly.
|
||||
defer service.CloseResponseBodyGracefully(resp)
|
||||
usage := &dto.Usage{}
|
||||
usage.PromptTokens = info.GetEstimatePromptTokens()
|
||||
usage.TotalTokens = info.GetEstimatePromptTokens()
|
||||
for k, v := range resp.Header {
|
||||
c.Writer.Header().Set(k, v[0])
|
||||
}
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
|
||||
isStreaming := resp.ContentLength == -1 || resp.Header.Get("Content-Length") == ""
|
||||
if isStreaming {
|
||||
streamTTSResponse(c, resp)
|
||||
} else {
|
||||
c.Writer.WriteHeaderNow()
|
||||
_, err := io.Copy(c.Writer, resp.Body)
|
||||
if err != nil {
|
||||
logger.LogError(c, err.Error())
|
||||
}
|
||||
}
|
||||
return usage
|
||||
}
|
||||
|
||||
func OpenaiSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, responseFormat string) (*types.NewAPIError, *dto.Usage) {
|
||||
defer service.CloseResponseBodyGracefully(resp)
|
||||
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError), nil
|
||||
}
|
||||
// 写入新的 response body
|
||||
service.IOCopyBytesGracefully(c, resp, responseBody)
|
||||
|
||||
var responseData struct {
|
||||
Usage *dto.Usage `json:"usage"`
|
||||
}
|
||||
if err := json.Unmarshal(responseBody, &responseData); err == nil && responseData.Usage != nil {
|
||||
if responseData.Usage.TotalTokens > 0 {
|
||||
usage := responseData.Usage
|
||||
if usage.PromptTokens == 0 {
|
||||
usage.PromptTokens = usage.InputTokens
|
||||
}
|
||||
if usage.CompletionTokens == 0 {
|
||||
usage.CompletionTokens = usage.OutputTokens
|
||||
}
|
||||
return nil, usage
|
||||
}
|
||||
}
|
||||
|
||||
usage := &dto.Usage{}
|
||||
usage.PromptTokens = info.GetEstimatePromptTokens()
|
||||
usage.CompletionTokens = 0
|
||||
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
|
||||
return nil, usage
|
||||
}
|
||||
|
||||
func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*types.NewAPIError, *dto.RealtimeUsage) {
|
||||
if info == nil || info.ClientWs == nil || info.TargetWs == nil {
|
||||
return types.NewError(fmt.Errorf("invalid websocket connection"), types.ErrorCodeBadResponse), nil
|
||||
@@ -659,7 +596,7 @@ func applyUsagePostProcessing(info *relaycommon.RelayInfo, usage *dto.Usage, res
|
||||
if usage.PromptTokensDetails.CachedTokens == 0 && usage.PromptCacheHitTokens != 0 {
|
||||
usage.PromptTokensDetails.CachedTokens = usage.PromptCacheHitTokens
|
||||
}
|
||||
case constant.ChannelTypeZhipu_v4:
|
||||
case constant.ChannelTypeZhipu_v4, constant.ChannelTypeMoonshot:
|
||||
if usage.PromptTokensDetails.CachedTokens == 0 {
|
||||
if usage.InputTokensDetails != nil && usage.InputTokensDetails.CachedTokens > 0 {
|
||||
usage.PromptTokensDetails.CachedTokens = usage.InputTokensDetails.CachedTokens
|
||||
@@ -687,7 +624,7 @@ func extractCachedTokensFromBody(body []byte) (int, bool) {
|
||||
} `json:"usage"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(body, &payload); err != nil {
|
||||
if err := common.Unmarshal(body, &payload); err != nil {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
|
||||
@@ -393,7 +393,7 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela
|
||||
}
|
||||
|
||||
// FetchTask 查询任务状态
|
||||
func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error) {
|
||||
func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any, proxy string) (*http.Response, error) {
|
||||
taskID, ok := body["task_id"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid task_id")
|
||||
@@ -408,7 +408,11 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http
|
||||
|
||||
req.Header.Set("Authorization", "Bearer "+key)
|
||||
|
||||
return service.GetHttpClient().Do(req)
|
||||
client, err := service.GetHttpClientWithProxy(proxy)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("new proxy http client failed: %w", err)
|
||||
}
|
||||
return client.Do(req)
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) GetModelList() []string {
|
||||
|
||||
@@ -146,7 +146,7 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela
|
||||
}
|
||||
|
||||
// FetchTask fetch task status
|
||||
func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error) {
|
||||
func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any, proxy string) (*http.Response, error) {
|
||||
taskID, ok := body["task_id"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid task_id")
|
||||
@@ -163,7 +163,11 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+key)
|
||||
|
||||
return service.GetHttpClient().Do(req)
|
||||
client, err := service.GetHttpClientWithProxy(proxy)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("new proxy http client failed: %w", err)
|
||||
}
|
||||
return client.Do(req)
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) GetModelList() []string {
|
||||
|
||||
@@ -200,7 +200,7 @@ func (a *TaskAdaptor) GetChannelName() string {
|
||||
}
|
||||
|
||||
// FetchTask fetch task status
|
||||
func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error) {
|
||||
func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any, proxy string) (*http.Response, error) {
|
||||
taskID, ok := body["task_id"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid task_id")
|
||||
@@ -223,7 +223,11 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http
|
||||
req.Header.Set("Accept", "application/json")
|
||||
req.Header.Set("x-goog-api-key", key)
|
||||
|
||||
return service.GetHttpClient().Do(req)
|
||||
client, err := service.GetHttpClientWithProxy(proxy)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("new proxy http client failed: %w", err)
|
||||
}
|
||||
return client.Do(req)
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) {
|
||||
|
||||
@@ -110,7 +110,7 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela
|
||||
return hResp.TaskID, responseBody, nil
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error) {
|
||||
func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any, proxy string) (*http.Response, error) {
|
||||
taskID, ok := body["task_id"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid task_id")
|
||||
@@ -126,7 +126,11 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http
|
||||
req.Header.Set("Accept", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+key)
|
||||
|
||||
return service.GetHttpClient().Do(req)
|
||||
client, err := service.GetHttpClientWithProxy(proxy)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("new proxy http client failed: %w", err)
|
||||
}
|
||||
return client.Do(req)
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) GetModelList() []string {
|
||||
|
||||
@@ -196,7 +196,7 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela
|
||||
}
|
||||
|
||||
if jResp.Code != 10000 {
|
||||
taskErr = service.TaskErrorWrapper(fmt.Errorf(jResp.Message), fmt.Sprintf("%d", jResp.Code), http.StatusInternalServerError)
|
||||
taskErr = service.TaskErrorWrapper(fmt.Errorf("%s", jResp.Message), fmt.Sprintf("%d", jResp.Code), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -210,7 +210,7 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela
|
||||
}
|
||||
|
||||
// FetchTask fetch task status
|
||||
func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error) {
|
||||
func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any, proxy string) (*http.Response, error) {
|
||||
taskID, ok := body["task_id"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid task_id")
|
||||
@@ -251,7 +251,11 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http
|
||||
return nil, errors.Wrap(err, "sign request failed")
|
||||
}
|
||||
}
|
||||
return service.GetHttpClient().Do(req)
|
||||
client, err := service.GetHttpClientWithProxy(proxy)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("new proxy http client failed: %w", err)
|
||||
}
|
||||
return client.Do(req)
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) GetModelList() []string {
|
||||
|
||||
@@ -186,7 +186,7 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela
|
||||
return
|
||||
}
|
||||
if kResp.Code != 0 {
|
||||
taskErr = service.TaskErrorWrapperLocal(fmt.Errorf(kResp.Message), "task_failed", http.StatusBadRequest)
|
||||
taskErr = service.TaskErrorWrapperLocal(fmt.Errorf("%s", kResp.Message), "task_failed", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
ov := dto.NewOpenAIVideo()
|
||||
@@ -199,7 +199,7 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela
|
||||
}
|
||||
|
||||
// FetchTask fetch task status
|
||||
func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error) {
|
||||
func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any, proxy string) (*http.Response, error) {
|
||||
taskID, ok := body["task_id"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid task_id")
|
||||
@@ -228,7 +228,11 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
req.Header.Set("User-Agent", "kling-sdk/1.0")
|
||||
|
||||
return service.GetHttpClient().Do(req)
|
||||
client, err := service.GetHttpClientWithProxy(proxy)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("new proxy http client failed: %w", err)
|
||||
}
|
||||
return client.Do(req)
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) GetModelList() []string {
|
||||
|
||||
@@ -5,8 +5,10 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/constant"
|
||||
"github.com/QuantumNous/new-api/dto"
|
||||
"github.com/QuantumNous/new-api/model"
|
||||
"github.com/QuantumNous/new-api/relay/channel"
|
||||
@@ -67,11 +69,30 @@ func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) {
|
||||
a.apiKey = info.ApiKey
|
||||
}
|
||||
|
||||
func validateRemixRequest(c *gin.Context) *dto.TaskError {
|
||||
var req struct {
|
||||
Prompt string `json:"prompt"`
|
||||
}
|
||||
if err := common.UnmarshalBodyReusable(c, &req); err != nil {
|
||||
return service.TaskErrorWrapperLocal(err, "invalid_request", http.StatusBadRequest)
|
||||
}
|
||||
if strings.TrimSpace(req.Prompt) == "" {
|
||||
return service.TaskErrorWrapperLocal(fmt.Errorf("field prompt is required"), "invalid_request", http.StatusBadRequest)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.TaskError) {
|
||||
if info.Action == constant.TaskActionRemix {
|
||||
return validateRemixRequest(c)
|
||||
}
|
||||
return relaycommon.ValidateMultipartDirect(c, info)
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
if info.Action == constant.TaskActionRemix {
|
||||
return fmt.Sprintf("%s/v1/videos/%s/remix", a.baseURL, info.OriginTaskID), nil
|
||||
}
|
||||
return fmt.Sprintf("%s/v1/videos", a.baseURL), nil
|
||||
}
|
||||
|
||||
@@ -125,7 +146,7 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, _ *relayco
|
||||
}
|
||||
|
||||
// FetchTask fetch task status
|
||||
func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error) {
|
||||
func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any, proxy string) (*http.Response, error) {
|
||||
taskID, ok := body["task_id"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid task_id")
|
||||
@@ -140,7 +161,11 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http
|
||||
|
||||
req.Header.Set("Authorization", "Bearer "+key)
|
||||
|
||||
return service.GetHttpClient().Do(req)
|
||||
client, err := service.GetHttpClientWithProxy(proxy)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("new proxy http client failed: %w", err)
|
||||
}
|
||||
return client.Do(req)
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) GetModelList() []string {
|
||||
|
||||
@@ -105,7 +105,7 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela
|
||||
return
|
||||
}
|
||||
if !sunoResponse.IsSuccess() {
|
||||
taskErr = service.TaskErrorWrapper(fmt.Errorf(sunoResponse.Message), sunoResponse.Code, http.StatusInternalServerError)
|
||||
taskErr = service.TaskErrorWrapper(fmt.Errorf("%s", sunoResponse.Message), sunoResponse.Code, http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -132,7 +132,7 @@ func (a *TaskAdaptor) GetChannelName() string {
|
||||
return ChannelName
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error) {
|
||||
func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any, proxy string) (*http.Response, error) {
|
||||
requestUrl := fmt.Sprintf("%s/suno/fetch", baseUrl)
|
||||
byteBody, err := json.Marshal(body)
|
||||
if err != nil {
|
||||
@@ -153,11 +153,11 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http
|
||||
req = req.WithContext(ctx)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+key)
|
||||
resp, err := service.GetHttpClient().Do(req)
|
||||
client, err := service.GetHttpClientWithProxy(proxy)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("new proxy http client failed: %w", err)
|
||||
}
|
||||
return resp, nil
|
||||
return client.Do(req)
|
||||
}
|
||||
|
||||
func actionValidate(c *gin.Context, sunoRequest *dto.SunoSubmitReq, action string) (err error) {
|
||||
|
||||
@@ -120,7 +120,11 @@ func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info
|
||||
return fmt.Errorf("failed to decode credentials: %w", err)
|
||||
}
|
||||
|
||||
token, err := vertexcore.AcquireAccessToken(*adc, "")
|
||||
proxy := ""
|
||||
if info != nil {
|
||||
proxy = info.ChannelSetting.Proxy
|
||||
}
|
||||
token, err := vertexcore.AcquireAccessToken(*adc, proxy)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to acquire access token: %w", err)
|
||||
}
|
||||
@@ -216,7 +220,7 @@ func (a *TaskAdaptor) GetModelList() []string { return []string{"veo-3.0-generat
|
||||
func (a *TaskAdaptor) GetChannelName() string { return "vertex" }
|
||||
|
||||
// FetchTask fetch task status
|
||||
func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error) {
|
||||
func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any, proxy string) (*http.Response, error) {
|
||||
taskID, ok := body["task_id"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid task_id")
|
||||
@@ -249,7 +253,7 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http
|
||||
if err := json.Unmarshal([]byte(key), adc); err != nil {
|
||||
return nil, fmt.Errorf("failed to decode credentials: %w", err)
|
||||
}
|
||||
token, err := vertexcore.AcquireAccessToken(*adc, "")
|
||||
token, err := vertexcore.AcquireAccessToken(*adc, proxy)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to acquire access token: %w", err)
|
||||
}
|
||||
@@ -261,7 +265,11 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http
|
||||
req.Header.Set("Accept", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
req.Header.Set("x-goog-user-project", adc.ProjectID)
|
||||
return service.GetHttpClient().Do(req)
|
||||
client, err := service.GetHttpClientWithProxy(proxy)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("new proxy http client failed: %w", err)
|
||||
}
|
||||
return client.Do(req)
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) {
|
||||
|
||||
@@ -188,7 +188,7 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela
|
||||
return vResp.TaskId, responseBody, nil
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error) {
|
||||
func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any, proxy string) (*http.Response, error) {
|
||||
taskID, ok := body["task_id"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid task_id")
|
||||
@@ -204,7 +204,11 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http
|
||||
req.Header.Set("Accept", "application/json")
|
||||
req.Header.Set("Authorization", "Token "+key)
|
||||
|
||||
return service.GetHttpClient().Do(req)
|
||||
client, err := service.GetHttpClientWithProxy(proxy)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("new proxy http client failed: %w", err)
|
||||
}
|
||||
return client.Do(req)
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) GetModelList() []string {
|
||||
|
||||
@@ -17,6 +17,7 @@ import (
|
||||
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||||
"github.com/QuantumNous/new-api/relay/constant"
|
||||
"github.com/QuantumNous/new-api/setting/model_setting"
|
||||
"github.com/QuantumNous/new-api/setting/reasoning"
|
||||
"github.com/QuantumNous/new-api/types"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -51,10 +52,43 @@ type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertGeminiRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeminiChatRequest) (any, error) {
|
||||
// Vertex AI does not support functionResponse.id; keep it stripped here for consistency.
|
||||
if model_setting.GetGeminiSettings().RemoveFunctionResponseIdEnabled {
|
||||
removeFunctionResponseID(request)
|
||||
}
|
||||
geminiAdaptor := gemini.Adaptor{}
|
||||
return geminiAdaptor.ConvertGeminiRequest(c, info, request)
|
||||
}
|
||||
|
||||
func removeFunctionResponseID(request *dto.GeminiChatRequest) {
|
||||
if request == nil {
|
||||
return
|
||||
}
|
||||
|
||||
if len(request.Contents) > 0 {
|
||||
for i := range request.Contents {
|
||||
if len(request.Contents[i].Parts) == 0 {
|
||||
continue
|
||||
}
|
||||
for j := range request.Contents[i].Parts {
|
||||
part := &request.Contents[i].Parts[j]
|
||||
if part.FunctionResponse == nil {
|
||||
continue
|
||||
}
|
||||
if len(part.FunctionResponse.ID) > 0 {
|
||||
part.FunctionResponse.ID = nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(request.Requests) > 0 {
|
||||
for i := range request.Requests {
|
||||
removeFunctionResponseID(&request.Requests[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error) {
|
||||
if v, ok := claudeModelMap[info.UpstreamModelName]; ok {
|
||||
c.Set("request_model", v)
|
||||
@@ -182,6 +216,8 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
info.UpstreamModelName = strings.TrimSuffix(info.UpstreamModelName, "-thinking")
|
||||
} else if strings.HasSuffix(info.UpstreamModelName, "-nothinking") {
|
||||
info.UpstreamModelName = strings.TrimSuffix(info.UpstreamModelName, "-nothinking")
|
||||
} else if baseModel, level, ok := reasoning.TrimEffortSuffix(info.UpstreamModelName); ok && level != "" {
|
||||
info.UpstreamModelName = baseModel
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -36,8 +36,7 @@ func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInf
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
return request, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
||||
@@ -63,6 +62,8 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
return fmt.Sprintf("%s/embeddings", specialPlan.OpenAIBaseURL), nil
|
||||
}
|
||||
return fmt.Sprintf("%s/api/paas/v4/embeddings", baseURL), nil
|
||||
case relayconstant.RelayModeImagesGenerations:
|
||||
return fmt.Sprintf("%s/api/paas/v4/images/generations", baseURL), nil
|
||||
default:
|
||||
if hasSpecialPlan && specialPlan.OpenAIBaseURL != "" {
|
||||
return fmt.Sprintf("%s/chat/completions", specialPlan.OpenAIBaseURL), nil
|
||||
@@ -114,6 +115,9 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
|
||||
return claude.ClaudeHandler(c, resp, info, claude.RequestModeMessage)
|
||||
}
|
||||
default:
|
||||
if info.RelayMode == relayconstant.RelayModeImagesGenerations {
|
||||
return zhipu4vImageHandler(c, resp, info)
|
||||
}
|
||||
adaptor := openai.Adaptor{}
|
||||
return adaptor.DoResponse(c, resp, info)
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/QuantumNous/new-api/dto"
|
||||
"github.com/QuantumNous/new-api/types"
|
||||
)
|
||||
|
||||
// type ZhipuMessage struct {
|
||||
@@ -37,7 +38,7 @@ type ZhipuV4Response struct {
|
||||
Model string `json:"model"`
|
||||
TextResponseChoices []dto.OpenAITextResponseChoice `json:"choices"`
|
||||
Usage dto.Usage `json:"usage"`
|
||||
Error dto.OpenAIError `json:"error"`
|
||||
Error types.OpenAIError `json:"error"`
|
||||
}
|
||||
|
||||
//
|
||||
|
||||
127
relay/channel/zhipu_4v/image.go
Normal file
127
relay/channel/zhipu_4v/image.go
Normal file
@@ -0,0 +1,127 @@
|
||||
package zhipu_4v
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/dto"
|
||||
"github.com/QuantumNous/new-api/logger"
|
||||
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||||
"github.com/QuantumNous/new-api/service"
|
||||
"github.com/QuantumNous/new-api/types"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type zhipuImageRequest struct {
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt"`
|
||||
Quality string `json:"quality,omitempty"`
|
||||
Size string `json:"size,omitempty"`
|
||||
WatermarkEnabled *bool `json:"watermark_enabled,omitempty"`
|
||||
UserID string `json:"user_id,omitempty"`
|
||||
}
|
||||
|
||||
type zhipuImageResponse struct {
|
||||
Created *int64 `json:"created,omitempty"`
|
||||
Data []zhipuImageData `json:"data,omitempty"`
|
||||
ContentFilter any `json:"content_filter,omitempty"`
|
||||
Usage *dto.Usage `json:"usage,omitempty"`
|
||||
Error *zhipuImageError `json:"error,omitempty"`
|
||||
RequestID string `json:"request_id,omitempty"`
|
||||
ExtendParam map[string]string `json:"extendParam,omitempty"`
|
||||
}
|
||||
|
||||
type zhipuImageError struct {
|
||||
Code string `json:"code"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
type zhipuImageData struct {
|
||||
Url string `json:"url,omitempty"`
|
||||
ImageUrl string `json:"image_url,omitempty"`
|
||||
B64Json string `json:"b64_json,omitempty"`
|
||||
B64Image string `json:"b64_image,omitempty"`
|
||||
}
|
||||
|
||||
type openAIImagePayload struct {
|
||||
Created int64 `json:"created"`
|
||||
Data []openAIImageData `json:"data"`
|
||||
}
|
||||
|
||||
type openAIImageData struct {
|
||||
B64Json string `json:"b64_json"`
|
||||
}
|
||||
|
||||
func zhipu4vImageHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.Usage, *types.NewAPIError) {
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError)
|
||||
}
|
||||
service.CloseResponseBodyGracefully(resp)
|
||||
|
||||
var zhipuResp zhipuImageResponse
|
||||
if err := common.Unmarshal(responseBody, &zhipuResp); err != nil {
|
||||
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
if zhipuResp.Error != nil && zhipuResp.Error.Message != "" {
|
||||
return nil, types.WithOpenAIError(types.OpenAIError{
|
||||
Message: zhipuResp.Error.Message,
|
||||
Type: "zhipu_image_error",
|
||||
Code: zhipuResp.Error.Code,
|
||||
}, resp.StatusCode)
|
||||
}
|
||||
|
||||
payload := openAIImagePayload{}
|
||||
if zhipuResp.Created != nil && *zhipuResp.Created != 0 {
|
||||
payload.Created = *zhipuResp.Created
|
||||
} else {
|
||||
payload.Created = info.StartTime.Unix()
|
||||
}
|
||||
for _, data := range zhipuResp.Data {
|
||||
url := data.Url
|
||||
if url == "" {
|
||||
url = data.ImageUrl
|
||||
}
|
||||
if url == "" {
|
||||
logger.LogWarn(c, "zhipu_image_missing_url")
|
||||
continue
|
||||
}
|
||||
|
||||
var b64 string
|
||||
switch {
|
||||
case data.B64Json != "":
|
||||
b64 = data.B64Json
|
||||
case data.B64Image != "":
|
||||
b64 = data.B64Image
|
||||
default:
|
||||
_, downloaded, err := service.GetImageFromUrl(url)
|
||||
if err != nil {
|
||||
logger.LogError(c, "zhipu_image_get_b64_failed: "+err.Error())
|
||||
continue
|
||||
}
|
||||
b64 = downloaded
|
||||
}
|
||||
|
||||
if b64 == "" {
|
||||
logger.LogWarn(c, "zhipu_image_empty_b64")
|
||||
continue
|
||||
}
|
||||
|
||||
imageData := openAIImageData{
|
||||
B64Json: b64,
|
||||
}
|
||||
payload.Data = append(payload.Data, imageData)
|
||||
}
|
||||
|
||||
jsonResp, err := common.Marshal(payload)
|
||||
if err != nil {
|
||||
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||
}
|
||||
|
||||
service.IOCopyBytesGracefully(c, resp, jsonResp)
|
||||
|
||||
return &dto.Usage{}, nil
|
||||
}
|
||||
@@ -11,6 +11,8 @@ import (
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
|
||||
var negativeIndexRegexp = regexp.MustCompile(`\.(-\d+)`)
|
||||
|
||||
type ConditionOperation struct {
|
||||
Path string `json:"path"` // JSON路径
|
||||
Mode string `json:"mode"` // full, prefix, suffix, contains, gt, gte, lt, lte
|
||||
@@ -186,8 +188,7 @@ func checkSingleCondition(jsonStr, contextJSON string, condition ConditionOperat
|
||||
}
|
||||
|
||||
func processNegativeIndex(jsonStr string, path string) string {
|
||||
re := regexp.MustCompile(`\.(-\d+)`)
|
||||
matches := re.FindAllStringSubmatch(path, -1)
|
||||
matches := negativeIndexRegexp.FindAllStringSubmatch(path, -1)
|
||||
|
||||
if len(matches) == 0 {
|
||||
return path
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"github.com/QuantumNous/new-api/constant"
|
||||
"github.com/QuantumNous/new-api/dto"
|
||||
relayconstant "github.com/QuantumNous/new-api/relay/constant"
|
||||
"github.com/QuantumNous/new-api/setting/model_setting"
|
||||
"github.com/QuantumNous/new-api/types"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -81,8 +82,9 @@ type TokenCountMeta struct {
|
||||
type RelayInfo struct {
|
||||
TokenId int
|
||||
TokenKey string
|
||||
TokenGroup string
|
||||
UserId int
|
||||
UsingGroup string // 使用的分组
|
||||
UsingGroup string // 使用的分组,当auto跨分组重试时,会变动
|
||||
UserGroup string // 用户所在分组
|
||||
TokenUnlimited bool
|
||||
StartTime time.Time
|
||||
@@ -373,6 +375,12 @@ func genBaseRelayInfo(c *gin.Context, request dto.Request) *RelayInfo {
|
||||
//channelId := common.GetContextKeyInt(c, constant.ContextKeyChannelId)
|
||||
//paramOverride := common.GetContextKeyStringMap(c, constant.ContextKeyChannelParamOverride)
|
||||
|
||||
tokenGroup := common.GetContextKeyString(c, constant.ContextKeyTokenGroup)
|
||||
// 当令牌分组为空时,表示使用用户分组
|
||||
if tokenGroup == "" {
|
||||
tokenGroup = common.GetContextKeyString(c, constant.ContextKeyUserGroup)
|
||||
}
|
||||
|
||||
startTime := common.GetContextKeyTime(c, constant.ContextKeyRequestStartTime)
|
||||
if startTime.IsZero() {
|
||||
startTime = time.Now()
|
||||
@@ -400,6 +408,7 @@ func genBaseRelayInfo(c *gin.Context, request dto.Request) *RelayInfo {
|
||||
TokenId: common.GetContextKeyInt(c, constant.ContextKeyTokenId),
|
||||
TokenKey: common.GetContextKeyString(c, constant.ContextKeyTokenKey),
|
||||
TokenUnlimited: common.GetContextKeyBool(c, constant.ContextKeyTokenUnlimited),
|
||||
TokenGroup: tokenGroup,
|
||||
|
||||
isFirstResponse: true,
|
||||
RelayMode: relayconstant.Path2RelayMode(c.Request.URL.Path),
|
||||
@@ -626,3 +635,47 @@ func RemoveDisabledFields(jsonData []byte, channelOtherSettings dto.ChannelOther
|
||||
}
|
||||
return jsonDataAfter, nil
|
||||
}
|
||||
|
||||
// RemoveGeminiDisabledFields removes disabled fields from Gemini request JSON data
|
||||
// Currently supports removing functionResponse.id field which Vertex AI does not support
|
||||
func RemoveGeminiDisabledFields(jsonData []byte) ([]byte, error) {
|
||||
if !model_setting.GetGeminiSettings().RemoveFunctionResponseIdEnabled {
|
||||
return jsonData, nil
|
||||
}
|
||||
|
||||
var data map[string]interface{}
|
||||
if err := common.Unmarshal(jsonData, &data); err != nil {
|
||||
common.SysError("RemoveGeminiDisabledFields Unmarshal error: " + err.Error())
|
||||
return jsonData, nil
|
||||
}
|
||||
|
||||
// Process contents array
|
||||
// Handle both camelCase (functionResponse) and snake_case (function_response)
|
||||
if contents, ok := data["contents"].([]interface{}); ok {
|
||||
for _, content := range contents {
|
||||
if contentMap, ok := content.(map[string]interface{}); ok {
|
||||
if parts, ok := contentMap["parts"].([]interface{}); ok {
|
||||
for _, part := range parts {
|
||||
if partMap, ok := part.(map[string]interface{}); ok {
|
||||
// Check functionResponse (camelCase)
|
||||
if funcResp, ok := partMap["functionResponse"].(map[string]interface{}); ok {
|
||||
delete(funcResp, "id")
|
||||
}
|
||||
// Check function_response (snake_case)
|
||||
if funcResp, ok := partMap["function_response"].(map[string]interface{}); ok {
|
||||
delete(funcResp, "id")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
jsonDataAfter, err := common.Marshal(data)
|
||||
if err != nil {
|
||||
common.SysError("RemoveGeminiDisabledFields Marshal error: " + err.Error())
|
||||
return jsonData, nil
|
||||
}
|
||||
return jsonDataAfter, nil
|
||||
}
|
||||
|
||||
@@ -181,7 +181,7 @@ func TextHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types
|
||||
return newApiErr
|
||||
}
|
||||
|
||||
if strings.HasPrefix(info.OriginModelName, "gpt-4o-audio") {
|
||||
if usage.(*dto.Usage).CompletionTokenDetails.AudioTokens > 0 || usage.(*dto.Usage).PromptTokensDetails.AudioTokens > 0 {
|
||||
service.PostAudioConsumeQuota(c, info, usage.(*dto.Usage), "")
|
||||
} else {
|
||||
postConsumeQuota(c, info, usage.(*dto.Usage), "")
|
||||
@@ -300,14 +300,20 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage
|
||||
if !relayInfo.PriceData.UsePrice {
|
||||
baseTokens := dPromptTokens
|
||||
// 减去 cached tokens
|
||||
// Anthropic API 的 input_tokens 已经不包含缓存 tokens,不需要减去
|
||||
// OpenAI/OpenRouter 等 API 的 prompt_tokens 包含缓存 tokens,需要减去
|
||||
var cachedTokensWithRatio decimal.Decimal
|
||||
if !dCacheTokens.IsZero() {
|
||||
baseTokens = baseTokens.Sub(dCacheTokens)
|
||||
if relayInfo.ChannelType != constant.ChannelTypeAnthropic {
|
||||
baseTokens = baseTokens.Sub(dCacheTokens)
|
||||
}
|
||||
cachedTokensWithRatio = dCacheTokens.Mul(dCacheRatio)
|
||||
}
|
||||
var dCachedCreationTokensWithRatio decimal.Decimal
|
||||
if !dCachedCreationTokens.IsZero() {
|
||||
baseTokens = baseTokens.Sub(dCachedCreationTokens)
|
||||
if relayInfo.ChannelType != constant.ChannelTypeAnthropic {
|
||||
baseTokens = baseTokens.Sub(dCachedCreationTokens)
|
||||
}
|
||||
dCachedCreationTokensWithRatio = dCachedCreationTokens.Mul(dCachedCreationRatio)
|
||||
}
|
||||
|
||||
|
||||
@@ -14,15 +14,28 @@ import (
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
func FlushWriter(c *gin.Context) error {
|
||||
if c.Writer == nil {
|
||||
func FlushWriter(c *gin.Context) (err error) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
err = fmt.Errorf("flush panic recovered: %v", r)
|
||||
}
|
||||
}()
|
||||
|
||||
if c == nil || c.Writer == nil {
|
||||
return nil
|
||||
}
|
||||
if flusher, ok := c.Writer.(http.Flusher); ok {
|
||||
flusher.Flush()
|
||||
return nil
|
||||
|
||||
if c.Request != nil && c.Request.Context().Err() != nil {
|
||||
return fmt.Errorf("request context done: %w", c.Request.Context().Err())
|
||||
}
|
||||
return errors.New("streaming error: flusher not found")
|
||||
|
||||
flusher, ok := c.Writer.(http.Flusher)
|
||||
if !ok {
|
||||
return errors.New("streaming error: flusher not found")
|
||||
}
|
||||
|
||||
flusher.Flush()
|
||||
return nil
|
||||
}
|
||||
|
||||
func SetEventStreamHeaders(c *gin.Context) {
|
||||
@@ -66,17 +79,31 @@ func ResponseChunkData(c *gin.Context, resp dto.ResponsesStreamResponse, data st
|
||||
}
|
||||
|
||||
func StringData(c *gin.Context, str string) error {
|
||||
//str = strings.TrimPrefix(str, "data: ")
|
||||
//str = strings.TrimSuffix(str, "\r")
|
||||
if c == nil || c.Writer == nil {
|
||||
return errors.New("context or writer is nil")
|
||||
}
|
||||
|
||||
if c.Request != nil && c.Request.Context().Err() != nil {
|
||||
return fmt.Errorf("request context done: %w", c.Request.Context().Err())
|
||||
}
|
||||
|
||||
c.Render(-1, common.CustomEvent{Data: "data: " + str})
|
||||
_ = FlushWriter(c)
|
||||
return nil
|
||||
return FlushWriter(c)
|
||||
}
|
||||
|
||||
func PingData(c *gin.Context) error {
|
||||
c.Writer.Write([]byte(": PING\n\n"))
|
||||
_ = FlushWriter(c)
|
||||
return nil
|
||||
if c == nil || c.Writer == nil {
|
||||
return errors.New("context or writer is nil")
|
||||
}
|
||||
|
||||
if c.Request != nil && c.Request.Context().Err() != nil {
|
||||
return fmt.Errorf("request context done: %w", c.Request.Context().Err())
|
||||
}
|
||||
|
||||
if _, err := c.Writer.Write([]byte(": PING\n\n")); err != nil {
|
||||
return fmt.Errorf("write ping data failed: %w", err)
|
||||
}
|
||||
return FlushWriter(c)
|
||||
}
|
||||
|
||||
func ObjectData(c *gin.Context, object interface{}) error {
|
||||
|
||||
@@ -32,7 +32,94 @@ func RelayTaskSubmit(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.
|
||||
if info.TaskRelayInfo == nil {
|
||||
info.TaskRelayInfo = &relaycommon.TaskRelayInfo{}
|
||||
}
|
||||
path := c.Request.URL.Path
|
||||
if strings.Contains(path, "/v1/videos/") && strings.HasSuffix(path, "/remix") {
|
||||
info.Action = constant.TaskActionRemix
|
||||
}
|
||||
|
||||
// 提取 remix 任务的 video_id
|
||||
if info.Action == constant.TaskActionRemix {
|
||||
videoID := c.Param("video_id")
|
||||
if strings.TrimSpace(videoID) == "" {
|
||||
return service.TaskErrorWrapperLocal(fmt.Errorf("video_id is required"), "invalid_request", http.StatusBadRequest)
|
||||
}
|
||||
info.OriginTaskID = videoID
|
||||
}
|
||||
|
||||
platform := constant.TaskPlatform(c.GetString("platform"))
|
||||
|
||||
// 获取原始任务信息
|
||||
if info.OriginTaskID != "" {
|
||||
originTask, exist, err := model.GetByTaskId(info.UserId, info.OriginTaskID)
|
||||
if err != nil {
|
||||
taskErr = service.TaskErrorWrapper(err, "get_origin_task_failed", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
if !exist {
|
||||
taskErr = service.TaskErrorWrapperLocal(errors.New("task_origin_not_exist"), "task_not_exist", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if info.OriginModelName == "" {
|
||||
if originTask.Properties.OriginModelName != "" {
|
||||
info.OriginModelName = originTask.Properties.OriginModelName
|
||||
} else if originTask.Properties.UpstreamModelName != "" {
|
||||
info.OriginModelName = originTask.Properties.UpstreamModelName
|
||||
} else {
|
||||
var taskData map[string]interface{}
|
||||
_ = json.Unmarshal(originTask.Data, &taskData)
|
||||
if m, ok := taskData["model"].(string); ok && m != "" {
|
||||
info.OriginModelName = m
|
||||
platform = originTask.Platform
|
||||
}
|
||||
}
|
||||
}
|
||||
if originTask.ChannelId != info.ChannelId {
|
||||
channel, err := model.GetChannelById(originTask.ChannelId, true)
|
||||
if err != nil {
|
||||
taskErr = service.TaskErrorWrapperLocal(err, "channel_not_found", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if channel.Status != common.ChannelStatusEnabled {
|
||||
taskErr = service.TaskErrorWrapperLocal(errors.New("the channel of the origin task is disabled"), "task_channel_disable", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
key, _, newAPIError := channel.GetNextEnabledKey()
|
||||
if newAPIError != nil {
|
||||
taskErr = service.TaskErrorWrapper(newAPIError, "channel_no_available_key", newAPIError.StatusCode)
|
||||
return
|
||||
}
|
||||
common.SetContextKey(c, constant.ContextKeyChannelKey, key)
|
||||
common.SetContextKey(c, constant.ContextKeyChannelType, channel.Type)
|
||||
common.SetContextKey(c, constant.ContextKeyChannelBaseUrl, channel.GetBaseURL())
|
||||
common.SetContextKey(c, constant.ContextKeyChannelId, originTask.ChannelId)
|
||||
|
||||
info.ChannelBaseUrl = channel.GetBaseURL()
|
||||
info.ChannelId = originTask.ChannelId
|
||||
info.ChannelType = channel.Type
|
||||
info.ApiKey = key
|
||||
platform = originTask.Platform
|
||||
}
|
||||
|
||||
// 使用原始任务的参数
|
||||
if info.Action == constant.TaskActionRemix {
|
||||
var taskData map[string]interface{}
|
||||
_ = json.Unmarshal(originTask.Data, &taskData)
|
||||
secondsStr, _ := taskData["seconds"].(string)
|
||||
seconds, _ := strconv.Atoi(secondsStr)
|
||||
if seconds <= 0 {
|
||||
seconds = 4
|
||||
}
|
||||
sizeStr, _ := taskData["size"].(string)
|
||||
if info.PriceData.OtherRatios == nil {
|
||||
info.PriceData.OtherRatios = map[string]float64{}
|
||||
}
|
||||
info.PriceData.OtherRatios["seconds"] = float64(seconds)
|
||||
info.PriceData.OtherRatios["size"] = 1
|
||||
if sizeStr == "1792x1024" || sizeStr == "1024x1792" {
|
||||
info.PriceData.OtherRatios["size"] = 1.666667
|
||||
}
|
||||
}
|
||||
}
|
||||
if platform == "" {
|
||||
platform = GetTaskPlatform(c)
|
||||
}
|
||||
@@ -94,34 +181,6 @@ func RelayTaskSubmit(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.
|
||||
return
|
||||
}
|
||||
|
||||
if info.OriginTaskID != "" {
|
||||
originTask, exist, err := model.GetByTaskId(info.UserId, info.OriginTaskID)
|
||||
if err != nil {
|
||||
taskErr = service.TaskErrorWrapper(err, "get_origin_task_failed", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
if !exist {
|
||||
taskErr = service.TaskErrorWrapperLocal(errors.New("task_origin_not_exist"), "task_not_exist", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if originTask.ChannelId != info.ChannelId {
|
||||
channel, err := model.GetChannelById(originTask.ChannelId, true)
|
||||
if err != nil {
|
||||
taskErr = service.TaskErrorWrapperLocal(err, "channel_not_found", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if channel.Status != common.ChannelStatusEnabled {
|
||||
return service.TaskErrorWrapperLocal(errors.New("该任务所属渠道已被禁用"), "task_channel_disable", http.StatusBadRequest)
|
||||
}
|
||||
c.Set("base_url", channel.GetBaseURL())
|
||||
c.Set("channel_id", originTask.ChannelId)
|
||||
c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
|
||||
|
||||
info.ChannelBaseUrl = channel.GetBaseURL()
|
||||
info.ChannelId = originTask.ChannelId
|
||||
}
|
||||
}
|
||||
|
||||
// build body
|
||||
requestBody, err := adaptor.BuildRequestBody(c, info)
|
||||
if err != nil {
|
||||
@@ -137,7 +196,7 @@ func RelayTaskSubmit(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.
|
||||
// handle response
|
||||
if resp != nil && resp.StatusCode != http.StatusOK {
|
||||
responseBody, _ := io.ReadAll(resp.Body)
|
||||
taskErr = service.TaskErrorWrapper(fmt.Errorf(string(responseBody)), "fail_to_fetch_task", resp.StatusCode)
|
||||
taskErr = service.TaskErrorWrapper(fmt.Errorf("%s", string(responseBody)), "fail_to_fetch_task", resp.StatusCode)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -326,6 +385,7 @@ func videoFetchByIDRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *d
|
||||
if channelModel.GetBaseURL() != "" {
|
||||
baseURL = channelModel.GetBaseURL()
|
||||
}
|
||||
proxy := channelModel.GetSetting().Proxy
|
||||
adaptor := GetTaskAdaptor(constant.TaskPlatform(strconv.Itoa(channelModel.Type)))
|
||||
if adaptor == nil {
|
||||
return
|
||||
@@ -333,7 +393,7 @@ func videoFetchByIDRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *d
|
||||
resp, err2 := adaptor.FetchTask(baseURL, channelModel.Key, map[string]any{
|
||||
"task_id": originTask.TaskID,
|
||||
"action": originTask.Action,
|
||||
})
|
||||
}, proxy)
|
||||
if err2 != nil || resp == nil {
|
||||
return
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user