fix(memory): enforce guarded remote policy for embeddings

This commit is contained in:
Peter Steinberger
2026-02-22 18:13:44 +01:00
parent f6feb4144c
commit f87db7c627
6 changed files with 45 additions and 26 deletions

View File

@@ -4,12 +4,15 @@ import {
} from "../agents/api-key-rotation.js"; } from "../agents/api-key-rotation.js";
import { requireApiKey, resolveApiKeyForProvider } from "../agents/model-auth.js"; import { requireApiKey, resolveApiKeyForProvider } from "../agents/model-auth.js";
import { parseGeminiAuth } from "../infra/gemini-auth.js"; import { parseGeminiAuth } from "../infra/gemini-auth.js";
import type { SsrFPolicy } from "../infra/net/ssrf.js";
import { debugEmbeddingsLog } from "./embeddings-debug.js"; import { debugEmbeddingsLog } from "./embeddings-debug.js";
import type { EmbeddingProvider, EmbeddingProviderOptions } from "./embeddings.js"; import type { EmbeddingProvider, EmbeddingProviderOptions } from "./embeddings.js";
import { buildRemoteBaseUrlPolicy, withRemoteHttpResponse } from "./remote-http.js";
export type GeminiEmbeddingClient = { export type GeminiEmbeddingClient = {
baseUrl: string; baseUrl: string;
headers: Record<string, string>; headers: Record<string, string>;
ssrfPolicy?: SsrFPolicy;
model: string; model: string;
modelPath: string; modelPath: string;
apiKeys: string[]; apiKeys: string[];
@@ -73,19 +76,26 @@ export async function createGeminiEmbeddingProvider(
...authHeaders.headers, ...authHeaders.headers,
...client.headers, ...client.headers,
}; };
const res = await fetch(endpoint, { const payload = await withRemoteHttpResponse({
method: "POST", url: endpoint,
headers, ssrfPolicy: client.ssrfPolicy,
body: JSON.stringify(body), init: {
method: "POST",
headers,
body: JSON.stringify(body),
},
onResponse: async (res) => {
if (!res.ok) {
const text = await res.text();
throw new Error(`gemini embeddings failed: ${res.status} ${text}`);
}
return (await res.json()) as {
embedding?: { values?: number[] };
embeddings?: Array<{ values?: number[] }>;
};
},
}); });
if (!res.ok) { return payload;
const payload = await res.text();
throw new Error(`gemini embeddings failed: ${res.status} ${payload}`);
}
return (await res.json()) as {
embedding?: { values?: number[] };
embeddings?: Array<{ values?: number[] }>;
};
}; };
const embedQuery = async (text: string): Promise<number[]> => { const embedQuery = async (text: string): Promise<number[]> => {
@@ -158,6 +168,7 @@ export async function resolveGeminiEmbeddingClient(
const providerConfig = options.config.models?.providers?.google; const providerConfig = options.config.models?.providers?.google;
const rawBaseUrl = remoteBaseUrl || providerConfig?.baseUrl?.trim() || DEFAULT_GEMINI_BASE_URL; const rawBaseUrl = remoteBaseUrl || providerConfig?.baseUrl?.trim() || DEFAULT_GEMINI_BASE_URL;
const baseUrl = normalizeGeminiBaseUrl(rawBaseUrl); const baseUrl = normalizeGeminiBaseUrl(rawBaseUrl);
const ssrfPolicy = buildRemoteBaseUrlPolicy(baseUrl);
const headerOverrides = Object.assign({}, providerConfig?.headers, remote?.headers); const headerOverrides = Object.assign({}, providerConfig?.headers, remote?.headers);
const headers: Record<string, string> = { const headers: Record<string, string> = {
...headerOverrides, ...headerOverrides,
@@ -176,5 +187,5 @@ export async function resolveGeminiEmbeddingClient(
embedEndpoint: `${baseUrl}/${modelPath}:embedContent`, embedEndpoint: `${baseUrl}/${modelPath}:embedContent`,
batchEndpoint: `${baseUrl}/${modelPath}:batchEmbedContents`, batchEndpoint: `${baseUrl}/${modelPath}:batchEmbedContents`,
}); });
return { baseUrl, headers, model, modelPath, apiKeys }; return { baseUrl, headers, ssrfPolicy, model, modelPath, apiKeys };
} }

View File

@@ -1,3 +1,4 @@
import type { SsrFPolicy } from "../infra/net/ssrf.js";
import { resolveRemoteEmbeddingBearerClient } from "./embeddings-remote-client.js"; import { resolveRemoteEmbeddingBearerClient } from "./embeddings-remote-client.js";
import { fetchRemoteEmbeddingVectors } from "./embeddings-remote-fetch.js"; import { fetchRemoteEmbeddingVectors } from "./embeddings-remote-fetch.js";
import type { EmbeddingProvider, EmbeddingProviderOptions } from "./embeddings.js"; import type { EmbeddingProvider, EmbeddingProviderOptions } from "./embeddings.js";
@@ -5,6 +6,7 @@ import type { EmbeddingProvider, EmbeddingProviderOptions } from "./embeddings.j
export type OpenAiEmbeddingClient = { export type OpenAiEmbeddingClient = {
baseUrl: string; baseUrl: string;
headers: Record<string, string>; headers: Record<string, string>;
ssrfPolicy?: SsrFPolicy;
model: string; model: string;
}; };
@@ -40,6 +42,7 @@ export async function createOpenAiEmbeddingProvider(
return await fetchRemoteEmbeddingVectors({ return await fetchRemoteEmbeddingVectors({
url, url,
headers: client.headers, headers: client.headers,
ssrfPolicy: client.ssrfPolicy,
body: { model: client.model, input }, body: { model: client.model, input },
errorPrefix: "openai embeddings failed", errorPrefix: "openai embeddings failed",
}); });
@@ -63,11 +66,11 @@ export async function createOpenAiEmbeddingProvider(
export async function resolveOpenAiEmbeddingClient( export async function resolveOpenAiEmbeddingClient(
options: EmbeddingProviderOptions, options: EmbeddingProviderOptions,
): Promise<OpenAiEmbeddingClient> { ): Promise<OpenAiEmbeddingClient> {
const { baseUrl, headers } = await resolveRemoteEmbeddingBearerClient({ const { baseUrl, headers, ssrfPolicy } = await resolveRemoteEmbeddingBearerClient({
provider: "openai", provider: "openai",
options, options,
defaultBaseUrl: DEFAULT_OPENAI_BASE_URL, defaultBaseUrl: DEFAULT_OPENAI_BASE_URL,
}); });
const model = normalizeOpenAiModel(options.model); const model = normalizeOpenAiModel(options.model);
return { baseUrl, headers, model }; return { baseUrl, headers, ssrfPolicy, model };
} }

View File

@@ -1,5 +1,7 @@
import { requireApiKey, resolveApiKeyForProvider } from "../agents/model-auth.js"; import { requireApiKey, resolveApiKeyForProvider } from "../agents/model-auth.js";
import type { SsrFPolicy } from "../infra/net/ssrf.js";
import type { EmbeddingProviderOptions } from "./embeddings.js"; import type { EmbeddingProviderOptions } from "./embeddings.js";
import { buildRemoteBaseUrlPolicy } from "./remote-http.js";
type RemoteEmbeddingProviderId = "openai" | "voyage"; type RemoteEmbeddingProviderId = "openai" | "voyage";
@@ -7,7 +9,7 @@ export async function resolveRemoteEmbeddingBearerClient(params: {
provider: RemoteEmbeddingProviderId; provider: RemoteEmbeddingProviderId;
options: EmbeddingProviderOptions; options: EmbeddingProviderOptions;
defaultBaseUrl: string; defaultBaseUrl: string;
}): Promise<{ baseUrl: string; headers: Record<string, string> }> { }): Promise<{ baseUrl: string; headers: Record<string, string>; ssrfPolicy?: SsrFPolicy }> {
const remote = params.options.remote; const remote = params.options.remote;
const remoteApiKey = remote?.apiKey?.trim(); const remoteApiKey = remote?.apiKey?.trim();
const remoteBaseUrl = remote?.baseUrl?.trim(); const remoteBaseUrl = remote?.baseUrl?.trim();
@@ -29,5 +31,5 @@ export async function resolveRemoteEmbeddingBearerClient(params: {
Authorization: `Bearer ${apiKey}`, Authorization: `Bearer ${apiKey}`,
...headerOverrides, ...headerOverrides,
}; };
return { baseUrl, headers }; return { baseUrl, headers, ssrfPolicy: buildRemoteBaseUrlPolicy(baseUrl) };
} }

View File

@@ -84,7 +84,7 @@ describe("voyage embedding provider", () => {
model: "voyage-4-lite", model: "voyage-4-lite",
fallback: "none", fallback: "none",
remote: { remote: {
baseUrl: "https://proxy.example.com", baseUrl: "https://example.com",
apiKey: "remote-override-key", apiKey: "remote-override-key",
headers: { "X-Custom": "123" }, headers: { "X-Custom": "123" },
}, },
@@ -95,7 +95,7 @@ describe("voyage embedding provider", () => {
const call = fetchMock.mock.calls[0]; const call = fetchMock.mock.calls[0];
expect(call).toBeDefined(); expect(call).toBeDefined();
const [url, init] = call as [RequestInfo | URL, RequestInit | undefined]; const [url, init] = call as [RequestInfo | URL, RequestInit | undefined];
expect(url).toBe("https://proxy.example.com/embeddings"); expect(url).toBe("https://example.com/embeddings");
const headers = (init?.headers ?? {}) as Record<string, string>; const headers = (init?.headers ?? {}) as Record<string, string>;
expect(headers.Authorization).toBe("Bearer remote-override-key"); expect(headers.Authorization).toBe("Bearer remote-override-key");

View File

@@ -1,3 +1,4 @@
import type { SsrFPolicy } from "../infra/net/ssrf.js";
import { resolveRemoteEmbeddingBearerClient } from "./embeddings-remote-client.js"; import { resolveRemoteEmbeddingBearerClient } from "./embeddings-remote-client.js";
import { fetchRemoteEmbeddingVectors } from "./embeddings-remote-fetch.js"; import { fetchRemoteEmbeddingVectors } from "./embeddings-remote-fetch.js";
import type { EmbeddingProvider, EmbeddingProviderOptions } from "./embeddings.js"; import type { EmbeddingProvider, EmbeddingProviderOptions } from "./embeddings.js";
@@ -5,6 +6,7 @@ import type { EmbeddingProvider, EmbeddingProviderOptions } from "./embeddings.j
export type VoyageEmbeddingClient = { export type VoyageEmbeddingClient = {
baseUrl: string; baseUrl: string;
headers: Record<string, string>; headers: Record<string, string>;
ssrfPolicy?: SsrFPolicy;
model: string; model: string;
}; };
@@ -48,6 +50,7 @@ export async function createVoyageEmbeddingProvider(
return await fetchRemoteEmbeddingVectors({ return await fetchRemoteEmbeddingVectors({
url, url,
headers: client.headers, headers: client.headers,
ssrfPolicy: client.ssrfPolicy,
body, body,
errorPrefix: "voyage embeddings failed", errorPrefix: "voyage embeddings failed",
}); });
@@ -71,11 +74,11 @@ export async function createVoyageEmbeddingProvider(
export async function resolveVoyageEmbeddingClient( export async function resolveVoyageEmbeddingClient(
options: EmbeddingProviderOptions, options: EmbeddingProviderOptions,
): Promise<VoyageEmbeddingClient> { ): Promise<VoyageEmbeddingClient> {
const { baseUrl, headers } = await resolveRemoteEmbeddingBearerClient({ const { baseUrl, headers, ssrfPolicy } = await resolveRemoteEmbeddingBearerClient({
provider: "voyage", provider: "voyage",
options, options,
defaultBaseUrl: DEFAULT_VOYAGE_BASE_URL, defaultBaseUrl: DEFAULT_VOYAGE_BASE_URL,
}); });
const model = normalizeVoyageModel(options.model); const model = normalizeVoyageModel(options.model);
return { baseUrl, headers, model }; return { baseUrl, headers, ssrfPolicy, model };
} }

View File

@@ -93,7 +93,7 @@ describe("embedding provider remote overrides", () => {
models: { models: {
providers: { providers: {
openai: { openai: {
baseUrl: "https://provider.example/v1", baseUrl: "https://api.openai.com/v1",
headers: { headers: {
"X-Provider": "p", "X-Provider": "p",
"X-Shared": "provider", "X-Shared": "provider",
@@ -107,7 +107,7 @@ describe("embedding provider remote overrides", () => {
config: cfg as never, config: cfg as never,
provider: "openai", provider: "openai",
remote: { remote: {
baseUrl: "https://remote.example/v1", baseUrl: "https://example.com/v1",
apiKey: " remote-key ", apiKey: " remote-key ",
headers: { headers: {
"X-Shared": "remote", "X-Shared": "remote",
@@ -124,7 +124,7 @@ describe("embedding provider remote overrides", () => {
expect(authModule.resolveApiKeyForProvider).not.toHaveBeenCalled(); expect(authModule.resolveApiKeyForProvider).not.toHaveBeenCalled();
const url = fetchMock.mock.calls[0]?.[0]; const url = fetchMock.mock.calls[0]?.[0];
const init = fetchMock.mock.calls[0]?.[1] as RequestInit | undefined; const init = fetchMock.mock.calls[0]?.[1] as RequestInit | undefined;
expect(url).toBe("https://remote.example/v1/embeddings"); expect(url).toBe("https://example.com/v1/embeddings");
const headers = (init?.headers ?? {}) as Record<string, string>; const headers = (init?.headers ?? {}) as Record<string, string>;
expect(headers.Authorization).toBe("Bearer remote-key"); expect(headers.Authorization).toBe("Bearer remote-key");
expect(headers["Content-Type"]).toBe("application/json"); expect(headers["Content-Type"]).toBe("application/json");
@@ -142,7 +142,7 @@ describe("embedding provider remote overrides", () => {
models: { models: {
providers: { providers: {
openai: { openai: {
baseUrl: "https://provider.example/v1", baseUrl: "https://api.openai.com/v1",
}, },
}, },
}, },
@@ -152,7 +152,7 @@ describe("embedding provider remote overrides", () => {
config: cfg as never, config: cfg as never,
provider: "openai", provider: "openai",
remote: { remote: {
baseUrl: "https://remote.example/v1", baseUrl: "https://example.com/v1",
apiKey: " ", apiKey: " ",
}, },
model: "text-embedding-3-small", model: "text-embedding-3-small",