mirror of
https://github.com/QuantumNous/new-api.git
synced 2026-03-29 23:10:35 +00:00
feat: add retry-aware param override with return_error and prune_objects
This commit is contained in:
@@ -366,7 +366,7 @@ func testChannel(channel *model.Channel, testModel string, endpointType string,
|
||||
newAPIError: types.NewError(err, types.ErrorCodeConvertRequestFailed),
|
||||
}
|
||||
}
|
||||
jsonData, err := json.Marshal(convertedRequest)
|
||||
jsonData, err := common.Marshal(convertedRequest)
|
||||
if err != nil {
|
||||
return testResult{
|
||||
context: c,
|
||||
@@ -387,6 +387,13 @@ func testChannel(channel *model.Channel, testModel string, endpointType string,
|
||||
if len(info.ParamOverride) > 0 {
|
||||
jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride, relaycommon.BuildParamOverrideContext(info))
|
||||
if err != nil {
|
||||
if fixedErr, ok := relaycommon.AsParamOverrideReturnError(err); ok {
|
||||
return testResult{
|
||||
context: c,
|
||||
localErr: fixedErr,
|
||||
newAPIError: relaycommon.NewAPIErrorFromParamOverride(fixedErr),
|
||||
}
|
||||
}
|
||||
return testResult{
|
||||
context: c,
|
||||
localErr: err,
|
||||
|
||||
@@ -182,8 +182,11 @@ func Relay(c *gin.Context, relayFormat types.RelayFormat) {
|
||||
ModelName: relayInfo.OriginModelName,
|
||||
Retry: common.GetPointer(0),
|
||||
}
|
||||
relayInfo.RetryIndex = 0
|
||||
relayInfo.LastError = nil
|
||||
|
||||
for ; retryParam.GetRetry() <= common.RetryTimes; retryParam.IncreaseRetry() {
|
||||
relayInfo.RetryIndex = retryParam.GetRetry()
|
||||
channel, channelErr := getChannel(c, relayInfo, retryParam)
|
||||
if channelErr != nil {
|
||||
logger.LogError(c, channelErr.Error())
|
||||
@@ -216,10 +219,12 @@ func Relay(c *gin.Context, relayFormat types.RelayFormat) {
|
||||
}
|
||||
|
||||
if newAPIError == nil {
|
||||
relayInfo.LastError = nil
|
||||
return
|
||||
}
|
||||
|
||||
newAPIError = service.NormalizeViolationFeeError(newAPIError)
|
||||
relayInfo.LastError = newAPIError
|
||||
|
||||
processChannelError(c, *types.NewChannelError(channel.Id, channel.Type, channel.Name, channel.ChannelInfo.IsMultiKey, common.GetContextKeyString(c, constant.ContextKeyChannelKey), channel.GetAutoBan()), newAPIError)
|
||||
|
||||
|
||||
@@ -84,7 +84,7 @@ func chatCompletionsViaResponses(c *gin.Context, info *relaycommon.RelayInfo, ad
|
||||
if len(info.ParamOverride) > 0 {
|
||||
chatJSON, err = relaycommon.ApplyParamOverride(chatJSON, info.ParamOverride, overrideCtx)
|
||||
if err != nil {
|
||||
return nil, types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
|
||||
return nil, newAPIErrorFromParamOverride(err)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -155,7 +155,7 @@ func ClaudeHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
|
||||
if len(info.ParamOverride) > 0 {
|
||||
jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride, relaycommon.BuildParamOverrideContext(info))
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
|
||||
return newAPIErrorFromParamOverride(err)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,12 +1,15 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/types"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
@@ -23,7 +26,7 @@ type ConditionOperation struct {
|
||||
|
||||
type ParamOperation struct {
|
||||
Path string `json:"path"`
|
||||
Mode string `json:"mode"` // delete, set, move, copy, prepend, append, trim_prefix, trim_suffix, ensure_prefix, ensure_suffix, trim_space, to_lower, to_upper, replace, regex_replace
|
||||
Mode string `json:"mode"` // delete, set, move, copy, prepend, append, trim_prefix, trim_suffix, ensure_prefix, ensure_suffix, trim_space, to_lower, to_upper, replace, regex_replace, return_error, prune_objects
|
||||
Value interface{} `json:"value"`
|
||||
KeepOrigin bool `json:"keep_origin"`
|
||||
From string `json:"from,omitempty"`
|
||||
@@ -32,6 +35,76 @@ type ParamOperation struct {
|
||||
Logic string `json:"logic,omitempty"` // AND, OR (默认OR)
|
||||
}
|
||||
|
||||
type ParamOverrideReturnError struct {
|
||||
Message string
|
||||
StatusCode int
|
||||
Code string
|
||||
Type string
|
||||
SkipRetry bool
|
||||
}
|
||||
|
||||
func (e *ParamOverrideReturnError) Error() string {
|
||||
if e == nil {
|
||||
return "param override return error"
|
||||
}
|
||||
if e.Message == "" {
|
||||
return "param override return error"
|
||||
}
|
||||
return e.Message
|
||||
}
|
||||
|
||||
func AsParamOverrideReturnError(err error) (*ParamOverrideReturnError, bool) {
|
||||
if err == nil {
|
||||
return nil, false
|
||||
}
|
||||
var target *ParamOverrideReturnError
|
||||
if errors.As(err, &target) {
|
||||
return target, true
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func NewAPIErrorFromParamOverride(err *ParamOverrideReturnError) *types.NewAPIError {
|
||||
if err == nil {
|
||||
return types.NewError(
|
||||
errors.New("param override return error is nil"),
|
||||
types.ErrorCodeChannelParamOverrideInvalid,
|
||||
types.ErrOptionWithSkipRetry(),
|
||||
)
|
||||
}
|
||||
|
||||
statusCode := err.StatusCode
|
||||
if statusCode < http.StatusContinue || statusCode > http.StatusNetworkAuthenticationRequired {
|
||||
statusCode = http.StatusBadRequest
|
||||
}
|
||||
|
||||
errorCode := err.Code
|
||||
if strings.TrimSpace(errorCode) == "" {
|
||||
errorCode = string(types.ErrorCodeInvalidRequest)
|
||||
}
|
||||
|
||||
errorType := err.Type
|
||||
if strings.TrimSpace(errorType) == "" {
|
||||
errorType = "invalid_request_error"
|
||||
}
|
||||
|
||||
message := strings.TrimSpace(err.Message)
|
||||
if message == "" {
|
||||
message = "request blocked by param override"
|
||||
}
|
||||
|
||||
opts := make([]types.NewAPIErrorOptions, 0, 1)
|
||||
if err.SkipRetry {
|
||||
opts = append(opts, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
return types.WithOpenAIError(types.OpenAIError{
|
||||
Message: message,
|
||||
Type: errorType,
|
||||
Code: errorCode,
|
||||
}, statusCode, opts...)
|
||||
}
|
||||
|
||||
func ApplyParamOverride(jsonData []byte, paramOverride map[string]interface{}, conditionContext map[string]interface{}) ([]byte, error) {
|
||||
if len(paramOverride) == 0 {
|
||||
return jsonData, nil
|
||||
@@ -372,16 +445,104 @@ func applyOperations(jsonStr string, operations []ParamOperation, conditionConte
|
||||
result, err = replaceStringValue(result, opPath, op.From, op.To)
|
||||
case "regex_replace":
|
||||
result, err = regexReplaceStringValue(result, opPath, op.From, op.To)
|
||||
case "return_error":
|
||||
returnErr, parseErr := parseParamOverrideReturnError(op.Value)
|
||||
if parseErr != nil {
|
||||
return "", parseErr
|
||||
}
|
||||
return "", returnErr
|
||||
case "prune_objects":
|
||||
result, err = pruneObjects(result, opPath, contextJSON, op.Value)
|
||||
default:
|
||||
return "", fmt.Errorf("unknown operation: %s", op.Mode)
|
||||
}
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("operation %s failed: %v", op.Mode, err)
|
||||
return "", fmt.Errorf("operation %s failed: %w", op.Mode, err)
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func parseParamOverrideReturnError(value interface{}) (*ParamOverrideReturnError, error) {
|
||||
result := &ParamOverrideReturnError{
|
||||
StatusCode: http.StatusBadRequest,
|
||||
Code: string(types.ErrorCodeInvalidRequest),
|
||||
Type: "invalid_request_error",
|
||||
SkipRetry: true,
|
||||
}
|
||||
|
||||
switch raw := value.(type) {
|
||||
case nil:
|
||||
return nil, fmt.Errorf("return_error value is required")
|
||||
case string:
|
||||
result.Message = strings.TrimSpace(raw)
|
||||
case map[string]interface{}:
|
||||
if message, ok := raw["message"].(string); ok {
|
||||
result.Message = strings.TrimSpace(message)
|
||||
}
|
||||
if result.Message == "" {
|
||||
if message, ok := raw["msg"].(string); ok {
|
||||
result.Message = strings.TrimSpace(message)
|
||||
}
|
||||
}
|
||||
|
||||
if code, exists := raw["code"]; exists {
|
||||
codeStr := strings.TrimSpace(fmt.Sprintf("%v", code))
|
||||
if codeStr != "" {
|
||||
result.Code = codeStr
|
||||
}
|
||||
}
|
||||
if errType, ok := raw["type"].(string); ok {
|
||||
errType = strings.TrimSpace(errType)
|
||||
if errType != "" {
|
||||
result.Type = errType
|
||||
}
|
||||
}
|
||||
if skipRetry, ok := raw["skip_retry"].(bool); ok {
|
||||
result.SkipRetry = skipRetry
|
||||
}
|
||||
|
||||
if statusCodeRaw, exists := raw["status_code"]; exists {
|
||||
statusCode, ok := parseOverrideInt(statusCodeRaw)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("return_error status_code must be an integer")
|
||||
}
|
||||
result.StatusCode = statusCode
|
||||
} else if statusRaw, exists := raw["status"]; exists {
|
||||
statusCode, ok := parseOverrideInt(statusRaw)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("return_error status must be an integer")
|
||||
}
|
||||
result.StatusCode = statusCode
|
||||
}
|
||||
default:
|
||||
return nil, fmt.Errorf("return_error value must be string or object")
|
||||
}
|
||||
|
||||
if result.Message == "" {
|
||||
return nil, fmt.Errorf("return_error message is required")
|
||||
}
|
||||
if result.StatusCode < http.StatusContinue || result.StatusCode > http.StatusNetworkAuthenticationRequired {
|
||||
return nil, fmt.Errorf("return_error status code out of range: %d", result.StatusCode)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func parseOverrideInt(v interface{}) (int, bool) {
|
||||
switch value := v.(type) {
|
||||
case int:
|
||||
return value, true
|
||||
case float64:
|
||||
if value != float64(int(value)) {
|
||||
return 0, false
|
||||
}
|
||||
return int(value), true
|
||||
default:
|
||||
return 0, false
|
||||
}
|
||||
}
|
||||
|
||||
func moveValue(jsonStr, fromPath, toPath string) (string, error) {
|
||||
sourceValue := gjson.Get(jsonStr, fromPath)
|
||||
if !sourceValue.Exists() {
|
||||
@@ -537,6 +698,217 @@ func regexReplaceStringValue(jsonStr, path, pattern, replacement string) (string
|
||||
return sjson.Set(jsonStr, path, re.ReplaceAllString(current.String(), replacement))
|
||||
}
|
||||
|
||||
type pruneObjectsOptions struct {
|
||||
conditions []ConditionOperation
|
||||
logic string
|
||||
recursive bool
|
||||
}
|
||||
|
||||
func pruneObjects(jsonStr, path, contextJSON string, value interface{}) (string, error) {
|
||||
options, err := parsePruneObjectsOptions(value)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if path == "" {
|
||||
var root interface{}
|
||||
if err := common.Unmarshal([]byte(jsonStr), &root); err != nil {
|
||||
return "", err
|
||||
}
|
||||
cleaned, _, err := pruneObjectsNode(root, options, contextJSON, true)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
cleanedBytes, err := common.Marshal(cleaned)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(cleanedBytes), nil
|
||||
}
|
||||
|
||||
target := gjson.Get(jsonStr, path)
|
||||
if !target.Exists() {
|
||||
return jsonStr, nil
|
||||
}
|
||||
|
||||
var targetNode interface{}
|
||||
if target.Type == gjson.JSON {
|
||||
if err := common.Unmarshal([]byte(target.Raw), &targetNode); err != nil {
|
||||
return "", err
|
||||
}
|
||||
} else {
|
||||
targetNode = target.Value()
|
||||
}
|
||||
|
||||
cleaned, _, err := pruneObjectsNode(targetNode, options, contextJSON, true)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
cleanedBytes, err := common.Marshal(cleaned)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return sjson.SetRaw(jsonStr, path, string(cleanedBytes))
|
||||
}
|
||||
|
||||
func parsePruneObjectsOptions(value interface{}) (pruneObjectsOptions, error) {
|
||||
opts := pruneObjectsOptions{
|
||||
logic: "AND",
|
||||
recursive: true,
|
||||
}
|
||||
|
||||
switch raw := value.(type) {
|
||||
case nil:
|
||||
return opts, fmt.Errorf("prune_objects value is required")
|
||||
case string:
|
||||
v := strings.TrimSpace(raw)
|
||||
if v == "" {
|
||||
return opts, fmt.Errorf("prune_objects value is required")
|
||||
}
|
||||
opts.conditions = []ConditionOperation{
|
||||
{
|
||||
Path: "type",
|
||||
Mode: "full",
|
||||
Value: v,
|
||||
},
|
||||
}
|
||||
case map[string]interface{}:
|
||||
if logic, ok := raw["logic"].(string); ok && strings.TrimSpace(logic) != "" {
|
||||
opts.logic = logic
|
||||
}
|
||||
if recursive, ok := raw["recursive"].(bool); ok {
|
||||
opts.recursive = recursive
|
||||
}
|
||||
|
||||
if condRaw, exists := raw["conditions"]; exists {
|
||||
conditions, err := parseConditionOperations(condRaw)
|
||||
if err != nil {
|
||||
return opts, err
|
||||
}
|
||||
opts.conditions = append(opts.conditions, conditions...)
|
||||
}
|
||||
|
||||
if whereRaw, exists := raw["where"]; exists {
|
||||
whereMap, ok := whereRaw.(map[string]interface{})
|
||||
if !ok {
|
||||
return opts, fmt.Errorf("prune_objects where must be object")
|
||||
}
|
||||
for key, val := range whereMap {
|
||||
key = strings.TrimSpace(key)
|
||||
if key == "" {
|
||||
continue
|
||||
}
|
||||
opts.conditions = append(opts.conditions, ConditionOperation{
|
||||
Path: key,
|
||||
Mode: "full",
|
||||
Value: val,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
if matchType, exists := raw["type"]; exists {
|
||||
opts.conditions = append(opts.conditions, ConditionOperation{
|
||||
Path: "type",
|
||||
Mode: "full",
|
||||
Value: matchType,
|
||||
})
|
||||
}
|
||||
default:
|
||||
return opts, fmt.Errorf("prune_objects value must be string or object")
|
||||
}
|
||||
|
||||
if len(opts.conditions) == 0 {
|
||||
return opts, fmt.Errorf("prune_objects conditions are required")
|
||||
}
|
||||
return opts, nil
|
||||
}
|
||||
|
||||
func parseConditionOperations(raw interface{}) ([]ConditionOperation, error) {
|
||||
items, ok := raw.([]interface{})
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("conditions must be an array")
|
||||
}
|
||||
|
||||
result := make([]ConditionOperation, 0, len(items))
|
||||
for _, item := range items {
|
||||
itemMap, ok := item.(map[string]interface{})
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("condition must be object")
|
||||
}
|
||||
path, _ := itemMap["path"].(string)
|
||||
mode, _ := itemMap["mode"].(string)
|
||||
if strings.TrimSpace(path) == "" || strings.TrimSpace(mode) == "" {
|
||||
return nil, fmt.Errorf("condition path/mode is required")
|
||||
}
|
||||
condition := ConditionOperation{
|
||||
Path: path,
|
||||
Mode: mode,
|
||||
}
|
||||
if value, exists := itemMap["value"]; exists {
|
||||
condition.Value = value
|
||||
}
|
||||
if invert, ok := itemMap["invert"].(bool); ok {
|
||||
condition.Invert = invert
|
||||
}
|
||||
if passMissingKey, ok := itemMap["pass_missing_key"].(bool); ok {
|
||||
condition.PassMissingKey = passMissingKey
|
||||
}
|
||||
result = append(result, condition)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func pruneObjectsNode(node interface{}, options pruneObjectsOptions, contextJSON string, isRoot bool) (interface{}, bool, error) {
|
||||
switch value := node.(type) {
|
||||
case []interface{}:
|
||||
result := make([]interface{}, 0, len(value))
|
||||
for _, item := range value {
|
||||
next, drop, err := pruneObjectsNode(item, options, contextJSON, false)
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
if drop {
|
||||
continue
|
||||
}
|
||||
result = append(result, next)
|
||||
}
|
||||
return result, false, nil
|
||||
case map[string]interface{}:
|
||||
shouldDrop, err := shouldPruneObject(value, options, contextJSON)
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
if shouldDrop && !isRoot {
|
||||
return nil, true, nil
|
||||
}
|
||||
if !options.recursive {
|
||||
return value, false, nil
|
||||
}
|
||||
for key, child := range value {
|
||||
next, drop, err := pruneObjectsNode(child, options, contextJSON, false)
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
if drop {
|
||||
delete(value, key)
|
||||
continue
|
||||
}
|
||||
value[key] = next
|
||||
}
|
||||
return value, false, nil
|
||||
default:
|
||||
return node, false, nil
|
||||
}
|
||||
}
|
||||
|
||||
func shouldPruneObject(node map[string]interface{}, options pruneObjectsOptions, contextJSON string) (bool, error) {
|
||||
nodeBytes, err := common.Marshal(node)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return checkConditions(string(nodeBytes), contextJSON, options.conditions, options.logic)
|
||||
}
|
||||
|
||||
func mergeObjects(jsonStr, path string, value interface{}, keepOrigin bool) (string, error) {
|
||||
current := gjson.Get(jsonStr, path)
|
||||
var currentMap, newMap map[string]interface{}
|
||||
@@ -598,6 +970,32 @@ func BuildParamOverrideContext(info *RelayInfo) map[string]interface{} {
|
||||
}
|
||||
}
|
||||
|
||||
ctx["retry_index"] = info.RetryIndex
|
||||
ctx["is_retry"] = info.RetryIndex > 0
|
||||
ctx["retry"] = map[string]interface{}{
|
||||
"index": info.RetryIndex,
|
||||
"is_retry": info.RetryIndex > 0,
|
||||
}
|
||||
|
||||
if info.LastError != nil {
|
||||
code := string(info.LastError.GetErrorCode())
|
||||
errorType := string(info.LastError.GetErrorType())
|
||||
lastError := map[string]interface{}{
|
||||
"status_code": info.LastError.StatusCode,
|
||||
"message": info.LastError.Error(),
|
||||
"code": code,
|
||||
"error_code": code,
|
||||
"type": errorType,
|
||||
"error_type": errorType,
|
||||
"skip_retry": types.IsSkipRetryError(info.LastError),
|
||||
}
|
||||
ctx["last_error"] = lastError
|
||||
ctx["last_error_status_code"] = info.LastError.StatusCode
|
||||
ctx["last_error_message"] = info.LastError.Error()
|
||||
ctx["last_error_code"] = code
|
||||
ctx["last_error_type"] = errorType
|
||||
}
|
||||
|
||||
ctx["is_channel_test"] = info.IsChannelTest
|
||||
return ctx
|
||||
}
|
||||
|
||||
@@ -4,6 +4,8 @@ import (
|
||||
"encoding/json"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/QuantumNous/new-api/types"
|
||||
)
|
||||
|
||||
func TestApplyParamOverrideTrimPrefix(t *testing.T) {
|
||||
@@ -772,6 +774,188 @@ func TestApplyParamOverrideToUpper(t *testing.T) {
|
||||
assertJSONEqual(t, `{"model":"GPT-4"}`, string(out))
|
||||
}
|
||||
|
||||
func TestApplyParamOverrideReturnError(t *testing.T) {
|
||||
input := []byte(`{"model":"gemini-2.5-pro"}`)
|
||||
override := map[string]interface{}{
|
||||
"operations": []interface{}{
|
||||
map[string]interface{}{
|
||||
"mode": "return_error",
|
||||
"value": map[string]interface{}{
|
||||
"message": "forced bad request by param override",
|
||||
"status_code": 422,
|
||||
"code": "forced_bad_request",
|
||||
"type": "invalid_request_error",
|
||||
"skip_retry": true,
|
||||
},
|
||||
"conditions": []interface{}{
|
||||
map[string]interface{}{
|
||||
"path": "retry.is_retry",
|
||||
"mode": "full",
|
||||
"value": true,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
ctx := map[string]interface{}{
|
||||
"retry": map[string]interface{}{
|
||||
"index": 1,
|
||||
"is_retry": true,
|
||||
},
|
||||
}
|
||||
|
||||
_, err := ApplyParamOverride(input, override, ctx)
|
||||
if err == nil {
|
||||
t.Fatalf("expected error, got nil")
|
||||
}
|
||||
returnErr, ok := AsParamOverrideReturnError(err)
|
||||
if !ok {
|
||||
t.Fatalf("expected ParamOverrideReturnError, got %T: %v", err, err)
|
||||
}
|
||||
if returnErr.StatusCode != 422 {
|
||||
t.Fatalf("expected status 422, got %d", returnErr.StatusCode)
|
||||
}
|
||||
if returnErr.Code != "forced_bad_request" {
|
||||
t.Fatalf("expected code forced_bad_request, got %s", returnErr.Code)
|
||||
}
|
||||
if !returnErr.SkipRetry {
|
||||
t.Fatalf("expected skip_retry true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyParamOverridePruneObjectsByTypeString(t *testing.T) {
|
||||
input := []byte(`{
|
||||
"messages":[
|
||||
{"role":"assistant","content":[
|
||||
{"type":"output_text","text":"a"},
|
||||
{"type":"redacted_thinking","text":"secret"},
|
||||
{"type":"tool_call","name":"tool_a"}
|
||||
]},
|
||||
{"role":"assistant","content":[
|
||||
{"type":"output_text","text":"b"},
|
||||
{"type":"wrapper","parts":[
|
||||
{"type":"redacted_thinking","text":"secret2"},
|
||||
{"type":"output_text","text":"c"}
|
||||
]}
|
||||
]}
|
||||
]
|
||||
}`)
|
||||
override := map[string]interface{}{
|
||||
"operations": []interface{}{
|
||||
map[string]interface{}{
|
||||
"mode": "prune_objects",
|
||||
"value": "redacted_thinking",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
out, err := ApplyParamOverride(input, override, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("ApplyParamOverride returned error: %v", err)
|
||||
}
|
||||
assertJSONEqual(t, `{
|
||||
"messages":[
|
||||
{"role":"assistant","content":[
|
||||
{"type":"output_text","text":"a"},
|
||||
{"type":"tool_call","name":"tool_a"}
|
||||
]},
|
||||
{"role":"assistant","content":[
|
||||
{"type":"output_text","text":"b"},
|
||||
{"type":"wrapper","parts":[
|
||||
{"type":"output_text","text":"c"}
|
||||
]}
|
||||
]}
|
||||
]
|
||||
}`, string(out))
|
||||
}
|
||||
|
||||
func TestApplyParamOverridePruneObjectsWhereAndPath(t *testing.T) {
|
||||
input := []byte(`{
|
||||
"a":{"items":[{"type":"redacted_thinking","id":1},{"type":"output_text","id":2}]},
|
||||
"b":{"items":[{"type":"redacted_thinking","id":3},{"type":"output_text","id":4}]}
|
||||
}`)
|
||||
override := map[string]interface{}{
|
||||
"operations": []interface{}{
|
||||
map[string]interface{}{
|
||||
"path": "a",
|
||||
"mode": "prune_objects",
|
||||
"value": map[string]interface{}{
|
||||
"where": map[string]interface{}{
|
||||
"type": "redacted_thinking",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
out, err := ApplyParamOverride(input, override, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("ApplyParamOverride returned error: %v", err)
|
||||
}
|
||||
assertJSONEqual(t, `{
|
||||
"a":{"items":[{"type":"output_text","id":2}]},
|
||||
"b":{"items":[{"type":"redacted_thinking","id":3},{"type":"output_text","id":4}]}
|
||||
}`, string(out))
|
||||
}
|
||||
|
||||
func TestApplyParamOverrideNormalizeThinkingSignatureUnsupported(t *testing.T) {
|
||||
input := []byte(`{"items":[{"type":"redacted_thinking"}]}`)
|
||||
override := map[string]interface{}{
|
||||
"operations": []interface{}{
|
||||
map[string]interface{}{
|
||||
"mode": "normalize_thinking_signature",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
_, err := ApplyParamOverride(input, override, nil)
|
||||
if err == nil {
|
||||
t.Fatalf("expected error, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyParamOverrideConditionFromRetryAndLastErrorContext(t *testing.T) {
|
||||
info := &RelayInfo{
|
||||
RetryIndex: 1,
|
||||
LastError: types.WithOpenAIError(types.OpenAIError{
|
||||
Message: "invalid thinking signature",
|
||||
Type: "invalid_request_error",
|
||||
Code: "bad_thought_signature",
|
||||
}, 400),
|
||||
}
|
||||
ctx := BuildParamOverrideContext(info)
|
||||
|
||||
input := []byte(`{"temperature":0.7}`)
|
||||
override := map[string]interface{}{
|
||||
"operations": []interface{}{
|
||||
map[string]interface{}{
|
||||
"path": "temperature",
|
||||
"mode": "set",
|
||||
"value": 0.1,
|
||||
"logic": "AND",
|
||||
"conditions": []interface{}{
|
||||
map[string]interface{}{
|
||||
"path": "is_retry",
|
||||
"mode": "full",
|
||||
"value": true,
|
||||
},
|
||||
map[string]interface{}{
|
||||
"path": "last_error.code",
|
||||
"mode": "contains",
|
||||
"value": "thought_signature",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
out, err := ApplyParamOverride(input, override, ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("ApplyParamOverride returned error: %v", err)
|
||||
}
|
||||
assertJSONEqual(t, `{"temperature":0.1}`, string(out))
|
||||
}
|
||||
|
||||
func assertJSONEqual(t *testing.T, want, got string) {
|
||||
t.Helper()
|
||||
|
||||
|
||||
@@ -140,6 +140,8 @@ type RelayInfo struct {
|
||||
SubscriptionAmountUsedAfterPreConsume int64
|
||||
IsClaudeBetaQuery bool // /v1/messages?beta=true
|
||||
IsChannelTest bool // channel test request
|
||||
RetryIndex int
|
||||
LastError *types.NewAPIError
|
||||
|
||||
PriceData types.PriceData
|
||||
|
||||
|
||||
@@ -174,7 +174,7 @@ func TextHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types
|
||||
if len(info.ParamOverride) > 0 {
|
||||
jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride, relaycommon.BuildParamOverrideContext(info))
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
|
||||
return newAPIErrorFromParamOverride(err)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -2,7 +2,6 @@ package relay
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
@@ -46,7 +45,7 @@ func EmbeddingHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *
|
||||
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
relaycommon.AppendRequestConversionFromRequest(info, convertedRequest)
|
||||
jsonData, err := json.Marshal(convertedRequest)
|
||||
jsonData, err := common.Marshal(convertedRequest)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
@@ -54,7 +53,7 @@ func EmbeddingHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *
|
||||
if len(info.ParamOverride) > 0 {
|
||||
jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride, relaycommon.BuildParamOverrideContext(info))
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
|
||||
return newAPIErrorFromParamOverride(err)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -159,7 +159,7 @@ func GeminiHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
|
||||
if len(info.ParamOverride) > 0 {
|
||||
jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride, relaycommon.BuildParamOverrideContext(info))
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
|
||||
return newAPIErrorFromParamOverride(err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -257,14 +257,9 @@ func GeminiEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo) (newAPI
|
||||
|
||||
// apply param override
|
||||
if len(info.ParamOverride) > 0 {
|
||||
reqMap := make(map[string]interface{})
|
||||
_ = common.Unmarshal(jsonData, &reqMap)
|
||||
for key, value := range info.ParamOverride {
|
||||
reqMap[key] = value
|
||||
}
|
||||
jsonData, err = common.Marshal(reqMap)
|
||||
jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride, relaycommon.BuildParamOverrideContext(info))
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
|
||||
return newAPIErrorFromParamOverride(err)
|
||||
}
|
||||
}
|
||||
logger.LogDebug(c, "Gemini embedding request body: "+string(jsonData))
|
||||
|
||||
@@ -72,7 +72,7 @@ func ImageHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *type
|
||||
if len(info.ParamOverride) > 0 {
|
||||
jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride, relaycommon.BuildParamOverrideContext(info))
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
|
||||
return newAPIErrorFromParamOverride(err)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
13
relay/param_override_error.go
Normal file
13
relay/param_override_error.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package relay
|
||||
|
||||
import (
|
||||
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||||
"github.com/QuantumNous/new-api/types"
|
||||
)
|
||||
|
||||
func newAPIErrorFromParamOverride(err error) *types.NewAPIError {
|
||||
if fixedErr, ok := relaycommon.AsParamOverrideReturnError(err); ok {
|
||||
return relaycommon.NewAPIErrorFromParamOverride(fixedErr)
|
||||
}
|
||||
return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
@@ -63,7 +63,7 @@ func RerankHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
|
||||
if len(info.ParamOverride) > 0 {
|
||||
jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride, relaycommon.BuildParamOverrideContext(info))
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
|
||||
return newAPIErrorFromParamOverride(err)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -98,7 +98,7 @@ func ResponsesHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *
|
||||
if len(info.ParamOverride) > 0 {
|
||||
jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride, relaycommon.BuildParamOverrideContext(info))
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
|
||||
return newAPIErrorFromParamOverride(err)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user