mirror of
https://github.com/openclaw/openclaw.git
synced 2026-04-18 11:47:28 +00:00
fix(memory): guard local embedding init race and failure recovery
This commit is contained in:
committed by
Gustavo Madeira Santana
parent
88ee57124e
commit
81dbef0030
@@ -471,6 +471,74 @@ describe("local embedding normalization", () => {
|
||||
});
|
||||
});
|
||||
|
||||
describe("local embedding ensureContext concurrency", () => {
|
||||
afterEach(() => {
|
||||
vi.resetAllMocks();
|
||||
vi.resetModules();
|
||||
vi.unstubAllGlobals();
|
||||
vi.doUnmock("./node-llama.js");
|
||||
});
|
||||
|
||||
it("loads the model only once when embedBatch is called concurrently", async () => {
|
||||
const getLlamaSpy = vi.fn();
|
||||
const loadModelSpy = vi.fn();
|
||||
const createContextSpy = vi.fn();
|
||||
|
||||
const nodeLlamaModule = await import("./node-llama.js");
|
||||
vi.spyOn(nodeLlamaModule, "importNodeLlamaCpp").mockResolvedValue({
|
||||
getLlama: async (...args: unknown[]) => {
|
||||
getLlamaSpy(...args);
|
||||
await new Promise((r) => setTimeout(r, 50));
|
||||
return {
|
||||
loadModel: async (...modelArgs: unknown[]) => {
|
||||
loadModelSpy(...modelArgs);
|
||||
await new Promise((r) => setTimeout(r, 50));
|
||||
return {
|
||||
createEmbeddingContext: async () => {
|
||||
createContextSpy();
|
||||
return {
|
||||
getEmbeddingFor: vi.fn().mockResolvedValue({
|
||||
vector: new Float32Array([1, 0, 0, 0]),
|
||||
}),
|
||||
};
|
||||
},
|
||||
};
|
||||
},
|
||||
};
|
||||
},
|
||||
resolveModelFile: async () => "/fake/model.gguf",
|
||||
LlamaLogLevel: { error: 0 },
|
||||
} as never);
|
||||
|
||||
const { createEmbeddingProvider } = await import("./embeddings.js");
|
||||
|
||||
const result = await createEmbeddingProvider({
|
||||
config: {} as never,
|
||||
provider: "local",
|
||||
model: "",
|
||||
fallback: "none",
|
||||
});
|
||||
|
||||
const provider = requireProvider(result);
|
||||
const results = await Promise.all([
|
||||
provider.embedBatch(["text1"]),
|
||||
provider.embedBatch(["text2"]),
|
||||
provider.embedBatch(["text3"]),
|
||||
provider.embedBatch(["text4"]),
|
||||
]);
|
||||
|
||||
expect(results).toHaveLength(4);
|
||||
for (const embeddings of results) {
|
||||
expect(embeddings).toHaveLength(1);
|
||||
expect(embeddings[0]).toHaveLength(4);
|
||||
}
|
||||
|
||||
expect(getLlamaSpy).toHaveBeenCalledTimes(1);
|
||||
expect(loadModelSpy).toHaveBeenCalledTimes(1);
|
||||
expect(createContextSpy).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
});
|
||||
|
||||
describe("FTS-only fallback when no provider available", () => {
|
||||
it("returns null provider with reason when auto mode finds no providers", async () => {
|
||||
vi.mocked(authModule.resolveApiKeyForProvider).mockRejectedValue(
|
||||
|
||||
@@ -111,19 +111,34 @@ async function createLocalEmbeddingProvider(
|
||||
let llama: Llama | null = null;
|
||||
let embeddingModel: LlamaModel | null = null;
|
||||
let embeddingContext: LlamaEmbeddingContext | null = null;
|
||||
let initPromise: Promise<LlamaEmbeddingContext> | null = null;
|
||||
|
||||
const ensureContext = async () => {
|
||||
if (!llama) {
|
||||
llama = await getLlama({ logLevel: LlamaLogLevel.error });
|
||||
const ensureContext = async (): Promise<LlamaEmbeddingContext> => {
|
||||
if (embeddingContext) {
|
||||
return embeddingContext;
|
||||
}
|
||||
if (!embeddingModel) {
|
||||
const resolved = await resolveModelFile(modelPath, modelCacheDir || undefined);
|
||||
embeddingModel = await llama.loadModel({ modelPath: resolved });
|
||||
if (initPromise) {
|
||||
return initPromise;
|
||||
}
|
||||
if (!embeddingContext) {
|
||||
embeddingContext = await embeddingModel.createEmbeddingContext();
|
||||
}
|
||||
return embeddingContext;
|
||||
initPromise = (async () => {
|
||||
try {
|
||||
if (!llama) {
|
||||
llama = await getLlama({ logLevel: LlamaLogLevel.error });
|
||||
}
|
||||
if (!embeddingModel) {
|
||||
const resolved = await resolveModelFile(modelPath, modelCacheDir || undefined);
|
||||
embeddingModel = await llama.loadModel({ modelPath: resolved });
|
||||
}
|
||||
if (!embeddingContext) {
|
||||
embeddingContext = await embeddingModel.createEmbeddingContext();
|
||||
}
|
||||
return embeddingContext;
|
||||
} catch (err) {
|
||||
initPromise = null;
|
||||
throw err;
|
||||
}
|
||||
})();
|
||||
return initPromise;
|
||||
};
|
||||
|
||||
return {
|
||||
|
||||
Reference in New Issue
Block a user