Skip to content

Commit 3968f0c

Browse files
igerberclaude
andcommitted
PR #464 R2 polish: thread df_name through agent_workflow (codex P1)
CI codex review on the rebased SHA flagged a newly-identified P1: agent_workflow() hardcoded the dataframe symbol `df` into every emitted call. A caller who does `panel = pd.read_parquet(...)` and then calls diff_diff.agent_workflow(panel, ...) gets back a script that references `df` and NameErrors on first execution. Fix: add `df_name: str = "df"` parameter, threaded into: - profile_call (`diff_diff.profile_panel({df_name}, ...)`) - both Step 3 fit example branches (CallawaySantAnna + DifferenceInDifferences) Default `df_name="df"` preserves prior behavior verbatim. Caller binds their dataframe to any identifier and passes the name; emitted script runs in their namespace without renaming. Tests added: - `test_df_name_templates_into_script`: default vs `df_name="panel"`, static AST scan confirms no bare `df` Name node remains. - `test_df_name_panel_script_executes_in_panel_namespace`: stubs `diff_diff` with MagicMock, exec's the panel-named script in a namespace where `panel` exists and `df` does not — NameError on the dataframe symbol would fail the test. Total tests: 21 → 23 in test_agent_workflow.py. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent e1d0f33 commit 3968f0c

2 files changed

Lines changed: 87 additions & 3 deletions

File tree

diff_diff/agent_workflow.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ def agent_workflow(
6969
treatment: str,
7070
outcome: str,
7171
first_treat: Optional[str] = None,
72+
df_name: str = "df",
7273
verbose: bool = True,
7374
) -> Dict[str, Any]:
7475
"""Print the recommended diff-diff workflow with your column names wired in.
@@ -99,6 +100,18 @@ def agent_workflow(
99100
example, matching the actual fit signatures (passing
100101
``treatment=`` to CallawaySantAnna's ``.fit()`` would raise
101102
TypeError).
103+
df_name : str, default ``"df"``
104+
Identifier under which the caller's dataframe is bound in
105+
their namespace. Templated verbatim into the emitted script
106+
as the first positional argument of every call
107+
(``profile_panel({df_name}, ...)``,
108+
``<Estimator>().fit({df_name}, ...)``) so the script is
109+
directly executable when the caller's local variable matches.
110+
If the caller has ``panel = pd.read_parquet(...)``, passing
111+
``df_name="panel"`` produces a script that references
112+
``panel`` instead of ``df``. Must be a valid Python identifier
113+
(not enforced; non-identifier values produce a script that
114+
won't parse).
102115
verbose : bool, default True
103116
If True, print the script to stdout. The dict is always
104117
returned regardless.
@@ -140,7 +153,7 @@ def agent_workflow(
140153
del df # intentionally unused: orchestrator templates from column names only
141154

142155
profile_call = (
143-
f"diff_diff.profile_panel(df, "
156+
f"diff_diff.profile_panel({df_name}, "
144157
f"{_join_kwargs(unit=unit, time=time, treatment=treatment, outcome=outcome)})"
145158
)
146159
guide_call = 'diff_diff.get_llm_guide("autonomous")'
@@ -163,7 +176,7 @@ def agent_workflow(
163176
fit_example_kwargs = _join_kwargs(
164177
outcome=outcome, unit=unit, time=time, first_treat=first_treat
165178
)
166-
fit_example_call = f"diff_diff.CallawaySantAnna().fit(df, {fit_example_kwargs})"
179+
fit_example_call = f"diff_diff.CallawaySantAnna().fit({df_name}, {fit_example_kwargs})"
167180
step3_label_lines = [
168181
"Step 3 - Fit. Your data has `first_treat` -> staggered structure.",
169182
"`first_treat` alone does NOT identify a single estimator; pick by",
@@ -178,7 +191,9 @@ def agent_workflow(
178191
fit_example_kwargs = _join_kwargs(
179192
outcome=outcome, unit=unit, time=time, treatment=treatment
180193
)
181-
fit_example_call = f"diff_diff.DifferenceInDifferences().fit(df, {fit_example_kwargs})"
194+
fit_example_call = (
195+
f"diff_diff.DifferenceInDifferences().fit({df_name}, {fit_example_kwargs})"
196+
)
182197
step3_label_lines = [
183198
"Step 3 - Fit. Pick a candidate from Step 2's patterns based on your",
184199
"treatment/time shape. The example below shows the simple 2x2 case",

tests/test_agent_workflow.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,75 @@ def test_emitted_script_prints_report(df):
242242
assert ".full_report())" in script
243243

244244

245+
def test_df_name_templates_into_script(df):
246+
"""Caller can rename the dataframe symbol in the emitted script.
247+
248+
Default (df_name="df"): script references `df`.
249+
Custom (df_name="panel"): every emitted call uses `panel` and no
250+
bare `df` identifier appears in the runnable code paths.
251+
"""
252+
import ast
253+
254+
out_default = diff_diff.agent_workflow(
255+
df,
256+
unit="firm_id",
257+
time="year",
258+
treatment="treated",
259+
outcome="logwage",
260+
verbose=False,
261+
)
262+
out_panel = diff_diff.agent_workflow(
263+
df,
264+
unit="firm_id",
265+
time="year",
266+
treatment="treated",
267+
outcome="logwage",
268+
df_name="panel",
269+
verbose=False,
270+
)
271+
# Default behavior preserved.
272+
assert "profile_panel(df," in out_default["script"]
273+
# Custom name flows through profile_call AND fit_example_call.
274+
assert "profile_panel(panel," in out_panel["script"]
275+
assert ".fit(panel," in out_panel["script"]
276+
# Static reference scan: parse the panel-script and confirm no `df`
277+
# Name node exists — catches template drift where a `df` reference
278+
# slips in outside the templated points.
279+
tree = ast.parse(out_panel["script"])
280+
names = {n.id for n in ast.walk(tree) if isinstance(n, ast.Name)}
281+
assert "df" not in names, (
282+
f"emitted script with df_name='panel' still references `df`: "
283+
f"identifier names found = {sorted(names)}"
284+
)
285+
assert "panel" in names
286+
287+
288+
def test_df_name_panel_script_executes_in_panel_namespace(df):
289+
"""The emitted script must resolve all names in a namespace where
290+
`panel` exists and `df` does not. We stub out `diff_diff` with a
291+
MagicMock so calls don't actually fit; the test is purely about
292+
symbol resolution, not numerical correctness — if the script still
293+
referenced `df` anywhere in runnable code, exec() would NameError.
294+
"""
295+
import unittest.mock
296+
297+
out = diff_diff.agent_workflow(
298+
df,
299+
unit="firm_id",
300+
time="year",
301+
treatment="treated",
302+
outcome="logwage",
303+
df_name="panel",
304+
verbose=False,
305+
)
306+
ns = {
307+
"diff_diff": unittest.mock.MagicMock(),
308+
"panel": "sentinel_df_object",
309+
# Deliberately no `df` key — script must not reference it.
310+
}
311+
exec(compile(out["script"], "<test_df_name>", "exec"), ns)
312+
313+
245314
def test_does_not_inspect_df():
246315
# Pure orchestrator: a structurally-empty DataFrame must still produce
247316
# the templated script (no df inspection happens).

0 commit comments

Comments
 (0)