mirror of
https://github.com/openclaw/openclaw.git
synced 2026-05-26 23:03:32 +00:00
176 lines
5.2 KiB
TypeScript
176 lines
5.2 KiB
TypeScript
/**
|
|
* Embedding generation for memory-neo4j.
|
|
*
|
|
* Supports both OpenAI and Ollama providers.
|
|
*/
|
|
|
|
import OpenAI from "openai";
|
|
import type { EmbeddingProvider } from "./config.js";
|
|
import { contextLengthForModel } from "./config.js";
|
|
|
|
type Logger = {
|
|
info: (msg: string) => void;
|
|
warn: (msg: string) => void;
|
|
error: (msg: string) => void;
|
|
debug?: (msg: string) => void;
|
|
};
|
|
|
|
export class Embeddings {
|
|
private client: OpenAI | null = null;
|
|
private readonly provider: EmbeddingProvider;
|
|
private readonly baseUrl: string;
|
|
private readonly logger: Logger | undefined;
|
|
private readonly contextLength: number;
|
|
|
|
constructor(
|
|
private readonly apiKey: string | undefined,
|
|
private readonly model: string = "text-embedding-3-small",
|
|
provider: EmbeddingProvider = "openai",
|
|
baseUrl?: string,
|
|
logger?: Logger,
|
|
) {
|
|
this.provider = provider;
|
|
this.baseUrl = baseUrl ?? (provider === "ollama" ? "http://localhost:11434" : "");
|
|
this.logger = logger;
|
|
this.contextLength = contextLengthForModel(model);
|
|
|
|
if (provider === "openai") {
|
|
if (!apiKey) {
|
|
throw new Error("API key required for OpenAI embeddings");
|
|
}
|
|
this.client = new OpenAI({ apiKey });
|
|
}
|
|
}
|
|
|
|
/**
|
|
* Truncate text to fit within the model's context length.
|
|
* Uses a conservative ~4 chars/token estimate to leave headroom.
|
|
* Truncates at a word boundary when possible.
|
|
*/
|
|
private truncateToContext(text: string): string {
|
|
const maxChars = this.contextLength * 4;
|
|
if (text.length <= maxChars) {
|
|
return text;
|
|
}
|
|
|
|
// Try to truncate at a word boundary
|
|
let truncated = text.slice(0, maxChars);
|
|
const lastSpace = truncated.lastIndexOf(" ");
|
|
if (lastSpace > maxChars * 0.8) {
|
|
truncated = truncated.slice(0, lastSpace);
|
|
}
|
|
|
|
this.logger?.debug?.(
|
|
`memory-neo4j: truncated embedding input from ${text.length} to ${truncated.length} chars (model context: ${this.contextLength} tokens)`,
|
|
);
|
|
return truncated;
|
|
}
|
|
|
|
/**
|
|
* Generate an embedding vector for a single text.
|
|
*/
|
|
async embed(text: string): Promise<number[]> {
|
|
const input = this.truncateToContext(text);
|
|
if (this.provider === "ollama") {
|
|
return this.embedOllama(input);
|
|
}
|
|
return this.embedOpenAI(input);
|
|
}
|
|
|
|
/**
|
|
* Generate embeddings for multiple texts.
|
|
* Returns array of embeddings in the same order as input.
|
|
*
|
|
* For Ollama: uses Promise.allSettled so individual failures don't break the
|
|
* entire batch. Failed embeddings are replaced with zero vectors and logged.
|
|
*/
|
|
async embedBatch(texts: string[]): Promise<number[][]> {
|
|
if (texts.length === 0) {
|
|
return [];
|
|
}
|
|
|
|
const truncated = texts.map((t) => this.truncateToContext(t));
|
|
|
|
if (this.provider === "ollama") {
|
|
// Ollama doesn't support batch, so we do sequential with resilient error handling
|
|
const results = await Promise.allSettled(truncated.map((t) => this.embedOllama(t)));
|
|
const embeddings: number[][] = [];
|
|
let failures = 0;
|
|
|
|
for (let i = 0; i < results.length; i++) {
|
|
const result = results[i];
|
|
if (result.status === "fulfilled") {
|
|
embeddings.push(result.value);
|
|
} else {
|
|
failures++;
|
|
this.logger?.warn?.(
|
|
`memory-neo4j: Ollama embedding failed for text ${i}: ${String(result.reason)}`,
|
|
);
|
|
// Use zero vector as placeholder so indices stay aligned
|
|
embeddings.push([]);
|
|
}
|
|
}
|
|
|
|
if (failures > 0) {
|
|
this.logger?.warn?.(
|
|
`memory-neo4j: ${failures}/${texts.length} Ollama embeddings failed in batch`,
|
|
);
|
|
}
|
|
|
|
return embeddings;
|
|
}
|
|
|
|
return this.embedBatchOpenAI(truncated);
|
|
}
|
|
|
|
private async embedOpenAI(text: string): Promise<number[]> {
|
|
if (!this.client) {
|
|
throw new Error("OpenAI client not initialized");
|
|
}
|
|
const response = await this.client.embeddings.create({
|
|
model: this.model,
|
|
input: text,
|
|
});
|
|
return response.data[0].embedding;
|
|
}
|
|
|
|
private async embedBatchOpenAI(texts: string[]): Promise<number[][]> {
|
|
if (!this.client) {
|
|
throw new Error("OpenAI client not initialized");
|
|
}
|
|
const response = await this.client.embeddings.create({
|
|
model: this.model,
|
|
input: texts,
|
|
});
|
|
// Sort by index to ensure correct order
|
|
return response.data.toSorted((a, b) => a.index - b.index).map((d) => d.embedding);
|
|
}
|
|
|
|
// Timeout for Ollama embedding fetch calls to prevent hanging indefinitely
|
|
private static readonly EMBED_TIMEOUT_MS = 30_000;
|
|
|
|
private async embedOllama(text: string): Promise<number[]> {
|
|
const url = `${this.baseUrl}/api/embed`;
|
|
const response = await fetch(url, {
|
|
method: "POST",
|
|
headers: { "Content-Type": "application/json" },
|
|
body: JSON.stringify({
|
|
model: this.model,
|
|
input: text,
|
|
}),
|
|
signal: AbortSignal.timeout(Embeddings.EMBED_TIMEOUT_MS),
|
|
});
|
|
|
|
if (!response.ok) {
|
|
const error = await response.text();
|
|
throw new Error(`Ollama embedding failed: ${response.status} ${error}`);
|
|
}
|
|
|
|
const data = (await response.json()) as { embeddings?: number[][] };
|
|
if (!data.embeddings?.[0]) {
|
|
throw new Error("No embedding returned from Ollama");
|
|
}
|
|
return data.embeddings[0];
|
|
}
|
|
}
|