diff --git a/src/cli/program/command-selector.test.ts b/src/cli/program/command-selector.test.ts index 4e2e53bd6a5..c214d0e81d1 100644 --- a/src/cli/program/command-selector.test.ts +++ b/src/cli/program/command-selector.test.ts @@ -52,6 +52,28 @@ describe("command-selector", () => { expect(ranked.some((candidate) => candidate.label === "status")).toBe(false); }); + it("prioritizes deep commands when querying a shared subcommand name", () => { + const program = new Command(); + const models = program.command("models").description("Model commands"); + const aliases = models.command("aliases").description("Alias commands"); + aliases.command("add").description("Add alias"); + const fallbacks = models.command("fallbacks").description("Fallback commands"); + fallbacks.command("add").description("Add fallback"); + + const candidates = collectCommandSelectorCandidates(program); + const ranked = rankCommandSelectorCandidates(candidates, "add"); + const topLabels = ranked.slice(0, 2).map((candidate) => candidate.label); + + expect(topLabels).toEqual(["models aliases add", "models fallbacks add"]); + const aliasesParentIndex = ranked.findIndex( + (candidate) => candidate.label === "models aliases", + ); + const aliasesAddIndex = ranked.findIndex( + (candidate) => candidate.label === "models aliases add", + ); + expect(aliasesParentIndex).toBeGreaterThan(aliasesAddIndex); + }); + it("resolves commands by path", () => { const program = new Command(); const models = program.command("models"); diff --git a/src/cli/program/command-selector.ts b/src/cli/program/command-selector.ts index e19b61ab5fe..5e3a7abc424 100644 --- a/src/cli/program/command-selector.ts +++ b/src/cli/program/command-selector.ts @@ -10,6 +10,7 @@ const SHOW_HELP_VALUE = "__show_help__"; const BACK_TO_MAIN_VALUE = "__back_to_main__"; const RUN_CURRENT_VALUE = "__run_current__"; const PATH_SEPARATOR = "\u0000"; +const SELECTION_VALUE_SEPARATOR = "\u0001"; const MAX_RESULTS = 200; type CommandSelectorCandidate = { @@ -142,6 +143,36 @@ export function collectDirectSubcommandSelectorCandidates( return prepareSortedCandidates(raw); } +function prioritizeDeepCommandsForSubcommandQuery(params: { + ranked: PreparedCommandSelectorCandidate[]; + queryLower: string; +}): PreparedCommandSelectorCandidate[] { + const tokens = params.queryLower.split(/\s+/).filter((token) => token.length > 0); + if (tokens.length !== 1) { + return params.ranked; + } + const [token] = tokens; + if (!token) { + return params.ranked; + } + + const deepExact: PreparedCommandSelectorCandidate[] = []; + const remaining: PreparedCommandSelectorCandidate[] = []; + for (const candidate of params.ranked) { + const last = candidate.path[candidate.path.length - 1]?.toLowerCase(); + if (candidate.path.length >= 2 && last === token) { + deepExact.push(candidate); + continue; + } + remaining.push(candidate); + } + + if (deepExact.length === 0) { + return params.ranked; + } + return [...deepExact, ...remaining]; +} + export function rankCommandSelectorCandidates( candidates: PreparedCommandSelectorCandidate[], query: string, @@ -150,7 +181,8 @@ export function rankCommandSelectorCandidates( if (!queryLower) { return candidates; } - return fuzzyFilterLower(candidates, queryLower); + const ranked = fuzzyFilterLower(candidates, queryLower); + return prioritizeDeepCommandsForSubcommandQuery({ ranked, queryLower }); } async function hydrateProgramCommandsForSelector(program: Command): Promise { @@ -185,6 +217,16 @@ function deserializePath(value: string): string[] { .filter(Boolean); } +function serializeSelectionValue(params: { path: string[]; query: string }): string { + return `${params.query}${SELECTION_VALUE_SEPARATOR}${serializePath(params.path)}`; +} + +function deserializeSelectionPath(value: string): string[] { + const separatorIndex = value.indexOf(SELECTION_VALUE_SEPARATOR); + const pathValue = separatorIndex >= 0 ? value.slice(separatorIndex + 1) : value; + return deserializePath(pathValue); +} + async function promptForCommandSelection(params: { message: string; placeholder: string; @@ -205,7 +247,7 @@ async function promptForCommandSelection(params: { const ranked = rankCommandSelectorCandidates(params.candidates, query).slice(0, MAX_RESULTS); return [ ...ranked.map((candidate) => ({ - value: serializePath(candidate.path), + value: serializeSelectionValue({ path: candidate.path, query }), label: candidate.label, hint: stylePromptHint(candidate.description), })), @@ -247,7 +289,7 @@ async function promptForCommandSelection(params: { if (selection === RUN_CURRENT_VALUE) { return "run_current"; } - return deserializePath(selection); + return deserializeSelectionPath(selection); } export async function runInteractiveCommandSelector(program: Command): Promise {