diff --git a/frontend/src/components/editor/ai/__tests__/completion-utils.test.ts b/frontend/src/components/editor/ai/__tests__/completion-utils.test.ts index 449c283ce27..e5c4c8ac7d9 100644 --- a/frontend/src/components/editor/ai/__tests__/completion-utils.test.ts +++ b/frontend/src/components/editor/ai/__tests__/completion-utils.test.ts @@ -1,6 +1,15 @@ /* Copyright 2026 Marimo. All rights reserved. */ -import { beforeEach, describe, expect, it, type Mock, vi } from "vitest"; +import { + afterEach, + beforeEach, + describe, + expect, + it, + type Mock, + vi, +} from "vitest"; import { variableName } from "@/__tests__/branded"; +import * as aiContext from "@/core/ai/context/context"; import { getCodes } from "@/core/codemirror/copilot/getCodes"; import { dataSourceConnectionsAtom } from "@/core/datasets/data-source-connections"; import { DUCKDB_ENGINE } from "@/core/datasets/engines"; @@ -8,10 +17,11 @@ import { datasetsAtom } from "@/core/datasets/state"; import type { DatasetsState } from "@/core/datasets/types"; import { store } from "@/core/state/jotai"; import { variablesAtom } from "@/core/variables/state"; -import type { UIMessage } from "ai"; +import type { FileUIPart, UIMessage } from "ai"; import { codeToCells, getAICompletionBody, + getAICompletionBodyWithAttachments, isContextAttachment, MARIMO_CONTEXT_PART_TYPE, resolveChatContext, @@ -440,6 +450,42 @@ describe("isContextAttachment", () => { }); }); +describe("context attachment stamping", () => { + const rawAttachment: FileUIPart = { + type: "file", + mediaType: "image/png", + url: "data:image/png;base64,abc", + }; + + beforeEach(() => { + vi.spyOn(aiContext, "getAIContextRegistry").mockReturnValue({ + parseAllContextIds: () => ["data://t1"], + formatContextForAI: () => '', + getAttachmentsForContext: async () => [rawAttachment], + } as unknown as ReturnType); + }); + + afterEach(() => { + vi.restoreAllMocks(); + }); + + it("stamps chat attachments as context-derived", async () => { + const { attachments } = await resolveChatContext("see @data://t1"); + expect(attachments).toHaveLength(1); + expect(isContextAttachment(attachments[0])).toBe(true); + // The original attachment is left untouched (we return a stamped copy). + expect(rawAttachment.providerMetadata).toBeUndefined(); + }); + + it("stamps completion attachments the same way as chat", async () => { + const { attachments } = await getAICompletionBodyWithAttachments({ + input: "see @data://t1", + }); + expect(attachments).toHaveLength(1); + expect(isContextAttachment(attachments[0])).toBe(true); + }); +}); + describe("codeToCells", () => { it("should return empty array for empty string", () => { const code = ""; diff --git a/frontend/src/components/editor/ai/completion-utils.ts b/frontend/src/components/editor/ai/completion-utils.ts index ef408595d42..d41cb095892 100644 --- a/frontend/src/components/editor/ai/completion-utils.ts +++ b/frontend/src/components/editor/ai/completion-utils.ts @@ -9,6 +9,7 @@ import { import type { ReactCodeMirrorRef } from "@uiw/react-codemirror"; import type { DataUIPart, FileUIPart, UIMessage } from "ai"; import { getAIContextRegistry } from "@/core/ai/context/context"; +import type { ContextLocatorId } from "@/core/ai/context/registry"; import { getCodes } from "@/core/codemirror/copilot/getCodes"; import type { LanguageAdapterType } from "@/core/codemirror/language/types"; import type { AiCompletionRequest } from "@/core/network/types"; @@ -89,20 +90,51 @@ export function isContextAttachment(part: UIMessage["parts"][number]): boolean { } /** - * Resolve @-context for messages. They represent referenced - * datasets, variables, or other context from the user's prompt. + * Stamp a context-derived attachment with a provenance marker. + * + * Some @-mentions resolve to file attachments (e.g. a cell's image output), + * which get appended to the user message right alongside files the user + * uploaded by hand. Once they're in the message the two are indistinguishable, + * so we mark the context-derived ones. This matters on message edit: we + * re-resolve context from the edited text, and `isContextAttachment` lets us + * drop only the stale context attachments while preserving the user's own + * uploads */ -export async function resolveChatContext( +function stampContextAttachment(attachment: FileUIPart): FileUIPart { + return { + ...attachment, + providerMetadata: { + ...attachment.providerMetadata, + // Merge within the `marimo` namespace so we don't clobber any other + // marimo metadata a provider may have already set. + marimo: { + ...attachment.providerMetadata?.marimo, + ...CONTEXT_ATTACHMENT_METADATA.marimo, + }, + }, + }; +} + +interface ResolvedContext { + plainText: string; + contextIds: ContextLocatorId[]; + attachments: FileUIPart[]; +} + +/** + * Parse @-context for messages + */ +async function resolveContextAttachments( input: string, -): Promise { +): Promise { if (!input.includes(CONTEXT_TRIGGER)) { - return { contextPart: null, attachments: [] }; + return { plainText: "", contextIds: [], attachments: [] }; } const registry = getAIContextRegistry(store); const contextIds = registry.parseAllContextIds(input); if (contextIds.length === 0) { - return { contextPart: null, attachments: [] }; + return { plainText: "", contextIds: [], attachments: [] }; } const plainText = registry.formatContextForAI(contextIds); @@ -110,20 +142,24 @@ export async function resolveChatContext( let attachments: FileUIPart[] = []; try { const resolved = await registry.getAttachmentsForContext(contextIds); - attachments = resolved.map((attachment) => ({ - ...attachment, - providerMetadata: { - ...attachment.providerMetadata, - marimo: { - ...attachment.providerMetadata?.marimo, - ...CONTEXT_ATTACHMENT_METADATA.marimo, - }, - }, - })); + attachments = resolved.map(stampContextAttachment); } catch (error) { Logger.error("Error getting attachments:", error); } + return { plainText, contextIds, attachments }; +} + +/** + * Resolve @-context for messages. They represent referenced + * datasets, variables, or other context from the user's prompt. + */ +export async function resolveChatContext( + input: string, +): Promise { + const { plainText, contextIds, attachments } = + await resolveContextAttachments(input); + let contextPart: MarimoContextUIPart | null = null; if (plainText.trim()) { contextPart = { @@ -141,31 +177,13 @@ export async function resolveChatContext( export async function getAICompletionBodyWithAttachments({ input, }: Opts): Promise { - let contextString = ""; - let attachments: FileUIPart[] = []; - - // Skip if no '@' in the input - if (input.includes("@")) { - const registry = getAIContextRegistry(store); - const contextIds = registry.parseAllContextIds(input); - - // Get context string - contextString = registry.formatContextForAI(contextIds); - - // Get attachments - try { - attachments = await registry.getAttachmentsForContext(contextIds); - Logger.debug("Included attachments", attachments.length); - } catch (error) { - Logger.error("Error getting attachments:", error); - } - } + const { plainText, attachments } = await resolveContextAttachments(input); return { body: { includeOtherCode: getCodes(""), context: { - plainText: contextString, + plainText, schema: [], variables: [], },