diff --git a/frontend/src/components/editor/ai/__tests__/completion-utils.test.ts b/frontend/src/components/editor/ai/__tests__/completion-utils.test.ts
index 449c283ce27..e5c4c8ac7d9 100644
--- a/frontend/src/components/editor/ai/__tests__/completion-utils.test.ts
+++ b/frontend/src/components/editor/ai/__tests__/completion-utils.test.ts
@@ -1,6 +1,15 @@
/* Copyright 2026 Marimo. All rights reserved. */
-import { beforeEach, describe, expect, it, type Mock, vi } from "vitest";
+import {
+ afterEach,
+ beforeEach,
+ describe,
+ expect,
+ it,
+ type Mock,
+ vi,
+} from "vitest";
import { variableName } from "@/__tests__/branded";
+import * as aiContext from "@/core/ai/context/context";
import { getCodes } from "@/core/codemirror/copilot/getCodes";
import { dataSourceConnectionsAtom } from "@/core/datasets/data-source-connections";
import { DUCKDB_ENGINE } from "@/core/datasets/engines";
@@ -8,10 +17,11 @@ import { datasetsAtom } from "@/core/datasets/state";
import type { DatasetsState } from "@/core/datasets/types";
import { store } from "@/core/state/jotai";
import { variablesAtom } from "@/core/variables/state";
-import type { UIMessage } from "ai";
+import type { FileUIPart, UIMessage } from "ai";
import {
codeToCells,
getAICompletionBody,
+ getAICompletionBodyWithAttachments,
isContextAttachment,
MARIMO_CONTEXT_PART_TYPE,
resolveChatContext,
@@ -440,6 +450,42 @@ describe("isContextAttachment", () => {
});
});
+describe("context attachment stamping", () => {
+ const rawAttachment: FileUIPart = {
+ type: "file",
+ mediaType: "image/png",
+ url: "data:image/png;base64,abc",
+ };
+
+ beforeEach(() => {
+ vi.spyOn(aiContext, "getAIContextRegistry").mockReturnValue({
+ parseAllContextIds: () => ["data://t1"],
+ formatContextForAI: () => '',
+ getAttachmentsForContext: async () => [rawAttachment],
+ } as unknown as ReturnType);
+ });
+
+ afterEach(() => {
+ vi.restoreAllMocks();
+ });
+
+ it("stamps chat attachments as context-derived", async () => {
+ const { attachments } = await resolveChatContext("see @data://t1");
+ expect(attachments).toHaveLength(1);
+ expect(isContextAttachment(attachments[0])).toBe(true);
+ // The original attachment is left untouched (we return a stamped copy).
+ expect(rawAttachment.providerMetadata).toBeUndefined();
+ });
+
+ it("stamps completion attachments the same way as chat", async () => {
+ const { attachments } = await getAICompletionBodyWithAttachments({
+ input: "see @data://t1",
+ });
+ expect(attachments).toHaveLength(1);
+ expect(isContextAttachment(attachments[0])).toBe(true);
+ });
+});
+
describe("codeToCells", () => {
it("should return empty array for empty string", () => {
const code = "";
diff --git a/frontend/src/components/editor/ai/completion-utils.ts b/frontend/src/components/editor/ai/completion-utils.ts
index ef408595d42..d41cb095892 100644
--- a/frontend/src/components/editor/ai/completion-utils.ts
+++ b/frontend/src/components/editor/ai/completion-utils.ts
@@ -9,6 +9,7 @@ import {
import type { ReactCodeMirrorRef } from "@uiw/react-codemirror";
import type { DataUIPart, FileUIPart, UIMessage } from "ai";
import { getAIContextRegistry } from "@/core/ai/context/context";
+import type { ContextLocatorId } from "@/core/ai/context/registry";
import { getCodes } from "@/core/codemirror/copilot/getCodes";
import type { LanguageAdapterType } from "@/core/codemirror/language/types";
import type { AiCompletionRequest } from "@/core/network/types";
@@ -89,20 +90,51 @@ export function isContextAttachment(part: UIMessage["parts"][number]): boolean {
}
/**
- * Resolve @-context for messages. They represent referenced
- * datasets, variables, or other context from the user's prompt.
+ * Stamp a context-derived attachment with a provenance marker.
+ *
+ * Some @-mentions resolve to file attachments (e.g. a cell's image output),
+ * which get appended to the user message right alongside files the user
+ * uploaded by hand. Once they're in the message the two are indistinguishable,
+ * so we mark the context-derived ones. This matters on message edit: we
+ * re-resolve context from the edited text, and `isContextAttachment` lets us
+ * drop only the stale context attachments while preserving the user's own
+ * uploads
*/
-export async function resolveChatContext(
+function stampContextAttachment(attachment: FileUIPart): FileUIPart {
+ return {
+ ...attachment,
+ providerMetadata: {
+ ...attachment.providerMetadata,
+ // Merge within the `marimo` namespace so we don't clobber any other
+ // marimo metadata a provider may have already set.
+ marimo: {
+ ...attachment.providerMetadata?.marimo,
+ ...CONTEXT_ATTACHMENT_METADATA.marimo,
+ },
+ },
+ };
+}
+
+interface ResolvedContext {
+ plainText: string;
+ contextIds: ContextLocatorId[];
+ attachments: FileUIPart[];
+}
+
+/**
+ * Parse @-context for messages
+ */
+async function resolveContextAttachments(
input: string,
-): Promise {
+): Promise {
if (!input.includes(CONTEXT_TRIGGER)) {
- return { contextPart: null, attachments: [] };
+ return { plainText: "", contextIds: [], attachments: [] };
}
const registry = getAIContextRegistry(store);
const contextIds = registry.parseAllContextIds(input);
if (contextIds.length === 0) {
- return { contextPart: null, attachments: [] };
+ return { plainText: "", contextIds: [], attachments: [] };
}
const plainText = registry.formatContextForAI(contextIds);
@@ -110,20 +142,24 @@ export async function resolveChatContext(
let attachments: FileUIPart[] = [];
try {
const resolved = await registry.getAttachmentsForContext(contextIds);
- attachments = resolved.map((attachment) => ({
- ...attachment,
- providerMetadata: {
- ...attachment.providerMetadata,
- marimo: {
- ...attachment.providerMetadata?.marimo,
- ...CONTEXT_ATTACHMENT_METADATA.marimo,
- },
- },
- }));
+ attachments = resolved.map(stampContextAttachment);
} catch (error) {
Logger.error("Error getting attachments:", error);
}
+ return { plainText, contextIds, attachments };
+}
+
+/**
+ * Resolve @-context for messages. They represent referenced
+ * datasets, variables, or other context from the user's prompt.
+ */
+export async function resolveChatContext(
+ input: string,
+): Promise {
+ const { plainText, contextIds, attachments } =
+ await resolveContextAttachments(input);
+
let contextPart: MarimoContextUIPart | null = null;
if (plainText.trim()) {
contextPart = {
@@ -141,31 +177,13 @@ export async function resolveChatContext(
export async function getAICompletionBodyWithAttachments({
input,
}: Opts): Promise {
- let contextString = "";
- let attachments: FileUIPart[] = [];
-
- // Skip if no '@' in the input
- if (input.includes("@")) {
- const registry = getAIContextRegistry(store);
- const contextIds = registry.parseAllContextIds(input);
-
- // Get context string
- contextString = registry.formatContextForAI(contextIds);
-
- // Get attachments
- try {
- attachments = await registry.getAttachmentsForContext(contextIds);
- Logger.debug("Included attachments", attachments.length);
- } catch (error) {
- Logger.error("Error getting attachments:", error);
- }
- }
+ const { plainText, attachments } = await resolveContextAttachments(input);
return {
body: {
includeOtherCode: getCodes(""),
context: {
- plainText: contextString,
+ plainText,
schema: [],
variables: [],
},