fix: harden ACP gateway startup sequencing (#23390) (thanks @janckerchen)

This commit is contained in:
Peter Steinberger
2026-02-22 10:42:33 +01:00
parent 7499e0f619
commit 9f0b6a8c92
3 changed files with 189 additions and 12 deletions

View File

@@ -0,0 +1,152 @@
import { beforeAll, beforeEach, describe, expect, it, vi } from "vitest";
type GatewayClientCallbacks = {
onHelloOk?: () => void;
onConnectError?: (err: Error) => void;
onClose?: (code: number, reason: string) => void;
};
const mockState = {
gateways: [] as MockGatewayClient[],
agentSideConnectionCtor: vi.fn(),
agentStart: vi.fn(),
};
class MockGatewayClient {
private callbacks: GatewayClientCallbacks;
constructor(opts: GatewayClientCallbacks) {
this.callbacks = opts;
mockState.gateways.push(this);
}
start(): void {}
stop(): void {
this.callbacks.onClose?.(1000, "gateway stopped");
}
emitHello(): void {
this.callbacks.onHelloOk?.();
}
emitConnectError(message: string): void {
this.callbacks.onConnectError?.(new Error(message));
}
}
vi.mock("@agentclientprotocol/sdk", () => ({
AgentSideConnection: class {
constructor(factory: (conn: unknown) => unknown, stream: unknown) {
mockState.agentSideConnectionCtor(factory, stream);
factory({});
}
},
ndJsonStream: vi.fn(() => ({ type: "mock-stream" })),
}));
vi.mock("../config/config.js", () => ({
loadConfig: () => ({
gateway: {
mode: "local",
},
}),
}));
vi.mock("../gateway/auth.js", () => ({
resolveGatewayAuth: () => ({}),
}));
vi.mock("../gateway/call.js", () => ({
buildGatewayConnectionDetails: () => ({
url: "ws://127.0.0.1:18789",
}),
}));
vi.mock("../gateway/client.js", () => ({
GatewayClient: MockGatewayClient,
}));
vi.mock("./translator.js", () => ({
AcpGatewayAgent: class {
start(): void {
mockState.agentStart();
}
handleGatewayReconnect(): void {}
handleGatewayDisconnect(): void {}
async handleGatewayEvent(): Promise<void> {}
},
}));
describe("serveAcpGateway startup", () => {
let serveAcpGateway: typeof import("./server.js").serveAcpGateway;
beforeAll(async () => {
({ serveAcpGateway } = await import("./server.js"));
});
beforeEach(() => {
mockState.gateways.length = 0;
mockState.agentSideConnectionCtor.mockReset();
mockState.agentStart.mockReset();
});
it("waits for gateway hello before creating AgentSideConnection", async () => {
const signalHandlers = new Map<NodeJS.Signals, () => void>();
const onceSpy = vi.spyOn(process, "once").mockImplementation(((
signal: NodeJS.Signals,
handler: () => void,
) => {
signalHandlers.set(signal, handler);
return process;
}) as typeof process.once);
try {
const servePromise = serveAcpGateway({});
await Promise.resolve();
expect(mockState.agentSideConnectionCtor).not.toHaveBeenCalled();
const gateway = mockState.gateways[0];
if (!gateway) {
throw new Error("Expected mocked gateway instance");
}
gateway.emitHello();
await vi.waitFor(() => {
expect(mockState.agentSideConnectionCtor).toHaveBeenCalledTimes(1);
});
signalHandlers.get("SIGINT")?.();
await servePromise;
} finally {
onceSpy.mockRestore();
}
});
it("rejects startup when gateway connect fails before hello", async () => {
const onceSpy = vi
.spyOn(process, "once")
.mockImplementation(
((_signal: NodeJS.Signals, _handler: () => void) => process) as typeof process.once,
);
try {
const servePromise = serveAcpGateway({});
await Promise.resolve();
const gateway = mockState.gateways[0];
if (!gateway) {
throw new Error("Expected mocked gateway instance");
}
gateway.emitConnectError("connect failed");
await expect(servePromise).rejects.toThrow("connect failed");
expect(mockState.agentSideConnectionCtor).not.toHaveBeenCalled();
} finally {
onceSpy.mockRestore();
}
});
});

View File

@@ -40,6 +40,27 @@ export async function serveAcpGateway(opts: AcpServerOptions = {}): Promise<void
onClosed = resolve;
});
let stopped = false;
let onGatewayReadyResolve!: () => void;
let onGatewayReadyReject!: (err: Error) => void;
let gatewayReadySettled = false;
const gatewayReady = new Promise<void>((resolve, reject) => {
onGatewayReadyResolve = resolve;
onGatewayReadyReject = reject;
});
const resolveGatewayReady = () => {
if (gatewayReadySettled) {
return;
}
gatewayReadySettled = true;
onGatewayReadyResolve();
};
const rejectGatewayReady = (err: unknown) => {
if (gatewayReadySettled) {
return;
}
gatewayReadySettled = true;
onGatewayReadyReject(err instanceof Error ? err : new Error(String(err)));
};
const gateway = new GatewayClient({
url: connection.url,
@@ -53,9 +74,16 @@ export async function serveAcpGateway(opts: AcpServerOptions = {}): Promise<void
void agent?.handleGatewayEvent(evt);
},
onHelloOk: () => {
resolveGatewayReady();
agent?.handleGatewayReconnect();
},
onConnectError: (err) => {
rejectGatewayReady(err);
},
onClose: (code, reason) => {
if (!stopped) {
rejectGatewayReady(new Error(`gateway closed before ready (${code}): ${reason}`));
}
agent?.handleGatewayDisconnect(`${code}: ${reason}`);
// Resolve only on intentional shutdown (gateway.stop() sets closed
// which skips scheduleReconnect, then fires onClose). Transient
@@ -71,6 +99,7 @@ export async function serveAcpGateway(opts: AcpServerOptions = {}): Promise<void
return;
}
stopped = true;
resolveGatewayReady();
gateway.stop();
// If no WebSocket is active (e.g. between reconnect attempts),
// gateway.stop() won't trigger onClose, so resolve directly.
@@ -80,20 +109,15 @@ export async function serveAcpGateway(opts: AcpServerOptions = {}): Promise<void
process.once("SIGINT", shutdown);
process.once("SIGTERM", shutdown);
// Start gateway first and wait for connection before processing ACP messages
// Start gateway first and wait for hello before accepting ACP requests.
gateway.start();
// Use a promise to wait for hello (connection established)
const helloReceived = new Promise<void>((resolve) => {
const originalOnHelloOk = gateway.opts.onHelloOk;
gateway.opts.onHelloOk = (hello) => {
originalOnHelloOk?.(hello);
resolve();
};
await gatewayReady.catch((err) => {
shutdown();
throw err;
});
// Wait for gateway connection before creating AgentSideConnection
await helloReceived;
if (stopped) {
return closed;
}
const input = Writable.toWeb(process.stdout);
const output = Readable.toWeb(process.stdin) as unknown as ReadableStream<Uint8Array>;