memory-neo4j: harden error handling, concurrency safety, config validation + add tests

This commit is contained in:
Tarun Sukhani
2026-02-05 12:23:36 +00:00
parent c1371b639e
commit 3082c53a76
11 changed files with 2769 additions and 291 deletions

View File

@@ -0,0 +1,549 @@
/**
* Tests for config.ts — Configuration Parsing.
*
* Tests memoryNeo4jConfigSchema.parse(), vectorDimsForModel(), and resolveExtractionConfig().
*/
import { describe, it, expect, beforeEach, afterEach, vi } from "vitest";
import { memoryNeo4jConfigSchema, vectorDimsForModel, resolveExtractionConfig } from "./config.js";
// ============================================================================
// memoryNeo4jConfigSchema.parse()
// ============================================================================
describe("memoryNeo4jConfigSchema.parse", () => {
// Store original env vars so we can restore them
const originalEnv = { ...process.env };
afterEach(() => {
process.env = { ...originalEnv };
});
describe("valid complete configs", () => {
it("should parse a minimal valid config with ollama provider", () => {
const config = memoryNeo4jConfigSchema.parse({
neo4j: { uri: "bolt://localhost:7687", user: "neo4j", password: "test" },
embedding: { provider: "ollama" },
});
expect(config.neo4j.uri).toBe("bolt://localhost:7687");
expect(config.neo4j.username).toBe("neo4j");
expect(config.neo4j.password).toBe("test");
expect(config.embedding.provider).toBe("ollama");
expect(config.embedding.model).toBe("mxbai-embed-large");
expect(config.embedding.apiKey).toBeUndefined();
expect(config.autoCapture).toBe(true);
expect(config.autoRecall).toBe(true);
expect(config.coreMemory.enabled).toBe(true);
expect(config.coreMemory.maxEntries).toBe(50);
});
it("should parse a full config with openai provider", () => {
const config = memoryNeo4jConfigSchema.parse({
neo4j: {
uri: "neo4j+s://cloud.neo4j.io:7687",
username: "admin",
password: "secret",
},
embedding: {
provider: "openai",
apiKey: "sk-test-key",
model: "text-embedding-3-large",
},
autoCapture: false,
autoRecall: false,
coreMemory: {
enabled: false,
maxEntries: 100,
refreshAtContextPercent: 75,
},
});
expect(config.neo4j.uri).toBe("neo4j+s://cloud.neo4j.io:7687");
expect(config.neo4j.username).toBe("admin");
expect(config.neo4j.password).toBe("secret");
expect(config.embedding.provider).toBe("openai");
expect(config.embedding.apiKey).toBe("sk-test-key");
expect(config.embedding.model).toBe("text-embedding-3-large");
expect(config.autoCapture).toBe(false);
expect(config.autoRecall).toBe(false);
expect(config.coreMemory.enabled).toBe(false);
expect(config.coreMemory.maxEntries).toBe(100);
expect(config.coreMemory.refreshAtContextPercent).toBe(75);
});
it("should support 'user' field as alias for 'username' in neo4j config", () => {
const config = memoryNeo4jConfigSchema.parse({
neo4j: { uri: "bolt://localhost:7687", user: "custom-user", password: "pass" },
embedding: { provider: "ollama" },
});
expect(config.neo4j.username).toBe("custom-user");
});
it("should support 'username' field in neo4j config", () => {
const config = memoryNeo4jConfigSchema.parse({
neo4j: { uri: "bolt://localhost:7687", username: "custom-user", password: "pass" },
embedding: { provider: "ollama" },
});
expect(config.neo4j.username).toBe("custom-user");
});
it("should default neo4j username to 'neo4j' when not specified", () => {
const config = memoryNeo4jConfigSchema.parse({
neo4j: { uri: "bolt://localhost:7687", password: "pass" },
embedding: { provider: "ollama" },
});
expect(config.neo4j.username).toBe("neo4j");
});
});
describe("missing required fields", () => {
it("should throw when config is null", () => {
expect(() => memoryNeo4jConfigSchema.parse(null)).toThrow("memory-neo4j config required");
});
it("should throw when config is undefined", () => {
expect(() => memoryNeo4jConfigSchema.parse(undefined)).toThrow(
"memory-neo4j config required",
);
});
it("should throw when config is not an object", () => {
expect(() => memoryNeo4jConfigSchema.parse("string")).toThrow("memory-neo4j config required");
});
it("should throw when config is an array", () => {
expect(() => memoryNeo4jConfigSchema.parse([])).toThrow("memory-neo4j config required");
});
it("should throw when neo4j section is missing", () => {
expect(() =>
memoryNeo4jConfigSchema.parse({
embedding: { provider: "ollama" },
}),
).toThrow("neo4j config section is required");
});
it("should throw when neo4j.uri is missing", () => {
expect(() =>
memoryNeo4jConfigSchema.parse({
neo4j: { password: "test" },
embedding: { provider: "ollama" },
}),
).toThrow("neo4j.uri is required");
});
it("should throw when neo4j.uri is empty string", () => {
expect(() =>
memoryNeo4jConfigSchema.parse({
neo4j: { uri: "", password: "test" },
embedding: { provider: "ollama" },
}),
).toThrow("neo4j.uri is required");
});
});
describe("environment variable resolution", () => {
it("should resolve ${ENV_VAR} in neo4j.password", () => {
process.env.TEST_NEO4J_PASSWORD = "resolved-password";
const config = memoryNeo4jConfigSchema.parse({
neo4j: {
uri: "bolt://localhost:7687",
password: "${TEST_NEO4J_PASSWORD}",
},
embedding: { provider: "ollama" },
});
expect(config.neo4j.password).toBe("resolved-password");
});
it("should resolve ${ENV_VAR} in embedding.apiKey", () => {
process.env.TEST_OPENAI_KEY = "sk-from-env";
const config = memoryNeo4jConfigSchema.parse({
neo4j: { uri: "bolt://localhost:7687", password: "" },
embedding: { provider: "openai", apiKey: "${TEST_OPENAI_KEY}" },
});
expect(config.embedding.apiKey).toBe("sk-from-env");
});
it("should throw when referenced env var is not set", () => {
delete process.env.NONEXISTENT_VAR;
expect(() =>
memoryNeo4jConfigSchema.parse({
neo4j: {
uri: "bolt://localhost:7687",
password: "${NONEXISTENT_VAR}",
},
embedding: { provider: "ollama" },
}),
).toThrow("Environment variable NONEXISTENT_VAR is not set");
});
});
describe("default values", () => {
it("should default autoCapture to true", () => {
const config = memoryNeo4jConfigSchema.parse({
neo4j: { uri: "bolt://localhost:7687", password: "" },
embedding: { provider: "ollama" },
});
expect(config.autoCapture).toBe(true);
});
it("should default autoRecall to true", () => {
const config = memoryNeo4jConfigSchema.parse({
neo4j: { uri: "bolt://localhost:7687", password: "" },
embedding: { provider: "ollama" },
});
expect(config.autoRecall).toBe(true);
});
it("should default coreMemory.enabled to true", () => {
const config = memoryNeo4jConfigSchema.parse({
neo4j: { uri: "bolt://localhost:7687", password: "" },
embedding: { provider: "ollama" },
});
expect(config.coreMemory.enabled).toBe(true);
});
it("should default coreMemory.maxEntries to 50", () => {
const config = memoryNeo4jConfigSchema.parse({
neo4j: { uri: "bolt://localhost:7687", password: "" },
embedding: { provider: "ollama" },
});
expect(config.coreMemory.maxEntries).toBe(50);
});
it("should default refreshAtContextPercent to undefined", () => {
const config = memoryNeo4jConfigSchema.parse({
neo4j: { uri: "bolt://localhost:7687", password: "" },
embedding: { provider: "ollama" },
});
expect(config.coreMemory.refreshAtContextPercent).toBeUndefined();
});
it("should default embedding model to mxbai-embed-large for ollama", () => {
const config = memoryNeo4jConfigSchema.parse({
neo4j: { uri: "bolt://localhost:7687", password: "" },
embedding: { provider: "ollama" },
});
expect(config.embedding.model).toBe("mxbai-embed-large");
});
it("should default embedding model to text-embedding-3-small for openai", () => {
const config = memoryNeo4jConfigSchema.parse({
neo4j: { uri: "bolt://localhost:7687", password: "" },
embedding: { provider: "openai", apiKey: "sk-test" },
});
expect(config.embedding.model).toBe("text-embedding-3-small");
});
it("should default neo4j.password to empty string when not provided", () => {
const config = memoryNeo4jConfigSchema.parse({
neo4j: { uri: "bolt://localhost:7687" },
embedding: { provider: "ollama" },
});
expect(config.neo4j.password).toBe("");
});
});
describe("provider validation", () => {
it("should require apiKey for openai provider", () => {
expect(() =>
memoryNeo4jConfigSchema.parse({
neo4j: { uri: "bolt://localhost:7687", password: "" },
embedding: { provider: "openai" },
}),
).toThrow("embedding.apiKey is required for OpenAI provider");
});
it("should not require apiKey for ollama provider", () => {
const config = memoryNeo4jConfigSchema.parse({
neo4j: { uri: "bolt://localhost:7687", password: "" },
embedding: { provider: "ollama" },
});
expect(config.embedding.apiKey).toBeUndefined();
});
it("should default to openai when no provider is specified", () => {
// No provider but has apiKey — should default to openai
const config = memoryNeo4jConfigSchema.parse({
neo4j: { uri: "bolt://localhost:7687", password: "" },
embedding: { apiKey: "sk-test" },
});
expect(config.embedding.provider).toBe("openai");
});
it("should accept embedding.baseUrl", () => {
const config = memoryNeo4jConfigSchema.parse({
neo4j: { uri: "bolt://localhost:7687", password: "" },
embedding: { provider: "ollama", baseUrl: "http://my-ollama:11434" },
});
expect(config.embedding.baseUrl).toBe("http://my-ollama:11434");
});
});
describe("unknown keys rejected", () => {
it("should reject unknown top-level keys", () => {
expect(() =>
memoryNeo4jConfigSchema.parse({
neo4j: { uri: "bolt://localhost:7687", password: "" },
embedding: { provider: "ollama" },
unknownKey: "value",
}),
).toThrow("unknown keys: unknownKey");
});
it("should reject unknown neo4j keys", () => {
expect(() =>
memoryNeo4jConfigSchema.parse({
neo4j: { uri: "bolt://localhost:7687", password: "", port: 7687 },
embedding: { provider: "ollama" },
}),
).toThrow("unknown keys: port");
});
it("should reject unknown embedding keys", () => {
expect(() =>
memoryNeo4jConfigSchema.parse({
neo4j: { uri: "bolt://localhost:7687", password: "" },
embedding: { provider: "ollama", temperature: 0.5 },
}),
).toThrow("unknown keys: temperature");
});
it("should reject unknown coreMemory keys", () => {
expect(() =>
memoryNeo4jConfigSchema.parse({
neo4j: { uri: "bolt://localhost:7687", password: "" },
embedding: { provider: "ollama" },
coreMemory: { unknownField: true },
}),
).toThrow("unknown keys: unknownField");
});
});
describe("refreshAtContextPercent edge cases", () => {
it("should accept refreshAtContextPercent of 1 (minimum valid)", () => {
const config = memoryNeo4jConfigSchema.parse({
neo4j: { uri: "bolt://localhost:7687", password: "" },
embedding: { provider: "ollama" },
coreMemory: { refreshAtContextPercent: 1 },
});
expect(config.coreMemory.refreshAtContextPercent).toBe(1);
});
it("should accept refreshAtContextPercent of 100 (maximum valid)", () => {
const config = memoryNeo4jConfigSchema.parse({
neo4j: { uri: "bolt://localhost:7687", password: "" },
embedding: { provider: "ollama" },
coreMemory: { refreshAtContextPercent: 100 },
});
expect(config.coreMemory.refreshAtContextPercent).toBe(100);
});
it("should reject refreshAtContextPercent of 0", () => {
const config = memoryNeo4jConfigSchema.parse({
neo4j: { uri: "bolt://localhost:7687", password: "" },
embedding: { provider: "ollama" },
coreMemory: { refreshAtContextPercent: 0 },
});
expect(config.coreMemory.refreshAtContextPercent).toBeUndefined();
});
it("should reject refreshAtContextPercent over 100 by throwing", () => {
expect(() =>
memoryNeo4jConfigSchema.parse({
neo4j: { uri: "bolt://localhost:7687", password: "" },
embedding: { provider: "ollama" },
coreMemory: { refreshAtContextPercent: 150 },
}),
).toThrow("coreMemory.refreshAtContextPercent must be between 1 and 100");
});
it("should reject negative refreshAtContextPercent", () => {
const config = memoryNeo4jConfigSchema.parse({
neo4j: { uri: "bolt://localhost:7687", password: "" },
embedding: { provider: "ollama" },
coreMemory: { refreshAtContextPercent: -10 },
});
expect(config.coreMemory.refreshAtContextPercent).toBeUndefined();
});
it("should ignore non-number refreshAtContextPercent", () => {
const config = memoryNeo4jConfigSchema.parse({
neo4j: { uri: "bolt://localhost:7687", password: "" },
embedding: { provider: "ollama" },
coreMemory: { refreshAtContextPercent: "50" },
});
expect(config.coreMemory.refreshAtContextPercent).toBeUndefined();
});
});
describe("extraction config section", () => {
it("should parse extraction config when provided", () => {
process.env.EXTRACTION_DUMMY = ""; // avoid env var issues
const config = memoryNeo4jConfigSchema.parse({
neo4j: { uri: "bolt://localhost:7687", password: "" },
embedding: { provider: "ollama" },
extraction: {
apiKey: "or-test-key",
model: "google/gemini-2.0-flash-001",
baseUrl: "https://openrouter.ai/api/v1",
},
});
expect(config.extraction).toBeDefined();
expect(config.extraction!.apiKey).toBe("or-test-key");
expect(config.extraction!.model).toBe("google/gemini-2.0-flash-001");
});
it("should not include extraction when section is empty", () => {
const config = memoryNeo4jConfigSchema.parse({
neo4j: { uri: "bolt://localhost:7687", password: "" },
embedding: { provider: "ollama" },
extraction: {},
});
expect(config.extraction).toBeUndefined();
});
it("should reject unknown keys in extraction section", () => {
expect(() =>
memoryNeo4jConfigSchema.parse({
neo4j: { uri: "bolt://localhost:7687", password: "" },
embedding: { provider: "ollama" },
extraction: { badKey: "value" },
}),
).toThrow("unknown keys: badKey");
});
});
});
// ============================================================================
// vectorDimsForModel()
// ============================================================================
describe("vectorDimsForModel", () => {
describe("known models", () => {
it("should return 1536 for text-embedding-3-small", () => {
expect(vectorDimsForModel("text-embedding-3-small")).toBe(1536);
});
it("should return 3072 for text-embedding-3-large", () => {
expect(vectorDimsForModel("text-embedding-3-large")).toBe(3072);
});
it("should return 1024 for mxbai-embed-large", () => {
expect(vectorDimsForModel("mxbai-embed-large")).toBe(1024);
});
it("should return 768 for nomic-embed-text", () => {
expect(vectorDimsForModel("nomic-embed-text")).toBe(768);
});
it("should return 384 for all-minilm", () => {
expect(vectorDimsForModel("all-minilm")).toBe(384);
});
});
describe("prefix matching", () => {
it("should match versioned model names via prefix", () => {
// mxbai-embed-large:latest should match mxbai-embed-large
expect(vectorDimsForModel("mxbai-embed-large:latest")).toBe(1024);
});
it("should match model with additional version suffix", () => {
expect(vectorDimsForModel("nomic-embed-text:v1.5")).toBe(768);
});
});
describe("unknown models", () => {
it("should return default 1024 for unknown model", () => {
expect(vectorDimsForModel("unknown-model")).toBe(1024);
});
it("should return default 1024 for empty string", () => {
expect(vectorDimsForModel("")).toBe(1024);
});
it("should return default 1024 for unrecognized prefix", () => {
expect(vectorDimsForModel("custom-embed-v2")).toBe(1024);
});
});
});
// ============================================================================
// resolveExtractionConfig()
// ============================================================================
describe("resolveExtractionConfig", () => {
const originalEnv = { ...process.env };
afterEach(() => {
process.env = { ...originalEnv };
});
it("should return disabled config when no API key or explicit baseUrl", () => {
delete process.env.OPENROUTER_API_KEY;
const config = resolveExtractionConfig();
expect(config.enabled).toBe(false);
expect(config.apiKey).toBe("");
});
it("should enable when OPENROUTER_API_KEY env var is set", () => {
process.env.OPENROUTER_API_KEY = "or-env-key";
const config = resolveExtractionConfig();
expect(config.enabled).toBe(true);
expect(config.apiKey).toBe("or-env-key");
});
it("should enable when plugin config provides apiKey", () => {
delete process.env.OPENROUTER_API_KEY;
const config = resolveExtractionConfig({
apiKey: "or-plugin-key",
model: "custom-model",
baseUrl: "https://custom.ai/api",
});
expect(config.enabled).toBe(true);
expect(config.apiKey).toBe("or-plugin-key");
expect(config.model).toBe("custom-model");
expect(config.baseUrl).toBe("https://custom.ai/api");
});
it("should enable when baseUrl is explicitly set (local Ollama, no API key)", () => {
delete process.env.OPENROUTER_API_KEY;
const config = resolveExtractionConfig({
model: "llama3",
baseUrl: "http://localhost:11434/v1",
});
expect(config.enabled).toBe(true);
expect(config.apiKey).toBe("");
expect(config.baseUrl).toBe("http://localhost:11434/v1");
});
it("should use defaults for model and baseUrl", () => {
delete process.env.OPENROUTER_API_KEY;
delete process.env.EXTRACTION_MODEL;
delete process.env.EXTRACTION_BASE_URL;
const config = resolveExtractionConfig();
expect(config.model).toBe("google/gemini-2.0-flash-001");
expect(config.baseUrl).toBe("https://openrouter.ai/api/v1");
});
it("should use EXTRACTION_MODEL env var", () => {
delete process.env.OPENROUTER_API_KEY;
process.env.EXTRACTION_MODEL = "meta/llama-3-70b";
const config = resolveExtractionConfig();
expect(config.model).toBe("meta/llama-3-70b");
});
it("should use EXTRACTION_BASE_URL env var", () => {
delete process.env.OPENROUTER_API_KEY;
process.env.EXTRACTION_BASE_URL = "https://my-proxy.ai/v1";
const config = resolveExtractionConfig();
expect(config.baseUrl).toBe("https://my-proxy.ai/v1");
});
it("should always set temperature to 0.0 and maxRetries to 2", () => {
const config = resolveExtractionConfig();
expect(config.temperature).toBe(0.0);
expect(config.maxRetries).toBe(2);
});
});

View File

@@ -63,7 +63,7 @@ export const MEMORY_CATEGORIES = [
export type MemoryCategory = (typeof MEMORY_CATEGORIES)[number];
const EMBEDDING_DIMENSIONS: Record<string, number> = {
export const EMBEDDING_DIMENSIONS: Record<string, number> = {
// OpenAI models
"text-embedding-3-small": 1536,
"text-embedding-3-large": 3072,
@@ -75,7 +75,7 @@ const EMBEDDING_DIMENSIONS: Record<string, number> = {
};
// Default dimension for unknown models (Ollama models vary)
const DEFAULT_EMBEDDING_DIMS = 1024;
export const DEFAULT_EMBEDDING_DIMS = 1024;
export function vectorDimsForModel(model: string): number {
// Check exact match first
@@ -88,7 +88,8 @@ export function vectorDimsForModel(model: string): number {
return dims;
}
}
// Return default for unknown models
// Return default for unknown models — callers should warn when this path is taken,
// as the default 1024 dimensions may not match the actual model's output.
return DEFAULT_EMBEDDING_DIMS;
}
@@ -164,6 +165,20 @@ export const memoryNeo4jConfigSchema = {
if (typeof neo4jRaw.uri !== "string" || !neo4jRaw.uri) {
throw new Error("neo4j.uri is required");
}
// Validate URI scheme — must be a valid Neo4j connection protocol
const VALID_NEO4J_SCHEMES = [
"bolt://",
"bolt+s://",
"bolt+ssc://",
"neo4j://",
"neo4j+s://",
"neo4j+ssc://",
];
if (!VALID_NEO4J_SCHEMES.some((scheme) => neo4jRaw.uri.startsWith(scheme))) {
throw new Error(
`neo4j.uri must start with a valid scheme (${VALID_NEO4J_SCHEMES.join(", ")}), got: "${neo4jRaw.uri}"`,
);
}
const neo4jPassword =
typeof neo4jRaw.password === "string" ? resolveEnvVars(neo4jRaw.password) : "";
@@ -212,7 +227,19 @@ export const memoryNeo4jConfigSchema = {
const coreMemoryEnabled = coreMemoryRaw?.enabled !== false; // enabled by default
const coreMemoryMaxEntries =
typeof coreMemoryRaw?.maxEntries === "number" ? coreMemoryRaw.maxEntries : 50;
// refreshAtContextPercent: number between 0-100, or undefined to disable
if (coreMemoryMaxEntries <= 0) {
throw new Error(`coreMemory.maxEntries must be greater than 0, got: ${coreMemoryMaxEntries}`);
}
// refreshAtContextPercent: number between 1-99 to be effective, or undefined to disable.
// Values at 0 or below are ignored (disables refresh). Values above 100 are invalid.
if (
typeof coreMemoryRaw?.refreshAtContextPercent === "number" &&
coreMemoryRaw.refreshAtContextPercent > 100
) {
throw new Error(
`coreMemory.refreshAtContextPercent must be between 1 and 100, got: ${coreMemoryRaw.refreshAtContextPercent}`,
);
}
const refreshAtContextPercent =
typeof coreMemoryRaw?.refreshAtContextPercent === "number" &&
coreMemoryRaw.refreshAtContextPercent > 0 &&

View File

@@ -0,0 +1,192 @@
/**
* Tests for embeddings.ts — Embedding Provider.
*
* Tests the Embeddings class with mocked OpenAI client and mocked fetch for Ollama.
*/
import { describe, it, expect, vi, afterEach } from "vitest";
// ============================================================================
// Constructor
// ============================================================================
describe("Embeddings constructor", () => {
it("should throw when OpenAI provider is used without API key", async () => {
const { Embeddings } = await import("./embeddings.js");
expect(() => new Embeddings(undefined, "text-embedding-3-small", "openai")).toThrow(
"API key required for OpenAI embeddings",
);
});
it("should not require API key for ollama provider", async () => {
const { Embeddings } = await import("./embeddings.js");
const emb = new Embeddings(undefined, "mxbai-embed-large", "ollama");
expect(emb).toBeDefined();
});
});
// ============================================================================
// Ollama embed
// ============================================================================
describe("Embeddings - Ollama provider", () => {
const originalFetch = globalThis.fetch;
afterEach(() => {
globalThis.fetch = originalFetch;
});
it("should call Ollama API with correct request body", async () => {
const { Embeddings } = await import("./embeddings.js");
const mockVector = [0.1, 0.2, 0.3, 0.4];
globalThis.fetch = vi.fn().mockResolvedValue({
ok: true,
json: () => Promise.resolve({ embeddings: [mockVector] }),
});
const emb = new Embeddings(undefined, "mxbai-embed-large", "ollama");
const result = await emb.embed("test text");
expect(result).toEqual(mockVector);
expect(globalThis.fetch).toHaveBeenCalledWith(
"http://localhost:11434/api/embed",
expect.objectContaining({
method: "POST",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({
model: "mxbai-embed-large",
input: "test text",
}),
}),
);
});
it("should use custom baseUrl for Ollama", async () => {
const { Embeddings } = await import("./embeddings.js");
const mockVector = [0.5, 0.6];
globalThis.fetch = vi.fn().mockResolvedValue({
ok: true,
json: () => Promise.resolve({ embeddings: [mockVector] }),
});
const emb = new Embeddings(undefined, "mxbai-embed-large", "ollama", "http://my-host:11434");
await emb.embed("test");
expect(globalThis.fetch).toHaveBeenCalledWith(
"http://my-host:11434/api/embed",
expect.any(Object),
);
});
it("should throw when Ollama returns error status", async () => {
const { Embeddings } = await import("./embeddings.js");
globalThis.fetch = vi.fn().mockResolvedValue({
ok: false,
status: 500,
text: () => Promise.resolve("Internal Server Error"),
});
const emb = new Embeddings(undefined, "mxbai-embed-large", "ollama");
await expect(emb.embed("test")).rejects.toThrow("Ollama embedding failed: 500");
});
it("should throw when Ollama returns no embeddings", async () => {
const { Embeddings } = await import("./embeddings.js");
globalThis.fetch = vi.fn().mockResolvedValue({
ok: true,
json: () => Promise.resolve({ embeddings: [] }),
});
const emb = new Embeddings(undefined, "mxbai-embed-large", "ollama");
await expect(emb.embed("test")).rejects.toThrow("No embedding returned from Ollama");
});
it("should throw when Ollama returns null embeddings", async () => {
const { Embeddings } = await import("./embeddings.js");
globalThis.fetch = vi.fn().mockResolvedValue({
ok: true,
json: () => Promise.resolve({}),
});
const emb = new Embeddings(undefined, "mxbai-embed-large", "ollama");
await expect(emb.embed("test")).rejects.toThrow("No embedding returned from Ollama");
});
it("should propagate fetch errors for Ollama", async () => {
const { Embeddings } = await import("./embeddings.js");
globalThis.fetch = vi.fn().mockRejectedValue(new Error("Network error"));
const emb = new Embeddings(undefined, "mxbai-embed-large", "ollama");
await expect(emb.embed("test")).rejects.toThrow("Network error");
});
});
// ============================================================================
// OpenAI embed (via mocked client internals)
// ============================================================================
describe("Embeddings - OpenAI provider", () => {
it("should create instance with OpenAI provider when API key provided", async () => {
const { Embeddings } = await import("./embeddings.js");
// Just verify construction succeeds with valid params
const emb = new Embeddings("sk-test-key", "text-embedding-3-small", "openai");
expect(emb).toBeDefined();
});
it("should have embed and embedBatch methods", async () => {
const { Embeddings } = await import("./embeddings.js");
const emb = new Embeddings("sk-test-key", "text-embedding-3-small", "openai");
expect(typeof emb.embed).toBe("function");
expect(typeof emb.embedBatch).toBe("function");
});
});
// ============================================================================
// Batch embedding
// ============================================================================
describe("Embeddings - embedBatch", () => {
const originalFetch = globalThis.fetch;
afterEach(() => {
globalThis.fetch = originalFetch;
});
it("should return empty array for empty input (openai)", async () => {
const { Embeddings } = await import("./embeddings.js");
const emb = new Embeddings("sk-test", "text-embedding-3-small", "openai");
const results = await emb.embedBatch([]);
expect(results).toEqual([]);
});
it("should return empty array for empty input (ollama)", async () => {
const { Embeddings } = await import("./embeddings.js");
const emb = new Embeddings(undefined, "mxbai-embed-large", "ollama");
const results = await emb.embedBatch([]);
expect(results).toEqual([]);
});
it("should use sequential calls for Ollama batch (no native batch support)", async () => {
const { Embeddings } = await import("./embeddings.js");
let callCount = 0;
globalThis.fetch = vi.fn().mockImplementation(() => {
callCount++;
return Promise.resolve({
ok: true,
json: () => Promise.resolve({ embeddings: [[callCount * 0.1, callCount * 0.2]] }),
});
});
const emb = new Embeddings(undefined, "mxbai-embed-large", "ollama");
const results = await emb.embedBatch(["text1", "text2", "text3"]);
// Should make 3 separate calls
expect(globalThis.fetch).toHaveBeenCalledTimes(3);
expect(results).toHaveLength(3);
// Each result should be a vector
for (const r of results) {
expect(Array.isArray(r)).toBe(true);
expect(r.length).toBe(2);
}
});
});

View File

@@ -7,19 +7,29 @@
import OpenAI from "openai";
import type { EmbeddingProvider } from "./config.js";
type Logger = {
info: (msg: string) => void;
warn: (msg: string) => void;
error: (msg: string) => void;
debug?: (msg: string) => void;
};
export class Embeddings {
private client: OpenAI | null = null;
private readonly provider: EmbeddingProvider;
private readonly baseUrl: string;
private readonly logger: Logger | undefined;
constructor(
private readonly apiKey: string | undefined,
private readonly model: string = "text-embedding-3-small",
provider: EmbeddingProvider = "openai",
baseUrl?: string,
logger?: Logger,
) {
this.provider = provider;
this.baseUrl = baseUrl ?? (provider === "ollama" ? "http://localhost:11434" : "");
this.logger = logger;
if (provider === "openai") {
if (!apiKey) {
@@ -42,6 +52,9 @@ export class Embeddings {
/**
* Generate embeddings for multiple texts.
* Returns array of embeddings in the same order as input.
*
* For Ollama: uses Promise.allSettled so individual failures don't break the
* entire batch. Failed embeddings are replaced with zero vectors and logged.
*/
async embedBatch(texts: string[]): Promise<number[][]> {
if (texts.length === 0) {
@@ -49,8 +62,32 @@ export class Embeddings {
}
if (this.provider === "ollama") {
// Ollama doesn't support batch, so we do sequential
return Promise.all(texts.map((t) => this.embedOllama(t)));
// Ollama doesn't support batch, so we do sequential with resilient error handling
const results = await Promise.allSettled(texts.map((t) => this.embedOllama(t)));
const embeddings: number[][] = [];
let failures = 0;
for (let i = 0; i < results.length; i++) {
const result = results[i];
if (result.status === "fulfilled") {
embeddings.push(result.value);
} else {
failures++;
this.logger?.warn?.(
`memory-neo4j: Ollama embedding failed for text ${i}: ${String(result.reason)}`,
);
// Use zero vector as placeholder so indices stay aligned
embeddings.push([]);
}
}
if (failures > 0) {
this.logger?.warn?.(
`memory-neo4j: ${failures}/${texts.length} Ollama embeddings failed in batch`,
);
}
return embeddings;
}
return this.embedBatchOpenAI(texts);
@@ -79,6 +116,9 @@ export class Embeddings {
return response.data.toSorted((a, b) => a.index - b.index).map((d) => d.embedding);
}
// Timeout for Ollama embedding fetch calls to prevent hanging indefinitely
private static readonly EMBED_TIMEOUT_MS = 30_000;
private async embedOllama(text: string): Promise<number[]> {
const url = `${this.baseUrl}/api/embed`;
const response = await fetch(url, {
@@ -88,6 +128,7 @@ export class Embeddings {
model: this.model,
input: text,
}),
signal: AbortSignal.timeout(Embeddings.EMBED_TIMEOUT_MS),
});
if (!response.ok) {

View File

@@ -0,0 +1,760 @@
/**
* Tests for extractor.ts — Extraction Logic.
*
* Tests exported functions: extractEntities(), extractUserMessages(), runBackgroundExtraction().
* Note: validateExtractionResult() is not exported; it is tested indirectly through extractEntities().
* Note: passesAttentionGate() is defined in index.ts and not exported; cannot be tested directly.
*/
import { describe, it, expect, vi, beforeEach, afterEach } from "vitest";
import type { ExtractionConfig } from "./config.js";
import { extractUserMessages, extractEntities, runBackgroundExtraction } from "./extractor.js";
// ============================================================================
// extractUserMessages()
// ============================================================================
describe("extractUserMessages", () => {
it("should extract string content from user messages", () => {
const messages = [
{ role: "user", content: "I prefer TypeScript over JavaScript" },
{ role: "user", content: "My favorite color is blue" },
];
const result = extractUserMessages(messages);
expect(result).toEqual(["I prefer TypeScript over JavaScript", "My favorite color is blue"]);
});
it("should extract text from content block arrays", () => {
const messages = [
{
role: "user",
content: [
{ type: "text", text: "Hello, this is a content block message" },
{ type: "image", url: "http://example.com/img.png" },
{ type: "text", text: "Another text block in same message" },
],
},
];
const result = extractUserMessages(messages);
expect(result).toEqual([
"Hello, this is a content block message",
"Another text block in same message",
]);
});
it("should filter out assistant messages", () => {
const messages = [
{ role: "user", content: "This is a user message that is long enough" },
{ role: "assistant", content: "This is an assistant message" },
];
const result = extractUserMessages(messages);
expect(result).toEqual(["This is a user message that is long enough"]);
});
it("should filter out system messages", () => {
const messages = [
{ role: "system", content: "You are a helpful assistant with context" },
{ role: "user", content: "This is a user message that is long enough" },
];
const result = extractUserMessages(messages);
expect(result).toEqual(["This is a user message that is long enough"]);
});
it("should filter out messages shorter than 10 characters", () => {
const messages = [
{ role: "user", content: "short" }, // 5 chars
{ role: "user", content: "1234567890" }, // exactly 10 chars
{ role: "user", content: "This is longer than ten characters" },
];
const result = extractUserMessages(messages);
expect(result).toEqual(["1234567890", "This is longer than ten characters"]);
});
it("should filter out messages containing <relevant-memories>", () => {
const messages = [
{ role: "user", content: "Normal user message that is long enough here" },
{
role: "user",
content:
"<relevant-memories>Some injected context that should be ignored</relevant-memories>",
},
];
const result = extractUserMessages(messages);
expect(result).toEqual(["Normal user message that is long enough here"]);
});
it("should filter out messages containing <system>", () => {
const messages = [
{ role: "user", content: "<system>System markup that should be filtered</system>" },
{ role: "user", content: "Normal user message that is long enough here" },
];
const result = extractUserMessages(messages);
expect(result).toEqual(["Normal user message that is long enough here"]);
});
it("should handle null and non-object messages gracefully", () => {
const messages = [
null,
undefined,
"not an object",
42,
{ role: "user", content: "Valid message with enough length" },
];
const result = extractUserMessages(messages as unknown[]);
expect(result).toEqual(["Valid message with enough length"]);
});
it("should return empty array when no user messages exist", () => {
const messages = [{ role: "assistant", content: "Only assistant messages" }];
const result = extractUserMessages(messages);
expect(result).toEqual([]);
});
it("should return empty array for empty input", () => {
expect(extractUserMessages([])).toEqual([]);
});
it("should handle messages where content is neither string nor array", () => {
const messages = [
{ role: "user", content: 42 },
{ role: "user", content: null },
{ role: "user", content: { nested: true } },
];
const result = extractUserMessages(messages as unknown[]);
expect(result).toEqual([]);
});
});
// ============================================================================
// extractEntities() — tests validateExtractionResult() indirectly
// ============================================================================
describe("extractEntities", () => {
// We need to mock `fetch` since callOpenRouter uses global fetch
const originalFetch = globalThis.fetch;
beforeEach(() => {
vi.restoreAllMocks();
});
afterEach(() => {
globalThis.fetch = originalFetch;
});
const enabledConfig: ExtractionConfig = {
enabled: true,
apiKey: "test-key",
model: "test-model",
baseUrl: "https://test.ai/api/v1",
temperature: 0.0,
maxRetries: 0, // No retries in tests
};
const disabledConfig: ExtractionConfig = {
...enabledConfig,
enabled: false,
};
function mockFetchResponse(content: string, status = 200) {
globalThis.fetch = vi.fn().mockResolvedValue({
ok: status >= 200 && status < 300,
status,
text: () => Promise.resolve(content),
json: () =>
Promise.resolve({
choices: [{ message: { content } }],
}),
});
}
it("should return null result when extraction is disabled", async () => {
const { result, transientFailure } = await extractEntities("test text", disabledConfig);
expect(result).toBeNull();
expect(transientFailure).toBe(false);
});
it("should extract valid entities from LLM response", async () => {
mockFetchResponse(
JSON.stringify({
category: "fact",
entities: [
{ name: "Tarun", type: "person", aliases: ["boss"], description: "The CEO" },
{ name: "Abundent", type: "organization" },
],
relationships: [
{ source: "Tarun", target: "Abundent", type: "WORKS_AT", confidence: 0.95 },
],
tags: [{ name: "Leadership", category: "business" }],
}),
);
const { result } = await extractEntities("Tarun works at Abundent", enabledConfig);
expect(result).not.toBeNull();
expect(result!.category).toBe("fact");
// Entities should be normalized to lowercase
expect(result!.entities).toHaveLength(2);
expect(result!.entities[0].name).toBe("tarun");
expect(result!.entities[0].type).toBe("person");
expect(result!.entities[0].aliases).toEqual(["boss"]);
expect(result!.entities[0].description).toBe("The CEO");
expect(result!.entities[1].name).toBe("abundent");
expect(result!.entities[1].type).toBe("organization");
// Relationships should be normalized to lowercase source/target
expect(result!.relationships).toHaveLength(1);
expect(result!.relationships[0].source).toBe("tarun");
expect(result!.relationships[0].target).toBe("abundent");
expect(result!.relationships[0].type).toBe("WORKS_AT");
expect(result!.relationships[0].confidence).toBe(0.95);
// Tags should be normalized to lowercase
expect(result!.tags).toHaveLength(1);
expect(result!.tags[0].name).toBe("leadership");
expect(result!.tags[0].category).toBe("business");
});
it("should handle empty extraction result", async () => {
mockFetchResponse(
JSON.stringify({
category: "other",
entities: [],
relationships: [],
tags: [],
}),
);
const { result } = await extractEntities("just a greeting", enabledConfig);
expect(result).not.toBeNull();
expect(result!.entities).toEqual([]);
expect(result!.relationships).toEqual([]);
expect(result!.tags).toEqual([]);
});
it("should handle missing fields in LLM response", async () => {
mockFetchResponse(
JSON.stringify({
// No category, entities, relationships, or tags
}),
);
const { result } = await extractEntities("some text", enabledConfig);
expect(result).not.toBeNull();
expect(result!.category).toBeUndefined();
expect(result!.entities).toEqual([]);
expect(result!.relationships).toEqual([]);
expect(result!.tags).toEqual([]);
});
it("should filter out invalid entity types (fallback to concept)", async () => {
mockFetchResponse(
JSON.stringify({
entities: [
{ name: "Widget", type: "gadget" }, // invalid type -> concept
{ name: "Paris", type: "location" }, // valid type
],
relationships: [],
tags: [],
}),
);
const { result } = await extractEntities("test", enabledConfig);
expect(result!.entities).toHaveLength(2);
expect(result!.entities[0].type).toBe("concept"); // invalid type falls back to concept
expect(result!.entities[1].type).toBe("location");
});
it("should filter out invalid relationship types", async () => {
mockFetchResponse(
JSON.stringify({
entities: [],
relationships: [
{ source: "a", target: "b", type: "WORKS_AT", confidence: 0.9 }, // valid
{ source: "a", target: "b", type: "HATES", confidence: 0.9 }, // invalid type
],
tags: [],
}),
);
const { result } = await extractEntities("test", enabledConfig);
expect(result!.relationships).toHaveLength(1);
expect(result!.relationships[0].type).toBe("WORKS_AT");
});
it("should clamp confidence to 0-1 range", async () => {
mockFetchResponse(
JSON.stringify({
entities: [],
relationships: [
{ source: "a", target: "b", type: "KNOWS", confidence: 1.5 }, // over 1
{ source: "c", target: "d", type: "KNOWS", confidence: -0.5 }, // under 0
],
tags: [],
}),
);
const { result } = await extractEntities("test", enabledConfig);
expect(result!.relationships[0].confidence).toBe(1);
expect(result!.relationships[1].confidence).toBe(0);
});
it("should default confidence to 0.7 when not a number", async () => {
mockFetchResponse(
JSON.stringify({
entities: [],
relationships: [{ source: "a", target: "b", type: "KNOWS", confidence: "high" }],
tags: [],
}),
);
const { result } = await extractEntities("test", enabledConfig);
expect(result!.relationships[0].confidence).toBe(0.7);
});
it("should filter out entities without name", async () => {
mockFetchResponse(
JSON.stringify({
entities: [
{ name: "", type: "person" }, // empty name -> filtered
{ name: " ", type: "person" }, // whitespace-only name -> filtered (after trim)
{ name: "valid", type: "person" }, // valid
],
relationships: [],
tags: [],
}),
);
const { result } = await extractEntities("test", enabledConfig);
expect(result!.entities).toHaveLength(1);
expect(result!.entities[0].name).toBe("valid");
});
it("should filter out entities with non-object shape", async () => {
mockFetchResponse(
JSON.stringify({
entities: [null, "not an entity", 42, { name: "valid", type: "person" }],
relationships: [],
tags: [],
}),
);
const { result } = await extractEntities("test", enabledConfig);
expect(result!.entities).toHaveLength(1);
});
it("should filter out entities missing required fields", async () => {
mockFetchResponse(
JSON.stringify({
entities: [
{ type: "person" }, // missing name
{ name: "test" }, // missing type
{ name: "valid", type: "person" }, // has both
],
relationships: [],
tags: [],
}),
);
const { result } = await extractEntities("test", enabledConfig);
expect(result!.entities).toHaveLength(1);
expect(result!.entities[0].name).toBe("valid");
});
it("should default tag category to 'topic' when missing", async () => {
mockFetchResponse(
JSON.stringify({
entities: [],
relationships: [],
tags: [{ name: "neo4j" }], // no category
}),
);
const { result } = await extractEntities("test", enabledConfig);
expect(result!.tags[0].category).toBe("topic");
});
it("should filter out tags with empty names", async () => {
mockFetchResponse(
JSON.stringify({
entities: [],
relationships: [],
tags: [
{ name: "", category: "tech" }, // empty -> filtered
{ name: " ", category: "tech" }, // whitespace-only -> filtered
{ name: "valid", category: "tech" },
],
}),
);
const { result } = await extractEntities("test", enabledConfig);
expect(result!.tags).toHaveLength(1);
expect(result!.tags[0].name).toBe("valid");
});
it("should reject invalid category values", async () => {
mockFetchResponse(
JSON.stringify({
category: "invalid-category",
entities: [],
relationships: [],
tags: [],
}),
);
const { result } = await extractEntities("test", enabledConfig);
expect(result!.category).toBeUndefined();
});
it("should accept valid category values", async () => {
for (const category of ["preference", "fact", "decision", "entity", "other"]) {
mockFetchResponse(
JSON.stringify({
category,
entities: [],
relationships: [],
tags: [],
}),
);
const { result } = await extractEntities(`test ${category}`, enabledConfig);
expect(result!.category).toBe(category);
}
});
it("should return null result for malformed JSON response (permanent failure)", async () => {
mockFetchResponse("not valid json at all");
const { result, transientFailure } = await extractEntities("test", enabledConfig);
// callOpenRouter returns the raw string, JSON.parse fails, catch returns null
expect(result).toBeNull();
expect(transientFailure).toBe(false);
});
it("should return null result when API returns error status", async () => {
globalThis.fetch = vi.fn().mockResolvedValue({
ok: false,
status: 500,
text: () => Promise.resolve("Internal Server Error"),
});
const { result } = await extractEntities("test", enabledConfig);
// API error 500 is not in the transient list (only 429, 502, 503, 504)
expect(result).toBeNull();
});
it("should return null result when API returns no content", async () => {
globalThis.fetch = vi.fn().mockResolvedValue({
ok: true,
status: 200,
json: () => Promise.resolve({ choices: [{ message: { content: null } }] }),
});
const { result, transientFailure } = await extractEntities("test", enabledConfig);
expect(result).toBeNull();
expect(transientFailure).toBe(false);
});
it("should normalize alias strings to lowercase", async () => {
mockFetchResponse(
JSON.stringify({
entities: [{ name: "John", type: "person", aliases: ["Johnny", "JOHN", "j.doe"] }],
relationships: [],
tags: [],
}),
);
const { result } = await extractEntities("test", enabledConfig);
expect(result!.entities[0].aliases).toEqual(["johnny", "john", "j.doe"]);
});
it("should filter out non-string aliases", async () => {
mockFetchResponse(
JSON.stringify({
entities: [{ name: "John", type: "person", aliases: ["valid", 42, null, "also-valid"] }],
relationships: [],
tags: [],
}),
);
const { result } = await extractEntities("test", enabledConfig);
expect(result!.entities[0].aliases).toEqual(["valid", "also-valid"]);
});
});
// ============================================================================
// runBackgroundExtraction()
// ============================================================================
describe("runBackgroundExtraction", () => {
const originalFetch = globalThis.fetch;
let mockLogger: {
info: ReturnType<typeof vi.fn>;
warn: ReturnType<typeof vi.fn>;
error: ReturnType<typeof vi.fn>;
debug: ReturnType<typeof vi.fn>;
};
let mockDb: {
updateExtractionStatus: ReturnType<typeof vi.fn>;
mergeEntity: ReturnType<typeof vi.fn>;
createMentions: ReturnType<typeof vi.fn>;
createEntityRelationship: ReturnType<typeof vi.fn>;
tagMemory: ReturnType<typeof vi.fn>;
updateMemoryCategory: ReturnType<typeof vi.fn>;
};
let mockEmbeddings: {
embed: ReturnType<typeof vi.fn>;
embedBatch: ReturnType<typeof vi.fn>;
};
beforeEach(() => {
vi.restoreAllMocks();
mockLogger = {
info: vi.fn(),
warn: vi.fn(),
error: vi.fn(),
debug: vi.fn(),
};
mockDb = {
updateExtractionStatus: vi.fn().mockResolvedValue(undefined),
mergeEntity: vi.fn().mockResolvedValue(undefined),
createMentions: vi.fn().mockResolvedValue(undefined),
createEntityRelationship: vi.fn().mockResolvedValue(undefined),
tagMemory: vi.fn().mockResolvedValue(undefined),
updateMemoryCategory: vi.fn().mockResolvedValue(undefined),
};
mockEmbeddings = {
embed: vi.fn().mockResolvedValue([0.1, 0.2, 0.3]),
embedBatch: vi.fn().mockResolvedValue([[0.1, 0.2, 0.3]]),
};
});
afterEach(() => {
globalThis.fetch = originalFetch;
});
const enabledConfig: ExtractionConfig = {
enabled: true,
apiKey: "test-key",
model: "test-model",
baseUrl: "https://test.ai/api/v1",
temperature: 0.0,
maxRetries: 0,
};
const disabledConfig: ExtractionConfig = {
...enabledConfig,
enabled: false,
};
function mockFetchResponse(content: string) {
globalThis.fetch = vi.fn().mockResolvedValue({
ok: true,
status: 200,
json: () =>
Promise.resolve({
choices: [{ message: { content } }],
}),
});
}
it("should skip extraction and mark as 'skipped' when disabled", async () => {
await runBackgroundExtraction(
"mem-1",
"test text",
mockDb as never,
mockEmbeddings as never,
disabledConfig,
mockLogger,
);
expect(mockDb.updateExtractionStatus).toHaveBeenCalledWith("mem-1", "skipped");
});
it("should mark as 'failed' when extraction returns null", async () => {
globalThis.fetch = vi.fn().mockResolvedValue({
ok: false,
status: 500,
text: () => Promise.resolve("error"),
});
await runBackgroundExtraction(
"mem-1",
"test text",
mockDb as never,
mockEmbeddings as never,
enabledConfig,
mockLogger,
);
expect(mockDb.updateExtractionStatus).toHaveBeenCalledWith("mem-1", "failed");
});
it("should mark as 'complete' when extraction result is empty", async () => {
mockFetchResponse(
JSON.stringify({
entities: [],
relationships: [],
tags: [],
}),
);
await runBackgroundExtraction(
"mem-1",
"test text",
mockDb as never,
mockEmbeddings as never,
enabledConfig,
mockLogger,
);
expect(mockDb.updateExtractionStatus).toHaveBeenCalledWith("mem-1", "complete");
});
it("should merge entities, create mentions, and mark complete", async () => {
mockFetchResponse(
JSON.stringify({
category: "fact",
entities: [{ name: "Alice", type: "person" }],
relationships: [],
tags: [],
}),
);
await runBackgroundExtraction(
"mem-1",
"Alice is a developer",
mockDb as never,
mockEmbeddings as never,
enabledConfig,
mockLogger,
);
expect(mockDb.mergeEntity).toHaveBeenCalledWith(
expect.objectContaining({
name: "alice",
type: "person",
}),
);
expect(mockDb.createMentions).toHaveBeenCalledWith("mem-1", "alice", "context", 1.0);
expect(mockDb.updateMemoryCategory).toHaveBeenCalledWith("mem-1", "fact");
expect(mockDb.updateExtractionStatus).toHaveBeenCalledWith("mem-1", "complete");
});
it("should create entity relationships", async () => {
mockFetchResponse(
JSON.stringify({
entities: [
{ name: "Alice", type: "person" },
{ name: "Acme", type: "organization" },
],
relationships: [{ source: "Alice", target: "Acme", type: "WORKS_AT", confidence: 0.9 }],
tags: [],
}),
);
await runBackgroundExtraction(
"mem-1",
"Alice works at Acme",
mockDb as never,
mockEmbeddings as never,
enabledConfig,
mockLogger,
);
expect(mockDb.createEntityRelationship).toHaveBeenCalledWith("alice", "acme", "WORKS_AT", 0.9);
});
it("should tag memories", async () => {
mockFetchResponse(
JSON.stringify({
entities: [],
relationships: [],
tags: [{ name: "Programming", category: "tech" }],
}),
);
await runBackgroundExtraction(
"mem-1",
"test text",
mockDb as never,
mockEmbeddings as never,
enabledConfig,
mockLogger,
);
expect(mockDb.tagMemory).toHaveBeenCalledWith("mem-1", "programming", "tech");
});
it("should not update category when result has no category", async () => {
mockFetchResponse(
JSON.stringify({
entities: [{ name: "Test", type: "concept" }],
relationships: [],
tags: [],
}),
);
await runBackgroundExtraction(
"mem-1",
"test",
mockDb as never,
mockEmbeddings as never,
enabledConfig,
mockLogger,
);
expect(mockDb.updateMemoryCategory).not.toHaveBeenCalled();
});
it("should handle entity merge failure gracefully", async () => {
mockFetchResponse(
JSON.stringify({
entities: [
{ name: "Alice", type: "person" },
{ name: "Bob", type: "person" },
],
relationships: [],
tags: [],
}),
);
// First entity merge fails, second succeeds
mockDb.mergeEntity.mockRejectedValueOnce(new Error("merge failed"));
mockDb.mergeEntity.mockResolvedValueOnce(undefined);
await runBackgroundExtraction(
"mem-1",
"Alice and Bob",
mockDb as never,
mockEmbeddings as never,
enabledConfig,
mockLogger,
);
// Should still continue and complete
expect(mockDb.mergeEntity).toHaveBeenCalledTimes(2);
expect(mockDb.updateExtractionStatus).toHaveBeenCalledWith("mem-1", "complete");
expect(mockLogger.warn).toHaveBeenCalled();
});
it("should log extraction results", async () => {
mockFetchResponse(
JSON.stringify({
category: "fact",
entities: [{ name: "Test", type: "concept" }],
relationships: [{ source: "a", target: "b", type: "RELATED_TO", confidence: 0.8 }],
tags: [{ name: "tech" }],
}),
);
await runBackgroundExtraction(
"mem-12345678-abcd",
"test",
mockDb as never,
mockEmbeddings as never,
enabledConfig,
mockLogger,
);
expect(mockLogger.info).toHaveBeenCalledWith(expect.stringContaining("extraction complete"));
});
});

View File

@@ -63,6 +63,9 @@ Rules:
// OpenRouter API Client
// ============================================================================
// Timeout for LLM and embedding fetch calls to prevent hanging indefinitely
const FETCH_TIMEOUT_MS = 30_000;
async function callOpenRouter(config: ExtractionConfig, prompt: string): Promise<string | null> {
for (let attempt = 0; attempt <= config.maxRetries; attempt++) {
try {
@@ -78,6 +81,7 @@ async function callOpenRouter(config: ExtractionConfig, prompt: string): Promise
temperature: config.temperature,
response_format: { type: "json_object" },
}),
signal: AbortSignal.timeout(FETCH_TIMEOUT_MS),
});
if (!response.ok) {
@@ -104,30 +108,68 @@ async function callOpenRouter(config: ExtractionConfig, prompt: string): Promise
// Entity Extraction
// ============================================================================
/** Max retries for transient extraction failures before marking permanently failed */
const MAX_EXTRACTION_RETRIES = 3;
/**
* Check if an error is transient (network/timeout) vs permanent (JSON parse, etc.)
*/
function isTransientError(err: unknown): boolean {
if (!(err instanceof Error)) return false;
const msg = err.message.toLowerCase();
return (
err.name === "AbortError" ||
err.name === "TimeoutError" ||
msg.includes("timeout") ||
msg.includes("econnrefused") ||
msg.includes("econnreset") ||
msg.includes("enotfound") ||
msg.includes("network") ||
msg.includes("fetch failed") ||
msg.includes("socket hang up") ||
msg.includes("api error 429") ||
msg.includes("api error 502") ||
msg.includes("api error 503") ||
msg.includes("api error 504")
);
}
/**
* Extract entities and relationships from a memory text using LLM.
*
* Returns { result, transientFailure }:
* - result is the ExtractionResult or null if extraction returned nothing useful
* - transientFailure is true if the failure was due to a network/timeout issue
* (caller should retry later) vs a permanent failure (bad JSON, etc.)
*/
export async function extractEntities(
text: string,
config: ExtractionConfig,
): Promise<ExtractionResult | null> {
): Promise<{ result: ExtractionResult | null; transientFailure: boolean }> {
if (!config.enabled) {
return null;
return { result: null, transientFailure: false };
}
const prompt = ENTITY_EXTRACTION_PROMPT.replace("{text}", text);
let content: string | null;
try {
const content = await callOpenRouter(config, prompt);
if (!content) {
return null;
}
content = await callOpenRouter(config, prompt);
} catch (err) {
// Network/timeout errors are transient — caller should retry
return { result: null, transientFailure: isTransientError(err) };
}
if (!content) {
return { result: null, transientFailure: false };
}
try {
const parsed = JSON.parse(content) as Record<string, unknown>;
return validateExtractionResult(parsed);
return { result: validateExtractionResult(parsed), transientFailure: false };
} catch {
// Will be handled by caller; don't throw for parse errors
return null;
// JSON parse failure is permanent — LLM returned malformed output
return { result: null, transientFailure: false };
}
}
@@ -213,7 +255,11 @@ function validateExtractionResult(raw: Record<string, unknown>): ExtractionResul
* 3. Create MENTIONS relationships from Memory → Entity
* 4. Create inter-Entity relationships (WORKS_AT, KNOWS, etc.)
* 5. Tag the memory
* 6. Update extractionStatus to "complete" or "failed"
* 6. Update extractionStatus to "complete", "pending" (transient retry), or "failed"
*
* Transient failures (network/timeout) leave status as "pending" with an incremented
* retry counter. After MAX_EXTRACTION_RETRIES transient failures, the memory is
* permanently marked "failed". Permanent failures (malformed JSON) are immediately "failed".
*/
export async function runBackgroundExtraction(
memoryId: string,
@@ -222,6 +268,7 @@ export async function runBackgroundExtraction(
embeddings: Embeddings,
config: ExtractionConfig,
logger: Logger,
currentRetries: number = 0,
): Promise<void> {
if (!config.enabled) {
await db.updateExtractionStatus(memoryId, "skipped").catch(() => {});
@@ -229,10 +276,28 @@ export async function runBackgroundExtraction(
}
try {
const result = await extractEntities(text, config);
const { result, transientFailure } = await extractEntities(text, config);
if (!result) {
await db.updateExtractionStatus(memoryId, "failed");
if (transientFailure) {
// Transient failure (network/timeout) — leave as pending for retry
const retries = currentRetries + 1;
if (retries >= MAX_EXTRACTION_RETRIES) {
logger.warn(
`memory-neo4j: extraction permanently failed for ${memoryId.slice(0, 8)} after ${retries} transient retries`,
);
await db.updateExtractionStatus(memoryId, "failed", { incrementRetries: true });
} else {
logger.info(
`memory-neo4j: extraction transient failure for ${memoryId.slice(0, 8)}, will retry (${retries}/${MAX_EXTRACTION_RETRIES})`,
);
// Keep status as "pending" but increment retry counter
await db.updateExtractionStatus(memoryId, "pending", { incrementRetries: true });
}
} else {
// Permanent failure (JSON parse, empty response, etc.)
await db.updateExtractionStatus(memoryId, "failed");
}
return;
}
@@ -309,8 +374,21 @@ export async function runBackgroundExtraction(
(result.category ? `, category=${result.category}` : ""),
);
} catch (err) {
logger.warn(`memory-neo4j: extraction failed for ${memoryId.slice(0, 8)}: ${String(err)}`);
await db.updateExtractionStatus(memoryId, "failed").catch(() => {});
// Unexpected error during graph operations — treat as transient if retry budget remains
const isTransient = isTransientError(err);
if (isTransient && currentRetries + 1 < MAX_EXTRACTION_RETRIES) {
logger.warn(
`memory-neo4j: extraction transient error for ${memoryId.slice(0, 8)}, will retry: ${String(err)}`,
);
await db
.updateExtractionStatus(memoryId, "pending", { incrementRetries: true })
.catch(() => {});
} else {
logger.warn(`memory-neo4j: extraction failed for ${memoryId.slice(0, 8)}: ${String(err)}`);
await db
.updateExtractionStatus(memoryId, "failed", { incrementRetries: true })
.catch(() => {});
}
}
}
@@ -533,6 +611,14 @@ export async function runSleepCycle(
// --------------------------------------------------------------------------
// Phase 3: Core Promotion (using pre-computed scores from Phase 2)
//
// Design note on staleness: The effective scores and Pareto threshold were
// computed in Phase 2 and may be slightly stale by the time Phases 3/4 run.
// This is acceptable because: (a) the sleep cycle is a background maintenance
// task that runs infrequently (not concurrent with itself), (b) the scoring
// formula is deterministic based on stored properties that change slowly, and
// (c) promotion/demotion are reversible in the next cycle. The alternative
// (re-querying scores per phase) adds latency without meaningful accuracy gain.
// --------------------------------------------------------------------------
if (!abortSignal?.aborted && paretoThreshold > 0) {
onPhaseStart?.("promotion");
@@ -631,7 +717,15 @@ export async function runSleepCycle(
const chunk = pending.slice(i, i + EXTRACTION_CONCURRENCY);
const outcomes = await Promise.allSettled(
chunk.map((memory) =>
runBackgroundExtraction(memory.id, memory.text, db, embeddings, config, logger),
runBackgroundExtraction(
memory.id,
memory.text,
db,
embeddings,
config,
logger,
memory.extractionRetries,
),
),
);

View File

@@ -18,6 +18,8 @@ import { randomUUID } from "node:crypto";
import { stringEnum } from "openclaw/plugin-sdk";
import type { MemoryCategory, MemorySource } from "./schema.js";
import {
DEFAULT_EMBEDDING_DIMS,
EMBEDDING_DIMENSIONS,
MEMORY_CATEGORIES,
memoryNeo4jConfigSchema,
resolveExtractionConfig,
@@ -46,6 +48,25 @@ const memoryNeo4jPlugin = {
const extractionConfig = resolveExtractionConfig(cfg.extraction);
const vectorDim = vectorDimsForModel(cfg.embedding.model);
// Warn on empty neo4j password (may be valid for some setups, but usually a misconfiguration)
if (!cfg.neo4j.password) {
api.logger.warn(
"memory-neo4j: neo4j.password is empty — this may be intentional for passwordless setups, but verify your configuration",
);
}
// Warn when using default embedding dimensions for an unknown model
const isKnownModel =
cfg.embedding.model in EMBEDDING_DIMENSIONS ||
Object.keys(EMBEDDING_DIMENSIONS).some((known) => cfg.embedding.model.startsWith(known));
if (!isKnownModel) {
api.logger.warn(
`memory-neo4j: unknown embedding model "${cfg.embedding.model}" — using default ${DEFAULT_EMBEDDING_DIMS} dimensions. ` +
`If your model outputs a different dimension, vector operations will fail. ` +
`Known models: ${Object.keys(EMBEDDING_DIMENSIONS).join(", ")}`,
);
}
// Create shared resources
const db = new Neo4jMemoryClient(
cfg.neo4j.uri,
@@ -59,6 +80,7 @@ const memoryNeo4jPlugin = {
cfg.embedding.model,
cfg.embedding.provider,
cfg.embedding.baseUrl,
api.logger,
);
api.logger.debug?.(
@@ -499,23 +521,68 @@ const memoryNeo4jPlugin = {
console.log(" Phase 7: Orphan Cleanup — Remove disconnected nodes\n");
try {
// Validate sleep cycle CLI parameters before running
const batchSize = opts.batchSize ? parseInt(opts.batchSize, 10) : undefined;
const delay = opts.delay ? parseInt(opts.delay, 10) : undefined;
const decayHalfLife = opts.decayHalfLife
? parseInt(opts.decayHalfLife, 10)
: undefined;
const decayThreshold = opts.decayThreshold
? parseFloat(opts.decayThreshold)
: undefined;
const pareto = opts.pareto ? parseFloat(opts.pareto) : undefined;
const promotionMinAge = opts.promotionMinAge
? parseInt(opts.promotionMinAge, 10)
: undefined;
if (batchSize != null && (Number.isNaN(batchSize) || batchSize <= 0)) {
console.error("Error: --batch-size must be greater than 0");
process.exitCode = 1;
return;
}
if (delay != null && (Number.isNaN(delay) || delay < 0)) {
console.error("Error: --delay must be >= 0");
process.exitCode = 1;
return;
}
if (decayHalfLife != null && (Number.isNaN(decayHalfLife) || decayHalfLife <= 0)) {
console.error("Error: --decay-half-life must be greater than 0");
process.exitCode = 1;
return;
}
if (
decayThreshold != null &&
(Number.isNaN(decayThreshold) || decayThreshold < 0 || decayThreshold > 1)
) {
console.error("Error: --decay-threshold must be between 0 and 1");
process.exitCode = 1;
return;
}
if (pareto != null && (Number.isNaN(pareto) || pareto < 0 || pareto > 1)) {
console.error("Error: --pareto must be between 0 and 1");
process.exitCode = 1;
return;
}
if (
promotionMinAge != null &&
(Number.isNaN(promotionMinAge) || promotionMinAge < 0)
) {
console.error("Error: --promotion-min-age must be >= 0");
process.exitCode = 1;
return;
}
await db.ensureInitialized();
const result = await runSleepCycle(db, embeddings, extractionConfig, api.logger, {
agentId: opts.agent,
dedupThreshold: opts.dedupThreshold ? parseFloat(opts.dedupThreshold) : undefined,
paretoPercentile: opts.pareto ? parseFloat(opts.pareto) : undefined,
promotionMinAgeDays: opts.promotionMinAge
? parseInt(opts.promotionMinAge, 10)
: undefined,
decayRetentionThreshold: opts.decayThreshold
? parseFloat(opts.decayThreshold)
: undefined,
decayBaseHalfLifeDays: opts.decayHalfLife
? parseInt(opts.decayHalfLife, 10)
: undefined,
extractionBatchSize: opts.batchSize ? parseInt(opts.batchSize, 10) : undefined,
extractionDelayMs: opts.delay ? parseInt(opts.delay, 10) : undefined,
paretoPercentile: pareto,
promotionMinAgeDays: promotionMinAge,
decayRetentionThreshold: decayThreshold,
decayBaseHalfLifeDays: decayHalfLife,
extractionBatchSize: batchSize,
extractionDelayMs: delay,
onPhaseStart: (phase) => {
const phaseNames = {
dedup: "Phase 1: Deduplication",
@@ -611,12 +678,45 @@ const memoryNeo4jPlugin = {
const midSessionRefreshAt = new Map<string, number>();
const MIN_TOKENS_SINCE_REFRESH = 10_000; // Only refresh if context grew by 10k+ tokens
// Track session timestamps for TTL-based cleanup. Without this, bootstrappedSessions
// and midSessionRefreshAt leak entries for sessions that ended without an explicit
// after_compaction event (e.g., normal session end on long-running gateways).
const SESSION_TTL_MS = 24 * 60 * 60 * 1000; // 24 hours
const sessionLastSeen = new Map<string, number>();
let lastTtlSweep = Date.now();
/** Evict stale entries from session tracking maps older than SESSION_TTL_MS. */
function pruneStaleSessionEntries(): void {
const now = Date.now();
// Only sweep at most once per 5 minutes to avoid overhead
if (now - lastTtlSweep < 5 * 60 * 1000) {
return;
}
lastTtlSweep = now;
const cutoff = now - SESSION_TTL_MS;
for (const [key, ts] of sessionLastSeen) {
if (ts < cutoff) {
bootstrappedSessions.delete(key);
midSessionRefreshAt.delete(key);
sessionLastSeen.delete(key);
}
}
}
/** Mark a session as recently active for TTL tracking. */
function touchSession(sessionKey: string): void {
sessionLastSeen.set(sessionKey, Date.now());
pruneStaleSessionEntries();
}
// After compaction: clear bootstrap flag and mid-session refresh tracking
if (cfg.coreMemory.enabled) {
api.on("after_compaction", async (_event, ctx) => {
if (ctx.sessionKey) {
bootstrappedSessions.delete(ctx.sessionKey);
midSessionRefreshAt.delete(ctx.sessionKey);
sessionLastSeen.delete(ctx.sessionKey);
api.logger.info?.(
`memory-neo4j: cleared bootstrap/refresh flags for session ${ctx.sessionKey} after compaction`,
);
@@ -624,6 +724,18 @@ const memoryNeo4jPlugin = {
});
}
// Session end: clean up tracking entries for completed sessions.
// The sessionId from session_end may match sessionKey in some implementations;
// this provides best-effort cleanup alongside the TTL-based sweep above.
api.on("session_end", async (_event, ctx) => {
const key = ctx.sessionId;
if (key) {
bootstrappedSessions.delete(key);
midSessionRefreshAt.delete(key);
sessionLastSeen.delete(key);
}
});
// Mid-session core memory refresh: re-inject core memories when context grows past threshold
// This counters the "lost in the middle" phenomenon by placing core memories closer to end of context
const refreshThreshold = cfg.coreMemory.refreshAtContextPercent;
@@ -666,6 +778,7 @@ const memoryNeo4jPlugin = {
// Record this refresh
midSessionRefreshAt.set(sessionKey, event.estimatedUsedTokens);
touchSession(sessionKey);
const content = coreMemories.map((m) => `- ${m.text}`).join("\n");
api.logger.info?.(
@@ -769,6 +882,7 @@ const memoryNeo4jPlugin = {
if (coreMemories.length === 0) {
if (sessionKey) {
bootstrappedSessions.add(sessionKey);
touchSession(sessionKey);
}
api.logger.debug?.(
`memory-neo4j: no core memories found for agent=${agentId}, marking session as bootstrapped`,
@@ -805,6 +919,7 @@ const memoryNeo4jPlugin = {
if (sessionKey) {
bootstrappedSessions.add(sessionKey);
touchSession(sessionKey);
}
// Log at info level when actually injecting, debug for skips
api.logger.info?.(

View File

@@ -77,14 +77,15 @@ describe("mid-session core memory refresh", () => {
expect(config.coreMemory.refreshAtContextPercent).toBeUndefined();
});
it("should reject refreshAtContextPercent over 100", async () => {
it("should throw for refreshAtContextPercent over 100", async () => {
const { memoryNeo4jConfigSchema } = await import("./config.js");
const config = memoryNeo4jConfigSchema.parse({
neo4j: { uri: "bolt://localhost:7687", user: "neo4j", password: "test" },
embedding: { provider: "ollama" },
coreMemory: { refreshAtContextPercent: 150 },
});
expect(config.coreMemory.refreshAtContextPercent).toBeUndefined();
expect(() =>
memoryNeo4jConfigSchema.parse({
neo4j: { uri: "bolt://localhost:7687", user: "neo4j", password: "test" },
embedding: { provider: "ollama" },
coreMemory: { refreshAtContextPercent: 150 },
}),
).toThrow("coreMemory.refreshAtContextPercent must be between 1 and 100");
});
it("should default to undefined when not specified", async () => {

View File

@@ -211,32 +211,36 @@ export class Neo4jMemoryClient {
async storeMemory(input: StoreMemoryInput): Promise<string> {
await this.ensureInitialized();
const session = this.driver!.session();
try {
const now = new Date().toISOString();
const result = await session.run(
`CREATE (m:Memory {
id: $id, text: $text, embedding: $embedding,
importance: $importance, category: $category,
source: $source, extractionStatus: $extractionStatus,
agentId: $agentId, sessionKey: $sessionKey,
createdAt: $createdAt, updatedAt: $updatedAt,
retrievalCount: $retrievalCount, lastRetrievedAt: $lastRetrievedAt
})
RETURN m.id AS id`,
{
...input,
sessionKey: input.sessionKey ?? null,
createdAt: now,
updatedAt: now,
retrievalCount: 0,
lastRetrievedAt: null,
},
);
return result.records[0].get("id") as string;
} finally {
await session.close();
}
return this.retryOnTransient(async () => {
const session = this.driver!.session();
try {
const now = new Date().toISOString();
const result = await session.run(
`CREATE (m:Memory {
id: $id, text: $text, embedding: $embedding,
importance: $importance, category: $category,
source: $source, extractionStatus: $extractionStatus,
agentId: $agentId, sessionKey: $sessionKey,
createdAt: $createdAt, updatedAt: $updatedAt,
retrievalCount: $retrievalCount, lastRetrievedAt: $lastRetrievedAt,
extractionRetries: $extractionRetries
})
RETURN m.id AS id`,
{
...input,
sessionKey: input.sessionKey ?? null,
createdAt: now,
updatedAt: now,
retrievalCount: 0,
lastRetrievedAt: null,
extractionRetries: 0,
},
);
return result.records[0].get("id") as string;
} finally {
await session.close();
}
});
}
async deleteMemory(id: string): Promise<boolean> {
@@ -247,29 +251,32 @@ export class Neo4jMemoryClient {
throw new Error(`Invalid memory ID format: ${id}`);
}
const session = this.driver!.session();
try {
// First, decrement mentionCount on connected entities
await session.run(
`MATCH (m:Memory {id: $id})-[:MENTIONS]->(e:Entity)
SET e.mentionCount = e.mentionCount - 1`,
{ id },
);
return this.retryOnTransient(async () => {
const session = this.driver!.session();
try {
// Decrement mentionCount on connected entities (floor at 0 to prevent
// negative counts from parallel deletes racing on the same entity)
await session.run(
`MATCH (m:Memory {id: $id})-[:MENTIONS]->(e:Entity)
SET e.mentionCount = CASE WHEN e.mentionCount > 0 THEN e.mentionCount - 1 ELSE 0 END`,
{ id },
);
// Then delete the memory with all its relationships
const result = await session.run(
`MATCH (m:Memory {id: $id})
DETACH DELETE m
RETURN count(*) AS deleted`,
{ id },
);
// Then delete the memory with all its relationships
const result = await session.run(
`MATCH (m:Memory {id: $id})
DETACH DELETE m
RETURN count(*) AS deleted`,
{ id },
);
const deleted =
result.records.length > 0 ? (result.records[0].get("deleted") as number) > 0 : false;
return deleted;
} finally {
await session.close();
}
const deleted =
result.records.length > 0 ? (result.records[0].get("deleted") as number) > 0 : false;
return deleted;
} finally {
await session.close();
}
});
}
async countMemories(agentId?: string): Promise<number> {
@@ -371,39 +378,43 @@ export class Neo4jMemoryClient {
agentId?: string,
): Promise<SearchSignalResult[]> {
await this.ensureInitialized();
const session = this.driver!.session();
try {
const agentFilter = agentId ? "AND node.agentId = $agentId" : "";
const result = await session.run(
`CALL db.index.vector.queryNodes('memory_embedding_index', $limit, $embedding)
YIELD node, score
WHERE score >= $minScore ${agentFilter}
RETURN node.id AS id, node.text AS text, node.category AS category,
node.importance AS importance, node.createdAt AS createdAt,
score AS similarity
ORDER BY score DESC`,
{
embedding,
limit: neo4j.int(Math.floor(limit)),
minScore,
...(agentId ? { agentId } : {}),
},
);
return await this.retryOnTransient(async () => {
const session = this.driver!.session();
try {
const agentFilter = agentId ? "AND node.agentId = $agentId" : "";
const result = await session.run(
`CALL db.index.vector.queryNodes('memory_embedding_index', $limit, $embedding)
YIELD node, score
WHERE score >= $minScore ${agentFilter}
RETURN node.id AS id, node.text AS text, node.category AS category,
node.importance AS importance, node.createdAt AS createdAt,
score AS similarity
ORDER BY score DESC`,
{
embedding,
limit: neo4j.int(Math.floor(limit)),
minScore,
...(agentId ? { agentId } : {}),
},
);
return result.records.map((r) => ({
id: r.get("id") as string,
text: r.get("text") as string,
category: r.get("category") as string,
importance: r.get("importance") as number,
createdAt: String(r.get("createdAt") ?? ""),
score: r.get("similarity") as number,
}));
return result.records.map((r) => ({
id: r.get("id") as string,
text: r.get("text") as string,
category: r.get("category") as string,
importance: r.get("importance") as number,
createdAt: String(r.get("createdAt") ?? ""),
score: r.get("similarity") as number,
}));
} finally {
await session.close();
}
});
} catch (err) {
// Graceful degradation: return empty if vector index isn't ready
// Graceful degradation: return empty if vector index isn't ready or all retries exhausted
this.logger.warn(`memory-neo4j: vector search failed: ${String(err)}`);
return [];
} finally {
await session.close();
}
}
@@ -413,49 +424,58 @@ export class Neo4jMemoryClient {
*/
async bm25Search(query: string, limit: number, agentId?: string): Promise<SearchSignalResult[]> {
await this.ensureInitialized();
const session = this.driver!.session();
const escaped = escapeLucene(query);
if (!escaped.trim()) {
return [];
}
try {
const escaped = escapeLucene(query);
if (!escaped.trim()) {
return [];
}
return await this.retryOnTransient(async () => {
const session = this.driver!.session();
try {
const agentFilter = agentId ? "AND node.agentId = $agentId" : "";
const result = await session.run(
`CALL db.index.fulltext.queryNodes('memory_fulltext_index', $query)
YIELD node, score
WHERE true ${agentFilter}
RETURN node.id AS id, node.text AS text, node.category AS category,
node.importance AS importance, node.createdAt AS createdAt,
score AS bm25Score
ORDER BY score DESC
LIMIT $limit`,
{
query: escaped,
limit: neo4j.int(Math.floor(limit)),
...(agentId ? { agentId } : {}),
},
);
const agentFilter = agentId ? "AND node.agentId = $agentId" : "";
const result = await session.run(
`CALL db.index.fulltext.queryNodes('memory_fulltext_index', $query)
YIELD node, score
WHERE true ${agentFilter}
RETURN node.id AS id, node.text AS text, node.category AS category,
node.importance AS importance, node.createdAt AS createdAt,
score AS bm25Score
ORDER BY score DESC
LIMIT $limit`,
{ query: escaped, limit: neo4j.int(Math.floor(limit)), ...(agentId ? { agentId } : {}) },
);
// Normalize BM25 scores to 0-1 range (divide by max)
const records = result.records.map((r) => ({
id: r.get("id") as string,
text: r.get("text") as string,
category: r.get("category") as string,
importance: r.get("importance") as number,
createdAt: String(r.get("createdAt") ?? ""),
rawScore: r.get("bm25Score") as number,
}));
// Normalize BM25 scores to 0-1 range (divide by max)
const records = result.records.map((r) => ({
id: r.get("id") as string,
text: r.get("text") as string,
category: r.get("category") as string,
importance: r.get("importance") as number,
createdAt: String(r.get("createdAt") ?? ""),
rawScore: r.get("bm25Score") as number,
}));
if (records.length === 0) {
return [];
}
const maxScore = records[0].rawScore || 1;
return records.map((r) => ({
...r,
score: r.rawScore / maxScore,
}));
if (records.length === 0) {
return [];
}
const maxScore = records[0].rawScore || 1;
return records.map((r) => ({
...r,
score: r.rawScore / maxScore,
}));
} finally {
await session.close();
}
});
} catch (err) {
// Graceful degradation: return empty if all retries exhausted
this.logger.warn(`memory-neo4j: BM25 search failed: ${String(err)}`);
return [];
} finally {
await session.close();
}
}
@@ -475,89 +495,94 @@ export class Neo4jMemoryClient {
agentId?: string,
): Promise<SearchSignalResult[]> {
await this.ensureInitialized();
const session = this.driver!.session();
const escaped = escapeLucene(query);
if (!escaped.trim()) {
return [];
}
try {
const escaped = escapeLucene(query);
if (!escaped.trim()) {
return [];
}
return await this.retryOnTransient(async () => {
const session = this.driver!.session();
try {
// Step 1: Find matching entities
const entityResult = await session.run(
`CALL db.index.fulltext.queryNodes('entity_fulltext_index', $query)
YIELD node, score
WHERE score >= 0.5
RETURN node.id AS entityId, node.name AS name, score
ORDER BY score DESC
LIMIT 5`,
{ query: escaped },
);
// Step 1: Find matching entities
const entityResult = await session.run(
`CALL db.index.fulltext.queryNodes('entity_fulltext_index', $query)
YIELD node, score
WHERE score >= 0.5
RETURN node.id AS entityId, node.name AS name, score
ORDER BY score DESC
LIMIT 5`,
{ query: escaped },
);
const entityIds = entityResult.records.map((r) => r.get("entityId") as string);
if (entityIds.length === 0) {
return [];
}
const entityIds = entityResult.records.map((r) => r.get("entityId") as string);
if (entityIds.length === 0) {
return [];
}
// Step 2 + 3: Direct mentions + 1-hop spreading activation
const agentFilter = agentId ? "AND m.agentId = $agentId" : "";
const result = await session.run(
`UNWIND $entityIds AS eid
// Direct: Entity ← MENTIONS ← Memory
OPTIONAL MATCH (e:Entity {id: eid})<-[rm:MENTIONS]-(m:Memory)
WHERE m IS NOT NULL ${agentFilter}
WITH m, coalesce(rm.confidence, 1.0) AS directScore
WHERE m IS NOT NULL
// Step 2 + 3: Direct mentions + 1-hop spreading activation
const agentFilter = agentId ? "AND m.agentId = $agentId" : "";
const result = await session.run(
`UNWIND $entityIds AS eid
// Direct: Entity ← MENTIONS ← Memory
OPTIONAL MATCH (e:Entity {id: eid})<-[rm:MENTIONS]-(m:Memory)
WHERE m IS NOT NULL ${agentFilter}
WITH m, coalesce(rm.confidence, 1.0) AS directScore
WHERE m IS NOT NULL
RETURN m.id AS id, m.text AS text, m.category AS category,
m.importance AS importance, m.createdAt AS createdAt,
max(directScore) AS graphScore
RETURN m.id AS id, m.text AS text, m.category AS category,
m.importance AS importance, m.createdAt AS createdAt,
max(directScore) AS graphScore
UNION
UNION
UNWIND $entityIds AS eid
// 1-hop: Entity → relationship → Entity ← MENTIONS ← Memory
OPTIONAL MATCH (e:Entity {id: eid})-[r1:RELATED_TO|KNOWS|WORKS_AT|LIVES_AT|MARRIED_TO|PREFERS|DECIDED]-(e2:Entity)
WHERE coalesce(r1.confidence, 0.7) >= $firingThreshold
OPTIONAL MATCH (e2)<-[rm:MENTIONS]-(m:Memory)
WHERE m IS NOT NULL ${agentFilter}
WITH m, coalesce(r1.confidence, 0.7) * coalesce(rm.confidence, 1.0) AS hopScore
WHERE m IS NOT NULL
UNWIND $entityIds AS eid
// 1-hop: Entity → relationship → Entity ← MENTIONS ← Memory
OPTIONAL MATCH (e:Entity {id: eid})-[r1:RELATED_TO|KNOWS|WORKS_AT|LIVES_AT|MARRIED_TO|PREFERS|DECIDED]-(e2:Entity)
WHERE coalesce(r1.confidence, 0.7) >= $firingThreshold
OPTIONAL MATCH (e2)<-[rm:MENTIONS]-(m:Memory)
WHERE m IS NOT NULL ${agentFilter}
WITH m, coalesce(r1.confidence, 0.7) * coalesce(rm.confidence, 1.0) AS hopScore
WHERE m IS NOT NULL
RETURN m.id AS id, m.text AS text, m.category AS category,
m.importance AS importance, m.createdAt AS createdAt,
max(hopScore) AS graphScore`,
{ entityIds, firingThreshold, ...(agentId ? { agentId } : {}) },
);
RETURN m.id AS id, m.text AS text, m.category AS category,
m.importance AS importance, m.createdAt AS createdAt,
max(hopScore) AS graphScore`,
{ entityIds, firingThreshold, ...(agentId ? { agentId } : {}) },
);
// Deduplicate by id, keeping highest score
const byId = new Map<string, SearchSignalResult>();
for (const record of result.records) {
const id = record.get("id") as string;
if (!id) {
continue;
}
const score = record.get("graphScore") as number;
const existing = byId.get(id);
if (!existing || score > existing.score) {
byId.set(id, {
id,
text: record.get("text") as string,
category: record.get("category") as string,
importance: record.get("importance") as number,
createdAt: String(record.get("createdAt") ?? ""),
score,
});
}
}
// Deduplicate by id, keeping highest score
const byId = new Map<string, SearchSignalResult>();
for (const record of result.records) {
const id = record.get("id") as string;
if (!id) {
continue;
return Array.from(byId.values())
.toSorted((a, b) => b.score - a.score)
.slice(0, limit);
} finally {
await session.close();
}
const score = record.get("graphScore") as number;
const existing = byId.get(id);
if (!existing || score > existing.score) {
byId.set(id, {
id,
text: record.get("text") as string,
category: record.get("category") as string,
importance: record.get("importance") as number,
createdAt: String(record.get("createdAt") ?? ""),
score,
});
}
}
return Array.from(byId.values())
.toSorted((a, b) => b.score - a.score)
.slice(0, limit);
});
} catch (err) {
// Graceful degradation: return empty if all retries exhausted
this.logger.warn(`memory-neo4j: graph search failed: ${String(err)}`);
return [];
} finally {
await session.close();
}
}
@@ -570,28 +595,32 @@ export class Neo4jMemoryClient {
limit: number = 1,
): Promise<Array<{ id: string; text: string; score: number }>> {
await this.ensureInitialized();
const session = this.driver!.session();
try {
const result = await session.run(
`CALL db.index.vector.queryNodes('memory_embedding_index', $limit, $embedding)
YIELD node, score
WHERE score >= $threshold
RETURN node.id AS id, node.text AS text, score AS similarity
ORDER BY score DESC`,
{ embedding, limit: neo4j.int(limit), threshold },
);
return await this.retryOnTransient(async () => {
const session = this.driver!.session();
try {
const result = await session.run(
`CALL db.index.vector.queryNodes('memory_embedding_index', $limit, $embedding)
YIELD node, score
WHERE score >= $threshold
RETURN node.id AS id, node.text AS text, score AS similarity
ORDER BY score DESC`,
{ embedding, limit: neo4j.int(limit), threshold },
);
return result.records.map((r) => ({
id: r.get("id") as string,
text: r.get("text") as string,
score: r.get("similarity") as number,
}));
return result.records.map((r) => ({
id: r.get("id") as string,
text: r.get("text") as string,
score: r.get("similarity") as number,
}));
} finally {
await session.close();
}
});
} catch (err) {
// If vector index isn't ready, return no duplicates (allow store)
// If vector index isn't ready or all retries exhausted, return no duplicates (allow store)
this.logger.debug?.(`memory-neo4j: similarity check failed: ${String(err)}`);
return [];
} finally {
await session.close();
}
}
@@ -609,18 +638,20 @@ export class Neo4jMemoryClient {
}
await this.ensureInitialized();
const session = this.driver!.session();
try {
await session.run(
`UNWIND $ids AS memId
MATCH (m:Memory {id: memId})
SET m.retrievalCount = coalesce(m.retrievalCount, 0) + 1,
m.lastRetrievedAt = $now`,
{ ids: memoryIds, now: new Date().toISOString() },
);
} finally {
await session.close();
}
return this.retryOnTransient(async () => {
const session = this.driver!.session();
try {
await session.run(
`UNWIND $ids AS memId
MATCH (m:Memory {id: memId})
SET m.retrievalCount = coalesce(m.retrievalCount, 0) + 1,
m.lastRetrievedAt = $now`,
{ ids: memoryIds, now: new Date().toISOString() },
);
} finally {
await session.close();
}
});
}
/**
@@ -832,14 +863,22 @@ export class Neo4jMemoryClient {
/**
* Update the extraction status of a Memory node.
* Optionally increments the extractionRetries counter (for transient failure tracking).
*/
async updateExtractionStatus(id: string, status: ExtractionStatus): Promise<void> {
async updateExtractionStatus(
id: string,
status: ExtractionStatus,
options?: { incrementRetries?: boolean },
): Promise<void> {
await this.ensureInitialized();
const session = this.driver!.session();
try {
const retryClause = options?.incrementRetries
? ", m.extractionRetries = coalesce(m.extractionRetries, 0) + 1"
: "";
await session.run(
`MATCH (m:Memory {id: $id})
SET m.extractionStatus = $status, m.updatedAt = $now`,
SET m.extractionStatus = $status, m.updatedAt = $now${retryClause}`,
{ id, status, now: new Date().toISOString() },
);
} finally {
@@ -847,6 +886,24 @@ export class Neo4jMemoryClient {
}
}
/**
* Get the current extraction retry count for a memory.
*/
async getExtractionRetries(id: string): Promise<number> {
await this.ensureInitialized();
const session = this.driver!.session();
try {
const result = await session.run(
`MATCH (m:Memory {id: $id})
RETURN coalesce(m.extractionRetries, 0) AS retries`,
{ id },
);
return (result.records[0]?.get("retries") as number) ?? 0;
} finally {
await session.close();
}
}
/**
* List memories with pending extraction status.
* Used by the sleep cycle to batch-process extractions.
@@ -854,7 +911,7 @@ export class Neo4jMemoryClient {
async listPendingExtractions(
limit: number = 100,
agentId?: string,
): Promise<Array<{ id: string; text: string; agentId: string }>> {
): Promise<Array<{ id: string; text: string; agentId: string; extractionRetries: number }>> {
await this.ensureInitialized();
const session = this.driver!.session();
try {
@@ -862,7 +919,8 @@ export class Neo4jMemoryClient {
const result = await session.run(
`MATCH (m:Memory)
WHERE m.extractionStatus = 'pending' ${agentFilter}
RETURN m.id AS id, m.text AS text, m.agentId AS agentId
RETURN m.id AS id, m.text AS text, m.agentId AS agentId,
coalesce(m.extractionRetries, 0) AS extractionRetries
ORDER BY m.createdAt ASC
LIMIT $limit`,
{ limit: neo4j.int(limit), agentId },
@@ -871,6 +929,7 @@ export class Neo4jMemoryClient {
id: r.get("id") as string,
text: r.get("text") as string,
agentId: r.get("agentId") as string,
extractionRetries: r.get("extractionRetries") as number,
}));
} finally {
await session.close();
@@ -976,14 +1035,17 @@ export class Neo4jMemoryClient {
let pairsFound = 0;
for (const id of memoryData.keys()) {
const similar = await session.run(
`MATCH (src:Memory {id: $id})
CALL db.index.vector.queryNodes('memory_embedding_index', $k, src.embedding)
YIELD node, score
WHERE node.id <> $id AND score >= $threshold
RETURN node.id AS matchId`,
{ id, k: neo4j.int(10), threshold },
);
// Retry individual vector queries on transient errors
const similar = await this.retryOnTransient(async () => {
return session.run(
`MATCH (src:Memory {id: $id})
CALL db.index.vector.queryNodes('memory_embedding_index', $k, src.embedding)
YIELD node, score
WHERE node.id <> $id AND score >= $threshold
RETURN node.id AS matchId`,
{ id, k: neo4j.int(10), threshold },
);
});
for (const r of similar.records) {
const matchId = r.get("matchId") as string;
@@ -1045,30 +1107,56 @@ export class Neo4jMemoryClient {
const survivorId = memoryIds[survivorIdx];
const toDelete = memoryIds.filter((_, i) => i !== survivorIdx);
const session = this.driver!.session();
try {
// Transfer MENTIONS relationships from deleted memories to survivor
await session.run(
`UNWIND $toDelete AS deadId
MATCH (dead:Memory {id: deadId})-[r:MENTIONS]->(e:Entity)
MATCH (survivor:Memory {id: $survivorId})
MERGE (survivor)-[:MENTIONS]->(e)
DELETE r`,
{ toDelete, survivorId },
);
return this.retryOnTransient(async () => {
const session = this.driver!.session();
try {
// Optimistic lock: verify all cluster members still exist before merging.
// New memories added or deleted between findDuplicateClusters() and this
// call could invalidate the cluster. Skip if any member is missing.
const verifyResult = await session.run(
`UNWIND $ids AS memId
OPTIONAL MATCH (m:Memory {id: memId})
RETURN memId, m IS NOT NULL AS exists`,
{ ids: memoryIds },
);
// Delete the duplicate memories
await session.run(
`UNWIND $toDelete AS deadId
MATCH (m:Memory {id: deadId})
DETACH DELETE m`,
{ toDelete },
);
const missingIds: string[] = [];
for (const r of verifyResult.records) {
if (!r.get("exists")) {
missingIds.push(r.get("memId") as string);
}
}
return { survivorId, deletedCount: toDelete.length };
} finally {
await session.close();
}
if (missingIds.length > 0) {
this.logger.warn(
`memory-neo4j: skipping cluster merge — ${missingIds.length} member(s) no longer exist: ${missingIds.join(", ")}`,
);
return { survivorId, deletedCount: 0 };
}
// Transfer MENTIONS relationships from deleted memories to survivor
await session.run(
`UNWIND $toDelete AS deadId
MATCH (dead:Memory {id: deadId})-[r:MENTIONS]->(e:Entity)
MATCH (survivor:Memory {id: $survivorId})
MERGE (survivor)-[:MENTIONS]->(e)
DELETE r`,
{ toDelete, survivorId },
);
// Delete the duplicate memories
await session.run(
`UNWIND $toDelete AS deadId
MATCH (m:Memory {id: deadId})
DETACH DELETE m`,
{ toDelete },
);
return { survivorId, deletedCount: toDelete.length };
} finally {
await session.close();
}
});
}
// --------------------------------------------------------------------------
@@ -1160,11 +1248,12 @@ export class Neo4jMemoryClient {
await this.ensureInitialized();
const session = this.driver!.session();
try {
// Decrement mention counts on connected entities
// Decrement mention counts on connected entities (floor at 0 to prevent
// negative counts from parallel prune/delete operations racing on the same entity)
await session.run(
`UNWIND $ids AS memId
MATCH (m:Memory {id: memId})-[:MENTIONS]->(e:Entity)
SET e.mentionCount = e.mentionCount - 1`,
SET e.mentionCount = CASE WHEN e.mentionCount > 0 THEN e.mentionCount - 1 ELSE 0 END`,
{ ids: memoryIds },
);
@@ -1581,7 +1670,7 @@ export class Neo4jMemoryClient {
// --------------------------------------------------------------------------
/**
* Retry an operation on transient Neo4j errors (deadlocks, etc.)
* Retry an operation on transient Neo4j errors (deadlocks, connection blips, etc.)
* with exponential backoff. Adapted from ontology project.
*/
private async retryOnTransient<T>(
@@ -1595,14 +1684,24 @@ export class Neo4jMemoryClient {
return await fn();
} catch (err) {
lastError = err;
// Check for Neo4j transient errors
// Check for Neo4j transient errors (deadlocks, connection blips, service unavailable)
const errCode =
err instanceof Error
? ((err as unknown as Record<string, unknown>).code as string | undefined)
: undefined;
const isTransient =
err instanceof Error &&
(err.message.includes("DeadlockDetected") ||
err.message.includes("TransientError") ||
err.message.includes("ServiceUnavailable") ||
err.message.includes("SessionExpired") ||
err.message.includes("ConnectionRefused") ||
err.message.includes("connection terminated") ||
(err.constructor.name === "Neo4jError" &&
(err as unknown as Record<string, unknown>).code ===
"Neo.TransientError.Transaction.DeadlockDetected"));
typeof errCode === "string" &&
(errCode.startsWith("Neo.TransientError.") ||
errCode === "ServiceUnavailable" ||
errCode === "SessionExpired")));
if (!isTransient || attempt >= maxAttempts - 1) {
throw err;

View File

@@ -0,0 +1,200 @@
/**
* Tests for schema.ts — Schema Validation & Helpers.
*
* Tests the exported pure functions: escapeLucene(), validateRelationshipType(),
* and the exported constants and types.
*/
import { describe, it, expect } from "vitest";
import {
escapeLucene,
validateRelationshipType,
ALLOWED_RELATIONSHIP_TYPES,
MEMORY_CATEGORIES,
ENTITY_TYPES,
} from "./schema.js";
// ============================================================================
// escapeLucene()
// ============================================================================
describe("escapeLucene", () => {
it("should return normal text unchanged", () => {
expect(escapeLucene("hello world")).toBe("hello world");
});
it("should return empty string unchanged", () => {
expect(escapeLucene("")).toBe("");
});
it("should escape plus sign", () => {
expect(escapeLucene("a+b")).toBe("a\\+b");
});
it("should escape minus sign", () => {
expect(escapeLucene("a-b")).toBe("a\\-b");
});
it("should escape ampersand", () => {
expect(escapeLucene("a&b")).toBe("a\\&b");
});
it("should escape pipe", () => {
expect(escapeLucene("a|b")).toBe("a\\|b");
});
it("should escape exclamation mark", () => {
expect(escapeLucene("hello!")).toBe("hello\\!");
});
it("should escape parentheses", () => {
expect(escapeLucene("(group)")).toBe("\\(group\\)");
});
it("should escape curly braces", () => {
expect(escapeLucene("{range}")).toBe("\\{range\\}");
});
it("should escape square brackets", () => {
expect(escapeLucene("[range]")).toBe("\\[range\\]");
});
it("should escape caret", () => {
expect(escapeLucene("boost^2")).toBe("boost\\^2");
});
it("should escape double quotes", () => {
expect(escapeLucene('"exact"')).toBe('\\"exact\\"');
});
it("should escape tilde", () => {
expect(escapeLucene("fuzzy~")).toBe("fuzzy\\~");
});
it("should escape asterisk", () => {
expect(escapeLucene("wild*")).toBe("wild\\*");
});
it("should escape question mark", () => {
expect(escapeLucene("single?")).toBe("single\\?");
});
it("should escape colon", () => {
expect(escapeLucene("field:value")).toBe("field\\:value");
});
it("should escape backslash", () => {
expect(escapeLucene("path\\file")).toBe("path\\\\file");
});
it("should escape forward slash", () => {
expect(escapeLucene("a/b")).toBe("a\\/b");
});
it("should escape multiple special characters in one string", () => {
expect(escapeLucene("(a+b) && c*")).toBe("\\(a\\+b\\) \\&\\& c\\*");
});
it("should handle mixed normal and special characters", () => {
expect(escapeLucene("hello world! [test]")).toBe("hello world\\! \\[test\\]");
});
it("should handle strings with only special characters", () => {
expect(escapeLucene("+-")).toBe("\\+\\-");
});
});
// ============================================================================
// validateRelationshipType()
// ============================================================================
describe("validateRelationshipType", () => {
describe("valid relationship types", () => {
it("should accept WORKS_AT", () => {
expect(validateRelationshipType("WORKS_AT")).toBe(true);
});
it("should accept LIVES_AT", () => {
expect(validateRelationshipType("LIVES_AT")).toBe(true);
});
it("should accept KNOWS", () => {
expect(validateRelationshipType("KNOWS")).toBe(true);
});
it("should accept MARRIED_TO", () => {
expect(validateRelationshipType("MARRIED_TO")).toBe(true);
});
it("should accept PREFERS", () => {
expect(validateRelationshipType("PREFERS")).toBe(true);
});
it("should accept DECIDED", () => {
expect(validateRelationshipType("DECIDED")).toBe(true);
});
it("should accept RELATED_TO", () => {
expect(validateRelationshipType("RELATED_TO")).toBe(true);
});
it("should accept all ALLOWED_RELATIONSHIP_TYPES", () => {
for (const type of ALLOWED_RELATIONSHIP_TYPES) {
expect(validateRelationshipType(type)).toBe(true);
}
});
});
describe("invalid relationship types", () => {
it("should reject unknown relationship type", () => {
expect(validateRelationshipType("HATES")).toBe(false);
});
it("should reject empty string", () => {
expect(validateRelationshipType("")).toBe(false);
});
it("should be case sensitive — lowercase is rejected", () => {
expect(validateRelationshipType("works_at")).toBe(false);
});
it("should be case sensitive — mixed case is rejected", () => {
expect(validateRelationshipType("Works_At")).toBe(false);
});
it("should reject types with extra whitespace", () => {
expect(validateRelationshipType(" WORKS_AT ")).toBe(false);
});
it("should reject potential Cypher injection", () => {
expect(validateRelationshipType("WORKS_AT]->(n) DELETE n//")).toBe(false);
});
});
});
// ============================================================================
// Exported Constants
// ============================================================================
describe("exported constants", () => {
it("MEMORY_CATEGORIES should contain expected categories", () => {
expect(MEMORY_CATEGORIES).toContain("preference");
expect(MEMORY_CATEGORIES).toContain("fact");
expect(MEMORY_CATEGORIES).toContain("decision");
expect(MEMORY_CATEGORIES).toContain("entity");
expect(MEMORY_CATEGORIES).toContain("other");
});
it("ENTITY_TYPES should contain expected types", () => {
expect(ENTITY_TYPES).toContain("person");
expect(ENTITY_TYPES).toContain("organization");
expect(ENTITY_TYPES).toContain("location");
expect(ENTITY_TYPES).toContain("event");
expect(ENTITY_TYPES).toContain("concept");
});
it("ALLOWED_RELATIONSHIP_TYPES should be a Set", () => {
expect(ALLOWED_RELATIONSHIP_TYPES).toBeInstanceOf(Set);
expect(ALLOWED_RELATIONSHIP_TYPES.size).toBe(7);
});
});

View File

@@ -0,0 +1,400 @@
/**
* Tests for search.ts — Hybrid Search & RRF Fusion.
*
* Tests the exported pure logic: classifyQuery() and getAdaptiveWeights().
* Note: fuseWithConfidenceRRF() is not exported (private module-level function)
* and is tested indirectly through hybridSearch().
* hybridSearch() is tested with mocked Neo4j client and Embeddings.
*/
import { describe, it, expect, vi, beforeEach } from "vitest";
import type { SearchSignalResult } from "./schema.js";
import { classifyQuery, getAdaptiveWeights, hybridSearch } from "./search.js";
// ============================================================================
// classifyQuery()
// ============================================================================
describe("classifyQuery", () => {
describe("short queries (1-2 words)", () => {
it("should classify a single word as 'short'", () => {
expect(classifyQuery("dogs")).toBe("short");
});
it("should classify two words as 'short'", () => {
expect(classifyQuery("best coffee")).toBe("short");
});
it("should classify a single capitalized word as 'short' (word count takes priority)", () => {
expect(classifyQuery("TypeScript")).toBe("short");
});
it("should handle whitespace-padded short queries", () => {
expect(classifyQuery(" hello ")).toBe("short");
});
});
describe("entity queries (proper nouns)", () => {
it("should classify query with proper noun as 'entity'", () => {
expect(classifyQuery("tell me about Tarun")).toBe("entity");
});
it("should classify query with organization name as 'entity'", () => {
expect(classifyQuery("what about Google")).toBe("entity");
});
it("should classify question patterns targeting entities", () => {
expect(classifyQuery("who is the CEO")).toBe("entity");
});
it("should classify 'where is' patterns as entity", () => {
expect(classifyQuery("where is the office")).toBe("entity");
});
it("should classify 'what does' patterns as entity", () => {
expect(classifyQuery("what does she do")).toBe("entity");
});
it("should not treat common words (The, Is, etc.) as entity indicators", () => {
// "The" and "Is" are excluded from capitalized word detection
// 3 words, no proper nouns detected, no question pattern -> default
expect(classifyQuery("this is fine")).toBe("default");
});
});
describe("long queries (5+ words)", () => {
it("should classify a 5-word query as 'long'", () => {
expect(classifyQuery("what is the best framework")).toBe("long");
});
it("should classify a longer sentence as 'long'", () => {
expect(classifyQuery("tell me about the history of programming languages")).toBe("long");
});
it("should classify a verbose question as 'long'", () => {
expect(classifyQuery("how do i configure the database connection")).toBe("long");
});
});
describe("default queries (3-4 words, no entities)", () => {
it("should classify a 3-word lowercase query as 'default'", () => {
expect(classifyQuery("my favorite color")).toBe("default");
});
it("should classify a 4-word lowercase query as 'default'", () => {
expect(classifyQuery("best practices for testing")).toBe("default");
});
});
describe("edge cases", () => {
it("should handle empty string", () => {
// Empty string splits to [""], length 1 -> "short"
expect(classifyQuery("")).toBe("short");
});
it("should handle only whitespace", () => {
// " ".trim() = "", splits to [""], length 1 -> "short"
expect(classifyQuery(" ")).toBe("short");
});
});
});
// ============================================================================
// getAdaptiveWeights()
// ============================================================================
describe("getAdaptiveWeights", () => {
describe("with graph enabled", () => {
it("should boost BM25 for short queries", () => {
const [vector, bm25, graph] = getAdaptiveWeights("short", true);
expect(bm25).toBeGreaterThan(vector);
expect(vector).toBe(0.8);
expect(bm25).toBe(1.2);
expect(graph).toBe(1.0);
});
it("should boost graph for entity queries", () => {
const [vector, bm25, graph] = getAdaptiveWeights("entity", true);
expect(graph).toBeGreaterThan(vector);
expect(graph).toBeGreaterThan(bm25);
expect(vector).toBe(0.8);
expect(bm25).toBe(1.0);
expect(graph).toBe(1.3);
});
it("should boost vector for long queries", () => {
const [vector, bm25, graph] = getAdaptiveWeights("long", true);
expect(vector).toBeGreaterThan(bm25);
expect(vector).toBeGreaterThan(graph);
expect(vector).toBe(1.2);
expect(bm25).toBe(0.7);
expect(graph).toBeCloseTo(0.8);
});
it("should return balanced weights for default queries", () => {
const [vector, bm25, graph] = getAdaptiveWeights("default", true);
expect(vector).toBe(1.0);
expect(bm25).toBe(1.0);
expect(graph).toBe(1.0);
});
});
describe("with graph disabled", () => {
it("should zero-out graph weight for short queries", () => {
const [vector, bm25, graph] = getAdaptiveWeights("short", false);
expect(graph).toBe(0);
expect(vector).toBe(0.8);
expect(bm25).toBe(1.2);
});
it("should zero-out graph weight for entity queries", () => {
const [, , graph] = getAdaptiveWeights("entity", false);
expect(graph).toBe(0);
});
it("should zero-out graph weight for long queries", () => {
const [, , graph] = getAdaptiveWeights("long", false);
expect(graph).toBe(0);
});
it("should zero-out graph weight for default queries", () => {
const [, , graph] = getAdaptiveWeights("default", false);
expect(graph).toBe(0);
});
});
});
// ============================================================================
// hybridSearch() — integration test with mocked dependencies
// ============================================================================
describe("hybridSearch", () => {
// Create mock db and embeddings
const mockDb = {
vectorSearch: vi.fn(),
bm25Search: vi.fn(),
graphSearch: vi.fn(),
recordRetrievals: vi.fn(),
};
const mockEmbeddings = {
embed: vi.fn(),
embedBatch: vi.fn(),
};
beforeEach(() => {
vi.resetAllMocks();
mockEmbeddings.embed.mockResolvedValue([0.1, 0.2, 0.3]);
mockDb.recordRetrievals.mockResolvedValue(undefined);
});
function makeSignalResult(overrides: Partial<SearchSignalResult> = {}): SearchSignalResult {
return {
id: "mem-1",
text: "Test memory",
category: "fact",
importance: 0.7,
createdAt: "2025-01-01T00:00:00Z",
score: 0.9,
...overrides,
};
}
it("should return empty array when no signals return results", async () => {
mockDb.vectorSearch.mockResolvedValue([]);
mockDb.bm25Search.mockResolvedValue([]);
const results = await hybridSearch(
mockDb as never,
mockEmbeddings as never,
"test query",
5,
"agent-1",
false,
);
expect(results).toEqual([]);
expect(mockDb.recordRetrievals).not.toHaveBeenCalled();
});
it("should fuse results from vector and BM25 signals", async () => {
const vectorResult = makeSignalResult({ id: "mem-1", score: 0.95, text: "Vector match" });
const bm25Result = makeSignalResult({ id: "mem-2", score: 0.8, text: "BM25 match" });
mockDb.vectorSearch.mockResolvedValue([vectorResult]);
mockDb.bm25Search.mockResolvedValue([bm25Result]);
const results = await hybridSearch(
mockDb as never,
mockEmbeddings as never,
"test query",
5,
"agent-1",
false,
);
expect(results.length).toBe(2);
// Results should have scores normalized to 0-1
expect(results[0].score).toBeLessThanOrEqual(1);
expect(results[0].score).toBeGreaterThanOrEqual(0);
// First result should have the highest score (normalized to 1)
expect(results[0].score).toBe(1);
});
it("should deduplicate across signals (same memory in multiple signals)", async () => {
const sharedResult = makeSignalResult({ id: "mem-shared", score: 0.9 });
mockDb.vectorSearch.mockResolvedValue([sharedResult]);
mockDb.bm25Search.mockResolvedValue([{ ...sharedResult, score: 0.85 }]);
const results = await hybridSearch(
mockDb as never,
mockEmbeddings as never,
"test query",
5,
"agent-1",
false,
);
// Should only have one result (deduplicated by ID)
expect(results.length).toBe(1);
expect(results[0].id).toBe("mem-shared");
// Score should be higher than either individual signal (boosted by appearing in both)
expect(results[0].score).toBe(1); // It's the only result, so normalized to 1
});
it("should include graph signal when graphEnabled is true", async () => {
mockDb.vectorSearch.mockResolvedValue([]);
mockDb.bm25Search.mockResolvedValue([]);
mockDb.graphSearch.mockResolvedValue([
makeSignalResult({ id: "mem-graph", score: 0.7, text: "Graph result" }),
]);
const results = await hybridSearch(
mockDb as never,
mockEmbeddings as never,
"tell me about Tarun",
5,
"agent-1",
true,
);
expect(mockDb.graphSearch).toHaveBeenCalled();
expect(results.length).toBe(1);
expect(results[0].id).toBe("mem-graph");
});
it("should not call graphSearch when graphEnabled is false", async () => {
mockDb.vectorSearch.mockResolvedValue([]);
mockDb.bm25Search.mockResolvedValue([]);
await hybridSearch(mockDb as never, mockEmbeddings as never, "test query", 5, "agent-1", false);
expect(mockDb.graphSearch).not.toHaveBeenCalled();
});
it("should limit results to the requested count", async () => {
const manyResults = Array.from({ length: 10 }, (_, i) =>
makeSignalResult({ id: `mem-${i}`, score: 0.9 - i * 0.05 }),
);
mockDb.vectorSearch.mockResolvedValue(manyResults);
mockDb.bm25Search.mockResolvedValue([]);
const results = await hybridSearch(
mockDb as never,
mockEmbeddings as never,
"test query",
3,
"agent-1",
false,
);
expect(results.length).toBe(3);
});
it("should record retrieval events for returned results", async () => {
mockDb.vectorSearch.mockResolvedValue([
makeSignalResult({ id: "mem-1" }),
makeSignalResult({ id: "mem-2" }),
]);
mockDb.bm25Search.mockResolvedValue([]);
await hybridSearch(mockDb as never, mockEmbeddings as never, "test query", 5, "agent-1", false);
expect(mockDb.recordRetrievals).toHaveBeenCalledWith(["mem-1", "mem-2"]);
});
it("should silently handle recordRetrievals failure", async () => {
mockDb.vectorSearch.mockResolvedValue([makeSignalResult({ id: "mem-1" })]);
mockDb.bm25Search.mockResolvedValue([]);
mockDb.recordRetrievals.mockRejectedValue(new Error("DB connection lost"));
// Should not throw
const results = await hybridSearch(
mockDb as never,
mockEmbeddings as never,
"test query",
5,
"agent-1",
false,
);
expect(results.length).toBe(1);
});
it("should normalize scores to 0-1 range", async () => {
mockDb.vectorSearch.mockResolvedValue([
makeSignalResult({ id: "mem-1", score: 0.95 }),
makeSignalResult({ id: "mem-2", score: 0.5 }),
]);
mockDb.bm25Search.mockResolvedValue([]);
const results = await hybridSearch(
mockDb as never,
mockEmbeddings as never,
"test query",
5,
"agent-1",
false,
);
for (const r of results) {
expect(r.score).toBeGreaterThanOrEqual(0);
expect(r.score).toBeLessThanOrEqual(1);
}
});
it("should use candidateMultiplier option", async () => {
mockDb.vectorSearch.mockResolvedValue([]);
mockDb.bm25Search.mockResolvedValue([]);
await hybridSearch(
mockDb as never,
mockEmbeddings as never,
"test query",
5,
"agent-1",
false,
{ candidateMultiplier: 8 },
);
// limit=5, multiplier=8 => candidateLimit = 40
expect(mockDb.vectorSearch).toHaveBeenCalledWith(expect.any(Array), 40, 0.1, "agent-1");
expect(mockDb.bm25Search).toHaveBeenCalledWith("test query", 40, "agent-1");
});
it("should pass default agentId when not specified", async () => {
mockDb.vectorSearch.mockResolvedValue([]);
mockDb.bm25Search.mockResolvedValue([]);
await hybridSearch(mockDb as never, mockEmbeddings as never, "test query");
expect(mockDb.vectorSearch).toHaveBeenCalledWith(
expect.any(Array),
expect.any(Number),
0.1,
"default",
);
});
});