fix: 修复 OpenAI WS 限流状态与调度同步

This commit is contained in:
神乐
2026-03-07 23:59:39 +08:00
parent 0c1dcad429
commit 45d57018eb
3 changed files with 471 additions and 7 deletions

View File

@@ -3899,6 +3899,30 @@ func buildCodexUsageExtraUpdates(snapshot *OpenAICodexUsageSnapshot, fallbackNow
return updates
}
func codexUsagePercentExhausted(value *float64) bool {
return value != nil && *value >= 100-1e-9
}
func codexRateLimitResetAtFromSnapshot(snapshot *OpenAICodexUsageSnapshot, fallbackNow time.Time) *time.Time {
if snapshot == nil {
return nil
}
normalized := snapshot.Normalize()
if normalized == nil {
return nil
}
baseTime := codexSnapshotBaseTime(snapshot, fallbackNow)
if codexUsagePercentExhausted(normalized.Used7dPercent) && normalized.Reset7dSeconds != nil {
resetAt := baseTime.Add(time.Duration(*normalized.Reset7dSeconds) * time.Second)
return &resetAt
}
if codexUsagePercentExhausted(normalized.Used5hPercent) && normalized.Reset5hSeconds != nil {
resetAt := baseTime.Add(time.Duration(*normalized.Reset5hSeconds) * time.Second)
return &resetAt
}
return nil
}
// updateCodexUsageSnapshot saves the Codex usage snapshot to account's Extra field
func (s *OpenAIGatewayService) updateCodexUsageSnapshot(ctx context.Context, accountID int64, snapshot *OpenAICodexUsageSnapshot) {
if snapshot == nil {
@@ -3908,16 +3932,22 @@ func (s *OpenAIGatewayService) updateCodexUsageSnapshot(ctx context.Context, acc
return
}
updates := buildCodexUsageExtraUpdates(snapshot, time.Now())
if len(updates) == 0 {
now := time.Now()
updates := buildCodexUsageExtraUpdates(snapshot, now)
resetAt := codexRateLimitResetAtFromSnapshot(snapshot, now)
if len(updates) == 0 && resetAt == nil {
return
}
// Update account's Extra field asynchronously
go func() {
updateCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
_ = s.accountRepo.UpdateExtra(updateCtx, accountID, updates)
if len(updates) > 0 {
_ = s.accountRepo.UpdateExtra(updateCtx, accountID, updates)
}
if resetAt != nil {
_ = s.accountRepo.SetRateLimited(updateCtx, accountID, *resetAt)
}
}()
}

View File

@@ -1853,6 +1853,10 @@ func (s *OpenAIGatewayService) forwardOpenAIWSV2(
wsPath,
account.ProxyID != nil && account.Proxy != nil,
)
var dialErr *openAIWSDialError
if errors.As(err, &dialErr) && dialErr != nil && dialErr.StatusCode == http.StatusTooManyRequests {
s.persistOpenAIWSRateLimitSignal(ctx, account, dialErr.ResponseHeaders, nil, "rate_limit_exceeded", "rate_limit_error", strings.TrimSpace(err.Error()))
}
return nil, wrapOpenAIWSFallback(classifyOpenAIWSAcquireError(err), err)
}
defer lease.Release()
@@ -2136,6 +2140,7 @@ func (s *OpenAIGatewayService) forwardOpenAIWSV2(
if eventType == "error" {
errCodeRaw, errTypeRaw, errMsgRaw := parseOpenAIWSErrorEventFields(message)
s.persistOpenAIWSRateLimitSignal(ctx, account, lease.HandshakeHeaders(), message, errCodeRaw, errTypeRaw, errMsgRaw)
errMsg := strings.TrimSpace(errMsgRaw)
if errMsg == "" {
errMsg = "Upstream websocket error"
@@ -2639,6 +2644,10 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
wsPath,
account.ProxyID != nil && account.Proxy != nil,
)
var dialErr *openAIWSDialError
if errors.As(acquireErr, &dialErr) && dialErr != nil && dialErr.StatusCode == http.StatusTooManyRequests {
s.persistOpenAIWSRateLimitSignal(ctx, account, dialErr.ResponseHeaders, nil, "rate_limit_exceeded", "rate_limit_error", strings.TrimSpace(acquireErr.Error()))
}
if errors.Is(acquireErr, errOpenAIWSPreferredConnUnavailable) {
return nil, NewOpenAIWSClientCloseError(
coderws.StatusPolicyViolation,
@@ -2777,6 +2786,7 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
}
if eventType == "error" {
errCodeRaw, errTypeRaw, errMsgRaw := parseOpenAIWSErrorEventFields(upstreamMessage)
s.persistOpenAIWSRateLimitSignal(ctx, account, lease.HandshakeHeaders(), upstreamMessage, errCodeRaw, errTypeRaw, errMsgRaw)
fallbackReason, _ := classifyOpenAIWSErrorEventFromRaw(errCodeRaw, errTypeRaw, errMsgRaw)
errCode, errType, errMessage := summarizeOpenAIWSErrorEventFieldsFromRaw(errCodeRaw, errTypeRaw, errMsgRaw)
recoverablePrevNotFound := fallbackReason == openAIWSIngressStagePreviousResponseNotFound &&
@@ -3604,6 +3614,7 @@ func (s *OpenAIGatewayService) performOpenAIWSGeneratePrewarm(
if eventType == "error" {
errCodeRaw, errTypeRaw, errMsgRaw := parseOpenAIWSErrorEventFields(message)
s.persistOpenAIWSRateLimitSignal(ctx, account, lease.HandshakeHeaders(), message, errCodeRaw, errTypeRaw, errMsgRaw)
errMsg := strings.TrimSpace(errMsgRaw)
if errMsg == "" {
errMsg = "OpenAI websocket prewarm error"
@@ -3867,6 +3878,36 @@ func classifyOpenAIWSAcquireError(err error) string {
return "acquire_conn"
}
func isOpenAIWSRateLimitError(codeRaw, errTypeRaw, msgRaw string) bool {
code := strings.ToLower(strings.TrimSpace(codeRaw))
errType := strings.ToLower(strings.TrimSpace(errTypeRaw))
msg := strings.ToLower(strings.TrimSpace(msgRaw))
if strings.Contains(errType, "rate_limit") || strings.Contains(errType, "usage_limit") {
return true
}
if strings.Contains(code, "rate_limit") || strings.Contains(code, "usage_limit") || strings.Contains(code, "insufficient_quota") {
return true
}
if strings.Contains(msg, "usage limit") && strings.Contains(msg, "reached") {
return true
}
if strings.Contains(msg, "rate limit") && (strings.Contains(msg, "reached") || strings.Contains(msg, "exceeded")) {
return true
}
return false
}
func (s *OpenAIGatewayService) persistOpenAIWSRateLimitSignal(ctx context.Context, account *Account, headers http.Header, responseBody []byte, codeRaw, errTypeRaw, msgRaw string) {
if s == nil || s.rateLimitService == nil || account == nil || account.Platform != PlatformOpenAI {
return
}
if !isOpenAIWSRateLimitError(codeRaw, errTypeRaw, msgRaw) {
return
}
s.rateLimitService.HandleUpstreamError(ctx, account, http.StatusTooManyRequests, headers, responseBody)
}
func classifyOpenAIWSErrorEventFromRaw(codeRaw, errTypeRaw, msgRaw string) (string, bool) {
code := strings.ToLower(strings.TrimSpace(codeRaw))
errType := strings.ToLower(strings.TrimSpace(errTypeRaw))
@@ -3882,6 +3923,9 @@ func classifyOpenAIWSErrorEventFromRaw(codeRaw, errTypeRaw, msgRaw string) (stri
case "previous_response_not_found":
return "previous_response_not_found", true
}
if isOpenAIWSRateLimitError(codeRaw, errTypeRaw, msgRaw) {
return "upstream_rate_limited", false
}
if strings.Contains(msg, "upgrade required") || strings.Contains(msg, "status 426") {
return "upgrade_required", true
}
@@ -3927,9 +3971,7 @@ func openAIWSErrorHTTPStatusFromRaw(codeRaw, errTypeRaw string) int {
case strings.Contains(errType, "permission"),
strings.Contains(code, "forbidden"):
return http.StatusForbidden
case strings.Contains(errType, "rate_limit"),
strings.Contains(code, "rate_limit"),
strings.Contains(code, "insufficient_quota"):
case isOpenAIWSRateLimitError(codeRaw, errTypeRaw, ""):
return http.StatusTooManyRequests
default:
return http.StatusBadGateway

View File

@@ -0,0 +1,392 @@
package service
import (
"context"
"io"
"net/http"
"net/http/httptest"
"strconv"
"strings"
"testing"
"time"
coderws "github.com/coder/websocket"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
"github.com/stretchr/testify/require"
)
type openAIWSRateLimitSignalRepo struct {
stubOpenAIAccountRepo
rateLimitCalls []time.Time
updateExtra []map[string]any
}
type openAICodexSnapshotAsyncRepo struct {
stubOpenAIAccountRepo
updateExtraCh chan map[string]any
rateLimitCh chan time.Time
}
func (r *openAIWSRateLimitSignalRepo) SetRateLimited(_ context.Context, _ int64, resetAt time.Time) error {
r.rateLimitCalls = append(r.rateLimitCalls, resetAt)
return nil
}
func (r *openAIWSRateLimitSignalRepo) UpdateExtra(_ context.Context, _ int64, updates map[string]any) error {
copied := make(map[string]any, len(updates))
for k, v := range updates {
copied[k] = v
}
r.updateExtra = append(r.updateExtra, copied)
return nil
}
func (r *openAICodexSnapshotAsyncRepo) SetRateLimited(_ context.Context, _ int64, resetAt time.Time) error {
if r.rateLimitCh != nil {
r.rateLimitCh <- resetAt
}
return nil
}
func (r *openAICodexSnapshotAsyncRepo) UpdateExtra(_ context.Context, _ int64, updates map[string]any) error {
if r.updateExtraCh != nil {
copied := make(map[string]any, len(updates))
for k, v := range updates {
copied[k] = v
}
r.updateExtraCh <- copied
}
return nil
}
func TestOpenAIGatewayService_Forward_WSv2ErrorEventUsageLimitPersistsRateLimit(t *testing.T) {
gin.SetMode(gin.TestMode)
resetAt := time.Now().Add(2 * time.Hour).Unix()
upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}
wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
t.Errorf("upgrade websocket failed: %v", err)
return
}
defer func() { _ = conn.Close() }()
var req map[string]any
if err := conn.ReadJSON(&req); err != nil {
t.Errorf("read ws request failed: %v", err)
return
}
_ = conn.WriteJSON(map[string]any{
"type": "error",
"error": map[string]any{
"code": "rate_limit_exceeded",
"type": "usage_limit_reached",
"message": "The usage limit has been reached",
"resets_at": resetAt,
},
})
}))
defer wsServer.Close()
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
c.Request.Header.Set("User-Agent", "unit-test-agent/1.0")
upstream := &httpUpstreamRecorder{
resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"application/json"}},
Body: io.NopCloser(strings.NewReader(`{"id":"resp_http_should_not_run"}`)),
},
}
cfg := newOpenAIWSV2TestConfig()
cfg.Security.URLAllowlist.Enabled = false
cfg.Security.URLAllowlist.AllowInsecureHTTP = true
account := Account{
ID: 501,
Name: "openai-ws-rate-limit-event",
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Credentials: map[string]any{
"api_key": "sk-test",
"base_url": wsServer.URL,
},
Extra: map[string]any{
"responses_websockets_v2_enabled": true,
},
}
repo := &openAIWSRateLimitSignalRepo{stubOpenAIAccountRepo: stubOpenAIAccountRepo{accounts: []Account{account}}}
rateSvc := &RateLimitService{accountRepo: repo}
svc := &OpenAIGatewayService{
accountRepo: repo,
rateLimitService: rateSvc,
httpUpstream: upstream,
cache: &stubGatewayCache{},
cfg: cfg,
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
toolCorrector: NewCodexToolCorrector(),
}
body := []byte(`{"model":"gpt-5.1","stream":false,"input":[{"type":"input_text","text":"hello"}]}`)
result, err := svc.Forward(context.Background(), c, &account, body)
require.Error(t, err)
require.Nil(t, result)
require.Equal(t, http.StatusTooManyRequests, rec.Code)
require.Nil(t, upstream.lastReq, "WS 限流 error event 不应回退到同账号 HTTP")
require.Len(t, repo.rateLimitCalls, 1)
require.WithinDuration(t, time.Unix(resetAt, 0), repo.rateLimitCalls[0], 2*time.Second)
}
func TestOpenAIGatewayService_Forward_WSv2Handshake429PersistsRateLimit(t *testing.T) {
gin.SetMode(gin.TestMode)
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("x-codex-primary-used-percent", "100")
w.Header().Set("x-codex-primary-reset-after-seconds", "7200")
w.Header().Set("x-codex-primary-window-minutes", "10080")
w.Header().Set("x-codex-secondary-used-percent", "3")
w.Header().Set("x-codex-secondary-reset-after-seconds", "1800")
w.Header().Set("x-codex-secondary-window-minutes", "300")
w.WriteHeader(http.StatusTooManyRequests)
_, _ = w.Write([]byte(`{"error":{"type":"rate_limit_exceeded","message":"rate limited"}}`))
}))
defer server.Close()
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
c.Request.Header.Set("User-Agent", "unit-test-agent/1.0")
upstream := &httpUpstreamRecorder{
resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"application/json"}},
Body: io.NopCloser(strings.NewReader(`{"id":"resp_http_should_not_run"}`)),
},
}
cfg := newOpenAIWSV2TestConfig()
cfg.Security.URLAllowlist.Enabled = false
cfg.Security.URLAllowlist.AllowInsecureHTTP = true
account := Account{
ID: 502,
Name: "openai-ws-rate-limit-handshake",
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Credentials: map[string]any{
"api_key": "sk-test",
"base_url": server.URL,
},
Extra: map[string]any{
"responses_websockets_v2_enabled": true,
},
}
repo := &openAIWSRateLimitSignalRepo{stubOpenAIAccountRepo: stubOpenAIAccountRepo{accounts: []Account{account}}}
rateSvc := &RateLimitService{accountRepo: repo}
svc := &OpenAIGatewayService{
accountRepo: repo,
rateLimitService: rateSvc,
httpUpstream: upstream,
cache: &stubGatewayCache{},
cfg: cfg,
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
toolCorrector: NewCodexToolCorrector(),
}
body := []byte(`{"model":"gpt-5.1","stream":false,"input":[{"type":"input_text","text":"hello"}]}`)
result, err := svc.Forward(context.Background(), c, &account, body)
require.Error(t, err)
require.Nil(t, result)
require.Equal(t, http.StatusTooManyRequests, rec.Code)
require.Nil(t, upstream.lastReq, "WS 握手 429 不应回退到同账号 HTTP")
require.Len(t, repo.rateLimitCalls, 1)
require.NotEmpty(t, repo.updateExtra, "握手 429 的 x-codex 头应立即落库")
require.Contains(t, repo.updateExtra[0], "codex_usage_updated_at")
}
func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_ErrorEventUsageLimitPersistsRateLimit(t *testing.T) {
gin.SetMode(gin.TestMode)
cfg := newOpenAIWSV2TestConfig()
cfg.Security.URLAllowlist.Enabled = false
cfg.Security.URLAllowlist.AllowInsecureHTTP = true
cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1
cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1
cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8
cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3
cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3
cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3
resetAt := time.Now().Add(90 * time.Minute).Unix()
captureConn := &openAIWSCaptureConn{
events: [][]byte{
[]byte(`{"type":"error","error":{"code":"rate_limit_exceeded","type":"usage_limit_reached","message":"The usage limit has been reached","resets_at":PLACEHOLDER}}`),
},
}
captureConn.events[0] = []byte(strings.ReplaceAll(string(captureConn.events[0]), "PLACEHOLDER", strconv.FormatInt(resetAt, 10)))
captureDialer := &openAIWSCaptureDialer{conn: captureConn}
pool := newOpenAIWSConnPool(cfg)
pool.setClientDialerForTest(captureDialer)
account := Account{
ID: 503,
Name: "openai-ingress-rate-limit",
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Credentials: map[string]any{
"api_key": "sk-test",
},
Extra: map[string]any{
"responses_websockets_v2_enabled": true,
},
}
repo := &openAIWSRateLimitSignalRepo{stubOpenAIAccountRepo: stubOpenAIAccountRepo{accounts: []Account{account}}}
rateSvc := &RateLimitService{accountRepo: repo}
svc := &OpenAIGatewayService{
accountRepo: repo,
rateLimitService: rateSvc,
httpUpstream: &httpUpstreamRecorder{},
cache: &stubGatewayCache{},
cfg: cfg,
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
toolCorrector: NewCodexToolCorrector(),
openaiWSPool: pool,
}
serverErrCh := make(chan error, 1)
wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{CompressionMode: coderws.CompressionContextTakeover})
if err != nil {
serverErrCh <- err
return
}
defer func() { _ = conn.CloseNow() }()
rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec)
req := r.Clone(r.Context())
req.Header = req.Header.Clone()
req.Header.Set("User-Agent", "unit-test-agent/1.0")
ginCtx.Request = req
readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second)
msgType, firstMessage, readErr := conn.Read(readCtx)
cancel()
if readErr != nil {
serverErrCh <- readErr
return
}
if msgType != coderws.MessageText && msgType != coderws.MessageBinary {
serverErrCh <- io.ErrUnexpectedEOF
return
}
serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, &account, "sk-test", firstMessage, nil)
}))
defer wsServer.Close()
dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second)
clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil)
cancelDial()
require.NoError(t, err)
defer func() { _ = clientConn.CloseNow() }()
writeCtx, cancelWrite := context.WithTimeout(context.Background(), 3*time.Second)
err = clientConn.Write(writeCtx, coderws.MessageText, []byte(`{"type":"response.create","model":"gpt-5.1","stream":false}`))
cancelWrite()
require.NoError(t, err)
select {
case serverErr := <-serverErrCh:
require.Error(t, serverErr)
require.Len(t, repo.rateLimitCalls, 1)
require.WithinDuration(t, time.Unix(resetAt, 0), repo.rateLimitCalls[0], 2*time.Second)
case <-time.After(5 * time.Second):
t.Fatal("等待 ingress websocket 结束超时")
}
}
func TestOpenAIGatewayService_UpdateCodexUsageSnapshot_ExhaustedSnapshotSetsRateLimit(t *testing.T) {
repo := &openAICodexSnapshotAsyncRepo{
updateExtraCh: make(chan map[string]any, 1),
rateLimitCh: make(chan time.Time, 1),
}
svc := &OpenAIGatewayService{accountRepo: repo}
snapshot := &OpenAICodexUsageSnapshot{
PrimaryUsedPercent: ptrFloat64WS(100),
PrimaryResetAfterSeconds: ptrIntWS(3600),
PrimaryWindowMinutes: ptrIntWS(10080),
SecondaryUsedPercent: ptrFloat64WS(12),
SecondaryResetAfterSeconds: ptrIntWS(1200),
SecondaryWindowMinutes: ptrIntWS(300),
}
before := time.Now()
svc.updateCodexUsageSnapshot(context.Background(), 601, snapshot)
select {
case updates := <-repo.updateExtraCh:
require.Equal(t, 100.0, updates["codex_7d_used_percent"])
case <-time.After(2 * time.Second):
t.Fatal("等待 codex 快照落库超时")
}
select {
case resetAt := <-repo.rateLimitCh:
require.WithinDuration(t, before.Add(time.Hour), resetAt, 2*time.Second)
case <-time.After(2 * time.Second):
t.Fatal("等待 codex 100% 自动切换限流超时")
}
}
func TestOpenAIGatewayService_UpdateCodexUsageSnapshot_NonExhaustedSnapshotDoesNotSetRateLimit(t *testing.T) {
repo := &openAICodexSnapshotAsyncRepo{
updateExtraCh: make(chan map[string]any, 1),
rateLimitCh: make(chan time.Time, 1),
}
svc := &OpenAIGatewayService{accountRepo: repo}
snapshot := &OpenAICodexUsageSnapshot{
PrimaryUsedPercent: ptrFloat64WS(94),
PrimaryResetAfterSeconds: ptrIntWS(3600),
PrimaryWindowMinutes: ptrIntWS(10080),
SecondaryUsedPercent: ptrFloat64WS(22),
SecondaryResetAfterSeconds: ptrIntWS(1200),
SecondaryWindowMinutes: ptrIntWS(300),
}
svc.updateCodexUsageSnapshot(context.Background(), 602, snapshot)
select {
case <-repo.updateExtraCh:
case <-time.After(2 * time.Second):
t.Fatal("等待 codex 快照落库超时")
}
select {
case resetAt := <-repo.rateLimitCh:
t.Fatalf("unexpected rate limit reset at: %v", resetAt)
case <-time.After(200 * time.Millisecond):
}
}
func ptrFloat64WS(v float64) *float64 { return &v }
func ptrIntWS(v int) *int { return &v }
func TestOpenAIWSErrorHTTPStatusFromRaw_UsageLimitReachedIs429(t *testing.T) {
require.Equal(t, http.StatusTooManyRequests, openAIWSErrorHTTPStatusFromRaw("", "usage_limit_reached"))
require.Equal(t, http.StatusTooManyRequests, openAIWSErrorHTTPStatusFromRaw("rate_limit_exceeded", ""))
}