mirror of
https://github.com/QuantumNous/new-api.git
synced 2026-06-07 22:09:57 +00:00
Merge branch 'upstream-main' into feature/improve-param-override
# Conflicts: # relay/channel/api_request_test.go # relay/common/override_test.go # web/src/components/table/channels/modals/EditChannelModal.jsx
This commit is contained in:
@@ -36,6 +36,32 @@ type TaskAdaptor interface {
|
||||
|
||||
ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) *dto.TaskError
|
||||
|
||||
// ── Billing ──────────────────────────────────────────────────────
|
||||
|
||||
// EstimateBilling returns OtherRatios for pre-charge based on user request.
|
||||
// Called after ValidateRequestAndSetAction, before price calculation.
|
||||
// Adaptors should extract duration, resolution, etc. from the parsed request
|
||||
// and return them as ratio multipliers (e.g. {"seconds": 5, "size": 1.666}).
|
||||
// Return nil to use the base model price without extra ratios.
|
||||
EstimateBilling(c *gin.Context, info *relaycommon.RelayInfo) map[string]float64
|
||||
|
||||
// AdjustBillingOnSubmit returns adjusted OtherRatios from the upstream
|
||||
// submit response. Called after a successful DoResponse.
|
||||
// If the upstream returned actual parameters that differ from the estimate
|
||||
// (e.g. actual seconds), return updated ratios so the caller can recalculate
|
||||
// the quota and settle the delta with the pre-charge.
|
||||
// Return nil if no adjustment is needed.
|
||||
AdjustBillingOnSubmit(info *relaycommon.RelayInfo, taskData []byte) map[string]float64
|
||||
|
||||
// AdjustBillingOnComplete returns the actual quota when a task reaches a
|
||||
// terminal state (success/failure) during polling.
|
||||
// Called by the polling loop after ParseTaskResult.
|
||||
// Return a positive value to trigger delta settlement (supplement / refund).
|
||||
// Return 0 to keep the pre-charged amount unchanged.
|
||||
AdjustBillingOnComplete(task *model.Task, taskResult *relaycommon.TaskInfo) int
|
||||
|
||||
// ── Request / Response ───────────────────────────────────────────
|
||||
|
||||
BuildRequestURL(info *relaycommon.RelayInfo) (string, error)
|
||||
BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error
|
||||
BuildRequestBody(c *gin.Context, info *relaycommon.RelayInfo) (io.Reader, error)
|
||||
@@ -46,9 +72,9 @@ type TaskAdaptor interface {
|
||||
GetModelList() []string
|
||||
GetChannelName() string
|
||||
|
||||
// FetchTask
|
||||
FetchTask(baseUrl, key string, body map[string]any, proxy string) (*http.Response, error)
|
||||
// ── Polling ──────────────────────────────────────────────────────
|
||||
|
||||
FetchTask(baseUrl, key string, body map[string]any, proxy string) (*http.Response, error)
|
||||
ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error)
|
||||
}
|
||||
|
||||
|
||||
@@ -61,8 +61,9 @@ var passthroughSkipHeaderNamesLower = map[string]struct{}{
|
||||
"cookie": {},
|
||||
|
||||
// Additional headers that should not be forwarded by name-matching passthrough rules.
|
||||
"host": {},
|
||||
"content-length": {},
|
||||
"host": {},
|
||||
"content-length": {},
|
||||
"accept-encoding": {},
|
||||
|
||||
// Do not passthrough credentials by wildcard/regex.
|
||||
"authorization": {},
|
||||
|
||||
@@ -110,3 +110,30 @@ func TestProcessHeaderOverride_RuntimeOverrideHasPriority(t *testing.T) {
|
||||
_, ok := headers["X-Legacy"]
|
||||
require.False(t, ok)
|
||||
}
|
||||
|
||||
func TestProcessHeaderOverride_PassthroughSkipsAcceptEncoding(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
gin.SetMode(gin.TestMode)
|
||||
recorder := httptest.NewRecorder()
|
||||
ctx, _ := gin.CreateTestContext(recorder)
|
||||
ctx.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||
ctx.Request.Header.Set("X-Trace-Id", "trace-123")
|
||||
ctx.Request.Header.Set("Accept-Encoding", "gzip")
|
||||
|
||||
info := &relaycommon.RelayInfo{
|
||||
IsChannelTest: false,
|
||||
ChannelMeta: &relaycommon.ChannelMeta{
|
||||
HeadersOverride: map[string]any{
|
||||
"*": "",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
headers, err := processHeaderOverride(info, ctx)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "trace-123", headers["X-Trace-Id"])
|
||||
|
||||
_, hasAcceptEncoding := headers["Accept-Encoding"]
|
||||
require.False(t, hasAcceptEncoding)
|
||||
}
|
||||
|
||||
@@ -42,22 +42,7 @@ func GeminiTextGenerationHandler(c *gin.Context, info *relaycommon.RelayInfo, re
|
||||
}
|
||||
|
||||
// 计算使用量(基于 UsageMetadata)
|
||||
usage := dto.Usage{
|
||||
PromptTokens: geminiResponse.UsageMetadata.PromptTokenCount,
|
||||
CompletionTokens: geminiResponse.UsageMetadata.CandidatesTokenCount + geminiResponse.UsageMetadata.ThoughtsTokenCount,
|
||||
TotalTokens: geminiResponse.UsageMetadata.TotalTokenCount,
|
||||
}
|
||||
|
||||
usage.CompletionTokenDetails.ReasoningTokens = geminiResponse.UsageMetadata.ThoughtsTokenCount
|
||||
usage.PromptTokensDetails.CachedTokens = geminiResponse.UsageMetadata.CachedContentTokenCount
|
||||
|
||||
for _, detail := range geminiResponse.UsageMetadata.PromptTokensDetails {
|
||||
if detail.Modality == "AUDIO" {
|
||||
usage.PromptTokensDetails.AudioTokens = detail.TokenCount
|
||||
} else if detail.Modality == "TEXT" {
|
||||
usage.PromptTokensDetails.TextTokens = detail.TokenCount
|
||||
}
|
||||
}
|
||||
usage := buildUsageFromGeminiMetadata(geminiResponse.UsageMetadata, info.GetEstimatePromptTokens())
|
||||
|
||||
service.IOCopyBytesGracefully(c, resp, responseBody)
|
||||
|
||||
|
||||
@@ -1032,6 +1032,46 @@ func getResponseToolCall(item *dto.GeminiPart) *dto.ToolCallResponse {
|
||||
}
|
||||
}
|
||||
|
||||
func buildUsageFromGeminiMetadata(metadata dto.GeminiUsageMetadata, fallbackPromptTokens int) dto.Usage {
|
||||
promptTokens := metadata.PromptTokenCount + metadata.ToolUsePromptTokenCount
|
||||
if promptTokens <= 0 && fallbackPromptTokens > 0 {
|
||||
promptTokens = fallbackPromptTokens
|
||||
}
|
||||
|
||||
usage := dto.Usage{
|
||||
PromptTokens: promptTokens,
|
||||
CompletionTokens: metadata.CandidatesTokenCount + metadata.ThoughtsTokenCount,
|
||||
TotalTokens: metadata.TotalTokenCount,
|
||||
}
|
||||
usage.CompletionTokenDetails.ReasoningTokens = metadata.ThoughtsTokenCount
|
||||
usage.PromptTokensDetails.CachedTokens = metadata.CachedContentTokenCount
|
||||
|
||||
for _, detail := range metadata.PromptTokensDetails {
|
||||
if detail.Modality == "AUDIO" {
|
||||
usage.PromptTokensDetails.AudioTokens += detail.TokenCount
|
||||
} else if detail.Modality == "TEXT" {
|
||||
usage.PromptTokensDetails.TextTokens += detail.TokenCount
|
||||
}
|
||||
}
|
||||
for _, detail := range metadata.ToolUsePromptTokensDetails {
|
||||
if detail.Modality == "AUDIO" {
|
||||
usage.PromptTokensDetails.AudioTokens += detail.TokenCount
|
||||
} else if detail.Modality == "TEXT" {
|
||||
usage.PromptTokensDetails.TextTokens += detail.TokenCount
|
||||
}
|
||||
}
|
||||
|
||||
if usage.TotalTokens > 0 && usage.CompletionTokens <= 0 {
|
||||
usage.CompletionTokens = usage.TotalTokens - usage.PromptTokens
|
||||
}
|
||||
|
||||
if usage.PromptTokens > 0 && usage.PromptTokensDetails.TextTokens == 0 && usage.PromptTokensDetails.AudioTokens == 0 {
|
||||
usage.PromptTokensDetails.TextTokens = usage.PromptTokens
|
||||
}
|
||||
|
||||
return usage
|
||||
}
|
||||
|
||||
func responseGeminiChat2OpenAI(c *gin.Context, response *dto.GeminiChatResponse) *dto.OpenAITextResponse {
|
||||
fullTextResponse := dto.OpenAITextResponse{
|
||||
Id: helper.GetResponseID(c),
|
||||
@@ -1272,18 +1312,8 @@ func geminiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http
|
||||
|
||||
// 更新使用量统计
|
||||
if geminiResponse.UsageMetadata.TotalTokenCount != 0 {
|
||||
usage.PromptTokens = geminiResponse.UsageMetadata.PromptTokenCount
|
||||
usage.CompletionTokens = geminiResponse.UsageMetadata.CandidatesTokenCount + geminiResponse.UsageMetadata.ThoughtsTokenCount
|
||||
usage.CompletionTokenDetails.ReasoningTokens = geminiResponse.UsageMetadata.ThoughtsTokenCount
|
||||
usage.TotalTokens = geminiResponse.UsageMetadata.TotalTokenCount
|
||||
usage.PromptTokensDetails.CachedTokens = geminiResponse.UsageMetadata.CachedContentTokenCount
|
||||
for _, detail := range geminiResponse.UsageMetadata.PromptTokensDetails {
|
||||
if detail.Modality == "AUDIO" {
|
||||
usage.PromptTokensDetails.AudioTokens = detail.TokenCount
|
||||
} else if detail.Modality == "TEXT" {
|
||||
usage.PromptTokensDetails.TextTokens = detail.TokenCount
|
||||
}
|
||||
}
|
||||
mappedUsage := buildUsageFromGeminiMetadata(geminiResponse.UsageMetadata, info.GetEstimatePromptTokens())
|
||||
*usage = mappedUsage
|
||||
}
|
||||
|
||||
return callback(data, &geminiResponse)
|
||||
@@ -1295,11 +1325,6 @@ func geminiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http
|
||||
}
|
||||
}
|
||||
|
||||
usage.PromptTokensDetails.TextTokens = usage.PromptTokens
|
||||
if usage.TotalTokens > 0 {
|
||||
usage.CompletionTokens = usage.TotalTokens - usage.PromptTokens
|
||||
}
|
||||
|
||||
if usage.CompletionTokens <= 0 {
|
||||
if info.ReceivedResponseCount > 0 {
|
||||
usage = service.ResponseText2Usage(c, responseText.String(), info.UpstreamModelName, info.GetEstimatePromptTokens())
|
||||
@@ -1416,21 +1441,7 @@ func GeminiChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.R
|
||||
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
||||
}
|
||||
if len(geminiResponse.Candidates) == 0 {
|
||||
usage := dto.Usage{
|
||||
PromptTokens: geminiResponse.UsageMetadata.PromptTokenCount,
|
||||
}
|
||||
usage.CompletionTokenDetails.ReasoningTokens = geminiResponse.UsageMetadata.ThoughtsTokenCount
|
||||
usage.PromptTokensDetails.CachedTokens = geminiResponse.UsageMetadata.CachedContentTokenCount
|
||||
for _, detail := range geminiResponse.UsageMetadata.PromptTokensDetails {
|
||||
if detail.Modality == "AUDIO" {
|
||||
usage.PromptTokensDetails.AudioTokens = detail.TokenCount
|
||||
} else if detail.Modality == "TEXT" {
|
||||
usage.PromptTokensDetails.TextTokens = detail.TokenCount
|
||||
}
|
||||
}
|
||||
if usage.PromptTokens <= 0 {
|
||||
usage.PromptTokens = info.GetEstimatePromptTokens()
|
||||
}
|
||||
usage := buildUsageFromGeminiMetadata(geminiResponse.UsageMetadata, info.GetEstimatePromptTokens())
|
||||
|
||||
var newAPIError *types.NewAPIError
|
||||
if geminiResponse.PromptFeedback != nil && geminiResponse.PromptFeedback.BlockReason != nil {
|
||||
@@ -1466,23 +1477,7 @@ func GeminiChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.R
|
||||
}
|
||||
fullTextResponse := responseGeminiChat2OpenAI(c, &geminiResponse)
|
||||
fullTextResponse.Model = info.UpstreamModelName
|
||||
usage := dto.Usage{
|
||||
PromptTokens: geminiResponse.UsageMetadata.PromptTokenCount,
|
||||
CompletionTokens: geminiResponse.UsageMetadata.CandidatesTokenCount,
|
||||
TotalTokens: geminiResponse.UsageMetadata.TotalTokenCount,
|
||||
}
|
||||
|
||||
usage.CompletionTokenDetails.ReasoningTokens = geminiResponse.UsageMetadata.ThoughtsTokenCount
|
||||
usage.PromptTokensDetails.CachedTokens = geminiResponse.UsageMetadata.CachedContentTokenCount
|
||||
usage.CompletionTokens = usage.TotalTokens - usage.PromptTokens
|
||||
|
||||
for _, detail := range geminiResponse.UsageMetadata.PromptTokensDetails {
|
||||
if detail.Modality == "AUDIO" {
|
||||
usage.PromptTokensDetails.AudioTokens = detail.TokenCount
|
||||
} else if detail.Modality == "TEXT" {
|
||||
usage.PromptTokensDetails.TextTokens = detail.TokenCount
|
||||
}
|
||||
}
|
||||
usage := buildUsageFromGeminiMetadata(geminiResponse.UsageMetadata, info.GetEstimatePromptTokens())
|
||||
|
||||
fullTextResponse.Usage = usage
|
||||
|
||||
|
||||
333
relay/channel/gemini/relay_gemini_usage_test.go
Normal file
333
relay/channel/gemini/relay_gemini_usage_test.go
Normal file
@@ -0,0 +1,333 @@
|
||||
package gemini
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/constant"
|
||||
"github.com/QuantumNous/new-api/dto"
|
||||
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||||
"github.com/QuantumNous/new-api/types"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestGeminiChatHandlerCompletionTokensExcludeToolUsePromptTokens(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
gin.SetMode(gin.TestMode)
|
||||
c, _ := gin.CreateTestContext(httptest.NewRecorder())
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||
|
||||
info := &relaycommon.RelayInfo{
|
||||
RelayFormat: types.RelayFormatGemini,
|
||||
OriginModelName: "gemini-3-flash-preview",
|
||||
ChannelMeta: &relaycommon.ChannelMeta{
|
||||
UpstreamModelName: "gemini-3-flash-preview",
|
||||
},
|
||||
}
|
||||
|
||||
payload := dto.GeminiChatResponse{
|
||||
Candidates: []dto.GeminiChatCandidate{
|
||||
{
|
||||
Content: dto.GeminiChatContent{
|
||||
Role: "model",
|
||||
Parts: []dto.GeminiPart{
|
||||
{Text: "ok"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
UsageMetadata: dto.GeminiUsageMetadata{
|
||||
PromptTokenCount: 151,
|
||||
ToolUsePromptTokenCount: 18329,
|
||||
CandidatesTokenCount: 1089,
|
||||
ThoughtsTokenCount: 1120,
|
||||
TotalTokenCount: 20689,
|
||||
},
|
||||
}
|
||||
|
||||
body, err := common.Marshal(payload)
|
||||
require.NoError(t, err)
|
||||
|
||||
resp := &http.Response{
|
||||
Body: io.NopCloser(bytes.NewReader(body)),
|
||||
}
|
||||
|
||||
usage, newAPIError := GeminiChatHandler(c, info, resp)
|
||||
require.Nil(t, newAPIError)
|
||||
require.NotNil(t, usage)
|
||||
require.Equal(t, 18480, usage.PromptTokens)
|
||||
require.Equal(t, 2209, usage.CompletionTokens)
|
||||
require.Equal(t, 20689, usage.TotalTokens)
|
||||
require.Equal(t, 1120, usage.CompletionTokenDetails.ReasoningTokens)
|
||||
}
|
||||
|
||||
func TestGeminiStreamHandlerCompletionTokensExcludeToolUsePromptTokens(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
c, _ := gin.CreateTestContext(httptest.NewRecorder())
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||
|
||||
oldStreamingTimeout := constant.StreamingTimeout
|
||||
constant.StreamingTimeout = 300
|
||||
t.Cleanup(func() {
|
||||
constant.StreamingTimeout = oldStreamingTimeout
|
||||
})
|
||||
|
||||
info := &relaycommon.RelayInfo{
|
||||
OriginModelName: "gemini-3-flash-preview",
|
||||
ChannelMeta: &relaycommon.ChannelMeta{
|
||||
UpstreamModelName: "gemini-3-flash-preview",
|
||||
},
|
||||
}
|
||||
|
||||
chunk := dto.GeminiChatResponse{
|
||||
Candidates: []dto.GeminiChatCandidate{
|
||||
{
|
||||
Content: dto.GeminiChatContent{
|
||||
Role: "model",
|
||||
Parts: []dto.GeminiPart{
|
||||
{Text: "partial"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
UsageMetadata: dto.GeminiUsageMetadata{
|
||||
PromptTokenCount: 151,
|
||||
ToolUsePromptTokenCount: 18329,
|
||||
CandidatesTokenCount: 1089,
|
||||
ThoughtsTokenCount: 1120,
|
||||
TotalTokenCount: 20689,
|
||||
},
|
||||
}
|
||||
|
||||
chunkData, err := common.Marshal(chunk)
|
||||
require.NoError(t, err)
|
||||
|
||||
streamBody := []byte("data: " + string(chunkData) + "\n" + "data: [DONE]\n")
|
||||
resp := &http.Response{
|
||||
Body: io.NopCloser(bytes.NewReader(streamBody)),
|
||||
}
|
||||
|
||||
usage, newAPIError := geminiStreamHandler(c, info, resp, func(_ string, _ *dto.GeminiChatResponse) bool {
|
||||
return true
|
||||
})
|
||||
require.Nil(t, newAPIError)
|
||||
require.NotNil(t, usage)
|
||||
require.Equal(t, 18480, usage.PromptTokens)
|
||||
require.Equal(t, 2209, usage.CompletionTokens)
|
||||
require.Equal(t, 20689, usage.TotalTokens)
|
||||
require.Equal(t, 1120, usage.CompletionTokenDetails.ReasoningTokens)
|
||||
}
|
||||
|
||||
func TestGeminiTextGenerationHandlerPromptTokensIncludeToolUsePromptTokens(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
gin.SetMode(gin.TestMode)
|
||||
c, _ := gin.CreateTestContext(httptest.NewRecorder())
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1beta/models/gemini-3-flash-preview:generateContent", nil)
|
||||
|
||||
info := &relaycommon.RelayInfo{
|
||||
OriginModelName: "gemini-3-flash-preview",
|
||||
ChannelMeta: &relaycommon.ChannelMeta{
|
||||
UpstreamModelName: "gemini-3-flash-preview",
|
||||
},
|
||||
}
|
||||
|
||||
payload := dto.GeminiChatResponse{
|
||||
Candidates: []dto.GeminiChatCandidate{
|
||||
{
|
||||
Content: dto.GeminiChatContent{
|
||||
Role: "model",
|
||||
Parts: []dto.GeminiPart{
|
||||
{Text: "ok"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
UsageMetadata: dto.GeminiUsageMetadata{
|
||||
PromptTokenCount: 151,
|
||||
ToolUsePromptTokenCount: 18329,
|
||||
CandidatesTokenCount: 1089,
|
||||
ThoughtsTokenCount: 1120,
|
||||
TotalTokenCount: 20689,
|
||||
},
|
||||
}
|
||||
|
||||
body, err := common.Marshal(payload)
|
||||
require.NoError(t, err)
|
||||
|
||||
resp := &http.Response{
|
||||
Body: io.NopCloser(bytes.NewReader(body)),
|
||||
}
|
||||
|
||||
usage, newAPIError := GeminiTextGenerationHandler(c, info, resp)
|
||||
require.Nil(t, newAPIError)
|
||||
require.NotNil(t, usage)
|
||||
require.Equal(t, 18480, usage.PromptTokens)
|
||||
require.Equal(t, 2209, usage.CompletionTokens)
|
||||
require.Equal(t, 20689, usage.TotalTokens)
|
||||
require.Equal(t, 1120, usage.CompletionTokenDetails.ReasoningTokens)
|
||||
}
|
||||
|
||||
func TestGeminiChatHandlerUsesEstimatedPromptTokensWhenUsagePromptMissing(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
gin.SetMode(gin.TestMode)
|
||||
c, _ := gin.CreateTestContext(httptest.NewRecorder())
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||
|
||||
info := &relaycommon.RelayInfo{
|
||||
RelayFormat: types.RelayFormatGemini,
|
||||
OriginModelName: "gemini-3-flash-preview",
|
||||
ChannelMeta: &relaycommon.ChannelMeta{
|
||||
UpstreamModelName: "gemini-3-flash-preview",
|
||||
},
|
||||
}
|
||||
info.SetEstimatePromptTokens(20)
|
||||
|
||||
payload := dto.GeminiChatResponse{
|
||||
Candidates: []dto.GeminiChatCandidate{
|
||||
{
|
||||
Content: dto.GeminiChatContent{
|
||||
Role: "model",
|
||||
Parts: []dto.GeminiPart{
|
||||
{Text: "ok"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
UsageMetadata: dto.GeminiUsageMetadata{
|
||||
PromptTokenCount: 0,
|
||||
ToolUsePromptTokenCount: 0,
|
||||
CandidatesTokenCount: 90,
|
||||
ThoughtsTokenCount: 10,
|
||||
TotalTokenCount: 110,
|
||||
},
|
||||
}
|
||||
|
||||
body, err := common.Marshal(payload)
|
||||
require.NoError(t, err)
|
||||
|
||||
resp := &http.Response{
|
||||
Body: io.NopCloser(bytes.NewReader(body)),
|
||||
}
|
||||
|
||||
usage, newAPIError := GeminiChatHandler(c, info, resp)
|
||||
require.Nil(t, newAPIError)
|
||||
require.NotNil(t, usage)
|
||||
require.Equal(t, 20, usage.PromptTokens)
|
||||
require.Equal(t, 100, usage.CompletionTokens)
|
||||
require.Equal(t, 110, usage.TotalTokens)
|
||||
}
|
||||
|
||||
func TestGeminiStreamHandlerUsesEstimatedPromptTokensWhenUsagePromptMissing(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
c, _ := gin.CreateTestContext(httptest.NewRecorder())
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||
|
||||
oldStreamingTimeout := constant.StreamingTimeout
|
||||
constant.StreamingTimeout = 300
|
||||
t.Cleanup(func() {
|
||||
constant.StreamingTimeout = oldStreamingTimeout
|
||||
})
|
||||
|
||||
info := &relaycommon.RelayInfo{
|
||||
OriginModelName: "gemini-3-flash-preview",
|
||||
ChannelMeta: &relaycommon.ChannelMeta{
|
||||
UpstreamModelName: "gemini-3-flash-preview",
|
||||
},
|
||||
}
|
||||
info.SetEstimatePromptTokens(20)
|
||||
|
||||
chunk := dto.GeminiChatResponse{
|
||||
Candidates: []dto.GeminiChatCandidate{
|
||||
{
|
||||
Content: dto.GeminiChatContent{
|
||||
Role: "model",
|
||||
Parts: []dto.GeminiPart{
|
||||
{Text: "partial"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
UsageMetadata: dto.GeminiUsageMetadata{
|
||||
PromptTokenCount: 0,
|
||||
ToolUsePromptTokenCount: 0,
|
||||
CandidatesTokenCount: 90,
|
||||
ThoughtsTokenCount: 10,
|
||||
TotalTokenCount: 110,
|
||||
},
|
||||
}
|
||||
|
||||
chunkData, err := common.Marshal(chunk)
|
||||
require.NoError(t, err)
|
||||
|
||||
streamBody := []byte("data: " + string(chunkData) + "\n" + "data: [DONE]\n")
|
||||
resp := &http.Response{
|
||||
Body: io.NopCloser(bytes.NewReader(streamBody)),
|
||||
}
|
||||
|
||||
usage, newAPIError := geminiStreamHandler(c, info, resp, func(_ string, _ *dto.GeminiChatResponse) bool {
|
||||
return true
|
||||
})
|
||||
require.Nil(t, newAPIError)
|
||||
require.NotNil(t, usage)
|
||||
require.Equal(t, 20, usage.PromptTokens)
|
||||
require.Equal(t, 100, usage.CompletionTokens)
|
||||
require.Equal(t, 110, usage.TotalTokens)
|
||||
}
|
||||
|
||||
func TestGeminiTextGenerationHandlerUsesEstimatedPromptTokensWhenUsagePromptMissing(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
gin.SetMode(gin.TestMode)
|
||||
c, _ := gin.CreateTestContext(httptest.NewRecorder())
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1beta/models/gemini-3-flash-preview:generateContent", nil)
|
||||
|
||||
info := &relaycommon.RelayInfo{
|
||||
OriginModelName: "gemini-3-flash-preview",
|
||||
ChannelMeta: &relaycommon.ChannelMeta{
|
||||
UpstreamModelName: "gemini-3-flash-preview",
|
||||
},
|
||||
}
|
||||
info.SetEstimatePromptTokens(20)
|
||||
|
||||
payload := dto.GeminiChatResponse{
|
||||
Candidates: []dto.GeminiChatCandidate{
|
||||
{
|
||||
Content: dto.GeminiChatContent{
|
||||
Role: "model",
|
||||
Parts: []dto.GeminiPart{
|
||||
{Text: "ok"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
UsageMetadata: dto.GeminiUsageMetadata{
|
||||
PromptTokenCount: 0,
|
||||
ToolUsePromptTokenCount: 0,
|
||||
CandidatesTokenCount: 90,
|
||||
ThoughtsTokenCount: 10,
|
||||
TotalTokenCount: 110,
|
||||
},
|
||||
}
|
||||
|
||||
body, err := common.Marshal(payload)
|
||||
require.NoError(t, err)
|
||||
|
||||
resp := &http.Response{
|
||||
Body: io.NopCloser(bytes.NewReader(body)),
|
||||
}
|
||||
|
||||
usage, newAPIError := GeminiTextGenerationHandler(c, info, resp)
|
||||
require.Nil(t, newAPIError)
|
||||
require.NotNil(t, usage)
|
||||
require.Equal(t, 20, usage.PromptTokens)
|
||||
require.Equal(t, 100, usage.CompletionTokens)
|
||||
require.Equal(t, 110, usage.TotalTokens)
|
||||
}
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
|
||||
"github.com/QuantumNous/new-api/dto"
|
||||
"github.com/QuantumNous/new-api/relay/channel"
|
||||
"github.com/QuantumNous/new-api/relay/channel/claude"
|
||||
"github.com/QuantumNous/new-api/relay/channel/openai"
|
||||
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||||
"github.com/QuantumNous/new-api/relay/constant"
|
||||
@@ -26,7 +27,8 @@ func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dt
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, req *dto.ClaudeRequest) (any, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
adaptor := claude.Adaptor{}
|
||||
return adaptor.ConvertClaudeRequest(c, info, req)
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
||||
@@ -119,8 +121,14 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
|
||||
return handleTTSResponse(c, resp, info)
|
||||
}
|
||||
|
||||
adaptor := openai.Adaptor{}
|
||||
return adaptor.DoResponse(c, resp, info)
|
||||
switch info.RelayFormat {
|
||||
case types.RelayFormatClaude:
|
||||
adaptor := claude.Adaptor{}
|
||||
return adaptor.DoResponse(c, resp, info)
|
||||
default:
|
||||
adaptor := openai.Adaptor{}
|
||||
return adaptor.DoResponse(c, resp, info)
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetModelList() []string {
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
channelconstant "github.com/QuantumNous/new-api/constant"
|
||||
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||||
"github.com/QuantumNous/new-api/relay/constant"
|
||||
"github.com/QuantumNous/new-api/types"
|
||||
)
|
||||
|
||||
func GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
@@ -13,13 +14,17 @@ func GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
if baseUrl == "" {
|
||||
baseUrl = channelconstant.ChannelBaseURLs[channelconstant.ChannelTypeMiniMax]
|
||||
}
|
||||
|
||||
switch info.RelayMode {
|
||||
case constant.RelayModeChatCompletions:
|
||||
return fmt.Sprintf("%s/v1/text/chatcompletion_v2", baseUrl), nil
|
||||
case constant.RelayModeAudioSpeech:
|
||||
return fmt.Sprintf("%s/v1/t2a_v2", baseUrl), nil
|
||||
switch info.RelayFormat {
|
||||
case types.RelayFormatClaude:
|
||||
return fmt.Sprintf("%s/anthropic/v1/messages", info.ChannelBaseUrl), nil
|
||||
default:
|
||||
return "", fmt.Errorf("unsupported relay mode: %d", info.RelayMode)
|
||||
switch info.RelayMode {
|
||||
case constant.RelayModeChatCompletions:
|
||||
return fmt.Sprintf("%s/v1/text/chatcompletion_v2", baseUrl), nil
|
||||
case constant.RelayModeAudioSpeech:
|
||||
return fmt.Sprintf("%s/v1/t2a_v2", baseUrl), nil
|
||||
default:
|
||||
return "", fmt.Errorf("unsupported relay mode: %d", info.RelayMode)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
"github.com/QuantumNous/new-api/logger"
|
||||
"github.com/QuantumNous/new-api/model"
|
||||
"github.com/QuantumNous/new-api/relay/channel"
|
||||
"github.com/QuantumNous/new-api/relay/channel/task/taskcommon"
|
||||
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||||
"github.com/QuantumNous/new-api/service"
|
||||
"github.com/samber/lo"
|
||||
@@ -108,10 +109,10 @@ type AliMetadata struct {
|
||||
// ============================
|
||||
|
||||
type TaskAdaptor struct {
|
||||
taskcommon.BaseBilling
|
||||
ChannelType int
|
||||
apiKey string
|
||||
baseURL string
|
||||
aliReq *AliVideoRequest
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) {
|
||||
@@ -121,17 +122,7 @@ func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) {
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.TaskError) {
|
||||
// 阿里通义万相支持 JSON 格式,不使用 multipart
|
||||
var taskReq relaycommon.TaskSubmitReq
|
||||
if err := common.UnmarshalBodyReusable(c, &taskReq); err != nil {
|
||||
return service.TaskErrorWrapper(err, "unmarshal_task_request_failed", http.StatusBadRequest)
|
||||
}
|
||||
aliReq, err := a.convertToAliRequest(info, taskReq)
|
||||
if err != nil {
|
||||
return service.TaskErrorWrapper(err, "convert_to_ali_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
a.aliReq = aliReq
|
||||
logger.LogJson(c, "ali video request body", aliReq)
|
||||
// ValidateMultipartDirect 负责解析并将原始 TaskSubmitReq 存入 context
|
||||
return relaycommon.ValidateMultipartDirect(c, info)
|
||||
}
|
||||
|
||||
@@ -148,11 +139,21 @@ func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayInfo) (io.Reader, error) {
|
||||
bodyBytes, err := common.Marshal(a.aliReq)
|
||||
taskReq, err := relaycommon.GetTaskRequest(c)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "get_task_request_failed")
|
||||
}
|
||||
|
||||
aliReq, err := a.convertToAliRequest(info, taskReq)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "convert_to_ali_request_failed")
|
||||
}
|
||||
logger.LogJson(c, "ali video request body", aliReq)
|
||||
|
||||
bodyBytes, err := common.Marshal(aliReq)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "marshal_ali_request_failed")
|
||||
}
|
||||
|
||||
return bytes.NewReader(bodyBytes), nil
|
||||
}
|
||||
|
||||
@@ -252,8 +253,12 @@ func ProcessAliOtherRatios(aliReq *AliVideoRequest) (map[string]float64, error)
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) convertToAliRequest(info *relaycommon.RelayInfo, req relaycommon.TaskSubmitReq) (*AliVideoRequest, error) {
|
||||
upstreamModel := req.Model
|
||||
if info.IsModelMapped {
|
||||
upstreamModel = info.UpstreamModelName
|
||||
}
|
||||
aliReq := &AliVideoRequest{
|
||||
Model: req.Model,
|
||||
Model: upstreamModel,
|
||||
Input: AliVideoInput{
|
||||
Prompt: req.Prompt,
|
||||
ImgURL: req.InputReference,
|
||||
@@ -331,23 +336,37 @@ func (a *TaskAdaptor) convertToAliRequest(info *relaycommon.RelayInfo, req relay
|
||||
}
|
||||
}
|
||||
|
||||
if aliReq.Model != req.Model {
|
||||
if aliReq.Model != upstreamModel {
|
||||
return nil, errors.New("can't change model with metadata")
|
||||
}
|
||||
|
||||
info.PriceData.OtherRatios = map[string]float64{
|
||||
return aliReq, nil
|
||||
}
|
||||
|
||||
// EstimateBilling 根据用户请求参数计算 OtherRatios(时长、分辨率等)。
|
||||
// 在 ValidateRequestAndSetAction 之后、价格计算之前调用。
|
||||
func (a *TaskAdaptor) EstimateBilling(c *gin.Context, info *relaycommon.RelayInfo) map[string]float64 {
|
||||
taskReq, err := relaycommon.GetTaskRequest(c)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
aliReq, err := a.convertToAliRequest(info, taskReq)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
otherRatios := map[string]float64{
|
||||
"seconds": float64(aliReq.Parameters.Duration),
|
||||
}
|
||||
|
||||
ratios, err := ProcessAliOtherRatios(aliReq)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return otherRatios
|
||||
}
|
||||
for s, f := range ratios {
|
||||
info.PriceData.OtherRatios[s] = f
|
||||
for k, v := range ratios {
|
||||
otherRatios[k] = v
|
||||
}
|
||||
|
||||
return aliReq, nil
|
||||
return otherRatios
|
||||
}
|
||||
|
||||
// DoRequest delegates to common helper
|
||||
@@ -384,7 +403,8 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela
|
||||
|
||||
// 转换为 OpenAI 格式响应
|
||||
openAIResp := dto.NewOpenAIVideo()
|
||||
openAIResp.ID = aliResp.Output.TaskID
|
||||
openAIResp.ID = info.PublicTaskID
|
||||
openAIResp.TaskID = info.PublicTaskID
|
||||
openAIResp.Model = c.GetString("model")
|
||||
if openAIResp.Model == "" && info != nil {
|
||||
openAIResp.Model = info.OriginModelName
|
||||
|
||||
@@ -2,7 +2,6 @@ package doubao
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
@@ -14,6 +13,7 @@ import (
|
||||
"github.com/QuantumNous/new-api/dto"
|
||||
"github.com/QuantumNous/new-api/model"
|
||||
"github.com/QuantumNous/new-api/relay/channel"
|
||||
taskcommon "github.com/QuantumNous/new-api/relay/channel/task/taskcommon"
|
||||
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||||
"github.com/QuantumNous/new-api/service"
|
||||
|
||||
@@ -89,6 +89,7 @@ type responseTask struct {
|
||||
// ============================
|
||||
|
||||
type TaskAdaptor struct {
|
||||
taskcommon.BaseBilling
|
||||
ChannelType int
|
||||
apiKey string
|
||||
baseURL string
|
||||
@@ -130,8 +131,12 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayIn
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "convert request payload failed")
|
||||
}
|
||||
info.UpstreamModelName = body.Model
|
||||
data, err := json.Marshal(body)
|
||||
if info.IsModelMapped {
|
||||
body.Model = info.UpstreamModelName
|
||||
} else {
|
||||
info.UpstreamModelName = body.Model
|
||||
}
|
||||
data, err := common.Marshal(body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -154,7 +159,7 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela
|
||||
|
||||
// Parse Doubao response
|
||||
var dResp responsePayload
|
||||
if err := json.Unmarshal(responseBody, &dResp); err != nil {
|
||||
if err := common.Unmarshal(responseBody, &dResp); err != nil {
|
||||
taskErr = service.TaskErrorWrapper(errors.Wrapf(err, "body: %s", responseBody), "unmarshal_response_body_failed", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
@@ -165,8 +170,8 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela
|
||||
}
|
||||
|
||||
ov := dto.NewOpenAIVideo()
|
||||
ov.ID = dResp.ID
|
||||
ov.TaskID = dResp.ID
|
||||
ov.ID = info.PublicTaskID
|
||||
ov.TaskID = info.PublicTaskID
|
||||
ov.CreatedAt = time.Now().Unix()
|
||||
ov.Model = info.OriginModelName
|
||||
|
||||
@@ -234,12 +239,7 @@ func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*
|
||||
}
|
||||
|
||||
metadata := req.Metadata
|
||||
medaBytes, err := json.Marshal(metadata)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "metadata marshal metadata failed")
|
||||
}
|
||||
err = json.Unmarshal(medaBytes, &r)
|
||||
if err != nil {
|
||||
if err := taskcommon.UnmarshalMetadata(metadata, &r); err != nil {
|
||||
return nil, errors.Wrap(err, "unmarshal metadata failed")
|
||||
}
|
||||
|
||||
@@ -248,7 +248,7 @@ func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*
|
||||
|
||||
func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) {
|
||||
resTask := responseTask{}
|
||||
if err := json.Unmarshal(respBody, &resTask); err != nil {
|
||||
if err := common.Unmarshal(respBody, &resTask); err != nil {
|
||||
return nil, errors.Wrap(err, "unmarshal task result failed")
|
||||
}
|
||||
|
||||
@@ -286,7 +286,7 @@ func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, e
|
||||
|
||||
func (a *TaskAdaptor) ConvertToOpenAIVideo(originTask *model.Task) ([]byte, error) {
|
||||
var dResp responseTask
|
||||
if err := json.Unmarshal(originTask.Data, &dResp); err != nil {
|
||||
if err := common.Unmarshal(originTask.Data, &dResp); err != nil {
|
||||
return nil, errors.Wrap(err, "unmarshal doubao task data failed")
|
||||
}
|
||||
|
||||
@@ -307,6 +307,5 @@ func (a *TaskAdaptor) ConvertToOpenAIVideo(originTask *model.Task) ([]byte, erro
|
||||
}
|
||||
}
|
||||
|
||||
jsonData, _ := common.Marshal(openAIVideo)
|
||||
return jsonData, nil
|
||||
return common.Marshal(openAIVideo)
|
||||
}
|
||||
|
||||
@@ -2,8 +2,6 @@ package gemini
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
@@ -16,10 +14,10 @@ import (
|
||||
"github.com/QuantumNous/new-api/dto"
|
||||
"github.com/QuantumNous/new-api/model"
|
||||
"github.com/QuantumNous/new-api/relay/channel"
|
||||
taskcommon "github.com/QuantumNous/new-api/relay/channel/task/taskcommon"
|
||||
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||||
"github.com/QuantumNous/new-api/service"
|
||||
"github.com/QuantumNous/new-api/setting/model_setting"
|
||||
"github.com/QuantumNous/new-api/setting/system_setting"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
@@ -87,6 +85,7 @@ type operationResponse struct {
|
||||
// ============================
|
||||
|
||||
type TaskAdaptor struct {
|
||||
taskcommon.BaseBilling
|
||||
ChannelType int
|
||||
apiKey string
|
||||
baseURL string
|
||||
@@ -106,7 +105,7 @@ func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycom
|
||||
|
||||
// BuildRequestURL constructs the upstream URL.
|
||||
func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
modelName := info.OriginModelName
|
||||
modelName := info.UpstreamModelName
|
||||
version := model_setting.GetGeminiVersionSetting(modelName)
|
||||
|
||||
return fmt.Sprintf(
|
||||
@@ -145,16 +144,11 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayIn
|
||||
}
|
||||
|
||||
metadata := req.Metadata
|
||||
medaBytes, err := json.Marshal(metadata)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "metadata marshal metadata failed")
|
||||
}
|
||||
err = json.Unmarshal(medaBytes, &body.Parameters)
|
||||
if err != nil {
|
||||
if err := taskcommon.UnmarshalMetadata(metadata, &body.Parameters); err != nil {
|
||||
return nil, errors.Wrap(err, "unmarshal metadata failed")
|
||||
}
|
||||
|
||||
data, err := json.Marshal(body)
|
||||
data, err := common.Marshal(body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -175,16 +169,16 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela
|
||||
_ = resp.Body.Close()
|
||||
|
||||
var s submitResponse
|
||||
if err := json.Unmarshal(responseBody, &s); err != nil {
|
||||
if err := common.Unmarshal(responseBody, &s); err != nil {
|
||||
return "", nil, service.TaskErrorWrapper(err, "unmarshal_response_failed", http.StatusInternalServerError)
|
||||
}
|
||||
if strings.TrimSpace(s.Name) == "" {
|
||||
return "", nil, service.TaskErrorWrapper(fmt.Errorf("missing operation name"), "invalid_response", http.StatusInternalServerError)
|
||||
}
|
||||
taskID = encodeLocalTaskID(s.Name)
|
||||
taskID = taskcommon.EncodeLocalTaskID(s.Name)
|
||||
ov := dto.NewOpenAIVideo()
|
||||
ov.ID = taskID
|
||||
ov.TaskID = taskID
|
||||
ov.ID = info.PublicTaskID
|
||||
ov.TaskID = info.PublicTaskID
|
||||
ov.CreatedAt = time.Now().Unix()
|
||||
ov.Model = info.OriginModelName
|
||||
c.JSON(http.StatusOK, ov)
|
||||
@@ -206,7 +200,7 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any, proxy
|
||||
return nil, fmt.Errorf("invalid task_id")
|
||||
}
|
||||
|
||||
upstreamName, err := decodeLocalTaskID(taskID)
|
||||
upstreamName, err := taskcommon.DecodeLocalTaskID(taskID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decode task_id failed: %w", err)
|
||||
}
|
||||
@@ -232,7 +226,7 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any, proxy
|
||||
|
||||
func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) {
|
||||
var op operationResponse
|
||||
if err := json.Unmarshal(respBody, &op); err != nil {
|
||||
if err := common.Unmarshal(respBody, &op); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal operation response failed: %w", err)
|
||||
}
|
||||
|
||||
@@ -254,9 +248,8 @@ func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, e
|
||||
ti.Status = model.TaskStatusSuccess
|
||||
ti.Progress = "100%"
|
||||
|
||||
taskID := encodeLocalTaskID(op.Name)
|
||||
ti.TaskID = taskID
|
||||
ti.Url = fmt.Sprintf("%s/v1/videos/%s/content", system_setting.ServerAddress, taskID)
|
||||
ti.TaskID = taskcommon.EncodeLocalTaskID(op.Name)
|
||||
// Url intentionally left empty — the caller constructs the proxy URL using the public task ID
|
||||
|
||||
// Extract URL from generateVideoResponse if available
|
||||
if len(op.Response.GenerateVideoResponse.GeneratedSamples) > 0 {
|
||||
@@ -269,7 +262,10 @@ func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, e
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) ConvertToOpenAIVideo(task *model.Task) ([]byte, error) {
|
||||
upstreamName, err := decodeLocalTaskID(task.TaskID)
|
||||
// Use GetUpstreamTaskID() to get the real upstream operation name for model extraction.
|
||||
// task.TaskID is now a public task_xxxx ID, no longer a base64-encoded upstream name.
|
||||
upstreamTaskID := task.GetUpstreamTaskID()
|
||||
upstreamName, err := taskcommon.DecodeLocalTaskID(upstreamTaskID)
|
||||
if err != nil {
|
||||
upstreamName = ""
|
||||
}
|
||||
@@ -297,18 +293,6 @@ func (a *TaskAdaptor) ConvertToOpenAIVideo(task *model.Task) ([]byte, error) {
|
||||
// helpers
|
||||
// ============================
|
||||
|
||||
func encodeLocalTaskID(name string) string {
|
||||
return base64.RawURLEncoding.EncodeToString([]byte(name))
|
||||
}
|
||||
|
||||
func decodeLocalTaskID(local string) (string, error) {
|
||||
b, err := base64.RawURLEncoding.DecodeString(local)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(b), nil
|
||||
}
|
||||
|
||||
var modelRe = regexp.MustCompile(`models/([^/]+)/operations/`)
|
||||
|
||||
func extractModelFromOperationName(name string) string {
|
||||
|
||||
@@ -2,7 +2,6 @@ package hailuo
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
@@ -18,12 +17,14 @@ import (
|
||||
"github.com/QuantumNous/new-api/constant"
|
||||
"github.com/QuantumNous/new-api/dto"
|
||||
"github.com/QuantumNous/new-api/relay/channel"
|
||||
taskcommon "github.com/QuantumNous/new-api/relay/channel/task/taskcommon"
|
||||
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||||
"github.com/QuantumNous/new-api/service"
|
||||
)
|
||||
|
||||
// https://platform.minimaxi.com/docs/api-reference/video-generation-intro
|
||||
type TaskAdaptor struct {
|
||||
taskcommon.BaseBilling
|
||||
ChannelType int
|
||||
apiKey string
|
||||
baseURL string
|
||||
@@ -60,12 +61,12 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayIn
|
||||
return nil, fmt.Errorf("invalid request type in context")
|
||||
}
|
||||
|
||||
body, err := a.convertToRequestPayload(&req)
|
||||
body, err := a.convertToRequestPayload(&req, info)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "convert request payload failed")
|
||||
}
|
||||
|
||||
data, err := json.Marshal(body)
|
||||
data, err := common.Marshal(body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -86,7 +87,7 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela
|
||||
_ = resp.Body.Close()
|
||||
|
||||
var hResp VideoResponse
|
||||
if err := json.Unmarshal(responseBody, &hResp); err != nil {
|
||||
if err := common.Unmarshal(responseBody, &hResp); err != nil {
|
||||
taskErr = service.TaskErrorWrapper(errors.Wrapf(err, "body: %s", responseBody), "unmarshal_response_body_failed", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
@@ -101,8 +102,8 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela
|
||||
}
|
||||
|
||||
ov := dto.NewOpenAIVideo()
|
||||
ov.ID = hResp.TaskID
|
||||
ov.TaskID = hResp.TaskID
|
||||
ov.ID = info.PublicTaskID
|
||||
ov.TaskID = info.PublicTaskID
|
||||
ov.CreatedAt = time.Now().Unix()
|
||||
ov.Model = info.OriginModelName
|
||||
|
||||
@@ -141,8 +142,8 @@ func (a *TaskAdaptor) GetChannelName() string {
|
||||
return ChannelName
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*VideoRequest, error) {
|
||||
modelConfig := GetModelConfig(req.Model)
|
||||
func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq, info *relaycommon.RelayInfo) (*VideoRequest, error) {
|
||||
modelConfig := GetModelConfig(info.UpstreamModelName)
|
||||
duration := DefaultDuration
|
||||
if req.Duration > 0 {
|
||||
duration = req.Duration
|
||||
@@ -153,7 +154,7 @@ func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*
|
||||
}
|
||||
|
||||
videoRequest := &VideoRequest{
|
||||
Model: req.Model,
|
||||
Model: info.UpstreamModelName,
|
||||
Prompt: req.Prompt,
|
||||
Duration: &duration,
|
||||
Resolution: resolution,
|
||||
@@ -182,7 +183,7 @@ func (a *TaskAdaptor) parseResolutionFromSize(size string, modelConfig ModelConf
|
||||
|
||||
func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) {
|
||||
resTask := QueryTaskResponse{}
|
||||
if err := json.Unmarshal(respBody, &resTask); err != nil {
|
||||
if err := common.Unmarshal(respBody, &resTask); err != nil {
|
||||
return nil, errors.Wrap(err, "unmarshal task result failed")
|
||||
}
|
||||
|
||||
@@ -224,7 +225,7 @@ func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, e
|
||||
|
||||
func (a *TaskAdaptor) ConvertToOpenAIVideo(originTask *model.Task) ([]byte, error) {
|
||||
var hailuoResp QueryTaskResponse
|
||||
if err := json.Unmarshal(originTask.Data, &hailuoResp); err != nil {
|
||||
if err := common.Unmarshal(originTask.Data, &hailuoResp); err != nil {
|
||||
return nil, errors.Wrap(err, "unmarshal hailuo task data failed")
|
||||
}
|
||||
|
||||
@@ -271,7 +272,7 @@ func (a *TaskAdaptor) buildVideoURL(_, fileID string) string {
|
||||
}
|
||||
|
||||
var retrieveResp RetrieveFileResponse
|
||||
if err := json.Unmarshal(responseBody, &retrieveResp); err != nil {
|
||||
if err := common.Unmarshal(responseBody, &retrieveResp); err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
|
||||
@@ -6,7 +6,6 @@ import (
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
@@ -25,6 +24,7 @@ import (
|
||||
"github.com/QuantumNous/new-api/constant"
|
||||
"github.com/QuantumNous/new-api/dto"
|
||||
"github.com/QuantumNous/new-api/relay/channel"
|
||||
taskcommon "github.com/QuantumNous/new-api/relay/channel/task/taskcommon"
|
||||
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||||
"github.com/QuantumNous/new-api/service"
|
||||
)
|
||||
@@ -77,6 +77,7 @@ const (
|
||||
// ============================
|
||||
|
||||
type TaskAdaptor struct {
|
||||
taskcommon.BaseBilling
|
||||
ChannelType int
|
||||
accessKey string
|
||||
secretKey string
|
||||
@@ -164,11 +165,11 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayIn
|
||||
}
|
||||
}
|
||||
|
||||
body, err := a.convertToRequestPayload(&req)
|
||||
body, err := a.convertToRequestPayload(&req, info)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "convert request payload failed")
|
||||
}
|
||||
data, err := json.Marshal(body)
|
||||
data, err := common.Marshal(body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -191,7 +192,7 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela
|
||||
|
||||
// Parse Jimeng response
|
||||
var jResp responsePayload
|
||||
if err := json.Unmarshal(responseBody, &jResp); err != nil {
|
||||
if err := common.Unmarshal(responseBody, &jResp); err != nil {
|
||||
taskErr = service.TaskErrorWrapper(errors.Wrapf(err, "body: %s", responseBody), "unmarshal_response_body_failed", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
@@ -202,8 +203,8 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela
|
||||
}
|
||||
|
||||
ov := dto.NewOpenAIVideo()
|
||||
ov.ID = jResp.Data.TaskID
|
||||
ov.TaskID = jResp.Data.TaskID
|
||||
ov.ID = info.PublicTaskID
|
||||
ov.TaskID = info.PublicTaskID
|
||||
ov.CreatedAt = time.Now().Unix()
|
||||
ov.Model = info.OriginModelName
|
||||
c.JSON(http.StatusOK, ov)
|
||||
@@ -225,7 +226,7 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any, proxy
|
||||
"req_key": "jimeng_vgfm_t2v_l20", // This is fixed value from doc: https://www.volcengine.com/docs/85621/1544774
|
||||
"task_id": taskID,
|
||||
}
|
||||
payloadBytes, err := json.Marshal(payload)
|
||||
payloadBytes, err := common.Marshal(payload)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "marshal fetch task payload failed")
|
||||
}
|
||||
@@ -377,9 +378,9 @@ func hmacSHA256(key []byte, data []byte) []byte {
|
||||
return h.Sum(nil)
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*requestPayload, error) {
|
||||
func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq, info *relaycommon.RelayInfo) (*requestPayload, error) {
|
||||
r := requestPayload{
|
||||
ReqKey: req.Model,
|
||||
ReqKey: info.UpstreamModelName,
|
||||
Prompt: req.Prompt,
|
||||
}
|
||||
|
||||
@@ -398,13 +399,7 @@ func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*
|
||||
r.BinaryDataBase64 = req.Images
|
||||
}
|
||||
}
|
||||
metadata := req.Metadata
|
||||
medaBytes, err := json.Marshal(metadata)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "metadata marshal metadata failed")
|
||||
}
|
||||
err = json.Unmarshal(medaBytes, &r)
|
||||
if err != nil {
|
||||
if err := taskcommon.UnmarshalMetadata(req.Metadata, &r); err != nil {
|
||||
return nil, errors.Wrap(err, "unmarshal metadata failed")
|
||||
}
|
||||
|
||||
@@ -432,7 +427,7 @@ func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*
|
||||
|
||||
func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) {
|
||||
resTask := responseTask{}
|
||||
if err := json.Unmarshal(respBody, &resTask); err != nil {
|
||||
if err := common.Unmarshal(respBody, &resTask); err != nil {
|
||||
return nil, errors.Wrap(err, "unmarshal task result failed")
|
||||
}
|
||||
taskResult := relaycommon.TaskInfo{}
|
||||
@@ -458,7 +453,7 @@ func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, e
|
||||
|
||||
func (a *TaskAdaptor) ConvertToOpenAIVideo(originTask *model.Task) ([]byte, error) {
|
||||
var jimengResp responseTask
|
||||
if err := json.Unmarshal(originTask.Data, &jimengResp); err != nil {
|
||||
if err := common.Unmarshal(originTask.Data, &jimengResp); err != nil {
|
||||
return nil, errors.Wrap(err, "unmarshal jimeng task data failed")
|
||||
}
|
||||
|
||||
@@ -477,8 +472,7 @@ func (a *TaskAdaptor) ConvertToOpenAIVideo(originTask *model.Task) ([]byte, erro
|
||||
}
|
||||
}
|
||||
|
||||
jsonData, _ := common.Marshal(openAIVideo)
|
||||
return jsonData, nil
|
||||
return common.Marshal(openAIVideo)
|
||||
}
|
||||
|
||||
func isNewAPIRelay(apiKey string) bool {
|
||||
|
||||
@@ -2,7 +2,6 @@ package kling
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
@@ -21,6 +20,7 @@ import (
|
||||
"github.com/QuantumNous/new-api/constant"
|
||||
"github.com/QuantumNous/new-api/dto"
|
||||
"github.com/QuantumNous/new-api/relay/channel"
|
||||
taskcommon "github.com/QuantumNous/new-api/relay/channel/task/taskcommon"
|
||||
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||||
"github.com/QuantumNous/new-api/service"
|
||||
)
|
||||
@@ -97,6 +97,7 @@ type responsePayload struct {
|
||||
// ============================
|
||||
|
||||
type TaskAdaptor struct {
|
||||
taskcommon.BaseBilling
|
||||
ChannelType int
|
||||
apiKey string
|
||||
baseURL string
|
||||
@@ -149,14 +150,14 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayIn
|
||||
}
|
||||
req := v.(relaycommon.TaskSubmitReq)
|
||||
|
||||
body, err := a.convertToRequestPayload(&req)
|
||||
body, err := a.convertToRequestPayload(&req, info)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if body.Image == "" && body.ImageTail == "" {
|
||||
c.Set("action", constant.TaskActionTextGenerate)
|
||||
}
|
||||
data, err := json.Marshal(body)
|
||||
data, err := common.Marshal(body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -180,7 +181,7 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela
|
||||
}
|
||||
|
||||
var kResp responsePayload
|
||||
err = json.Unmarshal(responseBody, &kResp)
|
||||
err = common.Unmarshal(responseBody, &kResp)
|
||||
if err != nil {
|
||||
taskErr = service.TaskErrorWrapper(err, "unmarshal_response_failed", http.StatusInternalServerError)
|
||||
return
|
||||
@@ -190,8 +191,8 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela
|
||||
return
|
||||
}
|
||||
ov := dto.NewOpenAIVideo()
|
||||
ov.ID = kResp.Data.TaskId
|
||||
ov.TaskID = kResp.Data.TaskId
|
||||
ov.ID = info.PublicTaskID
|
||||
ov.TaskID = info.PublicTaskID
|
||||
ov.CreatedAt = time.Now().Unix()
|
||||
ov.Model = info.OriginModelName
|
||||
c.JSON(http.StatusOK, ov)
|
||||
@@ -247,15 +248,15 @@ func (a *TaskAdaptor) GetChannelName() string {
|
||||
// helpers
|
||||
// ============================
|
||||
|
||||
func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*requestPayload, error) {
|
||||
func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq, info *relaycommon.RelayInfo) (*requestPayload, error) {
|
||||
r := requestPayload{
|
||||
Prompt: req.Prompt,
|
||||
Image: req.Image,
|
||||
Mode: defaultString(req.Mode, "std"),
|
||||
Duration: fmt.Sprintf("%d", defaultInt(req.Duration, 5)),
|
||||
Mode: taskcommon.DefaultString(req.Mode, "std"),
|
||||
Duration: fmt.Sprintf("%d", taskcommon.DefaultInt(req.Duration, 5)),
|
||||
AspectRatio: a.getAspectRatio(req.Size),
|
||||
ModelName: req.Model,
|
||||
Model: req.Model, // Keep consistent with model_name, double writing improves compatibility
|
||||
ModelName: info.UpstreamModelName,
|
||||
Model: info.UpstreamModelName,
|
||||
CfgScale: 0.5,
|
||||
StaticMask: "",
|
||||
DynamicMasks: []DynamicMask{},
|
||||
@@ -265,14 +266,9 @@ func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*
|
||||
}
|
||||
if r.ModelName == "" {
|
||||
r.ModelName = "kling-v1"
|
||||
r.Model = "kling-v1"
|
||||
}
|
||||
metadata := req.Metadata
|
||||
medaBytes, err := json.Marshal(metadata)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "metadata marshal metadata failed")
|
||||
}
|
||||
err = json.Unmarshal(medaBytes, &r)
|
||||
if err != nil {
|
||||
if err := taskcommon.UnmarshalMetadata(req.Metadata, &r); err != nil {
|
||||
return nil, errors.Wrap(err, "unmarshal metadata failed")
|
||||
}
|
||||
return &r, nil
|
||||
@@ -291,20 +287,6 @@ func (a *TaskAdaptor) getAspectRatio(size string) string {
|
||||
}
|
||||
}
|
||||
|
||||
func defaultString(s, def string) string {
|
||||
if strings.TrimSpace(s) == "" {
|
||||
return def
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
func defaultInt(v int, def int) int {
|
||||
if v == 0 {
|
||||
return def
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
// ============================
|
||||
// JWT helpers
|
||||
// ============================
|
||||
@@ -340,7 +322,7 @@ func (a *TaskAdaptor) createJWTTokenWithKey(apiKey string) (string, error) {
|
||||
func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) {
|
||||
taskInfo := &relaycommon.TaskInfo{}
|
||||
resPayload := responsePayload{}
|
||||
err := json.Unmarshal(respBody, &resPayload)
|
||||
err := common.Unmarshal(respBody, &resPayload)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to unmarshal response body")
|
||||
}
|
||||
@@ -374,7 +356,7 @@ func isNewAPIRelay(apiKey string) bool {
|
||||
|
||||
func (a *TaskAdaptor) ConvertToOpenAIVideo(originTask *model.Task) ([]byte, error) {
|
||||
var klingResp responsePayload
|
||||
if err := json.Unmarshal(originTask.Data, &klingResp); err != nil {
|
||||
if err := common.Unmarshal(originTask.Data, &klingResp); err != nil {
|
||||
return nil, errors.Wrap(err, "unmarshal kling task data failed")
|
||||
}
|
||||
|
||||
@@ -401,6 +383,5 @@ func (a *TaskAdaptor) ConvertToOpenAIVideo(originTask *model.Task) ([]byte, erro
|
||||
Code: fmt.Sprintf("%d", klingResp.Code),
|
||||
}
|
||||
}
|
||||
jsonData, _ := common.Marshal(openAIVideo)
|
||||
return jsonData, nil
|
||||
return common.Marshal(openAIVideo)
|
||||
}
|
||||
|
||||
@@ -1,9 +1,13 @@
|
||||
package sora
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"net/textproto"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
@@ -11,12 +15,13 @@ import (
|
||||
"github.com/QuantumNous/new-api/dto"
|
||||
"github.com/QuantumNous/new-api/model"
|
||||
"github.com/QuantumNous/new-api/relay/channel"
|
||||
taskcommon "github.com/QuantumNous/new-api/relay/channel/task/taskcommon"
|
||||
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||||
"github.com/QuantumNous/new-api/service"
|
||||
"github.com/QuantumNous/new-api/setting/system_setting"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
|
||||
// ============================
|
||||
@@ -57,6 +62,7 @@ type responseTask struct {
|
||||
// ============================
|
||||
|
||||
type TaskAdaptor struct {
|
||||
taskcommon.BaseBilling
|
||||
ChannelType int
|
||||
apiKey string
|
||||
baseURL string
|
||||
@@ -69,15 +75,15 @@ func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) {
|
||||
}
|
||||
|
||||
func validateRemixRequest(c *gin.Context) *dto.TaskError {
|
||||
var req struct {
|
||||
Prompt string `json:"prompt"`
|
||||
}
|
||||
var req relaycommon.TaskSubmitReq
|
||||
if err := common.UnmarshalBodyReusable(c, &req); err != nil {
|
||||
return service.TaskErrorWrapperLocal(err, "invalid_request", http.StatusBadRequest)
|
||||
}
|
||||
if strings.TrimSpace(req.Prompt) == "" {
|
||||
return service.TaskErrorWrapperLocal(fmt.Errorf("field prompt is required"), "invalid_request", http.StatusBadRequest)
|
||||
}
|
||||
// 存储原始请求到 context,与 ValidateMultipartDirect 路径保持一致
|
||||
c.Set("task_request", req)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -88,6 +94,41 @@ func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycom
|
||||
return relaycommon.ValidateMultipartDirect(c, info)
|
||||
}
|
||||
|
||||
// EstimateBilling 根据用户请求的 seconds 和 size 计算 OtherRatios。
|
||||
func (a *TaskAdaptor) EstimateBilling(c *gin.Context, info *relaycommon.RelayInfo) map[string]float64 {
|
||||
// remix 路径的 OtherRatios 已在 ResolveOriginTask 中设置
|
||||
if info.Action == constant.TaskActionRemix {
|
||||
return nil
|
||||
}
|
||||
|
||||
req, err := relaycommon.GetTaskRequest(c)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
seconds, _ := strconv.Atoi(req.Seconds)
|
||||
if seconds == 0 {
|
||||
seconds = req.Duration
|
||||
}
|
||||
if seconds <= 0 {
|
||||
seconds = 4
|
||||
}
|
||||
|
||||
size := req.Size
|
||||
if size == "" {
|
||||
size = "720x1280"
|
||||
}
|
||||
|
||||
ratios := map[string]float64{
|
||||
"seconds": float64(seconds),
|
||||
"size": 1,
|
||||
}
|
||||
if size == "1792x1024" || size == "1024x1792" {
|
||||
ratios["size"] = 1.666667
|
||||
}
|
||||
return ratios
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
if info.Action == constant.TaskActionRemix {
|
||||
return fmt.Sprintf("%s/v1/videos/%s/remix", a.baseURL, info.OriginTaskID), nil
|
||||
@@ -107,6 +148,74 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayIn
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "get_request_body_failed")
|
||||
}
|
||||
cachedBody, err := storage.Bytes()
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "read_body_bytes_failed")
|
||||
}
|
||||
contentType := c.GetHeader("Content-Type")
|
||||
|
||||
if strings.HasPrefix(contentType, "application/json") {
|
||||
var bodyMap map[string]interface{}
|
||||
if err := common.Unmarshal(cachedBody, &bodyMap); err == nil {
|
||||
bodyMap["model"] = info.UpstreamModelName
|
||||
if newBody, err := common.Marshal(bodyMap); err == nil {
|
||||
return bytes.NewReader(newBody), nil
|
||||
}
|
||||
}
|
||||
return bytes.NewReader(cachedBody), nil
|
||||
}
|
||||
|
||||
if strings.Contains(contentType, "multipart/form-data") {
|
||||
formData, err := common.ParseMultipartFormReusable(c)
|
||||
if err != nil {
|
||||
return bytes.NewReader(cachedBody), nil
|
||||
}
|
||||
var buf bytes.Buffer
|
||||
writer := multipart.NewWriter(&buf)
|
||||
writer.WriteField("model", info.UpstreamModelName)
|
||||
for key, values := range formData.Value {
|
||||
if key == "model" {
|
||||
continue
|
||||
}
|
||||
for _, v := range values {
|
||||
writer.WriteField(key, v)
|
||||
}
|
||||
}
|
||||
for fieldName, fileHeaders := range formData.File {
|
||||
for _, fh := range fileHeaders {
|
||||
f, err := fh.Open()
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
ct := fh.Header.Get("Content-Type")
|
||||
if ct == "" || ct == "application/octet-stream" {
|
||||
buf512 := make([]byte, 512)
|
||||
n, _ := io.ReadFull(f, buf512)
|
||||
ct = http.DetectContentType(buf512[:n])
|
||||
// Re-open after sniffing so the full content is copied below
|
||||
f.Close()
|
||||
f, err = fh.Open()
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
}
|
||||
h := make(textproto.MIMEHeader)
|
||||
h.Set("Content-Disposition", fmt.Sprintf(`form-data; name="%s"; filename="%s"`, fieldName, fh.Filename))
|
||||
h.Set("Content-Type", ct)
|
||||
part, err := writer.CreatePart(h)
|
||||
if err != nil {
|
||||
f.Close()
|
||||
continue
|
||||
}
|
||||
io.Copy(part, f)
|
||||
f.Close()
|
||||
}
|
||||
}
|
||||
writer.Close()
|
||||
c.Request.Header.Set("Content-Type", writer.FormDataContentType())
|
||||
return &buf, nil
|
||||
}
|
||||
|
||||
return common.ReaderOnly(storage), nil
|
||||
}
|
||||
|
||||
@@ -116,7 +225,7 @@ func (a *TaskAdaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, req
|
||||
}
|
||||
|
||||
// DoResponse handles upstream response, returns taskID etc.
|
||||
func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, _ *relaycommon.RelayInfo) (taskID string, taskData []byte, taskErr *dto.TaskError) {
|
||||
func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (taskID string, taskData []byte, taskErr *dto.TaskError) {
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
taskErr = service.TaskErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
|
||||
@@ -131,17 +240,20 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, _ *relayco
|
||||
return
|
||||
}
|
||||
|
||||
if dResp.ID == "" {
|
||||
if dResp.TaskID == "" {
|
||||
taskErr = service.TaskErrorWrapper(fmt.Errorf("task_id is empty"), "invalid_response", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
dResp.ID = dResp.TaskID
|
||||
dResp.TaskID = ""
|
||||
upstreamID := dResp.ID
|
||||
if upstreamID == "" {
|
||||
upstreamID = dResp.TaskID
|
||||
}
|
||||
if upstreamID == "" {
|
||||
taskErr = service.TaskErrorWrapper(fmt.Errorf("task_id is empty"), "invalid_response", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// 使用公开 task_xxxx ID 返回给客户端
|
||||
dResp.ID = info.PublicTaskID
|
||||
dResp.TaskID = info.PublicTaskID
|
||||
c.JSON(http.StatusOK, dResp)
|
||||
return dResp.ID, responseBody, nil
|
||||
return upstreamID, responseBody, nil
|
||||
}
|
||||
|
||||
// FetchTask fetch task status
|
||||
@@ -192,7 +304,7 @@ func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, e
|
||||
taskResult.Status = model.TaskStatusInProgress
|
||||
case "completed":
|
||||
taskResult.Status = model.TaskStatusSuccess
|
||||
taskResult.Url = fmt.Sprintf("%s/v1/videos/%s/content", system_setting.ServerAddress, resTask.ID)
|
||||
// Url intentionally left empty — the caller constructs the proxy URL using the public task ID
|
||||
case "failed", "cancelled":
|
||||
taskResult.Status = model.TaskStatusFailure
|
||||
if resTask.Error != nil {
|
||||
@@ -210,5 +322,10 @@ func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, e
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) ConvertToOpenAIVideo(task *model.Task) ([]byte, error) {
|
||||
return task.Data, nil
|
||||
data := task.Data
|
||||
var err error
|
||||
if data, err = sjson.SetBytes(data, "id", task.TaskID); err != nil {
|
||||
return nil, errors.Wrap(err, "set id failed")
|
||||
}
|
||||
return data, nil
|
||||
}
|
||||
|
||||
@@ -2,18 +2,16 @@ package suno
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/constant"
|
||||
"github.com/QuantumNous/new-api/dto"
|
||||
"github.com/QuantumNous/new-api/relay/channel"
|
||||
taskcommon "github.com/QuantumNous/new-api/relay/channel/task/taskcommon"
|
||||
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||||
"github.com/QuantumNous/new-api/service"
|
||||
|
||||
@@ -21,11 +19,16 @@ import (
|
||||
)
|
||||
|
||||
type TaskAdaptor struct {
|
||||
taskcommon.BaseBilling
|
||||
ChannelType int
|
||||
}
|
||||
|
||||
// ParseTaskResult is not used for Suno tasks.
|
||||
// Suno polling uses a dedicated batch-fetch path (service.UpdateSunoTasks) that
|
||||
// receives dto.TaskResponse[[]dto.SunoDataResponse] from the upstream /fetch API.
|
||||
// This differs from the per-task polling used by video adaptors.
|
||||
func (a *TaskAdaptor) ParseTaskResult([]byte) (*relaycommon.TaskInfo, error) {
|
||||
return nil, fmt.Errorf("not implement") // todo implement this method if needed
|
||||
return nil, fmt.Errorf("suno uses batch polling via UpdateSunoTasks, ParseTaskResult is not applicable")
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) {
|
||||
@@ -47,13 +50,13 @@ func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycom
|
||||
return
|
||||
}
|
||||
|
||||
if sunoRequest.ContinueClipId != "" {
|
||||
if sunoRequest.TaskID == "" {
|
||||
taskErr = service.TaskErrorWrapperLocal(fmt.Errorf("task id is empty"), "invalid_request", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
info.OriginTaskID = sunoRequest.TaskID
|
||||
}
|
||||
//if sunoRequest.ContinueClipId != "" {
|
||||
// if sunoRequest.TaskID == "" {
|
||||
// taskErr = service.TaskErrorWrapperLocal(fmt.Errorf("task id is empty"), "invalid_request", http.StatusBadRequest)
|
||||
// return
|
||||
// }
|
||||
// info.OriginTaskID = sunoRequest.TaskID
|
||||
//}
|
||||
|
||||
info.Action = action
|
||||
c.Set("task_request", sunoRequest)
|
||||
@@ -76,12 +79,9 @@ func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info
|
||||
func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayInfo) (io.Reader, error) {
|
||||
sunoRequest, ok := c.Get("task_request")
|
||||
if !ok {
|
||||
err := common.UnmarshalBodyReusable(c, &sunoRequest)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return nil, fmt.Errorf("task_request not found in context")
|
||||
}
|
||||
data, err := json.Marshal(sunoRequest)
|
||||
data, err := common.Marshal(sunoRequest)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -99,7 +99,7 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela
|
||||
return
|
||||
}
|
||||
var sunoResponse dto.TaskResponse[string]
|
||||
err = json.Unmarshal(responseBody, &sunoResponse)
|
||||
err = common.Unmarshal(responseBody, &sunoResponse)
|
||||
if err != nil {
|
||||
taskErr = service.TaskErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
|
||||
return
|
||||
@@ -109,17 +109,13 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela
|
||||
return
|
||||
}
|
||||
|
||||
for k, v := range resp.Header {
|
||||
c.Writer.Header().Set(k, v[0])
|
||||
}
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
|
||||
_, err = io.Copy(c.Writer, bytes.NewBuffer(responseBody))
|
||||
if err != nil {
|
||||
taskErr = service.TaskErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError)
|
||||
return
|
||||
// 使用公开 task_xxxx ID 替换上游 ID 返回给客户端
|
||||
publicResponse := dto.TaskResponse[string]{
|
||||
Code: sunoResponse.Code,
|
||||
Message: sunoResponse.Message,
|
||||
Data: info.PublicTaskID,
|
||||
}
|
||||
c.JSON(http.StatusOK, publicResponse)
|
||||
|
||||
return sunoResponse.Data, nil, nil
|
||||
}
|
||||
@@ -134,7 +130,7 @@ func (a *TaskAdaptor) GetChannelName() string {
|
||||
|
||||
func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any, proxy string) (*http.Response, error) {
|
||||
requestUrl := fmt.Sprintf("%s/suno/fetch", baseUrl)
|
||||
byteBody, err := json.Marshal(body)
|
||||
byteBody, err := common.Marshal(body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -144,13 +140,6 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any, proxy
|
||||
common.SysLog(fmt.Sprintf("Get Task error: %v", err))
|
||||
return nil, err
|
||||
}
|
||||
defer req.Body.Close()
|
||||
// 设置超时时间
|
||||
timeout := time.Second * 15
|
||||
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||
defer cancel()
|
||||
// 使用带有超时的 context 创建新的请求
|
||||
req = req.WithContext(ctx)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+key)
|
||||
client, err := service.GetHttpClientWithProxy(proxy)
|
||||
|
||||
95
relay/channel/task/taskcommon/helpers.go
Normal file
95
relay/channel/task/taskcommon/helpers.go
Normal file
@@ -0,0 +1,95 @@
|
||||
package taskcommon
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/model"
|
||||
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||||
"github.com/QuantumNous/new-api/setting/system_setting"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// UnmarshalMetadata converts a map[string]any metadata to a typed struct via JSON round-trip.
|
||||
// This replaces the repeated pattern: json.Marshal(metadata) → json.Unmarshal(bytes, &target).
|
||||
func UnmarshalMetadata(metadata map[string]any, target any) error {
|
||||
if metadata == nil {
|
||||
return nil
|
||||
}
|
||||
metaBytes, err := common.Marshal(metadata)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal metadata failed: %w", err)
|
||||
}
|
||||
if err := common.Unmarshal(metaBytes, target); err != nil {
|
||||
return fmt.Errorf("unmarshal metadata failed: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// DefaultString returns val if non-empty, otherwise fallback.
|
||||
func DefaultString(val, fallback string) string {
|
||||
if val == "" {
|
||||
return fallback
|
||||
}
|
||||
return val
|
||||
}
|
||||
|
||||
// DefaultInt returns val if non-zero, otherwise fallback.
|
||||
func DefaultInt(val, fallback int) int {
|
||||
if val == 0 {
|
||||
return fallback
|
||||
}
|
||||
return val
|
||||
}
|
||||
|
||||
// EncodeLocalTaskID encodes an upstream operation name to a URL-safe base64 string.
|
||||
// Used by Gemini/Vertex to store upstream names as task IDs.
|
||||
func EncodeLocalTaskID(name string) string {
|
||||
return base64.RawURLEncoding.EncodeToString([]byte(name))
|
||||
}
|
||||
|
||||
// DecodeLocalTaskID decodes a base64-encoded upstream operation name.
|
||||
func DecodeLocalTaskID(id string) (string, error) {
|
||||
b, err := base64.RawURLEncoding.DecodeString(id)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(b), nil
|
||||
}
|
||||
|
||||
// BuildProxyURL constructs the video proxy URL using the public task ID.
|
||||
// e.g., "https://your-server.com/v1/videos/task_xxxx/content"
|
||||
func BuildProxyURL(taskID string) string {
|
||||
return fmt.Sprintf("%s/v1/videos/%s/content", system_setting.ServerAddress, taskID)
|
||||
}
|
||||
|
||||
// Status-to-progress mapping constants for polling updates.
|
||||
const (
|
||||
ProgressSubmitted = "10%"
|
||||
ProgressQueued = "20%"
|
||||
ProgressInProgress = "30%"
|
||||
ProgressComplete = "100%"
|
||||
)
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// BaseBilling — embeddable no-op implementations for TaskAdaptor billing methods.
|
||||
// Adaptors that do not need custom billing can embed this struct directly.
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
type BaseBilling struct{}
|
||||
|
||||
// EstimateBilling returns nil (no extra ratios; use base model price).
|
||||
func (BaseBilling) EstimateBilling(_ *gin.Context, _ *relaycommon.RelayInfo) map[string]float64 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// AdjustBillingOnSubmit returns nil (no submit-time adjustment).
|
||||
func (BaseBilling) AdjustBillingOnSubmit(_ *relaycommon.RelayInfo, _ []byte) map[string]float64 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// AdjustBillingOnComplete returns 0 (keep pre-charged amount).
|
||||
func (BaseBilling) AdjustBillingOnComplete(_ *model.Task, _ *relaycommon.TaskInfo) int {
|
||||
return 0
|
||||
}
|
||||
@@ -2,13 +2,12 @@ package vertex
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/model"
|
||||
@@ -17,6 +16,7 @@ import (
|
||||
"github.com/QuantumNous/new-api/constant"
|
||||
"github.com/QuantumNous/new-api/dto"
|
||||
"github.com/QuantumNous/new-api/relay/channel"
|
||||
taskcommon "github.com/QuantumNous/new-api/relay/channel/task/taskcommon"
|
||||
vertexcore "github.com/QuantumNous/new-api/relay/channel/vertex"
|
||||
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||||
"github.com/QuantumNous/new-api/service"
|
||||
@@ -62,6 +62,7 @@ type operationResponse struct {
|
||||
// ============================
|
||||
|
||||
type TaskAdaptor struct {
|
||||
taskcommon.BaseBilling
|
||||
ChannelType int
|
||||
apiKey string
|
||||
baseURL string
|
||||
@@ -82,10 +83,10 @@ func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycom
|
||||
// BuildRequestURL constructs the upstream URL.
|
||||
func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
adc := &vertexcore.Credentials{}
|
||||
if err := json.Unmarshal([]byte(a.apiKey), adc); err != nil {
|
||||
if err := common.Unmarshal([]byte(a.apiKey), adc); err != nil {
|
||||
return "", fmt.Errorf("failed to decode credentials: %w", err)
|
||||
}
|
||||
modelName := info.OriginModelName
|
||||
modelName := info.UpstreamModelName
|
||||
if modelName == "" {
|
||||
modelName = "veo-3.0-generate-001"
|
||||
}
|
||||
@@ -116,7 +117,7 @@ func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info
|
||||
req.Header.Set("Accept", "application/json")
|
||||
|
||||
adc := &vertexcore.Credentials{}
|
||||
if err := json.Unmarshal([]byte(a.apiKey), adc); err != nil {
|
||||
if err := common.Unmarshal([]byte(a.apiKey), adc); err != nil {
|
||||
return fmt.Errorf("failed to decode credentials: %w", err)
|
||||
}
|
||||
|
||||
@@ -133,6 +134,28 @@ func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info
|
||||
return nil
|
||||
}
|
||||
|
||||
// EstimateBilling 根据用户请求中的 sampleCount 计算 OtherRatios。
|
||||
func (a *TaskAdaptor) EstimateBilling(c *gin.Context, _ *relaycommon.RelayInfo) map[string]float64 {
|
||||
sampleCount := 1
|
||||
v, ok := c.Get("task_request")
|
||||
if ok {
|
||||
req := v.(relaycommon.TaskSubmitReq)
|
||||
if req.Metadata != nil {
|
||||
if sc, exists := req.Metadata["sampleCount"]; exists {
|
||||
if i, ok := sc.(int); ok && i > 0 {
|
||||
sampleCount = i
|
||||
}
|
||||
if f, ok := sc.(float64); ok && int(f) > 0 {
|
||||
sampleCount = int(f)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return map[string]float64{
|
||||
"sampleCount": float64(sampleCount),
|
||||
}
|
||||
}
|
||||
|
||||
// BuildRequestBody converts request into Vertex specific format.
|
||||
func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayInfo) (io.Reader, error) {
|
||||
v, ok := c.Get("task_request")
|
||||
@@ -166,25 +189,7 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayIn
|
||||
return nil, fmt.Errorf("sampleCount must be greater than 0")
|
||||
}
|
||||
|
||||
// if req.Duration > 0 {
|
||||
// body.Parameters["durationSeconds"] = req.Duration
|
||||
// } else if req.Seconds != "" {
|
||||
// seconds, err := strconv.Atoi(req.Seconds)
|
||||
// if err != nil {
|
||||
// return nil, errors.Wrap(err, "convert seconds to int failed")
|
||||
// }
|
||||
// body.Parameters["durationSeconds"] = seconds
|
||||
// }
|
||||
|
||||
info.PriceData.OtherRatios = map[string]float64{
|
||||
"sampleCount": float64(body.Parameters["sampleCount"].(int)),
|
||||
}
|
||||
|
||||
// if v, ok := body.Parameters["durationSeconds"]; ok {
|
||||
// info.PriceData.OtherRatios["durationSeconds"] = float64(v.(int))
|
||||
// }
|
||||
|
||||
data, err := json.Marshal(body)
|
||||
data, err := common.Marshal(body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -205,14 +210,19 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela
|
||||
_ = resp.Body.Close()
|
||||
|
||||
var s submitResponse
|
||||
if err := json.Unmarshal(responseBody, &s); err != nil {
|
||||
if err := common.Unmarshal(responseBody, &s); err != nil {
|
||||
return "", nil, service.TaskErrorWrapper(err, "unmarshal_response_failed", http.StatusInternalServerError)
|
||||
}
|
||||
if strings.TrimSpace(s.Name) == "" {
|
||||
return "", nil, service.TaskErrorWrapper(fmt.Errorf("missing operation name"), "invalid_response", http.StatusInternalServerError)
|
||||
}
|
||||
localID := encodeLocalTaskID(s.Name)
|
||||
c.JSON(http.StatusOK, gin.H{"task_id": localID})
|
||||
localID := taskcommon.EncodeLocalTaskID(s.Name)
|
||||
ov := dto.NewOpenAIVideo()
|
||||
ov.ID = info.PublicTaskID
|
||||
ov.TaskID = info.PublicTaskID
|
||||
ov.CreatedAt = time.Now().Unix()
|
||||
ov.Model = info.OriginModelName
|
||||
c.JSON(http.StatusOK, ov)
|
||||
return localID, responseBody, nil
|
||||
}
|
||||
|
||||
@@ -225,7 +235,7 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any, proxy
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid task_id")
|
||||
}
|
||||
upstreamName, err := decodeLocalTaskID(taskID)
|
||||
upstreamName, err := taskcommon.DecodeLocalTaskID(taskID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decode task_id failed: %w", err)
|
||||
}
|
||||
@@ -245,12 +255,12 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any, proxy
|
||||
url = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:fetchPredictOperation", region, project, region, modelName)
|
||||
}
|
||||
payload := map[string]string{"operationName": upstreamName}
|
||||
data, err := json.Marshal(payload)
|
||||
data, err := common.Marshal(payload)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
adc := &vertexcore.Credentials{}
|
||||
if err := json.Unmarshal([]byte(key), adc); err != nil {
|
||||
if err := common.Unmarshal([]byte(key), adc); err != nil {
|
||||
return nil, fmt.Errorf("failed to decode credentials: %w", err)
|
||||
}
|
||||
token, err := vertexcore.AcquireAccessToken(*adc, proxy)
|
||||
@@ -274,7 +284,7 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any, proxy
|
||||
|
||||
func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) {
|
||||
var op operationResponse
|
||||
if err := json.Unmarshal(respBody, &op); err != nil {
|
||||
if err := common.Unmarshal(respBody, &op); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal operation response failed: %w", err)
|
||||
}
|
||||
ti := &relaycommon.TaskInfo{}
|
||||
@@ -338,7 +348,10 @@ func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, e
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) ConvertToOpenAIVideo(task *model.Task) ([]byte, error) {
|
||||
upstreamName, err := decodeLocalTaskID(task.TaskID)
|
||||
// Use GetUpstreamTaskID() to get the real upstream operation name for model extraction.
|
||||
// task.TaskID is now a public task_xxxx ID, no longer a base64-encoded upstream name.
|
||||
upstreamTaskID := task.GetUpstreamTaskID()
|
||||
upstreamName, err := taskcommon.DecodeLocalTaskID(upstreamTaskID)
|
||||
if err != nil {
|
||||
upstreamName = ""
|
||||
}
|
||||
@@ -353,8 +366,8 @@ func (a *TaskAdaptor) ConvertToOpenAIVideo(task *model.Task) ([]byte, error) {
|
||||
v.SetProgressStr(task.Progress)
|
||||
v.CreatedAt = task.CreatedAt
|
||||
v.CompletedAt = task.UpdatedAt
|
||||
if strings.HasPrefix(task.FailReason, "data:") && len(task.FailReason) > 0 {
|
||||
v.SetMetadata("url", task.FailReason)
|
||||
if resultURL := task.GetResultURL(); strings.HasPrefix(resultURL, "data:") && len(resultURL) > 0 {
|
||||
v.SetMetadata("url", resultURL)
|
||||
}
|
||||
|
||||
return common.Marshal(v)
|
||||
@@ -364,18 +377,6 @@ func (a *TaskAdaptor) ConvertToOpenAIVideo(task *model.Task) ([]byte, error) {
|
||||
// helpers
|
||||
// ============================
|
||||
|
||||
func encodeLocalTaskID(name string) string {
|
||||
return base64.RawURLEncoding.EncodeToString([]byte(name))
|
||||
}
|
||||
|
||||
func decodeLocalTaskID(local string) (string, error) {
|
||||
b, err := base64.RawURLEncoding.DecodeString(local)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(b), nil
|
||||
}
|
||||
|
||||
var regionRe = regexp.MustCompile(`locations/([a-z0-9-]+)/`)
|
||||
|
||||
func extractRegionFromOperationName(name string) string {
|
||||
|
||||
@@ -2,7 +2,6 @@ package vidu
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
@@ -16,6 +15,7 @@ import (
|
||||
"github.com/QuantumNous/new-api/dto"
|
||||
"github.com/QuantumNous/new-api/model"
|
||||
"github.com/QuantumNous/new-api/relay/channel"
|
||||
taskcommon "github.com/QuantumNous/new-api/relay/channel/task/taskcommon"
|
||||
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||||
"github.com/QuantumNous/new-api/service"
|
||||
|
||||
@@ -73,6 +73,7 @@ type creation struct {
|
||||
// ============================
|
||||
|
||||
type TaskAdaptor struct {
|
||||
taskcommon.BaseBilling
|
||||
ChannelType int
|
||||
baseURL string
|
||||
}
|
||||
@@ -115,7 +116,7 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayIn
|
||||
}
|
||||
req := v.(relaycommon.TaskSubmitReq)
|
||||
|
||||
body, err := a.convertToRequestPayload(&req)
|
||||
body, err := a.convertToRequestPayload(&req, info)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -127,7 +128,7 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayIn
|
||||
}
|
||||
}
|
||||
|
||||
data, err := json.Marshal(body)
|
||||
data, err := common.Marshal(body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -168,7 +169,7 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela
|
||||
}
|
||||
|
||||
var vResp responsePayload
|
||||
err = json.Unmarshal(responseBody, &vResp)
|
||||
err = common.Unmarshal(responseBody, &vResp)
|
||||
if err != nil {
|
||||
taskErr = service.TaskErrorWrapper(errors.Wrap(err, fmt.Sprintf("%s", responseBody)), "unmarshal_response_failed", http.StatusInternalServerError)
|
||||
return
|
||||
@@ -180,8 +181,8 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela
|
||||
}
|
||||
|
||||
ov := dto.NewOpenAIVideo()
|
||||
ov.ID = vResp.TaskId
|
||||
ov.TaskID = vResp.TaskId
|
||||
ov.ID = info.PublicTaskID
|
||||
ov.TaskID = info.PublicTaskID
|
||||
ov.CreatedAt = time.Now().Unix()
|
||||
ov.Model = info.OriginModelName
|
||||
c.JSON(http.StatusOK, ov)
|
||||
@@ -223,47 +224,27 @@ func (a *TaskAdaptor) GetChannelName() string {
|
||||
// helpers
|
||||
// ============================
|
||||
|
||||
func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*requestPayload, error) {
|
||||
func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq, info *relaycommon.RelayInfo) (*requestPayload, error) {
|
||||
r := requestPayload{
|
||||
Model: defaultString(req.Model, "viduq1"),
|
||||
Model: taskcommon.DefaultString(info.UpstreamModelName, "viduq1"),
|
||||
Images: req.Images,
|
||||
Prompt: req.Prompt,
|
||||
Duration: defaultInt(req.Duration, 5),
|
||||
Resolution: defaultString(req.Size, "1080p"),
|
||||
Duration: taskcommon.DefaultInt(req.Duration, 5),
|
||||
Resolution: taskcommon.DefaultString(req.Size, "1080p"),
|
||||
MovementAmplitude: "auto",
|
||||
Bgm: false,
|
||||
}
|
||||
metadata := req.Metadata
|
||||
medaBytes, err := json.Marshal(metadata)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "metadata marshal metadata failed")
|
||||
}
|
||||
err = json.Unmarshal(medaBytes, &r)
|
||||
if err != nil {
|
||||
if err := taskcommon.UnmarshalMetadata(req.Metadata, &r); err != nil {
|
||||
return nil, errors.Wrap(err, "unmarshal metadata failed")
|
||||
}
|
||||
return &r, nil
|
||||
}
|
||||
|
||||
func defaultString(value, defaultValue string) string {
|
||||
if value == "" {
|
||||
return defaultValue
|
||||
}
|
||||
return value
|
||||
}
|
||||
|
||||
func defaultInt(value, defaultValue int) int {
|
||||
if value == 0 {
|
||||
return defaultValue
|
||||
}
|
||||
return value
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) {
|
||||
taskInfo := &relaycommon.TaskInfo{}
|
||||
|
||||
var taskResp taskResultResponse
|
||||
err := json.Unmarshal(respBody, &taskResp)
|
||||
err := common.Unmarshal(respBody, &taskResp)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to unmarshal response body")
|
||||
}
|
||||
@@ -293,7 +274,7 @@ func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, e
|
||||
|
||||
func (a *TaskAdaptor) ConvertToOpenAIVideo(originTask *model.Task) ([]byte, error) {
|
||||
var viduResp taskResultResponse
|
||||
if err := json.Unmarshal(originTask.Data, &viduResp); err != nil {
|
||||
if err := common.Unmarshal(originTask.Data, &viduResp); err != nil {
|
||||
return nil, errors.Wrap(err, "unmarshal vidu task data failed")
|
||||
}
|
||||
|
||||
@@ -315,6 +296,5 @@ func (a *TaskAdaptor) ConvertToOpenAIVideo(originTask *model.Task) ([]byte, erro
|
||||
}
|
||||
}
|
||||
|
||||
jsonData, _ := common.Marshal(openAIVideo)
|
||||
return jsonData, nil
|
||||
return common.Marshal(openAIVideo)
|
||||
}
|
||||
|
||||
@@ -75,7 +75,7 @@ func chatCompletionsViaResponses(c *gin.Context, info *relaycommon.RelayInfo, ad
|
||||
return nil, types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
chatJSON, err = relaycommon.RemoveDisabledFields(chatJSON, info.ChannelOtherSettings)
|
||||
chatJSON, err = relaycommon.RemoveDisabledFields(chatJSON, info.ChannelOtherSettings, info.ChannelSetting.PassThroughBodyEnabled)
|
||||
if err != nil {
|
||||
return nil, types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
@@ -119,7 +119,7 @@ func chatCompletionsViaResponses(c *gin.Context, info *relaycommon.RelayInfo, ad
|
||||
return nil, types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
jsonData, err = relaycommon.RemoveDisabledFields(jsonData, info.ChannelOtherSettings)
|
||||
jsonData, err = relaycommon.RemoveDisabledFields(jsonData, info.ChannelOtherSettings, info.ChannelSetting.PassThroughBodyEnabled)
|
||||
if err != nil {
|
||||
return nil, types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
@@ -146,7 +146,7 @@ func ClaudeHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
|
||||
}
|
||||
|
||||
// remove disabled fields for Claude API
|
||||
jsonData, err = relaycommon.RemoveDisabledFields(jsonData, info.ChannelOtherSettings)
|
||||
jsonData, err = relaycommon.RemoveDisabledFields(jsonData, info.ChannelOtherSettings, info.ChannelSetting.PassThroughBodyEnabled)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
@@ -6,6 +6,9 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/QuantumNous/new-api/types"
|
||||
|
||||
"github.com/QuantumNous/new-api/dto"
|
||||
"github.com/QuantumNous/new-api/setting/model_setting"
|
||||
)
|
||||
|
||||
func TestApplyParamOverrideTrimPrefix(t *testing.T) {
|
||||
@@ -1311,6 +1314,76 @@ func TestApplyParamOverrideWithRelayInfoMoveAndCopyHeaders(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestRemoveDisabledFieldsSkipWhenChannelPassThroughEnabled(t *testing.T) {
|
||||
input := `{
|
||||
"service_tier":"flex",
|
||||
"safety_identifier":"user-123",
|
||||
"store":true,
|
||||
"stream_options":{"include_obfuscation":false}
|
||||
}`
|
||||
settings := dto.ChannelOtherSettings{}
|
||||
|
||||
out, err := RemoveDisabledFields([]byte(input), settings, true)
|
||||
if err != nil {
|
||||
t.Fatalf("RemoveDisabledFields returned error: %v", err)
|
||||
}
|
||||
assertJSONEqual(t, input, string(out))
|
||||
}
|
||||
|
||||
func TestRemoveDisabledFieldsSkipWhenGlobalPassThroughEnabled(t *testing.T) {
|
||||
original := model_setting.GetGlobalSettings().PassThroughRequestEnabled
|
||||
model_setting.GetGlobalSettings().PassThroughRequestEnabled = true
|
||||
t.Cleanup(func() {
|
||||
model_setting.GetGlobalSettings().PassThroughRequestEnabled = original
|
||||
})
|
||||
|
||||
input := `{
|
||||
"service_tier":"flex",
|
||||
"safety_identifier":"user-123",
|
||||
"stream_options":{"include_obfuscation":false}
|
||||
}`
|
||||
settings := dto.ChannelOtherSettings{}
|
||||
|
||||
out, err := RemoveDisabledFields([]byte(input), settings, false)
|
||||
if err != nil {
|
||||
t.Fatalf("RemoveDisabledFields returned error: %v", err)
|
||||
}
|
||||
assertJSONEqual(t, input, string(out))
|
||||
}
|
||||
|
||||
func TestRemoveDisabledFieldsDefaultFiltering(t *testing.T) {
|
||||
input := `{
|
||||
"service_tier":"flex",
|
||||
"inference_geo":"eu",
|
||||
"safety_identifier":"user-123",
|
||||
"store":true,
|
||||
"stream_options":{"include_obfuscation":false}
|
||||
}`
|
||||
settings := dto.ChannelOtherSettings{}
|
||||
|
||||
out, err := RemoveDisabledFields([]byte(input), settings, false)
|
||||
if err != nil {
|
||||
t.Fatalf("RemoveDisabledFields returned error: %v", err)
|
||||
}
|
||||
assertJSONEqual(t, `{"store":true}`, string(out))
|
||||
}
|
||||
|
||||
func TestRemoveDisabledFieldsAllowInferenceGeo(t *testing.T) {
|
||||
input := `{
|
||||
"inference_geo":"eu",
|
||||
"store":true
|
||||
}`
|
||||
settings := dto.ChannelOtherSettings{
|
||||
AllowInferenceGeo: true,
|
||||
}
|
||||
|
||||
out, err := RemoveDisabledFields([]byte(input), settings, false)
|
||||
if err != nil {
|
||||
t.Fatalf("RemoveDisabledFields returned error: %v", err)
|
||||
}
|
||||
assertJSONEqual(t, `{"inference_geo":"eu","store":true}`, string(out))
|
||||
}
|
||||
|
||||
func assertJSONEqual(t *testing.T, want, got string) {
|
||||
t.Helper()
|
||||
|
||||
|
||||
@@ -119,8 +119,12 @@ type RelayInfo struct {
|
||||
SendResponseCount int
|
||||
ReceivedResponseCount int
|
||||
FinalPreConsumedQuota int // 最终预消耗的配额
|
||||
// ForcePreConsume 为 true 时禁用 BillingSession 的信任额度旁路,
|
||||
// 强制预扣全额。用于异步任务(视频/音乐生成等),因为请求返回后任务仍在运行,
|
||||
// 必须在提交前锁定全额。
|
||||
ForcePreConsume bool
|
||||
// Billing 是计费会话,封装了预扣费/结算/退款的统一生命周期。
|
||||
// 免费模型和按次计费(MJ/Task)时为 nil。
|
||||
// 免费模型时为 nil。
|
||||
Billing BillingSettler
|
||||
// BillingSource indicates whether this request is billed from wallet quota or subscription.
|
||||
// "" or "wallet" => wallet; "subscription" => subscription
|
||||
@@ -153,7 +157,8 @@ type RelayInfo struct {
|
||||
// RequestConversionChain records request format conversions in order, e.g.
|
||||
// ["openai", "openai_responses"] or ["openai", "claude"].
|
||||
RequestConversionChain []types.RelayFormat
|
||||
// 最终请求到上游的格式 TODO: 当前仅设置了Claude
|
||||
// 最终请求到上游的格式。可由 adaptor 显式设置;
|
||||
// 若为空,调用 GetFinalRequestRelayFormat 会回退到 RequestConversionChain 的最后一项或 RelayFormat。
|
||||
FinalRequestRelayFormat types.RelayFormat
|
||||
|
||||
ThinkingContentInfo
|
||||
@@ -552,8 +557,10 @@ func GenRelayInfo(c *gin.Context, relayFormat types.RelayFormat, request dto.Req
|
||||
return nil, errors.New("request is not a OpenAIResponsesCompactionRequest")
|
||||
case types.RelayFormatTask:
|
||||
info = genBaseRelayInfo(c, nil)
|
||||
info.TaskRelayInfo = &TaskRelayInfo{}
|
||||
case types.RelayFormatMjProxy:
|
||||
info = genBaseRelayInfo(c, nil)
|
||||
info.TaskRelayInfo = &TaskRelayInfo{}
|
||||
default:
|
||||
err = errors.New("invalid relay format")
|
||||
}
|
||||
@@ -600,6 +607,19 @@ func (info *RelayInfo) AppendRequestConversion(format types.RelayFormat) {
|
||||
info.RequestConversionChain = append(info.RequestConversionChain, format)
|
||||
}
|
||||
|
||||
func (info *RelayInfo) GetFinalRequestRelayFormat() types.RelayFormat {
|
||||
if info == nil {
|
||||
return ""
|
||||
}
|
||||
if info.FinalRequestRelayFormat != "" {
|
||||
return info.FinalRequestRelayFormat
|
||||
}
|
||||
if n := len(info.RequestConversionChain); n > 0 {
|
||||
return info.RequestConversionChain[n-1]
|
||||
}
|
||||
return info.RelayFormat
|
||||
}
|
||||
|
||||
func GenRelayInfoResponsesCompaction(c *gin.Context, request *dto.OpenAIResponsesCompactionRequest) *RelayInfo {
|
||||
info := genBaseRelayInfo(c, request)
|
||||
if info.RelayMode == relayconstant.RelayModeUnknown {
|
||||
@@ -635,8 +655,16 @@ func (info *RelayInfo) HasSendResponse() bool {
|
||||
type TaskRelayInfo struct {
|
||||
Action string
|
||||
OriginTaskID string
|
||||
// PublicTaskID 是提交时预生成的 task_xxxx 格式公开 ID,
|
||||
// 供 DoResponse 在返回给客户端时使用(避免暴露上游真实 ID)。
|
||||
PublicTaskID string
|
||||
|
||||
ConsumeQuota bool
|
||||
|
||||
// LockedChannel holds the full channel object when the request is bound to
|
||||
// a specific channel (e.g., remix on origin task's channel). Stored as any
|
||||
// to avoid an import cycle with model; callers type-assert to *model.Channel.
|
||||
LockedChannel any
|
||||
}
|
||||
|
||||
type TaskSubmitReq struct {
|
||||
@@ -694,11 +722,11 @@ func (t *TaskSubmitReq) UnmarshalJSON(data []byte) error {
|
||||
func (t *TaskSubmitReq) UnmarshalMetadata(v any) error {
|
||||
metadata := t.Metadata
|
||||
if metadata != nil {
|
||||
metadataBytes, err := json.Marshal(metadata)
|
||||
metadataBytes, err := common.Marshal(metadata)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal metadata failed: %w", err)
|
||||
}
|
||||
err = json.Unmarshal(metadataBytes, v)
|
||||
err = common.Unmarshal(metadataBytes, v)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unmarshal metadata to target failed: %w", err)
|
||||
}
|
||||
@@ -727,9 +755,15 @@ func FailTaskInfo(reason string) *TaskInfo {
|
||||
|
||||
// RemoveDisabledFields 从请求 JSON 数据中移除渠道设置中禁用的字段
|
||||
// service_tier: 服务层级字段,可能导致额外计费(OpenAI、Claude、Responses API 支持)
|
||||
// inference_geo: Claude 数据驻留推理区域字段(仅 Claude 支持,默认过滤)
|
||||
// store: 数据存储授权字段,涉及用户隐私(仅 OpenAI、Responses API 支持,默认允许透传,禁用后可能导致 Codex 无法使用)
|
||||
// safety_identifier: 安全标识符,用于向 OpenAI 报告违规用户(仅 OpenAI 支持,涉及用户隐私)
|
||||
func RemoveDisabledFields(jsonData []byte, channelOtherSettings dto.ChannelOtherSettings) ([]byte, error) {
|
||||
// stream_options.include_obfuscation: 响应流混淆控制字段(仅 OpenAI Responses API 支持)
|
||||
func RemoveDisabledFields(jsonData []byte, channelOtherSettings dto.ChannelOtherSettings, channelPassThroughEnabled bool) ([]byte, error) {
|
||||
if model_setting.GetGlobalSettings().PassThroughRequestEnabled || channelPassThroughEnabled {
|
||||
return jsonData, nil
|
||||
}
|
||||
|
||||
var data map[string]interface{}
|
||||
if err := common.Unmarshal(jsonData, &data); err != nil {
|
||||
common.SysError("RemoveDisabledFields Unmarshal error :" + err.Error())
|
||||
@@ -743,6 +777,13 @@ func RemoveDisabledFields(jsonData []byte, channelOtherSettings dto.ChannelOther
|
||||
}
|
||||
}
|
||||
|
||||
// 默认移除 inference_geo,除非明确允许(避免在未授权情况下透传数据驻留区域)
|
||||
if !channelOtherSettings.AllowInferenceGeo {
|
||||
if _, exists := data["inference_geo"]; exists {
|
||||
delete(data, "inference_geo")
|
||||
}
|
||||
}
|
||||
|
||||
// 默认允许 store 透传,除非明确禁用(禁用可能影响 Codex 使用)
|
||||
if channelOtherSettings.DisableStore {
|
||||
if _, exists := data["store"]; exists {
|
||||
@@ -757,6 +798,22 @@ func RemoveDisabledFields(jsonData []byte, channelOtherSettings dto.ChannelOther
|
||||
}
|
||||
}
|
||||
|
||||
// 默认移除 stream_options.include_obfuscation,除非明确允许(避免关闭响应流混淆保护)
|
||||
if !channelOtherSettings.AllowIncludeObfuscation {
|
||||
if streamOptionsAny, exists := data["stream_options"]; exists {
|
||||
if streamOptions, ok := streamOptionsAny.(map[string]interface{}); ok {
|
||||
if _, includeExists := streamOptions["include_obfuscation"]; includeExists {
|
||||
delete(streamOptions, "include_obfuscation")
|
||||
}
|
||||
if len(streamOptions) == 0 {
|
||||
delete(data, "stream_options")
|
||||
} else {
|
||||
data["stream_options"] = streamOptions
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
jsonDataAfter, err := common.Marshal(data)
|
||||
if err != nil {
|
||||
common.SysError("RemoveDisabledFields Marshal error :" + err.Error())
|
||||
|
||||
40
relay/common/relay_info_test.go
Normal file
40
relay/common/relay_info_test.go
Normal file
@@ -0,0 +1,40 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/QuantumNous/new-api/types"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestRelayInfoGetFinalRequestRelayFormatPrefersExplicitFinal(t *testing.T) {
|
||||
info := &RelayInfo{
|
||||
RelayFormat: types.RelayFormatOpenAI,
|
||||
RequestConversionChain: []types.RelayFormat{types.RelayFormatOpenAI, types.RelayFormatClaude},
|
||||
FinalRequestRelayFormat: types.RelayFormatOpenAIResponses,
|
||||
}
|
||||
|
||||
require.Equal(t, types.RelayFormat(types.RelayFormatOpenAIResponses), info.GetFinalRequestRelayFormat())
|
||||
}
|
||||
|
||||
func TestRelayInfoGetFinalRequestRelayFormatFallsBackToConversionChain(t *testing.T) {
|
||||
info := &RelayInfo{
|
||||
RelayFormat: types.RelayFormatOpenAI,
|
||||
RequestConversionChain: []types.RelayFormat{types.RelayFormatOpenAI, types.RelayFormatClaude},
|
||||
}
|
||||
|
||||
require.Equal(t, types.RelayFormat(types.RelayFormatClaude), info.GetFinalRequestRelayFormat())
|
||||
}
|
||||
|
||||
func TestRelayInfoGetFinalRequestRelayFormatFallsBackToRelayFormat(t *testing.T) {
|
||||
info := &RelayInfo{
|
||||
RelayFormat: types.RelayFormatGemini,
|
||||
}
|
||||
|
||||
require.Equal(t, types.RelayFormat(types.RelayFormatGemini), info.GetFinalRequestRelayFormat())
|
||||
}
|
||||
|
||||
func TestRelayInfoGetFinalRequestRelayFormatNilReceiver(t *testing.T) {
|
||||
var info *RelayInfo
|
||||
require.Equal(t, types.RelayFormat(""), info.GetFinalRequestRelayFormat())
|
||||
}
|
||||
@@ -173,16 +173,10 @@ func ValidateMultipartDirect(c *gin.Context, info *RelayInfo) *dto.TaskError {
|
||||
if model == "sora-2-pro" && !lo.Contains([]string{"720x1280", "1280x720", "1792x1024", "1024x1792"}, size) {
|
||||
return createTaskError(fmt.Errorf("sora-2 size is invalid"), "invalid_size", http.StatusBadRequest, true)
|
||||
}
|
||||
info.PriceData.OtherRatios = map[string]float64{
|
||||
"seconds": float64(seconds),
|
||||
"size": 1,
|
||||
}
|
||||
if lo.Contains([]string{"1792x1024", "1024x1792"}, size) {
|
||||
info.PriceData.OtherRatios["size"] = 1.666667
|
||||
}
|
||||
// OtherRatios 已移到 Sora adaptor 的 EstimateBilling 中设置
|
||||
}
|
||||
|
||||
info.Action = action
|
||||
storeTaskRequest(c, info, action, req)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -165,7 +165,7 @@ func TextHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types
|
||||
}
|
||||
|
||||
// remove disabled fields for OpenAI API
|
||||
jsonData, err = relaycommon.RemoveDisabledFields(jsonData, info.ChannelOtherSettings)
|
||||
jsonData, err = relaycommon.RemoveDisabledFields(jsonData, info.ChannelOtherSettings, info.ChannelSetting.PassThroughBodyEnabled)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
@@ -232,7 +232,7 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage
|
||||
}
|
||||
|
||||
if originUsage != nil {
|
||||
service.ObserveChannelAffinityUsageCacheFromContext(ctx, usage)
|
||||
service.ObserveChannelAffinityUsageCacheByRelayFormat(ctx, usage, relayInfo.GetFinalRequestRelayFormat())
|
||||
}
|
||||
|
||||
adminRejectReason := common.GetContextKeyString(ctx, constant.ContextKeyAdminRejectReason)
|
||||
@@ -336,7 +336,7 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage
|
||||
|
||||
var audioInputQuota decimal.Decimal
|
||||
var audioInputPrice float64
|
||||
isClaudeUsageSemantic := relayInfo.FinalRequestRelayFormat == types.RelayFormatClaude
|
||||
isClaudeUsageSemantic := relayInfo.GetFinalRequestRelayFormat() == types.RelayFormatClaude
|
||||
if !relayInfo.PriceData.UsePrice {
|
||||
baseTokens := dPromptTokens
|
||||
// 减去 cached tokens
|
||||
|
||||
@@ -140,7 +140,7 @@ func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens
|
||||
}
|
||||
|
||||
// ModelPriceHelperPerCall 按次计费的 PriceHelper (MJ、Task)
|
||||
func ModelPriceHelperPerCall(c *gin.Context, info *relaycommon.RelayInfo) types.PerCallPriceData {
|
||||
func ModelPriceHelperPerCall(c *gin.Context, info *relaycommon.RelayInfo) types.PriceData {
|
||||
groupRatioInfo := HandleGroupRatio(c, info)
|
||||
|
||||
modelPrice, success := ratio_setting.GetModelPrice(info.OriginModelName, true)
|
||||
@@ -154,7 +154,18 @@ func ModelPriceHelperPerCall(c *gin.Context, info *relaycommon.RelayInfo) types.
|
||||
}
|
||||
}
|
||||
quota := int(modelPrice * common.QuotaPerUnit * groupRatioInfo.GroupRatio)
|
||||
priceData := types.PerCallPriceData{
|
||||
|
||||
// 免费模型检测(与 ModelPriceHelper 对齐)
|
||||
freeModel := false
|
||||
if !operation_setting.GetQuotaSetting().EnableFreeModelPreConsume {
|
||||
if groupRatioInfo.GroupRatio == 0 || modelPrice == 0 {
|
||||
quota = 0
|
||||
freeModel = true
|
||||
}
|
||||
}
|
||||
|
||||
priceData := types.PriceData{
|
||||
FreeModel: freeModel,
|
||||
ModelPrice: modelPrice,
|
||||
Quota: quota,
|
||||
GroupRatioInfo: groupRatioInfo,
|
||||
|
||||
@@ -176,10 +176,32 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon
|
||||
})
|
||||
}
|
||||
|
||||
dataChan := make(chan string, 10)
|
||||
|
||||
wg.Add(1)
|
||||
gopool.Go(func() {
|
||||
defer func() {
|
||||
wg.Done()
|
||||
if r := recover(); r != nil {
|
||||
logger.LogError(c, fmt.Sprintf("data handler goroutine panic: %v", r))
|
||||
}
|
||||
common.SafeSendBool(stopChan, true)
|
||||
}()
|
||||
for data := range dataChan {
|
||||
writeMutex.Lock()
|
||||
success := dataHandler(data)
|
||||
writeMutex.Unlock()
|
||||
if !success {
|
||||
return
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
// Scanner goroutine with improved error handling
|
||||
wg.Add(1)
|
||||
common.RelayCtxGo(ctx, func() {
|
||||
defer func() {
|
||||
close(dataChan)
|
||||
wg.Done()
|
||||
if r := recover(); r != nil {
|
||||
logger.LogError(c, fmt.Sprintf("scanner goroutine panic: %v", r))
|
||||
@@ -215,27 +237,16 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon
|
||||
continue
|
||||
}
|
||||
data = data[5:]
|
||||
data = strings.TrimLeft(data, " ")
|
||||
data = strings.TrimSuffix(data, "\r")
|
||||
data = strings.TrimSpace(data)
|
||||
if data == "" {
|
||||
continue
|
||||
}
|
||||
if !strings.HasPrefix(data, "[DONE]") {
|
||||
info.SetFirstResponseTime()
|
||||
info.ReceivedResponseCount++
|
||||
// 使用超时机制防止写操作阻塞
|
||||
done := make(chan bool, 1)
|
||||
gopool.Go(func() {
|
||||
writeMutex.Lock()
|
||||
defer writeMutex.Unlock()
|
||||
done <- dataHandler(data)
|
||||
})
|
||||
|
||||
select {
|
||||
case success := <-done:
|
||||
if !success {
|
||||
return
|
||||
}
|
||||
case <-time.After(10 * time.Second):
|
||||
logger.LogError(c, "data handler timeout")
|
||||
return
|
||||
case dataChan <- data:
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-stopChan:
|
||||
|
||||
521
relay/helper/stream_scanner_test.go
Normal file
521
relay/helper/stream_scanner_test.go
Normal file
@@ -0,0 +1,521 @@
|
||||
package helper
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/QuantumNous/new-api/constant"
|
||||
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||||
"github.com/QuantumNous/new-api/setting/operation_setting"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func init() {
|
||||
gin.SetMode(gin.TestMode)
|
||||
}
|
||||
|
||||
func setupStreamTest(t *testing.T, body io.Reader) (*gin.Context, *http.Response, *relaycommon.RelayInfo) {
|
||||
t.Helper()
|
||||
|
||||
oldTimeout := constant.StreamingTimeout
|
||||
constant.StreamingTimeout = 30
|
||||
t.Cleanup(func() {
|
||||
constant.StreamingTimeout = oldTimeout
|
||||
})
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(recorder)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||
|
||||
resp := &http.Response{
|
||||
Body: io.NopCloser(body),
|
||||
}
|
||||
|
||||
info := &relaycommon.RelayInfo{
|
||||
ChannelMeta: &relaycommon.ChannelMeta{},
|
||||
}
|
||||
|
||||
return c, resp, info
|
||||
}
|
||||
|
||||
func buildSSEBody(n int) string {
|
||||
var b strings.Builder
|
||||
for i := 0; i < n; i++ {
|
||||
fmt.Fprintf(&b, "data: {\"id\":%d,\"choices\":[{\"delta\":{\"content\":\"token_%d\"}}]}\n", i, i)
|
||||
}
|
||||
b.WriteString("data: [DONE]\n")
|
||||
return b.String()
|
||||
}
|
||||
|
||||
// slowReader wraps a reader and injects a delay before each Read call,
|
||||
// simulating a slow upstream that trickles data.
|
||||
type slowReader struct {
|
||||
r io.Reader
|
||||
delay time.Duration
|
||||
}
|
||||
|
||||
func (s *slowReader) Read(p []byte) (int, error) {
|
||||
time.Sleep(s.delay)
|
||||
return s.r.Read(p)
|
||||
}
|
||||
|
||||
// ---------- Basic correctness ----------
|
||||
|
||||
func TestStreamScannerHandler_NilInputs(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(recorder)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
|
||||
|
||||
info := &relaycommon.RelayInfo{ChannelMeta: &relaycommon.ChannelMeta{}}
|
||||
|
||||
StreamScannerHandler(c, nil, info, func(data string) bool { return true })
|
||||
StreamScannerHandler(c, &http.Response{Body: io.NopCloser(strings.NewReader(""))}, info, nil)
|
||||
}
|
||||
|
||||
func TestStreamScannerHandler_EmptyBody(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
c, resp, info := setupStreamTest(t, strings.NewReader(""))
|
||||
|
||||
var called atomic.Bool
|
||||
StreamScannerHandler(c, resp, info, func(data string) bool {
|
||||
called.Store(true)
|
||||
return true
|
||||
})
|
||||
|
||||
assert.False(t, called.Load(), "handler should not be called for empty body")
|
||||
}
|
||||
|
||||
func TestStreamScannerHandler_1000Chunks(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
const numChunks = 1000
|
||||
body := buildSSEBody(numChunks)
|
||||
c, resp, info := setupStreamTest(t, strings.NewReader(body))
|
||||
|
||||
var count atomic.Int64
|
||||
StreamScannerHandler(c, resp, info, func(data string) bool {
|
||||
count.Add(1)
|
||||
return true
|
||||
})
|
||||
|
||||
assert.Equal(t, int64(numChunks), count.Load())
|
||||
assert.Equal(t, numChunks, info.ReceivedResponseCount)
|
||||
}
|
||||
|
||||
func TestStreamScannerHandler_10000Chunks(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
const numChunks = 10000
|
||||
body := buildSSEBody(numChunks)
|
||||
c, resp, info := setupStreamTest(t, strings.NewReader(body))
|
||||
|
||||
var count atomic.Int64
|
||||
start := time.Now()
|
||||
|
||||
StreamScannerHandler(c, resp, info, func(data string) bool {
|
||||
count.Add(1)
|
||||
return true
|
||||
})
|
||||
|
||||
elapsed := time.Since(start)
|
||||
assert.Equal(t, int64(numChunks), count.Load())
|
||||
assert.Equal(t, numChunks, info.ReceivedResponseCount)
|
||||
t.Logf("10000 chunks processed in %v", elapsed)
|
||||
}
|
||||
|
||||
func TestStreamScannerHandler_OrderPreserved(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
const numChunks = 500
|
||||
body := buildSSEBody(numChunks)
|
||||
c, resp, info := setupStreamTest(t, strings.NewReader(body))
|
||||
|
||||
var mu sync.Mutex
|
||||
received := make([]string, 0, numChunks)
|
||||
|
||||
StreamScannerHandler(c, resp, info, func(data string) bool {
|
||||
mu.Lock()
|
||||
received = append(received, data)
|
||||
mu.Unlock()
|
||||
return true
|
||||
})
|
||||
|
||||
require.Equal(t, numChunks, len(received))
|
||||
for i := 0; i < numChunks; i++ {
|
||||
expected := fmt.Sprintf("{\"id\":%d,\"choices\":[{\"delta\":{\"content\":\"token_%d\"}}]}", i, i)
|
||||
assert.Equal(t, expected, received[i], "chunk %d out of order", i)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamScannerHandler_DoneStopsScanner(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
body := buildSSEBody(50) + "data: should_not_appear\n"
|
||||
c, resp, info := setupStreamTest(t, strings.NewReader(body))
|
||||
|
||||
var count atomic.Int64
|
||||
StreamScannerHandler(c, resp, info, func(data string) bool {
|
||||
count.Add(1)
|
||||
return true
|
||||
})
|
||||
|
||||
assert.Equal(t, int64(50), count.Load(), "data after [DONE] must not be processed")
|
||||
}
|
||||
|
||||
func TestStreamScannerHandler_HandlerFailureStops(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
const numChunks = 200
|
||||
body := buildSSEBody(numChunks)
|
||||
c, resp, info := setupStreamTest(t, strings.NewReader(body))
|
||||
|
||||
const failAt = 50
|
||||
var count atomic.Int64
|
||||
StreamScannerHandler(c, resp, info, func(data string) bool {
|
||||
n := count.Add(1)
|
||||
return n < failAt
|
||||
})
|
||||
|
||||
// The worker stops at failAt; the scanner may have read ahead,
|
||||
// but the handler should not be called beyond failAt.
|
||||
assert.Equal(t, int64(failAt), count.Load())
|
||||
}
|
||||
|
||||
func TestStreamScannerHandler_SkipsNonDataLines(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var b strings.Builder
|
||||
b.WriteString(": comment line\n")
|
||||
b.WriteString("event: message\n")
|
||||
b.WriteString("id: 12345\n")
|
||||
b.WriteString("retry: 5000\n")
|
||||
for i := 0; i < 100; i++ {
|
||||
fmt.Fprintf(&b, "data: payload_%d\n", i)
|
||||
b.WriteString(": interleaved comment\n")
|
||||
}
|
||||
b.WriteString("data: [DONE]\n")
|
||||
|
||||
c, resp, info := setupStreamTest(t, strings.NewReader(b.String()))
|
||||
|
||||
var count atomic.Int64
|
||||
StreamScannerHandler(c, resp, info, func(data string) bool {
|
||||
count.Add(1)
|
||||
return true
|
||||
})
|
||||
|
||||
assert.Equal(t, int64(100), count.Load())
|
||||
}
|
||||
|
||||
func TestStreamScannerHandler_DataWithExtraSpaces(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
body := "data: {\"trimmed\":true} \ndata: [DONE]\n"
|
||||
c, resp, info := setupStreamTest(t, strings.NewReader(body))
|
||||
|
||||
var got string
|
||||
StreamScannerHandler(c, resp, info, func(data string) bool {
|
||||
got = data
|
||||
return true
|
||||
})
|
||||
|
||||
assert.Equal(t, "{\"trimmed\":true}", got)
|
||||
}
|
||||
|
||||
// ---------- Decoupling: scanner not blocked by slow handler ----------
|
||||
|
||||
func TestStreamScannerHandler_ScannerDecoupledFromSlowHandler(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Strategy: use a slow upstream (io.Pipe, 10ms per chunk) AND a slow handler (20ms per chunk).
|
||||
// If the scanner were synchronously coupled to the handler, total time would be
|
||||
// ~numChunks * (10ms + 20ms) = 30ms * 50 = 1500ms.
|
||||
// With decoupling, total time should be closer to
|
||||
// ~numChunks * max(10ms, 20ms) = 20ms * 50 = 1000ms
|
||||
// because the scanner reads ahead into the buffer while the handler processes.
|
||||
const numChunks = 50
|
||||
const upstreamDelay = 10 * time.Millisecond
|
||||
const handlerDelay = 20 * time.Millisecond
|
||||
|
||||
pr, pw := io.Pipe()
|
||||
go func() {
|
||||
defer pw.Close()
|
||||
for i := 0; i < numChunks; i++ {
|
||||
fmt.Fprintf(pw, "data: {\"id\":%d}\n", i)
|
||||
time.Sleep(upstreamDelay)
|
||||
}
|
||||
fmt.Fprint(pw, "data: [DONE]\n")
|
||||
}()
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(recorder)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||
|
||||
oldTimeout := constant.StreamingTimeout
|
||||
constant.StreamingTimeout = 30
|
||||
t.Cleanup(func() { constant.StreamingTimeout = oldTimeout })
|
||||
|
||||
resp := &http.Response{Body: pr}
|
||||
info := &relaycommon.RelayInfo{ChannelMeta: &relaycommon.ChannelMeta{}}
|
||||
|
||||
var count atomic.Int64
|
||||
start := time.Now()
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
StreamScannerHandler(c, resp, info, func(data string) bool {
|
||||
time.Sleep(handlerDelay)
|
||||
count.Add(1)
|
||||
return true
|
||||
})
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(15 * time.Second):
|
||||
t.Fatal("StreamScannerHandler did not complete in time")
|
||||
}
|
||||
|
||||
elapsed := time.Since(start)
|
||||
assert.Equal(t, int64(numChunks), count.Load())
|
||||
|
||||
coupledTime := time.Duration(numChunks) * (upstreamDelay + handlerDelay)
|
||||
t.Logf("elapsed=%v, coupled_estimate=%v", elapsed, coupledTime)
|
||||
|
||||
// If decoupled, elapsed should be well under the coupled estimate.
|
||||
assert.Less(t, elapsed, coupledTime*85/100,
|
||||
"decoupled elapsed time (%v) should be significantly less than coupled estimate (%v)", elapsed, coupledTime)
|
||||
}
|
||||
|
||||
func TestStreamScannerHandler_SlowUpstreamFastHandler(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
const numChunks = 50
|
||||
body := buildSSEBody(numChunks)
|
||||
reader := &slowReader{r: strings.NewReader(body), delay: 2 * time.Millisecond}
|
||||
c, resp, info := setupStreamTest(t, reader)
|
||||
|
||||
var count atomic.Int64
|
||||
start := time.Now()
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
StreamScannerHandler(c, resp, info, func(data string) bool {
|
||||
count.Add(1)
|
||||
return true
|
||||
})
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(15 * time.Second):
|
||||
t.Fatal("timed out with slow upstream")
|
||||
}
|
||||
|
||||
elapsed := time.Since(start)
|
||||
assert.Equal(t, int64(numChunks), count.Load())
|
||||
t.Logf("slow upstream (%d chunks, 2ms/read): %v", numChunks, elapsed)
|
||||
}
|
||||
|
||||
// ---------- Ping tests ----------
|
||||
|
||||
func TestStreamScannerHandler_PingSentDuringSlowUpstream(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
setting := operation_setting.GetGeneralSetting()
|
||||
oldEnabled := setting.PingIntervalEnabled
|
||||
oldSeconds := setting.PingIntervalSeconds
|
||||
setting.PingIntervalEnabled = true
|
||||
setting.PingIntervalSeconds = 1
|
||||
t.Cleanup(func() {
|
||||
setting.PingIntervalEnabled = oldEnabled
|
||||
setting.PingIntervalSeconds = oldSeconds
|
||||
})
|
||||
|
||||
// Create a reader that delivers data slowly: one chunk every 500ms over 3.5 seconds.
|
||||
// The ping interval is 1s, so we should see at least 2 pings.
|
||||
pr, pw := io.Pipe()
|
||||
go func() {
|
||||
defer pw.Close()
|
||||
for i := 0; i < 7; i++ {
|
||||
fmt.Fprintf(pw, "data: chunk_%d\n", i)
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
}
|
||||
fmt.Fprint(pw, "data: [DONE]\n")
|
||||
}()
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(recorder)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||
|
||||
oldTimeout := constant.StreamingTimeout
|
||||
constant.StreamingTimeout = 30
|
||||
t.Cleanup(func() {
|
||||
constant.StreamingTimeout = oldTimeout
|
||||
})
|
||||
|
||||
resp := &http.Response{Body: pr}
|
||||
info := &relaycommon.RelayInfo{ChannelMeta: &relaycommon.ChannelMeta{}}
|
||||
|
||||
var count atomic.Int64
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
StreamScannerHandler(c, resp, info, func(data string) bool {
|
||||
count.Add(1)
|
||||
return true
|
||||
})
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(15 * time.Second):
|
||||
t.Fatal("timed out waiting for stream to finish")
|
||||
}
|
||||
|
||||
assert.Equal(t, int64(7), count.Load())
|
||||
|
||||
body := recorder.Body.String()
|
||||
pingCount := strings.Count(body, ": PING")
|
||||
t.Logf("received %d pings in response body", pingCount)
|
||||
assert.GreaterOrEqual(t, pingCount, 2,
|
||||
"expected at least 2 pings during 3.5s stream with 1s interval; got %d", pingCount)
|
||||
}
|
||||
|
||||
func TestStreamScannerHandler_PingDisabledByRelayInfo(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
setting := operation_setting.GetGeneralSetting()
|
||||
oldEnabled := setting.PingIntervalEnabled
|
||||
oldSeconds := setting.PingIntervalSeconds
|
||||
setting.PingIntervalEnabled = true
|
||||
setting.PingIntervalSeconds = 1
|
||||
t.Cleanup(func() {
|
||||
setting.PingIntervalEnabled = oldEnabled
|
||||
setting.PingIntervalSeconds = oldSeconds
|
||||
})
|
||||
|
||||
pr, pw := io.Pipe()
|
||||
go func() {
|
||||
defer pw.Close()
|
||||
for i := 0; i < 5; i++ {
|
||||
fmt.Fprintf(pw, "data: chunk_%d\n", i)
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
}
|
||||
fmt.Fprint(pw, "data: [DONE]\n")
|
||||
}()
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(recorder)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||
|
||||
oldTimeout := constant.StreamingTimeout
|
||||
constant.StreamingTimeout = 30
|
||||
t.Cleanup(func() {
|
||||
constant.StreamingTimeout = oldTimeout
|
||||
})
|
||||
|
||||
resp := &http.Response{Body: pr}
|
||||
info := &relaycommon.RelayInfo{
|
||||
DisablePing: true,
|
||||
ChannelMeta: &relaycommon.ChannelMeta{},
|
||||
}
|
||||
|
||||
var count atomic.Int64
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
StreamScannerHandler(c, resp, info, func(data string) bool {
|
||||
count.Add(1)
|
||||
return true
|
||||
})
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(15 * time.Second):
|
||||
t.Fatal("timed out")
|
||||
}
|
||||
|
||||
assert.Equal(t, int64(5), count.Load())
|
||||
|
||||
body := recorder.Body.String()
|
||||
pingCount := strings.Count(body, ": PING")
|
||||
assert.Equal(t, 0, pingCount, "pings should be disabled when DisablePing=true")
|
||||
}
|
||||
|
||||
func TestStreamScannerHandler_PingInterleavesWithSlowUpstream(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
setting := operation_setting.GetGeneralSetting()
|
||||
oldEnabled := setting.PingIntervalEnabled
|
||||
oldSeconds := setting.PingIntervalSeconds
|
||||
setting.PingIntervalEnabled = true
|
||||
setting.PingIntervalSeconds = 1
|
||||
t.Cleanup(func() {
|
||||
setting.PingIntervalEnabled = oldEnabled
|
||||
setting.PingIntervalSeconds = oldSeconds
|
||||
})
|
||||
|
||||
// Slow upstream + slow handler. Total stream takes ~5 seconds.
|
||||
// The ping goroutine stays alive as long as the scanner is reading,
|
||||
// so pings should fire between data writes.
|
||||
pr, pw := io.Pipe()
|
||||
go func() {
|
||||
defer pw.Close()
|
||||
for i := 0; i < 10; i++ {
|
||||
fmt.Fprintf(pw, "data: chunk_%d\n", i)
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
}
|
||||
fmt.Fprint(pw, "data: [DONE]\n")
|
||||
}()
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(recorder)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||
|
||||
oldTimeout := constant.StreamingTimeout
|
||||
constant.StreamingTimeout = 30
|
||||
t.Cleanup(func() {
|
||||
constant.StreamingTimeout = oldTimeout
|
||||
})
|
||||
|
||||
resp := &http.Response{Body: pr}
|
||||
info := &relaycommon.RelayInfo{ChannelMeta: &relaycommon.ChannelMeta{}}
|
||||
|
||||
var count atomic.Int64
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
StreamScannerHandler(c, resp, info, func(data string) bool {
|
||||
count.Add(1)
|
||||
return true
|
||||
})
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(15 * time.Second):
|
||||
t.Fatal("timed out")
|
||||
}
|
||||
|
||||
assert.Equal(t, int64(10), count.Load())
|
||||
|
||||
body := recorder.Body.String()
|
||||
pingCount := strings.Count(body, ": PING")
|
||||
t.Logf("received %d pings interleaved with 10 chunks over 5s", pingCount)
|
||||
assert.GreaterOrEqual(t, pingCount, 3,
|
||||
"expected at least 3 pings during 5s stream with 1s ping interval; got %d", pingCount)
|
||||
}
|
||||
@@ -184,7 +184,7 @@ func RelaySwapFace(c *gin.Context, info *relaycommon.RelayInfo) *dto.MidjourneyR
|
||||
if swapFaceRequest.SourceBase64 == "" || swapFaceRequest.TargetBase64 == "" {
|
||||
return service.MidjourneyErrorWrapper(constant.MjRequestError, "sour_base64_and_target_base64_is_required")
|
||||
}
|
||||
modelName := service.CoverActionToModelName(constant.MjActionSwapFace)
|
||||
modelName := service.CovertMjpActionToModelName(constant.MjActionSwapFace)
|
||||
|
||||
priceData := helper.ModelPriceHelperPerCall(c, info)
|
||||
|
||||
@@ -485,7 +485,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayInfo *relaycommon.RelayInfo) *dt
|
||||
|
||||
fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
|
||||
|
||||
modelName := service.CoverActionToModelName(midjRequest.Action)
|
||||
modelName := service.CovertMjpActionToModelName(midjRequest.Action)
|
||||
|
||||
priceData := helper.ModelPriceHelperPerCall(c, relayInfo)
|
||||
|
||||
|
||||
@@ -2,7 +2,6 @@ package relay
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
@@ -15,29 +14,33 @@ import (
|
||||
"github.com/QuantumNous/new-api/dto"
|
||||
"github.com/QuantumNous/new-api/model"
|
||||
"github.com/QuantumNous/new-api/relay/channel"
|
||||
"github.com/QuantumNous/new-api/relay/channel/task/taskcommon"
|
||||
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||||
relayconstant "github.com/QuantumNous/new-api/relay/constant"
|
||||
"github.com/QuantumNous/new-api/relay/helper"
|
||||
"github.com/QuantumNous/new-api/service"
|
||||
"github.com/QuantumNous/new-api/setting/ratio_setting"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
/*
|
||||
Task 任务通过平台、Action 区分任务
|
||||
*/
|
||||
func RelayTaskSubmit(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.TaskError) {
|
||||
info.InitChannelMeta(c)
|
||||
// ensure TaskRelayInfo is initialized to avoid nil dereference when accessing embedded fields
|
||||
if info.TaskRelayInfo == nil {
|
||||
info.TaskRelayInfo = &relaycommon.TaskRelayInfo{}
|
||||
}
|
||||
type TaskSubmitResult struct {
|
||||
UpstreamTaskID string
|
||||
TaskData []byte
|
||||
Platform constant.TaskPlatform
|
||||
Quota int
|
||||
//PerCallPrice types.PriceData
|
||||
}
|
||||
|
||||
// ResolveOriginTask 处理基于已有任务的提交(remix / continuation):
|
||||
// 查找原始任务、从中提取模型名称、将渠道锁定到原始任务的渠道
|
||||
// (通过 info.LockedChannel,重试时复用同一渠道并轮换 key),
|
||||
// 以及提取 OtherRatios(时长、分辨率)。
|
||||
// 该函数在控制器的重试循环之前调用一次,其结果通过 info 字段和上下文持久化。
|
||||
func ResolveOriginTask(c *gin.Context, info *relaycommon.RelayInfo) *dto.TaskError {
|
||||
// 检测 remix action
|
||||
path := c.Request.URL.Path
|
||||
if strings.Contains(path, "/v1/videos/") && strings.HasSuffix(path, "/remix") {
|
||||
info.Action = constant.TaskActionRemix
|
||||
}
|
||||
|
||||
// 提取 remix 任务的 video_id
|
||||
if info.Action == constant.TaskActionRemix {
|
||||
videoID := c.Param("video_id")
|
||||
if strings.TrimSpace(videoID) == "" {
|
||||
@@ -46,64 +49,71 @@ func RelayTaskSubmit(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.
|
||||
info.OriginTaskID = videoID
|
||||
}
|
||||
|
||||
platform := constant.TaskPlatform(c.GetString("platform"))
|
||||
if info.OriginTaskID == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 获取原始任务信息
|
||||
if info.OriginTaskID != "" {
|
||||
originTask, exist, err := model.GetByTaskId(info.UserId, info.OriginTaskID)
|
||||
if err != nil {
|
||||
taskErr = service.TaskErrorWrapper(err, "get_origin_task_failed", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
if !exist {
|
||||
taskErr = service.TaskErrorWrapperLocal(errors.New("task_origin_not_exist"), "task_not_exist", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if info.OriginModelName == "" {
|
||||
if originTask.Properties.OriginModelName != "" {
|
||||
info.OriginModelName = originTask.Properties.OriginModelName
|
||||
} else if originTask.Properties.UpstreamModelName != "" {
|
||||
info.OriginModelName = originTask.Properties.UpstreamModelName
|
||||
} else {
|
||||
var taskData map[string]interface{}
|
||||
_ = json.Unmarshal(originTask.Data, &taskData)
|
||||
if m, ok := taskData["model"].(string); ok && m != "" {
|
||||
info.OriginModelName = m
|
||||
platform = originTask.Platform
|
||||
}
|
||||
}
|
||||
}
|
||||
if originTask.ChannelId != info.ChannelId {
|
||||
channel, err := model.GetChannelById(originTask.ChannelId, true)
|
||||
if err != nil {
|
||||
taskErr = service.TaskErrorWrapperLocal(err, "channel_not_found", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if channel.Status != common.ChannelStatusEnabled {
|
||||
taskErr = service.TaskErrorWrapperLocal(errors.New("the channel of the origin task is disabled"), "task_channel_disable", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
key, _, newAPIError := channel.GetNextEnabledKey()
|
||||
if newAPIError != nil {
|
||||
taskErr = service.TaskErrorWrapper(newAPIError, "channel_no_available_key", newAPIError.StatusCode)
|
||||
return
|
||||
}
|
||||
common.SetContextKey(c, constant.ContextKeyChannelKey, key)
|
||||
common.SetContextKey(c, constant.ContextKeyChannelType, channel.Type)
|
||||
common.SetContextKey(c, constant.ContextKeyChannelBaseUrl, channel.GetBaseURL())
|
||||
common.SetContextKey(c, constant.ContextKeyChannelId, originTask.ChannelId)
|
||||
// 查找原始任务
|
||||
originTask, exist, err := model.GetByTaskId(info.UserId, info.OriginTaskID)
|
||||
if err != nil {
|
||||
return service.TaskErrorWrapper(err, "get_origin_task_failed", http.StatusInternalServerError)
|
||||
}
|
||||
if !exist {
|
||||
return service.TaskErrorWrapperLocal(errors.New("task_origin_not_exist"), "task_not_exist", http.StatusBadRequest)
|
||||
}
|
||||
|
||||
info.ChannelBaseUrl = channel.GetBaseURL()
|
||||
info.ChannelId = originTask.ChannelId
|
||||
info.ChannelType = channel.Type
|
||||
info.ApiKey = key
|
||||
platform = originTask.Platform
|
||||
}
|
||||
|
||||
// 使用原始任务的参数
|
||||
if info.Action == constant.TaskActionRemix {
|
||||
// 从原始任务推导模型名称
|
||||
if info.OriginModelName == "" {
|
||||
if originTask.Properties.OriginModelName != "" {
|
||||
info.OriginModelName = originTask.Properties.OriginModelName
|
||||
} else if originTask.Properties.UpstreamModelName != "" {
|
||||
info.OriginModelName = originTask.Properties.UpstreamModelName
|
||||
} else {
|
||||
var taskData map[string]interface{}
|
||||
_ = json.Unmarshal(originTask.Data, &taskData)
|
||||
_ = common.Unmarshal(originTask.Data, &taskData)
|
||||
if m, ok := taskData["model"].(string); ok && m != "" {
|
||||
info.OriginModelName = m
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 锁定到原始任务的渠道(重试时复用同一渠道,轮换 key)
|
||||
ch, err := model.GetChannelById(originTask.ChannelId, true)
|
||||
if err != nil {
|
||||
return service.TaskErrorWrapperLocal(err, "channel_not_found", http.StatusBadRequest)
|
||||
}
|
||||
if ch.Status != common.ChannelStatusEnabled {
|
||||
return service.TaskErrorWrapperLocal(errors.New("the channel of the origin task is disabled"), "task_channel_disable", http.StatusBadRequest)
|
||||
}
|
||||
info.LockedChannel = ch
|
||||
|
||||
if originTask.ChannelId != info.ChannelId {
|
||||
key, _, newAPIError := ch.GetNextEnabledKey()
|
||||
if newAPIError != nil {
|
||||
return service.TaskErrorWrapper(newAPIError, "channel_no_available_key", newAPIError.StatusCode)
|
||||
}
|
||||
common.SetContextKey(c, constant.ContextKeyChannelKey, key)
|
||||
common.SetContextKey(c, constant.ContextKeyChannelType, ch.Type)
|
||||
common.SetContextKey(c, constant.ContextKeyChannelBaseUrl, ch.GetBaseURL())
|
||||
common.SetContextKey(c, constant.ContextKeyChannelId, originTask.ChannelId)
|
||||
|
||||
info.ChannelBaseUrl = ch.GetBaseURL()
|
||||
info.ChannelId = originTask.ChannelId
|
||||
info.ChannelType = ch.Type
|
||||
info.ApiKey = key
|
||||
}
|
||||
|
||||
// 提取 remix 参数(时长、分辨率 → OtherRatios)
|
||||
if info.Action == constant.TaskActionRemix {
|
||||
if originTask.PrivateData.BillingContext != nil {
|
||||
// 新的 remix 逻辑:直接从原始任务的 BillingContext 中提取 OtherRatios(如果存在)
|
||||
for s, f := range originTask.PrivateData.BillingContext.OtherRatios {
|
||||
info.PriceData.AddOtherRatio(s, f)
|
||||
}
|
||||
} else {
|
||||
// 旧的 remix 逻辑:直接从 task data 解析 seconds 和 size(如果存在)
|
||||
var taskData map[string]interface{}
|
||||
_ = common.Unmarshal(originTask.Data, &taskData)
|
||||
secondsStr, _ := taskData["seconds"].(string)
|
||||
seconds, _ := strconv.Atoi(secondsStr)
|
||||
if seconds <= 0 {
|
||||
@@ -120,167 +130,146 @@ func RelayTaskSubmit(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RelayTaskSubmit 完成 task 提交的全部流程(每次尝试调用一次):
|
||||
// 刷新渠道元数据 → 确定 platform/adaptor → 验证请求 →
|
||||
// 估算计费(EstimateBilling) → 计算价格 → 预扣费(仅首次)→
|
||||
// 构建/发送/解析上游请求 → 提交后计费调整(AdjustBillingOnSubmit)。
|
||||
// 控制器负责 defer Refund 和成功后 Settle。
|
||||
func RelayTaskSubmit(c *gin.Context, info *relaycommon.RelayInfo) (*TaskSubmitResult, *dto.TaskError) {
|
||||
info.InitChannelMeta(c)
|
||||
|
||||
// 1. 确定 platform → 创建适配器 → 验证请求
|
||||
platform := constant.TaskPlatform(c.GetString("platform"))
|
||||
if platform == "" {
|
||||
platform = GetTaskPlatform(c)
|
||||
}
|
||||
|
||||
info.InitChannelMeta(c)
|
||||
adaptor := GetTaskAdaptor(platform)
|
||||
if adaptor == nil {
|
||||
return service.TaskErrorWrapperLocal(fmt.Errorf("invalid api platform: %s", platform), "invalid_api_platform", http.StatusBadRequest)
|
||||
return nil, service.TaskErrorWrapperLocal(fmt.Errorf("invalid api platform: %s", platform), "invalid_api_platform", http.StatusBadRequest)
|
||||
}
|
||||
adaptor.Init(info)
|
||||
// get & validate taskRequest 获取并验证文本请求
|
||||
taskErr = adaptor.ValidateRequestAndSetAction(c, info)
|
||||
if taskErr != nil {
|
||||
return
|
||||
if taskErr := adaptor.ValidateRequestAndSetAction(c, info); taskErr != nil {
|
||||
return nil, taskErr
|
||||
}
|
||||
|
||||
// 2. 确定模型名称
|
||||
modelName := info.OriginModelName
|
||||
if modelName == "" {
|
||||
modelName = service.CoverTaskActionToModelName(platform, info.Action)
|
||||
}
|
||||
modelPrice, success := ratio_setting.GetModelPrice(modelName, true)
|
||||
if !success {
|
||||
defaultPrice, ok := ratio_setting.GetDefaultModelPriceMap()[modelName]
|
||||
if !ok {
|
||||
modelPrice = float64(common.PreConsumedQuota) / common.QuotaPerUnit
|
||||
} else {
|
||||
modelPrice = defaultPrice
|
||||
|
||||
// 2.5 应用渠道的模型映射(与同步任务对齐)
|
||||
info.OriginModelName = modelName
|
||||
info.UpstreamModelName = modelName
|
||||
if err := helper.ModelMappedHelper(c, info, nil); err != nil {
|
||||
return nil, service.TaskErrorWrapperLocal(err, "model_mapping_failed", http.StatusBadRequest)
|
||||
}
|
||||
|
||||
// 3. 预生成公开 task ID(仅首次)
|
||||
if info.PublicTaskID == "" {
|
||||
info.PublicTaskID = model.GenerateTaskID()
|
||||
}
|
||||
|
||||
// 4. 价格计算:基础模型价格
|
||||
info.OriginModelName = modelName
|
||||
info.PriceData = helper.ModelPriceHelperPerCall(c, info)
|
||||
|
||||
// 5. 计费估算:让适配器根据用户请求提供 OtherRatios(时长、分辨率等)
|
||||
// 必须在 ModelPriceHelperPerCall 之后调用(它会重建 PriceData)。
|
||||
// ResolveOriginTask 可能已在 remix 路径中预设了 OtherRatios,此处合并。
|
||||
if estimatedRatios := adaptor.EstimateBilling(c, info); len(estimatedRatios) > 0 {
|
||||
for k, v := range estimatedRatios {
|
||||
info.PriceData.AddOtherRatio(k, v)
|
||||
}
|
||||
}
|
||||
|
||||
// 处理 auto 分组:从 context 获取实际选中的分组
|
||||
// 当使用 auto 分组时,Distribute 中间件会将实际选中的分组存储在 ContextKeyAutoGroup 中
|
||||
if autoGroup, exists := common.GetContextKey(c, constant.ContextKeyAutoGroup); exists {
|
||||
if groupStr, ok := autoGroup.(string); ok && groupStr != "" {
|
||||
info.UsingGroup = groupStr
|
||||
}
|
||||
}
|
||||
|
||||
// 预扣
|
||||
groupRatio := ratio_setting.GetGroupRatio(info.UsingGroup)
|
||||
var ratio float64
|
||||
userGroupRatio, hasUserGroupRatio := ratio_setting.GetGroupGroupRatio(info.UserGroup, info.UsingGroup)
|
||||
if hasUserGroupRatio {
|
||||
ratio = modelPrice * userGroupRatio
|
||||
} else {
|
||||
ratio = modelPrice * groupRatio
|
||||
}
|
||||
// FIXME: 临时修补,支持任务仅按次计费
|
||||
// 6. 将 OtherRatios 应用到基础额度
|
||||
if !common.StringsContains(constant.TaskPricePatches, modelName) {
|
||||
if len(info.PriceData.OtherRatios) > 0 {
|
||||
for _, ra := range info.PriceData.OtherRatios {
|
||||
if 1.0 != ra {
|
||||
ratio *= ra
|
||||
}
|
||||
for _, ra := range info.PriceData.OtherRatios {
|
||||
if ra != 1.0 {
|
||||
info.PriceData.Quota = int(float64(info.PriceData.Quota) * ra)
|
||||
}
|
||||
}
|
||||
}
|
||||
println(fmt.Sprintf("model: %s, model_price: %.4f, group: %s, group_ratio: %.4f, final_ratio: %.4f", modelName, modelPrice, info.UsingGroup, groupRatio, ratio))
|
||||
userQuota, err := model.GetUserQuota(info.UserId, false)
|
||||
if err != nil {
|
||||
taskErr = service.TaskErrorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
quota := int(ratio * common.QuotaPerUnit)
|
||||
if userQuota-quota < 0 {
|
||||
taskErr = service.TaskErrorWrapperLocal(errors.New("user quota is not enough"), "quota_not_enough", http.StatusForbidden)
|
||||
return
|
||||
|
||||
// 7. 预扣费(仅首次 — 重试时 info.Billing 已存在,跳过)
|
||||
if info.Billing == nil && !info.PriceData.FreeModel {
|
||||
info.ForcePreConsume = true
|
||||
if apiErr := service.PreConsumeBilling(c, info.PriceData.Quota, info); apiErr != nil {
|
||||
return nil, service.TaskErrorFromAPIError(apiErr)
|
||||
}
|
||||
}
|
||||
|
||||
// build body
|
||||
// 8. 构建请求体
|
||||
requestBody, err := adaptor.BuildRequestBody(c, info)
|
||||
if err != nil {
|
||||
taskErr = service.TaskErrorWrapper(err, "build_request_failed", http.StatusInternalServerError)
|
||||
return
|
||||
return nil, service.TaskErrorWrapper(err, "build_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
// do request
|
||||
|
||||
// 9. 发送请求
|
||||
resp, err := adaptor.DoRequest(c, info, requestBody)
|
||||
if err != nil {
|
||||
taskErr = service.TaskErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
|
||||
return
|
||||
return nil, service.TaskErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
// handle response
|
||||
if resp != nil && resp.StatusCode != http.StatusOK {
|
||||
responseBody, _ := io.ReadAll(resp.Body)
|
||||
taskErr = service.TaskErrorWrapper(fmt.Errorf("%s", string(responseBody)), "fail_to_fetch_task", resp.StatusCode)
|
||||
return
|
||||
return nil, service.TaskErrorWrapper(fmt.Errorf("%s", string(responseBody)), "fail_to_fetch_task", resp.StatusCode)
|
||||
}
|
||||
|
||||
defer func() {
|
||||
// release quota
|
||||
if info.ConsumeQuota && taskErr == nil {
|
||||
// 10. 返回 OtherRatios 给下游(header 必须在 DoResponse 写 body 之前设置)
|
||||
otherRatios := info.PriceData.OtherRatios
|
||||
if otherRatios == nil {
|
||||
otherRatios = map[string]float64{}
|
||||
}
|
||||
ratiosJSON, _ := common.Marshal(otherRatios)
|
||||
c.Header("X-New-Api-Other-Ratios", string(ratiosJSON))
|
||||
|
||||
err := service.PostConsumeQuota(info, quota, 0, true)
|
||||
if err != nil {
|
||||
common.SysLog("error consuming token remain quota: " + err.Error())
|
||||
}
|
||||
if quota != 0 {
|
||||
tokenName := c.GetString("token_name")
|
||||
//gRatio := groupRatio
|
||||
//if hasUserGroupRatio {
|
||||
// gRatio = userGroupRatio
|
||||
//}
|
||||
logContent := fmt.Sprintf("操作 %s", info.Action)
|
||||
// FIXME: 临时修补,支持任务仅按次计费
|
||||
if common.StringsContains(constant.TaskPricePatches, modelName) {
|
||||
logContent = fmt.Sprintf("%s,按次计费", logContent)
|
||||
} else {
|
||||
if len(info.PriceData.OtherRatios) > 0 {
|
||||
var contents []string
|
||||
for key, ra := range info.PriceData.OtherRatios {
|
||||
if 1.0 != ra {
|
||||
contents = append(contents, fmt.Sprintf("%s: %.2f", key, ra))
|
||||
}
|
||||
}
|
||||
if len(contents) > 0 {
|
||||
logContent = fmt.Sprintf("%s, 计算参数:%s", logContent, strings.Join(contents, ", "))
|
||||
}
|
||||
}
|
||||
}
|
||||
other := make(map[string]interface{})
|
||||
if c != nil && c.Request != nil && c.Request.URL != nil {
|
||||
other["request_path"] = c.Request.URL.Path
|
||||
}
|
||||
other["model_price"] = modelPrice
|
||||
other["group_ratio"] = groupRatio
|
||||
if hasUserGroupRatio {
|
||||
other["user_group_ratio"] = userGroupRatio
|
||||
}
|
||||
model.RecordConsumeLog(c, info.UserId, model.RecordConsumeLogParams{
|
||||
ChannelId: info.ChannelId,
|
||||
ModelName: modelName,
|
||||
TokenName: tokenName,
|
||||
Quota: quota,
|
||||
Content: logContent,
|
||||
TokenId: info.TokenId,
|
||||
Group: info.UsingGroup,
|
||||
Other: other,
|
||||
})
|
||||
model.UpdateUserUsedQuotaAndRequestCount(info.UserId, quota)
|
||||
model.UpdateChannelUsedQuota(info.ChannelId, quota)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
taskID, taskData, taskErr := adaptor.DoResponse(c, resp, info)
|
||||
// 11. 解析响应
|
||||
upstreamTaskID, taskData, taskErr := adaptor.DoResponse(c, resp, info)
|
||||
if taskErr != nil {
|
||||
return
|
||||
return nil, taskErr
|
||||
}
|
||||
info.ConsumeQuota = true
|
||||
// insert task
|
||||
task := model.InitTask(platform, info)
|
||||
task.TaskID = taskID
|
||||
task.Quota = quota
|
||||
task.Data = taskData
|
||||
task.Action = info.Action
|
||||
err = task.Insert()
|
||||
if err != nil {
|
||||
taskErr = service.TaskErrorWrapper(err, "insert_task_failed", http.StatusInternalServerError)
|
||||
return
|
||||
|
||||
// 11. 提交后计费调整:让适配器根据上游实际返回调整 OtherRatios
|
||||
finalQuota := info.PriceData.Quota
|
||||
if adjustedRatios := adaptor.AdjustBillingOnSubmit(info, taskData); len(adjustedRatios) > 0 {
|
||||
// 基于调整后的 ratios 重新计算 quota
|
||||
finalQuota = recalcQuotaFromRatios(info, adjustedRatios)
|
||||
info.PriceData.OtherRatios = adjustedRatios
|
||||
info.PriceData.Quota = finalQuota
|
||||
}
|
||||
return nil
|
||||
|
||||
return &TaskSubmitResult{
|
||||
UpstreamTaskID: upstreamTaskID,
|
||||
TaskData: taskData,
|
||||
Platform: platform,
|
||||
Quota: finalQuota,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// recalcQuotaFromRatios 根据 adjustedRatios 重新计算 quota。
|
||||
// 公式: baseQuota × ∏(ratio) — 其中 baseQuota 是不含 OtherRatios 的基础额度。
|
||||
func recalcQuotaFromRatios(info *relaycommon.RelayInfo, ratios map[string]float64) int {
|
||||
// 从 PriceData 获取不含 OtherRatios 的基础价格
|
||||
baseQuota := info.PriceData.Quota
|
||||
// 先除掉原有的 OtherRatios 恢复基础额度
|
||||
for _, ra := range info.PriceData.OtherRatios {
|
||||
if ra != 1.0 && ra > 0 {
|
||||
baseQuota = int(float64(baseQuota) / ra)
|
||||
}
|
||||
}
|
||||
// 应用新的 ratios
|
||||
result := float64(baseQuota)
|
||||
for _, ra := range ratios {
|
||||
if ra != 1.0 {
|
||||
result *= ra
|
||||
}
|
||||
}
|
||||
return int(result)
|
||||
}
|
||||
|
||||
var fetchRespBuilders = map[int]func(c *gin.Context) (respBody []byte, taskResp *dto.TaskError){
|
||||
@@ -336,7 +325,7 @@ func sunoFetchRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *dto.Ta
|
||||
} else {
|
||||
tasks = make([]any, 0)
|
||||
}
|
||||
respBody, err = json.Marshal(dto.TaskResponse[[]any]{
|
||||
respBody, err = common.Marshal(dto.TaskResponse[[]any]{
|
||||
Code: "success",
|
||||
Data: tasks,
|
||||
})
|
||||
@@ -357,7 +346,7 @@ func sunoFetchByIDRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *dt
|
||||
return
|
||||
}
|
||||
|
||||
respBody, err = json.Marshal(dto.TaskResponse[any]{
|
||||
respBody, err = common.Marshal(dto.TaskResponse[any]{
|
||||
Code: "success",
|
||||
Data: TaskModel2Dto(originTask),
|
||||
})
|
||||
@@ -381,97 +370,16 @@ func videoFetchByIDRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *d
|
||||
return
|
||||
}
|
||||
|
||||
func() {
|
||||
channelModel, err2 := model.GetChannelById(originTask.ChannelId, true)
|
||||
if err2 != nil {
|
||||
return
|
||||
}
|
||||
if channelModel.Type != constant.ChannelTypeVertexAi && channelModel.Type != constant.ChannelTypeGemini {
|
||||
return
|
||||
}
|
||||
baseURL := constant.ChannelBaseURLs[channelModel.Type]
|
||||
if channelModel.GetBaseURL() != "" {
|
||||
baseURL = channelModel.GetBaseURL()
|
||||
}
|
||||
proxy := channelModel.GetSetting().Proxy
|
||||
adaptor := GetTaskAdaptor(constant.TaskPlatform(strconv.Itoa(channelModel.Type)))
|
||||
if adaptor == nil {
|
||||
return
|
||||
}
|
||||
resp, err2 := adaptor.FetchTask(baseURL, channelModel.Key, map[string]any{
|
||||
"task_id": originTask.TaskID,
|
||||
"action": originTask.Action,
|
||||
}, proxy)
|
||||
if err2 != nil || resp == nil {
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
body, err2 := io.ReadAll(resp.Body)
|
||||
if err2 != nil {
|
||||
return
|
||||
}
|
||||
ti, err2 := adaptor.ParseTaskResult(body)
|
||||
if err2 == nil && ti != nil {
|
||||
if ti.Status != "" {
|
||||
originTask.Status = model.TaskStatus(ti.Status)
|
||||
}
|
||||
if ti.Progress != "" {
|
||||
originTask.Progress = ti.Progress
|
||||
}
|
||||
if ti.Url != "" {
|
||||
if strings.HasPrefix(ti.Url, "data:") {
|
||||
} else {
|
||||
originTask.FailReason = ti.Url
|
||||
}
|
||||
}
|
||||
_ = originTask.Update()
|
||||
var raw map[string]any
|
||||
_ = json.Unmarshal(body, &raw)
|
||||
format := "mp4"
|
||||
if respObj, ok := raw["response"].(map[string]any); ok {
|
||||
if vids, ok := respObj["videos"].([]any); ok && len(vids) > 0 {
|
||||
if v0, ok := vids[0].(map[string]any); ok {
|
||||
if mt, ok := v0["mimeType"].(string); ok && mt != "" {
|
||||
if strings.Contains(mt, "mp4") {
|
||||
format = "mp4"
|
||||
} else {
|
||||
format = mt
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
status := "processing"
|
||||
switch originTask.Status {
|
||||
case model.TaskStatusSuccess:
|
||||
status = "succeeded"
|
||||
case model.TaskStatusFailure:
|
||||
status = "failed"
|
||||
case model.TaskStatusQueued, model.TaskStatusSubmitted:
|
||||
status = "queued"
|
||||
}
|
||||
if !strings.HasPrefix(c.Request.RequestURI, "/v1/videos/") {
|
||||
out := map[string]any{
|
||||
"error": nil,
|
||||
"format": format,
|
||||
"metadata": nil,
|
||||
"status": status,
|
||||
"task_id": originTask.TaskID,
|
||||
"url": originTask.FailReason,
|
||||
}
|
||||
respBody, _ = json.Marshal(dto.TaskResponse[any]{
|
||||
Code: "success",
|
||||
Data: out,
|
||||
})
|
||||
}
|
||||
}
|
||||
}()
|
||||
isOpenAIVideoAPI := strings.HasPrefix(c.Request.RequestURI, "/v1/videos/")
|
||||
|
||||
if len(respBody) != 0 {
|
||||
// Gemini/Vertex 支持实时查询:用户 fetch 时直接从上游拉取最新状态
|
||||
if realtimeResp := tryRealtimeFetch(originTask, isOpenAIVideoAPI); len(realtimeResp) > 0 {
|
||||
respBody = realtimeResp
|
||||
return
|
||||
}
|
||||
|
||||
if strings.HasPrefix(c.Request.RequestURI, "/v1/videos/") {
|
||||
// OpenAI Video API 格式: 走各 adaptor 的 ConvertToOpenAIVideo
|
||||
if isOpenAIVideoAPI {
|
||||
adaptor := GetTaskAdaptor(originTask.Platform)
|
||||
if adaptor == nil {
|
||||
taskResp = service.TaskErrorWrapperLocal(fmt.Errorf("invalid channel id: %d", originTask.ChannelId), "invalid_channel_id", http.StatusBadRequest)
|
||||
@@ -486,10 +394,12 @@ func videoFetchByIDRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *d
|
||||
respBody = openAIVideoData
|
||||
return
|
||||
}
|
||||
taskResp = service.TaskErrorWrapperLocal(errors.New(fmt.Sprintf("not_implemented:%s", originTask.Platform)), "not_implemented", http.StatusNotImplemented)
|
||||
taskResp = service.TaskErrorWrapperLocal(fmt.Errorf("not_implemented:%s", originTask.Platform), "not_implemented", http.StatusNotImplemented)
|
||||
return
|
||||
}
|
||||
respBody, err = json.Marshal(dto.TaskResponse[any]{
|
||||
|
||||
// 通用 TaskDto 格式
|
||||
respBody, err = common.Marshal(dto.TaskResponse[any]{
|
||||
Code: "success",
|
||||
Data: TaskModel2Dto(originTask),
|
||||
})
|
||||
@@ -499,16 +409,150 @@ func videoFetchByIDRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *d
|
||||
return
|
||||
}
|
||||
|
||||
// tryRealtimeFetch 尝试从上游实时拉取 Gemini/Vertex 任务状态。
|
||||
// 仅当渠道类型为 Gemini 或 Vertex 时触发;其他渠道或出错时返回 nil。
|
||||
// 当非 OpenAI Video API 时,还会构建自定义格式的响应体。
|
||||
func tryRealtimeFetch(task *model.Task, isOpenAIVideoAPI bool) []byte {
|
||||
channelModel, err := model.GetChannelById(task.ChannelId, true)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
if channelModel.Type != constant.ChannelTypeVertexAi && channelModel.Type != constant.ChannelTypeGemini {
|
||||
return nil
|
||||
}
|
||||
|
||||
baseURL := constant.ChannelBaseURLs[channelModel.Type]
|
||||
if channelModel.GetBaseURL() != "" {
|
||||
baseURL = channelModel.GetBaseURL()
|
||||
}
|
||||
proxy := channelModel.GetSetting().Proxy
|
||||
adaptor := GetTaskAdaptor(constant.TaskPlatform(strconv.Itoa(channelModel.Type)))
|
||||
if adaptor == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
resp, err := adaptor.FetchTask(baseURL, channelModel.Key, map[string]any{
|
||||
"task_id": task.GetUpstreamTaskID(),
|
||||
"action": task.Action,
|
||||
}, proxy)
|
||||
if err != nil || resp == nil {
|
||||
return nil
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
ti, err := adaptor.ParseTaskResult(body)
|
||||
if err != nil || ti == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
snap := task.Snapshot()
|
||||
|
||||
// 将上游最新状态更新到 task
|
||||
if ti.Status != "" {
|
||||
task.Status = model.TaskStatus(ti.Status)
|
||||
}
|
||||
if ti.Progress != "" {
|
||||
task.Progress = ti.Progress
|
||||
}
|
||||
if strings.HasPrefix(ti.Url, "data:") {
|
||||
// data: URI — kept in Data, not ResultURL
|
||||
} else if ti.Url != "" {
|
||||
task.PrivateData.ResultURL = ti.Url
|
||||
} else if task.Status == model.TaskStatusSuccess {
|
||||
// No URL from adaptor — construct proxy URL using public task ID
|
||||
task.PrivateData.ResultURL = taskcommon.BuildProxyURL(task.TaskID)
|
||||
}
|
||||
|
||||
if !snap.Equal(task.Snapshot()) {
|
||||
_, _ = task.UpdateWithStatus(snap.Status)
|
||||
}
|
||||
|
||||
// OpenAI Video API 由调用者的 ConvertToOpenAIVideo 分支处理
|
||||
if isOpenAIVideoAPI {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 非 OpenAI Video API: 构建自定义格式响应
|
||||
format := detectVideoFormat(body)
|
||||
out := map[string]any{
|
||||
"error": nil,
|
||||
"format": format,
|
||||
"metadata": nil,
|
||||
"status": mapTaskStatusToSimple(task.Status),
|
||||
"task_id": task.TaskID,
|
||||
"url": task.GetResultURL(),
|
||||
}
|
||||
respBody, _ := common.Marshal(dto.TaskResponse[any]{
|
||||
Code: "success",
|
||||
Data: out,
|
||||
})
|
||||
return respBody
|
||||
}
|
||||
|
||||
// detectVideoFormat 从 Gemini/Vertex 原始响应中探测视频格式
|
||||
func detectVideoFormat(rawBody []byte) string {
|
||||
var raw map[string]any
|
||||
if err := common.Unmarshal(rawBody, &raw); err != nil {
|
||||
return "mp4"
|
||||
}
|
||||
respObj, ok := raw["response"].(map[string]any)
|
||||
if !ok {
|
||||
return "mp4"
|
||||
}
|
||||
vids, ok := respObj["videos"].([]any)
|
||||
if !ok || len(vids) == 0 {
|
||||
return "mp4"
|
||||
}
|
||||
v0, ok := vids[0].(map[string]any)
|
||||
if !ok {
|
||||
return "mp4"
|
||||
}
|
||||
mt, ok := v0["mimeType"].(string)
|
||||
if !ok || mt == "" || strings.Contains(mt, "mp4") {
|
||||
return "mp4"
|
||||
}
|
||||
return mt
|
||||
}
|
||||
|
||||
// mapTaskStatusToSimple 将内部 TaskStatus 映射为简化状态字符串
|
||||
func mapTaskStatusToSimple(status model.TaskStatus) string {
|
||||
switch status {
|
||||
case model.TaskStatusSuccess:
|
||||
return "succeeded"
|
||||
case model.TaskStatusFailure:
|
||||
return "failed"
|
||||
case model.TaskStatusQueued, model.TaskStatusSubmitted:
|
||||
return "queued"
|
||||
default:
|
||||
return "processing"
|
||||
}
|
||||
}
|
||||
|
||||
func TaskModel2Dto(task *model.Task) *dto.TaskDto {
|
||||
return &dto.TaskDto{
|
||||
ID: task.ID,
|
||||
CreatedAt: task.CreatedAt,
|
||||
UpdatedAt: task.UpdatedAt,
|
||||
TaskID: task.TaskID,
|
||||
Platform: string(task.Platform),
|
||||
UserId: task.UserId,
|
||||
Group: task.Group,
|
||||
ChannelId: task.ChannelId,
|
||||
Quota: task.Quota,
|
||||
Action: task.Action,
|
||||
Status: string(task.Status),
|
||||
FailReason: task.FailReason,
|
||||
ResultURL: task.GetResultURL(),
|
||||
SubmitTime: task.SubmitTime,
|
||||
StartTime: task.StartTime,
|
||||
FinishTime: task.FinishTime,
|
||||
Progress: task.Progress,
|
||||
Properties: task.Properties,
|
||||
Username: task.Username,
|
||||
Data: task.Data,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -89,7 +89,7 @@ func ResponsesHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *
|
||||
}
|
||||
|
||||
// remove disabled fields for OpenAI Responses API
|
||||
jsonData, err = relaycommon.RemoveDisabledFields(jsonData, info.ChannelOtherSettings)
|
||||
jsonData, err = relaycommon.RemoveDisabledFields(jsonData, info.ChannelOtherSettings, info.ChannelSetting.PassThroughBodyEnabled)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user