diff --git a/backend/internal/repository/account_repo.go b/backend/internal/repository/account_repo.go index 2e4c7ec9..8d9a270e 100644 --- a/backend/internal/repository/account_repo.go +++ b/backend/internal/repository/account_repo.go @@ -925,6 +925,7 @@ func (r *accountRepository) SetRateLimited(ctx context.Context, id int64, resetA if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil { logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue rate limit failed: account=%d err=%v", id, err) } + r.syncSchedulerAccountSnapshot(ctx, id) return nil } @@ -1040,6 +1041,7 @@ func (r *accountRepository) ClearRateLimit(ctx context.Context, id int64) error if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil { logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue clear rate limit failed: account=%d err=%v", id, err) } + r.syncSchedulerAccountSnapshot(ctx, id) return nil } diff --git a/backend/internal/service/account_test_service.go b/backend/internal/service/account_test_service.go index 9557e175..b44f29fd 100644 --- a/backend/internal/service/account_test_service.go +++ b/backend/internal/service/account_test_service.go @@ -406,8 +406,27 @@ func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account } defer func() { _ = resp.Body.Close() }() + if isOAuth && s.accountRepo != nil { + if updates, err := extractOpenAICodexProbeUpdates(resp); err == nil && len(updates) > 0 { + _ = s.accountRepo.UpdateExtra(ctx, account.ID, updates) + mergeAccountExtra(account, updates) + } + if snapshot := ParseCodexRateLimitHeaders(resp.Header); snapshot != nil { + if resetAt := codexRateLimitResetAtFromSnapshot(snapshot, time.Now()); resetAt != nil { + _ = s.accountRepo.SetRateLimited(ctx, account.ID, *resetAt) + account.RateLimitResetAt = resetAt + } + } + } + if resp.StatusCode != http.StatusOK { body, _ := io.ReadAll(resp.Body) + if isOAuth && s.accountRepo != nil { + if resetAt := (&RateLimitService{}).calculateOpenAI429ResetTime(resp.Header); resetAt != nil { + _ = s.accountRepo.SetRateLimited(ctx, account.ID, *resetAt) + account.RateLimitResetAt = resetAt + } + } return s.sendErrorAndEnd(c, fmt.Sprintf("API returned %d: %s", resp.StatusCode, string(body))) } diff --git a/backend/internal/service/account_test_service_openai_test.go b/backend/internal/service/account_test_service_openai_test.go new file mode 100644 index 00000000..efa6f7da --- /dev/null +++ b/backend/internal/service/account_test_service_openai_test.go @@ -0,0 +1,102 @@ +//go:build unit + +package service + +import ( + "context" + "io" + "net/http" + "strings" + "testing" + "time" + + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +type openAIAccountTestRepo struct { + mockAccountRepoForGemini + updatedExtra map[string]any + rateLimitedID int64 + rateLimitedAt *time.Time +} + +func (r *openAIAccountTestRepo) UpdateExtra(_ context.Context, _ int64, updates map[string]any) error { + r.updatedExtra = updates + return nil +} + +func (r *openAIAccountTestRepo) SetRateLimited(_ context.Context, id int64, resetAt time.Time) error { + r.rateLimitedID = id + r.rateLimitedAt = &resetAt + return nil +} + +func TestAccountTestService_OpenAISuccessPersistsSnapshotFromHeaders(t *testing.T) { + gin.SetMode(gin.TestMode) + ctx, recorder := newSoraTestContext() + + resp := newJSONResponse(http.StatusOK, "") + resp.Body = io.NopCloser(strings.NewReader(`data: {"type":"response.completed"} + +`)) + resp.Header.Set("x-codex-primary-used-percent", "88") + resp.Header.Set("x-codex-primary-reset-after-seconds", "604800") + resp.Header.Set("x-codex-primary-window-minutes", "10080") + resp.Header.Set("x-codex-secondary-used-percent", "42") + resp.Header.Set("x-codex-secondary-reset-after-seconds", "18000") + resp.Header.Set("x-codex-secondary-window-minutes", "300") + + repo := &openAIAccountTestRepo{} + upstream := &queuedHTTPUpstream{responses: []*http.Response{resp}} + svc := &AccountTestService{accountRepo: repo, httpUpstream: upstream} + account := &Account{ + ID: 89, + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Concurrency: 1, + Credentials: map[string]any{"access_token": "test-token"}, + } + + err := svc.testOpenAIAccountConnection(ctx, account, "gpt-5.4") + require.NoError(t, err) + require.NotEmpty(t, repo.updatedExtra) + require.Equal(t, 42.0, repo.updatedExtra["codex_5h_used_percent"]) + require.Equal(t, 88.0, repo.updatedExtra["codex_7d_used_percent"]) + require.Contains(t, recorder.Body.String(), "test_complete") +} + +func TestAccountTestService_OpenAI429PersistsSnapshotAndRateLimit(t *testing.T) { + gin.SetMode(gin.TestMode) + ctx, _ := newSoraTestContext() + + resp := newJSONResponse(http.StatusTooManyRequests, `{"error":{"type":"usage_limit_reached","message":"limit reached"}}`) + resp.Header.Set("x-codex-primary-used-percent", "100") + resp.Header.Set("x-codex-primary-reset-after-seconds", "604800") + resp.Header.Set("x-codex-primary-window-minutes", "10080") + resp.Header.Set("x-codex-secondary-used-percent", "100") + resp.Header.Set("x-codex-secondary-reset-after-seconds", "18000") + resp.Header.Set("x-codex-secondary-window-minutes", "300") + + repo := &openAIAccountTestRepo{} + upstream := &queuedHTTPUpstream{responses: []*http.Response{resp}} + svc := &AccountTestService{accountRepo: repo, httpUpstream: upstream} + account := &Account{ + ID: 88, + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Concurrency: 1, + Credentials: map[string]any{"access_token": "test-token"}, + } + + err := svc.testOpenAIAccountConnection(ctx, account, "gpt-5.4") + require.Error(t, err) + require.NotEmpty(t, repo.updatedExtra) + require.Equal(t, 100.0, repo.updatedExtra["codex_5h_used_percent"]) + require.Equal(t, int64(88), repo.rateLimitedID) + require.NotNil(t, repo.rateLimitedAt) + require.NotNil(t, account.RateLimitResetAt) + if account.RateLimitResetAt != nil && repo.rateLimitedAt != nil { + require.WithinDuration(t, *repo.rateLimitedAt, *account.RateLimitResetAt, time.Second) + } +} diff --git a/backend/internal/service/account_usage_service.go b/backend/internal/service/account_usage_service.go index b0a4900d..7c001118 100644 --- a/backend/internal/service/account_usage_service.go +++ b/backend/internal/service/account_usage_service.go @@ -359,6 +359,7 @@ func (s *AccountUsageService) getOpenAIUsage(ctx context.Context, account *Accou if account == nil { return usage, nil } + syncOpenAICodexRateLimitFromExtra(ctx, s.accountRepo, account, now) if progress := buildCodexUsageProgressFromExtra(account.Extra, "5h", now); progress != nil { usage.FiveHour = progress @@ -367,7 +368,7 @@ func (s *AccountUsageService) getOpenAIUsage(ctx context.Context, account *Accou usage.SevenDay = progress } - if (usage.FiveHour == nil || usage.SevenDay == nil) && s.shouldProbeOpenAICodexSnapshot(account.ID, now) { + if shouldRefreshOpenAICodexSnapshot(account, usage, now) && s.shouldProbeOpenAICodexSnapshot(account.ID, now) { if updates, err := s.probeOpenAICodexSnapshot(ctx, account); err == nil && len(updates) > 0 { mergeAccountExtra(account, updates) if usage.UpdatedAt == nil { @@ -409,6 +410,40 @@ func (s *AccountUsageService) getOpenAIUsage(ctx context.Context, account *Accou return usage, nil } +func shouldRefreshOpenAICodexSnapshot(account *Account, usage *UsageInfo, now time.Time) bool { + if account == nil { + return false + } + if usage == nil { + return true + } + if usage.FiveHour == nil || usage.SevenDay == nil { + return true + } + if account.IsRateLimited() { + return true + } + return isOpenAICodexSnapshotStale(account, now) +} + +func isOpenAICodexSnapshotStale(account *Account, now time.Time) bool { + if account == nil || !account.IsOpenAIOAuth() || !account.IsOpenAIResponsesWebSocketV2Enabled() { + return false + } + if account.Extra == nil { + return true + } + raw, ok := account.Extra["codex_usage_updated_at"] + if !ok { + return true + } + ts, err := parseTime(fmt.Sprint(raw)) + if err != nil { + return true + } + return now.Sub(ts) >= openAIProbeCacheTTL +} + func (s *AccountUsageService) shouldProbeOpenAICodexSnapshot(accountID int64, now time.Time) bool { if s == nil || s.cache == nil || accountID <= 0 { return true @@ -478,20 +513,34 @@ func (s *AccountUsageService) probeOpenAICodexSnapshot(ctx context.Context, acco } defer func() { _ = resp.Body.Close() }() - if resp.StatusCode < 200 || resp.StatusCode >= 300 { - return nil, fmt.Errorf("openai codex probe returned status %d", resp.StatusCode) + updates, err := extractOpenAICodexProbeUpdates(resp) + if err != nil { + return nil, err + } + if len(updates) > 0 { + go func(accountID int64, updates map[string]any) { + updateCtx, updateCancel := context.WithTimeout(context.Background(), 5*time.Second) + defer updateCancel() + _ = s.accountRepo.UpdateExtra(updateCtx, accountID, updates) + }(account.ID, updates) + return updates, nil + } + return nil, nil +} + +func extractOpenAICodexProbeUpdates(resp *http.Response) (map[string]any, error) { + if resp == nil { + return nil, nil } if snapshot := ParseCodexRateLimitHeaders(resp.Header); snapshot != nil { updates := buildCodexUsageExtraUpdates(snapshot, time.Now()) if len(updates) > 0 { - go func(accountID int64, updates map[string]any) { - updateCtx, updateCancel := context.WithTimeout(context.Background(), 5*time.Second) - defer updateCancel() - _ = s.accountRepo.UpdateExtra(updateCtx, accountID, updates) - }(account.ID, updates) return updates, nil } } + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + return nil, fmt.Errorf("openai codex probe returned status %d", resp.StatusCode) + } return nil, nil } diff --git a/backend/internal/service/account_usage_service_test.go b/backend/internal/service/account_usage_service_test.go new file mode 100644 index 00000000..974d9029 --- /dev/null +++ b/backend/internal/service/account_usage_service_test.go @@ -0,0 +1,68 @@ +package service + +import ( + "net/http" + "testing" + "time" +) + +func TestShouldRefreshOpenAICodexSnapshot(t *testing.T) { + t.Parallel() + + rateLimitedUntil := time.Now().Add(5 * time.Minute) + now := time.Now() + usage := &UsageInfo{ + FiveHour: &UsageProgress{Utilization: 0}, + SevenDay: &UsageProgress{Utilization: 0}, + } + + if !shouldRefreshOpenAICodexSnapshot(&Account{RateLimitResetAt: &rateLimitedUntil}, usage, now) { + t.Fatal("expected rate-limited account to force codex snapshot refresh") + } + + if shouldRefreshOpenAICodexSnapshot(&Account{}, usage, now) { + t.Fatal("expected complete non-rate-limited usage to skip codex snapshot refresh") + } + + if !shouldRefreshOpenAICodexSnapshot(&Account{}, &UsageInfo{FiveHour: nil, SevenDay: &UsageProgress{}}, now) { + t.Fatal("expected missing 5h snapshot to require refresh") + } + + staleAt := now.Add(-(openAIProbeCacheTTL + time.Minute)).Format(time.RFC3339) + if !shouldRefreshOpenAICodexSnapshot(&Account{ + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Extra: map[string]any{ + "openai_oauth_responses_websockets_v2_enabled": true, + "codex_usage_updated_at": staleAt, + }, + }, usage, now) { + t.Fatal("expected stale ws snapshot to trigger refresh") + } +} + +func TestExtractOpenAICodexProbeUpdatesAccepts429WithCodexHeaders(t *testing.T) { + t.Parallel() + + headers := make(http.Header) + headers.Set("x-codex-primary-used-percent", "100") + headers.Set("x-codex-primary-reset-after-seconds", "604800") + headers.Set("x-codex-primary-window-minutes", "10080") + headers.Set("x-codex-secondary-used-percent", "100") + headers.Set("x-codex-secondary-reset-after-seconds", "18000") + headers.Set("x-codex-secondary-window-minutes", "300") + + updates, err := extractOpenAICodexProbeUpdates(&http.Response{StatusCode: http.StatusTooManyRequests, Header: headers}) + if err != nil { + t.Fatalf("extractOpenAICodexProbeUpdates() error = %v", err) + } + if len(updates) == 0 { + t.Fatal("expected codex probe updates from 429 headers") + } + if got := updates["codex_5h_used_percent"]; got != 100.0 { + t.Fatalf("codex_5h_used_percent = %v, want 100", got) + } + if got := updates["codex_7d_used_percent"]; got != 100.0 { + t.Fatalf("codex_7d_used_percent = %v, want 100", got) + } +} diff --git a/backend/internal/service/admin_service.go b/backend/internal/service/admin_service.go index a3ed4233..fc658316 100644 --- a/backend/internal/service/admin_service.go +++ b/backend/internal/service/admin_service.go @@ -1349,6 +1349,10 @@ func (s *adminServiceImpl) ListAccounts(ctx context.Context, page, pageSize int, if err != nil { return nil, 0, err } + now := time.Now() + for i := range accounts { + syncOpenAICodexRateLimitFromExtra(ctx, s.accountRepo, &accounts[i], now) + } return accounts, result.Total, nil } diff --git a/backend/internal/service/openai_account_scheduler.go b/backend/internal/service/openai_account_scheduler.go index cf4bc26e..0fcf450b 100644 --- a/backend/internal/service/openai_account_scheduler.go +++ b/backend/internal/service/openai_account_scheduler.go @@ -319,7 +319,7 @@ func (s *defaultOpenAIAccountScheduler) selectBySessionHash( _ = s.service.deleteStickySessionAccountID(ctx, req.GroupID, sessionHash) return nil, nil } - if shouldClearStickySession(account, req.RequestedModel) || !account.IsOpenAI() { + if shouldClearStickySession(account, req.RequestedModel) || !account.IsOpenAI() || !account.IsSchedulable() { _ = s.service.deleteStickySessionAccountID(ctx, req.GroupID, sessionHash) return nil, nil } @@ -687,16 +687,20 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance( for i := 0; i < len(selectionOrder); i++ { candidate := selectionOrder[i] - result, acquireErr := s.service.tryAcquireAccountSlot(ctx, candidate.account.ID, candidate.account.Concurrency) + fresh := s.service.resolveFreshSchedulableOpenAIAccount(ctx, candidate.account, req.RequestedModel) + if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) { + continue + } + result, acquireErr := s.service.tryAcquireAccountSlot(ctx, fresh.ID, fresh.Concurrency) if acquireErr != nil { return nil, len(candidates), topK, loadSkew, acquireErr } if result != nil && result.Acquired { if req.SessionHash != "" { - _ = s.service.BindStickySession(ctx, req.GroupID, req.SessionHash, candidate.account.ID) + _ = s.service.BindStickySession(ctx, req.GroupID, req.SessionHash, fresh.ID) } return &AccountSelectionResult{ - Account: candidate.account, + Account: fresh, Acquired: true, ReleaseFunc: result.ReleaseFunc, }, len(candidates), topK, loadSkew, nil @@ -705,16 +709,23 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance( cfg := s.service.schedulingConfig() // WaitPlan.MaxConcurrency 使用 Concurrency(非 EffectiveLoadFactor),因为 WaitPlan 控制的是 Redis 实际并发槽位等待。 - candidate := selectionOrder[0] - return &AccountSelectionResult{ - Account: candidate.account, - WaitPlan: &AccountWaitPlan{ - AccountID: candidate.account.ID, - MaxConcurrency: candidate.account.Concurrency, - Timeout: cfg.FallbackWaitTimeout, - MaxWaiting: cfg.FallbackMaxWaiting, - }, - }, len(candidates), topK, loadSkew, nil + for _, candidate := range selectionOrder { + fresh := s.service.resolveFreshSchedulableOpenAIAccount(ctx, candidate.account, req.RequestedModel) + if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) { + continue + } + return &AccountSelectionResult{ + Account: fresh, + WaitPlan: &AccountWaitPlan{ + AccountID: fresh.ID, + MaxConcurrency: fresh.Concurrency, + Timeout: cfg.FallbackWaitTimeout, + MaxWaiting: cfg.FallbackMaxWaiting, + }, + }, len(candidates), topK, loadSkew, nil + } + + return nil, len(candidates), topK, loadSkew, errors.New("no available accounts") } func (s *defaultOpenAIAccountScheduler) isAccountTransportCompatible(account *Account, requiredTransport OpenAIUpstreamTransport) bool { diff --git a/backend/internal/service/openai_account_scheduler_test.go b/backend/internal/service/openai_account_scheduler_test.go index 7f6f1b66..977c4ee8 100644 --- a/backend/internal/service/openai_account_scheduler_test.go +++ b/backend/internal/service/openai_account_scheduler_test.go @@ -12,6 +12,78 @@ import ( "github.com/stretchr/testify/require" ) +type openAISnapshotCacheStub struct { + SchedulerCache + snapshotAccounts []*Account + accountsByID map[int64]*Account +} + +func (s *openAISnapshotCacheStub) GetSnapshot(ctx context.Context, bucket SchedulerBucket) ([]*Account, bool, error) { + if len(s.snapshotAccounts) == 0 { + return nil, false, nil + } + out := make([]*Account, 0, len(s.snapshotAccounts)) + for _, account := range s.snapshotAccounts { + if account == nil { + continue + } + cloned := *account + out = append(out, &cloned) + } + return out, true, nil +} + +func (s *openAISnapshotCacheStub) GetAccount(ctx context.Context, accountID int64) (*Account, error) { + if s.accountsByID == nil { + return nil, nil + } + account := s.accountsByID[accountID] + if account == nil { + return nil, nil + } + cloned := *account + return &cloned, nil +} + +func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionStickyRateLimitedAccountFallsBackToFreshCandidate(t *testing.T) { + ctx := context.Background() + groupID := int64(10101) + rateLimitedUntil := time.Now().Add(30 * time.Minute) + staleSticky := &Account{ID: 31001, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 0} + staleBackup := &Account{ID: 31002, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 5} + freshSticky := &Account{ID: 31001, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 0, RateLimitResetAt: &rateLimitedUntil} + freshBackup := &Account{ID: 31002, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 5} + cache := &stubGatewayCache{sessionBindings: map[string]int64{"openai:session_hash_rate_limited": 31001}} + snapshotCache := &openAISnapshotCacheStub{snapshotAccounts: []*Account{staleSticky, staleBackup}, accountsByID: map[int64]*Account{31001: freshSticky, 31002: freshBackup}} + snapshotService := &SchedulerSnapshotService{cache: snapshotCache} + svc := &OpenAIGatewayService{accountRepo: stubOpenAIAccountRepo{accounts: []Account{*freshSticky, *freshBackup}}, cache: cache, cfg: &config.Config{}, schedulerSnapshot: snapshotService, concurrencyService: NewConcurrencyService(stubConcurrencyCache{})} + + selection, decision, err := svc.SelectAccountWithScheduler(ctx, &groupID, "", "session_hash_rate_limited", "gpt-5.1", nil, OpenAIUpstreamTransportAny) + require.NoError(t, err) + require.NotNil(t, selection) + require.NotNil(t, selection.Account) + require.Equal(t, int64(31002), selection.Account.ID) + require.Equal(t, openAIAccountScheduleLayerLoadBalance, decision.Layer) +} + +func TestOpenAIGatewayService_SelectAccountForModelWithExclusions_SkipsFreshlyRateLimitedSnapshotCandidate(t *testing.T) { + ctx := context.Background() + groupID := int64(10102) + rateLimitedUntil := time.Now().Add(30 * time.Minute) + stalePrimary := &Account{ID: 32001, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 0} + staleSecondary := &Account{ID: 32002, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 5} + freshPrimary := &Account{ID: 32001, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 0, RateLimitResetAt: &rateLimitedUntil} + freshSecondary := &Account{ID: 32002, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 5} + snapshotCache := &openAISnapshotCacheStub{snapshotAccounts: []*Account{stalePrimary, staleSecondary}, accountsByID: map[int64]*Account{32001: freshPrimary, 32002: freshSecondary}} + snapshotService := &SchedulerSnapshotService{cache: snapshotCache} + svc := &OpenAIGatewayService{accountRepo: stubOpenAIAccountRepo{accounts: []Account{*freshPrimary, *freshSecondary}}, cfg: &config.Config{}, schedulerSnapshot: snapshotService} + + account, err := svc.SelectAccountForModelWithExclusions(ctx, &groupID, "", "gpt-5.1", nil) + require.NoError(t, err) + require.NotNil(t, account) + require.Equal(t, int64(32002), account.ID) +} + func TestOpenAIGatewayService_SelectAccountWithScheduler_PreviousResponseSticky(t *testing.T) { ctx := context.Background() groupID := int64(9) diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go index 5c8c2710..709ee808 100644 --- a/backend/internal/service/openai_gateway_service.go +++ b/backend/internal/service/openai_gateway_service.go @@ -1026,7 +1026,7 @@ func (s *OpenAIGatewayService) selectAccountForModelWithExclusions(ctx context.C // 3. 按优先级 + LRU 选择最佳账号 // Select by priority + LRU - selected := s.selectBestAccount(accounts, requestedModel, excludedIDs) + selected := s.selectBestAccount(ctx, accounts, requestedModel, excludedIDs) if selected == nil { if requestedModel != "" { @@ -1099,7 +1099,7 @@ func (s *OpenAIGatewayService) tryStickySessionHit(ctx context.Context, groupID // // selectBestAccount selects the best account from candidates (priority + LRU). // Returns nil if no available account. -func (s *OpenAIGatewayService) selectBestAccount(accounts []Account, requestedModel string, excludedIDs map[int64]struct{}) *Account { +func (s *OpenAIGatewayService) selectBestAccount(ctx context.Context, accounts []Account, requestedModel string, excludedIDs map[int64]struct{}) *Account { var selected *Account for i := range accounts { @@ -1111,27 +1111,20 @@ func (s *OpenAIGatewayService) selectBestAccount(accounts []Account, requestedMo continue } - // 调度器快照可能暂时过时,这里重新检查可调度性和平台 - // Scheduler snapshots can be temporarily stale; re-check schedulability and platform - if !acc.IsSchedulable() || !acc.IsOpenAI() { - continue - } - - // 检查模型支持 - // Check model support - if requestedModel != "" && !acc.IsModelSupported(requestedModel) { + fresh := s.resolveFreshSchedulableOpenAIAccount(ctx, acc, requestedModel) + if fresh == nil { continue } // 选择优先级最高且最久未使用的账号 // Select highest priority and least recently used if selected == nil { - selected = acc + selected = fresh continue } - if s.isBetterAccount(acc, selected) { - selected = acc + if s.isBetterAccount(fresh, selected) { + selected = fresh } } @@ -1309,13 +1302,17 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex ordered := append([]*Account(nil), candidates...) sortAccountsByPriorityAndLastUsed(ordered, false) for _, acc := range ordered { - result, err := s.tryAcquireAccountSlot(ctx, acc.ID, acc.Concurrency) + fresh := s.resolveFreshSchedulableOpenAIAccount(ctx, acc, requestedModel) + if fresh == nil { + continue + } + result, err := s.tryAcquireAccountSlot(ctx, fresh.ID, fresh.Concurrency) if err == nil && result.Acquired { if sessionHash != "" { - _ = s.setStickySessionAccountID(ctx, groupID, sessionHash, acc.ID, openaiStickySessionTTL) + _ = s.setStickySessionAccountID(ctx, groupID, sessionHash, fresh.ID, openaiStickySessionTTL) } return &AccountSelectionResult{ - Account: acc, + Account: fresh, Acquired: true, ReleaseFunc: result.ReleaseFunc, }, nil @@ -1359,13 +1356,17 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex shuffleWithinSortGroups(available) for _, item := range available { - result, err := s.tryAcquireAccountSlot(ctx, item.account.ID, item.account.Concurrency) + fresh := s.resolveFreshSchedulableOpenAIAccount(ctx, item.account, requestedModel) + if fresh == nil { + continue + } + result, err := s.tryAcquireAccountSlot(ctx, fresh.ID, fresh.Concurrency) if err == nil && result.Acquired { if sessionHash != "" { - _ = s.setStickySessionAccountID(ctx, groupID, sessionHash, item.account.ID, openaiStickySessionTTL) + _ = s.setStickySessionAccountID(ctx, groupID, sessionHash, fresh.ID, openaiStickySessionTTL) } return &AccountSelectionResult{ - Account: item.account, + Account: fresh, Acquired: true, ReleaseFunc: result.ReleaseFunc, }, nil @@ -1377,11 +1378,15 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex // ============ Layer 3: Fallback wait ============ sortAccountsByPriorityAndLastUsed(candidates, false) for _, acc := range candidates { + fresh := s.resolveFreshSchedulableOpenAIAccount(ctx, acc, requestedModel) + if fresh == nil { + continue + } return &AccountSelectionResult{ - Account: acc, + Account: fresh, WaitPlan: &AccountWaitPlan{ - AccountID: acc.ID, - MaxConcurrency: acc.Concurrency, + AccountID: fresh.ID, + MaxConcurrency: fresh.Concurrency, Timeout: cfg.FallbackWaitTimeout, MaxWaiting: cfg.FallbackMaxWaiting, }, @@ -1418,11 +1423,44 @@ func (s *OpenAIGatewayService) tryAcquireAccountSlot(ctx context.Context, accoun return s.concurrencyService.AcquireAccountSlot(ctx, accountID, maxConcurrency) } -func (s *OpenAIGatewayService) getSchedulableAccount(ctx context.Context, accountID int64) (*Account, error) { - if s.schedulerSnapshot != nil { - return s.schedulerSnapshot.GetAccount(ctx, accountID) +func (s *OpenAIGatewayService) resolveFreshSchedulableOpenAIAccount(ctx context.Context, account *Account, requestedModel string) *Account { + if account == nil { + return nil } - return s.accountRepo.GetByID(ctx, accountID) + + fresh := account + if s.schedulerSnapshot != nil { + current, err := s.getSchedulableAccount(ctx, account.ID) + if err != nil || current == nil { + return nil + } + fresh = current + } + + if !fresh.IsSchedulable() || !fresh.IsOpenAI() { + return nil + } + if requestedModel != "" && !fresh.IsModelSupported(requestedModel) { + return nil + } + return fresh +} + +func (s *OpenAIGatewayService) getSchedulableAccount(ctx context.Context, accountID int64) (*Account, error) { + var ( + account *Account + err error + ) + if s.schedulerSnapshot != nil { + account, err = s.schedulerSnapshot.GetAccount(ctx, accountID) + } else { + account, err = s.accountRepo.GetByID(ctx, accountID) + } + if err != nil || account == nil { + return account, err + } + syncOpenAICodexRateLimitFromExtra(ctx, s.accountRepo, account, time.Now()) + return account, nil } func (s *OpenAIGatewayService) schedulingConfig() config.GatewaySchedulingConfig { @@ -3871,6 +3909,69 @@ func buildCodexUsageExtraUpdates(snapshot *OpenAICodexUsageSnapshot, fallbackNow return updates } +func codexUsagePercentExhausted(value *float64) bool { + return value != nil && *value >= 100-1e-9 +} + +func codexRateLimitResetAtFromSnapshot(snapshot *OpenAICodexUsageSnapshot, fallbackNow time.Time) *time.Time { + if snapshot == nil { + return nil + } + normalized := snapshot.Normalize() + if normalized == nil { + return nil + } + baseTime := codexSnapshotBaseTime(snapshot, fallbackNow) + if codexUsagePercentExhausted(normalized.Used7dPercent) && normalized.Reset7dSeconds != nil { + resetAt := baseTime.Add(time.Duration(*normalized.Reset7dSeconds) * time.Second) + return &resetAt + } + if codexUsagePercentExhausted(normalized.Used5hPercent) && normalized.Reset5hSeconds != nil { + resetAt := baseTime.Add(time.Duration(*normalized.Reset5hSeconds) * time.Second) + return &resetAt + } + return nil +} + +func codexRateLimitResetAtFromExtra(extra map[string]any, now time.Time) *time.Time { + if len(extra) == 0 { + return nil + } + if progress := buildCodexUsageProgressFromExtra(extra, "7d", now); progress != nil && codexUsagePercentExhausted(&progress.Utilization) && progress.ResetsAt != nil && now.Before(*progress.ResetsAt) { + resetAt := progress.ResetsAt.UTC() + return &resetAt + } + if progress := buildCodexUsageProgressFromExtra(extra, "5h", now); progress != nil && codexUsagePercentExhausted(&progress.Utilization) && progress.ResetsAt != nil && now.Before(*progress.ResetsAt) { + resetAt := progress.ResetsAt.UTC() + return &resetAt + } + return nil +} + +func applyOpenAICodexRateLimitFromExtra(account *Account, now time.Time) (*time.Time, bool) { + if account == nil || !account.IsOpenAI() { + return nil, false + } + resetAt := codexRateLimitResetAtFromExtra(account.Extra, now) + if resetAt == nil { + return nil, false + } + if account.RateLimitResetAt != nil && now.Before(*account.RateLimitResetAt) && !account.RateLimitResetAt.Before(*resetAt) { + return account.RateLimitResetAt, false + } + account.RateLimitResetAt = resetAt + return resetAt, true +} + +func syncOpenAICodexRateLimitFromExtra(ctx context.Context, repo AccountRepository, account *Account, now time.Time) *time.Time { + resetAt, changed := applyOpenAICodexRateLimitFromExtra(account, now) + if !changed || resetAt == nil || repo == nil || account == nil || account.ID <= 0 { + return resetAt + } + _ = repo.SetRateLimited(ctx, account.ID, *resetAt) + return resetAt +} + // updateCodexUsageSnapshot saves the Codex usage snapshot to account's Extra field func (s *OpenAIGatewayService) updateCodexUsageSnapshot(ctx context.Context, accountID int64, snapshot *OpenAICodexUsageSnapshot) { if snapshot == nil { @@ -3880,16 +3981,22 @@ func (s *OpenAIGatewayService) updateCodexUsageSnapshot(ctx context.Context, acc return } - updates := buildCodexUsageExtraUpdates(snapshot, time.Now()) - if len(updates) == 0 { + now := time.Now() + updates := buildCodexUsageExtraUpdates(snapshot, now) + resetAt := codexRateLimitResetAtFromSnapshot(snapshot, now) + if len(updates) == 0 && resetAt == nil { return } - // Update account's Extra field asynchronously go func() { updateCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() - _ = s.accountRepo.UpdateExtra(updateCtx, accountID, updates) + if len(updates) > 0 { + _ = s.accountRepo.UpdateExtra(updateCtx, accountID, updates) + } + if resetAt != nil { + _ = s.accountRepo.SetRateLimited(updateCtx, accountID, *resetAt) + } }() } diff --git a/backend/internal/service/openai_ws_account_sticky_test.go b/backend/internal/service/openai_ws_account_sticky_test.go index 3fe08179..9a8803d3 100644 --- a/backend/internal/service/openai_ws_account_sticky_test.go +++ b/backend/internal/service/openai_ws_account_sticky_test.go @@ -48,6 +48,43 @@ func TestOpenAIGatewayService_SelectAccountByPreviousResponseID_Hit(t *testing.T } } +func TestOpenAIGatewayService_SelectAccountByPreviousResponseID_RateLimitedMiss(t *testing.T) { + ctx := context.Background() + groupID := int64(23) + rateLimitedUntil := time.Now().Add(30 * time.Minute) + account := Account{ + ID: 12, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + RateLimitResetAt: &rateLimitedUntil, + Extra: map[string]any{ + "openai_apikey_responses_websockets_v2_enabled": true, + }, + } + cache := &stubGatewayCache{} + store := NewOpenAIWSStateStore(cache) + cfg := newOpenAIWSV2TestConfig() + svc := &OpenAIGatewayService{ + accountRepo: stubOpenAIAccountRepo{accounts: []Account{account}}, + cache: cache, + cfg: cfg, + concurrencyService: NewConcurrencyService(stubConcurrencyCache{}), + openaiWSStateStore: store, + } + + require.NoError(t, store.BindResponseAccount(ctx, groupID, "resp_prev_rl", account.ID, time.Hour)) + + selection, err := svc.SelectAccountByPreviousResponseID(ctx, &groupID, "resp_prev_rl", "gpt-5.1", nil) + require.NoError(t, err) + require.Nil(t, selection, "限额中的账号不应继续命中 previous_response_id 粘连") + boundAccountID, getErr := store.GetResponseAccount(ctx, groupID, "resp_prev_rl") + require.NoError(t, getErr) + require.Zero(t, boundAccountID) +} + func TestOpenAIGatewayService_SelectAccountByPreviousResponseID_Excluded(t *testing.T) { ctx := context.Background() groupID := int64(23) diff --git a/backend/internal/service/openai_ws_forwarder.go b/backend/internal/service/openai_ws_forwarder.go index f9e93f85..f2f8edd9 100644 --- a/backend/internal/service/openai_ws_forwarder.go +++ b/backend/internal/service/openai_ws_forwarder.go @@ -1853,6 +1853,10 @@ func (s *OpenAIGatewayService) forwardOpenAIWSV2( wsPath, account.ProxyID != nil && account.Proxy != nil, ) + var dialErr *openAIWSDialError + if errors.As(err, &dialErr) && dialErr != nil && dialErr.StatusCode == http.StatusTooManyRequests { + s.persistOpenAIWSRateLimitSignal(ctx, account, dialErr.ResponseHeaders, nil, "rate_limit_exceeded", "rate_limit_error", strings.TrimSpace(err.Error())) + } return nil, wrapOpenAIWSFallback(classifyOpenAIWSAcquireError(err), err) } defer lease.Release() @@ -2136,6 +2140,7 @@ func (s *OpenAIGatewayService) forwardOpenAIWSV2( if eventType == "error" { errCodeRaw, errTypeRaw, errMsgRaw := parseOpenAIWSErrorEventFields(message) + s.persistOpenAIWSRateLimitSignal(ctx, account, lease.HandshakeHeaders(), message, errCodeRaw, errTypeRaw, errMsgRaw) errMsg := strings.TrimSpace(errMsgRaw) if errMsg == "" { errMsg = "Upstream websocket error" @@ -2639,6 +2644,10 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient( wsPath, account.ProxyID != nil && account.Proxy != nil, ) + var dialErr *openAIWSDialError + if errors.As(acquireErr, &dialErr) && dialErr != nil && dialErr.StatusCode == http.StatusTooManyRequests { + s.persistOpenAIWSRateLimitSignal(ctx, account, dialErr.ResponseHeaders, nil, "rate_limit_exceeded", "rate_limit_error", strings.TrimSpace(acquireErr.Error())) + } if errors.Is(acquireErr, errOpenAIWSPreferredConnUnavailable) { return nil, NewOpenAIWSClientCloseError( coderws.StatusPolicyViolation, @@ -2777,6 +2786,7 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient( } if eventType == "error" { errCodeRaw, errTypeRaw, errMsgRaw := parseOpenAIWSErrorEventFields(upstreamMessage) + s.persistOpenAIWSRateLimitSignal(ctx, account, lease.HandshakeHeaders(), upstreamMessage, errCodeRaw, errTypeRaw, errMsgRaw) fallbackReason, _ := classifyOpenAIWSErrorEventFromRaw(errCodeRaw, errTypeRaw, errMsgRaw) errCode, errType, errMessage := summarizeOpenAIWSErrorEventFieldsFromRaw(errCodeRaw, errTypeRaw, errMsgRaw) recoverablePrevNotFound := fallbackReason == openAIWSIngressStagePreviousResponseNotFound && @@ -3604,6 +3614,7 @@ func (s *OpenAIGatewayService) performOpenAIWSGeneratePrewarm( if eventType == "error" { errCodeRaw, errTypeRaw, errMsgRaw := parseOpenAIWSErrorEventFields(message) + s.persistOpenAIWSRateLimitSignal(ctx, account, lease.HandshakeHeaders(), message, errCodeRaw, errTypeRaw, errMsgRaw) errMsg := strings.TrimSpace(errMsgRaw) if errMsg == "" { errMsg = "OpenAI websocket prewarm error" @@ -3798,7 +3809,7 @@ func (s *OpenAIGatewayService) SelectAccountByPreviousResponseID( if s.getOpenAIWSProtocolResolver().Resolve(account).Transport != OpenAIUpstreamTransportResponsesWebsocketV2 { return nil, nil } - if shouldClearStickySession(account, requestedModel) || !account.IsOpenAI() { + if shouldClearStickySession(account, requestedModel) || !account.IsOpenAI() || !account.IsSchedulable() { _ = store.DeleteResponseAccount(ctx, derefGroupID(groupID), responseID) return nil, nil } @@ -3867,6 +3878,36 @@ func classifyOpenAIWSAcquireError(err error) string { return "acquire_conn" } +func isOpenAIWSRateLimitError(codeRaw, errTypeRaw, msgRaw string) bool { + code := strings.ToLower(strings.TrimSpace(codeRaw)) + errType := strings.ToLower(strings.TrimSpace(errTypeRaw)) + msg := strings.ToLower(strings.TrimSpace(msgRaw)) + + if strings.Contains(errType, "rate_limit") || strings.Contains(errType, "usage_limit") { + return true + } + if strings.Contains(code, "rate_limit") || strings.Contains(code, "usage_limit") || strings.Contains(code, "insufficient_quota") { + return true + } + if strings.Contains(msg, "usage limit") && strings.Contains(msg, "reached") { + return true + } + if strings.Contains(msg, "rate limit") && (strings.Contains(msg, "reached") || strings.Contains(msg, "exceeded")) { + return true + } + return false +} + +func (s *OpenAIGatewayService) persistOpenAIWSRateLimitSignal(ctx context.Context, account *Account, headers http.Header, responseBody []byte, codeRaw, errTypeRaw, msgRaw string) { + if s == nil || s.rateLimitService == nil || account == nil || account.Platform != PlatformOpenAI { + return + } + if !isOpenAIWSRateLimitError(codeRaw, errTypeRaw, msgRaw) { + return + } + s.rateLimitService.HandleUpstreamError(ctx, account, http.StatusTooManyRequests, headers, responseBody) +} + func classifyOpenAIWSErrorEventFromRaw(codeRaw, errTypeRaw, msgRaw string) (string, bool) { code := strings.ToLower(strings.TrimSpace(codeRaw)) errType := strings.ToLower(strings.TrimSpace(errTypeRaw)) @@ -3882,6 +3923,9 @@ func classifyOpenAIWSErrorEventFromRaw(codeRaw, errTypeRaw, msgRaw string) (stri case "previous_response_not_found": return "previous_response_not_found", true } + if isOpenAIWSRateLimitError(codeRaw, errTypeRaw, msgRaw) { + return "upstream_rate_limited", false + } if strings.Contains(msg, "upgrade required") || strings.Contains(msg, "status 426") { return "upgrade_required", true } @@ -3927,9 +3971,7 @@ func openAIWSErrorHTTPStatusFromRaw(codeRaw, errTypeRaw string) int { case strings.Contains(errType, "permission"), strings.Contains(code, "forbidden"): return http.StatusForbidden - case strings.Contains(errType, "rate_limit"), - strings.Contains(code, "rate_limit"), - strings.Contains(code, "insufficient_quota"): + case isOpenAIWSRateLimitError(codeRaw, errTypeRaw, ""): return http.StatusTooManyRequests default: return http.StatusBadGateway diff --git a/backend/internal/service/openai_ws_ratelimit_signal_test.go b/backend/internal/service/openai_ws_ratelimit_signal_test.go new file mode 100644 index 00000000..28cb8e00 --- /dev/null +++ b/backend/internal/service/openai_ws_ratelimit_signal_test.go @@ -0,0 +1,477 @@ +package service + +import ( + "context" + "io" + "net/http" + "net/http/httptest" + "strconv" + "strings" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" + coderws "github.com/coder/websocket" + "github.com/gin-gonic/gin" + "github.com/gorilla/websocket" + "github.com/stretchr/testify/require" +) + +type openAIWSRateLimitSignalRepo struct { + stubOpenAIAccountRepo + rateLimitCalls []time.Time + updateExtra []map[string]any +} + +type openAICodexSnapshotAsyncRepo struct { + stubOpenAIAccountRepo + updateExtraCh chan map[string]any + rateLimitCh chan time.Time +} + +type openAICodexExtraListRepo struct { + stubOpenAIAccountRepo + rateLimitCh chan time.Time +} + +func (r *openAIWSRateLimitSignalRepo) SetRateLimited(_ context.Context, _ int64, resetAt time.Time) error { + r.rateLimitCalls = append(r.rateLimitCalls, resetAt) + return nil +} + +func (r *openAIWSRateLimitSignalRepo) UpdateExtra(_ context.Context, _ int64, updates map[string]any) error { + copied := make(map[string]any, len(updates)) + for k, v := range updates { + copied[k] = v + } + r.updateExtra = append(r.updateExtra, copied) + return nil +} + +func (r *openAICodexSnapshotAsyncRepo) SetRateLimited(_ context.Context, _ int64, resetAt time.Time) error { + if r.rateLimitCh != nil { + r.rateLimitCh <- resetAt + } + return nil +} + +func (r *openAICodexSnapshotAsyncRepo) UpdateExtra(_ context.Context, _ int64, updates map[string]any) error { + if r.updateExtraCh != nil { + copied := make(map[string]any, len(updates)) + for k, v := range updates { + copied[k] = v + } + r.updateExtraCh <- copied + } + return nil +} + +func (r *openAICodexExtraListRepo) SetRateLimited(_ context.Context, _ int64, resetAt time.Time) error { + if r.rateLimitCh != nil { + r.rateLimitCh <- resetAt + } + return nil +} + +func (r *openAICodexExtraListRepo) ListWithFilters(_ context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64) ([]Account, *pagination.PaginationResult, error) { + _ = platform + _ = accountType + _ = status + _ = search + _ = groupID + return r.accounts, &pagination.PaginationResult{Total: int64(len(r.accounts)), Page: params.Page, PageSize: params.PageSize}, nil +} + +func TestOpenAIGatewayService_Forward_WSv2ErrorEventUsageLimitPersistsRateLimit(t *testing.T) { + gin.SetMode(gin.TestMode) + + resetAt := time.Now().Add(2 * time.Hour).Unix() + upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }} + wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + t.Errorf("upgrade websocket failed: %v", err) + return + } + defer func() { _ = conn.Close() }() + + var req map[string]any + if err := conn.ReadJSON(&req); err != nil { + t.Errorf("read ws request failed: %v", err) + return + } + _ = conn.WriteJSON(map[string]any{ + "type": "error", + "error": map[string]any{ + "code": "rate_limit_exceeded", + "type": "usage_limit_reached", + "message": "The usage limit has been reached", + "resets_at": resetAt, + }, + }) + })) + defer wsServer.Close() + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) + c.Request.Header.Set("User-Agent", "unit-test-agent/1.0") + + upstream := &httpUpstreamRecorder{ + resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"application/json"}}, + Body: io.NopCloser(strings.NewReader(`{"id":"resp_http_should_not_run"}`)), + }, + } + + cfg := newOpenAIWSV2TestConfig() + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + + account := Account{ + ID: 501, + Name: "openai-ws-rate-limit-event", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + "base_url": wsServer.URL, + }, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + repo := &openAIWSRateLimitSignalRepo{stubOpenAIAccountRepo: stubOpenAIAccountRepo{accounts: []Account{account}}} + rateSvc := &RateLimitService{accountRepo: repo} + svc := &OpenAIGatewayService{ + accountRepo: repo, + rateLimitService: rateSvc, + httpUpstream: upstream, + cache: &stubGatewayCache{}, + cfg: cfg, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + toolCorrector: NewCodexToolCorrector(), + } + + body := []byte(`{"model":"gpt-5.1","stream":false,"input":[{"type":"input_text","text":"hello"}]}`) + result, err := svc.Forward(context.Background(), c, &account, body) + require.Error(t, err) + require.Nil(t, result) + require.Equal(t, http.StatusTooManyRequests, rec.Code) + require.Nil(t, upstream.lastReq, "WS 限流 error event 不应回退到同账号 HTTP") + require.Len(t, repo.rateLimitCalls, 1) + require.WithinDuration(t, time.Unix(resetAt, 0), repo.rateLimitCalls[0], 2*time.Second) +} + +func TestOpenAIGatewayService_Forward_WSv2Handshake429PersistsRateLimit(t *testing.T) { + gin.SetMode(gin.TestMode) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("x-codex-primary-used-percent", "100") + w.Header().Set("x-codex-primary-reset-after-seconds", "7200") + w.Header().Set("x-codex-primary-window-minutes", "10080") + w.Header().Set("x-codex-secondary-used-percent", "3") + w.Header().Set("x-codex-secondary-reset-after-seconds", "1800") + w.Header().Set("x-codex-secondary-window-minutes", "300") + w.WriteHeader(http.StatusTooManyRequests) + _, _ = w.Write([]byte(`{"error":{"type":"rate_limit_exceeded","message":"rate limited"}}`)) + })) + defer server.Close() + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) + c.Request.Header.Set("User-Agent", "unit-test-agent/1.0") + + upstream := &httpUpstreamRecorder{ + resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"application/json"}}, + Body: io.NopCloser(strings.NewReader(`{"id":"resp_http_should_not_run"}`)), + }, + } + + cfg := newOpenAIWSV2TestConfig() + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + + account := Account{ + ID: 502, + Name: "openai-ws-rate-limit-handshake", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + "base_url": server.URL, + }, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + repo := &openAIWSRateLimitSignalRepo{stubOpenAIAccountRepo: stubOpenAIAccountRepo{accounts: []Account{account}}} + rateSvc := &RateLimitService{accountRepo: repo} + svc := &OpenAIGatewayService{ + accountRepo: repo, + rateLimitService: rateSvc, + httpUpstream: upstream, + cache: &stubGatewayCache{}, + cfg: cfg, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + toolCorrector: NewCodexToolCorrector(), + } + + body := []byte(`{"model":"gpt-5.1","stream":false,"input":[{"type":"input_text","text":"hello"}]}`) + result, err := svc.Forward(context.Background(), c, &account, body) + require.Error(t, err) + require.Nil(t, result) + require.Equal(t, http.StatusTooManyRequests, rec.Code) + require.Nil(t, upstream.lastReq, "WS 握手 429 不应回退到同账号 HTTP") + require.Len(t, repo.rateLimitCalls, 1) + require.NotEmpty(t, repo.updateExtra, "握手 429 的 x-codex 头应立即落库") + require.Contains(t, repo.updateExtra[0], "codex_usage_updated_at") +} + +func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_ErrorEventUsageLimitPersistsRateLimit(t *testing.T) { + gin.SetMode(gin.TestMode) + + cfg := newOpenAIWSV2TestConfig() + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1 + cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0 + cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1 + cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8 + cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3 + cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3 + cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3 + + resetAt := time.Now().Add(90 * time.Minute).Unix() + captureConn := &openAIWSCaptureConn{ + events: [][]byte{ + []byte(`{"type":"error","error":{"code":"rate_limit_exceeded","type":"usage_limit_reached","message":"The usage limit has been reached","resets_at":PLACEHOLDER}}`), + }, + } + captureConn.events[0] = []byte(strings.ReplaceAll(string(captureConn.events[0]), "PLACEHOLDER", strconv.FormatInt(resetAt, 10))) + captureDialer := &openAIWSCaptureDialer{conn: captureConn} + pool := newOpenAIWSConnPool(cfg) + pool.setClientDialerForTest(captureDialer) + + account := Account{ + ID: 503, + Name: "openai-ingress-rate-limit", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + }, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + repo := &openAIWSRateLimitSignalRepo{stubOpenAIAccountRepo: stubOpenAIAccountRepo{accounts: []Account{account}}} + rateSvc := &RateLimitService{accountRepo: repo} + svc := &OpenAIGatewayService{ + accountRepo: repo, + rateLimitService: rateSvc, + httpUpstream: &httpUpstreamRecorder{}, + cache: &stubGatewayCache{}, + cfg: cfg, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + toolCorrector: NewCodexToolCorrector(), + openaiWSPool: pool, + } + + serverErrCh := make(chan error, 1) + wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{CompressionMode: coderws.CompressionContextTakeover}) + if err != nil { + serverErrCh <- err + return + } + defer func() { _ = conn.CloseNow() }() + + rec := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(rec) + req := r.Clone(r.Context()) + req.Header = req.Header.Clone() + req.Header.Set("User-Agent", "unit-test-agent/1.0") + ginCtx.Request = req + + readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second) + msgType, firstMessage, readErr := conn.Read(readCtx) + cancel() + if readErr != nil { + serverErrCh <- readErr + return + } + if msgType != coderws.MessageText && msgType != coderws.MessageBinary { + serverErrCh <- io.ErrUnexpectedEOF + return + } + + serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, &account, "sk-test", firstMessage, nil) + })) + defer wsServer.Close() + + dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second) + clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil) + cancelDial() + require.NoError(t, err) + defer func() { _ = clientConn.CloseNow() }() + + writeCtx, cancelWrite := context.WithTimeout(context.Background(), 3*time.Second) + err = clientConn.Write(writeCtx, coderws.MessageText, []byte(`{"type":"response.create","model":"gpt-5.1","stream":false}`)) + cancelWrite() + require.NoError(t, err) + + select { + case serverErr := <-serverErrCh: + require.Error(t, serverErr) + require.Len(t, repo.rateLimitCalls, 1) + require.WithinDuration(t, time.Unix(resetAt, 0), repo.rateLimitCalls[0], 2*time.Second) + case <-time.After(5 * time.Second): + t.Fatal("等待 ingress websocket 结束超时") + } +} + +func TestOpenAIGatewayService_UpdateCodexUsageSnapshot_ExhaustedSnapshotSetsRateLimit(t *testing.T) { + repo := &openAICodexSnapshotAsyncRepo{ + updateExtraCh: make(chan map[string]any, 1), + rateLimitCh: make(chan time.Time, 1), + } + svc := &OpenAIGatewayService{accountRepo: repo} + snapshot := &OpenAICodexUsageSnapshot{ + PrimaryUsedPercent: ptrFloat64WS(100), + PrimaryResetAfterSeconds: ptrIntWS(3600), + PrimaryWindowMinutes: ptrIntWS(10080), + SecondaryUsedPercent: ptrFloat64WS(12), + SecondaryResetAfterSeconds: ptrIntWS(1200), + SecondaryWindowMinutes: ptrIntWS(300), + } + before := time.Now() + svc.updateCodexUsageSnapshot(context.Background(), 601, snapshot) + + select { + case updates := <-repo.updateExtraCh: + require.Equal(t, 100.0, updates["codex_7d_used_percent"]) + case <-time.After(2 * time.Second): + t.Fatal("等待 codex 快照落库超时") + } + + select { + case resetAt := <-repo.rateLimitCh: + require.WithinDuration(t, before.Add(time.Hour), resetAt, 2*time.Second) + case <-time.After(2 * time.Second): + t.Fatal("等待 codex 100% 自动切换限流超时") + } +} + +func TestOpenAIGatewayService_UpdateCodexUsageSnapshot_NonExhaustedSnapshotDoesNotSetRateLimit(t *testing.T) { + repo := &openAICodexSnapshotAsyncRepo{ + updateExtraCh: make(chan map[string]any, 1), + rateLimitCh: make(chan time.Time, 1), + } + svc := &OpenAIGatewayService{accountRepo: repo} + snapshot := &OpenAICodexUsageSnapshot{ + PrimaryUsedPercent: ptrFloat64WS(94), + PrimaryResetAfterSeconds: ptrIntWS(3600), + PrimaryWindowMinutes: ptrIntWS(10080), + SecondaryUsedPercent: ptrFloat64WS(22), + SecondaryResetAfterSeconds: ptrIntWS(1200), + SecondaryWindowMinutes: ptrIntWS(300), + } + svc.updateCodexUsageSnapshot(context.Background(), 602, snapshot) + + select { + case <-repo.updateExtraCh: + case <-time.After(2 * time.Second): + t.Fatal("等待 codex 快照落库超时") + } + + select { + case resetAt := <-repo.rateLimitCh: + t.Fatalf("unexpected rate limit reset at: %v", resetAt) + case <-time.After(200 * time.Millisecond): + } +} + +func ptrFloat64WS(v float64) *float64 { return &v } +func ptrIntWS(v int) *int { return &v } + +func TestOpenAIGatewayService_GetSchedulableAccount_ExhaustedCodexExtraSetsRateLimit(t *testing.T) { + resetAt := time.Now().Add(6 * 24 * time.Hour) + account := Account{ + ID: 701, + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Extra: map[string]any{ + "codex_7d_used_percent": 100.0, + "codex_7d_reset_at": resetAt.UTC().Format(time.RFC3339), + }, + } + repo := &openAICodexExtraListRepo{stubOpenAIAccountRepo: stubOpenAIAccountRepo{accounts: []Account{account}}, rateLimitCh: make(chan time.Time, 1)} + svc := &OpenAIGatewayService{accountRepo: repo} + + fresh, err := svc.getSchedulableAccount(context.Background(), account.ID) + require.NoError(t, err) + require.NotNil(t, fresh) + require.NotNil(t, fresh.RateLimitResetAt) + require.WithinDuration(t, resetAt.UTC(), *fresh.RateLimitResetAt, time.Second) + select { + case persisted := <-repo.rateLimitCh: + require.WithinDuration(t, resetAt.UTC(), persisted, time.Second) + case <-time.After(2 * time.Second): + t.Fatal("等待旧快照补写限流状态超时") + } +} + +func TestAdminService_ListAccounts_ExhaustedCodexExtraReturnsRateLimitedAccount(t *testing.T) { + resetAt := time.Now().Add(4 * 24 * time.Hour) + repo := &openAICodexExtraListRepo{ + stubOpenAIAccountRepo: stubOpenAIAccountRepo{accounts: []Account{{ + ID: 702, + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Extra: map[string]any{ + "codex_7d_used_percent": 100.0, + "codex_7d_reset_at": resetAt.UTC().Format(time.RFC3339), + }, + }}}, + rateLimitCh: make(chan time.Time, 1), + } + svc := &adminServiceImpl{accountRepo: repo} + + accounts, total, err := svc.ListAccounts(context.Background(), 1, 20, PlatformOpenAI, AccountTypeOAuth, "", "", 0) + require.NoError(t, err) + require.Equal(t, int64(1), total) + require.Len(t, accounts, 1) + require.NotNil(t, accounts[0].RateLimitResetAt) + require.WithinDuration(t, resetAt.UTC(), *accounts[0].RateLimitResetAt, time.Second) + select { + case persisted := <-repo.rateLimitCh: + require.WithinDuration(t, resetAt.UTC(), persisted, time.Second) + case <-time.After(2 * time.Second): + t.Fatal("等待列表补写限流状态超时") + } +} + +func TestOpenAIWSErrorHTTPStatusFromRaw_UsageLimitReachedIs429(t *testing.T) { + require.Equal(t, http.StatusTooManyRequests, openAIWSErrorHTTPStatusFromRaw("", "usage_limit_reached")) + require.Equal(t, http.StatusTooManyRequests, openAIWSErrorHTTPStatusFromRaw("rate_limit_exceeded", "")) +} diff --git a/backend/internal/service/ratelimit_service.go b/backend/internal/service/ratelimit_service.go index f8f3154b..60ad99d0 100644 --- a/backend/internal/service/ratelimit_service.go +++ b/backend/internal/service/ratelimit_service.go @@ -615,6 +615,7 @@ func (s *RateLimitService) handleCustomErrorCode(ctx context.Context, account *A func (s *RateLimitService) handle429(ctx context.Context, account *Account, headers http.Header, responseBody []byte) { // 1. OpenAI 平台:优先尝试解析 x-codex-* 响应头(用于 rate_limit_exceeded) if account.Platform == PlatformOpenAI { + s.persistOpenAICodexSnapshot(ctx, account, headers) if resetAt := s.calculateOpenAI429ResetTime(headers); resetAt != nil { if err := s.accountRepo.SetRateLimited(ctx, account.ID, *resetAt); err != nil { slog.Warn("rate_limit_set_failed", "account_id", account.ID, "error", err) @@ -878,6 +879,23 @@ func pickSooner(a, b *time.Time) *time.Time { } } +func (s *RateLimitService) persistOpenAICodexSnapshot(ctx context.Context, account *Account, headers http.Header) { + if s == nil || s.accountRepo == nil || account == nil || headers == nil { + return + } + snapshot := ParseCodexRateLimitHeaders(headers) + if snapshot == nil { + return + } + updates := buildCodexUsageExtraUpdates(snapshot, time.Now()) + if len(updates) == 0 { + return + } + if err := s.accountRepo.UpdateExtra(ctx, account.ID, updates); err != nil { + slog.Warn("openai_codex_snapshot_persist_failed", "account_id", account.ID, "error", err) + } +} + // parseOpenAIRateLimitResetTime 解析 OpenAI 格式的 429 响应,返回重置时间的 Unix 时间戳 // OpenAI 的 usage_limit_reached 错误格式: // diff --git a/backend/internal/service/ratelimit_service_openai_test.go b/backend/internal/service/ratelimit_service_openai_test.go index 51d7c62a..89c754c8 100644 --- a/backend/internal/service/ratelimit_service_openai_test.go +++ b/backend/internal/service/ratelimit_service_openai_test.go @@ -3,6 +3,7 @@ package service import ( + "context" "net/http" "testing" "time" @@ -143,6 +144,51 @@ func TestCalculateOpenAI429ResetTime_ReversedWindowOrder(t *testing.T) { } } +type openAI429SnapshotRepo struct { + mockAccountRepoForGemini + rateLimitedID int64 + updatedExtra map[string]any +} + +func (r *openAI429SnapshotRepo) SetRateLimited(_ context.Context, id int64, _ time.Time) error { + r.rateLimitedID = id + return nil +} + +func (r *openAI429SnapshotRepo) UpdateExtra(_ context.Context, _ int64, updates map[string]any) error { + r.updatedExtra = updates + return nil +} + +func TestHandle429_OpenAIPersistsCodexSnapshotImmediately(t *testing.T) { + repo := &openAI429SnapshotRepo{} + svc := NewRateLimitService(repo, nil, nil, nil, nil) + account := &Account{ID: 123, Platform: PlatformOpenAI, Type: AccountTypeOAuth} + + headers := http.Header{} + headers.Set("x-codex-primary-used-percent", "100") + headers.Set("x-codex-primary-reset-after-seconds", "604800") + headers.Set("x-codex-primary-window-minutes", "10080") + headers.Set("x-codex-secondary-used-percent", "100") + headers.Set("x-codex-secondary-reset-after-seconds", "18000") + headers.Set("x-codex-secondary-window-minutes", "300") + + svc.handle429(context.Background(), account, headers, nil) + + if repo.rateLimitedID != account.ID { + t.Fatalf("rateLimitedID = %d, want %d", repo.rateLimitedID, account.ID) + } + if len(repo.updatedExtra) == 0 { + t.Fatal("expected codex snapshot to be persisted on 429") + } + if got := repo.updatedExtra["codex_5h_used_percent"]; got != 100.0 { + t.Fatalf("codex_5h_used_percent = %v, want 100", got) + } + if got := repo.updatedExtra["codex_7d_used_percent"]; got != 100.0 { + t.Fatalf("codex_7d_used_percent = %v, want 100", got) + } +} + func TestNormalizedCodexLimits(t *testing.T) { // Test the Normalize() method directly pUsed := 100.0 diff --git a/frontend/src/components/account/AccountStatusIndicator.vue b/frontend/src/components/account/AccountStatusIndicator.vue index e8331c25..1dc4f287 100644 --- a/frontend/src/components/account/AccountStatusIndicator.vue +++ b/frontend/src/components/account/AccountStatusIndicator.vue @@ -3,7 +3,7 @@