Skip to content

Commit f580dae

Browse files
igerberclaude
andcommitted
Address AI review P1/P2 findings for Phase 3 PR B
P1 fixes: - DID^X residualization no longer leaks into per-period path: per_period_effects uses raw Y_mat, only multi-horizon path sees residualized outcomes - Added to_dataframe levels for heterogeneity and linear_trends P2 fixes: - Covariate coercion no longer mutates caller's DataFrame - Vectorized residualization (einsum replaces nested loop) - Heterogeneity test guards against rank-deficient OLS - Added estimand contract test for controls + L_max=1 - REGISTRY note clarifies per_period_effects stays unadjusted Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 4cb0c11 commit f580dae

4 files changed

Lines changed: 104 additions & 15 deletions

File tree

diff_diff/chaisemartin_dhaultfoeuille.py

Lines changed: 51 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -618,21 +618,26 @@ def fit(
618618
f"Control column(s) {missing_controls!r} not found in "
619619
f"data. Available columns: {list(data.columns)}"
620620
)
621+
# Work on a copy to avoid mutating the caller's DataFrame
622+
data_controls = data[controls].copy()
621623
for c in controls:
622624
try:
623-
data[c] = pd.to_numeric(data[c])
625+
data_controls[c] = pd.to_numeric(data_controls[c])
624626
except (ValueError, TypeError) as exc:
625627
raise ValueError(
626628
f"Could not coerce control column {c!r} to numeric: {exc}"
627629
) from exc
628-
n_nan = int(data[c].isna().sum())
630+
n_nan = int(data_controls[c].isna().sum())
629631
if n_nan > 0:
630632
raise ValueError(
631633
f"Control column {c!r} contains {n_nan} NaN value(s). "
632634
"Drop or impute missing covariates before fitting."
633635
)
634-
# Aggregate covariates to cell means (same groupby as treatment/outcome)
635-
x_cell_agg = data.groupby([group, time], as_index=False)[controls].mean()
636+
# Aggregate covariates to cell means (same groupby as treatment/outcome).
637+
# Use the coerced copy joined with group/time from original data.
638+
x_agg_input = data[[group, time]].copy()
639+
x_agg_input[controls] = data_controls[controls].values
640+
x_cell_agg = x_agg_input.groupby([group, time], as_index=False)[controls].mean()
636641
cell = cell.merge(x_cell_agg, on=[group, time], how="left")
637642

638643
# ------------------------------------------------------------------
@@ -948,13 +953,19 @@ def fit(
948953
)
949954
_switch_metadata_computed = True
950955

951-
Y_mat, covariate_diagnostics = _compute_covariate_residualization(
956+
Y_mat_residualized, covariate_diagnostics = _compute_covariate_residualization(
952957
Y_mat=Y_mat,
953958
X_cell=X_cell,
954959
N_mat=N_mat,
955960
baselines=baselines,
956961
first_switch_idx=first_switch_idx_arr,
957962
)
963+
# Keep raw Y_mat for the per-period DID path (which does not
964+
# support covariate residualization - it uses binary joiner/leaver
965+
# categorization). The residualized matrix is used only by the
966+
# per-group multi-horizon path (L_max >= 1).
967+
Y_mat_raw = Y_mat
968+
Y_mat = Y_mat_residualized
958969

959970
# ------------------------------------------------------------------
960971
# Step 7c: First-differencing for linear trends (DID^{fd})
@@ -1061,8 +1072,13 @@ def fit(
10611072
a11_minus_zeroed_arr,
10621073
) = _compute_per_period_dids(
10631074
D_mat=D_mat,
1064-
Y_mat=Y_mat,
1065-
N_mat=N_mat,
1075+
# Use raw (unadjusted) outcomes for per-period DID. Covariate
1076+
# residualization applies only to the per-group multi-horizon
1077+
# path (L_max >= 1). The per-period path uses binary
1078+
# joiner/leaver categorization and is not part of the DID^X
1079+
# contract (Web Appendix Section 1.2).
1080+
Y_mat=Y_mat_raw if controls is not None else Y_mat,
1081+
N_mat=N_mat_orig,
10661082
periods=all_periods,
10671083
)
10681084
if a11_warnings:
@@ -1489,7 +1505,8 @@ def fit(
14891505
U_centered_leavers,
14901506
) = _compute_cohort_recentered_inputs(
14911507
D_mat=D_mat,
1492-
Y_mat=Y_mat,
1508+
# Phase 1 IF uses per-period structure: use raw outcomes
1509+
Y_mat=Y_mat_raw if controls is not None else Y_mat,
14931510
N_mat=N_mat_orig,
14941511
n_10_t_arr=n_10_t_arr,
14951512
n_00_t_arr=n_00_t_arr,
@@ -2751,12 +2768,17 @@ def _compute_covariate_residualization(
27512768
}
27522769

27532770
# Residualize Y at levels for all groups with this baseline.
2754-
# Y_tilde[g, t] = Y[g, t] - X[g, t] @ theta_hat
2771+
# Vectorized level residualization: Y_tilde[g, t] = Y[g, t] - X[g, t] @ theta_hat
27552772
group_indices = np.where(d_mask)[0]
2756-
for g in group_indices:
2757-
for t in range(n_periods):
2758-
if N_mat[g, t] > 0 and np.all(np.isfinite(X_cell[g, t])):
2759-
Y_resid[g, t] = Y_mat[g, t] - float(X_cell[g, t] @ theta_hat)
2773+
if len(group_indices) > 0:
2774+
# X_sub: (n_d_groups, n_periods, n_covariates), theta: (n_covariates,)
2775+
X_sub = X_cell[group_indices] # (n_d, T, K)
2776+
adjustment = np.einsum("gtk,k->gt", X_sub, theta_hat) # (n_d, T)
2777+
# Mask: only adjust cells that are observed and have finite covariates
2778+
valid = (N_mat[group_indices] > 0) & np.all(np.isfinite(X_sub), axis=2)
2779+
Y_resid[group_indices] = np.where(
2780+
valid, Y_mat[group_indices] - adjustment, Y_mat[group_indices]
2781+
)
27602782

27612783
return Y_resid, diagnostics
27622784

@@ -2902,14 +2924,29 @@ def _compute_heterogeneity_test(
29022924
else:
29032925
design = x_arr
29042926

2927+
# Guard: need more observations than parameters
2928+
n_params = design.shape[1]
2929+
if n_obs <= n_params:
2930+
results[l_h] = {
2931+
"beta": float("nan"), "se": float("nan"),
2932+
"t_stat": float("nan"), "p_value": float("nan"),
2933+
"conf_int": (float("nan"), float("nan")),
2934+
"n_obs": n_obs,
2935+
}
2936+
continue
2937+
29052938
coefs, _residuals, vcov = solve_ols(
29062939
design, dep_arr,
29072940
return_vcov=True,
29082941
rank_deficient_action="warn",
29092942
)
29102943

29112944
beta_het = float(coefs[0])
2912-
se_het = float(np.sqrt(vcov[0, 0])) if vcov is not None else float("nan")
2945+
# NaN-safe: if vcov is None or target coefficient variance is NaN
2946+
# (rank-deficient), all inference fields are NaN.
2947+
se_het = float("nan")
2948+
if vcov is not None and np.isfinite(vcov[0, 0]) and vcov[0, 0] > 0:
2949+
se_het = float(np.sqrt(vcov[0, 0]))
29132950
t_stat, p_val, ci = safe_inference(beta_het, se_het, alpha=alpha, df=None)
29142951

29152952
results[l_h] = {

diff_diff/chaisemartin_dhaultfoeuille_results.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1029,6 +1029,28 @@ def to_dataframe(self, level: str = "overall") -> pd.DataFrame:
10291029
)
10301030
return self.twfe_weights.copy()
10311031

1032+
elif level == "heterogeneity":
1033+
if self.heterogeneity_effects is None:
1034+
raise ValueError(
1035+
"Heterogeneity test results not available. Pass "
1036+
"heterogeneity='column_name' to fit()."
1037+
)
1038+
rows = []
1039+
for h, data in sorted(self.heterogeneity_effects.items()):
1040+
rows.append({"horizon": h, **data})
1041+
return pd.DataFrame(rows)
1042+
1043+
elif level == "linear_trends":
1044+
if self.linear_trends_effects is None:
1045+
raise ValueError(
1046+
"Linear trends effects not available. Pass "
1047+
"trends_linear=True to fit()."
1048+
)
1049+
rows = []
1050+
for h, data in sorted(self.linear_trends_effects.items()):
1051+
rows.append({"horizon": h, **data})
1052+
return pd.DataFrame(rows)
1053+
10321054
else:
10331055
raise ValueError(
10341056
f"Unknown level: {level!r}. Use 'overall', 'joiners_leavers', "

docs/methodology/REGISTRY.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -609,7 +609,7 @@ Alternative: Multiplier bootstrap clustered at group via the `n_bootstrap` param
609609

610610
- **Note (deviation from R DIDmultiplegtDYN):** Phase 1 requires panels with a **balanced baseline** (every group observed at the first global period) and **no interior period gaps**. The Step 5b validation in `fit()` enforces this contract: groups missing the baseline raise `ValueError`; groups with interior gaps are dropped with a `UserWarning`; groups with **terminal missingness** (early exit / right-censoring — observed at the baseline but missing one or more later periods) are retained and contribute from their observed periods only. R `DIDmultiplegtDYN` accepts unbalanced panels with documented missing-treatment-before-first-switch handling. Python's restriction is a Phase 1 limitation: the cohort enumeration uses `D_{g,1}` as the canonical baseline (so the baseline observation must exist) and the first-switch detection walks adjacent observed periods (so interior gaps create ambiguous transition counts). Terminal missingness is supported because the per-period `present = (N_mat[:, t] > 0) & (N_mat[:, t-1] > 0)` guard appears at three sites in the variance computation (`_compute_per_period_dids`, `_compute_full_per_group_contributions`, `_compute_cohort_recentered_inputs`) and cleanly masks out missing transitions without propagating NaN into the arithmetic. **Workaround for unbalanced panels:** pre-process your data to back-fill the baseline (or drop late-entry groups before fitting), or use R `DIDmultiplegtDYN` until a future phase lifts the restriction. The Step 5b `ValueError` and `UserWarning` messages name the offending group IDs so you can locate them quickly.
611611

612-
- **Note (Phase 3 DID^X covariate adjustment):** Implements the residualization-style covariate adjustment from Web Appendix Section 1.2 (Assumption 11). For each baseline treatment value `d`, estimates `theta_hat_d` via OLS of first-differenced outcomes on first-differenced covariates with time FEs, restricted to not-yet-treated observations. Residualizes at levels: `Y_tilde[g,t] = Y[g,t] - X[g,t] @ theta_hat_d`. All downstream DID computations use residualized outcomes. This is NOT doubly-robust, NOT IPW, NOT Callaway-Sant'Anna-style. Plug-in IF (treating `theta_hat` as fixed) is valid by FWL theorem. Requires `L_max >= 1`. Activated via `controls=["col1", "col2"]` in `fit()`.
612+
- **Note (Phase 3 DID^X covariate adjustment):** When `controls` is set, `per_period_effects` (the Phase 1 per-period DID_M decomposition) remains **unadjusted** (computed on raw outcomes). The covariate residualization applies only to the per-group `DID_{g,l}` path (`L_max >= 1`), which produces `event_study_effects` and `overall_att`. This means `per_period_effects` and `event_study_effects[1]` may diverge when controls are active - by design (the per-period path uses binary joiner/leaver categorization and is not part of the DID^X contract). Implements the residualization-style covariate adjustment from Web Appendix Section 1.2 (Assumption 11). For each baseline treatment value `d`, estimates `theta_hat_d` via OLS of first-differenced outcomes on first-differenced covariates with time FEs, restricted to not-yet-treated observations. Residualizes at levels: `Y_tilde[g,t] = Y[g,t] - X[g,t] @ theta_hat_d`. All downstream DID computations use residualized outcomes. This is NOT doubly-robust, NOT IPW, NOT Callaway-Sant'Anna-style. Plug-in IF (treating `theta_hat` as fixed) is valid by FWL theorem. Requires `L_max >= 1`. Activated via `controls=["col1", "col2"]` in `fit()`.
613613

614614
- **Note (Phase 3 DID^{fd} linear trends):** Implements group-specific linear trends from Web Appendix Section 1.3 (Assumption 12, Lemma 6). Uses the Z_mat transformation: `Z[g,t] = Y[g,t] - Y[g,t-1]` (first-differenced outcomes). Since `DID_{g,l}(Z) = DID^{fd}_{g,l}` algebraically, the existing multi-horizon DID code produces trend-adjusted estimates when fed Z_mat. Requires F_g >= 3 (at least 2 pre-switch periods); groups with F_g < 3 are excluded with a `UserWarning`. Cumulated level effects `delta^{fd}_l = sum_{l'=1}^l DID^{fd}_{l'}` stored in `results.linear_trends_effects`. Cumulated SE uses conservative upper bound (sum of per-horizon SEs); cross-horizon covariance from IF vectors is a library extension (paper proves Theorem 1 per-horizon, not cross-horizon). When combined with DID^X, residualization is applied first, then first-differencing (per paper assumption ordering). Activated via `trends_linear=True` in `fit()`.
615615

tests/test_chaisemartin_dhaultfoeuille.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2421,6 +2421,36 @@ def test_controls_with_multi_horizon(self):
24212421
assert np.isfinite(r.event_study_effects[h]["effect"])
24222422
assert np.isfinite(r.event_study_effects[h]["se"])
24232423

2424+
def test_controls_lmax1_estimand_contract(self):
2425+
"""DID^X with L_max=1: per_period_effects stay raw, overall uses DID^X_1."""
2426+
df = self._make_panel_with_covariates()
2427+
est = ChaisemartinDHaultfoeuille(seed=1)
2428+
2429+
# Fit without controls for raw per-period baseline
2430+
r_raw = est.fit(df, "outcome", "group", "period", "treatment")
2431+
# Fit with controls
2432+
r_x = est.fit(
2433+
df, "outcome", "group", "period", "treatment",
2434+
controls=["X1"], L_max=1,
2435+
)
2436+
2437+
# per_period_effects should be UNADJUSTED (raw Phase 1 DID_M)
2438+
# because the per-period path does not support covariate adjustment
2439+
for period_key in r_raw.per_period_effects:
2440+
if period_key in r_x.per_period_effects:
2441+
raw_eff = r_raw.per_period_effects[period_key]
2442+
x_eff = r_x.per_period_effects[period_key]
2443+
assert raw_eff["did_plus_t"] == pytest.approx(
2444+
x_eff["did_plus_t"], abs=1e-10
2445+
), f"per_period_effects should be unadjusted at period {period_key}"
2446+
2447+
# overall_att should come from event_study_effects[1] (DID^X_1)
2448+
assert r_x.overall_att == pytest.approx(
2449+
r_x.event_study_effects[1]["effect"], abs=1e-10
2450+
)
2451+
# and should differ from the raw overall_att (covariate effect)
2452+
assert r_x.overall_att != r_raw.overall_att
2453+
24242454

24252455
class TestLinearTrends:
24262456
"""DID^{fd} group-specific linear trends (ROADMAP item 3b)."""

0 commit comments

Comments
 (0)