refactor(memory): share batch provider scaffolding

This commit is contained in:
Peter Steinberger
2026-02-22 20:26:14 +00:00
parent f8171ffcdc
commit ad51372f78
5 changed files with 104 additions and 99 deletions

View File

@@ -1,4 +1,8 @@
import { runEmbeddingBatchGroups } from "./batch-runner.js"; import {
buildEmbeddingBatchGroupOptions,
runEmbeddingBatchGroups,
type EmbeddingBatchExecutionParams,
} from "./batch-runner.js";
import { buildBatchHeaders, normalizeBatchBaseUrl } from "./batch-utils.js"; import { buildBatchHeaders, normalizeBatchBaseUrl } from "./batch-utils.js";
import { debugEmbeddingsLog } from "./embeddings-debug.js"; import { debugEmbeddingsLog } from "./embeddings-debug.js";
import type { GeminiEmbeddingClient } from "./embeddings-gemini.js"; import type { GeminiEmbeddingClient } from "./embeddings-gemini.js";
@@ -261,25 +265,18 @@ async function waitForGeminiBatch(params: {
} }
} }
export async function runGeminiEmbeddingBatches(params: { export async function runGeminiEmbeddingBatches(
gemini: GeminiEmbeddingClient; params: {
agentId: string; gemini: GeminiEmbeddingClient;
requests: GeminiBatchRequest[]; agentId: string;
wait: boolean; requests: GeminiBatchRequest[];
pollIntervalMs: number; } & EmbeddingBatchExecutionParams,
timeoutMs: number; ): Promise<Map<string, number[]>> {
concurrency: number;
debug?: (message: string, data?: Record<string, unknown>) => void;
}): Promise<Map<string, number[]>> {
return await runEmbeddingBatchGroups({ return await runEmbeddingBatchGroups({
requests: params.requests, ...buildEmbeddingBatchGroupOptions(params, {
maxRequests: GEMINI_BATCH_MAX_REQUESTS, maxRequests: GEMINI_BATCH_MAX_REQUESTS,
wait: params.wait, debugLabel: "memory embeddings: gemini batch submit",
pollIntervalMs: params.pollIntervalMs, }),
timeoutMs: params.timeoutMs,
concurrency: params.concurrency,
debug: params.debug,
debugLabel: "memory embeddings: gemini batch submit",
runGroup: async ({ group, groupIndex, groups, byCustomId }) => { runGroup: async ({ group, groupIndex, groups, byCustomId }) => {
const batchInfo = await submitGeminiBatch({ const batchInfo = await submitGeminiBatch({
gemini: params.gemini, gemini: params.gemini,

View File

@@ -1,7 +1,16 @@
import { extractBatchErrorMessage, formatUnavailableBatchError } from "./batch-error-utils.js"; import { extractBatchErrorMessage, formatUnavailableBatchError } from "./batch-error-utils.js";
import { postJsonWithRetry } from "./batch-http.js"; import { postJsonWithRetry } from "./batch-http.js";
import { applyEmbeddingBatchOutputLine } from "./batch-output.js"; import { applyEmbeddingBatchOutputLine } from "./batch-output.js";
import { runEmbeddingBatchGroups } from "./batch-runner.js"; import {
EMBEDDING_BATCH_ENDPOINT,
type EmbeddingBatchStatus,
type ProviderBatchOutputLine,
} from "./batch-provider-common.js";
import {
buildEmbeddingBatchGroupOptions,
runEmbeddingBatchGroups,
type EmbeddingBatchExecutionParams,
} from "./batch-runner.js";
import { uploadBatchJsonlFile } from "./batch-upload.js"; import { uploadBatchJsonlFile } from "./batch-upload.js";
import { buildBatchHeaders, normalizeBatchBaseUrl } from "./batch-utils.js"; import { buildBatchHeaders, normalizeBatchBaseUrl } from "./batch-utils.js";
import type { OpenAiEmbeddingClient } from "./embeddings-openai.js"; import type { OpenAiEmbeddingClient } from "./embeddings-openai.js";
@@ -17,26 +26,10 @@ export type OpenAiBatchRequest = {
}; };
}; };
export type OpenAiBatchStatus = { export type OpenAiBatchStatus = EmbeddingBatchStatus;
id?: string; export type OpenAiBatchOutputLine = ProviderBatchOutputLine;
status?: string;
output_file_id?: string | null;
error_file_id?: string | null;
};
export type OpenAiBatchOutputLine = { export const OPENAI_BATCH_ENDPOINT = EMBEDDING_BATCH_ENDPOINT;
custom_id?: string;
response?: {
status_code?: number;
body?: {
data?: Array<{ embedding?: number[]; index?: number }>;
error?: { message?: string };
};
};
error?: { message?: string };
};
export const OPENAI_BATCH_ENDPOINT = "/v1/embeddings";
const OPENAI_BATCH_COMPLETION_WINDOW = "24h"; const OPENAI_BATCH_COMPLETION_WINDOW = "24h";
const OPENAI_BATCH_MAX_REQUESTS = 50000; const OPENAI_BATCH_MAX_REQUESTS = 50000;
@@ -185,25 +178,18 @@ async function waitForOpenAiBatch(params: {
} }
} }
export async function runOpenAiEmbeddingBatches(params: { export async function runOpenAiEmbeddingBatches(
openAi: OpenAiEmbeddingClient; params: {
agentId: string; openAi: OpenAiEmbeddingClient;
requests: OpenAiBatchRequest[]; agentId: string;
wait: boolean; requests: OpenAiBatchRequest[];
pollIntervalMs: number; } & EmbeddingBatchExecutionParams,
timeoutMs: number; ): Promise<Map<string, number[]>> {
concurrency: number;
debug?: (message: string, data?: Record<string, unknown>) => void;
}): Promise<Map<string, number[]>> {
return await runEmbeddingBatchGroups({ return await runEmbeddingBatchGroups({
requests: params.requests, ...buildEmbeddingBatchGroupOptions(params, {
maxRequests: OPENAI_BATCH_MAX_REQUESTS, maxRequests: OPENAI_BATCH_MAX_REQUESTS,
wait: params.wait, debugLabel: "memory embeddings: openai batch submit",
pollIntervalMs: params.pollIntervalMs, }),
timeoutMs: params.timeoutMs,
concurrency: params.concurrency,
debug: params.debug,
debugLabel: "memory embeddings: openai batch submit",
runGroup: async ({ group, groupIndex, groups, byCustomId }) => { runGroup: async ({ group, groupIndex, groups, byCustomId }) => {
const batchInfo = await submitOpenAiBatch({ const batchInfo = await submitOpenAiBatch({
openAi: params.openAi, openAi: params.openAi,

View File

@@ -0,0 +1,12 @@
import type { EmbeddingBatchOutputLine } from "./batch-output.js";
export type EmbeddingBatchStatus = {
id?: string;
status?: string;
output_file_id?: string | null;
error_file_id?: string | null;
};
export type ProviderBatchOutputLine = EmbeddingBatchOutputLine;
export const EMBEDDING_BATCH_ENDPOINT = "/v1/embeddings";

View File

@@ -1,15 +1,23 @@
import { splitBatchRequests } from "./batch-utils.js"; import { splitBatchRequests } from "./batch-utils.js";
import { runWithConcurrency } from "./internal.js"; import { runWithConcurrency } from "./internal.js";
export async function runEmbeddingBatchGroups<TRequest>(params: { export type EmbeddingBatchExecutionParams = {
requests: TRequest[];
maxRequests: number;
wait: boolean; wait: boolean;
pollIntervalMs: number; pollIntervalMs: number;
timeoutMs: number; timeoutMs: number;
concurrency: number; concurrency: number;
debugLabel: string;
debug?: (message: string, data?: Record<string, unknown>) => void; debug?: (message: string, data?: Record<string, unknown>) => void;
};
export async function runEmbeddingBatchGroups<TRequest>(params: {
requests: TRequest[];
maxRequests: number;
wait: EmbeddingBatchExecutionParams["wait"];
pollIntervalMs: EmbeddingBatchExecutionParams["pollIntervalMs"];
timeoutMs: EmbeddingBatchExecutionParams["timeoutMs"];
concurrency: EmbeddingBatchExecutionParams["concurrency"];
debugLabel: string;
debug?: EmbeddingBatchExecutionParams["debug"];
runGroup: (args: { runGroup: (args: {
group: TRequest[]; group: TRequest[];
groupIndex: number; groupIndex: number;
@@ -38,3 +46,19 @@ export async function runEmbeddingBatchGroups<TRequest>(params: {
await runWithConcurrency(tasks, params.concurrency); await runWithConcurrency(tasks, params.concurrency);
return byCustomId; return byCustomId;
} }
export function buildEmbeddingBatchGroupOptions<TRequest>(
params: { requests: TRequest[] } & EmbeddingBatchExecutionParams,
options: { maxRequests: number; debugLabel: string },
) {
return {
requests: params.requests,
maxRequests: options.maxRequests,
wait: params.wait,
pollIntervalMs: params.pollIntervalMs,
timeoutMs: params.timeoutMs,
concurrency: params.concurrency,
debug: params.debug,
debugLabel: options.debugLabel,
};
}

View File

@@ -3,7 +3,16 @@ import { Readable } from "node:stream";
import { extractBatchErrorMessage, formatUnavailableBatchError } from "./batch-error-utils.js"; import { extractBatchErrorMessage, formatUnavailableBatchError } from "./batch-error-utils.js";
import { postJsonWithRetry } from "./batch-http.js"; import { postJsonWithRetry } from "./batch-http.js";
import { applyEmbeddingBatchOutputLine } from "./batch-output.js"; import { applyEmbeddingBatchOutputLine } from "./batch-output.js";
import { runEmbeddingBatchGroups } from "./batch-runner.js"; import {
EMBEDDING_BATCH_ENDPOINT,
type EmbeddingBatchStatus,
type ProviderBatchOutputLine,
} from "./batch-provider-common.js";
import {
buildEmbeddingBatchGroupOptions,
runEmbeddingBatchGroups,
type EmbeddingBatchExecutionParams,
} from "./batch-runner.js";
import { uploadBatchJsonlFile } from "./batch-upload.js"; import { uploadBatchJsonlFile } from "./batch-upload.js";
import { buildBatchHeaders, normalizeBatchBaseUrl } from "./batch-utils.js"; import { buildBatchHeaders, normalizeBatchBaseUrl } from "./batch-utils.js";
import type { VoyageEmbeddingClient } from "./embeddings-voyage.js"; import type { VoyageEmbeddingClient } from "./embeddings-voyage.js";
@@ -20,26 +29,10 @@ export type VoyageBatchRequest = {
}; };
}; };
export type VoyageBatchStatus = { export type VoyageBatchStatus = EmbeddingBatchStatus;
id?: string; export type VoyageBatchOutputLine = ProviderBatchOutputLine;
status?: string;
output_file_id?: string | null;
error_file_id?: string | null;
};
export type VoyageBatchOutputLine = { export const VOYAGE_BATCH_ENDPOINT = EMBEDDING_BATCH_ENDPOINT;
custom_id?: string;
response?: {
status_code?: number;
body?: {
data?: Array<{ embedding?: number[]; index?: number }>;
error?: { message?: string };
};
};
error?: { message?: string };
};
export const VOYAGE_BATCH_ENDPOINT = "/v1/embeddings";
const VOYAGE_BATCH_COMPLETION_WINDOW = "12h"; const VOYAGE_BATCH_COMPLETION_WINDOW = "12h";
const VOYAGE_BATCH_MAX_REQUESTS = 50000; const VOYAGE_BATCH_MAX_REQUESTS = 50000;
@@ -179,25 +172,18 @@ async function waitForVoyageBatch(params: {
} }
} }
export async function runVoyageEmbeddingBatches(params: { export async function runVoyageEmbeddingBatches(
client: VoyageEmbeddingClient; params: {
agentId: string; client: VoyageEmbeddingClient;
requests: VoyageBatchRequest[]; agentId: string;
wait: boolean; requests: VoyageBatchRequest[];
pollIntervalMs: number; } & EmbeddingBatchExecutionParams,
timeoutMs: number; ): Promise<Map<string, number[]>> {
concurrency: number;
debug?: (message: string, data?: Record<string, unknown>) => void;
}): Promise<Map<string, number[]>> {
return await runEmbeddingBatchGroups({ return await runEmbeddingBatchGroups({
requests: params.requests, ...buildEmbeddingBatchGroupOptions(params, {
maxRequests: VOYAGE_BATCH_MAX_REQUESTS, maxRequests: VOYAGE_BATCH_MAX_REQUESTS,
wait: params.wait, debugLabel: "memory embeddings: voyage batch submit",
pollIntervalMs: params.pollIntervalMs, }),
timeoutMs: params.timeoutMs,
concurrency: params.concurrency,
debug: params.debug,
debugLabel: "memory embeddings: voyage batch submit",
runGroup: async ({ group, groupIndex, groups, byCustomId }) => { runGroup: async ({ group, groupIndex, groups, byCustomId }) => {
const batchInfo = await submitVoyageBatch({ const batchInfo = await submitVoyageBatch({
client: params.client, client: params.client,