From 269414948976700f05a6636b1a9d53f0412ec776 Mon Sep 17 00:00:00 2001 From: ius Date: Wed, 11 Mar 2026 13:53:19 +0800 Subject: [PATCH 1/2] Reduce DB write amplification on quota and account extra updates --- backend/internal/repository/account_repo.go | 31 +++- .../account_repo_integration_test.go | 32 ++++ backend/internal/repository/api_key_repo.go | 26 +++ .../api_key_repo_integration_test.go | 21 +++ backend/internal/service/api_key_service.go | 25 +++ .../service/api_key_service_quota_test.go | 170 ++++++++++++++++++ .../service/openai_gateway_service.go | 66 ++++++- .../openai_ws_ratelimit_signal_test.go | 34 ++++ 8 files changed, 396 insertions(+), 9 deletions(-) create mode 100644 backend/internal/service/api_key_service_quota_test.go diff --git a/backend/internal/repository/account_repo.go b/backend/internal/repository/account_repo.go index c7642152..2aa72ebb 100644 --- a/backend/internal/repository/account_repo.go +++ b/backend/internal/repository/account_repo.go @@ -16,6 +16,7 @@ import ( "encoding/json" "errors" "strconv" + "strings" "time" dbent "github.com/Wei-Shaw/sub2api/ent" @@ -1185,12 +1186,38 @@ func (r *accountRepository) UpdateExtra(ctx context.Context, id int64, updates m if affected == 0 { return service.ErrAccountNotFound } - if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil { - logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue extra update failed: account=%d err=%v", id, err) + if shouldEnqueueSchedulerOutboxForExtraUpdates(updates) { + if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil { + logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue extra update failed: account=%d err=%v", id, err) + } } return nil } +func shouldEnqueueSchedulerOutboxForExtraUpdates(updates map[string]any) bool { + if len(updates) == 0 { + return false + } + for key := range updates { + if isSchedulerNeutralAccountExtraKey(key) { + continue + } + return true + } + return false +} + +func isSchedulerNeutralAccountExtraKey(key string) bool { + key = strings.TrimSpace(key) + if key == "" { + return false + } + if key == "session_window_utilization" { + return true + } + return strings.HasPrefix(key, "codex_") +} + func (r *accountRepository) BulkUpdate(ctx context.Context, ids []int64, updates service.AccountBulkUpdate) (int64, error) { if len(ids) == 0 { return 0, nil diff --git a/backend/internal/repository/account_repo_integration_test.go b/backend/internal/repository/account_repo_integration_test.go index 58971933..caf8d3f3 100644 --- a/backend/internal/repository/account_repo_integration_test.go +++ b/backend/internal/repository/account_repo_integration_test.go @@ -623,6 +623,38 @@ func (s *AccountRepoSuite) TestUpdateExtra_NilExtra() { s.Require().Equal("val", got.Extra["key"]) } +func (s *AccountRepoSuite) TestUpdateExtra_SchedulerNeutralKeysSkipOutbox() { + account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-extra-neutral", Extra: map[string]any{}}) + _, err := s.repo.sql.ExecContext(s.ctx, "TRUNCATE scheduler_outbox") + s.Require().NoError(err) + + s.Require().NoError(s.repo.UpdateExtra(s.ctx, account.ID, map[string]any{ + "codex_usage_updated_at": "2026-03-11T13:00:00Z", + "codex_5h_used_percent": 12.5, + "session_window_utilization": 0.42, + })) + + var count int + err = scanSingleRow(s.ctx, s.repo.sql, "SELECT COUNT(*) FROM scheduler_outbox", nil, &count) + s.Require().NoError(err) + s.Require().Equal(0, count) +} + +func (s *AccountRepoSuite) TestUpdateExtra_CustomKeysStillEnqueueOutbox() { + account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-extra-custom", Extra: map[string]any{}}) + _, err := s.repo.sql.ExecContext(s.ctx, "TRUNCATE scheduler_outbox") + s.Require().NoError(err) + + s.Require().NoError(s.repo.UpdateExtra(s.ctx, account.ID, map[string]any{ + "custom_scheduler_sensitive_key": true, + })) + + var count int + err = scanSingleRow(s.ctx, s.repo.sql, "SELECT COUNT(*) FROM scheduler_outbox", nil, &count) + s.Require().NoError(err) + s.Require().Equal(1, count) +} + // --- GetByCRSAccountID --- func (s *AccountRepoSuite) TestGetByCRSAccountID() { diff --git a/backend/internal/repository/api_key_repo.go b/backend/internal/repository/api_key_repo.go index 95db1819..4c7f38a8 100644 --- a/backend/internal/repository/api_key_repo.go +++ b/backend/internal/repository/api_key_repo.go @@ -452,6 +452,32 @@ func (r *apiKeyRepository) IncrementQuotaUsed(ctx context.Context, id int64, amo return updated.QuotaUsed, nil } +// IncrementQuotaUsedAndGetState atomically increments quota_used, conditionally marks the key +// as quota_exhausted, and returns the latest quota state in one round trip. +func (r *apiKeyRepository) IncrementQuotaUsedAndGetState(ctx context.Context, id int64, amount float64) (*service.APIKeyQuotaUsageState, error) { + query := ` + UPDATE api_keys + SET + quota_used = quota_used + $1, + status = CASE + WHEN quota > 0 AND quota_used + $1 >= quota THEN $2 + ELSE status + END, + updated_at = NOW() + WHERE id = $3 AND deleted_at IS NULL + RETURNING quota_used, quota, key, status + ` + + state := &service.APIKeyQuotaUsageState{} + if err := scanSingleRow(ctx, r.sql, query, []any{amount, service.StatusAPIKeyQuotaExhausted, id}, &state.QuotaUsed, &state.Quota, &state.Key, &state.Status); err != nil { + if err == sql.ErrNoRows { + return nil, service.ErrAPIKeyNotFound + } + return nil, err + } + return state, nil +} + func (r *apiKeyRepository) UpdateLastUsed(ctx context.Context, id int64, usedAt time.Time) error { affected, err := r.client.APIKey.Update(). Where(apikey.IDEQ(id), apikey.DeletedAtIsNil()). diff --git a/backend/internal/repository/api_key_repo_integration_test.go b/backend/internal/repository/api_key_repo_integration_test.go index 80714614..a8989ff2 100644 --- a/backend/internal/repository/api_key_repo_integration_test.go +++ b/backend/internal/repository/api_key_repo_integration_test.go @@ -417,6 +417,27 @@ func (s *APIKeyRepoSuite) TestIncrementQuotaUsed_DeletedKey() { s.Require().ErrorIs(err, service.ErrAPIKeyNotFound, "已删除的 key 应返回 ErrAPIKeyNotFound") } +func (s *APIKeyRepoSuite) TestIncrementQuotaUsedAndGetState() { + user := s.mustCreateUser("quota-state@test.com") + key := s.mustCreateApiKey(user.ID, "sk-quota-state", "QuotaState", nil) + key.Quota = 3 + key.QuotaUsed = 1 + s.Require().NoError(s.repo.Update(s.ctx, key), "Update quota") + + state, err := s.repo.IncrementQuotaUsedAndGetState(s.ctx, key.ID, 2.5) + s.Require().NoError(err, "IncrementQuotaUsedAndGetState") + s.Require().NotNil(state) + s.Require().Equal(3.5, state.QuotaUsed) + s.Require().Equal(3.0, state.Quota) + s.Require().Equal(service.StatusAPIKeyQuotaExhausted, state.Status) + s.Require().Equal(key.Key, state.Key) + + got, err := s.repo.GetByID(s.ctx, key.ID) + s.Require().NoError(err, "GetByID") + s.Require().Equal(3.5, got.QuotaUsed) + s.Require().Equal(service.StatusAPIKeyQuotaExhausted, got.Status) +} + // TestIncrementQuotaUsed_Concurrent 使用真实数据库验证并发原子性。 // 注意:此测试使用 testEntClient(非事务隔离),数据会真正写入数据库。 func TestIncrementQuotaUsed_Concurrent(t *testing.T) { diff --git a/backend/internal/service/api_key_service.go b/backend/internal/service/api_key_service.go index 17c5b486..18e9ff7a 100644 --- a/backend/internal/service/api_key_service.go +++ b/backend/internal/service/api_key_service.go @@ -6,6 +6,7 @@ import ( "encoding/hex" "fmt" "strconv" + "strings" "sync" "time" @@ -110,6 +111,15 @@ func (d *APIKeyRateLimitData) EffectiveUsage7d() float64 { return d.Usage7d } +// APIKeyQuotaUsageState captures the latest quota fields after an atomic quota update. +// It is intentionally small so repositories can return it from a single SQL statement. +type APIKeyQuotaUsageState struct { + QuotaUsed float64 + Quota float64 + Key string + Status string +} + // APIKeyCache defines cache operations for API key service type APIKeyCache interface { GetCreateAttemptCount(ctx context.Context, userID int64) (int, error) @@ -817,6 +827,21 @@ func (s *APIKeyService) UpdateQuotaUsed(ctx context.Context, apiKeyID int64, cos return nil } + type quotaStateReader interface { + IncrementQuotaUsedAndGetState(ctx context.Context, id int64, amount float64) (*APIKeyQuotaUsageState, error) + } + + if repo, ok := s.apiKeyRepo.(quotaStateReader); ok { + state, err := repo.IncrementQuotaUsedAndGetState(ctx, apiKeyID, cost) + if err != nil { + return fmt.Errorf("increment quota used: %w", err) + } + if state != nil && state.Status == StatusAPIKeyQuotaExhausted && strings.TrimSpace(state.Key) != "" { + s.InvalidateAuthCacheByKey(ctx, state.Key) + } + return nil + } + // Use repository to atomically increment quota_used newQuotaUsed, err := s.apiKeyRepo.IncrementQuotaUsed(ctx, apiKeyID, cost) if err != nil { diff --git a/backend/internal/service/api_key_service_quota_test.go b/backend/internal/service/api_key_service_quota_test.go new file mode 100644 index 00000000..2e2f6f78 --- /dev/null +++ b/backend/internal/service/api_key_service_quota_test.go @@ -0,0 +1,170 @@ +//go:build unit + +package service + +import ( + "context" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" + "github.com/stretchr/testify/require" +) + +type quotaStateRepoStub struct { + quotaBaseAPIKeyRepoStub + stateCalls int + state *APIKeyQuotaUsageState + stateErr error +} + +func (s *quotaStateRepoStub) IncrementQuotaUsedAndGetState(ctx context.Context, id int64, amount float64) (*APIKeyQuotaUsageState, error) { + s.stateCalls++ + if s.stateErr != nil { + return nil, s.stateErr + } + if s.state == nil { + return nil, nil + } + out := *s.state + return &out, nil +} + +type quotaStateCacheStub struct { + deleteAuthKeys []string +} + +func (s *quotaStateCacheStub) GetCreateAttemptCount(context.Context, int64) (int, error) { + return 0, nil +} + +func (s *quotaStateCacheStub) IncrementCreateAttemptCount(context.Context, int64) error { + return nil +} + +func (s *quotaStateCacheStub) DeleteCreateAttemptCount(context.Context, int64) error { + return nil +} + +func (s *quotaStateCacheStub) IncrementDailyUsage(context.Context, string) error { + return nil +} + +func (s *quotaStateCacheStub) SetDailyUsageExpiry(context.Context, string, time.Duration) error { + return nil +} + +func (s *quotaStateCacheStub) GetAuthCache(context.Context, string) (*APIKeyAuthCacheEntry, error) { + return nil, nil +} + +func (s *quotaStateCacheStub) SetAuthCache(context.Context, string, *APIKeyAuthCacheEntry, time.Duration) error { + return nil +} + +func (s *quotaStateCacheStub) DeleteAuthCache(_ context.Context, key string) error { + s.deleteAuthKeys = append(s.deleteAuthKeys, key) + return nil +} + +func (s *quotaStateCacheStub) PublishAuthCacheInvalidation(context.Context, string) error { + return nil +} + +func (s *quotaStateCacheStub) SubscribeAuthCacheInvalidation(context.Context, func(string)) error { + return nil +} + +type quotaBaseAPIKeyRepoStub struct { + getByIDCalls int +} + +func (s *quotaBaseAPIKeyRepoStub) Create(context.Context, *APIKey) error { + panic("unexpected Create call") +} +func (s *quotaBaseAPIKeyRepoStub) GetByID(context.Context, int64) (*APIKey, error) { + s.getByIDCalls++ + return nil, nil +} +func (s *quotaBaseAPIKeyRepoStub) GetKeyAndOwnerID(context.Context, int64) (string, int64, error) { + panic("unexpected GetKeyAndOwnerID call") +} +func (s *quotaBaseAPIKeyRepoStub) GetByKey(context.Context, string) (*APIKey, error) { + panic("unexpected GetByKey call") +} +func (s *quotaBaseAPIKeyRepoStub) GetByKeyForAuth(context.Context, string) (*APIKey, error) { + panic("unexpected GetByKeyForAuth call") +} +func (s *quotaBaseAPIKeyRepoStub) Update(context.Context, *APIKey) error { + panic("unexpected Update call") +} +func (s *quotaBaseAPIKeyRepoStub) Delete(context.Context, int64) error { + panic("unexpected Delete call") +} +func (s *quotaBaseAPIKeyRepoStub) ListByUserID(context.Context, int64, pagination.PaginationParams, APIKeyListFilters) ([]APIKey, *pagination.PaginationResult, error) { + panic("unexpected ListByUserID call") +} +func (s *quotaBaseAPIKeyRepoStub) VerifyOwnership(context.Context, int64, []int64) ([]int64, error) { + panic("unexpected VerifyOwnership call") +} +func (s *quotaBaseAPIKeyRepoStub) CountByUserID(context.Context, int64) (int64, error) { + panic("unexpected CountByUserID call") +} +func (s *quotaBaseAPIKeyRepoStub) ExistsByKey(context.Context, string) (bool, error) { + panic("unexpected ExistsByKey call") +} +func (s *quotaBaseAPIKeyRepoStub) ListByGroupID(context.Context, int64, pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error) { + panic("unexpected ListByGroupID call") +} +func (s *quotaBaseAPIKeyRepoStub) SearchAPIKeys(context.Context, int64, string, int) ([]APIKey, error) { + panic("unexpected SearchAPIKeys call") +} +func (s *quotaBaseAPIKeyRepoStub) ClearGroupIDByGroupID(context.Context, int64) (int64, error) { + panic("unexpected ClearGroupIDByGroupID call") +} +func (s *quotaBaseAPIKeyRepoStub) CountByGroupID(context.Context, int64) (int64, error) { + panic("unexpected CountByGroupID call") +} +func (s *quotaBaseAPIKeyRepoStub) ListKeysByUserID(context.Context, int64) ([]string, error) { + panic("unexpected ListKeysByUserID call") +} +func (s *quotaBaseAPIKeyRepoStub) ListKeysByGroupID(context.Context, int64) ([]string, error) { + panic("unexpected ListKeysByGroupID call") +} +func (s *quotaBaseAPIKeyRepoStub) IncrementQuotaUsed(context.Context, int64, float64) (float64, error) { + panic("unexpected IncrementQuotaUsed call") +} +func (s *quotaBaseAPIKeyRepoStub) UpdateLastUsed(context.Context, int64, time.Time) error { + panic("unexpected UpdateLastUsed call") +} +func (s *quotaBaseAPIKeyRepoStub) IncrementRateLimitUsage(context.Context, int64, float64) error { + panic("unexpected IncrementRateLimitUsage call") +} +func (s *quotaBaseAPIKeyRepoStub) ResetRateLimitWindows(context.Context, int64) error { + panic("unexpected ResetRateLimitWindows call") +} +func (s *quotaBaseAPIKeyRepoStub) GetRateLimitData(context.Context, int64) (*APIKeyRateLimitData, error) { + panic("unexpected GetRateLimitData call") +} + +func TestAPIKeyService_UpdateQuotaUsed_UsesAtomicStatePath(t *testing.T) { + repo := "aStateRepoStub{ + state: &APIKeyQuotaUsageState{ + QuotaUsed: 12, + Quota: 10, + Key: "sk-test-quota", + Status: StatusAPIKeyQuotaExhausted, + }, + } + cache := "aStateCacheStub{} + svc := &APIKeyService{ + apiKeyRepo: repo, + cache: cache, + } + + err := svc.UpdateQuotaUsed(context.Background(), 101, 2) + require.NoError(t, err) + require.Equal(t, 1, repo.stateCalls) + require.Equal(t, 0, repo.getByIDCalls, "fast path should not re-read API key by id") + require.Equal(t, []string{svc.authCacheKey("sk-test-quota")}, cache.deleteAuthKeys) +} diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go index 44cfc83a..233566f3 100644 --- a/backend/internal/service/openai_gateway_service.go +++ b/backend/internal/service/openai_gateway_service.go @@ -52,6 +52,8 @@ const ( openAIWSRetryJitterRatioDefault = 0.2 openAICompactSessionSeedKey = "openai_compact_session_seed" codexCLIVersion = "0.104.0" + // Codex 限额快照仅用于后台展示/诊断,不需要每个成功请求都立即落库。 + openAICodexSnapshotPersistMinInterval = 30 * time.Second ) // OpenAI allowed headers whitelist (for non-passthrough). @@ -255,6 +257,46 @@ type openAIWSRetryMetrics struct { nonRetryableFastFallback atomic.Int64 } +type accountWriteThrottle struct { + minInterval time.Duration + mu sync.Mutex + lastByID map[int64]time.Time +} + +func newAccountWriteThrottle(minInterval time.Duration) *accountWriteThrottle { + return &accountWriteThrottle{ + minInterval: minInterval, + lastByID: make(map[int64]time.Time), + } +} + +func (t *accountWriteThrottle) Allow(id int64, now time.Time) bool { + if t == nil || id <= 0 || t.minInterval <= 0 { + return true + } + + t.mu.Lock() + defer t.mu.Unlock() + + if last, ok := t.lastByID[id]; ok && now.Sub(last) < t.minInterval { + return false + } + t.lastByID[id] = now + + if len(t.lastByID) > 4096 { + cutoff := now.Add(-4 * t.minInterval) + for accountID, writtenAt := range t.lastByID { + if writtenAt.Before(cutoff) { + delete(t.lastByID, accountID) + } + } + } + + return true +} + +var defaultOpenAICodexSnapshotPersistThrottle = newAccountWriteThrottle(openAICodexSnapshotPersistMinInterval) + // OpenAIGatewayService handles OpenAI API gateway operations type OpenAIGatewayService struct { accountRepo AccountRepository @@ -289,6 +331,7 @@ type OpenAIGatewayService struct { openaiWSFallbackUntil sync.Map // key: int64(accountID), value: time.Time openaiWSRetryMetrics openAIWSRetryMetrics responseHeaderFilter *responseheaders.CompiledHeaderFilter + codexSnapshotThrottle *accountWriteThrottle } // NewOpenAIGatewayService creates a new OpenAIGatewayService @@ -329,17 +372,25 @@ func NewOpenAIGatewayService( nil, "service.openai_gateway", ), - httpUpstream: httpUpstream, - deferredService: deferredService, - openAITokenProvider: openAITokenProvider, - toolCorrector: NewCodexToolCorrector(), - openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), - responseHeaderFilter: compileResponseHeaderFilter(cfg), + httpUpstream: httpUpstream, + deferredService: deferredService, + openAITokenProvider: openAITokenProvider, + toolCorrector: NewCodexToolCorrector(), + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + responseHeaderFilter: compileResponseHeaderFilter(cfg), + codexSnapshotThrottle: newAccountWriteThrottle(openAICodexSnapshotPersistMinInterval), } svc.logOpenAIWSModeBootstrap() return svc } +func (s *OpenAIGatewayService) getCodexSnapshotThrottle() *accountWriteThrottle { + if s != nil && s.codexSnapshotThrottle != nil { + return s.codexSnapshotThrottle + } + return defaultOpenAICodexSnapshotPersistThrottle +} + func (s *OpenAIGatewayService) billingDeps() *billingDeps { return &billingDeps{ accountRepo: s.accountRepo, @@ -4050,11 +4101,12 @@ func (s *OpenAIGatewayService) updateCodexUsageSnapshot(ctx context.Context, acc if len(updates) == 0 && resetAt == nil { return } + shouldPersistUpdates := len(updates) > 0 && s.getCodexSnapshotThrottle().Allow(accountID, now) go func() { updateCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() - if len(updates) > 0 { + if shouldPersistUpdates { _ = s.accountRepo.UpdateExtra(updateCtx, accountID, updates) } if resetAt != nil { diff --git a/backend/internal/service/openai_ws_ratelimit_signal_test.go b/backend/internal/service/openai_ws_ratelimit_signal_test.go index 28cb8e00..f5c79923 100644 --- a/backend/internal/service/openai_ws_ratelimit_signal_test.go +++ b/backend/internal/service/openai_ws_ratelimit_signal_test.go @@ -405,6 +405,40 @@ func TestOpenAIGatewayService_UpdateCodexUsageSnapshot_NonExhaustedSnapshotDoesN } } +func TestOpenAIGatewayService_UpdateCodexUsageSnapshot_ThrottlesExtraWrites(t *testing.T) { + repo := &openAICodexSnapshotAsyncRepo{ + updateExtraCh: make(chan map[string]any, 2), + rateLimitCh: make(chan time.Time, 2), + } + svc := &OpenAIGatewayService{ + accountRepo: repo, + codexSnapshotThrottle: newAccountWriteThrottle(time.Hour), + } + snapshot := &OpenAICodexUsageSnapshot{ + PrimaryUsedPercent: ptrFloat64WS(94), + PrimaryResetAfterSeconds: ptrIntWS(3600), + PrimaryWindowMinutes: ptrIntWS(10080), + SecondaryUsedPercent: ptrFloat64WS(22), + SecondaryResetAfterSeconds: ptrIntWS(1200), + SecondaryWindowMinutes: ptrIntWS(300), + } + + svc.updateCodexUsageSnapshot(context.Background(), 777, snapshot) + svc.updateCodexUsageSnapshot(context.Background(), 777, snapshot) + + select { + case <-repo.updateExtraCh: + case <-time.After(2 * time.Second): + t.Fatal("等待第一次 codex 快照落库超时") + } + + select { + case updates := <-repo.updateExtraCh: + t.Fatalf("unexpected second codex snapshot write: %v", updates) + case <-time.After(200 * time.Millisecond): + } +} + func ptrFloat64WS(v float64) *float64 { return &v } func ptrIntWS(v int) *int { return &v } From 2fc6aaf9364f50f9af3c119a1b1abd3e43c064e6 Mon Sep 17 00:00:00 2001 From: ius Date: Wed, 11 Mar 2026 15:47:39 +0800 Subject: [PATCH 2/2] Fix Codex exhausted snapshot propagation --- backend/internal/repository/account_repo.go | 79 +++++++++++++++++ .../account_repo_integration_test.go | 27 ++++++ .../internal/service/account_usage_service.go | 87 ++++++++++++------- .../service/account_usage_service_test.go | 82 +++++++++++++++++ .../service/openai_gateway_service.go | 3 + 5 files changed, 248 insertions(+), 30 deletions(-) diff --git a/backend/internal/repository/account_repo.go b/backend/internal/repository/account_repo.go index 2aa72ebb..8083d3d1 100644 --- a/backend/internal/repository/account_repo.go +++ b/backend/internal/repository/account_repo.go @@ -1190,6 +1190,9 @@ func (r *accountRepository) UpdateExtra(ctx context.Context, id int64, updates m if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil { logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue extra update failed: account=%d err=%v", id, err) } + } else if shouldSyncSchedulerSnapshotForExtraUpdates(updates) { + // codex 限流快照仍需要让调度缓存尽快看见,避免 DB 抖动时丢失自愈链路。 + r.syncSchedulerAccountSnapshot(ctx, id) } return nil } @@ -1207,6 +1210,10 @@ func shouldEnqueueSchedulerOutboxForExtraUpdates(updates map[string]any) bool { return false } +func shouldSyncSchedulerSnapshotForExtraUpdates(updates map[string]any) bool { + return codexExtraIndicatesRateLimit(updates, "7d") || codexExtraIndicatesRateLimit(updates, "5h") +} + func isSchedulerNeutralAccountExtraKey(key string) bool { key = strings.TrimSpace(key) if key == "" { @@ -1218,6 +1225,78 @@ func isSchedulerNeutralAccountExtraKey(key string) bool { return strings.HasPrefix(key, "codex_") } +func codexExtraIndicatesRateLimit(updates map[string]any, window string) bool { + if len(updates) == 0 { + return false + } + usedValue, ok := updates["codex_"+window+"_used_percent"] + if !ok || !extraValueIndicatesExhausted(usedValue) { + return false + } + return extraValueHasResetMarker(updates["codex_"+window+"_reset_at"]) || + extraValueHasPositiveNumber(updates["codex_"+window+"_reset_after_seconds"]) +} + +func extraValueIndicatesExhausted(value any) bool { + number, ok := extraValueToFloat64(value) + return ok && number >= 100-1e-9 +} + +func extraValueHasPositiveNumber(value any) bool { + number, ok := extraValueToFloat64(value) + return ok && number > 0 +} + +func extraValueHasResetMarker(value any) bool { + switch v := value.(type) { + case string: + return strings.TrimSpace(v) != "" + case time.Time: + return !v.IsZero() + case *time.Time: + return v != nil && !v.IsZero() + default: + return false + } +} + +func extraValueToFloat64(value any) (float64, bool) { + switch v := value.(type) { + case float64: + return v, true + case float32: + return float64(v), true + case int: + return float64(v), true + case int8: + return float64(v), true + case int16: + return float64(v), true + case int32: + return float64(v), true + case int64: + return float64(v), true + case uint: + return float64(v), true + case uint8: + return float64(v), true + case uint16: + return float64(v), true + case uint32: + return float64(v), true + case uint64: + return float64(v), true + case json.Number: + parsed, err := v.Float64() + return parsed, err == nil + case string: + parsed, err := strconv.ParseFloat(strings.TrimSpace(v), 64) + return parsed, err == nil + default: + return 0, false + } +} + func (r *accountRepository) BulkUpdate(ctx context.Context, ids []int64, updates service.AccountBulkUpdate) (int64, error) { if len(ids) == 0 { return 0, nil diff --git a/backend/internal/repository/account_repo_integration_test.go b/backend/internal/repository/account_repo_integration_test.go index caf8d3f3..56f62491 100644 --- a/backend/internal/repository/account_repo_integration_test.go +++ b/backend/internal/repository/account_repo_integration_test.go @@ -640,6 +640,33 @@ func (s *AccountRepoSuite) TestUpdateExtra_SchedulerNeutralKeysSkipOutbox() { s.Require().Equal(0, count) } +func (s *AccountRepoSuite) TestUpdateExtra_ExhaustedCodexSnapshotSyncsSchedulerCache() { + account := mustCreateAccount(s.T(), s.client, &service.Account{ + Name: "acc-extra-codex-exhausted", + Platform: service.PlatformOpenAI, + Type: service.AccountTypeOAuth, + Extra: map[string]any{}, + }) + cacheRecorder := &schedulerCacheRecorder{} + s.repo.schedulerCache = cacheRecorder + _, err := s.repo.sql.ExecContext(s.ctx, "TRUNCATE scheduler_outbox") + s.Require().NoError(err) + + s.Require().NoError(s.repo.UpdateExtra(s.ctx, account.ID, map[string]any{ + "codex_7d_used_percent": 100.0, + "codex_7d_reset_at": "2026-03-12T13:00:00Z", + "codex_7d_reset_after_seconds": 86400, + })) + + var count int + err = scanSingleRow(s.ctx, s.repo.sql, "SELECT COUNT(*) FROM scheduler_outbox", nil, &count) + s.Require().NoError(err) + s.Require().Equal(0, count) + s.Require().Len(cacheRecorder.setAccounts, 1) + s.Require().Equal(account.ID, cacheRecorder.setAccounts[0].ID) + s.Require().Equal(100.0, cacheRecorder.setAccounts[0].Extra["codex_7d_used_percent"]) +} + func (s *AccountRepoSuite) TestUpdateExtra_CustomKeysStillEnqueueOutbox() { account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-extra-custom", Extra: map[string]any{}}) _, err := s.repo.sql.ExecContext(s.ctx, "TRUNCATE scheduler_outbox") diff --git a/backend/internal/service/account_usage_service.go b/backend/internal/service/account_usage_service.go index 7c001118..e4245133 100644 --- a/backend/internal/service/account_usage_service.go +++ b/backend/internal/service/account_usage_service.go @@ -369,8 +369,11 @@ func (s *AccountUsageService) getOpenAIUsage(ctx context.Context, account *Accou } if shouldRefreshOpenAICodexSnapshot(account, usage, now) && s.shouldProbeOpenAICodexSnapshot(account.ID, now) { - if updates, err := s.probeOpenAICodexSnapshot(ctx, account); err == nil && len(updates) > 0 { + if updates, resetAt, err := s.probeOpenAICodexSnapshot(ctx, account); err == nil && (len(updates) > 0 || resetAt != nil) { mergeAccountExtra(account, updates) + if resetAt != nil { + account.RateLimitResetAt = resetAt + } if usage.UpdatedAt == nil { usage.UpdatedAt = &now } @@ -457,26 +460,26 @@ func (s *AccountUsageService) shouldProbeOpenAICodexSnapshot(accountID int64, no return true } -func (s *AccountUsageService) probeOpenAICodexSnapshot(ctx context.Context, account *Account) (map[string]any, error) { +func (s *AccountUsageService) probeOpenAICodexSnapshot(ctx context.Context, account *Account) (map[string]any, *time.Time, error) { if account == nil || !account.IsOAuth() { - return nil, nil + return nil, nil, nil } accessToken := account.GetOpenAIAccessToken() if accessToken == "" { - return nil, fmt.Errorf("no access token available") + return nil, nil, fmt.Errorf("no access token available") } modelID := openaipkg.DefaultTestModel payload := createOpenAITestPayload(modelID, true) payloadBytes, err := json.Marshal(payload) if err != nil { - return nil, fmt.Errorf("marshal openai probe payload: %w", err) + return nil, nil, fmt.Errorf("marshal openai probe payload: %w", err) } reqCtx, cancel := context.WithTimeout(ctx, 15*time.Second) defer cancel() req, err := http.NewRequestWithContext(reqCtx, http.MethodPost, chatgptCodexURL, bytes.NewReader(payloadBytes)) if err != nil { - return nil, fmt.Errorf("create openai probe request: %w", err) + return nil, nil, fmt.Errorf("create openai probe request: %w", err) } req.Host = "chatgpt.com" req.Header.Set("Content-Type", "application/json") @@ -505,43 +508,67 @@ func (s *AccountUsageService) probeOpenAICodexSnapshot(ctx context.Context, acco ResponseHeaderTimeout: 10 * time.Second, }) if err != nil { - return nil, fmt.Errorf("build openai probe client: %w", err) + return nil, nil, fmt.Errorf("build openai probe client: %w", err) } resp, err := client.Do(req) if err != nil { - return nil, fmt.Errorf("openai codex probe request failed: %w", err) + return nil, nil, fmt.Errorf("openai codex probe request failed: %w", err) } defer func() { _ = resp.Body.Close() }() - updates, err := extractOpenAICodexProbeUpdates(resp) + updates, resetAt, err := extractOpenAICodexProbeSnapshot(resp) if err != nil { - return nil, err + return nil, nil, err } - if len(updates) > 0 { - go func(accountID int64, updates map[string]any) { - updateCtx, updateCancel := context.WithTimeout(context.Background(), 5*time.Second) - defer updateCancel() + if len(updates) > 0 || resetAt != nil { + s.persistOpenAICodexProbeSnapshot(account.ID, updates, resetAt) + return updates, resetAt, nil + } + return nil, nil, nil +} + +func (s *AccountUsageService) persistOpenAICodexProbeSnapshot(accountID int64, updates map[string]any, resetAt *time.Time) { + if s == nil || s.accountRepo == nil || accountID <= 0 { + return + } + if len(updates) == 0 && resetAt == nil { + return + } + + go func() { + updateCtx, updateCancel := context.WithTimeout(context.Background(), 5*time.Second) + defer updateCancel() + if len(updates) > 0 { _ = s.accountRepo.UpdateExtra(updateCtx, accountID, updates) - }(account.ID, updates) - return updates, nil + } + if resetAt != nil { + _ = s.accountRepo.SetRateLimited(updateCtx, accountID, *resetAt) + } + }() +} + +func extractOpenAICodexProbeSnapshot(resp *http.Response) (map[string]any, *time.Time, error) { + if resp == nil { + return nil, nil, nil } - return nil, nil + if snapshot := ParseCodexRateLimitHeaders(resp.Header); snapshot != nil { + baseTime := time.Now() + updates := buildCodexUsageExtraUpdates(snapshot, baseTime) + resetAt := codexRateLimitResetAtFromSnapshot(snapshot, baseTime) + if len(updates) > 0 { + return updates, resetAt, nil + } + return nil, resetAt, nil + } + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + return nil, nil, fmt.Errorf("openai codex probe returned status %d", resp.StatusCode) + } + return nil, nil, nil } func extractOpenAICodexProbeUpdates(resp *http.Response) (map[string]any, error) { - if resp == nil { - return nil, nil - } - if snapshot := ParseCodexRateLimitHeaders(resp.Header); snapshot != nil { - updates := buildCodexUsageExtraUpdates(snapshot, time.Now()) - if len(updates) > 0 { - return updates, nil - } - } - if resp.StatusCode < 200 || resp.StatusCode >= 300 { - return nil, fmt.Errorf("openai codex probe returned status %d", resp.StatusCode) - } - return nil, nil + updates, _, err := extractOpenAICodexProbeSnapshot(resp) + return updates, err } func mergeAccountExtra(account *Account, updates map[string]any) { diff --git a/backend/internal/service/account_usage_service_test.go b/backend/internal/service/account_usage_service_test.go index 974d9029..a063fe26 100644 --- a/backend/internal/service/account_usage_service_test.go +++ b/backend/internal/service/account_usage_service_test.go @@ -1,11 +1,36 @@ package service import ( + "context" "net/http" "testing" "time" ) +type accountUsageCodexProbeRepo struct { + stubOpenAIAccountRepo + updateExtraCh chan map[string]any + rateLimitCh chan time.Time +} + +func (r *accountUsageCodexProbeRepo) UpdateExtra(_ context.Context, _ int64, updates map[string]any) error { + if r.updateExtraCh != nil { + copied := make(map[string]any, len(updates)) + for k, v := range updates { + copied[k] = v + } + r.updateExtraCh <- copied + } + return nil +} + +func (r *accountUsageCodexProbeRepo) SetRateLimited(_ context.Context, _ int64, resetAt time.Time) error { + if r.rateLimitCh != nil { + r.rateLimitCh <- resetAt + } + return nil +} + func TestShouldRefreshOpenAICodexSnapshot(t *testing.T) { t.Parallel() @@ -66,3 +91,60 @@ func TestExtractOpenAICodexProbeUpdatesAccepts429WithCodexHeaders(t *testing.T) t.Fatalf("codex_7d_used_percent = %v, want 100", got) } } + +func TestExtractOpenAICodexProbeSnapshotAccepts429WithResetAt(t *testing.T) { + t.Parallel() + + headers := make(http.Header) + headers.Set("x-codex-primary-used-percent", "100") + headers.Set("x-codex-primary-reset-after-seconds", "604800") + headers.Set("x-codex-primary-window-minutes", "10080") + headers.Set("x-codex-secondary-used-percent", "100") + headers.Set("x-codex-secondary-reset-after-seconds", "18000") + headers.Set("x-codex-secondary-window-minutes", "300") + + updates, resetAt, err := extractOpenAICodexProbeSnapshot(&http.Response{StatusCode: http.StatusTooManyRequests, Header: headers}) + if err != nil { + t.Fatalf("extractOpenAICodexProbeSnapshot() error = %v", err) + } + if len(updates) == 0 { + t.Fatal("expected codex probe updates from 429 headers") + } + if resetAt == nil { + t.Fatal("expected resetAt from exhausted codex headers") + } +} + +func TestAccountUsageService_PersistOpenAICodexProbeSnapshotSetsRateLimit(t *testing.T) { + t.Parallel() + + repo := &accountUsageCodexProbeRepo{ + updateExtraCh: make(chan map[string]any, 1), + rateLimitCh: make(chan time.Time, 1), + } + svc := &AccountUsageService{accountRepo: repo} + resetAt := time.Now().Add(2 * time.Hour).UTC().Truncate(time.Second) + + svc.persistOpenAICodexProbeSnapshot(321, map[string]any{ + "codex_7d_used_percent": 100.0, + "codex_7d_reset_at": resetAt.Format(time.RFC3339), + }, &resetAt) + + select { + case updates := <-repo.updateExtraCh: + if got := updates["codex_7d_used_percent"]; got != 100.0 { + t.Fatalf("codex_7d_used_percent = %v, want 100", got) + } + case <-time.After(2 * time.Second): + t.Fatal("waiting for codex probe extra persistence timed out") + } + + select { + case got := <-repo.rateLimitCh: + if got.Before(resetAt.Add(-time.Second)) || got.After(resetAt.Add(time.Second)) { + t.Fatalf("rate limit resetAt = %v, want around %v", got, resetAt) + } + case <-time.After(2 * time.Second): + t.Fatal("waiting for codex probe rate limit persistence timed out") + } +} diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go index 233566f3..0bf924b8 100644 --- a/backend/internal/service/openai_gateway_service.go +++ b/backend/internal/service/openai_gateway_service.go @@ -4102,6 +4102,9 @@ func (s *OpenAIGatewayService) updateCodexUsageSnapshot(ctx context.Context, acc return } shouldPersistUpdates := len(updates) > 0 && s.getCodexSnapshotThrottle().Allow(accountID, now) + if !shouldPersistUpdates && resetAt == nil { + return + } go func() { updateCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)