diff --git a/relay/channel/api_request.go b/relay/channel/api_request.go index c1954cc02..79eac3ad6 100644 --- a/relay/channel/api_request.go +++ b/relay/channel/api_request.go @@ -179,7 +179,7 @@ func processHeaderOverride(info *common.RelayInfo, c *gin.Context) (map[string]s var passthroughRegex []*regexp.Regexp if !info.IsChannelTest { for k := range headerOverrideSource { - key := strings.TrimSpace(k) + key := strings.TrimSpace(strings.ToLower(k)) if key == "" { continue } @@ -188,12 +188,11 @@ func processHeaderOverride(info *common.RelayInfo, c *gin.Context) (map[string]s continue } - lower := strings.ToLower(key) var pattern string switch { - case strings.HasPrefix(lower, headerPassthroughRegexPrefix): + case strings.HasPrefix(key, headerPassthroughRegexPrefix): pattern = strings.TrimSpace(key[len(headerPassthroughRegexPrefix):]) - case strings.HasPrefix(lower, headerPassthroughRegexPrefixV2): + case strings.HasPrefix(key, headerPassthroughRegexPrefixV2): pattern = strings.TrimSpace(key[len(headerPassthroughRegexPrefixV2):]) default: continue @@ -234,7 +233,7 @@ func processHeaderOverride(info *common.RelayInfo, c *gin.Context) (map[string]s if value == "" { continue } - headerOverride[name] = value + headerOverride[strings.ToLower(strings.TrimSpace(name))] = value } } @@ -242,7 +241,7 @@ func processHeaderOverride(info *common.RelayInfo, c *gin.Context) (map[string]s if isHeaderPassthroughRuleKey(k) { continue } - key := strings.TrimSpace(k) + key := strings.TrimSpace(strings.ToLower(k)) if key == "" { continue } diff --git a/relay/channel/api_request_test.go b/relay/channel/api_request_test.go index 84406ba48..f697f8555 100644 --- a/relay/channel/api_request_test.go +++ b/relay/channel/api_request_test.go @@ -53,7 +53,7 @@ func TestProcessHeaderOverride_ChannelTestSkipsClientHeaderPlaceholder(t *testin headers, err := processHeaderOverride(info, ctx) require.NoError(t, err) - _, ok := headers["X-Upstream-Trace"] + _, ok := headers["x-upstream-trace"] require.False(t, ok) } @@ -77,10 +77,10 @@ func TestProcessHeaderOverride_NonTestKeepsClientHeaderPlaceholder(t *testing.T) headers, err := processHeaderOverride(info, ctx) require.NoError(t, err) - require.Equal(t, "trace-123", headers["X-Upstream-Trace"]) + require.Equal(t, "trace-123", headers["x-upstream-trace"]) } -func TestProcessHeaderOverride_RuntimeOverrideMergesWithChannelOverride(t *testing.T) { +func TestProcessHeaderOverride_RuntimeOverrideIsFinalHeaderMap(t *testing.T) { t.Parallel() gin.SetMode(gin.TestMode) @@ -92,8 +92,8 @@ func TestProcessHeaderOverride_RuntimeOverrideMergesWithChannelOverride(t *testi IsChannelTest: false, UseRuntimeHeadersOverride: true, RuntimeHeadersOverride: map[string]any{ - "X-Static": "runtime-value", - "X-Runtime": "runtime-only", + "x-static": "runtime-value", + "x-runtime": "runtime-only", }, ChannelMeta: &relaycommon.ChannelMeta{ HeadersOverride: map[string]any{ @@ -105,9 +105,10 @@ func TestProcessHeaderOverride_RuntimeOverrideMergesWithChannelOverride(t *testi headers, err := processHeaderOverride(info, ctx) require.NoError(t, err) - require.Equal(t, "runtime-value", headers["X-Static"]) - require.Equal(t, "runtime-only", headers["X-Runtime"]) - require.Equal(t, "legacy-only", headers["X-Legacy"]) + require.Equal(t, "runtime-value", headers["x-static"]) + require.Equal(t, "runtime-only", headers["x-runtime"]) + _, exists := headers["x-legacy"] + require.False(t, exists) } func TestProcessHeaderOverride_PassthroughSkipsAcceptEncoding(t *testing.T) { @@ -131,9 +132,9 @@ func TestProcessHeaderOverride_PassthroughSkipsAcceptEncoding(t *testing.T) { headers, err := processHeaderOverride(info, ctx) require.NoError(t, err) - require.Equal(t, "trace-123", headers["X-Trace-Id"]) + require.Equal(t, "trace-123", headers["x-trace-id"]) - _, hasAcceptEncoding := headers["Accept-Encoding"] + _, hasAcceptEncoding := headers["accept-encoding"] require.False(t, hasAcceptEncoding) } @@ -171,16 +172,17 @@ func TestProcessHeaderOverride_PassHeadersTemplateSetsRuntimeHeaders(t *testing. _, err := relaycommon.ApplyParamOverrideWithRelayInfo([]byte(`{"model":"gpt-4.1"}`), info) require.NoError(t, err) require.True(t, info.UseRuntimeHeadersOverride) - require.Equal(t, "Codex CLI", info.RuntimeHeadersOverride["Originator"]) - require.Equal(t, "sess-123", info.RuntimeHeadersOverride["Session_id"]) - _, exists := info.RuntimeHeadersOverride["X-Codex-Beta-Features"] + require.Equal(t, "Codex CLI", info.RuntimeHeadersOverride["originator"]) + require.Equal(t, "sess-123", info.RuntimeHeadersOverride["session_id"]) + _, exists := info.RuntimeHeadersOverride["x-codex-beta-features"] require.False(t, exists) + require.Equal(t, "legacy-value", info.RuntimeHeadersOverride["x-static"]) headers, err := processHeaderOverride(info, ctx) require.NoError(t, err) - require.Equal(t, "Codex CLI", headers["Originator"]) - require.Equal(t, "sess-123", headers["Session_id"]) - _, exists = headers["X-Codex-Beta-Features"] + require.Equal(t, "Codex CLI", headers["originator"]) + require.Equal(t, "sess-123", headers["session_id"]) + _, exists = headers["x-codex-beta-features"] require.False(t, exists) upstreamReq := httptest.NewRequest(http.MethodPost, "https://example.com/v1/responses", nil) diff --git a/relay/common/override.go b/relay/common/override.go index 78cf60b92..95af8cfae 100644 --- a/relay/common/override.go +++ b/relay/common/override.go @@ -18,11 +18,8 @@ import ( var negativeIndexRegexp = regexp.MustCompile(`\.(-\d+)`) const ( - paramOverrideContextRequestHeaders = "request_headers" - paramOverrideContextRequestHeadersRaw = "request_headers_raw" - paramOverrideContextHeaderOverride = "header_override" - paramOverrideContextHeaderOverrideNormalized = "header_override_normalized" - paramOverrideContextHeaderOverrideDeleted = "header_override_deleted_normalized" + paramOverrideContextRequestHeaders = "request_headers" + paramOverrideContextHeaderOverride = "header_override" ) var errSourceHeaderNotFound = errors.New("source header does not exist") @@ -161,141 +158,118 @@ func getHeaderOverrideMap(info *RelayInfo) map[string]interface{} { return info.ChannelMeta.HeadersOverride } -func cloneHeaderOverrideMap(source map[string]interface{}) map[string]interface{} { +func sanitizeHeaderOverrideMap(source map[string]interface{}) map[string]interface{} { if len(source) == 0 { return map[string]interface{}{} } target := make(map[string]interface{}, len(source)) for key, value := range source { - target[key] = value + normalizedKey := normalizeHeaderContextKey(key) + if normalizedKey == "" { + continue + } + normalizedValue := strings.TrimSpace(fmt.Sprintf("%v", value)) + if normalizedValue == "" { + if isHeaderPassthroughRuleKeyForOverride(normalizedKey) { + target[normalizedKey] = "" + } + continue + } + target[normalizedKey] = normalizedValue } return target } -func setHeaderOverrideEntry(target map[string]interface{}, key string, value interface{}) { - key = strings.TrimSpace(key) +func isHeaderPassthroughRuleKeyForOverride(key string) bool { + key = strings.TrimSpace(strings.ToLower(key)) if key == "" { - return - } - for existing := range target { - if strings.EqualFold(strings.TrimSpace(existing), key) { - delete(target, existing) - } - } - target[key] = value -} - -func isHeaderDeletedByRuntime(headerName string, deleted map[string]bool) bool { - if len(deleted) == 0 { return false } - normalized := normalizeHeaderContextKey(headerName) - if normalized == "" { - return false + if key == "*" { + return true } - return deleted[normalized] -} - -func mergeHeaderOverrideSource(base, runtime map[string]interface{}, deleted map[string]bool) map[string]interface{} { - merged := make(map[string]interface{}, len(base)+len(runtime)) - for key, value := range base { - if isHeaderDeletedByRuntime(key, deleted) { - continue - } - setHeaderOverrideEntry(merged, key, value) - } - for key, value := range runtime { - setHeaderOverrideEntry(merged, key, value) - } - return merged -} - -func cloneDeletedHeaderKeys(source map[string]bool) map[string]bool { - if len(source) == 0 { - return map[string]bool{} - } - target := make(map[string]bool, len(source)) - for key, value := range source { - if !value { - continue - } - normalized := normalizeHeaderContextKey(key) - if normalized == "" { - continue - } - target[normalized] = true - } - return target + return strings.HasPrefix(key, "re:") || strings.HasPrefix(key, "regex:") } func GetEffectiveHeaderOverride(info *RelayInfo) map[string]interface{} { if info == nil { return map[string]interface{}{} } - base := getHeaderOverrideMap(info) - if !info.UseRuntimeHeadersOverride { - return cloneHeaderOverrideMap(base) + if info.UseRuntimeHeadersOverride { + return sanitizeHeaderOverrideMap(info.RuntimeHeadersOverride) } - return mergeHeaderOverrideSource(base, info.RuntimeHeadersOverride, cloneDeletedHeaderKeys(info.RuntimeHeadersDeletedNormalized)) + return sanitizeHeaderOverrideMap(getHeaderOverrideMap(info)) } func tryParseOperations(paramOverride map[string]interface{}) ([]ParamOperation, bool) { // 检查是否包含 "operations" 字段 - if opsValue, exists := paramOverride["operations"]; exists { - if opsSlice, ok := opsValue.([]interface{}); ok { - var operations []ParamOperation - for _, op := range opsSlice { - if opMap, ok := op.(map[string]interface{}); ok { - operation := ParamOperation{} - - // 断言必要字段 - if path, ok := opMap["path"].(string); ok { - operation.Path = path - } - if mode, ok := opMap["mode"].(string); ok { - operation.Mode = mode - } else { - return nil, false // mode 是必需的 - } - - // 可选字段 - if value, exists := opMap["value"]; exists { - operation.Value = value - } - if keepOrigin, ok := opMap["keep_origin"].(bool); ok { - operation.KeepOrigin = keepOrigin - } - if from, ok := opMap["from"].(string); ok { - operation.From = from - } - if to, ok := opMap["to"].(string); ok { - operation.To = to - } - if logic, ok := opMap["logic"].(string); ok { - operation.Logic = logic - } else { - operation.Logic = "OR" // 默认为OR - } - - // 解析条件 - if conditions, exists := opMap["conditions"]; exists { - parsedConditions, err := parseConditionOperations(conditions) - if err != nil { - return nil, false - } - operation.Conditions = append(operation.Conditions, parsedConditions...) - } - - operations = append(operations, operation) - } else { - return nil, false - } - } - return operations, true - } + opsValue, exists := paramOverride["operations"] + if !exists { + return nil, false } - return nil, false + var opMaps []map[string]interface{} + switch ops := opsValue.(type) { + case []interface{}: + opMaps = make([]map[string]interface{}, 0, len(ops)) + for _, op := range ops { + opMap, ok := op.(map[string]interface{}) + if !ok { + return nil, false + } + opMaps = append(opMaps, opMap) + } + case []map[string]interface{}: + opMaps = ops + default: + return nil, false + } + + operations := make([]ParamOperation, 0, len(opMaps)) + for _, opMap := range opMaps { + operation := ParamOperation{} + + // 断言必要字段 + if path, ok := opMap["path"].(string); ok { + operation.Path = path + } + if mode, ok := opMap["mode"].(string); ok { + operation.Mode = mode + } else { + return nil, false // mode 是必需的 + } + + // 可选字段 + if value, exists := opMap["value"]; exists { + operation.Value = value + } + if keepOrigin, ok := opMap["keep_origin"].(bool); ok { + operation.KeepOrigin = keepOrigin + } + if from, ok := opMap["from"].(string); ok { + operation.From = from + } + if to, ok := opMap["to"].(string); ok { + operation.To = to + } + if logic, ok := opMap["logic"].(string); ok { + operation.Logic = logic + } else { + operation.Logic = "OR" // 默认为OR + } + + // 解析条件 + if conditions, exists := opMap["conditions"]; exists { + parsedConditions, err := parseConditionOperations(conditions) + if err != nil { + return nil, false + } + operation.Conditions = append(operation.Conditions, parsedConditions...) + } + + operations = append(operations, operation) + } + return operations, true } func checkConditions(jsonStr, contextJSON string, conditions []ConditionOperation, logic string) (bool, error) { @@ -712,15 +686,10 @@ func marshalContextJSON(context map[string]interface{}) (string, error) { } func setHeaderOverrideInContext(context map[string]interface{}, headerName string, value interface{}, keepOrigin bool) error { - headerName = strings.TrimSpace(headerName) + headerName = normalizeHeaderContextKey(headerName) if headerName == "" { return fmt.Errorf("header name is required") } - if keepOrigin { - if _, exists := getHeaderValueFromContext(context, headerName); exists { - return nil - } - } if value == nil { return fmt.Errorf("header value is required") } @@ -730,21 +699,21 @@ func setHeaderOverrideInContext(context map[string]interface{}, headerName strin } rawHeaders := ensureMapKeyInContext(context, paramOverrideContextHeaderOverride) - rawHeaders[headerName] = headerValue - - normalizedHeaderName := normalizeHeaderContextKey(headerName) - normalizedHeaders := ensureMapKeyInContext(context, paramOverrideContextHeaderOverrideNormalized) - normalizedHeaders[normalizedHeaderName] = headerValue - if normalizedHeaderName != "" { - deletedHeaders := ensureMapKeyInContext(context, paramOverrideContextHeaderOverrideDeleted) - delete(deletedHeaders, normalizedHeaderName) + if keepOrigin { + if existing, ok := rawHeaders[headerName]; ok { + existingValue := strings.TrimSpace(fmt.Sprintf("%v", existing)) + if existingValue != "" { + return nil + } + } } + rawHeaders[headerName] = headerValue return nil } func copyHeaderInContext(context map[string]interface{}, fromHeader, toHeader string, keepOrigin bool) error { - fromHeader = strings.TrimSpace(fromHeader) - toHeader = strings.TrimSpace(toHeader) + fromHeader = normalizeHeaderContextKey(fromHeader) + toHeader = normalizeHeaderContextKey(toHeader) if fromHeader == "" || toHeader == "" { return fmt.Errorf("copy_header from/to is required") } @@ -756,8 +725,8 @@ func copyHeaderInContext(context map[string]interface{}, fromHeader, toHeader st } func moveHeaderInContext(context map[string]interface{}, fromHeader, toHeader string, keepOrigin bool) error { - fromHeader = strings.TrimSpace(fromHeader) - toHeader = strings.TrimSpace(toHeader) + fromHeader = normalizeHeaderContextKey(fromHeader) + toHeader = normalizeHeaderContextKey(toHeader) if fromHeader == "" || toHeader == "" { return fmt.Errorf("move_header from/to is required") } @@ -771,31 +740,19 @@ func moveHeaderInContext(context map[string]interface{}, fromHeader, toHeader st } func deleteHeaderOverrideInContext(context map[string]interface{}, headerName string) error { - headerName = strings.TrimSpace(headerName) + headerName = normalizeHeaderContextKey(headerName) if headerName == "" { return fmt.Errorf("header name is required") } rawHeaders := ensureMapKeyInContext(context, paramOverrideContextHeaderOverride) - for key := range rawHeaders { - if strings.EqualFold(strings.TrimSpace(key), headerName) { - delete(rawHeaders, key) - } - } - - normalizedHeaders := ensureMapKeyInContext(context, paramOverrideContextHeaderOverrideNormalized) - normalizedHeaderName := normalizeHeaderContextKey(headerName) - delete(normalizedHeaders, normalizedHeaderName) - if normalizedHeaderName != "" { - deletedHeaders := ensureMapKeyInContext(context, paramOverrideContextHeaderOverrideDeleted) - deletedHeaders[normalizedHeaderName] = true - } + delete(rawHeaders, headerName) return nil } func parseHeaderPassThroughNames(value interface{}) ([]string, error) { normalizeNames := func(values []string) []string { names := lo.FilterMap(values, func(item string, _ int) (string, bool) { - headerName := strings.TrimSpace(item) + headerName := normalizeHeaderContextKey(item) if headerName == "" { return "", false } @@ -825,7 +782,20 @@ func parseHeaderPassThroughNames(value interface{}) ([]string, error) { return names, nil case []interface{}: names := lo.FilterMap(raw, func(item interface{}, _ int) (string, bool) { - headerName := strings.TrimSpace(fmt.Sprintf("%v", item)) + headerName := normalizeHeaderContextKey(fmt.Sprintf("%v", item)) + if headerName == "" { + return "", false + } + return headerName, true + }) + names = lo.Uniq(names) + if len(names) == 0 { + return nil, fmt.Errorf("pass_headers value is invalid") + } + return names, nil + case []string: + names := lo.FilterMap(raw, func(item string, _ int) (string, bool) { + headerName := normalizeHeaderContextKey(item) if headerName == "" { return "", false } @@ -994,76 +964,29 @@ func ensureMapKeyInContext(context map[string]interface{}, key string) map[strin } func getHeaderValueFromContext(context map[string]interface{}, headerName string) (string, bool) { - headerName = strings.TrimSpace(headerName) + headerName = normalizeHeaderContextKey(headerName) if headerName == "" { return "", false } - if value, ok := findHeaderValueInMap(ensureMapKeyInContext(context, paramOverrideContextHeaderOverride), headerName); ok { - return value, true - } - if value, ok := findHeaderValueInMap(ensureMapKeyInContext(context, paramOverrideContextRequestHeadersRaw), headerName); ok { - return value, true - } - - normalizedName := normalizeHeaderContextKey(headerName) - if normalizedName == "" { - return "", false - } - if value, ok := findHeaderValueInMap(ensureMapKeyInContext(context, paramOverrideContextHeaderOverrideNormalized), normalizedName); ok { - return value, true - } - if value, ok := findHeaderValueInMap(ensureMapKeyInContext(context, paramOverrideContextRequestHeaders), normalizedName); ok { - return value, true + for _, key := range []string{paramOverrideContextHeaderOverride, paramOverrideContextRequestHeaders} { + source := ensureMapKeyInContext(context, key) + raw, ok := source[headerName] + if !ok { + continue + } + value := strings.TrimSpace(fmt.Sprintf("%v", raw)) + if value != "" { + return value, true + } } return "", false } -func findHeaderValueInMap(source map[string]interface{}, key string) (string, bool) { - if len(source) == 0 { - return "", false - } - entries := lo.Entries(source) - entry, ok := lo.Find(entries, func(item lo.Entry[string, interface{}]) bool { - return strings.EqualFold(strings.TrimSpace(item.Key), key) - }) - if !ok { - return "", false - } - value := strings.TrimSpace(fmt.Sprintf("%v", entry.Value)) - if value == "" { - return "", false - } - return value, true -} - func normalizeHeaderContextKey(key string) string { - key = strings.TrimSpace(strings.ToLower(key)) - if key == "" { - return "" - } - var b strings.Builder - b.Grow(len(key)) - previousUnderscore := false - for _, r := range key { - switch { - case r >= 'a' && r <= 'z': - b.WriteRune(r) - previousUnderscore = false - case r >= '0' && r <= '9': - b.WriteRune(r) - previousUnderscore = false - default: - if !previousUnderscore { - b.WriteByte('_') - previousUnderscore = true - } - } - } - result := strings.Trim(b.String(), "_") - return result + return strings.TrimSpace(strings.ToLower(key)) } -func buildNormalizedHeaders(headers map[string]string) map[string]interface{} { +func buildRequestHeadersContext(headers map[string]string) map[string]interface{} { if len(headers) == 0 { return map[string]interface{}{} } @@ -1081,54 +1004,6 @@ func buildNormalizedHeaders(headers map[string]string) map[string]interface{} { }) } -func buildRawHeaders(headers map[string]string) map[string]interface{} { - if len(headers) == 0 { - return map[string]interface{}{} - } - entries := lo.Entries(headers) - rawEntries := lo.FilterMap(entries, func(item lo.Entry[string, string], _ int) (lo.Entry[string, string], bool) { - key := strings.TrimSpace(item.Key) - value := strings.TrimSpace(item.Value) - if key == "" || value == "" { - return lo.Entry[string, string]{}, false - } - return lo.Entry[string, string]{Key: key, Value: value}, true - }) - return lo.SliceToMap(rawEntries, func(item lo.Entry[string, string]) (string, interface{}) { - return item.Key, item.Value - }) -} - -func buildHeaderOverrideContext(headers map[string]interface{}) (map[string]interface{}, map[string]interface{}) { - if len(headers) == 0 { - return map[string]interface{}{}, map[string]interface{}{} - } - entries := lo.Entries(headers) - rawEntries := lo.FilterMap(entries, func(item lo.Entry[string, interface{}], _ int) (lo.Entry[string, string], bool) { - key := strings.TrimSpace(item.Key) - value := strings.TrimSpace(fmt.Sprintf("%v", item.Value)) - if key == "" || value == "" { - return lo.Entry[string, string]{}, false - } - return lo.Entry[string, string]{Key: key, Value: value}, true - }) - - raw := lo.SliceToMap(rawEntries, func(item lo.Entry[string, string]) (string, interface{}) { - return item.Key, item.Value - }) - normalizedEntries := lo.FilterMap(rawEntries, func(item lo.Entry[string, string], _ int) (lo.Entry[string, string], bool) { - normalized := normalizeHeaderContextKey(item.Key) - if normalized == "" { - return lo.Entry[string, string]{}, false - } - return lo.Entry[string, string]{Key: normalized, Value: item.Value}, true - }) - normalized := lo.SliceToMap(normalizedEntries, func(item lo.Entry[string, string]) (string, interface{}) { - return item.Key, item.Value - }) - return raw, normalized -} - func syncRuntimeHeaderOverrideFromContext(info *RelayInfo, context map[string]interface{}) { if info == nil || context == nil { return @@ -1141,55 +1016,10 @@ func syncRuntimeHeaderOverrideFromContext(info *RelayInfo, context map[string]in if !ok { return } - - entries := lo.Entries(rawMap) - sanitized := lo.FilterMap(entries, func(item lo.Entry[string, interface{}], _ int) (lo.Entry[string, interface{}], bool) { - key := strings.TrimSpace(item.Key) - if key == "" { - return lo.Entry[string, interface{}]{}, false - } - value := strings.TrimSpace(fmt.Sprintf("%v", item.Value)) - if value == "" { - return lo.Entry[string, interface{}]{}, false - } - return lo.Entry[string, interface{}]{Key: key, Value: value}, true - }) - info.RuntimeHeadersOverride = lo.SliceToMap(sanitized, func(item lo.Entry[string, interface{}]) (string, interface{}) { - return item.Key, item.Value - }) - info.RuntimeHeadersDeletedNormalized = sanitizeRuntimeDeletedHeadersFromContext(context) + info.RuntimeHeadersOverride = sanitizeHeaderOverrideMap(rawMap) info.UseRuntimeHeadersOverride = true } -func sanitizeRuntimeDeletedHeadersFromContext(context map[string]interface{}) map[string]bool { - deletedRaw, exists := context[paramOverrideContextHeaderOverrideDeleted] - if !exists { - return nil - } - deletedMap, ok := deletedRaw.(map[string]interface{}) - if !ok || len(deletedMap) == 0 { - return nil - } - entries := lo.Entries(deletedMap) - sanitized := lo.FilterMap(entries, func(item lo.Entry[string, interface{}], _ int) (string, bool) { - if keep, ok := item.Value.(bool); ok && !keep { - return "", false - } - normalized := normalizeHeaderContextKey(item.Key) - if normalized == "" { - return "", false - } - return normalized, true - }) - if len(sanitized) == 0 { - return nil - } - keys := lo.Uniq(sanitized) - return lo.SliceToMap(keys, func(item string) (string, bool) { - return item, true - }) -} - func moveValue(jsonStr, fromPath, toPath string) (string, error) { sourceValue := gjson.Get(jsonStr, fromPath) if !sourceValue.Exists() { @@ -1635,16 +1465,10 @@ func BuildParamOverrideContext(info *RelayInfo) map[string]interface{} { } } - ctx[paramOverrideContextRequestHeaders] = buildNormalizedHeaders(info.RequestHeaders) - ctx[paramOverrideContextRequestHeadersRaw] = buildRawHeaders(info.RequestHeaders) + ctx[paramOverrideContextRequestHeaders] = buildRequestHeadersContext(info.RequestHeaders) headerOverrideSource := GetEffectiveHeaderOverride(info) - rawHeaderOverride, normalizedHeaderOverride := buildHeaderOverrideContext(headerOverrideSource) - ctx[paramOverrideContextHeaderOverride] = rawHeaderOverride - ctx[paramOverrideContextHeaderOverrideNormalized] = normalizedHeaderOverride - ctx[paramOverrideContextHeaderOverrideDeleted] = lo.SliceToMap(lo.Keys(cloneDeletedHeaderKeys(info.RuntimeHeadersDeletedNormalized)), func(item string) (string, interface{}) { - return item, true - }) + ctx[paramOverrideContextHeaderOverride] = sanitizeHeaderOverrideMap(headerOverrideSource) ctx["retry_index"] = info.RetryIndex ctx["is_retry"] = info.RetryIndex > 0 diff --git a/relay/common/override_test.go b/relay/common/override_test.go index a9dad7177..7a27ca407 100644 --- a/relay/common/override_test.go +++ b/relay/common/override_test.go @@ -1005,7 +1005,7 @@ func TestApplyParamOverrideSetHeaderAndUseInLaterCondition(t *testing.T) { "value": 0.1, "conditions": []interface{}{ map[string]interface{}{ - "path": "header_override_normalized.x_debug_mode", + "path": "header_override.x-debug-mode", "mode": "full", "value": "enabled", }, @@ -1036,7 +1036,7 @@ func TestApplyParamOverrideCopyHeaderFromRequestHeaders(t *testing.T) { "value": 0.1, "conditions": []interface{}{ map[string]interface{}{ - "path": "header_override_normalized.x_upstream_auth", + "path": "header_override.x-upstream-auth", "mode": "contains", "value": "Bearer ", }, @@ -1045,9 +1045,6 @@ func TestApplyParamOverrideCopyHeaderFromRequestHeaders(t *testing.T) { }, } ctx := map[string]interface{}{ - "request_headers_raw": map[string]interface{}{ - "Authorization": "Bearer token-123", - }, "request_headers": map[string]interface{}{ "authorization": "Bearer token-123", }, @@ -1071,9 +1068,6 @@ func TestApplyParamOverridePassHeadersSkipsMissingHeaders(t *testing.T) { }, } ctx := map[string]interface{}{ - "request_headers_raw": map[string]interface{}{ - "Session_id": "sess-123", - }, "request_headers": map[string]interface{}{ "session_id": "sess-123", }, @@ -1089,10 +1083,10 @@ func TestApplyParamOverridePassHeadersSkipsMissingHeaders(t *testing.T) { if !ok { t.Fatalf("expected header_override context map") } - if headers["Session_id"] != "sess-123" { - t.Fatalf("expected Session_id to be passed, got: %v", headers["Session_id"]) + if headers["session_id"] != "sess-123" { + t.Fatalf("expected session_id to be passed, got: %v", headers["session_id"]) } - if _, exists := headers["X-Codex-Beta-Features"]; exists { + if _, exists := headers["x-codex-beta-features"]; exists { t.Fatalf("expected missing header to be skipped") } } @@ -1109,9 +1103,6 @@ func TestApplyParamOverrideCopyHeaderSkipsMissingSource(t *testing.T) { }, } ctx := map[string]interface{}{ - "request_headers_raw": map[string]interface{}{ - "Authorization": "Bearer token-123", - }, "request_headers": map[string]interface{}{ "authorization": "Bearer token-123", }, @@ -1127,7 +1118,7 @@ func TestApplyParamOverrideCopyHeaderSkipsMissingSource(t *testing.T) { if !ok { return } - if _, exists := headers["X-Upstream-Auth"]; exists { + if _, exists := headers["x-upstream-auth"]; exists { t.Fatalf("expected X-Upstream-Auth to be skipped when source header is missing") } } @@ -1144,9 +1135,6 @@ func TestApplyParamOverrideMoveHeaderSkipsMissingSource(t *testing.T) { }, } ctx := map[string]interface{}{ - "request_headers_raw": map[string]interface{}{ - "Authorization": "Bearer token-123", - }, "request_headers": map[string]interface{}{ "authorization": "Bearer token-123", }, @@ -1162,7 +1150,7 @@ func TestApplyParamOverrideMoveHeaderSkipsMissingSource(t *testing.T) { if !ok { return } - if _, exists := headers["X-Upstream-Auth"]; exists { + if _, exists := headers["x-upstream-auth"]; exists { t.Fatalf("expected X-Upstream-Auth to be skipped when source header is missing") } } @@ -1179,9 +1167,6 @@ func TestApplyParamOverrideSyncFieldsHeaderToJSON(t *testing.T) { }, } ctx := map[string]interface{}{ - "request_headers_raw": map[string]interface{}{ - "session_id": "sess-123", - }, "request_headers": map[string]interface{}{ "session_id": "sess-123", }, @@ -1234,9 +1219,6 @@ func TestApplyParamOverrideSyncFieldsNoChangeWhenBothExist(t *testing.T) { }, } ctx := map[string]interface{}{ - "request_headers_raw": map[string]interface{}{ - "session_id": "cache-header", - }, "request_headers": map[string]interface{}{ "session_id": "cache-header", }, @@ -1288,10 +1270,7 @@ func TestApplyParamOverrideSetHeaderKeepOrigin(t *testing.T) { } ctx := map[string]interface{}{ "header_override": map[string]interface{}{ - "X-Feature-Flag": "legacy-value", - }, - "header_override_normalized": map[string]interface{}{ - "x_feature_flag": "legacy-value", + "x-feature-flag": "legacy-value", }, } @@ -1303,8 +1282,8 @@ func TestApplyParamOverrideSetHeaderKeepOrigin(t *testing.T) { if !ok { t.Fatalf("expected header_override context map") } - if headers["X-Feature-Flag"] != "legacy-value" { - t.Fatalf("expected keep_origin to preserve old value, got: %v", headers["X-Feature-Flag"]) + if headers["x-feature-flag"] != "legacy-value" { + t.Fatalf("expected keep_origin to preserve old value, got: %v", headers["x-feature-flag"]) } } @@ -1371,14 +1350,14 @@ func TestApplyParamOverrideWithRelayInfoSyncRuntimeHeaders(t *testing.T) { if !info.UseRuntimeHeadersOverride { t.Fatalf("expected runtime header override to be enabled") } - if info.RuntimeHeadersOverride["X-Keep-Me"] != "keep" { - t.Fatalf("expected X-Keep-Me header to be preserved, got: %v", info.RuntimeHeadersOverride["X-Keep-Me"]) + if info.RuntimeHeadersOverride["x-keep-me"] != "keep" { + t.Fatalf("expected x-keep-me header to be preserved, got: %v", info.RuntimeHeadersOverride["x-keep-me"]) } - if info.RuntimeHeadersOverride["X-Injected-By-Param-Override"] != "enabled" { - t.Fatalf("expected X-Injected-By-Param-Override header to be set, got: %v", info.RuntimeHeadersOverride["X-Injected-By-Param-Override"]) + if info.RuntimeHeadersOverride["x-injected-by-param-override"] != "enabled" { + t.Fatalf("expected x-injected-by-param-override header to be set, got: %v", info.RuntimeHeadersOverride["x-injected-by-param-override"]) } - if _, exists := info.RuntimeHeadersOverride["X-Delete-Me"]; exists { - t.Fatalf("expected X-Delete-Me header to be deleted") + if _, exists := info.RuntimeHeadersOverride["x-delete-me"]; exists { + t.Fatalf("expected x-delete-me header to be deleted") } } @@ -1410,25 +1389,22 @@ func TestApplyParamOverrideWithRelayInfoMoveAndCopyHeaders(t *testing.T) { if err != nil { t.Fatalf("ApplyParamOverrideWithRelayInfo returned error: %v", err) } - if _, exists := info.RuntimeHeadersOverride["X-Legacy-Trace"]; exists { + if _, exists := info.RuntimeHeadersOverride["x-legacy-trace"]; exists { t.Fatalf("expected source header to be removed after move") } - if info.RuntimeHeadersOverride["X-Trace"] != "trace-123" { - t.Fatalf("expected X-Trace to be set, got: %v", info.RuntimeHeadersOverride["X-Trace"]) + if info.RuntimeHeadersOverride["x-trace"] != "trace-123" { + t.Fatalf("expected x-trace to be set, got: %v", info.RuntimeHeadersOverride["x-trace"]) } - if info.RuntimeHeadersOverride["X-Trace-Backup"] != "trace-123" { - t.Fatalf("expected X-Trace-Backup to be copied, got: %v", info.RuntimeHeadersOverride["X-Trace-Backup"]) + if info.RuntimeHeadersOverride["x-trace-backup"] != "trace-123" { + t.Fatalf("expected x-trace-backup to be copied, got: %v", info.RuntimeHeadersOverride["x-trace-backup"]) } } -func TestGetEffectiveHeaderOverrideMergesRuntimeAndChannelOverrides(t *testing.T) { +func TestGetEffectiveHeaderOverrideUsesRuntimeOverrideAsFinalResult(t *testing.T) { info := &RelayInfo{ UseRuntimeHeadersOverride: true, RuntimeHeadersOverride: map[string]interface{}{ - "X-Runtime": "runtime-only", - }, - RuntimeHeadersDeletedNormalized: map[string]bool{ - "x-deleted": true, + "x-runtime": "runtime-only", }, ChannelMeta: &ChannelMeta{ HeadersOverride: map[string]interface{}{ @@ -1439,14 +1415,11 @@ func TestGetEffectiveHeaderOverrideMergesRuntimeAndChannelOverrides(t *testing.T } effective := GetEffectiveHeaderOverride(info) - if effective["X-Static"] != "static-value" { - t.Fatalf("expected X-Static from channel override, got: %v", effective["X-Static"]) + if effective["x-runtime"] != "runtime-only" { + t.Fatalf("expected x-runtime from runtime override, got: %v", effective["x-runtime"]) } - if effective["X-Runtime"] != "runtime-only" { - t.Fatalf("expected X-Runtime from runtime override, got: %v", effective["X-Runtime"]) - } - if _, exists := effective["X-Deleted"]; exists { - t.Fatalf("expected deleted headers to stay deleted in effective override") + if _, exists := effective["x-static"]; exists { + t.Fatalf("expected runtime override to be final and not merge channel headers") } } diff --git a/relay/common/relay_info.go b/relay/common/relay_info.go index 946294d49..8b0789c0d 100644 --- a/relay/common/relay_info.go +++ b/relay/common/relay_info.go @@ -148,7 +148,6 @@ type RelayInfo struct { RetryIndex int LastError *types.NewAPIError RuntimeHeadersOverride map[string]interface{} - RuntimeHeadersDeletedNormalized map[string]bool UseRuntimeHeadersOverride bool PriceData types.PriceData diff --git a/service/channel_affinity_template_test.go b/service/channel_affinity_template_test.go index 71e29d668..acf301543 100644 --- a/service/channel_affinity_template_test.go +++ b/service/channel_affinity_template_test.go @@ -1,9 +1,15 @@ package service import ( + "fmt" + "net/http" "net/http/httptest" + "strings" "testing" + "time" + relaycommon "github.com/QuantumNous/new-api/relay/common" + "github.com/QuantumNous/new-api/setting/operation_setting" "github.com/gin-gonic/gin" "github.com/stretchr/testify/require" ) @@ -67,3 +73,73 @@ func TestApplyChannelAffinityOverrideTemplate_MergeTemplate(t *testing.T) { require.Equal(t, "rule-with-template", overrideInfo["rule_name"]) require.EqualValues(t, 2, overrideInfo["param_override_keys"]) } + +func TestChannelAffinityHitCodexTemplatePassHeadersEffective(t *testing.T) { + gin.SetMode(gin.TestMode) + + setting := operation_setting.GetChannelAffinitySetting() + require.NotNil(t, setting) + + var codexRule *operation_setting.ChannelAffinityRule + for i := range setting.Rules { + rule := &setting.Rules[i] + if strings.EqualFold(strings.TrimSpace(rule.Name), "codex cli trace") { + codexRule = rule + break + } + } + require.NotNil(t, codexRule) + + affinityValue := fmt.Sprintf("pc-hit-%d", time.Now().UnixNano()) + cacheKeySuffix := buildChannelAffinityCacheKeySuffix(*codexRule, "default", affinityValue) + + cache := getChannelAffinityCache() + require.NoError(t, cache.SetWithTTL(cacheKeySuffix, 9527, time.Minute)) + t.Cleanup(func() { + _, _ = cache.DeleteMany([]string{cacheKeySuffix}) + }) + + rec := httptest.NewRecorder() + ctx, _ := gin.CreateTestContext(rec) + ctx.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", strings.NewReader(fmt.Sprintf(`{"prompt_cache_key":"%s"}`, affinityValue))) + ctx.Request.Header.Set("Content-Type", "application/json") + + channelID, found := GetPreferredChannelByAffinity(ctx, "gpt-5", "default") + require.True(t, found) + require.Equal(t, 9527, channelID) + + baseOverride := map[string]interface{}{ + "temperature": 0.2, + } + mergedOverride, applied := ApplyChannelAffinityOverrideTemplate(ctx, baseOverride) + require.True(t, applied) + require.Equal(t, 0.2, mergedOverride["temperature"]) + + info := &relaycommon.RelayInfo{ + RequestHeaders: map[string]string{ + "Originator": "Codex CLI", + "Session_id": "sess-123", + "User-Agent": "codex-cli-test", + }, + ChannelMeta: &relaycommon.ChannelMeta{ + ParamOverride: mergedOverride, + HeadersOverride: map[string]interface{}{ + "X-Static": "legacy-static", + }, + }, + } + + _, err := relaycommon.ApplyParamOverrideWithRelayInfo([]byte(`{"model":"gpt-5"}`), info) + require.NoError(t, err) + require.True(t, info.UseRuntimeHeadersOverride) + + require.Equal(t, "legacy-static", info.RuntimeHeadersOverride["x-static"]) + require.Equal(t, "Codex CLI", info.RuntimeHeadersOverride["originator"]) + require.Equal(t, "sess-123", info.RuntimeHeadersOverride["session_id"]) + require.Equal(t, "codex-cli-test", info.RuntimeHeadersOverride["user-agent"]) + + _, exists := info.RuntimeHeadersOverride["x-codex-beta-features"] + require.False(t, exists) + _, exists = info.RuntimeHeadersOverride["x-codex-turn-metadata"] + require.False(t, exists) +}