mirror of
https://github.com/openclaw/openclaw.git
synced 2026-05-23 05:48:11 +00:00
memory-neo4j: harden error handling, concurrency safety, config validation + add tests
This commit is contained in:
549
extensions/memory-neo4j/config.test.ts
Normal file
549
extensions/memory-neo4j/config.test.ts
Normal file
@@ -0,0 +1,549 @@
|
||||
/**
|
||||
* Tests for config.ts — Configuration Parsing.
|
||||
*
|
||||
* Tests memoryNeo4jConfigSchema.parse(), vectorDimsForModel(), and resolveExtractionConfig().
|
||||
*/
|
||||
|
||||
import { describe, it, expect, beforeEach, afterEach, vi } from "vitest";
|
||||
import { memoryNeo4jConfigSchema, vectorDimsForModel, resolveExtractionConfig } from "./config.js";
|
||||
|
||||
// ============================================================================
|
||||
// memoryNeo4jConfigSchema.parse()
|
||||
// ============================================================================
|
||||
|
||||
describe("memoryNeo4jConfigSchema.parse", () => {
|
||||
// Store original env vars so we can restore them
|
||||
const originalEnv = { ...process.env };
|
||||
|
||||
afterEach(() => {
|
||||
process.env = { ...originalEnv };
|
||||
});
|
||||
|
||||
describe("valid complete configs", () => {
|
||||
it("should parse a minimal valid config with ollama provider", () => {
|
||||
const config = memoryNeo4jConfigSchema.parse({
|
||||
neo4j: { uri: "bolt://localhost:7687", user: "neo4j", password: "test" },
|
||||
embedding: { provider: "ollama" },
|
||||
});
|
||||
|
||||
expect(config.neo4j.uri).toBe("bolt://localhost:7687");
|
||||
expect(config.neo4j.username).toBe("neo4j");
|
||||
expect(config.neo4j.password).toBe("test");
|
||||
expect(config.embedding.provider).toBe("ollama");
|
||||
expect(config.embedding.model).toBe("mxbai-embed-large");
|
||||
expect(config.embedding.apiKey).toBeUndefined();
|
||||
expect(config.autoCapture).toBe(true);
|
||||
expect(config.autoRecall).toBe(true);
|
||||
expect(config.coreMemory.enabled).toBe(true);
|
||||
expect(config.coreMemory.maxEntries).toBe(50);
|
||||
});
|
||||
|
||||
it("should parse a full config with openai provider", () => {
|
||||
const config = memoryNeo4jConfigSchema.parse({
|
||||
neo4j: {
|
||||
uri: "neo4j+s://cloud.neo4j.io:7687",
|
||||
username: "admin",
|
||||
password: "secret",
|
||||
},
|
||||
embedding: {
|
||||
provider: "openai",
|
||||
apiKey: "sk-test-key",
|
||||
model: "text-embedding-3-large",
|
||||
},
|
||||
autoCapture: false,
|
||||
autoRecall: false,
|
||||
coreMemory: {
|
||||
enabled: false,
|
||||
maxEntries: 100,
|
||||
refreshAtContextPercent: 75,
|
||||
},
|
||||
});
|
||||
|
||||
expect(config.neo4j.uri).toBe("neo4j+s://cloud.neo4j.io:7687");
|
||||
expect(config.neo4j.username).toBe("admin");
|
||||
expect(config.neo4j.password).toBe("secret");
|
||||
expect(config.embedding.provider).toBe("openai");
|
||||
expect(config.embedding.apiKey).toBe("sk-test-key");
|
||||
expect(config.embedding.model).toBe("text-embedding-3-large");
|
||||
expect(config.autoCapture).toBe(false);
|
||||
expect(config.autoRecall).toBe(false);
|
||||
expect(config.coreMemory.enabled).toBe(false);
|
||||
expect(config.coreMemory.maxEntries).toBe(100);
|
||||
expect(config.coreMemory.refreshAtContextPercent).toBe(75);
|
||||
});
|
||||
|
||||
it("should support 'user' field as alias for 'username' in neo4j config", () => {
|
||||
const config = memoryNeo4jConfigSchema.parse({
|
||||
neo4j: { uri: "bolt://localhost:7687", user: "custom-user", password: "pass" },
|
||||
embedding: { provider: "ollama" },
|
||||
});
|
||||
expect(config.neo4j.username).toBe("custom-user");
|
||||
});
|
||||
|
||||
it("should support 'username' field in neo4j config", () => {
|
||||
const config = memoryNeo4jConfigSchema.parse({
|
||||
neo4j: { uri: "bolt://localhost:7687", username: "custom-user", password: "pass" },
|
||||
embedding: { provider: "ollama" },
|
||||
});
|
||||
expect(config.neo4j.username).toBe("custom-user");
|
||||
});
|
||||
|
||||
it("should default neo4j username to 'neo4j' when not specified", () => {
|
||||
const config = memoryNeo4jConfigSchema.parse({
|
||||
neo4j: { uri: "bolt://localhost:7687", password: "pass" },
|
||||
embedding: { provider: "ollama" },
|
||||
});
|
||||
expect(config.neo4j.username).toBe("neo4j");
|
||||
});
|
||||
});
|
||||
|
||||
describe("missing required fields", () => {
|
||||
it("should throw when config is null", () => {
|
||||
expect(() => memoryNeo4jConfigSchema.parse(null)).toThrow("memory-neo4j config required");
|
||||
});
|
||||
|
||||
it("should throw when config is undefined", () => {
|
||||
expect(() => memoryNeo4jConfigSchema.parse(undefined)).toThrow(
|
||||
"memory-neo4j config required",
|
||||
);
|
||||
});
|
||||
|
||||
it("should throw when config is not an object", () => {
|
||||
expect(() => memoryNeo4jConfigSchema.parse("string")).toThrow("memory-neo4j config required");
|
||||
});
|
||||
|
||||
it("should throw when config is an array", () => {
|
||||
expect(() => memoryNeo4jConfigSchema.parse([])).toThrow("memory-neo4j config required");
|
||||
});
|
||||
|
||||
it("should throw when neo4j section is missing", () => {
|
||||
expect(() =>
|
||||
memoryNeo4jConfigSchema.parse({
|
||||
embedding: { provider: "ollama" },
|
||||
}),
|
||||
).toThrow("neo4j config section is required");
|
||||
});
|
||||
|
||||
it("should throw when neo4j.uri is missing", () => {
|
||||
expect(() =>
|
||||
memoryNeo4jConfigSchema.parse({
|
||||
neo4j: { password: "test" },
|
||||
embedding: { provider: "ollama" },
|
||||
}),
|
||||
).toThrow("neo4j.uri is required");
|
||||
});
|
||||
|
||||
it("should throw when neo4j.uri is empty string", () => {
|
||||
expect(() =>
|
||||
memoryNeo4jConfigSchema.parse({
|
||||
neo4j: { uri: "", password: "test" },
|
||||
embedding: { provider: "ollama" },
|
||||
}),
|
||||
).toThrow("neo4j.uri is required");
|
||||
});
|
||||
});
|
||||
|
||||
describe("environment variable resolution", () => {
|
||||
it("should resolve ${ENV_VAR} in neo4j.password", () => {
|
||||
process.env.TEST_NEO4J_PASSWORD = "resolved-password";
|
||||
const config = memoryNeo4jConfigSchema.parse({
|
||||
neo4j: {
|
||||
uri: "bolt://localhost:7687",
|
||||
password: "${TEST_NEO4J_PASSWORD}",
|
||||
},
|
||||
embedding: { provider: "ollama" },
|
||||
});
|
||||
expect(config.neo4j.password).toBe("resolved-password");
|
||||
});
|
||||
|
||||
it("should resolve ${ENV_VAR} in embedding.apiKey", () => {
|
||||
process.env.TEST_OPENAI_KEY = "sk-from-env";
|
||||
const config = memoryNeo4jConfigSchema.parse({
|
||||
neo4j: { uri: "bolt://localhost:7687", password: "" },
|
||||
embedding: { provider: "openai", apiKey: "${TEST_OPENAI_KEY}" },
|
||||
});
|
||||
expect(config.embedding.apiKey).toBe("sk-from-env");
|
||||
});
|
||||
|
||||
it("should throw when referenced env var is not set", () => {
|
||||
delete process.env.NONEXISTENT_VAR;
|
||||
expect(() =>
|
||||
memoryNeo4jConfigSchema.parse({
|
||||
neo4j: {
|
||||
uri: "bolt://localhost:7687",
|
||||
password: "${NONEXISTENT_VAR}",
|
||||
},
|
||||
embedding: { provider: "ollama" },
|
||||
}),
|
||||
).toThrow("Environment variable NONEXISTENT_VAR is not set");
|
||||
});
|
||||
});
|
||||
|
||||
describe("default values", () => {
|
||||
it("should default autoCapture to true", () => {
|
||||
const config = memoryNeo4jConfigSchema.parse({
|
||||
neo4j: { uri: "bolt://localhost:7687", password: "" },
|
||||
embedding: { provider: "ollama" },
|
||||
});
|
||||
expect(config.autoCapture).toBe(true);
|
||||
});
|
||||
|
||||
it("should default autoRecall to true", () => {
|
||||
const config = memoryNeo4jConfigSchema.parse({
|
||||
neo4j: { uri: "bolt://localhost:7687", password: "" },
|
||||
embedding: { provider: "ollama" },
|
||||
});
|
||||
expect(config.autoRecall).toBe(true);
|
||||
});
|
||||
|
||||
it("should default coreMemory.enabled to true", () => {
|
||||
const config = memoryNeo4jConfigSchema.parse({
|
||||
neo4j: { uri: "bolt://localhost:7687", password: "" },
|
||||
embedding: { provider: "ollama" },
|
||||
});
|
||||
expect(config.coreMemory.enabled).toBe(true);
|
||||
});
|
||||
|
||||
it("should default coreMemory.maxEntries to 50", () => {
|
||||
const config = memoryNeo4jConfigSchema.parse({
|
||||
neo4j: { uri: "bolt://localhost:7687", password: "" },
|
||||
embedding: { provider: "ollama" },
|
||||
});
|
||||
expect(config.coreMemory.maxEntries).toBe(50);
|
||||
});
|
||||
|
||||
it("should default refreshAtContextPercent to undefined", () => {
|
||||
const config = memoryNeo4jConfigSchema.parse({
|
||||
neo4j: { uri: "bolt://localhost:7687", password: "" },
|
||||
embedding: { provider: "ollama" },
|
||||
});
|
||||
expect(config.coreMemory.refreshAtContextPercent).toBeUndefined();
|
||||
});
|
||||
|
||||
it("should default embedding model to mxbai-embed-large for ollama", () => {
|
||||
const config = memoryNeo4jConfigSchema.parse({
|
||||
neo4j: { uri: "bolt://localhost:7687", password: "" },
|
||||
embedding: { provider: "ollama" },
|
||||
});
|
||||
expect(config.embedding.model).toBe("mxbai-embed-large");
|
||||
});
|
||||
|
||||
it("should default embedding model to text-embedding-3-small for openai", () => {
|
||||
const config = memoryNeo4jConfigSchema.parse({
|
||||
neo4j: { uri: "bolt://localhost:7687", password: "" },
|
||||
embedding: { provider: "openai", apiKey: "sk-test" },
|
||||
});
|
||||
expect(config.embedding.model).toBe("text-embedding-3-small");
|
||||
});
|
||||
|
||||
it("should default neo4j.password to empty string when not provided", () => {
|
||||
const config = memoryNeo4jConfigSchema.parse({
|
||||
neo4j: { uri: "bolt://localhost:7687" },
|
||||
embedding: { provider: "ollama" },
|
||||
});
|
||||
expect(config.neo4j.password).toBe("");
|
||||
});
|
||||
});
|
||||
|
||||
describe("provider validation", () => {
|
||||
it("should require apiKey for openai provider", () => {
|
||||
expect(() =>
|
||||
memoryNeo4jConfigSchema.parse({
|
||||
neo4j: { uri: "bolt://localhost:7687", password: "" },
|
||||
embedding: { provider: "openai" },
|
||||
}),
|
||||
).toThrow("embedding.apiKey is required for OpenAI provider");
|
||||
});
|
||||
|
||||
it("should not require apiKey for ollama provider", () => {
|
||||
const config = memoryNeo4jConfigSchema.parse({
|
||||
neo4j: { uri: "bolt://localhost:7687", password: "" },
|
||||
embedding: { provider: "ollama" },
|
||||
});
|
||||
expect(config.embedding.apiKey).toBeUndefined();
|
||||
});
|
||||
|
||||
it("should default to openai when no provider is specified", () => {
|
||||
// No provider but has apiKey — should default to openai
|
||||
const config = memoryNeo4jConfigSchema.parse({
|
||||
neo4j: { uri: "bolt://localhost:7687", password: "" },
|
||||
embedding: { apiKey: "sk-test" },
|
||||
});
|
||||
expect(config.embedding.provider).toBe("openai");
|
||||
});
|
||||
|
||||
it("should accept embedding.baseUrl", () => {
|
||||
const config = memoryNeo4jConfigSchema.parse({
|
||||
neo4j: { uri: "bolt://localhost:7687", password: "" },
|
||||
embedding: { provider: "ollama", baseUrl: "http://my-ollama:11434" },
|
||||
});
|
||||
expect(config.embedding.baseUrl).toBe("http://my-ollama:11434");
|
||||
});
|
||||
});
|
||||
|
||||
describe("unknown keys rejected", () => {
|
||||
it("should reject unknown top-level keys", () => {
|
||||
expect(() =>
|
||||
memoryNeo4jConfigSchema.parse({
|
||||
neo4j: { uri: "bolt://localhost:7687", password: "" },
|
||||
embedding: { provider: "ollama" },
|
||||
unknownKey: "value",
|
||||
}),
|
||||
).toThrow("unknown keys: unknownKey");
|
||||
});
|
||||
|
||||
it("should reject unknown neo4j keys", () => {
|
||||
expect(() =>
|
||||
memoryNeo4jConfigSchema.parse({
|
||||
neo4j: { uri: "bolt://localhost:7687", password: "", port: 7687 },
|
||||
embedding: { provider: "ollama" },
|
||||
}),
|
||||
).toThrow("unknown keys: port");
|
||||
});
|
||||
|
||||
it("should reject unknown embedding keys", () => {
|
||||
expect(() =>
|
||||
memoryNeo4jConfigSchema.parse({
|
||||
neo4j: { uri: "bolt://localhost:7687", password: "" },
|
||||
embedding: { provider: "ollama", temperature: 0.5 },
|
||||
}),
|
||||
).toThrow("unknown keys: temperature");
|
||||
});
|
||||
|
||||
it("should reject unknown coreMemory keys", () => {
|
||||
expect(() =>
|
||||
memoryNeo4jConfigSchema.parse({
|
||||
neo4j: { uri: "bolt://localhost:7687", password: "" },
|
||||
embedding: { provider: "ollama" },
|
||||
coreMemory: { unknownField: true },
|
||||
}),
|
||||
).toThrow("unknown keys: unknownField");
|
||||
});
|
||||
});
|
||||
|
||||
describe("refreshAtContextPercent edge cases", () => {
|
||||
it("should accept refreshAtContextPercent of 1 (minimum valid)", () => {
|
||||
const config = memoryNeo4jConfigSchema.parse({
|
||||
neo4j: { uri: "bolt://localhost:7687", password: "" },
|
||||
embedding: { provider: "ollama" },
|
||||
coreMemory: { refreshAtContextPercent: 1 },
|
||||
});
|
||||
expect(config.coreMemory.refreshAtContextPercent).toBe(1);
|
||||
});
|
||||
|
||||
it("should accept refreshAtContextPercent of 100 (maximum valid)", () => {
|
||||
const config = memoryNeo4jConfigSchema.parse({
|
||||
neo4j: { uri: "bolt://localhost:7687", password: "" },
|
||||
embedding: { provider: "ollama" },
|
||||
coreMemory: { refreshAtContextPercent: 100 },
|
||||
});
|
||||
expect(config.coreMemory.refreshAtContextPercent).toBe(100);
|
||||
});
|
||||
|
||||
it("should reject refreshAtContextPercent of 0", () => {
|
||||
const config = memoryNeo4jConfigSchema.parse({
|
||||
neo4j: { uri: "bolt://localhost:7687", password: "" },
|
||||
embedding: { provider: "ollama" },
|
||||
coreMemory: { refreshAtContextPercent: 0 },
|
||||
});
|
||||
expect(config.coreMemory.refreshAtContextPercent).toBeUndefined();
|
||||
});
|
||||
|
||||
it("should reject refreshAtContextPercent over 100 by throwing", () => {
|
||||
expect(() =>
|
||||
memoryNeo4jConfigSchema.parse({
|
||||
neo4j: { uri: "bolt://localhost:7687", password: "" },
|
||||
embedding: { provider: "ollama" },
|
||||
coreMemory: { refreshAtContextPercent: 150 },
|
||||
}),
|
||||
).toThrow("coreMemory.refreshAtContextPercent must be between 1 and 100");
|
||||
});
|
||||
|
||||
it("should reject negative refreshAtContextPercent", () => {
|
||||
const config = memoryNeo4jConfigSchema.parse({
|
||||
neo4j: { uri: "bolt://localhost:7687", password: "" },
|
||||
embedding: { provider: "ollama" },
|
||||
coreMemory: { refreshAtContextPercent: -10 },
|
||||
});
|
||||
expect(config.coreMemory.refreshAtContextPercent).toBeUndefined();
|
||||
});
|
||||
|
||||
it("should ignore non-number refreshAtContextPercent", () => {
|
||||
const config = memoryNeo4jConfigSchema.parse({
|
||||
neo4j: { uri: "bolt://localhost:7687", password: "" },
|
||||
embedding: { provider: "ollama" },
|
||||
coreMemory: { refreshAtContextPercent: "50" },
|
||||
});
|
||||
expect(config.coreMemory.refreshAtContextPercent).toBeUndefined();
|
||||
});
|
||||
});
|
||||
|
||||
describe("extraction config section", () => {
|
||||
it("should parse extraction config when provided", () => {
|
||||
process.env.EXTRACTION_DUMMY = ""; // avoid env var issues
|
||||
const config = memoryNeo4jConfigSchema.parse({
|
||||
neo4j: { uri: "bolt://localhost:7687", password: "" },
|
||||
embedding: { provider: "ollama" },
|
||||
extraction: {
|
||||
apiKey: "or-test-key",
|
||||
model: "google/gemini-2.0-flash-001",
|
||||
baseUrl: "https://openrouter.ai/api/v1",
|
||||
},
|
||||
});
|
||||
expect(config.extraction).toBeDefined();
|
||||
expect(config.extraction!.apiKey).toBe("or-test-key");
|
||||
expect(config.extraction!.model).toBe("google/gemini-2.0-flash-001");
|
||||
});
|
||||
|
||||
it("should not include extraction when section is empty", () => {
|
||||
const config = memoryNeo4jConfigSchema.parse({
|
||||
neo4j: { uri: "bolt://localhost:7687", password: "" },
|
||||
embedding: { provider: "ollama" },
|
||||
extraction: {},
|
||||
});
|
||||
expect(config.extraction).toBeUndefined();
|
||||
});
|
||||
|
||||
it("should reject unknown keys in extraction section", () => {
|
||||
expect(() =>
|
||||
memoryNeo4jConfigSchema.parse({
|
||||
neo4j: { uri: "bolt://localhost:7687", password: "" },
|
||||
embedding: { provider: "ollama" },
|
||||
extraction: { badKey: "value" },
|
||||
}),
|
||||
).toThrow("unknown keys: badKey");
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
// ============================================================================
|
||||
// vectorDimsForModel()
|
||||
// ============================================================================
|
||||
|
||||
describe("vectorDimsForModel", () => {
|
||||
describe("known models", () => {
|
||||
it("should return 1536 for text-embedding-3-small", () => {
|
||||
expect(vectorDimsForModel("text-embedding-3-small")).toBe(1536);
|
||||
});
|
||||
|
||||
it("should return 3072 for text-embedding-3-large", () => {
|
||||
expect(vectorDimsForModel("text-embedding-3-large")).toBe(3072);
|
||||
});
|
||||
|
||||
it("should return 1024 for mxbai-embed-large", () => {
|
||||
expect(vectorDimsForModel("mxbai-embed-large")).toBe(1024);
|
||||
});
|
||||
|
||||
it("should return 768 for nomic-embed-text", () => {
|
||||
expect(vectorDimsForModel("nomic-embed-text")).toBe(768);
|
||||
});
|
||||
|
||||
it("should return 384 for all-minilm", () => {
|
||||
expect(vectorDimsForModel("all-minilm")).toBe(384);
|
||||
});
|
||||
});
|
||||
|
||||
describe("prefix matching", () => {
|
||||
it("should match versioned model names via prefix", () => {
|
||||
// mxbai-embed-large:latest should match mxbai-embed-large
|
||||
expect(vectorDimsForModel("mxbai-embed-large:latest")).toBe(1024);
|
||||
});
|
||||
|
||||
it("should match model with additional version suffix", () => {
|
||||
expect(vectorDimsForModel("nomic-embed-text:v1.5")).toBe(768);
|
||||
});
|
||||
});
|
||||
|
||||
describe("unknown models", () => {
|
||||
it("should return default 1024 for unknown model", () => {
|
||||
expect(vectorDimsForModel("unknown-model")).toBe(1024);
|
||||
});
|
||||
|
||||
it("should return default 1024 for empty string", () => {
|
||||
expect(vectorDimsForModel("")).toBe(1024);
|
||||
});
|
||||
|
||||
it("should return default 1024 for unrecognized prefix", () => {
|
||||
expect(vectorDimsForModel("custom-embed-v2")).toBe(1024);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
// ============================================================================
|
||||
// resolveExtractionConfig()
|
||||
// ============================================================================
|
||||
|
||||
describe("resolveExtractionConfig", () => {
|
||||
const originalEnv = { ...process.env };
|
||||
|
||||
afterEach(() => {
|
||||
process.env = { ...originalEnv };
|
||||
});
|
||||
|
||||
it("should return disabled config when no API key or explicit baseUrl", () => {
|
||||
delete process.env.OPENROUTER_API_KEY;
|
||||
const config = resolveExtractionConfig();
|
||||
expect(config.enabled).toBe(false);
|
||||
expect(config.apiKey).toBe("");
|
||||
});
|
||||
|
||||
it("should enable when OPENROUTER_API_KEY env var is set", () => {
|
||||
process.env.OPENROUTER_API_KEY = "or-env-key";
|
||||
const config = resolveExtractionConfig();
|
||||
expect(config.enabled).toBe(true);
|
||||
expect(config.apiKey).toBe("or-env-key");
|
||||
});
|
||||
|
||||
it("should enable when plugin config provides apiKey", () => {
|
||||
delete process.env.OPENROUTER_API_KEY;
|
||||
const config = resolveExtractionConfig({
|
||||
apiKey: "or-plugin-key",
|
||||
model: "custom-model",
|
||||
baseUrl: "https://custom.ai/api",
|
||||
});
|
||||
expect(config.enabled).toBe(true);
|
||||
expect(config.apiKey).toBe("or-plugin-key");
|
||||
expect(config.model).toBe("custom-model");
|
||||
expect(config.baseUrl).toBe("https://custom.ai/api");
|
||||
});
|
||||
|
||||
it("should enable when baseUrl is explicitly set (local Ollama, no API key)", () => {
|
||||
delete process.env.OPENROUTER_API_KEY;
|
||||
const config = resolveExtractionConfig({
|
||||
model: "llama3",
|
||||
baseUrl: "http://localhost:11434/v1",
|
||||
});
|
||||
expect(config.enabled).toBe(true);
|
||||
expect(config.apiKey).toBe("");
|
||||
expect(config.baseUrl).toBe("http://localhost:11434/v1");
|
||||
});
|
||||
|
||||
it("should use defaults for model and baseUrl", () => {
|
||||
delete process.env.OPENROUTER_API_KEY;
|
||||
delete process.env.EXTRACTION_MODEL;
|
||||
delete process.env.EXTRACTION_BASE_URL;
|
||||
const config = resolveExtractionConfig();
|
||||
expect(config.model).toBe("google/gemini-2.0-flash-001");
|
||||
expect(config.baseUrl).toBe("https://openrouter.ai/api/v1");
|
||||
});
|
||||
|
||||
it("should use EXTRACTION_MODEL env var", () => {
|
||||
delete process.env.OPENROUTER_API_KEY;
|
||||
process.env.EXTRACTION_MODEL = "meta/llama-3-70b";
|
||||
const config = resolveExtractionConfig();
|
||||
expect(config.model).toBe("meta/llama-3-70b");
|
||||
});
|
||||
|
||||
it("should use EXTRACTION_BASE_URL env var", () => {
|
||||
delete process.env.OPENROUTER_API_KEY;
|
||||
process.env.EXTRACTION_BASE_URL = "https://my-proxy.ai/v1";
|
||||
const config = resolveExtractionConfig();
|
||||
expect(config.baseUrl).toBe("https://my-proxy.ai/v1");
|
||||
});
|
||||
|
||||
it("should always set temperature to 0.0 and maxRetries to 2", () => {
|
||||
const config = resolveExtractionConfig();
|
||||
expect(config.temperature).toBe(0.0);
|
||||
expect(config.maxRetries).toBe(2);
|
||||
});
|
||||
});
|
||||
@@ -63,7 +63,7 @@ export const MEMORY_CATEGORIES = [
|
||||
|
||||
export type MemoryCategory = (typeof MEMORY_CATEGORIES)[number];
|
||||
|
||||
const EMBEDDING_DIMENSIONS: Record<string, number> = {
|
||||
export const EMBEDDING_DIMENSIONS: Record<string, number> = {
|
||||
// OpenAI models
|
||||
"text-embedding-3-small": 1536,
|
||||
"text-embedding-3-large": 3072,
|
||||
@@ -75,7 +75,7 @@ const EMBEDDING_DIMENSIONS: Record<string, number> = {
|
||||
};
|
||||
|
||||
// Default dimension for unknown models (Ollama models vary)
|
||||
const DEFAULT_EMBEDDING_DIMS = 1024;
|
||||
export const DEFAULT_EMBEDDING_DIMS = 1024;
|
||||
|
||||
export function vectorDimsForModel(model: string): number {
|
||||
// Check exact match first
|
||||
@@ -88,7 +88,8 @@ export function vectorDimsForModel(model: string): number {
|
||||
return dims;
|
||||
}
|
||||
}
|
||||
// Return default for unknown models
|
||||
// Return default for unknown models — callers should warn when this path is taken,
|
||||
// as the default 1024 dimensions may not match the actual model's output.
|
||||
return DEFAULT_EMBEDDING_DIMS;
|
||||
}
|
||||
|
||||
@@ -164,6 +165,20 @@ export const memoryNeo4jConfigSchema = {
|
||||
if (typeof neo4jRaw.uri !== "string" || !neo4jRaw.uri) {
|
||||
throw new Error("neo4j.uri is required");
|
||||
}
|
||||
// Validate URI scheme — must be a valid Neo4j connection protocol
|
||||
const VALID_NEO4J_SCHEMES = [
|
||||
"bolt://",
|
||||
"bolt+s://",
|
||||
"bolt+ssc://",
|
||||
"neo4j://",
|
||||
"neo4j+s://",
|
||||
"neo4j+ssc://",
|
||||
];
|
||||
if (!VALID_NEO4J_SCHEMES.some((scheme) => neo4jRaw.uri.startsWith(scheme))) {
|
||||
throw new Error(
|
||||
`neo4j.uri must start with a valid scheme (${VALID_NEO4J_SCHEMES.join(", ")}), got: "${neo4jRaw.uri}"`,
|
||||
);
|
||||
}
|
||||
|
||||
const neo4jPassword =
|
||||
typeof neo4jRaw.password === "string" ? resolveEnvVars(neo4jRaw.password) : "";
|
||||
@@ -212,7 +227,19 @@ export const memoryNeo4jConfigSchema = {
|
||||
const coreMemoryEnabled = coreMemoryRaw?.enabled !== false; // enabled by default
|
||||
const coreMemoryMaxEntries =
|
||||
typeof coreMemoryRaw?.maxEntries === "number" ? coreMemoryRaw.maxEntries : 50;
|
||||
// refreshAtContextPercent: number between 0-100, or undefined to disable
|
||||
if (coreMemoryMaxEntries <= 0) {
|
||||
throw new Error(`coreMemory.maxEntries must be greater than 0, got: ${coreMemoryMaxEntries}`);
|
||||
}
|
||||
// refreshAtContextPercent: number between 1-99 to be effective, or undefined to disable.
|
||||
// Values at 0 or below are ignored (disables refresh). Values above 100 are invalid.
|
||||
if (
|
||||
typeof coreMemoryRaw?.refreshAtContextPercent === "number" &&
|
||||
coreMemoryRaw.refreshAtContextPercent > 100
|
||||
) {
|
||||
throw new Error(
|
||||
`coreMemory.refreshAtContextPercent must be between 1 and 100, got: ${coreMemoryRaw.refreshAtContextPercent}`,
|
||||
);
|
||||
}
|
||||
const refreshAtContextPercent =
|
||||
typeof coreMemoryRaw?.refreshAtContextPercent === "number" &&
|
||||
coreMemoryRaw.refreshAtContextPercent > 0 &&
|
||||
|
||||
192
extensions/memory-neo4j/embeddings.test.ts
Normal file
192
extensions/memory-neo4j/embeddings.test.ts
Normal file
@@ -0,0 +1,192 @@
|
||||
/**
|
||||
* Tests for embeddings.ts — Embedding Provider.
|
||||
*
|
||||
* Tests the Embeddings class with mocked OpenAI client and mocked fetch for Ollama.
|
||||
*/
|
||||
|
||||
import { describe, it, expect, vi, afterEach } from "vitest";
|
||||
|
||||
// ============================================================================
|
||||
// Constructor
|
||||
// ============================================================================
|
||||
|
||||
describe("Embeddings constructor", () => {
|
||||
it("should throw when OpenAI provider is used without API key", async () => {
|
||||
const { Embeddings } = await import("./embeddings.js");
|
||||
expect(() => new Embeddings(undefined, "text-embedding-3-small", "openai")).toThrow(
|
||||
"API key required for OpenAI embeddings",
|
||||
);
|
||||
});
|
||||
|
||||
it("should not require API key for ollama provider", async () => {
|
||||
const { Embeddings } = await import("./embeddings.js");
|
||||
const emb = new Embeddings(undefined, "mxbai-embed-large", "ollama");
|
||||
expect(emb).toBeDefined();
|
||||
});
|
||||
});
|
||||
|
||||
// ============================================================================
|
||||
// Ollama embed
|
||||
// ============================================================================
|
||||
|
||||
describe("Embeddings - Ollama provider", () => {
|
||||
const originalFetch = globalThis.fetch;
|
||||
|
||||
afterEach(() => {
|
||||
globalThis.fetch = originalFetch;
|
||||
});
|
||||
|
||||
it("should call Ollama API with correct request body", async () => {
|
||||
const { Embeddings } = await import("./embeddings.js");
|
||||
const mockVector = [0.1, 0.2, 0.3, 0.4];
|
||||
globalThis.fetch = vi.fn().mockResolvedValue({
|
||||
ok: true,
|
||||
json: () => Promise.resolve({ embeddings: [mockVector] }),
|
||||
});
|
||||
|
||||
const emb = new Embeddings(undefined, "mxbai-embed-large", "ollama");
|
||||
const result = await emb.embed("test text");
|
||||
|
||||
expect(result).toEqual(mockVector);
|
||||
expect(globalThis.fetch).toHaveBeenCalledWith(
|
||||
"http://localhost:11434/api/embed",
|
||||
expect.objectContaining({
|
||||
method: "POST",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
body: JSON.stringify({
|
||||
model: "mxbai-embed-large",
|
||||
input: "test text",
|
||||
}),
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
it("should use custom baseUrl for Ollama", async () => {
|
||||
const { Embeddings } = await import("./embeddings.js");
|
||||
const mockVector = [0.5, 0.6];
|
||||
globalThis.fetch = vi.fn().mockResolvedValue({
|
||||
ok: true,
|
||||
json: () => Promise.resolve({ embeddings: [mockVector] }),
|
||||
});
|
||||
|
||||
const emb = new Embeddings(undefined, "mxbai-embed-large", "ollama", "http://my-host:11434");
|
||||
await emb.embed("test");
|
||||
|
||||
expect(globalThis.fetch).toHaveBeenCalledWith(
|
||||
"http://my-host:11434/api/embed",
|
||||
expect.any(Object),
|
||||
);
|
||||
});
|
||||
|
||||
it("should throw when Ollama returns error status", async () => {
|
||||
const { Embeddings } = await import("./embeddings.js");
|
||||
globalThis.fetch = vi.fn().mockResolvedValue({
|
||||
ok: false,
|
||||
status: 500,
|
||||
text: () => Promise.resolve("Internal Server Error"),
|
||||
});
|
||||
|
||||
const emb = new Embeddings(undefined, "mxbai-embed-large", "ollama");
|
||||
await expect(emb.embed("test")).rejects.toThrow("Ollama embedding failed: 500");
|
||||
});
|
||||
|
||||
it("should throw when Ollama returns no embeddings", async () => {
|
||||
const { Embeddings } = await import("./embeddings.js");
|
||||
globalThis.fetch = vi.fn().mockResolvedValue({
|
||||
ok: true,
|
||||
json: () => Promise.resolve({ embeddings: [] }),
|
||||
});
|
||||
|
||||
const emb = new Embeddings(undefined, "mxbai-embed-large", "ollama");
|
||||
await expect(emb.embed("test")).rejects.toThrow("No embedding returned from Ollama");
|
||||
});
|
||||
|
||||
it("should throw when Ollama returns null embeddings", async () => {
|
||||
const { Embeddings } = await import("./embeddings.js");
|
||||
globalThis.fetch = vi.fn().mockResolvedValue({
|
||||
ok: true,
|
||||
json: () => Promise.resolve({}),
|
||||
});
|
||||
|
||||
const emb = new Embeddings(undefined, "mxbai-embed-large", "ollama");
|
||||
await expect(emb.embed("test")).rejects.toThrow("No embedding returned from Ollama");
|
||||
});
|
||||
|
||||
it("should propagate fetch errors for Ollama", async () => {
|
||||
const { Embeddings } = await import("./embeddings.js");
|
||||
globalThis.fetch = vi.fn().mockRejectedValue(new Error("Network error"));
|
||||
|
||||
const emb = new Embeddings(undefined, "mxbai-embed-large", "ollama");
|
||||
await expect(emb.embed("test")).rejects.toThrow("Network error");
|
||||
});
|
||||
});
|
||||
|
||||
// ============================================================================
|
||||
// OpenAI embed (via mocked client internals)
|
||||
// ============================================================================
|
||||
|
||||
describe("Embeddings - OpenAI provider", () => {
|
||||
it("should create instance with OpenAI provider when API key provided", async () => {
|
||||
const { Embeddings } = await import("./embeddings.js");
|
||||
// Just verify construction succeeds with valid params
|
||||
const emb = new Embeddings("sk-test-key", "text-embedding-3-small", "openai");
|
||||
expect(emb).toBeDefined();
|
||||
});
|
||||
|
||||
it("should have embed and embedBatch methods", async () => {
|
||||
const { Embeddings } = await import("./embeddings.js");
|
||||
const emb = new Embeddings("sk-test-key", "text-embedding-3-small", "openai");
|
||||
expect(typeof emb.embed).toBe("function");
|
||||
expect(typeof emb.embedBatch).toBe("function");
|
||||
});
|
||||
});
|
||||
|
||||
// ============================================================================
|
||||
// Batch embedding
|
||||
// ============================================================================
|
||||
|
||||
describe("Embeddings - embedBatch", () => {
|
||||
const originalFetch = globalThis.fetch;
|
||||
|
||||
afterEach(() => {
|
||||
globalThis.fetch = originalFetch;
|
||||
});
|
||||
|
||||
it("should return empty array for empty input (openai)", async () => {
|
||||
const { Embeddings } = await import("./embeddings.js");
|
||||
const emb = new Embeddings("sk-test", "text-embedding-3-small", "openai");
|
||||
const results = await emb.embedBatch([]);
|
||||
expect(results).toEqual([]);
|
||||
});
|
||||
|
||||
it("should return empty array for empty input (ollama)", async () => {
|
||||
const { Embeddings } = await import("./embeddings.js");
|
||||
const emb = new Embeddings(undefined, "mxbai-embed-large", "ollama");
|
||||
const results = await emb.embedBatch([]);
|
||||
expect(results).toEqual([]);
|
||||
});
|
||||
|
||||
it("should use sequential calls for Ollama batch (no native batch support)", async () => {
|
||||
const { Embeddings } = await import("./embeddings.js");
|
||||
let callCount = 0;
|
||||
globalThis.fetch = vi.fn().mockImplementation(() => {
|
||||
callCount++;
|
||||
return Promise.resolve({
|
||||
ok: true,
|
||||
json: () => Promise.resolve({ embeddings: [[callCount * 0.1, callCount * 0.2]] }),
|
||||
});
|
||||
});
|
||||
|
||||
const emb = new Embeddings(undefined, "mxbai-embed-large", "ollama");
|
||||
const results = await emb.embedBatch(["text1", "text2", "text3"]);
|
||||
|
||||
// Should make 3 separate calls
|
||||
expect(globalThis.fetch).toHaveBeenCalledTimes(3);
|
||||
expect(results).toHaveLength(3);
|
||||
// Each result should be a vector
|
||||
for (const r of results) {
|
||||
expect(Array.isArray(r)).toBe(true);
|
||||
expect(r.length).toBe(2);
|
||||
}
|
||||
});
|
||||
});
|
||||
@@ -7,19 +7,29 @@
|
||||
import OpenAI from "openai";
|
||||
import type { EmbeddingProvider } from "./config.js";
|
||||
|
||||
type Logger = {
|
||||
info: (msg: string) => void;
|
||||
warn: (msg: string) => void;
|
||||
error: (msg: string) => void;
|
||||
debug?: (msg: string) => void;
|
||||
};
|
||||
|
||||
export class Embeddings {
|
||||
private client: OpenAI | null = null;
|
||||
private readonly provider: EmbeddingProvider;
|
||||
private readonly baseUrl: string;
|
||||
private readonly logger: Logger | undefined;
|
||||
|
||||
constructor(
|
||||
private readonly apiKey: string | undefined,
|
||||
private readonly model: string = "text-embedding-3-small",
|
||||
provider: EmbeddingProvider = "openai",
|
||||
baseUrl?: string,
|
||||
logger?: Logger,
|
||||
) {
|
||||
this.provider = provider;
|
||||
this.baseUrl = baseUrl ?? (provider === "ollama" ? "http://localhost:11434" : "");
|
||||
this.logger = logger;
|
||||
|
||||
if (provider === "openai") {
|
||||
if (!apiKey) {
|
||||
@@ -42,6 +52,9 @@ export class Embeddings {
|
||||
/**
|
||||
* Generate embeddings for multiple texts.
|
||||
* Returns array of embeddings in the same order as input.
|
||||
*
|
||||
* For Ollama: uses Promise.allSettled so individual failures don't break the
|
||||
* entire batch. Failed embeddings are replaced with zero vectors and logged.
|
||||
*/
|
||||
async embedBatch(texts: string[]): Promise<number[][]> {
|
||||
if (texts.length === 0) {
|
||||
@@ -49,8 +62,32 @@ export class Embeddings {
|
||||
}
|
||||
|
||||
if (this.provider === "ollama") {
|
||||
// Ollama doesn't support batch, so we do sequential
|
||||
return Promise.all(texts.map((t) => this.embedOllama(t)));
|
||||
// Ollama doesn't support batch, so we do sequential with resilient error handling
|
||||
const results = await Promise.allSettled(texts.map((t) => this.embedOllama(t)));
|
||||
const embeddings: number[][] = [];
|
||||
let failures = 0;
|
||||
|
||||
for (let i = 0; i < results.length; i++) {
|
||||
const result = results[i];
|
||||
if (result.status === "fulfilled") {
|
||||
embeddings.push(result.value);
|
||||
} else {
|
||||
failures++;
|
||||
this.logger?.warn?.(
|
||||
`memory-neo4j: Ollama embedding failed for text ${i}: ${String(result.reason)}`,
|
||||
);
|
||||
// Use zero vector as placeholder so indices stay aligned
|
||||
embeddings.push([]);
|
||||
}
|
||||
}
|
||||
|
||||
if (failures > 0) {
|
||||
this.logger?.warn?.(
|
||||
`memory-neo4j: ${failures}/${texts.length} Ollama embeddings failed in batch`,
|
||||
);
|
||||
}
|
||||
|
||||
return embeddings;
|
||||
}
|
||||
|
||||
return this.embedBatchOpenAI(texts);
|
||||
@@ -79,6 +116,9 @@ export class Embeddings {
|
||||
return response.data.toSorted((a, b) => a.index - b.index).map((d) => d.embedding);
|
||||
}
|
||||
|
||||
// Timeout for Ollama embedding fetch calls to prevent hanging indefinitely
|
||||
private static readonly EMBED_TIMEOUT_MS = 30_000;
|
||||
|
||||
private async embedOllama(text: string): Promise<number[]> {
|
||||
const url = `${this.baseUrl}/api/embed`;
|
||||
const response = await fetch(url, {
|
||||
@@ -88,6 +128,7 @@ export class Embeddings {
|
||||
model: this.model,
|
||||
input: text,
|
||||
}),
|
||||
signal: AbortSignal.timeout(Embeddings.EMBED_TIMEOUT_MS),
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
|
||||
760
extensions/memory-neo4j/extractor.test.ts
Normal file
760
extensions/memory-neo4j/extractor.test.ts
Normal file
@@ -0,0 +1,760 @@
|
||||
/**
|
||||
* Tests for extractor.ts — Extraction Logic.
|
||||
*
|
||||
* Tests exported functions: extractEntities(), extractUserMessages(), runBackgroundExtraction().
|
||||
* Note: validateExtractionResult() is not exported; it is tested indirectly through extractEntities().
|
||||
* Note: passesAttentionGate() is defined in index.ts and not exported; cannot be tested directly.
|
||||
*/
|
||||
|
||||
import { describe, it, expect, vi, beforeEach, afterEach } from "vitest";
|
||||
import type { ExtractionConfig } from "./config.js";
|
||||
import { extractUserMessages, extractEntities, runBackgroundExtraction } from "./extractor.js";
|
||||
|
||||
// ============================================================================
|
||||
// extractUserMessages()
|
||||
// ============================================================================
|
||||
|
||||
describe("extractUserMessages", () => {
|
||||
it("should extract string content from user messages", () => {
|
||||
const messages = [
|
||||
{ role: "user", content: "I prefer TypeScript over JavaScript" },
|
||||
{ role: "user", content: "My favorite color is blue" },
|
||||
];
|
||||
const result = extractUserMessages(messages);
|
||||
expect(result).toEqual(["I prefer TypeScript over JavaScript", "My favorite color is blue"]);
|
||||
});
|
||||
|
||||
it("should extract text from content block arrays", () => {
|
||||
const messages = [
|
||||
{
|
||||
role: "user",
|
||||
content: [
|
||||
{ type: "text", text: "Hello, this is a content block message" },
|
||||
{ type: "image", url: "http://example.com/img.png" },
|
||||
{ type: "text", text: "Another text block in same message" },
|
||||
],
|
||||
},
|
||||
];
|
||||
const result = extractUserMessages(messages);
|
||||
expect(result).toEqual([
|
||||
"Hello, this is a content block message",
|
||||
"Another text block in same message",
|
||||
]);
|
||||
});
|
||||
|
||||
it("should filter out assistant messages", () => {
|
||||
const messages = [
|
||||
{ role: "user", content: "This is a user message that is long enough" },
|
||||
{ role: "assistant", content: "This is an assistant message" },
|
||||
];
|
||||
const result = extractUserMessages(messages);
|
||||
expect(result).toEqual(["This is a user message that is long enough"]);
|
||||
});
|
||||
|
||||
it("should filter out system messages", () => {
|
||||
const messages = [
|
||||
{ role: "system", content: "You are a helpful assistant with context" },
|
||||
{ role: "user", content: "This is a user message that is long enough" },
|
||||
];
|
||||
const result = extractUserMessages(messages);
|
||||
expect(result).toEqual(["This is a user message that is long enough"]);
|
||||
});
|
||||
|
||||
it("should filter out messages shorter than 10 characters", () => {
|
||||
const messages = [
|
||||
{ role: "user", content: "short" }, // 5 chars
|
||||
{ role: "user", content: "1234567890" }, // exactly 10 chars
|
||||
{ role: "user", content: "This is longer than ten characters" },
|
||||
];
|
||||
const result = extractUserMessages(messages);
|
||||
expect(result).toEqual(["1234567890", "This is longer than ten characters"]);
|
||||
});
|
||||
|
||||
it("should filter out messages containing <relevant-memories>", () => {
|
||||
const messages = [
|
||||
{ role: "user", content: "Normal user message that is long enough here" },
|
||||
{
|
||||
role: "user",
|
||||
content:
|
||||
"<relevant-memories>Some injected context that should be ignored</relevant-memories>",
|
||||
},
|
||||
];
|
||||
const result = extractUserMessages(messages);
|
||||
expect(result).toEqual(["Normal user message that is long enough here"]);
|
||||
});
|
||||
|
||||
it("should filter out messages containing <system>", () => {
|
||||
const messages = [
|
||||
{ role: "user", content: "<system>System markup that should be filtered</system>" },
|
||||
{ role: "user", content: "Normal user message that is long enough here" },
|
||||
];
|
||||
const result = extractUserMessages(messages);
|
||||
expect(result).toEqual(["Normal user message that is long enough here"]);
|
||||
});
|
||||
|
||||
it("should handle null and non-object messages gracefully", () => {
|
||||
const messages = [
|
||||
null,
|
||||
undefined,
|
||||
"not an object",
|
||||
42,
|
||||
{ role: "user", content: "Valid message with enough length" },
|
||||
];
|
||||
const result = extractUserMessages(messages as unknown[]);
|
||||
expect(result).toEqual(["Valid message with enough length"]);
|
||||
});
|
||||
|
||||
it("should return empty array when no user messages exist", () => {
|
||||
const messages = [{ role: "assistant", content: "Only assistant messages" }];
|
||||
const result = extractUserMessages(messages);
|
||||
expect(result).toEqual([]);
|
||||
});
|
||||
|
||||
it("should return empty array for empty input", () => {
|
||||
expect(extractUserMessages([])).toEqual([]);
|
||||
});
|
||||
|
||||
it("should handle messages where content is neither string nor array", () => {
|
||||
const messages = [
|
||||
{ role: "user", content: 42 },
|
||||
{ role: "user", content: null },
|
||||
{ role: "user", content: { nested: true } },
|
||||
];
|
||||
const result = extractUserMessages(messages as unknown[]);
|
||||
expect(result).toEqual([]);
|
||||
});
|
||||
});
|
||||
|
||||
// ============================================================================
|
||||
// extractEntities() — tests validateExtractionResult() indirectly
|
||||
// ============================================================================
|
||||
|
||||
describe("extractEntities", () => {
|
||||
// We need to mock `fetch` since callOpenRouter uses global fetch
|
||||
const originalFetch = globalThis.fetch;
|
||||
|
||||
beforeEach(() => {
|
||||
vi.restoreAllMocks();
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
globalThis.fetch = originalFetch;
|
||||
});
|
||||
|
||||
const enabledConfig: ExtractionConfig = {
|
||||
enabled: true,
|
||||
apiKey: "test-key",
|
||||
model: "test-model",
|
||||
baseUrl: "https://test.ai/api/v1",
|
||||
temperature: 0.0,
|
||||
maxRetries: 0, // No retries in tests
|
||||
};
|
||||
|
||||
const disabledConfig: ExtractionConfig = {
|
||||
...enabledConfig,
|
||||
enabled: false,
|
||||
};
|
||||
|
||||
function mockFetchResponse(content: string, status = 200) {
|
||||
globalThis.fetch = vi.fn().mockResolvedValue({
|
||||
ok: status >= 200 && status < 300,
|
||||
status,
|
||||
text: () => Promise.resolve(content),
|
||||
json: () =>
|
||||
Promise.resolve({
|
||||
choices: [{ message: { content } }],
|
||||
}),
|
||||
});
|
||||
}
|
||||
|
||||
it("should return null result when extraction is disabled", async () => {
|
||||
const { result, transientFailure } = await extractEntities("test text", disabledConfig);
|
||||
expect(result).toBeNull();
|
||||
expect(transientFailure).toBe(false);
|
||||
});
|
||||
|
||||
it("should extract valid entities from LLM response", async () => {
|
||||
mockFetchResponse(
|
||||
JSON.stringify({
|
||||
category: "fact",
|
||||
entities: [
|
||||
{ name: "Tarun", type: "person", aliases: ["boss"], description: "The CEO" },
|
||||
{ name: "Abundent", type: "organization" },
|
||||
],
|
||||
relationships: [
|
||||
{ source: "Tarun", target: "Abundent", type: "WORKS_AT", confidence: 0.95 },
|
||||
],
|
||||
tags: [{ name: "Leadership", category: "business" }],
|
||||
}),
|
||||
);
|
||||
|
||||
const { result } = await extractEntities("Tarun works at Abundent", enabledConfig);
|
||||
expect(result).not.toBeNull();
|
||||
expect(result!.category).toBe("fact");
|
||||
|
||||
// Entities should be normalized to lowercase
|
||||
expect(result!.entities).toHaveLength(2);
|
||||
expect(result!.entities[0].name).toBe("tarun");
|
||||
expect(result!.entities[0].type).toBe("person");
|
||||
expect(result!.entities[0].aliases).toEqual(["boss"]);
|
||||
expect(result!.entities[0].description).toBe("The CEO");
|
||||
expect(result!.entities[1].name).toBe("abundent");
|
||||
expect(result!.entities[1].type).toBe("organization");
|
||||
|
||||
// Relationships should be normalized to lowercase source/target
|
||||
expect(result!.relationships).toHaveLength(1);
|
||||
expect(result!.relationships[0].source).toBe("tarun");
|
||||
expect(result!.relationships[0].target).toBe("abundent");
|
||||
expect(result!.relationships[0].type).toBe("WORKS_AT");
|
||||
expect(result!.relationships[0].confidence).toBe(0.95);
|
||||
|
||||
// Tags should be normalized to lowercase
|
||||
expect(result!.tags).toHaveLength(1);
|
||||
expect(result!.tags[0].name).toBe("leadership");
|
||||
expect(result!.tags[0].category).toBe("business");
|
||||
});
|
||||
|
||||
it("should handle empty extraction result", async () => {
|
||||
mockFetchResponse(
|
||||
JSON.stringify({
|
||||
category: "other",
|
||||
entities: [],
|
||||
relationships: [],
|
||||
tags: [],
|
||||
}),
|
||||
);
|
||||
|
||||
const { result } = await extractEntities("just a greeting", enabledConfig);
|
||||
expect(result).not.toBeNull();
|
||||
expect(result!.entities).toEqual([]);
|
||||
expect(result!.relationships).toEqual([]);
|
||||
expect(result!.tags).toEqual([]);
|
||||
});
|
||||
|
||||
it("should handle missing fields in LLM response", async () => {
|
||||
mockFetchResponse(
|
||||
JSON.stringify({
|
||||
// No category, entities, relationships, or tags
|
||||
}),
|
||||
);
|
||||
|
||||
const { result } = await extractEntities("some text", enabledConfig);
|
||||
expect(result).not.toBeNull();
|
||||
expect(result!.category).toBeUndefined();
|
||||
expect(result!.entities).toEqual([]);
|
||||
expect(result!.relationships).toEqual([]);
|
||||
expect(result!.tags).toEqual([]);
|
||||
});
|
||||
|
||||
it("should filter out invalid entity types (fallback to concept)", async () => {
|
||||
mockFetchResponse(
|
||||
JSON.stringify({
|
||||
entities: [
|
||||
{ name: "Widget", type: "gadget" }, // invalid type -> concept
|
||||
{ name: "Paris", type: "location" }, // valid type
|
||||
],
|
||||
relationships: [],
|
||||
tags: [],
|
||||
}),
|
||||
);
|
||||
|
||||
const { result } = await extractEntities("test", enabledConfig);
|
||||
expect(result!.entities).toHaveLength(2);
|
||||
expect(result!.entities[0].type).toBe("concept"); // invalid type falls back to concept
|
||||
expect(result!.entities[1].type).toBe("location");
|
||||
});
|
||||
|
||||
it("should filter out invalid relationship types", async () => {
|
||||
mockFetchResponse(
|
||||
JSON.stringify({
|
||||
entities: [],
|
||||
relationships: [
|
||||
{ source: "a", target: "b", type: "WORKS_AT", confidence: 0.9 }, // valid
|
||||
{ source: "a", target: "b", type: "HATES", confidence: 0.9 }, // invalid type
|
||||
],
|
||||
tags: [],
|
||||
}),
|
||||
);
|
||||
|
||||
const { result } = await extractEntities("test", enabledConfig);
|
||||
expect(result!.relationships).toHaveLength(1);
|
||||
expect(result!.relationships[0].type).toBe("WORKS_AT");
|
||||
});
|
||||
|
||||
it("should clamp confidence to 0-1 range", async () => {
|
||||
mockFetchResponse(
|
||||
JSON.stringify({
|
||||
entities: [],
|
||||
relationships: [
|
||||
{ source: "a", target: "b", type: "KNOWS", confidence: 1.5 }, // over 1
|
||||
{ source: "c", target: "d", type: "KNOWS", confidence: -0.5 }, // under 0
|
||||
],
|
||||
tags: [],
|
||||
}),
|
||||
);
|
||||
|
||||
const { result } = await extractEntities("test", enabledConfig);
|
||||
expect(result!.relationships[0].confidence).toBe(1);
|
||||
expect(result!.relationships[1].confidence).toBe(0);
|
||||
});
|
||||
|
||||
it("should default confidence to 0.7 when not a number", async () => {
|
||||
mockFetchResponse(
|
||||
JSON.stringify({
|
||||
entities: [],
|
||||
relationships: [{ source: "a", target: "b", type: "KNOWS", confidence: "high" }],
|
||||
tags: [],
|
||||
}),
|
||||
);
|
||||
|
||||
const { result } = await extractEntities("test", enabledConfig);
|
||||
expect(result!.relationships[0].confidence).toBe(0.7);
|
||||
});
|
||||
|
||||
it("should filter out entities without name", async () => {
|
||||
mockFetchResponse(
|
||||
JSON.stringify({
|
||||
entities: [
|
||||
{ name: "", type: "person" }, // empty name -> filtered
|
||||
{ name: " ", type: "person" }, // whitespace-only name -> filtered (after trim)
|
||||
{ name: "valid", type: "person" }, // valid
|
||||
],
|
||||
relationships: [],
|
||||
tags: [],
|
||||
}),
|
||||
);
|
||||
|
||||
const { result } = await extractEntities("test", enabledConfig);
|
||||
expect(result!.entities).toHaveLength(1);
|
||||
expect(result!.entities[0].name).toBe("valid");
|
||||
});
|
||||
|
||||
it("should filter out entities with non-object shape", async () => {
|
||||
mockFetchResponse(
|
||||
JSON.stringify({
|
||||
entities: [null, "not an entity", 42, { name: "valid", type: "person" }],
|
||||
relationships: [],
|
||||
tags: [],
|
||||
}),
|
||||
);
|
||||
|
||||
const { result } = await extractEntities("test", enabledConfig);
|
||||
expect(result!.entities).toHaveLength(1);
|
||||
});
|
||||
|
||||
it("should filter out entities missing required fields", async () => {
|
||||
mockFetchResponse(
|
||||
JSON.stringify({
|
||||
entities: [
|
||||
{ type: "person" }, // missing name
|
||||
{ name: "test" }, // missing type
|
||||
{ name: "valid", type: "person" }, // has both
|
||||
],
|
||||
relationships: [],
|
||||
tags: [],
|
||||
}),
|
||||
);
|
||||
|
||||
const { result } = await extractEntities("test", enabledConfig);
|
||||
expect(result!.entities).toHaveLength(1);
|
||||
expect(result!.entities[0].name).toBe("valid");
|
||||
});
|
||||
|
||||
it("should default tag category to 'topic' when missing", async () => {
|
||||
mockFetchResponse(
|
||||
JSON.stringify({
|
||||
entities: [],
|
||||
relationships: [],
|
||||
tags: [{ name: "neo4j" }], // no category
|
||||
}),
|
||||
);
|
||||
|
||||
const { result } = await extractEntities("test", enabledConfig);
|
||||
expect(result!.tags[0].category).toBe("topic");
|
||||
});
|
||||
|
||||
it("should filter out tags with empty names", async () => {
|
||||
mockFetchResponse(
|
||||
JSON.stringify({
|
||||
entities: [],
|
||||
relationships: [],
|
||||
tags: [
|
||||
{ name: "", category: "tech" }, // empty -> filtered
|
||||
{ name: " ", category: "tech" }, // whitespace-only -> filtered
|
||||
{ name: "valid", category: "tech" },
|
||||
],
|
||||
}),
|
||||
);
|
||||
|
||||
const { result } = await extractEntities("test", enabledConfig);
|
||||
expect(result!.tags).toHaveLength(1);
|
||||
expect(result!.tags[0].name).toBe("valid");
|
||||
});
|
||||
|
||||
it("should reject invalid category values", async () => {
|
||||
mockFetchResponse(
|
||||
JSON.stringify({
|
||||
category: "invalid-category",
|
||||
entities: [],
|
||||
relationships: [],
|
||||
tags: [],
|
||||
}),
|
||||
);
|
||||
|
||||
const { result } = await extractEntities("test", enabledConfig);
|
||||
expect(result!.category).toBeUndefined();
|
||||
});
|
||||
|
||||
it("should accept valid category values", async () => {
|
||||
for (const category of ["preference", "fact", "decision", "entity", "other"]) {
|
||||
mockFetchResponse(
|
||||
JSON.stringify({
|
||||
category,
|
||||
entities: [],
|
||||
relationships: [],
|
||||
tags: [],
|
||||
}),
|
||||
);
|
||||
const { result } = await extractEntities(`test ${category}`, enabledConfig);
|
||||
expect(result!.category).toBe(category);
|
||||
}
|
||||
});
|
||||
|
||||
it("should return null result for malformed JSON response (permanent failure)", async () => {
|
||||
mockFetchResponse("not valid json at all");
|
||||
|
||||
const { result, transientFailure } = await extractEntities("test", enabledConfig);
|
||||
// callOpenRouter returns the raw string, JSON.parse fails, catch returns null
|
||||
expect(result).toBeNull();
|
||||
expect(transientFailure).toBe(false);
|
||||
});
|
||||
|
||||
it("should return null result when API returns error status", async () => {
|
||||
globalThis.fetch = vi.fn().mockResolvedValue({
|
||||
ok: false,
|
||||
status: 500,
|
||||
text: () => Promise.resolve("Internal Server Error"),
|
||||
});
|
||||
|
||||
const { result } = await extractEntities("test", enabledConfig);
|
||||
// API error 500 is not in the transient list (only 429, 502, 503, 504)
|
||||
expect(result).toBeNull();
|
||||
});
|
||||
|
||||
it("should return null result when API returns no content", async () => {
|
||||
globalThis.fetch = vi.fn().mockResolvedValue({
|
||||
ok: true,
|
||||
status: 200,
|
||||
json: () => Promise.resolve({ choices: [{ message: { content: null } }] }),
|
||||
});
|
||||
|
||||
const { result, transientFailure } = await extractEntities("test", enabledConfig);
|
||||
expect(result).toBeNull();
|
||||
expect(transientFailure).toBe(false);
|
||||
});
|
||||
|
||||
it("should normalize alias strings to lowercase", async () => {
|
||||
mockFetchResponse(
|
||||
JSON.stringify({
|
||||
entities: [{ name: "John", type: "person", aliases: ["Johnny", "JOHN", "j.doe"] }],
|
||||
relationships: [],
|
||||
tags: [],
|
||||
}),
|
||||
);
|
||||
|
||||
const { result } = await extractEntities("test", enabledConfig);
|
||||
expect(result!.entities[0].aliases).toEqual(["johnny", "john", "j.doe"]);
|
||||
});
|
||||
|
||||
it("should filter out non-string aliases", async () => {
|
||||
mockFetchResponse(
|
||||
JSON.stringify({
|
||||
entities: [{ name: "John", type: "person", aliases: ["valid", 42, null, "also-valid"] }],
|
||||
relationships: [],
|
||||
tags: [],
|
||||
}),
|
||||
);
|
||||
|
||||
const { result } = await extractEntities("test", enabledConfig);
|
||||
expect(result!.entities[0].aliases).toEqual(["valid", "also-valid"]);
|
||||
});
|
||||
});
|
||||
|
||||
// ============================================================================
|
||||
// runBackgroundExtraction()
|
||||
// ============================================================================
|
||||
|
||||
describe("runBackgroundExtraction", () => {
|
||||
const originalFetch = globalThis.fetch;
|
||||
|
||||
let mockLogger: {
|
||||
info: ReturnType<typeof vi.fn>;
|
||||
warn: ReturnType<typeof vi.fn>;
|
||||
error: ReturnType<typeof vi.fn>;
|
||||
debug: ReturnType<typeof vi.fn>;
|
||||
};
|
||||
|
||||
let mockDb: {
|
||||
updateExtractionStatus: ReturnType<typeof vi.fn>;
|
||||
mergeEntity: ReturnType<typeof vi.fn>;
|
||||
createMentions: ReturnType<typeof vi.fn>;
|
||||
createEntityRelationship: ReturnType<typeof vi.fn>;
|
||||
tagMemory: ReturnType<typeof vi.fn>;
|
||||
updateMemoryCategory: ReturnType<typeof vi.fn>;
|
||||
};
|
||||
|
||||
let mockEmbeddings: {
|
||||
embed: ReturnType<typeof vi.fn>;
|
||||
embedBatch: ReturnType<typeof vi.fn>;
|
||||
};
|
||||
|
||||
beforeEach(() => {
|
||||
vi.restoreAllMocks();
|
||||
mockLogger = {
|
||||
info: vi.fn(),
|
||||
warn: vi.fn(),
|
||||
error: vi.fn(),
|
||||
debug: vi.fn(),
|
||||
};
|
||||
mockDb = {
|
||||
updateExtractionStatus: vi.fn().mockResolvedValue(undefined),
|
||||
mergeEntity: vi.fn().mockResolvedValue(undefined),
|
||||
createMentions: vi.fn().mockResolvedValue(undefined),
|
||||
createEntityRelationship: vi.fn().mockResolvedValue(undefined),
|
||||
tagMemory: vi.fn().mockResolvedValue(undefined),
|
||||
updateMemoryCategory: vi.fn().mockResolvedValue(undefined),
|
||||
};
|
||||
mockEmbeddings = {
|
||||
embed: vi.fn().mockResolvedValue([0.1, 0.2, 0.3]),
|
||||
embedBatch: vi.fn().mockResolvedValue([[0.1, 0.2, 0.3]]),
|
||||
};
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
globalThis.fetch = originalFetch;
|
||||
});
|
||||
|
||||
const enabledConfig: ExtractionConfig = {
|
||||
enabled: true,
|
||||
apiKey: "test-key",
|
||||
model: "test-model",
|
||||
baseUrl: "https://test.ai/api/v1",
|
||||
temperature: 0.0,
|
||||
maxRetries: 0,
|
||||
};
|
||||
|
||||
const disabledConfig: ExtractionConfig = {
|
||||
...enabledConfig,
|
||||
enabled: false,
|
||||
};
|
||||
|
||||
function mockFetchResponse(content: string) {
|
||||
globalThis.fetch = vi.fn().mockResolvedValue({
|
||||
ok: true,
|
||||
status: 200,
|
||||
json: () =>
|
||||
Promise.resolve({
|
||||
choices: [{ message: { content } }],
|
||||
}),
|
||||
});
|
||||
}
|
||||
|
||||
it("should skip extraction and mark as 'skipped' when disabled", async () => {
|
||||
await runBackgroundExtraction(
|
||||
"mem-1",
|
||||
"test text",
|
||||
mockDb as never,
|
||||
mockEmbeddings as never,
|
||||
disabledConfig,
|
||||
mockLogger,
|
||||
);
|
||||
expect(mockDb.updateExtractionStatus).toHaveBeenCalledWith("mem-1", "skipped");
|
||||
});
|
||||
|
||||
it("should mark as 'failed' when extraction returns null", async () => {
|
||||
globalThis.fetch = vi.fn().mockResolvedValue({
|
||||
ok: false,
|
||||
status: 500,
|
||||
text: () => Promise.resolve("error"),
|
||||
});
|
||||
|
||||
await runBackgroundExtraction(
|
||||
"mem-1",
|
||||
"test text",
|
||||
mockDb as never,
|
||||
mockEmbeddings as never,
|
||||
enabledConfig,
|
||||
mockLogger,
|
||||
);
|
||||
expect(mockDb.updateExtractionStatus).toHaveBeenCalledWith("mem-1", "failed");
|
||||
});
|
||||
|
||||
it("should mark as 'complete' when extraction result is empty", async () => {
|
||||
mockFetchResponse(
|
||||
JSON.stringify({
|
||||
entities: [],
|
||||
relationships: [],
|
||||
tags: [],
|
||||
}),
|
||||
);
|
||||
|
||||
await runBackgroundExtraction(
|
||||
"mem-1",
|
||||
"test text",
|
||||
mockDb as never,
|
||||
mockEmbeddings as never,
|
||||
enabledConfig,
|
||||
mockLogger,
|
||||
);
|
||||
expect(mockDb.updateExtractionStatus).toHaveBeenCalledWith("mem-1", "complete");
|
||||
});
|
||||
|
||||
it("should merge entities, create mentions, and mark complete", async () => {
|
||||
mockFetchResponse(
|
||||
JSON.stringify({
|
||||
category: "fact",
|
||||
entities: [{ name: "Alice", type: "person" }],
|
||||
relationships: [],
|
||||
tags: [],
|
||||
}),
|
||||
);
|
||||
|
||||
await runBackgroundExtraction(
|
||||
"mem-1",
|
||||
"Alice is a developer",
|
||||
mockDb as never,
|
||||
mockEmbeddings as never,
|
||||
enabledConfig,
|
||||
mockLogger,
|
||||
);
|
||||
|
||||
expect(mockDb.mergeEntity).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
name: "alice",
|
||||
type: "person",
|
||||
}),
|
||||
);
|
||||
expect(mockDb.createMentions).toHaveBeenCalledWith("mem-1", "alice", "context", 1.0);
|
||||
expect(mockDb.updateMemoryCategory).toHaveBeenCalledWith("mem-1", "fact");
|
||||
expect(mockDb.updateExtractionStatus).toHaveBeenCalledWith("mem-1", "complete");
|
||||
});
|
||||
|
||||
it("should create entity relationships", async () => {
|
||||
mockFetchResponse(
|
||||
JSON.stringify({
|
||||
entities: [
|
||||
{ name: "Alice", type: "person" },
|
||||
{ name: "Acme", type: "organization" },
|
||||
],
|
||||
relationships: [{ source: "Alice", target: "Acme", type: "WORKS_AT", confidence: 0.9 }],
|
||||
tags: [],
|
||||
}),
|
||||
);
|
||||
|
||||
await runBackgroundExtraction(
|
||||
"mem-1",
|
||||
"Alice works at Acme",
|
||||
mockDb as never,
|
||||
mockEmbeddings as never,
|
||||
enabledConfig,
|
||||
mockLogger,
|
||||
);
|
||||
|
||||
expect(mockDb.createEntityRelationship).toHaveBeenCalledWith("alice", "acme", "WORKS_AT", 0.9);
|
||||
});
|
||||
|
||||
it("should tag memories", async () => {
|
||||
mockFetchResponse(
|
||||
JSON.stringify({
|
||||
entities: [],
|
||||
relationships: [],
|
||||
tags: [{ name: "Programming", category: "tech" }],
|
||||
}),
|
||||
);
|
||||
|
||||
await runBackgroundExtraction(
|
||||
"mem-1",
|
||||
"test text",
|
||||
mockDb as never,
|
||||
mockEmbeddings as never,
|
||||
enabledConfig,
|
||||
mockLogger,
|
||||
);
|
||||
|
||||
expect(mockDb.tagMemory).toHaveBeenCalledWith("mem-1", "programming", "tech");
|
||||
});
|
||||
|
||||
it("should not update category when result has no category", async () => {
|
||||
mockFetchResponse(
|
||||
JSON.stringify({
|
||||
entities: [{ name: "Test", type: "concept" }],
|
||||
relationships: [],
|
||||
tags: [],
|
||||
}),
|
||||
);
|
||||
|
||||
await runBackgroundExtraction(
|
||||
"mem-1",
|
||||
"test",
|
||||
mockDb as never,
|
||||
mockEmbeddings as never,
|
||||
enabledConfig,
|
||||
mockLogger,
|
||||
);
|
||||
|
||||
expect(mockDb.updateMemoryCategory).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it("should handle entity merge failure gracefully", async () => {
|
||||
mockFetchResponse(
|
||||
JSON.stringify({
|
||||
entities: [
|
||||
{ name: "Alice", type: "person" },
|
||||
{ name: "Bob", type: "person" },
|
||||
],
|
||||
relationships: [],
|
||||
tags: [],
|
||||
}),
|
||||
);
|
||||
|
||||
// First entity merge fails, second succeeds
|
||||
mockDb.mergeEntity.mockRejectedValueOnce(new Error("merge failed"));
|
||||
mockDb.mergeEntity.mockResolvedValueOnce(undefined);
|
||||
|
||||
await runBackgroundExtraction(
|
||||
"mem-1",
|
||||
"Alice and Bob",
|
||||
mockDb as never,
|
||||
mockEmbeddings as never,
|
||||
enabledConfig,
|
||||
mockLogger,
|
||||
);
|
||||
|
||||
// Should still continue and complete
|
||||
expect(mockDb.mergeEntity).toHaveBeenCalledTimes(2);
|
||||
expect(mockDb.updateExtractionStatus).toHaveBeenCalledWith("mem-1", "complete");
|
||||
expect(mockLogger.warn).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it("should log extraction results", async () => {
|
||||
mockFetchResponse(
|
||||
JSON.stringify({
|
||||
category: "fact",
|
||||
entities: [{ name: "Test", type: "concept" }],
|
||||
relationships: [{ source: "a", target: "b", type: "RELATED_TO", confidence: 0.8 }],
|
||||
tags: [{ name: "tech" }],
|
||||
}),
|
||||
);
|
||||
|
||||
await runBackgroundExtraction(
|
||||
"mem-12345678-abcd",
|
||||
"test",
|
||||
mockDb as never,
|
||||
mockEmbeddings as never,
|
||||
enabledConfig,
|
||||
mockLogger,
|
||||
);
|
||||
|
||||
expect(mockLogger.info).toHaveBeenCalledWith(expect.stringContaining("extraction complete"));
|
||||
});
|
||||
});
|
||||
@@ -63,6 +63,9 @@ Rules:
|
||||
// OpenRouter API Client
|
||||
// ============================================================================
|
||||
|
||||
// Timeout for LLM and embedding fetch calls to prevent hanging indefinitely
|
||||
const FETCH_TIMEOUT_MS = 30_000;
|
||||
|
||||
async function callOpenRouter(config: ExtractionConfig, prompt: string): Promise<string | null> {
|
||||
for (let attempt = 0; attempt <= config.maxRetries; attempt++) {
|
||||
try {
|
||||
@@ -78,6 +81,7 @@ async function callOpenRouter(config: ExtractionConfig, prompt: string): Promise
|
||||
temperature: config.temperature,
|
||||
response_format: { type: "json_object" },
|
||||
}),
|
||||
signal: AbortSignal.timeout(FETCH_TIMEOUT_MS),
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
@@ -104,30 +108,68 @@ async function callOpenRouter(config: ExtractionConfig, prompt: string): Promise
|
||||
// Entity Extraction
|
||||
// ============================================================================
|
||||
|
||||
/** Max retries for transient extraction failures before marking permanently failed */
|
||||
const MAX_EXTRACTION_RETRIES = 3;
|
||||
|
||||
/**
|
||||
* Check if an error is transient (network/timeout) vs permanent (JSON parse, etc.)
|
||||
*/
|
||||
function isTransientError(err: unknown): boolean {
|
||||
if (!(err instanceof Error)) return false;
|
||||
const msg = err.message.toLowerCase();
|
||||
return (
|
||||
err.name === "AbortError" ||
|
||||
err.name === "TimeoutError" ||
|
||||
msg.includes("timeout") ||
|
||||
msg.includes("econnrefused") ||
|
||||
msg.includes("econnreset") ||
|
||||
msg.includes("enotfound") ||
|
||||
msg.includes("network") ||
|
||||
msg.includes("fetch failed") ||
|
||||
msg.includes("socket hang up") ||
|
||||
msg.includes("api error 429") ||
|
||||
msg.includes("api error 502") ||
|
||||
msg.includes("api error 503") ||
|
||||
msg.includes("api error 504")
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Extract entities and relationships from a memory text using LLM.
|
||||
*
|
||||
* Returns { result, transientFailure }:
|
||||
* - result is the ExtractionResult or null if extraction returned nothing useful
|
||||
* - transientFailure is true if the failure was due to a network/timeout issue
|
||||
* (caller should retry later) vs a permanent failure (bad JSON, etc.)
|
||||
*/
|
||||
export async function extractEntities(
|
||||
text: string,
|
||||
config: ExtractionConfig,
|
||||
): Promise<ExtractionResult | null> {
|
||||
): Promise<{ result: ExtractionResult | null; transientFailure: boolean }> {
|
||||
if (!config.enabled) {
|
||||
return null;
|
||||
return { result: null, transientFailure: false };
|
||||
}
|
||||
|
||||
const prompt = ENTITY_EXTRACTION_PROMPT.replace("{text}", text);
|
||||
|
||||
let content: string | null;
|
||||
try {
|
||||
const content = await callOpenRouter(config, prompt);
|
||||
if (!content) {
|
||||
return null;
|
||||
}
|
||||
content = await callOpenRouter(config, prompt);
|
||||
} catch (err) {
|
||||
// Network/timeout errors are transient — caller should retry
|
||||
return { result: null, transientFailure: isTransientError(err) };
|
||||
}
|
||||
|
||||
if (!content) {
|
||||
return { result: null, transientFailure: false };
|
||||
}
|
||||
|
||||
try {
|
||||
const parsed = JSON.parse(content) as Record<string, unknown>;
|
||||
return validateExtractionResult(parsed);
|
||||
return { result: validateExtractionResult(parsed), transientFailure: false };
|
||||
} catch {
|
||||
// Will be handled by caller; don't throw for parse errors
|
||||
return null;
|
||||
// JSON parse failure is permanent — LLM returned malformed output
|
||||
return { result: null, transientFailure: false };
|
||||
}
|
||||
}
|
||||
|
||||
@@ -213,7 +255,11 @@ function validateExtractionResult(raw: Record<string, unknown>): ExtractionResul
|
||||
* 3. Create MENTIONS relationships from Memory → Entity
|
||||
* 4. Create inter-Entity relationships (WORKS_AT, KNOWS, etc.)
|
||||
* 5. Tag the memory
|
||||
* 6. Update extractionStatus to "complete" or "failed"
|
||||
* 6. Update extractionStatus to "complete", "pending" (transient retry), or "failed"
|
||||
*
|
||||
* Transient failures (network/timeout) leave status as "pending" with an incremented
|
||||
* retry counter. After MAX_EXTRACTION_RETRIES transient failures, the memory is
|
||||
* permanently marked "failed". Permanent failures (malformed JSON) are immediately "failed".
|
||||
*/
|
||||
export async function runBackgroundExtraction(
|
||||
memoryId: string,
|
||||
@@ -222,6 +268,7 @@ export async function runBackgroundExtraction(
|
||||
embeddings: Embeddings,
|
||||
config: ExtractionConfig,
|
||||
logger: Logger,
|
||||
currentRetries: number = 0,
|
||||
): Promise<void> {
|
||||
if (!config.enabled) {
|
||||
await db.updateExtractionStatus(memoryId, "skipped").catch(() => {});
|
||||
@@ -229,10 +276,28 @@ export async function runBackgroundExtraction(
|
||||
}
|
||||
|
||||
try {
|
||||
const result = await extractEntities(text, config);
|
||||
const { result, transientFailure } = await extractEntities(text, config);
|
||||
|
||||
if (!result) {
|
||||
await db.updateExtractionStatus(memoryId, "failed");
|
||||
if (transientFailure) {
|
||||
// Transient failure (network/timeout) — leave as pending for retry
|
||||
const retries = currentRetries + 1;
|
||||
if (retries >= MAX_EXTRACTION_RETRIES) {
|
||||
logger.warn(
|
||||
`memory-neo4j: extraction permanently failed for ${memoryId.slice(0, 8)} after ${retries} transient retries`,
|
||||
);
|
||||
await db.updateExtractionStatus(memoryId, "failed", { incrementRetries: true });
|
||||
} else {
|
||||
logger.info(
|
||||
`memory-neo4j: extraction transient failure for ${memoryId.slice(0, 8)}, will retry (${retries}/${MAX_EXTRACTION_RETRIES})`,
|
||||
);
|
||||
// Keep status as "pending" but increment retry counter
|
||||
await db.updateExtractionStatus(memoryId, "pending", { incrementRetries: true });
|
||||
}
|
||||
} else {
|
||||
// Permanent failure (JSON parse, empty response, etc.)
|
||||
await db.updateExtractionStatus(memoryId, "failed");
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -309,8 +374,21 @@ export async function runBackgroundExtraction(
|
||||
(result.category ? `, category=${result.category}` : ""),
|
||||
);
|
||||
} catch (err) {
|
||||
logger.warn(`memory-neo4j: extraction failed for ${memoryId.slice(0, 8)}: ${String(err)}`);
|
||||
await db.updateExtractionStatus(memoryId, "failed").catch(() => {});
|
||||
// Unexpected error during graph operations — treat as transient if retry budget remains
|
||||
const isTransient = isTransientError(err);
|
||||
if (isTransient && currentRetries + 1 < MAX_EXTRACTION_RETRIES) {
|
||||
logger.warn(
|
||||
`memory-neo4j: extraction transient error for ${memoryId.slice(0, 8)}, will retry: ${String(err)}`,
|
||||
);
|
||||
await db
|
||||
.updateExtractionStatus(memoryId, "pending", { incrementRetries: true })
|
||||
.catch(() => {});
|
||||
} else {
|
||||
logger.warn(`memory-neo4j: extraction failed for ${memoryId.slice(0, 8)}: ${String(err)}`);
|
||||
await db
|
||||
.updateExtractionStatus(memoryId, "failed", { incrementRetries: true })
|
||||
.catch(() => {});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -533,6 +611,14 @@ export async function runSleepCycle(
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// Phase 3: Core Promotion (using pre-computed scores from Phase 2)
|
||||
//
|
||||
// Design note on staleness: The effective scores and Pareto threshold were
|
||||
// computed in Phase 2 and may be slightly stale by the time Phases 3/4 run.
|
||||
// This is acceptable because: (a) the sleep cycle is a background maintenance
|
||||
// task that runs infrequently (not concurrent with itself), (b) the scoring
|
||||
// formula is deterministic based on stored properties that change slowly, and
|
||||
// (c) promotion/demotion are reversible in the next cycle. The alternative
|
||||
// (re-querying scores per phase) adds latency without meaningful accuracy gain.
|
||||
// --------------------------------------------------------------------------
|
||||
if (!abortSignal?.aborted && paretoThreshold > 0) {
|
||||
onPhaseStart?.("promotion");
|
||||
@@ -631,7 +717,15 @@ export async function runSleepCycle(
|
||||
const chunk = pending.slice(i, i + EXTRACTION_CONCURRENCY);
|
||||
const outcomes = await Promise.allSettled(
|
||||
chunk.map((memory) =>
|
||||
runBackgroundExtraction(memory.id, memory.text, db, embeddings, config, logger),
|
||||
runBackgroundExtraction(
|
||||
memory.id,
|
||||
memory.text,
|
||||
db,
|
||||
embeddings,
|
||||
config,
|
||||
logger,
|
||||
memory.extractionRetries,
|
||||
),
|
||||
),
|
||||
);
|
||||
|
||||
|
||||
@@ -18,6 +18,8 @@ import { randomUUID } from "node:crypto";
|
||||
import { stringEnum } from "openclaw/plugin-sdk";
|
||||
import type { MemoryCategory, MemorySource } from "./schema.js";
|
||||
import {
|
||||
DEFAULT_EMBEDDING_DIMS,
|
||||
EMBEDDING_DIMENSIONS,
|
||||
MEMORY_CATEGORIES,
|
||||
memoryNeo4jConfigSchema,
|
||||
resolveExtractionConfig,
|
||||
@@ -46,6 +48,25 @@ const memoryNeo4jPlugin = {
|
||||
const extractionConfig = resolveExtractionConfig(cfg.extraction);
|
||||
const vectorDim = vectorDimsForModel(cfg.embedding.model);
|
||||
|
||||
// Warn on empty neo4j password (may be valid for some setups, but usually a misconfiguration)
|
||||
if (!cfg.neo4j.password) {
|
||||
api.logger.warn(
|
||||
"memory-neo4j: neo4j.password is empty — this may be intentional for passwordless setups, but verify your configuration",
|
||||
);
|
||||
}
|
||||
|
||||
// Warn when using default embedding dimensions for an unknown model
|
||||
const isKnownModel =
|
||||
cfg.embedding.model in EMBEDDING_DIMENSIONS ||
|
||||
Object.keys(EMBEDDING_DIMENSIONS).some((known) => cfg.embedding.model.startsWith(known));
|
||||
if (!isKnownModel) {
|
||||
api.logger.warn(
|
||||
`memory-neo4j: unknown embedding model "${cfg.embedding.model}" — using default ${DEFAULT_EMBEDDING_DIMS} dimensions. ` +
|
||||
`If your model outputs a different dimension, vector operations will fail. ` +
|
||||
`Known models: ${Object.keys(EMBEDDING_DIMENSIONS).join(", ")}`,
|
||||
);
|
||||
}
|
||||
|
||||
// Create shared resources
|
||||
const db = new Neo4jMemoryClient(
|
||||
cfg.neo4j.uri,
|
||||
@@ -59,6 +80,7 @@ const memoryNeo4jPlugin = {
|
||||
cfg.embedding.model,
|
||||
cfg.embedding.provider,
|
||||
cfg.embedding.baseUrl,
|
||||
api.logger,
|
||||
);
|
||||
|
||||
api.logger.debug?.(
|
||||
@@ -499,23 +521,68 @@ const memoryNeo4jPlugin = {
|
||||
console.log(" Phase 7: Orphan Cleanup — Remove disconnected nodes\n");
|
||||
|
||||
try {
|
||||
// Validate sleep cycle CLI parameters before running
|
||||
const batchSize = opts.batchSize ? parseInt(opts.batchSize, 10) : undefined;
|
||||
const delay = opts.delay ? parseInt(opts.delay, 10) : undefined;
|
||||
const decayHalfLife = opts.decayHalfLife
|
||||
? parseInt(opts.decayHalfLife, 10)
|
||||
: undefined;
|
||||
const decayThreshold = opts.decayThreshold
|
||||
? parseFloat(opts.decayThreshold)
|
||||
: undefined;
|
||||
const pareto = opts.pareto ? parseFloat(opts.pareto) : undefined;
|
||||
const promotionMinAge = opts.promotionMinAge
|
||||
? parseInt(opts.promotionMinAge, 10)
|
||||
: undefined;
|
||||
|
||||
if (batchSize != null && (Number.isNaN(batchSize) || batchSize <= 0)) {
|
||||
console.error("Error: --batch-size must be greater than 0");
|
||||
process.exitCode = 1;
|
||||
return;
|
||||
}
|
||||
if (delay != null && (Number.isNaN(delay) || delay < 0)) {
|
||||
console.error("Error: --delay must be >= 0");
|
||||
process.exitCode = 1;
|
||||
return;
|
||||
}
|
||||
if (decayHalfLife != null && (Number.isNaN(decayHalfLife) || decayHalfLife <= 0)) {
|
||||
console.error("Error: --decay-half-life must be greater than 0");
|
||||
process.exitCode = 1;
|
||||
return;
|
||||
}
|
||||
if (
|
||||
decayThreshold != null &&
|
||||
(Number.isNaN(decayThreshold) || decayThreshold < 0 || decayThreshold > 1)
|
||||
) {
|
||||
console.error("Error: --decay-threshold must be between 0 and 1");
|
||||
process.exitCode = 1;
|
||||
return;
|
||||
}
|
||||
if (pareto != null && (Number.isNaN(pareto) || pareto < 0 || pareto > 1)) {
|
||||
console.error("Error: --pareto must be between 0 and 1");
|
||||
process.exitCode = 1;
|
||||
return;
|
||||
}
|
||||
if (
|
||||
promotionMinAge != null &&
|
||||
(Number.isNaN(promotionMinAge) || promotionMinAge < 0)
|
||||
) {
|
||||
console.error("Error: --promotion-min-age must be >= 0");
|
||||
process.exitCode = 1;
|
||||
return;
|
||||
}
|
||||
|
||||
await db.ensureInitialized();
|
||||
|
||||
const result = await runSleepCycle(db, embeddings, extractionConfig, api.logger, {
|
||||
agentId: opts.agent,
|
||||
dedupThreshold: opts.dedupThreshold ? parseFloat(opts.dedupThreshold) : undefined,
|
||||
paretoPercentile: opts.pareto ? parseFloat(opts.pareto) : undefined,
|
||||
promotionMinAgeDays: opts.promotionMinAge
|
||||
? parseInt(opts.promotionMinAge, 10)
|
||||
: undefined,
|
||||
decayRetentionThreshold: opts.decayThreshold
|
||||
? parseFloat(opts.decayThreshold)
|
||||
: undefined,
|
||||
decayBaseHalfLifeDays: opts.decayHalfLife
|
||||
? parseInt(opts.decayHalfLife, 10)
|
||||
: undefined,
|
||||
extractionBatchSize: opts.batchSize ? parseInt(opts.batchSize, 10) : undefined,
|
||||
extractionDelayMs: opts.delay ? parseInt(opts.delay, 10) : undefined,
|
||||
paretoPercentile: pareto,
|
||||
promotionMinAgeDays: promotionMinAge,
|
||||
decayRetentionThreshold: decayThreshold,
|
||||
decayBaseHalfLifeDays: decayHalfLife,
|
||||
extractionBatchSize: batchSize,
|
||||
extractionDelayMs: delay,
|
||||
onPhaseStart: (phase) => {
|
||||
const phaseNames = {
|
||||
dedup: "Phase 1: Deduplication",
|
||||
@@ -611,12 +678,45 @@ const memoryNeo4jPlugin = {
|
||||
const midSessionRefreshAt = new Map<string, number>();
|
||||
const MIN_TOKENS_SINCE_REFRESH = 10_000; // Only refresh if context grew by 10k+ tokens
|
||||
|
||||
// Track session timestamps for TTL-based cleanup. Without this, bootstrappedSessions
|
||||
// and midSessionRefreshAt leak entries for sessions that ended without an explicit
|
||||
// after_compaction event (e.g., normal session end on long-running gateways).
|
||||
const SESSION_TTL_MS = 24 * 60 * 60 * 1000; // 24 hours
|
||||
const sessionLastSeen = new Map<string, number>();
|
||||
let lastTtlSweep = Date.now();
|
||||
|
||||
/** Evict stale entries from session tracking maps older than SESSION_TTL_MS. */
|
||||
function pruneStaleSessionEntries(): void {
|
||||
const now = Date.now();
|
||||
// Only sweep at most once per 5 minutes to avoid overhead
|
||||
if (now - lastTtlSweep < 5 * 60 * 1000) {
|
||||
return;
|
||||
}
|
||||
lastTtlSweep = now;
|
||||
|
||||
const cutoff = now - SESSION_TTL_MS;
|
||||
for (const [key, ts] of sessionLastSeen) {
|
||||
if (ts < cutoff) {
|
||||
bootstrappedSessions.delete(key);
|
||||
midSessionRefreshAt.delete(key);
|
||||
sessionLastSeen.delete(key);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/** Mark a session as recently active for TTL tracking. */
|
||||
function touchSession(sessionKey: string): void {
|
||||
sessionLastSeen.set(sessionKey, Date.now());
|
||||
pruneStaleSessionEntries();
|
||||
}
|
||||
|
||||
// After compaction: clear bootstrap flag and mid-session refresh tracking
|
||||
if (cfg.coreMemory.enabled) {
|
||||
api.on("after_compaction", async (_event, ctx) => {
|
||||
if (ctx.sessionKey) {
|
||||
bootstrappedSessions.delete(ctx.sessionKey);
|
||||
midSessionRefreshAt.delete(ctx.sessionKey);
|
||||
sessionLastSeen.delete(ctx.sessionKey);
|
||||
api.logger.info?.(
|
||||
`memory-neo4j: cleared bootstrap/refresh flags for session ${ctx.sessionKey} after compaction`,
|
||||
);
|
||||
@@ -624,6 +724,18 @@ const memoryNeo4jPlugin = {
|
||||
});
|
||||
}
|
||||
|
||||
// Session end: clean up tracking entries for completed sessions.
|
||||
// The sessionId from session_end may match sessionKey in some implementations;
|
||||
// this provides best-effort cleanup alongside the TTL-based sweep above.
|
||||
api.on("session_end", async (_event, ctx) => {
|
||||
const key = ctx.sessionId;
|
||||
if (key) {
|
||||
bootstrappedSessions.delete(key);
|
||||
midSessionRefreshAt.delete(key);
|
||||
sessionLastSeen.delete(key);
|
||||
}
|
||||
});
|
||||
|
||||
// Mid-session core memory refresh: re-inject core memories when context grows past threshold
|
||||
// This counters the "lost in the middle" phenomenon by placing core memories closer to end of context
|
||||
const refreshThreshold = cfg.coreMemory.refreshAtContextPercent;
|
||||
@@ -666,6 +778,7 @@ const memoryNeo4jPlugin = {
|
||||
|
||||
// Record this refresh
|
||||
midSessionRefreshAt.set(sessionKey, event.estimatedUsedTokens);
|
||||
touchSession(sessionKey);
|
||||
|
||||
const content = coreMemories.map((m) => `- ${m.text}`).join("\n");
|
||||
api.logger.info?.(
|
||||
@@ -769,6 +882,7 @@ const memoryNeo4jPlugin = {
|
||||
if (coreMemories.length === 0) {
|
||||
if (sessionKey) {
|
||||
bootstrappedSessions.add(sessionKey);
|
||||
touchSession(sessionKey);
|
||||
}
|
||||
api.logger.debug?.(
|
||||
`memory-neo4j: no core memories found for agent=${agentId}, marking session as bootstrapped`,
|
||||
@@ -805,6 +919,7 @@ const memoryNeo4jPlugin = {
|
||||
|
||||
if (sessionKey) {
|
||||
bootstrappedSessions.add(sessionKey);
|
||||
touchSession(sessionKey);
|
||||
}
|
||||
// Log at info level when actually injecting, debug for skips
|
||||
api.logger.info?.(
|
||||
|
||||
@@ -77,14 +77,15 @@ describe("mid-session core memory refresh", () => {
|
||||
expect(config.coreMemory.refreshAtContextPercent).toBeUndefined();
|
||||
});
|
||||
|
||||
it("should reject refreshAtContextPercent over 100", async () => {
|
||||
it("should throw for refreshAtContextPercent over 100", async () => {
|
||||
const { memoryNeo4jConfigSchema } = await import("./config.js");
|
||||
const config = memoryNeo4jConfigSchema.parse({
|
||||
neo4j: { uri: "bolt://localhost:7687", user: "neo4j", password: "test" },
|
||||
embedding: { provider: "ollama" },
|
||||
coreMemory: { refreshAtContextPercent: 150 },
|
||||
});
|
||||
expect(config.coreMemory.refreshAtContextPercent).toBeUndefined();
|
||||
expect(() =>
|
||||
memoryNeo4jConfigSchema.parse({
|
||||
neo4j: { uri: "bolt://localhost:7687", user: "neo4j", password: "test" },
|
||||
embedding: { provider: "ollama" },
|
||||
coreMemory: { refreshAtContextPercent: 150 },
|
||||
}),
|
||||
).toThrow("coreMemory.refreshAtContextPercent must be between 1 and 100");
|
||||
});
|
||||
|
||||
it("should default to undefined when not specified", async () => {
|
||||
|
||||
@@ -211,32 +211,36 @@ export class Neo4jMemoryClient {
|
||||
|
||||
async storeMemory(input: StoreMemoryInput): Promise<string> {
|
||||
await this.ensureInitialized();
|
||||
const session = this.driver!.session();
|
||||
try {
|
||||
const now = new Date().toISOString();
|
||||
const result = await session.run(
|
||||
`CREATE (m:Memory {
|
||||
id: $id, text: $text, embedding: $embedding,
|
||||
importance: $importance, category: $category,
|
||||
source: $source, extractionStatus: $extractionStatus,
|
||||
agentId: $agentId, sessionKey: $sessionKey,
|
||||
createdAt: $createdAt, updatedAt: $updatedAt,
|
||||
retrievalCount: $retrievalCount, lastRetrievedAt: $lastRetrievedAt
|
||||
})
|
||||
RETURN m.id AS id`,
|
||||
{
|
||||
...input,
|
||||
sessionKey: input.sessionKey ?? null,
|
||||
createdAt: now,
|
||||
updatedAt: now,
|
||||
retrievalCount: 0,
|
||||
lastRetrievedAt: null,
|
||||
},
|
||||
);
|
||||
return result.records[0].get("id") as string;
|
||||
} finally {
|
||||
await session.close();
|
||||
}
|
||||
return this.retryOnTransient(async () => {
|
||||
const session = this.driver!.session();
|
||||
try {
|
||||
const now = new Date().toISOString();
|
||||
const result = await session.run(
|
||||
`CREATE (m:Memory {
|
||||
id: $id, text: $text, embedding: $embedding,
|
||||
importance: $importance, category: $category,
|
||||
source: $source, extractionStatus: $extractionStatus,
|
||||
agentId: $agentId, sessionKey: $sessionKey,
|
||||
createdAt: $createdAt, updatedAt: $updatedAt,
|
||||
retrievalCount: $retrievalCount, lastRetrievedAt: $lastRetrievedAt,
|
||||
extractionRetries: $extractionRetries
|
||||
})
|
||||
RETURN m.id AS id`,
|
||||
{
|
||||
...input,
|
||||
sessionKey: input.sessionKey ?? null,
|
||||
createdAt: now,
|
||||
updatedAt: now,
|
||||
retrievalCount: 0,
|
||||
lastRetrievedAt: null,
|
||||
extractionRetries: 0,
|
||||
},
|
||||
);
|
||||
return result.records[0].get("id") as string;
|
||||
} finally {
|
||||
await session.close();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
async deleteMemory(id: string): Promise<boolean> {
|
||||
@@ -247,29 +251,32 @@ export class Neo4jMemoryClient {
|
||||
throw new Error(`Invalid memory ID format: ${id}`);
|
||||
}
|
||||
|
||||
const session = this.driver!.session();
|
||||
try {
|
||||
// First, decrement mentionCount on connected entities
|
||||
await session.run(
|
||||
`MATCH (m:Memory {id: $id})-[:MENTIONS]->(e:Entity)
|
||||
SET e.mentionCount = e.mentionCount - 1`,
|
||||
{ id },
|
||||
);
|
||||
return this.retryOnTransient(async () => {
|
||||
const session = this.driver!.session();
|
||||
try {
|
||||
// Decrement mentionCount on connected entities (floor at 0 to prevent
|
||||
// negative counts from parallel deletes racing on the same entity)
|
||||
await session.run(
|
||||
`MATCH (m:Memory {id: $id})-[:MENTIONS]->(e:Entity)
|
||||
SET e.mentionCount = CASE WHEN e.mentionCount > 0 THEN e.mentionCount - 1 ELSE 0 END`,
|
||||
{ id },
|
||||
);
|
||||
|
||||
// Then delete the memory with all its relationships
|
||||
const result = await session.run(
|
||||
`MATCH (m:Memory {id: $id})
|
||||
DETACH DELETE m
|
||||
RETURN count(*) AS deleted`,
|
||||
{ id },
|
||||
);
|
||||
// Then delete the memory with all its relationships
|
||||
const result = await session.run(
|
||||
`MATCH (m:Memory {id: $id})
|
||||
DETACH DELETE m
|
||||
RETURN count(*) AS deleted`,
|
||||
{ id },
|
||||
);
|
||||
|
||||
const deleted =
|
||||
result.records.length > 0 ? (result.records[0].get("deleted") as number) > 0 : false;
|
||||
return deleted;
|
||||
} finally {
|
||||
await session.close();
|
||||
}
|
||||
const deleted =
|
||||
result.records.length > 0 ? (result.records[0].get("deleted") as number) > 0 : false;
|
||||
return deleted;
|
||||
} finally {
|
||||
await session.close();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
async countMemories(agentId?: string): Promise<number> {
|
||||
@@ -371,39 +378,43 @@ export class Neo4jMemoryClient {
|
||||
agentId?: string,
|
||||
): Promise<SearchSignalResult[]> {
|
||||
await this.ensureInitialized();
|
||||
const session = this.driver!.session();
|
||||
try {
|
||||
const agentFilter = agentId ? "AND node.agentId = $agentId" : "";
|
||||
const result = await session.run(
|
||||
`CALL db.index.vector.queryNodes('memory_embedding_index', $limit, $embedding)
|
||||
YIELD node, score
|
||||
WHERE score >= $minScore ${agentFilter}
|
||||
RETURN node.id AS id, node.text AS text, node.category AS category,
|
||||
node.importance AS importance, node.createdAt AS createdAt,
|
||||
score AS similarity
|
||||
ORDER BY score DESC`,
|
||||
{
|
||||
embedding,
|
||||
limit: neo4j.int(Math.floor(limit)),
|
||||
minScore,
|
||||
...(agentId ? { agentId } : {}),
|
||||
},
|
||||
);
|
||||
return await this.retryOnTransient(async () => {
|
||||
const session = this.driver!.session();
|
||||
try {
|
||||
const agentFilter = agentId ? "AND node.agentId = $agentId" : "";
|
||||
const result = await session.run(
|
||||
`CALL db.index.vector.queryNodes('memory_embedding_index', $limit, $embedding)
|
||||
YIELD node, score
|
||||
WHERE score >= $minScore ${agentFilter}
|
||||
RETURN node.id AS id, node.text AS text, node.category AS category,
|
||||
node.importance AS importance, node.createdAt AS createdAt,
|
||||
score AS similarity
|
||||
ORDER BY score DESC`,
|
||||
{
|
||||
embedding,
|
||||
limit: neo4j.int(Math.floor(limit)),
|
||||
minScore,
|
||||
...(agentId ? { agentId } : {}),
|
||||
},
|
||||
);
|
||||
|
||||
return result.records.map((r) => ({
|
||||
id: r.get("id") as string,
|
||||
text: r.get("text") as string,
|
||||
category: r.get("category") as string,
|
||||
importance: r.get("importance") as number,
|
||||
createdAt: String(r.get("createdAt") ?? ""),
|
||||
score: r.get("similarity") as number,
|
||||
}));
|
||||
return result.records.map((r) => ({
|
||||
id: r.get("id") as string,
|
||||
text: r.get("text") as string,
|
||||
category: r.get("category") as string,
|
||||
importance: r.get("importance") as number,
|
||||
createdAt: String(r.get("createdAt") ?? ""),
|
||||
score: r.get("similarity") as number,
|
||||
}));
|
||||
} finally {
|
||||
await session.close();
|
||||
}
|
||||
});
|
||||
} catch (err) {
|
||||
// Graceful degradation: return empty if vector index isn't ready
|
||||
// Graceful degradation: return empty if vector index isn't ready or all retries exhausted
|
||||
this.logger.warn(`memory-neo4j: vector search failed: ${String(err)}`);
|
||||
return [];
|
||||
} finally {
|
||||
await session.close();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -413,49 +424,58 @@ export class Neo4jMemoryClient {
|
||||
*/
|
||||
async bm25Search(query: string, limit: number, agentId?: string): Promise<SearchSignalResult[]> {
|
||||
await this.ensureInitialized();
|
||||
const session = this.driver!.session();
|
||||
const escaped = escapeLucene(query);
|
||||
if (!escaped.trim()) {
|
||||
return [];
|
||||
}
|
||||
|
||||
try {
|
||||
const escaped = escapeLucene(query);
|
||||
if (!escaped.trim()) {
|
||||
return [];
|
||||
}
|
||||
return await this.retryOnTransient(async () => {
|
||||
const session = this.driver!.session();
|
||||
try {
|
||||
const agentFilter = agentId ? "AND node.agentId = $agentId" : "";
|
||||
const result = await session.run(
|
||||
`CALL db.index.fulltext.queryNodes('memory_fulltext_index', $query)
|
||||
YIELD node, score
|
||||
WHERE true ${agentFilter}
|
||||
RETURN node.id AS id, node.text AS text, node.category AS category,
|
||||
node.importance AS importance, node.createdAt AS createdAt,
|
||||
score AS bm25Score
|
||||
ORDER BY score DESC
|
||||
LIMIT $limit`,
|
||||
{
|
||||
query: escaped,
|
||||
limit: neo4j.int(Math.floor(limit)),
|
||||
...(agentId ? { agentId } : {}),
|
||||
},
|
||||
);
|
||||
|
||||
const agentFilter = agentId ? "AND node.agentId = $agentId" : "";
|
||||
const result = await session.run(
|
||||
`CALL db.index.fulltext.queryNodes('memory_fulltext_index', $query)
|
||||
YIELD node, score
|
||||
WHERE true ${agentFilter}
|
||||
RETURN node.id AS id, node.text AS text, node.category AS category,
|
||||
node.importance AS importance, node.createdAt AS createdAt,
|
||||
score AS bm25Score
|
||||
ORDER BY score DESC
|
||||
LIMIT $limit`,
|
||||
{ query: escaped, limit: neo4j.int(Math.floor(limit)), ...(agentId ? { agentId } : {}) },
|
||||
);
|
||||
// Normalize BM25 scores to 0-1 range (divide by max)
|
||||
const records = result.records.map((r) => ({
|
||||
id: r.get("id") as string,
|
||||
text: r.get("text") as string,
|
||||
category: r.get("category") as string,
|
||||
importance: r.get("importance") as number,
|
||||
createdAt: String(r.get("createdAt") ?? ""),
|
||||
rawScore: r.get("bm25Score") as number,
|
||||
}));
|
||||
|
||||
// Normalize BM25 scores to 0-1 range (divide by max)
|
||||
const records = result.records.map((r) => ({
|
||||
id: r.get("id") as string,
|
||||
text: r.get("text") as string,
|
||||
category: r.get("category") as string,
|
||||
importance: r.get("importance") as number,
|
||||
createdAt: String(r.get("createdAt") ?? ""),
|
||||
rawScore: r.get("bm25Score") as number,
|
||||
}));
|
||||
|
||||
if (records.length === 0) {
|
||||
return [];
|
||||
}
|
||||
const maxScore = records[0].rawScore || 1;
|
||||
return records.map((r) => ({
|
||||
...r,
|
||||
score: r.rawScore / maxScore,
|
||||
}));
|
||||
if (records.length === 0) {
|
||||
return [];
|
||||
}
|
||||
const maxScore = records[0].rawScore || 1;
|
||||
return records.map((r) => ({
|
||||
...r,
|
||||
score: r.rawScore / maxScore,
|
||||
}));
|
||||
} finally {
|
||||
await session.close();
|
||||
}
|
||||
});
|
||||
} catch (err) {
|
||||
// Graceful degradation: return empty if all retries exhausted
|
||||
this.logger.warn(`memory-neo4j: BM25 search failed: ${String(err)}`);
|
||||
return [];
|
||||
} finally {
|
||||
await session.close();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -475,89 +495,94 @@ export class Neo4jMemoryClient {
|
||||
agentId?: string,
|
||||
): Promise<SearchSignalResult[]> {
|
||||
await this.ensureInitialized();
|
||||
const session = this.driver!.session();
|
||||
const escaped = escapeLucene(query);
|
||||
if (!escaped.trim()) {
|
||||
return [];
|
||||
}
|
||||
|
||||
try {
|
||||
const escaped = escapeLucene(query);
|
||||
if (!escaped.trim()) {
|
||||
return [];
|
||||
}
|
||||
return await this.retryOnTransient(async () => {
|
||||
const session = this.driver!.session();
|
||||
try {
|
||||
// Step 1: Find matching entities
|
||||
const entityResult = await session.run(
|
||||
`CALL db.index.fulltext.queryNodes('entity_fulltext_index', $query)
|
||||
YIELD node, score
|
||||
WHERE score >= 0.5
|
||||
RETURN node.id AS entityId, node.name AS name, score
|
||||
ORDER BY score DESC
|
||||
LIMIT 5`,
|
||||
{ query: escaped },
|
||||
);
|
||||
|
||||
// Step 1: Find matching entities
|
||||
const entityResult = await session.run(
|
||||
`CALL db.index.fulltext.queryNodes('entity_fulltext_index', $query)
|
||||
YIELD node, score
|
||||
WHERE score >= 0.5
|
||||
RETURN node.id AS entityId, node.name AS name, score
|
||||
ORDER BY score DESC
|
||||
LIMIT 5`,
|
||||
{ query: escaped },
|
||||
);
|
||||
const entityIds = entityResult.records.map((r) => r.get("entityId") as string);
|
||||
if (entityIds.length === 0) {
|
||||
return [];
|
||||
}
|
||||
|
||||
const entityIds = entityResult.records.map((r) => r.get("entityId") as string);
|
||||
if (entityIds.length === 0) {
|
||||
return [];
|
||||
}
|
||||
// Step 2 + 3: Direct mentions + 1-hop spreading activation
|
||||
const agentFilter = agentId ? "AND m.agentId = $agentId" : "";
|
||||
const result = await session.run(
|
||||
`UNWIND $entityIds AS eid
|
||||
// Direct: Entity ← MENTIONS ← Memory
|
||||
OPTIONAL MATCH (e:Entity {id: eid})<-[rm:MENTIONS]-(m:Memory)
|
||||
WHERE m IS NOT NULL ${agentFilter}
|
||||
WITH m, coalesce(rm.confidence, 1.0) AS directScore
|
||||
WHERE m IS NOT NULL
|
||||
|
||||
// Step 2 + 3: Direct mentions + 1-hop spreading activation
|
||||
const agentFilter = agentId ? "AND m.agentId = $agentId" : "";
|
||||
const result = await session.run(
|
||||
`UNWIND $entityIds AS eid
|
||||
// Direct: Entity ← MENTIONS ← Memory
|
||||
OPTIONAL MATCH (e:Entity {id: eid})<-[rm:MENTIONS]-(m:Memory)
|
||||
WHERE m IS NOT NULL ${agentFilter}
|
||||
WITH m, coalesce(rm.confidence, 1.0) AS directScore
|
||||
WHERE m IS NOT NULL
|
||||
RETURN m.id AS id, m.text AS text, m.category AS category,
|
||||
m.importance AS importance, m.createdAt AS createdAt,
|
||||
max(directScore) AS graphScore
|
||||
|
||||
RETURN m.id AS id, m.text AS text, m.category AS category,
|
||||
m.importance AS importance, m.createdAt AS createdAt,
|
||||
max(directScore) AS graphScore
|
||||
UNION
|
||||
|
||||
UNION
|
||||
UNWIND $entityIds AS eid
|
||||
// 1-hop: Entity → relationship → Entity ← MENTIONS ← Memory
|
||||
OPTIONAL MATCH (e:Entity {id: eid})-[r1:RELATED_TO|KNOWS|WORKS_AT|LIVES_AT|MARRIED_TO|PREFERS|DECIDED]-(e2:Entity)
|
||||
WHERE coalesce(r1.confidence, 0.7) >= $firingThreshold
|
||||
OPTIONAL MATCH (e2)<-[rm:MENTIONS]-(m:Memory)
|
||||
WHERE m IS NOT NULL ${agentFilter}
|
||||
WITH m, coalesce(r1.confidence, 0.7) * coalesce(rm.confidence, 1.0) AS hopScore
|
||||
WHERE m IS NOT NULL
|
||||
|
||||
UNWIND $entityIds AS eid
|
||||
// 1-hop: Entity → relationship → Entity ← MENTIONS ← Memory
|
||||
OPTIONAL MATCH (e:Entity {id: eid})-[r1:RELATED_TO|KNOWS|WORKS_AT|LIVES_AT|MARRIED_TO|PREFERS|DECIDED]-(e2:Entity)
|
||||
WHERE coalesce(r1.confidence, 0.7) >= $firingThreshold
|
||||
OPTIONAL MATCH (e2)<-[rm:MENTIONS]-(m:Memory)
|
||||
WHERE m IS NOT NULL ${agentFilter}
|
||||
WITH m, coalesce(r1.confidence, 0.7) * coalesce(rm.confidence, 1.0) AS hopScore
|
||||
WHERE m IS NOT NULL
|
||||
RETURN m.id AS id, m.text AS text, m.category AS category,
|
||||
m.importance AS importance, m.createdAt AS createdAt,
|
||||
max(hopScore) AS graphScore`,
|
||||
{ entityIds, firingThreshold, ...(agentId ? { agentId } : {}) },
|
||||
);
|
||||
|
||||
RETURN m.id AS id, m.text AS text, m.category AS category,
|
||||
m.importance AS importance, m.createdAt AS createdAt,
|
||||
max(hopScore) AS graphScore`,
|
||||
{ entityIds, firingThreshold, ...(agentId ? { agentId } : {}) },
|
||||
);
|
||||
// Deduplicate by id, keeping highest score
|
||||
const byId = new Map<string, SearchSignalResult>();
|
||||
for (const record of result.records) {
|
||||
const id = record.get("id") as string;
|
||||
if (!id) {
|
||||
continue;
|
||||
}
|
||||
const score = record.get("graphScore") as number;
|
||||
const existing = byId.get(id);
|
||||
if (!existing || score > existing.score) {
|
||||
byId.set(id, {
|
||||
id,
|
||||
text: record.get("text") as string,
|
||||
category: record.get("category") as string,
|
||||
importance: record.get("importance") as number,
|
||||
createdAt: String(record.get("createdAt") ?? ""),
|
||||
score,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Deduplicate by id, keeping highest score
|
||||
const byId = new Map<string, SearchSignalResult>();
|
||||
for (const record of result.records) {
|
||||
const id = record.get("id") as string;
|
||||
if (!id) {
|
||||
continue;
|
||||
return Array.from(byId.values())
|
||||
.toSorted((a, b) => b.score - a.score)
|
||||
.slice(0, limit);
|
||||
} finally {
|
||||
await session.close();
|
||||
}
|
||||
const score = record.get("graphScore") as number;
|
||||
const existing = byId.get(id);
|
||||
if (!existing || score > existing.score) {
|
||||
byId.set(id, {
|
||||
id,
|
||||
text: record.get("text") as string,
|
||||
category: record.get("category") as string,
|
||||
importance: record.get("importance") as number,
|
||||
createdAt: String(record.get("createdAt") ?? ""),
|
||||
score,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
return Array.from(byId.values())
|
||||
.toSorted((a, b) => b.score - a.score)
|
||||
.slice(0, limit);
|
||||
});
|
||||
} catch (err) {
|
||||
// Graceful degradation: return empty if all retries exhausted
|
||||
this.logger.warn(`memory-neo4j: graph search failed: ${String(err)}`);
|
||||
return [];
|
||||
} finally {
|
||||
await session.close();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -570,28 +595,32 @@ export class Neo4jMemoryClient {
|
||||
limit: number = 1,
|
||||
): Promise<Array<{ id: string; text: string; score: number }>> {
|
||||
await this.ensureInitialized();
|
||||
const session = this.driver!.session();
|
||||
try {
|
||||
const result = await session.run(
|
||||
`CALL db.index.vector.queryNodes('memory_embedding_index', $limit, $embedding)
|
||||
YIELD node, score
|
||||
WHERE score >= $threshold
|
||||
RETURN node.id AS id, node.text AS text, score AS similarity
|
||||
ORDER BY score DESC`,
|
||||
{ embedding, limit: neo4j.int(limit), threshold },
|
||||
);
|
||||
return await this.retryOnTransient(async () => {
|
||||
const session = this.driver!.session();
|
||||
try {
|
||||
const result = await session.run(
|
||||
`CALL db.index.vector.queryNodes('memory_embedding_index', $limit, $embedding)
|
||||
YIELD node, score
|
||||
WHERE score >= $threshold
|
||||
RETURN node.id AS id, node.text AS text, score AS similarity
|
||||
ORDER BY score DESC`,
|
||||
{ embedding, limit: neo4j.int(limit), threshold },
|
||||
);
|
||||
|
||||
return result.records.map((r) => ({
|
||||
id: r.get("id") as string,
|
||||
text: r.get("text") as string,
|
||||
score: r.get("similarity") as number,
|
||||
}));
|
||||
return result.records.map((r) => ({
|
||||
id: r.get("id") as string,
|
||||
text: r.get("text") as string,
|
||||
score: r.get("similarity") as number,
|
||||
}));
|
||||
} finally {
|
||||
await session.close();
|
||||
}
|
||||
});
|
||||
} catch (err) {
|
||||
// If vector index isn't ready, return no duplicates (allow store)
|
||||
// If vector index isn't ready or all retries exhausted, return no duplicates (allow store)
|
||||
this.logger.debug?.(`memory-neo4j: similarity check failed: ${String(err)}`);
|
||||
return [];
|
||||
} finally {
|
||||
await session.close();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -609,18 +638,20 @@ export class Neo4jMemoryClient {
|
||||
}
|
||||
|
||||
await this.ensureInitialized();
|
||||
const session = this.driver!.session();
|
||||
try {
|
||||
await session.run(
|
||||
`UNWIND $ids AS memId
|
||||
MATCH (m:Memory {id: memId})
|
||||
SET m.retrievalCount = coalesce(m.retrievalCount, 0) + 1,
|
||||
m.lastRetrievedAt = $now`,
|
||||
{ ids: memoryIds, now: new Date().toISOString() },
|
||||
);
|
||||
} finally {
|
||||
await session.close();
|
||||
}
|
||||
return this.retryOnTransient(async () => {
|
||||
const session = this.driver!.session();
|
||||
try {
|
||||
await session.run(
|
||||
`UNWIND $ids AS memId
|
||||
MATCH (m:Memory {id: memId})
|
||||
SET m.retrievalCount = coalesce(m.retrievalCount, 0) + 1,
|
||||
m.lastRetrievedAt = $now`,
|
||||
{ ids: memoryIds, now: new Date().toISOString() },
|
||||
);
|
||||
} finally {
|
||||
await session.close();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -832,14 +863,22 @@ export class Neo4jMemoryClient {
|
||||
|
||||
/**
|
||||
* Update the extraction status of a Memory node.
|
||||
* Optionally increments the extractionRetries counter (for transient failure tracking).
|
||||
*/
|
||||
async updateExtractionStatus(id: string, status: ExtractionStatus): Promise<void> {
|
||||
async updateExtractionStatus(
|
||||
id: string,
|
||||
status: ExtractionStatus,
|
||||
options?: { incrementRetries?: boolean },
|
||||
): Promise<void> {
|
||||
await this.ensureInitialized();
|
||||
const session = this.driver!.session();
|
||||
try {
|
||||
const retryClause = options?.incrementRetries
|
||||
? ", m.extractionRetries = coalesce(m.extractionRetries, 0) + 1"
|
||||
: "";
|
||||
await session.run(
|
||||
`MATCH (m:Memory {id: $id})
|
||||
SET m.extractionStatus = $status, m.updatedAt = $now`,
|
||||
SET m.extractionStatus = $status, m.updatedAt = $now${retryClause}`,
|
||||
{ id, status, now: new Date().toISOString() },
|
||||
);
|
||||
} finally {
|
||||
@@ -847,6 +886,24 @@ export class Neo4jMemoryClient {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the current extraction retry count for a memory.
|
||||
*/
|
||||
async getExtractionRetries(id: string): Promise<number> {
|
||||
await this.ensureInitialized();
|
||||
const session = this.driver!.session();
|
||||
try {
|
||||
const result = await session.run(
|
||||
`MATCH (m:Memory {id: $id})
|
||||
RETURN coalesce(m.extractionRetries, 0) AS retries`,
|
||||
{ id },
|
||||
);
|
||||
return (result.records[0]?.get("retries") as number) ?? 0;
|
||||
} finally {
|
||||
await session.close();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* List memories with pending extraction status.
|
||||
* Used by the sleep cycle to batch-process extractions.
|
||||
@@ -854,7 +911,7 @@ export class Neo4jMemoryClient {
|
||||
async listPendingExtractions(
|
||||
limit: number = 100,
|
||||
agentId?: string,
|
||||
): Promise<Array<{ id: string; text: string; agentId: string }>> {
|
||||
): Promise<Array<{ id: string; text: string; agentId: string; extractionRetries: number }>> {
|
||||
await this.ensureInitialized();
|
||||
const session = this.driver!.session();
|
||||
try {
|
||||
@@ -862,7 +919,8 @@ export class Neo4jMemoryClient {
|
||||
const result = await session.run(
|
||||
`MATCH (m:Memory)
|
||||
WHERE m.extractionStatus = 'pending' ${agentFilter}
|
||||
RETURN m.id AS id, m.text AS text, m.agentId AS agentId
|
||||
RETURN m.id AS id, m.text AS text, m.agentId AS agentId,
|
||||
coalesce(m.extractionRetries, 0) AS extractionRetries
|
||||
ORDER BY m.createdAt ASC
|
||||
LIMIT $limit`,
|
||||
{ limit: neo4j.int(limit), agentId },
|
||||
@@ -871,6 +929,7 @@ export class Neo4jMemoryClient {
|
||||
id: r.get("id") as string,
|
||||
text: r.get("text") as string,
|
||||
agentId: r.get("agentId") as string,
|
||||
extractionRetries: r.get("extractionRetries") as number,
|
||||
}));
|
||||
} finally {
|
||||
await session.close();
|
||||
@@ -976,14 +1035,17 @@ export class Neo4jMemoryClient {
|
||||
|
||||
let pairsFound = 0;
|
||||
for (const id of memoryData.keys()) {
|
||||
const similar = await session.run(
|
||||
`MATCH (src:Memory {id: $id})
|
||||
CALL db.index.vector.queryNodes('memory_embedding_index', $k, src.embedding)
|
||||
YIELD node, score
|
||||
WHERE node.id <> $id AND score >= $threshold
|
||||
RETURN node.id AS matchId`,
|
||||
{ id, k: neo4j.int(10), threshold },
|
||||
);
|
||||
// Retry individual vector queries on transient errors
|
||||
const similar = await this.retryOnTransient(async () => {
|
||||
return session.run(
|
||||
`MATCH (src:Memory {id: $id})
|
||||
CALL db.index.vector.queryNodes('memory_embedding_index', $k, src.embedding)
|
||||
YIELD node, score
|
||||
WHERE node.id <> $id AND score >= $threshold
|
||||
RETURN node.id AS matchId`,
|
||||
{ id, k: neo4j.int(10), threshold },
|
||||
);
|
||||
});
|
||||
|
||||
for (const r of similar.records) {
|
||||
const matchId = r.get("matchId") as string;
|
||||
@@ -1045,30 +1107,56 @@ export class Neo4jMemoryClient {
|
||||
const survivorId = memoryIds[survivorIdx];
|
||||
const toDelete = memoryIds.filter((_, i) => i !== survivorIdx);
|
||||
|
||||
const session = this.driver!.session();
|
||||
try {
|
||||
// Transfer MENTIONS relationships from deleted memories to survivor
|
||||
await session.run(
|
||||
`UNWIND $toDelete AS deadId
|
||||
MATCH (dead:Memory {id: deadId})-[r:MENTIONS]->(e:Entity)
|
||||
MATCH (survivor:Memory {id: $survivorId})
|
||||
MERGE (survivor)-[:MENTIONS]->(e)
|
||||
DELETE r`,
|
||||
{ toDelete, survivorId },
|
||||
);
|
||||
return this.retryOnTransient(async () => {
|
||||
const session = this.driver!.session();
|
||||
try {
|
||||
// Optimistic lock: verify all cluster members still exist before merging.
|
||||
// New memories added or deleted between findDuplicateClusters() and this
|
||||
// call could invalidate the cluster. Skip if any member is missing.
|
||||
const verifyResult = await session.run(
|
||||
`UNWIND $ids AS memId
|
||||
OPTIONAL MATCH (m:Memory {id: memId})
|
||||
RETURN memId, m IS NOT NULL AS exists`,
|
||||
{ ids: memoryIds },
|
||||
);
|
||||
|
||||
// Delete the duplicate memories
|
||||
await session.run(
|
||||
`UNWIND $toDelete AS deadId
|
||||
MATCH (m:Memory {id: deadId})
|
||||
DETACH DELETE m`,
|
||||
{ toDelete },
|
||||
);
|
||||
const missingIds: string[] = [];
|
||||
for (const r of verifyResult.records) {
|
||||
if (!r.get("exists")) {
|
||||
missingIds.push(r.get("memId") as string);
|
||||
}
|
||||
}
|
||||
|
||||
return { survivorId, deletedCount: toDelete.length };
|
||||
} finally {
|
||||
await session.close();
|
||||
}
|
||||
if (missingIds.length > 0) {
|
||||
this.logger.warn(
|
||||
`memory-neo4j: skipping cluster merge — ${missingIds.length} member(s) no longer exist: ${missingIds.join(", ")}`,
|
||||
);
|
||||
return { survivorId, deletedCount: 0 };
|
||||
}
|
||||
|
||||
// Transfer MENTIONS relationships from deleted memories to survivor
|
||||
await session.run(
|
||||
`UNWIND $toDelete AS deadId
|
||||
MATCH (dead:Memory {id: deadId})-[r:MENTIONS]->(e:Entity)
|
||||
MATCH (survivor:Memory {id: $survivorId})
|
||||
MERGE (survivor)-[:MENTIONS]->(e)
|
||||
DELETE r`,
|
||||
{ toDelete, survivorId },
|
||||
);
|
||||
|
||||
// Delete the duplicate memories
|
||||
await session.run(
|
||||
`UNWIND $toDelete AS deadId
|
||||
MATCH (m:Memory {id: deadId})
|
||||
DETACH DELETE m`,
|
||||
{ toDelete },
|
||||
);
|
||||
|
||||
return { survivorId, deletedCount: toDelete.length };
|
||||
} finally {
|
||||
await session.close();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
@@ -1160,11 +1248,12 @@ export class Neo4jMemoryClient {
|
||||
await this.ensureInitialized();
|
||||
const session = this.driver!.session();
|
||||
try {
|
||||
// Decrement mention counts on connected entities
|
||||
// Decrement mention counts on connected entities (floor at 0 to prevent
|
||||
// negative counts from parallel prune/delete operations racing on the same entity)
|
||||
await session.run(
|
||||
`UNWIND $ids AS memId
|
||||
MATCH (m:Memory {id: memId})-[:MENTIONS]->(e:Entity)
|
||||
SET e.mentionCount = e.mentionCount - 1`,
|
||||
SET e.mentionCount = CASE WHEN e.mentionCount > 0 THEN e.mentionCount - 1 ELSE 0 END`,
|
||||
{ ids: memoryIds },
|
||||
);
|
||||
|
||||
@@ -1581,7 +1670,7 @@ export class Neo4jMemoryClient {
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
/**
|
||||
* Retry an operation on transient Neo4j errors (deadlocks, etc.)
|
||||
* Retry an operation on transient Neo4j errors (deadlocks, connection blips, etc.)
|
||||
* with exponential backoff. Adapted from ontology project.
|
||||
*/
|
||||
private async retryOnTransient<T>(
|
||||
@@ -1595,14 +1684,24 @@ export class Neo4jMemoryClient {
|
||||
return await fn();
|
||||
} catch (err) {
|
||||
lastError = err;
|
||||
// Check for Neo4j transient errors
|
||||
// Check for Neo4j transient errors (deadlocks, connection blips, service unavailable)
|
||||
const errCode =
|
||||
err instanceof Error
|
||||
? ((err as unknown as Record<string, unknown>).code as string | undefined)
|
||||
: undefined;
|
||||
const isTransient =
|
||||
err instanceof Error &&
|
||||
(err.message.includes("DeadlockDetected") ||
|
||||
err.message.includes("TransientError") ||
|
||||
err.message.includes("ServiceUnavailable") ||
|
||||
err.message.includes("SessionExpired") ||
|
||||
err.message.includes("ConnectionRefused") ||
|
||||
err.message.includes("connection terminated") ||
|
||||
(err.constructor.name === "Neo4jError" &&
|
||||
(err as unknown as Record<string, unknown>).code ===
|
||||
"Neo.TransientError.Transaction.DeadlockDetected"));
|
||||
typeof errCode === "string" &&
|
||||
(errCode.startsWith("Neo.TransientError.") ||
|
||||
errCode === "ServiceUnavailable" ||
|
||||
errCode === "SessionExpired")));
|
||||
|
||||
if (!isTransient || attempt >= maxAttempts - 1) {
|
||||
throw err;
|
||||
|
||||
200
extensions/memory-neo4j/schema.test.ts
Normal file
200
extensions/memory-neo4j/schema.test.ts
Normal file
@@ -0,0 +1,200 @@
|
||||
/**
|
||||
* Tests for schema.ts — Schema Validation & Helpers.
|
||||
*
|
||||
* Tests the exported pure functions: escapeLucene(), validateRelationshipType(),
|
||||
* and the exported constants and types.
|
||||
*/
|
||||
|
||||
import { describe, it, expect } from "vitest";
|
||||
import {
|
||||
escapeLucene,
|
||||
validateRelationshipType,
|
||||
ALLOWED_RELATIONSHIP_TYPES,
|
||||
MEMORY_CATEGORIES,
|
||||
ENTITY_TYPES,
|
||||
} from "./schema.js";
|
||||
|
||||
// ============================================================================
|
||||
// escapeLucene()
|
||||
// ============================================================================
|
||||
|
||||
describe("escapeLucene", () => {
|
||||
it("should return normal text unchanged", () => {
|
||||
expect(escapeLucene("hello world")).toBe("hello world");
|
||||
});
|
||||
|
||||
it("should return empty string unchanged", () => {
|
||||
expect(escapeLucene("")).toBe("");
|
||||
});
|
||||
|
||||
it("should escape plus sign", () => {
|
||||
expect(escapeLucene("a+b")).toBe("a\\+b");
|
||||
});
|
||||
|
||||
it("should escape minus sign", () => {
|
||||
expect(escapeLucene("a-b")).toBe("a\\-b");
|
||||
});
|
||||
|
||||
it("should escape ampersand", () => {
|
||||
expect(escapeLucene("a&b")).toBe("a\\&b");
|
||||
});
|
||||
|
||||
it("should escape pipe", () => {
|
||||
expect(escapeLucene("a|b")).toBe("a\\|b");
|
||||
});
|
||||
|
||||
it("should escape exclamation mark", () => {
|
||||
expect(escapeLucene("hello!")).toBe("hello\\!");
|
||||
});
|
||||
|
||||
it("should escape parentheses", () => {
|
||||
expect(escapeLucene("(group)")).toBe("\\(group\\)");
|
||||
});
|
||||
|
||||
it("should escape curly braces", () => {
|
||||
expect(escapeLucene("{range}")).toBe("\\{range\\}");
|
||||
});
|
||||
|
||||
it("should escape square brackets", () => {
|
||||
expect(escapeLucene("[range]")).toBe("\\[range\\]");
|
||||
});
|
||||
|
||||
it("should escape caret", () => {
|
||||
expect(escapeLucene("boost^2")).toBe("boost\\^2");
|
||||
});
|
||||
|
||||
it("should escape double quotes", () => {
|
||||
expect(escapeLucene('"exact"')).toBe('\\"exact\\"');
|
||||
});
|
||||
|
||||
it("should escape tilde", () => {
|
||||
expect(escapeLucene("fuzzy~")).toBe("fuzzy\\~");
|
||||
});
|
||||
|
||||
it("should escape asterisk", () => {
|
||||
expect(escapeLucene("wild*")).toBe("wild\\*");
|
||||
});
|
||||
|
||||
it("should escape question mark", () => {
|
||||
expect(escapeLucene("single?")).toBe("single\\?");
|
||||
});
|
||||
|
||||
it("should escape colon", () => {
|
||||
expect(escapeLucene("field:value")).toBe("field\\:value");
|
||||
});
|
||||
|
||||
it("should escape backslash", () => {
|
||||
expect(escapeLucene("path\\file")).toBe("path\\\\file");
|
||||
});
|
||||
|
||||
it("should escape forward slash", () => {
|
||||
expect(escapeLucene("a/b")).toBe("a\\/b");
|
||||
});
|
||||
|
||||
it("should escape multiple special characters in one string", () => {
|
||||
expect(escapeLucene("(a+b) && c*")).toBe("\\(a\\+b\\) \\&\\& c\\*");
|
||||
});
|
||||
|
||||
it("should handle mixed normal and special characters", () => {
|
||||
expect(escapeLucene("hello world! [test]")).toBe("hello world\\! \\[test\\]");
|
||||
});
|
||||
|
||||
it("should handle strings with only special characters", () => {
|
||||
expect(escapeLucene("+-")).toBe("\\+\\-");
|
||||
});
|
||||
});
|
||||
|
||||
// ============================================================================
|
||||
// validateRelationshipType()
|
||||
// ============================================================================
|
||||
|
||||
describe("validateRelationshipType", () => {
|
||||
describe("valid relationship types", () => {
|
||||
it("should accept WORKS_AT", () => {
|
||||
expect(validateRelationshipType("WORKS_AT")).toBe(true);
|
||||
});
|
||||
|
||||
it("should accept LIVES_AT", () => {
|
||||
expect(validateRelationshipType("LIVES_AT")).toBe(true);
|
||||
});
|
||||
|
||||
it("should accept KNOWS", () => {
|
||||
expect(validateRelationshipType("KNOWS")).toBe(true);
|
||||
});
|
||||
|
||||
it("should accept MARRIED_TO", () => {
|
||||
expect(validateRelationshipType("MARRIED_TO")).toBe(true);
|
||||
});
|
||||
|
||||
it("should accept PREFERS", () => {
|
||||
expect(validateRelationshipType("PREFERS")).toBe(true);
|
||||
});
|
||||
|
||||
it("should accept DECIDED", () => {
|
||||
expect(validateRelationshipType("DECIDED")).toBe(true);
|
||||
});
|
||||
|
||||
it("should accept RELATED_TO", () => {
|
||||
expect(validateRelationshipType("RELATED_TO")).toBe(true);
|
||||
});
|
||||
|
||||
it("should accept all ALLOWED_RELATIONSHIP_TYPES", () => {
|
||||
for (const type of ALLOWED_RELATIONSHIP_TYPES) {
|
||||
expect(validateRelationshipType(type)).toBe(true);
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
describe("invalid relationship types", () => {
|
||||
it("should reject unknown relationship type", () => {
|
||||
expect(validateRelationshipType("HATES")).toBe(false);
|
||||
});
|
||||
|
||||
it("should reject empty string", () => {
|
||||
expect(validateRelationshipType("")).toBe(false);
|
||||
});
|
||||
|
||||
it("should be case sensitive — lowercase is rejected", () => {
|
||||
expect(validateRelationshipType("works_at")).toBe(false);
|
||||
});
|
||||
|
||||
it("should be case sensitive — mixed case is rejected", () => {
|
||||
expect(validateRelationshipType("Works_At")).toBe(false);
|
||||
});
|
||||
|
||||
it("should reject types with extra whitespace", () => {
|
||||
expect(validateRelationshipType(" WORKS_AT ")).toBe(false);
|
||||
});
|
||||
|
||||
it("should reject potential Cypher injection", () => {
|
||||
expect(validateRelationshipType("WORKS_AT]->(n) DELETE n//")).toBe(false);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
// ============================================================================
|
||||
// Exported Constants
|
||||
// ============================================================================
|
||||
|
||||
describe("exported constants", () => {
|
||||
it("MEMORY_CATEGORIES should contain expected categories", () => {
|
||||
expect(MEMORY_CATEGORIES).toContain("preference");
|
||||
expect(MEMORY_CATEGORIES).toContain("fact");
|
||||
expect(MEMORY_CATEGORIES).toContain("decision");
|
||||
expect(MEMORY_CATEGORIES).toContain("entity");
|
||||
expect(MEMORY_CATEGORIES).toContain("other");
|
||||
});
|
||||
|
||||
it("ENTITY_TYPES should contain expected types", () => {
|
||||
expect(ENTITY_TYPES).toContain("person");
|
||||
expect(ENTITY_TYPES).toContain("organization");
|
||||
expect(ENTITY_TYPES).toContain("location");
|
||||
expect(ENTITY_TYPES).toContain("event");
|
||||
expect(ENTITY_TYPES).toContain("concept");
|
||||
});
|
||||
|
||||
it("ALLOWED_RELATIONSHIP_TYPES should be a Set", () => {
|
||||
expect(ALLOWED_RELATIONSHIP_TYPES).toBeInstanceOf(Set);
|
||||
expect(ALLOWED_RELATIONSHIP_TYPES.size).toBe(7);
|
||||
});
|
||||
});
|
||||
400
extensions/memory-neo4j/search.test.ts
Normal file
400
extensions/memory-neo4j/search.test.ts
Normal file
@@ -0,0 +1,400 @@
|
||||
/**
|
||||
* Tests for search.ts — Hybrid Search & RRF Fusion.
|
||||
*
|
||||
* Tests the exported pure logic: classifyQuery() and getAdaptiveWeights().
|
||||
* Note: fuseWithConfidenceRRF() is not exported (private module-level function)
|
||||
* and is tested indirectly through hybridSearch().
|
||||
* hybridSearch() is tested with mocked Neo4j client and Embeddings.
|
||||
*/
|
||||
|
||||
import { describe, it, expect, vi, beforeEach } from "vitest";
|
||||
import type { SearchSignalResult } from "./schema.js";
|
||||
import { classifyQuery, getAdaptiveWeights, hybridSearch } from "./search.js";
|
||||
|
||||
// ============================================================================
|
||||
// classifyQuery()
|
||||
// ============================================================================
|
||||
|
||||
describe("classifyQuery", () => {
|
||||
describe("short queries (1-2 words)", () => {
|
||||
it("should classify a single word as 'short'", () => {
|
||||
expect(classifyQuery("dogs")).toBe("short");
|
||||
});
|
||||
|
||||
it("should classify two words as 'short'", () => {
|
||||
expect(classifyQuery("best coffee")).toBe("short");
|
||||
});
|
||||
|
||||
it("should classify a single capitalized word as 'short' (word count takes priority)", () => {
|
||||
expect(classifyQuery("TypeScript")).toBe("short");
|
||||
});
|
||||
|
||||
it("should handle whitespace-padded short queries", () => {
|
||||
expect(classifyQuery(" hello ")).toBe("short");
|
||||
});
|
||||
});
|
||||
|
||||
describe("entity queries (proper nouns)", () => {
|
||||
it("should classify query with proper noun as 'entity'", () => {
|
||||
expect(classifyQuery("tell me about Tarun")).toBe("entity");
|
||||
});
|
||||
|
||||
it("should classify query with organization name as 'entity'", () => {
|
||||
expect(classifyQuery("what about Google")).toBe("entity");
|
||||
});
|
||||
|
||||
it("should classify question patterns targeting entities", () => {
|
||||
expect(classifyQuery("who is the CEO")).toBe("entity");
|
||||
});
|
||||
|
||||
it("should classify 'where is' patterns as entity", () => {
|
||||
expect(classifyQuery("where is the office")).toBe("entity");
|
||||
});
|
||||
|
||||
it("should classify 'what does' patterns as entity", () => {
|
||||
expect(classifyQuery("what does she do")).toBe("entity");
|
||||
});
|
||||
|
||||
it("should not treat common words (The, Is, etc.) as entity indicators", () => {
|
||||
// "The" and "Is" are excluded from capitalized word detection
|
||||
// 3 words, no proper nouns detected, no question pattern -> default
|
||||
expect(classifyQuery("this is fine")).toBe("default");
|
||||
});
|
||||
});
|
||||
|
||||
describe("long queries (5+ words)", () => {
|
||||
it("should classify a 5-word query as 'long'", () => {
|
||||
expect(classifyQuery("what is the best framework")).toBe("long");
|
||||
});
|
||||
|
||||
it("should classify a longer sentence as 'long'", () => {
|
||||
expect(classifyQuery("tell me about the history of programming languages")).toBe("long");
|
||||
});
|
||||
|
||||
it("should classify a verbose question as 'long'", () => {
|
||||
expect(classifyQuery("how do i configure the database connection")).toBe("long");
|
||||
});
|
||||
});
|
||||
|
||||
describe("default queries (3-4 words, no entities)", () => {
|
||||
it("should classify a 3-word lowercase query as 'default'", () => {
|
||||
expect(classifyQuery("my favorite color")).toBe("default");
|
||||
});
|
||||
|
||||
it("should classify a 4-word lowercase query as 'default'", () => {
|
||||
expect(classifyQuery("best practices for testing")).toBe("default");
|
||||
});
|
||||
});
|
||||
|
||||
describe("edge cases", () => {
|
||||
it("should handle empty string", () => {
|
||||
// Empty string splits to [""], length 1 -> "short"
|
||||
expect(classifyQuery("")).toBe("short");
|
||||
});
|
||||
|
||||
it("should handle only whitespace", () => {
|
||||
// " ".trim() = "", splits to [""], length 1 -> "short"
|
||||
expect(classifyQuery(" ")).toBe("short");
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
// ============================================================================
|
||||
// getAdaptiveWeights()
|
||||
// ============================================================================
|
||||
|
||||
describe("getAdaptiveWeights", () => {
|
||||
describe("with graph enabled", () => {
|
||||
it("should boost BM25 for short queries", () => {
|
||||
const [vector, bm25, graph] = getAdaptiveWeights("short", true);
|
||||
expect(bm25).toBeGreaterThan(vector);
|
||||
expect(vector).toBe(0.8);
|
||||
expect(bm25).toBe(1.2);
|
||||
expect(graph).toBe(1.0);
|
||||
});
|
||||
|
||||
it("should boost graph for entity queries", () => {
|
||||
const [vector, bm25, graph] = getAdaptiveWeights("entity", true);
|
||||
expect(graph).toBeGreaterThan(vector);
|
||||
expect(graph).toBeGreaterThan(bm25);
|
||||
expect(vector).toBe(0.8);
|
||||
expect(bm25).toBe(1.0);
|
||||
expect(graph).toBe(1.3);
|
||||
});
|
||||
|
||||
it("should boost vector for long queries", () => {
|
||||
const [vector, bm25, graph] = getAdaptiveWeights("long", true);
|
||||
expect(vector).toBeGreaterThan(bm25);
|
||||
expect(vector).toBeGreaterThan(graph);
|
||||
expect(vector).toBe(1.2);
|
||||
expect(bm25).toBe(0.7);
|
||||
expect(graph).toBeCloseTo(0.8);
|
||||
});
|
||||
|
||||
it("should return balanced weights for default queries", () => {
|
||||
const [vector, bm25, graph] = getAdaptiveWeights("default", true);
|
||||
expect(vector).toBe(1.0);
|
||||
expect(bm25).toBe(1.0);
|
||||
expect(graph).toBe(1.0);
|
||||
});
|
||||
});
|
||||
|
||||
describe("with graph disabled", () => {
|
||||
it("should zero-out graph weight for short queries", () => {
|
||||
const [vector, bm25, graph] = getAdaptiveWeights("short", false);
|
||||
expect(graph).toBe(0);
|
||||
expect(vector).toBe(0.8);
|
||||
expect(bm25).toBe(1.2);
|
||||
});
|
||||
|
||||
it("should zero-out graph weight for entity queries", () => {
|
||||
const [, , graph] = getAdaptiveWeights("entity", false);
|
||||
expect(graph).toBe(0);
|
||||
});
|
||||
|
||||
it("should zero-out graph weight for long queries", () => {
|
||||
const [, , graph] = getAdaptiveWeights("long", false);
|
||||
expect(graph).toBe(0);
|
||||
});
|
||||
|
||||
it("should zero-out graph weight for default queries", () => {
|
||||
const [, , graph] = getAdaptiveWeights("default", false);
|
||||
expect(graph).toBe(0);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
// ============================================================================
|
||||
// hybridSearch() — integration test with mocked dependencies
|
||||
// ============================================================================
|
||||
|
||||
describe("hybridSearch", () => {
|
||||
// Create mock db and embeddings
|
||||
const mockDb = {
|
||||
vectorSearch: vi.fn(),
|
||||
bm25Search: vi.fn(),
|
||||
graphSearch: vi.fn(),
|
||||
recordRetrievals: vi.fn(),
|
||||
};
|
||||
|
||||
const mockEmbeddings = {
|
||||
embed: vi.fn(),
|
||||
embedBatch: vi.fn(),
|
||||
};
|
||||
|
||||
beforeEach(() => {
|
||||
vi.resetAllMocks();
|
||||
mockEmbeddings.embed.mockResolvedValue([0.1, 0.2, 0.3]);
|
||||
mockDb.recordRetrievals.mockResolvedValue(undefined);
|
||||
});
|
||||
|
||||
function makeSignalResult(overrides: Partial<SearchSignalResult> = {}): SearchSignalResult {
|
||||
return {
|
||||
id: "mem-1",
|
||||
text: "Test memory",
|
||||
category: "fact",
|
||||
importance: 0.7,
|
||||
createdAt: "2025-01-01T00:00:00Z",
|
||||
score: 0.9,
|
||||
...overrides,
|
||||
};
|
||||
}
|
||||
|
||||
it("should return empty array when no signals return results", async () => {
|
||||
mockDb.vectorSearch.mockResolvedValue([]);
|
||||
mockDb.bm25Search.mockResolvedValue([]);
|
||||
|
||||
const results = await hybridSearch(
|
||||
mockDb as never,
|
||||
mockEmbeddings as never,
|
||||
"test query",
|
||||
5,
|
||||
"agent-1",
|
||||
false,
|
||||
);
|
||||
|
||||
expect(results).toEqual([]);
|
||||
expect(mockDb.recordRetrievals).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it("should fuse results from vector and BM25 signals", async () => {
|
||||
const vectorResult = makeSignalResult({ id: "mem-1", score: 0.95, text: "Vector match" });
|
||||
const bm25Result = makeSignalResult({ id: "mem-2", score: 0.8, text: "BM25 match" });
|
||||
|
||||
mockDb.vectorSearch.mockResolvedValue([vectorResult]);
|
||||
mockDb.bm25Search.mockResolvedValue([bm25Result]);
|
||||
|
||||
const results = await hybridSearch(
|
||||
mockDb as never,
|
||||
mockEmbeddings as never,
|
||||
"test query",
|
||||
5,
|
||||
"agent-1",
|
||||
false,
|
||||
);
|
||||
|
||||
expect(results.length).toBe(2);
|
||||
// Results should have scores normalized to 0-1
|
||||
expect(results[0].score).toBeLessThanOrEqual(1);
|
||||
expect(results[0].score).toBeGreaterThanOrEqual(0);
|
||||
// First result should have the highest score (normalized to 1)
|
||||
expect(results[0].score).toBe(1);
|
||||
});
|
||||
|
||||
it("should deduplicate across signals (same memory in multiple signals)", async () => {
|
||||
const sharedResult = makeSignalResult({ id: "mem-shared", score: 0.9 });
|
||||
|
||||
mockDb.vectorSearch.mockResolvedValue([sharedResult]);
|
||||
mockDb.bm25Search.mockResolvedValue([{ ...sharedResult, score: 0.85 }]);
|
||||
|
||||
const results = await hybridSearch(
|
||||
mockDb as never,
|
||||
mockEmbeddings as never,
|
||||
"test query",
|
||||
5,
|
||||
"agent-1",
|
||||
false,
|
||||
);
|
||||
|
||||
// Should only have one result (deduplicated by ID)
|
||||
expect(results.length).toBe(1);
|
||||
expect(results[0].id).toBe("mem-shared");
|
||||
// Score should be higher than either individual signal (boosted by appearing in both)
|
||||
expect(results[0].score).toBe(1); // It's the only result, so normalized to 1
|
||||
});
|
||||
|
||||
it("should include graph signal when graphEnabled is true", async () => {
|
||||
mockDb.vectorSearch.mockResolvedValue([]);
|
||||
mockDb.bm25Search.mockResolvedValue([]);
|
||||
mockDb.graphSearch.mockResolvedValue([
|
||||
makeSignalResult({ id: "mem-graph", score: 0.7, text: "Graph result" }),
|
||||
]);
|
||||
|
||||
const results = await hybridSearch(
|
||||
mockDb as never,
|
||||
mockEmbeddings as never,
|
||||
"tell me about Tarun",
|
||||
5,
|
||||
"agent-1",
|
||||
true,
|
||||
);
|
||||
|
||||
expect(mockDb.graphSearch).toHaveBeenCalled();
|
||||
expect(results.length).toBe(1);
|
||||
expect(results[0].id).toBe("mem-graph");
|
||||
});
|
||||
|
||||
it("should not call graphSearch when graphEnabled is false", async () => {
|
||||
mockDb.vectorSearch.mockResolvedValue([]);
|
||||
mockDb.bm25Search.mockResolvedValue([]);
|
||||
|
||||
await hybridSearch(mockDb as never, mockEmbeddings as never, "test query", 5, "agent-1", false);
|
||||
|
||||
expect(mockDb.graphSearch).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it("should limit results to the requested count", async () => {
|
||||
const manyResults = Array.from({ length: 10 }, (_, i) =>
|
||||
makeSignalResult({ id: `mem-${i}`, score: 0.9 - i * 0.05 }),
|
||||
);
|
||||
|
||||
mockDb.vectorSearch.mockResolvedValue(manyResults);
|
||||
mockDb.bm25Search.mockResolvedValue([]);
|
||||
|
||||
const results = await hybridSearch(
|
||||
mockDb as never,
|
||||
mockEmbeddings as never,
|
||||
"test query",
|
||||
3,
|
||||
"agent-1",
|
||||
false,
|
||||
);
|
||||
|
||||
expect(results.length).toBe(3);
|
||||
});
|
||||
|
||||
it("should record retrieval events for returned results", async () => {
|
||||
mockDb.vectorSearch.mockResolvedValue([
|
||||
makeSignalResult({ id: "mem-1" }),
|
||||
makeSignalResult({ id: "mem-2" }),
|
||||
]);
|
||||
mockDb.bm25Search.mockResolvedValue([]);
|
||||
|
||||
await hybridSearch(mockDb as never, mockEmbeddings as never, "test query", 5, "agent-1", false);
|
||||
|
||||
expect(mockDb.recordRetrievals).toHaveBeenCalledWith(["mem-1", "mem-2"]);
|
||||
});
|
||||
|
||||
it("should silently handle recordRetrievals failure", async () => {
|
||||
mockDb.vectorSearch.mockResolvedValue([makeSignalResult({ id: "mem-1" })]);
|
||||
mockDb.bm25Search.mockResolvedValue([]);
|
||||
mockDb.recordRetrievals.mockRejectedValue(new Error("DB connection lost"));
|
||||
|
||||
// Should not throw
|
||||
const results = await hybridSearch(
|
||||
mockDb as never,
|
||||
mockEmbeddings as never,
|
||||
"test query",
|
||||
5,
|
||||
"agent-1",
|
||||
false,
|
||||
);
|
||||
|
||||
expect(results.length).toBe(1);
|
||||
});
|
||||
|
||||
it("should normalize scores to 0-1 range", async () => {
|
||||
mockDb.vectorSearch.mockResolvedValue([
|
||||
makeSignalResult({ id: "mem-1", score: 0.95 }),
|
||||
makeSignalResult({ id: "mem-2", score: 0.5 }),
|
||||
]);
|
||||
mockDb.bm25Search.mockResolvedValue([]);
|
||||
|
||||
const results = await hybridSearch(
|
||||
mockDb as never,
|
||||
mockEmbeddings as never,
|
||||
"test query",
|
||||
5,
|
||||
"agent-1",
|
||||
false,
|
||||
);
|
||||
|
||||
for (const r of results) {
|
||||
expect(r.score).toBeGreaterThanOrEqual(0);
|
||||
expect(r.score).toBeLessThanOrEqual(1);
|
||||
}
|
||||
});
|
||||
|
||||
it("should use candidateMultiplier option", async () => {
|
||||
mockDb.vectorSearch.mockResolvedValue([]);
|
||||
mockDb.bm25Search.mockResolvedValue([]);
|
||||
|
||||
await hybridSearch(
|
||||
mockDb as never,
|
||||
mockEmbeddings as never,
|
||||
"test query",
|
||||
5,
|
||||
"agent-1",
|
||||
false,
|
||||
{ candidateMultiplier: 8 },
|
||||
);
|
||||
|
||||
// limit=5, multiplier=8 => candidateLimit = 40
|
||||
expect(mockDb.vectorSearch).toHaveBeenCalledWith(expect.any(Array), 40, 0.1, "agent-1");
|
||||
expect(mockDb.bm25Search).toHaveBeenCalledWith("test query", 40, "agent-1");
|
||||
});
|
||||
|
||||
it("should pass default agentId when not specified", async () => {
|
||||
mockDb.vectorSearch.mockResolvedValue([]);
|
||||
mockDb.bm25Search.mockResolvedValue([]);
|
||||
|
||||
await hybridSearch(mockDb as never, mockEmbeddings as never, "test query");
|
||||
|
||||
expect(mockDb.vectorSearch).toHaveBeenCalledWith(
|
||||
expect.any(Array),
|
||||
expect.any(Number),
|
||||
0.1,
|
||||
"default",
|
||||
);
|
||||
});
|
||||
});
|
||||
Reference in New Issue
Block a user