Agents: flush pending tool results on drop

This commit is contained in:
Shakker
2026-02-02 23:46:34 +00:00
committed by Shakker
parent e6fdac7bfb
commit befa421a57
2 changed files with 99 additions and 48 deletions

View File

@@ -3,10 +3,14 @@ import { SessionManager } from "@mariozechner/pi-coding-agent";
import { describe, expect, it } from "vitest"; import { describe, expect, it } from "vitest";
import { installSessionToolResultGuard } from "./session-tool-result-guard.js"; import { installSessionToolResultGuard } from "./session-tool-result-guard.js";
const toolCallMessage = { type AppendMessage = Parameters<SessionManager["appendMessage"]>[0];
const asAppendMessage = (message: unknown) => message as AppendMessage;
const toolCallMessage = asAppendMessage({
role: "assistant", role: "assistant",
content: [{ type: "toolCall", id: "call_1", name: "read", arguments: {} }], content: [{ type: "toolCall", id: "call_1", name: "read", arguments: {} }],
} satisfies AgentMessage; });
describe("installSessionToolResultGuard", () => { describe("installSessionToolResultGuard", () => {
it("inserts synthetic toolResult before non-tool message when pending", () => { it("inserts synthetic toolResult before non-tool message when pending", () => {
@@ -14,11 +18,13 @@ describe("installSessionToolResultGuard", () => {
installSessionToolResultGuard(sm); installSessionToolResultGuard(sm);
sm.appendMessage(toolCallMessage); sm.appendMessage(toolCallMessage);
sm.appendMessage({ sm.appendMessage(
role: "assistant", asAppendMessage({
content: [{ type: "text", text: "error" }], role: "assistant",
stopReason: "error", content: [{ type: "text", text: "error" }],
} as AgentMessage); stopReason: "error",
}),
);
const entries = sm const entries = sm
.getEntries() .getEntries()
@@ -56,12 +62,14 @@ describe("installSessionToolResultGuard", () => {
installSessionToolResultGuard(sm); installSessionToolResultGuard(sm);
sm.appendMessage(toolCallMessage); sm.appendMessage(toolCallMessage);
sm.appendMessage({ sm.appendMessage(
role: "toolResult", asAppendMessage({
toolCallId: "call_1", role: "toolResult",
content: [{ type: "text", text: "ok" }], toolCallId: "call_1",
isError: false, content: [{ type: "text", text: "ok" }],
} as AgentMessage); isError: false,
}),
);
const messages = sm const messages = sm
.getEntries() .getEntries()
@@ -75,23 +83,29 @@ describe("installSessionToolResultGuard", () => {
const sm = SessionManager.inMemory(); const sm = SessionManager.inMemory();
const guard = installSessionToolResultGuard(sm); const guard = installSessionToolResultGuard(sm);
sm.appendMessage({ sm.appendMessage(
role: "assistant", asAppendMessage({
content: [ role: "assistant",
{ type: "toolCall", id: "call_a", name: "one", arguments: {} }, content: [
{ type: "toolUse", id: "call_b", name: "two", arguments: {} }, { type: "toolCall", id: "call_a", name: "one", arguments: {} },
], { type: "toolUse", id: "call_b", name: "two", arguments: {} },
} as AgentMessage); ],
sm.appendMessage({ }),
role: "toolResult", );
toolUseId: "call_a", sm.appendMessage(
content: [{ type: "text", text: "a" }], asAppendMessage({
isError: false, role: "toolResult",
} as AgentMessage); toolUseId: "call_a",
sm.appendMessage({ content: [{ type: "text", text: "a" }],
role: "assistant", isError: false,
content: [{ type: "text", text: "after tools" }], }),
} as AgentMessage); );
sm.appendMessage(
asAppendMessage({
role: "assistant",
content: [{ type: "text", text: "after tools" }],
}),
);
const messages = sm const messages = sm
.getEntries() .getEntries()
@@ -113,11 +127,13 @@ describe("installSessionToolResultGuard", () => {
const guard = installSessionToolResultGuard(sm); const guard = installSessionToolResultGuard(sm);
sm.appendMessage(toolCallMessage); sm.appendMessage(toolCallMessage);
sm.appendMessage({ sm.appendMessage(
role: "assistant", asAppendMessage({
content: [{ type: "text", text: "hard error" }], role: "assistant",
stopReason: "error", content: [{ type: "text", text: "hard error" }],
} as AgentMessage); stopReason: "error",
}),
);
expect(guard.getPendingIds()).toEqual([]); expect(guard.getPendingIds()).toEqual([]);
}); });
@@ -125,15 +141,19 @@ describe("installSessionToolResultGuard", () => {
const sm = SessionManager.inMemory(); const sm = SessionManager.inMemory();
installSessionToolResultGuard(sm); installSessionToolResultGuard(sm);
sm.appendMessage({ sm.appendMessage(
role: "assistant", asAppendMessage({
content: [{ type: "toolUse", id: "use_1", name: "f", arguments: {} }], role: "assistant",
} as AgentMessage); content: [{ type: "toolUse", id: "use_1", name: "f", arguments: {} }],
sm.appendMessage({ }),
role: "toolResult", );
toolUseId: "use_1", sm.appendMessage(
content: [{ type: "text", text: "ok" }], asAppendMessage({
} as AgentMessage); role: "toolResult",
toolUseId: "use_1",
content: [{ type: "text", text: "ok" }],
}),
);
const messages = sm const messages = sm
.getEntries() .getEntries()
@@ -146,10 +166,12 @@ describe("installSessionToolResultGuard", () => {
const sm = SessionManager.inMemory(); const sm = SessionManager.inMemory();
installSessionToolResultGuard(sm); installSessionToolResultGuard(sm);
sm.appendMessage({ sm.appendMessage(
role: "assistant", asAppendMessage({
content: [{ type: "toolCall", id: "call_1", name: "read" }], role: "assistant",
} as AgentMessage); content: [{ type: "toolCall", id: "call_1", name: "read" }],
}),
);
const messages = sm const messages = sm
.getEntries() .getEntries()
@@ -158,4 +180,30 @@ describe("installSessionToolResultGuard", () => {
expect(messages).toHaveLength(0); expect(messages).toHaveLength(0);
}); });
it("flushes pending tool results when a sanitized assistant message is dropped", () => {
const sm = SessionManager.inMemory();
installSessionToolResultGuard(sm);
sm.appendMessage(
asAppendMessage({
role: "assistant",
content: [{ type: "toolCall", id: "call_1", name: "read", arguments: {} }],
}),
);
sm.appendMessage(
asAppendMessage({
role: "assistant",
content: [{ type: "toolCall", id: "call_2", name: "read" }],
}),
);
const messages = sm
.getEntries()
.filter((e) => e.type === "message")
.map((e) => (e as { message: AgentMessage }).message);
expect(messages.map((m) => m.role)).toEqual(["assistant", "toolResult"]);
});
}); });

View File

@@ -101,6 +101,9 @@ export function installSessionToolResultGuard(
if (role === "assistant") { if (role === "assistant") {
const sanitized = sanitizeToolCallInputs([message]); const sanitized = sanitizeToolCallInputs([message]);
if (sanitized.length === 0) { if (sanitized.length === 0) {
if (allowSyntheticToolResults && pending.size > 0) {
flushPendingToolResults();
}
return undefined; return undefined;
} }
nextMessage = sanitized[0]; nextMessage = sanitized[0];