mirror of
https://github.com/QuantumNous/new-api.git
synced 2026-03-30 02:25:00 +00:00
refactor(override): simplify header overrides to a lowercase single map
This commit is contained in:
@@ -179,7 +179,7 @@ func processHeaderOverride(info *common.RelayInfo, c *gin.Context) (map[string]s
|
|||||||
var passthroughRegex []*regexp.Regexp
|
var passthroughRegex []*regexp.Regexp
|
||||||
if !info.IsChannelTest {
|
if !info.IsChannelTest {
|
||||||
for k := range headerOverrideSource {
|
for k := range headerOverrideSource {
|
||||||
key := strings.TrimSpace(k)
|
key := strings.TrimSpace(strings.ToLower(k))
|
||||||
if key == "" {
|
if key == "" {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@@ -188,12 +188,11 @@ func processHeaderOverride(info *common.RelayInfo, c *gin.Context) (map[string]s
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
lower := strings.ToLower(key)
|
|
||||||
var pattern string
|
var pattern string
|
||||||
switch {
|
switch {
|
||||||
case strings.HasPrefix(lower, headerPassthroughRegexPrefix):
|
case strings.HasPrefix(key, headerPassthroughRegexPrefix):
|
||||||
pattern = strings.TrimSpace(key[len(headerPassthroughRegexPrefix):])
|
pattern = strings.TrimSpace(key[len(headerPassthroughRegexPrefix):])
|
||||||
case strings.HasPrefix(lower, headerPassthroughRegexPrefixV2):
|
case strings.HasPrefix(key, headerPassthroughRegexPrefixV2):
|
||||||
pattern = strings.TrimSpace(key[len(headerPassthroughRegexPrefixV2):])
|
pattern = strings.TrimSpace(key[len(headerPassthroughRegexPrefixV2):])
|
||||||
default:
|
default:
|
||||||
continue
|
continue
|
||||||
@@ -234,7 +233,7 @@ func processHeaderOverride(info *common.RelayInfo, c *gin.Context) (map[string]s
|
|||||||
if value == "" {
|
if value == "" {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
headerOverride[name] = value
|
headerOverride[strings.ToLower(strings.TrimSpace(name))] = value
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -242,7 +241,7 @@ func processHeaderOverride(info *common.RelayInfo, c *gin.Context) (map[string]s
|
|||||||
if isHeaderPassthroughRuleKey(k) {
|
if isHeaderPassthroughRuleKey(k) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
key := strings.TrimSpace(k)
|
key := strings.TrimSpace(strings.ToLower(k))
|
||||||
if key == "" {
|
if key == "" {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -53,7 +53,7 @@ func TestProcessHeaderOverride_ChannelTestSkipsClientHeaderPlaceholder(t *testin
|
|||||||
|
|
||||||
headers, err := processHeaderOverride(info, ctx)
|
headers, err := processHeaderOverride(info, ctx)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
_, ok := headers["X-Upstream-Trace"]
|
_, ok := headers["x-upstream-trace"]
|
||||||
require.False(t, ok)
|
require.False(t, ok)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -77,10 +77,10 @@ func TestProcessHeaderOverride_NonTestKeepsClientHeaderPlaceholder(t *testing.T)
|
|||||||
|
|
||||||
headers, err := processHeaderOverride(info, ctx)
|
headers, err := processHeaderOverride(info, ctx)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Equal(t, "trace-123", headers["X-Upstream-Trace"])
|
require.Equal(t, "trace-123", headers["x-upstream-trace"])
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestProcessHeaderOverride_RuntimeOverrideMergesWithChannelOverride(t *testing.T) {
|
func TestProcessHeaderOverride_RuntimeOverrideIsFinalHeaderMap(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
gin.SetMode(gin.TestMode)
|
gin.SetMode(gin.TestMode)
|
||||||
@@ -92,8 +92,8 @@ func TestProcessHeaderOverride_RuntimeOverrideMergesWithChannelOverride(t *testi
|
|||||||
IsChannelTest: false,
|
IsChannelTest: false,
|
||||||
UseRuntimeHeadersOverride: true,
|
UseRuntimeHeadersOverride: true,
|
||||||
RuntimeHeadersOverride: map[string]any{
|
RuntimeHeadersOverride: map[string]any{
|
||||||
"X-Static": "runtime-value",
|
"x-static": "runtime-value",
|
||||||
"X-Runtime": "runtime-only",
|
"x-runtime": "runtime-only",
|
||||||
},
|
},
|
||||||
ChannelMeta: &relaycommon.ChannelMeta{
|
ChannelMeta: &relaycommon.ChannelMeta{
|
||||||
HeadersOverride: map[string]any{
|
HeadersOverride: map[string]any{
|
||||||
@@ -105,9 +105,10 @@ func TestProcessHeaderOverride_RuntimeOverrideMergesWithChannelOverride(t *testi
|
|||||||
|
|
||||||
headers, err := processHeaderOverride(info, ctx)
|
headers, err := processHeaderOverride(info, ctx)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Equal(t, "runtime-value", headers["X-Static"])
|
require.Equal(t, "runtime-value", headers["x-static"])
|
||||||
require.Equal(t, "runtime-only", headers["X-Runtime"])
|
require.Equal(t, "runtime-only", headers["x-runtime"])
|
||||||
require.Equal(t, "legacy-only", headers["X-Legacy"])
|
_, exists := headers["x-legacy"]
|
||||||
|
require.False(t, exists)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestProcessHeaderOverride_PassthroughSkipsAcceptEncoding(t *testing.T) {
|
func TestProcessHeaderOverride_PassthroughSkipsAcceptEncoding(t *testing.T) {
|
||||||
@@ -131,9 +132,9 @@ func TestProcessHeaderOverride_PassthroughSkipsAcceptEncoding(t *testing.T) {
|
|||||||
|
|
||||||
headers, err := processHeaderOverride(info, ctx)
|
headers, err := processHeaderOverride(info, ctx)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Equal(t, "trace-123", headers["X-Trace-Id"])
|
require.Equal(t, "trace-123", headers["x-trace-id"])
|
||||||
|
|
||||||
_, hasAcceptEncoding := headers["Accept-Encoding"]
|
_, hasAcceptEncoding := headers["accept-encoding"]
|
||||||
require.False(t, hasAcceptEncoding)
|
require.False(t, hasAcceptEncoding)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -171,16 +172,17 @@ func TestProcessHeaderOverride_PassHeadersTemplateSetsRuntimeHeaders(t *testing.
|
|||||||
_, err := relaycommon.ApplyParamOverrideWithRelayInfo([]byte(`{"model":"gpt-4.1"}`), info)
|
_, err := relaycommon.ApplyParamOverrideWithRelayInfo([]byte(`{"model":"gpt-4.1"}`), info)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.True(t, info.UseRuntimeHeadersOverride)
|
require.True(t, info.UseRuntimeHeadersOverride)
|
||||||
require.Equal(t, "Codex CLI", info.RuntimeHeadersOverride["Originator"])
|
require.Equal(t, "Codex CLI", info.RuntimeHeadersOverride["originator"])
|
||||||
require.Equal(t, "sess-123", info.RuntimeHeadersOverride["Session_id"])
|
require.Equal(t, "sess-123", info.RuntimeHeadersOverride["session_id"])
|
||||||
_, exists := info.RuntimeHeadersOverride["X-Codex-Beta-Features"]
|
_, exists := info.RuntimeHeadersOverride["x-codex-beta-features"]
|
||||||
require.False(t, exists)
|
require.False(t, exists)
|
||||||
|
require.Equal(t, "legacy-value", info.RuntimeHeadersOverride["x-static"])
|
||||||
|
|
||||||
headers, err := processHeaderOverride(info, ctx)
|
headers, err := processHeaderOverride(info, ctx)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Equal(t, "Codex CLI", headers["Originator"])
|
require.Equal(t, "Codex CLI", headers["originator"])
|
||||||
require.Equal(t, "sess-123", headers["Session_id"])
|
require.Equal(t, "sess-123", headers["session_id"])
|
||||||
_, exists = headers["X-Codex-Beta-Features"]
|
_, exists = headers["x-codex-beta-features"]
|
||||||
require.False(t, exists)
|
require.False(t, exists)
|
||||||
|
|
||||||
upstreamReq := httptest.NewRequest(http.MethodPost, "https://example.com/v1/responses", nil)
|
upstreamReq := httptest.NewRequest(http.MethodPost, "https://example.com/v1/responses", nil)
|
||||||
|
|||||||
@@ -18,11 +18,8 @@ import (
|
|||||||
var negativeIndexRegexp = regexp.MustCompile(`\.(-\d+)`)
|
var negativeIndexRegexp = regexp.MustCompile(`\.(-\d+)`)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
paramOverrideContextRequestHeaders = "request_headers"
|
paramOverrideContextRequestHeaders = "request_headers"
|
||||||
paramOverrideContextRequestHeadersRaw = "request_headers_raw"
|
paramOverrideContextHeaderOverride = "header_override"
|
||||||
paramOverrideContextHeaderOverride = "header_override"
|
|
||||||
paramOverrideContextHeaderOverrideNormalized = "header_override_normalized"
|
|
||||||
paramOverrideContextHeaderOverrideDeleted = "header_override_deleted_normalized"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var errSourceHeaderNotFound = errors.New("source header does not exist")
|
var errSourceHeaderNotFound = errors.New("source header does not exist")
|
||||||
@@ -161,141 +158,118 @@ func getHeaderOverrideMap(info *RelayInfo) map[string]interface{} {
|
|||||||
return info.ChannelMeta.HeadersOverride
|
return info.ChannelMeta.HeadersOverride
|
||||||
}
|
}
|
||||||
|
|
||||||
func cloneHeaderOverrideMap(source map[string]interface{}) map[string]interface{} {
|
func sanitizeHeaderOverrideMap(source map[string]interface{}) map[string]interface{} {
|
||||||
if len(source) == 0 {
|
if len(source) == 0 {
|
||||||
return map[string]interface{}{}
|
return map[string]interface{}{}
|
||||||
}
|
}
|
||||||
target := make(map[string]interface{}, len(source))
|
target := make(map[string]interface{}, len(source))
|
||||||
for key, value := range source {
|
for key, value := range source {
|
||||||
target[key] = value
|
normalizedKey := normalizeHeaderContextKey(key)
|
||||||
|
if normalizedKey == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
normalizedValue := strings.TrimSpace(fmt.Sprintf("%v", value))
|
||||||
|
if normalizedValue == "" {
|
||||||
|
if isHeaderPassthroughRuleKeyForOverride(normalizedKey) {
|
||||||
|
target[normalizedKey] = ""
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
target[normalizedKey] = normalizedValue
|
||||||
}
|
}
|
||||||
return target
|
return target
|
||||||
}
|
}
|
||||||
|
|
||||||
func setHeaderOverrideEntry(target map[string]interface{}, key string, value interface{}) {
|
func isHeaderPassthroughRuleKeyForOverride(key string) bool {
|
||||||
key = strings.TrimSpace(key)
|
key = strings.TrimSpace(strings.ToLower(key))
|
||||||
if key == "" {
|
if key == "" {
|
||||||
return
|
|
||||||
}
|
|
||||||
for existing := range target {
|
|
||||||
if strings.EqualFold(strings.TrimSpace(existing), key) {
|
|
||||||
delete(target, existing)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
target[key] = value
|
|
||||||
}
|
|
||||||
|
|
||||||
func isHeaderDeletedByRuntime(headerName string, deleted map[string]bool) bool {
|
|
||||||
if len(deleted) == 0 {
|
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
normalized := normalizeHeaderContextKey(headerName)
|
if key == "*" {
|
||||||
if normalized == "" {
|
return true
|
||||||
return false
|
|
||||||
}
|
}
|
||||||
return deleted[normalized]
|
return strings.HasPrefix(key, "re:") || strings.HasPrefix(key, "regex:")
|
||||||
}
|
|
||||||
|
|
||||||
func mergeHeaderOverrideSource(base, runtime map[string]interface{}, deleted map[string]bool) map[string]interface{} {
|
|
||||||
merged := make(map[string]interface{}, len(base)+len(runtime))
|
|
||||||
for key, value := range base {
|
|
||||||
if isHeaderDeletedByRuntime(key, deleted) {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
setHeaderOverrideEntry(merged, key, value)
|
|
||||||
}
|
|
||||||
for key, value := range runtime {
|
|
||||||
setHeaderOverrideEntry(merged, key, value)
|
|
||||||
}
|
|
||||||
return merged
|
|
||||||
}
|
|
||||||
|
|
||||||
func cloneDeletedHeaderKeys(source map[string]bool) map[string]bool {
|
|
||||||
if len(source) == 0 {
|
|
||||||
return map[string]bool{}
|
|
||||||
}
|
|
||||||
target := make(map[string]bool, len(source))
|
|
||||||
for key, value := range source {
|
|
||||||
if !value {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
normalized := normalizeHeaderContextKey(key)
|
|
||||||
if normalized == "" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
target[normalized] = true
|
|
||||||
}
|
|
||||||
return target
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetEffectiveHeaderOverride(info *RelayInfo) map[string]interface{} {
|
func GetEffectiveHeaderOverride(info *RelayInfo) map[string]interface{} {
|
||||||
if info == nil {
|
if info == nil {
|
||||||
return map[string]interface{}{}
|
return map[string]interface{}{}
|
||||||
}
|
}
|
||||||
base := getHeaderOverrideMap(info)
|
if info.UseRuntimeHeadersOverride {
|
||||||
if !info.UseRuntimeHeadersOverride {
|
return sanitizeHeaderOverrideMap(info.RuntimeHeadersOverride)
|
||||||
return cloneHeaderOverrideMap(base)
|
|
||||||
}
|
}
|
||||||
return mergeHeaderOverrideSource(base, info.RuntimeHeadersOverride, cloneDeletedHeaderKeys(info.RuntimeHeadersDeletedNormalized))
|
return sanitizeHeaderOverrideMap(getHeaderOverrideMap(info))
|
||||||
}
|
}
|
||||||
|
|
||||||
func tryParseOperations(paramOverride map[string]interface{}) ([]ParamOperation, bool) {
|
func tryParseOperations(paramOverride map[string]interface{}) ([]ParamOperation, bool) {
|
||||||
// 检查是否包含 "operations" 字段
|
// 检查是否包含 "operations" 字段
|
||||||
if opsValue, exists := paramOverride["operations"]; exists {
|
opsValue, exists := paramOverride["operations"]
|
||||||
if opsSlice, ok := opsValue.([]interface{}); ok {
|
if !exists {
|
||||||
var operations []ParamOperation
|
return nil, false
|
||||||
for _, op := range opsSlice {
|
|
||||||
if opMap, ok := op.(map[string]interface{}); ok {
|
|
||||||
operation := ParamOperation{}
|
|
||||||
|
|
||||||
// 断言必要字段
|
|
||||||
if path, ok := opMap["path"].(string); ok {
|
|
||||||
operation.Path = path
|
|
||||||
}
|
|
||||||
if mode, ok := opMap["mode"].(string); ok {
|
|
||||||
operation.Mode = mode
|
|
||||||
} else {
|
|
||||||
return nil, false // mode 是必需的
|
|
||||||
}
|
|
||||||
|
|
||||||
// 可选字段
|
|
||||||
if value, exists := opMap["value"]; exists {
|
|
||||||
operation.Value = value
|
|
||||||
}
|
|
||||||
if keepOrigin, ok := opMap["keep_origin"].(bool); ok {
|
|
||||||
operation.KeepOrigin = keepOrigin
|
|
||||||
}
|
|
||||||
if from, ok := opMap["from"].(string); ok {
|
|
||||||
operation.From = from
|
|
||||||
}
|
|
||||||
if to, ok := opMap["to"].(string); ok {
|
|
||||||
operation.To = to
|
|
||||||
}
|
|
||||||
if logic, ok := opMap["logic"].(string); ok {
|
|
||||||
operation.Logic = logic
|
|
||||||
} else {
|
|
||||||
operation.Logic = "OR" // 默认为OR
|
|
||||||
}
|
|
||||||
|
|
||||||
// 解析条件
|
|
||||||
if conditions, exists := opMap["conditions"]; exists {
|
|
||||||
parsedConditions, err := parseConditionOperations(conditions)
|
|
||||||
if err != nil {
|
|
||||||
return nil, false
|
|
||||||
}
|
|
||||||
operation.Conditions = append(operation.Conditions, parsedConditions...)
|
|
||||||
}
|
|
||||||
|
|
||||||
operations = append(operations, operation)
|
|
||||||
} else {
|
|
||||||
return nil, false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return operations, true
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil, false
|
var opMaps []map[string]interface{}
|
||||||
|
switch ops := opsValue.(type) {
|
||||||
|
case []interface{}:
|
||||||
|
opMaps = make([]map[string]interface{}, 0, len(ops))
|
||||||
|
for _, op := range ops {
|
||||||
|
opMap, ok := op.(map[string]interface{})
|
||||||
|
if !ok {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
opMaps = append(opMaps, opMap)
|
||||||
|
}
|
||||||
|
case []map[string]interface{}:
|
||||||
|
opMaps = ops
|
||||||
|
default:
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
operations := make([]ParamOperation, 0, len(opMaps))
|
||||||
|
for _, opMap := range opMaps {
|
||||||
|
operation := ParamOperation{}
|
||||||
|
|
||||||
|
// 断言必要字段
|
||||||
|
if path, ok := opMap["path"].(string); ok {
|
||||||
|
operation.Path = path
|
||||||
|
}
|
||||||
|
if mode, ok := opMap["mode"].(string); ok {
|
||||||
|
operation.Mode = mode
|
||||||
|
} else {
|
||||||
|
return nil, false // mode 是必需的
|
||||||
|
}
|
||||||
|
|
||||||
|
// 可选字段
|
||||||
|
if value, exists := opMap["value"]; exists {
|
||||||
|
operation.Value = value
|
||||||
|
}
|
||||||
|
if keepOrigin, ok := opMap["keep_origin"].(bool); ok {
|
||||||
|
operation.KeepOrigin = keepOrigin
|
||||||
|
}
|
||||||
|
if from, ok := opMap["from"].(string); ok {
|
||||||
|
operation.From = from
|
||||||
|
}
|
||||||
|
if to, ok := opMap["to"].(string); ok {
|
||||||
|
operation.To = to
|
||||||
|
}
|
||||||
|
if logic, ok := opMap["logic"].(string); ok {
|
||||||
|
operation.Logic = logic
|
||||||
|
} else {
|
||||||
|
operation.Logic = "OR" // 默认为OR
|
||||||
|
}
|
||||||
|
|
||||||
|
// 解析条件
|
||||||
|
if conditions, exists := opMap["conditions"]; exists {
|
||||||
|
parsedConditions, err := parseConditionOperations(conditions)
|
||||||
|
if err != nil {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
operation.Conditions = append(operation.Conditions, parsedConditions...)
|
||||||
|
}
|
||||||
|
|
||||||
|
operations = append(operations, operation)
|
||||||
|
}
|
||||||
|
return operations, true
|
||||||
}
|
}
|
||||||
|
|
||||||
func checkConditions(jsonStr, contextJSON string, conditions []ConditionOperation, logic string) (bool, error) {
|
func checkConditions(jsonStr, contextJSON string, conditions []ConditionOperation, logic string) (bool, error) {
|
||||||
@@ -712,15 +686,10 @@ func marshalContextJSON(context map[string]interface{}) (string, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func setHeaderOverrideInContext(context map[string]interface{}, headerName string, value interface{}, keepOrigin bool) error {
|
func setHeaderOverrideInContext(context map[string]interface{}, headerName string, value interface{}, keepOrigin bool) error {
|
||||||
headerName = strings.TrimSpace(headerName)
|
headerName = normalizeHeaderContextKey(headerName)
|
||||||
if headerName == "" {
|
if headerName == "" {
|
||||||
return fmt.Errorf("header name is required")
|
return fmt.Errorf("header name is required")
|
||||||
}
|
}
|
||||||
if keepOrigin {
|
|
||||||
if _, exists := getHeaderValueFromContext(context, headerName); exists {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if value == nil {
|
if value == nil {
|
||||||
return fmt.Errorf("header value is required")
|
return fmt.Errorf("header value is required")
|
||||||
}
|
}
|
||||||
@@ -730,21 +699,21 @@ func setHeaderOverrideInContext(context map[string]interface{}, headerName strin
|
|||||||
}
|
}
|
||||||
|
|
||||||
rawHeaders := ensureMapKeyInContext(context, paramOverrideContextHeaderOverride)
|
rawHeaders := ensureMapKeyInContext(context, paramOverrideContextHeaderOverride)
|
||||||
rawHeaders[headerName] = headerValue
|
if keepOrigin {
|
||||||
|
if existing, ok := rawHeaders[headerName]; ok {
|
||||||
normalizedHeaderName := normalizeHeaderContextKey(headerName)
|
existingValue := strings.TrimSpace(fmt.Sprintf("%v", existing))
|
||||||
normalizedHeaders := ensureMapKeyInContext(context, paramOverrideContextHeaderOverrideNormalized)
|
if existingValue != "" {
|
||||||
normalizedHeaders[normalizedHeaderName] = headerValue
|
return nil
|
||||||
if normalizedHeaderName != "" {
|
}
|
||||||
deletedHeaders := ensureMapKeyInContext(context, paramOverrideContextHeaderOverrideDeleted)
|
}
|
||||||
delete(deletedHeaders, normalizedHeaderName)
|
|
||||||
}
|
}
|
||||||
|
rawHeaders[headerName] = headerValue
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func copyHeaderInContext(context map[string]interface{}, fromHeader, toHeader string, keepOrigin bool) error {
|
func copyHeaderInContext(context map[string]interface{}, fromHeader, toHeader string, keepOrigin bool) error {
|
||||||
fromHeader = strings.TrimSpace(fromHeader)
|
fromHeader = normalizeHeaderContextKey(fromHeader)
|
||||||
toHeader = strings.TrimSpace(toHeader)
|
toHeader = normalizeHeaderContextKey(toHeader)
|
||||||
if fromHeader == "" || toHeader == "" {
|
if fromHeader == "" || toHeader == "" {
|
||||||
return fmt.Errorf("copy_header from/to is required")
|
return fmt.Errorf("copy_header from/to is required")
|
||||||
}
|
}
|
||||||
@@ -756,8 +725,8 @@ func copyHeaderInContext(context map[string]interface{}, fromHeader, toHeader st
|
|||||||
}
|
}
|
||||||
|
|
||||||
func moveHeaderInContext(context map[string]interface{}, fromHeader, toHeader string, keepOrigin bool) error {
|
func moveHeaderInContext(context map[string]interface{}, fromHeader, toHeader string, keepOrigin bool) error {
|
||||||
fromHeader = strings.TrimSpace(fromHeader)
|
fromHeader = normalizeHeaderContextKey(fromHeader)
|
||||||
toHeader = strings.TrimSpace(toHeader)
|
toHeader = normalizeHeaderContextKey(toHeader)
|
||||||
if fromHeader == "" || toHeader == "" {
|
if fromHeader == "" || toHeader == "" {
|
||||||
return fmt.Errorf("move_header from/to is required")
|
return fmt.Errorf("move_header from/to is required")
|
||||||
}
|
}
|
||||||
@@ -771,31 +740,19 @@ func moveHeaderInContext(context map[string]interface{}, fromHeader, toHeader st
|
|||||||
}
|
}
|
||||||
|
|
||||||
func deleteHeaderOverrideInContext(context map[string]interface{}, headerName string) error {
|
func deleteHeaderOverrideInContext(context map[string]interface{}, headerName string) error {
|
||||||
headerName = strings.TrimSpace(headerName)
|
headerName = normalizeHeaderContextKey(headerName)
|
||||||
if headerName == "" {
|
if headerName == "" {
|
||||||
return fmt.Errorf("header name is required")
|
return fmt.Errorf("header name is required")
|
||||||
}
|
}
|
||||||
rawHeaders := ensureMapKeyInContext(context, paramOverrideContextHeaderOverride)
|
rawHeaders := ensureMapKeyInContext(context, paramOverrideContextHeaderOverride)
|
||||||
for key := range rawHeaders {
|
delete(rawHeaders, headerName)
|
||||||
if strings.EqualFold(strings.TrimSpace(key), headerName) {
|
|
||||||
delete(rawHeaders, key)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
normalizedHeaders := ensureMapKeyInContext(context, paramOverrideContextHeaderOverrideNormalized)
|
|
||||||
normalizedHeaderName := normalizeHeaderContextKey(headerName)
|
|
||||||
delete(normalizedHeaders, normalizedHeaderName)
|
|
||||||
if normalizedHeaderName != "" {
|
|
||||||
deletedHeaders := ensureMapKeyInContext(context, paramOverrideContextHeaderOverrideDeleted)
|
|
||||||
deletedHeaders[normalizedHeaderName] = true
|
|
||||||
}
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func parseHeaderPassThroughNames(value interface{}) ([]string, error) {
|
func parseHeaderPassThroughNames(value interface{}) ([]string, error) {
|
||||||
normalizeNames := func(values []string) []string {
|
normalizeNames := func(values []string) []string {
|
||||||
names := lo.FilterMap(values, func(item string, _ int) (string, bool) {
|
names := lo.FilterMap(values, func(item string, _ int) (string, bool) {
|
||||||
headerName := strings.TrimSpace(item)
|
headerName := normalizeHeaderContextKey(item)
|
||||||
if headerName == "" {
|
if headerName == "" {
|
||||||
return "", false
|
return "", false
|
||||||
}
|
}
|
||||||
@@ -825,7 +782,20 @@ func parseHeaderPassThroughNames(value interface{}) ([]string, error) {
|
|||||||
return names, nil
|
return names, nil
|
||||||
case []interface{}:
|
case []interface{}:
|
||||||
names := lo.FilterMap(raw, func(item interface{}, _ int) (string, bool) {
|
names := lo.FilterMap(raw, func(item interface{}, _ int) (string, bool) {
|
||||||
headerName := strings.TrimSpace(fmt.Sprintf("%v", item))
|
headerName := normalizeHeaderContextKey(fmt.Sprintf("%v", item))
|
||||||
|
if headerName == "" {
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
return headerName, true
|
||||||
|
})
|
||||||
|
names = lo.Uniq(names)
|
||||||
|
if len(names) == 0 {
|
||||||
|
return nil, fmt.Errorf("pass_headers value is invalid")
|
||||||
|
}
|
||||||
|
return names, nil
|
||||||
|
case []string:
|
||||||
|
names := lo.FilterMap(raw, func(item string, _ int) (string, bool) {
|
||||||
|
headerName := normalizeHeaderContextKey(item)
|
||||||
if headerName == "" {
|
if headerName == "" {
|
||||||
return "", false
|
return "", false
|
||||||
}
|
}
|
||||||
@@ -994,76 +964,29 @@ func ensureMapKeyInContext(context map[string]interface{}, key string) map[strin
|
|||||||
}
|
}
|
||||||
|
|
||||||
func getHeaderValueFromContext(context map[string]interface{}, headerName string) (string, bool) {
|
func getHeaderValueFromContext(context map[string]interface{}, headerName string) (string, bool) {
|
||||||
headerName = strings.TrimSpace(headerName)
|
headerName = normalizeHeaderContextKey(headerName)
|
||||||
if headerName == "" {
|
if headerName == "" {
|
||||||
return "", false
|
return "", false
|
||||||
}
|
}
|
||||||
if value, ok := findHeaderValueInMap(ensureMapKeyInContext(context, paramOverrideContextHeaderOverride), headerName); ok {
|
for _, key := range []string{paramOverrideContextHeaderOverride, paramOverrideContextRequestHeaders} {
|
||||||
return value, true
|
source := ensureMapKeyInContext(context, key)
|
||||||
}
|
raw, ok := source[headerName]
|
||||||
if value, ok := findHeaderValueInMap(ensureMapKeyInContext(context, paramOverrideContextRequestHeadersRaw), headerName); ok {
|
if !ok {
|
||||||
return value, true
|
continue
|
||||||
}
|
}
|
||||||
|
value := strings.TrimSpace(fmt.Sprintf("%v", raw))
|
||||||
normalizedName := normalizeHeaderContextKey(headerName)
|
if value != "" {
|
||||||
if normalizedName == "" {
|
return value, true
|
||||||
return "", false
|
}
|
||||||
}
|
|
||||||
if value, ok := findHeaderValueInMap(ensureMapKeyInContext(context, paramOverrideContextHeaderOverrideNormalized), normalizedName); ok {
|
|
||||||
return value, true
|
|
||||||
}
|
|
||||||
if value, ok := findHeaderValueInMap(ensureMapKeyInContext(context, paramOverrideContextRequestHeaders), normalizedName); ok {
|
|
||||||
return value, true
|
|
||||||
}
|
}
|
||||||
return "", false
|
return "", false
|
||||||
}
|
}
|
||||||
|
|
||||||
func findHeaderValueInMap(source map[string]interface{}, key string) (string, bool) {
|
|
||||||
if len(source) == 0 {
|
|
||||||
return "", false
|
|
||||||
}
|
|
||||||
entries := lo.Entries(source)
|
|
||||||
entry, ok := lo.Find(entries, func(item lo.Entry[string, interface{}]) bool {
|
|
||||||
return strings.EqualFold(strings.TrimSpace(item.Key), key)
|
|
||||||
})
|
|
||||||
if !ok {
|
|
||||||
return "", false
|
|
||||||
}
|
|
||||||
value := strings.TrimSpace(fmt.Sprintf("%v", entry.Value))
|
|
||||||
if value == "" {
|
|
||||||
return "", false
|
|
||||||
}
|
|
||||||
return value, true
|
|
||||||
}
|
|
||||||
|
|
||||||
func normalizeHeaderContextKey(key string) string {
|
func normalizeHeaderContextKey(key string) string {
|
||||||
key = strings.TrimSpace(strings.ToLower(key))
|
return strings.TrimSpace(strings.ToLower(key))
|
||||||
if key == "" {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
var b strings.Builder
|
|
||||||
b.Grow(len(key))
|
|
||||||
previousUnderscore := false
|
|
||||||
for _, r := range key {
|
|
||||||
switch {
|
|
||||||
case r >= 'a' && r <= 'z':
|
|
||||||
b.WriteRune(r)
|
|
||||||
previousUnderscore = false
|
|
||||||
case r >= '0' && r <= '9':
|
|
||||||
b.WriteRune(r)
|
|
||||||
previousUnderscore = false
|
|
||||||
default:
|
|
||||||
if !previousUnderscore {
|
|
||||||
b.WriteByte('_')
|
|
||||||
previousUnderscore = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
result := strings.Trim(b.String(), "_")
|
|
||||||
return result
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func buildNormalizedHeaders(headers map[string]string) map[string]interface{} {
|
func buildRequestHeadersContext(headers map[string]string) map[string]interface{} {
|
||||||
if len(headers) == 0 {
|
if len(headers) == 0 {
|
||||||
return map[string]interface{}{}
|
return map[string]interface{}{}
|
||||||
}
|
}
|
||||||
@@ -1081,54 +1004,6 @@ func buildNormalizedHeaders(headers map[string]string) map[string]interface{} {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func buildRawHeaders(headers map[string]string) map[string]interface{} {
|
|
||||||
if len(headers) == 0 {
|
|
||||||
return map[string]interface{}{}
|
|
||||||
}
|
|
||||||
entries := lo.Entries(headers)
|
|
||||||
rawEntries := lo.FilterMap(entries, func(item lo.Entry[string, string], _ int) (lo.Entry[string, string], bool) {
|
|
||||||
key := strings.TrimSpace(item.Key)
|
|
||||||
value := strings.TrimSpace(item.Value)
|
|
||||||
if key == "" || value == "" {
|
|
||||||
return lo.Entry[string, string]{}, false
|
|
||||||
}
|
|
||||||
return lo.Entry[string, string]{Key: key, Value: value}, true
|
|
||||||
})
|
|
||||||
return lo.SliceToMap(rawEntries, func(item lo.Entry[string, string]) (string, interface{}) {
|
|
||||||
return item.Key, item.Value
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func buildHeaderOverrideContext(headers map[string]interface{}) (map[string]interface{}, map[string]interface{}) {
|
|
||||||
if len(headers) == 0 {
|
|
||||||
return map[string]interface{}{}, map[string]interface{}{}
|
|
||||||
}
|
|
||||||
entries := lo.Entries(headers)
|
|
||||||
rawEntries := lo.FilterMap(entries, func(item lo.Entry[string, interface{}], _ int) (lo.Entry[string, string], bool) {
|
|
||||||
key := strings.TrimSpace(item.Key)
|
|
||||||
value := strings.TrimSpace(fmt.Sprintf("%v", item.Value))
|
|
||||||
if key == "" || value == "" {
|
|
||||||
return lo.Entry[string, string]{}, false
|
|
||||||
}
|
|
||||||
return lo.Entry[string, string]{Key: key, Value: value}, true
|
|
||||||
})
|
|
||||||
|
|
||||||
raw := lo.SliceToMap(rawEntries, func(item lo.Entry[string, string]) (string, interface{}) {
|
|
||||||
return item.Key, item.Value
|
|
||||||
})
|
|
||||||
normalizedEntries := lo.FilterMap(rawEntries, func(item lo.Entry[string, string], _ int) (lo.Entry[string, string], bool) {
|
|
||||||
normalized := normalizeHeaderContextKey(item.Key)
|
|
||||||
if normalized == "" {
|
|
||||||
return lo.Entry[string, string]{}, false
|
|
||||||
}
|
|
||||||
return lo.Entry[string, string]{Key: normalized, Value: item.Value}, true
|
|
||||||
})
|
|
||||||
normalized := lo.SliceToMap(normalizedEntries, func(item lo.Entry[string, string]) (string, interface{}) {
|
|
||||||
return item.Key, item.Value
|
|
||||||
})
|
|
||||||
return raw, normalized
|
|
||||||
}
|
|
||||||
|
|
||||||
func syncRuntimeHeaderOverrideFromContext(info *RelayInfo, context map[string]interface{}) {
|
func syncRuntimeHeaderOverrideFromContext(info *RelayInfo, context map[string]interface{}) {
|
||||||
if info == nil || context == nil {
|
if info == nil || context == nil {
|
||||||
return
|
return
|
||||||
@@ -1141,55 +1016,10 @@ func syncRuntimeHeaderOverrideFromContext(info *RelayInfo, context map[string]in
|
|||||||
if !ok {
|
if !ok {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
info.RuntimeHeadersOverride = sanitizeHeaderOverrideMap(rawMap)
|
||||||
entries := lo.Entries(rawMap)
|
|
||||||
sanitized := lo.FilterMap(entries, func(item lo.Entry[string, interface{}], _ int) (lo.Entry[string, interface{}], bool) {
|
|
||||||
key := strings.TrimSpace(item.Key)
|
|
||||||
if key == "" {
|
|
||||||
return lo.Entry[string, interface{}]{}, false
|
|
||||||
}
|
|
||||||
value := strings.TrimSpace(fmt.Sprintf("%v", item.Value))
|
|
||||||
if value == "" {
|
|
||||||
return lo.Entry[string, interface{}]{}, false
|
|
||||||
}
|
|
||||||
return lo.Entry[string, interface{}]{Key: key, Value: value}, true
|
|
||||||
})
|
|
||||||
info.RuntimeHeadersOverride = lo.SliceToMap(sanitized, func(item lo.Entry[string, interface{}]) (string, interface{}) {
|
|
||||||
return item.Key, item.Value
|
|
||||||
})
|
|
||||||
info.RuntimeHeadersDeletedNormalized = sanitizeRuntimeDeletedHeadersFromContext(context)
|
|
||||||
info.UseRuntimeHeadersOverride = true
|
info.UseRuntimeHeadersOverride = true
|
||||||
}
|
}
|
||||||
|
|
||||||
func sanitizeRuntimeDeletedHeadersFromContext(context map[string]interface{}) map[string]bool {
|
|
||||||
deletedRaw, exists := context[paramOverrideContextHeaderOverrideDeleted]
|
|
||||||
if !exists {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
deletedMap, ok := deletedRaw.(map[string]interface{})
|
|
||||||
if !ok || len(deletedMap) == 0 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
entries := lo.Entries(deletedMap)
|
|
||||||
sanitized := lo.FilterMap(entries, func(item lo.Entry[string, interface{}], _ int) (string, bool) {
|
|
||||||
if keep, ok := item.Value.(bool); ok && !keep {
|
|
||||||
return "", false
|
|
||||||
}
|
|
||||||
normalized := normalizeHeaderContextKey(item.Key)
|
|
||||||
if normalized == "" {
|
|
||||||
return "", false
|
|
||||||
}
|
|
||||||
return normalized, true
|
|
||||||
})
|
|
||||||
if len(sanitized) == 0 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
keys := lo.Uniq(sanitized)
|
|
||||||
return lo.SliceToMap(keys, func(item string) (string, bool) {
|
|
||||||
return item, true
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func moveValue(jsonStr, fromPath, toPath string) (string, error) {
|
func moveValue(jsonStr, fromPath, toPath string) (string, error) {
|
||||||
sourceValue := gjson.Get(jsonStr, fromPath)
|
sourceValue := gjson.Get(jsonStr, fromPath)
|
||||||
if !sourceValue.Exists() {
|
if !sourceValue.Exists() {
|
||||||
@@ -1635,16 +1465,10 @@ func BuildParamOverrideContext(info *RelayInfo) map[string]interface{} {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx[paramOverrideContextRequestHeaders] = buildNormalizedHeaders(info.RequestHeaders)
|
ctx[paramOverrideContextRequestHeaders] = buildRequestHeadersContext(info.RequestHeaders)
|
||||||
ctx[paramOverrideContextRequestHeadersRaw] = buildRawHeaders(info.RequestHeaders)
|
|
||||||
|
|
||||||
headerOverrideSource := GetEffectiveHeaderOverride(info)
|
headerOverrideSource := GetEffectiveHeaderOverride(info)
|
||||||
rawHeaderOverride, normalizedHeaderOverride := buildHeaderOverrideContext(headerOverrideSource)
|
ctx[paramOverrideContextHeaderOverride] = sanitizeHeaderOverrideMap(headerOverrideSource)
|
||||||
ctx[paramOverrideContextHeaderOverride] = rawHeaderOverride
|
|
||||||
ctx[paramOverrideContextHeaderOverrideNormalized] = normalizedHeaderOverride
|
|
||||||
ctx[paramOverrideContextHeaderOverrideDeleted] = lo.SliceToMap(lo.Keys(cloneDeletedHeaderKeys(info.RuntimeHeadersDeletedNormalized)), func(item string) (string, interface{}) {
|
|
||||||
return item, true
|
|
||||||
})
|
|
||||||
|
|
||||||
ctx["retry_index"] = info.RetryIndex
|
ctx["retry_index"] = info.RetryIndex
|
||||||
ctx["is_retry"] = info.RetryIndex > 0
|
ctx["is_retry"] = info.RetryIndex > 0
|
||||||
|
|||||||
@@ -1005,7 +1005,7 @@ func TestApplyParamOverrideSetHeaderAndUseInLaterCondition(t *testing.T) {
|
|||||||
"value": 0.1,
|
"value": 0.1,
|
||||||
"conditions": []interface{}{
|
"conditions": []interface{}{
|
||||||
map[string]interface{}{
|
map[string]interface{}{
|
||||||
"path": "header_override_normalized.x_debug_mode",
|
"path": "header_override.x-debug-mode",
|
||||||
"mode": "full",
|
"mode": "full",
|
||||||
"value": "enabled",
|
"value": "enabled",
|
||||||
},
|
},
|
||||||
@@ -1036,7 +1036,7 @@ func TestApplyParamOverrideCopyHeaderFromRequestHeaders(t *testing.T) {
|
|||||||
"value": 0.1,
|
"value": 0.1,
|
||||||
"conditions": []interface{}{
|
"conditions": []interface{}{
|
||||||
map[string]interface{}{
|
map[string]interface{}{
|
||||||
"path": "header_override_normalized.x_upstream_auth",
|
"path": "header_override.x-upstream-auth",
|
||||||
"mode": "contains",
|
"mode": "contains",
|
||||||
"value": "Bearer ",
|
"value": "Bearer ",
|
||||||
},
|
},
|
||||||
@@ -1045,9 +1045,6 @@ func TestApplyParamOverrideCopyHeaderFromRequestHeaders(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
ctx := map[string]interface{}{
|
ctx := map[string]interface{}{
|
||||||
"request_headers_raw": map[string]interface{}{
|
|
||||||
"Authorization": "Bearer token-123",
|
|
||||||
},
|
|
||||||
"request_headers": map[string]interface{}{
|
"request_headers": map[string]interface{}{
|
||||||
"authorization": "Bearer token-123",
|
"authorization": "Bearer token-123",
|
||||||
},
|
},
|
||||||
@@ -1071,9 +1068,6 @@ func TestApplyParamOverridePassHeadersSkipsMissingHeaders(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
ctx := map[string]interface{}{
|
ctx := map[string]interface{}{
|
||||||
"request_headers_raw": map[string]interface{}{
|
|
||||||
"Session_id": "sess-123",
|
|
||||||
},
|
|
||||||
"request_headers": map[string]interface{}{
|
"request_headers": map[string]interface{}{
|
||||||
"session_id": "sess-123",
|
"session_id": "sess-123",
|
||||||
},
|
},
|
||||||
@@ -1089,10 +1083,10 @@ func TestApplyParamOverridePassHeadersSkipsMissingHeaders(t *testing.T) {
|
|||||||
if !ok {
|
if !ok {
|
||||||
t.Fatalf("expected header_override context map")
|
t.Fatalf("expected header_override context map")
|
||||||
}
|
}
|
||||||
if headers["Session_id"] != "sess-123" {
|
if headers["session_id"] != "sess-123" {
|
||||||
t.Fatalf("expected Session_id to be passed, got: %v", headers["Session_id"])
|
t.Fatalf("expected session_id to be passed, got: %v", headers["session_id"])
|
||||||
}
|
}
|
||||||
if _, exists := headers["X-Codex-Beta-Features"]; exists {
|
if _, exists := headers["x-codex-beta-features"]; exists {
|
||||||
t.Fatalf("expected missing header to be skipped")
|
t.Fatalf("expected missing header to be skipped")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -1109,9 +1103,6 @@ func TestApplyParamOverrideCopyHeaderSkipsMissingSource(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
ctx := map[string]interface{}{
|
ctx := map[string]interface{}{
|
||||||
"request_headers_raw": map[string]interface{}{
|
|
||||||
"Authorization": "Bearer token-123",
|
|
||||||
},
|
|
||||||
"request_headers": map[string]interface{}{
|
"request_headers": map[string]interface{}{
|
||||||
"authorization": "Bearer token-123",
|
"authorization": "Bearer token-123",
|
||||||
},
|
},
|
||||||
@@ -1127,7 +1118,7 @@ func TestApplyParamOverrideCopyHeaderSkipsMissingSource(t *testing.T) {
|
|||||||
if !ok {
|
if !ok {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if _, exists := headers["X-Upstream-Auth"]; exists {
|
if _, exists := headers["x-upstream-auth"]; exists {
|
||||||
t.Fatalf("expected X-Upstream-Auth to be skipped when source header is missing")
|
t.Fatalf("expected X-Upstream-Auth to be skipped when source header is missing")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -1144,9 +1135,6 @@ func TestApplyParamOverrideMoveHeaderSkipsMissingSource(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
ctx := map[string]interface{}{
|
ctx := map[string]interface{}{
|
||||||
"request_headers_raw": map[string]interface{}{
|
|
||||||
"Authorization": "Bearer token-123",
|
|
||||||
},
|
|
||||||
"request_headers": map[string]interface{}{
|
"request_headers": map[string]interface{}{
|
||||||
"authorization": "Bearer token-123",
|
"authorization": "Bearer token-123",
|
||||||
},
|
},
|
||||||
@@ -1162,7 +1150,7 @@ func TestApplyParamOverrideMoveHeaderSkipsMissingSource(t *testing.T) {
|
|||||||
if !ok {
|
if !ok {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if _, exists := headers["X-Upstream-Auth"]; exists {
|
if _, exists := headers["x-upstream-auth"]; exists {
|
||||||
t.Fatalf("expected X-Upstream-Auth to be skipped when source header is missing")
|
t.Fatalf("expected X-Upstream-Auth to be skipped when source header is missing")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -1179,9 +1167,6 @@ func TestApplyParamOverrideSyncFieldsHeaderToJSON(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
ctx := map[string]interface{}{
|
ctx := map[string]interface{}{
|
||||||
"request_headers_raw": map[string]interface{}{
|
|
||||||
"session_id": "sess-123",
|
|
||||||
},
|
|
||||||
"request_headers": map[string]interface{}{
|
"request_headers": map[string]interface{}{
|
||||||
"session_id": "sess-123",
|
"session_id": "sess-123",
|
||||||
},
|
},
|
||||||
@@ -1234,9 +1219,6 @@ func TestApplyParamOverrideSyncFieldsNoChangeWhenBothExist(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
ctx := map[string]interface{}{
|
ctx := map[string]interface{}{
|
||||||
"request_headers_raw": map[string]interface{}{
|
|
||||||
"session_id": "cache-header",
|
|
||||||
},
|
|
||||||
"request_headers": map[string]interface{}{
|
"request_headers": map[string]interface{}{
|
||||||
"session_id": "cache-header",
|
"session_id": "cache-header",
|
||||||
},
|
},
|
||||||
@@ -1288,10 +1270,7 @@ func TestApplyParamOverrideSetHeaderKeepOrigin(t *testing.T) {
|
|||||||
}
|
}
|
||||||
ctx := map[string]interface{}{
|
ctx := map[string]interface{}{
|
||||||
"header_override": map[string]interface{}{
|
"header_override": map[string]interface{}{
|
||||||
"X-Feature-Flag": "legacy-value",
|
"x-feature-flag": "legacy-value",
|
||||||
},
|
|
||||||
"header_override_normalized": map[string]interface{}{
|
|
||||||
"x_feature_flag": "legacy-value",
|
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1303,8 +1282,8 @@ func TestApplyParamOverrideSetHeaderKeepOrigin(t *testing.T) {
|
|||||||
if !ok {
|
if !ok {
|
||||||
t.Fatalf("expected header_override context map")
|
t.Fatalf("expected header_override context map")
|
||||||
}
|
}
|
||||||
if headers["X-Feature-Flag"] != "legacy-value" {
|
if headers["x-feature-flag"] != "legacy-value" {
|
||||||
t.Fatalf("expected keep_origin to preserve old value, got: %v", headers["X-Feature-Flag"])
|
t.Fatalf("expected keep_origin to preserve old value, got: %v", headers["x-feature-flag"])
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1371,14 +1350,14 @@ func TestApplyParamOverrideWithRelayInfoSyncRuntimeHeaders(t *testing.T) {
|
|||||||
if !info.UseRuntimeHeadersOverride {
|
if !info.UseRuntimeHeadersOverride {
|
||||||
t.Fatalf("expected runtime header override to be enabled")
|
t.Fatalf("expected runtime header override to be enabled")
|
||||||
}
|
}
|
||||||
if info.RuntimeHeadersOverride["X-Keep-Me"] != "keep" {
|
if info.RuntimeHeadersOverride["x-keep-me"] != "keep" {
|
||||||
t.Fatalf("expected X-Keep-Me header to be preserved, got: %v", info.RuntimeHeadersOverride["X-Keep-Me"])
|
t.Fatalf("expected x-keep-me header to be preserved, got: %v", info.RuntimeHeadersOverride["x-keep-me"])
|
||||||
}
|
}
|
||||||
if info.RuntimeHeadersOverride["X-Injected-By-Param-Override"] != "enabled" {
|
if info.RuntimeHeadersOverride["x-injected-by-param-override"] != "enabled" {
|
||||||
t.Fatalf("expected X-Injected-By-Param-Override header to be set, got: %v", info.RuntimeHeadersOverride["X-Injected-By-Param-Override"])
|
t.Fatalf("expected x-injected-by-param-override header to be set, got: %v", info.RuntimeHeadersOverride["x-injected-by-param-override"])
|
||||||
}
|
}
|
||||||
if _, exists := info.RuntimeHeadersOverride["X-Delete-Me"]; exists {
|
if _, exists := info.RuntimeHeadersOverride["x-delete-me"]; exists {
|
||||||
t.Fatalf("expected X-Delete-Me header to be deleted")
|
t.Fatalf("expected x-delete-me header to be deleted")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1410,25 +1389,22 @@ func TestApplyParamOverrideWithRelayInfoMoveAndCopyHeaders(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("ApplyParamOverrideWithRelayInfo returned error: %v", err)
|
t.Fatalf("ApplyParamOverrideWithRelayInfo returned error: %v", err)
|
||||||
}
|
}
|
||||||
if _, exists := info.RuntimeHeadersOverride["X-Legacy-Trace"]; exists {
|
if _, exists := info.RuntimeHeadersOverride["x-legacy-trace"]; exists {
|
||||||
t.Fatalf("expected source header to be removed after move")
|
t.Fatalf("expected source header to be removed after move")
|
||||||
}
|
}
|
||||||
if info.RuntimeHeadersOverride["X-Trace"] != "trace-123" {
|
if info.RuntimeHeadersOverride["x-trace"] != "trace-123" {
|
||||||
t.Fatalf("expected X-Trace to be set, got: %v", info.RuntimeHeadersOverride["X-Trace"])
|
t.Fatalf("expected x-trace to be set, got: %v", info.RuntimeHeadersOverride["x-trace"])
|
||||||
}
|
}
|
||||||
if info.RuntimeHeadersOverride["X-Trace-Backup"] != "trace-123" {
|
if info.RuntimeHeadersOverride["x-trace-backup"] != "trace-123" {
|
||||||
t.Fatalf("expected X-Trace-Backup to be copied, got: %v", info.RuntimeHeadersOverride["X-Trace-Backup"])
|
t.Fatalf("expected x-trace-backup to be copied, got: %v", info.RuntimeHeadersOverride["x-trace-backup"])
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestGetEffectiveHeaderOverrideMergesRuntimeAndChannelOverrides(t *testing.T) {
|
func TestGetEffectiveHeaderOverrideUsesRuntimeOverrideAsFinalResult(t *testing.T) {
|
||||||
info := &RelayInfo{
|
info := &RelayInfo{
|
||||||
UseRuntimeHeadersOverride: true,
|
UseRuntimeHeadersOverride: true,
|
||||||
RuntimeHeadersOverride: map[string]interface{}{
|
RuntimeHeadersOverride: map[string]interface{}{
|
||||||
"X-Runtime": "runtime-only",
|
"x-runtime": "runtime-only",
|
||||||
},
|
|
||||||
RuntimeHeadersDeletedNormalized: map[string]bool{
|
|
||||||
"x-deleted": true,
|
|
||||||
},
|
},
|
||||||
ChannelMeta: &ChannelMeta{
|
ChannelMeta: &ChannelMeta{
|
||||||
HeadersOverride: map[string]interface{}{
|
HeadersOverride: map[string]interface{}{
|
||||||
@@ -1439,14 +1415,11 @@ func TestGetEffectiveHeaderOverrideMergesRuntimeAndChannelOverrides(t *testing.T
|
|||||||
}
|
}
|
||||||
|
|
||||||
effective := GetEffectiveHeaderOverride(info)
|
effective := GetEffectiveHeaderOverride(info)
|
||||||
if effective["X-Static"] != "static-value" {
|
if effective["x-runtime"] != "runtime-only" {
|
||||||
t.Fatalf("expected X-Static from channel override, got: %v", effective["X-Static"])
|
t.Fatalf("expected x-runtime from runtime override, got: %v", effective["x-runtime"])
|
||||||
}
|
}
|
||||||
if effective["X-Runtime"] != "runtime-only" {
|
if _, exists := effective["x-static"]; exists {
|
||||||
t.Fatalf("expected X-Runtime from runtime override, got: %v", effective["X-Runtime"])
|
t.Fatalf("expected runtime override to be final and not merge channel headers")
|
||||||
}
|
|
||||||
if _, exists := effective["X-Deleted"]; exists {
|
|
||||||
t.Fatalf("expected deleted headers to stay deleted in effective override")
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -148,7 +148,6 @@ type RelayInfo struct {
|
|||||||
RetryIndex int
|
RetryIndex int
|
||||||
LastError *types.NewAPIError
|
LastError *types.NewAPIError
|
||||||
RuntimeHeadersOverride map[string]interface{}
|
RuntimeHeadersOverride map[string]interface{}
|
||||||
RuntimeHeadersDeletedNormalized map[string]bool
|
|
||||||
UseRuntimeHeadersOverride bool
|
UseRuntimeHeadersOverride bool
|
||||||
|
|
||||||
PriceData types.PriceData
|
PriceData types.PriceData
|
||||||
|
|||||||
@@ -1,9 +1,15 @@
|
|||||||
package service
|
package service
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||||||
|
"github.com/QuantumNous/new-api/setting/operation_setting"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
@@ -67,3 +73,73 @@ func TestApplyChannelAffinityOverrideTemplate_MergeTemplate(t *testing.T) {
|
|||||||
require.Equal(t, "rule-with-template", overrideInfo["rule_name"])
|
require.Equal(t, "rule-with-template", overrideInfo["rule_name"])
|
||||||
require.EqualValues(t, 2, overrideInfo["param_override_keys"])
|
require.EqualValues(t, 2, overrideInfo["param_override_keys"])
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestChannelAffinityHitCodexTemplatePassHeadersEffective(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
setting := operation_setting.GetChannelAffinitySetting()
|
||||||
|
require.NotNil(t, setting)
|
||||||
|
|
||||||
|
var codexRule *operation_setting.ChannelAffinityRule
|
||||||
|
for i := range setting.Rules {
|
||||||
|
rule := &setting.Rules[i]
|
||||||
|
if strings.EqualFold(strings.TrimSpace(rule.Name), "codex cli trace") {
|
||||||
|
codexRule = rule
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
require.NotNil(t, codexRule)
|
||||||
|
|
||||||
|
affinityValue := fmt.Sprintf("pc-hit-%d", time.Now().UnixNano())
|
||||||
|
cacheKeySuffix := buildChannelAffinityCacheKeySuffix(*codexRule, "default", affinityValue)
|
||||||
|
|
||||||
|
cache := getChannelAffinityCache()
|
||||||
|
require.NoError(t, cache.SetWithTTL(cacheKeySuffix, 9527, time.Minute))
|
||||||
|
t.Cleanup(func() {
|
||||||
|
_, _ = cache.DeleteMany([]string{cacheKeySuffix})
|
||||||
|
})
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
ctx, _ := gin.CreateTestContext(rec)
|
||||||
|
ctx.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", strings.NewReader(fmt.Sprintf(`{"prompt_cache_key":"%s"}`, affinityValue)))
|
||||||
|
ctx.Request.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
channelID, found := GetPreferredChannelByAffinity(ctx, "gpt-5", "default")
|
||||||
|
require.True(t, found)
|
||||||
|
require.Equal(t, 9527, channelID)
|
||||||
|
|
||||||
|
baseOverride := map[string]interface{}{
|
||||||
|
"temperature": 0.2,
|
||||||
|
}
|
||||||
|
mergedOverride, applied := ApplyChannelAffinityOverrideTemplate(ctx, baseOverride)
|
||||||
|
require.True(t, applied)
|
||||||
|
require.Equal(t, 0.2, mergedOverride["temperature"])
|
||||||
|
|
||||||
|
info := &relaycommon.RelayInfo{
|
||||||
|
RequestHeaders: map[string]string{
|
||||||
|
"Originator": "Codex CLI",
|
||||||
|
"Session_id": "sess-123",
|
||||||
|
"User-Agent": "codex-cli-test",
|
||||||
|
},
|
||||||
|
ChannelMeta: &relaycommon.ChannelMeta{
|
||||||
|
ParamOverride: mergedOverride,
|
||||||
|
HeadersOverride: map[string]interface{}{
|
||||||
|
"X-Static": "legacy-static",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := relaycommon.ApplyParamOverrideWithRelayInfo([]byte(`{"model":"gpt-5"}`), info)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.True(t, info.UseRuntimeHeadersOverride)
|
||||||
|
|
||||||
|
require.Equal(t, "legacy-static", info.RuntimeHeadersOverride["x-static"])
|
||||||
|
require.Equal(t, "Codex CLI", info.RuntimeHeadersOverride["originator"])
|
||||||
|
require.Equal(t, "sess-123", info.RuntimeHeadersOverride["session_id"])
|
||||||
|
require.Equal(t, "codex-cli-test", info.RuntimeHeadersOverride["user-agent"])
|
||||||
|
|
||||||
|
_, exists := info.RuntimeHeadersOverride["x-codex-beta-features"]
|
||||||
|
require.False(t, exists)
|
||||||
|
_, exists = info.RuntimeHeadersOverride["x-codex-turn-metadata"]
|
||||||
|
require.False(t, exists)
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user