From e5d47daf26c2472c123faf8549331ac1338d69f5 Mon Sep 17 00:00:00 2001 From: feitianbubu Date: Mon, 9 Feb 2026 15:03:41 +0800 Subject: [PATCH 01/41] feat: allow custom username for new users --- controller/oauth.go | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/controller/oauth.go b/controller/oauth.go index 65e18f9da..faa22dd4f 100644 --- a/controller/oauth.go +++ b/controller/oauth.go @@ -237,6 +237,13 @@ func findOrCreateOAuthUser(c *gin.Context, provider oauth.Provider, oauthUser *o // Set up new user user.Username = provider.GetProviderPrefix() + strconv.Itoa(model.GetMaxUserId()+1) + + if oauthUser.Username != "" { + if exists, err := model.CheckUserExistOrDeleted(oauthUser.Username, ""); err == nil && !exists { + user.Username = oauthUser.Username + } + } + if oauthUser.DisplayName != "" { user.DisplayName = oauthUser.DisplayName } else if oauthUser.Username != "" { From 4360393dc10b2f1cac279aa2a00cc14d3825da4d Mon Sep 17 00:00:00 2001 From: Seefs Date: Tue, 17 Feb 2026 15:45:14 +0800 Subject: [PATCH 02/41] fix: unify usage mapping and include toolUsePromptTokenCount in input tokens --- dto/gemini.go | 14 +- relay/channel/gemini/relay-gemini-native.go | 17 +- relay/channel/gemini/relay-gemini.go | 93 +++-- .../channel/gemini/relay_gemini_usage_test.go | 333 ++++++++++++++++++ 4 files changed, 386 insertions(+), 71 deletions(-) create mode 100644 relay/channel/gemini/relay_gemini_usage_test.go diff --git a/dto/gemini.go b/dto/gemini.go index 0fd74c639..c963960e5 100644 --- a/dto/gemini.go +++ b/dto/gemini.go @@ -453,12 +453,14 @@ type GeminiChatResponse struct { } type GeminiUsageMetadata struct { - PromptTokenCount int `json:"promptTokenCount"` - CandidatesTokenCount int `json:"candidatesTokenCount"` - TotalTokenCount int `json:"totalTokenCount"` - ThoughtsTokenCount int `json:"thoughtsTokenCount"` - CachedContentTokenCount int `json:"cachedContentTokenCount"` - PromptTokensDetails []GeminiPromptTokensDetails `json:"promptTokensDetails"` + PromptTokenCount int `json:"promptTokenCount"` + ToolUsePromptTokenCount int `json:"toolUsePromptTokenCount"` + CandidatesTokenCount int `json:"candidatesTokenCount"` + TotalTokenCount int `json:"totalTokenCount"` + ThoughtsTokenCount int `json:"thoughtsTokenCount"` + CachedContentTokenCount int `json:"cachedContentTokenCount"` + PromptTokensDetails []GeminiPromptTokensDetails `json:"promptTokensDetails"` + ToolUsePromptTokensDetails []GeminiPromptTokensDetails `json:"toolUsePromptTokensDetails"` } type GeminiPromptTokensDetails struct { diff --git a/relay/channel/gemini/relay-gemini-native.go b/relay/channel/gemini/relay-gemini-native.go index 39485b16f..1a434a432 100644 --- a/relay/channel/gemini/relay-gemini-native.go +++ b/relay/channel/gemini/relay-gemini-native.go @@ -42,22 +42,7 @@ func GeminiTextGenerationHandler(c *gin.Context, info *relaycommon.RelayInfo, re } // 计算使用量(基于 UsageMetadata) - usage := dto.Usage{ - PromptTokens: geminiResponse.UsageMetadata.PromptTokenCount, - CompletionTokens: geminiResponse.UsageMetadata.CandidatesTokenCount + geminiResponse.UsageMetadata.ThoughtsTokenCount, - TotalTokens: geminiResponse.UsageMetadata.TotalTokenCount, - } - - usage.CompletionTokenDetails.ReasoningTokens = geminiResponse.UsageMetadata.ThoughtsTokenCount - usage.PromptTokensDetails.CachedTokens = geminiResponse.UsageMetadata.CachedContentTokenCount - - for _, detail := range geminiResponse.UsageMetadata.PromptTokensDetails { - if detail.Modality == "AUDIO" { - usage.PromptTokensDetails.AudioTokens = detail.TokenCount - } else if detail.Modality == "TEXT" { - usage.PromptTokensDetails.TextTokens = detail.TokenCount - } - } + usage := buildUsageFromGeminiMetadata(geminiResponse.UsageMetadata, info.GetEstimatePromptTokens()) service.IOCopyBytesGracefully(c, resp, responseBody) diff --git a/relay/channel/gemini/relay-gemini.go b/relay/channel/gemini/relay-gemini.go index b10ec06c7..b81a148a3 100644 --- a/relay/channel/gemini/relay-gemini.go +++ b/relay/channel/gemini/relay-gemini.go @@ -1032,6 +1032,46 @@ func getResponseToolCall(item *dto.GeminiPart) *dto.ToolCallResponse { } } +func buildUsageFromGeminiMetadata(metadata dto.GeminiUsageMetadata, fallbackPromptTokens int) dto.Usage { + promptTokens := metadata.PromptTokenCount + metadata.ToolUsePromptTokenCount + if promptTokens <= 0 && fallbackPromptTokens > 0 { + promptTokens = fallbackPromptTokens + } + + usage := dto.Usage{ + PromptTokens: promptTokens, + CompletionTokens: metadata.CandidatesTokenCount + metadata.ThoughtsTokenCount, + TotalTokens: metadata.TotalTokenCount, + } + usage.CompletionTokenDetails.ReasoningTokens = metadata.ThoughtsTokenCount + usage.PromptTokensDetails.CachedTokens = metadata.CachedContentTokenCount + + for _, detail := range metadata.PromptTokensDetails { + if detail.Modality == "AUDIO" { + usage.PromptTokensDetails.AudioTokens += detail.TokenCount + } else if detail.Modality == "TEXT" { + usage.PromptTokensDetails.TextTokens += detail.TokenCount + } + } + for _, detail := range metadata.ToolUsePromptTokensDetails { + if detail.Modality == "AUDIO" { + usage.PromptTokensDetails.AudioTokens += detail.TokenCount + } else if detail.Modality == "TEXT" { + usage.PromptTokensDetails.TextTokens += detail.TokenCount + } + } + + if usage.TotalTokens > 0 && usage.CompletionTokens <= 0 { + usage.CompletionTokens = usage.TotalTokens - usage.PromptTokens + } + + if usage.PromptTokens > 0 && usage.PromptTokensDetails.TextTokens == 0 && usage.PromptTokensDetails.AudioTokens == 0 { + usage.PromptTokensDetails.TextTokens = usage.PromptTokens + } + + return usage +} + func responseGeminiChat2OpenAI(c *gin.Context, response *dto.GeminiChatResponse) *dto.OpenAITextResponse { fullTextResponse := dto.OpenAITextResponse{ Id: helper.GetResponseID(c), @@ -1272,18 +1312,8 @@ func geminiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http // 更新使用量统计 if geminiResponse.UsageMetadata.TotalTokenCount != 0 { - usage.PromptTokens = geminiResponse.UsageMetadata.PromptTokenCount - usage.CompletionTokens = geminiResponse.UsageMetadata.CandidatesTokenCount + geminiResponse.UsageMetadata.ThoughtsTokenCount - usage.CompletionTokenDetails.ReasoningTokens = geminiResponse.UsageMetadata.ThoughtsTokenCount - usage.TotalTokens = geminiResponse.UsageMetadata.TotalTokenCount - usage.PromptTokensDetails.CachedTokens = geminiResponse.UsageMetadata.CachedContentTokenCount - for _, detail := range geminiResponse.UsageMetadata.PromptTokensDetails { - if detail.Modality == "AUDIO" { - usage.PromptTokensDetails.AudioTokens = detail.TokenCount - } else if detail.Modality == "TEXT" { - usage.PromptTokensDetails.TextTokens = detail.TokenCount - } - } + mappedUsage := buildUsageFromGeminiMetadata(geminiResponse.UsageMetadata, info.GetEstimatePromptTokens()) + *usage = mappedUsage } return callback(data, &geminiResponse) @@ -1295,11 +1325,6 @@ func geminiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http } } - usage.PromptTokensDetails.TextTokens = usage.PromptTokens - if usage.TotalTokens > 0 { - usage.CompletionTokens = usage.TotalTokens - usage.PromptTokens - } - if usage.CompletionTokens <= 0 { if info.ReceivedResponseCount > 0 { usage = service.ResponseText2Usage(c, responseText.String(), info.UpstreamModelName, info.GetEstimatePromptTokens()) @@ -1416,21 +1441,7 @@ func GeminiChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.R return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) } if len(geminiResponse.Candidates) == 0 { - usage := dto.Usage{ - PromptTokens: geminiResponse.UsageMetadata.PromptTokenCount, - } - usage.CompletionTokenDetails.ReasoningTokens = geminiResponse.UsageMetadata.ThoughtsTokenCount - usage.PromptTokensDetails.CachedTokens = geminiResponse.UsageMetadata.CachedContentTokenCount - for _, detail := range geminiResponse.UsageMetadata.PromptTokensDetails { - if detail.Modality == "AUDIO" { - usage.PromptTokensDetails.AudioTokens = detail.TokenCount - } else if detail.Modality == "TEXT" { - usage.PromptTokensDetails.TextTokens = detail.TokenCount - } - } - if usage.PromptTokens <= 0 { - usage.PromptTokens = info.GetEstimatePromptTokens() - } + usage := buildUsageFromGeminiMetadata(geminiResponse.UsageMetadata, info.GetEstimatePromptTokens()) var newAPIError *types.NewAPIError if geminiResponse.PromptFeedback != nil && geminiResponse.PromptFeedback.BlockReason != nil { @@ -1466,23 +1477,7 @@ func GeminiChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.R } fullTextResponse := responseGeminiChat2OpenAI(c, &geminiResponse) fullTextResponse.Model = info.UpstreamModelName - usage := dto.Usage{ - PromptTokens: geminiResponse.UsageMetadata.PromptTokenCount, - CompletionTokens: geminiResponse.UsageMetadata.CandidatesTokenCount, - TotalTokens: geminiResponse.UsageMetadata.TotalTokenCount, - } - - usage.CompletionTokenDetails.ReasoningTokens = geminiResponse.UsageMetadata.ThoughtsTokenCount - usage.PromptTokensDetails.CachedTokens = geminiResponse.UsageMetadata.CachedContentTokenCount - usage.CompletionTokens = usage.TotalTokens - usage.PromptTokens - - for _, detail := range geminiResponse.UsageMetadata.PromptTokensDetails { - if detail.Modality == "AUDIO" { - usage.PromptTokensDetails.AudioTokens = detail.TokenCount - } else if detail.Modality == "TEXT" { - usage.PromptTokensDetails.TextTokens = detail.TokenCount - } - } + usage := buildUsageFromGeminiMetadata(geminiResponse.UsageMetadata, info.GetEstimatePromptTokens()) fullTextResponse.Usage = usage diff --git a/relay/channel/gemini/relay_gemini_usage_test.go b/relay/channel/gemini/relay_gemini_usage_test.go new file mode 100644 index 000000000..c8f9f8343 --- /dev/null +++ b/relay/channel/gemini/relay_gemini_usage_test.go @@ -0,0 +1,333 @@ +package gemini + +import ( + "bytes" + "io" + "net/http" + "net/http/httptest" + "testing" + + "github.com/QuantumNous/new-api/common" + "github.com/QuantumNous/new-api/constant" + "github.com/QuantumNous/new-api/dto" + relaycommon "github.com/QuantumNous/new-api/relay/common" + "github.com/QuantumNous/new-api/types" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +func TestGeminiChatHandlerCompletionTokensExcludeToolUsePromptTokens(t *testing.T) { + t.Parallel() + + gin.SetMode(gin.TestMode) + c, _ := gin.CreateTestContext(httptest.NewRecorder()) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil) + + info := &relaycommon.RelayInfo{ + RelayFormat: types.RelayFormatGemini, + OriginModelName: "gemini-3-flash-preview", + ChannelMeta: &relaycommon.ChannelMeta{ + UpstreamModelName: "gemini-3-flash-preview", + }, + } + + payload := dto.GeminiChatResponse{ + Candidates: []dto.GeminiChatCandidate{ + { + Content: dto.GeminiChatContent{ + Role: "model", + Parts: []dto.GeminiPart{ + {Text: "ok"}, + }, + }, + }, + }, + UsageMetadata: dto.GeminiUsageMetadata{ + PromptTokenCount: 151, + ToolUsePromptTokenCount: 18329, + CandidatesTokenCount: 1089, + ThoughtsTokenCount: 1120, + TotalTokenCount: 20689, + }, + } + + body, err := common.Marshal(payload) + require.NoError(t, err) + + resp := &http.Response{ + Body: io.NopCloser(bytes.NewReader(body)), + } + + usage, newAPIError := GeminiChatHandler(c, info, resp) + require.Nil(t, newAPIError) + require.NotNil(t, usage) + require.Equal(t, 18480, usage.PromptTokens) + require.Equal(t, 2209, usage.CompletionTokens) + require.Equal(t, 20689, usage.TotalTokens) + require.Equal(t, 1120, usage.CompletionTokenDetails.ReasoningTokens) +} + +func TestGeminiStreamHandlerCompletionTokensExcludeToolUsePromptTokens(t *testing.T) { + gin.SetMode(gin.TestMode) + c, _ := gin.CreateTestContext(httptest.NewRecorder()) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil) + + oldStreamingTimeout := constant.StreamingTimeout + constant.StreamingTimeout = 300 + t.Cleanup(func() { + constant.StreamingTimeout = oldStreamingTimeout + }) + + info := &relaycommon.RelayInfo{ + OriginModelName: "gemini-3-flash-preview", + ChannelMeta: &relaycommon.ChannelMeta{ + UpstreamModelName: "gemini-3-flash-preview", + }, + } + + chunk := dto.GeminiChatResponse{ + Candidates: []dto.GeminiChatCandidate{ + { + Content: dto.GeminiChatContent{ + Role: "model", + Parts: []dto.GeminiPart{ + {Text: "partial"}, + }, + }, + }, + }, + UsageMetadata: dto.GeminiUsageMetadata{ + PromptTokenCount: 151, + ToolUsePromptTokenCount: 18329, + CandidatesTokenCount: 1089, + ThoughtsTokenCount: 1120, + TotalTokenCount: 20689, + }, + } + + chunkData, err := common.Marshal(chunk) + require.NoError(t, err) + + streamBody := []byte("data: " + string(chunkData) + "\n" + "data: [DONE]\n") + resp := &http.Response{ + Body: io.NopCloser(bytes.NewReader(streamBody)), + } + + usage, newAPIError := geminiStreamHandler(c, info, resp, func(_ string, _ *dto.GeminiChatResponse) bool { + return true + }) + require.Nil(t, newAPIError) + require.NotNil(t, usage) + require.Equal(t, 18480, usage.PromptTokens) + require.Equal(t, 2209, usage.CompletionTokens) + require.Equal(t, 20689, usage.TotalTokens) + require.Equal(t, 1120, usage.CompletionTokenDetails.ReasoningTokens) +} + +func TestGeminiTextGenerationHandlerPromptTokensIncludeToolUsePromptTokens(t *testing.T) { + t.Parallel() + + gin.SetMode(gin.TestMode) + c, _ := gin.CreateTestContext(httptest.NewRecorder()) + c.Request = httptest.NewRequest(http.MethodPost, "/v1beta/models/gemini-3-flash-preview:generateContent", nil) + + info := &relaycommon.RelayInfo{ + OriginModelName: "gemini-3-flash-preview", + ChannelMeta: &relaycommon.ChannelMeta{ + UpstreamModelName: "gemini-3-flash-preview", + }, + } + + payload := dto.GeminiChatResponse{ + Candidates: []dto.GeminiChatCandidate{ + { + Content: dto.GeminiChatContent{ + Role: "model", + Parts: []dto.GeminiPart{ + {Text: "ok"}, + }, + }, + }, + }, + UsageMetadata: dto.GeminiUsageMetadata{ + PromptTokenCount: 151, + ToolUsePromptTokenCount: 18329, + CandidatesTokenCount: 1089, + ThoughtsTokenCount: 1120, + TotalTokenCount: 20689, + }, + } + + body, err := common.Marshal(payload) + require.NoError(t, err) + + resp := &http.Response{ + Body: io.NopCloser(bytes.NewReader(body)), + } + + usage, newAPIError := GeminiTextGenerationHandler(c, info, resp) + require.Nil(t, newAPIError) + require.NotNil(t, usage) + require.Equal(t, 18480, usage.PromptTokens) + require.Equal(t, 2209, usage.CompletionTokens) + require.Equal(t, 20689, usage.TotalTokens) + require.Equal(t, 1120, usage.CompletionTokenDetails.ReasoningTokens) +} + +func TestGeminiChatHandlerUsesEstimatedPromptTokensWhenUsagePromptMissing(t *testing.T) { + t.Parallel() + + gin.SetMode(gin.TestMode) + c, _ := gin.CreateTestContext(httptest.NewRecorder()) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil) + + info := &relaycommon.RelayInfo{ + RelayFormat: types.RelayFormatGemini, + OriginModelName: "gemini-3-flash-preview", + ChannelMeta: &relaycommon.ChannelMeta{ + UpstreamModelName: "gemini-3-flash-preview", + }, + } + info.SetEstimatePromptTokens(20) + + payload := dto.GeminiChatResponse{ + Candidates: []dto.GeminiChatCandidate{ + { + Content: dto.GeminiChatContent{ + Role: "model", + Parts: []dto.GeminiPart{ + {Text: "ok"}, + }, + }, + }, + }, + UsageMetadata: dto.GeminiUsageMetadata{ + PromptTokenCount: 0, + ToolUsePromptTokenCount: 0, + CandidatesTokenCount: 90, + ThoughtsTokenCount: 10, + TotalTokenCount: 110, + }, + } + + body, err := common.Marshal(payload) + require.NoError(t, err) + + resp := &http.Response{ + Body: io.NopCloser(bytes.NewReader(body)), + } + + usage, newAPIError := GeminiChatHandler(c, info, resp) + require.Nil(t, newAPIError) + require.NotNil(t, usage) + require.Equal(t, 20, usage.PromptTokens) + require.Equal(t, 100, usage.CompletionTokens) + require.Equal(t, 110, usage.TotalTokens) +} + +func TestGeminiStreamHandlerUsesEstimatedPromptTokensWhenUsagePromptMissing(t *testing.T) { + gin.SetMode(gin.TestMode) + c, _ := gin.CreateTestContext(httptest.NewRecorder()) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil) + + oldStreamingTimeout := constant.StreamingTimeout + constant.StreamingTimeout = 300 + t.Cleanup(func() { + constant.StreamingTimeout = oldStreamingTimeout + }) + + info := &relaycommon.RelayInfo{ + OriginModelName: "gemini-3-flash-preview", + ChannelMeta: &relaycommon.ChannelMeta{ + UpstreamModelName: "gemini-3-flash-preview", + }, + } + info.SetEstimatePromptTokens(20) + + chunk := dto.GeminiChatResponse{ + Candidates: []dto.GeminiChatCandidate{ + { + Content: dto.GeminiChatContent{ + Role: "model", + Parts: []dto.GeminiPart{ + {Text: "partial"}, + }, + }, + }, + }, + UsageMetadata: dto.GeminiUsageMetadata{ + PromptTokenCount: 0, + ToolUsePromptTokenCount: 0, + CandidatesTokenCount: 90, + ThoughtsTokenCount: 10, + TotalTokenCount: 110, + }, + } + + chunkData, err := common.Marshal(chunk) + require.NoError(t, err) + + streamBody := []byte("data: " + string(chunkData) + "\n" + "data: [DONE]\n") + resp := &http.Response{ + Body: io.NopCloser(bytes.NewReader(streamBody)), + } + + usage, newAPIError := geminiStreamHandler(c, info, resp, func(_ string, _ *dto.GeminiChatResponse) bool { + return true + }) + require.Nil(t, newAPIError) + require.NotNil(t, usage) + require.Equal(t, 20, usage.PromptTokens) + require.Equal(t, 100, usage.CompletionTokens) + require.Equal(t, 110, usage.TotalTokens) +} + +func TestGeminiTextGenerationHandlerUsesEstimatedPromptTokensWhenUsagePromptMissing(t *testing.T) { + t.Parallel() + + gin.SetMode(gin.TestMode) + c, _ := gin.CreateTestContext(httptest.NewRecorder()) + c.Request = httptest.NewRequest(http.MethodPost, "/v1beta/models/gemini-3-flash-preview:generateContent", nil) + + info := &relaycommon.RelayInfo{ + OriginModelName: "gemini-3-flash-preview", + ChannelMeta: &relaycommon.ChannelMeta{ + UpstreamModelName: "gemini-3-flash-preview", + }, + } + info.SetEstimatePromptTokens(20) + + payload := dto.GeminiChatResponse{ + Candidates: []dto.GeminiChatCandidate{ + { + Content: dto.GeminiChatContent{ + Role: "model", + Parts: []dto.GeminiPart{ + {Text: "ok"}, + }, + }, + }, + }, + UsageMetadata: dto.GeminiUsageMetadata{ + PromptTokenCount: 0, + ToolUsePromptTokenCount: 0, + CandidatesTokenCount: 90, + ThoughtsTokenCount: 10, + TotalTokenCount: 110, + }, + } + + body, err := common.Marshal(payload) + require.NoError(t, err) + + resp := &http.Response{ + Body: io.NopCloser(bytes.NewReader(body)), + } + + usage, newAPIError := GeminiTextGenerationHandler(c, info, resp) + require.Nil(t, newAPIError) + require.NotNil(t, usage) + require.Equal(t, 20, usage.PromptTokens) + require.Equal(t, 100, usage.CompletionTokens) + require.Equal(t, 110, usage.TotalTokens) +} From 721d0a41fb7a95bcd25c95f94455338c34fdf3c0 Mon Sep 17 00:00:00 2001 From: Seefs Date: Tue, 17 Feb 2026 17:27:57 +0800 Subject: [PATCH 03/41] feat: minimax native /v1/messages --- relay/channel/minimax/adaptor.go | 14 +++++++++++--- relay/channel/minimax/relay-minimax.go | 19 ++++++++++++------- 2 files changed, 23 insertions(+), 10 deletions(-) diff --git a/relay/channel/minimax/adaptor.go b/relay/channel/minimax/adaptor.go index 8235abc05..d244e695a 100644 --- a/relay/channel/minimax/adaptor.go +++ b/relay/channel/minimax/adaptor.go @@ -10,6 +10,7 @@ import ( "github.com/QuantumNous/new-api/dto" "github.com/QuantumNous/new-api/relay/channel" + "github.com/QuantumNous/new-api/relay/channel/claude" "github.com/QuantumNous/new-api/relay/channel/openai" relaycommon "github.com/QuantumNous/new-api/relay/common" "github.com/QuantumNous/new-api/relay/constant" @@ -26,7 +27,8 @@ func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dt } func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, req *dto.ClaudeRequest) (any, error) { - return nil, errors.New("not implemented") + adaptor := claude.Adaptor{} + return adaptor.ConvertClaudeRequest(c, info, req) } func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { @@ -119,8 +121,14 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom return handleTTSResponse(c, resp, info) } - adaptor := openai.Adaptor{} - return adaptor.DoResponse(c, resp, info) + switch info.RelayFormat { + case types.RelayFormatClaude: + adaptor := claude.Adaptor{} + return adaptor.DoResponse(c, resp, info) + default: + adaptor := openai.Adaptor{} + return adaptor.DoResponse(c, resp, info) + } } func (a *Adaptor) GetModelList() []string { diff --git a/relay/channel/minimax/relay-minimax.go b/relay/channel/minimax/relay-minimax.go index b314e69d7..c249de6a4 100644 --- a/relay/channel/minimax/relay-minimax.go +++ b/relay/channel/minimax/relay-minimax.go @@ -6,6 +6,7 @@ import ( channelconstant "github.com/QuantumNous/new-api/constant" relaycommon "github.com/QuantumNous/new-api/relay/common" "github.com/QuantumNous/new-api/relay/constant" + "github.com/QuantumNous/new-api/types" ) func GetRequestURL(info *relaycommon.RelayInfo) (string, error) { @@ -13,13 +14,17 @@ func GetRequestURL(info *relaycommon.RelayInfo) (string, error) { if baseUrl == "" { baseUrl = channelconstant.ChannelBaseURLs[channelconstant.ChannelTypeMiniMax] } - - switch info.RelayMode { - case constant.RelayModeChatCompletions: - return fmt.Sprintf("%s/v1/text/chatcompletion_v2", baseUrl), nil - case constant.RelayModeAudioSpeech: - return fmt.Sprintf("%s/v1/t2a_v2", baseUrl), nil + switch info.RelayFormat { + case types.RelayFormatClaude: + return fmt.Sprintf("%s/anthropic/v1/messages", info.ChannelBaseUrl), nil default: - return "", fmt.Errorf("unsupported relay mode: %d", info.RelayMode) + switch info.RelayMode { + case constant.RelayModeChatCompletions: + return fmt.Sprintf("%s/v1/text/chatcompletion_v2", baseUrl), nil + case constant.RelayModeAudioSpeech: + return fmt.Sprintf("%s/v1/t2a_v2", baseUrl), nil + default: + return "", fmt.Errorf("unsupported relay mode: %d", info.RelayMode) + } } } From 20c9002fdecbd9829b8ec8aaadd4531a47a25a23 Mon Sep 17 00:00:00 2001 From: Seefs Date: Tue, 17 Feb 2026 18:00:10 +0800 Subject: [PATCH 04/41] feat: codex oauth proxy --- controller/codex_oauth.go | 4 +++- controller/codex_usage.go | 5 ++-- service/codex_credential_refresh.go | 2 +- service/codex_oauth.go | 37 +++++++++++++++++++++++++---- 4 files changed, 39 insertions(+), 9 deletions(-) diff --git a/controller/codex_oauth.go b/controller/codex_oauth.go index 3071413c6..de9743ab7 100644 --- a/controller/codex_oauth.go +++ b/controller/codex_oauth.go @@ -145,6 +145,7 @@ func completeCodexOAuthWithChannelID(c *gin.Context, channelID int) { return } + channelProxy := "" if channelID > 0 { ch, err := model.GetChannelById(channelID, false) if err != nil { @@ -159,6 +160,7 @@ func completeCodexOAuthWithChannelID(c *gin.Context, channelID int) { c.JSON(http.StatusOK, gin.H{"success": false, "message": "channel type is not Codex"}) return } + channelProxy = ch.GetSetting().Proxy } session := sessions.Default(c) @@ -176,7 +178,7 @@ func completeCodexOAuthWithChannelID(c *gin.Context, channelID int) { ctx, cancel := context.WithTimeout(c.Request.Context(), 15*time.Second) defer cancel() - tokenRes, err := service.ExchangeCodexAuthorizationCode(ctx, code, verifier) + tokenRes, err := service.ExchangeCodexAuthorizationCodeWithProxy(ctx, code, verifier, channelProxy) if err != nil { common.SysError("failed to exchange codex authorization code: " + err.Error()) c.JSON(http.StatusOK, gin.H{"success": false, "message": "授权码交换失败,请重试"}) diff --git a/controller/codex_usage.go b/controller/codex_usage.go index 62b7a754f..52fdbdf6f 100644 --- a/controller/codex_usage.go +++ b/controller/codex_usage.go @@ -2,7 +2,6 @@ package controller import ( "context" - "encoding/json" "fmt" "net/http" "strconv" @@ -80,7 +79,7 @@ func GetCodexChannelUsage(c *gin.Context) { refreshCtx, refreshCancel := context.WithTimeout(c.Request.Context(), 10*time.Second) defer refreshCancel() - res, refreshErr := service.RefreshCodexOAuthToken(refreshCtx, oauthKey.RefreshToken) + res, refreshErr := service.RefreshCodexOAuthTokenWithProxy(refreshCtx, oauthKey.RefreshToken, ch.GetSetting().Proxy) if refreshErr == nil { oauthKey.AccessToken = res.AccessToken oauthKey.RefreshToken = res.RefreshToken @@ -109,7 +108,7 @@ func GetCodexChannelUsage(c *gin.Context) { } var payload any - if json.Unmarshal(body, &payload) != nil { + if common.Unmarshal(body, &payload) != nil { payload = string(body) } diff --git a/service/codex_credential_refresh.go b/service/codex_credential_refresh.go index 0290fe516..2e681ee61 100644 --- a/service/codex_credential_refresh.go +++ b/service/codex_credential_refresh.go @@ -62,7 +62,7 @@ func RefreshCodexChannelCredential(ctx context.Context, channelID int, opts Code refreshCtx, cancel := context.WithTimeout(ctx, 10*time.Second) defer cancel() - res, err := RefreshCodexOAuthToken(refreshCtx, oauthKey.RefreshToken) + res, err := RefreshCodexOAuthTokenWithProxy(refreshCtx, oauthKey.RefreshToken, ch.GetSetting().Proxy) if err != nil { return nil, nil, err } diff --git a/service/codex_oauth.go b/service/codex_oauth.go index 4c2dce1cc..33ef1d60a 100644 --- a/service/codex_oauth.go +++ b/service/codex_oauth.go @@ -12,6 +12,8 @@ import ( "net/url" "strings" "time" + + "github.com/QuantumNous/new-api/common" ) const ( @@ -38,12 +40,26 @@ type CodexOAuthAuthorizationFlow struct { } func RefreshCodexOAuthToken(ctx context.Context, refreshToken string) (*CodexOAuthTokenResult, error) { - client := &http.Client{Timeout: defaultHTTPTimeout} + return RefreshCodexOAuthTokenWithProxy(ctx, refreshToken, "") +} + +func RefreshCodexOAuthTokenWithProxy(ctx context.Context, refreshToken string, proxyURL string) (*CodexOAuthTokenResult, error) { + client, err := getCodexOAuthHTTPClient(proxyURL) + if err != nil { + return nil, err + } return refreshCodexOAuthToken(ctx, client, codexOAuthTokenURL, codexOAuthClientID, refreshToken) } func ExchangeCodexAuthorizationCode(ctx context.Context, code string, verifier string) (*CodexOAuthTokenResult, error) { - client := &http.Client{Timeout: defaultHTTPTimeout} + return ExchangeCodexAuthorizationCodeWithProxy(ctx, code, verifier, "") +} + +func ExchangeCodexAuthorizationCodeWithProxy(ctx context.Context, code string, verifier string, proxyURL string) (*CodexOAuthTokenResult, error) { + client, err := getCodexOAuthHTTPClient(proxyURL) + if err != nil { + return nil, err + } return exchangeCodexAuthorizationCode(ctx, client, codexOAuthTokenURL, codexOAuthClientID, code, verifier, codexOAuthRedirectURI) } @@ -104,7 +120,7 @@ func refreshCodexOAuthToken( ExpiresIn int `json:"expires_in"` } - if err := json.NewDecoder(resp.Body).Decode(&payload); err != nil { + if err := common.DecodeJson(resp.Body, &payload); err != nil { return nil, err } if resp.StatusCode < 200 || resp.StatusCode >= 300 { @@ -165,7 +181,7 @@ func exchangeCodexAuthorizationCode( RefreshToken string `json:"refresh_token"` ExpiresIn int `json:"expires_in"` } - if err := json.NewDecoder(resp.Body).Decode(&payload); err != nil { + if err := common.DecodeJson(resp.Body, &payload); err != nil { return nil, err } if resp.StatusCode < 200 || resp.StatusCode >= 300 { @@ -181,6 +197,19 @@ func exchangeCodexAuthorizationCode( }, nil } +func getCodexOAuthHTTPClient(proxyURL string) (*http.Client, error) { + baseClient, err := GetHttpClientWithProxy(strings.TrimSpace(proxyURL)) + if err != nil { + return nil, err + } + if baseClient == nil { + return &http.Client{Timeout: defaultHTTPTimeout}, nil + } + clientCopy := *baseClient + clientCopy.Timeout = defaultHTTPTimeout + return &clientCopy, nil +} + func buildCodexAuthorizeURL(state string, challenge string) (string, error) { u, err := url.Parse(codexOAuthAuthorizeURL) if err != nil { From 6004314c8800a4659eae294712b49af35d93481c Mon Sep 17 00:00:00 2001 From: Seefs Date: Thu, 19 Feb 2026 14:16:07 +0800 Subject: [PATCH 05/41] feat: add missing OpenAI/Claude/Gemini request fields and responses stream options --- dto/channel_settings.go | 17 +++-- dto/claude.go | 10 ++- dto/gemini.go | 75 ++++++++++--------- dto/openai_request.go | 17 ++++- relay/common/relay_info.go | 17 +++++ .../channels/modals/EditChannelModal.jsx | 28 ++++++- 6 files changed, 113 insertions(+), 51 deletions(-) diff --git a/dto/channel_settings.go b/dto/channel_settings.go index 74bceb281..58c15db0b 100644 --- a/dto/channel_settings.go +++ b/dto/channel_settings.go @@ -24,14 +24,15 @@ const ( ) type ChannelOtherSettings struct { - AzureResponsesVersion string `json:"azure_responses_version,omitempty"` - VertexKeyType VertexKeyType `json:"vertex_key_type,omitempty"` // "json" or "api_key" - OpenRouterEnterprise *bool `json:"openrouter_enterprise,omitempty"` - ClaudeBetaQuery bool `json:"claude_beta_query,omitempty"` // Claude 渠道是否强制追加 ?beta=true - AllowServiceTier bool `json:"allow_service_tier,omitempty"` // 是否允许 service_tier 透传(默认过滤以避免额外计费) - DisableStore bool `json:"disable_store,omitempty"` // 是否禁用 store 透传(默认允许透传,禁用后可能导致 Codex 无法使用) - AllowSafetyIdentifier bool `json:"allow_safety_identifier,omitempty"` // 是否允许 safety_identifier 透传(默认过滤以保护用户隐私) - AwsKeyType AwsKeyType `json:"aws_key_type,omitempty"` + AzureResponsesVersion string `json:"azure_responses_version,omitempty"` + VertexKeyType VertexKeyType `json:"vertex_key_type,omitempty"` // "json" or "api_key" + OpenRouterEnterprise *bool `json:"openrouter_enterprise,omitempty"` + ClaudeBetaQuery bool `json:"claude_beta_query,omitempty"` // Claude 渠道是否强制追加 ?beta=true + AllowServiceTier bool `json:"allow_service_tier,omitempty"` // 是否允许 service_tier 透传(默认过滤以避免额外计费) + DisableStore bool `json:"disable_store,omitempty"` // 是否禁用 store 透传(默认允许透传,禁用后可能导致 Codex 无法使用) + AllowSafetyIdentifier bool `json:"allow_safety_identifier,omitempty"` // 是否允许 safety_identifier 透传(默认过滤以保护用户隐私) + AllowIncludeObfuscation bool `json:"allow_include_obfuscation,omitempty"` // 是否允许 stream_options.include_obfuscation 透传(默认过滤以避免关闭流混淆保护) + AwsKeyType AwsKeyType `json:"aws_key_type,omitempty"` } func (s *ChannelOtherSettings) IsOpenRouterEnterprise() bool { diff --git a/dto/claude.go b/dto/claude.go index 8b6b495f6..bad3835fa 100644 --- a/dto/claude.go +++ b/dto/claude.go @@ -190,10 +190,12 @@ type ClaudeToolChoice struct { } type ClaudeRequest struct { - Model string `json:"model"` - Prompt string `json:"prompt,omitempty"` - System any `json:"system,omitempty"` - Messages []ClaudeMessage `json:"messages,omitempty"` + Model string `json:"model"` + Prompt string `json:"prompt,omitempty"` + System any `json:"system,omitempty"` + Messages []ClaudeMessage `json:"messages,omitempty"` + // https://platform.claude.com/docs/en/build-with-claude/data-residency#inference-geo + // InferenceGeo string `json:"inference_geo,omitempty"` MaxTokens uint `json:"max_tokens,omitempty"` MaxTokensToSample uint `json:"max_tokens_to_sample,omitempty"` StopSequences []string `json:"stop_sequences,omitempty"` diff --git a/dto/gemini.go b/dto/gemini.go index 0fd74c639..51490fc4c 100644 --- a/dto/gemini.go +++ b/dto/gemini.go @@ -324,25 +324,26 @@ type GeminiChatTool struct { } type GeminiChatGenerationConfig struct { - Temperature *float64 `json:"temperature,omitempty"` - TopP float64 `json:"topP,omitempty"` - TopK float64 `json:"topK,omitempty"` - MaxOutputTokens uint `json:"maxOutputTokens,omitempty"` - CandidateCount int `json:"candidateCount,omitempty"` - StopSequences []string `json:"stopSequences,omitempty"` - ResponseMimeType string `json:"responseMimeType,omitempty"` - ResponseSchema any `json:"responseSchema,omitempty"` - ResponseJsonSchema json.RawMessage `json:"responseJsonSchema,omitempty"` - PresencePenalty *float32 `json:"presencePenalty,omitempty"` - FrequencyPenalty *float32 `json:"frequencyPenalty,omitempty"` - ResponseLogprobs bool `json:"responseLogprobs,omitempty"` - Logprobs *int32 `json:"logprobs,omitempty"` - MediaResolution MediaResolution `json:"mediaResolution,omitempty"` - Seed int64 `json:"seed,omitempty"` - ResponseModalities []string `json:"responseModalities,omitempty"` - ThinkingConfig *GeminiThinkingConfig `json:"thinkingConfig,omitempty"` - SpeechConfig json.RawMessage `json:"speechConfig,omitempty"` // RawMessage to allow flexible speech config - ImageConfig json.RawMessage `json:"imageConfig,omitempty"` // RawMessage to allow flexible image config + Temperature *float64 `json:"temperature,omitempty"` + TopP float64 `json:"topP,omitempty"` + TopK float64 `json:"topK,omitempty"` + MaxOutputTokens uint `json:"maxOutputTokens,omitempty"` + CandidateCount int `json:"candidateCount,omitempty"` + StopSequences []string `json:"stopSequences,omitempty"` + ResponseMimeType string `json:"responseMimeType,omitempty"` + ResponseSchema any `json:"responseSchema,omitempty"` + ResponseJsonSchema json.RawMessage `json:"responseJsonSchema,omitempty"` + PresencePenalty *float32 `json:"presencePenalty,omitempty"` + FrequencyPenalty *float32 `json:"frequencyPenalty,omitempty"` + ResponseLogprobs bool `json:"responseLogprobs,omitempty"` + Logprobs *int32 `json:"logprobs,omitempty"` + EnableEnhancedCivicAnswers *bool `json:"enableEnhancedCivicAnswers,omitempty"` + MediaResolution MediaResolution `json:"mediaResolution,omitempty"` + Seed int64 `json:"seed,omitempty"` + ResponseModalities []string `json:"responseModalities,omitempty"` + ThinkingConfig *GeminiThinkingConfig `json:"thinkingConfig,omitempty"` + SpeechConfig json.RawMessage `json:"speechConfig,omitempty"` // RawMessage to allow flexible speech config + ImageConfig json.RawMessage `json:"imageConfig,omitempty"` // RawMessage to allow flexible image config } // UnmarshalJSON allows GeminiChatGenerationConfig to accept both snake_case and camelCase fields. @@ -350,22 +351,23 @@ func (c *GeminiChatGenerationConfig) UnmarshalJSON(data []byte) error { type Alias GeminiChatGenerationConfig var aux struct { Alias - TopPSnake float64 `json:"top_p,omitempty"` - TopKSnake float64 `json:"top_k,omitempty"` - MaxOutputTokensSnake uint `json:"max_output_tokens,omitempty"` - CandidateCountSnake int `json:"candidate_count,omitempty"` - StopSequencesSnake []string `json:"stop_sequences,omitempty"` - ResponseMimeTypeSnake string `json:"response_mime_type,omitempty"` - ResponseSchemaSnake any `json:"response_schema,omitempty"` - ResponseJsonSchemaSnake json.RawMessage `json:"response_json_schema,omitempty"` - PresencePenaltySnake *float32 `json:"presence_penalty,omitempty"` - FrequencyPenaltySnake *float32 `json:"frequency_penalty,omitempty"` - ResponseLogprobsSnake bool `json:"response_logprobs,omitempty"` - MediaResolutionSnake MediaResolution `json:"media_resolution,omitempty"` - ResponseModalitiesSnake []string `json:"response_modalities,omitempty"` - ThinkingConfigSnake *GeminiThinkingConfig `json:"thinking_config,omitempty"` - SpeechConfigSnake json.RawMessage `json:"speech_config,omitempty"` - ImageConfigSnake json.RawMessage `json:"image_config,omitempty"` + TopPSnake float64 `json:"top_p,omitempty"` + TopKSnake float64 `json:"top_k,omitempty"` + MaxOutputTokensSnake uint `json:"max_output_tokens,omitempty"` + CandidateCountSnake int `json:"candidate_count,omitempty"` + StopSequencesSnake []string `json:"stop_sequences,omitempty"` + ResponseMimeTypeSnake string `json:"response_mime_type,omitempty"` + ResponseSchemaSnake any `json:"response_schema,omitempty"` + ResponseJsonSchemaSnake json.RawMessage `json:"response_json_schema,omitempty"` + PresencePenaltySnake *float32 `json:"presence_penalty,omitempty"` + FrequencyPenaltySnake *float32 `json:"frequency_penalty,omitempty"` + ResponseLogprobsSnake bool `json:"response_logprobs,omitempty"` + EnableEnhancedCivicAnswersSnake *bool `json:"enable_enhanced_civic_answers,omitempty"` + MediaResolutionSnake MediaResolution `json:"media_resolution,omitempty"` + ResponseModalitiesSnake []string `json:"response_modalities,omitempty"` + ThinkingConfigSnake *GeminiThinkingConfig `json:"thinking_config,omitempty"` + SpeechConfigSnake json.RawMessage `json:"speech_config,omitempty"` + ImageConfigSnake json.RawMessage `json:"image_config,omitempty"` } if err := common.Unmarshal(data, &aux); err != nil { @@ -408,6 +410,9 @@ func (c *GeminiChatGenerationConfig) UnmarshalJSON(data []byte) error { if aux.ResponseLogprobsSnake { c.ResponseLogprobs = aux.ResponseLogprobsSnake } + if aux.EnableEnhancedCivicAnswersSnake != nil { + c.EnableEnhancedCivicAnswers = aux.EnableEnhancedCivicAnswersSnake + } if aux.MediaResolutionSnake != "" { c.MediaResolution = aux.MediaResolutionSnake } diff --git a/dto/openai_request.go b/dto/openai_request.go index 9113a086e..0b261a61e 100644 --- a/dto/openai_request.go +++ b/dto/openai_request.go @@ -54,7 +54,9 @@ type GeneralOpenAIRequest struct { ParallelTooCalls *bool `json:"parallel_tool_calls,omitempty"` Tools []ToolCallRequest `json:"tools,omitempty"` ToolChoice any `json:"tool_choice,omitempty"` + FunctionCall json.RawMessage `json:"function_call,omitempty"` User string `json:"user,omitempty"` + ServiceTier string `json:"service_tier,omitempty"` LogProbs bool `json:"logprobs,omitempty"` TopLogProbs int `json:"top_logprobs,omitempty"` Dimensions int `json:"dimensions,omitempty"` @@ -261,6 +263,8 @@ type FunctionRequest struct { type StreamOptions struct { IncludeUsage bool `json:"include_usage,omitempty"` + // for /v1/responses + IncludeObfuscation bool `json:"include_obfuscation,omitempty"` } func (r *GeneralOpenAIRequest) GetMaxTokens() uint { @@ -799,11 +803,16 @@ type WebSearchOptions struct { // https://platform.openai.com/docs/api-reference/responses/create type OpenAIResponsesRequest struct { - Model string `json:"model"` - Input json.RawMessage `json:"input,omitempty"` - Include json.RawMessage `json:"include,omitempty"` + Model string `json:"model"` + Input json.RawMessage `json:"input,omitempty"` + Include json.RawMessage `json:"include,omitempty"` + // 在后台运行推理,暂时还不支持依赖的接口 + // Background json.RawMessage `json:"background,omitempty"` + Conversation json.RawMessage `json:"conversation,omitempty"` + ContextManagement json.RawMessage `json:"context_management,omitempty"` Instructions json.RawMessage `json:"instructions,omitempty"` MaxOutputTokens uint `json:"max_output_tokens,omitempty"` + TopLogProbs *int `json:"top_logprobs,omitempty"` Metadata json.RawMessage `json:"metadata,omitempty"` ParallelToolCalls json.RawMessage `json:"parallel_tool_calls,omitempty"` PreviousResponseID string `json:"previous_response_id,omitempty"` @@ -813,7 +822,9 @@ type OpenAIResponsesRequest struct { Store json.RawMessage `json:"store,omitempty"` PromptCacheKey json.RawMessage `json:"prompt_cache_key,omitempty"` PromptCacheRetention json.RawMessage `json:"prompt_cache_retention,omitempty"` + SafetyIdentifier string `json:"safety_identifier,omitempty"` Stream bool `json:"stream,omitempty"` + StreamOptions *StreamOptions `json:"stream_options,omitempty"` Temperature *float64 `json:"temperature,omitempty"` Text json.RawMessage `json:"text,omitempty"` ToolChoice json.RawMessage `json:"tool_choice,omitempty"` diff --git a/relay/common/relay_info.go b/relay/common/relay_info.go index 81b7d21d6..25dac8cfc 100644 --- a/relay/common/relay_info.go +++ b/relay/common/relay_info.go @@ -702,6 +702,7 @@ func FailTaskInfo(reason string) *TaskInfo { // service_tier: 服务层级字段,可能导致额外计费(OpenAI、Claude、Responses API 支持) // store: 数据存储授权字段,涉及用户隐私(仅 OpenAI、Responses API 支持,默认允许透传,禁用后可能导致 Codex 无法使用) // safety_identifier: 安全标识符,用于向 OpenAI 报告违规用户(仅 OpenAI 支持,涉及用户隐私) +// stream_options.include_obfuscation: 响应流混淆控制字段(仅 OpenAI Responses API 支持) func RemoveDisabledFields(jsonData []byte, channelOtherSettings dto.ChannelOtherSettings) ([]byte, error) { var data map[string]interface{} if err := common.Unmarshal(jsonData, &data); err != nil { @@ -730,6 +731,22 @@ func RemoveDisabledFields(jsonData []byte, channelOtherSettings dto.ChannelOther } } + // 默认移除 stream_options.include_obfuscation,除非明确允许(避免关闭响应流混淆保护) + if !channelOtherSettings.AllowIncludeObfuscation { + if streamOptionsAny, exists := data["stream_options"]; exists { + if streamOptions, ok := streamOptionsAny.(map[string]interface{}); ok { + if _, includeExists := streamOptions["include_obfuscation"]; includeExists { + delete(streamOptions, "include_obfuscation") + } + if len(streamOptions) == 0 { + delete(data, "stream_options") + } else { + data["stream_options"] = streamOptions + } + } + } + } + jsonDataAfter, err := common.Marshal(data) if err != nil { common.SysError("RemoveDisabledFields Marshal error :" + err.Error()) diff --git a/web/src/components/table/channels/modals/EditChannelModal.jsx b/web/src/components/table/channels/modals/EditChannelModal.jsx index 6e85ca982..a33c070c4 100644 --- a/web/src/components/table/channels/modals/EditChannelModal.jsx +++ b/web/src/components/table/channels/modals/EditChannelModal.jsx @@ -170,6 +170,7 @@ const EditChannelModal = (props) => { allow_service_tier: false, disable_store: false, // false = 允许透传(默认开启) allow_safety_identifier: false, + allow_include_obfuscation: false, claude_beta_query: false, }; const [batch, setBatch] = useState(false); @@ -634,6 +635,8 @@ const EditChannelModal = (props) => { data.disable_store = parsedSettings.disable_store || false; data.allow_safety_identifier = parsedSettings.allow_safety_identifier || false; + data.allow_include_obfuscation = + parsedSettings.allow_include_obfuscation || false; data.claude_beta_query = parsedSettings.claude_beta_query || false; } catch (error) { console.error('解析其他设置失败:', error); @@ -645,6 +648,7 @@ const EditChannelModal = (props) => { data.allow_service_tier = false; data.disable_store = false; data.allow_safety_identifier = false; + data.allow_include_obfuscation = false; data.claude_beta_query = false; } } else { @@ -655,6 +659,7 @@ const EditChannelModal = (props) => { data.allow_service_tier = false; data.disable_store = false; data.allow_safety_identifier = false; + data.allow_include_obfuscation = false; data.claude_beta_query = false; } @@ -1392,11 +1397,13 @@ const EditChannelModal = (props) => { // type === 1 (OpenAI) 或 type === 14 (Claude): 设置字段透传控制(显式保存布尔值) if (localInputs.type === 1 || localInputs.type === 14) { settings.allow_service_tier = localInputs.allow_service_tier === true; - // 仅 OpenAI 渠道需要 store 和 safety_identifier + // 仅 OpenAI 渠道需要 store / safety_identifier / include_obfuscation if (localInputs.type === 1) { settings.disable_store = localInputs.disable_store === true; settings.allow_safety_identifier = localInputs.allow_safety_identifier === true; + settings.allow_include_obfuscation = + localInputs.allow_include_obfuscation === true; } if (localInputs.type === 14) { settings.claude_beta_query = localInputs.claude_beta_query === true; @@ -1421,6 +1428,7 @@ const EditChannelModal = (props) => { delete localInputs.allow_service_tier; delete localInputs.disable_store; delete localInputs.allow_safety_identifier; + delete localInputs.allow_include_obfuscation; delete localInputs.claude_beta_query; let res; @@ -3271,6 +3279,24 @@ const EditChannelModal = (props) => { 'safety_identifier 字段用于帮助 OpenAI 识别可能违反使用政策的应用程序用户。默认关闭以保护用户隐私', )} /> + + + handleChannelOtherSettingsChange( + 'allow_include_obfuscation', + value, + ) + } + extraText={t( + 'include_obfuscation 用于控制 Responses 流混淆字段。默认关闭以避免客户端关闭该安全保护', + )} + /> )} From 1770a08504e3ae1d5337a0310a6a9a9912273db1 Mon Sep 17 00:00:00 2001 From: Seefs Date: Thu, 19 Feb 2026 15:09:13 +0800 Subject: [PATCH 06/41] fix: skip field filtering when request passthrough is enabled --- relay/chat_completions_via_responses.go | 4 +-- relay/claude_handler.go | 2 +- relay/common/override_test.go | 40 +++++++++++++++++++++++++ relay/common/relay_info.go | 6 +++- relay/compatible_handler.go | 2 +- relay/responses_handler.go | 2 +- 6 files changed, 50 insertions(+), 6 deletions(-) diff --git a/relay/chat_completions_via_responses.go b/relay/chat_completions_via_responses.go index 38dae3c56..6412b7d24 100644 --- a/relay/chat_completions_via_responses.go +++ b/relay/chat_completions_via_responses.go @@ -76,7 +76,7 @@ func chatCompletionsViaResponses(c *gin.Context, info *relaycommon.RelayInfo, ad return nil, types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry()) } - chatJSON, err = relaycommon.RemoveDisabledFields(chatJSON, info.ChannelOtherSettings) + chatJSON, err = relaycommon.RemoveDisabledFields(chatJSON, info.ChannelOtherSettings, info.ChannelSetting.PassThroughBodyEnabled) if err != nil { return nil, types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry()) } @@ -120,7 +120,7 @@ func chatCompletionsViaResponses(c *gin.Context, info *relaycommon.RelayInfo, ad return nil, types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry()) } - jsonData, err = relaycommon.RemoveDisabledFields(jsonData, info.ChannelOtherSettings) + jsonData, err = relaycommon.RemoveDisabledFields(jsonData, info.ChannelOtherSettings, info.ChannelSetting.PassThroughBodyEnabled) if err != nil { return nil, types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry()) } diff --git a/relay/claude_handler.go b/relay/claude_handler.go index 81adb276a..9b08781c8 100644 --- a/relay/claude_handler.go +++ b/relay/claude_handler.go @@ -146,7 +146,7 @@ func ClaudeHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ } // remove disabled fields for Claude API - jsonData, err = relaycommon.RemoveDisabledFields(jsonData, info.ChannelOtherSettings) + jsonData, err = relaycommon.RemoveDisabledFields(jsonData, info.ChannelOtherSettings, info.ChannelSetting.PassThroughBodyEnabled) if err != nil { return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry()) } diff --git a/relay/common/override_test.go b/relay/common/override_test.go index 021df3f60..4e8cd5cff 100644 --- a/relay/common/override_test.go +++ b/relay/common/override_test.go @@ -4,6 +4,9 @@ import ( "encoding/json" "reflect" "testing" + + "github.com/QuantumNous/new-api/dto" + "github.com/QuantumNous/new-api/setting/model_setting" ) func TestApplyParamOverrideTrimPrefix(t *testing.T) { @@ -772,6 +775,43 @@ func TestApplyParamOverrideToUpper(t *testing.T) { assertJSONEqual(t, `{"model":"GPT-4"}`, string(out)) } +func TestRemoveDisabledFieldsSkipWhenChannelPassThroughEnabled(t *testing.T) { + input := `{ + "service_tier":"flex", + "safety_identifier":"user-123", + "store":true, + "stream_options":{"include_obfuscation":false} + }` + settings := dto.ChannelOtherSettings{} + + out, err := RemoveDisabledFields([]byte(input), settings, true) + if err != nil { + t.Fatalf("RemoveDisabledFields returned error: %v", err) + } + assertJSONEqual(t, input, string(out)) +} + +func TestRemoveDisabledFieldsSkipWhenGlobalPassThroughEnabled(t *testing.T) { + original := model_setting.GetGlobalSettings().PassThroughRequestEnabled + model_setting.GetGlobalSettings().PassThroughRequestEnabled = true + t.Cleanup(func() { + model_setting.GetGlobalSettings().PassThroughRequestEnabled = original + }) + + input := `{ + "service_tier":"flex", + "safety_identifier":"user-123", + "stream_options":{"include_obfuscation":false} + }` + settings := dto.ChannelOtherSettings{} + + out, err := RemoveDisabledFields([]byte(input), settings, false) + if err != nil { + t.Fatalf("RemoveDisabledFields returned error: %v", err) + } + assertJSONEqual(t, input, string(out)) +} + func assertJSONEqual(t *testing.T, want, got string) { t.Helper() diff --git a/relay/common/relay_info.go b/relay/common/relay_info.go index 25dac8cfc..491bfb67d 100644 --- a/relay/common/relay_info.go +++ b/relay/common/relay_info.go @@ -703,7 +703,11 @@ func FailTaskInfo(reason string) *TaskInfo { // store: 数据存储授权字段,涉及用户隐私(仅 OpenAI、Responses API 支持,默认允许透传,禁用后可能导致 Codex 无法使用) // safety_identifier: 安全标识符,用于向 OpenAI 报告违规用户(仅 OpenAI 支持,涉及用户隐私) // stream_options.include_obfuscation: 响应流混淆控制字段(仅 OpenAI Responses API 支持) -func RemoveDisabledFields(jsonData []byte, channelOtherSettings dto.ChannelOtherSettings) ([]byte, error) { +func RemoveDisabledFields(jsonData []byte, channelOtherSettings dto.ChannelOtherSettings, channelPassThroughEnabled bool) ([]byte, error) { + if model_setting.GetGlobalSettings().PassThroughRequestEnabled || channelPassThroughEnabled { + return jsonData, nil + } + var data map[string]interface{} if err := common.Unmarshal(jsonData, &data); err != nil { common.SysError("RemoveDisabledFields Unmarshal error :" + err.Error()) diff --git a/relay/compatible_handler.go b/relay/compatible_handler.go index e7adddbbf..5fff8a918 100644 --- a/relay/compatible_handler.go +++ b/relay/compatible_handler.go @@ -165,7 +165,7 @@ func TextHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types } // remove disabled fields for OpenAI API - jsonData, err = relaycommon.RemoveDisabledFields(jsonData, info.ChannelOtherSettings) + jsonData, err = relaycommon.RemoveDisabledFields(jsonData, info.ChannelOtherSettings, info.ChannelSetting.PassThroughBodyEnabled) if err != nil { return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry()) } diff --git a/relay/responses_handler.go b/relay/responses_handler.go index 04fc3470e..b3169e726 100644 --- a/relay/responses_handler.go +++ b/relay/responses_handler.go @@ -89,7 +89,7 @@ func ResponsesHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError * } // remove disabled fields for OpenAI Responses API - jsonData, err = relaycommon.RemoveDisabledFields(jsonData, info.ChannelOtherSettings) + jsonData, err = relaycommon.RemoveDisabledFields(jsonData, info.ChannelOtherSettings, info.ChannelSetting.PassThroughBodyEnabled) if err != nil { return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry()) } From 2c5af0df36aa1f4d18dae2043beb735b2b0d7bf6 Mon Sep 17 00:00:00 2001 From: Seefs Date: Thu, 19 Feb 2026 16:27:11 +0800 Subject: [PATCH 07/41] fix: include subscription in personal sidebar module controls --- .../settings/personal/cards/NotificationSettings.jsx | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/web/src/components/settings/personal/cards/NotificationSettings.jsx b/web/src/components/settings/personal/cards/NotificationSettings.jsx index 964a730e4..e57e39d63 100644 --- a/web/src/components/settings/personal/cards/NotificationSettings.jsx +++ b/web/src/components/settings/personal/cards/NotificationSettings.jsx @@ -86,6 +86,7 @@ const NotificationSettings = ({ channel: true, models: true, deployment: true, + subscription: true, redemption: true, user: true, setting: true, @@ -169,6 +170,7 @@ const NotificationSettings = ({ channel: true, models: true, deployment: true, + subscription: true, redemption: true, user: true, setting: true, @@ -296,6 +298,11 @@ const NotificationSettings = ({ title: t('模型部署'), description: t('模型部署管理'), }, + { + key: 'subscription', + title: t('订阅管理'), + description: t('订阅套餐管理'), + }, { key: 'redemption', title: t('兑换码管理'), From a546871a80e09c733f5803972865ecbdb3487684 Mon Sep 17 00:00:00 2001 From: Seefs Date: Sat, 21 Feb 2026 14:25:58 +0800 Subject: [PATCH 08/41] feat: gate Claude inference_geo passthrough behind channel setting and add field docs --- dto/channel_settings.go | 1 + dto/claude.go | 8 +-- dto/openai_request.go | 54 +++++++++++-------- relay/common/override_test.go | 33 ++++++++++++ relay/common/relay_info.go | 8 +++ .../channels/modals/EditChannelModal.jsx | 23 ++++++++ 6 files changed, 101 insertions(+), 26 deletions(-) diff --git a/dto/channel_settings.go b/dto/channel_settings.go index 58c15db0b..72fdf460c 100644 --- a/dto/channel_settings.go +++ b/dto/channel_settings.go @@ -29,6 +29,7 @@ type ChannelOtherSettings struct { OpenRouterEnterprise *bool `json:"openrouter_enterprise,omitempty"` ClaudeBetaQuery bool `json:"claude_beta_query,omitempty"` // Claude 渠道是否强制追加 ?beta=true AllowServiceTier bool `json:"allow_service_tier,omitempty"` // 是否允许 service_tier 透传(默认过滤以避免额外计费) + AllowInferenceGeo bool `json:"allow_inference_geo,omitempty"` // 是否允许 inference_geo 透传(仅 Claude,默认过滤以满足数据驻留合规) DisableStore bool `json:"disable_store,omitempty"` // 是否禁用 store 透传(默认允许透传,禁用后可能导致 Codex 无法使用) AllowSafetyIdentifier bool `json:"allow_safety_identifier,omitempty"` // 是否允许 safety_identifier 透传(默认过滤以保护用户隐私) AllowIncludeObfuscation bool `json:"allow_include_obfuscation,omitempty"` // 是否允许 stream_options.include_obfuscation 透传(默认过滤以避免关闭流混淆保护) diff --git a/dto/claude.go b/dto/claude.go index bad3835fa..32e31710b 100644 --- a/dto/claude.go +++ b/dto/claude.go @@ -194,8 +194,9 @@ type ClaudeRequest struct { Prompt string `json:"prompt,omitempty"` System any `json:"system,omitempty"` Messages []ClaudeMessage `json:"messages,omitempty"` - // https://platform.claude.com/docs/en/build-with-claude/data-residency#inference-geo - // InferenceGeo string `json:"inference_geo,omitempty"` + // InferenceGeo controls Claude data residency region. + // This field is filtered by default and can be enabled via channel setting allow_inference_geo. + InferenceGeo string `json:"inference_geo,omitempty"` MaxTokens uint `json:"max_tokens,omitempty"` MaxTokensToSample uint `json:"max_tokens_to_sample,omitempty"` StopSequences []string `json:"stop_sequences,omitempty"` @@ -212,7 +213,8 @@ type ClaudeRequest struct { Thinking *Thinking `json:"thinking,omitempty"` McpServers json.RawMessage `json:"mcp_servers,omitempty"` Metadata json.RawMessage `json:"metadata,omitempty"` - // 服务层级字段,用于指定 API 服务等级。允许透传可能导致实际计费高于预期,默认应过滤 + // ServiceTier specifies upstream service level and may affect billing. + // This field is filtered by default and can be enabled via channel setting allow_service_tier. ServiceTier string `json:"service_tier,omitempty"` } diff --git a/dto/openai_request.go b/dto/openai_request.go index 0b261a61e..c0a69a376 100644 --- a/dto/openai_request.go +++ b/dto/openai_request.go @@ -56,18 +56,20 @@ type GeneralOpenAIRequest struct { ToolChoice any `json:"tool_choice,omitempty"` FunctionCall json.RawMessage `json:"function_call,omitempty"` User string `json:"user,omitempty"` - ServiceTier string `json:"service_tier,omitempty"` - LogProbs bool `json:"logprobs,omitempty"` - TopLogProbs int `json:"top_logprobs,omitempty"` - Dimensions int `json:"dimensions,omitempty"` - Modalities json.RawMessage `json:"modalities,omitempty"` - Audio json.RawMessage `json:"audio,omitempty"` + // ServiceTier specifies upstream service level and may affect billing. + // This field is filtered by default and can be enabled via channel setting allow_service_tier. + ServiceTier string `json:"service_tier,omitempty"` + LogProbs bool `json:"logprobs,omitempty"` + TopLogProbs int `json:"top_logprobs,omitempty"` + Dimensions int `json:"dimensions,omitempty"` + Modalities json.RawMessage `json:"modalities,omitempty"` + Audio json.RawMessage `json:"audio,omitempty"` // 安全标识符,用于帮助 OpenAI 检测可能违反使用政策的应用程序用户 - // 注意:此字段会向 OpenAI 发送用户标识信息,默认过滤以保护用户隐私 + // 注意:此字段会向 OpenAI 发送用户标识信息,默认过滤,可通过 allow_safety_identifier 开启 SafetyIdentifier string `json:"safety_identifier,omitempty"` // Whether or not to store the output of this chat completion request for use in our model distillation or evals products. // 是否存储此次请求数据供 OpenAI 用于评估和优化产品 - // 注意:默认过滤此字段以保护用户隐私,但过滤后可能导致 Codex 无法正常使用 + // 注意:默认允许透传,可通过 disable_store 禁用;禁用后可能导致 Codex 无法正常使用 Store json.RawMessage `json:"store,omitempty"` // Used by OpenAI to cache responses for similar requests to optimize your cache hit rates. Replaces the user field PromptCacheKey string `json:"prompt_cache_key,omitempty"` @@ -263,7 +265,8 @@ type FunctionRequest struct { type StreamOptions struct { IncludeUsage bool `json:"include_usage,omitempty"` - // for /v1/responses + // IncludeObfuscation is only for /v1/responses stream payload. + // This field is filtered by default and can be enabled via channel setting allow_include_obfuscation. IncludeObfuscation bool `json:"include_obfuscation,omitempty"` } @@ -817,23 +820,28 @@ type OpenAIResponsesRequest struct { ParallelToolCalls json.RawMessage `json:"parallel_tool_calls,omitempty"` PreviousResponseID string `json:"previous_response_id,omitempty"` Reasoning *Reasoning `json:"reasoning,omitempty"` - // 服务层级字段,用于指定 API 服务等级。允许透传可能导致实际计费高于预期,默认应过滤 - ServiceTier string `json:"service_tier,omitempty"` + // ServiceTier specifies upstream service level and may affect billing. + // This field is filtered by default and can be enabled via channel setting allow_service_tier. + ServiceTier string `json:"service_tier,omitempty"` + // Store controls whether upstream may store request/response data. + // This field is allowed by default and can be disabled via channel setting disable_store. Store json.RawMessage `json:"store,omitempty"` PromptCacheKey json.RawMessage `json:"prompt_cache_key,omitempty"` PromptCacheRetention json.RawMessage `json:"prompt_cache_retention,omitempty"` - SafetyIdentifier string `json:"safety_identifier,omitempty"` - Stream bool `json:"stream,omitempty"` - StreamOptions *StreamOptions `json:"stream_options,omitempty"` - Temperature *float64 `json:"temperature,omitempty"` - Text json.RawMessage `json:"text,omitempty"` - ToolChoice json.RawMessage `json:"tool_choice,omitempty"` - Tools json.RawMessage `json:"tools,omitempty"` // 需要处理的参数很少,MCP 参数太多不确定,所以用 map - TopP *float64 `json:"top_p,omitempty"` - Truncation string `json:"truncation,omitempty"` - User string `json:"user,omitempty"` - MaxToolCalls uint `json:"max_tool_calls,omitempty"` - Prompt json.RawMessage `json:"prompt,omitempty"` + // SafetyIdentifier carries client identity for policy abuse detection. + // This field is filtered by default and can be enabled via channel setting allow_safety_identifier. + SafetyIdentifier string `json:"safety_identifier,omitempty"` + Stream bool `json:"stream,omitempty"` + StreamOptions *StreamOptions `json:"stream_options,omitempty"` + Temperature *float64 `json:"temperature,omitempty"` + Text json.RawMessage `json:"text,omitempty"` + ToolChoice json.RawMessage `json:"tool_choice,omitempty"` + Tools json.RawMessage `json:"tools,omitempty"` // 需要处理的参数很少,MCP 参数太多不确定,所以用 map + TopP *float64 `json:"top_p,omitempty"` + Truncation string `json:"truncation,omitempty"` + User string `json:"user,omitempty"` + MaxToolCalls uint `json:"max_tool_calls,omitempty"` + Prompt json.RawMessage `json:"prompt,omitempty"` // qwen EnableThinking json.RawMessage `json:"enable_thinking,omitempty"` // perplexity diff --git a/relay/common/override_test.go b/relay/common/override_test.go index 4e8cd5cff..c83cddff7 100644 --- a/relay/common/override_test.go +++ b/relay/common/override_test.go @@ -812,6 +812,39 @@ func TestRemoveDisabledFieldsSkipWhenGlobalPassThroughEnabled(t *testing.T) { assertJSONEqual(t, input, string(out)) } +func TestRemoveDisabledFieldsDefaultFiltering(t *testing.T) { + input := `{ + "service_tier":"flex", + "inference_geo":"eu", + "safety_identifier":"user-123", + "store":true, + "stream_options":{"include_obfuscation":false} + }` + settings := dto.ChannelOtherSettings{} + + out, err := RemoveDisabledFields([]byte(input), settings, false) + if err != nil { + t.Fatalf("RemoveDisabledFields returned error: %v", err) + } + assertJSONEqual(t, `{"store":true}`, string(out)) +} + +func TestRemoveDisabledFieldsAllowInferenceGeo(t *testing.T) { + input := `{ + "inference_geo":"eu", + "store":true + }` + settings := dto.ChannelOtherSettings{ + AllowInferenceGeo: true, + } + + out, err := RemoveDisabledFields([]byte(input), settings, false) + if err != nil { + t.Fatalf("RemoveDisabledFields returned error: %v", err) + } + assertJSONEqual(t, `{"inference_geo":"eu","store":true}`, string(out)) +} + func assertJSONEqual(t *testing.T, want, got string) { t.Helper() diff --git a/relay/common/relay_info.go b/relay/common/relay_info.go index 491bfb67d..c5c5a883f 100644 --- a/relay/common/relay_info.go +++ b/relay/common/relay_info.go @@ -700,6 +700,7 @@ func FailTaskInfo(reason string) *TaskInfo { // RemoveDisabledFields 从请求 JSON 数据中移除渠道设置中禁用的字段 // service_tier: 服务层级字段,可能导致额外计费(OpenAI、Claude、Responses API 支持) +// inference_geo: Claude 数据驻留推理区域字段(仅 Claude 支持,默认过滤) // store: 数据存储授权字段,涉及用户隐私(仅 OpenAI、Responses API 支持,默认允许透传,禁用后可能导致 Codex 无法使用) // safety_identifier: 安全标识符,用于向 OpenAI 报告违规用户(仅 OpenAI 支持,涉及用户隐私) // stream_options.include_obfuscation: 响应流混淆控制字段(仅 OpenAI Responses API 支持) @@ -721,6 +722,13 @@ func RemoveDisabledFields(jsonData []byte, channelOtherSettings dto.ChannelOther } } + // 默认移除 inference_geo,除非明确允许(避免在未授权情况下透传数据驻留区域) + if !channelOtherSettings.AllowInferenceGeo { + if _, exists := data["inference_geo"]; exists { + delete(data, "inference_geo") + } + } + // 默认允许 store 透传,除非明确禁用(禁用可能影响 Codex 使用) if channelOtherSettings.DisableStore { if _, exists := data["store"]; exists { diff --git a/web/src/components/table/channels/modals/EditChannelModal.jsx b/web/src/components/table/channels/modals/EditChannelModal.jsx index a33c070c4..931a42efb 100644 --- a/web/src/components/table/channels/modals/EditChannelModal.jsx +++ b/web/src/components/table/channels/modals/EditChannelModal.jsx @@ -171,6 +171,7 @@ const EditChannelModal = (props) => { disable_store: false, // false = 允许透传(默认开启) allow_safety_identifier: false, allow_include_obfuscation: false, + allow_inference_geo: false, claude_beta_query: false, }; const [batch, setBatch] = useState(false); @@ -637,6 +638,8 @@ const EditChannelModal = (props) => { parsedSettings.allow_safety_identifier || false; data.allow_include_obfuscation = parsedSettings.allow_include_obfuscation || false; + data.allow_inference_geo = + parsedSettings.allow_inference_geo || false; data.claude_beta_query = parsedSettings.claude_beta_query || false; } catch (error) { console.error('解析其他设置失败:', error); @@ -649,6 +652,7 @@ const EditChannelModal = (props) => { data.disable_store = false; data.allow_safety_identifier = false; data.allow_include_obfuscation = false; + data.allow_inference_geo = false; data.claude_beta_query = false; } } else { @@ -660,6 +664,7 @@ const EditChannelModal = (props) => { data.disable_store = false; data.allow_safety_identifier = false; data.allow_include_obfuscation = false; + data.allow_inference_geo = false; data.claude_beta_query = false; } @@ -1406,6 +1411,7 @@ const EditChannelModal = (props) => { localInputs.allow_include_obfuscation === true; } if (localInputs.type === 14) { + settings.allow_inference_geo = localInputs.allow_inference_geo === true; settings.claude_beta_query = localInputs.claude_beta_query === true; } } @@ -1429,6 +1435,7 @@ const EditChannelModal = (props) => { delete localInputs.disable_store; delete localInputs.allow_safety_identifier; delete localInputs.allow_include_obfuscation; + delete localInputs.allow_inference_geo; delete localInputs.claude_beta_query; let res; @@ -3322,6 +3329,22 @@ const EditChannelModal = (props) => { 'service_tier 字段用于指定服务层级,允许透传可能导致实际计费高于预期。默认关闭以避免额外费用', )} /> + + + handleChannelOtherSettingsChange( + 'allow_inference_geo', + value, + ) + } + extraText={t( + 'inference_geo 字段用于控制 Claude 数据驻留推理区域。默认关闭以避免未经授权透传地域信息', + )} + /> )} From e0a6ee1cb804d255c9fd36b90447929a4e9ff6d5 Mon Sep 17 00:00:00 2001 From: Seefs <40468931+seefs001@users.noreply.github.com> Date: Sun, 22 Feb 2026 15:41:29 +0800 Subject: [PATCH 09/41] imporve oauth provider UI/UX (#2983) * feat: imporve UI/UX * fix: stabilize provider enabled toggle and polish custom OAuth settings UX * fix: add access policy/message templates and persist advanced fields reliably * fix: move template fill actions below fields and keep advanced form flow cleaner --- controller/custom_oauth.go | 112 ++- controller/misc.go | 4 + controller/oauth.go | 14 +- model/custom_oauth_provider.go | 116 ++- oauth/generic.go | 404 ++++++++++- oauth/types.go | 9 + router/api-router.go | 3 +- web/src/components/auth/LoginForm.jsx | 30 +- web/src/components/auth/RegisterForm.jsx | 58 +- .../settings/CustomOAuthSetting.jsx | 672 ++++++++++++++---- .../personal/cards/AccountManagement.jsx | 15 +- web/src/helpers/render.jsx | 125 ++++ 12 files changed, 1375 insertions(+), 187 deletions(-) diff --git a/controller/custom_oauth.go b/controller/custom_oauth.go index e2245f880..3197a9163 100644 --- a/controller/custom_oauth.go +++ b/controller/custom_oauth.go @@ -1,8 +1,13 @@ package controller import ( + "context" + "io" "net/http" + "net/url" "strconv" + "strings" + "time" "github.com/QuantumNous/new-api/common" "github.com/QuantumNous/new-api/model" @@ -16,6 +21,7 @@ type CustomOAuthProviderResponse struct { Id int `json:"id"` Name string `json:"name"` Slug string `json:"slug"` + Icon string `json:"icon"` Enabled bool `json:"enabled"` ClientId string `json:"client_id"` AuthorizationEndpoint string `json:"authorization_endpoint"` @@ -28,6 +34,8 @@ type CustomOAuthProviderResponse struct { EmailField string `json:"email_field"` WellKnown string `json:"well_known"` AuthStyle int `json:"auth_style"` + AccessPolicy string `json:"access_policy"` + AccessDeniedMessage string `json:"access_denied_message"` } func toCustomOAuthProviderResponse(p *model.CustomOAuthProvider) *CustomOAuthProviderResponse { @@ -35,6 +43,7 @@ func toCustomOAuthProviderResponse(p *model.CustomOAuthProvider) *CustomOAuthPro Id: p.Id, Name: p.Name, Slug: p.Slug, + Icon: p.Icon, Enabled: p.Enabled, ClientId: p.ClientId, AuthorizationEndpoint: p.AuthorizationEndpoint, @@ -47,6 +56,8 @@ func toCustomOAuthProviderResponse(p *model.CustomOAuthProvider) *CustomOAuthPro EmailField: p.EmailField, WellKnown: p.WellKnown, AuthStyle: p.AuthStyle, + AccessPolicy: p.AccessPolicy, + AccessDeniedMessage: p.AccessDeniedMessage, } } @@ -96,6 +107,7 @@ func GetCustomOAuthProvider(c *gin.Context) { type CreateCustomOAuthProviderRequest struct { Name string `json:"name" binding:"required"` Slug string `json:"slug" binding:"required"` + Icon string `json:"icon"` Enabled bool `json:"enabled"` ClientId string `json:"client_id" binding:"required"` ClientSecret string `json:"client_secret" binding:"required"` @@ -109,6 +121,85 @@ type CreateCustomOAuthProviderRequest struct { EmailField string `json:"email_field"` WellKnown string `json:"well_known"` AuthStyle int `json:"auth_style"` + AccessPolicy string `json:"access_policy"` + AccessDeniedMessage string `json:"access_denied_message"` +} + +type FetchCustomOAuthDiscoveryRequest struct { + WellKnownURL string `json:"well_known_url"` + IssuerURL string `json:"issuer_url"` +} + +// FetchCustomOAuthDiscovery fetches OIDC discovery document via backend (root-only route) +func FetchCustomOAuthDiscovery(c *gin.Context) { + var req FetchCustomOAuthDiscoveryRequest + if err := c.ShouldBindJSON(&req); err != nil { + common.ApiErrorMsg(c, "无效的请求参数: "+err.Error()) + return + } + + wellKnownURL := strings.TrimSpace(req.WellKnownURL) + issuerURL := strings.TrimSpace(req.IssuerURL) + + if wellKnownURL == "" && issuerURL == "" { + common.ApiErrorMsg(c, "请先填写 Discovery URL 或 Issuer URL") + return + } + + targetURL := wellKnownURL + if targetURL == "" { + targetURL = strings.TrimRight(issuerURL, "/") + "/.well-known/openid-configuration" + } + targetURL = strings.TrimSpace(targetURL) + + parsedURL, err := url.Parse(targetURL) + if err != nil || parsedURL.Host == "" || (parsedURL.Scheme != "http" && parsedURL.Scheme != "https") { + common.ApiErrorMsg(c, "Discovery URL 无效,仅支持 http/https") + return + } + + ctx, cancel := context.WithTimeout(c.Request.Context(), 20*time.Second) + defer cancel() + + httpReq, err := http.NewRequestWithContext(ctx, http.MethodGet, targetURL, nil) + if err != nil { + common.ApiErrorMsg(c, "创建 Discovery 请求失败: "+err.Error()) + return + } + httpReq.Header.Set("Accept", "application/json") + + client := &http.Client{Timeout: 20 * time.Second} + resp, err := client.Do(httpReq) + if err != nil { + common.ApiErrorMsg(c, "获取 Discovery 配置失败: "+err.Error()) + return + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(io.LimitReader(resp.Body, 512)) + message := strings.TrimSpace(string(body)) + if message == "" { + message = resp.Status + } + common.ApiErrorMsg(c, "获取 Discovery 配置失败: "+message) + return + } + + var discovery map[string]any + if err = common.DecodeJson(resp.Body, &discovery); err != nil { + common.ApiErrorMsg(c, "解析 Discovery 配置失败: "+err.Error()) + return + } + + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + "data": gin.H{ + "well_known_url": targetURL, + "discovery": discovery, + }, + }) } // CreateCustomOAuthProvider creates a new custom OAuth provider @@ -134,6 +225,7 @@ func CreateCustomOAuthProvider(c *gin.Context) { provider := &model.CustomOAuthProvider{ Name: req.Name, Slug: req.Slug, + Icon: req.Icon, Enabled: req.Enabled, ClientId: req.ClientId, ClientSecret: req.ClientSecret, @@ -147,6 +239,8 @@ func CreateCustomOAuthProvider(c *gin.Context) { EmailField: req.EmailField, WellKnown: req.WellKnown, AuthStyle: req.AuthStyle, + AccessPolicy: req.AccessPolicy, + AccessDeniedMessage: req.AccessDeniedMessage, } if err := model.CreateCustomOAuthProvider(provider); err != nil { @@ -168,9 +262,10 @@ func CreateCustomOAuthProvider(c *gin.Context) { type UpdateCustomOAuthProviderRequest struct { Name string `json:"name"` Slug string `json:"slug"` - Enabled *bool `json:"enabled"` // Optional: if nil, keep existing + Icon *string `json:"icon"` // Optional: if nil, keep existing + Enabled *bool `json:"enabled"` // Optional: if nil, keep existing ClientId string `json:"client_id"` - ClientSecret string `json:"client_secret"` // Optional: if empty, keep existing + ClientSecret string `json:"client_secret"` // Optional: if empty, keep existing AuthorizationEndpoint string `json:"authorization_endpoint"` TokenEndpoint string `json:"token_endpoint"` UserInfoEndpoint string `json:"user_info_endpoint"` @@ -181,6 +276,8 @@ type UpdateCustomOAuthProviderRequest struct { EmailField string `json:"email_field"` WellKnown *string `json:"well_known"` // Optional: if nil, keep existing AuthStyle *int `json:"auth_style"` // Optional: if nil, keep existing + AccessPolicy *string `json:"access_policy"` // Optional: if nil, keep existing + AccessDeniedMessage *string `json:"access_denied_message"` // Optional: if nil, keep existing } // UpdateCustomOAuthProvider updates an existing custom OAuth provider @@ -227,6 +324,9 @@ func UpdateCustomOAuthProvider(c *gin.Context) { if req.Slug != "" { provider.Slug = req.Slug } + if req.Icon != nil { + provider.Icon = *req.Icon + } if req.Enabled != nil { provider.Enabled = *req.Enabled } @@ -266,6 +366,12 @@ func UpdateCustomOAuthProvider(c *gin.Context) { if req.AuthStyle != nil { provider.AuthStyle = *req.AuthStyle } + if req.AccessPolicy != nil { + provider.AccessPolicy = *req.AccessPolicy + } + if req.AccessDeniedMessage != nil { + provider.AccessDeniedMessage = *req.AccessDeniedMessage + } if err := model.UpdateCustomOAuthProvider(provider); err != nil { common.ApiError(c, err) @@ -346,6 +452,7 @@ func GetUserOAuthBindings(c *gin.Context) { ProviderId int `json:"provider_id"` ProviderName string `json:"provider_name"` ProviderSlug string `json:"provider_slug"` + ProviderIcon string `json:"provider_icon"` ProviderUserId string `json:"provider_user_id"` } @@ -359,6 +466,7 @@ func GetUserOAuthBindings(c *gin.Context) { ProviderId: binding.ProviderId, ProviderName: provider.Name, ProviderSlug: provider.Slug, + ProviderIcon: provider.Icon, ProviderUserId: binding.ProviderUserId, }) } diff --git a/controller/misc.go b/controller/misc.go index a16e2d554..b24a74adf 100644 --- a/controller/misc.go +++ b/controller/misc.go @@ -134,8 +134,10 @@ func GetStatus(c *gin.Context) { customProviders := oauth.GetEnabledCustomProviders() if len(customProviders) > 0 { type CustomOAuthInfo struct { + Id int `json:"id"` Name string `json:"name"` Slug string `json:"slug"` + Icon string `json:"icon"` ClientId string `json:"client_id"` AuthorizationEndpoint string `json:"authorization_endpoint"` Scopes string `json:"scopes"` @@ -144,8 +146,10 @@ func GetStatus(c *gin.Context) { for _, p := range customProviders { config := p.GetConfig() providersInfo = append(providersInfo, CustomOAuthInfo{ + Id: config.Id, Name: config.Name, Slug: config.Slug, + Icon: config.Icon, ClientId: config.ClientId, AuthorizationEndpoint: config.AuthorizationEndpoint, Scopes: config.Scopes, diff --git a/controller/oauth.go b/controller/oauth.go index 65e18f9da..75ab29898 100644 --- a/controller/oauth.go +++ b/controller/oauth.go @@ -295,12 +295,12 @@ func findOrCreateOAuthUser(c *gin.Context, provider oauth.Provider, oauthUser *o // Set the provider user ID on the user model and update provider.SetProviderUserID(user, oauthUser.ProviderUserID) if err := tx.Model(user).Updates(map[string]interface{}{ - "github_id": user.GitHubId, - "discord_id": user.DiscordId, - "oidc_id": user.OidcId, - "linux_do_id": user.LinuxDOId, - "wechat_id": user.WeChatId, - "telegram_id": user.TelegramId, + "github_id": user.GitHubId, + "discord_id": user.DiscordId, + "oidc_id": user.OidcId, + "linux_do_id": user.LinuxDOId, + "wechat_id": user.WeChatId, + "telegram_id": user.TelegramId, }).Error; err != nil { return err } @@ -340,6 +340,8 @@ func handleOAuthError(c *gin.Context, err error) { } else { common.ApiErrorI18n(c, e.MsgKey) } + case *oauth.AccessDeniedError: + common.ApiErrorMsg(c, e.Message) case *oauth.TrustLevelError: common.ApiErrorI18n(c, i18n.MsgOAuthTrustLevelLow) default: diff --git a/model/custom_oauth_provider.go b/model/custom_oauth_provider.go index 43c69833a..12b4d1111 100644 --- a/model/custom_oauth_provider.go +++ b/model/custom_oauth_provider.go @@ -2,32 +2,65 @@ package model import ( "errors" + "fmt" "strings" "time" + + "github.com/QuantumNous/new-api/common" ) +type accessPolicyPayload struct { + Logic string `json:"logic"` + Conditions []accessConditionItem `json:"conditions"` + Groups []accessPolicyPayload `json:"groups"` +} + +type accessConditionItem struct { + Field string `json:"field"` + Op string `json:"op"` + Value any `json:"value"` +} + +var supportedAccessPolicyOps = map[string]struct{}{ + "eq": {}, + "ne": {}, + "gt": {}, + "gte": {}, + "lt": {}, + "lte": {}, + "in": {}, + "not_in": {}, + "contains": {}, + "not_contains": {}, + "exists": {}, + "not_exists": {}, +} + // CustomOAuthProvider stores configuration for custom OAuth providers type CustomOAuthProvider struct { - Id int `json:"id" gorm:"primaryKey"` - Name string `json:"name" gorm:"type:varchar(64);not null"` // Display name, e.g., "GitHub Enterprise" - Slug string `json:"slug" gorm:"type:varchar(64);uniqueIndex;not null"` // URL identifier, e.g., "github-enterprise" - Enabled bool `json:"enabled" gorm:"default:false"` // Whether this provider is enabled - ClientId string `json:"client_id" gorm:"type:varchar(256)"` // OAuth client ID - ClientSecret string `json:"-" gorm:"type:varchar(512)"` // OAuth client secret (not returned to frontend) - AuthorizationEndpoint string `json:"authorization_endpoint" gorm:"type:varchar(512)"` // Authorization URL - TokenEndpoint string `json:"token_endpoint" gorm:"type:varchar(512)"` // Token exchange URL - UserInfoEndpoint string `json:"user_info_endpoint" gorm:"type:varchar(512)"` // User info URL - Scopes string `json:"scopes" gorm:"type:varchar(256);default:'openid profile email'"` // OAuth scopes + Id int `json:"id" gorm:"primaryKey"` + Name string `json:"name" gorm:"type:varchar(64);not null"` // Display name, e.g., "GitHub Enterprise" + Slug string `json:"slug" gorm:"type:varchar(64);uniqueIndex;not null"` // URL identifier, e.g., "github-enterprise" + Icon string `json:"icon" gorm:"type:varchar(128);default:''"` // Icon name from @lobehub/icons + Enabled bool `json:"enabled" gorm:"default:false"` // Whether this provider is enabled + ClientId string `json:"client_id" gorm:"type:varchar(256)"` // OAuth client ID + ClientSecret string `json:"-" gorm:"type:varchar(512)"` // OAuth client secret (not returned to frontend) + AuthorizationEndpoint string `json:"authorization_endpoint" gorm:"type:varchar(512)"` // Authorization URL + TokenEndpoint string `json:"token_endpoint" gorm:"type:varchar(512)"` // Token exchange URL + UserInfoEndpoint string `json:"user_info_endpoint" gorm:"type:varchar(512)"` // User info URL + Scopes string `json:"scopes" gorm:"type:varchar(256);default:'openid profile email'"` // OAuth scopes // Field mapping configuration (supports JSONPath via gjson) - UserIdField string `json:"user_id_field" gorm:"type:varchar(128);default:'sub'"` // User ID field path, e.g., "sub", "id", "data.user.id" - UsernameField string `json:"username_field" gorm:"type:varchar(128);default:'preferred_username'"` // Username field path - DisplayNameField string `json:"display_name_field" gorm:"type:varchar(128);default:'name'"` // Display name field path - EmailField string `json:"email_field" gorm:"type:varchar(128);default:'email'"` // Email field path + UserIdField string `json:"user_id_field" gorm:"type:varchar(128);default:'sub'"` // User ID field path, e.g., "sub", "id", "data.user.id" + UsernameField string `json:"username_field" gorm:"type:varchar(128);default:'preferred_username'"` // Username field path + DisplayNameField string `json:"display_name_field" gorm:"type:varchar(128);default:'name'"` // Display name field path + EmailField string `json:"email_field" gorm:"type:varchar(128);default:'email'"` // Email field path // Advanced options - WellKnown string `json:"well_known" gorm:"type:varchar(512)"` // OIDC discovery endpoint (optional) - AuthStyle int `json:"auth_style" gorm:"default:0"` // 0=auto, 1=params, 2=header (Basic Auth) + WellKnown string `json:"well_known" gorm:"type:varchar(512)"` // OIDC discovery endpoint (optional) + AuthStyle int `json:"auth_style" gorm:"default:0"` // 0=auto, 1=params, 2=header (Basic Auth) + AccessPolicy string `json:"access_policy" gorm:"type:text"` // JSON policy for access control based on user info + AccessDeniedMessage string `json:"access_denied_message" gorm:"type:varchar(512)"` // Custom error message template when access is denied CreatedAt time.Time `json:"created_at"` UpdatedAt time.Time `json:"updated_at"` @@ -158,6 +191,57 @@ func validateCustomOAuthProvider(provider *CustomOAuthProvider) error { if provider.Scopes == "" { provider.Scopes = "openid profile email" } + if strings.TrimSpace(provider.AccessPolicy) != "" { + var policy accessPolicyPayload + if err := common.UnmarshalJsonStr(provider.AccessPolicy, &policy); err != nil { + return errors.New("access_policy must be valid JSON") + } + if err := validateAccessPolicyPayload(&policy); err != nil { + return fmt.Errorf("access_policy is invalid: %w", err) + } + } + + return nil +} + +func validateAccessPolicyPayload(policy *accessPolicyPayload) error { + if policy == nil { + return errors.New("policy is nil") + } + + logic := strings.ToLower(strings.TrimSpace(policy.Logic)) + if logic == "" { + logic = "and" + } + if logic != "and" && logic != "or" { + return fmt.Errorf("unsupported logic: %s", logic) + } + + if len(policy.Conditions) == 0 && len(policy.Groups) == 0 { + return errors.New("policy requires at least one condition or group") + } + + for index, condition := range policy.Conditions { + field := strings.TrimSpace(condition.Field) + if field == "" { + return fmt.Errorf("condition[%d].field is required", index) + } + op := strings.ToLower(strings.TrimSpace(condition.Op)) + if _, ok := supportedAccessPolicyOps[op]; !ok { + return fmt.Errorf("condition[%d].op is unsupported: %s", index, op) + } + if op == "in" || op == "not_in" { + if _, ok := condition.Value.([]any); !ok { + return fmt.Errorf("condition[%d].value must be an array for op %s", index, op) + } + } + } + + for index := range policy.Groups { + if err := validateAccessPolicyPayload(&policy.Groups[index]); err != nil { + return fmt.Errorf("group[%d]: %w", index, err) + } + } return nil } diff --git a/oauth/generic.go b/oauth/generic.go index c7aa87931..bc18054d5 100644 --- a/oauth/generic.go +++ b/oauth/generic.go @@ -3,19 +3,24 @@ package oauth import ( "context" "encoding/base64" - "encoding/json" + stdjson "encoding/json" + "errors" "fmt" "io" "net/http" "net/url" + "regexp" + "strconv" "strings" "time" + "github.com/QuantumNous/new-api/common" "github.com/QuantumNous/new-api/i18n" "github.com/QuantumNous/new-api/logger" "github.com/QuantumNous/new-api/model" "github.com/QuantumNous/new-api/setting/system_setting" "github.com/gin-gonic/gin" + "github.com/samber/lo" "github.com/tidwall/gjson" ) @@ -31,6 +36,40 @@ type GenericOAuthProvider struct { config *model.CustomOAuthProvider } +type accessPolicy struct { + Logic string `json:"logic"` + Conditions []accessCondition `json:"conditions"` + Groups []accessPolicy `json:"groups"` +} + +type accessCondition struct { + Field string `json:"field"` + Op string `json:"op"` + Value any `json:"value"` +} + +type accessPolicyFailure struct { + Field string + Op string + Expected any + Current any +} + +var supportedAccessPolicyOps = []string{ + "eq", + "ne", + "gt", + "gte", + "lt", + "lte", + "in", + "not_in", + "contains", + "not_contains", + "exists", + "not_exists", +} + // NewGenericOAuthProvider creates a new generic OAuth provider from config func NewGenericOAuthProvider(config *model.CustomOAuthProvider) *GenericOAuthProvider { return &GenericOAuthProvider{config: config} @@ -125,7 +164,7 @@ func (p *GenericOAuthProvider) ExchangeToken(ctx context.Context, code string, c ErrorDesc string `json:"error_description"` } - if err := json.Unmarshal(body, &tokenResponse); err != nil { + if err := common.Unmarshal(body, &tokenResponse); err != nil { // Try to parse as URL-encoded (some OAuth servers like GitHub return this format) parsedValues, parseErr := url.ParseQuery(bodyStr) if parseErr != nil { @@ -227,11 +266,30 @@ func (p *GenericOAuthProvider) GetUserInfo(ctx context.Context, token *OAuthToke logger.LogDebug(ctx, "[OAuth-Generic-%s] GetUserInfo success: id=%s, username=%s, name=%s, email=%s", p.config.Slug, userId, username, displayName, email) + policyRaw := strings.TrimSpace(p.config.AccessPolicy) + if policyRaw != "" { + policy, err := parseAccessPolicy(policyRaw) + if err != nil { + logger.LogError(ctx, fmt.Sprintf("[OAuth-Generic-%s] invalid access policy: %s", p.config.Slug, err.Error())) + return nil, NewOAuthErrorWithRaw(i18n.MsgOAuthGetUserErr, nil, "invalid access policy configuration") + } + allowed, failure := evaluateAccessPolicy(bodyStr, policy) + if !allowed { + message := renderAccessDeniedMessage(p.config.AccessDeniedMessage, p.config.Name, bodyStr, failure) + logger.LogWarn(ctx, fmt.Sprintf("[OAuth-Generic-%s] access denied by policy: field=%s op=%s expected=%v current=%v", + p.config.Slug, failure.Field, failure.Op, failure.Expected, failure.Current)) + return nil, &AccessDeniedError{Message: message} + } + } + return &OAuthUser{ ProviderUserID: userId, Username: username, DisplayName: displayName, Email: email, + Extra: map[string]any{ + "provider": p.config.Slug, + }, }, nil } @@ -266,3 +324,345 @@ func (p *GenericOAuthProvider) GetProviderId() int { func (p *GenericOAuthProvider) IsGenericProvider() bool { return true } + +func parseAccessPolicy(raw string) (*accessPolicy, error) { + var policy accessPolicy + if err := common.UnmarshalJsonStr(raw, &policy); err != nil { + return nil, err + } + if err := validateAccessPolicy(&policy); err != nil { + return nil, err + } + return &policy, nil +} + +func validateAccessPolicy(policy *accessPolicy) error { + if policy == nil { + return errors.New("policy is nil") + } + + logic := strings.ToLower(strings.TrimSpace(policy.Logic)) + if logic == "" { + logic = "and" + } + if !lo.Contains([]string{"and", "or"}, logic) { + return fmt.Errorf("unsupported policy logic: %s", logic) + } + policy.Logic = logic + + if len(policy.Conditions) == 0 && len(policy.Groups) == 0 { + return errors.New("policy requires at least one condition or group") + } + + for index := range policy.Conditions { + if err := validateAccessCondition(&policy.Conditions[index], index); err != nil { + return err + } + } + + for index := range policy.Groups { + if err := validateAccessPolicy(&policy.Groups[index]); err != nil { + return fmt.Errorf("invalid policy group[%d]: %w", index, err) + } + } + + return nil +} + +func validateAccessCondition(condition *accessCondition, index int) error { + if condition == nil { + return fmt.Errorf("condition[%d] is nil", index) + } + + condition.Field = strings.TrimSpace(condition.Field) + if condition.Field == "" { + return fmt.Errorf("condition[%d].field is required", index) + } + + condition.Op = normalizePolicyOp(condition.Op) + if !lo.Contains(supportedAccessPolicyOps, condition.Op) { + return fmt.Errorf("condition[%d].op is unsupported: %s", index, condition.Op) + } + + if lo.Contains([]string{"in", "not_in"}, condition.Op) { + if _, ok := condition.Value.([]any); !ok { + return fmt.Errorf("condition[%d].value must be an array for op %s", index, condition.Op) + } + } + + return nil +} + +func evaluateAccessPolicy(body string, policy *accessPolicy) (bool, *accessPolicyFailure) { + if policy == nil { + return true, nil + } + + logic := strings.ToLower(strings.TrimSpace(policy.Logic)) + if logic == "" { + logic = "and" + } + + hasAny := len(policy.Conditions) > 0 || len(policy.Groups) > 0 + if !hasAny { + return true, nil + } + + if logic == "or" { + var firstFailure *accessPolicyFailure + for _, cond := range policy.Conditions { + ok, failure := evaluateAccessCondition(body, cond) + if ok { + return true, nil + } + if firstFailure == nil { + firstFailure = failure + } + } + for _, group := range policy.Groups { + ok, failure := evaluateAccessPolicy(body, &group) + if ok { + return true, nil + } + if firstFailure == nil { + firstFailure = failure + } + } + return false, firstFailure + } + + for _, cond := range policy.Conditions { + ok, failure := evaluateAccessCondition(body, cond) + if !ok { + return false, failure + } + } + for _, group := range policy.Groups { + ok, failure := evaluateAccessPolicy(body, &group) + if !ok { + return false, failure + } + } + return true, nil +} + +func evaluateAccessCondition(body string, cond accessCondition) (bool, *accessPolicyFailure) { + path := cond.Field + op := cond.Op + result := gjson.Get(body, path) + current := gjsonResultToValue(result) + failure := &accessPolicyFailure{ + Field: path, + Op: op, + Expected: cond.Value, + Current: current, + } + + switch op { + case "exists": + return result.Exists(), failure + case "not_exists": + return !result.Exists(), failure + case "eq": + return compareAny(current, cond.Value) == 0, failure + case "ne": + return compareAny(current, cond.Value) != 0, failure + case "gt": + return compareAny(current, cond.Value) > 0, failure + case "gte": + return compareAny(current, cond.Value) >= 0, failure + case "lt": + return compareAny(current, cond.Value) < 0, failure + case "lte": + return compareAny(current, cond.Value) <= 0, failure + case "in": + return valueInSlice(current, cond.Value), failure + case "not_in": + return !valueInSlice(current, cond.Value), failure + case "contains": + return containsValue(current, cond.Value), failure + case "not_contains": + return !containsValue(current, cond.Value), failure + default: + return false, failure + } +} + +func normalizePolicyOp(op string) string { + return strings.ToLower(strings.TrimSpace(op)) +} + +func gjsonResultToValue(result gjson.Result) any { + if !result.Exists() { + return nil + } + if result.IsArray() { + arr := result.Array() + values := make([]any, 0, len(arr)) + for _, item := range arr { + values = append(values, gjsonResultToValue(item)) + } + return values + } + switch result.Type { + case gjson.Null: + return nil + case gjson.True: + return true + case gjson.False: + return false + case gjson.Number: + return result.Num + case gjson.String: + return result.String() + case gjson.JSON: + var data any + if err := common.UnmarshalJsonStr(result.Raw, &data); err == nil { + return data + } + return result.Raw + default: + return result.Value() + } +} + +func compareAny(left any, right any) int { + if lf, ok := toFloat(left); ok { + if rf, ok2 := toFloat(right); ok2 { + switch { + case lf < rf: + return -1 + case lf > rf: + return 1 + default: + return 0 + } + } + } + + ls := strings.TrimSpace(fmt.Sprint(left)) + rs := strings.TrimSpace(fmt.Sprint(right)) + switch { + case ls < rs: + return -1 + case ls > rs: + return 1 + default: + return 0 + } +} + +func toFloat(v any) (float64, bool) { + switch value := v.(type) { + case float64: + return value, true + case float32: + return float64(value), true + case int: + return float64(value), true + case int8: + return float64(value), true + case int16: + return float64(value), true + case int32: + return float64(value), true + case int64: + return float64(value), true + case uint: + return float64(value), true + case uint8: + return float64(value), true + case uint16: + return float64(value), true + case uint32: + return float64(value), true + case uint64: + return float64(value), true + case stdjson.Number: + n, err := value.Float64() + if err == nil { + return n, true + } + case string: + n, err := strconv.ParseFloat(strings.TrimSpace(value), 64) + if err == nil { + return n, true + } + } + return 0, false +} + +func valueInSlice(current any, expected any) bool { + list, ok := expected.([]any) + if !ok { + return false + } + return lo.ContainsBy(list, func(item any) bool { + return compareAny(current, item) == 0 + }) +} + +func containsValue(current any, expected any) bool { + switch value := current.(type) { + case string: + target := strings.TrimSpace(fmt.Sprint(expected)) + return strings.Contains(value, target) + case []any: + return lo.ContainsBy(value, func(item any) bool { + return compareAny(item, expected) == 0 + }) + } + return false +} + +func renderAccessDeniedMessage(template string, providerName string, body string, failure *accessPolicyFailure) string { + defaultMessage := "Access denied: your account does not meet this provider's access requirements." + message := strings.TrimSpace(template) + if message == "" { + return defaultMessage + } + + if failure == nil { + failure = &accessPolicyFailure{} + } + + replacements := map[string]string{ + "{{provider}}": providerName, + "{{field}}": failure.Field, + "{{op}}": failure.Op, + "{{required}}": fmt.Sprint(failure.Expected), + "{{current}}": fmt.Sprint(failure.Current), + } + + for key, value := range replacements { + message = strings.ReplaceAll(message, key, value) + } + + currentPattern := regexp.MustCompile(`\{\{current\.([^}]+)\}\}`) + message = currentPattern.ReplaceAllStringFunc(message, func(token string) string { + match := currentPattern.FindStringSubmatch(token) + if len(match) != 2 { + return "" + } + path := strings.TrimSpace(match[1]) + if path == "" { + return "" + } + return strings.TrimSpace(gjson.Get(body, path).String()) + }) + + requiredPattern := regexp.MustCompile(`\{\{required\.([^}]+)\}\}`) + message = requiredPattern.ReplaceAllStringFunc(message, func(token string) string { + match := requiredPattern.FindStringSubmatch(token) + if len(match) != 2 { + return "" + } + path := strings.TrimSpace(match[1]) + if failure.Field == path { + return fmt.Sprint(failure.Expected) + } + return "" + }) + + return strings.TrimSpace(message) +} diff --git a/oauth/types.go b/oauth/types.go index 1b0e3646a..383e6f351 100644 --- a/oauth/types.go +++ b/oauth/types.go @@ -57,3 +57,12 @@ func NewOAuthErrorWithRaw(msgKey string, params map[string]any, rawError string) RawError: rawError, } } + +// AccessDeniedError is a direct user-facing access denial message. +type AccessDeniedError struct { + Message string +} + +func (e *AccessDeniedError) Error() string { + return e.Message +} diff --git a/router/api-router.go b/router/api-router.go index e2ef2f531..d60ba39b2 100644 --- a/router/api-router.go +++ b/router/api-router.go @@ -170,10 +170,11 @@ func SetApiRouter(router *gin.Engine) { optionRoute.POST("/migrate_console_setting", controller.MigrateConsoleSetting) // 用于迁移检测的旧键,下个版本会删除 } - // Custom OAuth provider management (admin only) + // Custom OAuth provider management (root only) customOAuthRoute := apiRouter.Group("/custom-oauth-provider") customOAuthRoute.Use(middleware.RootAuth()) { + customOAuthRoute.POST("/discovery", controller.FetchCustomOAuthDiscovery) customOAuthRoute.GET("/", controller.GetCustomOAuthProviders) customOAuthRoute.GET("/:id", controller.GetCustomOAuthProvider) customOAuthRoute.POST("/", controller.CreateCustomOAuthProvider) diff --git a/web/src/components/auth/LoginForm.jsx b/web/src/components/auth/LoginForm.jsx index 636317e44..7e8c0ce01 100644 --- a/web/src/components/auth/LoginForm.jsx +++ b/web/src/components/auth/LoginForm.jsx @@ -29,6 +29,7 @@ import { showSuccess, updateAPI, getSystemName, + getOAuthProviderIcon, setUserData, onGitHubOAuthClicked, onDiscordOAuthClicked, @@ -130,6 +131,17 @@ const LoginForm = () => { return {}; } }, [statusState?.status]); + const hasCustomOAuthProviders = + (status.custom_oauth_providers || []).length > 0; + const hasOAuthLoginOptions = Boolean( + status.github_oauth || + status.discord_oauth || + status.oidc_enabled || + status.wechat_login || + status.linuxdo_oauth || + status.telegram_oauth || + hasCustomOAuthProviders, + ); useEffect(() => { if (status?.turnstile_check) { @@ -598,7 +610,7 @@ const LoginForm = () => { theme='outline' className='w-full h-12 flex items-center justify-center !rounded-full border border-gray-200 hover:bg-gray-50 transition-colors' type='tertiary' - icon={} + icon={getOAuthProviderIcon(provider.icon || '', 20)} onClick={() => handleCustomOAuthClick(provider)} loading={customOAuthLoading[provider.slug]} > @@ -817,12 +829,7 @@ const LoginForm = () => { - {(status.github_oauth || - status.discord_oauth || - status.oidc_enabled || - status.wechat_login || - status.linuxdo_oauth || - status.telegram_oauth) && ( + {hasOAuthLoginOptions && ( <> {t('或')} @@ -952,14 +959,7 @@ const LoginForm = () => { />
{showEmailLogin || - !( - status.github_oauth || - status.discord_oauth || - status.oidc_enabled || - status.wechat_login || - status.linuxdo_oauth || - status.telegram_oauth - ) + !hasOAuthLoginOptions ? renderEmailLoginForm() : renderOAuthOptions()} {renderWeChatLoginModal()} diff --git a/web/src/components/auth/RegisterForm.jsx b/web/src/components/auth/RegisterForm.jsx index 2edc499b1..0a755b194 100644 --- a/web/src/components/auth/RegisterForm.jsx +++ b/web/src/components/auth/RegisterForm.jsx @@ -27,8 +27,10 @@ import { showSuccess, updateAPI, getSystemName, + getOAuthProviderIcon, setUserData, onDiscordOAuthClicked, + onCustomOAuthClicked, } from '../../helpers'; import Turnstile from 'react-turnstile'; import { @@ -98,6 +100,7 @@ const RegisterForm = () => { const [otherRegisterOptionsLoading, setOtherRegisterOptionsLoading] = useState(false); const [wechatCodeSubmitLoading, setWechatCodeSubmitLoading] = useState(false); + const [customOAuthLoading, setCustomOAuthLoading] = useState({}); const [disableButton, setDisableButton] = useState(false); const [countdown, setCountdown] = useState(30); const [agreedToTerms, setAgreedToTerms] = useState(false); @@ -126,6 +129,17 @@ const RegisterForm = () => { return {}; } }, [statusState?.status]); + const hasCustomOAuthProviders = + (status.custom_oauth_providers || []).length > 0; + const hasOAuthRegisterOptions = Boolean( + status.github_oauth || + status.discord_oauth || + status.oidc_enabled || + status.wechat_login || + status.linuxdo_oauth || + status.telegram_oauth || + hasCustomOAuthProviders, + ); const [showEmailVerification, setShowEmailVerification] = useState(false); @@ -319,6 +333,17 @@ const RegisterForm = () => { } }; + const handleCustomOAuthClick = (provider) => { + setCustomOAuthLoading((prev) => ({ ...prev, [provider.slug]: true })); + try { + onCustomOAuthClicked(provider, { shouldLogout: true }); + } finally { + setTimeout(() => { + setCustomOAuthLoading((prev) => ({ ...prev, [provider.slug]: false })); + }, 3000); + } + }; + const handleEmailRegisterClick = () => { setEmailRegisterLoading(true); setShowEmailRegister(true); @@ -469,6 +494,23 @@ const RegisterForm = () => { )} + {status.custom_oauth_providers && + status.custom_oauth_providers.map((provider) => ( + + ))} + {status.telegram_oauth && (
{
- {(status.github_oauth || - status.discord_oauth || - status.oidc_enabled || - status.wechat_login || - status.linuxdo_oauth || - status.telegram_oauth) && ( + {hasOAuthRegisterOptions && ( <> {t('或')} @@ -745,14 +782,7 @@ const RegisterForm = () => { />
{showEmailRegister || - !( - status.github_oauth || - status.discord_oauth || - status.oidc_enabled || - status.wechat_login || - status.linuxdo_oauth || - status.telegram_oauth - ) + !hasOAuthRegisterOptions ? renderEmailRegisterForm() : renderOAuthOptions()} {renderWeChatLoginModal()} diff --git a/web/src/components/settings/CustomOAuthSetting.jsx b/web/src/components/settings/CustomOAuthSetting.jsx index 4b6df4c81..0912160be 100644 --- a/web/src/components/settings/CustomOAuthSetting.jsx +++ b/web/src/components/settings/CustomOAuthSetting.jsx @@ -27,14 +27,20 @@ import { Modal, Banner, Card, + Collapse, + Switch, Table, Tag, Popconfirm, Space, - Select, } from '@douyinfe/semi-ui'; -import { IconPlus, IconEdit, IconDelete } from '@douyinfe/semi-icons'; -import { API, showError, showSuccess } from '../../helpers'; +import { + IconPlus, + IconEdit, + IconDelete, + IconRefresh, +} from '@douyinfe/semi-icons'; +import { API, showError, showSuccess, getOAuthProviderIcon } from '../../helpers'; import { useTranslation } from 'react-i18next'; const { Text } = Typography; @@ -120,6 +126,69 @@ const OAUTH_PRESETS = { }, }; +const OAUTH_PRESET_ICONS = { + 'github-enterprise': 'github', + gitlab: 'gitlab', + gitea: 'gitea', + nextcloud: 'nextcloud', + keycloak: 'keycloak', + authentik: 'authentik', + ory: 'openid', +}; + +const getPresetIcon = (preset) => OAUTH_PRESET_ICONS[preset] || ''; + +const PRESET_RESET_VALUES = { + name: '', + slug: '', + icon: '', + authorization_endpoint: '', + token_endpoint: '', + user_info_endpoint: '', + scopes: '', + user_id_field: '', + username_field: '', + display_name_field: '', + email_field: '', + well_known: '', + auth_style: 0, + access_policy: '', + access_denied_message: '', +}; + +const DISCOVERY_FIELD_LABELS = { + authorization_endpoint: 'Authorization Endpoint', + token_endpoint: 'Token Endpoint', + user_info_endpoint: 'User Info Endpoint', + scopes: 'Scopes', + user_id_field: 'User ID Field', + username_field: 'Username Field', + display_name_field: 'Display Name Field', + email_field: 'Email Field', +}; + +const ACCESS_POLICY_TEMPLATES = { + level_active: `{ + "logic": "and", + "conditions": [ + {"field": "trust_level", "op": "gte", "value": 2}, + {"field": "active", "op": "eq", "value": true} + ] +}`, + org_or_role: `{ + "logic": "or", + "conditions": [ + {"field": "org", "op": "eq", "value": "core"}, + {"field": "roles", "op": "contains", "value": "admin"} + ] +}`, +}; + +const ACCESS_DENIED_TEMPLATES = { + level_hint: '需要等级 {{required}},你当前等级 {{current}}(字段:{{field}})', + org_hint: '仅限指定组织或角色访问。组织={{current.org}},角色={{current.roles}}', +}; + const CustomOAuthSetting = ({ serverAddress }) => { const { t } = useTranslation(); const [providers, setProviders] = useState([]); @@ -129,8 +198,47 @@ const CustomOAuthSetting = ({ serverAddress }) => { const [formValues, setFormValues] = useState({}); const [selectedPreset, setSelectedPreset] = useState(''); const [baseUrl, setBaseUrl] = useState(''); + const [discoveryLoading, setDiscoveryLoading] = useState(false); + const [discoveryInfo, setDiscoveryInfo] = useState(null); + const [advancedActiveKeys, setAdvancedActiveKeys] = useState([]); const formApiRef = React.useRef(null); + const mergeFormValues = (newValues) => { + setFormValues((prev) => ({ ...prev, ...newValues })); + if (!formApiRef.current) return; + Object.entries(newValues).forEach(([key, value]) => { + formApiRef.current.setValue(key, value); + }); + }; + + const getLatestFormValues = () => { + const values = formApiRef.current?.getValues?.(); + return values && typeof values === 'object' ? values : formValues; + }; + + const normalizeBaseUrl = (url) => (url || '').trim().replace(/\/+$/, ''); + + const inferBaseUrlFromProvider = (provider) => { + const endpoint = provider?.authorization_endpoint || provider?.token_endpoint; + if (!endpoint) return ''; + try { + const url = new URL(endpoint); + return `${url.protocol}//${url.host}`; + } catch (error) { + return ''; + } + }; + + const resetDiscoveryState = () => { + setDiscoveryInfo(null); + }; + + const closeModal = () => { + setModalVisible(false); + resetDiscoveryState(); + setAdvancedActiveKeys([]); + }; + const fetchProviders = async () => { setLoading(true); try { @@ -154,23 +262,30 @@ const CustomOAuthSetting = ({ serverAddress }) => { setEditingProvider(null); setFormValues({ enabled: false, + icon: '', scopes: 'openid profile email', user_id_field: 'sub', username_field: 'preferred_username', display_name_field: 'name', email_field: 'email', auth_style: 0, + access_policy: '', + access_denied_message: '', }); setSelectedPreset(''); setBaseUrl(''); + resetDiscoveryState(); + setAdvancedActiveKeys([]); setModalVisible(true); }; const handleEdit = (provider) => { setEditingProvider(provider); setFormValues({ ...provider }); - setSelectedPreset(''); - setBaseUrl(''); + setSelectedPreset(OAUTH_PRESETS[provider.slug] ? provider.slug : ''); + setBaseUrl(inferBaseUrlFromProvider(provider)); + resetDiscoveryState(); + setAdvancedActiveKeys([]); setModalVisible(true); }; @@ -189,6 +304,8 @@ const CustomOAuthSetting = ({ serverAddress }) => { }; const handleSubmit = async () => { + const currentValues = getLatestFormValues(); + // Validate required fields const requiredFields = [ 'name', @@ -204,7 +321,7 @@ const CustomOAuthSetting = ({ serverAddress }) => { } for (const field of requiredFields) { - if (!formValues[field]) { + if (!currentValues[field]) { showError(t(`请填写 ${field}`)); return; } @@ -213,11 +330,11 @@ const CustomOAuthSetting = ({ serverAddress }) => { // Validate endpoint URLs must be full URLs const endpointFields = ['authorization_endpoint', 'token_endpoint', 'user_info_endpoint']; for (const field of endpointFields) { - const value = formValues[field]; + const value = currentValues[field]; if (value && !value.startsWith('http://') && !value.startsWith('https://')) { - // Check if user selected a preset but forgot to fill server address + // Check if user selected a preset but forgot to fill issuer URL if (selectedPreset && !baseUrl) { - showError(t('请先填写服务器地址,以自动生成完整的端点 URL')); + showError(t('请先填写 Issuer URL,以自动生成完整的端点 URL')); } else { showError(t('端点 URL 必须是完整地址(以 http:// 或 https:// 开头)')); } @@ -226,80 +343,199 @@ const CustomOAuthSetting = ({ serverAddress }) => { } try { + const payload = { ...currentValues, enabled: !!currentValues.enabled }; + delete payload.preset; + delete payload.base_url; + let res; if (editingProvider) { res = await API.put( `/api/custom-oauth-provider/${editingProvider.id}`, - formValues + payload ); } else { - res = await API.post('/api/custom-oauth-provider/', formValues); + res = await API.post('/api/custom-oauth-provider/', payload); } if (res.data.success) { showSuccess(editingProvider ? t('更新成功') : t('创建成功')); - setModalVisible(false); + closeModal(); fetchProviders(); } else { showError(res.data.message); } } catch (error) { - showError(editingProvider ? t('更新失败') : t('创建失败')); + showError( + error?.response?.data?.message || + (editingProvider ? t('更新失败') : t('创建失败')), + ); + } + }; + + const handleFetchFromDiscovery = async () => { + const cleanBaseUrl = normalizeBaseUrl(baseUrl); + const configuredWellKnown = (formValues.well_known || '').trim(); + const wellKnownUrl = + configuredWellKnown || + (cleanBaseUrl ? `${cleanBaseUrl}/.well-known/openid-configuration` : ''); + + if (!wellKnownUrl) { + showError(t('请先填写 Discovery URL 或 Issuer URL')); + return; + } + + setDiscoveryLoading(true); + try { + const res = await API.post('/api/custom-oauth-provider/discovery', { + well_known_url: configuredWellKnown || '', + issuer_url: cleanBaseUrl || '', + }); + if (!res.data.success) { + throw new Error(res.data.message || t('未知错误')); + } + const data = res.data.data?.discovery || {}; + const resolvedWellKnown = res.data.data?.well_known_url || wellKnownUrl; + + const discoveredValues = { + well_known: resolvedWellKnown, + }; + const autoFilledFields = []; + if (data.authorization_endpoint) { + discoveredValues.authorization_endpoint = data.authorization_endpoint; + autoFilledFields.push('authorization_endpoint'); + } + if (data.token_endpoint) { + discoveredValues.token_endpoint = data.token_endpoint; + autoFilledFields.push('token_endpoint'); + } + if (data.userinfo_endpoint) { + discoveredValues.user_info_endpoint = data.userinfo_endpoint; + autoFilledFields.push('user_info_endpoint'); + } + + const scopesSupported = Array.isArray(data.scopes_supported) + ? data.scopes_supported + : []; + if (scopesSupported.length > 0 && !formValues.scopes) { + const preferredScopes = ['openid', 'profile', 'email'].filter((scope) => + scopesSupported.includes(scope), + ); + discoveredValues.scopes = + preferredScopes.length > 0 + ? preferredScopes.join(' ') + : scopesSupported.slice(0, 5).join(' '); + autoFilledFields.push('scopes'); + } + + const claimsSupported = Array.isArray(data.claims_supported) + ? data.claims_supported + : []; + const claimMap = { + user_id_field: 'sub', + username_field: 'preferred_username', + display_name_field: 'name', + email_field: 'email', + }; + Object.entries(claimMap).forEach(([field, claim]) => { + if (!formValues[field] && claimsSupported.includes(claim)) { + discoveredValues[field] = claim; + autoFilledFields.push(field); + } + }); + + const hasCoreEndpoint = + discoveredValues.authorization_endpoint || + discoveredValues.token_endpoint || + discoveredValues.user_info_endpoint; + if (!hasCoreEndpoint) { + showError(t('未在 Discovery 响应中找到可用的 OAuth 端点')); + return; + } + + mergeFormValues(discoveredValues); + setDiscoveryInfo({ + wellKnown: wellKnownUrl, + autoFilledFields, + scopesSupported: scopesSupported.slice(0, 12), + claimsSupported: claimsSupported.slice(0, 12), + }); + showSuccess(t('已从 Discovery 自动填充配置')); + } catch (error) { + showError( + t('获取 Discovery 配置失败:') + (error?.message || t('未知错误')), + ); + } finally { + setDiscoveryLoading(false); } }; const handlePresetChange = (preset) => { setSelectedPreset(preset); - if (preset && OAUTH_PRESETS[preset]) { - const presetConfig = OAUTH_PRESETS[preset]; - const cleanUrl = baseUrl ? baseUrl.replace(/\/+$/, '') : ''; - const newValues = { - name: presetConfig.name, - slug: preset, - scopes: presetConfig.scopes, - user_id_field: presetConfig.user_id_field, - username_field: presetConfig.username_field, - display_name_field: presetConfig.display_name_field, - email_field: presetConfig.email_field, - auth_style: presetConfig.auth_style ?? 0, - }; - // Only fill endpoints if server address is provided - if (cleanUrl) { - newValues.authorization_endpoint = cleanUrl + presetConfig.authorization_endpoint; - newValues.token_endpoint = cleanUrl + presetConfig.token_endpoint; - newValues.user_info_endpoint = cleanUrl + presetConfig.user_info_endpoint; - } - setFormValues((prev) => ({ ...prev, ...newValues })); - // Update form fields directly via formApi - if (formApiRef.current) { - Object.entries(newValues).forEach(([key, value]) => { - formApiRef.current.setValue(key, value); - }); - } + resetDiscoveryState(); + const cleanUrl = normalizeBaseUrl(baseUrl); + if (!preset || !OAUTH_PRESETS[preset]) { + mergeFormValues(PRESET_RESET_VALUES); + return; } + + const presetConfig = OAUTH_PRESETS[preset]; + const newValues = { + ...PRESET_RESET_VALUES, + name: presetConfig.name, + slug: preset, + icon: getPresetIcon(preset), + scopes: presetConfig.scopes, + user_id_field: presetConfig.user_id_field, + username_field: presetConfig.username_field, + display_name_field: presetConfig.display_name_field, + email_field: presetConfig.email_field, + auth_style: presetConfig.auth_style ?? 0, + }; + if (cleanUrl) { + newValues.authorization_endpoint = + cleanUrl + presetConfig.authorization_endpoint; + newValues.token_endpoint = cleanUrl + presetConfig.token_endpoint; + newValues.user_info_endpoint = cleanUrl + presetConfig.user_info_endpoint; + } + mergeFormValues(newValues); }; const handleBaseUrlChange = (url) => { setBaseUrl(url); if (url && selectedPreset && OAUTH_PRESETS[selectedPreset]) { const presetConfig = OAUTH_PRESETS[selectedPreset]; - const cleanUrl = url.replace(/\/+$/, ''); // Remove trailing slashes + const cleanUrl = normalizeBaseUrl(url); const newValues = { authorization_endpoint: cleanUrl + presetConfig.authorization_endpoint, token_endpoint: cleanUrl + presetConfig.token_endpoint, user_info_endpoint: cleanUrl + presetConfig.user_info_endpoint, }; - setFormValues((prev) => ({ ...prev, ...newValues })); - // Update form fields directly via formApi (use merge mode to preserve other fields) - if (formApiRef.current) { - Object.entries(newValues).forEach(([key, value]) => { - formApiRef.current.setValue(key, value); - }); - } + mergeFormValues(newValues); } }; + const applyAccessPolicyTemplate = (templateKey) => { + const template = ACCESS_POLICY_TEMPLATES[templateKey]; + if (!template) return; + mergeFormValues({ access_policy: template }); + showSuccess(t('已填充策略模板')); + }; + + const applyDeniedTemplate = (templateKey) => { + const template = ACCESS_DENIED_TEMPLATES[templateKey]; + if (!template) return; + mergeFormValues({ access_denied_message: template }); + showSuccess(t('已填充提示模板')); + }; + const columns = [ + { + title: t('图标'), + dataIndex: 'icon', + key: 'icon', + width: 80, + render: (icon) => getOAuthProviderIcon(icon || '', 18), + }, { title: t('名称'), dataIndex: 'name', @@ -325,7 +561,10 @@ const CustomOAuthSetting = ({ serverAddress }) => { title: t('Client ID'), dataIndex: 'client_id', key: 'client_id', - render: (id) => (id ? id.substring(0, 20) + '...' : '-'), + render: (id) => { + if (!id) return '-'; + return id.length > 20 ? `${id.substring(0, 20)}...` : id; + }, }, { title: t('操作'), @@ -352,6 +591,10 @@ const CustomOAuthSetting = ({ serverAddress }) => { }, ]; + const discoveryAutoFilledLabels = (discoveryInfo?.autoFilledFields || []) + .map((field) => DISCOVERY_FIELD_LABELS[field] || field) + .join(', '); + return ( @@ -391,56 +634,142 @@ const CustomOAuthSetting = ({ serverAddress }) => { setModalVisible(false)} - okText={t('保存')} - cancelText={t('取消')} - width={800} + onCancel={closeModal} + width={860} + centered + bodyStyle={{ maxHeight: '72vh', overflowY: 'auto', paddingRight: 6 }} + footer={ +
+ + {t('启用供应商')} + mergeFormValues({ enabled: !!checked })} + /> + + {formValues.enabled ? t('已启用') : t('已禁用')} + + + + +
+ } >
setFormValues(values)} + onValueChange={() => { + setFormValues((prev) => ({ ...prev, ...getLatestFormValues() })); + }} getFormApi={(api) => (formApiRef.current = api)} > - {!editingProvider && ( - - - ({ - value: key, - label: config.name, - })), - ]} - /> - - - - - + + {t('Configuration')} + + + {t('先填写配置,再自动填充 OAuth 端点,能显著减少手工输入')} + + {discoveryInfo && ( + +
+ {t('已从 Discovery 获取配置,可继续手动修改所有字段。')} +
+ {discoveryAutoFilledLabels ? ( +
+ {t('自动填充字段')}: + {' '} + {discoveryAutoFilledLabels} +
+ ) : null} + {discoveryInfo.scopesSupported?.length ? ( +
+ {t('Discovery scopes')}: + {' '} + {discoveryInfo.scopesSupported.join(', ')} +
+ ) : null} + {discoveryInfo.claimsSupported?.length ? ( +
+ {t('Discovery claims')}: + {' '} + {discoveryInfo.claimsSupported.join(', ')} +
+ ) : null} +
+ } + /> )} + + + ({ + value: key, + label: config.name, + })), + ]} + /> + + + + + +
+ +
+ +
+ + + + + + { + + + + {t( + '图标使用 react-icons(Simple Icons)或 URL/emoji,例如:github、gitlab、si:google', + )} + + } + showClear + /> + + +
+ {getOAuthProviderIcon(formValues.icon || '', 24)} +
+ +
+ { label={t('Authorization Endpoint')} placeholder={ selectedPreset && OAUTH_PRESETS[selectedPreset] - ? t('填写服务器地址后自动生成:') + + ? t('填写 Issuer URL 后自动生成:') + OAUTH_PRESETS[selectedPreset].authorization_endpoint : 'https://example.com/oauth/authorize' } @@ -544,15 +908,14 @@ const CustomOAuthSetting = ({ serverAddress }) => { - - - @@ -568,7 +931,7 @@ const CustomOAuthSetting = ({ serverAddress }) => { @@ -576,7 +939,7 @@ const CustomOAuthSetting = ({ serverAddress }) => { @@ -586,41 +949,100 @@ const CustomOAuthSetting = ({ serverAddress }) => { - - {t('高级选项')} - + { + const keys = Array.isArray(activeKey) ? activeKey : [activeKey]; + setAdvancedActiveKeys(keys.filter(Boolean)); + }} + > + + + + + + - - - - - - - {t('启用此 OAuth 提供商')} - - - + + {t('准入策略')} + + + {t('可选:基于用户信息 JSON 做组合条件准入,条件不满足时返回自定义提示')} + + + + mergeFormValues({ access_policy: value })} + label={t('准入策略 JSON(可选)')} + rows={6} + placeholder={`{ + "logic": "and", + "conditions": [ + {"field": "trust_level", "op": "gte", "value": 2}, + {"field": "active", "op": "eq", "value": true} + ] +}`} + extraText={t('支持逻辑 and/or 与嵌套 groups;操作符支持 eq/ne/gt/gte/lt/lte/in/not_in/contains/exists')} + showClear + /> + + + + + + + + + mergeFormValues({ access_denied_message: value })} + label={t('拒绝提示模板(可选)')} + placeholder={t('例如:需要等级 {{required}},你当前等级 {{current}}')} + extraText={t('可用变量:{{provider}} {{field}} {{op}} {{required}} {{current}} 以及 {{current.path}}')} + showClear + /> + + + + + + + + diff --git a/web/src/components/settings/personal/cards/AccountManagement.jsx b/web/src/components/settings/personal/cards/AccountManagement.jsx index bc27630ba..29249caa1 100644 --- a/web/src/components/settings/personal/cards/AccountManagement.jsx +++ b/web/src/components/settings/personal/cards/AccountManagement.jsx @@ -50,6 +50,7 @@ import { onLinuxDOOAuthClicked, onDiscordOAuthClicked, onCustomOAuthClicked, + getOAuthProviderIcon, } from '../../../../helpers'; import TwoFASetting from '../components/TwoFASetting'; @@ -148,12 +149,14 @@ const AccountManagement = ({ // Check if custom OAuth provider is bound const isCustomOAuthBound = (providerId) => { - return customOAuthBindings.some((b) => b.provider_id === providerId); + const normalizedId = Number(providerId); + return customOAuthBindings.some((b) => Number(b.provider_id) === normalizedId); }; // Get binding info for a provider const getCustomOAuthBinding = (providerId) => { - return customOAuthBindings.find((b) => b.provider_id === providerId); + const normalizedId = Number(providerId); + return customOAuthBindings.find((b) => Number(b.provider_id) === normalizedId); }; React.useEffect(() => { @@ -524,10 +527,10 @@ const AccountManagement = ({
- + {getOAuthProviderIcon( + provider.icon || binding?.provider_icon || '', + 20, + )}
diff --git a/web/src/helpers/render.jsx b/web/src/helpers/render.jsx index ecc252cfd..3ba198cb3 100644 --- a/web/src/helpers/render.jsx +++ b/web/src/helpers/render.jsx @@ -76,6 +76,31 @@ import { Server, CalendarClock, } from 'lucide-react'; +import { + SiAtlassian, + SiAuth0, + SiAuthentik, + SiBitbucket, + SiDiscord, + SiDropbox, + SiFacebook, + SiGitea, + SiGithub, + SiGitlab, + SiGoogle, + SiKeycloak, + SiLinkedin, + SiNextcloud, + SiNotion, + SiOkta, + SiOpenid, + SiReddit, + SiSlack, + SiTelegram, + SiTwitch, + SiWechat, + SiX, +} from 'react-icons/si'; // 获取侧边栏Lucide图标组件 export function getLucideIcon(key, selected = false) { @@ -472,6 +497,106 @@ export function getLobeHubIcon(iconName, size = 14) { return ; } +const oauthProviderIconMap = { + github: SiGithub, + gitlab: SiGitlab, + gitea: SiGitea, + google: SiGoogle, + discord: SiDiscord, + facebook: SiFacebook, + linkedin: SiLinkedin, + x: SiX, + twitter: SiX, + slack: SiSlack, + telegram: SiTelegram, + wechat: SiWechat, + keycloak: SiKeycloak, + nextcloud: SiNextcloud, + authentik: SiAuthentik, + openid: SiOpenid, + okta: SiOkta, + auth0: SiAuth0, + atlassian: SiAtlassian, + bitbucket: SiBitbucket, + notion: SiNotion, + twitch: SiTwitch, + reddit: SiReddit, + dropbox: SiDropbox, +}; + +function isHttpUrl(value) { + return /^https?:\/\//i.test(value || ''); +} + +function isSimpleEmoji(value) { + if (!value) return false; + const trimmed = String(value).trim(); + return trimmed.length > 0 && trimmed.length <= 4 && !isHttpUrl(trimmed); +} + +function normalizeOAuthIconKey(raw) { + return raw + .trim() + .toLowerCase() + .replace(/^ri:/, '') + .replace(/^react-icons:/, '') + .replace(/^si:/, ''); +} + +/** + * Render custom OAuth provider icon with react-icons or URL/emoji fallback. + * Supported formats: + * - react-icons simple key: github / gitlab / google / keycloak + * - prefixed key: ri:github / si:github + * - full URL image: https://example.com/logo.png + * - emoji: 🐱 + */ +export function getOAuthProviderIcon(iconName, size = 20) { + const raw = String(iconName || '').trim(); + const iconSize = Number(size) > 0 ? Number(size) : 20; + + if (!raw) { + return ; + } + + if (isHttpUrl(raw)) { + return ( + provider icon + ); + } + + if (isSimpleEmoji(raw)) { + return ( + + {raw} + + ); + } + + const key = normalizeOAuthIconKey(raw); + const IconComp = oauthProviderIconMap[key]; + if (IconComp) { + return ; + } + + return {raw.charAt(0).toUpperCase()}; +} + // 颜色列表 const colors = [ 'amber', From 9e3954428dc0bf6bf5c29eed415d8b213affad22 Mon Sep 17 00:00:00 2001 From: CaIon Date: Tue, 10 Feb 2026 20:40:33 +0800 Subject: [PATCH 10/41] refactor(task): extract billing and polling logic from controller to service layer Restructure the task relay system for better separation of concerns: - Extract task billing into service/task_billing.go with unified settlement flow - Move task polling loop from controller to service/task_polling.go (supports Suno + video platforms) - Split RelayTask into fetch/submit paths with dedicated retry logic (taskSubmitWithRetry) - Add TaskDto, TaskResponse generics, and FetchReq to dto/task.go - Add taskcommon/helpers.go for shared task adaptor utilities - Remove controller/task_video.go (logic consolidated into service layer) - Update all task adaptors (ali, doubao, gemini, hailuo, jimeng, kling, sora, suno, vertex, vidu) - Simplify frontend task logs to use new TaskDto response format --- controller/relay.go | 122 +++- controller/task.go | 228 +------ controller/task_video.go | 313 ---------- controller/video_proxy.go | 111 +--- controller/video_proxy_gemini.go | 8 +- dto/suno.go | 32 - dto/task.go | 47 ++ main.go | 10 + middleware/auth.go | 18 + model/task.go | 57 +- model/token.go | 6 +- relay/channel/task/ali/adaptor.go | 3 +- relay/channel/task/doubao/adaptor.go | 24 +- relay/channel/task/gemini/adaptor.go | 47 +- relay/channel/task/hailuo/adaptor.go | 15 +- relay/channel/task/jimeng/adaptor.go | 27 +- relay/channel/task/kling/adaptor.go | 43 +- relay/channel/task/sora/adaptor.go | 24 +- relay/channel/task/suno/adaptor.go | 29 +- relay/channel/task/taskcommon/helpers.go | 70 +++ relay/channel/task/vertex/adaptor.go | 50 +- relay/channel/task/vidu/adaptor.go | 45 +- relay/common/relay_info.go | 15 +- relay/helper/price.go | 15 +- relay/relay_task.go | 576 +++++++++--------- router/video-router.go | 8 +- service/billing_session.go | 5 + service/error.go | 13 + service/log_info_generate.go | 2 +- service/task_billing.go | 227 +++++++ service/task_polling.go | 446 ++++++++++++++ types/price_data.go | 9 +- .../table/task-logs/TaskLogsColumnDefs.jsx | 9 +- .../table/task-logs/modals/ContentModal.jsx | 2 - 34 files changed, 1465 insertions(+), 1191 deletions(-) delete mode 100644 controller/task_video.go create mode 100644 relay/channel/task/taskcommon/helpers.go create mode 100644 service/task_billing.go create mode 100644 service/task_polling.go diff --git a/controller/relay.go b/controller/relay.go index 0b30e6e9e..132fee9ba 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -451,17 +451,102 @@ func RelayNotFound(c *gin.Context) { } func RelayTask(c *gin.Context) { - retryTimes := common.RetryTimes channelId := c.GetInt("channel_id") c.Set("use_channel", []string{fmt.Sprintf("%d", channelId)}) relayInfo, err := relaycommon.GenRelayInfo(c, types.RelayFormatTask, nil, nil) if err != nil { + c.JSON(http.StatusInternalServerError, &dto.TaskError{ + Code: "gen_relay_info_failed", + Message: err.Error(), + StatusCode: http.StatusInternalServerError, + }) return } - taskErr := taskRelayHandler(c, relayInfo) - if taskErr == nil { - retryTimes = 0 + + // Fetch 操作是纯 DB 查询(或 task 自带 channelId 的上游查询),不依赖上下文 channel,无需重试 + // TODO: 在video-route层面优化,避免无谓的 channel 选择和上下文设置,也没必要吧代码放到这里来写这么多屎山 + switch relayInfo.RelayMode { + case relayconstant.RelayModeSunoFetch, relayconstant.RelayModeSunoFetchByID, relayconstant.RelayModeVideoFetchByID: + if taskErr := relay.RelayTaskFetch(c, relayInfo.RelayMode); taskErr != nil { + respondTaskError(c, taskErr) + } + return } + + // ── Submit 路径 ───────────────────────────────────────────────── + + // 1. 解析原始任务(remix / continuation),一次性,可能锁定渠道并禁止重试 + if taskErr := relay.ResolveOriginTask(c, relayInfo); taskErr != nil { + respondTaskError(c, taskErr) + return + } + + // 2. defer Refund(全部失败时回滚预扣费) + var result *relay.TaskSubmitResult + var taskErr *dto.TaskError + defer func() { + if taskErr != nil && relayInfo.Billing != nil { + relayInfo.Billing.Refund(c) + } + }() + + // 3. 执行 + 重试(RelayTaskSubmit 内部在首次调用时自动预扣费) + taskErr = taskSubmitWithRetry(c, relayInfo, channelId, common.RetryTimes, func() *dto.TaskError { + var te *dto.TaskError + result, te = relay.RelayTaskSubmit(c, relayInfo) + return te + }) + + // 4. 成功:结算 + 日志 + 插入任务 + if taskErr == nil { + if settleErr := service.SettleBilling(c, relayInfo, result.Quota); settleErr != nil { + common.SysError("settle task billing error: " + settleErr.Error()) + } + service.LogTaskConsumption(c, relayInfo, result.ModelName) + + task := model.InitTask(result.Platform, relayInfo) + task.PrivateData.UpstreamTaskID = result.UpstreamTaskID + task.PrivateData.BillingSource = relayInfo.BillingSource + task.PrivateData.SubscriptionId = relayInfo.SubscriptionId + task.PrivateData.TokenId = relayInfo.TokenId + task.Quota = result.Quota + task.Data = result.TaskData + task.Action = relayInfo.Action + if insertErr := task.Insert(); insertErr != nil { + //taskErr = service.TaskErrorWrapper(insertErr, "insert_task_failed", http.StatusInternalServerError) + common.SysError("insert task error: " + insertErr.Error()) + } + } + + if taskErr != nil { + respondTaskError(c, taskErr) + } +} + +// respondTaskError 统一输出 Task 错误响应(含 429 限流提示改写) +func respondTaskError(c *gin.Context, taskErr *dto.TaskError) { + if taskErr.StatusCode == http.StatusTooManyRequests { + taskErr.Message = "当前分组上游负载已饱和,请稍后再试" + } + c.JSON(taskErr.StatusCode, taskErr) +} + +// taskSubmitWithRetry 执行首次尝试并在失败时切换渠道重试,返回最终的 taskErr。 +// attempt 闭包负责实际的上游请求,不涉及计费。 +func taskSubmitWithRetry(c *gin.Context, relayInfo *relaycommon.RelayInfo, + channelId int, retryTimes int, attempt func() *dto.TaskError) *dto.TaskError { + + taskErr := attempt() + if taskErr == nil { + return nil + } + if !taskErr.LocalError { + processChannelError(c, + *types.NewChannelError(channelId, c.GetInt("channel_type"), c.GetString("channel_name"), common.GetContextKeyBool(c, constant.ContextKeyChannelIsMultiKey), + common.GetContextKeyString(c, constant.ContextKeyChannelKey), common.GetContextKeyBool(c, constant.ContextKeyChannelAutoBan)), + types.NewOpenAIError(taskErr.Error, types.ErrorCodeBadResponseStatusCode, taskErr.StatusCode)) + } + retryParam := &service.RetryParam{ Ctx: c, TokenGroup: relayInfo.TokenGroup, @@ -480,7 +565,7 @@ func RelayTask(c *gin.Context) { useChannel = append(useChannel, fmt.Sprintf("%d", channelId)) c.Set("use_channel", useChannel) logger.LogInfo(c, fmt.Sprintf("using channel #%d to retry (remain times %d)", channel.Id, retryParam.GetRetry())) - //middleware.SetupContextForSelectedChannel(c, channel, originalModel) + middleware.SetupContextForSelectedChannel(c, channel, c.GetString("original_model")) bodyStorage, err := common.GetBodyStorage(c) if err != nil { @@ -492,30 +577,21 @@ func RelayTask(c *gin.Context) { break } c.Request.Body = io.NopCloser(bodyStorage) - taskErr = taskRelayHandler(c, relayInfo) + taskErr = attempt() + if taskErr != nil && !taskErr.LocalError { + processChannelError(c, + *types.NewChannelError(channel.Id, channel.Type, channel.Name, channel.ChannelInfo.IsMultiKey, + common.GetContextKeyString(c, constant.ContextKeyChannelKey), channel.GetAutoBan()), + types.NewOpenAIError(taskErr.Error, types.ErrorCodeBadResponseStatusCode, taskErr.StatusCode)) + } } + useChannel := c.GetStringSlice("use_channel") if len(useChannel) > 1 { retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]")) logger.LogInfo(c, retryLogStr) } - if taskErr != nil { - if taskErr.StatusCode == http.StatusTooManyRequests { - taskErr.Message = "当前分组上游负载已饱和,请稍后再试" - } - c.JSON(taskErr.StatusCode, taskErr) - } -} - -func taskRelayHandler(c *gin.Context, relayInfo *relaycommon.RelayInfo) *dto.TaskError { - var err *dto.TaskError - switch relayInfo.RelayMode { - case relayconstant.RelayModeSunoFetch, relayconstant.RelayModeSunoFetchByID, relayconstant.RelayModeVideoFetchByID: - err = relay.RelayTaskFetch(c, relayInfo.RelayMode) - default: - err = relay.RelayTaskSubmit(c, relayInfo) - } - return err + return taskErr } func shouldRetryTaskRelay(c *gin.Context, channelId int, taskErr *dto.TaskError, retryTimes int) bool { diff --git a/controller/task.go b/controller/task.go index 244f9161c..ec713c5d2 100644 --- a/controller/task.go +++ b/controller/task.go @@ -1,231 +1,21 @@ package controller import ( - "context" - "encoding/json" - "errors" - "fmt" - "io" - "net/http" - "sort" "strconv" - "time" "github.com/QuantumNous/new-api/common" "github.com/QuantumNous/new-api/constant" "github.com/QuantumNous/new-api/dto" - "github.com/QuantumNous/new-api/logger" "github.com/QuantumNous/new-api/model" "github.com/QuantumNous/new-api/relay" + "github.com/QuantumNous/new-api/service" "github.com/gin-gonic/gin" - "github.com/samber/lo" ) +// UpdateTaskBulk 薄入口,实际轮询逻辑在 service 层 func UpdateTaskBulk() { - //revocer - //imageModel := "midjourney" - for { - time.Sleep(time.Duration(15) * time.Second) - common.SysLog("任务进度轮询开始") - ctx := context.TODO() - allTasks := model.GetAllUnFinishSyncTasks(constant.TaskQueryLimit) - platformTask := make(map[constant.TaskPlatform][]*model.Task) - for _, t := range allTasks { - platformTask[t.Platform] = append(platformTask[t.Platform], t) - } - for platform, tasks := range platformTask { - if len(tasks) == 0 { - continue - } - taskChannelM := make(map[int][]string) - taskM := make(map[string]*model.Task) - nullTaskIds := make([]int64, 0) - for _, task := range tasks { - if task.TaskID == "" { - // 统计失败的未完成任务 - nullTaskIds = append(nullTaskIds, task.ID) - continue - } - taskM[task.TaskID] = task - taskChannelM[task.ChannelId] = append(taskChannelM[task.ChannelId], task.TaskID) - } - if len(nullTaskIds) > 0 { - err := model.TaskBulkUpdateByID(nullTaskIds, map[string]any{ - "status": "FAILURE", - "progress": "100%", - }) - if err != nil { - logger.LogError(ctx, fmt.Sprintf("Fix null task_id task error: %v", err)) - } else { - logger.LogInfo(ctx, fmt.Sprintf("Fix null task_id task success: %v", nullTaskIds)) - } - } - if len(taskChannelM) == 0 { - continue - } - - UpdateTaskByPlatform(platform, taskChannelM, taskM) - } - common.SysLog("任务进度轮询完成") - } -} - -func UpdateTaskByPlatform(platform constant.TaskPlatform, taskChannelM map[int][]string, taskM map[string]*model.Task) { - switch platform { - case constant.TaskPlatformMidjourney: - //_ = UpdateMidjourneyTaskAll(context.Background(), tasks) - case constant.TaskPlatformSuno: - _ = UpdateSunoTaskAll(context.Background(), taskChannelM, taskM) - default: - if err := UpdateVideoTaskAll(context.Background(), platform, taskChannelM, taskM); err != nil { - common.SysLog(fmt.Sprintf("UpdateVideoTaskAll fail: %s", err)) - } - } -} - -func UpdateSunoTaskAll(ctx context.Context, taskChannelM map[int][]string, taskM map[string]*model.Task) error { - for channelId, taskIds := range taskChannelM { - err := updateSunoTaskAll(ctx, channelId, taskIds, taskM) - if err != nil { - logger.LogError(ctx, fmt.Sprintf("渠道 #%d 更新异步任务失败: %s", channelId, err.Error())) - } - } - return nil -} - -func updateSunoTaskAll(ctx context.Context, channelId int, taskIds []string, taskM map[string]*model.Task) error { - logger.LogInfo(ctx, fmt.Sprintf("渠道 #%d 未完成的任务有: %d", channelId, len(taskIds))) - if len(taskIds) == 0 { - return nil - } - channel, err := model.CacheGetChannel(channelId) - if err != nil { - common.SysLog(fmt.Sprintf("CacheGetChannel: %v", err)) - err = model.TaskBulkUpdate(taskIds, map[string]any{ - "fail_reason": fmt.Sprintf("获取渠道信息失败,请联系管理员,渠道ID:%d", channelId), - "status": "FAILURE", - "progress": "100%", - }) - if err != nil { - common.SysLog(fmt.Sprintf("UpdateMidjourneyTask error2: %v", err)) - } - return err - } - adaptor := relay.GetTaskAdaptor(constant.TaskPlatformSuno) - if adaptor == nil { - return errors.New("adaptor not found") - } - proxy := channel.GetSetting().Proxy - resp, err := adaptor.FetchTask(*channel.BaseURL, channel.Key, map[string]any{ - "ids": taskIds, - }, proxy) - if err != nil { - common.SysLog(fmt.Sprintf("Get Task Do req error: %v", err)) - return err - } - if resp.StatusCode != http.StatusOK { - logger.LogError(ctx, fmt.Sprintf("Get Task status code: %d", resp.StatusCode)) - return errors.New(fmt.Sprintf("Get Task status code: %d", resp.StatusCode)) - } - defer resp.Body.Close() - responseBody, err := io.ReadAll(resp.Body) - if err != nil { - common.SysLog(fmt.Sprintf("Get Task parse body error: %v", err)) - return err - } - var responseItems dto.TaskResponse[[]dto.SunoDataResponse] - err = json.Unmarshal(responseBody, &responseItems) - if err != nil { - logger.LogError(ctx, fmt.Sprintf("Get Task parse body error2: %v, body: %s", err, string(responseBody))) - return err - } - if !responseItems.IsSuccess() { - common.SysLog(fmt.Sprintf("渠道 #%d 未完成的任务有: %d, 成功获取到任务数: %s", channelId, len(taskIds), string(responseBody))) - return err - } - - for _, responseItem := range responseItems.Data { - task := taskM[responseItem.TaskID] - if !checkTaskNeedUpdate(task, responseItem) { - continue - } - - task.Status = lo.If(model.TaskStatus(responseItem.Status) != "", model.TaskStatus(responseItem.Status)).Else(task.Status) - task.FailReason = lo.If(responseItem.FailReason != "", responseItem.FailReason).Else(task.FailReason) - task.SubmitTime = lo.If(responseItem.SubmitTime != 0, responseItem.SubmitTime).Else(task.SubmitTime) - task.StartTime = lo.If(responseItem.StartTime != 0, responseItem.StartTime).Else(task.StartTime) - task.FinishTime = lo.If(responseItem.FinishTime != 0, responseItem.FinishTime).Else(task.FinishTime) - if responseItem.FailReason != "" || task.Status == model.TaskStatusFailure { - logger.LogInfo(ctx, task.TaskID+" 构建失败,"+task.FailReason) - task.Progress = "100%" - //err = model.CacheUpdateUserQuota(task.UserId) ? - if err != nil { - logger.LogError(ctx, "error update user quota cache: "+err.Error()) - } else { - quota := task.Quota - if quota != 0 { - err = model.IncreaseUserQuota(task.UserId, quota, false) - if err != nil { - logger.LogError(ctx, "fail to increase user quota: "+err.Error()) - } - logContent := fmt.Sprintf("异步任务执行失败 %s,补偿 %s", task.TaskID, logger.LogQuota(quota)) - model.RecordLog(task.UserId, model.LogTypeSystem, logContent) - } - } - } - if responseItem.Status == model.TaskStatusSuccess { - task.Progress = "100%" - } - task.Data = responseItem.Data - - err = task.Update() - if err != nil { - common.SysLog("UpdateMidjourneyTask task error: " + err.Error()) - } - } - return nil -} - -func checkTaskNeedUpdate(oldTask *model.Task, newTask dto.SunoDataResponse) bool { - - if oldTask.SubmitTime != newTask.SubmitTime { - return true - } - if oldTask.StartTime != newTask.StartTime { - return true - } - if oldTask.FinishTime != newTask.FinishTime { - return true - } - if string(oldTask.Status) != newTask.Status { - return true - } - if oldTask.FailReason != newTask.FailReason { - return true - } - if oldTask.FinishTime != newTask.FinishTime { - return true - } - - if (oldTask.Status == model.TaskStatusFailure || oldTask.Status == model.TaskStatusSuccess) && oldTask.Progress != "100%" { - return true - } - - oldData, _ := json.Marshal(oldTask.Data) - newData, _ := json.Marshal(newTask.Data) - - sort.Slice(oldData, func(i, j int) bool { - return oldData[i] < oldData[j] - }) - sort.Slice(newData, func(i, j int) bool { - return newData[i] < newData[j] - }) - - if string(oldData) != string(newData) { - return true - } - return false + service.TaskPollingLoop() } func GetAllTask(c *gin.Context) { @@ -247,7 +37,7 @@ func GetAllTask(c *gin.Context) { items := model.TaskGetAllTasks(pageInfo.GetStartIdx(), pageInfo.GetPageSize(), queryParams) total := model.TaskCountAllTasks(queryParams) pageInfo.SetTotal(int(total)) - pageInfo.SetItems(items) + pageInfo.SetItems(tasksToDto(items)) common.ApiSuccess(c, pageInfo) } @@ -271,6 +61,14 @@ func GetUserTask(c *gin.Context) { items := model.TaskGetAllUserTask(userId, pageInfo.GetStartIdx(), pageInfo.GetPageSize(), queryParams) total := model.TaskCountAllUserTask(userId, queryParams) pageInfo.SetTotal(int(total)) - pageInfo.SetItems(items) + pageInfo.SetItems(tasksToDto(items)) common.ApiSuccess(c, pageInfo) } + +func tasksToDto(tasks []*model.Task) []*dto.TaskDto { + result := make([]*dto.TaskDto, len(tasks)) + for i, task := range tasks { + result[i] = relay.TaskModel2Dto(task) + } + return result +} diff --git a/controller/task_video.go b/controller/task_video.go deleted file mode 100644 index d7c19e620..000000000 --- a/controller/task_video.go +++ /dev/null @@ -1,313 +0,0 @@ -package controller - -import ( - "context" - "encoding/json" - "fmt" - "io" - "time" - - "github.com/QuantumNous/new-api/common" - "github.com/QuantumNous/new-api/constant" - "github.com/QuantumNous/new-api/dto" - "github.com/QuantumNous/new-api/logger" - "github.com/QuantumNous/new-api/model" - "github.com/QuantumNous/new-api/relay" - "github.com/QuantumNous/new-api/relay/channel" - relaycommon "github.com/QuantumNous/new-api/relay/common" - "github.com/QuantumNous/new-api/setting/ratio_setting" -) - -func UpdateVideoTaskAll(ctx context.Context, platform constant.TaskPlatform, taskChannelM map[int][]string, taskM map[string]*model.Task) error { - for channelId, taskIds := range taskChannelM { - if err := updateVideoTaskAll(ctx, platform, channelId, taskIds, taskM); err != nil { - logger.LogError(ctx, fmt.Sprintf("Channel #%d failed to update video async tasks: %s", channelId, err.Error())) - } - } - return nil -} - -func updateVideoTaskAll(ctx context.Context, platform constant.TaskPlatform, channelId int, taskIds []string, taskM map[string]*model.Task) error { - logger.LogInfo(ctx, fmt.Sprintf("Channel #%d pending video tasks: %d", channelId, len(taskIds))) - if len(taskIds) == 0 { - return nil - } - cacheGetChannel, err := model.CacheGetChannel(channelId) - if err != nil { - errUpdate := model.TaskBulkUpdate(taskIds, map[string]any{ - "fail_reason": fmt.Sprintf("Failed to get channel info, channel ID: %d", channelId), - "status": "FAILURE", - "progress": "100%", - }) - if errUpdate != nil { - common.SysLog(fmt.Sprintf("UpdateVideoTask error: %v", errUpdate)) - } - return fmt.Errorf("CacheGetChannel failed: %w", err) - } - adaptor := relay.GetTaskAdaptor(platform) - if adaptor == nil { - return fmt.Errorf("video adaptor not found") - } - info := &relaycommon.RelayInfo{} - info.ChannelMeta = &relaycommon.ChannelMeta{ - ChannelBaseUrl: cacheGetChannel.GetBaseURL(), - } - info.ApiKey = cacheGetChannel.Key - adaptor.Init(info) - for _, taskId := range taskIds { - if err := updateVideoSingleTask(ctx, adaptor, cacheGetChannel, taskId, taskM); err != nil { - logger.LogError(ctx, fmt.Sprintf("Failed to update video task %s: %s", taskId, err.Error())) - } - } - return nil -} - -func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, channel *model.Channel, taskId string, taskM map[string]*model.Task) error { - baseURL := constant.ChannelBaseURLs[channel.Type] - if channel.GetBaseURL() != "" { - baseURL = channel.GetBaseURL() - } - proxy := channel.GetSetting().Proxy - - task := taskM[taskId] - if task == nil { - logger.LogError(ctx, fmt.Sprintf("Task %s not found in taskM", taskId)) - return fmt.Errorf("task %s not found", taskId) - } - key := channel.Key - - privateData := task.PrivateData - if privateData.Key != "" { - key = privateData.Key - } - resp, err := adaptor.FetchTask(baseURL, key, map[string]any{ - "task_id": taskId, - "action": task.Action, - }, proxy) - if err != nil { - return fmt.Errorf("fetchTask failed for task %s: %w", taskId, err) - } - //if resp.StatusCode != http.StatusOK { - //return fmt.Errorf("get Video Task status code: %d", resp.StatusCode) - //} - defer resp.Body.Close() - responseBody, err := io.ReadAll(resp.Body) - if err != nil { - return fmt.Errorf("readAll failed for task %s: %w", taskId, err) - } - - logger.LogDebug(ctx, fmt.Sprintf("UpdateVideoSingleTask response: %s", string(responseBody))) - - taskResult := &relaycommon.TaskInfo{} - // try parse as New API response format - var responseItems dto.TaskResponse[model.Task] - if err = common.Unmarshal(responseBody, &responseItems); err == nil && responseItems.IsSuccess() { - logger.LogDebug(ctx, fmt.Sprintf("UpdateVideoSingleTask parsed as new api response format: %+v", responseItems)) - t := responseItems.Data - taskResult.TaskID = t.TaskID - taskResult.Status = string(t.Status) - taskResult.Url = t.FailReason - taskResult.Progress = t.Progress - taskResult.Reason = t.FailReason - task.Data = t.Data - } else if taskResult, err = adaptor.ParseTaskResult(responseBody); err != nil { - return fmt.Errorf("parseTaskResult failed for task %s: %w", taskId, err) - } else { - task.Data = redactVideoResponseBody(responseBody) - } - - logger.LogDebug(ctx, fmt.Sprintf("UpdateVideoSingleTask taskResult: %+v", taskResult)) - - now := time.Now().Unix() - if taskResult.Status == "" { - //return fmt.Errorf("task %s status is empty", taskId) - taskResult = relaycommon.FailTaskInfo("upstream returned empty status") - } - - // 记录原本的状态,防止重复退款 - shouldRefund := false - quota := task.Quota - preStatus := task.Status - - task.Status = model.TaskStatus(taskResult.Status) - switch taskResult.Status { - case model.TaskStatusSubmitted: - task.Progress = "10%" - case model.TaskStatusQueued: - task.Progress = "20%" - case model.TaskStatusInProgress: - task.Progress = "30%" - if task.StartTime == 0 { - task.StartTime = now - } - case model.TaskStatusSuccess: - task.Progress = "100%" - if task.FinishTime == 0 { - task.FinishTime = now - } - if !(len(taskResult.Url) > 5 && taskResult.Url[:5] == "data:") { - task.FailReason = taskResult.Url - } - - // 如果返回了 total_tokens 并且配置了模型倍率(非固定价格),则重新计费 - if taskResult.TotalTokens > 0 { - // 获取模型名称 - var taskData map[string]interface{} - if err := json.Unmarshal(task.Data, &taskData); err == nil { - if modelName, ok := taskData["model"].(string); ok && modelName != "" { - // 获取模型价格和倍率 - modelRatio, hasRatioSetting, _ := ratio_setting.GetModelRatio(modelName) - // 只有配置了倍率(非固定价格)时才按 token 重新计费 - if hasRatioSetting && modelRatio > 0 { - // 获取用户和组的倍率信息 - group := task.Group - if group == "" { - user, err := model.GetUserById(task.UserId, false) - if err == nil { - group = user.Group - } - } - if group != "" { - groupRatio := ratio_setting.GetGroupRatio(group) - userGroupRatio, hasUserGroupRatio := ratio_setting.GetGroupGroupRatio(group, group) - - var finalGroupRatio float64 - if hasUserGroupRatio { - finalGroupRatio = userGroupRatio - } else { - finalGroupRatio = groupRatio - } - - // 计算实际应扣费额度: totalTokens * modelRatio * groupRatio - actualQuota := int(float64(taskResult.TotalTokens) * modelRatio * finalGroupRatio) - - // 计算差额 - preConsumedQuota := task.Quota - quotaDelta := actualQuota - preConsumedQuota - - if quotaDelta > 0 { - // 需要补扣费 - logger.LogInfo(ctx, fmt.Sprintf("视频任务 %s 预扣费后补扣费:%s(实际消耗:%s,预扣费:%s,tokens:%d)", - task.TaskID, - logger.LogQuota(quotaDelta), - logger.LogQuota(actualQuota), - logger.LogQuota(preConsumedQuota), - taskResult.TotalTokens, - )) - if err := model.DecreaseUserQuota(task.UserId, quotaDelta); err != nil { - logger.LogError(ctx, fmt.Sprintf("补扣费失败: %s", err.Error())) - } else { - model.UpdateUserUsedQuotaAndRequestCount(task.UserId, quotaDelta) - model.UpdateChannelUsedQuota(task.ChannelId, quotaDelta) - task.Quota = actualQuota // 更新任务记录的实际扣费额度 - - // 记录消费日志 - logContent := fmt.Sprintf("视频任务成功补扣费,模型倍率 %.2f,分组倍率 %.2f,tokens %d,预扣费 %s,实际扣费 %s,补扣费 %s", - modelRatio, finalGroupRatio, taskResult.TotalTokens, - logger.LogQuota(preConsumedQuota), logger.LogQuota(actualQuota), logger.LogQuota(quotaDelta)) - model.RecordLog(task.UserId, model.LogTypeSystem, logContent) - } - } else if quotaDelta < 0 { - // 需要退还多扣的费用 - refundQuota := -quotaDelta - logger.LogInfo(ctx, fmt.Sprintf("视频任务 %s 预扣费后返还:%s(实际消耗:%s,预扣费:%s,tokens:%d)", - task.TaskID, - logger.LogQuota(refundQuota), - logger.LogQuota(actualQuota), - logger.LogQuota(preConsumedQuota), - taskResult.TotalTokens, - )) - if err := model.IncreaseUserQuota(task.UserId, refundQuota, false); err != nil { - logger.LogError(ctx, fmt.Sprintf("退还预扣费失败: %s", err.Error())) - } else { - task.Quota = actualQuota // 更新任务记录的实际扣费额度 - - // 记录退款日志 - logContent := fmt.Sprintf("视频任务成功退还多扣费用,模型倍率 %.2f,分组倍率 %.2f,tokens %d,预扣费 %s,实际扣费 %s,退还 %s", - modelRatio, finalGroupRatio, taskResult.TotalTokens, - logger.LogQuota(preConsumedQuota), logger.LogQuota(actualQuota), logger.LogQuota(refundQuota)) - model.RecordLog(task.UserId, model.LogTypeSystem, logContent) - } - } else { - // quotaDelta == 0, 预扣费刚好准确 - logger.LogInfo(ctx, fmt.Sprintf("视频任务 %s 预扣费准确(%s,tokens:%d)", - task.TaskID, logger.LogQuota(actualQuota), taskResult.TotalTokens)) - } - } - } - } - } - } - case model.TaskStatusFailure: - logger.LogJson(ctx, fmt.Sprintf("Task %s failed", taskId), task) - task.Status = model.TaskStatusFailure - task.Progress = "100%" - if task.FinishTime == 0 { - task.FinishTime = now - } - task.FailReason = taskResult.Reason - logger.LogInfo(ctx, fmt.Sprintf("Task %s failed: %s", task.TaskID, task.FailReason)) - taskResult.Progress = "100%" - if quota != 0 { - if preStatus != model.TaskStatusFailure { - shouldRefund = true - } else { - logger.LogWarn(ctx, fmt.Sprintf("Task %s already in failure status, skip refund", task.TaskID)) - } - } - default: - return fmt.Errorf("unknown task status %s for task %s", taskResult.Status, taskId) - } - if taskResult.Progress != "" { - task.Progress = taskResult.Progress - } - if err := task.Update(); err != nil { - common.SysLog("UpdateVideoTask task error: " + err.Error()) - shouldRefund = false - } - - if shouldRefund { - // 任务失败且之前状态不是失败才退还额度,防止重复退还 - if err := model.IncreaseUserQuota(task.UserId, quota, false); err != nil { - logger.LogWarn(ctx, "Failed to increase user quota: "+err.Error()) - } - logContent := fmt.Sprintf("Video async task failed %s, refund %s", task.TaskID, logger.LogQuota(quota)) - model.RecordLog(task.UserId, model.LogTypeSystem, logContent) - } - - return nil -} - -func redactVideoResponseBody(body []byte) []byte { - var m map[string]any - if err := json.Unmarshal(body, &m); err != nil { - return body - } - resp, _ := m["response"].(map[string]any) - if resp != nil { - delete(resp, "bytesBase64Encoded") - if v, ok := resp["video"].(string); ok { - resp["video"] = truncateBase64(v) - } - if vs, ok := resp["videos"].([]any); ok { - for i := range vs { - if vm, ok := vs[i].(map[string]any); ok { - delete(vm, "bytesBase64Encoded") - } - } - } - } - b, err := json.Marshal(m) - if err != nil { - return body - } - return b -} - -func truncateBase64(s string) string { - const maxKeep = 256 - if len(s) <= maxKeep { - return s - } - return s[:maxKeep] + "..." -} diff --git a/controller/video_proxy.go b/controller/video_proxy.go index f102baae4..f1dd2bc92 100644 --- a/controller/video_proxy.go +++ b/controller/video_proxy.go @@ -16,59 +16,44 @@ import ( "github.com/gin-gonic/gin" ) +// videoProxyError returns a standardized OpenAI-style error response. +func videoProxyError(c *gin.Context, status int, errType, message string) { + c.JSON(status, gin.H{ + "error": gin.H{ + "message": message, + "type": errType, + }, + }) +} + func VideoProxy(c *gin.Context) { taskID := c.Param("task_id") if taskID == "" { - c.JSON(http.StatusBadRequest, gin.H{ - "error": gin.H{ - "message": "task_id is required", - "type": "invalid_request_error", - }, - }) + videoProxyError(c, http.StatusBadRequest, "invalid_request_error", "task_id is required") return } task, exists, err := model.GetByOnlyTaskId(taskID) if err != nil { logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to query task %s: %s", taskID, err.Error())) - c.JSON(http.StatusInternalServerError, gin.H{ - "error": gin.H{ - "message": "Failed to query task", - "type": "server_error", - }, - }) + videoProxyError(c, http.StatusInternalServerError, "server_error", "Failed to query task") return } if !exists || task == nil { - logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to get task %s: %v", taskID, err)) - c.JSON(http.StatusNotFound, gin.H{ - "error": gin.H{ - "message": "Task not found", - "type": "invalid_request_error", - }, - }) + videoProxyError(c, http.StatusNotFound, "invalid_request_error", "Task not found") return } if task.Status != model.TaskStatusSuccess { - c.JSON(http.StatusBadRequest, gin.H{ - "error": gin.H{ - "message": fmt.Sprintf("Task is not completed yet, current status: %s", task.Status), - "type": "invalid_request_error", - }, - }) + videoProxyError(c, http.StatusBadRequest, "invalid_request_error", + fmt.Sprintf("Task is not completed yet, current status: %s", task.Status)) return } channel, err := model.CacheGetChannel(task.ChannelId) if err != nil { - logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to get task %s: not found", taskID)) - c.JSON(http.StatusInternalServerError, gin.H{ - "error": gin.H{ - "message": "Failed to retrieve channel information", - "type": "server_error", - }, - }) + logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to get channel for task %s: %s", taskID, err.Error())) + videoProxyError(c, http.StatusInternalServerError, "server_error", "Failed to retrieve channel information") return } baseURL := channel.GetBaseURL() @@ -81,12 +66,7 @@ func VideoProxy(c *gin.Context) { client, err := service.GetHttpClientWithProxy(proxy) if err != nil { logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to create proxy client for task %s: %s", taskID, err.Error())) - c.JSON(http.StatusInternalServerError, gin.H{ - "error": gin.H{ - "message": "Failed to create proxy client", - "type": "server_error", - }, - }) + videoProxyError(c, http.StatusInternalServerError, "server_error", "Failed to create proxy client") return } @@ -95,12 +75,7 @@ func VideoProxy(c *gin.Context) { req, err := http.NewRequestWithContext(ctx, http.MethodGet, "", nil) if err != nil { logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to create request: %s", err.Error())) - c.JSON(http.StatusInternalServerError, gin.H{ - "error": gin.H{ - "message": "Failed to create proxy request", - "type": "server_error", - }, - }) + videoProxyError(c, http.StatusInternalServerError, "server_error", "Failed to create proxy request") return } @@ -109,68 +84,43 @@ func VideoProxy(c *gin.Context) { apiKey := task.PrivateData.Key if apiKey == "" { logger.LogError(c.Request.Context(), fmt.Sprintf("Missing stored API key for Gemini task %s", taskID)) - c.JSON(http.StatusInternalServerError, gin.H{ - "error": gin.H{ - "message": "API key not stored for task", - "type": "server_error", - }, - }) + videoProxyError(c, http.StatusInternalServerError, "server_error", "API key not stored for task") return } - videoURL, err = getGeminiVideoURL(channel, task, apiKey) if err != nil { logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to resolve Gemini video URL for task %s: %s", taskID, err.Error())) - c.JSON(http.StatusBadGateway, gin.H{ - "error": gin.H{ - "message": "Failed to resolve Gemini video URL", - "type": "server_error", - }, - }) + videoProxyError(c, http.StatusBadGateway, "server_error", "Failed to resolve Gemini video URL") return } req.Header.Set("x-goog-api-key", apiKey) case constant.ChannelTypeOpenAI, constant.ChannelTypeSora: - videoURL = fmt.Sprintf("%s/v1/videos/%s/content", baseURL, task.TaskID) + videoURL = fmt.Sprintf("%s/v1/videos/%s/content", baseURL, task.GetUpstreamTaskID()) req.Header.Set("Authorization", "Bearer "+channel.Key) default: - // Video URL is directly in task.FailReason - videoURL = task.FailReason + // Video URL is stored in PrivateData.ResultURL (fallback to FailReason for old data) + videoURL = task.GetResultURL() } req.URL, err = url.Parse(videoURL) if err != nil { logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to parse URL %s: %s", videoURL, err.Error())) - c.JSON(http.StatusInternalServerError, gin.H{ - "error": gin.H{ - "message": "Failed to create proxy request", - "type": "server_error", - }, - }) + videoProxyError(c, http.StatusInternalServerError, "server_error", "Failed to create proxy request") return } resp, err := client.Do(req) if err != nil { logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to fetch video from %s: %s", videoURL, err.Error())) - c.JSON(http.StatusBadGateway, gin.H{ - "error": gin.H{ - "message": "Failed to fetch video content", - "type": "server_error", - }, - }) + videoProxyError(c, http.StatusBadGateway, "server_error", "Failed to fetch video content") return } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { logger.LogError(c.Request.Context(), fmt.Sprintf("Upstream returned status %d for %s", resp.StatusCode, videoURL)) - c.JSON(http.StatusBadGateway, gin.H{ - "error": gin.H{ - "message": fmt.Sprintf("Upstream service returned status %d", resp.StatusCode), - "type": "server_error", - }, - }) + videoProxyError(c, http.StatusBadGateway, "server_error", + fmt.Sprintf("Upstream service returned status %d", resp.StatusCode)) return } @@ -180,10 +130,9 @@ func VideoProxy(c *gin.Context) { } } - c.Writer.Header().Set("Cache-Control", "public, max-age=86400") // Cache for 24 hours + c.Writer.Header().Set("Cache-Control", "public, max-age=86400") c.Writer.WriteHeader(resp.StatusCode) - _, err = io.Copy(c.Writer, resp.Body) - if err != nil { + if _, err = io.Copy(c.Writer, resp.Body); err != nil { logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to stream video content: %s", err.Error())) } } diff --git a/controller/video_proxy_gemini.go b/controller/video_proxy_gemini.go index 053ac6515..a63a2a5c4 100644 --- a/controller/video_proxy_gemini.go +++ b/controller/video_proxy_gemini.go @@ -1,12 +1,12 @@ package controller import ( - "encoding/json" "fmt" "io" "strconv" "strings" + "github.com/QuantumNous/new-api/common" "github.com/QuantumNous/new-api/constant" "github.com/QuantumNous/new-api/model" "github.com/QuantumNous/new-api/relay" @@ -37,7 +37,7 @@ func getGeminiVideoURL(channel *model.Channel, task *model.Task, apiKey string) proxy := channel.GetSetting().Proxy resp, err := adaptor.FetchTask(baseURL, apiKey, map[string]any{ - "task_id": task.TaskID, + "task_id": task.GetUpstreamTaskID(), "action": task.Action, }, proxy) if err != nil { @@ -71,7 +71,7 @@ func extractGeminiVideoURLFromTaskData(task *model.Task) string { return "" } var payload map[string]any - if err := json.Unmarshal(task.Data, &payload); err != nil { + if err := common.Unmarshal(task.Data, &payload); err != nil { return "" } return extractGeminiVideoURLFromMap(payload) @@ -79,7 +79,7 @@ func extractGeminiVideoURLFromTaskData(task *model.Task) string { func extractGeminiVideoURLFromPayload(body []byte) string { var payload map[string]any - if err := json.Unmarshal(body, &payload); err != nil { + if err := common.Unmarshal(body, &payload); err != nil { return "" } return extractGeminiVideoURLFromMap(payload) diff --git a/dto/suno.go b/dto/suno.go index a6bb3ebae..90e11b810 100644 --- a/dto/suno.go +++ b/dto/suno.go @@ -4,10 +4,6 @@ import ( "encoding/json" ) -type TaskData interface { - SunoDataResponse | []SunoDataResponse | string | any -} - type SunoSubmitReq struct { GptDescriptionPrompt string `json:"gpt_description_prompt,omitempty"` Prompt string `json:"prompt,omitempty"` @@ -20,10 +16,6 @@ type SunoSubmitReq struct { MakeInstrumental bool `json:"make_instrumental"` } -type FetchReq struct { - IDs []string `json:"ids"` -} - type SunoDataResponse struct { TaskID string `json:"task_id" gorm:"type:varchar(50);index"` Action string `json:"action" gorm:"type:varchar(40);index"` // 任务类型, song, lyrics, description-mode @@ -66,30 +58,6 @@ type SunoLyrics struct { Text string `json:"text"` } -const TaskSuccessCode = "success" - -type TaskResponse[T TaskData] struct { - Code string `json:"code"` - Message string `json:"message"` - Data T `json:"data"` -} - -func (t *TaskResponse[T]) IsSuccess() bool { - return t.Code == TaskSuccessCode -} - -type TaskDto struct { - TaskID string `json:"task_id"` // 第三方id,不一定有/ song id\ Task id - Action string `json:"action"` // 任务类型, song, lyrics, description-mode - Status string `json:"status"` // 任务状态, submitted, queueing, processing, success, failed - FailReason string `json:"fail_reason"` - SubmitTime int64 `json:"submit_time"` - StartTime int64 `json:"start_time"` - FinishTime int64 `json:"finish_time"` - Progress string `json:"progress"` - Data json.RawMessage `json:"data"` -} - type SunoGoAPISubmitReq struct { CustomMode bool `json:"custom_mode"` diff --git a/dto/task.go b/dto/task.go index afc186b41..4a9a8e2e6 100644 --- a/dto/task.go +++ b/dto/task.go @@ -1,5 +1,9 @@ package dto +import ( + "encoding/json" +) + type TaskError struct { Code string `json:"code"` Message string `json:"message"` @@ -8,3 +12,46 @@ type TaskError struct { LocalError bool `json:"-"` Error error `json:"-"` } + +type TaskData interface { + SunoDataResponse | []SunoDataResponse | string | any +} + +const TaskSuccessCode = "success" + +type TaskResponse[T TaskData] struct { + Code string `json:"code"` + Message string `json:"message"` + Data T `json:"data"` +} + +func (t *TaskResponse[T]) IsSuccess() bool { + return t.Code == TaskSuccessCode +} + +type TaskDto struct { + ID int64 `json:"id"` + CreatedAt int64 `json:"created_at"` + UpdatedAt int64 `json:"updated_at"` + TaskID string `json:"task_id"` + Platform string `json:"platform"` + UserId int `json:"user_id"` + Group string `json:"group"` + ChannelId int `json:"channel_id"` + Quota int `json:"quota"` + Action string `json:"action"` + Status string `json:"status"` + FailReason string `json:"fail_reason"` + ResultURL string `json:"result_url,omitempty"` // 任务结果 URL(视频地址等) + SubmitTime int64 `json:"submit_time"` + StartTime int64 `json:"start_time"` + FinishTime int64 `json:"finish_time"` + Progress string `json:"progress"` + Properties any `json:"properties"` + Username string `json:"username,omitempty"` + Data json.RawMessage `json:"data"` +} + +type FetchReq struct { + IDs []string `json:"ids"` +} diff --git a/main.go b/main.go index 852e1a0a8..476a2ed24 100644 --- a/main.go +++ b/main.go @@ -19,6 +19,7 @@ import ( "github.com/QuantumNous/new-api/middleware" "github.com/QuantumNous/new-api/model" "github.com/QuantumNous/new-api/oauth" + "github.com/QuantumNous/new-api/relay" "github.com/QuantumNous/new-api/router" "github.com/QuantumNous/new-api/service" _ "github.com/QuantumNous/new-api/setting/performance_setting" @@ -111,6 +112,15 @@ func main() { // Subscription quota reset task (daily/weekly/monthly/custom) service.StartSubscriptionQuotaResetTask() + // Wire task polling adaptor factory (breaks service -> relay import cycle) + service.GetTaskAdaptorFunc = func(platform constant.TaskPlatform) service.TaskPollingAdaptor { + a := relay.GetTaskAdaptor(platform) + if a == nil { + return nil + } + return a + } + if common.IsMasterNode && constant.UpdateTask { gopool.Go(func() { controller.UpdateMidjourneyTaskBulk() diff --git a/middleware/auth.go b/middleware/auth.go index cf1843510..342e7f498 100644 --- a/middleware/auth.go +++ b/middleware/auth.go @@ -170,6 +170,24 @@ func WssAuth(c *gin.Context) { } +// TokenOrUserAuth allows either session-based user auth or API token auth. +// Used for endpoints that need to be accessible from both the dashboard and API clients. +func TokenOrUserAuth() func(c *gin.Context) { + return func(c *gin.Context) { + // Try session auth first (dashboard users) + session := sessions.Default(c) + if id := session.Get("id"); id != nil { + if status, ok := session.Get("status").(int); ok && status == common.UserStatusEnabled { + c.Set("id", id) + c.Next() + return + } + } + // Fall back to token auth (API clients) + TokenAuth()(c) + } +} + // TokenAuthReadOnly 宽松版本的令牌认证中间件,用于只读查询接口。 // 只验证令牌 key 是否存在,不检查令牌状态、过期时间和额度。 // 即使令牌已过期、已耗尽或已禁用,也允许访问。 diff --git a/model/task.go b/model/task.go index 82c2e978a..38bb4d05a 100644 --- a/model/task.go +++ b/model/task.go @@ -5,6 +5,7 @@ import ( "encoding/json" "time" + "github.com/QuantumNous/new-api/common" "github.com/QuantumNous/new-api/constant" "github.com/QuantumNous/new-api/dto" commonRelay "github.com/QuantumNous/new-api/relay/common" @@ -64,13 +65,12 @@ type Task struct { } func (t *Task) SetData(data any) { - b, _ := json.Marshal(data) + b, _ := common.Marshal(data) t.Data = json.RawMessage(b) } func (t *Task) GetData(v any) error { - err := json.Unmarshal(t.Data, &v) - return err + return common.Unmarshal(t.Data, &v) } type Properties struct { @@ -85,18 +85,48 @@ func (m *Properties) Scan(val interface{}) error { *m = Properties{} return nil } - return json.Unmarshal(bytesValue, m) + return common.Unmarshal(bytesValue, m) } func (m Properties) Value() (driver.Value, error) { if m == (Properties{}) { return nil, nil } - return json.Marshal(m) + return common.Marshal(m) } type TaskPrivateData struct { - Key string `json:"key,omitempty"` + Key string `json:"key,omitempty"` + UpstreamTaskID string `json:"upstream_task_id,omitempty"` // 上游真实 task ID + ResultURL string `json:"result_url,omitempty"` // 任务成功后的结果 URL(视频地址等) + // 计费上下文:用于异步退款/差额结算(轮询阶段读取) + BillingSource string `json:"billing_source,omitempty"` // "wallet" 或 "subscription" + SubscriptionId int `json:"subscription_id,omitempty"` // 订阅 ID,用于订阅退款 + TokenId int `json:"token_id,omitempty"` // 令牌 ID,用于令牌额度退款 +} + +// GetUpstreamTaskID 获取上游真实 task ID(用于与 provider 通信) +// 旧数据没有 UpstreamTaskID 时,TaskID 本身就是上游 ID +func (t *Task) GetUpstreamTaskID() string { + if t.PrivateData.UpstreamTaskID != "" { + return t.PrivateData.UpstreamTaskID + } + return t.TaskID +} + +// GetResultURL 获取任务结果 URL(视频地址等) +// 新数据存在 PrivateData.ResultURL 中;旧数据回退到 FailReason(历史兼容) +func (t *Task) GetResultURL() string { + if t.PrivateData.ResultURL != "" { + return t.PrivateData.ResultURL + } + return t.FailReason +} + +// GenerateTaskID 生成对外暴露的 task_xxxx 格式 ID +func GenerateTaskID() string { + key, _ := common.GenerateRandomCharsKey(32) + return "task_" + key } func (p *TaskPrivateData) Scan(val interface{}) error { @@ -104,14 +134,14 @@ func (p *TaskPrivateData) Scan(val interface{}) error { if len(bytesValue) == 0 { return nil } - return json.Unmarshal(bytesValue, p) + return common.Unmarshal(bytesValue, p) } func (p TaskPrivateData) Value() (driver.Value, error) { if (p == TaskPrivateData{}) { return nil, nil } - return json.Marshal(p) + return common.Marshal(p) } // SyncTaskQueryParams 用于包含所有搜索条件的结构体,可以根据需求添加更多字段 @@ -142,7 +172,16 @@ func InitTask(platform constant.TaskPlatform, relayInfo *commonRelay.RelayInfo) } } + // 使用预生成的公开 ID(如果有),否则新生成 + taskID := "" + if relayInfo.TaskRelayInfo != nil && relayInfo.TaskRelayInfo.PublicTaskID != "" { + taskID = relayInfo.TaskRelayInfo.PublicTaskID + } else { + taskID = GenerateTaskID() + } + t := &Task{ + TaskID: taskID, UserId: relayInfo.UserId, Group: relayInfo.UsingGroup, SubmitTime: time.Now().Unix(), @@ -438,6 +477,6 @@ func (t *Task) ToOpenAIVideo() *dto.OpenAIVideo { openAIVideo.SetProgressStr(t.Progress) openAIVideo.CreatedAt = t.CreatedAt openAIVideo.CompletedAt = t.UpdatedAt - openAIVideo.SetMetadata("url", t.FailReason) + openAIVideo.SetMetadata("url", t.GetResultURL()) return openAIVideo } diff --git a/model/token.go b/model/token.go index 9e05b63ca..773b2d792 100644 --- a/model/token.go +++ b/model/token.go @@ -360,7 +360,7 @@ func DeleteTokenById(id int, userId int) (err error) { return token.Delete() } -func IncreaseTokenQuota(id int, key string, quota int) (err error) { +func IncreaseTokenQuota(tokenId int, key string, quota int) (err error) { if quota < 0 { return errors.New("quota 不能为负数!") } @@ -373,10 +373,10 @@ func IncreaseTokenQuota(id int, key string, quota int) (err error) { }) } if common.BatchUpdateEnabled { - addNewRecord(BatchUpdateTypeTokenQuota, id, quota) + addNewRecord(BatchUpdateTypeTokenQuota, tokenId, quota) return nil } - return increaseTokenQuota(id, quota) + return increaseTokenQuota(tokenId, quota) } func increaseTokenQuota(id int, quota int) (err error) { diff --git a/relay/channel/task/ali/adaptor.go b/relay/channel/task/ali/adaptor.go index d55452c08..5d14ff655 100644 --- a/relay/channel/task/ali/adaptor.go +++ b/relay/channel/task/ali/adaptor.go @@ -384,7 +384,8 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela // 转换为 OpenAI 格式响应 openAIResp := dto.NewOpenAIVideo() - openAIResp.ID = aliResp.Output.TaskID + openAIResp.ID = info.PublicTaskID + openAIResp.TaskID = info.PublicTaskID openAIResp.Model = c.GetString("model") if openAIResp.Model == "" && info != nil { openAIResp.Model = info.OriginModelName diff --git a/relay/channel/task/doubao/adaptor.go b/relay/channel/task/doubao/adaptor.go index 6ebecb3c0..3da125afc 100644 --- a/relay/channel/task/doubao/adaptor.go +++ b/relay/channel/task/doubao/adaptor.go @@ -2,7 +2,6 @@ package doubao import ( "bytes" - "encoding/json" "fmt" "io" "net/http" @@ -14,6 +13,7 @@ import ( "github.com/QuantumNous/new-api/dto" "github.com/QuantumNous/new-api/model" "github.com/QuantumNous/new-api/relay/channel" + taskcommon "github.com/QuantumNous/new-api/relay/channel/task/taskcommon" relaycommon "github.com/QuantumNous/new-api/relay/common" "github.com/QuantumNous/new-api/service" @@ -131,7 +131,7 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayIn return nil, errors.Wrap(err, "convert request payload failed") } info.UpstreamModelName = body.Model - data, err := json.Marshal(body) + data, err := common.Marshal(body) if err != nil { return nil, err } @@ -154,7 +154,7 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela // Parse Doubao response var dResp responsePayload - if err := json.Unmarshal(responseBody, &dResp); err != nil { + if err := common.Unmarshal(responseBody, &dResp); err != nil { taskErr = service.TaskErrorWrapper(errors.Wrapf(err, "body: %s", responseBody), "unmarshal_response_body_failed", http.StatusInternalServerError) return } @@ -165,8 +165,8 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela } ov := dto.NewOpenAIVideo() - ov.ID = dResp.ID - ov.TaskID = dResp.ID + ov.ID = info.PublicTaskID + ov.TaskID = info.PublicTaskID ov.CreatedAt = time.Now().Unix() ov.Model = info.OriginModelName @@ -234,12 +234,7 @@ func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (* } metadata := req.Metadata - medaBytes, err := json.Marshal(metadata) - if err != nil { - return nil, errors.Wrap(err, "metadata marshal metadata failed") - } - err = json.Unmarshal(medaBytes, &r) - if err != nil { + if err := taskcommon.UnmarshalMetadata(metadata, &r); err != nil { return nil, errors.Wrap(err, "unmarshal metadata failed") } @@ -248,7 +243,7 @@ func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (* func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) { resTask := responseTask{} - if err := json.Unmarshal(respBody, &resTask); err != nil { + if err := common.Unmarshal(respBody, &resTask); err != nil { return nil, errors.Wrap(err, "unmarshal task result failed") } @@ -286,7 +281,7 @@ func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, e func (a *TaskAdaptor) ConvertToOpenAIVideo(originTask *model.Task) ([]byte, error) { var dResp responseTask - if err := json.Unmarshal(originTask.Data, &dResp); err != nil { + if err := common.Unmarshal(originTask.Data, &dResp); err != nil { return nil, errors.Wrap(err, "unmarshal doubao task data failed") } @@ -307,6 +302,5 @@ func (a *TaskAdaptor) ConvertToOpenAIVideo(originTask *model.Task) ([]byte, erro } } - jsonData, _ := common.Marshal(openAIVideo) - return jsonData, nil + return common.Marshal(openAIVideo) } diff --git a/relay/channel/task/gemini/adaptor.go b/relay/channel/task/gemini/adaptor.go index 16c6919b7..a863ea852 100644 --- a/relay/channel/task/gemini/adaptor.go +++ b/relay/channel/task/gemini/adaptor.go @@ -2,8 +2,6 @@ package gemini import ( "bytes" - "encoding/base64" - "encoding/json" "fmt" "io" "net/http" @@ -16,10 +14,10 @@ import ( "github.com/QuantumNous/new-api/dto" "github.com/QuantumNous/new-api/model" "github.com/QuantumNous/new-api/relay/channel" + taskcommon "github.com/QuantumNous/new-api/relay/channel/task/taskcommon" relaycommon "github.com/QuantumNous/new-api/relay/common" "github.com/QuantumNous/new-api/service" "github.com/QuantumNous/new-api/setting/model_setting" - "github.com/QuantumNous/new-api/setting/system_setting" "github.com/gin-gonic/gin" "github.com/pkg/errors" ) @@ -145,16 +143,11 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayIn } metadata := req.Metadata - medaBytes, err := json.Marshal(metadata) - if err != nil { - return nil, errors.Wrap(err, "metadata marshal metadata failed") - } - err = json.Unmarshal(medaBytes, &body.Parameters) - if err != nil { + if err := taskcommon.UnmarshalMetadata(metadata, &body.Parameters); err != nil { return nil, errors.Wrap(err, "unmarshal metadata failed") } - data, err := json.Marshal(body) + data, err := common.Marshal(body) if err != nil { return nil, err } @@ -175,16 +168,16 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela _ = resp.Body.Close() var s submitResponse - if err := json.Unmarshal(responseBody, &s); err != nil { + if err := common.Unmarshal(responseBody, &s); err != nil { return "", nil, service.TaskErrorWrapper(err, "unmarshal_response_failed", http.StatusInternalServerError) } if strings.TrimSpace(s.Name) == "" { return "", nil, service.TaskErrorWrapper(fmt.Errorf("missing operation name"), "invalid_response", http.StatusInternalServerError) } - taskID = encodeLocalTaskID(s.Name) + taskID = taskcommon.EncodeLocalTaskID(s.Name) ov := dto.NewOpenAIVideo() - ov.ID = taskID - ov.TaskID = taskID + ov.ID = info.PublicTaskID + ov.TaskID = info.PublicTaskID ov.CreatedAt = time.Now().Unix() ov.Model = info.OriginModelName c.JSON(http.StatusOK, ov) @@ -206,7 +199,7 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any, proxy return nil, fmt.Errorf("invalid task_id") } - upstreamName, err := decodeLocalTaskID(taskID) + upstreamName, err := taskcommon.DecodeLocalTaskID(taskID) if err != nil { return nil, fmt.Errorf("decode task_id failed: %w", err) } @@ -232,7 +225,7 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any, proxy func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) { var op operationResponse - if err := json.Unmarshal(respBody, &op); err != nil { + if err := common.Unmarshal(respBody, &op); err != nil { return nil, fmt.Errorf("unmarshal operation response failed: %w", err) } @@ -254,9 +247,8 @@ func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, e ti.Status = model.TaskStatusSuccess ti.Progress = "100%" - taskID := encodeLocalTaskID(op.Name) - ti.TaskID = taskID - ti.Url = fmt.Sprintf("%s/v1/videos/%s/content", system_setting.ServerAddress, taskID) + ti.TaskID = taskcommon.EncodeLocalTaskID(op.Name) + // Url intentionally left empty — the caller constructs the proxy URL using the public task ID // Extract URL from generateVideoResponse if available if len(op.Response.GenerateVideoResponse.GeneratedSamples) > 0 { @@ -269,7 +261,10 @@ func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, e } func (a *TaskAdaptor) ConvertToOpenAIVideo(task *model.Task) ([]byte, error) { - upstreamName, err := decodeLocalTaskID(task.TaskID) + // Use GetUpstreamTaskID() to get the real upstream operation name for model extraction. + // task.TaskID is now a public task_xxxx ID, no longer a base64-encoded upstream name. + upstreamTaskID := task.GetUpstreamTaskID() + upstreamName, err := taskcommon.DecodeLocalTaskID(upstreamTaskID) if err != nil { upstreamName = "" } @@ -297,18 +292,6 @@ func (a *TaskAdaptor) ConvertToOpenAIVideo(task *model.Task) ([]byte, error) { // helpers // ============================ -func encodeLocalTaskID(name string) string { - return base64.RawURLEncoding.EncodeToString([]byte(name)) -} - -func decodeLocalTaskID(local string) (string, error) { - b, err := base64.RawURLEncoding.DecodeString(local) - if err != nil { - return "", err - } - return string(b), nil -} - var modelRe = regexp.MustCompile(`models/([^/]+)/operations/`) func extractModelFromOperationName(name string) string { diff --git a/relay/channel/task/hailuo/adaptor.go b/relay/channel/task/hailuo/adaptor.go index c77905bfb..67a68a10e 100644 --- a/relay/channel/task/hailuo/adaptor.go +++ b/relay/channel/task/hailuo/adaptor.go @@ -2,7 +2,6 @@ package hailuo import ( "bytes" - "encoding/json" "fmt" "io" "net/http" @@ -65,7 +64,7 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayIn return nil, errors.Wrap(err, "convert request payload failed") } - data, err := json.Marshal(body) + data, err := common.Marshal(body) if err != nil { return nil, err } @@ -86,7 +85,7 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela _ = resp.Body.Close() var hResp VideoResponse - if err := json.Unmarshal(responseBody, &hResp); err != nil { + if err := common.Unmarshal(responseBody, &hResp); err != nil { taskErr = service.TaskErrorWrapper(errors.Wrapf(err, "body: %s", responseBody), "unmarshal_response_body_failed", http.StatusInternalServerError) return } @@ -101,8 +100,8 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela } ov := dto.NewOpenAIVideo() - ov.ID = hResp.TaskID - ov.TaskID = hResp.TaskID + ov.ID = info.PublicTaskID + ov.TaskID = info.PublicTaskID ov.CreatedAt = time.Now().Unix() ov.Model = info.OriginModelName @@ -182,7 +181,7 @@ func (a *TaskAdaptor) parseResolutionFromSize(size string, modelConfig ModelConf func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) { resTask := QueryTaskResponse{} - if err := json.Unmarshal(respBody, &resTask); err != nil { + if err := common.Unmarshal(respBody, &resTask); err != nil { return nil, errors.Wrap(err, "unmarshal task result failed") } @@ -224,7 +223,7 @@ func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, e func (a *TaskAdaptor) ConvertToOpenAIVideo(originTask *model.Task) ([]byte, error) { var hailuoResp QueryTaskResponse - if err := json.Unmarshal(originTask.Data, &hailuoResp); err != nil { + if err := common.Unmarshal(originTask.Data, &hailuoResp); err != nil { return nil, errors.Wrap(err, "unmarshal hailuo task data failed") } @@ -271,7 +270,7 @@ func (a *TaskAdaptor) buildVideoURL(_, fileID string) string { } var retrieveResp RetrieveFileResponse - if err := json.Unmarshal(responseBody, &retrieveResp); err != nil { + if err := common.Unmarshal(responseBody, &retrieveResp); err != nil { return "" } diff --git a/relay/channel/task/jimeng/adaptor.go b/relay/channel/task/jimeng/adaptor.go index 1522a967f..7f88be248 100644 --- a/relay/channel/task/jimeng/adaptor.go +++ b/relay/channel/task/jimeng/adaptor.go @@ -6,7 +6,6 @@ import ( "crypto/sha256" "encoding/base64" "encoding/hex" - "encoding/json" "fmt" "io" "net/http" @@ -25,6 +24,7 @@ import ( "github.com/QuantumNous/new-api/constant" "github.com/QuantumNous/new-api/dto" "github.com/QuantumNous/new-api/relay/channel" + taskcommon "github.com/QuantumNous/new-api/relay/channel/task/taskcommon" relaycommon "github.com/QuantumNous/new-api/relay/common" "github.com/QuantumNous/new-api/service" ) @@ -168,7 +168,7 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayIn if err != nil { return nil, errors.Wrap(err, "convert request payload failed") } - data, err := json.Marshal(body) + data, err := common.Marshal(body) if err != nil { return nil, err } @@ -191,7 +191,7 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela // Parse Jimeng response var jResp responsePayload - if err := json.Unmarshal(responseBody, &jResp); err != nil { + if err := common.Unmarshal(responseBody, &jResp); err != nil { taskErr = service.TaskErrorWrapper(errors.Wrapf(err, "body: %s", responseBody), "unmarshal_response_body_failed", http.StatusInternalServerError) return } @@ -202,8 +202,8 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela } ov := dto.NewOpenAIVideo() - ov.ID = jResp.Data.TaskID - ov.TaskID = jResp.Data.TaskID + ov.ID = info.PublicTaskID + ov.TaskID = info.PublicTaskID ov.CreatedAt = time.Now().Unix() ov.Model = info.OriginModelName c.JSON(http.StatusOK, ov) @@ -225,7 +225,7 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any, proxy "req_key": "jimeng_vgfm_t2v_l20", // This is fixed value from doc: https://www.volcengine.com/docs/85621/1544774 "task_id": taskID, } - payloadBytes, err := json.Marshal(payload) + payloadBytes, err := common.Marshal(payload) if err != nil { return nil, errors.Wrap(err, "marshal fetch task payload failed") } @@ -398,13 +398,7 @@ func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (* r.BinaryDataBase64 = req.Images } } - metadata := req.Metadata - medaBytes, err := json.Marshal(metadata) - if err != nil { - return nil, errors.Wrap(err, "metadata marshal metadata failed") - } - err = json.Unmarshal(medaBytes, &r) - if err != nil { + if err := taskcommon.UnmarshalMetadata(req.Metadata, &r); err != nil { return nil, errors.Wrap(err, "unmarshal metadata failed") } @@ -432,7 +426,7 @@ func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (* func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) { resTask := responseTask{} - if err := json.Unmarshal(respBody, &resTask); err != nil { + if err := common.Unmarshal(respBody, &resTask); err != nil { return nil, errors.Wrap(err, "unmarshal task result failed") } taskResult := relaycommon.TaskInfo{} @@ -458,7 +452,7 @@ func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, e func (a *TaskAdaptor) ConvertToOpenAIVideo(originTask *model.Task) ([]byte, error) { var jimengResp responseTask - if err := json.Unmarshal(originTask.Data, &jimengResp); err != nil { + if err := common.Unmarshal(originTask.Data, &jimengResp); err != nil { return nil, errors.Wrap(err, "unmarshal jimeng task data failed") } @@ -477,8 +471,7 @@ func (a *TaskAdaptor) ConvertToOpenAIVideo(originTask *model.Task) ([]byte, erro } } - jsonData, _ := common.Marshal(openAIVideo) - return jsonData, nil + return common.Marshal(openAIVideo) } func isNewAPIRelay(apiKey string) bool { diff --git a/relay/channel/task/kling/adaptor.go b/relay/channel/task/kling/adaptor.go index 5fb853481..4458626b2 100644 --- a/relay/channel/task/kling/adaptor.go +++ b/relay/channel/task/kling/adaptor.go @@ -2,7 +2,6 @@ package kling import ( "bytes" - "encoding/json" "fmt" "io" "net/http" @@ -21,6 +20,7 @@ import ( "github.com/QuantumNous/new-api/constant" "github.com/QuantumNous/new-api/dto" "github.com/QuantumNous/new-api/relay/channel" + taskcommon "github.com/QuantumNous/new-api/relay/channel/task/taskcommon" relaycommon "github.com/QuantumNous/new-api/relay/common" "github.com/QuantumNous/new-api/service" ) @@ -156,7 +156,7 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayIn if body.Image == "" && body.ImageTail == "" { c.Set("action", constant.TaskActionTextGenerate) } - data, err := json.Marshal(body) + data, err := common.Marshal(body) if err != nil { return nil, err } @@ -180,7 +180,7 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela } var kResp responsePayload - err = json.Unmarshal(responseBody, &kResp) + err = common.Unmarshal(responseBody, &kResp) if err != nil { taskErr = service.TaskErrorWrapper(err, "unmarshal_response_failed", http.StatusInternalServerError) return @@ -190,8 +190,8 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela return } ov := dto.NewOpenAIVideo() - ov.ID = kResp.Data.TaskId - ov.TaskID = kResp.Data.TaskId + ov.ID = info.PublicTaskID + ov.TaskID = info.PublicTaskID ov.CreatedAt = time.Now().Unix() ov.Model = info.OriginModelName c.JSON(http.StatusOK, ov) @@ -251,8 +251,8 @@ func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (* r := requestPayload{ Prompt: req.Prompt, Image: req.Image, - Mode: defaultString(req.Mode, "std"), - Duration: fmt.Sprintf("%d", defaultInt(req.Duration, 5)), + Mode: taskcommon.DefaultString(req.Mode, "std"), + Duration: fmt.Sprintf("%d", taskcommon.DefaultInt(req.Duration, 5)), AspectRatio: a.getAspectRatio(req.Size), ModelName: req.Model, Model: req.Model, // Keep consistent with model_name, double writing improves compatibility @@ -266,13 +266,7 @@ func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (* if r.ModelName == "" { r.ModelName = "kling-v1" } - metadata := req.Metadata - medaBytes, err := json.Marshal(metadata) - if err != nil { - return nil, errors.Wrap(err, "metadata marshal metadata failed") - } - err = json.Unmarshal(medaBytes, &r) - if err != nil { + if err := taskcommon.UnmarshalMetadata(req.Metadata, &r); err != nil { return nil, errors.Wrap(err, "unmarshal metadata failed") } return &r, nil @@ -291,20 +285,6 @@ func (a *TaskAdaptor) getAspectRatio(size string) string { } } -func defaultString(s, def string) string { - if strings.TrimSpace(s) == "" { - return def - } - return s -} - -func defaultInt(v int, def int) int { - if v == 0 { - return def - } - return v -} - // ============================ // JWT helpers // ============================ @@ -340,7 +320,7 @@ func (a *TaskAdaptor) createJWTTokenWithKey(apiKey string) (string, error) { func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) { taskInfo := &relaycommon.TaskInfo{} resPayload := responsePayload{} - err := json.Unmarshal(respBody, &resPayload) + err := common.Unmarshal(respBody, &resPayload) if err != nil { return nil, errors.Wrap(err, "failed to unmarshal response body") } @@ -374,7 +354,7 @@ func isNewAPIRelay(apiKey string) bool { func (a *TaskAdaptor) ConvertToOpenAIVideo(originTask *model.Task) ([]byte, error) { var klingResp responsePayload - if err := json.Unmarshal(originTask.Data, &klingResp); err != nil { + if err := common.Unmarshal(originTask.Data, &klingResp); err != nil { return nil, errors.Wrap(err, "unmarshal kling task data failed") } @@ -401,6 +381,5 @@ func (a *TaskAdaptor) ConvertToOpenAIVideo(originTask *model.Task) ([]byte, erro Code: fmt.Sprintf("%d", klingResp.Code), } } - jsonData, _ := common.Marshal(openAIVideo) - return jsonData, nil + return common.Marshal(openAIVideo) } diff --git a/relay/channel/task/sora/adaptor.go b/relay/channel/task/sora/adaptor.go index c149f9663..ee69a3e48 100644 --- a/relay/channel/task/sora/adaptor.go +++ b/relay/channel/task/sora/adaptor.go @@ -13,7 +13,6 @@ import ( "github.com/QuantumNous/new-api/relay/channel" relaycommon "github.com/QuantumNous/new-api/relay/common" "github.com/QuantumNous/new-api/service" - "github.com/QuantumNous/new-api/setting/system_setting" "github.com/gin-gonic/gin" "github.com/pkg/errors" @@ -116,7 +115,7 @@ func (a *TaskAdaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, req } // DoResponse handles upstream response, returns taskID etc. -func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, _ *relaycommon.RelayInfo) (taskID string, taskData []byte, taskErr *dto.TaskError) { +func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (taskID string, taskData []byte, taskErr *dto.TaskError) { responseBody, err := io.ReadAll(resp.Body) if err != nil { taskErr = service.TaskErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError) @@ -131,17 +130,20 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, _ *relayco return } - if dResp.ID == "" { - if dResp.TaskID == "" { - taskErr = service.TaskErrorWrapper(fmt.Errorf("task_id is empty"), "invalid_response", http.StatusInternalServerError) - return - } - dResp.ID = dResp.TaskID - dResp.TaskID = "" + upstreamID := dResp.ID + if upstreamID == "" { + upstreamID = dResp.TaskID + } + if upstreamID == "" { + taskErr = service.TaskErrorWrapper(fmt.Errorf("task_id is empty"), "invalid_response", http.StatusInternalServerError) + return } + // 使用公开 task_xxxx ID 返回给客户端 + dResp.ID = info.PublicTaskID + dResp.TaskID = info.PublicTaskID c.JSON(http.StatusOK, dResp) - return dResp.ID, responseBody, nil + return upstreamID, responseBody, nil } // FetchTask fetch task status @@ -192,7 +194,7 @@ func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, e taskResult.Status = model.TaskStatusInProgress case "completed": taskResult.Status = model.TaskStatusSuccess - taskResult.Url = fmt.Sprintf("%s/v1/videos/%s/content", system_setting.ServerAddress, resTask.ID) + // Url intentionally left empty — the caller constructs the proxy URL using the public task ID case "failed", "cancelled": taskResult.Status = model.TaskStatusFailure if resTask.Error != nil { diff --git a/relay/channel/task/suno/adaptor.go b/relay/channel/task/suno/adaptor.go index 8ea9a1c7f..5dd62a70f 100644 --- a/relay/channel/task/suno/adaptor.go +++ b/relay/channel/task/suno/adaptor.go @@ -3,7 +3,6 @@ package suno import ( "bytes" "context" - "encoding/json" "fmt" "io" "net/http" @@ -24,8 +23,12 @@ type TaskAdaptor struct { ChannelType int } +// ParseTaskResult is not used for Suno tasks. +// Suno polling uses a dedicated batch-fetch path (service.UpdateSunoTasks) that +// receives dto.TaskResponse[[]dto.SunoDataResponse] from the upstream /fetch API. +// This differs from the per-task polling used by video adaptors. func (a *TaskAdaptor) ParseTaskResult([]byte) (*relaycommon.TaskInfo, error) { - return nil, fmt.Errorf("not implement") // todo implement this method if needed + return nil, fmt.Errorf("suno uses batch polling via UpdateSunoTasks, ParseTaskResult is not applicable") } func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) { @@ -81,7 +84,7 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayIn return nil, err } } - data, err := json.Marshal(sunoRequest) + data, err := common.Marshal(sunoRequest) if err != nil { return nil, err } @@ -99,7 +102,7 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela return } var sunoResponse dto.TaskResponse[string] - err = json.Unmarshal(responseBody, &sunoResponse) + err = common.Unmarshal(responseBody, &sunoResponse) if err != nil { taskErr = service.TaskErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError) return @@ -109,17 +112,13 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela return } - for k, v := range resp.Header { - c.Writer.Header().Set(k, v[0]) - } - c.Writer.Header().Set("Content-Type", "application/json") - c.Writer.WriteHeader(resp.StatusCode) - - _, err = io.Copy(c.Writer, bytes.NewBuffer(responseBody)) - if err != nil { - taskErr = service.TaskErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError) - return + // 使用公开 task_xxxx ID 替换上游 ID 返回给客户端 + publicResponse := dto.TaskResponse[string]{ + Code: sunoResponse.Code, + Message: sunoResponse.Message, + Data: info.PublicTaskID, } + c.JSON(http.StatusOK, publicResponse) return sunoResponse.Data, nil, nil } @@ -134,7 +133,7 @@ func (a *TaskAdaptor) GetChannelName() string { func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any, proxy string) (*http.Response, error) { requestUrl := fmt.Sprintf("%s/suno/fetch", baseUrl) - byteBody, err := json.Marshal(body) + byteBody, err := common.Marshal(body) if err != nil { return nil, err } diff --git a/relay/channel/task/taskcommon/helpers.go b/relay/channel/task/taskcommon/helpers.go new file mode 100644 index 000000000..b1dde998b --- /dev/null +++ b/relay/channel/task/taskcommon/helpers.go @@ -0,0 +1,70 @@ +package taskcommon + +import ( + "encoding/base64" + "fmt" + + "github.com/QuantumNous/new-api/common" + "github.com/QuantumNous/new-api/setting/system_setting" +) + +// UnmarshalMetadata converts a map[string]any metadata to a typed struct via JSON round-trip. +// This replaces the repeated pattern: json.Marshal(metadata) → json.Unmarshal(bytes, &target). +func UnmarshalMetadata(metadata map[string]any, target any) error { + if metadata == nil { + return nil + } + metaBytes, err := common.Marshal(metadata) + if err != nil { + return fmt.Errorf("marshal metadata failed: %w", err) + } + if err := common.Unmarshal(metaBytes, target); err != nil { + return fmt.Errorf("unmarshal metadata failed: %w", err) + } + return nil +} + +// DefaultString returns val if non-empty, otherwise fallback. +func DefaultString(val, fallback string) string { + if val == "" { + return fallback + } + return val +} + +// DefaultInt returns val if non-zero, otherwise fallback. +func DefaultInt(val, fallback int) int { + if val == 0 { + return fallback + } + return val +} + +// EncodeLocalTaskID encodes an upstream operation name to a URL-safe base64 string. +// Used by Gemini/Vertex to store upstream names as task IDs. +func EncodeLocalTaskID(name string) string { + return base64.RawURLEncoding.EncodeToString([]byte(name)) +} + +// DecodeLocalTaskID decodes a base64-encoded upstream operation name. +func DecodeLocalTaskID(id string) (string, error) { + b, err := base64.RawURLEncoding.DecodeString(id) + if err != nil { + return "", err + } + return string(b), nil +} + +// BuildProxyURL constructs the video proxy URL using the public task ID. +// e.g., "https://your-server.com/v1/videos/task_xxxx/content" +func BuildProxyURL(taskID string) string { + return fmt.Sprintf("%s/v1/videos/%s/content", system_setting.ServerAddress, taskID) +} + +// Status-to-progress mapping constants for polling updates. +const ( + ProgressSubmitted = "10%" + ProgressQueued = "20%" + ProgressInProgress = "30%" + ProgressComplete = "100%" +) diff --git a/relay/channel/task/vertex/adaptor.go b/relay/channel/task/vertex/adaptor.go index 8ec77266e..fb3a313ff 100644 --- a/relay/channel/task/vertex/adaptor.go +++ b/relay/channel/task/vertex/adaptor.go @@ -2,13 +2,12 @@ package vertex import ( "bytes" - "encoding/base64" - "encoding/json" "fmt" "io" "net/http" "regexp" "strings" + "time" "github.com/QuantumNous/new-api/common" "github.com/QuantumNous/new-api/model" @@ -17,6 +16,7 @@ import ( "github.com/QuantumNous/new-api/constant" "github.com/QuantumNous/new-api/dto" "github.com/QuantumNous/new-api/relay/channel" + taskcommon "github.com/QuantumNous/new-api/relay/channel/task/taskcommon" vertexcore "github.com/QuantumNous/new-api/relay/channel/vertex" relaycommon "github.com/QuantumNous/new-api/relay/common" "github.com/QuantumNous/new-api/service" @@ -82,7 +82,7 @@ func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycom // BuildRequestURL constructs the upstream URL. func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.RelayInfo) (string, error) { adc := &vertexcore.Credentials{} - if err := json.Unmarshal([]byte(a.apiKey), adc); err != nil { + if err := common.Unmarshal([]byte(a.apiKey), adc); err != nil { return "", fmt.Errorf("failed to decode credentials: %w", err) } modelName := info.OriginModelName @@ -116,7 +116,7 @@ func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info req.Header.Set("Accept", "application/json") adc := &vertexcore.Credentials{} - if err := json.Unmarshal([]byte(a.apiKey), adc); err != nil { + if err := common.Unmarshal([]byte(a.apiKey), adc); err != nil { return fmt.Errorf("failed to decode credentials: %w", err) } @@ -184,7 +184,7 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayIn // info.PriceData.OtherRatios["durationSeconds"] = float64(v.(int)) // } - data, err := json.Marshal(body) + data, err := common.Marshal(body) if err != nil { return nil, err } @@ -205,14 +205,19 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela _ = resp.Body.Close() var s submitResponse - if err := json.Unmarshal(responseBody, &s); err != nil { + if err := common.Unmarshal(responseBody, &s); err != nil { return "", nil, service.TaskErrorWrapper(err, "unmarshal_response_failed", http.StatusInternalServerError) } if strings.TrimSpace(s.Name) == "" { return "", nil, service.TaskErrorWrapper(fmt.Errorf("missing operation name"), "invalid_response", http.StatusInternalServerError) } - localID := encodeLocalTaskID(s.Name) - c.JSON(http.StatusOK, gin.H{"task_id": localID}) + localID := taskcommon.EncodeLocalTaskID(s.Name) + ov := dto.NewOpenAIVideo() + ov.ID = info.PublicTaskID + ov.TaskID = info.PublicTaskID + ov.CreatedAt = time.Now().Unix() + ov.Model = info.OriginModelName + c.JSON(http.StatusOK, ov) return localID, responseBody, nil } @@ -225,7 +230,7 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any, proxy if !ok { return nil, fmt.Errorf("invalid task_id") } - upstreamName, err := decodeLocalTaskID(taskID) + upstreamName, err := taskcommon.DecodeLocalTaskID(taskID) if err != nil { return nil, fmt.Errorf("decode task_id failed: %w", err) } @@ -245,12 +250,12 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any, proxy url = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:fetchPredictOperation", region, project, region, modelName) } payload := map[string]string{"operationName": upstreamName} - data, err := json.Marshal(payload) + data, err := common.Marshal(payload) if err != nil { return nil, err } adc := &vertexcore.Credentials{} - if err := json.Unmarshal([]byte(key), adc); err != nil { + if err := common.Unmarshal([]byte(key), adc); err != nil { return nil, fmt.Errorf("failed to decode credentials: %w", err) } token, err := vertexcore.AcquireAccessToken(*adc, proxy) @@ -274,7 +279,7 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any, proxy func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) { var op operationResponse - if err := json.Unmarshal(respBody, &op); err != nil { + if err := common.Unmarshal(respBody, &op); err != nil { return nil, fmt.Errorf("unmarshal operation response failed: %w", err) } ti := &relaycommon.TaskInfo{} @@ -338,7 +343,10 @@ func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, e } func (a *TaskAdaptor) ConvertToOpenAIVideo(task *model.Task) ([]byte, error) { - upstreamName, err := decodeLocalTaskID(task.TaskID) + // Use GetUpstreamTaskID() to get the real upstream operation name for model extraction. + // task.TaskID is now a public task_xxxx ID, no longer a base64-encoded upstream name. + upstreamTaskID := task.GetUpstreamTaskID() + upstreamName, err := taskcommon.DecodeLocalTaskID(upstreamTaskID) if err != nil { upstreamName = "" } @@ -353,8 +361,8 @@ func (a *TaskAdaptor) ConvertToOpenAIVideo(task *model.Task) ([]byte, error) { v.SetProgressStr(task.Progress) v.CreatedAt = task.CreatedAt v.CompletedAt = task.UpdatedAt - if strings.HasPrefix(task.FailReason, "data:") && len(task.FailReason) > 0 { - v.SetMetadata("url", task.FailReason) + if resultURL := task.GetResultURL(); strings.HasPrefix(resultURL, "data:") && len(resultURL) > 0 { + v.SetMetadata("url", resultURL) } return common.Marshal(v) @@ -364,18 +372,6 @@ func (a *TaskAdaptor) ConvertToOpenAIVideo(task *model.Task) ([]byte, error) { // helpers // ============================ -func encodeLocalTaskID(name string) string { - return base64.RawURLEncoding.EncodeToString([]byte(name)) -} - -func decodeLocalTaskID(local string) (string, error) { - b, err := base64.RawURLEncoding.DecodeString(local) - if err != nil { - return "", err - } - return string(b), nil -} - var regionRe = regexp.MustCompile(`locations/([a-z0-9-]+)/`) func extractRegionFromOperationName(name string) string { diff --git a/relay/channel/task/vidu/adaptor.go b/relay/channel/task/vidu/adaptor.go index 3657161c0..1bab12f03 100644 --- a/relay/channel/task/vidu/adaptor.go +++ b/relay/channel/task/vidu/adaptor.go @@ -2,7 +2,6 @@ package vidu import ( "bytes" - "encoding/json" "fmt" "io" "net/http" @@ -16,6 +15,7 @@ import ( "github.com/QuantumNous/new-api/dto" "github.com/QuantumNous/new-api/model" "github.com/QuantumNous/new-api/relay/channel" + taskcommon "github.com/QuantumNous/new-api/relay/channel/task/taskcommon" relaycommon "github.com/QuantumNous/new-api/relay/common" "github.com/QuantumNous/new-api/service" @@ -127,7 +127,7 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayIn } } - data, err := json.Marshal(body) + data, err := common.Marshal(body) if err != nil { return nil, err } @@ -168,7 +168,7 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela } var vResp responsePayload - err = json.Unmarshal(responseBody, &vResp) + err = common.Unmarshal(responseBody, &vResp) if err != nil { taskErr = service.TaskErrorWrapper(errors.Wrap(err, fmt.Sprintf("%s", responseBody)), "unmarshal_response_failed", http.StatusInternalServerError) return @@ -180,8 +180,8 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela } ov := dto.NewOpenAIVideo() - ov.ID = vResp.TaskId - ov.TaskID = vResp.TaskId + ov.ID = info.PublicTaskID + ov.TaskID = info.PublicTaskID ov.CreatedAt = time.Now().Unix() ov.Model = info.OriginModelName c.JSON(http.StatusOK, ov) @@ -225,45 +225,25 @@ func (a *TaskAdaptor) GetChannelName() string { func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*requestPayload, error) { r := requestPayload{ - Model: defaultString(req.Model, "viduq1"), + Model: taskcommon.DefaultString(req.Model, "viduq1"), Images: req.Images, Prompt: req.Prompt, - Duration: defaultInt(req.Duration, 5), - Resolution: defaultString(req.Size, "1080p"), + Duration: taskcommon.DefaultInt(req.Duration, 5), + Resolution: taskcommon.DefaultString(req.Size, "1080p"), MovementAmplitude: "auto", Bgm: false, } - metadata := req.Metadata - medaBytes, err := json.Marshal(metadata) - if err != nil { - return nil, errors.Wrap(err, "metadata marshal metadata failed") - } - err = json.Unmarshal(medaBytes, &r) - if err != nil { + if err := taskcommon.UnmarshalMetadata(req.Metadata, &r); err != nil { return nil, errors.Wrap(err, "unmarshal metadata failed") } return &r, nil } -func defaultString(value, defaultValue string) string { - if value == "" { - return defaultValue - } - return value -} - -func defaultInt(value, defaultValue int) int { - if value == 0 { - return defaultValue - } - return value -} - func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) { taskInfo := &relaycommon.TaskInfo{} var taskResp taskResultResponse - err := json.Unmarshal(respBody, &taskResp) + err := common.Unmarshal(respBody, &taskResp) if err != nil { return nil, errors.Wrap(err, "failed to unmarshal response body") } @@ -293,7 +273,7 @@ func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, e func (a *TaskAdaptor) ConvertToOpenAIVideo(originTask *model.Task) ([]byte, error) { var viduResp taskResultResponse - if err := json.Unmarshal(originTask.Data, &viduResp); err != nil { + if err := common.Unmarshal(originTask.Data, &viduResp); err != nil { return nil, errors.Wrap(err, "unmarshal vidu task data failed") } @@ -315,6 +295,5 @@ func (a *TaskAdaptor) ConvertToOpenAIVideo(originTask *model.Task) ([]byte, erro } } - jsonData, _ := common.Marshal(openAIVideo) - return jsonData, nil + return common.Marshal(openAIVideo) } diff --git a/relay/common/relay_info.go b/relay/common/relay_info.go index 81b7d21d6..b68826812 100644 --- a/relay/common/relay_info.go +++ b/relay/common/relay_info.go @@ -118,8 +118,12 @@ type RelayInfo struct { SendResponseCount int ReceivedResponseCount int FinalPreConsumedQuota int // 最终预消耗的配额 + // ForcePreConsume 为 true 时禁用 BillingSession 的信任额度旁路, + // 强制预扣全额。用于异步任务(视频/音乐生成等),因为请求返回后任务仍在运行, + // 必须在提交前锁定全额。 + ForcePreConsume bool // Billing 是计费会话,封装了预扣费/结算/退款的统一生命周期。 - // 免费模型和按次计费(MJ/Task)时为 nil。 + // 免费模型时为 nil。 Billing BillingSettler // BillingSource indicates whether this request is billed from wallet quota or subscription. // "" or "wallet" => wallet; "subscription" => subscription @@ -525,8 +529,10 @@ func GenRelayInfo(c *gin.Context, relayFormat types.RelayFormat, request dto.Req return nil, errors.New("request is not a OpenAIResponsesCompactionRequest") case types.RelayFormatTask: info = genBaseRelayInfo(c, nil) + info.TaskRelayInfo = &TaskRelayInfo{} case types.RelayFormatMjProxy: info = genBaseRelayInfo(c, nil) + info.TaskRelayInfo = &TaskRelayInfo{} default: err = errors.New("invalid relay format") } @@ -608,6 +614,9 @@ func (info *RelayInfo) HasSendResponse() bool { type TaskRelayInfo struct { Action string OriginTaskID string + // PublicTaskID 是提交时预生成的 task_xxxx 格式公开 ID, + // 供 DoResponse 在返回给客户端时使用(避免暴露上游真实 ID)。 + PublicTaskID string ConsumeQuota bool } @@ -667,11 +676,11 @@ func (t *TaskSubmitReq) UnmarshalJSON(data []byte) error { func (t *TaskSubmitReq) UnmarshalMetadata(v any) error { metadata := t.Metadata if metadata != nil { - metadataBytes, err := json.Marshal(metadata) + metadataBytes, err := common.Marshal(metadata) if err != nil { return fmt.Errorf("marshal metadata failed: %w", err) } - err = json.Unmarshal(metadataBytes, v) + err = common.Unmarshal(metadataBytes, v) if err != nil { return fmt.Errorf("unmarshal metadata to target failed: %w", err) } diff --git a/relay/helper/price.go b/relay/helper/price.go index c310220fe..1cb04166f 100644 --- a/relay/helper/price.go +++ b/relay/helper/price.go @@ -140,7 +140,7 @@ func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens } // ModelPriceHelperPerCall 按次计费的 PriceHelper (MJ、Task) -func ModelPriceHelperPerCall(c *gin.Context, info *relaycommon.RelayInfo) types.PerCallPriceData { +func ModelPriceHelperPerCall(c *gin.Context, info *relaycommon.RelayInfo) types.PriceData { groupRatioInfo := HandleGroupRatio(c, info) modelPrice, success := ratio_setting.GetModelPrice(info.OriginModelName, true) @@ -154,7 +154,18 @@ func ModelPriceHelperPerCall(c *gin.Context, info *relaycommon.RelayInfo) types. } } quota := int(modelPrice * common.QuotaPerUnit * groupRatioInfo.GroupRatio) - priceData := types.PerCallPriceData{ + + // 免费模型检测(与 ModelPriceHelper 对齐) + freeModel := false + if !operation_setting.GetQuotaSetting().EnableFreeModelPreConsume { + if groupRatioInfo.GroupRatio == 0 || modelPrice == 0 { + quota = 0 + freeModel = true + } + } + + priceData := types.PriceData{ + FreeModel: freeModel, ModelPrice: modelPrice, Quota: quota, GroupRatioInfo: groupRatioInfo, diff --git a/relay/relay_task.go b/relay/relay_task.go index ebbd1f65d..d372ca2e8 100644 --- a/relay/relay_task.go +++ b/relay/relay_task.go @@ -2,7 +2,6 @@ package relay import ( "bytes" - "encoding/json" "errors" "fmt" "io" @@ -15,29 +14,33 @@ import ( "github.com/QuantumNous/new-api/dto" "github.com/QuantumNous/new-api/model" "github.com/QuantumNous/new-api/relay/channel" + "github.com/QuantumNous/new-api/relay/channel/task/taskcommon" relaycommon "github.com/QuantumNous/new-api/relay/common" relayconstant "github.com/QuantumNous/new-api/relay/constant" + "github.com/QuantumNous/new-api/relay/helper" "github.com/QuantumNous/new-api/service" - "github.com/QuantumNous/new-api/setting/ratio_setting" - "github.com/gin-gonic/gin" ) -/* -Task 任务通过平台、Action 区分任务 -*/ -func RelayTaskSubmit(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.TaskError) { - info.InitChannelMeta(c) - // ensure TaskRelayInfo is initialized to avoid nil dereference when accessing embedded fields - if info.TaskRelayInfo == nil { - info.TaskRelayInfo = &relaycommon.TaskRelayInfo{} - } +type TaskSubmitResult struct { + UpstreamTaskID string + TaskData []byte + Platform constant.TaskPlatform + ModelName string + Quota int + //PerCallPrice types.PriceData +} + +// ResolveOriginTask 处理基于已有任务的提交(remix / continuation): +// 查找原始任务、从中提取模型名称、将渠道锁定到原始任务的渠道(并通过 +// specific_channel_id 禁止重试),以及提取 OtherRatios(时长、分辨率)。 +// 该函数在控制器的重试循环之前调用一次,其结果通过 info 字段和上下文持久化。 +func ResolveOriginTask(c *gin.Context, info *relaycommon.RelayInfo) *dto.TaskError { + // 检测 remix action path := c.Request.URL.Path if strings.Contains(path, "/v1/videos/") && strings.HasSuffix(path, "/remix") { info.Action = constant.TaskActionRemix } - - // 提取 remix 任务的 video_id if info.Action == constant.TaskActionRemix { videoID := c.Param("video_id") if strings.TrimSpace(videoID) == "" { @@ -46,241 +49,164 @@ func RelayTaskSubmit(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto. info.OriginTaskID = videoID } - platform := constant.TaskPlatform(c.GetString("platform")) + if info.OriginTaskID == "" { + return nil + } - // 获取原始任务信息 - if info.OriginTaskID != "" { - originTask, exist, err := model.GetByTaskId(info.UserId, info.OriginTaskID) - if err != nil { - taskErr = service.TaskErrorWrapper(err, "get_origin_task_failed", http.StatusInternalServerError) - return - } - if !exist { - taskErr = service.TaskErrorWrapperLocal(errors.New("task_origin_not_exist"), "task_not_exist", http.StatusBadRequest) - return - } - if info.OriginModelName == "" { - if originTask.Properties.OriginModelName != "" { - info.OriginModelName = originTask.Properties.OriginModelName - } else if originTask.Properties.UpstreamModelName != "" { - info.OriginModelName = originTask.Properties.UpstreamModelName - } else { - var taskData map[string]interface{} - _ = json.Unmarshal(originTask.Data, &taskData) - if m, ok := taskData["model"].(string); ok && m != "" { - info.OriginModelName = m - platform = originTask.Platform - } - } - } - if originTask.ChannelId != info.ChannelId { - channel, err := model.GetChannelById(originTask.ChannelId, true) - if err != nil { - taskErr = service.TaskErrorWrapperLocal(err, "channel_not_found", http.StatusBadRequest) - return - } - if channel.Status != common.ChannelStatusEnabled { - taskErr = service.TaskErrorWrapperLocal(errors.New("the channel of the origin task is disabled"), "task_channel_disable", http.StatusBadRequest) - return - } - key, _, newAPIError := channel.GetNextEnabledKey() - if newAPIError != nil { - taskErr = service.TaskErrorWrapper(newAPIError, "channel_no_available_key", newAPIError.StatusCode) - return - } - common.SetContextKey(c, constant.ContextKeyChannelKey, key) - common.SetContextKey(c, constant.ContextKeyChannelType, channel.Type) - common.SetContextKey(c, constant.ContextKeyChannelBaseUrl, channel.GetBaseURL()) - common.SetContextKey(c, constant.ContextKeyChannelId, originTask.ChannelId) + // 查找原始任务 + originTask, exist, err := model.GetByTaskId(info.UserId, info.OriginTaskID) + if err != nil { + return service.TaskErrorWrapper(err, "get_origin_task_failed", http.StatusInternalServerError) + } + if !exist { + return service.TaskErrorWrapperLocal(errors.New("task_origin_not_exist"), "task_not_exist", http.StatusBadRequest) + } - info.ChannelBaseUrl = channel.GetBaseURL() - info.ChannelId = originTask.ChannelId - info.ChannelType = channel.Type - info.ApiKey = key - platform = originTask.Platform - } - - // 使用原始任务的参数 - if info.Action == constant.TaskActionRemix { + // 从原始任务推导模型名称 + if info.OriginModelName == "" { + if originTask.Properties.OriginModelName != "" { + info.OriginModelName = originTask.Properties.OriginModelName + } else if originTask.Properties.UpstreamModelName != "" { + info.OriginModelName = originTask.Properties.UpstreamModelName + } else { var taskData map[string]interface{} - _ = json.Unmarshal(originTask.Data, &taskData) - secondsStr, _ := taskData["seconds"].(string) - seconds, _ := strconv.Atoi(secondsStr) - if seconds <= 0 { - seconds = 4 - } - sizeStr, _ := taskData["size"].(string) - if info.PriceData.OtherRatios == nil { - info.PriceData.OtherRatios = map[string]float64{} - } - info.PriceData.OtherRatios["seconds"] = float64(seconds) - info.PriceData.OtherRatios["size"] = 1 - if sizeStr == "1792x1024" || sizeStr == "1024x1792" { - info.PriceData.OtherRatios["size"] = 1.666667 + _ = common.Unmarshal(originTask.Data, &taskData) + if m, ok := taskData["model"].(string); ok && m != "" { + info.OriginModelName = m } } } + + // 锁定到原始任务的渠道(如果与当前选中的不同) + if originTask.ChannelId != info.ChannelId { + ch, err := model.GetChannelById(originTask.ChannelId, true) + if err != nil { + return service.TaskErrorWrapperLocal(err, "channel_not_found", http.StatusBadRequest) + } + if ch.Status != common.ChannelStatusEnabled { + return service.TaskErrorWrapperLocal(errors.New("the channel of the origin task is disabled"), "task_channel_disable", http.StatusBadRequest) + } + key, _, newAPIError := ch.GetNextEnabledKey() + if newAPIError != nil { + return service.TaskErrorWrapper(newAPIError, "channel_no_available_key", newAPIError.StatusCode) + } + common.SetContextKey(c, constant.ContextKeyChannelKey, key) + common.SetContextKey(c, constant.ContextKeyChannelType, ch.Type) + common.SetContextKey(c, constant.ContextKeyChannelBaseUrl, ch.GetBaseURL()) + common.SetContextKey(c, constant.ContextKeyChannelId, originTask.ChannelId) + + info.ChannelBaseUrl = ch.GetBaseURL() + info.ChannelId = originTask.ChannelId + info.ChannelType = ch.Type + info.ApiKey = key + } + + // 渠道已锁定到原始任务 → 禁止重试切换到其他渠道 + c.Set("specific_channel_id", fmt.Sprintf("%d", originTask.ChannelId)) + + // 提取 remix 参数(时长、分辨率 → OtherRatios) + if info.Action == constant.TaskActionRemix { + var taskData map[string]interface{} + _ = common.Unmarshal(originTask.Data, &taskData) + secondsStr, _ := taskData["seconds"].(string) + seconds, _ := strconv.Atoi(secondsStr) + if seconds <= 0 { + seconds = 4 + } + sizeStr, _ := taskData["size"].(string) + if info.PriceData.OtherRatios == nil { + info.PriceData.OtherRatios = map[string]float64{} + } + info.PriceData.OtherRatios["seconds"] = float64(seconds) + info.PriceData.OtherRatios["size"] = 1 + if sizeStr == "1792x1024" || sizeStr == "1024x1792" { + info.PriceData.OtherRatios["size"] = 1.666667 + } + } + + return nil +} + +// RelayTaskSubmit 完成 task 提交的全部流程(每次尝试调用一次): +// 刷新渠道元数据 → 确定 platform/adaptor → 验证请求 → 计算价格 → +// 预扣费(仅首次,通过 info.Billing==nil 守卫)→ 构建/发送/解析上游请求。 +// 控制器负责 defer Refund 和成功后 Settle。 +func RelayTaskSubmit(c *gin.Context, info *relaycommon.RelayInfo) (*TaskSubmitResult, *dto.TaskError) { + info.InitChannelMeta(c) + + // 1. 确定 platform → 创建适配器 → 验证请求 + platform := constant.TaskPlatform(c.GetString("platform")) if platform == "" { platform = GetTaskPlatform(c) } - - info.InitChannelMeta(c) adaptor := GetTaskAdaptor(platform) if adaptor == nil { - return service.TaskErrorWrapperLocal(fmt.Errorf("invalid api platform: %s", platform), "invalid_api_platform", http.StatusBadRequest) + return nil, service.TaskErrorWrapperLocal(fmt.Errorf("invalid api platform: %s", platform), "invalid_api_platform", http.StatusBadRequest) } adaptor.Init(info) - // get & validate taskRequest 获取并验证文本请求 - taskErr = adaptor.ValidateRequestAndSetAction(c, info) - if taskErr != nil { - return + if taskErr := adaptor.ValidateRequestAndSetAction(c, info); taskErr != nil { + return nil, taskErr } + // 2. 确定模型名称 modelName := info.OriginModelName if modelName == "" { modelName = service.CoverTaskActionToModelName(platform, info.Action) } - modelPrice, success := ratio_setting.GetModelPrice(modelName, true) - if !success { - defaultPrice, ok := ratio_setting.GetDefaultModelPriceMap()[modelName] - if !ok { - modelPrice = float64(common.PreConsumedQuota) / common.QuotaPerUnit - } else { - modelPrice = defaultPrice - } + + // 3. 预生成公开 task ID(仅首次) + if info.PublicTaskID == "" { + info.PublicTaskID = model.GenerateTaskID() } - // 处理 auto 分组:从 context 获取实际选中的分组 - // 当使用 auto 分组时,Distribute 中间件会将实际选中的分组存储在 ContextKeyAutoGroup 中 - if autoGroup, exists := common.GetContextKey(c, constant.ContextKeyAutoGroup); exists { - if groupStr, ok := autoGroup.(string); ok && groupStr != "" { - info.UsingGroup = groupStr - } - } + // 4. 价格计算 + info.OriginModelName = modelName + info.PriceData = helper.ModelPriceHelperPerCall(c, info) - // 预扣 - groupRatio := ratio_setting.GetGroupRatio(info.UsingGroup) - var ratio float64 - userGroupRatio, hasUserGroupRatio := ratio_setting.GetGroupGroupRatio(info.UserGroup, info.UsingGroup) - if hasUserGroupRatio { - ratio = modelPrice * userGroupRatio - } else { - ratio = modelPrice * groupRatio - } - // FIXME: 临时修补,支持任务仅按次计费 if !common.StringsContains(constant.TaskPricePatches, modelName) { - if len(info.PriceData.OtherRatios) > 0 { - for _, ra := range info.PriceData.OtherRatios { - if 1.0 != ra { - ratio *= ra - } + for _, ra := range info.PriceData.OtherRatios { + if ra != 1.0 { + info.PriceData.Quota = int(float64(info.PriceData.Quota) * ra) } } } - println(fmt.Sprintf("model: %s, model_price: %.4f, group: %s, group_ratio: %.4f, final_ratio: %.4f", modelName, modelPrice, info.UsingGroup, groupRatio, ratio)) - userQuota, err := model.GetUserQuota(info.UserId, false) - if err != nil { - taskErr = service.TaskErrorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError) - return - } - quota := int(ratio * common.QuotaPerUnit) - if userQuota-quota < 0 { - taskErr = service.TaskErrorWrapperLocal(errors.New("user quota is not enough"), "quota_not_enough", http.StatusForbidden) - return + + // 5. 预扣费(仅首次 — 重试时 info.Billing 已存在,跳过) + if info.Billing == nil && !info.PriceData.FreeModel { + info.ForcePreConsume = true + if apiErr := service.PreConsumeBilling(c, info.PriceData.Quota, info); apiErr != nil { + return nil, service.TaskErrorFromAPIError(apiErr) + } } - // build body + // 6. 构建请求体 requestBody, err := adaptor.BuildRequestBody(c, info) if err != nil { - taskErr = service.TaskErrorWrapper(err, "build_request_failed", http.StatusInternalServerError) - return + return nil, service.TaskErrorWrapper(err, "build_request_failed", http.StatusInternalServerError) } - // do request + + // 7. 发送请求 resp, err := adaptor.DoRequest(c, info, requestBody) if err != nil { - taskErr = service.TaskErrorWrapper(err, "do_request_failed", http.StatusInternalServerError) - return + return nil, service.TaskErrorWrapper(err, "do_request_failed", http.StatusInternalServerError) } - // handle response if resp != nil && resp.StatusCode != http.StatusOK { responseBody, _ := io.ReadAll(resp.Body) - taskErr = service.TaskErrorWrapper(fmt.Errorf("%s", string(responseBody)), "fail_to_fetch_task", resp.StatusCode) - return + return nil, service.TaskErrorWrapper(fmt.Errorf("%s", string(responseBody)), "fail_to_fetch_task", resp.StatusCode) } - defer func() { - // release quota - if info.ConsumeQuota && taskErr == nil { - - err := service.PostConsumeQuota(info, quota, 0, true) - if err != nil { - common.SysLog("error consuming token remain quota: " + err.Error()) - } - if quota != 0 { - tokenName := c.GetString("token_name") - //gRatio := groupRatio - //if hasUserGroupRatio { - // gRatio = userGroupRatio - //} - logContent := fmt.Sprintf("操作 %s", info.Action) - // FIXME: 临时修补,支持任务仅按次计费 - if common.StringsContains(constant.TaskPricePatches, modelName) { - logContent = fmt.Sprintf("%s,按次计费", logContent) - } else { - if len(info.PriceData.OtherRatios) > 0 { - var contents []string - for key, ra := range info.PriceData.OtherRatios { - if 1.0 != ra { - contents = append(contents, fmt.Sprintf("%s: %.2f", key, ra)) - } - } - if len(contents) > 0 { - logContent = fmt.Sprintf("%s, 计算参数:%s", logContent, strings.Join(contents, ", ")) - } - } - } - other := make(map[string]interface{}) - if c != nil && c.Request != nil && c.Request.URL != nil { - other["request_path"] = c.Request.URL.Path - } - other["model_price"] = modelPrice - other["group_ratio"] = groupRatio - if hasUserGroupRatio { - other["user_group_ratio"] = userGroupRatio - } - model.RecordConsumeLog(c, info.UserId, model.RecordConsumeLogParams{ - ChannelId: info.ChannelId, - ModelName: modelName, - TokenName: tokenName, - Quota: quota, - Content: logContent, - TokenId: info.TokenId, - Group: info.UsingGroup, - Other: other, - }) - model.UpdateUserUsedQuotaAndRequestCount(info.UserId, quota) - model.UpdateChannelUsedQuota(info.ChannelId, quota) - } - } - }() - - taskID, taskData, taskErr := adaptor.DoResponse(c, resp, info) + // 8. 解析响应 + upstreamTaskID, taskData, taskErr := adaptor.DoResponse(c, resp, info) if taskErr != nil { - return + return nil, taskErr } - info.ConsumeQuota = true - // insert task - task := model.InitTask(platform, info) - task.TaskID = taskID - task.Quota = quota - task.Data = taskData - task.Action = info.Action - err = task.Insert() - if err != nil { - taskErr = service.TaskErrorWrapper(err, "insert_task_failed", http.StatusInternalServerError) - return - } - return nil + + return &TaskSubmitResult{ + UpstreamTaskID: upstreamTaskID, + TaskData: taskData, + Platform: platform, + ModelName: modelName, + }, nil } var fetchRespBuilders = map[int]func(c *gin.Context) (respBody []byte, taskResp *dto.TaskError){ @@ -336,7 +262,7 @@ func sunoFetchRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *dto.Ta } else { tasks = make([]any, 0) } - respBody, err = json.Marshal(dto.TaskResponse[[]any]{ + respBody, err = common.Marshal(dto.TaskResponse[[]any]{ Code: "success", Data: tasks, }) @@ -357,7 +283,7 @@ func sunoFetchByIDRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *dt return } - respBody, err = json.Marshal(dto.TaskResponse[any]{ + respBody, err = common.Marshal(dto.TaskResponse[any]{ Code: "success", Data: TaskModel2Dto(originTask), }) @@ -381,97 +307,16 @@ func videoFetchByIDRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *d return } - func() { - channelModel, err2 := model.GetChannelById(originTask.ChannelId, true) - if err2 != nil { - return - } - if channelModel.Type != constant.ChannelTypeVertexAi && channelModel.Type != constant.ChannelTypeGemini { - return - } - baseURL := constant.ChannelBaseURLs[channelModel.Type] - if channelModel.GetBaseURL() != "" { - baseURL = channelModel.GetBaseURL() - } - proxy := channelModel.GetSetting().Proxy - adaptor := GetTaskAdaptor(constant.TaskPlatform(strconv.Itoa(channelModel.Type))) - if adaptor == nil { - return - } - resp, err2 := adaptor.FetchTask(baseURL, channelModel.Key, map[string]any{ - "task_id": originTask.TaskID, - "action": originTask.Action, - }, proxy) - if err2 != nil || resp == nil { - return - } - defer resp.Body.Close() - body, err2 := io.ReadAll(resp.Body) - if err2 != nil { - return - } - ti, err2 := adaptor.ParseTaskResult(body) - if err2 == nil && ti != nil { - if ti.Status != "" { - originTask.Status = model.TaskStatus(ti.Status) - } - if ti.Progress != "" { - originTask.Progress = ti.Progress - } - if ti.Url != "" { - if strings.HasPrefix(ti.Url, "data:") { - } else { - originTask.FailReason = ti.Url - } - } - _ = originTask.Update() - var raw map[string]any - _ = json.Unmarshal(body, &raw) - format := "mp4" - if respObj, ok := raw["response"].(map[string]any); ok { - if vids, ok := respObj["videos"].([]any); ok && len(vids) > 0 { - if v0, ok := vids[0].(map[string]any); ok { - if mt, ok := v0["mimeType"].(string); ok && mt != "" { - if strings.Contains(mt, "mp4") { - format = "mp4" - } else { - format = mt - } - } - } - } - } - status := "processing" - switch originTask.Status { - case model.TaskStatusSuccess: - status = "succeeded" - case model.TaskStatusFailure: - status = "failed" - case model.TaskStatusQueued, model.TaskStatusSubmitted: - status = "queued" - } - if !strings.HasPrefix(c.Request.RequestURI, "/v1/videos/") { - out := map[string]any{ - "error": nil, - "format": format, - "metadata": nil, - "status": status, - "task_id": originTask.TaskID, - "url": originTask.FailReason, - } - respBody, _ = json.Marshal(dto.TaskResponse[any]{ - Code: "success", - Data: out, - }) - } - } - }() + isOpenAIVideoAPI := strings.HasPrefix(c.Request.RequestURI, "/v1/videos/") - if len(respBody) != 0 { + // Gemini/Vertex 支持实时查询:用户 fetch 时直接从上游拉取最新状态 + if realtimeResp := tryRealtimeFetch(originTask, isOpenAIVideoAPI); len(realtimeResp) > 0 { + respBody = realtimeResp return } - if strings.HasPrefix(c.Request.RequestURI, "/v1/videos/") { + // OpenAI Video API 格式: 走各 adaptor 的 ConvertToOpenAIVideo + if isOpenAIVideoAPI { adaptor := GetTaskAdaptor(originTask.Platform) if adaptor == nil { taskResp = service.TaskErrorWrapperLocal(fmt.Errorf("invalid channel id: %d", originTask.ChannelId), "invalid_channel_id", http.StatusBadRequest) @@ -486,10 +331,12 @@ func videoFetchByIDRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *d respBody = openAIVideoData return } - taskResp = service.TaskErrorWrapperLocal(errors.New(fmt.Sprintf("not_implemented:%s", originTask.Platform)), "not_implemented", http.StatusNotImplemented) + taskResp = service.TaskErrorWrapperLocal(fmt.Errorf("not_implemented:%s", originTask.Platform), "not_implemented", http.StatusNotImplemented) return } - respBody, err = json.Marshal(dto.TaskResponse[any]{ + + // 通用 TaskDto 格式 + respBody, err = common.Marshal(dto.TaskResponse[any]{ Code: "success", Data: TaskModel2Dto(originTask), }) @@ -499,16 +346,145 @@ func videoFetchByIDRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *d return } +// tryRealtimeFetch 尝试从上游实时拉取 Gemini/Vertex 任务状态。 +// 仅当渠道类型为 Gemini 或 Vertex 时触发;其他渠道或出错时返回 nil。 +// 当非 OpenAI Video API 时,还会构建自定义格式的响应体。 +func tryRealtimeFetch(task *model.Task, isOpenAIVideoAPI bool) []byte { + channelModel, err := model.GetChannelById(task.ChannelId, true) + if err != nil { + return nil + } + if channelModel.Type != constant.ChannelTypeVertexAi && channelModel.Type != constant.ChannelTypeGemini { + return nil + } + + baseURL := constant.ChannelBaseURLs[channelModel.Type] + if channelModel.GetBaseURL() != "" { + baseURL = channelModel.GetBaseURL() + } + proxy := channelModel.GetSetting().Proxy + adaptor := GetTaskAdaptor(constant.TaskPlatform(strconv.Itoa(channelModel.Type))) + if adaptor == nil { + return nil + } + + resp, err := adaptor.FetchTask(baseURL, channelModel.Key, map[string]any{ + "task_id": task.GetUpstreamTaskID(), + "action": task.Action, + }, proxy) + if err != nil || resp == nil { + return nil + } + defer resp.Body.Close() + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil + } + + ti, err := adaptor.ParseTaskResult(body) + if err != nil || ti == nil { + return nil + } + + // 将上游最新状态更新到 task + if ti.Status != "" { + task.Status = model.TaskStatus(ti.Status) + } + if ti.Progress != "" { + task.Progress = ti.Progress + } + if strings.HasPrefix(ti.Url, "data:") { + // data: URI — kept in Data, not ResultURL + } else if ti.Url != "" { + task.PrivateData.ResultURL = ti.Url + } else if task.Status == model.TaskStatusSuccess { + // No URL from adaptor — construct proxy URL using public task ID + task.PrivateData.ResultURL = taskcommon.BuildProxyURL(task.TaskID) + } + _ = task.Update() + + // OpenAI Video API 由调用者的 ConvertToOpenAIVideo 分支处理 + if isOpenAIVideoAPI { + return nil + } + + // 非 OpenAI Video API: 构建自定义格式响应 + format := detectVideoFormat(body) + out := map[string]any{ + "error": nil, + "format": format, + "metadata": nil, + "status": mapTaskStatusToSimple(task.Status), + "task_id": task.TaskID, + "url": task.GetResultURL(), + } + respBody, _ := common.Marshal(dto.TaskResponse[any]{ + Code: "success", + Data: out, + }) + return respBody +} + +// detectVideoFormat 从 Gemini/Vertex 原始响应中探测视频格式 +func detectVideoFormat(rawBody []byte) string { + var raw map[string]any + if err := common.Unmarshal(rawBody, &raw); err != nil { + return "mp4" + } + respObj, ok := raw["response"].(map[string]any) + if !ok { + return "mp4" + } + vids, ok := respObj["videos"].([]any) + if !ok || len(vids) == 0 { + return "mp4" + } + v0, ok := vids[0].(map[string]any) + if !ok { + return "mp4" + } + mt, ok := v0["mimeType"].(string) + if !ok || mt == "" || strings.Contains(mt, "mp4") { + return "mp4" + } + return mt +} + +// mapTaskStatusToSimple 将内部 TaskStatus 映射为简化状态字符串 +func mapTaskStatusToSimple(status model.TaskStatus) string { + switch status { + case model.TaskStatusSuccess: + return "succeeded" + case model.TaskStatusFailure: + return "failed" + case model.TaskStatusQueued, model.TaskStatusSubmitted: + return "queued" + default: + return "processing" + } +} + func TaskModel2Dto(task *model.Task) *dto.TaskDto { return &dto.TaskDto{ + ID: task.ID, + CreatedAt: task.CreatedAt, + UpdatedAt: task.UpdatedAt, TaskID: task.TaskID, + Platform: string(task.Platform), + UserId: task.UserId, + Group: task.Group, + ChannelId: task.ChannelId, + Quota: task.Quota, Action: task.Action, Status: string(task.Status), FailReason: task.FailReason, + ResultURL: task.GetResultURL(), SubmitTime: task.SubmitTime, StartTime: task.StartTime, FinishTime: task.FinishTime, Progress: task.Progress, + Properties: task.Properties, + Username: task.Username, Data: task.Data, } } diff --git a/router/video-router.go b/router/video-router.go index d5fed1d78..d2bce42b2 100644 --- a/router/video-router.go +++ b/router/video-router.go @@ -8,10 +8,16 @@ import ( ) func SetVideoRouter(router *gin.Engine) { + // Video proxy: accepts either session auth (dashboard) or token auth (API clients) + videoProxyRouter := router.Group("/v1") + videoProxyRouter.Use(middleware.TokenOrUserAuth()) + { + videoProxyRouter.GET("/videos/:task_id/content", controller.VideoProxy) + } + videoV1Router := router.Group("/v1") videoV1Router.Use(middleware.TokenAuth(), middleware.Distribute()) { - videoV1Router.GET("/videos/:task_id/content", controller.VideoProxy) videoV1Router.POST("/video/generations", controller.RelayTask) videoV1Router.GET("/video/generations/:task_id", controller.RelayTask) videoV1Router.POST("/videos/:video_id/remix", controller.RelayTask) diff --git a/service/billing_session.go b/service/billing_session.go index 1a31316b5..f24b68e55 100644 --- a/service/billing_session.go +++ b/service/billing_session.go @@ -193,6 +193,11 @@ func (s *BillingSession) preConsume(c *gin.Context, quota int) *types.NewAPIErro // shouldTrust 统一信任额度检查,适用于钱包和订阅。 func (s *BillingSession) shouldTrust(c *gin.Context) bool { + // 异步任务(ForcePreConsume=true)必须预扣全额,不允许信任旁路 + if s.relayInfo.ForcePreConsume { + return false + } + trustQuota := common.GetTrustQuota() if trustQuota <= 0 { return false diff --git a/service/error.go b/service/error.go index 7a9d7a815..a2ff0aad7 100644 --- a/service/error.go +++ b/service/error.go @@ -206,3 +206,16 @@ func TaskErrorWrapper(err error, code string, statusCode int) *dto.TaskError { return taskError } + +// TaskErrorFromAPIError 将 PreConsumeBilling 返回的 NewAPIError 转换为 TaskError。 +func TaskErrorFromAPIError(apiErr *types.NewAPIError) *dto.TaskError { + if apiErr == nil { + return nil + } + return &dto.TaskError{ + Code: string(apiErr.GetErrorCode()), + Message: apiErr.Err.Error(), + StatusCode: apiErr.StatusCode, + Error: apiErr.Err, + } +} diff --git a/service/log_info_generate.go b/service/log_info_generate.go index 771da5b77..1c440911b 100644 --- a/service/log_info_generate.go +++ b/service/log_info_generate.go @@ -204,7 +204,7 @@ func GenerateClaudeOtherInfo(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, return info } -func GenerateMjOtherInfo(relayInfo *relaycommon.RelayInfo, priceData types.PerCallPriceData) map[string]interface{} { +func GenerateMjOtherInfo(relayInfo *relaycommon.RelayInfo, priceData types.PriceData) map[string]interface{} { other := make(map[string]interface{}) other["model_price"] = priceData.ModelPrice other["group_ratio"] = priceData.GroupRatioInfo.GroupRatio diff --git a/service/task_billing.go b/service/task_billing.go new file mode 100644 index 000000000..ec0094bd9 --- /dev/null +++ b/service/task_billing.go @@ -0,0 +1,227 @@ +package service + +import ( + "context" + "fmt" + "strings" + + "github.com/QuantumNous/new-api/common" + "github.com/QuantumNous/new-api/constant" + "github.com/QuantumNous/new-api/logger" + "github.com/QuantumNous/new-api/model" + relaycommon "github.com/QuantumNous/new-api/relay/common" + "github.com/QuantumNous/new-api/setting/ratio_setting" + "github.com/gin-gonic/gin" +) + +// LogTaskConsumption 记录任务消费日志和统计信息(仅记录,不涉及实际扣费)。 +// 实际扣费已由 BillingSession(PreConsumeBilling + SettleBilling)完成。 +func LogTaskConsumption(c *gin.Context, info *relaycommon.RelayInfo, modelName string) { + tokenName := c.GetString("token_name") + logContent := fmt.Sprintf("操作 %s", info.Action) + // 支持任务仅按次计费 + if common.StringsContains(constant.TaskPricePatches, modelName) { + logContent = fmt.Sprintf("%s,按次计费", logContent) + } else { + if len(info.PriceData.OtherRatios) > 0 { + var contents []string + for key, ra := range info.PriceData.OtherRatios { + if 1.0 != ra { + contents = append(contents, fmt.Sprintf("%s: %.2f", key, ra)) + } + } + if len(contents) > 0 { + logContent = fmt.Sprintf("%s, 计算参数:%s", logContent, strings.Join(contents, ", ")) + } + } + } + other := make(map[string]interface{}) + other["request_path"] = c.Request.URL.Path + other["model_price"] = info.PriceData.ModelPrice + other["group_ratio"] = info.PriceData.GroupRatioInfo.GroupRatio + if info.PriceData.GroupRatioInfo.HasSpecialRatio { + other["user_group_ratio"] = info.PriceData.GroupRatioInfo.GroupSpecialRatio + } + model.RecordConsumeLog(c, info.UserId, model.RecordConsumeLogParams{ + ChannelId: info.ChannelId, + ModelName: modelName, + TokenName: tokenName, + Quota: info.PriceData.Quota, + Content: logContent, + TokenId: info.TokenId, + Group: info.UsingGroup, + Other: other, + }) + model.UpdateUserUsedQuotaAndRequestCount(info.UserId, info.PriceData.Quota) + model.UpdateChannelUsedQuota(info.ChannelId, info.PriceData.Quota) +} + +// --------------------------------------------------------------------------- +// 异步任务计费辅助函数 +// --------------------------------------------------------------------------- + +// resolveTokenKey 通过 TokenId 运行时获取令牌 Key(用于 Redis 缓存操作)。 +// 如果令牌已被删除或查询失败,返回空字符串。 +func resolveTokenKey(ctx context.Context, tokenId int, taskID string) string { + token, err := model.GetTokenById(tokenId) + if err != nil { + logger.LogWarn(ctx, fmt.Sprintf("获取令牌 key 失败 (tokenId=%d, task=%s): %s", tokenId, taskID, err.Error())) + return "" + } + return token.Key +} + +// taskIsSubscription 判断任务是否通过订阅计费。 +func taskIsSubscription(task *model.Task) bool { + return task.PrivateData.BillingSource == BillingSourceSubscription && task.PrivateData.SubscriptionId > 0 +} + +// taskAdjustFunding 调整任务的资金来源(钱包或订阅),delta > 0 表示扣费,delta < 0 表示退还。 +func taskAdjustFunding(task *model.Task, delta int) error { + if taskIsSubscription(task) { + return model.PostConsumeUserSubscriptionDelta(task.PrivateData.SubscriptionId, int64(delta)) + } + if delta > 0 { + return model.DecreaseUserQuota(task.UserId, delta) + } + return model.IncreaseUserQuota(task.UserId, -delta, false) +} + +// taskAdjustTokenQuota 调整任务的令牌额度,delta > 0 表示扣费,delta < 0 表示退还。 +// 需要通过 resolveTokenKey 运行时获取 key(不从 PrivateData 中读取)。 +func taskAdjustTokenQuota(ctx context.Context, task *model.Task, delta int) { + if task.PrivateData.TokenId <= 0 || delta == 0 { + return + } + tokenKey := resolveTokenKey(ctx, task.PrivateData.TokenId, task.TaskID) + if tokenKey == "" { + return + } + var err error + if delta > 0 { + err = model.DecreaseTokenQuota(task.PrivateData.TokenId, tokenKey, delta) + } else { + err = model.IncreaseTokenQuota(task.PrivateData.TokenId, tokenKey, -delta) + } + if err != nil { + logger.LogWarn(ctx, fmt.Sprintf("调整令牌额度失败 (delta=%d, task=%s): %s", delta, task.TaskID, err.Error())) + } +} + +// RefundTaskQuota 统一的任务失败退款逻辑。 +// 当异步任务失败时,将预扣的 quota 退还给用户(支持钱包和订阅),并退还令牌额度。 +func RefundTaskQuota(ctx context.Context, task *model.Task, reason string) { + quota := task.Quota + if quota == 0 { + return + } + + // 1. 退还资金来源(钱包或订阅) + if err := taskAdjustFunding(task, -quota); err != nil { + logger.LogWarn(ctx, fmt.Sprintf("退还资金来源失败 task %s: %s", task.TaskID, err.Error())) + return + } + + // 2. 退还令牌额度 + taskAdjustTokenQuota(ctx, task, -quota) + + // 3. 记录日志 + logContent := fmt.Sprintf("异步任务执行失败 %s,补偿 %s,原因:%s", task.TaskID, logger.LogQuota(quota), reason) + model.RecordLog(task.UserId, model.LogTypeSystem, logContent) +} + +// RecalculateTaskQuotaByTokens 根据实际 token 消耗重新计费(异步差额结算)。 +// 当任务成功且返回了 totalTokens 时,根据模型倍率和分组倍率重新计算实际扣费额度, +// 与预扣费的差额进行补扣或退还。支持钱包和订阅计费来源。 +func RecalculateTaskQuotaByTokens(ctx context.Context, task *model.Task, totalTokens int) { + if totalTokens <= 0 { + return + } + + // 获取模型名称 + var taskData map[string]interface{} + if err := common.Unmarshal(task.Data, &taskData); err != nil { + return + } + modelName, ok := taskData["model"].(string) + if !ok || modelName == "" { + return + } + + // 获取模型价格和倍率 + modelRatio, hasRatioSetting, _ := ratio_setting.GetModelRatio(modelName) + // 只有配置了倍率(非固定价格)时才按 token 重新计费 + if !hasRatioSetting || modelRatio <= 0 { + return + } + + // 获取用户和组的倍率信息 + group := task.Group + if group == "" { + user, err := model.GetUserById(task.UserId, false) + if err == nil { + group = user.Group + } + } + if group == "" { + return + } + + groupRatio := ratio_setting.GetGroupRatio(group) + userGroupRatio, hasUserGroupRatio := ratio_setting.GetGroupGroupRatio(group, group) + + var finalGroupRatio float64 + if hasUserGroupRatio { + finalGroupRatio = userGroupRatio + } else { + finalGroupRatio = groupRatio + } + + // 计算实际应扣费额度: totalTokens * modelRatio * groupRatio + actualQuota := int(float64(totalTokens) * modelRatio * finalGroupRatio) + + // 计算差额(正数=需要补扣,负数=需要退还) + preConsumedQuota := task.Quota + quotaDelta := actualQuota - preConsumedQuota + + if quotaDelta == 0 { + logger.LogInfo(ctx, fmt.Sprintf("视频任务 %s 预扣费准确(%s,tokens:%d)", + task.TaskID, logger.LogQuota(actualQuota), totalTokens)) + return + } + + logger.LogInfo(ctx, fmt.Sprintf("视频任务 %s 差额结算:delta=%s(实际:%s,预扣:%s,tokens:%d)", + task.TaskID, + logger.LogQuota(quotaDelta), + logger.LogQuota(actualQuota), + logger.LogQuota(preConsumedQuota), + totalTokens, + )) + + // 调整资金来源 + if err := taskAdjustFunding(task, quotaDelta); err != nil { + logger.LogError(ctx, fmt.Sprintf("差额结算资金调整失败 task %s: %s", task.TaskID, err.Error())) + return + } + + // 调整令牌额度 + taskAdjustTokenQuota(ctx, task, quotaDelta) + + // 更新统计(仅补扣时更新,退还不影响已用统计) + if quotaDelta > 0 { + model.UpdateUserUsedQuotaAndRequestCount(task.UserId, quotaDelta) + model.UpdateChannelUsedQuota(task.ChannelId, quotaDelta) + } + task.Quota = actualQuota + + var action string + if quotaDelta > 0 { + action = "补扣费" + } else { + action = "退还" + } + logContent := fmt.Sprintf("视频任务成功%s,模型倍率 %.2f,分组倍率 %.2f,tokens %d,预扣费 %s,实际扣费 %s", + action, modelRatio, finalGroupRatio, totalTokens, + logger.LogQuota(preConsumedQuota), logger.LogQuota(actualQuota)) + model.RecordLog(task.UserId, model.LogTypeSystem, logContent) +} diff --git a/service/task_polling.go b/service/task_polling.go new file mode 100644 index 000000000..847e1659b --- /dev/null +++ b/service/task_polling.go @@ -0,0 +1,446 @@ +package service + +import ( + "context" + "errors" + "fmt" + "io" + "net/http" + "sort" + "strings" + "time" + + "github.com/QuantumNous/new-api/common" + "github.com/QuantumNous/new-api/constant" + "github.com/QuantumNous/new-api/dto" + "github.com/QuantumNous/new-api/logger" + "github.com/QuantumNous/new-api/model" + "github.com/QuantumNous/new-api/relay/channel/task/taskcommon" + relaycommon "github.com/QuantumNous/new-api/relay/common" + + "github.com/samber/lo" +) + +// TaskPollingAdaptor 定义轮询所需的最小适配器接口,避免 service -> relay 的循环依赖 +type TaskPollingAdaptor interface { + Init(info *relaycommon.RelayInfo) + FetchTask(baseURL string, key string, body map[string]any, proxy string) (*http.Response, error) + ParseTaskResult(body []byte) (*relaycommon.TaskInfo, error) +} + +// GetTaskAdaptorFunc 由 main 包注入,用于获取指定平台的任务适配器。 +// 打破 service -> relay -> relay/channel -> service 的循环依赖。 +var GetTaskAdaptorFunc func(platform constant.TaskPlatform) TaskPollingAdaptor + +// TaskPollingLoop 主轮询循环,每 15 秒检查一次未完成的任务 +func TaskPollingLoop() { + for { + time.Sleep(time.Duration(15) * time.Second) + common.SysLog("任务进度轮询开始") + ctx := context.TODO() + allTasks := model.GetAllUnFinishSyncTasks(constant.TaskQueryLimit) + platformTask := make(map[constant.TaskPlatform][]*model.Task) + for _, t := range allTasks { + platformTask[t.Platform] = append(platformTask[t.Platform], t) + } + for platform, tasks := range platformTask { + if len(tasks) == 0 { + continue + } + taskChannelM := make(map[int][]string) + taskM := make(map[string]*model.Task) + nullTaskIds := make([]int64, 0) + for _, task := range tasks { + upstreamID := task.GetUpstreamTaskID() + if upstreamID == "" { + // 统计失败的未完成任务 + nullTaskIds = append(nullTaskIds, task.ID) + continue + } + taskM[upstreamID] = task + taskChannelM[task.ChannelId] = append(taskChannelM[task.ChannelId], upstreamID) + } + if len(nullTaskIds) > 0 { + err := model.TaskBulkUpdateByID(nullTaskIds, map[string]any{ + "status": "FAILURE", + "progress": "100%", + }) + if err != nil { + logger.LogError(ctx, fmt.Sprintf("Fix null task_id task error: %v", err)) + } else { + logger.LogInfo(ctx, fmt.Sprintf("Fix null task_id task success: %v", nullTaskIds)) + } + } + if len(taskChannelM) == 0 { + continue + } + + DispatchPlatformUpdate(platform, taskChannelM, taskM) + } + common.SysLog("任务进度轮询完成") + } +} + +// DispatchPlatformUpdate 按平台分发轮询更新 +func DispatchPlatformUpdate(platform constant.TaskPlatform, taskChannelM map[int][]string, taskM map[string]*model.Task) { + switch platform { + case constant.TaskPlatformMidjourney: + // MJ 轮询由其自身处理,这里预留入口 + case constant.TaskPlatformSuno: + _ = UpdateSunoTasks(context.Background(), taskChannelM, taskM) + default: + if err := UpdateVideoTasks(context.Background(), platform, taskChannelM, taskM); err != nil { + common.SysLog(fmt.Sprintf("UpdateVideoTasks fail: %s", err)) + } + } +} + +// UpdateSunoTasks 按渠道更新所有 Suno 任务 +func UpdateSunoTasks(ctx context.Context, taskChannelM map[int][]string, taskM map[string]*model.Task) error { + for channelId, taskIds := range taskChannelM { + err := updateSunoTasks(ctx, channelId, taskIds, taskM) + if err != nil { + logger.LogError(ctx, fmt.Sprintf("渠道 #%d 更新异步任务失败: %s", channelId, err.Error())) + } + } + return nil +} + +func updateSunoTasks(ctx context.Context, channelId int, taskIds []string, taskM map[string]*model.Task) error { + logger.LogInfo(ctx, fmt.Sprintf("渠道 #%d 未完成的任务有: %d", channelId, len(taskIds))) + if len(taskIds) == 0 { + return nil + } + ch, err := model.CacheGetChannel(channelId) + if err != nil { + common.SysLog(fmt.Sprintf("CacheGetChannel: %v", err)) + // Collect DB primary key IDs for bulk update (taskIds are upstream IDs, not task_id column values) + var failedIDs []int64 + for _, upstreamID := range taskIds { + if t, ok := taskM[upstreamID]; ok { + failedIDs = append(failedIDs, t.ID) + } + } + err = model.TaskBulkUpdateByID(failedIDs, map[string]any{ + "fail_reason": fmt.Sprintf("获取渠道信息失败,请联系管理员,渠道ID:%d", channelId), + "status": "FAILURE", + "progress": "100%", + }) + if err != nil { + common.SysLog(fmt.Sprintf("UpdateSunoTask error: %v", err)) + } + return err + } + adaptor := GetTaskAdaptorFunc(constant.TaskPlatformSuno) + if adaptor == nil { + return errors.New("adaptor not found") + } + proxy := ch.GetSetting().Proxy + resp, err := adaptor.FetchTask(*ch.BaseURL, ch.Key, map[string]any{ + "ids": taskIds, + }, proxy) + if err != nil { + common.SysLog(fmt.Sprintf("Get Task Do req error: %v", err)) + return err + } + if resp.StatusCode != http.StatusOK { + logger.LogError(ctx, fmt.Sprintf("Get Task status code: %d", resp.StatusCode)) + return fmt.Errorf("Get Task status code: %d", resp.StatusCode) + } + defer resp.Body.Close() + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + common.SysLog(fmt.Sprintf("Get Task parse body error: %v", err)) + return err + } + var responseItems dto.TaskResponse[[]dto.SunoDataResponse] + err = common.Unmarshal(responseBody, &responseItems) + if err != nil { + logger.LogError(ctx, fmt.Sprintf("Get Task parse body error2: %v, body: %s", err, string(responseBody))) + return err + } + if !responseItems.IsSuccess() { + common.SysLog(fmt.Sprintf("渠道 #%d 未完成的任务有: %d, 成功获取到任务数: %s", channelId, len(taskIds), string(responseBody))) + return err + } + + for _, responseItem := range responseItems.Data { + task := taskM[responseItem.TaskID] + if !taskNeedsUpdate(task, responseItem) { + continue + } + + task.Status = lo.If(model.TaskStatus(responseItem.Status) != "", model.TaskStatus(responseItem.Status)).Else(task.Status) + task.FailReason = lo.If(responseItem.FailReason != "", responseItem.FailReason).Else(task.FailReason) + task.SubmitTime = lo.If(responseItem.SubmitTime != 0, responseItem.SubmitTime).Else(task.SubmitTime) + task.StartTime = lo.If(responseItem.StartTime != 0, responseItem.StartTime).Else(task.StartTime) + task.FinishTime = lo.If(responseItem.FinishTime != 0, responseItem.FinishTime).Else(task.FinishTime) + if responseItem.FailReason != "" || task.Status == model.TaskStatusFailure { + logger.LogInfo(ctx, task.TaskID+" 构建失败,"+task.FailReason) + task.Progress = "100%" + RefundTaskQuota(ctx, task, task.FailReason) + } + if responseItem.Status == model.TaskStatusSuccess { + task.Progress = "100%" + } + task.Data = responseItem.Data + + err = task.Update() + if err != nil { + common.SysLog("UpdateSunoTask task error: " + err.Error()) + } + } + return nil +} + +// taskNeedsUpdate 检查 Suno 任务是否需要更新 +func taskNeedsUpdate(oldTask *model.Task, newTask dto.SunoDataResponse) bool { + if oldTask.SubmitTime != newTask.SubmitTime { + return true + } + if oldTask.StartTime != newTask.StartTime { + return true + } + if oldTask.FinishTime != newTask.FinishTime { + return true + } + if string(oldTask.Status) != newTask.Status { + return true + } + if oldTask.FailReason != newTask.FailReason { + return true + } + + if (oldTask.Status == model.TaskStatusFailure || oldTask.Status == model.TaskStatusSuccess) && oldTask.Progress != "100%" { + return true + } + + oldData, _ := common.Marshal(oldTask.Data) + newData, _ := common.Marshal(newTask.Data) + + sort.Slice(oldData, func(i, j int) bool { + return oldData[i] < oldData[j] + }) + sort.Slice(newData, func(i, j int) bool { + return newData[i] < newData[j] + }) + + if string(oldData) != string(newData) { + return true + } + return false +} + +// UpdateVideoTasks 按渠道更新所有视频任务 +func UpdateVideoTasks(ctx context.Context, platform constant.TaskPlatform, taskChannelM map[int][]string, taskM map[string]*model.Task) error { + for channelId, taskIds := range taskChannelM { + if err := updateVideoTasks(ctx, platform, channelId, taskIds, taskM); err != nil { + logger.LogError(ctx, fmt.Sprintf("Channel #%d failed to update video async tasks: %s", channelId, err.Error())) + } + } + return nil +} + +func updateVideoTasks(ctx context.Context, platform constant.TaskPlatform, channelId int, taskIds []string, taskM map[string]*model.Task) error { + logger.LogInfo(ctx, fmt.Sprintf("Channel #%d pending video tasks: %d", channelId, len(taskIds))) + if len(taskIds) == 0 { + return nil + } + cacheGetChannel, err := model.CacheGetChannel(channelId) + if err != nil { + // Collect DB primary key IDs for bulk update (taskIds are upstream IDs, not task_id column values) + var failedIDs []int64 + for _, upstreamID := range taskIds { + if t, ok := taskM[upstreamID]; ok { + failedIDs = append(failedIDs, t.ID) + } + } + errUpdate := model.TaskBulkUpdateByID(failedIDs, map[string]any{ + "fail_reason": fmt.Sprintf("Failed to get channel info, channel ID: %d", channelId), + "status": "FAILURE", + "progress": "100%", + }) + if errUpdate != nil { + common.SysLog(fmt.Sprintf("UpdateVideoTask error: %v", errUpdate)) + } + return fmt.Errorf("CacheGetChannel failed: %w", err) + } + adaptor := GetTaskAdaptorFunc(platform) + if adaptor == nil { + return fmt.Errorf("video adaptor not found") + } + info := &relaycommon.RelayInfo{} + info.ChannelMeta = &relaycommon.ChannelMeta{ + ChannelBaseUrl: cacheGetChannel.GetBaseURL(), + } + info.ApiKey = cacheGetChannel.Key + adaptor.Init(info) + for _, taskId := range taskIds { + if err := updateVideoSingleTask(ctx, adaptor, cacheGetChannel, taskId, taskM); err != nil { + logger.LogError(ctx, fmt.Sprintf("Failed to update video task %s: %s", taskId, err.Error())) + } + } + return nil +} + +func updateVideoSingleTask(ctx context.Context, adaptor TaskPollingAdaptor, ch *model.Channel, taskId string, taskM map[string]*model.Task) error { + baseURL := constant.ChannelBaseURLs[ch.Type] + if ch.GetBaseURL() != "" { + baseURL = ch.GetBaseURL() + } + proxy := ch.GetSetting().Proxy + + task := taskM[taskId] + if task == nil { + logger.LogError(ctx, fmt.Sprintf("Task %s not found in taskM", taskId)) + return fmt.Errorf("task %s not found", taskId) + } + key := ch.Key + + privateData := task.PrivateData + if privateData.Key != "" { + key = privateData.Key + } + resp, err := adaptor.FetchTask(baseURL, key, map[string]any{ + "task_id": task.GetUpstreamTaskID(), + "action": task.Action, + }, proxy) + if err != nil { + return fmt.Errorf("fetchTask failed for task %s: %w", taskId, err) + } + defer resp.Body.Close() + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return fmt.Errorf("readAll failed for task %s: %w", taskId, err) + } + + logger.LogDebug(ctx, fmt.Sprintf("updateVideoSingleTask response: %s", string(responseBody))) + + taskResult := &relaycommon.TaskInfo{} + // try parse as New API response format + var responseItems dto.TaskResponse[model.Task] + if err = common.Unmarshal(responseBody, &responseItems); err == nil && responseItems.IsSuccess() { + logger.LogDebug(ctx, fmt.Sprintf("updateVideoSingleTask parsed as new api response format: %+v", responseItems)) + t := responseItems.Data + taskResult.TaskID = t.TaskID + taskResult.Status = string(t.Status) + taskResult.Url = t.GetResultURL() + taskResult.Progress = t.Progress + taskResult.Reason = t.FailReason + task.Data = t.Data + } else if taskResult, err = adaptor.ParseTaskResult(responseBody); err != nil { + return fmt.Errorf("parseTaskResult failed for task %s: %w", taskId, err) + } else { + task.Data = redactVideoResponseBody(responseBody) + } + + logger.LogDebug(ctx, fmt.Sprintf("updateVideoSingleTask taskResult: %+v", taskResult)) + + now := time.Now().Unix() + if taskResult.Status == "" { + taskResult = relaycommon.FailTaskInfo("upstream returned empty status") + } + + // 记录原本的状态,防止重复退款 + shouldRefund := false + quota := task.Quota + preStatus := task.Status + + task.Status = model.TaskStatus(taskResult.Status) + switch taskResult.Status { + case model.TaskStatusSubmitted: + task.Progress = taskcommon.ProgressSubmitted + case model.TaskStatusQueued: + task.Progress = taskcommon.ProgressQueued + case model.TaskStatusInProgress: + task.Progress = taskcommon.ProgressInProgress + if task.StartTime == 0 { + task.StartTime = now + } + case model.TaskStatusSuccess: + task.Progress = taskcommon.ProgressComplete + if task.FinishTime == 0 { + task.FinishTime = now + } + if strings.HasPrefix(taskResult.Url, "data:") { + // data: URI (e.g. Vertex base64 encoded video) — keep in Data, not in ResultURL + } else if taskResult.Url != "" { + // Direct upstream URL (e.g. Kling, Ali, Doubao, etc.) + task.PrivateData.ResultURL = taskResult.Url + } else { + // No URL from adaptor — construct proxy URL using public task ID + task.PrivateData.ResultURL = taskcommon.BuildProxyURL(task.TaskID) + } + + // 如果返回了 total_tokens,根据模型倍率重新计费 + if taskResult.TotalTokens > 0 { + RecalculateTaskQuotaByTokens(ctx, task, taskResult.TotalTokens) + } + case model.TaskStatusFailure: + logger.LogJson(ctx, fmt.Sprintf("Task %s failed", taskId), task) + task.Status = model.TaskStatusFailure + task.Progress = taskcommon.ProgressComplete + if task.FinishTime == 0 { + task.FinishTime = now + } + task.FailReason = taskResult.Reason + logger.LogInfo(ctx, fmt.Sprintf("Task %s failed: %s", task.TaskID, task.FailReason)) + taskResult.Progress = taskcommon.ProgressComplete + if quota != 0 { + if preStatus != model.TaskStatusFailure { + shouldRefund = true + } else { + logger.LogWarn(ctx, fmt.Sprintf("Task %s already in failure status, skip refund", task.TaskID)) + } + } + default: + return fmt.Errorf("unknown task status %s for task %s", taskResult.Status, taskId) + } + if taskResult.Progress != "" { + task.Progress = taskResult.Progress + } + if err := task.Update(); err != nil { + common.SysLog("UpdateVideoTask task error: " + err.Error()) + shouldRefund = false + } + + if shouldRefund { + RefundTaskQuota(ctx, task, task.FailReason) + } + + return nil +} + +func redactVideoResponseBody(body []byte) []byte { + var m map[string]any + if err := common.Unmarshal(body, &m); err != nil { + return body + } + resp, _ := m["response"].(map[string]any) + if resp != nil { + delete(resp, "bytesBase64Encoded") + if v, ok := resp["video"].(string); ok { + resp["video"] = truncateBase64(v) + } + if vs, ok := resp["videos"].([]any); ok { + for i := range vs { + if vm, ok := vs[i].(map[string]any); ok { + delete(vm, "bytesBase64Encoded") + } + } + } + } + b, err := common.Marshal(m) + if err != nil { + return body + } + return b +} + +func truncateBase64(s string) string { + const maxKeep = 256 + if len(s) <= maxKeep { + return s + } + return s[:maxKeep] + "..." +} diff --git a/types/price_data.go b/types/price_data.go index 3f7121b8c..93bc6ae8d 100644 --- a/types/price_data.go +++ b/types/price_data.go @@ -22,7 +22,8 @@ type PriceData struct { AudioCompletionRatio float64 OtherRatios map[string]float64 UsePrice bool - QuotaToPreConsume int // 预消耗额度 + Quota int // 按次计费的最终额度(MJ / Task) + QuotaToPreConsume int // 按量计费的预消耗额度 GroupRatioInfo GroupRatioInfo } @@ -36,12 +37,6 @@ func (p *PriceData) AddOtherRatio(key string, ratio float64) { p.OtherRatios[key] = ratio } -type PerCallPriceData struct { - ModelPrice float64 - Quota int - GroupRatioInfo GroupRatioInfo -} - func (p *PriceData) ToSetting() string { return fmt.Sprintf("ModelPrice: %f, ModelRatio: %f, CompletionRatio: %f, CacheRatio: %f, GroupRatio: %f, UsePrice: %t, CacheCreationRatio: %f, CacheCreation5mRatio: %f, CacheCreation1hRatio: %f, QuotaToPreConsume: %d, ImageRatio: %f, AudioRatio: %f, AudioCompletionRatio: %f", p.ModelPrice, p.ModelRatio, p.CompletionRatio, p.CacheRatio, p.GroupRatioInfo.GroupRatio, p.UsePrice, p.CacheCreationRatio, p.CacheCreation5mRatio, p.CacheCreation1hRatio, p.QuotaToPreConsume, p.ImageRatio, p.AudioRatio, p.AudioCompletionRatio) } diff --git a/web/src/components/table/task-logs/TaskLogsColumnDefs.jsx b/web/src/components/table/task-logs/TaskLogsColumnDefs.jsx index c78d5773e..4bce45256 100644 --- a/web/src/components/table/task-logs/TaskLogsColumnDefs.jsx +++ b/web/src/components/table/task-logs/TaskLogsColumnDefs.jsx @@ -396,7 +396,7 @@ export const getTaskLogsColumns = ({ dataIndex: 'fail_reason', fixed: 'right', render: (text, record, index) => { - // 仅当为视频生成任务且成功,且 fail_reason 是 URL 时显示可点击链接 + // 视频预览:优先使用 result_url,兼容旧数据 fail_reason 中的 URL const isVideoTask = record.action === TASK_ACTION_GENERATE || record.action === TASK_ACTION_TEXT_GENERATE || @@ -404,14 +404,15 @@ export const getTaskLogsColumns = ({ record.action === TASK_ACTION_REFERENCE_GENERATE || record.action === TASK_ACTION_REMIX_GENERATE; const isSuccess = record.status === 'SUCCESS'; - const isUrl = typeof text === 'string' && /^https?:\/\//.test(text); - if (isSuccess && isVideoTask && isUrl) { + const resultUrl = record.result_url; + const hasResultUrl = typeof resultUrl === 'string' && /^https?:\/\//.test(resultUrl); + if (isSuccess && isVideoTask && hasResultUrl) { return ( { e.preventDefault(); - openVideoModal(text); + openVideoModal(resultUrl); }} > {t('点击预览视频')} diff --git a/web/src/components/table/task-logs/modals/ContentModal.jsx b/web/src/components/table/task-logs/modals/ContentModal.jsx index 88df4d8ce..3527fd96d 100644 --- a/web/src/components/table/task-logs/modals/ContentModal.jsx +++ b/web/src/components/table/task-logs/modals/ContentModal.jsx @@ -144,8 +144,6 @@ const ContentModal = ({ maxHeight: '100%', objectFit: 'contain', }} - autoPlay - crossOrigin='anonymous' onError={handleVideoError} onLoadedData={handleVideoLoaded} onLoadStart={() => setIsLoading(true)} From d6e11fd2e1764b4f3b2d78dfbbd3d867612ad946 Mon Sep 17 00:00:00 2001 From: CaIon Date: Tue, 10 Feb 2026 21:15:09 +0800 Subject: [PATCH 11/41] feat(task): add adaptor billing interface and async settlement framework Add three billing lifecycle methods to the TaskAdaptor interface: - EstimateBilling: compute OtherRatios from user request before pricing - AdjustBillingOnSubmit: adjust ratios from upstream submit response - AdjustBillingOnComplete: determine final quota at task terminal state Introduce BaseBilling as embeddable no-op default for adaptors without custom billing. Move Sora/Ali OtherRatios logic from shared validation into per-adaptor EstimateBilling implementations. Add TaskBillingContext to persist pricing params (model_price, group_ratio, other_ratios) in task private data for async polling settlement. Extract RecalculateTaskQuota as a general-purpose delta settlement function and unify polling billing via settleTaskBillingOnComplete (adaptor-first, then token-based fallback). --- controller/relay.go | 7 ++ logger/logger.go | 3 +- model/task.go | 16 +++- relay/channel/adapter.go | 30 +++++++- relay/channel/task/ali/adaptor.go | 57 +++++++++----- relay/channel/task/doubao/adaptor.go | 1 + relay/channel/task/gemini/adaptor.go | 1 + relay/channel/task/hailuo/adaptor.go | 2 + relay/channel/task/jimeng/adaptor.go | 1 + relay/channel/task/kling/adaptor.go | 1 + relay/channel/task/sora/adaptor.go | 44 ++++++++++- relay/channel/task/suno/adaptor.go | 7 +- relay/channel/task/taskcommon/helpers.go | 25 ++++++ relay/channel/task/vertex/adaptor.go | 41 +++++----- relay/channel/task/vidu/adaptor.go | 1 + relay/common/relay_utils.go | 10 +-- relay/relay_task.go | 64 ++++++++++++++-- service/task_billing.go | 98 +++++++++++++----------- service/task_polling.go | 28 ++++++- 19 files changed, 321 insertions(+), 116 deletions(-) diff --git a/controller/relay.go b/controller/relay.go index 132fee9ba..3d2f20e82 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -509,6 +509,13 @@ func RelayTask(c *gin.Context) { task.PrivateData.BillingSource = relayInfo.BillingSource task.PrivateData.SubscriptionId = relayInfo.SubscriptionId task.PrivateData.TokenId = relayInfo.TokenId + task.PrivateData.BillingContext = &model.TaskBillingContext{ + ModelPrice: relayInfo.PriceData.ModelPrice, + GroupRatio: relayInfo.PriceData.GroupRatioInfo.GroupRatio, + ModelRatio: relayInfo.PriceData.ModelRatio, + OtherRatios: relayInfo.PriceData.OtherRatios, + ModelName: result.ModelName, + } task.Quota = result.Quota task.Data = result.TaskData task.Action = relayInfo.Action diff --git a/logger/logger.go b/logger/logger.go index 61b1d49d8..90cf5006e 100644 --- a/logger/logger.go +++ b/logger/logger.go @@ -2,7 +2,6 @@ package logger import ( "context" - "encoding/json" "fmt" "io" "log" @@ -151,7 +150,7 @@ func FormatQuota(quota int) string { // LogJson 仅供测试使用 only for test func LogJson(ctx context.Context, msg string, obj any) { - jsonStr, err := json.Marshal(obj) + jsonStr, err := common.Marshal(obj) if err != nil { LogError(ctx, fmt.Sprintf("json marshal failed: %s", err.Error())) return diff --git a/model/task.go b/model/task.go index 38bb4d05a..592643ebb 100644 --- a/model/task.go +++ b/model/task.go @@ -100,9 +100,19 @@ type TaskPrivateData struct { UpstreamTaskID string `json:"upstream_task_id,omitempty"` // 上游真实 task ID ResultURL string `json:"result_url,omitempty"` // 任务成功后的结果 URL(视频地址等) // 计费上下文:用于异步退款/差额结算(轮询阶段读取) - BillingSource string `json:"billing_source,omitempty"` // "wallet" 或 "subscription" - SubscriptionId int `json:"subscription_id,omitempty"` // 订阅 ID,用于订阅退款 - TokenId int `json:"token_id,omitempty"` // 令牌 ID,用于令牌额度退款 + BillingSource string `json:"billing_source,omitempty"` // "wallet" 或 "subscription" + SubscriptionId int `json:"subscription_id,omitempty"` // 订阅 ID,用于订阅退款 + TokenId int `json:"token_id,omitempty"` // 令牌 ID,用于令牌额度退款 + BillingContext *TaskBillingContext `json:"billing_context,omitempty"` // 计费参数快照(用于轮询阶段重新计算) +} + +// TaskBillingContext 记录任务提交时的计费参数,以便轮询阶段可以重新计算额度。 +type TaskBillingContext struct { + ModelPrice float64 `json:"model_price,omitempty"` // 模型单价 + GroupRatio float64 `json:"group_ratio,omitempty"` // 分组倍率 + ModelRatio float64 `json:"model_ratio,omitempty"` // 模型倍率 + OtherRatios map[string]float64 `json:"other_ratios,omitempty"` // 附加倍率(时长、分辨率等) + ModelName string `json:"model_name,omitempty"` // 模型名称 } // GetUpstreamTaskID 获取上游真实 task ID(用于与 provider 通信) diff --git a/relay/channel/adapter.go b/relay/channel/adapter.go index ff7606e2e..d2f7c6bb6 100644 --- a/relay/channel/adapter.go +++ b/relay/channel/adapter.go @@ -36,6 +36,32 @@ type TaskAdaptor interface { ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) *dto.TaskError + // ── Billing ────────────────────────────────────────────────────── + + // EstimateBilling returns OtherRatios for pre-charge based on user request. + // Called after ValidateRequestAndSetAction, before price calculation. + // Adaptors should extract duration, resolution, etc. from the parsed request + // and return them as ratio multipliers (e.g. {"seconds": 5, "size": 1.666}). + // Return nil to use the base model price without extra ratios. + EstimateBilling(c *gin.Context, info *relaycommon.RelayInfo) map[string]float64 + + // AdjustBillingOnSubmit returns adjusted OtherRatios from the upstream + // submit response. Called after a successful DoResponse. + // If the upstream returned actual parameters that differ from the estimate + // (e.g. actual seconds), return updated ratios so the caller can recalculate + // the quota and settle the delta with the pre-charge. + // Return nil if no adjustment is needed. + AdjustBillingOnSubmit(info *relaycommon.RelayInfo, taskData []byte) map[string]float64 + + // AdjustBillingOnComplete returns the actual quota when a task reaches a + // terminal state (success/failure) during polling. + // Called by the polling loop after ParseTaskResult. + // Return a positive value to trigger delta settlement (supplement / refund). + // Return 0 to keep the pre-charged amount unchanged. + AdjustBillingOnComplete(task *model.Task, taskResult *relaycommon.TaskInfo) int + + // ── Request / Response ─────────────────────────────────────────── + BuildRequestURL(info *relaycommon.RelayInfo) (string, error) BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error BuildRequestBody(c *gin.Context, info *relaycommon.RelayInfo) (io.Reader, error) @@ -46,9 +72,9 @@ type TaskAdaptor interface { GetModelList() []string GetChannelName() string - // FetchTask - FetchTask(baseUrl, key string, body map[string]any, proxy string) (*http.Response, error) + // ── Polling ────────────────────────────────────────────────────── + FetchTask(baseUrl, key string, body map[string]any, proxy string) (*http.Response, error) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) } diff --git a/relay/channel/task/ali/adaptor.go b/relay/channel/task/ali/adaptor.go index 5d14ff655..f55178b3b 100644 --- a/relay/channel/task/ali/adaptor.go +++ b/relay/channel/task/ali/adaptor.go @@ -13,6 +13,7 @@ import ( "github.com/QuantumNous/new-api/logger" "github.com/QuantumNous/new-api/model" "github.com/QuantumNous/new-api/relay/channel" + "github.com/QuantumNous/new-api/relay/channel/task/taskcommon" relaycommon "github.com/QuantumNous/new-api/relay/common" "github.com/QuantumNous/new-api/service" "github.com/samber/lo" @@ -108,10 +109,10 @@ type AliMetadata struct { // ============================ type TaskAdaptor struct { + taskcommon.BaseBilling ChannelType int apiKey string baseURL string - aliReq *AliVideoRequest } func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) { @@ -121,17 +122,7 @@ func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) { } func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.TaskError) { - // 阿里通义万相支持 JSON 格式,不使用 multipart - var taskReq relaycommon.TaskSubmitReq - if err := common.UnmarshalBodyReusable(c, &taskReq); err != nil { - return service.TaskErrorWrapper(err, "unmarshal_task_request_failed", http.StatusBadRequest) - } - aliReq, err := a.convertToAliRequest(info, taskReq) - if err != nil { - return service.TaskErrorWrapper(err, "convert_to_ali_request_failed", http.StatusInternalServerError) - } - a.aliReq = aliReq - logger.LogJson(c, "ali video request body", aliReq) + // ValidateMultipartDirect 负责解析并将原始 TaskSubmitReq 存入 context return relaycommon.ValidateMultipartDirect(c, info) } @@ -148,11 +139,21 @@ func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info } func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayInfo) (io.Reader, error) { - bodyBytes, err := common.Marshal(a.aliReq) + taskReq, err := relaycommon.GetTaskRequest(c) + if err != nil { + return nil, errors.Wrap(err, "get_task_request_failed") + } + + aliReq, err := a.convertToAliRequest(info, taskReq) + if err != nil { + return nil, errors.Wrap(err, "convert_to_ali_request_failed") + } + logger.LogJson(c, "ali video request body", aliReq) + + bodyBytes, err := common.Marshal(aliReq) if err != nil { return nil, errors.Wrap(err, "marshal_ali_request_failed") } - return bytes.NewReader(bodyBytes), nil } @@ -335,19 +336,33 @@ func (a *TaskAdaptor) convertToAliRequest(info *relaycommon.RelayInfo, req relay return nil, errors.New("can't change model with metadata") } - info.PriceData.OtherRatios = map[string]float64{ + return aliReq, nil +} + +// EstimateBilling 根据用户请求参数计算 OtherRatios(时长、分辨率等)。 +// 在 ValidateRequestAndSetAction 之后、价格计算之前调用。 +func (a *TaskAdaptor) EstimateBilling(c *gin.Context, info *relaycommon.RelayInfo) map[string]float64 { + taskReq, err := relaycommon.GetTaskRequest(c) + if err != nil { + return nil + } + + aliReq, err := a.convertToAliRequest(info, taskReq) + if err != nil { + return nil + } + + otherRatios := map[string]float64{ "seconds": float64(aliReq.Parameters.Duration), } - ratios, err := ProcessAliOtherRatios(aliReq) if err != nil { - return nil, err + return otherRatios } - for s, f := range ratios { - info.PriceData.OtherRatios[s] = f + for k, v := range ratios { + otherRatios[k] = v } - - return aliReq, nil + return otherRatios } // DoRequest delegates to common helper diff --git a/relay/channel/task/doubao/adaptor.go b/relay/channel/task/doubao/adaptor.go index 3da125afc..eca421bd3 100644 --- a/relay/channel/task/doubao/adaptor.go +++ b/relay/channel/task/doubao/adaptor.go @@ -89,6 +89,7 @@ type responseTask struct { // ============================ type TaskAdaptor struct { + taskcommon.BaseBilling ChannelType int apiKey string baseURL string diff --git a/relay/channel/task/gemini/adaptor.go b/relay/channel/task/gemini/adaptor.go index a863ea852..06c00a469 100644 --- a/relay/channel/task/gemini/adaptor.go +++ b/relay/channel/task/gemini/adaptor.go @@ -85,6 +85,7 @@ type operationResponse struct { // ============================ type TaskAdaptor struct { + taskcommon.BaseBilling ChannelType int apiKey string baseURL string diff --git a/relay/channel/task/hailuo/adaptor.go b/relay/channel/task/hailuo/adaptor.go index 67a68a10e..ab83d659b 100644 --- a/relay/channel/task/hailuo/adaptor.go +++ b/relay/channel/task/hailuo/adaptor.go @@ -17,12 +17,14 @@ import ( "github.com/QuantumNous/new-api/constant" "github.com/QuantumNous/new-api/dto" "github.com/QuantumNous/new-api/relay/channel" + taskcommon "github.com/QuantumNous/new-api/relay/channel/task/taskcommon" relaycommon "github.com/QuantumNous/new-api/relay/common" "github.com/QuantumNous/new-api/service" ) // https://platform.minimaxi.com/docs/api-reference/video-generation-intro type TaskAdaptor struct { + taskcommon.BaseBilling ChannelType int apiKey string baseURL string diff --git a/relay/channel/task/jimeng/adaptor.go b/relay/channel/task/jimeng/adaptor.go index 7f88be248..b61cca418 100644 --- a/relay/channel/task/jimeng/adaptor.go +++ b/relay/channel/task/jimeng/adaptor.go @@ -77,6 +77,7 @@ const ( // ============================ type TaskAdaptor struct { + taskcommon.BaseBilling ChannelType int accessKey string secretKey string diff --git a/relay/channel/task/kling/adaptor.go b/relay/channel/task/kling/adaptor.go index 4458626b2..46e210f19 100644 --- a/relay/channel/task/kling/adaptor.go +++ b/relay/channel/task/kling/adaptor.go @@ -97,6 +97,7 @@ type responsePayload struct { // ============================ type TaskAdaptor struct { + taskcommon.BaseBilling ChannelType int apiKey string baseURL string diff --git a/relay/channel/task/sora/adaptor.go b/relay/channel/task/sora/adaptor.go index ee69a3e48..8faaf984f 100644 --- a/relay/channel/task/sora/adaptor.go +++ b/relay/channel/task/sora/adaptor.go @@ -4,6 +4,7 @@ import ( "fmt" "io" "net/http" + "strconv" "strings" "github.com/QuantumNous/new-api/common" @@ -11,6 +12,7 @@ import ( "github.com/QuantumNous/new-api/dto" "github.com/QuantumNous/new-api/model" "github.com/QuantumNous/new-api/relay/channel" + taskcommon "github.com/QuantumNous/new-api/relay/channel/task/taskcommon" relaycommon "github.com/QuantumNous/new-api/relay/common" "github.com/QuantumNous/new-api/service" @@ -56,6 +58,7 @@ type responseTask struct { // ============================ type TaskAdaptor struct { + taskcommon.BaseBilling ChannelType int apiKey string baseURL string @@ -68,15 +71,15 @@ func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) { } func validateRemixRequest(c *gin.Context) *dto.TaskError { - var req struct { - Prompt string `json:"prompt"` - } + var req relaycommon.TaskSubmitReq if err := common.UnmarshalBodyReusable(c, &req); err != nil { return service.TaskErrorWrapperLocal(err, "invalid_request", http.StatusBadRequest) } if strings.TrimSpace(req.Prompt) == "" { return service.TaskErrorWrapperLocal(fmt.Errorf("field prompt is required"), "invalid_request", http.StatusBadRequest) } + // 存储原始请求到 context,与 ValidateMultipartDirect 路径保持一致 + c.Set("task_request", req) return nil } @@ -87,6 +90,41 @@ func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycom return relaycommon.ValidateMultipartDirect(c, info) } +// EstimateBilling 根据用户请求的 seconds 和 size 计算 OtherRatios。 +func (a *TaskAdaptor) EstimateBilling(c *gin.Context, info *relaycommon.RelayInfo) map[string]float64 { + // remix 路径的 OtherRatios 已在 ResolveOriginTask 中设置 + if info.Action == constant.TaskActionRemix { + return nil + } + + req, err := relaycommon.GetTaskRequest(c) + if err != nil { + return nil + } + + seconds, _ := strconv.Atoi(req.Seconds) + if seconds == 0 { + seconds = req.Duration + } + if seconds <= 0 { + seconds = 4 + } + + size := req.Size + if size == "" { + size = "720x1280" + } + + ratios := map[string]float64{ + "seconds": float64(seconds), + "size": 1, + } + if size == "1792x1024" || size == "1024x1792" { + ratios["size"] = 1.666667 + } + return ratios +} + func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.RelayInfo) (string, error) { if info.Action == constant.TaskActionRemix { return fmt.Sprintf("%s/v1/videos/%s/remix", a.baseURL, info.OriginTaskID), nil diff --git a/relay/channel/task/suno/adaptor.go b/relay/channel/task/suno/adaptor.go index 5dd62a70f..2dbb44f00 100644 --- a/relay/channel/task/suno/adaptor.go +++ b/relay/channel/task/suno/adaptor.go @@ -13,6 +13,7 @@ import ( "github.com/QuantumNous/new-api/constant" "github.com/QuantumNous/new-api/dto" "github.com/QuantumNous/new-api/relay/channel" + taskcommon "github.com/QuantumNous/new-api/relay/channel/task/taskcommon" relaycommon "github.com/QuantumNous/new-api/relay/common" "github.com/QuantumNous/new-api/service" @@ -20,6 +21,7 @@ import ( ) type TaskAdaptor struct { + taskcommon.BaseBilling ChannelType int } @@ -79,10 +81,7 @@ func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayInfo) (io.Reader, error) { sunoRequest, ok := c.Get("task_request") if !ok { - err := common.UnmarshalBodyReusable(c, &sunoRequest) - if err != nil { - return nil, err - } + return nil, fmt.Errorf("task_request not found in context") } data, err := common.Marshal(sunoRequest) if err != nil { diff --git a/relay/channel/task/taskcommon/helpers.go b/relay/channel/task/taskcommon/helpers.go index b1dde998b..27d6612d4 100644 --- a/relay/channel/task/taskcommon/helpers.go +++ b/relay/channel/task/taskcommon/helpers.go @@ -5,7 +5,10 @@ import ( "fmt" "github.com/QuantumNous/new-api/common" + "github.com/QuantumNous/new-api/model" + relaycommon "github.com/QuantumNous/new-api/relay/common" "github.com/QuantumNous/new-api/setting/system_setting" + "github.com/gin-gonic/gin" ) // UnmarshalMetadata converts a map[string]any metadata to a typed struct via JSON round-trip. @@ -68,3 +71,25 @@ const ( ProgressInProgress = "30%" ProgressComplete = "100%" ) + +// --------------------------------------------------------------------------- +// BaseBilling — embeddable no-op implementations for TaskAdaptor billing methods. +// Adaptors that do not need custom billing can embed this struct directly. +// --------------------------------------------------------------------------- + +type BaseBilling struct{} + +// EstimateBilling returns nil (no extra ratios; use base model price). +func (BaseBilling) EstimateBilling(_ *gin.Context, _ *relaycommon.RelayInfo) map[string]float64 { + return nil +} + +// AdjustBillingOnSubmit returns nil (no submit-time adjustment). +func (BaseBilling) AdjustBillingOnSubmit(_ *relaycommon.RelayInfo, _ []byte) map[string]float64 { + return nil +} + +// AdjustBillingOnComplete returns 0 (keep pre-charged amount). +func (BaseBilling) AdjustBillingOnComplete(_ *model.Task, _ *relaycommon.TaskInfo) int { + return 0 +} diff --git a/relay/channel/task/vertex/adaptor.go b/relay/channel/task/vertex/adaptor.go index fb3a313ff..4931002dd 100644 --- a/relay/channel/task/vertex/adaptor.go +++ b/relay/channel/task/vertex/adaptor.go @@ -62,6 +62,7 @@ type operationResponse struct { // ============================ type TaskAdaptor struct { + taskcommon.BaseBilling ChannelType int apiKey string baseURL string @@ -133,6 +134,28 @@ func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info return nil } +// EstimateBilling 根据用户请求中的 sampleCount 计算 OtherRatios。 +func (a *TaskAdaptor) EstimateBilling(c *gin.Context, _ *relaycommon.RelayInfo) map[string]float64 { + sampleCount := 1 + v, ok := c.Get("task_request") + if ok { + req := v.(relaycommon.TaskSubmitReq) + if req.Metadata != nil { + if sc, exists := req.Metadata["sampleCount"]; exists { + if i, ok := sc.(int); ok && i > 0 { + sampleCount = i + } + if f, ok := sc.(float64); ok && int(f) > 0 { + sampleCount = int(f) + } + } + } + } + return map[string]float64{ + "sampleCount": float64(sampleCount), + } +} + // BuildRequestBody converts request into Vertex specific format. func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayInfo) (io.Reader, error) { v, ok := c.Get("task_request") @@ -166,24 +189,6 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayIn return nil, fmt.Errorf("sampleCount must be greater than 0") } - // if req.Duration > 0 { - // body.Parameters["durationSeconds"] = req.Duration - // } else if req.Seconds != "" { - // seconds, err := strconv.Atoi(req.Seconds) - // if err != nil { - // return nil, errors.Wrap(err, "convert seconds to int failed") - // } - // body.Parameters["durationSeconds"] = seconds - // } - - info.PriceData.OtherRatios = map[string]float64{ - "sampleCount": float64(body.Parameters["sampleCount"].(int)), - } - - // if v, ok := body.Parameters["durationSeconds"]; ok { - // info.PriceData.OtherRatios["durationSeconds"] = float64(v.(int)) - // } - data, err := common.Marshal(body) if err != nil { return nil, err diff --git a/relay/channel/task/vidu/adaptor.go b/relay/channel/task/vidu/adaptor.go index 1bab12f03..e689bf888 100644 --- a/relay/channel/task/vidu/adaptor.go +++ b/relay/channel/task/vidu/adaptor.go @@ -73,6 +73,7 @@ type creation struct { // ============================ type TaskAdaptor struct { + taskcommon.BaseBilling ChannelType int baseURL string } diff --git a/relay/common/relay_utils.go b/relay/common/relay_utils.go index b662f9053..3cbb18c22 100644 --- a/relay/common/relay_utils.go +++ b/relay/common/relay_utils.go @@ -173,16 +173,10 @@ func ValidateMultipartDirect(c *gin.Context, info *RelayInfo) *dto.TaskError { if model == "sora-2-pro" && !lo.Contains([]string{"720x1280", "1280x720", "1792x1024", "1024x1792"}, size) { return createTaskError(fmt.Errorf("sora-2 size is invalid"), "invalid_size", http.StatusBadRequest, true) } - info.PriceData.OtherRatios = map[string]float64{ - "seconds": float64(seconds), - "size": 1, - } - if lo.Contains([]string{"1792x1024", "1024x1792"}, size) { - info.PriceData.OtherRatios["size"] = 1.666667 - } + // OtherRatios 已移到 Sora adaptor 的 EstimateBilling 中设置 } - info.Action = action + storeTaskRequest(c, info, action, req) return nil } diff --git a/relay/relay_task.go b/relay/relay_task.go index d372ca2e8..7c6724d80 100644 --- a/relay/relay_task.go +++ b/relay/relay_task.go @@ -128,8 +128,9 @@ func ResolveOriginTask(c *gin.Context, info *relaycommon.RelayInfo) *dto.TaskErr } // RelayTaskSubmit 完成 task 提交的全部流程(每次尝试调用一次): -// 刷新渠道元数据 → 确定 platform/adaptor → 验证请求 → 计算价格 → -// 预扣费(仅首次,通过 info.Billing==nil 守卫)→ 构建/发送/解析上游请求。 +// 刷新渠道元数据 → 确定 platform/adaptor → 验证请求 → +// 估算计费(EstimateBilling) → 计算价格 → 预扣费(仅首次)→ +// 构建/发送/解析上游请求 → 提交后计费调整(AdjustBillingOnSubmit)。 // 控制器负责 defer Refund 和成功后 Settle。 func RelayTaskSubmit(c *gin.Context, info *relaycommon.RelayInfo) (*TaskSubmitResult, *dto.TaskError) { info.InitChannelMeta(c) @@ -159,10 +160,20 @@ func RelayTaskSubmit(c *gin.Context, info *relaycommon.RelayInfo) (*TaskSubmitRe info.PublicTaskID = model.GenerateTaskID() } - // 4. 价格计算 + // 4. 价格计算:基础模型价格 info.OriginModelName = modelName info.PriceData = helper.ModelPriceHelperPerCall(c, info) + // 5. 计费估算:让适配器根据用户请求提供 OtherRatios(时长、分辨率等) + // 必须在 ModelPriceHelperPerCall 之后调用(它会重建 PriceData)。 + // ResolveOriginTask 可能已在 remix 路径中预设了 OtherRatios,此处合并。 + if estimatedRatios := adaptor.EstimateBilling(c, info); len(estimatedRatios) > 0 { + for k, v := range estimatedRatios { + info.PriceData.AddOtherRatio(k, v) + } + } + + // 6. 将 OtherRatios 应用到基础额度 if !common.StringsContains(constant.TaskPricePatches, modelName) { for _, ra := range info.PriceData.OtherRatios { if ra != 1.0 { @@ -171,7 +182,7 @@ func RelayTaskSubmit(c *gin.Context, info *relaycommon.RelayInfo) (*TaskSubmitRe } } - // 5. 预扣费(仅首次 — 重试时 info.Billing 已存在,跳过) + // 7. 预扣费(仅首次 — 重试时 info.Billing 已存在,跳过) if info.Billing == nil && !info.PriceData.FreeModel { info.ForcePreConsume = true if apiErr := service.PreConsumeBilling(c, info.PriceData.Quota, info); apiErr != nil { @@ -179,13 +190,13 @@ func RelayTaskSubmit(c *gin.Context, info *relaycommon.RelayInfo) (*TaskSubmitRe } } - // 6. 构建请求体 + // 8. 构建请求体 requestBody, err := adaptor.BuildRequestBody(c, info) if err != nil { return nil, service.TaskErrorWrapper(err, "build_request_failed", http.StatusInternalServerError) } - // 7. 发送请求 + // 9. 发送请求 resp, err := adaptor.DoRequest(c, info, requestBody) if err != nil { return nil, service.TaskErrorWrapper(err, "do_request_failed", http.StatusInternalServerError) @@ -195,20 +206,59 @@ func RelayTaskSubmit(c *gin.Context, info *relaycommon.RelayInfo) (*TaskSubmitRe return nil, service.TaskErrorWrapper(fmt.Errorf("%s", string(responseBody)), "fail_to_fetch_task", resp.StatusCode) } - // 8. 解析响应 + // 10. 返回 OtherRatios 给下游(header 必须在 DoResponse 写 body 之前设置) + otherRatios := info.PriceData.OtherRatios + if otherRatios == nil { + otherRatios = map[string]float64{} + } + ratiosJSON, _ := common.Marshal(otherRatios) + c.Header("X-New-Api-Other-Ratios", string(ratiosJSON)) + + // 11. 解析响应 upstreamTaskID, taskData, taskErr := adaptor.DoResponse(c, resp, info) if taskErr != nil { return nil, taskErr } + // 11. 提交后计费调整:让适配器根据上游实际返回调整 OtherRatios + finalQuota := info.PriceData.Quota + if adjustedRatios := adaptor.AdjustBillingOnSubmit(info, taskData); len(adjustedRatios) > 0 { + // 基于调整后的 ratios 重新计算 quota + finalQuota = recalcQuotaFromRatios(info, adjustedRatios) + info.PriceData.OtherRatios = adjustedRatios + info.PriceData.Quota = finalQuota + } + return &TaskSubmitResult{ UpstreamTaskID: upstreamTaskID, TaskData: taskData, Platform: platform, ModelName: modelName, + Quota: finalQuota, }, nil } +// recalcQuotaFromRatios 根据 adjustedRatios 重新计算 quota。 +// 公式: baseQuota × ∏(ratio) — 其中 baseQuota 是不含 OtherRatios 的基础额度。 +func recalcQuotaFromRatios(info *relaycommon.RelayInfo, ratios map[string]float64) int { + // 从 PriceData 获取不含 OtherRatios 的基础价格 + baseQuota := info.PriceData.Quota + // 先除掉原有的 OtherRatios 恢复基础额度 + for _, ra := range info.PriceData.OtherRatios { + if ra != 1.0 && ra > 0 { + baseQuota = int(float64(baseQuota) / ra) + } + } + // 应用新的 ratios + result := float64(baseQuota) + for _, ra := range ratios { + if ra != 1.0 { + result *= ra + } + } + return int(result) +} + var fetchRespBuilders = map[int]func(c *gin.Context) (respBody []byte, taskResp *dto.TaskError){ relayconstant.RelayModeSunoFetchByID: sunoFetchByIDRespBodyBuilder, relayconstant.RelayModeSunoFetch: sunoFetchRespBodyBuilder, diff --git a/service/task_billing.go b/service/task_billing.go index ec0094bd9..fc44c5876 100644 --- a/service/task_billing.go +++ b/service/task_billing.go @@ -130,6 +130,58 @@ func RefundTaskQuota(ctx context.Context, task *model.Task, reason string) { model.RecordLog(task.UserId, model.LogTypeSystem, logContent) } +// RecalculateTaskQuota 通用的异步差额结算。 +// actualQuota 是任务完成后的实际应扣额度,与预扣额度 (task.Quota) 做差额结算。 +// reason 用于日志记录(例如 "token重算" 或 "adaptor调整")。 +func RecalculateTaskQuota(ctx context.Context, task *model.Task, actualQuota int, reason string) { + if actualQuota <= 0 { + return + } + preConsumedQuota := task.Quota + quotaDelta := actualQuota - preConsumedQuota + + if quotaDelta == 0 { + logger.LogInfo(ctx, fmt.Sprintf("任务 %s 预扣费准确(%s,%s)", + task.TaskID, logger.LogQuota(actualQuota), reason)) + return + } + + logger.LogInfo(ctx, fmt.Sprintf("任务 %s 差额结算:delta=%s(实际:%s,预扣:%s,%s)", + task.TaskID, + logger.LogQuota(quotaDelta), + logger.LogQuota(actualQuota), + logger.LogQuota(preConsumedQuota), + reason, + )) + + // 调整资金来源 + if err := taskAdjustFunding(task, quotaDelta); err != nil { + logger.LogError(ctx, fmt.Sprintf("差额结算资金调整失败 task %s: %s", task.TaskID, err.Error())) + return + } + + // 调整令牌额度 + taskAdjustTokenQuota(ctx, task, quotaDelta) + + // 更新统计(仅补扣时更新,退还不影响已用统计) + if quotaDelta > 0 { + model.UpdateUserUsedQuotaAndRequestCount(task.UserId, quotaDelta) + model.UpdateChannelUsedQuota(task.ChannelId, quotaDelta) + } + task.Quota = actualQuota + + var action string + if quotaDelta > 0 { + action = "补扣费" + } else { + action = "退还" + } + logContent := fmt.Sprintf("异步任务成功%s,预扣费 %s,实际扣费 %s,原因:%s", + action, + logger.LogQuota(preConsumedQuota), logger.LogQuota(actualQuota), reason) + model.RecordLog(task.UserId, model.LogTypeSystem, logContent) +} + // RecalculateTaskQuotaByTokens 根据实际 token 消耗重新计费(异步差额结算)。 // 当任务成功且返回了 totalTokens 时,根据模型倍率和分组倍率重新计算实际扣费额度, // 与预扣费的差额进行补扣或退还。支持钱包和订阅计费来源。 @@ -180,48 +232,6 @@ func RecalculateTaskQuotaByTokens(ctx context.Context, task *model.Task, totalTo // 计算实际应扣费额度: totalTokens * modelRatio * groupRatio actualQuota := int(float64(totalTokens) * modelRatio * finalGroupRatio) - // 计算差额(正数=需要补扣,负数=需要退还) - preConsumedQuota := task.Quota - quotaDelta := actualQuota - preConsumedQuota - - if quotaDelta == 0 { - logger.LogInfo(ctx, fmt.Sprintf("视频任务 %s 预扣费准确(%s,tokens:%d)", - task.TaskID, logger.LogQuota(actualQuota), totalTokens)) - return - } - - logger.LogInfo(ctx, fmt.Sprintf("视频任务 %s 差额结算:delta=%s(实际:%s,预扣:%s,tokens:%d)", - task.TaskID, - logger.LogQuota(quotaDelta), - logger.LogQuota(actualQuota), - logger.LogQuota(preConsumedQuota), - totalTokens, - )) - - // 调整资金来源 - if err := taskAdjustFunding(task, quotaDelta); err != nil { - logger.LogError(ctx, fmt.Sprintf("差额结算资金调整失败 task %s: %s", task.TaskID, err.Error())) - return - } - - // 调整令牌额度 - taskAdjustTokenQuota(ctx, task, quotaDelta) - - // 更新统计(仅补扣时更新,退还不影响已用统计) - if quotaDelta > 0 { - model.UpdateUserUsedQuotaAndRequestCount(task.UserId, quotaDelta) - model.UpdateChannelUsedQuota(task.ChannelId, quotaDelta) - } - task.Quota = actualQuota - - var action string - if quotaDelta > 0 { - action = "补扣费" - } else { - action = "退还" - } - logContent := fmt.Sprintf("视频任务成功%s,模型倍率 %.2f,分组倍率 %.2f,tokens %d,预扣费 %s,实际扣费 %s", - action, modelRatio, finalGroupRatio, totalTokens, - logger.LogQuota(preConsumedQuota), logger.LogQuota(actualQuota)) - model.RecordLog(task.UserId, model.LogTypeSystem, logContent) + reason := fmt.Sprintf("token重算:tokens=%d, modelRatio=%.2f, groupRatio=%.2f", totalTokens, modelRatio, finalGroupRatio) + RecalculateTaskQuota(ctx, task, actualQuota, reason) } diff --git a/service/task_polling.go b/service/task_polling.go index 847e1659b..efbad8afa 100644 --- a/service/task_polling.go +++ b/service/task_polling.go @@ -26,6 +26,9 @@ type TaskPollingAdaptor interface { Init(info *relaycommon.RelayInfo) FetchTask(baseURL string, key string, body map[string]any, proxy string) (*http.Response, error) ParseTaskResult(body []byte) (*relaycommon.TaskInfo, error) + // AdjustBillingOnComplete 在任务到达终态(成功/失败)时由轮询循环调用。 + // 返回正数触发差额结算(补扣/退还),返回 0 保持预扣费金额不变。 + AdjustBillingOnComplete(task *model.Task, taskResult *relaycommon.TaskInfo) int } // GetTaskAdaptorFunc 由 main 包注入,用于获取指定平台的任务适配器。 @@ -372,10 +375,8 @@ func updateVideoSingleTask(ctx context.Context, adaptor TaskPollingAdaptor, ch * task.PrivateData.ResultURL = taskcommon.BuildProxyURL(task.TaskID) } - // 如果返回了 total_tokens,根据模型倍率重新计费 - if taskResult.TotalTokens > 0 { - RecalculateTaskQuotaByTokens(ctx, task, taskResult.TotalTokens) - } + // 完成时计费调整:优先由 adaptor 计算,回退到 token 重算 + settleTaskBillingOnComplete(ctx, adaptor, task, taskResult) case model.TaskStatusFailure: logger.LogJson(ctx, fmt.Sprintf("Task %s failed", taskId), task) task.Status = model.TaskStatusFailure @@ -444,3 +445,22 @@ func truncateBase64(s string) string { } return s[:maxKeep] + "..." } + +// settleTaskBillingOnComplete 任务完成时的统一计费调整。 +// 优先级:1. adaptor.AdjustBillingOnComplete 返回正数 → 使用 adaptor 计算的额度 +// +// 2. taskResult.TotalTokens > 0 → 按 token 重算 +// 3. 都不满足 → 保持预扣额度不变 +func settleTaskBillingOnComplete(ctx context.Context, adaptor TaskPollingAdaptor, task *model.Task, taskResult *relaycommon.TaskInfo) { + // 1. 优先让 adaptor 决定最终额度 + if actualQuota := adaptor.AdjustBillingOnComplete(task, taskResult); actualQuota > 0 { + RecalculateTaskQuota(ctx, task, actualQuota, "adaptor计费调整") + return + } + // 2. 回退到 token 重算 + if taskResult.TotalTokens > 0 { + RecalculateTaskQuotaByTokens(ctx, task, taskResult.TotalTokens) + return + } + // 3. 无调整,保持预扣额度 +} From 809ba92089685b761e47b28d550f0d77356c3e68 Mon Sep 17 00:00:00 2001 From: CaIon Date: Sat, 21 Feb 2026 22:48:30 +0800 Subject: [PATCH 12/41] refactor(logs): add refund logging for asynchronous tasks and update translations --- controller/relay.go | 128 +++++++----------- model/log.go | 43 ++++++ service/task_billing.go | 75 ++++++++-- .../table/usage-logs/UsageLogsColumnDefs.jsx | 30 +++- .../table/usage-logs/UsageLogsFilters.jsx | 1 + web/src/hooks/usage-logs/useUsageLogsData.jsx | 24 +++- web/src/i18n/locales/en.json | 5 + web/src/i18n/locales/fr.json | 5 + web/src/i18n/locales/ja.json | 5 + web/src/i18n/locales/ru.json | 5 + web/src/i18n/locales/vi.json | 4 + web/src/i18n/locales/zh-CN.json | 5 + 12 files changed, 229 insertions(+), 101 deletions(-) diff --git a/controller/relay.go b/controller/relay.go index 3d2f20e82..e90d6dd0c 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -451,8 +451,6 @@ func RelayNotFound(c *gin.Context) { } func RelayTask(c *gin.Context) { - channelId := c.GetInt("channel_id") - c.Set("use_channel", []string{fmt.Sprintf("%d", channelId)}) relayInfo, err := relaycommon.GenRelayInfo(c, types.RelayFormatTask, nil, nil) if err != nil { c.JSON(http.StatusInternalServerError, &dto.TaskError{ @@ -463,8 +461,7 @@ func RelayTask(c *gin.Context) { return } - // Fetch 操作是纯 DB 查询(或 task 自带 channelId 的上游查询),不依赖上下文 channel,无需重试 - // TODO: 在video-route层面优化,避免无谓的 channel 选择和上下文设置,也没必要吧代码放到这里来写这么多屎山 + // Fetch 路径:纯 DB 查询,不依赖上下文 channel,无需重试 switch relayInfo.RelayMode { case relayconstant.RelayModeSunoFetch, relayconstant.RelayModeSunoFetchByID, relayconstant.RelayModeVideoFetchByID: if taskErr := relay.RelayTaskFetch(c, relayInfo.RelayMode); taskErr != nil { @@ -475,13 +472,11 @@ func RelayTask(c *gin.Context) { // ── Submit 路径 ───────────────────────────────────────────────── - // 1. 解析原始任务(remix / continuation),一次性,可能锁定渠道并禁止重试 if taskErr := relay.ResolveOriginTask(c, relayInfo); taskErr != nil { respondTaskError(c, taskErr) return } - // 2. defer Refund(全部失败时回滚预扣费) var result *relay.TaskSubmitResult var taskErr *dto.TaskError defer func() { @@ -490,14 +485,57 @@ func RelayTask(c *gin.Context) { } }() - // 3. 执行 + 重试(RelayTaskSubmit 内部在首次调用时自动预扣费) - taskErr = taskSubmitWithRetry(c, relayInfo, channelId, common.RetryTimes, func() *dto.TaskError { - var te *dto.TaskError - result, te = relay.RelayTaskSubmit(c, relayInfo) - return te - }) + retryParam := &service.RetryParam{ + Ctx: c, + TokenGroup: relayInfo.TokenGroup, + ModelName: relayInfo.OriginModelName, + Retry: common.GetPointer(0), + } - // 4. 成功:结算 + 日志 + 插入任务 + for ; retryParam.GetRetry() <= common.RetryTimes; retryParam.IncreaseRetry() { + channel, channelErr := getChannel(c, relayInfo, retryParam) + if channelErr != nil { + logger.LogError(c, channelErr.Error()) + taskErr = service.TaskErrorWrapperLocal(channelErr.Err, "get_channel_failed", http.StatusInternalServerError) + break + } + + addUsedChannel(c, channel.Id) + requestBody, bodyErr := common.GetRequestBody(c) + if bodyErr != nil { + if common.IsRequestBodyTooLargeError(bodyErr) || errors.Is(bodyErr, common.ErrRequestBodyTooLarge) { + taskErr = service.TaskErrorWrapperLocal(bodyErr, "read_request_body_failed", http.StatusRequestEntityTooLarge) + } else { + taskErr = service.TaskErrorWrapperLocal(bodyErr, "read_request_body_failed", http.StatusBadRequest) + } + break + } + c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody)) + + result, taskErr = relay.RelayTaskSubmit(c, relayInfo) + if taskErr == nil { + break + } + + if !taskErr.LocalError { + processChannelError(c, + *types.NewChannelError(channel.Id, channel.Type, channel.Name, channel.ChannelInfo.IsMultiKey, + common.GetContextKeyString(c, constant.ContextKeyChannelKey), channel.GetAutoBan()), + types.NewOpenAIError(taskErr.Error, types.ErrorCodeBadResponseStatusCode, taskErr.StatusCode)) + } + + if !shouldRetryTaskRelay(c, channel.Id, taskErr, common.RetryTimes-retryParam.GetRetry()) { + break + } + } + + useChannel := c.GetStringSlice("use_channel") + if len(useChannel) > 1 { + retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]")) + logger.LogInfo(c, retryLogStr) + } + + // ── 成功:结算 + 日志 + 插入任务 ── if taskErr == nil { if settleErr := service.SettleBilling(c, relayInfo, result.Quota); settleErr != nil { common.SysError("settle task billing error: " + settleErr.Error()) @@ -520,7 +558,6 @@ func RelayTask(c *gin.Context) { task.Data = result.TaskData task.Action = relayInfo.Action if insertErr := task.Insert(); insertErr != nil { - //taskErr = service.TaskErrorWrapper(insertErr, "insert_task_failed", http.StatusInternalServerError) common.SysError("insert task error: " + insertErr.Error()) } } @@ -538,69 +575,6 @@ func respondTaskError(c *gin.Context, taskErr *dto.TaskError) { c.JSON(taskErr.StatusCode, taskErr) } -// taskSubmitWithRetry 执行首次尝试并在失败时切换渠道重试,返回最终的 taskErr。 -// attempt 闭包负责实际的上游请求,不涉及计费。 -func taskSubmitWithRetry(c *gin.Context, relayInfo *relaycommon.RelayInfo, - channelId int, retryTimes int, attempt func() *dto.TaskError) *dto.TaskError { - - taskErr := attempt() - if taskErr == nil { - return nil - } - if !taskErr.LocalError { - processChannelError(c, - *types.NewChannelError(channelId, c.GetInt("channel_type"), c.GetString("channel_name"), common.GetContextKeyBool(c, constant.ContextKeyChannelIsMultiKey), - common.GetContextKeyString(c, constant.ContextKeyChannelKey), common.GetContextKeyBool(c, constant.ContextKeyChannelAutoBan)), - types.NewOpenAIError(taskErr.Error, types.ErrorCodeBadResponseStatusCode, taskErr.StatusCode)) - } - - retryParam := &service.RetryParam{ - Ctx: c, - TokenGroup: relayInfo.TokenGroup, - ModelName: relayInfo.OriginModelName, - Retry: common.GetPointer(0), - } - for ; shouldRetryTaskRelay(c, channelId, taskErr, retryTimes) && retryParam.GetRetry() < retryTimes; retryParam.IncreaseRetry() { - channel, newAPIError := getChannel(c, relayInfo, retryParam) - if newAPIError != nil { - logger.LogError(c, fmt.Sprintf("CacheGetRandomSatisfiedChannel failed: %s", newAPIError.Error())) - taskErr = service.TaskErrorWrapperLocal(newAPIError.Err, "get_channel_failed", http.StatusInternalServerError) - break - } - channelId = channel.Id - useChannel := c.GetStringSlice("use_channel") - useChannel = append(useChannel, fmt.Sprintf("%d", channelId)) - c.Set("use_channel", useChannel) - logger.LogInfo(c, fmt.Sprintf("using channel #%d to retry (remain times %d)", channel.Id, retryParam.GetRetry())) - middleware.SetupContextForSelectedChannel(c, channel, c.GetString("original_model")) - - bodyStorage, err := common.GetBodyStorage(c) - if err != nil { - if common.IsRequestBodyTooLargeError(err) || errors.Is(err, common.ErrRequestBodyTooLarge) { - taskErr = service.TaskErrorWrapperLocal(err, "read_request_body_failed", http.StatusRequestEntityTooLarge) - } else { - taskErr = service.TaskErrorWrapperLocal(err, "read_request_body_failed", http.StatusBadRequest) - } - break - } - c.Request.Body = io.NopCloser(bodyStorage) - taskErr = attempt() - if taskErr != nil && !taskErr.LocalError { - processChannelError(c, - *types.NewChannelError(channel.Id, channel.Type, channel.Name, channel.ChannelInfo.IsMultiKey, - common.GetContextKeyString(c, constant.ContextKeyChannelKey), channel.GetAutoBan()), - types.NewOpenAIError(taskErr.Error, types.ErrorCodeBadResponseStatusCode, taskErr.StatusCode)) - } - } - - useChannel := c.GetStringSlice("use_channel") - if len(useChannel) > 1 { - retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]")) - logger.LogInfo(c, retryLogStr) - } - return taskErr -} - func shouldRetryTaskRelay(c *gin.Context, channelId int, taskErr *dto.TaskError, retryTimes int) bool { if taskErr == nil { return false diff --git a/model/log.go b/model/log.go index d7cd97a42..1f521b1e5 100644 --- a/model/log.go +++ b/model/log.go @@ -199,6 +199,49 @@ func RecordConsumeLog(c *gin.Context, userId int, params RecordConsumeLogParams) } } +type RecordTaskBillingLogParams struct { + UserId int + LogType int + Content string + ChannelId int + ModelName string + Quota int + TokenId int + Group string + Other map[string]interface{} +} + +func RecordTaskBillingLog(params RecordTaskBillingLogParams) { + if params.LogType == LogTypeConsume && !common.LogConsumeEnabled { + return + } + username, _ := GetUsernameById(params.UserId, false) + tokenName := "" + if params.TokenId > 0 { + if token, err := GetTokenById(params.TokenId); err == nil { + tokenName = token.Name + } + } + log := &Log{ + UserId: params.UserId, + Username: username, + CreatedAt: common.GetTimestamp(), + Type: params.LogType, + Content: params.Content, + TokenName: tokenName, + ModelName: params.ModelName, + Quota: params.Quota, + ChannelId: params.ChannelId, + TokenId: params.TokenId, + Group: params.Group, + Other: common.MapToJsonStr(params.Other), + } + err := LOG_DB.Create(log).Error + if err != nil { + common.SysLog("failed to record task billing log: " + err.Error()) + } +} + func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, startIdx int, num int, channel int, group string, requestId string) (logs []*Log, total int64, err error) { var tx *gorm.DB if logType == LogTypeUnknown { diff --git a/service/task_billing.go b/service/task_billing.go index fc44c5876..78ad0fc09 100644 --- a/service/task_billing.go +++ b/service/task_billing.go @@ -108,6 +108,29 @@ func taskAdjustTokenQuota(ctx context.Context, task *model.Task, delta int) { } } +// taskBillingOther 从 task 的 BillingContext 构建日志 Other 字段。 +func taskBillingOther(task *model.Task) map[string]interface{} { + other := make(map[string]interface{}) + if bc := task.PrivateData.BillingContext; bc != nil { + other["model_price"] = bc.ModelPrice + other["group_ratio"] = bc.GroupRatio + if len(bc.OtherRatios) > 0 { + for k, v := range bc.OtherRatios { + other[k] = v + } + } + } + return other +} + +// taskModelName 从 BillingContext 或 Properties 中获取模型名称。 +func taskModelName(task *model.Task) string { + if bc := task.PrivateData.BillingContext; bc != nil && bc.ModelName != "" { + return bc.ModelName + } + return task.Properties.OriginModelName +} + // RefundTaskQuota 统一的任务失败退款逻辑。 // 当异步任务失败时,将预扣的 quota 退还给用户(支持钱包和订阅),并退还令牌额度。 func RefundTaskQuota(ctx context.Context, task *model.Task, reason string) { @@ -126,8 +149,20 @@ func RefundTaskQuota(ctx context.Context, task *model.Task, reason string) { taskAdjustTokenQuota(ctx, task, -quota) // 3. 记录日志 - logContent := fmt.Sprintf("异步任务执行失败 %s,补偿 %s,原因:%s", task.TaskID, logger.LogQuota(quota), reason) - model.RecordLog(task.UserId, model.LogTypeSystem, logContent) + other := taskBillingOther(task) + other["task_id"] = task.TaskID + other["reason"] = reason + model.RecordTaskBillingLog(model.RecordTaskBillingLogParams{ + UserId: task.UserId, + LogType: model.LogTypeRefund, + Content: "", + ChannelId: task.ChannelId, + ModelName: taskModelName(task), + Quota: quota, + TokenId: task.PrivateData.TokenId, + Group: task.Group, + Other: other, + }) } // RecalculateTaskQuota 通用的异步差额结算。 @@ -163,23 +198,35 @@ func RecalculateTaskQuota(ctx context.Context, task *model.Task, actualQuota int // 调整令牌额度 taskAdjustTokenQuota(ctx, task, quotaDelta) - // 更新统计(仅补扣时更新,退还不影响已用统计) - if quotaDelta > 0 { - model.UpdateUserUsedQuotaAndRequestCount(task.UserId, quotaDelta) - model.UpdateChannelUsedQuota(task.ChannelId, quotaDelta) - } task.Quota = actualQuota - var action string + var logType int + var logQuota int if quotaDelta > 0 { - action = "补扣费" + logType = model.LogTypeConsume + logQuota = quotaDelta + model.UpdateUserUsedQuotaAndRequestCount(task.UserId, quotaDelta) + model.UpdateChannelUsedQuota(task.ChannelId, quotaDelta) } else { - action = "退还" + logType = model.LogTypeRefund + logQuota = -quotaDelta } - logContent := fmt.Sprintf("异步任务成功%s,预扣费 %s,实际扣费 %s,原因:%s", - action, - logger.LogQuota(preConsumedQuota), logger.LogQuota(actualQuota), reason) - model.RecordLog(task.UserId, model.LogTypeSystem, logContent) + other := taskBillingOther(task) + other["task_id"] = task.TaskID + other["reason"] = reason + other["pre_consumed_quota"] = preConsumedQuota + other["actual_quota"] = actualQuota + model.RecordTaskBillingLog(model.RecordTaskBillingLogParams{ + UserId: task.UserId, + LogType: logType, + Content: "", + ChannelId: task.ChannelId, + ModelName: taskModelName(task), + Quota: logQuota, + TokenId: task.PrivateData.TokenId, + Group: task.Group, + Other: other, + }) } // RecalculateTaskQuotaByTokens 根据实际 token 消耗重新计费(异步差额结算)。 diff --git a/web/src/components/table/usage-logs/UsageLogsColumnDefs.jsx b/web/src/components/table/usage-logs/UsageLogsColumnDefs.jsx index f0dcd379e..b1538877a 100644 --- a/web/src/components/table/usage-logs/UsageLogsColumnDefs.jsx +++ b/web/src/components/table/usage-logs/UsageLogsColumnDefs.jsx @@ -133,6 +133,12 @@ function renderType(type, t) { {t('错误')} ); + case 6: + return ( + + {t('退款')} + + ); default: return ( @@ -368,7 +374,7 @@ export const getLogsColumns = ({ } return isAdminUser && - (record.type === 0 || record.type === 2 || record.type === 5) ? ( + (record.type === 0 || record.type === 2 || record.type === 5 || record.type === 6) ? ( @@ -459,7 +465,7 @@ export const getLogsColumns = ({ title: t('令牌'), dataIndex: 'token_name', render: (text, record, index) => { - return record.type === 0 || record.type === 2 || record.type === 5 ? ( + return record.type === 0 || record.type === 2 || record.type === 5 || record.type === 6 ? (
{ - if (record.type === 0 || record.type === 2 || record.type === 5) { + if (record.type === 0 || record.type === 2 || record.type === 5 || record.type === 6) { if (record.group) { return <>{renderGroup(record.group)}; } else { @@ -522,7 +528,7 @@ export const getLogsColumns = ({ title: t('模型'), dataIndex: 'model_name', render: (text, record, index) => { - return record.type === 0 || record.type === 2 || record.type === 5 ? ( + return record.type === 0 || record.type === 2 || record.type === 5 || record.type === 6 ? ( <>{renderModelName(record, copyText, t)} ) : ( <> @@ -589,7 +595,7 @@ export const getLogsColumns = ({ cacheText = `${t('缓存写')} ${formatTokenCount(cacheSummary.cacheWriteTokens)}`; } - return record.type === 0 || record.type === 2 || record.type === 5 ? ( + return record.type === 0 || record.type === 2 || record.type === 5 || record.type === 6 ? (
{ return parseInt(text) > 0 && - (record.type === 0 || record.type === 2 || record.type === 5) ? ( + (record.type === 0 || record.type === 2 || record.type === 5 || record.type === 6) ? ( <>{ {text} } ) : ( <> @@ -635,7 +641,7 @@ export const getLogsColumns = ({ title: t('花费'), dataIndex: 'quota', render: (text, record, index) => { - if (!(record.type === 0 || record.type === 2 || record.type === 5)) { + if (!(record.type === 0 || record.type === 2 || record.type === 5 || record.type === 6)) { return <>; } const other = getLogOther(record.other); @@ -722,6 +728,16 @@ export const getLogsColumns = ({ fixed: 'right', render: (text, record, index) => { let other = getLogOther(record.other); + if (record.type === 6) { + return ( + + {t('异步任务退款')} + + ); + } if (other == null || record.type !== 2) { return ( {t('管理')} {t('系统')} {t('错误')} + {t('退款')}
diff --git a/web/src/hooks/usage-logs/useUsageLogsData.jsx b/web/src/hooks/usage-logs/useUsageLogsData.jsx index 14c021e41..b69a7cf18 100644 --- a/web/src/hooks/usage-logs/useUsageLogsData.jsx +++ b/web/src/hooks/usage-logs/useUsageLogsData.jsx @@ -344,7 +344,7 @@ export const useLogsData = () => { let other = getLogOther(logs[i].other); let expandDataLocal = []; - if (isAdminUser && (logs[i].type === 0 || logs[i].type === 2)) { + if (isAdminUser && (logs[i].type === 0 || logs[i].type === 2 || logs[i].type === 6)) { expandDataLocal.push({ key: t('渠道信息'), value: `${logs[i].channel} - ${logs[i].channel_name || '[未知]'}`, @@ -535,6 +535,24 @@ export const useLogsData = () => { }); } } + if (logs[i].type === 6) { + if (other?.task_id) { + expandDataLocal.push({ + key: t('任务ID'), + value: other.task_id, + }); + } + if (other?.reason) { + expandDataLocal.push({ + key: t('失败原因'), + value: ( +
+ {other.reason} +
+ ), + }); + } + } if (other?.request_path) { expandDataLocal.push({ key: t('请求路径'), @@ -590,13 +608,13 @@ export const useLogsData = () => { ), }); } - if (isAdminUser) { + if (isAdminUser && logs[i].type !== 6) { expandDataLocal.push({ key: t('请求转换'), value: requestConversionDisplayValue(other?.request_conversion), }); } - if (isAdminUser) { + if (isAdminUser && logs[i].type !== 6) { let localCountMode = ''; if (other?.admin_info?.local_count_tokens) { localCountMode = t('本地计费'); diff --git a/web/src/i18n/locales/en.json b/web/src/i18n/locales/en.json index 8b2b08529..c25468339 100644 --- a/web/src/i18n/locales/en.json +++ b/web/src/i18n/locales/en.json @@ -2545,6 +2545,11 @@ "销毁容器": "Destroy Container", "销毁容器失败": "Failed to destroy container", "错误": "errors", + "退款": "Refund", + "错误详情": "Error Details", + "异步任务退款": "Async Task Refund", + "任务ID": "Task ID", + "失败原因": "Failure Reason", "键为分组名称,值为另一个 JSON 对象,键为分组名称,值为该分组的用户的特殊分组倍率,例如:{\"vip\": {\"default\": 0.5, \"test\": 1}},表示 vip 分组的用户在使用default分组的令牌时倍率为0.5,使用test分组时倍率为1": "The key is the group name, and the value is another JSON object. The key is the group name, and the value is the special group ratio for users in that group. For example: {\"vip\": {\"default\": 0.5, \"test\": 1}} means that users in the vip group have a ratio of 0.5 when using tokens from the default group, and a ratio of 1 when using tokens from the test group", "键为原状态码,值为要复写的状态码,仅影响本地判断": "The key is the original status code, and the value is the status code to override, only affects local judgment", "键为用户分组名称,值为操作映射对象。内层键以\"+:\"开头表示添加指定分组(键值为分组名称,值为描述),以\"-:\"开头表示移除指定分组(键值为分组名称),不带前缀的键直接添加该分组。例如:{\"vip\": {\"+:premium\": \"高级分组\", \"special\": \"特殊分组\", \"-:default\": \"默认分组\"}},表示 vip 分组的用户可以使用 premium 和 special 分组,同时移除 default 分组的访问权限": "Keys are user group names and values are operation mappings. Inner keys prefixed with \"+:\" add the specified group (key is the group name, value is the description); keys prefixed with \"-:\" remove the specified group; keys without a prefix add that group directly. Example: {\"vip\": {\"+:premium\": \"Advanced group\", \"special\": \"Special group\", \"-:default\": \"Default group\"}} means vip users can access the premium and special groups while removing access to the default group.", diff --git a/web/src/i18n/locales/fr.json b/web/src/i18n/locales/fr.json index d4c76db69..54fd3617e 100644 --- a/web/src/i18n/locales/fr.json +++ b/web/src/i18n/locales/fr.json @@ -2508,6 +2508,11 @@ "销毁容器": "Destroy Container", "销毁容器失败": "Failed to destroy container", "错误": "Erreur", + "退款": "Remboursement", + "错误详情": "Détails de l'erreur", + "异步任务退款": "Remboursement de tâche asynchrone", + "任务ID": "ID de tâche", + "失败原因": "Raison de l'échec", "键为分组名称,值为另一个 JSON 对象,键为分组名称,值为该分组的用户的特殊分组倍率,例如:{\"vip\": {\"default\": 0.5, \"test\": 1}},表示 vip 分组的用户在使用default分组的令牌时倍率为0.5,使用test分组时倍率为1": "La clé est le nom du groupe, la valeur est un autre objet JSON, la clé est le nom du groupe, la valeur est le ratio de groupe spécial des utilisateurs de ce groupe, par exemple : {\"vip\": {\"default\": 0.5, \"test\": 1}}, ce qui signifie que les utilisateurs du groupe vip ont un ratio de 0.5 lors de l'utilisation de jetons du groupe default et un ratio de 1 lors de l'utilisation du groupe test", "键为原状态码,值为要复写的状态码,仅影响本地判断": "La clé est le code d'état d'origine, la valeur est le code d'état à réécrire, n'affecte que le jugement local", "键为用户分组名称,值为操作映射对象。内层键以\"+:\"开头表示添加指定分组(键值为分组名称,值为描述),以\"-:\"开头表示移除指定分组(键值为分组名称),不带前缀的键直接添加该分组。例如:{\"vip\": {\"+:premium\": \"高级分组\", \"special\": \"特殊分组\", \"-:default\": \"默认分组\"}},表示 vip 分组的用户可以使用 premium 和 special 分组,同时移除 default 分组的访问权限": "La clé correspond au nom du groupe d'utilisateurs et la valeur à un objet de mappage des opérations. Les clés internes commençant par \"+:\" ajoutent le groupe indiqué (clé = nom du groupe, valeur = description), celles commençant par \"-:\" retirent le groupe indiqué, et les clés sans préfixe ajoutent directement ce groupe. Exemple : {\"vip\": {\"+:premium\": \"Groupe avancé\", \"special\": \"Groupe spécial\", \"-:default\": \"Groupe par défaut\"}} signifie que les utilisateurs du groupe vip peuvent accéder aux groupes premium et special tout en perdant l'accès au groupe default.", diff --git a/web/src/i18n/locales/ja.json b/web/src/i18n/locales/ja.json index 9ab727ec4..d9a49aa50 100644 --- a/web/src/i18n/locales/ja.json +++ b/web/src/i18n/locales/ja.json @@ -2491,6 +2491,11 @@ "销毁容器": "Destroy Container", "销毁容器失败": "Failed to destroy container", "错误": "エラー", + "退款": "返金", + "错误详情": "エラー詳細", + "异步任务退款": "非同期タスク返金", + "任务ID": "タスクID", + "失败原因": "失敗の原因", "键为分组名称,值为另一个 JSON 对象,键为分组名称,值为该分组的用户的特殊分组倍率,例如:{\"vip\": {\"default\": 0.5, \"test\": 1}},表示 vip 分组的用户在使用default分组的令牌时倍率为0.5,使用test分组时倍率为1": "キーはグループ名、値は別のJSONオブジェクトです。このオブジェクトのキーには、利用するトークンが属するグループ名を指定し、値にはそのユーザーグループに適用される特別な倍率を指定します。例:{\"vip\": {\"default\": 0.5, \"test\": 1}} は、vipグループのユーザーがdefaultグループのトークンを利用する際の倍率が0.5、testグループのトークンを利用する際の倍率が1になることを示します", "键为原状态码,值为要复写的状态码,仅影响本地判断": "キーは元のステータスコード、値は上書きするステータスコードで、ローカルでの判断にのみ影響します", "键为用户分组名称,值为操作映射对象。内层键以\"+:\"开头表示添加指定分组(键值为分组名称,值为描述),以\"-:\"开头表示移除指定分组(键值为分组名称),不带前缀的键直接添加该分组。例如:{\"vip\": {\"+:premium\": \"高级分组\", \"special\": \"特殊分组\", \"-:default\": \"默认分组\"}},表示 vip 分组的用户可以使用 premium 和 special 分组,同时移除 default 分组的访问权限": "Keys are user group names and values are operation mappings. Inner keys prefixed with \"+:\" add the specified group (key is the group name, value is the description); keys prefixed with \"-:\" remove the specified group; keys without a prefix add that group directly. Example: {\"vip\": {\"+:premium\": \"Advanced group\", \"special\": \"Special group\", \"-:default\": \"Default group\"}} means vip users can access the premium and special groups while removing access to the default group.", diff --git a/web/src/i18n/locales/ru.json b/web/src/i18n/locales/ru.json index 97e243d37..fc117a51a 100644 --- a/web/src/i18n/locales/ru.json +++ b/web/src/i18n/locales/ru.json @@ -2521,6 +2521,11 @@ "销毁容器": "Destroy Container", "销毁容器失败": "Failed to destroy container", "错误": "Ошибка", + "退款": "Возврат", + "错误详情": "Детали ошибки", + "异步任务退款": "Возврат асинхронной задачи", + "任务ID": "ID задачи", + "失败原因": "Причина ошибки", "键为分组名称,值为另一个 JSON 对象,键为分组名称,值为该分组的用户的特殊分组倍率,例如:{\"vip\": {\"default\": 0.5, \"test\": 1}},表示 vip 分组的用户在使用default分组的令牌时倍率为0.5,使用test分组时倍率为1": "Ключ - это имя группы, значение - другой JSON объект, ключ - имя группы, значение - специальный групповой коэффициент для пользователей этой группы, например: {\"vip\": {\"default\": 0.5, \"test\": 1}}, означает, что пользователи группы vip при использовании токенов группы default имеют коэффициент 0.5, при использовании группы test - коэффициент 1", "键为原状态码,值为要复写的状态码,仅影响本地判断": "Ключ - исходный код состояния, значение - код состояния для перезаписи, влияет только на локальную проверку", "键为用户分组名称,值为操作映射对象。内层键以\"+:\"开头表示添加指定分组(键值为分组名称,值为描述),以\"-:\"开头表示移除指定分组(键值为分组名称),不带前缀的键直接添加该分组。例如:{\"vip\": {\"+:premium\": \"高级分组\", \"special\": \"特殊分组\", \"-:default\": \"默认分组\"}},表示 vip 分组的用户可以使用 premium 和 special 分组,同时移除 default 分组的访问权限": "Ключ — это название группы пользователей, значение — объект сопоставления операций. Внутренние ключи с префиксом \"+:\" добавляют указанные группы (ключ — название группы, значение — описание), с префиксом \"-:\" удаляют указанные группы, без префикса — сразу добавляют эту группу. Пример: {\"vip\": {\"+:premium\": \"Продвинутая группа\", \"special\": \"Особая группа\", \"-:default\": \"Группа по умолчанию\"}} означает, что пользователи группы vip могут использовать группы premium и special, одновременно теряя доступ к группе default.", diff --git a/web/src/i18n/locales/vi.json b/web/src/i18n/locales/vi.json index 8875b1b5f..89d8715e2 100644 --- a/web/src/i18n/locales/vi.json +++ b/web/src/i18n/locales/vi.json @@ -3060,10 +3060,14 @@ "销毁容器失败": "Failed to destroy container", "锁定": "Khóa", "错误": "Lỗi", + "退款": "Hoàn tiền", "错误信息": "Thông tin lỗi", "错误日志": "Nhật ký lỗi", "错误码": "Mã lỗi", "错误详情": "Chi tiết lỗi", + "异步任务退款": "Hoàn tiền tác vụ bất đồng bộ", + "任务ID": "ID tác vụ", + "失败原因": "Nguyên nhân thất bại", "键为分组名称,值为另一个 JSON 对象,键为分组名称,值为该分组的用户的特殊分组倍率,例如:{\"vip\": {\"default\": 0.5, \"test\": 1}},表示 vip 分组的用户在使用default分组的令牌时倍率为0.5,使用test分组时倍率为1": "Khóa là tên nhóm và giá trị là một đối tượng JSON khác. Khóa là tên nhóm và giá trị là tỷ lệ nhóm đặc biệt cho người dùng trong nhóm đó. Ví dụ: {\"vip\": {\"default\": 0.5, \"test\": 1}} có nghĩa là người dùng trong nhóm vip có tỷ lệ 0.5 khi sử dụng mã thông báo từ nhóm default và tỷ lệ 1 khi sử dụng mã thông báo từ nhóm test.", "键为原状态码,值为要复写的状态码,仅影响本地判断": "Khóa là mã trạng thái gốc và giá trị là mã trạng thái cần ghi đè, chỉ ảnh hưởng đến phán đoán cục bộ", "键为用户分组名称,值为操作映射对象。内层键以\"+:\"开头表示添加指定分组(键值为分组名称,值为描述),以\"-:\"开头表示移除指定分组(键值为分组名称),不带前缀的键直接添加该分组。例如:{\"vip\": {\"+:premium\": \"高级分组\", \"special\": \"特殊分组\", \"-:default\": \"默认分组\"}},表示 vip 分组的用户可以使用 premium 和 special 分组,同时移除 default 分组的访问权限": "Keys are user group names and values are operation mappings. Inner keys prefixed with \"+:\" add the specified group (key is the group name, value is the description); keys prefixed with \"-:\" remove the specified group; keys without a prefix add that group directly. Example: {\"vip\": {\"+:premium\": \"Advanced group\", \"special\": \"Special group\", \"-:default\": \"Default group\"}} means vip users can access the premium and special groups while removing access to the default group.", diff --git a/web/src/i18n/locales/zh-CN.json b/web/src/i18n/locales/zh-CN.json index 43ce65b7a..3cfcc0326 100644 --- a/web/src/i18n/locales/zh-CN.json +++ b/web/src/i18n/locales/zh-CN.json @@ -2531,6 +2531,11 @@ "销毁容器": "销毁容器", "销毁容器失败": "销毁容器失败", "错误": "错误", + "退款": "退款", + "错误详情": "错误详情", + "异步任务退款": "异步任务退款", + "任务ID": "任务ID", + "失败原因": "失败原因", "键为分组名称,值为另一个 JSON 对象,键为分组名称,值为该分组的用户的特殊分组倍率,例如:{\"vip\": {\"default\": 0.5, \"test\": 1}},表示 vip 分组的用户在使用default分组的令牌时倍率为0.5,使用test分组时倍率为1": "键为分组名称,值为另一个 JSON 对象,键为分组名称,值为该分组的用户的特殊分组倍率,例如:{\"vip\": {\"default\": 0.5, \"test\": 1}},表示 vip 分组的用户在使用default分组的令牌时倍率为0.5,使用test分组时倍率为1", "键为原状态码,值为要复写的状态码,仅影响本地判断": "键为原状态码,值为要复写的状态码,仅影响本地判断", "键为用户分组名称,值为操作映射对象。内层键以\"+:\"开头表示添加指定分组(键值为分组名称,值为描述),以\"-:\"开头表示移除指定分组(键值为分组名称),不带前缀的键直接添加该分组。例如:{\"vip\": {\"+:premium\": \"高级分组\", \"special\": \"特殊分组\", \"-:default\": \"默认分组\"}},表示 vip 分组的用户可以使用 premium 和 special 分组,同时移除 default 分组的访问权限": "键为用户分组名称,值为操作映射对象。内层键以\"+:\"开头表示添加指定分组(键值为分组名称,值为描述),以\"-:\"开头表示移除指定分组(键值为分组名称),不带前缀的键直接添加该分组。例如:{\"vip\": {\"+:premium\": \"高级分组\", \"special\": \"特殊分组\", \"-:default\": \"默认分组\"}},表示 vip 分组的用户可以使用 premium 和 special 分组,同时移除 default 分组的访问权限", From a920d1f9258389a093286db03509aabefa434063 Mon Sep 17 00:00:00 2001 From: CaIon Date: Sat, 21 Feb 2026 23:05:58 +0800 Subject: [PATCH 13/41] refactor(relay): rename RelayTask to RelayTaskFetch and update routing - Renamed RelayTask function to RelayTaskFetch for clarity. - Updated routing in relay-router.go and video-router.go to use RelayTaskFetch for fetch operations. - Enhanced error handling in RelayTaskFetch function. - Adjusted task data conversion in TaskAdaptor to include task ID. --- controller/relay.go | 26 +++++++++++++++----------- relay/channel/task/sora/adaptor.go | 8 +++++++- router/relay-router.go | 4 ++-- router/video-router.go | 8 ++++---- 4 files changed, 28 insertions(+), 18 deletions(-) diff --git a/controller/relay.go b/controller/relay.go index e90d6dd0c..1477df8f7 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -450,6 +450,21 @@ func RelayNotFound(c *gin.Context) { }) } +func RelayTaskFetch(c *gin.Context) { + relayInfo, err := relaycommon.GenRelayInfo(c, types.RelayFormatTask, nil, nil) + if err != nil { + c.JSON(http.StatusInternalServerError, &dto.TaskError{ + Code: "gen_relay_info_failed", + Message: err.Error(), + StatusCode: http.StatusInternalServerError, + }) + return + } + if taskErr := relay.RelayTaskFetch(c, relayInfo.RelayMode); taskErr != nil { + respondTaskError(c, taskErr) + } +} + func RelayTask(c *gin.Context) { relayInfo, err := relaycommon.GenRelayInfo(c, types.RelayFormatTask, nil, nil) if err != nil { @@ -461,17 +476,6 @@ func RelayTask(c *gin.Context) { return } - // Fetch 路径:纯 DB 查询,不依赖上下文 channel,无需重试 - switch relayInfo.RelayMode { - case relayconstant.RelayModeSunoFetch, relayconstant.RelayModeSunoFetchByID, relayconstant.RelayModeVideoFetchByID: - if taskErr := relay.RelayTaskFetch(c, relayInfo.RelayMode); taskErr != nil { - respondTaskError(c, taskErr) - } - return - } - - // ── Submit 路径 ───────────────────────────────────────────────── - if taskErr := relay.ResolveOriginTask(c, relayInfo); taskErr != nil { respondTaskError(c, taskErr) return diff --git a/relay/channel/task/sora/adaptor.go b/relay/channel/task/sora/adaptor.go index 8faaf984f..bf2f70053 100644 --- a/relay/channel/task/sora/adaptor.go +++ b/relay/channel/task/sora/adaptor.go @@ -18,6 +18,7 @@ import ( "github.com/gin-gonic/gin" "github.com/pkg/errors" + "github.com/tidwall/sjson" ) // ============================ @@ -250,5 +251,10 @@ func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, e } func (a *TaskAdaptor) ConvertToOpenAIVideo(task *model.Task) ([]byte, error) { - return task.Data, nil + data := task.Data + var err error + if data, err = sjson.SetBytes(data, "id", task.TaskID); err != nil { + return nil, errors.Wrap(err, "set id failed") + } + return data, nil } diff --git a/router/relay-router.go b/router/relay-router.go index 04584945b..dcec439cb 100644 --- a/router/relay-router.go +++ b/router/relay-router.go @@ -174,8 +174,8 @@ func SetRelayRouter(router *gin.Engine) { relaySunoRouter.Use(middleware.TokenAuth(), middleware.Distribute()) { relaySunoRouter.POST("/submit/:action", controller.RelayTask) - relaySunoRouter.POST("/fetch", controller.RelayTask) - relaySunoRouter.GET("/fetch/:id", controller.RelayTask) + relaySunoRouter.POST("/fetch", controller.RelayTaskFetch) + relaySunoRouter.GET("/fetch/:id", controller.RelayTaskFetch) } relayGeminiRouter := router.Group("/v1beta") diff --git a/router/video-router.go b/router/video-router.go index d2bce42b2..875b0af86 100644 --- a/router/video-router.go +++ b/router/video-router.go @@ -19,14 +19,14 @@ func SetVideoRouter(router *gin.Engine) { videoV1Router.Use(middleware.TokenAuth(), middleware.Distribute()) { videoV1Router.POST("/video/generations", controller.RelayTask) - videoV1Router.GET("/video/generations/:task_id", controller.RelayTask) + videoV1Router.GET("/video/generations/:task_id", controller.RelayTaskFetch) videoV1Router.POST("/videos/:video_id/remix", controller.RelayTask) } // openai compatible API video routes // docs: https://platform.openai.com/docs/api-reference/videos/create { videoV1Router.POST("/videos", controller.RelayTask) - videoV1Router.GET("/videos/:task_id", controller.RelayTask) + videoV1Router.GET("/videos/:task_id", controller.RelayTaskFetch) } klingV1Router := router.Group("/kling/v1") @@ -34,8 +34,8 @@ func SetVideoRouter(router *gin.Engine) { { klingV1Router.POST("/videos/text2video", controller.RelayTask) klingV1Router.POST("/videos/image2video", controller.RelayTask) - klingV1Router.GET("/videos/text2video/:task_id", controller.RelayTask) - klingV1Router.GET("/videos/image2video/:task_id", controller.RelayTask) + klingV1Router.GET("/videos/text2video/:task_id", controller.RelayTaskFetch) + klingV1Router.GET("/videos/image2video/:task_id", controller.RelayTaskFetch) } // Jimeng official API routes - direct mapping to official API format From 76892e8376e46feb470034977e95cbad3273b717 Mon Sep 17 00:00:00 2001 From: CaIon Date: Sat, 21 Feb 2026 23:20:31 +0800 Subject: [PATCH 14/41] refactor(relay): enhance remix logic for billing context extraction - Updated the remix handling in ResolveOriginTask to prioritize extracting OtherRatios from the BillingContext of the original task if available. - Retained the previous logic for extracting seconds and size from task data as a fallback. - Improved clarity and maintainability of the remix logic by separating the new and old approaches. --- relay/relay_task.go | 38 +++++++++++++++++++++++--------------- 1 file changed, 23 insertions(+), 15 deletions(-) diff --git a/relay/relay_task.go b/relay/relay_task.go index 7c6724d80..cc4d0e450 100644 --- a/relay/relay_task.go +++ b/relay/relay_task.go @@ -106,21 +106,29 @@ func ResolveOriginTask(c *gin.Context, info *relaycommon.RelayInfo) *dto.TaskErr // 提取 remix 参数(时长、分辨率 → OtherRatios) if info.Action == constant.TaskActionRemix { - var taskData map[string]interface{} - _ = common.Unmarshal(originTask.Data, &taskData) - secondsStr, _ := taskData["seconds"].(string) - seconds, _ := strconv.Atoi(secondsStr) - if seconds <= 0 { - seconds = 4 - } - sizeStr, _ := taskData["size"].(string) - if info.PriceData.OtherRatios == nil { - info.PriceData.OtherRatios = map[string]float64{} - } - info.PriceData.OtherRatios["seconds"] = float64(seconds) - info.PriceData.OtherRatios["size"] = 1 - if sizeStr == "1792x1024" || sizeStr == "1024x1792" { - info.PriceData.OtherRatios["size"] = 1.666667 + if originTask.PrivateData.BillingContext != nil { + // 新的 remix 逻辑:直接从原始任务的 BillingContext 中提取 OtherRatios(如果存在) + for s, f := range originTask.PrivateData.BillingContext.OtherRatios { + info.PriceData.AddOtherRatio(s, f) + } + } else { + // 旧的 remix 逻辑:直接从 task data 解析 seconds 和 size(如果存在) + var taskData map[string]interface{} + _ = common.Unmarshal(originTask.Data, &taskData) + secondsStr, _ := taskData["seconds"].(string) + seconds, _ := strconv.Atoi(secondsStr) + if seconds <= 0 { + seconds = 4 + } + sizeStr, _ := taskData["size"].(string) + if info.PriceData.OtherRatios == nil { + info.PriceData.OtherRatios = map[string]float64{} + } + info.PriceData.OtherRatios["seconds"] = float64(seconds) + info.PriceData.OtherRatios["size"] = 1 + if sizeStr == "1792x1024" || sizeStr == "1024x1792" { + info.PriceData.OtherRatios["size"] = 1.666667 + } } } From cda540180b0c2b0c8e9ee888b37110e15a8663a5 Mon Sep 17 00:00:00 2001 From: CaIon Date: Sat, 21 Feb 2026 23:47:55 +0800 Subject: [PATCH 15/41] refactor(relay): improve channel locking and retry logic in RelayTask - Enhanced the RelayTask function to utilize a locked channel when available, allowing for better reuse during retries. - Updated error handling to ensure proper context setup for the selected channel. - Clarified comments in ResolveOriginTask regarding channel locking and retry behavior. - Introduced a new field in TaskRelayInfo to store the locked channel object, improving type safety and reducing import cycles. --- controller/relay.go | 23 ++++++++++++++++++----- relay/common/relay_info.go | 5 +++++ relay/relay_task.go | 26 +++++++++++++------------- 3 files changed, 36 insertions(+), 18 deletions(-) diff --git a/controller/relay.go b/controller/relay.go index 1477df8f7..6951974c5 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -497,11 +497,24 @@ func RelayTask(c *gin.Context) { } for ; retryParam.GetRetry() <= common.RetryTimes; retryParam.IncreaseRetry() { - channel, channelErr := getChannel(c, relayInfo, retryParam) - if channelErr != nil { - logger.LogError(c, channelErr.Error()) - taskErr = service.TaskErrorWrapperLocal(channelErr.Err, "get_channel_failed", http.StatusInternalServerError) - break + var channel *model.Channel + + if lockedCh, ok := relayInfo.LockedChannel.(*model.Channel); ok && lockedCh != nil { + channel = lockedCh + if retryParam.GetRetry() > 0 { + if setupErr := middleware.SetupContextForSelectedChannel(c, channel, relayInfo.OriginModelName); setupErr != nil { + taskErr = service.TaskErrorWrapperLocal(setupErr.Err, "setup_locked_channel_failed", http.StatusInternalServerError) + break + } + } + } else { + var channelErr *types.NewAPIError + channel, channelErr = getChannel(c, relayInfo, retryParam) + if channelErr != nil { + logger.LogError(c, channelErr.Error()) + taskErr = service.TaskErrorWrapperLocal(channelErr.Err, "get_channel_failed", http.StatusInternalServerError) + break + } } addUsedChannel(c, channel.Id) diff --git a/relay/common/relay_info.go b/relay/common/relay_info.go index b68826812..541f1b9f8 100644 --- a/relay/common/relay_info.go +++ b/relay/common/relay_info.go @@ -619,6 +619,11 @@ type TaskRelayInfo struct { PublicTaskID string ConsumeQuota bool + + // LockedChannel holds the full channel object when the request is bound to + // a specific channel (e.g., remix on origin task's channel). Stored as any + // to avoid an import cycle with model; callers type-assert to *model.Channel. + LockedChannel any } type TaskSubmitReq struct { diff --git a/relay/relay_task.go b/relay/relay_task.go index cc4d0e450..8d0e61d72 100644 --- a/relay/relay_task.go +++ b/relay/relay_task.go @@ -32,8 +32,9 @@ type TaskSubmitResult struct { } // ResolveOriginTask 处理基于已有任务的提交(remix / continuation): -// 查找原始任务、从中提取模型名称、将渠道锁定到原始任务的渠道(并通过 -// specific_channel_id 禁止重试),以及提取 OtherRatios(时长、分辨率)。 +// 查找原始任务、从中提取模型名称、将渠道锁定到原始任务的渠道 +// (通过 info.LockedChannel,重试时复用同一渠道并轮换 key), +// 以及提取 OtherRatios(时长、分辨率)。 // 该函数在控制器的重试循环之前调用一次,其结果通过 info 字段和上下文持久化。 func ResolveOriginTask(c *gin.Context, info *relaycommon.RelayInfo) *dto.TaskError { // 检测 remix action @@ -77,15 +78,17 @@ func ResolveOriginTask(c *gin.Context, info *relaycommon.RelayInfo) *dto.TaskErr } } - // 锁定到原始任务的渠道(如果与当前选中的不同) + // 锁定到原始任务的渠道(重试时复用同一渠道,轮换 key) + ch, err := model.GetChannelById(originTask.ChannelId, true) + if err != nil { + return service.TaskErrorWrapperLocal(err, "channel_not_found", http.StatusBadRequest) + } + if ch.Status != common.ChannelStatusEnabled { + return service.TaskErrorWrapperLocal(errors.New("the channel of the origin task is disabled"), "task_channel_disable", http.StatusBadRequest) + } + info.LockedChannel = ch + if originTask.ChannelId != info.ChannelId { - ch, err := model.GetChannelById(originTask.ChannelId, true) - if err != nil { - return service.TaskErrorWrapperLocal(err, "channel_not_found", http.StatusBadRequest) - } - if ch.Status != common.ChannelStatusEnabled { - return service.TaskErrorWrapperLocal(errors.New("the channel of the origin task is disabled"), "task_channel_disable", http.StatusBadRequest) - } key, _, newAPIError := ch.GetNextEnabledKey() if newAPIError != nil { return service.TaskErrorWrapper(newAPIError, "channel_no_available_key", newAPIError.StatusCode) @@ -101,9 +104,6 @@ func ResolveOriginTask(c *gin.Context, info *relaycommon.RelayInfo) *dto.TaskErr info.ApiKey = key } - // 渠道已锁定到原始任务 → 禁止重试切换到其他渠道 - c.Set("specific_channel_id", fmt.Sprintf("%d", originTask.ChannelId)) - // 提取 remix 参数(时长、分辨率 → OtherRatios) if info.Action == constant.TaskActionRemix { if originTask.PrivateData.BillingContext != nil { From 5ec4633cb8ed92c7c863a106fdf9b5cfa389c66c Mon Sep 17 00:00:00 2001 From: CaIon Date: Sun, 22 Feb 2026 00:52:35 +0800 Subject: [PATCH 16/41] refactor(task): add CAS-guarded updates to prevent concurrent billing conflicts Replace all bare task.Update() (DB.Save) calls with UpdateWithStatus(), which adds a WHERE status = ? guard to prevent concurrent processes from overwriting each other's state transitions. Key changes: model/task.go: - Add taskSnapshot struct with Equal() method for change detection - Add Snapshot() method to capture pre-update state - Add UpdateWithStatus(fromStatus) using DB.Where().Save() for CAS semantics with full-struct save (no explicit field listing needed) model/midjourney.go: - Add UpdateWithStatus(fromStatus string) with same CAS pattern service/task_polling.go (updateVideoSingleTask): - Snapshot before processing upstream response; skip DB write if unchanged - Terminal transitions (SUCCESS/FAILURE) use UpdateWithStatus CAS: billing/refund only executes if this process wins the transition - Non-terminal updates also use UpdateWithStatus to prevent overwriting a concurrent terminal transition back to IN_PROGRESS - Defer settleTaskBillingOnComplete to after CAS check (shouldSettle flag) relay/relay_task.go (tryRealtimeFetch): - Add snapshot + change detection; use UpdateWithStatus for CAS safety controller/midjourney.go (UpdateMidjourneyTaskBulk): - Capture preStatus before mutations; use UpdateWithStatus CAS - Gate refund (IncreaseUserQuota) on CAS success (won && shouldReturnQuota) This prevents the multi-instance race condition where: 1. Instance A reads task (IN_PROGRESS), fetches upstream (still IN_PROGRESS) 2. Instance B reads same task, fetches upstream (now SUCCESS), writes SUCCESS 3. Instance A's bare Save() overwrites SUCCESS back to IN_PROGRESS --- controller/midjourney.go | 17 ++++---- model/midjourney.go | 11 +++++ model/task.go | 91 ++++++++++++++++++---------------------- relay/relay_task.go | 7 +++- service/task_polling.go | 43 ++++++++++++------- 5 files changed, 95 insertions(+), 74 deletions(-) diff --git a/controller/midjourney.go b/controller/midjourney.go index c480c12bb..4045a5509 100644 --- a/controller/midjourney.go +++ b/controller/midjourney.go @@ -130,6 +130,7 @@ func UpdateMidjourneyTaskBulk() { if !checkMjTaskNeedUpdate(task, responseItem) { continue } + preStatus := task.Status task.Code = 1 task.Progress = responseItem.Progress task.PromptEn = responseItem.PromptEn @@ -172,18 +173,16 @@ func UpdateMidjourneyTaskBulk() { shouldReturnQuota = true } } - err = task.Update() + won, err := task.UpdateWithStatus(preStatus) if err != nil { logger.LogError(ctx, "UpdateMidjourneyTask task error: "+err.Error()) - } else { - if shouldReturnQuota { - err = model.IncreaseUserQuota(task.UserId, task.Quota, false) - if err != nil { - logger.LogError(ctx, "fail to increase user quota: "+err.Error()) - } - logContent := fmt.Sprintf("构图失败 %s,补偿 %s", task.MjId, logger.LogQuota(task.Quota)) - model.RecordLog(task.UserId, model.LogTypeSystem, logContent) + } else if won && shouldReturnQuota { + err = model.IncreaseUserQuota(task.UserId, task.Quota, false) + if err != nil { + logger.LogError(ctx, "fail to increase user quota: "+err.Error()) } + logContent := fmt.Sprintf("构图失败 %s,补偿 %s", task.MjId, logger.LogQuota(task.Quota)) + model.RecordLog(task.UserId, model.LogTypeSystem, logContent) } } } diff --git a/model/midjourney.go b/model/midjourney.go index c6ef5de5b..9867e8a96 100644 --- a/model/midjourney.go +++ b/model/midjourney.go @@ -157,6 +157,17 @@ func (midjourney *Midjourney) Update() error { return err } +// UpdateWithStatus performs a conditional UPDATE guarded by fromStatus (CAS). +// Returns (true, nil) if this caller won the update, (false, nil) if +// another process already moved the task out of fromStatus. +func (midjourney *Midjourney) UpdateWithStatus(fromStatus string) (bool, error) { + result := DB.Where("status = ?", fromStatus).Save(midjourney) + if result.Error != nil { + return false, result.Error + } + return result.RowsAffected > 0, nil +} + func MjBulkUpdate(mjIds []string, params map[string]any) error { return DB.Model(&Midjourney{}). Where("mj_id in (?)", mjIds). diff --git a/model/task.go b/model/task.go index 592643ebb..4d1482f8b 100644 --- a/model/task.go +++ b/model/task.go @@ -1,6 +1,7 @@ package model import ( + "bytes" "database/sql/driver" "encoding/json" "time" @@ -340,38 +341,59 @@ func GetByTaskIds(userId int, taskIds []any) ([]*Task, error) { return task, nil } -func TaskUpdateProgress(id int64, progress string) error { - return DB.Model(&Task{}).Where("id = ?", id).Update("progress", progress).Error -} - func (Task *Task) Insert() error { var err error err = DB.Create(Task).Error return err } +type taskSnapshot struct { + Status TaskStatus + Progress string + StartTime int64 + FinishTime int64 + FailReason string + ResultURL string + Data json.RawMessage +} + +func (s taskSnapshot) Equal(other taskSnapshot) bool { + return s.Status == other.Status && + s.Progress == other.Progress && + s.StartTime == other.StartTime && + s.FinishTime == other.FinishTime && + s.FailReason == other.FailReason && + s.ResultURL == other.ResultURL && + bytes.Equal(s.Data, other.Data) +} + +func (t *Task) Snapshot() taskSnapshot { + return taskSnapshot{ + Status: t.Status, + Progress: t.Progress, + StartTime: t.StartTime, + FinishTime: t.FinishTime, + FailReason: t.FailReason, + ResultURL: t.PrivateData.ResultURL, + Data: t.Data, + } +} + func (Task *Task) Update() error { var err error err = DB.Save(Task).Error return err } -func TaskBulkUpdate(TaskIds []string, params map[string]any) error { - if len(TaskIds) == 0 { - return nil +// UpdateWithStatus performs a conditional UPDATE guarded by fromStatus (CAS). +// Returns (true, nil) if this caller won the update, (false, nil) if +// another process already moved the task out of fromStatus. +func (t *Task) UpdateWithStatus(fromStatus TaskStatus) (bool, error) { + result := DB.Where("status = ?", fromStatus).Save(t) + if result.Error != nil { + return false, result.Error } - return DB.Model(&Task{}). - Where("task_id in (?)", TaskIds). - Updates(params).Error -} - -func TaskBulkUpdateByTaskIds(taskIDs []int64, params map[string]any) error { - if len(taskIDs) == 0 { - return nil - } - return DB.Model(&Task{}). - Where("id in (?)", taskIDs). - Updates(params).Error + return result.RowsAffected > 0, nil } func TaskBulkUpdateByID(ids []int64, params map[string]any) error { @@ -388,37 +410,6 @@ type TaskQuotaUsage struct { Count float64 `json:"count"` } -func SumUsedTaskQuota(queryParams SyncTaskQueryParams) (stat []TaskQuotaUsage, err error) { - query := DB.Model(Task{}) - // 添加过滤条件 - if queryParams.ChannelID != "" { - query = query.Where("channel_id = ?", queryParams.ChannelID) - } - if queryParams.UserID != "" { - query = query.Where("user_id = ?", queryParams.UserID) - } - if len(queryParams.UserIDs) != 0 { - query = query.Where("user_id in (?)", queryParams.UserIDs) - } - if queryParams.TaskID != "" { - query = query.Where("task_id = ?", queryParams.TaskID) - } - if queryParams.Action != "" { - query = query.Where("action = ?", queryParams.Action) - } - if queryParams.Status != "" { - query = query.Where("status = ?", queryParams.Status) - } - if queryParams.StartTimestamp != 0 { - query = query.Where("submit_time >= ?", queryParams.StartTimestamp) - } - if queryParams.EndTimestamp != 0 { - query = query.Where("submit_time <= ?", queryParams.EndTimestamp) - } - err = query.Select("mode, sum(quota) as count").Group("mode").Find(&stat).Error - return stat, err -} - // TaskCountAllTasks returns total tasks that match the given query params (admin usage) func TaskCountAllTasks(queryParams SyncTaskQueryParams) int64 { var total int64 diff --git a/relay/relay_task.go b/relay/relay_task.go index 8d0e61d72..cd43e6ebb 100644 --- a/relay/relay_task.go +++ b/relay/relay_task.go @@ -444,6 +444,8 @@ func tryRealtimeFetch(task *model.Task, isOpenAIVideoAPI bool) []byte { return nil } + snap := task.Snapshot() + // 将上游最新状态更新到 task if ti.Status != "" { task.Status = model.TaskStatus(ti.Status) @@ -459,7 +461,10 @@ func tryRealtimeFetch(task *model.Task, isOpenAIVideoAPI bool) []byte { // No URL from adaptor — construct proxy URL using public task ID task.PrivateData.ResultURL = taskcommon.BuildProxyURL(task.TaskID) } - _ = task.Update() + + if !snap.Equal(task.Snapshot()) { + _, _ = task.UpdateWithStatus(snap.Status) + } // OpenAI Video API 由调用者的 ConvertToOpenAIVideo 分支处理 if isOpenAIVideoAPI { diff --git a/service/task_polling.go b/service/task_polling.go index efbad8afa..7e92d14ba 100644 --- a/service/task_polling.go +++ b/service/task_polling.go @@ -319,6 +319,8 @@ func updateVideoSingleTask(ctx context.Context, adaptor TaskPollingAdaptor, ch * logger.LogDebug(ctx, fmt.Sprintf("updateVideoSingleTask response: %s", string(responseBody))) + snap := task.Snapshot() + taskResult := &relaycommon.TaskInfo{} // try parse as New API response format var responseItems dto.TaskResponse[model.Task] @@ -344,10 +346,9 @@ func updateVideoSingleTask(ctx context.Context, adaptor TaskPollingAdaptor, ch * taskResult = relaycommon.FailTaskInfo("upstream returned empty status") } - // 记录原本的状态,防止重复退款 shouldRefund := false + shouldSettle := false quota := task.Quota - preStatus := task.Status task.Status = model.TaskStatus(taskResult.Status) switch taskResult.Status { @@ -374,9 +375,7 @@ func updateVideoSingleTask(ctx context.Context, adaptor TaskPollingAdaptor, ch * // No URL from adaptor — construct proxy URL using public task ID task.PrivateData.ResultURL = taskcommon.BuildProxyURL(task.TaskID) } - - // 完成时计费调整:优先由 adaptor 计算,回退到 token 重算 - settleTaskBillingOnComplete(ctx, adaptor, task, taskResult) + shouldSettle = true case model.TaskStatusFailure: logger.LogJson(ctx, fmt.Sprintf("Task %s failed", taskId), task) task.Status = model.TaskStatusFailure @@ -388,23 +387,39 @@ func updateVideoSingleTask(ctx context.Context, adaptor TaskPollingAdaptor, ch * logger.LogInfo(ctx, fmt.Sprintf("Task %s failed: %s", task.TaskID, task.FailReason)) taskResult.Progress = taskcommon.ProgressComplete if quota != 0 { - if preStatus != model.TaskStatusFailure { - shouldRefund = true - } else { - logger.LogWarn(ctx, fmt.Sprintf("Task %s already in failure status, skip refund", task.TaskID)) - } + shouldRefund = true } default: - return fmt.Errorf("unknown task status %s for task %s", taskResult.Status, taskId) + return fmt.Errorf("unknown task status %s for task %s", taskResult.Status, task.TaskID) } if taskResult.Progress != "" { task.Progress = taskResult.Progress } - if err := task.Update(); err != nil { - common.SysLog("UpdateVideoTask task error: " + err.Error()) - shouldRefund = false + + isDone := task.Status == model.TaskStatusSuccess || task.Status == model.TaskStatusFailure + if isDone && snap.Status != task.Status { + won, err := task.UpdateWithStatus(snap.Status) + if err != nil { + logger.LogError(ctx, fmt.Sprintf("UpdateWithStatus failed for task %s: %s", task.TaskID, err.Error())) + shouldRefund = false + shouldSettle = false + } else if !won { + logger.LogWarn(ctx, fmt.Sprintf("Task %s already transitioned by another process, skip billing", task.TaskID)) + shouldRefund = false + shouldSettle = false + } + } else if !snap.Equal(task.Snapshot()) { + if _, err := task.UpdateWithStatus(snap.Status); err != nil { + logger.LogError(ctx, fmt.Sprintf("Failed to update task %s: %s", task.TaskID, err.Error())) + } + } else { + // No changes, skip update + logger.LogDebug(ctx, fmt.Sprintf("No update needed for task %s", task.TaskID)) } + if shouldSettle { + settleTaskBillingOnComplete(ctx, adaptor, task, taskResult) + } if shouldRefund { RefundTaskQuota(ctx, task, task.FailReason) } From 9976b311efa92da33a9355d4048a24f33014b1c3 Mon Sep 17 00:00:00 2001 From: CaIon Date: Sun, 22 Feb 2026 01:25:04 +0800 Subject: [PATCH 17/41] refactor(task): enhance UpdateWithStatus for CAS updates and add integration tests - Updated UpdateWithStatus method to use Model().Select("*").Updates() for conditional updates, preventing GORM's INSERT fallback. - Introduced comprehensive integration tests for UpdateWithStatus, covering scenarios for winning and losing CAS updates, as well as concurrent updates. - Added task_cas_test.go to validate the new behavior and ensure data integrity during concurrent state transitions. --- model/midjourney.go | 4 +- model/task.go | 6 +- model/task_cas_test.go | 217 +++++++++++++ service/task_billing_test.go | 606 +++++++++++++++++++++++++++++++++++ 4 files changed, 831 insertions(+), 2 deletions(-) create mode 100644 model/task_cas_test.go create mode 100644 service/task_billing_test.go diff --git a/model/midjourney.go b/model/midjourney.go index 9867e8a96..e1a8d772b 100644 --- a/model/midjourney.go +++ b/model/midjourney.go @@ -160,8 +160,10 @@ func (midjourney *Midjourney) Update() error { // UpdateWithStatus performs a conditional UPDATE guarded by fromStatus (CAS). // Returns (true, nil) if this caller won the update, (false, nil) if // another process already moved the task out of fromStatus. +// UpdateWithStatus performs a conditional UPDATE guarded by fromStatus (CAS). +// Uses Model().Select("*").Updates() to avoid GORM Save()'s INSERT fallback. func (midjourney *Midjourney) UpdateWithStatus(fromStatus string) (bool, error) { - result := DB.Where("status = ?", fromStatus).Save(midjourney) + result := DB.Model(midjourney).Where("status = ?", fromStatus).Select("*").Updates(midjourney) if result.Error != nil { return false, result.Error } diff --git a/model/task.go b/model/task.go index 4d1482f8b..0cf6bd47e 100644 --- a/model/task.go +++ b/model/task.go @@ -388,8 +388,12 @@ func (Task *Task) Update() error { // UpdateWithStatus performs a conditional UPDATE guarded by fromStatus (CAS). // Returns (true, nil) if this caller won the update, (false, nil) if // another process already moved the task out of fromStatus. +// +// Uses Model().Select("*").Updates() instead of Save() because GORM's Save +// falls back to INSERT ON CONFLICT when the WHERE-guarded UPDATE matches +// zero rows, which silently bypasses the CAS guard. func (t *Task) UpdateWithStatus(fromStatus TaskStatus) (bool, error) { - result := DB.Where("status = ?", fromStatus).Save(t) + result := DB.Model(t).Where("status = ?", fromStatus).Select("*").Updates(t) if result.Error != nil { return false, result.Error } diff --git a/model/task_cas_test.go b/model/task_cas_test.go new file mode 100644 index 000000000..3449c6d26 --- /dev/null +++ b/model/task_cas_test.go @@ -0,0 +1,217 @@ +package model + +import ( + "encoding/json" + "os" + "sync" + "testing" + "time" + + "github.com/QuantumNous/new-api/common" + "github.com/glebarez/sqlite" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "gorm.io/gorm" +) + +func TestMain(m *testing.M) { + db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{}) + if err != nil { + panic("failed to open test db: " + err.Error()) + } + DB = db + LOG_DB = db + + common.UsingSQLite = true + common.RedisEnabled = false + common.BatchUpdateEnabled = false + common.LogConsumeEnabled = true + + sqlDB, err := db.DB() + if err != nil { + panic("failed to get sql.DB: " + err.Error()) + } + sqlDB.SetMaxOpenConns(1) + + if err := db.AutoMigrate(&Task{}, &User{}, &Token{}, &Log{}, &Channel{}); err != nil { + panic("failed to migrate: " + err.Error()) + } + + os.Exit(m.Run()) +} + +func truncateTables(t *testing.T) { + t.Helper() + t.Cleanup(func() { + DB.Exec("DELETE FROM tasks") + DB.Exec("DELETE FROM users") + DB.Exec("DELETE FROM tokens") + DB.Exec("DELETE FROM logs") + DB.Exec("DELETE FROM channels") + }) +} + +func insertTask(t *testing.T, task *Task) { + t.Helper() + task.CreatedAt = time.Now().Unix() + task.UpdatedAt = time.Now().Unix() + require.NoError(t, DB.Create(task).Error) +} + +// --------------------------------------------------------------------------- +// Snapshot / Equal — pure logic tests (no DB) +// --------------------------------------------------------------------------- + +func TestSnapshotEqual_Same(t *testing.T) { + s := taskSnapshot{ + Status: TaskStatusInProgress, + Progress: "50%", + StartTime: 1000, + FinishTime: 0, + FailReason: "", + ResultURL: "", + Data: json.RawMessage(`{"key":"value"}`), + } + assert.True(t, s.Equal(s)) +} + +func TestSnapshotEqual_DifferentStatus(t *testing.T) { + a := taskSnapshot{Status: TaskStatusInProgress, Data: json.RawMessage(`{}`)} + b := taskSnapshot{Status: TaskStatusSuccess, Data: json.RawMessage(`{}`)} + assert.False(t, a.Equal(b)) +} + +func TestSnapshotEqual_DifferentProgress(t *testing.T) { + a := taskSnapshot{Status: TaskStatusInProgress, Progress: "30%", Data: json.RawMessage(`{}`)} + b := taskSnapshot{Status: TaskStatusInProgress, Progress: "60%", Data: json.RawMessage(`{}`)} + assert.False(t, a.Equal(b)) +} + +func TestSnapshotEqual_DifferentData(t *testing.T) { + a := taskSnapshot{Status: TaskStatusInProgress, Data: json.RawMessage(`{"a":1}`)} + b := taskSnapshot{Status: TaskStatusInProgress, Data: json.RawMessage(`{"a":2}`)} + assert.False(t, a.Equal(b)) +} + +func TestSnapshotEqual_NilVsEmpty(t *testing.T) { + a := taskSnapshot{Status: TaskStatusInProgress, Data: nil} + b := taskSnapshot{Status: TaskStatusInProgress, Data: json.RawMessage{}} + // bytes.Equal(nil, []byte{}) == true + assert.True(t, a.Equal(b)) +} + +func TestSnapshot_Roundtrip(t *testing.T) { + task := &Task{ + Status: TaskStatusInProgress, + Progress: "42%", + StartTime: 1234, + FinishTime: 5678, + FailReason: "timeout", + PrivateData: TaskPrivateData{ + ResultURL: "https://example.com/result.mp4", + }, + Data: json.RawMessage(`{"model":"test-model"}`), + } + snap := task.Snapshot() + assert.Equal(t, task.Status, snap.Status) + assert.Equal(t, task.Progress, snap.Progress) + assert.Equal(t, task.StartTime, snap.StartTime) + assert.Equal(t, task.FinishTime, snap.FinishTime) + assert.Equal(t, task.FailReason, snap.FailReason) + assert.Equal(t, task.PrivateData.ResultURL, snap.ResultURL) + assert.JSONEq(t, string(task.Data), string(snap.Data)) +} + +// --------------------------------------------------------------------------- +// UpdateWithStatus CAS — DB integration tests +// --------------------------------------------------------------------------- + +func TestUpdateWithStatus_Win(t *testing.T) { + truncateTables(t) + + task := &Task{ + TaskID: "task_cas_win", + Status: TaskStatusInProgress, + Progress: "50%", + Data: json.RawMessage(`{}`), + } + insertTask(t, task) + + task.Status = TaskStatusSuccess + task.Progress = "100%" + won, err := task.UpdateWithStatus(TaskStatusInProgress) + require.NoError(t, err) + assert.True(t, won) + + var reloaded Task + require.NoError(t, DB.First(&reloaded, task.ID).Error) + assert.EqualValues(t, TaskStatusSuccess, reloaded.Status) + assert.Equal(t, "100%", reloaded.Progress) +} + +func TestUpdateWithStatus_Lose(t *testing.T) { + truncateTables(t) + + task := &Task{ + TaskID: "task_cas_lose", + Status: TaskStatusFailure, + Data: json.RawMessage(`{}`), + } + insertTask(t, task) + + task.Status = TaskStatusSuccess + won, err := task.UpdateWithStatus(TaskStatusInProgress) // wrong fromStatus + require.NoError(t, err) + assert.False(t, won) + + var reloaded Task + require.NoError(t, DB.First(&reloaded, task.ID).Error) + assert.EqualValues(t, TaskStatusFailure, reloaded.Status) // unchanged +} + +func TestUpdateWithStatus_ConcurrentWinner(t *testing.T) { + truncateTables(t) + + task := &Task{ + TaskID: "task_cas_race", + Status: TaskStatusInProgress, + Quota: 1000, + Data: json.RawMessage(`{}`), + } + insertTask(t, task) + + const goroutines = 5 + wins := make([]bool, goroutines) + var wg sync.WaitGroup + wg.Add(goroutines) + + for i := 0; i < goroutines; i++ { + go func(idx int) { + defer wg.Done() + t := &Task{} + *t = Task{ + ID: task.ID, + TaskID: task.TaskID, + Status: TaskStatusSuccess, + Progress: "100%", + Quota: task.Quota, + Data: json.RawMessage(`{}`), + } + t.CreatedAt = task.CreatedAt + t.UpdatedAt = time.Now().Unix() + won, err := t.UpdateWithStatus(TaskStatusInProgress) + if err == nil { + wins[idx] = won + } + }(i) + } + wg.Wait() + + winCount := 0 + for _, w := range wins { + if w { + winCount++ + } + } + assert.Equal(t, 1, winCount, "exactly one goroutine should win the CAS") +} diff --git a/service/task_billing_test.go b/service/task_billing_test.go new file mode 100644 index 000000000..6c2d231d5 --- /dev/null +++ b/service/task_billing_test.go @@ -0,0 +1,606 @@ +package service + +import ( + "context" + "encoding/json" + "os" + "testing" + "time" + + "github.com/QuantumNous/new-api/common" + "github.com/QuantumNous/new-api/model" + "github.com/glebarez/sqlite" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "gorm.io/gorm" +) + +func TestMain(m *testing.M) { + db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{}) + if err != nil { + panic("failed to open test db: " + err.Error()) + } + sqlDB, err := db.DB() + if err != nil { + panic("failed to get sql.DB: " + err.Error()) + } + sqlDB.SetMaxOpenConns(1) + + model.DB = db + model.LOG_DB = db + + common.UsingSQLite = true + common.RedisEnabled = false + common.BatchUpdateEnabled = false + common.LogConsumeEnabled = true + + if err := db.AutoMigrate( + &model.Task{}, + &model.User{}, + &model.Token{}, + &model.Log{}, + &model.Channel{}, + &model.UserSubscription{}, + ); err != nil { + panic("failed to migrate: " + err.Error()) + } + + os.Exit(m.Run()) +} + +// --------------------------------------------------------------------------- +// Seed helpers +// --------------------------------------------------------------------------- + +func truncate(t *testing.T) { + t.Helper() + t.Cleanup(func() { + model.DB.Exec("DELETE FROM tasks") + model.DB.Exec("DELETE FROM users") + model.DB.Exec("DELETE FROM tokens") + model.DB.Exec("DELETE FROM logs") + model.DB.Exec("DELETE FROM channels") + model.DB.Exec("DELETE FROM user_subscriptions") + }) +} + +func seedUser(t *testing.T, id int, quota int) { + t.Helper() + user := &model.User{Id: id, Username: "test_user", Quota: quota, Status: common.UserStatusEnabled} + require.NoError(t, model.DB.Create(user).Error) +} + +func seedToken(t *testing.T, id int, userId int, key string, remainQuota int) { + t.Helper() + token := &model.Token{ + Id: id, + UserId: userId, + Key: key, + Name: "test_token", + Status: common.TokenStatusEnabled, + RemainQuota: remainQuota, + UsedQuota: 0, + } + require.NoError(t, model.DB.Create(token).Error) +} + +func seedSubscription(t *testing.T, id int, userId int, amountTotal int64, amountUsed int64) { + t.Helper() + sub := &model.UserSubscription{ + Id: id, + UserId: userId, + AmountTotal: amountTotal, + AmountUsed: amountUsed, + Status: "active", + StartTime: time.Now().Unix(), + EndTime: time.Now().Add(30 * 24 * time.Hour).Unix(), + } + require.NoError(t, model.DB.Create(sub).Error) +} + +func seedChannel(t *testing.T, id int) { + t.Helper() + ch := &model.Channel{Id: id, Name: "test_channel", Key: "sk-test", Status: common.ChannelStatusEnabled} + require.NoError(t, model.DB.Create(ch).Error) +} + +func makeTask(userId, channelId, quota, tokenId int, billingSource string, subscriptionId int) *model.Task { + return &model.Task{ + TaskID: "task_" + time.Now().Format("150405.000"), + UserId: userId, + ChannelId: channelId, + Quota: quota, + Status: model.TaskStatus(model.TaskStatusInProgress), + Group: "default", + Data: json.RawMessage(`{}`), + CreatedAt: time.Now().Unix(), + UpdatedAt: time.Now().Unix(), + Properties: model.Properties{ + OriginModelName: "test-model", + }, + PrivateData: model.TaskPrivateData{ + BillingSource: billingSource, + SubscriptionId: subscriptionId, + TokenId: tokenId, + BillingContext: &model.TaskBillingContext{ + ModelPrice: 0.02, + GroupRatio: 1.0, + ModelName: "test-model", + }, + }, + } +} + +// --------------------------------------------------------------------------- +// Read-back helpers +// --------------------------------------------------------------------------- + +func getUserQuota(t *testing.T, id int) int { + t.Helper() + var user model.User + require.NoError(t, model.DB.Select("quota").Where("id = ?", id).First(&user).Error) + return user.Quota +} + +func getTokenRemainQuota(t *testing.T, id int) int { + t.Helper() + var token model.Token + require.NoError(t, model.DB.Select("remain_quota").Where("id = ?", id).First(&token).Error) + return token.RemainQuota +} + +func getTokenUsedQuota(t *testing.T, id int) int { + t.Helper() + var token model.Token + require.NoError(t, model.DB.Select("used_quota").Where("id = ?", id).First(&token).Error) + return token.UsedQuota +} + +func getSubscriptionUsed(t *testing.T, id int) int64 { + t.Helper() + var sub model.UserSubscription + require.NoError(t, model.DB.Select("amount_used").Where("id = ?", id).First(&sub).Error) + return sub.AmountUsed +} + +func getLastLog(t *testing.T) *model.Log { + t.Helper() + var log model.Log + err := model.LOG_DB.Order("id desc").First(&log).Error + if err != nil { + return nil + } + return &log +} + +func countLogs(t *testing.T) int64 { + t.Helper() + var count int64 + model.LOG_DB.Model(&model.Log{}).Count(&count) + return count +} + +// =========================================================================== +// RefundTaskQuota tests +// =========================================================================== + +func TestRefundTaskQuota_Wallet(t *testing.T) { + truncate(t) + ctx := context.Background() + + const userID, tokenID, channelID = 1, 1, 1 + const initQuota, preConsumed = 10000, 3000 + const tokenRemain = 5000 + + seedUser(t, userID, initQuota) + seedToken(t, tokenID, userID, "sk-test-key", tokenRemain) + seedChannel(t, channelID) + + task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceWallet, 0) + + RefundTaskQuota(ctx, task, "task failed: upstream error") + + // User quota should increase by preConsumed + assert.Equal(t, initQuota+preConsumed, getUserQuota(t, userID)) + + // Token remain_quota should increase, used_quota should decrease + assert.Equal(t, tokenRemain+preConsumed, getTokenRemainQuota(t, tokenID)) + assert.Equal(t, -preConsumed, getTokenUsedQuota(t, tokenID)) + + // A refund log should be created + log := getLastLog(t) + require.NotNil(t, log) + assert.Equal(t, model.LogTypeRefund, log.Type) + assert.Equal(t, preConsumed, log.Quota) + assert.Equal(t, "test-model", log.ModelName) +} + +func TestRefundTaskQuota_Subscription(t *testing.T) { + truncate(t) + ctx := context.Background() + + const userID, tokenID, channelID, subID = 2, 2, 2, 1 + const preConsumed = 2000 + const subTotal, subUsed int64 = 100000, 50000 + const tokenRemain = 8000 + + seedUser(t, userID, 0) + seedToken(t, tokenID, userID, "sk-sub-key", tokenRemain) + seedChannel(t, channelID) + seedSubscription(t, subID, userID, subTotal, subUsed) + + task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceSubscription, subID) + + RefundTaskQuota(ctx, task, "subscription task failed") + + // Subscription used should decrease by preConsumed + assert.Equal(t, subUsed-int64(preConsumed), getSubscriptionUsed(t, subID)) + + // Token should also be refunded + assert.Equal(t, tokenRemain+preConsumed, getTokenRemainQuota(t, tokenID)) + + log := getLastLog(t) + require.NotNil(t, log) + assert.Equal(t, model.LogTypeRefund, log.Type) +} + +func TestRefundTaskQuota_ZeroQuota(t *testing.T) { + truncate(t) + ctx := context.Background() + + const userID = 3 + seedUser(t, userID, 5000) + + task := makeTask(userID, 0, 0, 0, BillingSourceWallet, 0) + + RefundTaskQuota(ctx, task, "zero quota task") + + // No change to user quota + assert.Equal(t, 5000, getUserQuota(t, userID)) + + // No log created + assert.Equal(t, int64(0), countLogs(t)) +} + +func TestRefundTaskQuota_NoToken(t *testing.T) { + truncate(t) + ctx := context.Background() + + const userID, channelID = 4, 4 + const initQuota, preConsumed = 10000, 1500 + + seedUser(t, userID, initQuota) + seedChannel(t, channelID) + + task := makeTask(userID, channelID, preConsumed, 0, BillingSourceWallet, 0) // TokenId=0 + + RefundTaskQuota(ctx, task, "no token task failed") + + // User quota refunded + assert.Equal(t, initQuota+preConsumed, getUserQuota(t, userID)) + + // Log created + log := getLastLog(t) + require.NotNil(t, log) + assert.Equal(t, model.LogTypeRefund, log.Type) +} + +// =========================================================================== +// RecalculateTaskQuota tests +// =========================================================================== + +func TestRecalculate_PositiveDelta(t *testing.T) { + truncate(t) + ctx := context.Background() + + const userID, tokenID, channelID = 10, 10, 10 + const initQuota, preConsumed = 10000, 2000 + const actualQuota = 3000 // under-charged by 1000 + const tokenRemain = 5000 + + seedUser(t, userID, initQuota) + seedToken(t, tokenID, userID, "sk-recalc-pos", tokenRemain) + seedChannel(t, channelID) + + task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceWallet, 0) + + RecalculateTaskQuota(ctx, task, actualQuota, "adaptor adjustment") + + // User quota should decrease by the delta (1000 additional charge) + assert.Equal(t, initQuota-(actualQuota-preConsumed), getUserQuota(t, userID)) + + // Token should also be charged the delta + assert.Equal(t, tokenRemain-(actualQuota-preConsumed), getTokenRemainQuota(t, tokenID)) + + // task.Quota should be updated to actualQuota + assert.Equal(t, actualQuota, task.Quota) + + // Log type should be Consume (additional charge) + log := getLastLog(t) + require.NotNil(t, log) + assert.Equal(t, model.LogTypeConsume, log.Type) + assert.Equal(t, actualQuota-preConsumed, log.Quota) +} + +func TestRecalculate_NegativeDelta(t *testing.T) { + truncate(t) + ctx := context.Background() + + const userID, tokenID, channelID = 11, 11, 11 + const initQuota, preConsumed = 10000, 5000 + const actualQuota = 3000 // over-charged by 2000 + const tokenRemain = 5000 + + seedUser(t, userID, initQuota) + seedToken(t, tokenID, userID, "sk-recalc-neg", tokenRemain) + seedChannel(t, channelID) + + task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceWallet, 0) + + RecalculateTaskQuota(ctx, task, actualQuota, "adaptor adjustment") + + // User quota should increase by abs(delta) = 2000 (refund overpayment) + assert.Equal(t, initQuota+(preConsumed-actualQuota), getUserQuota(t, userID)) + + // Token should be refunded the difference + assert.Equal(t, tokenRemain+(preConsumed-actualQuota), getTokenRemainQuota(t, tokenID)) + + // task.Quota updated + assert.Equal(t, actualQuota, task.Quota) + + // Log type should be Refund + log := getLastLog(t) + require.NotNil(t, log) + assert.Equal(t, model.LogTypeRefund, log.Type) + assert.Equal(t, preConsumed-actualQuota, log.Quota) +} + +func TestRecalculate_ZeroDelta(t *testing.T) { + truncate(t) + ctx := context.Background() + + const userID = 12 + const initQuota, preConsumed = 10000, 3000 + + seedUser(t, userID, initQuota) + + task := makeTask(userID, 0, preConsumed, 0, BillingSourceWallet, 0) + + RecalculateTaskQuota(ctx, task, preConsumed, "exact match") + + // No change to user quota + assert.Equal(t, initQuota, getUserQuota(t, userID)) + + // No log created (delta is zero) + assert.Equal(t, int64(0), countLogs(t)) +} + +func TestRecalculate_ActualQuotaZero(t *testing.T) { + truncate(t) + ctx := context.Background() + + const userID = 13 + const initQuota = 10000 + + seedUser(t, userID, initQuota) + + task := makeTask(userID, 0, 5000, 0, BillingSourceWallet, 0) + + RecalculateTaskQuota(ctx, task, 0, "zero actual") + + // No change (early return) + assert.Equal(t, initQuota, getUserQuota(t, userID)) + assert.Equal(t, int64(0), countLogs(t)) +} + +func TestRecalculate_Subscription_NegativeDelta(t *testing.T) { + truncate(t) + ctx := context.Background() + + const userID, tokenID, channelID, subID = 14, 14, 14, 2 + const preConsumed = 5000 + const actualQuota = 2000 // over-charged by 3000 + const subTotal, subUsed int64 = 100000, 50000 + const tokenRemain = 8000 + + seedUser(t, userID, 0) + seedToken(t, tokenID, userID, "sk-sub-recalc", tokenRemain) + seedChannel(t, channelID) + seedSubscription(t, subID, userID, subTotal, subUsed) + + task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceSubscription, subID) + + RecalculateTaskQuota(ctx, task, actualQuota, "subscription over-charge") + + // Subscription used should decrease by delta (refund 3000) + assert.Equal(t, subUsed-int64(preConsumed-actualQuota), getSubscriptionUsed(t, subID)) + + // Token refunded + assert.Equal(t, tokenRemain+(preConsumed-actualQuota), getTokenRemainQuota(t, tokenID)) + + assert.Equal(t, actualQuota, task.Quota) + + log := getLastLog(t) + require.NotNil(t, log) + assert.Equal(t, model.LogTypeRefund, log.Type) +} + +// =========================================================================== +// CAS + Billing integration tests +// Simulates the flow in updateVideoSingleTask (service/task_polling.go) +// =========================================================================== + +// simulatePollBilling reproduces the CAS + billing logic from updateVideoSingleTask. +// It takes a persisted task (already in DB), applies the new status, and performs +// the conditional update + billing exactly as the polling loop does. +func simulatePollBilling(ctx context.Context, task *model.Task, newStatus model.TaskStatus, actualQuota int) { + snap := task.Snapshot() + + shouldRefund := false + shouldSettle := false + quota := task.Quota + + task.Status = newStatus + switch string(newStatus) { + case model.TaskStatusSuccess: + task.Progress = "100%" + task.FinishTime = 9999 + shouldSettle = true + case model.TaskStatusFailure: + task.Progress = "100%" + task.FinishTime = 9999 + task.FailReason = "upstream error" + if quota != 0 { + shouldRefund = true + } + default: + task.Progress = "50%" + } + + isDone := task.Status == model.TaskStatus(model.TaskStatusSuccess) || task.Status == model.TaskStatus(model.TaskStatusFailure) + if isDone && snap.Status != task.Status { + won, err := task.UpdateWithStatus(snap.Status) + if err != nil { + shouldRefund = false + shouldSettle = false + } else if !won { + shouldRefund = false + shouldSettle = false + } + } else if !snap.Equal(task.Snapshot()) { + _, _ = task.UpdateWithStatus(snap.Status) + } + + if shouldSettle && actualQuota > 0 { + RecalculateTaskQuota(ctx, task, actualQuota, "test settle") + } + if shouldRefund { + RefundTaskQuota(ctx, task, task.FailReason) + } +} + +func TestCASGuardedRefund_Win(t *testing.T) { + truncate(t) + ctx := context.Background() + + const userID, tokenID, channelID = 20, 20, 20 + const initQuota, preConsumed = 10000, 4000 + const tokenRemain = 6000 + + seedUser(t, userID, initQuota) + seedToken(t, tokenID, userID, "sk-cas-refund-win", tokenRemain) + seedChannel(t, channelID) + + task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceWallet, 0) + task.Status = model.TaskStatus(model.TaskStatusInProgress) + require.NoError(t, model.DB.Create(task).Error) + + simulatePollBilling(ctx, task, model.TaskStatus(model.TaskStatusFailure), 0) + + // CAS wins: task in DB should now be FAILURE + var reloaded model.Task + require.NoError(t, model.DB.First(&reloaded, task.ID).Error) + assert.EqualValues(t, model.TaskStatusFailure, reloaded.Status) + + // Refund should have happened + assert.Equal(t, initQuota+preConsumed, getUserQuota(t, userID)) + assert.Equal(t, tokenRemain+preConsumed, getTokenRemainQuota(t, tokenID)) + + log := getLastLog(t) + require.NotNil(t, log) + assert.Equal(t, model.LogTypeRefund, log.Type) +} + +func TestCASGuardedRefund_Lose(t *testing.T) { + truncate(t) + ctx := context.Background() + + const userID, tokenID, channelID = 21, 21, 21 + const initQuota, preConsumed = 10000, 4000 + const tokenRemain = 6000 + + seedUser(t, userID, initQuota) + seedToken(t, tokenID, userID, "sk-cas-refund-lose", tokenRemain) + seedChannel(t, channelID) + + // Create task with IN_PROGRESS in DB + task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceWallet, 0) + task.Status = model.TaskStatus(model.TaskStatusInProgress) + require.NoError(t, model.DB.Create(task).Error) + + // Simulate another process already transitioning to FAILURE + model.DB.Model(&model.Task{}).Where("id = ?", task.ID).Update("status", model.TaskStatusFailure) + + // Our process still has the old in-memory state (IN_PROGRESS) and tries to transition + // task.Status is still IN_PROGRESS in the snapshot + simulatePollBilling(ctx, task, model.TaskStatus(model.TaskStatusFailure), 0) + + // CAS lost: user quota should NOT change (no double refund) + assert.Equal(t, initQuota, getUserQuota(t, userID)) + assert.Equal(t, tokenRemain, getTokenRemainQuota(t, tokenID)) + + // No billing log should be created + assert.Equal(t, int64(0), countLogs(t)) +} + +func TestCASGuardedSettle_Win(t *testing.T) { + truncate(t) + ctx := context.Background() + + const userID, tokenID, channelID = 22, 22, 22 + const initQuota, preConsumed = 10000, 5000 + const actualQuota = 3000 // over-charged, should get partial refund + const tokenRemain = 8000 + + seedUser(t, userID, initQuota) + seedToken(t, tokenID, userID, "sk-cas-settle-win", tokenRemain) + seedChannel(t, channelID) + + task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceWallet, 0) + task.Status = model.TaskStatus(model.TaskStatusInProgress) + require.NoError(t, model.DB.Create(task).Error) + + simulatePollBilling(ctx, task, model.TaskStatus(model.TaskStatusSuccess), actualQuota) + + // CAS wins: task should be SUCCESS + var reloaded model.Task + require.NoError(t, model.DB.First(&reloaded, task.ID).Error) + assert.EqualValues(t, model.TaskStatusSuccess, reloaded.Status) + + // Settlement should refund the over-charge (5000 - 3000 = 2000 back to user) + assert.Equal(t, initQuota+(preConsumed-actualQuota), getUserQuota(t, userID)) + assert.Equal(t, tokenRemain+(preConsumed-actualQuota), getTokenRemainQuota(t, tokenID)) + + // task.Quota should be updated to actualQuota + assert.Equal(t, actualQuota, task.Quota) +} + +func TestNonTerminalUpdate_NoBilling(t *testing.T) { + truncate(t) + ctx := context.Background() + + const userID, channelID = 23, 23 + const initQuota, preConsumed = 10000, 3000 + + seedUser(t, userID, initQuota) + seedChannel(t, channelID) + + task := makeTask(userID, channelID, preConsumed, 0, BillingSourceWallet, 0) + task.Status = model.TaskStatus(model.TaskStatusInProgress) + task.Progress = "20%" + require.NoError(t, model.DB.Create(task).Error) + + // Simulate a non-terminal poll update (still IN_PROGRESS, progress changed) + simulatePollBilling(ctx, task, model.TaskStatus(model.TaskStatusInProgress), 0) + + // User quota should NOT change + assert.Equal(t, initQuota, getUserQuota(t, userID)) + + // No billing log + assert.Equal(t, int64(0), countLogs(t)) + + // Task progress should be updated in DB + var reloaded model.Task + require.NoError(t, model.DB.First(&reloaded, task.ID).Error) + assert.Equal(t, "50%", reloaded.Progress) +} From ec5c6b28eafb165d402b897f3fd252e0ffe98028 Mon Sep 17 00:00:00 2001 From: CaIon Date: Sun, 22 Feb 2026 15:32:33 +0800 Subject: [PATCH 18/41] feat(task): add model redirection, per-call billing, and multipart retry fix for async tasks MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 1. Async task model redirection (aligned with sync tasks): - Integrate ModelMappedHelper in RelayTaskSubmit after model name determination, populating OriginModelName / UpstreamModelName on RelayInfo. - All task adaptors now send UpstreamModelName to upstream providers: - Gemini & Vertex: BuildRequestURL uses UpstreamModelName. - Doubao & Ali: BuildRequestBody conditionally overwrites body.Model. - Vidu, Kling, Hailuo, Jimeng: convertToRequestPayload accepts RelayInfo and unconditionally uses info.UpstreamModelName. - Sora: BuildRequestBody parses JSON and multipart bodies to replace the "model" field with UpstreamModelName. - Frontend log visibility: LogTaskConsumption and taskBillingOther now emit is_model_mapped / upstream_model_name in the "other" JSON field. - Billing safety: RecalculateTaskQuotaByTokens reads model name from BillingContext.OriginModelName (via taskModelName) instead of task.Data["model"], preventing billing leaks from upstream model names. 2. Per-call billing (TaskPricePatches lifecycle): - Rename TaskBillingContext.ModelName → OriginModelName; add PerCallBilling bool field, populated from TaskPricePatches at submission time. - settleTaskBillingOnComplete short-circuits when PerCallBilling is true, skipping both adaptor adjustments and token-based recalculation. - Remove ModelName from TaskSubmitResult; use relayInfo.OriginModelName consistently in controller/relay.go for billing context and logging. 3. Multipart retry boundary mismatch fix: - Root cause: after Sora (or OpenAI audio) rebuilds a multipart body with a new boundary and overwrites c.Request.Header["Content-Type"], subsequent calls to ParseMultipartFormReusable on retry would parse the cached original body with the wrong boundary, causing "NextPart: EOF". - Fix: ParseMultipartFormReusable now caches the original Content-Type in gin context key "_original_multipart_ct" on first call and reuses it for all subsequent parses, making multipart parsing retry-safe globally. - Sora adaptor reverted to the standard pattern (direct header set/get), which is now safe thanks to the root fix. 4. Tests: - task_billing_test.go: update makeTask to use OriginModelName; add PerCallBilling settlement tests (skip adaptor adjust, skip token recalc); add non-per-call adaptor adjustment test with refund verification. --- common/gin.go | 10 +- controller/relay.go | 17 +-- controller/task.go | 26 ++++- model/task.go | 11 +- relay/channel/task/ali/adaptor.go | 8 +- relay/channel/task/doubao/adaptor.go | 6 +- relay/channel/task/gemini/adaptor.go | 2 +- relay/channel/task/hailuo/adaptor.go | 8 +- relay/channel/task/jimeng/adaptor.go | 6 +- relay/channel/task/kling/adaptor.go | 9 +- relay/channel/task/sora/adaptor.go | 55 +++++++++ relay/channel/task/vertex/adaptor.go | 2 +- relay/channel/task/vidu/adaptor.go | 6 +- relay/relay_task.go | 9 +- service/task_billing.go | 29 ++--- service/task_billing_test.go | 108 +++++++++++++++++- service/task_polling.go | 5 + .../table/task-logs/TaskLogsColumnDefs.jsx | 36 +++--- web/src/components/table/task-logs/index.jsx | 2 - 19 files changed, 277 insertions(+), 78 deletions(-) diff --git a/common/gin.go b/common/gin.go index 48971c130..009e39080 100644 --- a/common/gin.go +++ b/common/gin.go @@ -243,7 +243,15 @@ func ParseMultipartFormReusable(c *gin.Context) (*multipart.Form, error) { return nil, err } - contentType := c.Request.Header.Get("Content-Type") + // Use the original Content-Type saved on first call to avoid boundary + // mismatch when callers overwrite c.Request.Header after multipart rebuild. + var contentType string + if saved, ok := c.Get("_original_multipart_ct"); ok { + contentType = saved.(string) + } else { + contentType = c.Request.Header.Get("Content-Type") + c.Set("_original_multipart_ct", contentType) + } boundary, err := parseBoundary(contentType) if err != nil { return nil, err diff --git a/controller/relay.go b/controller/relay.go index 6951974c5..7e7922e75 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -518,7 +518,7 @@ func RelayTask(c *gin.Context) { } addUsedChannel(c, channel.Id) - requestBody, bodyErr := common.GetRequestBody(c) + bodyStorage, bodyErr := common.GetBodyStorage(c) if bodyErr != nil { if common.IsRequestBodyTooLargeError(bodyErr) || errors.Is(bodyErr, common.ErrRequestBodyTooLarge) { taskErr = service.TaskErrorWrapperLocal(bodyErr, "read_request_body_failed", http.StatusRequestEntityTooLarge) @@ -527,7 +527,7 @@ func RelayTask(c *gin.Context) { } break } - c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody)) + c.Request.Body = io.NopCloser(bodyStorage) result, taskErr = relay.RelayTaskSubmit(c, relayInfo) if taskErr == nil { @@ -557,7 +557,7 @@ func RelayTask(c *gin.Context) { if settleErr := service.SettleBilling(c, relayInfo, result.Quota); settleErr != nil { common.SysError("settle task billing error: " + settleErr.Error()) } - service.LogTaskConsumption(c, relayInfo, result.ModelName) + service.LogTaskConsumption(c, relayInfo) task := model.InitTask(result.Platform, relayInfo) task.PrivateData.UpstreamTaskID = result.UpstreamTaskID @@ -565,11 +565,12 @@ func RelayTask(c *gin.Context) { task.PrivateData.SubscriptionId = relayInfo.SubscriptionId task.PrivateData.TokenId = relayInfo.TokenId task.PrivateData.BillingContext = &model.TaskBillingContext{ - ModelPrice: relayInfo.PriceData.ModelPrice, - GroupRatio: relayInfo.PriceData.GroupRatioInfo.GroupRatio, - ModelRatio: relayInfo.PriceData.ModelRatio, - OtherRatios: relayInfo.PriceData.OtherRatios, - ModelName: result.ModelName, + ModelPrice: relayInfo.PriceData.ModelPrice, + GroupRatio: relayInfo.PriceData.GroupRatioInfo.GroupRatio, + ModelRatio: relayInfo.PriceData.ModelRatio, + OtherRatios: relayInfo.PriceData.OtherRatios, + OriginModelName: relayInfo.OriginModelName, + PerCallBilling: common.StringsContains(constant.TaskPricePatches, relayInfo.OriginModelName), } task.Quota = result.Quota task.Data = result.TaskData diff --git a/controller/task.go b/controller/task.go index ec713c5d2..eac7db153 100644 --- a/controller/task.go +++ b/controller/task.go @@ -9,6 +9,7 @@ import ( "github.com/QuantumNous/new-api/model" "github.com/QuantumNous/new-api/relay" "github.com/QuantumNous/new-api/service" + "github.com/QuantumNous/new-api/types" "github.com/gin-gonic/gin" ) @@ -37,7 +38,7 @@ func GetAllTask(c *gin.Context) { items := model.TaskGetAllTasks(pageInfo.GetStartIdx(), pageInfo.GetPageSize(), queryParams) total := model.TaskCountAllTasks(queryParams) pageInfo.SetTotal(int(total)) - pageInfo.SetItems(tasksToDto(items)) + pageInfo.SetItems(tasksToDto(items, true)) common.ApiSuccess(c, pageInfo) } @@ -61,13 +62,32 @@ func GetUserTask(c *gin.Context) { items := model.TaskGetAllUserTask(userId, pageInfo.GetStartIdx(), pageInfo.GetPageSize(), queryParams) total := model.TaskCountAllUserTask(userId, queryParams) pageInfo.SetTotal(int(total)) - pageInfo.SetItems(tasksToDto(items)) + pageInfo.SetItems(tasksToDto(items, false)) common.ApiSuccess(c, pageInfo) } -func tasksToDto(tasks []*model.Task) []*dto.TaskDto { +func tasksToDto(tasks []*model.Task, fillUser bool) []*dto.TaskDto { + var userIdMap map[int]*model.UserBase + if fillUser { + userIdMap = make(map[int]*model.UserBase) + userIds := types.NewSet[int]() + for _, task := range tasks { + userIds.Add(task.UserId) + } + for _, userId := range userIds.Items() { + cacheUser, err := model.GetUserCache(userId) + if err == nil { + userIdMap[userId] = cacheUser + } + } + } result := make([]*dto.TaskDto, len(tasks)) for i, task := range tasks { + if fillUser { + if user, ok := userIdMap[task.UserId]; ok { + task.Username = user.Username + } + } result[i] = relay.TaskModel2Dto(task) } return result diff --git a/model/task.go b/model/task.go index 0cf6bd47e..da3be34ed 100644 --- a/model/task.go +++ b/model/task.go @@ -109,11 +109,12 @@ type TaskPrivateData struct { // TaskBillingContext 记录任务提交时的计费参数,以便轮询阶段可以重新计算额度。 type TaskBillingContext struct { - ModelPrice float64 `json:"model_price,omitempty"` // 模型单价 - GroupRatio float64 `json:"group_ratio,omitempty"` // 分组倍率 - ModelRatio float64 `json:"model_ratio,omitempty"` // 模型倍率 - OtherRatios map[string]float64 `json:"other_ratios,omitempty"` // 附加倍率(时长、分辨率等) - ModelName string `json:"model_name,omitempty"` // 模型名称 + ModelPrice float64 `json:"model_price,omitempty"` // 模型单价 + GroupRatio float64 `json:"group_ratio,omitempty"` // 分组倍率 + ModelRatio float64 `json:"model_ratio,omitempty"` // 模型倍率 + OtherRatios map[string]float64 `json:"other_ratios,omitempty"` // 附加倍率(时长、分辨率等) + OriginModelName string `json:"origin_model_name,omitempty"` // 模型名称,必须为OriginModelName + PerCallBilling bool `json:"per_call_billing,omitempty"` // 按次计费:跳过轮询阶段的差额结算 } // GetUpstreamTaskID 获取上游真实 task ID(用于与 provider 通信) diff --git a/relay/channel/task/ali/adaptor.go b/relay/channel/task/ali/adaptor.go index f55178b3b..f698fc9f6 100644 --- a/relay/channel/task/ali/adaptor.go +++ b/relay/channel/task/ali/adaptor.go @@ -253,8 +253,12 @@ func ProcessAliOtherRatios(aliReq *AliVideoRequest) (map[string]float64, error) } func (a *TaskAdaptor) convertToAliRequest(info *relaycommon.RelayInfo, req relaycommon.TaskSubmitReq) (*AliVideoRequest, error) { + upstreamModel := req.Model + if info.IsModelMapped { + upstreamModel = info.UpstreamModelName + } aliReq := &AliVideoRequest{ - Model: req.Model, + Model: upstreamModel, Input: AliVideoInput{ Prompt: req.Prompt, ImgURL: req.InputReference, @@ -332,7 +336,7 @@ func (a *TaskAdaptor) convertToAliRequest(info *relaycommon.RelayInfo, req relay } } - if aliReq.Model != req.Model { + if aliReq.Model != upstreamModel { return nil, errors.New("can't change model with metadata") } diff --git a/relay/channel/task/doubao/adaptor.go b/relay/channel/task/doubao/adaptor.go index eca421bd3..8f1d748ce 100644 --- a/relay/channel/task/doubao/adaptor.go +++ b/relay/channel/task/doubao/adaptor.go @@ -131,7 +131,11 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayIn if err != nil { return nil, errors.Wrap(err, "convert request payload failed") } - info.UpstreamModelName = body.Model + if info.IsModelMapped { + body.Model = info.UpstreamModelName + } else { + info.UpstreamModelName = body.Model + } data, err := common.Marshal(body) if err != nil { return nil, err diff --git a/relay/channel/task/gemini/adaptor.go b/relay/channel/task/gemini/adaptor.go index 06c00a469..5644cd5dc 100644 --- a/relay/channel/task/gemini/adaptor.go +++ b/relay/channel/task/gemini/adaptor.go @@ -105,7 +105,7 @@ func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycom // BuildRequestURL constructs the upstream URL. func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.RelayInfo) (string, error) { - modelName := info.OriginModelName + modelName := info.UpstreamModelName version := model_setting.GetGeminiVersionSetting(modelName) return fmt.Sprintf( diff --git a/relay/channel/task/hailuo/adaptor.go b/relay/channel/task/hailuo/adaptor.go index ab83d659b..28b3a97f1 100644 --- a/relay/channel/task/hailuo/adaptor.go +++ b/relay/channel/task/hailuo/adaptor.go @@ -61,7 +61,7 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayIn return nil, fmt.Errorf("invalid request type in context") } - body, err := a.convertToRequestPayload(&req) + body, err := a.convertToRequestPayload(&req, info) if err != nil { return nil, errors.Wrap(err, "convert request payload failed") } @@ -142,8 +142,8 @@ func (a *TaskAdaptor) GetChannelName() string { return ChannelName } -func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*VideoRequest, error) { - modelConfig := GetModelConfig(req.Model) +func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq, info *relaycommon.RelayInfo) (*VideoRequest, error) { + modelConfig := GetModelConfig(info.UpstreamModelName) duration := DefaultDuration if req.Duration > 0 { duration = req.Duration @@ -154,7 +154,7 @@ func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (* } videoRequest := &VideoRequest{ - Model: req.Model, + Model: info.UpstreamModelName, Prompt: req.Prompt, Duration: &duration, Resolution: resolution, diff --git a/relay/channel/task/jimeng/adaptor.go b/relay/channel/task/jimeng/adaptor.go index b61cca418..e6211b1e4 100644 --- a/relay/channel/task/jimeng/adaptor.go +++ b/relay/channel/task/jimeng/adaptor.go @@ -165,7 +165,7 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayIn } } - body, err := a.convertToRequestPayload(&req) + body, err := a.convertToRequestPayload(&req, info) if err != nil { return nil, errors.Wrap(err, "convert request payload failed") } @@ -378,9 +378,9 @@ func hmacSHA256(key []byte, data []byte) []byte { return h.Sum(nil) } -func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*requestPayload, error) { +func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq, info *relaycommon.RelayInfo) (*requestPayload, error) { r := requestPayload{ - ReqKey: req.Model, + ReqKey: info.UpstreamModelName, Prompt: req.Prompt, } diff --git a/relay/channel/task/kling/adaptor.go b/relay/channel/task/kling/adaptor.go index 46e210f19..cdbb56878 100644 --- a/relay/channel/task/kling/adaptor.go +++ b/relay/channel/task/kling/adaptor.go @@ -150,7 +150,7 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayIn } req := v.(relaycommon.TaskSubmitReq) - body, err := a.convertToRequestPayload(&req) + body, err := a.convertToRequestPayload(&req, info) if err != nil { return nil, err } @@ -248,15 +248,15 @@ func (a *TaskAdaptor) GetChannelName() string { // helpers // ============================ -func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*requestPayload, error) { +func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq, info *relaycommon.RelayInfo) (*requestPayload, error) { r := requestPayload{ Prompt: req.Prompt, Image: req.Image, Mode: taskcommon.DefaultString(req.Mode, "std"), Duration: fmt.Sprintf("%d", taskcommon.DefaultInt(req.Duration, 5)), AspectRatio: a.getAspectRatio(req.Size), - ModelName: req.Model, - Model: req.Model, // Keep consistent with model_name, double writing improves compatibility + ModelName: info.UpstreamModelName, + Model: info.UpstreamModelName, CfgScale: 0.5, StaticMask: "", DynamicMasks: []DynamicMask{}, @@ -266,6 +266,7 @@ func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (* } if r.ModelName == "" { r.ModelName = "kling-v1" + r.Model = "kling-v1" } if err := taskcommon.UnmarshalMetadata(req.Metadata, &r); err != nil { return nil, errors.Wrap(err, "unmarshal metadata failed") diff --git a/relay/channel/task/sora/adaptor.go b/relay/channel/task/sora/adaptor.go index bf2f70053..33db8fe55 100644 --- a/relay/channel/task/sora/adaptor.go +++ b/relay/channel/task/sora/adaptor.go @@ -1,8 +1,10 @@ package sora import ( + "bytes" "fmt" "io" + "mime/multipart" "net/http" "strconv" "strings" @@ -145,6 +147,59 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayIn if err != nil { return nil, errors.Wrap(err, "get_request_body_failed") } + cachedBody, err := storage.Bytes() + if err != nil { + return nil, errors.Wrap(err, "read_body_bytes_failed") + } + contentType := c.GetHeader("Content-Type") + + if strings.HasPrefix(contentType, "application/json") { + var bodyMap map[string]interface{} + if err := common.Unmarshal(cachedBody, &bodyMap); err == nil { + bodyMap["model"] = info.UpstreamModelName + if newBody, err := common.Marshal(bodyMap); err == nil { + return bytes.NewReader(newBody), nil + } + } + return bytes.NewReader(cachedBody), nil + } + + if strings.Contains(contentType, "multipart/form-data") { + formData, err := common.ParseMultipartFormReusable(c) + if err != nil { + return bytes.NewReader(cachedBody), nil + } + var buf bytes.Buffer + writer := multipart.NewWriter(&buf) + writer.WriteField("model", info.UpstreamModelName) + for key, values := range formData.Value { + if key == "model" { + continue + } + for _, v := range values { + writer.WriteField(key, v) + } + } + for fieldName, fileHeaders := range formData.File { + for _, fh := range fileHeaders { + f, err := fh.Open() + if err != nil { + continue + } + part, err := writer.CreateFormFile(fieldName, fh.Filename) + if err != nil { + f.Close() + continue + } + io.Copy(part, f) + f.Close() + } + } + writer.Close() + c.Request.Header.Set("Content-Type", writer.FormDataContentType()) + return &buf, nil + } + return common.ReaderOnly(storage), nil } diff --git a/relay/channel/task/vertex/adaptor.go b/relay/channel/task/vertex/adaptor.go index 4931002dd..700e60976 100644 --- a/relay/channel/task/vertex/adaptor.go +++ b/relay/channel/task/vertex/adaptor.go @@ -86,7 +86,7 @@ func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.RelayInfo) (string, erro if err := common.Unmarshal([]byte(a.apiKey), adc); err != nil { return "", fmt.Errorf("failed to decode credentials: %w", err) } - modelName := info.OriginModelName + modelName := info.UpstreamModelName if modelName == "" { modelName = "veo-3.0-generate-001" } diff --git a/relay/channel/task/vidu/adaptor.go b/relay/channel/task/vidu/adaptor.go index e689bf888..6ae1c181b 100644 --- a/relay/channel/task/vidu/adaptor.go +++ b/relay/channel/task/vidu/adaptor.go @@ -116,7 +116,7 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayIn } req := v.(relaycommon.TaskSubmitReq) - body, err := a.convertToRequestPayload(&req) + body, err := a.convertToRequestPayload(&req, info) if err != nil { return nil, err } @@ -224,9 +224,9 @@ func (a *TaskAdaptor) GetChannelName() string { // helpers // ============================ -func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*requestPayload, error) { +func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq, info *relaycommon.RelayInfo) (*requestPayload, error) { r := requestPayload{ - Model: taskcommon.DefaultString(req.Model, "viduq1"), + Model: taskcommon.DefaultString(info.UpstreamModelName, "viduq1"), Images: req.Images, Prompt: req.Prompt, Duration: taskcommon.DefaultInt(req.Duration, 5), diff --git a/relay/relay_task.go b/relay/relay_task.go index cd43e6ebb..c740facdb 100644 --- a/relay/relay_task.go +++ b/relay/relay_task.go @@ -26,7 +26,6 @@ type TaskSubmitResult struct { UpstreamTaskID string TaskData []byte Platform constant.TaskPlatform - ModelName string Quota int //PerCallPrice types.PriceData } @@ -163,6 +162,13 @@ func RelayTaskSubmit(c *gin.Context, info *relaycommon.RelayInfo) (*TaskSubmitRe modelName = service.CoverTaskActionToModelName(platform, info.Action) } + // 2.5 应用渠道的模型映射(与同步任务对齐) + info.OriginModelName = modelName + info.UpstreamModelName = modelName + if err := helper.ModelMappedHelper(c, info, nil); err != nil { + return nil, service.TaskErrorWrapperLocal(err, "model_mapping_failed", http.StatusBadRequest) + } + // 3. 预生成公开 task ID(仅首次) if info.PublicTaskID == "" { info.PublicTaskID = model.GenerateTaskID() @@ -241,7 +247,6 @@ func RelayTaskSubmit(c *gin.Context, info *relaycommon.RelayInfo) (*TaskSubmitRe UpstreamTaskID: upstreamTaskID, TaskData: taskData, Platform: platform, - ModelName: modelName, Quota: finalQuota, }, nil } diff --git a/service/task_billing.go b/service/task_billing.go index 78ad0fc09..0da4cf431 100644 --- a/service/task_billing.go +++ b/service/task_billing.go @@ -16,11 +16,11 @@ import ( // LogTaskConsumption 记录任务消费日志和统计信息(仅记录,不涉及实际扣费)。 // 实际扣费已由 BillingSession(PreConsumeBilling + SettleBilling)完成。 -func LogTaskConsumption(c *gin.Context, info *relaycommon.RelayInfo, modelName string) { +func LogTaskConsumption(c *gin.Context, info *relaycommon.RelayInfo) { tokenName := c.GetString("token_name") logContent := fmt.Sprintf("操作 %s", info.Action) // 支持任务仅按次计费 - if common.StringsContains(constant.TaskPricePatches, modelName) { + if common.StringsContains(constant.TaskPricePatches, info.OriginModelName) { logContent = fmt.Sprintf("%s,按次计费", logContent) } else { if len(info.PriceData.OtherRatios) > 0 { @@ -42,9 +42,13 @@ func LogTaskConsumption(c *gin.Context, info *relaycommon.RelayInfo, modelName s if info.PriceData.GroupRatioInfo.HasSpecialRatio { other["user_group_ratio"] = info.PriceData.GroupRatioInfo.GroupSpecialRatio } + if info.IsModelMapped { + other["is_model_mapped"] = true + other["upstream_model_name"] = info.UpstreamModelName + } model.RecordConsumeLog(c, info.UserId, model.RecordConsumeLogParams{ ChannelId: info.ChannelId, - ModelName: modelName, + ModelName: info.OriginModelName, TokenName: tokenName, Quota: info.PriceData.Quota, Content: logContent, @@ -120,13 +124,18 @@ func taskBillingOther(task *model.Task) map[string]interface{} { } } } + props := task.Properties + if props.UpstreamModelName != "" && props.UpstreamModelName != props.OriginModelName { + other["is_model_mapped"] = true + other["upstream_model_name"] = props.UpstreamModelName + } return other } // taskModelName 从 BillingContext 或 Properties 中获取模型名称。 func taskModelName(task *model.Task) string { - if bc := task.PrivateData.BillingContext; bc != nil && bc.ModelName != "" { - return bc.ModelName + if bc := task.PrivateData.BillingContext; bc != nil && bc.OriginModelName != "" { + return bc.OriginModelName } return task.Properties.OriginModelName } @@ -237,15 +246,7 @@ func RecalculateTaskQuotaByTokens(ctx context.Context, task *model.Task, totalTo return } - // 获取模型名称 - var taskData map[string]interface{} - if err := common.Unmarshal(task.Data, &taskData); err != nil { - return - } - modelName, ok := taskData["model"].(string) - if !ok || modelName == "" { - return - } + modelName := taskModelName(task) // 获取模型价格和倍率 modelRatio, hasRatioSetting, _ := ratio_setting.GetModelRatio(modelName) diff --git a/service/task_billing_test.go b/service/task_billing_test.go index 6c2d231d5..1145bba54 100644 --- a/service/task_billing_test.go +++ b/service/task_billing_test.go @@ -3,12 +3,14 @@ package service import ( "context" "encoding/json" + "net/http" "os" "testing" "time" "github.com/QuantumNous/new-api/common" "github.com/QuantumNous/new-api/model" + relaycommon "github.com/QuantumNous/new-api/relay/common" "github.com/glebarez/sqlite" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -125,7 +127,7 @@ func makeTask(userId, channelId, quota, tokenId int, billingSource string, subsc BillingContext: &model.TaskBillingContext{ ModelPrice: 0.02, GroupRatio: 1.0, - ModelName: "test-model", + OriginModelName: "test-model", }, }, } @@ -604,3 +606,107 @@ func TestNonTerminalUpdate_NoBilling(t *testing.T) { require.NoError(t, model.DB.First(&reloaded, task.ID).Error) assert.Equal(t, "50%", reloaded.Progress) } + +// =========================================================================== +// Mock adaptor for settleTaskBillingOnComplete tests +// =========================================================================== + +type mockAdaptor struct { + adjustReturn int +} + +func (m *mockAdaptor) Init(_ *relaycommon.RelayInfo) {} +func (m *mockAdaptor) FetchTask(string, string, map[string]any, string) (*http.Response, error) { return nil, nil } +func (m *mockAdaptor) ParseTaskResult([]byte) (*relaycommon.TaskInfo, error) { return nil, nil } +func (m *mockAdaptor) AdjustBillingOnComplete(_ *model.Task, _ *relaycommon.TaskInfo) int { + return m.adjustReturn +} + +// =========================================================================== +// PerCallBilling tests — settleTaskBillingOnComplete +// =========================================================================== + +func TestSettle_PerCallBilling_SkipsAdaptorAdjust(t *testing.T) { + truncate(t) + ctx := context.Background() + + const userID, tokenID, channelID = 30, 30, 30 + const initQuota, preConsumed = 10000, 5000 + const tokenRemain = 8000 + + seedUser(t, userID, initQuota) + seedToken(t, tokenID, userID, "sk-percall-adaptor", tokenRemain) + seedChannel(t, channelID) + + task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceWallet, 0) + task.PrivateData.BillingContext.PerCallBilling = true + + adaptor := &mockAdaptor{adjustReturn: 2000} + taskResult := &relaycommon.TaskInfo{Status: model.TaskStatusSuccess} + + settleTaskBillingOnComplete(ctx, adaptor, task, taskResult) + + // Per-call: no adjustment despite adaptor returning 2000 + assert.Equal(t, initQuota, getUserQuota(t, userID)) + assert.Equal(t, tokenRemain, getTokenRemainQuota(t, tokenID)) + assert.Equal(t, preConsumed, task.Quota) + assert.Equal(t, int64(0), countLogs(t)) +} + +func TestSettle_PerCallBilling_SkipsTotalTokens(t *testing.T) { + truncate(t) + ctx := context.Background() + + const userID, tokenID, channelID = 31, 31, 31 + const initQuota, preConsumed = 10000, 4000 + const tokenRemain = 7000 + + seedUser(t, userID, initQuota) + seedToken(t, tokenID, userID, "sk-percall-tokens", tokenRemain) + seedChannel(t, channelID) + + task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceWallet, 0) + task.PrivateData.BillingContext.PerCallBilling = true + + adaptor := &mockAdaptor{adjustReturn: 0} + taskResult := &relaycommon.TaskInfo{Status: model.TaskStatusSuccess, TotalTokens: 9999} + + settleTaskBillingOnComplete(ctx, adaptor, task, taskResult) + + // Per-call: no recalculation by tokens + assert.Equal(t, initQuota, getUserQuota(t, userID)) + assert.Equal(t, tokenRemain, getTokenRemainQuota(t, tokenID)) + assert.Equal(t, preConsumed, task.Quota) + assert.Equal(t, int64(0), countLogs(t)) +} + +func TestSettle_NonPerCall_AdaptorAdjustWorks(t *testing.T) { + truncate(t) + ctx := context.Background() + + const userID, tokenID, channelID = 32, 32, 32 + const initQuota, preConsumed = 10000, 5000 + const adaptorQuota = 3000 + const tokenRemain = 8000 + + seedUser(t, userID, initQuota) + seedToken(t, tokenID, userID, "sk-nonpercall-adj", tokenRemain) + seedChannel(t, channelID) + + task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceWallet, 0) + // PerCallBilling defaults to false + + adaptor := &mockAdaptor{adjustReturn: adaptorQuota} + taskResult := &relaycommon.TaskInfo{Status: model.TaskStatusSuccess} + + settleTaskBillingOnComplete(ctx, adaptor, task, taskResult) + + // Non-per-call: adaptor adjustment applies (refund 2000) + assert.Equal(t, initQuota+(preConsumed-adaptorQuota), getUserQuota(t, userID)) + assert.Equal(t, tokenRemain+(preConsumed-adaptorQuota), getTokenRemainQuota(t, tokenID)) + assert.Equal(t, adaptorQuota, task.Quota) + + log := getLastLog(t) + require.NotNil(t, log) + assert.Equal(t, model.LogTypeRefund, log.Type) +} diff --git a/service/task_polling.go b/service/task_polling.go index 7e92d14ba..a03fc9b88 100644 --- a/service/task_polling.go +++ b/service/task_polling.go @@ -467,6 +467,11 @@ func truncateBase64(s string) string { // 2. taskResult.TotalTokens > 0 → 按 token 重算 // 3. 都不满足 → 保持预扣额度不变 func settleTaskBillingOnComplete(ctx context.Context, adaptor TaskPollingAdaptor, task *model.Task, taskResult *relaycommon.TaskInfo) { + // 0. 按次计费的任务不做差额结算 + if bc := task.PrivateData.BillingContext; bc != nil && bc.PerCallBilling { + logger.LogInfo(ctx, fmt.Sprintf("任务 %s 按次计费,跳过差额结算", task.TaskID)) + return + } // 1. 优先让 adaptor 决定最终额度 if actualQuota := adaptor.AdjustBillingOnComplete(task, taskResult); actualQuota > 0 { RecalculateTaskQuota(ctx, task, actualQuota, "adaptor计费调整") diff --git a/web/src/components/table/task-logs/TaskLogsColumnDefs.jsx b/web/src/components/table/task-logs/TaskLogsColumnDefs.jsx index 4bce45256..7fddb0a50 100644 --- a/web/src/components/table/task-logs/TaskLogsColumnDefs.jsx +++ b/web/src/components/table/task-logs/TaskLogsColumnDefs.jsx @@ -84,8 +84,8 @@ function renderDuration(submit_time, finishTime) { // 返回带有样式的颜色标签 return ( - }> - {durationSec} 秒 + + {durationSec} s ); } @@ -149,7 +149,7 @@ const renderPlatform = (platform, t) => { ); if (option) { return ( - }> + {option.label} ); @@ -157,13 +157,13 @@ const renderPlatform = (platform, t) => { switch (platform) { case 'suno': return ( - }> + Suno ); default: return ( - }> + {t('未知')} ); @@ -240,7 +240,6 @@ export const getTaskLogsColumns = ({ openContentModal, isAdminUser, openVideoModal, - showUserInfoFunc, }) => { return [ { @@ -278,7 +277,6 @@ export const getTaskLogsColumns = ({ color={colors[parseInt(text) % colors.length]} size='large' shape='circle' - prefixIcon={} onClick={() => { copyText(text); }} @@ -294,7 +292,7 @@ export const getTaskLogsColumns = ({ { key: COLUMN_KEYS.USERNAME, title: t('用户'), - dataIndex: 'user_id', + dataIndex: 'username', render: (userId, record, index) => { if (!isAdminUser) { return <>; @@ -302,22 +300,14 @@ export const getTaskLogsColumns = ({ const displayText = String(record.username || userId || '?'); return ( - - showUserInfoFunc && showUserInfoFunc(userId)} - > - {displayText.slice(0, 1)} - - - showUserInfoFunc && showUserInfoFunc(userId)} + - {userId} + {displayText.slice(0, 1)} + + + {displayText} ); diff --git a/web/src/components/table/task-logs/index.jsx b/web/src/components/table/task-logs/index.jsx index 140725a89..bc5b91787 100644 --- a/web/src/components/table/task-logs/index.jsx +++ b/web/src/components/table/task-logs/index.jsx @@ -25,7 +25,6 @@ import TaskLogsActions from './TaskLogsActions'; import TaskLogsFilters from './TaskLogsFilters'; import ColumnSelectorModal from './modals/ColumnSelectorModal'; import ContentModal from './modals/ContentModal'; -import UserInfoModal from '../usage-logs/modals/UserInfoModal'; import { useTaskLogsData } from '../../../hooks/task-logs/useTaskLogsData'; import { useIsMobile } from '../../../hooks/common/useIsMobile'; import { createCardProPagination } from '../../../helpers/utils'; @@ -46,7 +45,6 @@ const TaskLogsPage = () => { modalContent={taskLogsData.videoUrl} isVideo={true} /> - Date: Sun, 22 Feb 2026 16:45:35 +0800 Subject: [PATCH 19/41] fix(i18n): remove duplicate task ID translations and clean up unused keys across multiple languages --- web/src/i18n/locales/en.json | 43 --------------------------------- web/src/i18n/locales/fr.json | 2 -- web/src/i18n/locales/ja.json | 2 -- web/src/i18n/locales/ru.json | 2 -- web/src/i18n/locales/vi.json | 2 -- web/src/i18n/locales/zh-CN.json | 6 ----- 6 files changed, 57 deletions(-) diff --git a/web/src/i18n/locales/en.json b/web/src/i18n/locales/en.json index c25468339..93b5f18c3 100644 --- a/web/src/i18n/locales/en.json +++ b/web/src/i18n/locales/en.json @@ -302,7 +302,6 @@ "价格重新计算中...": "Recalculating price...", "价格预估": "Price Estimate", "任务 ID": "Task ID", - "任务ID": "Task ID", "任务日志": "Task Logs", "任务状态": "Status", "任务记录": "Task Records", @@ -544,7 +543,6 @@ "创建": "Create", "创建令牌默认选择auto分组,初始令牌也将设为auto(否则留空,为用户默认分组)": "Create token with auto group by default, initial token will also be set to auto (otherwise leave blank for user default group)", "创建失败": "Creation failed", - "创建成功": "Creation successful", "创建或选择密钥时,将 Project 设置为 io.cloud": "When creating or selecting a key, set Project to io.cloud", "创建新用户账户": "Create new user account", "创建新的令牌": "Create New Token", @@ -787,7 +785,6 @@ "天": "day", "天前": "days ago", "失败": "Failed", - "失败原因": "Failure reason", "失败时自动禁用通道": "Automatically disable channel on failure", "失败重试次数": "Failed retry times", "奖励说明": "Reward description", @@ -1336,7 +1333,6 @@ "更新失败,请检查输入信息": "Update failed, please check the input information", "更新容器配置": "Update Container Configuration", "更新容器配置可能会导致容器重启,请确保在合适的时间进行此操作。": "Updating container configuration may cause the container to restart, please ensure you perform this operation at an appropriate time.", - "更新成功": "Update successful", "更新所有已启用通道余额": "Update balance for all enabled channels", "更新支付设置": "Update payment settings", "更新时间": "Update time", @@ -1767,7 +1763,6 @@ "确认清理不活跃的磁盘缓存?": "Confirm cleanup of inactive disk cache?", "确认禁用": "Confirm disable", "确认补单": "Confirm Order Completion", - "确认解绑": "Confirm Unbind", "确认解绑 Passkey": "Confirm Unbind Passkey", "确认设置并完成初始化": "Confirm settings and complete initialization", "确认重置 Passkey": "Confirm Passkey Reset", @@ -1945,7 +1940,6 @@ "自动分组auto,从第一个开始选择": "Auto grouping auto, select from the first one", "自动刷新": "Auto Refresh", "自动刷新中": "Auto refreshing", - "自动检测": "Auto Detect", "自动模式": "Auto Mode", "自动测试所有通道间隔时间": "Auto test interval for all channels", "自动禁用": "Auto disabled", @@ -2343,46 +2337,9 @@ "输入验证码完成设置": "Enter verification code to complete setup", "输出": "Output", "输出 {{completion}} tokens / 1M tokens * {{symbol}}{{compPrice}}) * {{ratioType}} {{ratio}}": "Output {{completion}} tokens / 1M tokens * {{symbol}}{{compPrice}} * {{ratioType}} {{ratio}}", - "磁盘缓存设置(磁盘换内存)": "Disk Cache Settings (Disk Swap Memory)", - "启用磁盘缓存后,大请求体将临时存储到磁盘而非内存,可显著降低内存占用,适用于处理包含大量图片/文件的请求。建议在 SSD 环境下使用。": "When enabled, large request bodies are temporarily stored on disk instead of memory, significantly reducing memory usage. Suitable for requests with large images/files. SSD recommended.", - "启用磁盘缓存": "Enable Disk Cache", - "将大请求体临时存储到磁盘": "Store large request bodies temporarily on disk", - "磁盘缓存阈值 (MB)": "Disk Cache Threshold (MB)", - "请求体超过此大小时使用磁盘缓存": "Use disk cache when request body exceeds this size", - "磁盘缓存最大总量 (MB)": "Max Disk Cache Size (MB)", - "可用空间: {{free}} / 总空间: {{total}}": "Free: {{free}} / Total: {{total}}", - "磁盘缓存占用的最大空间": "Maximum space occupied by disk cache", - "留空使用系统临时目录": "Leave empty to use system temp directory", - "例如 /var/cache/new-api": "e.g. /var/cache/new-api", - "性能监控": "Performance Monitor", - "刷新统计": "Refresh Stats", - "重置统计": "Reset Stats", - "执行 GC": "Run GC", - "请求体磁盘缓存": "Request Body Disk Cache", - "活跃文件": "Active Files", - "磁盘命中": "Disk Hits", - "请求体内存缓存": "Request Body Memory Cache", - "当前缓存大小": "Current Cache Size", - "活跃缓存数": "Active Cache Count", - "内存命中": "Memory Hits", - "缓存目录磁盘空间": "Cache Directory Disk Space", - "磁盘可用空间小于缓存最大总量设置": "Disk free space is less than max cache size setting", - "已分配内存": "Allocated Memory", - "总分配内存": "Total Allocated Memory", - "系统内存": "System Memory", - "GC 次数": "GC Count", - "Goroutine 数": "Goroutine Count", - "目录文件数": "Directory File Count", - "目录总大小": "Directory Total Size", - "磁盘缓存已清理": "Disk cache cleared", - "清理失败": "Cleanup failed", - "统计已重置": "Statistics reset", - "重置失败": "Reset failed", - "GC 已执行": "GC executed", "GC 执行失败": "GC execution failed", "缓存目录": "Cache Directory", "可用": "Available", - "输出价格": "Output Price", "输出价格:{{symbol}}{{price}} * {{completionRatio}} = {{symbol}}{{total}} / 1M tokens (补全倍率: {{completionRatio}})": "Output price: {{symbol}}{{price}} * {{completionRatio}} = {{symbol}}{{total}} / 1M tokens (Completion ratio: {{completionRatio}})", "输出倍率 {{completionRatio}}": "Output ratio {{completionRatio}}", "边栏设置": "Sidebar Settings", diff --git a/web/src/i18n/locales/fr.json b/web/src/i18n/locales/fr.json index 54fd3617e..702a61dee 100644 --- a/web/src/i18n/locales/fr.json +++ b/web/src/i18n/locales/fr.json @@ -304,7 +304,6 @@ "价格重新计算中...": "Recalculating price...", "价格预估": "Price Estimate", "任务 ID": "ID de la tâche", - "任务ID": "ID de la tâche", "任务日志": "Tâches", "任务状态": "Statut de la tâche", "任务记录": "Tâches", @@ -792,7 +791,6 @@ "天": "Jour", "天前": "il y a des jours", "失败": "Échec", - "失败原因": "Raison de l'échec", "失败时自动禁用通道": "Désactiver automatiquement le canal en cas d'échec", "失败重试次数": "Nombre de tentatives en cas d'échec", "奖励说明": "Description de la récompense", diff --git a/web/src/i18n/locales/ja.json b/web/src/i18n/locales/ja.json index d9a49aa50..d1e770e9e 100644 --- a/web/src/i18n/locales/ja.json +++ b/web/src/i18n/locales/ja.json @@ -300,7 +300,6 @@ "价格重新计算中...": "Recalculating price...", "价格预估": "Price Estimate", "任务 ID": "タスクID", - "任务ID": "タスクID", "任务日志": "タスク履歴", "任务状态": "タスクステータス", "任务记录": "タスク履歴", @@ -783,7 +782,6 @@ "天": "日", "天前": "日前", "失败": "失敗", - "失败原因": "失敗理由", "失败时自动禁用通道": "失敗時にチャネルを自動的に無効にする", "失败重试次数": "再試行回数", "奖励说明": "特典説明", diff --git a/web/src/i18n/locales/ru.json b/web/src/i18n/locales/ru.json index fc117a51a..e2a529041 100644 --- a/web/src/i18n/locales/ru.json +++ b/web/src/i18n/locales/ru.json @@ -307,7 +307,6 @@ "价格重新计算中...": "Recalculating price...", "价格预估": "Price Estimate", "任务 ID": "ID задачи", - "任务ID": "ID задачи", "任务日志": "Журнал задач", "任务状态": "Статус задачи", "任务记录": "Записи задач", @@ -798,7 +797,6 @@ "天": "день", "天前": "дней назад", "失败": "Неудача", - "失败原因": "Причина неудачи", "失败时自动禁用通道": "Автоматически отключать канал при неудаче", "失败重试次数": "Количество повторных попыток при неудаче", "奖励说明": "Описание награды", diff --git a/web/src/i18n/locales/vi.json b/web/src/i18n/locales/vi.json index 89d8715e2..a311ca9ec 100644 --- a/web/src/i18n/locales/vi.json +++ b/web/src/i18n/locales/vi.json @@ -301,7 +301,6 @@ "价格重新计算中...": "Recalculating price...", "价格预估": "Price Estimate", "任务 ID": "ID tác vụ", - "任务ID": "ID tác vụ", "任务日志": "Nhật ký tác vụ", "任务状态": "Trạng thái", "任务记录": "Hồ sơ tác vụ", @@ -784,7 +783,6 @@ "天": "ngày", "天前": "ngày trước", "失败": "Thất bại", - "失败原因": "Lý do thất bại", "失败时自动禁用通道": "Tự động vô hiệu hóa kênh khi thất bại", "失败重试次数": "Số lần thử lại thất bại", "奖励说明": "Mô tả phần thưởng", diff --git a/web/src/i18n/locales/zh-CN.json b/web/src/i18n/locales/zh-CN.json index 3cfcc0326..a5bace57f 100644 --- a/web/src/i18n/locales/zh-CN.json +++ b/web/src/i18n/locales/zh-CN.json @@ -298,7 +298,6 @@ "价格重新计算中...": "价格重新计算中...", "价格预估": "价格预估", "任务 ID": "任务 ID", - "任务ID": "任务ID", "任务日志": "任务日志", "任务状态": "任务状态", "任务记录": "任务记录", @@ -539,7 +538,6 @@ "创建": "创建", "创建令牌默认选择auto分组,初始令牌也将设为auto(否则留空,为用户默认分组)": "创建令牌默认选择auto分组,初始令牌也将设为auto(否则留空,为用户默认分组)", "创建失败": "创建失败", - "创建成功": "创建成功", "创建或选择密钥时,将 Project 设置为 io.cloud": "创建或选择密钥时,将 Project 设置为 io.cloud", "创建新用户账户": "创建新用户账户", "创建新的令牌": "创建新的令牌", @@ -782,7 +780,6 @@ "天": "天", "天前": "天前", "失败": "失败", - "失败原因": "失败原因", "失败时自动禁用通道": "失败时自动禁用通道", "失败重试次数": "失败重试次数", "奖励说明": "奖励说明", @@ -1326,7 +1323,6 @@ "更新失败,请检查输入信息": "更新失败,请检查输入信息", "更新容器配置": "更新容器配置", "更新容器配置可能会导致容器重启,请确保在合适的时间进行此操作。": "更新容器配置可能会导致容器重启,请确保在合适的时间进行此操作。", - "更新成功": "更新成功", "更新所有已启用通道余额": "更新所有已启用通道余额", "更新支付设置": "更新支付设置", "更新时间": "更新时间", @@ -1754,7 +1750,6 @@ "确认清除历史日志": "确认清除历史日志", "确认禁用": "确认禁用", "确认补单": "确认补单", - "确认解绑": "确认解绑", "确认解绑 Passkey": "确认解绑 Passkey", "确认设置并完成初始化": "确认设置并完成初始化", "确认重置 Passkey": "确认重置 Passkey", @@ -1932,7 +1927,6 @@ "自动分组auto,从第一个开始选择": "自动分组auto,从第一个开始选择", "自动刷新": "自动刷新", "自动刷新中": "自动刷新中", - "自动检测": "自动检测", "自动模式": "自动模式", "自动测试所有通道间隔时间": "自动测试所有通道间隔时间", "自动禁用": "自动禁用", From 13ada6484a6d7b6ef8142f5e55ab8fe0f3e8eb06 Mon Sep 17 00:00:00 2001 From: CaIon Date: Sun, 22 Feb 2026 17:59:38 +0800 Subject: [PATCH 20/41] feat(task): introduce task timeout configuration and cleanup unfinished tasks - Added TaskTimeoutMinutes constant to configure the timeout duration for asynchronous tasks. - Implemented sweepTimedOutTasks function to identify and handle unfinished tasks that exceed the timeout limit, marking them as failed and processing refunds if applicable. - Enhanced task polling loop to include the new timeout handling logic, ensuring timely cleanup of stale tasks. --- common/init.go | 2 ++ constant/env.go | 1 + model/task.go | 19 +++++++++++++++ service/task_polling.go | 53 +++++++++++++++++++++++++++++++++++++++++ 4 files changed, 75 insertions(+) diff --git a/common/init.go b/common/init.go index 6d2c3572b..e4ddbb453 100644 --- a/common/init.go +++ b/common/init.go @@ -145,6 +145,8 @@ func initConstantEnv() { constant.ErrorLogEnabled = GetEnvOrDefaultBool("ERROR_LOG_ENABLED", false) // 任务轮询时查询的最大数量 constant.TaskQueryLimit = GetEnvOrDefault("TASK_QUERY_LIMIT", 1000) + // 异步任务超时时间(分钟),超过此时间未完成的任务将被标记为失败并退款。0 表示禁用。 + constant.TaskTimeoutMinutes = GetEnvOrDefault("TASK_TIMEOUT_MINUTES", 1440) soraPatchStr := GetEnvOrDefaultString("TASK_PRICE_PATCH", "") if soraPatchStr != "" { diff --git a/constant/env.go b/constant/env.go index 957f68669..d5aff1b0b 100644 --- a/constant/env.go +++ b/constant/env.go @@ -16,6 +16,7 @@ var NotificationLimitDurationMinute int var GenerateDefaultToken bool var ErrorLogEnabled bool var TaskQueryLimit int +var TaskTimeoutMinutes int // temporary variable for sora patch, will be removed in future var TaskPricePatches []string diff --git a/model/task.go b/model/task.go index da3be34ed..984445083 100644 --- a/model/task.go +++ b/model/task.go @@ -288,6 +288,20 @@ func TaskGetAllTasks(startIdx int, num int, queryParams SyncTaskQueryParams) []* return tasks } +func GetTimedOutUnfinishedTasks(cutoffUnix int64, limit int) []*Task { + var tasks []*Task + err := DB.Where("progress != ?", "100%"). + Where("status NOT IN ?", []string{TaskStatusFailure, TaskStatusSuccess}). + Where("submit_time < ?", cutoffUnix). + Order("submit_time"). + Limit(limit). + Find(&tasks).Error + if err != nil { + return nil + } + return tasks +} + func GetAllUnFinishSyncTasks(limit int) []*Task { var tasks []*Task var err error @@ -401,6 +415,11 @@ func (t *Task) UpdateWithStatus(fromStatus TaskStatus) (bool, error) { return result.RowsAffected > 0, nil } +// TaskBulkUpdateByID performs an unconditional bulk UPDATE by primary key IDs. +// WARNING: This function has NO CAS (Compare-And-Swap) guard — it will overwrite +// any concurrent status changes. DO NOT use in billing/quota lifecycle flows +// (e.g., timeout, success, failure transitions that trigger refunds or settlements). +// For status transitions that involve billing, use Task.UpdateWithStatus() instead. func TaskBulkUpdateByID(ids []int64, params map[string]any) error { if len(ids) == 0 { return nil diff --git a/service/task_polling.go b/service/task_polling.go index a03fc9b88..9ac4deddc 100644 --- a/service/task_polling.go +++ b/service/task_polling.go @@ -35,12 +35,65 @@ type TaskPollingAdaptor interface { // 打破 service -> relay -> relay/channel -> service 的循环依赖。 var GetTaskAdaptorFunc func(platform constant.TaskPlatform) TaskPollingAdaptor +// sweepTimedOutTasks 在主轮询之前独立清理超时任务。 +// 每次最多处理 100 条,剩余的下个周期继续处理。 +// 使用 per-task CAS (UpdateWithStatus) 防止覆盖被正常轮询已推进的任务。 +func sweepTimedOutTasks(ctx context.Context) { + if constant.TaskTimeoutMinutes <= 0 { + return + } + cutoff := time.Now().Unix() - int64(constant.TaskTimeoutMinutes)*60 + tasks := model.GetTimedOutUnfinishedTasks(cutoff, 100) + if len(tasks) == 0 { + return + } + + const legacyTaskCutoff int64 = 1740182400 // 2026-02-22 00:00:00 UTC + reason := fmt.Sprintf("任务超时(%d分钟)", constant.TaskTimeoutMinutes) + legacyReason := "任务超时(旧系统遗留任务,不进行退款,请联系管理员)" + now := time.Now().Unix() + timedOutCount := 0 + + for _, task := range tasks { + isLegacy := task.SubmitTime > 0 && task.SubmitTime < legacyTaskCutoff + + oldStatus := task.Status + task.Status = model.TaskStatusFailure + task.Progress = "100%" + task.FinishTime = now + if isLegacy { + task.FailReason = legacyReason + } else { + task.FailReason = reason + } + + won, err := task.UpdateWithStatus(oldStatus) + if err != nil { + logger.LogError(ctx, fmt.Sprintf("sweepTimedOutTasks CAS update error for task %s: %v", task.TaskID, err)) + continue + } + if !won { + logger.LogInfo(ctx, fmt.Sprintf("sweepTimedOutTasks: task %s already transitioned, skip", task.TaskID)) + continue + } + timedOutCount++ + if !isLegacy && task.Quota != 0 { + RefundTaskQuota(ctx, task, reason) + } + } + + if timedOutCount > 0 { + logger.LogInfo(ctx, fmt.Sprintf("sweepTimedOutTasks: timed out %d tasks", timedOutCount)) + } +} + // TaskPollingLoop 主轮询循环,每 15 秒检查一次未完成的任务 func TaskPollingLoop() { for { time.Sleep(time.Duration(15) * time.Second) common.SysLog("任务进度轮询开始") ctx := context.TODO() + sweepTimedOutTasks(ctx) allTasks := model.GetAllUnFinishSyncTasks(constant.TaskQueryLimit) platformTask := make(map[constant.TaskPlatform][]*model.Task) for _, t := range allTasks { From f4dded51ab913423abd0e6aae412c79f077ab564 Mon Sep 17 00:00:00 2001 From: CaIon Date: Sun, 22 Feb 2026 18:24:42 +0800 Subject: [PATCH 21/41] Update README --- README.fr.md | 4 ++-- README.ja.md | 4 ++-- README.md | 4 ++-- README.zh_CN.md | 4 ++-- README.zh_TW.md | 4 ++-- 5 files changed, 10 insertions(+), 10 deletions(-) diff --git a/README.fr.md b/README.fr.md index 77fd0cd1c..6b4d0ceba 100644 --- a/README.fr.md +++ b/README.fr.md @@ -30,8 +30,8 @@

- - Calcium-Ion%2Fnew-api | Trendshift + + QuantumNous%2Fnew-api | Trendshift
diff --git a/README.ja.md b/README.ja.md index 2cb00affb..2b35bdfe9 100644 --- a/README.ja.md +++ b/README.ja.md @@ -30,8 +30,8 @@

- - Calcium-Ion%2Fnew-api | Trendshift + + QuantumNous%2Fnew-api | Trendshift
diff --git a/README.md b/README.md index 5f64a0d0b..8f23d5dcd 100644 --- a/README.md +++ b/README.md @@ -30,8 +30,8 @@

- - Calcium-Ion%2Fnew-api | Trendshift + + QuantumNous%2Fnew-api | Trendshift
diff --git a/README.zh_CN.md b/README.zh_CN.md index 55265d9a8..fd3204950 100644 --- a/README.zh_CN.md +++ b/README.zh_CN.md @@ -30,8 +30,8 @@

- - Calcium-Ion%2Fnew-api | Trendshift + + QuantumNous%2Fnew-api | Trendshift
diff --git a/README.zh_TW.md b/README.zh_TW.md index 2fa93157e..9264bc722 100644 --- a/README.zh_TW.md +++ b/README.zh_TW.md @@ -30,8 +30,8 @@

- - Calcium-Ion%2Fnew-api | Trendshift + + QuantumNous%2Fnew-api | Trendshift
From 4831bb7b5bef72c2c84fc4ac94e9bd154e4102cb Mon Sep 17 00:00:00 2001 From: Seefs Date: Sun, 22 Feb 2026 20:03:46 +0800 Subject: [PATCH 22/41] feat: guard new 504/524 status remaps with risk confirmation --- controller/relay.go | 2 +- .../operation_setting/status_code_ranges.go | 13 ++ .../status_code_ranges_test.go | 8 + .../modals/RiskAcknowledgementModal.jsx | 165 ++++++++++++++++++ .../channels/modals/EditChannelModal.jsx | 54 ++++++ .../modals/StatusCodeRiskGuardModal.jsx | 37 ++++ .../channels/modals/statusCodeRiskGuard.js | 101 +++++++++++ web/src/i18n/locales/en.json | 13 ++ web/src/i18n/locales/zh-CN.json | 13 ++ web/src/i18n/locales/zh-TW.json | 13 ++ .../Setting/Operation/SettingsMonitoring.jsx | 2 +- 11 files changed, 419 insertions(+), 2 deletions(-) create mode 100644 web/src/components/common/modals/RiskAcknowledgementModal.jsx create mode 100644 web/src/components/table/channels/modals/StatusCodeRiskGuardModal.jsx create mode 100644 web/src/components/table/channels/modals/statusCodeRiskGuard.js diff --git a/controller/relay.go b/controller/relay.go index 7e7922e75..edea1586f 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -614,7 +614,7 @@ func shouldRetryTaskRelay(c *gin.Context, channelId int, taskErr *dto.TaskError, } if taskErr.StatusCode/100 == 5 { // 超时不重试 - if taskErr.StatusCode == 504 || taskErr.StatusCode == 524 { + if operation_setting.IsAlwaysSkipRetryStatusCode(taskErr.StatusCode) { return false } return true diff --git a/setting/operation_setting/status_code_ranges.go b/setting/operation_setting/status_code_ranges.go index 698c87c91..7e3bc847a 100644 --- a/setting/operation_setting/status_code_ranges.go +++ b/setting/operation_setting/status_code_ranges.go @@ -26,6 +26,11 @@ var AutomaticRetryStatusCodeRanges = []StatusCodeRange{ {Start: 525, End: 599}, } +var alwaysSkipRetryStatusCodes = map[int]struct{}{ + 504: {}, + 524: {}, +} + func AutomaticDisableStatusCodesToString() string { return statusCodeRangesToString(AutomaticDisableStatusCodeRanges) } @@ -56,7 +61,15 @@ func AutomaticRetryStatusCodesFromString(s string) error { return nil } +func IsAlwaysSkipRetryStatusCode(code int) bool { + _, exists := alwaysSkipRetryStatusCodes[code] + return exists +} + func ShouldRetryByStatusCode(code int) bool { + if IsAlwaysSkipRetryStatusCode(code) { + return false + } return shouldMatchStatusCodeRanges(AutomaticRetryStatusCodeRanges, code) } diff --git a/setting/operation_setting/status_code_ranges_test.go b/setting/operation_setting/status_code_ranges_test.go index 5801824ac..4e292a368 100644 --- a/setting/operation_setting/status_code_ranges_test.go +++ b/setting/operation_setting/status_code_ranges_test.go @@ -62,6 +62,8 @@ func TestShouldRetryByStatusCode(t *testing.T) { require.True(t, ShouldRetryByStatusCode(429)) require.True(t, ShouldRetryByStatusCode(500)) + require.False(t, ShouldRetryByStatusCode(504)) + require.False(t, ShouldRetryByStatusCode(524)) require.False(t, ShouldRetryByStatusCode(400)) require.False(t, ShouldRetryByStatusCode(200)) } @@ -77,3 +79,9 @@ func TestShouldRetryByStatusCode_DefaultMatchesLegacyBehavior(t *testing.T) { require.False(t, ShouldRetryByStatusCode(524)) require.True(t, ShouldRetryByStatusCode(599)) } + +func TestIsAlwaysSkipRetryStatusCode(t *testing.T) { + require.True(t, IsAlwaysSkipRetryStatusCode(504)) + require.True(t, IsAlwaysSkipRetryStatusCode(524)) + require.False(t, IsAlwaysSkipRetryStatusCode(500)) +} diff --git a/web/src/components/common/modals/RiskAcknowledgementModal.jsx b/web/src/components/common/modals/RiskAcknowledgementModal.jsx new file mode 100644 index 000000000..1ed12166e --- /dev/null +++ b/web/src/components/common/modals/RiskAcknowledgementModal.jsx @@ -0,0 +1,165 @@ +/* +Copyright (C) 2025 QuantumNous + +This program is free software: you can redistribute it and/or modify +it under the terms of the GNU Affero General Public License as +published by the Free Software Foundation, either version 3 of the +License, or (at your option) any later version. + +This program is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU Affero General Public License for more details. + +You should have received a copy of the GNU Affero General Public License +along with this program. If not, see . + +For commercial licensing, please contact support@quantumnous.com +*/ + +import React, { useEffect, useMemo, useState } from 'react'; +import { + Modal, + Button, + Typography, + Checkbox, + Input, + Space, +} from '@douyinfe/semi-ui'; +import { IconAlertTriangle } from '@douyinfe/semi-icons'; +import { useIsMobile } from '../../../hooks/common/useIsMobile'; +import MarkdownRenderer from '../markdown/MarkdownRenderer'; + +const { Text } = Typography; + +const RiskAcknowledgementModal = ({ + visible, + title, + markdownContent = '', + detailTitle = '', + detailItems = [], + checklist = [], + inputPrompt = '', + requiredText = '', + inputPlaceholder = '', + mismatchText = '', + cancelText = '', + confirmText = '', + onCancel, + onConfirm, +}) => { + const isMobile = useIsMobile(); + const [checkedItems, setCheckedItems] = useState([]); + const [typedText, setTypedText] = useState(''); + + useEffect(() => { + if (!visible) return; + setCheckedItems(Array(checklist.length).fill(false)); + setTypedText(''); + }, [visible, checklist.length]); + + const allChecked = useMemo(() => { + if (checklist.length === 0) return true; + return checkedItems.length === checklist.length && checkedItems.every(Boolean); + }, [checkedItems, checklist.length]); + + const typedMatched = useMemo(() => { + if (!requiredText) return true; + return typedText.trim() === requiredText.trim(); + }, [typedText, requiredText]); + + return ( + + + {title} + + } + width={isMobile ? '100%' : 860} + centered + maskClosable={false} + closeOnEsc={false} + onCancel={onCancel} + bodyStyle={{ + maxHeight: isMobile ? '70vh' : '72vh', + overflowY: 'auto', + padding: isMobile ? '12px 16px' : '16px 20px', + }} + footer={ + + + + + } + > +

+ + ); +}; + +export default RiskAcknowledgementModal; diff --git a/web/src/components/table/channels/modals/EditChannelModal.jsx b/web/src/components/table/channels/modals/EditChannelModal.jsx index 6e85ca982..2935006dd 100644 --- a/web/src/components/table/channels/modals/EditChannelModal.jsx +++ b/web/src/components/table/channels/modals/EditChannelModal.jsx @@ -61,9 +61,11 @@ import OllamaModelModal from './OllamaModelModal'; import CodexOAuthModal from './CodexOAuthModal'; import JSONEditor from '../../../common/ui/JSONEditor'; import SecureVerificationModal from '../../../common/modals/SecureVerificationModal'; +import StatusCodeRiskGuardModal from './StatusCodeRiskGuardModal'; import ChannelKeyDisplay from '../../../common/ui/ChannelKeyDisplay'; import { useSecureVerification } from '../../../../hooks/common/useSecureVerification'; import { createApiCalls } from '../../../../services/secureVerification'; +import { collectNewDisallowedStatusCodeRedirects } from './statusCodeRiskGuard'; import { IconSave, IconClose, @@ -255,6 +257,12 @@ const EditChannelModal = (props) => { window.open(targetUrl, '_blank', 'noopener'); }; const [verifyLoading, setVerifyLoading] = useState(false); + const statusCodeRiskConfirmResolverRef = useRef(null); + const [statusCodeRiskConfirmVisible, setStatusCodeRiskConfirmVisible] = + useState(false); + const [statusCodeRiskDetailItems, setStatusCodeRiskDetailItems] = useState( + [], + ); // 表单块导航相关状态 const formSectionRefs = useRef({ @@ -276,6 +284,7 @@ const EditChannelModal = (props) => { const doubaoApiClickCountRef = useRef(0); const initialModelsRef = useRef([]); const initialModelMappingRef = useRef(''); + const initialStatusCodeMappingRef = useRef(''); // 2FA状态更新辅助函数 const updateTwoFAState = (updates) => { @@ -691,6 +700,7 @@ const EditChannelModal = (props) => { .map((model) => (model || '').trim()) .filter(Boolean); initialModelMappingRef.current = data.model_mapping || ''; + initialStatusCodeMappingRef.current = data.status_code_mapping || ''; let parsedIonet = null; if (data.other_info) { @@ -1017,11 +1027,22 @@ const EditChannelModal = (props) => { if (!isEdit) { initialModelsRef.current = []; initialModelMappingRef.current = ''; + initialStatusCodeMappingRef.current = ''; } }, [isEdit, props.visible]); + useEffect(() => { + return () => { + if (statusCodeRiskConfirmResolverRef.current) { + statusCodeRiskConfirmResolverRef.current(false); + statusCodeRiskConfirmResolverRef.current = null; + } + }; + }, []); + // 统一的模态框重置函数 const resetModalState = () => { + resolveStatusCodeRiskConfirm(false); formApiRef.current?.reset(); // 重置渠道设置状态 setChannelSettings({ @@ -1151,6 +1172,22 @@ const EditChannelModal = (props) => { }); }); + const resolveStatusCodeRiskConfirm = (confirmed) => { + setStatusCodeRiskConfirmVisible(false); + setStatusCodeRiskDetailItems([]); + if (statusCodeRiskConfirmResolverRef.current) { + statusCodeRiskConfirmResolverRef.current(confirmed); + statusCodeRiskConfirmResolverRef.current = null; + } + }; + + const confirmStatusCodeRisk = (detailItems) => + new Promise((resolve) => { + statusCodeRiskConfirmResolverRef.current = resolve; + setStatusCodeRiskDetailItems(detailItems); + setStatusCodeRiskConfirmVisible(true); + }); + const hasModelConfigChanged = (normalizedModels, modelMappingStr) => { if (!isEdit) return true; const initialModels = initialModelsRef.current; @@ -1340,6 +1377,17 @@ const EditChannelModal = (props) => { } } + const riskyStatusCodeRedirects = collectNewDisallowedStatusCodeRedirects( + initialStatusCodeMappingRef.current, + localInputs.status_code_mapping, + ); + if (riskyStatusCodeRedirects.length > 0) { + const confirmed = await confirmStatusCodeRisk(riskyStatusCodeRedirects); + if (!confirmed) { + return; + } + } + if (localInputs.base_url && localInputs.base_url.endsWith('/')) { localInputs.base_url = localInputs.base_url.slice( 0, @@ -3440,6 +3488,12 @@ const EditChannelModal = (props) => { onVisibleChange={(visible) => setIsModalOpenurl(visible)} /> + resolveStatusCodeRiskConfirm(false)} + onConfirm={() => resolveStatusCodeRiskConfirm(true)} + /> {/* 使用通用安全验证模态框 */} { + const { t } = useTranslation(); + + return ( + t(item))} + inputPrompt={t(STATUS_CODE_RISK_I18N_KEYS.inputPrompt)} + requiredText={t(STATUS_CODE_RISK_I18N_KEYS.confirmText)} + inputPlaceholder={t(STATUS_CODE_RISK_I18N_KEYS.inputPlaceholder)} + mismatchText={t(STATUS_CODE_RISK_I18N_KEYS.mismatchText)} + cancelText={t('取消')} + confirmText={t(STATUS_CODE_RISK_I18N_KEYS.confirmButton)} + onCancel={onCancel} + onConfirm={onConfirm} + /> + ); +}; + +export default StatusCodeRiskGuardModal; diff --git a/web/src/components/table/channels/modals/statusCodeRiskGuard.js b/web/src/components/table/channels/modals/statusCodeRiskGuard.js new file mode 100644 index 000000000..7ea983f86 --- /dev/null +++ b/web/src/components/table/channels/modals/statusCodeRiskGuard.js @@ -0,0 +1,101 @@ +const NON_REDIRECTABLE_STATUS_CODES = new Set([504, 524]); + +export const STATUS_CODE_RISK_I18N_KEYS = { + title: '高危操作确认', + detailTitle: '检测到以下高危状态码重定向规则', + inputPrompt: '操作确认', + confirmButton: '我确认开启高危重试', + markdown: '高危状态码重试风险告知与免责声明Markdown', + confirmText: '高危状态码重试风险确认输入文本', + inputPlaceholder: '高危状态码重试风险输入框占位文案', + mismatchText: '高危状态码重试风险输入不匹配提示', +}; + +export const STATUS_CODE_RISK_CHECKLIST_KEYS = [ + '高危状态码重试风险确认项1', + '高危状态码重试风险确认项2', + '高危状态码重试风险确认项3', + '高危状态码重试风险确认项4', +]; + +function parseStatusCodeKey(rawKey) { + if (typeof rawKey !== 'string') { + return null; + } + const normalized = rawKey.trim(); + if (!/^[1-5]\d{2}$/.test(normalized)) { + return null; + } + return Number.parseInt(normalized, 10); +} + +function parseStatusCodeMappingTarget(rawValue) { + if (typeof rawValue === 'number' && Number.isInteger(rawValue)) { + return rawValue >= 100 && rawValue <= 599 ? rawValue : null; + } + if (typeof rawValue === 'string') { + const normalized = rawValue.trim(); + if (!/^[1-5]\d{2}$/.test(normalized)) { + return null; + } + const code = Number.parseInt(normalized, 10); + return code >= 100 && code <= 599 ? code : null; + } + return null; +} + +export function collectDisallowedStatusCodeRedirects(statusCodeMappingStr) { + if ( + typeof statusCodeMappingStr !== 'string' || + statusCodeMappingStr.trim() === '' + ) { + return []; + } + + let parsed; + try { + parsed = JSON.parse(statusCodeMappingStr); + } catch (error) { + return []; + } + + if (!parsed || typeof parsed !== 'object' || Array.isArray(parsed)) { + return []; + } + + const riskyMappings = []; + Object.entries(parsed).forEach(([rawFrom, rawTo]) => { + const fromCode = parseStatusCodeKey(rawFrom); + const toCode = parseStatusCodeMappingTarget(rawTo); + if (fromCode === null || toCode === null) { + return; + } + if (!NON_REDIRECTABLE_STATUS_CODES.has(fromCode)) { + return; + } + if (fromCode === toCode) { + return; + } + riskyMappings.push(`${fromCode} -> ${toCode}`); + }); + + return Array.from(new Set(riskyMappings)).sort(); +} + +export function collectNewDisallowedStatusCodeRedirects( + originalStatusCodeMappingStr, + currentStatusCodeMappingStr, +) { + const currentRisky = collectDisallowedStatusCodeRedirects( + currentStatusCodeMappingStr, + ); + if (currentRisky.length === 0) { + return []; + } + + const originalRiskySet = new Set( + collectDisallowedStatusCodeRedirects(originalStatusCodeMappingStr), + ); + + return currentRisky.filter((mapping) => !originalRiskySet.has(mapping)); +} diff --git a/web/src/i18n/locales/en.json b/web/src/i18n/locales/en.json index 93b5f18c3..b5ab87353 100644 --- a/web/src/i18n/locales/en.json +++ b/web/src/i18n/locales/en.json @@ -1949,6 +1949,19 @@ "自动重试状态码": "Auto-retry status codes", "自动重试状态码格式不正确": "Invalid auto-retry status code format", "支持填写单个状态码或范围(含首尾),使用逗号分隔": "Supports single status codes or inclusive ranges; separate with commas", + "支持填写单个状态码或范围(含首尾),使用逗号分隔;504 和 524 始终不重试,不受此处配置影响": "Supports single status codes or inclusive ranges; separate with commas. 504 and 524 are never retried and are not affected by this setting", + "高危操作确认": "High-risk operation confirmation", + "检测到以下高危状态码重定向规则": "Detected high-risk status-code redirect rules", + "操作确认": "Operation confirmation", + "我确认开启高危重试": "I confirm enabling high-risk retry", + "高危状态码重试风险告知与免责声明Markdown": "### ⚠️ High-Risk Operation: Risk Notice and Disclaimer for 504/524 Retry\n\n[Background]\nBy default, this project does not retry for status codes `400` (bad request), `504` (gateway timeout), and `524` (timeout occurred). In many cases, `504` and `524` mean the request has reached the upstream AI service and processing has started, but the connection was closed due to long processing time.\n\nEnabling redirection/retry for these timeout status codes is a **high-risk operation**. Before enabling it, you must read and understand the consequences below:\n\n#### 1. Core Risks (Read Carefully)\n1. 💸 Duplicate/multiple billing risk: Most upstream AI providers **still charge** for requests that started processing but got interrupted by network timeout (`504`/`524`). If retry is triggered, a new upstream request will be sent, which can lead to **duplicate or multiple charges**.\n2. ⏳ Severe client timeout: If a single request already timed out, adding retries can multiply total latency and cause severe or unacceptable timeout behavior for your final client/caller.\n3. 💥 Request backlog and system crash risk: Forcing retries on timeout requests keeps threads and connections occupied for longer. Under high concurrency, this can cause serious backlog, exhaust system resources, trigger a cascading failure, and crash your proxy service.\n\n#### 2. Risk Acknowledgement\nIf you still choose to enable this feature, you acknowledge all of the following:\n\n- [ ] I have fully read and understood the risks and fully understand the destructive consequences of forcing retries for status codes `504` and `524`.\n- [ ] I have communicated with the upstream provider and confirmed that the timeout issue is an upstream bottleneck and cannot be resolved upstream at this time.\n- [ ] I voluntarily accept all duplicate/multiple billing risks and will not file issues or complaints in this project repository regarding billing anomalies caused by this retry behavior.\n- [ ] I voluntarily accept system stability risks, including severe client timeout and possible service crash. Any consequences caused by enabling this feature are my own responsibility.\n\n> **[Operation Confirmation]**\n> To unlock this feature, manually type the text below in the input box:\n> I understand the duplicate billing and crash risks, and confirm enabling it.", + "高危状态码重试风险确认输入文本": "I understand the duplicate billing and crash risks, and confirm enabling it.", + "高危状态码重试风险确认项1": "I have fully read and understood the risks and fully understand the destructive consequences of forcing retries for status codes 504 and 524.", + "高危状态码重试风险确认项2": "I have communicated with the upstream provider and confirmed that the timeout issue is an upstream bottleneck and cannot be resolved upstream at this time.", + "高危状态码重试风险确认项3": "I voluntarily accept all duplicate/multiple billing risks and will not file issues or complaints in this project repository regarding billing anomalies caused by this retry behavior.", + "高危状态码重试风险确认项4": "I voluntarily accept system stability risks, including severe client timeout and possible service crash. Any consequences caused by enabling this feature are my own responsibility.", + "高危状态码重试风险输入框占位文案": "Please type the exact text above", + "高危状态码重试风险输入不匹配提示": "The input does not match the required text", "例如:401, 403, 429, 500-599": "e.g. 401,403,429,500-599", "自动选择": "Auto Select", "自定义充值数量选项": "Custom Recharge Amount Options", diff --git a/web/src/i18n/locales/zh-CN.json b/web/src/i18n/locales/zh-CN.json index a5bace57f..e99441050 100644 --- a/web/src/i18n/locales/zh-CN.json +++ b/web/src/i18n/locales/zh-CN.json @@ -1936,6 +1936,19 @@ "自动重试状态码": "自动重试状态码", "自动重试状态码格式不正确": "自动重试状态码格式不正确", "支持填写单个状态码或范围(含首尾),使用逗号分隔": "支持填写单个状态码或范围(含首尾),使用逗号分隔", + "支持填写单个状态码或范围(含首尾),使用逗号分隔;504 和 524 始终不重试,不受此处配置影响": "支持填写单个状态码或范围(含首尾),使用逗号分隔;504 和 524 始终不重试,不受此处配置影响", + "高危操作确认": "高危操作确认", + "检测到以下高危状态码重定向规则": "检测到以下高危状态码重定向规则", + "操作确认": "操作确认", + "我确认开启高危重试": "我确认开启高危重试", + "高危状态码重试风险告知与免责声明Markdown": "### ⚠️ 高危操作:504/524 状态码重试风险告知与免责声明\n\n【背景提示】\n本项目默认对 `400`(请求错误)、`504`(网关超时)和 `524`(发生超时)状态码不进行重试。504 和 524 错误通常意味着**请求已成功送达上游 AI 服务,且上游正在处理,但因处理时间过长导致连接断开**。\n\n开启对此类超时状态码的重定向/重试属于**极高风险操作**。作为本开源项目的使用者,在开启该功能前,您必须仔细阅读并知悉以下严重后果:\n\n#### 一、 核心风险告知(请仔细阅读)\n1. 💸 双重/多重计费风险: 绝大多数 AI 上游厂商对于已经开始处理但因网络原因中断(504/524)的请求**依然会进行扣费**。此时若触发重试,将会向上游发起全新请求,导致您被**双重甚至多重计费**。\n2. ⏳ 客户端严重超时: 单次请求已经触发超时,叠加重试机制将会使总请求耗时成倍增加,导致您的最终客户端(或调用方)出现严重甚至完全无法接受的超时现象。\n3. 💥 请求积压与系统崩溃风险: 强制重试超时请求会长时间占用系统线程和连接数。在高并发场景下,这会导致严重的**请求积压**,进而耗尽系统资源,引发雪崩效应,导致您的整个代理服务崩溃。\n\n#### 二、 风险确认声明\n如果您坚持开启该功能,即代表您作出以下确认:\n\n- [ ] 我已充分阅读并理解**:本人已完整阅读上述全部风险提示,完全理解强制重试 `504` 和 `524` 状态码可能带来的破坏性后果。\n- [ ] **我已与上游沟通并确认**:本人确认,当前出现的超时问题属于上游服务的瓶颈。**本人已与上游提供商进行过沟通,确认上游无法解决该超时问题**,因此才采取强制重试方案作为妥协手段。\n- [ ] **我自愿承担计费损失**:本人知晓并接受由此产生的全部双重/多重计费风险,**承诺不会因重试导致的账单异常在本项目仓库中提交 Issue 或抱怨**。\n- [ ] **我自愿承担系统稳定性风险**:本人知晓该操作可能导致客户端严重超时及服务崩溃。若因本人开启此功能导致请求积压或服务不可用,后果由本人自行承担。\n\n> **【操作确认】\n> 为确认您已清晰了解上述风险,请在下方输入框内手动输入以下文字以解锁功能:\n> 我已了解多重计费与崩溃风险,确认开启", + "高危状态码重试风险确认输入文本": "我已了解多重计费与崩溃风险,确认开启", + "高危状态码重试风险确认项1": "我已充分阅读并理解:本人已完整阅读上述全部风险提示,完全理解强制重试 504 和 524 状态码可能带来的破坏性后果。", + "高危状态码重试风险确认项2": "我已与上游沟通并确认:本人确认,当前出现的超时问题属于上游服务的瓶颈。本人已与上游提供商进行过沟通,确认上游无法解决该超时问题,因此才采取强制重试方案作为妥协手段。", + "高危状态码重试风险确认项3": "我自愿承担计费损失:本人知晓并接受由此产生的全部双重/多重计费风险,承诺不会因重试导致的账单异常在本项目仓库中提交 Issue 或抱怨。", + "高危状态码重试风险确认项4": "我自愿承担系统稳定性风险:本人知晓该操作可能导致客户端严重超时及服务崩溃。若因本人开启此功能导致请求积压或服务不可用,后果由本人自行承担。", + "高危状态码重试风险输入框占位文案": "请完整输入上方文字", + "高危状态码重试风险输入不匹配提示": "输入内容与要求不一致", "例如:401, 403, 429, 500-599": "例如:401,403,429,500-599", "自动选择": "自动选择", "自定义充值数量选项": "自定义充值数量选项", diff --git a/web/src/i18n/locales/zh-TW.json b/web/src/i18n/locales/zh-TW.json index 562a7d543..635255c79 100644 --- a/web/src/i18n/locales/zh-TW.json +++ b/web/src/i18n/locales/zh-TW.json @@ -1942,6 +1942,19 @@ "自动重试状态码": "自動重試狀態碼", "自动重试状态码格式不正确": "自動重試狀態碼格式不正確", "支持填写单个状态码或范围(含首尾),使用逗号分隔": "支援填寫單個狀態碼或範圍(含首尾),使用逗號分隔", + "支持填写单个状态码或范围(含首尾),使用逗号分隔;504 和 524 始终不重试,不受此处配置影响": "支援填寫單個狀態碼或範圍(含首尾),使用逗號分隔;504 和 524 一律不重試,不受此處設定影響", + "高危操作确认": "高風險操作確認", + "检测到以下高危状态码重定向规则": "檢測到以下高風險狀態碼重定向規則", + "操作确认": "操作確認", + "我确认开启高危重试": "我確認開啟高風險重試", + "高危状态码重试风险告知与免责声明Markdown": "### ⚠️ 高風險操作:504/524 狀態碼重試風險告知與免責聲明\n\n【背景提示】\n本專案預設對 `400`(請求錯誤)、`504`(閘道逾時)與 `524`(發生逾時)狀態碼不進行重試。504 與 524 錯誤通常代表**請求已成功送達上游 AI 服務,且上游正在處理,但因處理時間過長導致連線中斷**。\n\n開啟此類逾時狀態碼的重定向/重試屬於**極高風險操作**。作為本開源專案使用者,在開啟該功能前,您必須仔細閱讀並知悉以下嚴重後果:\n\n#### 一、 核心風險告知(請仔細閱讀)\n1. 💸 雙重/多重計費風險:多數 AI 上游廠商對於已開始處理但因網路原因中斷(504/524)的請求**仍然會扣費**。此時若觸發重試,將會向上游發起全新請求,導致您被**雙重甚至多重計費**。\n2. ⏳ 用戶端嚴重逾時:單次請求已觸發逾時,疊加重試機制會使總請求耗時成倍增加,導致最終用戶端(或呼叫方)出現嚴重甚至無法接受的逾時現象。\n3. 💥 請求積壓與系統崩潰風險:強制重試逾時請求會長時間占用系統執行緒與連線數。在高併發場景下,這將導致嚴重**請求積壓**,進而耗盡系統資源,引發雪崩效應,造成整個代理服務崩潰。\n\n#### 二、 風險確認聲明\n若您堅持開啟該功能,即代表您作出以下確認:\n\n- [ ] 我已充分閱讀並理解:本人已完整閱讀上述全部風險提示,完全理解強制重試 `504` 與 `524` 狀態碼可能帶來的破壞性後果。\n- [ ] 我已與上游溝通並確認:本人確認,當前逾時問題屬於上游服務瓶頸。本人已與上游供應商溝通,確認上游無法解決該逾時問題,因此才採取強制重試方案作為妥協手段。\n- [ ] 我自願承擔計費損失:本人知悉並接受由此產生的全部雙重/多重計費風險,承諾不會因重試導致的帳單異常在本專案倉庫提交 Issue 或抱怨。\n- [ ] 我自願承擔系統穩定性風險:本人知悉該操作可能導致用戶端嚴重逾時及服務崩潰。若因本人開啟此功能導致請求積壓或服務不可用,後果由本人自行承擔。\n\n> **【操作確認】**\n> 為確認您已清楚了解上述風險,請在下方輸入框內手動輸入以下文字以解鎖功能:\n> 我已了解多重計費與崩潰風險,確認開啟", + "高危状态码重试风险确认输入文本": "我已了解多重計費與崩潰風險,確認開啟", + "高危状态码重试风险确认项1": "我已充分閱讀並理解:本人已完整閱讀上述全部風險提示,完全理解強制重試 504 與 524 狀態碼可能帶來的破壞性後果。", + "高危状态码重试风险确认项2": "我已與上游溝通並確認:本人確認,當前逾時問題屬於上游服務瓶頸。本人已與上游供應商溝通,確認上游無法解決該逾時問題,因此才採取強制重試方案作為妥協手段。", + "高危状态码重试风险确认项3": "我自願承擔計費損失:本人知悉並接受由此產生的全部雙重/多重計費風險,承諾不會因重試導致的帳單異常在本專案倉庫提交 Issue 或抱怨。", + "高危状态码重试风险确认项4": "我自願承擔系統穩定性風險:本人知悉該操作可能導致用戶端嚴重逾時及服務崩潰。若因本人開啟此功能導致請求積壓或服務不可用,後果由本人自行承擔。", + "高危状态码重试风险输入框占位文案": "請完整輸入上方文字", + "高危状态码重试风险输入不匹配提示": "輸入內容與要求不一致", "例如:401, 403, 429, 500-599": "例如:401,403,429,500-599", "自动选择": "自動選擇", "自定义充值数量选项": "自訂儲值數量選項", diff --git a/web/src/pages/Setting/Operation/SettingsMonitoring.jsx b/web/src/pages/Setting/Operation/SettingsMonitoring.jsx index 29b55e56c..e4ee116f2 100644 --- a/web/src/pages/Setting/Operation/SettingsMonitoring.jsx +++ b/web/src/pages/Setting/Operation/SettingsMonitoring.jsx @@ -254,7 +254,7 @@ export default function SettingsMonitoring(props) { label={t('自动重试状态码')} placeholder={t('例如:401, 403, 429, 500-599')} extraText={t( - '支持填写单个状态码或范围(含首尾),使用逗号分隔', + '支持填写单个状态码或范围(含首尾),使用逗号分隔;504 和 524 始终不重试,不受此处配置影响', )} field={'AutomaticRetryStatusCodes'} onChange={(value) => From c4c4e5eda6bdf41b9c54999359c0c9b7b14481f2 Mon Sep 17 00:00:00 2001 From: Seefs Date: Sun, 22 Feb 2026 20:14:56 +0800 Subject: [PATCH 23/41] feat: add localized high-risk status remap guard with optimized modal UX --- .../modals/RiskAcknowledgementModal.jsx | 98 +++++++++++++++---- .../modals/StatusCodeRiskGuardModal.jsx | 16 +-- 2 files changed, 89 insertions(+), 25 deletions(-) diff --git a/web/src/components/common/modals/RiskAcknowledgementModal.jsx b/web/src/components/common/modals/RiskAcknowledgementModal.jsx index 1ed12166e..63806ad73 100644 --- a/web/src/components/common/modals/RiskAcknowledgementModal.jsx +++ b/web/src/components/common/modals/RiskAcknowledgementModal.jsx @@ -17,7 +17,7 @@ along with this program. If not, see . For commercial licensing, please contact support@quantumnous.com */ -import React, { useEffect, useMemo, useState } from 'react'; +import React, { useCallback, useEffect, useMemo, useState } from 'react'; import { Modal, Button, @@ -32,7 +32,30 @@ import MarkdownRenderer from '../markdown/MarkdownRenderer'; const { Text } = Typography; -const RiskAcknowledgementModal = ({ +const RiskMarkdownBlock = React.memo(function RiskMarkdownBlock({ + markdownContent, +}) { + if (!markdownContent) { + return null; + } + + return ( +
+ +
+ ); +}); + +const RiskAcknowledgementModal = React.memo(function RiskAcknowledgementModal({ visible, title, markdownContent = '', @@ -47,7 +70,7 @@ const RiskAcknowledgementModal = ({ confirmText = '', onCancel, onConfirm, -}) => { +}) { const isMobile = useIsMobile(); const [checkedItems, setCheckedItems] = useState([]); const [typedText, setTypedText] = useState(''); @@ -68,6 +91,17 @@ const RiskAcknowledgementModal = ({ return typedText.trim() === requiredText.trim(); }, [typedText, requiredText]); + const detailText = useMemo(() => detailItems.join(', '), [detailItems]); + const canConfirm = allChecked && typedMatched; + + const handleChecklistChange = useCallback((index, checked) => { + setCheckedItems((previous) => { + const next = [...previous]; + next[index] = checked; + return next; + }); + }, []); + return ( @@ -93,7 +127,7 @@ const RiskAcknowledgementModal = ({
+ +
+ ) : ( +
+
+ ); +}; + +const AudioPreviewModal = ({ isModalOpen, setIsModalOpen, audioClips }) => { + const { t } = useTranslation(); + const clips = Array.isArray(audioClips) ? audioClips : []; + + return ( + setIsModalOpen(false)} + onCancel={() => setIsModalOpen(false)} + closable={null} + footer={null} + bodyStyle={{ + maxHeight: '70vh', + overflow: 'auto', + padding: '16px', + }} + width={560} + > + {clips.length === 0 ? ( + {t('无')} + ) : ( +
+ {clips.map((clip, idx) => ( + + ))} +
+ )} +
+ ); +}; + +export default AudioPreviewModal; diff --git a/web/src/hooks/task-logs/useTaskLogsData.js b/web/src/hooks/task-logs/useTaskLogsData.js index a461e3522..6ba3de388 100644 --- a/web/src/hooks/task-logs/useTaskLogsData.js +++ b/web/src/hooks/task-logs/useTaskLogsData.js @@ -72,6 +72,10 @@ export const useTaskLogsData = () => { const [isVideoModalOpen, setIsVideoModalOpen] = useState(false); const [videoUrl, setVideoUrl] = useState(''); + // Audio preview modal state + const [isAudioModalOpen, setIsAudioModalOpen] = useState(false); + const [audioClips, setAudioClips] = useState([]); + // User info modal state const [showUserInfo, setShowUserInfoModal] = useState(false); const [userInfoData, setUserInfoData] = useState(null); @@ -277,6 +281,11 @@ export const useTaskLogsData = () => { setIsVideoModalOpen(true); }; + const openAudioModal = (clips) => { + setAudioClips(clips); + setIsAudioModalOpen(true); + }; + // User info function const showUserInfoFunc = async (userId) => { if (!isAdminUser) { @@ -319,6 +328,11 @@ export const useTaskLogsData = () => { setIsVideoModalOpen, videoUrl, + // Audio preview modal + isAudioModalOpen, + setIsAudioModalOpen, + audioClips, + // Form state formApi, setFormApi, @@ -351,7 +365,8 @@ export const useTaskLogsData = () => { refresh, copyText, openContentModal, - openVideoModal, // 新增 + openVideoModal, + openAudioModal, enrichLogs, syncPageData, diff --git a/web/src/i18n/locales/en.json b/web/src/i18n/locales/en.json index 4c37baa76..e06c68362 100644 --- a/web/src/i18n/locales/en.json +++ b/web/src/i18n/locales/en.json @@ -1634,6 +1634,9 @@ "点击查看差异": "Click to view differences", "点击此处": "click here", "点击预览视频": "Click to preview video", + "点击预览音乐": "Click to preview music", + "音乐预览": "Music Preview", + "音频无法播放": "Audio cannot be played", "点击验证按钮,使用您的生物特征或安全密钥": "Click the verification button and use your biometrics or security key", "版权所有": "All rights reserved", "状态": "Status", diff --git a/web/src/i18n/locales/fr.json b/web/src/i18n/locales/fr.json index 4fdca8b86..2843728b8 100644 --- a/web/src/i18n/locales/fr.json +++ b/web/src/i18n/locales/fr.json @@ -1646,6 +1646,9 @@ "点击查看差异": "Cliquez pour voir les différences", "点击此处": "cliquez ici", "点击预览视频": "Cliquez pour prévisualiser la vidéo", + "点击预览音乐": "Cliquez pour écouter la musique", + "音乐预览": "Aperçu musical", + "音频无法播放": "Impossible de lire l'audio", "点击验证按钮,使用您的生物特征或安全密钥": "Cliquez sur le bouton de vérification pour utiliser vos caractéristiques biométriques ou votre clé de sécurité", "版权所有": "Tous droits réservés", "状态": "Statut", diff --git a/web/src/i18n/locales/ja.json b/web/src/i18n/locales/ja.json index d9f739a5d..d18a62923 100644 --- a/web/src/i18n/locales/ja.json +++ b/web/src/i18n/locales/ja.json @@ -1631,6 +1631,9 @@ "点击查看差异": "差分を表示", "点击此处": "こちらをクリック", "点击预览视频": "動画をプレビュー", + "点击预览音乐": "音楽をプレビュー", + "音乐预览": "音楽プレビュー", + "音频无法播放": "音声を再生できません", "点击验证按钮,使用您的生物特征或安全密钥": "認証ボタンをクリックし、生体情報またはセキュリティキーを使用してください", "版权所有": "All rights reserved", "状态": "ステータス", diff --git a/web/src/i18n/locales/ru.json b/web/src/i18n/locales/ru.json index cb2dec1c6..099f405c9 100644 --- a/web/src/i18n/locales/ru.json +++ b/web/src/i18n/locales/ru.json @@ -1657,6 +1657,9 @@ "点击查看差异": "Нажмите для просмотра различий", "点击此处": "Нажмите здесь", "点击预览视频": "Нажмите для предварительного просмотра видео", + "点击预览音乐": "Нажмите для прослушивания музыки", + "音乐预览": "Предварительное прослушивание", + "音频无法播放": "Не удалось воспроизвести аудио", "点击验证按钮,使用您的生物特征或安全密钥": "Нажмите кнопку проверки, используйте ваши биометрические данные или ключ безопасности", "版权所有": "Все права защищены", "状态": "Статус", diff --git a/web/src/i18n/locales/vi.json b/web/src/i18n/locales/vi.json index 4b81fac6e..d2602efdf 100644 --- a/web/src/i18n/locales/vi.json +++ b/web/src/i18n/locales/vi.json @@ -1773,6 +1773,9 @@ "点击链接重置密码": "Nhấp vào liên kết để đặt lại mật khẩu", "点击阅读": "Nhấp để đọc", "点击预览视频": "Nhấp để xem trước video", + "点击预览音乐": "Nhấp để nghe nhạc", + "音乐预览": "Xem trước nhạc", + "音频无法播放": "Không thể phát âm thanh", "点击验证按钮,使用您的生物特征或安全密钥": "Nhấp vào nút xác minh và sử dụng sinh trắc học hoặc khóa bảo mật của bạn", "版": "Phiên bản", "版本": "Phiên bản", diff --git a/web/src/i18n/locales/zh-CN.json b/web/src/i18n/locales/zh-CN.json index a41db52d0..d067ad569 100644 --- a/web/src/i18n/locales/zh-CN.json +++ b/web/src/i18n/locales/zh-CN.json @@ -1624,6 +1624,9 @@ "点击查看差异": "点击查看差异", "点击此处": "点击此处", "点击预览视频": "点击预览视频", + "点击预览音乐": "点击预览音乐", + "音乐预览": "音乐预览", + "音频无法播放": "音频无法播放", "点击验证按钮,使用您的生物特征或安全密钥": "点击验证按钮,使用您的生物特征或安全密钥", "版权所有": "版权所有", "状态": "状态", diff --git a/web/src/i18n/locales/zh-TW.json b/web/src/i18n/locales/zh-TW.json index 7ca22d2a6..26f7092b7 100644 --- a/web/src/i18n/locales/zh-TW.json +++ b/web/src/i18n/locales/zh-TW.json @@ -1628,6 +1628,9 @@ "点击查看差异": "點擊查看差異", "点击此处": "點擊此處", "点击预览视频": "點擊預覽影片", + "点击预览音乐": "點擊預覽音樂", + "音乐预览": "音樂預覽", + "音频无法播放": "音訊無法播放", "点击验证按钮,使用您的生物特征或安全密钥": "點擊驗證按鈕,使用您的生物特徵或安全密鑰", "版权所有": "版權所有", "状态": "狀態", From a01a77fc6f0e467b815b2f89a8bedb561f32117b Mon Sep 17 00:00:00 2001 From: Seefs <40468931+seefs001@users.noreply.github.com> Date: Sun, 22 Feb 2026 23:30:02 +0800 Subject: [PATCH 28/41] fix: claude affinity cache counter (#2980) * fix: claude affinity cache counter * fix: claude affinity cache counter * fix: stabilize cache usage stats format and simplify modal rendering --- relay/common/relay_info.go | 16 ++- relay/common/relay_info_test.go | 40 ++++++ relay/compatible_handler.go | 4 +- service/channel_affinity.go | 60 +++++++- service/channel_affinity_usage_cache_test.go | 105 ++++++++++++++ service/quota.go | 3 + .../modals/ChannelAffinityUsageCacheModal.jsx | 129 ++++++++++++------ 7 files changed, 304 insertions(+), 53 deletions(-) create mode 100644 relay/common/relay_info_test.go create mode 100644 service/channel_affinity_usage_cache_test.go diff --git a/relay/common/relay_info.go b/relay/common/relay_info.go index 541f1b9f8..e88f4e51f 100644 --- a/relay/common/relay_info.go +++ b/relay/common/relay_info.go @@ -152,7 +152,8 @@ type RelayInfo struct { // RequestConversionChain records request format conversions in order, e.g. // ["openai", "openai_responses"] or ["openai", "claude"]. RequestConversionChain []types.RelayFormat - // 最终请求到上游的格式 TODO: 当前仅设置了Claude + // 最终请求到上游的格式。可由 adaptor 显式设置; + // 若为空,调用 GetFinalRequestRelayFormat 会回退到 RequestConversionChain 的最后一项或 RelayFormat。 FinalRequestRelayFormat types.RelayFormat ThinkingContentInfo @@ -579,6 +580,19 @@ func (info *RelayInfo) AppendRequestConversion(format types.RelayFormat) { info.RequestConversionChain = append(info.RequestConversionChain, format) } +func (info *RelayInfo) GetFinalRequestRelayFormat() types.RelayFormat { + if info == nil { + return "" + } + if info.FinalRequestRelayFormat != "" { + return info.FinalRequestRelayFormat + } + if n := len(info.RequestConversionChain); n > 0 { + return info.RequestConversionChain[n-1] + } + return info.RelayFormat +} + func GenRelayInfoResponsesCompaction(c *gin.Context, request *dto.OpenAIResponsesCompactionRequest) *RelayInfo { info := genBaseRelayInfo(c, request) if info.RelayMode == relayconstant.RelayModeUnknown { diff --git a/relay/common/relay_info_test.go b/relay/common/relay_info_test.go new file mode 100644 index 000000000..e53ec804c --- /dev/null +++ b/relay/common/relay_info_test.go @@ -0,0 +1,40 @@ +package common + +import ( + "testing" + + "github.com/QuantumNous/new-api/types" + "github.com/stretchr/testify/require" +) + +func TestRelayInfoGetFinalRequestRelayFormatPrefersExplicitFinal(t *testing.T) { + info := &RelayInfo{ + RelayFormat: types.RelayFormatOpenAI, + RequestConversionChain: []types.RelayFormat{types.RelayFormatOpenAI, types.RelayFormatClaude}, + FinalRequestRelayFormat: types.RelayFormatOpenAIResponses, + } + + require.Equal(t, types.RelayFormat(types.RelayFormatOpenAIResponses), info.GetFinalRequestRelayFormat()) +} + +func TestRelayInfoGetFinalRequestRelayFormatFallsBackToConversionChain(t *testing.T) { + info := &RelayInfo{ + RelayFormat: types.RelayFormatOpenAI, + RequestConversionChain: []types.RelayFormat{types.RelayFormatOpenAI, types.RelayFormatClaude}, + } + + require.Equal(t, types.RelayFormat(types.RelayFormatClaude), info.GetFinalRequestRelayFormat()) +} + +func TestRelayInfoGetFinalRequestRelayFormatFallsBackToRelayFormat(t *testing.T) { + info := &RelayInfo{ + RelayFormat: types.RelayFormatGemini, + } + + require.Equal(t, types.RelayFormat(types.RelayFormatGemini), info.GetFinalRequestRelayFormat()) +} + +func TestRelayInfoGetFinalRequestRelayFormatNilReceiver(t *testing.T) { + var info *RelayInfo + require.Equal(t, types.RelayFormat(""), info.GetFinalRequestRelayFormat()) +} diff --git a/relay/compatible_handler.go b/relay/compatible_handler.go index e7adddbbf..cb25da0b3 100644 --- a/relay/compatible_handler.go +++ b/relay/compatible_handler.go @@ -232,7 +232,7 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage } if originUsage != nil { - service.ObserveChannelAffinityUsageCacheFromContext(ctx, usage) + service.ObserveChannelAffinityUsageCacheByRelayFormat(ctx, usage, relayInfo.GetFinalRequestRelayFormat()) } adminRejectReason := common.GetContextKeyString(ctx, constant.ContextKeyAdminRejectReason) @@ -336,7 +336,7 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage var audioInputQuota decimal.Decimal var audioInputPrice float64 - isClaudeUsageSemantic := relayInfo.FinalRequestRelayFormat == types.RelayFormatClaude + isClaudeUsageSemantic := relayInfo.GetFinalRequestRelayFormat() == types.RelayFormatClaude if !relayInfo.PriceData.UsePrice { baseTokens := dPromptTokens // 减去 cached tokens diff --git a/service/channel_affinity.go b/service/channel_affinity.go index fe1524c59..524c6574a 100644 --- a/service/channel_affinity.go +++ b/service/channel_affinity.go @@ -13,6 +13,7 @@ import ( "github.com/QuantumNous/new-api/dto" "github.com/QuantumNous/new-api/pkg/cachex" "github.com/QuantumNous/new-api/setting/operation_setting" + "github.com/QuantumNous/new-api/types" "github.com/gin-gonic/gin" "github.com/samber/hot" "github.com/tidwall/gjson" @@ -61,6 +62,12 @@ type ChannelAffinityStatsContext struct { TTLSeconds int64 } +const ( + cacheTokenRateModeCachedOverPrompt = "cached_over_prompt" + cacheTokenRateModeCachedOverPromptPlusCached = "cached_over_prompt_plus_cached" + cacheTokenRateModeMixed = "mixed" +) + type ChannelAffinityCacheStats struct { Enabled bool `json:"enabled"` Total int `json:"total"` @@ -565,9 +572,10 @@ func RecordChannelAffinity(c *gin.Context, channelID int) { } type ChannelAffinityUsageCacheStats struct { - RuleName string `json:"rule_name"` - UsingGroup string `json:"using_group"` - KeyFingerprint string `json:"key_fp"` + RuleName string `json:"rule_name"` + UsingGroup string `json:"using_group"` + KeyFingerprint string `json:"key_fp"` + CachedTokenRateMode string `json:"cached_token_rate_mode"` Hit int64 `json:"hit"` Total int64 `json:"total"` @@ -582,6 +590,8 @@ type ChannelAffinityUsageCacheStats struct { } type ChannelAffinityUsageCacheCounters struct { + CachedTokenRateMode string `json:"cached_token_rate_mode"` + Hit int64 `json:"hit"` Total int64 `json:"total"` WindowSeconds int64 `json:"window_seconds"` @@ -596,12 +606,17 @@ type ChannelAffinityUsageCacheCounters struct { var channelAffinityUsageCacheStatsLocks [64]sync.Mutex -func ObserveChannelAffinityUsageCacheFromContext(c *gin.Context, usage *dto.Usage) { +// ObserveChannelAffinityUsageCacheByRelayFormat records usage cache stats with a stable rate mode derived from relay format. +func ObserveChannelAffinityUsageCacheByRelayFormat(c *gin.Context, usage *dto.Usage, relayFormat types.RelayFormat) { + ObserveChannelAffinityUsageCacheFromContext(c, usage, cachedTokenRateModeByRelayFormat(relayFormat)) +} + +func ObserveChannelAffinityUsageCacheFromContext(c *gin.Context, usage *dto.Usage, cachedTokenRateMode string) { statsCtx, ok := GetChannelAffinityStatsContext(c) if !ok { return } - observeChannelAffinityUsageCache(statsCtx, usage) + observeChannelAffinityUsageCache(statsCtx, usage, cachedTokenRateMode) } func GetChannelAffinityUsageCacheStats(ruleName, usingGroup, keyFp string) ChannelAffinityUsageCacheStats { @@ -628,6 +643,7 @@ func GetChannelAffinityUsageCacheStats(ruleName, usingGroup, keyFp string) Chann } } return ChannelAffinityUsageCacheStats{ + CachedTokenRateMode: v.CachedTokenRateMode, RuleName: ruleName, UsingGroup: usingGroup, KeyFingerprint: keyFp, @@ -643,7 +659,7 @@ func GetChannelAffinityUsageCacheStats(ruleName, usingGroup, keyFp string) Chann } } -func observeChannelAffinityUsageCache(statsCtx ChannelAffinityStatsContext, usage *dto.Usage) { +func observeChannelAffinityUsageCache(statsCtx ChannelAffinityStatsContext, usage *dto.Usage, cachedTokenRateMode string) { entryKey := channelAffinityUsageCacheEntryKey(statsCtx.RuleName, statsCtx.UsingGroup, statsCtx.KeyFingerprint) if entryKey == "" { return @@ -669,6 +685,14 @@ func observeChannelAffinityUsageCache(statsCtx ChannelAffinityStatsContext, usag if !found { next = ChannelAffinityUsageCacheCounters{} } + currentMode := normalizeCachedTokenRateMode(cachedTokenRateMode) + if currentMode != "" { + if next.CachedTokenRateMode == "" { + next.CachedTokenRateMode = currentMode + } else if next.CachedTokenRateMode != currentMode && next.CachedTokenRateMode != cacheTokenRateModeMixed { + next.CachedTokenRateMode = cacheTokenRateModeMixed + } + } next.Total++ hit, cachedTokens, promptCacheHitTokens := usageCacheSignals(usage) if hit { @@ -684,6 +708,30 @@ func observeChannelAffinityUsageCache(statsCtx ChannelAffinityStatsContext, usag _ = cache.SetWithTTL(entryKey, next, ttl) } +func normalizeCachedTokenRateMode(mode string) string { + switch mode { + case cacheTokenRateModeCachedOverPrompt: + return cacheTokenRateModeCachedOverPrompt + case cacheTokenRateModeCachedOverPromptPlusCached: + return cacheTokenRateModeCachedOverPromptPlusCached + case cacheTokenRateModeMixed: + return cacheTokenRateModeMixed + default: + return "" + } +} + +func cachedTokenRateModeByRelayFormat(relayFormat types.RelayFormat) string { + switch relayFormat { + case types.RelayFormatOpenAI, types.RelayFormatOpenAIResponses, types.RelayFormatOpenAIResponsesCompaction: + return cacheTokenRateModeCachedOverPrompt + case types.RelayFormatClaude: + return cacheTokenRateModeCachedOverPromptPlusCached + default: + return "" + } +} + func channelAffinityUsageCacheEntryKey(ruleName, usingGroup, keyFp string) string { ruleName = strings.TrimSpace(ruleName) usingGroup = strings.TrimSpace(usingGroup) diff --git a/service/channel_affinity_usage_cache_test.go b/service/channel_affinity_usage_cache_test.go new file mode 100644 index 000000000..64d3d715b --- /dev/null +++ b/service/channel_affinity_usage_cache_test.go @@ -0,0 +1,105 @@ +package service + +import ( + "fmt" + "net/http/httptest" + "testing" + "time" + + "github.com/QuantumNous/new-api/dto" + "github.com/QuantumNous/new-api/types" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +func buildChannelAffinityStatsContextForTest(ruleName, usingGroup, keyFP string) *gin.Context { + rec := httptest.NewRecorder() + ctx, _ := gin.CreateTestContext(rec) + setChannelAffinityContext(ctx, channelAffinityMeta{ + CacheKey: fmt.Sprintf("test:%s:%s:%s", ruleName, usingGroup, keyFP), + TTLSeconds: 600, + RuleName: ruleName, + UsingGroup: usingGroup, + KeyFingerprint: keyFP, + }) + return ctx +} + +func TestObserveChannelAffinityUsageCacheByRelayFormat_ClaudeMode(t *testing.T) { + ruleName := fmt.Sprintf("rule_%d", time.Now().UnixNano()) + usingGroup := "default" + keyFP := fmt.Sprintf("fp_%d", time.Now().UnixNano()) + ctx := buildChannelAffinityStatsContextForTest(ruleName, usingGroup, keyFP) + + usage := &dto.Usage{ + PromptTokens: 100, + CompletionTokens: 40, + TotalTokens: 140, + PromptTokensDetails: dto.InputTokenDetails{ + CachedTokens: 30, + }, + } + + ObserveChannelAffinityUsageCacheByRelayFormat(ctx, usage, types.RelayFormatClaude) + stats := GetChannelAffinityUsageCacheStats(ruleName, usingGroup, keyFP) + + require.EqualValues(t, 1, stats.Total) + require.EqualValues(t, 1, stats.Hit) + require.EqualValues(t, 100, stats.PromptTokens) + require.EqualValues(t, 40, stats.CompletionTokens) + require.EqualValues(t, 140, stats.TotalTokens) + require.EqualValues(t, 30, stats.CachedTokens) + require.Equal(t, cacheTokenRateModeCachedOverPromptPlusCached, stats.CachedTokenRateMode) +} + +func TestObserveChannelAffinityUsageCacheByRelayFormat_MixedMode(t *testing.T) { + ruleName := fmt.Sprintf("rule_%d", time.Now().UnixNano()) + usingGroup := "default" + keyFP := fmt.Sprintf("fp_%d", time.Now().UnixNano()) + ctx := buildChannelAffinityStatsContextForTest(ruleName, usingGroup, keyFP) + + openAIUsage := &dto.Usage{ + PromptTokens: 100, + PromptTokensDetails: dto.InputTokenDetails{ + CachedTokens: 10, + }, + } + claudeUsage := &dto.Usage{ + PromptTokens: 80, + PromptTokensDetails: dto.InputTokenDetails{ + CachedTokens: 20, + }, + } + + ObserveChannelAffinityUsageCacheByRelayFormat(ctx, openAIUsage, types.RelayFormatOpenAI) + ObserveChannelAffinityUsageCacheByRelayFormat(ctx, claudeUsage, types.RelayFormatClaude) + stats := GetChannelAffinityUsageCacheStats(ruleName, usingGroup, keyFP) + + require.EqualValues(t, 2, stats.Total) + require.EqualValues(t, 2, stats.Hit) + require.EqualValues(t, 180, stats.PromptTokens) + require.EqualValues(t, 30, stats.CachedTokens) + require.Equal(t, cacheTokenRateModeMixed, stats.CachedTokenRateMode) +} + +func TestObserveChannelAffinityUsageCacheByRelayFormat_UnsupportedModeKeepsEmpty(t *testing.T) { + ruleName := fmt.Sprintf("rule_%d", time.Now().UnixNano()) + usingGroup := "default" + keyFP := fmt.Sprintf("fp_%d", time.Now().UnixNano()) + ctx := buildChannelAffinityStatsContextForTest(ruleName, usingGroup, keyFP) + + usage := &dto.Usage{ + PromptTokens: 100, + PromptTokensDetails: dto.InputTokenDetails{ + CachedTokens: 25, + }, + } + + ObserveChannelAffinityUsageCacheByRelayFormat(ctx, usage, types.RelayFormatGemini) + stats := GetChannelAffinityUsageCacheStats(ruleName, usingGroup, keyFP) + + require.EqualValues(t, 1, stats.Total) + require.EqualValues(t, 1, stats.Hit) + require.EqualValues(t, 25, stats.CachedTokens) + require.Equal(t, "", stats.CachedTokenRateMode) +} diff --git a/service/quota.go b/service/quota.go index 50421017e..7ee70edd5 100644 --- a/service/quota.go +++ b/service/quota.go @@ -236,6 +236,9 @@ func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, mod } func PostClaudeConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage *dto.Usage) { + if usage != nil { + ObserveChannelAffinityUsageCacheByRelayFormat(ctx, usage, relayInfo.GetFinalRequestRelayFormat()) + } useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix() promptTokens := usage.PromptTokens diff --git a/web/src/components/table/usage-logs/modals/ChannelAffinityUsageCacheModal.jsx b/web/src/components/table/usage-logs/modals/ChannelAffinityUsageCacheModal.jsx index ea1a5c7fb..383ebabc1 100644 --- a/web/src/components/table/usage-logs/modals/ChannelAffinityUsageCacheModal.jsx +++ b/web/src/components/table/usage-logs/modals/ChannelAffinityUsageCacheModal.jsx @@ -39,6 +39,21 @@ function formatTokenRate(n, d) { return `${r.toFixed(2)}%`; } +function formatCachedTokenRate(cachedTokens, promptTokens, mode) { + if (mode === 'cached_over_prompt_plus_cached') { + const denominator = Number(promptTokens || 0) + Number(cachedTokens || 0); + return formatTokenRate(cachedTokens, denominator); + } + if (mode === 'cached_over_prompt') { + return formatTokenRate(cachedTokens, promptTokens); + } + return '-'; +} + +function hasTextValue(value) { + return typeof value === 'string' && value.trim() !== ''; +} + const ChannelAffinityUsageCacheModal = ({ t, showChannelAffinityUsageCacheModal, @@ -107,7 +122,7 @@ const ChannelAffinityUsageCacheModal = ({ t, ]); - const rows = useMemo(() => { + const { rows, supportsTokenStats } = useMemo(() => { const s = stats || {}; const hit = Number(s.hit || 0); const total = Number(s.total || 0); @@ -118,48 +133,62 @@ const ChannelAffinityUsageCacheModal = ({ const totalTokens = Number(s.total_tokens || 0); const cachedTokens = Number(s.cached_tokens || 0); const promptCacheHitTokens = Number(s.prompt_cache_hit_tokens || 0); + const cachedTokenRateMode = String(s.cached_token_rate_mode || '').trim(); + const supportsTokenStats = + cachedTokenRateMode === 'cached_over_prompt' || + cachedTokenRateMode === 'cached_over_prompt_plus_cached' || + cachedTokenRateMode === 'mixed'; - return [ - { key: t('规则'), value: s.rule_name || params.rule_name || '-' }, - { key: t('分组'), value: s.using_group || params.using_group || '-' }, - { - key: t('Key 摘要'), - value: params.key_hint || '-', - }, - { - key: t('Key 指纹'), - value: s.key_fp || params.key_fp || '-', - }, - { key: t('TTL(秒)'), value: windowSeconds > 0 ? windowSeconds : '-' }, - { - key: t('命中率'), - value: `${hit}/${total} (${formatRate(hit, total)})`, - }, - { - key: t('Prompt tokens'), - value: promptTokens, - }, - { - key: t('Cached tokens'), - value: `${cachedTokens} (${formatTokenRate(cachedTokens, promptTokens)})`, - }, - { - key: t('Prompt cache hit tokens'), - value: promptCacheHitTokens, - }, - { - key: t('Completion tokens'), - value: completionTokens, - }, - { - key: t('Total tokens'), - value: totalTokens, - }, - { - key: t('最近一次'), - value: lastSeenAt > 0 ? timestamp2string(lastSeenAt) : '-', - }, - ]; + const data = []; + const ruleName = String(s.rule_name || params.rule_name || '').trim(); + const usingGroup = String(s.using_group || params.using_group || '').trim(); + const keyHint = String(params.key_hint || '').trim(); + const keyFp = String(s.key_fp || params.key_fp || '').trim(); + + if (hasTextValue(ruleName)) { + data.push({ key: t('规则'), value: ruleName }); + } + if (hasTextValue(usingGroup)) { + data.push({ key: t('分组'), value: usingGroup }); + } + if (hasTextValue(keyHint)) { + data.push({ key: t('Key 摘要'), value: keyHint }); + } + if (hasTextValue(keyFp)) { + data.push({ key: t('Key 指纹'), value: keyFp }); + } + if (windowSeconds > 0) { + data.push({ key: t('TTL(秒)'), value: windowSeconds }); + } + if (total > 0) { + data.push({ key: t('命中率'), value: `${hit}/${total} (${formatRate(hit, total)})` }); + } + if (lastSeenAt > 0) { + data.push({ key: t('最近一次'), value: timestamp2string(lastSeenAt) }); + } + + if (supportsTokenStats) { + if (promptTokens > 0) { + data.push({ key: t('Prompt tokens'), value: promptTokens }); + } + if (promptTokens > 0 || cachedTokens > 0) { + data.push({ + key: t('Cached tokens'), + value: `${cachedTokens} (${formatCachedTokenRate(cachedTokens, promptTokens, cachedTokenRateMode)})`, + }); + } + if (promptCacheHitTokens > 0) { + data.push({ key: t('Prompt cache hit tokens'), value: promptCacheHitTokens }); + } + if (completionTokens > 0) { + data.push({ key: t('Completion tokens'), value: completionTokens }); + } + if (totalTokens > 0) { + data.push({ key: t('Total tokens'), value: totalTokens }); + } + } + + return { rows: data, supportsTokenStats }; }, [stats, params, t]); return ( @@ -179,15 +208,27 @@ const ChannelAffinityUsageCacheModal = ({ {t( '命中判定:usage 中存在 cached tokens(例如 cached_tokens/prompt_cache_hit_tokens)即视为命中。', )} + {' '} + {t( + 'Cached tokens 占比口径由后端返回:Claude 语义按 cached/(prompt+cached),其余按 cached/prompt。', + )} + {' '} + {t('当前仅 OpenAI / Claude 语义支持缓存 token 统计,其他通道将隐藏 token 相关字段。')} + {stats && !supportsTokenStats ? ( + <> + {' '} + {t('该记录不包含可用的 token 统计口径。')} + + ) : null}
- {stats ? ( + {stats && rows.length > 0 ? ( ) : (
- {loading ? t('加载中...') : t('暂无数据')} + {loading ? t('加载中...') : t('暂无可展示数据')}
)} From 016812baa621d9962b0152e9829763f953817b5e Mon Sep 17 00:00:00 2001 From: CaIon Date: Mon, 23 Feb 2026 14:11:11 +0800 Subject: [PATCH 29/41] feat: implement caching for channel retrieval --- model/log.go | 20 ++++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/model/log.go b/model/log.go index 1f521b1e5..2d4782fa5 100644 --- a/model/log.go +++ b/model/log.go @@ -295,8 +295,24 @@ func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName Id int `gorm:"column:id"` Name string `gorm:"column:name"` } - if err = DB.Table("channels").Select("id, name").Where("id IN ?", channelIds.Items()).Find(&channels).Error; err != nil { - return logs, total, err + if common.MemoryCacheEnabled { + // Cache get channel + for _, channelId := range channelIds.Items() { + if cacheChannel, err := CacheGetChannel(channelId); err == nil { + channels = append(channels, struct { + Id int `gorm:"column:id"` + Name string `gorm:"column:name"` + }{ + Id: channelId, + Name: cacheChannel.Name, + }) + } + } + } else { + // Bulk query channels from DB + if err = DB.Table("channels").Select("id, name").Where("id IN ?", channelIds.Items()).Find(&channels).Error; err != nil { + return logs, total, err + } } channelMap := make(map[int]string, len(channels)) for _, channel := range channels { From 9a5f8222bd8acaf7052c374d7153b9320fe7c1cc Mon Sep 17 00:00:00 2001 From: Seefs Date: Mon, 23 Feb 2026 14:51:55 +0800 Subject: [PATCH 30/41] feat: move user bindings to dedicated management modal --- controller/custom_oauth.go | 121 +++++- controller/user.go | 38 ++ model/user.go | 31 ++ router/api-router.go | 3 + .../table/users/modals/EditUserModal.jsx | 107 +++-- .../modals/UserBindingManagementModal.jsx | 396 ++++++++++++++++++ 6 files changed, 629 insertions(+), 67 deletions(-) create mode 100644 web/src/components/table/users/modals/UserBindingManagementModal.jsx diff --git a/controller/custom_oauth.go b/controller/custom_oauth.go index 3197a9163..c21ec7910 100644 --- a/controller/custom_oauth.go +++ b/controller/custom_oauth.go @@ -38,6 +38,14 @@ type CustomOAuthProviderResponse struct { AccessDeniedMessage string `json:"access_denied_message"` } +type UserOAuthBindingResponse struct { + ProviderId int `json:"provider_id"` + ProviderName string `json:"provider_name"` + ProviderSlug string `json:"provider_slug"` + ProviderIcon string `json:"provider_icon"` + ProviderUserId string `json:"provider_user_id"` +} + func toCustomOAuthProviderResponse(p *model.CustomOAuthProvider) *CustomOAuthProviderResponse { return &CustomOAuthProviderResponse{ Id: p.Id, @@ -433,6 +441,30 @@ func DeleteCustomOAuthProvider(c *gin.Context) { }) } +func buildUserOAuthBindingsResponse(userId int) ([]UserOAuthBindingResponse, error) { + bindings, err := model.GetUserOAuthBindingsByUserId(userId) + if err != nil { + return nil, err + } + + response := make([]UserOAuthBindingResponse, 0, len(bindings)) + for _, binding := range bindings { + provider, err := model.GetCustomOAuthProviderById(binding.ProviderId) + if err != nil { + continue + } + response = append(response, UserOAuthBindingResponse{ + ProviderId: binding.ProviderId, + ProviderName: provider.Name, + ProviderSlug: provider.Slug, + ProviderIcon: provider.Icon, + ProviderUserId: binding.ProviderUserId, + }) + } + + return response, nil +} + // GetUserOAuthBindings returns all OAuth bindings for the current user func GetUserOAuthBindings(c *gin.Context) { userId := c.GetInt("id") @@ -441,34 +473,43 @@ func GetUserOAuthBindings(c *gin.Context) { return } - bindings, err := model.GetUserOAuthBindingsByUserId(userId) + response, err := buildUserOAuthBindingsResponse(userId) if err != nil { common.ApiError(c, err) return } - // Build response with provider info - type BindingResponse struct { - ProviderId int `json:"provider_id"` - ProviderName string `json:"provider_name"` - ProviderSlug string `json:"provider_slug"` - ProviderIcon string `json:"provider_icon"` - ProviderUserId string `json:"provider_user_id"` + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + "data": response, + }) +} + +func GetUserOAuthBindingsByAdmin(c *gin.Context) { + userIdStr := c.Param("id") + userId, err := strconv.Atoi(userIdStr) + if err != nil { + common.ApiErrorMsg(c, "invalid user id") + return } - response := make([]BindingResponse, 0) - for _, binding := range bindings { - provider, err := model.GetCustomOAuthProviderById(binding.ProviderId) - if err != nil { - continue // Skip if provider not found - } - response = append(response, BindingResponse{ - ProviderId: binding.ProviderId, - ProviderName: provider.Name, - ProviderSlug: provider.Slug, - ProviderIcon: provider.Icon, - ProviderUserId: binding.ProviderUserId, - }) + targetUser, err := model.GetUserById(userId, false) + if err != nil { + common.ApiError(c, err) + return + } + + myRole := c.GetInt("role") + if myRole <= targetUser.Role && myRole != common.RoleRootUser { + common.ApiErrorMsg(c, "no permission") + return + } + + response, err := buildUserOAuthBindingsResponse(userId) + if err != nil { + common.ApiError(c, err) + return } c.JSON(http.StatusOK, gin.H{ @@ -503,3 +544,41 @@ func UnbindCustomOAuth(c *gin.Context) { "message": "解绑成功", }) } + +func UnbindCustomOAuthByAdmin(c *gin.Context) { + userIdStr := c.Param("id") + userId, err := strconv.Atoi(userIdStr) + if err != nil { + common.ApiErrorMsg(c, "invalid user id") + return + } + + targetUser, err := model.GetUserById(userId, false) + if err != nil { + common.ApiError(c, err) + return + } + + myRole := c.GetInt("role") + if myRole <= targetUser.Role && myRole != common.RoleRootUser { + common.ApiErrorMsg(c, "no permission") + return + } + + providerIdStr := c.Param("provider_id") + providerId, err := strconv.Atoi(providerIdStr) + if err != nil { + common.ApiErrorMsg(c, "invalid provider id") + return + } + + if err := model.DeleteUserOAuthBinding(userId, providerId); err != nil { + common.ApiError(c, err) + return + } + + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "success", + }) +} diff --git a/controller/user.go b/controller/user.go index db078071e..b58eab88f 100644 --- a/controller/user.go +++ b/controller/user.go @@ -582,6 +582,44 @@ func UpdateUser(c *gin.Context) { return } +func AdminClearUserBinding(c *gin.Context) { + id, err := strconv.Atoi(c.Param("id")) + if err != nil { + common.ApiErrorI18n(c, i18n.MsgInvalidParams) + return + } + + bindingType := strings.ToLower(strings.TrimSpace(c.Param("binding_type"))) + if bindingType == "" { + common.ApiErrorI18n(c, i18n.MsgInvalidParams) + return + } + + user, err := model.GetUserById(id, false) + if err != nil { + common.ApiError(c, err) + return + } + + myRole := c.GetInt("role") + if myRole <= user.Role && myRole != common.RoleRootUser { + common.ApiErrorI18n(c, i18n.MsgUserNoPermissionSameLevel) + return + } + + if err := user.ClearBinding(bindingType); err != nil { + common.ApiError(c, err) + return + } + + model.RecordLog(user.Id, model.LogTypeManage, fmt.Sprintf("admin cleared %s binding for user %s", bindingType, user.Username)) + + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "success", + }) +} + func UpdateSelf(c *gin.Context) { var requestData map[string]interface{} err := json.NewDecoder(c.Request.Body).Decode(&requestData) diff --git a/model/user.go b/model/user.go index e0c9c686f..e0f803a90 100644 --- a/model/user.go +++ b/model/user.go @@ -536,6 +536,37 @@ func (user *User) Edit(updatePassword bool) error { return updateUserCache(*user) } +func (user *User) ClearBinding(bindingType string) error { + if user.Id == 0 { + return errors.New("user id is empty") + } + + bindingColumnMap := map[string]string{ + "email": "email", + "github": "github_id", + "discord": "discord_id", + "oidc": "oidc_id", + "wechat": "wechat_id", + "telegram": "telegram_id", + "linuxdo": "linux_do_id", + } + + column, ok := bindingColumnMap[bindingType] + if !ok { + return errors.New("invalid binding type") + } + + if err := DB.Model(&User{}).Where("id = ?", user.Id).Update(column, "").Error; err != nil { + return err + } + + if err := DB.Where("id = ?", user.Id).First(user).Error; err != nil { + return err + } + + return updateUserCache(*user) +} + func (user *User) Delete() error { if user.Id == 0 { return errors.New("id 为空!") diff --git a/router/api-router.go b/router/api-router.go index d60ba39b2..b6e418c6e 100644 --- a/router/api-router.go +++ b/router/api-router.go @@ -114,6 +114,9 @@ func SetApiRouter(router *gin.Engine) { adminRoute.GET("/topup", controller.GetAllTopUps) adminRoute.POST("/topup/complete", controller.AdminCompleteTopUp) adminRoute.GET("/search", controller.SearchUsers) + adminRoute.GET("/:id/oauth/bindings", controller.GetUserOAuthBindingsByAdmin) + adminRoute.DELETE("/:id/oauth/bindings/:provider_id", controller.UnbindCustomOAuthByAdmin) + adminRoute.DELETE("/:id/bindings/:binding_type", controller.AdminClearUserBinding) adminRoute.GET("/:id", controller.GetUser) adminRoute.POST("/", controller.CreateUser) adminRoute.POST("/manage", controller.ManageUser) diff --git a/web/src/components/table/users/modals/EditUserModal.jsx b/web/src/components/table/users/modals/EditUserModal.jsx index 32601daa8..297f18116 100644 --- a/web/src/components/table/users/modals/EditUserModal.jsx +++ b/web/src/components/table/users/modals/EditUserModal.jsx @@ -45,7 +45,6 @@ import { Avatar, Row, Col, - Input, InputNumber, } from '@douyinfe/semi-ui'; import { @@ -56,6 +55,7 @@ import { IconUserGroup, IconPlus, } from '@douyinfe/semi-icons'; +import UserBindingManagementModal from './UserBindingManagementModal'; const { Text, Title } = Typography; @@ -68,6 +68,7 @@ const EditUserModal = (props) => { const [addAmountLocal, setAddAmountLocal] = useState(''); const isMobile = useIsMobile(); const [groupOptions, setGroupOptions] = useState([]); + const [bindingModalVisible, setBindingModalVisible] = useState(false); const formApiRef = useRef(null); const isEdit = Boolean(userId); @@ -81,6 +82,7 @@ const EditUserModal = (props) => { discord_id: '', wechat_id: '', telegram_id: '', + linux_do_id: '', email: '', quota: 0, group: 'default', @@ -115,8 +117,17 @@ const EditUserModal = (props) => { useEffect(() => { loadUser(); if (userId) fetchGroups(); + setBindingModalVisible(false); }, [props.editingUser.id]); + const openBindingModal = () => { + setBindingModalVisible(true); + }; + + const closeBindingModal = () => { + setBindingModalVisible(false); + }; + /* ----------------------- submit ----------------------- */ const submit = async (values) => { setLoading(true); @@ -316,56 +327,51 @@ const EditUserModal = (props) => { )} - {/* 绑定信息 */} - -
- - - -
- - {t('绑定信息')} - -
- {t('第三方账户绑定状态(只读)')} + {/* 绑定信息入口 */} + {userId && ( + +
+
+ + + +
+ + {t('绑定信息')} + +
+ {t('第三方账户绑定状态(只读)')} +
+
+
-
- - - {[ - 'github_id', - 'discord_id', - 'oidc_id', - 'wechat_id', - 'email', - 'telegram_id', - ].map((field) => ( - - - - ))} - - + + )}
)} + + {/* 添加额度模态框 */} {
{t('金额')} - ({t('仅用于换算,实际保存的是额度')}) + + {' '} + ({t('仅用于换算,实际保存的是额度')}) +
{ onChange={(val) => { setAddAmountLocal(val); setAddQuotaLocal( - val != null && val !== '' ? displayAmountToQuota(Math.abs(val)) * Math.sign(val) : '', + val != null && val !== '' + ? displayAmountToQuota(Math.abs(val)) * Math.sign(val) + : '', ); }} style={{ width: '100%' }} @@ -430,7 +441,11 @@ const EditUserModal = (props) => { setAddQuotaLocal(val); setAddAmountLocal( val != null && val !== '' - ? Number((quotaToDisplayAmount(Math.abs(val)) * Math.sign(val)).toFixed(2)) + ? Number( + ( + quotaToDisplayAmount(Math.abs(val)) * Math.sign(val) + ).toFixed(2), + ) : '', ); }} diff --git a/web/src/components/table/users/modals/UserBindingManagementModal.jsx b/web/src/components/table/users/modals/UserBindingManagementModal.jsx new file mode 100644 index 000000000..547c04f7d --- /dev/null +++ b/web/src/components/table/users/modals/UserBindingManagementModal.jsx @@ -0,0 +1,396 @@ +/* +Copyright (C) 2025 QuantumNous + +This program is free software: you can redistribute it and/or modify +it under the terms of the GNU Affero General Public License as +published by the Free Software Foundation, either version 3 of the +License, or (at your option) any later version. + +This program is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU Affero General Public License for more details. + +You should have received a copy of the GNU Affero General Public License +along with this program. If not, see . + +For commercial licensing, please contact support@quantumnous.com +*/ + +import React from 'react'; +import { useTranslation } from 'react-i18next'; +import { + API, + showError, + showSuccess, + getOAuthProviderIcon, +} from '../../../../helpers'; +import { + Modal, + Spin, + Typography, + Card, + Checkbox, + Tag, + Button, +} from '@douyinfe/semi-ui'; +import { + IconLink, + IconMail, + IconDelete, + IconGithubLogo, +} from '@douyinfe/semi-icons'; +import { SiDiscord, SiTelegram, SiWechat, SiLinux } from 'react-icons/si'; + +const { Text } = Typography; + +const UserBindingManagementModal = ({ + visible, + onCancel, + userId, + isMobile, + formApiRef, +}) => { + const { t } = useTranslation(); + const [bindingLoading, setBindingLoading] = React.useState(false); + const [showUnboundOnly, setShowUnboundOnly] = React.useState(false); + const [statusInfo, setStatusInfo] = React.useState({}); + const [customOAuthBindings, setCustomOAuthBindings] = React.useState([]); + const [bindingActionLoading, setBindingActionLoading] = React.useState({}); + + const loadBindingData = React.useCallback(async () => { + if (!userId) return; + + setBindingLoading(true); + try { + const [statusRes, customBindingRes] = await Promise.all([ + API.get('/api/status'), + API.get(`/api/user/${userId}/oauth/bindings`), + ]); + + if (statusRes.data?.success) { + setStatusInfo(statusRes.data.data || {}); + } else { + showError(statusRes.data?.message || t('操作失败')); + } + + if (customBindingRes.data?.success) { + setCustomOAuthBindings(customBindingRes.data.data || []); + } else { + showError(customBindingRes.data?.message || t('操作失败')); + } + } catch (error) { + showError( + error.response?.data?.message || error.message || t('操作失败'), + ); + } finally { + setBindingLoading(false); + } + }, [t, userId]); + + React.useEffect(() => { + if (!visible) return; + setShowUnboundOnly(false); + setBindingActionLoading({}); + loadBindingData(); + }, [visible, loadBindingData]); + + const setBindingLoadingState = (key, value) => { + setBindingActionLoading((prev) => ({ ...prev, [key]: value })); + }; + + const handleUnbindBuiltInAccount = (bindingItem) => { + if (!userId) return; + + Modal.confirm({ + title: t('确认解绑'), + content: t('确定要解绑 {{name}} 吗?', { name: bindingItem.name }), + okText: t('确认'), + cancelText: t('取消'), + onOk: async () => { + const loadingKey = `builtin-${bindingItem.key}`; + setBindingLoadingState(loadingKey, true); + try { + const res = await API.delete( + `/api/user/${userId}/bindings/${bindingItem.key}`, + ); + if (!res.data?.success) { + showError(res.data?.message || t('操作失败')); + return; + } + formApiRef.current?.setValue(bindingItem.field, ''); + showSuccess(t('解绑成功')); + } catch (error) { + showError( + error.response?.data?.message || error.message || t('操作失败'), + ); + } finally { + setBindingLoadingState(loadingKey, false); + } + }, + }); + }; + + const handleUnbindCustomOAuthAccount = (provider) => { + if (!userId) return; + + Modal.confirm({ + title: t('确认解绑'), + content: t('确定要解绑 {{name}} 吗?', { name: provider.name }), + okText: t('确认'), + cancelText: t('取消'), + onOk: async () => { + const loadingKey = `custom-${provider.id}`; + setBindingLoadingState(loadingKey, true); + try { + const res = await API.delete( + `/api/user/${userId}/oauth/bindings/${provider.id}`, + ); + if (!res.data?.success) { + showError(res.data?.message || t('操作失败')); + return; + } + setCustomOAuthBindings((prev) => + prev.filter( + (item) => Number(item.provider_id) !== Number(provider.id), + ), + ); + showSuccess(t('解绑成功')); + } catch (error) { + showError( + error.response?.data?.message || error.message || t('操作失败'), + ); + } finally { + setBindingLoadingState(loadingKey, false); + } + }, + }); + }; + + const currentValues = formApiRef.current?.getValues?.() || {}; + + const builtInBindingItems = [ + { + key: 'email', + field: 'email', + name: t('邮箱'), + enabled: true, + value: currentValues.email, + icon: ( + + ), + }, + { + key: 'github', + field: 'github_id', + name: 'GitHub', + enabled: Boolean(statusInfo.github_oauth), + value: currentValues.github_id, + icon: ( + + ), + }, + { + key: 'discord', + field: 'discord_id', + name: 'Discord', + enabled: Boolean(statusInfo.discord_oauth), + value: currentValues.discord_id, + icon: ( + + ), + }, + { + key: 'oidc', + field: 'oidc_id', + name: 'OIDC', + enabled: Boolean(statusInfo.oidc_enabled), + value: currentValues.oidc_id, + icon: ( + + ), + }, + { + key: 'wechat', + field: 'wechat_id', + name: t('微信'), + enabled: Boolean(statusInfo.wechat_login), + value: currentValues.wechat_id, + icon: ( + + ), + }, + { + key: 'telegram', + field: 'telegram_id', + name: 'Telegram', + enabled: Boolean(statusInfo.telegram_oauth), + value: currentValues.telegram_id, + icon: ( + + ), + }, + { + key: 'linuxdo', + field: 'linux_do_id', + name: 'LinuxDO', + enabled: Boolean(statusInfo.linuxdo_oauth), + value: currentValues.linux_do_id, + icon: ( + + ), + }, + ]; + + const customBindingMap = new Map( + customOAuthBindings.map((item) => [Number(item.provider_id), item]), + ); + + const customProviderMap = new Map( + (statusInfo.custom_oauth_providers || []).map((provider) => [ + Number(provider.id), + provider, + ]), + ); + + customOAuthBindings.forEach((binding) => { + if (!customProviderMap.has(Number(binding.provider_id))) { + customProviderMap.set(Number(binding.provider_id), { + id: binding.provider_id, + name: binding.provider_name, + icon: binding.provider_icon, + }); + } + }); + + const customBindingItems = Array.from(customProviderMap.values()).map( + (provider) => { + const binding = customBindingMap.get(Number(provider.id)); + return { + key: `custom-${provider.id}`, + providerId: provider.id, + name: provider.name, + enabled: true, + value: binding?.provider_user_id || '', + icon: getOAuthProviderIcon( + provider.icon || binding?.provider_icon || '', + 20, + ), + }; + }, + ); + + const allBindingItems = [ + ...builtInBindingItems.map((item) => ({ ...item, type: 'builtin' })), + ...customBindingItems.map((item) => ({ ...item, type: 'custom' })), + ]; + + const visibleBindingItems = showUnboundOnly + ? allBindingItems.filter((item) => !item.value) + : allBindingItems; + + return ( + + + {t('绑定信息')} +
+ } + > + +
+ setShowUnboundOnly(Boolean(e.target.checked))} + > + {`${t('筛选')} ${t('未绑定')}`} + + + {t('筛选')} · {visibleBindingItems.length} + +
+ + {visibleBindingItems.length === 0 ? ( + + {t('暂无自定义 OAuth 提供商')} + + ) : ( +
+ {visibleBindingItems.map((item) => { + const isBound = Boolean(item.value); + const loadingKey = + item.type === 'builtin' + ? `builtin-${item.key}` + : `custom-${item.providerId}`; + const statusText = isBound + ? item.value + : item.enabled + ? t('未绑定') + : t('未启用'); + + return ( + +
+
+
+ {item.icon} +
+
+
+ {item.name} + + {item.type === 'builtin' ? 'Built-in' : 'Custom'} + +
+
+ {statusText} +
+
+
+ +
+
+ ); + })} +
+ )} +
+
+ ); +}; + +export default UserBindingManagementModal; From 2f4d38fefd19a77f43f75e036da71be4cec41dfa Mon Sep 17 00:00:00 2001 From: Seefs Date: Mon, 23 Feb 2026 15:16:22 +0800 Subject: [PATCH 31/41] refactor: extract binding modal and polish binding management UX --- .../table/users/modals/EditUserModal.jsx | 6 +- .../modals/UserBindingManagementModal.jsx | 172 ++++++++++-------- 2 files changed, 96 insertions(+), 82 deletions(-) diff --git a/web/src/components/table/users/modals/EditUserModal.jsx b/web/src/components/table/users/modals/EditUserModal.jsx index 297f18116..90676d840 100644 --- a/web/src/components/table/users/modals/EditUserModal.jsx +++ b/web/src/components/table/users/modals/EditUserModal.jsx @@ -207,7 +207,7 @@ const EditUserModal = (props) => { onSubmit={submit} > {({ values }) => ( -
+
{/* 基本信息 */}
@@ -344,7 +344,7 @@ const EditUserModal = (props) => { {t('绑定信息')}
- {t('第三方账户绑定状态(只读)')} + {t('管理用户已绑定的第三方账户,支持筛选与解绑')}
@@ -353,7 +353,7 @@ const EditUserModal = (props) => { theme='outline' onClick={openBindingModal} > - {t('修改绑定')} + {t('管理绑定')}
diff --git a/web/src/components/table/users/modals/UserBindingManagementModal.jsx b/web/src/components/table/users/modals/UserBindingManagementModal.jsx index 547c04f7d..c5b2a3a15 100644 --- a/web/src/components/table/users/modals/UserBindingManagementModal.jsx +++ b/web/src/components/table/users/modals/UserBindingManagementModal.jsx @@ -53,7 +53,7 @@ const UserBindingManagementModal = ({ }) => { const { t } = useTranslation(); const [bindingLoading, setBindingLoading] = React.useState(false); - const [showUnboundOnly, setShowUnboundOnly] = React.useState(false); + const [showBoundOnly, setShowBoundOnly] = React.useState(true); const [statusInfo, setStatusInfo] = React.useState({}); const [customOAuthBindings, setCustomOAuthBindings] = React.useState([]); const [bindingActionLoading, setBindingActionLoading] = React.useState({}); @@ -90,7 +90,7 @@ const UserBindingManagementModal = ({ React.useEffect(() => { if (!visible) return; - setShowUnboundOnly(false); + setShowBoundOnly(true); setBindingActionLoading({}); loadBindingData(); }, [visible, loadBindingData]); @@ -294,8 +294,12 @@ const UserBindingManagementModal = ({ ...customBindingItems.map((item) => ({ ...item, type: 'custom' })), ]; - const visibleBindingItems = showUnboundOnly - ? allBindingItems.filter((item) => !item.value) + const boundCount = allBindingItems.filter((item) => + Boolean(item.value), + ).length; + + const visibleBindingItems = showBoundOnly + ? allBindingItems.filter((item) => Boolean(item.value)) : allBindingItems; return ( @@ -308,86 +312,96 @@ const UserBindingManagementModal = ({ title={
- {t('绑定信息')} + {t('账户绑定管理')}
} > -
- setShowUnboundOnly(Boolean(e.target.checked))} - > - {`${t('筛选')} ${t('未绑定')}`} - - - {t('筛选')} · {visibleBindingItems.length} - -
- - {visibleBindingItems.length === 0 ? ( - - {t('暂无自定义 OAuth 提供商')} - - ) : ( -
- {visibleBindingItems.map((item) => { - const isBound = Boolean(item.value); - const loadingKey = - item.type === 'builtin' - ? `builtin-${item.key}` - : `custom-${item.providerId}`; - const statusText = isBound - ? item.value - : item.enabled - ? t('未绑定') - : t('未启用'); - - return ( - -
-
-
- {item.icon} -
-
-
- {item.name} - - {item.type === 'builtin' ? 'Built-in' : 'Custom'} - -
-
- {statusText} -
-
-
- -
-
- ); - })} +
+
+ setShowBoundOnly(Boolean(e.target.checked))} + > + {t('仅显示已绑定')} + + + {t('已绑定')} {boundCount} / {allBindingItems.length} +
- )} + + {visibleBindingItems.length === 0 ? ( + + {t('暂无已绑定项')} + + ) : ( +
+ {visibleBindingItems.map((item, index) => { + const isBound = Boolean(item.value); + const loadingKey = + item.type === 'builtin' + ? `builtin-${item.key}` + : `custom-${item.providerId}`; + const statusText = isBound + ? item.value + : item.enabled + ? t('未绑定') + : t('未启用'); + const shouldSpanTwoColsOnDesktop = + visibleBindingItems.length % 2 === 1 && + index === visibleBindingItems.length - 1; + + return ( + +
+
+
+ {item.icon} +
+
+
+ {item.name} + + {item.type === 'builtin' + ? t('内置') + : t('自定义')} + +
+
+ {statusText} +
+
+
+ +
+
+ ); + })} +
+ )} +
); From 80c213072ce709394b1e1cfdf57002a9384dcde5 Mon Sep 17 00:00:00 2001 From: CaIon Date: Mon, 23 Feb 2026 16:59:46 +0800 Subject: [PATCH 32/41] fix: improve multipart form data handling in gin context - Added caching for the original Content-Type header in the parseMultipartFormData function. - This change ensures that the Content-Type is retrieved from the context if previously set, enhancing performance and consistency. --- common/gin.go | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/common/gin.go b/common/gin.go index 009e39080..5cad6e5c9 100644 --- a/common/gin.go +++ b/common/gin.go @@ -303,7 +303,13 @@ func parseFormData(data []byte, v any) error { } func parseMultipartFormData(c *gin.Context, data []byte, v any) error { - contentType := c.Request.Header.Get("Content-Type") + var contentType string + if saved, ok := c.Get("_original_multipart_ct"); ok { + contentType = saved.(string) + } else { + contentType = c.Request.Header.Get("Content-Type") + c.Set("_original_multipart_ct", contentType) + } boundary, err := parseBoundary(contentType) if err != nil { if errors.Is(err, errBoundaryNotFound) { From 0835e15091d26d730cf44c164d6e71b399a6b7eb Mon Sep 17 00:00:00 2001 From: CaIon Date: Mon, 23 Feb 2026 17:42:22 +0800 Subject: [PATCH 33/41] fix: enhance data trimming and validation in stream scanner --- relay/helper/stream_scanner.go | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/relay/helper/stream_scanner.go b/relay/helper/stream_scanner.go index 4f3ab2363..b28941403 100644 --- a/relay/helper/stream_scanner.go +++ b/relay/helper/stream_scanner.go @@ -215,8 +215,10 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon continue } data = data[5:] - data = strings.TrimLeft(data, " ") - data = strings.TrimSuffix(data, "\r") + data = strings.TrimSpace(data) + if data == "" { + continue + } if !strings.HasPrefix(data, "[DONE]") { info.SetFirstResponseTime() info.ReceivedResponseCount++ From 532691b06bc1dc2eab0dce055b46511510f21a99 Mon Sep 17 00:00:00 2001 From: Seefs Date: Mon, 23 Feb 2026 22:02:59 +0800 Subject: [PATCH 34/41] fix: violation fee check --- service/violation_fee.go | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/service/violation_fee.go b/service/violation_fee.go index 400c10dd5..455088561 100644 --- a/service/violation_fee.go +++ b/service/violation_fee.go @@ -18,8 +18,9 @@ import ( ) const ( - ViolationFeeCodePrefix = "violation_fee." - CSAMViolationMarker = "Failed check: SAFETY_CHECK_TYPE_CSAM" + ViolationFeeCodePrefix = "violation_fee." + CSAMViolationMarker = "Failed check: SAFETY_CHECK_TYPE" + ContentViolatesUsageMarker = "Content violates usage guidelines" ) func IsViolationFeeCode(code types.ErrorCode) bool { @@ -30,11 +31,11 @@ func HasCSAMViolationMarker(err *types.NewAPIError) bool { if err == nil { return false } - if strings.Contains(err.Error(), CSAMViolationMarker) { + if strings.Contains(err.Error(), CSAMViolationMarker) || strings.Contains(err.Error(), ContentViolatesUsageMarker) { return true } msg := err.ToOpenAIError().Message - return strings.Contains(msg, CSAMViolationMarker) + return strings.Contains(msg, CSAMViolationMarker) || strings.Contains(err.Error(), ContentViolatesUsageMarker) } func WrapAsViolationFeeGrokCSAM(err *types.NewAPIError) *types.NewAPIError { From 98de08280484b6a01d975c9015c77a6f8902b842 Mon Sep 17 00:00:00 2001 From: hekx Date: Tue, 24 Feb 2026 09:58:50 +0800 Subject: [PATCH 35/41] fix: skip Accept-Encoding during header passthrough (#2214) --- relay/channel/api_request.go | 5 +++-- relay/channel/api_request_test.go | 27 +++++++++++++++++++++++++++ 2 files changed, 30 insertions(+), 2 deletions(-) diff --git a/relay/channel/api_request.go b/relay/channel/api_request.go index ec5573ab1..09ca855dd 100644 --- a/relay/channel/api_request.go +++ b/relay/channel/api_request.go @@ -61,8 +61,9 @@ var passthroughSkipHeaderNamesLower = map[string]struct{}{ "cookie": {}, // Additional headers that should not be forwarded by name-matching passthrough rules. - "host": {}, - "content-length": {}, + "host": {}, + "content-length": {}, + "accept-encoding": {}, // Do not passthrough credentials by wildcard/regex. "authorization": {}, diff --git a/relay/channel/api_request_test.go b/relay/channel/api_request_test.go index c55ffcab2..6c7834ef9 100644 --- a/relay/channel/api_request_test.go +++ b/relay/channel/api_request_test.go @@ -79,3 +79,30 @@ func TestProcessHeaderOverride_NonTestKeepsClientHeaderPlaceholder(t *testing.T) require.NoError(t, err) require.Equal(t, "trace-123", headers["X-Upstream-Trace"]) } + +func TestProcessHeaderOverride_PassthroughSkipsAcceptEncoding(t *testing.T) { + t.Parallel() + + gin.SetMode(gin.TestMode) + recorder := httptest.NewRecorder() + ctx, _ := gin.CreateTestContext(recorder) + ctx.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil) + ctx.Request.Header.Set("X-Trace-Id", "trace-123") + ctx.Request.Header.Set("Accept-Encoding", "gzip") + + info := &relaycommon.RelayInfo{ + IsChannelTest: false, + ChannelMeta: &relaycommon.ChannelMeta{ + HeadersOverride: map[string]any{ + "*": "", + }, + }, + } + + headers, err := processHeaderOverride(info, ctx) + require.NoError(t, err) + require.Equal(t, "trace-123", headers["X-Trace-Id"]) + + _, hasAcceptEncoding := headers["Accept-Encoding"] + require.False(t, hasAcceptEncoding) +} From af31935102df21e6df5c99fa3a07315084b4d13a Mon Sep 17 00:00:00 2001 From: Seefs Date: Tue, 24 Feb 2026 13:26:19 +0800 Subject: [PATCH 36/41] fix: check oauthUser.Username length --- controller/oauth.go | 17 ++++++++++------- model/user.go | 2 ++ 2 files changed, 12 insertions(+), 7 deletions(-) diff --git a/controller/oauth.go b/controller/oauth.go index faa22dd4f..0bb33d2cd 100644 --- a/controller/oauth.go +++ b/controller/oauth.go @@ -240,7 +240,10 @@ func findOrCreateOAuthUser(c *gin.Context, provider oauth.Provider, oauthUser *o if oauthUser.Username != "" { if exists, err := model.CheckUserExistOrDeleted(oauthUser.Username, ""); err == nil && !exists { - user.Username = oauthUser.Username + // 防止索引退化 + if len(oauthUser.Username) <= model.UserNameMaxLength { + user.Username = oauthUser.Username + } } } @@ -302,12 +305,12 @@ func findOrCreateOAuthUser(c *gin.Context, provider oauth.Provider, oauthUser *o // Set the provider user ID on the user model and update provider.SetProviderUserID(user, oauthUser.ProviderUserID) if err := tx.Model(user).Updates(map[string]interface{}{ - "github_id": user.GitHubId, - "discord_id": user.DiscordId, - "oidc_id": user.OidcId, - "linux_do_id": user.LinuxDOId, - "wechat_id": user.WeChatId, - "telegram_id": user.TelegramId, + "github_id": user.GitHubId, + "discord_id": user.DiscordId, + "oidc_id": user.OidcId, + "linux_do_id": user.LinuxDOId, + "wechat_id": user.WeChatId, + "telegram_id": user.TelegramId, }).Error; err != nil { return err } diff --git a/model/user.go b/model/user.go index e0c9c686f..f5ba5cebe 100644 --- a/model/user.go +++ b/model/user.go @@ -15,6 +15,8 @@ import ( "gorm.io/gorm" ) +const UserNameMaxLength = 20 + // User if you add sensitive fields, don't forget to clean them in setupLogin function. // Otherwise, the sensitive information will be saved on local storage in plain text! type User struct { From 4c7e65cb2409bd44b4e0ac62c9e811b6f7e3cd20 Mon Sep 17 00:00:00 2001 From: CaIon Date: Tue, 24 Feb 2026 17:35:54 +0800 Subject: [PATCH 37/41] feat: add comprehensive tests for StreamScannerHandler functionality - Introduced a new test file for StreamScannerHandler, covering various scenarios including nil inputs, empty bodies, chunk processing, order preservation, and handler failures. - Enhanced error handling and data processing logic in StreamScannerHandler to improve robustness and performance. --- relay/helper/stream_scanner.go | 37 +- relay/helper/stream_scanner_test.go | 521 ++++++++++++++++++++++++++++ 2 files changed, 544 insertions(+), 14 deletions(-) create mode 100644 relay/helper/stream_scanner_test.go diff --git a/relay/helper/stream_scanner.go b/relay/helper/stream_scanner.go index b28941403..ae70f53c0 100644 --- a/relay/helper/stream_scanner.go +++ b/relay/helper/stream_scanner.go @@ -176,10 +176,32 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon }) } + dataChan := make(chan string, 10) + + wg.Add(1) + gopool.Go(func() { + defer func() { + wg.Done() + if r := recover(); r != nil { + logger.LogError(c, fmt.Sprintf("data handler goroutine panic: %v", r)) + } + common.SafeSendBool(stopChan, true) + }() + for data := range dataChan { + writeMutex.Lock() + success := dataHandler(data) + writeMutex.Unlock() + if !success { + return + } + } + }) + // Scanner goroutine with improved error handling wg.Add(1) common.RelayCtxGo(ctx, func() { defer func() { + close(dataChan) wg.Done() if r := recover(); r != nil { logger.LogError(c, fmt.Sprintf("scanner goroutine panic: %v", r)) @@ -222,22 +244,9 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon if !strings.HasPrefix(data, "[DONE]") { info.SetFirstResponseTime() info.ReceivedResponseCount++ - // 使用超时机制防止写操作阻塞 - done := make(chan bool, 1) - gopool.Go(func() { - writeMutex.Lock() - defer writeMutex.Unlock() - done <- dataHandler(data) - }) select { - case success := <-done: - if !success { - return - } - case <-time.After(10 * time.Second): - logger.LogError(c, "data handler timeout") - return + case dataChan <- data: case <-ctx.Done(): return case <-stopChan: diff --git a/relay/helper/stream_scanner_test.go b/relay/helper/stream_scanner_test.go new file mode 100644 index 000000000..6890d82a5 --- /dev/null +++ b/relay/helper/stream_scanner_test.go @@ -0,0 +1,521 @@ +package helper + +import ( + "fmt" + "io" + "net/http" + "net/http/httptest" + "strings" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/QuantumNous/new-api/constant" + relaycommon "github.com/QuantumNous/new-api/relay/common" + "github.com/QuantumNous/new-api/setting/operation_setting" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func init() { + gin.SetMode(gin.TestMode) +} + +func setupStreamTest(t *testing.T, body io.Reader) (*gin.Context, *http.Response, *relaycommon.RelayInfo) { + t.Helper() + + oldTimeout := constant.StreamingTimeout + constant.StreamingTimeout = 30 + t.Cleanup(func() { + constant.StreamingTimeout = oldTimeout + }) + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil) + + resp := &http.Response{ + Body: io.NopCloser(body), + } + + info := &relaycommon.RelayInfo{ + ChannelMeta: &relaycommon.ChannelMeta{}, + } + + return c, resp, info +} + +func buildSSEBody(n int) string { + var b strings.Builder + for i := 0; i < n; i++ { + fmt.Fprintf(&b, "data: {\"id\":%d,\"choices\":[{\"delta\":{\"content\":\"token_%d\"}}]}\n", i, i) + } + b.WriteString("data: [DONE]\n") + return b.String() +} + +// slowReader wraps a reader and injects a delay before each Read call, +// simulating a slow upstream that trickles data. +type slowReader struct { + r io.Reader + delay time.Duration +} + +func (s *slowReader) Read(p []byte) (int, error) { + time.Sleep(s.delay) + return s.r.Read(p) +} + +// ---------- Basic correctness ---------- + +func TestStreamScannerHandler_NilInputs(t *testing.T) { + t.Parallel() + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + c.Request = httptest.NewRequest(http.MethodPost, "/", nil) + + info := &relaycommon.RelayInfo{ChannelMeta: &relaycommon.ChannelMeta{}} + + StreamScannerHandler(c, nil, info, func(data string) bool { return true }) + StreamScannerHandler(c, &http.Response{Body: io.NopCloser(strings.NewReader(""))}, info, nil) +} + +func TestStreamScannerHandler_EmptyBody(t *testing.T) { + t.Parallel() + + c, resp, info := setupStreamTest(t, strings.NewReader("")) + + var called atomic.Bool + StreamScannerHandler(c, resp, info, func(data string) bool { + called.Store(true) + return true + }) + + assert.False(t, called.Load(), "handler should not be called for empty body") +} + +func TestStreamScannerHandler_1000Chunks(t *testing.T) { + t.Parallel() + + const numChunks = 1000 + body := buildSSEBody(numChunks) + c, resp, info := setupStreamTest(t, strings.NewReader(body)) + + var count atomic.Int64 + StreamScannerHandler(c, resp, info, func(data string) bool { + count.Add(1) + return true + }) + + assert.Equal(t, int64(numChunks), count.Load()) + assert.Equal(t, numChunks, info.ReceivedResponseCount) +} + +func TestStreamScannerHandler_10000Chunks(t *testing.T) { + t.Parallel() + + const numChunks = 10000 + body := buildSSEBody(numChunks) + c, resp, info := setupStreamTest(t, strings.NewReader(body)) + + var count atomic.Int64 + start := time.Now() + + StreamScannerHandler(c, resp, info, func(data string) bool { + count.Add(1) + return true + }) + + elapsed := time.Since(start) + assert.Equal(t, int64(numChunks), count.Load()) + assert.Equal(t, numChunks, info.ReceivedResponseCount) + t.Logf("10000 chunks processed in %v", elapsed) +} + +func TestStreamScannerHandler_OrderPreserved(t *testing.T) { + t.Parallel() + + const numChunks = 500 + body := buildSSEBody(numChunks) + c, resp, info := setupStreamTest(t, strings.NewReader(body)) + + var mu sync.Mutex + received := make([]string, 0, numChunks) + + StreamScannerHandler(c, resp, info, func(data string) bool { + mu.Lock() + received = append(received, data) + mu.Unlock() + return true + }) + + require.Equal(t, numChunks, len(received)) + for i := 0; i < numChunks; i++ { + expected := fmt.Sprintf("{\"id\":%d,\"choices\":[{\"delta\":{\"content\":\"token_%d\"}}]}", i, i) + assert.Equal(t, expected, received[i], "chunk %d out of order", i) + } +} + +func TestStreamScannerHandler_DoneStopsScanner(t *testing.T) { + t.Parallel() + + body := buildSSEBody(50) + "data: should_not_appear\n" + c, resp, info := setupStreamTest(t, strings.NewReader(body)) + + var count atomic.Int64 + StreamScannerHandler(c, resp, info, func(data string) bool { + count.Add(1) + return true + }) + + assert.Equal(t, int64(50), count.Load(), "data after [DONE] must not be processed") +} + +func TestStreamScannerHandler_HandlerFailureStops(t *testing.T) { + t.Parallel() + + const numChunks = 200 + body := buildSSEBody(numChunks) + c, resp, info := setupStreamTest(t, strings.NewReader(body)) + + const failAt = 50 + var count atomic.Int64 + StreamScannerHandler(c, resp, info, func(data string) bool { + n := count.Add(1) + return n < failAt + }) + + // The worker stops at failAt; the scanner may have read ahead, + // but the handler should not be called beyond failAt. + assert.Equal(t, int64(failAt), count.Load()) +} + +func TestStreamScannerHandler_SkipsNonDataLines(t *testing.T) { + t.Parallel() + + var b strings.Builder + b.WriteString(": comment line\n") + b.WriteString("event: message\n") + b.WriteString("id: 12345\n") + b.WriteString("retry: 5000\n") + for i := 0; i < 100; i++ { + fmt.Fprintf(&b, "data: payload_%d\n", i) + b.WriteString(": interleaved comment\n") + } + b.WriteString("data: [DONE]\n") + + c, resp, info := setupStreamTest(t, strings.NewReader(b.String())) + + var count atomic.Int64 + StreamScannerHandler(c, resp, info, func(data string) bool { + count.Add(1) + return true + }) + + assert.Equal(t, int64(100), count.Load()) +} + +func TestStreamScannerHandler_DataWithExtraSpaces(t *testing.T) { + t.Parallel() + + body := "data: {\"trimmed\":true} \ndata: [DONE]\n" + c, resp, info := setupStreamTest(t, strings.NewReader(body)) + + var got string + StreamScannerHandler(c, resp, info, func(data string) bool { + got = data + return true + }) + + assert.Equal(t, "{\"trimmed\":true}", got) +} + +// ---------- Decoupling: scanner not blocked by slow handler ---------- + +func TestStreamScannerHandler_ScannerDecoupledFromSlowHandler(t *testing.T) { + t.Parallel() + + // Strategy: use a slow upstream (io.Pipe, 10ms per chunk) AND a slow handler (20ms per chunk). + // If the scanner were synchronously coupled to the handler, total time would be + // ~numChunks * (10ms + 20ms) = 30ms * 50 = 1500ms. + // With decoupling, total time should be closer to + // ~numChunks * max(10ms, 20ms) = 20ms * 50 = 1000ms + // because the scanner reads ahead into the buffer while the handler processes. + const numChunks = 50 + const upstreamDelay = 10 * time.Millisecond + const handlerDelay = 20 * time.Millisecond + + pr, pw := io.Pipe() + go func() { + defer pw.Close() + for i := 0; i < numChunks; i++ { + fmt.Fprintf(pw, "data: {\"id\":%d}\n", i) + time.Sleep(upstreamDelay) + } + fmt.Fprint(pw, "data: [DONE]\n") + }() + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil) + + oldTimeout := constant.StreamingTimeout + constant.StreamingTimeout = 30 + t.Cleanup(func() { constant.StreamingTimeout = oldTimeout }) + + resp := &http.Response{Body: pr} + info := &relaycommon.RelayInfo{ChannelMeta: &relaycommon.ChannelMeta{}} + + var count atomic.Int64 + start := time.Now() + done := make(chan struct{}) + go func() { + StreamScannerHandler(c, resp, info, func(data string) bool { + time.Sleep(handlerDelay) + count.Add(1) + return true + }) + close(done) + }() + + select { + case <-done: + case <-time.After(15 * time.Second): + t.Fatal("StreamScannerHandler did not complete in time") + } + + elapsed := time.Since(start) + assert.Equal(t, int64(numChunks), count.Load()) + + coupledTime := time.Duration(numChunks) * (upstreamDelay + handlerDelay) + t.Logf("elapsed=%v, coupled_estimate=%v", elapsed, coupledTime) + + // If decoupled, elapsed should be well under the coupled estimate. + assert.Less(t, elapsed, coupledTime*85/100, + "decoupled elapsed time (%v) should be significantly less than coupled estimate (%v)", elapsed, coupledTime) +} + +func TestStreamScannerHandler_SlowUpstreamFastHandler(t *testing.T) { + t.Parallel() + + const numChunks = 50 + body := buildSSEBody(numChunks) + reader := &slowReader{r: strings.NewReader(body), delay: 2 * time.Millisecond} + c, resp, info := setupStreamTest(t, reader) + + var count atomic.Int64 + start := time.Now() + + done := make(chan struct{}) + go func() { + StreamScannerHandler(c, resp, info, func(data string) bool { + count.Add(1) + return true + }) + close(done) + }() + + select { + case <-done: + case <-time.After(15 * time.Second): + t.Fatal("timed out with slow upstream") + } + + elapsed := time.Since(start) + assert.Equal(t, int64(numChunks), count.Load()) + t.Logf("slow upstream (%d chunks, 2ms/read): %v", numChunks, elapsed) +} + +// ---------- Ping tests ---------- + +func TestStreamScannerHandler_PingSentDuringSlowUpstream(t *testing.T) { + t.Parallel() + + setting := operation_setting.GetGeneralSetting() + oldEnabled := setting.PingIntervalEnabled + oldSeconds := setting.PingIntervalSeconds + setting.PingIntervalEnabled = true + setting.PingIntervalSeconds = 1 + t.Cleanup(func() { + setting.PingIntervalEnabled = oldEnabled + setting.PingIntervalSeconds = oldSeconds + }) + + // Create a reader that delivers data slowly: one chunk every 500ms over 3.5 seconds. + // The ping interval is 1s, so we should see at least 2 pings. + pr, pw := io.Pipe() + go func() { + defer pw.Close() + for i := 0; i < 7; i++ { + fmt.Fprintf(pw, "data: chunk_%d\n", i) + time.Sleep(500 * time.Millisecond) + } + fmt.Fprint(pw, "data: [DONE]\n") + }() + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil) + + oldTimeout := constant.StreamingTimeout + constant.StreamingTimeout = 30 + t.Cleanup(func() { + constant.StreamingTimeout = oldTimeout + }) + + resp := &http.Response{Body: pr} + info := &relaycommon.RelayInfo{ChannelMeta: &relaycommon.ChannelMeta{}} + + var count atomic.Int64 + done := make(chan struct{}) + go func() { + StreamScannerHandler(c, resp, info, func(data string) bool { + count.Add(1) + return true + }) + close(done) + }() + + select { + case <-done: + case <-time.After(15 * time.Second): + t.Fatal("timed out waiting for stream to finish") + } + + assert.Equal(t, int64(7), count.Load()) + + body := recorder.Body.String() + pingCount := strings.Count(body, ": PING") + t.Logf("received %d pings in response body", pingCount) + assert.GreaterOrEqual(t, pingCount, 2, + "expected at least 2 pings during 3.5s stream with 1s interval; got %d", pingCount) +} + +func TestStreamScannerHandler_PingDisabledByRelayInfo(t *testing.T) { + t.Parallel() + + setting := operation_setting.GetGeneralSetting() + oldEnabled := setting.PingIntervalEnabled + oldSeconds := setting.PingIntervalSeconds + setting.PingIntervalEnabled = true + setting.PingIntervalSeconds = 1 + t.Cleanup(func() { + setting.PingIntervalEnabled = oldEnabled + setting.PingIntervalSeconds = oldSeconds + }) + + pr, pw := io.Pipe() + go func() { + defer pw.Close() + for i := 0; i < 5; i++ { + fmt.Fprintf(pw, "data: chunk_%d\n", i) + time.Sleep(500 * time.Millisecond) + } + fmt.Fprint(pw, "data: [DONE]\n") + }() + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil) + + oldTimeout := constant.StreamingTimeout + constant.StreamingTimeout = 30 + t.Cleanup(func() { + constant.StreamingTimeout = oldTimeout + }) + + resp := &http.Response{Body: pr} + info := &relaycommon.RelayInfo{ + DisablePing: true, + ChannelMeta: &relaycommon.ChannelMeta{}, + } + + var count atomic.Int64 + done := make(chan struct{}) + go func() { + StreamScannerHandler(c, resp, info, func(data string) bool { + count.Add(1) + return true + }) + close(done) + }() + + select { + case <-done: + case <-time.After(15 * time.Second): + t.Fatal("timed out") + } + + assert.Equal(t, int64(5), count.Load()) + + body := recorder.Body.String() + pingCount := strings.Count(body, ": PING") + assert.Equal(t, 0, pingCount, "pings should be disabled when DisablePing=true") +} + +func TestStreamScannerHandler_PingInterleavesWithSlowUpstream(t *testing.T) { + t.Parallel() + + setting := operation_setting.GetGeneralSetting() + oldEnabled := setting.PingIntervalEnabled + oldSeconds := setting.PingIntervalSeconds + setting.PingIntervalEnabled = true + setting.PingIntervalSeconds = 1 + t.Cleanup(func() { + setting.PingIntervalEnabled = oldEnabled + setting.PingIntervalSeconds = oldSeconds + }) + + // Slow upstream + slow handler. Total stream takes ~5 seconds. + // The ping goroutine stays alive as long as the scanner is reading, + // so pings should fire between data writes. + pr, pw := io.Pipe() + go func() { + defer pw.Close() + for i := 0; i < 10; i++ { + fmt.Fprintf(pw, "data: chunk_%d\n", i) + time.Sleep(500 * time.Millisecond) + } + fmt.Fprint(pw, "data: [DONE]\n") + }() + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil) + + oldTimeout := constant.StreamingTimeout + constant.StreamingTimeout = 30 + t.Cleanup(func() { + constant.StreamingTimeout = oldTimeout + }) + + resp := &http.Response{Body: pr} + info := &relaycommon.RelayInfo{ChannelMeta: &relaycommon.ChannelMeta{}} + + var count atomic.Int64 + done := make(chan struct{}) + go func() { + StreamScannerHandler(c, resp, info, func(data string) bool { + count.Add(1) + return true + }) + close(done) + }() + + select { + case <-done: + case <-time.After(15 * time.Second): + t.Fatal("timed out") + } + + assert.Equal(t, int64(10), count.Load()) + + body := recorder.Body.String() + pingCount := strings.Count(body, ": PING") + t.Logf("received %d pings interleaved with 10 chunks over 5s", pingCount) + assert.GreaterOrEqual(t, pingCount, 3, + "expected at least 3 pings during 5s stream with 1s ping interval; got %d", pingCount) +} From 79e1daff5a6545c018fcd63911fe6fda81215be4 Mon Sep 17 00:00:00 2001 From: RedwindA Date: Tue, 24 Feb 2026 21:44:21 +0800 Subject: [PATCH 38/41] feat(web): add custom-model create hint and i18n translations --- .../channels/modals/EditChannelModal.jsx | 31 ++++++++++++++++++ .../table/channels/modals/EditTagModal.jsx | 32 ++++++++++++++++++- web/src/i18n/locales/en.json | 3 +- web/src/i18n/locales/fr.json | 3 +- web/src/i18n/locales/ja.json | 3 +- web/src/i18n/locales/ru.json | 3 +- web/src/i18n/locales/vi.json | 3 +- web/src/i18n/locales/zh-CN.json | 3 +- web/src/i18n/locales/zh-TW.json | 3 +- 9 files changed, 76 insertions(+), 8 deletions(-) diff --git a/web/src/components/table/channels/modals/EditChannelModal.jsx b/web/src/components/table/channels/modals/EditChannelModal.jsx index 8d30a5a34..3d3afcc38 100644 --- a/web/src/components/table/channels/modals/EditChannelModal.jsx +++ b/web/src/components/table/channels/modals/EditChannelModal.jsx @@ -191,6 +191,7 @@ const EditChannelModal = (props) => { const [fullModels, setFullModels] = useState([]); const [modelGroups, setModelGroups] = useState([]); const [customModel, setCustomModel] = useState(''); + const [modelSearchValue, setModelSearchValue] = useState(''); const [modalImageUrl, setModalImageUrl] = useState(''); const [isModalOpenurl, setIsModalOpenurl] = useState(false); const [modelModalVisible, setModelModalVisible] = useState(false); @@ -231,6 +232,25 @@ const EditChannelModal = (props) => { return []; } }, [inputs.model_mapping]); + const modelSearchMatchedCount = useMemo(() => { + const keyword = modelSearchValue.trim(); + if (!keyword) { + return modelOptions.length; + } + return modelOptions.reduce( + (count, option) => count + (selectFilter(keyword, option) ? 1 : 0), + 0, + ); + }, [modelOptions, modelSearchValue]); + const modelSearchHintText = useMemo(() => { + const keyword = modelSearchValue.trim(); + if (!keyword || modelSearchMatchedCount !== 0) { + return ''; + } + return t('未匹配到模型,按回车键可将「{{name}}」作为自定义模型名添加', { + name: keyword, + }); + }, [modelSearchMatchedCount, modelSearchValue, t]); const [isIonetChannel, setIsIonetChannel] = useState(false); const [ionetMetadata, setIonetMetadata] = useState(null); const [codexOAuthModalVisible, setCodexOAuthModalVisible] = useState(false); @@ -1019,6 +1039,7 @@ const EditChannelModal = (props) => { }, [inputs]); useEffect(() => { + setModelSearchValue(''); if (props.visible) { if (isEdit) { loadChannel(); @@ -1073,6 +1094,7 @@ const EditChannelModal = (props) => { // 重置豆包隐藏入口状态 setDoubaoApiEditUnlocked(false); doubaoApiClickCountRef.current = 0; + setModelSearchValue(''); // 清空表单中的key_mode字段 if (formApiRef.current) { formApiRef.current.setValue('key_mode', undefined); @@ -2815,9 +2837,18 @@ const EditChannelModal = (props) => { rules={[{ required: true, message: t('请选择模型') }]} multiple filter={selectFilter} + allowCreate autoClearSearchValue={false} searchPosition='dropdown' optionList={modelOptions} + onSearch={(value) => setModelSearchValue(value)} + innerBottomSlot={ + modelSearchHintText ? ( + + {modelSearchHintText} + + ) : null + } style={{ width: '100%' }} onChange={(value) => handleInputChange('models', value)} renderSelectedItem={(optionNode) => { diff --git a/web/src/components/table/channels/modals/EditTagModal.jsx b/web/src/components/table/channels/modals/EditTagModal.jsx index d4d060690..fbb00be58 100644 --- a/web/src/components/table/channels/modals/EditTagModal.jsx +++ b/web/src/components/table/channels/modals/EditTagModal.jsx @@ -17,7 +17,7 @@ along with this program. If not, see . For commercial licensing, please contact support@quantumnous.com */ -import React, { useState, useEffect, useRef } from 'react'; +import React, { useState, useEffect, useRef, useMemo } from 'react'; import { API, showError, @@ -64,6 +64,7 @@ const EditTagModal = (props) => { const [modelOptions, setModelOptions] = useState([]); const [groupOptions, setGroupOptions] = useState([]); const [customModel, setCustomModel] = useState(''); + const [modelSearchValue, setModelSearchValue] = useState(''); const originInputs = { tag: '', new_tag: null, @@ -74,6 +75,25 @@ const EditTagModal = (props) => { header_override: null, }; const [inputs, setInputs] = useState(originInputs); + const modelSearchMatchedCount = useMemo(() => { + const keyword = modelSearchValue.trim(); + if (!keyword) { + return modelOptions.length; + } + return modelOptions.reduce( + (count, option) => count + (selectFilter(keyword, option) ? 1 : 0), + 0, + ); + }, [modelOptions, modelSearchValue]); + const modelSearchHintText = useMemo(() => { + const keyword = modelSearchValue.trim(); + if (!keyword || modelSearchMatchedCount !== 0) { + return ''; + } + return t('未匹配到模型,按回车键可将「{{name}}」作为自定义模型名添加', { + name: keyword, + }); + }, [modelSearchMatchedCount, modelSearchValue, t]); const formApiRef = useRef(null); const getInitValues = () => ({ ...originInputs }); @@ -292,6 +312,7 @@ const EditTagModal = (props) => { fetchModels().then(); fetchGroups().then(); fetchTagModels().then(); + setModelSearchValue(''); if (formApiRef.current) { formApiRef.current.setValues({ ...getInitValues(), @@ -461,9 +482,18 @@ const EditTagModal = (props) => { placeholder={t('请选择该渠道所支持的模型,留空则不更改')} multiple filter={selectFilter} + allowCreate autoClearSearchValue={false} searchPosition='dropdown' optionList={modelOptions} + onSearch={(value) => setModelSearchValue(value)} + innerBottomSlot={ + modelSearchHintText ? ( + + {modelSearchHintText} + + ) : null + } style={{ width: '100%' }} onChange={(value) => handleInputChange('models', value)} /> diff --git a/web/src/i18n/locales/en.json b/web/src/i18n/locales/en.json index e06c68362..f6c13e7d8 100644 --- a/web/src/i18n/locales/en.json +++ b/web/src/i18n/locales/en.json @@ -2835,6 +2835,7 @@ "缓存写": "Cache Write", "写": "Write", "根据 Anthropic 协定,/v1/messages 的输入 tokens 仅统计非缓存输入,不包含缓存读取与缓存写入 tokens。": "Per Anthropic conventions, /v1/messages input tokens count only non-cached input and exclude cache read/write tokens.", - "设计版本": "b80c3466cb6feafeb3990c7820e10e50" + "设计版本": "b80c3466cb6feafeb3990c7820e10e50", + "未匹配到模型,按回车键可将「{{name}}」作为自定义模型名添加": "No matching models. Press Enter to add \"{{name}}\" as a custom model name." } } diff --git a/web/src/i18n/locales/fr.json b/web/src/i18n/locales/fr.json index 2843728b8..c36b969dd 100644 --- a/web/src/i18n/locales/fr.json +++ b/web/src/i18n/locales/fr.json @@ -2737,6 +2737,7 @@ "缓存写": "Écriture cache", "写": "Écriture", "根据 Anthropic 协定,/v1/messages 的输入 tokens 仅统计非缓存输入,不包含缓存读取与缓存写入 tokens。": "Selon la convention Anthropic, les tokens d'entrée de /v1/messages ne comptent que les entrées non mises en cache et excluent les tokens de lecture/écriture du cache.", - "设计版本": "b80c3466cb6feafeb3990c7820e10e50" + "设计版本": "b80c3466cb6feafeb3990c7820e10e50", + "未匹配到模型,按回车键可将「{{name}}」作为自定义模型名添加": "Aucun modèle correspondant. Appuyez sur Entrée pour ajouter «{{name}}» comme nom de modèle personnalisé." } } diff --git a/web/src/i18n/locales/ja.json b/web/src/i18n/locales/ja.json index d18a62923..2951e9ea3 100644 --- a/web/src/i18n/locales/ja.json +++ b/web/src/i18n/locales/ja.json @@ -2720,6 +2720,7 @@ "缓存写": "キャッシュ書込", "写": "書込", "根据 Anthropic 协定,/v1/messages 的输入 tokens 仅统计非缓存输入,不包含缓存读取与缓存写入 tokens。": "Anthropic の仕様により、/v1/messages の入力 tokens は非キャッシュ入力のみを集計し、キャッシュ読み取り/書き込み tokens は含みません。", - "设计版本": "b80c3466cb6feafeb3990c7820e10e50" + "设计版本": "b80c3466cb6feafeb3990c7820e10e50", + "未匹配到模型,按回车键可将「{{name}}」作为自定义模型名添加": "一致するモデルが見つかりません。Enterキーで「{{name}}」をカスタムモデル名として追加できます。" } } diff --git a/web/src/i18n/locales/ru.json b/web/src/i18n/locales/ru.json index 099f405c9..82ccb0edf 100644 --- a/web/src/i18n/locales/ru.json +++ b/web/src/i18n/locales/ru.json @@ -2750,6 +2750,7 @@ "缓存写": "Запись в кэш", "写": "Запись", "根据 Anthropic 协定,/v1/messages 的输入 tokens 仅统计非缓存输入,不包含缓存读取与缓存写入 tokens。": "Согласно соглашению Anthropic, входные токены /v1/messages учитывают только некэшированный ввод и не включают токены чтения/записи кэша.", - "设计版本": "b80c3466cb6feafeb3990c7820e10e50" + "设计版本": "b80c3466cb6feafeb3990c7820e10e50", + "未匹配到模型,按回车键可将「{{name}}」作为自定义模型名添加": "Совпадающих моделей не найдено. Нажмите Enter, чтобы добавить «{{name}}» как пользовательское имя модели." } } diff --git a/web/src/i18n/locales/vi.json b/web/src/i18n/locales/vi.json index d2602efdf..f78620cff 100644 --- a/web/src/i18n/locales/vi.json +++ b/web/src/i18n/locales/vi.json @@ -3296,6 +3296,7 @@ "缓存写": "Ghi bộ nhớ đệm", "写": "Ghi", "根据 Anthropic 协定,/v1/messages 的输入 tokens 仅统计非缓存输入,不包含缓存读取与缓存写入 tokens。": "Theo quy ước của Anthropic, input tokens của /v1/messages chỉ tính phần đầu vào không dùng cache và không bao gồm tokens đọc/ghi cache.", - "设计版本": "b80c3466cb6feafeb3990c7820e10e50" + "设计版本": "b80c3466cb6feafeb3990c7820e10e50", + "未匹配到模型,按回车键可将「{{name}}」作为自定义模型名添加": "Không tìm thấy mô hình khớp. Nhấn Enter để thêm \"{{name}}\" làm tên mô hình tùy chỉnh." } } diff --git a/web/src/i18n/locales/zh-CN.json b/web/src/i18n/locales/zh-CN.json index d067ad569..fb135f6fb 100644 --- a/web/src/i18n/locales/zh-CN.json +++ b/web/src/i18n/locales/zh-CN.json @@ -2812,6 +2812,7 @@ "缓存读": "缓存读", "缓存写": "缓存写", "写": "写", - "根据 Anthropic 协定,/v1/messages 的输入 tokens 仅统计非缓存输入,不包含缓存读取与缓存写入 tokens。": "根据 Anthropic 协定,/v1/messages 的输入 tokens 仅统计非缓存输入,不包含缓存读取与缓存写入 tokens。" + "根据 Anthropic 协定,/v1/messages 的输入 tokens 仅统计非缓存输入,不包含缓存读取与缓存写入 tokens。": "根据 Anthropic 协定,/v1/messages 的输入 tokens 仅统计非缓存输入,不包含缓存读取与缓存写入 tokens。", + "未匹配到模型,按回车键可将「{{name}}」作为自定义模型名添加": "未匹配到模型,按回车键可将「{{name}}」作为自定义模型名添加" } } diff --git a/web/src/i18n/locales/zh-TW.json b/web/src/i18n/locales/zh-TW.json index 26f7092b7..85be3f9f7 100644 --- a/web/src/i18n/locales/zh-TW.json +++ b/web/src/i18n/locales/zh-TW.json @@ -2805,6 +2805,7 @@ "填写服务器地址后自动生成:": "填寫伺服器位址後自動生成:", "自动生成:": "自動生成:", "请先填写服务器地址,以自动生成完整的端点 URL": "請先填寫伺服器位址,以自動生成完整的端點 URL", - "端点 URL 必须是完整地址(以 http:// 或 https:// 开头)": "端點 URL 必須是完整位址(以 http:// 或 https:// 開頭)" + "端点 URL 必须是完整地址(以 http:// 或 https:// 开头)": "端點 URL 必須是完整位址(以 http:// 或 https:// 開頭)", + "未匹配到模型,按回车键可将「{{name}}」作为自定义模型名添加": "未匹配到模型,按下 Enter 鍵可將「{{name}}」作為自訂模型名稱新增" } } From 0da0d8064768aacc4328cb434ce0d421f7888773 Mon Sep 17 00:00:00 2001 From: CaIon Date: Tue, 24 Feb 2026 23:46:17 +0800 Subject: [PATCH 39/41] fix: handle nil setting in user retrieval from database --- model/user.go | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/model/user.go b/model/user.go index f3e66e6c7..1210b5435 100644 --- a/model/user.go +++ b/model/user.go @@ -1,6 +1,7 @@ package model import ( + "database/sql" "encoding/json" "errors" "fmt" @@ -853,10 +854,17 @@ func GetUserSetting(id int, fromDB bool) (settingMap dto.UserSetting, err error) // Don't return error - fall through to DB } fromDB = true - err = DB.Model(&User{}).Where("id = ?", id).Select("setting").Find(&setting).Error + // can be nil setting + var safeSetting sql.NullString + err = DB.Model(&User{}).Where("id = ?", id).Select("setting").Find(&safeSetting).Error if err != nil { return settingMap, err } + if safeSetting.Valid { + setting = safeSetting.String + } else { + setting = "" + } userBase := &UserBase{ Setting: setting, } From c5365e4b4306e0084c1cb9f2020c7f7a5d7bcae8 Mon Sep 17 00:00:00 2001 From: CaIon Date: Wed, 25 Feb 2026 00:11:24 +0800 Subject: [PATCH 40/41] feat(middleware): add RouteTag middleware for enhanced logging and routing - Introduced RouteTag middleware to set route tags for different API endpoints. - Updated logger to include route tags in log output. - Applied RouteTag middleware across various routers including API, dashboard, relay, video, and web routers for consistent logging. --- middleware/logger.go | 18 ++++++++++++++++-- router/api-router.go | 1 + router/dashboard.go | 1 + router/main.go | 2 ++ router/relay-router.go | 9 +++++++++ router/video-router.go | 4 ++++ router/web-router.go | 1 + 7 files changed, 34 insertions(+), 2 deletions(-) diff --git a/middleware/logger.go b/middleware/logger.go index b4ed8c89d..151008d9f 100644 --- a/middleware/logger.go +++ b/middleware/logger.go @@ -7,14 +7,28 @@ import ( "github.com/gin-gonic/gin" ) +const RouteTagKey = "route_tag" + +func RouteTag(tag string) gin.HandlerFunc { + return func(c *gin.Context) { + c.Set(RouteTagKey, tag) + c.Next() + } +} + func SetUpLogger(server *gin.Engine) { server.Use(gin.LoggerWithFormatter(func(param gin.LogFormatterParams) string { var requestID string if param.Keys != nil { - requestID = param.Keys[common.RequestIdKey].(string) + requestID, _ = param.Keys[common.RequestIdKey].(string) } - return fmt.Sprintf("[GIN] %s | %s | %3d | %13v | %15s | %7s %s\n", + tag, _ := param.Keys[RouteTagKey].(string) + if tag == "" { + tag = "web" + } + return fmt.Sprintf("[GIN] %s | %s | %s | %3d | %13v | %15s | %7s %s\n", param.TimeStamp.Format("2006/01/02 - 15:04:05"), + tag, requestID, param.StatusCode, param.Latency, diff --git a/router/api-router.go b/router/api-router.go index b6e418c6e..d48934000 100644 --- a/router/api-router.go +++ b/router/api-router.go @@ -13,6 +13,7 @@ import ( func SetApiRouter(router *gin.Engine) { apiRouter := router.Group("/api") + apiRouter.Use(middleware.RouteTag("api")) apiRouter.Use(gzip.Gzip(gzip.DefaultCompression)) apiRouter.Use(middleware.BodyStorageCleanup()) // 清理请求体存储 apiRouter.Use(middleware.GlobalAPIRateLimit()) diff --git a/router/dashboard.go b/router/dashboard.go index 17132dfb2..2e486156d 100644 --- a/router/dashboard.go +++ b/router/dashboard.go @@ -9,6 +9,7 @@ import ( func SetDashboardRouter(router *gin.Engine) { apiRouter := router.Group("/") + apiRouter.Use(middleware.RouteTag("old_api")) apiRouter.Use(gzip.Gzip(gzip.DefaultCompression)) apiRouter.Use(middleware.GlobalAPIRateLimit()) apiRouter.Use(middleware.CORS()) diff --git a/router/main.go b/router/main.go index 45b3080f2..ac9506fe4 100644 --- a/router/main.go +++ b/router/main.go @@ -8,6 +8,7 @@ import ( "strings" "github.com/QuantumNous/new-api/common" + "github.com/QuantumNous/new-api/middleware" "github.com/gin-gonic/gin" ) @@ -27,6 +28,7 @@ func SetRouter(router *gin.Engine, buildFS embed.FS, indexPage []byte) { } else { frontendBaseUrl = strings.TrimSuffix(frontendBaseUrl, "/") router.NoRoute(func(c *gin.Context) { + c.Set(middleware.RouteTagKey, "web") c.Redirect(http.StatusMovedPermanently, fmt.Sprintf("%s%s", frontendBaseUrl, c.Request.RequestURI)) }) } diff --git a/router/relay-router.go b/router/relay-router.go index dcec439cb..3d38be5ee 100644 --- a/router/relay-router.go +++ b/router/relay-router.go @@ -17,6 +17,7 @@ func SetRelayRouter(router *gin.Engine) { router.Use(middleware.StatsMiddleware()) // https://platform.openai.com/docs/api-reference/introduction modelsRouter := router.Group("/v1/models") + modelsRouter.Use(middleware.RouteTag("relay")) modelsRouter.Use(middleware.TokenAuth()) { modelsRouter.GET("", func(c *gin.Context) { @@ -41,6 +42,7 @@ func SetRelayRouter(router *gin.Engine) { } geminiRouter := router.Group("/v1beta/models") + geminiRouter.Use(middleware.RouteTag("relay")) geminiRouter.Use(middleware.TokenAuth()) { geminiRouter.GET("", func(c *gin.Context) { @@ -49,6 +51,7 @@ func SetRelayRouter(router *gin.Engine) { } geminiCompatibleRouter := router.Group("/v1beta/openai/models") + geminiCompatibleRouter.Use(middleware.RouteTag("relay")) geminiCompatibleRouter.Use(middleware.TokenAuth()) { geminiCompatibleRouter.GET("", func(c *gin.Context) { @@ -57,12 +60,14 @@ func SetRelayRouter(router *gin.Engine) { } playgroundRouter := router.Group("/pg") + playgroundRouter.Use(middleware.RouteTag("relay")) playgroundRouter.Use(middleware.SystemPerformanceCheck()) playgroundRouter.Use(middleware.UserAuth(), middleware.Distribute()) { playgroundRouter.POST("/chat/completions", controller.Playground) } relayV1Router := router.Group("/v1") + relayV1Router.Use(middleware.RouteTag("relay")) relayV1Router.Use(middleware.SystemPerformanceCheck()) relayV1Router.Use(middleware.TokenAuth()) relayV1Router.Use(middleware.ModelRequestRateLimit()) @@ -161,15 +166,18 @@ func SetRelayRouter(router *gin.Engine) { } relayMjRouter := router.Group("/mj") + relayMjRouter.Use(middleware.RouteTag("relay")) relayMjRouter.Use(middleware.SystemPerformanceCheck()) registerMjRouterGroup(relayMjRouter) relayMjModeRouter := router.Group("/:mode/mj") + relayMjModeRouter.Use(middleware.RouteTag("relay")) relayMjModeRouter.Use(middleware.SystemPerformanceCheck()) registerMjRouterGroup(relayMjModeRouter) //relayMjRouter.Use() relaySunoRouter := router.Group("/suno") + relaySunoRouter.Use(middleware.RouteTag("relay")) relaySunoRouter.Use(middleware.SystemPerformanceCheck()) relaySunoRouter.Use(middleware.TokenAuth(), middleware.Distribute()) { @@ -179,6 +187,7 @@ func SetRelayRouter(router *gin.Engine) { } relayGeminiRouter := router.Group("/v1beta") + relayGeminiRouter.Use(middleware.RouteTag("relay")) relayGeminiRouter.Use(middleware.SystemPerformanceCheck()) relayGeminiRouter.Use(middleware.TokenAuth()) relayGeminiRouter.Use(middleware.ModelRequestRateLimit()) diff --git a/router/video-router.go b/router/video-router.go index 875b0af86..461451104 100644 --- a/router/video-router.go +++ b/router/video-router.go @@ -10,12 +10,14 @@ import ( func SetVideoRouter(router *gin.Engine) { // Video proxy: accepts either session auth (dashboard) or token auth (API clients) videoProxyRouter := router.Group("/v1") + videoProxyRouter.Use(middleware.RouteTag("relay")) videoProxyRouter.Use(middleware.TokenOrUserAuth()) { videoProxyRouter.GET("/videos/:task_id/content", controller.VideoProxy) } videoV1Router := router.Group("/v1") + videoV1Router.Use(middleware.RouteTag("relay")) videoV1Router.Use(middleware.TokenAuth(), middleware.Distribute()) { videoV1Router.POST("/video/generations", controller.RelayTask) @@ -30,6 +32,7 @@ func SetVideoRouter(router *gin.Engine) { } klingV1Router := router.Group("/kling/v1") + klingV1Router.Use(middleware.RouteTag("relay")) klingV1Router.Use(middleware.KlingRequestConvert(), middleware.TokenAuth(), middleware.Distribute()) { klingV1Router.POST("/videos/text2video", controller.RelayTask) @@ -40,6 +43,7 @@ func SetVideoRouter(router *gin.Engine) { // Jimeng official API routes - direct mapping to official API format jimengOfficialGroup := router.Group("jimeng") + jimengOfficialGroup.Use(middleware.RouteTag("relay")) jimengOfficialGroup.Use(middleware.JimengRequestConvert(), middleware.TokenAuth(), middleware.Distribute()) { // Maps to: /?Action=CVSync2AsyncSubmitTask&Version=2022-08-31 and /?Action=CVSync2AsyncGetResult&Version=2022-08-31 diff --git a/router/web-router.go b/router/web-router.go index b053a3e63..17a8378dd 100644 --- a/router/web-router.go +++ b/router/web-router.go @@ -19,6 +19,7 @@ func SetWebRouter(router *gin.Engine, buildFS embed.FS, indexPage []byte) { router.Use(middleware.Cache()) router.Use(static.Serve("/", common.EmbedFolder(buildFS, "web/dist"))) router.NoRoute(func(c *gin.Context) { + c.Set(middleware.RouteTagKey, "web") if strings.HasPrefix(c.Request.RequestURI, "/v1") || strings.HasPrefix(c.Request.RequestURI, "/api") || strings.HasPrefix(c.Request.RequestURI, "/assets") { controller.RelayNotFound(c) return From 4a4cf0a0dfb545b90051ade7db9cd145605d47db Mon Sep 17 00:00:00 2001 From: CaIon Date: Wed, 25 Feb 2026 12:51:46 +0800 Subject: [PATCH 41/41] fix: improve multipart form data handling by detecting content type. fix #3007 --- relay/channel/task/sora/adaptor.go | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/relay/channel/task/sora/adaptor.go b/relay/channel/task/sora/adaptor.go index 33db8fe55..e9029aa20 100644 --- a/relay/channel/task/sora/adaptor.go +++ b/relay/channel/task/sora/adaptor.go @@ -6,6 +6,7 @@ import ( "io" "mime/multipart" "net/http" + "net/textproto" "strconv" "strings" @@ -186,7 +187,22 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayIn if err != nil { continue } - part, err := writer.CreateFormFile(fieldName, fh.Filename) + ct := fh.Header.Get("Content-Type") + if ct == "" || ct == "application/octet-stream" { + buf512 := make([]byte, 512) + n, _ := io.ReadFull(f, buf512) + ct = http.DetectContentType(buf512[:n]) + // Re-open after sniffing so the full content is copied below + f.Close() + f, err = fh.Open() + if err != nil { + continue + } + } + h := make(textproto.MIMEHeader) + h.Set("Content-Disposition", fmt.Sprintf(`form-data; name="%s"; filename="%s"`, fieldName, fh.Filename)) + h.Set("Content-Type", ct) + part, err := writer.CreatePart(h) if err != nil { f.Close() continue