diff --git a/controller/channel-test.go b/controller/channel-test.go index 7ffee9fdf..3947c8d5c 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -385,7 +385,7 @@ func testChannel(channel *model.Channel, testModel string, endpointType string, //} if len(info.ParamOverride) > 0 { - jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride, relaycommon.BuildParamOverrideContext(info)) + jsonData, err = relaycommon.ApplyParamOverrideWithRelayInfo(jsonData, info) if err != nil { if fixedErr, ok := relaycommon.AsParamOverrideReturnError(err); ok { return testResult{ diff --git a/relay/channel/api_request.go b/relay/channel/api_request.go index ec5573ab1..dcdff584b 100644 --- a/relay/channel/api_request.go +++ b/relay/channel/api_request.go @@ -168,11 +168,19 @@ func applyHeaderOverridePlaceholders(template string, c *gin.Context, apiKey str // Passthrough rules are applied first, then normal overrides are applied, so explicit overrides win. func processHeaderOverride(info *common.RelayInfo, c *gin.Context) (map[string]string, error) { headerOverride := make(map[string]string) + if info == nil { + return headerOverride, nil + } + + headerOverrideSource := info.HeadersOverride + if info.UseRuntimeHeadersOverride { + headerOverrideSource = info.RuntimeHeadersOverride + } passAll := false var passthroughRegex []*regexp.Regexp if !info.IsChannelTest { - for k := range info.HeadersOverride { + for k := range headerOverrideSource { key := strings.TrimSpace(k) if key == "" { continue @@ -232,7 +240,7 @@ func processHeaderOverride(info *common.RelayInfo, c *gin.Context) (map[string]s } } - for k, v := range info.HeadersOverride { + for k, v := range headerOverrideSource { if isHeaderPassthroughRuleKey(k) { continue } diff --git a/relay/channel/api_request_test.go b/relay/channel/api_request_test.go index c55ffcab2..31e15340a 100644 --- a/relay/channel/api_request_test.go +++ b/relay/channel/api_request_test.go @@ -79,3 +79,34 @@ func TestProcessHeaderOverride_NonTestKeepsClientHeaderPlaceholder(t *testing.T) require.NoError(t, err) require.Equal(t, "trace-123", headers["X-Upstream-Trace"]) } + +func TestProcessHeaderOverride_RuntimeOverrideHasPriority(t *testing.T) { + t.Parallel() + + gin.SetMode(gin.TestMode) + recorder := httptest.NewRecorder() + ctx, _ := gin.CreateTestContext(recorder) + ctx.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil) + + info := &relaycommon.RelayInfo{ + IsChannelTest: false, + UseRuntimeHeadersOverride: true, + RuntimeHeadersOverride: map[string]any{ + "X-Static": "runtime-value", + "X-Runtime": "runtime-only", + }, + ChannelMeta: &relaycommon.ChannelMeta{ + HeadersOverride: map[string]any{ + "X-Static": "legacy-value", + "X-Legacy": "legacy-only", + }, + }, + } + + headers, err := processHeaderOverride(info, ctx) + require.NoError(t, err) + require.Equal(t, "runtime-value", headers["X-Static"]) + require.Equal(t, "runtime-only", headers["X-Runtime"]) + _, ok := headers["X-Legacy"] + require.False(t, ok) +} diff --git a/relay/chat_completions_via_responses.go b/relay/chat_completions_via_responses.go index 3f2fb1874..580cba5f4 100644 --- a/relay/chat_completions_via_responses.go +++ b/relay/chat_completions_via_responses.go @@ -70,7 +70,6 @@ func applySystemPromptIfNeeded(c *gin.Context, info *relaycommon.RelayInfo, requ } func chatCompletionsViaResponses(c *gin.Context, info *relaycommon.RelayInfo, adaptor channel.Adaptor, request *dto.GeneralOpenAIRequest) (*dto.Usage, *types.NewAPIError) { - overrideCtx := relaycommon.BuildParamOverrideContext(info) chatJSON, err := common.Marshal(request) if err != nil { return nil, types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry()) @@ -82,7 +81,7 @@ func chatCompletionsViaResponses(c *gin.Context, info *relaycommon.RelayInfo, ad } if len(info.ParamOverride) > 0 { - chatJSON, err = relaycommon.ApplyParamOverride(chatJSON, info.ParamOverride, overrideCtx) + chatJSON, err = relaycommon.ApplyParamOverrideWithRelayInfo(chatJSON, info) if err != nil { return nil, newAPIErrorFromParamOverride(err) } diff --git a/relay/claude_handler.go b/relay/claude_handler.go index b9d9936e9..2dfa09df5 100644 --- a/relay/claude_handler.go +++ b/relay/claude_handler.go @@ -153,7 +153,7 @@ func ClaudeHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ // apply param override if len(info.ParamOverride) > 0 { - jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride, relaycommon.BuildParamOverrideContext(info)) + jsonData, err = relaycommon.ApplyParamOverrideWithRelayInfo(jsonData, info) if err != nil { return newAPIErrorFromParamOverride(err) } diff --git a/relay/common/override.go b/relay/common/override.go index 070a8e7af..9ac007ecd 100644 --- a/relay/common/override.go +++ b/relay/common/override.go @@ -10,12 +10,20 @@ import ( "github.com/QuantumNous/new-api/common" "github.com/QuantumNous/new-api/types" + "github.com/samber/lo" "github.com/tidwall/gjson" "github.com/tidwall/sjson" ) var negativeIndexRegexp = regexp.MustCompile(`\.(-\d+)`) +const ( + paramOverrideContextRequestHeaders = "request_headers" + paramOverrideContextRequestHeadersRaw = "request_headers_raw" + paramOverrideContextHeaderOverride = "header_override" + paramOverrideContextHeaderOverrideNormalized = "header_override_normalized" +) + type ConditionOperation struct { Path string `json:"path"` // JSON路径 Mode string `json:"mode"` // full, prefix, suffix, contains, gt, gte, lt, lte @@ -26,7 +34,7 @@ type ConditionOperation struct { type ParamOperation struct { Path string `json:"path"` - Mode string `json:"mode"` // delete, set, move, copy, prepend, append, trim_prefix, trim_suffix, ensure_prefix, ensure_suffix, trim_space, to_lower, to_upper, replace, regex_replace, return_error, prune_objects + Mode string `json:"mode"` // delete, set, move, copy, prepend, append, trim_prefix, trim_suffix, ensure_prefix, ensure_suffix, trim_space, to_lower, to_upper, replace, regex_replace, return_error, prune_objects, set_header, delete_header, copy_header, move_header Value interface{} `json:"value"` KeepOrigin bool `json:"keep_origin"` From string `json:"from,omitempty"` @@ -121,6 +129,35 @@ func ApplyParamOverride(jsonData []byte, paramOverride map[string]interface{}, c return applyOperationsLegacy(jsonData, paramOverride) } +func ApplyParamOverrideWithRelayInfo(jsonData []byte, info *RelayInfo) ([]byte, error) { + paramOverride := getParamOverrideMap(info) + if len(paramOverride) == 0 { + return jsonData, nil + } + + overrideCtx := BuildParamOverrideContext(info) + result, err := ApplyParamOverride(jsonData, paramOverride, overrideCtx) + if err != nil { + return nil, err + } + syncRuntimeHeaderOverrideFromContext(info, overrideCtx) + return result, nil +} + +func getParamOverrideMap(info *RelayInfo) map[string]interface{} { + if info == nil || info.ChannelMeta == nil { + return nil + } + return info.ChannelMeta.ParamOverride +} + +func getHeaderOverrideMap(info *RelayInfo) map[string]interface{} { + if info == nil || info.ChannelMeta == nil { + return nil + } + return info.ChannelMeta.HeadersOverride +} + func tryParseOperations(paramOverride map[string]interface{}) ([]ParamOperation, bool) { // 检查是否包含 "operations" 字段 if opsValue, exists := paramOverride["operations"]; exists { @@ -161,29 +198,11 @@ func tryParseOperations(paramOverride map[string]interface{}) ([]ParamOperation, // 解析条件 if conditions, exists := opMap["conditions"]; exists { - if condSlice, ok := conditions.([]interface{}); ok { - for _, cond := range condSlice { - if condMap, ok := cond.(map[string]interface{}); ok { - condition := ConditionOperation{} - if path, ok := condMap["path"].(string); ok { - condition.Path = path - } - if mode, ok := condMap["mode"].(string); ok { - condition.Mode = mode - } - if value, ok := condMap["value"]; ok { - condition.Value = value - } - if invert, ok := condMap["invert"].(bool); ok { - condition.Invert = invert - } - if passMissingKey, ok := condMap["pass_missing_key"].(bool); ok { - condition.PassMissingKey = passMissingKey - } - operation.Conditions = append(operation.Conditions, condition) - } - } + parsedConditions, err := parseConditionOperations(conditions) + if err != nil { + return nil, false } + operation.Conditions = append(operation.Conditions, parsedConditions...) } operations = append(operations, operation) @@ -212,20 +231,9 @@ func checkConditions(jsonStr, contextJSON string, conditions []ConditionOperatio } if strings.ToUpper(logic) == "AND" { - for _, result := range results { - if !result { - return false, nil - } - } - return true, nil - } else { - for _, result := range results { - if result { - return true, nil - } - } - return false, nil + return lo.EveryBy(results, func(item bool) bool { return item }), nil } + return lo.SomeBy(results, func(item bool) bool { return item }), nil } func checkSingleCondition(jsonStr, contextJSON string, condition ConditionOperation) (bool, error) { @@ -382,13 +390,10 @@ func applyOperationsLegacy(jsonData []byte, paramOverride map[string]interface{} } func applyOperations(jsonStr string, operations []ParamOperation, conditionContext map[string]interface{}) (string, error) { - var contextJSON string - if conditionContext != nil && len(conditionContext) > 0 { - ctxBytes, err := common.Marshal(conditionContext) - if err != nil { - return "", fmt.Errorf("failed to marshal condition context: %v", err) - } - contextJSON = string(ctxBytes) + context := ensureContextMap(conditionContext) + contextJSON, err := marshalContextJSON(context) + if err != nil { + return "", fmt.Errorf("failed to marshal condition context: %v", err) } result := jsonStr @@ -453,6 +458,42 @@ func applyOperations(jsonStr string, operations []ParamOperation, conditionConte return "", returnErr case "prune_objects": result, err = pruneObjects(result, opPath, contextJSON, op.Value) + case "set_header": + err = setHeaderOverrideInContext(context, op.Path, op.Value, op.KeepOrigin) + if err == nil { + contextJSON, err = marshalContextJSON(context) + } + case "delete_header": + err = deleteHeaderOverrideInContext(context, op.Path) + if err == nil { + contextJSON, err = marshalContextJSON(context) + } + case "copy_header": + sourceHeader := strings.TrimSpace(op.From) + targetHeader := strings.TrimSpace(op.To) + if sourceHeader == "" { + sourceHeader = strings.TrimSpace(op.Path) + } + if targetHeader == "" { + targetHeader = strings.TrimSpace(op.Path) + } + err = copyHeaderInContext(context, sourceHeader, targetHeader, op.KeepOrigin) + if err == nil { + contextJSON, err = marshalContextJSON(context) + } + case "move_header": + sourceHeader := strings.TrimSpace(op.From) + targetHeader := strings.TrimSpace(op.To) + if sourceHeader == "" { + sourceHeader = strings.TrimSpace(op.Path) + } + if targetHeader == "" { + targetHeader = strings.TrimSpace(op.Path) + } + err = moveHeaderInContext(context, sourceHeader, targetHeader, op.KeepOrigin) + if err == nil { + contextJSON, err = marshalContextJSON(context) + } default: return "", fmt.Errorf("unknown operation: %s", op.Mode) } @@ -543,6 +584,276 @@ func parseOverrideInt(v interface{}) (int, bool) { } } +func ensureContextMap(conditionContext map[string]interface{}) map[string]interface{} { + if conditionContext != nil { + return conditionContext + } + return make(map[string]interface{}) +} + +func marshalContextJSON(context map[string]interface{}) (string, error) { + if context == nil || len(context) == 0 { + return "", nil + } + ctxBytes, err := common.Marshal(context) + if err != nil { + return "", err + } + return string(ctxBytes), nil +} + +func setHeaderOverrideInContext(context map[string]interface{}, headerName string, value interface{}, keepOrigin bool) error { + headerName = strings.TrimSpace(headerName) + if headerName == "" { + return fmt.Errorf("header name is required") + } + if keepOrigin { + if _, exists := getHeaderValueFromContext(context, headerName); exists { + return nil + } + } + if value == nil { + return fmt.Errorf("header value is required") + } + headerValue := strings.TrimSpace(fmt.Sprintf("%v", value)) + if headerValue == "" { + return fmt.Errorf("header value is required") + } + + rawHeaders := ensureMapKeyInContext(context, paramOverrideContextHeaderOverride) + rawHeaders[headerName] = headerValue + + normalizedHeaders := ensureMapKeyInContext(context, paramOverrideContextHeaderOverrideNormalized) + normalizedHeaders[normalizeHeaderContextKey(headerName)] = headerValue + return nil +} + +func copyHeaderInContext(context map[string]interface{}, fromHeader, toHeader string, keepOrigin bool) error { + fromHeader = strings.TrimSpace(fromHeader) + toHeader = strings.TrimSpace(toHeader) + if fromHeader == "" || toHeader == "" { + return fmt.Errorf("copy_header from/to is required") + } + value, exists := getHeaderValueFromContext(context, fromHeader) + if !exists { + return fmt.Errorf("source header does not exist: %s", fromHeader) + } + return setHeaderOverrideInContext(context, toHeader, value, keepOrigin) +} + +func moveHeaderInContext(context map[string]interface{}, fromHeader, toHeader string, keepOrigin bool) error { + fromHeader = strings.TrimSpace(fromHeader) + toHeader = strings.TrimSpace(toHeader) + if fromHeader == "" || toHeader == "" { + return fmt.Errorf("move_header from/to is required") + } + if err := copyHeaderInContext(context, fromHeader, toHeader, keepOrigin); err != nil { + return err + } + if strings.EqualFold(fromHeader, toHeader) { + return nil + } + return deleteHeaderOverrideInContext(context, fromHeader) +} + +func deleteHeaderOverrideInContext(context map[string]interface{}, headerName string) error { + headerName = strings.TrimSpace(headerName) + if headerName == "" { + return fmt.Errorf("header name is required") + } + rawHeaders := ensureMapKeyInContext(context, paramOverrideContextHeaderOverride) + for key := range rawHeaders { + if strings.EqualFold(strings.TrimSpace(key), headerName) { + delete(rawHeaders, key) + } + } + + normalizedHeaders := ensureMapKeyInContext(context, paramOverrideContextHeaderOverrideNormalized) + delete(normalizedHeaders, normalizeHeaderContextKey(headerName)) + return nil +} + +func ensureMapKeyInContext(context map[string]interface{}, key string) map[string]interface{} { + if context == nil { + return map[string]interface{}{} + } + if existing, ok := context[key]; ok { + if mapVal, ok := existing.(map[string]interface{}); ok { + return mapVal + } + } + result := make(map[string]interface{}) + context[key] = result + return result +} + +func getHeaderValueFromContext(context map[string]interface{}, headerName string) (string, bool) { + headerName = strings.TrimSpace(headerName) + if headerName == "" { + return "", false + } + if value, ok := findHeaderValueInMap(ensureMapKeyInContext(context, paramOverrideContextHeaderOverride), headerName); ok { + return value, true + } + if value, ok := findHeaderValueInMap(ensureMapKeyInContext(context, paramOverrideContextRequestHeadersRaw), headerName); ok { + return value, true + } + + normalizedName := normalizeHeaderContextKey(headerName) + if normalizedName == "" { + return "", false + } + if value, ok := findHeaderValueInMap(ensureMapKeyInContext(context, paramOverrideContextHeaderOverrideNormalized), normalizedName); ok { + return value, true + } + if value, ok := findHeaderValueInMap(ensureMapKeyInContext(context, paramOverrideContextRequestHeaders), normalizedName); ok { + return value, true + } + return "", false +} + +func findHeaderValueInMap(source map[string]interface{}, key string) (string, bool) { + if len(source) == 0 { + return "", false + } + entries := lo.Entries(source) + entry, ok := lo.Find(entries, func(item lo.Entry[string, interface{}]) bool { + return strings.EqualFold(strings.TrimSpace(item.Key), key) + }) + if !ok { + return "", false + } + value := strings.TrimSpace(fmt.Sprintf("%v", entry.Value)) + if value == "" { + return "", false + } + return value, true +} + +func normalizeHeaderContextKey(key string) string { + key = strings.TrimSpace(strings.ToLower(key)) + if key == "" { + return "" + } + var b strings.Builder + b.Grow(len(key)) + previousUnderscore := false + for _, r := range key { + switch { + case r >= 'a' && r <= 'z': + b.WriteRune(r) + previousUnderscore = false + case r >= '0' && r <= '9': + b.WriteRune(r) + previousUnderscore = false + default: + if !previousUnderscore { + b.WriteByte('_') + previousUnderscore = true + } + } + } + result := strings.Trim(b.String(), "_") + return result +} + +func buildNormalizedHeaders(headers map[string]string) map[string]interface{} { + if len(headers) == 0 { + return map[string]interface{}{} + } + entries := lo.Entries(headers) + normalizedEntries := lo.FilterMap(entries, func(item lo.Entry[string, string], _ int) (lo.Entry[string, string], bool) { + normalized := normalizeHeaderContextKey(item.Key) + value := strings.TrimSpace(item.Value) + if normalized == "" || value == "" { + return lo.Entry[string, string]{}, false + } + return lo.Entry[string, string]{Key: normalized, Value: value}, true + }) + return lo.SliceToMap(normalizedEntries, func(item lo.Entry[string, string]) (string, interface{}) { + return item.Key, item.Value + }) +} + +func buildRawHeaders(headers map[string]string) map[string]interface{} { + if len(headers) == 0 { + return map[string]interface{}{} + } + entries := lo.Entries(headers) + rawEntries := lo.FilterMap(entries, func(item lo.Entry[string, string], _ int) (lo.Entry[string, string], bool) { + key := strings.TrimSpace(item.Key) + value := strings.TrimSpace(item.Value) + if key == "" || value == "" { + return lo.Entry[string, string]{}, false + } + return lo.Entry[string, string]{Key: key, Value: value}, true + }) + return lo.SliceToMap(rawEntries, func(item lo.Entry[string, string]) (string, interface{}) { + return item.Key, item.Value + }) +} + +func buildHeaderOverrideContext(headers map[string]interface{}) (map[string]interface{}, map[string]interface{}) { + if len(headers) == 0 { + return map[string]interface{}{}, map[string]interface{}{} + } + entries := lo.Entries(headers) + rawEntries := lo.FilterMap(entries, func(item lo.Entry[string, interface{}], _ int) (lo.Entry[string, string], bool) { + key := strings.TrimSpace(item.Key) + value := strings.TrimSpace(fmt.Sprintf("%v", item.Value)) + if key == "" || value == "" { + return lo.Entry[string, string]{}, false + } + return lo.Entry[string, string]{Key: key, Value: value}, true + }) + + raw := lo.SliceToMap(rawEntries, func(item lo.Entry[string, string]) (string, interface{}) { + return item.Key, item.Value + }) + normalizedEntries := lo.FilterMap(rawEntries, func(item lo.Entry[string, string], _ int) (lo.Entry[string, string], bool) { + normalized := normalizeHeaderContextKey(item.Key) + if normalized == "" { + return lo.Entry[string, string]{}, false + } + return lo.Entry[string, string]{Key: normalized, Value: item.Value}, true + }) + normalized := lo.SliceToMap(normalizedEntries, func(item lo.Entry[string, string]) (string, interface{}) { + return item.Key, item.Value + }) + return raw, normalized +} + +func syncRuntimeHeaderOverrideFromContext(info *RelayInfo, context map[string]interface{}) { + if info == nil || context == nil { + return + } + raw, exists := context[paramOverrideContextHeaderOverride] + if !exists { + return + } + rawMap, ok := raw.(map[string]interface{}) + if !ok { + return + } + + entries := lo.Entries(rawMap) + sanitized := lo.FilterMap(entries, func(item lo.Entry[string, interface{}], _ int) (lo.Entry[string, interface{}], bool) { + key := strings.TrimSpace(item.Key) + if key == "" { + return lo.Entry[string, interface{}]{}, false + } + value := strings.TrimSpace(fmt.Sprintf("%v", item.Value)) + if value == "" { + return lo.Entry[string, interface{}]{}, false + } + return lo.Entry[string, interface{}]{Key: key, Value: value}, true + }) + info.RuntimeHeadersOverride = lo.SliceToMap(sanitized, func(item lo.Entry[string, interface{}]) (string, interface{}) { + return item.Key, item.Value + }) + info.UseRuntimeHeadersOverride = true +} + func moveValue(jsonStr, fromPath, toPath string) (string, error) { sourceValue := gjson.Get(jsonStr, fromPath) if !sourceValue.Exists() { @@ -824,38 +1135,56 @@ func parsePruneObjectsOptions(value interface{}) (pruneObjectsOptions, error) { } func parseConditionOperations(raw interface{}) ([]ConditionOperation, error) { - items, ok := raw.([]interface{}) - if !ok { - return nil, fmt.Errorf("conditions must be an array") + switch typed := raw.(type) { + case map[string]interface{}: + entries := lo.Entries(typed) + conditions := lo.FilterMap(entries, func(item lo.Entry[string, interface{}], _ int) (ConditionOperation, bool) { + path := strings.TrimSpace(item.Key) + if path == "" { + return ConditionOperation{}, false + } + return ConditionOperation{ + Path: path, + Mode: "full", + Value: item.Value, + }, true + }) + if len(conditions) == 0 { + return nil, fmt.Errorf("conditions object must contain at least one key") + } + return conditions, nil + case []interface{}: + items := typed + result := make([]ConditionOperation, 0, len(items)) + for _, item := range items { + itemMap, ok := item.(map[string]interface{}) + if !ok { + return nil, fmt.Errorf("condition must be object") + } + path, _ := itemMap["path"].(string) + mode, _ := itemMap["mode"].(string) + if strings.TrimSpace(path) == "" || strings.TrimSpace(mode) == "" { + return nil, fmt.Errorf("condition path/mode is required") + } + condition := ConditionOperation{ + Path: path, + Mode: mode, + } + if value, exists := itemMap["value"]; exists { + condition.Value = value + } + if invert, ok := itemMap["invert"].(bool); ok { + condition.Invert = invert + } + if passMissingKey, ok := itemMap["pass_missing_key"].(bool); ok { + condition.PassMissingKey = passMissingKey + } + result = append(result, condition) + } + return result, nil + default: + return nil, fmt.Errorf("conditions must be an array or object") } - - result := make([]ConditionOperation, 0, len(items)) - for _, item := range items { - itemMap, ok := item.(map[string]interface{}) - if !ok { - return nil, fmt.Errorf("condition must be object") - } - path, _ := itemMap["path"].(string) - mode, _ := itemMap["mode"].(string) - if strings.TrimSpace(path) == "" || strings.TrimSpace(mode) == "" { - return nil, fmt.Errorf("condition path/mode is required") - } - condition := ConditionOperation{ - Path: path, - Mode: mode, - } - if value, exists := itemMap["value"]; exists { - condition.Value = value - } - if invert, ok := itemMap["invert"].(bool); ok { - condition.Invert = invert - } - if passMissingKey, ok := itemMap["pass_missing_key"].(bool); ok { - condition.PassMissingKey = passMissingKey - } - result = append(result, condition) - } - return result, nil } func pruneObjectsNode(node interface{}, options pruneObjectsOptions, contextJSON string, isRoot bool) (interface{}, bool, error) { @@ -970,6 +1299,17 @@ func BuildParamOverrideContext(info *RelayInfo) map[string]interface{} { } } + ctx[paramOverrideContextRequestHeaders] = buildNormalizedHeaders(info.RequestHeaders) + ctx[paramOverrideContextRequestHeadersRaw] = buildRawHeaders(info.RequestHeaders) + + headerOverrideSource := getHeaderOverrideMap(info) + if info.UseRuntimeHeadersOverride { + headerOverrideSource = info.RuntimeHeadersOverride + } + rawHeaderOverride, normalizedHeaderOverride := buildHeaderOverrideContext(headerOverrideSource) + ctx[paramOverrideContextHeaderOverride] = rawHeaderOverride + ctx[paramOverrideContextHeaderOverrideNormalized] = normalizedHeaderOverride + ctx["retry_index"] = info.RetryIndex ctx["is_retry"] = info.RetryIndex > 0 ctx["retry"] = map[string]interface{}{ diff --git a/relay/common/override_test.go b/relay/common/override_test.go index cc1489f74..653a87f6a 100644 --- a/relay/common/override_test.go +++ b/relay/common/override_test.go @@ -956,6 +956,254 @@ func TestApplyParamOverrideConditionFromRetryAndLastErrorContext(t *testing.T) { assertJSONEqual(t, `{"temperature":0.1}`, string(out)) } +func TestApplyParamOverrideConditionFromRequestHeaders(t *testing.T) { + input := []byte(`{"temperature":0.7}`) + override := map[string]interface{}{ + "operations": []interface{}{ + map[string]interface{}{ + "path": "temperature", + "mode": "set", + "value": 0.1, + "conditions": []interface{}{ + map[string]interface{}{ + "path": "request_headers.authorization", + "mode": "contains", + "value": "Bearer ", + }, + }, + }, + }, + } + ctx := map[string]interface{}{ + "request_headers": map[string]interface{}{ + "authorization": "Bearer token-123", + }, + } + + out, err := ApplyParamOverride(input, override, ctx) + if err != nil { + t.Fatalf("ApplyParamOverride returned error: %v", err) + } + assertJSONEqual(t, `{"temperature":0.1}`, string(out)) +} + +func TestApplyParamOverrideSetHeaderAndUseInLaterCondition(t *testing.T) { + input := []byte(`{"temperature":0.7}`) + override := map[string]interface{}{ + "operations": []interface{}{ + map[string]interface{}{ + "mode": "set_header", + "path": "X-Debug-Mode", + "value": "enabled", + }, + map[string]interface{}{ + "path": "temperature", + "mode": "set", + "value": 0.1, + "conditions": []interface{}{ + map[string]interface{}{ + "path": "header_override_normalized.x_debug_mode", + "mode": "full", + "value": "enabled", + }, + }, + }, + }, + } + + out, err := ApplyParamOverride(input, override, nil) + if err != nil { + t.Fatalf("ApplyParamOverride returned error: %v", err) + } + assertJSONEqual(t, `{"temperature":0.1}`, string(out)) +} + +func TestApplyParamOverrideCopyHeaderFromRequestHeaders(t *testing.T) { + input := []byte(`{"temperature":0.7}`) + override := map[string]interface{}{ + "operations": []interface{}{ + map[string]interface{}{ + "mode": "copy_header", + "from": "Authorization", + "to": "X-Upstream-Auth", + }, + map[string]interface{}{ + "path": "temperature", + "mode": "set", + "value": 0.1, + "conditions": []interface{}{ + map[string]interface{}{ + "path": "header_override_normalized.x_upstream_auth", + "mode": "contains", + "value": "Bearer ", + }, + }, + }, + }, + } + ctx := map[string]interface{}{ + "request_headers_raw": map[string]interface{}{ + "Authorization": "Bearer token-123", + }, + "request_headers": map[string]interface{}{ + "authorization": "Bearer token-123", + }, + } + + out, err := ApplyParamOverride(input, override, ctx) + if err != nil { + t.Fatalf("ApplyParamOverride returned error: %v", err) + } + assertJSONEqual(t, `{"temperature":0.1}`, string(out)) +} + +func TestApplyParamOverrideSetHeaderKeepOrigin(t *testing.T) { + input := []byte(`{"temperature":0.7}`) + override := map[string]interface{}{ + "operations": []interface{}{ + map[string]interface{}{ + "mode": "set_header", + "path": "X-Feature-Flag", + "value": "new-value", + "keep_origin": true, + }, + }, + } + ctx := map[string]interface{}{ + "header_override": map[string]interface{}{ + "X-Feature-Flag": "legacy-value", + }, + "header_override_normalized": map[string]interface{}{ + "x_feature_flag": "legacy-value", + }, + } + + _, err := ApplyParamOverride(input, override, ctx) + if err != nil { + t.Fatalf("ApplyParamOverride returned error: %v", err) + } + headers, ok := ctx["header_override"].(map[string]interface{}) + if !ok { + t.Fatalf("expected header_override context map") + } + if headers["X-Feature-Flag"] != "legacy-value" { + t.Fatalf("expected keep_origin to preserve old value, got: %v", headers["X-Feature-Flag"]) + } +} + +func TestApplyParamOverrideConditionsObjectShorthand(t *testing.T) { + input := []byte(`{"temperature":0.7}`) + override := map[string]interface{}{ + "operations": []interface{}{ + map[string]interface{}{ + "path": "temperature", + "mode": "set", + "value": 0.1, + "logic": "AND", + "conditions": map[string]interface{}{ + "is_retry": true, + "last_error.status_code": 400.0, + }, + }, + }, + } + ctx := map[string]interface{}{ + "is_retry": true, + "last_error": map[string]interface{}{ + "status_code": 400.0, + }, + } + + out, err := ApplyParamOverride(input, override, ctx) + if err != nil { + t.Fatalf("ApplyParamOverride returned error: %v", err) + } + assertJSONEqual(t, `{"temperature":0.1}`, string(out)) +} + +func TestApplyParamOverrideWithRelayInfoSyncRuntimeHeaders(t *testing.T) { + info := &RelayInfo{ + ChannelMeta: &ChannelMeta{ + ParamOverride: map[string]interface{}{ + "operations": []interface{}{ + map[string]interface{}{ + "mode": "set_header", + "path": "X-Injected-By-Param-Override", + "value": "enabled", + }, + map[string]interface{}{ + "mode": "delete_header", + "path": "X-Delete-Me", + }, + }, + }, + HeadersOverride: map[string]interface{}{ + "X-Delete-Me": "legacy", + "X-Keep-Me": "keep", + }, + }, + } + + input := []byte(`{"temperature":0.7}`) + out, err := ApplyParamOverrideWithRelayInfo(input, info) + if err != nil { + t.Fatalf("ApplyParamOverrideWithRelayInfo returned error: %v", err) + } + assertJSONEqual(t, `{"temperature":0.7}`, string(out)) + + if !info.UseRuntimeHeadersOverride { + t.Fatalf("expected runtime header override to be enabled") + } + if info.RuntimeHeadersOverride["X-Keep-Me"] != "keep" { + t.Fatalf("expected X-Keep-Me header to be preserved, got: %v", info.RuntimeHeadersOverride["X-Keep-Me"]) + } + if info.RuntimeHeadersOverride["X-Injected-By-Param-Override"] != "enabled" { + t.Fatalf("expected X-Injected-By-Param-Override header to be set, got: %v", info.RuntimeHeadersOverride["X-Injected-By-Param-Override"]) + } + if _, exists := info.RuntimeHeadersOverride["X-Delete-Me"]; exists { + t.Fatalf("expected X-Delete-Me header to be deleted") + } +} + +func TestApplyParamOverrideWithRelayInfoMoveAndCopyHeaders(t *testing.T) { + info := &RelayInfo{ + ChannelMeta: &ChannelMeta{ + ParamOverride: map[string]interface{}{ + "operations": []interface{}{ + map[string]interface{}{ + "mode": "move_header", + "from": "X-Legacy-Trace", + "to": "X-Trace", + }, + map[string]interface{}{ + "mode": "copy_header", + "from": "X-Trace", + "to": "X-Trace-Backup", + }, + }, + }, + HeadersOverride: map[string]interface{}{ + "X-Legacy-Trace": "trace-123", + }, + }, + } + + input := []byte(`{"temperature":0.7}`) + _, err := ApplyParamOverrideWithRelayInfo(input, info) + if err != nil { + t.Fatalf("ApplyParamOverrideWithRelayInfo returned error: %v", err) + } + if _, exists := info.RuntimeHeadersOverride["X-Legacy-Trace"]; exists { + t.Fatalf("expected source header to be removed after move") + } + if info.RuntimeHeadersOverride["X-Trace"] != "trace-123" { + t.Fatalf("expected X-Trace to be set, got: %v", info.RuntimeHeadersOverride["X-Trace"]) + } + if info.RuntimeHeadersOverride["X-Trace-Backup"] != "trace-123" { + t.Fatalf("expected X-Trace-Backup to be copied, got: %v", info.RuntimeHeadersOverride["X-Trace-Backup"]) + } +} + func assertJSONEqual(t *testing.T, want, got string) { t.Helper() diff --git a/relay/common/relay_info.go b/relay/common/relay_info.go index c10e6d5fb..e5a0a06f5 100644 --- a/relay/common/relay_info.go +++ b/relay/common/relay_info.go @@ -101,6 +101,7 @@ type RelayInfo struct { RelayMode int OriginModelName string RequestURLPath string + RequestHeaders map[string]string ShouldIncludeUsage bool DisablePing bool // 是否禁止向下游发送自定义 Ping ClientWs *websocket.Conn @@ -142,6 +143,8 @@ type RelayInfo struct { IsChannelTest bool // channel test request RetryIndex int LastError *types.NewAPIError + RuntimeHeadersOverride map[string]interface{} + UseRuntimeHeadersOverride bool PriceData types.PriceData @@ -458,6 +461,7 @@ func genBaseRelayInfo(c *gin.Context, request dto.Request) *RelayInfo { isFirstResponse: true, RelayMode: relayconstant.Path2RelayMode(c.Request.URL.Path), RequestURLPath: c.Request.URL.String(), + RequestHeaders: cloneRequestHeaders(c), IsStream: isStream, StartTime: startTime, @@ -490,6 +494,27 @@ func genBaseRelayInfo(c *gin.Context, request dto.Request) *RelayInfo { return info } +func cloneRequestHeaders(c *gin.Context) map[string]string { + if c == nil || c.Request == nil { + return nil + } + if len(c.Request.Header) == 0 { + return nil + } + headers := make(map[string]string, len(c.Request.Header)) + for key := range c.Request.Header { + value := strings.TrimSpace(c.Request.Header.Get(key)) + if value == "" { + continue + } + headers[key] = value + } + if len(headers) == 0 { + return nil + } + return headers +} + func GenRelayInfo(c *gin.Context, relayFormat types.RelayFormat, request dto.Request, ws *websocket.Conn) (*RelayInfo, error) { var info *RelayInfo var err error diff --git a/relay/compatible_handler.go b/relay/compatible_handler.go index 4cf5e0411..7f4b99488 100644 --- a/relay/compatible_handler.go +++ b/relay/compatible_handler.go @@ -172,7 +172,7 @@ func TextHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types // apply param override if len(info.ParamOverride) > 0 { - jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride, relaycommon.BuildParamOverrideContext(info)) + jsonData, err = relaycommon.ApplyParamOverrideWithRelayInfo(jsonData, info) if err != nil { return newAPIErrorFromParamOverride(err) } diff --git a/relay/embedding_handler.go b/relay/embedding_handler.go index edbd1f7e6..d8ca42230 100644 --- a/relay/embedding_handler.go +++ b/relay/embedding_handler.go @@ -51,7 +51,7 @@ func EmbeddingHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError * } if len(info.ParamOverride) > 0 { - jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride, relaycommon.BuildParamOverrideContext(info)) + jsonData, err = relaycommon.ApplyParamOverrideWithRelayInfo(jsonData, info) if err != nil { return newAPIErrorFromParamOverride(err) } diff --git a/relay/gemini_handler.go b/relay/gemini_handler.go index a58c404f5..39bd44e62 100644 --- a/relay/gemini_handler.go +++ b/relay/gemini_handler.go @@ -157,7 +157,7 @@ func GeminiHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ // apply param override if len(info.ParamOverride) > 0 { - jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride, relaycommon.BuildParamOverrideContext(info)) + jsonData, err = relaycommon.ApplyParamOverrideWithRelayInfo(jsonData, info) if err != nil { return newAPIErrorFromParamOverride(err) } @@ -257,7 +257,7 @@ func GeminiEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo) (newAPI // apply param override if len(info.ParamOverride) > 0 { - jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride, relaycommon.BuildParamOverrideContext(info)) + jsonData, err = relaycommon.ApplyParamOverrideWithRelayInfo(jsonData, info) if err != nil { return newAPIErrorFromParamOverride(err) } diff --git a/relay/image_handler.go b/relay/image_handler.go index 21a5be2fa..fc8ef500e 100644 --- a/relay/image_handler.go +++ b/relay/image_handler.go @@ -70,7 +70,7 @@ func ImageHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *type // apply param override if len(info.ParamOverride) > 0 { - jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride, relaycommon.BuildParamOverrideContext(info)) + jsonData, err = relaycommon.ApplyParamOverrideWithRelayInfo(jsonData, info) if err != nil { return newAPIErrorFromParamOverride(err) } diff --git a/relay/rerank_handler.go b/relay/rerank_handler.go index 9c4bef6e1..40d686f70 100644 --- a/relay/rerank_handler.go +++ b/relay/rerank_handler.go @@ -61,7 +61,7 @@ func RerankHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ // apply param override if len(info.ParamOverride) > 0 { - jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride, relaycommon.BuildParamOverrideContext(info)) + jsonData, err = relaycommon.ApplyParamOverrideWithRelayInfo(jsonData, info) if err != nil { return newAPIErrorFromParamOverride(err) } diff --git a/relay/responses_handler.go b/relay/responses_handler.go index 2190be87f..3bcaa673f 100644 --- a/relay/responses_handler.go +++ b/relay/responses_handler.go @@ -96,7 +96,7 @@ func ResponsesHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError * // apply param override if len(info.ParamOverride) > 0 { - jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride, relaycommon.BuildParamOverrideContext(info)) + jsonData, err = relaycommon.ApplyParamOverrideWithRelayInfo(jsonData, info) if err != nil { return newAPIErrorFromParamOverride(err) }