mirror of
https://github.com/openclaw/openclaw.git
synced 2026-05-09 23:04:32 +00:00
fix(memory): enforce guarded remote policy for embeddings
This commit is contained in:
@@ -4,12 +4,15 @@ import {
|
|||||||
} from "../agents/api-key-rotation.js";
|
} from "../agents/api-key-rotation.js";
|
||||||
import { requireApiKey, resolveApiKeyForProvider } from "../agents/model-auth.js";
|
import { requireApiKey, resolveApiKeyForProvider } from "../agents/model-auth.js";
|
||||||
import { parseGeminiAuth } from "../infra/gemini-auth.js";
|
import { parseGeminiAuth } from "../infra/gemini-auth.js";
|
||||||
|
import type { SsrFPolicy } from "../infra/net/ssrf.js";
|
||||||
import { debugEmbeddingsLog } from "./embeddings-debug.js";
|
import { debugEmbeddingsLog } from "./embeddings-debug.js";
|
||||||
import type { EmbeddingProvider, EmbeddingProviderOptions } from "./embeddings.js";
|
import type { EmbeddingProvider, EmbeddingProviderOptions } from "./embeddings.js";
|
||||||
|
import { buildRemoteBaseUrlPolicy, withRemoteHttpResponse } from "./remote-http.js";
|
||||||
|
|
||||||
export type GeminiEmbeddingClient = {
|
export type GeminiEmbeddingClient = {
|
||||||
baseUrl: string;
|
baseUrl: string;
|
||||||
headers: Record<string, string>;
|
headers: Record<string, string>;
|
||||||
|
ssrfPolicy?: SsrFPolicy;
|
||||||
model: string;
|
model: string;
|
||||||
modelPath: string;
|
modelPath: string;
|
||||||
apiKeys: string[];
|
apiKeys: string[];
|
||||||
@@ -73,19 +76,26 @@ export async function createGeminiEmbeddingProvider(
|
|||||||
...authHeaders.headers,
|
...authHeaders.headers,
|
||||||
...client.headers,
|
...client.headers,
|
||||||
};
|
};
|
||||||
const res = await fetch(endpoint, {
|
const payload = await withRemoteHttpResponse({
|
||||||
|
url: endpoint,
|
||||||
|
ssrfPolicy: client.ssrfPolicy,
|
||||||
|
init: {
|
||||||
method: "POST",
|
method: "POST",
|
||||||
headers,
|
headers,
|
||||||
body: JSON.stringify(body),
|
body: JSON.stringify(body),
|
||||||
});
|
},
|
||||||
|
onResponse: async (res) => {
|
||||||
if (!res.ok) {
|
if (!res.ok) {
|
||||||
const payload = await res.text();
|
const text = await res.text();
|
||||||
throw new Error(`gemini embeddings failed: ${res.status} ${payload}`);
|
throw new Error(`gemini embeddings failed: ${res.status} ${text}`);
|
||||||
}
|
}
|
||||||
return (await res.json()) as {
|
return (await res.json()) as {
|
||||||
embedding?: { values?: number[] };
|
embedding?: { values?: number[] };
|
||||||
embeddings?: Array<{ values?: number[] }>;
|
embeddings?: Array<{ values?: number[] }>;
|
||||||
};
|
};
|
||||||
|
},
|
||||||
|
});
|
||||||
|
return payload;
|
||||||
};
|
};
|
||||||
|
|
||||||
const embedQuery = async (text: string): Promise<number[]> => {
|
const embedQuery = async (text: string): Promise<number[]> => {
|
||||||
@@ -158,6 +168,7 @@ export async function resolveGeminiEmbeddingClient(
|
|||||||
const providerConfig = options.config.models?.providers?.google;
|
const providerConfig = options.config.models?.providers?.google;
|
||||||
const rawBaseUrl = remoteBaseUrl || providerConfig?.baseUrl?.trim() || DEFAULT_GEMINI_BASE_URL;
|
const rawBaseUrl = remoteBaseUrl || providerConfig?.baseUrl?.trim() || DEFAULT_GEMINI_BASE_URL;
|
||||||
const baseUrl = normalizeGeminiBaseUrl(rawBaseUrl);
|
const baseUrl = normalizeGeminiBaseUrl(rawBaseUrl);
|
||||||
|
const ssrfPolicy = buildRemoteBaseUrlPolicy(baseUrl);
|
||||||
const headerOverrides = Object.assign({}, providerConfig?.headers, remote?.headers);
|
const headerOverrides = Object.assign({}, providerConfig?.headers, remote?.headers);
|
||||||
const headers: Record<string, string> = {
|
const headers: Record<string, string> = {
|
||||||
...headerOverrides,
|
...headerOverrides,
|
||||||
@@ -176,5 +187,5 @@ export async function resolveGeminiEmbeddingClient(
|
|||||||
embedEndpoint: `${baseUrl}/${modelPath}:embedContent`,
|
embedEndpoint: `${baseUrl}/${modelPath}:embedContent`,
|
||||||
batchEndpoint: `${baseUrl}/${modelPath}:batchEmbedContents`,
|
batchEndpoint: `${baseUrl}/${modelPath}:batchEmbedContents`,
|
||||||
});
|
});
|
||||||
return { baseUrl, headers, model, modelPath, apiKeys };
|
return { baseUrl, headers, ssrfPolicy, model, modelPath, apiKeys };
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
import type { SsrFPolicy } from "../infra/net/ssrf.js";
|
||||||
import { resolveRemoteEmbeddingBearerClient } from "./embeddings-remote-client.js";
|
import { resolveRemoteEmbeddingBearerClient } from "./embeddings-remote-client.js";
|
||||||
import { fetchRemoteEmbeddingVectors } from "./embeddings-remote-fetch.js";
|
import { fetchRemoteEmbeddingVectors } from "./embeddings-remote-fetch.js";
|
||||||
import type { EmbeddingProvider, EmbeddingProviderOptions } from "./embeddings.js";
|
import type { EmbeddingProvider, EmbeddingProviderOptions } from "./embeddings.js";
|
||||||
@@ -5,6 +6,7 @@ import type { EmbeddingProvider, EmbeddingProviderOptions } from "./embeddings.j
|
|||||||
export type OpenAiEmbeddingClient = {
|
export type OpenAiEmbeddingClient = {
|
||||||
baseUrl: string;
|
baseUrl: string;
|
||||||
headers: Record<string, string>;
|
headers: Record<string, string>;
|
||||||
|
ssrfPolicy?: SsrFPolicy;
|
||||||
model: string;
|
model: string;
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -40,6 +42,7 @@ export async function createOpenAiEmbeddingProvider(
|
|||||||
return await fetchRemoteEmbeddingVectors({
|
return await fetchRemoteEmbeddingVectors({
|
||||||
url,
|
url,
|
||||||
headers: client.headers,
|
headers: client.headers,
|
||||||
|
ssrfPolicy: client.ssrfPolicy,
|
||||||
body: { model: client.model, input },
|
body: { model: client.model, input },
|
||||||
errorPrefix: "openai embeddings failed",
|
errorPrefix: "openai embeddings failed",
|
||||||
});
|
});
|
||||||
@@ -63,11 +66,11 @@ export async function createOpenAiEmbeddingProvider(
|
|||||||
export async function resolveOpenAiEmbeddingClient(
|
export async function resolveOpenAiEmbeddingClient(
|
||||||
options: EmbeddingProviderOptions,
|
options: EmbeddingProviderOptions,
|
||||||
): Promise<OpenAiEmbeddingClient> {
|
): Promise<OpenAiEmbeddingClient> {
|
||||||
const { baseUrl, headers } = await resolveRemoteEmbeddingBearerClient({
|
const { baseUrl, headers, ssrfPolicy } = await resolveRemoteEmbeddingBearerClient({
|
||||||
provider: "openai",
|
provider: "openai",
|
||||||
options,
|
options,
|
||||||
defaultBaseUrl: DEFAULT_OPENAI_BASE_URL,
|
defaultBaseUrl: DEFAULT_OPENAI_BASE_URL,
|
||||||
});
|
});
|
||||||
const model = normalizeOpenAiModel(options.model);
|
const model = normalizeOpenAiModel(options.model);
|
||||||
return { baseUrl, headers, model };
|
return { baseUrl, headers, ssrfPolicy, model };
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,5 +1,7 @@
|
|||||||
import { requireApiKey, resolveApiKeyForProvider } from "../agents/model-auth.js";
|
import { requireApiKey, resolveApiKeyForProvider } from "../agents/model-auth.js";
|
||||||
|
import type { SsrFPolicy } from "../infra/net/ssrf.js";
|
||||||
import type { EmbeddingProviderOptions } from "./embeddings.js";
|
import type { EmbeddingProviderOptions } from "./embeddings.js";
|
||||||
|
import { buildRemoteBaseUrlPolicy } from "./remote-http.js";
|
||||||
|
|
||||||
type RemoteEmbeddingProviderId = "openai" | "voyage";
|
type RemoteEmbeddingProviderId = "openai" | "voyage";
|
||||||
|
|
||||||
@@ -7,7 +9,7 @@ export async function resolveRemoteEmbeddingBearerClient(params: {
|
|||||||
provider: RemoteEmbeddingProviderId;
|
provider: RemoteEmbeddingProviderId;
|
||||||
options: EmbeddingProviderOptions;
|
options: EmbeddingProviderOptions;
|
||||||
defaultBaseUrl: string;
|
defaultBaseUrl: string;
|
||||||
}): Promise<{ baseUrl: string; headers: Record<string, string> }> {
|
}): Promise<{ baseUrl: string; headers: Record<string, string>; ssrfPolicy?: SsrFPolicy }> {
|
||||||
const remote = params.options.remote;
|
const remote = params.options.remote;
|
||||||
const remoteApiKey = remote?.apiKey?.trim();
|
const remoteApiKey = remote?.apiKey?.trim();
|
||||||
const remoteBaseUrl = remote?.baseUrl?.trim();
|
const remoteBaseUrl = remote?.baseUrl?.trim();
|
||||||
@@ -29,5 +31,5 @@ export async function resolveRemoteEmbeddingBearerClient(params: {
|
|||||||
Authorization: `Bearer ${apiKey}`,
|
Authorization: `Bearer ${apiKey}`,
|
||||||
...headerOverrides,
|
...headerOverrides,
|
||||||
};
|
};
|
||||||
return { baseUrl, headers };
|
return { baseUrl, headers, ssrfPolicy: buildRemoteBaseUrlPolicy(baseUrl) };
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -84,7 +84,7 @@ describe("voyage embedding provider", () => {
|
|||||||
model: "voyage-4-lite",
|
model: "voyage-4-lite",
|
||||||
fallback: "none",
|
fallback: "none",
|
||||||
remote: {
|
remote: {
|
||||||
baseUrl: "https://proxy.example.com",
|
baseUrl: "https://example.com",
|
||||||
apiKey: "remote-override-key",
|
apiKey: "remote-override-key",
|
||||||
headers: { "X-Custom": "123" },
|
headers: { "X-Custom": "123" },
|
||||||
},
|
},
|
||||||
@@ -95,7 +95,7 @@ describe("voyage embedding provider", () => {
|
|||||||
const call = fetchMock.mock.calls[0];
|
const call = fetchMock.mock.calls[0];
|
||||||
expect(call).toBeDefined();
|
expect(call).toBeDefined();
|
||||||
const [url, init] = call as [RequestInfo | URL, RequestInit | undefined];
|
const [url, init] = call as [RequestInfo | URL, RequestInit | undefined];
|
||||||
expect(url).toBe("https://proxy.example.com/embeddings");
|
expect(url).toBe("https://example.com/embeddings");
|
||||||
|
|
||||||
const headers = (init?.headers ?? {}) as Record<string, string>;
|
const headers = (init?.headers ?? {}) as Record<string, string>;
|
||||||
expect(headers.Authorization).toBe("Bearer remote-override-key");
|
expect(headers.Authorization).toBe("Bearer remote-override-key");
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
import type { SsrFPolicy } from "../infra/net/ssrf.js";
|
||||||
import { resolveRemoteEmbeddingBearerClient } from "./embeddings-remote-client.js";
|
import { resolveRemoteEmbeddingBearerClient } from "./embeddings-remote-client.js";
|
||||||
import { fetchRemoteEmbeddingVectors } from "./embeddings-remote-fetch.js";
|
import { fetchRemoteEmbeddingVectors } from "./embeddings-remote-fetch.js";
|
||||||
import type { EmbeddingProvider, EmbeddingProviderOptions } from "./embeddings.js";
|
import type { EmbeddingProvider, EmbeddingProviderOptions } from "./embeddings.js";
|
||||||
@@ -5,6 +6,7 @@ import type { EmbeddingProvider, EmbeddingProviderOptions } from "./embeddings.j
|
|||||||
export type VoyageEmbeddingClient = {
|
export type VoyageEmbeddingClient = {
|
||||||
baseUrl: string;
|
baseUrl: string;
|
||||||
headers: Record<string, string>;
|
headers: Record<string, string>;
|
||||||
|
ssrfPolicy?: SsrFPolicy;
|
||||||
model: string;
|
model: string;
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -48,6 +50,7 @@ export async function createVoyageEmbeddingProvider(
|
|||||||
return await fetchRemoteEmbeddingVectors({
|
return await fetchRemoteEmbeddingVectors({
|
||||||
url,
|
url,
|
||||||
headers: client.headers,
|
headers: client.headers,
|
||||||
|
ssrfPolicy: client.ssrfPolicy,
|
||||||
body,
|
body,
|
||||||
errorPrefix: "voyage embeddings failed",
|
errorPrefix: "voyage embeddings failed",
|
||||||
});
|
});
|
||||||
@@ -71,11 +74,11 @@ export async function createVoyageEmbeddingProvider(
|
|||||||
export async function resolveVoyageEmbeddingClient(
|
export async function resolveVoyageEmbeddingClient(
|
||||||
options: EmbeddingProviderOptions,
|
options: EmbeddingProviderOptions,
|
||||||
): Promise<VoyageEmbeddingClient> {
|
): Promise<VoyageEmbeddingClient> {
|
||||||
const { baseUrl, headers } = await resolveRemoteEmbeddingBearerClient({
|
const { baseUrl, headers, ssrfPolicy } = await resolveRemoteEmbeddingBearerClient({
|
||||||
provider: "voyage",
|
provider: "voyage",
|
||||||
options,
|
options,
|
||||||
defaultBaseUrl: DEFAULT_VOYAGE_BASE_URL,
|
defaultBaseUrl: DEFAULT_VOYAGE_BASE_URL,
|
||||||
});
|
});
|
||||||
const model = normalizeVoyageModel(options.model);
|
const model = normalizeVoyageModel(options.model);
|
||||||
return { baseUrl, headers, model };
|
return { baseUrl, headers, ssrfPolicy, model };
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -93,7 +93,7 @@ describe("embedding provider remote overrides", () => {
|
|||||||
models: {
|
models: {
|
||||||
providers: {
|
providers: {
|
||||||
openai: {
|
openai: {
|
||||||
baseUrl: "https://provider.example/v1",
|
baseUrl: "https://api.openai.com/v1",
|
||||||
headers: {
|
headers: {
|
||||||
"X-Provider": "p",
|
"X-Provider": "p",
|
||||||
"X-Shared": "provider",
|
"X-Shared": "provider",
|
||||||
@@ -107,7 +107,7 @@ describe("embedding provider remote overrides", () => {
|
|||||||
config: cfg as never,
|
config: cfg as never,
|
||||||
provider: "openai",
|
provider: "openai",
|
||||||
remote: {
|
remote: {
|
||||||
baseUrl: "https://remote.example/v1",
|
baseUrl: "https://example.com/v1",
|
||||||
apiKey: " remote-key ",
|
apiKey: " remote-key ",
|
||||||
headers: {
|
headers: {
|
||||||
"X-Shared": "remote",
|
"X-Shared": "remote",
|
||||||
@@ -124,7 +124,7 @@ describe("embedding provider remote overrides", () => {
|
|||||||
expect(authModule.resolveApiKeyForProvider).not.toHaveBeenCalled();
|
expect(authModule.resolveApiKeyForProvider).not.toHaveBeenCalled();
|
||||||
const url = fetchMock.mock.calls[0]?.[0];
|
const url = fetchMock.mock.calls[0]?.[0];
|
||||||
const init = fetchMock.mock.calls[0]?.[1] as RequestInit | undefined;
|
const init = fetchMock.mock.calls[0]?.[1] as RequestInit | undefined;
|
||||||
expect(url).toBe("https://remote.example/v1/embeddings");
|
expect(url).toBe("https://example.com/v1/embeddings");
|
||||||
const headers = (init?.headers ?? {}) as Record<string, string>;
|
const headers = (init?.headers ?? {}) as Record<string, string>;
|
||||||
expect(headers.Authorization).toBe("Bearer remote-key");
|
expect(headers.Authorization).toBe("Bearer remote-key");
|
||||||
expect(headers["Content-Type"]).toBe("application/json");
|
expect(headers["Content-Type"]).toBe("application/json");
|
||||||
@@ -142,7 +142,7 @@ describe("embedding provider remote overrides", () => {
|
|||||||
models: {
|
models: {
|
||||||
providers: {
|
providers: {
|
||||||
openai: {
|
openai: {
|
||||||
baseUrl: "https://provider.example/v1",
|
baseUrl: "https://api.openai.com/v1",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -152,7 +152,7 @@ describe("embedding provider remote overrides", () => {
|
|||||||
config: cfg as never,
|
config: cfg as never,
|
||||||
provider: "openai",
|
provider: "openai",
|
||||||
remote: {
|
remote: {
|
||||||
baseUrl: "https://remote.example/v1",
|
baseUrl: "https://example.com/v1",
|
||||||
apiKey: " ",
|
apiKey: " ",
|
||||||
},
|
},
|
||||||
model: "text-embedding-3-small",
|
model: "text-embedding-3-small",
|
||||||
|
|||||||
Reference in New Issue
Block a user