Skip to content

Commit bf427af

Browse files
igerberclaude
andcommitted
Rewrite Hausman pretest to ES(e) per Theorem A.1, guard cluster+survey
Address rerun-3 review findings: P1 fixes: - Hausman pretest now aggregates to post-treatment event-study ES(e) before computing test statistic, matching Theorem A.1 (was using raw (g,t) cells including pre-treatment placebos) - Raise NotImplementedError when both cluster and survey_design are set (cluster was silently ignored under survey TSL dispatch) - Add clustered bootstrap+aggregate='all' test covering bootstrap-updated event-study and group effect SEs P2: Update REGISTRY.md Hausman note to describe ES(e) aggregation. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 94be0c0 commit bf427af

4 files changed

Lines changed: 108 additions & 65 deletions

File tree

diff_diff/efficient_did.py

Lines changed: 77 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -348,6 +348,12 @@ def fit(
348348
"""
349349
self._validate_params()
350350

351+
if self.cluster is not None and survey_design is not None:
352+
raise NotImplementedError(
353+
"cluster and survey_design cannot both be set. "
354+
"Use survey_design with PSU/strata for cluster-robust inference."
355+
)
356+
351357
# Resolve survey design if provided
352358
from diff_diff.survey import _resolve_survey_for_fit
353359

@@ -1475,74 +1481,92 @@ def _nan_result() -> HausmanPretestResult:
14751481
if not common_gts:
14761482
return _nan_result()
14771483

1478-
# Filter out (g,t) cells with non-finite effect estimates
1479-
common_gts = [
1480-
gt
1481-
for gt in common_gts
1482-
if np.isfinite(result_all.group_time_effects[gt]["effect"])
1483-
and np.isfinite(result_post.group_time_effects[gt]["effect"])
1484-
]
1485-
if not common_gts:
1486-
return _nan_result()
1487-
1488-
k = len(common_gts)
1489-
1490-
# Build EIF matrices for common (g,t) pairs: (n_units, k)
14911484
eif_all = result_all.influence_functions
14921485
eif_post = result_post.influence_functions
14931486
assert eif_all is not None and eif_post is not None
14941487
n_units = len(next(iter(eif_all.values())))
14951488

1496-
eif_all_mat = np.column_stack([eif_all[gt] for gt in common_gts])
1497-
eif_post_mat = np.column_stack([eif_post[gt] for gt in common_gts])
1498-
1499-
# Filter out (g,t) pairs with non-finite EIF values
1500-
finite_mask = np.all(np.isfinite(eif_all_mat), axis=0) & np.all(
1501-
np.isfinite(eif_post_mat), axis=0
1489+
# --- Aggregate to post-treatment ES(e) per Theorem A.1 ---
1490+
# Derive cohort fractions from data for proper weights
1491+
all_units_list = sorted(data[unit].unique())
1492+
unit_cohorts = (
1493+
data.groupby(unit)[first_treat].first().reindex(all_units_list).values.astype(float)
15021494
)
1503-
if not np.all(finite_mask):
1504-
n_dropped = int(np.sum(~finite_mask))
1505-
common_gts = [gt for gt, m in zip(common_gts, finite_mask) if m]
1506-
eif_all_mat = eif_all_mat[:, finite_mask]
1507-
eif_post_mat = eif_post_mat[:, finite_mask]
1508-
k = len(common_gts)
1509-
if k == 0:
1510-
return _nan_result()
1511-
warnings.warn(
1512-
f"Dropped {n_dropped} (g,t) pair(s) with non-finite EIF values "
1513-
"from Hausman test.",
1514-
UserWarning,
1515-
stacklevel=2,
1516-
)
1517-
1518-
# Recompute delta after filtering
1519-
delta = np.array(
1520-
[
1521-
result_post.group_time_effects[gt]["effect"]
1522-
- result_all.group_time_effects[gt]["effect"]
1523-
for gt in common_gts
1524-
]
1495+
cohort_fractions: Dict[float, float] = {}
1496+
for g in set(result_all.groups) | set(result_post.groups):
1497+
cohort_fractions[g] = float(np.sum(unit_cohorts == g)) / n_units
1498+
1499+
def _aggregate_es(
1500+
gt_effects: Dict, eif_dict: Dict, groups: List, ant: int
1501+
) -> Dict[int, Tuple[float, np.ndarray]]:
1502+
"""Aggregate (g,t) effects to post-treatment ES(e) with cohort weights."""
1503+
by_e: Dict[int, List[Tuple[float, float, np.ndarray]]] = {}
1504+
for (g, t), d in gt_effects.items():
1505+
e = int(t - g)
1506+
if e < -ant: # pre-treatment beyond anticipation window
1507+
continue
1508+
if not np.isfinite(d["effect"]):
1509+
continue
1510+
if (g, t) not in eif_dict:
1511+
continue
1512+
eif_vec = eif_dict[(g, t)]
1513+
if not np.all(np.isfinite(eif_vec)):
1514+
continue
1515+
pg = cohort_fractions.get(g, 0.0)
1516+
if e not in by_e:
1517+
by_e[e] = []
1518+
by_e[e].append((d["effect"], pg, eif_vec))
1519+
1520+
result: Dict[int, Tuple[float, np.ndarray]] = {}
1521+
for e, items in by_e.items():
1522+
if e < 0: # restrict to post-treatment (e >= 0)
1523+
continue
1524+
effs = np.array([x[0] for x in items])
1525+
pgs = np.array([x[1] for x in items])
1526+
eifs = [x[2] for x in items]
1527+
total_pg = pgs.sum()
1528+
w = pgs / total_pg if total_pg > 0 else np.ones(len(pgs)) / len(pgs)
1529+
es_eff = float(np.sum(w * effs))
1530+
es_eif = np.zeros(n_units)
1531+
for k_idx in range(len(eifs)):
1532+
es_eif += w[k_idx] * eifs[k_idx]
1533+
result[e] = (es_eff, es_eif)
1534+
return result
1535+
1536+
es_all = _aggregate_es(
1537+
result_all.group_time_effects, eif_all, result_all.groups, anticipation
1538+
)
1539+
es_post = _aggregate_es(
1540+
result_post.group_time_effects, eif_post, result_post.groups, anticipation
15251541
)
15261542

1527-
# Also filter units with non-finite EIF values (row-wise)
1543+
# Find common post-treatment horizons
1544+
common_e = sorted(set(es_all.keys()) & set(es_post.keys()))
1545+
if not common_e:
1546+
return _nan_result()
1547+
1548+
delta = np.array([es_post[e][0] - es_all[e][0] for e in common_e])
1549+
1550+
# Build ES(e)-level EIF matrices
1551+
eif_all_mat = np.column_stack([es_all[e][1] for e in common_e])
1552+
eif_post_mat = np.column_stack([es_post[e][1] for e in common_e])
1553+
1554+
# Filter units with non-finite EIF values
15281555
row_finite = np.all(np.isfinite(eif_all_mat), axis=1) & np.all(
15291556
np.isfinite(eif_post_mat), axis=1
15301557
)
1531-
# Build cluster mapping for covariance if needed
15321558
cl_idx: Optional[np.ndarray] = None
15331559
n_cl: Optional[int] = None
15341560
if cluster is not None:
1535-
all_units = sorted(data[unit].unique())
1536-
cl_idx, n_cl = _validate_and_build_cluster_mapping(data, unit, cluster, all_units)
1537-
1561+
cl_idx, n_cl = _validate_and_build_cluster_mapping(data, unit, cluster, all_units_list)
15381562
if not np.all(row_finite):
15391563
eif_all_mat = eif_all_mat[row_finite]
15401564
eif_post_mat = eif_post_mat[row_finite]
15411565
n_units = int(np.sum(row_finite))
15421566
if cl_idx is not None:
15431567
cl_idx = cl_idx[row_finite]
15441568

1545-
# Compute full covariance matrices using shared _cluster_aggregate
1569+
# Compute full covariance matrices
15461570
if cl_idx is not None and n_cl is not None:
15471571

15481572
def _eif_cov(eif_mat: np.ndarray) -> np.ndarray:
@@ -1559,7 +1583,6 @@ def _eif_cov(eif_mat: np.ndarray) -> np.ndarray:
15591583

15601584
V = cov_post - cov_all
15611585

1562-
# If covariance has NaN/Inf, test is unreliable
15631586
if not np.all(np.isfinite(V)):
15641587
warnings.warn(
15651588
"Hausman covariance matrix contains non-finite values. " "The test is unreliable.",
@@ -1583,29 +1606,23 @@ def _eif_cov(eif_mat: np.ndarray) -> np.ndarray:
15831606
stacklevel=2,
15841607
)
15851608

1586-
# Effective rank = number of positive eigenvalues
15871609
effective_rank = int(np.sum(eigvals > tol))
15881610
if effective_rank == 0:
15891611
return _nan_result()
15901612

1591-
# Compute H = delta' @ pinv(V) @ delta
15921613
V_pinv = np.linalg.pinv(V, rcond=tol / max_eigval if max_eigval > 0 else 1e-10)
15931614
H = float(delta @ V_pinv @ delta)
1594-
H = max(H, 0.0) # numerical floor
1615+
H = max(H, 0.0)
15951616

15961617
p_value = float(chi2.sf(H, df=effective_rank))
15971618
reject = p_value < alpha
15981619

1599-
# Build per-(g,t) details DataFrame
1600-
gt_details = pd.DataFrame(
1620+
es_details = pd.DataFrame(
16011621
{
1602-
"group": [gt[0] for gt in common_gts],
1603-
"time": [gt[1] for gt in common_gts],
1604-
"att_all": [result_all.group_time_effects[gt]["effect"] for gt in common_gts],
1605-
"att_post": [result_post.group_time_effects[gt]["effect"] for gt in common_gts],
1622+
"relative_period": common_e,
1623+
"es_all": [es_all[e][0] for e in common_e],
1624+
"es_post": [es_post[e][0] for e in common_e],
16061625
"delta": delta,
1607-
"se_all": [result_all.group_time_effects[gt]["se"] for gt in common_gts],
1608-
"se_post": [result_post.group_time_effects[gt]["se"] for gt in common_gts],
16091626
}
16101627
)
16111628

@@ -1618,5 +1635,5 @@ def _eif_cov(eif_mat: np.ndarray) -> np.ndarray:
16181635
att_all=result_all.overall_att,
16191636
att_post=result_post.overall_att,
16201637
recommendation="pt_post" if reject else "pt_all",
1621-
gt_details=gt_details,
1638+
gt_details=es_details,
16221639
)

diff_diff/efficient_did_results.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ class HausmanPretestResult:
4242
recommendation: str
4343
"""``"pt_all"`` if fail to reject, ``"pt_post"`` if reject."""
4444
gt_details: Optional[pd.DataFrame] = None
45-
"""Per-(g,t) details: ATT_all, ATT_post, delta, SE_all, SE_post."""
45+
"""Per-event-study-horizon details: relative_period, es_all, es_post, delta."""
4646

4747
def __repr__(self) -> str:
4848
return (

docs/methodology/REGISTRY.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -680,7 +680,7 @@ where `q_{g,e} = pi_g / sum_{g' in G_{trt,e}} pi_{g'}`.
680680
- **Note:** EfficientDiD bootstrap with survey weights deferred to Phase 5
681681
- **Note:** EfficientDiD covariates (DR path) with survey weights deferred — the doubly robust nuisance estimation does not yet thread survey weights through sieve/kernel steps
682682
- **Note:** Cluster-robust SEs use the standard Liang-Zeger clustered sandwich estimator applied to EIF values: aggregate EIF within clusters, center, and compute variance with G/(G-1) small-sample correction. Cluster bootstrap generates multiplier weights at the cluster level (all units in a cluster share the same weight). Analytical clustered SEs are the default when `cluster` is set; cluster bootstrap is opt-in via `n_bootstrap > 0`.
683-
- **Note:** Hausman pretest uses the full cross-(g,t) covariance matrix from EIF values (Theorem A.1), not a diagonal approximation. The variance-difference matrix V = Cov(ATT_post) - Cov(ATT_all) is inverted via Moore-Penrose pseudoinverse to handle finite-sample non-positive-definiteness. Effective rank of V (number of positive eigenvalues) is used as degrees of freedom. Substantially negative eigenvalues trigger a warning.
683+
- **Note:** Hausman pretest operates on the post-treatment event-study vector ES(e) per Theorem A.1. Both PT-All and PT-Post fits are aggregated to ES(e) using cohort-size weights before computing the test statistic H = delta' V^{-1} delta where delta = ES_post - ES_all and V = Cov(ES_post) - Cov(ES_all). Covariance is computed from aggregated ES(e)-level EIF values. The variance-difference matrix V is inverted via Moore-Penrose pseudoinverse to handle finite-sample non-positive-definiteness. Effective rank of V (number of positive eigenvalues) is used as degrees of freedom.
684684
- **Note:** Last-cohort-as-control (`control_group="last_cohort"`) reclassifies the latest treatment cohort as pseudo-never-treated and drops time periods at/after that cohort's treatment start. This is distinct from CallawaySantAnna's `not_yet_treated` option which dynamically selects not-yet-treated units per (g,t) pair.
685685

686686
---

tests/test_efficient_did.py

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -638,13 +638,15 @@ def test_hausman_differential_trends_detects(self):
638638
assert pretest.statistic >= 0
639639
assert pretest.recommendation in ("pt_all", "pt_post", "inconclusive")
640640

641-
def test_hausman_gt_details(self):
642-
"""gt_details should have expected columns."""
641+
def test_hausman_es_details(self):
642+
"""gt_details should have event-study columns per Theorem A.1."""
643643
df = _make_staggered_panel(n_per_group=80, n_control=100)
644644
pretest = EfficientDiD.hausman_pretest(df, "y", "unit", "time", "first_treat")
645645
assert pretest.gt_details is not None
646-
expected_cols = {"group", "time", "att_all", "att_post", "delta", "se_all", "se_post"}
646+
expected_cols = {"relative_period", "es_all", "es_post", "delta"}
647647
assert set(pretest.gt_details.columns) == expected_cols
648+
# All relative periods should be post-treatment (>= 0)
649+
assert all(e >= 0 for e in pretest.gt_details["relative_period"])
648650

649651
def test_hausman_recommendation_field(self):
650652
"""recommendation should be pt_all or pt_post."""
@@ -893,6 +895,30 @@ def test_single_cluster_raises(self):
893895
with pytest.raises(ValueError, match="at least 2 clusters"):
894896
EfficientDiD(cluster="cluster_id").fit(df, "y", "unit", "time", "first_treat")
895897

898+
def test_cluster_plus_survey_raises(self):
899+
"""cluster + survey_design should raise NotImplementedError."""
900+
df = _make_staggered_panel(n_per_group=60, n_control=80)
901+
df["cluster_id"] = df["unit"] % 5
902+
df["w"] = 1.0
903+
with pytest.raises(NotImplementedError, match="cluster and survey_design"):
904+
EfficientDiD(cluster="cluster_id").fit(
905+
df, "y", "unit", "time", "first_treat", survey_design="w"
906+
)
907+
908+
def test_clustered_bootstrap_aggregate_all(self, ci_params):
909+
"""Clustered bootstrap with aggregate='all' should produce finite results."""
910+
n_boot = ci_params.bootstrap(99)
911+
df = self._make_clustered_panel(n_clusters=60, units_per_cluster=3)
912+
result = EfficientDiD(cluster="cluster_id", n_bootstrap=n_boot, seed=42).fit(
913+
df, "y", "unit", "time", "first_treat", aggregate="all"
914+
)
915+
assert result.event_study_effects is not None
916+
assert result.group_effects is not None
917+
for e, d in result.event_study_effects.items():
918+
assert np.isfinite(d["se"])
919+
for g, d in result.group_effects.items():
920+
assert np.isfinite(d["se"])
921+
896922

897923
class TestSmallCohortWarning:
898924
"""Small cohort warnings for numerical stability."""

0 commit comments

Comments
 (0)