Agents: validate persisted tool-call names

This commit is contained in:
Vignesh Natarajan
2026-02-21 23:06:44 -08:00
parent 29a782b9cd
commit cdfe45eeb8
11 changed files with 248 additions and 8 deletions

View File

@@ -1,6 +1,9 @@
import type { AgentMessage } from "@mariozechner/pi-agent-core";
import { extractToolCallsFromAssistant, extractToolResultId } from "./tool-call-id.js";
const TOOL_CALL_NAME_MAX_CHARS = 64;
const TOOL_CALL_NAME_RE = /^[A-Za-z0-9_-]+$/;
type ToolCallBlock = {
type?: unknown;
id?: unknown;
@@ -35,8 +38,38 @@ function hasToolCallId(block: ToolCallBlock): boolean {
return hasNonEmptyStringField(block.id);
}
function hasToolCallName(block: ToolCallBlock): boolean {
return hasNonEmptyStringField(block.name);
function normalizeAllowedToolNames(allowedToolNames?: Iterable<string>): Set<string> | null {
if (!allowedToolNames) {
return null;
}
const normalized = new Set<string>();
for (const name of allowedToolNames) {
if (typeof name !== "string") {
continue;
}
const trimmed = name.trim();
if (trimmed) {
normalized.add(trimmed.toLowerCase());
}
}
return normalized.size > 0 ? normalized : null;
}
function hasToolCallName(block: ToolCallBlock, allowedToolNames: Set<string> | null): boolean {
if (typeof block.name !== "string") {
return false;
}
const trimmed = block.name.trim();
if (!trimmed || trimmed !== block.name) {
return false;
}
if (trimmed.length > TOOL_CALL_NAME_MAX_CHARS || !TOOL_CALL_NAME_RE.test(trimmed)) {
return false;
}
if (!allowedToolNames) {
return true;
}
return allowedToolNames.has(trimmed.toLowerCase());
}
function makeMissingToolResult(params: {
@@ -66,6 +99,10 @@ export type ToolCallInputRepairReport = {
droppedAssistantMessages: number;
};
export type ToolCallInputRepairOptions = {
allowedToolNames?: Iterable<string>;
};
export function stripToolResultDetails(messages: AgentMessage[]): AgentMessage[] {
let touched = false;
const out: AgentMessage[] = [];
@@ -85,11 +122,15 @@ export function stripToolResultDetails(messages: AgentMessage[]): AgentMessage[]
return touched ? out : messages;
}
export function repairToolCallInputs(messages: AgentMessage[]): ToolCallInputRepairReport {
export function repairToolCallInputs(
messages: AgentMessage[],
options?: ToolCallInputRepairOptions,
): ToolCallInputRepairReport {
let droppedToolCalls = 0;
let droppedAssistantMessages = 0;
let changed = false;
const out: AgentMessage[] = [];
const allowedToolNames = normalizeAllowedToolNames(options?.allowedToolNames);
for (const msg of messages) {
if (!msg || typeof msg !== "object") {
@@ -108,7 +149,9 @@ export function repairToolCallInputs(messages: AgentMessage[]): ToolCallInputRep
for (const block of msg.content) {
if (
isToolCallBlock(block) &&
(!hasToolCallInput(block) || !hasToolCallId(block) || !hasToolCallName(block))
(!hasToolCallInput(block) ||
!hasToolCallId(block) ||
!hasToolCallName(block, allowedToolNames))
) {
droppedToolCalls += 1;
droppedInMessage += 1;
@@ -138,8 +181,11 @@ export function repairToolCallInputs(messages: AgentMessage[]): ToolCallInputRep
};
}
export function sanitizeToolCallInputs(messages: AgentMessage[]): AgentMessage[] {
return repairToolCallInputs(messages).messages;
export function sanitizeToolCallInputs(
messages: AgentMessage[],
options?: ToolCallInputRepairOptions,
): AgentMessage[] {
return repairToolCallInputs(messages, options).messages;
}
export function sanitizeToolUseResultPairing(messages: AgentMessage[]): AgentMessage[] {