mirror of
https://github.com/openclaw/openclaw.git
synced 2026-04-18 12:37:27 +00:00
This commit is contained in:
committed by
GitHub
parent
6c6f1e9660
commit
8eb11bd304
@@ -443,9 +443,16 @@ export async function runEmbeddedAttempt(
|
||||
// Add client tools (OpenResponses hosted tools) to customTools
|
||||
let clientToolCallDetected: { name: string; params: Record<string, unknown> } | null = null;
|
||||
const clientToolDefs = params.clientTools
|
||||
? toClientToolDefinitions(params.clientTools, (toolName, toolParams) => {
|
||||
clientToolCallDetected = { name: toolName, params: toolParams };
|
||||
})
|
||||
? toClientToolDefinitions(
|
||||
params.clientTools,
|
||||
(toolName, toolParams) => {
|
||||
clientToolCallDetected = { name: toolName, params: toolParams };
|
||||
},
|
||||
{
|
||||
agentId: sessionAgentId,
|
||||
sessionKey: params.sessionKey,
|
||||
},
|
||||
)
|
||||
: [];
|
||||
|
||||
const allCustomTools = [...customTools, ...clientToolDefs];
|
||||
|
||||
@@ -1,351 +0,0 @@
|
||||
import type { AgentEvent } from "@mariozechner/pi-agent-core";
|
||||
import { beforeEach, describe, expect, it, vi } from "vitest";
|
||||
import type { EmbeddedPiSubscribeContext } from "./pi-embedded-subscribe.handlers.types.js";
|
||||
import { getGlobalHookRunner } from "../plugins/hook-runner-global.js";
|
||||
import { handleToolExecutionStart } from "./pi-embedded-subscribe.handlers.tools.js";
|
||||
|
||||
// Mock dependencies
|
||||
vi.mock("../plugins/hook-runner-global.js");
|
||||
vi.mock("../infra/agent-events.js", () => ({
|
||||
emitAgentEvent: vi.fn(),
|
||||
}));
|
||||
vi.mock("./pi-embedded-helpers.js");
|
||||
vi.mock("./pi-embedded-messaging.js");
|
||||
vi.mock("./pi-embedded-subscribe.tools.js");
|
||||
vi.mock("./pi-embedded-utils.js", () => ({
|
||||
inferToolMetaFromArgs: vi.fn(() => undefined),
|
||||
}));
|
||||
vi.mock("./tool-policy.js", () => ({
|
||||
normalizeToolName: vi.fn((name: string) => name.toLowerCase()),
|
||||
}));
|
||||
|
||||
const mockGetGlobalHookRunner = vi.mocked(getGlobalHookRunner);
|
||||
|
||||
describe("before_tool_call hook integration", () => {
|
||||
let mockContext: EmbeddedPiSubscribeContext;
|
||||
let mockHookRunner: any;
|
||||
|
||||
beforeEach(() => {
|
||||
// Reset mocks
|
||||
vi.clearAllMocks();
|
||||
|
||||
// Mock context
|
||||
mockContext = {
|
||||
params: {
|
||||
runId: "test-run-123",
|
||||
session: { key: "test-session" },
|
||||
onBlockReplyFlush: vi.fn(),
|
||||
onAgentEvent: vi.fn(),
|
||||
},
|
||||
state: {
|
||||
toolMetaById: {
|
||||
set: vi.fn(),
|
||||
get: vi.fn(),
|
||||
has: vi.fn(),
|
||||
},
|
||||
},
|
||||
log: {
|
||||
debug: vi.fn(),
|
||||
warn: vi.fn(),
|
||||
},
|
||||
flushBlockReplyBuffer: vi.fn(),
|
||||
shouldEmitToolResult: vi.fn().mockReturnValue(true),
|
||||
} as any;
|
||||
|
||||
// Mock hook runner
|
||||
mockHookRunner = {
|
||||
hasHooks: vi.fn(),
|
||||
runBeforeToolCall: vi.fn(),
|
||||
};
|
||||
|
||||
mockGetGlobalHookRunner.mockReturnValue(mockHookRunner);
|
||||
});
|
||||
|
||||
describe("when no hooks are registered", () => {
|
||||
beforeEach(() => {
|
||||
mockHookRunner.hasHooks.mockReturnValue(false);
|
||||
});
|
||||
|
||||
it("should proceed with tool execution normally", async () => {
|
||||
const event: AgentEvent & { toolName: string; toolCallId: string; args: unknown } = {
|
||||
type: "tool_start",
|
||||
toolName: "TestTool",
|
||||
toolCallId: "tool-call-123",
|
||||
args: { param: "value" },
|
||||
};
|
||||
|
||||
// Should not throw
|
||||
await expect(handleToolExecutionStart(mockContext, event)).resolves.toBeUndefined();
|
||||
|
||||
// Hook runner should check for hooks but not run them
|
||||
expect(mockHookRunner.hasHooks).toHaveBeenCalledWith("before_tool_call");
|
||||
expect(mockHookRunner.runBeforeToolCall).not.toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
|
||||
describe("when hooks are registered", () => {
|
||||
beforeEach(() => {
|
||||
mockHookRunner.hasHooks.mockReturnValue(true);
|
||||
});
|
||||
|
||||
it("should call the hook with correct parameters", async () => {
|
||||
mockHookRunner.runBeforeToolCall.mockResolvedValue(undefined);
|
||||
|
||||
const event: AgentEvent & { toolName: string; toolCallId: string; args: unknown } = {
|
||||
type: "tool_start",
|
||||
toolName: "TestTool",
|
||||
toolCallId: "tool-call-123",
|
||||
args: { param: "value" },
|
||||
};
|
||||
|
||||
await handleToolExecutionStart(mockContext, event);
|
||||
|
||||
expect(mockHookRunner.runBeforeToolCall).toHaveBeenCalledWith(
|
||||
{
|
||||
toolName: "testtool", // normalized
|
||||
params: { param: "value" },
|
||||
},
|
||||
{
|
||||
toolName: "testtool",
|
||||
},
|
||||
);
|
||||
});
|
||||
|
||||
it("should allow hook to modify parameters", async () => {
|
||||
const modifiedParams = { param: "modified_value", newParam: "added" };
|
||||
mockHookRunner.runBeforeToolCall.mockResolvedValue({
|
||||
params: modifiedParams,
|
||||
});
|
||||
|
||||
const event: AgentEvent & { toolName: string; toolCallId: string; args: unknown } = {
|
||||
type: "tool_start",
|
||||
toolName: "TestTool",
|
||||
toolCallId: "tool-call-123",
|
||||
args: { param: "value" },
|
||||
};
|
||||
|
||||
// The function should complete without error
|
||||
await expect(handleToolExecutionStart(mockContext, event)).resolves.toBeUndefined();
|
||||
|
||||
expect(mockHookRunner.runBeforeToolCall).toHaveBeenCalledWith(
|
||||
{
|
||||
toolName: "testtool",
|
||||
params: { param: "value" },
|
||||
},
|
||||
{
|
||||
toolName: "testtool",
|
||||
},
|
||||
);
|
||||
|
||||
// Hook should be called and parameter modification should work
|
||||
expect(mockHookRunner.runBeforeToolCall).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it("should handle parameter modification with non-object args safely", async () => {
|
||||
const modifiedParams = { newParam: "replaced" };
|
||||
mockHookRunner.runBeforeToolCall.mockResolvedValue({
|
||||
params: modifiedParams,
|
||||
});
|
||||
|
||||
const testCases = [
|
||||
{ args: null, description: "null args" },
|
||||
{ args: "string", description: "string args" },
|
||||
{ args: 123, description: "number args" },
|
||||
{ args: [1, 2, 3], description: "array args" },
|
||||
];
|
||||
|
||||
for (const { args, description } of testCases) {
|
||||
mockHookRunner.runBeforeToolCall.mockClear();
|
||||
|
||||
const event: AgentEvent & { toolName: string; toolCallId: string; args: unknown } = {
|
||||
type: "tool_start",
|
||||
toolName: "TestTool",
|
||||
toolCallId: `call-${description}`,
|
||||
args,
|
||||
};
|
||||
|
||||
// Should not crash even with non-object args
|
||||
await expect(handleToolExecutionStart(mockContext, event)).resolves.toBeUndefined();
|
||||
|
||||
// Hook should be called with normalized empty params
|
||||
expect(mockHookRunner.runBeforeToolCall).toHaveBeenCalledWith(
|
||||
{
|
||||
toolName: "testtool",
|
||||
params: {}, // Non-objects normalized to empty object
|
||||
},
|
||||
{
|
||||
toolName: "testtool",
|
||||
},
|
||||
);
|
||||
}
|
||||
});
|
||||
|
||||
it("should block tool call when hook returns block=true", async () => {
|
||||
const blockReason = "Tool blocked by security policy";
|
||||
const mockResult = {
|
||||
block: true,
|
||||
blockReason,
|
||||
};
|
||||
|
||||
mockHookRunner.runBeforeToolCall.mockResolvedValue(mockResult);
|
||||
|
||||
const event: AgentEvent & { toolName: string; toolCallId: string; args: unknown } = {
|
||||
type: "tool_start",
|
||||
toolName: "BlockedTool",
|
||||
toolCallId: "tool-call-456",
|
||||
args: { dangerous: "payload" },
|
||||
};
|
||||
|
||||
// Should throw an error with the block reason
|
||||
await expect(handleToolExecutionStart(mockContext, event)).rejects.toThrow(blockReason);
|
||||
|
||||
// Should log the block
|
||||
expect(mockContext.log.debug).toHaveBeenCalledWith(
|
||||
expect.stringContaining("Tool call blocked by plugin hook"),
|
||||
);
|
||||
expect(mockContext.log.debug).toHaveBeenCalledWith(expect.stringContaining(blockReason));
|
||||
|
||||
// Should update internal state like normal tool flow
|
||||
expect(mockContext.state.toolMetaById.set).toHaveBeenCalled();
|
||||
expect(mockContext.params.onAgentEvent).toHaveBeenCalledWith({
|
||||
stream: "tool",
|
||||
data: { phase: "start", name: "blockedtool", toolCallId: "tool-call-456" },
|
||||
});
|
||||
});
|
||||
|
||||
it("should block tool call with default reason when no blockReason provided", async () => {
|
||||
mockHookRunner.runBeforeToolCall.mockResolvedValue({
|
||||
block: true,
|
||||
// no blockReason
|
||||
});
|
||||
|
||||
const event: AgentEvent & { toolName: string; toolCallId: string; args: unknown } = {
|
||||
type: "tool_start",
|
||||
toolName: "BlockedTool",
|
||||
toolCallId: "tool-call-789",
|
||||
args: {},
|
||||
};
|
||||
|
||||
// Should throw with default message
|
||||
await expect(handleToolExecutionStart(mockContext, event)).rejects.toThrow(
|
||||
"Tool call blocked by plugin hook",
|
||||
);
|
||||
});
|
||||
|
||||
it("should handle hook errors gracefully and continue execution", async () => {
|
||||
const hookError = new Error("Hook implementation error");
|
||||
mockHookRunner.runBeforeToolCall.mockRejectedValue(hookError);
|
||||
|
||||
const event: AgentEvent & { toolName: string; toolCallId: string; args: unknown } = {
|
||||
type: "tool_start",
|
||||
toolName: "TestTool",
|
||||
toolCallId: "tool-call-999",
|
||||
args: { param: "value" },
|
||||
};
|
||||
|
||||
// Should not throw - hook errors should be caught
|
||||
await expect(handleToolExecutionStart(mockContext, event)).resolves.toBeUndefined();
|
||||
|
||||
// Should log the hook error
|
||||
expect(mockContext.log.warn).toHaveBeenCalledWith(
|
||||
expect.stringContaining("before_tool_call hook failed"),
|
||||
);
|
||||
expect(mockContext.log.warn).toHaveBeenCalledWith(
|
||||
expect.stringContaining("Hook implementation error"),
|
||||
);
|
||||
});
|
||||
|
||||
it("should re-throw blocking errors even when caught", async () => {
|
||||
const blockReason = "Blocked by security";
|
||||
mockHookRunner.runBeforeToolCall.mockResolvedValue({
|
||||
block: true,
|
||||
blockReason,
|
||||
});
|
||||
|
||||
const event: AgentEvent & { toolName: string; toolCallId: string; args: unknown } = {
|
||||
type: "tool_start",
|
||||
toolName: "TestTool",
|
||||
toolCallId: "tool-call-000",
|
||||
args: {},
|
||||
};
|
||||
|
||||
// The blocking error should still be thrown
|
||||
await expect(handleToolExecutionStart(mockContext, event)).rejects.toThrow(blockReason);
|
||||
});
|
||||
});
|
||||
|
||||
describe("hook context handling", () => {
|
||||
beforeEach(() => {
|
||||
mockHookRunner.hasHooks.mockReturnValue(true);
|
||||
mockHookRunner.runBeforeToolCall.mockResolvedValue(undefined);
|
||||
});
|
||||
|
||||
it("should handle various tool name formats", async () => {
|
||||
const testCases = [
|
||||
{ input: "ReadFile", expected: "readfile" },
|
||||
{ input: "EXEC", expected: "exec" },
|
||||
{ input: "bash-command", expected: "bash-command" },
|
||||
{ input: " SpacedTool ", expected: " spacedtool " },
|
||||
];
|
||||
|
||||
for (const { input, expected } of testCases) {
|
||||
mockHookRunner.runBeforeToolCall.mockClear();
|
||||
|
||||
const event: AgentEvent & { toolName: string; toolCallId: string; args: unknown } = {
|
||||
type: "tool_start",
|
||||
toolName: input,
|
||||
toolCallId: `call-${input}`,
|
||||
args: {},
|
||||
};
|
||||
|
||||
await handleToolExecutionStart(mockContext, event);
|
||||
|
||||
expect(mockHookRunner.runBeforeToolCall).toHaveBeenCalledWith(
|
||||
{
|
||||
toolName: expected,
|
||||
params: {},
|
||||
},
|
||||
{
|
||||
toolName: expected,
|
||||
},
|
||||
);
|
||||
}
|
||||
});
|
||||
|
||||
it("should handle different argument types", async () => {
|
||||
const testCases = [
|
||||
// Non-objects get normalized to {} for hook params (to maintain hook contract)
|
||||
{ args: null, expectedParams: {} },
|
||||
{ args: undefined, expectedParams: {} },
|
||||
{ args: "string", expectedParams: {} },
|
||||
{ args: 123, expectedParams: {} },
|
||||
{ args: [1, 2, 3], expectedParams: {} }, // arrays are not plain objects
|
||||
// Only plain objects are passed through
|
||||
{ args: { key: "value" }, expectedParams: { key: "value" } },
|
||||
];
|
||||
|
||||
for (const { args, expectedParams } of testCases) {
|
||||
mockHookRunner.runBeforeToolCall.mockClear();
|
||||
|
||||
const event: AgentEvent & { toolName: string; toolCallId: string; args: unknown } = {
|
||||
type: "tool_start",
|
||||
toolName: "TestTool",
|
||||
toolCallId: `call-${typeof args}`,
|
||||
args,
|
||||
};
|
||||
|
||||
await handleToolExecutionStart(mockContext, event);
|
||||
|
||||
expect(mockHookRunner.runBeforeToolCall).toHaveBeenCalledWith(
|
||||
{
|
||||
toolName: "testtool",
|
||||
params: expectedParams,
|
||||
},
|
||||
{
|
||||
toolName: "testtool",
|
||||
},
|
||||
);
|
||||
}
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -1,16 +1,7 @@
|
||||
import type { AgentEvent } from "@mariozechner/pi-agent-core";
|
||||
import type { EmbeddedPiSubscribeContext } from "./pi-embedded-subscribe.handlers.types.js";
|
||||
import { emitAgentEvent } from "../infra/agent-events.js";
|
||||
import { getGlobalHookRunner } from "../plugins/hook-runner-global.js";
|
||||
import { normalizeTextForComparison } from "./pi-embedded-helpers.js";
|
||||
|
||||
// Dedicated error class for hook blocking to avoid magic property issues
|
||||
class ToolBlockedError extends Error {
|
||||
constructor(message: string) {
|
||||
super(message);
|
||||
this.name = "ToolBlockedError";
|
||||
}
|
||||
}
|
||||
import { isMessagingTool, isMessagingToolSendAction } from "./pi-embedded-messaging.js";
|
||||
import {
|
||||
extractToolErrorMessage,
|
||||
@@ -58,94 +49,7 @@ export async function handleToolExecutionStart(
|
||||
const rawToolName = String(evt.toolName);
|
||||
const toolName = normalizeToolName(rawToolName);
|
||||
const toolCallId = String(evt.toolCallId);
|
||||
let args = evt.args;
|
||||
|
||||
// Run before_tool_call hook - allows plugins to modify or block tool calls
|
||||
const hookRunner = getGlobalHookRunner();
|
||||
if (hookRunner?.hasHooks("before_tool_call")) {
|
||||
try {
|
||||
// Normalize args to object for hook contract - plugins expect params to be an object
|
||||
const normalizedParams =
|
||||
args && typeof args === "object" && !Array.isArray(args)
|
||||
? (args as Record<string, unknown>)
|
||||
: {};
|
||||
|
||||
const hookResult = await hookRunner.runBeforeToolCall(
|
||||
{
|
||||
toolName,
|
||||
params: normalizedParams,
|
||||
},
|
||||
{
|
||||
toolName,
|
||||
},
|
||||
);
|
||||
|
||||
// Check if hook blocked the tool call
|
||||
if (hookResult?.block) {
|
||||
const blockReason = hookResult.blockReason || "Tool call blocked by plugin hook";
|
||||
|
||||
// Update internal state to match normal tool execution flow
|
||||
const meta = extendExecMeta(toolName, args, inferToolMetaFromArgs(toolName, args));
|
||||
ctx.state.toolMetaById.set(toolCallId, meta);
|
||||
|
||||
ctx.log.debug(
|
||||
`Tool call blocked by plugin hook: runId=${ctx.params.runId} tool=${toolName} toolCallId=${toolCallId} reason=${blockReason}`,
|
||||
);
|
||||
|
||||
// Emit tool start/end events with error to maintain event consistency
|
||||
emitAgentEvent({
|
||||
runId: ctx.params.runId,
|
||||
stream: "tool",
|
||||
data: {
|
||||
phase: "start",
|
||||
name: toolName,
|
||||
toolCallId,
|
||||
args: args as Record<string, unknown>,
|
||||
},
|
||||
});
|
||||
|
||||
// Call onAgentEvent callback to match normal flow
|
||||
void ctx.params.onAgentEvent?.({
|
||||
stream: "tool",
|
||||
data: { phase: "start", name: toolName, toolCallId },
|
||||
});
|
||||
|
||||
emitAgentEvent({
|
||||
runId: ctx.params.runId,
|
||||
stream: "tool",
|
||||
data: {
|
||||
phase: "end",
|
||||
name: toolName,
|
||||
toolCallId,
|
||||
error: blockReason,
|
||||
},
|
||||
});
|
||||
|
||||
// Throw dedicated error class instead of using magic properties
|
||||
throw new ToolBlockedError(blockReason);
|
||||
}
|
||||
|
||||
// If hook modified params, update args safely
|
||||
if (hookResult?.params) {
|
||||
if (args && typeof args === "object" && !Array.isArray(args)) {
|
||||
// Safe to merge with existing object args
|
||||
args = { ...(args as Record<string, unknown>), ...hookResult.params };
|
||||
} else {
|
||||
// For non-object args, replace entirely with hook params
|
||||
args = hookResult.params;
|
||||
}
|
||||
}
|
||||
} catch (err) {
|
||||
// If it's our blocking error, re-throw it
|
||||
if (err instanceof ToolBlockedError) {
|
||||
throw err;
|
||||
}
|
||||
// For other hook errors, log but don't block the tool call
|
||||
ctx.log.warn(
|
||||
`before_tool_call hook failed: runId=${ctx.params.runId} tool=${toolName} toolCallId=${toolCallId} error=${String(err)}`,
|
||||
);
|
||||
}
|
||||
}
|
||||
const args = evt.args;
|
||||
|
||||
if (toolName === "read") {
|
||||
const record = args && typeof args === "object" ? (args as Record<string, unknown>) : {};
|
||||
|
||||
@@ -6,12 +6,17 @@ import type {
|
||||
import type { ToolDefinition } from "@mariozechner/pi-coding-agent";
|
||||
import type { ClientToolDefinition } from "./pi-embedded-runner/run/params.js";
|
||||
import { logDebug, logError } from "../logger.js";
|
||||
import { runBeforeToolCallHook } from "./pi-tools.before-tool-call.js";
|
||||
import { normalizeToolName } from "./tool-policy.js";
|
||||
import { jsonResult } from "./tools/common.js";
|
||||
|
||||
// biome-ignore lint/suspicious/noExplicitAny: TypeBox schema type from pi-agent-core uses a different module instance.
|
||||
type AnyAgentTool = AgentTool<any, unknown>;
|
||||
|
||||
function isPlainObject(value: unknown): value is Record<string, unknown> {
|
||||
return typeof value === "object" && value !== null && !Array.isArray(value);
|
||||
}
|
||||
|
||||
function describeToolExecutionError(err: unknown): {
|
||||
message: string;
|
||||
stack?: string;
|
||||
@@ -76,6 +81,7 @@ export function toToolDefinitions(tools: AnyAgentTool[]): ToolDefinition[] {
|
||||
export function toClientToolDefinitions(
|
||||
tools: ClientToolDefinition[],
|
||||
onClientToolCall?: (toolName: string, params: Record<string, unknown>) => void,
|
||||
hookContext?: { agentId?: string; sessionKey?: string },
|
||||
): ToolDefinition[] {
|
||||
return tools.map((tool) => {
|
||||
const func = tool.function;
|
||||
@@ -91,9 +97,20 @@ export function toClientToolDefinitions(
|
||||
_ctx,
|
||||
_signal,
|
||||
): Promise<AgentToolResult<unknown>> => {
|
||||
const outcome = await runBeforeToolCallHook({
|
||||
toolName: func.name,
|
||||
params,
|
||||
toolCallId,
|
||||
ctx: hookContext,
|
||||
});
|
||||
if (outcome.blocked) {
|
||||
throw new Error(outcome.reason);
|
||||
}
|
||||
const adjustedParams = outcome.params;
|
||||
const paramsRecord = isPlainObject(adjustedParams) ? adjustedParams : {};
|
||||
// Notify handler that a client tool was called
|
||||
if (onClientToolCall) {
|
||||
onClientToolCall(func.name, params as Record<string, unknown>);
|
||||
onClientToolCall(func.name, paramsRecord);
|
||||
}
|
||||
// Return a pending result - the client will execute this tool
|
||||
return jsonResult({
|
||||
|
||||
145
src/agents/pi-tools.before-tool-call.test.ts
Normal file
145
src/agents/pi-tools.before-tool-call.test.ts
Normal file
@@ -0,0 +1,145 @@
|
||||
import { beforeEach, describe, expect, it, vi } from "vitest";
|
||||
import { getGlobalHookRunner } from "../plugins/hook-runner-global.js";
|
||||
import { toClientToolDefinitions } from "./pi-tool-definition-adapter.js";
|
||||
import { wrapToolWithBeforeToolCallHook } from "./pi-tools.before-tool-call.js";
|
||||
|
||||
vi.mock("../plugins/hook-runner-global.js");
|
||||
|
||||
const mockGetGlobalHookRunner = vi.mocked(getGlobalHookRunner);
|
||||
|
||||
describe("before_tool_call hook integration", () => {
|
||||
let hookRunner: {
|
||||
hasHooks: ReturnType<typeof vi.fn>;
|
||||
runBeforeToolCall: ReturnType<typeof vi.fn>;
|
||||
};
|
||||
|
||||
beforeEach(() => {
|
||||
hookRunner = {
|
||||
hasHooks: vi.fn(),
|
||||
runBeforeToolCall: vi.fn(),
|
||||
};
|
||||
mockGetGlobalHookRunner.mockReturnValue(hookRunner as any);
|
||||
});
|
||||
|
||||
it("executes tool normally when no hook is registered", async () => {
|
||||
hookRunner.hasHooks.mockReturnValue(false);
|
||||
const execute = vi.fn().mockResolvedValue({ content: [], details: { ok: true } });
|
||||
const tool = wrapToolWithBeforeToolCallHook({ name: "Read", execute } as any, {
|
||||
agentId: "main",
|
||||
sessionKey: "main",
|
||||
});
|
||||
|
||||
await tool.execute("call-1", { path: "/tmp/file" }, undefined, undefined);
|
||||
|
||||
expect(hookRunner.runBeforeToolCall).not.toHaveBeenCalled();
|
||||
expect(execute).toHaveBeenCalledWith("call-1", { path: "/tmp/file" }, undefined, undefined);
|
||||
});
|
||||
|
||||
it("allows hook to modify parameters", async () => {
|
||||
hookRunner.hasHooks.mockReturnValue(true);
|
||||
hookRunner.runBeforeToolCall.mockResolvedValue({ params: { mode: "safe" } });
|
||||
const execute = vi.fn().mockResolvedValue({ content: [], details: { ok: true } });
|
||||
const tool = wrapToolWithBeforeToolCallHook({ name: "exec", execute } as any);
|
||||
|
||||
await tool.execute("call-2", { cmd: "ls" }, undefined, undefined);
|
||||
|
||||
expect(execute).toHaveBeenCalledWith(
|
||||
"call-2",
|
||||
{ cmd: "ls", mode: "safe" },
|
||||
undefined,
|
||||
undefined,
|
||||
);
|
||||
});
|
||||
|
||||
it("blocks tool execution when hook returns block=true", async () => {
|
||||
hookRunner.hasHooks.mockReturnValue(true);
|
||||
hookRunner.runBeforeToolCall.mockResolvedValue({
|
||||
block: true,
|
||||
blockReason: "blocked",
|
||||
});
|
||||
const execute = vi.fn().mockResolvedValue({ content: [], details: { ok: true } });
|
||||
const tool = wrapToolWithBeforeToolCallHook({ name: "exec", execute } as any);
|
||||
|
||||
await expect(tool.execute("call-3", { cmd: "rm -rf /" }, undefined, undefined)).rejects.toThrow(
|
||||
"blocked",
|
||||
);
|
||||
expect(execute).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it("continues execution when hook throws", async () => {
|
||||
hookRunner.hasHooks.mockReturnValue(true);
|
||||
hookRunner.runBeforeToolCall.mockRejectedValue(new Error("boom"));
|
||||
const execute = vi.fn().mockResolvedValue({ content: [], details: { ok: true } });
|
||||
const tool = wrapToolWithBeforeToolCallHook({ name: "read", execute } as any);
|
||||
|
||||
await tool.execute("call-4", { path: "/tmp/file" }, undefined, undefined);
|
||||
|
||||
expect(execute).toHaveBeenCalledWith("call-4", { path: "/tmp/file" }, undefined, undefined);
|
||||
});
|
||||
|
||||
it("normalizes non-object params for hook contract", async () => {
|
||||
hookRunner.hasHooks.mockReturnValue(true);
|
||||
hookRunner.runBeforeToolCall.mockResolvedValue(undefined);
|
||||
const execute = vi.fn().mockResolvedValue({ content: [], details: { ok: true } });
|
||||
const tool = wrapToolWithBeforeToolCallHook({ name: "ReAd", execute } as any, {
|
||||
agentId: "main",
|
||||
sessionKey: "main",
|
||||
});
|
||||
|
||||
await tool.execute("call-5", "not-an-object", undefined, undefined);
|
||||
|
||||
expect(hookRunner.runBeforeToolCall).toHaveBeenCalledWith(
|
||||
{
|
||||
toolName: "read",
|
||||
params: {},
|
||||
},
|
||||
{
|
||||
toolName: "read",
|
||||
agentId: "main",
|
||||
sessionKey: "main",
|
||||
},
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe("before_tool_call hook integration for client tools", () => {
|
||||
let hookRunner: {
|
||||
hasHooks: ReturnType<typeof vi.fn>;
|
||||
runBeforeToolCall: ReturnType<typeof vi.fn>;
|
||||
};
|
||||
|
||||
beforeEach(() => {
|
||||
hookRunner = {
|
||||
hasHooks: vi.fn(),
|
||||
runBeforeToolCall: vi.fn(),
|
||||
};
|
||||
mockGetGlobalHookRunner.mockReturnValue(hookRunner as any);
|
||||
});
|
||||
|
||||
it("passes modified params to client tool callbacks", async () => {
|
||||
hookRunner.hasHooks.mockReturnValue(true);
|
||||
hookRunner.runBeforeToolCall.mockResolvedValue({ params: { extra: true } });
|
||||
const onClientToolCall = vi.fn();
|
||||
const [tool] = toClientToolDefinitions(
|
||||
[
|
||||
{
|
||||
type: "function",
|
||||
function: {
|
||||
name: "client_tool",
|
||||
description: "Client tool",
|
||||
parameters: { type: "object", properties: { value: { type: "string" } } },
|
||||
},
|
||||
},
|
||||
],
|
||||
onClientToolCall,
|
||||
{ agentId: "main", sessionKey: "main" },
|
||||
);
|
||||
|
||||
await tool.execute("client-call-1", { value: "ok" }, undefined, undefined, undefined);
|
||||
|
||||
expect(onClientToolCall).toHaveBeenCalledWith("client_tool", {
|
||||
value: "ok",
|
||||
extra: true,
|
||||
});
|
||||
});
|
||||
});
|
||||
96
src/agents/pi-tools.before-tool-call.ts
Normal file
96
src/agents/pi-tools.before-tool-call.ts
Normal file
@@ -0,0 +1,96 @@
|
||||
import type { AnyAgentTool } from "./tools/common.js";
|
||||
import { createSubsystemLogger } from "../logging/subsystem.js";
|
||||
import { getGlobalHookRunner } from "../plugins/hook-runner-global.js";
|
||||
import { normalizeToolName } from "./tool-policy.js";
|
||||
|
||||
type HookContext = {
|
||||
agentId?: string;
|
||||
sessionKey?: string;
|
||||
};
|
||||
|
||||
type HookOutcome = { blocked: true; reason: string } | { blocked: false; params: unknown };
|
||||
|
||||
const log = createSubsystemLogger("agents/tools");
|
||||
|
||||
function isPlainObject(value: unknown): value is Record<string, unknown> {
|
||||
return typeof value === "object" && value !== null && !Array.isArray(value);
|
||||
}
|
||||
|
||||
export async function runBeforeToolCallHook(args: {
|
||||
toolName: string;
|
||||
params: unknown;
|
||||
toolCallId?: string;
|
||||
ctx?: HookContext;
|
||||
}): Promise<HookOutcome> {
|
||||
const hookRunner = getGlobalHookRunner();
|
||||
if (!hookRunner?.hasHooks("before_tool_call")) {
|
||||
return { blocked: false, params: args.params };
|
||||
}
|
||||
|
||||
const toolName = normalizeToolName(args.toolName || "tool");
|
||||
const params = args.params;
|
||||
try {
|
||||
const normalizedParams = isPlainObject(params) ? params : {};
|
||||
const hookResult = await hookRunner.runBeforeToolCall(
|
||||
{
|
||||
toolName,
|
||||
params: normalizedParams,
|
||||
},
|
||||
{
|
||||
toolName,
|
||||
agentId: args.ctx?.agentId,
|
||||
sessionKey: args.ctx?.sessionKey,
|
||||
},
|
||||
);
|
||||
|
||||
if (hookResult?.block) {
|
||||
return {
|
||||
blocked: true,
|
||||
reason: hookResult.blockReason || "Tool call blocked by plugin hook",
|
||||
};
|
||||
}
|
||||
|
||||
if (hookResult?.params && isPlainObject(hookResult.params)) {
|
||||
if (isPlainObject(params)) {
|
||||
return { blocked: false, params: { ...params, ...hookResult.params } };
|
||||
}
|
||||
return { blocked: false, params: hookResult.params };
|
||||
}
|
||||
} catch (err) {
|
||||
const toolCallId = args.toolCallId ? ` toolCallId=${args.toolCallId}` : "";
|
||||
log.warn(`before_tool_call hook failed: tool=${toolName}${toolCallId} error=${String(err)}`);
|
||||
}
|
||||
|
||||
return { blocked: false, params };
|
||||
}
|
||||
|
||||
export function wrapToolWithBeforeToolCallHook(
|
||||
tool: AnyAgentTool,
|
||||
ctx?: HookContext,
|
||||
): AnyAgentTool {
|
||||
const execute = tool.execute;
|
||||
if (!execute) {
|
||||
return tool;
|
||||
}
|
||||
const toolName = tool.name || "tool";
|
||||
return {
|
||||
...tool,
|
||||
execute: async (toolCallId, params, signal, onUpdate) => {
|
||||
const outcome = await runBeforeToolCallHook({
|
||||
toolName,
|
||||
params,
|
||||
toolCallId,
|
||||
ctx,
|
||||
});
|
||||
if (outcome.blocked) {
|
||||
throw new Error(outcome.reason);
|
||||
}
|
||||
return await execute(toolCallId, outcome.params, signal, onUpdate);
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
export const __testing = {
|
||||
runBeforeToolCallHook,
|
||||
isPlainObject,
|
||||
};
|
||||
@@ -23,6 +23,7 @@ import {
|
||||
import { listChannelAgentTools } from "./channel-tools.js";
|
||||
import { createOpenClawTools } from "./openclaw-tools.js";
|
||||
import { wrapToolWithAbortSignal } from "./pi-tools.abort.js";
|
||||
import { wrapToolWithBeforeToolCallHook } from "./pi-tools.before-tool-call.js";
|
||||
import {
|
||||
filterToolsByPolicy,
|
||||
isToolAllowedByPolicies,
|
||||
@@ -423,9 +424,15 @@ export function createOpenClawCodingTools(options?: {
|
||||
// Always normalize tool JSON Schemas before handing them to pi-agent/pi-ai.
|
||||
// Without this, some providers (notably OpenAI) will reject root-level union schemas.
|
||||
const normalized = subagentFiltered.map(normalizeToolParameters);
|
||||
const withHooks = normalized.map((tool) =>
|
||||
wrapToolWithBeforeToolCallHook(tool, {
|
||||
agentId,
|
||||
sessionKey: options?.sessionKey,
|
||||
}),
|
||||
);
|
||||
const withAbort = options?.abortSignal
|
||||
? normalized.map((tool) => wrapToolWithAbortSignal(tool, options.abortSignal))
|
||||
: normalized;
|
||||
? withHooks.map((tool) => wrapToolWithAbortSignal(tool, options.abortSignal))
|
||||
: withHooks;
|
||||
|
||||
// NOTE: Keep canonical (lowercase) tool names here.
|
||||
// pi-ai's Anthropic OAuth transport remaps tool names to Claude Code-style names
|
||||
|
||||
Reference in New Issue
Block a user