From adac9cb67f0a724e3394e3ebc3b2b6403190d97a Mon Sep 17 00:00:00 2001 From: Peter Steinberger Date: Wed, 18 Feb 2026 04:03:34 +0000 Subject: [PATCH] refactor: dedupe gateway and scheduler test scaffolding --- src/cron/service.issue-regressions.test.ts | 104 +++-- src/gateway/channel-health-monitor.test.ts | 374 +++++++----------- ...ver.chat.gateway-server-chat-b.e2e.test.ts | 137 +++---- .../server/ws-connection/message-handler.ts | 136 +++---- src/infra/update-runner.test.ts | 127 +++--- 5 files changed, 334 insertions(+), 544 deletions(-) diff --git a/src/cron/service.issue-regressions.test.ts b/src/cron/service.issue-regressions.test.ts index e4ebc413bd0..a28682faeac 100644 --- a/src/cron/service.issue-regressions.test.ts +++ b/src/cron/service.issue-regressions.test.ts @@ -18,6 +18,7 @@ const noopLogger = { trace: vi.fn(), }; const TOP_OF_HOUR_STAGGER_MS = 5 * 60 * 1_000; +type CronServiceOptions = ConstructorParameters[0]; function topOfHourOffsetMs(jobId: string) { const digest = crypto.createHash("sha256").update(jobId).digest(); @@ -68,6 +69,40 @@ function createDueIsolatedJob(params: { }; } +function createDefaultIsolatedRunner(): CronServiceOptions["runIsolatedAgentJob"] { + return vi.fn().mockResolvedValue({ + status: "ok", + summary: "ok", + }) as CronServiceOptions["runIsolatedAgentJob"]; +} + +async function startCronForStore(params: { + storePath: string; + cronEnabled?: boolean; + enqueueSystemEvent?: CronServiceOptions["enqueueSystemEvent"]; + requestHeartbeatNow?: CronServiceOptions["requestHeartbeatNow"]; + runIsolatedAgentJob?: CronServiceOptions["runIsolatedAgentJob"]; + onEvent?: CronServiceOptions["onEvent"]; +}) { + const enqueueSystemEvent = + params.enqueueSystemEvent ?? (vi.fn() as unknown as CronServiceOptions["enqueueSystemEvent"]); + const requestHeartbeatNow = + params.requestHeartbeatNow ?? (vi.fn() as unknown as CronServiceOptions["requestHeartbeatNow"]); + const runIsolatedAgentJob = params.runIsolatedAgentJob ?? createDefaultIsolatedRunner(); + + const cron = new CronService({ + cronEnabled: params.cronEnabled ?? true, + storePath: params.storePath, + log: noopLogger, + enqueueSystemEvent, + requestHeartbeatNow, + runIsolatedAgentJob, + ...(params.onEvent ? { onEvent: params.onEvent } : {}), + }); + await cron.start(); + return cron; +} + describe("Cron issue regressions", () => { beforeAll(async () => { fixtureRoot = await fs.mkdtemp(path.join(os.tmpdir(), "cron-issues-")); @@ -90,15 +125,10 @@ describe("Cron issue regressions", () => { it("covers schedule updates, force runs, isolated wake scheduling, and payload patching", async () => { const store = await makeStorePath(); const enqueueSystemEvent = vi.fn(); - const cron = new CronService({ - cronEnabled: true, + const cron = await startCronForStore({ storePath: store.storePath, - log: noopLogger, enqueueSystemEvent, - requestHeartbeatNow: vi.fn(), - runIsolatedAgentJob: vi.fn().mockResolvedValue({ status: "ok", summary: "ok" }), }); - await cron.start(); const created = await cron.add({ name: "hourly", @@ -171,15 +201,7 @@ describe("Cron issue regressions", () => { it("repairs missing nextRunAtMs on non-schedule updates without touching other jobs", async () => { const store = await makeStorePath(); - const cron = new CronService({ - cronEnabled: true, - storePath: store.storePath, - log: noopLogger, - enqueueSystemEvent: vi.fn(), - requestHeartbeatNow: vi.fn(), - runIsolatedAgentJob: vi.fn().mockResolvedValue({ status: "ok", summary: "ok" }), - }); - await cron.start(); + const cron = await startCronForStore({ storePath: store.storePath }); const created = await cron.add({ name: "repair-target", @@ -205,15 +227,7 @@ describe("Cron issue regressions", () => { const store = await makeStorePath(); const now = Date.parse("2026-02-06T10:05:00.000Z"); vi.setSystemTime(now); - const cron = new CronService({ - cronEnabled: false, - storePath: store.storePath, - log: noopLogger, - enqueueSystemEvent: vi.fn(), - requestHeartbeatNow: vi.fn(), - runIsolatedAgentJob: vi.fn().mockResolvedValue({ status: "ok", summary: "ok" }), - }); - await cron.start(); + const cron = await startCronForStore({ storePath: store.storePath, cronEnabled: false }); const dueJob = await cron.add({ name: "due-preserved", @@ -279,15 +293,7 @@ describe("Cron issue regressions", () => { "utf-8", ); - const cron = new CronService({ - cronEnabled: true, - storePath: store.storePath, - log: noopLogger, - enqueueSystemEvent: vi.fn(), - requestHeartbeatNow: vi.fn(), - runIsolatedAgentJob: vi.fn().mockResolvedValue({ status: "ok", summary: "ok" }), - }); - await cron.start(); + const cron = await startCronForStore({ storePath: store.storePath }); const listed = await cron.list(); expect(listed.some((job) => job.id === "missing-enabled-update")).toBe(true); @@ -332,15 +338,11 @@ describe("Cron issue regressions", () => { ); const enqueueSystemEvent = vi.fn(); - const cron = new CronService({ - cronEnabled: false, + const cron = await startCronForStore({ storePath: store.storePath, - log: noopLogger, + cronEnabled: false, enqueueSystemEvent, - requestHeartbeatNow: vi.fn(), - runIsolatedAgentJob: vi.fn().mockResolvedValue({ status: "ok", summary: "ok" }), }); - await cron.start(); const result = await cron.run("missing-enabled-due", "due"); expect(result).toEqual({ ok: true, ran: true }); @@ -355,15 +357,7 @@ describe("Cron issue regressions", () => { it("caps timer delay to 60s for far-future schedules", async () => { const timeoutSpy = vi.spyOn(globalThis, "setTimeout"); const store = await makeStorePath(); - const cron = new CronService({ - cronEnabled: true, - storePath: store.storePath, - log: noopLogger, - enqueueSystemEvent: vi.fn(), - requestHeartbeatNow: vi.fn(), - runIsolatedAgentJob: vi.fn().mockResolvedValue({ status: "ok", summary: "ok" }), - }); - await cron.start(); + const cron = await startCronForStore({ storePath: store.storePath }); const callsBeforeAdd = timeoutSpy.mock.calls.length; await cron.add({ @@ -436,12 +430,8 @@ describe("Cron issue regressions", () => { const finished = createDeferred(); let targetJobId = ""; - const cron = new CronService({ - cronEnabled: true, + const cron = await startCronForStore({ storePath: store.storePath, - log: noopLogger, - enqueueSystemEvent: vi.fn(), - requestHeartbeatNow: vi.fn(), runIsolatedAgentJob, onEvent: (evt: CronEvent) => { if (evt.jobId !== targetJobId) { @@ -454,7 +444,6 @@ describe("Cron issue regressions", () => { } }, }); - await cron.start(); const runAt = Date.now() + 1; const job = await cron.add({ @@ -525,16 +514,11 @@ describe("Cron issue regressions", () => { "utf-8", ); const enqueueSystemEvent = vi.fn(); - const cron = new CronService({ - cronEnabled: true, + const cron = await startCronForStore({ storePath: store.storePath, - log: noopLogger, enqueueSystemEvent, - requestHeartbeatNow: vi.fn(), runIsolatedAgentJob: vi.fn().mockResolvedValue({ status: "ok" }), }); - - await cron.start(); expect(enqueueSystemEvent).not.toHaveBeenCalled(); cron.stop(); } diff --git a/src/gateway/channel-health-monitor.test.ts b/src/gateway/channel-health-monitor.test.ts index 4f3992dabb7..4726405690f 100644 --- a/src/gateway/channel-health-monitor.test.ts +++ b/src/gateway/channel-health-monitor.test.ts @@ -36,6 +36,50 @@ function snapshotWith( return { channels, channelAccounts }; } +const DEFAULT_CHECK_INTERVAL_MS = 5_000; + +function createSnapshotManager( + accounts: Record>>, + overrides?: Partial, +): ChannelManager { + return createMockChannelManager({ + getRuntimeSnapshot: vi.fn(() => snapshotWith(accounts)), + ...overrides, + }); +} + +function startDefaultMonitor( + manager: ChannelManager, + overrides: Partial[0], "channelManager">> = {}, +) { + return startChannelHealthMonitor({ + channelManager: manager, + checkIntervalMs: DEFAULT_CHECK_INTERVAL_MS, + startupGraceMs: 0, + ...overrides, + }); +} + +async function startAndRunCheck( + manager: ChannelManager, + overrides: Partial[0], "channelManager">> = {}, +) { + const monitor = startDefaultMonitor(manager, overrides); + const startupGraceMs = overrides.startupGraceMs ?? 0; + const checkIntervalMs = overrides.checkIntervalMs ?? DEFAULT_CHECK_INTERVAL_MS; + await vi.advanceTimersByTimeAsync(startupGraceMs + checkIntervalMs + 500); + return monitor; +} + +function managedStoppedAccount(lastError: string): Partial { + return { + running: false, + enabled: true, + configured: true, + lastError, + }; +} + describe("channel-health-monitor", () => { beforeEach(() => { vi.useFakeTimers(); @@ -46,11 +90,7 @@ describe("channel-health-monitor", () => { it("does not run before the grace period", async () => { const manager = createMockChannelManager(); - const monitor = startChannelHealthMonitor({ - channelManager: manager, - checkIntervalMs: 5_000, - startupGraceMs: 60_000, - }); + const monitor = startDefaultMonitor(manager, { startupGraceMs: 60_000 }); await vi.advanceTimersByTimeAsync(10_000); expect(manager.getRuntimeSnapshot).not.toHaveBeenCalled(); monitor.stop(); @@ -58,125 +98,77 @@ describe("channel-health-monitor", () => { it("runs health check after grace period", async () => { const manager = createMockChannelManager(); - const monitor = startChannelHealthMonitor({ - channelManager: manager, - checkIntervalMs: 5_000, - startupGraceMs: 1_000, - }); - await vi.advanceTimersByTimeAsync(6_500); + const monitor = await startAndRunCheck(manager, { startupGraceMs: 1_000 }); expect(manager.getRuntimeSnapshot).toHaveBeenCalled(); monitor.stop(); }); it("skips healthy channels (running + connected)", async () => { - const manager = createMockChannelManager({ - getRuntimeSnapshot: vi.fn(() => - snapshotWith({ - discord: { - default: { running: true, connected: true, enabled: true, configured: true }, - }, - }), - ), + const manager = createSnapshotManager({ + discord: { + default: { running: true, connected: true, enabled: true, configured: true }, + }, }); - const monitor = startChannelHealthMonitor({ - channelManager: manager, - checkIntervalMs: 5_000, - startupGraceMs: 0, - }); - await vi.advanceTimersByTimeAsync(5_500); + const monitor = await startAndRunCheck(manager); expect(manager.stopChannel).not.toHaveBeenCalled(); expect(manager.startChannel).not.toHaveBeenCalled(); monitor.stop(); }); it("skips disabled channels", async () => { - const manager = createMockChannelManager({ - getRuntimeSnapshot: vi.fn(() => - snapshotWith({ - imessage: { - default: { - running: false, - enabled: false, - configured: true, - lastError: "disabled", - }, - }, - }), - ), + const manager = createSnapshotManager({ + imessage: { + default: { + running: false, + enabled: false, + configured: true, + lastError: "disabled", + }, + }, }); - const monitor = startChannelHealthMonitor({ - channelManager: manager, - checkIntervalMs: 5_000, - startupGraceMs: 0, - }); - await vi.advanceTimersByTimeAsync(5_500); + const monitor = await startAndRunCheck(manager); expect(manager.startChannel).not.toHaveBeenCalled(); monitor.stop(); }); it("skips unconfigured channels", async () => { - const manager = createMockChannelManager({ - getRuntimeSnapshot: vi.fn(() => - snapshotWith({ - discord: { - default: { running: false, enabled: true, configured: false }, - }, - }), - ), + const manager = createSnapshotManager({ + discord: { + default: { running: false, enabled: true, configured: false }, + }, }); - const monitor = startChannelHealthMonitor({ - channelManager: manager, - checkIntervalMs: 5_000, - startupGraceMs: 0, - }); - await vi.advanceTimersByTimeAsync(5_500); + const monitor = await startAndRunCheck(manager); expect(manager.startChannel).not.toHaveBeenCalled(); monitor.stop(); }); it("skips manually stopped channels", async () => { - const manager = createMockChannelManager({ - getRuntimeSnapshot: vi.fn(() => - snapshotWith({ - discord: { - default: { running: false, enabled: true, configured: true }, - }, - }), - ), - isManuallyStopped: vi.fn(() => true), - }); - const monitor = startChannelHealthMonitor({ - channelManager: manager, - checkIntervalMs: 5_000, - startupGraceMs: 0, - }); - await vi.advanceTimersByTimeAsync(5_500); + const manager = createSnapshotManager( + { + discord: { + default: { running: false, enabled: true, configured: true }, + }, + }, + { isManuallyStopped: vi.fn(() => true) }, + ); + const monitor = await startAndRunCheck(manager); expect(manager.startChannel).not.toHaveBeenCalled(); monitor.stop(); }); it("restarts a stuck channel (running but not connected)", async () => { - const manager = createMockChannelManager({ - getRuntimeSnapshot: vi.fn(() => - snapshotWith({ - whatsapp: { - default: { - running: true, - connected: false, - enabled: true, - configured: true, - linked: true, - }, - }, - }), - ), + const manager = createSnapshotManager({ + whatsapp: { + default: { + running: true, + connected: false, + enabled: true, + configured: true, + linked: true, + }, + }, }); - const monitor = startChannelHealthMonitor({ - channelManager: manager, - checkIntervalMs: 5_000, - startupGraceMs: 0, - }); - await vi.advanceTimersByTimeAsync(5_500); + const monitor = await startAndRunCheck(manager); expect(manager.stopChannel).toHaveBeenCalledWith("whatsapp", "default"); expect(manager.resetRestartAttempts).toHaveBeenCalledWith("whatsapp", "default"); expect(manager.startChannel).toHaveBeenCalledWith("whatsapp", "default"); @@ -184,131 +176,71 @@ describe("channel-health-monitor", () => { }); it("restarts a stopped channel that gave up (reconnectAttempts >= 10)", async () => { - const manager = createMockChannelManager({ - getRuntimeSnapshot: vi.fn(() => - snapshotWith({ - discord: { - default: { - running: false, - enabled: true, - configured: true, - reconnectAttempts: 10, - lastError: "Failed to resolve Discord application id", - }, - }, - }), - ), + const manager = createSnapshotManager({ + discord: { + default: { + ...managedStoppedAccount("Failed to resolve Discord application id"), + reconnectAttempts: 10, + }, + }, }); - const monitor = startChannelHealthMonitor({ - channelManager: manager, - checkIntervalMs: 5_000, - startupGraceMs: 0, - }); - await vi.advanceTimersByTimeAsync(5_500); + const monitor = await startAndRunCheck(manager); expect(manager.resetRestartAttempts).toHaveBeenCalledWith("discord", "default"); expect(manager.startChannel).toHaveBeenCalledWith("discord", "default"); monitor.stop(); }); it("restarts a channel that stopped unexpectedly (not running, not manual)", async () => { - const manager = createMockChannelManager({ - getRuntimeSnapshot: vi.fn(() => - snapshotWith({ - telegram: { - default: { - running: false, - enabled: true, - configured: true, - lastError: "polling stopped unexpectedly", - }, - }, - }), - ), + const manager = createSnapshotManager({ + telegram: { + default: managedStoppedAccount("polling stopped unexpectedly"), + }, }); - const monitor = startChannelHealthMonitor({ - channelManager: manager, - checkIntervalMs: 5_000, - startupGraceMs: 0, - }); - await vi.advanceTimersByTimeAsync(5_500); + const monitor = await startAndRunCheck(manager); expect(manager.resetRestartAttempts).toHaveBeenCalledWith("telegram", "default"); expect(manager.startChannel).toHaveBeenCalledWith("telegram", "default"); monitor.stop(); }); it("treats missing enabled/configured flags as managed accounts", async () => { - const manager = createMockChannelManager({ - getRuntimeSnapshot: vi.fn(() => - snapshotWith({ - telegram: { - default: { - running: false, - lastError: "polling stopped unexpectedly", - }, - }, - }), - ), + const manager = createSnapshotManager({ + telegram: { + default: { + running: false, + lastError: "polling stopped unexpectedly", + }, + }, }); - const monitor = startChannelHealthMonitor({ - channelManager: manager, - checkIntervalMs: 5_000, - startupGraceMs: 0, - }); - await vi.advanceTimersByTimeAsync(5_500); + const monitor = await startAndRunCheck(manager); expect(manager.startChannel).toHaveBeenCalledWith("telegram", "default"); monitor.stop(); }); it("applies cooldown — skips recently restarted channels for 2 cycles", async () => { - const manager = createMockChannelManager({ - getRuntimeSnapshot: vi.fn(() => - snapshotWith({ - discord: { - default: { - running: false, - enabled: true, - configured: true, - lastError: "crashed", - }, - }, - }), - ), + const manager = createSnapshotManager({ + discord: { + default: managedStoppedAccount("crashed"), + }, }); - const monitor = startChannelHealthMonitor({ - channelManager: manager, - checkIntervalMs: 5_000, - startupGraceMs: 0, - }); - await vi.advanceTimersByTimeAsync(5_500); + const monitor = await startAndRunCheck(manager); expect(manager.startChannel).toHaveBeenCalledTimes(1); - await vi.advanceTimersByTimeAsync(5_000); + await vi.advanceTimersByTimeAsync(DEFAULT_CHECK_INTERVAL_MS); expect(manager.startChannel).toHaveBeenCalledTimes(1); - await vi.advanceTimersByTimeAsync(5_000); + await vi.advanceTimersByTimeAsync(DEFAULT_CHECK_INTERVAL_MS); expect(manager.startChannel).toHaveBeenCalledTimes(1); - await vi.advanceTimersByTimeAsync(5_000); + await vi.advanceTimersByTimeAsync(DEFAULT_CHECK_INTERVAL_MS); expect(manager.startChannel).toHaveBeenCalledTimes(2); monitor.stop(); }); it("caps at 3 health-monitor restarts per channel per hour", async () => { - const manager = createMockChannelManager({ - getRuntimeSnapshot: vi.fn(() => - snapshotWith({ - discord: { - default: { - running: false, - enabled: true, - configured: true, - lastError: "keeps crashing", - }, - }, - }), - ), + const manager = createSnapshotManager({ + discord: { + default: managedStoppedAccount("keeps crashing"), + }, }); - const monitor = startChannelHealthMonitor({ - channelManager: manager, + const monitor = startDefaultMonitor(manager, { checkIntervalMs: 1_000, - startupGraceMs: 0, cooldownCycles: 1, maxRestartsPerHour: 3, }); @@ -326,29 +258,19 @@ describe("channel-health-monitor", () => { const startGate = new Promise((resolve) => { releaseStart = () => resolve(); }); - const manager = createMockChannelManager({ - getRuntimeSnapshot: vi.fn(() => - snapshotWith({ - telegram: { - default: { - running: false, - enabled: true, - configured: true, - lastError: "stopped", - }, - }, + const manager = createSnapshotManager( + { + telegram: { + default: managedStoppedAccount("stopped"), + }, + }, + { + startChannel: vi.fn(async () => { + await startGate; }), - ), - startChannel: vi.fn(async () => { - await startGate; - }), - }); - const monitor = startChannelHealthMonitor({ - channelManager: manager, - checkIntervalMs: 100, - startupGraceMs: 0, - cooldownCycles: 0, - }); + }, + ); + const monitor = startDefaultMonitor(manager, { checkIntervalMs: 100, cooldownCycles: 0 }); await vi.advanceTimersByTimeAsync(120); expect(manager.startChannel).toHaveBeenCalledTimes(1); await vi.advanceTimersByTimeAsync(500); @@ -360,11 +282,7 @@ describe("channel-health-monitor", () => { it("stops cleanly", async () => { const manager = createMockChannelManager(); - const monitor = startChannelHealthMonitor({ - channelManager: manager, - checkIntervalMs: 5_000, - startupGraceMs: 0, - }); + const monitor = startDefaultMonitor(manager); monitor.stop(); await vi.advanceTimersByTimeAsync(10_000); expect(manager.getRuntimeSnapshot).not.toHaveBeenCalled(); @@ -373,12 +291,7 @@ describe("channel-health-monitor", () => { it("stops via abort signal", async () => { const manager = createMockChannelManager(); const abort = new AbortController(); - const monitor = startChannelHealthMonitor({ - channelManager: manager, - checkIntervalMs: 5_000, - startupGraceMs: 0, - abortSignal: abort.signal, - }); + const monitor = startDefaultMonitor(manager, { abortSignal: abort.signal }); abort.abort(); await vi.advanceTimersByTimeAsync(10_000); expect(manager.getRuntimeSnapshot).not.toHaveBeenCalled(); @@ -386,21 +299,12 @@ describe("channel-health-monitor", () => { }); it("treats running channels without a connected field as healthy", async () => { - const manager = createMockChannelManager({ - getRuntimeSnapshot: vi.fn(() => - snapshotWith({ - slack: { - default: { running: true, enabled: true, configured: true }, - }, - }), - ), + const manager = createSnapshotManager({ + slack: { + default: { running: true, enabled: true, configured: true }, + }, }); - const monitor = startChannelHealthMonitor({ - channelManager: manager, - checkIntervalMs: 5_000, - startupGraceMs: 0, - }); - await vi.advanceTimersByTimeAsync(5_500); + const monitor = await startAndRunCheck(manager); expect(manager.stopChannel).not.toHaveBeenCalled(); monitor.stop(); }); diff --git a/src/gateway/server.chat.gateway-server-chat-b.e2e.test.ts b/src/gateway/server.chat.gateway-server-chat-b.e2e.test.ts index 5521dea21f5..47cc432bd25 100644 --- a/src/gateway/server.chat.gateway-server-chat-b.e2e.test.ts +++ b/src/gateway/server.chat.gateway-server-chat-b.e2e.test.ts @@ -43,24 +43,49 @@ const sendReq = ( ); }; +async function withGatewayChatHarness( + run: (ctx: { + ws: Awaited>["ws"]; + createSessionDir: () => Promise; + }) => Promise, +) { + const tempDirs: string[] = []; + const { server, ws } = await startServerWithClient(); + const createSessionDir = async () => { + const sessionDir = await fs.mkdtemp(path.join(os.tmpdir(), "openclaw-gw-")); + tempDirs.push(sessionDir); + testState.sessionStorePath = path.join(sessionDir, "sessions.json"); + return sessionDir; + }; + + try { + await run({ ws, createSessionDir }); + } finally { + __setMaxChatHistoryMessagesBytesForTest(); + testState.sessionStorePath = undefined; + ws.close(); + await server.close(); + await Promise.all(tempDirs.map((dir) => fs.rm(dir, { recursive: true, force: true }))); + } +} + +async function writeMainSessionStore() { + await writeSessionStore({ + entries: { + main: { sessionId: "sess-main", updatedAt: Date.now() }, + }, + }); +} + describe("gateway server chat", () => { test("smoke: caps history payload and preserves routing metadata", async () => { - const tempDirs: string[] = []; - const { server, ws } = await startServerWithClient(); - try { + await withGatewayChatHarness(async ({ ws, createSessionDir }) => { const historyMaxBytes = 192 * 1024; __setMaxChatHistoryMessagesBytesForTest(historyMaxBytes); await connectOk(ws); - const sessionDir = await fs.mkdtemp(path.join(os.tmpdir(), "openclaw-gw-")); - tempDirs.push(sessionDir); - testState.sessionStorePath = path.join(sessionDir, "sessions.json"); - - await writeSessionStore({ - entries: { - main: { sessionId: "sess-main", updatedAt: Date.now() }, - }, - }); + const sessionDir = await createSessionDir(); + await writeMainSessionStore(); const bigText = "x".repeat(4_000); const historyLines: string[] = []; @@ -109,38 +134,27 @@ describe("gateway server chat", () => { }); expect(sendRes.ok).toBe(true); - const stored = JSON.parse(await fs.readFile(testState.sessionStorePath, "utf-8")) as Record< + const sessionStorePath = testState.sessionStorePath; + if (!sessionStorePath) { + throw new Error("expected session store path"); + } + const stored = JSON.parse(await fs.readFile(sessionStorePath, "utf-8")) as Record< string, { lastChannel?: string; lastTo?: string } | undefined >; expect(stored["agent:main:main"]?.lastChannel).toBe("whatsapp"); expect(stored["agent:main:main"]?.lastTo).toBe("+1555"); - } finally { - __setMaxChatHistoryMessagesBytesForTest(); - testState.sessionStorePath = undefined; - ws.close(); - await server.close(); - await Promise.all(tempDirs.map((dir) => fs.rm(dir, { recursive: true, force: true }))); - } + }); }); test("chat.history hard-caps single oversized nested payloads", async () => { - const tempDirs: string[] = []; - const { server, ws } = await startServerWithClient(); - try { + await withGatewayChatHarness(async ({ ws, createSessionDir }) => { const historyMaxBytes = 64 * 1024; __setMaxChatHistoryMessagesBytesForTest(historyMaxBytes); await connectOk(ws); - const sessionDir = await fs.mkdtemp(path.join(os.tmpdir(), "openclaw-gw-")); - tempDirs.push(sessionDir); - testState.sessionStorePath = path.join(sessionDir, "sessions.json"); - - await writeSessionStore({ - entries: { - main: { sessionId: "sess-main", updatedAt: Date.now() }, - }, - }); + const sessionDir = await createSessionDir(); + await writeMainSessionStore(); const hugeNestedText = "n".repeat(450_000); const oversizedLine = JSON.stringify({ @@ -175,32 +189,17 @@ describe("gateway server chat", () => { expect(bytes).toBeLessThanOrEqual(historyMaxBytes); expect(serialized).toContain("[chat.history omitted: message too large]"); expect(serialized.includes(hugeNestedText.slice(0, 256))).toBe(false); - } finally { - __setMaxChatHistoryMessagesBytesForTest(); - testState.sessionStorePath = undefined; - ws.close(); - await server.close(); - await Promise.all(tempDirs.map((dir) => fs.rm(dir, { recursive: true, force: true }))); - } + }); }); test("chat.history keeps recent small messages when latest message is oversized", async () => { - const tempDirs: string[] = []; - const { server, ws } = await startServerWithClient(); - try { + await withGatewayChatHarness(async ({ ws, createSessionDir }) => { const historyMaxBytes = 64 * 1024; __setMaxChatHistoryMessagesBytesForTest(historyMaxBytes); await connectOk(ws); - const sessionDir = await fs.mkdtemp(path.join(os.tmpdir(), "openclaw-gw-")); - tempDirs.push(sessionDir); - testState.sessionStorePath = path.join(sessionDir, "sessions.json"); - - await writeSessionStore({ - entries: { - main: { sessionId: "sess-main", updatedAt: Date.now() }, - }, - }); + const sessionDir = await createSessionDir(); + await writeMainSessionStore(); const baseText = "s".repeat(1_200); const lines: string[] = []; @@ -258,33 +257,17 @@ describe("gateway server chat", () => { expect(serialized).toContain("small-29:"); expect(serialized).toContain("[chat.history omitted: message too large]"); expect(serialized.includes(hugeNestedText.slice(0, 256))).toBe(false); - } finally { - __setMaxChatHistoryMessagesBytesForTest(); - testState.sessionStorePath = undefined; - ws.close(); - await server.close(); - await Promise.all(tempDirs.map((dir) => fs.rm(dir, { recursive: true, force: true }))); - } + }); }); test("smoke: supports abort and idempotent completion", async () => { - const tempDirs: string[] = []; - const { server, ws } = await startServerWithClient(); - const spy = vi.mocked(getReplyFromConfig) as unknown as ReturnType; - let aborted = false; - - try { + await withGatewayChatHarness(async ({ ws, createSessionDir }) => { + const spy = vi.mocked(getReplyFromConfig) as unknown as ReturnType; + let aborted = false; await connectOk(ws); - const sessionDir = await fs.mkdtemp(path.join(os.tmpdir(), "openclaw-gw-")); - tempDirs.push(sessionDir); - testState.sessionStorePath = path.join(sessionDir, "sessions.json"); - - await writeSessionStore({ - entries: { - main: { sessionId: "sess-main", updatedAt: Date.now() }, - }, - }); + await createSessionDir(); + await writeMainSessionStore(); spy.mockReset(); spy.mockImplementationOnce(async (_ctx, opts) => { @@ -359,12 +342,6 @@ describe("gateway server chat", () => { await new Promise((resolve) => setTimeout(resolve, 10)); } expect(completed).toBe(true); - } finally { - __setMaxChatHistoryMessagesBytesForTest(); - testState.sessionStorePath = undefined; - ws.close(); - await server.close(); - await Promise.all(tempDirs.map((dir) => fs.rm(dir, { recursive: true, force: true }))); - } + }); }); }); diff --git a/src/gateway/server/ws-connection/message-handler.ts b/src/gateway/server/ws-connection/message-handler.ts index c265b09f880..51008a35909 100644 --- a/src/gateway/server/ws-connection/message-handler.ts +++ b/src/gateway/server/ws-connection/message-handler.ts @@ -245,30 +245,42 @@ export function attachGatewayWsMessageHandler(params: { const frame = parsed; const connectParams = frame.params as ConnectParams; const clientLabel = connectParams.client.displayName ?? connectParams.client.id; - - // protocol negotiation - const { minProtocol, maxProtocol } = connectParams; - if (maxProtocol < PROTOCOL_VERSION || minProtocol > PROTOCOL_VERSION) { + const clientMeta = { + client: connectParams.client.id, + clientDisplayName: connectParams.client.displayName, + mode: connectParams.client.mode, + version: connectParams.client.version, + }; + const markHandshakeFailure = (cause: string, meta?: Record) => { setHandshakeState("failed"); - logWsControl.warn( - `protocol mismatch conn=${connId} remote=${remoteAddr ?? "?"} client=${clientLabel} ${connectParams.client.mode} v${connectParams.client.version}`, - ); - setCloseCause("protocol-mismatch", { - minProtocol, - maxProtocol, - expectedProtocol: PROTOCOL_VERSION, - client: connectParams.client.id, - clientDisplayName: connectParams.client.displayName, - mode: connectParams.client.mode, - version: connectParams.client.version, - }); + setCloseCause(cause, { ...meta, ...clientMeta }); + }; + const sendHandshakeErrorResponse = ( + code: Parameters[0], + message: string, + options?: Parameters[2], + ) => { send({ type: "res", id: frame.id, ok: false, - error: errorShape(ErrorCodes.INVALID_REQUEST, "protocol mismatch", { - details: { expectedProtocol: PROTOCOL_VERSION }, - }), + error: errorShape(code, message, options), + }); + }; + + // protocol negotiation + const { minProtocol, maxProtocol } = connectParams; + if (maxProtocol < PROTOCOL_VERSION || minProtocol > PROTOCOL_VERSION) { + markHandshakeFailure("protocol-mismatch", { + minProtocol, + maxProtocol, + expectedProtocol: PROTOCOL_VERSION, + }); + logWsControl.warn( + `protocol mismatch conn=${connId} remote=${remoteAddr ?? "?"} client=${clientLabel} ${connectParams.client.mode} v${connectParams.client.version}`, + ); + sendHandshakeErrorResponse(ErrorCodes.INVALID_REQUEST, "protocol mismatch", { + details: { expectedProtocol: PROTOCOL_VERSION }, }); close(1002, "protocol mismatch"); return; @@ -277,20 +289,10 @@ export function attachGatewayWsMessageHandler(params: { const roleRaw = connectParams.role ?? "operator"; const role = roleRaw === "operator" || roleRaw === "node" ? roleRaw : null; if (!role) { - setHandshakeState("failed"); - setCloseCause("invalid-role", { + markHandshakeFailure("invalid-role", { role: roleRaw, - client: connectParams.client.id, - clientDisplayName: connectParams.client.displayName, - mode: connectParams.client.mode, - version: connectParams.client.version, - }); - send({ - type: "res", - id: frame.id, - ok: false, - error: errorShape(ErrorCodes.INVALID_REQUEST, "invalid role"), }); + sendHandshakeErrorResponse(ErrorCodes.INVALID_REQUEST, "invalid role"); close(1008, "invalid role"); return; } @@ -312,22 +314,12 @@ export function attachGatewayWsMessageHandler(params: { if (!originCheck.ok) { const errorMessage = "origin not allowed (open the Control UI from the gateway host or allow it in gateway.controlUi.allowedOrigins)"; - setHandshakeState("failed"); - setCloseCause("origin-mismatch", { + markHandshakeFailure("origin-mismatch", { origin: requestOrigin ?? "n/a", host: requestHost ?? "n/a", reason: originCheck.reason, - client: connectParams.client.id, - clientDisplayName: connectParams.client.displayName, - mode: connectParams.client.mode, - version: connectParams.client.version, - }); - send({ - type: "res", - id: frame.id, - ok: false, - error: errorShape(ErrorCodes.INVALID_REQUEST, errorMessage), }); + sendHandshakeErrorResponse(ErrorCodes.INVALID_REQUEST, errorMessage); close(1008, truncateCloseReason(errorMessage)); return; } @@ -393,7 +385,16 @@ export function attachGatewayWsMessageHandler(params: { sharedAuthResult?.ok === true && (sharedAuthResult.method === "token" || sharedAuthResult.method === "password"); const rejectUnauthorized = (failedAuth: GatewayAuthResult) => { - setHandshakeState("failed"); + markHandshakeFailure("unauthorized", { + authMode: resolvedAuth.mode, + authProvided: connectParams.auth?.token + ? "token" + : connectParams.auth?.password + ? "password" + : "none", + authReason: failedAuth.reason, + allowTailscale: resolvedAuth.allowTailscale, + }); logWsControl.warn( `unauthorized conn=${connId} remote=${remoteAddr ?? "?"} client=${clientLabel} ${connectParams.client.mode} v${connectParams.client.version} reason=${failedAuth.reason ?? "unknown"}`, ); @@ -408,22 +409,7 @@ export function attachGatewayWsMessageHandler(params: { reason: failedAuth.reason, client: connectParams.client, }); - setCloseCause("unauthorized", { - authMode: resolvedAuth.mode, - authProvided, - authReason: failedAuth.reason, - allowTailscale: resolvedAuth.allowTailscale, - client: connectParams.client.id, - clientDisplayName: connectParams.client.displayName, - mode: connectParams.client.mode, - version: connectParams.client.version, - }); - send({ - type: "res", - id: frame.id, - ok: false, - error: errorShape(ErrorCodes.INVALID_REQUEST, authMessage), - }); + sendHandshakeErrorResponse(ErrorCodes.INVALID_REQUEST, authMessage); close(1008, truncateCloseReason(authMessage)); }; if (!device) { @@ -435,19 +421,8 @@ export function attachGatewayWsMessageHandler(params: { if (isControlUi && !allowControlUiBypass) { const errorMessage = "control ui requires HTTPS or localhost (secure context)"; - setHandshakeState("failed"); - setCloseCause("control-ui-insecure-auth", { - client: connectParams.client.id, - clientDisplayName: connectParams.client.displayName, - mode: connectParams.client.mode, - version: connectParams.client.version, - }); - send({ - type: "res", - id: frame.id, - ok: false, - error: errorShape(ErrorCodes.INVALID_REQUEST, errorMessage), - }); + markHandshakeFailure("control-ui-insecure-auth"); + sendHandshakeErrorResponse(ErrorCodes.INVALID_REQUEST, errorMessage); close(1008, errorMessage); return; } @@ -458,19 +433,8 @@ export function attachGatewayWsMessageHandler(params: { rejectUnauthorized(authResult); return; } - setHandshakeState("failed"); - setCloseCause("device-required", { - client: connectParams.client.id, - clientDisplayName: connectParams.client.displayName, - mode: connectParams.client.mode, - version: connectParams.client.version, - }); - send({ - type: "res", - id: frame.id, - ok: false, - error: errorShape(ErrorCodes.NOT_PAIRED, "device identity required"), - }); + markHandshakeFailure("device-required"); + sendHandshakeErrorResponse(ErrorCodes.NOT_PAIRED, "device identity required"); close(1008, "device identity required"); return; } diff --git a/src/infra/update-runner.test.ts b/src/infra/update-runner.test.ts index 31766593bc5..81e5cb884e0 100644 --- a/src/infra/update-runner.test.ts +++ b/src/infra/update-runner.test.ts @@ -5,9 +5,10 @@ import { afterAll, afterEach, beforeAll, beforeEach, describe, expect, it, vi } import { pathExists } from "../utils.js"; import { runGatewayUpdate } from "./update-runner.js"; -type CommandResult = { stdout?: string; stderr?: string; code?: number }; +type CommandResponse = { stdout?: string; stderr?: string; code?: number | null }; +type CommandResult = { stdout: string; stderr: string; code: number | null }; -function createRunner(responses: Record) { +function createRunner(responses: Record) { const calls: string[] = []; const runner = async (argv: string[]) => { const key = argv.join(" "); @@ -125,6 +126,32 @@ describe("runGatewayUpdate", () => { await fs.rm(path.join(tempDir, "dist", "control-ui"), { recursive: true, force: true }); } + async function runWithRunner( + runner: (argv: string[]) => Promise, + options?: { channel?: "stable" | "beta"; tag?: string; cwd?: string }, + ) { + return runGatewayUpdate({ + cwd: options?.cwd ?? tempDir, + runCommand: async (argv, _runOptions) => runner(argv), + timeoutMs: 5000, + ...(options?.channel ? { channel: options.channel } : {}), + ...(options?.tag ? { tag: options.tag } : {}), + }); + } + + async function runWithCommand( + runCommand: (argv: string[]) => Promise, + options?: { channel?: "stable" | "beta"; tag?: string; cwd?: string }, + ) { + return runGatewayUpdate({ + cwd: options?.cwd ?? tempDir, + runCommand: async (argv, _runOptions) => runCommand(argv), + timeoutMs: 5000, + ...(options?.channel ? { channel: options.channel } : {}), + ...(options?.tag ? { tag: options.tag } : {}), + }); + } + it("skips git update when worktree is dirty", async () => { await setupGitCheckout(); const { runner, calls } = createRunner({ @@ -134,11 +161,7 @@ describe("runGatewayUpdate", () => { [`git -C ${tempDir} status --porcelain -- :!dist/control-ui/`]: { stdout: " M README.md" }, }); - const result = await runGatewayUpdate({ - cwd: tempDir, - runCommand: async (argv, _options) => runner(argv), - timeoutMs: 5000, - }); + const result = await runWithRunner(runner); expect(result.status).toBe("skipped"); expect(result.reason).toBe("dirty"); @@ -162,11 +185,7 @@ describe("runGatewayUpdate", () => { [`git -C ${tempDir} rebase --abort`]: { stdout: "" }, }); - const result = await runGatewayUpdate({ - cwd: tempDir, - runCommand: async (argv, _options) => runner(argv), - timeoutMs: 5000, - }); + const result = await runWithRunner(runner); expect(result.status).toBe("error"); expect(result.reason).toBe("rebase-failed"); @@ -174,12 +193,7 @@ describe("runGatewayUpdate", () => { }); it("returns error and stops early when deps install fails", async () => { - await fs.mkdir(path.join(tempDir, ".git")); - await fs.writeFile( - path.join(tempDir, "package.json"), - JSON.stringify({ name: "openclaw", version: "1.0.0", packageManager: "pnpm@8.0.0" }), - "utf-8", - ); + await setupGitCheckout({ packageManager: "pnpm@8.0.0" }); const stableTag = "v1.0.1-1"; const { runner, calls } = createRunner({ [`git -C ${tempDir} rev-parse --show-toplevel`]: { stdout: tempDir }, @@ -191,12 +205,7 @@ describe("runGatewayUpdate", () => { "pnpm install": { code: 1, stderr: "ERR_PNPM_NETWORK" }, }); - const result = await runGatewayUpdate({ - cwd: tempDir, - runCommand: async (argv, _options) => runner(argv), - timeoutMs: 5000, - channel: "stable", - }); + const result = await runWithRunner(runner, { channel: "stable" }); expect(result.status).toBe("error"); expect(result.reason).toBe("deps-install-failed"); @@ -205,12 +214,7 @@ describe("runGatewayUpdate", () => { }); it("returns error and stops early when build fails", async () => { - await fs.mkdir(path.join(tempDir, ".git")); - await fs.writeFile( - path.join(tempDir, "package.json"), - JSON.stringify({ name: "openclaw", version: "1.0.0", packageManager: "pnpm@8.0.0" }), - "utf-8", - ); + await setupGitCheckout({ packageManager: "pnpm@8.0.0" }); const stableTag = "v1.0.1-1"; const { runner, calls } = createRunner({ [`git -C ${tempDir} rev-parse --show-toplevel`]: { stdout: tempDir }, @@ -223,12 +227,7 @@ describe("runGatewayUpdate", () => { "pnpm build": { code: 1, stderr: "tsc: error TS2345" }, }); - const result = await runGatewayUpdate({ - cwd: tempDir, - runCommand: async (argv, _options) => runner(argv), - timeoutMs: 5000, - channel: "stable", - }); + const result = await runWithRunner(runner, { channel: "stable" }); expect(result.status).toBe("error"); expect(result.reason).toBe("build-failed"); @@ -259,12 +258,7 @@ describe("runGatewayUpdate", () => { }, }); - const result = await runGatewayUpdate({ - cwd: tempDir, - runCommand: async (argv, _options) => runner(argv), - timeoutMs: 5000, - channel: "beta", - }); + const result = await runWithRunner(runner, { channel: "beta" }); expect(result.status).toBe("ok"); expect(calls).toContain(`git -C ${tempDir} checkout --detach ${stableTag}`); @@ -284,11 +278,7 @@ describe("runGatewayUpdate", () => { "pnpm root -g": { code: 1 }, }); - const result = await runGatewayUpdate({ - cwd: tempDir, - runCommand: async (argv, _options) => runner(argv), - timeoutMs: 5000, - }); + const result = await runWithRunner(runner); expect(result.status).toBe("skipped"); expect(result.reason).toBe("not-git-install"); @@ -323,10 +313,8 @@ describe("runGatewayUpdate", () => { }, }); - const result = await runGatewayUpdate({ + const result = await runWithCommand(runCommand, { cwd: pkgRoot, - runCommand: async (argv, _options) => runCommand(argv), - timeoutMs: 5000, channel: params.channel, tag: params.tag, }); @@ -425,11 +413,7 @@ describe("runGatewayUpdate", () => { return { stdout: "", stderr: "", code: 0 }; }; - const result = await runGatewayUpdate({ - cwd: pkgRoot, - runCommand: async (argv, _options) => runCommand(argv), - timeoutMs: 5000, - }); + const result = await runWithCommand(runCommand, { cwd: pkgRoot }); expect(result.status).toBe("ok"); expect(stalePresentAtInstall).toBe(false); @@ -463,11 +447,7 @@ describe("runGatewayUpdate", () => { }, }); - const result = await runGatewayUpdate({ - cwd: pkgRoot, - runCommand: async (argv, _options) => runCommand(argv), - timeoutMs: 5000, - }); + const result = await runWithCommand(runCommand, { cwd: pkgRoot }); expect(result.status).toBe("ok"); expect(result.mode).toBe("bun"); @@ -490,11 +470,7 @@ describe("runGatewayUpdate", () => { [`git -C ${tempDir} rev-parse --show-toplevel`]: { stdout: tempDir }, }); - const result = await runGatewayUpdate({ - cwd: tempDir, - runCommand: async (argv, _options) => runner(argv), - timeoutMs: 5000, - }); + const result = await runWithRunner(runner); cwdSpy.mockRestore(); @@ -520,12 +496,7 @@ describe("runGatewayUpdate", () => { "pnpm ui:build": { stdout: "" }, }); - const result = await runGatewayUpdate({ - cwd: tempDir, - runCommand: async (argv, _options) => runner(argv), - timeoutMs: 5000, - channel: "stable", - }); + const result = await runWithRunner(runner, { channel: "stable" }); expect(result.status).toBe("error"); expect(result.reason).toBe("doctor-entry-missing"); @@ -547,12 +518,7 @@ describe("runGatewayUpdate", () => { onDoctor: removeControlUiAssets, }); - const result = await runGatewayUpdate({ - cwd: tempDir, - runCommand: async (argv, _options) => runCommand(argv), - timeoutMs: 5000, - channel: "stable", - }); + const result = await runWithCommand(runCommand, { channel: "stable" }); expect(result.status).toBe("ok"); expect(getUiBuildCount()).toBe(2); @@ -577,12 +543,7 @@ describe("runGatewayUpdate", () => { onDoctor: removeControlUiAssets, }); - const result = await runGatewayUpdate({ - cwd: tempDir, - runCommand: async (argv, _options) => runCommand(argv), - timeoutMs: 5000, - channel: "stable", - }); + const result = await runWithCommand(runCommand, { channel: "stable" }); expect(result.status).toBe("error"); expect(result.reason).toBe("ui-assets-missing");