feat: embedding param override && internal params

This commit is contained in:
Seefs
2025-11-22 18:27:17 +08:00
parent efb8f1f5b8
commit 0885597427
8 changed files with 68 additions and 20 deletions

View File

@@ -123,7 +123,7 @@ func ClaudeHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
// apply param override // apply param override
if len(info.ParamOverride) > 0 { if len(info.ParamOverride) > 0 {
jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride) jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride, relaycommon.BuildParamOverrideContext(info))
if err != nil { if err != nil {
return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry()) return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
} }

View File

@@ -1,12 +1,12 @@
package common package common
import ( import (
"encoding/json"
"fmt" "fmt"
"regexp" "regexp"
"strconv" "strconv"
"strings" "strings"
"github.com/QuantumNous/new-api/common"
"github.com/tidwall/gjson" "github.com/tidwall/gjson"
"github.com/tidwall/sjson" "github.com/tidwall/sjson"
) )
@@ -30,7 +30,7 @@ type ParamOperation struct {
Logic string `json:"logic,omitempty"` // AND, OR (默认OR) Logic string `json:"logic,omitempty"` // AND, OR (默认OR)
} }
func ApplyParamOverride(jsonData []byte, paramOverride map[string]interface{}) ([]byte, error) { func ApplyParamOverride(jsonData []byte, paramOverride map[string]interface{}, conditionContext map[string]interface{}) ([]byte, error) {
if len(paramOverride) == 0 { if len(paramOverride) == 0 {
return jsonData, nil return jsonData, nil
} }
@@ -38,7 +38,7 @@ func ApplyParamOverride(jsonData []byte, paramOverride map[string]interface{}) (
// 尝试断言为操作格式 // 尝试断言为操作格式
if operations, ok := tryParseOperations(paramOverride); ok { if operations, ok := tryParseOperations(paramOverride); ok {
// 使用新方法 // 使用新方法
result, err := applyOperations(string(jsonData), operations) result, err := applyOperations(string(jsonData), operations, conditionContext)
return []byte(result), err return []byte(result), err
} }
@@ -123,13 +123,13 @@ func tryParseOperations(paramOverride map[string]interface{}) ([]ParamOperation,
return nil, false return nil, false
} }
func checkConditions(jsonStr string, conditions []ConditionOperation, logic string) (bool, error) { func checkConditions(jsonStr, contextJSON string, conditions []ConditionOperation, logic string) (bool, error) {
if len(conditions) == 0 { if len(conditions) == 0 {
return true, nil // 没有条件,直接通过 return true, nil // 没有条件,直接通过
} }
results := make([]bool, len(conditions)) results := make([]bool, len(conditions))
for i, condition := range conditions { for i, condition := range conditions {
result, err := checkSingleCondition(jsonStr, condition) result, err := checkSingleCondition(jsonStr, contextJSON, condition)
if err != nil { if err != nil {
return false, err return false, err
} }
@@ -153,10 +153,13 @@ func checkConditions(jsonStr string, conditions []ConditionOperation, logic stri
} }
} }
func checkSingleCondition(jsonStr string, condition ConditionOperation) (bool, error) { func checkSingleCondition(jsonStr, contextJSON string, condition ConditionOperation) (bool, error) {
// 处理负数索引 // 处理负数索引
path := processNegativeIndex(jsonStr, condition.Path) path := processNegativeIndex(jsonStr, condition.Path)
value := gjson.Get(jsonStr, path) value := gjson.Get(jsonStr, path)
if !value.Exists() && contextJSON != "" {
value = gjson.Get(contextJSON, condition.Path)
}
if !value.Exists() { if !value.Exists() {
if condition.PassMissingKey { if condition.PassMissingKey {
return true, nil return true, nil
@@ -165,7 +168,7 @@ func checkSingleCondition(jsonStr string, condition ConditionOperation) (bool, e
} }
// 利用gjson的类型解析 // 利用gjson的类型解析
targetBytes, err := json.Marshal(condition.Value) targetBytes, err := common.Marshal(condition.Value)
if err != nil { if err != nil {
return false, fmt.Errorf("failed to marshal condition value: %v", err) return false, fmt.Errorf("failed to marshal condition value: %v", err)
} }
@@ -292,7 +295,7 @@ func compareNumeric(jsonValue, targetValue gjson.Result, operator string) (bool,
// applyOperationsLegacy 原参数覆盖方法 // applyOperationsLegacy 原参数覆盖方法
func applyOperationsLegacy(jsonData []byte, paramOverride map[string]interface{}) ([]byte, error) { func applyOperationsLegacy(jsonData []byte, paramOverride map[string]interface{}) ([]byte, error) {
reqMap := make(map[string]interface{}) reqMap := make(map[string]interface{})
err := json.Unmarshal(jsonData, &reqMap) err := common.Unmarshal(jsonData, &reqMap)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -301,14 +304,23 @@ func applyOperationsLegacy(jsonData []byte, paramOverride map[string]interface{}
reqMap[key] = value reqMap[key] = value
} }
return json.Marshal(reqMap) return common.Marshal(reqMap)
} }
func applyOperations(jsonStr string, operations []ParamOperation) (string, error) { func applyOperations(jsonStr string, operations []ParamOperation, conditionContext map[string]interface{}) (string, error) {
var contextJSON string
if conditionContext != nil && len(conditionContext) > 0 {
ctxBytes, err := common.Marshal(conditionContext)
if err != nil {
return "", fmt.Errorf("failed to marshal condition context: %v", err)
}
contextJSON = string(ctxBytes)
}
result := jsonStr result := jsonStr
for _, op := range operations { for _, op := range operations {
// 检查条件是否满足 // 检查条件是否满足
ok, err := checkConditions(result, op.Conditions, op.Logic) ok, err := checkConditions(result, contextJSON, op.Conditions, op.Logic)
if err != nil { if err != nil {
return "", err return "", err
} }
@@ -414,7 +426,7 @@ func mergeObjects(jsonStr, path string, value interface{}, keepOrigin bool) (str
var currentMap, newMap map[string]interface{} var currentMap, newMap map[string]interface{}
// 解析当前值 // 解析当前值
if err := json.Unmarshal([]byte(current.Raw), &currentMap); err != nil { if err := common.Unmarshal([]byte(current.Raw), &currentMap); err != nil {
return "", err return "", err
} }
// 解析新值 // 解析新值
@@ -422,8 +434,8 @@ func mergeObjects(jsonStr, path string, value interface{}, keepOrigin bool) (str
case map[string]interface{}: case map[string]interface{}:
newMap = v newMap = v
default: default:
jsonBytes, _ := json.Marshal(v) jsonBytes, _ := common.Marshal(v)
if err := json.Unmarshal(jsonBytes, &newMap); err != nil { if err := common.Unmarshal(jsonBytes, &newMap); err != nil {
return "", err return "", err
} }
} }
@@ -439,3 +451,31 @@ func mergeObjects(jsonStr, path string, value interface{}, keepOrigin bool) (str
} }
return sjson.Set(jsonStr, path, result) return sjson.Set(jsonStr, path, result)
} }
// BuildParamOverrideContext 提供 ApplyParamOverride 可用的上下文信息。
// 目前内置以下字段:
// - model优先使用上游模型名UpstreamModelName若不存在则回落到原始模型名OriginModelName
// - upstream_model始终为通道映射后的上游模型名。
// - original_model请求最初指定的模型名。
func BuildParamOverrideContext(info *RelayInfo) map[string]interface{} {
if info == nil || info.ChannelMeta == nil {
return nil
}
ctx := make(map[string]interface{})
if info.UpstreamModelName != "" {
ctx["model"] = info.UpstreamModelName
ctx["upstream_model"] = info.UpstreamModelName
}
if info.OriginModelName != "" {
ctx["original_model"] = info.OriginModelName
if _, exists := ctx["model"]; !exists {
ctx["model"] = info.OriginModelName
}
}
if len(ctx) == 0 {
return nil
}
return ctx
}

View File

@@ -144,7 +144,7 @@ func TextHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types
// apply param override // apply param override
if len(info.ParamOverride) > 0 { if len(info.ParamOverride) > 0 {
jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride) jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride, relaycommon.BuildParamOverrideContext(info))
if err != nil { if err != nil {
return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry()) return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
} }

View File

@@ -49,6 +49,14 @@ func EmbeddingHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *
if err != nil { if err != nil {
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry()) return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
} }
if len(info.ParamOverride) > 0 {
jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride, relaycommon.BuildParamOverrideContext(info))
if err != nil {
return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
}
}
logger.LogDebug(c, fmt.Sprintf("converted embedding request body: %s", string(jsonData))) logger.LogDebug(c, fmt.Sprintf("converted embedding request body: %s", string(jsonData)))
requestBody := bytes.NewBuffer(jsonData) requestBody := bytes.NewBuffer(jsonData)
statusCodeMappingStr := c.GetString("status_code_mapping") statusCodeMappingStr := c.GetString("status_code_mapping")

View File

@@ -156,7 +156,7 @@ func GeminiHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
// apply param override // apply param override
if len(info.ParamOverride) > 0 { if len(info.ParamOverride) > 0 {
jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride) jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride, relaycommon.BuildParamOverrideContext(info))
if err != nil { if err != nil {
return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry()) return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
} }

View File

@@ -69,7 +69,7 @@ func ImageHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *type
// apply param override // apply param override
if len(info.ParamOverride) > 0 { if len(info.ParamOverride) > 0 {
jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride) jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride, relaycommon.BuildParamOverrideContext(info))
if err != nil { if err != nil {
return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry()) return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
} }

View File

@@ -60,7 +60,7 @@ func RerankHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
// apply param override // apply param override
if len(info.ParamOverride) > 0 { if len(info.ParamOverride) > 0 {
jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride) jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride, relaycommon.BuildParamOverrideContext(info))
if err != nil { if err != nil {
return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry()) return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
} }

View File

@@ -66,7 +66,7 @@ func ResponsesHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *
// apply param override // apply param override
if len(info.ParamOverride) > 0 { if len(info.ParamOverride) > 0 {
jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride) jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride, relaycommon.BuildParamOverrideContext(info))
if err != nil { if err != nil {
return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry()) return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
} }