mirror of
https://github.com/QuantumNous/new-api.git
synced 2026-04-28 19:48:37 +00:00
feat: embedding param override && internal params
This commit is contained in:
@@ -123,7 +123,7 @@ func ClaudeHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
|
|||||||
|
|
||||||
// apply param override
|
// apply param override
|
||||||
if len(info.ParamOverride) > 0 {
|
if len(info.ParamOverride) > 0 {
|
||||||
jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride)
|
jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride, relaycommon.BuildParamOverrideContext(info))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
|
return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,12 +1,12 @@
|
|||||||
package common
|
package common
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"regexp"
|
"regexp"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"github.com/QuantumNous/new-api/common"
|
||||||
"github.com/tidwall/gjson"
|
"github.com/tidwall/gjson"
|
||||||
"github.com/tidwall/sjson"
|
"github.com/tidwall/sjson"
|
||||||
)
|
)
|
||||||
@@ -30,7 +30,7 @@ type ParamOperation struct {
|
|||||||
Logic string `json:"logic,omitempty"` // AND, OR (默认OR)
|
Logic string `json:"logic,omitempty"` // AND, OR (默认OR)
|
||||||
}
|
}
|
||||||
|
|
||||||
func ApplyParamOverride(jsonData []byte, paramOverride map[string]interface{}) ([]byte, error) {
|
func ApplyParamOverride(jsonData []byte, paramOverride map[string]interface{}, conditionContext map[string]interface{}) ([]byte, error) {
|
||||||
if len(paramOverride) == 0 {
|
if len(paramOverride) == 0 {
|
||||||
return jsonData, nil
|
return jsonData, nil
|
||||||
}
|
}
|
||||||
@@ -38,7 +38,7 @@ func ApplyParamOverride(jsonData []byte, paramOverride map[string]interface{}) (
|
|||||||
// 尝试断言为操作格式
|
// 尝试断言为操作格式
|
||||||
if operations, ok := tryParseOperations(paramOverride); ok {
|
if operations, ok := tryParseOperations(paramOverride); ok {
|
||||||
// 使用新方法
|
// 使用新方法
|
||||||
result, err := applyOperations(string(jsonData), operations)
|
result, err := applyOperations(string(jsonData), operations, conditionContext)
|
||||||
return []byte(result), err
|
return []byte(result), err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -123,13 +123,13 @@ func tryParseOperations(paramOverride map[string]interface{}) ([]ParamOperation,
|
|||||||
return nil, false
|
return nil, false
|
||||||
}
|
}
|
||||||
|
|
||||||
func checkConditions(jsonStr string, conditions []ConditionOperation, logic string) (bool, error) {
|
func checkConditions(jsonStr, contextJSON string, conditions []ConditionOperation, logic string) (bool, error) {
|
||||||
if len(conditions) == 0 {
|
if len(conditions) == 0 {
|
||||||
return true, nil // 没有条件,直接通过
|
return true, nil // 没有条件,直接通过
|
||||||
}
|
}
|
||||||
results := make([]bool, len(conditions))
|
results := make([]bool, len(conditions))
|
||||||
for i, condition := range conditions {
|
for i, condition := range conditions {
|
||||||
result, err := checkSingleCondition(jsonStr, condition)
|
result, err := checkSingleCondition(jsonStr, contextJSON, condition)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, err
|
return false, err
|
||||||
}
|
}
|
||||||
@@ -153,10 +153,13 @@ func checkConditions(jsonStr string, conditions []ConditionOperation, logic stri
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func checkSingleCondition(jsonStr string, condition ConditionOperation) (bool, error) {
|
func checkSingleCondition(jsonStr, contextJSON string, condition ConditionOperation) (bool, error) {
|
||||||
// 处理负数索引
|
// 处理负数索引
|
||||||
path := processNegativeIndex(jsonStr, condition.Path)
|
path := processNegativeIndex(jsonStr, condition.Path)
|
||||||
value := gjson.Get(jsonStr, path)
|
value := gjson.Get(jsonStr, path)
|
||||||
|
if !value.Exists() && contextJSON != "" {
|
||||||
|
value = gjson.Get(contextJSON, condition.Path)
|
||||||
|
}
|
||||||
if !value.Exists() {
|
if !value.Exists() {
|
||||||
if condition.PassMissingKey {
|
if condition.PassMissingKey {
|
||||||
return true, nil
|
return true, nil
|
||||||
@@ -165,7 +168,7 @@ func checkSingleCondition(jsonStr string, condition ConditionOperation) (bool, e
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 利用gjson的类型解析
|
// 利用gjson的类型解析
|
||||||
targetBytes, err := json.Marshal(condition.Value)
|
targetBytes, err := common.Marshal(condition.Value)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, fmt.Errorf("failed to marshal condition value: %v", err)
|
return false, fmt.Errorf("failed to marshal condition value: %v", err)
|
||||||
}
|
}
|
||||||
@@ -292,7 +295,7 @@ func compareNumeric(jsonValue, targetValue gjson.Result, operator string) (bool,
|
|||||||
// applyOperationsLegacy 原参数覆盖方法
|
// applyOperationsLegacy 原参数覆盖方法
|
||||||
func applyOperationsLegacy(jsonData []byte, paramOverride map[string]interface{}) ([]byte, error) {
|
func applyOperationsLegacy(jsonData []byte, paramOverride map[string]interface{}) ([]byte, error) {
|
||||||
reqMap := make(map[string]interface{})
|
reqMap := make(map[string]interface{})
|
||||||
err := json.Unmarshal(jsonData, &reqMap)
|
err := common.Unmarshal(jsonData, &reqMap)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -301,14 +304,23 @@ func applyOperationsLegacy(jsonData []byte, paramOverride map[string]interface{}
|
|||||||
reqMap[key] = value
|
reqMap[key] = value
|
||||||
}
|
}
|
||||||
|
|
||||||
return json.Marshal(reqMap)
|
return common.Marshal(reqMap)
|
||||||
}
|
}
|
||||||
|
|
||||||
func applyOperations(jsonStr string, operations []ParamOperation) (string, error) {
|
func applyOperations(jsonStr string, operations []ParamOperation, conditionContext map[string]interface{}) (string, error) {
|
||||||
|
var contextJSON string
|
||||||
|
if conditionContext != nil && len(conditionContext) > 0 {
|
||||||
|
ctxBytes, err := common.Marshal(conditionContext)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("failed to marshal condition context: %v", err)
|
||||||
|
}
|
||||||
|
contextJSON = string(ctxBytes)
|
||||||
|
}
|
||||||
|
|
||||||
result := jsonStr
|
result := jsonStr
|
||||||
for _, op := range operations {
|
for _, op := range operations {
|
||||||
// 检查条件是否满足
|
// 检查条件是否满足
|
||||||
ok, err := checkConditions(result, op.Conditions, op.Logic)
|
ok, err := checkConditions(result, contextJSON, op.Conditions, op.Logic)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
@@ -414,7 +426,7 @@ func mergeObjects(jsonStr, path string, value interface{}, keepOrigin bool) (str
|
|||||||
var currentMap, newMap map[string]interface{}
|
var currentMap, newMap map[string]interface{}
|
||||||
|
|
||||||
// 解析当前值
|
// 解析当前值
|
||||||
if err := json.Unmarshal([]byte(current.Raw), ¤tMap); err != nil {
|
if err := common.Unmarshal([]byte(current.Raw), ¤tMap); err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
// 解析新值
|
// 解析新值
|
||||||
@@ -422,8 +434,8 @@ func mergeObjects(jsonStr, path string, value interface{}, keepOrigin bool) (str
|
|||||||
case map[string]interface{}:
|
case map[string]interface{}:
|
||||||
newMap = v
|
newMap = v
|
||||||
default:
|
default:
|
||||||
jsonBytes, _ := json.Marshal(v)
|
jsonBytes, _ := common.Marshal(v)
|
||||||
if err := json.Unmarshal(jsonBytes, &newMap); err != nil {
|
if err := common.Unmarshal(jsonBytes, &newMap); err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -439,3 +451,31 @@ func mergeObjects(jsonStr, path string, value interface{}, keepOrigin bool) (str
|
|||||||
}
|
}
|
||||||
return sjson.Set(jsonStr, path, result)
|
return sjson.Set(jsonStr, path, result)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// BuildParamOverrideContext 提供 ApplyParamOverride 可用的上下文信息。
|
||||||
|
// 目前内置以下字段:
|
||||||
|
// - model:优先使用上游模型名(UpstreamModelName),若不存在则回落到原始模型名(OriginModelName)。
|
||||||
|
// - upstream_model:始终为通道映射后的上游模型名。
|
||||||
|
// - original_model:请求最初指定的模型名。
|
||||||
|
func BuildParamOverrideContext(info *RelayInfo) map[string]interface{} {
|
||||||
|
if info == nil || info.ChannelMeta == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := make(map[string]interface{})
|
||||||
|
if info.UpstreamModelName != "" {
|
||||||
|
ctx["model"] = info.UpstreamModelName
|
||||||
|
ctx["upstream_model"] = info.UpstreamModelName
|
||||||
|
}
|
||||||
|
if info.OriginModelName != "" {
|
||||||
|
ctx["original_model"] = info.OriginModelName
|
||||||
|
if _, exists := ctx["model"]; !exists {
|
||||||
|
ctx["model"] = info.OriginModelName
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(ctx) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return ctx
|
||||||
|
}
|
||||||
|
|||||||
@@ -144,7 +144,7 @@ func TextHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types
|
|||||||
|
|
||||||
// apply param override
|
// apply param override
|
||||||
if len(info.ParamOverride) > 0 {
|
if len(info.ParamOverride) > 0 {
|
||||||
jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride)
|
jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride, relaycommon.BuildParamOverrideContext(info))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
|
return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -49,6 +49,14 @@ func EmbeddingHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
|
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if len(info.ParamOverride) > 0 {
|
||||||
|
jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride, relaycommon.BuildParamOverrideContext(info))
|
||||||
|
if err != nil {
|
||||||
|
return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
logger.LogDebug(c, fmt.Sprintf("converted embedding request body: %s", string(jsonData)))
|
logger.LogDebug(c, fmt.Sprintf("converted embedding request body: %s", string(jsonData)))
|
||||||
requestBody := bytes.NewBuffer(jsonData)
|
requestBody := bytes.NewBuffer(jsonData)
|
||||||
statusCodeMappingStr := c.GetString("status_code_mapping")
|
statusCodeMappingStr := c.GetString("status_code_mapping")
|
||||||
|
|||||||
@@ -156,7 +156,7 @@ func GeminiHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
|
|||||||
|
|
||||||
// apply param override
|
// apply param override
|
||||||
if len(info.ParamOverride) > 0 {
|
if len(info.ParamOverride) > 0 {
|
||||||
jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride)
|
jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride, relaycommon.BuildParamOverrideContext(info))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
|
return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -69,7 +69,7 @@ func ImageHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *type
|
|||||||
|
|
||||||
// apply param override
|
// apply param override
|
||||||
if len(info.ParamOverride) > 0 {
|
if len(info.ParamOverride) > 0 {
|
||||||
jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride)
|
jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride, relaycommon.BuildParamOverrideContext(info))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
|
return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -60,7 +60,7 @@ func RerankHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
|
|||||||
|
|
||||||
// apply param override
|
// apply param override
|
||||||
if len(info.ParamOverride) > 0 {
|
if len(info.ParamOverride) > 0 {
|
||||||
jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride)
|
jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride, relaycommon.BuildParamOverrideContext(info))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
|
return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -66,7 +66,7 @@ func ResponsesHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *
|
|||||||
|
|
||||||
// apply param override
|
// apply param override
|
||||||
if len(info.ParamOverride) > 0 {
|
if len(info.ParamOverride) > 0 {
|
||||||
jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride)
|
jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride, relaycommon.BuildParamOverrideContext(info))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
|
return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user