mirror of
https://github.com/QuantumNous/new-api.git
synced 2026-06-07 22:09:57 +00:00
Merge branch 'upstream-main' into feature/improve-param-override
# Conflicts: # relay/channel/api_request_test.go # relay/common/override_test.go # web/src/components/table/channels/modals/EditChannelModal.jsx
This commit is contained in:
@@ -6,6 +6,9 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/QuantumNous/new-api/types"
|
||||
|
||||
"github.com/QuantumNous/new-api/dto"
|
||||
"github.com/QuantumNous/new-api/setting/model_setting"
|
||||
)
|
||||
|
||||
func TestApplyParamOverrideTrimPrefix(t *testing.T) {
|
||||
@@ -1311,6 +1314,76 @@ func TestApplyParamOverrideWithRelayInfoMoveAndCopyHeaders(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestRemoveDisabledFieldsSkipWhenChannelPassThroughEnabled(t *testing.T) {
|
||||
input := `{
|
||||
"service_tier":"flex",
|
||||
"safety_identifier":"user-123",
|
||||
"store":true,
|
||||
"stream_options":{"include_obfuscation":false}
|
||||
}`
|
||||
settings := dto.ChannelOtherSettings{}
|
||||
|
||||
out, err := RemoveDisabledFields([]byte(input), settings, true)
|
||||
if err != nil {
|
||||
t.Fatalf("RemoveDisabledFields returned error: %v", err)
|
||||
}
|
||||
assertJSONEqual(t, input, string(out))
|
||||
}
|
||||
|
||||
func TestRemoveDisabledFieldsSkipWhenGlobalPassThroughEnabled(t *testing.T) {
|
||||
original := model_setting.GetGlobalSettings().PassThroughRequestEnabled
|
||||
model_setting.GetGlobalSettings().PassThroughRequestEnabled = true
|
||||
t.Cleanup(func() {
|
||||
model_setting.GetGlobalSettings().PassThroughRequestEnabled = original
|
||||
})
|
||||
|
||||
input := `{
|
||||
"service_tier":"flex",
|
||||
"safety_identifier":"user-123",
|
||||
"stream_options":{"include_obfuscation":false}
|
||||
}`
|
||||
settings := dto.ChannelOtherSettings{}
|
||||
|
||||
out, err := RemoveDisabledFields([]byte(input), settings, false)
|
||||
if err != nil {
|
||||
t.Fatalf("RemoveDisabledFields returned error: %v", err)
|
||||
}
|
||||
assertJSONEqual(t, input, string(out))
|
||||
}
|
||||
|
||||
func TestRemoveDisabledFieldsDefaultFiltering(t *testing.T) {
|
||||
input := `{
|
||||
"service_tier":"flex",
|
||||
"inference_geo":"eu",
|
||||
"safety_identifier":"user-123",
|
||||
"store":true,
|
||||
"stream_options":{"include_obfuscation":false}
|
||||
}`
|
||||
settings := dto.ChannelOtherSettings{}
|
||||
|
||||
out, err := RemoveDisabledFields([]byte(input), settings, false)
|
||||
if err != nil {
|
||||
t.Fatalf("RemoveDisabledFields returned error: %v", err)
|
||||
}
|
||||
assertJSONEqual(t, `{"store":true}`, string(out))
|
||||
}
|
||||
|
||||
func TestRemoveDisabledFieldsAllowInferenceGeo(t *testing.T) {
|
||||
input := `{
|
||||
"inference_geo":"eu",
|
||||
"store":true
|
||||
}`
|
||||
settings := dto.ChannelOtherSettings{
|
||||
AllowInferenceGeo: true,
|
||||
}
|
||||
|
||||
out, err := RemoveDisabledFields([]byte(input), settings, false)
|
||||
if err != nil {
|
||||
t.Fatalf("RemoveDisabledFields returned error: %v", err)
|
||||
}
|
||||
assertJSONEqual(t, `{"inference_geo":"eu","store":true}`, string(out))
|
||||
}
|
||||
|
||||
func assertJSONEqual(t *testing.T, want, got string) {
|
||||
t.Helper()
|
||||
|
||||
|
||||
@@ -119,8 +119,12 @@ type RelayInfo struct {
|
||||
SendResponseCount int
|
||||
ReceivedResponseCount int
|
||||
FinalPreConsumedQuota int // 最终预消耗的配额
|
||||
// ForcePreConsume 为 true 时禁用 BillingSession 的信任额度旁路,
|
||||
// 强制预扣全额。用于异步任务(视频/音乐生成等),因为请求返回后任务仍在运行,
|
||||
// 必须在提交前锁定全额。
|
||||
ForcePreConsume bool
|
||||
// Billing 是计费会话,封装了预扣费/结算/退款的统一生命周期。
|
||||
// 免费模型和按次计费(MJ/Task)时为 nil。
|
||||
// 免费模型时为 nil。
|
||||
Billing BillingSettler
|
||||
// BillingSource indicates whether this request is billed from wallet quota or subscription.
|
||||
// "" or "wallet" => wallet; "subscription" => subscription
|
||||
@@ -153,7 +157,8 @@ type RelayInfo struct {
|
||||
// RequestConversionChain records request format conversions in order, e.g.
|
||||
// ["openai", "openai_responses"] or ["openai", "claude"].
|
||||
RequestConversionChain []types.RelayFormat
|
||||
// 最终请求到上游的格式 TODO: 当前仅设置了Claude
|
||||
// 最终请求到上游的格式。可由 adaptor 显式设置;
|
||||
// 若为空,调用 GetFinalRequestRelayFormat 会回退到 RequestConversionChain 的最后一项或 RelayFormat。
|
||||
FinalRequestRelayFormat types.RelayFormat
|
||||
|
||||
ThinkingContentInfo
|
||||
@@ -552,8 +557,10 @@ func GenRelayInfo(c *gin.Context, relayFormat types.RelayFormat, request dto.Req
|
||||
return nil, errors.New("request is not a OpenAIResponsesCompactionRequest")
|
||||
case types.RelayFormatTask:
|
||||
info = genBaseRelayInfo(c, nil)
|
||||
info.TaskRelayInfo = &TaskRelayInfo{}
|
||||
case types.RelayFormatMjProxy:
|
||||
info = genBaseRelayInfo(c, nil)
|
||||
info.TaskRelayInfo = &TaskRelayInfo{}
|
||||
default:
|
||||
err = errors.New("invalid relay format")
|
||||
}
|
||||
@@ -600,6 +607,19 @@ func (info *RelayInfo) AppendRequestConversion(format types.RelayFormat) {
|
||||
info.RequestConversionChain = append(info.RequestConversionChain, format)
|
||||
}
|
||||
|
||||
func (info *RelayInfo) GetFinalRequestRelayFormat() types.RelayFormat {
|
||||
if info == nil {
|
||||
return ""
|
||||
}
|
||||
if info.FinalRequestRelayFormat != "" {
|
||||
return info.FinalRequestRelayFormat
|
||||
}
|
||||
if n := len(info.RequestConversionChain); n > 0 {
|
||||
return info.RequestConversionChain[n-1]
|
||||
}
|
||||
return info.RelayFormat
|
||||
}
|
||||
|
||||
func GenRelayInfoResponsesCompaction(c *gin.Context, request *dto.OpenAIResponsesCompactionRequest) *RelayInfo {
|
||||
info := genBaseRelayInfo(c, request)
|
||||
if info.RelayMode == relayconstant.RelayModeUnknown {
|
||||
@@ -635,8 +655,16 @@ func (info *RelayInfo) HasSendResponse() bool {
|
||||
type TaskRelayInfo struct {
|
||||
Action string
|
||||
OriginTaskID string
|
||||
// PublicTaskID 是提交时预生成的 task_xxxx 格式公开 ID,
|
||||
// 供 DoResponse 在返回给客户端时使用(避免暴露上游真实 ID)。
|
||||
PublicTaskID string
|
||||
|
||||
ConsumeQuota bool
|
||||
|
||||
// LockedChannel holds the full channel object when the request is bound to
|
||||
// a specific channel (e.g., remix on origin task's channel). Stored as any
|
||||
// to avoid an import cycle with model; callers type-assert to *model.Channel.
|
||||
LockedChannel any
|
||||
}
|
||||
|
||||
type TaskSubmitReq struct {
|
||||
@@ -694,11 +722,11 @@ func (t *TaskSubmitReq) UnmarshalJSON(data []byte) error {
|
||||
func (t *TaskSubmitReq) UnmarshalMetadata(v any) error {
|
||||
metadata := t.Metadata
|
||||
if metadata != nil {
|
||||
metadataBytes, err := json.Marshal(metadata)
|
||||
metadataBytes, err := common.Marshal(metadata)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal metadata failed: %w", err)
|
||||
}
|
||||
err = json.Unmarshal(metadataBytes, v)
|
||||
err = common.Unmarshal(metadataBytes, v)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unmarshal metadata to target failed: %w", err)
|
||||
}
|
||||
@@ -727,9 +755,15 @@ func FailTaskInfo(reason string) *TaskInfo {
|
||||
|
||||
// RemoveDisabledFields 从请求 JSON 数据中移除渠道设置中禁用的字段
|
||||
// service_tier: 服务层级字段,可能导致额外计费(OpenAI、Claude、Responses API 支持)
|
||||
// inference_geo: Claude 数据驻留推理区域字段(仅 Claude 支持,默认过滤)
|
||||
// store: 数据存储授权字段,涉及用户隐私(仅 OpenAI、Responses API 支持,默认允许透传,禁用后可能导致 Codex 无法使用)
|
||||
// safety_identifier: 安全标识符,用于向 OpenAI 报告违规用户(仅 OpenAI 支持,涉及用户隐私)
|
||||
func RemoveDisabledFields(jsonData []byte, channelOtherSettings dto.ChannelOtherSettings) ([]byte, error) {
|
||||
// stream_options.include_obfuscation: 响应流混淆控制字段(仅 OpenAI Responses API 支持)
|
||||
func RemoveDisabledFields(jsonData []byte, channelOtherSettings dto.ChannelOtherSettings, channelPassThroughEnabled bool) ([]byte, error) {
|
||||
if model_setting.GetGlobalSettings().PassThroughRequestEnabled || channelPassThroughEnabled {
|
||||
return jsonData, nil
|
||||
}
|
||||
|
||||
var data map[string]interface{}
|
||||
if err := common.Unmarshal(jsonData, &data); err != nil {
|
||||
common.SysError("RemoveDisabledFields Unmarshal error :" + err.Error())
|
||||
@@ -743,6 +777,13 @@ func RemoveDisabledFields(jsonData []byte, channelOtherSettings dto.ChannelOther
|
||||
}
|
||||
}
|
||||
|
||||
// 默认移除 inference_geo,除非明确允许(避免在未授权情况下透传数据驻留区域)
|
||||
if !channelOtherSettings.AllowInferenceGeo {
|
||||
if _, exists := data["inference_geo"]; exists {
|
||||
delete(data, "inference_geo")
|
||||
}
|
||||
}
|
||||
|
||||
// 默认允许 store 透传,除非明确禁用(禁用可能影响 Codex 使用)
|
||||
if channelOtherSettings.DisableStore {
|
||||
if _, exists := data["store"]; exists {
|
||||
@@ -757,6 +798,22 @@ func RemoveDisabledFields(jsonData []byte, channelOtherSettings dto.ChannelOther
|
||||
}
|
||||
}
|
||||
|
||||
// 默认移除 stream_options.include_obfuscation,除非明确允许(避免关闭响应流混淆保护)
|
||||
if !channelOtherSettings.AllowIncludeObfuscation {
|
||||
if streamOptionsAny, exists := data["stream_options"]; exists {
|
||||
if streamOptions, ok := streamOptionsAny.(map[string]interface{}); ok {
|
||||
if _, includeExists := streamOptions["include_obfuscation"]; includeExists {
|
||||
delete(streamOptions, "include_obfuscation")
|
||||
}
|
||||
if len(streamOptions) == 0 {
|
||||
delete(data, "stream_options")
|
||||
} else {
|
||||
data["stream_options"] = streamOptions
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
jsonDataAfter, err := common.Marshal(data)
|
||||
if err != nil {
|
||||
common.SysError("RemoveDisabledFields Marshal error :" + err.Error())
|
||||
|
||||
40
relay/common/relay_info_test.go
Normal file
40
relay/common/relay_info_test.go
Normal file
@@ -0,0 +1,40 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/QuantumNous/new-api/types"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestRelayInfoGetFinalRequestRelayFormatPrefersExplicitFinal(t *testing.T) {
|
||||
info := &RelayInfo{
|
||||
RelayFormat: types.RelayFormatOpenAI,
|
||||
RequestConversionChain: []types.RelayFormat{types.RelayFormatOpenAI, types.RelayFormatClaude},
|
||||
FinalRequestRelayFormat: types.RelayFormatOpenAIResponses,
|
||||
}
|
||||
|
||||
require.Equal(t, types.RelayFormat(types.RelayFormatOpenAIResponses), info.GetFinalRequestRelayFormat())
|
||||
}
|
||||
|
||||
func TestRelayInfoGetFinalRequestRelayFormatFallsBackToConversionChain(t *testing.T) {
|
||||
info := &RelayInfo{
|
||||
RelayFormat: types.RelayFormatOpenAI,
|
||||
RequestConversionChain: []types.RelayFormat{types.RelayFormatOpenAI, types.RelayFormatClaude},
|
||||
}
|
||||
|
||||
require.Equal(t, types.RelayFormat(types.RelayFormatClaude), info.GetFinalRequestRelayFormat())
|
||||
}
|
||||
|
||||
func TestRelayInfoGetFinalRequestRelayFormatFallsBackToRelayFormat(t *testing.T) {
|
||||
info := &RelayInfo{
|
||||
RelayFormat: types.RelayFormatGemini,
|
||||
}
|
||||
|
||||
require.Equal(t, types.RelayFormat(types.RelayFormatGemini), info.GetFinalRequestRelayFormat())
|
||||
}
|
||||
|
||||
func TestRelayInfoGetFinalRequestRelayFormatNilReceiver(t *testing.T) {
|
||||
var info *RelayInfo
|
||||
require.Equal(t, types.RelayFormat(""), info.GetFinalRequestRelayFormat())
|
||||
}
|
||||
@@ -173,16 +173,10 @@ func ValidateMultipartDirect(c *gin.Context, info *RelayInfo) *dto.TaskError {
|
||||
if model == "sora-2-pro" && !lo.Contains([]string{"720x1280", "1280x720", "1792x1024", "1024x1792"}, size) {
|
||||
return createTaskError(fmt.Errorf("sora-2 size is invalid"), "invalid_size", http.StatusBadRequest, true)
|
||||
}
|
||||
info.PriceData.OtherRatios = map[string]float64{
|
||||
"seconds": float64(seconds),
|
||||
"size": 1,
|
||||
}
|
||||
if lo.Contains([]string{"1792x1024", "1024x1792"}, size) {
|
||||
info.PriceData.OtherRatios["size"] = 1.666667
|
||||
}
|
||||
// OtherRatios 已移到 Sora adaptor 的 EstimateBilling 中设置
|
||||
}
|
||||
|
||||
info.Action = action
|
||||
storeTaskRequest(c, info, action, req)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user