memory-neo4j: fix high-severity review findings — security, concurrency, silent failures

- Add safety comment for RELATIONSHIP_TYPE_PATTERN Cypher interpolation
- Add concurrency batching (8) to findDuplicateClusters vector queries
- Bounds-validate memory_recall limit parameter (1-50)
- Fix maxRetries comment (default 2 = 3 attempts, not 1 = 2)
- Fix countByExtractionStatus passing undefined agentId to Cypher
- Fix assistant auto-capture silently disabled when extraction disabled
- Add agentId scoping to findSimilar (dedup + auto-capture)
- Fix BM25 single-result normalization (0.5 instead of inflated 1.0)
- Wrap pruneMemories in retryOnTransient for resilience
- Use UNWIND batch update in reindex instead of N individual queries
- Raise auto-delete threshold from 0.9 to 0.95 to reduce false positives

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Tarun Sukhani
2026-02-09 23:46:12 +08:00
parent 03e4768732
commit 806c5e2d13
13 changed files with 1090 additions and 188 deletions

View File

@@ -78,7 +78,10 @@ export function passesAttentionGate(text: string): boolean {
}
// Excessive emoji (likely reaction, not substance)
const emojiCount = (trimmed.match(/[\u{1F300}-\u{1F9FF}]/gu) || []).length;
const emojiCount = (
trimmed.match(/[\u{1F300}-\u{1F9FF}\u{2600}-\u{26FF}\u{2700}-\u{27BF}\u{1FA00}-\u{1FAFF}]/gu) ||
[]
).length;
if (emojiCount > 3) {
return false;
}
@@ -142,7 +145,10 @@ export function passesAssistantAttentionGate(text: string): boolean {
}
// Excessive emoji (likely reaction, not substance)
const emojiCount = (trimmed.match(/[\u{1F300}-\u{1F9FF}]/gu) || []).length;
const emojiCount = (
trimmed.match(/[\u{1F300}-\u{1F9FF}\u{2600}-\u{26FF}\u{2700}-\u{27BF}\u{1FA00}-\u{1FAFF}]/gu) ||
[]
).length;
if (emojiCount > 3) {
return false;
}

View File

@@ -171,6 +171,32 @@ describe("memoryNeo4jConfigSchema.parse", () => {
expect(config.embedding.apiKey).toBe("sk-from-env");
});
it("should resolve ${ENV_VAR} in neo4j.user (username)", () => {
process.env.TEST_NEO4J_USER = "resolved-user";
const config = memoryNeo4jConfigSchema.parse({
neo4j: {
uri: "bolt://localhost:7687",
user: "${TEST_NEO4J_USER}",
password: "",
},
embedding: { provider: "ollama" },
});
expect(config.neo4j.username).toBe("resolved-user");
});
it("should resolve ${ENV_VAR} in neo4j.username", () => {
process.env.TEST_NEO4J_USERNAME = "resolved-username";
const config = memoryNeo4jConfigSchema.parse({
neo4j: {
uri: "bolt://localhost:7687",
username: "${TEST_NEO4J_USERNAME}",
password: "",
},
embedding: { provider: "ollama" },
});
expect(config.neo4j.username).toBe("resolved-username");
});
it("should throw when referenced env var is not set", () => {
delete process.env.NONEXISTENT_VAR;
expect(() =>

View File

@@ -247,9 +247,9 @@ export const memoryNeo4jConfigSchema = {
// Support both 'user' and 'username' for neo4j config
const neo4jUsername =
typeof neo4jRaw.user === "string"
? neo4jRaw.user
? resolveEnvVars(neo4jRaw.user)
: typeof neo4jRaw.username === "string"
? neo4jRaw.username
? resolveEnvVars(neo4jRaw.username)
: "neo4j";
// Parse embedding section (optional for ollama without apiKey)

View File

@@ -78,6 +78,40 @@ describe("Embeddings - Ollama provider", () => {
);
});
it("should strip trailing slashes from baseUrl", async () => {
const { Embeddings } = await import("./embeddings.js");
const mockVector = [0.1, 0.2];
globalThis.fetch = vi.fn().mockResolvedValue({
ok: true,
json: () => Promise.resolve({ embeddings: [mockVector] }),
});
const emb = new Embeddings(undefined, "mxbai-embed-large", "ollama", "http://my-host:11434/");
await emb.embed("test");
expect(globalThis.fetch).toHaveBeenCalledWith(
"http://my-host:11434/api/embed",
expect.any(Object),
);
});
it("should strip multiple trailing slashes from baseUrl", async () => {
const { Embeddings } = await import("./embeddings.js");
const mockVector = [0.1, 0.2];
globalThis.fetch = vi.fn().mockResolvedValue({
ok: true,
json: () => Promise.resolve({ embeddings: [mockVector] }),
});
const emb = new Embeddings(undefined, "mxbai-embed-large", "ollama", "http://my-host:11434///");
await emb.embed("test");
expect(globalThis.fetch).toHaveBeenCalledWith(
"http://my-host:11434/api/embed",
expect.any(Object),
);
});
it("should throw when Ollama returns error status", async () => {
const { Embeddings } = await import("./embeddings.js");
globalThis.fetch = vi.fn().mockResolvedValue({
@@ -298,3 +332,150 @@ describe("Embeddings - Ollama context-length truncation", () => {
expect(body2.input).toBe(shortText);
});
});
// ============================================================================
// OpenAI embed — functional tests with mocked OpenAI client
// ============================================================================
describe("Embeddings - OpenAI functional", () => {
beforeEach(() => {
vi.resetModules();
});
afterEach(() => {
vi.restoreAllMocks();
});
it("embed() should call OpenAI API with correct model and input", async () => {
const mockCreate = vi.fn().mockResolvedValue({
data: [{ index: 0, embedding: [0.1, 0.2, 0.3] }],
});
// Mock the openai module
vi.doMock("openai", () => ({
default: class MockOpenAI {
embeddings = { create: mockCreate };
},
}));
const { Embeddings } = await import("./embeddings.js");
const emb = new Embeddings("sk-test-key", "text-embedding-3-small", "openai");
const result = await emb.embed("hello world");
expect(result).toEqual([0.1, 0.2, 0.3]);
expect(mockCreate).toHaveBeenCalledWith({
model: "text-embedding-3-small",
input: "hello world",
});
});
it("embedBatch() should send all texts in a single API call and return correctly ordered results", async () => {
const mockCreate = vi.fn().mockResolvedValue({
// Return out-of-order to verify sorting by index
data: [
{ index: 2, embedding: [0.7, 0.8, 0.9] },
{ index: 0, embedding: [0.1, 0.2, 0.3] },
{ index: 1, embedding: [0.4, 0.5, 0.6] },
],
});
vi.doMock("openai", () => ({
default: class MockOpenAI {
embeddings = { create: mockCreate };
},
}));
const { Embeddings } = await import("./embeddings.js");
const emb = new Embeddings("sk-test-key", "text-embedding-3-small", "openai");
const results = await emb.embedBatch(["first", "second", "third"]);
// Should have made exactly one API call with all texts
expect(mockCreate).toHaveBeenCalledTimes(1);
expect(mockCreate).toHaveBeenCalledWith({
model: "text-embedding-3-small",
input: ["first", "second", "third"],
});
// Results should be sorted by index (0, 1, 2)
expect(results).toEqual([
[0.1, 0.2, 0.3],
[0.4, 0.5, 0.6],
[0.7, 0.8, 0.9],
]);
});
it("embed() should propagate OpenAI API errors", async () => {
const mockCreate = vi.fn().mockRejectedValue(new Error("API rate limit exceeded"));
vi.doMock("openai", () => ({
default: class MockOpenAI {
embeddings = { create: mockCreate };
},
}));
const { Embeddings } = await import("./embeddings.js");
const emb = new Embeddings("sk-test-key", "text-embedding-3-small", "openai");
await expect(emb.embed("test")).rejects.toThrow("API rate limit exceeded");
});
it("embed() should return cached result on second call for same text", async () => {
const mockCreate = vi.fn().mockResolvedValue({
data: [{ index: 0, embedding: [0.1, 0.2, 0.3] }],
});
vi.doMock("openai", () => ({
default: class MockOpenAI {
embeddings = { create: mockCreate };
},
}));
const { Embeddings } = await import("./embeddings.js");
const emb = new Embeddings("sk-test-key", "text-embedding-3-small", "openai");
const result1 = await emb.embed("cached text");
const result2 = await emb.embed("cached text");
expect(result1).toEqual([0.1, 0.2, 0.3]);
expect(result2).toEqual([0.1, 0.2, 0.3]);
// Should only make one API call — second call uses cache
expect(mockCreate).toHaveBeenCalledTimes(1);
});
it("embedBatch() should use cache for previously embedded texts", async () => {
const mockCreate = vi
.fn()
.mockResolvedValueOnce({
data: [{ index: 0, embedding: [0.1, 0.2, 0.3] }],
})
.mockResolvedValueOnce({
data: [{ index: 0, embedding: [0.7, 0.8, 0.9] }],
});
vi.doMock("openai", () => ({
default: class MockOpenAI {
embeddings = { create: mockCreate };
},
}));
const { Embeddings } = await import("./embeddings.js");
const emb = new Embeddings("sk-test-key", "text-embedding-3-small", "openai");
// First: embed "alpha" to populate cache
await emb.embed("alpha");
expect(mockCreate).toHaveBeenCalledTimes(1);
// Now batch with "alpha" (cached) and "beta" (uncached)
const results = await emb.embedBatch(["alpha", "beta"]);
// Should only call API once more for "beta"
expect(mockCreate).toHaveBeenCalledTimes(2);
expect(mockCreate).toHaveBeenLastCalledWith({
model: "text-embedding-3-small",
input: ["beta"],
});
expect(results).toEqual([
[0.1, 0.2, 0.3], // cached
[0.7, 0.8, 0.9], // freshly computed
]);
});
});

View File

@@ -83,7 +83,10 @@ export class Embeddings {
logger?: Logger,
) {
this.provider = provider;
this.baseUrl = baseUrl ?? (provider === "ollama" ? "http://localhost:11434" : "");
this.baseUrl = (baseUrl ?? (provider === "ollama" ? "http://localhost:11434" : "")).replace(
/\/+$/,
"",
);
this.logger = logger;
this.contextLength = contextLengthForModel(model);
@@ -250,7 +253,7 @@ export class Embeddings {
input: texts,
});
// Sort by index to ensure correct order
return response.data.toSorted((a, b) => a.index - b.index).map((d) => d.embedding);
return [...response.data].sort((a, b) => a.index - b.index).map((d) => d.embedding);
}
// Timeout for Ollama embedding fetch calls to prevent hanging indefinitely

View File

@@ -17,6 +17,7 @@ import {
rateImportance,
resolveConflict,
isSemanticDuplicate,
isTransientError,
runSleepCycle,
} from "./extractor.js";
import { passesAttentionGate, passesAssistantAttentionGate } from "./index.js";
@@ -820,7 +821,7 @@ describe("runBackgroundExtraction", () => {
}
it("should skip extraction and mark as 'skipped' when disabled", async () => {
await runBackgroundExtraction(
const result = await runBackgroundExtraction(
"mem-1",
"test text",
mockDb as never,
@@ -829,6 +830,7 @@ describe("runBackgroundExtraction", () => {
mockLogger,
);
expect(mockDb.updateExtractionStatus).toHaveBeenCalledWith("mem-1", "skipped");
expect(result).toEqual({ success: true, memoryId: "mem-1" });
});
it("should mark as 'failed' when extraction returns null", async () => {
@@ -838,7 +840,7 @@ describe("runBackgroundExtraction", () => {
text: () => Promise.resolve("error"),
});
await runBackgroundExtraction(
const result = await runBackgroundExtraction(
"mem-1",
"test text",
mockDb as never,
@@ -847,6 +849,7 @@ describe("runBackgroundExtraction", () => {
mockLogger,
);
expect(mockDb.updateExtractionStatus).toHaveBeenCalledWith("mem-1", "failed");
expect(result).toEqual({ success: false, memoryId: "mem-1" });
});
it("should mark as 'complete' when extraction result is empty", async () => {
@@ -858,7 +861,7 @@ describe("runBackgroundExtraction", () => {
}),
);
await runBackgroundExtraction(
const result = await runBackgroundExtraction(
"mem-1",
"test text",
mockDb as never,
@@ -867,6 +870,7 @@ describe("runBackgroundExtraction", () => {
mockLogger,
);
expect(mockDb.updateExtractionStatus).toHaveBeenCalledWith("mem-1", "complete");
expect(result).toEqual({ success: true, memoryId: "mem-1" });
});
it("should batch entities, relationships, tags, and category in one call", async () => {
@@ -2274,10 +2278,10 @@ describe("runSleepCycle", () => {
const result = await runSleepCycle(mockDb, mockEmbeddings, mockConfig, mockLogger);
expect(result.extraction.processed).toBe(1);
// runBackgroundExtraction doesn't throw on HTTP errors, it just marks the extraction status as failed/pending
// The sleep cycle counts it as succeeded because Promise.allSettled reports it as fulfilled
expect(result.extraction.succeeded).toBe(1);
expect(result.extraction.failed).toBe(0);
// runBackgroundExtraction returns { success: false } on HTTP errors,
// so the sleep cycle correctly counts it as failed via outcome.value.success
expect(result.extraction.succeeded).toBe(0);
expect(result.extraction.failed).toBe(1);
});
it("should respect batch size and delay", async () => {
@@ -2570,3 +2574,82 @@ describe("runSleepCycle", () => {
});
});
});
// ============================================================================
// isTransientError()
// ============================================================================
describe("isTransientError", () => {
it("should return false for non-Error values", () => {
expect(isTransientError("string error")).toBe(false);
expect(isTransientError(42)).toBe(false);
expect(isTransientError(null)).toBe(false);
expect(isTransientError(undefined)).toBe(false);
});
it("should classify AbortError as transient", () => {
const err = new DOMException("signal aborted", "AbortError");
expect(isTransientError(err)).toBe(true);
});
it("should classify TimeoutError as transient", () => {
const err = new DOMException("signal timed out", "TimeoutError");
expect(isTransientError(err)).toBe(true);
});
it("should classify timeout messages as transient", () => {
expect(isTransientError(new Error("Request timeout after 30s"))).toBe(true);
});
it("should classify ECONNREFUSED as transient", () => {
expect(isTransientError(new Error("connect ECONNREFUSED 127.0.0.1:7687"))).toBe(true);
});
it("should classify ECONNRESET as transient", () => {
expect(isTransientError(new Error("read ECONNRESET"))).toBe(true);
});
it("should classify ETIMEDOUT as transient", () => {
expect(isTransientError(new Error("connect ETIMEDOUT 10.0.0.1:443"))).toBe(true);
});
it("should classify DNS failure (ENOTFOUND) as transient", () => {
expect(isTransientError(new Error("getaddrinfo ENOTFOUND api.openrouter.ai"))).toBe(true);
});
it("should classify HTTP 429 (rate limit) as transient", () => {
expect(isTransientError(new Error("OpenRouter API error 429: rate limited"))).toBe(true);
});
it("should classify HTTP 502 (bad gateway) as transient", () => {
expect(isTransientError(new Error("OpenRouter API error 502: bad gateway"))).toBe(true);
});
it("should classify HTTP 503 (service unavailable) as transient", () => {
expect(isTransientError(new Error("OpenRouter API error 503: service unavailable"))).toBe(true);
});
it("should classify HTTP 504 (gateway timeout) as transient", () => {
expect(isTransientError(new Error("OpenRouter API error 504: gateway timeout"))).toBe(true);
});
it("should classify network errors as transient", () => {
expect(isTransientError(new Error("network error"))).toBe(true);
expect(isTransientError(new Error("fetch failed"))).toBe(true);
expect(isTransientError(new Error("socket hang up"))).toBe(true);
});
it("should classify HTTP 500 as non-transient", () => {
expect(isTransientError(new Error("OpenRouter API error 500: internal server error"))).toBe(
false,
);
});
it("should classify JSON parse errors as non-transient", () => {
expect(isTransientError(new Error("Unexpected token < in JSON"))).toBe(false);
});
it("should classify generic errors as non-transient", () => {
expect(isTransientError(new Error("something went wrong"))).toBe(false);
});
});

View File

@@ -217,13 +217,20 @@ async function callOpenRouterStream(
// Entity Extraction
// ============================================================================
/** Max retries for transient extraction failures before marking permanently failed */
/**
* Max retries for transient extraction failures before marking permanently failed.
*
* Retry budget accounting — two layers of retry:
* Layer 1: callOpenRouter/callOpenRouterStream internal retries (config.maxRetries, default 2 = 3 attempts)
* Layer 2: Sleep cycle retries (MAX_EXTRACTION_RETRIES = 3 sleep cycles)
* Total worst-case: 3 × 3 = 9 LLM attempts per memory
*/
const MAX_EXTRACTION_RETRIES = 3;
/**
* Check if an error is transient (network/timeout) vs permanent (JSON parse, etc.)
*/
function isTransientError(err: unknown): boolean {
export function isTransientError(err: unknown): boolean {
if (!(err instanceof Error)) {
return false;
}
@@ -234,6 +241,7 @@ function isTransientError(err: unknown): boolean {
msg.includes("timeout") ||
msg.includes("econnrefused") ||
msg.includes("econnreset") ||
msg.includes("etimedout") ||
msg.includes("enotfound") ||
msg.includes("network") ||
msg.includes("fetch failed") ||
@@ -434,10 +442,10 @@ export async function runBackgroundExtraction(
logger: Logger,
currentRetries: number = 0,
abortSignal?: AbortSignal,
): Promise<void> {
): Promise<{ success: boolean; memoryId: string }> {
if (!config.enabled) {
await db.updateExtractionStatus(memoryId, "skipped").catch(() => {});
return;
return { success: true, memoryId };
}
try {
@@ -463,7 +471,7 @@ export async function runBackgroundExtraction(
// Permanent failure (JSON parse, empty response, etc.)
await db.updateExtractionStatus(memoryId, "failed");
}
return;
return { success: false, memoryId };
}
// Empty extraction is valid — not all memories have extractable entities
@@ -473,7 +481,7 @@ export async function runBackgroundExtraction(
result.tags.length === 0
) {
await db.updateExtractionStatus(memoryId, "complete");
return;
return { success: true, memoryId };
}
// Batch all entity operations into a single transaction:
@@ -497,6 +505,7 @@ export async function runBackgroundExtraction(
`${result.entities.length} entities, ${result.relationships.length} rels, ${result.tags.length} tags` +
(result.category ? `, category=${result.category}` : ""),
);
return { success: true, memoryId };
} catch (err) {
// Unexpected error during graph operations — treat as transient if retry budget remains
const isTransient = isTransientError(err);
@@ -513,6 +522,7 @@ export async function runBackgroundExtraction(
.updateExtractionStatus(memoryId, "failed", { incrementRetries: true })
.catch(() => {});
}
return { success: false, memoryId };
}
}
@@ -1058,7 +1068,7 @@ export async function runSleepCycle(
for (const outcome of outcomes) {
result.extraction.processed++;
if (outcome.status === "fulfilled") {
if (outcome.status === "fulfilled" && outcome.value.success) {
result.extraction.succeeded++;
} else {
result.extraction.failed++;

View File

@@ -114,10 +114,11 @@ const memoryNeo4jPlugin = {
limit: Type.Optional(Type.Number({ description: "Max results (default: 5)" })),
}),
async execute(_toolCallId: string, params: unknown) {
const { query, limit = 5 } = params as {
const { query, limit: rawLimit = 5 } = params as {
query: string;
limit?: number;
};
const limit = Math.floor(Math.min(50, Math.max(1, rawLimit)));
const results = await hybridSearch(
db,
@@ -197,7 +198,7 @@ const memoryNeo4jPlugin = {
const vector = await embeddings.embed(text);
// 2. Check for duplicates (vector similarity > 0.95)
const existing = await db.findSimilar(vector, 0.95, 1);
const existing = await db.findSimilar(vector, 0.95, 1, agentId);
if (existing.length > 0) {
return {
content: [
@@ -301,8 +302,9 @@ const memoryNeo4jPlugin = {
};
}
// Auto-delete if single high-confidence match
if (results.length === 1 && results[0].score > 0.9) {
// Auto-delete if single high-confidence match (0.95 threshold
// reduces false positives — 0.9 cosine similarity is not exact match)
if (results.length === 1 && results[0].score > 0.95) {
await db.deleteMemory(results[0].id, agentId);
return {
content: [
@@ -1191,8 +1193,10 @@ async function captureMessage(
extractionConfig: import("./config.js").ExtractionConfig,
logger: AutoCaptureLogger,
): Promise<{ stored: boolean; semanticDeduped: boolean }> {
// For assistant messages, rate importance first (before embedding) to skip early
const rateFirst = source === "auto-capture-assistant";
// For assistant messages, rate importance first (before embedding) to skip early.
// When extraction is disabled, rateImportance returns 0.5 (the fallback), so we
// skip the early importance gate to avoid silently blocking all assistant captures.
const rateFirst = source === "auto-capture-assistant" && extractionConfig.enabled;
let importance: number | undefined;
if (rateFirst) {
@@ -1205,15 +1209,17 @@ async function captureMessage(
const vector = await embeddings.embed(text);
// Quick dedup (same content already stored — cosine >= 0.95)
const existing = await db.findSimilar(vector, 0.95, 1);
const existing = await db.findSimilar(vector, 0.95, 1, agentId);
if (existing.length > 0) {
return { stored: false, semanticDeduped: false };
}
// Rate importance if not already done
// Rate importance if not already done.
// When extraction is disabled, rateImportance returns a fixed 0.5 fallback,
// so skip the threshold check to avoid silently blocking all captures.
if (importance === undefined) {
importance = await rateImportance(text, extractionConfig);
if (importance < importanceThreshold) {
if (extractionConfig.enabled && importance < importanceThreshold) {
return { stored: false, semanticDeduped: false };
}
}
@@ -1221,7 +1227,7 @@ async function captureMessage(
// Semantic dedup: check moderate-similarity memories (0.75-0.95)
// Pass the vector similarity score as a pre-screen to skip LLM calls
// for pairs below SEMANTIC_DEDUP_VECTOR_THRESHOLD.
const candidates = await db.findSimilar(vector, 0.75, 3);
const candidates = await db.findSimilar(vector, 0.75, 3, agentId);
if (candidates.length > 0) {
for (const candidate of candidates) {
if (await isSemanticDuplicate(text, candidate.text, extractionConfig, candidate.score)) {

View File

@@ -2,60 +2,16 @@
* Tests for mid-session core memory refresh feature.
*
* Verifies that core memories are re-injected when context usage exceeds threshold.
* Tests config parsing, threshold calculation, shouldRefresh logic, and edge cases.
*/
import { describe, it, expect } from "vitest";
// ============================================================================
// Config parsing for refreshAtContextPercent
// ============================================================================
describe("mid-session core memory refresh", () => {
// Test context threshold calculation
describe("context threshold calculation", () => {
it("should calculate usage percentage correctly", () => {
const contextWindowTokens = 200_000;
const estimatedUsedTokens = 100_000;
const usagePercent = (estimatedUsedTokens / contextWindowTokens) * 100;
expect(usagePercent).toBe(50);
});
it("should detect when threshold is exceeded", () => {
const threshold = 50;
const usagePercent = 55;
expect(usagePercent >= threshold).toBe(true);
});
it("should not trigger when below threshold", () => {
const threshold = 50;
const usagePercent = 45;
expect(usagePercent >= threshold).toBe(false);
});
});
// Test refresh frequency limiting
describe("refresh frequency limiting", () => {
const MIN_TOKENS_SINCE_REFRESH = 10_000;
it("should allow refresh when enough tokens have accumulated", () => {
const lastRefreshTokens = 50_000;
const currentTokens = 65_000;
const tokensSinceRefresh = currentTokens - lastRefreshTokens;
expect(tokensSinceRefresh >= MIN_TOKENS_SINCE_REFRESH).toBe(true);
});
it("should block refresh when not enough tokens have accumulated", () => {
const lastRefreshTokens = 50_000;
const currentTokens = 55_000;
const tokensSinceRefresh = currentTokens - lastRefreshTokens;
expect(tokensSinceRefresh >= MIN_TOKENS_SINCE_REFRESH).toBe(false);
});
it("should allow first refresh (no previous refresh)", () => {
const lastRefreshTokens = 0; // No previous refresh
const currentTokens = 100_000;
const tokensSinceRefresh = currentTokens - lastRefreshTokens;
expect(tokensSinceRefresh >= MIN_TOKENS_SINCE_REFRESH).toBe(true);
});
});
// Test config parsing
describe("config parsing", () => {
it("should accept valid refreshAtContextPercent values", async () => {
const { memoryNeo4jConfigSchema } = await import("./config.js");
@@ -67,7 +23,27 @@ describe("mid-session core memory refresh", () => {
expect(config.coreMemory.refreshAtContextPercent).toBe(50);
});
it("should reject refreshAtContextPercent of 0", async () => {
it("should accept refreshAtContextPercent of 1 (minimum)", async () => {
const { memoryNeo4jConfigSchema } = await import("./config.js");
const config = memoryNeo4jConfigSchema.parse({
neo4j: { uri: "bolt://localhost:7687", user: "neo4j", password: "test" },
embedding: { provider: "ollama" },
coreMemory: { refreshAtContextPercent: 1 },
});
expect(config.coreMemory.refreshAtContextPercent).toBe(1);
});
it("should accept refreshAtContextPercent of 100 (maximum)", async () => {
const { memoryNeo4jConfigSchema } = await import("./config.js");
const config = memoryNeo4jConfigSchema.parse({
neo4j: { uri: "bolt://localhost:7687", user: "neo4j", password: "test" },
embedding: { provider: "ollama" },
coreMemory: { refreshAtContextPercent: 100 },
});
expect(config.coreMemory.refreshAtContextPercent).toBe(100);
});
it("should treat refreshAtContextPercent of 0 as disabled (undefined)", async () => {
const { memoryNeo4jConfigSchema } = await import("./config.js");
const config = memoryNeo4jConfigSchema.parse({
neo4j: { uri: "bolt://localhost:7687", user: "neo4j", password: "test" },
@@ -77,6 +53,16 @@ describe("mid-session core memory refresh", () => {
expect(config.coreMemory.refreshAtContextPercent).toBeUndefined();
});
it("should treat negative refreshAtContextPercent as disabled (undefined)", async () => {
const { memoryNeo4jConfigSchema } = await import("./config.js");
const config = memoryNeo4jConfigSchema.parse({
neo4j: { uri: "bolt://localhost:7687", user: "neo4j", password: "test" },
embedding: { provider: "ollama" },
coreMemory: { refreshAtContextPercent: -10 },
});
expect(config.coreMemory.refreshAtContextPercent).toBeUndefined();
});
it("should throw for refreshAtContextPercent over 100", async () => {
const { memoryNeo4jConfigSchema } = await import("./config.js");
expect(() =>
@@ -88,7 +74,7 @@ describe("mid-session core memory refresh", () => {
).toThrow("coreMemory.refreshAtContextPercent must be between 1 and 100");
});
it("should default to undefined when not specified", async () => {
it("should default to undefined when coreMemory section is omitted", async () => {
const { memoryNeo4jConfigSchema } = await import("./config.js");
const config = memoryNeo4jConfigSchema.parse({
neo4j: { uri: "bolt://localhost:7687", user: "neo4j", password: "test" },
@@ -96,11 +82,231 @@ describe("mid-session core memory refresh", () => {
});
expect(config.coreMemory.refreshAtContextPercent).toBeUndefined();
});
it("should default to undefined when refreshAtContextPercent is omitted", async () => {
const { memoryNeo4jConfigSchema } = await import("./config.js");
const config = memoryNeo4jConfigSchema.parse({
neo4j: { uri: "bolt://localhost:7687", user: "neo4j", password: "test" },
embedding: { provider: "ollama" },
coreMemory: { enabled: true },
});
expect(config.coreMemory.refreshAtContextPercent).toBeUndefined();
});
});
// Test output format
// ============================================================================
// shouldRefresh logic (tests the decision flow from index.ts)
// ============================================================================
describe("shouldRefresh decision logic", () => {
// These tests mirror the logic from index.ts lines 893-916:
// 1. Skip if contextWindowTokens or estimatedUsedTokens not available
// 2. Calculate usagePercent = (estimatedUsedTokens / contextWindowTokens) * 100
// 3. Skip if usagePercent < refreshThreshold
// 4. Skip if tokens since last refresh < MIN_TOKENS_SINCE_REFRESH (10_000)
// 5. Otherwise, refresh
const MIN_TOKENS_SINCE_REFRESH = 10_000;
function shouldRefresh(params: {
contextWindowTokens: number | undefined;
estimatedUsedTokens: number | undefined;
refreshThreshold: number;
lastRefreshTokens: number;
}): boolean {
const { contextWindowTokens, estimatedUsedTokens, refreshThreshold, lastRefreshTokens } =
params;
// Skip if context info not available
if (!contextWindowTokens || !estimatedUsedTokens) {
return false;
}
const usagePercent = (estimatedUsedTokens / contextWindowTokens) * 100;
// Only refresh if we've crossed the threshold
if (usagePercent < refreshThreshold) {
return false;
}
// Check if we've already refreshed recently
const tokensSinceRefresh = estimatedUsedTokens - lastRefreshTokens;
if (tokensSinceRefresh < MIN_TOKENS_SINCE_REFRESH) {
return false;
}
return true;
}
it("should trigger refresh when usage exceeds threshold and enough tokens accumulated", () => {
expect(
shouldRefresh({
contextWindowTokens: 200_000,
estimatedUsedTokens: 120_000, // 60%
refreshThreshold: 50,
lastRefreshTokens: 0, // Never refreshed
}),
).toBe(true);
});
it("should not trigger when usage is below threshold", () => {
expect(
shouldRefresh({
contextWindowTokens: 200_000,
estimatedUsedTokens: 80_000, // 40%
refreshThreshold: 50,
lastRefreshTokens: 0,
}),
).toBe(false);
});
it("should not trigger when not enough tokens since last refresh", () => {
expect(
shouldRefresh({
contextWindowTokens: 200_000,
estimatedUsedTokens: 105_000, // 52.5%
refreshThreshold: 50,
lastRefreshTokens: 100_000, // Only 5k tokens since last refresh
}),
).toBe(false);
});
it("should trigger when enough tokens accumulated since last refresh", () => {
expect(
shouldRefresh({
contextWindowTokens: 200_000,
estimatedUsedTokens: 115_000, // 57.5%
refreshThreshold: 50,
lastRefreshTokens: 100_000, // 15k tokens since last refresh
}),
).toBe(true);
});
it("should not trigger when contextWindowTokens is undefined", () => {
expect(
shouldRefresh({
contextWindowTokens: undefined,
estimatedUsedTokens: 120_000,
refreshThreshold: 50,
lastRefreshTokens: 0,
}),
).toBe(false);
});
it("should not trigger when estimatedUsedTokens is undefined", () => {
expect(
shouldRefresh({
contextWindowTokens: 200_000,
estimatedUsedTokens: undefined,
refreshThreshold: 50,
lastRefreshTokens: 0,
}),
).toBe(false);
});
it("should handle 0% usage (empty context)", () => {
expect(
shouldRefresh({
contextWindowTokens: 200_000,
estimatedUsedTokens: 0,
refreshThreshold: 50,
lastRefreshTokens: 0,
}),
).toBe(false);
});
it("should handle 100% usage", () => {
expect(
shouldRefresh({
contextWindowTokens: 200_000,
estimatedUsedTokens: 200_000, // 100%
refreshThreshold: 50,
lastRefreshTokens: 0,
}),
).toBe(true);
});
it("should handle exact threshold boundary (50% == 50% threshold)", () => {
// usagePercent == refreshThreshold: usagePercent < refreshThreshold is false, so it proceeds
expect(
shouldRefresh({
contextWindowTokens: 200_000,
estimatedUsedTokens: 100_000, // exactly 50%
refreshThreshold: 50,
lastRefreshTokens: 0,
}),
).toBe(true);
});
it("should handle threshold of 1 (refresh almost immediately)", () => {
expect(
shouldRefresh({
contextWindowTokens: 200_000,
estimatedUsedTokens: 15_000, // 7.5%
refreshThreshold: 1,
lastRefreshTokens: 0,
}),
).toBe(true);
});
it("should handle threshold of 100 (refresh only at full context)", () => {
expect(
shouldRefresh({
contextWindowTokens: 200_000,
estimatedUsedTokens: 190_000, // 95%
refreshThreshold: 100,
lastRefreshTokens: 0,
}),
).toBe(false);
});
it("should allow first refresh even when lastRefreshTokens is 0", () => {
expect(
shouldRefresh({
contextWindowTokens: 200_000,
estimatedUsedTokens: 110_000,
refreshThreshold: 50,
lastRefreshTokens: 0,
}),
).toBe(true);
});
it("should support multiple refresh cycles with cumulative token growth", () => {
// First refresh at 110k tokens
const firstResult = shouldRefresh({
contextWindowTokens: 200_000,
estimatedUsedTokens: 110_000,
refreshThreshold: 50,
lastRefreshTokens: 0,
});
expect(firstResult).toBe(true);
// Second attempt too soon (only 5k since first)
const secondResult = shouldRefresh({
contextWindowTokens: 200_000,
estimatedUsedTokens: 115_000,
refreshThreshold: 50,
lastRefreshTokens: 110_000,
});
expect(secondResult).toBe(false);
// Third attempt after enough growth (15k since first refresh)
const thirdResult = shouldRefresh({
contextWindowTokens: 200_000,
estimatedUsedTokens: 125_000,
refreshThreshold: 50,
lastRefreshTokens: 110_000,
});
expect(thirdResult).toBe(true);
});
});
// ============================================================================
// Output format
// ============================================================================
describe("refresh output format", () => {
it("should format core memories correctly", () => {
it("should format core memories as XML-wrapped bullet list", () => {
const coreMemories = [
{ text: "User prefers TypeScript over JavaScript" },
{ text: "User works at Acme Corp" },
@@ -113,5 +319,14 @@ describe("mid-session core memory refresh", () => {
expect(output).toContain("- User prefers TypeScript over JavaScript");
expect(output).toContain("- User works at Acme Corp");
});
it("should handle single core memory", () => {
const coreMemories = [{ text: "Only memory" }];
const content = coreMemories.map((m) => `- ${m.text}`).join("\n");
const output = `<core-memory-refresh>\nReminder of persistent context (you may have seen this earlier, re-stating for recency):\n${content}\n</core-memory-refresh>`;
expect(output).toContain("- Only memory");
expect(output.match(/^- /gm)?.length).toBe(1);
});
});
});

View File

@@ -276,6 +276,60 @@ describe("Neo4jMemoryClient", () => {
expect(result).toEqual([]);
expect(mockLogger.debug).toHaveBeenCalled();
});
it("should filter by agentId when provided", async () => {
mockSession.run.mockResolvedValue({
records: [
{
get: vi.fn((key) => {
if (key === "id") return "mem-1";
if (key === "text") return "similar text";
if (key === "similarity") return 0.96;
return null;
}),
},
],
});
const result = await client.findSimilar([0.1, 0.2, 0.3], 0.95, 5, "agent-1");
expect(result).toHaveLength(1);
expect(result[0]).toEqual({ id: "mem-1", text: "similar text", score: 0.96 });
// Should include agentId filter in query and params
expect(mockSession.run).toHaveBeenCalledWith(
expect.stringContaining("node.agentId = $agentId"),
expect.objectContaining({ agentId: "agent-1" }),
);
});
it("should fetch extra candidates and trim when agentId is provided", async () => {
mockSession.run.mockResolvedValue({
records: [
{
get: vi.fn((key) => {
if (key === "id") return "mem-1";
if (key === "text") return "text 1";
if (key === "similarity") return 0.99;
return null;
}),
},
{
get: vi.fn((key) => {
if (key === "id") return "mem-2";
if (key === "text") return "text 2";
if (key === "similarity") return 0.97;
return null;
}),
},
],
});
// Request limit=1 with agentId: should fetch 3x candidates (limit*3) and trim to 1
const result = await client.findSimilar([0.1, 0.2], 0.95, 1, "agent-1");
expect(result).toHaveLength(1);
expect(result[0].id).toBe("mem-1");
});
});
// ------------------------------------------------------------------------
@@ -1298,6 +1352,15 @@ describe("Neo4jMemoryClient", () => {
extractionRetries: 1,
});
});
it("should not pass agentId: undefined in listPendingExtractions params", async () => {
mockSession.run.mockResolvedValue({ records: [] });
await client.listPendingExtractions(50);
const params = mockSession.run.mock.calls[0][1] as Record<string, unknown>;
expect(params).not.toHaveProperty("agentId");
});
});
// ------------------------------------------------------------------------
@@ -1356,8 +1419,49 @@ describe("Neo4jMemoryClient", () => {
const result = await client.bm25Search("test query", 10);
expect(result).toHaveLength(1);
// Score should be normalized (divided by max)
// Single result: score should be moderate 0.5 (not 1.0) to avoid inflating weak matches
expect(result[0].score).toBe(0.5);
});
it("should normalize BM25 scores with min-max when multiple results exist", async () => {
mockSession.run.mockResolvedValue({
records: [
{
get: vi.fn((key) => {
const data: Record<string, any> = {
id: "m1",
text: "best match",
category: "fact",
importance: 0.8,
createdAt: "2024-01-01",
bm25Score: 10.0,
};
return data[key];
}),
},
{
get: vi.fn((key) => {
const data: Record<string, any> = {
id: "m2",
text: "worst match",
category: "fact",
importance: 0.5,
createdAt: "2024-01-02",
bm25Score: 2.0,
};
return data[key];
}),
},
],
});
const result = await client.bm25Search("test", 10);
expect(result).toHaveLength(2);
// Best result gets score 1.0 (FLOOR + (1-FLOOR)*1)
expect(result[0].score).toBe(1.0);
// Worst result gets FLOOR (0.3)
expect(result[1].score).toBeCloseTo(0.3);
});
it("should escape Lucene special characters in BM25 query", async () => {
@@ -1404,6 +1508,75 @@ describe("Neo4jMemoryClient", () => {
});
});
// ------------------------------------------------------------------------
// reindex()
// ------------------------------------------------------------------------
describe("reindex", () => {
it("should use UNWIND batch update instead of individual queries", async () => {
// Mock drop index session
const dropSession = createMockSession();
// Mock fetch session (returns 2 memories)
const fetchSession = createMockSession();
fetchSession.run.mockResolvedValueOnce({
records: [
{ get: vi.fn((key) => (key === "id" ? "m1" : "text 1")) },
{ get: vi.fn((key) => (key === "id" ? "m2" : "text 2")) },
],
});
// Mock batch update session
const updateSession = createMockSession();
// Mock recreate index session
const indexSession = createMockSession();
mockDriver.session
.mockReturnValueOnce(dropSession)
.mockReturnValueOnce(fetchSession)
.mockReturnValueOnce(updateSession)
.mockReturnValueOnce(indexSession);
const embedFn = vi.fn().mockResolvedValue([
[0.1, 0.2],
[0.3, 0.4],
]);
await client.reindex(embedFn, { batchSize: 50 });
// Should call UNWIND batch, not individual queries
expect(updateSession.run).toHaveBeenCalledTimes(1);
expect(updateSession.run).toHaveBeenCalledWith(
expect.stringContaining("UNWIND $items"),
expect.objectContaining({
items: [
{ id: "m1", embedding: [0.1, 0.2] },
{ id: "m2", embedding: [0.3, 0.4] },
],
}),
);
});
it("should skip batch update when all embeddings are empty", async () => {
const dropSession = createMockSession();
const fetchSession = createMockSession();
fetchSession.run.mockResolvedValueOnce({
records: [{ get: vi.fn((key) => (key === "id" ? "m1" : "text 1")) }],
});
const indexSession = createMockSession();
mockDriver.session
.mockReturnValueOnce(dropSession)
.mockReturnValueOnce(fetchSession)
.mockReturnValueOnce(indexSession);
const embedFn = vi.fn().mockResolvedValue([[]]);
await client.reindex(embedFn, { batchSize: 50 });
// No update session should be created (only drop, fetch, and index sessions)
expect(mockDriver.session).toHaveBeenCalledTimes(3);
});
});
// ------------------------------------------------------------------------
// Retrieval tracking
// ------------------------------------------------------------------------

View File

@@ -18,6 +18,10 @@ import type {
} from "./schema.js";
import { ALLOWED_RELATIONSHIP_TYPES, escapeLucene, validateRelationshipType } from "./schema.js";
// SAFETY: This pattern is built from the hardcoded ALLOWED_RELATIONSHIP_TYPES constant,
// not from user input. It's used in Cypher variable-length path patterns like
// (e1)-[:WORKS_AT|LIVES_AT|...*1..N]-(e2). Since the source is a compile-time
// constant, there is no injection risk.
const RELATIONSHIP_TYPE_PATTERN = [...ALLOWED_RELATIONSHIP_TYPES].join("|");
// ============================================================================
@@ -501,7 +505,7 @@ export class Neo4jMemoryClient {
const FLOOR = 0.3; // Minimum normalized score for the lowest-ranked result
return records.map((r) => ({
...r,
score: range > 0 ? FLOOR + ((1 - FLOOR) * (r.rawScore - minScore)) / range : 1.0, // All scores identical → all get 1.0
score: range > 0 ? FLOOR + ((1 - FLOOR) * (r.rawScore - minScore)) / range : 0.5, // Single result or identical scores → moderate 0.5 to avoid inflating weak matches
}));
} finally {
await session.close();
@@ -609,7 +613,8 @@ export class Neo4jMemoryClient {
}
return Array.from(byId.values())
.toSorted((a, b) => b.score - a.score)
.slice()
.sort((a, b) => b.score - a.score)
.slice(0, limit);
} finally {
await session.close();
@@ -624,31 +629,45 @@ export class Neo4jMemoryClient {
/**
* Find similar memories by vector similarity. Used for deduplication.
* When agentId is provided, results are post-filtered to that agent
* (HNSW indexes don't support pre-filtering, so we fetch extra candidates).
*/
async findSimilar(
embedding: number[],
threshold: number = 0.95,
limit: number = 1,
agentId?: string,
): Promise<Array<{ id: string; text: string; score: number }>> {
await this.ensureInitialized();
try {
return await this.retryOnTransient(async () => {
const session = this.driver!.session();
try {
// Fetch extra candidates when filtering by agentId since HNSW
// doesn't support pre-filtering; post-filter and trim to limit.
const fetchLimit = agentId ? limit * 3 : limit;
const agentFilter = agentId ? "AND node.agentId = $agentId" : "";
const result = await session.run(
`CALL db.index.vector.queryNodes('memory_embedding_index', $limit, $embedding)
YIELD node, score
WHERE score >= $threshold
WHERE score >= $threshold ${agentFilter}
RETURN node.id AS id, node.text AS text, score AS similarity
ORDER BY score DESC`,
{ embedding, limit: neo4j.int(limit), threshold },
{
embedding,
limit: neo4j.int(fetchLimit),
threshold,
...(agentId ? { agentId } : {}),
},
);
return result.records.map((r) => ({
const results = result.records.map((r) => ({
id: r.get("id") as string,
text: r.get("text") as string,
score: r.get("similarity") as number,
}));
// Trim to requested limit after post-filtering
return agentId ? results.slice(0, limit) : results;
} finally {
await session.close();
}
@@ -1056,7 +1075,7 @@ export class Neo4jMemoryClient {
coalesce(m.extractionRetries, 0) AS extractionRetries
ORDER BY m.createdAt ASC
LIMIT $limit`,
{ limit: neo4j.int(limit), agentId },
{ limit: neo4j.int(limit), ...(agentId ? { agentId } : {}) },
);
return result.records.map((r) => ({
id: r.get("id") as string,
@@ -1082,7 +1101,7 @@ export class Neo4jMemoryClient {
`MATCH (m:Memory)
${agentFilter}
RETURN m.extractionStatus AS status, count(m) AS count`,
{ agentId },
agentId ? { agentId } : {},
);
const counts: Record<string, number> = {
pending: 0,
@@ -1193,51 +1212,64 @@ export class Neo4jMemoryClient {
return a < b ? `${a}:${b}` : `${b}:${a}`;
};
// Process vector queries in concurrent batches to avoid overwhelming Neo4j
// while still being much faster than fully sequential execution.
const DEDUP_CONCURRENCY = 8;
let pairsFound = 0;
for (const id of memoryData.keys()) {
// Retry individual vector queries on transient errors (each uses a fresh session)
const similar = await this.retryOnTransient(async () => {
const session = this.driver!.session();
try {
return await session.run(
`MATCH (src:Memory {id: $id})
CALL db.index.vector.queryNodes('memory_embedding_index', $k, src.embedding)
YIELD node, score
WHERE node.id <> $id AND score >= $threshold
RETURN node.id AS matchId, score`,
{ id, k: neo4j.int(10), threshold },
);
} finally {
await session.close();
}
});
const allIds = [...memoryData.keys()];
for (const r of similar.records) {
const matchId = r.get("matchId") as string;
if (memoryData.has(matchId)) {
union(id, matchId);
pairsFound++;
// Capture similarity score if requested
if (pairwiseSimilarities) {
const score = r.get("score") as number;
const pairKey = makePairKey(id, matchId);
// Keep the highest score if we see this pair multiple times
const existing = pairwiseSimilarities.get(pairKey);
if (existing === undefined || score > existing) {
pairwiseSimilarities.set(pairKey, score);
}
}
}
}
// Early exit if we've found many pairs (safety bound)
for (let batchStart = 0; batchStart < allIds.length; batchStart += DEDUP_CONCURRENCY) {
if (pairsFound > 500) {
this.logger.warn(
`memory-neo4j: findDuplicateClusters hit safety bound (500 pairs) — some duplicates may not be detected. Consider running with a higher threshold.`,
);
break;
}
const batch = allIds.slice(batchStart, batchStart + DEDUP_CONCURRENCY);
const results = await Promise.all(
batch.map((id) =>
this.retryOnTransient(async () => {
const session = this.driver!.session();
try {
return await session.run(
`MATCH (src:Memory {id: $id})
CALL db.index.vector.queryNodes('memory_embedding_index', $k, src.embedding)
YIELD node, score
WHERE node.id <> $id AND score >= $threshold
RETURN node.id AS matchId, score`,
{ id, k: neo4j.int(10), threshold },
);
} finally {
await session.close();
}
}),
),
);
for (let idx = 0; idx < batch.length; idx++) {
const id = batch[idx];
const similar = results[idx];
for (const r of similar.records) {
const matchId = r.get("matchId") as string;
if (memoryData.has(matchId)) {
union(id, matchId);
pairsFound++;
// Capture similarity score if requested
if (pairwiseSimilarities) {
const score = r.get("score") as number;
const pairKey = makePairKey(id, matchId);
// Keep the highest score if we see this pair multiple times
const existing = pairwiseSimilarities.get(pairKey);
if (existing === undefined || score > existing) {
pairwiseSimilarities.set(pairKey, score);
}
}
}
}
}
}
// Step 3: Group by root
@@ -1490,25 +1522,27 @@ export class Neo4jMemoryClient {
}
await this.ensureInitialized();
const session = this.driver!.session();
try {
// Atomic: decrement mentionCount and delete in a single Cypher statement
// to prevent inconsistent state if a crash occurs between operations
const result = await session.run(
`UNWIND $ids AS memId
MATCH (m:Memory {id: memId})
OPTIONAL MATCH (m)-[:MENTIONS]->(e:Entity)
SET e.mentionCount = CASE WHEN e.mentionCount > 0 THEN e.mentionCount - 1 ELSE 0 END
WITH m, count(e) AS _
DETACH DELETE m
RETURN count(*) AS deleted`,
{ ids: memoryIds },
);
return this.retryOnTransient(async () => {
const session = this.driver!.session();
try {
// Atomic: decrement mentionCount and delete in a single Cypher statement
// to prevent inconsistent state if a crash occurs between operations
const result = await session.run(
`UNWIND $ids AS memId
MATCH (m:Memory {id: memId})
OPTIONAL MATCH (m)-[:MENTIONS]->(e:Entity)
SET e.mentionCount = CASE WHEN e.mentionCount > 0 THEN e.mentionCount - 1 ELSE 0 END
WITH m, count(e) AS _
DETACH DELETE m
RETURN count(*) AS deleted`,
{ ids: memoryIds },
);
return (result.records[0]?.get("deleted") as number) ?? 0;
} finally {
await session.close();
}
return (result.records[0]?.get("deleted") as number) ?? 0;
} finally {
await session.close();
}
});
}
// --------------------------------------------------------------------------
@@ -1766,7 +1800,7 @@ export class Neo4jMemoryClient {
}
// Scores should already be sorted descending, but ensure it
const sorted = scores.toSorted((a, b) => b.effectiveScore - a.effectiveScore);
const sorted = [...scores].sort((a, b) => b.effectiveScore - a.effectiveScore);
// Find the index at the percentile boundary
// For top 20%, we want the score at index = 20% of total
@@ -1888,18 +1922,25 @@ export class Neo4jMemoryClient {
const batch = memories.slice(i, i + batchSize);
const vectors = await embedFn(batch.map((m) => m.text));
const session = this.driver!.session();
try {
for (let j = 0; j < batch.length; j++) {
if (vectors[j] && vectors[j].length > 0) {
await session.run("MATCH (m:Memory {id: $id}) SET m.embedding = $embedding", {
id: batch[j].id,
embedding: vectors[j],
});
}
// Build items array for batch UNWIND update
const items: Array<{ id: string; embedding: number[] }> = [];
for (let j = 0; j < batch.length; j++) {
if (vectors[j] && vectors[j].length > 0) {
items.push({ id: batch[j].id, embedding: vectors[j] });
}
}
if (items.length > 0) {
const session = this.driver!.session();
try {
await session.run(
`UNWIND $items AS item
MATCH (m:Memory {id: item.id})
SET m.embedding = item.embedding`,
{ items },
);
} finally {
await session.close();
}
} finally {
await session.close();
}
progress("memories", Math.min(i + batchSize, memories.length), memories.length);
}

View File

@@ -1,15 +1,20 @@
/**
* Tests for search.ts — Hybrid Search & RRF Fusion.
*
* Tests the exported pure logic: classifyQuery() and getAdaptiveWeights().
* Note: fuseWithConfidenceRRF() is not exported (private module-level function)
* and is tested indirectly through hybridSearch().
* Tests the exported pure logic: classifyQuery(), getAdaptiveWeights(), and fuseWithConfidenceRRF().
* hybridSearch() is tested with mocked Neo4j client and Embeddings.
*/
import { describe, it, expect, vi, beforeEach } from "vitest";
import type { Embeddings } from "./embeddings.js";
import type { Neo4jMemoryClient } from "./neo4j-client.js";
import type { SearchSignalResult } from "./schema.js";
import { classifyQuery, getAdaptiveWeights, hybridSearch } from "./search.js";
import {
classifyQuery,
getAdaptiveWeights,
fuseWithConfidenceRRF,
hybridSearch,
} from "./search.js";
// ============================================================================
// classifyQuery()
@@ -168,15 +173,27 @@ describe("getAdaptiveWeights", () => {
// ============================================================================
describe("hybridSearch", () => {
// Create mock db and embeddings
const mockDb = {
// Properly typed mocks matching the interfaces hybridSearch depends on.
// Using Pick<> to extract only the methods hybridSearch actually calls,
// so TypeScript will catch interface changes (e.g. renamed or removed methods).
type MockedDb = {
[K in keyof Pick<
Neo4jMemoryClient,
"vectorSearch" | "bm25Search" | "graphSearch" | "recordRetrievals"
>]: ReturnType<typeof vi.fn>;
};
type MockedEmbeddings = {
[K in keyof Pick<Embeddings, "embed" | "embedBatch">]: ReturnType<typeof vi.fn>;
};
const mockDb: MockedDb = {
vectorSearch: vi.fn(),
bm25Search: vi.fn(),
graphSearch: vi.fn(),
recordRetrievals: vi.fn(),
};
const mockEmbeddings = {
const mockEmbeddings: MockedEmbeddings = {
embed: vi.fn(),
embedBatch: vi.fn(),
};
@@ -204,8 +221,8 @@ describe("hybridSearch", () => {
mockDb.bm25Search.mockResolvedValue([]);
const results = await hybridSearch(
mockDb as never,
mockEmbeddings as never,
mockDb as unknown as Neo4jMemoryClient,
mockEmbeddings as unknown as Embeddings,
"test query",
5,
"agent-1",
@@ -224,8 +241,8 @@ describe("hybridSearch", () => {
mockDb.bm25Search.mockResolvedValue([bm25Result]);
const results = await hybridSearch(
mockDb as never,
mockEmbeddings as never,
mockDb as unknown as Neo4jMemoryClient,
mockEmbeddings as unknown as Embeddings,
"test query",
5,
"agent-1",
@@ -247,8 +264,8 @@ describe("hybridSearch", () => {
mockDb.bm25Search.mockResolvedValue([{ ...sharedResult, score: 0.85 }]);
const results = await hybridSearch(
mockDb as never,
mockEmbeddings as never,
mockDb as unknown as Neo4jMemoryClient,
mockEmbeddings as unknown as Embeddings,
"test query",
5,
"agent-1",
@@ -270,8 +287,8 @@ describe("hybridSearch", () => {
]);
const results = await hybridSearch(
mockDb as never,
mockEmbeddings as never,
mockDb as unknown as Neo4jMemoryClient,
mockEmbeddings as unknown as Embeddings,
"tell me about Tarun",
5,
"agent-1",
@@ -287,7 +304,14 @@ describe("hybridSearch", () => {
mockDb.vectorSearch.mockResolvedValue([]);
mockDb.bm25Search.mockResolvedValue([]);
await hybridSearch(mockDb as never, mockEmbeddings as never, "test query", 5, "agent-1", false);
await hybridSearch(
mockDb as unknown as Neo4jMemoryClient,
mockEmbeddings as unknown as Embeddings,
"test query",
5,
"agent-1",
false,
);
expect(mockDb.graphSearch).not.toHaveBeenCalled();
});
@@ -301,8 +325,8 @@ describe("hybridSearch", () => {
mockDb.bm25Search.mockResolvedValue([]);
const results = await hybridSearch(
mockDb as never,
mockEmbeddings as never,
mockDb as unknown as Neo4jMemoryClient,
mockEmbeddings as unknown as Embeddings,
"test query",
3,
"agent-1",
@@ -319,7 +343,14 @@ describe("hybridSearch", () => {
]);
mockDb.bm25Search.mockResolvedValue([]);
await hybridSearch(mockDb as never, mockEmbeddings as never, "test query", 5, "agent-1", false);
await hybridSearch(
mockDb as unknown as Neo4jMemoryClient,
mockEmbeddings as unknown as Embeddings,
"test query",
5,
"agent-1",
false,
);
expect(mockDb.recordRetrievals).toHaveBeenCalledWith(["mem-1", "mem-2"]);
});
@@ -331,8 +362,8 @@ describe("hybridSearch", () => {
// Should not throw
const results = await hybridSearch(
mockDb as never,
mockEmbeddings as never,
mockDb as unknown as Neo4jMemoryClient,
mockEmbeddings as unknown as Embeddings,
"test query",
5,
"agent-1",
@@ -350,8 +381,8 @@ describe("hybridSearch", () => {
mockDb.bm25Search.mockResolvedValue([]);
const results = await hybridSearch(
mockDb as never,
mockEmbeddings as never,
mockDb as unknown as Neo4jMemoryClient,
mockEmbeddings as unknown as Embeddings,
"test query",
5,
"agent-1",
@@ -369,8 +400,8 @@ describe("hybridSearch", () => {
mockDb.bm25Search.mockResolvedValue([]);
await hybridSearch(
mockDb as never,
mockEmbeddings as never,
mockDb as unknown as Neo4jMemoryClient,
mockEmbeddings as unknown as Embeddings,
"test query",
5,
"agent-1",
@@ -387,7 +418,11 @@ describe("hybridSearch", () => {
mockDb.vectorSearch.mockResolvedValue([]);
mockDb.bm25Search.mockResolvedValue([]);
await hybridSearch(mockDb as never, mockEmbeddings as never, "test query");
await hybridSearch(
mockDb as unknown as Neo4jMemoryClient,
mockEmbeddings as unknown as Embeddings,
"test query",
);
expect(mockDb.vectorSearch).toHaveBeenCalledWith(
expect.any(Array),
@@ -397,3 +432,123 @@ describe("hybridSearch", () => {
);
});
});
// ============================================================================
// fuseWithConfidenceRRF()
// ============================================================================
describe("fuseWithConfidenceRRF", () => {
function makeSignal(id: string, score: number, text = `Memory ${id}`): SearchSignalResult {
return {
id,
text,
category: "fact",
importance: 0.7,
createdAt: "2025-01-01T00:00:00Z",
score,
};
}
it("should return empty array when all signals are empty", () => {
const result = fuseWithConfidenceRRF([[], [], []], 60, [1.0, 1.0, 1.0]);
expect(result).toEqual([]);
});
it("should handle a single signal with results", () => {
const signal = [makeSignal("a", 0.9), makeSignal("b", 0.5)];
const result = fuseWithConfidenceRRF([signal, [], []], 60, [1.0, 1.0, 1.0]);
expect(result).toHaveLength(2);
expect(result[0].id).toBe("a");
expect(result[1].id).toBe("b");
// First result should have higher RRF score than second
expect(result[0].rrfScore).toBeGreaterThan(result[1].rrfScore);
});
it("should boost candidates appearing in multiple signals", () => {
const vectorSignal = [makeSignal("shared", 0.9), makeSignal("vec-only", 0.8)];
const bm25Signal = [makeSignal("shared", 0.85)];
const result = fuseWithConfidenceRRF([vectorSignal, bm25Signal, []], 60, [1.0, 1.0, 1.0]);
// "shared" should rank higher than "vec-only" despite similar scores
// because it appears in two signals
expect(result[0].id).toBe("shared");
expect(result[1].id).toBe("vec-only");
});
it("should handle ties (same score, same rank) consistently", () => {
const signal = [makeSignal("a", 0.5), makeSignal("b", 0.5)];
const result = fuseWithConfidenceRRF([signal], 60, [1.0]);
expect(result).toHaveLength(2);
// With same score, first in signal should have higher RRF (rank 1 vs rank 2)
expect(result[0].id).toBe("a");
expect(result[1].id).toBe("b");
});
it("should respect different k values", () => {
const signal = [makeSignal("a", 0.9), makeSignal("b", 0.5)];
// Small k amplifies rank differences, large k smooths them
const resultSmallK = fuseWithConfidenceRRF([signal], 1, [1.0]);
const resultLargeK = fuseWithConfidenceRRF([signal], 1000, [1.0]);
// The ratio between first and second should be larger with smaller k
const ratioSmallK = resultSmallK[0].rrfScore / resultSmallK[1].rrfScore;
const ratioLargeK = resultLargeK[0].rrfScore / resultLargeK[1].rrfScore;
expect(ratioSmallK).toBeGreaterThan(ratioLargeK);
});
it("should handle zero-score entries", () => {
const signal = [makeSignal("a", 0.9), makeSignal("b", 0)];
const result = fuseWithConfidenceRRF([signal], 60, [1.0]);
expect(result).toHaveLength(2);
// Zero score entry should have zero RRF contribution
expect(result[1].rrfScore).toBe(0);
expect(result[0].rrfScore).toBeGreaterThan(0);
});
it("should apply signal weights correctly", () => {
// Same item appears in two signals with different weights
const signal1 = [makeSignal("a", 0.8)];
const signal2 = [makeSignal("a", 0.8)];
const resultEqual = fuseWithConfidenceRRF([signal1, signal2], 60, [1.0, 1.0]);
const resultWeighted = fuseWithConfidenceRRF([signal1, signal2], 60, [2.0, 0.5]);
// Both should have the same item, but weighted version uses different signal contributions
expect(resultEqual[0].id).toBe("a");
expect(resultWeighted[0].id).toBe("a");
// With unequal weights, overall score differs
expect(resultEqual[0].rrfScore).not.toBeCloseTo(resultWeighted[0].rrfScore);
});
it("should sort results by RRF score descending", () => {
const signal1 = [makeSignal("low", 0.3)];
const signal2 = [makeSignal("high", 0.95)];
const signal3 = [makeSignal("mid", 0.6)];
const result = fuseWithConfidenceRRF([signal1, signal2, signal3], 60, [1.0, 1.0, 1.0]);
expect(result[0].id).toBe("high");
expect(result[1].id).toBe("mid");
expect(result[2].id).toBe("low");
});
it("should deduplicate within a single signal (keep first occurrence)", () => {
const signal = [
makeSignal("dup", 0.9),
makeSignal("dup", 0.5), // duplicate — should be ignored
makeSignal("other", 0.7),
];
const result = fuseWithConfidenceRRF([signal], 60, [1.0]);
// "dup" should appear once using its first occurrence (rank 1, score 0.9)
const dupEntry = result.find((r) => r.id === "dup");
expect(dupEntry).toBeDefined();
// Only 2 unique candidates
expect(result).toHaveLength(2);
});
});

View File

@@ -120,7 +120,7 @@ type FusedCandidate = {
*
* Reference: Cormack et al. (2009), extended with confidence weighting.
*/
function fuseWithConfidenceRRF(
export function fuseWithConfidenceRRF(
signals: SearchSignalResult[][],
k: number,
weights: number[],
@@ -249,9 +249,12 @@ export async function hybridSearch(
// 4. Fuse with confidence-weighted RRF
const fused = fuseWithConfidenceRRF([vectorResults, bm25Results, graphResults], rrfK, weights);
// 5. Return top results, normalized to 0-100% display scores
const maxRrf = fused.length > 0 ? fused[0].rrfScore : 1;
const normalizer = maxRrf > 0 ? 1 / maxRrf : 1;
// 5. Return top results, normalized to 0-100% display scores.
// Only normalize when maxRrf is above a minimum threshold to avoid
// inflating weak matches (e.g., a single low-score result becoming 1.0).
const maxRrf = fused.length > 0 ? fused[0].rrfScore : 0;
const MIN_RRF_FOR_NORMALIZATION = 0.01;
const normalizer = maxRrf >= MIN_RRF_FOR_NORMALIZATION ? 1 / maxRrf : 1;
const results = fused.slice(0, limit).map((r) => ({
id: r.id,