Skip to content

Commit b1e0237

Browse files
authored
Merge pull request #165 from igerber/tech-debt-paydown
Address tech debt from code reviews (PRs #115-#159)
2 parents edbb5ca + 969ae82 commit b1e0237

17 files changed

Lines changed: 397 additions & 325 deletions

README.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1983,7 +1983,7 @@ TROP(
19831983
max_iter=100, # Max iterations for factor estimation
19841984
tol=1e-6, # Convergence tolerance
19851985
alpha=0.05, # Significance level for CIs
1986-
n_bootstrap=200, # Bootstrap replications
1986+
n_bootstrap=200, # Bootstrap replications (minimum 2; TROP requires bootstrap for SEs)
19871987
seed=None # Random seed
19881988
)
19891989
```
@@ -2064,8 +2064,6 @@ SunAbraham(
20642064
| `time` | str | Time period column |
20652065
| `first_treat` | str | Column with first treatment period (0 for never-treated) |
20662066
| `covariates` | list | Covariate column names |
2067-
| `min_pre_periods` | int | Minimum pre-treatment periods to include |
2068-
| `min_post_periods` | int | Minimum post-treatment periods to include |
20692067

20702068
### SunAbrahamResults
20712069

@@ -2105,6 +2103,7 @@ ImputationDiD(
21052103
alpha=0.05, # Significance level for CIs
21062104
cluster=None, # Column for cluster-robust SEs
21072105
n_bootstrap=0, # Bootstrap iterations (0 = analytical)
2106+
bootstrap_weights='rademacher', # 'rademacher', 'mammen', or 'webb'
21082107
seed=None, # Random seed
21092108
rank_deficient_action='warn', # 'warn', 'error', or 'silent'
21102109
horizon_max=None, # Max event-study horizon
@@ -2159,6 +2158,7 @@ TwoStageDiD(
21592158
alpha=0.05, # Significance level for CIs
21602159
cluster=None, # Column for cluster-robust SEs (defaults to unit)
21612160
n_bootstrap=0, # Bootstrap iterations (0 = analytical GMM SEs)
2161+
bootstrap_weights='rademacher', # 'rademacher', 'mammen', or 'webb'
21622162
seed=None, # Random seed
21632163
rank_deficient_action='warn', # 'warn', 'error', or 'silent'
21642164
horizon_max=None, # Max event-study horizon

TODO.md

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ Target: < 1000 lines per module for maintainability.
2828
| ~~`trop.py`~~ | ~~2904~~ ~2560 | ✅ Partially split: results extracted to `trop_results.py` (~340 lines) |
2929
| ~~`imputation.py`~~ | ~~2480~~ ~1740 | ✅ Split into imputation.py, imputation_results.py, imputation_bootstrap.py |
3030
| ~~`two_stage.py`~~ | ~~2209~~ ~1490 | ✅ Split into two_stage.py, two_stage_results.py, two_stage_bootstrap.py |
31-
| `utils.py` | 1879 | Monitor -- legacy placebo functions stay to avoid circular imports |
31+
| `utils.py` | 1780 | Monitor -- legacy placebo function removed |
3232
| `visualization.py` | 1678 | Monitor -- growing but cohesive |
3333
| `linalg.py` | 1537 | Monitor -- unified backend, splitting would hurt cohesion |
3434
| `honest_did.py` | 1511 | Acceptable |
@@ -58,18 +58,18 @@ Deferred items from PR reviews that were not addressed before merge.
5858

5959
| Issue | Location | PR | Priority |
6060
|-------|----------|----|----------|
61-
| TwoStageDiD & ImputationDiD bootstrap hardcodes Rademacher only; no `bootstrap_weights` parameter unlike CallawaySantAnna | `two_stage_bootstrap.py`, `imputation_bootstrap.py` | #156, #141 | Medium |
62-
| TwoStageDiD GMM score logic duplicated between analytic/bootstrap with inconsistent NaN/overflow handling | `two_stage.py`, `two_stage_bootstrap.py` | #156 | Medium |
63-
| ImputationDiD weight construction duplicated between aggregation and bootstrap (drift risk) -- has explicit code comment acknowledging duplication | `imputation.py`, `imputation_bootstrap.py` | #141 | Medium |
64-
| ImputationDiD dense `(A0'A0).toarray()` scales O((U+T+K)^2), OOM risk on large panels | `imputation.py` | #141 | Medium |
61+
| ~~TwoStageDiD & ImputationDiD bootstrap hardcodes Rademacher only; no `bootstrap_weights` parameter unlike CallawaySantAnna~~ | ~~`two_stage_bootstrap.py`, `imputation_bootstrap.py`~~ | ~~#156, #141~~ | ✅ Fixed: Added `bootstrap_weights` parameter to both estimators |
62+
| ~~TwoStageDiD GMM score logic duplicated between analytic/bootstrap with inconsistent NaN/overflow handling~~ | ~~`two_stage.py`, `two_stage_bootstrap.py`~~ | ~~#156~~ | ✅ Fixed: Unified via `_compute_gmm_scores()` static method |
63+
| ~~ImputationDiD weight construction duplicated between aggregation and bootstrap (drift risk)~~ | ~~`imputation.py`, `imputation_bootstrap.py`~~ | ~~#141~~ | ✅ Fixed: Extracted `_compute_target_weights()` helper in `imputation_bootstrap.py` |
64+
| ImputationDiD dense `(A0'A0).toarray()` scales O((U+T+K)^2), OOM risk on large panels | `imputation.py` | #141 | Medium (deferred — only triggers when sparse solver fails; fixing requires sparse least-squares alternatives) |
6565

6666
#### Performance
6767

6868
| Issue | Location | PR | Priority |
6969
|-------|----------|----|----------|
70-
| TwoStageDiD per-column `.toarray()` in loop for cluster scores | `two_stage_bootstrap.py` | #156 | Medium |
70+
| ~~TwoStageDiD per-column `.toarray()` in loop for cluster scores~~ | ~~`two_stage_bootstrap.py`~~ | ~~#156~~ | ✅ Fixed: Single `.toarray()` call replaces per-column loop |
7171
| ImputationDiD event-study SEs recompute full conservative variance per horizon (should cache A0/A1 factorization) | `imputation.py` | #141 | Low |
72-
| Legacy `compute_placebo_effects` uses deprecated projected-gradient weights (marked deprecated, users directed to `SyntheticDiD`) | `utils.py:1689-1691` | #145 | Low |
72+
| ~~Legacy `compute_placebo_effects` uses deprecated projected-gradient weights~~ | ~~`utils.py:1689-1691`~~ | ~~#145~~ | ✅ Fixed: Removed function entirely |
7373
| Rust faer SVD ndarray-to-faer conversion overhead (minimal vs SVD cost) | `rust/src/linalg.rs:67` | #115 | Low |
7474

7575
#### Testing/Docs
@@ -78,11 +78,11 @@ Deferred items from PR reviews that were not addressed before merge.
7878
|-------|----------|----|----------|
7979
| Tutorial notebooks not executed in CI | `docs/tutorials/*.ipynb` | #159 | Low |
8080
| ~~TwoStageDiD `test_nan_propagation` is a no-op~~ | ~~`tests/test_two_stage.py:643-652`~~ | ~~#156~~ | ✅ Fixed |
81-
| ImputationDiD bootstrap + covariate path untested | `tests/test_imputation.py` | #141 | Low |
82-
| TROP `n_bootstrap >= 2` validation missing (can yield 0/NaN SE silently) | `trop.py:462` | #124 | Low |
83-
| SunAbraham deprecated `min_pre_periods`/`min_post_periods` still in `fit()` docstring | `sun_abraham.py:458-487` | #153 | Low |
81+
| ~~ImputationDiD bootstrap + covariate path untested~~ | ~~`tests/test_imputation.py`~~ | ~~#141~~ | ✅ Fixed: Added `test_bootstrap_with_covariates` |
82+
| ~~TROP `n_bootstrap >= 2` validation missing (can yield 0/NaN SE silently)~~ | ~~`trop.py:462`~~ | ~~#124~~ | ✅ Fixed: Added `ValueError` for `n_bootstrap < 2` |
83+
| ~~SunAbraham deprecated `min_pre_periods`/`min_post_periods` still in `fit()` docstring~~ | ~~`sun_abraham.py:458-487`~~ | ~~#153~~ | ✅ Fixed: Removed deprecated params from `fit()` |
8484
| R comparison tests spawn separate `Rscript` per test (slow CI) | `tests/test_methodology_twfe.py:294` | #139 | Low |
85-
| Rust TROP bootstrap SE returns 0.0 instead of NaN for <2 samples | `rust/src/trop.rs:1038-1054` | #115 | Low |
85+
| ~~Rust TROP bootstrap SE returns 0.0 instead of NaN for <2 samples~~ | ~~`rust/src/trop.rs:1038-1054`~~ | ~~#115~~ | ✅ Already fixed: Returns `f64::NAN` at `rust/src/trop.rs:1034` |
8686

8787
---
8888

diff_diff/imputation.py

Lines changed: 13 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from scipy import sparse, stats
2323
from scipy.sparse.linalg import spsolve
2424

25-
from diff_diff.imputation_bootstrap import ImputationDiDBootstrapMixin
25+
from diff_diff.imputation_bootstrap import ImputationDiDBootstrapMixin, _compute_target_weights
2626
from diff_diff.imputation_results import ImputationBootstrapResults, ImputationDiDResults # noqa: F401 (re-export)
2727
from diff_diff.linalg import solve_ols
2828
from diff_diff.utils import safe_inference
@@ -63,6 +63,8 @@ class ImputationDiD(ImputationDiDBootstrapMixin):
6363
n_bootstrap : int, default=0
6464
Number of bootstrap iterations. If 0, uses analytical inference
6565
(conservative variance from Theorem 3).
66+
bootstrap_weights : str, default="rademacher"
67+
Type of bootstrap weights: "rademacher", "mammen", or "webb".
6668
seed : int, optional
6769
Random seed for reproducibility.
6870
rank_deficient_action : str, default="warn"
@@ -126,6 +128,7 @@ def __init__(
126128
alpha: float = 0.05,
127129
cluster: Optional[str] = None,
128130
n_bootstrap: int = 0,
131+
bootstrap_weights: str = "rademacher",
129132
seed: Optional[int] = None,
130133
rank_deficient_action: str = "warn",
131134
horizon_max: Optional[int] = None,
@@ -136,6 +139,11 @@ def __init__(
136139
f"rank_deficient_action must be 'warn', 'error', or 'silent', "
137140
f"got '{rank_deficient_action}'"
138141
)
142+
if bootstrap_weights not in ("rademacher", "mammen", "webb"):
143+
raise ValueError(
144+
f"bootstrap_weights must be 'rademacher', 'mammen', or 'webb', "
145+
f"got '{bootstrap_weights}'"
146+
)
139147
if aux_partition not in ("cohort_horizon", "cohort", "horizon"):
140148
raise ValueError(
141149
f"aux_partition must be 'cohort_horizon', 'cohort', or 'horizon', "
@@ -146,6 +154,7 @@ def __init__(
146154
self.alpha = alpha
147155
self.cluster = cluster
148156
self.n_bootstrap = n_bootstrap
157+
self.bootstrap_weights = bootstrap_weights
149158
self.seed = seed
150159
self.rank_deficient_action = rank_deficient_action
151160
self.horizon_max = horizon_max
@@ -1359,15 +1368,7 @@ def _aggregate_event_study(
13591368
effect = float(np.mean(valid_tau))
13601369

13611370
# Compute SE via conservative variance with horizon-specific weights
1362-
weights_h = np.zeros(int(omega_1_mask.sum()))
1363-
# Map h_mask (relative to df_1) to weights array
1364-
h_indices_in_omega1 = np.where(h_mask)[0]
1365-
n_valid = len(valid_tau)
1366-
# Only weight valid (finite) observations
1367-
finite_mask = np.isfinite(tau_hat[h_mask])
1368-
valid_h_indices = h_indices_in_omega1[finite_mask]
1369-
for idx in valid_h_indices:
1370-
weights_h[idx] = 1.0 / n_valid
1371+
weights_h, n_valid = _compute_target_weights(tau_hat, h_mask)
13711372

13721373
se = self._compute_conservative_variance(
13731374
df=df,
@@ -1477,12 +1478,7 @@ def _aggregate_group(
14771478
effect = float(np.mean(valid_tau))
14781479

14791480
# Compute SE with group-specific weights
1480-
weights_g = np.zeros(int(omega_1_mask.sum()))
1481-
finite_mask = np.isfinite(tau_hat) & g_mask
1482-
g_indices = np.where(finite_mask)[0]
1483-
n_valid = len(valid_tau)
1484-
for idx in g_indices:
1485-
weights_g[idx] = 1.0 / n_valid
1481+
weights_g, _ = _compute_target_weights(tau_hat, g_mask)
14861482

14871483
se = self._compute_conservative_variance(
14881484
df=df,
@@ -1664,6 +1660,7 @@ def get_params(self) -> Dict[str, Any]:
16641660
"alpha": self.alpha,
16651661
"cluster": self.cluster,
16661662
"n_bootstrap": self.n_bootstrap,
1663+
"bootstrap_weights": self.bootstrap_weights,
16671664
"seed": self.seed,
16681665
"rank_deficient_action": self.rank_deficient_action,
16691666
"horizon_max": self.horizon_max,

diff_diff/imputation_bootstrap.py

Lines changed: 41 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,39 @@
1919
]
2020

2121

22+
def _compute_target_weights(
23+
tau_hat: np.ndarray,
24+
target_mask: np.ndarray,
25+
) -> "tuple[np.ndarray, int]":
26+
"""
27+
Equal weights for finite tau_hat observations within target_mask.
28+
29+
Used by both aggregation and bootstrap paths to avoid weight logic
30+
duplication.
31+
32+
Parameters
33+
----------
34+
tau_hat : np.ndarray
35+
Per-observation treatment effects (may contain NaN).
36+
target_mask : np.ndarray
37+
Boolean mask selecting the target subset within tau_hat.
38+
39+
Returns
40+
-------
41+
weights : np.ndarray
42+
Weight array (same length as tau_hat). 1/n_valid for finite
43+
observations in target_mask, 0 elsewhere.
44+
n_valid : int
45+
Number of finite observations in the target subset.
46+
"""
47+
finite_target = np.isfinite(tau_hat) & target_mask
48+
n_valid = int(finite_target.sum())
49+
weights = np.zeros(len(tau_hat))
50+
if n_valid > 0:
51+
weights[np.where(finite_target)[0]] = 1.0 / n_valid
52+
return weights, n_valid
53+
54+
2255
class ImputationDiDBootstrapMixin:
2356
"""Mixin providing bootstrap inference methods for ImputationDiD."""
2457

@@ -91,7 +124,8 @@ def _precompute_bootstrap_psi(
91124
92125
For each aggregation target (overall, per-horizon, per-group), computes
93126
psi_i = sum_t v_it * epsilon_tilde_it for each cluster. The multiplier
94-
bootstrap then perturbs these psi sums with Rademacher weights.
127+
bootstrap then perturbs these psi sums with multiplier weights
128+
(rademacher/mammen/webb; configurable via ``bootstrap_weights``).
95129
96130
Computational cost scales with the number of aggregation targets, since
97131
each target requires its own v_untreated computation (weight-dependent).
@@ -120,13 +154,10 @@ def _precompute_bootstrap_psi(
120154
result["overall"] = (overall_psi, cluster_ids)
121155

122156
# Event study: per-horizon weights
123-
# NOTE: weight logic duplicated from _aggregate_event_study.
124-
# If weight scheme changes there, update here too.
125157
if event_study_effects:
126158
result["event_study"] = {}
127159
df_1 = df.loc[omega_1_mask]
128160
rel_times = df_1["_rel_time"].values
129-
n_omega_1 = int(omega_1_mask.sum())
130161

131162
# Balanced cohort mask (same logic as _aggregate_event_study)
132163
balanced_mask = None
@@ -150,37 +181,28 @@ def _precompute_bootstrap_psi(
150181
h_mask = rel_times == h
151182
if balanced_mask is not None:
152183
h_mask = h_mask & balanced_mask
153-
weights_h = np.zeros(n_omega_1)
154-
finite_h = np.isfinite(tau_hat) & h_mask
155-
n_valid_h = int(finite_h.sum())
184+
weights_h, n_valid_h = _compute_target_weights(tau_hat, h_mask)
156185
if n_valid_h == 0:
157186
continue
158-
weights_h[np.where(finite_h)[0]] = 1.0 / n_valid_h
159187

160188
psi_h, _ = self._compute_cluster_psi_sums(**common, weights=weights_h)
161189
result["event_study"][h] = psi_h
162190

163191
# Group effects: per-group weights
164-
# NOTE: weight logic duplicated from _aggregate_group.
165-
# If weight scheme changes there, update here too.
166192
if group_effects:
167193
result["group"] = {}
168194
df_1 = df.loc[omega_1_mask]
169195
cohorts = df_1[first_treat].values
170-
n_omega_1 = int(omega_1_mask.sum())
171196

172197
for g in group_effects:
173198
if group_effects[g].get("n_obs", 0) == 0:
174199
continue
175200
if not np.isfinite(group_effects[g].get("effect", np.nan)):
176201
continue
177202
g_mask = cohorts == g
178-
weights_g = np.zeros(n_omega_1)
179-
finite_g = np.isfinite(tau_hat) & g_mask
180-
n_valid_g = int(finite_g.sum())
203+
weights_g, n_valid_g = _compute_target_weights(tau_hat, g_mask)
181204
if n_valid_g == 0:
182205
continue
183-
weights_g[np.where(finite_g)[0]] = 1.0 / n_valid_g
184206

185207
psi_g, _ = self._compute_cluster_psi_sums(**common, weights=weights_g)
186208
result["group"][g] = psi_g
@@ -197,7 +219,8 @@ def _run_bootstrap(
197219
"""
198220
Run multiplier bootstrap on pre-computed influence function sums.
199221
200-
Uses T_b = sum_i w_b_i * psi_i where w_b_i are Rademacher weights
222+
Uses T_b = sum_i w_b_i * psi_i where w_b_i are multiplier weights
223+
(rademacher/mammen/webb; configurable via ``bootstrap_weights``)
201224
and psi_i are cluster-level influence function sums from Theorem 3.
202225
SE = std(T_b, ddof=1).
203226
"""
@@ -216,7 +239,7 @@ def _run_bootstrap(
216239

217240
# Generate ALL weights upfront: shape (n_bootstrap, n_clusters)
218241
all_weights = _generate_bootstrap_weights_batch(
219-
self.n_bootstrap, n_clusters, "rademacher", rng
242+
self.n_bootstrap, n_clusters, self.bootstrap_weights, rng
220243
)
221244

222245
# Overall ATT bootstrap draws
@@ -295,7 +318,7 @@ def _run_bootstrap(
295318

296319
return ImputationBootstrapResults(
297320
n_bootstrap=self.n_bootstrap,
298-
weight_type="rademacher",
321+
weight_type=self.bootstrap_weights,
299322
alpha=self.alpha,
300323
overall_att_se=overall_se,
301324
overall_att_ci=overall_ci,

diff_diff/imputation_results.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ class ImputationBootstrapResults:
3333
n_bootstrap : int
3434
Number of bootstrap iterations.
3535
weight_type : str
36-
Type of bootstrap weights (currently "rademacher" only).
36+
Type of bootstrap weights: "rademacher", "mammen", or "webb".
3737
alpha : float
3838
Significance level used for confidence intervals.
3939
overall_att_se : float

diff_diff/sun_abraham.py

Lines changed: 0 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -433,8 +433,6 @@ def fit(
433433
time: str,
434434
first_treat: str,
435435
covariates: Optional[List[str]] = None,
436-
min_pre_periods: int = 1,
437-
min_post_periods: int = 1,
438436
) -> SunAbrahamResults:
439437
"""
440438
Fit the Sun-Abraham estimator using saturated regression.
@@ -454,10 +452,6 @@ def fit(
454452
Use 0 (or np.inf) for never-treated units.
455453
covariates : list, optional
456454
List of covariate column names to include in regression.
457-
min_pre_periods : int, default=1
458-
**Deprecated**: Accepted but ignored. Will be removed in a future version.
459-
min_post_periods : int, default=1
460-
**Deprecated**: Accepted but ignored. Will be removed in a future version.
461455
462456
Returns
463457
-------
@@ -469,22 +463,6 @@ def fit(
469463
ValueError
470464
If required columns are missing or data validation fails.
471465
"""
472-
# Deprecation warnings for unimplemented parameters
473-
if min_pre_periods != 1:
474-
warnings.warn(
475-
"min_pre_periods is not yet implemented and will be ignored. "
476-
"This parameter will be removed in a future version.",
477-
FutureWarning,
478-
stacklevel=2,
479-
)
480-
if min_post_periods != 1:
481-
warnings.warn(
482-
"min_post_periods is not yet implemented and will be ignored. "
483-
"This parameter will be removed in a future version.",
484-
FutureWarning,
485-
stacklevel=2,
486-
)
487-
488466
# Validate inputs
489467
required_cols = [outcome, unit, time, first_treat]
490468
if covariates:

diff_diff/trop.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ class TROP:
9393
alpha : float, default=0.05
9494
Significance level for confidence intervals.
9595
n_bootstrap : int, default=200
96-
Number of bootstrap replications for variance estimation.
96+
Number of bootstrap replications for variance estimation. Must be >= 2.
9797
seed : int, optional
9898
Random seed for reproducibility.
9999
@@ -156,6 +156,12 @@ def __init__(
156156
self.lambda_unit_grid = lambda_unit_grid or [0.0, 0.1, 0.5, 1.0, 2.0, 5.0]
157157
self.lambda_nn_grid = lambda_nn_grid or [0.0, 0.01, 0.1, 1.0, 10.0]
158158

159+
if n_bootstrap < 2:
160+
raise ValueError(
161+
"n_bootstrap must be >= 2 for TROP (bootstrap variance "
162+
"estimation is always used)"
163+
)
164+
159165
self.max_iter = max_iter
160166
self.tol = tol
161167
self.alpha = alpha

0 commit comments

Comments
 (0)