refactor(gateway): dedupe wizard param validation

This commit is contained in:
Peter Steinberger
2026-02-16 01:08:36 +00:00
parent 260a514467
commit a5cbd036de

View File

@@ -1,5 +1,6 @@
import type { ErrorObject } from "ajv";
import { randomUUID } from "node:crypto"; import { randomUUID } from "node:crypto";
import type { GatewayRequestHandlers } from "./types.js"; import type { GatewayRequestHandlers, RespondFn } from "./types.js";
import { defaultRuntime } from "../../runtime.js"; import { defaultRuntime } from "../../runtime.js";
import { WizardSession } from "../../wizard/session.js"; import { WizardSession } from "../../wizard/session.js";
import { import {
@@ -13,17 +14,40 @@ import {
} from "../protocol/index.js"; } from "../protocol/index.js";
import { formatForLog } from "../ws-log.js"; import { formatForLog } from "../ws-log.js";
export const wizardHandlers: GatewayRequestHandlers = { type Validator<T> = ((params: unknown) => params is T) & {
"wizard.start": async ({ params, respond, context }) => { errors?: ErrorObject[] | null;
if (!validateWizardStartParams(params)) { };
function assertValidParams<T>(
params: unknown,
validate: Validator<T>,
method: string,
respond: RespondFn,
): params is T {
if (validate(params)) {
return true;
}
respond( respond(
false, false,
undefined, undefined,
errorShape( errorShape(
ErrorCodes.INVALID_REQUEST, ErrorCodes.INVALID_REQUEST,
`invalid wizard.start params: ${formatValidationErrors(validateWizardStartParams.errors)}`, `invalid ${method} params: ${formatValidationErrors(validate.errors)}`,
), ),
); );
return false;
}
function readWizardStatus(session: WizardSession) {
return {
status: session.getStatus(),
error: session.getError(),
};
}
export const wizardHandlers: GatewayRequestHandlers = {
"wizard.start": async ({ params, respond, context }) => {
if (!assertValidParams(params, validateWizardStartParams, "wizard.start", respond)) {
return; return;
} }
const running = context.findRunningWizard(); const running = context.findRunningWizard();
@@ -47,15 +71,7 @@ export const wizardHandlers: GatewayRequestHandlers = {
respond(true, { sessionId, ...result }, undefined); respond(true, { sessionId, ...result }, undefined);
}, },
"wizard.next": async ({ params, respond, context }) => { "wizard.next": async ({ params, respond, context }) => {
if (!validateWizardNextParams(params)) { if (!assertValidParams(params, validateWizardNextParams, "wizard.next", respond)) {
respond(
false,
undefined,
errorShape(
ErrorCodes.INVALID_REQUEST,
`invalid wizard.next params: ${formatValidationErrors(validateWizardNextParams.errors)}`,
),
);
return; return;
} }
const sessionId = params.sessionId; const sessionId = params.sessionId;
@@ -84,15 +100,7 @@ export const wizardHandlers: GatewayRequestHandlers = {
respond(true, result, undefined); respond(true, result, undefined);
}, },
"wizard.cancel": ({ params, respond, context }) => { "wizard.cancel": ({ params, respond, context }) => {
if (!validateWizardCancelParams(params)) { if (!assertValidParams(params, validateWizardCancelParams, "wizard.cancel", respond)) {
respond(
false,
undefined,
errorShape(
ErrorCodes.INVALID_REQUEST,
`invalid wizard.cancel params: ${formatValidationErrors(validateWizardCancelParams.errors)}`,
),
);
return; return;
} }
const sessionId = params.sessionId; const sessionId = params.sessionId;
@@ -102,23 +110,12 @@ export const wizardHandlers: GatewayRequestHandlers = {
return; return;
} }
session.cancel(); session.cancel();
const status = { const status = readWizardStatus(session);
status: session.getStatus(),
error: session.getError(),
};
context.wizardSessions.delete(sessionId); context.wizardSessions.delete(sessionId);
respond(true, status, undefined); respond(true, status, undefined);
}, },
"wizard.status": ({ params, respond, context }) => { "wizard.status": ({ params, respond, context }) => {
if (!validateWizardStatusParams(params)) { if (!assertValidParams(params, validateWizardStatusParams, "wizard.status", respond)) {
respond(
false,
undefined,
errorShape(
ErrorCodes.INVALID_REQUEST,
`invalid wizard.status params: ${formatValidationErrors(validateWizardStatusParams.errors)}`,
),
);
return; return;
} }
const sessionId = params.sessionId; const sessionId = params.sessionId;
@@ -127,10 +124,7 @@ export const wizardHandlers: GatewayRequestHandlers = {
respond(false, undefined, errorShape(ErrorCodes.INVALID_REQUEST, "wizard not found")); respond(false, undefined, errorShape(ErrorCodes.INVALID_REQUEST, "wizard not found"));
return; return;
} }
const status = { const status = readWizardStatus(session);
status: session.getStatus(),
error: session.getError(),
};
if (status.status !== "running") { if (status.status !== "running") {
context.wizardSessions.delete(sessionId); context.wizardSessions.delete(sessionId);
} }