mirror of
https://github.com/Wei-Shaw/sub2api.git
synced 2026-03-30 11:35:30 +00:00
Merge pull request #854 from james-6-23/main
feat(admin): 支持定时测试自动恢复并统一账号恢复入口
This commit is contained in:
@@ -660,6 +660,42 @@ func (h *AccountHandler) Test(c *gin.Context) {
|
||||
// Error already sent via SSE, just log
|
||||
return
|
||||
}
|
||||
|
||||
if h.rateLimitService != nil {
|
||||
if _, err := h.rateLimitService.RecoverAccountAfterSuccessfulTest(c.Request.Context(), accountID); err != nil {
|
||||
_ = c.Error(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// RecoverState handles unified recovery of recoverable account runtime state.
|
||||
// POST /api/v1/admin/accounts/:id/recover-state
|
||||
func (h *AccountHandler) RecoverState(c *gin.Context) {
|
||||
accountID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid account ID")
|
||||
return
|
||||
}
|
||||
|
||||
if h.rateLimitService == nil {
|
||||
response.Error(c, http.StatusServiceUnavailable, "Rate limit service unavailable")
|
||||
return
|
||||
}
|
||||
|
||||
if _, err := h.rateLimitService.RecoverAccountState(c.Request.Context(), accountID, service.AccountRecoveryOptions{
|
||||
InvalidateToken: true,
|
||||
}); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
account, err := h.adminService.GetAccount(c.Request.Context(), accountID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, h.buildAccountResponseWithRuntime(c.Request.Context(), account))
|
||||
}
|
||||
|
||||
// SyncFromCRS handles syncing accounts from claude-relay-service (CRS)
|
||||
|
||||
@@ -25,6 +25,7 @@ type createScheduledTestPlanRequest struct {
|
||||
CronExpression string `json:"cron_expression" binding:"required"`
|
||||
Enabled *bool `json:"enabled"`
|
||||
MaxResults int `json:"max_results"`
|
||||
AutoRecover *bool `json:"auto_recover"`
|
||||
}
|
||||
|
||||
type updateScheduledTestPlanRequest struct {
|
||||
@@ -32,6 +33,7 @@ type updateScheduledTestPlanRequest struct {
|
||||
CronExpression string `json:"cron_expression"`
|
||||
Enabled *bool `json:"enabled"`
|
||||
MaxResults int `json:"max_results"`
|
||||
AutoRecover *bool `json:"auto_recover"`
|
||||
}
|
||||
|
||||
// ListByAccount GET /admin/accounts/:id/scheduled-test-plans
|
||||
@@ -68,6 +70,9 @@ func (h *ScheduledTestHandler) Create(c *gin.Context) {
|
||||
if req.Enabled != nil {
|
||||
plan.Enabled = *req.Enabled
|
||||
}
|
||||
if req.AutoRecover != nil {
|
||||
plan.AutoRecover = *req.AutoRecover
|
||||
}
|
||||
|
||||
created, err := h.scheduledTestSvc.CreatePlan(c.Request.Context(), plan)
|
||||
if err != nil {
|
||||
@@ -109,6 +114,9 @@ func (h *ScheduledTestHandler) Update(c *gin.Context) {
|
||||
if req.MaxResults > 0 {
|
||||
existing.MaxResults = req.MaxResults
|
||||
}
|
||||
if req.AutoRecover != nil {
|
||||
existing.AutoRecover = *req.AutoRecover
|
||||
}
|
||||
|
||||
updated, err := h.scheduledTestSvc.UpdatePlan(c.Request.Context(), existing)
|
||||
if err != nil {
|
||||
|
||||
@@ -659,13 +659,10 @@ func (r *accountRepository) ClearError(ctx context.Context, id int64) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// 清除临时不可调度状态,重置 401 升级链
|
||||
_, _ = r.sql.ExecContext(ctx, `
|
||||
UPDATE accounts
|
||||
SET temp_unschedulable_until = NULL,
|
||||
temp_unschedulable_reason = NULL
|
||||
WHERE id = $1 AND deleted_at IS NULL
|
||||
`, id)
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
|
||||
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue clear error failed: account=%d err=%v", id, err)
|
||||
}
|
||||
r.syncSchedulerAccountSnapshot(ctx, id)
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -558,6 +558,26 @@ func (s *AccountRepoSuite) TestSetError() {
|
||||
s.Require().Equal("something went wrong", got.ErrorMessage)
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) TestClearError_SyncSchedulerSnapshotOnRecovery() {
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{
|
||||
Name: "acc-clear-err",
|
||||
Status: service.StatusError,
|
||||
ErrorMessage: "temporary error",
|
||||
})
|
||||
cacheRecorder := &schedulerCacheRecorder{}
|
||||
s.repo.schedulerCache = cacheRecorder
|
||||
|
||||
s.Require().NoError(s.repo.ClearError(s.ctx, account.ID))
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, account.ID)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Equal(service.StatusActive, got.Status)
|
||||
s.Require().Empty(got.ErrorMessage)
|
||||
s.Require().Len(cacheRecorder.setAccounts, 1)
|
||||
s.Require().Equal(account.ID, cacheRecorder.setAccounts[0].ID)
|
||||
s.Require().Equal(service.StatusActive, cacheRecorder.setAccounts[0].Status)
|
||||
}
|
||||
|
||||
// --- UpdateSessionWindow ---
|
||||
|
||||
func (s *AccountRepoSuite) TestUpdateSessionWindow() {
|
||||
|
||||
@@ -20,16 +20,16 @@ func NewScheduledTestPlanRepository(db *sql.DB) service.ScheduledTestPlanReposit
|
||||
|
||||
func (r *scheduledTestPlanRepository) Create(ctx context.Context, plan *service.ScheduledTestPlan) (*service.ScheduledTestPlan, error) {
|
||||
row := r.db.QueryRowContext(ctx, `
|
||||
INSERT INTO scheduled_test_plans (account_id, model_id, cron_expression, enabled, max_results, next_run_at, created_at, updated_at)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, NOW(), NOW())
|
||||
RETURNING id, account_id, model_id, cron_expression, enabled, max_results, last_run_at, next_run_at, created_at, updated_at
|
||||
`, plan.AccountID, plan.ModelID, plan.CronExpression, plan.Enabled, plan.MaxResults, plan.NextRunAt)
|
||||
INSERT INTO scheduled_test_plans (account_id, model_id, cron_expression, enabled, max_results, auto_recover, next_run_at, created_at, updated_at)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, NOW(), NOW())
|
||||
RETURNING id, account_id, model_id, cron_expression, enabled, max_results, auto_recover, last_run_at, next_run_at, created_at, updated_at
|
||||
`, plan.AccountID, plan.ModelID, plan.CronExpression, plan.Enabled, plan.MaxResults, plan.AutoRecover, plan.NextRunAt)
|
||||
return scanPlan(row)
|
||||
}
|
||||
|
||||
func (r *scheduledTestPlanRepository) GetByID(ctx context.Context, id int64) (*service.ScheduledTestPlan, error) {
|
||||
row := r.db.QueryRowContext(ctx, `
|
||||
SELECT id, account_id, model_id, cron_expression, enabled, max_results, last_run_at, next_run_at, created_at, updated_at
|
||||
SELECT id, account_id, model_id, cron_expression, enabled, max_results, auto_recover, last_run_at, next_run_at, created_at, updated_at
|
||||
FROM scheduled_test_plans WHERE id = $1
|
||||
`, id)
|
||||
return scanPlan(row)
|
||||
@@ -37,7 +37,7 @@ func (r *scheduledTestPlanRepository) GetByID(ctx context.Context, id int64) (*s
|
||||
|
||||
func (r *scheduledTestPlanRepository) ListByAccountID(ctx context.Context, accountID int64) ([]*service.ScheduledTestPlan, error) {
|
||||
rows, err := r.db.QueryContext(ctx, `
|
||||
SELECT id, account_id, model_id, cron_expression, enabled, max_results, last_run_at, next_run_at, created_at, updated_at
|
||||
SELECT id, account_id, model_id, cron_expression, enabled, max_results, auto_recover, last_run_at, next_run_at, created_at, updated_at
|
||||
FROM scheduled_test_plans WHERE account_id = $1
|
||||
ORDER BY created_at DESC
|
||||
`, accountID)
|
||||
@@ -50,7 +50,7 @@ func (r *scheduledTestPlanRepository) ListByAccountID(ctx context.Context, accou
|
||||
|
||||
func (r *scheduledTestPlanRepository) ListDue(ctx context.Context, now time.Time) ([]*service.ScheduledTestPlan, error) {
|
||||
rows, err := r.db.QueryContext(ctx, `
|
||||
SELECT id, account_id, model_id, cron_expression, enabled, max_results, last_run_at, next_run_at, created_at, updated_at
|
||||
SELECT id, account_id, model_id, cron_expression, enabled, max_results, auto_recover, last_run_at, next_run_at, created_at, updated_at
|
||||
FROM scheduled_test_plans
|
||||
WHERE enabled = true AND next_run_at <= $1
|
||||
ORDER BY next_run_at ASC
|
||||
@@ -65,10 +65,10 @@ func (r *scheduledTestPlanRepository) ListDue(ctx context.Context, now time.Time
|
||||
func (r *scheduledTestPlanRepository) Update(ctx context.Context, plan *service.ScheduledTestPlan) (*service.ScheduledTestPlan, error) {
|
||||
row := r.db.QueryRowContext(ctx, `
|
||||
UPDATE scheduled_test_plans
|
||||
SET model_id = $2, cron_expression = $3, enabled = $4, max_results = $5, next_run_at = $6, updated_at = NOW()
|
||||
SET model_id = $2, cron_expression = $3, enabled = $4, max_results = $5, auto_recover = $6, next_run_at = $7, updated_at = NOW()
|
||||
WHERE id = $1
|
||||
RETURNING id, account_id, model_id, cron_expression, enabled, max_results, last_run_at, next_run_at, created_at, updated_at
|
||||
`, plan.ID, plan.ModelID, plan.CronExpression, plan.Enabled, plan.MaxResults, plan.NextRunAt)
|
||||
RETURNING id, account_id, model_id, cron_expression, enabled, max_results, auto_recover, last_run_at, next_run_at, created_at, updated_at
|
||||
`, plan.ID, plan.ModelID, plan.CronExpression, plan.Enabled, plan.MaxResults, plan.AutoRecover, plan.NextRunAt)
|
||||
return scanPlan(row)
|
||||
}
|
||||
|
||||
@@ -162,7 +162,7 @@ type scannable interface {
|
||||
func scanPlan(row scannable) (*service.ScheduledTestPlan, error) {
|
||||
p := &service.ScheduledTestPlan{}
|
||||
if err := row.Scan(
|
||||
&p.ID, &p.AccountID, &p.ModelID, &p.CronExpression, &p.Enabled, &p.MaxResults,
|
||||
&p.ID, &p.AccountID, &p.ModelID, &p.CronExpression, &p.Enabled, &p.MaxResults, &p.AutoRecover,
|
||||
&p.LastRunAt, &p.NextRunAt, &p.CreatedAt, &p.UpdatedAt,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
|
||||
@@ -244,6 +244,7 @@ func registerAccountRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
accounts.PUT("/:id", h.Admin.Account.Update)
|
||||
accounts.DELETE("/:id", h.Admin.Account.Delete)
|
||||
accounts.POST("/:id/test", h.Admin.Account.Test)
|
||||
accounts.POST("/:id/recover-state", h.Admin.Account.RecoverState)
|
||||
accounts.POST("/:id/refresh", h.Admin.Account.Refresh)
|
||||
accounts.POST("/:id/refresh-tier", h.Admin.Account.RefreshTier)
|
||||
accounts.GET("/:id/stats", h.Admin.Account.GetStats)
|
||||
|
||||
@@ -1723,16 +1723,10 @@ func (s *adminServiceImpl) RefreshAccountCredentials(ctx context.Context, id int
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) ClearAccountError(ctx context.Context, id int64) (*Account, error) {
|
||||
account, err := s.accountRepo.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
if err := s.accountRepo.ClearError(ctx, id); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
account.Status = StatusActive
|
||||
account.ErrorMessage = ""
|
||||
if err := s.accountRepo.Update(ctx, account); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return account, nil
|
||||
return s.accountRepo.GetByID(ctx, id)
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) SetAccountError(ctx context.Context, id int64, errorMsg string) error {
|
||||
|
||||
@@ -28,6 +28,17 @@ type RateLimitService struct {
|
||||
usageCache map[int64]*geminiUsageCacheEntry
|
||||
}
|
||||
|
||||
// SuccessfulTestRecoveryResult 表示测试成功后恢复了哪些运行时状态。
|
||||
type SuccessfulTestRecoveryResult struct {
|
||||
ClearedError bool
|
||||
ClearedRateLimit bool
|
||||
}
|
||||
|
||||
// AccountRecoveryOptions 控制账号恢复时的附加行为。
|
||||
type AccountRecoveryOptions struct {
|
||||
InvalidateToken bool
|
||||
}
|
||||
|
||||
type geminiUsageCacheEntry struct {
|
||||
windowStart time.Time
|
||||
cachedAt time.Time
|
||||
@@ -1040,6 +1051,42 @@ func (s *RateLimitService) ClearRateLimit(ctx context.Context, accountID int64)
|
||||
return nil
|
||||
}
|
||||
|
||||
// RecoverAccountState 按需恢复账号的可恢复运行时状态。
|
||||
func (s *RateLimitService) RecoverAccountState(ctx context.Context, accountID int64, options AccountRecoveryOptions) (*SuccessfulTestRecoveryResult, error) {
|
||||
account, err := s.accountRepo.GetByID(ctx, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result := &SuccessfulTestRecoveryResult{}
|
||||
if account.Status == StatusError {
|
||||
if err := s.accountRepo.ClearError(ctx, accountID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result.ClearedError = true
|
||||
if options.InvalidateToken && s.tokenCacheInvalidator != nil && account.IsOAuth() {
|
||||
if invalidateErr := s.tokenCacheInvalidator.InvalidateToken(ctx, account); invalidateErr != nil {
|
||||
slog.Warn("recover_account_state_invalidate_token_failed", "account_id", accountID, "error", invalidateErr)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if hasRecoverableRuntimeState(account) {
|
||||
if err := s.ClearRateLimit(ctx, accountID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result.ClearedRateLimit = true
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// RecoverAccountAfterSuccessfulTest 将一次成功测试视为正常请求,
|
||||
// 按需恢复 error / rate-limit / overload / temp-unsched / model-rate-limit 等运行时状态。
|
||||
func (s *RateLimitService) RecoverAccountAfterSuccessfulTest(ctx context.Context, accountID int64) (*SuccessfulTestRecoveryResult, error) {
|
||||
return s.RecoverAccountState(ctx, accountID, AccountRecoveryOptions{})
|
||||
}
|
||||
|
||||
func (s *RateLimitService) ClearTempUnschedulable(ctx context.Context, accountID int64) error {
|
||||
if err := s.accountRepo.ClearTempUnschedulable(ctx, accountID); err != nil {
|
||||
return err
|
||||
@@ -1056,6 +1103,36 @@ func (s *RateLimitService) ClearTempUnschedulable(ctx context.Context, accountID
|
||||
return nil
|
||||
}
|
||||
|
||||
func hasRecoverableRuntimeState(account *Account) bool {
|
||||
if account == nil {
|
||||
return false
|
||||
}
|
||||
if account.RateLimitedAt != nil || account.RateLimitResetAt != nil || account.OverloadUntil != nil || account.TempUnschedulableUntil != nil {
|
||||
return true
|
||||
}
|
||||
if len(account.Extra) == 0 {
|
||||
return false
|
||||
}
|
||||
return hasNonEmptyMapValue(account.Extra, "model_rate_limits") || hasNonEmptyMapValue(account.Extra, "antigravity_quota_scopes")
|
||||
}
|
||||
|
||||
func hasNonEmptyMapValue(extra map[string]any, key string) bool {
|
||||
raw, ok := extra[key]
|
||||
if !ok || raw == nil {
|
||||
return false
|
||||
}
|
||||
switch typed := raw.(type) {
|
||||
case map[string]any:
|
||||
return len(typed) > 0
|
||||
case map[string]string:
|
||||
return len(typed) > 0
|
||||
case []any:
|
||||
return len(typed) > 0
|
||||
default:
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
func (s *RateLimitService) GetTempUnschedStatus(ctx context.Context, accountID int64) (*TempUnschedState, error) {
|
||||
now := time.Now().Unix()
|
||||
if s.tempUnschedCache != nil {
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/stretchr/testify/require"
|
||||
@@ -13,16 +14,34 @@ import (
|
||||
|
||||
type rateLimitClearRepoStub struct {
|
||||
mockAccountRepoForGemini
|
||||
getByIDAccount *Account
|
||||
getByIDErr error
|
||||
getByIDCalls int
|
||||
clearErrorCalls int
|
||||
clearRateLimitCalls int
|
||||
clearAntigravityCalls int
|
||||
clearModelRateLimitCalls int
|
||||
clearTempUnschedCalls int
|
||||
clearErrorErr error
|
||||
clearRateLimitErr error
|
||||
clearAntigravityErr error
|
||||
clearModelRateLimitErr error
|
||||
clearTempUnschedulableErr error
|
||||
}
|
||||
|
||||
func (r *rateLimitClearRepoStub) GetByID(ctx context.Context, id int64) (*Account, error) {
|
||||
r.getByIDCalls++
|
||||
if r.getByIDErr != nil {
|
||||
return nil, r.getByIDErr
|
||||
}
|
||||
return r.getByIDAccount, nil
|
||||
}
|
||||
|
||||
func (r *rateLimitClearRepoStub) ClearError(ctx context.Context, id int64) error {
|
||||
r.clearErrorCalls++
|
||||
return r.clearErrorErr
|
||||
}
|
||||
|
||||
func (r *rateLimitClearRepoStub) ClearRateLimit(ctx context.Context, id int64) error {
|
||||
r.clearRateLimitCalls++
|
||||
return r.clearRateLimitErr
|
||||
@@ -48,6 +67,11 @@ type tempUnschedCacheRecorder struct {
|
||||
deleteErr error
|
||||
}
|
||||
|
||||
type recoverTokenInvalidatorStub struct {
|
||||
accounts []*Account
|
||||
err error
|
||||
}
|
||||
|
||||
func (c *tempUnschedCacheRecorder) SetTempUnsched(ctx context.Context, accountID int64, state *TempUnschedState) error {
|
||||
return nil
|
||||
}
|
||||
@@ -61,6 +85,11 @@ func (c *tempUnschedCacheRecorder) DeleteTempUnsched(ctx context.Context, accoun
|
||||
return c.deleteErr
|
||||
}
|
||||
|
||||
func (s *recoverTokenInvalidatorStub) InvalidateToken(ctx context.Context, account *Account) error {
|
||||
s.accounts = append(s.accounts, account)
|
||||
return s.err
|
||||
}
|
||||
|
||||
func TestRateLimitService_ClearRateLimit_AlsoClearsTempUnschedulable(t *testing.T) {
|
||||
repo := &rateLimitClearRepoStub{}
|
||||
cache := &tempUnschedCacheRecorder{}
|
||||
@@ -170,3 +199,108 @@ func TestRateLimitService_ClearRateLimit_WithoutTempUnschedCache(t *testing.T) {
|
||||
require.Equal(t, 1, repo.clearModelRateLimitCalls)
|
||||
require.Equal(t, 1, repo.clearTempUnschedCalls)
|
||||
}
|
||||
|
||||
func TestRateLimitService_RecoverAccountAfterSuccessfulTest_ClearsErrorAndRateLimitRelatedState(t *testing.T) {
|
||||
now := time.Now()
|
||||
repo := &rateLimitClearRepoStub{
|
||||
getByIDAccount: &Account{
|
||||
ID: 42,
|
||||
Status: StatusError,
|
||||
RateLimitedAt: &now,
|
||||
TempUnschedulableUntil: &now,
|
||||
Extra: map[string]any{
|
||||
"model_rate_limits": map[string]any{
|
||||
"claude-sonnet-4-5": map[string]any{
|
||||
"rate_limit_reset_at": now.Format(time.RFC3339),
|
||||
},
|
||||
},
|
||||
"antigravity_quota_scopes": map[string]any{"gemini": true},
|
||||
},
|
||||
},
|
||||
}
|
||||
cache := &tempUnschedCacheRecorder{}
|
||||
svc := NewRateLimitService(repo, nil, &config.Config{}, nil, cache)
|
||||
|
||||
result, err := svc.RecoverAccountAfterSuccessfulTest(context.Background(), 42)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.True(t, result.ClearedError)
|
||||
require.True(t, result.ClearedRateLimit)
|
||||
|
||||
require.Equal(t, 1, repo.getByIDCalls)
|
||||
require.Equal(t, 1, repo.clearErrorCalls)
|
||||
require.Equal(t, 1, repo.clearRateLimitCalls)
|
||||
require.Equal(t, 1, repo.clearAntigravityCalls)
|
||||
require.Equal(t, 1, repo.clearModelRateLimitCalls)
|
||||
require.Equal(t, 1, repo.clearTempUnschedCalls)
|
||||
require.Equal(t, []int64{42}, cache.deletedIDs)
|
||||
}
|
||||
|
||||
func TestRateLimitService_RecoverAccountAfterSuccessfulTest_NoRecoverableStateIsNoop(t *testing.T) {
|
||||
repo := &rateLimitClearRepoStub{
|
||||
getByIDAccount: &Account{
|
||||
ID: 7,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Extra: map[string]any{},
|
||||
},
|
||||
}
|
||||
cache := &tempUnschedCacheRecorder{}
|
||||
svc := NewRateLimitService(repo, nil, &config.Config{}, nil, cache)
|
||||
|
||||
result, err := svc.RecoverAccountAfterSuccessfulTest(context.Background(), 7)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.False(t, result.ClearedError)
|
||||
require.False(t, result.ClearedRateLimit)
|
||||
|
||||
require.Equal(t, 1, repo.getByIDCalls)
|
||||
require.Equal(t, 0, repo.clearErrorCalls)
|
||||
require.Equal(t, 0, repo.clearRateLimitCalls)
|
||||
require.Equal(t, 0, repo.clearAntigravityCalls)
|
||||
require.Equal(t, 0, repo.clearModelRateLimitCalls)
|
||||
require.Equal(t, 0, repo.clearTempUnschedCalls)
|
||||
require.Empty(t, cache.deletedIDs)
|
||||
}
|
||||
|
||||
func TestRateLimitService_RecoverAccountAfterSuccessfulTest_ClearErrorFailed(t *testing.T) {
|
||||
repo := &rateLimitClearRepoStub{
|
||||
getByIDAccount: &Account{
|
||||
ID: 9,
|
||||
Status: StatusError,
|
||||
},
|
||||
clearErrorErr: errors.New("clear error failed"),
|
||||
}
|
||||
svc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
|
||||
|
||||
result, err := svc.RecoverAccountAfterSuccessfulTest(context.Background(), 9)
|
||||
require.Error(t, err)
|
||||
require.Nil(t, result)
|
||||
require.Equal(t, 1, repo.getByIDCalls)
|
||||
require.Equal(t, 1, repo.clearErrorCalls)
|
||||
require.Equal(t, 0, repo.clearRateLimitCalls)
|
||||
}
|
||||
|
||||
func TestRateLimitService_RecoverAccountState_InvalidatesOAuthTokenOnErrorRecovery(t *testing.T) {
|
||||
repo := &rateLimitClearRepoStub{
|
||||
getByIDAccount: &Account{
|
||||
ID: 21,
|
||||
Type: AccountTypeOAuth,
|
||||
Status: StatusError,
|
||||
},
|
||||
}
|
||||
invalidator := &recoverTokenInvalidatorStub{}
|
||||
svc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
|
||||
svc.SetTokenCacheInvalidator(invalidator)
|
||||
|
||||
result, err := svc.RecoverAccountState(context.Background(), 21, AccountRecoveryOptions{
|
||||
InvalidateToken: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.True(t, result.ClearedError)
|
||||
require.False(t, result.ClearedRateLimit)
|
||||
require.Equal(t, 1, repo.clearErrorCalls)
|
||||
require.Len(t, invalidator.accounts, 1)
|
||||
require.Equal(t, int64(21), invalidator.accounts[0].ID)
|
||||
}
|
||||
|
||||
@@ -13,6 +13,7 @@ type ScheduledTestPlan struct {
|
||||
CronExpression string `json:"cron_expression"`
|
||||
Enabled bool `json:"enabled"`
|
||||
MaxResults int `json:"max_results"`
|
||||
AutoRecover bool `json:"auto_recover"`
|
||||
LastRunAt *time.Time `json:"last_run_at"`
|
||||
NextRunAt *time.Time `json:"next_run_at"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
|
||||
@@ -17,6 +17,7 @@ type ScheduledTestRunnerService struct {
|
||||
planRepo ScheduledTestPlanRepository
|
||||
scheduledSvc *ScheduledTestService
|
||||
accountTestSvc *AccountTestService
|
||||
rateLimitSvc *RateLimitService
|
||||
cfg *config.Config
|
||||
|
||||
cron *cron.Cron
|
||||
@@ -29,12 +30,14 @@ func NewScheduledTestRunnerService(
|
||||
planRepo ScheduledTestPlanRepository,
|
||||
scheduledSvc *ScheduledTestService,
|
||||
accountTestSvc *AccountTestService,
|
||||
rateLimitSvc *RateLimitService,
|
||||
cfg *config.Config,
|
||||
) *ScheduledTestRunnerService {
|
||||
return &ScheduledTestRunnerService{
|
||||
planRepo: planRepo,
|
||||
scheduledSvc: scheduledSvc,
|
||||
accountTestSvc: accountTestSvc,
|
||||
rateLimitSvc: rateLimitSvc,
|
||||
cfg: cfg,
|
||||
}
|
||||
}
|
||||
@@ -127,6 +130,11 @@ func (s *ScheduledTestRunnerService) runOnePlan(ctx context.Context, plan *Sched
|
||||
logger.LegacyPrintf("service.scheduled_test_runner", "[ScheduledTestRunner] plan=%d SaveResult error: %v", plan.ID, err)
|
||||
}
|
||||
|
||||
// Auto-recover account if test succeeded and auto_recover is enabled.
|
||||
if result.Status == "success" && plan.AutoRecover {
|
||||
s.tryRecoverAccount(ctx, plan.AccountID, plan.ID)
|
||||
}
|
||||
|
||||
nextRun, err := computeNextRun(plan.CronExpression, time.Now())
|
||||
if err != nil {
|
||||
logger.LegacyPrintf("service.scheduled_test_runner", "[ScheduledTestRunner] plan=%d computeNextRun error: %v", plan.ID, err)
|
||||
@@ -137,3 +145,26 @@ func (s *ScheduledTestRunnerService) runOnePlan(ctx context.Context, plan *Sched
|
||||
logger.LegacyPrintf("service.scheduled_test_runner", "[ScheduledTestRunner] plan=%d UpdateAfterRun error: %v", plan.ID, err)
|
||||
}
|
||||
}
|
||||
|
||||
// tryRecoverAccount attempts to recover an account from recoverable runtime state.
|
||||
func (s *ScheduledTestRunnerService) tryRecoverAccount(ctx context.Context, accountID int64, planID int64) {
|
||||
if s.rateLimitSvc == nil {
|
||||
return
|
||||
}
|
||||
|
||||
recovery, err := s.rateLimitSvc.RecoverAccountAfterSuccessfulTest(ctx, accountID)
|
||||
if err != nil {
|
||||
logger.LegacyPrintf("service.scheduled_test_runner", "[ScheduledTestRunner] plan=%d auto-recover failed: %v", planID, err)
|
||||
return
|
||||
}
|
||||
if recovery == nil {
|
||||
return
|
||||
}
|
||||
|
||||
if recovery.ClearedError {
|
||||
logger.LegacyPrintf("service.scheduled_test_runner", "[ScheduledTestRunner] plan=%d auto-recover: account=%d recovered from error status", planID, accountID)
|
||||
}
|
||||
if recovery.ClearedRateLimit {
|
||||
logger.LegacyPrintf("service.scheduled_test_runner", "[ScheduledTestRunner] plan=%d auto-recover: account=%d cleared rate-limit/runtime state", planID, accountID)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -287,9 +287,10 @@ func ProvideScheduledTestRunnerService(
|
||||
planRepo ScheduledTestPlanRepository,
|
||||
scheduledSvc *ScheduledTestService,
|
||||
accountTestSvc *AccountTestService,
|
||||
rateLimitSvc *RateLimitService,
|
||||
cfg *config.Config,
|
||||
) *ScheduledTestRunnerService {
|
||||
svc := NewScheduledTestRunnerService(planRepo, scheduledSvc, accountTestSvc, cfg)
|
||||
svc := NewScheduledTestRunnerService(planRepo, scheduledSvc, accountTestSvc, rateLimitSvc, cfg)
|
||||
svc.Start()
|
||||
return svc
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user