diff --git a/controller/option.go b/controller/option.go index 4d5b4e8d2..a2db95326 100644 --- a/controller/option.go +++ b/controller/option.go @@ -10,6 +10,7 @@ import ( "github.com/QuantumNous/new-api/model" "github.com/QuantumNous/new-api/setting" "github.com/QuantumNous/new-api/setting/console_setting" + "github.com/QuantumNous/new-api/setting/operation_setting" "github.com/QuantumNous/new-api/setting/ratio_setting" "github.com/QuantumNous/new-api/setting/system_setting" @@ -177,6 +178,15 @@ func UpdateOption(c *gin.Context) { }) return } + case "AutomaticDisableStatusCodes": + _, err = operation_setting.ParseHTTPStatusCodeRanges(option.Value.(string)) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } case "console_setting.api_info": err = console_setting.ValidateConsoleSettings(option.Value.(string), "ApiInfo") if err != nil { diff --git a/controller/relay.go b/controller/relay.go index 9759fa30c..72ea3e24c 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -348,7 +348,7 @@ func processChannelError(c *gin.Context, channelError types.ChannelError, err *t // do not use context to get channel info, there may be inconsistent channel info when processing asynchronously if service.ShouldDisableChannel(channelError.ChannelType, err) && channelError.AutoBan { gopool.Go(func() { - service.DisableChannel(channelError, err.Error()) + service.DisableChannel(channelError, err.ErrorWithStatusCode()) }) } @@ -378,7 +378,7 @@ func processChannelError(c *gin.Context, channelError types.ChannelError, err *t adminInfo["multi_key_index"] = common.GetContextKeyInt(c, constant.ContextKeyChannelMultiKeyIndex) } other["admin_info"] = adminInfo - model.RecordErrorLog(c, userId, channelId, modelName, tokenName, err.MaskSensitiveError(), tokenId, 0, false, userGroup, other) + model.RecordErrorLog(c, userId, channelId, modelName, tokenName, err.MaskSensitiveErrorWithStatusCode(), tokenId, 0, false, userGroup, other) } } diff --git a/model/option.go b/model/option.go index e9fd50d7f..24cf7862d 100644 --- a/model/option.go +++ b/model/option.go @@ -143,6 +143,7 @@ func InitOptionMap() { common.OptionMap["SensitiveWords"] = setting.SensitiveWordsToString() common.OptionMap["StreamCacheQueueLength"] = strconv.Itoa(setting.StreamCacheQueueLength) common.OptionMap["AutomaticDisableKeywords"] = operation_setting.AutomaticDisableKeywordsToString() + common.OptionMap["AutomaticDisableStatusCodes"] = operation_setting.AutomaticDisableStatusCodesToString() common.OptionMap["ExposeRatioEnabled"] = strconv.FormatBool(ratio_setting.IsExposeRatioEnabled()) // 自动添加所有注册的模型配置 @@ -444,6 +445,8 @@ func updateOptionMap(key string, value string) (err error) { setting.SensitiveWordsFromString(value) case "AutomaticDisableKeywords": operation_setting.AutomaticDisableKeywordsFromString(value) + case "AutomaticDisableStatusCodes": + err = operation_setting.AutomaticDisableStatusCodesFromString(value) case "StreamCacheQueueLength": setting.StreamCacheQueueLength, _ = strconv.Atoi(value) case "PayMethods": diff --git a/service/channel.go b/service/channel.go index 8f8a35726..96bc1efe7 100644 --- a/service/channel.go +++ b/service/channel.go @@ -57,9 +57,12 @@ func ShouldDisableChannel(channelType int, err *types.NewAPIError) bool { if types.IsSkipRetryError(err) { return false } - if err.StatusCode == http.StatusUnauthorized { + if operation_setting.ShouldDisableByStatusCode(err.StatusCode) { return true } + //if err.StatusCode == http.StatusUnauthorized { + // return true + //} if err.StatusCode == http.StatusForbidden { switch channelType { case constant.ChannelTypeGemini: diff --git a/setting/operation_setting/status_code_ranges.go b/setting/operation_setting/status_code_ranges.go new file mode 100644 index 000000000..7a763008e --- /dev/null +++ b/setting/operation_setting/status_code_ranges.go @@ -0,0 +1,147 @@ +package operation_setting + +import ( + "fmt" + "sort" + "strconv" + "strings" +) + +type StatusCodeRange struct { + Start int + End int +} + +var AutomaticDisableStatusCodeRanges = []StatusCodeRange{{Start: 401, End: 401}} + +func AutomaticDisableStatusCodesToString() string { + if len(AutomaticDisableStatusCodeRanges) == 0 { + return "" + } + parts := make([]string, 0, len(AutomaticDisableStatusCodeRanges)) + for _, r := range AutomaticDisableStatusCodeRanges { + if r.Start == r.End { + parts = append(parts, strconv.Itoa(r.Start)) + continue + } + parts = append(parts, fmt.Sprintf("%d-%d", r.Start, r.End)) + } + return strings.Join(parts, ",") +} + +func AutomaticDisableStatusCodesFromString(s string) error { + ranges, err := ParseHTTPStatusCodeRanges(s) + if err != nil { + return err + } + AutomaticDisableStatusCodeRanges = ranges + return nil +} + +func ShouldDisableByStatusCode(code int) bool { + if code < 100 || code > 599 { + return false + } + for _, r := range AutomaticDisableStatusCodeRanges { + if code < r.Start { + return false + } + if code <= r.End { + return true + } + } + return false +} + +func ParseHTTPStatusCodeRanges(input string) ([]StatusCodeRange, error) { + input = strings.TrimSpace(input) + if input == "" { + return nil, nil + } + + input = strings.NewReplacer(",", ",").Replace(input) + segments := strings.Split(input, ",") + + var ranges []StatusCodeRange + var invalid []string + + for _, seg := range segments { + seg = strings.TrimSpace(seg) + if seg == "" { + continue + } + r, err := parseHTTPStatusCodeToken(seg) + if err != nil { + invalid = append(invalid, seg) + continue + } + ranges = append(ranges, r) + } + + if len(invalid) > 0 { + return nil, fmt.Errorf("invalid http status code rules: %s", strings.Join(invalid, ", ")) + } + if len(ranges) == 0 { + return nil, nil + } + + sort.Slice(ranges, func(i, j int) bool { + if ranges[i].Start == ranges[j].Start { + return ranges[i].End < ranges[j].End + } + return ranges[i].Start < ranges[j].Start + }) + + merged := []StatusCodeRange{ranges[0]} + for _, r := range ranges[1:] { + last := &merged[len(merged)-1] + if r.Start <= last.End+1 { + if r.End > last.End { + last.End = r.End + } + continue + } + merged = append(merged, r) + } + + return merged, nil +} + +func parseHTTPStatusCodeToken(token string) (StatusCodeRange, error) { + token = strings.TrimSpace(token) + token = strings.ReplaceAll(token, " ", "") + if token == "" { + return StatusCodeRange{}, fmt.Errorf("empty token") + } + + if strings.Contains(token, "-") { + parts := strings.Split(token, "-") + if len(parts) != 2 || parts[0] == "" || parts[1] == "" { + return StatusCodeRange{}, fmt.Errorf("invalid range token: %s", token) + } + start, err := strconv.Atoi(parts[0]) + if err != nil { + return StatusCodeRange{}, fmt.Errorf("invalid range start: %s", token) + } + end, err := strconv.Atoi(parts[1]) + if err != nil { + return StatusCodeRange{}, fmt.Errorf("invalid range end: %s", token) + } + if start > end { + return StatusCodeRange{}, fmt.Errorf("range start > end: %s", token) + } + if start < 100 || end > 599 { + return StatusCodeRange{}, fmt.Errorf("range out of bounds: %s", token) + } + return StatusCodeRange{Start: start, End: end}, nil + } + + code, err := strconv.Atoi(token) + if err != nil { + return StatusCodeRange{}, fmt.Errorf("invalid status code: %s", token) + } + if code < 100 || code > 599 { + return StatusCodeRange{}, fmt.Errorf("status code out of bounds: %s", token) + } + return StatusCodeRange{Start: code, End: code}, nil +} diff --git a/setting/operation_setting/status_code_ranges_test.go b/setting/operation_setting/status_code_ranges_test.go new file mode 100644 index 000000000..1712efd75 --- /dev/null +++ b/setting/operation_setting/status_code_ranges_test.go @@ -0,0 +1,52 @@ +package operation_setting + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestParseHTTPStatusCodeRanges_CommaSeparated(t *testing.T) { + ranges, err := ParseHTTPStatusCodeRanges("401,403,500-599") + require.NoError(t, err) + require.Equal(t, []StatusCodeRange{ + {Start: 401, End: 401}, + {Start: 403, End: 403}, + {Start: 500, End: 599}, + }, ranges) +} + +func TestParseHTTPStatusCodeRanges_MergeAndNormalize(t *testing.T) { + ranges, err := ParseHTTPStatusCodeRanges("500-505,504,401,403,402") + require.NoError(t, err) + require.Equal(t, []StatusCodeRange{ + {Start: 401, End: 403}, + {Start: 500, End: 505}, + }, ranges) +} + +func TestParseHTTPStatusCodeRanges_Invalid(t *testing.T) { + _, err := ParseHTTPStatusCodeRanges("99,600,foo,500-400,500-") + require.Error(t, err) +} + +func TestParseHTTPStatusCodeRanges_NoComma_IsInvalid(t *testing.T) { + _, err := ParseHTTPStatusCodeRanges("401 403") + require.Error(t, err) +} + +func TestShouldDisableByStatusCode(t *testing.T) { + orig := AutomaticDisableStatusCodeRanges + t.Cleanup(func() { AutomaticDisableStatusCodeRanges = orig }) + + AutomaticDisableStatusCodeRanges = []StatusCodeRange{ + {Start: 401, End: 403}, + {Start: 500, End: 599}, + } + + require.True(t, ShouldDisableByStatusCode(401)) + require.True(t, ShouldDisableByStatusCode(403)) + require.False(t, ShouldDisableByStatusCode(404)) + require.True(t, ShouldDisableByStatusCode(500)) + require.False(t, ShouldDisableByStatusCode(200)) +} diff --git a/types/error.go b/types/error.go index b060a9db6..e112eeefb 100644 --- a/types/error.go +++ b/types/error.go @@ -130,6 +130,20 @@ func (e *NewAPIError) Error() string { return e.Err.Error() } +func (e *NewAPIError) ErrorWithStatusCode() string { + if e == nil { + return "" + } + msg := e.Error() + if e.StatusCode == 0 { + return msg + } + if msg == "" { + return fmt.Sprintf("status_code=%d", e.StatusCode) + } + return fmt.Sprintf("status_code=%d, %s", e.StatusCode, msg) +} + func (e *NewAPIError) MaskSensitiveError() string { if e == nil { return "" @@ -144,6 +158,20 @@ func (e *NewAPIError) MaskSensitiveError() string { return common.MaskSensitiveInfo(errStr) } +func (e *NewAPIError) MaskSensitiveErrorWithStatusCode() string { + if e == nil { + return "" + } + msg := e.MaskSensitiveError() + if e.StatusCode == 0 { + return msg + } + if msg == "" { + return fmt.Sprintf("status_code=%d", e.StatusCode) + } + return fmt.Sprintf("status_code=%d, %s", e.StatusCode, msg) +} + func (e *NewAPIError) SetMessage(message string) { e.Err = errors.New(message) } diff --git a/web/src/components/settings/OperationSetting.jsx b/web/src/components/settings/OperationSetting.jsx index 92591db45..4a77bcf10 100644 --- a/web/src/components/settings/OperationSetting.jsx +++ b/web/src/components/settings/OperationSetting.jsx @@ -70,6 +70,7 @@ const OperationSetting = () => { AutomaticDisableChannelEnabled: false, AutomaticEnableChannelEnabled: false, AutomaticDisableKeywords: '', + AutomaticDisableStatusCodes: '401', 'monitor_setting.auto_test_channel_enabled': false, 'monitor_setting.auto_test_channel_minutes': 10 /* 签到设置 */, 'checkin_setting.enabled': false, diff --git a/web/src/helpers/index.js b/web/src/helpers/index.js index cfb0270e7..a86c3bca5 100644 --- a/web/src/helpers/index.js +++ b/web/src/helpers/index.js @@ -29,3 +29,4 @@ export * from './token'; export * from './boolean'; export * from './dashboard'; export * from './passkey'; +export * from './statusCodeRules'; diff --git a/web/src/helpers/statusCodeRules.js b/web/src/helpers/statusCodeRules.js new file mode 100644 index 000000000..a0d5e75f9 --- /dev/null +++ b/web/src/helpers/statusCodeRules.js @@ -0,0 +1,96 @@ +export function parseHttpStatusCodeRules(input) { + const raw = (input ?? '').toString().trim(); + if (raw.length === 0) { + return { + ok: true, + ranges: [], + tokens: [], + normalized: '', + invalidTokens: [], + }; + } + + const sanitized = raw.replace(/[,]/g, ','); + const segments = sanitized.split(/[,]/g); + + const ranges = []; + const invalidTokens = []; + + for (const segment of segments) { + const trimmed = segment.trim(); + if (!trimmed) continue; + const parsed = parseToken(trimmed); + if (!parsed) invalidTokens.push(trimmed); + else ranges.push(parsed); + } + + if (invalidTokens.length > 0) { + return { + ok: false, + ranges: [], + tokens: [], + normalized: raw, + invalidTokens, + }; + } + + const merged = mergeRanges(ranges); + const tokens = merged.map((r) => (r.start === r.end ? `${r.start}` : `${r.start}-${r.end}`)); + const normalized = tokens.join(','); + + return { + ok: true, + ranges: merged, + tokens, + normalized, + invalidTokens: [], + }; +} + +function parseToken(token) { + const cleaned = (token ?? '').toString().trim().replaceAll(' ', ''); + if (!cleaned) return null; + + if (cleaned.includes('-')) { + const parts = cleaned.split('-'); + if (parts.length !== 2) return null; + const [a, b] = parts; + if (!isNumber(a) || !isNumber(b)) return null; + const start = Number.parseInt(a, 10); + const end = Number.parseInt(b, 10); + if (!Number.isFinite(start) || !Number.isFinite(end)) return null; + if (start > end) return null; + if (start < 100 || end > 599) return null; + return { start, end }; + } + + if (!isNumber(cleaned)) return null; + const code = Number.parseInt(cleaned, 10); + if (!Number.isFinite(code)) return null; + if (code < 100 || code > 599) return null; + return { start: code, end: code }; +} + +function isNumber(s) { + return typeof s === 'string' && /^\d+$/.test(s); +} + +function mergeRanges(ranges) { + if (!Array.isArray(ranges) || ranges.length === 0) return []; + + const sorted = [...ranges].sort((a, b) => (a.start !== b.start ? a.start - b.start : a.end - b.end)); + const merged = [sorted[0]]; + + for (let i = 1; i < sorted.length; i += 1) { + const current = sorted[i]; + const last = merged[merged.length - 1]; + + if (current.start <= last.end + 1) { + last.end = Math.max(last.end, current.end); + continue; + } + merged.push({ ...current }); + } + + return merged; +} diff --git a/web/src/i18n/locales/en.json b/web/src/i18n/locales/en.json index f36f70e93..f6d55544d 100644 --- a/web/src/i18n/locales/en.json +++ b/web/src/i18n/locales/en.json @@ -1923,6 +1923,10 @@ "自动测试所有通道间隔时间": "Auto test interval for all channels", "自动禁用": "Auto disabled", "自动禁用关键词": "Automatic disable keywords", + "自动禁用状态码": "Auto-disable status codes", + "自动禁用状态码格式不正确": "Invalid auto-disable status code format", + "支持填写单个状态码或范围(含首尾),使用逗号分隔": "Supports single status codes or inclusive ranges; separate with commas", + "例如:401, 403, 429, 500-599": "e.g. 401,403,429,500-599", "自动选择": "Auto Select", "自定义充值数量选项": "Custom Recharge Amount Options", "自定义充值数量选项不是合法的 JSON 数组": "Custom recharge amount options is not a valid JSON array", diff --git a/web/src/i18n/locales/zh.json b/web/src/i18n/locales/zh.json index 3b67fba18..e91f50a4e 100644 --- a/web/src/i18n/locales/zh.json +++ b/web/src/i18n/locales/zh.json @@ -1909,6 +1909,10 @@ "自动测试所有通道间隔时间": "自动测试所有通道间隔时间", "自动禁用": "自动禁用", "自动禁用关键词": "自动禁用关键词", + "自动禁用状态码": "自动禁用状态码", + "自动禁用状态码格式不正确": "自动禁用状态码格式不正确", + "支持填写单个状态码或范围(含首尾),使用逗号分隔": "支持填写单个状态码或范围(含首尾),使用逗号分隔", + "例如:401, 403, 429, 500-599": "例如:401,403,429,500-599", "自动选择": "自动选择", "自定义充值数量选项": "自定义充值数量选项", "自定义充值数量选项不是合法的 JSON 数组": "自定义充值数量选项不是合法的 JSON 数组", diff --git a/web/src/pages/Setting/Operation/SettingsMonitoring.jsx b/web/src/pages/Setting/Operation/SettingsMonitoring.jsx index b93a5ff09..9715ef3cb 100644 --- a/web/src/pages/Setting/Operation/SettingsMonitoring.jsx +++ b/web/src/pages/Setting/Operation/SettingsMonitoring.jsx @@ -18,19 +18,29 @@ For commercial licensing, please contact support@quantumnous.com */ import React, { useEffect, useState, useRef } from 'react'; -import { Button, Col, Form, Row, Spin } from '@douyinfe/semi-ui'; +import { + Button, + Col, + Form, + Row, + Spin, + Tag, + Typography, +} from '@douyinfe/semi-ui'; import { compareObjects, API, showError, showSuccess, showWarning, + parseHttpStatusCodeRules, verifyJSON, } from '../../../helpers'; import { useTranslation } from 'react-i18next'; export default function SettingsMonitoring(props) { const { t } = useTranslation(); + const { Text } = Typography; const [loading, setLoading] = useState(false); const [inputs, setInputs] = useState({ ChannelDisableThreshold: '', @@ -38,21 +48,37 @@ export default function SettingsMonitoring(props) { AutomaticDisableChannelEnabled: false, AutomaticEnableChannelEnabled: false, AutomaticDisableKeywords: '', + AutomaticDisableStatusCodes: '401', 'monitor_setting.auto_test_channel_enabled': false, 'monitor_setting.auto_test_channel_minutes': 10, }); const refForm = useRef(); const [inputsRow, setInputsRow] = useState(inputs); + const parsedAutoDisableStatusCodes = parseHttpStatusCodeRules( + inputs.AutomaticDisableStatusCodes || '', + ); function onSubmit() { const updateArray = compareObjects(inputs, inputsRow); if (!updateArray.length) return showWarning(t('你似乎并没有修改什么')); + if (!parsedAutoDisableStatusCodes.ok) { + const details = + parsedAutoDisableStatusCodes.invalidTokens && + parsedAutoDisableStatusCodes.invalidTokens.length > 0 + ? `: ${parsedAutoDisableStatusCodes.invalidTokens.join(', ')}` + : ''; + return showError(`${t('自动禁用状态码格式不正确')}${details}`); + } const requestQueue = updateArray.map((item) => { let value = ''; if (typeof inputs[item.key] === 'boolean') { value = String(inputs[item.key]); } else { - value = inputs[item.key]; + if (item.key === 'AutomaticDisableStatusCodes') { + value = parsedAutoDisableStatusCodes.normalized; + } else { + value = inputs[item.key]; + } } return API.put('/api/option/', { key: item.key, @@ -207,6 +233,45 @@ export default function SettingsMonitoring(props) { + + setInputs({ ...inputs, AutomaticDisableStatusCodes: value }) + } + /> + {parsedAutoDisableStatusCodes.ok && + parsedAutoDisableStatusCodes.tokens.length > 0 && ( +
+ {parsedAutoDisableStatusCodes.tokens.map((token) => ( + + {token} + + ))} +
+ )} + {!parsedAutoDisableStatusCodes.ok && ( + + {t('自动禁用状态码格式不正确')} + {parsedAutoDisableStatusCodes.invalidTokens && + parsedAutoDisableStatusCodes.invalidTokens.length > 0 + ? `: ${parsedAutoDisableStatusCodes.invalidTokens.join( + ', ', + )}` + : ''} + + )}