Merge pull request #3009 from seefs001/feature/improve-param-override

feat: improve channel override ui/ux
This commit is contained in:
Seefs
2026-02-28 18:19:40 +08:00
committed by GitHub
26 changed files with 6180 additions and 261 deletions

View File

@@ -169,12 +169,17 @@ func applyHeaderOverridePlaceholders(template string, c *gin.Context, apiKey str
// Passthrough rules are applied first, then normal overrides are applied, so explicit overrides win.
func processHeaderOverride(info *common.RelayInfo, c *gin.Context) (map[string]string, error) {
headerOverride := make(map[string]string)
if info == nil {
return headerOverride, nil
}
headerOverrideSource := common.GetEffectiveHeaderOverride(info)
passAll := false
var passthroughRegex []*regexp.Regexp
if !info.IsChannelTest {
for k := range info.HeadersOverride {
key := strings.TrimSpace(k)
for k := range headerOverrideSource {
key := strings.TrimSpace(strings.ToLower(k))
if key == "" {
continue
}
@@ -183,12 +188,11 @@ func processHeaderOverride(info *common.RelayInfo, c *gin.Context) (map[string]s
continue
}
lower := strings.ToLower(key)
var pattern string
switch {
case strings.HasPrefix(lower, headerPassthroughRegexPrefix):
case strings.HasPrefix(key, headerPassthroughRegexPrefix):
pattern = strings.TrimSpace(key[len(headerPassthroughRegexPrefix):])
case strings.HasPrefix(lower, headerPassthroughRegexPrefixV2):
case strings.HasPrefix(key, headerPassthroughRegexPrefixV2):
pattern = strings.TrimSpace(key[len(headerPassthroughRegexPrefixV2):])
default:
continue
@@ -229,15 +233,15 @@ func processHeaderOverride(info *common.RelayInfo, c *gin.Context) (map[string]s
if value == "" {
continue
}
headerOverride[name] = value
headerOverride[strings.ToLower(strings.TrimSpace(name))] = value
}
}
for k, v := range info.HeadersOverride {
for k, v := range headerOverrideSource {
if isHeaderPassthroughRuleKey(k) {
continue
}
key := strings.TrimSpace(k)
key := strings.TrimSpace(strings.ToLower(k))
if key == "" {
continue
}

View File

@@ -53,7 +53,7 @@ func TestProcessHeaderOverride_ChannelTestSkipsClientHeaderPlaceholder(t *testin
headers, err := processHeaderOverride(info, ctx)
require.NoError(t, err)
_, ok := headers["X-Upstream-Trace"]
_, ok := headers["x-upstream-trace"]
require.False(t, ok)
}
@@ -77,7 +77,38 @@ func TestProcessHeaderOverride_NonTestKeepsClientHeaderPlaceholder(t *testing.T)
headers, err := processHeaderOverride(info, ctx)
require.NoError(t, err)
require.Equal(t, "trace-123", headers["X-Upstream-Trace"])
require.Equal(t, "trace-123", headers["x-upstream-trace"])
}
func TestProcessHeaderOverride_RuntimeOverrideIsFinalHeaderMap(t *testing.T) {
t.Parallel()
gin.SetMode(gin.TestMode)
recorder := httptest.NewRecorder()
ctx, _ := gin.CreateTestContext(recorder)
ctx.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
info := &relaycommon.RelayInfo{
IsChannelTest: false,
UseRuntimeHeadersOverride: true,
RuntimeHeadersOverride: map[string]any{
"x-static": "runtime-value",
"x-runtime": "runtime-only",
},
ChannelMeta: &relaycommon.ChannelMeta{
HeadersOverride: map[string]any{
"X-Static": "legacy-value",
"X-Legacy": "legacy-only",
},
},
}
headers, err := processHeaderOverride(info, ctx)
require.NoError(t, err)
require.Equal(t, "runtime-value", headers["x-static"])
require.Equal(t, "runtime-only", headers["x-runtime"])
_, exists := headers["x-legacy"]
require.False(t, exists)
}
func TestProcessHeaderOverride_PassthroughSkipsAcceptEncoding(t *testing.T) {
@@ -101,8 +132,62 @@ func TestProcessHeaderOverride_PassthroughSkipsAcceptEncoding(t *testing.T) {
headers, err := processHeaderOverride(info, ctx)
require.NoError(t, err)
require.Equal(t, "trace-123", headers["X-Trace-Id"])
require.Equal(t, "trace-123", headers["x-trace-id"])
_, hasAcceptEncoding := headers["Accept-Encoding"]
_, hasAcceptEncoding := headers["accept-encoding"]
require.False(t, hasAcceptEncoding)
}
func TestProcessHeaderOverride_PassHeadersTemplateSetsRuntimeHeaders(t *testing.T) {
t.Parallel()
gin.SetMode(gin.TestMode)
recorder := httptest.NewRecorder()
ctx, _ := gin.CreateTestContext(recorder)
ctx.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
ctx.Request.Header.Set("Originator", "Codex CLI")
ctx.Request.Header.Set("Session_id", "sess-123")
info := &relaycommon.RelayInfo{
IsChannelTest: false,
RequestHeaders: map[string]string{
"Originator": "Codex CLI",
"Session_id": "sess-123",
},
ChannelMeta: &relaycommon.ChannelMeta{
ParamOverride: map[string]any{
"operations": []any{
map[string]any{
"mode": "pass_headers",
"value": []any{"Originator", "Session_id", "X-Codex-Beta-Features"},
},
},
},
HeadersOverride: map[string]any{
"X-Static": "legacy-value",
},
},
}
_, err := relaycommon.ApplyParamOverrideWithRelayInfo([]byte(`{"model":"gpt-4.1"}`), info)
require.NoError(t, err)
require.True(t, info.UseRuntimeHeadersOverride)
require.Equal(t, "Codex CLI", info.RuntimeHeadersOverride["originator"])
require.Equal(t, "sess-123", info.RuntimeHeadersOverride["session_id"])
_, exists := info.RuntimeHeadersOverride["x-codex-beta-features"]
require.False(t, exists)
require.Equal(t, "legacy-value", info.RuntimeHeadersOverride["x-static"])
headers, err := processHeaderOverride(info, ctx)
require.NoError(t, err)
require.Equal(t, "Codex CLI", headers["originator"])
require.Equal(t, "sess-123", headers["session_id"])
_, exists = headers["x-codex-beta-features"]
require.False(t, exists)
upstreamReq := httptest.NewRequest(http.MethodPost, "https://example.com/v1/responses", nil)
applyHeaderOverrideToRequest(upstreamReq, headers)
require.Equal(t, "Codex CLI", upstreamReq.Header.Get("Originator"))
require.Equal(t, "sess-123", upstreamReq.Header.Get("Session_id"))
require.Empty(t, upstreamReq.Header.Get("X-Codex-Beta-Features"))
}

View File

@@ -70,7 +70,6 @@ func applySystemPromptIfNeeded(c *gin.Context, info *relaycommon.RelayInfo, requ
}
func chatCompletionsViaResponses(c *gin.Context, info *relaycommon.RelayInfo, adaptor channel.Adaptor, request *dto.GeneralOpenAIRequest) (*dto.Usage, *types.NewAPIError) {
overrideCtx := relaycommon.BuildParamOverrideContext(info)
chatJSON, err := common.Marshal(request)
if err != nil {
return nil, types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
@@ -82,9 +81,9 @@ func chatCompletionsViaResponses(c *gin.Context, info *relaycommon.RelayInfo, ad
}
if len(info.ParamOverride) > 0 {
chatJSON, err = relaycommon.ApplyParamOverride(chatJSON, info.ParamOverride, overrideCtx)
chatJSON, err = relaycommon.ApplyParamOverrideWithRelayInfo(chatJSON, info)
if err != nil {
return nil, types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
return nil, newAPIErrorFromParamOverride(err)
}
}

View File

@@ -153,9 +153,9 @@ func ClaudeHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
// apply param override
if len(info.ParamOverride) > 0 {
jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride, relaycommon.BuildParamOverrideContext(info))
jsonData, err = relaycommon.ApplyParamOverrideWithRelayInfo(jsonData, info)
if err != nil {
return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
return newAPIErrorFromParamOverride(err)
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -5,6 +5,8 @@ import (
"reflect"
"testing"
"github.com/QuantumNous/new-api/types"
"github.com/QuantumNous/new-api/dto"
"github.com/QuantumNous/new-api/setting/model_setting"
)
@@ -775,6 +777,754 @@ func TestApplyParamOverrideToUpper(t *testing.T) {
assertJSONEqual(t, `{"model":"GPT-4"}`, string(out))
}
func TestApplyParamOverrideReturnError(t *testing.T) {
input := []byte(`{"model":"gemini-2.5-pro"}`)
override := map[string]interface{}{
"operations": []interface{}{
map[string]interface{}{
"mode": "return_error",
"value": map[string]interface{}{
"message": "forced bad request by param override",
"status_code": 422,
"code": "forced_bad_request",
"type": "invalid_request_error",
"skip_retry": true,
},
"conditions": []interface{}{
map[string]interface{}{
"path": "retry.is_retry",
"mode": "full",
"value": true,
},
},
},
},
}
ctx := map[string]interface{}{
"retry": map[string]interface{}{
"index": 1,
"is_retry": true,
},
}
_, err := ApplyParamOverride(input, override, ctx)
if err == nil {
t.Fatalf("expected error, got nil")
}
returnErr, ok := AsParamOverrideReturnError(err)
if !ok {
t.Fatalf("expected ParamOverrideReturnError, got %T: %v", err, err)
}
if returnErr.StatusCode != 422 {
t.Fatalf("expected status 422, got %d", returnErr.StatusCode)
}
if returnErr.Code != "forced_bad_request" {
t.Fatalf("expected code forced_bad_request, got %s", returnErr.Code)
}
if !returnErr.SkipRetry {
t.Fatalf("expected skip_retry true")
}
}
func TestApplyParamOverridePruneObjectsByTypeString(t *testing.T) {
input := []byte(`{
"messages":[
{"role":"assistant","content":[
{"type":"output_text","text":"a"},
{"type":"redacted_thinking","text":"secret"},
{"type":"tool_call","name":"tool_a"}
]},
{"role":"assistant","content":[
{"type":"output_text","text":"b"},
{"type":"wrapper","parts":[
{"type":"redacted_thinking","text":"secret2"},
{"type":"output_text","text":"c"}
]}
]}
]
}`)
override := map[string]interface{}{
"operations": []interface{}{
map[string]interface{}{
"mode": "prune_objects",
"value": "redacted_thinking",
},
},
}
out, err := ApplyParamOverride(input, override, nil)
if err != nil {
t.Fatalf("ApplyParamOverride returned error: %v", err)
}
assertJSONEqual(t, `{
"messages":[
{"role":"assistant","content":[
{"type":"output_text","text":"a"},
{"type":"tool_call","name":"tool_a"}
]},
{"role":"assistant","content":[
{"type":"output_text","text":"b"},
{"type":"wrapper","parts":[
{"type":"output_text","text":"c"}
]}
]}
]
}`, string(out))
}
func TestApplyParamOverridePruneObjectsWhereAndPath(t *testing.T) {
input := []byte(`{
"a":{"items":[{"type":"redacted_thinking","id":1},{"type":"output_text","id":2}]},
"b":{"items":[{"type":"redacted_thinking","id":3},{"type":"output_text","id":4}]}
}`)
override := map[string]interface{}{
"operations": []interface{}{
map[string]interface{}{
"path": "a",
"mode": "prune_objects",
"value": map[string]interface{}{
"where": map[string]interface{}{
"type": "redacted_thinking",
},
},
},
},
}
out, err := ApplyParamOverride(input, override, nil)
if err != nil {
t.Fatalf("ApplyParamOverride returned error: %v", err)
}
assertJSONEqual(t, `{
"a":{"items":[{"type":"output_text","id":2}]},
"b":{"items":[{"type":"redacted_thinking","id":3},{"type":"output_text","id":4}]}
}`, string(out))
}
func TestApplyParamOverrideNormalizeThinkingSignatureUnsupported(t *testing.T) {
input := []byte(`{"items":[{"type":"redacted_thinking"}]}`)
override := map[string]interface{}{
"operations": []interface{}{
map[string]interface{}{
"mode": "normalize_thinking_signature",
},
},
}
_, err := ApplyParamOverride(input, override, nil)
if err == nil {
t.Fatalf("expected error, got nil")
}
}
func TestApplyParamOverrideConditionFromRetryAndLastErrorContext(t *testing.T) {
info := &RelayInfo{
RetryIndex: 1,
LastError: types.WithOpenAIError(types.OpenAIError{
Message: "invalid thinking signature",
Type: "invalid_request_error",
Code: "bad_thought_signature",
}, 400),
}
ctx := BuildParamOverrideContext(info)
input := []byte(`{"temperature":0.7}`)
override := map[string]interface{}{
"operations": []interface{}{
map[string]interface{}{
"path": "temperature",
"mode": "set",
"value": 0.1,
"logic": "AND",
"conditions": []interface{}{
map[string]interface{}{
"path": "is_retry",
"mode": "full",
"value": true,
},
map[string]interface{}{
"path": "last_error.code",
"mode": "contains",
"value": "thought_signature",
},
},
},
},
}
out, err := ApplyParamOverride(input, override, ctx)
if err != nil {
t.Fatalf("ApplyParamOverride returned error: %v", err)
}
assertJSONEqual(t, `{"temperature":0.1}`, string(out))
}
func TestApplyParamOverrideConditionFromRequestHeaders(t *testing.T) {
input := []byte(`{"temperature":0.7}`)
override := map[string]interface{}{
"operations": []interface{}{
map[string]interface{}{
"path": "temperature",
"mode": "set",
"value": 0.1,
"conditions": []interface{}{
map[string]interface{}{
"path": "request_headers.authorization",
"mode": "contains",
"value": "Bearer ",
},
},
},
},
}
ctx := map[string]interface{}{
"request_headers": map[string]interface{}{
"authorization": "Bearer token-123",
},
}
out, err := ApplyParamOverride(input, override, ctx)
if err != nil {
t.Fatalf("ApplyParamOverride returned error: %v", err)
}
assertJSONEqual(t, `{"temperature":0.1}`, string(out))
}
func TestApplyParamOverrideSetHeaderAndUseInLaterCondition(t *testing.T) {
input := []byte(`{"temperature":0.7}`)
override := map[string]interface{}{
"operations": []interface{}{
map[string]interface{}{
"mode": "set_header",
"path": "X-Debug-Mode",
"value": "enabled",
},
map[string]interface{}{
"path": "temperature",
"mode": "set",
"value": 0.1,
"conditions": []interface{}{
map[string]interface{}{
"path": "header_override.x-debug-mode",
"mode": "full",
"value": "enabled",
},
},
},
},
}
out, err := ApplyParamOverride(input, override, nil)
if err != nil {
t.Fatalf("ApplyParamOverride returned error: %v", err)
}
assertJSONEqual(t, `{"temperature":0.1}`, string(out))
}
func TestApplyParamOverrideCopyHeaderFromRequestHeaders(t *testing.T) {
input := []byte(`{"temperature":0.7}`)
override := map[string]interface{}{
"operations": []interface{}{
map[string]interface{}{
"mode": "copy_header",
"from": "Authorization",
"to": "X-Upstream-Auth",
},
map[string]interface{}{
"path": "temperature",
"mode": "set",
"value": 0.1,
"conditions": []interface{}{
map[string]interface{}{
"path": "header_override.x-upstream-auth",
"mode": "contains",
"value": "Bearer ",
},
},
},
},
}
ctx := map[string]interface{}{
"request_headers": map[string]interface{}{
"authorization": "Bearer token-123",
},
}
out, err := ApplyParamOverride(input, override, ctx)
if err != nil {
t.Fatalf("ApplyParamOverride returned error: %v", err)
}
assertJSONEqual(t, `{"temperature":0.1}`, string(out))
}
func TestApplyParamOverridePassHeadersSkipsMissingHeaders(t *testing.T) {
input := []byte(`{"temperature":0.7}`)
override := map[string]interface{}{
"operations": []interface{}{
map[string]interface{}{
"mode": "pass_headers",
"value": []interface{}{"X-Codex-Beta-Features", "Session_id"},
},
},
}
ctx := map[string]interface{}{
"request_headers": map[string]interface{}{
"session_id": "sess-123",
},
}
out, err := ApplyParamOverride(input, override, ctx)
if err != nil {
t.Fatalf("ApplyParamOverride returned error: %v", err)
}
assertJSONEqual(t, `{"temperature":0.7}`, string(out))
headers, ok := ctx["header_override"].(map[string]interface{})
if !ok {
t.Fatalf("expected header_override context map")
}
if headers["session_id"] != "sess-123" {
t.Fatalf("expected session_id to be passed, got: %v", headers["session_id"])
}
if _, exists := headers["x-codex-beta-features"]; exists {
t.Fatalf("expected missing header to be skipped")
}
}
func TestApplyParamOverrideCopyHeaderSkipsMissingSource(t *testing.T) {
input := []byte(`{"temperature":0.7}`)
override := map[string]interface{}{
"operations": []interface{}{
map[string]interface{}{
"mode": "copy_header",
"from": "X-Missing-Header",
"to": "X-Upstream-Auth",
},
},
}
ctx := map[string]interface{}{
"request_headers": map[string]interface{}{
"authorization": "Bearer token-123",
},
}
out, err := ApplyParamOverride(input, override, ctx)
if err != nil {
t.Fatalf("ApplyParamOverride returned error: %v", err)
}
assertJSONEqual(t, `{"temperature":0.7}`, string(out))
headers, ok := ctx["header_override"].(map[string]interface{})
if !ok {
return
}
if _, exists := headers["x-upstream-auth"]; exists {
t.Fatalf("expected X-Upstream-Auth to be skipped when source header is missing")
}
}
func TestApplyParamOverrideMoveHeaderSkipsMissingSource(t *testing.T) {
input := []byte(`{"temperature":0.7}`)
override := map[string]interface{}{
"operations": []interface{}{
map[string]interface{}{
"mode": "move_header",
"from": "X-Missing-Header",
"to": "X-Upstream-Auth",
},
},
}
ctx := map[string]interface{}{
"request_headers": map[string]interface{}{
"authorization": "Bearer token-123",
},
}
out, err := ApplyParamOverride(input, override, ctx)
if err != nil {
t.Fatalf("ApplyParamOverride returned error: %v", err)
}
assertJSONEqual(t, `{"temperature":0.7}`, string(out))
headers, ok := ctx["header_override"].(map[string]interface{})
if !ok {
return
}
if _, exists := headers["x-upstream-auth"]; exists {
t.Fatalf("expected X-Upstream-Auth to be skipped when source header is missing")
}
}
func TestApplyParamOverrideSyncFieldsHeaderToJSON(t *testing.T) {
input := []byte(`{"model":"gpt-4"}`)
override := map[string]interface{}{
"operations": []interface{}{
map[string]interface{}{
"mode": "sync_fields",
"from": "header:session_id",
"to": "json:prompt_cache_key",
},
},
}
ctx := map[string]interface{}{
"request_headers": map[string]interface{}{
"session_id": "sess-123",
},
}
out, err := ApplyParamOverride(input, override, ctx)
if err != nil {
t.Fatalf("ApplyParamOverride returned error: %v", err)
}
assertJSONEqual(t, `{"model":"gpt-4","prompt_cache_key":"sess-123"}`, string(out))
}
func TestApplyParamOverrideSyncFieldsJSONToHeader(t *testing.T) {
input := []byte(`{"model":"gpt-4","prompt_cache_key":"cache-abc"}`)
override := map[string]interface{}{
"operations": []interface{}{
map[string]interface{}{
"mode": "sync_fields",
"from": "header:session_id",
"to": "json:prompt_cache_key",
},
},
}
ctx := map[string]interface{}{}
out, err := ApplyParamOverride(input, override, ctx)
if err != nil {
t.Fatalf("ApplyParamOverride returned error: %v", err)
}
assertJSONEqual(t, `{"model":"gpt-4","prompt_cache_key":"cache-abc"}`, string(out))
headers, ok := ctx["header_override"].(map[string]interface{})
if !ok {
t.Fatalf("expected header_override context map")
}
if headers["session_id"] != "cache-abc" {
t.Fatalf("expected session_id to be synced from prompt_cache_key, got: %v", headers["session_id"])
}
}
func TestApplyParamOverrideSyncFieldsNoChangeWhenBothExist(t *testing.T) {
input := []byte(`{"model":"gpt-4","prompt_cache_key":"cache-body"}`)
override := map[string]interface{}{
"operations": []interface{}{
map[string]interface{}{
"mode": "sync_fields",
"from": "header:session_id",
"to": "json:prompt_cache_key",
},
},
}
ctx := map[string]interface{}{
"request_headers": map[string]interface{}{
"session_id": "cache-header",
},
}
out, err := ApplyParamOverride(input, override, ctx)
if err != nil {
t.Fatalf("ApplyParamOverride returned error: %v", err)
}
assertJSONEqual(t, `{"model":"gpt-4","prompt_cache_key":"cache-body"}`, string(out))
headers, _ := ctx["header_override"].(map[string]interface{})
if headers != nil {
if _, exists := headers["session_id"]; exists {
t.Fatalf("expected no override when both sides already have value")
}
}
}
func TestApplyParamOverrideSyncFieldsInvalidTarget(t *testing.T) {
input := []byte(`{"model":"gpt-4"}`)
override := map[string]interface{}{
"operations": []interface{}{
map[string]interface{}{
"mode": "sync_fields",
"from": "foo:session_id",
"to": "json:prompt_cache_key",
},
},
}
_, err := ApplyParamOverride(input, override, nil)
if err == nil {
t.Fatalf("expected error, got nil")
}
}
func TestApplyParamOverrideSetHeaderKeepOrigin(t *testing.T) {
input := []byte(`{"temperature":0.7}`)
override := map[string]interface{}{
"operations": []interface{}{
map[string]interface{}{
"mode": "set_header",
"path": "X-Feature-Flag",
"value": "new-value",
"keep_origin": true,
},
},
}
ctx := map[string]interface{}{
"header_override": map[string]interface{}{
"x-feature-flag": "legacy-value",
},
}
_, err := ApplyParamOverride(input, override, ctx)
if err != nil {
t.Fatalf("ApplyParamOverride returned error: %v", err)
}
headers, ok := ctx["header_override"].(map[string]interface{})
if !ok {
t.Fatalf("expected header_override context map")
}
if headers["x-feature-flag"] != "legacy-value" {
t.Fatalf("expected keep_origin to preserve old value, got: %v", headers["x-feature-flag"])
}
}
func TestApplyParamOverrideSetHeaderMapRewritesCommaSeparatedHeader(t *testing.T) {
input := []byte(`{"temperature":0.7}`)
override := map[string]interface{}{
"operations": []interface{}{
map[string]interface{}{
"mode": "set_header",
"path": "anthropic-beta",
"value": map[string]interface{}{
"advanced-tool-use-2025-11-20": nil,
"computer-use-2025-01-24": "computer-use-2025-01-24",
},
},
},
}
ctx := map[string]interface{}{
"request_headers": map[string]interface{}{
"anthropic-beta": "advanced-tool-use-2025-11-20, computer-use-2025-01-24",
},
}
_, err := ApplyParamOverride(input, override, ctx)
if err != nil {
t.Fatalf("ApplyParamOverride returned error: %v", err)
}
headers, ok := ctx["header_override"].(map[string]interface{})
if !ok {
t.Fatalf("expected header_override context map")
}
if headers["anthropic-beta"] != "computer-use-2025-01-24" {
t.Fatalf("expected anthropic-beta to keep only mapped value, got: %v", headers["anthropic-beta"])
}
}
func TestApplyParamOverrideSetHeaderMapDeleteWholeHeaderWhenAllTokensCleared(t *testing.T) {
input := []byte(`{"temperature":0.7}`)
override := map[string]interface{}{
"operations": []interface{}{
map[string]interface{}{
"mode": "set_header",
"path": "anthropic-beta",
"value": map[string]interface{}{
"advanced-tool-use-2025-11-20": nil,
"computer-use-2025-01-24": nil,
},
},
},
}
ctx := map[string]interface{}{
"header_override": map[string]interface{}{
"anthropic-beta": "advanced-tool-use-2025-11-20,computer-use-2025-01-24",
},
}
_, err := ApplyParamOverride(input, override, ctx)
if err != nil {
t.Fatalf("ApplyParamOverride returned error: %v", err)
}
headers, ok := ctx["header_override"].(map[string]interface{})
if !ok {
t.Fatalf("expected header_override context map")
}
if _, exists := headers["anthropic-beta"]; exists {
t.Fatalf("expected anthropic-beta to be deleted when all mapped values are null")
}
}
func TestApplyParamOverrideConditionsObjectShorthand(t *testing.T) {
input := []byte(`{"temperature":0.7}`)
override := map[string]interface{}{
"operations": []interface{}{
map[string]interface{}{
"path": "temperature",
"mode": "set",
"value": 0.1,
"logic": "AND",
"conditions": map[string]interface{}{
"is_retry": true,
"last_error.status_code": 400.0,
},
},
},
}
ctx := map[string]interface{}{
"is_retry": true,
"last_error": map[string]interface{}{
"status_code": 400.0,
},
}
out, err := ApplyParamOverride(input, override, ctx)
if err != nil {
t.Fatalf("ApplyParamOverride returned error: %v", err)
}
assertJSONEqual(t, `{"temperature":0.1}`, string(out))
}
func TestApplyParamOverrideWithRelayInfoSyncRuntimeHeaders(t *testing.T) {
info := &RelayInfo{
ChannelMeta: &ChannelMeta{
ParamOverride: map[string]interface{}{
"operations": []interface{}{
map[string]interface{}{
"mode": "set_header",
"path": "X-Injected-By-Param-Override",
"value": "enabled",
},
map[string]interface{}{
"mode": "delete_header",
"path": "X-Delete-Me",
},
},
},
HeadersOverride: map[string]interface{}{
"X-Delete-Me": "legacy",
"X-Keep-Me": "keep",
},
},
}
input := []byte(`{"temperature":0.7}`)
out, err := ApplyParamOverrideWithRelayInfo(input, info)
if err != nil {
t.Fatalf("ApplyParamOverrideWithRelayInfo returned error: %v", err)
}
assertJSONEqual(t, `{"temperature":0.7}`, string(out))
if !info.UseRuntimeHeadersOverride {
t.Fatalf("expected runtime header override to be enabled")
}
if info.RuntimeHeadersOverride["x-keep-me"] != "keep" {
t.Fatalf("expected x-keep-me header to be preserved, got: %v", info.RuntimeHeadersOverride["x-keep-me"])
}
if info.RuntimeHeadersOverride["x-injected-by-param-override"] != "enabled" {
t.Fatalf("expected x-injected-by-param-override header to be set, got: %v", info.RuntimeHeadersOverride["x-injected-by-param-override"])
}
if _, exists := info.RuntimeHeadersOverride["x-delete-me"]; exists {
t.Fatalf("expected x-delete-me header to be deleted")
}
}
func TestApplyParamOverrideWithRelayInfoMoveAndCopyHeaders(t *testing.T) {
info := &RelayInfo{
ChannelMeta: &ChannelMeta{
ParamOverride: map[string]interface{}{
"operations": []interface{}{
map[string]interface{}{
"mode": "move_header",
"from": "X-Legacy-Trace",
"to": "X-Trace",
},
map[string]interface{}{
"mode": "copy_header",
"from": "X-Trace",
"to": "X-Trace-Backup",
},
},
},
HeadersOverride: map[string]interface{}{
"X-Legacy-Trace": "trace-123",
},
},
}
input := []byte(`{"temperature":0.7}`)
_, err := ApplyParamOverrideWithRelayInfo(input, info)
if err != nil {
t.Fatalf("ApplyParamOverrideWithRelayInfo returned error: %v", err)
}
if _, exists := info.RuntimeHeadersOverride["x-legacy-trace"]; exists {
t.Fatalf("expected source header to be removed after move")
}
if info.RuntimeHeadersOverride["x-trace"] != "trace-123" {
t.Fatalf("expected x-trace to be set, got: %v", info.RuntimeHeadersOverride["x-trace"])
}
if info.RuntimeHeadersOverride["x-trace-backup"] != "trace-123" {
t.Fatalf("expected x-trace-backup to be copied, got: %v", info.RuntimeHeadersOverride["x-trace-backup"])
}
}
func TestApplyParamOverrideWithRelayInfoSetHeaderMapRewritesAnthropicBeta(t *testing.T) {
info := &RelayInfo{
ChannelMeta: &ChannelMeta{
ParamOverride: map[string]interface{}{
"operations": []interface{}{
map[string]interface{}{
"mode": "set_header",
"path": "anthropic-beta",
"value": map[string]interface{}{
"advanced-tool-use-2025-11-20": nil,
"computer-use-2025-01-24": "computer-use-2025-01-24",
},
},
},
},
HeadersOverride: map[string]interface{}{
"anthropic-beta": "advanced-tool-use-2025-11-20, computer-use-2025-01-24",
},
},
}
_, err := ApplyParamOverrideWithRelayInfo([]byte(`{"temperature":0.7}`), info)
if err != nil {
t.Fatalf("ApplyParamOverrideWithRelayInfo returned error: %v", err)
}
if !info.UseRuntimeHeadersOverride {
t.Fatalf("expected runtime header override to be enabled")
}
if info.RuntimeHeadersOverride["anthropic-beta"] != "computer-use-2025-01-24" {
t.Fatalf("expected anthropic-beta to be rewritten, got: %v", info.RuntimeHeadersOverride["anthropic-beta"])
}
}
func TestGetEffectiveHeaderOverrideUsesRuntimeOverrideAsFinalResult(t *testing.T) {
info := &RelayInfo{
UseRuntimeHeadersOverride: true,
RuntimeHeadersOverride: map[string]interface{}{
"x-runtime": "runtime-only",
},
ChannelMeta: &ChannelMeta{
HeadersOverride: map[string]interface{}{
"X-Static": "static-value",
"X-Deleted": "should-not-exist",
},
},
}
effective := GetEffectiveHeaderOverride(info)
if effective["x-runtime"] != "runtime-only" {
t.Fatalf("expected x-runtime from runtime override, got: %v", effective["x-runtime"])
}
if _, exists := effective["x-static"]; exists {
t.Fatalf("expected runtime override to be final and not merge channel headers")
}
}
func TestRemoveDisabledFieldsSkipWhenChannelPassThroughEnabled(t *testing.T) {
input := `{
"service_tier":"flex",

View File

@@ -101,6 +101,7 @@ type RelayInfo struct {
RelayMode int
OriginModelName string
RequestURLPath string
RequestHeaders map[string]string
ShouldIncludeUsage bool
DisablePing bool // 是否禁止向下游发送自定义 Ping
ClientWs *websocket.Conn
@@ -144,6 +145,10 @@ type RelayInfo struct {
SubscriptionAmountUsedAfterPreConsume int64
IsClaudeBetaQuery bool // /v1/messages?beta=true
IsChannelTest bool // channel test request
RetryIndex int
LastError *types.NewAPIError
RuntimeHeadersOverride map[string]interface{}
UseRuntimeHeadersOverride bool
PriceData types.PriceData
@@ -461,6 +466,7 @@ func genBaseRelayInfo(c *gin.Context, request dto.Request) *RelayInfo {
isFirstResponse: true,
RelayMode: relayconstant.Path2RelayMode(c.Request.URL.Path),
RequestURLPath: c.Request.URL.String(),
RequestHeaders: cloneRequestHeaders(c),
IsStream: isStream,
StartTime: startTime,
@@ -493,6 +499,27 @@ func genBaseRelayInfo(c *gin.Context, request dto.Request) *RelayInfo {
return info
}
func cloneRequestHeaders(c *gin.Context) map[string]string {
if c == nil || c.Request == nil {
return nil
}
if len(c.Request.Header) == 0 {
return nil
}
headers := make(map[string]string, len(c.Request.Header))
for key := range c.Request.Header {
value := strings.TrimSpace(c.Request.Header.Get(key))
if value == "" {
continue
}
headers[key] = value
}
if len(headers) == 0 {
return nil
}
return headers
}
func GenRelayInfo(c *gin.Context, relayFormat types.RelayFormat, request dto.Request, ws *websocket.Conn) (*RelayInfo, error) {
var info *RelayInfo
var err error

View File

@@ -172,9 +172,9 @@ func TextHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types
// apply param override
if len(info.ParamOverride) > 0 {
jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride, relaycommon.BuildParamOverrideContext(info))
jsonData, err = relaycommon.ApplyParamOverrideWithRelayInfo(jsonData, info)
if err != nil {
return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
return newAPIErrorFromParamOverride(err)
}
}

View File

@@ -2,7 +2,6 @@ package relay
import (
"bytes"
"encoding/json"
"fmt"
"net/http"
@@ -46,15 +45,15 @@ func EmbeddingHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
}
relaycommon.AppendRequestConversionFromRequest(info, convertedRequest)
jsonData, err := json.Marshal(convertedRequest)
jsonData, err := common.Marshal(convertedRequest)
if err != nil {
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
}
if len(info.ParamOverride) > 0 {
jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride, relaycommon.BuildParamOverrideContext(info))
jsonData, err = relaycommon.ApplyParamOverrideWithRelayInfo(jsonData, info)
if err != nil {
return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
return newAPIErrorFromParamOverride(err)
}
}

View File

@@ -157,9 +157,9 @@ func GeminiHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
// apply param override
if len(info.ParamOverride) > 0 {
jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride, relaycommon.BuildParamOverrideContext(info))
jsonData, err = relaycommon.ApplyParamOverrideWithRelayInfo(jsonData, info)
if err != nil {
return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
return newAPIErrorFromParamOverride(err)
}
}
@@ -257,14 +257,9 @@ func GeminiEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo) (newAPI
// apply param override
if len(info.ParamOverride) > 0 {
reqMap := make(map[string]interface{})
_ = common.Unmarshal(jsonData, &reqMap)
for key, value := range info.ParamOverride {
reqMap[key] = value
}
jsonData, err = common.Marshal(reqMap)
jsonData, err = relaycommon.ApplyParamOverrideWithRelayInfo(jsonData, info)
if err != nil {
return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
return newAPIErrorFromParamOverride(err)
}
}
logger.LogDebug(c, "Gemini embedding request body: "+string(jsonData))

View File

@@ -70,9 +70,9 @@ func ImageHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *type
// apply param override
if len(info.ParamOverride) > 0 {
jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride, relaycommon.BuildParamOverrideContext(info))
jsonData, err = relaycommon.ApplyParamOverrideWithRelayInfo(jsonData, info)
if err != nil {
return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
return newAPIErrorFromParamOverride(err)
}
}

View File

@@ -0,0 +1,13 @@
package relay
import (
relaycommon "github.com/QuantumNous/new-api/relay/common"
"github.com/QuantumNous/new-api/types"
)
func newAPIErrorFromParamOverride(err error) *types.NewAPIError {
if fixedErr, ok := relaycommon.AsParamOverrideReturnError(err); ok {
return relaycommon.NewAPIErrorFromParamOverride(fixedErr)
}
return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
}

View File

@@ -61,9 +61,9 @@ func RerankHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
// apply param override
if len(info.ParamOverride) > 0 {
jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride, relaycommon.BuildParamOverrideContext(info))
jsonData, err = relaycommon.ApplyParamOverrideWithRelayInfo(jsonData, info)
if err != nil {
return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
return newAPIErrorFromParamOverride(err)
}
}

View File

@@ -96,9 +96,9 @@ func ResponsesHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *
// apply param override
if len(info.ParamOverride) > 0 {
jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride, relaycommon.BuildParamOverrideContext(info))
jsonData, err = relaycommon.ApplyParamOverrideWithRelayInfo(jsonData, info)
if err != nil {
return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
return newAPIErrorFromParamOverride(err)
}
}