fix: merge runtime and channel header overrides, skip missing source headers

This commit is contained in:
Seefs
2026-02-25 16:12:34 +08:00
parent bb0c663dbe
commit 305dbce4ad
5 changed files with 234 additions and 13 deletions

View File

@@ -173,10 +173,7 @@ func processHeaderOverride(info *common.RelayInfo, c *gin.Context) (map[string]s
return headerOverride, nil
}
headerOverrideSource := info.HeadersOverride
if info.UseRuntimeHeadersOverride {
headerOverrideSource = info.RuntimeHeadersOverride
}
headerOverrideSource := common.GetEffectiveHeaderOverride(info)
passAll := false
var passthroughRegex []*regexp.Regexp

View File

@@ -80,7 +80,7 @@ func TestProcessHeaderOverride_NonTestKeepsClientHeaderPlaceholder(t *testing.T)
require.Equal(t, "trace-123", headers["X-Upstream-Trace"])
}
func TestProcessHeaderOverride_RuntimeOverrideHasPriority(t *testing.T) {
func TestProcessHeaderOverride_RuntimeOverrideMergesWithChannelOverride(t *testing.T) {
t.Parallel()
gin.SetMode(gin.TestMode)
@@ -107,8 +107,7 @@ func TestProcessHeaderOverride_RuntimeOverrideHasPriority(t *testing.T) {
require.NoError(t, err)
require.Equal(t, "runtime-value", headers["X-Static"])
require.Equal(t, "runtime-only", headers["X-Runtime"])
_, ok := headers["X-Legacy"]
require.False(t, ok)
require.Equal(t, "legacy-only", headers["X-Legacy"])
}
func TestProcessHeaderOverride_PassthroughSkipsAcceptEncoding(t *testing.T) {

View File

@@ -22,6 +22,7 @@ const (
paramOverrideContextRequestHeadersRaw = "request_headers_raw"
paramOverrideContextHeaderOverride = "header_override"
paramOverrideContextHeaderOverrideNormalized = "header_override_normalized"
paramOverrideContextHeaderOverrideDeleted = "header_override_deleted_normalized"
)
var errSourceHeaderNotFound = errors.New("source header does not exist")
@@ -160,6 +161,84 @@ func getHeaderOverrideMap(info *RelayInfo) map[string]interface{} {
return info.ChannelMeta.HeadersOverride
}
func cloneHeaderOverrideMap(source map[string]interface{}) map[string]interface{} {
if len(source) == 0 {
return map[string]interface{}{}
}
target := make(map[string]interface{}, len(source))
for key, value := range source {
target[key] = value
}
return target
}
func setHeaderOverrideEntry(target map[string]interface{}, key string, value interface{}) {
key = strings.TrimSpace(key)
if key == "" {
return
}
for existing := range target {
if strings.EqualFold(strings.TrimSpace(existing), key) {
delete(target, existing)
}
}
target[key] = value
}
func isHeaderDeletedByRuntime(headerName string, deleted map[string]bool) bool {
if len(deleted) == 0 {
return false
}
normalized := normalizeHeaderContextKey(headerName)
if normalized == "" {
return false
}
return deleted[normalized]
}
func mergeHeaderOverrideSource(base, runtime map[string]interface{}, deleted map[string]bool) map[string]interface{} {
merged := make(map[string]interface{}, len(base)+len(runtime))
for key, value := range base {
if isHeaderDeletedByRuntime(key, deleted) {
continue
}
setHeaderOverrideEntry(merged, key, value)
}
for key, value := range runtime {
setHeaderOverrideEntry(merged, key, value)
}
return merged
}
func cloneDeletedHeaderKeys(source map[string]bool) map[string]bool {
if len(source) == 0 {
return map[string]bool{}
}
target := make(map[string]bool, len(source))
for key, value := range source {
if !value {
continue
}
normalized := normalizeHeaderContextKey(key)
if normalized == "" {
continue
}
target[normalized] = true
}
return target
}
func GetEffectiveHeaderOverride(info *RelayInfo) map[string]interface{} {
if info == nil {
return map[string]interface{}{}
}
base := getHeaderOverrideMap(info)
if !info.UseRuntimeHeadersOverride {
return cloneHeaderOverrideMap(base)
}
return mergeHeaderOverrideSource(base, info.RuntimeHeadersOverride, cloneDeletedHeaderKeys(info.RuntimeHeadersDeletedNormalized))
}
func tryParseOperations(paramOverride map[string]interface{}) ([]ParamOperation, bool) {
// 检查是否包含 "operations" 字段
if opsValue, exists := paramOverride["operations"]; exists {
@@ -480,6 +559,9 @@ func applyOperations(jsonStr string, operations []ParamOperation, conditionConte
targetHeader = strings.TrimSpace(op.Path)
}
err = copyHeaderInContext(context, sourceHeader, targetHeader, op.KeepOrigin)
if errors.Is(err, errSourceHeaderNotFound) {
err = nil
}
if err == nil {
contextJSON, err = marshalContextJSON(context)
}
@@ -493,6 +575,9 @@ func applyOperations(jsonStr string, operations []ParamOperation, conditionConte
targetHeader = strings.TrimSpace(op.Path)
}
err = moveHeaderInContext(context, sourceHeader, targetHeader, op.KeepOrigin)
if errors.Is(err, errSourceHeaderNotFound) {
err = nil
}
if err == nil {
contextJSON, err = marshalContextJSON(context)
}
@@ -647,8 +732,13 @@ func setHeaderOverrideInContext(context map[string]interface{}, headerName strin
rawHeaders := ensureMapKeyInContext(context, paramOverrideContextHeaderOverride)
rawHeaders[headerName] = headerValue
normalizedHeaderName := normalizeHeaderContextKey(headerName)
normalizedHeaders := ensureMapKeyInContext(context, paramOverrideContextHeaderOverrideNormalized)
normalizedHeaders[normalizeHeaderContextKey(headerName)] = headerValue
normalizedHeaders[normalizedHeaderName] = headerValue
if normalizedHeaderName != "" {
deletedHeaders := ensureMapKeyInContext(context, paramOverrideContextHeaderOverrideDeleted)
delete(deletedHeaders, normalizedHeaderName)
}
return nil
}
@@ -693,7 +783,12 @@ func deleteHeaderOverrideInContext(context map[string]interface{}, headerName st
}
normalizedHeaders := ensureMapKeyInContext(context, paramOverrideContextHeaderOverrideNormalized)
delete(normalizedHeaders, normalizeHeaderContextKey(headerName))
normalizedHeaderName := normalizeHeaderContextKey(headerName)
delete(normalizedHeaders, normalizedHeaderName)
if normalizedHeaderName != "" {
deletedHeaders := ensureMapKeyInContext(context, paramOverrideContextHeaderOverrideDeleted)
deletedHeaders[normalizedHeaderName] = true
}
return nil
}
@@ -1062,9 +1157,39 @@ func syncRuntimeHeaderOverrideFromContext(info *RelayInfo, context map[string]in
info.RuntimeHeadersOverride = lo.SliceToMap(sanitized, func(item lo.Entry[string, interface{}]) (string, interface{}) {
return item.Key, item.Value
})
info.RuntimeHeadersDeletedNormalized = sanitizeRuntimeDeletedHeadersFromContext(context)
info.UseRuntimeHeadersOverride = true
}
func sanitizeRuntimeDeletedHeadersFromContext(context map[string]interface{}) map[string]bool {
deletedRaw, exists := context[paramOverrideContextHeaderOverrideDeleted]
if !exists {
return nil
}
deletedMap, ok := deletedRaw.(map[string]interface{})
if !ok || len(deletedMap) == 0 {
return nil
}
entries := lo.Entries(deletedMap)
sanitized := lo.FilterMap(entries, func(item lo.Entry[string, interface{}], _ int) (string, bool) {
if keep, ok := item.Value.(bool); ok && !keep {
return "", false
}
normalized := normalizeHeaderContextKey(item.Key)
if normalized == "" {
return "", false
}
return normalized, true
})
if len(sanitized) == 0 {
return nil
}
keys := lo.Uniq(sanitized)
return lo.SliceToMap(keys, func(item string) (string, bool) {
return item, true
})
}
func moveValue(jsonStr, fromPath, toPath string) (string, error) {
sourceValue := gjson.Get(jsonStr, fromPath)
if !sourceValue.Exists() {
@@ -1513,13 +1638,13 @@ func BuildParamOverrideContext(info *RelayInfo) map[string]interface{} {
ctx[paramOverrideContextRequestHeaders] = buildNormalizedHeaders(info.RequestHeaders)
ctx[paramOverrideContextRequestHeadersRaw] = buildRawHeaders(info.RequestHeaders)
headerOverrideSource := getHeaderOverrideMap(info)
if info.UseRuntimeHeadersOverride {
headerOverrideSource = info.RuntimeHeadersOverride
}
headerOverrideSource := GetEffectiveHeaderOverride(info)
rawHeaderOverride, normalizedHeaderOverride := buildHeaderOverrideContext(headerOverrideSource)
ctx[paramOverrideContextHeaderOverride] = rawHeaderOverride
ctx[paramOverrideContextHeaderOverrideNormalized] = normalizedHeaderOverride
ctx[paramOverrideContextHeaderOverrideDeleted] = lo.SliceToMap(lo.Keys(cloneDeletedHeaderKeys(info.RuntimeHeadersDeletedNormalized)), func(item string) (string, interface{}) {
return item, true
})
ctx["retry_index"] = info.RetryIndex
ctx["is_retry"] = info.RetryIndex > 0

View File

@@ -1097,6 +1097,76 @@ func TestApplyParamOverridePassHeadersSkipsMissingHeaders(t *testing.T) {
}
}
func TestApplyParamOverrideCopyHeaderSkipsMissingSource(t *testing.T) {
input := []byte(`{"temperature":0.7}`)
override := map[string]interface{}{
"operations": []interface{}{
map[string]interface{}{
"mode": "copy_header",
"from": "X-Missing-Header",
"to": "X-Upstream-Auth",
},
},
}
ctx := map[string]interface{}{
"request_headers_raw": map[string]interface{}{
"Authorization": "Bearer token-123",
},
"request_headers": map[string]interface{}{
"authorization": "Bearer token-123",
},
}
out, err := ApplyParamOverride(input, override, ctx)
if err != nil {
t.Fatalf("ApplyParamOverride returned error: %v", err)
}
assertJSONEqual(t, `{"temperature":0.7}`, string(out))
headers, ok := ctx["header_override"].(map[string]interface{})
if !ok {
return
}
if _, exists := headers["X-Upstream-Auth"]; exists {
t.Fatalf("expected X-Upstream-Auth to be skipped when source header is missing")
}
}
func TestApplyParamOverrideMoveHeaderSkipsMissingSource(t *testing.T) {
input := []byte(`{"temperature":0.7}`)
override := map[string]interface{}{
"operations": []interface{}{
map[string]interface{}{
"mode": "move_header",
"from": "X-Missing-Header",
"to": "X-Upstream-Auth",
},
},
}
ctx := map[string]interface{}{
"request_headers_raw": map[string]interface{}{
"Authorization": "Bearer token-123",
},
"request_headers": map[string]interface{}{
"authorization": "Bearer token-123",
},
}
out, err := ApplyParamOverride(input, override, ctx)
if err != nil {
t.Fatalf("ApplyParamOverride returned error: %v", err)
}
assertJSONEqual(t, `{"temperature":0.7}`, string(out))
headers, ok := ctx["header_override"].(map[string]interface{})
if !ok {
return
}
if _, exists := headers["X-Upstream-Auth"]; exists {
t.Fatalf("expected X-Upstream-Auth to be skipped when source header is missing")
}
}
func TestApplyParamOverrideSyncFieldsHeaderToJSON(t *testing.T) {
input := []byte(`{"model":"gpt-4"}`)
override := map[string]interface{}{
@@ -1351,6 +1421,35 @@ func TestApplyParamOverrideWithRelayInfoMoveAndCopyHeaders(t *testing.T) {
}
}
func TestGetEffectiveHeaderOverrideMergesRuntimeAndChannelOverrides(t *testing.T) {
info := &RelayInfo{
UseRuntimeHeadersOverride: true,
RuntimeHeadersOverride: map[string]interface{}{
"X-Runtime": "runtime-only",
},
RuntimeHeadersDeletedNormalized: map[string]bool{
"x-deleted": true,
},
ChannelMeta: &ChannelMeta{
HeadersOverride: map[string]interface{}{
"X-Static": "static-value",
"X-Deleted": "should-not-exist",
},
},
}
effective := GetEffectiveHeaderOverride(info)
if effective["X-Static"] != "static-value" {
t.Fatalf("expected X-Static from channel override, got: %v", effective["X-Static"])
}
if effective["X-Runtime"] != "runtime-only" {
t.Fatalf("expected X-Runtime from runtime override, got: %v", effective["X-Runtime"])
}
if _, exists := effective["X-Deleted"]; exists {
t.Fatalf("expected deleted headers to stay deleted in effective override")
}
}
func TestRemoveDisabledFieldsSkipWhenChannelPassThroughEnabled(t *testing.T) {
input := `{
"service_tier":"flex",

View File

@@ -148,6 +148,7 @@ type RelayInfo struct {
RetryIndex int
LastError *types.NewAPIError
RuntimeHeadersOverride map[string]interface{}
RuntimeHeadersDeletedNormalized map[string]bool
UseRuntimeHeadersOverride bool
PriceData types.PriceData