diff --git a/extensions/memory-neo4j/attention-gate.ts b/extensions/memory-neo4j/attention-gate.ts index cd704eb8001..c97e8e17aff 100644 --- a/extensions/memory-neo4j/attention-gate.ts +++ b/extensions/memory-neo4j/attention-gate.ts @@ -78,7 +78,10 @@ export function passesAttentionGate(text: string): boolean { } // Excessive emoji (likely reaction, not substance) - const emojiCount = (trimmed.match(/[\u{1F300}-\u{1F9FF}]/gu) || []).length; + const emojiCount = ( + trimmed.match(/[\u{1F300}-\u{1F9FF}\u{2600}-\u{26FF}\u{2700}-\u{27BF}\u{1FA00}-\u{1FAFF}]/gu) || + [] + ).length; if (emojiCount > 3) { return false; } @@ -142,7 +145,10 @@ export function passesAssistantAttentionGate(text: string): boolean { } // Excessive emoji (likely reaction, not substance) - const emojiCount = (trimmed.match(/[\u{1F300}-\u{1F9FF}]/gu) || []).length; + const emojiCount = ( + trimmed.match(/[\u{1F300}-\u{1F9FF}\u{2600}-\u{26FF}\u{2700}-\u{27BF}\u{1FA00}-\u{1FAFF}]/gu) || + [] + ).length; if (emojiCount > 3) { return false; } diff --git a/extensions/memory-neo4j/config.test.ts b/extensions/memory-neo4j/config.test.ts index a56bb22596e..6c53f504527 100644 --- a/extensions/memory-neo4j/config.test.ts +++ b/extensions/memory-neo4j/config.test.ts @@ -171,6 +171,32 @@ describe("memoryNeo4jConfigSchema.parse", () => { expect(config.embedding.apiKey).toBe("sk-from-env"); }); + it("should resolve ${ENV_VAR} in neo4j.user (username)", () => { + process.env.TEST_NEO4J_USER = "resolved-user"; + const config = memoryNeo4jConfigSchema.parse({ + neo4j: { + uri: "bolt://localhost:7687", + user: "${TEST_NEO4J_USER}", + password: "", + }, + embedding: { provider: "ollama" }, + }); + expect(config.neo4j.username).toBe("resolved-user"); + }); + + it("should resolve ${ENV_VAR} in neo4j.username", () => { + process.env.TEST_NEO4J_USERNAME = "resolved-username"; + const config = memoryNeo4jConfigSchema.parse({ + neo4j: { + uri: "bolt://localhost:7687", + username: "${TEST_NEO4J_USERNAME}", + password: "", + }, + embedding: { provider: "ollama" }, + }); + expect(config.neo4j.username).toBe("resolved-username"); + }); + it("should throw when referenced env var is not set", () => { delete process.env.NONEXISTENT_VAR; expect(() => diff --git a/extensions/memory-neo4j/config.ts b/extensions/memory-neo4j/config.ts index ea068287d65..8348396ff82 100644 --- a/extensions/memory-neo4j/config.ts +++ b/extensions/memory-neo4j/config.ts @@ -247,9 +247,9 @@ export const memoryNeo4jConfigSchema = { // Support both 'user' and 'username' for neo4j config const neo4jUsername = typeof neo4jRaw.user === "string" - ? neo4jRaw.user + ? resolveEnvVars(neo4jRaw.user) : typeof neo4jRaw.username === "string" - ? neo4jRaw.username + ? resolveEnvVars(neo4jRaw.username) : "neo4j"; // Parse embedding section (optional for ollama without apiKey) diff --git a/extensions/memory-neo4j/embeddings.test.ts b/extensions/memory-neo4j/embeddings.test.ts index 615bc86f31a..28009c1cf66 100644 --- a/extensions/memory-neo4j/embeddings.test.ts +++ b/extensions/memory-neo4j/embeddings.test.ts @@ -78,6 +78,40 @@ describe("Embeddings - Ollama provider", () => { ); }); + it("should strip trailing slashes from baseUrl", async () => { + const { Embeddings } = await import("./embeddings.js"); + const mockVector = [0.1, 0.2]; + globalThis.fetch = vi.fn().mockResolvedValue({ + ok: true, + json: () => Promise.resolve({ embeddings: [mockVector] }), + }); + + const emb = new Embeddings(undefined, "mxbai-embed-large", "ollama", "http://my-host:11434/"); + await emb.embed("test"); + + expect(globalThis.fetch).toHaveBeenCalledWith( + "http://my-host:11434/api/embed", + expect.any(Object), + ); + }); + + it("should strip multiple trailing slashes from baseUrl", async () => { + const { Embeddings } = await import("./embeddings.js"); + const mockVector = [0.1, 0.2]; + globalThis.fetch = vi.fn().mockResolvedValue({ + ok: true, + json: () => Promise.resolve({ embeddings: [mockVector] }), + }); + + const emb = new Embeddings(undefined, "mxbai-embed-large", "ollama", "http://my-host:11434///"); + await emb.embed("test"); + + expect(globalThis.fetch).toHaveBeenCalledWith( + "http://my-host:11434/api/embed", + expect.any(Object), + ); + }); + it("should throw when Ollama returns error status", async () => { const { Embeddings } = await import("./embeddings.js"); globalThis.fetch = vi.fn().mockResolvedValue({ @@ -298,3 +332,150 @@ describe("Embeddings - Ollama context-length truncation", () => { expect(body2.input).toBe(shortText); }); }); + +// ============================================================================ +// OpenAI embed — functional tests with mocked OpenAI client +// ============================================================================ + +describe("Embeddings - OpenAI functional", () => { + beforeEach(() => { + vi.resetModules(); + }); + + afterEach(() => { + vi.restoreAllMocks(); + }); + + it("embed() should call OpenAI API with correct model and input", async () => { + const mockCreate = vi.fn().mockResolvedValue({ + data: [{ index: 0, embedding: [0.1, 0.2, 0.3] }], + }); + + // Mock the openai module + vi.doMock("openai", () => ({ + default: class MockOpenAI { + embeddings = { create: mockCreate }; + }, + })); + + const { Embeddings } = await import("./embeddings.js"); + const emb = new Embeddings("sk-test-key", "text-embedding-3-small", "openai"); + const result = await emb.embed("hello world"); + + expect(result).toEqual([0.1, 0.2, 0.3]); + expect(mockCreate).toHaveBeenCalledWith({ + model: "text-embedding-3-small", + input: "hello world", + }); + }); + + it("embedBatch() should send all texts in a single API call and return correctly ordered results", async () => { + const mockCreate = vi.fn().mockResolvedValue({ + // Return out-of-order to verify sorting by index + data: [ + { index: 2, embedding: [0.7, 0.8, 0.9] }, + { index: 0, embedding: [0.1, 0.2, 0.3] }, + { index: 1, embedding: [0.4, 0.5, 0.6] }, + ], + }); + + vi.doMock("openai", () => ({ + default: class MockOpenAI { + embeddings = { create: mockCreate }; + }, + })); + + const { Embeddings } = await import("./embeddings.js"); + const emb = new Embeddings("sk-test-key", "text-embedding-3-small", "openai"); + const results = await emb.embedBatch(["first", "second", "third"]); + + // Should have made exactly one API call with all texts + expect(mockCreate).toHaveBeenCalledTimes(1); + expect(mockCreate).toHaveBeenCalledWith({ + model: "text-embedding-3-small", + input: ["first", "second", "third"], + }); + + // Results should be sorted by index (0, 1, 2) + expect(results).toEqual([ + [0.1, 0.2, 0.3], + [0.4, 0.5, 0.6], + [0.7, 0.8, 0.9], + ]); + }); + + it("embed() should propagate OpenAI API errors", async () => { + const mockCreate = vi.fn().mockRejectedValue(new Error("API rate limit exceeded")); + + vi.doMock("openai", () => ({ + default: class MockOpenAI { + embeddings = { create: mockCreate }; + }, + })); + + const { Embeddings } = await import("./embeddings.js"); + const emb = new Embeddings("sk-test-key", "text-embedding-3-small", "openai"); + + await expect(emb.embed("test")).rejects.toThrow("API rate limit exceeded"); + }); + + it("embed() should return cached result on second call for same text", async () => { + const mockCreate = vi.fn().mockResolvedValue({ + data: [{ index: 0, embedding: [0.1, 0.2, 0.3] }], + }); + + vi.doMock("openai", () => ({ + default: class MockOpenAI { + embeddings = { create: mockCreate }; + }, + })); + + const { Embeddings } = await import("./embeddings.js"); + const emb = new Embeddings("sk-test-key", "text-embedding-3-small", "openai"); + + const result1 = await emb.embed("cached text"); + const result2 = await emb.embed("cached text"); + + expect(result1).toEqual([0.1, 0.2, 0.3]); + expect(result2).toEqual([0.1, 0.2, 0.3]); + // Should only make one API call — second call uses cache + expect(mockCreate).toHaveBeenCalledTimes(1); + }); + + it("embedBatch() should use cache for previously embedded texts", async () => { + const mockCreate = vi + .fn() + .mockResolvedValueOnce({ + data: [{ index: 0, embedding: [0.1, 0.2, 0.3] }], + }) + .mockResolvedValueOnce({ + data: [{ index: 0, embedding: [0.7, 0.8, 0.9] }], + }); + + vi.doMock("openai", () => ({ + default: class MockOpenAI { + embeddings = { create: mockCreate }; + }, + })); + + const { Embeddings } = await import("./embeddings.js"); + const emb = new Embeddings("sk-test-key", "text-embedding-3-small", "openai"); + + // First: embed "alpha" to populate cache + await emb.embed("alpha"); + expect(mockCreate).toHaveBeenCalledTimes(1); + + // Now batch with "alpha" (cached) and "beta" (uncached) + const results = await emb.embedBatch(["alpha", "beta"]); + // Should only call API once more for "beta" + expect(mockCreate).toHaveBeenCalledTimes(2); + expect(mockCreate).toHaveBeenLastCalledWith({ + model: "text-embedding-3-small", + input: ["beta"], + }); + expect(results).toEqual([ + [0.1, 0.2, 0.3], // cached + [0.7, 0.8, 0.9], // freshly computed + ]); + }); +}); diff --git a/extensions/memory-neo4j/embeddings.ts b/extensions/memory-neo4j/embeddings.ts index 53803e3b29a..44b101175c9 100644 --- a/extensions/memory-neo4j/embeddings.ts +++ b/extensions/memory-neo4j/embeddings.ts @@ -83,7 +83,10 @@ export class Embeddings { logger?: Logger, ) { this.provider = provider; - this.baseUrl = baseUrl ?? (provider === "ollama" ? "http://localhost:11434" : ""); + this.baseUrl = (baseUrl ?? (provider === "ollama" ? "http://localhost:11434" : "")).replace( + /\/+$/, + "", + ); this.logger = logger; this.contextLength = contextLengthForModel(model); @@ -250,7 +253,7 @@ export class Embeddings { input: texts, }); // Sort by index to ensure correct order - return response.data.toSorted((a, b) => a.index - b.index).map((d) => d.embedding); + return [...response.data].sort((a, b) => a.index - b.index).map((d) => d.embedding); } // Timeout for Ollama embedding fetch calls to prevent hanging indefinitely diff --git a/extensions/memory-neo4j/extractor.test.ts b/extensions/memory-neo4j/extractor.test.ts index 8e115e898ce..575f6622ae7 100644 --- a/extensions/memory-neo4j/extractor.test.ts +++ b/extensions/memory-neo4j/extractor.test.ts @@ -17,6 +17,7 @@ import { rateImportance, resolveConflict, isSemanticDuplicate, + isTransientError, runSleepCycle, } from "./extractor.js"; import { passesAttentionGate, passesAssistantAttentionGate } from "./index.js"; @@ -820,7 +821,7 @@ describe("runBackgroundExtraction", () => { } it("should skip extraction and mark as 'skipped' when disabled", async () => { - await runBackgroundExtraction( + const result = await runBackgroundExtraction( "mem-1", "test text", mockDb as never, @@ -829,6 +830,7 @@ describe("runBackgroundExtraction", () => { mockLogger, ); expect(mockDb.updateExtractionStatus).toHaveBeenCalledWith("mem-1", "skipped"); + expect(result).toEqual({ success: true, memoryId: "mem-1" }); }); it("should mark as 'failed' when extraction returns null", async () => { @@ -838,7 +840,7 @@ describe("runBackgroundExtraction", () => { text: () => Promise.resolve("error"), }); - await runBackgroundExtraction( + const result = await runBackgroundExtraction( "mem-1", "test text", mockDb as never, @@ -847,6 +849,7 @@ describe("runBackgroundExtraction", () => { mockLogger, ); expect(mockDb.updateExtractionStatus).toHaveBeenCalledWith("mem-1", "failed"); + expect(result).toEqual({ success: false, memoryId: "mem-1" }); }); it("should mark as 'complete' when extraction result is empty", async () => { @@ -858,7 +861,7 @@ describe("runBackgroundExtraction", () => { }), ); - await runBackgroundExtraction( + const result = await runBackgroundExtraction( "mem-1", "test text", mockDb as never, @@ -867,6 +870,7 @@ describe("runBackgroundExtraction", () => { mockLogger, ); expect(mockDb.updateExtractionStatus).toHaveBeenCalledWith("mem-1", "complete"); + expect(result).toEqual({ success: true, memoryId: "mem-1" }); }); it("should batch entities, relationships, tags, and category in one call", async () => { @@ -2274,10 +2278,10 @@ describe("runSleepCycle", () => { const result = await runSleepCycle(mockDb, mockEmbeddings, mockConfig, mockLogger); expect(result.extraction.processed).toBe(1); - // runBackgroundExtraction doesn't throw on HTTP errors, it just marks the extraction status as failed/pending - // The sleep cycle counts it as succeeded because Promise.allSettled reports it as fulfilled - expect(result.extraction.succeeded).toBe(1); - expect(result.extraction.failed).toBe(0); + // runBackgroundExtraction returns { success: false } on HTTP errors, + // so the sleep cycle correctly counts it as failed via outcome.value.success + expect(result.extraction.succeeded).toBe(0); + expect(result.extraction.failed).toBe(1); }); it("should respect batch size and delay", async () => { @@ -2570,3 +2574,82 @@ describe("runSleepCycle", () => { }); }); }); + +// ============================================================================ +// isTransientError() +// ============================================================================ + +describe("isTransientError", () => { + it("should return false for non-Error values", () => { + expect(isTransientError("string error")).toBe(false); + expect(isTransientError(42)).toBe(false); + expect(isTransientError(null)).toBe(false); + expect(isTransientError(undefined)).toBe(false); + }); + + it("should classify AbortError as transient", () => { + const err = new DOMException("signal aborted", "AbortError"); + expect(isTransientError(err)).toBe(true); + }); + + it("should classify TimeoutError as transient", () => { + const err = new DOMException("signal timed out", "TimeoutError"); + expect(isTransientError(err)).toBe(true); + }); + + it("should classify timeout messages as transient", () => { + expect(isTransientError(new Error("Request timeout after 30s"))).toBe(true); + }); + + it("should classify ECONNREFUSED as transient", () => { + expect(isTransientError(new Error("connect ECONNREFUSED 127.0.0.1:7687"))).toBe(true); + }); + + it("should classify ECONNRESET as transient", () => { + expect(isTransientError(new Error("read ECONNRESET"))).toBe(true); + }); + + it("should classify ETIMEDOUT as transient", () => { + expect(isTransientError(new Error("connect ETIMEDOUT 10.0.0.1:443"))).toBe(true); + }); + + it("should classify DNS failure (ENOTFOUND) as transient", () => { + expect(isTransientError(new Error("getaddrinfo ENOTFOUND api.openrouter.ai"))).toBe(true); + }); + + it("should classify HTTP 429 (rate limit) as transient", () => { + expect(isTransientError(new Error("OpenRouter API error 429: rate limited"))).toBe(true); + }); + + it("should classify HTTP 502 (bad gateway) as transient", () => { + expect(isTransientError(new Error("OpenRouter API error 502: bad gateway"))).toBe(true); + }); + + it("should classify HTTP 503 (service unavailable) as transient", () => { + expect(isTransientError(new Error("OpenRouter API error 503: service unavailable"))).toBe(true); + }); + + it("should classify HTTP 504 (gateway timeout) as transient", () => { + expect(isTransientError(new Error("OpenRouter API error 504: gateway timeout"))).toBe(true); + }); + + it("should classify network errors as transient", () => { + expect(isTransientError(new Error("network error"))).toBe(true); + expect(isTransientError(new Error("fetch failed"))).toBe(true); + expect(isTransientError(new Error("socket hang up"))).toBe(true); + }); + + it("should classify HTTP 500 as non-transient", () => { + expect(isTransientError(new Error("OpenRouter API error 500: internal server error"))).toBe( + false, + ); + }); + + it("should classify JSON parse errors as non-transient", () => { + expect(isTransientError(new Error("Unexpected token < in JSON"))).toBe(false); + }); + + it("should classify generic errors as non-transient", () => { + expect(isTransientError(new Error("something went wrong"))).toBe(false); + }); +}); diff --git a/extensions/memory-neo4j/extractor.ts b/extensions/memory-neo4j/extractor.ts index 871a134f2fe..6253dcb4d47 100644 --- a/extensions/memory-neo4j/extractor.ts +++ b/extensions/memory-neo4j/extractor.ts @@ -217,13 +217,20 @@ async function callOpenRouterStream( // Entity Extraction // ============================================================================ -/** Max retries for transient extraction failures before marking permanently failed */ +/** + * Max retries for transient extraction failures before marking permanently failed. + * + * Retry budget accounting — two layers of retry: + * Layer 1: callOpenRouter/callOpenRouterStream internal retries (config.maxRetries, default 2 = 3 attempts) + * Layer 2: Sleep cycle retries (MAX_EXTRACTION_RETRIES = 3 sleep cycles) + * Total worst-case: 3 × 3 = 9 LLM attempts per memory + */ const MAX_EXTRACTION_RETRIES = 3; /** * Check if an error is transient (network/timeout) vs permanent (JSON parse, etc.) */ -function isTransientError(err: unknown): boolean { +export function isTransientError(err: unknown): boolean { if (!(err instanceof Error)) { return false; } @@ -234,6 +241,7 @@ function isTransientError(err: unknown): boolean { msg.includes("timeout") || msg.includes("econnrefused") || msg.includes("econnreset") || + msg.includes("etimedout") || msg.includes("enotfound") || msg.includes("network") || msg.includes("fetch failed") || @@ -434,10 +442,10 @@ export async function runBackgroundExtraction( logger: Logger, currentRetries: number = 0, abortSignal?: AbortSignal, -): Promise { +): Promise<{ success: boolean; memoryId: string }> { if (!config.enabled) { await db.updateExtractionStatus(memoryId, "skipped").catch(() => {}); - return; + return { success: true, memoryId }; } try { @@ -463,7 +471,7 @@ export async function runBackgroundExtraction( // Permanent failure (JSON parse, empty response, etc.) await db.updateExtractionStatus(memoryId, "failed"); } - return; + return { success: false, memoryId }; } // Empty extraction is valid — not all memories have extractable entities @@ -473,7 +481,7 @@ export async function runBackgroundExtraction( result.tags.length === 0 ) { await db.updateExtractionStatus(memoryId, "complete"); - return; + return { success: true, memoryId }; } // Batch all entity operations into a single transaction: @@ -497,6 +505,7 @@ export async function runBackgroundExtraction( `${result.entities.length} entities, ${result.relationships.length} rels, ${result.tags.length} tags` + (result.category ? `, category=${result.category}` : ""), ); + return { success: true, memoryId }; } catch (err) { // Unexpected error during graph operations — treat as transient if retry budget remains const isTransient = isTransientError(err); @@ -513,6 +522,7 @@ export async function runBackgroundExtraction( .updateExtractionStatus(memoryId, "failed", { incrementRetries: true }) .catch(() => {}); } + return { success: false, memoryId }; } } @@ -1058,7 +1068,7 @@ export async function runSleepCycle( for (const outcome of outcomes) { result.extraction.processed++; - if (outcome.status === "fulfilled") { + if (outcome.status === "fulfilled" && outcome.value.success) { result.extraction.succeeded++; } else { result.extraction.failed++; diff --git a/extensions/memory-neo4j/index.ts b/extensions/memory-neo4j/index.ts index 46411a06ea5..2680199f23e 100644 --- a/extensions/memory-neo4j/index.ts +++ b/extensions/memory-neo4j/index.ts @@ -114,10 +114,11 @@ const memoryNeo4jPlugin = { limit: Type.Optional(Type.Number({ description: "Max results (default: 5)" })), }), async execute(_toolCallId: string, params: unknown) { - const { query, limit = 5 } = params as { + const { query, limit: rawLimit = 5 } = params as { query: string; limit?: number; }; + const limit = Math.floor(Math.min(50, Math.max(1, rawLimit))); const results = await hybridSearch( db, @@ -197,7 +198,7 @@ const memoryNeo4jPlugin = { const vector = await embeddings.embed(text); // 2. Check for duplicates (vector similarity > 0.95) - const existing = await db.findSimilar(vector, 0.95, 1); + const existing = await db.findSimilar(vector, 0.95, 1, agentId); if (existing.length > 0) { return { content: [ @@ -301,8 +302,9 @@ const memoryNeo4jPlugin = { }; } - // Auto-delete if single high-confidence match - if (results.length === 1 && results[0].score > 0.9) { + // Auto-delete if single high-confidence match (0.95 threshold + // reduces false positives — 0.9 cosine similarity is not exact match) + if (results.length === 1 && results[0].score > 0.95) { await db.deleteMemory(results[0].id, agentId); return { content: [ @@ -1191,8 +1193,10 @@ async function captureMessage( extractionConfig: import("./config.js").ExtractionConfig, logger: AutoCaptureLogger, ): Promise<{ stored: boolean; semanticDeduped: boolean }> { - // For assistant messages, rate importance first (before embedding) to skip early - const rateFirst = source === "auto-capture-assistant"; + // For assistant messages, rate importance first (before embedding) to skip early. + // When extraction is disabled, rateImportance returns 0.5 (the fallback), so we + // skip the early importance gate to avoid silently blocking all assistant captures. + const rateFirst = source === "auto-capture-assistant" && extractionConfig.enabled; let importance: number | undefined; if (rateFirst) { @@ -1205,15 +1209,17 @@ async function captureMessage( const vector = await embeddings.embed(text); // Quick dedup (same content already stored — cosine >= 0.95) - const existing = await db.findSimilar(vector, 0.95, 1); + const existing = await db.findSimilar(vector, 0.95, 1, agentId); if (existing.length > 0) { return { stored: false, semanticDeduped: false }; } - // Rate importance if not already done + // Rate importance if not already done. + // When extraction is disabled, rateImportance returns a fixed 0.5 fallback, + // so skip the threshold check to avoid silently blocking all captures. if (importance === undefined) { importance = await rateImportance(text, extractionConfig); - if (importance < importanceThreshold) { + if (extractionConfig.enabled && importance < importanceThreshold) { return { stored: false, semanticDeduped: false }; } } @@ -1221,7 +1227,7 @@ async function captureMessage( // Semantic dedup: check moderate-similarity memories (0.75-0.95) // Pass the vector similarity score as a pre-screen to skip LLM calls // for pairs below SEMANTIC_DEDUP_VECTOR_THRESHOLD. - const candidates = await db.findSimilar(vector, 0.75, 3); + const candidates = await db.findSimilar(vector, 0.75, 3, agentId); if (candidates.length > 0) { for (const candidate of candidates) { if (await isSemanticDuplicate(text, candidate.text, extractionConfig, candidate.score)) { diff --git a/extensions/memory-neo4j/mid-session-refresh.test.ts b/extensions/memory-neo4j/mid-session-refresh.test.ts index 50a99b9e259..e741585a9bc 100644 --- a/extensions/memory-neo4j/mid-session-refresh.test.ts +++ b/extensions/memory-neo4j/mid-session-refresh.test.ts @@ -2,60 +2,16 @@ * Tests for mid-session core memory refresh feature. * * Verifies that core memories are re-injected when context usage exceeds threshold. + * Tests config parsing, threshold calculation, shouldRefresh logic, and edge cases. */ import { describe, it, expect } from "vitest"; +// ============================================================================ +// Config parsing for refreshAtContextPercent +// ============================================================================ + describe("mid-session core memory refresh", () => { - // Test context threshold calculation - describe("context threshold calculation", () => { - it("should calculate usage percentage correctly", () => { - const contextWindowTokens = 200_000; - const estimatedUsedTokens = 100_000; - const usagePercent = (estimatedUsedTokens / contextWindowTokens) * 100; - expect(usagePercent).toBe(50); - }); - - it("should detect when threshold is exceeded", () => { - const threshold = 50; - const usagePercent = 55; - expect(usagePercent >= threshold).toBe(true); - }); - - it("should not trigger when below threshold", () => { - const threshold = 50; - const usagePercent = 45; - expect(usagePercent >= threshold).toBe(false); - }); - }); - - // Test refresh frequency limiting - describe("refresh frequency limiting", () => { - const MIN_TOKENS_SINCE_REFRESH = 10_000; - - it("should allow refresh when enough tokens have accumulated", () => { - const lastRefreshTokens = 50_000; - const currentTokens = 65_000; - const tokensSinceRefresh = currentTokens - lastRefreshTokens; - expect(tokensSinceRefresh >= MIN_TOKENS_SINCE_REFRESH).toBe(true); - }); - - it("should block refresh when not enough tokens have accumulated", () => { - const lastRefreshTokens = 50_000; - const currentTokens = 55_000; - const tokensSinceRefresh = currentTokens - lastRefreshTokens; - expect(tokensSinceRefresh >= MIN_TOKENS_SINCE_REFRESH).toBe(false); - }); - - it("should allow first refresh (no previous refresh)", () => { - const lastRefreshTokens = 0; // No previous refresh - const currentTokens = 100_000; - const tokensSinceRefresh = currentTokens - lastRefreshTokens; - expect(tokensSinceRefresh >= MIN_TOKENS_SINCE_REFRESH).toBe(true); - }); - }); - - // Test config parsing describe("config parsing", () => { it("should accept valid refreshAtContextPercent values", async () => { const { memoryNeo4jConfigSchema } = await import("./config.js"); @@ -67,7 +23,27 @@ describe("mid-session core memory refresh", () => { expect(config.coreMemory.refreshAtContextPercent).toBe(50); }); - it("should reject refreshAtContextPercent of 0", async () => { + it("should accept refreshAtContextPercent of 1 (minimum)", async () => { + const { memoryNeo4jConfigSchema } = await import("./config.js"); + const config = memoryNeo4jConfigSchema.parse({ + neo4j: { uri: "bolt://localhost:7687", user: "neo4j", password: "test" }, + embedding: { provider: "ollama" }, + coreMemory: { refreshAtContextPercent: 1 }, + }); + expect(config.coreMemory.refreshAtContextPercent).toBe(1); + }); + + it("should accept refreshAtContextPercent of 100 (maximum)", async () => { + const { memoryNeo4jConfigSchema } = await import("./config.js"); + const config = memoryNeo4jConfigSchema.parse({ + neo4j: { uri: "bolt://localhost:7687", user: "neo4j", password: "test" }, + embedding: { provider: "ollama" }, + coreMemory: { refreshAtContextPercent: 100 }, + }); + expect(config.coreMemory.refreshAtContextPercent).toBe(100); + }); + + it("should treat refreshAtContextPercent of 0 as disabled (undefined)", async () => { const { memoryNeo4jConfigSchema } = await import("./config.js"); const config = memoryNeo4jConfigSchema.parse({ neo4j: { uri: "bolt://localhost:7687", user: "neo4j", password: "test" }, @@ -77,6 +53,16 @@ describe("mid-session core memory refresh", () => { expect(config.coreMemory.refreshAtContextPercent).toBeUndefined(); }); + it("should treat negative refreshAtContextPercent as disabled (undefined)", async () => { + const { memoryNeo4jConfigSchema } = await import("./config.js"); + const config = memoryNeo4jConfigSchema.parse({ + neo4j: { uri: "bolt://localhost:7687", user: "neo4j", password: "test" }, + embedding: { provider: "ollama" }, + coreMemory: { refreshAtContextPercent: -10 }, + }); + expect(config.coreMemory.refreshAtContextPercent).toBeUndefined(); + }); + it("should throw for refreshAtContextPercent over 100", async () => { const { memoryNeo4jConfigSchema } = await import("./config.js"); expect(() => @@ -88,7 +74,7 @@ describe("mid-session core memory refresh", () => { ).toThrow("coreMemory.refreshAtContextPercent must be between 1 and 100"); }); - it("should default to undefined when not specified", async () => { + it("should default to undefined when coreMemory section is omitted", async () => { const { memoryNeo4jConfigSchema } = await import("./config.js"); const config = memoryNeo4jConfigSchema.parse({ neo4j: { uri: "bolt://localhost:7687", user: "neo4j", password: "test" }, @@ -96,11 +82,231 @@ describe("mid-session core memory refresh", () => { }); expect(config.coreMemory.refreshAtContextPercent).toBeUndefined(); }); + + it("should default to undefined when refreshAtContextPercent is omitted", async () => { + const { memoryNeo4jConfigSchema } = await import("./config.js"); + const config = memoryNeo4jConfigSchema.parse({ + neo4j: { uri: "bolt://localhost:7687", user: "neo4j", password: "test" }, + embedding: { provider: "ollama" }, + coreMemory: { enabled: true }, + }); + expect(config.coreMemory.refreshAtContextPercent).toBeUndefined(); + }); }); - // Test output format + // ============================================================================ + // shouldRefresh logic (tests the decision flow from index.ts) + // ============================================================================ + + describe("shouldRefresh decision logic", () => { + // These tests mirror the logic from index.ts lines 893-916: + // 1. Skip if contextWindowTokens or estimatedUsedTokens not available + // 2. Calculate usagePercent = (estimatedUsedTokens / contextWindowTokens) * 100 + // 3. Skip if usagePercent < refreshThreshold + // 4. Skip if tokens since last refresh < MIN_TOKENS_SINCE_REFRESH (10_000) + // 5. Otherwise, refresh + + const MIN_TOKENS_SINCE_REFRESH = 10_000; + + function shouldRefresh(params: { + contextWindowTokens: number | undefined; + estimatedUsedTokens: number | undefined; + refreshThreshold: number; + lastRefreshTokens: number; + }): boolean { + const { contextWindowTokens, estimatedUsedTokens, refreshThreshold, lastRefreshTokens } = + params; + + // Skip if context info not available + if (!contextWindowTokens || !estimatedUsedTokens) { + return false; + } + + const usagePercent = (estimatedUsedTokens / contextWindowTokens) * 100; + + // Only refresh if we've crossed the threshold + if (usagePercent < refreshThreshold) { + return false; + } + + // Check if we've already refreshed recently + const tokensSinceRefresh = estimatedUsedTokens - lastRefreshTokens; + if (tokensSinceRefresh < MIN_TOKENS_SINCE_REFRESH) { + return false; + } + + return true; + } + + it("should trigger refresh when usage exceeds threshold and enough tokens accumulated", () => { + expect( + shouldRefresh({ + contextWindowTokens: 200_000, + estimatedUsedTokens: 120_000, // 60% + refreshThreshold: 50, + lastRefreshTokens: 0, // Never refreshed + }), + ).toBe(true); + }); + + it("should not trigger when usage is below threshold", () => { + expect( + shouldRefresh({ + contextWindowTokens: 200_000, + estimatedUsedTokens: 80_000, // 40% + refreshThreshold: 50, + lastRefreshTokens: 0, + }), + ).toBe(false); + }); + + it("should not trigger when not enough tokens since last refresh", () => { + expect( + shouldRefresh({ + contextWindowTokens: 200_000, + estimatedUsedTokens: 105_000, // 52.5% + refreshThreshold: 50, + lastRefreshTokens: 100_000, // Only 5k tokens since last refresh + }), + ).toBe(false); + }); + + it("should trigger when enough tokens accumulated since last refresh", () => { + expect( + shouldRefresh({ + contextWindowTokens: 200_000, + estimatedUsedTokens: 115_000, // 57.5% + refreshThreshold: 50, + lastRefreshTokens: 100_000, // 15k tokens since last refresh + }), + ).toBe(true); + }); + + it("should not trigger when contextWindowTokens is undefined", () => { + expect( + shouldRefresh({ + contextWindowTokens: undefined, + estimatedUsedTokens: 120_000, + refreshThreshold: 50, + lastRefreshTokens: 0, + }), + ).toBe(false); + }); + + it("should not trigger when estimatedUsedTokens is undefined", () => { + expect( + shouldRefresh({ + contextWindowTokens: 200_000, + estimatedUsedTokens: undefined, + refreshThreshold: 50, + lastRefreshTokens: 0, + }), + ).toBe(false); + }); + + it("should handle 0% usage (empty context)", () => { + expect( + shouldRefresh({ + contextWindowTokens: 200_000, + estimatedUsedTokens: 0, + refreshThreshold: 50, + lastRefreshTokens: 0, + }), + ).toBe(false); + }); + + it("should handle 100% usage", () => { + expect( + shouldRefresh({ + contextWindowTokens: 200_000, + estimatedUsedTokens: 200_000, // 100% + refreshThreshold: 50, + lastRefreshTokens: 0, + }), + ).toBe(true); + }); + + it("should handle exact threshold boundary (50% == 50% threshold)", () => { + // usagePercent == refreshThreshold: usagePercent < refreshThreshold is false, so it proceeds + expect( + shouldRefresh({ + contextWindowTokens: 200_000, + estimatedUsedTokens: 100_000, // exactly 50% + refreshThreshold: 50, + lastRefreshTokens: 0, + }), + ).toBe(true); + }); + + it("should handle threshold of 1 (refresh almost immediately)", () => { + expect( + shouldRefresh({ + contextWindowTokens: 200_000, + estimatedUsedTokens: 15_000, // 7.5% + refreshThreshold: 1, + lastRefreshTokens: 0, + }), + ).toBe(true); + }); + + it("should handle threshold of 100 (refresh only at full context)", () => { + expect( + shouldRefresh({ + contextWindowTokens: 200_000, + estimatedUsedTokens: 190_000, // 95% + refreshThreshold: 100, + lastRefreshTokens: 0, + }), + ).toBe(false); + }); + + it("should allow first refresh even when lastRefreshTokens is 0", () => { + expect( + shouldRefresh({ + contextWindowTokens: 200_000, + estimatedUsedTokens: 110_000, + refreshThreshold: 50, + lastRefreshTokens: 0, + }), + ).toBe(true); + }); + + it("should support multiple refresh cycles with cumulative token growth", () => { + // First refresh at 110k tokens + const firstResult = shouldRefresh({ + contextWindowTokens: 200_000, + estimatedUsedTokens: 110_000, + refreshThreshold: 50, + lastRefreshTokens: 0, + }); + expect(firstResult).toBe(true); + + // Second attempt too soon (only 5k since first) + const secondResult = shouldRefresh({ + contextWindowTokens: 200_000, + estimatedUsedTokens: 115_000, + refreshThreshold: 50, + lastRefreshTokens: 110_000, + }); + expect(secondResult).toBe(false); + + // Third attempt after enough growth (15k since first refresh) + const thirdResult = shouldRefresh({ + contextWindowTokens: 200_000, + estimatedUsedTokens: 125_000, + refreshThreshold: 50, + lastRefreshTokens: 110_000, + }); + expect(thirdResult).toBe(true); + }); + }); + + // ============================================================================ + // Output format + // ============================================================================ + describe("refresh output format", () => { - it("should format core memories correctly", () => { + it("should format core memories as XML-wrapped bullet list", () => { const coreMemories = [ { text: "User prefers TypeScript over JavaScript" }, { text: "User works at Acme Corp" }, @@ -113,5 +319,14 @@ describe("mid-session core memory refresh", () => { expect(output).toContain("- User prefers TypeScript over JavaScript"); expect(output).toContain("- User works at Acme Corp"); }); + + it("should handle single core memory", () => { + const coreMemories = [{ text: "Only memory" }]; + const content = coreMemories.map((m) => `- ${m.text}`).join("\n"); + const output = `\nReminder of persistent context (you may have seen this earlier, re-stating for recency):\n${content}\n`; + + expect(output).toContain("- Only memory"); + expect(output.match(/^- /gm)?.length).toBe(1); + }); }); }); diff --git a/extensions/memory-neo4j/neo4j-client.test.ts b/extensions/memory-neo4j/neo4j-client.test.ts index 81fc4b06e79..08eb426e378 100644 --- a/extensions/memory-neo4j/neo4j-client.test.ts +++ b/extensions/memory-neo4j/neo4j-client.test.ts @@ -276,6 +276,60 @@ describe("Neo4jMemoryClient", () => { expect(result).toEqual([]); expect(mockLogger.debug).toHaveBeenCalled(); }); + + it("should filter by agentId when provided", async () => { + mockSession.run.mockResolvedValue({ + records: [ + { + get: vi.fn((key) => { + if (key === "id") return "mem-1"; + if (key === "text") return "similar text"; + if (key === "similarity") return 0.96; + return null; + }), + }, + ], + }); + + const result = await client.findSimilar([0.1, 0.2, 0.3], 0.95, 5, "agent-1"); + + expect(result).toHaveLength(1); + expect(result[0]).toEqual({ id: "mem-1", text: "similar text", score: 0.96 }); + // Should include agentId filter in query and params + expect(mockSession.run).toHaveBeenCalledWith( + expect.stringContaining("node.agentId = $agentId"), + expect.objectContaining({ agentId: "agent-1" }), + ); + }); + + it("should fetch extra candidates and trim when agentId is provided", async () => { + mockSession.run.mockResolvedValue({ + records: [ + { + get: vi.fn((key) => { + if (key === "id") return "mem-1"; + if (key === "text") return "text 1"; + if (key === "similarity") return 0.99; + return null; + }), + }, + { + get: vi.fn((key) => { + if (key === "id") return "mem-2"; + if (key === "text") return "text 2"; + if (key === "similarity") return 0.97; + return null; + }), + }, + ], + }); + + // Request limit=1 with agentId: should fetch 3x candidates (limit*3) and trim to 1 + const result = await client.findSimilar([0.1, 0.2], 0.95, 1, "agent-1"); + + expect(result).toHaveLength(1); + expect(result[0].id).toBe("mem-1"); + }); }); // ------------------------------------------------------------------------ @@ -1298,6 +1352,15 @@ describe("Neo4jMemoryClient", () => { extractionRetries: 1, }); }); + + it("should not pass agentId: undefined in listPendingExtractions params", async () => { + mockSession.run.mockResolvedValue({ records: [] }); + + await client.listPendingExtractions(50); + + const params = mockSession.run.mock.calls[0][1] as Record; + expect(params).not.toHaveProperty("agentId"); + }); }); // ------------------------------------------------------------------------ @@ -1356,8 +1419,49 @@ describe("Neo4jMemoryClient", () => { const result = await client.bm25Search("test query", 10); expect(result).toHaveLength(1); - // Score should be normalized (divided by max) + // Single result: score should be moderate 0.5 (not 1.0) to avoid inflating weak matches + expect(result[0].score).toBe(0.5); + }); + + it("should normalize BM25 scores with min-max when multiple results exist", async () => { + mockSession.run.mockResolvedValue({ + records: [ + { + get: vi.fn((key) => { + const data: Record = { + id: "m1", + text: "best match", + category: "fact", + importance: 0.8, + createdAt: "2024-01-01", + bm25Score: 10.0, + }; + return data[key]; + }), + }, + { + get: vi.fn((key) => { + const data: Record = { + id: "m2", + text: "worst match", + category: "fact", + importance: 0.5, + createdAt: "2024-01-02", + bm25Score: 2.0, + }; + return data[key]; + }), + }, + ], + }); + + const result = await client.bm25Search("test", 10); + + expect(result).toHaveLength(2); + // Best result gets score 1.0 (FLOOR + (1-FLOOR)*1) expect(result[0].score).toBe(1.0); + // Worst result gets FLOOR (0.3) + expect(result[1].score).toBeCloseTo(0.3); }); it("should escape Lucene special characters in BM25 query", async () => { @@ -1404,6 +1508,75 @@ describe("Neo4jMemoryClient", () => { }); }); + // ------------------------------------------------------------------------ + // reindex() + // ------------------------------------------------------------------------ + + describe("reindex", () => { + it("should use UNWIND batch update instead of individual queries", async () => { + // Mock drop index session + const dropSession = createMockSession(); + // Mock fetch session (returns 2 memories) + const fetchSession = createMockSession(); + fetchSession.run.mockResolvedValueOnce({ + records: [ + { get: vi.fn((key) => (key === "id" ? "m1" : "text 1")) }, + { get: vi.fn((key) => (key === "id" ? "m2" : "text 2")) }, + ], + }); + // Mock batch update session + const updateSession = createMockSession(); + // Mock recreate index session + const indexSession = createMockSession(); + + mockDriver.session + .mockReturnValueOnce(dropSession) + .mockReturnValueOnce(fetchSession) + .mockReturnValueOnce(updateSession) + .mockReturnValueOnce(indexSession); + + const embedFn = vi.fn().mockResolvedValue([ + [0.1, 0.2], + [0.3, 0.4], + ]); + + await client.reindex(embedFn, { batchSize: 50 }); + + // Should call UNWIND batch, not individual queries + expect(updateSession.run).toHaveBeenCalledTimes(1); + expect(updateSession.run).toHaveBeenCalledWith( + expect.stringContaining("UNWIND $items"), + expect.objectContaining({ + items: [ + { id: "m1", embedding: [0.1, 0.2] }, + { id: "m2", embedding: [0.3, 0.4] }, + ], + }), + ); + }); + + it("should skip batch update when all embeddings are empty", async () => { + const dropSession = createMockSession(); + const fetchSession = createMockSession(); + fetchSession.run.mockResolvedValueOnce({ + records: [{ get: vi.fn((key) => (key === "id" ? "m1" : "text 1")) }], + }); + const indexSession = createMockSession(); + + mockDriver.session + .mockReturnValueOnce(dropSession) + .mockReturnValueOnce(fetchSession) + .mockReturnValueOnce(indexSession); + + const embedFn = vi.fn().mockResolvedValue([[]]); + + await client.reindex(embedFn, { batchSize: 50 }); + + // No update session should be created (only drop, fetch, and index sessions) + expect(mockDriver.session).toHaveBeenCalledTimes(3); + }); + }); + // ------------------------------------------------------------------------ // Retrieval tracking // ------------------------------------------------------------------------ diff --git a/extensions/memory-neo4j/neo4j-client.ts b/extensions/memory-neo4j/neo4j-client.ts index 484d762e3d5..2443221e41d 100644 --- a/extensions/memory-neo4j/neo4j-client.ts +++ b/extensions/memory-neo4j/neo4j-client.ts @@ -18,6 +18,10 @@ import type { } from "./schema.js"; import { ALLOWED_RELATIONSHIP_TYPES, escapeLucene, validateRelationshipType } from "./schema.js"; +// SAFETY: This pattern is built from the hardcoded ALLOWED_RELATIONSHIP_TYPES constant, +// not from user input. It's used in Cypher variable-length path patterns like +// (e1)-[:WORKS_AT|LIVES_AT|...*1..N]-(e2). Since the source is a compile-time +// constant, there is no injection risk. const RELATIONSHIP_TYPE_PATTERN = [...ALLOWED_RELATIONSHIP_TYPES].join("|"); // ============================================================================ @@ -501,7 +505,7 @@ export class Neo4jMemoryClient { const FLOOR = 0.3; // Minimum normalized score for the lowest-ranked result return records.map((r) => ({ ...r, - score: range > 0 ? FLOOR + ((1 - FLOOR) * (r.rawScore - minScore)) / range : 1.0, // All scores identical → all get 1.0 + score: range > 0 ? FLOOR + ((1 - FLOOR) * (r.rawScore - minScore)) / range : 0.5, // Single result or identical scores → moderate 0.5 to avoid inflating weak matches })); } finally { await session.close(); @@ -609,7 +613,8 @@ export class Neo4jMemoryClient { } return Array.from(byId.values()) - .toSorted((a, b) => b.score - a.score) + .slice() + .sort((a, b) => b.score - a.score) .slice(0, limit); } finally { await session.close(); @@ -624,31 +629,45 @@ export class Neo4jMemoryClient { /** * Find similar memories by vector similarity. Used for deduplication. + * When agentId is provided, results are post-filtered to that agent + * (HNSW indexes don't support pre-filtering, so we fetch extra candidates). */ async findSimilar( embedding: number[], threshold: number = 0.95, limit: number = 1, + agentId?: string, ): Promise> { await this.ensureInitialized(); try { return await this.retryOnTransient(async () => { const session = this.driver!.session(); try { + // Fetch extra candidates when filtering by agentId since HNSW + // doesn't support pre-filtering; post-filter and trim to limit. + const fetchLimit = agentId ? limit * 3 : limit; + const agentFilter = agentId ? "AND node.agentId = $agentId" : ""; const result = await session.run( `CALL db.index.vector.queryNodes('memory_embedding_index', $limit, $embedding) YIELD node, score - WHERE score >= $threshold + WHERE score >= $threshold ${agentFilter} RETURN node.id AS id, node.text AS text, score AS similarity ORDER BY score DESC`, - { embedding, limit: neo4j.int(limit), threshold }, + { + embedding, + limit: neo4j.int(fetchLimit), + threshold, + ...(agentId ? { agentId } : {}), + }, ); - return result.records.map((r) => ({ + const results = result.records.map((r) => ({ id: r.get("id") as string, text: r.get("text") as string, score: r.get("similarity") as number, })); + // Trim to requested limit after post-filtering + return agentId ? results.slice(0, limit) : results; } finally { await session.close(); } @@ -1056,7 +1075,7 @@ export class Neo4jMemoryClient { coalesce(m.extractionRetries, 0) AS extractionRetries ORDER BY m.createdAt ASC LIMIT $limit`, - { limit: neo4j.int(limit), agentId }, + { limit: neo4j.int(limit), ...(agentId ? { agentId } : {}) }, ); return result.records.map((r) => ({ id: r.get("id") as string, @@ -1082,7 +1101,7 @@ export class Neo4jMemoryClient { `MATCH (m:Memory) ${agentFilter} RETURN m.extractionStatus AS status, count(m) AS count`, - { agentId }, + agentId ? { agentId } : {}, ); const counts: Record = { pending: 0, @@ -1193,51 +1212,64 @@ export class Neo4jMemoryClient { return a < b ? `${a}:${b}` : `${b}:${a}`; }; + // Process vector queries in concurrent batches to avoid overwhelming Neo4j + // while still being much faster than fully sequential execution. + const DEDUP_CONCURRENCY = 8; let pairsFound = 0; - for (const id of memoryData.keys()) { - // Retry individual vector queries on transient errors (each uses a fresh session) - const similar = await this.retryOnTransient(async () => { - const session = this.driver!.session(); - try { - return await session.run( - `MATCH (src:Memory {id: $id}) - CALL db.index.vector.queryNodes('memory_embedding_index', $k, src.embedding) - YIELD node, score - WHERE node.id <> $id AND score >= $threshold - RETURN node.id AS matchId, score`, - { id, k: neo4j.int(10), threshold }, - ); - } finally { - await session.close(); - } - }); + const allIds = [...memoryData.keys()]; - for (const r of similar.records) { - const matchId = r.get("matchId") as string; - if (memoryData.has(matchId)) { - union(id, matchId); - pairsFound++; - - // Capture similarity score if requested - if (pairwiseSimilarities) { - const score = r.get("score") as number; - const pairKey = makePairKey(id, matchId); - // Keep the highest score if we see this pair multiple times - const existing = pairwiseSimilarities.get(pairKey); - if (existing === undefined || score > existing) { - pairwiseSimilarities.set(pairKey, score); - } - } - } - } - - // Early exit if we've found many pairs (safety bound) + for (let batchStart = 0; batchStart < allIds.length; batchStart += DEDUP_CONCURRENCY) { if (pairsFound > 500) { this.logger.warn( `memory-neo4j: findDuplicateClusters hit safety bound (500 pairs) — some duplicates may not be detected. Consider running with a higher threshold.`, ); break; } + + const batch = allIds.slice(batchStart, batchStart + DEDUP_CONCURRENCY); + const results = await Promise.all( + batch.map((id) => + this.retryOnTransient(async () => { + const session = this.driver!.session(); + try { + return await session.run( + `MATCH (src:Memory {id: $id}) + CALL db.index.vector.queryNodes('memory_embedding_index', $k, src.embedding) + YIELD node, score + WHERE node.id <> $id AND score >= $threshold + RETURN node.id AS matchId, score`, + { id, k: neo4j.int(10), threshold }, + ); + } finally { + await session.close(); + } + }), + ), + ); + + for (let idx = 0; idx < batch.length; idx++) { + const id = batch[idx]; + const similar = results[idx]; + + for (const r of similar.records) { + const matchId = r.get("matchId") as string; + if (memoryData.has(matchId)) { + union(id, matchId); + pairsFound++; + + // Capture similarity score if requested + if (pairwiseSimilarities) { + const score = r.get("score") as number; + const pairKey = makePairKey(id, matchId); + // Keep the highest score if we see this pair multiple times + const existing = pairwiseSimilarities.get(pairKey); + if (existing === undefined || score > existing) { + pairwiseSimilarities.set(pairKey, score); + } + } + } + } + } } // Step 3: Group by root @@ -1490,25 +1522,27 @@ export class Neo4jMemoryClient { } await this.ensureInitialized(); - const session = this.driver!.session(); - try { - // Atomic: decrement mentionCount and delete in a single Cypher statement - // to prevent inconsistent state if a crash occurs between operations - const result = await session.run( - `UNWIND $ids AS memId - MATCH (m:Memory {id: memId}) - OPTIONAL MATCH (m)-[:MENTIONS]->(e:Entity) - SET e.mentionCount = CASE WHEN e.mentionCount > 0 THEN e.mentionCount - 1 ELSE 0 END - WITH m, count(e) AS _ - DETACH DELETE m - RETURN count(*) AS deleted`, - { ids: memoryIds }, - ); + return this.retryOnTransient(async () => { + const session = this.driver!.session(); + try { + // Atomic: decrement mentionCount and delete in a single Cypher statement + // to prevent inconsistent state if a crash occurs between operations + const result = await session.run( + `UNWIND $ids AS memId + MATCH (m:Memory {id: memId}) + OPTIONAL MATCH (m)-[:MENTIONS]->(e:Entity) + SET e.mentionCount = CASE WHEN e.mentionCount > 0 THEN e.mentionCount - 1 ELSE 0 END + WITH m, count(e) AS _ + DETACH DELETE m + RETURN count(*) AS deleted`, + { ids: memoryIds }, + ); - return (result.records[0]?.get("deleted") as number) ?? 0; - } finally { - await session.close(); - } + return (result.records[0]?.get("deleted") as number) ?? 0; + } finally { + await session.close(); + } + }); } // -------------------------------------------------------------------------- @@ -1766,7 +1800,7 @@ export class Neo4jMemoryClient { } // Scores should already be sorted descending, but ensure it - const sorted = scores.toSorted((a, b) => b.effectiveScore - a.effectiveScore); + const sorted = [...scores].sort((a, b) => b.effectiveScore - a.effectiveScore); // Find the index at the percentile boundary // For top 20%, we want the score at index = 20% of total @@ -1888,18 +1922,25 @@ export class Neo4jMemoryClient { const batch = memories.slice(i, i + batchSize); const vectors = await embedFn(batch.map((m) => m.text)); - const session = this.driver!.session(); - try { - for (let j = 0; j < batch.length; j++) { - if (vectors[j] && vectors[j].length > 0) { - await session.run("MATCH (m:Memory {id: $id}) SET m.embedding = $embedding", { - id: batch[j].id, - embedding: vectors[j], - }); - } + // Build items array for batch UNWIND update + const items: Array<{ id: string; embedding: number[] }> = []; + for (let j = 0; j < batch.length; j++) { + if (vectors[j] && vectors[j].length > 0) { + items.push({ id: batch[j].id, embedding: vectors[j] }); + } + } + if (items.length > 0) { + const session = this.driver!.session(); + try { + await session.run( + `UNWIND $items AS item + MATCH (m:Memory {id: item.id}) + SET m.embedding = item.embedding`, + { items }, + ); + } finally { + await session.close(); } - } finally { - await session.close(); } progress("memories", Math.min(i + batchSize, memories.length), memories.length); } diff --git a/extensions/memory-neo4j/search.test.ts b/extensions/memory-neo4j/search.test.ts index 46b6822dad5..85048d9e926 100644 --- a/extensions/memory-neo4j/search.test.ts +++ b/extensions/memory-neo4j/search.test.ts @@ -1,15 +1,20 @@ /** * Tests for search.ts — Hybrid Search & RRF Fusion. * - * Tests the exported pure logic: classifyQuery() and getAdaptiveWeights(). - * Note: fuseWithConfidenceRRF() is not exported (private module-level function) - * and is tested indirectly through hybridSearch(). + * Tests the exported pure logic: classifyQuery(), getAdaptiveWeights(), and fuseWithConfidenceRRF(). * hybridSearch() is tested with mocked Neo4j client and Embeddings. */ import { describe, it, expect, vi, beforeEach } from "vitest"; +import type { Embeddings } from "./embeddings.js"; +import type { Neo4jMemoryClient } from "./neo4j-client.js"; import type { SearchSignalResult } from "./schema.js"; -import { classifyQuery, getAdaptiveWeights, hybridSearch } from "./search.js"; +import { + classifyQuery, + getAdaptiveWeights, + fuseWithConfidenceRRF, + hybridSearch, +} from "./search.js"; // ============================================================================ // classifyQuery() @@ -168,15 +173,27 @@ describe("getAdaptiveWeights", () => { // ============================================================================ describe("hybridSearch", () => { - // Create mock db and embeddings - const mockDb = { + // Properly typed mocks matching the interfaces hybridSearch depends on. + // Using Pick<> to extract only the methods hybridSearch actually calls, + // so TypeScript will catch interface changes (e.g. renamed or removed methods). + type MockedDb = { + [K in keyof Pick< + Neo4jMemoryClient, + "vectorSearch" | "bm25Search" | "graphSearch" | "recordRetrievals" + >]: ReturnType; + }; + type MockedEmbeddings = { + [K in keyof Pick]: ReturnType; + }; + + const mockDb: MockedDb = { vectorSearch: vi.fn(), bm25Search: vi.fn(), graphSearch: vi.fn(), recordRetrievals: vi.fn(), }; - const mockEmbeddings = { + const mockEmbeddings: MockedEmbeddings = { embed: vi.fn(), embedBatch: vi.fn(), }; @@ -204,8 +221,8 @@ describe("hybridSearch", () => { mockDb.bm25Search.mockResolvedValue([]); const results = await hybridSearch( - mockDb as never, - mockEmbeddings as never, + mockDb as unknown as Neo4jMemoryClient, + mockEmbeddings as unknown as Embeddings, "test query", 5, "agent-1", @@ -224,8 +241,8 @@ describe("hybridSearch", () => { mockDb.bm25Search.mockResolvedValue([bm25Result]); const results = await hybridSearch( - mockDb as never, - mockEmbeddings as never, + mockDb as unknown as Neo4jMemoryClient, + mockEmbeddings as unknown as Embeddings, "test query", 5, "agent-1", @@ -247,8 +264,8 @@ describe("hybridSearch", () => { mockDb.bm25Search.mockResolvedValue([{ ...sharedResult, score: 0.85 }]); const results = await hybridSearch( - mockDb as never, - mockEmbeddings as never, + mockDb as unknown as Neo4jMemoryClient, + mockEmbeddings as unknown as Embeddings, "test query", 5, "agent-1", @@ -270,8 +287,8 @@ describe("hybridSearch", () => { ]); const results = await hybridSearch( - mockDb as never, - mockEmbeddings as never, + mockDb as unknown as Neo4jMemoryClient, + mockEmbeddings as unknown as Embeddings, "tell me about Tarun", 5, "agent-1", @@ -287,7 +304,14 @@ describe("hybridSearch", () => { mockDb.vectorSearch.mockResolvedValue([]); mockDb.bm25Search.mockResolvedValue([]); - await hybridSearch(mockDb as never, mockEmbeddings as never, "test query", 5, "agent-1", false); + await hybridSearch( + mockDb as unknown as Neo4jMemoryClient, + mockEmbeddings as unknown as Embeddings, + "test query", + 5, + "agent-1", + false, + ); expect(mockDb.graphSearch).not.toHaveBeenCalled(); }); @@ -301,8 +325,8 @@ describe("hybridSearch", () => { mockDb.bm25Search.mockResolvedValue([]); const results = await hybridSearch( - mockDb as never, - mockEmbeddings as never, + mockDb as unknown as Neo4jMemoryClient, + mockEmbeddings as unknown as Embeddings, "test query", 3, "agent-1", @@ -319,7 +343,14 @@ describe("hybridSearch", () => { ]); mockDb.bm25Search.mockResolvedValue([]); - await hybridSearch(mockDb as never, mockEmbeddings as never, "test query", 5, "agent-1", false); + await hybridSearch( + mockDb as unknown as Neo4jMemoryClient, + mockEmbeddings as unknown as Embeddings, + "test query", + 5, + "agent-1", + false, + ); expect(mockDb.recordRetrievals).toHaveBeenCalledWith(["mem-1", "mem-2"]); }); @@ -331,8 +362,8 @@ describe("hybridSearch", () => { // Should not throw const results = await hybridSearch( - mockDb as never, - mockEmbeddings as never, + mockDb as unknown as Neo4jMemoryClient, + mockEmbeddings as unknown as Embeddings, "test query", 5, "agent-1", @@ -350,8 +381,8 @@ describe("hybridSearch", () => { mockDb.bm25Search.mockResolvedValue([]); const results = await hybridSearch( - mockDb as never, - mockEmbeddings as never, + mockDb as unknown as Neo4jMemoryClient, + mockEmbeddings as unknown as Embeddings, "test query", 5, "agent-1", @@ -369,8 +400,8 @@ describe("hybridSearch", () => { mockDb.bm25Search.mockResolvedValue([]); await hybridSearch( - mockDb as never, - mockEmbeddings as never, + mockDb as unknown as Neo4jMemoryClient, + mockEmbeddings as unknown as Embeddings, "test query", 5, "agent-1", @@ -387,7 +418,11 @@ describe("hybridSearch", () => { mockDb.vectorSearch.mockResolvedValue([]); mockDb.bm25Search.mockResolvedValue([]); - await hybridSearch(mockDb as never, mockEmbeddings as never, "test query"); + await hybridSearch( + mockDb as unknown as Neo4jMemoryClient, + mockEmbeddings as unknown as Embeddings, + "test query", + ); expect(mockDb.vectorSearch).toHaveBeenCalledWith( expect.any(Array), @@ -397,3 +432,123 @@ describe("hybridSearch", () => { ); }); }); + +// ============================================================================ +// fuseWithConfidenceRRF() +// ============================================================================ + +describe("fuseWithConfidenceRRF", () => { + function makeSignal(id: string, score: number, text = `Memory ${id}`): SearchSignalResult { + return { + id, + text, + category: "fact", + importance: 0.7, + createdAt: "2025-01-01T00:00:00Z", + score, + }; + } + + it("should return empty array when all signals are empty", () => { + const result = fuseWithConfidenceRRF([[], [], []], 60, [1.0, 1.0, 1.0]); + expect(result).toEqual([]); + }); + + it("should handle a single signal with results", () => { + const signal = [makeSignal("a", 0.9), makeSignal("b", 0.5)]; + const result = fuseWithConfidenceRRF([signal, [], []], 60, [1.0, 1.0, 1.0]); + + expect(result).toHaveLength(2); + expect(result[0].id).toBe("a"); + expect(result[1].id).toBe("b"); + // First result should have higher RRF score than second + expect(result[0].rrfScore).toBeGreaterThan(result[1].rrfScore); + }); + + it("should boost candidates appearing in multiple signals", () => { + const vectorSignal = [makeSignal("shared", 0.9), makeSignal("vec-only", 0.8)]; + const bm25Signal = [makeSignal("shared", 0.85)]; + + const result = fuseWithConfidenceRRF([vectorSignal, bm25Signal, []], 60, [1.0, 1.0, 1.0]); + + // "shared" should rank higher than "vec-only" despite similar scores + // because it appears in two signals + expect(result[0].id).toBe("shared"); + expect(result[1].id).toBe("vec-only"); + }); + + it("should handle ties (same score, same rank) consistently", () => { + const signal = [makeSignal("a", 0.5), makeSignal("b", 0.5)]; + const result = fuseWithConfidenceRRF([signal], 60, [1.0]); + + expect(result).toHaveLength(2); + // With same score, first in signal should have higher RRF (rank 1 vs rank 2) + expect(result[0].id).toBe("a"); + expect(result[1].id).toBe("b"); + }); + + it("should respect different k values", () => { + const signal = [makeSignal("a", 0.9), makeSignal("b", 0.5)]; + + // Small k amplifies rank differences, large k smooths them + const resultSmallK = fuseWithConfidenceRRF([signal], 1, [1.0]); + const resultLargeK = fuseWithConfidenceRRF([signal], 1000, [1.0]); + + // The ratio between first and second should be larger with smaller k + const ratioSmallK = resultSmallK[0].rrfScore / resultSmallK[1].rrfScore; + const ratioLargeK = resultLargeK[0].rrfScore / resultLargeK[1].rrfScore; + expect(ratioSmallK).toBeGreaterThan(ratioLargeK); + }); + + it("should handle zero-score entries", () => { + const signal = [makeSignal("a", 0.9), makeSignal("b", 0)]; + const result = fuseWithConfidenceRRF([signal], 60, [1.0]); + + expect(result).toHaveLength(2); + // Zero score entry should have zero RRF contribution + expect(result[1].rrfScore).toBe(0); + expect(result[0].rrfScore).toBeGreaterThan(0); + }); + + it("should apply signal weights correctly", () => { + // Same item appears in two signals with different weights + const signal1 = [makeSignal("a", 0.8)]; + const signal2 = [makeSignal("a", 0.8)]; + + const resultEqual = fuseWithConfidenceRRF([signal1, signal2], 60, [1.0, 1.0]); + const resultWeighted = fuseWithConfidenceRRF([signal1, signal2], 60, [2.0, 0.5]); + + // Both should have the same item, but weighted version uses different signal contributions + expect(resultEqual[0].id).toBe("a"); + expect(resultWeighted[0].id).toBe("a"); + // With unequal weights, overall score differs + expect(resultEqual[0].rrfScore).not.toBeCloseTo(resultWeighted[0].rrfScore); + }); + + it("should sort results by RRF score descending", () => { + const signal1 = [makeSignal("low", 0.3)]; + const signal2 = [makeSignal("high", 0.95)]; + const signal3 = [makeSignal("mid", 0.6)]; + + const result = fuseWithConfidenceRRF([signal1, signal2, signal3], 60, [1.0, 1.0, 1.0]); + + expect(result[0].id).toBe("high"); + expect(result[1].id).toBe("mid"); + expect(result[2].id).toBe("low"); + }); + + it("should deduplicate within a single signal (keep first occurrence)", () => { + const signal = [ + makeSignal("dup", 0.9), + makeSignal("dup", 0.5), // duplicate — should be ignored + makeSignal("other", 0.7), + ]; + const result = fuseWithConfidenceRRF([signal], 60, [1.0]); + + // "dup" should appear once using its first occurrence (rank 1, score 0.9) + const dupEntry = result.find((r) => r.id === "dup"); + expect(dupEntry).toBeDefined(); + // Only 2 unique candidates + expect(result).toHaveLength(2); + }); +}); diff --git a/extensions/memory-neo4j/search.ts b/extensions/memory-neo4j/search.ts index 113b2c197ec..60a787e747e 100644 --- a/extensions/memory-neo4j/search.ts +++ b/extensions/memory-neo4j/search.ts @@ -120,7 +120,7 @@ type FusedCandidate = { * * Reference: Cormack et al. (2009), extended with confidence weighting. */ -function fuseWithConfidenceRRF( +export function fuseWithConfidenceRRF( signals: SearchSignalResult[][], k: number, weights: number[], @@ -249,9 +249,12 @@ export async function hybridSearch( // 4. Fuse with confidence-weighted RRF const fused = fuseWithConfidenceRRF([vectorResults, bm25Results, graphResults], rrfK, weights); - // 5. Return top results, normalized to 0-100% display scores - const maxRrf = fused.length > 0 ? fused[0].rrfScore : 1; - const normalizer = maxRrf > 0 ? 1 / maxRrf : 1; + // 5. Return top results, normalized to 0-100% display scores. + // Only normalize when maxRrf is above a minimum threshold to avoid + // inflating weak matches (e.g., a single low-score result becoming 1.0). + const maxRrf = fused.length > 0 ? fused[0].rrfScore : 0; + const MIN_RRF_FOR_NORMALIZATION = 0.01; + const normalizer = maxRrf >= MIN_RRF_FOR_NORMALIZATION ? 1 / maxRrf : 1; const results = fused.slice(0, limit).map((r) => ({ id: r.id,