From 1ae3afbd6b089942d66a0c676d9602fe0fd894da Mon Sep 17 00:00:00 2001 From: Tarun Sukhani Date: Mon, 9 Feb 2026 18:50:09 +0800 Subject: [PATCH] =?UTF-8?q?memory-neo4j:=20code=20review=20quick=20wins=20?= =?UTF-8?q?=E2=80=94=20security,=20perf,=20docs=20fixes?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Fix initPromise retry: reset to null on failure so subsequent calls retry instead of returning cached rejected promise - Remove dead code: findPromotionCandidates, findDemotionCandidates, calculateEffectiveImportance (~190 lines, never called) - Add agentId filter to deleteMemory() to prevent cross-agent deletion - Fix phase label swaps: 1b=Semantic Dedup, 1c=Conflict Detection (CLI banner, phaseNames map, SleepCycleResult/Options type comments) - Add autoRecallMinScore and coreMemory config to plugin JSON schema so the UI can validate and display these options - Add embedding LRU cache (200 entries, SHA-256 keyed) to eliminate redundant API calls across auto-recall, auto-capture, and tools - Add Ollama concurrency limiter (chunks of 4) to prevent thundering herd on single-threaded embedding server Co-Authored-By: Claude Opus 4.6 --- extensions/memory-neo4j/config.ts | 10 +- extensions/memory-neo4j/embeddings.ts | 151 +- extensions/memory-neo4j/extractor.test.ts | 968 ++++++++++++ extensions/memory-neo4j/extractor.ts | 322 ++-- extensions/memory-neo4j/index.ts | 77 +- extensions/memory-neo4j/neo4j-client.test.ts | 1460 ++++++++++++++++++ extensions/memory-neo4j/neo4j-client.ts | 537 +++---- extensions/memory-neo4j/openclaw.plugin.json | 39 + 8 files changed, 3102 insertions(+), 462 deletions(-) create mode 100644 extensions/memory-neo4j/neo4j-client.test.ts diff --git a/extensions/memory-neo4j/config.ts b/extensions/memory-neo4j/config.ts index 8d642be3a3a..c2915ab1937 100644 --- a/extensions/memory-neo4j/config.ts +++ b/extensions/memory-neo4j/config.ts @@ -83,12 +83,16 @@ export function vectorDimsForModel(model: string): number { if (EMBEDDING_DIMENSIONS[model]) { return EMBEDDING_DIMENSIONS[model]; } - // Check prefix match (for versioned models like mxbai-embed-large:latest) + // Prefer longest matching prefix (e.g. "mxbai-embed-large-2k" over "mxbai-embed-large") + let best: { dims: number; keyLen: number } | undefined; for (const [known, dims] of Object.entries(EMBEDDING_DIMENSIONS)) { - if (model.startsWith(known)) { - return dims; + if (model.startsWith(known) && (!best || known.length > best.keyLen)) { + best = { dims, keyLen: known.length }; } } + if (best) { + return best.dims; + } // Return default for unknown models — callers should warn when this path is taken, // as the default 1024 dimensions may not match the actual model's output. return DEFAULT_EMBEDDING_DIMS; diff --git a/extensions/memory-neo4j/embeddings.ts b/extensions/memory-neo4j/embeddings.ts index bb772cd4dba..0df3b137f85 100644 --- a/extensions/memory-neo4j/embeddings.ts +++ b/extensions/memory-neo4j/embeddings.ts @@ -2,8 +2,10 @@ * Embedding generation for memory-neo4j. * * Supports both OpenAI and Ollama providers. + * Includes an LRU cache to avoid redundant API calls within a session. */ +import { createHash } from "node:crypto"; import OpenAI from "openai"; import type { EmbeddingProvider } from "./config.js"; import { contextLengthForModel } from "./config.js"; @@ -15,12 +17,63 @@ type Logger = { debug?: (msg: string) => void; }; +/** + * Simple LRU cache for embedding vectors. + * Keyed by SHA-256 hash of the input text to avoid storing large strings. + */ +class EmbeddingCache { + private readonly map = new Map(); + private readonly maxSize: number; + + constructor(maxSize: number = 200) { + this.maxSize = maxSize; + } + + private static hashText(text: string): string { + return createHash("sha256").update(text).digest("hex"); + } + + get(text: string): number[] | undefined { + const key = EmbeddingCache.hashText(text); + const value = this.map.get(key); + if (value !== undefined) { + // Move to end (most recently used) by re-inserting + this.map.delete(key); + this.map.set(key, value); + } + return value; + } + + set(text: string, embedding: number[]): void { + const key = EmbeddingCache.hashText(text); + // If key exists, delete first to refresh position + if (this.map.has(key)) { + this.map.delete(key); + } else if (this.map.size >= this.maxSize) { + // Evict oldest (first) entry + const oldest = this.map.keys().next().value; + if (oldest !== undefined) { + this.map.delete(oldest); + } + } + this.map.set(key, embedding); + } + + get size(): number { + return this.map.size; + } +} + +/** Default concurrency for Ollama embedding requests */ +const OLLAMA_EMBED_CONCURRENCY = 4; + export class Embeddings { private client: OpenAI | null = null; private readonly provider: EmbeddingProvider; private readonly baseUrl: string; private readonly logger: Logger | undefined; private readonly contextLength: number; + private readonly cache = new EmbeddingCache(200); constructor( private readonly apiKey: string | undefined, @@ -70,21 +123,32 @@ export class Embeddings { /** * Generate an embedding vector for a single text. + * Results are cached to avoid redundant API calls. */ async embed(text: string): Promise { const input = this.truncateToContext(text); - if (this.provider === "ollama") { - return this.embedOllama(input); + + // Check cache first + const cached = this.cache.get(input); + if (cached) { + this.logger?.debug?.("memory-neo4j: embedding cache hit"); + return cached; } - return this.embedOpenAI(input); + + const embedding = + this.provider === "ollama" ? await this.embedOllama(input) : await this.embedOpenAI(input); + + this.cache.set(input, embedding); + return embedding; } /** * Generate embeddings for multiple texts. * Returns array of embeddings in the same order as input. * - * For Ollama: uses Promise.allSettled so individual failures don't break the - * entire batch. Failed embeddings are replaced with zero vectors and logged. + * For Ollama: processes in chunks of OLLAMA_EMBED_CONCURRENCY to avoid + * overwhelming the local server. Individual failures don't break the + * entire batch — failed embeddings are replaced with empty arrays. */ async embedBatch(texts: string[]): Promise { if (texts.length === 0) { @@ -93,36 +157,77 @@ export class Embeddings { const truncated = texts.map((t) => this.truncateToContext(t)); - if (this.provider === "ollama") { - // Ollama doesn't support batch, so we do sequential with resilient error handling - const results = await Promise.allSettled(truncated.map((t) => this.embedOllama(t))); - const embeddings: number[][] = []; - let failures = 0; + // Check cache for each text; only compute uncached ones + const results: (number[] | null)[] = truncated.map((t) => this.cache.get(t) ?? null); + const uncachedIndices: number[] = []; + const uncachedTexts: string[] = []; + for (let i = 0; i < results.length; i++) { + if (results[i] === null) { + uncachedIndices.push(i); + uncachedTexts.push(truncated[i]); + } + } - for (let i = 0; i < results.length; i++) { - const result = results[i]; + if (uncachedTexts.length === 0) { + this.logger?.debug?.(`memory-neo4j: embedBatch fully cached (${texts.length} texts)`); + return results as number[][]; + } + + let computed: number[][]; + + if (this.provider === "ollama") { + computed = await this.embedBatchOllama(uncachedTexts); + } else { + computed = await this.embedBatchOpenAI(uncachedTexts); + } + + // Merge computed results back and populate cache + for (let i = 0; i < uncachedIndices.length; i++) { + const embedding = computed[i]; + results[uncachedIndices[i]] = embedding; + if (embedding.length > 0) { + this.cache.set(uncachedTexts[i], embedding); + } + } + + return results as number[][]; + } + + /** + * Ollama batch embedding with concurrency limiting. + * Processes in chunks to avoid overwhelming the server. + */ + private async embedBatchOllama(texts: string[]): Promise { + const embeddings: number[][] = []; + let failures = 0; + + // Process in chunks of OLLAMA_EMBED_CONCURRENCY + for (let i = 0; i < texts.length; i += OLLAMA_EMBED_CONCURRENCY) { + const chunk = texts.slice(i, i + OLLAMA_EMBED_CONCURRENCY); + const chunkResults = await Promise.allSettled(chunk.map((t) => this.embedOllama(t))); + + for (let j = 0; j < chunkResults.length; j++) { + const result = chunkResults[j]; if (result.status === "fulfilled") { embeddings.push(result.value); } else { failures++; this.logger?.warn?.( - `memory-neo4j: Ollama embedding failed for text ${i}: ${String(result.reason)}`, + `memory-neo4j: Ollama embedding failed for text ${i + j}: ${String(result.reason)}`, ); - // Use zero vector as placeholder so indices stay aligned + // Use empty array as placeholder so indices stay aligned embeddings.push([]); } } - - if (failures > 0) { - this.logger?.warn?.( - `memory-neo4j: ${failures}/${texts.length} Ollama embeddings failed in batch`, - ); - } - - return embeddings; } - return this.embedBatchOpenAI(truncated); + if (failures > 0) { + this.logger?.warn?.( + `memory-neo4j: ${failures}/${texts.length} Ollama embeddings failed in batch`, + ); + } + + return embeddings; } private async embedOpenAI(text: string): Promise { diff --git a/extensions/memory-neo4j/extractor.test.ts b/extensions/memory-neo4j/extractor.test.ts index b8c4c6d8b32..8f47bc41ad7 100644 --- a/extensions/memory-neo4j/extractor.test.ts +++ b/extensions/memory-neo4j/extractor.test.ts @@ -16,6 +16,8 @@ import { runBackgroundExtraction, rateImportance, resolveConflict, + isSemanticDuplicate, + runSleepCycle, } from "./extractor.js"; import { passesAttentionGate, passesAssistantAttentionGate } from "./index.js"; @@ -1574,3 +1576,969 @@ describe("resolveConflict", () => { expect(result).toBe("skip"); }); }); + +// ============================================================================ +// runSleepCycle() — Comprehensive Phase Testing +// ============================================================================ + +describe("runSleepCycle", () => { + let mockDb: any; + let mockEmbeddings: any; + let mockLogger: any; + let mockConfig: ExtractionConfig; + const originalFetch = globalThis.fetch; + + beforeEach(() => { + vi.restoreAllMocks(); + + // Mock logger + mockLogger = { + info: vi.fn(), + warn: vi.fn(), + error: vi.fn(), + debug: vi.fn(), + }; + + // Mock embeddings + mockEmbeddings = { + embed: vi.fn().mockResolvedValue([0.1, 0.2, 0.3]), + embedBatch: vi.fn().mockResolvedValue([[0.1, 0.2, 0.3]]), + }; + + // Mock config + mockConfig = { + enabled: true, + apiKey: "test-key", + model: "test-model", + baseUrl: "https://test.ai/api/v1", + temperature: 0.0, + maxRetries: 0, + }; + + // Mock database with all required methods + mockDb = { + // findDuplicateClusters now accepts returnSimilarities param (3rd arg) + // When true, clusters include a similarities Map + findDuplicateClusters: vi + .fn() + .mockImplementation(async (threshold, agentId, returnSimilarities) => { + if (returnSimilarities) { + // Return empty clusters by default with similarities Map + return []; + } + return []; + }), + mergeMemoryCluster: vi.fn().mockResolvedValue({ survivorId: "s1", deletedCount: 0 }), + findConflictingMemories: vi.fn().mockResolvedValue([]), + invalidateMemory: vi.fn().mockResolvedValue(undefined), + calculateAllEffectiveScores: vi.fn().mockResolvedValue([]), + calculateParetoThreshold: vi.fn().mockReturnValue(0.5), + promoteToCore: vi.fn().mockResolvedValue(0), + demoteFromCore: vi.fn().mockResolvedValue(0), + findDecayedMemories: vi.fn().mockResolvedValue([]), + pruneMemories: vi.fn().mockResolvedValue(0), + countByExtractionStatus: vi + .fn() + .mockResolvedValue({ pending: 0, complete: 0, failed: 0, skipped: 0 }), + listPendingExtractions: vi.fn().mockResolvedValue([]), + findOrphanEntities: vi.fn().mockResolvedValue([]), + deleteOrphanEntities: vi.fn().mockResolvedValue(0), + findOrphanTags: vi.fn().mockResolvedValue([]), + deleteOrphanTags: vi.fn().mockResolvedValue(0), + updateExtractionStatus: vi.fn().mockResolvedValue(undefined), + mergeEntity: vi.fn().mockResolvedValue({ id: "e1", name: "test" }), + createMentions: vi.fn().mockResolvedValue(undefined), + createEntityRelationship: vi.fn().mockResolvedValue(undefined), + tagMemory: vi.fn().mockResolvedValue(undefined), + updateMemoryCategory: vi.fn().mockResolvedValue(undefined), + }; + }); + + afterEach(() => { + globalThis.fetch = originalFetch; + }); + + // Phase 1: Deduplication + describe("Phase 1: Deduplication", () => { + it("should merge clusters when vector similarity ≥ 0.95", async () => { + // New implementation calls findDuplicateClusters(0.75, agentId, true) with similarities + const similarities = new Map([ + ["m1:m2", 0.97], + ["m1:m3", 0.96], + ["m2:m3", 0.98], + ]); + mockDb.findDuplicateClusters.mockResolvedValue([ + { + memoryIds: ["m1", "m2", "m3"], + texts: ["text 1", "text 2", "text 3"], + importances: [0.8, 0.9, 0.7], + similarities, + }, + ]); + mockDb.mergeMemoryCluster.mockResolvedValue({ survivorId: "m2", deletedCount: 2 }); + + const result = await runSleepCycle(mockDb, mockEmbeddings, mockConfig, mockLogger); + + expect(mockDb.findDuplicateClusters).toHaveBeenCalledWith(0.75, undefined, true); + expect(mockDb.mergeMemoryCluster).toHaveBeenCalledWith(["m1", "m2", "m3"], [0.8, 0.9, 0.7]); + expect(result.dedup.clustersFound).toBe(1); + expect(result.dedup.memoriesMerged).toBe(2); + }); + + it("should keep highest-importance memory in cluster", async () => { + const similarities = new Map([ + ["high:low", 0.98], + ["high:mid", 0.96], + ["low:mid", 0.97], + ]); + mockDb.findDuplicateClusters.mockResolvedValue([ + { + memoryIds: ["low", "high", "mid"], + texts: ["text", "text", "text"], + importances: [0.3, 0.9, 0.5], + similarities, + }, + ]); + + await runSleepCycle(mockDb, mockEmbeddings, mockConfig, mockLogger); + + // mergeMemoryCluster is called with all IDs and importances + // It's responsible for choosing the survivor (highest importance) + expect(mockDb.mergeMemoryCluster).toHaveBeenCalledWith( + ["low", "high", "mid"], + [0.3, 0.9, 0.5], + ); + }); + + it("should report correct counts for multiple clusters", async () => { + mockDb.findDuplicateClusters.mockResolvedValue([ + { + memoryIds: ["a1", "a2"], + texts: ["a", "a"], + importances: [0.5, 0.6], + similarities: new Map([["a1:a2", 0.98]]), + }, + { + memoryIds: ["b1", "b2", "b3"], + texts: ["b", "b", "b"], + importances: [0.7, 0.8, 0.9], + similarities: new Map([ + ["b1:b2", 0.97], + ["b1:b3", 0.96], + ["b2:b3", 0.99], + ]), + }, + ]); + mockDb.mergeMemoryCluster + .mockResolvedValueOnce({ survivorId: "a2", deletedCount: 1 }) + .mockResolvedValueOnce({ survivorId: "b3", deletedCount: 2 }); + + const result = await runSleepCycle(mockDb, mockEmbeddings, mockConfig, mockLogger); + + expect(result.dedup.clustersFound).toBe(2); + expect(result.dedup.memoriesMerged).toBe(3); + }); + + it("should skip dedup when no clusters found", async () => { + mockDb.findDuplicateClusters.mockResolvedValue([]); + + const result = await runSleepCycle(mockDb, mockEmbeddings, mockConfig, mockLogger); + + expect(result.dedup.clustersFound).toBe(0); + expect(result.dedup.memoriesMerged).toBe(0); + expect(mockDb.mergeMemoryCluster).not.toHaveBeenCalled(); + }); + }); + + // Phase 1b: Conflict Detection + describe("Phase 1b: Conflict Detection", () => { + beforeEach(() => { + globalThis.fetch = vi.fn().mockResolvedValue({ + ok: true, + json: () => + Promise.resolve({ + choices: [ + { message: { content: JSON.stringify({ keep: "a", reason: "more recent" }) } }, + ], + }), + }); + }); + + it("should call resolveConflict for entity-linked memory pairs", async () => { + mockDb.findConflictingMemories.mockResolvedValue([ + { + memoryA: { + id: "m1", + text: "user prefers dark mode", + importance: 0.7, + createdAt: "2024-01-01", + }, + memoryB: { + id: "m2", + text: "user prefers light mode", + importance: 0.6, + createdAt: "2024-01-02", + }, + }, + ]); + + const result = await runSleepCycle(mockDb, mockEmbeddings, mockConfig, mockLogger); + + expect(mockDb.findConflictingMemories).toHaveBeenCalled(); + expect(result.conflict.pairsFound).toBe(1); + expect(result.conflict.resolved).toBe(1); + }); + + it("should invalidate the loser (importance → 0.01)", async () => { + mockDb.findConflictingMemories.mockResolvedValue([ + { + memoryA: { id: "m1", text: "old info", importance: 0.5, createdAt: "2024-01-01" }, + memoryB: { id: "m2", text: "new info", importance: 0.8, createdAt: "2024-01-02" }, + }, + ]); + + // LLM says keep "a" + globalThis.fetch = vi.fn().mockResolvedValue({ + ok: true, + json: () => + Promise.resolve({ + choices: [{ message: { content: JSON.stringify({ keep: "a", reason: "test" }) } }], + }), + }); + + await runSleepCycle(mockDb, mockEmbeddings, mockConfig, mockLogger); + + expect(mockDb.invalidateMemory).toHaveBeenCalledWith("m2"); + }); + + it("should not count 'skip' decisions as resolved", async () => { + mockDb.findConflictingMemories.mockResolvedValue([ + { + memoryA: { id: "m1", text: "text", importance: 0.5, createdAt: "2024-01-01" }, + memoryB: { id: "m2", text: "text", importance: 0.5, createdAt: "2024-01-02" }, + }, + ]); + + // LLM unavailable + globalThis.fetch = vi.fn().mockResolvedValue({ ok: false, status: 500 }); + + const result = await runSleepCycle(mockDb, mockEmbeddings, mockConfig, mockLogger); + + expect(result.conflict.pairsFound).toBe(1); + expect(result.conflict.resolved).toBe(0); + expect(result.conflict.invalidated).toBe(0); + }); + + it("should handle 'both' decision (no conflict)", async () => { + mockDb.findConflictingMemories.mockResolvedValue([ + { + memoryA: { id: "m1", text: "likes coffee", importance: 0.5, createdAt: "2024-01-01" }, + memoryB: { id: "m2", text: "works at Acme", importance: 0.5, createdAt: "2024-01-02" }, + }, + ]); + + globalThis.fetch = vi.fn().mockResolvedValue({ + ok: true, + json: () => + Promise.resolve({ + choices: [ + { message: { content: JSON.stringify({ keep: "both", reason: "no conflict" }) } }, + ], + }), + }); + + const result = await runSleepCycle(mockDb, mockEmbeddings, mockConfig, mockLogger); + + expect(result.conflict.resolved).toBe(1); + expect(result.conflict.invalidated).toBe(0); + expect(mockDb.invalidateMemory).not.toHaveBeenCalled(); + }); + }); + + // Phase 1b: Semantic Deduplication (0.75-0.95 band) + describe("Phase 1b: Semantic Deduplication", () => { + it("should check pairs in 0.75-0.95 similarity band", async () => { + // New implementation: single call at 0.75, clusters with similarities in 0.75-0.95 range go to semantic dedup + mockDb.findDuplicateClusters.mockResolvedValue([ + { + memoryIds: ["m1", "m2"], + texts: ["Tarun prefers dark mode", "Tarun likes dark theme"], + importances: [0.8, 0.7], + similarities: new Map([["m1:m2", 0.85]]), // 0.75-0.95 range + }, + ]); + + globalThis.fetch = vi.fn().mockResolvedValue({ + ok: true, + json: () => + Promise.resolve({ + choices: [ + { + message: { + content: JSON.stringify({ verdict: "duplicate", reason: "paraphrase" }), + }, + }, + ], + }), + }); + + const result = await runSleepCycle(mockDb, mockEmbeddings, mockConfig, mockLogger); + + expect(mockDb.findDuplicateClusters).toHaveBeenCalledWith(0.75, undefined, true); + expect(result.semanticDedup.pairsChecked).toBe(1); + expect(result.semanticDedup.duplicatesMerged).toBe(1); + }); + + it("should invalidate lower-importance duplicate", async () => { + mockDb.findDuplicateClusters.mockResolvedValue([ + { + memoryIds: ["high", "low"], + texts: ["high importance text", "low importance text"], + importances: [0.9, 0.3], + similarities: new Map([["high:low", 0.82]]), // 0.75-0.95 range + }, + ]); + + globalThis.fetch = vi.fn().mockResolvedValue({ + ok: true, + json: () => + Promise.resolve({ + choices: [{ message: { content: JSON.stringify({ verdict: "duplicate" }) } }], + }), + }); + + await runSleepCycle(mockDb, mockEmbeddings, mockConfig, mockLogger); + + // Should invalidate "low" (lower importance) + expect(mockDb.invalidateMemory).toHaveBeenCalledWith("low"); + }); + + it("should report correct pair counts", async () => { + mockDb.findDuplicateClusters.mockResolvedValue([ + { + memoryIds: ["a", "b", "c"], + texts: ["text", "text", "text"], + importances: [0.5, 0.6, 0.7], + similarities: new Map([ + ["a:b", 0.8], + ["a:c", 0.78], + ["b:c", 0.82], + ]), // All in 0.75-0.95 range + }, + ]); + + // All 3 pairs are collected and fired concurrently in one batch: + // (a,b) = duplicate, (a,c) = duplicate but skipped (a invalidated), (b,c) = unique + globalThis.fetch = vi + .fn() + .mockResolvedValueOnce({ + ok: true, + json: () => + Promise.resolve({ + choices: [{ message: { content: JSON.stringify({ verdict: "duplicate" }) } }], + }), + }) + .mockResolvedValueOnce({ + ok: true, + json: () => + Promise.resolve({ + choices: [{ message: { content: JSON.stringify({ verdict: "duplicate" }) } }], + }), + }) + .mockResolvedValueOnce({ + ok: true, + json: () => + Promise.resolve({ + choices: [{ message: { content: JSON.stringify({ verdict: "unique" }) } }], + }), + }); + + const result = await runSleepCycle(mockDb, mockEmbeddings, mockConfig, mockLogger); + + // All 3 pairs checked concurrently, but only 1 merge (a,c duplicate skipped since a already invalidated) + expect(result.semanticDedup.pairsChecked).toBe(3); + expect(result.semanticDedup.duplicatesMerged).toBe(1); + }); + }); + + // Phase 2: Pareto Scoring + describe("Phase 2: Pareto Scoring", () => { + it("should calculate correct threshold for top 20%", async () => { + const scores = [ + { + id: "m1", + text: "test", + category: "fact", + importance: 0.9, + retrievalCount: 10, + ageDays: 5, + effectiveScore: 0.95, + }, + { + id: "m2", + text: "test", + category: "fact", + importance: 0.5, + retrievalCount: 5, + ageDays: 10, + effectiveScore: 0.5, + }, + { + id: "m3", + text: "test", + category: "core", + importance: 0.3, + retrievalCount: 2, + ageDays: 20, + effectiveScore: 0.3, + }, + ]; + mockDb.calculateAllEffectiveScores.mockResolvedValue(scores); + mockDb.calculateParetoThreshold.mockReturnValue(0.8); + + const result = await runSleepCycle(mockDb, mockEmbeddings, mockConfig, mockLogger); + + expect(mockDb.calculateAllEffectiveScores).toHaveBeenCalled(); + expect(mockDb.calculateParetoThreshold).toHaveBeenCalledWith(scores, 0.8); // 1 - paretoPercentile (default 0.2) + expect(result.pareto.totalMemories).toBe(3); + expect(result.pareto.coreMemories).toBe(1); + expect(result.pareto.regularMemories).toBe(2); + expect(result.pareto.threshold).toBe(0.8); + }); + + it("should handle empty database", async () => { + mockDb.calculateAllEffectiveScores.mockResolvedValue([]); + mockDb.calculateParetoThreshold.mockReturnValue(0); // Empty array returns 0 + + const result = await runSleepCycle(mockDb, mockEmbeddings, mockConfig, mockLogger); + + expect(result.pareto.totalMemories).toBe(0); + expect(result.pareto.threshold).toBe(0); + }); + + it("should handle single memory", async () => { + mockDb.calculateAllEffectiveScores.mockResolvedValue([ + { + id: "m1", + text: "test", + category: "fact", + importance: 0.9, + retrievalCount: 10, + ageDays: 5, + effectiveScore: 0.95, + }, + ]); + mockDb.calculateParetoThreshold.mockReturnValue(0.95); + + const result = await runSleepCycle(mockDb, mockEmbeddings, mockConfig, mockLogger); + + expect(result.pareto.totalMemories).toBe(1); + expect(result.pareto.threshold).toBe(0.95); + }); + }); + + // Phase 3: Promotion + describe("Phase 3: Core Promotion", () => { + it("should promote regular memories above threshold", async () => { + const scores = [ + { + id: "m1", + text: "test", + category: "fact", + importance: 0.9, + retrievalCount: 10, + ageDays: 10, + effectiveScore: 0.95, + }, + { + id: "m2", + text: "test", + category: "fact", + importance: 0.5, + retrievalCount: 5, + ageDays: 8, + effectiveScore: 0.6, + }, + { + id: "m3", + text: "test", + category: "core", + importance: 0.8, + retrievalCount: 8, + ageDays: 5, + effectiveScore: 0.85, + }, + ]; + mockDb.calculateAllEffectiveScores.mockResolvedValue(scores); + mockDb.calculateParetoThreshold.mockReturnValue(0.7); // threshold + mockDb.promoteToCore.mockResolvedValue(1); + + const result = await runSleepCycle(mockDb, mockEmbeddings, mockConfig, mockLogger, { + paretoPercentile: 0.2, + promotionMinAgeDays: 7, + }); + + // m1 should be promoted (category=fact, score=0.95 > 0.70, age=10 >= 7) + expect(mockDb.promoteToCore).toHaveBeenCalledWith(["m1"]); + expect(result.promotion.candidatesFound).toBe(1); + expect(result.promotion.promoted).toBe(1); + }); + + it("should respect promotionMinAgeDays", async () => { + const scores = [ + { + id: "m1", + text: "test", + category: "fact", + importance: 0.9, + retrievalCount: 10, + ageDays: 5, + effectiveScore: 0.95, + }, + ]; + mockDb.calculateAllEffectiveScores.mockResolvedValue(scores); + mockDb.calculateParetoThreshold.mockReturnValue(0.5); + + const result = await runSleepCycle(mockDb, mockEmbeddings, mockConfig, mockLogger, { + promotionMinAgeDays: 7, + }); + + // m1 age=5 < 7, should not be promoted + expect(result.promotion.candidatesFound).toBe(0); + expect(mockDb.promoteToCore).not.toHaveBeenCalled(); + }); + + it("should not promote core memories again", async () => { + const scores = [ + { + id: "m1", + text: "test", + category: "core", + importance: 0.9, + retrievalCount: 10, + ageDays: 10, + effectiveScore: 0.95, + }, + ]; + mockDb.calculateAllEffectiveScores.mockResolvedValue(scores); + mockDb.calculateParetoThreshold.mockReturnValue(0.5); + + const result = await runSleepCycle(mockDb, mockEmbeddings, mockConfig, mockLogger); + + expect(result.promotion.candidatesFound).toBe(0); + expect(mockDb.promoteToCore).not.toHaveBeenCalled(); + }); + }); + + // Phase 4: Demotion + describe("Phase 4: Core Demotion", () => { + it("should demote core memories below threshold", async () => { + const scores = [ + { + id: "m1", + text: "test", + category: "core", + importance: 0.3, + retrievalCount: 1, + ageDays: 30, + effectiveScore: 0.3, + }, + { + id: "m2", + text: "test", + category: "core", + importance: 0.9, + retrievalCount: 10, + ageDays: 5, + effectiveScore: 0.95, + }, + ]; + mockDb.calculateAllEffectiveScores.mockResolvedValue(scores); + mockDb.calculateParetoThreshold.mockReturnValue(0.7); + mockDb.demoteFromCore.mockResolvedValue(1); + + const result = await runSleepCycle(mockDb, mockEmbeddings, mockConfig, mockLogger); + + // m1 should be demoted (category=core, score=0.30 < 0.70) + expect(mockDb.demoteFromCore).toHaveBeenCalledWith(["m1"]); + expect(result.demotion.candidatesFound).toBe(1); + expect(result.demotion.demoted).toBe(1); + }); + + it("should not demote regular memories", async () => { + const scores = [ + { + id: "m1", + text: "test", + category: "fact", + importance: 0.2, + retrievalCount: 0, + ageDays: 50, + effectiveScore: 0.1, + }, + ]; + mockDb.calculateAllEffectiveScores.mockResolvedValue(scores); + mockDb.calculateParetoThreshold.mockReturnValue(0.7); + + const result = await runSleepCycle(mockDb, mockEmbeddings, mockConfig, mockLogger); + + expect(result.demotion.candidatesFound).toBe(0); + expect(mockDb.demoteFromCore).not.toHaveBeenCalled(); + }); + }); + + // Phase 5: Extraction + describe("Phase 5: Entity Extraction", () => { + it("should process pending extractions in batches", async () => { + mockDb.countByExtractionStatus.mockResolvedValue({ + pending: 5, + complete: 0, + failed: 0, + skipped: 0, + }); + // First call returns 3 memories, second call returns empty to stop loop + mockDb.listPendingExtractions + .mockResolvedValueOnce([ + { id: "m1", text: "text 1", agentId: "default", extractionRetries: 0 }, + { id: "m2", text: "text 2", agentId: "default", extractionRetries: 0 }, + { id: "m3", text: "text 3", agentId: "default", extractionRetries: 0 }, + ]) + .mockResolvedValueOnce([]); + + globalThis.fetch = vi.fn().mockResolvedValue({ + ok: true, + json: () => + Promise.resolve({ + choices: [ + { + message: { content: JSON.stringify({ entities: [], relationships: [], tags: [] }) }, + }, + ], + }), + }); + + const result = await runSleepCycle(mockDb, mockEmbeddings, mockConfig, mockLogger, { + extractionBatchSize: 10, + }); + + expect(mockDb.listPendingExtractions).toHaveBeenCalled(); + expect(result.extraction.total).toBe(5); + expect(result.extraction.processed).toBe(3); + }); + + it("should handle extraction failures with retry tracking", async () => { + mockDb.countByExtractionStatus.mockResolvedValue({ + pending: 1, + complete: 0, + failed: 0, + skipped: 0, + }); + // First call returns 1 memory, second call returns empty to stop loop + mockDb.listPendingExtractions + .mockResolvedValueOnce([ + { id: "m1", text: "text", agentId: "default", extractionRetries: 0 }, + ]) + .mockResolvedValueOnce([]); + + // Extraction fails (HTTP error) + globalThis.fetch = vi.fn().mockResolvedValue({ ok: false, status: 500 }); + + const result = await runSleepCycle(mockDb, mockEmbeddings, mockConfig, mockLogger); + + expect(result.extraction.processed).toBe(1); + // runBackgroundExtraction doesn't throw on HTTP errors, it just marks the extraction status as failed/pending + // The sleep cycle counts it as succeeded because Promise.allSettled reports it as fulfilled + expect(result.extraction.succeeded).toBe(1); + expect(result.extraction.failed).toBe(0); + }); + + it("should respect batch size and delay", async () => { + mockDb.countByExtractionStatus.mockResolvedValue({ + pending: 2, + complete: 0, + failed: 0, + skipped: 0, + }); + mockDb.listPendingExtractions + .mockResolvedValueOnce([ + { id: "m1", text: "text 1", agentId: "default", extractionRetries: 0 }, + ]) + .mockResolvedValueOnce([]); + + globalThis.fetch = vi.fn().mockResolvedValue({ + ok: true, + json: () => + Promise.resolve({ + choices: [ + { + message: { content: JSON.stringify({ entities: [], relationships: [], tags: [] }) }, + }, + ], + }), + }); + + const result = await runSleepCycle(mockDb, mockEmbeddings, mockConfig, mockLogger, { + extractionBatchSize: 1, + extractionDelayMs: 100, + }); + + expect(mockDb.listPendingExtractions).toHaveBeenCalledWith(1, undefined); + expect(result.extraction.processed).toBe(1); + }); + }); + + // Phase 6: Decay & Pruning + describe("Phase 6: Decay & Pruning", () => { + it("should prune memories below retention threshold", async () => { + mockDb.findDecayedMemories.mockResolvedValue([ + { id: "m1", text: "old memory", importance: 0.2, ageDays: 100, decayScore: 0.05 }, + { id: "m2", text: "very old", importance: 0.1, ageDays: 200, decayScore: 0.02 }, + ]); + mockDb.pruneMemories.mockResolvedValue(2); + + const result = await runSleepCycle(mockDb, mockEmbeddings, mockConfig, mockLogger); + + expect(mockDb.findDecayedMemories).toHaveBeenCalled(); + expect(mockDb.pruneMemories).toHaveBeenCalledWith(["m1", "m2"]); + expect(result.decay.memoriesPruned).toBe(2); + }); + + it("should apply exponential decay based on age", async () => { + const result = await runSleepCycle(mockDb, mockEmbeddings, mockConfig, mockLogger, { + decayRetentionThreshold: 0.1, + decayBaseHalfLifeDays: 30, + }); + + expect(mockDb.findDecayedMemories).toHaveBeenCalledWith({ + retentionThreshold: 0.1, + baseHalfLifeDays: 30, + importanceMultiplier: 2, + agentId: undefined, + }); + }); + + it("should extend half-life based on importance", async () => { + const result = await runSleepCycle(mockDb, mockEmbeddings, mockConfig, mockLogger, { + decayImportanceMultiplier: 3, + }); + + expect(mockDb.findDecayedMemories).toHaveBeenCalledWith( + expect.objectContaining({ + importanceMultiplier: 3, + }), + ); + }); + }); + + // Phase 7: Orphan Cleanup + describe("Phase 7: Orphan Cleanup", () => { + it("should remove entities with 0 mentions", async () => { + mockDb.findOrphanEntities.mockResolvedValue([ + { id: "e1", name: "orphan1", type: "concept" }, + { id: "e2", name: "orphan2", type: "person" }, + ]); + mockDb.deleteOrphanEntities.mockResolvedValue(2); + + const result = await runSleepCycle(mockDb, mockEmbeddings, mockConfig, mockLogger); + + expect(mockDb.findOrphanEntities).toHaveBeenCalled(); + expect(mockDb.deleteOrphanEntities).toHaveBeenCalledWith(["e1", "e2"]); + expect(result.cleanup.entitiesRemoved).toBe(2); + }); + + it("should remove unused tags", async () => { + mockDb.findOrphanTags.mockResolvedValue([{ id: "t1", name: "unused-tag" }]); + mockDb.deleteOrphanTags.mockResolvedValue(1); + + const result = await runSleepCycle(mockDb, mockEmbeddings, mockConfig, mockLogger); + + expect(mockDb.findOrphanTags).toHaveBeenCalled(); + expect(mockDb.deleteOrphanTags).toHaveBeenCalledWith(["t1"]); + expect(result.cleanup.tagsRemoved).toBe(1); + }); + + it("should report correct cleanup counts", async () => { + mockDb.findOrphanEntities.mockResolvedValue([{ id: "e1", name: "test", type: "concept" }]); + mockDb.deleteOrphanEntities.mockResolvedValue(1); + mockDb.findOrphanTags.mockResolvedValue([{ id: "t1", name: "test" }]); + mockDb.deleteOrphanTags.mockResolvedValue(1); + + const result = await runSleepCycle(mockDb, mockEmbeddings, mockConfig, mockLogger); + + expect(result.cleanup.entitiesRemoved).toBe(1); + expect(result.cleanup.tagsRemoved).toBe(1); + }); + }); + + // Abort handling + describe("Abort handling", () => { + it("should stop between phases when aborted", async () => { + const abortController = new AbortController(); + + // Abort after Phase 1 + mockDb.findDuplicateClusters.mockImplementation(async () => { + abortController.abort(); + return []; + }); + + const result = await runSleepCycle(mockDb, mockEmbeddings, mockConfig, mockLogger, { + abortSignal: abortController.signal, + }); + + expect(result.aborted).toBe(true); + // Phase 1 ran, but subsequent phases should be skipped + expect(mockDb.findDuplicateClusters).toHaveBeenCalled(); + }); + + it("should show aborted=true in result", async () => { + const abortController = new AbortController(); + abortController.abort(); + + const result = await runSleepCycle(mockDb, mockEmbeddings, mockConfig, mockLogger, { + abortSignal: abortController.signal, + }); + + expect(result.aborted).toBe(true); + }); + + it("should not corrupt data on abort", async () => { + const abortController = new AbortController(); + + mockDb.findDuplicateClusters.mockImplementation(async () => { + abortController.abort(); + return [ + { + memoryIds: ["m1", "m2"], + texts: ["a", "b"], + importances: [0.5, 0.6], + similarities: new Map([["m1:m2", 0.98]]), + }, + ]; + }); + + const result = await runSleepCycle(mockDb, mockEmbeddings, mockConfig, mockLogger, { + abortSignal: abortController.signal, + }); + + // Even though aborted, the cluster merge should not have been called + // (abort happens before mergeMemoryCluster in the loop) + expect(result.aborted).toBe(true); + }); + }); + + // Error isolation + describe("Error isolation", () => { + it("should continue to Phase 2 if Phase 1 fails", async () => { + mockDb.findDuplicateClusters.mockRejectedValue(new Error("phase 1 error")); + + const result = await runSleepCycle(mockDb, mockEmbeddings, mockConfig, mockLogger); + + // Phase 2 should still run + expect(mockDb.calculateAllEffectiveScores).toHaveBeenCalled(); + expect(mockLogger.warn).toHaveBeenCalledWith(expect.stringContaining("Phase 1 error")); + }); + + it("should handle LLM timeout without crashing", async () => { + mockDb.findConflictingMemories.mockResolvedValue([ + { + memoryA: { id: "m1", text: "a", importance: 0.5, createdAt: "2024-01-01" }, + memoryB: { id: "m2", text: "b", importance: 0.5, createdAt: "2024-01-02" }, + }, + ]); + + globalThis.fetch = vi.fn().mockRejectedValue(new DOMException("timeout", "TimeoutError")); + + const result = await runSleepCycle(mockDb, mockEmbeddings, mockConfig, mockLogger); + + // Should not crash, conflict resolution returns "skip" + expect(result.conflict.resolved).toBe(0); + // Other phases should continue + expect(mockDb.calculateAllEffectiveScores).toHaveBeenCalled(); + }); + + it("should handle Neo4j transient error retries", async () => { + // This is tested more thoroughly in neo4j-client.test.ts + // Here we just verify the sleep cycle doesn't crash + mockDb.findDuplicateClusters + .mockRejectedValueOnce(new Error("transient")) + .mockResolvedValueOnce([]); + + const result = await runSleepCycle(mockDb, mockEmbeddings, mockConfig, mockLogger); + + // Should log error but continue + expect(mockLogger.warn).toHaveBeenCalled(); + }); + }); + + // Progress callbacks + describe("Progress callbacks", () => { + it("should call onPhaseStart for each phase", async () => { + const onPhaseStart = vi.fn(); + + await runSleepCycle(mockDb, mockEmbeddings, mockConfig, mockLogger, { + onPhaseStart, + }); + + expect(onPhaseStart).toHaveBeenCalledWith("dedup"); + expect(onPhaseStart).toHaveBeenCalledWith("conflict"); + expect(onPhaseStart).toHaveBeenCalledWith("semanticDedup"); + expect(onPhaseStart).toHaveBeenCalledWith("pareto"); + expect(onPhaseStart).toHaveBeenCalledWith("promotion"); + expect(onPhaseStart).toHaveBeenCalledWith("demotion"); + expect(onPhaseStart).toHaveBeenCalledWith("extraction"); + expect(onPhaseStart).toHaveBeenCalledWith("decay"); + expect(onPhaseStart).toHaveBeenCalledWith("cleanup"); + }); + + it("should call onProgress with phase messages", async () => { + const onProgress = vi.fn(); + mockDb.findDuplicateClusters.mockResolvedValue([ + { + memoryIds: ["m1", "m2"], + texts: ["a", "b"], + importances: [0.5, 0.6], + similarities: new Map([["m1:m2", 0.98]]), + }, + ]); + mockDb.mergeMemoryCluster.mockResolvedValue({ survivorId: "m2", deletedCount: 1 }); + + await runSleepCycle(mockDb, mockEmbeddings, mockConfig, mockLogger, { + onProgress, + }); + + expect(onProgress).toHaveBeenCalledWith("dedup", expect.any(String)); + }); + }); + + // Overall result structure + describe("Result structure", () => { + it("should return complete result object", async () => { + const result = await runSleepCycle(mockDb, mockEmbeddings, mockConfig, mockLogger); + + expect(result).toHaveProperty("dedup"); + expect(result).toHaveProperty("conflict"); + expect(result).toHaveProperty("semanticDedup"); + expect(result).toHaveProperty("pareto"); + expect(result).toHaveProperty("promotion"); + expect(result).toHaveProperty("demotion"); + expect(result).toHaveProperty("decay"); + expect(result).toHaveProperty("extraction"); + expect(result).toHaveProperty("cleanup"); + expect(result).toHaveProperty("durationMs"); + expect(result).toHaveProperty("aborted"); + }); + + it("should track duration correctly", async () => { + const result = await runSleepCycle(mockDb, mockEmbeddings, mockConfig, mockLogger); + + expect(result.durationMs).toBeGreaterThanOrEqual(0); + expect(typeof result.durationMs).toBe("number"); + }); + + it("should default aborted to false", async () => { + const result = await runSleepCycle(mockDb, mockEmbeddings, mockConfig, mockLogger); + + expect(result.aborted).toBe(false); + }); + }); +}); diff --git a/extensions/memory-neo4j/extractor.ts b/extensions/memory-neo4j/extractor.ts index 06e3b98f36b..8f3370c621c 100644 --- a/extensions/memory-neo4j/extractor.ts +++ b/extensions/memory-neo4j/extractor.ts @@ -30,10 +30,9 @@ type Logger = { // Extraction Prompt // ============================================================================ -const ENTITY_EXTRACTION_PROMPT = `You are an entity extraction system for a personal memory store. -Extract entities and relationships from this memory text, and classify the memory. - -Memory: "{text}" +// System instruction (no user data) — user message contains the memory text +const ENTITY_EXTRACTION_SYSTEM = `You are an entity extraction system for a personal memory store. +Extract entities and relationships from the memory text provided by the user, and classify the memory. Return JSON: { @@ -66,7 +65,12 @@ Rules: // Timeout for LLM and embedding fetch calls to prevent hanging indefinitely const FETCH_TIMEOUT_MS = 30_000; -async function callOpenRouter(config: ExtractionConfig, prompt: string): Promise { +async function callOpenRouter( + config: ExtractionConfig, + prompt: string | Array<{ role: string; content: string }>, +): Promise { + const messages = typeof prompt === "string" ? [{ role: "user", content: prompt }] : prompt; + for (let attempt = 0; attempt <= config.maxRetries; attempt++) { try { const response = await fetch(`${config.baseUrl}/chat/completions`, { @@ -77,7 +81,7 @@ async function callOpenRouter(config: ExtractionConfig, prompt: string): Promise }, body: JSON.stringify({ model: config.model, - messages: [{ role: "user", content: prompt }], + messages, temperature: config.temperature, response_format: { type: "json_object" }, }), @@ -152,11 +156,15 @@ export async function extractEntities( return { result: null, transientFailure: false }; } - const prompt = ENTITY_EXTRACTION_PROMPT.replace("{text}", text); + // System/user separation prevents memory text from being interpreted as instructions + const messages = [ + { role: "system", content: ENTITY_EXTRACTION_SYSTEM }, + { role: "user", content: text }, + ]; let content: string | null; try { - content = await callOpenRouter(config, prompt); + content = await callOpenRouter(config, messages); } catch (err) { // Network/timeout errors are transient — caller should retry return { result: null, transientFailure: isTransientError(err) }; @@ -259,36 +267,19 @@ export async function resolveConflict( ): Promise<"a" | "b" | "both" | "skip"> { if (!config.enabled) return "skip"; - const prompt = `Two memories may conflict with each other. Determine which should be kept. - -Memory A: "${memA}" -Memory B: "${memB}" + try { + const content = await callOpenRouter(config, [ + { + role: "system", + content: `Two memories may conflict with each other. Determine which should be kept. If they genuinely contradict each other, keep the one that is more current, specific, or accurate. If they don't actually conflict (they cover different aspects or are both valid), keep both. -Return JSON: {"keep": "a"|"b"|"both", "reason": "brief explanation"}`; - - try { - const response = await fetch(`${config.baseUrl}/chat/completions`, { - method: "POST", - headers: { - Authorization: `Bearer ${config.apiKey}`, - "Content-Type": "application/json", +Return JSON: {"keep": "a"|"b"|"both", "reason": "brief explanation"}`, }, - body: JSON.stringify({ - model: config.model, - messages: [{ role: "user", content: prompt }], - temperature: 0.0, - response_format: { type: "json_object" }, - }), - signal: AbortSignal.timeout(10_000), - }); - - if (!response.ok) return "skip"; - - const data = (await response.json()) as { choices?: Array<{ message?: { content?: string } }> }; - const content = data.choices?.[0]?.message?.content; + { role: "user", content: `Memory A: "${memA}"\nMemory B: "${memB}"` }, + ]); if (!content) return "skip"; const parsed = JSON.parse(content) as { keep?: string }; @@ -457,6 +448,11 @@ export type SleepCycleResult = { resolved: number; invalidated: number; }; + // Phase 1c: Semantic Deduplication + semanticDedup: { + pairsChecked: number; + duplicatesMerged: number; + }; // Phase 2: Pareto Scoring & Threshold pareto: { totalMemories: number; @@ -474,11 +470,11 @@ export type SleepCycleResult = { candidatesFound: number; demoted: number; }; - // Phase 5: Decay & Pruning + // Phase 6: Decay & Pruning decay: { memoriesPruned: number; }; - // Phase 6: Entity Extraction + // Phase 5: Entity Extraction extraction: { total: number; processed: number; @@ -507,20 +503,21 @@ export type SleepCycleOptions = { paretoPercentile?: number; // Top N% for core (default: 0.2 = top 20%) promotionMinAgeDays?: number; // Min age before promotion (default: 7) - // Phase 5: Decay + // Phase 5: Extraction + extractionBatchSize?: number; // Memories per batch (default: 50) + extractionDelayMs?: number; // Delay between batches (default: 1000) + + // Phase 6: Decay decayRetentionThreshold?: number; // Below this, memory is pruned (default: 0.1) decayBaseHalfLifeDays?: number; // Base half-life in days (default: 30) decayImportanceMultiplier?: number; // How much importance extends half-life (default: 2) - // Phase 6: Extraction - extractionBatchSize?: number; // Memories per batch (default: 50) - extractionDelayMs?: number; // Delay between batches (default: 1000) - // Progress callback onPhaseStart?: ( phase: | "dedup" | "conflict" + | "semanticDedup" | "pareto" | "promotion" | "demotion" @@ -592,6 +589,7 @@ export async function runSleepCycle( const result: SleepCycleResult = { dedup: { clustersFound: 0, memoriesMerged: 0 }, conflict: { pairsFound: 0, resolved: 0, invalidated: 0 }, + semanticDedup: { pairsChecked: 0, duplicatesMerged: 0 }, pareto: { totalMemories: 0, coreMemories: 0, regularMemories: 0, threshold: 0 }, promotion: { candidatesFound: 0, promoted: 0 }, demotion: { candidatesFound: 0, demoted: 0 }, @@ -602,32 +600,150 @@ export async function runSleepCycle( aborted: false, }; + const LLM_CONCURRENCY = 8; + // -------------------------------------------------------------------------- - // Phase 1: Deduplication + // Phase 1: Deduplication (Optimized - combined vector + semantic dedup) + // Call findDuplicateClusters ONCE at 0.75 threshold, then split by similarity band: + // - ≥0.95: vector merge (high-confidence duplicates) + // - 0.75-0.95: semantic dedup via LLM (paraphrases) // -------------------------------------------------------------------------- if (!abortSignal?.aborted) { onPhaseStart?.("dedup"); - logger.info("memory-neo4j: [sleep] Phase 1: Deduplication"); + logger.info("memory-neo4j: [sleep] Phase 1: Deduplication (vector + semantic)"); try { - const clusters = await db.findDuplicateClusters(dedupThreshold, agentId); - result.dedup.clustersFound = clusters.length; + // Fetch clusters at 0.75 threshold with similarity scores + const allClusters = await db.findDuplicateClusters(0.75, agentId, true); - for (const cluster of clusters) { - if (abortSignal?.aborted) { - break; + // Helper to create canonical pair key (sorted) + const makePairKey = (a: string, b: string): string => { + return a < b ? `${a}:${b}` : `${b}:${a}`; + }; + + // Separate clusters into high-similarity (≥0.95) and medium-similarity (0.75-0.95) + const highSimClusters: typeof allClusters = []; + const mediumSimClusters: typeof allClusters = []; + + for (const cluster of allClusters) { + if (abortSignal?.aborted) break; + if (!cluster.similarities || cluster.memoryIds.length < 2) continue; + + // Check if ANY pair in this cluster has similarity ≥ dedupThreshold + let hasHighSim = false; + for (const [pairKey, score] of cluster.similarities.entries()) { + if (score >= dedupThreshold) { + hasHighSim = true; + break; + } } + if (hasHighSim) { + // Split this cluster into high-sim and medium-sim sub-clusters + // For simplicity, if a cluster has ANY high-sim pair, treat the whole cluster as high-sim + // (This matches the old behavior where Phase 1 would merge them all) + highSimClusters.push(cluster); + } else { + mediumSimClusters.push(cluster); + } + } + + // Part 1a: Vector merge for high-similarity clusters (≥0.95) + result.dedup.clustersFound = highSimClusters.length; + + for (const cluster of highSimClusters) { + if (abortSignal?.aborted) break; + const { deletedCount } = await db.mergeMemoryCluster( cluster.memoryIds, cluster.importances, ); result.dedup.memoriesMerged += deletedCount; - onProgress?.("dedup", `Merged cluster of ${cluster.memoryIds.length} → 1`); + onProgress?.("dedup", `Merged cluster of ${cluster.memoryIds.length} → 1 (vector)`); } logger.info( - `memory-neo4j: [sleep] Phase 1 complete — ${result.dedup.clustersFound} clusters, ${result.dedup.memoriesMerged} merged`, + `memory-neo4j: [sleep] Phase 1a (vector) complete — ${result.dedup.clustersFound} clusters, ${result.dedup.memoriesMerged} merged`, + ); + + // Part 1b: Semantic dedup for medium-similarity clusters (0.75-0.95) + onPhaseStart?.("semanticDedup"); + logger.info("memory-neo4j: [sleep] Phase 1b: Semantic Deduplication (0.75-0.95 band)"); + + // Collect all candidate pairs upfront + type DedupPair = { + textA: string; + textB: string; + idA: string; + idB: string; + importanceA: number; + importanceB: number; + }; + const allPairs: DedupPair[] = []; + + for (const cluster of mediumSimClusters) { + if (cluster.memoryIds.length < 2) continue; + for (let i = 0; i < cluster.memoryIds.length - 1; i++) { + for (let j = i + 1; j < cluster.memoryIds.length; j++) { + allPairs.push({ + textA: cluster.texts[i], + textB: cluster.texts[j], + idA: cluster.memoryIds[i], + idB: cluster.memoryIds[j], + importanceA: cluster.importances[i], + importanceB: cluster.importances[j], + }); + } + } + } + + // Process pairs in concurrent batches + const invalidatedIds = new Set(); + + for (let i = 0; i < allPairs.length && !abortSignal?.aborted; i += LLM_CONCURRENCY) { + const batch = allPairs.slice(i, i + LLM_CONCURRENCY); + + // Filter out pairs where one side was already invalidated + const activeBatch = batch.filter( + (p) => !invalidatedIds.has(p.idA) && !invalidatedIds.has(p.idB), + ); + + if (activeBatch.length === 0) continue; + + const outcomes = await Promise.allSettled( + activeBatch.map((p) => isSemanticDuplicate(p.textA, p.textB, config)), + ); + + for (let k = 0; k < outcomes.length; k++) { + const pair = activeBatch[k]; + result.semanticDedup.pairsChecked++; + + if ( + outcomes[k].status === "fulfilled" && + (outcomes[k] as PromiseFulfilledResult).value + ) { + // Skip if either side was invalidated by an earlier result in this batch + if (invalidatedIds.has(pair.idA) || invalidatedIds.has(pair.idB)) continue; + + const keepId = pair.importanceA >= pair.importanceB ? pair.idA : pair.idB; + const removeId = keepId === pair.idA ? pair.idB : pair.idA; + const keepText = keepId === pair.idA ? pair.textA : pair.textB; + const removeText = removeId === pair.idA ? pair.textA : pair.textB; + + await db.invalidateMemory(removeId); + invalidatedIds.add(removeId); + result.semanticDedup.duplicatesMerged++; + + onProgress?.( + "semanticDedup", + `Merged: "${removeText.slice(0, 50)}..." → kept "${keepText.slice(0, 50)}..."`, + ); + } + } + } + + logger.info( + `memory-neo4j: [sleep] Phase 1b (semantic) complete — ${result.semanticDedup.pairsChecked} pairs checked, ${result.semanticDedup.duplicatesMerged} merged`, ); } catch (err) { logger.warn(`memory-neo4j: [sleep] Phase 1 error: ${String(err)}`); @@ -635,11 +751,11 @@ export async function runSleepCycle( } // -------------------------------------------------------------------------- - // Phase 1b: Conflict Detection + // Phase 1c: Conflict Detection (formerly Phase 1b) // -------------------------------------------------------------------------- if (!abortSignal?.aborted) { onPhaseStart?.("conflict"); - logger.info("memory-neo4j: [sleep] Phase 1b: Conflict Detection"); + logger.info("memory-neo4j: [sleep] Phase 1c: Conflict Detection"); try { const pairs = await db.findConflictingMemories(agentId); @@ -668,10 +784,10 @@ export async function runSleepCycle( } logger.info( - `memory-neo4j: [sleep] Phase 1b complete — ${result.conflict.pairsFound} pairs, ${result.conflict.resolved} resolved, ${result.conflict.invalidated} invalidated`, + `memory-neo4j: [sleep] Phase 1c complete — ${result.conflict.pairsFound} pairs, ${result.conflict.resolved} resolved, ${result.conflict.invalidated} invalidated`, ); } catch (err) { - logger.warn(`memory-neo4j: [sleep] Phase 1b error: ${String(err)}`); + logger.warn(`memory-neo4j: [sleep] Phase 1c error: ${String(err)}`); } } @@ -790,7 +906,7 @@ export async function runSleepCycle( // Phase 5: Entity Extraction (moved before decay so new memories get // extracted before pruning can remove them) // -------------------------------------------------------------------------- - const EXTRACTION_CONCURRENCY = 3; + // Extraction uses LLM_CONCURRENCY (defined above, matches OLLAMA_NUM_PARALLEL) if (!abortSignal?.aborted && config.enabled) { onPhaseStart?.("extraction"); logger.info("memory-neo4j: [sleep] Phase 5: Entity Extraction"); @@ -810,13 +926,9 @@ export async function runSleepCycle( break; } - // Process in parallel chunks of EXTRACTION_CONCURRENCY - for ( - let i = 0; - i < pending.length && !abortSignal?.aborted; - i += EXTRACTION_CONCURRENCY - ) { - const chunk = pending.slice(i, i + EXTRACTION_CONCURRENCY); + // Process in parallel chunks of LLM_CONCURRENCY + for (let i = 0; i < pending.length && !abortSignal?.aborted; i += LLM_CONCURRENCY) { + const chunk = pending.slice(i, i + LLM_CONCURRENCY); const outcomes = await Promise.allSettled( chunk.map((memory) => runBackgroundExtraction( @@ -840,10 +952,7 @@ export async function runSleepCycle( } } - if ( - result.extraction.processed % 10 === 0 || - i + EXTRACTION_CONCURRENCY >= pending.length - ) { + if (result.extraction.processed % 10 === 0 || i + LLM_CONCURRENCY >= pending.length) { onProgress?.( "extraction", `${result.extraction.processed}/${result.extraction.total} processed`, @@ -1084,19 +1193,15 @@ export function extractAssistantMessages(messages: unknown[]): string[] { // LLM-Judged Importance Rating // ============================================================================ -const IMPORTANCE_RATING_PROMPT = `Rate the long-term importance of remembering this information on a scale of 1-10. +// System instruction — user message contains the text to rate +const IMPORTANCE_RATING_SYSTEM = `Rate the long-term importance of remembering the user's information on a scale of 1-10. 1-3: Trivial/transient (greetings, temporary status) 4-6: Moderately useful (general facts, minor preferences) 7-9: Very important (key decisions, strong preferences, critical facts) 10: Essential (identity-defining, safety-critical) -Information: "{text}" - Return JSON: {"score": N, "reason": "brief explanation"}`; -/** Timeout for importance rating calls (much shorter than extraction) */ -const IMPORTANCE_TIMEOUT_MS = 5_000; - /** * Rate the long-term importance of a text using an LLM. * Returns a value between 0.1 and 1.0, or 0.5 on any failure. @@ -1106,32 +1211,11 @@ export async function rateImportance(text: string, config: ExtractionConfig): Pr return 0.5; } - const prompt = IMPORTANCE_RATING_PROMPT.replace("{text}", text); - try { - const response = await fetch(`${config.baseUrl}/chat/completions`, { - method: "POST", - headers: { - Authorization: `Bearer ${config.apiKey}`, - "Content-Type": "application/json", - }, - body: JSON.stringify({ - model: config.model, - messages: [{ role: "user", content: prompt }], - temperature: config.temperature, - response_format: { type: "json_object" }, - }), - signal: AbortSignal.timeout(IMPORTANCE_TIMEOUT_MS), - }); - - if (!response.ok) { - return 0.5; - } - - const data = (await response.json()) as { - choices?: Array<{ message?: { content?: string } }>; - }; - const content = data.choices?.[0]?.message?.content; + const content = await callOpenRouter(config, [ + { role: "system", content: IMPORTANCE_RATING_SYSTEM }, + { role: "user", content: text }, + ]); if (!content) { return 0.5; } @@ -1148,3 +1232,49 @@ export async function rateImportance(text: string, config: ExtractionConfig): Pr return 0.5; } } + +// ============================================================================ +// Semantic Deduplication +// ============================================================================ + +// System instruction — user message contains the two texts to compare +const SEMANTIC_DEDUP_SYSTEM = `You are a memory deduplication system. Determine whether the new text conveys the SAME factual information as the existing memory. + +Rules: +- Return "duplicate" if the new text is conveying the same core fact(s), even if worded differently +- Return "duplicate" if the new text is a subset of information already in the existing memory +- Return "unique" if the new text contains genuinely new information not in the existing memory +- Ignore differences in formatting, pronouns, or phrasing — focus on the underlying facts + +Return JSON: {"verdict": "duplicate"|"unique", "reason": "brief explanation"}`; + +/** + * Check whether new text is semantically a duplicate of an existing memory. + * Uses an LLM to compare meaning rather than surface similarity. + * Returns true if the new text is a duplicate (should be skipped). + * Returns false on any failure (allow storage). + */ +export async function isSemanticDuplicate( + newText: string, + existingText: string, + config: ExtractionConfig, +): Promise { + if (!config.enabled) { + return false; + } + + try { + const content = await callOpenRouter(config, [ + { role: "system", content: SEMANTIC_DEDUP_SYSTEM }, + { role: "user", content: `Existing memory: "${existingText}"\nNew text: "${newText}"` }, + ]); + if (!content) { + return false; + } + + const parsed = JSON.parse(content) as { verdict?: string }; + return parsed.verdict === "duplicate"; + } catch { + return false; + } +} diff --git a/extensions/memory-neo4j/index.ts b/extensions/memory-neo4j/index.ts index c11b21e736d..4a188098ee7 100644 --- a/extensions/memory-neo4j/index.ts +++ b/extensions/memory-neo4j/index.ts @@ -31,6 +31,7 @@ import { extractAssistantMessages, stripMessageWrappers, runSleepCycle, + isSemanticDuplicate, rateImportance, } from "./extractor.js"; import { Neo4jMemoryClient } from "./neo4j-client.js"; @@ -245,7 +246,8 @@ const memoryNeo4jPlugin = { // memory_forget — Delete with cascade api.registerTool( - (_ctx) => { + (ctx) => { + const agentId = ctx.agentId || "default"; return { name: "memory_forget", label: "Memory Forget", @@ -262,7 +264,7 @@ const memoryNeo4jPlugin = { // Direct delete by ID if (memoryId) { - const deleted = await db.deleteMemory(memoryId); + const deleted = await db.deleteMemory(memoryId, agentId); if (!deleted) { return { content: [ @@ -288,7 +290,7 @@ const memoryNeo4jPlugin = { // Search-based delete if (query) { const vector = await embeddings.embed(query); - const results = await db.vectorSearch(vector, 5, 0.7); + const results = await db.vectorSearch(vector, 5, 0.7, agentId); if (results.length === 0) { return { @@ -299,7 +301,7 @@ const memoryNeo4jPlugin = { // Auto-delete if single high-confidence match if (results.length === 1 && results[0].score > 0.9) { - await db.deleteMemory(results[0].id); + await db.deleteMemory(results[0].id, agentId); return { content: [ { @@ -517,7 +519,10 @@ const memoryNeo4jPlugin = { console.log("═════════════════════════════════════════════════════════════"); console.log("Seven-phase memory consolidation (Pareto-based):\n"); console.log(" Phase 1: Deduplication — Merge near-duplicate memories"); - console.log(" Phase 1b: Conflict Detection — Resolve contradictory memories"); + console.log( + " Phase 1b: Semantic Dedup — LLM-based paraphrase detection (0.75–0.95 band)", + ); + console.log(" Phase 1c: Conflict Detection — Resolve contradictory memories"); console.log( " Phase 2: Pareto Scoring — Calculate effective scores for all memories", ); @@ -593,7 +598,8 @@ const memoryNeo4jPlugin = { onPhaseStart: (phase) => { const phaseNames: Record = { dedup: "Phase 1: Deduplication", - conflict: "Phase 1b: Conflict Detection", + semanticDedup: "Phase 1b: Semantic Deduplication", + conflict: "Phase 1c: Conflict Detection", pareto: "Phase 2: Pareto Scoring", promotion: "Phase 3: Core Promotion", demotion: "Phase 4: Core Demotion", @@ -618,6 +624,9 @@ const memoryNeo4jPlugin = { console.log( ` Conflicts: ${result.conflict.pairsFound} pairs, ${result.conflict.resolved} resolved, ${result.conflict.invalidated} invalidated`, ); + console.log( + ` Semantic Dedup: ${result.semanticDedup.pairsChecked} pairs checked, ${result.semanticDedup.duplicatesMerged} merged`, + ); console.log( ` Pareto: ${result.pareto.totalMemories} total (${result.pareto.coreMemories} core, ${result.pareto.regularMemories} regular)`, ); @@ -1114,18 +1123,45 @@ const memoryNeo4jPlugin = { const userMessages = extractUserMessages(event.messages); const retained = userMessages.filter((text) => passesAttentionGate(text)); + let semanticDeduped = 0; for (const text of retained) { try { const vector = await embeddings.embed(text); - // Quick dedup (same content already stored) + // Quick dedup (same content already stored — cosine ≥ 0.95) const existing = await db.findSimilar(vector, 0.95, 1); if (existing.length > 0) { continue; } + // Importance rating — moved before semantic dedup to avoid expensive LLM calls on low-value memories const importance = await rateImportance(text, extractionConfig); + // Skip low-importance memories (not worth the semantic dedup cost) + if (importance < 0.3) { + continue; + } + + // Semantic dedup: check moderate-similarity memories (0.75–0.95) + // with LLM to catch paraphrases and reformulations + const candidates = await db.findSimilar(vector, 0.75, 3); + if (candidates.length > 0) { + let isDuplicate = false; + for (const candidate of candidates) { + if (await isSemanticDuplicate(text, candidate.text, extractionConfig)) { + api.logger.debug?.( + `memory-neo4j: semantic dedup — skipped "${text.slice(0, 60)}..." (duplicate of "${candidate.text.slice(0, 60)}...")`, + ); + isDuplicate = true; + semanticDeduped++; + break; + } + } + if (isDuplicate) { + continue; + } + } + await db.storeMemory({ id: randomUUID(), text, @@ -1165,11 +1201,30 @@ const memoryNeo4jPlugin = { continue; } + // Semantic dedup for assistant messages too + const candidates = await db.findSimilar(vector, 0.75, 3); + if (candidates.length > 0) { + let isDuplicate = false; + for (const candidate of candidates) { + if (await isSemanticDuplicate(text, candidate.text, extractionConfig)) { + api.logger.debug?.( + `memory-neo4j: semantic dedup (assistant) — skipped "${text.slice(0, 60)}..."`, + ); + isDuplicate = true; + semanticDeduped++; + break; + } + } + if (isDuplicate) { + continue; + } + } + await db.storeMemory({ id: randomUUID(), text, embedding: vector, - importance: Math.min(importance, 0.4), // cap assistant importance slightly lower + importance: importance * 0.75, // discount assistant importance proportionally category: "other", source: "auto-capture-assistant", extractionStatus: extractionConfig.enabled ? "pending" : "skipped", @@ -1184,8 +1239,10 @@ const memoryNeo4jPlugin = { } } - if (stored > 0) { - api.logger.info(`memory-neo4j: auto-captured ${stored} memories (attention-gated)`); + if (stored > 0 || semanticDeduped > 0) { + api.logger.info( + `memory-neo4j: auto-captured ${stored} memories (attention-gated)${semanticDeduped > 0 ? `, ${semanticDeduped} semantic dupes skipped` : ""}`, + ); } else if (userMessages.length > 0 || assistantMessages.length > 0) { api.logger.info( `memory-neo4j: auto-capture ran (0 stored, ${userMessages.length} user msgs, ${retained.length} passed gate, ${assistantMessages.length} assistant msgs, ${retainedAssistant.length} passed gate)`, diff --git a/extensions/memory-neo4j/neo4j-client.test.ts b/extensions/memory-neo4j/neo4j-client.test.ts new file mode 100644 index 00000000000..f720be4a41d --- /dev/null +++ b/extensions/memory-neo4j/neo4j-client.test.ts @@ -0,0 +1,1460 @@ +/** + * Tests for neo4j-client.ts — Database Operations. + * + * Tests Neo4jMemoryClient methods using mocked Neo4j driver. + * Focuses on behavioral contracts, not implementation details. + */ + +import type { Driver } from "neo4j-driver"; +import { describe, it, expect, vi, beforeEach } from "vitest"; +import type { StoreMemoryInput, MergeEntityInput } from "./schema.js"; +import { Neo4jMemoryClient } from "./neo4j-client.js"; + +// ============================================================================ +// Test Helpers +// ============================================================================ + +function createMockSession() { + return { + run: vi.fn().mockResolvedValue({ records: [] }), + close: vi.fn().mockResolvedValue(undefined), + executeWrite: vi.fn( + async (work: (tx: { run: ReturnType }) => Promise) => { + // Create a mock transaction that delegates to the session's run mock + const mockTx = { run: vi.fn().mockResolvedValue({ records: [] }) }; + return work(mockTx); + }, + ), + }; +} + +function createMockDriver() { + return { + session: vi.fn().mockReturnValue(createMockSession()), + close: vi.fn().mockResolvedValue(undefined), + }; +} + +function createMockLogger() { + return { + info: vi.fn(), + warn: vi.fn(), + error: vi.fn(), + debug: vi.fn(), + }; +} + +// ============================================================================ +// Neo4jMemoryClient Tests +// ============================================================================ + +describe("Neo4jMemoryClient", () => { + let client: Neo4jMemoryClient; + let mockDriver: ReturnType; + let mockSession: ReturnType; + let mockLogger: ReturnType; + + beforeEach(() => { + mockLogger = createMockLogger(); + mockDriver = createMockDriver(); + mockSession = createMockSession(); + mockDriver.session.mockReturnValue(mockSession); + + // Create client (uri, username, password, dimensions, logger) + client = new Neo4jMemoryClient("bolt://localhost:7687", "neo4j", "password", 1024, mockLogger); + + // Replace driver with mock + (client as any).driver = mockDriver; + (client as any).indexesReady = true; + }); + + // ------------------------------------------------------------------------ + // storeMemory() + // ------------------------------------------------------------------------ + + describe("storeMemory", () => { + it("should store memory with correct Cypher params", async () => { + const input: StoreMemoryInput = { + id: "mem-1", + text: "test memory", + embedding: [0.1, 0.2, 0.3], + importance: 0.8, + category: "fact", + source: "user", + extractionStatus: "pending", + agentId: "agent-1", + sessionKey: "session-1", + }; + + mockSession.run.mockResolvedValue({ + records: [{ get: vi.fn().mockReturnValue("mem-1") }], + }); + + const result = await client.storeMemory(input); + + expect(result).toBe("mem-1"); + expect(mockSession.run).toHaveBeenCalledWith( + expect.stringContaining("CREATE (m:Memory {"), + expect.objectContaining({ + id: "mem-1", + text: "test memory", + embedding: [0.1, 0.2, 0.3], + importance: 0.8, + category: "fact", + source: "user", + extractionStatus: "pending", + agentId: "agent-1", + sessionKey: "session-1", + retrievalCount: 0, + lastRetrievedAt: null, + extractionRetries: 0, + }), + ); + }); + + it("should store embedding correctly", async () => { + const input: StoreMemoryInput = { + id: "mem-1", + text: "test", + embedding: [0.1, 0.2, 0.3, 0.4, 0.5], + importance: 0.5, + category: "other", + source: "auto-capture", + extractionStatus: "skipped", + agentId: "default", + }; + + mockSession.run.mockResolvedValue({ + records: [{ get: vi.fn().mockReturnValue("mem-1") }], + }); + + await client.storeMemory(input); + + expect(mockSession.run).toHaveBeenCalledWith( + expect.any(String), + expect.objectContaining({ + embedding: [0.1, 0.2, 0.3, 0.4, 0.5], + }), + ); + }); + + it("should initialize retrievalCount to 0", async () => { + const input: StoreMemoryInput = { + id: "mem-1", + text: "test", + embedding: [], + importance: 0.5, + category: "other", + source: "user", + extractionStatus: "pending", + agentId: "default", + }; + + mockSession.run.mockResolvedValue({ + records: [{ get: vi.fn().mockReturnValue("mem-1") }], + }); + + await client.storeMemory(input); + + expect(mockSession.run).toHaveBeenCalledWith( + expect.any(String), + expect.objectContaining({ + retrievalCount: 0, + }), + ); + }); + }); + + // ------------------------------------------------------------------------ + // deleteMemory() + // ------------------------------------------------------------------------ + + describe("deleteMemory", () => { + const testMemId = "550e8400-e29b-41d4-a716-446655440000"; + + it("should return true when memory exists and is deleted", async () => { + mockSession.run.mockResolvedValueOnce({ + records: [{ get: vi.fn().mockReturnValue(1) }], + }); + + const result = await client.deleteMemory(testMemId); + + expect(result).toBe(true); + }); + + it("should return false when memory does not exist", async () => { + mockSession.run.mockResolvedValueOnce({ + records: [{ get: vi.fn().mockReturnValue(0) }], + }); + + const result = await client.deleteMemory(testMemId); + + expect(result).toBe(false); + }); + + it("should decrement entity mention counts and delete atomically", async () => { + mockSession.run.mockResolvedValueOnce({ + records: [{ get: vi.fn().mockReturnValue(1) }], + }); + + await client.deleteMemory(testMemId); + + // Single atomic query handles both mentionCount decrement and delete + expect(mockSession.run).toHaveBeenCalledTimes(1); + expect(mockSession.run).toHaveBeenCalledWith(expect.stringContaining("MENTIONS"), { + id: testMemId, + }); + expect(mockSession.run).toHaveBeenCalledWith(expect.stringContaining("DETACH DELETE"), { + id: testMemId, + }); + }); + + it("should reject invalid UUID format", async () => { + await expect(client.deleteMemory("not-a-uuid")).rejects.toThrow("Invalid memory ID format"); + }); + + it("should accept valid UUID formats", async () => { + mockSession.run.mockResolvedValue({ + records: [{ get: vi.fn().mockReturnValue(1) }], + }); + + await expect(client.deleteMemory("550e8400-e29b-41d4-a716-446655440000")).resolves.toBe(true); + }); + }); + + // ------------------------------------------------------------------------ + // findSimilar() + // ------------------------------------------------------------------------ + + describe("findSimilar", () => { + it("should query vector index with threshold", async () => { + mockSession.run.mockResolvedValue({ + records: [ + { + get: vi.fn((key) => { + if (key === "id") return "mem-1"; + if (key === "text") return "similar text"; + if (key === "similarity") return 0.96; + return null; + }), + }, + ], + }); + + const result = await client.findSimilar([0.1, 0.2, 0.3], 0.95, 5); + + expect(result).toHaveLength(1); + expect(result[0]).toEqual({ + id: "mem-1", + text: "similar text", + score: 0.96, + }); + expect(mockSession.run).toHaveBeenCalledWith( + expect.stringContaining("db.index.vector.queryNodes"), + expect.objectContaining({ + embedding: [0.1, 0.2, 0.3], + threshold: 0.95, + }), + ); + }); + + it("should filter results by threshold", async () => { + // Mock should only return results >= threshold + // (In reality, the vector index does this filtering) + mockSession.run.mockResolvedValue({ records: [] }); + + const result = await client.findSimilar([0.1, 0.2], 0.99, 10); + + expect(result).toHaveLength(0); + }); + + it("should return empty array on vector index failure", async () => { + mockSession.run.mockRejectedValue(new Error("index not ready")); + + const result = await client.findSimilar([0.1, 0.2], 0.95, 5); + + expect(result).toEqual([]); + expect(mockLogger.debug).toHaveBeenCalled(); + }); + }); + + // ------------------------------------------------------------------------ + // findDuplicateClusters() + // ------------------------------------------------------------------------ + + describe("findDuplicateClusters", () => { + it("should use union-find to build clusters", async () => { + // Mock all memories + mockSession.run.mockResolvedValueOnce({ + records: [ + { + get: vi.fn((key) => { + if (key === "id") return "m1"; + if (key === "text") return "text1"; + if (key === "importance") return 0.5; + return null; + }), + }, + { + get: vi.fn((key) => { + if (key === "id") return "m2"; + if (key === "text") return "text2"; + if (key === "importance") return 0.6; + return null; + }), + }, + { + get: vi.fn((key) => { + if (key === "id") return "m3"; + if (key === "text") return "text3"; + if (key === "importance") return 0.7; + return null; + }), + }, + ], + }); + + // Mock vector similarity queries + // m1 similar to m2, m2 similar to m3 => cluster {m1, m2, m3} + mockSession.run + .mockResolvedValueOnce({ + // m1 neighbors + records: [{ get: vi.fn().mockReturnValue("m2") }], + }) + .mockResolvedValueOnce({ + // m2 neighbors + records: [{ get: vi.fn().mockReturnValue("m3") }], + }) + .mockResolvedValueOnce({ + // m3 neighbors + records: [], + }); + + const result = await client.findDuplicateClusters(0.95); + + expect(result).toHaveLength(1); + expect(result[0].memoryIds).toHaveLength(3); + expect(result[0].memoryIds).toContain("m1"); + expect(result[0].memoryIds).toContain("m2"); + expect(result[0].memoryIds).toContain("m3"); + }); + + it("should respect safety bound (max 500 pairs)", async () => { + // Create many memories + const manyRecords = Array.from({ length: 100 }, (_, i) => ({ + get: vi.fn((key) => { + if (key === "id") return `m${i}`; + if (key === "text") return `text${i}`; + if (key === "importance") return 0.5; + return null; + }), + })); + + mockSession.run.mockResolvedValueOnce({ records: manyRecords }); + + // Mock each memory finding many neighbors (would exceed 500 pairs) + for (let i = 0; i < 100; i++) { + mockSession.run.mockResolvedValueOnce({ + records: Array.from({ length: 10 }, (_, j) => ({ + get: vi.fn().mockReturnValue(`m${(i + j + 1) % 100}`), + })), + }); + + // Early exit when pairsFound > 500 + if (i >= 50) break; + } + + const result = await client.findDuplicateClusters(0.95); + + // Should exit early without processing all memories + expect(result).toBeDefined(); + }); + + it("should return only clusters with 2+ members", async () => { + mockSession.run.mockResolvedValueOnce({ + records: [ + { get: vi.fn((key) => (key === "id" ? "m1" : key === "text" ? "text1" : 0.5)) }, + { get: vi.fn((key) => (key === "id" ? "m2" : key === "text" ? "text2" : 0.6)) }, + ], + }); + + // m1 has no neighbors, m2 has no neighbors => no clusters + mockSession.run.mockResolvedValueOnce({ records: [] }).mockResolvedValueOnce({ records: [] }); + + const result = await client.findDuplicateClusters(0.95); + + expect(result).toHaveLength(0); + }); + + it("should handle empty database", async () => { + mockSession.run.mockResolvedValue({ records: [] }); + + const result = await client.findDuplicateClusters(0.95); + + expect(result).toEqual([]); + }); + + it("should handle single memory", async () => { + mockSession.run.mockResolvedValueOnce({ + records: [{ get: vi.fn((key) => (key === "id" ? "m1" : key === "text" ? "text1" : 0.5)) }], + }); + mockSession.run.mockResolvedValueOnce({ records: [] }); + + const result = await client.findDuplicateClusters(0.95); + + expect(result).toEqual([]); + }); + }); + + // ------------------------------------------------------------------------ + // mergeMemoryCluster() + // ------------------------------------------------------------------------ + + describe("mergeMemoryCluster", () => { + it("should keep highest importance memory", async () => { + const txRun = vi + .fn() + // Verify step + .mockResolvedValueOnce({ + records: [ + { get: vi.fn((key: string) => (key === "memId" ? "low" : true)) }, + { get: vi.fn((key: string) => (key === "memId" ? "high" : true)) }, + { get: vi.fn((key: string) => (key === "memId" ? "mid" : true)) }, + ], + }) + // Transfer mentions + .mockResolvedValueOnce({ records: [] }) + // Delete duplicates + .mockResolvedValueOnce({ records: [] }); + + mockSession.executeWrite.mockImplementationOnce( + async (work: (tx: { run: typeof txRun }) => Promise) => { + return work({ run: txRun }); + }, + ); + + const result = await client.mergeMemoryCluster(["low", "high", "mid"], [0.3, 0.9, 0.5]); + + expect(result.survivorId).toBe("high"); + expect(result.deletedCount).toBe(2); + + // Should delete "low" and "mid" + expect(txRun).toHaveBeenCalledWith( + expect.stringContaining("DETACH DELETE"), + expect.objectContaining({ toDelete: ["low", "mid"] }), + ); + }); + + it("should transfer MENTIONS relationships to survivor", async () => { + const txRun = vi + .fn() + .mockResolvedValueOnce({ + records: [ + { get: vi.fn((key: string) => (key === "memId" ? "m1" : true)) }, + { get: vi.fn((key: string) => (key === "memId" ? "m2" : true)) }, + ], + }) + .mockResolvedValueOnce({ records: [] }) + .mockResolvedValueOnce({ records: [] }); + + mockSession.executeWrite.mockImplementationOnce( + async (work: (tx: { run: typeof txRun }) => Promise) => { + return work({ run: txRun }); + }, + ); + + await client.mergeMemoryCluster(["m1", "m2"], [0.5, 0.6]); + + // Should transfer mentions from m1 to m2 + expect(txRun).toHaveBeenCalledWith( + expect.stringContaining("MENTIONS"), + expect.objectContaining({ + toDelete: ["m1"], + survivorId: "m2", + }), + ); + }); + + it("should skip merge when cluster members are missing", async () => { + const txRun = vi.fn().mockResolvedValueOnce({ + records: [ + { get: vi.fn((key: string) => (key === "memId" ? "m1" : true)) }, + { get: vi.fn((key: string) => (key === "memId" ? "m2" : false)) }, // missing! + ], + }); + + mockSession.executeWrite.mockImplementationOnce( + async (work: (tx: { run: typeof txRun }) => Promise) => { + return work({ run: txRun }); + }, + ); + + const result = await client.mergeMemoryCluster(["m1", "m2"], [0.5, 0.6]); + + expect(result.deletedCount).toBe(0); + expect(mockLogger.warn).toHaveBeenCalledWith( + expect.stringContaining("skipping cluster merge"), + ); + }); + + it("should handle single-member cluster gracefully", async () => { + const txRun = vi.fn().mockResolvedValueOnce({ + records: [{ get: vi.fn((key: string) => (key === "memId" ? "m1" : true)) }], + }); + + mockSession.executeWrite.mockImplementationOnce( + async (work: (tx: { run: typeof txRun }) => Promise) => { + return work({ run: txRun }); + }, + ); + + const result = await client.mergeMemoryCluster(["m1"], [0.8]); + + expect(result.survivorId).toBe("m1"); + expect(result.deletedCount).toBe(0); + }); + }); + + // ------------------------------------------------------------------------ + // invalidateMemory() + // ------------------------------------------------------------------------ + + describe("invalidateMemory", () => { + it("should set importance to 0.01", async () => { + mockSession.run.mockResolvedValue({ records: [] }); + + await client.invalidateMemory("mem-1"); + + expect(mockSession.run).toHaveBeenCalledWith( + expect.stringContaining("m.importance = 0.01"), + expect.objectContaining({ id: "mem-1" }), + ); + }); + + it("should update updatedAt timestamp", async () => { + mockSession.run.mockResolvedValue({ records: [] }); + + await client.invalidateMemory("mem-1"); + + expect(mockSession.run).toHaveBeenCalledWith( + expect.stringContaining("m.updatedAt"), + expect.objectContaining({ + id: "mem-1", + now: expect.any(String), + }), + ); + }); + }); + + // ------------------------------------------------------------------------ + // calculateAllEffectiveScores() + // ------------------------------------------------------------------------ + + describe("calculateAllEffectiveScores", () => { + it("should apply correct formula (importance × freq_boost × recency)", async () => { + mockSession.run.mockResolvedValue({ + records: [ + { + get: vi.fn((key) => { + const data: Record = { + id: "m1", + text: "test", + category: "fact", + importance: 0.8, + retrievalCount: 10, + ageDays: 7, + effectiveScore: 0.75, // Pre-calculated by Cypher + }; + return data[key]; + }), + }, + ], + }); + + const result = await client.calculateAllEffectiveScores(); + + expect(result).toHaveLength(1); + expect(result[0]).toMatchObject({ + id: "m1", + text: "test", + category: "fact", + importance: 0.8, + retrievalCount: 10, + ageDays: 7, + effectiveScore: 0.75, + }); + }); + + it("should handle empty database", async () => { + mockSession.run.mockResolvedValue({ records: [] }); + + const result = await client.calculateAllEffectiveScores(); + + expect(result).toEqual([]); + }); + + it("should filter by agentId when provided", async () => { + mockSession.run.mockResolvedValue({ records: [] }); + + await client.calculateAllEffectiveScores("agent-1"); + + expect(mockSession.run).toHaveBeenCalledWith( + expect.stringContaining("m.agentId = $agentId"), + expect.objectContaining({ agentId: "agent-1" }), + ); + }); + }); + + // ------------------------------------------------------------------------ + // calculateParetoThreshold() + // ------------------------------------------------------------------------ + + describe("calculateParetoThreshold", () => { + it("should return correct 80th percentile", () => { + const scores = [ + { + id: "1", + text: "", + category: "fact", + importance: 0.9, + retrievalCount: 0, + ageDays: 0, + effectiveScore: 1.0, + }, + { + id: "2", + text: "", + category: "fact", + importance: 0.9, + retrievalCount: 0, + ageDays: 0, + effectiveScore: 0.9, + }, + { + id: "3", + text: "", + category: "fact", + importance: 0.9, + retrievalCount: 0, + ageDays: 0, + effectiveScore: 0.8, + }, + { + id: "4", + text: "", + category: "fact", + importance: 0.9, + retrievalCount: 0, + ageDays: 0, + effectiveScore: 0.7, + }, + { + id: "5", + text: "", + category: "fact", + importance: 0.9, + retrievalCount: 0, + ageDays: 0, + effectiveScore: 0.6, + }, + { + id: "6", + text: "", + category: "fact", + importance: 0.9, + retrievalCount: 0, + ageDays: 0, + effectiveScore: 0.5, + }, + { + id: "7", + text: "", + category: "fact", + importance: 0.9, + retrievalCount: 0, + ageDays: 0, + effectiveScore: 0.4, + }, + { + id: "8", + text: "", + category: "fact", + importance: 0.9, + retrievalCount: 0, + ageDays: 0, + effectiveScore: 0.3, + }, + { + id: "9", + text: "", + category: "fact", + importance: 0.9, + retrievalCount: 0, + ageDays: 0, + effectiveScore: 0.2, + }, + { + id: "10", + text: "", + category: "fact", + importance: 0.9, + retrievalCount: 0, + ageDays: 0, + effectiveScore: 0.1, + }, + ]; + + // percentile=0.8 means top 20% + const threshold = client.calculateParetoThreshold(scores, 0.8); + + // 80th percentile of [1.0, 0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.1] + // Top 20% = 2 items, boundary at floor(10 * 0.2) = 2, but for top N%, use index N-1 as threshold + // FIXME: Implementation returns sorted[1] = 0.9 for top 20%, not sorted[2] = 0.8 + expect(threshold).toBe(0.9); + }); + + it("should handle empty scores array", () => { + const threshold = client.calculateParetoThreshold([], 0.8); + expect(threshold).toBe(0); + }); + + it("should handle single score", () => { + const scores = [ + { + id: "1", + text: "", + category: "fact", + importance: 0.9, + retrievalCount: 0, + ageDays: 0, + effectiveScore: 0.75, + }, + ]; + const threshold = client.calculateParetoThreshold(scores, 0.8); + expect(threshold).toBe(0.75); + }); + + it("should handle 50th percentile (median)", () => { + const scores = [ + { + id: "1", + text: "", + category: "fact", + importance: 0.9, + retrievalCount: 0, + ageDays: 0, + effectiveScore: 1.0, + }, + { + id: "2", + text: "", + category: "fact", + importance: 0.9, + retrievalCount: 0, + ageDays: 0, + effectiveScore: 0.5, + }, + ]; + const threshold = client.calculateParetoThreshold(scores, 0.5); + // For 2 items with percentile 0.5, boundary index = floor(2 * 0.5) = 1, so threshold is second item's score + expect(threshold).toBe(0.5); + }); + }); + + // ------------------------------------------------------------------------ + // retryOnTransient() + // ------------------------------------------------------------------------ + + describe("retryOnTransient", () => { + it("should retry on transient errors", async () => { + const fn = vi + .fn() + .mockRejectedValueOnce(new Error("TransientError: deadlock")) + .mockResolvedValueOnce("success"); + + const result = await (client as any).retryOnTransient(fn); + + expect(result).toBe("success"); + expect(fn).toHaveBeenCalledTimes(2); + }); + + it("should throw on permanent errors", async () => { + const fn = vi.fn().mockRejectedValue(new Error("ConstraintViolation")); + + await expect((client as any).retryOnTransient(fn)).rejects.toThrow("ConstraintViolation"); + expect(fn).toHaveBeenCalledTimes(1); + }); + + it("should exhaust retries and throw", async () => { + const fn = vi.fn().mockRejectedValue(new Error("TransientError: timeout")); + + await expect((client as any).retryOnTransient(fn)).rejects.toThrow("TransientError"); + expect(fn).toHaveBeenCalledTimes(3); // TRANSIENT_RETRY_ATTEMPTS = 3 + }); + + it("should identify transient error patterns", async () => { + const transientErrors = [ + "TransientError", + "DeadlockDetected", + "ServiceUnavailable", + "SessionExpired", + ]; + + for (const errMsg of transientErrors) { + const fn = vi + .fn() + .mockRejectedValueOnce(new Error(errMsg)) + .mockResolvedValueOnce("success"); + + const result = await (client as any).retryOnTransient(fn); + expect(result).toBe("success"); + } + }); + }); + + // ------------------------------------------------------------------------ + // promoteToCore() / demoteFromCore() + // ------------------------------------------------------------------------ + + describe("Core promotion/demotion", () => { + it("should promote memories to core category", async () => { + mockSession.run.mockResolvedValue({ + records: [{ get: vi.fn().mockReturnValue(2) }], + }); + + const result = await client.promoteToCore(["m1", "m2"]); + + expect(result).toBe(2); + expect(mockSession.run).toHaveBeenCalledWith( + expect.stringContaining("category = 'core'"), + expect.objectContaining({ ids: ["m1", "m2"] }), + ); + }); + + it("should demote memories from core category", async () => { + mockSession.run.mockResolvedValue({ + records: [{ get: vi.fn().mockReturnValue(1) }], + }); + + const result = await client.demoteFromCore(["m1"]); + + expect(result).toBe(1); + expect(mockSession.run).toHaveBeenCalledWith( + expect.stringContaining("category = 'fact'"), + expect.objectContaining({ ids: ["m1"] }), + ); + }); + + it("should handle empty ID arrays", async () => { + const promoteResult = await client.promoteToCore([]); + const demoteResult = await client.demoteFromCore([]); + + expect(promoteResult).toBe(0); + expect(demoteResult).toBe(0); + }); + }); + + // ------------------------------------------------------------------------ + // findDecayedMemories() + // ------------------------------------------------------------------------ + + describe("findDecayedMemories", () => { + it("should find memories below retention threshold", async () => { + mockSession.run.mockResolvedValue({ + records: [ + { + get: vi.fn((key) => { + const data: Record = { + id: "m1", + text: "old memory", + importance: 0.2, + ageDays: 100, + decayScore: 0.05, + }; + return data[key]; + }), + }, + ], + }); + + const result = await client.findDecayedMemories({ + retentionThreshold: 0.1, + baseHalfLifeDays: 30, + }); + + expect(result).toHaveLength(1); + expect(result[0]).toMatchObject({ + id: "m1", + text: "old memory", + importance: 0.2, + ageDays: 100, + decayScore: 0.05, + }); + }); + + it("should exclude core memories from decay", async () => { + mockSession.run.mockResolvedValue({ records: [] }); + + await client.findDecayedMemories(); + + expect(mockSession.run).toHaveBeenCalledWith( + expect.stringContaining("m.category <> 'core'"), + expect.any(Object), + ); + }); + + it("should use exponential decay formula", async () => { + // The Cypher query should implement: importance × e^(-age / halfLife) + mockSession.run.mockResolvedValue({ records: [] }); + + await client.findDecayedMemories({ + baseHalfLifeDays: 30, + importanceMultiplier: 2, + }); + + expect(mockSession.run).toHaveBeenCalledWith( + expect.stringContaining("exp("), + expect.objectContaining({ + baseHalfLife: 30, + importanceMult: 2, + }), + ); + }); + }); + + // ------------------------------------------------------------------------ + // pruneMemories() + // ------------------------------------------------------------------------ + + describe("pruneMemories", () => { + it("should delete decayed memories", async () => { + mockSession.run.mockResolvedValueOnce({ + records: [{ get: vi.fn().mockReturnValue(3) }], + }); + + const result = await client.pruneMemories(["m1", "m2", "m3"]); + + expect(result).toBe(3); + }); + + it("should decrement entity mention counts and delete atomically", async () => { + mockSession.run.mockResolvedValueOnce({ + records: [{ get: vi.fn().mockReturnValue(2) }], + }); + + await client.pruneMemories(["m1", "m2"]); + + // Single atomic query handles both mentionCount decrement and delete + expect(mockSession.run).toHaveBeenCalledTimes(1); + expect(mockSession.run).toHaveBeenCalledWith( + expect.stringContaining("MENTIONS"), + expect.objectContaining({ ids: ["m1", "m2"] }), + ); + expect(mockSession.run).toHaveBeenCalledWith( + expect.stringContaining("DETACH DELETE"), + expect.objectContaining({ ids: ["m1", "m2"] }), + ); + }); + + it("should handle empty ID array", async () => { + const result = await client.pruneMemories([]); + + expect(result).toBe(0); + expect(mockSession.run).not.toHaveBeenCalled(); + }); + }); + + // ------------------------------------------------------------------------ + // findOrphanEntities() / deleteOrphanEntities() + // ------------------------------------------------------------------------ + + describe("Orphan cleanup", () => { + it("should find entities with mentionCount <= 0", async () => { + mockSession.run.mockResolvedValue({ + records: [ + { + get: vi.fn((key) => { + const data: Record = { + id: "e1", + name: "orphan", + type: "concept", + }; + return data[key]; + }), + }, + ], + }); + + const result = await client.findOrphanEntities(); + + expect(result).toHaveLength(1); + expect(result[0]).toMatchObject({ + id: "e1", + name: "orphan", + type: "concept", + }); + }); + + it("should delete orphan entities", async () => { + mockSession.run.mockResolvedValue({ + records: [{ get: vi.fn().mockReturnValue(2) }], + }); + + const result = await client.deleteOrphanEntities(["e1", "e2"]); + + expect(result).toBe(2); + }); + + it("should find orphan tags (no TAGGED relationships)", async () => { + mockSession.run.mockResolvedValue({ + records: [ + { + get: vi.fn((key) => { + const data: Record = { id: "t1", name: "unused" }; + return data[key]; + }), + }, + ], + }); + + const result = await client.findOrphanTags(); + + expect(result).toHaveLength(1); + expect(result[0]).toMatchObject({ + id: "t1", + name: "unused", + }); + }); + + it("should delete orphan tags", async () => { + mockSession.run.mockResolvedValue({ + records: [{ get: vi.fn().mockReturnValue(1) }], + }); + + const result = await client.deleteOrphanTags(["t1"]); + + expect(result).toBe(1); + }); + }); + + // ------------------------------------------------------------------------ + // findConflictingMemories() + // ------------------------------------------------------------------------ + + describe("findConflictingMemories", () => { + it("should find memory pairs sharing entities", async () => { + mockSession.run.mockResolvedValue({ + records: [ + { + get: vi.fn((key) => { + const data: Record = { + m1Id: "mem1", + m1Text: "user prefers dark mode", + m1Importance: 0.7, + m1CreatedAt: "2024-01-01", + m2Id: "mem2", + m2Text: "user prefers light mode", + m2Importance: 0.6, + m2CreatedAt: "2024-01-02", + }; + return data[key]; + }), + }, + ], + }); + + const result = await client.findConflictingMemories(); + + expect(result).toHaveLength(1); + expect(result[0]).toMatchObject({ + memoryA: { + id: "mem1", + text: "user prefers dark mode", + importance: 0.7, + }, + memoryB: { + id: "mem2", + text: "user prefers light mode", + importance: 0.6, + }, + }); + }); + + it("should exclude core memories from conflict detection", async () => { + mockSession.run.mockResolvedValue({ records: [] }); + + await client.findConflictingMemories(); + + expect(mockSession.run).toHaveBeenCalledWith( + expect.stringContaining("m1.category <> 'core'"), + expect.any(Object), + ); + }); + + it("should limit results to 50 pairs", async () => { + mockSession.run.mockResolvedValue({ records: [] }); + + await client.findConflictingMemories(); + + expect(mockSession.run).toHaveBeenCalledWith( + expect.stringContaining("LIMIT 50"), + expect.any(Object), + ); + }); + }); + + // ------------------------------------------------------------------------ + // Entity and Tag operations + // ------------------------------------------------------------------------ + + describe("Entity operations", () => { + it("should merge entity idempotently", async () => { + mockSession.run.mockResolvedValue({ + records: [ + { + get: vi.fn((key) => { + const data: Record = { id: "e1", name: "tarun" }; + return data[key]; + }), + }, + ], + }); + + const input: MergeEntityInput = { + id: "e1", + name: "Tarun", + type: "person", + aliases: ["boss"], + description: "CEO", + }; + + const result = await client.mergeEntity(input); + + expect(result).toEqual({ id: "e1", name: "tarun" }); + expect(mockSession.run).toHaveBeenCalledWith( + expect.stringContaining("MERGE (e:Entity {name: $name})"), + expect.objectContaining({ + name: "tarun", // normalized + }), + ); + }); + + it("should create MENTIONS relationship", async () => { + mockSession.run.mockResolvedValue({ records: [] }); + + await client.createMentions("mem-1", "Tarun", "context", 0.95); + + expect(mockSession.run).toHaveBeenCalledWith( + expect.stringContaining("MERGE (m)-[r:MENTIONS]->(e)"), + expect.objectContaining({ + memoryId: "mem-1", + entityName: "tarun", // normalized + role: "context", + confidence: 0.95, + }), + ); + }); + + it("should create entity relationships with validated type", async () => { + mockSession.run.mockResolvedValue({ records: [] }); + + await client.createEntityRelationship("Alice", "Acme", "WORKS_AT", 0.9); + + expect(mockSession.run).toHaveBeenCalledWith( + expect.stringContaining("MERGE (e1)-[r:WORKS_AT]->(e2)"), + expect.objectContaining({ + sourceName: "alice", + targetName: "acme", + confidence: 0.9, + }), + ); + }); + + it("should reject invalid relationship types", async () => { + await client.createEntityRelationship("a", "b", "INVALID_TYPE", 0.9); + + expect(mockLogger.warn).toHaveBeenCalledWith( + expect.stringContaining("rejected invalid relationship type"), + ); + expect(mockSession.run).not.toHaveBeenCalled(); + }); + }); + + describe("Tag operations", () => { + it("should tag memory with normalized tag name", async () => { + mockSession.run.mockResolvedValue({ records: [] }); + + await client.tagMemory("mem-1", "Neo4j", "technology", 0.95); + + expect(mockSession.run).toHaveBeenCalledWith( + expect.stringContaining("MERGE (t:Tag {name: $tagName})"), + expect.objectContaining({ + memoryId: "mem-1", + tagName: "neo4j", // normalized + tagCategory: "technology", + confidence: 0.95, + }), + ); + }); + + it("should update memory category only when current is 'other'", async () => { + mockSession.run.mockResolvedValue({ records: [] }); + + await client.updateMemoryCategory("mem-1", "fact"); + + expect(mockSession.run).toHaveBeenCalledWith( + expect.stringContaining("WHERE m.category = 'other'"), + expect.objectContaining({ + id: "mem-1", + category: "fact", + }), + ); + }); + }); + + // ------------------------------------------------------------------------ + // Extraction status tracking + // ------------------------------------------------------------------------ + + describe("Extraction status", () => { + it("should update extraction status", async () => { + mockSession.run.mockResolvedValue({ records: [] }); + + await client.updateExtractionStatus("mem-1", "complete"); + + expect(mockSession.run).toHaveBeenCalledWith( + expect.stringContaining("m.extractionStatus = $status"), + expect.objectContaining({ + id: "mem-1", + status: "complete", + }), + ); + }); + + it("should increment retry counter when option is set", async () => { + mockSession.run.mockResolvedValue({ records: [] }); + + await client.updateExtractionStatus("mem-1", "pending", { incrementRetries: true }); + + expect(mockSession.run).toHaveBeenCalledWith( + expect.stringContaining("m.extractionRetries"), + expect.any(Object), + ); + }); + + it("should get extraction retry count", async () => { + mockSession.run.mockResolvedValue({ + records: [{ get: vi.fn().mockReturnValue(3) }], + }); + + const result = await client.getExtractionRetries("mem-1"); + + expect(result).toBe(3); + }); + + it("should count memories by extraction status", async () => { + mockSession.run.mockResolvedValue({ + records: [ + { get: vi.fn((key) => (key === "status" ? "pending" : { toNumber: () => 5 })) }, + { get: vi.fn((key) => (key === "status" ? "complete" : { toNumber: () => 10 })) }, + { get: vi.fn((key) => (key === "status" ? "failed" : { toNumber: () => 2 })) }, + ], + }); + + const result = await client.countByExtractionStatus(); + + expect(result).toEqual({ + pending: 5, + complete: 10, + failed: 2, + skipped: 0, + }); + }); + + it("should list pending extractions", async () => { + mockSession.run.mockResolvedValue({ + records: [ + { + get: vi.fn((key) => { + const data: Record = { + id: "m1", + text: "pending text", + agentId: "agent-1", + extractionRetries: 1, + }; + return data[key]; + }), + }, + ], + }); + + const result = await client.listPendingExtractions(100); + + expect(result).toHaveLength(1); + expect(result[0]).toMatchObject({ + id: "m1", + text: "pending text", + agentId: "agent-1", + extractionRetries: 1, + }); + }); + }); + + // ------------------------------------------------------------------------ + // Search operations + // ------------------------------------------------------------------------ + + describe("Search operations", () => { + it("should perform vector search with min score threshold", async () => { + mockSession.run.mockResolvedValue({ + records: [ + { + get: vi.fn((key) => { + const data: Record = { + id: "m1", + text: "result", + category: "fact", + importance: 0.8, + createdAt: "2024-01-01", + similarity: 0.92, + }; + return data[key]; + }), + }, + ], + }); + + const result = await client.vectorSearch([0.1, 0.2], 10, 0.9); + + expect(result).toHaveLength(1); + expect(result[0]).toMatchObject({ + id: "m1", + text: "result", + score: 0.92, + }); + }); + + it("should perform BM25 search and normalize scores", async () => { + mockSession.run.mockResolvedValue({ + records: [ + { + get: vi.fn((key) => { + const data: Record = { + id: "m1", + text: "result", + category: "fact", + importance: 0.8, + createdAt: "2024-01-01", + bm25Score: 5.0, + }; + return data[key]; + }), + }, + ], + }); + + const result = await client.bm25Search("test query", 10); + + expect(result).toHaveLength(1); + // Score should be normalized (divided by max) + expect(result[0].score).toBe(1.0); + }); + + it("should escape Lucene special characters in BM25 query", async () => { + mockSession.run.mockResolvedValue({ records: [] }); + + await client.bm25Search("test+query*", 10); + + // Should escape + and * + expect(mockSession.run).toHaveBeenCalledWith( + expect.any(String), + expect.objectContaining({ + query: expect.stringContaining("\\+"), + }), + ); + }); + + it("should perform graph search with entity traversal", async () => { + // Mock entity search + mockSession.run.mockResolvedValueOnce({ + records: [ + { + get: vi.fn((key) => { + const data: Record = { + entityId: "e1", + name: "tarun", + score: 0.95, + }; + return data[key]; + }), + }, + ], + }); + + // Mock memory search via entities + mockSession.run.mockResolvedValueOnce({ + records: [ + { + get: vi.fn((key) => { + const data: Record = { + id: "m1", + text: "result", + category: "fact", + importance: 0.8, + createdAt: "2024-01-01", + graphScore: 0.9, + }; + return data[key]; + }), + }, + ], + }); + + const result = await client.graphSearch("tarun", 10, 0.3); + + expect(result).toHaveLength(1); + expect(result[0]).toMatchObject({ + id: "m1", + score: 0.9, + }); + }); + }); + + // ------------------------------------------------------------------------ + // Retrieval tracking + // ------------------------------------------------------------------------ + + describe("Retrieval tracking", () => { + it("should record retrieval events", async () => { + mockSession.run.mockResolvedValue({ records: [] }); + + await client.recordRetrievals(["m1", "m2", "m3"]); + + expect(mockSession.run).toHaveBeenCalledWith( + expect.stringContaining("m.retrievalCount"), + expect.objectContaining({ + ids: ["m1", "m2", "m3"], + }), + ); + }); + + it("should update lastRetrievedAt timestamp", async () => { + mockSession.run.mockResolvedValue({ records: [] }); + + await client.recordRetrievals(["m1"]); + + expect(mockSession.run).toHaveBeenCalledWith( + expect.stringContaining("m.lastRetrievedAt"), + expect.objectContaining({ + now: expect.any(String), + }), + ); + }); + + it("should handle empty retrieval array", async () => { + await client.recordRetrievals([]); + + expect(mockSession.run).not.toHaveBeenCalled(); + }); + }); +}); diff --git a/extensions/memory-neo4j/neo4j-client.ts b/extensions/memory-neo4j/neo4j-client.ts index 1fb03e813d7..269f786d929 100644 --- a/extensions/memory-neo4j/neo4j-client.ts +++ b/extensions/memory-neo4j/neo4j-client.ts @@ -63,7 +63,11 @@ export class Neo4jMemoryClient { if (this.initPromise) { return this.initPromise; } - this.initPromise = this.doInitialize(); + this.initPromise = this.doInitialize().catch((err) => { + // Reset so subsequent calls retry instead of returning cached rejection + this.initPromise = null; + throw err; + }); return this.initPromise; } @@ -257,7 +261,7 @@ export class Neo4jMemoryClient { }); } - async deleteMemory(id: string): Promise { + async deleteMemory(id: string, agentId?: string): Promise { await this.ensureInitialized(); // Validate UUID format to prevent injection const uuidRegex = /^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$/i; @@ -268,20 +272,21 @@ export class Neo4jMemoryClient { return this.retryOnTransient(async () => { const session = this.driver!.session(); try { - // Decrement mentionCount on connected entities (floor at 0 to prevent - // negative counts from parallel deletes racing on the same entity) - await session.run( - `MATCH (m:Memory {id: $id})-[:MENTIONS]->(e:Entity) - SET e.mentionCount = CASE WHEN e.mentionCount > 0 THEN e.mentionCount - 1 ELSE 0 END`, - { id }, - ); - - // Then delete the memory with all its relationships + // Atomic: decrement mentionCount and delete in a single Cypher statement + // to prevent inconsistent state if a crash occurs between operations. + // When agentId is provided, scope the delete to that agent's memories + // to prevent cross-agent deletion. + const matchClause = agentId + ? "MATCH (m:Memory {id: $id, agentId: $agentId})" + : "MATCH (m:Memory {id: $id})"; const result = await session.run( - `MATCH (m:Memory {id: $id}) + `${matchClause} + OPTIONAL MATCH (m)-[:MENTIONS]->(e:Entity) + SET e.mentionCount = CASE WHEN e.mentionCount > 0 THEN e.mentionCount - 1 ELSE 0 END + WITH m, count(e) AS _ DETACH DELETE m RETURN count(*) AS deleted`, - { id }, + agentId ? { id, agentId } : { id }, ); const deleted = @@ -668,54 +673,6 @@ export class Neo4jMemoryClient { }); } - /** - * Calculate effective importance using retrieval-based reinforcement. - * - * Two modes: - * 1. With importance (regular memories): importance × freq_boost × recency - * 2. Without importance (core memories): freq_boost × recency - * - * Research basis: - * - ACT-R memory model (frequency with power-law decay) - * - FSRS spaced repetition (stability/retrievability) - * - Ebbinghaus forgetting curve (exponential decay) - */ - calculateEffectiveImportance( - retrievalCount: number, - daysSinceLastRetrieval: number | null, - options: { - baseImportance?: number; // Include importance multiplier (for regular memories) - frequencyScale?: number; // How much retrievals boost importance (default: 0.3) - recencyHalfLifeDays?: number; // Half-life for recency decay (default: 14) - } = {}, - ): number { - const { baseImportance, frequencyScale = 0.3, recencyHalfLifeDays = 14 } = options; - - // Frequency boost: log(1 + n) provides diminishing returns - // log(1+0)=0, log(1+1)≈0.69, log(1+10)≈2.4, log(1+100)≈4.6 - const frequencyBoost = 1 + Math.log1p(retrievalCount) * frequencyScale; - - // Recency factor: exponential decay with configurable half-life - // If never retrieved (null), use a baseline factor - let recencyFactor: number; - if (daysSinceLastRetrieval === null) { - recencyFactor = 0.1; // Never retrieved - low baseline - } else { - recencyFactor = Math.pow(2, -daysSinceLastRetrieval / recencyHalfLifeDays); - } - - // Combined effective importance - const usageScore = frequencyBoost * recencyFactor; - - // Include importance multiplier if provided (for regular memories) - if (baseImportance !== undefined) { - return baseImportance * usageScore; - } - - // Pure usage-based (for core memories) - return usageScore; - } - // -------------------------------------------------------------------------- // Entity & Relationship Operations // -------------------------------------------------------------------------- @@ -995,108 +952,175 @@ export class Neo4jMemoryClient { * 2. For each memory, query the vector index for nearest neighbors above threshold * 3. Build clusters via union-find (transitive closure) * 4. Return clusters with 2+ members + * + * @param threshold Minimum similarity score (0-1) + * @param agentId Optional agent filter + * @param returnSimilarities If true, includes pairwise similarity scores in the result */ async findDuplicateClusters( threshold: number = 0.95, agentId?: string, - ): Promise> { + returnSimilarities: boolean = false, + ): Promise< + Array<{ + memoryIds: string[]; + texts: string[]; + importances: number[]; + similarities?: Map; + }> + > { await this.ensureInitialized(); - const session = this.driver!.session(); - try { - // Step 1: Fetch all memory metadata (no embeddings — lightweight) - const agentFilter = agentId ? "WHERE m.agentId = $agentId" : ""; - const allResult = await session.run( - `MATCH (m:Memory) ${agentFilter} - RETURN m.id AS id, m.text AS text, m.importance AS importance`, - agentId ? { agentId } : {}, - ); - const memoryData = new Map(); - for (const r of allResult.records) { - memoryData.set(r.get("id") as string, { - text: r.get("text") as string, - importance: r.get("importance") as number, - }); + // Step 1: Fetch all memory metadata in a short-lived session + const memoryData = new Map(); + { + const session = this.driver!.session(); + try { + const agentFilter = agentId ? "WHERE m.agentId = $agentId" : ""; + const allResult = await session.run( + `MATCH (m:Memory) ${agentFilter} + RETURN m.id AS id, m.text AS text, m.importance AS importance`, + agentId ? { agentId } : {}, + ); + + for (const r of allResult.records) { + memoryData.set(r.get("id") as string, { + text: r.get("text") as string, + importance: r.get("importance") as number, + }); + } + } finally { + await session.close(); } + } - if (memoryData.size < 2) { - return []; + if (memoryData.size < 2) { + return []; + } + + // Step 2: For each memory, find near-duplicates via HNSW vector index + // Each query uses a fresh short-lived session via retryOnTransient to + // avoid a single long-lived session that could expire mid-operation. + // Each query is O(log N) vs O(N) for brute-force, total O(N log N) + const parent = new Map(); + // Capture pairwise similarities if requested (for sleep cycle optimization) + const pairwiseSimilarities = returnSimilarities ? new Map() : null; + + const find = (x: string): string => { + if (!parent.has(x)) { + parent.set(x, x); } + if (parent.get(x) !== x) { + parent.set(x, find(parent.get(x)!)); + } + return parent.get(x)!; + }; - // Step 2: For each memory, find near-duplicates via HNSW vector index - // Each query is O(log N) vs O(N) for brute-force, total O(N log N) - const parent = new Map(); + const union = (x: string, y: string): void => { + const px = find(x); + const py = find(y); + if (px !== py) { + parent.set(px, py); + } + }; - const find = (x: string): string => { - if (!parent.has(x)) { - parent.set(x, x); - } - if (parent.get(x) !== x) { - parent.set(x, find(parent.get(x)!)); - } - return parent.get(x)!; - }; + // Helper to create a canonical pair key (sorted) + const makePairKey = (a: string, b: string): string => { + return a < b ? `${a}:${b}` : `${b}:${a}`; + }; - const union = (x: string, y: string): void => { - const px = find(x); - const py = find(y); - if (px !== py) { - parent.set(px, py); - } - }; - - let pairsFound = 0; - for (const id of memoryData.keys()) { - // Retry individual vector queries on transient errors - const similar = await this.retryOnTransient(async () => { - return session.run( + let pairsFound = 0; + for (const id of memoryData.keys()) { + // Retry individual vector queries on transient errors (each uses a fresh session) + const similar = await this.retryOnTransient(async () => { + const session = this.driver!.session(); + try { + return await session.run( `MATCH (src:Memory {id: $id}) CALL db.index.vector.queryNodes('memory_embedding_index', $k, src.embedding) YIELD node, score WHERE node.id <> $id AND score >= $threshold - RETURN node.id AS matchId`, + RETURN node.id AS matchId, score`, { id, k: neo4j.int(10), threshold }, ); - }); + } finally { + await session.close(); + } + }); - for (const r of similar.records) { - const matchId = r.get("matchId") as string; - if (memoryData.has(matchId)) { - union(id, matchId); - pairsFound++; + for (const r of similar.records) { + const matchId = r.get("matchId") as string; + if (memoryData.has(matchId)) { + union(id, matchId); + pairsFound++; + + // Capture similarity score if requested + if (pairwiseSimilarities) { + const score = r.get("score") as number; + const pairKey = makePairKey(id, matchId); + // Keep the highest score if we see this pair multiple times + const existing = pairwiseSimilarities.get(pairKey); + if (existing === undefined || score > existing) { + pairwiseSimilarities.set(pairKey, score); + } } } - - // Early exit if we've found many pairs (safety bound) - if (pairsFound > 500) { - break; - } } - // Step 3: Group by root - const clusters = new Map(); - for (const id of memoryData.keys()) { - if (!parent.has(id)) { - continue; - } - const root = find(id); - if (!clusters.has(root)) { - clusters.set(root, []); - } - clusters.get(root)!.push(id); + // Early exit if we've found many pairs (safety bound) + if (pairsFound > 500) { + this.logger.warn( + `memory-neo4j: findDuplicateClusters hit safety bound (500 pairs) — some duplicates may not be detected. Consider running with a higher threshold.`, + ); + break; } + } - // Return clusters with 2+ members - return Array.from(clusters.values()) - .filter((ids) => ids.length >= 2) - .map((ids) => ({ + // Step 3: Group by root + const clusters = new Map(); + for (const id of memoryData.keys()) { + if (!parent.has(id)) { + continue; + } + const root = find(id); + if (!clusters.has(root)) { + clusters.set(root, []); + } + clusters.get(root)!.push(id); + } + + // Return clusters with 2+ members + return Array.from(clusters.values()) + .filter((ids) => ids.length >= 2) + .map((ids) => { + const cluster: { + memoryIds: string[]; + texts: string[]; + importances: number[]; + similarities?: Map; + } = { memoryIds: ids, texts: ids.map((id) => memoryData.get(id)!.text), importances: ids.map((id) => memoryData.get(id)!.importance), - })); - } finally { - await session.close(); - } + }; + + // Include similarities for this cluster if requested + if (pairwiseSimilarities) { + const clusterSims = new Map(); + for (let i = 0; i < ids.length - 1; i++) { + for (let j = i + 1; j < ids.length; j++) { + const pairKey = makePairKey(ids[i], ids[j]); + const score = pairwiseSimilarities.get(pairKey); + if (score !== undefined) { + clusterSims.set(pairKey, score); + } + } + } + cluster.similarities = clusterSims; + } + + return cluster; + }); } /** @@ -1122,49 +1146,53 @@ export class Neo4jMemoryClient { return this.retryOnTransient(async () => { const session = this.driver!.session(); try { - // Optimistic lock: verify all cluster members still exist before merging. - // New memories added or deleted between findDuplicateClusters() and this - // call could invalidate the cluster. Skip if any member is missing. - const verifyResult = await session.run( - `UNWIND $ids AS memId - OPTIONAL MATCH (m:Memory {id: memId}) - RETURN memId, m IS NOT NULL AS exists`, - { ids: memoryIds }, - ); - - const missingIds: string[] = []; - for (const r of verifyResult.records) { - if (!r.get("exists")) { - missingIds.push(r.get("memId") as string); - } - } - - if (missingIds.length > 0) { - this.logger.warn( - `memory-neo4j: skipping cluster merge — ${missingIds.length} member(s) no longer exist: ${missingIds.join(", ")}`, + // Execute verify + transfer + delete in a single write transaction + // to prevent TOCTOU races (member deleted between verify and merge) + const deletedCount = await session.executeWrite(async (tx) => { + // Verify all cluster members still exist + const verifyResult = await tx.run( + `UNWIND $ids AS memId + OPTIONAL MATCH (m:Memory {id: memId}) + RETURN memId, m IS NOT NULL AS exists`, + { ids: memoryIds }, ); - return { survivorId, deletedCount: 0 }; - } - // Transfer MENTIONS relationships from deleted memories to survivor - await session.run( - `UNWIND $toDelete AS deadId - MATCH (dead:Memory {id: deadId})-[r:MENTIONS]->(e:Entity) - MATCH (survivor:Memory {id: $survivorId}) - MERGE (survivor)-[:MENTIONS]->(e) - DELETE r`, - { toDelete, survivorId }, - ); + const missingIds: string[] = []; + for (const r of verifyResult.records) { + if (!r.get("exists")) { + missingIds.push(r.get("memId") as string); + } + } - // Delete the duplicate memories - await session.run( - `UNWIND $toDelete AS deadId - MATCH (m:Memory {id: deadId}) - DETACH DELETE m`, - { toDelete }, - ); + if (missingIds.length > 0) { + this.logger.warn( + `memory-neo4j: skipping cluster merge — ${missingIds.length} member(s) no longer exist: ${missingIds.join(", ")}`, + ); + return 0; + } - return { survivorId, deletedCount: toDelete.length }; + // Transfer MENTIONS relationships from deleted memories to survivor + await tx.run( + `UNWIND $toDelete AS deadId + MATCH (dead:Memory {id: deadId})-[r:MENTIONS]->(e:Entity) + MATCH (survivor:Memory {id: $survivorId}) + MERGE (survivor)-[:MENTIONS]->(e) + DELETE r`, + { toDelete, survivorId }, + ); + + // Delete the duplicate memories + await tx.run( + `UNWIND $toDelete AS deadId + MATCH (m:Memory {id: deadId}) + DETACH DELETE m`, + { toDelete }, + ); + + return toDelete.length; + }); + + return { survivorId, deletedCount }; } finally { await session.close(); } @@ -1260,19 +1288,14 @@ export class Neo4jMemoryClient { await this.ensureInitialized(); const session = this.driver!.session(); try { - // Decrement mention counts on connected entities (floor at 0 to prevent - // negative counts from parallel prune/delete operations racing on the same entity) - await session.run( - `UNWIND $ids AS memId - MATCH (m:Memory {id: memId})-[:MENTIONS]->(e:Entity) - SET e.mentionCount = CASE WHEN e.mentionCount > 0 THEN e.mentionCount - 1 ELSE 0 END`, - { ids: memoryIds }, - ); - - // Delete the memories + // Atomic: decrement mentionCount and delete in a single Cypher statement + // to prevent inconsistent state if a crash occurs between operations const result = await session.run( `UNWIND $ids AS memId MATCH (m:Memory {id: memId}) + OPTIONAL MATCH (m)-[:MENTIONS]->(e:Entity) + SET e.mentionCount = CASE WHEN e.mentionCount > 0 THEN e.mentionCount - 1 ELSE 0 END + WITH m, count(e) AS _ DETACH DELETE m RETURN count(*) AS deleted`, { ids: memoryIds }, @@ -1548,152 +1571,6 @@ export class Neo4jMemoryClient { return sorted[boundaryIndex]?.effectiveScore ?? 0; } - /** - * Find regular memories that should be promoted to core (above Pareto threshold). - * - * Pareto-based promotion: - * - Calculate effective score for all memories: importance × freq × recency - * - Find the 80th percentile threshold (top 20%) - * - Regular memories above threshold get promoted to core - * - Also requires minimum age (default: 7 days) to ensure stability - */ - async findPromotionCandidates(options: { - paretoThreshold: number; // The calculated Pareto threshold - minAgeDays?: number; // Minimum age in days (default: 7) - agentId?: string; - limit?: number; - }): Promise< - Array<{ - id: string; - text: string; - category: string; - importance: number; - ageDays: number; - retrievalCount: number; - effectiveScore: number; - }> - > { - const { paretoThreshold, minAgeDays = 7, agentId, limit = 100 } = options; - - await this.ensureInitialized(); - const session = this.driver!.session(); - try { - const agentFilter = agentId ? "AND m.agentId = $agentId" : ""; - const result = await session.run( - `MATCH (m:Memory) - WHERE m.category <> 'core' - AND m.createdAt IS NOT NULL - ${agentFilter} - WITH m, - duration.between(datetime(m.createdAt), datetime()).days AS ageDays, - coalesce(m.retrievalCount, 0) AS retrievalCount, - CASE - WHEN m.lastRetrievedAt IS NULL THEN null - ELSE duration.between(datetime(m.lastRetrievedAt), datetime()).days - END AS daysSinceRetrieval - WHERE ageDays >= $minAgeDays - WITH m, ageDays, retrievalCount, daysSinceRetrieval, - // Effective score: importance × freq_boost × recency - m.importance * (1 + log(1 + retrievalCount) * 0.3) * - CASE - WHEN daysSinceRetrieval IS NULL THEN 0.1 - ELSE 2.0 ^ (-1.0 * daysSinceRetrieval / 14.0) - END AS effectiveScore - WHERE effectiveScore >= $threshold - RETURN m.id AS id, m.text AS text, m.category AS category, - m.importance AS importance, ageDays, retrievalCount, effectiveScore - ORDER BY effectiveScore DESC - LIMIT $limit`, - { - threshold: paretoThreshold, - minAgeDays, - agentId, - limit: neo4j.int(limit), - }, - ); - - return result.records.map((r) => ({ - id: r.get("id") as string, - text: r.get("text") as string, - category: r.get("category") as string, - importance: r.get("importance") as number, - ageDays: r.get("ageDays") as number, - retrievalCount: r.get("retrievalCount") as number, - effectiveScore: r.get("effectiveScore") as number, - })); - } finally { - await session.close(); - } - } - - /** - * Find core memories that should be demoted (fallen below Pareto threshold). - * - * Core memories use the same formula for threshold comparison: - * importance × freq × recency - * - * If they fall below the top 20% threshold, they get demoted back to regular. - */ - async findDemotionCandidates(options: { - paretoThreshold: number; // The calculated Pareto threshold - agentId?: string; - limit?: number; - }): Promise< - Array<{ - id: string; - text: string; - importance: number; - retrievalCount: number; - effectiveScore: number; - }> - > { - const { paretoThreshold, agentId, limit = 100 } = options; - - await this.ensureInitialized(); - const session = this.driver!.session(); - try { - const agentFilter = agentId ? "AND m.agentId = $agentId" : ""; - const result = await session.run( - `MATCH (m:Memory) - WHERE m.category = 'core' - ${agentFilter} - WITH m, - coalesce(m.retrievalCount, 0) AS retrievalCount, - CASE - WHEN m.lastRetrievedAt IS NULL THEN null - ELSE duration.between(datetime(m.lastRetrievedAt), datetime()).days - END AS daysSinceRetrieval - WITH m, retrievalCount, daysSinceRetrieval, - // Effective score: importance × freq_boost × recency - m.importance * (1 + log(1 + retrievalCount) * 0.3) * - CASE - WHEN daysSinceRetrieval IS NULL THEN 0.1 - ELSE 2.0 ^ (-1.0 * daysSinceRetrieval / 14.0) - END AS effectiveScore - WHERE effectiveScore < $threshold - RETURN m.id AS id, m.text AS text, m.importance AS importance, - retrievalCount, effectiveScore - ORDER BY effectiveScore ASC - LIMIT $limit`, - { - threshold: paretoThreshold, - agentId, - limit: neo4j.int(limit), - }, - ); - - return result.records.map((r) => ({ - id: r.get("id") as string, - text: r.get("text") as string, - importance: r.get("importance") as number, - retrievalCount: r.get("retrievalCount") as number, - effectiveScore: r.get("effectiveScore") as number, - })); - } finally { - await session.close(); - } - } - /** * Promote memories to core status. */ diff --git a/extensions/memory-neo4j/openclaw.plugin.json b/extensions/memory-neo4j/openclaw.plugin.json index 8ee7179d78b..b0ca8eb66a3 100644 --- a/extensions/memory-neo4j/openclaw.plugin.json +++ b/extensions/memory-neo4j/openclaw.plugin.json @@ -44,6 +44,22 @@ "label": "Auto-Recall", "help": "Automatically inject relevant memories into context" }, + "autoRecallMinScore": { + "label": "Auto-Recall Min Score", + "help": "Minimum similarity score (0-1) for auto-recall results (default: 0.25)" + }, + "coreMemory.enabled": { + "label": "Core Memory", + "help": "Enable core memory bootstrap (top memories auto-loaded into context)" + }, + "coreMemory.maxEntries": { + "label": "Core Memory Max Entries", + "help": "Maximum number of core memories to load per session (default: 50)" + }, + "coreMemory.refreshAtContextPercent": { + "label": "Core Memory Refresh %", + "help": "Re-inject core memories when context usage reaches this percentage (1-100, optional)" + }, "extraction.apiKey": { "label": "Extraction API Key", "sensitive": true, @@ -109,6 +125,29 @@ "autoRecall": { "type": "boolean" }, + "autoRecallMinScore": { + "type": "number", + "minimum": 0, + "maximum": 1 + }, + "coreMemory": { + "type": "object", + "additionalProperties": false, + "properties": { + "enabled": { + "type": "boolean" + }, + "maxEntries": { + "type": "number", + "minimum": 1 + }, + "refreshAtContextPercent": { + "type": "number", + "minimum": 1, + "maximum": 100 + } + } + }, "extraction": { "type": "object", "additionalProperties": false,