refactor(memory): dedupe local embedding init concurrency fixtures

This commit is contained in:
Peter Steinberger
2026-03-07 17:36:42 +00:00
parent 98ed7f57c6
commit a96ef12061

View File

@@ -516,20 +516,32 @@ describe("local embedding ensureContext concurrency", () => {
vi.doUnmock("./node-llama.js");
});
it("loads the model only once when embedBatch is called concurrently", async () => {
async function setupLocalProviderWithMockedInit(params?: {
initializationDelayMs?: number;
failFirstGetLlama?: boolean;
}) {
const getLlamaSpy = vi.fn();
const loadModelSpy = vi.fn();
const createContextSpy = vi.fn();
let shouldFail = params?.failFirstGetLlama ?? false;
const nodeLlamaModule = await import("./node-llama.js");
vi.spyOn(nodeLlamaModule, "importNodeLlamaCpp").mockResolvedValue({
getLlama: async (...args: unknown[]) => {
getLlamaSpy(...args);
await new Promise((r) => setTimeout(r, 50));
if (shouldFail) {
shouldFail = false;
throw new Error("transient init failure");
}
if (params?.initializationDelayMs) {
await new Promise((r) => setTimeout(r, params.initializationDelayMs));
}
return {
loadModel: async (...modelArgs: unknown[]) => {
loadModelSpy(...modelArgs);
await new Promise((r) => setTimeout(r, 50));
if (params?.initializationDelayMs) {
await new Promise((r) => setTimeout(r, params.initializationDelayMs));
}
return {
createEmbeddingContext: async () => {
createContextSpy();
@@ -548,7 +560,6 @@ describe("local embedding ensureContext concurrency", () => {
} as never);
const { createEmbeddingProvider } = await import("./embeddings.js");
const result = await createEmbeddingProvider({
config: {} as never,
provider: "local",
@@ -556,7 +567,20 @@ describe("local embedding ensureContext concurrency", () => {
fallback: "none",
});
const provider = requireProvider(result);
return {
provider: requireProvider(result),
getLlamaSpy,
loadModelSpy,
createContextSpy,
};
}
it("loads the model only once when embedBatch is called concurrently", async () => {
const { provider, getLlamaSpy, loadModelSpy, createContextSpy } =
await setupLocalProviderWithMockedInit({
initializationDelayMs: 50,
});
const results = await Promise.all([
provider.embedBatch(["text1"]),
provider.embedBatch(["text2"]),
@@ -576,49 +600,11 @@ describe("local embedding ensureContext concurrency", () => {
});
it("retries initialization after a transient ensureContext failure", async () => {
const getLlamaSpy = vi.fn();
const loadModelSpy = vi.fn();
const createContextSpy = vi.fn();
const { provider, getLlamaSpy, loadModelSpy, createContextSpy } =
await setupLocalProviderWithMockedInit({
failFirstGetLlama: true,
});
let failFirstGetLlama = true;
const nodeLlamaModule = await import("./node-llama.js");
vi.spyOn(nodeLlamaModule, "importNodeLlamaCpp").mockResolvedValue({
getLlama: async (...args: unknown[]) => {
getLlamaSpy(...args);
if (failFirstGetLlama) {
failFirstGetLlama = false;
throw new Error("transient init failure");
}
return {
loadModel: async (...modelArgs: unknown[]) => {
loadModelSpy(...modelArgs);
return {
createEmbeddingContext: async () => {
createContextSpy();
return {
getEmbeddingFor: vi.fn().mockResolvedValue({
vector: new Float32Array([1, 0, 0, 0]),
}),
};
},
};
},
};
},
resolveModelFile: async () => "/fake/model.gguf",
LlamaLogLevel: { error: 0 },
} as never);
const { createEmbeddingProvider } = await import("./embeddings.js");
const result = await createEmbeddingProvider({
config: {} as never,
provider: "local",
model: "",
fallback: "none",
});
const provider = requireProvider(result);
await expect(provider.embedBatch(["first"])).rejects.toThrow("transient init failure");
const recovered = await provider.embedBatch(["second"]);
@@ -631,46 +617,11 @@ describe("local embedding ensureContext concurrency", () => {
});
it("shares initialization when embedQuery and embedBatch start concurrently", async () => {
const getLlamaSpy = vi.fn();
const loadModelSpy = vi.fn();
const createContextSpy = vi.fn();
const { provider, getLlamaSpy, loadModelSpy, createContextSpy } =
await setupLocalProviderWithMockedInit({
initializationDelayMs: 50,
});
const nodeLlamaModule = await import("./node-llama.js");
vi.spyOn(nodeLlamaModule, "importNodeLlamaCpp").mockResolvedValue({
getLlama: async (...args: unknown[]) => {
getLlamaSpy(...args);
await new Promise((r) => setTimeout(r, 50));
return {
loadModel: async (...modelArgs: unknown[]) => {
loadModelSpy(...modelArgs);
await new Promise((r) => setTimeout(r, 50));
return {
createEmbeddingContext: async () => {
createContextSpy();
return {
getEmbeddingFor: vi.fn().mockResolvedValue({
vector: new Float32Array([1, 0, 0, 0]),
}),
};
},
};
},
};
},
resolveModelFile: async () => "/fake/model.gguf",
LlamaLogLevel: { error: 0 },
} as never);
const { createEmbeddingProvider } = await import("./embeddings.js");
const result = await createEmbeddingProvider({
config: {} as never,
provider: "local",
model: "",
fallback: "none",
});
const provider = requireProvider(result);
const [queryA, batch, queryB] = await Promise.all([
provider.embedQuery("query-a"),
provider.embedBatch(["batch-a", "batch-b"]),