feat(agent): opt-in tool-result context pruning

This commit is contained in:
Max Sumrall
2026-01-07 12:02:46 +01:00
committed by Peter Steinberger
parent 937e0265a3
commit eeaa6ea46f
9 changed files with 779 additions and 26 deletions

View File

@@ -0,0 +1,27 @@
import type { AgentMessage } from "@mariozechner/pi-agent-core";
import type {
ContextEvent,
ExtensionAPI,
ExtensionContext,
} from "@mariozechner/pi-coding-agent";
import { pruneContextMessages } from "./pruner.js";
import { getContextPruningRuntime } from "./runtime.js";
export default function contextPruningExtension(api: ExtensionAPI): void {
api.on("context", (event: ContextEvent, ctx: ExtensionContext) => {
const runtime = getContextPruningRuntime(ctx.sessionManager);
if (!runtime) return undefined;
const next = pruneContextMessages({
messages: event.messages as AgentMessage[],
settings: runtime.settings,
ctx,
isToolPrunable: runtime.isToolPrunable,
contextWindowTokensOverride: runtime.contextWindowTokens ?? undefined,
});
if (next === event.messages) return undefined;
return { messages: next };
});
}

View File

@@ -0,0 +1,310 @@
import type { AgentMessage } from "@mariozechner/pi-agent-core";
import type {
ImageContent,
TextContent,
ToolResultMessage,
} from "@mariozechner/pi-ai";
import type { ExtensionContext } from "@mariozechner/pi-coding-agent";
import type { EffectiveContextPruningSettings } from "./settings.js";
import { makeToolPrunablePredicate } from "./tools.js";
const CHARS_PER_TOKEN_ESTIMATE = 4;
// We currently skip pruning tool results that contain images. Still, we count them (approx.) so
// we start trimming prunable tool results earlier when image-heavy context is consuming the window.
const IMAGE_CHAR_ESTIMATE = 8_000;
function asText(text: string): TextContent {
return { type: "text", text };
}
function collectTextSegments(
content: ReadonlyArray<TextContent | ImageContent>,
): string[] {
const parts: string[] = [];
for (const block of content) {
if (block.type === "text") parts.push(block.text);
}
return parts;
}
function estimateJoinedTextLength(parts: string[]): number {
if (parts.length === 0) return 0;
let len = 0;
for (const p of parts) len += p.length;
// Joined with "\n" separators between blocks.
len += Math.max(0, parts.length - 1);
return len;
}
function takeHeadFromJoinedText(parts: string[], maxChars: number): string {
if (maxChars <= 0 || parts.length === 0) return "";
let remaining = maxChars;
let out = "";
for (let i = 0; i < parts.length && remaining > 0; i++) {
if (i > 0) {
out += "\n";
remaining -= 1;
if (remaining <= 0) break;
}
const p = parts[i];
if (p.length <= remaining) {
out += p;
remaining -= p.length;
} else {
out += p.slice(0, remaining);
remaining = 0;
}
}
return out;
}
function takeTailFromJoinedText(parts: string[], maxChars: number): string {
if (maxChars <= 0 || parts.length === 0) return "";
let remaining = maxChars;
const out: string[] = [];
for (let i = parts.length - 1; i >= 0 && remaining > 0; i--) {
const p = parts[i];
if (p.length <= remaining) {
out.push(p);
remaining -= p.length;
} else {
out.push(p.slice(p.length - remaining));
remaining = 0;
break;
}
if (remaining > 0 && i > 0) {
out.push("\n");
remaining -= 1;
}
}
out.reverse();
return out.join("");
}
function hasImageBlocks(
content: ReadonlyArray<TextContent | ImageContent>,
): boolean {
for (const block of content) {
if (block.type === "image") return true;
}
return false;
}
function estimateMessageChars(message: AgentMessage): number {
if (message.role === "user") {
const content = message.content;
if (typeof content === "string") return content.length;
let chars = 0;
for (const b of content) {
if (b.type === "text") chars += b.text.length;
if (b.type === "image") chars += IMAGE_CHAR_ESTIMATE;
}
return chars;
}
if (message.role === "assistant") {
let chars = 0;
for (const b of message.content) {
if (b.type === "text") chars += b.text.length;
if (b.type === "thinking") chars += b.thinking.length;
if (b.type === "toolCall") {
try {
chars += JSON.stringify(b.arguments ?? {}).length;
} catch {
chars += 128;
}
}
}
return chars;
}
if (message.role === "toolResult") {
let chars = 0;
for (const b of message.content) {
if (b.type === "text") chars += b.text.length;
if (b.type === "image") chars += IMAGE_CHAR_ESTIMATE;
}
return chars;
}
return 256;
}
function estimateContextChars(messages: AgentMessage[]): number {
return messages.reduce((sum, m) => sum + estimateMessageChars(m), 0);
}
function findAssistantCutoffIndex(
messages: AgentMessage[],
keepLastAssistants: number,
): number | null {
// keepLastAssistants <= 0 => everything is potentially prunable.
if (keepLastAssistants <= 0) return messages.length;
let remaining = keepLastAssistants;
for (let i = messages.length - 1; i >= 0; i--) {
if (messages[i]?.role !== "assistant") continue;
remaining--;
if (remaining === 0) return i;
}
// Not enough assistant messages to establish a protected tail.
return null;
}
function softTrimToolResultMessage(params: {
msg: ToolResultMessage;
settings: EffectiveContextPruningSettings;
}): ToolResultMessage | null {
const { msg, settings } = params;
// Ignore image tool results for now: these are often directly relevant and hard to partially prune safely.
if (hasImageBlocks(msg.content)) return null;
const parts = collectTextSegments(msg.content);
const rawLen = estimateJoinedTextLength(parts);
if (rawLen <= settings.softTrim.maxChars) return null;
const headChars = Math.max(0, settings.softTrim.headChars);
const tailChars = Math.max(0, settings.softTrim.tailChars);
if (headChars + tailChars >= rawLen) return null;
const head = takeHeadFromJoinedText(parts, headChars);
const tail = takeTailFromJoinedText(parts, tailChars);
const trimmed = `${head}
...
${tail}`;
const note = `
[Tool result trimmed: kept first ${headChars} chars and last ${tailChars} chars of ${rawLen} chars.]`;
return { ...msg, content: [asText(trimmed + note)] };
}
export function pruneContextMessages(params: {
messages: AgentMessage[];
settings: EffectiveContextPruningSettings;
ctx: Pick<ExtensionContext, "model">;
isToolPrunable?: (toolName: string) => boolean;
contextWindowTokensOverride?: number;
}): AgentMessage[] {
const { messages, settings, ctx } = params;
const contextWindowTokens =
typeof params.contextWindowTokensOverride === "number" &&
Number.isFinite(params.contextWindowTokensOverride) &&
params.contextWindowTokensOverride > 0
? params.contextWindowTokensOverride
: ctx.model?.contextWindow;
if (!contextWindowTokens || contextWindowTokens <= 0) return messages;
const charWindow = contextWindowTokens * CHARS_PER_TOKEN_ESTIMATE;
if (charWindow <= 0) return messages;
const cutoffIndex = findAssistantCutoffIndex(
messages,
settings.keepLastAssistants,
);
if (cutoffIndex === null) return messages;
const isToolPrunable =
params.isToolPrunable ?? makeToolPrunablePredicate(settings.tools);
if (settings.mode === "aggressive") {
let next: AgentMessage[] | null = null;
for (let i = 0; i < cutoffIndex; i++) {
const msg = messages[i];
if (!msg || msg.role !== "toolResult") continue;
if (!isToolPrunable(msg.toolName)) continue;
if (hasImageBlocks(msg.content)) {
continue;
}
const alreadyCleared =
msg.content.length === 1 &&
msg.content[0]?.type === "text" &&
msg.content[0].text === settings.hardClear.placeholder;
if (alreadyCleared) continue;
const cleared: ToolResultMessage = {
...msg,
content: [asText(settings.hardClear.placeholder)],
};
if (!next) next = messages.slice();
next[i] = cleared as unknown as AgentMessage;
}
return next ?? messages;
}
const totalCharsBefore = estimateContextChars(messages);
let totalChars = totalCharsBefore;
let ratio = totalChars / charWindow;
if (ratio < settings.softTrimRatio) {
return messages;
}
const prunableToolIndexes: number[] = [];
let next: AgentMessage[] | null = null;
for (let i = 0; i < cutoffIndex; i++) {
const msg = messages[i];
if (!msg || msg.role !== "toolResult") continue;
if (!isToolPrunable(msg.toolName)) continue;
if (hasImageBlocks(msg.content)) {
continue;
}
prunableToolIndexes.push(i);
const updated = softTrimToolResultMessage({
msg: msg as unknown as ToolResultMessage,
settings,
});
if (!updated) continue;
const beforeChars = estimateMessageChars(msg);
const afterChars = estimateMessageChars(updated as unknown as AgentMessage);
totalChars += afterChars - beforeChars;
if (!next) next = messages.slice();
next[i] = updated as unknown as AgentMessage;
}
const outputAfterSoftTrim = next ?? messages;
ratio = totalChars / charWindow;
if (ratio < settings.hardClearRatio) {
return outputAfterSoftTrim;
}
if (!settings.hardClear.enabled) {
return outputAfterSoftTrim;
}
let prunableToolChars = 0;
for (const i of prunableToolIndexes) {
const msg = outputAfterSoftTrim[i];
if (!msg || msg.role !== "toolResult") continue;
prunableToolChars += estimateMessageChars(msg);
}
if (prunableToolChars < settings.minPrunableToolChars) {
return outputAfterSoftTrim;
}
for (const i of prunableToolIndexes) {
if (ratio < settings.hardClearRatio) break;
const msg = (next ?? messages)[i];
if (!msg || msg.role !== "toolResult") continue;
const beforeChars = estimateMessageChars(msg);
const cleared: ToolResultMessage = {
...msg,
content: [asText(settings.hardClear.placeholder)],
};
if (!next) next = messages.slice();
next[i] = cleared as unknown as AgentMessage;
const afterChars = estimateMessageChars(cleared as unknown as AgentMessage);
totalChars += afterChars - beforeChars;
ratio = totalChars / charWindow;
}
return next ?? messages;
}

View File

@@ -0,0 +1,39 @@
import type { EffectiveContextPruningSettings } from "./settings.js";
export type ContextPruningRuntimeValue = {
settings: EffectiveContextPruningSettings;
contextWindowTokens?: number | null;
isToolPrunable: (toolName: string) => boolean;
};
// Session-scoped runtime registry keyed by object identity.
// Important: this relies on Pi passing the same SessionManager object instance into
// ExtensionContext (ctx.sessionManager) that we used when calling setContextPruningRuntime.
const REGISTRY = new WeakMap<object, ContextPruningRuntimeValue>();
export function setContextPruningRuntime(
sessionManager: unknown,
value: ContextPruningRuntimeValue | null,
): void {
if (!sessionManager || typeof sessionManager !== "object") {
return;
}
const key = sessionManager as object;
if (value === null) {
REGISTRY.delete(key);
return;
}
REGISTRY.set(key, value);
}
export function getContextPruningRuntime(
sessionManager: unknown,
): ContextPruningRuntimeValue | null {
if (!sessionManager || typeof sessionManager !== "object") {
return null;
}
return REGISTRY.get(sessionManager as object) ?? null;
}

View File

@@ -0,0 +1,135 @@
export type ContextPruningToolMatch = {
allow?: string[];
deny?: string[];
};
export type ContextPruningMode = "off" | "adaptive" | "aggressive";
export type ContextPruningConfig = {
mode?: ContextPruningMode;
keepLastAssistants?: number;
softTrimRatio?: number;
hardClearRatio?: number;
minPrunableToolChars?: number;
tools?: ContextPruningToolMatch;
softTrim?: {
maxChars?: number;
headChars?: number;
tailChars?: number;
};
hardClear?: {
enabled?: boolean;
placeholder?: string;
};
};
export type EffectiveContextPruningSettings = {
mode: Exclude<ContextPruningMode, "off">;
keepLastAssistants: number;
softTrimRatio: number;
hardClearRatio: number;
minPrunableToolChars: number;
tools: ContextPruningToolMatch;
softTrim: {
maxChars: number;
headChars: number;
tailChars: number;
};
hardClear: {
enabled: boolean;
placeholder: string;
};
};
export const DEFAULT_CONTEXT_PRUNING_SETTINGS: EffectiveContextPruningSettings =
{
mode: "adaptive",
keepLastAssistants: 3,
softTrimRatio: 0.3,
hardClearRatio: 0.5,
minPrunableToolChars: 50_000,
tools: {},
softTrim: {
maxChars: 4_000,
headChars: 1_500,
tailChars: 1_500,
},
hardClear: {
enabled: true,
placeholder: "[Old tool result content cleared]",
},
};
export function computeEffectiveSettings(
raw: unknown,
): EffectiveContextPruningSettings | null {
if (!raw || typeof raw !== "object") return null;
const cfg = raw as ContextPruningConfig;
if (cfg.mode !== "adaptive" && cfg.mode !== "aggressive") return null;
const s: EffectiveContextPruningSettings = structuredClone(
DEFAULT_CONTEXT_PRUNING_SETTINGS,
);
s.mode = cfg.mode;
if (
typeof cfg.keepLastAssistants === "number" &&
Number.isFinite(cfg.keepLastAssistants)
) {
s.keepLastAssistants = Math.max(0, Math.floor(cfg.keepLastAssistants));
}
if (
typeof cfg.softTrimRatio === "number" &&
Number.isFinite(cfg.softTrimRatio)
) {
s.softTrimRatio = Math.min(1, Math.max(0, cfg.softTrimRatio));
}
if (
typeof cfg.hardClearRatio === "number" &&
Number.isFinite(cfg.hardClearRatio)
) {
s.hardClearRatio = Math.min(1, Math.max(0, cfg.hardClearRatio));
}
if (
typeof cfg.minPrunableToolChars === "number" &&
Number.isFinite(cfg.minPrunableToolChars)
) {
s.minPrunableToolChars = Math.max(0, Math.floor(cfg.minPrunableToolChars));
}
if (cfg.tools) {
s.tools = cfg.tools;
}
if (cfg.softTrim) {
if (
typeof cfg.softTrim.maxChars === "number" &&
Number.isFinite(cfg.softTrim.maxChars)
) {
s.softTrim.maxChars = Math.max(0, Math.floor(cfg.softTrim.maxChars));
}
if (
typeof cfg.softTrim.headChars === "number" &&
Number.isFinite(cfg.softTrim.headChars)
) {
s.softTrim.headChars = Math.max(0, Math.floor(cfg.softTrim.headChars));
}
if (
typeof cfg.softTrim.tailChars === "number" &&
Number.isFinite(cfg.softTrim.tailChars)
) {
s.softTrim.tailChars = Math.max(0, Math.floor(cfg.softTrim.tailChars));
}
}
if (cfg.hardClear) {
if (s.mode === "adaptive" && typeof cfg.hardClear.enabled === "boolean") {
s.hardClear.enabled = cfg.hardClear.enabled;
}
if (
typeof cfg.hardClear.placeholder === "string" &&
cfg.hardClear.placeholder.trim()
) {
s.hardClear.placeholder = cfg.hardClear.placeholder.trim();
}
}
return s;
}

View File

@@ -0,0 +1,46 @@
import type { ContextPruningToolMatch } from "./settings.js";
function normalizePatterns(patterns?: string[]): string[] {
if (!Array.isArray(patterns)) return [];
return patterns.map((p) => String(p ?? "").trim()).filter(Boolean);
}
type CompiledPattern =
| { kind: "all" }
| { kind: "exact"; value: string }
| { kind: "regex"; value: RegExp };
function compilePattern(pattern: string): CompiledPattern {
if (pattern === "*") return { kind: "all" };
if (!pattern.includes("*")) return { kind: "exact", value: pattern };
const escaped = pattern.replace(/[.*+?^${}()|[\]\\]/g, "\\$&");
const re = new RegExp(`^${escaped.replaceAll("\\*", ".*")}$`);
return { kind: "regex", value: re };
}
function compilePatterns(patterns?: string[]): CompiledPattern[] {
return normalizePatterns(patterns).map(compilePattern);
}
function matchesAny(toolName: string, patterns: CompiledPattern[]): boolean {
for (const p of patterns) {
if (p.kind === "all") return true;
if (p.kind === "exact" && toolName === p.value) return true;
if (p.kind === "regex" && p.value.test(toolName)) return true;
}
return false;
}
export function makeToolPrunablePredicate(
match: ContextPruningToolMatch,
): (toolName: string) => boolean {
const deny = compilePatterns(match.deny);
const allow = compilePatterns(match.allow);
return (toolName: string) => {
if (matchesAny(toolName, deny)) return false;
if (allow.length === 0) return true;
return matchesAny(toolName, allow);
};
}