feat: add configurable tool loop detection

This commit is contained in:
Peter Steinberger
2026-02-17 00:17:01 +01:00
parent dacffd7ac8
commit 076df941a3
14 changed files with 557 additions and 30 deletions

View File

@@ -49,7 +49,7 @@ import {
resolveCompactionReserveTokensFloor,
} from "../../pi-settings.js";
import { toClientToolDefinitions } from "../../pi-tool-definition-adapter.js";
import { createOpenClawCodingTools } from "../../pi-tools.js";
import { createOpenClawCodingTools, resolveToolLoopDetectionConfig } from "../../pi-tools.js";
import { resolveSandboxContext } from "../../sandbox.js";
import { resolveSandboxRuntimeStatus } from "../../sandbox/runtime-status.js";
import { repairSessionFileIfNeeded } from "../../session-file-repair.js";
@@ -544,6 +544,10 @@ export async function runEmbeddedAttempt(
// Add client tools (OpenResponses hosted tools) to customTools
let clientToolCallDetected: { name: string; params: Record<string, unknown> } | null = null;
const clientToolLoopDetection = resolveToolLoopDetectionConfig({
cfg: params.config,
agentId: sessionAgentId,
});
const clientToolDefs = params.clientTools
? toClientToolDefinitions(
params.clientTools,
@@ -553,6 +557,7 @@ export async function runEmbeddedAttempt(
{
agentId: sessionAgentId,
sessionKey: params.sessionKey,
loopDetection: clientToolLoopDetection,
},
)
: [];

View File

@@ -5,6 +5,7 @@ import type {
} from "@mariozechner/pi-agent-core";
import type { ToolDefinition } from "@mariozechner/pi-coding-agent";
import type { ClientToolDefinition } from "./pi-embedded-runner/run/params.js";
import type { HookContext } from "./pi-tools.before-tool-call.js";
import { logDebug, logError } from "../logger.js";
import { getGlobalHookRunner } from "../plugins/hook-runner-global.js";
import { isPlainObject } from "../utils.js";
@@ -190,7 +191,7 @@ export function toToolDefinitions(tools: AnyAgentTool[]): ToolDefinition[] {
export function toClientToolDefinitions(
tools: ClientToolDefinition[],
onClientToolCall?: (toolName: string, params: Record<string, unknown>) => void,
hookContext?: { agentId?: string; sessionKey?: string },
hookContext?: HookContext,
): ToolDefinition[] {
return tools.map((tool) => {
const func = tool.function;

View File

@@ -19,7 +19,17 @@ describe("before_tool_call loop detection behavior", () => {
hasHooks: ReturnType<typeof vi.fn>;
runBeforeToolCall: ReturnType<typeof vi.fn>;
};
const defaultToolContext = { agentId: "main", sessionKey: "main" };
const enabledLoopDetectionContext = {
agentId: "main",
sessionKey: "main",
loopDetection: { enabled: true },
};
const disabledLoopDetectionContext = {
agentId: "main",
sessionKey: "main",
loopDetection: { enabled: false },
};
beforeEach(() => {
resetDiagnosticSessionStateForTest();
@@ -33,10 +43,14 @@ describe("before_tool_call loop detection behavior", () => {
hookRunner.hasHooks.mockReturnValue(false);
});
function createWrappedTool(name: string, execute: ReturnType<typeof vi.fn>) {
function createWrappedTool(
name: string,
execute: ReturnType<typeof vi.fn>,
loopDetectionContext = enabledLoopDetectionContext,
) {
return wrapToolWithBeforeToolCallHook(
{ name, execute } as unknown as AnyAgentTool,
defaultToolContext,
loopDetectionContext,
);
}
@@ -95,7 +109,6 @@ describe("before_tool_call loop detection behavior", () => {
}
}
}
it("blocks known poll loops when no progress repeats", async () => {
const execute = vi.fn().mockResolvedValue({
content: [{ type: "text", text: "(no new output)\n\nProcess still running." }],
@@ -113,6 +126,22 @@ describe("before_tool_call loop detection behavior", () => {
).rejects.toThrow("CRITICAL");
});
it("does nothing when loopDetection.enabled is false", async () => {
const execute = vi.fn().mockResolvedValue({
content: [{ type: "text", text: "(no new output)\n\nProcess still running." }],
details: { status: "running", aggregated: "steady" },
});
// oxlint-disable-next-line typescript/no-explicit-any
const tool = wrapToolWithBeforeToolCallHook({ name: "process", execute } as any, {
...disabledLoopDetectionContext,
});
const params = { action: "poll", sessionId: "sess-off" };
for (let i = 0; i < CRITICAL_THRESHOLD; i += 1) {
await expect(tool.execute(`poll-${i}`, params, undefined, undefined)).resolves.toBeDefined();
}
});
it("does not block known poll loops when output progresses", async () => {
const execute = vi.fn().mockImplementation(async (toolCallId: string) => {
return {

View File

@@ -1,3 +1,4 @@
import type { ToolLoopDetectionConfig } from "../config/types.tools.js";
import type { SessionState } from "../logging/diagnostic-session-state.js";
import type { AnyAgentTool } from "./tools/common.js";
import { createSubsystemLogger } from "../logging/subsystem.js";
@@ -5,9 +6,10 @@ import { getGlobalHookRunner } from "../plugins/hook-runner-global.js";
import { isPlainObject } from "../utils.js";
import { normalizeToolName } from "./tool-policy.js";
type HookContext = {
export type HookContext = {
agentId?: string;
sessionKey?: string;
loopDetection?: ToolLoopDetectionConfig;
};
type HookOutcome = { blocked: true; reason: string } | { blocked: false; params: unknown };
@@ -62,6 +64,7 @@ async function recordLoopOutcome(args: {
toolCallId: args.toolCallId,
result: args.result,
error: args.error,
config: args.ctx.loopDetection,
});
} catch (err) {
log.warn(`tool loop outcome tracking failed: tool=${args.toolName} error=${String(err)}`);
@@ -87,7 +90,7 @@ export async function runBeforeToolCallHook(args: {
sessionId: args.ctx?.agentId,
});
const loopResult = detectToolCallLoop(sessionState, toolName, params);
const loopResult = detectToolCallLoop(sessionState, toolName, params, args.ctx.loopDetection);
if (loopResult.stuck) {
if (loopResult.level === "critical") {
@@ -126,7 +129,7 @@ export async function runBeforeToolCallHook(args: {
}
}
recordToolCall(sessionState, toolName, params, args.toolCallId);
recordToolCall(sessionState, toolName, params, args.toolCallId, args.ctx.loopDetection);
}
const hookRunner = getGlobalHookRunner();

View File

@@ -6,6 +6,7 @@ import {
readTool,
} from "@mariozechner/pi-coding-agent";
import type { OpenClawConfig } from "../config/config.js";
import type { ToolLoopDetectionConfig } from "../config/types.tools.js";
import type { ModelAuthMode } from "./model-auth.js";
import type { AnyAgentTool } from "./pi-tools.types.js";
import type { SandboxContext } from "./sandbox.js";
@@ -124,6 +125,33 @@ function resolveFsConfig(params: { cfg?: OpenClawConfig; agentId?: string }) {
};
}
export function resolveToolLoopDetectionConfig(params: {
cfg?: OpenClawConfig;
agentId?: string;
}): ToolLoopDetectionConfig | undefined {
const global = params.cfg?.tools?.loopDetection;
const agent =
params.agentId && params.cfg
? resolveAgentConfig(params.cfg, params.agentId)?.tools?.loopDetection
: undefined;
if (!agent) {
return global;
}
if (!global) {
return agent;
}
return {
...global,
...agent,
detectors: {
...global.detectors,
...agent.detectors,
},
};
}
export const __testing = {
cleanToolSchemaForGemini,
normalizeToolParams,
@@ -451,6 +479,7 @@ export function createOpenClawCodingTools(options?: {
wrapToolWithBeforeToolCallHook(tool, {
agentId,
sessionKey: options?.sessionKey,
loopDetection: resolveToolLoopDetectionConfig({ cfg: options?.config, agentId }),
}),
);
const withAbort = options?.abortSignal

View File

@@ -1,4 +1,5 @@
import { describe, expect, it } from "vitest";
import type { ToolLoopDetectionConfig } from "../config/types.tools.js";
import type { SessionState } from "../logging/diagnostic-session-state.js";
import {
CRITICAL_THRESHOLD,
@@ -20,6 +21,13 @@ function createState(): SessionState {
};
}
const enabledLoopDetectionConfig: ToolLoopDetectionConfig = { enabled: true };
const shortHistoryLoopConfig: ToolLoopDetectionConfig = {
enabled: true,
historySize: 4,
};
function recordSuccessfulCall(
state: SessionState,
toolName: string,
@@ -111,9 +119,31 @@ describe("tool-loop-detection", () => {
expect(timestamp).toBeGreaterThanOrEqual(before);
expect(timestamp).toBeLessThanOrEqual(after);
});
it("respects configured historySize", () => {
const state = createState();
for (let i = 0; i < 10; i += 1) {
recordToolCall(state, "tool", { iteration: i }, `call-${i}`, shortHistoryLoopConfig);
}
expect(state.toolCallHistory).toHaveLength(4);
expect(state.toolCallHistory?.[0]?.argsHash).toBe(hashToolCall("tool", { iteration: 6 }));
});
});
describe("detectToolCallLoop", () => {
it("is disabled by default", () => {
const state = createState();
for (let i = 0; i < 20; i += 1) {
recordToolCall(state, "read", { path: "/same.txt" }, `default-${i}`);
}
const loopResult = detectToolCallLoop(state, "read", { path: "/same.txt" });
expect(loopResult.stuck).toBe(false);
});
it("does not flag unique tool calls", () => {
const state = createState();
@@ -121,7 +151,12 @@ describe("tool-loop-detection", () => {
recordToolCall(state, "read", { path: `/file${i}.txt` }, `call-${i}`);
}
const result = detectToolCallLoop(state, "read", { path: "/new-file.txt" });
const result = detectToolCallLoop(
state,
"read",
{ path: "/new-file.txt" },
enabledLoopDetectionConfig,
);
expect(result.stuck).toBe(false);
});
@@ -131,7 +166,12 @@ describe("tool-loop-detection", () => {
recordToolCall(state, "read", { path: "/same.txt" }, `warn-${i}`);
}
const result = detectToolCallLoop(state, "read", { path: "/same.txt" });
const result = detectToolCallLoop(
state,
"read",
{ path: "/same.txt" },
enabledLoopDetectionConfig,
);
expect(result.stuck).toBe(true);
if (result.stuck) {
@@ -155,13 +195,74 @@ describe("tool-loop-detection", () => {
recordSuccessfulCall(state, "read", params, result, i);
}
const loopResult = detectToolCallLoop(state, "read", params);
const loopResult = detectToolCallLoop(state, "read", params, enabledLoopDetectionConfig);
expect(loopResult.stuck).toBe(true);
if (loopResult.stuck) {
expect(loopResult.level).toBe("warning");
}
});
it("applies custom thresholds when detection is enabled", () => {
const state = createState();
const params = { action: "poll", sessionId: "sess-custom" };
const result = {
content: [{ type: "text", text: "(no new output)\n\nProcess still running." }],
details: { status: "running", aggregated: "steady" },
};
const config: ToolLoopDetectionConfig = {
enabled: true,
warningThreshold: 2,
criticalThreshold: 4,
detectors: {
genericRepeat: false,
knownPollNoProgress: true,
pingPong: false,
},
};
for (let i = 0; i < 2; i += 1) {
recordSuccessfulCall(state, "process", params, result, i);
}
const warningResult = detectToolCallLoop(state, "process", params, config);
expect(warningResult.stuck).toBe(true);
if (warningResult.stuck) {
expect(warningResult.level).toBe("warning");
}
recordSuccessfulCall(state, "process", params, result, 2);
recordSuccessfulCall(state, "process", params, result, 3);
const criticalResult = detectToolCallLoop(state, "process", params, config);
expect(criticalResult.stuck).toBe(true);
if (criticalResult.stuck) {
expect(criticalResult.level).toBe("critical");
}
expect(criticalResult.detector).toBe("known_poll_no_progress");
});
it("can disable specific detectors", () => {
const state = createState();
const params = { action: "poll", sessionId: "sess-no-detectors" };
const result = {
content: [{ type: "text", text: "(no new output)\n\nProcess still running." }],
details: { status: "running", aggregated: "steady" },
};
const config: ToolLoopDetectionConfig = {
enabled: true,
detectors: {
genericRepeat: false,
knownPollNoProgress: false,
pingPong: false,
},
};
for (let i = 0; i < CRITICAL_THRESHOLD; i += 1) {
recordSuccessfulCall(state, "process", params, result, i);
}
const loopResult = detectToolCallLoop(state, "process", params, config);
expect(loopResult.stuck).toBe(false);
});
it("warns for known polling no-progress loops", () => {
const state = createState();
const params = { action: "poll", sessionId: "sess-1" };
@@ -174,7 +275,7 @@ describe("tool-loop-detection", () => {
recordSuccessfulCall(state, "process", params, result, i);
}
const loopResult = detectToolCallLoop(state, "process", params);
const loopResult = detectToolCallLoop(state, "process", params, enabledLoopDetectionConfig);
expect(loopResult.stuck).toBe(true);
if (loopResult.stuck) {
expect(loopResult.level).toBe("warning");
@@ -195,7 +296,7 @@ describe("tool-loop-detection", () => {
recordSuccessfulCall(state, "process", params, result, i);
}
const loopResult = detectToolCallLoop(state, "process", params);
const loopResult = detectToolCallLoop(state, "process", params, enabledLoopDetectionConfig);
expect(loopResult.stuck).toBe(true);
if (loopResult.stuck) {
expect(loopResult.level).toBe("critical");
@@ -216,7 +317,7 @@ describe("tool-loop-detection", () => {
recordSuccessfulCall(state, "process", params, result, i);
}
const loopResult = detectToolCallLoop(state, "process", params);
const loopResult = detectToolCallLoop(state, "process", params, enabledLoopDetectionConfig);
expect(loopResult.stuck).toBe(false);
});
@@ -232,7 +333,7 @@ describe("tool-loop-detection", () => {
recordSuccessfulCall(state, "read", params, result, i);
}
const loopResult = detectToolCallLoop(state, "read", params);
const loopResult = detectToolCallLoop(state, "read", params, enabledLoopDetectionConfig);
expect(loopResult.stuck).toBe(true);
if (loopResult.stuck) {
expect(loopResult.level).toBe("critical");
@@ -254,7 +355,7 @@ describe("tool-loop-detection", () => {
}
}
const loopResult = detectToolCallLoop(state, "list", listParams);
const loopResult = detectToolCallLoop(state, "list", listParams, enabledLoopDetectionConfig);
expect(loopResult.stuck).toBe(true);
if (loopResult.stuck) {
expect(loopResult.level).toBe("warning");
@@ -289,7 +390,7 @@ describe("tool-loop-detection", () => {
}
}
const loopResult = detectToolCallLoop(state, "list", listParams);
const loopResult = detectToolCallLoop(state, "list", listParams, enabledLoopDetectionConfig);
expect(loopResult.stuck).toBe(true);
if (loopResult.stuck) {
expect(loopResult.level).toBe("critical");
@@ -325,7 +426,7 @@ describe("tool-loop-detection", () => {
}
}
const loopResult = detectToolCallLoop(state, "list", listParams);
const loopResult = detectToolCallLoop(state, "list", listParams, enabledLoopDetectionConfig);
expect(loopResult.stuck).toBe(true);
if (loopResult.stuck) {
expect(loopResult.level).toBe("warning");
@@ -341,7 +442,12 @@ describe("tool-loop-detection", () => {
recordToolCall(state, "read", { path: "/a.txt" }, "a2");
recordToolCall(state, "write", { path: "/tmp/out.txt" }, "c1"); // breaks alternation
const loopResult = detectToolCallLoop(state, "list", { dir: "/workspace" });
const loopResult = detectToolCallLoop(
state,
"list",
{ dir: "/workspace" },
enabledLoopDetectionConfig,
);
expect(loopResult.stuck).toBe(false);
});
@@ -368,7 +474,7 @@ describe("tool-loop-detection", () => {
it("handles empty history", () => {
const state = createState();
const result = detectToolCallLoop(state, "tool", { arg: 1 });
const result = detectToolCallLoop(state, "tool", { arg: 1 }, enabledLoopDetectionConfig);
expect(result.stuck).toBe(false);
});
});

View File

@@ -1,4 +1,5 @@
import { createHash } from "node:crypto";
import type { ToolLoopDetectionConfig } from "../config/types.tools.js";
import type { SessionState } from "../logging/diagnostic-session-state.js";
import { createSubsystemLogger } from "../logging/subsystem.js";
import { isPlainObject } from "../utils.js";
@@ -27,6 +28,76 @@ export const TOOL_CALL_HISTORY_SIZE = 30;
export const WARNING_THRESHOLD = 10;
export const CRITICAL_THRESHOLD = 20;
export const GLOBAL_CIRCUIT_BREAKER_THRESHOLD = 30;
const DEFAULT_LOOP_DETECTION_CONFIG = {
enabled: false,
historySize: TOOL_CALL_HISTORY_SIZE,
warningThreshold: WARNING_THRESHOLD,
criticalThreshold: CRITICAL_THRESHOLD,
globalCircuitBreakerThreshold: GLOBAL_CIRCUIT_BREAKER_THRESHOLD,
detectors: {
genericRepeat: true,
knownPollNoProgress: true,
pingPong: true,
},
};
type ResolvedLoopDetectionConfig = {
enabled: boolean;
historySize: number;
warningThreshold: number;
criticalThreshold: number;
globalCircuitBreakerThreshold: number;
detectors: {
genericRepeat: boolean;
knownPollNoProgress: boolean;
pingPong: boolean;
};
};
function asPositiveInt(value: number | undefined, fallback: number): number {
if (!Number.isInteger(value) || value <= 0) {
return fallback;
}
return value;
}
function resolveLoopDetectionConfig(config?: ToolLoopDetectionConfig): ResolvedLoopDetectionConfig {
let warningThreshold = asPositiveInt(
config?.warningThreshold,
DEFAULT_LOOP_DETECTION_CONFIG.warningThreshold,
);
let criticalThreshold = asPositiveInt(
config?.criticalThreshold,
DEFAULT_LOOP_DETECTION_CONFIG.criticalThreshold,
);
let globalCircuitBreakerThreshold = asPositiveInt(
config?.globalCircuitBreakerThreshold,
DEFAULT_LOOP_DETECTION_CONFIG.globalCircuitBreakerThreshold,
);
if (criticalThreshold <= warningThreshold) {
criticalThreshold = warningThreshold + 1;
}
if (globalCircuitBreakerThreshold <= criticalThreshold) {
globalCircuitBreakerThreshold = criticalThreshold + 1;
}
return {
enabled: config?.enabled ?? DEFAULT_LOOP_DETECTION_CONFIG.enabled,
historySize: asPositiveInt(config?.historySize, DEFAULT_LOOP_DETECTION_CONFIG.historySize),
warningThreshold,
criticalThreshold,
globalCircuitBreakerThreshold,
detectors: {
genericRepeat:
config?.detectors?.genericRepeat ?? DEFAULT_LOOP_DETECTION_CONFIG.detectors.genericRepeat,
knownPollNoProgress:
config?.detectors?.knownPollNoProgress ??
DEFAULT_LOOP_DETECTION_CONFIG.detectors.knownPollNoProgress,
pingPong: config?.detectors?.pingPong ?? DEFAULT_LOOP_DETECTION_CONFIG.detectors.pingPong,
},
};
}
/**
* Hash a tool call for pattern matching.
@@ -302,7 +373,12 @@ export function detectToolCallLoop(
state: SessionState,
toolName: string,
params: unknown,
config?: ToolLoopDetectionConfig,
): LoopDetectionResult {
const resolvedConfig = resolveLoopDetectionConfig(config);
if (!resolvedConfig.enabled) {
return { stuck: false };
}
const history = state.toolCallHistory ?? [];
const currentHash = hashToolCall(toolName, params);
const noProgress = getNoProgressStreak(history, toolName, currentHash);
@@ -310,7 +386,7 @@ export function detectToolCallLoop(
const knownPollTool = isKnownPollToolCall(toolName, params);
const pingPong = getPingPongStreak(history, currentHash);
if (noProgressStreak >= GLOBAL_CIRCUIT_BREAKER_THRESHOLD) {
if (noProgressStreak >= resolvedConfig.globalCircuitBreakerThreshold) {
log.error(
`Global circuit breaker triggered: ${toolName} repeated ${noProgressStreak} times with no progress`,
);
@@ -324,7 +400,11 @@ export function detectToolCallLoop(
};
}
if (knownPollTool && noProgressStreak >= CRITICAL_THRESHOLD) {
if (
knownPollTool &&
resolvedConfig.detectors.knownPollNoProgress &&
noProgressStreak >= resolvedConfig.criticalThreshold
) {
log.error(`Critical polling loop detected: ${toolName} repeated ${noProgressStreak} times`);
return {
stuck: true,
@@ -336,7 +416,11 @@ export function detectToolCallLoop(
};
}
if (knownPollTool && noProgressStreak >= WARNING_THRESHOLD) {
if (
knownPollTool &&
resolvedConfig.detectors.knownPollNoProgress &&
noProgressStreak >= resolvedConfig.warningThreshold
) {
log.warn(`Polling loop warning: ${toolName} repeated ${noProgressStreak} times`);
return {
stuck: true,
@@ -352,7 +436,11 @@ export function detectToolCallLoop(
? `pingpong:${canonicalPairKey(currentHash, pingPong.pairedSignature)}`
: `pingpong:${toolName}:${currentHash}`;
if (pingPong.count >= CRITICAL_THRESHOLD && pingPong.noProgressEvidence) {
if (
resolvedConfig.detectors.pingPong &&
pingPong.count >= resolvedConfig.criticalThreshold &&
pingPong.noProgressEvidence
) {
log.error(
`Critical ping-pong loop detected: alternating calls count=${pingPong.count} currentTool=${toolName}`,
);
@@ -367,7 +455,7 @@ export function detectToolCallLoop(
};
}
if (pingPong.count >= WARNING_THRESHOLD) {
if (resolvedConfig.detectors.pingPong && pingPong.count >= resolvedConfig.warningThreshold) {
log.warn(
`Ping-pong loop warning: alternating calls count=${pingPong.count} currentTool=${toolName}`,
);
@@ -387,7 +475,11 @@ export function detectToolCallLoop(
(h) => h.toolName === toolName && h.argsHash === currentHash,
).length;
if (!knownPollTool && recentCount >= WARNING_THRESHOLD) {
if (
!knownPollTool &&
resolvedConfig.detectors.genericRepeat &&
recentCount >= resolvedConfig.warningThreshold
) {
log.warn(`Loop warning: ${toolName} called ${recentCount} times with identical arguments`);
return {
stuck: true,
@@ -411,7 +503,9 @@ export function recordToolCall(
toolName: string,
params: unknown,
toolCallId?: string,
config?: ToolLoopDetectionConfig,
): void {
const resolvedConfig = resolveLoopDetectionConfig(config);
if (!state.toolCallHistory) {
state.toolCallHistory = [];
}
@@ -423,7 +517,7 @@ export function recordToolCall(
timestamp: Date.now(),
});
if (state.toolCallHistory.length > TOOL_CALL_HISTORY_SIZE) {
if (state.toolCallHistory.length > resolvedConfig.historySize) {
state.toolCallHistory.shift();
}
}
@@ -439,8 +533,10 @@ export function recordToolCallOutcome(
toolCallId?: string;
result?: unknown;
error?: unknown;
config?: ToolLoopDetectionConfig;
},
): void {
const resolvedConfig = resolveLoopDetectionConfig(params.config);
const resultHash = hashToolOutcome(
params.toolName,
params.toolParams,
@@ -486,8 +582,8 @@ export function recordToolCallOutcome(
});
}
if (state.toolCallHistory.length > TOOL_CALL_HISTORY_SIZE) {
state.toolCallHistory.splice(0, state.toolCallHistory.length - TOOL_CALL_HISTORY_SIZE);
if (state.toolCallHistory.length > resolvedConfig.historySize) {
state.toolCallHistory.splice(0, state.toolCallHistory.length - resolvedConfig.historySize);
}
}