diff --git a/relay/channel/api_request.go b/relay/channel/api_request.go index dfa0e4fab..ec5573ab1 100644 --- a/relay/channel/api_request.go +++ b/relay/channel/api_request.go @@ -171,35 +171,37 @@ func processHeaderOverride(info *common.RelayInfo, c *gin.Context) (map[string]s passAll := false var passthroughRegex []*regexp.Regexp - for k := range info.HeadersOverride { - key := strings.TrimSpace(k) - if key == "" { - continue - } - if key == headerPassthroughAllKey { - passAll = true - continue - } + if !info.IsChannelTest { + for k := range info.HeadersOverride { + key := strings.TrimSpace(k) + if key == "" { + continue + } + if key == headerPassthroughAllKey { + passAll = true + continue + } - lower := strings.ToLower(key) - var pattern string - switch { - case strings.HasPrefix(lower, headerPassthroughRegexPrefix): - pattern = strings.TrimSpace(key[len(headerPassthroughRegexPrefix):]) - case strings.HasPrefix(lower, headerPassthroughRegexPrefixV2): - pattern = strings.TrimSpace(key[len(headerPassthroughRegexPrefixV2):]) - default: - continue - } + lower := strings.ToLower(key) + var pattern string + switch { + case strings.HasPrefix(lower, headerPassthroughRegexPrefix): + pattern = strings.TrimSpace(key[len(headerPassthroughRegexPrefix):]) + case strings.HasPrefix(lower, headerPassthroughRegexPrefixV2): + pattern = strings.TrimSpace(key[len(headerPassthroughRegexPrefixV2):]) + default: + continue + } - if pattern == "" { - return nil, types.NewError(fmt.Errorf("header passthrough regex pattern is empty: %q", k), types.ErrorCodeChannelHeaderOverrideInvalid) + if pattern == "" { + return nil, types.NewError(fmt.Errorf("header passthrough regex pattern is empty: %q", k), types.ErrorCodeChannelHeaderOverrideInvalid) + } + compiled, err := getHeaderPassthroughRegex(pattern) + if err != nil { + return nil, types.NewError(err, types.ErrorCodeChannelHeaderOverrideInvalid) + } + passthroughRegex = append(passthroughRegex, compiled) } - compiled, err := getHeaderPassthroughRegex(pattern) - if err != nil { - return nil, types.NewError(err, types.ErrorCodeChannelHeaderOverrideInvalid) - } - passthroughRegex = append(passthroughRegex, compiled) } if passAll || len(passthroughRegex) > 0 { @@ -243,6 +245,9 @@ func processHeaderOverride(info *common.RelayInfo, c *gin.Context) (map[string]s if !ok { return nil, types.NewError(nil, types.ErrorCodeChannelHeaderOverrideInvalid) } + if info.IsChannelTest && strings.HasPrefix(strings.TrimSpace(str), clientHeaderPlaceholderPrefix) { + continue + } value, include, err := applyHeaderOverridePlaceholders(str, c, info.ApiKey) if err != nil { diff --git a/relay/channel/api_request_test.go b/relay/channel/api_request_test.go new file mode 100644 index 000000000..c55ffcab2 --- /dev/null +++ b/relay/channel/api_request_test.go @@ -0,0 +1,81 @@ +package channel + +import ( + "net/http" + "net/http/httptest" + "testing" + + relaycommon "github.com/QuantumNous/new-api/relay/common" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +func TestProcessHeaderOverride_ChannelTestSkipsPassthroughRules(t *testing.T) { + t.Parallel() + + gin.SetMode(gin.TestMode) + recorder := httptest.NewRecorder() + ctx, _ := gin.CreateTestContext(recorder) + ctx.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil) + ctx.Request.Header.Set("X-Trace-Id", "trace-123") + + info := &relaycommon.RelayInfo{ + IsChannelTest: true, + ChannelMeta: &relaycommon.ChannelMeta{ + HeadersOverride: map[string]any{ + "*": "", + }, + }, + } + + headers, err := processHeaderOverride(info, ctx) + require.NoError(t, err) + require.Empty(t, headers) +} + +func TestProcessHeaderOverride_ChannelTestSkipsClientHeaderPlaceholder(t *testing.T) { + t.Parallel() + + gin.SetMode(gin.TestMode) + recorder := httptest.NewRecorder() + ctx, _ := gin.CreateTestContext(recorder) + ctx.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil) + ctx.Request.Header.Set("X-Trace-Id", "trace-123") + + info := &relaycommon.RelayInfo{ + IsChannelTest: true, + ChannelMeta: &relaycommon.ChannelMeta{ + HeadersOverride: map[string]any{ + "X-Upstream-Trace": "{client_header:X-Trace-Id}", + }, + }, + } + + headers, err := processHeaderOverride(info, ctx) + require.NoError(t, err) + _, ok := headers["X-Upstream-Trace"] + require.False(t, ok) +} + +func TestProcessHeaderOverride_NonTestKeepsClientHeaderPlaceholder(t *testing.T) { + t.Parallel() + + gin.SetMode(gin.TestMode) + recorder := httptest.NewRecorder() + ctx, _ := gin.CreateTestContext(recorder) + ctx.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil) + ctx.Request.Header.Set("X-Trace-Id", "trace-123") + + info := &relaycommon.RelayInfo{ + IsChannelTest: false, + ChannelMeta: &relaycommon.ChannelMeta{ + HeadersOverride: map[string]any{ + "X-Upstream-Trace": "{client_header:X-Trace-Id}", + }, + }, + } + + headers, err := processHeaderOverride(info, ctx) + require.NoError(t, err) + require.Equal(t, "trace-123", headers["X-Upstream-Trace"]) +}