refactor(agents): share turn validation skeleton

This commit is contained in:
Peter Steinberger
2026-02-15 05:55:36 +00:00
parent 485b78bb94
commit 806c8b3129

View File

@@ -1,11 +1,14 @@
import type { AgentMessage } from "@mariozechner/pi-agent-core"; import type { AgentMessage } from "@mariozechner/pi-agent-core";
/** function validateTurnsWithConsecutiveMerge<TRole extends "assistant" | "user">(params: {
* Validates and fixes conversation turn sequences for Gemini API. messages: AgentMessage[];
* Gemini requires strict alternating user→assistant→tool→user pattern. role: TRole;
* Merges consecutive assistant messages together. merge: (
*/ previous: Extract<AgentMessage, { role: TRole }>,
export function validateGeminiTurns(messages: AgentMessage[]): AgentMessage[] { current: Extract<AgentMessage, { role: TRole }>,
) => Extract<AgentMessage, { role: TRole }>;
}): AgentMessage[] {
const { messages, role, merge } = params;
if (!Array.isArray(messages) || messages.length === 0) { if (!Array.isArray(messages) || messages.length === 0) {
return messages; return messages;
} }
@@ -25,28 +28,13 @@ export function validateGeminiTurns(messages: AgentMessage[]): AgentMessage[] {
continue; continue;
} }
if (msgRole === lastRole && lastRole === "assistant") { if (msgRole === lastRole && lastRole === role) {
const lastMsg = result[result.length - 1]; const lastMsg = result[result.length - 1];
const currentMsg = msg as Extract<AgentMessage, { role: "assistant" }>; const currentMsg = msg as Extract<AgentMessage, { role: TRole }>;
if (lastMsg && typeof lastMsg === "object") { if (lastMsg && typeof lastMsg === "object") {
const lastAsst = lastMsg as Extract<AgentMessage, { role: "assistant" }>; const lastTyped = lastMsg as Extract<AgentMessage, { role: TRole }>;
const mergedContent = [ result[result.length - 1] = merge(lastTyped, currentMsg);
...(Array.isArray(lastAsst.content) ? lastAsst.content : []),
...(Array.isArray(currentMsg.content) ? currentMsg.content : []),
];
const merged: Extract<AgentMessage, { role: "assistant" }> = {
...lastAsst,
content: mergedContent,
...(currentMsg.usage && { usage: currentMsg.usage }),
...(currentMsg.stopReason && { stopReason: currentMsg.stopReason }),
...(currentMsg.errorMessage && {
errorMessage: currentMsg.errorMessage,
}),
};
result[result.length - 1] = merged;
continue; continue;
} }
} }
@@ -58,6 +46,38 @@ export function validateGeminiTurns(messages: AgentMessage[]): AgentMessage[] {
return result; return result;
} }
function mergeConsecutiveAssistantTurns(
previous: Extract<AgentMessage, { role: "assistant" }>,
current: Extract<AgentMessage, { role: "assistant" }>,
): Extract<AgentMessage, { role: "assistant" }> {
const mergedContent = [
...(Array.isArray(previous.content) ? previous.content : []),
...(Array.isArray(current.content) ? current.content : []),
];
return {
...previous,
content: mergedContent,
...(current.usage && { usage: current.usage }),
...(current.stopReason && { stopReason: current.stopReason }),
...(current.errorMessage && {
errorMessage: current.errorMessage,
}),
};
}
/**
* Validates and fixes conversation turn sequences for Gemini API.
* Gemini requires strict alternating user→assistant→tool→user pattern.
* Merges consecutive assistant messages together.
*/
export function validateGeminiTurns(messages: AgentMessage[]): AgentMessage[] {
return validateTurnsWithConsecutiveMerge({
messages,
role: "assistant",
merge: mergeConsecutiveAssistantTurns,
});
}
export function mergeConsecutiveUserTurns( export function mergeConsecutiveUserTurns(
previous: Extract<AgentMessage, { role: "user" }>, previous: Extract<AgentMessage, { role: "user" }>,
current: Extract<AgentMessage, { role: "user" }>, current: Extract<AgentMessage, { role: "user" }>,
@@ -80,40 +100,9 @@ export function mergeConsecutiveUserTurns(
* Merges consecutive user messages together. * Merges consecutive user messages together.
*/ */
export function validateAnthropicTurns(messages: AgentMessage[]): AgentMessage[] { export function validateAnthropicTurns(messages: AgentMessage[]): AgentMessage[] {
if (!Array.isArray(messages) || messages.length === 0) { return validateTurnsWithConsecutiveMerge({
return messages; messages,
} role: "user",
merge: mergeConsecutiveUserTurns,
const result: AgentMessage[] = []; });
let lastRole: string | undefined;
for (const msg of messages) {
if (!msg || typeof msg !== "object") {
result.push(msg);
continue;
}
const msgRole = (msg as { role?: unknown }).role as string | undefined;
if (!msgRole) {
result.push(msg);
continue;
}
if (msgRole === lastRole && lastRole === "user") {
const lastMsg = result[result.length - 1];
const currentMsg = msg as Extract<AgentMessage, { role: "user" }>;
if (lastMsg && typeof lastMsg === "object") {
const lastUser = lastMsg as Extract<AgentMessage, { role: "user" }>;
const merged = mergeConsecutiveUserTurns(lastUser, currentMsg);
result[result.length - 1] = merged;
continue;
}
}
result.push(msg);
lastRole = msgRole;
}
return result;
} }