mirror of
https://github.com/QuantumNous/new-api.git
synced 2026-03-30 00:46:42 +00:00
fix: merge runtime and channel header overrides, skip missing source headers
This commit is contained in:
@@ -173,10 +173,7 @@ func processHeaderOverride(info *common.RelayInfo, c *gin.Context) (map[string]s
|
||||
return headerOverride, nil
|
||||
}
|
||||
|
||||
headerOverrideSource := info.HeadersOverride
|
||||
if info.UseRuntimeHeadersOverride {
|
||||
headerOverrideSource = info.RuntimeHeadersOverride
|
||||
}
|
||||
headerOverrideSource := common.GetEffectiveHeaderOverride(info)
|
||||
|
||||
passAll := false
|
||||
var passthroughRegex []*regexp.Regexp
|
||||
|
||||
@@ -80,7 +80,7 @@ func TestProcessHeaderOverride_NonTestKeepsClientHeaderPlaceholder(t *testing.T)
|
||||
require.Equal(t, "trace-123", headers["X-Upstream-Trace"])
|
||||
}
|
||||
|
||||
func TestProcessHeaderOverride_RuntimeOverrideHasPriority(t *testing.T) {
|
||||
func TestProcessHeaderOverride_RuntimeOverrideMergesWithChannelOverride(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
gin.SetMode(gin.TestMode)
|
||||
@@ -107,8 +107,7 @@ func TestProcessHeaderOverride_RuntimeOverrideHasPriority(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "runtime-value", headers["X-Static"])
|
||||
require.Equal(t, "runtime-only", headers["X-Runtime"])
|
||||
_, ok := headers["X-Legacy"]
|
||||
require.False(t, ok)
|
||||
require.Equal(t, "legacy-only", headers["X-Legacy"])
|
||||
}
|
||||
|
||||
func TestProcessHeaderOverride_PassthroughSkipsAcceptEncoding(t *testing.T) {
|
||||
|
||||
@@ -22,6 +22,7 @@ const (
|
||||
paramOverrideContextRequestHeadersRaw = "request_headers_raw"
|
||||
paramOverrideContextHeaderOverride = "header_override"
|
||||
paramOverrideContextHeaderOverrideNormalized = "header_override_normalized"
|
||||
paramOverrideContextHeaderOverrideDeleted = "header_override_deleted_normalized"
|
||||
)
|
||||
|
||||
var errSourceHeaderNotFound = errors.New("source header does not exist")
|
||||
@@ -160,6 +161,84 @@ func getHeaderOverrideMap(info *RelayInfo) map[string]interface{} {
|
||||
return info.ChannelMeta.HeadersOverride
|
||||
}
|
||||
|
||||
func cloneHeaderOverrideMap(source map[string]interface{}) map[string]interface{} {
|
||||
if len(source) == 0 {
|
||||
return map[string]interface{}{}
|
||||
}
|
||||
target := make(map[string]interface{}, len(source))
|
||||
for key, value := range source {
|
||||
target[key] = value
|
||||
}
|
||||
return target
|
||||
}
|
||||
|
||||
func setHeaderOverrideEntry(target map[string]interface{}, key string, value interface{}) {
|
||||
key = strings.TrimSpace(key)
|
||||
if key == "" {
|
||||
return
|
||||
}
|
||||
for existing := range target {
|
||||
if strings.EqualFold(strings.TrimSpace(existing), key) {
|
||||
delete(target, existing)
|
||||
}
|
||||
}
|
||||
target[key] = value
|
||||
}
|
||||
|
||||
func isHeaderDeletedByRuntime(headerName string, deleted map[string]bool) bool {
|
||||
if len(deleted) == 0 {
|
||||
return false
|
||||
}
|
||||
normalized := normalizeHeaderContextKey(headerName)
|
||||
if normalized == "" {
|
||||
return false
|
||||
}
|
||||
return deleted[normalized]
|
||||
}
|
||||
|
||||
func mergeHeaderOverrideSource(base, runtime map[string]interface{}, deleted map[string]bool) map[string]interface{} {
|
||||
merged := make(map[string]interface{}, len(base)+len(runtime))
|
||||
for key, value := range base {
|
||||
if isHeaderDeletedByRuntime(key, deleted) {
|
||||
continue
|
||||
}
|
||||
setHeaderOverrideEntry(merged, key, value)
|
||||
}
|
||||
for key, value := range runtime {
|
||||
setHeaderOverrideEntry(merged, key, value)
|
||||
}
|
||||
return merged
|
||||
}
|
||||
|
||||
func cloneDeletedHeaderKeys(source map[string]bool) map[string]bool {
|
||||
if len(source) == 0 {
|
||||
return map[string]bool{}
|
||||
}
|
||||
target := make(map[string]bool, len(source))
|
||||
for key, value := range source {
|
||||
if !value {
|
||||
continue
|
||||
}
|
||||
normalized := normalizeHeaderContextKey(key)
|
||||
if normalized == "" {
|
||||
continue
|
||||
}
|
||||
target[normalized] = true
|
||||
}
|
||||
return target
|
||||
}
|
||||
|
||||
func GetEffectiveHeaderOverride(info *RelayInfo) map[string]interface{} {
|
||||
if info == nil {
|
||||
return map[string]interface{}{}
|
||||
}
|
||||
base := getHeaderOverrideMap(info)
|
||||
if !info.UseRuntimeHeadersOverride {
|
||||
return cloneHeaderOverrideMap(base)
|
||||
}
|
||||
return mergeHeaderOverrideSource(base, info.RuntimeHeadersOverride, cloneDeletedHeaderKeys(info.RuntimeHeadersDeletedNormalized))
|
||||
}
|
||||
|
||||
func tryParseOperations(paramOverride map[string]interface{}) ([]ParamOperation, bool) {
|
||||
// 检查是否包含 "operations" 字段
|
||||
if opsValue, exists := paramOverride["operations"]; exists {
|
||||
@@ -480,6 +559,9 @@ func applyOperations(jsonStr string, operations []ParamOperation, conditionConte
|
||||
targetHeader = strings.TrimSpace(op.Path)
|
||||
}
|
||||
err = copyHeaderInContext(context, sourceHeader, targetHeader, op.KeepOrigin)
|
||||
if errors.Is(err, errSourceHeaderNotFound) {
|
||||
err = nil
|
||||
}
|
||||
if err == nil {
|
||||
contextJSON, err = marshalContextJSON(context)
|
||||
}
|
||||
@@ -493,6 +575,9 @@ func applyOperations(jsonStr string, operations []ParamOperation, conditionConte
|
||||
targetHeader = strings.TrimSpace(op.Path)
|
||||
}
|
||||
err = moveHeaderInContext(context, sourceHeader, targetHeader, op.KeepOrigin)
|
||||
if errors.Is(err, errSourceHeaderNotFound) {
|
||||
err = nil
|
||||
}
|
||||
if err == nil {
|
||||
contextJSON, err = marshalContextJSON(context)
|
||||
}
|
||||
@@ -647,8 +732,13 @@ func setHeaderOverrideInContext(context map[string]interface{}, headerName strin
|
||||
rawHeaders := ensureMapKeyInContext(context, paramOverrideContextHeaderOverride)
|
||||
rawHeaders[headerName] = headerValue
|
||||
|
||||
normalizedHeaderName := normalizeHeaderContextKey(headerName)
|
||||
normalizedHeaders := ensureMapKeyInContext(context, paramOverrideContextHeaderOverrideNormalized)
|
||||
normalizedHeaders[normalizeHeaderContextKey(headerName)] = headerValue
|
||||
normalizedHeaders[normalizedHeaderName] = headerValue
|
||||
if normalizedHeaderName != "" {
|
||||
deletedHeaders := ensureMapKeyInContext(context, paramOverrideContextHeaderOverrideDeleted)
|
||||
delete(deletedHeaders, normalizedHeaderName)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -693,7 +783,12 @@ func deleteHeaderOverrideInContext(context map[string]interface{}, headerName st
|
||||
}
|
||||
|
||||
normalizedHeaders := ensureMapKeyInContext(context, paramOverrideContextHeaderOverrideNormalized)
|
||||
delete(normalizedHeaders, normalizeHeaderContextKey(headerName))
|
||||
normalizedHeaderName := normalizeHeaderContextKey(headerName)
|
||||
delete(normalizedHeaders, normalizedHeaderName)
|
||||
if normalizedHeaderName != "" {
|
||||
deletedHeaders := ensureMapKeyInContext(context, paramOverrideContextHeaderOverrideDeleted)
|
||||
deletedHeaders[normalizedHeaderName] = true
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1062,9 +1157,39 @@ func syncRuntimeHeaderOverrideFromContext(info *RelayInfo, context map[string]in
|
||||
info.RuntimeHeadersOverride = lo.SliceToMap(sanitized, func(item lo.Entry[string, interface{}]) (string, interface{}) {
|
||||
return item.Key, item.Value
|
||||
})
|
||||
info.RuntimeHeadersDeletedNormalized = sanitizeRuntimeDeletedHeadersFromContext(context)
|
||||
info.UseRuntimeHeadersOverride = true
|
||||
}
|
||||
|
||||
func sanitizeRuntimeDeletedHeadersFromContext(context map[string]interface{}) map[string]bool {
|
||||
deletedRaw, exists := context[paramOverrideContextHeaderOverrideDeleted]
|
||||
if !exists {
|
||||
return nil
|
||||
}
|
||||
deletedMap, ok := deletedRaw.(map[string]interface{})
|
||||
if !ok || len(deletedMap) == 0 {
|
||||
return nil
|
||||
}
|
||||
entries := lo.Entries(deletedMap)
|
||||
sanitized := lo.FilterMap(entries, func(item lo.Entry[string, interface{}], _ int) (string, bool) {
|
||||
if keep, ok := item.Value.(bool); ok && !keep {
|
||||
return "", false
|
||||
}
|
||||
normalized := normalizeHeaderContextKey(item.Key)
|
||||
if normalized == "" {
|
||||
return "", false
|
||||
}
|
||||
return normalized, true
|
||||
})
|
||||
if len(sanitized) == 0 {
|
||||
return nil
|
||||
}
|
||||
keys := lo.Uniq(sanitized)
|
||||
return lo.SliceToMap(keys, func(item string) (string, bool) {
|
||||
return item, true
|
||||
})
|
||||
}
|
||||
|
||||
func moveValue(jsonStr, fromPath, toPath string) (string, error) {
|
||||
sourceValue := gjson.Get(jsonStr, fromPath)
|
||||
if !sourceValue.Exists() {
|
||||
@@ -1513,13 +1638,13 @@ func BuildParamOverrideContext(info *RelayInfo) map[string]interface{} {
|
||||
ctx[paramOverrideContextRequestHeaders] = buildNormalizedHeaders(info.RequestHeaders)
|
||||
ctx[paramOverrideContextRequestHeadersRaw] = buildRawHeaders(info.RequestHeaders)
|
||||
|
||||
headerOverrideSource := getHeaderOverrideMap(info)
|
||||
if info.UseRuntimeHeadersOverride {
|
||||
headerOverrideSource = info.RuntimeHeadersOverride
|
||||
}
|
||||
headerOverrideSource := GetEffectiveHeaderOverride(info)
|
||||
rawHeaderOverride, normalizedHeaderOverride := buildHeaderOverrideContext(headerOverrideSource)
|
||||
ctx[paramOverrideContextHeaderOverride] = rawHeaderOverride
|
||||
ctx[paramOverrideContextHeaderOverrideNormalized] = normalizedHeaderOverride
|
||||
ctx[paramOverrideContextHeaderOverrideDeleted] = lo.SliceToMap(lo.Keys(cloneDeletedHeaderKeys(info.RuntimeHeadersDeletedNormalized)), func(item string) (string, interface{}) {
|
||||
return item, true
|
||||
})
|
||||
|
||||
ctx["retry_index"] = info.RetryIndex
|
||||
ctx["is_retry"] = info.RetryIndex > 0
|
||||
|
||||
@@ -1097,6 +1097,76 @@ func TestApplyParamOverridePassHeadersSkipsMissingHeaders(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyParamOverrideCopyHeaderSkipsMissingSource(t *testing.T) {
|
||||
input := []byte(`{"temperature":0.7}`)
|
||||
override := map[string]interface{}{
|
||||
"operations": []interface{}{
|
||||
map[string]interface{}{
|
||||
"mode": "copy_header",
|
||||
"from": "X-Missing-Header",
|
||||
"to": "X-Upstream-Auth",
|
||||
},
|
||||
},
|
||||
}
|
||||
ctx := map[string]interface{}{
|
||||
"request_headers_raw": map[string]interface{}{
|
||||
"Authorization": "Bearer token-123",
|
||||
},
|
||||
"request_headers": map[string]interface{}{
|
||||
"authorization": "Bearer token-123",
|
||||
},
|
||||
}
|
||||
|
||||
out, err := ApplyParamOverride(input, override, ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("ApplyParamOverride returned error: %v", err)
|
||||
}
|
||||
assertJSONEqual(t, `{"temperature":0.7}`, string(out))
|
||||
|
||||
headers, ok := ctx["header_override"].(map[string]interface{})
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if _, exists := headers["X-Upstream-Auth"]; exists {
|
||||
t.Fatalf("expected X-Upstream-Auth to be skipped when source header is missing")
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyParamOverrideMoveHeaderSkipsMissingSource(t *testing.T) {
|
||||
input := []byte(`{"temperature":0.7}`)
|
||||
override := map[string]interface{}{
|
||||
"operations": []interface{}{
|
||||
map[string]interface{}{
|
||||
"mode": "move_header",
|
||||
"from": "X-Missing-Header",
|
||||
"to": "X-Upstream-Auth",
|
||||
},
|
||||
},
|
||||
}
|
||||
ctx := map[string]interface{}{
|
||||
"request_headers_raw": map[string]interface{}{
|
||||
"Authorization": "Bearer token-123",
|
||||
},
|
||||
"request_headers": map[string]interface{}{
|
||||
"authorization": "Bearer token-123",
|
||||
},
|
||||
}
|
||||
|
||||
out, err := ApplyParamOverride(input, override, ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("ApplyParamOverride returned error: %v", err)
|
||||
}
|
||||
assertJSONEqual(t, `{"temperature":0.7}`, string(out))
|
||||
|
||||
headers, ok := ctx["header_override"].(map[string]interface{})
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if _, exists := headers["X-Upstream-Auth"]; exists {
|
||||
t.Fatalf("expected X-Upstream-Auth to be skipped when source header is missing")
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyParamOverrideSyncFieldsHeaderToJSON(t *testing.T) {
|
||||
input := []byte(`{"model":"gpt-4"}`)
|
||||
override := map[string]interface{}{
|
||||
@@ -1351,6 +1421,35 @@ func TestApplyParamOverrideWithRelayInfoMoveAndCopyHeaders(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetEffectiveHeaderOverrideMergesRuntimeAndChannelOverrides(t *testing.T) {
|
||||
info := &RelayInfo{
|
||||
UseRuntimeHeadersOverride: true,
|
||||
RuntimeHeadersOverride: map[string]interface{}{
|
||||
"X-Runtime": "runtime-only",
|
||||
},
|
||||
RuntimeHeadersDeletedNormalized: map[string]bool{
|
||||
"x-deleted": true,
|
||||
},
|
||||
ChannelMeta: &ChannelMeta{
|
||||
HeadersOverride: map[string]interface{}{
|
||||
"X-Static": "static-value",
|
||||
"X-Deleted": "should-not-exist",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
effective := GetEffectiveHeaderOverride(info)
|
||||
if effective["X-Static"] != "static-value" {
|
||||
t.Fatalf("expected X-Static from channel override, got: %v", effective["X-Static"])
|
||||
}
|
||||
if effective["X-Runtime"] != "runtime-only" {
|
||||
t.Fatalf("expected X-Runtime from runtime override, got: %v", effective["X-Runtime"])
|
||||
}
|
||||
if _, exists := effective["X-Deleted"]; exists {
|
||||
t.Fatalf("expected deleted headers to stay deleted in effective override")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRemoveDisabledFieldsSkipWhenChannelPassThroughEnabled(t *testing.T) {
|
||||
input := `{
|
||||
"service_tier":"flex",
|
||||
|
||||
@@ -148,6 +148,7 @@ type RelayInfo struct {
|
||||
RetryIndex int
|
||||
LastError *types.NewAPIError
|
||||
RuntimeHeadersOverride map[string]interface{}
|
||||
RuntimeHeadersDeletedNormalized map[string]bool
|
||||
UseRuntimeHeadersOverride bool
|
||||
|
||||
PriceData types.PriceData
|
||||
|
||||
Reference in New Issue
Block a user