feat: unify param/header overrides with retry-aware conditions and flexible header operations

This commit is contained in:
Seefs
2026-02-22 00:45:49 +08:00
parent ff76e75f4c
commit 91b300f522
14 changed files with 738 additions and 87 deletions

View File

@@ -385,7 +385,7 @@ func testChannel(channel *model.Channel, testModel string, endpointType string,
//}
if len(info.ParamOverride) > 0 {
jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride, relaycommon.BuildParamOverrideContext(info))
jsonData, err = relaycommon.ApplyParamOverrideWithRelayInfo(jsonData, info)
if err != nil {
if fixedErr, ok := relaycommon.AsParamOverrideReturnError(err); ok {
return testResult{

View File

@@ -168,11 +168,19 @@ func applyHeaderOverridePlaceholders(template string, c *gin.Context, apiKey str
// Passthrough rules are applied first, then normal overrides are applied, so explicit overrides win.
func processHeaderOverride(info *common.RelayInfo, c *gin.Context) (map[string]string, error) {
headerOverride := make(map[string]string)
if info == nil {
return headerOverride, nil
}
headerOverrideSource := info.HeadersOverride
if info.UseRuntimeHeadersOverride {
headerOverrideSource = info.RuntimeHeadersOverride
}
passAll := false
var passthroughRegex []*regexp.Regexp
if !info.IsChannelTest {
for k := range info.HeadersOverride {
for k := range headerOverrideSource {
key := strings.TrimSpace(k)
if key == "" {
continue
@@ -232,7 +240,7 @@ func processHeaderOverride(info *common.RelayInfo, c *gin.Context) (map[string]s
}
}
for k, v := range info.HeadersOverride {
for k, v := range headerOverrideSource {
if isHeaderPassthroughRuleKey(k) {
continue
}

View File

@@ -79,3 +79,34 @@ func TestProcessHeaderOverride_NonTestKeepsClientHeaderPlaceholder(t *testing.T)
require.NoError(t, err)
require.Equal(t, "trace-123", headers["X-Upstream-Trace"])
}
func TestProcessHeaderOverride_RuntimeOverrideHasPriority(t *testing.T) {
t.Parallel()
gin.SetMode(gin.TestMode)
recorder := httptest.NewRecorder()
ctx, _ := gin.CreateTestContext(recorder)
ctx.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
info := &relaycommon.RelayInfo{
IsChannelTest: false,
UseRuntimeHeadersOverride: true,
RuntimeHeadersOverride: map[string]any{
"X-Static": "runtime-value",
"X-Runtime": "runtime-only",
},
ChannelMeta: &relaycommon.ChannelMeta{
HeadersOverride: map[string]any{
"X-Static": "legacy-value",
"X-Legacy": "legacy-only",
},
},
}
headers, err := processHeaderOverride(info, ctx)
require.NoError(t, err)
require.Equal(t, "runtime-value", headers["X-Static"])
require.Equal(t, "runtime-only", headers["X-Runtime"])
_, ok := headers["X-Legacy"]
require.False(t, ok)
}

View File

@@ -70,7 +70,6 @@ func applySystemPromptIfNeeded(c *gin.Context, info *relaycommon.RelayInfo, requ
}
func chatCompletionsViaResponses(c *gin.Context, info *relaycommon.RelayInfo, adaptor channel.Adaptor, request *dto.GeneralOpenAIRequest) (*dto.Usage, *types.NewAPIError) {
overrideCtx := relaycommon.BuildParamOverrideContext(info)
chatJSON, err := common.Marshal(request)
if err != nil {
return nil, types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
@@ -82,7 +81,7 @@ func chatCompletionsViaResponses(c *gin.Context, info *relaycommon.RelayInfo, ad
}
if len(info.ParamOverride) > 0 {
chatJSON, err = relaycommon.ApplyParamOverride(chatJSON, info.ParamOverride, overrideCtx)
chatJSON, err = relaycommon.ApplyParamOverrideWithRelayInfo(chatJSON, info)
if err != nil {
return nil, newAPIErrorFromParamOverride(err)
}

View File

@@ -153,7 +153,7 @@ func ClaudeHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
// apply param override
if len(info.ParamOverride) > 0 {
jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride, relaycommon.BuildParamOverrideContext(info))
jsonData, err = relaycommon.ApplyParamOverrideWithRelayInfo(jsonData, info)
if err != nil {
return newAPIErrorFromParamOverride(err)
}

View File

@@ -10,12 +10,20 @@ import (
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/types"
"github.com/samber/lo"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
var negativeIndexRegexp = regexp.MustCompile(`\.(-\d+)`)
const (
paramOverrideContextRequestHeaders = "request_headers"
paramOverrideContextRequestHeadersRaw = "request_headers_raw"
paramOverrideContextHeaderOverride = "header_override"
paramOverrideContextHeaderOverrideNormalized = "header_override_normalized"
)
type ConditionOperation struct {
Path string `json:"path"` // JSON路径
Mode string `json:"mode"` // full, prefix, suffix, contains, gt, gte, lt, lte
@@ -26,7 +34,7 @@ type ConditionOperation struct {
type ParamOperation struct {
Path string `json:"path"`
Mode string `json:"mode"` // delete, set, move, copy, prepend, append, trim_prefix, trim_suffix, ensure_prefix, ensure_suffix, trim_space, to_lower, to_upper, replace, regex_replace, return_error, prune_objects
Mode string `json:"mode"` // delete, set, move, copy, prepend, append, trim_prefix, trim_suffix, ensure_prefix, ensure_suffix, trim_space, to_lower, to_upper, replace, regex_replace, return_error, prune_objects, set_header, delete_header, copy_header, move_header
Value interface{} `json:"value"`
KeepOrigin bool `json:"keep_origin"`
From string `json:"from,omitempty"`
@@ -121,6 +129,35 @@ func ApplyParamOverride(jsonData []byte, paramOverride map[string]interface{}, c
return applyOperationsLegacy(jsonData, paramOverride)
}
func ApplyParamOverrideWithRelayInfo(jsonData []byte, info *RelayInfo) ([]byte, error) {
paramOverride := getParamOverrideMap(info)
if len(paramOverride) == 0 {
return jsonData, nil
}
overrideCtx := BuildParamOverrideContext(info)
result, err := ApplyParamOverride(jsonData, paramOverride, overrideCtx)
if err != nil {
return nil, err
}
syncRuntimeHeaderOverrideFromContext(info, overrideCtx)
return result, nil
}
func getParamOverrideMap(info *RelayInfo) map[string]interface{} {
if info == nil || info.ChannelMeta == nil {
return nil
}
return info.ChannelMeta.ParamOverride
}
func getHeaderOverrideMap(info *RelayInfo) map[string]interface{} {
if info == nil || info.ChannelMeta == nil {
return nil
}
return info.ChannelMeta.HeadersOverride
}
func tryParseOperations(paramOverride map[string]interface{}) ([]ParamOperation, bool) {
// 检查是否包含 "operations" 字段
if opsValue, exists := paramOverride["operations"]; exists {
@@ -161,29 +198,11 @@ func tryParseOperations(paramOverride map[string]interface{}) ([]ParamOperation,
// 解析条件
if conditions, exists := opMap["conditions"]; exists {
if condSlice, ok := conditions.([]interface{}); ok {
for _, cond := range condSlice {
if condMap, ok := cond.(map[string]interface{}); ok {
condition := ConditionOperation{}
if path, ok := condMap["path"].(string); ok {
condition.Path = path
}
if mode, ok := condMap["mode"].(string); ok {
condition.Mode = mode
}
if value, ok := condMap["value"]; ok {
condition.Value = value
}
if invert, ok := condMap["invert"].(bool); ok {
condition.Invert = invert
}
if passMissingKey, ok := condMap["pass_missing_key"].(bool); ok {
condition.PassMissingKey = passMissingKey
}
operation.Conditions = append(operation.Conditions, condition)
}
}
parsedConditions, err := parseConditionOperations(conditions)
if err != nil {
return nil, false
}
operation.Conditions = append(operation.Conditions, parsedConditions...)
}
operations = append(operations, operation)
@@ -212,20 +231,9 @@ func checkConditions(jsonStr, contextJSON string, conditions []ConditionOperatio
}
if strings.ToUpper(logic) == "AND" {
for _, result := range results {
if !result {
return false, nil
}
}
return true, nil
} else {
for _, result := range results {
if result {
return true, nil
}
}
return false, nil
return lo.EveryBy(results, func(item bool) bool { return item }), nil
}
return lo.SomeBy(results, func(item bool) bool { return item }), nil
}
func checkSingleCondition(jsonStr, contextJSON string, condition ConditionOperation) (bool, error) {
@@ -382,13 +390,10 @@ func applyOperationsLegacy(jsonData []byte, paramOverride map[string]interface{}
}
func applyOperations(jsonStr string, operations []ParamOperation, conditionContext map[string]interface{}) (string, error) {
var contextJSON string
if conditionContext != nil && len(conditionContext) > 0 {
ctxBytes, err := common.Marshal(conditionContext)
if err != nil {
return "", fmt.Errorf("failed to marshal condition context: %v", err)
}
contextJSON = string(ctxBytes)
context := ensureContextMap(conditionContext)
contextJSON, err := marshalContextJSON(context)
if err != nil {
return "", fmt.Errorf("failed to marshal condition context: %v", err)
}
result := jsonStr
@@ -453,6 +458,42 @@ func applyOperations(jsonStr string, operations []ParamOperation, conditionConte
return "", returnErr
case "prune_objects":
result, err = pruneObjects(result, opPath, contextJSON, op.Value)
case "set_header":
err = setHeaderOverrideInContext(context, op.Path, op.Value, op.KeepOrigin)
if err == nil {
contextJSON, err = marshalContextJSON(context)
}
case "delete_header":
err = deleteHeaderOverrideInContext(context, op.Path)
if err == nil {
contextJSON, err = marshalContextJSON(context)
}
case "copy_header":
sourceHeader := strings.TrimSpace(op.From)
targetHeader := strings.TrimSpace(op.To)
if sourceHeader == "" {
sourceHeader = strings.TrimSpace(op.Path)
}
if targetHeader == "" {
targetHeader = strings.TrimSpace(op.Path)
}
err = copyHeaderInContext(context, sourceHeader, targetHeader, op.KeepOrigin)
if err == nil {
contextJSON, err = marshalContextJSON(context)
}
case "move_header":
sourceHeader := strings.TrimSpace(op.From)
targetHeader := strings.TrimSpace(op.To)
if sourceHeader == "" {
sourceHeader = strings.TrimSpace(op.Path)
}
if targetHeader == "" {
targetHeader = strings.TrimSpace(op.Path)
}
err = moveHeaderInContext(context, sourceHeader, targetHeader, op.KeepOrigin)
if err == nil {
contextJSON, err = marshalContextJSON(context)
}
default:
return "", fmt.Errorf("unknown operation: %s", op.Mode)
}
@@ -543,6 +584,276 @@ func parseOverrideInt(v interface{}) (int, bool) {
}
}
func ensureContextMap(conditionContext map[string]interface{}) map[string]interface{} {
if conditionContext != nil {
return conditionContext
}
return make(map[string]interface{})
}
func marshalContextJSON(context map[string]interface{}) (string, error) {
if context == nil || len(context) == 0 {
return "", nil
}
ctxBytes, err := common.Marshal(context)
if err != nil {
return "", err
}
return string(ctxBytes), nil
}
func setHeaderOverrideInContext(context map[string]interface{}, headerName string, value interface{}, keepOrigin bool) error {
headerName = strings.TrimSpace(headerName)
if headerName == "" {
return fmt.Errorf("header name is required")
}
if keepOrigin {
if _, exists := getHeaderValueFromContext(context, headerName); exists {
return nil
}
}
if value == nil {
return fmt.Errorf("header value is required")
}
headerValue := strings.TrimSpace(fmt.Sprintf("%v", value))
if headerValue == "" {
return fmt.Errorf("header value is required")
}
rawHeaders := ensureMapKeyInContext(context, paramOverrideContextHeaderOverride)
rawHeaders[headerName] = headerValue
normalizedHeaders := ensureMapKeyInContext(context, paramOverrideContextHeaderOverrideNormalized)
normalizedHeaders[normalizeHeaderContextKey(headerName)] = headerValue
return nil
}
func copyHeaderInContext(context map[string]interface{}, fromHeader, toHeader string, keepOrigin bool) error {
fromHeader = strings.TrimSpace(fromHeader)
toHeader = strings.TrimSpace(toHeader)
if fromHeader == "" || toHeader == "" {
return fmt.Errorf("copy_header from/to is required")
}
value, exists := getHeaderValueFromContext(context, fromHeader)
if !exists {
return fmt.Errorf("source header does not exist: %s", fromHeader)
}
return setHeaderOverrideInContext(context, toHeader, value, keepOrigin)
}
func moveHeaderInContext(context map[string]interface{}, fromHeader, toHeader string, keepOrigin bool) error {
fromHeader = strings.TrimSpace(fromHeader)
toHeader = strings.TrimSpace(toHeader)
if fromHeader == "" || toHeader == "" {
return fmt.Errorf("move_header from/to is required")
}
if err := copyHeaderInContext(context, fromHeader, toHeader, keepOrigin); err != nil {
return err
}
if strings.EqualFold(fromHeader, toHeader) {
return nil
}
return deleteHeaderOverrideInContext(context, fromHeader)
}
func deleteHeaderOverrideInContext(context map[string]interface{}, headerName string) error {
headerName = strings.TrimSpace(headerName)
if headerName == "" {
return fmt.Errorf("header name is required")
}
rawHeaders := ensureMapKeyInContext(context, paramOverrideContextHeaderOverride)
for key := range rawHeaders {
if strings.EqualFold(strings.TrimSpace(key), headerName) {
delete(rawHeaders, key)
}
}
normalizedHeaders := ensureMapKeyInContext(context, paramOverrideContextHeaderOverrideNormalized)
delete(normalizedHeaders, normalizeHeaderContextKey(headerName))
return nil
}
func ensureMapKeyInContext(context map[string]interface{}, key string) map[string]interface{} {
if context == nil {
return map[string]interface{}{}
}
if existing, ok := context[key]; ok {
if mapVal, ok := existing.(map[string]interface{}); ok {
return mapVal
}
}
result := make(map[string]interface{})
context[key] = result
return result
}
func getHeaderValueFromContext(context map[string]interface{}, headerName string) (string, bool) {
headerName = strings.TrimSpace(headerName)
if headerName == "" {
return "", false
}
if value, ok := findHeaderValueInMap(ensureMapKeyInContext(context, paramOverrideContextHeaderOverride), headerName); ok {
return value, true
}
if value, ok := findHeaderValueInMap(ensureMapKeyInContext(context, paramOverrideContextRequestHeadersRaw), headerName); ok {
return value, true
}
normalizedName := normalizeHeaderContextKey(headerName)
if normalizedName == "" {
return "", false
}
if value, ok := findHeaderValueInMap(ensureMapKeyInContext(context, paramOverrideContextHeaderOverrideNormalized), normalizedName); ok {
return value, true
}
if value, ok := findHeaderValueInMap(ensureMapKeyInContext(context, paramOverrideContextRequestHeaders), normalizedName); ok {
return value, true
}
return "", false
}
func findHeaderValueInMap(source map[string]interface{}, key string) (string, bool) {
if len(source) == 0 {
return "", false
}
entries := lo.Entries(source)
entry, ok := lo.Find(entries, func(item lo.Entry[string, interface{}]) bool {
return strings.EqualFold(strings.TrimSpace(item.Key), key)
})
if !ok {
return "", false
}
value := strings.TrimSpace(fmt.Sprintf("%v", entry.Value))
if value == "" {
return "", false
}
return value, true
}
func normalizeHeaderContextKey(key string) string {
key = strings.TrimSpace(strings.ToLower(key))
if key == "" {
return ""
}
var b strings.Builder
b.Grow(len(key))
previousUnderscore := false
for _, r := range key {
switch {
case r >= 'a' && r <= 'z':
b.WriteRune(r)
previousUnderscore = false
case r >= '0' && r <= '9':
b.WriteRune(r)
previousUnderscore = false
default:
if !previousUnderscore {
b.WriteByte('_')
previousUnderscore = true
}
}
}
result := strings.Trim(b.String(), "_")
return result
}
func buildNormalizedHeaders(headers map[string]string) map[string]interface{} {
if len(headers) == 0 {
return map[string]interface{}{}
}
entries := lo.Entries(headers)
normalizedEntries := lo.FilterMap(entries, func(item lo.Entry[string, string], _ int) (lo.Entry[string, string], bool) {
normalized := normalizeHeaderContextKey(item.Key)
value := strings.TrimSpace(item.Value)
if normalized == "" || value == "" {
return lo.Entry[string, string]{}, false
}
return lo.Entry[string, string]{Key: normalized, Value: value}, true
})
return lo.SliceToMap(normalizedEntries, func(item lo.Entry[string, string]) (string, interface{}) {
return item.Key, item.Value
})
}
func buildRawHeaders(headers map[string]string) map[string]interface{} {
if len(headers) == 0 {
return map[string]interface{}{}
}
entries := lo.Entries(headers)
rawEntries := lo.FilterMap(entries, func(item lo.Entry[string, string], _ int) (lo.Entry[string, string], bool) {
key := strings.TrimSpace(item.Key)
value := strings.TrimSpace(item.Value)
if key == "" || value == "" {
return lo.Entry[string, string]{}, false
}
return lo.Entry[string, string]{Key: key, Value: value}, true
})
return lo.SliceToMap(rawEntries, func(item lo.Entry[string, string]) (string, interface{}) {
return item.Key, item.Value
})
}
func buildHeaderOverrideContext(headers map[string]interface{}) (map[string]interface{}, map[string]interface{}) {
if len(headers) == 0 {
return map[string]interface{}{}, map[string]interface{}{}
}
entries := lo.Entries(headers)
rawEntries := lo.FilterMap(entries, func(item lo.Entry[string, interface{}], _ int) (lo.Entry[string, string], bool) {
key := strings.TrimSpace(item.Key)
value := strings.TrimSpace(fmt.Sprintf("%v", item.Value))
if key == "" || value == "" {
return lo.Entry[string, string]{}, false
}
return lo.Entry[string, string]{Key: key, Value: value}, true
})
raw := lo.SliceToMap(rawEntries, func(item lo.Entry[string, string]) (string, interface{}) {
return item.Key, item.Value
})
normalizedEntries := lo.FilterMap(rawEntries, func(item lo.Entry[string, string], _ int) (lo.Entry[string, string], bool) {
normalized := normalizeHeaderContextKey(item.Key)
if normalized == "" {
return lo.Entry[string, string]{}, false
}
return lo.Entry[string, string]{Key: normalized, Value: item.Value}, true
})
normalized := lo.SliceToMap(normalizedEntries, func(item lo.Entry[string, string]) (string, interface{}) {
return item.Key, item.Value
})
return raw, normalized
}
func syncRuntimeHeaderOverrideFromContext(info *RelayInfo, context map[string]interface{}) {
if info == nil || context == nil {
return
}
raw, exists := context[paramOverrideContextHeaderOverride]
if !exists {
return
}
rawMap, ok := raw.(map[string]interface{})
if !ok {
return
}
entries := lo.Entries(rawMap)
sanitized := lo.FilterMap(entries, func(item lo.Entry[string, interface{}], _ int) (lo.Entry[string, interface{}], bool) {
key := strings.TrimSpace(item.Key)
if key == "" {
return lo.Entry[string, interface{}]{}, false
}
value := strings.TrimSpace(fmt.Sprintf("%v", item.Value))
if value == "" {
return lo.Entry[string, interface{}]{}, false
}
return lo.Entry[string, interface{}]{Key: key, Value: value}, true
})
info.RuntimeHeadersOverride = lo.SliceToMap(sanitized, func(item lo.Entry[string, interface{}]) (string, interface{}) {
return item.Key, item.Value
})
info.UseRuntimeHeadersOverride = true
}
func moveValue(jsonStr, fromPath, toPath string) (string, error) {
sourceValue := gjson.Get(jsonStr, fromPath)
if !sourceValue.Exists() {
@@ -824,38 +1135,56 @@ func parsePruneObjectsOptions(value interface{}) (pruneObjectsOptions, error) {
}
func parseConditionOperations(raw interface{}) ([]ConditionOperation, error) {
items, ok := raw.([]interface{})
if !ok {
return nil, fmt.Errorf("conditions must be an array")
switch typed := raw.(type) {
case map[string]interface{}:
entries := lo.Entries(typed)
conditions := lo.FilterMap(entries, func(item lo.Entry[string, interface{}], _ int) (ConditionOperation, bool) {
path := strings.TrimSpace(item.Key)
if path == "" {
return ConditionOperation{}, false
}
return ConditionOperation{
Path: path,
Mode: "full",
Value: item.Value,
}, true
})
if len(conditions) == 0 {
return nil, fmt.Errorf("conditions object must contain at least one key")
}
return conditions, nil
case []interface{}:
items := typed
result := make([]ConditionOperation, 0, len(items))
for _, item := range items {
itemMap, ok := item.(map[string]interface{})
if !ok {
return nil, fmt.Errorf("condition must be object")
}
path, _ := itemMap["path"].(string)
mode, _ := itemMap["mode"].(string)
if strings.TrimSpace(path) == "" || strings.TrimSpace(mode) == "" {
return nil, fmt.Errorf("condition path/mode is required")
}
condition := ConditionOperation{
Path: path,
Mode: mode,
}
if value, exists := itemMap["value"]; exists {
condition.Value = value
}
if invert, ok := itemMap["invert"].(bool); ok {
condition.Invert = invert
}
if passMissingKey, ok := itemMap["pass_missing_key"].(bool); ok {
condition.PassMissingKey = passMissingKey
}
result = append(result, condition)
}
return result, nil
default:
return nil, fmt.Errorf("conditions must be an array or object")
}
result := make([]ConditionOperation, 0, len(items))
for _, item := range items {
itemMap, ok := item.(map[string]interface{})
if !ok {
return nil, fmt.Errorf("condition must be object")
}
path, _ := itemMap["path"].(string)
mode, _ := itemMap["mode"].(string)
if strings.TrimSpace(path) == "" || strings.TrimSpace(mode) == "" {
return nil, fmt.Errorf("condition path/mode is required")
}
condition := ConditionOperation{
Path: path,
Mode: mode,
}
if value, exists := itemMap["value"]; exists {
condition.Value = value
}
if invert, ok := itemMap["invert"].(bool); ok {
condition.Invert = invert
}
if passMissingKey, ok := itemMap["pass_missing_key"].(bool); ok {
condition.PassMissingKey = passMissingKey
}
result = append(result, condition)
}
return result, nil
}
func pruneObjectsNode(node interface{}, options pruneObjectsOptions, contextJSON string, isRoot bool) (interface{}, bool, error) {
@@ -970,6 +1299,17 @@ func BuildParamOverrideContext(info *RelayInfo) map[string]interface{} {
}
}
ctx[paramOverrideContextRequestHeaders] = buildNormalizedHeaders(info.RequestHeaders)
ctx[paramOverrideContextRequestHeadersRaw] = buildRawHeaders(info.RequestHeaders)
headerOverrideSource := getHeaderOverrideMap(info)
if info.UseRuntimeHeadersOverride {
headerOverrideSource = info.RuntimeHeadersOverride
}
rawHeaderOverride, normalizedHeaderOverride := buildHeaderOverrideContext(headerOverrideSource)
ctx[paramOverrideContextHeaderOverride] = rawHeaderOverride
ctx[paramOverrideContextHeaderOverrideNormalized] = normalizedHeaderOverride
ctx["retry_index"] = info.RetryIndex
ctx["is_retry"] = info.RetryIndex > 0
ctx["retry"] = map[string]interface{}{

View File

@@ -956,6 +956,254 @@ func TestApplyParamOverrideConditionFromRetryAndLastErrorContext(t *testing.T) {
assertJSONEqual(t, `{"temperature":0.1}`, string(out))
}
func TestApplyParamOverrideConditionFromRequestHeaders(t *testing.T) {
input := []byte(`{"temperature":0.7}`)
override := map[string]interface{}{
"operations": []interface{}{
map[string]interface{}{
"path": "temperature",
"mode": "set",
"value": 0.1,
"conditions": []interface{}{
map[string]interface{}{
"path": "request_headers.authorization",
"mode": "contains",
"value": "Bearer ",
},
},
},
},
}
ctx := map[string]interface{}{
"request_headers": map[string]interface{}{
"authorization": "Bearer token-123",
},
}
out, err := ApplyParamOverride(input, override, ctx)
if err != nil {
t.Fatalf("ApplyParamOverride returned error: %v", err)
}
assertJSONEqual(t, `{"temperature":0.1}`, string(out))
}
func TestApplyParamOverrideSetHeaderAndUseInLaterCondition(t *testing.T) {
input := []byte(`{"temperature":0.7}`)
override := map[string]interface{}{
"operations": []interface{}{
map[string]interface{}{
"mode": "set_header",
"path": "X-Debug-Mode",
"value": "enabled",
},
map[string]interface{}{
"path": "temperature",
"mode": "set",
"value": 0.1,
"conditions": []interface{}{
map[string]interface{}{
"path": "header_override_normalized.x_debug_mode",
"mode": "full",
"value": "enabled",
},
},
},
},
}
out, err := ApplyParamOverride(input, override, nil)
if err != nil {
t.Fatalf("ApplyParamOverride returned error: %v", err)
}
assertJSONEqual(t, `{"temperature":0.1}`, string(out))
}
func TestApplyParamOverrideCopyHeaderFromRequestHeaders(t *testing.T) {
input := []byte(`{"temperature":0.7}`)
override := map[string]interface{}{
"operations": []interface{}{
map[string]interface{}{
"mode": "copy_header",
"from": "Authorization",
"to": "X-Upstream-Auth",
},
map[string]interface{}{
"path": "temperature",
"mode": "set",
"value": 0.1,
"conditions": []interface{}{
map[string]interface{}{
"path": "header_override_normalized.x_upstream_auth",
"mode": "contains",
"value": "Bearer ",
},
},
},
},
}
ctx := map[string]interface{}{
"request_headers_raw": map[string]interface{}{
"Authorization": "Bearer token-123",
},
"request_headers": map[string]interface{}{
"authorization": "Bearer token-123",
},
}
out, err := ApplyParamOverride(input, override, ctx)
if err != nil {
t.Fatalf("ApplyParamOverride returned error: %v", err)
}
assertJSONEqual(t, `{"temperature":0.1}`, string(out))
}
func TestApplyParamOverrideSetHeaderKeepOrigin(t *testing.T) {
input := []byte(`{"temperature":0.7}`)
override := map[string]interface{}{
"operations": []interface{}{
map[string]interface{}{
"mode": "set_header",
"path": "X-Feature-Flag",
"value": "new-value",
"keep_origin": true,
},
},
}
ctx := map[string]interface{}{
"header_override": map[string]interface{}{
"X-Feature-Flag": "legacy-value",
},
"header_override_normalized": map[string]interface{}{
"x_feature_flag": "legacy-value",
},
}
_, err := ApplyParamOverride(input, override, ctx)
if err != nil {
t.Fatalf("ApplyParamOverride returned error: %v", err)
}
headers, ok := ctx["header_override"].(map[string]interface{})
if !ok {
t.Fatalf("expected header_override context map")
}
if headers["X-Feature-Flag"] != "legacy-value" {
t.Fatalf("expected keep_origin to preserve old value, got: %v", headers["X-Feature-Flag"])
}
}
func TestApplyParamOverrideConditionsObjectShorthand(t *testing.T) {
input := []byte(`{"temperature":0.7}`)
override := map[string]interface{}{
"operations": []interface{}{
map[string]interface{}{
"path": "temperature",
"mode": "set",
"value": 0.1,
"logic": "AND",
"conditions": map[string]interface{}{
"is_retry": true,
"last_error.status_code": 400.0,
},
},
},
}
ctx := map[string]interface{}{
"is_retry": true,
"last_error": map[string]interface{}{
"status_code": 400.0,
},
}
out, err := ApplyParamOverride(input, override, ctx)
if err != nil {
t.Fatalf("ApplyParamOverride returned error: %v", err)
}
assertJSONEqual(t, `{"temperature":0.1}`, string(out))
}
func TestApplyParamOverrideWithRelayInfoSyncRuntimeHeaders(t *testing.T) {
info := &RelayInfo{
ChannelMeta: &ChannelMeta{
ParamOverride: map[string]interface{}{
"operations": []interface{}{
map[string]interface{}{
"mode": "set_header",
"path": "X-Injected-By-Param-Override",
"value": "enabled",
},
map[string]interface{}{
"mode": "delete_header",
"path": "X-Delete-Me",
},
},
},
HeadersOverride: map[string]interface{}{
"X-Delete-Me": "legacy",
"X-Keep-Me": "keep",
},
},
}
input := []byte(`{"temperature":0.7}`)
out, err := ApplyParamOverrideWithRelayInfo(input, info)
if err != nil {
t.Fatalf("ApplyParamOverrideWithRelayInfo returned error: %v", err)
}
assertJSONEqual(t, `{"temperature":0.7}`, string(out))
if !info.UseRuntimeHeadersOverride {
t.Fatalf("expected runtime header override to be enabled")
}
if info.RuntimeHeadersOverride["X-Keep-Me"] != "keep" {
t.Fatalf("expected X-Keep-Me header to be preserved, got: %v", info.RuntimeHeadersOverride["X-Keep-Me"])
}
if info.RuntimeHeadersOverride["X-Injected-By-Param-Override"] != "enabled" {
t.Fatalf("expected X-Injected-By-Param-Override header to be set, got: %v", info.RuntimeHeadersOverride["X-Injected-By-Param-Override"])
}
if _, exists := info.RuntimeHeadersOverride["X-Delete-Me"]; exists {
t.Fatalf("expected X-Delete-Me header to be deleted")
}
}
func TestApplyParamOverrideWithRelayInfoMoveAndCopyHeaders(t *testing.T) {
info := &RelayInfo{
ChannelMeta: &ChannelMeta{
ParamOverride: map[string]interface{}{
"operations": []interface{}{
map[string]interface{}{
"mode": "move_header",
"from": "X-Legacy-Trace",
"to": "X-Trace",
},
map[string]interface{}{
"mode": "copy_header",
"from": "X-Trace",
"to": "X-Trace-Backup",
},
},
},
HeadersOverride: map[string]interface{}{
"X-Legacy-Trace": "trace-123",
},
},
}
input := []byte(`{"temperature":0.7}`)
_, err := ApplyParamOverrideWithRelayInfo(input, info)
if err != nil {
t.Fatalf("ApplyParamOverrideWithRelayInfo returned error: %v", err)
}
if _, exists := info.RuntimeHeadersOverride["X-Legacy-Trace"]; exists {
t.Fatalf("expected source header to be removed after move")
}
if info.RuntimeHeadersOverride["X-Trace"] != "trace-123" {
t.Fatalf("expected X-Trace to be set, got: %v", info.RuntimeHeadersOverride["X-Trace"])
}
if info.RuntimeHeadersOverride["X-Trace-Backup"] != "trace-123" {
t.Fatalf("expected X-Trace-Backup to be copied, got: %v", info.RuntimeHeadersOverride["X-Trace-Backup"])
}
}
func assertJSONEqual(t *testing.T, want, got string) {
t.Helper()

View File

@@ -101,6 +101,7 @@ type RelayInfo struct {
RelayMode int
OriginModelName string
RequestURLPath string
RequestHeaders map[string]string
ShouldIncludeUsage bool
DisablePing bool // 是否禁止向下游发送自定义 Ping
ClientWs *websocket.Conn
@@ -142,6 +143,8 @@ type RelayInfo struct {
IsChannelTest bool // channel test request
RetryIndex int
LastError *types.NewAPIError
RuntimeHeadersOverride map[string]interface{}
UseRuntimeHeadersOverride bool
PriceData types.PriceData
@@ -458,6 +461,7 @@ func genBaseRelayInfo(c *gin.Context, request dto.Request) *RelayInfo {
isFirstResponse: true,
RelayMode: relayconstant.Path2RelayMode(c.Request.URL.Path),
RequestURLPath: c.Request.URL.String(),
RequestHeaders: cloneRequestHeaders(c),
IsStream: isStream,
StartTime: startTime,
@@ -490,6 +494,27 @@ func genBaseRelayInfo(c *gin.Context, request dto.Request) *RelayInfo {
return info
}
func cloneRequestHeaders(c *gin.Context) map[string]string {
if c == nil || c.Request == nil {
return nil
}
if len(c.Request.Header) == 0 {
return nil
}
headers := make(map[string]string, len(c.Request.Header))
for key := range c.Request.Header {
value := strings.TrimSpace(c.Request.Header.Get(key))
if value == "" {
continue
}
headers[key] = value
}
if len(headers) == 0 {
return nil
}
return headers
}
func GenRelayInfo(c *gin.Context, relayFormat types.RelayFormat, request dto.Request, ws *websocket.Conn) (*RelayInfo, error) {
var info *RelayInfo
var err error

View File

@@ -172,7 +172,7 @@ func TextHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types
// apply param override
if len(info.ParamOverride) > 0 {
jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride, relaycommon.BuildParamOverrideContext(info))
jsonData, err = relaycommon.ApplyParamOverrideWithRelayInfo(jsonData, info)
if err != nil {
return newAPIErrorFromParamOverride(err)
}

View File

@@ -51,7 +51,7 @@ func EmbeddingHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *
}
if len(info.ParamOverride) > 0 {
jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride, relaycommon.BuildParamOverrideContext(info))
jsonData, err = relaycommon.ApplyParamOverrideWithRelayInfo(jsonData, info)
if err != nil {
return newAPIErrorFromParamOverride(err)
}

View File

@@ -157,7 +157,7 @@ func GeminiHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
// apply param override
if len(info.ParamOverride) > 0 {
jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride, relaycommon.BuildParamOverrideContext(info))
jsonData, err = relaycommon.ApplyParamOverrideWithRelayInfo(jsonData, info)
if err != nil {
return newAPIErrorFromParamOverride(err)
}
@@ -257,7 +257,7 @@ func GeminiEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo) (newAPI
// apply param override
if len(info.ParamOverride) > 0 {
jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride, relaycommon.BuildParamOverrideContext(info))
jsonData, err = relaycommon.ApplyParamOverrideWithRelayInfo(jsonData, info)
if err != nil {
return newAPIErrorFromParamOverride(err)
}

View File

@@ -70,7 +70,7 @@ func ImageHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *type
// apply param override
if len(info.ParamOverride) > 0 {
jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride, relaycommon.BuildParamOverrideContext(info))
jsonData, err = relaycommon.ApplyParamOverrideWithRelayInfo(jsonData, info)
if err != nil {
return newAPIErrorFromParamOverride(err)
}

View File

@@ -61,7 +61,7 @@ func RerankHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
// apply param override
if len(info.ParamOverride) > 0 {
jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride, relaycommon.BuildParamOverrideContext(info))
jsonData, err = relaycommon.ApplyParamOverrideWithRelayInfo(jsonData, info)
if err != nil {
return newAPIErrorFromParamOverride(err)
}

View File

@@ -96,7 +96,7 @@ func ResponsesHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *
// apply param override
if len(info.ParamOverride) > 0 {
jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride, relaycommon.BuildParamOverrideContext(info))
jsonData, err = relaycommon.ApplyParamOverrideWithRelayInfo(jsonData, info)
if err != nil {
return newAPIErrorFromParamOverride(err)
}