diff --git a/src/memory/batch-gemini.ts b/src/memory/batch-gemini.ts index 50f3b3f9460..998f283b676 100644 --- a/src/memory/batch-gemini.ts +++ b/src/memory/batch-gemini.ts @@ -1,4 +1,8 @@ -import { runEmbeddingBatchGroups } from "./batch-runner.js"; +import { + buildEmbeddingBatchGroupOptions, + runEmbeddingBatchGroups, + type EmbeddingBatchExecutionParams, +} from "./batch-runner.js"; import { buildBatchHeaders, normalizeBatchBaseUrl } from "./batch-utils.js"; import { debugEmbeddingsLog } from "./embeddings-debug.js"; import type { GeminiEmbeddingClient } from "./embeddings-gemini.js"; @@ -261,25 +265,18 @@ async function waitForGeminiBatch(params: { } } -export async function runGeminiEmbeddingBatches(params: { - gemini: GeminiEmbeddingClient; - agentId: string; - requests: GeminiBatchRequest[]; - wait: boolean; - pollIntervalMs: number; - timeoutMs: number; - concurrency: number; - debug?: (message: string, data?: Record) => void; -}): Promise> { +export async function runGeminiEmbeddingBatches( + params: { + gemini: GeminiEmbeddingClient; + agentId: string; + requests: GeminiBatchRequest[]; + } & EmbeddingBatchExecutionParams, +): Promise> { return await runEmbeddingBatchGroups({ - requests: params.requests, - maxRequests: GEMINI_BATCH_MAX_REQUESTS, - wait: params.wait, - pollIntervalMs: params.pollIntervalMs, - timeoutMs: params.timeoutMs, - concurrency: params.concurrency, - debug: params.debug, - debugLabel: "memory embeddings: gemini batch submit", + ...buildEmbeddingBatchGroupOptions(params, { + maxRequests: GEMINI_BATCH_MAX_REQUESTS, + debugLabel: "memory embeddings: gemini batch submit", + }), runGroup: async ({ group, groupIndex, groups, byCustomId }) => { const batchInfo = await submitGeminiBatch({ gemini: params.gemini, diff --git a/src/memory/batch-openai.ts b/src/memory/batch-openai.ts index c1a0a97c4db..158b75faf1f 100644 --- a/src/memory/batch-openai.ts +++ b/src/memory/batch-openai.ts @@ -1,7 +1,16 @@ import { extractBatchErrorMessage, formatUnavailableBatchError } from "./batch-error-utils.js"; import { postJsonWithRetry } from "./batch-http.js"; import { applyEmbeddingBatchOutputLine } from "./batch-output.js"; -import { runEmbeddingBatchGroups } from "./batch-runner.js"; +import { + EMBEDDING_BATCH_ENDPOINT, + type EmbeddingBatchStatus, + type ProviderBatchOutputLine, +} from "./batch-provider-common.js"; +import { + buildEmbeddingBatchGroupOptions, + runEmbeddingBatchGroups, + type EmbeddingBatchExecutionParams, +} from "./batch-runner.js"; import { uploadBatchJsonlFile } from "./batch-upload.js"; import { buildBatchHeaders, normalizeBatchBaseUrl } from "./batch-utils.js"; import type { OpenAiEmbeddingClient } from "./embeddings-openai.js"; @@ -17,26 +26,10 @@ export type OpenAiBatchRequest = { }; }; -export type OpenAiBatchStatus = { - id?: string; - status?: string; - output_file_id?: string | null; - error_file_id?: string | null; -}; +export type OpenAiBatchStatus = EmbeddingBatchStatus; +export type OpenAiBatchOutputLine = ProviderBatchOutputLine; -export type OpenAiBatchOutputLine = { - custom_id?: string; - response?: { - status_code?: number; - body?: { - data?: Array<{ embedding?: number[]; index?: number }>; - error?: { message?: string }; - }; - }; - error?: { message?: string }; -}; - -export const OPENAI_BATCH_ENDPOINT = "/v1/embeddings"; +export const OPENAI_BATCH_ENDPOINT = EMBEDDING_BATCH_ENDPOINT; const OPENAI_BATCH_COMPLETION_WINDOW = "24h"; const OPENAI_BATCH_MAX_REQUESTS = 50000; @@ -185,25 +178,18 @@ async function waitForOpenAiBatch(params: { } } -export async function runOpenAiEmbeddingBatches(params: { - openAi: OpenAiEmbeddingClient; - agentId: string; - requests: OpenAiBatchRequest[]; - wait: boolean; - pollIntervalMs: number; - timeoutMs: number; - concurrency: number; - debug?: (message: string, data?: Record) => void; -}): Promise> { +export async function runOpenAiEmbeddingBatches( + params: { + openAi: OpenAiEmbeddingClient; + agentId: string; + requests: OpenAiBatchRequest[]; + } & EmbeddingBatchExecutionParams, +): Promise> { return await runEmbeddingBatchGroups({ - requests: params.requests, - maxRequests: OPENAI_BATCH_MAX_REQUESTS, - wait: params.wait, - pollIntervalMs: params.pollIntervalMs, - timeoutMs: params.timeoutMs, - concurrency: params.concurrency, - debug: params.debug, - debugLabel: "memory embeddings: openai batch submit", + ...buildEmbeddingBatchGroupOptions(params, { + maxRequests: OPENAI_BATCH_MAX_REQUESTS, + debugLabel: "memory embeddings: openai batch submit", + }), runGroup: async ({ group, groupIndex, groups, byCustomId }) => { const batchInfo = await submitOpenAiBatch({ openAi: params.openAi, diff --git a/src/memory/batch-provider-common.ts b/src/memory/batch-provider-common.ts new file mode 100644 index 00000000000..878387ffd6d --- /dev/null +++ b/src/memory/batch-provider-common.ts @@ -0,0 +1,12 @@ +import type { EmbeddingBatchOutputLine } from "./batch-output.js"; + +export type EmbeddingBatchStatus = { + id?: string; + status?: string; + output_file_id?: string | null; + error_file_id?: string | null; +}; + +export type ProviderBatchOutputLine = EmbeddingBatchOutputLine; + +export const EMBEDDING_BATCH_ENDPOINT = "/v1/embeddings"; diff --git a/src/memory/batch-runner.ts b/src/memory/batch-runner.ts index 52045a3a268..aa1785095bb 100644 --- a/src/memory/batch-runner.ts +++ b/src/memory/batch-runner.ts @@ -1,15 +1,23 @@ import { splitBatchRequests } from "./batch-utils.js"; import { runWithConcurrency } from "./internal.js"; -export async function runEmbeddingBatchGroups(params: { - requests: TRequest[]; - maxRequests: number; +export type EmbeddingBatchExecutionParams = { wait: boolean; pollIntervalMs: number; timeoutMs: number; concurrency: number; - debugLabel: string; debug?: (message: string, data?: Record) => void; +}; + +export async function runEmbeddingBatchGroups(params: { + requests: TRequest[]; + maxRequests: number; + wait: EmbeddingBatchExecutionParams["wait"]; + pollIntervalMs: EmbeddingBatchExecutionParams["pollIntervalMs"]; + timeoutMs: EmbeddingBatchExecutionParams["timeoutMs"]; + concurrency: EmbeddingBatchExecutionParams["concurrency"]; + debugLabel: string; + debug?: EmbeddingBatchExecutionParams["debug"]; runGroup: (args: { group: TRequest[]; groupIndex: number; @@ -38,3 +46,19 @@ export async function runEmbeddingBatchGroups(params: { await runWithConcurrency(tasks, params.concurrency); return byCustomId; } + +export function buildEmbeddingBatchGroupOptions( + params: { requests: TRequest[] } & EmbeddingBatchExecutionParams, + options: { maxRequests: number; debugLabel: string }, +) { + return { + requests: params.requests, + maxRequests: options.maxRequests, + wait: params.wait, + pollIntervalMs: params.pollIntervalMs, + timeoutMs: params.timeoutMs, + concurrency: params.concurrency, + debug: params.debug, + debugLabel: options.debugLabel, + }; +} diff --git a/src/memory/batch-voyage.ts b/src/memory/batch-voyage.ts index 322adedc311..07722ac19f2 100644 --- a/src/memory/batch-voyage.ts +++ b/src/memory/batch-voyage.ts @@ -3,7 +3,16 @@ import { Readable } from "node:stream"; import { extractBatchErrorMessage, formatUnavailableBatchError } from "./batch-error-utils.js"; import { postJsonWithRetry } from "./batch-http.js"; import { applyEmbeddingBatchOutputLine } from "./batch-output.js"; -import { runEmbeddingBatchGroups } from "./batch-runner.js"; +import { + EMBEDDING_BATCH_ENDPOINT, + type EmbeddingBatchStatus, + type ProviderBatchOutputLine, +} from "./batch-provider-common.js"; +import { + buildEmbeddingBatchGroupOptions, + runEmbeddingBatchGroups, + type EmbeddingBatchExecutionParams, +} from "./batch-runner.js"; import { uploadBatchJsonlFile } from "./batch-upload.js"; import { buildBatchHeaders, normalizeBatchBaseUrl } from "./batch-utils.js"; import type { VoyageEmbeddingClient } from "./embeddings-voyage.js"; @@ -20,26 +29,10 @@ export type VoyageBatchRequest = { }; }; -export type VoyageBatchStatus = { - id?: string; - status?: string; - output_file_id?: string | null; - error_file_id?: string | null; -}; +export type VoyageBatchStatus = EmbeddingBatchStatus; +export type VoyageBatchOutputLine = ProviderBatchOutputLine; -export type VoyageBatchOutputLine = { - custom_id?: string; - response?: { - status_code?: number; - body?: { - data?: Array<{ embedding?: number[]; index?: number }>; - error?: { message?: string }; - }; - }; - error?: { message?: string }; -}; - -export const VOYAGE_BATCH_ENDPOINT = "/v1/embeddings"; +export const VOYAGE_BATCH_ENDPOINT = EMBEDDING_BATCH_ENDPOINT; const VOYAGE_BATCH_COMPLETION_WINDOW = "12h"; const VOYAGE_BATCH_MAX_REQUESTS = 50000; @@ -179,25 +172,18 @@ async function waitForVoyageBatch(params: { } } -export async function runVoyageEmbeddingBatches(params: { - client: VoyageEmbeddingClient; - agentId: string; - requests: VoyageBatchRequest[]; - wait: boolean; - pollIntervalMs: number; - timeoutMs: number; - concurrency: number; - debug?: (message: string, data?: Record) => void; -}): Promise> { +export async function runVoyageEmbeddingBatches( + params: { + client: VoyageEmbeddingClient; + agentId: string; + requests: VoyageBatchRequest[]; + } & EmbeddingBatchExecutionParams, +): Promise> { return await runEmbeddingBatchGroups({ - requests: params.requests, - maxRequests: VOYAGE_BATCH_MAX_REQUESTS, - wait: params.wait, - pollIntervalMs: params.pollIntervalMs, - timeoutMs: params.timeoutMs, - concurrency: params.concurrency, - debug: params.debug, - debugLabel: "memory embeddings: voyage batch submit", + ...buildEmbeddingBatchGroupOptions(params, { + maxRequests: VOYAGE_BATCH_MAX_REQUESTS, + debugLabel: "memory embeddings: voyage batch submit", + }), runGroup: async ({ group, groupIndex, groups, byCustomId }) => { const batchInfo = await submitVoyageBatch({ client: params.client,