fix(memory): route batch APIs through guarded remote HTTP

This commit is contained in:
Peter Steinberger
2026-02-22 18:14:00 +01:00
parent f87db7c627
commit eb041daee2
5 changed files with 178 additions and 108 deletions

View File

@@ -3,6 +3,7 @@ import { buildBatchHeaders, normalizeBatchBaseUrl } from "./batch-utils.js";
import { debugEmbeddingsLog } from "./embeddings-debug.js"; import { debugEmbeddingsLog } from "./embeddings-debug.js";
import type { GeminiEmbeddingClient } from "./embeddings-gemini.js"; import type { GeminiEmbeddingClient } from "./embeddings-gemini.js";
import { hashText } from "./internal.js"; import { hashText } from "./internal.js";
import { withRemoteHttpResponse } from "./remote-http.js";
export type GeminiBatchRequest = { export type GeminiBatchRequest = {
custom_id: string; custom_id: string;
@@ -93,19 +94,25 @@ async function submitGeminiBatch(params: {
baseUrl, baseUrl,
requests: params.requests.length, requests: params.requests.length,
}); });
const fileRes = await fetch(uploadUrl, { const filePayload = await withRemoteHttpResponse({
url: uploadUrl,
ssrfPolicy: params.gemini.ssrfPolicy,
init: {
method: "POST", method: "POST",
headers: { headers: {
...buildBatchHeaders(params.gemini, { json: false }), ...buildBatchHeaders(params.gemini, { json: false }),
"Content-Type": uploadPayload.contentType, "Content-Type": uploadPayload.contentType,
}, },
body: uploadPayload.body, body: uploadPayload.body,
}); },
onResponse: async (fileRes) => {
if (!fileRes.ok) { if (!fileRes.ok) {
const text = await fileRes.text(); const text = await fileRes.text();
throw new Error(`gemini batch file upload failed: ${fileRes.status} ${text}`); throw new Error(`gemini batch file upload failed: ${fileRes.status} ${text}`);
} }
const filePayload = (await fileRes.json()) as { name?: string; file?: { name?: string } }; return (await fileRes.json()) as { name?: string; file?: { name?: string } };
},
});
const fileId = filePayload.name ?? filePayload.file?.name; const fileId = filePayload.name ?? filePayload.file?.name;
if (!fileId) { if (!fileId) {
throw new Error("gemini batch file upload failed: missing file id"); throw new Error("gemini batch file upload failed: missing file id");
@@ -125,11 +132,15 @@ async function submitGeminiBatch(params: {
batchEndpoint, batchEndpoint,
fileId, fileId,
}); });
const batchRes = await fetch(batchEndpoint, { return await withRemoteHttpResponse({
url: batchEndpoint,
ssrfPolicy: params.gemini.ssrfPolicy,
init: {
method: "POST", method: "POST",
headers: buildBatchHeaders(params.gemini, { json: true }), headers: buildBatchHeaders(params.gemini, { json: true }),
body: JSON.stringify(batchBody), body: JSON.stringify(batchBody),
}); },
onResponse: async (batchRes) => {
if (batchRes.ok) { if (batchRes.ok) {
return (await batchRes.json()) as GeminiBatchStatus; return (await batchRes.json()) as GeminiBatchStatus;
} }
@@ -140,6 +151,8 @@ async function submitGeminiBatch(params: {
); );
} }
throw new Error(`gemini batch create failed: ${batchRes.status} ${text}`); throw new Error(`gemini batch create failed: ${batchRes.status} ${text}`);
},
});
} }
async function fetchGeminiBatchStatus(params: { async function fetchGeminiBatchStatus(params: {
@@ -152,14 +165,20 @@ async function fetchGeminiBatchStatus(params: {
: `batches/${params.batchName}`; : `batches/${params.batchName}`;
const statusUrl = `${baseUrl}/${name}`; const statusUrl = `${baseUrl}/${name}`;
debugEmbeddingsLog("memory embeddings: gemini batch status", { statusUrl }); debugEmbeddingsLog("memory embeddings: gemini batch status", { statusUrl });
const res = await fetch(statusUrl, { return await withRemoteHttpResponse({
url: statusUrl,
ssrfPolicy: params.gemini.ssrfPolicy,
init: {
headers: buildBatchHeaders(params.gemini, { json: true }), headers: buildBatchHeaders(params.gemini, { json: true }),
}); },
onResponse: async (res) => {
if (!res.ok) { if (!res.ok) {
const text = await res.text(); const text = await res.text();
throw new Error(`gemini batch status failed: ${res.status} ${text}`); throw new Error(`gemini batch status failed: ${res.status} ${text}`);
} }
return (await res.json()) as GeminiBatchStatus; return (await res.json()) as GeminiBatchStatus;
},
});
} }
async function fetchGeminiFileContent(params: { async function fetchGeminiFileContent(params: {
@@ -170,14 +189,20 @@ async function fetchGeminiFileContent(params: {
const file = params.fileId.startsWith("files/") ? params.fileId : `files/${params.fileId}`; const file = params.fileId.startsWith("files/") ? params.fileId : `files/${params.fileId}`;
const downloadUrl = `${baseUrl}/${file}:download`; const downloadUrl = `${baseUrl}/${file}:download`;
debugEmbeddingsLog("memory embeddings: gemini batch download", { downloadUrl }); debugEmbeddingsLog("memory embeddings: gemini batch download", { downloadUrl });
const res = await fetch(downloadUrl, { return await withRemoteHttpResponse({
url: downloadUrl,
ssrfPolicy: params.gemini.ssrfPolicy,
init: {
headers: buildBatchHeaders(params.gemini, { json: true }), headers: buildBatchHeaders(params.gemini, { json: true }),
}); },
onResponse: async (res) => {
if (!res.ok) { if (!res.ok) {
const text = await res.text(); const text = await res.text();
throw new Error(`gemini batch file content failed: ${res.status} ${text}`); throw new Error(`gemini batch file content failed: ${res.status} ${text}`);
} }
return await res.text(); return await res.text();
},
});
} }
function parseGeminiBatchOutput(text: string): GeminiBatchOutputLine[] { function parseGeminiBatchOutput(text: string): GeminiBatchOutputLine[] {

View File

@@ -5,6 +5,7 @@ import { runEmbeddingBatchGroups } from "./batch-runner.js";
import { uploadBatchJsonlFile } from "./batch-upload.js"; import { uploadBatchJsonlFile } from "./batch-upload.js";
import { buildBatchHeaders, normalizeBatchBaseUrl } from "./batch-utils.js"; import { buildBatchHeaders, normalizeBatchBaseUrl } from "./batch-utils.js";
import type { OpenAiEmbeddingClient } from "./embeddings-openai.js"; import type { OpenAiEmbeddingClient } from "./embeddings-openai.js";
import { withRemoteHttpResponse } from "./remote-http.js";
export type OpenAiBatchRequest = { export type OpenAiBatchRequest = {
custom_id: string; custom_id: string;
@@ -54,6 +55,7 @@ async function submitOpenAiBatch(params: {
return await postJsonWithRetry<OpenAiBatchStatus>({ return await postJsonWithRetry<OpenAiBatchStatus>({
url: `${baseUrl}/batches`, url: `${baseUrl}/batches`,
headers: buildBatchHeaders(params.openAi, { json: true }), headers: buildBatchHeaders(params.openAi, { json: true }),
ssrfPolicy: params.openAi.ssrfPolicy,
body: { body: {
input_file_id: inputFileId, input_file_id: inputFileId,
endpoint: OPENAI_BATCH_ENDPOINT, endpoint: OPENAI_BATCH_ENDPOINT,
@@ -72,14 +74,20 @@ async function fetchOpenAiBatchStatus(params: {
batchId: string; batchId: string;
}): Promise<OpenAiBatchStatus> { }): Promise<OpenAiBatchStatus> {
const baseUrl = normalizeBatchBaseUrl(params.openAi); const baseUrl = normalizeBatchBaseUrl(params.openAi);
const res = await fetch(`${baseUrl}/batches/${params.batchId}`, { return await withRemoteHttpResponse({
url: `${baseUrl}/batches/${params.batchId}`,
ssrfPolicy: params.openAi.ssrfPolicy,
init: {
headers: buildBatchHeaders(params.openAi, { json: true }), headers: buildBatchHeaders(params.openAi, { json: true }),
}); },
onResponse: async (res) => {
if (!res.ok) { if (!res.ok) {
const text = await res.text(); const text = await res.text();
throw new Error(`openai batch status failed: ${res.status} ${text}`); throw new Error(`openai batch status failed: ${res.status} ${text}`);
} }
return (await res.json()) as OpenAiBatchStatus; return (await res.json()) as OpenAiBatchStatus;
},
});
} }
async function fetchOpenAiFileContent(params: { async function fetchOpenAiFileContent(params: {
@@ -87,14 +95,20 @@ async function fetchOpenAiFileContent(params: {
fileId: string; fileId: string;
}): Promise<string> { }): Promise<string> {
const baseUrl = normalizeBatchBaseUrl(params.openAi); const baseUrl = normalizeBatchBaseUrl(params.openAi);
const res = await fetch(`${baseUrl}/files/${params.fileId}/content`, { return await withRemoteHttpResponse({
url: `${baseUrl}/files/${params.fileId}/content`,
ssrfPolicy: params.openAi.ssrfPolicy,
init: {
headers: buildBatchHeaders(params.openAi, { json: true }), headers: buildBatchHeaders(params.openAi, { json: true }),
}); },
onResponse: async (res) => {
if (!res.ok) { if (!res.ok) {
const text = await res.text(); const text = await res.text();
throw new Error(`openai batch file content failed: ${res.status} ${text}`); throw new Error(`openai batch file content failed: ${res.status} ${text}`);
} }
return await res.text(); return await res.text();
},
});
} }
function parseOpenAiBatchOutput(text: string): OpenAiBatchOutputLine[] { function parseOpenAiBatchOutput(text: string): OpenAiBatchOutputLine[] {

View File

@@ -4,6 +4,7 @@ import {
type BatchHttpClientConfig, type BatchHttpClientConfig,
} from "./batch-utils.js"; } from "./batch-utils.js";
import { hashText } from "./internal.js"; import { hashText } from "./internal.js";
import { withRemoteHttpResponse } from "./remote-http.js";
export async function uploadBatchJsonlFile(params: { export async function uploadBatchJsonlFile(params: {
client: BatchHttpClientConfig; client: BatchHttpClientConfig;
@@ -20,16 +21,22 @@ export async function uploadBatchJsonlFile(params: {
`memory-embeddings.${hashText(String(Date.now()))}.jsonl`, `memory-embeddings.${hashText(String(Date.now()))}.jsonl`,
); );
const fileRes = await fetch(`${baseUrl}/files`, { const filePayload = await withRemoteHttpResponse({
url: `${baseUrl}/files`,
ssrfPolicy: params.client.ssrfPolicy,
init: {
method: "POST", method: "POST",
headers: buildBatchHeaders(params.client, { json: false }), headers: buildBatchHeaders(params.client, { json: false }),
body: form, body: form,
}); },
onResponse: async (fileRes) => {
if (!fileRes.ok) { if (!fileRes.ok) {
const text = await fileRes.text(); const text = await fileRes.text();
throw new Error(`${params.errorPrefix}: ${fileRes.status} ${text}`); throw new Error(`${params.errorPrefix}: ${fileRes.status} ${text}`);
} }
const filePayload = (await fileRes.json()) as { id?: string }; return (await fileRes.json()) as { id?: string };
},
});
if (!filePayload.id) { if (!filePayload.id) {
throw new Error(`${params.errorPrefix}: missing file id`); throw new Error(`${params.errorPrefix}: missing file id`);
} }

View File

@@ -1,6 +1,9 @@
import type { SsrFPolicy } from "../infra/net/ssrf.js";
export type BatchHttpClientConfig = { export type BatchHttpClientConfig = {
baseUrl?: string; baseUrl?: string;
headers?: Record<string, string>; headers?: Record<string, string>;
ssrfPolicy?: SsrFPolicy;
}; };
export function normalizeBatchBaseUrl(client: BatchHttpClientConfig): string { export function normalizeBatchBaseUrl(client: BatchHttpClientConfig): string {

View File

@@ -7,6 +7,7 @@ import { runEmbeddingBatchGroups } from "./batch-runner.js";
import { uploadBatchJsonlFile } from "./batch-upload.js"; import { uploadBatchJsonlFile } from "./batch-upload.js";
import { buildBatchHeaders, normalizeBatchBaseUrl } from "./batch-utils.js"; import { buildBatchHeaders, normalizeBatchBaseUrl } from "./batch-utils.js";
import type { VoyageEmbeddingClient } from "./embeddings-voyage.js"; import type { VoyageEmbeddingClient } from "./embeddings-voyage.js";
import { withRemoteHttpResponse } from "./remote-http.js";
/** /**
* Voyage Batch API Input Line format. * Voyage Batch API Input Line format.
@@ -58,6 +59,7 @@ async function submitVoyageBatch(params: {
return await postJsonWithRetry<VoyageBatchStatus>({ return await postJsonWithRetry<VoyageBatchStatus>({
url: `${baseUrl}/batches`, url: `${baseUrl}/batches`,
headers: buildBatchHeaders(params.client, { json: true }), headers: buildBatchHeaders(params.client, { json: true }),
ssrfPolicy: params.client.ssrfPolicy,
body: { body: {
input_file_id: inputFileId, input_file_id: inputFileId,
endpoint: VOYAGE_BATCH_ENDPOINT, endpoint: VOYAGE_BATCH_ENDPOINT,
@@ -80,14 +82,20 @@ async function fetchVoyageBatchStatus(params: {
batchId: string; batchId: string;
}): Promise<VoyageBatchStatus> { }): Promise<VoyageBatchStatus> {
const baseUrl = normalizeBatchBaseUrl(params.client); const baseUrl = normalizeBatchBaseUrl(params.client);
const res = await fetch(`${baseUrl}/batches/${params.batchId}`, { return await withRemoteHttpResponse({
url: `${baseUrl}/batches/${params.batchId}`,
ssrfPolicy: params.client.ssrfPolicy,
init: {
headers: buildBatchHeaders(params.client, { json: true }), headers: buildBatchHeaders(params.client, { json: true }),
}); },
onResponse: async (res) => {
if (!res.ok) { if (!res.ok) {
const text = await res.text(); const text = await res.text();
throw new Error(`voyage batch status failed: ${res.status} ${text}`); throw new Error(`voyage batch status failed: ${res.status} ${text}`);
} }
return (await res.json()) as VoyageBatchStatus; return (await res.json()) as VoyageBatchStatus;
},
});
} }
async function readVoyageBatchError(params: { async function readVoyageBatchError(params: {
@@ -96,9 +104,13 @@ async function readVoyageBatchError(params: {
}): Promise<string | undefined> { }): Promise<string | undefined> {
try { try {
const baseUrl = normalizeBatchBaseUrl(params.client); const baseUrl = normalizeBatchBaseUrl(params.client);
const res = await fetch(`${baseUrl}/files/${params.errorFileId}/content`, { return await withRemoteHttpResponse({
url: `${baseUrl}/files/${params.errorFileId}/content`,
ssrfPolicy: params.client.ssrfPolicy,
init: {
headers: buildBatchHeaders(params.client, { json: true }), headers: buildBatchHeaders(params.client, { json: true }),
}); },
onResponse: async (res) => {
if (!res.ok) { if (!res.ok) {
const text = await res.text(); const text = await res.text();
throw new Error(`voyage batch error file content failed: ${res.status} ${text}`); throw new Error(`voyage batch error file content failed: ${res.status} ${text}`);
@@ -113,6 +125,8 @@ async function readVoyageBatchError(params: {
.filter(Boolean) .filter(Boolean)
.map((line) => JSON.parse(line) as VoyageBatchOutputLine); .map((line) => JSON.parse(line) as VoyageBatchOutputLine);
return extractBatchErrorMessage(lines); return extractBatchErrorMessage(lines);
},
});
} catch (err) { } catch (err) {
return formatUnavailableBatchError(err); return formatUnavailableBatchError(err);
} }
@@ -228,18 +242,24 @@ export async function runVoyageEmbeddingBatches(params: {
} }
const baseUrl = normalizeBatchBaseUrl(params.client); const baseUrl = normalizeBatchBaseUrl(params.client);
const contentRes = await fetch(`${baseUrl}/files/${completed.outputFileId}/content`, { const errors: string[] = [];
const remaining = new Set(group.map((request) => request.custom_id));
await withRemoteHttpResponse({
url: `${baseUrl}/files/${completed.outputFileId}/content`,
ssrfPolicy: params.client.ssrfPolicy,
init: {
headers: buildBatchHeaders(params.client, { json: true }), headers: buildBatchHeaders(params.client, { json: true }),
}); },
onResponse: async (contentRes) => {
if (!contentRes.ok) { if (!contentRes.ok) {
const text = await contentRes.text(); const text = await contentRes.text();
throw new Error(`voyage batch file content failed: ${contentRes.status} ${text}`); throw new Error(`voyage batch file content failed: ${contentRes.status} ${text}`);
} }
const errors: string[] = []; if (!contentRes.body) {
const remaining = new Set(group.map((request) => request.custom_id)); return;
}
if (contentRes.body) {
const reader = createInterface({ const reader = createInterface({
input: Readable.fromWeb( input: Readable.fromWeb(
contentRes.body as unknown as import("stream/web").ReadableStream, contentRes.body as unknown as import("stream/web").ReadableStream,
@@ -254,7 +274,8 @@ export async function runVoyageEmbeddingBatches(params: {
const line = JSON.parse(rawLine) as VoyageBatchOutputLine; const line = JSON.parse(rawLine) as VoyageBatchOutputLine;
applyEmbeddingBatchOutputLine({ line, remaining, errors, byCustomId }); applyEmbeddingBatchOutputLine({ line, remaining, errors, byCustomId });
} }
} },
});
if (errors.length > 0) { if (errors.length > 0) {
throw new Error(`voyage batch ${batchInfo.id} failed: ${errors.join("; ")}`); throw new Error(`voyage batch ${batchInfo.id} failed: ${errors.join("; ")}`);