mirror of
https://github.com/QuantumNous/new-api.git
synced 2026-03-30 00:46:42 +00:00
Merge pull request #2647 from seefs001/feature/status-code-auto-disable
feat: status code auto-disable configuration
This commit is contained in:
147
setting/operation_setting/status_code_ranges.go
Normal file
147
setting/operation_setting/status_code_ranges.go
Normal file
@@ -0,0 +1,147 @@
|
||||
package operation_setting
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type StatusCodeRange struct {
|
||||
Start int
|
||||
End int
|
||||
}
|
||||
|
||||
var AutomaticDisableStatusCodeRanges = []StatusCodeRange{{Start: 401, End: 401}}
|
||||
|
||||
func AutomaticDisableStatusCodesToString() string {
|
||||
if len(AutomaticDisableStatusCodeRanges) == 0 {
|
||||
return ""
|
||||
}
|
||||
parts := make([]string, 0, len(AutomaticDisableStatusCodeRanges))
|
||||
for _, r := range AutomaticDisableStatusCodeRanges {
|
||||
if r.Start == r.End {
|
||||
parts = append(parts, strconv.Itoa(r.Start))
|
||||
continue
|
||||
}
|
||||
parts = append(parts, fmt.Sprintf("%d-%d", r.Start, r.End))
|
||||
}
|
||||
return strings.Join(parts, ",")
|
||||
}
|
||||
|
||||
func AutomaticDisableStatusCodesFromString(s string) error {
|
||||
ranges, err := ParseHTTPStatusCodeRanges(s)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
AutomaticDisableStatusCodeRanges = ranges
|
||||
return nil
|
||||
}
|
||||
|
||||
func ShouldDisableByStatusCode(code int) bool {
|
||||
if code < 100 || code > 599 {
|
||||
return false
|
||||
}
|
||||
for _, r := range AutomaticDisableStatusCodeRanges {
|
||||
if code < r.Start {
|
||||
return false
|
||||
}
|
||||
if code <= r.End {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func ParseHTTPStatusCodeRanges(input string) ([]StatusCodeRange, error) {
|
||||
input = strings.TrimSpace(input)
|
||||
if input == "" {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
input = strings.NewReplacer(",", ",").Replace(input)
|
||||
segments := strings.Split(input, ",")
|
||||
|
||||
var ranges []StatusCodeRange
|
||||
var invalid []string
|
||||
|
||||
for _, seg := range segments {
|
||||
seg = strings.TrimSpace(seg)
|
||||
if seg == "" {
|
||||
continue
|
||||
}
|
||||
r, err := parseHTTPStatusCodeToken(seg)
|
||||
if err != nil {
|
||||
invalid = append(invalid, seg)
|
||||
continue
|
||||
}
|
||||
ranges = append(ranges, r)
|
||||
}
|
||||
|
||||
if len(invalid) > 0 {
|
||||
return nil, fmt.Errorf("invalid http status code rules: %s", strings.Join(invalid, ", "))
|
||||
}
|
||||
if len(ranges) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
sort.Slice(ranges, func(i, j int) bool {
|
||||
if ranges[i].Start == ranges[j].Start {
|
||||
return ranges[i].End < ranges[j].End
|
||||
}
|
||||
return ranges[i].Start < ranges[j].Start
|
||||
})
|
||||
|
||||
merged := []StatusCodeRange{ranges[0]}
|
||||
for _, r := range ranges[1:] {
|
||||
last := &merged[len(merged)-1]
|
||||
if r.Start <= last.End+1 {
|
||||
if r.End > last.End {
|
||||
last.End = r.End
|
||||
}
|
||||
continue
|
||||
}
|
||||
merged = append(merged, r)
|
||||
}
|
||||
|
||||
return merged, nil
|
||||
}
|
||||
|
||||
func parseHTTPStatusCodeToken(token string) (StatusCodeRange, error) {
|
||||
token = strings.TrimSpace(token)
|
||||
token = strings.ReplaceAll(token, " ", "")
|
||||
if token == "" {
|
||||
return StatusCodeRange{}, fmt.Errorf("empty token")
|
||||
}
|
||||
|
||||
if strings.Contains(token, "-") {
|
||||
parts := strings.Split(token, "-")
|
||||
if len(parts) != 2 || parts[0] == "" || parts[1] == "" {
|
||||
return StatusCodeRange{}, fmt.Errorf("invalid range token: %s", token)
|
||||
}
|
||||
start, err := strconv.Atoi(parts[0])
|
||||
if err != nil {
|
||||
return StatusCodeRange{}, fmt.Errorf("invalid range start: %s", token)
|
||||
}
|
||||
end, err := strconv.Atoi(parts[1])
|
||||
if err != nil {
|
||||
return StatusCodeRange{}, fmt.Errorf("invalid range end: %s", token)
|
||||
}
|
||||
if start > end {
|
||||
return StatusCodeRange{}, fmt.Errorf("range start > end: %s", token)
|
||||
}
|
||||
if start < 100 || end > 599 {
|
||||
return StatusCodeRange{}, fmt.Errorf("range out of bounds: %s", token)
|
||||
}
|
||||
return StatusCodeRange{Start: start, End: end}, nil
|
||||
}
|
||||
|
||||
code, err := strconv.Atoi(token)
|
||||
if err != nil {
|
||||
return StatusCodeRange{}, fmt.Errorf("invalid status code: %s", token)
|
||||
}
|
||||
if code < 100 || code > 599 {
|
||||
return StatusCodeRange{}, fmt.Errorf("status code out of bounds: %s", token)
|
||||
}
|
||||
return StatusCodeRange{Start: code, End: code}, nil
|
||||
}
|
||||
52
setting/operation_setting/status_code_ranges_test.go
Normal file
52
setting/operation_setting/status_code_ranges_test.go
Normal file
@@ -0,0 +1,52 @@
|
||||
package operation_setting
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestParseHTTPStatusCodeRanges_CommaSeparated(t *testing.T) {
|
||||
ranges, err := ParseHTTPStatusCodeRanges("401,403,500-599")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, []StatusCodeRange{
|
||||
{Start: 401, End: 401},
|
||||
{Start: 403, End: 403},
|
||||
{Start: 500, End: 599},
|
||||
}, ranges)
|
||||
}
|
||||
|
||||
func TestParseHTTPStatusCodeRanges_MergeAndNormalize(t *testing.T) {
|
||||
ranges, err := ParseHTTPStatusCodeRanges("500-505,504,401,403,402")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, []StatusCodeRange{
|
||||
{Start: 401, End: 403},
|
||||
{Start: 500, End: 505},
|
||||
}, ranges)
|
||||
}
|
||||
|
||||
func TestParseHTTPStatusCodeRanges_Invalid(t *testing.T) {
|
||||
_, err := ParseHTTPStatusCodeRanges("99,600,foo,500-400,500-")
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestParseHTTPStatusCodeRanges_NoComma_IsInvalid(t *testing.T) {
|
||||
_, err := ParseHTTPStatusCodeRanges("401 403")
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestShouldDisableByStatusCode(t *testing.T) {
|
||||
orig := AutomaticDisableStatusCodeRanges
|
||||
t.Cleanup(func() { AutomaticDisableStatusCodeRanges = orig })
|
||||
|
||||
AutomaticDisableStatusCodeRanges = []StatusCodeRange{
|
||||
{Start: 401, End: 403},
|
||||
{Start: 500, End: 599},
|
||||
}
|
||||
|
||||
require.True(t, ShouldDisableByStatusCode(401))
|
||||
require.True(t, ShouldDisableByStatusCode(403))
|
||||
require.False(t, ShouldDisableByStatusCode(404))
|
||||
require.True(t, ShouldDisableByStatusCode(500))
|
||||
require.False(t, ShouldDisableByStatusCode(200))
|
||||
}
|
||||
Reference in New Issue
Block a user