Skip to content

Commit 94be0c0

Browse files
igerberclaude
andcommitted
Filter NaN effects in Hausman pretest, add clustered aggregation tests
Address rerun review findings: - Filter common (g,t) cells with non-finite effect estimates before building Hausman delta/covariance (prevents NaN poisoning from no-valid-pairs cells) - Make _nan_result return recommendation="inconclusive" for consistency (reject=False + "pt_post" was misleading) - Add clustered SE tests for aggregate=event_study and aggregate=all (exercises cluster path through aggregation methods) Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent dd72e0f commit 94be0c0

2 files changed

Lines changed: 50 additions & 15 deletions

File tree

diff_diff/efficient_did.py

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1077,17 +1077,21 @@ def _compute_survey_eif_se(self, eif_vals: np.ndarray) -> float:
10771077
vcov = compute_survey_vcov(X_ones, eif_vals, self._unit_resolved_survey)
10781078
return float(np.sqrt(np.abs(vcov[0, 0])))
10791079

1080-
def _eif_se(self, eif_vals: np.ndarray, n_units: int) -> float:
1080+
def _eif_se(
1081+
self,
1082+
eif_vals: np.ndarray,
1083+
n_units: int,
1084+
cluster_indices: Optional[np.ndarray] = None,
1085+
n_clusters: Optional[int] = None,
1086+
) -> float:
10811087
"""Compute SE from aggregated EIF scores.
10821088
10831089
Dispatches to survey TSL when ``_unit_resolved_survey`` is set
10841090
(during fit), otherwise uses cluster-robust or standard formula.
10851091
"""
10861092
if self._unit_resolved_survey is not None:
10871093
return self._compute_survey_eif_se(eif_vals)
1088-
return _compute_se_from_eif(
1089-
eif_vals, n_units, self._cluster_indices, self._n_clusters
1090-
)
1094+
return _compute_se_from_eif(eif_vals, n_units, cluster_indices, n_clusters)
10911095

10921096
# -- Aggregation helpers --------------------------------------------------
10931097

@@ -1196,7 +1200,7 @@ def _aggregate_overall(
11961200

11971201
# SE = sqrt(mean(EIF^2) / n) — standard IF-based SE
11981202
# (dispatches to survey TSL or cluster-robust when active)
1199-
se = self._eif_se(agg_eif_total, n_units)
1203+
se = self._eif_se(agg_eif_total, n_units, cluster_indices, n_clusters)
12001204

12011205
return overall_att, se
12021206

@@ -1290,7 +1294,7 @@ def _aggregate_event_study(
12901294
)
12911295
agg_eif = agg_eif + wif
12921296

1293-
agg_se = self._eif_se(agg_eif, n_units)
1297+
agg_se = self._eif_se(agg_eif, n_units, cluster_indices, n_clusters)
12941298

12951299
t_stat, p_val, ci = safe_inference(
12961300
agg_eff, agg_se, alpha=self.alpha, df=self._survey_df
@@ -1354,7 +1358,7 @@ def _aggregate_by_group(
13541358
agg_eif = np.zeros(n_units)
13551359
for k, gt in enumerate(g_gts):
13561360
agg_eif += w[k] * eif_by_gt[gt]
1357-
agg_se = self._eif_se(agg_eif, n_units)
1361+
agg_se = self._eif_se(agg_eif, n_units, cluster_indices, n_clusters)
13581362

13591363
t_stat, p_val, ci = safe_inference(
13601364
agg_eff, agg_se, alpha=self.alpha, df=self._survey_df
@@ -1455,7 +1459,7 @@ def hausman_pretest(
14551459
set(result_all.group_time_effects.keys()) & set(result_post.group_time_effects.keys())
14561460
)
14571461

1458-
def _nan_result(recommendation: str = "pt_post") -> HausmanPretestResult:
1462+
def _nan_result() -> HausmanPretestResult:
14591463
return HausmanPretestResult(
14601464
statistic=np.nan,
14611465
p_value=np.nan,
@@ -1464,13 +1468,23 @@ def _nan_result(recommendation: str = "pt_post") -> HausmanPretestResult:
14641468
alpha=alpha,
14651469
att_all=result_all.overall_att,
14661470
att_post=result_post.overall_att,
1467-
recommendation=recommendation,
1471+
recommendation="inconclusive",
14681472
gt_details=None,
14691473
)
14701474

14711475
if not common_gts:
14721476
return _nan_result()
14731477

1478+
# Filter out (g,t) cells with non-finite effect estimates
1479+
common_gts = [
1480+
gt
1481+
for gt in common_gts
1482+
if np.isfinite(result_all.group_time_effects[gt]["effect"])
1483+
and np.isfinite(result_post.group_time_effects[gt]["effect"])
1484+
]
1485+
if not common_gts:
1486+
return _nan_result()
1487+
14741488
k = len(common_gts)
14751489

14761490
# Build EIF matrices for common (g,t) pairs: (n_units, k)
@@ -1572,7 +1586,7 @@ def _eif_cov(eif_mat: np.ndarray) -> np.ndarray:
15721586
# Effective rank = number of positive eigenvalues
15731587
effective_rank = int(np.sum(eigvals > tol))
15741588
if effective_rank == 0:
1575-
return _nan_result("pt_all")
1589+
return _nan_result()
15761590

15771591
# Compute H = delta' @ pinv(V) @ delta
15781592
V_pinv = np.linalg.pinv(V, rcond=tol / max_eigval if max_eigval > 0 else 1e-10)

tests/test_efficient_did.py

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -588,7 +588,7 @@ def test_hausman_homogeneous_trends_fail_to_reject(self):
588588
assert np.isfinite(pretest.p_value)
589589
assert pretest.df > 0
590590
# With homogeneous trends, should generally fail to reject
591-
assert pretest.recommendation in ("pt_all", "pt_post")
591+
assert pretest.recommendation in ("pt_all", "pt_post", "inconclusive")
592592

593593
def test_hausman_differential_trends_detects(self):
594594
"""DGP with cohort-specific trends → test detects or warns."""
@@ -636,7 +636,7 @@ def test_hausman_differential_trends_detects(self):
636636
# Both are acceptable outcomes for a DGP that violates PT-All
637637
if np.isfinite(pretest.statistic):
638638
assert pretest.statistic >= 0
639-
assert pretest.recommendation in ("pt_all", "pt_post")
639+
assert pretest.recommendation in ("pt_all", "pt_post", "inconclusive")
640640

641641
def test_hausman_gt_details(self):
642642
"""gt_details should have expected columns."""
@@ -650,7 +650,7 @@ def test_hausman_recommendation_field(self):
650650
"""recommendation should be pt_all or pt_post."""
651651
df = _make_staggered_panel(n_per_group=80, n_control=100)
652652
pretest = EfficientDiD.hausman_pretest(df, "y", "unit", "time", "first_treat")
653-
assert pretest.recommendation in ("pt_all", "pt_post")
653+
assert pretest.recommendation in ("pt_all", "pt_post", "inconclusive")
654654
if pretest.reject:
655655
assert pretest.recommendation == "pt_post"
656656
else:
@@ -700,7 +700,7 @@ def test_hausman_clustered(self):
700700
pretest = EfficientDiD.hausman_pretest(
701701
df, "y", "unit", "time", "first_treat", cluster="cluster_id"
702702
)
703-
assert pretest.recommendation in ("pt_all", "pt_post")
703+
assert pretest.recommendation in ("pt_all", "pt_post", "inconclusive")
704704
assert pretest.df >= 0
705705

706706
def test_hausman_last_cohort(self):
@@ -719,7 +719,7 @@ def test_hausman_last_cohort(self):
719719
"first_treat",
720720
control_group="last_cohort",
721721
)
722-
assert pretest.recommendation in ("pt_all", "pt_post")
722+
assert pretest.recommendation in ("pt_all", "pt_post", "inconclusive")
723723
assert np.isfinite(pretest.att_all)
724724
assert np.isfinite(pretest.att_post)
725725

@@ -803,6 +803,27 @@ def test_clustered_se_at_least_as_large(self):
803803
assert result_clustered.overall_se > 0
804804
assert result_unclustered.overall_se > 0
805805

806+
def test_clustered_aggregate_event_study(self):
807+
"""Clustered SE with aggregate='event_study' should produce finite results."""
808+
df = self._make_clustered_panel(n_clusters=60, units_per_cluster=3)
809+
result = EfficientDiD(cluster="cluster_id").fit(
810+
df, "y", "unit", "time", "first_treat", aggregate="event_study"
811+
)
812+
assert result.event_study_effects is not None
813+
for e, d in result.event_study_effects.items():
814+
assert np.isfinite(d["se"])
815+
816+
def test_clustered_aggregate_all(self):
817+
"""Clustered SE with aggregate='all' should produce finite results."""
818+
df = self._make_clustered_panel(n_clusters=60, units_per_cluster=3)
819+
result = EfficientDiD(cluster="cluster_id").fit(
820+
df, "y", "unit", "time", "first_treat", aggregate="all"
821+
)
822+
assert result.event_study_effects is not None
823+
assert result.group_effects is not None
824+
for g, d in result.group_effects.items():
825+
assert np.isfinite(d["se"])
826+
806827
def test_cluster_bootstrap(self, ci_params):
807828
"""Cluster bootstrap should produce finite inference."""
808829
n_boot = ci_params.bootstrap(99)

0 commit comments

Comments
 (0)