mirror of
https://github.com/QuantumNous/new-api.git
synced 2026-03-30 02:25:00 +00:00
fix: merge runtime and channel header overrides, skip missing source headers
This commit is contained in:
@@ -173,10 +173,7 @@ func processHeaderOverride(info *common.RelayInfo, c *gin.Context) (map[string]s
|
|||||||
return headerOverride, nil
|
return headerOverride, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
headerOverrideSource := info.HeadersOverride
|
headerOverrideSource := common.GetEffectiveHeaderOverride(info)
|
||||||
if info.UseRuntimeHeadersOverride {
|
|
||||||
headerOverrideSource = info.RuntimeHeadersOverride
|
|
||||||
}
|
|
||||||
|
|
||||||
passAll := false
|
passAll := false
|
||||||
var passthroughRegex []*regexp.Regexp
|
var passthroughRegex []*regexp.Regexp
|
||||||
|
|||||||
@@ -80,7 +80,7 @@ func TestProcessHeaderOverride_NonTestKeepsClientHeaderPlaceholder(t *testing.T)
|
|||||||
require.Equal(t, "trace-123", headers["X-Upstream-Trace"])
|
require.Equal(t, "trace-123", headers["X-Upstream-Trace"])
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestProcessHeaderOverride_RuntimeOverrideHasPriority(t *testing.T) {
|
func TestProcessHeaderOverride_RuntimeOverrideMergesWithChannelOverride(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
gin.SetMode(gin.TestMode)
|
gin.SetMode(gin.TestMode)
|
||||||
@@ -107,8 +107,7 @@ func TestProcessHeaderOverride_RuntimeOverrideHasPriority(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Equal(t, "runtime-value", headers["X-Static"])
|
require.Equal(t, "runtime-value", headers["X-Static"])
|
||||||
require.Equal(t, "runtime-only", headers["X-Runtime"])
|
require.Equal(t, "runtime-only", headers["X-Runtime"])
|
||||||
_, ok := headers["X-Legacy"]
|
require.Equal(t, "legacy-only", headers["X-Legacy"])
|
||||||
require.False(t, ok)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestProcessHeaderOverride_PassthroughSkipsAcceptEncoding(t *testing.T) {
|
func TestProcessHeaderOverride_PassthroughSkipsAcceptEncoding(t *testing.T) {
|
||||||
|
|||||||
@@ -22,6 +22,7 @@ const (
|
|||||||
paramOverrideContextRequestHeadersRaw = "request_headers_raw"
|
paramOverrideContextRequestHeadersRaw = "request_headers_raw"
|
||||||
paramOverrideContextHeaderOverride = "header_override"
|
paramOverrideContextHeaderOverride = "header_override"
|
||||||
paramOverrideContextHeaderOverrideNormalized = "header_override_normalized"
|
paramOverrideContextHeaderOverrideNormalized = "header_override_normalized"
|
||||||
|
paramOverrideContextHeaderOverrideDeleted = "header_override_deleted_normalized"
|
||||||
)
|
)
|
||||||
|
|
||||||
var errSourceHeaderNotFound = errors.New("source header does not exist")
|
var errSourceHeaderNotFound = errors.New("source header does not exist")
|
||||||
@@ -160,6 +161,84 @@ func getHeaderOverrideMap(info *RelayInfo) map[string]interface{} {
|
|||||||
return info.ChannelMeta.HeadersOverride
|
return info.ChannelMeta.HeadersOverride
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func cloneHeaderOverrideMap(source map[string]interface{}) map[string]interface{} {
|
||||||
|
if len(source) == 0 {
|
||||||
|
return map[string]interface{}{}
|
||||||
|
}
|
||||||
|
target := make(map[string]interface{}, len(source))
|
||||||
|
for key, value := range source {
|
||||||
|
target[key] = value
|
||||||
|
}
|
||||||
|
return target
|
||||||
|
}
|
||||||
|
|
||||||
|
func setHeaderOverrideEntry(target map[string]interface{}, key string, value interface{}) {
|
||||||
|
key = strings.TrimSpace(key)
|
||||||
|
if key == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for existing := range target {
|
||||||
|
if strings.EqualFold(strings.TrimSpace(existing), key) {
|
||||||
|
delete(target, existing)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
target[key] = value
|
||||||
|
}
|
||||||
|
|
||||||
|
func isHeaderDeletedByRuntime(headerName string, deleted map[string]bool) bool {
|
||||||
|
if len(deleted) == 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
normalized := normalizeHeaderContextKey(headerName)
|
||||||
|
if normalized == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return deleted[normalized]
|
||||||
|
}
|
||||||
|
|
||||||
|
func mergeHeaderOverrideSource(base, runtime map[string]interface{}, deleted map[string]bool) map[string]interface{} {
|
||||||
|
merged := make(map[string]interface{}, len(base)+len(runtime))
|
||||||
|
for key, value := range base {
|
||||||
|
if isHeaderDeletedByRuntime(key, deleted) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
setHeaderOverrideEntry(merged, key, value)
|
||||||
|
}
|
||||||
|
for key, value := range runtime {
|
||||||
|
setHeaderOverrideEntry(merged, key, value)
|
||||||
|
}
|
||||||
|
return merged
|
||||||
|
}
|
||||||
|
|
||||||
|
func cloneDeletedHeaderKeys(source map[string]bool) map[string]bool {
|
||||||
|
if len(source) == 0 {
|
||||||
|
return map[string]bool{}
|
||||||
|
}
|
||||||
|
target := make(map[string]bool, len(source))
|
||||||
|
for key, value := range source {
|
||||||
|
if !value {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
normalized := normalizeHeaderContextKey(key)
|
||||||
|
if normalized == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
target[normalized] = true
|
||||||
|
}
|
||||||
|
return target
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetEffectiveHeaderOverride(info *RelayInfo) map[string]interface{} {
|
||||||
|
if info == nil {
|
||||||
|
return map[string]interface{}{}
|
||||||
|
}
|
||||||
|
base := getHeaderOverrideMap(info)
|
||||||
|
if !info.UseRuntimeHeadersOverride {
|
||||||
|
return cloneHeaderOverrideMap(base)
|
||||||
|
}
|
||||||
|
return mergeHeaderOverrideSource(base, info.RuntimeHeadersOverride, cloneDeletedHeaderKeys(info.RuntimeHeadersDeletedNormalized))
|
||||||
|
}
|
||||||
|
|
||||||
func tryParseOperations(paramOverride map[string]interface{}) ([]ParamOperation, bool) {
|
func tryParseOperations(paramOverride map[string]interface{}) ([]ParamOperation, bool) {
|
||||||
// 检查是否包含 "operations" 字段
|
// 检查是否包含 "operations" 字段
|
||||||
if opsValue, exists := paramOverride["operations"]; exists {
|
if opsValue, exists := paramOverride["operations"]; exists {
|
||||||
@@ -480,6 +559,9 @@ func applyOperations(jsonStr string, operations []ParamOperation, conditionConte
|
|||||||
targetHeader = strings.TrimSpace(op.Path)
|
targetHeader = strings.TrimSpace(op.Path)
|
||||||
}
|
}
|
||||||
err = copyHeaderInContext(context, sourceHeader, targetHeader, op.KeepOrigin)
|
err = copyHeaderInContext(context, sourceHeader, targetHeader, op.KeepOrigin)
|
||||||
|
if errors.Is(err, errSourceHeaderNotFound) {
|
||||||
|
err = nil
|
||||||
|
}
|
||||||
if err == nil {
|
if err == nil {
|
||||||
contextJSON, err = marshalContextJSON(context)
|
contextJSON, err = marshalContextJSON(context)
|
||||||
}
|
}
|
||||||
@@ -493,6 +575,9 @@ func applyOperations(jsonStr string, operations []ParamOperation, conditionConte
|
|||||||
targetHeader = strings.TrimSpace(op.Path)
|
targetHeader = strings.TrimSpace(op.Path)
|
||||||
}
|
}
|
||||||
err = moveHeaderInContext(context, sourceHeader, targetHeader, op.KeepOrigin)
|
err = moveHeaderInContext(context, sourceHeader, targetHeader, op.KeepOrigin)
|
||||||
|
if errors.Is(err, errSourceHeaderNotFound) {
|
||||||
|
err = nil
|
||||||
|
}
|
||||||
if err == nil {
|
if err == nil {
|
||||||
contextJSON, err = marshalContextJSON(context)
|
contextJSON, err = marshalContextJSON(context)
|
||||||
}
|
}
|
||||||
@@ -647,8 +732,13 @@ func setHeaderOverrideInContext(context map[string]interface{}, headerName strin
|
|||||||
rawHeaders := ensureMapKeyInContext(context, paramOverrideContextHeaderOverride)
|
rawHeaders := ensureMapKeyInContext(context, paramOverrideContextHeaderOverride)
|
||||||
rawHeaders[headerName] = headerValue
|
rawHeaders[headerName] = headerValue
|
||||||
|
|
||||||
|
normalizedHeaderName := normalizeHeaderContextKey(headerName)
|
||||||
normalizedHeaders := ensureMapKeyInContext(context, paramOverrideContextHeaderOverrideNormalized)
|
normalizedHeaders := ensureMapKeyInContext(context, paramOverrideContextHeaderOverrideNormalized)
|
||||||
normalizedHeaders[normalizeHeaderContextKey(headerName)] = headerValue
|
normalizedHeaders[normalizedHeaderName] = headerValue
|
||||||
|
if normalizedHeaderName != "" {
|
||||||
|
deletedHeaders := ensureMapKeyInContext(context, paramOverrideContextHeaderOverrideDeleted)
|
||||||
|
delete(deletedHeaders, normalizedHeaderName)
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -693,7 +783,12 @@ func deleteHeaderOverrideInContext(context map[string]interface{}, headerName st
|
|||||||
}
|
}
|
||||||
|
|
||||||
normalizedHeaders := ensureMapKeyInContext(context, paramOverrideContextHeaderOverrideNormalized)
|
normalizedHeaders := ensureMapKeyInContext(context, paramOverrideContextHeaderOverrideNormalized)
|
||||||
delete(normalizedHeaders, normalizeHeaderContextKey(headerName))
|
normalizedHeaderName := normalizeHeaderContextKey(headerName)
|
||||||
|
delete(normalizedHeaders, normalizedHeaderName)
|
||||||
|
if normalizedHeaderName != "" {
|
||||||
|
deletedHeaders := ensureMapKeyInContext(context, paramOverrideContextHeaderOverrideDeleted)
|
||||||
|
deletedHeaders[normalizedHeaderName] = true
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1062,9 +1157,39 @@ func syncRuntimeHeaderOverrideFromContext(info *RelayInfo, context map[string]in
|
|||||||
info.RuntimeHeadersOverride = lo.SliceToMap(sanitized, func(item lo.Entry[string, interface{}]) (string, interface{}) {
|
info.RuntimeHeadersOverride = lo.SliceToMap(sanitized, func(item lo.Entry[string, interface{}]) (string, interface{}) {
|
||||||
return item.Key, item.Value
|
return item.Key, item.Value
|
||||||
})
|
})
|
||||||
|
info.RuntimeHeadersDeletedNormalized = sanitizeRuntimeDeletedHeadersFromContext(context)
|
||||||
info.UseRuntimeHeadersOverride = true
|
info.UseRuntimeHeadersOverride = true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func sanitizeRuntimeDeletedHeadersFromContext(context map[string]interface{}) map[string]bool {
|
||||||
|
deletedRaw, exists := context[paramOverrideContextHeaderOverrideDeleted]
|
||||||
|
if !exists {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
deletedMap, ok := deletedRaw.(map[string]interface{})
|
||||||
|
if !ok || len(deletedMap) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
entries := lo.Entries(deletedMap)
|
||||||
|
sanitized := lo.FilterMap(entries, func(item lo.Entry[string, interface{}], _ int) (string, bool) {
|
||||||
|
if keep, ok := item.Value.(bool); ok && !keep {
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
normalized := normalizeHeaderContextKey(item.Key)
|
||||||
|
if normalized == "" {
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
return normalized, true
|
||||||
|
})
|
||||||
|
if len(sanitized) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
keys := lo.Uniq(sanitized)
|
||||||
|
return lo.SliceToMap(keys, func(item string) (string, bool) {
|
||||||
|
return item, true
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func moveValue(jsonStr, fromPath, toPath string) (string, error) {
|
func moveValue(jsonStr, fromPath, toPath string) (string, error) {
|
||||||
sourceValue := gjson.Get(jsonStr, fromPath)
|
sourceValue := gjson.Get(jsonStr, fromPath)
|
||||||
if !sourceValue.Exists() {
|
if !sourceValue.Exists() {
|
||||||
@@ -1513,13 +1638,13 @@ func BuildParamOverrideContext(info *RelayInfo) map[string]interface{} {
|
|||||||
ctx[paramOverrideContextRequestHeaders] = buildNormalizedHeaders(info.RequestHeaders)
|
ctx[paramOverrideContextRequestHeaders] = buildNormalizedHeaders(info.RequestHeaders)
|
||||||
ctx[paramOverrideContextRequestHeadersRaw] = buildRawHeaders(info.RequestHeaders)
|
ctx[paramOverrideContextRequestHeadersRaw] = buildRawHeaders(info.RequestHeaders)
|
||||||
|
|
||||||
headerOverrideSource := getHeaderOverrideMap(info)
|
headerOverrideSource := GetEffectiveHeaderOverride(info)
|
||||||
if info.UseRuntimeHeadersOverride {
|
|
||||||
headerOverrideSource = info.RuntimeHeadersOverride
|
|
||||||
}
|
|
||||||
rawHeaderOverride, normalizedHeaderOverride := buildHeaderOverrideContext(headerOverrideSource)
|
rawHeaderOverride, normalizedHeaderOverride := buildHeaderOverrideContext(headerOverrideSource)
|
||||||
ctx[paramOverrideContextHeaderOverride] = rawHeaderOverride
|
ctx[paramOverrideContextHeaderOverride] = rawHeaderOverride
|
||||||
ctx[paramOverrideContextHeaderOverrideNormalized] = normalizedHeaderOverride
|
ctx[paramOverrideContextHeaderOverrideNormalized] = normalizedHeaderOverride
|
||||||
|
ctx[paramOverrideContextHeaderOverrideDeleted] = lo.SliceToMap(lo.Keys(cloneDeletedHeaderKeys(info.RuntimeHeadersDeletedNormalized)), func(item string) (string, interface{}) {
|
||||||
|
return item, true
|
||||||
|
})
|
||||||
|
|
||||||
ctx["retry_index"] = info.RetryIndex
|
ctx["retry_index"] = info.RetryIndex
|
||||||
ctx["is_retry"] = info.RetryIndex > 0
|
ctx["is_retry"] = info.RetryIndex > 0
|
||||||
|
|||||||
@@ -1097,6 +1097,76 @@ func TestApplyParamOverridePassHeadersSkipsMissingHeaders(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestApplyParamOverrideCopyHeaderSkipsMissingSource(t *testing.T) {
|
||||||
|
input := []byte(`{"temperature":0.7}`)
|
||||||
|
override := map[string]interface{}{
|
||||||
|
"operations": []interface{}{
|
||||||
|
map[string]interface{}{
|
||||||
|
"mode": "copy_header",
|
||||||
|
"from": "X-Missing-Header",
|
||||||
|
"to": "X-Upstream-Auth",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
ctx := map[string]interface{}{
|
||||||
|
"request_headers_raw": map[string]interface{}{
|
||||||
|
"Authorization": "Bearer token-123",
|
||||||
|
},
|
||||||
|
"request_headers": map[string]interface{}{
|
||||||
|
"authorization": "Bearer token-123",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
out, err := ApplyParamOverride(input, override, ctx)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ApplyParamOverride returned error: %v", err)
|
||||||
|
}
|
||||||
|
assertJSONEqual(t, `{"temperature":0.7}`, string(out))
|
||||||
|
|
||||||
|
headers, ok := ctx["header_override"].(map[string]interface{})
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if _, exists := headers["X-Upstream-Auth"]; exists {
|
||||||
|
t.Fatalf("expected X-Upstream-Auth to be skipped when source header is missing")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestApplyParamOverrideMoveHeaderSkipsMissingSource(t *testing.T) {
|
||||||
|
input := []byte(`{"temperature":0.7}`)
|
||||||
|
override := map[string]interface{}{
|
||||||
|
"operations": []interface{}{
|
||||||
|
map[string]interface{}{
|
||||||
|
"mode": "move_header",
|
||||||
|
"from": "X-Missing-Header",
|
||||||
|
"to": "X-Upstream-Auth",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
ctx := map[string]interface{}{
|
||||||
|
"request_headers_raw": map[string]interface{}{
|
||||||
|
"Authorization": "Bearer token-123",
|
||||||
|
},
|
||||||
|
"request_headers": map[string]interface{}{
|
||||||
|
"authorization": "Bearer token-123",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
out, err := ApplyParamOverride(input, override, ctx)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ApplyParamOverride returned error: %v", err)
|
||||||
|
}
|
||||||
|
assertJSONEqual(t, `{"temperature":0.7}`, string(out))
|
||||||
|
|
||||||
|
headers, ok := ctx["header_override"].(map[string]interface{})
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if _, exists := headers["X-Upstream-Auth"]; exists {
|
||||||
|
t.Fatalf("expected X-Upstream-Auth to be skipped when source header is missing")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestApplyParamOverrideSyncFieldsHeaderToJSON(t *testing.T) {
|
func TestApplyParamOverrideSyncFieldsHeaderToJSON(t *testing.T) {
|
||||||
input := []byte(`{"model":"gpt-4"}`)
|
input := []byte(`{"model":"gpt-4"}`)
|
||||||
override := map[string]interface{}{
|
override := map[string]interface{}{
|
||||||
@@ -1351,6 +1421,35 @@ func TestApplyParamOverrideWithRelayInfoMoveAndCopyHeaders(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestGetEffectiveHeaderOverrideMergesRuntimeAndChannelOverrides(t *testing.T) {
|
||||||
|
info := &RelayInfo{
|
||||||
|
UseRuntimeHeadersOverride: true,
|
||||||
|
RuntimeHeadersOverride: map[string]interface{}{
|
||||||
|
"X-Runtime": "runtime-only",
|
||||||
|
},
|
||||||
|
RuntimeHeadersDeletedNormalized: map[string]bool{
|
||||||
|
"x-deleted": true,
|
||||||
|
},
|
||||||
|
ChannelMeta: &ChannelMeta{
|
||||||
|
HeadersOverride: map[string]interface{}{
|
||||||
|
"X-Static": "static-value",
|
||||||
|
"X-Deleted": "should-not-exist",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
effective := GetEffectiveHeaderOverride(info)
|
||||||
|
if effective["X-Static"] != "static-value" {
|
||||||
|
t.Fatalf("expected X-Static from channel override, got: %v", effective["X-Static"])
|
||||||
|
}
|
||||||
|
if effective["X-Runtime"] != "runtime-only" {
|
||||||
|
t.Fatalf("expected X-Runtime from runtime override, got: %v", effective["X-Runtime"])
|
||||||
|
}
|
||||||
|
if _, exists := effective["X-Deleted"]; exists {
|
||||||
|
t.Fatalf("expected deleted headers to stay deleted in effective override")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestRemoveDisabledFieldsSkipWhenChannelPassThroughEnabled(t *testing.T) {
|
func TestRemoveDisabledFieldsSkipWhenChannelPassThroughEnabled(t *testing.T) {
|
||||||
input := `{
|
input := `{
|
||||||
"service_tier":"flex",
|
"service_tier":"flex",
|
||||||
|
|||||||
@@ -148,6 +148,7 @@ type RelayInfo struct {
|
|||||||
RetryIndex int
|
RetryIndex int
|
||||||
LastError *types.NewAPIError
|
LastError *types.NewAPIError
|
||||||
RuntimeHeadersOverride map[string]interface{}
|
RuntimeHeadersOverride map[string]interface{}
|
||||||
|
RuntimeHeadersDeletedNormalized map[string]bool
|
||||||
UseRuntimeHeadersOverride bool
|
UseRuntimeHeadersOverride bool
|
||||||
|
|
||||||
PriceData types.PriceData
|
PriceData types.PriceData
|
||||||
|
|||||||
Reference in New Issue
Block a user