diff --git a/backend/internal/handler/gateway_handler_warmup_intercept_unit_test.go b/backend/internal/handler/gateway_handler_warmup_intercept_unit_test.go index 210c033f..0c94d50b 100644 --- a/backend/internal/handler/gateway_handler_warmup_intercept_unit_test.go +++ b/backend/internal/handler/gateway_handler_warmup_intercept_unit_test.go @@ -127,6 +127,7 @@ func (f *fakeConcurrencyCache) GetAccountConcurrencyBatch(_ context.Context, acc return result, nil } func (f *fakeConcurrencyCache) CleanupExpiredAccountSlots(context.Context, int64) error { return nil } +func (f *fakeConcurrencyCache) CleanupStaleProcessSlots(context.Context, string) error { return nil } func newTestGatewayHandler(t *testing.T, group *service.Group, accounts []*service.Account) (*GatewayHandler, func()) { t.Helper() diff --git a/backend/internal/handler/gateway_helper_fastpath_test.go b/backend/internal/handler/gateway_helper_fastpath_test.go index 31d489f0..c7c0fb6c 100644 --- a/backend/internal/handler/gateway_helper_fastpath_test.go +++ b/backend/internal/handler/gateway_helper_fastpath_test.go @@ -89,6 +89,10 @@ func (m *concurrencyCacheMock) CleanupExpiredAccountSlots(ctx context.Context, a return nil } +func (m *concurrencyCacheMock) CleanupStaleProcessSlots(ctx context.Context, activeRequestPrefix string) error { + return nil +} + func TestConcurrencyHelper_TryAcquireUserSlot(t *testing.T) { cache := &concurrencyCacheMock{ acquireUserSlotFn: func(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error) { diff --git a/backend/internal/handler/gateway_helper_hotpath_test.go b/backend/internal/handler/gateway_helper_hotpath_test.go index f8f7eaca..9e904107 100644 --- a/backend/internal/handler/gateway_helper_hotpath_test.go +++ b/backend/internal/handler/gateway_helper_hotpath_test.go @@ -120,6 +120,10 @@ func (s *helperConcurrencyCacheStub) CleanupExpiredAccountSlots(ctx context.Cont return nil } +func (s *helperConcurrencyCacheStub) CleanupStaleProcessSlots(ctx context.Context, activeRequestPrefix string) error { + return nil +} + func newHelperTestContext(method, path string) (*gin.Context, *httptest.ResponseRecorder) { gin.SetMode(gin.TestMode) rec := httptest.NewRecorder() diff --git a/backend/internal/repository/concurrency_cache.go b/backend/internal/repository/concurrency_cache.go index a2552715..8732b2ce 100644 --- a/backend/internal/repository/concurrency_cache.go +++ b/backend/internal/repository/concurrency_cache.go @@ -147,17 +147,47 @@ var ( return 1 `) - // cleanupExpiredSlotsScript - remove expired slots - // KEYS[1] = concurrency:account:{accountID} - // ARGV[1] = TTL (seconds) + // cleanupExpiredSlotsScript 清理单个账号/用户有序集合中过期槽位 + // KEYS[1] = 有序集合键 + // ARGV[1] = TTL(秒) cleanupExpiredSlotsScript = redis.NewScript(` - local key = KEYS[1] - local ttl = tonumber(ARGV[1]) - local timeResult = redis.call('TIME') - local now = tonumber(timeResult[1]) - local expireBefore = now - ttl - return redis.call('ZREMRANGEBYSCORE', key, '-inf', expireBefore) - `) + local key = KEYS[1] + local ttl = tonumber(ARGV[1]) + local timeResult = redis.call('TIME') + local now = tonumber(timeResult[1]) + local expireBefore = now - ttl + redis.call('ZREMRANGEBYSCORE', key, '-inf', expireBefore) + if redis.call('ZCARD', key) == 0 then + redis.call('DEL', key) + else + redis.call('EXPIRE', key, ttl) + end + return 1 + `) + + // startupCleanupScript 清理非当前进程前缀的槽位成员。 + // KEYS 是有序集合键列表,ARGV[1] 是当前进程前缀,ARGV[2] 是槽位 TTL。 + // 遍历每个 KEYS[i],移除前缀不匹配的成员,清空后删 key,否则刷新 EXPIRE。 + startupCleanupScript = redis.NewScript(` + local activePrefix = ARGV[1] + local slotTTL = tonumber(ARGV[2]) + local removed = 0 + for i = 1, #KEYS do + local key = KEYS[i] + local members = redis.call('ZRANGE', key, 0, -1) + for _, member in ipairs(members) do + if string.sub(member, 1, string.len(activePrefix)) ~= activePrefix then + removed = removed + redis.call('ZREM', key, member) + end + end + if redis.call('ZCARD', key) == 0 then + redis.call('DEL', key) + else + redis.call('EXPIRE', key, slotTTL) + end + end + return removed + `) ) type concurrencyCache struct { @@ -463,3 +493,72 @@ func (c *concurrencyCache) CleanupExpiredAccountSlots(ctx context.Context, accou _, err := cleanupExpiredSlotsScript.Run(ctx, c.rdb, []string{key}, c.slotTTLSeconds).Result() return err } + +func (c *concurrencyCache) CleanupStaleProcessSlots(ctx context.Context, activeRequestPrefix string) error { + if activeRequestPrefix == "" { + return nil + } + + // 1. 清理有序集合中非当前进程前缀的成员 + slotPatterns := []string{accountSlotKeyPrefix + "*", userSlotKeyPrefix + "*"} + for _, pattern := range slotPatterns { + if err := c.cleanupSlotsByPattern(ctx, pattern, activeRequestPrefix); err != nil { + return err + } + } + + // 2. 删除所有等待队列计数器(重启后计数器失效) + waitPatterns := []string{accountWaitKeyPrefix + "*", waitQueueKeyPrefix + "*"} + for _, pattern := range waitPatterns { + if err := c.deleteKeysByPattern(ctx, pattern); err != nil { + return err + } + } + + return nil +} + +// cleanupSlotsByPattern 扫描匹配 pattern 的有序集合键,批量调用 Lua 脚本清理非当前进程成员。 +func (c *concurrencyCache) cleanupSlotsByPattern(ctx context.Context, pattern, activePrefix string) error { + const scanCount = 200 + var cursor uint64 + for { + keys, nextCursor, err := c.rdb.Scan(ctx, cursor, pattern, scanCount).Result() + if err != nil { + return fmt.Errorf("scan %s: %w", pattern, err) + } + if len(keys) > 0 { + _, err := startupCleanupScript.Run(ctx, c.rdb, keys, activePrefix, c.slotTTLSeconds).Result() + if err != nil { + return fmt.Errorf("cleanup slots %s: %w", pattern, err) + } + } + cursor = nextCursor + if cursor == 0 { + break + } + } + return nil +} + +// deleteKeysByPattern 扫描匹配 pattern 的键并删除。 +func (c *concurrencyCache) deleteKeysByPattern(ctx context.Context, pattern string) error { + const scanCount = 200 + var cursor uint64 + for { + keys, nextCursor, err := c.rdb.Scan(ctx, cursor, pattern, scanCount).Result() + if err != nil { + return fmt.Errorf("scan %s: %w", pattern, err) + } + if len(keys) > 0 { + if err := c.rdb.Del(ctx, keys...).Err(); err != nil { + return fmt.Errorf("del %s: %w", pattern, err) + } + } + cursor = nextCursor + if cursor == 0 { + break + } + } + return nil +} diff --git a/backend/internal/repository/concurrency_cache_integration_test.go b/backend/internal/repository/concurrency_cache_integration_test.go index 5983c832..5da94fc2 100644 --- a/backend/internal/repository/concurrency_cache_integration_test.go +++ b/backend/internal/repository/concurrency_cache_integration_test.go @@ -25,6 +25,10 @@ type ConcurrencyCacheSuite struct { cache service.ConcurrencyCache } +func TestConcurrencyCacheSuite(t *testing.T) { + suite.Run(t, new(ConcurrencyCacheSuite)) +} + func (s *ConcurrencyCacheSuite) SetupTest() { s.IntegrationRedisSuite.SetupTest() s.cache = NewConcurrencyCache(s.rdb, testSlotTTLMinutes, int(testSlotTTL.Seconds())) @@ -247,17 +251,41 @@ func (s *ConcurrencyCacheSuite) TestAccountWaitQueue_IncrementAndDecrement() { require.Equal(s.T(), 1, val, "expected account wait count 1") } -func (s *ConcurrencyCacheSuite) TestAccountWaitQueue_DecrementNoNegative() { - accountID := int64(301) - waitKey := fmt.Sprintf("%s%d", accountWaitKeyPrefix, accountID) +func (s *ConcurrencyCacheSuite) TestCleanupStaleProcessSlots() { + accountID := int64(901) + userID := int64(902) + accountKey := fmt.Sprintf("%s%d", accountSlotKeyPrefix, accountID) + userKey := fmt.Sprintf("%s%d", userSlotKeyPrefix, userID) + userWaitKey := fmt.Sprintf("%s%d", waitQueueKeyPrefix, userID) + accountWaitKey := fmt.Sprintf("%s%d", accountWaitKeyPrefix, accountID) - require.NoError(s.T(), s.cache.DecrementAccountWaitCount(s.ctx, accountID), "DecrementAccountWaitCount on non-existent key") + now := time.Now().Unix() + require.NoError(s.T(), s.rdb.ZAdd(s.ctx, accountKey, + redis.Z{Score: float64(now), Member: "oldproc-1"}, + redis.Z{Score: float64(now), Member: "keep-1"}, + ).Err()) + require.NoError(s.T(), s.rdb.ZAdd(s.ctx, userKey, + redis.Z{Score: float64(now), Member: "oldproc-2"}, + redis.Z{Score: float64(now), Member: "keep-2"}, + ).Err()) + require.NoError(s.T(), s.rdb.Set(s.ctx, userWaitKey, 3, time.Minute).Err()) + require.NoError(s.T(), s.rdb.Set(s.ctx, accountWaitKey, 2, time.Minute).Err()) - val, err := s.rdb.Get(s.ctx, waitKey).Int() - if !errors.Is(err, redis.Nil) { - require.NoError(s.T(), err, "Get waitKey") - } - require.GreaterOrEqual(s.T(), val, 0, "expected non-negative account wait count after decrement on empty") + require.NoError(s.T(), s.cache.CleanupStaleProcessSlots(s.ctx, "keep-")) + + accountMembers, err := s.rdb.ZRange(s.ctx, accountKey, 0, -1).Result() + require.NoError(s.T(), err) + require.Equal(s.T(), []string{"keep-1"}, accountMembers) + + userMembers, err := s.rdb.ZRange(s.ctx, userKey, 0, -1).Result() + require.NoError(s.T(), err) + require.Equal(s.T(), []string{"keep-2"}, userMembers) + + _, err = s.rdb.Get(s.ctx, userWaitKey).Result() + require.True(s.T(), errors.Is(err, redis.Nil)) + + _, err = s.rdb.Get(s.ctx, accountWaitKey).Result() + require.True(s.T(), errors.Is(err, redis.Nil)) } func (s *ConcurrencyCacheSuite) TestGetAccountConcurrency_Missing() { @@ -407,6 +435,53 @@ func (s *ConcurrencyCacheSuite) TestCleanupExpiredAccountSlots_NoExpired() { require.Equal(s.T(), 2, cur) } -func TestConcurrencyCacheSuite(t *testing.T) { - suite.Run(t, new(ConcurrencyCacheSuite)) +func (s *ConcurrencyCacheSuite) TestCleanupStaleProcessSlots_RemovesOldPrefixesAndWaitCounters() { + accountID := int64(901) + userID := int64(902) + accountSlotKey := fmt.Sprintf("%s%d", accountSlotKeyPrefix, accountID) + userSlotKey := fmt.Sprintf("%s%d", userSlotKeyPrefix, userID) + userWaitKey := fmt.Sprintf("%s%d", waitQueueKeyPrefix, userID) + accountWaitKey := fmt.Sprintf("%s%d", accountWaitKeyPrefix, accountID) + + now := float64(time.Now().Unix()) + require.NoError(s.T(), s.rdb.ZAdd(s.ctx, accountSlotKey, + redis.Z{Score: now, Member: "oldproc-1"}, + redis.Z{Score: now, Member: "activeproc-1"}, + ).Err()) + require.NoError(s.T(), s.rdb.Expire(s.ctx, accountSlotKey, testSlotTTL).Err()) + require.NoError(s.T(), s.rdb.ZAdd(s.ctx, userSlotKey, + redis.Z{Score: now, Member: "oldproc-2"}, + redis.Z{Score: now, Member: "activeproc-2"}, + ).Err()) + require.NoError(s.T(), s.rdb.Expire(s.ctx, userSlotKey, testSlotTTL).Err()) + require.NoError(s.T(), s.rdb.Set(s.ctx, userWaitKey, 3, testSlotTTL).Err()) + require.NoError(s.T(), s.rdb.Set(s.ctx, accountWaitKey, 2, testSlotTTL).Err()) + + require.NoError(s.T(), s.cache.CleanupStaleProcessSlots(s.ctx, "activeproc-")) + + accountMembers, err := s.rdb.ZRange(s.ctx, accountSlotKey, 0, -1).Result() + require.NoError(s.T(), err) + require.Equal(s.T(), []string{"activeproc-1"}, accountMembers) + + userMembers, err := s.rdb.ZRange(s.ctx, userSlotKey, 0, -1).Result() + require.NoError(s.T(), err) + require.Equal(s.T(), []string{"activeproc-2"}, userMembers) + + _, err = s.rdb.Get(s.ctx, userWaitKey).Result() + require.ErrorIs(s.T(), err, redis.Nil) + _, err = s.rdb.Get(s.ctx, accountWaitKey).Result() + require.ErrorIs(s.T(), err, redis.Nil) +} + +func (s *ConcurrencyCacheSuite) TestCleanupStaleProcessSlots_DeletesEmptySlotKeys() { + accountID := int64(903) + accountSlotKey := fmt.Sprintf("%s%d", accountSlotKeyPrefix, accountID) + require.NoError(s.T(), s.rdb.ZAdd(s.ctx, accountSlotKey, redis.Z{Score: float64(time.Now().Unix()), Member: "oldproc-1"}).Err()) + require.NoError(s.T(), s.rdb.Expire(s.ctx, accountSlotKey, testSlotTTL).Err()) + + require.NoError(s.T(), s.cache.CleanupStaleProcessSlots(s.ctx, "activeproc-")) + + exists, err := s.rdb.Exists(s.ctx, accountSlotKey).Result() + require.NoError(s.T(), err) + require.EqualValues(s.T(), 0, exists) } diff --git a/backend/internal/service/concurrency_service.go b/backend/internal/service/concurrency_service.go index 4dcf84e0..217b83d6 100644 --- a/backend/internal/service/concurrency_service.go +++ b/backend/internal/service/concurrency_service.go @@ -43,6 +43,9 @@ type ConcurrencyCache interface { // 清理过期槽位(后台任务) CleanupExpiredAccountSlots(ctx context.Context, accountID int64) error + + // 启动时清理旧进程遗留槽位与等待计数 + CleanupStaleProcessSlots(ctx context.Context, activeRequestPrefix string) error } var ( @@ -59,13 +62,22 @@ func initRequestIDPrefix() string { return "r" + strconv.FormatUint(fallback, 36) } -// generateRequestID generates a unique request ID for concurrency slot tracking. -// Format: {process_random_prefix}-{base36_counter} +func RequestIDPrefix() string { + return requestIDPrefix +} + func generateRequestID() string { seq := requestIDCounter.Add(1) return requestIDPrefix + "-" + strconv.FormatUint(seq, 36) } +func (s *ConcurrencyService) CleanupStaleProcessSlots(ctx context.Context) error { + if s == nil || s.cache == nil { + return nil + } + return s.cache.CleanupStaleProcessSlots(ctx, RequestIDPrefix()) +} + const ( // Default extra wait slots beyond concurrency limit defaultExtraWaitSlots = 20 diff --git a/backend/internal/service/concurrency_service_test.go b/backend/internal/service/concurrency_service_test.go index 9ba43d93..078ba0dc 100644 --- a/backend/internal/service/concurrency_service_test.go +++ b/backend/internal/service/concurrency_service_test.go @@ -91,6 +91,32 @@ func (c *stubConcurrencyCacheForTest) CleanupExpiredAccountSlots(_ context.Conte return c.cleanupErr } +func (c *stubConcurrencyCacheForTest) CleanupStaleProcessSlots(_ context.Context, _ string) error { + return c.cleanupErr +} + +type trackingConcurrencyCache struct { + stubConcurrencyCacheForTest + cleanupPrefix string +} + +func (c *trackingConcurrencyCache) CleanupStaleProcessSlots(_ context.Context, prefix string) error { + c.cleanupPrefix = prefix + return c.cleanupErr +} + +func TestCleanupStaleProcessSlots_NilCache(t *testing.T) { + svc := &ConcurrencyService{cache: nil} + require.NoError(t, svc.CleanupStaleProcessSlots(context.Background())) +} + +func TestCleanupStaleProcessSlots_DelegatesPrefix(t *testing.T) { + cache := &trackingConcurrencyCache{} + svc := NewConcurrencyService(cache) + require.NoError(t, svc.CleanupStaleProcessSlots(context.Background())) + require.Equal(t, RequestIDPrefix(), cache.cleanupPrefix) +} + func TestAcquireAccountSlot_Success(t *testing.T) { cache := &stubConcurrencyCacheForTest{acquireResult: true} svc := NewConcurrencyService(cache) diff --git a/backend/internal/service/gateway_multiplatform_test.go b/backend/internal/service/gateway_multiplatform_test.go index 320ceaa7..f947a8ee 100644 --- a/backend/internal/service/gateway_multiplatform_test.go +++ b/backend/internal/service/gateway_multiplatform_test.go @@ -1986,6 +1986,10 @@ func (m *mockConcurrencyCache) CleanupExpiredAccountSlots(ctx context.Context, a return nil } +func (m *mockConcurrencyCache) CleanupStaleProcessSlots(ctx context.Context, activeRequestPrefix string) error { + return nil +} + func (m *mockConcurrencyCache) GetUsersLoadBatch(ctx context.Context, users []UserWithConcurrency) (map[int64]*UserLoadInfo, error) { result := make(map[int64]*UserLoadInfo, len(users)) for _, user := range users { diff --git a/backend/internal/service/wire.go b/backend/internal/service/wire.go index f91fbb88..7457b77e 100644 --- a/backend/internal/service/wire.go +++ b/backend/internal/service/wire.go @@ -105,6 +105,9 @@ func ProvideDeferredService(accountRepo AccountRepository, timingWheel *TimingWh // ProvideConcurrencyService creates ConcurrencyService and starts slot cleanup worker. func ProvideConcurrencyService(cache ConcurrencyCache, accountRepo AccountRepository, cfg *config.Config) *ConcurrencyService { svc := NewConcurrencyService(cache) + if err := svc.CleanupStaleProcessSlots(context.Background()); err != nil { + logger.LegacyPrintf("service.concurrency", "Warning: startup cleanup stale process slots failed: %v", err) + } if cfg != nil { svc.StartSlotCleanupWorker(accountRepo, cfg.Gateway.Scheduling.SlotCleanupInterval) } diff --git a/backend/internal/testutil/stubs.go b/backend/internal/testutil/stubs.go index 217a5f56..bc572e11 100644 --- a/backend/internal/testutil/stubs.go +++ b/backend/internal/testutil/stubs.go @@ -76,6 +76,9 @@ func (c StubConcurrencyCache) GetAccountConcurrencyBatch(_ context.Context, acco func (c StubConcurrencyCache) CleanupExpiredAccountSlots(_ context.Context, _ int64) error { return nil } +func (c StubConcurrencyCache) CleanupStaleProcessSlots(_ context.Context, _ string) error { + return nil +} // ============================================================ // StubGatewayCache — service.GatewayCache 的空实现