fix(memory): serialize local embedding initialization to avoid duplicate model loads (#15639)

Merged via squash.

Prepared head SHA: a085fc21a8
Co-authored-by: SubtleSpark <43933609+SubtleSpark@users.noreply.github.com>
Co-authored-by: gumadeiras <5599352+gumadeiras@users.noreply.github.com>
Reviewed-by: @gumadeiras
This commit is contained in:
huangcj
2026-03-04 23:09:03 +08:00
committed by GitHub
parent 88ee57124e
commit dc8253a84d
3 changed files with 207 additions and 10 deletions

View File

@@ -82,6 +82,7 @@ Docs: https://docs.openclaw.ai
- Agents/Compaction continuity: expand staged-summary merge instructions to preserve active task status, batch progress, latest user request, and follow-up commitments so compaction handoffs retain in-flight work context. (#8903) thanks @joetomasone.
- Gateway/status self version reporting: make Gateway self version in `openclaw status` prefer runtime `VERSION` (while preserving explicit `OPENCLAW_VERSION` override), preventing stale post-upgrade app version output. (#32655) thanks @liuxiaopai-ai.
- Memory/QMD index isolation: set `QMD_CONFIG_DIR` alongside `XDG_CONFIG_HOME` so QMD config state stays per-agent despite upstream XDG handling bugs, preventing cross-agent collection indexing and excess disk/CPU usage. (#27028) thanks @HenryLoenwind.
- Memory/local embedding initialization hardening: add regression coverage for transient initialization retry and mixed `embedQuery` + `embedBatch` concurrent startup to lock single-flight initialization behavior. (#15639) thanks @SubtleSpark.
- CLI/Coding-agent reliability: switch default `claude-cli` non-interactive args to `--permission-mode bypassPermissions`, auto-normalize legacy `--dangerously-skip-permissions` backend overrides to the modern permission-mode form, align coding-agent + live-test docs with the non-PTY Claude path, and emit session system-event heartbeat notices when CLI watchdog no-output timeouts terminate runs. Related to #28261. Landed from contributor PRs #28610 and #31149. Thanks @niceysam, @cryptomaltese and @vincentkoc.
- ACP/ACPX session bootstrap: retry with `sessions new` when `sessions ensure` returns no session identifiers so ACP spawns avoid `NO_SESSION`/`ACP_TURN_FAILED` failures on affected agents. Related to #28786. Landed from contributor PR #31338. Thanks @Sid-Qin and @vincentkoc.
- LINE/auth boundary hardening synthesis: enforce strict LINE webhook authn/z boundary semantics across pairing-store account scoping, DM/group allowlist separation, fail-closed webhook auth/runtime behavior, and replay/duplication controls (including in-flight replay reservation and post-success dedupe marking). (from #26701, #26683, #25978, #17593, #16619, #31990, #26047, #30584, #18777) Thanks @bmendonca3, @davidahmann, @harshang03, @haosenwang1018, @liuxiaopai-ai, @coygeek, and @Takhoffman.

View File

@@ -471,6 +471,187 @@ describe("local embedding normalization", () => {
});
});
describe("local embedding ensureContext concurrency", () => {
afterEach(() => {
vi.resetAllMocks();
vi.resetModules();
vi.unstubAllGlobals();
vi.doUnmock("./node-llama.js");
});
it("loads the model only once when embedBatch is called concurrently", async () => {
const getLlamaSpy = vi.fn();
const loadModelSpy = vi.fn();
const createContextSpy = vi.fn();
const nodeLlamaModule = await import("./node-llama.js");
vi.spyOn(nodeLlamaModule, "importNodeLlamaCpp").mockResolvedValue({
getLlama: async (...args: unknown[]) => {
getLlamaSpy(...args);
await new Promise((r) => setTimeout(r, 50));
return {
loadModel: async (...modelArgs: unknown[]) => {
loadModelSpy(...modelArgs);
await new Promise((r) => setTimeout(r, 50));
return {
createEmbeddingContext: async () => {
createContextSpy();
return {
getEmbeddingFor: vi.fn().mockResolvedValue({
vector: new Float32Array([1, 0, 0, 0]),
}),
};
},
};
},
};
},
resolveModelFile: async () => "/fake/model.gguf",
LlamaLogLevel: { error: 0 },
} as never);
const { createEmbeddingProvider } = await import("./embeddings.js");
const result = await createEmbeddingProvider({
config: {} as never,
provider: "local",
model: "",
fallback: "none",
});
const provider = requireProvider(result);
const results = await Promise.all([
provider.embedBatch(["text1"]),
provider.embedBatch(["text2"]),
provider.embedBatch(["text3"]),
provider.embedBatch(["text4"]),
]);
expect(results).toHaveLength(4);
for (const embeddings of results) {
expect(embeddings).toHaveLength(1);
expect(embeddings[0]).toHaveLength(4);
}
expect(getLlamaSpy).toHaveBeenCalledTimes(1);
expect(loadModelSpy).toHaveBeenCalledTimes(1);
expect(createContextSpy).toHaveBeenCalledTimes(1);
});
it("retries initialization after a transient ensureContext failure", async () => {
const getLlamaSpy = vi.fn();
const loadModelSpy = vi.fn();
const createContextSpy = vi.fn();
let failFirstGetLlama = true;
const nodeLlamaModule = await import("./node-llama.js");
vi.spyOn(nodeLlamaModule, "importNodeLlamaCpp").mockResolvedValue({
getLlama: async (...args: unknown[]) => {
getLlamaSpy(...args);
if (failFirstGetLlama) {
failFirstGetLlama = false;
throw new Error("transient init failure");
}
return {
loadModel: async (...modelArgs: unknown[]) => {
loadModelSpy(...modelArgs);
return {
createEmbeddingContext: async () => {
createContextSpy();
return {
getEmbeddingFor: vi.fn().mockResolvedValue({
vector: new Float32Array([1, 0, 0, 0]),
}),
};
},
};
},
};
},
resolveModelFile: async () => "/fake/model.gguf",
LlamaLogLevel: { error: 0 },
} as never);
const { createEmbeddingProvider } = await import("./embeddings.js");
const result = await createEmbeddingProvider({
config: {} as never,
provider: "local",
model: "",
fallback: "none",
});
const provider = requireProvider(result);
await expect(provider.embedBatch(["first"])).rejects.toThrow("transient init failure");
const recovered = await provider.embedBatch(["second"]);
expect(recovered).toHaveLength(1);
expect(recovered[0]).toHaveLength(4);
expect(getLlamaSpy).toHaveBeenCalledTimes(2);
expect(loadModelSpy).toHaveBeenCalledTimes(1);
expect(createContextSpy).toHaveBeenCalledTimes(1);
});
it("shares initialization when embedQuery and embedBatch start concurrently", async () => {
const getLlamaSpy = vi.fn();
const loadModelSpy = vi.fn();
const createContextSpy = vi.fn();
const nodeLlamaModule = await import("./node-llama.js");
vi.spyOn(nodeLlamaModule, "importNodeLlamaCpp").mockResolvedValue({
getLlama: async (...args: unknown[]) => {
getLlamaSpy(...args);
await new Promise((r) => setTimeout(r, 50));
return {
loadModel: async (...modelArgs: unknown[]) => {
loadModelSpy(...modelArgs);
await new Promise((r) => setTimeout(r, 50));
return {
createEmbeddingContext: async () => {
createContextSpy();
return {
getEmbeddingFor: vi.fn().mockResolvedValue({
vector: new Float32Array([1, 0, 0, 0]),
}),
};
},
};
},
};
},
resolveModelFile: async () => "/fake/model.gguf",
LlamaLogLevel: { error: 0 },
} as never);
const { createEmbeddingProvider } = await import("./embeddings.js");
const result = await createEmbeddingProvider({
config: {} as never,
provider: "local",
model: "",
fallback: "none",
});
const provider = requireProvider(result);
const [queryA, batch, queryB] = await Promise.all([
provider.embedQuery("query-a"),
provider.embedBatch(["batch-a", "batch-b"]),
provider.embedQuery("query-b"),
]);
expect(queryA).toHaveLength(4);
expect(batch).toHaveLength(2);
expect(queryB).toHaveLength(4);
expect(batch[0]).toHaveLength(4);
expect(batch[1]).toHaveLength(4);
expect(getLlamaSpy).toHaveBeenCalledTimes(1);
expect(loadModelSpy).toHaveBeenCalledTimes(1);
expect(createContextSpy).toHaveBeenCalledTimes(1);
});
});
describe("FTS-only fallback when no provider available", () => {
it("returns null provider with reason when auto mode finds no providers", async () => {
vi.mocked(authModule.resolveApiKeyForProvider).mockRejectedValue(

View File

@@ -111,19 +111,34 @@ async function createLocalEmbeddingProvider(
let llama: Llama | null = null;
let embeddingModel: LlamaModel | null = null;
let embeddingContext: LlamaEmbeddingContext | null = null;
let initPromise: Promise<LlamaEmbeddingContext> | null = null;
const ensureContext = async () => {
if (!llama) {
llama = await getLlama({ logLevel: LlamaLogLevel.error });
const ensureContext = async (): Promise<LlamaEmbeddingContext> => {
if (embeddingContext) {
return embeddingContext;
}
if (!embeddingModel) {
const resolved = await resolveModelFile(modelPath, modelCacheDir || undefined);
embeddingModel = await llama.loadModel({ modelPath: resolved });
if (initPromise) {
return initPromise;
}
if (!embeddingContext) {
embeddingContext = await embeddingModel.createEmbeddingContext();
}
return embeddingContext;
initPromise = (async () => {
try {
if (!llama) {
llama = await getLlama({ logLevel: LlamaLogLevel.error });
}
if (!embeddingModel) {
const resolved = await resolveModelFile(modelPath, modelCacheDir || undefined);
embeddingModel = await llama.loadModel({ modelPath: resolved });
}
if (!embeddingContext) {
embeddingContext = await embeddingModel.createEmbeddingContext();
}
return embeddingContext;
} catch (err) {
initPromise = null;
throw err;
}
})();
return initPromise;
};
return {