From 48a7957294646f524ca6cf45564fc7dfb2a48f36 Mon Sep 17 00:00:00 2001 From: David Cramer Date: Tue, 17 Mar 2026 11:18:24 -0700 Subject: [PATCH 1/2] Add multimodal message-chain eval support --- src/evaluate/index.test.ts | 106 ++++++- src/evaluate/index.ts | 39 ++- src/formatScores.test.ts | 43 +++ src/index.ts | 239 ++++++---------- src/messages.test.ts | 163 +++++++++++ src/messages.ts | 558 +++++++++++++++++++++++++++++++++++++ 6 files changed, 976 insertions(+), 172 deletions(-) create mode 100644 src/messages.test.ts create mode 100644 src/messages.ts diff --git a/src/evaluate/index.test.ts b/src/evaluate/index.test.ts index bfd954c..1c6a5e6 100644 --- a/src/evaluate/index.test.ts +++ b/src/evaluate/index.test.ts @@ -152,8 +152,110 @@ describe("evaluate", () => { }); const call = mockGenerateObject.mock.calls[0][0]; - expect(call.prompt).toContain("the task output"); - expect(call.prompt).toContain("must mention specific details"); + expect(call.messages[0].content[0].text).toContain("[ASSISTANT]"); + expect(call.messages[0].content[0].text).toContain("the task output"); + expect(call.messages[1].content).toContain("must mention specific details"); + }); + + test("passes multimodal message chains to the judge", async () => { + mockGenerateObject.mockResolvedValueOnce({ + object: { answer: "A", rationale: "Handled the transcript correctly" }, + } as any); + + const ctx = makeContext(); + await _evaluate(ctx, { + task: async () => ({ + messages: [ + { + role: "user", + parts: [ + { type: "text", text: "What is shown here?" }, + { + type: "image", + image: "data:image/png;base64,abc123", + mediaType: "image/png", + }, + ], + }, + { + role: "assistant", + parts: [{ type: "text", text: "It is a cat." }], + }, + ], + }), + criteria: "The answer should identify the subject of the image", + threshold: 1, + }); + + const call = mockGenerateObject.mock.calls[0][0]; + const transcriptText = call.messages[0].content + .filter((part: any) => part.type === "text") + .map((part: any) => part.text) + .join(""); + expect(transcriptText).toContain("[USER]\nWhat is shown here?"); + expect(transcriptText).toContain("[image image/png]"); + expect(transcriptText).toContain("[ASSISTANT]\nIt is a cat."); + expect(call.messages[0].content).toContainEqual({ + type: "image", + image: "data:image/png;base64,abc123", + mediaType: "image/png", + }); + }); + + test("does not pass tool metadata to the judge by default", async () => { + mockGenerateObject.mockResolvedValueOnce({ + object: { answer: "A", rationale: "Focused on the visible transcript" }, + } as any); + + const ctx = makeContext(); + await _evaluate(ctx, { + task: async () => ({ + messages: [ + { + role: "user", + parts: [{ type: "text", text: "What is the weather?" }], + }, + { + role: "assistant", + parts: [ + { + type: "tool-call", + toolName: "getWeather", + toolCallId: "call-1", + input: { location: "Seattle" }, + }, + ], + }, + { + role: "tool", + parts: [ + { + type: "tool-result", + toolName: "getWeather", + toolCallId: "call-1", + output: { temperature: 72 }, + }, + ], + }, + { + role: "assistant", + parts: [{ type: "text", text: "It is 72F in Seattle." }], + }, + ], + }), + criteria: "The answer should report the weather to the user", + threshold: 1, + }); + + const call = mockGenerateObject.mock.calls[0][0]; + const transcriptText = call.messages[0].content + .filter((part: any) => part.type === "text") + .map((part: any) => part.text) + .join(""); + expect(transcriptText).toContain("[USER]\nWhat is the weather?"); + expect(transcriptText).toContain("[ASSISTANT]\nIt is 72F in Seattle."); + expect(transcriptText).not.toContain("tool-call"); + expect(transcriptText).not.toContain("tool-result"); }); test("maps all answer choices to correct scores", async () => { diff --git a/src/evaluate/index.ts b/src/evaluate/index.ts index b48088e..6ff73c6 100644 --- a/src/evaluate/index.ts +++ b/src/evaluate/index.ts @@ -1,7 +1,12 @@ import { generateObject } from "ai"; import { z } from "zod"; import { assert, test } from "vitest"; -import { wrapText } from "../wrapText"; +import { + type TaskResult, + formatEvalValue, + normalizeEvaluateOutput, + toJudgeUserMessage, +} from "../messages"; type LanguageModel = Parameters[0]["model"]; @@ -13,15 +18,11 @@ export function configure(opts: { model: LanguageModel }) { const EVAL_SYSTEM = `You are assessing a submitted output based on a given criterion. Ignore differences in style, grammar, punctuation, or length. Focus only on whether the criterion is met.`; -const EVAL_PROMPT = (output: string, criteria: string) => ` -${output} - - - +const EVAL_PROMPT = (criteria: string) => ` ${criteria} -Does the submission meet the criteria? Select one option: +Does the conversation transcript meet the criteria? Select one option: (A) The criteria is fully met with no issues (B) The criteria is mostly met with minor gaps (C) The criteria is partially met with notable gaps @@ -37,7 +38,7 @@ const CHOICE_SCORES: Record = { }; interface EvaluateOptions { - task: () => Promise; + task: () => Promise; criteria: string; threshold?: number; } @@ -57,9 +58,11 @@ export async function _evaluate( ); } - let output: string; + let taskOutput: string | TaskResult; + let evaluationOutput: ReturnType; try { - output = await opts.task(); + taskOutput = await opts.task(); + evaluationOutput = normalizeEvaluateOutput(taskOutput); } catch (error) { const errorMessage = error instanceof Error ? error.message : String(error); ctx.task.meta.eval = { @@ -84,7 +87,13 @@ export async function _evaluate( rationale: z.string(), }), system: EVAL_SYSTEM, - prompt: EVAL_PROMPT(output, opts.criteria), + messages: [ + toJudgeUserMessage(evaluationOutput.messages), + { + role: "user", + content: EVAL_PROMPT(opts.criteria), + }, + ], })); } catch (error) { const errorMessage = error instanceof Error ? error.message : String(error); @@ -118,7 +127,13 @@ export async function _evaluate( if (score < threshold) { assert( false, - `Score: ${score} (${object.answer}) below threshold: ${threshold}\n\n## Output:\n${wrapText(output)}\n\n## Rationale:\n${wrapText(object.rationale)}`, + `Score: ${score} (${object.answer}) below threshold: ${threshold}\n\n## Output:\n${formatEvalValue( + typeof taskOutput === "string" + ? taskOutput + : "result" in taskOutput && taskOutput.result !== undefined + ? taskOutput.result + : taskOutput.messages, + )}\n\n## Rationale:\n${formatEvalValue(object.rationale)}`, ); } } diff --git a/src/formatScores.test.ts b/src/formatScores.test.ts index 81d69fc..ba4c23b 100644 --- a/src/formatScores.test.ts +++ b/src/formatScores.test.ts @@ -72,4 +72,47 @@ describe("formatScores", () => { # Scorer B [0.8]" `); }); + + it("should format message-chain outputs", () => { + const scores = [ + { + name: "Scorer A", + score: 0.2, + metadata: { + rationale: "Image description was incorrect", + output: [ + { + role: "assistant", + parts: [ + { type: "text", text: "A dog on a sofa." }, + { + type: "image", + image: "data:image/png;base64,abc", + mediaType: "image/png", + }, + ], + }, + ], + }, + }, + ]; + + const result = formatScores(scores); + + expect(result).toMatchInlineSnapshot(` + "# Scorer A [0.2] + + ## Rationale + + Image description was incorrect + + ## Response + + ## assistant + + A dog on a sofa. + + [image image/png]" + `); + }); }); diff --git a/src/index.ts b/src/index.ts index d77425d..b90ff6a 100644 --- a/src/index.ts +++ b/src/index.ts @@ -7,44 +7,24 @@ import { test, } from "vitest"; import "vitest"; +import { + type EvalDataInput, + type EvalMessage, + type TaskInput, + type TaskResult, + type ToolCall, + formatEvalValue, + getDefaultTestName, + getTaskInput, + normalizeScorerPayload, +} from "./messages"; import { wrapText } from "./wrapText"; -/** - * Represents a tool/function call made during task execution. - * Supports various LLM provider formats and use cases. - */ -export type ToolCall = { - // Core fields (required for basic usage) - name: string; - arguments?: Record; - - // Additional metadata - [key: string]: any; // Allow provider-specific fields -}; - -export type TaskResult = { - result: string; - toolCalls?: ToolCall[]; -}; - /** * Task function that processes an input and returns either a string result - * or a TaskResult object containing the result and any tool calls made. - * - * @param input - The input string to process - * @returns Promise resolving to either a string or TaskResult object - * - * @example - * // Simple tasks can just return a string - * const simpleTask: TaskFn = async (input) => "The answer is 42"; - * - * // Tasks that use tools should return TaskResult - * const taskWithTools: TaskFn = async (input) => ({ - * result: "The answer is 42", - * toolCalls: [{ name: "calculate", arguments: { expr: "6*7" }, result: 42 }] - * }); + * or a TaskResult object containing response messages and any tool calls made. */ -export type TaskFn = (input: string) => Promise; +export type TaskFn = (input: TaskInput) => Promise; export type Score = { score: number | null; @@ -57,6 +37,9 @@ export type Score = { export interface BaseScorerOptions { input: string; output: string; + messages: EvalMessage[]; + inputMessages: EvalMessage[]; + outputMessages: EvalMessage[]; toolCalls?: ToolCall[]; } @@ -91,53 +74,37 @@ declare module "vitest" { } } +function formatEvaluationOutputForDisplay( + taskOutput: string | TaskResult, +): string { + if (typeof taskOutput === "string") { + return formatEvalValue(taskOutput); + } + + if ("result" in taskOutput && taskOutput.result !== undefined) { + return formatEvalValue(taskOutput.result); + } + + return formatEvalValue(taskOutput.messages); +} + expect.extend({ /** * Evaluates a language model output against an expected answer using a scoring function. * * @deprecated Use describeEval() instead for better test organization and multiple scorers support - * @param expected - The expected (ground truth) answer, can be any type depending on the scorer - * @param taskFn - Async function that processes the input and returns the model output - * Can return either a string or TaskResult object with result and optional toolCalls - * @param scoreFn - Function that evaluates the model output against the expected answer - * @param threshold - Minimum acceptable score (0-1), defaults to 1.0 - * - * @example - * ```javascript - * test("checks capital of France", async () => { - * expect("What is the capital of France?").toEval( - * "Paris", - * async (input) => { - * const response = await queryLLM(input); - * // Recommended: return TaskResult - * return { - * result: response.text, - * toolCalls: response.toolCalls || [] - * }; - * }, - * checkFactuality, - * 0.8 - * ); - * }); - * ``` */ - // TODO: this needs to be support true extensibility with Eval scorers toEval: async function toEval( - input: string, + input: TaskInput, expected: any, taskFn: TaskFn, scoreFn: ScoreFn, threshold = 1.0, ) { - const { isNot } = this; - const taskOutput = await taskFn(input); - const output = - typeof taskOutput === "string" ? taskOutput : taskOutput.result; - const toolCalls = - typeof taskOutput === "object" ? taskOutput.toolCalls : undefined; + const normalized = normalizeScorerPayload(input, taskOutput); - let result = scoreFn({ input, expected, output, toolCalls }); + let result = scoreFn({ expected, ...normalized }); if (result instanceof Promise) { result = await result; } @@ -149,59 +116,6 @@ expect.extend({ }, }); -/** - * Creates a test suite for evaluating language model outputs. - * - * @param name - The name of the test suite - * @param options - Configuration options - * @param options.data - Async function that returns an array of test cases with input and any additional fields - * @param options.task - Function that processes the input and returns the model output - * Can return either a string or TaskResult object with result and optional toolCalls - * @param options.skipIf - Optional function that determines if tests should be skipped - * @param options.scorers - Array of scoring functions that evaluate model outputs - * @param options.threshold - Minimum acceptable average score (0-1), defaults to 1.0 - * @param options.timeout - Test timeout in milliseconds, defaults to 60000 (60s) - * - * @example - * ```javascript - * // Recommended: TaskResult format with tool tracking - * describeEval("capital cities test", { - * data: async () => [{ - * input: "What is the capital of France?", - * expected: "Paris" - * }], - * task: async (input) => { - * const response = await queryLLM(input); - * return { - * result: response.text, - * toolCalls: response.toolCalls || [] - * }; - * }, - * scorers: [checkFactuality], - * threshold: 0.8 - * }); - * - * // Example with tool usage evaluation - * describeEval("tool usage test", { - * data: async () => [{ - * input: "Search for weather in Seattle", - * expectedTools: [{ name: "weather_api", arguments: { location: "Seattle" } }] - * }], - * task: async (input) => { - * return { - * result: "The weather in Seattle is 65°F", - * toolCalls: [{ - * name: "weather_api", - * arguments: { location: "Seattle" }, - * result: { temp: 65, condition: "partly cloudy" } - * }] - * }; - * }, - * scorers: [ToolCallScorer()], - * threshold: 1.0 - * }); - * ``` - */ export function describeEval( name: string, { @@ -210,14 +124,12 @@ export function describeEval( skipIf, scorers, threshold = 1.0, - // increase default test timeout as 5s is usually not enough for - // a single factuality check timeout = 60000, beforeEach: beforeEachHook, afterEach: afterEachHook, }: { data: () => Promise< - Array<{ input: string; name?: string } & Record> + Array<{ name?: string } & EvalDataInput & Record> >; task: TaskFn; skipIf?: () => boolean; @@ -237,49 +149,60 @@ export function describeEval( } const testFn = skipIf ? test.skipIf(skipIf()) : test; - // TODO: should data just be a generator? - for (const { input, name: testName, ...params } of await data()) { + for (const testCase of await data()) { + const { + input, + messages, + name: testName, + ...params + } = testCase as { + input?: string; + messages?: EvalMessage[]; + name?: string; + } & Record; + + const taskInput = getTaskInput(input, messages); + testFn( - testName ?? input, + testName ?? getDefaultTestName(taskInput), { timeout, }, async ({ task: testTask }) => { - const taskOutput = await task(input); - const output = - typeof taskOutput === "string" ? taskOutput : taskOutput.result; - const toolCalls = - typeof taskOutput === "object" ? taskOutput.toolCalls : undefined; + const taskOutput = await task(taskInput); + const normalized = normalizeScorerPayload(taskInput, taskOutput); const scores = await Promise.all( scorers.map((scorer) => { - const result = scorer({ input, ...params, output, toolCalls }); + const result = scorer({ ...params, ...normalized }); if (result instanceof Promise) { return result; } - return new Promise((resolve) => resolve(result)); + return Promise.resolve(result); }), ); - const scoresWithName = scores.map((s, i) => ({ - ...s, - name: scorers[i].name, + + const scoresWithName = scores.map((score, index) => ({ + ...score, + name: scorers[index].name, })); const avgScore = - scores.reduce((acc, s) => acc + (s.score ?? 0), 0) / scores.length; + scores.reduce((acc, score) => acc + (score.score ?? 0), 0) / + scores.length; testTask.meta.eval = { scores: scoresWithName, avgScore, - ...(toolCalls && { toolCalls }), + ...(normalized.toolCalls && { toolCalls: normalized.toolCalls }), }; if (threshold) { assert( avgScore >= threshold, - `Score: ${avgScore} below threshold: ${threshold}\n\n## Output:\n${wrapText(output)}\n\n${formatScores( - scoresWithName, - )}`, + `Score: ${avgScore} below threshold: ${threshold}\n\n## Output:\n${formatEvaluationOutputForDisplay( + taskOutput, + )}\n\n${formatScores(scoresWithName)}`, ); } }, @@ -291,27 +214,20 @@ export function describeEval( export function formatScores(scores: (Score & { name: string })[]) { return scores .sort((a, b) => (a.score ?? 0) - (b.score ?? 0)) - .map((s) => { - const scoreLine = `# ${s.name || "Unknown"} [${(s.score ?? 0).toFixed(1)}]`; + .map((score) => { + const scoreLine = `# ${score.name || "Unknown"} [${(score.score ?? 0).toFixed(1)}]`; if ( - ((s.score ?? 0) < 1.0 && s.metadata?.rationale) || - s.metadata?.output + ((score.score ?? 0) < 1.0 && score.metadata?.rationale) || + score.metadata?.output !== undefined ) { - // Format output - handle both strings and objects - let formattedOutput = ""; - if (s.metadata?.output !== undefined) { - const output = s.metadata.output; - if (typeof output === "string") { - formattedOutput = `\n\n## Response\n\n${wrapText(output)}`; - } else { - // For objects, stringify with proper formatting - formattedOutput = `\n\n## Response\n\n${wrapText(JSON.stringify(output, null, 2))}`; - } - } + const formattedOutput = + score.metadata?.output !== undefined + ? `\n\n## Response\n\n${formatEvalValue(score.metadata.output)}` + : ""; return `${scoreLine}${ - s.metadata?.rationale - ? `\n\n## Rationale\n\n${wrapText(s.metadata.rationale)}` + score.metadata?.rationale + ? `\n\n## Rationale\n\n${wrapText(score.metadata.rationale)}` : "" }${formattedOutput}`; } @@ -321,8 +237,15 @@ export function formatScores(scores: (Score & { name: string })[]) { } export { wrapText } from "./wrapText"; +export type { + EvalDataInput, + EvalMessage, + EvalPart, + TaskInput, + TaskResult, + ToolCall, +} from "./messages"; -// Export built-in scorers export { ToolCallScorer, type ToolCallScorerOptions, diff --git a/src/messages.test.ts b/src/messages.test.ts new file mode 100644 index 0000000..8451387 --- /dev/null +++ b/src/messages.test.ts @@ -0,0 +1,163 @@ +import { describe, expect, test, vi } from "vitest"; +import { describeEval, ToolCallScorer } from "./index"; +import { + formatEvalValue, + getTaskInput, + normalizeScorerPayload, + type EvalMessage, +} from "./messages"; + +const multimodalInput: EvalMessage[] = [ + { + role: "system", + parts: [{ type: "text", text: "Answer concisely." }], + }, + { + role: "user", + parts: [ + { type: "text", text: "Describe this image" }, + { + type: "image", + image: "data:image/png;base64,abc123", + mediaType: "image/png", + }, + ], + }, +]; + +const multimodalOutput: EvalMessage[] = [ + { + role: "assistant", + parts: [{ type: "text", text: "A cat sitting on a chair." }], + }, +]; + +describe("message normalization", () => { + test("toEval passes full message chains to scorers", async () => { + const scorer = vi.fn(async (opts) => { + expect(opts.input).toBe("Answer concisely.\nDescribe this image"); + expect(opts.output).toBe("A cat sitting on a chair."); + expect(opts.inputMessages).toEqual(multimodalInput); + expect(opts.outputMessages).toEqual(multimodalOutput); + expect(opts.messages).toEqual([...multimodalInput, ...multimodalOutput]); + return { score: 1 }; + }); + + const task = vi.fn(async (input) => { + expect(input).toEqual(multimodalInput); + return { messages: multimodalOutput }; + }); + + await expect(multimodalInput).toEval( + { expected: "cat" }, + task, + scorer, + 1.0, + ); + + expect(task).toHaveBeenCalledOnce(); + expect(scorer).toHaveBeenCalledOnce(); + }); + + test("rejects eval cases that define both input and messages", () => { + expect(() => + getTaskInput("hello", [ + { role: "user", parts: [{ type: "text", text: "world" }] }, + ]), + ).toThrow( + "Each eval case must define exactly one of `input` or `messages`.", + ); + }); + + test("rejects task outputs that define both result and messages", () => { + expect(() => + normalizeScorerPayload("hello", { + result: "hi", + messages: [ + { role: "assistant", parts: [{ type: "text", text: "hi" }] }, + ], + } as any), + ).toThrow( + "Task results must define exactly one of `result` or `messages`.", + ); + }); + + test("formats transcripts safely for debug output", () => { + expect(formatEvalValue(multimodalInput)).toMatchInlineSnapshot(` + "## system + + Answer concisely. + + ## user + + Describe this image + + [image image/png]" + `); + }); +}); + +describeEval("message chain scorer payload", { + data: async () => [ + { + name: "passes full chains through describeEval", + messages: multimodalInput, + }, + ], + task: async (input) => { + expect(input).toEqual(multimodalInput); + return { messages: multimodalOutput }; + }, + scorers: [ + async (opts) => { + expect(opts.inputMessages).toEqual(multimodalInput); + expect(opts.outputMessages).toEqual(multimodalOutput); + expect(opts.messages).toEqual([...multimodalInput, ...multimodalOutput]); + expect(opts.output).toBe("A cat sitting on a chair."); + return { score: 1 }; + }, + ], +}); + +describeEval("derived tool calls from message parts", { + data: async () => [ + { + name: "tool calls are derived without an explicit toolCalls array", + input: "What is the weather in Seattle?", + expectedTools: [ + { name: "getWeather", arguments: { location: "Seattle" } }, + ], + }, + ], + task: async () => ({ + messages: [ + { + role: "assistant", + parts: [ + { + type: "tool-call", + toolName: "getWeather", + toolCallId: "call-1", + input: { location: "Seattle" }, + }, + ], + }, + { + role: "tool", + parts: [ + { + type: "tool-result", + toolName: "getWeather", + toolCallId: "call-1", + output: { temperature: 72 }, + }, + ], + }, + { + role: "assistant", + parts: [{ type: "text", text: "It is 72F in Seattle." }], + }, + ], + }), + scorers: [ToolCallScorer()], +}); diff --git a/src/messages.ts b/src/messages.ts new file mode 100644 index 0000000..74530f1 --- /dev/null +++ b/src/messages.ts @@ -0,0 +1,558 @@ +import { wrapText } from "./wrapText"; + +export type ToolCall = { + name: string; + arguments?: any; + [key: string]: any; +}; + +type EvalBasePart = { + [key: string]: any; +}; + +export type EvalTextPart = EvalBasePart & { + type: "text"; + text: string; +}; + +export type EvalImagePart = EvalBasePart & { + type: "image"; + image: unknown; + mediaType?: string; +}; + +export type EvalFilePart = EvalBasePart & { + type: "file"; + data: unknown; + mediaType: string; + filename?: string; +}; + +export type EvalReasoningPart = EvalBasePart & { + type: "reasoning"; + text: string; +}; + +export type EvalToolCallPart = EvalBasePart & { + type: "tool-call"; + toolName: string; + input?: unknown; + toolCallId?: string; +}; + +export type EvalToolResultPart = EvalBasePart & { + type: "tool-result"; + toolName: string; + output: unknown; + toolCallId?: string; +}; + +export type EvalToolErrorPart = EvalBasePart & { + type: "tool-error"; + toolName: string; + error?: unknown; + output?: unknown; + toolCallId?: string; +}; + +export type EvalSourcePart = EvalBasePart & { + type: "source"; + source?: unknown; +}; + +export type EvalPart = + | EvalTextPart + | EvalImagePart + | EvalFilePart + | EvalReasoningPart + | EvalToolCallPart + | EvalToolResultPart + | EvalToolErrorPart + | EvalSourcePart; + +export type EvalMessage = { + role: "system" | "user" | "assistant" | "tool"; + parts: EvalPart[]; + metadata?: Record; + [key: string]: any; +}; + +export type TaskInput = string | EvalMessage[]; + +export type TaskResult = + | { + result: string; + messages?: never; + toolCalls?: ToolCall[]; + } + | { + messages: EvalMessage[]; + result?: never; + toolCalls?: ToolCall[]; + }; + +export type TaskOutput = string | TaskResult; + +export type EvalDataInput = + | { + input: string; + messages?: never; + } + | { + messages: EvalMessage[]; + input?: never; + }; + +export interface NormalizedInput { + input: string; + inputMessages: EvalMessage[]; +} + +export interface NormalizedOutput { + output: string; + outputMessages: EvalMessage[]; + toolCalls?: ToolCall[]; +} + +export interface NormalizedScorerPayload + extends NormalizedInput, + NormalizedOutput { + messages: EvalMessage[]; +} + +const EMPTY_TEXT = ""; + +function isRecord(value: unknown): value is Record { + return typeof value === "object" && value !== null && !Array.isArray(value); +} + +export function isEvalMessageArray(value: unknown): value is EvalMessage[] { + return ( + Array.isArray(value) && + value.every( + (item) => + isRecord(item) && + typeof item.role === "string" && + Array.isArray(item.parts), + ) + ); +} + +function textMessage(role: EvalMessage["role"], text: string): EvalMessage { + return { + role, + parts: [{ type: "text", text }], + }; +} + +function assertValidMessages( + messages: unknown, + fieldName: string, +): asserts messages is EvalMessage[] { + if (!isEvalMessageArray(messages)) { + throw new Error(`${fieldName} must be an array of message objects.`); + } +} + +function assertString( + value: unknown, + fieldName: string, +): asserts value is string { + if (typeof value !== "string") { + throw new Error(`${fieldName} must be a string.`); + } +} + +export function getTaskInput( + input: string | undefined, + messages: EvalMessage[] | undefined, +): TaskInput { + const hasInput = input !== undefined; + const hasMessages = messages !== undefined; + + if (hasInput === hasMessages) { + throw new Error( + "Each eval case must define exactly one of `input` or `messages`.", + ); + } + + if (hasInput) { + assertString(input, "`input`"); + return input; + } + + assertValidMessages(messages, "`messages`"); + return messages; +} + +export function normalizeEvalInput(input: TaskInput): NormalizedInput { + if (typeof input === "string") { + return { + input, + inputMessages: [textMessage("user", input)], + }; + } + + assertValidMessages(input, "Eval input"); + + return { + input: extractTextFromMessages(input), + inputMessages: input, + }; +} + +function getTaskResultVariant(taskOutput: TaskResult): "result" | "messages" { + const hasResult = "result" in taskOutput && taskOutput.result !== undefined; + const hasMessages = + "messages" in taskOutput && taskOutput.messages !== undefined; + + if (hasResult === hasMessages) { + throw new Error( + "Task results must define exactly one of `result` or `messages`.", + ); + } + + return hasResult ? "result" : "messages"; +} + +export function deriveToolCalls(messages: EvalMessage[]): ToolCall[] { + const orderedCalls: ToolCall[] = []; + const byId = new Map(); + + function findOpenCall(toolName: string) { + for (let index = orderedCalls.length - 1; index >= 0; index -= 1) { + const call = orderedCalls[index]; + if ( + call.name === toolName && + call.result === undefined && + call.error === undefined + ) { + return call; + } + } + return undefined; + } + + function upsertCall(part: { + toolName: string; + toolCallId?: string; + input?: unknown; + }) { + if (part.toolCallId) { + const existing = byId.get(part.toolCallId); + if (existing) { + if (part.input !== undefined) { + existing.arguments = part.input; + } + return existing; + } + } + + const openCall = part.toolCallId ? undefined : findOpenCall(part.toolName); + if (openCall) { + if (part.input !== undefined) { + openCall.arguments = part.input; + } + return openCall; + } + + const call: ToolCall = { + name: part.toolName, + ...(part.input !== undefined ? { arguments: part.input } : {}), + }; + + orderedCalls.push(call); + if (part.toolCallId) { + byId.set(part.toolCallId, call); + } + return call; + } + + for (const message of messages) { + for (const part of message.parts) { + if (part.type === "tool-call") { + upsertCall(part); + continue; + } + + if (part.type === "tool-result") { + const call = + (part.toolCallId ? byId.get(part.toolCallId) : undefined) ?? + findOpenCall(part.toolName) ?? + upsertCall(part); + call.result = part.output; + continue; + } + + if (part.type === "tool-error") { + const call = + (part.toolCallId ? byId.get(part.toolCallId) : undefined) ?? + findOpenCall(part.toolName) ?? + upsertCall(part); + call.error = part.error ?? + part.output ?? { + toolName: part.toolName, + }; + } + } + } + + return orderedCalls; +} + +export function normalizeTaskOutput(taskOutput: TaskOutput): NormalizedOutput { + if (typeof taskOutput === "string") { + return { + output: taskOutput, + outputMessages: [textMessage("assistant", taskOutput)], + }; + } + + if (!isRecord(taskOutput)) { + throw new Error( + "Task output must be either a string or an object with `result` or `messages`.", + ); + } + + const variant = getTaskResultVariant(taskOutput as TaskResult); + + if (variant === "result") { + assertString(taskOutput.result, "`result`"); + return { + output: taskOutput.result, + outputMessages: [textMessage("assistant", taskOutput.result)], + toolCalls: taskOutput.toolCalls, + }; + } + + assertValidMessages(taskOutput.messages, "`messages`"); + + return { + output: extractTextFromMessages(taskOutput.messages), + outputMessages: taskOutput.messages, + toolCalls: taskOutput.toolCalls ?? deriveToolCalls(taskOutput.messages), + }; +} + +export function normalizeScorerPayload( + input: TaskInput, + taskOutput: TaskOutput, +): NormalizedScorerPayload { + const normalizedInput = normalizeEvalInput(input); + const normalizedOutput = normalizeTaskOutput(taskOutput); + + return { + ...normalizedInput, + ...normalizedOutput, + messages: [ + ...normalizedInput.inputMessages, + ...normalizedOutput.outputMessages, + ], + }; +} + +export function normalizeEvaluateOutput(taskOutput: TaskOutput): { + messages: EvalMessage[]; + output: string; + toolCalls?: ToolCall[]; +} { + const normalizedOutput = normalizeTaskOutput(taskOutput); + + return { + messages: normalizedOutput.outputMessages, + output: normalizedOutput.output, + toolCalls: normalizedOutput.toolCalls, + }; +} + +function extractTextFromPart(part: EvalPart): string { + switch (part.type) { + case "text": + case "reasoning": + return part.text; + default: + return EMPTY_TEXT; + } +} + +export function extractTextFromMessages(messages: EvalMessage[]): string { + return messages + .flatMap((message) => message.parts.map(extractTextFromPart)) + .filter(Boolean) + .join("\n"); +} + +function summarizeUnknown(value: unknown): string { + if (typeof value === "string") { + return value; + } + + if (value instanceof URL) { + return value.toString(); + } + + if (value instanceof Error) { + return value.message; + } + + try { + const json = JSON.stringify(value, null, 2); + return json ?? String(value); + } catch { + return String(value); + } +} + +function formatPartForDisplay(part: EvalPart): string { + switch (part.type) { + case "text": + return part.text; + case "reasoning": + return `[reasoning]\n${part.text}`; + case "image": + return `[image${part.mediaType ? ` ${part.mediaType}` : ""}]`; + case "file": + return `[file${part.filename ? ` ${part.filename}` : ""}${part.mediaType ? ` ${part.mediaType}` : ""}]`; + case "tool-call": + return `[tool-call ${part.toolName}]${part.input !== undefined ? ` ${summarizeUnknown(part.input)}` : ""}`; + case "tool-result": + return `[tool-result ${part.toolName}] ${summarizeUnknown(part.output)}`; + case "tool-error": + return `[tool-error ${part.toolName}] ${summarizeUnknown(part.error ?? part.output)}`; + case "source": + return `[source] ${summarizeUnknown(part.source ?? part)}`; + } +} + +export function formatMessages(messages: EvalMessage[]): string { + if (messages.length === 0) { + return "(empty transcript)"; + } + + return messages + .map((message) => { + const heading = `## ${message.role}`; + const body = message.parts.length + ? message.parts.map(formatPartForDisplay).join("\n\n") + : "(empty message)"; + return `${heading}\n\n${body}`; + }) + .join("\n\n"); +} + +export function formatEvalValue(value: unknown): string { + if (typeof value === "string") { + return wrapText(value); + } + + if (isEvalMessageArray(value)) { + return formatMessages(value); + } + + return wrapText(summarizeUnknown(value)); +} + +function pushJudgeText(content: Array, text: string) { + if (text.length === 0) { + return; + } + + const lastPart = content[content.length - 1]; + if (lastPart?.type === "text") { + lastPart.text += text; + return; + } + + content.push({ type: "text", text }); +} + +export function toJudgeUserMessage(messages: EvalMessage[]) { + const visibleMessages = messages + .filter( + (message) => message.role === "user" || message.role === "assistant", + ) + .map((message) => ({ + ...message, + parts: message.parts.filter( + (part) => + part.type === "text" || part.type === "image" || part.type === "file", + ), + })) + .filter((message) => message.parts.length > 0); + + const content: Array = []; + + if (visibleMessages.length === 0) { + content.push({ + type: "text", + text: "(no user-facing transcript)", + }); + return { role: "user" as const, content }; + } + + visibleMessages.forEach((message, index) => { + if (index > 0) { + pushJudgeText(content, "\n\n"); + } + + pushJudgeText(content, `[${message.role.toUpperCase()}]\n`); + + if (message.parts.length === 0) { + pushJudgeText(content, "(empty message)"); + return; + } + + message.parts.forEach((part, partIndex) => { + if (partIndex > 0) { + pushJudgeText(content, "\n"); + } + + switch (part.type) { + case "text": + pushJudgeText(content, `${part.text}\n`); + return; + case "image": + pushJudgeText( + content, + `[image${part.mediaType ? ` ${part.mediaType}` : ""}]\n`, + ); + content.push({ + type: "image", + image: part.image, + ...(part.mediaType ? { mediaType: part.mediaType } : {}), + }); + pushJudgeText(content, "\n"); + return; + case "file": + pushJudgeText( + content, + `[file${part.filename ? ` ${part.filename}` : ""}${part.mediaType ? ` ${part.mediaType}` : ""}]\n`, + ); + content.push({ + type: "file", + data: part.data, + mediaType: part.mediaType, + ...(part.filename ? { filename: part.filename } : {}), + }); + pushJudgeText(content, "\n"); + return; + } + }); + }); + + const lastPart = content[content.length - 1]; + if (lastPart?.type === "text") { + lastPart.text = lastPart.text.trimEnd(); + } + + return { role: "user" as const, content }; +} + +export function getDefaultTestName(input: TaskInput): string { + if (typeof input === "string") { + return input; + } + + const firstText = extractTextFromMessages(input).trim(); + return firstText.length > 0 ? firstText : "message chain"; +} From 8a8e78b4ed75254f55890ab52b8df396caa20fe7 Mon Sep 17 00:00:00 2001 From: David Cramer Date: Tue, 17 Mar 2026 11:46:46 -0700 Subject: [PATCH 2/2] Simplify multimodal evals around transcripts --- src/ai-sdk-integration.test.ts | 19 +- src/evaluate/index.test.ts | 33 +-- src/evaluate/index.ts | 14 +- src/formatScores.test.ts | 2 +- src/index.ts | 28 ++- src/messages.test.ts | 94 ++++----- src/messages.ts | 355 ++++++++------------------------- 7 files changed, 159 insertions(+), 386 deletions(-) diff --git a/src/ai-sdk-integration.test.ts b/src/ai-sdk-integration.test.ts index 82443ba..d1fa3b2 100644 --- a/src/ai-sdk-integration.test.ts +++ b/src/ai-sdk-integration.test.ts @@ -72,7 +72,12 @@ describeEval("@ai/sdk ToolCallScorer", { }); return { - result: text, + transcript: [ + { + role: "assistant", + parts: [{ type: "text", text }], + }, + ], toolCalls: steps .flatMap((step) => step.toolCalls) .map((call) => ({ @@ -112,10 +117,7 @@ describeEval("@ai/sdk StructuredOutputScorer", { }), }); - return { - result: JSON.stringify(object), - toolCalls: [], - }; + return JSON.stringify(object); }, scorers: [ StructuredOutputScorer({ @@ -148,7 +150,12 @@ describeEval("@ai/sdk ToolCallScorer (No stopWhen)", { }); return { - result: text, + transcript: [ + { + role: "assistant", + parts: [{ type: "text", text }], + }, + ], toolCalls: steps .flatMap((step) => step.toolCalls) .map((call) => ({ diff --git a/src/evaluate/index.test.ts b/src/evaluate/index.test.ts index 1c6a5e6..def56db 100644 --- a/src/evaluate/index.test.ts +++ b/src/evaluate/index.test.ts @@ -157,7 +157,7 @@ describe("evaluate", () => { expect(call.messages[1].content).toContain("must mention specific details"); }); - test("passes multimodal message chains to the judge", async () => { + test("passes multimodal transcripts to the judge", async () => { mockGenerateObject.mockResolvedValueOnce({ object: { answer: "A", rationale: "Handled the transcript correctly" }, } as any); @@ -165,7 +165,7 @@ describe("evaluate", () => { const ctx = makeContext(); await _evaluate(ctx, { task: async () => ({ - messages: [ + transcript: [ { role: "user", parts: [ @@ -210,36 +210,21 @@ describe("evaluate", () => { const ctx = makeContext(); await _evaluate(ctx, { task: async () => ({ - messages: [ + transcript: [ { role: "user", parts: [{ type: "text", text: "What is the weather?" }], }, { role: "assistant", - parts: [ - { - type: "tool-call", - toolName: "getWeather", - toolCallId: "call-1", - input: { location: "Seattle" }, - }, - ], - }, - { - role: "tool", - parts: [ - { - type: "tool-result", - toolName: "getWeather", - toolCallId: "call-1", - output: { temperature: 72 }, - }, - ], + parts: [{ type: "text", text: "It is 72F in Seattle." }], }, + ], + toolCalls: [ { - role: "assistant", - parts: [{ type: "text", text: "It is 72F in Seattle." }], + name: "getWeather", + arguments: { location: "Seattle" }, + result: { temperature: 72 }, }, ], }), diff --git a/src/evaluate/index.ts b/src/evaluate/index.ts index 6ff73c6..0c8244f 100644 --- a/src/evaluate/index.ts +++ b/src/evaluate/index.ts @@ -2,7 +2,7 @@ import { generateObject } from "ai"; import { z } from "zod"; import { assert, test } from "vitest"; import { - type TaskResult, + type Transcript, formatEvalValue, normalizeEvaluateOutput, toJudgeUserMessage, @@ -38,7 +38,7 @@ const CHOICE_SCORES: Record = { }; interface EvaluateOptions { - task: () => Promise; + task: () => Promise; criteria: string; threshold?: number; } @@ -58,7 +58,7 @@ export async function _evaluate( ); } - let taskOutput: string | TaskResult; + let taskOutput: string | { transcript: Transcript }; let evaluationOutput: ReturnType; try { taskOutput = await opts.task(); @@ -88,7 +88,7 @@ export async function _evaluate( }), system: EVAL_SYSTEM, messages: [ - toJudgeUserMessage(evaluationOutput.messages), + toJudgeUserMessage(evaluationOutput.transcript), { role: "user", content: EVAL_PROMPT(opts.criteria), @@ -128,11 +128,7 @@ export async function _evaluate( assert( false, `Score: ${score} (${object.answer}) below threshold: ${threshold}\n\n## Output:\n${formatEvalValue( - typeof taskOutput === "string" - ? taskOutput - : "result" in taskOutput && taskOutput.result !== undefined - ? taskOutput.result - : taskOutput.messages, + typeof taskOutput === "string" ? taskOutput : taskOutput.transcript, )}\n\n## Rationale:\n${formatEvalValue(object.rationale)}`, ); } diff --git a/src/formatScores.test.ts b/src/formatScores.test.ts index ba4c23b..a3ebd4a 100644 --- a/src/formatScores.test.ts +++ b/src/formatScores.test.ts @@ -73,7 +73,7 @@ describe("formatScores", () => { `); }); - it("should format message-chain outputs", () => { + it("should format transcript outputs", () => { const scores = [ { name: "Scorer A", diff --git a/src/index.ts b/src/index.ts index b90ff6a..b8c85ed 100644 --- a/src/index.ts +++ b/src/index.ts @@ -9,7 +9,9 @@ import { import "vitest"; import { type EvalDataInput, - type EvalMessage, + type Transcript, + type TranscriptMessage, + type TranscriptPart, type TaskInput, type TaskResult, type ToolCall, @@ -22,7 +24,8 @@ import { wrapText } from "./wrapText"; /** * Task function that processes an input and returns either a string result - * or a TaskResult object containing response messages and any tool calls made. + * or a TaskResult object containing a multimodal response transcript and any + * tool calls made. */ export type TaskFn = (input: TaskInput) => Promise; @@ -37,9 +40,7 @@ export type Score = { export interface BaseScorerOptions { input: string; output: string; - messages: EvalMessage[]; - inputMessages: EvalMessage[]; - outputMessages: EvalMessage[]; + transcript: Transcript; toolCalls?: ToolCall[]; } @@ -81,11 +82,7 @@ function formatEvaluationOutputForDisplay( return formatEvalValue(taskOutput); } - if ("result" in taskOutput && taskOutput.result !== undefined) { - return formatEvalValue(taskOutput.result); - } - - return formatEvalValue(taskOutput.messages); + return formatEvalValue(taskOutput.transcript); } expect.extend({ @@ -152,16 +149,16 @@ export function describeEval( for (const testCase of await data()) { const { input, - messages, + transcript, name: testName, ...params } = testCase as { input?: string; - messages?: EvalMessage[]; + transcript?: Transcript; name?: string; } & Record; - const taskInput = getTaskInput(input, messages); + const taskInput = getTaskInput(input, transcript); testFn( testName ?? getDefaultTestName(taskInput), @@ -239,8 +236,9 @@ export function formatScores(scores: (Score & { name: string })[]) { export { wrapText } from "./wrapText"; export type { EvalDataInput, - EvalMessage, - EvalPart, + Transcript, + TranscriptMessage, + TranscriptPart, TaskInput, TaskResult, ToolCall, diff --git a/src/messages.test.ts b/src/messages.test.ts index 8451387..8857ca8 100644 --- a/src/messages.test.ts +++ b/src/messages.test.ts @@ -4,14 +4,10 @@ import { formatEvalValue, getTaskInput, normalizeScorerPayload, - type EvalMessage, + type Transcript, } from "./messages"; -const multimodalInput: EvalMessage[] = [ - { - role: "system", - parts: [{ type: "text", text: "Answer concisely." }], - }, +const multimodalInput: Transcript = [ { role: "user", parts: [ @@ -25,27 +21,28 @@ const multimodalInput: EvalMessage[] = [ }, ]; -const multimodalOutput: EvalMessage[] = [ +const multimodalOutput: Transcript = [ { role: "assistant", parts: [{ type: "text", text: "A cat sitting on a chair." }], }, ]; -describe("message normalization", () => { - test("toEval passes full message chains to scorers", async () => { +describe("transcript normalization", () => { + test("toEval passes a combined transcript to scorers", async () => { const scorer = vi.fn(async (opts) => { - expect(opts.input).toBe("Answer concisely.\nDescribe this image"); + expect(opts.input).toBe("Describe this image"); expect(opts.output).toBe("A cat sitting on a chair."); - expect(opts.inputMessages).toEqual(multimodalInput); - expect(opts.outputMessages).toEqual(multimodalOutput); - expect(opts.messages).toEqual([...multimodalInput, ...multimodalOutput]); + expect(opts.transcript).toEqual([ + ...multimodalInput, + ...multimodalOutput, + ]); return { score: 1 }; }); const task = vi.fn(async (input) => { expect(input).toEqual(multimodalInput); - return { messages: multimodalOutput }; + return { transcript: multimodalOutput }; }); await expect(multimodalInput).toEval( @@ -59,36 +56,27 @@ describe("message normalization", () => { expect(scorer).toHaveBeenCalledOnce(); }); - test("rejects eval cases that define both input and messages", () => { + test("rejects eval cases that define both input and transcript", () => { expect(() => getTaskInput("hello", [ { role: "user", parts: [{ type: "text", text: "world" }] }, ]), ).toThrow( - "Each eval case must define exactly one of `input` or `messages`.", + "Each eval case must define exactly one of `input` or `transcript`.", ); }); - test("rejects task outputs that define both result and messages", () => { + test("rejects task outputs without a transcript", () => { expect(() => - normalizeScorerPayload("hello", { - result: "hi", - messages: [ - { role: "assistant", parts: [{ type: "text", text: "hi" }] }, - ], - } as any), + normalizeScorerPayload("hello", { messages: [] } as any), ).toThrow( - "Task results must define exactly one of `result` or `messages`.", + "Task output must be either a string or an object with `transcript`.", ); }); test("formats transcripts safely for debug output", () => { expect(formatEvalValue(multimodalInput)).toMatchInlineSnapshot(` - "## system - - Answer concisely. - - ## user + "## user Describe this image @@ -97,32 +85,33 @@ describe("message normalization", () => { }); }); -describeEval("message chain scorer payload", { +describeEval("transcript scorer payload", { data: async () => [ { - name: "passes full chains through describeEval", - messages: multimodalInput, + name: "passes transcript through describeEval", + transcript: multimodalInput, }, ], task: async (input) => { expect(input).toEqual(multimodalInput); - return { messages: multimodalOutput }; + return { transcript: multimodalOutput }; }, scorers: [ async (opts) => { - expect(opts.inputMessages).toEqual(multimodalInput); - expect(opts.outputMessages).toEqual(multimodalOutput); - expect(opts.messages).toEqual([...multimodalInput, ...multimodalOutput]); + expect(opts.transcript).toEqual([ + ...multimodalInput, + ...multimodalOutput, + ]); expect(opts.output).toBe("A cat sitting on a chair."); return { score: 1 }; }, ], }); -describeEval("derived tool calls from message parts", { +describeEval("explicit tool call metadata", { data: async () => [ { - name: "tool calls are derived without an explicit toolCalls array", + name: "tool calls are passed separately from the transcript", input: "What is the weather in Seattle?", expectedTools: [ { name: "getWeather", arguments: { location: "Seattle" } }, @@ -130,32 +119,17 @@ describeEval("derived tool calls from message parts", { }, ], task: async () => ({ - messages: [ + transcript: [ { role: "assistant", - parts: [ - { - type: "tool-call", - toolName: "getWeather", - toolCallId: "call-1", - input: { location: "Seattle" }, - }, - ], - }, - { - role: "tool", - parts: [ - { - type: "tool-result", - toolName: "getWeather", - toolCallId: "call-1", - output: { temperature: 72 }, - }, - ], + parts: [{ type: "text", text: "It is 72F in Seattle." }], }, + ], + toolCalls: [ { - role: "assistant", - parts: [{ type: "text", text: "It is 72F in Seattle." }], + name: "getWeather", + arguments: { location: "Seattle" }, + result: { temperature: 72 }, }, ], }), diff --git a/src/messages.ts b/src/messages.ts index 74530f1..1b60351 100644 --- a/src/messages.ts +++ b/src/messages.ts @@ -2,155 +2,109 @@ import { wrapText } from "./wrapText"; export type ToolCall = { name: string; - arguments?: any; - [key: string]: any; + arguments?: unknown; + [key: string]: unknown; }; -type EvalBasePart = { - [key: string]: any; -}; - -export type EvalTextPart = EvalBasePart & { +export type TranscriptTextPart = { type: "text"; text: string; }; -export type EvalImagePart = EvalBasePart & { +export type TranscriptImagePart = { type: "image"; image: unknown; mediaType?: string; }; -export type EvalFilePart = EvalBasePart & { +export type TranscriptFilePart = { type: "file"; data: unknown; mediaType: string; filename?: string; }; -export type EvalReasoningPart = EvalBasePart & { - type: "reasoning"; - text: string; -}; +export type TranscriptPart = + | TranscriptTextPart + | TranscriptImagePart + | TranscriptFilePart; -export type EvalToolCallPart = EvalBasePart & { - type: "tool-call"; - toolName: string; - input?: unknown; - toolCallId?: string; +export type TranscriptMessage = { + role: "user" | "assistant"; + parts: TranscriptPart[]; }; -export type EvalToolResultPart = EvalBasePart & { - type: "tool-result"; - toolName: string; - output: unknown; - toolCallId?: string; -}; +export type Transcript = TranscriptMessage[]; -export type EvalToolErrorPart = EvalBasePart & { - type: "tool-error"; - toolName: string; - error?: unknown; - output?: unknown; - toolCallId?: string; -}; - -export type EvalSourcePart = EvalBasePart & { - type: "source"; - source?: unknown; -}; +export type TaskInput = string | Transcript; -export type EvalPart = - | EvalTextPart - | EvalImagePart - | EvalFilePart - | EvalReasoningPart - | EvalToolCallPart - | EvalToolResultPart - | EvalToolErrorPart - | EvalSourcePart; - -export type EvalMessage = { - role: "system" | "user" | "assistant" | "tool"; - parts: EvalPart[]; - metadata?: Record; - [key: string]: any; +export type TaskResult = { + transcript: Transcript; + toolCalls?: ToolCall[]; }; -export type TaskInput = string | EvalMessage[]; - -export type TaskResult = - | { - result: string; - messages?: never; - toolCalls?: ToolCall[]; - } - | { - messages: EvalMessage[]; - result?: never; - toolCalls?: ToolCall[]; - }; - export type TaskOutput = string | TaskResult; export type EvalDataInput = | { input: string; - messages?: never; + transcript?: never; } | { - messages: EvalMessage[]; + transcript: Transcript; input?: never; }; -export interface NormalizedInput { +interface NormalizedInput { input: string; - inputMessages: EvalMessage[]; + inputTranscript: Transcript; } -export interface NormalizedOutput { +interface NormalizedOutput { output: string; - outputMessages: EvalMessage[]; + outputTranscript: Transcript; toolCalls?: ToolCall[]; } -export interface NormalizedScorerPayload - extends NormalizedInput, - NormalizedOutput { - messages: EvalMessage[]; +export interface NormalizedScorerPayload { + input: string; + output: string; + transcript: Transcript; + toolCalls?: ToolCall[]; } -const EMPTY_TEXT = ""; - -function isRecord(value: unknown): value is Record { +function isRecord(value: unknown): value is Record { return typeof value === "object" && value !== null && !Array.isArray(value); } -export function isEvalMessageArray(value: unknown): value is EvalMessage[] { +export function isTranscript(value: unknown): value is Transcript { return ( Array.isArray(value) && value.every( - (item) => - isRecord(item) && - typeof item.role === "string" && - Array.isArray(item.parts), + (message) => + isRecord(message) && + (message.role === "user" || message.role === "assistant") && + Array.isArray(message.parts), ) ); } -function textMessage(role: EvalMessage["role"], text: string): EvalMessage { +function textMessage( + role: TranscriptMessage["role"], + text: string, +): TranscriptMessage { return { role, parts: [{ type: "text", text }], }; } -function assertValidMessages( - messages: unknown, +function assertValidTranscript( + transcript: unknown, fieldName: string, -): asserts messages is EvalMessage[] { - if (!isEvalMessageArray(messages)) { - throw new Error(`${fieldName} must be an array of message objects.`); +): asserts transcript is Transcript { + if (!isTranscript(transcript)) { + throw new Error(`${fieldName} must be an array of transcript messages.`); } } @@ -165,14 +119,14 @@ function assertString( export function getTaskInput( input: string | undefined, - messages: EvalMessage[] | undefined, + transcript: Transcript | undefined, ): TaskInput { const hasInput = input !== undefined; - const hasMessages = messages !== undefined; + const hasTranscript = transcript !== undefined; - if (hasInput === hasMessages) { + if (hasInput === hasTranscript) { throw new Error( - "Each eval case must define exactly one of `input` or `messages`.", + "Each eval case must define exactly one of `input` or `transcript`.", ); } @@ -181,156 +135,46 @@ export function getTaskInput( return input; } - assertValidMessages(messages, "`messages`"); - return messages; + assertValidTranscript(transcript, "`transcript`"); + return transcript; } export function normalizeEvalInput(input: TaskInput): NormalizedInput { if (typeof input === "string") { return { input, - inputMessages: [textMessage("user", input)], + inputTranscript: [textMessage("user", input)], }; } - assertValidMessages(input, "Eval input"); + assertValidTranscript(input, "Eval input"); return { - input: extractTextFromMessages(input), - inputMessages: input, + input: extractTextFromTranscript(input), + inputTranscript: input, }; } -function getTaskResultVariant(taskOutput: TaskResult): "result" | "messages" { - const hasResult = "result" in taskOutput && taskOutput.result !== undefined; - const hasMessages = - "messages" in taskOutput && taskOutput.messages !== undefined; - - if (hasResult === hasMessages) { - throw new Error( - "Task results must define exactly one of `result` or `messages`.", - ); - } - - return hasResult ? "result" : "messages"; -} - -export function deriveToolCalls(messages: EvalMessage[]): ToolCall[] { - const orderedCalls: ToolCall[] = []; - const byId = new Map(); - - function findOpenCall(toolName: string) { - for (let index = orderedCalls.length - 1; index >= 0; index -= 1) { - const call = orderedCalls[index]; - if ( - call.name === toolName && - call.result === undefined && - call.error === undefined - ) { - return call; - } - } - return undefined; - } - - function upsertCall(part: { - toolName: string; - toolCallId?: string; - input?: unknown; - }) { - if (part.toolCallId) { - const existing = byId.get(part.toolCallId); - if (existing) { - if (part.input !== undefined) { - existing.arguments = part.input; - } - return existing; - } - } - - const openCall = part.toolCallId ? undefined : findOpenCall(part.toolName); - if (openCall) { - if (part.input !== undefined) { - openCall.arguments = part.input; - } - return openCall; - } - - const call: ToolCall = { - name: part.toolName, - ...(part.input !== undefined ? { arguments: part.input } : {}), - }; - - orderedCalls.push(call); - if (part.toolCallId) { - byId.set(part.toolCallId, call); - } - return call; - } - - for (const message of messages) { - for (const part of message.parts) { - if (part.type === "tool-call") { - upsertCall(part); - continue; - } - - if (part.type === "tool-result") { - const call = - (part.toolCallId ? byId.get(part.toolCallId) : undefined) ?? - findOpenCall(part.toolName) ?? - upsertCall(part); - call.result = part.output; - continue; - } - - if (part.type === "tool-error") { - const call = - (part.toolCallId ? byId.get(part.toolCallId) : undefined) ?? - findOpenCall(part.toolName) ?? - upsertCall(part); - call.error = part.error ?? - part.output ?? { - toolName: part.toolName, - }; - } - } - } - - return orderedCalls; -} - export function normalizeTaskOutput(taskOutput: TaskOutput): NormalizedOutput { if (typeof taskOutput === "string") { return { output: taskOutput, - outputMessages: [textMessage("assistant", taskOutput)], + outputTranscript: [textMessage("assistant", taskOutput)], }; } - if (!isRecord(taskOutput)) { + if (!isRecord(taskOutput) || !("transcript" in taskOutput)) { throw new Error( - "Task output must be either a string or an object with `result` or `messages`.", + "Task output must be either a string or an object with `transcript`.", ); } - const variant = getTaskResultVariant(taskOutput as TaskResult); - - if (variant === "result") { - assertString(taskOutput.result, "`result`"); - return { - output: taskOutput.result, - outputMessages: [textMessage("assistant", taskOutput.result)], - toolCalls: taskOutput.toolCalls, - }; - } - - assertValidMessages(taskOutput.messages, "`messages`"); + assertValidTranscript(taskOutput.transcript, "`transcript`"); return { - output: extractTextFromMessages(taskOutput.messages), - outputMessages: taskOutput.messages, - toolCalls: taskOutput.toolCalls ?? deriveToolCalls(taskOutput.messages), + output: extractTextFromTranscript(taskOutput.transcript), + outputTranscript: taskOutput.transcript, + toolCalls: taskOutput.toolCalls, }; } @@ -342,41 +186,34 @@ export function normalizeScorerPayload( const normalizedOutput = normalizeTaskOutput(taskOutput); return { - ...normalizedInput, - ...normalizedOutput, - messages: [ - ...normalizedInput.inputMessages, - ...normalizedOutput.outputMessages, + input: normalizedInput.input, + output: normalizedOutput.output, + transcript: [ + ...normalizedInput.inputTranscript, + ...normalizedOutput.outputTranscript, ], + toolCalls: normalizedOutput.toolCalls, }; } export function normalizeEvaluateOutput(taskOutput: TaskOutput): { - messages: EvalMessage[]; + transcript: Transcript; output: string; - toolCalls?: ToolCall[]; } { const normalizedOutput = normalizeTaskOutput(taskOutput); return { - messages: normalizedOutput.outputMessages, + transcript: normalizedOutput.outputTranscript, output: normalizedOutput.output, - toolCalls: normalizedOutput.toolCalls, }; } -function extractTextFromPart(part: EvalPart): string { - switch (part.type) { - case "text": - case "reasoning": - return part.text; - default: - return EMPTY_TEXT; - } +function extractTextFromPart(part: TranscriptPart): string { + return part.type === "text" ? part.text : ""; } -export function extractTextFromMessages(messages: EvalMessage[]): string { - return messages +export function extractTextFromTranscript(transcript: Transcript): string { + return transcript .flatMap((message) => message.parts.map(extractTextFromPart)) .filter(Boolean) .join("\n"); @@ -403,33 +240,23 @@ function summarizeUnknown(value: unknown): string { } } -function formatPartForDisplay(part: EvalPart): string { +function formatPartForDisplay(part: TranscriptPart): string { switch (part.type) { case "text": return part.text; - case "reasoning": - return `[reasoning]\n${part.text}`; case "image": return `[image${part.mediaType ? ` ${part.mediaType}` : ""}]`; case "file": return `[file${part.filename ? ` ${part.filename}` : ""}${part.mediaType ? ` ${part.mediaType}` : ""}]`; - case "tool-call": - return `[tool-call ${part.toolName}]${part.input !== undefined ? ` ${summarizeUnknown(part.input)}` : ""}`; - case "tool-result": - return `[tool-result ${part.toolName}] ${summarizeUnknown(part.output)}`; - case "tool-error": - return `[tool-error ${part.toolName}] ${summarizeUnknown(part.error ?? part.output)}`; - case "source": - return `[source] ${summarizeUnknown(part.source ?? part)}`; } } -export function formatMessages(messages: EvalMessage[]): string { - if (messages.length === 0) { +export function formatTranscript(transcript: Transcript): string { + if (transcript.length === 0) { return "(empty transcript)"; } - return messages + return transcript .map((message) => { const heading = `## ${message.role}`; const body = message.parts.length @@ -445,8 +272,8 @@ export function formatEvalValue(value: unknown): string { return wrapText(value); } - if (isEvalMessageArray(value)) { - return formatMessages(value); + if (isTranscript(value)) { + return formatTranscript(value); } return wrapText(summarizeUnknown(value)); @@ -466,31 +293,18 @@ function pushJudgeText(content: Array, text: string) { content.push({ type: "text", text }); } -export function toJudgeUserMessage(messages: EvalMessage[]) { - const visibleMessages = messages - .filter( - (message) => message.role === "user" || message.role === "assistant", - ) - .map((message) => ({ - ...message, - parts: message.parts.filter( - (part) => - part.type === "text" || part.type === "image" || part.type === "file", - ), - })) - .filter((message) => message.parts.length > 0); - +export function toJudgeUserMessage(transcript: Transcript) { const content: Array = []; - if (visibleMessages.length === 0) { + if (transcript.length === 0) { content.push({ type: "text", - text: "(no user-facing transcript)", + text: "(empty transcript)", }); return { role: "user" as const, content }; } - visibleMessages.forEach((message, index) => { + transcript.forEach((message, index) => { if (index > 0) { pushJudgeText(content, "\n\n"); } @@ -535,7 +349,6 @@ export function toJudgeUserMessage(messages: EvalMessage[]) { ...(part.filename ? { filename: part.filename } : {}), }); pushJudgeText(content, "\n"); - return; } }); }); @@ -553,6 +366,6 @@ export function getDefaultTestName(input: TaskInput): string { return input; } - const firstText = extractTextFromMessages(input).trim(); - return firstText.length > 0 ? firstText : "message chain"; + const firstText = extractTextFromTranscript(input).trim(); + return firstText.length > 0 ? firstText : "transcript"; }