diff --git a/relay/channel/api_request.go b/relay/channel/api_request.go index 49773e1e6..c1954cc02 100644 --- a/relay/channel/api_request.go +++ b/relay/channel/api_request.go @@ -173,10 +173,7 @@ func processHeaderOverride(info *common.RelayInfo, c *gin.Context) (map[string]s return headerOverride, nil } - headerOverrideSource := info.HeadersOverride - if info.UseRuntimeHeadersOverride { - headerOverrideSource = info.RuntimeHeadersOverride - } + headerOverrideSource := common.GetEffectiveHeaderOverride(info) passAll := false var passthroughRegex []*regexp.Regexp diff --git a/relay/channel/api_request_test.go b/relay/channel/api_request_test.go index fc39f54ae..84406ba48 100644 --- a/relay/channel/api_request_test.go +++ b/relay/channel/api_request_test.go @@ -80,7 +80,7 @@ func TestProcessHeaderOverride_NonTestKeepsClientHeaderPlaceholder(t *testing.T) require.Equal(t, "trace-123", headers["X-Upstream-Trace"]) } -func TestProcessHeaderOverride_RuntimeOverrideHasPriority(t *testing.T) { +func TestProcessHeaderOverride_RuntimeOverrideMergesWithChannelOverride(t *testing.T) { t.Parallel() gin.SetMode(gin.TestMode) @@ -107,8 +107,7 @@ func TestProcessHeaderOverride_RuntimeOverrideHasPriority(t *testing.T) { require.NoError(t, err) require.Equal(t, "runtime-value", headers["X-Static"]) require.Equal(t, "runtime-only", headers["X-Runtime"]) - _, ok := headers["X-Legacy"] - require.False(t, ok) + require.Equal(t, "legacy-only", headers["X-Legacy"]) } func TestProcessHeaderOverride_PassthroughSkipsAcceptEncoding(t *testing.T) { diff --git a/relay/common/override.go b/relay/common/override.go index 4a5994291..78cf60b92 100644 --- a/relay/common/override.go +++ b/relay/common/override.go @@ -22,6 +22,7 @@ const ( paramOverrideContextRequestHeadersRaw = "request_headers_raw" paramOverrideContextHeaderOverride = "header_override" paramOverrideContextHeaderOverrideNormalized = "header_override_normalized" + paramOverrideContextHeaderOverrideDeleted = "header_override_deleted_normalized" ) var errSourceHeaderNotFound = errors.New("source header does not exist") @@ -160,6 +161,84 @@ func getHeaderOverrideMap(info *RelayInfo) map[string]interface{} { return info.ChannelMeta.HeadersOverride } +func cloneHeaderOverrideMap(source map[string]interface{}) map[string]interface{} { + if len(source) == 0 { + return map[string]interface{}{} + } + target := make(map[string]interface{}, len(source)) + for key, value := range source { + target[key] = value + } + return target +} + +func setHeaderOverrideEntry(target map[string]interface{}, key string, value interface{}) { + key = strings.TrimSpace(key) + if key == "" { + return + } + for existing := range target { + if strings.EqualFold(strings.TrimSpace(existing), key) { + delete(target, existing) + } + } + target[key] = value +} + +func isHeaderDeletedByRuntime(headerName string, deleted map[string]bool) bool { + if len(deleted) == 0 { + return false + } + normalized := normalizeHeaderContextKey(headerName) + if normalized == "" { + return false + } + return deleted[normalized] +} + +func mergeHeaderOverrideSource(base, runtime map[string]interface{}, deleted map[string]bool) map[string]interface{} { + merged := make(map[string]interface{}, len(base)+len(runtime)) + for key, value := range base { + if isHeaderDeletedByRuntime(key, deleted) { + continue + } + setHeaderOverrideEntry(merged, key, value) + } + for key, value := range runtime { + setHeaderOverrideEntry(merged, key, value) + } + return merged +} + +func cloneDeletedHeaderKeys(source map[string]bool) map[string]bool { + if len(source) == 0 { + return map[string]bool{} + } + target := make(map[string]bool, len(source)) + for key, value := range source { + if !value { + continue + } + normalized := normalizeHeaderContextKey(key) + if normalized == "" { + continue + } + target[normalized] = true + } + return target +} + +func GetEffectiveHeaderOverride(info *RelayInfo) map[string]interface{} { + if info == nil { + return map[string]interface{}{} + } + base := getHeaderOverrideMap(info) + if !info.UseRuntimeHeadersOverride { + return cloneHeaderOverrideMap(base) + } + return mergeHeaderOverrideSource(base, info.RuntimeHeadersOverride, cloneDeletedHeaderKeys(info.RuntimeHeadersDeletedNormalized)) +} + func tryParseOperations(paramOverride map[string]interface{}) ([]ParamOperation, bool) { // 检查是否包含 "operations" 字段 if opsValue, exists := paramOverride["operations"]; exists { @@ -480,6 +559,9 @@ func applyOperations(jsonStr string, operations []ParamOperation, conditionConte targetHeader = strings.TrimSpace(op.Path) } err = copyHeaderInContext(context, sourceHeader, targetHeader, op.KeepOrigin) + if errors.Is(err, errSourceHeaderNotFound) { + err = nil + } if err == nil { contextJSON, err = marshalContextJSON(context) } @@ -493,6 +575,9 @@ func applyOperations(jsonStr string, operations []ParamOperation, conditionConte targetHeader = strings.TrimSpace(op.Path) } err = moveHeaderInContext(context, sourceHeader, targetHeader, op.KeepOrigin) + if errors.Is(err, errSourceHeaderNotFound) { + err = nil + } if err == nil { contextJSON, err = marshalContextJSON(context) } @@ -647,8 +732,13 @@ func setHeaderOverrideInContext(context map[string]interface{}, headerName strin rawHeaders := ensureMapKeyInContext(context, paramOverrideContextHeaderOverride) rawHeaders[headerName] = headerValue + normalizedHeaderName := normalizeHeaderContextKey(headerName) normalizedHeaders := ensureMapKeyInContext(context, paramOverrideContextHeaderOverrideNormalized) - normalizedHeaders[normalizeHeaderContextKey(headerName)] = headerValue + normalizedHeaders[normalizedHeaderName] = headerValue + if normalizedHeaderName != "" { + deletedHeaders := ensureMapKeyInContext(context, paramOverrideContextHeaderOverrideDeleted) + delete(deletedHeaders, normalizedHeaderName) + } return nil } @@ -693,7 +783,12 @@ func deleteHeaderOverrideInContext(context map[string]interface{}, headerName st } normalizedHeaders := ensureMapKeyInContext(context, paramOverrideContextHeaderOverrideNormalized) - delete(normalizedHeaders, normalizeHeaderContextKey(headerName)) + normalizedHeaderName := normalizeHeaderContextKey(headerName) + delete(normalizedHeaders, normalizedHeaderName) + if normalizedHeaderName != "" { + deletedHeaders := ensureMapKeyInContext(context, paramOverrideContextHeaderOverrideDeleted) + deletedHeaders[normalizedHeaderName] = true + } return nil } @@ -1062,9 +1157,39 @@ func syncRuntimeHeaderOverrideFromContext(info *RelayInfo, context map[string]in info.RuntimeHeadersOverride = lo.SliceToMap(sanitized, func(item lo.Entry[string, interface{}]) (string, interface{}) { return item.Key, item.Value }) + info.RuntimeHeadersDeletedNormalized = sanitizeRuntimeDeletedHeadersFromContext(context) info.UseRuntimeHeadersOverride = true } +func sanitizeRuntimeDeletedHeadersFromContext(context map[string]interface{}) map[string]bool { + deletedRaw, exists := context[paramOverrideContextHeaderOverrideDeleted] + if !exists { + return nil + } + deletedMap, ok := deletedRaw.(map[string]interface{}) + if !ok || len(deletedMap) == 0 { + return nil + } + entries := lo.Entries(deletedMap) + sanitized := lo.FilterMap(entries, func(item lo.Entry[string, interface{}], _ int) (string, bool) { + if keep, ok := item.Value.(bool); ok && !keep { + return "", false + } + normalized := normalizeHeaderContextKey(item.Key) + if normalized == "" { + return "", false + } + return normalized, true + }) + if len(sanitized) == 0 { + return nil + } + keys := lo.Uniq(sanitized) + return lo.SliceToMap(keys, func(item string) (string, bool) { + return item, true + }) +} + func moveValue(jsonStr, fromPath, toPath string) (string, error) { sourceValue := gjson.Get(jsonStr, fromPath) if !sourceValue.Exists() { @@ -1513,13 +1638,13 @@ func BuildParamOverrideContext(info *RelayInfo) map[string]interface{} { ctx[paramOverrideContextRequestHeaders] = buildNormalizedHeaders(info.RequestHeaders) ctx[paramOverrideContextRequestHeadersRaw] = buildRawHeaders(info.RequestHeaders) - headerOverrideSource := getHeaderOverrideMap(info) - if info.UseRuntimeHeadersOverride { - headerOverrideSource = info.RuntimeHeadersOverride - } + headerOverrideSource := GetEffectiveHeaderOverride(info) rawHeaderOverride, normalizedHeaderOverride := buildHeaderOverrideContext(headerOverrideSource) ctx[paramOverrideContextHeaderOverride] = rawHeaderOverride ctx[paramOverrideContextHeaderOverrideNormalized] = normalizedHeaderOverride + ctx[paramOverrideContextHeaderOverrideDeleted] = lo.SliceToMap(lo.Keys(cloneDeletedHeaderKeys(info.RuntimeHeadersDeletedNormalized)), func(item string) (string, interface{}) { + return item, true + }) ctx["retry_index"] = info.RetryIndex ctx["is_retry"] = info.RetryIndex > 0 diff --git a/relay/common/override_test.go b/relay/common/override_test.go index 7905d42bd..a9dad7177 100644 --- a/relay/common/override_test.go +++ b/relay/common/override_test.go @@ -1097,6 +1097,76 @@ func TestApplyParamOverridePassHeadersSkipsMissingHeaders(t *testing.T) { } } +func TestApplyParamOverrideCopyHeaderSkipsMissingSource(t *testing.T) { + input := []byte(`{"temperature":0.7}`) + override := map[string]interface{}{ + "operations": []interface{}{ + map[string]interface{}{ + "mode": "copy_header", + "from": "X-Missing-Header", + "to": "X-Upstream-Auth", + }, + }, + } + ctx := map[string]interface{}{ + "request_headers_raw": map[string]interface{}{ + "Authorization": "Bearer token-123", + }, + "request_headers": map[string]interface{}{ + "authorization": "Bearer token-123", + }, + } + + out, err := ApplyParamOverride(input, override, ctx) + if err != nil { + t.Fatalf("ApplyParamOverride returned error: %v", err) + } + assertJSONEqual(t, `{"temperature":0.7}`, string(out)) + + headers, ok := ctx["header_override"].(map[string]interface{}) + if !ok { + return + } + if _, exists := headers["X-Upstream-Auth"]; exists { + t.Fatalf("expected X-Upstream-Auth to be skipped when source header is missing") + } +} + +func TestApplyParamOverrideMoveHeaderSkipsMissingSource(t *testing.T) { + input := []byte(`{"temperature":0.7}`) + override := map[string]interface{}{ + "operations": []interface{}{ + map[string]interface{}{ + "mode": "move_header", + "from": "X-Missing-Header", + "to": "X-Upstream-Auth", + }, + }, + } + ctx := map[string]interface{}{ + "request_headers_raw": map[string]interface{}{ + "Authorization": "Bearer token-123", + }, + "request_headers": map[string]interface{}{ + "authorization": "Bearer token-123", + }, + } + + out, err := ApplyParamOverride(input, override, ctx) + if err != nil { + t.Fatalf("ApplyParamOverride returned error: %v", err) + } + assertJSONEqual(t, `{"temperature":0.7}`, string(out)) + + headers, ok := ctx["header_override"].(map[string]interface{}) + if !ok { + return + } + if _, exists := headers["X-Upstream-Auth"]; exists { + t.Fatalf("expected X-Upstream-Auth to be skipped when source header is missing") + } +} + func TestApplyParamOverrideSyncFieldsHeaderToJSON(t *testing.T) { input := []byte(`{"model":"gpt-4"}`) override := map[string]interface{}{ @@ -1351,6 +1421,35 @@ func TestApplyParamOverrideWithRelayInfoMoveAndCopyHeaders(t *testing.T) { } } +func TestGetEffectiveHeaderOverrideMergesRuntimeAndChannelOverrides(t *testing.T) { + info := &RelayInfo{ + UseRuntimeHeadersOverride: true, + RuntimeHeadersOverride: map[string]interface{}{ + "X-Runtime": "runtime-only", + }, + RuntimeHeadersDeletedNormalized: map[string]bool{ + "x-deleted": true, + }, + ChannelMeta: &ChannelMeta{ + HeadersOverride: map[string]interface{}{ + "X-Static": "static-value", + "X-Deleted": "should-not-exist", + }, + }, + } + + effective := GetEffectiveHeaderOverride(info) + if effective["X-Static"] != "static-value" { + t.Fatalf("expected X-Static from channel override, got: %v", effective["X-Static"]) + } + if effective["X-Runtime"] != "runtime-only" { + t.Fatalf("expected X-Runtime from runtime override, got: %v", effective["X-Runtime"]) + } + if _, exists := effective["X-Deleted"]; exists { + t.Fatalf("expected deleted headers to stay deleted in effective override") + } +} + func TestRemoveDisabledFieldsSkipWhenChannelPassThroughEnabled(t *testing.T) { input := `{ "service_tier":"flex", diff --git a/relay/common/relay_info.go b/relay/common/relay_info.go index 8b0789c0d..946294d49 100644 --- a/relay/common/relay_info.go +++ b/relay/common/relay_info.go @@ -148,6 +148,7 @@ type RelayInfo struct { RetryIndex int LastError *types.NewAPIError RuntimeHeadersOverride map[string]interface{} + RuntimeHeadersDeletedNormalized map[string]bool UseRuntimeHeadersOverride bool PriceData types.PriceData