From 0885597427fdd11d2120cb0222df13e839fcadfc Mon Sep 17 00:00:00 2001 From: Seefs Date: Sat, 22 Nov 2025 18:27:17 +0800 Subject: [PATCH] feat: embedding param override && internal params --- relay/claude_handler.go | 2 +- relay/common/override.go | 68 +++++++++++++++++++++++++++++-------- relay/compatible_handler.go | 2 +- relay/embedding_handler.go | 8 +++++ relay/gemini_handler.go | 2 +- relay/image_handler.go | 2 +- relay/rerank_handler.go | 2 +- relay/responses_handler.go | 2 +- 8 files changed, 68 insertions(+), 20 deletions(-) diff --git a/relay/claude_handler.go b/relay/claude_handler.go index 395d1e37b..7a18c1737 100644 --- a/relay/claude_handler.go +++ b/relay/claude_handler.go @@ -123,7 +123,7 @@ func ClaudeHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ // apply param override if len(info.ParamOverride) > 0 { - jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride) + jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride, relaycommon.BuildParamOverrideContext(info)) if err != nil { return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry()) } diff --git a/relay/common/override.go b/relay/common/override.go index ab60f1ab5..1d0794d26 100644 --- a/relay/common/override.go +++ b/relay/common/override.go @@ -1,12 +1,12 @@ package common import ( - "encoding/json" "fmt" "regexp" "strconv" "strings" + "github.com/QuantumNous/new-api/common" "github.com/tidwall/gjson" "github.com/tidwall/sjson" ) @@ -30,7 +30,7 @@ type ParamOperation struct { Logic string `json:"logic,omitempty"` // AND, OR (默认OR) } -func ApplyParamOverride(jsonData []byte, paramOverride map[string]interface{}) ([]byte, error) { +func ApplyParamOverride(jsonData []byte, paramOverride map[string]interface{}, conditionContext map[string]interface{}) ([]byte, error) { if len(paramOverride) == 0 { return jsonData, nil } @@ -38,7 +38,7 @@ func ApplyParamOverride(jsonData []byte, paramOverride map[string]interface{}) ( // 尝试断言为操作格式 if operations, ok := tryParseOperations(paramOverride); ok { // 使用新方法 - result, err := applyOperations(string(jsonData), operations) + result, err := applyOperations(string(jsonData), operations, conditionContext) return []byte(result), err } @@ -123,13 +123,13 @@ func tryParseOperations(paramOverride map[string]interface{}) ([]ParamOperation, return nil, false } -func checkConditions(jsonStr string, conditions []ConditionOperation, logic string) (bool, error) { +func checkConditions(jsonStr, contextJSON string, conditions []ConditionOperation, logic string) (bool, error) { if len(conditions) == 0 { return true, nil // 没有条件,直接通过 } results := make([]bool, len(conditions)) for i, condition := range conditions { - result, err := checkSingleCondition(jsonStr, condition) + result, err := checkSingleCondition(jsonStr, contextJSON, condition) if err != nil { return false, err } @@ -153,10 +153,13 @@ func checkConditions(jsonStr string, conditions []ConditionOperation, logic stri } } -func checkSingleCondition(jsonStr string, condition ConditionOperation) (bool, error) { +func checkSingleCondition(jsonStr, contextJSON string, condition ConditionOperation) (bool, error) { // 处理负数索引 path := processNegativeIndex(jsonStr, condition.Path) value := gjson.Get(jsonStr, path) + if !value.Exists() && contextJSON != "" { + value = gjson.Get(contextJSON, condition.Path) + } if !value.Exists() { if condition.PassMissingKey { return true, nil @@ -165,7 +168,7 @@ func checkSingleCondition(jsonStr string, condition ConditionOperation) (bool, e } // 利用gjson的类型解析 - targetBytes, err := json.Marshal(condition.Value) + targetBytes, err := common.Marshal(condition.Value) if err != nil { return false, fmt.Errorf("failed to marshal condition value: %v", err) } @@ -292,7 +295,7 @@ func compareNumeric(jsonValue, targetValue gjson.Result, operator string) (bool, // applyOperationsLegacy 原参数覆盖方法 func applyOperationsLegacy(jsonData []byte, paramOverride map[string]interface{}) ([]byte, error) { reqMap := make(map[string]interface{}) - err := json.Unmarshal(jsonData, &reqMap) + err := common.Unmarshal(jsonData, &reqMap) if err != nil { return nil, err } @@ -301,14 +304,23 @@ func applyOperationsLegacy(jsonData []byte, paramOverride map[string]interface{} reqMap[key] = value } - return json.Marshal(reqMap) + return common.Marshal(reqMap) } -func applyOperations(jsonStr string, operations []ParamOperation) (string, error) { +func applyOperations(jsonStr string, operations []ParamOperation, conditionContext map[string]interface{}) (string, error) { + var contextJSON string + if conditionContext != nil && len(conditionContext) > 0 { + ctxBytes, err := common.Marshal(conditionContext) + if err != nil { + return "", fmt.Errorf("failed to marshal condition context: %v", err) + } + contextJSON = string(ctxBytes) + } + result := jsonStr for _, op := range operations { // 检查条件是否满足 - ok, err := checkConditions(result, op.Conditions, op.Logic) + ok, err := checkConditions(result, contextJSON, op.Conditions, op.Logic) if err != nil { return "", err } @@ -414,7 +426,7 @@ func mergeObjects(jsonStr, path string, value interface{}, keepOrigin bool) (str var currentMap, newMap map[string]interface{} // 解析当前值 - if err := json.Unmarshal([]byte(current.Raw), ¤tMap); err != nil { + if err := common.Unmarshal([]byte(current.Raw), ¤tMap); err != nil { return "", err } // 解析新值 @@ -422,8 +434,8 @@ func mergeObjects(jsonStr, path string, value interface{}, keepOrigin bool) (str case map[string]interface{}: newMap = v default: - jsonBytes, _ := json.Marshal(v) - if err := json.Unmarshal(jsonBytes, &newMap); err != nil { + jsonBytes, _ := common.Marshal(v) + if err := common.Unmarshal(jsonBytes, &newMap); err != nil { return "", err } } @@ -439,3 +451,31 @@ func mergeObjects(jsonStr, path string, value interface{}, keepOrigin bool) (str } return sjson.Set(jsonStr, path, result) } + +// BuildParamOverrideContext 提供 ApplyParamOverride 可用的上下文信息。 +// 目前内置以下字段: +// - model:优先使用上游模型名(UpstreamModelName),若不存在则回落到原始模型名(OriginModelName)。 +// - upstream_model:始终为通道映射后的上游模型名。 +// - original_model:请求最初指定的模型名。 +func BuildParamOverrideContext(info *RelayInfo) map[string]interface{} { + if info == nil || info.ChannelMeta == nil { + return nil + } + + ctx := make(map[string]interface{}) + if info.UpstreamModelName != "" { + ctx["model"] = info.UpstreamModelName + ctx["upstream_model"] = info.UpstreamModelName + } + if info.OriginModelName != "" { + ctx["original_model"] = info.OriginModelName + if _, exists := ctx["model"]; !exists { + ctx["model"] = info.OriginModelName + } + } + + if len(ctx) == 0 { + return nil + } + return ctx +} diff --git a/relay/compatible_handler.go b/relay/compatible_handler.go index 1975eb423..cb3b5d5f2 100644 --- a/relay/compatible_handler.go +++ b/relay/compatible_handler.go @@ -144,7 +144,7 @@ func TextHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types // apply param override if len(info.ParamOverride) > 0 { - jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride) + jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride, relaycommon.BuildParamOverrideContext(info)) if err != nil { return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry()) } diff --git a/relay/embedding_handler.go b/relay/embedding_handler.go index 9bb76df03..740ca400e 100644 --- a/relay/embedding_handler.go +++ b/relay/embedding_handler.go @@ -49,6 +49,14 @@ func EmbeddingHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError * if err != nil { return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry()) } + + if len(info.ParamOverride) > 0 { + jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride, relaycommon.BuildParamOverrideContext(info)) + if err != nil { + return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry()) + } + } + logger.LogDebug(c, fmt.Sprintf("converted embedding request body: %s", string(jsonData))) requestBody := bytes.NewBuffer(jsonData) statusCodeMappingStr := c.GetString("status_code_mapping") diff --git a/relay/gemini_handler.go b/relay/gemini_handler.go index b3eb7f336..af13341bf 100644 --- a/relay/gemini_handler.go +++ b/relay/gemini_handler.go @@ -156,7 +156,7 @@ func GeminiHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ // apply param override if len(info.ParamOverride) > 0 { - jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride) + jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride, relaycommon.BuildParamOverrideContext(info)) if err != nil { return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry()) } diff --git a/relay/image_handler.go b/relay/image_handler.go index 101447c13..b58968402 100644 --- a/relay/image_handler.go +++ b/relay/image_handler.go @@ -69,7 +69,7 @@ func ImageHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *type // apply param override if len(info.ParamOverride) > 0 { - jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride) + jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride, relaycommon.BuildParamOverrideContext(info)) if err != nil { return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry()) } diff --git a/relay/rerank_handler.go b/relay/rerank_handler.go index 06aebbd1c..3efc45079 100644 --- a/relay/rerank_handler.go +++ b/relay/rerank_handler.go @@ -60,7 +60,7 @@ func RerankHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ // apply param override if len(info.ParamOverride) > 0 { - jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride) + jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride, relaycommon.BuildParamOverrideContext(info)) if err != nil { return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry()) } diff --git a/relay/responses_handler.go b/relay/responses_handler.go index 8087e2391..9460356d6 100644 --- a/relay/responses_handler.go +++ b/relay/responses_handler.go @@ -66,7 +66,7 @@ func ResponsesHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError * // apply param override if len(info.ParamOverride) > 0 { - jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride) + jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride, relaycommon.BuildParamOverrideContext(info)) if err != nil { return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry()) }