Merge pull request #894 from touwaeriol/pr/startup-concurrency-cleanup

feat: cleanup stale concurrency slots on startup
This commit is contained in:
Wesley Liddick
2026-03-10 09:08:33 +08:00
committed by GitHub
10 changed files with 254 additions and 23 deletions

View File

@@ -127,6 +127,7 @@ func (f *fakeConcurrencyCache) GetAccountConcurrencyBatch(_ context.Context, acc
return result, nil
}
func (f *fakeConcurrencyCache) CleanupExpiredAccountSlots(context.Context, int64) error { return nil }
func (f *fakeConcurrencyCache) CleanupStaleProcessSlots(context.Context, string) error { return nil }
func newTestGatewayHandler(t *testing.T, group *service.Group, accounts []*service.Account) (*GatewayHandler, func()) {
t.Helper()

View File

@@ -89,6 +89,10 @@ func (m *concurrencyCacheMock) CleanupExpiredAccountSlots(ctx context.Context, a
return nil
}
func (m *concurrencyCacheMock) CleanupStaleProcessSlots(ctx context.Context, activeRequestPrefix string) error {
return nil
}
func TestConcurrencyHelper_TryAcquireUserSlot(t *testing.T) {
cache := &concurrencyCacheMock{
acquireUserSlotFn: func(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error) {

View File

@@ -120,6 +120,10 @@ func (s *helperConcurrencyCacheStub) CleanupExpiredAccountSlots(ctx context.Cont
return nil
}
func (s *helperConcurrencyCacheStub) CleanupStaleProcessSlots(ctx context.Context, activeRequestPrefix string) error {
return nil
}
func newHelperTestContext(method, path string) (*gin.Context, *httptest.ResponseRecorder) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()

View File

@@ -147,17 +147,47 @@ var (
return 1
`)
// cleanupExpiredSlotsScript - remove expired slots
// KEYS[1] = concurrency:account:{accountID}
// ARGV[1] = TTL (seconds)
// cleanupExpiredSlotsScript 清理单个账号/用户有序集合中过期槽位
// KEYS[1] = 有序集合键
// ARGV[1] = TTL(秒)
cleanupExpiredSlotsScript = redis.NewScript(`
local key = KEYS[1]
local ttl = tonumber(ARGV[1])
local timeResult = redis.call('TIME')
local now = tonumber(timeResult[1])
local expireBefore = now - ttl
return redis.call('ZREMRANGEBYSCORE', key, '-inf', expireBefore)
`)
local key = KEYS[1]
local ttl = tonumber(ARGV[1])
local timeResult = redis.call('TIME')
local now = tonumber(timeResult[1])
local expireBefore = now - ttl
redis.call('ZREMRANGEBYSCORE', key, '-inf', expireBefore)
if redis.call('ZCARD', key) == 0 then
redis.call('DEL', key)
else
redis.call('EXPIRE', key, ttl)
end
return 1
`)
// startupCleanupScript 清理非当前进程前缀的槽位成员。
// KEYS 是有序集合键列表ARGV[1] 是当前进程前缀ARGV[2] 是槽位 TTL。
// 遍历每个 KEYS[i],移除前缀不匹配的成员,清空后删 key否则刷新 EXPIRE。
startupCleanupScript = redis.NewScript(`
local activePrefix = ARGV[1]
local slotTTL = tonumber(ARGV[2])
local removed = 0
for i = 1, #KEYS do
local key = KEYS[i]
local members = redis.call('ZRANGE', key, 0, -1)
for _, member in ipairs(members) do
if string.sub(member, 1, string.len(activePrefix)) ~= activePrefix then
removed = removed + redis.call('ZREM', key, member)
end
end
if redis.call('ZCARD', key) == 0 then
redis.call('DEL', key)
else
redis.call('EXPIRE', key, slotTTL)
end
end
return removed
`)
)
type concurrencyCache struct {
@@ -463,3 +493,72 @@ func (c *concurrencyCache) CleanupExpiredAccountSlots(ctx context.Context, accou
_, err := cleanupExpiredSlotsScript.Run(ctx, c.rdb, []string{key}, c.slotTTLSeconds).Result()
return err
}
func (c *concurrencyCache) CleanupStaleProcessSlots(ctx context.Context, activeRequestPrefix string) error {
if activeRequestPrefix == "" {
return nil
}
// 1. 清理有序集合中非当前进程前缀的成员
slotPatterns := []string{accountSlotKeyPrefix + "*", userSlotKeyPrefix + "*"}
for _, pattern := range slotPatterns {
if err := c.cleanupSlotsByPattern(ctx, pattern, activeRequestPrefix); err != nil {
return err
}
}
// 2. 删除所有等待队列计数器(重启后计数器失效)
waitPatterns := []string{accountWaitKeyPrefix + "*", waitQueueKeyPrefix + "*"}
for _, pattern := range waitPatterns {
if err := c.deleteKeysByPattern(ctx, pattern); err != nil {
return err
}
}
return nil
}
// cleanupSlotsByPattern 扫描匹配 pattern 的有序集合键,批量调用 Lua 脚本清理非当前进程成员。
func (c *concurrencyCache) cleanupSlotsByPattern(ctx context.Context, pattern, activePrefix string) error {
const scanCount = 200
var cursor uint64
for {
keys, nextCursor, err := c.rdb.Scan(ctx, cursor, pattern, scanCount).Result()
if err != nil {
return fmt.Errorf("scan %s: %w", pattern, err)
}
if len(keys) > 0 {
_, err := startupCleanupScript.Run(ctx, c.rdb, keys, activePrefix, c.slotTTLSeconds).Result()
if err != nil {
return fmt.Errorf("cleanup slots %s: %w", pattern, err)
}
}
cursor = nextCursor
if cursor == 0 {
break
}
}
return nil
}
// deleteKeysByPattern 扫描匹配 pattern 的键并删除。
func (c *concurrencyCache) deleteKeysByPattern(ctx context.Context, pattern string) error {
const scanCount = 200
var cursor uint64
for {
keys, nextCursor, err := c.rdb.Scan(ctx, cursor, pattern, scanCount).Result()
if err != nil {
return fmt.Errorf("scan %s: %w", pattern, err)
}
if len(keys) > 0 {
if err := c.rdb.Del(ctx, keys...).Err(); err != nil {
return fmt.Errorf("del %s: %w", pattern, err)
}
}
cursor = nextCursor
if cursor == 0 {
break
}
}
return nil
}

View File

@@ -25,6 +25,10 @@ type ConcurrencyCacheSuite struct {
cache service.ConcurrencyCache
}
func TestConcurrencyCacheSuite(t *testing.T) {
suite.Run(t, new(ConcurrencyCacheSuite))
}
func (s *ConcurrencyCacheSuite) SetupTest() {
s.IntegrationRedisSuite.SetupTest()
s.cache = NewConcurrencyCache(s.rdb, testSlotTTLMinutes, int(testSlotTTL.Seconds()))
@@ -247,17 +251,41 @@ func (s *ConcurrencyCacheSuite) TestAccountWaitQueue_IncrementAndDecrement() {
require.Equal(s.T(), 1, val, "expected account wait count 1")
}
func (s *ConcurrencyCacheSuite) TestAccountWaitQueue_DecrementNoNegative() {
accountID := int64(301)
waitKey := fmt.Sprintf("%s%d", accountWaitKeyPrefix, accountID)
func (s *ConcurrencyCacheSuite) TestCleanupStaleProcessSlots() {
accountID := int64(901)
userID := int64(902)
accountKey := fmt.Sprintf("%s%d", accountSlotKeyPrefix, accountID)
userKey := fmt.Sprintf("%s%d", userSlotKeyPrefix, userID)
userWaitKey := fmt.Sprintf("%s%d", waitQueueKeyPrefix, userID)
accountWaitKey := fmt.Sprintf("%s%d", accountWaitKeyPrefix, accountID)
require.NoError(s.T(), s.cache.DecrementAccountWaitCount(s.ctx, accountID), "DecrementAccountWaitCount on non-existent key")
now := time.Now().Unix()
require.NoError(s.T(), s.rdb.ZAdd(s.ctx, accountKey,
redis.Z{Score: float64(now), Member: "oldproc-1"},
redis.Z{Score: float64(now), Member: "keep-1"},
).Err())
require.NoError(s.T(), s.rdb.ZAdd(s.ctx, userKey,
redis.Z{Score: float64(now), Member: "oldproc-2"},
redis.Z{Score: float64(now), Member: "keep-2"},
).Err())
require.NoError(s.T(), s.rdb.Set(s.ctx, userWaitKey, 3, time.Minute).Err())
require.NoError(s.T(), s.rdb.Set(s.ctx, accountWaitKey, 2, time.Minute).Err())
val, err := s.rdb.Get(s.ctx, waitKey).Int()
if !errors.Is(err, redis.Nil) {
require.NoError(s.T(), err, "Get waitKey")
}
require.GreaterOrEqual(s.T(), val, 0, "expected non-negative account wait count after decrement on empty")
require.NoError(s.T(), s.cache.CleanupStaleProcessSlots(s.ctx, "keep-"))
accountMembers, err := s.rdb.ZRange(s.ctx, accountKey, 0, -1).Result()
require.NoError(s.T(), err)
require.Equal(s.T(), []string{"keep-1"}, accountMembers)
userMembers, err := s.rdb.ZRange(s.ctx, userKey, 0, -1).Result()
require.NoError(s.T(), err)
require.Equal(s.T(), []string{"keep-2"}, userMembers)
_, err = s.rdb.Get(s.ctx, userWaitKey).Result()
require.True(s.T(), errors.Is(err, redis.Nil))
_, err = s.rdb.Get(s.ctx, accountWaitKey).Result()
require.True(s.T(), errors.Is(err, redis.Nil))
}
func (s *ConcurrencyCacheSuite) TestGetAccountConcurrency_Missing() {
@@ -407,6 +435,53 @@ func (s *ConcurrencyCacheSuite) TestCleanupExpiredAccountSlots_NoExpired() {
require.Equal(s.T(), 2, cur)
}
func TestConcurrencyCacheSuite(t *testing.T) {
suite.Run(t, new(ConcurrencyCacheSuite))
func (s *ConcurrencyCacheSuite) TestCleanupStaleProcessSlots_RemovesOldPrefixesAndWaitCounters() {
accountID := int64(901)
userID := int64(902)
accountSlotKey := fmt.Sprintf("%s%d", accountSlotKeyPrefix, accountID)
userSlotKey := fmt.Sprintf("%s%d", userSlotKeyPrefix, userID)
userWaitKey := fmt.Sprintf("%s%d", waitQueueKeyPrefix, userID)
accountWaitKey := fmt.Sprintf("%s%d", accountWaitKeyPrefix, accountID)
now := float64(time.Now().Unix())
require.NoError(s.T(), s.rdb.ZAdd(s.ctx, accountSlotKey,
redis.Z{Score: now, Member: "oldproc-1"},
redis.Z{Score: now, Member: "activeproc-1"},
).Err())
require.NoError(s.T(), s.rdb.Expire(s.ctx, accountSlotKey, testSlotTTL).Err())
require.NoError(s.T(), s.rdb.ZAdd(s.ctx, userSlotKey,
redis.Z{Score: now, Member: "oldproc-2"},
redis.Z{Score: now, Member: "activeproc-2"},
).Err())
require.NoError(s.T(), s.rdb.Expire(s.ctx, userSlotKey, testSlotTTL).Err())
require.NoError(s.T(), s.rdb.Set(s.ctx, userWaitKey, 3, testSlotTTL).Err())
require.NoError(s.T(), s.rdb.Set(s.ctx, accountWaitKey, 2, testSlotTTL).Err())
require.NoError(s.T(), s.cache.CleanupStaleProcessSlots(s.ctx, "activeproc-"))
accountMembers, err := s.rdb.ZRange(s.ctx, accountSlotKey, 0, -1).Result()
require.NoError(s.T(), err)
require.Equal(s.T(), []string{"activeproc-1"}, accountMembers)
userMembers, err := s.rdb.ZRange(s.ctx, userSlotKey, 0, -1).Result()
require.NoError(s.T(), err)
require.Equal(s.T(), []string{"activeproc-2"}, userMembers)
_, err = s.rdb.Get(s.ctx, userWaitKey).Result()
require.ErrorIs(s.T(), err, redis.Nil)
_, err = s.rdb.Get(s.ctx, accountWaitKey).Result()
require.ErrorIs(s.T(), err, redis.Nil)
}
func (s *ConcurrencyCacheSuite) TestCleanupStaleProcessSlots_DeletesEmptySlotKeys() {
accountID := int64(903)
accountSlotKey := fmt.Sprintf("%s%d", accountSlotKeyPrefix, accountID)
require.NoError(s.T(), s.rdb.ZAdd(s.ctx, accountSlotKey, redis.Z{Score: float64(time.Now().Unix()), Member: "oldproc-1"}).Err())
require.NoError(s.T(), s.rdb.Expire(s.ctx, accountSlotKey, testSlotTTL).Err())
require.NoError(s.T(), s.cache.CleanupStaleProcessSlots(s.ctx, "activeproc-"))
exists, err := s.rdb.Exists(s.ctx, accountSlotKey).Result()
require.NoError(s.T(), err)
require.EqualValues(s.T(), 0, exists)
}

View File

@@ -43,6 +43,9 @@ type ConcurrencyCache interface {
// 清理过期槽位(后台任务)
CleanupExpiredAccountSlots(ctx context.Context, accountID int64) error
// 启动时清理旧进程遗留槽位与等待计数
CleanupStaleProcessSlots(ctx context.Context, activeRequestPrefix string) error
}
var (
@@ -59,13 +62,22 @@ func initRequestIDPrefix() string {
return "r" + strconv.FormatUint(fallback, 36)
}
// generateRequestID generates a unique request ID for concurrency slot tracking.
// Format: {process_random_prefix}-{base36_counter}
func RequestIDPrefix() string {
return requestIDPrefix
}
func generateRequestID() string {
seq := requestIDCounter.Add(1)
return requestIDPrefix + "-" + strconv.FormatUint(seq, 36)
}
func (s *ConcurrencyService) CleanupStaleProcessSlots(ctx context.Context) error {
if s == nil || s.cache == nil {
return nil
}
return s.cache.CleanupStaleProcessSlots(ctx, RequestIDPrefix())
}
const (
// Default extra wait slots beyond concurrency limit
defaultExtraWaitSlots = 20

View File

@@ -91,6 +91,32 @@ func (c *stubConcurrencyCacheForTest) CleanupExpiredAccountSlots(_ context.Conte
return c.cleanupErr
}
func (c *stubConcurrencyCacheForTest) CleanupStaleProcessSlots(_ context.Context, _ string) error {
return c.cleanupErr
}
type trackingConcurrencyCache struct {
stubConcurrencyCacheForTest
cleanupPrefix string
}
func (c *trackingConcurrencyCache) CleanupStaleProcessSlots(_ context.Context, prefix string) error {
c.cleanupPrefix = prefix
return c.cleanupErr
}
func TestCleanupStaleProcessSlots_NilCache(t *testing.T) {
svc := &ConcurrencyService{cache: nil}
require.NoError(t, svc.CleanupStaleProcessSlots(context.Background()))
}
func TestCleanupStaleProcessSlots_DelegatesPrefix(t *testing.T) {
cache := &trackingConcurrencyCache{}
svc := NewConcurrencyService(cache)
require.NoError(t, svc.CleanupStaleProcessSlots(context.Background()))
require.Equal(t, RequestIDPrefix(), cache.cleanupPrefix)
}
func TestAcquireAccountSlot_Success(t *testing.T) {
cache := &stubConcurrencyCacheForTest{acquireResult: true}
svc := NewConcurrencyService(cache)

View File

@@ -1986,6 +1986,10 @@ func (m *mockConcurrencyCache) CleanupExpiredAccountSlots(ctx context.Context, a
return nil
}
func (m *mockConcurrencyCache) CleanupStaleProcessSlots(ctx context.Context, activeRequestPrefix string) error {
return nil
}
func (m *mockConcurrencyCache) GetUsersLoadBatch(ctx context.Context, users []UserWithConcurrency) (map[int64]*UserLoadInfo, error) {
result := make(map[int64]*UserLoadInfo, len(users))
for _, user := range users {

View File

@@ -105,6 +105,9 @@ func ProvideDeferredService(accountRepo AccountRepository, timingWheel *TimingWh
// ProvideConcurrencyService creates ConcurrencyService and starts slot cleanup worker.
func ProvideConcurrencyService(cache ConcurrencyCache, accountRepo AccountRepository, cfg *config.Config) *ConcurrencyService {
svc := NewConcurrencyService(cache)
if err := svc.CleanupStaleProcessSlots(context.Background()); err != nil {
logger.LegacyPrintf("service.concurrency", "Warning: startup cleanup stale process slots failed: %v", err)
}
if cfg != nil {
svc.StartSlotCleanupWorker(accountRepo, cfg.Gateway.Scheduling.SlotCleanupInterval)
}

View File

@@ -76,6 +76,9 @@ func (c StubConcurrencyCache) GetAccountConcurrencyBatch(_ context.Context, acco
func (c StubConcurrencyCache) CleanupExpiredAccountSlots(_ context.Context, _ int64) error {
return nil
}
func (c StubConcurrencyCache) CleanupStaleProcessSlots(_ context.Context, _ string) error {
return nil
}
// ============================================================
// StubGatewayCache — service.GatewayCache 的空实现