From 4ac59ca6e6d63fddfc3effa66e86f0208c2638fd Mon Sep 17 00:00:00 2001 From: Seefs Date: Thu, 12 Feb 2026 14:58:17 +0800 Subject: [PATCH] fix: support numeric status code mapping in ResetStatusCode --- service/error.go | 43 +++++++++++++++++++++++++++++--- service/error_test.go | 57 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 97 insertions(+), 3 deletions(-) create mode 100644 service/error_test.go diff --git a/service/error.go b/service/error.go index 889964beb..7a9d7a815 100644 --- a/service/error.go +++ b/service/error.go @@ -2,9 +2,11 @@ package service import ( "context" + "encoding/json" "errors" "fmt" "io" + "math" "net/http" "strconv" "strings" @@ -127,10 +129,13 @@ func RelayErrorHandler(ctx context.Context, resp *http.Response, showBodyWhenFai } func ResetStatusCode(newApiErr *types.NewAPIError, statusCodeMappingStr string) { + if newApiErr == nil { + return + } if statusCodeMappingStr == "" || statusCodeMappingStr == "{}" { return } - statusCodeMapping := make(map[string]string) + statusCodeMapping := make(map[string]any) err := common.Unmarshal([]byte(statusCodeMappingStr), &statusCodeMapping) if err != nil { return @@ -139,12 +144,44 @@ func ResetStatusCode(newApiErr *types.NewAPIError, statusCodeMappingStr string) return } codeStr := strconv.Itoa(newApiErr.StatusCode) - if _, ok := statusCodeMapping[codeStr]; ok { - intCode, _ := strconv.Atoi(statusCodeMapping[codeStr]) + if value, ok := statusCodeMapping[codeStr]; ok { + intCode, ok := parseStatusCodeMappingValue(value) + if !ok { + return + } newApiErr.StatusCode = intCode } } +func parseStatusCodeMappingValue(value any) (int, bool) { + switch v := value.(type) { + case string: + if v == "" { + return 0, false + } + statusCode, err := strconv.Atoi(v) + if err != nil { + return 0, false + } + return statusCode, true + case float64: + if v != math.Trunc(v) { + return 0, false + } + return int(v), true + case int: + return v, true + case json.Number: + statusCode, err := strconv.Atoi(v.String()) + if err != nil { + return 0, false + } + return statusCode, true + default: + return 0, false + } +} + func TaskErrorWrapperLocal(err error, code string, statusCode int) *dto.TaskError { openaiErr := TaskErrorWrapper(err, code, statusCode) openaiErr.LocalError = true diff --git a/service/error_test.go b/service/error_test.go new file mode 100644 index 000000000..2303e8f4a --- /dev/null +++ b/service/error_test.go @@ -0,0 +1,57 @@ +package service + +import ( + "testing" + + "github.com/QuantumNous/new-api/types" + "github.com/stretchr/testify/require" +) + +func TestResetStatusCode(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + statusCode int + statusCodeConfig string + expectedCode int + }{ + { + name: "map string value", + statusCode: 429, + statusCodeConfig: `{"429":"503"}`, + expectedCode: 503, + }, + { + name: "map int value", + statusCode: 429, + statusCodeConfig: `{"429":503}`, + expectedCode: 503, + }, + { + name: "skip invalid string value", + statusCode: 429, + statusCodeConfig: `{"429":"bad-code"}`, + expectedCode: 429, + }, + { + name: "skip status code 200", + statusCode: 200, + statusCodeConfig: `{"200":503}`, + expectedCode: 200, + }, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + newAPIError := &types.NewAPIError{ + StatusCode: tc.statusCode, + } + ResetStatusCode(newAPIError, tc.statusCodeConfig) + require.Equal(t, tc.expectedCode, newAPIError.StatusCode) + }) + } +}