Merge branch 'upstream-main' into feature/improve-param-override

# Conflicts:
#	relay/channel/api_request_test.go
#	relay/common/override_test.go
#	web/src/components/table/channels/modals/EditChannelModal.jsx
This commit is contained in:
Seefs
2026-02-25 13:39:54 +08:00
124 changed files with 7729 additions and 1971 deletions

View File

@@ -6,6 +6,9 @@ import (
"testing"
"github.com/QuantumNous/new-api/types"
"github.com/QuantumNous/new-api/dto"
"github.com/QuantumNous/new-api/setting/model_setting"
)
func TestApplyParamOverrideTrimPrefix(t *testing.T) {
@@ -1311,6 +1314,76 @@ func TestApplyParamOverrideWithRelayInfoMoveAndCopyHeaders(t *testing.T) {
}
}
func TestRemoveDisabledFieldsSkipWhenChannelPassThroughEnabled(t *testing.T) {
input := `{
"service_tier":"flex",
"safety_identifier":"user-123",
"store":true,
"stream_options":{"include_obfuscation":false}
}`
settings := dto.ChannelOtherSettings{}
out, err := RemoveDisabledFields([]byte(input), settings, true)
if err != nil {
t.Fatalf("RemoveDisabledFields returned error: %v", err)
}
assertJSONEqual(t, input, string(out))
}
func TestRemoveDisabledFieldsSkipWhenGlobalPassThroughEnabled(t *testing.T) {
original := model_setting.GetGlobalSettings().PassThroughRequestEnabled
model_setting.GetGlobalSettings().PassThroughRequestEnabled = true
t.Cleanup(func() {
model_setting.GetGlobalSettings().PassThroughRequestEnabled = original
})
input := `{
"service_tier":"flex",
"safety_identifier":"user-123",
"stream_options":{"include_obfuscation":false}
}`
settings := dto.ChannelOtherSettings{}
out, err := RemoveDisabledFields([]byte(input), settings, false)
if err != nil {
t.Fatalf("RemoveDisabledFields returned error: %v", err)
}
assertJSONEqual(t, input, string(out))
}
func TestRemoveDisabledFieldsDefaultFiltering(t *testing.T) {
input := `{
"service_tier":"flex",
"inference_geo":"eu",
"safety_identifier":"user-123",
"store":true,
"stream_options":{"include_obfuscation":false}
}`
settings := dto.ChannelOtherSettings{}
out, err := RemoveDisabledFields([]byte(input), settings, false)
if err != nil {
t.Fatalf("RemoveDisabledFields returned error: %v", err)
}
assertJSONEqual(t, `{"store":true}`, string(out))
}
func TestRemoveDisabledFieldsAllowInferenceGeo(t *testing.T) {
input := `{
"inference_geo":"eu",
"store":true
}`
settings := dto.ChannelOtherSettings{
AllowInferenceGeo: true,
}
out, err := RemoveDisabledFields([]byte(input), settings, false)
if err != nil {
t.Fatalf("RemoveDisabledFields returned error: %v", err)
}
assertJSONEqual(t, `{"inference_geo":"eu","store":true}`, string(out))
}
func assertJSONEqual(t *testing.T, want, got string) {
t.Helper()

View File

@@ -119,8 +119,12 @@ type RelayInfo struct {
SendResponseCount int
ReceivedResponseCount int
FinalPreConsumedQuota int // 最终预消耗的配额
// ForcePreConsume 为 true 时禁用 BillingSession 的信任额度旁路,
// 强制预扣全额。用于异步任务(视频/音乐生成等),因为请求返回后任务仍在运行,
// 必须在提交前锁定全额。
ForcePreConsume bool
// Billing 是计费会话,封装了预扣费/结算/退款的统一生命周期。
// 免费模型和按次计费MJ/Task时为 nil。
// 免费模型时为 nil。
Billing BillingSettler
// BillingSource indicates whether this request is billed from wallet quota or subscription.
// "" or "wallet" => wallet; "subscription" => subscription
@@ -153,7 +157,8 @@ type RelayInfo struct {
// RequestConversionChain records request format conversions in order, e.g.
// ["openai", "openai_responses"] or ["openai", "claude"].
RequestConversionChain []types.RelayFormat
// 最终请求到上游的格式 TODO: 当前仅设置了Claude
// 最终请求到上游的格式。可由 adaptor 显式设置;
// 若为空,调用 GetFinalRequestRelayFormat 会回退到 RequestConversionChain 的最后一项或 RelayFormat。
FinalRequestRelayFormat types.RelayFormat
ThinkingContentInfo
@@ -552,8 +557,10 @@ func GenRelayInfo(c *gin.Context, relayFormat types.RelayFormat, request dto.Req
return nil, errors.New("request is not a OpenAIResponsesCompactionRequest")
case types.RelayFormatTask:
info = genBaseRelayInfo(c, nil)
info.TaskRelayInfo = &TaskRelayInfo{}
case types.RelayFormatMjProxy:
info = genBaseRelayInfo(c, nil)
info.TaskRelayInfo = &TaskRelayInfo{}
default:
err = errors.New("invalid relay format")
}
@@ -600,6 +607,19 @@ func (info *RelayInfo) AppendRequestConversion(format types.RelayFormat) {
info.RequestConversionChain = append(info.RequestConversionChain, format)
}
func (info *RelayInfo) GetFinalRequestRelayFormat() types.RelayFormat {
if info == nil {
return ""
}
if info.FinalRequestRelayFormat != "" {
return info.FinalRequestRelayFormat
}
if n := len(info.RequestConversionChain); n > 0 {
return info.RequestConversionChain[n-1]
}
return info.RelayFormat
}
func GenRelayInfoResponsesCompaction(c *gin.Context, request *dto.OpenAIResponsesCompactionRequest) *RelayInfo {
info := genBaseRelayInfo(c, request)
if info.RelayMode == relayconstant.RelayModeUnknown {
@@ -635,8 +655,16 @@ func (info *RelayInfo) HasSendResponse() bool {
type TaskRelayInfo struct {
Action string
OriginTaskID string
// PublicTaskID 是提交时预生成的 task_xxxx 格式公开 ID
// 供 DoResponse 在返回给客户端时使用(避免暴露上游真实 ID
PublicTaskID string
ConsumeQuota bool
// LockedChannel holds the full channel object when the request is bound to
// a specific channel (e.g., remix on origin task's channel). Stored as any
// to avoid an import cycle with model; callers type-assert to *model.Channel.
LockedChannel any
}
type TaskSubmitReq struct {
@@ -694,11 +722,11 @@ func (t *TaskSubmitReq) UnmarshalJSON(data []byte) error {
func (t *TaskSubmitReq) UnmarshalMetadata(v any) error {
metadata := t.Metadata
if metadata != nil {
metadataBytes, err := json.Marshal(metadata)
metadataBytes, err := common.Marshal(metadata)
if err != nil {
return fmt.Errorf("marshal metadata failed: %w", err)
}
err = json.Unmarshal(metadataBytes, v)
err = common.Unmarshal(metadataBytes, v)
if err != nil {
return fmt.Errorf("unmarshal metadata to target failed: %w", err)
}
@@ -727,9 +755,15 @@ func FailTaskInfo(reason string) *TaskInfo {
// RemoveDisabledFields 从请求 JSON 数据中移除渠道设置中禁用的字段
// service_tier: 服务层级字段可能导致额外计费OpenAI、Claude、Responses API 支持)
// inference_geo: Claude 数据驻留推理区域字段(仅 Claude 支持,默认过滤)
// store: 数据存储授权字段,涉及用户隐私(仅 OpenAI、Responses API 支持,默认允许透传,禁用后可能导致 Codex 无法使用)
// safety_identifier: 安全标识符,用于向 OpenAI 报告违规用户(仅 OpenAI 支持,涉及用户隐私)
func RemoveDisabledFields(jsonData []byte, channelOtherSettings dto.ChannelOtherSettings) ([]byte, error) {
// stream_options.include_obfuscation: 响应流混淆控制字段(仅 OpenAI Responses API 支持)
func RemoveDisabledFields(jsonData []byte, channelOtherSettings dto.ChannelOtherSettings, channelPassThroughEnabled bool) ([]byte, error) {
if model_setting.GetGlobalSettings().PassThroughRequestEnabled || channelPassThroughEnabled {
return jsonData, nil
}
var data map[string]interface{}
if err := common.Unmarshal(jsonData, &data); err != nil {
common.SysError("RemoveDisabledFields Unmarshal error :" + err.Error())
@@ -743,6 +777,13 @@ func RemoveDisabledFields(jsonData []byte, channelOtherSettings dto.ChannelOther
}
}
// 默认移除 inference_geo除非明确允许避免在未授权情况下透传数据驻留区域
if !channelOtherSettings.AllowInferenceGeo {
if _, exists := data["inference_geo"]; exists {
delete(data, "inference_geo")
}
}
// 默认允许 store 透传,除非明确禁用(禁用可能影响 Codex 使用)
if channelOtherSettings.DisableStore {
if _, exists := data["store"]; exists {
@@ -757,6 +798,22 @@ func RemoveDisabledFields(jsonData []byte, channelOtherSettings dto.ChannelOther
}
}
// 默认移除 stream_options.include_obfuscation除非明确允许避免关闭响应流混淆保护
if !channelOtherSettings.AllowIncludeObfuscation {
if streamOptionsAny, exists := data["stream_options"]; exists {
if streamOptions, ok := streamOptionsAny.(map[string]interface{}); ok {
if _, includeExists := streamOptions["include_obfuscation"]; includeExists {
delete(streamOptions, "include_obfuscation")
}
if len(streamOptions) == 0 {
delete(data, "stream_options")
} else {
data["stream_options"] = streamOptions
}
}
}
}
jsonDataAfter, err := common.Marshal(data)
if err != nil {
common.SysError("RemoveDisabledFields Marshal error :" + err.Error())

View File

@@ -0,0 +1,40 @@
package common
import (
"testing"
"github.com/QuantumNous/new-api/types"
"github.com/stretchr/testify/require"
)
func TestRelayInfoGetFinalRequestRelayFormatPrefersExplicitFinal(t *testing.T) {
info := &RelayInfo{
RelayFormat: types.RelayFormatOpenAI,
RequestConversionChain: []types.RelayFormat{types.RelayFormatOpenAI, types.RelayFormatClaude},
FinalRequestRelayFormat: types.RelayFormatOpenAIResponses,
}
require.Equal(t, types.RelayFormat(types.RelayFormatOpenAIResponses), info.GetFinalRequestRelayFormat())
}
func TestRelayInfoGetFinalRequestRelayFormatFallsBackToConversionChain(t *testing.T) {
info := &RelayInfo{
RelayFormat: types.RelayFormatOpenAI,
RequestConversionChain: []types.RelayFormat{types.RelayFormatOpenAI, types.RelayFormatClaude},
}
require.Equal(t, types.RelayFormat(types.RelayFormatClaude), info.GetFinalRequestRelayFormat())
}
func TestRelayInfoGetFinalRequestRelayFormatFallsBackToRelayFormat(t *testing.T) {
info := &RelayInfo{
RelayFormat: types.RelayFormatGemini,
}
require.Equal(t, types.RelayFormat(types.RelayFormatGemini), info.GetFinalRequestRelayFormat())
}
func TestRelayInfoGetFinalRequestRelayFormatNilReceiver(t *testing.T) {
var info *RelayInfo
require.Equal(t, types.RelayFormat(""), info.GetFinalRequestRelayFormat())
}

View File

@@ -173,16 +173,10 @@ func ValidateMultipartDirect(c *gin.Context, info *RelayInfo) *dto.TaskError {
if model == "sora-2-pro" && !lo.Contains([]string{"720x1280", "1280x720", "1792x1024", "1024x1792"}, size) {
return createTaskError(fmt.Errorf("sora-2 size is invalid"), "invalid_size", http.StatusBadRequest, true)
}
info.PriceData.OtherRatios = map[string]float64{
"seconds": float64(seconds),
"size": 1,
}
if lo.Contains([]string{"1792x1024", "1024x1792"}, size) {
info.PriceData.OtherRatios["size"] = 1.666667
}
// OtherRatios 已移到 Sora adaptor 的 EstimateBilling 中设置
}
info.Action = action
storeTaskRequest(c, info, action, req)
return nil
}