From 0c1c34c9506d885253428286e56cf02f223ac780 Mon Sep 17 00:00:00 2001 From: Peter Steinberger Date: Tue, 17 Feb 2026 03:28:10 +0100 Subject: [PATCH] refactor(plugins): split before-agent hooks by model and prompt phases --- docs/concepts/agent-loop.md | 4 +- src/agents/pi-embedded-runner/run.ts | 68 +++-- src/agents/pi-embedded-runner/run/attempt.ts | 75 +++-- src/agents/pi-embedded-runner/run/types.ts | 3 - .../hooks.model-override-wiring.test.ts | 267 ++++++++---------- src/plugins/hooks.phase-hooks.test.ts | 75 +++++ src/plugins/hooks.ts | 75 ++++- src/plugins/types.ts | 47 ++- 8 files changed, 389 insertions(+), 225 deletions(-) create mode 100644 src/plugins/hooks.phase-hooks.test.ts diff --git a/docs/concepts/agent-loop.md b/docs/concepts/agent-loop.md index b0d99ca907e..8699535aa6b 100644 --- a/docs/concepts/agent-loop.md +++ b/docs/concepts/agent-loop.md @@ -81,7 +81,9 @@ See [Hooks](/automation/hooks) for setup and examples. These run inside the agent loop or gateway pipeline: -- **`before_agent_start`**: inject context or override system prompt before the run starts. +- **`before_model_resolve`**: runs pre-session (no `messages`) to deterministically override provider/model before model resolution. +- **`before_prompt_build`**: runs after session load (with `messages`) to inject `prependContext`/`systemPrompt` before prompt submission. +- **`before_agent_start`**: legacy compatibility hook that may run in either phase; prefer the explicit hooks above. - **`agent_end`**: inspect the final message list and run metadata after completion. - **`before_compaction` / `after_compaction`**: observe or annotate compaction cycles. - **`before_tool_call` / `after_tool_call`**: intercept tool params/results. diff --git a/src/agents/pi-embedded-runner/run.ts b/src/agents/pi-embedded-runner/run.ts index ec86d3d0cfe..381b23b27fb 100644 --- a/src/agents/pi-embedded-runner/run.ts +++ b/src/agents/pi-embedded-runner/run.ts @@ -1,7 +1,8 @@ import fs from "node:fs/promises"; import type { ThinkLevel } from "../../auto-reply/thinking.js"; +import type { RunEmbeddedPiAgentParams } from "./run/params.js"; +import type { EmbeddedPiAgentMeta, EmbeddedPiRunResult } from "./types.js"; import { getGlobalHookRunner } from "../../plugins/hook-runner-global.js"; -import type { PluginHookBeforeAgentStartResult } from "../../plugins/types.js"; import { enqueueCommandInLane } from "../../process/command-queue.js"; import { isMarkdownCapableMessageChannel } from "../../utils/message-channel.js"; import { resolveOpenClawAgentDir } from "../agent-paths.js"; @@ -51,13 +52,11 @@ import { resolveGlobalLane, resolveSessionLane } from "./lanes.js"; import { log } from "./logger.js"; import { resolveModel } from "./model.js"; import { runEmbeddedAttempt } from "./run/attempt.js"; -import type { RunEmbeddedPiAgentParams } from "./run/params.js"; import { buildEmbeddedRunPayloads } from "./run/payloads.js"; import { truncateOversizedToolResultsInSession, sessionLikelyHasOversizedToolResults, } from "./tool-result-truncation.js"; -import type { EmbeddedPiAgentMeta, EmbeddedPiRunResult } from "./types.js"; import { describeUnknownError } from "./utils.js"; type ApiKeyInfo = ResolvedProviderAuth; @@ -207,35 +206,55 @@ export async function runEmbeddedPiAgent( (params.config?.agents?.defaults?.model?.fallbacks?.length ?? 0) > 0; await ensureOpenClawModelsJson(params.config, agentDir); - // Run before_agent_start hooks early so plugins can override the model - // before it gets resolved. The hook result is passed downstream to - // attempt.ts to avoid double-firing. - let earlyHookResult: PluginHookBeforeAgentStartResult | undefined; + // Run before_model_resolve hooks early so plugins can override the + // provider/model before resolveModel(). + // + // Legacy compatibility: before_agent_start is also checked for override + // fields if present. New hook takes precedence when both are set. + let modelResolveOverride: { providerOverride?: string; modelOverride?: string } | undefined; const hookRunner = getGlobalHookRunner(); + const hookCtx = { + agentId: workspaceResolution.agentId, + sessionKey: params.sessionKey, + sessionId: params.sessionId, + workspaceDir: resolvedWorkspace, + messageProvider: params.messageProvider ?? undefined, + }; + if (hookRunner?.hasHooks("before_model_resolve")) { + try { + modelResolveOverride = await hookRunner.runBeforeModelResolve( + { prompt: params.prompt }, + hookCtx, + ); + } catch (hookErr) { + log.warn(`before_model_resolve hook failed: ${String(hookErr)}`); + } + } if (hookRunner?.hasHooks("before_agent_start")) { try { - earlyHookResult = await hookRunner.runBeforeAgentStart( + const legacyResult = await hookRunner.runBeforeAgentStart( { prompt: params.prompt }, - { - agentId: params.agentId, - sessionKey: params.sessionKey, - sessionId: params.sessionId, - workspaceDir: params.workspaceDir, - messageProvider: params.messageProvider ?? undefined, - }, + hookCtx, ); - if (earlyHookResult?.providerOverride) { - provider = earlyHookResult.providerOverride; - log.info(`[hooks] provider overridden to ${provider}`); - } - if (earlyHookResult?.modelOverride) { - modelId = earlyHookResult.modelOverride; - log.info(`[hooks] model overridden to ${modelId}`); - } + modelResolveOverride = { + providerOverride: + modelResolveOverride?.providerOverride ?? legacyResult?.providerOverride, + modelOverride: modelResolveOverride?.modelOverride ?? legacyResult?.modelOverride, + }; } catch (hookErr) { - log.warn(`before_agent_start hook (early) failed: ${String(hookErr)}`); + log.warn( + `before_agent_start hook (legacy model resolve path) failed: ${String(hookErr)}`, + ); } } + if (modelResolveOverride?.providerOverride) { + provider = modelResolveOverride.providerOverride; + log.info(`[hooks] provider overridden to ${provider}`); + } + if (modelResolveOverride?.modelOverride) { + modelId = modelResolveOverride.modelOverride; + log.info(`[hooks] model overridden to ${modelId}`); + } const { model, error, authStorage, modelRegistry } = resolveModel( provider, @@ -511,7 +530,6 @@ export async function runEmbeddedPiAgent( streamParams: params.streamParams, ownerNumbers: params.ownerNumbers, enforceFinalTag: params.enforceFinalTag, - earlyHookResult, }); const { diff --git a/src/agents/pi-embedded-runner/run/attempt.ts b/src/agents/pi-embedded-runner/run/attempt.ts index dcc4a216498..1d79a7c5cc5 100644 --- a/src/agents/pi-embedded-runner/run/attempt.ts +++ b/src/agents/pi-embedded-runner/run/attempt.ts @@ -1,9 +1,10 @@ -import fs from "node:fs/promises"; -import os from "node:os"; import type { AgentMessage } from "@mariozechner/pi-agent-core"; import type { ImageContent } from "@mariozechner/pi-ai"; import { streamSimple } from "@mariozechner/pi-ai"; import { createAgentSession, SessionManager, SettingsManager } from "@mariozechner/pi-coding-agent"; +import fs from "node:fs/promises"; +import os from "node:os"; +import type { EmbeddedRunAttemptParams, EmbeddedRunAttemptResult } from "./types.js"; import { resolveHeartbeatPrompt } from "../../../auto-reply/heartbeat.js"; import { resolveChannelCapabilities } from "../../../config/channel-capabilities.js"; import { getMachineDisplayName } from "../../../infra/machine-name.js"; @@ -103,7 +104,6 @@ import { shouldFlagCompactionTimeout, } from "./compaction-timeout.js"; import { detectAndLoadPromptImages } from "./images.js"; -import type { EmbeddedRunAttemptParams, EmbeddedRunAttemptResult } from "./types.js"; export function injectHistoryImagesIntoMessages( messages: AgentMessage[], @@ -863,31 +863,52 @@ export async function runEmbeddedAttempt( try { const promptStartedAt = Date.now(); - // Run before_agent_start hooks to allow plugins to inject context. - // If run.ts already fired the hook (for model override), reuse its result. + // Run before_prompt_build hooks to allow plugins to inject prompt context. + // Legacy compatibility: before_agent_start is also checked for context fields. let effectivePrompt = params.prompt; - const hookResult = - params.earlyHookResult ?? - (hookRunner?.hasHooks("before_agent_start") - ? await hookRunner - .runBeforeAgentStart( - { - prompt: params.prompt, - messages: activeSession.messages, - }, - { - agentId: hookAgentId, - sessionKey: params.sessionKey, - sessionId: params.sessionId, - workspaceDir: params.workspaceDir, - messageProvider: params.messageProvider ?? undefined, - }, - ) - .catch((hookErr: unknown) => { - log.warn(`before_agent_start hook failed: ${String(hookErr)}`); - return undefined; - }) - : undefined); + const hookCtx = { + agentId: hookAgentId, + sessionKey: params.sessionKey, + sessionId: params.sessionId, + workspaceDir: params.workspaceDir, + messageProvider: params.messageProvider ?? undefined, + }; + const promptBuildResult = hookRunner?.hasHooks("before_prompt_build") + ? await hookRunner + .runBeforePromptBuild( + { + prompt: params.prompt, + messages: activeSession.messages, + }, + hookCtx, + ) + .catch((hookErr: unknown) => { + log.warn(`before_prompt_build hook failed: ${String(hookErr)}`); + return undefined; + }) + : undefined; + const legacyResult = hookRunner?.hasHooks("before_agent_start") + ? await hookRunner + .runBeforeAgentStart( + { + prompt: params.prompt, + messages: activeSession.messages, + }, + hookCtx, + ) + .catch((hookErr: unknown) => { + log.warn( + `before_agent_start hook (legacy prompt build path) failed: ${String(hookErr)}`, + ); + return undefined; + }) + : undefined; + const hookResult = { + systemPrompt: promptBuildResult?.systemPrompt ?? legacyResult?.systemPrompt, + prependContext: [promptBuildResult?.prependContext, legacyResult?.prependContext] + .filter((value): value is string => Boolean(value)) + .join("\n\n"), + }; { if (hookResult?.prependContext) { effectivePrompt = `${hookResult.prependContext}\n\n${params.prompt}`; diff --git a/src/agents/pi-embedded-runner/run/types.ts b/src/agents/pi-embedded-runner/run/types.ts index d1371618bcd..f0d1234875e 100644 --- a/src/agents/pi-embedded-runner/run/types.ts +++ b/src/agents/pi-embedded-runner/run/types.ts @@ -2,7 +2,6 @@ import type { AgentMessage } from "@mariozechner/pi-agent-core"; import type { Api, AssistantMessage, Model } from "@mariozechner/pi-ai"; import type { ThinkLevel } from "../../../auto-reply/thinking.js"; import type { SessionSystemPromptReport } from "../../../config/sessions/types.js"; -import type { PluginHookBeforeAgentStartResult } from "../../../plugins/types.js"; import type { MessagingToolSend } from "../../pi-embedded-messaging.js"; import type { AuthStorage, ModelRegistry } from "../../pi-model-discovery.js"; import type { NormalizedUsage } from "../../usage.js"; @@ -20,8 +19,6 @@ export type EmbeddedRunAttemptParams = EmbeddedRunAttemptBase & { authStorage: AuthStorage; modelRegistry: ModelRegistry; thinkLevel: ThinkLevel; - /** Pre-computed hook result from run.ts to avoid double-firing before_agent_start. */ - earlyHookResult?: PluginHookBeforeAgentStartResult; }; export type EmbeddedRunAttemptResult = { diff --git a/src/plugins/hooks.model-override-wiring.test.ts b/src/plugins/hooks.model-override-wiring.test.ts index 541d29caf18..901e0c9d936 100644 --- a/src/plugins/hooks.model-override-wiring.test.ts +++ b/src/plugins/hooks.model-override-wiring.test.ts @@ -1,32 +1,64 @@ /** - * Layer 2: Model Override Pipeline Wiring Tests + * Layer 2: Explicit model/prompt hook wiring tests. * - * Tests the integration between the hook runner and model override flow. - * Verifies that: - * 1. When hooks return modelOverride/providerOverride, the run pipeline applies them - * 2. The earlyHookResult mechanism prevents double-firing of before_agent_start - * 3. Graceful degradation when hooks throw errors - * - * These tests verify the hook runner contract at the boundary — the same runner - * that's used by both run.ts (early invocation) and attempt.ts (fallback invocation). + * Verifies: + * 1. before_model_resolve applies deterministic provider/model overrides + * 2. before_prompt_build receives session messages and prepends prompt context + * 3. before_agent_start remains a legacy compatibility fallback */ import { beforeEach, describe, expect, it, vi } from "vitest"; -import { createHookRunner } from "./hooks.js"; -import { createEmptyPluginRegistry, type PluginRegistry } from "./registry.js"; import type { - PluginHookBeforeAgentStartEvent, - PluginHookBeforeAgentStartResult, PluginHookAgentContext, + PluginHookBeforeAgentStartResult, + PluginHookBeforeModelResolveEvent, + PluginHookBeforeModelResolveResult, + PluginHookBeforePromptBuildEvent, + PluginHookBeforePromptBuildResult, TypedPluginHookRegistration, } from "./types.js"; +import { createHookRunner } from "./hooks.js"; +import { createEmptyPluginRegistry, type PluginRegistry } from "./registry.js"; -function addBeforeAgentStartHook( +function addBeforeModelResolveHook( registry: PluginRegistry, pluginId: string, handler: ( - event: PluginHookBeforeAgentStartEvent, + event: PluginHookBeforeModelResolveEvent, ctx: PluginHookAgentContext, - ) => PluginHookBeforeAgentStartResult | Promise, + ) => PluginHookBeforeModelResolveResult | Promise, + priority?: number, +) { + registry.typedHooks.push({ + pluginId, + hookName: "before_model_resolve", + handler, + priority, + source: "test", + } as TypedPluginHookRegistration); +} + +function addBeforePromptBuildHook( + registry: PluginRegistry, + pluginId: string, + handler: ( + event: PluginHookBeforePromptBuildEvent, + ctx: PluginHookAgentContext, + ) => PluginHookBeforePromptBuildResult | Promise, + priority?: number, +) { + registry.typedHooks.push({ + pluginId, + hookName: "before_prompt_build", + handler, + priority, + source: "test", + } as TypedPluginHookRegistration); +} + +function addLegacyBeforeAgentStartHook( + registry: PluginRegistry, + pluginId: string, + handler: () => PluginHookBeforeAgentStartResult | Promise, priority?: number, ) { registry.typedHooks.push({ @@ -52,203 +84,134 @@ describe("model override pipeline wiring", () => { registry = createEmptyPluginRegistry(); }); - describe("early invocation (run.ts pattern)", () => { - it("hook receives prompt-only event and returns model override", async () => { + describe("before_model_resolve (run.ts pattern)", () => { + it("hook receives prompt-only event and returns provider/model override", async () => { const handlerSpy = vi.fn( - (_event: PluginHookBeforeAgentStartEvent) => + (_event: PluginHookBeforeModelResolveEvent) => ({ modelOverride: "llama3.3:8b", providerOverride: "ollama", - prependContext: "PII detected: routing to local model", - }) as PluginHookBeforeAgentStartResult, + }) as PluginHookBeforeModelResolveResult, ); - addBeforeAgentStartHook(registry, "router-plugin", handlerSpy); + addBeforeModelResolveHook(registry, "router-plugin", handlerSpy); const runner = createHookRunner(registry); - - // Simulate run.ts early invocation: prompt only, no messages - const result = await runner.runBeforeAgentStart({ prompt: "My SSN is 123-45-6789" }, stubCtx); + const result = await runner.runBeforeModelResolve({ prompt: "PII text" }, stubCtx); expect(handlerSpy).toHaveBeenCalledTimes(1); - expect(handlerSpy).toHaveBeenCalledWith({ prompt: "My SSN is 123-45-6789" }, stubCtx); + expect(handlerSpy).toHaveBeenCalledWith({ prompt: "PII text" }, stubCtx); expect(result?.modelOverride).toBe("llama3.3:8b"); expect(result?.providerOverride).toBe("ollama"); - expect(result?.prependContext).toBe("PII detected: routing to local model"); }); - it("overrides can be applied to mutable provider/model variables", async () => { - addBeforeAgentStartHook(registry, "router-plugin", () => ({ + it("new hook overrides beat legacy before_agent_start fallback", async () => { + addBeforeModelResolveHook(registry, "new-hook", () => ({ modelOverride: "llama3.3:8b", providerOverride: "ollama", })); - - const runner = createHookRunner(registry); - const result = await runner.runBeforeAgentStart({ prompt: "sensitive data" }, stubCtx); - - // Simulate run.ts override application - let provider = "anthropic"; - let modelId = "claude-sonnet-4-5-20250929"; - - if (result?.providerOverride) { - provider = result.providerOverride; - } - if (result?.modelOverride) { - modelId = result.modelOverride; - } - - expect(provider).toBe("ollama"); - expect(modelId).toBe("llama3.3:8b"); - }); - - it("no overrides when hook returns only prependContext", async () => { - addBeforeAgentStartHook(registry, "context-plugin", () => ({ - prependContext: "Additional instructions", + addLegacyBeforeAgentStartHook(registry, "legacy-hook", () => ({ + modelOverride: "gpt-4o", + providerOverride: "openai", })); const runner = createHookRunner(registry); - const result = await runner.runBeforeAgentStart({ prompt: "normal query" }, stubCtx); + const explicit = await runner.runBeforeModelResolve({ prompt: "sensitive" }, stubCtx); + const legacy = await runner.runBeforeAgentStart({ prompt: "sensitive" }, stubCtx); + const merged = { + providerOverride: explicit?.providerOverride ?? legacy?.providerOverride, + modelOverride: explicit?.modelOverride ?? legacy?.modelOverride, + }; - // Simulate run.ts override application - let provider = "anthropic"; - let modelId = "claude-sonnet-4-5-20250929"; - - if (result?.providerOverride) { - provider = result.providerOverride; - } - if (result?.modelOverride) { - modelId = result.modelOverride; - } - - // Original values preserved - expect(provider).toBe("anthropic"); - expect(modelId).toBe("claude-sonnet-4-5-20250929"); + expect(merged.providerOverride).toBe("ollama"); + expect(merged.modelOverride).toBe("llama3.3:8b"); }); }); - describe("earlyHookResult passthrough (attempt.ts pattern)", () => { - it("when earlyHookResult exists, hook does not need to fire again", async () => { - const handlerSpy = vi.fn(() => ({ - modelOverride: "should-not-be-called", - })); - - addBeforeAgentStartHook(registry, "router-plugin", handlerSpy); - const runner = createHookRunner(registry); - - // Simulate the earlyHookResult already computed by run.ts - const earlyHookResult: PluginHookBeforeAgentStartResult = { - modelOverride: "llama3.3:8b", - providerOverride: "ollama", - prependContext: "PII detected", - }; - - // Simulate attempt.ts pattern: use earlyHookResult if present - const hookResult = - earlyHookResult ?? - (runner.hasHooks("before_agent_start") - ? await runner.runBeforeAgentStart({ prompt: "test", messages: [] }, stubCtx) - : undefined); - - expect(handlerSpy).not.toHaveBeenCalled(); - expect(hookResult?.modelOverride).toBe("llama3.3:8b"); - expect(hookResult?.prependContext).toBe("PII detected"); - }); - - it("when earlyHookResult is undefined, hook fires normally with messages", async () => { + describe("before_prompt_build (attempt.ts pattern)", () => { + it("hook receives prompt and messages and can prepend context", async () => { const handlerSpy = vi.fn( - (event: PluginHookBeforeAgentStartEvent) => + (event: PluginHookBeforePromptBuildEvent) => ({ - prependContext: `Saw ${(event.messages ?? []).length} messages`, - }) as PluginHookBeforeAgentStartResult, + prependContext: `Saw ${event.messages.length} messages`, + }) as PluginHookBeforePromptBuildResult, ); - addBeforeAgentStartHook(registry, "context-plugin", handlerSpy); + addBeforePromptBuildHook(registry, "context-plugin", handlerSpy); const runner = createHookRunner(registry); - - const earlyHookResult: PluginHookBeforeAgentStartResult | undefined = undefined; - - // Simulate attempt.ts pattern: fire hook since no early result - const hookResult = - earlyHookResult ?? - (runner.hasHooks("before_agent_start") - ? await runner.runBeforeAgentStart( - { prompt: "test", messages: [{}, {}] as unknown[] }, - stubCtx, - ) - : undefined); + const result = await runner.runBeforePromptBuild( + { prompt: "test", messages: [{}, {}] as unknown[] }, + stubCtx, + ); expect(handlerSpy).toHaveBeenCalledTimes(1); - expect(hookResult?.prependContext).toBe("Saw 2 messages"); + expect(result?.prependContext).toBe("Saw 2 messages"); }); - it("prependContext from earlyHookResult is applied to prompt", async () => { - const earlyHookResult: PluginHookBeforeAgentStartResult = { - prependContext: "PII detected: SSN found. Routing to local model.", - modelOverride: "llama3.3:8b", - providerOverride: "ollama", - }; + it("legacy before_agent_start context can still be merged as fallback", async () => { + addBeforePromptBuildHook(registry, "new-hook", () => ({ + prependContext: "new context", + })); + addLegacyBeforeAgentStartHook(registry, "legacy-hook", () => ({ + prependContext: "legacy context", + })); - // Simulate attempt.ts prompt modification - const originalPrompt = "My SSN is 123-45-6789"; - let effectivePrompt = originalPrompt; - if (earlyHookResult.prependContext) { - effectivePrompt = `${earlyHookResult.prependContext}\n\n${originalPrompt}`; - } - - expect(effectivePrompt).toBe( - "PII detected: SSN found. Routing to local model.\n\nMy SSN is 123-45-6789", + const runner = createHookRunner(registry); + const promptBuild = await runner.runBeforePromptBuild( + { prompt: "test", messages: [{ role: "user", content: "x" }] as unknown[] }, + stubCtx, ); + const legacy = await runner.runBeforeAgentStart( + { prompt: "test", messages: [{ role: "user", content: "x" }] as unknown[] }, + stubCtx, + ); + const prependContext = [promptBuild?.prependContext, legacy?.prependContext] + .filter((value): value is string => Boolean(value)) + .join("\n\n"); + + expect(prependContext).toBe("new context\n\nlegacy context"); }); }); - describe("graceful degradation", () => { - it("hook error does not produce override (run.ts pattern)", async () => { - addBeforeAgentStartHook(registry, "broken-plugin", () => { - throw new Error("plugin crashed"); - }); - - const runner = createHookRunner(registry, { catchErrors: true }); - - // The runner catches errors internally when catchErrors is true - const result = await runner.runBeforeAgentStart({ prompt: "test" }, stubCtx); - - // Result should be undefined since the handler threw - expect(result?.modelOverride).toBeUndefined(); - expect(result?.providerOverride).toBeUndefined(); - }); - - it("one broken plugin does not prevent other plugins from providing overrides", async () => { - addBeforeAgentStartHook( + describe("graceful degradation + hook detection", () => { + it("one broken before_model_resolve plugin does not block other overrides", async () => { + addBeforeModelResolveHook( registry, "broken-plugin", () => { throw new Error("plugin crashed"); }, - 10, // Higher priority, runs first + 10, ); - addBeforeAgentStartHook( + addBeforeModelResolveHook( registry, "router-plugin", () => ({ modelOverride: "llama3.3:8b", providerOverride: "ollama", }), - 1, // Lower priority, runs second + 1, ); const runner = createHookRunner(registry, { catchErrors: true }); - const result = await runner.runBeforeAgentStart({ prompt: "PII data" }, stubCtx); + const result = await runner.runBeforeModelResolve({ prompt: "PII data" }, stubCtx); - // The router plugin's result should still be returned expect(result?.modelOverride).toBe("llama3.3:8b"); expect(result?.providerOverride).toBe("ollama"); }); - it("hasHooks correctly reports when before_agent_start hooks exist", () => { + it("hasHooks reports new and legacy hooks independently", () => { const runner1 = createHookRunner(registry); + expect(runner1.hasHooks("before_model_resolve")).toBe(false); + expect(runner1.hasHooks("before_prompt_build")).toBe(false); expect(runner1.hasHooks("before_agent_start")).toBe(false); - addBeforeAgentStartHook(registry, "plugin-a", () => ({})); + addBeforeModelResolveHook(registry, "plugin-a", () => ({})); + addBeforePromptBuildHook(registry, "plugin-b", () => ({})); + addLegacyBeforeAgentStartHook(registry, "plugin-c", () => ({})); + const runner2 = createHookRunner(registry); + expect(runner2.hasHooks("before_model_resolve")).toBe(true); + expect(runner2.hasHooks("before_prompt_build")).toBe(true); expect(runner2.hasHooks("before_agent_start")).toBe(true); }); }); diff --git a/src/plugins/hooks.phase-hooks.test.ts b/src/plugins/hooks.phase-hooks.test.ts new file mode 100644 index 00000000000..9dcafd917fb --- /dev/null +++ b/src/plugins/hooks.phase-hooks.test.ts @@ -0,0 +1,75 @@ +import { beforeEach, describe, expect, it } from "vitest"; +import type { + PluginHookBeforeModelResolveResult, + PluginHookBeforePromptBuildResult, + TypedPluginHookRegistration, +} from "./types.js"; +import { createHookRunner } from "./hooks.js"; +import { createEmptyPluginRegistry, type PluginRegistry } from "./registry.js"; + +function addTypedHook( + registry: PluginRegistry, + hookName: "before_model_resolve" | "before_prompt_build", + pluginId: string, + handler: () => + | PluginHookBeforeModelResolveResult + | PluginHookBeforePromptBuildResult + | Promise, + priority?: number, +) { + registry.typedHooks.push({ + pluginId, + hookName, + handler, + priority, + source: "test", + } as TypedPluginHookRegistration); +} + +describe("phase hooks merger", () => { + let registry: PluginRegistry; + + beforeEach(() => { + registry = createEmptyPluginRegistry(); + }); + + it("before_model_resolve keeps higher-priority override values", async () => { + addTypedHook(registry, "before_model_resolve", "low", () => ({ modelOverride: "gpt-4o" }), 1); + addTypedHook( + registry, + "before_model_resolve", + "high", + () => ({ modelOverride: "llama3.3:8b", providerOverride: "ollama" }), + 10, + ); + + const runner = createHookRunner(registry); + const result = await runner.runBeforeModelResolve({ prompt: "test" }, {}); + + expect(result?.modelOverride).toBe("llama3.3:8b"); + expect(result?.providerOverride).toBe("ollama"); + }); + + it("before_prompt_build concatenates prependContext and preserves systemPrompt precedence", async () => { + addTypedHook( + registry, + "before_prompt_build", + "high", + () => ({ prependContext: "context A", systemPrompt: "system A" }), + 10, + ); + addTypedHook( + registry, + "before_prompt_build", + "low", + () => ({ prependContext: "context B" }), + 1, + ); + + const runner = createHookRunner(registry); + const result = await runner.runBeforePromptBuild({ prompt: "test", messages: [] }, {}); + + expect(result?.prependContext).toBe("context A\n\ncontext B"); + expect(result?.systemPrompt).toBe("system A"); + }); +}); diff --git a/src/plugins/hooks.ts b/src/plugins/hooks.ts index 24f6f6a91df..19b10404262 100644 --- a/src/plugins/hooks.ts +++ b/src/plugins/hooks.ts @@ -13,6 +13,10 @@ import type { PluginHookAgentEndEvent, PluginHookBeforeAgentStartEvent, PluginHookBeforeAgentStartResult, + PluginHookBeforeModelResolveEvent, + PluginHookBeforeModelResolveResult, + PluginHookBeforePromptBuildEvent, + PluginHookBeforePromptBuildResult, PluginHookBeforeCompactionEvent, PluginHookLlmInputEvent, PluginHookLlmOutputEvent, @@ -45,6 +49,10 @@ export type { PluginHookAgentContext, PluginHookBeforeAgentStartEvent, PluginHookBeforeAgentStartResult, + PluginHookBeforeModelResolveEvent, + PluginHookBeforeModelResolveResult, + PluginHookBeforePromptBuildEvent, + PluginHookBeforePromptBuildResult, PluginHookLlmInputEvent, PluginHookLlmOutputEvent, PluginHookAgentEndEvent, @@ -104,6 +112,26 @@ export function createHookRunner(registry: PluginRegistry, options: HookRunnerOp const logger = options.logger; const catchErrors = options.catchErrors ?? true; + const mergeBeforeModelResolve = ( + acc: PluginHookBeforeModelResolveResult | undefined, + next: PluginHookBeforeModelResolveResult, + ): PluginHookBeforeModelResolveResult => ({ + // Keep the first defined override so higher-priority hooks win. + modelOverride: acc?.modelOverride ?? next.modelOverride, + providerOverride: acc?.providerOverride ?? next.providerOverride, + }); + + const mergeBeforePromptBuild = ( + acc: PluginHookBeforePromptBuildResult | undefined, + next: PluginHookBeforePromptBuildResult, + ): PluginHookBeforePromptBuildResult => ({ + systemPrompt: next.systemPrompt ?? acc?.systemPrompt, + prependContext: + acc?.prependContext && next.prependContext + ? `${acc.prependContext}\n\n${next.prependContext}` + : (next.prependContext ?? acc?.prependContext), + }); + /** * Run a hook that doesn't return a value (fire-and-forget style). * All handlers are executed in parallel for performance. @@ -185,10 +213,41 @@ export function createHookRunner(registry: PluginRegistry, options: HookRunnerOp // Agent Hooks // ========================================================================= + /** + * Run before_model_resolve hook. + * Allows plugins to override provider/model before model resolution. + */ + async function runBeforeModelResolve( + event: PluginHookBeforeModelResolveEvent, + ctx: PluginHookAgentContext, + ): Promise { + return runModifyingHook<"before_model_resolve", PluginHookBeforeModelResolveResult>( + "before_model_resolve", + event, + ctx, + mergeBeforeModelResolve, + ); + } + + /** + * Run before_prompt_build hook. + * Allows plugins to inject context and system prompt before prompt submission. + */ + async function runBeforePromptBuild( + event: PluginHookBeforePromptBuildEvent, + ctx: PluginHookAgentContext, + ): Promise { + return runModifyingHook<"before_prompt_build", PluginHookBeforePromptBuildResult>( + "before_prompt_build", + event, + ctx, + mergeBeforePromptBuild, + ); + } + /** * Run before_agent_start hook. - * Allows plugins to inject context into the system prompt. - * Runs sequentially, merging systemPrompt and prependContext from all handlers. + * Legacy compatibility hook that combines model resolve + prompt build phases. */ async function runBeforeAgentStart( event: PluginHookBeforeAgentStartEvent, @@ -199,14 +258,8 @@ export function createHookRunner(registry: PluginRegistry, options: HookRunnerOp event, ctx, (acc, next) => ({ - systemPrompt: next.systemPrompt ?? acc?.systemPrompt, - prependContext: - acc?.prependContext && next.prependContext - ? `${acc.prependContext}\n\n${next.prependContext}` - : (next.prependContext ?? acc?.prependContext), - // Keep the first defined override so higher-priority hooks win. - modelOverride: acc?.modelOverride ?? next.modelOverride, - providerOverride: acc?.providerOverride ?? next.providerOverride, + ...mergeBeforePromptBuild(acc, next), + ...mergeBeforeModelResolve(acc, next), }), ); } @@ -563,6 +616,8 @@ export function createHookRunner(registry: PluginRegistry, options: HookRunnerOp return { // Agent hooks + runBeforeModelResolve, + runBeforePromptBuild, runBeforeAgentStart, runLlmInput, runLlmOutput, diff --git a/src/plugins/types.ts b/src/plugins/types.ts index d5a4f6b41fe..03d9ffbae56 100644 --- a/src/plugins/types.ts +++ b/src/plugins/types.ts @@ -1,6 +1,6 @@ -import type { IncomingMessage, ServerResponse } from "node:http"; import type { AgentMessage } from "@mariozechner/pi-agent-core"; import type { Command } from "commander"; +import type { IncomingMessage, ServerResponse } from "node:http"; import type { AuthProfileCredential, OAuthCredential } from "../agents/auth-profiles/types.js"; import type { AnyAgentTool } from "../agents/tools/common.js"; import type { ReplyPayload } from "../auto-reply/types.js"; @@ -296,6 +296,8 @@ export type PluginDiagnostic = { // ============================================================================ export type PluginHookName = + | "before_model_resolve" + | "before_prompt_build" | "before_agent_start" | "llm_input" | "llm_output" @@ -324,21 +326,41 @@ export type PluginHookAgentContext = { messageProvider?: string; }; -// before_agent_start hook -export type PluginHookBeforeAgentStartEvent = { +// before_model_resolve hook +export type PluginHookBeforeModelResolveEvent = { + /** User prompt for this run. No session messages are available yet in this phase. */ prompt: string; - messages?: unknown[]; }; -export type PluginHookBeforeAgentStartResult = { - systemPrompt?: string; - prependContext?: string; +export type PluginHookBeforeModelResolveResult = { /** Override the model for this agent run. E.g. "llama3.3:8b" */ modelOverride?: string; /** Override the provider for this agent run. E.g. "ollama" */ providerOverride?: string; }; +// before_prompt_build hook +export type PluginHookBeforePromptBuildEvent = { + prompt: string; + /** Session messages prepared for this run. */ + messages: unknown[]; +}; + +export type PluginHookBeforePromptBuildResult = { + systemPrompt?: string; + prependContext?: string; +}; + +// before_agent_start hook (legacy compatibility: combines both phases) +export type PluginHookBeforeAgentStartEvent = { + prompt: string; + /** Optional because legacy hook can run in pre-session phase. */ + messages?: unknown[]; +}; + +export type PluginHookBeforeAgentStartResult = PluginHookBeforePromptBuildResult & + PluginHookBeforeModelResolveResult; + // llm_input hook export type PluginHookLlmInputEvent = { runId: string; @@ -542,6 +564,17 @@ export type PluginHookGatewayStopEvent = { // Hook handler types mapped by hook name export type PluginHookHandlerMap = { + before_model_resolve: ( + event: PluginHookBeforeModelResolveEvent, + ctx: PluginHookAgentContext, + ) => + | Promise + | PluginHookBeforeModelResolveResult + | void; + before_prompt_build: ( + event: PluginHookBeforePromptBuildEvent, + ctx: PluginHookAgentContext, + ) => Promise | PluginHookBeforePromptBuildResult | void; before_agent_start: ( event: PluginHookBeforeAgentStartEvent, ctx: PluginHookAgentContext,