diff --git a/src/agents/pi-embedded-helpers/turns.ts b/src/agents/pi-embedded-helpers/turns.ts index ed927d32cad..f6dddb20a04 100644 --- a/src/agents/pi-embedded-helpers/turns.ts +++ b/src/agents/pi-embedded-helpers/turns.ts @@ -1,11 +1,14 @@ import type { AgentMessage } from "@mariozechner/pi-agent-core"; -/** - * Validates and fixes conversation turn sequences for Gemini API. - * Gemini requires strict alternating user→assistant→tool→user pattern. - * Merges consecutive assistant messages together. - */ -export function validateGeminiTurns(messages: AgentMessage[]): AgentMessage[] { +function validateTurnsWithConsecutiveMerge(params: { + messages: AgentMessage[]; + role: TRole; + merge: ( + previous: Extract, + current: Extract, + ) => Extract; +}): AgentMessage[] { + const { messages, role, merge } = params; if (!Array.isArray(messages) || messages.length === 0) { return messages; } @@ -25,28 +28,13 @@ export function validateGeminiTurns(messages: AgentMessage[]): AgentMessage[] { continue; } - if (msgRole === lastRole && lastRole === "assistant") { + if (msgRole === lastRole && lastRole === role) { const lastMsg = result[result.length - 1]; - const currentMsg = msg as Extract; + const currentMsg = msg as Extract; if (lastMsg && typeof lastMsg === "object") { - const lastAsst = lastMsg as Extract; - const mergedContent = [ - ...(Array.isArray(lastAsst.content) ? lastAsst.content : []), - ...(Array.isArray(currentMsg.content) ? currentMsg.content : []), - ]; - - const merged: Extract = { - ...lastAsst, - content: mergedContent, - ...(currentMsg.usage && { usage: currentMsg.usage }), - ...(currentMsg.stopReason && { stopReason: currentMsg.stopReason }), - ...(currentMsg.errorMessage && { - errorMessage: currentMsg.errorMessage, - }), - }; - - result[result.length - 1] = merged; + const lastTyped = lastMsg as Extract; + result[result.length - 1] = merge(lastTyped, currentMsg); continue; } } @@ -58,6 +46,38 @@ export function validateGeminiTurns(messages: AgentMessage[]): AgentMessage[] { return result; } +function mergeConsecutiveAssistantTurns( + previous: Extract, + current: Extract, +): Extract { + const mergedContent = [ + ...(Array.isArray(previous.content) ? previous.content : []), + ...(Array.isArray(current.content) ? current.content : []), + ]; + return { + ...previous, + content: mergedContent, + ...(current.usage && { usage: current.usage }), + ...(current.stopReason && { stopReason: current.stopReason }), + ...(current.errorMessage && { + errorMessage: current.errorMessage, + }), + }; +} + +/** + * Validates and fixes conversation turn sequences for Gemini API. + * Gemini requires strict alternating user→assistant→tool→user pattern. + * Merges consecutive assistant messages together. + */ +export function validateGeminiTurns(messages: AgentMessage[]): AgentMessage[] { + return validateTurnsWithConsecutiveMerge({ + messages, + role: "assistant", + merge: mergeConsecutiveAssistantTurns, + }); +} + export function mergeConsecutiveUserTurns( previous: Extract, current: Extract, @@ -80,40 +100,9 @@ export function mergeConsecutiveUserTurns( * Merges consecutive user messages together. */ export function validateAnthropicTurns(messages: AgentMessage[]): AgentMessage[] { - if (!Array.isArray(messages) || messages.length === 0) { - return messages; - } - - const result: AgentMessage[] = []; - let lastRole: string | undefined; - - for (const msg of messages) { - if (!msg || typeof msg !== "object") { - result.push(msg); - continue; - } - - const msgRole = (msg as { role?: unknown }).role as string | undefined; - if (!msgRole) { - result.push(msg); - continue; - } - - if (msgRole === lastRole && lastRole === "user") { - const lastMsg = result[result.length - 1]; - const currentMsg = msg as Extract; - - if (lastMsg && typeof lastMsg === "object") { - const lastUser = lastMsg as Extract; - const merged = mergeConsecutiveUserTurns(lastUser, currentMsg); - result[result.length - 1] = merged; - continue; - } - } - - result.push(msg); - lastRole = msgRole; - } - - return result; + return validateTurnsWithConsecutiveMerge({ + messages, + role: "user", + merge: mergeConsecutiveUserTurns, + }); }