mirror of
https://github.com/QuantumNous/new-api.git
synced 2026-03-30 02:25:00 +00:00
feat: add retry-aware param override with return_error and prune_objects
This commit is contained in:
@@ -366,7 +366,7 @@ func testChannel(channel *model.Channel, testModel string, endpointType string,
|
|||||||
newAPIError: types.NewError(err, types.ErrorCodeConvertRequestFailed),
|
newAPIError: types.NewError(err, types.ErrorCodeConvertRequestFailed),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
jsonData, err := json.Marshal(convertedRequest)
|
jsonData, err := common.Marshal(convertedRequest)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return testResult{
|
return testResult{
|
||||||
context: c,
|
context: c,
|
||||||
@@ -387,6 +387,13 @@ func testChannel(channel *model.Channel, testModel string, endpointType string,
|
|||||||
if len(info.ParamOverride) > 0 {
|
if len(info.ParamOverride) > 0 {
|
||||||
jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride, relaycommon.BuildParamOverrideContext(info))
|
jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride, relaycommon.BuildParamOverrideContext(info))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
if fixedErr, ok := relaycommon.AsParamOverrideReturnError(err); ok {
|
||||||
|
return testResult{
|
||||||
|
context: c,
|
||||||
|
localErr: fixedErr,
|
||||||
|
newAPIError: relaycommon.NewAPIErrorFromParamOverride(fixedErr),
|
||||||
|
}
|
||||||
|
}
|
||||||
return testResult{
|
return testResult{
|
||||||
context: c,
|
context: c,
|
||||||
localErr: err,
|
localErr: err,
|
||||||
|
|||||||
@@ -182,8 +182,11 @@ func Relay(c *gin.Context, relayFormat types.RelayFormat) {
|
|||||||
ModelName: relayInfo.OriginModelName,
|
ModelName: relayInfo.OriginModelName,
|
||||||
Retry: common.GetPointer(0),
|
Retry: common.GetPointer(0),
|
||||||
}
|
}
|
||||||
|
relayInfo.RetryIndex = 0
|
||||||
|
relayInfo.LastError = nil
|
||||||
|
|
||||||
for ; retryParam.GetRetry() <= common.RetryTimes; retryParam.IncreaseRetry() {
|
for ; retryParam.GetRetry() <= common.RetryTimes; retryParam.IncreaseRetry() {
|
||||||
|
relayInfo.RetryIndex = retryParam.GetRetry()
|
||||||
channel, channelErr := getChannel(c, relayInfo, retryParam)
|
channel, channelErr := getChannel(c, relayInfo, retryParam)
|
||||||
if channelErr != nil {
|
if channelErr != nil {
|
||||||
logger.LogError(c, channelErr.Error())
|
logger.LogError(c, channelErr.Error())
|
||||||
@@ -216,10 +219,12 @@ func Relay(c *gin.Context, relayFormat types.RelayFormat) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if newAPIError == nil {
|
if newAPIError == nil {
|
||||||
|
relayInfo.LastError = nil
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
newAPIError = service.NormalizeViolationFeeError(newAPIError)
|
newAPIError = service.NormalizeViolationFeeError(newAPIError)
|
||||||
|
relayInfo.LastError = newAPIError
|
||||||
|
|
||||||
processChannelError(c, *types.NewChannelError(channel.Id, channel.Type, channel.Name, channel.ChannelInfo.IsMultiKey, common.GetContextKeyString(c, constant.ContextKeyChannelKey), channel.GetAutoBan()), newAPIError)
|
processChannelError(c, *types.NewChannelError(channel.Id, channel.Type, channel.Name, channel.ChannelInfo.IsMultiKey, common.GetContextKeyString(c, constant.ContextKeyChannelKey), channel.GetAutoBan()), newAPIError)
|
||||||
|
|
||||||
|
|||||||
@@ -84,7 +84,7 @@ func chatCompletionsViaResponses(c *gin.Context, info *relaycommon.RelayInfo, ad
|
|||||||
if len(info.ParamOverride) > 0 {
|
if len(info.ParamOverride) > 0 {
|
||||||
chatJSON, err = relaycommon.ApplyParamOverride(chatJSON, info.ParamOverride, overrideCtx)
|
chatJSON, err = relaycommon.ApplyParamOverride(chatJSON, info.ParamOverride, overrideCtx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
|
return nil, newAPIErrorFromParamOverride(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -155,7 +155,7 @@ func ClaudeHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
|
|||||||
if len(info.ParamOverride) > 0 {
|
if len(info.ParamOverride) > 0 {
|
||||||
jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride, relaycommon.BuildParamOverrideContext(info))
|
jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride, relaycommon.BuildParamOverrideContext(info))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
|
return newAPIErrorFromParamOverride(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,12 +1,15 @@
|
|||||||
package common
|
package common
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net/http"
|
||||||
"regexp"
|
"regexp"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/QuantumNous/new-api/common"
|
"github.com/QuantumNous/new-api/common"
|
||||||
|
"github.com/QuantumNous/new-api/types"
|
||||||
"github.com/tidwall/gjson"
|
"github.com/tidwall/gjson"
|
||||||
"github.com/tidwall/sjson"
|
"github.com/tidwall/sjson"
|
||||||
)
|
)
|
||||||
@@ -23,7 +26,7 @@ type ConditionOperation struct {
|
|||||||
|
|
||||||
type ParamOperation struct {
|
type ParamOperation struct {
|
||||||
Path string `json:"path"`
|
Path string `json:"path"`
|
||||||
Mode string `json:"mode"` // delete, set, move, copy, prepend, append, trim_prefix, trim_suffix, ensure_prefix, ensure_suffix, trim_space, to_lower, to_upper, replace, regex_replace
|
Mode string `json:"mode"` // delete, set, move, copy, prepend, append, trim_prefix, trim_suffix, ensure_prefix, ensure_suffix, trim_space, to_lower, to_upper, replace, regex_replace, return_error, prune_objects
|
||||||
Value interface{} `json:"value"`
|
Value interface{} `json:"value"`
|
||||||
KeepOrigin bool `json:"keep_origin"`
|
KeepOrigin bool `json:"keep_origin"`
|
||||||
From string `json:"from,omitempty"`
|
From string `json:"from,omitempty"`
|
||||||
@@ -32,6 +35,76 @@ type ParamOperation struct {
|
|||||||
Logic string `json:"logic,omitempty"` // AND, OR (默认OR)
|
Logic string `json:"logic,omitempty"` // AND, OR (默认OR)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type ParamOverrideReturnError struct {
|
||||||
|
Message string
|
||||||
|
StatusCode int
|
||||||
|
Code string
|
||||||
|
Type string
|
||||||
|
SkipRetry bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *ParamOverrideReturnError) Error() string {
|
||||||
|
if e == nil {
|
||||||
|
return "param override return error"
|
||||||
|
}
|
||||||
|
if e.Message == "" {
|
||||||
|
return "param override return error"
|
||||||
|
}
|
||||||
|
return e.Message
|
||||||
|
}
|
||||||
|
|
||||||
|
func AsParamOverrideReturnError(err error) (*ParamOverrideReturnError, bool) {
|
||||||
|
if err == nil {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
var target *ParamOverrideReturnError
|
||||||
|
if errors.As(err, &target) {
|
||||||
|
return target, true
|
||||||
|
}
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewAPIErrorFromParamOverride(err *ParamOverrideReturnError) *types.NewAPIError {
|
||||||
|
if err == nil {
|
||||||
|
return types.NewError(
|
||||||
|
errors.New("param override return error is nil"),
|
||||||
|
types.ErrorCodeChannelParamOverrideInvalid,
|
||||||
|
types.ErrOptionWithSkipRetry(),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
statusCode := err.StatusCode
|
||||||
|
if statusCode < http.StatusContinue || statusCode > http.StatusNetworkAuthenticationRequired {
|
||||||
|
statusCode = http.StatusBadRequest
|
||||||
|
}
|
||||||
|
|
||||||
|
errorCode := err.Code
|
||||||
|
if strings.TrimSpace(errorCode) == "" {
|
||||||
|
errorCode = string(types.ErrorCodeInvalidRequest)
|
||||||
|
}
|
||||||
|
|
||||||
|
errorType := err.Type
|
||||||
|
if strings.TrimSpace(errorType) == "" {
|
||||||
|
errorType = "invalid_request_error"
|
||||||
|
}
|
||||||
|
|
||||||
|
message := strings.TrimSpace(err.Message)
|
||||||
|
if message == "" {
|
||||||
|
message = "request blocked by param override"
|
||||||
|
}
|
||||||
|
|
||||||
|
opts := make([]types.NewAPIErrorOptions, 0, 1)
|
||||||
|
if err.SkipRetry {
|
||||||
|
opts = append(opts, types.ErrOptionWithSkipRetry())
|
||||||
|
}
|
||||||
|
|
||||||
|
return types.WithOpenAIError(types.OpenAIError{
|
||||||
|
Message: message,
|
||||||
|
Type: errorType,
|
||||||
|
Code: errorCode,
|
||||||
|
}, statusCode, opts...)
|
||||||
|
}
|
||||||
|
|
||||||
func ApplyParamOverride(jsonData []byte, paramOverride map[string]interface{}, conditionContext map[string]interface{}) ([]byte, error) {
|
func ApplyParamOverride(jsonData []byte, paramOverride map[string]interface{}, conditionContext map[string]interface{}) ([]byte, error) {
|
||||||
if len(paramOverride) == 0 {
|
if len(paramOverride) == 0 {
|
||||||
return jsonData, nil
|
return jsonData, nil
|
||||||
@@ -372,16 +445,104 @@ func applyOperations(jsonStr string, operations []ParamOperation, conditionConte
|
|||||||
result, err = replaceStringValue(result, opPath, op.From, op.To)
|
result, err = replaceStringValue(result, opPath, op.From, op.To)
|
||||||
case "regex_replace":
|
case "regex_replace":
|
||||||
result, err = regexReplaceStringValue(result, opPath, op.From, op.To)
|
result, err = regexReplaceStringValue(result, opPath, op.From, op.To)
|
||||||
|
case "return_error":
|
||||||
|
returnErr, parseErr := parseParamOverrideReturnError(op.Value)
|
||||||
|
if parseErr != nil {
|
||||||
|
return "", parseErr
|
||||||
|
}
|
||||||
|
return "", returnErr
|
||||||
|
case "prune_objects":
|
||||||
|
result, err = pruneObjects(result, opPath, contextJSON, op.Value)
|
||||||
default:
|
default:
|
||||||
return "", fmt.Errorf("unknown operation: %s", op.Mode)
|
return "", fmt.Errorf("unknown operation: %s", op.Mode)
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("operation %s failed: %v", op.Mode, err)
|
return "", fmt.Errorf("operation %s failed: %w", op.Mode, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return result, nil
|
return result, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func parseParamOverrideReturnError(value interface{}) (*ParamOverrideReturnError, error) {
|
||||||
|
result := &ParamOverrideReturnError{
|
||||||
|
StatusCode: http.StatusBadRequest,
|
||||||
|
Code: string(types.ErrorCodeInvalidRequest),
|
||||||
|
Type: "invalid_request_error",
|
||||||
|
SkipRetry: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
switch raw := value.(type) {
|
||||||
|
case nil:
|
||||||
|
return nil, fmt.Errorf("return_error value is required")
|
||||||
|
case string:
|
||||||
|
result.Message = strings.TrimSpace(raw)
|
||||||
|
case map[string]interface{}:
|
||||||
|
if message, ok := raw["message"].(string); ok {
|
||||||
|
result.Message = strings.TrimSpace(message)
|
||||||
|
}
|
||||||
|
if result.Message == "" {
|
||||||
|
if message, ok := raw["msg"].(string); ok {
|
||||||
|
result.Message = strings.TrimSpace(message)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if code, exists := raw["code"]; exists {
|
||||||
|
codeStr := strings.TrimSpace(fmt.Sprintf("%v", code))
|
||||||
|
if codeStr != "" {
|
||||||
|
result.Code = codeStr
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if errType, ok := raw["type"].(string); ok {
|
||||||
|
errType = strings.TrimSpace(errType)
|
||||||
|
if errType != "" {
|
||||||
|
result.Type = errType
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if skipRetry, ok := raw["skip_retry"].(bool); ok {
|
||||||
|
result.SkipRetry = skipRetry
|
||||||
|
}
|
||||||
|
|
||||||
|
if statusCodeRaw, exists := raw["status_code"]; exists {
|
||||||
|
statusCode, ok := parseOverrideInt(statusCodeRaw)
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("return_error status_code must be an integer")
|
||||||
|
}
|
||||||
|
result.StatusCode = statusCode
|
||||||
|
} else if statusRaw, exists := raw["status"]; exists {
|
||||||
|
statusCode, ok := parseOverrideInt(statusRaw)
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("return_error status must be an integer")
|
||||||
|
}
|
||||||
|
result.StatusCode = statusCode
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("return_error value must be string or object")
|
||||||
|
}
|
||||||
|
|
||||||
|
if result.Message == "" {
|
||||||
|
return nil, fmt.Errorf("return_error message is required")
|
||||||
|
}
|
||||||
|
if result.StatusCode < http.StatusContinue || result.StatusCode > http.StatusNetworkAuthenticationRequired {
|
||||||
|
return nil, fmt.Errorf("return_error status code out of range: %d", result.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseOverrideInt(v interface{}) (int, bool) {
|
||||||
|
switch value := v.(type) {
|
||||||
|
case int:
|
||||||
|
return value, true
|
||||||
|
case float64:
|
||||||
|
if value != float64(int(value)) {
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
return int(value), true
|
||||||
|
default:
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func moveValue(jsonStr, fromPath, toPath string) (string, error) {
|
func moveValue(jsonStr, fromPath, toPath string) (string, error) {
|
||||||
sourceValue := gjson.Get(jsonStr, fromPath)
|
sourceValue := gjson.Get(jsonStr, fromPath)
|
||||||
if !sourceValue.Exists() {
|
if !sourceValue.Exists() {
|
||||||
@@ -537,6 +698,217 @@ func regexReplaceStringValue(jsonStr, path, pattern, replacement string) (string
|
|||||||
return sjson.Set(jsonStr, path, re.ReplaceAllString(current.String(), replacement))
|
return sjson.Set(jsonStr, path, re.ReplaceAllString(current.String(), replacement))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type pruneObjectsOptions struct {
|
||||||
|
conditions []ConditionOperation
|
||||||
|
logic string
|
||||||
|
recursive bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func pruneObjects(jsonStr, path, contextJSON string, value interface{}) (string, error) {
|
||||||
|
options, err := parsePruneObjectsOptions(value)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
if path == "" {
|
||||||
|
var root interface{}
|
||||||
|
if err := common.Unmarshal([]byte(jsonStr), &root); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
cleaned, _, err := pruneObjectsNode(root, options, contextJSON, true)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
cleanedBytes, err := common.Marshal(cleaned)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return string(cleanedBytes), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
target := gjson.Get(jsonStr, path)
|
||||||
|
if !target.Exists() {
|
||||||
|
return jsonStr, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var targetNode interface{}
|
||||||
|
if target.Type == gjson.JSON {
|
||||||
|
if err := common.Unmarshal([]byte(target.Raw), &targetNode); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
targetNode = target.Value()
|
||||||
|
}
|
||||||
|
|
||||||
|
cleaned, _, err := pruneObjectsNode(targetNode, options, contextJSON, true)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
cleanedBytes, err := common.Marshal(cleaned)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return sjson.SetRaw(jsonStr, path, string(cleanedBytes))
|
||||||
|
}
|
||||||
|
|
||||||
|
func parsePruneObjectsOptions(value interface{}) (pruneObjectsOptions, error) {
|
||||||
|
opts := pruneObjectsOptions{
|
||||||
|
logic: "AND",
|
||||||
|
recursive: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
switch raw := value.(type) {
|
||||||
|
case nil:
|
||||||
|
return opts, fmt.Errorf("prune_objects value is required")
|
||||||
|
case string:
|
||||||
|
v := strings.TrimSpace(raw)
|
||||||
|
if v == "" {
|
||||||
|
return opts, fmt.Errorf("prune_objects value is required")
|
||||||
|
}
|
||||||
|
opts.conditions = []ConditionOperation{
|
||||||
|
{
|
||||||
|
Path: "type",
|
||||||
|
Mode: "full",
|
||||||
|
Value: v,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
case map[string]interface{}:
|
||||||
|
if logic, ok := raw["logic"].(string); ok && strings.TrimSpace(logic) != "" {
|
||||||
|
opts.logic = logic
|
||||||
|
}
|
||||||
|
if recursive, ok := raw["recursive"].(bool); ok {
|
||||||
|
opts.recursive = recursive
|
||||||
|
}
|
||||||
|
|
||||||
|
if condRaw, exists := raw["conditions"]; exists {
|
||||||
|
conditions, err := parseConditionOperations(condRaw)
|
||||||
|
if err != nil {
|
||||||
|
return opts, err
|
||||||
|
}
|
||||||
|
opts.conditions = append(opts.conditions, conditions...)
|
||||||
|
}
|
||||||
|
|
||||||
|
if whereRaw, exists := raw["where"]; exists {
|
||||||
|
whereMap, ok := whereRaw.(map[string]interface{})
|
||||||
|
if !ok {
|
||||||
|
return opts, fmt.Errorf("prune_objects where must be object")
|
||||||
|
}
|
||||||
|
for key, val := range whereMap {
|
||||||
|
key = strings.TrimSpace(key)
|
||||||
|
if key == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
opts.conditions = append(opts.conditions, ConditionOperation{
|
||||||
|
Path: key,
|
||||||
|
Mode: "full",
|
||||||
|
Value: val,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if matchType, exists := raw["type"]; exists {
|
||||||
|
opts.conditions = append(opts.conditions, ConditionOperation{
|
||||||
|
Path: "type",
|
||||||
|
Mode: "full",
|
||||||
|
Value: matchType,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
return opts, fmt.Errorf("prune_objects value must be string or object")
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(opts.conditions) == 0 {
|
||||||
|
return opts, fmt.Errorf("prune_objects conditions are required")
|
||||||
|
}
|
||||||
|
return opts, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseConditionOperations(raw interface{}) ([]ConditionOperation, error) {
|
||||||
|
items, ok := raw.([]interface{})
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("conditions must be an array")
|
||||||
|
}
|
||||||
|
|
||||||
|
result := make([]ConditionOperation, 0, len(items))
|
||||||
|
for _, item := range items {
|
||||||
|
itemMap, ok := item.(map[string]interface{})
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("condition must be object")
|
||||||
|
}
|
||||||
|
path, _ := itemMap["path"].(string)
|
||||||
|
mode, _ := itemMap["mode"].(string)
|
||||||
|
if strings.TrimSpace(path) == "" || strings.TrimSpace(mode) == "" {
|
||||||
|
return nil, fmt.Errorf("condition path/mode is required")
|
||||||
|
}
|
||||||
|
condition := ConditionOperation{
|
||||||
|
Path: path,
|
||||||
|
Mode: mode,
|
||||||
|
}
|
||||||
|
if value, exists := itemMap["value"]; exists {
|
||||||
|
condition.Value = value
|
||||||
|
}
|
||||||
|
if invert, ok := itemMap["invert"].(bool); ok {
|
||||||
|
condition.Invert = invert
|
||||||
|
}
|
||||||
|
if passMissingKey, ok := itemMap["pass_missing_key"].(bool); ok {
|
||||||
|
condition.PassMissingKey = passMissingKey
|
||||||
|
}
|
||||||
|
result = append(result, condition)
|
||||||
|
}
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func pruneObjectsNode(node interface{}, options pruneObjectsOptions, contextJSON string, isRoot bool) (interface{}, bool, error) {
|
||||||
|
switch value := node.(type) {
|
||||||
|
case []interface{}:
|
||||||
|
result := make([]interface{}, 0, len(value))
|
||||||
|
for _, item := range value {
|
||||||
|
next, drop, err := pruneObjectsNode(item, options, contextJSON, false)
|
||||||
|
if err != nil {
|
||||||
|
return nil, false, err
|
||||||
|
}
|
||||||
|
if drop {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
result = append(result, next)
|
||||||
|
}
|
||||||
|
return result, false, nil
|
||||||
|
case map[string]interface{}:
|
||||||
|
shouldDrop, err := shouldPruneObject(value, options, contextJSON)
|
||||||
|
if err != nil {
|
||||||
|
return nil, false, err
|
||||||
|
}
|
||||||
|
if shouldDrop && !isRoot {
|
||||||
|
return nil, true, nil
|
||||||
|
}
|
||||||
|
if !options.recursive {
|
||||||
|
return value, false, nil
|
||||||
|
}
|
||||||
|
for key, child := range value {
|
||||||
|
next, drop, err := pruneObjectsNode(child, options, contextJSON, false)
|
||||||
|
if err != nil {
|
||||||
|
return nil, false, err
|
||||||
|
}
|
||||||
|
if drop {
|
||||||
|
delete(value, key)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
value[key] = next
|
||||||
|
}
|
||||||
|
return value, false, nil
|
||||||
|
default:
|
||||||
|
return node, false, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func shouldPruneObject(node map[string]interface{}, options pruneObjectsOptions, contextJSON string) (bool, error) {
|
||||||
|
nodeBytes, err := common.Marshal(node)
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
return checkConditions(string(nodeBytes), contextJSON, options.conditions, options.logic)
|
||||||
|
}
|
||||||
|
|
||||||
func mergeObjects(jsonStr, path string, value interface{}, keepOrigin bool) (string, error) {
|
func mergeObjects(jsonStr, path string, value interface{}, keepOrigin bool) (string, error) {
|
||||||
current := gjson.Get(jsonStr, path)
|
current := gjson.Get(jsonStr, path)
|
||||||
var currentMap, newMap map[string]interface{}
|
var currentMap, newMap map[string]interface{}
|
||||||
@@ -598,6 +970,32 @@ func BuildParamOverrideContext(info *RelayInfo) map[string]interface{} {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ctx["retry_index"] = info.RetryIndex
|
||||||
|
ctx["is_retry"] = info.RetryIndex > 0
|
||||||
|
ctx["retry"] = map[string]interface{}{
|
||||||
|
"index": info.RetryIndex,
|
||||||
|
"is_retry": info.RetryIndex > 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
if info.LastError != nil {
|
||||||
|
code := string(info.LastError.GetErrorCode())
|
||||||
|
errorType := string(info.LastError.GetErrorType())
|
||||||
|
lastError := map[string]interface{}{
|
||||||
|
"status_code": info.LastError.StatusCode,
|
||||||
|
"message": info.LastError.Error(),
|
||||||
|
"code": code,
|
||||||
|
"error_code": code,
|
||||||
|
"type": errorType,
|
||||||
|
"error_type": errorType,
|
||||||
|
"skip_retry": types.IsSkipRetryError(info.LastError),
|
||||||
|
}
|
||||||
|
ctx["last_error"] = lastError
|
||||||
|
ctx["last_error_status_code"] = info.LastError.StatusCode
|
||||||
|
ctx["last_error_message"] = info.LastError.Error()
|
||||||
|
ctx["last_error_code"] = code
|
||||||
|
ctx["last_error_type"] = errorType
|
||||||
|
}
|
||||||
|
|
||||||
ctx["is_channel_test"] = info.IsChannelTest
|
ctx["is_channel_test"] = info.IsChannelTest
|
||||||
return ctx
|
return ctx
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,6 +4,8 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"reflect"
|
"reflect"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/QuantumNous/new-api/types"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestApplyParamOverrideTrimPrefix(t *testing.T) {
|
func TestApplyParamOverrideTrimPrefix(t *testing.T) {
|
||||||
@@ -772,6 +774,188 @@ func TestApplyParamOverrideToUpper(t *testing.T) {
|
|||||||
assertJSONEqual(t, `{"model":"GPT-4"}`, string(out))
|
assertJSONEqual(t, `{"model":"GPT-4"}`, string(out))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestApplyParamOverrideReturnError(t *testing.T) {
|
||||||
|
input := []byte(`{"model":"gemini-2.5-pro"}`)
|
||||||
|
override := map[string]interface{}{
|
||||||
|
"operations": []interface{}{
|
||||||
|
map[string]interface{}{
|
||||||
|
"mode": "return_error",
|
||||||
|
"value": map[string]interface{}{
|
||||||
|
"message": "forced bad request by param override",
|
||||||
|
"status_code": 422,
|
||||||
|
"code": "forced_bad_request",
|
||||||
|
"type": "invalid_request_error",
|
||||||
|
"skip_retry": true,
|
||||||
|
},
|
||||||
|
"conditions": []interface{}{
|
||||||
|
map[string]interface{}{
|
||||||
|
"path": "retry.is_retry",
|
||||||
|
"mode": "full",
|
||||||
|
"value": true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
ctx := map[string]interface{}{
|
||||||
|
"retry": map[string]interface{}{
|
||||||
|
"index": 1,
|
||||||
|
"is_retry": true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := ApplyParamOverride(input, override, ctx)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatalf("expected error, got nil")
|
||||||
|
}
|
||||||
|
returnErr, ok := AsParamOverrideReturnError(err)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("expected ParamOverrideReturnError, got %T: %v", err, err)
|
||||||
|
}
|
||||||
|
if returnErr.StatusCode != 422 {
|
||||||
|
t.Fatalf("expected status 422, got %d", returnErr.StatusCode)
|
||||||
|
}
|
||||||
|
if returnErr.Code != "forced_bad_request" {
|
||||||
|
t.Fatalf("expected code forced_bad_request, got %s", returnErr.Code)
|
||||||
|
}
|
||||||
|
if !returnErr.SkipRetry {
|
||||||
|
t.Fatalf("expected skip_retry true")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestApplyParamOverridePruneObjectsByTypeString(t *testing.T) {
|
||||||
|
input := []byte(`{
|
||||||
|
"messages":[
|
||||||
|
{"role":"assistant","content":[
|
||||||
|
{"type":"output_text","text":"a"},
|
||||||
|
{"type":"redacted_thinking","text":"secret"},
|
||||||
|
{"type":"tool_call","name":"tool_a"}
|
||||||
|
]},
|
||||||
|
{"role":"assistant","content":[
|
||||||
|
{"type":"output_text","text":"b"},
|
||||||
|
{"type":"wrapper","parts":[
|
||||||
|
{"type":"redacted_thinking","text":"secret2"},
|
||||||
|
{"type":"output_text","text":"c"}
|
||||||
|
]}
|
||||||
|
]}
|
||||||
|
]
|
||||||
|
}`)
|
||||||
|
override := map[string]interface{}{
|
||||||
|
"operations": []interface{}{
|
||||||
|
map[string]interface{}{
|
||||||
|
"mode": "prune_objects",
|
||||||
|
"value": "redacted_thinking",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
out, err := ApplyParamOverride(input, override, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ApplyParamOverride returned error: %v", err)
|
||||||
|
}
|
||||||
|
assertJSONEqual(t, `{
|
||||||
|
"messages":[
|
||||||
|
{"role":"assistant","content":[
|
||||||
|
{"type":"output_text","text":"a"},
|
||||||
|
{"type":"tool_call","name":"tool_a"}
|
||||||
|
]},
|
||||||
|
{"role":"assistant","content":[
|
||||||
|
{"type":"output_text","text":"b"},
|
||||||
|
{"type":"wrapper","parts":[
|
||||||
|
{"type":"output_text","text":"c"}
|
||||||
|
]}
|
||||||
|
]}
|
||||||
|
]
|
||||||
|
}`, string(out))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestApplyParamOverridePruneObjectsWhereAndPath(t *testing.T) {
|
||||||
|
input := []byte(`{
|
||||||
|
"a":{"items":[{"type":"redacted_thinking","id":1},{"type":"output_text","id":2}]},
|
||||||
|
"b":{"items":[{"type":"redacted_thinking","id":3},{"type":"output_text","id":4}]}
|
||||||
|
}`)
|
||||||
|
override := map[string]interface{}{
|
||||||
|
"operations": []interface{}{
|
||||||
|
map[string]interface{}{
|
||||||
|
"path": "a",
|
||||||
|
"mode": "prune_objects",
|
||||||
|
"value": map[string]interface{}{
|
||||||
|
"where": map[string]interface{}{
|
||||||
|
"type": "redacted_thinking",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
out, err := ApplyParamOverride(input, override, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ApplyParamOverride returned error: %v", err)
|
||||||
|
}
|
||||||
|
assertJSONEqual(t, `{
|
||||||
|
"a":{"items":[{"type":"output_text","id":2}]},
|
||||||
|
"b":{"items":[{"type":"redacted_thinking","id":3},{"type":"output_text","id":4}]}
|
||||||
|
}`, string(out))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestApplyParamOverrideNormalizeThinkingSignatureUnsupported(t *testing.T) {
|
||||||
|
input := []byte(`{"items":[{"type":"redacted_thinking"}]}`)
|
||||||
|
override := map[string]interface{}{
|
||||||
|
"operations": []interface{}{
|
||||||
|
map[string]interface{}{
|
||||||
|
"mode": "normalize_thinking_signature",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := ApplyParamOverride(input, override, nil)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatalf("expected error, got nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestApplyParamOverrideConditionFromRetryAndLastErrorContext(t *testing.T) {
|
||||||
|
info := &RelayInfo{
|
||||||
|
RetryIndex: 1,
|
||||||
|
LastError: types.WithOpenAIError(types.OpenAIError{
|
||||||
|
Message: "invalid thinking signature",
|
||||||
|
Type: "invalid_request_error",
|
||||||
|
Code: "bad_thought_signature",
|
||||||
|
}, 400),
|
||||||
|
}
|
||||||
|
ctx := BuildParamOverrideContext(info)
|
||||||
|
|
||||||
|
input := []byte(`{"temperature":0.7}`)
|
||||||
|
override := map[string]interface{}{
|
||||||
|
"operations": []interface{}{
|
||||||
|
map[string]interface{}{
|
||||||
|
"path": "temperature",
|
||||||
|
"mode": "set",
|
||||||
|
"value": 0.1,
|
||||||
|
"logic": "AND",
|
||||||
|
"conditions": []interface{}{
|
||||||
|
map[string]interface{}{
|
||||||
|
"path": "is_retry",
|
||||||
|
"mode": "full",
|
||||||
|
"value": true,
|
||||||
|
},
|
||||||
|
map[string]interface{}{
|
||||||
|
"path": "last_error.code",
|
||||||
|
"mode": "contains",
|
||||||
|
"value": "thought_signature",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
out, err := ApplyParamOverride(input, override, ctx)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ApplyParamOverride returned error: %v", err)
|
||||||
|
}
|
||||||
|
assertJSONEqual(t, `{"temperature":0.1}`, string(out))
|
||||||
|
}
|
||||||
|
|
||||||
func assertJSONEqual(t *testing.T, want, got string) {
|
func assertJSONEqual(t *testing.T, want, got string) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
|
|||||||
@@ -140,6 +140,8 @@ type RelayInfo struct {
|
|||||||
SubscriptionAmountUsedAfterPreConsume int64
|
SubscriptionAmountUsedAfterPreConsume int64
|
||||||
IsClaudeBetaQuery bool // /v1/messages?beta=true
|
IsClaudeBetaQuery bool // /v1/messages?beta=true
|
||||||
IsChannelTest bool // channel test request
|
IsChannelTest bool // channel test request
|
||||||
|
RetryIndex int
|
||||||
|
LastError *types.NewAPIError
|
||||||
|
|
||||||
PriceData types.PriceData
|
PriceData types.PriceData
|
||||||
|
|
||||||
|
|||||||
@@ -174,7 +174,7 @@ func TextHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types
|
|||||||
if len(info.ParamOverride) > 0 {
|
if len(info.ParamOverride) > 0 {
|
||||||
jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride, relaycommon.BuildParamOverrideContext(info))
|
jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride, relaycommon.BuildParamOverrideContext(info))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
|
return newAPIErrorFromParamOverride(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -2,7 +2,6 @@ package relay
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
@@ -46,7 +45,7 @@ func EmbeddingHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *
|
|||||||
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
|
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
|
||||||
}
|
}
|
||||||
relaycommon.AppendRequestConversionFromRequest(info, convertedRequest)
|
relaycommon.AppendRequestConversionFromRequest(info, convertedRequest)
|
||||||
jsonData, err := json.Marshal(convertedRequest)
|
jsonData, err := common.Marshal(convertedRequest)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
|
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
|
||||||
}
|
}
|
||||||
@@ -54,7 +53,7 @@ func EmbeddingHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *
|
|||||||
if len(info.ParamOverride) > 0 {
|
if len(info.ParamOverride) > 0 {
|
||||||
jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride, relaycommon.BuildParamOverrideContext(info))
|
jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride, relaycommon.BuildParamOverrideContext(info))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
|
return newAPIErrorFromParamOverride(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -159,7 +159,7 @@ func GeminiHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
|
|||||||
if len(info.ParamOverride) > 0 {
|
if len(info.ParamOverride) > 0 {
|
||||||
jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride, relaycommon.BuildParamOverrideContext(info))
|
jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride, relaycommon.BuildParamOverrideContext(info))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
|
return newAPIErrorFromParamOverride(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -257,14 +257,9 @@ func GeminiEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo) (newAPI
|
|||||||
|
|
||||||
// apply param override
|
// apply param override
|
||||||
if len(info.ParamOverride) > 0 {
|
if len(info.ParamOverride) > 0 {
|
||||||
reqMap := make(map[string]interface{})
|
jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride, relaycommon.BuildParamOverrideContext(info))
|
||||||
_ = common.Unmarshal(jsonData, &reqMap)
|
|
||||||
for key, value := range info.ParamOverride {
|
|
||||||
reqMap[key] = value
|
|
||||||
}
|
|
||||||
jsonData, err = common.Marshal(reqMap)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
|
return newAPIErrorFromParamOverride(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
logger.LogDebug(c, "Gemini embedding request body: "+string(jsonData))
|
logger.LogDebug(c, "Gemini embedding request body: "+string(jsonData))
|
||||||
|
|||||||
@@ -72,7 +72,7 @@ func ImageHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *type
|
|||||||
if len(info.ParamOverride) > 0 {
|
if len(info.ParamOverride) > 0 {
|
||||||
jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride, relaycommon.BuildParamOverrideContext(info))
|
jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride, relaycommon.BuildParamOverrideContext(info))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
|
return newAPIErrorFromParamOverride(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
13
relay/param_override_error.go
Normal file
13
relay/param_override_error.go
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
package relay
|
||||||
|
|
||||||
|
import (
|
||||||
|
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||||||
|
"github.com/QuantumNous/new-api/types"
|
||||||
|
)
|
||||||
|
|
||||||
|
func newAPIErrorFromParamOverride(err error) *types.NewAPIError {
|
||||||
|
if fixedErr, ok := relaycommon.AsParamOverrideReturnError(err); ok {
|
||||||
|
return relaycommon.NewAPIErrorFromParamOverride(fixedErr)
|
||||||
|
}
|
||||||
|
return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
|
||||||
|
}
|
||||||
@@ -63,7 +63,7 @@ func RerankHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
|
|||||||
if len(info.ParamOverride) > 0 {
|
if len(info.ParamOverride) > 0 {
|
||||||
jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride, relaycommon.BuildParamOverrideContext(info))
|
jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride, relaycommon.BuildParamOverrideContext(info))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
|
return newAPIErrorFromParamOverride(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -98,7 +98,7 @@ func ResponsesHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *
|
|||||||
if len(info.ParamOverride) > 0 {
|
if len(info.ParamOverride) > 0 {
|
||||||
jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride, relaycommon.BuildParamOverrideContext(info))
|
jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride, relaycommon.BuildParamOverrideContext(info))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
|
return newAPIErrorFromParamOverride(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user