mirror of
https://github.com/QuantumNous/new-api.git
synced 2026-03-30 07:37:23 +00:00
fix: ignore header passthrough during channel tests
This commit is contained in:
@@ -171,35 +171,37 @@ func processHeaderOverride(info *common.RelayInfo, c *gin.Context) (map[string]s
|
||||
|
||||
passAll := false
|
||||
var passthroughRegex []*regexp.Regexp
|
||||
for k := range info.HeadersOverride {
|
||||
key := strings.TrimSpace(k)
|
||||
if key == "" {
|
||||
continue
|
||||
}
|
||||
if key == headerPassthroughAllKey {
|
||||
passAll = true
|
||||
continue
|
||||
}
|
||||
if !info.IsChannelTest {
|
||||
for k := range info.HeadersOverride {
|
||||
key := strings.TrimSpace(k)
|
||||
if key == "" {
|
||||
continue
|
||||
}
|
||||
if key == headerPassthroughAllKey {
|
||||
passAll = true
|
||||
continue
|
||||
}
|
||||
|
||||
lower := strings.ToLower(key)
|
||||
var pattern string
|
||||
switch {
|
||||
case strings.HasPrefix(lower, headerPassthroughRegexPrefix):
|
||||
pattern = strings.TrimSpace(key[len(headerPassthroughRegexPrefix):])
|
||||
case strings.HasPrefix(lower, headerPassthroughRegexPrefixV2):
|
||||
pattern = strings.TrimSpace(key[len(headerPassthroughRegexPrefixV2):])
|
||||
default:
|
||||
continue
|
||||
}
|
||||
lower := strings.ToLower(key)
|
||||
var pattern string
|
||||
switch {
|
||||
case strings.HasPrefix(lower, headerPassthroughRegexPrefix):
|
||||
pattern = strings.TrimSpace(key[len(headerPassthroughRegexPrefix):])
|
||||
case strings.HasPrefix(lower, headerPassthroughRegexPrefixV2):
|
||||
pattern = strings.TrimSpace(key[len(headerPassthroughRegexPrefixV2):])
|
||||
default:
|
||||
continue
|
||||
}
|
||||
|
||||
if pattern == "" {
|
||||
return nil, types.NewError(fmt.Errorf("header passthrough regex pattern is empty: %q", k), types.ErrorCodeChannelHeaderOverrideInvalid)
|
||||
if pattern == "" {
|
||||
return nil, types.NewError(fmt.Errorf("header passthrough regex pattern is empty: %q", k), types.ErrorCodeChannelHeaderOverrideInvalid)
|
||||
}
|
||||
compiled, err := getHeaderPassthroughRegex(pattern)
|
||||
if err != nil {
|
||||
return nil, types.NewError(err, types.ErrorCodeChannelHeaderOverrideInvalid)
|
||||
}
|
||||
passthroughRegex = append(passthroughRegex, compiled)
|
||||
}
|
||||
compiled, err := getHeaderPassthroughRegex(pattern)
|
||||
if err != nil {
|
||||
return nil, types.NewError(err, types.ErrorCodeChannelHeaderOverrideInvalid)
|
||||
}
|
||||
passthroughRegex = append(passthroughRegex, compiled)
|
||||
}
|
||||
|
||||
if passAll || len(passthroughRegex) > 0 {
|
||||
@@ -243,6 +245,9 @@ func processHeaderOverride(info *common.RelayInfo, c *gin.Context) (map[string]s
|
||||
if !ok {
|
||||
return nil, types.NewError(nil, types.ErrorCodeChannelHeaderOverrideInvalid)
|
||||
}
|
||||
if info.IsChannelTest && strings.HasPrefix(strings.TrimSpace(str), clientHeaderPlaceholderPrefix) {
|
||||
continue
|
||||
}
|
||||
|
||||
value, include, err := applyHeaderOverridePlaceholders(str, c, info.ApiKey)
|
||||
if err != nil {
|
||||
|
||||
81
relay/channel/api_request_test.go
Normal file
81
relay/channel/api_request_test.go
Normal file
@@ -0,0 +1,81 @@
|
||||
package channel
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestProcessHeaderOverride_ChannelTestSkipsPassthroughRules(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
gin.SetMode(gin.TestMode)
|
||||
recorder := httptest.NewRecorder()
|
||||
ctx, _ := gin.CreateTestContext(recorder)
|
||||
ctx.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||
ctx.Request.Header.Set("X-Trace-Id", "trace-123")
|
||||
|
||||
info := &relaycommon.RelayInfo{
|
||||
IsChannelTest: true,
|
||||
ChannelMeta: &relaycommon.ChannelMeta{
|
||||
HeadersOverride: map[string]any{
|
||||
"*": "",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
headers, err := processHeaderOverride(info, ctx)
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, headers)
|
||||
}
|
||||
|
||||
func TestProcessHeaderOverride_ChannelTestSkipsClientHeaderPlaceholder(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
gin.SetMode(gin.TestMode)
|
||||
recorder := httptest.NewRecorder()
|
||||
ctx, _ := gin.CreateTestContext(recorder)
|
||||
ctx.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||
ctx.Request.Header.Set("X-Trace-Id", "trace-123")
|
||||
|
||||
info := &relaycommon.RelayInfo{
|
||||
IsChannelTest: true,
|
||||
ChannelMeta: &relaycommon.ChannelMeta{
|
||||
HeadersOverride: map[string]any{
|
||||
"X-Upstream-Trace": "{client_header:X-Trace-Id}",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
headers, err := processHeaderOverride(info, ctx)
|
||||
require.NoError(t, err)
|
||||
_, ok := headers["X-Upstream-Trace"]
|
||||
require.False(t, ok)
|
||||
}
|
||||
|
||||
func TestProcessHeaderOverride_NonTestKeepsClientHeaderPlaceholder(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
gin.SetMode(gin.TestMode)
|
||||
recorder := httptest.NewRecorder()
|
||||
ctx, _ := gin.CreateTestContext(recorder)
|
||||
ctx.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||
ctx.Request.Header.Set("X-Trace-Id", "trace-123")
|
||||
|
||||
info := &relaycommon.RelayInfo{
|
||||
IsChannelTest: false,
|
||||
ChannelMeta: &relaycommon.ChannelMeta{
|
||||
HeadersOverride: map[string]any{
|
||||
"X-Upstream-Trace": "{client_header:X-Trace-Id}",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
headers, err := processHeaderOverride(info, ctx)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "trace-123", headers["X-Upstream-Trace"])
|
||||
}
|
||||
Reference in New Issue
Block a user