Skip to content

Commit fe77d84

Browse files
igerberclaude
andcommitted
Validate grouping columns for NaN, fix examples, add missing-key tests
Address remaining P1/P3 from AI review: - Reject NaN in by columns before groupby (P1) - Make docstring/RST examples illustrative pseudocode (P3) - Add tests for partial and all-NaN grouping keys (P3) Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent d330d36 commit fe77d84

3 files changed

Lines changed: 54 additions & 11 deletions

File tree

diff_diff/prep.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1466,10 +1466,13 @@ def aggregate_survey(
14661466
... microdata, by=["state", "year"],
14671467
... outcomes="smoking_rate", survey_design=design,
14681468
... )
1469-
>>> result = DifferenceInDifferences().fit(
1470-
... panel, outcome="smoking_rate_mean",
1471-
... treatment="treated", time="year", survey_design=stage2,
1472-
... )
1469+
>>> # Add treatment/time indicators at the panel level, then fit:
1470+
>>> # panel["treated"] = ... # e.g., from policy adoption data
1471+
>>> # panel["post"] = (panel["year"] >= treatment_year).astype(int)
1472+
>>> # result = DifferenceInDifferences().fit(
1473+
>>> # panel, outcome="smoking_rate_mean",
1474+
>>> # treatment="treated", time="post", survey_design=stage2,
1475+
>>> # )
14731476
"""
14741477
import warnings
14751478
from dataclasses import replace
@@ -1508,6 +1511,15 @@ def aggregate_survey(
15081511
if data.empty:
15091512
raise ValueError("data must be non-empty")
15101513

1514+
# --- Validate grouping columns have no missing values ---
1515+
by_missing = data[by_cols].isna().any()
1516+
cols_with_na = list(by_missing[by_missing].index)
1517+
if cols_with_na:
1518+
raise ValueError(
1519+
f"Missing values in grouping column(s): {cols_with_na}. "
1520+
f"Drop or fill NaN values before calling aggregate_survey()."
1521+
)
1522+
15111523
# --- Resolve design once on full data ---
15121524
effective_design = (
15131525
replace(survey_design, lonely_psu=lonely_psu) if lonely_psu else survey_design

docs/api/prep.rst

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -284,13 +284,13 @@ Example
284284
# cell_n, cell_n_eff, srs_fallback
285285
286286
# stage2 is pre-configured: aweights + state-level clustering
287-
result = DifferenceInDifferences().fit(
288-
panel,
289-
outcome="smoking_rate_mean",
290-
treatment="treated",
291-
time="year",
292-
survey_design=stage2,
293-
)
287+
# Add treatment/time indicators at the panel level, then fit:
288+
# panel["treated"] = ... # from policy adoption data
289+
# panel["post"] = (panel["year"] >= treatment_year).astype(int)
290+
# result = DifferenceInDifferences().fit(
291+
# panel, outcome="smoking_rate_mean",
292+
# treatment="treated", time="post", survey_design=stage2,
293+
# )
294294
295295
Data Validation
296296
---------------

tests/test_prep.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2421,6 +2421,37 @@ def test_error_empty_data(self, design):
24212421
survey_design=design,
24222422
)
24232423

2424+
def test_error_missing_grouping_keys(self, micro_data, design):
2425+
"""NaN in grouping columns raises ValueError."""
2426+
data = micro_data.copy()
2427+
data.loc[0, "state"] = np.nan
2428+
with pytest.raises(ValueError, match="Missing values in grouping column"):
2429+
aggregate_survey(
2430+
data,
2431+
by=["state", "year"],
2432+
outcomes="y",
2433+
survey_design=design,
2434+
)
2435+
2436+
def test_error_all_missing_grouping_keys(self, design):
2437+
"""All-NaN grouping column raises ValueError."""
2438+
data = pd.DataFrame(
2439+
{
2440+
"state": [np.nan] * 10,
2441+
"year": np.ones(10, dtype=int),
2442+
"y": np.random.RandomState(1).normal(0, 1, 10),
2443+
"wt": np.ones(10),
2444+
}
2445+
)
2446+
design_simple = SurveyDesign(weights="wt")
2447+
with pytest.raises(ValueError, match="Missing values in grouping column"):
2448+
aggregate_survey(
2449+
data,
2450+
by=["state", "year"],
2451+
outcomes="y",
2452+
survey_design=design_simple,
2453+
)
2454+
24242455
def test_domain_estimation_preserves_full_design(self):
24252456
"""Full-design domain estimation accounts for PSUs outside the cell.
24262457

0 commit comments

Comments
 (0)