mirror of
https://github.com/openclaw/openclaw.git
synced 2026-04-18 03:57:28 +00:00
refactor(memory): dedupe local embedding init concurrency fixtures
This commit is contained in:
@@ -516,20 +516,32 @@ describe("local embedding ensureContext concurrency", () => {
|
||||
vi.doUnmock("./node-llama.js");
|
||||
});
|
||||
|
||||
it("loads the model only once when embedBatch is called concurrently", async () => {
|
||||
async function setupLocalProviderWithMockedInit(params?: {
|
||||
initializationDelayMs?: number;
|
||||
failFirstGetLlama?: boolean;
|
||||
}) {
|
||||
const getLlamaSpy = vi.fn();
|
||||
const loadModelSpy = vi.fn();
|
||||
const createContextSpy = vi.fn();
|
||||
let shouldFail = params?.failFirstGetLlama ?? false;
|
||||
|
||||
const nodeLlamaModule = await import("./node-llama.js");
|
||||
vi.spyOn(nodeLlamaModule, "importNodeLlamaCpp").mockResolvedValue({
|
||||
getLlama: async (...args: unknown[]) => {
|
||||
getLlamaSpy(...args);
|
||||
await new Promise((r) => setTimeout(r, 50));
|
||||
if (shouldFail) {
|
||||
shouldFail = false;
|
||||
throw new Error("transient init failure");
|
||||
}
|
||||
if (params?.initializationDelayMs) {
|
||||
await new Promise((r) => setTimeout(r, params.initializationDelayMs));
|
||||
}
|
||||
return {
|
||||
loadModel: async (...modelArgs: unknown[]) => {
|
||||
loadModelSpy(...modelArgs);
|
||||
await new Promise((r) => setTimeout(r, 50));
|
||||
if (params?.initializationDelayMs) {
|
||||
await new Promise((r) => setTimeout(r, params.initializationDelayMs));
|
||||
}
|
||||
return {
|
||||
createEmbeddingContext: async () => {
|
||||
createContextSpy();
|
||||
@@ -548,7 +560,6 @@ describe("local embedding ensureContext concurrency", () => {
|
||||
} as never);
|
||||
|
||||
const { createEmbeddingProvider } = await import("./embeddings.js");
|
||||
|
||||
const result = await createEmbeddingProvider({
|
||||
config: {} as never,
|
||||
provider: "local",
|
||||
@@ -556,7 +567,20 @@ describe("local embedding ensureContext concurrency", () => {
|
||||
fallback: "none",
|
||||
});
|
||||
|
||||
const provider = requireProvider(result);
|
||||
return {
|
||||
provider: requireProvider(result),
|
||||
getLlamaSpy,
|
||||
loadModelSpy,
|
||||
createContextSpy,
|
||||
};
|
||||
}
|
||||
|
||||
it("loads the model only once when embedBatch is called concurrently", async () => {
|
||||
const { provider, getLlamaSpy, loadModelSpy, createContextSpy } =
|
||||
await setupLocalProviderWithMockedInit({
|
||||
initializationDelayMs: 50,
|
||||
});
|
||||
|
||||
const results = await Promise.all([
|
||||
provider.embedBatch(["text1"]),
|
||||
provider.embedBatch(["text2"]),
|
||||
@@ -576,49 +600,11 @@ describe("local embedding ensureContext concurrency", () => {
|
||||
});
|
||||
|
||||
it("retries initialization after a transient ensureContext failure", async () => {
|
||||
const getLlamaSpy = vi.fn();
|
||||
const loadModelSpy = vi.fn();
|
||||
const createContextSpy = vi.fn();
|
||||
const { provider, getLlamaSpy, loadModelSpy, createContextSpy } =
|
||||
await setupLocalProviderWithMockedInit({
|
||||
failFirstGetLlama: true,
|
||||
});
|
||||
|
||||
let failFirstGetLlama = true;
|
||||
const nodeLlamaModule = await import("./node-llama.js");
|
||||
vi.spyOn(nodeLlamaModule, "importNodeLlamaCpp").mockResolvedValue({
|
||||
getLlama: async (...args: unknown[]) => {
|
||||
getLlamaSpy(...args);
|
||||
if (failFirstGetLlama) {
|
||||
failFirstGetLlama = false;
|
||||
throw new Error("transient init failure");
|
||||
}
|
||||
return {
|
||||
loadModel: async (...modelArgs: unknown[]) => {
|
||||
loadModelSpy(...modelArgs);
|
||||
return {
|
||||
createEmbeddingContext: async () => {
|
||||
createContextSpy();
|
||||
return {
|
||||
getEmbeddingFor: vi.fn().mockResolvedValue({
|
||||
vector: new Float32Array([1, 0, 0, 0]),
|
||||
}),
|
||||
};
|
||||
},
|
||||
};
|
||||
},
|
||||
};
|
||||
},
|
||||
resolveModelFile: async () => "/fake/model.gguf",
|
||||
LlamaLogLevel: { error: 0 },
|
||||
} as never);
|
||||
|
||||
const { createEmbeddingProvider } = await import("./embeddings.js");
|
||||
|
||||
const result = await createEmbeddingProvider({
|
||||
config: {} as never,
|
||||
provider: "local",
|
||||
model: "",
|
||||
fallback: "none",
|
||||
});
|
||||
|
||||
const provider = requireProvider(result);
|
||||
await expect(provider.embedBatch(["first"])).rejects.toThrow("transient init failure");
|
||||
|
||||
const recovered = await provider.embedBatch(["second"]);
|
||||
@@ -631,46 +617,11 @@ describe("local embedding ensureContext concurrency", () => {
|
||||
});
|
||||
|
||||
it("shares initialization when embedQuery and embedBatch start concurrently", async () => {
|
||||
const getLlamaSpy = vi.fn();
|
||||
const loadModelSpy = vi.fn();
|
||||
const createContextSpy = vi.fn();
|
||||
const { provider, getLlamaSpy, loadModelSpy, createContextSpy } =
|
||||
await setupLocalProviderWithMockedInit({
|
||||
initializationDelayMs: 50,
|
||||
});
|
||||
|
||||
const nodeLlamaModule = await import("./node-llama.js");
|
||||
vi.spyOn(nodeLlamaModule, "importNodeLlamaCpp").mockResolvedValue({
|
||||
getLlama: async (...args: unknown[]) => {
|
||||
getLlamaSpy(...args);
|
||||
await new Promise((r) => setTimeout(r, 50));
|
||||
return {
|
||||
loadModel: async (...modelArgs: unknown[]) => {
|
||||
loadModelSpy(...modelArgs);
|
||||
await new Promise((r) => setTimeout(r, 50));
|
||||
return {
|
||||
createEmbeddingContext: async () => {
|
||||
createContextSpy();
|
||||
return {
|
||||
getEmbeddingFor: vi.fn().mockResolvedValue({
|
||||
vector: new Float32Array([1, 0, 0, 0]),
|
||||
}),
|
||||
};
|
||||
},
|
||||
};
|
||||
},
|
||||
};
|
||||
},
|
||||
resolveModelFile: async () => "/fake/model.gguf",
|
||||
LlamaLogLevel: { error: 0 },
|
||||
} as never);
|
||||
|
||||
const { createEmbeddingProvider } = await import("./embeddings.js");
|
||||
|
||||
const result = await createEmbeddingProvider({
|
||||
config: {} as never,
|
||||
provider: "local",
|
||||
model: "",
|
||||
fallback: "none",
|
||||
});
|
||||
|
||||
const provider = requireProvider(result);
|
||||
const [queryA, batch, queryB] = await Promise.all([
|
||||
provider.embedQuery("query-a"),
|
||||
provider.embedBatch(["batch-a", "batch-b"]),
|
||||
|
||||
Reference in New Issue
Block a user