Skip to content

Commit 5bead50

Browse files
igerberclaude
andcommitted
Address P2 review findings: remove mutable state, deduplicate cluster covariance, update docstrings
- Remove _cluster_indices, _n_clusters, _store_eif instance attributes; pass cluster info explicitly through aggregation methods and store_eif as a fit() parameter - Extract shared _cluster_aggregate() helper used by both _compute_se_from_eif (scalar SE) and hausman_pretest (covariance matrix) - Update class docstring to document cluster, control_group parameters Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 78b19e7 commit 5bead50

1 file changed

Lines changed: 92 additions & 37 deletions

File tree

diff_diff/efficient_did.py

Lines changed: 92 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,39 @@
4545
__all__ = ["EfficientDiD", "EfficientDiDResults", "EDiDBootstrapResults"]
4646

4747

48+
def _cluster_aggregate(
49+
eif_mat: np.ndarray,
50+
cluster_indices: np.ndarray,
51+
n_clusters: int,
52+
) -> np.ndarray:
53+
"""Sum EIF values within clusters and center.
54+
55+
Parameters
56+
----------
57+
eif_mat : ndarray, shape (n_units,) or (n_units, k)
58+
EIF values — 1-D for a single estimand, 2-D for multiple.
59+
cluster_indices : ndarray, shape (n_units,)
60+
Integer cluster assignment per unit.
61+
n_clusters : int
62+
Number of unique clusters.
63+
64+
Returns
65+
-------
66+
ndarray, shape (n_clusters,) or (n_clusters, k)
67+
Centered cluster-level sums.
68+
"""
69+
if eif_mat.ndim == 1:
70+
sums = np.bincount(cluster_indices, weights=eif_mat, minlength=n_clusters).astype(float)
71+
else:
72+
sums = np.column_stack(
73+
[
74+
np.bincount(cluster_indices, weights=eif_mat[:, j], minlength=n_clusters)
75+
for j in range(eif_mat.shape[1])
76+
]
77+
).astype(float)
78+
return sums - sums.mean(axis=0)
79+
80+
4881
def _compute_se_from_eif(
4982
eif: np.ndarray,
5083
n_units: int,
@@ -58,10 +91,9 @@ def _compute_se_from_eif(
5891
center, and apply G/(G-1) small-sample correction.
5992
"""
6093
if cluster_indices is not None and n_clusters is not None:
61-
cluster_sums = np.bincount(cluster_indices, weights=eif, minlength=n_clusters)
62-
cluster_mean = np.mean(cluster_sums)
94+
centered = _cluster_aggregate(eif, cluster_indices, n_clusters)
6395
correction = n_clusters / (n_clusters - 1) if n_clusters > 1 else 1.0
64-
var = correction * np.mean((cluster_sums - cluster_mean) ** 2) / n_units
96+
var = correction * np.mean(centered**2) / n_units
6597
return float(np.sqrt(max(var, 0.0)))
6698
return float(np.sqrt(np.mean(eif**2) / n_units))
6799

@@ -91,8 +123,17 @@ class EfficientDiD(EfficientDiDBootstrapMixin):
91123
alpha : float, default 0.05
92124
Significance level.
93125
cluster : str or None
94-
Column name for cluster-robust SEs (not yet implemented —
95-
currently only unit-level inference).
126+
Column name for cluster-robust SEs. When set, analytical SEs
127+
use the Liang-Zeger clustered sandwich estimator on EIF values.
128+
With ``n_bootstrap > 0``, bootstrap weights are generated at the
129+
cluster level (all units in a cluster share the same weight).
130+
control_group : str, default ``"never_treated"``
131+
Which units serve as the comparison group:
132+
``"never_treated"`` requires a never-treated cohort (raises if
133+
none exist); ``"last_cohort"`` reclassifies the latest treatment
134+
cohort as pseudo-never-treated and drops post-treatment periods
135+
for that cohort. Distinct from CallawaySantAnna's
136+
``"not_yet_treated"`` — see REGISTRY.md for details.
96137
n_bootstrap : int, default 0
97138
Number of multiplier bootstrap iterations (0 = analytical only).
98139
bootstrap_weights : str, default ``"rademacher"``
@@ -151,7 +192,6 @@ def __init__(
151192
self.kernel_bandwidth = kernel_bandwidth
152193
self.is_fitted_ = False
153194
self.results_: Optional[EfficientDiDResults] = None
154-
self._store_eif = False
155195
self._validate_params()
156196

157197
def _validate_params(self) -> None:
@@ -229,6 +269,7 @@ def fit(
229269
covariates: Optional[List[str]] = None,
230270
aggregate: Optional[str] = None,
231271
balance_e: Optional[int] = None,
272+
store_eif: bool = False,
232273
) -> EfficientDiDResults:
233274
"""Fit the Efficient DiD estimator.
234275
@@ -397,8 +438,6 @@ def fit(
397438
else:
398439
unit_cluster_indices = None
399440
n_clusters = None
400-
self._cluster_indices = unit_cluster_indices
401-
self._n_clusters = n_clusters
402441

403442
period_to_col = {p: i for i, p in enumerate(time_periods)}
404443
period_1 = time_periods[0]
@@ -709,9 +748,7 @@ def fit(
709748
eif_by_gt[(g, t)] = eif_vals
710749

711750
# Analytical SE = sqrt(mean(EIF^2) / n) [paper p.21]
712-
se_gt = _compute_se_from_eif(
713-
eif_vals, n_units, self._cluster_indices, self._n_clusters
714-
)
751+
se_gt = _compute_se_from_eif(eif_vals, n_units, unit_cluster_indices, n_clusters)
715752

716753
t_stat, p_val, ci = safe_inference(att_gt, se_gt, alpha=self.alpha)
717754

@@ -733,7 +770,13 @@ def fit(
733770

734771
# ----- Aggregation -----
735772
overall_att, overall_se = self._aggregate_overall(
736-
group_time_effects, eif_by_gt, n_units, cohort_fractions, unit_cohorts
773+
group_time_effects,
774+
eif_by_gt,
775+
n_units,
776+
cohort_fractions,
777+
unit_cohorts,
778+
cluster_indices=unit_cluster_indices,
779+
n_clusters=n_clusters,
737780
)
738781
overall_t, overall_p, overall_ci = safe_inference(overall_att, overall_se, alpha=self.alpha)
739782

@@ -750,6 +793,8 @@ def fit(
750793
time_periods,
751794
balance_e,
752795
unit_cohorts=unit_cohorts,
796+
cluster_indices=unit_cluster_indices,
797+
n_clusters=n_clusters,
753798
)
754799
if aggregate in ("group", "all"):
755800
group_effects = self._aggregate_by_group(
@@ -759,6 +804,8 @@ def fit(
759804
cohort_fractions,
760805
treatment_groups,
761806
unit_cohorts=unit_cohorts,
807+
cluster_indices=unit_cluster_indices,
808+
n_clusters=n_clusters,
762809
)
763810

764811
# ----- Bootstrap -----
@@ -772,8 +819,8 @@ def fit(
772819
balance_e=balance_e,
773820
treatment_groups=treatment_groups,
774821
cohort_fractions=cohort_fractions,
775-
cluster_indices=self._cluster_indices,
776-
n_clusters=self._n_clusters,
822+
cluster_indices=unit_cluster_indices,
823+
n_clusters=n_clusters,
777824
)
778825
# Update estimates with bootstrap inference
779826
overall_se = bootstrap_results.overall_att_se
@@ -850,7 +897,7 @@ def fit(
850897
efficient_weights=stored_weights if stored_weights else None,
851898
omega_condition_numbers=stored_cond if stored_cond else None,
852899
control_group=self.control_group,
853-
influence_functions=eif_by_gt if self._store_eif else None,
900+
influence_functions=eif_by_gt if store_eif else None,
854901
bootstrap_results=bootstrap_results,
855902
estimation_path="dr" if use_covariates else "nocov",
856903
sieve_k_max=self.sieve_k_max,
@@ -918,6 +965,8 @@ def _aggregate_overall(
918965
n_units: int,
919966
cohort_fractions: Dict[float, float],
920967
unit_cohorts: np.ndarray,
968+
cluster_indices: Optional[np.ndarray] = None,
969+
n_clusters: Optional[int] = None,
921970
) -> Tuple[float, float]:
922971
"""Compute overall ATT with WIF-adjusted SE.
923972
@@ -965,7 +1014,7 @@ def _aggregate_overall(
9651014
agg_eif_total = agg_eif + wif # both O(1) scale
9661015

9671016
# SE = sqrt(mean(EIF^2) / n) — standard IF-based SE
968-
se = _compute_se_from_eif(agg_eif_total, n_units, self._cluster_indices, self._n_clusters)
1017+
se = _compute_se_from_eif(agg_eif_total, n_units, cluster_indices, n_clusters)
9691018

9701019
return overall_att, se
9711020

@@ -979,6 +1028,8 @@ def _aggregate_event_study(
9791028
time_periods: List[Any],
9801029
balance_e: Optional[int] = None,
9811030
unit_cohorts: Optional[np.ndarray] = None,
1031+
cluster_indices: Optional[np.ndarray] = None,
1032+
n_clusters: Optional[int] = None,
9821033
) -> Dict[int, Dict[str, Any]]:
9831034
"""Aggregate ATT(g,t) by relative time e = t - g.
9841035
@@ -1057,7 +1108,7 @@ def _aggregate_event_study(
10571108
)
10581109
agg_eif = agg_eif + wif
10591110

1060-
agg_se = _compute_se_from_eif(agg_eif, n_units, self._cluster_indices, self._n_clusters)
1111+
agg_se = _compute_se_from_eif(agg_eif, n_units, cluster_indices, n_clusters)
10611112

10621113
t_stat, p_val, ci = safe_inference(agg_eff, agg_se, alpha=self.alpha)
10631114
result[e] = {
@@ -1079,6 +1130,8 @@ def _aggregate_by_group(
10791130
cohort_fractions: Dict[float, float],
10801131
treatment_groups: List[Any],
10811132
unit_cohorts: Optional[np.ndarray] = None,
1133+
cluster_indices: Optional[np.ndarray] = None,
1134+
n_clusters: Optional[int] = None,
10821135
) -> Dict[Any, Dict[str, Any]]:
10831136
"""Aggregate ATT(g,t) by treatment cohort.
10841137
@@ -1117,7 +1170,7 @@ def _aggregate_by_group(
11171170
agg_eif = np.zeros(n_units)
11181171
for k, gt in enumerate(g_gts):
11191172
agg_eif += w[k] * eif_by_gt[gt]
1120-
agg_se = _compute_se_from_eif(agg_eif, n_units, self._cluster_indices, self._n_clusters)
1173+
agg_se = _compute_se_from_eif(agg_eif, n_units, cluster_indices, n_clusters)
11211174

11221175
t_stat, p_val, ci = safe_inference(agg_eff, agg_se, alpha=self.alpha)
11231176
result[g] = {
@@ -1206,12 +1259,10 @@ def hausman_pretest(
12061259
)
12071260

12081261
edid_all = cls(pt_assumption="all", alpha=alpha, **common_kwargs)
1209-
edid_all._store_eif = True
1210-
result_all = edid_all.fit(**fit_kwargs)
1262+
result_all = edid_all.fit(**fit_kwargs, store_eif=True)
12111263

12121264
edid_post = cls(pt_assumption="post", alpha=alpha, **common_kwargs)
1213-
edid_post._store_eif = True
1214-
result_post = edid_post.fit(**fit_kwargs)
1265+
result_post = edid_post.fit(**fit_kwargs, store_eif=True)
12151266

12161267
# Find common (g,t) pairs — PT-Post pairs are a subset of PT-All
12171268
common_gts = sorted(
@@ -1277,31 +1328,35 @@ def _nan_result(recommendation: str = "pt_post") -> HausmanPretestResult:
12771328
row_finite = np.all(np.isfinite(eif_all_mat), axis=1) & np.all(
12781329
np.isfinite(eif_post_mat), axis=1
12791330
)
1280-
cl_idx = edid_all._cluster_indices
1331+
# Build cluster mapping for covariance if needed
1332+
cl_idx: Optional[np.ndarray] = None
1333+
n_cl: Optional[int] = None
1334+
if cluster is not None:
1335+
all_units = sorted(data[unit].unique())
1336+
cluster_col = data.groupby(unit)[cluster].first()
1337+
cluster_ids = cluster_col.reindex(all_units).values
1338+
unique_clusters = np.unique(cluster_ids)
1339+
n_cl = len(unique_clusters)
1340+
cluster_to_idx = {c: i for i, c in enumerate(unique_clusters)}
1341+
cl_idx = np.array([cluster_to_idx[c] for c in cluster_ids])
1342+
12811343
if not np.all(row_finite):
12821344
eif_all_mat = eif_all_mat[row_finite]
12831345
eif_post_mat = eif_post_mat[row_finite]
12841346
n_units = int(np.sum(row_finite))
12851347
if cl_idx is not None:
12861348
cl_idx = cl_idx[row_finite]
12871349

1288-
# Compute full covariance matrices
1289-
if cl_idx is not None:
1290-
n_cl = edid_all._n_clusters
1350+
# Compute full covariance matrices using shared _cluster_aggregate
1351+
if cl_idx is not None and n_cl is not None:
12911352

1292-
def _cluster_cov(eif_mat: np.ndarray) -> np.ndarray:
1293-
s_mat = np.column_stack(
1294-
[
1295-
np.bincount(cl_idx, weights=eif_mat[:, j], minlength=n_cl)
1296-
for j in range(eif_mat.shape[1])
1297-
]
1298-
)
1299-
s_centered = s_mat - s_mat.mean(axis=0)
1353+
def _eif_cov(eif_mat: np.ndarray) -> np.ndarray:
1354+
centered = _cluster_aggregate(eif_mat, cl_idx, n_cl)
13001355
correction = n_cl / (n_cl - 1) if n_cl > 1 else 1.0
1301-
return correction * (s_centered.T @ s_centered) / (n_units**2)
1356+
return correction * (centered.T @ centered) / (n_units**2)
13021357

1303-
cov_all = _cluster_cov(eif_all_mat)
1304-
cov_post = _cluster_cov(eif_post_mat)
1358+
cov_all = _eif_cov(eif_all_mat)
1359+
cov_post = _eif_cov(eif_post_mat)
13051360
else:
13061361
with np.errstate(over="ignore", invalid="ignore"):
13071362
cov_all = (eif_all_mat.T @ eif_all_mat) / (n_units**2)

0 commit comments

Comments
 (0)