From 269414948976700f05a6636b1a9d53f0412ec776 Mon Sep 17 00:00:00 2001 From: ius Date: Wed, 11 Mar 2026 13:53:19 +0800 Subject: [PATCH] Reduce DB write amplification on quota and account extra updates --- backend/internal/repository/account_repo.go | 31 +++- .../account_repo_integration_test.go | 32 ++++ backend/internal/repository/api_key_repo.go | 26 +++ .../api_key_repo_integration_test.go | 21 +++ backend/internal/service/api_key_service.go | 25 +++ .../service/api_key_service_quota_test.go | 170 ++++++++++++++++++ .../service/openai_gateway_service.go | 66 ++++++- .../openai_ws_ratelimit_signal_test.go | 34 ++++ 8 files changed, 396 insertions(+), 9 deletions(-) create mode 100644 backend/internal/service/api_key_service_quota_test.go diff --git a/backend/internal/repository/account_repo.go b/backend/internal/repository/account_repo.go index c7642152..2aa72ebb 100644 --- a/backend/internal/repository/account_repo.go +++ b/backend/internal/repository/account_repo.go @@ -16,6 +16,7 @@ import ( "encoding/json" "errors" "strconv" + "strings" "time" dbent "github.com/Wei-Shaw/sub2api/ent" @@ -1185,12 +1186,38 @@ func (r *accountRepository) UpdateExtra(ctx context.Context, id int64, updates m if affected == 0 { return service.ErrAccountNotFound } - if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil { - logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue extra update failed: account=%d err=%v", id, err) + if shouldEnqueueSchedulerOutboxForExtraUpdates(updates) { + if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil { + logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue extra update failed: account=%d err=%v", id, err) + } } return nil } +func shouldEnqueueSchedulerOutboxForExtraUpdates(updates map[string]any) bool { + if len(updates) == 0 { + return false + } + for key := range updates { + if isSchedulerNeutralAccountExtraKey(key) { + continue + } + return true + } + return false +} + +func isSchedulerNeutralAccountExtraKey(key string) bool { + key = strings.TrimSpace(key) + if key == "" { + return false + } + if key == "session_window_utilization" { + return true + } + return strings.HasPrefix(key, "codex_") +} + func (r *accountRepository) BulkUpdate(ctx context.Context, ids []int64, updates service.AccountBulkUpdate) (int64, error) { if len(ids) == 0 { return 0, nil diff --git a/backend/internal/repository/account_repo_integration_test.go b/backend/internal/repository/account_repo_integration_test.go index 58971933..caf8d3f3 100644 --- a/backend/internal/repository/account_repo_integration_test.go +++ b/backend/internal/repository/account_repo_integration_test.go @@ -623,6 +623,38 @@ func (s *AccountRepoSuite) TestUpdateExtra_NilExtra() { s.Require().Equal("val", got.Extra["key"]) } +func (s *AccountRepoSuite) TestUpdateExtra_SchedulerNeutralKeysSkipOutbox() { + account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-extra-neutral", Extra: map[string]any{}}) + _, err := s.repo.sql.ExecContext(s.ctx, "TRUNCATE scheduler_outbox") + s.Require().NoError(err) + + s.Require().NoError(s.repo.UpdateExtra(s.ctx, account.ID, map[string]any{ + "codex_usage_updated_at": "2026-03-11T13:00:00Z", + "codex_5h_used_percent": 12.5, + "session_window_utilization": 0.42, + })) + + var count int + err = scanSingleRow(s.ctx, s.repo.sql, "SELECT COUNT(*) FROM scheduler_outbox", nil, &count) + s.Require().NoError(err) + s.Require().Equal(0, count) +} + +func (s *AccountRepoSuite) TestUpdateExtra_CustomKeysStillEnqueueOutbox() { + account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-extra-custom", Extra: map[string]any{}}) + _, err := s.repo.sql.ExecContext(s.ctx, "TRUNCATE scheduler_outbox") + s.Require().NoError(err) + + s.Require().NoError(s.repo.UpdateExtra(s.ctx, account.ID, map[string]any{ + "custom_scheduler_sensitive_key": true, + })) + + var count int + err = scanSingleRow(s.ctx, s.repo.sql, "SELECT COUNT(*) FROM scheduler_outbox", nil, &count) + s.Require().NoError(err) + s.Require().Equal(1, count) +} + // --- GetByCRSAccountID --- func (s *AccountRepoSuite) TestGetByCRSAccountID() { diff --git a/backend/internal/repository/api_key_repo.go b/backend/internal/repository/api_key_repo.go index 95db1819..4c7f38a8 100644 --- a/backend/internal/repository/api_key_repo.go +++ b/backend/internal/repository/api_key_repo.go @@ -452,6 +452,32 @@ func (r *apiKeyRepository) IncrementQuotaUsed(ctx context.Context, id int64, amo return updated.QuotaUsed, nil } +// IncrementQuotaUsedAndGetState atomically increments quota_used, conditionally marks the key +// as quota_exhausted, and returns the latest quota state in one round trip. +func (r *apiKeyRepository) IncrementQuotaUsedAndGetState(ctx context.Context, id int64, amount float64) (*service.APIKeyQuotaUsageState, error) { + query := ` + UPDATE api_keys + SET + quota_used = quota_used + $1, + status = CASE + WHEN quota > 0 AND quota_used + $1 >= quota THEN $2 + ELSE status + END, + updated_at = NOW() + WHERE id = $3 AND deleted_at IS NULL + RETURNING quota_used, quota, key, status + ` + + state := &service.APIKeyQuotaUsageState{} + if err := scanSingleRow(ctx, r.sql, query, []any{amount, service.StatusAPIKeyQuotaExhausted, id}, &state.QuotaUsed, &state.Quota, &state.Key, &state.Status); err != nil { + if err == sql.ErrNoRows { + return nil, service.ErrAPIKeyNotFound + } + return nil, err + } + return state, nil +} + func (r *apiKeyRepository) UpdateLastUsed(ctx context.Context, id int64, usedAt time.Time) error { affected, err := r.client.APIKey.Update(). Where(apikey.IDEQ(id), apikey.DeletedAtIsNil()). diff --git a/backend/internal/repository/api_key_repo_integration_test.go b/backend/internal/repository/api_key_repo_integration_test.go index 80714614..a8989ff2 100644 --- a/backend/internal/repository/api_key_repo_integration_test.go +++ b/backend/internal/repository/api_key_repo_integration_test.go @@ -417,6 +417,27 @@ func (s *APIKeyRepoSuite) TestIncrementQuotaUsed_DeletedKey() { s.Require().ErrorIs(err, service.ErrAPIKeyNotFound, "已删除的 key 应返回 ErrAPIKeyNotFound") } +func (s *APIKeyRepoSuite) TestIncrementQuotaUsedAndGetState() { + user := s.mustCreateUser("quota-state@test.com") + key := s.mustCreateApiKey(user.ID, "sk-quota-state", "QuotaState", nil) + key.Quota = 3 + key.QuotaUsed = 1 + s.Require().NoError(s.repo.Update(s.ctx, key), "Update quota") + + state, err := s.repo.IncrementQuotaUsedAndGetState(s.ctx, key.ID, 2.5) + s.Require().NoError(err, "IncrementQuotaUsedAndGetState") + s.Require().NotNil(state) + s.Require().Equal(3.5, state.QuotaUsed) + s.Require().Equal(3.0, state.Quota) + s.Require().Equal(service.StatusAPIKeyQuotaExhausted, state.Status) + s.Require().Equal(key.Key, state.Key) + + got, err := s.repo.GetByID(s.ctx, key.ID) + s.Require().NoError(err, "GetByID") + s.Require().Equal(3.5, got.QuotaUsed) + s.Require().Equal(service.StatusAPIKeyQuotaExhausted, got.Status) +} + // TestIncrementQuotaUsed_Concurrent 使用真实数据库验证并发原子性。 // 注意:此测试使用 testEntClient(非事务隔离),数据会真正写入数据库。 func TestIncrementQuotaUsed_Concurrent(t *testing.T) { diff --git a/backend/internal/service/api_key_service.go b/backend/internal/service/api_key_service.go index 17c5b486..18e9ff7a 100644 --- a/backend/internal/service/api_key_service.go +++ b/backend/internal/service/api_key_service.go @@ -6,6 +6,7 @@ import ( "encoding/hex" "fmt" "strconv" + "strings" "sync" "time" @@ -110,6 +111,15 @@ func (d *APIKeyRateLimitData) EffectiveUsage7d() float64 { return d.Usage7d } +// APIKeyQuotaUsageState captures the latest quota fields after an atomic quota update. +// It is intentionally small so repositories can return it from a single SQL statement. +type APIKeyQuotaUsageState struct { + QuotaUsed float64 + Quota float64 + Key string + Status string +} + // APIKeyCache defines cache operations for API key service type APIKeyCache interface { GetCreateAttemptCount(ctx context.Context, userID int64) (int, error) @@ -817,6 +827,21 @@ func (s *APIKeyService) UpdateQuotaUsed(ctx context.Context, apiKeyID int64, cos return nil } + type quotaStateReader interface { + IncrementQuotaUsedAndGetState(ctx context.Context, id int64, amount float64) (*APIKeyQuotaUsageState, error) + } + + if repo, ok := s.apiKeyRepo.(quotaStateReader); ok { + state, err := repo.IncrementQuotaUsedAndGetState(ctx, apiKeyID, cost) + if err != nil { + return fmt.Errorf("increment quota used: %w", err) + } + if state != nil && state.Status == StatusAPIKeyQuotaExhausted && strings.TrimSpace(state.Key) != "" { + s.InvalidateAuthCacheByKey(ctx, state.Key) + } + return nil + } + // Use repository to atomically increment quota_used newQuotaUsed, err := s.apiKeyRepo.IncrementQuotaUsed(ctx, apiKeyID, cost) if err != nil { diff --git a/backend/internal/service/api_key_service_quota_test.go b/backend/internal/service/api_key_service_quota_test.go new file mode 100644 index 00000000..2e2f6f78 --- /dev/null +++ b/backend/internal/service/api_key_service_quota_test.go @@ -0,0 +1,170 @@ +//go:build unit + +package service + +import ( + "context" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" + "github.com/stretchr/testify/require" +) + +type quotaStateRepoStub struct { + quotaBaseAPIKeyRepoStub + stateCalls int + state *APIKeyQuotaUsageState + stateErr error +} + +func (s *quotaStateRepoStub) IncrementQuotaUsedAndGetState(ctx context.Context, id int64, amount float64) (*APIKeyQuotaUsageState, error) { + s.stateCalls++ + if s.stateErr != nil { + return nil, s.stateErr + } + if s.state == nil { + return nil, nil + } + out := *s.state + return &out, nil +} + +type quotaStateCacheStub struct { + deleteAuthKeys []string +} + +func (s *quotaStateCacheStub) GetCreateAttemptCount(context.Context, int64) (int, error) { + return 0, nil +} + +func (s *quotaStateCacheStub) IncrementCreateAttemptCount(context.Context, int64) error { + return nil +} + +func (s *quotaStateCacheStub) DeleteCreateAttemptCount(context.Context, int64) error { + return nil +} + +func (s *quotaStateCacheStub) IncrementDailyUsage(context.Context, string) error { + return nil +} + +func (s *quotaStateCacheStub) SetDailyUsageExpiry(context.Context, string, time.Duration) error { + return nil +} + +func (s *quotaStateCacheStub) GetAuthCache(context.Context, string) (*APIKeyAuthCacheEntry, error) { + return nil, nil +} + +func (s *quotaStateCacheStub) SetAuthCache(context.Context, string, *APIKeyAuthCacheEntry, time.Duration) error { + return nil +} + +func (s *quotaStateCacheStub) DeleteAuthCache(_ context.Context, key string) error { + s.deleteAuthKeys = append(s.deleteAuthKeys, key) + return nil +} + +func (s *quotaStateCacheStub) PublishAuthCacheInvalidation(context.Context, string) error { + return nil +} + +func (s *quotaStateCacheStub) SubscribeAuthCacheInvalidation(context.Context, func(string)) error { + return nil +} + +type quotaBaseAPIKeyRepoStub struct { + getByIDCalls int +} + +func (s *quotaBaseAPIKeyRepoStub) Create(context.Context, *APIKey) error { + panic("unexpected Create call") +} +func (s *quotaBaseAPIKeyRepoStub) GetByID(context.Context, int64) (*APIKey, error) { + s.getByIDCalls++ + return nil, nil +} +func (s *quotaBaseAPIKeyRepoStub) GetKeyAndOwnerID(context.Context, int64) (string, int64, error) { + panic("unexpected GetKeyAndOwnerID call") +} +func (s *quotaBaseAPIKeyRepoStub) GetByKey(context.Context, string) (*APIKey, error) { + panic("unexpected GetByKey call") +} +func (s *quotaBaseAPIKeyRepoStub) GetByKeyForAuth(context.Context, string) (*APIKey, error) { + panic("unexpected GetByKeyForAuth call") +} +func (s *quotaBaseAPIKeyRepoStub) Update(context.Context, *APIKey) error { + panic("unexpected Update call") +} +func (s *quotaBaseAPIKeyRepoStub) Delete(context.Context, int64) error { + panic("unexpected Delete call") +} +func (s *quotaBaseAPIKeyRepoStub) ListByUserID(context.Context, int64, pagination.PaginationParams, APIKeyListFilters) ([]APIKey, *pagination.PaginationResult, error) { + panic("unexpected ListByUserID call") +} +func (s *quotaBaseAPIKeyRepoStub) VerifyOwnership(context.Context, int64, []int64) ([]int64, error) { + panic("unexpected VerifyOwnership call") +} +func (s *quotaBaseAPIKeyRepoStub) CountByUserID(context.Context, int64) (int64, error) { + panic("unexpected CountByUserID call") +} +func (s *quotaBaseAPIKeyRepoStub) ExistsByKey(context.Context, string) (bool, error) { + panic("unexpected ExistsByKey call") +} +func (s *quotaBaseAPIKeyRepoStub) ListByGroupID(context.Context, int64, pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error) { + panic("unexpected ListByGroupID call") +} +func (s *quotaBaseAPIKeyRepoStub) SearchAPIKeys(context.Context, int64, string, int) ([]APIKey, error) { + panic("unexpected SearchAPIKeys call") +} +func (s *quotaBaseAPIKeyRepoStub) ClearGroupIDByGroupID(context.Context, int64) (int64, error) { + panic("unexpected ClearGroupIDByGroupID call") +} +func (s *quotaBaseAPIKeyRepoStub) CountByGroupID(context.Context, int64) (int64, error) { + panic("unexpected CountByGroupID call") +} +func (s *quotaBaseAPIKeyRepoStub) ListKeysByUserID(context.Context, int64) ([]string, error) { + panic("unexpected ListKeysByUserID call") +} +func (s *quotaBaseAPIKeyRepoStub) ListKeysByGroupID(context.Context, int64) ([]string, error) { + panic("unexpected ListKeysByGroupID call") +} +func (s *quotaBaseAPIKeyRepoStub) IncrementQuotaUsed(context.Context, int64, float64) (float64, error) { + panic("unexpected IncrementQuotaUsed call") +} +func (s *quotaBaseAPIKeyRepoStub) UpdateLastUsed(context.Context, int64, time.Time) error { + panic("unexpected UpdateLastUsed call") +} +func (s *quotaBaseAPIKeyRepoStub) IncrementRateLimitUsage(context.Context, int64, float64) error { + panic("unexpected IncrementRateLimitUsage call") +} +func (s *quotaBaseAPIKeyRepoStub) ResetRateLimitWindows(context.Context, int64) error { + panic("unexpected ResetRateLimitWindows call") +} +func (s *quotaBaseAPIKeyRepoStub) GetRateLimitData(context.Context, int64) (*APIKeyRateLimitData, error) { + panic("unexpected GetRateLimitData call") +} + +func TestAPIKeyService_UpdateQuotaUsed_UsesAtomicStatePath(t *testing.T) { + repo := "aStateRepoStub{ + state: &APIKeyQuotaUsageState{ + QuotaUsed: 12, + Quota: 10, + Key: "sk-test-quota", + Status: StatusAPIKeyQuotaExhausted, + }, + } + cache := "aStateCacheStub{} + svc := &APIKeyService{ + apiKeyRepo: repo, + cache: cache, + } + + err := svc.UpdateQuotaUsed(context.Background(), 101, 2) + require.NoError(t, err) + require.Equal(t, 1, repo.stateCalls) + require.Equal(t, 0, repo.getByIDCalls, "fast path should not re-read API key by id") + require.Equal(t, []string{svc.authCacheKey("sk-test-quota")}, cache.deleteAuthKeys) +} diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go index 44cfc83a..233566f3 100644 --- a/backend/internal/service/openai_gateway_service.go +++ b/backend/internal/service/openai_gateway_service.go @@ -52,6 +52,8 @@ const ( openAIWSRetryJitterRatioDefault = 0.2 openAICompactSessionSeedKey = "openai_compact_session_seed" codexCLIVersion = "0.104.0" + // Codex 限额快照仅用于后台展示/诊断,不需要每个成功请求都立即落库。 + openAICodexSnapshotPersistMinInterval = 30 * time.Second ) // OpenAI allowed headers whitelist (for non-passthrough). @@ -255,6 +257,46 @@ type openAIWSRetryMetrics struct { nonRetryableFastFallback atomic.Int64 } +type accountWriteThrottle struct { + minInterval time.Duration + mu sync.Mutex + lastByID map[int64]time.Time +} + +func newAccountWriteThrottle(minInterval time.Duration) *accountWriteThrottle { + return &accountWriteThrottle{ + minInterval: minInterval, + lastByID: make(map[int64]time.Time), + } +} + +func (t *accountWriteThrottle) Allow(id int64, now time.Time) bool { + if t == nil || id <= 0 || t.minInterval <= 0 { + return true + } + + t.mu.Lock() + defer t.mu.Unlock() + + if last, ok := t.lastByID[id]; ok && now.Sub(last) < t.minInterval { + return false + } + t.lastByID[id] = now + + if len(t.lastByID) > 4096 { + cutoff := now.Add(-4 * t.minInterval) + for accountID, writtenAt := range t.lastByID { + if writtenAt.Before(cutoff) { + delete(t.lastByID, accountID) + } + } + } + + return true +} + +var defaultOpenAICodexSnapshotPersistThrottle = newAccountWriteThrottle(openAICodexSnapshotPersistMinInterval) + // OpenAIGatewayService handles OpenAI API gateway operations type OpenAIGatewayService struct { accountRepo AccountRepository @@ -289,6 +331,7 @@ type OpenAIGatewayService struct { openaiWSFallbackUntil sync.Map // key: int64(accountID), value: time.Time openaiWSRetryMetrics openAIWSRetryMetrics responseHeaderFilter *responseheaders.CompiledHeaderFilter + codexSnapshotThrottle *accountWriteThrottle } // NewOpenAIGatewayService creates a new OpenAIGatewayService @@ -329,17 +372,25 @@ func NewOpenAIGatewayService( nil, "service.openai_gateway", ), - httpUpstream: httpUpstream, - deferredService: deferredService, - openAITokenProvider: openAITokenProvider, - toolCorrector: NewCodexToolCorrector(), - openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), - responseHeaderFilter: compileResponseHeaderFilter(cfg), + httpUpstream: httpUpstream, + deferredService: deferredService, + openAITokenProvider: openAITokenProvider, + toolCorrector: NewCodexToolCorrector(), + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + responseHeaderFilter: compileResponseHeaderFilter(cfg), + codexSnapshotThrottle: newAccountWriteThrottle(openAICodexSnapshotPersistMinInterval), } svc.logOpenAIWSModeBootstrap() return svc } +func (s *OpenAIGatewayService) getCodexSnapshotThrottle() *accountWriteThrottle { + if s != nil && s.codexSnapshotThrottle != nil { + return s.codexSnapshotThrottle + } + return defaultOpenAICodexSnapshotPersistThrottle +} + func (s *OpenAIGatewayService) billingDeps() *billingDeps { return &billingDeps{ accountRepo: s.accountRepo, @@ -4050,11 +4101,12 @@ func (s *OpenAIGatewayService) updateCodexUsageSnapshot(ctx context.Context, acc if len(updates) == 0 && resetAt == nil { return } + shouldPersistUpdates := len(updates) > 0 && s.getCodexSnapshotThrottle().Allow(accountID, now) go func() { updateCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() - if len(updates) > 0 { + if shouldPersistUpdates { _ = s.accountRepo.UpdateExtra(updateCtx, accountID, updates) } if resetAt != nil { diff --git a/backend/internal/service/openai_ws_ratelimit_signal_test.go b/backend/internal/service/openai_ws_ratelimit_signal_test.go index 28cb8e00..f5c79923 100644 --- a/backend/internal/service/openai_ws_ratelimit_signal_test.go +++ b/backend/internal/service/openai_ws_ratelimit_signal_test.go @@ -405,6 +405,40 @@ func TestOpenAIGatewayService_UpdateCodexUsageSnapshot_NonExhaustedSnapshotDoesN } } +func TestOpenAIGatewayService_UpdateCodexUsageSnapshot_ThrottlesExtraWrites(t *testing.T) { + repo := &openAICodexSnapshotAsyncRepo{ + updateExtraCh: make(chan map[string]any, 2), + rateLimitCh: make(chan time.Time, 2), + } + svc := &OpenAIGatewayService{ + accountRepo: repo, + codexSnapshotThrottle: newAccountWriteThrottle(time.Hour), + } + snapshot := &OpenAICodexUsageSnapshot{ + PrimaryUsedPercent: ptrFloat64WS(94), + PrimaryResetAfterSeconds: ptrIntWS(3600), + PrimaryWindowMinutes: ptrIntWS(10080), + SecondaryUsedPercent: ptrFloat64WS(22), + SecondaryResetAfterSeconds: ptrIntWS(1200), + SecondaryWindowMinutes: ptrIntWS(300), + } + + svc.updateCodexUsageSnapshot(context.Background(), 777, snapshot) + svc.updateCodexUsageSnapshot(context.Background(), 777, snapshot) + + select { + case <-repo.updateExtraCh: + case <-time.After(2 * time.Second): + t.Fatal("等待第一次 codex 快照落库超时") + } + + select { + case updates := <-repo.updateExtraCh: + t.Fatalf("unexpected second codex snapshot write: %v", updates) + case <-time.After(200 * time.Millisecond): + } +} + func ptrFloat64WS(v float64) *float64 { return &v } func ptrIntWS(v int) *int { return &v }