feat: 支持 API Key 上游池模式同账号重试次数配置与自定义错误策略

This commit is contained in:
kyx236
2026-03-08 13:57:23 +08:00
parent 03bf348530
commit e643fc382c
13 changed files with 558 additions and 40 deletions

View File

@@ -30,7 +30,7 @@ const (
const (
// maxSameAccountRetries 同账号重试次数上限(针对 RetryableOnSameAccount 错误)
maxSameAccountRetries = 2
maxSameAccountRetries = 3
// sameAccountRetryDelay 同账号重试间隔
sameAccountRetryDelay = 500 * time.Millisecond
// singleAccountBackoffDelay 单账号分组 503 退避重试固定延时。

View File

@@ -291,35 +291,31 @@ func TestHandleFailoverError_SameAccountRetry(t *testing.T) {
require.Less(t, elapsed, 2*time.Second)
})
t.Run("第二次重试仍返回FailoverContinue", func(t *testing.T) {
t.Run("达到最大重试次数前均返回FailoverContinue", func(t *testing.T) {
mock := &mockTempUnscheduler{}
fs := NewFailoverState(3, false)
err := newTestFailoverErr(400, true, false)
// 第一次
action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
require.Equal(t, FailoverContinue, action)
require.Equal(t, 1, fs.SameAccountRetryCount[100])
for i := 1; i <= maxSameAccountRetries; i++ {
action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
require.Equal(t, FailoverContinue, action)
require.Equal(t, i, fs.SameAccountRetryCount[100])
}
// 第二次
action = fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
require.Equal(t, FailoverContinue, action)
require.Equal(t, 2, fs.SameAccountRetryCount[100])
require.Empty(t, mock.calls, "两次重试期间均不应调用 TempUnschedule")
require.Empty(t, mock.calls, "达到最大重试次数前均不应调用 TempUnschedule")
})
t.Run("第三次重试耗尽_触发TempUnschedule并切换", func(t *testing.T) {
t.Run("超过最大重试次数后触发TempUnschedule并切换", func(t *testing.T) {
mock := &mockTempUnscheduler{}
fs := NewFailoverState(3, false)
err := newTestFailoverErr(400, true, false)
// 第一次、第二次重试
fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
require.Equal(t, 2, fs.SameAccountRetryCount[100])
for i := 0; i < maxSameAccountRetries; i++ {
fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
}
require.Equal(t, maxSameAccountRetries, fs.SameAccountRetryCount[100])
// 第三次:重试已达到 maxSameAccountRetries(2),应切换账号
// 第 maxSameAccountRetries+1 次:重试耗尽,应切换账号
action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
require.Equal(t, FailoverContinue, action)
require.Equal(t, 1, fs.SwitchCount)
@@ -354,13 +350,14 @@ func TestHandleFailoverError_SameAccountRetry(t *testing.T) {
err := newTestFailoverErr(400, true, false)
// 耗尽账号 100 的重试
fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
// 第三次: 重试耗尽 → 切换
for i := 0; i < maxSameAccountRetries; i++ {
fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
}
// 第 maxSameAccountRetries+1 次: 重试耗尽 → 切换
action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
require.Equal(t, FailoverContinue, action)
// 再次遇到账号 100计数仍为 2,条件不满足 → 直接切换
// 再次遇到账号 100计数仍为 maxSameAccountRetries,条件不满足 → 直接切换
action = fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
require.Equal(t, FailoverContinue, action)
require.Len(t, mock.calls, 2, "第二次耗尽也应调用 TempUnschedule")
@@ -386,9 +383,10 @@ func TestHandleFailoverError_TempUnschedule(t *testing.T) {
fs := NewFailoverState(3, false)
err := newTestFailoverErr(502, true, false)
// 耗尽重试
fs.HandleFailoverError(context.Background(), mock, 42, "openai", err)
fs.HandleFailoverError(context.Background(), mock, 42, "openai", err)
for i := 0; i < maxSameAccountRetries; i++ {
fs.HandleFailoverError(context.Background(), mock, 42, "openai", err)
}
// 再次触发时才会执行 TempUnschedule + 切换
fs.HandleFailoverError(context.Background(), mock, 42, "openai", err)
require.Len(t, mock.calls, 1)
@@ -521,17 +519,16 @@ func TestHandleFailoverError_IntegrationScenario(t *testing.T) {
mock := &mockTempUnscheduler{}
fs := NewFailoverState(3, true) // hasBoundSession=true
// 1. 账号 100 遇到可重试错误,同账号重试 2
// 1. 账号 100 遇到可重试错误,同账号重试 maxSameAccountRetries
retryErr := newTestFailoverErr(400, true, false)
action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", retryErr)
require.Equal(t, FailoverContinue, action)
for i := 0; i < maxSameAccountRetries; i++ {
action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", retryErr)
require.Equal(t, FailoverContinue, action)
}
require.True(t, fs.ForceCacheBilling, "hasBoundSession=true 应设置 ForceCacheBilling")
action = fs.HandleFailoverError(context.Background(), mock, 100, "openai", retryErr)
require.Equal(t, FailoverContinue, action)
// 2. 账号 100 重试耗尽 → TempUnschedule + 切换
action = fs.HandleFailoverError(context.Background(), mock, 100, "openai", retryErr)
// 2. 账号 100 超过重试上限 → TempUnschedule + 切换
action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", retryErr)
require.Equal(t, FailoverContinue, action)
require.Equal(t, 1, fs.SwitchCount)
require.Len(t, mock.calls, 1)

View File

@@ -20,6 +20,7 @@ import (
coderws "github.com/coder/websocket"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
"github.com/tidwall/gjson"
"go.uber.org/zap"
)
@@ -212,6 +213,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
maxAccountSwitches := h.maxAccountSwitches
switchCount := 0
failedAccountIDs := make(map[int64]struct{})
sameAccountRetryCount := make(map[int64]int)
var lastFailoverErr *service.UpstreamFailoverError
for {
@@ -259,6 +261,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
zap.Float64("load_skew", scheduleDecision.LoadSkew),
)
account := selection.Account
sessionHash = ensureOpenAIPoolModeSessionHash(sessionHash, account)
reqLog.Debug("openai.account_selected", zap.Int64("account_id", account.ID), zap.String("account_name", account.Name))
setOpsSelectedAccount(c, account.ID, account.Platform)
@@ -288,6 +291,25 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
var failoverErr *service.UpstreamFailoverError
if errors.As(err, &failoverErr) {
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
// 池模式:同账号重试
if failoverErr.RetryableOnSameAccount {
retryLimit := account.GetPoolModeRetryCount()
if sameAccountRetryCount[account.ID] < retryLimit {
sameAccountRetryCount[account.ID]++
reqLog.Warn("openai.pool_mode_same_account_retry",
zap.Int64("account_id", account.ID),
zap.Int("upstream_status", failoverErr.StatusCode),
zap.Int("retry_limit", retryLimit),
zap.Int("retry_count", sameAccountRetryCount[account.ID]),
)
select {
case <-c.Request.Context().Done():
return
case <-time.After(sameAccountRetryDelay):
}
continue
}
}
h.gatewayService.RecordOpenAIAccountSwitch()
failedAccountIDs[account.ID] = struct{}{}
lastFailoverErr = failoverErr
@@ -541,6 +563,7 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
maxAccountSwitches := h.maxAccountSwitches
switchCount := 0
failedAccountIDs := make(map[int64]struct{})
sameAccountRetryCount := make(map[int64]int)
var lastFailoverErr *service.UpstreamFailoverError
for {
@@ -602,6 +625,7 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
return
}
account := selection.Account
sessionHash = ensureOpenAIPoolModeSessionHash(sessionHash, account)
reqLog.Debug("openai_messages.account_selected", zap.Int64("account_id", account.ID), zap.String("account_name", account.Name))
_ = scheduleDecision
setOpsSelectedAccount(c, account.ID, account.Platform)
@@ -641,6 +665,25 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
var failoverErr *service.UpstreamFailoverError
if errors.As(err, &failoverErr) {
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
// 池模式:同账号重试
if failoverErr.RetryableOnSameAccount {
retryLimit := account.GetPoolModeRetryCount()
if sameAccountRetryCount[account.ID] < retryLimit {
sameAccountRetryCount[account.ID]++
reqLog.Warn("openai_messages.pool_mode_same_account_retry",
zap.Int64("account_id", account.ID),
zap.Int("upstream_status", failoverErr.StatusCode),
zap.Int("retry_limit", retryLimit),
zap.Int("retry_count", sameAccountRetryCount[account.ID]),
)
select {
case <-c.Request.Context().Done():
return
case <-time.After(sameAccountRetryDelay):
}
continue
}
}
h.gatewayService.RecordOpenAIAccountSwitch()
failedAccountIDs[account.ID] = struct{}{}
lastFailoverErr = failoverErr
@@ -1456,6 +1499,14 @@ func setOpenAIClientTransportWS(c *gin.Context) {
service.SetOpenAIClientTransport(c, service.OpenAIClientTransportWS)
}
func ensureOpenAIPoolModeSessionHash(sessionHash string, account *service.Account) string {
if sessionHash != "" || account == nil || !account.IsPoolMode() {
return sessionHash
}
// 为当前请求生成一次性粘性会话键,确保同账号重试不会重新负载均衡到其他账号。
return "openai-pool-retry-" + uuid.NewString()
}
func openAIWSIngressFallbackSessionSeed(userID, apiKeyID int64, groupID *int64) string {
gid := int64(0)
if groupID != nil {

View File

@@ -647,6 +647,75 @@ func (a *Account) IsCustomErrorCodesEnabled() bool {
return false
}
// IsPoolMode 检查 API Key 账号是否启用池模式。
// 池模式下,上游错误不标记本地账号状态,而是在同一账号上重试。
func (a *Account) IsPoolMode() bool {
if a.Type != AccountTypeAPIKey || a.Credentials == nil {
return false
}
if v, ok := a.Credentials["pool_mode"]; ok {
if enabled, ok := v.(bool); ok {
return enabled
}
}
return false
}
const (
defaultPoolModeRetryCount = 3
maxPoolModeRetryCount = 10
)
// GetPoolModeRetryCount 返回池模式同账号重试次数。
// 未配置或配置非法时回退为默认值 3小于 0 按 0 处理;过大则截断到 10。
func (a *Account) GetPoolModeRetryCount() int {
if a == nil || !a.IsPoolMode() || a.Credentials == nil {
return defaultPoolModeRetryCount
}
raw, ok := a.Credentials["pool_mode_retry_count"]
if !ok || raw == nil {
return defaultPoolModeRetryCount
}
count := parsePoolModeRetryCount(raw)
if count < 0 {
return 0
}
if count > maxPoolModeRetryCount {
return maxPoolModeRetryCount
}
return count
}
func parsePoolModeRetryCount(value any) int {
switch v := value.(type) {
case int:
return v
case int64:
return int(v)
case float64:
return int(v)
case json.Number:
if i, err := v.Int64(); err == nil {
return int(i)
}
case string:
if i, err := strconv.Atoi(strings.TrimSpace(v)); err == nil {
return i
}
}
return defaultPoolModeRetryCount
}
// isPoolModeRetryableStatus 池模式下应触发同账号重试的状态码
func isPoolModeRetryableStatus(statusCode int) bool {
switch statusCode {
case 401, 403, 429:
return true
default:
return false
}
}
func (a *Account) GetCustomErrorCodes() []int {
if a.Credentials == nil {
return nil

View File

@@ -0,0 +1,117 @@
//go:build unit
package service
import (
"encoding/json"
"testing"
"github.com/stretchr/testify/require"
)
func TestGetPoolModeRetryCount(t *testing.T) {
tests := []struct {
name string
account *Account
expected int
}{
{
name: "default_when_not_pool_mode",
account: &Account{
Type: AccountTypeAPIKey,
Platform: PlatformOpenAI,
Credentials: map[string]any{},
},
expected: defaultPoolModeRetryCount,
},
{
name: "default_when_missing_retry_count",
account: &Account{
Type: AccountTypeAPIKey,
Platform: PlatformOpenAI,
Credentials: map[string]any{
"pool_mode": true,
},
},
expected: defaultPoolModeRetryCount,
},
{
name: "supports_float64_from_json_credentials",
account: &Account{
Type: AccountTypeAPIKey,
Platform: PlatformOpenAI,
Credentials: map[string]any{
"pool_mode": true,
"pool_mode_retry_count": float64(5),
},
},
expected: 5,
},
{
name: "supports_json_number",
account: &Account{
Type: AccountTypeAPIKey,
Platform: PlatformOpenAI,
Credentials: map[string]any{
"pool_mode": true,
"pool_mode_retry_count": json.Number("4"),
},
},
expected: 4,
},
{
name: "supports_string_value",
account: &Account{
Type: AccountTypeAPIKey,
Platform: PlatformOpenAI,
Credentials: map[string]any{
"pool_mode": true,
"pool_mode_retry_count": "2",
},
},
expected: 2,
},
{
name: "negative_value_is_clamped_to_zero",
account: &Account{
Type: AccountTypeAPIKey,
Platform: PlatformOpenAI,
Credentials: map[string]any{
"pool_mode": true,
"pool_mode_retry_count": -1,
},
},
expected: 0,
},
{
name: "oversized_value_is_clamped_to_max",
account: &Account{
Type: AccountTypeAPIKey,
Platform: PlatformOpenAI,
Credentials: map[string]any{
"pool_mode": true,
"pool_mode_retry_count": 99,
},
},
expected: maxPoolModeRetryCount,
},
{
name: "invalid_value_falls_back_to_default",
account: &Account{
Type: AccountTypeAPIKey,
Platform: PlatformOpenAI,
Credentials: map[string]any{
"pool_mode": true,
"pool_mode_retry_count": "oops",
},
},
expected: defaultPoolModeRetryCount,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
require.Equal(t, tt.expected, tt.account.GetPoolModeRetryCount())
})
}
}

View File

@@ -177,6 +177,36 @@ func TestCheckErrorPolicy(t *testing.T) {
body: []byte(`overloaded`),
expected: ErrorPolicyMatched, // custom codes take precedence
},
{
name: "pool_mode_custom_error_codes_hit_returns_matched",
account: &Account{
ID: 7,
Type: AccountTypeAPIKey,
Platform: PlatformOpenAI,
Credentials: map[string]any{
"pool_mode": true,
"custom_error_codes_enabled": true,
"custom_error_codes": []any{float64(401), float64(403)},
},
},
statusCode: 401,
body: []byte(`unauthorized`),
expected: ErrorPolicyMatched,
},
{
name: "pool_mode_without_custom_error_codes_returns_skipped",
account: &Account{
ID: 8,
Type: AccountTypeAPIKey,
Platform: PlatformOpenAI,
Credentials: map[string]any{
"pool_mode": true,
},
},
statusCode: 401,
body: []byte(`unauthorized`),
expected: ErrorPolicySkipped,
},
}
for _, tt := range tests {
@@ -190,6 +220,48 @@ func TestCheckErrorPolicy(t *testing.T) {
}
}
func TestHandleUpstreamError_PoolModeCustomErrorCodesOverride(t *testing.T) {
t.Run("pool_mode_without_custom_error_codes_still_skips", func(t *testing.T) {
repo := &errorPolicyRepoStub{}
svc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
account := &Account{
ID: 30,
Type: AccountTypeAPIKey,
Platform: PlatformOpenAI,
Credentials: map[string]any{
"pool_mode": true,
},
}
shouldDisable := svc.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized"))
require.False(t, shouldDisable)
require.Equal(t, 0, repo.setErrCalls)
require.Equal(t, 0, repo.tempCalls)
})
t.Run("pool_mode_with_custom_error_codes_uses_local_error_policy", func(t *testing.T) {
repo := &errorPolicyRepoStub{}
svc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
account := &Account{
ID: 31,
Type: AccountTypeAPIKey,
Platform: PlatformOpenAI,
Credentials: map[string]any{
"pool_mode": true,
"custom_error_codes_enabled": true,
"custom_error_codes": []any{float64(401)},
},
}
shouldDisable := svc.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized"))
require.True(t, shouldDisable)
require.Equal(t, 1, repo.setErrCalls)
require.Equal(t, 0, repo.tempCalls)
})
}
// ---------------------------------------------------------------------------
// TestApplyErrorPolicy — 4 table-driven cases for the wrapper method
// ---------------------------------------------------------------------------

View File

@@ -4319,7 +4319,11 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
return ""
}(),
})
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody}
return nil, &UpstreamFailoverError{
StatusCode: resp.StatusCode,
ResponseBody: respBody,
RetryableOnSameAccount: account.IsPoolMode() && isPoolModeRetryableStatus(resp.StatusCode),
}
}
return s.handleRetryExhaustedError(ctx, resp, c, account)
}
@@ -4349,7 +4353,11 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
return ""
}(),
})
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody}
return nil, &UpstreamFailoverError{
StatusCode: resp.StatusCode,
ResponseBody: respBody,
RetryableOnSameAccount: account.IsPoolMode() && isPoolModeRetryableStatus(resp.StatusCode),
}
}
if resp.StatusCode >= 400 {
// 可选:对部分 400 触发 failover默认关闭以保持语义
@@ -4584,7 +4592,11 @@ func (s *GatewayService) forwardAnthropicAPIKeyPassthrough(
return ""
}(),
})
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody}
return nil, &UpstreamFailoverError{
StatusCode: resp.StatusCode,
ResponseBody: respBody,
RetryableOnSameAccount: account.IsPoolMode() && isPoolModeRetryableStatus(resp.StatusCode),
}
}
return s.handleRetryExhaustedError(ctx, resp, c, account)
}
@@ -4614,7 +4626,11 @@ func (s *GatewayService) forwardAnthropicAPIKeyPassthrough(
return ""
}(),
})
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody}
return nil, &UpstreamFailoverError{
StatusCode: resp.StatusCode,
ResponseBody: respBody,
RetryableOnSameAccount: account.IsPoolMode() && isPoolModeRetryableStatus(resp.StatusCode),
}
}
if resp.StatusCode >= 400 {

View File

@@ -2002,7 +2002,11 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
})
s.handleFailoverSideEffects(ctx, resp, account)
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody}
return nil, &UpstreamFailoverError{
StatusCode: resp.StatusCode,
ResponseBody: respBody,
RetryableOnSameAccount: account.IsPoolMode() && isPoolModeRetryableStatus(resp.StatusCode),
}
}
return s.handleErrorResponse(ctx, resp, c, account, body)
}
@@ -2815,7 +2819,11 @@ func (s *OpenAIGatewayService) handleErrorResponse(
Detail: upstreamDetail,
})
if shouldDisable {
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: body}
return nil, &UpstreamFailoverError{
StatusCode: resp.StatusCode,
ResponseBody: body,
RetryableOnSameAccount: account.IsPoolMode() && isPoolModeRetryableStatus(resp.StatusCode),
}
}
// Return appropriate error response

View File

@@ -87,6 +87,9 @@ func (s *RateLimitService) CheckErrorPolicy(ctx context.Context, account *Accoun
slog.Info("account_error_code_skipped", "account_id", account.ID, "status_code", statusCode)
return ErrorPolicySkipped
}
if account.IsPoolMode() {
return ErrorPolicySkipped
}
if s.tryTempUnschedulable(ctx, account, statusCode, responseBody) {
return ErrorPolicyTempUnscheduled
}
@@ -96,9 +99,16 @@ func (s *RateLimitService) CheckErrorPolicy(ctx context.Context, account *Accoun
// HandleUpstreamError 处理上游错误响应,标记账号状态
// 返回是否应该停止该账号的调度
func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Account, statusCode int, headers http.Header, responseBody []byte) (shouldDisable bool) {
customErrorCodesEnabled := account.IsCustomErrorCodesEnabled()
// 池模式默认不标记本地账号状态;仅当用户显式配置自定义错误码时按本地策略处理。
if account.IsPoolMode() && !customErrorCodesEnabled {
slog.Info("pool_mode_error_skipped", "account_id", account.ID, "status_code", statusCode)
return false
}
// apikey 类型账号:检查自定义错误码配置
// 如果启用且错误码不在列表中,则不处理(不停止调度、不标记限流/过载)
customErrorCodesEnabled := account.IsCustomErrorCodesEnabled()
if !account.ShouldHandleErrorCode(statusCode) {
slog.Info("account_error_code_skipped", "account_id", account.ID, "status_code", statusCode)
return false