refactor(memory): unify embedding provider constants

This commit is contained in:
Peter Steinberger
2026-02-14 03:16:46 +01:00
parent 61b5133264
commit 03fee3c605
2 changed files with 36 additions and 15 deletions

View File

@@ -1,7 +1,7 @@
import { afterEach, describe, expect, it, vi } from "vitest"; import { afterEach, describe, expect, it, vi } from "vitest";
import * as authModule from "../agents/model-auth.js"; import * as authModule from "../agents/model-auth.js";
import { DEFAULT_GEMINI_EMBEDDING_MODEL } from "./embeddings-gemini.js"; import { DEFAULT_GEMINI_EMBEDDING_MODEL } from "./embeddings-gemini.js";
import { createEmbeddingProvider } from "./embeddings.js"; import { createEmbeddingProvider, DEFAULT_LOCAL_MODEL } from "./embeddings.js";
vi.mock("../agents/model-auth.js", () => ({ vi.mock("../agents/model-auth.js", () => ({
resolveApiKeyForProvider: vi.fn(), resolveApiKeyForProvider: vi.fn(),
@@ -303,6 +303,23 @@ describe("embedding provider local fallback", () => {
}), }),
).rejects.toThrow(/optional dependency node-llama-cpp/i); ).rejects.toThrow(/optional dependency node-llama-cpp/i);
}); });
it("mentions every remote provider in local setup guidance", async () => {
importNodeLlamaCppMock.mockRejectedValue(
Object.assign(new Error("Cannot find package 'node-llama-cpp'"), {
code: "ERR_MODULE_NOT_FOUND",
}),
);
await expect(
createEmbeddingProvider({
config: {} as never,
provider: "local",
model: "text-embedding-3-small",
fallback: "none",
}),
).rejects.toThrow(/provider = "gemini"/i);
});
}); });
describe("local embedding normalization", () => { describe("local embedding normalization", () => {
@@ -341,10 +358,7 @@ describe("local embedding normalization", () => {
const magnitude = Math.sqrt(embedding.reduce((sum, x) => sum + x * x, 0)); const magnitude = Math.sqrt(embedding.reduce((sum, x) => sum + x * x, 0));
expect(magnitude).toBeCloseTo(1.0, 5); expect(magnitude).toBeCloseTo(1.0, 5);
expect(resolveModelFileMock).toHaveBeenCalledWith( expect(resolveModelFileMock).toHaveBeenCalledWith(DEFAULT_LOCAL_MODEL, undefined);
"hf:ggml-org/embeddinggemma-300m-qat-q8_0-GGUF/embeddinggemma-300m-qat-Q8_0.gguf",
undefined,
);
}); });
it("handles zero vector without division by zero", async () => { it("handles zero vector without division by zero", async () => {

View File

@@ -29,10 +29,16 @@ export type EmbeddingProvider = {
embedBatch: (texts: string[]) => Promise<number[][]>; embedBatch: (texts: string[]) => Promise<number[][]>;
}; };
export type EmbeddingProviderId = "openai" | "local" | "gemini" | "voyage";
export type EmbeddingProviderRequest = EmbeddingProviderId | "auto";
export type EmbeddingProviderFallback = EmbeddingProviderId | "none";
const REMOTE_EMBEDDING_PROVIDER_IDS = ["openai", "gemini", "voyage"] as const;
export type EmbeddingProviderResult = { export type EmbeddingProviderResult = {
provider: EmbeddingProvider; provider: EmbeddingProvider;
requestedProvider: "openai" | "local" | "gemini" | "voyage" | "auto"; requestedProvider: EmbeddingProviderRequest;
fallbackFrom?: "openai" | "local" | "gemini" | "voyage"; fallbackFrom?: EmbeddingProviderId;
fallbackReason?: string; fallbackReason?: string;
openAi?: OpenAiEmbeddingClient; openAi?: OpenAiEmbeddingClient;
gemini?: GeminiEmbeddingClient; gemini?: GeminiEmbeddingClient;
@@ -42,21 +48,21 @@ export type EmbeddingProviderResult = {
export type EmbeddingProviderOptions = { export type EmbeddingProviderOptions = {
config: OpenClawConfig; config: OpenClawConfig;
agentDir?: string; agentDir?: string;
provider: "openai" | "local" | "gemini" | "voyage" | "auto"; provider: EmbeddingProviderRequest;
remote?: { remote?: {
baseUrl?: string; baseUrl?: string;
apiKey?: string; apiKey?: string;
headers?: Record<string, string>; headers?: Record<string, string>;
}; };
model: string; model: string;
fallback: "openai" | "gemini" | "local" | "voyage" | "none"; fallback: EmbeddingProviderFallback;
local?: { local?: {
modelPath?: string; modelPath?: string;
modelCacheDir?: string; modelCacheDir?: string;
}; };
}; };
const DEFAULT_LOCAL_MODEL = export const DEFAULT_LOCAL_MODEL =
"hf:ggml-org/embeddinggemma-300m-qat-q8_0-GGUF/embeddinggemma-300m-qat-Q8_0.gguf"; "hf:ggml-org/embeddinggemma-300m-qat-q8_0-GGUF/embeddinggemma-300m-qat-Q8_0.gguf";
function canAutoSelectLocal(options: EmbeddingProviderOptions): boolean { function canAutoSelectLocal(options: EmbeddingProviderOptions): boolean {
@@ -134,7 +140,7 @@ export async function createEmbeddingProvider(
const requestedProvider = options.provider; const requestedProvider = options.provider;
const fallback = options.fallback; const fallback = options.fallback;
const createProvider = async (id: "openai" | "local" | "gemini" | "voyage") => { const createProvider = async (id: EmbeddingProviderId) => {
if (id === "local") { if (id === "local") {
const provider = await createLocalEmbeddingProvider(options); const provider = await createLocalEmbeddingProvider(options);
return { provider }; return { provider };
@@ -151,7 +157,7 @@ export async function createEmbeddingProvider(
return { provider, openAi: client }; return { provider, openAi: client };
}; };
const formatPrimaryError = (err: unknown, provider: "openai" | "local" | "gemini" | "voyage") => const formatPrimaryError = (err: unknown, provider: EmbeddingProviderId) =>
provider === "local" ? formatLocalSetupError(err) : formatErrorMessage(err); provider === "local" ? formatLocalSetupError(err) : formatErrorMessage(err);
if (requestedProvider === "auto") { if (requestedProvider === "auto") {
@@ -167,7 +173,7 @@ export async function createEmbeddingProvider(
} }
} }
for (const provider of ["openai", "gemini", "voyage"] as const) { for (const provider of REMOTE_EMBEDDING_PROVIDER_IDS) {
try { try {
const result = await createProvider(provider); const result = await createProvider(provider);
return { ...result, requestedProvider }; return { ...result, requestedProvider };
@@ -242,8 +248,9 @@ function formatLocalSetupError(err: unknown): string {
? "2) Reinstall OpenClaw (this should install node-llama-cpp): npm i -g openclaw@latest" ? "2) Reinstall OpenClaw (this should install node-llama-cpp): npm i -g openclaw@latest"
: null, : null,
"3) If you use pnpm: pnpm approve-builds (select node-llama-cpp), then pnpm rebuild node-llama-cpp", "3) If you use pnpm: pnpm approve-builds (select node-llama-cpp), then pnpm rebuild node-llama-cpp",
'Or set agents.defaults.memorySearch.provider = "openai" (remote).', ...REMOTE_EMBEDDING_PROVIDER_IDS.map(
'Or set agents.defaults.memorySearch.provider = "voyage" (remote).', (provider) => `Or set agents.defaults.memorySearch.provider = "${provider}" (remote).`,
),
] ]
.filter(Boolean) .filter(Boolean)
.join("\n"); .join("\n");