From 3082c53a7671055e5d27e2f26933c88208810780 Mon Sep 17 00:00:00 2001 From: Tarun Sukhani Date: Thu, 5 Feb 2026 12:23:36 +0000 Subject: [PATCH] memory-neo4j: harden error handling, concurrency safety, config validation + add tests --- extensions/memory-neo4j/config.test.ts | 549 +++++++++++++ extensions/memory-neo4j/config.ts | 35 +- extensions/memory-neo4j/embeddings.test.ts | 192 +++++ extensions/memory-neo4j/embeddings.ts | 45 +- extensions/memory-neo4j/extractor.test.ts | 760 ++++++++++++++++++ extensions/memory-neo4j/extractor.ts | 124 ++- extensions/memory-neo4j/index.ts | 139 +++- .../memory-neo4j/mid-session-refresh.test.ts | 15 +- extensions/memory-neo4j/neo4j-client.ts | 601 ++++++++------ extensions/memory-neo4j/schema.test.ts | 200 +++++ extensions/memory-neo4j/search.test.ts | 400 +++++++++ 11 files changed, 2769 insertions(+), 291 deletions(-) create mode 100644 extensions/memory-neo4j/config.test.ts create mode 100644 extensions/memory-neo4j/embeddings.test.ts create mode 100644 extensions/memory-neo4j/extractor.test.ts create mode 100644 extensions/memory-neo4j/schema.test.ts create mode 100644 extensions/memory-neo4j/search.test.ts diff --git a/extensions/memory-neo4j/config.test.ts b/extensions/memory-neo4j/config.test.ts new file mode 100644 index 00000000000..81cdafa8bda --- /dev/null +++ b/extensions/memory-neo4j/config.test.ts @@ -0,0 +1,549 @@ +/** + * Tests for config.ts — Configuration Parsing. + * + * Tests memoryNeo4jConfigSchema.parse(), vectorDimsForModel(), and resolveExtractionConfig(). + */ + +import { describe, it, expect, beforeEach, afterEach, vi } from "vitest"; +import { memoryNeo4jConfigSchema, vectorDimsForModel, resolveExtractionConfig } from "./config.js"; + +// ============================================================================ +// memoryNeo4jConfigSchema.parse() +// ============================================================================ + +describe("memoryNeo4jConfigSchema.parse", () => { + // Store original env vars so we can restore them + const originalEnv = { ...process.env }; + + afterEach(() => { + process.env = { ...originalEnv }; + }); + + describe("valid complete configs", () => { + it("should parse a minimal valid config with ollama provider", () => { + const config = memoryNeo4jConfigSchema.parse({ + neo4j: { uri: "bolt://localhost:7687", user: "neo4j", password: "test" }, + embedding: { provider: "ollama" }, + }); + + expect(config.neo4j.uri).toBe("bolt://localhost:7687"); + expect(config.neo4j.username).toBe("neo4j"); + expect(config.neo4j.password).toBe("test"); + expect(config.embedding.provider).toBe("ollama"); + expect(config.embedding.model).toBe("mxbai-embed-large"); + expect(config.embedding.apiKey).toBeUndefined(); + expect(config.autoCapture).toBe(true); + expect(config.autoRecall).toBe(true); + expect(config.coreMemory.enabled).toBe(true); + expect(config.coreMemory.maxEntries).toBe(50); + }); + + it("should parse a full config with openai provider", () => { + const config = memoryNeo4jConfigSchema.parse({ + neo4j: { + uri: "neo4j+s://cloud.neo4j.io:7687", + username: "admin", + password: "secret", + }, + embedding: { + provider: "openai", + apiKey: "sk-test-key", + model: "text-embedding-3-large", + }, + autoCapture: false, + autoRecall: false, + coreMemory: { + enabled: false, + maxEntries: 100, + refreshAtContextPercent: 75, + }, + }); + + expect(config.neo4j.uri).toBe("neo4j+s://cloud.neo4j.io:7687"); + expect(config.neo4j.username).toBe("admin"); + expect(config.neo4j.password).toBe("secret"); + expect(config.embedding.provider).toBe("openai"); + expect(config.embedding.apiKey).toBe("sk-test-key"); + expect(config.embedding.model).toBe("text-embedding-3-large"); + expect(config.autoCapture).toBe(false); + expect(config.autoRecall).toBe(false); + expect(config.coreMemory.enabled).toBe(false); + expect(config.coreMemory.maxEntries).toBe(100); + expect(config.coreMemory.refreshAtContextPercent).toBe(75); + }); + + it("should support 'user' field as alias for 'username' in neo4j config", () => { + const config = memoryNeo4jConfigSchema.parse({ + neo4j: { uri: "bolt://localhost:7687", user: "custom-user", password: "pass" }, + embedding: { provider: "ollama" }, + }); + expect(config.neo4j.username).toBe("custom-user"); + }); + + it("should support 'username' field in neo4j config", () => { + const config = memoryNeo4jConfigSchema.parse({ + neo4j: { uri: "bolt://localhost:7687", username: "custom-user", password: "pass" }, + embedding: { provider: "ollama" }, + }); + expect(config.neo4j.username).toBe("custom-user"); + }); + + it("should default neo4j username to 'neo4j' when not specified", () => { + const config = memoryNeo4jConfigSchema.parse({ + neo4j: { uri: "bolt://localhost:7687", password: "pass" }, + embedding: { provider: "ollama" }, + }); + expect(config.neo4j.username).toBe("neo4j"); + }); + }); + + describe("missing required fields", () => { + it("should throw when config is null", () => { + expect(() => memoryNeo4jConfigSchema.parse(null)).toThrow("memory-neo4j config required"); + }); + + it("should throw when config is undefined", () => { + expect(() => memoryNeo4jConfigSchema.parse(undefined)).toThrow( + "memory-neo4j config required", + ); + }); + + it("should throw when config is not an object", () => { + expect(() => memoryNeo4jConfigSchema.parse("string")).toThrow("memory-neo4j config required"); + }); + + it("should throw when config is an array", () => { + expect(() => memoryNeo4jConfigSchema.parse([])).toThrow("memory-neo4j config required"); + }); + + it("should throw when neo4j section is missing", () => { + expect(() => + memoryNeo4jConfigSchema.parse({ + embedding: { provider: "ollama" }, + }), + ).toThrow("neo4j config section is required"); + }); + + it("should throw when neo4j.uri is missing", () => { + expect(() => + memoryNeo4jConfigSchema.parse({ + neo4j: { password: "test" }, + embedding: { provider: "ollama" }, + }), + ).toThrow("neo4j.uri is required"); + }); + + it("should throw when neo4j.uri is empty string", () => { + expect(() => + memoryNeo4jConfigSchema.parse({ + neo4j: { uri: "", password: "test" }, + embedding: { provider: "ollama" }, + }), + ).toThrow("neo4j.uri is required"); + }); + }); + + describe("environment variable resolution", () => { + it("should resolve ${ENV_VAR} in neo4j.password", () => { + process.env.TEST_NEO4J_PASSWORD = "resolved-password"; + const config = memoryNeo4jConfigSchema.parse({ + neo4j: { + uri: "bolt://localhost:7687", + password: "${TEST_NEO4J_PASSWORD}", + }, + embedding: { provider: "ollama" }, + }); + expect(config.neo4j.password).toBe("resolved-password"); + }); + + it("should resolve ${ENV_VAR} in embedding.apiKey", () => { + process.env.TEST_OPENAI_KEY = "sk-from-env"; + const config = memoryNeo4jConfigSchema.parse({ + neo4j: { uri: "bolt://localhost:7687", password: "" }, + embedding: { provider: "openai", apiKey: "${TEST_OPENAI_KEY}" }, + }); + expect(config.embedding.apiKey).toBe("sk-from-env"); + }); + + it("should throw when referenced env var is not set", () => { + delete process.env.NONEXISTENT_VAR; + expect(() => + memoryNeo4jConfigSchema.parse({ + neo4j: { + uri: "bolt://localhost:7687", + password: "${NONEXISTENT_VAR}", + }, + embedding: { provider: "ollama" }, + }), + ).toThrow("Environment variable NONEXISTENT_VAR is not set"); + }); + }); + + describe("default values", () => { + it("should default autoCapture to true", () => { + const config = memoryNeo4jConfigSchema.parse({ + neo4j: { uri: "bolt://localhost:7687", password: "" }, + embedding: { provider: "ollama" }, + }); + expect(config.autoCapture).toBe(true); + }); + + it("should default autoRecall to true", () => { + const config = memoryNeo4jConfigSchema.parse({ + neo4j: { uri: "bolt://localhost:7687", password: "" }, + embedding: { provider: "ollama" }, + }); + expect(config.autoRecall).toBe(true); + }); + + it("should default coreMemory.enabled to true", () => { + const config = memoryNeo4jConfigSchema.parse({ + neo4j: { uri: "bolt://localhost:7687", password: "" }, + embedding: { provider: "ollama" }, + }); + expect(config.coreMemory.enabled).toBe(true); + }); + + it("should default coreMemory.maxEntries to 50", () => { + const config = memoryNeo4jConfigSchema.parse({ + neo4j: { uri: "bolt://localhost:7687", password: "" }, + embedding: { provider: "ollama" }, + }); + expect(config.coreMemory.maxEntries).toBe(50); + }); + + it("should default refreshAtContextPercent to undefined", () => { + const config = memoryNeo4jConfigSchema.parse({ + neo4j: { uri: "bolt://localhost:7687", password: "" }, + embedding: { provider: "ollama" }, + }); + expect(config.coreMemory.refreshAtContextPercent).toBeUndefined(); + }); + + it("should default embedding model to mxbai-embed-large for ollama", () => { + const config = memoryNeo4jConfigSchema.parse({ + neo4j: { uri: "bolt://localhost:7687", password: "" }, + embedding: { provider: "ollama" }, + }); + expect(config.embedding.model).toBe("mxbai-embed-large"); + }); + + it("should default embedding model to text-embedding-3-small for openai", () => { + const config = memoryNeo4jConfigSchema.parse({ + neo4j: { uri: "bolt://localhost:7687", password: "" }, + embedding: { provider: "openai", apiKey: "sk-test" }, + }); + expect(config.embedding.model).toBe("text-embedding-3-small"); + }); + + it("should default neo4j.password to empty string when not provided", () => { + const config = memoryNeo4jConfigSchema.parse({ + neo4j: { uri: "bolt://localhost:7687" }, + embedding: { provider: "ollama" }, + }); + expect(config.neo4j.password).toBe(""); + }); + }); + + describe("provider validation", () => { + it("should require apiKey for openai provider", () => { + expect(() => + memoryNeo4jConfigSchema.parse({ + neo4j: { uri: "bolt://localhost:7687", password: "" }, + embedding: { provider: "openai" }, + }), + ).toThrow("embedding.apiKey is required for OpenAI provider"); + }); + + it("should not require apiKey for ollama provider", () => { + const config = memoryNeo4jConfigSchema.parse({ + neo4j: { uri: "bolt://localhost:7687", password: "" }, + embedding: { provider: "ollama" }, + }); + expect(config.embedding.apiKey).toBeUndefined(); + }); + + it("should default to openai when no provider is specified", () => { + // No provider but has apiKey — should default to openai + const config = memoryNeo4jConfigSchema.parse({ + neo4j: { uri: "bolt://localhost:7687", password: "" }, + embedding: { apiKey: "sk-test" }, + }); + expect(config.embedding.provider).toBe("openai"); + }); + + it("should accept embedding.baseUrl", () => { + const config = memoryNeo4jConfigSchema.parse({ + neo4j: { uri: "bolt://localhost:7687", password: "" }, + embedding: { provider: "ollama", baseUrl: "http://my-ollama:11434" }, + }); + expect(config.embedding.baseUrl).toBe("http://my-ollama:11434"); + }); + }); + + describe("unknown keys rejected", () => { + it("should reject unknown top-level keys", () => { + expect(() => + memoryNeo4jConfigSchema.parse({ + neo4j: { uri: "bolt://localhost:7687", password: "" }, + embedding: { provider: "ollama" }, + unknownKey: "value", + }), + ).toThrow("unknown keys: unknownKey"); + }); + + it("should reject unknown neo4j keys", () => { + expect(() => + memoryNeo4jConfigSchema.parse({ + neo4j: { uri: "bolt://localhost:7687", password: "", port: 7687 }, + embedding: { provider: "ollama" }, + }), + ).toThrow("unknown keys: port"); + }); + + it("should reject unknown embedding keys", () => { + expect(() => + memoryNeo4jConfigSchema.parse({ + neo4j: { uri: "bolt://localhost:7687", password: "" }, + embedding: { provider: "ollama", temperature: 0.5 }, + }), + ).toThrow("unknown keys: temperature"); + }); + + it("should reject unknown coreMemory keys", () => { + expect(() => + memoryNeo4jConfigSchema.parse({ + neo4j: { uri: "bolt://localhost:7687", password: "" }, + embedding: { provider: "ollama" }, + coreMemory: { unknownField: true }, + }), + ).toThrow("unknown keys: unknownField"); + }); + }); + + describe("refreshAtContextPercent edge cases", () => { + it("should accept refreshAtContextPercent of 1 (minimum valid)", () => { + const config = memoryNeo4jConfigSchema.parse({ + neo4j: { uri: "bolt://localhost:7687", password: "" }, + embedding: { provider: "ollama" }, + coreMemory: { refreshAtContextPercent: 1 }, + }); + expect(config.coreMemory.refreshAtContextPercent).toBe(1); + }); + + it("should accept refreshAtContextPercent of 100 (maximum valid)", () => { + const config = memoryNeo4jConfigSchema.parse({ + neo4j: { uri: "bolt://localhost:7687", password: "" }, + embedding: { provider: "ollama" }, + coreMemory: { refreshAtContextPercent: 100 }, + }); + expect(config.coreMemory.refreshAtContextPercent).toBe(100); + }); + + it("should reject refreshAtContextPercent of 0", () => { + const config = memoryNeo4jConfigSchema.parse({ + neo4j: { uri: "bolt://localhost:7687", password: "" }, + embedding: { provider: "ollama" }, + coreMemory: { refreshAtContextPercent: 0 }, + }); + expect(config.coreMemory.refreshAtContextPercent).toBeUndefined(); + }); + + it("should reject refreshAtContextPercent over 100 by throwing", () => { + expect(() => + memoryNeo4jConfigSchema.parse({ + neo4j: { uri: "bolt://localhost:7687", password: "" }, + embedding: { provider: "ollama" }, + coreMemory: { refreshAtContextPercent: 150 }, + }), + ).toThrow("coreMemory.refreshAtContextPercent must be between 1 and 100"); + }); + + it("should reject negative refreshAtContextPercent", () => { + const config = memoryNeo4jConfigSchema.parse({ + neo4j: { uri: "bolt://localhost:7687", password: "" }, + embedding: { provider: "ollama" }, + coreMemory: { refreshAtContextPercent: -10 }, + }); + expect(config.coreMemory.refreshAtContextPercent).toBeUndefined(); + }); + + it("should ignore non-number refreshAtContextPercent", () => { + const config = memoryNeo4jConfigSchema.parse({ + neo4j: { uri: "bolt://localhost:7687", password: "" }, + embedding: { provider: "ollama" }, + coreMemory: { refreshAtContextPercent: "50" }, + }); + expect(config.coreMemory.refreshAtContextPercent).toBeUndefined(); + }); + }); + + describe("extraction config section", () => { + it("should parse extraction config when provided", () => { + process.env.EXTRACTION_DUMMY = ""; // avoid env var issues + const config = memoryNeo4jConfigSchema.parse({ + neo4j: { uri: "bolt://localhost:7687", password: "" }, + embedding: { provider: "ollama" }, + extraction: { + apiKey: "or-test-key", + model: "google/gemini-2.0-flash-001", + baseUrl: "https://openrouter.ai/api/v1", + }, + }); + expect(config.extraction).toBeDefined(); + expect(config.extraction!.apiKey).toBe("or-test-key"); + expect(config.extraction!.model).toBe("google/gemini-2.0-flash-001"); + }); + + it("should not include extraction when section is empty", () => { + const config = memoryNeo4jConfigSchema.parse({ + neo4j: { uri: "bolt://localhost:7687", password: "" }, + embedding: { provider: "ollama" }, + extraction: {}, + }); + expect(config.extraction).toBeUndefined(); + }); + + it("should reject unknown keys in extraction section", () => { + expect(() => + memoryNeo4jConfigSchema.parse({ + neo4j: { uri: "bolt://localhost:7687", password: "" }, + embedding: { provider: "ollama" }, + extraction: { badKey: "value" }, + }), + ).toThrow("unknown keys: badKey"); + }); + }); +}); + +// ============================================================================ +// vectorDimsForModel() +// ============================================================================ + +describe("vectorDimsForModel", () => { + describe("known models", () => { + it("should return 1536 for text-embedding-3-small", () => { + expect(vectorDimsForModel("text-embedding-3-small")).toBe(1536); + }); + + it("should return 3072 for text-embedding-3-large", () => { + expect(vectorDimsForModel("text-embedding-3-large")).toBe(3072); + }); + + it("should return 1024 for mxbai-embed-large", () => { + expect(vectorDimsForModel("mxbai-embed-large")).toBe(1024); + }); + + it("should return 768 for nomic-embed-text", () => { + expect(vectorDimsForModel("nomic-embed-text")).toBe(768); + }); + + it("should return 384 for all-minilm", () => { + expect(vectorDimsForModel("all-minilm")).toBe(384); + }); + }); + + describe("prefix matching", () => { + it("should match versioned model names via prefix", () => { + // mxbai-embed-large:latest should match mxbai-embed-large + expect(vectorDimsForModel("mxbai-embed-large:latest")).toBe(1024); + }); + + it("should match model with additional version suffix", () => { + expect(vectorDimsForModel("nomic-embed-text:v1.5")).toBe(768); + }); + }); + + describe("unknown models", () => { + it("should return default 1024 for unknown model", () => { + expect(vectorDimsForModel("unknown-model")).toBe(1024); + }); + + it("should return default 1024 for empty string", () => { + expect(vectorDimsForModel("")).toBe(1024); + }); + + it("should return default 1024 for unrecognized prefix", () => { + expect(vectorDimsForModel("custom-embed-v2")).toBe(1024); + }); + }); +}); + +// ============================================================================ +// resolveExtractionConfig() +// ============================================================================ + +describe("resolveExtractionConfig", () => { + const originalEnv = { ...process.env }; + + afterEach(() => { + process.env = { ...originalEnv }; + }); + + it("should return disabled config when no API key or explicit baseUrl", () => { + delete process.env.OPENROUTER_API_KEY; + const config = resolveExtractionConfig(); + expect(config.enabled).toBe(false); + expect(config.apiKey).toBe(""); + }); + + it("should enable when OPENROUTER_API_KEY env var is set", () => { + process.env.OPENROUTER_API_KEY = "or-env-key"; + const config = resolveExtractionConfig(); + expect(config.enabled).toBe(true); + expect(config.apiKey).toBe("or-env-key"); + }); + + it("should enable when plugin config provides apiKey", () => { + delete process.env.OPENROUTER_API_KEY; + const config = resolveExtractionConfig({ + apiKey: "or-plugin-key", + model: "custom-model", + baseUrl: "https://custom.ai/api", + }); + expect(config.enabled).toBe(true); + expect(config.apiKey).toBe("or-plugin-key"); + expect(config.model).toBe("custom-model"); + expect(config.baseUrl).toBe("https://custom.ai/api"); + }); + + it("should enable when baseUrl is explicitly set (local Ollama, no API key)", () => { + delete process.env.OPENROUTER_API_KEY; + const config = resolveExtractionConfig({ + model: "llama3", + baseUrl: "http://localhost:11434/v1", + }); + expect(config.enabled).toBe(true); + expect(config.apiKey).toBe(""); + expect(config.baseUrl).toBe("http://localhost:11434/v1"); + }); + + it("should use defaults for model and baseUrl", () => { + delete process.env.OPENROUTER_API_KEY; + delete process.env.EXTRACTION_MODEL; + delete process.env.EXTRACTION_BASE_URL; + const config = resolveExtractionConfig(); + expect(config.model).toBe("google/gemini-2.0-flash-001"); + expect(config.baseUrl).toBe("https://openrouter.ai/api/v1"); + }); + + it("should use EXTRACTION_MODEL env var", () => { + delete process.env.OPENROUTER_API_KEY; + process.env.EXTRACTION_MODEL = "meta/llama-3-70b"; + const config = resolveExtractionConfig(); + expect(config.model).toBe("meta/llama-3-70b"); + }); + + it("should use EXTRACTION_BASE_URL env var", () => { + delete process.env.OPENROUTER_API_KEY; + process.env.EXTRACTION_BASE_URL = "https://my-proxy.ai/v1"; + const config = resolveExtractionConfig(); + expect(config.baseUrl).toBe("https://my-proxy.ai/v1"); + }); + + it("should always set temperature to 0.0 and maxRetries to 2", () => { + const config = resolveExtractionConfig(); + expect(config.temperature).toBe(0.0); + expect(config.maxRetries).toBe(2); + }); +}); diff --git a/extensions/memory-neo4j/config.ts b/extensions/memory-neo4j/config.ts index 4530539b7b0..2ccb3535456 100644 --- a/extensions/memory-neo4j/config.ts +++ b/extensions/memory-neo4j/config.ts @@ -63,7 +63,7 @@ export const MEMORY_CATEGORIES = [ export type MemoryCategory = (typeof MEMORY_CATEGORIES)[number]; -const EMBEDDING_DIMENSIONS: Record = { +export const EMBEDDING_DIMENSIONS: Record = { // OpenAI models "text-embedding-3-small": 1536, "text-embedding-3-large": 3072, @@ -75,7 +75,7 @@ const EMBEDDING_DIMENSIONS: Record = { }; // Default dimension for unknown models (Ollama models vary) -const DEFAULT_EMBEDDING_DIMS = 1024; +export const DEFAULT_EMBEDDING_DIMS = 1024; export function vectorDimsForModel(model: string): number { // Check exact match first @@ -88,7 +88,8 @@ export function vectorDimsForModel(model: string): number { return dims; } } - // Return default for unknown models + // Return default for unknown models — callers should warn when this path is taken, + // as the default 1024 dimensions may not match the actual model's output. return DEFAULT_EMBEDDING_DIMS; } @@ -164,6 +165,20 @@ export const memoryNeo4jConfigSchema = { if (typeof neo4jRaw.uri !== "string" || !neo4jRaw.uri) { throw new Error("neo4j.uri is required"); } + // Validate URI scheme — must be a valid Neo4j connection protocol + const VALID_NEO4J_SCHEMES = [ + "bolt://", + "bolt+s://", + "bolt+ssc://", + "neo4j://", + "neo4j+s://", + "neo4j+ssc://", + ]; + if (!VALID_NEO4J_SCHEMES.some((scheme) => neo4jRaw.uri.startsWith(scheme))) { + throw new Error( + `neo4j.uri must start with a valid scheme (${VALID_NEO4J_SCHEMES.join(", ")}), got: "${neo4jRaw.uri}"`, + ); + } const neo4jPassword = typeof neo4jRaw.password === "string" ? resolveEnvVars(neo4jRaw.password) : ""; @@ -212,7 +227,19 @@ export const memoryNeo4jConfigSchema = { const coreMemoryEnabled = coreMemoryRaw?.enabled !== false; // enabled by default const coreMemoryMaxEntries = typeof coreMemoryRaw?.maxEntries === "number" ? coreMemoryRaw.maxEntries : 50; - // refreshAtContextPercent: number between 0-100, or undefined to disable + if (coreMemoryMaxEntries <= 0) { + throw new Error(`coreMemory.maxEntries must be greater than 0, got: ${coreMemoryMaxEntries}`); + } + // refreshAtContextPercent: number between 1-99 to be effective, or undefined to disable. + // Values at 0 or below are ignored (disables refresh). Values above 100 are invalid. + if ( + typeof coreMemoryRaw?.refreshAtContextPercent === "number" && + coreMemoryRaw.refreshAtContextPercent > 100 + ) { + throw new Error( + `coreMemory.refreshAtContextPercent must be between 1 and 100, got: ${coreMemoryRaw.refreshAtContextPercent}`, + ); + } const refreshAtContextPercent = typeof coreMemoryRaw?.refreshAtContextPercent === "number" && coreMemoryRaw.refreshAtContextPercent > 0 && diff --git a/extensions/memory-neo4j/embeddings.test.ts b/extensions/memory-neo4j/embeddings.test.ts new file mode 100644 index 00000000000..2531f43bc57 --- /dev/null +++ b/extensions/memory-neo4j/embeddings.test.ts @@ -0,0 +1,192 @@ +/** + * Tests for embeddings.ts — Embedding Provider. + * + * Tests the Embeddings class with mocked OpenAI client and mocked fetch for Ollama. + */ + +import { describe, it, expect, vi, afterEach } from "vitest"; + +// ============================================================================ +// Constructor +// ============================================================================ + +describe("Embeddings constructor", () => { + it("should throw when OpenAI provider is used without API key", async () => { + const { Embeddings } = await import("./embeddings.js"); + expect(() => new Embeddings(undefined, "text-embedding-3-small", "openai")).toThrow( + "API key required for OpenAI embeddings", + ); + }); + + it("should not require API key for ollama provider", async () => { + const { Embeddings } = await import("./embeddings.js"); + const emb = new Embeddings(undefined, "mxbai-embed-large", "ollama"); + expect(emb).toBeDefined(); + }); +}); + +// ============================================================================ +// Ollama embed +// ============================================================================ + +describe("Embeddings - Ollama provider", () => { + const originalFetch = globalThis.fetch; + + afterEach(() => { + globalThis.fetch = originalFetch; + }); + + it("should call Ollama API with correct request body", async () => { + const { Embeddings } = await import("./embeddings.js"); + const mockVector = [0.1, 0.2, 0.3, 0.4]; + globalThis.fetch = vi.fn().mockResolvedValue({ + ok: true, + json: () => Promise.resolve({ embeddings: [mockVector] }), + }); + + const emb = new Embeddings(undefined, "mxbai-embed-large", "ollama"); + const result = await emb.embed("test text"); + + expect(result).toEqual(mockVector); + expect(globalThis.fetch).toHaveBeenCalledWith( + "http://localhost:11434/api/embed", + expect.objectContaining({ + method: "POST", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify({ + model: "mxbai-embed-large", + input: "test text", + }), + }), + ); + }); + + it("should use custom baseUrl for Ollama", async () => { + const { Embeddings } = await import("./embeddings.js"); + const mockVector = [0.5, 0.6]; + globalThis.fetch = vi.fn().mockResolvedValue({ + ok: true, + json: () => Promise.resolve({ embeddings: [mockVector] }), + }); + + const emb = new Embeddings(undefined, "mxbai-embed-large", "ollama", "http://my-host:11434"); + await emb.embed("test"); + + expect(globalThis.fetch).toHaveBeenCalledWith( + "http://my-host:11434/api/embed", + expect.any(Object), + ); + }); + + it("should throw when Ollama returns error status", async () => { + const { Embeddings } = await import("./embeddings.js"); + globalThis.fetch = vi.fn().mockResolvedValue({ + ok: false, + status: 500, + text: () => Promise.resolve("Internal Server Error"), + }); + + const emb = new Embeddings(undefined, "mxbai-embed-large", "ollama"); + await expect(emb.embed("test")).rejects.toThrow("Ollama embedding failed: 500"); + }); + + it("should throw when Ollama returns no embeddings", async () => { + const { Embeddings } = await import("./embeddings.js"); + globalThis.fetch = vi.fn().mockResolvedValue({ + ok: true, + json: () => Promise.resolve({ embeddings: [] }), + }); + + const emb = new Embeddings(undefined, "mxbai-embed-large", "ollama"); + await expect(emb.embed("test")).rejects.toThrow("No embedding returned from Ollama"); + }); + + it("should throw when Ollama returns null embeddings", async () => { + const { Embeddings } = await import("./embeddings.js"); + globalThis.fetch = vi.fn().mockResolvedValue({ + ok: true, + json: () => Promise.resolve({}), + }); + + const emb = new Embeddings(undefined, "mxbai-embed-large", "ollama"); + await expect(emb.embed("test")).rejects.toThrow("No embedding returned from Ollama"); + }); + + it("should propagate fetch errors for Ollama", async () => { + const { Embeddings } = await import("./embeddings.js"); + globalThis.fetch = vi.fn().mockRejectedValue(new Error("Network error")); + + const emb = new Embeddings(undefined, "mxbai-embed-large", "ollama"); + await expect(emb.embed("test")).rejects.toThrow("Network error"); + }); +}); + +// ============================================================================ +// OpenAI embed (via mocked client internals) +// ============================================================================ + +describe("Embeddings - OpenAI provider", () => { + it("should create instance with OpenAI provider when API key provided", async () => { + const { Embeddings } = await import("./embeddings.js"); + // Just verify construction succeeds with valid params + const emb = new Embeddings("sk-test-key", "text-embedding-3-small", "openai"); + expect(emb).toBeDefined(); + }); + + it("should have embed and embedBatch methods", async () => { + const { Embeddings } = await import("./embeddings.js"); + const emb = new Embeddings("sk-test-key", "text-embedding-3-small", "openai"); + expect(typeof emb.embed).toBe("function"); + expect(typeof emb.embedBatch).toBe("function"); + }); +}); + +// ============================================================================ +// Batch embedding +// ============================================================================ + +describe("Embeddings - embedBatch", () => { + const originalFetch = globalThis.fetch; + + afterEach(() => { + globalThis.fetch = originalFetch; + }); + + it("should return empty array for empty input (openai)", async () => { + const { Embeddings } = await import("./embeddings.js"); + const emb = new Embeddings("sk-test", "text-embedding-3-small", "openai"); + const results = await emb.embedBatch([]); + expect(results).toEqual([]); + }); + + it("should return empty array for empty input (ollama)", async () => { + const { Embeddings } = await import("./embeddings.js"); + const emb = new Embeddings(undefined, "mxbai-embed-large", "ollama"); + const results = await emb.embedBatch([]); + expect(results).toEqual([]); + }); + + it("should use sequential calls for Ollama batch (no native batch support)", async () => { + const { Embeddings } = await import("./embeddings.js"); + let callCount = 0; + globalThis.fetch = vi.fn().mockImplementation(() => { + callCount++; + return Promise.resolve({ + ok: true, + json: () => Promise.resolve({ embeddings: [[callCount * 0.1, callCount * 0.2]] }), + }); + }); + + const emb = new Embeddings(undefined, "mxbai-embed-large", "ollama"); + const results = await emb.embedBatch(["text1", "text2", "text3"]); + + // Should make 3 separate calls + expect(globalThis.fetch).toHaveBeenCalledTimes(3); + expect(results).toHaveLength(3); + // Each result should be a vector + for (const r of results) { + expect(Array.isArray(r)).toBe(true); + expect(r.length).toBe(2); + } + }); +}); diff --git a/extensions/memory-neo4j/embeddings.ts b/extensions/memory-neo4j/embeddings.ts index ff6caac7a86..66573d55774 100644 --- a/extensions/memory-neo4j/embeddings.ts +++ b/extensions/memory-neo4j/embeddings.ts @@ -7,19 +7,29 @@ import OpenAI from "openai"; import type { EmbeddingProvider } from "./config.js"; +type Logger = { + info: (msg: string) => void; + warn: (msg: string) => void; + error: (msg: string) => void; + debug?: (msg: string) => void; +}; + export class Embeddings { private client: OpenAI | null = null; private readonly provider: EmbeddingProvider; private readonly baseUrl: string; + private readonly logger: Logger | undefined; constructor( private readonly apiKey: string | undefined, private readonly model: string = "text-embedding-3-small", provider: EmbeddingProvider = "openai", baseUrl?: string, + logger?: Logger, ) { this.provider = provider; this.baseUrl = baseUrl ?? (provider === "ollama" ? "http://localhost:11434" : ""); + this.logger = logger; if (provider === "openai") { if (!apiKey) { @@ -42,6 +52,9 @@ export class Embeddings { /** * Generate embeddings for multiple texts. * Returns array of embeddings in the same order as input. + * + * For Ollama: uses Promise.allSettled so individual failures don't break the + * entire batch. Failed embeddings are replaced with zero vectors and logged. */ async embedBatch(texts: string[]): Promise { if (texts.length === 0) { @@ -49,8 +62,32 @@ export class Embeddings { } if (this.provider === "ollama") { - // Ollama doesn't support batch, so we do sequential - return Promise.all(texts.map((t) => this.embedOllama(t))); + // Ollama doesn't support batch, so we do sequential with resilient error handling + const results = await Promise.allSettled(texts.map((t) => this.embedOllama(t))); + const embeddings: number[][] = []; + let failures = 0; + + for (let i = 0; i < results.length; i++) { + const result = results[i]; + if (result.status === "fulfilled") { + embeddings.push(result.value); + } else { + failures++; + this.logger?.warn?.( + `memory-neo4j: Ollama embedding failed for text ${i}: ${String(result.reason)}`, + ); + // Use zero vector as placeholder so indices stay aligned + embeddings.push([]); + } + } + + if (failures > 0) { + this.logger?.warn?.( + `memory-neo4j: ${failures}/${texts.length} Ollama embeddings failed in batch`, + ); + } + + return embeddings; } return this.embedBatchOpenAI(texts); @@ -79,6 +116,9 @@ export class Embeddings { return response.data.toSorted((a, b) => a.index - b.index).map((d) => d.embedding); } + // Timeout for Ollama embedding fetch calls to prevent hanging indefinitely + private static readonly EMBED_TIMEOUT_MS = 30_000; + private async embedOllama(text: string): Promise { const url = `${this.baseUrl}/api/embed`; const response = await fetch(url, { @@ -88,6 +128,7 @@ export class Embeddings { model: this.model, input: text, }), + signal: AbortSignal.timeout(Embeddings.EMBED_TIMEOUT_MS), }); if (!response.ok) { diff --git a/extensions/memory-neo4j/extractor.test.ts b/extensions/memory-neo4j/extractor.test.ts new file mode 100644 index 00000000000..de57167174e --- /dev/null +++ b/extensions/memory-neo4j/extractor.test.ts @@ -0,0 +1,760 @@ +/** + * Tests for extractor.ts — Extraction Logic. + * + * Tests exported functions: extractEntities(), extractUserMessages(), runBackgroundExtraction(). + * Note: validateExtractionResult() is not exported; it is tested indirectly through extractEntities(). + * Note: passesAttentionGate() is defined in index.ts and not exported; cannot be tested directly. + */ + +import { describe, it, expect, vi, beforeEach, afterEach } from "vitest"; +import type { ExtractionConfig } from "./config.js"; +import { extractUserMessages, extractEntities, runBackgroundExtraction } from "./extractor.js"; + +// ============================================================================ +// extractUserMessages() +// ============================================================================ + +describe("extractUserMessages", () => { + it("should extract string content from user messages", () => { + const messages = [ + { role: "user", content: "I prefer TypeScript over JavaScript" }, + { role: "user", content: "My favorite color is blue" }, + ]; + const result = extractUserMessages(messages); + expect(result).toEqual(["I prefer TypeScript over JavaScript", "My favorite color is blue"]); + }); + + it("should extract text from content block arrays", () => { + const messages = [ + { + role: "user", + content: [ + { type: "text", text: "Hello, this is a content block message" }, + { type: "image", url: "http://example.com/img.png" }, + { type: "text", text: "Another text block in same message" }, + ], + }, + ]; + const result = extractUserMessages(messages); + expect(result).toEqual([ + "Hello, this is a content block message", + "Another text block in same message", + ]); + }); + + it("should filter out assistant messages", () => { + const messages = [ + { role: "user", content: "This is a user message that is long enough" }, + { role: "assistant", content: "This is an assistant message" }, + ]; + const result = extractUserMessages(messages); + expect(result).toEqual(["This is a user message that is long enough"]); + }); + + it("should filter out system messages", () => { + const messages = [ + { role: "system", content: "You are a helpful assistant with context" }, + { role: "user", content: "This is a user message that is long enough" }, + ]; + const result = extractUserMessages(messages); + expect(result).toEqual(["This is a user message that is long enough"]); + }); + + it("should filter out messages shorter than 10 characters", () => { + const messages = [ + { role: "user", content: "short" }, // 5 chars + { role: "user", content: "1234567890" }, // exactly 10 chars + { role: "user", content: "This is longer than ten characters" }, + ]; + const result = extractUserMessages(messages); + expect(result).toEqual(["1234567890", "This is longer than ten characters"]); + }); + + it("should filter out messages containing ", () => { + const messages = [ + { role: "user", content: "Normal user message that is long enough here" }, + { + role: "user", + content: + "Some injected context that should be ignored", + }, + ]; + const result = extractUserMessages(messages); + expect(result).toEqual(["Normal user message that is long enough here"]); + }); + + it("should filter out messages containing ", () => { + const messages = [ + { role: "user", content: "System markup that should be filtered" }, + { role: "user", content: "Normal user message that is long enough here" }, + ]; + const result = extractUserMessages(messages); + expect(result).toEqual(["Normal user message that is long enough here"]); + }); + + it("should handle null and non-object messages gracefully", () => { + const messages = [ + null, + undefined, + "not an object", + 42, + { role: "user", content: "Valid message with enough length" }, + ]; + const result = extractUserMessages(messages as unknown[]); + expect(result).toEqual(["Valid message with enough length"]); + }); + + it("should return empty array when no user messages exist", () => { + const messages = [{ role: "assistant", content: "Only assistant messages" }]; + const result = extractUserMessages(messages); + expect(result).toEqual([]); + }); + + it("should return empty array for empty input", () => { + expect(extractUserMessages([])).toEqual([]); + }); + + it("should handle messages where content is neither string nor array", () => { + const messages = [ + { role: "user", content: 42 }, + { role: "user", content: null }, + { role: "user", content: { nested: true } }, + ]; + const result = extractUserMessages(messages as unknown[]); + expect(result).toEqual([]); + }); +}); + +// ============================================================================ +// extractEntities() — tests validateExtractionResult() indirectly +// ============================================================================ + +describe("extractEntities", () => { + // We need to mock `fetch` since callOpenRouter uses global fetch + const originalFetch = globalThis.fetch; + + beforeEach(() => { + vi.restoreAllMocks(); + }); + + afterEach(() => { + globalThis.fetch = originalFetch; + }); + + const enabledConfig: ExtractionConfig = { + enabled: true, + apiKey: "test-key", + model: "test-model", + baseUrl: "https://test.ai/api/v1", + temperature: 0.0, + maxRetries: 0, // No retries in tests + }; + + const disabledConfig: ExtractionConfig = { + ...enabledConfig, + enabled: false, + }; + + function mockFetchResponse(content: string, status = 200) { + globalThis.fetch = vi.fn().mockResolvedValue({ + ok: status >= 200 && status < 300, + status, + text: () => Promise.resolve(content), + json: () => + Promise.resolve({ + choices: [{ message: { content } }], + }), + }); + } + + it("should return null result when extraction is disabled", async () => { + const { result, transientFailure } = await extractEntities("test text", disabledConfig); + expect(result).toBeNull(); + expect(transientFailure).toBe(false); + }); + + it("should extract valid entities from LLM response", async () => { + mockFetchResponse( + JSON.stringify({ + category: "fact", + entities: [ + { name: "Tarun", type: "person", aliases: ["boss"], description: "The CEO" }, + { name: "Abundent", type: "organization" }, + ], + relationships: [ + { source: "Tarun", target: "Abundent", type: "WORKS_AT", confidence: 0.95 }, + ], + tags: [{ name: "Leadership", category: "business" }], + }), + ); + + const { result } = await extractEntities("Tarun works at Abundent", enabledConfig); + expect(result).not.toBeNull(); + expect(result!.category).toBe("fact"); + + // Entities should be normalized to lowercase + expect(result!.entities).toHaveLength(2); + expect(result!.entities[0].name).toBe("tarun"); + expect(result!.entities[0].type).toBe("person"); + expect(result!.entities[0].aliases).toEqual(["boss"]); + expect(result!.entities[0].description).toBe("The CEO"); + expect(result!.entities[1].name).toBe("abundent"); + expect(result!.entities[1].type).toBe("organization"); + + // Relationships should be normalized to lowercase source/target + expect(result!.relationships).toHaveLength(1); + expect(result!.relationships[0].source).toBe("tarun"); + expect(result!.relationships[0].target).toBe("abundent"); + expect(result!.relationships[0].type).toBe("WORKS_AT"); + expect(result!.relationships[0].confidence).toBe(0.95); + + // Tags should be normalized to lowercase + expect(result!.tags).toHaveLength(1); + expect(result!.tags[0].name).toBe("leadership"); + expect(result!.tags[0].category).toBe("business"); + }); + + it("should handle empty extraction result", async () => { + mockFetchResponse( + JSON.stringify({ + category: "other", + entities: [], + relationships: [], + tags: [], + }), + ); + + const { result } = await extractEntities("just a greeting", enabledConfig); + expect(result).not.toBeNull(); + expect(result!.entities).toEqual([]); + expect(result!.relationships).toEqual([]); + expect(result!.tags).toEqual([]); + }); + + it("should handle missing fields in LLM response", async () => { + mockFetchResponse( + JSON.stringify({ + // No category, entities, relationships, or tags + }), + ); + + const { result } = await extractEntities("some text", enabledConfig); + expect(result).not.toBeNull(); + expect(result!.category).toBeUndefined(); + expect(result!.entities).toEqual([]); + expect(result!.relationships).toEqual([]); + expect(result!.tags).toEqual([]); + }); + + it("should filter out invalid entity types (fallback to concept)", async () => { + mockFetchResponse( + JSON.stringify({ + entities: [ + { name: "Widget", type: "gadget" }, // invalid type -> concept + { name: "Paris", type: "location" }, // valid type + ], + relationships: [], + tags: [], + }), + ); + + const { result } = await extractEntities("test", enabledConfig); + expect(result!.entities).toHaveLength(2); + expect(result!.entities[0].type).toBe("concept"); // invalid type falls back to concept + expect(result!.entities[1].type).toBe("location"); + }); + + it("should filter out invalid relationship types", async () => { + mockFetchResponse( + JSON.stringify({ + entities: [], + relationships: [ + { source: "a", target: "b", type: "WORKS_AT", confidence: 0.9 }, // valid + { source: "a", target: "b", type: "HATES", confidence: 0.9 }, // invalid type + ], + tags: [], + }), + ); + + const { result } = await extractEntities("test", enabledConfig); + expect(result!.relationships).toHaveLength(1); + expect(result!.relationships[0].type).toBe("WORKS_AT"); + }); + + it("should clamp confidence to 0-1 range", async () => { + mockFetchResponse( + JSON.stringify({ + entities: [], + relationships: [ + { source: "a", target: "b", type: "KNOWS", confidence: 1.5 }, // over 1 + { source: "c", target: "d", type: "KNOWS", confidence: -0.5 }, // under 0 + ], + tags: [], + }), + ); + + const { result } = await extractEntities("test", enabledConfig); + expect(result!.relationships[0].confidence).toBe(1); + expect(result!.relationships[1].confidence).toBe(0); + }); + + it("should default confidence to 0.7 when not a number", async () => { + mockFetchResponse( + JSON.stringify({ + entities: [], + relationships: [{ source: "a", target: "b", type: "KNOWS", confidence: "high" }], + tags: [], + }), + ); + + const { result } = await extractEntities("test", enabledConfig); + expect(result!.relationships[0].confidence).toBe(0.7); + }); + + it("should filter out entities without name", async () => { + mockFetchResponse( + JSON.stringify({ + entities: [ + { name: "", type: "person" }, // empty name -> filtered + { name: " ", type: "person" }, // whitespace-only name -> filtered (after trim) + { name: "valid", type: "person" }, // valid + ], + relationships: [], + tags: [], + }), + ); + + const { result } = await extractEntities("test", enabledConfig); + expect(result!.entities).toHaveLength(1); + expect(result!.entities[0].name).toBe("valid"); + }); + + it("should filter out entities with non-object shape", async () => { + mockFetchResponse( + JSON.stringify({ + entities: [null, "not an entity", 42, { name: "valid", type: "person" }], + relationships: [], + tags: [], + }), + ); + + const { result } = await extractEntities("test", enabledConfig); + expect(result!.entities).toHaveLength(1); + }); + + it("should filter out entities missing required fields", async () => { + mockFetchResponse( + JSON.stringify({ + entities: [ + { type: "person" }, // missing name + { name: "test" }, // missing type + { name: "valid", type: "person" }, // has both + ], + relationships: [], + tags: [], + }), + ); + + const { result } = await extractEntities("test", enabledConfig); + expect(result!.entities).toHaveLength(1); + expect(result!.entities[0].name).toBe("valid"); + }); + + it("should default tag category to 'topic' when missing", async () => { + mockFetchResponse( + JSON.stringify({ + entities: [], + relationships: [], + tags: [{ name: "neo4j" }], // no category + }), + ); + + const { result } = await extractEntities("test", enabledConfig); + expect(result!.tags[0].category).toBe("topic"); + }); + + it("should filter out tags with empty names", async () => { + mockFetchResponse( + JSON.stringify({ + entities: [], + relationships: [], + tags: [ + { name: "", category: "tech" }, // empty -> filtered + { name: " ", category: "tech" }, // whitespace-only -> filtered + { name: "valid", category: "tech" }, + ], + }), + ); + + const { result } = await extractEntities("test", enabledConfig); + expect(result!.tags).toHaveLength(1); + expect(result!.tags[0].name).toBe("valid"); + }); + + it("should reject invalid category values", async () => { + mockFetchResponse( + JSON.stringify({ + category: "invalid-category", + entities: [], + relationships: [], + tags: [], + }), + ); + + const { result } = await extractEntities("test", enabledConfig); + expect(result!.category).toBeUndefined(); + }); + + it("should accept valid category values", async () => { + for (const category of ["preference", "fact", "decision", "entity", "other"]) { + mockFetchResponse( + JSON.stringify({ + category, + entities: [], + relationships: [], + tags: [], + }), + ); + const { result } = await extractEntities(`test ${category}`, enabledConfig); + expect(result!.category).toBe(category); + } + }); + + it("should return null result for malformed JSON response (permanent failure)", async () => { + mockFetchResponse("not valid json at all"); + + const { result, transientFailure } = await extractEntities("test", enabledConfig); + // callOpenRouter returns the raw string, JSON.parse fails, catch returns null + expect(result).toBeNull(); + expect(transientFailure).toBe(false); + }); + + it("should return null result when API returns error status", async () => { + globalThis.fetch = vi.fn().mockResolvedValue({ + ok: false, + status: 500, + text: () => Promise.resolve("Internal Server Error"), + }); + + const { result } = await extractEntities("test", enabledConfig); + // API error 500 is not in the transient list (only 429, 502, 503, 504) + expect(result).toBeNull(); + }); + + it("should return null result when API returns no content", async () => { + globalThis.fetch = vi.fn().mockResolvedValue({ + ok: true, + status: 200, + json: () => Promise.resolve({ choices: [{ message: { content: null } }] }), + }); + + const { result, transientFailure } = await extractEntities("test", enabledConfig); + expect(result).toBeNull(); + expect(transientFailure).toBe(false); + }); + + it("should normalize alias strings to lowercase", async () => { + mockFetchResponse( + JSON.stringify({ + entities: [{ name: "John", type: "person", aliases: ["Johnny", "JOHN", "j.doe"] }], + relationships: [], + tags: [], + }), + ); + + const { result } = await extractEntities("test", enabledConfig); + expect(result!.entities[0].aliases).toEqual(["johnny", "john", "j.doe"]); + }); + + it("should filter out non-string aliases", async () => { + mockFetchResponse( + JSON.stringify({ + entities: [{ name: "John", type: "person", aliases: ["valid", 42, null, "also-valid"] }], + relationships: [], + tags: [], + }), + ); + + const { result } = await extractEntities("test", enabledConfig); + expect(result!.entities[0].aliases).toEqual(["valid", "also-valid"]); + }); +}); + +// ============================================================================ +// runBackgroundExtraction() +// ============================================================================ + +describe("runBackgroundExtraction", () => { + const originalFetch = globalThis.fetch; + + let mockLogger: { + info: ReturnType; + warn: ReturnType; + error: ReturnType; + debug: ReturnType; + }; + + let mockDb: { + updateExtractionStatus: ReturnType; + mergeEntity: ReturnType; + createMentions: ReturnType; + createEntityRelationship: ReturnType; + tagMemory: ReturnType; + updateMemoryCategory: ReturnType; + }; + + let mockEmbeddings: { + embed: ReturnType; + embedBatch: ReturnType; + }; + + beforeEach(() => { + vi.restoreAllMocks(); + mockLogger = { + info: vi.fn(), + warn: vi.fn(), + error: vi.fn(), + debug: vi.fn(), + }; + mockDb = { + updateExtractionStatus: vi.fn().mockResolvedValue(undefined), + mergeEntity: vi.fn().mockResolvedValue(undefined), + createMentions: vi.fn().mockResolvedValue(undefined), + createEntityRelationship: vi.fn().mockResolvedValue(undefined), + tagMemory: vi.fn().mockResolvedValue(undefined), + updateMemoryCategory: vi.fn().mockResolvedValue(undefined), + }; + mockEmbeddings = { + embed: vi.fn().mockResolvedValue([0.1, 0.2, 0.3]), + embedBatch: vi.fn().mockResolvedValue([[0.1, 0.2, 0.3]]), + }; + }); + + afterEach(() => { + globalThis.fetch = originalFetch; + }); + + const enabledConfig: ExtractionConfig = { + enabled: true, + apiKey: "test-key", + model: "test-model", + baseUrl: "https://test.ai/api/v1", + temperature: 0.0, + maxRetries: 0, + }; + + const disabledConfig: ExtractionConfig = { + ...enabledConfig, + enabled: false, + }; + + function mockFetchResponse(content: string) { + globalThis.fetch = vi.fn().mockResolvedValue({ + ok: true, + status: 200, + json: () => + Promise.resolve({ + choices: [{ message: { content } }], + }), + }); + } + + it("should skip extraction and mark as 'skipped' when disabled", async () => { + await runBackgroundExtraction( + "mem-1", + "test text", + mockDb as never, + mockEmbeddings as never, + disabledConfig, + mockLogger, + ); + expect(mockDb.updateExtractionStatus).toHaveBeenCalledWith("mem-1", "skipped"); + }); + + it("should mark as 'failed' when extraction returns null", async () => { + globalThis.fetch = vi.fn().mockResolvedValue({ + ok: false, + status: 500, + text: () => Promise.resolve("error"), + }); + + await runBackgroundExtraction( + "mem-1", + "test text", + mockDb as never, + mockEmbeddings as never, + enabledConfig, + mockLogger, + ); + expect(mockDb.updateExtractionStatus).toHaveBeenCalledWith("mem-1", "failed"); + }); + + it("should mark as 'complete' when extraction result is empty", async () => { + mockFetchResponse( + JSON.stringify({ + entities: [], + relationships: [], + tags: [], + }), + ); + + await runBackgroundExtraction( + "mem-1", + "test text", + mockDb as never, + mockEmbeddings as never, + enabledConfig, + mockLogger, + ); + expect(mockDb.updateExtractionStatus).toHaveBeenCalledWith("mem-1", "complete"); + }); + + it("should merge entities, create mentions, and mark complete", async () => { + mockFetchResponse( + JSON.stringify({ + category: "fact", + entities: [{ name: "Alice", type: "person" }], + relationships: [], + tags: [], + }), + ); + + await runBackgroundExtraction( + "mem-1", + "Alice is a developer", + mockDb as never, + mockEmbeddings as never, + enabledConfig, + mockLogger, + ); + + expect(mockDb.mergeEntity).toHaveBeenCalledWith( + expect.objectContaining({ + name: "alice", + type: "person", + }), + ); + expect(mockDb.createMentions).toHaveBeenCalledWith("mem-1", "alice", "context", 1.0); + expect(mockDb.updateMemoryCategory).toHaveBeenCalledWith("mem-1", "fact"); + expect(mockDb.updateExtractionStatus).toHaveBeenCalledWith("mem-1", "complete"); + }); + + it("should create entity relationships", async () => { + mockFetchResponse( + JSON.stringify({ + entities: [ + { name: "Alice", type: "person" }, + { name: "Acme", type: "organization" }, + ], + relationships: [{ source: "Alice", target: "Acme", type: "WORKS_AT", confidence: 0.9 }], + tags: [], + }), + ); + + await runBackgroundExtraction( + "mem-1", + "Alice works at Acme", + mockDb as never, + mockEmbeddings as never, + enabledConfig, + mockLogger, + ); + + expect(mockDb.createEntityRelationship).toHaveBeenCalledWith("alice", "acme", "WORKS_AT", 0.9); + }); + + it("should tag memories", async () => { + mockFetchResponse( + JSON.stringify({ + entities: [], + relationships: [], + tags: [{ name: "Programming", category: "tech" }], + }), + ); + + await runBackgroundExtraction( + "mem-1", + "test text", + mockDb as never, + mockEmbeddings as never, + enabledConfig, + mockLogger, + ); + + expect(mockDb.tagMemory).toHaveBeenCalledWith("mem-1", "programming", "tech"); + }); + + it("should not update category when result has no category", async () => { + mockFetchResponse( + JSON.stringify({ + entities: [{ name: "Test", type: "concept" }], + relationships: [], + tags: [], + }), + ); + + await runBackgroundExtraction( + "mem-1", + "test", + mockDb as never, + mockEmbeddings as never, + enabledConfig, + mockLogger, + ); + + expect(mockDb.updateMemoryCategory).not.toHaveBeenCalled(); + }); + + it("should handle entity merge failure gracefully", async () => { + mockFetchResponse( + JSON.stringify({ + entities: [ + { name: "Alice", type: "person" }, + { name: "Bob", type: "person" }, + ], + relationships: [], + tags: [], + }), + ); + + // First entity merge fails, second succeeds + mockDb.mergeEntity.mockRejectedValueOnce(new Error("merge failed")); + mockDb.mergeEntity.mockResolvedValueOnce(undefined); + + await runBackgroundExtraction( + "mem-1", + "Alice and Bob", + mockDb as never, + mockEmbeddings as never, + enabledConfig, + mockLogger, + ); + + // Should still continue and complete + expect(mockDb.mergeEntity).toHaveBeenCalledTimes(2); + expect(mockDb.updateExtractionStatus).toHaveBeenCalledWith("mem-1", "complete"); + expect(mockLogger.warn).toHaveBeenCalled(); + }); + + it("should log extraction results", async () => { + mockFetchResponse( + JSON.stringify({ + category: "fact", + entities: [{ name: "Test", type: "concept" }], + relationships: [{ source: "a", target: "b", type: "RELATED_TO", confidence: 0.8 }], + tags: [{ name: "tech" }], + }), + ); + + await runBackgroundExtraction( + "mem-12345678-abcd", + "test", + mockDb as never, + mockEmbeddings as never, + enabledConfig, + mockLogger, + ); + + expect(mockLogger.info).toHaveBeenCalledWith(expect.stringContaining("extraction complete")); + }); +}); diff --git a/extensions/memory-neo4j/extractor.ts b/extensions/memory-neo4j/extractor.ts index 6fdc9f5d99f..7fc38eb1e2b 100644 --- a/extensions/memory-neo4j/extractor.ts +++ b/extensions/memory-neo4j/extractor.ts @@ -63,6 +63,9 @@ Rules: // OpenRouter API Client // ============================================================================ +// Timeout for LLM and embedding fetch calls to prevent hanging indefinitely +const FETCH_TIMEOUT_MS = 30_000; + async function callOpenRouter(config: ExtractionConfig, prompt: string): Promise { for (let attempt = 0; attempt <= config.maxRetries; attempt++) { try { @@ -78,6 +81,7 @@ async function callOpenRouter(config: ExtractionConfig, prompt: string): Promise temperature: config.temperature, response_format: { type: "json_object" }, }), + signal: AbortSignal.timeout(FETCH_TIMEOUT_MS), }); if (!response.ok) { @@ -104,30 +108,68 @@ async function callOpenRouter(config: ExtractionConfig, prompt: string): Promise // Entity Extraction // ============================================================================ +/** Max retries for transient extraction failures before marking permanently failed */ +const MAX_EXTRACTION_RETRIES = 3; + +/** + * Check if an error is transient (network/timeout) vs permanent (JSON parse, etc.) + */ +function isTransientError(err: unknown): boolean { + if (!(err instanceof Error)) return false; + const msg = err.message.toLowerCase(); + return ( + err.name === "AbortError" || + err.name === "TimeoutError" || + msg.includes("timeout") || + msg.includes("econnrefused") || + msg.includes("econnreset") || + msg.includes("enotfound") || + msg.includes("network") || + msg.includes("fetch failed") || + msg.includes("socket hang up") || + msg.includes("api error 429") || + msg.includes("api error 502") || + msg.includes("api error 503") || + msg.includes("api error 504") + ); +} + /** * Extract entities and relationships from a memory text using LLM. + * + * Returns { result, transientFailure }: + * - result is the ExtractionResult or null if extraction returned nothing useful + * - transientFailure is true if the failure was due to a network/timeout issue + * (caller should retry later) vs a permanent failure (bad JSON, etc.) */ export async function extractEntities( text: string, config: ExtractionConfig, -): Promise { +): Promise<{ result: ExtractionResult | null; transientFailure: boolean }> { if (!config.enabled) { - return null; + return { result: null, transientFailure: false }; } const prompt = ENTITY_EXTRACTION_PROMPT.replace("{text}", text); + let content: string | null; try { - const content = await callOpenRouter(config, prompt); - if (!content) { - return null; - } + content = await callOpenRouter(config, prompt); + } catch (err) { + // Network/timeout errors are transient — caller should retry + return { result: null, transientFailure: isTransientError(err) }; + } + if (!content) { + return { result: null, transientFailure: false }; + } + + try { const parsed = JSON.parse(content) as Record; - return validateExtractionResult(parsed); + return { result: validateExtractionResult(parsed), transientFailure: false }; } catch { - // Will be handled by caller; don't throw for parse errors - return null; + // JSON parse failure is permanent — LLM returned malformed output + return { result: null, transientFailure: false }; } } @@ -213,7 +255,11 @@ function validateExtractionResult(raw: Record): ExtractionResul * 3. Create MENTIONS relationships from Memory → Entity * 4. Create inter-Entity relationships (WORKS_AT, KNOWS, etc.) * 5. Tag the memory - * 6. Update extractionStatus to "complete" or "failed" + * 6. Update extractionStatus to "complete", "pending" (transient retry), or "failed" + * + * Transient failures (network/timeout) leave status as "pending" with an incremented + * retry counter. After MAX_EXTRACTION_RETRIES transient failures, the memory is + * permanently marked "failed". Permanent failures (malformed JSON) are immediately "failed". */ export async function runBackgroundExtraction( memoryId: string, @@ -222,6 +268,7 @@ export async function runBackgroundExtraction( embeddings: Embeddings, config: ExtractionConfig, logger: Logger, + currentRetries: number = 0, ): Promise { if (!config.enabled) { await db.updateExtractionStatus(memoryId, "skipped").catch(() => {}); @@ -229,10 +276,28 @@ export async function runBackgroundExtraction( } try { - const result = await extractEntities(text, config); + const { result, transientFailure } = await extractEntities(text, config); if (!result) { - await db.updateExtractionStatus(memoryId, "failed"); + if (transientFailure) { + // Transient failure (network/timeout) — leave as pending for retry + const retries = currentRetries + 1; + if (retries >= MAX_EXTRACTION_RETRIES) { + logger.warn( + `memory-neo4j: extraction permanently failed for ${memoryId.slice(0, 8)} after ${retries} transient retries`, + ); + await db.updateExtractionStatus(memoryId, "failed", { incrementRetries: true }); + } else { + logger.info( + `memory-neo4j: extraction transient failure for ${memoryId.slice(0, 8)}, will retry (${retries}/${MAX_EXTRACTION_RETRIES})`, + ); + // Keep status as "pending" but increment retry counter + await db.updateExtractionStatus(memoryId, "pending", { incrementRetries: true }); + } + } else { + // Permanent failure (JSON parse, empty response, etc.) + await db.updateExtractionStatus(memoryId, "failed"); + } return; } @@ -309,8 +374,21 @@ export async function runBackgroundExtraction( (result.category ? `, category=${result.category}` : ""), ); } catch (err) { - logger.warn(`memory-neo4j: extraction failed for ${memoryId.slice(0, 8)}: ${String(err)}`); - await db.updateExtractionStatus(memoryId, "failed").catch(() => {}); + // Unexpected error during graph operations — treat as transient if retry budget remains + const isTransient = isTransientError(err); + if (isTransient && currentRetries + 1 < MAX_EXTRACTION_RETRIES) { + logger.warn( + `memory-neo4j: extraction transient error for ${memoryId.slice(0, 8)}, will retry: ${String(err)}`, + ); + await db + .updateExtractionStatus(memoryId, "pending", { incrementRetries: true }) + .catch(() => {}); + } else { + logger.warn(`memory-neo4j: extraction failed for ${memoryId.slice(0, 8)}: ${String(err)}`); + await db + .updateExtractionStatus(memoryId, "failed", { incrementRetries: true }) + .catch(() => {}); + } } } @@ -533,6 +611,14 @@ export async function runSleepCycle( // -------------------------------------------------------------------------- // Phase 3: Core Promotion (using pre-computed scores from Phase 2) + // + // Design note on staleness: The effective scores and Pareto threshold were + // computed in Phase 2 and may be slightly stale by the time Phases 3/4 run. + // This is acceptable because: (a) the sleep cycle is a background maintenance + // task that runs infrequently (not concurrent with itself), (b) the scoring + // formula is deterministic based on stored properties that change slowly, and + // (c) promotion/demotion are reversible in the next cycle. The alternative + // (re-querying scores per phase) adds latency without meaningful accuracy gain. // -------------------------------------------------------------------------- if (!abortSignal?.aborted && paretoThreshold > 0) { onPhaseStart?.("promotion"); @@ -631,7 +717,15 @@ export async function runSleepCycle( const chunk = pending.slice(i, i + EXTRACTION_CONCURRENCY); const outcomes = await Promise.allSettled( chunk.map((memory) => - runBackgroundExtraction(memory.id, memory.text, db, embeddings, config, logger), + runBackgroundExtraction( + memory.id, + memory.text, + db, + embeddings, + config, + logger, + memory.extractionRetries, + ), ), ); diff --git a/extensions/memory-neo4j/index.ts b/extensions/memory-neo4j/index.ts index ac40c1d93ac..2995d678e3f 100644 --- a/extensions/memory-neo4j/index.ts +++ b/extensions/memory-neo4j/index.ts @@ -18,6 +18,8 @@ import { randomUUID } from "node:crypto"; import { stringEnum } from "openclaw/plugin-sdk"; import type { MemoryCategory, MemorySource } from "./schema.js"; import { + DEFAULT_EMBEDDING_DIMS, + EMBEDDING_DIMENSIONS, MEMORY_CATEGORIES, memoryNeo4jConfigSchema, resolveExtractionConfig, @@ -46,6 +48,25 @@ const memoryNeo4jPlugin = { const extractionConfig = resolveExtractionConfig(cfg.extraction); const vectorDim = vectorDimsForModel(cfg.embedding.model); + // Warn on empty neo4j password (may be valid for some setups, but usually a misconfiguration) + if (!cfg.neo4j.password) { + api.logger.warn( + "memory-neo4j: neo4j.password is empty — this may be intentional for passwordless setups, but verify your configuration", + ); + } + + // Warn when using default embedding dimensions for an unknown model + const isKnownModel = + cfg.embedding.model in EMBEDDING_DIMENSIONS || + Object.keys(EMBEDDING_DIMENSIONS).some((known) => cfg.embedding.model.startsWith(known)); + if (!isKnownModel) { + api.logger.warn( + `memory-neo4j: unknown embedding model "${cfg.embedding.model}" — using default ${DEFAULT_EMBEDDING_DIMS} dimensions. ` + + `If your model outputs a different dimension, vector operations will fail. ` + + `Known models: ${Object.keys(EMBEDDING_DIMENSIONS).join(", ")}`, + ); + } + // Create shared resources const db = new Neo4jMemoryClient( cfg.neo4j.uri, @@ -59,6 +80,7 @@ const memoryNeo4jPlugin = { cfg.embedding.model, cfg.embedding.provider, cfg.embedding.baseUrl, + api.logger, ); api.logger.debug?.( @@ -499,23 +521,68 @@ const memoryNeo4jPlugin = { console.log(" Phase 7: Orphan Cleanup — Remove disconnected nodes\n"); try { + // Validate sleep cycle CLI parameters before running + const batchSize = opts.batchSize ? parseInt(opts.batchSize, 10) : undefined; + const delay = opts.delay ? parseInt(opts.delay, 10) : undefined; + const decayHalfLife = opts.decayHalfLife + ? parseInt(opts.decayHalfLife, 10) + : undefined; + const decayThreshold = opts.decayThreshold + ? parseFloat(opts.decayThreshold) + : undefined; + const pareto = opts.pareto ? parseFloat(opts.pareto) : undefined; + const promotionMinAge = opts.promotionMinAge + ? parseInt(opts.promotionMinAge, 10) + : undefined; + + if (batchSize != null && (Number.isNaN(batchSize) || batchSize <= 0)) { + console.error("Error: --batch-size must be greater than 0"); + process.exitCode = 1; + return; + } + if (delay != null && (Number.isNaN(delay) || delay < 0)) { + console.error("Error: --delay must be >= 0"); + process.exitCode = 1; + return; + } + if (decayHalfLife != null && (Number.isNaN(decayHalfLife) || decayHalfLife <= 0)) { + console.error("Error: --decay-half-life must be greater than 0"); + process.exitCode = 1; + return; + } + if ( + decayThreshold != null && + (Number.isNaN(decayThreshold) || decayThreshold < 0 || decayThreshold > 1) + ) { + console.error("Error: --decay-threshold must be between 0 and 1"); + process.exitCode = 1; + return; + } + if (pareto != null && (Number.isNaN(pareto) || pareto < 0 || pareto > 1)) { + console.error("Error: --pareto must be between 0 and 1"); + process.exitCode = 1; + return; + } + if ( + promotionMinAge != null && + (Number.isNaN(promotionMinAge) || promotionMinAge < 0) + ) { + console.error("Error: --promotion-min-age must be >= 0"); + process.exitCode = 1; + return; + } + await db.ensureInitialized(); const result = await runSleepCycle(db, embeddings, extractionConfig, api.logger, { agentId: opts.agent, dedupThreshold: opts.dedupThreshold ? parseFloat(opts.dedupThreshold) : undefined, - paretoPercentile: opts.pareto ? parseFloat(opts.pareto) : undefined, - promotionMinAgeDays: opts.promotionMinAge - ? parseInt(opts.promotionMinAge, 10) - : undefined, - decayRetentionThreshold: opts.decayThreshold - ? parseFloat(opts.decayThreshold) - : undefined, - decayBaseHalfLifeDays: opts.decayHalfLife - ? parseInt(opts.decayHalfLife, 10) - : undefined, - extractionBatchSize: opts.batchSize ? parseInt(opts.batchSize, 10) : undefined, - extractionDelayMs: opts.delay ? parseInt(opts.delay, 10) : undefined, + paretoPercentile: pareto, + promotionMinAgeDays: promotionMinAge, + decayRetentionThreshold: decayThreshold, + decayBaseHalfLifeDays: decayHalfLife, + extractionBatchSize: batchSize, + extractionDelayMs: delay, onPhaseStart: (phase) => { const phaseNames = { dedup: "Phase 1: Deduplication", @@ -611,12 +678,45 @@ const memoryNeo4jPlugin = { const midSessionRefreshAt = new Map(); const MIN_TOKENS_SINCE_REFRESH = 10_000; // Only refresh if context grew by 10k+ tokens + // Track session timestamps for TTL-based cleanup. Without this, bootstrappedSessions + // and midSessionRefreshAt leak entries for sessions that ended without an explicit + // after_compaction event (e.g., normal session end on long-running gateways). + const SESSION_TTL_MS = 24 * 60 * 60 * 1000; // 24 hours + const sessionLastSeen = new Map(); + let lastTtlSweep = Date.now(); + + /** Evict stale entries from session tracking maps older than SESSION_TTL_MS. */ + function pruneStaleSessionEntries(): void { + const now = Date.now(); + // Only sweep at most once per 5 minutes to avoid overhead + if (now - lastTtlSweep < 5 * 60 * 1000) { + return; + } + lastTtlSweep = now; + + const cutoff = now - SESSION_TTL_MS; + for (const [key, ts] of sessionLastSeen) { + if (ts < cutoff) { + bootstrappedSessions.delete(key); + midSessionRefreshAt.delete(key); + sessionLastSeen.delete(key); + } + } + } + + /** Mark a session as recently active for TTL tracking. */ + function touchSession(sessionKey: string): void { + sessionLastSeen.set(sessionKey, Date.now()); + pruneStaleSessionEntries(); + } + // After compaction: clear bootstrap flag and mid-session refresh tracking if (cfg.coreMemory.enabled) { api.on("after_compaction", async (_event, ctx) => { if (ctx.sessionKey) { bootstrappedSessions.delete(ctx.sessionKey); midSessionRefreshAt.delete(ctx.sessionKey); + sessionLastSeen.delete(ctx.sessionKey); api.logger.info?.( `memory-neo4j: cleared bootstrap/refresh flags for session ${ctx.sessionKey} after compaction`, ); @@ -624,6 +724,18 @@ const memoryNeo4jPlugin = { }); } + // Session end: clean up tracking entries for completed sessions. + // The sessionId from session_end may match sessionKey in some implementations; + // this provides best-effort cleanup alongside the TTL-based sweep above. + api.on("session_end", async (_event, ctx) => { + const key = ctx.sessionId; + if (key) { + bootstrappedSessions.delete(key); + midSessionRefreshAt.delete(key); + sessionLastSeen.delete(key); + } + }); + // Mid-session core memory refresh: re-inject core memories when context grows past threshold // This counters the "lost in the middle" phenomenon by placing core memories closer to end of context const refreshThreshold = cfg.coreMemory.refreshAtContextPercent; @@ -666,6 +778,7 @@ const memoryNeo4jPlugin = { // Record this refresh midSessionRefreshAt.set(sessionKey, event.estimatedUsedTokens); + touchSession(sessionKey); const content = coreMemories.map((m) => `- ${m.text}`).join("\n"); api.logger.info?.( @@ -769,6 +882,7 @@ const memoryNeo4jPlugin = { if (coreMemories.length === 0) { if (sessionKey) { bootstrappedSessions.add(sessionKey); + touchSession(sessionKey); } api.logger.debug?.( `memory-neo4j: no core memories found for agent=${agentId}, marking session as bootstrapped`, @@ -805,6 +919,7 @@ const memoryNeo4jPlugin = { if (sessionKey) { bootstrappedSessions.add(sessionKey); + touchSession(sessionKey); } // Log at info level when actually injecting, debug for skips api.logger.info?.( diff --git a/extensions/memory-neo4j/mid-session-refresh.test.ts b/extensions/memory-neo4j/mid-session-refresh.test.ts index 60adf39b8b8..50a99b9e259 100644 --- a/extensions/memory-neo4j/mid-session-refresh.test.ts +++ b/extensions/memory-neo4j/mid-session-refresh.test.ts @@ -77,14 +77,15 @@ describe("mid-session core memory refresh", () => { expect(config.coreMemory.refreshAtContextPercent).toBeUndefined(); }); - it("should reject refreshAtContextPercent over 100", async () => { + it("should throw for refreshAtContextPercent over 100", async () => { const { memoryNeo4jConfigSchema } = await import("./config.js"); - const config = memoryNeo4jConfigSchema.parse({ - neo4j: { uri: "bolt://localhost:7687", user: "neo4j", password: "test" }, - embedding: { provider: "ollama" }, - coreMemory: { refreshAtContextPercent: 150 }, - }); - expect(config.coreMemory.refreshAtContextPercent).toBeUndefined(); + expect(() => + memoryNeo4jConfigSchema.parse({ + neo4j: { uri: "bolt://localhost:7687", user: "neo4j", password: "test" }, + embedding: { provider: "ollama" }, + coreMemory: { refreshAtContextPercent: 150 }, + }), + ).toThrow("coreMemory.refreshAtContextPercent must be between 1 and 100"); }); it("should default to undefined when not specified", async () => { diff --git a/extensions/memory-neo4j/neo4j-client.ts b/extensions/memory-neo4j/neo4j-client.ts index fe67e7696f7..7fe8e4ce8ff 100644 --- a/extensions/memory-neo4j/neo4j-client.ts +++ b/extensions/memory-neo4j/neo4j-client.ts @@ -211,32 +211,36 @@ export class Neo4jMemoryClient { async storeMemory(input: StoreMemoryInput): Promise { await this.ensureInitialized(); - const session = this.driver!.session(); - try { - const now = new Date().toISOString(); - const result = await session.run( - `CREATE (m:Memory { - id: $id, text: $text, embedding: $embedding, - importance: $importance, category: $category, - source: $source, extractionStatus: $extractionStatus, - agentId: $agentId, sessionKey: $sessionKey, - createdAt: $createdAt, updatedAt: $updatedAt, - retrievalCount: $retrievalCount, lastRetrievedAt: $lastRetrievedAt - }) - RETURN m.id AS id`, - { - ...input, - sessionKey: input.sessionKey ?? null, - createdAt: now, - updatedAt: now, - retrievalCount: 0, - lastRetrievedAt: null, - }, - ); - return result.records[0].get("id") as string; - } finally { - await session.close(); - } + return this.retryOnTransient(async () => { + const session = this.driver!.session(); + try { + const now = new Date().toISOString(); + const result = await session.run( + `CREATE (m:Memory { + id: $id, text: $text, embedding: $embedding, + importance: $importance, category: $category, + source: $source, extractionStatus: $extractionStatus, + agentId: $agentId, sessionKey: $sessionKey, + createdAt: $createdAt, updatedAt: $updatedAt, + retrievalCount: $retrievalCount, lastRetrievedAt: $lastRetrievedAt, + extractionRetries: $extractionRetries + }) + RETURN m.id AS id`, + { + ...input, + sessionKey: input.sessionKey ?? null, + createdAt: now, + updatedAt: now, + retrievalCount: 0, + lastRetrievedAt: null, + extractionRetries: 0, + }, + ); + return result.records[0].get("id") as string; + } finally { + await session.close(); + } + }); } async deleteMemory(id: string): Promise { @@ -247,29 +251,32 @@ export class Neo4jMemoryClient { throw new Error(`Invalid memory ID format: ${id}`); } - const session = this.driver!.session(); - try { - // First, decrement mentionCount on connected entities - await session.run( - `MATCH (m:Memory {id: $id})-[:MENTIONS]->(e:Entity) - SET e.mentionCount = e.mentionCount - 1`, - { id }, - ); + return this.retryOnTransient(async () => { + const session = this.driver!.session(); + try { + // Decrement mentionCount on connected entities (floor at 0 to prevent + // negative counts from parallel deletes racing on the same entity) + await session.run( + `MATCH (m:Memory {id: $id})-[:MENTIONS]->(e:Entity) + SET e.mentionCount = CASE WHEN e.mentionCount > 0 THEN e.mentionCount - 1 ELSE 0 END`, + { id }, + ); - // Then delete the memory with all its relationships - const result = await session.run( - `MATCH (m:Memory {id: $id}) - DETACH DELETE m - RETURN count(*) AS deleted`, - { id }, - ); + // Then delete the memory with all its relationships + const result = await session.run( + `MATCH (m:Memory {id: $id}) + DETACH DELETE m + RETURN count(*) AS deleted`, + { id }, + ); - const deleted = - result.records.length > 0 ? (result.records[0].get("deleted") as number) > 0 : false; - return deleted; - } finally { - await session.close(); - } + const deleted = + result.records.length > 0 ? (result.records[0].get("deleted") as number) > 0 : false; + return deleted; + } finally { + await session.close(); + } + }); } async countMemories(agentId?: string): Promise { @@ -371,39 +378,43 @@ export class Neo4jMemoryClient { agentId?: string, ): Promise { await this.ensureInitialized(); - const session = this.driver!.session(); try { - const agentFilter = agentId ? "AND node.agentId = $agentId" : ""; - const result = await session.run( - `CALL db.index.vector.queryNodes('memory_embedding_index', $limit, $embedding) - YIELD node, score - WHERE score >= $minScore ${agentFilter} - RETURN node.id AS id, node.text AS text, node.category AS category, - node.importance AS importance, node.createdAt AS createdAt, - score AS similarity - ORDER BY score DESC`, - { - embedding, - limit: neo4j.int(Math.floor(limit)), - minScore, - ...(agentId ? { agentId } : {}), - }, - ); + return await this.retryOnTransient(async () => { + const session = this.driver!.session(); + try { + const agentFilter = agentId ? "AND node.agentId = $agentId" : ""; + const result = await session.run( + `CALL db.index.vector.queryNodes('memory_embedding_index', $limit, $embedding) + YIELD node, score + WHERE score >= $minScore ${agentFilter} + RETURN node.id AS id, node.text AS text, node.category AS category, + node.importance AS importance, node.createdAt AS createdAt, + score AS similarity + ORDER BY score DESC`, + { + embedding, + limit: neo4j.int(Math.floor(limit)), + minScore, + ...(agentId ? { agentId } : {}), + }, + ); - return result.records.map((r) => ({ - id: r.get("id") as string, - text: r.get("text") as string, - category: r.get("category") as string, - importance: r.get("importance") as number, - createdAt: String(r.get("createdAt") ?? ""), - score: r.get("similarity") as number, - })); + return result.records.map((r) => ({ + id: r.get("id") as string, + text: r.get("text") as string, + category: r.get("category") as string, + importance: r.get("importance") as number, + createdAt: String(r.get("createdAt") ?? ""), + score: r.get("similarity") as number, + })); + } finally { + await session.close(); + } + }); } catch (err) { - // Graceful degradation: return empty if vector index isn't ready + // Graceful degradation: return empty if vector index isn't ready or all retries exhausted this.logger.warn(`memory-neo4j: vector search failed: ${String(err)}`); return []; - } finally { - await session.close(); } } @@ -413,49 +424,58 @@ export class Neo4jMemoryClient { */ async bm25Search(query: string, limit: number, agentId?: string): Promise { await this.ensureInitialized(); - const session = this.driver!.session(); + const escaped = escapeLucene(query); + if (!escaped.trim()) { + return []; + } + try { - const escaped = escapeLucene(query); - if (!escaped.trim()) { - return []; - } + return await this.retryOnTransient(async () => { + const session = this.driver!.session(); + try { + const agentFilter = agentId ? "AND node.agentId = $agentId" : ""; + const result = await session.run( + `CALL db.index.fulltext.queryNodes('memory_fulltext_index', $query) + YIELD node, score + WHERE true ${agentFilter} + RETURN node.id AS id, node.text AS text, node.category AS category, + node.importance AS importance, node.createdAt AS createdAt, + score AS bm25Score + ORDER BY score DESC + LIMIT $limit`, + { + query: escaped, + limit: neo4j.int(Math.floor(limit)), + ...(agentId ? { agentId } : {}), + }, + ); - const agentFilter = agentId ? "AND node.agentId = $agentId" : ""; - const result = await session.run( - `CALL db.index.fulltext.queryNodes('memory_fulltext_index', $query) - YIELD node, score - WHERE true ${agentFilter} - RETURN node.id AS id, node.text AS text, node.category AS category, - node.importance AS importance, node.createdAt AS createdAt, - score AS bm25Score - ORDER BY score DESC - LIMIT $limit`, - { query: escaped, limit: neo4j.int(Math.floor(limit)), ...(agentId ? { agentId } : {}) }, - ); + // Normalize BM25 scores to 0-1 range (divide by max) + const records = result.records.map((r) => ({ + id: r.get("id") as string, + text: r.get("text") as string, + category: r.get("category") as string, + importance: r.get("importance") as number, + createdAt: String(r.get("createdAt") ?? ""), + rawScore: r.get("bm25Score") as number, + })); - // Normalize BM25 scores to 0-1 range (divide by max) - const records = result.records.map((r) => ({ - id: r.get("id") as string, - text: r.get("text") as string, - category: r.get("category") as string, - importance: r.get("importance") as number, - createdAt: String(r.get("createdAt") ?? ""), - rawScore: r.get("bm25Score") as number, - })); - - if (records.length === 0) { - return []; - } - const maxScore = records[0].rawScore || 1; - return records.map((r) => ({ - ...r, - score: r.rawScore / maxScore, - })); + if (records.length === 0) { + return []; + } + const maxScore = records[0].rawScore || 1; + return records.map((r) => ({ + ...r, + score: r.rawScore / maxScore, + })); + } finally { + await session.close(); + } + }); } catch (err) { + // Graceful degradation: return empty if all retries exhausted this.logger.warn(`memory-neo4j: BM25 search failed: ${String(err)}`); return []; - } finally { - await session.close(); } } @@ -475,89 +495,94 @@ export class Neo4jMemoryClient { agentId?: string, ): Promise { await this.ensureInitialized(); - const session = this.driver!.session(); + const escaped = escapeLucene(query); + if (!escaped.trim()) { + return []; + } + try { - const escaped = escapeLucene(query); - if (!escaped.trim()) { - return []; - } + return await this.retryOnTransient(async () => { + const session = this.driver!.session(); + try { + // Step 1: Find matching entities + const entityResult = await session.run( + `CALL db.index.fulltext.queryNodes('entity_fulltext_index', $query) + YIELD node, score + WHERE score >= 0.5 + RETURN node.id AS entityId, node.name AS name, score + ORDER BY score DESC + LIMIT 5`, + { query: escaped }, + ); - // Step 1: Find matching entities - const entityResult = await session.run( - `CALL db.index.fulltext.queryNodes('entity_fulltext_index', $query) - YIELD node, score - WHERE score >= 0.5 - RETURN node.id AS entityId, node.name AS name, score - ORDER BY score DESC - LIMIT 5`, - { query: escaped }, - ); + const entityIds = entityResult.records.map((r) => r.get("entityId") as string); + if (entityIds.length === 0) { + return []; + } - const entityIds = entityResult.records.map((r) => r.get("entityId") as string); - if (entityIds.length === 0) { - return []; - } + // Step 2 + 3: Direct mentions + 1-hop spreading activation + const agentFilter = agentId ? "AND m.agentId = $agentId" : ""; + const result = await session.run( + `UNWIND $entityIds AS eid + // Direct: Entity ← MENTIONS ← Memory + OPTIONAL MATCH (e:Entity {id: eid})<-[rm:MENTIONS]-(m:Memory) + WHERE m IS NOT NULL ${agentFilter} + WITH m, coalesce(rm.confidence, 1.0) AS directScore + WHERE m IS NOT NULL - // Step 2 + 3: Direct mentions + 1-hop spreading activation - const agentFilter = agentId ? "AND m.agentId = $agentId" : ""; - const result = await session.run( - `UNWIND $entityIds AS eid - // Direct: Entity ← MENTIONS ← Memory - OPTIONAL MATCH (e:Entity {id: eid})<-[rm:MENTIONS]-(m:Memory) - WHERE m IS NOT NULL ${agentFilter} - WITH m, coalesce(rm.confidence, 1.0) AS directScore - WHERE m IS NOT NULL + RETURN m.id AS id, m.text AS text, m.category AS category, + m.importance AS importance, m.createdAt AS createdAt, + max(directScore) AS graphScore - RETURN m.id AS id, m.text AS text, m.category AS category, - m.importance AS importance, m.createdAt AS createdAt, - max(directScore) AS graphScore + UNION - UNION + UNWIND $entityIds AS eid + // 1-hop: Entity → relationship → Entity ← MENTIONS ← Memory + OPTIONAL MATCH (e:Entity {id: eid})-[r1:RELATED_TO|KNOWS|WORKS_AT|LIVES_AT|MARRIED_TO|PREFERS|DECIDED]-(e2:Entity) + WHERE coalesce(r1.confidence, 0.7) >= $firingThreshold + OPTIONAL MATCH (e2)<-[rm:MENTIONS]-(m:Memory) + WHERE m IS NOT NULL ${agentFilter} + WITH m, coalesce(r1.confidence, 0.7) * coalesce(rm.confidence, 1.0) AS hopScore + WHERE m IS NOT NULL - UNWIND $entityIds AS eid - // 1-hop: Entity → relationship → Entity ← MENTIONS ← Memory - OPTIONAL MATCH (e:Entity {id: eid})-[r1:RELATED_TO|KNOWS|WORKS_AT|LIVES_AT|MARRIED_TO|PREFERS|DECIDED]-(e2:Entity) - WHERE coalesce(r1.confidence, 0.7) >= $firingThreshold - OPTIONAL MATCH (e2)<-[rm:MENTIONS]-(m:Memory) - WHERE m IS NOT NULL ${agentFilter} - WITH m, coalesce(r1.confidence, 0.7) * coalesce(rm.confidence, 1.0) AS hopScore - WHERE m IS NOT NULL + RETURN m.id AS id, m.text AS text, m.category AS category, + m.importance AS importance, m.createdAt AS createdAt, + max(hopScore) AS graphScore`, + { entityIds, firingThreshold, ...(agentId ? { agentId } : {}) }, + ); - RETURN m.id AS id, m.text AS text, m.category AS category, - m.importance AS importance, m.createdAt AS createdAt, - max(hopScore) AS graphScore`, - { entityIds, firingThreshold, ...(agentId ? { agentId } : {}) }, - ); + // Deduplicate by id, keeping highest score + const byId = new Map(); + for (const record of result.records) { + const id = record.get("id") as string; + if (!id) { + continue; + } + const score = record.get("graphScore") as number; + const existing = byId.get(id); + if (!existing || score > existing.score) { + byId.set(id, { + id, + text: record.get("text") as string, + category: record.get("category") as string, + importance: record.get("importance") as number, + createdAt: String(record.get("createdAt") ?? ""), + score, + }); + } + } - // Deduplicate by id, keeping highest score - const byId = new Map(); - for (const record of result.records) { - const id = record.get("id") as string; - if (!id) { - continue; + return Array.from(byId.values()) + .toSorted((a, b) => b.score - a.score) + .slice(0, limit); + } finally { + await session.close(); } - const score = record.get("graphScore") as number; - const existing = byId.get(id); - if (!existing || score > existing.score) { - byId.set(id, { - id, - text: record.get("text") as string, - category: record.get("category") as string, - importance: record.get("importance") as number, - createdAt: String(record.get("createdAt") ?? ""), - score, - }); - } - } - - return Array.from(byId.values()) - .toSorted((a, b) => b.score - a.score) - .slice(0, limit); + }); } catch (err) { + // Graceful degradation: return empty if all retries exhausted this.logger.warn(`memory-neo4j: graph search failed: ${String(err)}`); return []; - } finally { - await session.close(); } } @@ -570,28 +595,32 @@ export class Neo4jMemoryClient { limit: number = 1, ): Promise> { await this.ensureInitialized(); - const session = this.driver!.session(); try { - const result = await session.run( - `CALL db.index.vector.queryNodes('memory_embedding_index', $limit, $embedding) - YIELD node, score - WHERE score >= $threshold - RETURN node.id AS id, node.text AS text, score AS similarity - ORDER BY score DESC`, - { embedding, limit: neo4j.int(limit), threshold }, - ); + return await this.retryOnTransient(async () => { + const session = this.driver!.session(); + try { + const result = await session.run( + `CALL db.index.vector.queryNodes('memory_embedding_index', $limit, $embedding) + YIELD node, score + WHERE score >= $threshold + RETURN node.id AS id, node.text AS text, score AS similarity + ORDER BY score DESC`, + { embedding, limit: neo4j.int(limit), threshold }, + ); - return result.records.map((r) => ({ - id: r.get("id") as string, - text: r.get("text") as string, - score: r.get("similarity") as number, - })); + return result.records.map((r) => ({ + id: r.get("id") as string, + text: r.get("text") as string, + score: r.get("similarity") as number, + })); + } finally { + await session.close(); + } + }); } catch (err) { - // If vector index isn't ready, return no duplicates (allow store) + // If vector index isn't ready or all retries exhausted, return no duplicates (allow store) this.logger.debug?.(`memory-neo4j: similarity check failed: ${String(err)}`); return []; - } finally { - await session.close(); } } @@ -609,18 +638,20 @@ export class Neo4jMemoryClient { } await this.ensureInitialized(); - const session = this.driver!.session(); - try { - await session.run( - `UNWIND $ids AS memId - MATCH (m:Memory {id: memId}) - SET m.retrievalCount = coalesce(m.retrievalCount, 0) + 1, - m.lastRetrievedAt = $now`, - { ids: memoryIds, now: new Date().toISOString() }, - ); - } finally { - await session.close(); - } + return this.retryOnTransient(async () => { + const session = this.driver!.session(); + try { + await session.run( + `UNWIND $ids AS memId + MATCH (m:Memory {id: memId}) + SET m.retrievalCount = coalesce(m.retrievalCount, 0) + 1, + m.lastRetrievedAt = $now`, + { ids: memoryIds, now: new Date().toISOString() }, + ); + } finally { + await session.close(); + } + }); } /** @@ -832,14 +863,22 @@ export class Neo4jMemoryClient { /** * Update the extraction status of a Memory node. + * Optionally increments the extractionRetries counter (for transient failure tracking). */ - async updateExtractionStatus(id: string, status: ExtractionStatus): Promise { + async updateExtractionStatus( + id: string, + status: ExtractionStatus, + options?: { incrementRetries?: boolean }, + ): Promise { await this.ensureInitialized(); const session = this.driver!.session(); try { + const retryClause = options?.incrementRetries + ? ", m.extractionRetries = coalesce(m.extractionRetries, 0) + 1" + : ""; await session.run( `MATCH (m:Memory {id: $id}) - SET m.extractionStatus = $status, m.updatedAt = $now`, + SET m.extractionStatus = $status, m.updatedAt = $now${retryClause}`, { id, status, now: new Date().toISOString() }, ); } finally { @@ -847,6 +886,24 @@ export class Neo4jMemoryClient { } } + /** + * Get the current extraction retry count for a memory. + */ + async getExtractionRetries(id: string): Promise { + await this.ensureInitialized(); + const session = this.driver!.session(); + try { + const result = await session.run( + `MATCH (m:Memory {id: $id}) + RETURN coalesce(m.extractionRetries, 0) AS retries`, + { id }, + ); + return (result.records[0]?.get("retries") as number) ?? 0; + } finally { + await session.close(); + } + } + /** * List memories with pending extraction status. * Used by the sleep cycle to batch-process extractions. @@ -854,7 +911,7 @@ export class Neo4jMemoryClient { async listPendingExtractions( limit: number = 100, agentId?: string, - ): Promise> { + ): Promise> { await this.ensureInitialized(); const session = this.driver!.session(); try { @@ -862,7 +919,8 @@ export class Neo4jMemoryClient { const result = await session.run( `MATCH (m:Memory) WHERE m.extractionStatus = 'pending' ${agentFilter} - RETURN m.id AS id, m.text AS text, m.agentId AS agentId + RETURN m.id AS id, m.text AS text, m.agentId AS agentId, + coalesce(m.extractionRetries, 0) AS extractionRetries ORDER BY m.createdAt ASC LIMIT $limit`, { limit: neo4j.int(limit), agentId }, @@ -871,6 +929,7 @@ export class Neo4jMemoryClient { id: r.get("id") as string, text: r.get("text") as string, agentId: r.get("agentId") as string, + extractionRetries: r.get("extractionRetries") as number, })); } finally { await session.close(); @@ -976,14 +1035,17 @@ export class Neo4jMemoryClient { let pairsFound = 0; for (const id of memoryData.keys()) { - const similar = await session.run( - `MATCH (src:Memory {id: $id}) - CALL db.index.vector.queryNodes('memory_embedding_index', $k, src.embedding) - YIELD node, score - WHERE node.id <> $id AND score >= $threshold - RETURN node.id AS matchId`, - { id, k: neo4j.int(10), threshold }, - ); + // Retry individual vector queries on transient errors + const similar = await this.retryOnTransient(async () => { + return session.run( + `MATCH (src:Memory {id: $id}) + CALL db.index.vector.queryNodes('memory_embedding_index', $k, src.embedding) + YIELD node, score + WHERE node.id <> $id AND score >= $threshold + RETURN node.id AS matchId`, + { id, k: neo4j.int(10), threshold }, + ); + }); for (const r of similar.records) { const matchId = r.get("matchId") as string; @@ -1045,30 +1107,56 @@ export class Neo4jMemoryClient { const survivorId = memoryIds[survivorIdx]; const toDelete = memoryIds.filter((_, i) => i !== survivorIdx); - const session = this.driver!.session(); - try { - // Transfer MENTIONS relationships from deleted memories to survivor - await session.run( - `UNWIND $toDelete AS deadId - MATCH (dead:Memory {id: deadId})-[r:MENTIONS]->(e:Entity) - MATCH (survivor:Memory {id: $survivorId}) - MERGE (survivor)-[:MENTIONS]->(e) - DELETE r`, - { toDelete, survivorId }, - ); + return this.retryOnTransient(async () => { + const session = this.driver!.session(); + try { + // Optimistic lock: verify all cluster members still exist before merging. + // New memories added or deleted between findDuplicateClusters() and this + // call could invalidate the cluster. Skip if any member is missing. + const verifyResult = await session.run( + `UNWIND $ids AS memId + OPTIONAL MATCH (m:Memory {id: memId}) + RETURN memId, m IS NOT NULL AS exists`, + { ids: memoryIds }, + ); - // Delete the duplicate memories - await session.run( - `UNWIND $toDelete AS deadId - MATCH (m:Memory {id: deadId}) - DETACH DELETE m`, - { toDelete }, - ); + const missingIds: string[] = []; + for (const r of verifyResult.records) { + if (!r.get("exists")) { + missingIds.push(r.get("memId") as string); + } + } - return { survivorId, deletedCount: toDelete.length }; - } finally { - await session.close(); - } + if (missingIds.length > 0) { + this.logger.warn( + `memory-neo4j: skipping cluster merge — ${missingIds.length} member(s) no longer exist: ${missingIds.join(", ")}`, + ); + return { survivorId, deletedCount: 0 }; + } + + // Transfer MENTIONS relationships from deleted memories to survivor + await session.run( + `UNWIND $toDelete AS deadId + MATCH (dead:Memory {id: deadId})-[r:MENTIONS]->(e:Entity) + MATCH (survivor:Memory {id: $survivorId}) + MERGE (survivor)-[:MENTIONS]->(e) + DELETE r`, + { toDelete, survivorId }, + ); + + // Delete the duplicate memories + await session.run( + `UNWIND $toDelete AS deadId + MATCH (m:Memory {id: deadId}) + DETACH DELETE m`, + { toDelete }, + ); + + return { survivorId, deletedCount: toDelete.length }; + } finally { + await session.close(); + } + }); } // -------------------------------------------------------------------------- @@ -1160,11 +1248,12 @@ export class Neo4jMemoryClient { await this.ensureInitialized(); const session = this.driver!.session(); try { - // Decrement mention counts on connected entities + // Decrement mention counts on connected entities (floor at 0 to prevent + // negative counts from parallel prune/delete operations racing on the same entity) await session.run( `UNWIND $ids AS memId MATCH (m:Memory {id: memId})-[:MENTIONS]->(e:Entity) - SET e.mentionCount = e.mentionCount - 1`, + SET e.mentionCount = CASE WHEN e.mentionCount > 0 THEN e.mentionCount - 1 ELSE 0 END`, { ids: memoryIds }, ); @@ -1581,7 +1670,7 @@ export class Neo4jMemoryClient { // -------------------------------------------------------------------------- /** - * Retry an operation on transient Neo4j errors (deadlocks, etc.) + * Retry an operation on transient Neo4j errors (deadlocks, connection blips, etc.) * with exponential backoff. Adapted from ontology project. */ private async retryOnTransient( @@ -1595,14 +1684,24 @@ export class Neo4jMemoryClient { return await fn(); } catch (err) { lastError = err; - // Check for Neo4j transient errors + // Check for Neo4j transient errors (deadlocks, connection blips, service unavailable) + const errCode = + err instanceof Error + ? ((err as unknown as Record).code as string | undefined) + : undefined; const isTransient = err instanceof Error && (err.message.includes("DeadlockDetected") || err.message.includes("TransientError") || + err.message.includes("ServiceUnavailable") || + err.message.includes("SessionExpired") || + err.message.includes("ConnectionRefused") || + err.message.includes("connection terminated") || (err.constructor.name === "Neo4jError" && - (err as unknown as Record).code === - "Neo.TransientError.Transaction.DeadlockDetected")); + typeof errCode === "string" && + (errCode.startsWith("Neo.TransientError.") || + errCode === "ServiceUnavailable" || + errCode === "SessionExpired"))); if (!isTransient || attempt >= maxAttempts - 1) { throw err; diff --git a/extensions/memory-neo4j/schema.test.ts b/extensions/memory-neo4j/schema.test.ts new file mode 100644 index 00000000000..d546c8e1ebf --- /dev/null +++ b/extensions/memory-neo4j/schema.test.ts @@ -0,0 +1,200 @@ +/** + * Tests for schema.ts — Schema Validation & Helpers. + * + * Tests the exported pure functions: escapeLucene(), validateRelationshipType(), + * and the exported constants and types. + */ + +import { describe, it, expect } from "vitest"; +import { + escapeLucene, + validateRelationshipType, + ALLOWED_RELATIONSHIP_TYPES, + MEMORY_CATEGORIES, + ENTITY_TYPES, +} from "./schema.js"; + +// ============================================================================ +// escapeLucene() +// ============================================================================ + +describe("escapeLucene", () => { + it("should return normal text unchanged", () => { + expect(escapeLucene("hello world")).toBe("hello world"); + }); + + it("should return empty string unchanged", () => { + expect(escapeLucene("")).toBe(""); + }); + + it("should escape plus sign", () => { + expect(escapeLucene("a+b")).toBe("a\\+b"); + }); + + it("should escape minus sign", () => { + expect(escapeLucene("a-b")).toBe("a\\-b"); + }); + + it("should escape ampersand", () => { + expect(escapeLucene("a&b")).toBe("a\\&b"); + }); + + it("should escape pipe", () => { + expect(escapeLucene("a|b")).toBe("a\\|b"); + }); + + it("should escape exclamation mark", () => { + expect(escapeLucene("hello!")).toBe("hello\\!"); + }); + + it("should escape parentheses", () => { + expect(escapeLucene("(group)")).toBe("\\(group\\)"); + }); + + it("should escape curly braces", () => { + expect(escapeLucene("{range}")).toBe("\\{range\\}"); + }); + + it("should escape square brackets", () => { + expect(escapeLucene("[range]")).toBe("\\[range\\]"); + }); + + it("should escape caret", () => { + expect(escapeLucene("boost^2")).toBe("boost\\^2"); + }); + + it("should escape double quotes", () => { + expect(escapeLucene('"exact"')).toBe('\\"exact\\"'); + }); + + it("should escape tilde", () => { + expect(escapeLucene("fuzzy~")).toBe("fuzzy\\~"); + }); + + it("should escape asterisk", () => { + expect(escapeLucene("wild*")).toBe("wild\\*"); + }); + + it("should escape question mark", () => { + expect(escapeLucene("single?")).toBe("single\\?"); + }); + + it("should escape colon", () => { + expect(escapeLucene("field:value")).toBe("field\\:value"); + }); + + it("should escape backslash", () => { + expect(escapeLucene("path\\file")).toBe("path\\\\file"); + }); + + it("should escape forward slash", () => { + expect(escapeLucene("a/b")).toBe("a\\/b"); + }); + + it("should escape multiple special characters in one string", () => { + expect(escapeLucene("(a+b) && c*")).toBe("\\(a\\+b\\) \\&\\& c\\*"); + }); + + it("should handle mixed normal and special characters", () => { + expect(escapeLucene("hello world! [test]")).toBe("hello world\\! \\[test\\]"); + }); + + it("should handle strings with only special characters", () => { + expect(escapeLucene("+-")).toBe("\\+\\-"); + }); +}); + +// ============================================================================ +// validateRelationshipType() +// ============================================================================ + +describe("validateRelationshipType", () => { + describe("valid relationship types", () => { + it("should accept WORKS_AT", () => { + expect(validateRelationshipType("WORKS_AT")).toBe(true); + }); + + it("should accept LIVES_AT", () => { + expect(validateRelationshipType("LIVES_AT")).toBe(true); + }); + + it("should accept KNOWS", () => { + expect(validateRelationshipType("KNOWS")).toBe(true); + }); + + it("should accept MARRIED_TO", () => { + expect(validateRelationshipType("MARRIED_TO")).toBe(true); + }); + + it("should accept PREFERS", () => { + expect(validateRelationshipType("PREFERS")).toBe(true); + }); + + it("should accept DECIDED", () => { + expect(validateRelationshipType("DECIDED")).toBe(true); + }); + + it("should accept RELATED_TO", () => { + expect(validateRelationshipType("RELATED_TO")).toBe(true); + }); + + it("should accept all ALLOWED_RELATIONSHIP_TYPES", () => { + for (const type of ALLOWED_RELATIONSHIP_TYPES) { + expect(validateRelationshipType(type)).toBe(true); + } + }); + }); + + describe("invalid relationship types", () => { + it("should reject unknown relationship type", () => { + expect(validateRelationshipType("HATES")).toBe(false); + }); + + it("should reject empty string", () => { + expect(validateRelationshipType("")).toBe(false); + }); + + it("should be case sensitive — lowercase is rejected", () => { + expect(validateRelationshipType("works_at")).toBe(false); + }); + + it("should be case sensitive — mixed case is rejected", () => { + expect(validateRelationshipType("Works_At")).toBe(false); + }); + + it("should reject types with extra whitespace", () => { + expect(validateRelationshipType(" WORKS_AT ")).toBe(false); + }); + + it("should reject potential Cypher injection", () => { + expect(validateRelationshipType("WORKS_AT]->(n) DELETE n//")).toBe(false); + }); + }); +}); + +// ============================================================================ +// Exported Constants +// ============================================================================ + +describe("exported constants", () => { + it("MEMORY_CATEGORIES should contain expected categories", () => { + expect(MEMORY_CATEGORIES).toContain("preference"); + expect(MEMORY_CATEGORIES).toContain("fact"); + expect(MEMORY_CATEGORIES).toContain("decision"); + expect(MEMORY_CATEGORIES).toContain("entity"); + expect(MEMORY_CATEGORIES).toContain("other"); + }); + + it("ENTITY_TYPES should contain expected types", () => { + expect(ENTITY_TYPES).toContain("person"); + expect(ENTITY_TYPES).toContain("organization"); + expect(ENTITY_TYPES).toContain("location"); + expect(ENTITY_TYPES).toContain("event"); + expect(ENTITY_TYPES).toContain("concept"); + }); + + it("ALLOWED_RELATIONSHIP_TYPES should be a Set", () => { + expect(ALLOWED_RELATIONSHIP_TYPES).toBeInstanceOf(Set); + expect(ALLOWED_RELATIONSHIP_TYPES.size).toBe(7); + }); +}); diff --git a/extensions/memory-neo4j/search.test.ts b/extensions/memory-neo4j/search.test.ts new file mode 100644 index 00000000000..f3ff09a53f0 --- /dev/null +++ b/extensions/memory-neo4j/search.test.ts @@ -0,0 +1,400 @@ +/** + * Tests for search.ts — Hybrid Search & RRF Fusion. + * + * Tests the exported pure logic: classifyQuery() and getAdaptiveWeights(). + * Note: fuseWithConfidenceRRF() is not exported (private module-level function) + * and is tested indirectly through hybridSearch(). + * hybridSearch() is tested with mocked Neo4j client and Embeddings. + */ + +import { describe, it, expect, vi, beforeEach } from "vitest"; +import type { SearchSignalResult } from "./schema.js"; +import { classifyQuery, getAdaptiveWeights, hybridSearch } from "./search.js"; + +// ============================================================================ +// classifyQuery() +// ============================================================================ + +describe("classifyQuery", () => { + describe("short queries (1-2 words)", () => { + it("should classify a single word as 'short'", () => { + expect(classifyQuery("dogs")).toBe("short"); + }); + + it("should classify two words as 'short'", () => { + expect(classifyQuery("best coffee")).toBe("short"); + }); + + it("should classify a single capitalized word as 'short' (word count takes priority)", () => { + expect(classifyQuery("TypeScript")).toBe("short"); + }); + + it("should handle whitespace-padded short queries", () => { + expect(classifyQuery(" hello ")).toBe("short"); + }); + }); + + describe("entity queries (proper nouns)", () => { + it("should classify query with proper noun as 'entity'", () => { + expect(classifyQuery("tell me about Tarun")).toBe("entity"); + }); + + it("should classify query with organization name as 'entity'", () => { + expect(classifyQuery("what about Google")).toBe("entity"); + }); + + it("should classify question patterns targeting entities", () => { + expect(classifyQuery("who is the CEO")).toBe("entity"); + }); + + it("should classify 'where is' patterns as entity", () => { + expect(classifyQuery("where is the office")).toBe("entity"); + }); + + it("should classify 'what does' patterns as entity", () => { + expect(classifyQuery("what does she do")).toBe("entity"); + }); + + it("should not treat common words (The, Is, etc.) as entity indicators", () => { + // "The" and "Is" are excluded from capitalized word detection + // 3 words, no proper nouns detected, no question pattern -> default + expect(classifyQuery("this is fine")).toBe("default"); + }); + }); + + describe("long queries (5+ words)", () => { + it("should classify a 5-word query as 'long'", () => { + expect(classifyQuery("what is the best framework")).toBe("long"); + }); + + it("should classify a longer sentence as 'long'", () => { + expect(classifyQuery("tell me about the history of programming languages")).toBe("long"); + }); + + it("should classify a verbose question as 'long'", () => { + expect(classifyQuery("how do i configure the database connection")).toBe("long"); + }); + }); + + describe("default queries (3-4 words, no entities)", () => { + it("should classify a 3-word lowercase query as 'default'", () => { + expect(classifyQuery("my favorite color")).toBe("default"); + }); + + it("should classify a 4-word lowercase query as 'default'", () => { + expect(classifyQuery("best practices for testing")).toBe("default"); + }); + }); + + describe("edge cases", () => { + it("should handle empty string", () => { + // Empty string splits to [""], length 1 -> "short" + expect(classifyQuery("")).toBe("short"); + }); + + it("should handle only whitespace", () => { + // " ".trim() = "", splits to [""], length 1 -> "short" + expect(classifyQuery(" ")).toBe("short"); + }); + }); +}); + +// ============================================================================ +// getAdaptiveWeights() +// ============================================================================ + +describe("getAdaptiveWeights", () => { + describe("with graph enabled", () => { + it("should boost BM25 for short queries", () => { + const [vector, bm25, graph] = getAdaptiveWeights("short", true); + expect(bm25).toBeGreaterThan(vector); + expect(vector).toBe(0.8); + expect(bm25).toBe(1.2); + expect(graph).toBe(1.0); + }); + + it("should boost graph for entity queries", () => { + const [vector, bm25, graph] = getAdaptiveWeights("entity", true); + expect(graph).toBeGreaterThan(vector); + expect(graph).toBeGreaterThan(bm25); + expect(vector).toBe(0.8); + expect(bm25).toBe(1.0); + expect(graph).toBe(1.3); + }); + + it("should boost vector for long queries", () => { + const [vector, bm25, graph] = getAdaptiveWeights("long", true); + expect(vector).toBeGreaterThan(bm25); + expect(vector).toBeGreaterThan(graph); + expect(vector).toBe(1.2); + expect(bm25).toBe(0.7); + expect(graph).toBeCloseTo(0.8); + }); + + it("should return balanced weights for default queries", () => { + const [vector, bm25, graph] = getAdaptiveWeights("default", true); + expect(vector).toBe(1.0); + expect(bm25).toBe(1.0); + expect(graph).toBe(1.0); + }); + }); + + describe("with graph disabled", () => { + it("should zero-out graph weight for short queries", () => { + const [vector, bm25, graph] = getAdaptiveWeights("short", false); + expect(graph).toBe(0); + expect(vector).toBe(0.8); + expect(bm25).toBe(1.2); + }); + + it("should zero-out graph weight for entity queries", () => { + const [, , graph] = getAdaptiveWeights("entity", false); + expect(graph).toBe(0); + }); + + it("should zero-out graph weight for long queries", () => { + const [, , graph] = getAdaptiveWeights("long", false); + expect(graph).toBe(0); + }); + + it("should zero-out graph weight for default queries", () => { + const [, , graph] = getAdaptiveWeights("default", false); + expect(graph).toBe(0); + }); + }); +}); + +// ============================================================================ +// hybridSearch() — integration test with mocked dependencies +// ============================================================================ + +describe("hybridSearch", () => { + // Create mock db and embeddings + const mockDb = { + vectorSearch: vi.fn(), + bm25Search: vi.fn(), + graphSearch: vi.fn(), + recordRetrievals: vi.fn(), + }; + + const mockEmbeddings = { + embed: vi.fn(), + embedBatch: vi.fn(), + }; + + beforeEach(() => { + vi.resetAllMocks(); + mockEmbeddings.embed.mockResolvedValue([0.1, 0.2, 0.3]); + mockDb.recordRetrievals.mockResolvedValue(undefined); + }); + + function makeSignalResult(overrides: Partial = {}): SearchSignalResult { + return { + id: "mem-1", + text: "Test memory", + category: "fact", + importance: 0.7, + createdAt: "2025-01-01T00:00:00Z", + score: 0.9, + ...overrides, + }; + } + + it("should return empty array when no signals return results", async () => { + mockDb.vectorSearch.mockResolvedValue([]); + mockDb.bm25Search.mockResolvedValue([]); + + const results = await hybridSearch( + mockDb as never, + mockEmbeddings as never, + "test query", + 5, + "agent-1", + false, + ); + + expect(results).toEqual([]); + expect(mockDb.recordRetrievals).not.toHaveBeenCalled(); + }); + + it("should fuse results from vector and BM25 signals", async () => { + const vectorResult = makeSignalResult({ id: "mem-1", score: 0.95, text: "Vector match" }); + const bm25Result = makeSignalResult({ id: "mem-2", score: 0.8, text: "BM25 match" }); + + mockDb.vectorSearch.mockResolvedValue([vectorResult]); + mockDb.bm25Search.mockResolvedValue([bm25Result]); + + const results = await hybridSearch( + mockDb as never, + mockEmbeddings as never, + "test query", + 5, + "agent-1", + false, + ); + + expect(results.length).toBe(2); + // Results should have scores normalized to 0-1 + expect(results[0].score).toBeLessThanOrEqual(1); + expect(results[0].score).toBeGreaterThanOrEqual(0); + // First result should have the highest score (normalized to 1) + expect(results[0].score).toBe(1); + }); + + it("should deduplicate across signals (same memory in multiple signals)", async () => { + const sharedResult = makeSignalResult({ id: "mem-shared", score: 0.9 }); + + mockDb.vectorSearch.mockResolvedValue([sharedResult]); + mockDb.bm25Search.mockResolvedValue([{ ...sharedResult, score: 0.85 }]); + + const results = await hybridSearch( + mockDb as never, + mockEmbeddings as never, + "test query", + 5, + "agent-1", + false, + ); + + // Should only have one result (deduplicated by ID) + expect(results.length).toBe(1); + expect(results[0].id).toBe("mem-shared"); + // Score should be higher than either individual signal (boosted by appearing in both) + expect(results[0].score).toBe(1); // It's the only result, so normalized to 1 + }); + + it("should include graph signal when graphEnabled is true", async () => { + mockDb.vectorSearch.mockResolvedValue([]); + mockDb.bm25Search.mockResolvedValue([]); + mockDb.graphSearch.mockResolvedValue([ + makeSignalResult({ id: "mem-graph", score: 0.7, text: "Graph result" }), + ]); + + const results = await hybridSearch( + mockDb as never, + mockEmbeddings as never, + "tell me about Tarun", + 5, + "agent-1", + true, + ); + + expect(mockDb.graphSearch).toHaveBeenCalled(); + expect(results.length).toBe(1); + expect(results[0].id).toBe("mem-graph"); + }); + + it("should not call graphSearch when graphEnabled is false", async () => { + mockDb.vectorSearch.mockResolvedValue([]); + mockDb.bm25Search.mockResolvedValue([]); + + await hybridSearch(mockDb as never, mockEmbeddings as never, "test query", 5, "agent-1", false); + + expect(mockDb.graphSearch).not.toHaveBeenCalled(); + }); + + it("should limit results to the requested count", async () => { + const manyResults = Array.from({ length: 10 }, (_, i) => + makeSignalResult({ id: `mem-${i}`, score: 0.9 - i * 0.05 }), + ); + + mockDb.vectorSearch.mockResolvedValue(manyResults); + mockDb.bm25Search.mockResolvedValue([]); + + const results = await hybridSearch( + mockDb as never, + mockEmbeddings as never, + "test query", + 3, + "agent-1", + false, + ); + + expect(results.length).toBe(3); + }); + + it("should record retrieval events for returned results", async () => { + mockDb.vectorSearch.mockResolvedValue([ + makeSignalResult({ id: "mem-1" }), + makeSignalResult({ id: "mem-2" }), + ]); + mockDb.bm25Search.mockResolvedValue([]); + + await hybridSearch(mockDb as never, mockEmbeddings as never, "test query", 5, "agent-1", false); + + expect(mockDb.recordRetrievals).toHaveBeenCalledWith(["mem-1", "mem-2"]); + }); + + it("should silently handle recordRetrievals failure", async () => { + mockDb.vectorSearch.mockResolvedValue([makeSignalResult({ id: "mem-1" })]); + mockDb.bm25Search.mockResolvedValue([]); + mockDb.recordRetrievals.mockRejectedValue(new Error("DB connection lost")); + + // Should not throw + const results = await hybridSearch( + mockDb as never, + mockEmbeddings as never, + "test query", + 5, + "agent-1", + false, + ); + + expect(results.length).toBe(1); + }); + + it("should normalize scores to 0-1 range", async () => { + mockDb.vectorSearch.mockResolvedValue([ + makeSignalResult({ id: "mem-1", score: 0.95 }), + makeSignalResult({ id: "mem-2", score: 0.5 }), + ]); + mockDb.bm25Search.mockResolvedValue([]); + + const results = await hybridSearch( + mockDb as never, + mockEmbeddings as never, + "test query", + 5, + "agent-1", + false, + ); + + for (const r of results) { + expect(r.score).toBeGreaterThanOrEqual(0); + expect(r.score).toBeLessThanOrEqual(1); + } + }); + + it("should use candidateMultiplier option", async () => { + mockDb.vectorSearch.mockResolvedValue([]); + mockDb.bm25Search.mockResolvedValue([]); + + await hybridSearch( + mockDb as never, + mockEmbeddings as never, + "test query", + 5, + "agent-1", + false, + { candidateMultiplier: 8 }, + ); + + // limit=5, multiplier=8 => candidateLimit = 40 + expect(mockDb.vectorSearch).toHaveBeenCalledWith(expect.any(Array), 40, 0.1, "agent-1"); + expect(mockDb.bm25Search).toHaveBeenCalledWith("test query", 40, "agent-1"); + }); + + it("should pass default agentId when not specified", async () => { + mockDb.vectorSearch.mockResolvedValue([]); + mockDb.bm25Search.mockResolvedValue([]); + + await hybridSearch(mockDb as never, mockEmbeddings as never, "test query"); + + expect(mockDb.vectorSearch).toHaveBeenCalledWith( + expect.any(Array), + expect.any(Number), + 0.1, + "default", + ); + }); +});