fix: ignore header passthrough during channel tests

This commit is contained in:
Seefs
2026-02-12 15:16:24 +08:00
parent 30fed3cc5c
commit c78b37662b
2 changed files with 112 additions and 26 deletions

View File

@@ -171,35 +171,37 @@ func processHeaderOverride(info *common.RelayInfo, c *gin.Context) (map[string]s
passAll := false
var passthroughRegex []*regexp.Regexp
for k := range info.HeadersOverride {
key := strings.TrimSpace(k)
if key == "" {
continue
}
if key == headerPassthroughAllKey {
passAll = true
continue
}
if !info.IsChannelTest {
for k := range info.HeadersOverride {
key := strings.TrimSpace(k)
if key == "" {
continue
}
if key == headerPassthroughAllKey {
passAll = true
continue
}
lower := strings.ToLower(key)
var pattern string
switch {
case strings.HasPrefix(lower, headerPassthroughRegexPrefix):
pattern = strings.TrimSpace(key[len(headerPassthroughRegexPrefix):])
case strings.HasPrefix(lower, headerPassthroughRegexPrefixV2):
pattern = strings.TrimSpace(key[len(headerPassthroughRegexPrefixV2):])
default:
continue
}
lower := strings.ToLower(key)
var pattern string
switch {
case strings.HasPrefix(lower, headerPassthroughRegexPrefix):
pattern = strings.TrimSpace(key[len(headerPassthroughRegexPrefix):])
case strings.HasPrefix(lower, headerPassthroughRegexPrefixV2):
pattern = strings.TrimSpace(key[len(headerPassthroughRegexPrefixV2):])
default:
continue
}
if pattern == "" {
return nil, types.NewError(fmt.Errorf("header passthrough regex pattern is empty: %q", k), types.ErrorCodeChannelHeaderOverrideInvalid)
if pattern == "" {
return nil, types.NewError(fmt.Errorf("header passthrough regex pattern is empty: %q", k), types.ErrorCodeChannelHeaderOverrideInvalid)
}
compiled, err := getHeaderPassthroughRegex(pattern)
if err != nil {
return nil, types.NewError(err, types.ErrorCodeChannelHeaderOverrideInvalid)
}
passthroughRegex = append(passthroughRegex, compiled)
}
compiled, err := getHeaderPassthroughRegex(pattern)
if err != nil {
return nil, types.NewError(err, types.ErrorCodeChannelHeaderOverrideInvalid)
}
passthroughRegex = append(passthroughRegex, compiled)
}
if passAll || len(passthroughRegex) > 0 {
@@ -243,6 +245,9 @@ func processHeaderOverride(info *common.RelayInfo, c *gin.Context) (map[string]s
if !ok {
return nil, types.NewError(nil, types.ErrorCodeChannelHeaderOverrideInvalid)
}
if info.IsChannelTest && strings.HasPrefix(strings.TrimSpace(str), clientHeaderPlaceholderPrefix) {
continue
}
value, include, err := applyHeaderOverridePlaceholders(str, c, info.ApiKey)
if err != nil {

View File

@@ -0,0 +1,81 @@
package channel
import (
"net/http"
"net/http/httptest"
"testing"
relaycommon "github.com/QuantumNous/new-api/relay/common"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
func TestProcessHeaderOverride_ChannelTestSkipsPassthroughRules(t *testing.T) {
t.Parallel()
gin.SetMode(gin.TestMode)
recorder := httptest.NewRecorder()
ctx, _ := gin.CreateTestContext(recorder)
ctx.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
ctx.Request.Header.Set("X-Trace-Id", "trace-123")
info := &relaycommon.RelayInfo{
IsChannelTest: true,
ChannelMeta: &relaycommon.ChannelMeta{
HeadersOverride: map[string]any{
"*": "",
},
},
}
headers, err := processHeaderOverride(info, ctx)
require.NoError(t, err)
require.Empty(t, headers)
}
func TestProcessHeaderOverride_ChannelTestSkipsClientHeaderPlaceholder(t *testing.T) {
t.Parallel()
gin.SetMode(gin.TestMode)
recorder := httptest.NewRecorder()
ctx, _ := gin.CreateTestContext(recorder)
ctx.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
ctx.Request.Header.Set("X-Trace-Id", "trace-123")
info := &relaycommon.RelayInfo{
IsChannelTest: true,
ChannelMeta: &relaycommon.ChannelMeta{
HeadersOverride: map[string]any{
"X-Upstream-Trace": "{client_header:X-Trace-Id}",
},
},
}
headers, err := processHeaderOverride(info, ctx)
require.NoError(t, err)
_, ok := headers["X-Upstream-Trace"]
require.False(t, ok)
}
func TestProcessHeaderOverride_NonTestKeepsClientHeaderPlaceholder(t *testing.T) {
t.Parallel()
gin.SetMode(gin.TestMode)
recorder := httptest.NewRecorder()
ctx, _ := gin.CreateTestContext(recorder)
ctx.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
ctx.Request.Header.Set("X-Trace-Id", "trace-123")
info := &relaycommon.RelayInfo{
IsChannelTest: false,
ChannelMeta: &relaycommon.ChannelMeta{
HeadersOverride: map[string]any{
"X-Upstream-Trace": "{client_header:X-Trace-Id}",
},
},
}
headers, err := processHeaderOverride(info, ctx)
require.NoError(t, err)
require.Equal(t, "trace-123", headers["X-Upstream-Trace"])
}