From ff76e75f4c1bbd10442f3397da7a1df3a4099ead Mon Sep 17 00:00:00 2001 From: Seefs Date: Sun, 22 Feb 2026 00:10:49 +0800 Subject: [PATCH] feat: add retry-aware param override with return_error and prune_objects --- controller/channel-test.go | 9 +- controller/relay.go | 5 + relay/chat_completions_via_responses.go | 2 +- relay/claude_handler.go | 2 +- relay/common/override.go | 402 +++++++++++++++++++++++- relay/common/override_test.go | 184 +++++++++++ relay/common/relay_info.go | 2 + relay/compatible_handler.go | 2 +- relay/embedding_handler.go | 5 +- relay/gemini_handler.go | 11 +- relay/image_handler.go | 2 +- relay/param_override_error.go | 13 + relay/rerank_handler.go | 2 +- relay/responses_handler.go | 2 +- 14 files changed, 623 insertions(+), 20 deletions(-) create mode 100644 relay/param_override_error.go diff --git a/controller/channel-test.go b/controller/channel-test.go index ab12132b1..7ffee9fdf 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -366,7 +366,7 @@ func testChannel(channel *model.Channel, testModel string, endpointType string, newAPIError: types.NewError(err, types.ErrorCodeConvertRequestFailed), } } - jsonData, err := json.Marshal(convertedRequest) + jsonData, err := common.Marshal(convertedRequest) if err != nil { return testResult{ context: c, @@ -387,6 +387,13 @@ func testChannel(channel *model.Channel, testModel string, endpointType string, if len(info.ParamOverride) > 0 { jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride, relaycommon.BuildParamOverrideContext(info)) if err != nil { + if fixedErr, ok := relaycommon.AsParamOverrideReturnError(err); ok { + return testResult{ + context: c, + localErr: fixedErr, + newAPIError: relaycommon.NewAPIErrorFromParamOverride(fixedErr), + } + } return testResult{ context: c, localErr: err, diff --git a/controller/relay.go b/controller/relay.go index 0b30e6e9e..e3e92bc51 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -182,8 +182,11 @@ func Relay(c *gin.Context, relayFormat types.RelayFormat) { ModelName: relayInfo.OriginModelName, Retry: common.GetPointer(0), } + relayInfo.RetryIndex = 0 + relayInfo.LastError = nil for ; retryParam.GetRetry() <= common.RetryTimes; retryParam.IncreaseRetry() { + relayInfo.RetryIndex = retryParam.GetRetry() channel, channelErr := getChannel(c, relayInfo, retryParam) if channelErr != nil { logger.LogError(c, channelErr.Error()) @@ -216,10 +219,12 @@ func Relay(c *gin.Context, relayFormat types.RelayFormat) { } if newAPIError == nil { + relayInfo.LastError = nil return } newAPIError = service.NormalizeViolationFeeError(newAPIError) + relayInfo.LastError = newAPIError processChannelError(c, *types.NewChannelError(channel.Id, channel.Type, channel.Name, channel.ChannelInfo.IsMultiKey, common.GetContextKeyString(c, constant.ContextKeyChannelKey), channel.GetAutoBan()), newAPIError) diff --git a/relay/chat_completions_via_responses.go b/relay/chat_completions_via_responses.go index 38dae3c56..3f2fb1874 100644 --- a/relay/chat_completions_via_responses.go +++ b/relay/chat_completions_via_responses.go @@ -84,7 +84,7 @@ func chatCompletionsViaResponses(c *gin.Context, info *relaycommon.RelayInfo, ad if len(info.ParamOverride) > 0 { chatJSON, err = relaycommon.ApplyParamOverride(chatJSON, info.ParamOverride, overrideCtx) if err != nil { - return nil, types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry()) + return nil, newAPIErrorFromParamOverride(err) } } diff --git a/relay/claude_handler.go b/relay/claude_handler.go index 81adb276a..b9d9936e9 100644 --- a/relay/claude_handler.go +++ b/relay/claude_handler.go @@ -155,7 +155,7 @@ func ClaudeHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ if len(info.ParamOverride) > 0 { jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride, relaycommon.BuildParamOverrideContext(info)) if err != nil { - return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry()) + return newAPIErrorFromParamOverride(err) } } diff --git a/relay/common/override.go b/relay/common/override.go index 1a0c2478d..070a8e7af 100644 --- a/relay/common/override.go +++ b/relay/common/override.go @@ -1,12 +1,15 @@ package common import ( + "errors" "fmt" + "net/http" "regexp" "strconv" "strings" "github.com/QuantumNous/new-api/common" + "github.com/QuantumNous/new-api/types" "github.com/tidwall/gjson" "github.com/tidwall/sjson" ) @@ -23,7 +26,7 @@ type ConditionOperation struct { type ParamOperation struct { Path string `json:"path"` - Mode string `json:"mode"` // delete, set, move, copy, prepend, append, trim_prefix, trim_suffix, ensure_prefix, ensure_suffix, trim_space, to_lower, to_upper, replace, regex_replace + Mode string `json:"mode"` // delete, set, move, copy, prepend, append, trim_prefix, trim_suffix, ensure_prefix, ensure_suffix, trim_space, to_lower, to_upper, replace, regex_replace, return_error, prune_objects Value interface{} `json:"value"` KeepOrigin bool `json:"keep_origin"` From string `json:"from,omitempty"` @@ -32,6 +35,76 @@ type ParamOperation struct { Logic string `json:"logic,omitempty"` // AND, OR (默认OR) } +type ParamOverrideReturnError struct { + Message string + StatusCode int + Code string + Type string + SkipRetry bool +} + +func (e *ParamOverrideReturnError) Error() string { + if e == nil { + return "param override return error" + } + if e.Message == "" { + return "param override return error" + } + return e.Message +} + +func AsParamOverrideReturnError(err error) (*ParamOverrideReturnError, bool) { + if err == nil { + return nil, false + } + var target *ParamOverrideReturnError + if errors.As(err, &target) { + return target, true + } + return nil, false +} + +func NewAPIErrorFromParamOverride(err *ParamOverrideReturnError) *types.NewAPIError { + if err == nil { + return types.NewError( + errors.New("param override return error is nil"), + types.ErrorCodeChannelParamOverrideInvalid, + types.ErrOptionWithSkipRetry(), + ) + } + + statusCode := err.StatusCode + if statusCode < http.StatusContinue || statusCode > http.StatusNetworkAuthenticationRequired { + statusCode = http.StatusBadRequest + } + + errorCode := err.Code + if strings.TrimSpace(errorCode) == "" { + errorCode = string(types.ErrorCodeInvalidRequest) + } + + errorType := err.Type + if strings.TrimSpace(errorType) == "" { + errorType = "invalid_request_error" + } + + message := strings.TrimSpace(err.Message) + if message == "" { + message = "request blocked by param override" + } + + opts := make([]types.NewAPIErrorOptions, 0, 1) + if err.SkipRetry { + opts = append(opts, types.ErrOptionWithSkipRetry()) + } + + return types.WithOpenAIError(types.OpenAIError{ + Message: message, + Type: errorType, + Code: errorCode, + }, statusCode, opts...) +} + func ApplyParamOverride(jsonData []byte, paramOverride map[string]interface{}, conditionContext map[string]interface{}) ([]byte, error) { if len(paramOverride) == 0 { return jsonData, nil @@ -372,16 +445,104 @@ func applyOperations(jsonStr string, operations []ParamOperation, conditionConte result, err = replaceStringValue(result, opPath, op.From, op.To) case "regex_replace": result, err = regexReplaceStringValue(result, opPath, op.From, op.To) + case "return_error": + returnErr, parseErr := parseParamOverrideReturnError(op.Value) + if parseErr != nil { + return "", parseErr + } + return "", returnErr + case "prune_objects": + result, err = pruneObjects(result, opPath, contextJSON, op.Value) default: return "", fmt.Errorf("unknown operation: %s", op.Mode) } if err != nil { - return "", fmt.Errorf("operation %s failed: %v", op.Mode, err) + return "", fmt.Errorf("operation %s failed: %w", op.Mode, err) } } return result, nil } +func parseParamOverrideReturnError(value interface{}) (*ParamOverrideReturnError, error) { + result := &ParamOverrideReturnError{ + StatusCode: http.StatusBadRequest, + Code: string(types.ErrorCodeInvalidRequest), + Type: "invalid_request_error", + SkipRetry: true, + } + + switch raw := value.(type) { + case nil: + return nil, fmt.Errorf("return_error value is required") + case string: + result.Message = strings.TrimSpace(raw) + case map[string]interface{}: + if message, ok := raw["message"].(string); ok { + result.Message = strings.TrimSpace(message) + } + if result.Message == "" { + if message, ok := raw["msg"].(string); ok { + result.Message = strings.TrimSpace(message) + } + } + + if code, exists := raw["code"]; exists { + codeStr := strings.TrimSpace(fmt.Sprintf("%v", code)) + if codeStr != "" { + result.Code = codeStr + } + } + if errType, ok := raw["type"].(string); ok { + errType = strings.TrimSpace(errType) + if errType != "" { + result.Type = errType + } + } + if skipRetry, ok := raw["skip_retry"].(bool); ok { + result.SkipRetry = skipRetry + } + + if statusCodeRaw, exists := raw["status_code"]; exists { + statusCode, ok := parseOverrideInt(statusCodeRaw) + if !ok { + return nil, fmt.Errorf("return_error status_code must be an integer") + } + result.StatusCode = statusCode + } else if statusRaw, exists := raw["status"]; exists { + statusCode, ok := parseOverrideInt(statusRaw) + if !ok { + return nil, fmt.Errorf("return_error status must be an integer") + } + result.StatusCode = statusCode + } + default: + return nil, fmt.Errorf("return_error value must be string or object") + } + + if result.Message == "" { + return nil, fmt.Errorf("return_error message is required") + } + if result.StatusCode < http.StatusContinue || result.StatusCode > http.StatusNetworkAuthenticationRequired { + return nil, fmt.Errorf("return_error status code out of range: %d", result.StatusCode) + } + + return result, nil +} + +func parseOverrideInt(v interface{}) (int, bool) { + switch value := v.(type) { + case int: + return value, true + case float64: + if value != float64(int(value)) { + return 0, false + } + return int(value), true + default: + return 0, false + } +} + func moveValue(jsonStr, fromPath, toPath string) (string, error) { sourceValue := gjson.Get(jsonStr, fromPath) if !sourceValue.Exists() { @@ -537,6 +698,217 @@ func regexReplaceStringValue(jsonStr, path, pattern, replacement string) (string return sjson.Set(jsonStr, path, re.ReplaceAllString(current.String(), replacement)) } +type pruneObjectsOptions struct { + conditions []ConditionOperation + logic string + recursive bool +} + +func pruneObjects(jsonStr, path, contextJSON string, value interface{}) (string, error) { + options, err := parsePruneObjectsOptions(value) + if err != nil { + return "", err + } + + if path == "" { + var root interface{} + if err := common.Unmarshal([]byte(jsonStr), &root); err != nil { + return "", err + } + cleaned, _, err := pruneObjectsNode(root, options, contextJSON, true) + if err != nil { + return "", err + } + cleanedBytes, err := common.Marshal(cleaned) + if err != nil { + return "", err + } + return string(cleanedBytes), nil + } + + target := gjson.Get(jsonStr, path) + if !target.Exists() { + return jsonStr, nil + } + + var targetNode interface{} + if target.Type == gjson.JSON { + if err := common.Unmarshal([]byte(target.Raw), &targetNode); err != nil { + return "", err + } + } else { + targetNode = target.Value() + } + + cleaned, _, err := pruneObjectsNode(targetNode, options, contextJSON, true) + if err != nil { + return "", err + } + cleanedBytes, err := common.Marshal(cleaned) + if err != nil { + return "", err + } + return sjson.SetRaw(jsonStr, path, string(cleanedBytes)) +} + +func parsePruneObjectsOptions(value interface{}) (pruneObjectsOptions, error) { + opts := pruneObjectsOptions{ + logic: "AND", + recursive: true, + } + + switch raw := value.(type) { + case nil: + return opts, fmt.Errorf("prune_objects value is required") + case string: + v := strings.TrimSpace(raw) + if v == "" { + return opts, fmt.Errorf("prune_objects value is required") + } + opts.conditions = []ConditionOperation{ + { + Path: "type", + Mode: "full", + Value: v, + }, + } + case map[string]interface{}: + if logic, ok := raw["logic"].(string); ok && strings.TrimSpace(logic) != "" { + opts.logic = logic + } + if recursive, ok := raw["recursive"].(bool); ok { + opts.recursive = recursive + } + + if condRaw, exists := raw["conditions"]; exists { + conditions, err := parseConditionOperations(condRaw) + if err != nil { + return opts, err + } + opts.conditions = append(opts.conditions, conditions...) + } + + if whereRaw, exists := raw["where"]; exists { + whereMap, ok := whereRaw.(map[string]interface{}) + if !ok { + return opts, fmt.Errorf("prune_objects where must be object") + } + for key, val := range whereMap { + key = strings.TrimSpace(key) + if key == "" { + continue + } + opts.conditions = append(opts.conditions, ConditionOperation{ + Path: key, + Mode: "full", + Value: val, + }) + } + } + + if matchType, exists := raw["type"]; exists { + opts.conditions = append(opts.conditions, ConditionOperation{ + Path: "type", + Mode: "full", + Value: matchType, + }) + } + default: + return opts, fmt.Errorf("prune_objects value must be string or object") + } + + if len(opts.conditions) == 0 { + return opts, fmt.Errorf("prune_objects conditions are required") + } + return opts, nil +} + +func parseConditionOperations(raw interface{}) ([]ConditionOperation, error) { + items, ok := raw.([]interface{}) + if !ok { + return nil, fmt.Errorf("conditions must be an array") + } + + result := make([]ConditionOperation, 0, len(items)) + for _, item := range items { + itemMap, ok := item.(map[string]interface{}) + if !ok { + return nil, fmt.Errorf("condition must be object") + } + path, _ := itemMap["path"].(string) + mode, _ := itemMap["mode"].(string) + if strings.TrimSpace(path) == "" || strings.TrimSpace(mode) == "" { + return nil, fmt.Errorf("condition path/mode is required") + } + condition := ConditionOperation{ + Path: path, + Mode: mode, + } + if value, exists := itemMap["value"]; exists { + condition.Value = value + } + if invert, ok := itemMap["invert"].(bool); ok { + condition.Invert = invert + } + if passMissingKey, ok := itemMap["pass_missing_key"].(bool); ok { + condition.PassMissingKey = passMissingKey + } + result = append(result, condition) + } + return result, nil +} + +func pruneObjectsNode(node interface{}, options pruneObjectsOptions, contextJSON string, isRoot bool) (interface{}, bool, error) { + switch value := node.(type) { + case []interface{}: + result := make([]interface{}, 0, len(value)) + for _, item := range value { + next, drop, err := pruneObjectsNode(item, options, contextJSON, false) + if err != nil { + return nil, false, err + } + if drop { + continue + } + result = append(result, next) + } + return result, false, nil + case map[string]interface{}: + shouldDrop, err := shouldPruneObject(value, options, contextJSON) + if err != nil { + return nil, false, err + } + if shouldDrop && !isRoot { + return nil, true, nil + } + if !options.recursive { + return value, false, nil + } + for key, child := range value { + next, drop, err := pruneObjectsNode(child, options, contextJSON, false) + if err != nil { + return nil, false, err + } + if drop { + delete(value, key) + continue + } + value[key] = next + } + return value, false, nil + default: + return node, false, nil + } +} + +func shouldPruneObject(node map[string]interface{}, options pruneObjectsOptions, contextJSON string) (bool, error) { + nodeBytes, err := common.Marshal(node) + if err != nil { + return false, err + } + return checkConditions(string(nodeBytes), contextJSON, options.conditions, options.logic) +} + func mergeObjects(jsonStr, path string, value interface{}, keepOrigin bool) (string, error) { current := gjson.Get(jsonStr, path) var currentMap, newMap map[string]interface{} @@ -598,6 +970,32 @@ func BuildParamOverrideContext(info *RelayInfo) map[string]interface{} { } } + ctx["retry_index"] = info.RetryIndex + ctx["is_retry"] = info.RetryIndex > 0 + ctx["retry"] = map[string]interface{}{ + "index": info.RetryIndex, + "is_retry": info.RetryIndex > 0, + } + + if info.LastError != nil { + code := string(info.LastError.GetErrorCode()) + errorType := string(info.LastError.GetErrorType()) + lastError := map[string]interface{}{ + "status_code": info.LastError.StatusCode, + "message": info.LastError.Error(), + "code": code, + "error_code": code, + "type": errorType, + "error_type": errorType, + "skip_retry": types.IsSkipRetryError(info.LastError), + } + ctx["last_error"] = lastError + ctx["last_error_status_code"] = info.LastError.StatusCode + ctx["last_error_message"] = info.LastError.Error() + ctx["last_error_code"] = code + ctx["last_error_type"] = errorType + } + ctx["is_channel_test"] = info.IsChannelTest return ctx } diff --git a/relay/common/override_test.go b/relay/common/override_test.go index 021df3f60..cc1489f74 100644 --- a/relay/common/override_test.go +++ b/relay/common/override_test.go @@ -4,6 +4,8 @@ import ( "encoding/json" "reflect" "testing" + + "github.com/QuantumNous/new-api/types" ) func TestApplyParamOverrideTrimPrefix(t *testing.T) { @@ -772,6 +774,188 @@ func TestApplyParamOverrideToUpper(t *testing.T) { assertJSONEqual(t, `{"model":"GPT-4"}`, string(out)) } +func TestApplyParamOverrideReturnError(t *testing.T) { + input := []byte(`{"model":"gemini-2.5-pro"}`) + override := map[string]interface{}{ + "operations": []interface{}{ + map[string]interface{}{ + "mode": "return_error", + "value": map[string]interface{}{ + "message": "forced bad request by param override", + "status_code": 422, + "code": "forced_bad_request", + "type": "invalid_request_error", + "skip_retry": true, + }, + "conditions": []interface{}{ + map[string]interface{}{ + "path": "retry.is_retry", + "mode": "full", + "value": true, + }, + }, + }, + }, + } + ctx := map[string]interface{}{ + "retry": map[string]interface{}{ + "index": 1, + "is_retry": true, + }, + } + + _, err := ApplyParamOverride(input, override, ctx) + if err == nil { + t.Fatalf("expected error, got nil") + } + returnErr, ok := AsParamOverrideReturnError(err) + if !ok { + t.Fatalf("expected ParamOverrideReturnError, got %T: %v", err, err) + } + if returnErr.StatusCode != 422 { + t.Fatalf("expected status 422, got %d", returnErr.StatusCode) + } + if returnErr.Code != "forced_bad_request" { + t.Fatalf("expected code forced_bad_request, got %s", returnErr.Code) + } + if !returnErr.SkipRetry { + t.Fatalf("expected skip_retry true") + } +} + +func TestApplyParamOverridePruneObjectsByTypeString(t *testing.T) { + input := []byte(`{ + "messages":[ + {"role":"assistant","content":[ + {"type":"output_text","text":"a"}, + {"type":"redacted_thinking","text":"secret"}, + {"type":"tool_call","name":"tool_a"} + ]}, + {"role":"assistant","content":[ + {"type":"output_text","text":"b"}, + {"type":"wrapper","parts":[ + {"type":"redacted_thinking","text":"secret2"}, + {"type":"output_text","text":"c"} + ]} + ]} + ] + }`) + override := map[string]interface{}{ + "operations": []interface{}{ + map[string]interface{}{ + "mode": "prune_objects", + "value": "redacted_thinking", + }, + }, + } + + out, err := ApplyParamOverride(input, override, nil) + if err != nil { + t.Fatalf("ApplyParamOverride returned error: %v", err) + } + assertJSONEqual(t, `{ + "messages":[ + {"role":"assistant","content":[ + {"type":"output_text","text":"a"}, + {"type":"tool_call","name":"tool_a"} + ]}, + {"role":"assistant","content":[ + {"type":"output_text","text":"b"}, + {"type":"wrapper","parts":[ + {"type":"output_text","text":"c"} + ]} + ]} + ] + }`, string(out)) +} + +func TestApplyParamOverridePruneObjectsWhereAndPath(t *testing.T) { + input := []byte(`{ + "a":{"items":[{"type":"redacted_thinking","id":1},{"type":"output_text","id":2}]}, + "b":{"items":[{"type":"redacted_thinking","id":3},{"type":"output_text","id":4}]} + }`) + override := map[string]interface{}{ + "operations": []interface{}{ + map[string]interface{}{ + "path": "a", + "mode": "prune_objects", + "value": map[string]interface{}{ + "where": map[string]interface{}{ + "type": "redacted_thinking", + }, + }, + }, + }, + } + + out, err := ApplyParamOverride(input, override, nil) + if err != nil { + t.Fatalf("ApplyParamOverride returned error: %v", err) + } + assertJSONEqual(t, `{ + "a":{"items":[{"type":"output_text","id":2}]}, + "b":{"items":[{"type":"redacted_thinking","id":3},{"type":"output_text","id":4}]} + }`, string(out)) +} + +func TestApplyParamOverrideNormalizeThinkingSignatureUnsupported(t *testing.T) { + input := []byte(`{"items":[{"type":"redacted_thinking"}]}`) + override := map[string]interface{}{ + "operations": []interface{}{ + map[string]interface{}{ + "mode": "normalize_thinking_signature", + }, + }, + } + + _, err := ApplyParamOverride(input, override, nil) + if err == nil { + t.Fatalf("expected error, got nil") + } +} + +func TestApplyParamOverrideConditionFromRetryAndLastErrorContext(t *testing.T) { + info := &RelayInfo{ + RetryIndex: 1, + LastError: types.WithOpenAIError(types.OpenAIError{ + Message: "invalid thinking signature", + Type: "invalid_request_error", + Code: "bad_thought_signature", + }, 400), + } + ctx := BuildParamOverrideContext(info) + + input := []byte(`{"temperature":0.7}`) + override := map[string]interface{}{ + "operations": []interface{}{ + map[string]interface{}{ + "path": "temperature", + "mode": "set", + "value": 0.1, + "logic": "AND", + "conditions": []interface{}{ + map[string]interface{}{ + "path": "is_retry", + "mode": "full", + "value": true, + }, + map[string]interface{}{ + "path": "last_error.code", + "mode": "contains", + "value": "thought_signature", + }, + }, + }, + }, + } + + out, err := ApplyParamOverride(input, override, ctx) + if err != nil { + t.Fatalf("ApplyParamOverride returned error: %v", err) + } + assertJSONEqual(t, `{"temperature":0.1}`, string(out)) +} + func assertJSONEqual(t *testing.T, want, got string) { t.Helper() diff --git a/relay/common/relay_info.go b/relay/common/relay_info.go index 81b7d21d6..c10e6d5fb 100644 --- a/relay/common/relay_info.go +++ b/relay/common/relay_info.go @@ -140,6 +140,8 @@ type RelayInfo struct { SubscriptionAmountUsedAfterPreConsume int64 IsClaudeBetaQuery bool // /v1/messages?beta=true IsChannelTest bool // channel test request + RetryIndex int + LastError *types.NewAPIError PriceData types.PriceData diff --git a/relay/compatible_handler.go b/relay/compatible_handler.go index e7adddbbf..4cf5e0411 100644 --- a/relay/compatible_handler.go +++ b/relay/compatible_handler.go @@ -174,7 +174,7 @@ func TextHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types if len(info.ParamOverride) > 0 { jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride, relaycommon.BuildParamOverrideContext(info)) if err != nil { - return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry()) + return newAPIErrorFromParamOverride(err) } } diff --git a/relay/embedding_handler.go b/relay/embedding_handler.go index 1a41756b8..edbd1f7e6 100644 --- a/relay/embedding_handler.go +++ b/relay/embedding_handler.go @@ -2,7 +2,6 @@ package relay import ( "bytes" - "encoding/json" "fmt" "net/http" @@ -46,7 +45,7 @@ func EmbeddingHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError * return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry()) } relaycommon.AppendRequestConversionFromRequest(info, convertedRequest) - jsonData, err := json.Marshal(convertedRequest) + jsonData, err := common.Marshal(convertedRequest) if err != nil { return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry()) } @@ -54,7 +53,7 @@ func EmbeddingHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError * if len(info.ParamOverride) > 0 { jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride, relaycommon.BuildParamOverrideContext(info)) if err != nil { - return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry()) + return newAPIErrorFromParamOverride(err) } } diff --git a/relay/gemini_handler.go b/relay/gemini_handler.go index a1b8e592e..a58c404f5 100644 --- a/relay/gemini_handler.go +++ b/relay/gemini_handler.go @@ -159,7 +159,7 @@ func GeminiHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ if len(info.ParamOverride) > 0 { jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride, relaycommon.BuildParamOverrideContext(info)) if err != nil { - return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry()) + return newAPIErrorFromParamOverride(err) } } @@ -257,14 +257,9 @@ func GeminiEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo) (newAPI // apply param override if len(info.ParamOverride) > 0 { - reqMap := make(map[string]interface{}) - _ = common.Unmarshal(jsonData, &reqMap) - for key, value := range info.ParamOverride { - reqMap[key] = value - } - jsonData, err = common.Marshal(reqMap) + jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride, relaycommon.BuildParamOverrideContext(info)) if err != nil { - return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry()) + return newAPIErrorFromParamOverride(err) } } logger.LogDebug(c, "Gemini embedding request body: "+string(jsonData)) diff --git a/relay/image_handler.go b/relay/image_handler.go index e83294268..21a5be2fa 100644 --- a/relay/image_handler.go +++ b/relay/image_handler.go @@ -72,7 +72,7 @@ func ImageHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *type if len(info.ParamOverride) > 0 { jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride, relaycommon.BuildParamOverrideContext(info)) if err != nil { - return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry()) + return newAPIErrorFromParamOverride(err) } } diff --git a/relay/param_override_error.go b/relay/param_override_error.go new file mode 100644 index 000000000..c23382985 --- /dev/null +++ b/relay/param_override_error.go @@ -0,0 +1,13 @@ +package relay + +import ( + relaycommon "github.com/QuantumNous/new-api/relay/common" + "github.com/QuantumNous/new-api/types" +) + +func newAPIErrorFromParamOverride(err error) *types.NewAPIError { + if fixedErr, ok := relaycommon.AsParamOverrideReturnError(err); ok { + return relaycommon.NewAPIErrorFromParamOverride(fixedErr) + } + return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry()) +} diff --git a/relay/rerank_handler.go b/relay/rerank_handler.go index 8fe2930e9..9c4bef6e1 100644 --- a/relay/rerank_handler.go +++ b/relay/rerank_handler.go @@ -63,7 +63,7 @@ func RerankHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ if len(info.ParamOverride) > 0 { jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride, relaycommon.BuildParamOverrideContext(info)) if err != nil { - return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry()) + return newAPIErrorFromParamOverride(err) } } diff --git a/relay/responses_handler.go b/relay/responses_handler.go index 04fc3470e..2190be87f 100644 --- a/relay/responses_handler.go +++ b/relay/responses_handler.go @@ -98,7 +98,7 @@ func ResponsesHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError * if len(info.ParamOverride) > 0 { jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride, relaycommon.BuildParamOverrideContext(info)) if err != nil { - return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry()) + return newAPIErrorFromParamOverride(err) } }