Merge pull request #858 from james-6-23/fix/pool-mode-03bf3485

支持 API Key 上游池模式的同账号重试次数配置与自定义错误策略
This commit is contained in:
Wesley Liddick
2026-03-09 08:48:53 +08:00
committed by GitHub
13 changed files with 558 additions and 40 deletions

View File

@@ -647,6 +647,75 @@ func (a *Account) IsCustomErrorCodesEnabled() bool {
return false
}
// IsPoolMode 检查 API Key 账号是否启用池模式。
// 池模式下,上游错误不标记本地账号状态,而是在同一账号上重试。
func (a *Account) IsPoolMode() bool {
if a.Type != AccountTypeAPIKey || a.Credentials == nil {
return false
}
if v, ok := a.Credentials["pool_mode"]; ok {
if enabled, ok := v.(bool); ok {
return enabled
}
}
return false
}
const (
defaultPoolModeRetryCount = 3
maxPoolModeRetryCount = 10
)
// GetPoolModeRetryCount 返回池模式同账号重试次数。
// 未配置或配置非法时回退为默认值 3小于 0 按 0 处理;过大则截断到 10。
func (a *Account) GetPoolModeRetryCount() int {
if a == nil || !a.IsPoolMode() || a.Credentials == nil {
return defaultPoolModeRetryCount
}
raw, ok := a.Credentials["pool_mode_retry_count"]
if !ok || raw == nil {
return defaultPoolModeRetryCount
}
count := parsePoolModeRetryCount(raw)
if count < 0 {
return 0
}
if count > maxPoolModeRetryCount {
return maxPoolModeRetryCount
}
return count
}
func parsePoolModeRetryCount(value any) int {
switch v := value.(type) {
case int:
return v
case int64:
return int(v)
case float64:
return int(v)
case json.Number:
if i, err := v.Int64(); err == nil {
return int(i)
}
case string:
if i, err := strconv.Atoi(strings.TrimSpace(v)); err == nil {
return i
}
}
return defaultPoolModeRetryCount
}
// isPoolModeRetryableStatus 池模式下应触发同账号重试的状态码
func isPoolModeRetryableStatus(statusCode int) bool {
switch statusCode {
case 401, 403, 429:
return true
default:
return false
}
}
func (a *Account) GetCustomErrorCodes() []int {
if a.Credentials == nil {
return nil

View File

@@ -0,0 +1,117 @@
//go:build unit
package service
import (
"encoding/json"
"testing"
"github.com/stretchr/testify/require"
)
func TestGetPoolModeRetryCount(t *testing.T) {
tests := []struct {
name string
account *Account
expected int
}{
{
name: "default_when_not_pool_mode",
account: &Account{
Type: AccountTypeAPIKey,
Platform: PlatformOpenAI,
Credentials: map[string]any{},
},
expected: defaultPoolModeRetryCount,
},
{
name: "default_when_missing_retry_count",
account: &Account{
Type: AccountTypeAPIKey,
Platform: PlatformOpenAI,
Credentials: map[string]any{
"pool_mode": true,
},
},
expected: defaultPoolModeRetryCount,
},
{
name: "supports_float64_from_json_credentials",
account: &Account{
Type: AccountTypeAPIKey,
Platform: PlatformOpenAI,
Credentials: map[string]any{
"pool_mode": true,
"pool_mode_retry_count": float64(5),
},
},
expected: 5,
},
{
name: "supports_json_number",
account: &Account{
Type: AccountTypeAPIKey,
Platform: PlatformOpenAI,
Credentials: map[string]any{
"pool_mode": true,
"pool_mode_retry_count": json.Number("4"),
},
},
expected: 4,
},
{
name: "supports_string_value",
account: &Account{
Type: AccountTypeAPIKey,
Platform: PlatformOpenAI,
Credentials: map[string]any{
"pool_mode": true,
"pool_mode_retry_count": "2",
},
},
expected: 2,
},
{
name: "negative_value_is_clamped_to_zero",
account: &Account{
Type: AccountTypeAPIKey,
Platform: PlatformOpenAI,
Credentials: map[string]any{
"pool_mode": true,
"pool_mode_retry_count": -1,
},
},
expected: 0,
},
{
name: "oversized_value_is_clamped_to_max",
account: &Account{
Type: AccountTypeAPIKey,
Platform: PlatformOpenAI,
Credentials: map[string]any{
"pool_mode": true,
"pool_mode_retry_count": 99,
},
},
expected: maxPoolModeRetryCount,
},
{
name: "invalid_value_falls_back_to_default",
account: &Account{
Type: AccountTypeAPIKey,
Platform: PlatformOpenAI,
Credentials: map[string]any{
"pool_mode": true,
"pool_mode_retry_count": "oops",
},
},
expected: defaultPoolModeRetryCount,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
require.Equal(t, tt.expected, tt.account.GetPoolModeRetryCount())
})
}
}

View File

@@ -177,6 +177,36 @@ func TestCheckErrorPolicy(t *testing.T) {
body: []byte(`overloaded`),
expected: ErrorPolicyMatched, // custom codes take precedence
},
{
name: "pool_mode_custom_error_codes_hit_returns_matched",
account: &Account{
ID: 7,
Type: AccountTypeAPIKey,
Platform: PlatformOpenAI,
Credentials: map[string]any{
"pool_mode": true,
"custom_error_codes_enabled": true,
"custom_error_codes": []any{float64(401), float64(403)},
},
},
statusCode: 401,
body: []byte(`unauthorized`),
expected: ErrorPolicyMatched,
},
{
name: "pool_mode_without_custom_error_codes_returns_skipped",
account: &Account{
ID: 8,
Type: AccountTypeAPIKey,
Platform: PlatformOpenAI,
Credentials: map[string]any{
"pool_mode": true,
},
},
statusCode: 401,
body: []byte(`unauthorized`),
expected: ErrorPolicySkipped,
},
}
for _, tt := range tests {
@@ -190,6 +220,48 @@ func TestCheckErrorPolicy(t *testing.T) {
}
}
func TestHandleUpstreamError_PoolModeCustomErrorCodesOverride(t *testing.T) {
t.Run("pool_mode_without_custom_error_codes_still_skips", func(t *testing.T) {
repo := &errorPolicyRepoStub{}
svc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
account := &Account{
ID: 30,
Type: AccountTypeAPIKey,
Platform: PlatformOpenAI,
Credentials: map[string]any{
"pool_mode": true,
},
}
shouldDisable := svc.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized"))
require.False(t, shouldDisable)
require.Equal(t, 0, repo.setErrCalls)
require.Equal(t, 0, repo.tempCalls)
})
t.Run("pool_mode_with_custom_error_codes_uses_local_error_policy", func(t *testing.T) {
repo := &errorPolicyRepoStub{}
svc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
account := &Account{
ID: 31,
Type: AccountTypeAPIKey,
Platform: PlatformOpenAI,
Credentials: map[string]any{
"pool_mode": true,
"custom_error_codes_enabled": true,
"custom_error_codes": []any{float64(401)},
},
}
shouldDisable := svc.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized"))
require.True(t, shouldDisable)
require.Equal(t, 1, repo.setErrCalls)
require.Equal(t, 0, repo.tempCalls)
})
}
// ---------------------------------------------------------------------------
// TestApplyErrorPolicy — 4 table-driven cases for the wrapper method
// ---------------------------------------------------------------------------

View File

@@ -4319,7 +4319,11 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
return ""
}(),
})
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody}
return nil, &UpstreamFailoverError{
StatusCode: resp.StatusCode,
ResponseBody: respBody,
RetryableOnSameAccount: account.IsPoolMode() && isPoolModeRetryableStatus(resp.StatusCode),
}
}
return s.handleRetryExhaustedError(ctx, resp, c, account)
}
@@ -4349,7 +4353,11 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
return ""
}(),
})
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody}
return nil, &UpstreamFailoverError{
StatusCode: resp.StatusCode,
ResponseBody: respBody,
RetryableOnSameAccount: account.IsPoolMode() && isPoolModeRetryableStatus(resp.StatusCode),
}
}
if resp.StatusCode >= 400 {
// 可选:对部分 400 触发 failover默认关闭以保持语义
@@ -4584,7 +4592,11 @@ func (s *GatewayService) forwardAnthropicAPIKeyPassthrough(
return ""
}(),
})
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody}
return nil, &UpstreamFailoverError{
StatusCode: resp.StatusCode,
ResponseBody: respBody,
RetryableOnSameAccount: account.IsPoolMode() && isPoolModeRetryableStatus(resp.StatusCode),
}
}
return s.handleRetryExhaustedError(ctx, resp, c, account)
}
@@ -4614,7 +4626,11 @@ func (s *GatewayService) forwardAnthropicAPIKeyPassthrough(
return ""
}(),
})
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody}
return nil, &UpstreamFailoverError{
StatusCode: resp.StatusCode,
ResponseBody: respBody,
RetryableOnSameAccount: account.IsPoolMode() && isPoolModeRetryableStatus(resp.StatusCode),
}
}
if resp.StatusCode >= 400 {

View File

@@ -2040,7 +2040,11 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
})
s.handleFailoverSideEffects(ctx, resp, account)
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody}
return nil, &UpstreamFailoverError{
StatusCode: resp.StatusCode,
ResponseBody: respBody,
RetryableOnSameAccount: account.IsPoolMode() && isPoolModeRetryableStatus(resp.StatusCode),
}
}
return s.handleErrorResponse(ctx, resp, c, account, body)
}
@@ -2853,7 +2857,11 @@ func (s *OpenAIGatewayService) handleErrorResponse(
Detail: upstreamDetail,
})
if shouldDisable {
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: body}
return nil, &UpstreamFailoverError{
StatusCode: resp.StatusCode,
ResponseBody: body,
RetryableOnSameAccount: account.IsPoolMode() && isPoolModeRetryableStatus(resp.StatusCode),
}
}
// Return appropriate error response

View File

@@ -98,6 +98,9 @@ func (s *RateLimitService) CheckErrorPolicy(ctx context.Context, account *Accoun
slog.Info("account_error_code_skipped", "account_id", account.ID, "status_code", statusCode)
return ErrorPolicySkipped
}
if account.IsPoolMode() {
return ErrorPolicySkipped
}
if s.tryTempUnschedulable(ctx, account, statusCode, responseBody) {
return ErrorPolicyTempUnscheduled
}
@@ -107,9 +110,16 @@ func (s *RateLimitService) CheckErrorPolicy(ctx context.Context, account *Accoun
// HandleUpstreamError 处理上游错误响应,标记账号状态
// 返回是否应该停止该账号的调度
func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Account, statusCode int, headers http.Header, responseBody []byte) (shouldDisable bool) {
customErrorCodesEnabled := account.IsCustomErrorCodesEnabled()
// 池模式默认不标记本地账号状态;仅当用户显式配置自定义错误码时按本地策略处理。
if account.IsPoolMode() && !customErrorCodesEnabled {
slog.Info("pool_mode_error_skipped", "account_id", account.ID, "status_code", statusCode)
return false
}
// apikey 类型账号:检查自定义错误码配置
// 如果启用且错误码不在列表中,则不处理(不停止调度、不标记限流/过载)
customErrorCodesEnabled := account.IsCustomErrorCodesEnabled()
if !account.ShouldHandleErrorCode(statusCode) {
slog.Info("account_error_code_skipped", "account_id", account.ID, "status_code", statusCode)
return false