fix(memory): guard local embedding init race and failure recovery

This commit is contained in:
huangcj
2026-02-25 03:09:41 +08:00
committed by Gustavo Madeira Santana
parent 88ee57124e
commit 81dbef0030
2 changed files with 93 additions and 10 deletions

View File

@@ -471,6 +471,74 @@ describe("local embedding normalization", () => {
});
});
describe("local embedding ensureContext concurrency", () => {
afterEach(() => {
vi.resetAllMocks();
vi.resetModules();
vi.unstubAllGlobals();
vi.doUnmock("./node-llama.js");
});
it("loads the model only once when embedBatch is called concurrently", async () => {
const getLlamaSpy = vi.fn();
const loadModelSpy = vi.fn();
const createContextSpy = vi.fn();
const nodeLlamaModule = await import("./node-llama.js");
vi.spyOn(nodeLlamaModule, "importNodeLlamaCpp").mockResolvedValue({
getLlama: async (...args: unknown[]) => {
getLlamaSpy(...args);
await new Promise((r) => setTimeout(r, 50));
return {
loadModel: async (...modelArgs: unknown[]) => {
loadModelSpy(...modelArgs);
await new Promise((r) => setTimeout(r, 50));
return {
createEmbeddingContext: async () => {
createContextSpy();
return {
getEmbeddingFor: vi.fn().mockResolvedValue({
vector: new Float32Array([1, 0, 0, 0]),
}),
};
},
};
},
};
},
resolveModelFile: async () => "/fake/model.gguf",
LlamaLogLevel: { error: 0 },
} as never);
const { createEmbeddingProvider } = await import("./embeddings.js");
const result = await createEmbeddingProvider({
config: {} as never,
provider: "local",
model: "",
fallback: "none",
});
const provider = requireProvider(result);
const results = await Promise.all([
provider.embedBatch(["text1"]),
provider.embedBatch(["text2"]),
provider.embedBatch(["text3"]),
provider.embedBatch(["text4"]),
]);
expect(results).toHaveLength(4);
for (const embeddings of results) {
expect(embeddings).toHaveLength(1);
expect(embeddings[0]).toHaveLength(4);
}
expect(getLlamaSpy).toHaveBeenCalledTimes(1);
expect(loadModelSpy).toHaveBeenCalledTimes(1);
expect(createContextSpy).toHaveBeenCalledTimes(1);
});
});
describe("FTS-only fallback when no provider available", () => {
it("returns null provider with reason when auto mode finds no providers", async () => {
vi.mocked(authModule.resolveApiKeyForProvider).mockRejectedValue(

View File

@@ -111,19 +111,34 @@ async function createLocalEmbeddingProvider(
let llama: Llama | null = null;
let embeddingModel: LlamaModel | null = null;
let embeddingContext: LlamaEmbeddingContext | null = null;
let initPromise: Promise<LlamaEmbeddingContext> | null = null;
const ensureContext = async () => {
if (!llama) {
llama = await getLlama({ logLevel: LlamaLogLevel.error });
const ensureContext = async (): Promise<LlamaEmbeddingContext> => {
if (embeddingContext) {
return embeddingContext;
}
if (!embeddingModel) {
const resolved = await resolveModelFile(modelPath, modelCacheDir || undefined);
embeddingModel = await llama.loadModel({ modelPath: resolved });
if (initPromise) {
return initPromise;
}
if (!embeddingContext) {
embeddingContext = await embeddingModel.createEmbeddingContext();
}
return embeddingContext;
initPromise = (async () => {
try {
if (!llama) {
llama = await getLlama({ logLevel: LlamaLogLevel.error });
}
if (!embeddingModel) {
const resolved = await resolveModelFile(modelPath, modelCacheDir || undefined);
embeddingModel = await llama.loadModel({ modelPath: resolved });
}
if (!embeddingContext) {
embeddingContext = await embeddingModel.createEmbeddingContext();
}
return embeddingContext;
} catch (err) {
initPromise = null;
throw err;
}
})();
return initPromise;
};
return {