refactor(agents): centralize model fallback resolution

This commit is contained in:
Peter Steinberger
2026-02-25 04:32:25 +00:00
parent dd6ad0da8c
commit 9beec48e9c
7 changed files with 205 additions and 61 deletions

View File

@@ -2,13 +2,16 @@ import path from "node:path";
import { afterEach, describe, expect, it, vi } from "vitest"; import { afterEach, describe, expect, it, vi } from "vitest";
import type { OpenClawConfig } from "../config/config.js"; import type { OpenClawConfig } from "../config/config.js";
import { import {
hasConfiguredModelFallbacks,
resolveAgentConfig, resolveAgentConfig,
resolveAgentDir, resolveAgentDir,
resolveAgentEffectiveModelPrimary, resolveAgentEffectiveModelPrimary,
resolveAgentExplicitModelPrimary, resolveAgentExplicitModelPrimary,
resolveFallbackAgentId,
resolveEffectiveModelFallbacks, resolveEffectiveModelFallbacks,
resolveAgentModelFallbacksOverride, resolveAgentModelFallbacksOverride,
resolveAgentModelPrimary, resolveAgentModelPrimary,
resolveRunModelFallbacksOverride,
resolveAgentWorkspaceDir, resolveAgentWorkspaceDir,
} from "./agent-scope.js"; } from "./agent-scope.js";
@@ -210,6 +213,109 @@ describe("resolveAgentConfig", () => {
).toEqual([]); ).toEqual([]);
}); });
it("resolves fallback agent id from explicit agent id first", () => {
expect(
resolveFallbackAgentId({
agentId: "Support",
sessionKey: "agent:main:session",
}),
).toBe("support");
});
it("resolves fallback agent id from session key when explicit id is missing", () => {
expect(
resolveFallbackAgentId({
sessionKey: "agent:worker:session",
}),
).toBe("worker");
});
it("resolves run fallback overrides via shared helper", () => {
const cfg: OpenClawConfig = {
agents: {
defaults: {
model: {
fallbacks: ["openai/gpt-4.1"],
},
},
list: [
{
id: "support",
model: {
fallbacks: ["openai/gpt-5.2"],
},
},
],
},
};
expect(
resolveRunModelFallbacksOverride({
cfg,
agentId: "support",
sessionKey: "agent:main:session",
}),
).toEqual(["openai/gpt-5.2"]);
expect(
resolveRunModelFallbacksOverride({
cfg,
agentId: undefined,
sessionKey: "agent:support:session",
}),
).toEqual(["openai/gpt-5.2"]);
});
it("computes whether any model fallbacks are configured via shared helper", () => {
const cfgDefaultsOnly: OpenClawConfig = {
agents: {
defaults: {
model: {
fallbacks: ["openai/gpt-4.1"],
},
},
list: [{ id: "main" }],
},
};
expect(
hasConfiguredModelFallbacks({
cfg: cfgDefaultsOnly,
sessionKey: "agent:main:session",
}),
).toBe(true);
const cfgAgentOverrideOnly: OpenClawConfig = {
agents: {
defaults: {
model: {
fallbacks: [],
},
},
list: [
{
id: "support",
model: {
fallbacks: ["openai/gpt-5.2"],
},
},
],
},
};
expect(
hasConfiguredModelFallbacks({
cfg: cfgAgentOverrideOnly,
agentId: "support",
sessionKey: "agent:support:session",
}),
).toBe(true);
expect(
hasConfiguredModelFallbacks({
cfg: cfgAgentOverrideOnly,
agentId: "main",
sessionKey: "agent:main:session",
}),
).toBe(false);
});
it("should return agent-specific sandbox config", () => { it("should return agent-specific sandbox config", () => {
const cfg: OpenClawConfig = { const cfg: OpenClawConfig = {
agents: { agents: {

View File

@@ -7,6 +7,7 @@ import {
DEFAULT_AGENT_ID, DEFAULT_AGENT_ID,
normalizeAgentId, normalizeAgentId,
parseAgentSessionKey, parseAgentSessionKey,
resolveAgentIdFromSessionKey,
} from "../routing/session-key.js"; } from "../routing/session-key.js";
import { resolveUserPath } from "../utils.js"; import { resolveUserPath } from "../utils.js";
import { normalizeSkillFilter } from "./skills/filter.js"; import { normalizeSkillFilter } from "./skills/filter.js";
@@ -19,7 +20,7 @@ function stripNullBytes(s: string): string {
return s.replace(/\0/g, ""); return s.replace(/\0/g, "");
} }
export { resolveAgentIdFromSessionKey } from "../routing/session-key.js"; export { resolveAgentIdFromSessionKey };
type AgentEntry = NonNullable<NonNullable<OpenClawConfig["agents"]>["list"]>[number]; type AgentEntry = NonNullable<NonNullable<OpenClawConfig["agents"]>["list"]>[number];
@@ -203,6 +204,41 @@ export function resolveAgentModelFallbacksOverride(
return Array.isArray(raw.fallbacks) ? raw.fallbacks : undefined; return Array.isArray(raw.fallbacks) ? raw.fallbacks : undefined;
} }
export function resolveFallbackAgentId(params: {
agentId?: string | null;
sessionKey?: string | null;
}): string {
const explicitAgentId = typeof params.agentId === "string" ? params.agentId.trim() : "";
if (explicitAgentId) {
return normalizeAgentId(explicitAgentId);
}
return resolveAgentIdFromSessionKey(params.sessionKey);
}
export function resolveRunModelFallbacksOverride(params: {
cfg: OpenClawConfig | undefined;
agentId?: string | null;
sessionKey?: string | null;
}): string[] | undefined {
if (!params.cfg) {
return undefined;
}
return resolveAgentModelFallbacksOverride(
params.cfg,
resolveFallbackAgentId({ agentId: params.agentId, sessionKey: params.sessionKey }),
);
}
export function hasConfiguredModelFallbacks(params: {
cfg: OpenClawConfig | undefined;
agentId?: string | null;
sessionKey?: string | null;
}): boolean {
const fallbacksOverride = resolveRunModelFallbacksOverride(params);
const defaultFallbacks = resolveAgentModelFallbackValues(params.cfg?.agents?.defaults?.model);
return (fallbacksOverride ?? defaultFallbacks).length > 0;
}
export function resolveEffectiveModelFallbacks(params: { export function resolveEffectiveModelFallbacks(params: {
cfg: OpenClawConfig; cfg: OpenClawConfig;
agentId: string; agentId: string;

View File

@@ -63,7 +63,8 @@ function shouldRethrowAbort(err: unknown): boolean {
function createModelCandidateCollector(allowlist: Set<string> | null | undefined): { function createModelCandidateCollector(allowlist: Set<string> | null | undefined): {
candidates: ModelCandidate[]; candidates: ModelCandidate[];
addCandidate: (candidate: ModelCandidate, enforceAllowlist: boolean) => void; addExplicitCandidate: (candidate: ModelCandidate) => void;
addAllowlistedCandidate: (candidate: ModelCandidate) => void;
} { } {
const seen = new Set<string>(); const seen = new Set<string>();
const candidates: ModelCandidate[] = []; const candidates: ModelCandidate[] = [];
@@ -83,7 +84,14 @@ function createModelCandidateCollector(allowlist: Set<string> | null | undefined
candidates.push(candidate); candidates.push(candidate);
}; };
return { candidates, addCandidate }; const addExplicitCandidate = (candidate: ModelCandidate) => {
addCandidate(candidate, false);
};
const addAllowlistedCandidate = (candidate: ModelCandidate) => {
addCandidate(candidate, true);
};
return { candidates, addExplicitCandidate, addAllowlistedCandidate };
} }
type ModelFallbackErrorHandler = (attempt: { type ModelFallbackErrorHandler = (attempt: {
@@ -138,9 +146,10 @@ function resolveImageFallbackCandidates(params: {
cfg: params.cfg, cfg: params.cfg,
defaultProvider: params.defaultProvider, defaultProvider: params.defaultProvider,
}); });
const { candidates, addCandidate } = createModelCandidateCollector(allowlist); const { candidates, addExplicitCandidate, addAllowlistedCandidate } =
createModelCandidateCollector(allowlist);
const addRaw = (raw: string, enforceAllowlist: boolean) => { const addRaw = (raw: string, opts?: { allowlist?: boolean }) => {
const resolved = resolveModelRefFromString({ const resolved = resolveModelRefFromString({
raw: String(raw ?? ""), raw: String(raw ?? ""),
defaultProvider: params.defaultProvider, defaultProvider: params.defaultProvider,
@@ -149,15 +158,19 @@ function resolveImageFallbackCandidates(params: {
if (!resolved) { if (!resolved) {
return; return;
} }
addCandidate(resolved.ref, enforceAllowlist); if (opts?.allowlist) {
addAllowlistedCandidate(resolved.ref);
return;
}
addExplicitCandidate(resolved.ref);
}; };
if (params.modelOverride?.trim()) { if (params.modelOverride?.trim()) {
addRaw(params.modelOverride, false); addRaw(params.modelOverride);
} else { } else {
const primary = resolveAgentModelPrimaryValue(params.cfg?.agents?.defaults?.imageModel); const primary = resolveAgentModelPrimaryValue(params.cfg?.agents?.defaults?.imageModel);
if (primary?.trim()) { if (primary?.trim()) {
addRaw(primary, false); addRaw(primary);
} }
} }
@@ -166,7 +179,7 @@ function resolveImageFallbackCandidates(params: {
for (const raw of imageFallbacks) { for (const raw of imageFallbacks) {
// Explicitly configured image fallbacks should remain reachable even when a // Explicitly configured image fallbacks should remain reachable even when a
// model allowlist is present. // model allowlist is present.
addRaw(raw, false); addRaw(raw);
} }
return candidates; return candidates;
@@ -200,9 +213,9 @@ function resolveFallbackCandidates(params: {
cfg: params.cfg, cfg: params.cfg,
defaultProvider, defaultProvider,
}); });
const { candidates, addCandidate } = createModelCandidateCollector(allowlist); const { candidates, addExplicitCandidate } = createModelCandidateCollector(allowlist);
addCandidate(normalizedPrimary, false); addExplicitCandidate(normalizedPrimary);
const modelFallbacks = (() => { const modelFallbacks = (() => {
if (params.fallbacksOverride !== undefined) { if (params.fallbacksOverride !== undefined) {
@@ -239,11 +252,11 @@ function resolveFallbackCandidates(params: {
} }
// Fallbacks are explicit user intent; do not silently filter them by the // Fallbacks are explicit user intent; do not silently filter them by the
// model allowlist. // model allowlist.
addCandidate(resolved.ref, false); addExplicitCandidate(resolved.ref);
} }
if (params.fallbacksOverride === undefined && primary?.provider && primary.model) { if (params.fallbacksOverride === undefined && primary?.provider && primary.model) {
addCandidate({ provider: primary.provider, model: primary.model }, false); addExplicitCandidate({ provider: primary.provider, model: primary.model });
} }
return candidates; return candidates;

View File

@@ -1,14 +1,13 @@
import { randomBytes } from "node:crypto"; import { randomBytes } from "node:crypto";
import fs from "node:fs/promises"; import fs from "node:fs/promises";
import type { ThinkLevel } from "../../auto-reply/thinking.js"; import type { ThinkLevel } from "../../auto-reply/thinking.js";
import { resolveAgentModelFallbackValues } from "../../config/model-input.js";
import { generateSecureToken } from "../../infra/secure-random.js"; import { generateSecureToken } from "../../infra/secure-random.js";
import { getGlobalHookRunner } from "../../plugins/hook-runner-global.js"; import { getGlobalHookRunner } from "../../plugins/hook-runner-global.js";
import type { PluginHookBeforeAgentStartResult } from "../../plugins/types.js"; import type { PluginHookBeforeAgentStartResult } from "../../plugins/types.js";
import { enqueueCommandInLane } from "../../process/command-queue.js"; import { enqueueCommandInLane } from "../../process/command-queue.js";
import { isMarkdownCapableMessageChannel } from "../../utils/message-channel.js"; import { isMarkdownCapableMessageChannel } from "../../utils/message-channel.js";
import { resolveOpenClawAgentDir } from "../agent-paths.js"; import { resolveOpenClawAgentDir } from "../agent-paths.js";
import { resolveAgentModelFallbacksOverride } from "../agent-scope.js"; import { hasConfiguredModelFallbacks } from "../agent-scope.js";
import { import {
isProfileInCooldown, isProfileInCooldown,
markAuthProfileFailure, markAuthProfileFailure,
@@ -232,15 +231,11 @@ export async function runEmbeddedPiAgent(
let provider = (params.provider ?? DEFAULT_PROVIDER).trim() || DEFAULT_PROVIDER; let provider = (params.provider ?? DEFAULT_PROVIDER).trim() || DEFAULT_PROVIDER;
let modelId = (params.model ?? DEFAULT_MODEL).trim() || DEFAULT_MODEL; let modelId = (params.model ?? DEFAULT_MODEL).trim() || DEFAULT_MODEL;
const agentDir = params.agentDir ?? resolveOpenClawAgentDir(); const agentDir = params.agentDir ?? resolveOpenClawAgentDir();
const agentFallbacksOverride = const fallbackConfigured = hasConfiguredModelFallbacks({
params.config && params.agentId cfg: params.config,
? resolveAgentModelFallbacksOverride(params.config, params.agentId) agentId: params.agentId,
: undefined; sessionKey: params.sessionKey,
const fallbackConfigured = });
(
agentFallbacksOverride ??
resolveAgentModelFallbackValues(params.config?.agents?.defaults?.model)
).length > 0;
await ensureOpenClawModelsJson(params.config, agentDir); await ensureOpenClawModelsJson(params.config, agentDir);
// Run before_model_resolve hooks early so plugins can override the // Run before_model_resolve hooks early so plugins can override the

View File

@@ -2,19 +2,13 @@ import { beforeEach, describe, expect, it, vi } from "vitest";
import type { FollowupRun } from "./queue.js"; import type { FollowupRun } from "./queue.js";
const hoisted = vi.hoisted(() => { const hoisted = vi.hoisted(() => {
const resolveAgentModelFallbacksOverrideMock = vi.fn(); const resolveRunModelFallbacksOverrideMock = vi.fn();
const resolveAgentIdFromSessionKeyMock = vi.fn(); return { resolveRunModelFallbacksOverrideMock };
return { resolveAgentModelFallbacksOverrideMock, resolveAgentIdFromSessionKeyMock };
}); });
vi.mock("../../agents/agent-scope.js", () => ({ vi.mock("../../agents/agent-scope.js", () => ({
resolveAgentModelFallbacksOverride: (...args: unknown[]) => resolveRunModelFallbacksOverride: (...args: unknown[]) =>
hoisted.resolveAgentModelFallbacksOverrideMock(...args), hoisted.resolveRunModelFallbacksOverrideMock(...args),
}));
vi.mock("../../config/sessions.js", () => ({
resolveAgentIdFromSessionKey: (...args: unknown[]) =>
hoisted.resolveAgentIdFromSessionKeyMock(...args),
})); }));
const { const {
@@ -50,22 +44,20 @@ function makeRun(overrides: Partial<FollowupRun["run"]> = {}): FollowupRun["run"
describe("agent-runner-utils", () => { describe("agent-runner-utils", () => {
beforeEach(() => { beforeEach(() => {
hoisted.resolveAgentModelFallbacksOverrideMock.mockClear(); hoisted.resolveRunModelFallbacksOverrideMock.mockClear();
hoisted.resolveAgentIdFromSessionKeyMock.mockClear();
}); });
it("resolves model fallback options from run context", () => { it("resolves model fallback options from run context", () => {
hoisted.resolveAgentIdFromSessionKeyMock.mockReturnValue("agent-id"); hoisted.resolveRunModelFallbacksOverrideMock.mockReturnValue(["fallback-model"]);
hoisted.resolveAgentModelFallbacksOverrideMock.mockReturnValue(["fallback-model"]);
const run = makeRun(); const run = makeRun();
const resolved = resolveModelFallbackOptions(run); const resolved = resolveModelFallbackOptions(run);
expect(hoisted.resolveAgentIdFromSessionKeyMock).not.toHaveBeenCalled(); expect(hoisted.resolveRunModelFallbacksOverrideMock).toHaveBeenCalledWith({
expect(hoisted.resolveAgentModelFallbacksOverrideMock).toHaveBeenCalledWith( cfg: run.config,
run.config, agentId: run.agentId,
run.agentId, sessionKey: run.sessionKey,
); });
expect(resolved).toEqual({ expect(resolved).toEqual({
cfg: run.config, cfg: run.config,
provider: run.provider, provider: run.provider,
@@ -75,18 +67,17 @@ describe("agent-runner-utils", () => {
}); });
}); });
it("falls back to sessionKey agent id when run.agentId is missing", () => { it("passes through missing agentId for helper-based fallback resolution", () => {
hoisted.resolveAgentIdFromSessionKeyMock.mockReturnValue("agent-from-session-key"); hoisted.resolveRunModelFallbacksOverrideMock.mockReturnValue(["fallback-model"]);
hoisted.resolveAgentModelFallbacksOverrideMock.mockReturnValue(["fallback-model"]);
const run = makeRun({ agentId: undefined }); const run = makeRun({ agentId: undefined });
const resolved = resolveModelFallbackOptions(run); const resolved = resolveModelFallbackOptions(run);
expect(hoisted.resolveAgentIdFromSessionKeyMock).toHaveBeenCalledWith(run.sessionKey); expect(hoisted.resolveRunModelFallbacksOverrideMock).toHaveBeenCalledWith({
expect(hoisted.resolveAgentModelFallbacksOverrideMock).toHaveBeenCalledWith( cfg: run.config,
run.config, agentId: undefined,
"agent-from-session-key", sessionKey: run.sessionKey,
); });
expect(resolved.fallbacksOverride).toEqual(["fallback-model"]); expect(resolved.fallbacksOverride).toEqual(["fallback-model"]);
}); });

View File

@@ -1,10 +1,9 @@
import { resolveAgentModelFallbacksOverride } from "../../agents/agent-scope.js"; import { resolveRunModelFallbacksOverride } from "../../agents/agent-scope.js";
import type { NormalizedUsage } from "../../agents/usage.js"; import type { NormalizedUsage } from "../../agents/usage.js";
import { getChannelDock } from "../../channels/dock.js"; import { getChannelDock } from "../../channels/dock.js";
import type { ChannelId, ChannelThreadingToolContext } from "../../channels/plugins/types.js"; import type { ChannelId, ChannelThreadingToolContext } from "../../channels/plugins/types.js";
import { normalizeAnyChannelId, normalizeChannelId } from "../../channels/registry.js"; import { normalizeAnyChannelId, normalizeChannelId } from "../../channels/registry.js";
import type { OpenClawConfig } from "../../config/config.js"; import type { OpenClawConfig } from "../../config/config.js";
import { resolveAgentIdFromSessionKey } from "../../config/sessions.js";
import { isReasoningTagProvider } from "../../utils/provider-utils.js"; import { isReasoningTagProvider } from "../../utils/provider-utils.js";
import { estimateUsageCost, formatTokenCount, formatUsd } from "../../utils/usage-format.js"; import { estimateUsageCost, formatTokenCount, formatUsd } from "../../utils/usage-format.js";
import type { TemplateContext } from "../templating.js"; import type { TemplateContext } from "../templating.js";
@@ -147,13 +146,16 @@ export const resolveEnforceFinalTag = (run: FollowupRun["run"], provider: string
Boolean(run.enforceFinalTag || isReasoningTagProvider(provider)); Boolean(run.enforceFinalTag || isReasoningTagProvider(provider));
export function resolveModelFallbackOptions(run: FollowupRun["run"]) { export function resolveModelFallbackOptions(run: FollowupRun["run"]) {
const fallbackAgentId = run.agentId ?? resolveAgentIdFromSessionKey(run.sessionKey);
return { return {
cfg: run.config, cfg: run.config,
provider: run.provider, provider: run.provider,
model: run.model, model: run.model,
agentDir: run.agentDir, agentDir: run.agentDir,
fallbacksOverride: resolveAgentModelFallbacksOverride(run.config, fallbackAgentId), fallbacksOverride: resolveRunModelFallbacksOverride({
cfg: run.config,
agentId: run.agentId,
sessionKey: run.sessionKey,
}),
}; };
} }

View File

@@ -1,10 +1,10 @@
import crypto from "node:crypto"; import crypto from "node:crypto";
import { resolveAgentModelFallbacksOverride } from "../../agents/agent-scope.js"; import { resolveRunModelFallbacksOverride } from "../../agents/agent-scope.js";
import { lookupContextTokens } from "../../agents/context.js"; import { lookupContextTokens } from "../../agents/context.js";
import { DEFAULT_CONTEXT_TOKENS } from "../../agents/defaults.js"; import { DEFAULT_CONTEXT_TOKENS } from "../../agents/defaults.js";
import { runWithModelFallback } from "../../agents/model-fallback.js"; import { runWithModelFallback } from "../../agents/model-fallback.js";
import { runEmbeddedPiAgent } from "../../agents/pi-embedded.js"; import { runEmbeddedPiAgent } from "../../agents/pi-embedded.js";
import { resolveAgentIdFromSessionKey, type SessionEntry } from "../../config/sessions.js"; import type { SessionEntry } from "../../config/sessions.js";
import type { TypingMode } from "../../config/types.js"; import type { TypingMode } from "../../config/types.js";
import { logVerbose } from "../../globals.js"; import { logVerbose } from "../../globals.js";
import { registerAgentRunContext } from "../../infra/agent-events.js"; import { registerAgentRunContext } from "../../infra/agent-events.js";
@@ -133,10 +133,11 @@ export function createFollowupRunner(params: {
provider: queued.run.provider, provider: queued.run.provider,
model: queued.run.model, model: queued.run.model,
agentDir: queued.run.agentDir, agentDir: queued.run.agentDir,
fallbacksOverride: resolveAgentModelFallbacksOverride( fallbacksOverride: resolveRunModelFallbacksOverride({
queued.run.config, cfg: queued.run.config,
queued.run.agentId ?? resolveAgentIdFromSessionKey(queued.run.sessionKey), agentId: queued.run.agentId,
), sessionKey: queued.run.sessionKey,
}),
run: (provider, model) => { run: (provider, model) => {
const authProfile = resolveRunAuthProfile(queued.run, provider); const authProfile = resolveRunAuthProfile(queued.run, provider);
return runEmbeddedPiAgent({ return runEmbeddedPiAgent({