Merge branch 'upstream-main' into feature/improve-param-override

# Conflicts:
#	relay/channel/api_request_test.go
#	relay/common/override_test.go
#	web/src/components/table/channels/modals/EditChannelModal.jsx
This commit is contained in:
Seefs
2026-02-25 13:39:54 +08:00
124 changed files with 7729 additions and 1971 deletions

View File

@@ -36,6 +36,32 @@ type TaskAdaptor interface {
ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) *dto.TaskError
// ── Billing ──────────────────────────────────────────────────────
// EstimateBilling returns OtherRatios for pre-charge based on user request.
// Called after ValidateRequestAndSetAction, before price calculation.
// Adaptors should extract duration, resolution, etc. from the parsed request
// and return them as ratio multipliers (e.g. {"seconds": 5, "size": 1.666}).
// Return nil to use the base model price without extra ratios.
EstimateBilling(c *gin.Context, info *relaycommon.RelayInfo) map[string]float64
// AdjustBillingOnSubmit returns adjusted OtherRatios from the upstream
// submit response. Called after a successful DoResponse.
// If the upstream returned actual parameters that differ from the estimate
// (e.g. actual seconds), return updated ratios so the caller can recalculate
// the quota and settle the delta with the pre-charge.
// Return nil if no adjustment is needed.
AdjustBillingOnSubmit(info *relaycommon.RelayInfo, taskData []byte) map[string]float64
// AdjustBillingOnComplete returns the actual quota when a task reaches a
// terminal state (success/failure) during polling.
// Called by the polling loop after ParseTaskResult.
// Return a positive value to trigger delta settlement (supplement / refund).
// Return 0 to keep the pre-charged amount unchanged.
AdjustBillingOnComplete(task *model.Task, taskResult *relaycommon.TaskInfo) int
// ── Request / Response ───────────────────────────────────────────
BuildRequestURL(info *relaycommon.RelayInfo) (string, error)
BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error
BuildRequestBody(c *gin.Context, info *relaycommon.RelayInfo) (io.Reader, error)
@@ -46,9 +72,9 @@ type TaskAdaptor interface {
GetModelList() []string
GetChannelName() string
// FetchTask
FetchTask(baseUrl, key string, body map[string]any, proxy string) (*http.Response, error)
// ── Polling ──────────────────────────────────────────────────────
FetchTask(baseUrl, key string, body map[string]any, proxy string) (*http.Response, error)
ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error)
}

View File

@@ -61,8 +61,9 @@ var passthroughSkipHeaderNamesLower = map[string]struct{}{
"cookie": {},
// Additional headers that should not be forwarded by name-matching passthrough rules.
"host": {},
"content-length": {},
"host": {},
"content-length": {},
"accept-encoding": {},
// Do not passthrough credentials by wildcard/regex.
"authorization": {},

View File

@@ -110,3 +110,30 @@ func TestProcessHeaderOverride_RuntimeOverrideHasPriority(t *testing.T) {
_, ok := headers["X-Legacy"]
require.False(t, ok)
}
func TestProcessHeaderOverride_PassthroughSkipsAcceptEncoding(t *testing.T) {
t.Parallel()
gin.SetMode(gin.TestMode)
recorder := httptest.NewRecorder()
ctx, _ := gin.CreateTestContext(recorder)
ctx.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
ctx.Request.Header.Set("X-Trace-Id", "trace-123")
ctx.Request.Header.Set("Accept-Encoding", "gzip")
info := &relaycommon.RelayInfo{
IsChannelTest: false,
ChannelMeta: &relaycommon.ChannelMeta{
HeadersOverride: map[string]any{
"*": "",
},
},
}
headers, err := processHeaderOverride(info, ctx)
require.NoError(t, err)
require.Equal(t, "trace-123", headers["X-Trace-Id"])
_, hasAcceptEncoding := headers["Accept-Encoding"]
require.False(t, hasAcceptEncoding)
}

View File

@@ -42,22 +42,7 @@ func GeminiTextGenerationHandler(c *gin.Context, info *relaycommon.RelayInfo, re
}
// 计算使用量(基于 UsageMetadata
usage := dto.Usage{
PromptTokens: geminiResponse.UsageMetadata.PromptTokenCount,
CompletionTokens: geminiResponse.UsageMetadata.CandidatesTokenCount + geminiResponse.UsageMetadata.ThoughtsTokenCount,
TotalTokens: geminiResponse.UsageMetadata.TotalTokenCount,
}
usage.CompletionTokenDetails.ReasoningTokens = geminiResponse.UsageMetadata.ThoughtsTokenCount
usage.PromptTokensDetails.CachedTokens = geminiResponse.UsageMetadata.CachedContentTokenCount
for _, detail := range geminiResponse.UsageMetadata.PromptTokensDetails {
if detail.Modality == "AUDIO" {
usage.PromptTokensDetails.AudioTokens = detail.TokenCount
} else if detail.Modality == "TEXT" {
usage.PromptTokensDetails.TextTokens = detail.TokenCount
}
}
usage := buildUsageFromGeminiMetadata(geminiResponse.UsageMetadata, info.GetEstimatePromptTokens())
service.IOCopyBytesGracefully(c, resp, responseBody)

View File

@@ -1032,6 +1032,46 @@ func getResponseToolCall(item *dto.GeminiPart) *dto.ToolCallResponse {
}
}
func buildUsageFromGeminiMetadata(metadata dto.GeminiUsageMetadata, fallbackPromptTokens int) dto.Usage {
promptTokens := metadata.PromptTokenCount + metadata.ToolUsePromptTokenCount
if promptTokens <= 0 && fallbackPromptTokens > 0 {
promptTokens = fallbackPromptTokens
}
usage := dto.Usage{
PromptTokens: promptTokens,
CompletionTokens: metadata.CandidatesTokenCount + metadata.ThoughtsTokenCount,
TotalTokens: metadata.TotalTokenCount,
}
usage.CompletionTokenDetails.ReasoningTokens = metadata.ThoughtsTokenCount
usage.PromptTokensDetails.CachedTokens = metadata.CachedContentTokenCount
for _, detail := range metadata.PromptTokensDetails {
if detail.Modality == "AUDIO" {
usage.PromptTokensDetails.AudioTokens += detail.TokenCount
} else if detail.Modality == "TEXT" {
usage.PromptTokensDetails.TextTokens += detail.TokenCount
}
}
for _, detail := range metadata.ToolUsePromptTokensDetails {
if detail.Modality == "AUDIO" {
usage.PromptTokensDetails.AudioTokens += detail.TokenCount
} else if detail.Modality == "TEXT" {
usage.PromptTokensDetails.TextTokens += detail.TokenCount
}
}
if usage.TotalTokens > 0 && usage.CompletionTokens <= 0 {
usage.CompletionTokens = usage.TotalTokens - usage.PromptTokens
}
if usage.PromptTokens > 0 && usage.PromptTokensDetails.TextTokens == 0 && usage.PromptTokensDetails.AudioTokens == 0 {
usage.PromptTokensDetails.TextTokens = usage.PromptTokens
}
return usage
}
func responseGeminiChat2OpenAI(c *gin.Context, response *dto.GeminiChatResponse) *dto.OpenAITextResponse {
fullTextResponse := dto.OpenAITextResponse{
Id: helper.GetResponseID(c),
@@ -1272,18 +1312,8 @@ func geminiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http
// 更新使用量统计
if geminiResponse.UsageMetadata.TotalTokenCount != 0 {
usage.PromptTokens = geminiResponse.UsageMetadata.PromptTokenCount
usage.CompletionTokens = geminiResponse.UsageMetadata.CandidatesTokenCount + geminiResponse.UsageMetadata.ThoughtsTokenCount
usage.CompletionTokenDetails.ReasoningTokens = geminiResponse.UsageMetadata.ThoughtsTokenCount
usage.TotalTokens = geminiResponse.UsageMetadata.TotalTokenCount
usage.PromptTokensDetails.CachedTokens = geminiResponse.UsageMetadata.CachedContentTokenCount
for _, detail := range geminiResponse.UsageMetadata.PromptTokensDetails {
if detail.Modality == "AUDIO" {
usage.PromptTokensDetails.AudioTokens = detail.TokenCount
} else if detail.Modality == "TEXT" {
usage.PromptTokensDetails.TextTokens = detail.TokenCount
}
}
mappedUsage := buildUsageFromGeminiMetadata(geminiResponse.UsageMetadata, info.GetEstimatePromptTokens())
*usage = mappedUsage
}
return callback(data, &geminiResponse)
@@ -1295,11 +1325,6 @@ func geminiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http
}
}
usage.PromptTokensDetails.TextTokens = usage.PromptTokens
if usage.TotalTokens > 0 {
usage.CompletionTokens = usage.TotalTokens - usage.PromptTokens
}
if usage.CompletionTokens <= 0 {
if info.ReceivedResponseCount > 0 {
usage = service.ResponseText2Usage(c, responseText.String(), info.UpstreamModelName, info.GetEstimatePromptTokens())
@@ -1416,21 +1441,7 @@ func GeminiChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.R
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
}
if len(geminiResponse.Candidates) == 0 {
usage := dto.Usage{
PromptTokens: geminiResponse.UsageMetadata.PromptTokenCount,
}
usage.CompletionTokenDetails.ReasoningTokens = geminiResponse.UsageMetadata.ThoughtsTokenCount
usage.PromptTokensDetails.CachedTokens = geminiResponse.UsageMetadata.CachedContentTokenCount
for _, detail := range geminiResponse.UsageMetadata.PromptTokensDetails {
if detail.Modality == "AUDIO" {
usage.PromptTokensDetails.AudioTokens = detail.TokenCount
} else if detail.Modality == "TEXT" {
usage.PromptTokensDetails.TextTokens = detail.TokenCount
}
}
if usage.PromptTokens <= 0 {
usage.PromptTokens = info.GetEstimatePromptTokens()
}
usage := buildUsageFromGeminiMetadata(geminiResponse.UsageMetadata, info.GetEstimatePromptTokens())
var newAPIError *types.NewAPIError
if geminiResponse.PromptFeedback != nil && geminiResponse.PromptFeedback.BlockReason != nil {
@@ -1466,23 +1477,7 @@ func GeminiChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.R
}
fullTextResponse := responseGeminiChat2OpenAI(c, &geminiResponse)
fullTextResponse.Model = info.UpstreamModelName
usage := dto.Usage{
PromptTokens: geminiResponse.UsageMetadata.PromptTokenCount,
CompletionTokens: geminiResponse.UsageMetadata.CandidatesTokenCount,
TotalTokens: geminiResponse.UsageMetadata.TotalTokenCount,
}
usage.CompletionTokenDetails.ReasoningTokens = geminiResponse.UsageMetadata.ThoughtsTokenCount
usage.PromptTokensDetails.CachedTokens = geminiResponse.UsageMetadata.CachedContentTokenCount
usage.CompletionTokens = usage.TotalTokens - usage.PromptTokens
for _, detail := range geminiResponse.UsageMetadata.PromptTokensDetails {
if detail.Modality == "AUDIO" {
usage.PromptTokensDetails.AudioTokens = detail.TokenCount
} else if detail.Modality == "TEXT" {
usage.PromptTokensDetails.TextTokens = detail.TokenCount
}
}
usage := buildUsageFromGeminiMetadata(geminiResponse.UsageMetadata, info.GetEstimatePromptTokens())
fullTextResponse.Usage = usage

View File

@@ -0,0 +1,333 @@
package gemini
import (
"bytes"
"io"
"net/http"
"net/http/httptest"
"testing"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/constant"
"github.com/QuantumNous/new-api/dto"
relaycommon "github.com/QuantumNous/new-api/relay/common"
"github.com/QuantumNous/new-api/types"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
func TestGeminiChatHandlerCompletionTokensExcludeToolUsePromptTokens(t *testing.T) {
t.Parallel()
gin.SetMode(gin.TestMode)
c, _ := gin.CreateTestContext(httptest.NewRecorder())
c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
info := &relaycommon.RelayInfo{
RelayFormat: types.RelayFormatGemini,
OriginModelName: "gemini-3-flash-preview",
ChannelMeta: &relaycommon.ChannelMeta{
UpstreamModelName: "gemini-3-flash-preview",
},
}
payload := dto.GeminiChatResponse{
Candidates: []dto.GeminiChatCandidate{
{
Content: dto.GeminiChatContent{
Role: "model",
Parts: []dto.GeminiPart{
{Text: "ok"},
},
},
},
},
UsageMetadata: dto.GeminiUsageMetadata{
PromptTokenCount: 151,
ToolUsePromptTokenCount: 18329,
CandidatesTokenCount: 1089,
ThoughtsTokenCount: 1120,
TotalTokenCount: 20689,
},
}
body, err := common.Marshal(payload)
require.NoError(t, err)
resp := &http.Response{
Body: io.NopCloser(bytes.NewReader(body)),
}
usage, newAPIError := GeminiChatHandler(c, info, resp)
require.Nil(t, newAPIError)
require.NotNil(t, usage)
require.Equal(t, 18480, usage.PromptTokens)
require.Equal(t, 2209, usage.CompletionTokens)
require.Equal(t, 20689, usage.TotalTokens)
require.Equal(t, 1120, usage.CompletionTokenDetails.ReasoningTokens)
}
func TestGeminiStreamHandlerCompletionTokensExcludeToolUsePromptTokens(t *testing.T) {
gin.SetMode(gin.TestMode)
c, _ := gin.CreateTestContext(httptest.NewRecorder())
c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
oldStreamingTimeout := constant.StreamingTimeout
constant.StreamingTimeout = 300
t.Cleanup(func() {
constant.StreamingTimeout = oldStreamingTimeout
})
info := &relaycommon.RelayInfo{
OriginModelName: "gemini-3-flash-preview",
ChannelMeta: &relaycommon.ChannelMeta{
UpstreamModelName: "gemini-3-flash-preview",
},
}
chunk := dto.GeminiChatResponse{
Candidates: []dto.GeminiChatCandidate{
{
Content: dto.GeminiChatContent{
Role: "model",
Parts: []dto.GeminiPart{
{Text: "partial"},
},
},
},
},
UsageMetadata: dto.GeminiUsageMetadata{
PromptTokenCount: 151,
ToolUsePromptTokenCount: 18329,
CandidatesTokenCount: 1089,
ThoughtsTokenCount: 1120,
TotalTokenCount: 20689,
},
}
chunkData, err := common.Marshal(chunk)
require.NoError(t, err)
streamBody := []byte("data: " + string(chunkData) + "\n" + "data: [DONE]\n")
resp := &http.Response{
Body: io.NopCloser(bytes.NewReader(streamBody)),
}
usage, newAPIError := geminiStreamHandler(c, info, resp, func(_ string, _ *dto.GeminiChatResponse) bool {
return true
})
require.Nil(t, newAPIError)
require.NotNil(t, usage)
require.Equal(t, 18480, usage.PromptTokens)
require.Equal(t, 2209, usage.CompletionTokens)
require.Equal(t, 20689, usage.TotalTokens)
require.Equal(t, 1120, usage.CompletionTokenDetails.ReasoningTokens)
}
func TestGeminiTextGenerationHandlerPromptTokensIncludeToolUsePromptTokens(t *testing.T) {
t.Parallel()
gin.SetMode(gin.TestMode)
c, _ := gin.CreateTestContext(httptest.NewRecorder())
c.Request = httptest.NewRequest(http.MethodPost, "/v1beta/models/gemini-3-flash-preview:generateContent", nil)
info := &relaycommon.RelayInfo{
OriginModelName: "gemini-3-flash-preview",
ChannelMeta: &relaycommon.ChannelMeta{
UpstreamModelName: "gemini-3-flash-preview",
},
}
payload := dto.GeminiChatResponse{
Candidates: []dto.GeminiChatCandidate{
{
Content: dto.GeminiChatContent{
Role: "model",
Parts: []dto.GeminiPart{
{Text: "ok"},
},
},
},
},
UsageMetadata: dto.GeminiUsageMetadata{
PromptTokenCount: 151,
ToolUsePromptTokenCount: 18329,
CandidatesTokenCount: 1089,
ThoughtsTokenCount: 1120,
TotalTokenCount: 20689,
},
}
body, err := common.Marshal(payload)
require.NoError(t, err)
resp := &http.Response{
Body: io.NopCloser(bytes.NewReader(body)),
}
usage, newAPIError := GeminiTextGenerationHandler(c, info, resp)
require.Nil(t, newAPIError)
require.NotNil(t, usage)
require.Equal(t, 18480, usage.PromptTokens)
require.Equal(t, 2209, usage.CompletionTokens)
require.Equal(t, 20689, usage.TotalTokens)
require.Equal(t, 1120, usage.CompletionTokenDetails.ReasoningTokens)
}
func TestGeminiChatHandlerUsesEstimatedPromptTokensWhenUsagePromptMissing(t *testing.T) {
t.Parallel()
gin.SetMode(gin.TestMode)
c, _ := gin.CreateTestContext(httptest.NewRecorder())
c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
info := &relaycommon.RelayInfo{
RelayFormat: types.RelayFormatGemini,
OriginModelName: "gemini-3-flash-preview",
ChannelMeta: &relaycommon.ChannelMeta{
UpstreamModelName: "gemini-3-flash-preview",
},
}
info.SetEstimatePromptTokens(20)
payload := dto.GeminiChatResponse{
Candidates: []dto.GeminiChatCandidate{
{
Content: dto.GeminiChatContent{
Role: "model",
Parts: []dto.GeminiPart{
{Text: "ok"},
},
},
},
},
UsageMetadata: dto.GeminiUsageMetadata{
PromptTokenCount: 0,
ToolUsePromptTokenCount: 0,
CandidatesTokenCount: 90,
ThoughtsTokenCount: 10,
TotalTokenCount: 110,
},
}
body, err := common.Marshal(payload)
require.NoError(t, err)
resp := &http.Response{
Body: io.NopCloser(bytes.NewReader(body)),
}
usage, newAPIError := GeminiChatHandler(c, info, resp)
require.Nil(t, newAPIError)
require.NotNil(t, usage)
require.Equal(t, 20, usage.PromptTokens)
require.Equal(t, 100, usage.CompletionTokens)
require.Equal(t, 110, usage.TotalTokens)
}
func TestGeminiStreamHandlerUsesEstimatedPromptTokensWhenUsagePromptMissing(t *testing.T) {
gin.SetMode(gin.TestMode)
c, _ := gin.CreateTestContext(httptest.NewRecorder())
c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
oldStreamingTimeout := constant.StreamingTimeout
constant.StreamingTimeout = 300
t.Cleanup(func() {
constant.StreamingTimeout = oldStreamingTimeout
})
info := &relaycommon.RelayInfo{
OriginModelName: "gemini-3-flash-preview",
ChannelMeta: &relaycommon.ChannelMeta{
UpstreamModelName: "gemini-3-flash-preview",
},
}
info.SetEstimatePromptTokens(20)
chunk := dto.GeminiChatResponse{
Candidates: []dto.GeminiChatCandidate{
{
Content: dto.GeminiChatContent{
Role: "model",
Parts: []dto.GeminiPart{
{Text: "partial"},
},
},
},
},
UsageMetadata: dto.GeminiUsageMetadata{
PromptTokenCount: 0,
ToolUsePromptTokenCount: 0,
CandidatesTokenCount: 90,
ThoughtsTokenCount: 10,
TotalTokenCount: 110,
},
}
chunkData, err := common.Marshal(chunk)
require.NoError(t, err)
streamBody := []byte("data: " + string(chunkData) + "\n" + "data: [DONE]\n")
resp := &http.Response{
Body: io.NopCloser(bytes.NewReader(streamBody)),
}
usage, newAPIError := geminiStreamHandler(c, info, resp, func(_ string, _ *dto.GeminiChatResponse) bool {
return true
})
require.Nil(t, newAPIError)
require.NotNil(t, usage)
require.Equal(t, 20, usage.PromptTokens)
require.Equal(t, 100, usage.CompletionTokens)
require.Equal(t, 110, usage.TotalTokens)
}
func TestGeminiTextGenerationHandlerUsesEstimatedPromptTokensWhenUsagePromptMissing(t *testing.T) {
t.Parallel()
gin.SetMode(gin.TestMode)
c, _ := gin.CreateTestContext(httptest.NewRecorder())
c.Request = httptest.NewRequest(http.MethodPost, "/v1beta/models/gemini-3-flash-preview:generateContent", nil)
info := &relaycommon.RelayInfo{
OriginModelName: "gemini-3-flash-preview",
ChannelMeta: &relaycommon.ChannelMeta{
UpstreamModelName: "gemini-3-flash-preview",
},
}
info.SetEstimatePromptTokens(20)
payload := dto.GeminiChatResponse{
Candidates: []dto.GeminiChatCandidate{
{
Content: dto.GeminiChatContent{
Role: "model",
Parts: []dto.GeminiPart{
{Text: "ok"},
},
},
},
},
UsageMetadata: dto.GeminiUsageMetadata{
PromptTokenCount: 0,
ToolUsePromptTokenCount: 0,
CandidatesTokenCount: 90,
ThoughtsTokenCount: 10,
TotalTokenCount: 110,
},
}
body, err := common.Marshal(payload)
require.NoError(t, err)
resp := &http.Response{
Body: io.NopCloser(bytes.NewReader(body)),
}
usage, newAPIError := GeminiTextGenerationHandler(c, info, resp)
require.Nil(t, newAPIError)
require.NotNil(t, usage)
require.Equal(t, 20, usage.PromptTokens)
require.Equal(t, 100, usage.CompletionTokens)
require.Equal(t, 110, usage.TotalTokens)
}

View File

@@ -10,6 +10,7 @@ import (
"github.com/QuantumNous/new-api/dto"
"github.com/QuantumNous/new-api/relay/channel"
"github.com/QuantumNous/new-api/relay/channel/claude"
"github.com/QuantumNous/new-api/relay/channel/openai"
relaycommon "github.com/QuantumNous/new-api/relay/common"
"github.com/QuantumNous/new-api/relay/constant"
@@ -26,7 +27,8 @@ func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dt
}
func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, req *dto.ClaudeRequest) (any, error) {
return nil, errors.New("not implemented")
adaptor := claude.Adaptor{}
return adaptor.ConvertClaudeRequest(c, info, req)
}
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
@@ -119,8 +121,14 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
return handleTTSResponse(c, resp, info)
}
adaptor := openai.Adaptor{}
return adaptor.DoResponse(c, resp, info)
switch info.RelayFormat {
case types.RelayFormatClaude:
adaptor := claude.Adaptor{}
return adaptor.DoResponse(c, resp, info)
default:
adaptor := openai.Adaptor{}
return adaptor.DoResponse(c, resp, info)
}
}
func (a *Adaptor) GetModelList() []string {

View File

@@ -6,6 +6,7 @@ import (
channelconstant "github.com/QuantumNous/new-api/constant"
relaycommon "github.com/QuantumNous/new-api/relay/common"
"github.com/QuantumNous/new-api/relay/constant"
"github.com/QuantumNous/new-api/types"
)
func GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
@@ -13,13 +14,17 @@ func GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
if baseUrl == "" {
baseUrl = channelconstant.ChannelBaseURLs[channelconstant.ChannelTypeMiniMax]
}
switch info.RelayMode {
case constant.RelayModeChatCompletions:
return fmt.Sprintf("%s/v1/text/chatcompletion_v2", baseUrl), nil
case constant.RelayModeAudioSpeech:
return fmt.Sprintf("%s/v1/t2a_v2", baseUrl), nil
switch info.RelayFormat {
case types.RelayFormatClaude:
return fmt.Sprintf("%s/anthropic/v1/messages", info.ChannelBaseUrl), nil
default:
return "", fmt.Errorf("unsupported relay mode: %d", info.RelayMode)
switch info.RelayMode {
case constant.RelayModeChatCompletions:
return fmt.Sprintf("%s/v1/text/chatcompletion_v2", baseUrl), nil
case constant.RelayModeAudioSpeech:
return fmt.Sprintf("%s/v1/t2a_v2", baseUrl), nil
default:
return "", fmt.Errorf("unsupported relay mode: %d", info.RelayMode)
}
}
}

View File

@@ -13,6 +13,7 @@ import (
"github.com/QuantumNous/new-api/logger"
"github.com/QuantumNous/new-api/model"
"github.com/QuantumNous/new-api/relay/channel"
"github.com/QuantumNous/new-api/relay/channel/task/taskcommon"
relaycommon "github.com/QuantumNous/new-api/relay/common"
"github.com/QuantumNous/new-api/service"
"github.com/samber/lo"
@@ -108,10 +109,10 @@ type AliMetadata struct {
// ============================
type TaskAdaptor struct {
taskcommon.BaseBilling
ChannelType int
apiKey string
baseURL string
aliReq *AliVideoRequest
}
func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) {
@@ -121,17 +122,7 @@ func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) {
}
func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.TaskError) {
// 阿里通义万相支持 JSON 格式,不使用 multipart
var taskReq relaycommon.TaskSubmitReq
if err := common.UnmarshalBodyReusable(c, &taskReq); err != nil {
return service.TaskErrorWrapper(err, "unmarshal_task_request_failed", http.StatusBadRequest)
}
aliReq, err := a.convertToAliRequest(info, taskReq)
if err != nil {
return service.TaskErrorWrapper(err, "convert_to_ali_request_failed", http.StatusInternalServerError)
}
a.aliReq = aliReq
logger.LogJson(c, "ali video request body", aliReq)
// ValidateMultipartDirect 负责解析并将原始 TaskSubmitReq 存入 context
return relaycommon.ValidateMultipartDirect(c, info)
}
@@ -148,11 +139,21 @@ func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info
}
func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayInfo) (io.Reader, error) {
bodyBytes, err := common.Marshal(a.aliReq)
taskReq, err := relaycommon.GetTaskRequest(c)
if err != nil {
return nil, errors.Wrap(err, "get_task_request_failed")
}
aliReq, err := a.convertToAliRequest(info, taskReq)
if err != nil {
return nil, errors.Wrap(err, "convert_to_ali_request_failed")
}
logger.LogJson(c, "ali video request body", aliReq)
bodyBytes, err := common.Marshal(aliReq)
if err != nil {
return nil, errors.Wrap(err, "marshal_ali_request_failed")
}
return bytes.NewReader(bodyBytes), nil
}
@@ -252,8 +253,12 @@ func ProcessAliOtherRatios(aliReq *AliVideoRequest) (map[string]float64, error)
}
func (a *TaskAdaptor) convertToAliRequest(info *relaycommon.RelayInfo, req relaycommon.TaskSubmitReq) (*AliVideoRequest, error) {
upstreamModel := req.Model
if info.IsModelMapped {
upstreamModel = info.UpstreamModelName
}
aliReq := &AliVideoRequest{
Model: req.Model,
Model: upstreamModel,
Input: AliVideoInput{
Prompt: req.Prompt,
ImgURL: req.InputReference,
@@ -331,23 +336,37 @@ func (a *TaskAdaptor) convertToAliRequest(info *relaycommon.RelayInfo, req relay
}
}
if aliReq.Model != req.Model {
if aliReq.Model != upstreamModel {
return nil, errors.New("can't change model with metadata")
}
info.PriceData.OtherRatios = map[string]float64{
return aliReq, nil
}
// EstimateBilling 根据用户请求参数计算 OtherRatios时长、分辨率等
// 在 ValidateRequestAndSetAction 之后、价格计算之前调用。
func (a *TaskAdaptor) EstimateBilling(c *gin.Context, info *relaycommon.RelayInfo) map[string]float64 {
taskReq, err := relaycommon.GetTaskRequest(c)
if err != nil {
return nil
}
aliReq, err := a.convertToAliRequest(info, taskReq)
if err != nil {
return nil
}
otherRatios := map[string]float64{
"seconds": float64(aliReq.Parameters.Duration),
}
ratios, err := ProcessAliOtherRatios(aliReq)
if err != nil {
return nil, err
return otherRatios
}
for s, f := range ratios {
info.PriceData.OtherRatios[s] = f
for k, v := range ratios {
otherRatios[k] = v
}
return aliReq, nil
return otherRatios
}
// DoRequest delegates to common helper
@@ -384,7 +403,8 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela
// 转换为 OpenAI 格式响应
openAIResp := dto.NewOpenAIVideo()
openAIResp.ID = aliResp.Output.TaskID
openAIResp.ID = info.PublicTaskID
openAIResp.TaskID = info.PublicTaskID
openAIResp.Model = c.GetString("model")
if openAIResp.Model == "" && info != nil {
openAIResp.Model = info.OriginModelName

View File

@@ -2,7 +2,6 @@ package doubao
import (
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
@@ -14,6 +13,7 @@ import (
"github.com/QuantumNous/new-api/dto"
"github.com/QuantumNous/new-api/model"
"github.com/QuantumNous/new-api/relay/channel"
taskcommon "github.com/QuantumNous/new-api/relay/channel/task/taskcommon"
relaycommon "github.com/QuantumNous/new-api/relay/common"
"github.com/QuantumNous/new-api/service"
@@ -89,6 +89,7 @@ type responseTask struct {
// ============================
type TaskAdaptor struct {
taskcommon.BaseBilling
ChannelType int
apiKey string
baseURL string
@@ -130,8 +131,12 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayIn
if err != nil {
return nil, errors.Wrap(err, "convert request payload failed")
}
info.UpstreamModelName = body.Model
data, err := json.Marshal(body)
if info.IsModelMapped {
body.Model = info.UpstreamModelName
} else {
info.UpstreamModelName = body.Model
}
data, err := common.Marshal(body)
if err != nil {
return nil, err
}
@@ -154,7 +159,7 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela
// Parse Doubao response
var dResp responsePayload
if err := json.Unmarshal(responseBody, &dResp); err != nil {
if err := common.Unmarshal(responseBody, &dResp); err != nil {
taskErr = service.TaskErrorWrapper(errors.Wrapf(err, "body: %s", responseBody), "unmarshal_response_body_failed", http.StatusInternalServerError)
return
}
@@ -165,8 +170,8 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela
}
ov := dto.NewOpenAIVideo()
ov.ID = dResp.ID
ov.TaskID = dResp.ID
ov.ID = info.PublicTaskID
ov.TaskID = info.PublicTaskID
ov.CreatedAt = time.Now().Unix()
ov.Model = info.OriginModelName
@@ -234,12 +239,7 @@ func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*
}
metadata := req.Metadata
medaBytes, err := json.Marshal(metadata)
if err != nil {
return nil, errors.Wrap(err, "metadata marshal metadata failed")
}
err = json.Unmarshal(medaBytes, &r)
if err != nil {
if err := taskcommon.UnmarshalMetadata(metadata, &r); err != nil {
return nil, errors.Wrap(err, "unmarshal metadata failed")
}
@@ -248,7 +248,7 @@ func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*
func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) {
resTask := responseTask{}
if err := json.Unmarshal(respBody, &resTask); err != nil {
if err := common.Unmarshal(respBody, &resTask); err != nil {
return nil, errors.Wrap(err, "unmarshal task result failed")
}
@@ -286,7 +286,7 @@ func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, e
func (a *TaskAdaptor) ConvertToOpenAIVideo(originTask *model.Task) ([]byte, error) {
var dResp responseTask
if err := json.Unmarshal(originTask.Data, &dResp); err != nil {
if err := common.Unmarshal(originTask.Data, &dResp); err != nil {
return nil, errors.Wrap(err, "unmarshal doubao task data failed")
}
@@ -307,6 +307,5 @@ func (a *TaskAdaptor) ConvertToOpenAIVideo(originTask *model.Task) ([]byte, erro
}
}
jsonData, _ := common.Marshal(openAIVideo)
return jsonData, nil
return common.Marshal(openAIVideo)
}

View File

@@ -2,8 +2,6 @@ package gemini
import (
"bytes"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"net/http"
@@ -16,10 +14,10 @@ import (
"github.com/QuantumNous/new-api/dto"
"github.com/QuantumNous/new-api/model"
"github.com/QuantumNous/new-api/relay/channel"
taskcommon "github.com/QuantumNous/new-api/relay/channel/task/taskcommon"
relaycommon "github.com/QuantumNous/new-api/relay/common"
"github.com/QuantumNous/new-api/service"
"github.com/QuantumNous/new-api/setting/model_setting"
"github.com/QuantumNous/new-api/setting/system_setting"
"github.com/gin-gonic/gin"
"github.com/pkg/errors"
)
@@ -87,6 +85,7 @@ type operationResponse struct {
// ============================
type TaskAdaptor struct {
taskcommon.BaseBilling
ChannelType int
apiKey string
baseURL string
@@ -106,7 +105,7 @@ func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycom
// BuildRequestURL constructs the upstream URL.
func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.RelayInfo) (string, error) {
modelName := info.OriginModelName
modelName := info.UpstreamModelName
version := model_setting.GetGeminiVersionSetting(modelName)
return fmt.Sprintf(
@@ -145,16 +144,11 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayIn
}
metadata := req.Metadata
medaBytes, err := json.Marshal(metadata)
if err != nil {
return nil, errors.Wrap(err, "metadata marshal metadata failed")
}
err = json.Unmarshal(medaBytes, &body.Parameters)
if err != nil {
if err := taskcommon.UnmarshalMetadata(metadata, &body.Parameters); err != nil {
return nil, errors.Wrap(err, "unmarshal metadata failed")
}
data, err := json.Marshal(body)
data, err := common.Marshal(body)
if err != nil {
return nil, err
}
@@ -175,16 +169,16 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela
_ = resp.Body.Close()
var s submitResponse
if err := json.Unmarshal(responseBody, &s); err != nil {
if err := common.Unmarshal(responseBody, &s); err != nil {
return "", nil, service.TaskErrorWrapper(err, "unmarshal_response_failed", http.StatusInternalServerError)
}
if strings.TrimSpace(s.Name) == "" {
return "", nil, service.TaskErrorWrapper(fmt.Errorf("missing operation name"), "invalid_response", http.StatusInternalServerError)
}
taskID = encodeLocalTaskID(s.Name)
taskID = taskcommon.EncodeLocalTaskID(s.Name)
ov := dto.NewOpenAIVideo()
ov.ID = taskID
ov.TaskID = taskID
ov.ID = info.PublicTaskID
ov.TaskID = info.PublicTaskID
ov.CreatedAt = time.Now().Unix()
ov.Model = info.OriginModelName
c.JSON(http.StatusOK, ov)
@@ -206,7 +200,7 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any, proxy
return nil, fmt.Errorf("invalid task_id")
}
upstreamName, err := decodeLocalTaskID(taskID)
upstreamName, err := taskcommon.DecodeLocalTaskID(taskID)
if err != nil {
return nil, fmt.Errorf("decode task_id failed: %w", err)
}
@@ -232,7 +226,7 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any, proxy
func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) {
var op operationResponse
if err := json.Unmarshal(respBody, &op); err != nil {
if err := common.Unmarshal(respBody, &op); err != nil {
return nil, fmt.Errorf("unmarshal operation response failed: %w", err)
}
@@ -254,9 +248,8 @@ func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, e
ti.Status = model.TaskStatusSuccess
ti.Progress = "100%"
taskID := encodeLocalTaskID(op.Name)
ti.TaskID = taskID
ti.Url = fmt.Sprintf("%s/v1/videos/%s/content", system_setting.ServerAddress, taskID)
ti.TaskID = taskcommon.EncodeLocalTaskID(op.Name)
// Url intentionally left empty — the caller constructs the proxy URL using the public task ID
// Extract URL from generateVideoResponse if available
if len(op.Response.GenerateVideoResponse.GeneratedSamples) > 0 {
@@ -269,7 +262,10 @@ func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, e
}
func (a *TaskAdaptor) ConvertToOpenAIVideo(task *model.Task) ([]byte, error) {
upstreamName, err := decodeLocalTaskID(task.TaskID)
// Use GetUpstreamTaskID() to get the real upstream operation name for model extraction.
// task.TaskID is now a public task_xxxx ID, no longer a base64-encoded upstream name.
upstreamTaskID := task.GetUpstreamTaskID()
upstreamName, err := taskcommon.DecodeLocalTaskID(upstreamTaskID)
if err != nil {
upstreamName = ""
}
@@ -297,18 +293,6 @@ func (a *TaskAdaptor) ConvertToOpenAIVideo(task *model.Task) ([]byte, error) {
// helpers
// ============================
func encodeLocalTaskID(name string) string {
return base64.RawURLEncoding.EncodeToString([]byte(name))
}
func decodeLocalTaskID(local string) (string, error) {
b, err := base64.RawURLEncoding.DecodeString(local)
if err != nil {
return "", err
}
return string(b), nil
}
var modelRe = regexp.MustCompile(`models/([^/]+)/operations/`)
func extractModelFromOperationName(name string) string {

View File

@@ -2,7 +2,6 @@ package hailuo
import (
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
@@ -18,12 +17,14 @@ import (
"github.com/QuantumNous/new-api/constant"
"github.com/QuantumNous/new-api/dto"
"github.com/QuantumNous/new-api/relay/channel"
taskcommon "github.com/QuantumNous/new-api/relay/channel/task/taskcommon"
relaycommon "github.com/QuantumNous/new-api/relay/common"
"github.com/QuantumNous/new-api/service"
)
// https://platform.minimaxi.com/docs/api-reference/video-generation-intro
type TaskAdaptor struct {
taskcommon.BaseBilling
ChannelType int
apiKey string
baseURL string
@@ -60,12 +61,12 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayIn
return nil, fmt.Errorf("invalid request type in context")
}
body, err := a.convertToRequestPayload(&req)
body, err := a.convertToRequestPayload(&req, info)
if err != nil {
return nil, errors.Wrap(err, "convert request payload failed")
}
data, err := json.Marshal(body)
data, err := common.Marshal(body)
if err != nil {
return nil, err
}
@@ -86,7 +87,7 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela
_ = resp.Body.Close()
var hResp VideoResponse
if err := json.Unmarshal(responseBody, &hResp); err != nil {
if err := common.Unmarshal(responseBody, &hResp); err != nil {
taskErr = service.TaskErrorWrapper(errors.Wrapf(err, "body: %s", responseBody), "unmarshal_response_body_failed", http.StatusInternalServerError)
return
}
@@ -101,8 +102,8 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela
}
ov := dto.NewOpenAIVideo()
ov.ID = hResp.TaskID
ov.TaskID = hResp.TaskID
ov.ID = info.PublicTaskID
ov.TaskID = info.PublicTaskID
ov.CreatedAt = time.Now().Unix()
ov.Model = info.OriginModelName
@@ -141,8 +142,8 @@ func (a *TaskAdaptor) GetChannelName() string {
return ChannelName
}
func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*VideoRequest, error) {
modelConfig := GetModelConfig(req.Model)
func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq, info *relaycommon.RelayInfo) (*VideoRequest, error) {
modelConfig := GetModelConfig(info.UpstreamModelName)
duration := DefaultDuration
if req.Duration > 0 {
duration = req.Duration
@@ -153,7 +154,7 @@ func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*
}
videoRequest := &VideoRequest{
Model: req.Model,
Model: info.UpstreamModelName,
Prompt: req.Prompt,
Duration: &duration,
Resolution: resolution,
@@ -182,7 +183,7 @@ func (a *TaskAdaptor) parseResolutionFromSize(size string, modelConfig ModelConf
func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) {
resTask := QueryTaskResponse{}
if err := json.Unmarshal(respBody, &resTask); err != nil {
if err := common.Unmarshal(respBody, &resTask); err != nil {
return nil, errors.Wrap(err, "unmarshal task result failed")
}
@@ -224,7 +225,7 @@ func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, e
func (a *TaskAdaptor) ConvertToOpenAIVideo(originTask *model.Task) ([]byte, error) {
var hailuoResp QueryTaskResponse
if err := json.Unmarshal(originTask.Data, &hailuoResp); err != nil {
if err := common.Unmarshal(originTask.Data, &hailuoResp); err != nil {
return nil, errors.Wrap(err, "unmarshal hailuo task data failed")
}
@@ -271,7 +272,7 @@ func (a *TaskAdaptor) buildVideoURL(_, fileID string) string {
}
var retrieveResp RetrieveFileResponse
if err := json.Unmarshal(responseBody, &retrieveResp); err != nil {
if err := common.Unmarshal(responseBody, &retrieveResp); err != nil {
return ""
}

View File

@@ -6,7 +6,6 @@ import (
"crypto/sha256"
"encoding/base64"
"encoding/hex"
"encoding/json"
"fmt"
"io"
"net/http"
@@ -25,6 +24,7 @@ import (
"github.com/QuantumNous/new-api/constant"
"github.com/QuantumNous/new-api/dto"
"github.com/QuantumNous/new-api/relay/channel"
taskcommon "github.com/QuantumNous/new-api/relay/channel/task/taskcommon"
relaycommon "github.com/QuantumNous/new-api/relay/common"
"github.com/QuantumNous/new-api/service"
)
@@ -77,6 +77,7 @@ const (
// ============================
type TaskAdaptor struct {
taskcommon.BaseBilling
ChannelType int
accessKey string
secretKey string
@@ -164,11 +165,11 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayIn
}
}
body, err := a.convertToRequestPayload(&req)
body, err := a.convertToRequestPayload(&req, info)
if err != nil {
return nil, errors.Wrap(err, "convert request payload failed")
}
data, err := json.Marshal(body)
data, err := common.Marshal(body)
if err != nil {
return nil, err
}
@@ -191,7 +192,7 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela
// Parse Jimeng response
var jResp responsePayload
if err := json.Unmarshal(responseBody, &jResp); err != nil {
if err := common.Unmarshal(responseBody, &jResp); err != nil {
taskErr = service.TaskErrorWrapper(errors.Wrapf(err, "body: %s", responseBody), "unmarshal_response_body_failed", http.StatusInternalServerError)
return
}
@@ -202,8 +203,8 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela
}
ov := dto.NewOpenAIVideo()
ov.ID = jResp.Data.TaskID
ov.TaskID = jResp.Data.TaskID
ov.ID = info.PublicTaskID
ov.TaskID = info.PublicTaskID
ov.CreatedAt = time.Now().Unix()
ov.Model = info.OriginModelName
c.JSON(http.StatusOK, ov)
@@ -225,7 +226,7 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any, proxy
"req_key": "jimeng_vgfm_t2v_l20", // This is fixed value from doc: https://www.volcengine.com/docs/85621/1544774
"task_id": taskID,
}
payloadBytes, err := json.Marshal(payload)
payloadBytes, err := common.Marshal(payload)
if err != nil {
return nil, errors.Wrap(err, "marshal fetch task payload failed")
}
@@ -377,9 +378,9 @@ func hmacSHA256(key []byte, data []byte) []byte {
return h.Sum(nil)
}
func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*requestPayload, error) {
func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq, info *relaycommon.RelayInfo) (*requestPayload, error) {
r := requestPayload{
ReqKey: req.Model,
ReqKey: info.UpstreamModelName,
Prompt: req.Prompt,
}
@@ -398,13 +399,7 @@ func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*
r.BinaryDataBase64 = req.Images
}
}
metadata := req.Metadata
medaBytes, err := json.Marshal(metadata)
if err != nil {
return nil, errors.Wrap(err, "metadata marshal metadata failed")
}
err = json.Unmarshal(medaBytes, &r)
if err != nil {
if err := taskcommon.UnmarshalMetadata(req.Metadata, &r); err != nil {
return nil, errors.Wrap(err, "unmarshal metadata failed")
}
@@ -432,7 +427,7 @@ func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*
func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) {
resTask := responseTask{}
if err := json.Unmarshal(respBody, &resTask); err != nil {
if err := common.Unmarshal(respBody, &resTask); err != nil {
return nil, errors.Wrap(err, "unmarshal task result failed")
}
taskResult := relaycommon.TaskInfo{}
@@ -458,7 +453,7 @@ func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, e
func (a *TaskAdaptor) ConvertToOpenAIVideo(originTask *model.Task) ([]byte, error) {
var jimengResp responseTask
if err := json.Unmarshal(originTask.Data, &jimengResp); err != nil {
if err := common.Unmarshal(originTask.Data, &jimengResp); err != nil {
return nil, errors.Wrap(err, "unmarshal jimeng task data failed")
}
@@ -477,8 +472,7 @@ func (a *TaskAdaptor) ConvertToOpenAIVideo(originTask *model.Task) ([]byte, erro
}
}
jsonData, _ := common.Marshal(openAIVideo)
return jsonData, nil
return common.Marshal(openAIVideo)
}
func isNewAPIRelay(apiKey string) bool {

View File

@@ -2,7 +2,6 @@ package kling
import (
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
@@ -21,6 +20,7 @@ import (
"github.com/QuantumNous/new-api/constant"
"github.com/QuantumNous/new-api/dto"
"github.com/QuantumNous/new-api/relay/channel"
taskcommon "github.com/QuantumNous/new-api/relay/channel/task/taskcommon"
relaycommon "github.com/QuantumNous/new-api/relay/common"
"github.com/QuantumNous/new-api/service"
)
@@ -97,6 +97,7 @@ type responsePayload struct {
// ============================
type TaskAdaptor struct {
taskcommon.BaseBilling
ChannelType int
apiKey string
baseURL string
@@ -149,14 +150,14 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayIn
}
req := v.(relaycommon.TaskSubmitReq)
body, err := a.convertToRequestPayload(&req)
body, err := a.convertToRequestPayload(&req, info)
if err != nil {
return nil, err
}
if body.Image == "" && body.ImageTail == "" {
c.Set("action", constant.TaskActionTextGenerate)
}
data, err := json.Marshal(body)
data, err := common.Marshal(body)
if err != nil {
return nil, err
}
@@ -180,7 +181,7 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela
}
var kResp responsePayload
err = json.Unmarshal(responseBody, &kResp)
err = common.Unmarshal(responseBody, &kResp)
if err != nil {
taskErr = service.TaskErrorWrapper(err, "unmarshal_response_failed", http.StatusInternalServerError)
return
@@ -190,8 +191,8 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela
return
}
ov := dto.NewOpenAIVideo()
ov.ID = kResp.Data.TaskId
ov.TaskID = kResp.Data.TaskId
ov.ID = info.PublicTaskID
ov.TaskID = info.PublicTaskID
ov.CreatedAt = time.Now().Unix()
ov.Model = info.OriginModelName
c.JSON(http.StatusOK, ov)
@@ -247,15 +248,15 @@ func (a *TaskAdaptor) GetChannelName() string {
// helpers
// ============================
func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*requestPayload, error) {
func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq, info *relaycommon.RelayInfo) (*requestPayload, error) {
r := requestPayload{
Prompt: req.Prompt,
Image: req.Image,
Mode: defaultString(req.Mode, "std"),
Duration: fmt.Sprintf("%d", defaultInt(req.Duration, 5)),
Mode: taskcommon.DefaultString(req.Mode, "std"),
Duration: fmt.Sprintf("%d", taskcommon.DefaultInt(req.Duration, 5)),
AspectRatio: a.getAspectRatio(req.Size),
ModelName: req.Model,
Model: req.Model, // Keep consistent with model_name, double writing improves compatibility
ModelName: info.UpstreamModelName,
Model: info.UpstreamModelName,
CfgScale: 0.5,
StaticMask: "",
DynamicMasks: []DynamicMask{},
@@ -265,14 +266,9 @@ func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*
}
if r.ModelName == "" {
r.ModelName = "kling-v1"
r.Model = "kling-v1"
}
metadata := req.Metadata
medaBytes, err := json.Marshal(metadata)
if err != nil {
return nil, errors.Wrap(err, "metadata marshal metadata failed")
}
err = json.Unmarshal(medaBytes, &r)
if err != nil {
if err := taskcommon.UnmarshalMetadata(req.Metadata, &r); err != nil {
return nil, errors.Wrap(err, "unmarshal metadata failed")
}
return &r, nil
@@ -291,20 +287,6 @@ func (a *TaskAdaptor) getAspectRatio(size string) string {
}
}
func defaultString(s, def string) string {
if strings.TrimSpace(s) == "" {
return def
}
return s
}
func defaultInt(v int, def int) int {
if v == 0 {
return def
}
return v
}
// ============================
// JWT helpers
// ============================
@@ -340,7 +322,7 @@ func (a *TaskAdaptor) createJWTTokenWithKey(apiKey string) (string, error) {
func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) {
taskInfo := &relaycommon.TaskInfo{}
resPayload := responsePayload{}
err := json.Unmarshal(respBody, &resPayload)
err := common.Unmarshal(respBody, &resPayload)
if err != nil {
return nil, errors.Wrap(err, "failed to unmarshal response body")
}
@@ -374,7 +356,7 @@ func isNewAPIRelay(apiKey string) bool {
func (a *TaskAdaptor) ConvertToOpenAIVideo(originTask *model.Task) ([]byte, error) {
var klingResp responsePayload
if err := json.Unmarshal(originTask.Data, &klingResp); err != nil {
if err := common.Unmarshal(originTask.Data, &klingResp); err != nil {
return nil, errors.Wrap(err, "unmarshal kling task data failed")
}
@@ -401,6 +383,5 @@ func (a *TaskAdaptor) ConvertToOpenAIVideo(originTask *model.Task) ([]byte, erro
Code: fmt.Sprintf("%d", klingResp.Code),
}
}
jsonData, _ := common.Marshal(openAIVideo)
return jsonData, nil
return common.Marshal(openAIVideo)
}

View File

@@ -1,9 +1,13 @@
package sora
import (
"bytes"
"fmt"
"io"
"mime/multipart"
"net/http"
"net/textproto"
"strconv"
"strings"
"github.com/QuantumNous/new-api/common"
@@ -11,12 +15,13 @@ import (
"github.com/QuantumNous/new-api/dto"
"github.com/QuantumNous/new-api/model"
"github.com/QuantumNous/new-api/relay/channel"
taskcommon "github.com/QuantumNous/new-api/relay/channel/task/taskcommon"
relaycommon "github.com/QuantumNous/new-api/relay/common"
"github.com/QuantumNous/new-api/service"
"github.com/QuantumNous/new-api/setting/system_setting"
"github.com/gin-gonic/gin"
"github.com/pkg/errors"
"github.com/tidwall/sjson"
)
// ============================
@@ -57,6 +62,7 @@ type responseTask struct {
// ============================
type TaskAdaptor struct {
taskcommon.BaseBilling
ChannelType int
apiKey string
baseURL string
@@ -69,15 +75,15 @@ func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) {
}
func validateRemixRequest(c *gin.Context) *dto.TaskError {
var req struct {
Prompt string `json:"prompt"`
}
var req relaycommon.TaskSubmitReq
if err := common.UnmarshalBodyReusable(c, &req); err != nil {
return service.TaskErrorWrapperLocal(err, "invalid_request", http.StatusBadRequest)
}
if strings.TrimSpace(req.Prompt) == "" {
return service.TaskErrorWrapperLocal(fmt.Errorf("field prompt is required"), "invalid_request", http.StatusBadRequest)
}
// 存储原始请求到 context与 ValidateMultipartDirect 路径保持一致
c.Set("task_request", req)
return nil
}
@@ -88,6 +94,41 @@ func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycom
return relaycommon.ValidateMultipartDirect(c, info)
}
// EstimateBilling 根据用户请求的 seconds 和 size 计算 OtherRatios。
func (a *TaskAdaptor) EstimateBilling(c *gin.Context, info *relaycommon.RelayInfo) map[string]float64 {
// remix 路径的 OtherRatios 已在 ResolveOriginTask 中设置
if info.Action == constant.TaskActionRemix {
return nil
}
req, err := relaycommon.GetTaskRequest(c)
if err != nil {
return nil
}
seconds, _ := strconv.Atoi(req.Seconds)
if seconds == 0 {
seconds = req.Duration
}
if seconds <= 0 {
seconds = 4
}
size := req.Size
if size == "" {
size = "720x1280"
}
ratios := map[string]float64{
"seconds": float64(seconds),
"size": 1,
}
if size == "1792x1024" || size == "1024x1792" {
ratios["size"] = 1.666667
}
return ratios
}
func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.RelayInfo) (string, error) {
if info.Action == constant.TaskActionRemix {
return fmt.Sprintf("%s/v1/videos/%s/remix", a.baseURL, info.OriginTaskID), nil
@@ -107,6 +148,74 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayIn
if err != nil {
return nil, errors.Wrap(err, "get_request_body_failed")
}
cachedBody, err := storage.Bytes()
if err != nil {
return nil, errors.Wrap(err, "read_body_bytes_failed")
}
contentType := c.GetHeader("Content-Type")
if strings.HasPrefix(contentType, "application/json") {
var bodyMap map[string]interface{}
if err := common.Unmarshal(cachedBody, &bodyMap); err == nil {
bodyMap["model"] = info.UpstreamModelName
if newBody, err := common.Marshal(bodyMap); err == nil {
return bytes.NewReader(newBody), nil
}
}
return bytes.NewReader(cachedBody), nil
}
if strings.Contains(contentType, "multipart/form-data") {
formData, err := common.ParseMultipartFormReusable(c)
if err != nil {
return bytes.NewReader(cachedBody), nil
}
var buf bytes.Buffer
writer := multipart.NewWriter(&buf)
writer.WriteField("model", info.UpstreamModelName)
for key, values := range formData.Value {
if key == "model" {
continue
}
for _, v := range values {
writer.WriteField(key, v)
}
}
for fieldName, fileHeaders := range formData.File {
for _, fh := range fileHeaders {
f, err := fh.Open()
if err != nil {
continue
}
ct := fh.Header.Get("Content-Type")
if ct == "" || ct == "application/octet-stream" {
buf512 := make([]byte, 512)
n, _ := io.ReadFull(f, buf512)
ct = http.DetectContentType(buf512[:n])
// Re-open after sniffing so the full content is copied below
f.Close()
f, err = fh.Open()
if err != nil {
continue
}
}
h := make(textproto.MIMEHeader)
h.Set("Content-Disposition", fmt.Sprintf(`form-data; name="%s"; filename="%s"`, fieldName, fh.Filename))
h.Set("Content-Type", ct)
part, err := writer.CreatePart(h)
if err != nil {
f.Close()
continue
}
io.Copy(part, f)
f.Close()
}
}
writer.Close()
c.Request.Header.Set("Content-Type", writer.FormDataContentType())
return &buf, nil
}
return common.ReaderOnly(storage), nil
}
@@ -116,7 +225,7 @@ func (a *TaskAdaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, req
}
// DoResponse handles upstream response, returns taskID etc.
func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, _ *relaycommon.RelayInfo) (taskID string, taskData []byte, taskErr *dto.TaskError) {
func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (taskID string, taskData []byte, taskErr *dto.TaskError) {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
taskErr = service.TaskErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
@@ -131,17 +240,20 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, _ *relayco
return
}
if dResp.ID == "" {
if dResp.TaskID == "" {
taskErr = service.TaskErrorWrapper(fmt.Errorf("task_id is empty"), "invalid_response", http.StatusInternalServerError)
return
}
dResp.ID = dResp.TaskID
dResp.TaskID = ""
upstreamID := dResp.ID
if upstreamID == "" {
upstreamID = dResp.TaskID
}
if upstreamID == "" {
taskErr = service.TaskErrorWrapper(fmt.Errorf("task_id is empty"), "invalid_response", http.StatusInternalServerError)
return
}
// 使用公开 task_xxxx ID 返回给客户端
dResp.ID = info.PublicTaskID
dResp.TaskID = info.PublicTaskID
c.JSON(http.StatusOK, dResp)
return dResp.ID, responseBody, nil
return upstreamID, responseBody, nil
}
// FetchTask fetch task status
@@ -192,7 +304,7 @@ func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, e
taskResult.Status = model.TaskStatusInProgress
case "completed":
taskResult.Status = model.TaskStatusSuccess
taskResult.Url = fmt.Sprintf("%s/v1/videos/%s/content", system_setting.ServerAddress, resTask.ID)
// Url intentionally left empty — the caller constructs the proxy URL using the public task ID
case "failed", "cancelled":
taskResult.Status = model.TaskStatusFailure
if resTask.Error != nil {
@@ -210,5 +322,10 @@ func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, e
}
func (a *TaskAdaptor) ConvertToOpenAIVideo(task *model.Task) ([]byte, error) {
return task.Data, nil
data := task.Data
var err error
if data, err = sjson.SetBytes(data, "id", task.TaskID); err != nil {
return nil, errors.Wrap(err, "set id failed")
}
return data, nil
}

View File

@@ -2,18 +2,16 @@ package suno
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"time"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/constant"
"github.com/QuantumNous/new-api/dto"
"github.com/QuantumNous/new-api/relay/channel"
taskcommon "github.com/QuantumNous/new-api/relay/channel/task/taskcommon"
relaycommon "github.com/QuantumNous/new-api/relay/common"
"github.com/QuantumNous/new-api/service"
@@ -21,11 +19,16 @@ import (
)
type TaskAdaptor struct {
taskcommon.BaseBilling
ChannelType int
}
// ParseTaskResult is not used for Suno tasks.
// Suno polling uses a dedicated batch-fetch path (service.UpdateSunoTasks) that
// receives dto.TaskResponse[[]dto.SunoDataResponse] from the upstream /fetch API.
// This differs from the per-task polling used by video adaptors.
func (a *TaskAdaptor) ParseTaskResult([]byte) (*relaycommon.TaskInfo, error) {
return nil, fmt.Errorf("not implement") // todo implement this method if needed
return nil, fmt.Errorf("suno uses batch polling via UpdateSunoTasks, ParseTaskResult is not applicable")
}
func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) {
@@ -47,13 +50,13 @@ func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycom
return
}
if sunoRequest.ContinueClipId != "" {
if sunoRequest.TaskID == "" {
taskErr = service.TaskErrorWrapperLocal(fmt.Errorf("task id is empty"), "invalid_request", http.StatusBadRequest)
return
}
info.OriginTaskID = sunoRequest.TaskID
}
//if sunoRequest.ContinueClipId != "" {
// if sunoRequest.TaskID == "" {
// taskErr = service.TaskErrorWrapperLocal(fmt.Errorf("task id is empty"), "invalid_request", http.StatusBadRequest)
// return
// }
// info.OriginTaskID = sunoRequest.TaskID
//}
info.Action = action
c.Set("task_request", sunoRequest)
@@ -76,12 +79,9 @@ func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info
func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayInfo) (io.Reader, error) {
sunoRequest, ok := c.Get("task_request")
if !ok {
err := common.UnmarshalBodyReusable(c, &sunoRequest)
if err != nil {
return nil, err
}
return nil, fmt.Errorf("task_request not found in context")
}
data, err := json.Marshal(sunoRequest)
data, err := common.Marshal(sunoRequest)
if err != nil {
return nil, err
}
@@ -99,7 +99,7 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela
return
}
var sunoResponse dto.TaskResponse[string]
err = json.Unmarshal(responseBody, &sunoResponse)
err = common.Unmarshal(responseBody, &sunoResponse)
if err != nil {
taskErr = service.TaskErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
return
@@ -109,17 +109,13 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela
return
}
for k, v := range resp.Header {
c.Writer.Header().Set(k, v[0])
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, err = io.Copy(c.Writer, bytes.NewBuffer(responseBody))
if err != nil {
taskErr = service.TaskErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError)
return
// 使用公开 task_xxxx ID 替换上游 ID 返回给客户端
publicResponse := dto.TaskResponse[string]{
Code: sunoResponse.Code,
Message: sunoResponse.Message,
Data: info.PublicTaskID,
}
c.JSON(http.StatusOK, publicResponse)
return sunoResponse.Data, nil, nil
}
@@ -134,7 +130,7 @@ func (a *TaskAdaptor) GetChannelName() string {
func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any, proxy string) (*http.Response, error) {
requestUrl := fmt.Sprintf("%s/suno/fetch", baseUrl)
byteBody, err := json.Marshal(body)
byteBody, err := common.Marshal(body)
if err != nil {
return nil, err
}
@@ -144,13 +140,6 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any, proxy
common.SysLog(fmt.Sprintf("Get Task error: %v", err))
return nil, err
}
defer req.Body.Close()
// 设置超时时间
timeout := time.Second * 15
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
// 使用带有超时的 context 创建新的请求
req = req.WithContext(ctx)
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+key)
client, err := service.GetHttpClientWithProxy(proxy)

View File

@@ -0,0 +1,95 @@
package taskcommon
import (
"encoding/base64"
"fmt"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/model"
relaycommon "github.com/QuantumNous/new-api/relay/common"
"github.com/QuantumNous/new-api/setting/system_setting"
"github.com/gin-gonic/gin"
)
// UnmarshalMetadata converts a map[string]any metadata to a typed struct via JSON round-trip.
// This replaces the repeated pattern: json.Marshal(metadata) → json.Unmarshal(bytes, &target).
func UnmarshalMetadata(metadata map[string]any, target any) error {
if metadata == nil {
return nil
}
metaBytes, err := common.Marshal(metadata)
if err != nil {
return fmt.Errorf("marshal metadata failed: %w", err)
}
if err := common.Unmarshal(metaBytes, target); err != nil {
return fmt.Errorf("unmarshal metadata failed: %w", err)
}
return nil
}
// DefaultString returns val if non-empty, otherwise fallback.
func DefaultString(val, fallback string) string {
if val == "" {
return fallback
}
return val
}
// DefaultInt returns val if non-zero, otherwise fallback.
func DefaultInt(val, fallback int) int {
if val == 0 {
return fallback
}
return val
}
// EncodeLocalTaskID encodes an upstream operation name to a URL-safe base64 string.
// Used by Gemini/Vertex to store upstream names as task IDs.
func EncodeLocalTaskID(name string) string {
return base64.RawURLEncoding.EncodeToString([]byte(name))
}
// DecodeLocalTaskID decodes a base64-encoded upstream operation name.
func DecodeLocalTaskID(id string) (string, error) {
b, err := base64.RawURLEncoding.DecodeString(id)
if err != nil {
return "", err
}
return string(b), nil
}
// BuildProxyURL constructs the video proxy URL using the public task ID.
// e.g., "https://your-server.com/v1/videos/task_xxxx/content"
func BuildProxyURL(taskID string) string {
return fmt.Sprintf("%s/v1/videos/%s/content", system_setting.ServerAddress, taskID)
}
// Status-to-progress mapping constants for polling updates.
const (
ProgressSubmitted = "10%"
ProgressQueued = "20%"
ProgressInProgress = "30%"
ProgressComplete = "100%"
)
// ---------------------------------------------------------------------------
// BaseBilling — embeddable no-op implementations for TaskAdaptor billing methods.
// Adaptors that do not need custom billing can embed this struct directly.
// ---------------------------------------------------------------------------
type BaseBilling struct{}
// EstimateBilling returns nil (no extra ratios; use base model price).
func (BaseBilling) EstimateBilling(_ *gin.Context, _ *relaycommon.RelayInfo) map[string]float64 {
return nil
}
// AdjustBillingOnSubmit returns nil (no submit-time adjustment).
func (BaseBilling) AdjustBillingOnSubmit(_ *relaycommon.RelayInfo, _ []byte) map[string]float64 {
return nil
}
// AdjustBillingOnComplete returns 0 (keep pre-charged amount).
func (BaseBilling) AdjustBillingOnComplete(_ *model.Task, _ *relaycommon.TaskInfo) int {
return 0
}

View File

@@ -2,13 +2,12 @@ package vertex
import (
"bytes"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"net/http"
"regexp"
"strings"
"time"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/model"
@@ -17,6 +16,7 @@ import (
"github.com/QuantumNous/new-api/constant"
"github.com/QuantumNous/new-api/dto"
"github.com/QuantumNous/new-api/relay/channel"
taskcommon "github.com/QuantumNous/new-api/relay/channel/task/taskcommon"
vertexcore "github.com/QuantumNous/new-api/relay/channel/vertex"
relaycommon "github.com/QuantumNous/new-api/relay/common"
"github.com/QuantumNous/new-api/service"
@@ -62,6 +62,7 @@ type operationResponse struct {
// ============================
type TaskAdaptor struct {
taskcommon.BaseBilling
ChannelType int
apiKey string
baseURL string
@@ -82,10 +83,10 @@ func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycom
// BuildRequestURL constructs the upstream URL.
func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.RelayInfo) (string, error) {
adc := &vertexcore.Credentials{}
if err := json.Unmarshal([]byte(a.apiKey), adc); err != nil {
if err := common.Unmarshal([]byte(a.apiKey), adc); err != nil {
return "", fmt.Errorf("failed to decode credentials: %w", err)
}
modelName := info.OriginModelName
modelName := info.UpstreamModelName
if modelName == "" {
modelName = "veo-3.0-generate-001"
}
@@ -116,7 +117,7 @@ func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info
req.Header.Set("Accept", "application/json")
adc := &vertexcore.Credentials{}
if err := json.Unmarshal([]byte(a.apiKey), adc); err != nil {
if err := common.Unmarshal([]byte(a.apiKey), adc); err != nil {
return fmt.Errorf("failed to decode credentials: %w", err)
}
@@ -133,6 +134,28 @@ func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info
return nil
}
// EstimateBilling 根据用户请求中的 sampleCount 计算 OtherRatios。
func (a *TaskAdaptor) EstimateBilling(c *gin.Context, _ *relaycommon.RelayInfo) map[string]float64 {
sampleCount := 1
v, ok := c.Get("task_request")
if ok {
req := v.(relaycommon.TaskSubmitReq)
if req.Metadata != nil {
if sc, exists := req.Metadata["sampleCount"]; exists {
if i, ok := sc.(int); ok && i > 0 {
sampleCount = i
}
if f, ok := sc.(float64); ok && int(f) > 0 {
sampleCount = int(f)
}
}
}
}
return map[string]float64{
"sampleCount": float64(sampleCount),
}
}
// BuildRequestBody converts request into Vertex specific format.
func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayInfo) (io.Reader, error) {
v, ok := c.Get("task_request")
@@ -166,25 +189,7 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayIn
return nil, fmt.Errorf("sampleCount must be greater than 0")
}
// if req.Duration > 0 {
// body.Parameters["durationSeconds"] = req.Duration
// } else if req.Seconds != "" {
// seconds, err := strconv.Atoi(req.Seconds)
// if err != nil {
// return nil, errors.Wrap(err, "convert seconds to int failed")
// }
// body.Parameters["durationSeconds"] = seconds
// }
info.PriceData.OtherRatios = map[string]float64{
"sampleCount": float64(body.Parameters["sampleCount"].(int)),
}
// if v, ok := body.Parameters["durationSeconds"]; ok {
// info.PriceData.OtherRatios["durationSeconds"] = float64(v.(int))
// }
data, err := json.Marshal(body)
data, err := common.Marshal(body)
if err != nil {
return nil, err
}
@@ -205,14 +210,19 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela
_ = resp.Body.Close()
var s submitResponse
if err := json.Unmarshal(responseBody, &s); err != nil {
if err := common.Unmarshal(responseBody, &s); err != nil {
return "", nil, service.TaskErrorWrapper(err, "unmarshal_response_failed", http.StatusInternalServerError)
}
if strings.TrimSpace(s.Name) == "" {
return "", nil, service.TaskErrorWrapper(fmt.Errorf("missing operation name"), "invalid_response", http.StatusInternalServerError)
}
localID := encodeLocalTaskID(s.Name)
c.JSON(http.StatusOK, gin.H{"task_id": localID})
localID := taskcommon.EncodeLocalTaskID(s.Name)
ov := dto.NewOpenAIVideo()
ov.ID = info.PublicTaskID
ov.TaskID = info.PublicTaskID
ov.CreatedAt = time.Now().Unix()
ov.Model = info.OriginModelName
c.JSON(http.StatusOK, ov)
return localID, responseBody, nil
}
@@ -225,7 +235,7 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any, proxy
if !ok {
return nil, fmt.Errorf("invalid task_id")
}
upstreamName, err := decodeLocalTaskID(taskID)
upstreamName, err := taskcommon.DecodeLocalTaskID(taskID)
if err != nil {
return nil, fmt.Errorf("decode task_id failed: %w", err)
}
@@ -245,12 +255,12 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any, proxy
url = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:fetchPredictOperation", region, project, region, modelName)
}
payload := map[string]string{"operationName": upstreamName}
data, err := json.Marshal(payload)
data, err := common.Marshal(payload)
if err != nil {
return nil, err
}
adc := &vertexcore.Credentials{}
if err := json.Unmarshal([]byte(key), adc); err != nil {
if err := common.Unmarshal([]byte(key), adc); err != nil {
return nil, fmt.Errorf("failed to decode credentials: %w", err)
}
token, err := vertexcore.AcquireAccessToken(*adc, proxy)
@@ -274,7 +284,7 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any, proxy
func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) {
var op operationResponse
if err := json.Unmarshal(respBody, &op); err != nil {
if err := common.Unmarshal(respBody, &op); err != nil {
return nil, fmt.Errorf("unmarshal operation response failed: %w", err)
}
ti := &relaycommon.TaskInfo{}
@@ -338,7 +348,10 @@ func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, e
}
func (a *TaskAdaptor) ConvertToOpenAIVideo(task *model.Task) ([]byte, error) {
upstreamName, err := decodeLocalTaskID(task.TaskID)
// Use GetUpstreamTaskID() to get the real upstream operation name for model extraction.
// task.TaskID is now a public task_xxxx ID, no longer a base64-encoded upstream name.
upstreamTaskID := task.GetUpstreamTaskID()
upstreamName, err := taskcommon.DecodeLocalTaskID(upstreamTaskID)
if err != nil {
upstreamName = ""
}
@@ -353,8 +366,8 @@ func (a *TaskAdaptor) ConvertToOpenAIVideo(task *model.Task) ([]byte, error) {
v.SetProgressStr(task.Progress)
v.CreatedAt = task.CreatedAt
v.CompletedAt = task.UpdatedAt
if strings.HasPrefix(task.FailReason, "data:") && len(task.FailReason) > 0 {
v.SetMetadata("url", task.FailReason)
if resultURL := task.GetResultURL(); strings.HasPrefix(resultURL, "data:") && len(resultURL) > 0 {
v.SetMetadata("url", resultURL)
}
return common.Marshal(v)
@@ -364,18 +377,6 @@ func (a *TaskAdaptor) ConvertToOpenAIVideo(task *model.Task) ([]byte, error) {
// helpers
// ============================
func encodeLocalTaskID(name string) string {
return base64.RawURLEncoding.EncodeToString([]byte(name))
}
func decodeLocalTaskID(local string) (string, error) {
b, err := base64.RawURLEncoding.DecodeString(local)
if err != nil {
return "", err
}
return string(b), nil
}
var regionRe = regexp.MustCompile(`locations/([a-z0-9-]+)/`)
func extractRegionFromOperationName(name string) string {

View File

@@ -2,7 +2,6 @@ package vidu
import (
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
@@ -16,6 +15,7 @@ import (
"github.com/QuantumNous/new-api/dto"
"github.com/QuantumNous/new-api/model"
"github.com/QuantumNous/new-api/relay/channel"
taskcommon "github.com/QuantumNous/new-api/relay/channel/task/taskcommon"
relaycommon "github.com/QuantumNous/new-api/relay/common"
"github.com/QuantumNous/new-api/service"
@@ -73,6 +73,7 @@ type creation struct {
// ============================
type TaskAdaptor struct {
taskcommon.BaseBilling
ChannelType int
baseURL string
}
@@ -115,7 +116,7 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayIn
}
req := v.(relaycommon.TaskSubmitReq)
body, err := a.convertToRequestPayload(&req)
body, err := a.convertToRequestPayload(&req, info)
if err != nil {
return nil, err
}
@@ -127,7 +128,7 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayIn
}
}
data, err := json.Marshal(body)
data, err := common.Marshal(body)
if err != nil {
return nil, err
}
@@ -168,7 +169,7 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela
}
var vResp responsePayload
err = json.Unmarshal(responseBody, &vResp)
err = common.Unmarshal(responseBody, &vResp)
if err != nil {
taskErr = service.TaskErrorWrapper(errors.Wrap(err, fmt.Sprintf("%s", responseBody)), "unmarshal_response_failed", http.StatusInternalServerError)
return
@@ -180,8 +181,8 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela
}
ov := dto.NewOpenAIVideo()
ov.ID = vResp.TaskId
ov.TaskID = vResp.TaskId
ov.ID = info.PublicTaskID
ov.TaskID = info.PublicTaskID
ov.CreatedAt = time.Now().Unix()
ov.Model = info.OriginModelName
c.JSON(http.StatusOK, ov)
@@ -223,47 +224,27 @@ func (a *TaskAdaptor) GetChannelName() string {
// helpers
// ============================
func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*requestPayload, error) {
func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq, info *relaycommon.RelayInfo) (*requestPayload, error) {
r := requestPayload{
Model: defaultString(req.Model, "viduq1"),
Model: taskcommon.DefaultString(info.UpstreamModelName, "viduq1"),
Images: req.Images,
Prompt: req.Prompt,
Duration: defaultInt(req.Duration, 5),
Resolution: defaultString(req.Size, "1080p"),
Duration: taskcommon.DefaultInt(req.Duration, 5),
Resolution: taskcommon.DefaultString(req.Size, "1080p"),
MovementAmplitude: "auto",
Bgm: false,
}
metadata := req.Metadata
medaBytes, err := json.Marshal(metadata)
if err != nil {
return nil, errors.Wrap(err, "metadata marshal metadata failed")
}
err = json.Unmarshal(medaBytes, &r)
if err != nil {
if err := taskcommon.UnmarshalMetadata(req.Metadata, &r); err != nil {
return nil, errors.Wrap(err, "unmarshal metadata failed")
}
return &r, nil
}
func defaultString(value, defaultValue string) string {
if value == "" {
return defaultValue
}
return value
}
func defaultInt(value, defaultValue int) int {
if value == 0 {
return defaultValue
}
return value
}
func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) {
taskInfo := &relaycommon.TaskInfo{}
var taskResp taskResultResponse
err := json.Unmarshal(respBody, &taskResp)
err := common.Unmarshal(respBody, &taskResp)
if err != nil {
return nil, errors.Wrap(err, "failed to unmarshal response body")
}
@@ -293,7 +274,7 @@ func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, e
func (a *TaskAdaptor) ConvertToOpenAIVideo(originTask *model.Task) ([]byte, error) {
var viduResp taskResultResponse
if err := json.Unmarshal(originTask.Data, &viduResp); err != nil {
if err := common.Unmarshal(originTask.Data, &viduResp); err != nil {
return nil, errors.Wrap(err, "unmarshal vidu task data failed")
}
@@ -315,6 +296,5 @@ func (a *TaskAdaptor) ConvertToOpenAIVideo(originTask *model.Task) ([]byte, erro
}
}
jsonData, _ := common.Marshal(openAIVideo)
return jsonData, nil
return common.Marshal(openAIVideo)
}

View File

@@ -75,7 +75,7 @@ func chatCompletionsViaResponses(c *gin.Context, info *relaycommon.RelayInfo, ad
return nil, types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
}
chatJSON, err = relaycommon.RemoveDisabledFields(chatJSON, info.ChannelOtherSettings)
chatJSON, err = relaycommon.RemoveDisabledFields(chatJSON, info.ChannelOtherSettings, info.ChannelSetting.PassThroughBodyEnabled)
if err != nil {
return nil, types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
}
@@ -119,7 +119,7 @@ func chatCompletionsViaResponses(c *gin.Context, info *relaycommon.RelayInfo, ad
return nil, types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
}
jsonData, err = relaycommon.RemoveDisabledFields(jsonData, info.ChannelOtherSettings)
jsonData, err = relaycommon.RemoveDisabledFields(jsonData, info.ChannelOtherSettings, info.ChannelSetting.PassThroughBodyEnabled)
if err != nil {
return nil, types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
}

View File

@@ -146,7 +146,7 @@ func ClaudeHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
}
// remove disabled fields for Claude API
jsonData, err = relaycommon.RemoveDisabledFields(jsonData, info.ChannelOtherSettings)
jsonData, err = relaycommon.RemoveDisabledFields(jsonData, info.ChannelOtherSettings, info.ChannelSetting.PassThroughBodyEnabled)
if err != nil {
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
}

View File

@@ -6,6 +6,9 @@ import (
"testing"
"github.com/QuantumNous/new-api/types"
"github.com/QuantumNous/new-api/dto"
"github.com/QuantumNous/new-api/setting/model_setting"
)
func TestApplyParamOverrideTrimPrefix(t *testing.T) {
@@ -1311,6 +1314,76 @@ func TestApplyParamOverrideWithRelayInfoMoveAndCopyHeaders(t *testing.T) {
}
}
func TestRemoveDisabledFieldsSkipWhenChannelPassThroughEnabled(t *testing.T) {
input := `{
"service_tier":"flex",
"safety_identifier":"user-123",
"store":true,
"stream_options":{"include_obfuscation":false}
}`
settings := dto.ChannelOtherSettings{}
out, err := RemoveDisabledFields([]byte(input), settings, true)
if err != nil {
t.Fatalf("RemoveDisabledFields returned error: %v", err)
}
assertJSONEqual(t, input, string(out))
}
func TestRemoveDisabledFieldsSkipWhenGlobalPassThroughEnabled(t *testing.T) {
original := model_setting.GetGlobalSettings().PassThroughRequestEnabled
model_setting.GetGlobalSettings().PassThroughRequestEnabled = true
t.Cleanup(func() {
model_setting.GetGlobalSettings().PassThroughRequestEnabled = original
})
input := `{
"service_tier":"flex",
"safety_identifier":"user-123",
"stream_options":{"include_obfuscation":false}
}`
settings := dto.ChannelOtherSettings{}
out, err := RemoveDisabledFields([]byte(input), settings, false)
if err != nil {
t.Fatalf("RemoveDisabledFields returned error: %v", err)
}
assertJSONEqual(t, input, string(out))
}
func TestRemoveDisabledFieldsDefaultFiltering(t *testing.T) {
input := `{
"service_tier":"flex",
"inference_geo":"eu",
"safety_identifier":"user-123",
"store":true,
"stream_options":{"include_obfuscation":false}
}`
settings := dto.ChannelOtherSettings{}
out, err := RemoveDisabledFields([]byte(input), settings, false)
if err != nil {
t.Fatalf("RemoveDisabledFields returned error: %v", err)
}
assertJSONEqual(t, `{"store":true}`, string(out))
}
func TestRemoveDisabledFieldsAllowInferenceGeo(t *testing.T) {
input := `{
"inference_geo":"eu",
"store":true
}`
settings := dto.ChannelOtherSettings{
AllowInferenceGeo: true,
}
out, err := RemoveDisabledFields([]byte(input), settings, false)
if err != nil {
t.Fatalf("RemoveDisabledFields returned error: %v", err)
}
assertJSONEqual(t, `{"inference_geo":"eu","store":true}`, string(out))
}
func assertJSONEqual(t *testing.T, want, got string) {
t.Helper()

View File

@@ -119,8 +119,12 @@ type RelayInfo struct {
SendResponseCount int
ReceivedResponseCount int
FinalPreConsumedQuota int // 最终预消耗的配额
// ForcePreConsume 为 true 时禁用 BillingSession 的信任额度旁路,
// 强制预扣全额。用于异步任务(视频/音乐生成等),因为请求返回后任务仍在运行,
// 必须在提交前锁定全额。
ForcePreConsume bool
// Billing 是计费会话,封装了预扣费/结算/退款的统一生命周期。
// 免费模型和按次计费MJ/Task时为 nil。
// 免费模型时为 nil。
Billing BillingSettler
// BillingSource indicates whether this request is billed from wallet quota or subscription.
// "" or "wallet" => wallet; "subscription" => subscription
@@ -153,7 +157,8 @@ type RelayInfo struct {
// RequestConversionChain records request format conversions in order, e.g.
// ["openai", "openai_responses"] or ["openai", "claude"].
RequestConversionChain []types.RelayFormat
// 最终请求到上游的格式 TODO: 当前仅设置了Claude
// 最终请求到上游的格式。可由 adaptor 显式设置;
// 若为空,调用 GetFinalRequestRelayFormat 会回退到 RequestConversionChain 的最后一项或 RelayFormat。
FinalRequestRelayFormat types.RelayFormat
ThinkingContentInfo
@@ -552,8 +557,10 @@ func GenRelayInfo(c *gin.Context, relayFormat types.RelayFormat, request dto.Req
return nil, errors.New("request is not a OpenAIResponsesCompactionRequest")
case types.RelayFormatTask:
info = genBaseRelayInfo(c, nil)
info.TaskRelayInfo = &TaskRelayInfo{}
case types.RelayFormatMjProxy:
info = genBaseRelayInfo(c, nil)
info.TaskRelayInfo = &TaskRelayInfo{}
default:
err = errors.New("invalid relay format")
}
@@ -600,6 +607,19 @@ func (info *RelayInfo) AppendRequestConversion(format types.RelayFormat) {
info.RequestConversionChain = append(info.RequestConversionChain, format)
}
func (info *RelayInfo) GetFinalRequestRelayFormat() types.RelayFormat {
if info == nil {
return ""
}
if info.FinalRequestRelayFormat != "" {
return info.FinalRequestRelayFormat
}
if n := len(info.RequestConversionChain); n > 0 {
return info.RequestConversionChain[n-1]
}
return info.RelayFormat
}
func GenRelayInfoResponsesCompaction(c *gin.Context, request *dto.OpenAIResponsesCompactionRequest) *RelayInfo {
info := genBaseRelayInfo(c, request)
if info.RelayMode == relayconstant.RelayModeUnknown {
@@ -635,8 +655,16 @@ func (info *RelayInfo) HasSendResponse() bool {
type TaskRelayInfo struct {
Action string
OriginTaskID string
// PublicTaskID 是提交时预生成的 task_xxxx 格式公开 ID
// 供 DoResponse 在返回给客户端时使用(避免暴露上游真实 ID
PublicTaskID string
ConsumeQuota bool
// LockedChannel holds the full channel object when the request is bound to
// a specific channel (e.g., remix on origin task's channel). Stored as any
// to avoid an import cycle with model; callers type-assert to *model.Channel.
LockedChannel any
}
type TaskSubmitReq struct {
@@ -694,11 +722,11 @@ func (t *TaskSubmitReq) UnmarshalJSON(data []byte) error {
func (t *TaskSubmitReq) UnmarshalMetadata(v any) error {
metadata := t.Metadata
if metadata != nil {
metadataBytes, err := json.Marshal(metadata)
metadataBytes, err := common.Marshal(metadata)
if err != nil {
return fmt.Errorf("marshal metadata failed: %w", err)
}
err = json.Unmarshal(metadataBytes, v)
err = common.Unmarshal(metadataBytes, v)
if err != nil {
return fmt.Errorf("unmarshal metadata to target failed: %w", err)
}
@@ -727,9 +755,15 @@ func FailTaskInfo(reason string) *TaskInfo {
// RemoveDisabledFields 从请求 JSON 数据中移除渠道设置中禁用的字段
// service_tier: 服务层级字段可能导致额外计费OpenAI、Claude、Responses API 支持)
// inference_geo: Claude 数据驻留推理区域字段(仅 Claude 支持,默认过滤)
// store: 数据存储授权字段,涉及用户隐私(仅 OpenAI、Responses API 支持,默认允许透传,禁用后可能导致 Codex 无法使用)
// safety_identifier: 安全标识符,用于向 OpenAI 报告违规用户(仅 OpenAI 支持,涉及用户隐私)
func RemoveDisabledFields(jsonData []byte, channelOtherSettings dto.ChannelOtherSettings) ([]byte, error) {
// stream_options.include_obfuscation: 响应流混淆控制字段(仅 OpenAI Responses API 支持)
func RemoveDisabledFields(jsonData []byte, channelOtherSettings dto.ChannelOtherSettings, channelPassThroughEnabled bool) ([]byte, error) {
if model_setting.GetGlobalSettings().PassThroughRequestEnabled || channelPassThroughEnabled {
return jsonData, nil
}
var data map[string]interface{}
if err := common.Unmarshal(jsonData, &data); err != nil {
common.SysError("RemoveDisabledFields Unmarshal error :" + err.Error())
@@ -743,6 +777,13 @@ func RemoveDisabledFields(jsonData []byte, channelOtherSettings dto.ChannelOther
}
}
// 默认移除 inference_geo除非明确允许避免在未授权情况下透传数据驻留区域
if !channelOtherSettings.AllowInferenceGeo {
if _, exists := data["inference_geo"]; exists {
delete(data, "inference_geo")
}
}
// 默认允许 store 透传,除非明确禁用(禁用可能影响 Codex 使用)
if channelOtherSettings.DisableStore {
if _, exists := data["store"]; exists {
@@ -757,6 +798,22 @@ func RemoveDisabledFields(jsonData []byte, channelOtherSettings dto.ChannelOther
}
}
// 默认移除 stream_options.include_obfuscation除非明确允许避免关闭响应流混淆保护
if !channelOtherSettings.AllowIncludeObfuscation {
if streamOptionsAny, exists := data["stream_options"]; exists {
if streamOptions, ok := streamOptionsAny.(map[string]interface{}); ok {
if _, includeExists := streamOptions["include_obfuscation"]; includeExists {
delete(streamOptions, "include_obfuscation")
}
if len(streamOptions) == 0 {
delete(data, "stream_options")
} else {
data["stream_options"] = streamOptions
}
}
}
}
jsonDataAfter, err := common.Marshal(data)
if err != nil {
common.SysError("RemoveDisabledFields Marshal error :" + err.Error())

View File

@@ -0,0 +1,40 @@
package common
import (
"testing"
"github.com/QuantumNous/new-api/types"
"github.com/stretchr/testify/require"
)
func TestRelayInfoGetFinalRequestRelayFormatPrefersExplicitFinal(t *testing.T) {
info := &RelayInfo{
RelayFormat: types.RelayFormatOpenAI,
RequestConversionChain: []types.RelayFormat{types.RelayFormatOpenAI, types.RelayFormatClaude},
FinalRequestRelayFormat: types.RelayFormatOpenAIResponses,
}
require.Equal(t, types.RelayFormat(types.RelayFormatOpenAIResponses), info.GetFinalRequestRelayFormat())
}
func TestRelayInfoGetFinalRequestRelayFormatFallsBackToConversionChain(t *testing.T) {
info := &RelayInfo{
RelayFormat: types.RelayFormatOpenAI,
RequestConversionChain: []types.RelayFormat{types.RelayFormatOpenAI, types.RelayFormatClaude},
}
require.Equal(t, types.RelayFormat(types.RelayFormatClaude), info.GetFinalRequestRelayFormat())
}
func TestRelayInfoGetFinalRequestRelayFormatFallsBackToRelayFormat(t *testing.T) {
info := &RelayInfo{
RelayFormat: types.RelayFormatGemini,
}
require.Equal(t, types.RelayFormat(types.RelayFormatGemini), info.GetFinalRequestRelayFormat())
}
func TestRelayInfoGetFinalRequestRelayFormatNilReceiver(t *testing.T) {
var info *RelayInfo
require.Equal(t, types.RelayFormat(""), info.GetFinalRequestRelayFormat())
}

View File

@@ -173,16 +173,10 @@ func ValidateMultipartDirect(c *gin.Context, info *RelayInfo) *dto.TaskError {
if model == "sora-2-pro" && !lo.Contains([]string{"720x1280", "1280x720", "1792x1024", "1024x1792"}, size) {
return createTaskError(fmt.Errorf("sora-2 size is invalid"), "invalid_size", http.StatusBadRequest, true)
}
info.PriceData.OtherRatios = map[string]float64{
"seconds": float64(seconds),
"size": 1,
}
if lo.Contains([]string{"1792x1024", "1024x1792"}, size) {
info.PriceData.OtherRatios["size"] = 1.666667
}
// OtherRatios 已移到 Sora adaptor 的 EstimateBilling 中设置
}
info.Action = action
storeTaskRequest(c, info, action, req)
return nil
}

View File

@@ -165,7 +165,7 @@ func TextHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types
}
// remove disabled fields for OpenAI API
jsonData, err = relaycommon.RemoveDisabledFields(jsonData, info.ChannelOtherSettings)
jsonData, err = relaycommon.RemoveDisabledFields(jsonData, info.ChannelOtherSettings, info.ChannelSetting.PassThroughBodyEnabled)
if err != nil {
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
}
@@ -232,7 +232,7 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage
}
if originUsage != nil {
service.ObserveChannelAffinityUsageCacheFromContext(ctx, usage)
service.ObserveChannelAffinityUsageCacheByRelayFormat(ctx, usage, relayInfo.GetFinalRequestRelayFormat())
}
adminRejectReason := common.GetContextKeyString(ctx, constant.ContextKeyAdminRejectReason)
@@ -336,7 +336,7 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage
var audioInputQuota decimal.Decimal
var audioInputPrice float64
isClaudeUsageSemantic := relayInfo.FinalRequestRelayFormat == types.RelayFormatClaude
isClaudeUsageSemantic := relayInfo.GetFinalRequestRelayFormat() == types.RelayFormatClaude
if !relayInfo.PriceData.UsePrice {
baseTokens := dPromptTokens
// 减去 cached tokens

View File

@@ -140,7 +140,7 @@ func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens
}
// ModelPriceHelperPerCall 按次计费的 PriceHelper (MJ、Task)
func ModelPriceHelperPerCall(c *gin.Context, info *relaycommon.RelayInfo) types.PerCallPriceData {
func ModelPriceHelperPerCall(c *gin.Context, info *relaycommon.RelayInfo) types.PriceData {
groupRatioInfo := HandleGroupRatio(c, info)
modelPrice, success := ratio_setting.GetModelPrice(info.OriginModelName, true)
@@ -154,7 +154,18 @@ func ModelPriceHelperPerCall(c *gin.Context, info *relaycommon.RelayInfo) types.
}
}
quota := int(modelPrice * common.QuotaPerUnit * groupRatioInfo.GroupRatio)
priceData := types.PerCallPriceData{
// 免费模型检测(与 ModelPriceHelper 对齐)
freeModel := false
if !operation_setting.GetQuotaSetting().EnableFreeModelPreConsume {
if groupRatioInfo.GroupRatio == 0 || modelPrice == 0 {
quota = 0
freeModel = true
}
}
priceData := types.PriceData{
FreeModel: freeModel,
ModelPrice: modelPrice,
Quota: quota,
GroupRatioInfo: groupRatioInfo,

View File

@@ -176,10 +176,32 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon
})
}
dataChan := make(chan string, 10)
wg.Add(1)
gopool.Go(func() {
defer func() {
wg.Done()
if r := recover(); r != nil {
logger.LogError(c, fmt.Sprintf("data handler goroutine panic: %v", r))
}
common.SafeSendBool(stopChan, true)
}()
for data := range dataChan {
writeMutex.Lock()
success := dataHandler(data)
writeMutex.Unlock()
if !success {
return
}
}
})
// Scanner goroutine with improved error handling
wg.Add(1)
common.RelayCtxGo(ctx, func() {
defer func() {
close(dataChan)
wg.Done()
if r := recover(); r != nil {
logger.LogError(c, fmt.Sprintf("scanner goroutine panic: %v", r))
@@ -215,27 +237,16 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon
continue
}
data = data[5:]
data = strings.TrimLeft(data, " ")
data = strings.TrimSuffix(data, "\r")
data = strings.TrimSpace(data)
if data == "" {
continue
}
if !strings.HasPrefix(data, "[DONE]") {
info.SetFirstResponseTime()
info.ReceivedResponseCount++
// 使用超时机制防止写操作阻塞
done := make(chan bool, 1)
gopool.Go(func() {
writeMutex.Lock()
defer writeMutex.Unlock()
done <- dataHandler(data)
})
select {
case success := <-done:
if !success {
return
}
case <-time.After(10 * time.Second):
logger.LogError(c, "data handler timeout")
return
case dataChan <- data:
case <-ctx.Done():
return
case <-stopChan:

View File

@@ -0,0 +1,521 @@
package helper
import (
"fmt"
"io"
"net/http"
"net/http/httptest"
"strings"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/QuantumNous/new-api/constant"
relaycommon "github.com/QuantumNous/new-api/relay/common"
"github.com/QuantumNous/new-api/setting/operation_setting"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func init() {
gin.SetMode(gin.TestMode)
}
func setupStreamTest(t *testing.T, body io.Reader) (*gin.Context, *http.Response, *relaycommon.RelayInfo) {
t.Helper()
oldTimeout := constant.StreamingTimeout
constant.StreamingTimeout = 30
t.Cleanup(func() {
constant.StreamingTimeout = oldTimeout
})
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
resp := &http.Response{
Body: io.NopCloser(body),
}
info := &relaycommon.RelayInfo{
ChannelMeta: &relaycommon.ChannelMeta{},
}
return c, resp, info
}
func buildSSEBody(n int) string {
var b strings.Builder
for i := 0; i < n; i++ {
fmt.Fprintf(&b, "data: {\"id\":%d,\"choices\":[{\"delta\":{\"content\":\"token_%d\"}}]}\n", i, i)
}
b.WriteString("data: [DONE]\n")
return b.String()
}
// slowReader wraps a reader and injects a delay before each Read call,
// simulating a slow upstream that trickles data.
type slowReader struct {
r io.Reader
delay time.Duration
}
func (s *slowReader) Read(p []byte) (int, error) {
time.Sleep(s.delay)
return s.r.Read(p)
}
// ---------- Basic correctness ----------
func TestStreamScannerHandler_NilInputs(t *testing.T) {
t.Parallel()
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
info := &relaycommon.RelayInfo{ChannelMeta: &relaycommon.ChannelMeta{}}
StreamScannerHandler(c, nil, info, func(data string) bool { return true })
StreamScannerHandler(c, &http.Response{Body: io.NopCloser(strings.NewReader(""))}, info, nil)
}
func TestStreamScannerHandler_EmptyBody(t *testing.T) {
t.Parallel()
c, resp, info := setupStreamTest(t, strings.NewReader(""))
var called atomic.Bool
StreamScannerHandler(c, resp, info, func(data string) bool {
called.Store(true)
return true
})
assert.False(t, called.Load(), "handler should not be called for empty body")
}
func TestStreamScannerHandler_1000Chunks(t *testing.T) {
t.Parallel()
const numChunks = 1000
body := buildSSEBody(numChunks)
c, resp, info := setupStreamTest(t, strings.NewReader(body))
var count atomic.Int64
StreamScannerHandler(c, resp, info, func(data string) bool {
count.Add(1)
return true
})
assert.Equal(t, int64(numChunks), count.Load())
assert.Equal(t, numChunks, info.ReceivedResponseCount)
}
func TestStreamScannerHandler_10000Chunks(t *testing.T) {
t.Parallel()
const numChunks = 10000
body := buildSSEBody(numChunks)
c, resp, info := setupStreamTest(t, strings.NewReader(body))
var count atomic.Int64
start := time.Now()
StreamScannerHandler(c, resp, info, func(data string) bool {
count.Add(1)
return true
})
elapsed := time.Since(start)
assert.Equal(t, int64(numChunks), count.Load())
assert.Equal(t, numChunks, info.ReceivedResponseCount)
t.Logf("10000 chunks processed in %v", elapsed)
}
func TestStreamScannerHandler_OrderPreserved(t *testing.T) {
t.Parallel()
const numChunks = 500
body := buildSSEBody(numChunks)
c, resp, info := setupStreamTest(t, strings.NewReader(body))
var mu sync.Mutex
received := make([]string, 0, numChunks)
StreamScannerHandler(c, resp, info, func(data string) bool {
mu.Lock()
received = append(received, data)
mu.Unlock()
return true
})
require.Equal(t, numChunks, len(received))
for i := 0; i < numChunks; i++ {
expected := fmt.Sprintf("{\"id\":%d,\"choices\":[{\"delta\":{\"content\":\"token_%d\"}}]}", i, i)
assert.Equal(t, expected, received[i], "chunk %d out of order", i)
}
}
func TestStreamScannerHandler_DoneStopsScanner(t *testing.T) {
t.Parallel()
body := buildSSEBody(50) + "data: should_not_appear\n"
c, resp, info := setupStreamTest(t, strings.NewReader(body))
var count atomic.Int64
StreamScannerHandler(c, resp, info, func(data string) bool {
count.Add(1)
return true
})
assert.Equal(t, int64(50), count.Load(), "data after [DONE] must not be processed")
}
func TestStreamScannerHandler_HandlerFailureStops(t *testing.T) {
t.Parallel()
const numChunks = 200
body := buildSSEBody(numChunks)
c, resp, info := setupStreamTest(t, strings.NewReader(body))
const failAt = 50
var count atomic.Int64
StreamScannerHandler(c, resp, info, func(data string) bool {
n := count.Add(1)
return n < failAt
})
// The worker stops at failAt; the scanner may have read ahead,
// but the handler should not be called beyond failAt.
assert.Equal(t, int64(failAt), count.Load())
}
func TestStreamScannerHandler_SkipsNonDataLines(t *testing.T) {
t.Parallel()
var b strings.Builder
b.WriteString(": comment line\n")
b.WriteString("event: message\n")
b.WriteString("id: 12345\n")
b.WriteString("retry: 5000\n")
for i := 0; i < 100; i++ {
fmt.Fprintf(&b, "data: payload_%d\n", i)
b.WriteString(": interleaved comment\n")
}
b.WriteString("data: [DONE]\n")
c, resp, info := setupStreamTest(t, strings.NewReader(b.String()))
var count atomic.Int64
StreamScannerHandler(c, resp, info, func(data string) bool {
count.Add(1)
return true
})
assert.Equal(t, int64(100), count.Load())
}
func TestStreamScannerHandler_DataWithExtraSpaces(t *testing.T) {
t.Parallel()
body := "data: {\"trimmed\":true} \ndata: [DONE]\n"
c, resp, info := setupStreamTest(t, strings.NewReader(body))
var got string
StreamScannerHandler(c, resp, info, func(data string) bool {
got = data
return true
})
assert.Equal(t, "{\"trimmed\":true}", got)
}
// ---------- Decoupling: scanner not blocked by slow handler ----------
func TestStreamScannerHandler_ScannerDecoupledFromSlowHandler(t *testing.T) {
t.Parallel()
// Strategy: use a slow upstream (io.Pipe, 10ms per chunk) AND a slow handler (20ms per chunk).
// If the scanner were synchronously coupled to the handler, total time would be
// ~numChunks * (10ms + 20ms) = 30ms * 50 = 1500ms.
// With decoupling, total time should be closer to
// ~numChunks * max(10ms, 20ms) = 20ms * 50 = 1000ms
// because the scanner reads ahead into the buffer while the handler processes.
const numChunks = 50
const upstreamDelay = 10 * time.Millisecond
const handlerDelay = 20 * time.Millisecond
pr, pw := io.Pipe()
go func() {
defer pw.Close()
for i := 0; i < numChunks; i++ {
fmt.Fprintf(pw, "data: {\"id\":%d}\n", i)
time.Sleep(upstreamDelay)
}
fmt.Fprint(pw, "data: [DONE]\n")
}()
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
oldTimeout := constant.StreamingTimeout
constant.StreamingTimeout = 30
t.Cleanup(func() { constant.StreamingTimeout = oldTimeout })
resp := &http.Response{Body: pr}
info := &relaycommon.RelayInfo{ChannelMeta: &relaycommon.ChannelMeta{}}
var count atomic.Int64
start := time.Now()
done := make(chan struct{})
go func() {
StreamScannerHandler(c, resp, info, func(data string) bool {
time.Sleep(handlerDelay)
count.Add(1)
return true
})
close(done)
}()
select {
case <-done:
case <-time.After(15 * time.Second):
t.Fatal("StreamScannerHandler did not complete in time")
}
elapsed := time.Since(start)
assert.Equal(t, int64(numChunks), count.Load())
coupledTime := time.Duration(numChunks) * (upstreamDelay + handlerDelay)
t.Logf("elapsed=%v, coupled_estimate=%v", elapsed, coupledTime)
// If decoupled, elapsed should be well under the coupled estimate.
assert.Less(t, elapsed, coupledTime*85/100,
"decoupled elapsed time (%v) should be significantly less than coupled estimate (%v)", elapsed, coupledTime)
}
func TestStreamScannerHandler_SlowUpstreamFastHandler(t *testing.T) {
t.Parallel()
const numChunks = 50
body := buildSSEBody(numChunks)
reader := &slowReader{r: strings.NewReader(body), delay: 2 * time.Millisecond}
c, resp, info := setupStreamTest(t, reader)
var count atomic.Int64
start := time.Now()
done := make(chan struct{})
go func() {
StreamScannerHandler(c, resp, info, func(data string) bool {
count.Add(1)
return true
})
close(done)
}()
select {
case <-done:
case <-time.After(15 * time.Second):
t.Fatal("timed out with slow upstream")
}
elapsed := time.Since(start)
assert.Equal(t, int64(numChunks), count.Load())
t.Logf("slow upstream (%d chunks, 2ms/read): %v", numChunks, elapsed)
}
// ---------- Ping tests ----------
func TestStreamScannerHandler_PingSentDuringSlowUpstream(t *testing.T) {
t.Parallel()
setting := operation_setting.GetGeneralSetting()
oldEnabled := setting.PingIntervalEnabled
oldSeconds := setting.PingIntervalSeconds
setting.PingIntervalEnabled = true
setting.PingIntervalSeconds = 1
t.Cleanup(func() {
setting.PingIntervalEnabled = oldEnabled
setting.PingIntervalSeconds = oldSeconds
})
// Create a reader that delivers data slowly: one chunk every 500ms over 3.5 seconds.
// The ping interval is 1s, so we should see at least 2 pings.
pr, pw := io.Pipe()
go func() {
defer pw.Close()
for i := 0; i < 7; i++ {
fmt.Fprintf(pw, "data: chunk_%d\n", i)
time.Sleep(500 * time.Millisecond)
}
fmt.Fprint(pw, "data: [DONE]\n")
}()
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
oldTimeout := constant.StreamingTimeout
constant.StreamingTimeout = 30
t.Cleanup(func() {
constant.StreamingTimeout = oldTimeout
})
resp := &http.Response{Body: pr}
info := &relaycommon.RelayInfo{ChannelMeta: &relaycommon.ChannelMeta{}}
var count atomic.Int64
done := make(chan struct{})
go func() {
StreamScannerHandler(c, resp, info, func(data string) bool {
count.Add(1)
return true
})
close(done)
}()
select {
case <-done:
case <-time.After(15 * time.Second):
t.Fatal("timed out waiting for stream to finish")
}
assert.Equal(t, int64(7), count.Load())
body := recorder.Body.String()
pingCount := strings.Count(body, ": PING")
t.Logf("received %d pings in response body", pingCount)
assert.GreaterOrEqual(t, pingCount, 2,
"expected at least 2 pings during 3.5s stream with 1s interval; got %d", pingCount)
}
func TestStreamScannerHandler_PingDisabledByRelayInfo(t *testing.T) {
t.Parallel()
setting := operation_setting.GetGeneralSetting()
oldEnabled := setting.PingIntervalEnabled
oldSeconds := setting.PingIntervalSeconds
setting.PingIntervalEnabled = true
setting.PingIntervalSeconds = 1
t.Cleanup(func() {
setting.PingIntervalEnabled = oldEnabled
setting.PingIntervalSeconds = oldSeconds
})
pr, pw := io.Pipe()
go func() {
defer pw.Close()
for i := 0; i < 5; i++ {
fmt.Fprintf(pw, "data: chunk_%d\n", i)
time.Sleep(500 * time.Millisecond)
}
fmt.Fprint(pw, "data: [DONE]\n")
}()
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
oldTimeout := constant.StreamingTimeout
constant.StreamingTimeout = 30
t.Cleanup(func() {
constant.StreamingTimeout = oldTimeout
})
resp := &http.Response{Body: pr}
info := &relaycommon.RelayInfo{
DisablePing: true,
ChannelMeta: &relaycommon.ChannelMeta{},
}
var count atomic.Int64
done := make(chan struct{})
go func() {
StreamScannerHandler(c, resp, info, func(data string) bool {
count.Add(1)
return true
})
close(done)
}()
select {
case <-done:
case <-time.After(15 * time.Second):
t.Fatal("timed out")
}
assert.Equal(t, int64(5), count.Load())
body := recorder.Body.String()
pingCount := strings.Count(body, ": PING")
assert.Equal(t, 0, pingCount, "pings should be disabled when DisablePing=true")
}
func TestStreamScannerHandler_PingInterleavesWithSlowUpstream(t *testing.T) {
t.Parallel()
setting := operation_setting.GetGeneralSetting()
oldEnabled := setting.PingIntervalEnabled
oldSeconds := setting.PingIntervalSeconds
setting.PingIntervalEnabled = true
setting.PingIntervalSeconds = 1
t.Cleanup(func() {
setting.PingIntervalEnabled = oldEnabled
setting.PingIntervalSeconds = oldSeconds
})
// Slow upstream + slow handler. Total stream takes ~5 seconds.
// The ping goroutine stays alive as long as the scanner is reading,
// so pings should fire between data writes.
pr, pw := io.Pipe()
go func() {
defer pw.Close()
for i := 0; i < 10; i++ {
fmt.Fprintf(pw, "data: chunk_%d\n", i)
time.Sleep(500 * time.Millisecond)
}
fmt.Fprint(pw, "data: [DONE]\n")
}()
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
oldTimeout := constant.StreamingTimeout
constant.StreamingTimeout = 30
t.Cleanup(func() {
constant.StreamingTimeout = oldTimeout
})
resp := &http.Response{Body: pr}
info := &relaycommon.RelayInfo{ChannelMeta: &relaycommon.ChannelMeta{}}
var count atomic.Int64
done := make(chan struct{})
go func() {
StreamScannerHandler(c, resp, info, func(data string) bool {
count.Add(1)
return true
})
close(done)
}()
select {
case <-done:
case <-time.After(15 * time.Second):
t.Fatal("timed out")
}
assert.Equal(t, int64(10), count.Load())
body := recorder.Body.String()
pingCount := strings.Count(body, ": PING")
t.Logf("received %d pings interleaved with 10 chunks over 5s", pingCount)
assert.GreaterOrEqual(t, pingCount, 3,
"expected at least 3 pings during 5s stream with 1s ping interval; got %d", pingCount)
}

View File

@@ -184,7 +184,7 @@ func RelaySwapFace(c *gin.Context, info *relaycommon.RelayInfo) *dto.MidjourneyR
if swapFaceRequest.SourceBase64 == "" || swapFaceRequest.TargetBase64 == "" {
return service.MidjourneyErrorWrapper(constant.MjRequestError, "sour_base64_and_target_base64_is_required")
}
modelName := service.CoverActionToModelName(constant.MjActionSwapFace)
modelName := service.CovertMjpActionToModelName(constant.MjActionSwapFace)
priceData := helper.ModelPriceHelperPerCall(c, info)
@@ -485,7 +485,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayInfo *relaycommon.RelayInfo) *dt
fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
modelName := service.CoverActionToModelName(midjRequest.Action)
modelName := service.CovertMjpActionToModelName(midjRequest.Action)
priceData := helper.ModelPriceHelperPerCall(c, relayInfo)

View File

@@ -2,7 +2,6 @@ package relay
import (
"bytes"
"encoding/json"
"errors"
"fmt"
"io"
@@ -15,29 +14,33 @@ import (
"github.com/QuantumNous/new-api/dto"
"github.com/QuantumNous/new-api/model"
"github.com/QuantumNous/new-api/relay/channel"
"github.com/QuantumNous/new-api/relay/channel/task/taskcommon"
relaycommon "github.com/QuantumNous/new-api/relay/common"
relayconstant "github.com/QuantumNous/new-api/relay/constant"
"github.com/QuantumNous/new-api/relay/helper"
"github.com/QuantumNous/new-api/service"
"github.com/QuantumNous/new-api/setting/ratio_setting"
"github.com/gin-gonic/gin"
)
/*
Task 任务通过平台、Action 区分任务
*/
func RelayTaskSubmit(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.TaskError) {
info.InitChannelMeta(c)
// ensure TaskRelayInfo is initialized to avoid nil dereference when accessing embedded fields
if info.TaskRelayInfo == nil {
info.TaskRelayInfo = &relaycommon.TaskRelayInfo{}
}
type TaskSubmitResult struct {
UpstreamTaskID string
TaskData []byte
Platform constant.TaskPlatform
Quota int
//PerCallPrice types.PriceData
}
// ResolveOriginTask 处理基于已有任务的提交remix / continuation
// 查找原始任务、从中提取模型名称、将渠道锁定到原始任务的渠道
// (通过 info.LockedChannel重试时复用同一渠道并轮换 key
// 以及提取 OtherRatios时长、分辨率
// 该函数在控制器的重试循环之前调用一次,其结果通过 info 字段和上下文持久化。
func ResolveOriginTask(c *gin.Context, info *relaycommon.RelayInfo) *dto.TaskError {
// 检测 remix action
path := c.Request.URL.Path
if strings.Contains(path, "/v1/videos/") && strings.HasSuffix(path, "/remix") {
info.Action = constant.TaskActionRemix
}
// 提取 remix 任务的 video_id
if info.Action == constant.TaskActionRemix {
videoID := c.Param("video_id")
if strings.TrimSpace(videoID) == "" {
@@ -46,64 +49,71 @@ func RelayTaskSubmit(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.
info.OriginTaskID = videoID
}
platform := constant.TaskPlatform(c.GetString("platform"))
if info.OriginTaskID == "" {
return nil
}
// 获取原始任务信息
if info.OriginTaskID != "" {
originTask, exist, err := model.GetByTaskId(info.UserId, info.OriginTaskID)
if err != nil {
taskErr = service.TaskErrorWrapper(err, "get_origin_task_failed", http.StatusInternalServerError)
return
}
if !exist {
taskErr = service.TaskErrorWrapperLocal(errors.New("task_origin_not_exist"), "task_not_exist", http.StatusBadRequest)
return
}
if info.OriginModelName == "" {
if originTask.Properties.OriginModelName != "" {
info.OriginModelName = originTask.Properties.OriginModelName
} else if originTask.Properties.UpstreamModelName != "" {
info.OriginModelName = originTask.Properties.UpstreamModelName
} else {
var taskData map[string]interface{}
_ = json.Unmarshal(originTask.Data, &taskData)
if m, ok := taskData["model"].(string); ok && m != "" {
info.OriginModelName = m
platform = originTask.Platform
}
}
}
if originTask.ChannelId != info.ChannelId {
channel, err := model.GetChannelById(originTask.ChannelId, true)
if err != nil {
taskErr = service.TaskErrorWrapperLocal(err, "channel_not_found", http.StatusBadRequest)
return
}
if channel.Status != common.ChannelStatusEnabled {
taskErr = service.TaskErrorWrapperLocal(errors.New("the channel of the origin task is disabled"), "task_channel_disable", http.StatusBadRequest)
return
}
key, _, newAPIError := channel.GetNextEnabledKey()
if newAPIError != nil {
taskErr = service.TaskErrorWrapper(newAPIError, "channel_no_available_key", newAPIError.StatusCode)
return
}
common.SetContextKey(c, constant.ContextKeyChannelKey, key)
common.SetContextKey(c, constant.ContextKeyChannelType, channel.Type)
common.SetContextKey(c, constant.ContextKeyChannelBaseUrl, channel.GetBaseURL())
common.SetContextKey(c, constant.ContextKeyChannelId, originTask.ChannelId)
// 查找原始任务
originTask, exist, err := model.GetByTaskId(info.UserId, info.OriginTaskID)
if err != nil {
return service.TaskErrorWrapper(err, "get_origin_task_failed", http.StatusInternalServerError)
}
if !exist {
return service.TaskErrorWrapperLocal(errors.New("task_origin_not_exist"), "task_not_exist", http.StatusBadRequest)
}
info.ChannelBaseUrl = channel.GetBaseURL()
info.ChannelId = originTask.ChannelId
info.ChannelType = channel.Type
info.ApiKey = key
platform = originTask.Platform
}
// 使用原始任务的参数
if info.Action == constant.TaskActionRemix {
// 从原始任务推导模型名称
if info.OriginModelName == "" {
if originTask.Properties.OriginModelName != "" {
info.OriginModelName = originTask.Properties.OriginModelName
} else if originTask.Properties.UpstreamModelName != "" {
info.OriginModelName = originTask.Properties.UpstreamModelName
} else {
var taskData map[string]interface{}
_ = json.Unmarshal(originTask.Data, &taskData)
_ = common.Unmarshal(originTask.Data, &taskData)
if m, ok := taskData["model"].(string); ok && m != "" {
info.OriginModelName = m
}
}
}
// 锁定到原始任务的渠道(重试时复用同一渠道,轮换 key
ch, err := model.GetChannelById(originTask.ChannelId, true)
if err != nil {
return service.TaskErrorWrapperLocal(err, "channel_not_found", http.StatusBadRequest)
}
if ch.Status != common.ChannelStatusEnabled {
return service.TaskErrorWrapperLocal(errors.New("the channel of the origin task is disabled"), "task_channel_disable", http.StatusBadRequest)
}
info.LockedChannel = ch
if originTask.ChannelId != info.ChannelId {
key, _, newAPIError := ch.GetNextEnabledKey()
if newAPIError != nil {
return service.TaskErrorWrapper(newAPIError, "channel_no_available_key", newAPIError.StatusCode)
}
common.SetContextKey(c, constant.ContextKeyChannelKey, key)
common.SetContextKey(c, constant.ContextKeyChannelType, ch.Type)
common.SetContextKey(c, constant.ContextKeyChannelBaseUrl, ch.GetBaseURL())
common.SetContextKey(c, constant.ContextKeyChannelId, originTask.ChannelId)
info.ChannelBaseUrl = ch.GetBaseURL()
info.ChannelId = originTask.ChannelId
info.ChannelType = ch.Type
info.ApiKey = key
}
// 提取 remix 参数(时长、分辨率 → OtherRatios
if info.Action == constant.TaskActionRemix {
if originTask.PrivateData.BillingContext != nil {
// 新的 remix 逻辑:直接从原始任务的 BillingContext 中提取 OtherRatios如果存在
for s, f := range originTask.PrivateData.BillingContext.OtherRatios {
info.PriceData.AddOtherRatio(s, f)
}
} else {
// 旧的 remix 逻辑:直接从 task data 解析 seconds 和 size如果存在
var taskData map[string]interface{}
_ = common.Unmarshal(originTask.Data, &taskData)
secondsStr, _ := taskData["seconds"].(string)
seconds, _ := strconv.Atoi(secondsStr)
if seconds <= 0 {
@@ -120,167 +130,146 @@ func RelayTaskSubmit(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.
}
}
}
return nil
}
// RelayTaskSubmit 完成 task 提交的全部流程(每次尝试调用一次):
// 刷新渠道元数据 → 确定 platform/adaptor → 验证请求 →
// 估算计费(EstimateBilling) → 计算价格 → 预扣费(仅首次)→
// 构建/发送/解析上游请求 → 提交后计费调整(AdjustBillingOnSubmit)。
// 控制器负责 defer Refund 和成功后 Settle。
func RelayTaskSubmit(c *gin.Context, info *relaycommon.RelayInfo) (*TaskSubmitResult, *dto.TaskError) {
info.InitChannelMeta(c)
// 1. 确定 platform → 创建适配器 → 验证请求
platform := constant.TaskPlatform(c.GetString("platform"))
if platform == "" {
platform = GetTaskPlatform(c)
}
info.InitChannelMeta(c)
adaptor := GetTaskAdaptor(platform)
if adaptor == nil {
return service.TaskErrorWrapperLocal(fmt.Errorf("invalid api platform: %s", platform), "invalid_api_platform", http.StatusBadRequest)
return nil, service.TaskErrorWrapperLocal(fmt.Errorf("invalid api platform: %s", platform), "invalid_api_platform", http.StatusBadRequest)
}
adaptor.Init(info)
// get & validate taskRequest 获取并验证文本请求
taskErr = adaptor.ValidateRequestAndSetAction(c, info)
if taskErr != nil {
return
if taskErr := adaptor.ValidateRequestAndSetAction(c, info); taskErr != nil {
return nil, taskErr
}
// 2. 确定模型名称
modelName := info.OriginModelName
if modelName == "" {
modelName = service.CoverTaskActionToModelName(platform, info.Action)
}
modelPrice, success := ratio_setting.GetModelPrice(modelName, true)
if !success {
defaultPrice, ok := ratio_setting.GetDefaultModelPriceMap()[modelName]
if !ok {
modelPrice = float64(common.PreConsumedQuota) / common.QuotaPerUnit
} else {
modelPrice = defaultPrice
// 2.5 应用渠道的模型映射(与同步任务对齐)
info.OriginModelName = modelName
info.UpstreamModelName = modelName
if err := helper.ModelMappedHelper(c, info, nil); err != nil {
return nil, service.TaskErrorWrapperLocal(err, "model_mapping_failed", http.StatusBadRequest)
}
// 3. 预生成公开 task ID仅首次
if info.PublicTaskID == "" {
info.PublicTaskID = model.GenerateTaskID()
}
// 4. 价格计算:基础模型价格
info.OriginModelName = modelName
info.PriceData = helper.ModelPriceHelperPerCall(c, info)
// 5. 计费估算:让适配器根据用户请求提供 OtherRatios时长、分辨率等
// 必须在 ModelPriceHelperPerCall 之后调用(它会重建 PriceData
// ResolveOriginTask 可能已在 remix 路径中预设了 OtherRatios此处合并。
if estimatedRatios := adaptor.EstimateBilling(c, info); len(estimatedRatios) > 0 {
for k, v := range estimatedRatios {
info.PriceData.AddOtherRatio(k, v)
}
}
// 处理 auto 分组:从 context 获取实际选中的分组
// 当使用 auto 分组时Distribute 中间件会将实际选中的分组存储在 ContextKeyAutoGroup 中
if autoGroup, exists := common.GetContextKey(c, constant.ContextKeyAutoGroup); exists {
if groupStr, ok := autoGroup.(string); ok && groupStr != "" {
info.UsingGroup = groupStr
}
}
// 预扣
groupRatio := ratio_setting.GetGroupRatio(info.UsingGroup)
var ratio float64
userGroupRatio, hasUserGroupRatio := ratio_setting.GetGroupGroupRatio(info.UserGroup, info.UsingGroup)
if hasUserGroupRatio {
ratio = modelPrice * userGroupRatio
} else {
ratio = modelPrice * groupRatio
}
// FIXME: 临时修补,支持任务仅按次计费
// 6. 将 OtherRatios 应用到基础额度
if !common.StringsContains(constant.TaskPricePatches, modelName) {
if len(info.PriceData.OtherRatios) > 0 {
for _, ra := range info.PriceData.OtherRatios {
if 1.0 != ra {
ratio *= ra
}
for _, ra := range info.PriceData.OtherRatios {
if ra != 1.0 {
info.PriceData.Quota = int(float64(info.PriceData.Quota) * ra)
}
}
}
println(fmt.Sprintf("model: %s, model_price: %.4f, group: %s, group_ratio: %.4f, final_ratio: %.4f", modelName, modelPrice, info.UsingGroup, groupRatio, ratio))
userQuota, err := model.GetUserQuota(info.UserId, false)
if err != nil {
taskErr = service.TaskErrorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError)
return
}
quota := int(ratio * common.QuotaPerUnit)
if userQuota-quota < 0 {
taskErr = service.TaskErrorWrapperLocal(errors.New("user quota is not enough"), "quota_not_enough", http.StatusForbidden)
return
// 7. 预扣费(仅首次 — 重试时 info.Billing 已存在,跳过)
if info.Billing == nil && !info.PriceData.FreeModel {
info.ForcePreConsume = true
if apiErr := service.PreConsumeBilling(c, info.PriceData.Quota, info); apiErr != nil {
return nil, service.TaskErrorFromAPIError(apiErr)
}
}
// build body
// 8. 构建请求体
requestBody, err := adaptor.BuildRequestBody(c, info)
if err != nil {
taskErr = service.TaskErrorWrapper(err, "build_request_failed", http.StatusInternalServerError)
return
return nil, service.TaskErrorWrapper(err, "build_request_failed", http.StatusInternalServerError)
}
// do request
// 9. 发送请求
resp, err := adaptor.DoRequest(c, info, requestBody)
if err != nil {
taskErr = service.TaskErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
return
return nil, service.TaskErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
}
// handle response
if resp != nil && resp.StatusCode != http.StatusOK {
responseBody, _ := io.ReadAll(resp.Body)
taskErr = service.TaskErrorWrapper(fmt.Errorf("%s", string(responseBody)), "fail_to_fetch_task", resp.StatusCode)
return
return nil, service.TaskErrorWrapper(fmt.Errorf("%s", string(responseBody)), "fail_to_fetch_task", resp.StatusCode)
}
defer func() {
// release quota
if info.ConsumeQuota && taskErr == nil {
// 10. 返回 OtherRatios 给下游header 必须在 DoResponse 写 body 之前设置)
otherRatios := info.PriceData.OtherRatios
if otherRatios == nil {
otherRatios = map[string]float64{}
}
ratiosJSON, _ := common.Marshal(otherRatios)
c.Header("X-New-Api-Other-Ratios", string(ratiosJSON))
err := service.PostConsumeQuota(info, quota, 0, true)
if err != nil {
common.SysLog("error consuming token remain quota: " + err.Error())
}
if quota != 0 {
tokenName := c.GetString("token_name")
//gRatio := groupRatio
//if hasUserGroupRatio {
// gRatio = userGroupRatio
//}
logContent := fmt.Sprintf("操作 %s", info.Action)
// FIXME: 临时修补,支持任务仅按次计费
if common.StringsContains(constant.TaskPricePatches, modelName) {
logContent = fmt.Sprintf("%s按次计费", logContent)
} else {
if len(info.PriceData.OtherRatios) > 0 {
var contents []string
for key, ra := range info.PriceData.OtherRatios {
if 1.0 != ra {
contents = append(contents, fmt.Sprintf("%s: %.2f", key, ra))
}
}
if len(contents) > 0 {
logContent = fmt.Sprintf("%s, 计算参数:%s", logContent, strings.Join(contents, ", "))
}
}
}
other := make(map[string]interface{})
if c != nil && c.Request != nil && c.Request.URL != nil {
other["request_path"] = c.Request.URL.Path
}
other["model_price"] = modelPrice
other["group_ratio"] = groupRatio
if hasUserGroupRatio {
other["user_group_ratio"] = userGroupRatio
}
model.RecordConsumeLog(c, info.UserId, model.RecordConsumeLogParams{
ChannelId: info.ChannelId,
ModelName: modelName,
TokenName: tokenName,
Quota: quota,
Content: logContent,
TokenId: info.TokenId,
Group: info.UsingGroup,
Other: other,
})
model.UpdateUserUsedQuotaAndRequestCount(info.UserId, quota)
model.UpdateChannelUsedQuota(info.ChannelId, quota)
}
}
}()
taskID, taskData, taskErr := adaptor.DoResponse(c, resp, info)
// 11. 解析响应
upstreamTaskID, taskData, taskErr := adaptor.DoResponse(c, resp, info)
if taskErr != nil {
return
return nil, taskErr
}
info.ConsumeQuota = true
// insert task
task := model.InitTask(platform, info)
task.TaskID = taskID
task.Quota = quota
task.Data = taskData
task.Action = info.Action
err = task.Insert()
if err != nil {
taskErr = service.TaskErrorWrapper(err, "insert_task_failed", http.StatusInternalServerError)
return
// 11. 提交后计费调整:让适配器根据上游实际返回调整 OtherRatios
finalQuota := info.PriceData.Quota
if adjustedRatios := adaptor.AdjustBillingOnSubmit(info, taskData); len(adjustedRatios) > 0 {
// 基于调整后的 ratios 重新计算 quota
finalQuota = recalcQuotaFromRatios(info, adjustedRatios)
info.PriceData.OtherRatios = adjustedRatios
info.PriceData.Quota = finalQuota
}
return nil
return &TaskSubmitResult{
UpstreamTaskID: upstreamTaskID,
TaskData: taskData,
Platform: platform,
Quota: finalQuota,
}, nil
}
// recalcQuotaFromRatios 根据 adjustedRatios 重新计算 quota。
// 公式: baseQuota × ∏(ratio) — 其中 baseQuota 是不含 OtherRatios 的基础额度。
func recalcQuotaFromRatios(info *relaycommon.RelayInfo, ratios map[string]float64) int {
// 从 PriceData 获取不含 OtherRatios 的基础价格
baseQuota := info.PriceData.Quota
// 先除掉原有的 OtherRatios 恢复基础额度
for _, ra := range info.PriceData.OtherRatios {
if ra != 1.0 && ra > 0 {
baseQuota = int(float64(baseQuota) / ra)
}
}
// 应用新的 ratios
result := float64(baseQuota)
for _, ra := range ratios {
if ra != 1.0 {
result *= ra
}
}
return int(result)
}
var fetchRespBuilders = map[int]func(c *gin.Context) (respBody []byte, taskResp *dto.TaskError){
@@ -336,7 +325,7 @@ func sunoFetchRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *dto.Ta
} else {
tasks = make([]any, 0)
}
respBody, err = json.Marshal(dto.TaskResponse[[]any]{
respBody, err = common.Marshal(dto.TaskResponse[[]any]{
Code: "success",
Data: tasks,
})
@@ -357,7 +346,7 @@ func sunoFetchByIDRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *dt
return
}
respBody, err = json.Marshal(dto.TaskResponse[any]{
respBody, err = common.Marshal(dto.TaskResponse[any]{
Code: "success",
Data: TaskModel2Dto(originTask),
})
@@ -381,97 +370,16 @@ func videoFetchByIDRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *d
return
}
func() {
channelModel, err2 := model.GetChannelById(originTask.ChannelId, true)
if err2 != nil {
return
}
if channelModel.Type != constant.ChannelTypeVertexAi && channelModel.Type != constant.ChannelTypeGemini {
return
}
baseURL := constant.ChannelBaseURLs[channelModel.Type]
if channelModel.GetBaseURL() != "" {
baseURL = channelModel.GetBaseURL()
}
proxy := channelModel.GetSetting().Proxy
adaptor := GetTaskAdaptor(constant.TaskPlatform(strconv.Itoa(channelModel.Type)))
if adaptor == nil {
return
}
resp, err2 := adaptor.FetchTask(baseURL, channelModel.Key, map[string]any{
"task_id": originTask.TaskID,
"action": originTask.Action,
}, proxy)
if err2 != nil || resp == nil {
return
}
defer resp.Body.Close()
body, err2 := io.ReadAll(resp.Body)
if err2 != nil {
return
}
ti, err2 := adaptor.ParseTaskResult(body)
if err2 == nil && ti != nil {
if ti.Status != "" {
originTask.Status = model.TaskStatus(ti.Status)
}
if ti.Progress != "" {
originTask.Progress = ti.Progress
}
if ti.Url != "" {
if strings.HasPrefix(ti.Url, "data:") {
} else {
originTask.FailReason = ti.Url
}
}
_ = originTask.Update()
var raw map[string]any
_ = json.Unmarshal(body, &raw)
format := "mp4"
if respObj, ok := raw["response"].(map[string]any); ok {
if vids, ok := respObj["videos"].([]any); ok && len(vids) > 0 {
if v0, ok := vids[0].(map[string]any); ok {
if mt, ok := v0["mimeType"].(string); ok && mt != "" {
if strings.Contains(mt, "mp4") {
format = "mp4"
} else {
format = mt
}
}
}
}
}
status := "processing"
switch originTask.Status {
case model.TaskStatusSuccess:
status = "succeeded"
case model.TaskStatusFailure:
status = "failed"
case model.TaskStatusQueued, model.TaskStatusSubmitted:
status = "queued"
}
if !strings.HasPrefix(c.Request.RequestURI, "/v1/videos/") {
out := map[string]any{
"error": nil,
"format": format,
"metadata": nil,
"status": status,
"task_id": originTask.TaskID,
"url": originTask.FailReason,
}
respBody, _ = json.Marshal(dto.TaskResponse[any]{
Code: "success",
Data: out,
})
}
}
}()
isOpenAIVideoAPI := strings.HasPrefix(c.Request.RequestURI, "/v1/videos/")
if len(respBody) != 0 {
// Gemini/Vertex 支持实时查询:用户 fetch 时直接从上游拉取最新状态
if realtimeResp := tryRealtimeFetch(originTask, isOpenAIVideoAPI); len(realtimeResp) > 0 {
respBody = realtimeResp
return
}
if strings.HasPrefix(c.Request.RequestURI, "/v1/videos/") {
// OpenAI Video API 格式: 走各 adaptor 的 ConvertToOpenAIVideo
if isOpenAIVideoAPI {
adaptor := GetTaskAdaptor(originTask.Platform)
if adaptor == nil {
taskResp = service.TaskErrorWrapperLocal(fmt.Errorf("invalid channel id: %d", originTask.ChannelId), "invalid_channel_id", http.StatusBadRequest)
@@ -486,10 +394,12 @@ func videoFetchByIDRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *d
respBody = openAIVideoData
return
}
taskResp = service.TaskErrorWrapperLocal(errors.New(fmt.Sprintf("not_implemented:%s", originTask.Platform)), "not_implemented", http.StatusNotImplemented)
taskResp = service.TaskErrorWrapperLocal(fmt.Errorf("not_implemented:%s", originTask.Platform), "not_implemented", http.StatusNotImplemented)
return
}
respBody, err = json.Marshal(dto.TaskResponse[any]{
// 通用 TaskDto 格式
respBody, err = common.Marshal(dto.TaskResponse[any]{
Code: "success",
Data: TaskModel2Dto(originTask),
})
@@ -499,16 +409,150 @@ func videoFetchByIDRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *d
return
}
// tryRealtimeFetch 尝试从上游实时拉取 Gemini/Vertex 任务状态。
// 仅当渠道类型为 Gemini 或 Vertex 时触发;其他渠道或出错时返回 nil。
// 当非 OpenAI Video API 时,还会构建自定义格式的响应体。
func tryRealtimeFetch(task *model.Task, isOpenAIVideoAPI bool) []byte {
channelModel, err := model.GetChannelById(task.ChannelId, true)
if err != nil {
return nil
}
if channelModel.Type != constant.ChannelTypeVertexAi && channelModel.Type != constant.ChannelTypeGemini {
return nil
}
baseURL := constant.ChannelBaseURLs[channelModel.Type]
if channelModel.GetBaseURL() != "" {
baseURL = channelModel.GetBaseURL()
}
proxy := channelModel.GetSetting().Proxy
adaptor := GetTaskAdaptor(constant.TaskPlatform(strconv.Itoa(channelModel.Type)))
if adaptor == nil {
return nil
}
resp, err := adaptor.FetchTask(baseURL, channelModel.Key, map[string]any{
"task_id": task.GetUpstreamTaskID(),
"action": task.Action,
}, proxy)
if err != nil || resp == nil {
return nil
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil
}
ti, err := adaptor.ParseTaskResult(body)
if err != nil || ti == nil {
return nil
}
snap := task.Snapshot()
// 将上游最新状态更新到 task
if ti.Status != "" {
task.Status = model.TaskStatus(ti.Status)
}
if ti.Progress != "" {
task.Progress = ti.Progress
}
if strings.HasPrefix(ti.Url, "data:") {
// data: URI — kept in Data, not ResultURL
} else if ti.Url != "" {
task.PrivateData.ResultURL = ti.Url
} else if task.Status == model.TaskStatusSuccess {
// No URL from adaptor — construct proxy URL using public task ID
task.PrivateData.ResultURL = taskcommon.BuildProxyURL(task.TaskID)
}
if !snap.Equal(task.Snapshot()) {
_, _ = task.UpdateWithStatus(snap.Status)
}
// OpenAI Video API 由调用者的 ConvertToOpenAIVideo 分支处理
if isOpenAIVideoAPI {
return nil
}
// 非 OpenAI Video API: 构建自定义格式响应
format := detectVideoFormat(body)
out := map[string]any{
"error": nil,
"format": format,
"metadata": nil,
"status": mapTaskStatusToSimple(task.Status),
"task_id": task.TaskID,
"url": task.GetResultURL(),
}
respBody, _ := common.Marshal(dto.TaskResponse[any]{
Code: "success",
Data: out,
})
return respBody
}
// detectVideoFormat 从 Gemini/Vertex 原始响应中探测视频格式
func detectVideoFormat(rawBody []byte) string {
var raw map[string]any
if err := common.Unmarshal(rawBody, &raw); err != nil {
return "mp4"
}
respObj, ok := raw["response"].(map[string]any)
if !ok {
return "mp4"
}
vids, ok := respObj["videos"].([]any)
if !ok || len(vids) == 0 {
return "mp4"
}
v0, ok := vids[0].(map[string]any)
if !ok {
return "mp4"
}
mt, ok := v0["mimeType"].(string)
if !ok || mt == "" || strings.Contains(mt, "mp4") {
return "mp4"
}
return mt
}
// mapTaskStatusToSimple 将内部 TaskStatus 映射为简化状态字符串
func mapTaskStatusToSimple(status model.TaskStatus) string {
switch status {
case model.TaskStatusSuccess:
return "succeeded"
case model.TaskStatusFailure:
return "failed"
case model.TaskStatusQueued, model.TaskStatusSubmitted:
return "queued"
default:
return "processing"
}
}
func TaskModel2Dto(task *model.Task) *dto.TaskDto {
return &dto.TaskDto{
ID: task.ID,
CreatedAt: task.CreatedAt,
UpdatedAt: task.UpdatedAt,
TaskID: task.TaskID,
Platform: string(task.Platform),
UserId: task.UserId,
Group: task.Group,
ChannelId: task.ChannelId,
Quota: task.Quota,
Action: task.Action,
Status: string(task.Status),
FailReason: task.FailReason,
ResultURL: task.GetResultURL(),
SubmitTime: task.SubmitTime,
StartTime: task.StartTime,
FinishTime: task.FinishTime,
Progress: task.Progress,
Properties: task.Properties,
Username: task.Username,
Data: task.Data,
}
}

View File

@@ -89,7 +89,7 @@ func ResponsesHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *
}
// remove disabled fields for OpenAI Responses API
jsonData, err = relaycommon.RemoveDisabledFields(jsonData, info.ChannelOtherSettings)
jsonData, err = relaycommon.RemoveDisabledFields(jsonData, info.ChannelOtherSettings, info.ChannelSetting.PassThroughBodyEnabled)
if err != nil {
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
}