Skip to content

Commit 8a57c5d

Browse files
igerberclaude
andcommitted
Fix CI review Round 10: NaN set validation, design2 raw Y, controls=[]
P1: trends_nonparam now rejects NaN/missing set assignments with ValueError. P1: design2_effects always uses raw level outcomes from y_pivot (not residualized or first-differenced Y_mat). P2: controls=[] now raises ValueError instead of crashing on np.stack([]). P3: summary() overall block labeled "N/A under trends_linear" when trends + L_max>=2 (was "Cost-Benefit Delta" with NaN value). Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent e3d51db commit 8a57c5d

2 files changed

Lines changed: 21 additions & 2 deletions

File tree

diff_diff/chaisemartin_dhaultfoeuille.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -631,6 +631,12 @@ def fit(
631631
# Step 4b: Covariate aggregation (DID^X, Web Appendix Section 1.2)
632632
# ------------------------------------------------------------------
633633
if controls is not None:
634+
if not controls:
635+
raise ValueError(
636+
"controls must be a non-empty list of column names, "
637+
"got an empty list. Pass controls=None to disable "
638+
"covariate adjustment."
639+
)
634640
if L_max is None:
635641
raise ValueError(
636642
"Covariate adjustment (DID^X) requires L_max >= 1. The "
@@ -1081,6 +1087,14 @@ def fit(
10811087
f"trends_nonparam column {set_col!r} not found in "
10821088
f"data. Available columns: {list(data.columns)}"
10831089
)
1090+
# Reject NaN/missing set assignments
1091+
n_na_set = int(data[set_col].isna().sum())
1092+
if n_na_set > 0:
1093+
raise ValueError(
1094+
f"trends_nonparam column {set_col!r} contains "
1095+
f"{n_na_set} NaN/missing value(s). All groups must "
1096+
f"have a valid set assignment."
1097+
)
10841098
# Aggregate set membership per group (must be time-invariant)
10851099
set_per_group = data.groupby(group)[set_col].nunique()
10861100
time_varying = set_per_group[set_per_group > 1]
@@ -2361,7 +2375,9 @@ def fit(
23612375
design2_effects=(
23622376
_compute_design2_effects(
23632377
D_mat=D_mat,
2364-
Y_mat=Y_mat if not _is_trends_linear else y_pivot.to_numpy(),
2378+
# Design-2 always uses raw level outcomes (not residualized,
2379+
# not first-differenced). Use y_pivot as the canonical raw source.
2380+
Y_mat=y_pivot.to_numpy(),
23652381
N_mat=N_mat_orig,
23662382
baselines=baselines,
23672383
first_switch_idx=first_switch_idx_arr,

diff_diff/chaisemartin_dhaultfoeuille_results.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -580,7 +580,10 @@ def summary(self, alpha: Optional[float] = None) -> str:
580580
adj_tag = " (Trend-Adjusted)"
581581

582582
if self.L_max is not None and self.L_max >= 2:
583-
overall_label = f"Cost-Benefit Delta{adj_tag}"
583+
if has_trends:
584+
overall_label = f"Overall (N/A under trends_linear){adj_tag}"
585+
else:
586+
overall_label = f"Cost-Benefit Delta{adj_tag}"
584587
overall_row_label = self._estimand_label()
585588
elif self.L_max is not None and self.L_max == 1:
586589
overall_label = f"Per-Group ATT at Horizon 1{adj_tag}"

0 commit comments

Comments
 (0)