feat(memory): add Ollama embedding provider

- Add Ollama as embedding provider for memory search (provider/fallback)
- Keep main state (Mistral) and support both in types, schema, runtime
- Add embeddings-ollama.ts and tests

Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
nico-hoff
2026-02-25 10:08:56 +01:00
committed by Gustavo Madeira Santana
parent 4ba5937ef9
commit 00317343e7
9 changed files with 153 additions and 17 deletions

View File

@@ -9,7 +9,7 @@ export type ResolvedMemorySearchConfig = {
enabled: boolean;
sources: Array<"memory" | "sessions">;
extraPaths: string[];
provider: "openai" | "local" | "gemini" | "voyage" | "mistral" | "auto";
provider: "openai" | "local" | "gemini" | "voyage" | "mistral" | "ollama" | "auto";
remote?: {
baseUrl?: string;
apiKey?: string;
@@ -25,7 +25,7 @@ export type ResolvedMemorySearchConfig = {
experimental: {
sessionMemory: boolean;
};
fallback: "openai" | "gemini" | "local" | "voyage" | "mistral" | "none";
fallback: "openai" | "gemini" | "local" | "voyage" | "mistral" | "ollama" | "none";
model: string;
local: {
modelPath?: string;

View File

@@ -724,7 +724,7 @@ export const FIELD_HELP: Record<string, string> = {
"agents.defaults.memorySearch.experimental.sessionMemory":
"Indexes session transcripts into memory search so responses can reference prior chat turns. Keep this off unless transcript recall is needed, because indexing cost and storage usage both increase.",
"agents.defaults.memorySearch.provider":
'Selects the embedding backend used to build/query memory vectors: "openai", "gemini", "voyage", "mistral", or "local". Keep your most reliable provider here and configure fallback for resilience.',
'Selects the embedding backend used to build/query memory vectors: "openai", "gemini", "voyage", "mistral", "ollama", or "local". Keep your most reliable provider here and configure fallback for resilience.',
"agents.defaults.memorySearch.model":
"Embedding model override used by the selected memory provider when a non-default model is required. Set this only when you need explicit recall quality/cost tuning beyond provider defaults.",
"agents.defaults.memorySearch.remote.baseUrl":
@@ -746,7 +746,7 @@ export const FIELD_HELP: Record<string, string> = {
"agents.defaults.memorySearch.local.modelPath":
"Specifies the local embedding model source for local memory search, such as a GGUF file path or `hf:` URI. Use this only when provider is `local`, and verify model compatibility before large index rebuilds.",
"agents.defaults.memorySearch.fallback":
'Backup provider used when primary embeddings fail: "openai", "gemini", "voyage", "mistral", "local", or "none". Set a real fallback for production reliability; use "none" only if you prefer explicit failures.',
'Backup provider used when primary embeddings fail: "openai", "gemini", "voyage", "mistral", "ollama", "local", or "none". Set a real fallback for production reliability; use "none" only if you prefer explicit failures.',
"agents.defaults.memorySearch.store.path":
"Sets where the SQLite memory index is stored on disk for each agent. Keep the default `~/.openclaw/memory/{agentId}.sqlite` unless you need custom storage placement or backup policy alignment.",
"agents.defaults.memorySearch.store.vector.enabled":

View File

@@ -324,7 +324,7 @@ export type MemorySearchConfig = {
sessionMemory?: boolean;
};
/** Embedding provider mode. */
provider?: "openai" | "gemini" | "local" | "voyage" | "mistral";
provider?: "openai" | "gemini" | "local" | "voyage" | "mistral" | "ollama";
remote?: {
baseUrl?: string;
apiKey?: string;
@@ -343,7 +343,7 @@ export type MemorySearchConfig = {
};
};
/** Fallback behavior when embeddings fail. */
fallback?: "openai" | "gemini" | "local" | "voyage" | "mistral" | "none";
fallback?: "openai" | "gemini" | "local" | "voyage" | "mistral" | "ollama" | "none";
/** Embedding model id (remote) or alias (local). */
model?: string;
/** Local embedding settings (node-llama-cpp). */

View File

@@ -557,6 +557,7 @@ export const MemorySearchSchema = z
z.literal("gemini"),
z.literal("voyage"),
z.literal("mistral"),
z.literal("ollama"),
])
.optional(),
remote: z
@@ -584,6 +585,7 @@ export const MemorySearchSchema = z
z.literal("local"),
z.literal("voyage"),
z.literal("mistral"),
z.literal("ollama"),
z.literal("none"),
])
.optional(),

View File

@@ -0,0 +1,31 @@
import { describe, it, expect, vi } from "vitest";
import type { OpenClawConfig } from "../config/config.js";
import { createOllamaEmbeddingProvider } from "./embeddings-ollama.js";
describe("embeddings-ollama", () => {
it("calls /api/embeddings and returns normalized vectors", async () => {
const fetchMock = vi.fn(
async () =>
new Response(JSON.stringify({ embedding: [3, 4] }), {
status: 200,
headers: { "content-type": "application/json" },
}),
);
// @ts-expect-error test override
globalThis.fetch = fetchMock;
const { provider } = await createOllamaEmbeddingProvider({
config: {} as OpenClawConfig,
provider: "ollama",
model: "nomic-embed-text",
fallback: "none",
remote: { baseUrl: "http://127.0.0.1:11434" },
});
const v = await provider.embedQuery("hi");
expect(fetchMock).toHaveBeenCalledTimes(1);
// normalized [3,4] => [0.6,0.8]
expect(v[0]).toBeCloseTo(0.6, 5);
expect(v[1]).toBeCloseTo(0.8, 5);
});
});

View File

@@ -0,0 +1,72 @@
import { formatErrorMessage } from "../infra/errors.js";
import type { EmbeddingProvider, EmbeddingProviderOptions } from "./embeddings.js";
export type OllamaEmbeddingClient = {
embedBatch: (texts: string[]) => Promise<number[][]>;
};
function sanitizeAndNormalizeEmbedding(vec: number[]): number[] {
const sanitized = vec.map((value) => (Number.isFinite(value) ? value : 0));
const magnitude = Math.sqrt(sanitized.reduce((sum, value) => sum + value * value, 0));
if (magnitude < 1e-10) {
return sanitized;
}
return sanitized.map((value) => value / magnitude);
}
export async function createOllamaEmbeddingProvider(
options: EmbeddingProviderOptions,
): Promise<{ provider: EmbeddingProvider; client: OllamaEmbeddingClient }> {
const baseUrl = options.remote?.baseUrl?.trim() || "http://127.0.0.1:11434";
const model = options.model || "nomic-embed-text";
const headers: Record<string, string> = {
"content-type": "application/json",
...options.remote?.headers,
};
// Ollama doesn't require an API key by default. If users set one (proxy), allow it.
const apiKey = options.remote?.apiKey;
if (apiKey) {
headers.authorization = `Bearer ${apiKey}`;
}
const embedOne = async (text: string): Promise<number[]> => {
const res = await fetch(`${baseUrl.replace(/\/$/, "")}/api/embeddings`, {
method: "POST",
headers,
body: JSON.stringify({ model, prompt: text }),
});
if (!res.ok) {
throw new Error(`Ollama embeddings HTTP ${res.status}: ${await res.text()}`);
}
const json = (await res.json()) as { embedding?: number[] };
if (!Array.isArray(json.embedding)) {
throw new Error(`Ollama embeddings response missing embedding[]`);
}
return sanitizeAndNormalizeEmbedding(json.embedding);
};
const provider: EmbeddingProvider = {
id: "ollama",
model,
embedQuery: embedOne,
embedBatch: async (texts: string[]) => {
// Ollama /api/embeddings is single-prompt; parallelize with a small fanout.
// Keep it simple and let caller batch size control overall load.
return await Promise.all(texts.map(embedOne));
},
};
const client: OllamaEmbeddingClient = {
embedBatch: async (texts) => {
try {
return await provider.embedBatch(texts);
} catch (err) {
throw new Error(formatErrorMessage(err), { cause: err });
}
},
};
return { provider, client };
}

View File

@@ -8,6 +8,7 @@ import {
createMistralEmbeddingProvider,
type MistralEmbeddingClient,
} from "./embeddings-mistral.js";
import { createOllamaEmbeddingProvider, type OllamaEmbeddingClient } from "./embeddings-ollama.js";
import { createOpenAiEmbeddingProvider, type OpenAiEmbeddingClient } from "./embeddings-openai.js";
import { createVoyageEmbeddingProvider, type VoyageEmbeddingClient } from "./embeddings-voyage.js";
import { importNodeLlamaCpp } from "./node-llama.js";
@@ -25,6 +26,7 @@ export type { GeminiEmbeddingClient } from "./embeddings-gemini.js";
export type { MistralEmbeddingClient } from "./embeddings-mistral.js";
export type { OpenAiEmbeddingClient } from "./embeddings-openai.js";
export type { VoyageEmbeddingClient } from "./embeddings-voyage.js";
export type { OllamaEmbeddingClient } from "./embeddings-ollama.js";
export type EmbeddingProvider = {
id: string;
@@ -34,11 +36,11 @@ export type EmbeddingProvider = {
embedBatch: (texts: string[]) => Promise<number[][]>;
};
export type EmbeddingProviderId = "openai" | "local" | "gemini" | "voyage" | "mistral";
export type EmbeddingProviderId = "openai" | "local" | "gemini" | "voyage" | "mistral" | "ollama";
export type EmbeddingProviderRequest = EmbeddingProviderId | "auto";
export type EmbeddingProviderFallback = EmbeddingProviderId | "none";
const REMOTE_EMBEDDING_PROVIDER_IDS = ["openai", "gemini", "voyage", "mistral"] as const;
const REMOTE_EMBEDDING_PROVIDER_IDS = ["openai", "gemini", "voyage", "mistral", "ollama"] as const;
export type EmbeddingProviderResult = {
provider: EmbeddingProvider | null;
@@ -50,6 +52,7 @@ export type EmbeddingProviderResult = {
gemini?: GeminiEmbeddingClient;
voyage?: VoyageEmbeddingClient;
mistral?: MistralEmbeddingClient;
ollama?: OllamaEmbeddingClient;
};
export type EmbeddingProviderOptions = {
@@ -152,6 +155,10 @@ export async function createEmbeddingProvider(
const provider = await createLocalEmbeddingProvider(options);
return { provider };
}
if (id === "ollama") {
const { provider, client } = await createOllamaEmbeddingProvider(options);
return { provider, ollama: client };
}
if (id === "gemini") {
const { provider, client } = await createGeminiEmbeddingProvider(options);
return { provider, gemini: client };

View File

@@ -20,6 +20,7 @@ import {
type EmbeddingProvider,
type GeminiEmbeddingClient,
type MistralEmbeddingClient,
type OllamaEmbeddingClient,
type OpenAiEmbeddingClient,
type VoyageEmbeddingClient,
} from "./embeddings.js";
@@ -91,11 +92,12 @@ export abstract class MemoryManagerSyncOps {
protected abstract readonly workspaceDir: string;
protected abstract readonly settings: ResolvedMemorySearchConfig;
protected provider: EmbeddingProvider | null = null;
protected fallbackFrom?: "openai" | "local" | "gemini" | "voyage" | "mistral";
protected fallbackFrom?: "openai" | "local" | "gemini" | "voyage" | "mistral" | "ollama";
protected openAi?: OpenAiEmbeddingClient;
protected gemini?: GeminiEmbeddingClient;
protected voyage?: VoyageEmbeddingClient;
protected mistral?: MistralEmbeddingClient;
protected ollama?: OllamaEmbeddingClient;
protected abstract batch: {
enabled: boolean;
wait: boolean;
@@ -350,7 +352,10 @@ export abstract class MemoryManagerSyncOps {
this.fts.available = result.ftsAvailable;
if (result.ftsError) {
this.fts.loadError = result.ftsError;
log.warn(`fts unavailable: ${result.ftsError}`);
// Only warn when hybrid search is enabled; otherwise this is expected noise.
if (this.fts.enabled) {
log.warn(`fts unavailable: ${result.ftsError}`);
}
}
}
@@ -958,7 +963,13 @@ export abstract class MemoryManagerSyncOps {
if (this.fallbackFrom) {
return false;
}
const fallbackFrom = this.provider.id as "openai" | "gemini" | "local" | "voyage" | "mistral";
const fallbackFrom = this.provider.id as
| "openai"
| "gemini"
| "local"
| "voyage"
| "mistral"
| "ollama";
const fallbackModel =
fallback === "gemini"
@@ -988,6 +999,7 @@ export abstract class MemoryManagerSyncOps {
this.gemini = fallbackResult.gemini;
this.voyage = fallbackResult.voyage;
this.mistral = fallbackResult.mistral;
this.ollama = fallbackResult.ollama;
this.providerKey = this.computeProviderKey();
this.batch = this.resolveBatchConfig();
log.warn(`memory embeddings: switched to fallback provider (${fallback})`, { reason });

View File

@@ -13,6 +13,7 @@ import {
type EmbeddingProviderResult,
type GeminiEmbeddingClient,
type MistralEmbeddingClient,
type OllamaEmbeddingClient,
type OpenAiEmbeddingClient,
type VoyageEmbeddingClient,
} from "./embeddings.js";
@@ -48,14 +49,22 @@ export class MemoryIndexManager extends MemoryManagerEmbeddingOps implements Mem
protected readonly workspaceDir: string;
protected readonly settings: ResolvedMemorySearchConfig;
protected provider: EmbeddingProvider | null;
private readonly requestedProvider: "openai" | "local" | "gemini" | "voyage" | "mistral" | "auto";
protected fallbackFrom?: "openai" | "local" | "gemini" | "voyage" | "mistral";
private readonly requestedProvider:
| "openai"
| "local"
| "gemini"
| "voyage"
| "mistral"
| "ollama"
| "auto";
protected fallbackFrom?: "openai" | "local" | "gemini" | "voyage" | "mistral" | "ollama";
protected fallbackReason?: string;
private readonly providerUnavailableReason?: string;
protected openAi?: OpenAiEmbeddingClient;
protected gemini?: GeminiEmbeddingClient;
protected voyage?: VoyageEmbeddingClient;
protected mistral?: MistralEmbeddingClient;
protected ollama?: OllamaEmbeddingClient;
protected batch: {
enabled: boolean;
wait: boolean;
@@ -185,6 +194,7 @@ export class MemoryIndexManager extends MemoryManagerEmbeddingOps implements Mem
this.gemini = params.providerResult.gemini;
this.voyage = params.providerResult.voyage;
this.mistral = params.providerResult.mistral;
this.ollama = params.providerResult.ollama;
this.sources = new Set(params.settings.sources);
this.db = this.openDatabase();
this.providerKey = this.computeProviderKey();
@@ -289,9 +299,11 @@ export class MemoryIndexManager extends MemoryManagerEmbeddingOps implements Mem
return merged;
}
const keywordResults = hybrid.enabled
? await this.searchKeyword(cleaned, candidates).catch(() => [])
: [];
// If FTS isn't available, hybrid mode cannot use keyword search; degrade to vector-only.
const keywordResults =
hybrid.enabled && this.fts.enabled && this.fts.available
? await this.searchKeyword(cleaned, candidates).catch(() => [])
: [];
const queryVec = await this.embedQueryWithTimeout(cleaned);
const hasVector = queryVec.some((v) => v !== 0);
@@ -299,7 +311,7 @@ export class MemoryIndexManager extends MemoryManagerEmbeddingOps implements Mem
? await this.searchVector(queryVec, candidates).catch(() => [])
: [];
if (!hybrid.enabled) {
if (!hybrid.enabled || !this.fts.enabled || !this.fts.available) {
return vectorResults.filter((entry) => entry.score >= minScore).slice(0, maxResults);
}