4545__all__ = ["EfficientDiD" , "EfficientDiDResults" , "EDiDBootstrapResults" ]
4646
4747
48+ def _cluster_aggregate (
49+ eif_mat : np .ndarray ,
50+ cluster_indices : np .ndarray ,
51+ n_clusters : int ,
52+ ) -> np .ndarray :
53+ """Sum EIF values within clusters and center.
54+
55+ Parameters
56+ ----------
57+ eif_mat : ndarray, shape (n_units,) or (n_units, k)
58+ EIF values — 1-D for a single estimand, 2-D for multiple.
59+ cluster_indices : ndarray, shape (n_units,)
60+ Integer cluster assignment per unit.
61+ n_clusters : int
62+ Number of unique clusters.
63+
64+ Returns
65+ -------
66+ ndarray, shape (n_clusters,) or (n_clusters, k)
67+ Centered cluster-level sums.
68+ """
69+ if eif_mat .ndim == 1 :
70+ sums = np .bincount (cluster_indices , weights = eif_mat , minlength = n_clusters ).astype (float )
71+ else :
72+ sums = np .column_stack (
73+ [
74+ np .bincount (cluster_indices , weights = eif_mat [:, j ], minlength = n_clusters )
75+ for j in range (eif_mat .shape [1 ])
76+ ]
77+ ).astype (float )
78+ return sums - sums .mean (axis = 0 )
79+
80+
4881def _compute_se_from_eif (
4982 eif : np .ndarray ,
5083 n_units : int ,
@@ -58,10 +91,9 @@ def _compute_se_from_eif(
5891 center, and apply G/(G-1) small-sample correction.
5992 """
6093 if cluster_indices is not None and n_clusters is not None :
61- cluster_sums = np .bincount (cluster_indices , weights = eif , minlength = n_clusters )
62- cluster_mean = np .mean (cluster_sums )
94+ centered = _cluster_aggregate (eif , cluster_indices , n_clusters )
6395 correction = n_clusters / (n_clusters - 1 ) if n_clusters > 1 else 1.0
64- var = correction * np .mean (( cluster_sums - cluster_mean ) ** 2 ) / n_units
96+ var = correction * np .mean (centered ** 2 ) / n_units
6597 return float (np .sqrt (max (var , 0.0 )))
6698 return float (np .sqrt (np .mean (eif ** 2 ) / n_units ))
6799
@@ -91,8 +123,17 @@ class EfficientDiD(EfficientDiDBootstrapMixin):
91123 alpha : float, default 0.05
92124 Significance level.
93125 cluster : str or None
94- Column name for cluster-robust SEs (not yet implemented —
95- currently only unit-level inference).
126+ Column name for cluster-robust SEs. When set, analytical SEs
127+ use the Liang-Zeger clustered sandwich estimator on EIF values.
128+ With ``n_bootstrap > 0``, bootstrap weights are generated at the
129+ cluster level (all units in a cluster share the same weight).
130+ control_group : str, default ``"never_treated"``
131+ Which units serve as the comparison group:
132+ ``"never_treated"`` requires a never-treated cohort (raises if
133+ none exist); ``"last_cohort"`` reclassifies the latest treatment
134+ cohort as pseudo-never-treated and drops post-treatment periods
135+ for that cohort. Distinct from CallawaySantAnna's
136+ ``"not_yet_treated"`` — see REGISTRY.md for details.
96137 n_bootstrap : int, default 0
97138 Number of multiplier bootstrap iterations (0 = analytical only).
98139 bootstrap_weights : str, default ``"rademacher"``
@@ -151,7 +192,6 @@ def __init__(
151192 self .kernel_bandwidth = kernel_bandwidth
152193 self .is_fitted_ = False
153194 self .results_ : Optional [EfficientDiDResults ] = None
154- self ._store_eif = False
155195 self ._validate_params ()
156196
157197 def _validate_params (self ) -> None :
@@ -229,6 +269,7 @@ def fit(
229269 covariates : Optional [List [str ]] = None ,
230270 aggregate : Optional [str ] = None ,
231271 balance_e : Optional [int ] = None ,
272+ store_eif : bool = False ,
232273 ) -> EfficientDiDResults :
233274 """Fit the Efficient DiD estimator.
234275
@@ -397,8 +438,6 @@ def fit(
397438 else :
398439 unit_cluster_indices = None
399440 n_clusters = None
400- self ._cluster_indices = unit_cluster_indices
401- self ._n_clusters = n_clusters
402441
403442 period_to_col = {p : i for i , p in enumerate (time_periods )}
404443 period_1 = time_periods [0 ]
@@ -709,9 +748,7 @@ def fit(
709748 eif_by_gt [(g , t )] = eif_vals
710749
711750 # Analytical SE = sqrt(mean(EIF^2) / n) [paper p.21]
712- se_gt = _compute_se_from_eif (
713- eif_vals , n_units , self ._cluster_indices , self ._n_clusters
714- )
751+ se_gt = _compute_se_from_eif (eif_vals , n_units , unit_cluster_indices , n_clusters )
715752
716753 t_stat , p_val , ci = safe_inference (att_gt , se_gt , alpha = self .alpha )
717754
@@ -733,7 +770,13 @@ def fit(
733770
734771 # ----- Aggregation -----
735772 overall_att , overall_se = self ._aggregate_overall (
736- group_time_effects , eif_by_gt , n_units , cohort_fractions , unit_cohorts
773+ group_time_effects ,
774+ eif_by_gt ,
775+ n_units ,
776+ cohort_fractions ,
777+ unit_cohorts ,
778+ cluster_indices = unit_cluster_indices ,
779+ n_clusters = n_clusters ,
737780 )
738781 overall_t , overall_p , overall_ci = safe_inference (overall_att , overall_se , alpha = self .alpha )
739782
@@ -750,6 +793,8 @@ def fit(
750793 time_periods ,
751794 balance_e ,
752795 unit_cohorts = unit_cohorts ,
796+ cluster_indices = unit_cluster_indices ,
797+ n_clusters = n_clusters ,
753798 )
754799 if aggregate in ("group" , "all" ):
755800 group_effects = self ._aggregate_by_group (
@@ -759,6 +804,8 @@ def fit(
759804 cohort_fractions ,
760805 treatment_groups ,
761806 unit_cohorts = unit_cohorts ,
807+ cluster_indices = unit_cluster_indices ,
808+ n_clusters = n_clusters ,
762809 )
763810
764811 # ----- Bootstrap -----
@@ -772,8 +819,8 @@ def fit(
772819 balance_e = balance_e ,
773820 treatment_groups = treatment_groups ,
774821 cohort_fractions = cohort_fractions ,
775- cluster_indices = self . _cluster_indices ,
776- n_clusters = self . _n_clusters ,
822+ cluster_indices = unit_cluster_indices ,
823+ n_clusters = n_clusters ,
777824 )
778825 # Update estimates with bootstrap inference
779826 overall_se = bootstrap_results .overall_att_se
@@ -850,7 +897,7 @@ def fit(
850897 efficient_weights = stored_weights if stored_weights else None ,
851898 omega_condition_numbers = stored_cond if stored_cond else None ,
852899 control_group = self .control_group ,
853- influence_functions = eif_by_gt if self . _store_eif else None ,
900+ influence_functions = eif_by_gt if store_eif else None ,
854901 bootstrap_results = bootstrap_results ,
855902 estimation_path = "dr" if use_covariates else "nocov" ,
856903 sieve_k_max = self .sieve_k_max ,
@@ -918,6 +965,8 @@ def _aggregate_overall(
918965 n_units : int ,
919966 cohort_fractions : Dict [float , float ],
920967 unit_cohorts : np .ndarray ,
968+ cluster_indices : Optional [np .ndarray ] = None ,
969+ n_clusters : Optional [int ] = None ,
921970 ) -> Tuple [float , float ]:
922971 """Compute overall ATT with WIF-adjusted SE.
923972
@@ -965,7 +1014,7 @@ def _aggregate_overall(
9651014 agg_eif_total = agg_eif + wif # both O(1) scale
9661015
9671016 # SE = sqrt(mean(EIF^2) / n) — standard IF-based SE
968- se = _compute_se_from_eif (agg_eif_total , n_units , self . _cluster_indices , self . _n_clusters )
1017+ se = _compute_se_from_eif (agg_eif_total , n_units , cluster_indices , n_clusters )
9691018
9701019 return overall_att , se
9711020
@@ -979,6 +1028,8 @@ def _aggregate_event_study(
9791028 time_periods : List [Any ],
9801029 balance_e : Optional [int ] = None ,
9811030 unit_cohorts : Optional [np .ndarray ] = None ,
1031+ cluster_indices : Optional [np .ndarray ] = None ,
1032+ n_clusters : Optional [int ] = None ,
9821033 ) -> Dict [int , Dict [str , Any ]]:
9831034 """Aggregate ATT(g,t) by relative time e = t - g.
9841035
@@ -1057,7 +1108,7 @@ def _aggregate_event_study(
10571108 )
10581109 agg_eif = agg_eif + wif
10591110
1060- agg_se = _compute_se_from_eif (agg_eif , n_units , self . _cluster_indices , self . _n_clusters )
1111+ agg_se = _compute_se_from_eif (agg_eif , n_units , cluster_indices , n_clusters )
10611112
10621113 t_stat , p_val , ci = safe_inference (agg_eff , agg_se , alpha = self .alpha )
10631114 result [e ] = {
@@ -1079,6 +1130,8 @@ def _aggregate_by_group(
10791130 cohort_fractions : Dict [float , float ],
10801131 treatment_groups : List [Any ],
10811132 unit_cohorts : Optional [np .ndarray ] = None ,
1133+ cluster_indices : Optional [np .ndarray ] = None ,
1134+ n_clusters : Optional [int ] = None ,
10821135 ) -> Dict [Any , Dict [str , Any ]]:
10831136 """Aggregate ATT(g,t) by treatment cohort.
10841137
@@ -1117,7 +1170,7 @@ def _aggregate_by_group(
11171170 agg_eif = np .zeros (n_units )
11181171 for k , gt in enumerate (g_gts ):
11191172 agg_eif += w [k ] * eif_by_gt [gt ]
1120- agg_se = _compute_se_from_eif (agg_eif , n_units , self . _cluster_indices , self . _n_clusters )
1173+ agg_se = _compute_se_from_eif (agg_eif , n_units , cluster_indices , n_clusters )
11211174
11221175 t_stat , p_val , ci = safe_inference (agg_eff , agg_se , alpha = self .alpha )
11231176 result [g ] = {
@@ -1206,12 +1259,10 @@ def hausman_pretest(
12061259 )
12071260
12081261 edid_all = cls (pt_assumption = "all" , alpha = alpha , ** common_kwargs )
1209- edid_all ._store_eif = True
1210- result_all = edid_all .fit (** fit_kwargs )
1262+ result_all = edid_all .fit (** fit_kwargs , store_eif = True )
12111263
12121264 edid_post = cls (pt_assumption = "post" , alpha = alpha , ** common_kwargs )
1213- edid_post ._store_eif = True
1214- result_post = edid_post .fit (** fit_kwargs )
1265+ result_post = edid_post .fit (** fit_kwargs , store_eif = True )
12151266
12161267 # Find common (g,t) pairs — PT-Post pairs are a subset of PT-All
12171268 common_gts = sorted (
@@ -1277,31 +1328,35 @@ def _nan_result(recommendation: str = "pt_post") -> HausmanPretestResult:
12771328 row_finite = np .all (np .isfinite (eif_all_mat ), axis = 1 ) & np .all (
12781329 np .isfinite (eif_post_mat ), axis = 1
12791330 )
1280- cl_idx = edid_all ._cluster_indices
1331+ # Build cluster mapping for covariance if needed
1332+ cl_idx : Optional [np .ndarray ] = None
1333+ n_cl : Optional [int ] = None
1334+ if cluster is not None :
1335+ all_units = sorted (data [unit ].unique ())
1336+ cluster_col = data .groupby (unit )[cluster ].first ()
1337+ cluster_ids = cluster_col .reindex (all_units ).values
1338+ unique_clusters = np .unique (cluster_ids )
1339+ n_cl = len (unique_clusters )
1340+ cluster_to_idx = {c : i for i , c in enumerate (unique_clusters )}
1341+ cl_idx = np .array ([cluster_to_idx [c ] for c in cluster_ids ])
1342+
12811343 if not np .all (row_finite ):
12821344 eif_all_mat = eif_all_mat [row_finite ]
12831345 eif_post_mat = eif_post_mat [row_finite ]
12841346 n_units = int (np .sum (row_finite ))
12851347 if cl_idx is not None :
12861348 cl_idx = cl_idx [row_finite ]
12871349
1288- # Compute full covariance matrices
1289- if cl_idx is not None :
1290- n_cl = edid_all ._n_clusters
1350+ # Compute full covariance matrices using shared _cluster_aggregate
1351+ if cl_idx is not None and n_cl is not None :
12911352
1292- def _cluster_cov (eif_mat : np .ndarray ) -> np .ndarray :
1293- s_mat = np .column_stack (
1294- [
1295- np .bincount (cl_idx , weights = eif_mat [:, j ], minlength = n_cl )
1296- for j in range (eif_mat .shape [1 ])
1297- ]
1298- )
1299- s_centered = s_mat - s_mat .mean (axis = 0 )
1353+ def _eif_cov (eif_mat : np .ndarray ) -> np .ndarray :
1354+ centered = _cluster_aggregate (eif_mat , cl_idx , n_cl )
13001355 correction = n_cl / (n_cl - 1 ) if n_cl > 1 else 1.0
1301- return correction * (s_centered .T @ s_centered ) / (n_units ** 2 )
1356+ return correction * (centered .T @ centered ) / (n_units ** 2 )
13021357
1303- cov_all = _cluster_cov (eif_all_mat )
1304- cov_post = _cluster_cov (eif_post_mat )
1358+ cov_all = _eif_cov (eif_all_mat )
1359+ cov_post = _eif_cov (eif_post_mat )
13051360 else :
13061361 with np .errstate (over = "ignore" , invalid = "ignore" ):
13071362 cov_all = (eif_all_mat .T @ eif_all_mat ) / (n_units ** 2 )
0 commit comments