feat: add retry-aware param override with return_error and prune_objects

This commit is contained in:
Seefs
2026-02-22 00:10:49 +08:00
parent dbc3236245
commit ff76e75f4c
14 changed files with 623 additions and 20 deletions

View File

@@ -366,7 +366,7 @@ func testChannel(channel *model.Channel, testModel string, endpointType string,
newAPIError: types.NewError(err, types.ErrorCodeConvertRequestFailed),
}
}
jsonData, err := json.Marshal(convertedRequest)
jsonData, err := common.Marshal(convertedRequest)
if err != nil {
return testResult{
context: c,
@@ -387,6 +387,13 @@ func testChannel(channel *model.Channel, testModel string, endpointType string,
if len(info.ParamOverride) > 0 {
jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride, relaycommon.BuildParamOverrideContext(info))
if err != nil {
if fixedErr, ok := relaycommon.AsParamOverrideReturnError(err); ok {
return testResult{
context: c,
localErr: fixedErr,
newAPIError: relaycommon.NewAPIErrorFromParamOverride(fixedErr),
}
}
return testResult{
context: c,
localErr: err,

View File

@@ -182,8 +182,11 @@ func Relay(c *gin.Context, relayFormat types.RelayFormat) {
ModelName: relayInfo.OriginModelName,
Retry: common.GetPointer(0),
}
relayInfo.RetryIndex = 0
relayInfo.LastError = nil
for ; retryParam.GetRetry() <= common.RetryTimes; retryParam.IncreaseRetry() {
relayInfo.RetryIndex = retryParam.GetRetry()
channel, channelErr := getChannel(c, relayInfo, retryParam)
if channelErr != nil {
logger.LogError(c, channelErr.Error())
@@ -216,10 +219,12 @@ func Relay(c *gin.Context, relayFormat types.RelayFormat) {
}
if newAPIError == nil {
relayInfo.LastError = nil
return
}
newAPIError = service.NormalizeViolationFeeError(newAPIError)
relayInfo.LastError = newAPIError
processChannelError(c, *types.NewChannelError(channel.Id, channel.Type, channel.Name, channel.ChannelInfo.IsMultiKey, common.GetContextKeyString(c, constant.ContextKeyChannelKey), channel.GetAutoBan()), newAPIError)

View File

@@ -84,7 +84,7 @@ func chatCompletionsViaResponses(c *gin.Context, info *relaycommon.RelayInfo, ad
if len(info.ParamOverride) > 0 {
chatJSON, err = relaycommon.ApplyParamOverride(chatJSON, info.ParamOverride, overrideCtx)
if err != nil {
return nil, types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
return nil, newAPIErrorFromParamOverride(err)
}
}

View File

@@ -155,7 +155,7 @@ func ClaudeHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
if len(info.ParamOverride) > 0 {
jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride, relaycommon.BuildParamOverrideContext(info))
if err != nil {
return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
return newAPIErrorFromParamOverride(err)
}
}

View File

@@ -1,12 +1,15 @@
package common
import (
"errors"
"fmt"
"net/http"
"regexp"
"strconv"
"strings"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/types"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
@@ -23,7 +26,7 @@ type ConditionOperation struct {
type ParamOperation struct {
Path string `json:"path"`
Mode string `json:"mode"` // delete, set, move, copy, prepend, append, trim_prefix, trim_suffix, ensure_prefix, ensure_suffix, trim_space, to_lower, to_upper, replace, regex_replace
Mode string `json:"mode"` // delete, set, move, copy, prepend, append, trim_prefix, trim_suffix, ensure_prefix, ensure_suffix, trim_space, to_lower, to_upper, replace, regex_replace, return_error, prune_objects
Value interface{} `json:"value"`
KeepOrigin bool `json:"keep_origin"`
From string `json:"from,omitempty"`
@@ -32,6 +35,76 @@ type ParamOperation struct {
Logic string `json:"logic,omitempty"` // AND, OR (默认OR)
}
type ParamOverrideReturnError struct {
Message string
StatusCode int
Code string
Type string
SkipRetry bool
}
func (e *ParamOverrideReturnError) Error() string {
if e == nil {
return "param override return error"
}
if e.Message == "" {
return "param override return error"
}
return e.Message
}
func AsParamOverrideReturnError(err error) (*ParamOverrideReturnError, bool) {
if err == nil {
return nil, false
}
var target *ParamOverrideReturnError
if errors.As(err, &target) {
return target, true
}
return nil, false
}
func NewAPIErrorFromParamOverride(err *ParamOverrideReturnError) *types.NewAPIError {
if err == nil {
return types.NewError(
errors.New("param override return error is nil"),
types.ErrorCodeChannelParamOverrideInvalid,
types.ErrOptionWithSkipRetry(),
)
}
statusCode := err.StatusCode
if statusCode < http.StatusContinue || statusCode > http.StatusNetworkAuthenticationRequired {
statusCode = http.StatusBadRequest
}
errorCode := err.Code
if strings.TrimSpace(errorCode) == "" {
errorCode = string(types.ErrorCodeInvalidRequest)
}
errorType := err.Type
if strings.TrimSpace(errorType) == "" {
errorType = "invalid_request_error"
}
message := strings.TrimSpace(err.Message)
if message == "" {
message = "request blocked by param override"
}
opts := make([]types.NewAPIErrorOptions, 0, 1)
if err.SkipRetry {
opts = append(opts, types.ErrOptionWithSkipRetry())
}
return types.WithOpenAIError(types.OpenAIError{
Message: message,
Type: errorType,
Code: errorCode,
}, statusCode, opts...)
}
func ApplyParamOverride(jsonData []byte, paramOverride map[string]interface{}, conditionContext map[string]interface{}) ([]byte, error) {
if len(paramOverride) == 0 {
return jsonData, nil
@@ -372,16 +445,104 @@ func applyOperations(jsonStr string, operations []ParamOperation, conditionConte
result, err = replaceStringValue(result, opPath, op.From, op.To)
case "regex_replace":
result, err = regexReplaceStringValue(result, opPath, op.From, op.To)
case "return_error":
returnErr, parseErr := parseParamOverrideReturnError(op.Value)
if parseErr != nil {
return "", parseErr
}
return "", returnErr
case "prune_objects":
result, err = pruneObjects(result, opPath, contextJSON, op.Value)
default:
return "", fmt.Errorf("unknown operation: %s", op.Mode)
}
if err != nil {
return "", fmt.Errorf("operation %s failed: %v", op.Mode, err)
return "", fmt.Errorf("operation %s failed: %w", op.Mode, err)
}
}
return result, nil
}
func parseParamOverrideReturnError(value interface{}) (*ParamOverrideReturnError, error) {
result := &ParamOverrideReturnError{
StatusCode: http.StatusBadRequest,
Code: string(types.ErrorCodeInvalidRequest),
Type: "invalid_request_error",
SkipRetry: true,
}
switch raw := value.(type) {
case nil:
return nil, fmt.Errorf("return_error value is required")
case string:
result.Message = strings.TrimSpace(raw)
case map[string]interface{}:
if message, ok := raw["message"].(string); ok {
result.Message = strings.TrimSpace(message)
}
if result.Message == "" {
if message, ok := raw["msg"].(string); ok {
result.Message = strings.TrimSpace(message)
}
}
if code, exists := raw["code"]; exists {
codeStr := strings.TrimSpace(fmt.Sprintf("%v", code))
if codeStr != "" {
result.Code = codeStr
}
}
if errType, ok := raw["type"].(string); ok {
errType = strings.TrimSpace(errType)
if errType != "" {
result.Type = errType
}
}
if skipRetry, ok := raw["skip_retry"].(bool); ok {
result.SkipRetry = skipRetry
}
if statusCodeRaw, exists := raw["status_code"]; exists {
statusCode, ok := parseOverrideInt(statusCodeRaw)
if !ok {
return nil, fmt.Errorf("return_error status_code must be an integer")
}
result.StatusCode = statusCode
} else if statusRaw, exists := raw["status"]; exists {
statusCode, ok := parseOverrideInt(statusRaw)
if !ok {
return nil, fmt.Errorf("return_error status must be an integer")
}
result.StatusCode = statusCode
}
default:
return nil, fmt.Errorf("return_error value must be string or object")
}
if result.Message == "" {
return nil, fmt.Errorf("return_error message is required")
}
if result.StatusCode < http.StatusContinue || result.StatusCode > http.StatusNetworkAuthenticationRequired {
return nil, fmt.Errorf("return_error status code out of range: %d", result.StatusCode)
}
return result, nil
}
func parseOverrideInt(v interface{}) (int, bool) {
switch value := v.(type) {
case int:
return value, true
case float64:
if value != float64(int(value)) {
return 0, false
}
return int(value), true
default:
return 0, false
}
}
func moveValue(jsonStr, fromPath, toPath string) (string, error) {
sourceValue := gjson.Get(jsonStr, fromPath)
if !sourceValue.Exists() {
@@ -537,6 +698,217 @@ func regexReplaceStringValue(jsonStr, path, pattern, replacement string) (string
return sjson.Set(jsonStr, path, re.ReplaceAllString(current.String(), replacement))
}
type pruneObjectsOptions struct {
conditions []ConditionOperation
logic string
recursive bool
}
func pruneObjects(jsonStr, path, contextJSON string, value interface{}) (string, error) {
options, err := parsePruneObjectsOptions(value)
if err != nil {
return "", err
}
if path == "" {
var root interface{}
if err := common.Unmarshal([]byte(jsonStr), &root); err != nil {
return "", err
}
cleaned, _, err := pruneObjectsNode(root, options, contextJSON, true)
if err != nil {
return "", err
}
cleanedBytes, err := common.Marshal(cleaned)
if err != nil {
return "", err
}
return string(cleanedBytes), nil
}
target := gjson.Get(jsonStr, path)
if !target.Exists() {
return jsonStr, nil
}
var targetNode interface{}
if target.Type == gjson.JSON {
if err := common.Unmarshal([]byte(target.Raw), &targetNode); err != nil {
return "", err
}
} else {
targetNode = target.Value()
}
cleaned, _, err := pruneObjectsNode(targetNode, options, contextJSON, true)
if err != nil {
return "", err
}
cleanedBytes, err := common.Marshal(cleaned)
if err != nil {
return "", err
}
return sjson.SetRaw(jsonStr, path, string(cleanedBytes))
}
func parsePruneObjectsOptions(value interface{}) (pruneObjectsOptions, error) {
opts := pruneObjectsOptions{
logic: "AND",
recursive: true,
}
switch raw := value.(type) {
case nil:
return opts, fmt.Errorf("prune_objects value is required")
case string:
v := strings.TrimSpace(raw)
if v == "" {
return opts, fmt.Errorf("prune_objects value is required")
}
opts.conditions = []ConditionOperation{
{
Path: "type",
Mode: "full",
Value: v,
},
}
case map[string]interface{}:
if logic, ok := raw["logic"].(string); ok && strings.TrimSpace(logic) != "" {
opts.logic = logic
}
if recursive, ok := raw["recursive"].(bool); ok {
opts.recursive = recursive
}
if condRaw, exists := raw["conditions"]; exists {
conditions, err := parseConditionOperations(condRaw)
if err != nil {
return opts, err
}
opts.conditions = append(opts.conditions, conditions...)
}
if whereRaw, exists := raw["where"]; exists {
whereMap, ok := whereRaw.(map[string]interface{})
if !ok {
return opts, fmt.Errorf("prune_objects where must be object")
}
for key, val := range whereMap {
key = strings.TrimSpace(key)
if key == "" {
continue
}
opts.conditions = append(opts.conditions, ConditionOperation{
Path: key,
Mode: "full",
Value: val,
})
}
}
if matchType, exists := raw["type"]; exists {
opts.conditions = append(opts.conditions, ConditionOperation{
Path: "type",
Mode: "full",
Value: matchType,
})
}
default:
return opts, fmt.Errorf("prune_objects value must be string or object")
}
if len(opts.conditions) == 0 {
return opts, fmt.Errorf("prune_objects conditions are required")
}
return opts, nil
}
func parseConditionOperations(raw interface{}) ([]ConditionOperation, error) {
items, ok := raw.([]interface{})
if !ok {
return nil, fmt.Errorf("conditions must be an array")
}
result := make([]ConditionOperation, 0, len(items))
for _, item := range items {
itemMap, ok := item.(map[string]interface{})
if !ok {
return nil, fmt.Errorf("condition must be object")
}
path, _ := itemMap["path"].(string)
mode, _ := itemMap["mode"].(string)
if strings.TrimSpace(path) == "" || strings.TrimSpace(mode) == "" {
return nil, fmt.Errorf("condition path/mode is required")
}
condition := ConditionOperation{
Path: path,
Mode: mode,
}
if value, exists := itemMap["value"]; exists {
condition.Value = value
}
if invert, ok := itemMap["invert"].(bool); ok {
condition.Invert = invert
}
if passMissingKey, ok := itemMap["pass_missing_key"].(bool); ok {
condition.PassMissingKey = passMissingKey
}
result = append(result, condition)
}
return result, nil
}
func pruneObjectsNode(node interface{}, options pruneObjectsOptions, contextJSON string, isRoot bool) (interface{}, bool, error) {
switch value := node.(type) {
case []interface{}:
result := make([]interface{}, 0, len(value))
for _, item := range value {
next, drop, err := pruneObjectsNode(item, options, contextJSON, false)
if err != nil {
return nil, false, err
}
if drop {
continue
}
result = append(result, next)
}
return result, false, nil
case map[string]interface{}:
shouldDrop, err := shouldPruneObject(value, options, contextJSON)
if err != nil {
return nil, false, err
}
if shouldDrop && !isRoot {
return nil, true, nil
}
if !options.recursive {
return value, false, nil
}
for key, child := range value {
next, drop, err := pruneObjectsNode(child, options, contextJSON, false)
if err != nil {
return nil, false, err
}
if drop {
delete(value, key)
continue
}
value[key] = next
}
return value, false, nil
default:
return node, false, nil
}
}
func shouldPruneObject(node map[string]interface{}, options pruneObjectsOptions, contextJSON string) (bool, error) {
nodeBytes, err := common.Marshal(node)
if err != nil {
return false, err
}
return checkConditions(string(nodeBytes), contextJSON, options.conditions, options.logic)
}
func mergeObjects(jsonStr, path string, value interface{}, keepOrigin bool) (string, error) {
current := gjson.Get(jsonStr, path)
var currentMap, newMap map[string]interface{}
@@ -598,6 +970,32 @@ func BuildParamOverrideContext(info *RelayInfo) map[string]interface{} {
}
}
ctx["retry_index"] = info.RetryIndex
ctx["is_retry"] = info.RetryIndex > 0
ctx["retry"] = map[string]interface{}{
"index": info.RetryIndex,
"is_retry": info.RetryIndex > 0,
}
if info.LastError != nil {
code := string(info.LastError.GetErrorCode())
errorType := string(info.LastError.GetErrorType())
lastError := map[string]interface{}{
"status_code": info.LastError.StatusCode,
"message": info.LastError.Error(),
"code": code,
"error_code": code,
"type": errorType,
"error_type": errorType,
"skip_retry": types.IsSkipRetryError(info.LastError),
}
ctx["last_error"] = lastError
ctx["last_error_status_code"] = info.LastError.StatusCode
ctx["last_error_message"] = info.LastError.Error()
ctx["last_error_code"] = code
ctx["last_error_type"] = errorType
}
ctx["is_channel_test"] = info.IsChannelTest
return ctx
}

View File

@@ -4,6 +4,8 @@ import (
"encoding/json"
"reflect"
"testing"
"github.com/QuantumNous/new-api/types"
)
func TestApplyParamOverrideTrimPrefix(t *testing.T) {
@@ -772,6 +774,188 @@ func TestApplyParamOverrideToUpper(t *testing.T) {
assertJSONEqual(t, `{"model":"GPT-4"}`, string(out))
}
func TestApplyParamOverrideReturnError(t *testing.T) {
input := []byte(`{"model":"gemini-2.5-pro"}`)
override := map[string]interface{}{
"operations": []interface{}{
map[string]interface{}{
"mode": "return_error",
"value": map[string]interface{}{
"message": "forced bad request by param override",
"status_code": 422,
"code": "forced_bad_request",
"type": "invalid_request_error",
"skip_retry": true,
},
"conditions": []interface{}{
map[string]interface{}{
"path": "retry.is_retry",
"mode": "full",
"value": true,
},
},
},
},
}
ctx := map[string]interface{}{
"retry": map[string]interface{}{
"index": 1,
"is_retry": true,
},
}
_, err := ApplyParamOverride(input, override, ctx)
if err == nil {
t.Fatalf("expected error, got nil")
}
returnErr, ok := AsParamOverrideReturnError(err)
if !ok {
t.Fatalf("expected ParamOverrideReturnError, got %T: %v", err, err)
}
if returnErr.StatusCode != 422 {
t.Fatalf("expected status 422, got %d", returnErr.StatusCode)
}
if returnErr.Code != "forced_bad_request" {
t.Fatalf("expected code forced_bad_request, got %s", returnErr.Code)
}
if !returnErr.SkipRetry {
t.Fatalf("expected skip_retry true")
}
}
func TestApplyParamOverridePruneObjectsByTypeString(t *testing.T) {
input := []byte(`{
"messages":[
{"role":"assistant","content":[
{"type":"output_text","text":"a"},
{"type":"redacted_thinking","text":"secret"},
{"type":"tool_call","name":"tool_a"}
]},
{"role":"assistant","content":[
{"type":"output_text","text":"b"},
{"type":"wrapper","parts":[
{"type":"redacted_thinking","text":"secret2"},
{"type":"output_text","text":"c"}
]}
]}
]
}`)
override := map[string]interface{}{
"operations": []interface{}{
map[string]interface{}{
"mode": "prune_objects",
"value": "redacted_thinking",
},
},
}
out, err := ApplyParamOverride(input, override, nil)
if err != nil {
t.Fatalf("ApplyParamOverride returned error: %v", err)
}
assertJSONEqual(t, `{
"messages":[
{"role":"assistant","content":[
{"type":"output_text","text":"a"},
{"type":"tool_call","name":"tool_a"}
]},
{"role":"assistant","content":[
{"type":"output_text","text":"b"},
{"type":"wrapper","parts":[
{"type":"output_text","text":"c"}
]}
]}
]
}`, string(out))
}
func TestApplyParamOverridePruneObjectsWhereAndPath(t *testing.T) {
input := []byte(`{
"a":{"items":[{"type":"redacted_thinking","id":1},{"type":"output_text","id":2}]},
"b":{"items":[{"type":"redacted_thinking","id":3},{"type":"output_text","id":4}]}
}`)
override := map[string]interface{}{
"operations": []interface{}{
map[string]interface{}{
"path": "a",
"mode": "prune_objects",
"value": map[string]interface{}{
"where": map[string]interface{}{
"type": "redacted_thinking",
},
},
},
},
}
out, err := ApplyParamOverride(input, override, nil)
if err != nil {
t.Fatalf("ApplyParamOverride returned error: %v", err)
}
assertJSONEqual(t, `{
"a":{"items":[{"type":"output_text","id":2}]},
"b":{"items":[{"type":"redacted_thinking","id":3},{"type":"output_text","id":4}]}
}`, string(out))
}
func TestApplyParamOverrideNormalizeThinkingSignatureUnsupported(t *testing.T) {
input := []byte(`{"items":[{"type":"redacted_thinking"}]}`)
override := map[string]interface{}{
"operations": []interface{}{
map[string]interface{}{
"mode": "normalize_thinking_signature",
},
},
}
_, err := ApplyParamOverride(input, override, nil)
if err == nil {
t.Fatalf("expected error, got nil")
}
}
func TestApplyParamOverrideConditionFromRetryAndLastErrorContext(t *testing.T) {
info := &RelayInfo{
RetryIndex: 1,
LastError: types.WithOpenAIError(types.OpenAIError{
Message: "invalid thinking signature",
Type: "invalid_request_error",
Code: "bad_thought_signature",
}, 400),
}
ctx := BuildParamOverrideContext(info)
input := []byte(`{"temperature":0.7}`)
override := map[string]interface{}{
"operations": []interface{}{
map[string]interface{}{
"path": "temperature",
"mode": "set",
"value": 0.1,
"logic": "AND",
"conditions": []interface{}{
map[string]interface{}{
"path": "is_retry",
"mode": "full",
"value": true,
},
map[string]interface{}{
"path": "last_error.code",
"mode": "contains",
"value": "thought_signature",
},
},
},
},
}
out, err := ApplyParamOverride(input, override, ctx)
if err != nil {
t.Fatalf("ApplyParamOverride returned error: %v", err)
}
assertJSONEqual(t, `{"temperature":0.1}`, string(out))
}
func assertJSONEqual(t *testing.T, want, got string) {
t.Helper()

View File

@@ -140,6 +140,8 @@ type RelayInfo struct {
SubscriptionAmountUsedAfterPreConsume int64
IsClaudeBetaQuery bool // /v1/messages?beta=true
IsChannelTest bool // channel test request
RetryIndex int
LastError *types.NewAPIError
PriceData types.PriceData

View File

@@ -174,7 +174,7 @@ func TextHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types
if len(info.ParamOverride) > 0 {
jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride, relaycommon.BuildParamOverrideContext(info))
if err != nil {
return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
return newAPIErrorFromParamOverride(err)
}
}

View File

@@ -2,7 +2,6 @@ package relay
import (
"bytes"
"encoding/json"
"fmt"
"net/http"
@@ -46,7 +45,7 @@ func EmbeddingHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
}
relaycommon.AppendRequestConversionFromRequest(info, convertedRequest)
jsonData, err := json.Marshal(convertedRequest)
jsonData, err := common.Marshal(convertedRequest)
if err != nil {
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
}
@@ -54,7 +53,7 @@ func EmbeddingHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *
if len(info.ParamOverride) > 0 {
jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride, relaycommon.BuildParamOverrideContext(info))
if err != nil {
return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
return newAPIErrorFromParamOverride(err)
}
}

View File

@@ -159,7 +159,7 @@ func GeminiHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
if len(info.ParamOverride) > 0 {
jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride, relaycommon.BuildParamOverrideContext(info))
if err != nil {
return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
return newAPIErrorFromParamOverride(err)
}
}
@@ -257,14 +257,9 @@ func GeminiEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo) (newAPI
// apply param override
if len(info.ParamOverride) > 0 {
reqMap := make(map[string]interface{})
_ = common.Unmarshal(jsonData, &reqMap)
for key, value := range info.ParamOverride {
reqMap[key] = value
}
jsonData, err = common.Marshal(reqMap)
jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride, relaycommon.BuildParamOverrideContext(info))
if err != nil {
return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
return newAPIErrorFromParamOverride(err)
}
}
logger.LogDebug(c, "Gemini embedding request body: "+string(jsonData))

View File

@@ -72,7 +72,7 @@ func ImageHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *type
if len(info.ParamOverride) > 0 {
jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride, relaycommon.BuildParamOverrideContext(info))
if err != nil {
return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
return newAPIErrorFromParamOverride(err)
}
}

View File

@@ -0,0 +1,13 @@
package relay
import (
relaycommon "github.com/QuantumNous/new-api/relay/common"
"github.com/QuantumNous/new-api/types"
)
func newAPIErrorFromParamOverride(err error) *types.NewAPIError {
if fixedErr, ok := relaycommon.AsParamOverrideReturnError(err); ok {
return relaycommon.NewAPIErrorFromParamOverride(fixedErr)
}
return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
}

View File

@@ -63,7 +63,7 @@ func RerankHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
if len(info.ParamOverride) > 0 {
jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride, relaycommon.BuildParamOverrideContext(info))
if err != nil {
return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
return newAPIErrorFromParamOverride(err)
}
}

View File

@@ -98,7 +98,7 @@ func ResponsesHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *
if len(info.ParamOverride) > 0 {
jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride, relaycommon.BuildParamOverrideContext(info))
if err != nil {
return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
return newAPIErrorFromParamOverride(err)
}
}