feat(plugins): add modelOverride/providerOverride to before_agent_start hook

Enable plugins to override the model and provider for agent runs by
returning modelOverride/providerOverride from the before_agent_start
hook. The hook is now invoked early in run.ts (before resolveModel)
so overrides take effect. The result is passed to attempt.ts via
earlyHookResult to prevent double-firing.

This enables security-critical use cases like routing PII-containing
prompts to local models instead of cloud providers.
This commit is contained in:
Nate Fikru
2026-02-15 12:05:29 -05:00
committed by Peter Steinberger
parent 15dd2cda20
commit b90eb51520
5 changed files with 73 additions and 25 deletions

View File

@@ -1,7 +1,9 @@
import fs from "node:fs/promises"; import fs from "node:fs/promises";
import type { ThinkLevel } from "../../auto-reply/thinking.js"; import type { ThinkLevel } from "../../auto-reply/thinking.js";
import type { PluginHookBeforeAgentStartResult } from "../../plugins/types.js";
import type { RunEmbeddedPiAgentParams } from "./run/params.js"; import type { RunEmbeddedPiAgentParams } from "./run/params.js";
import type { EmbeddedPiAgentMeta, EmbeddedPiRunResult } from "./types.js"; import type { EmbeddedPiAgentMeta, EmbeddedPiRunResult } from "./types.js";
import { getGlobalHookRunner } from "../../plugins/hook-runner-global.js";
import { enqueueCommandInLane } from "../../process/command-queue.js"; import { enqueueCommandInLane } from "../../process/command-queue.js";
import { isMarkdownCapableMessageChannel } from "../../utils/message-channel.js"; import { isMarkdownCapableMessageChannel } from "../../utils/message-channel.js";
import { resolveOpenClawAgentDir } from "../agent-paths.js"; import { resolveOpenClawAgentDir } from "../agent-paths.js";
@@ -198,13 +200,43 @@ export async function runEmbeddedPiAgent(
} }
const prevCwd = process.cwd(); const prevCwd = process.cwd();
const provider = (params.provider ?? DEFAULT_PROVIDER).trim() || DEFAULT_PROVIDER; let provider = (params.provider ?? DEFAULT_PROVIDER).trim() || DEFAULT_PROVIDER;
const modelId = (params.model ?? DEFAULT_MODEL).trim() || DEFAULT_MODEL; let modelId = (params.model ?? DEFAULT_MODEL).trim() || DEFAULT_MODEL;
const agentDir = params.agentDir ?? resolveOpenClawAgentDir(); const agentDir = params.agentDir ?? resolveOpenClawAgentDir();
const fallbackConfigured = const fallbackConfigured =
(params.config?.agents?.defaults?.model?.fallbacks?.length ?? 0) > 0; (params.config?.agents?.defaults?.model?.fallbacks?.length ?? 0) > 0;
await ensureOpenClawModelsJson(params.config, agentDir); await ensureOpenClawModelsJson(params.config, agentDir);
// Run before_agent_start hooks early so plugins can override the model
// before it gets resolved. The hook result is passed downstream to
// attempt.ts to avoid double-firing.
let earlyHookResult: PluginHookBeforeAgentStartResult | undefined;
const hookRunner = getGlobalHookRunner();
if (hookRunner?.hasHooks("before_agent_start")) {
try {
earlyHookResult = await hookRunner.runBeforeAgentStart(
{ prompt: params.prompt },
{
agentId: params.agentId,
sessionKey: params.sessionKey,
sessionId: params.sessionId,
workspaceDir: params.workspaceDir,
messageProvider: params.messageProvider ?? undefined,
},
);
if (earlyHookResult?.providerOverride) {
provider = earlyHookResult.providerOverride;
log.info(`[hooks] provider overridden to ${provider}`);
}
if (earlyHookResult?.modelOverride) {
modelId = earlyHookResult.modelOverride;
log.info(`[hooks] model overridden to ${modelId}`);
}
} catch (hookErr) {
log.warn(`before_agent_start hook (early) failed: ${String(hookErr)}`);
}
}
const { model, error, authStorage, modelRegistry } = resolveModel( const { model, error, authStorage, modelRegistry } = resolveModel(
provider, provider,
modelId, modelId,
@@ -479,6 +511,7 @@ export async function runEmbeddedPiAgent(
streamParams: params.streamParams, streamParams: params.streamParams,
ownerNumbers: params.ownerNumbers, ownerNumbers: params.ownerNumbers,
enforceFinalTag: params.enforceFinalTag, enforceFinalTag: params.enforceFinalTag,
earlyHookResult,
}); });
const { const {

View File

@@ -850,31 +850,37 @@ export async function runEmbeddedAttempt(
try { try {
const promptStartedAt = Date.now(); const promptStartedAt = Date.now();
// Run before_agent_start hooks to allow plugins to inject context // Run before_agent_start hooks to allow plugins to inject context.
// If run.ts already fired the hook (for model override), reuse its result.
let effectivePrompt = params.prompt; let effectivePrompt = params.prompt;
if (hookRunner?.hasHooks("before_agent_start")) { const hookResult =
try { params.earlyHookResult ??
const hookResult = await hookRunner.runBeforeAgentStart( (hookRunner?.hasHooks("before_agent_start")
{ ? await hookRunner
prompt: params.prompt, .runBeforeAgentStart(
messages: activeSession.messages, {
}, prompt: params.prompt,
{ messages: activeSession.messages,
agentId: hookAgentId, },
sessionKey: params.sessionKey, {
sessionId: params.sessionId, agentId: hookAgentId,
workspaceDir: params.workspaceDir, sessionKey: params.sessionKey,
messageProvider: params.messageProvider ?? undefined, sessionId: params.sessionId,
}, workspaceDir: params.workspaceDir,
messageProvider: params.messageProvider ?? undefined,
},
)
.catch((hookErr: unknown) => {
log.warn(`before_agent_start hook failed: ${String(hookErr)}`);
return undefined;
})
: undefined);
{
if (hookResult?.prependContext) {
effectivePrompt = `${hookResult.prependContext}\n\n${params.prompt}`;
log.debug(
`hooks: prepended context to prompt (${hookResult.prependContext.length} chars)`,
); );
if (hookResult?.prependContext) {
effectivePrompt = `${hookResult.prependContext}\n\n${params.prompt}`;
log.debug(
`hooks: prepended context to prompt (${hookResult.prependContext.length} chars)`,
);
}
} catch (hookErr) {
log.warn(`before_agent_start hook failed: ${String(hookErr)}`);
} }
} }

View File

@@ -2,6 +2,7 @@ import type { AgentMessage } from "@mariozechner/pi-agent-core";
import type { Api, AssistantMessage, Model } from "@mariozechner/pi-ai"; import type { Api, AssistantMessage, Model } from "@mariozechner/pi-ai";
import type { ThinkLevel } from "../../../auto-reply/thinking.js"; import type { ThinkLevel } from "../../../auto-reply/thinking.js";
import type { SessionSystemPromptReport } from "../../../config/sessions/types.js"; import type { SessionSystemPromptReport } from "../../../config/sessions/types.js";
import type { PluginHookBeforeAgentStartResult } from "../../../plugins/types.js";
import type { MessagingToolSend } from "../../pi-embedded-messaging.js"; import type { MessagingToolSend } from "../../pi-embedded-messaging.js";
import type { AuthStorage, ModelRegistry } from "../../pi-model-discovery.js"; import type { AuthStorage, ModelRegistry } from "../../pi-model-discovery.js";
import type { NormalizedUsage } from "../../usage.js"; import type { NormalizedUsage } from "../../usage.js";
@@ -19,6 +20,8 @@ export type EmbeddedRunAttemptParams = EmbeddedRunAttemptBase & {
authStorage: AuthStorage; authStorage: AuthStorage;
modelRegistry: ModelRegistry; modelRegistry: ModelRegistry;
thinkLevel: ThinkLevel; thinkLevel: ThinkLevel;
/** Pre-computed hook result from run.ts to avoid double-firing before_agent_start. */
earlyHookResult?: PluginHookBeforeAgentStartResult;
}; };
export type EmbeddedRunAttemptResult = { export type EmbeddedRunAttemptResult = {

View File

@@ -200,6 +200,8 @@ export function createHookRunner(registry: PluginRegistry, options: HookRunnerOp
acc?.prependContext && next.prependContext acc?.prependContext && next.prependContext
? `${acc.prependContext}\n\n${next.prependContext}` ? `${acc.prependContext}\n\n${next.prependContext}`
: (next.prependContext ?? acc?.prependContext), : (next.prependContext ?? acc?.prependContext),
modelOverride: next.modelOverride ?? acc?.modelOverride,
providerOverride: next.providerOverride ?? acc?.providerOverride,
}), }),
); );
} }

View File

@@ -332,6 +332,10 @@ export type PluginHookBeforeAgentStartEvent = {
export type PluginHookBeforeAgentStartResult = { export type PluginHookBeforeAgentStartResult = {
systemPrompt?: string; systemPrompt?: string;
prependContext?: string; prependContext?: string;
/** Override the model for this agent run. E.g. "llama3.3:8b" */
modelOverride?: string;
/** Override the provider for this agent run. E.g. "ollama" */
providerOverride?: string;
}; };
// llm_input hook // llm_input hook