Skip to content

Commit 3b2e40d

Browse files
igerberclaude
andcommitted
Add sparse size guard, remove stale SA params, add tests for PR #165 round 2
- Add _SPARSE_DENSE_THRESHOLD in two_stage.py with per-column .tocsc() fallback for large FE matrices; apply same pattern in bootstrap module - Remove min_pre_periods/min_post_periods from README SunAbraham table - Add test_removed_params_raise_typeerror for SunAbraham - Add test_sparse_fallback_path for TwoStageDiD dense/sparse equivalence Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent e97ca9d commit 3b2e40d

5 files changed

Lines changed: 71 additions & 15 deletions

File tree

README.md

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2064,8 +2064,6 @@ SunAbraham(
20642064
| `time` | str | Time period column |
20652065
| `first_treat` | str | Column with first treatment period (0 for never-treated) |
20662066
| `covariates` | list | Covariate column names |
2067-
| `min_pre_periods` | int | Minimum pre-treatment periods to include |
2068-
| `min_post_periods` | int | Minimum post-treatment periods to include |
20692067

20702068
### SunAbrahamResults
20712069

diff_diff/two_stage.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,11 @@
2929
from scipy import sparse
3030
from scipy.sparse.linalg import factorized as sparse_factorized
3131

32+
# Maximum number of elements before falling back to per-column sparse aggregation.
33+
# 10M float64 elements ≈ 80 MB peak allocation. Above this, per-column .getcol()
34+
# trades throughput for bounded memory.
35+
_SPARSE_DENSE_THRESHOLD = 10_000_000
36+
3237
from diff_diff.linalg import solve_ols
3338
from diff_diff.two_stage_bootstrap import TwoStageDiDBootstrapMixin
3439
from diff_diff.two_stage_results import TwoStageBootstrapResults, TwoStageDiDResults # noqa: F401 (re-export)
@@ -1222,15 +1227,19 @@ def _compute_gmm_variance(
12221227
unique_clusters, cluster_indices = np.unique(cluster_ids, return_inverse=True)
12231228
G = len(unique_clusters)
12241229

1225-
# Convert sparse to dense once for efficient cluster aggregation.
1226-
# Total memory touched is identical to per-column .getcol().toarray();
1227-
# only peak allocation differs (full matrix vs one column at a time).
1228-
# For panels with >100K FE columns, consider reverting to per-column
1229-
# .getcol() to limit peak memory.
1230-
weighted_X10_dense = weighted_X10.toarray()
1230+
n_elements = weighted_X10.shape[0] * weighted_X10.shape[1]
12311231
c_by_cluster = np.zeros((G, p))
1232-
for j_col in range(p):
1233-
np.add.at(c_by_cluster[:, j_col], cluster_indices, weighted_X10_dense[:, j_col])
1232+
if n_elements > _SPARSE_DENSE_THRESHOLD:
1233+
# Per-column path: limits peak memory for large FE matrices
1234+
weighted_X10_csc = weighted_X10.tocsc()
1235+
for j_col in range(p):
1236+
col_data = weighted_X10_csc.getcol(j_col).toarray().ravel()
1237+
np.add.at(c_by_cluster[:, j_col], cluster_indices, col_data)
1238+
else:
1239+
# Dense path: faster for moderate-size matrices
1240+
weighted_X10_dense = weighted_X10.toarray()
1241+
for j_col in range(p):
1242+
np.add.at(c_by_cluster[:, j_col], cluster_indices, weighted_X10_dense[:, j_col])
12341243

12351244
# 3. Per-cluster Stage 2 scores: X'_{2g} eps_{2g}
12361245
weighted_X2 = X_2 * eps_2[:, None] # (n x k) dense

diff_diff/two_stage_bootstrap.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
from diff_diff.linalg import solve_ols
1717
from diff_diff.staggered_bootstrap import _generate_bootstrap_weights_batch
18+
from diff_diff.two_stage import _SPARSE_DENSE_THRESHOLD
1819
from diff_diff.two_stage_results import TwoStageBootstrapResults
1920

2021
__all__ = [
@@ -106,12 +107,19 @@ def _compute_cluster_S_scores(
106107
unique_clusters, cluster_indices = np.unique(cluster_ids, return_inverse=True)
107108
G = len(unique_clusters)
108109

109-
# Convert sparse to dense once (see _compute_gmm_variance for memory note).
110-
# For panels with >100K FE columns, consider per-column .getcol() instead.
111-
weighted_X10_dense = weighted_X10.toarray()
110+
n_elements = weighted_X10.shape[0] * weighted_X10.shape[1]
112111
c_by_cluster = np.zeros((G, p))
113-
for j_col in range(p):
114-
np.add.at(c_by_cluster[:, j_col], cluster_indices, weighted_X10_dense[:, j_col])
112+
if n_elements > _SPARSE_DENSE_THRESHOLD:
113+
# Per-column path: limits peak memory for large FE matrices
114+
weighted_X10_csc = weighted_X10.tocsc()
115+
for j_col in range(p):
116+
col_data = weighted_X10_csc.getcol(j_col).toarray().ravel()
117+
np.add.at(c_by_cluster[:, j_col], cluster_indices, col_data)
118+
else:
119+
# Dense path: faster for moderate-size matrices
120+
weighted_X10_dense = weighted_X10.toarray()
121+
for j_col in range(p):
122+
np.add.at(c_by_cluster[:, j_col], cluster_indices, weighted_X10_dense[:, j_col])
115123

116124
weighted_X2 = X_2 * eps_2[:, None]
117125
s2_by_cluster = np.zeros((G, k))

tests/test_sun_abraham.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1442,6 +1442,15 @@ def test_never_treated_inf_encoding(self):
14421442
f"SE differs: inf={results_inf.overall_se}, zero={results_zero.overall_se}"
14431443
)
14441444

1445+
def test_removed_params_raise_typeerror(self):
1446+
"""Removed min_pre_periods/min_post_periods raise TypeError."""
1447+
data = generate_staggered_data(n_units=30, n_periods=6, seed=42)
1448+
sa = SunAbraham(n_bootstrap=0)
1449+
with pytest.raises(TypeError, match="unexpected keyword argument"):
1450+
sa.fit(data, "outcome", "unit", "time", "first_treat", min_pre_periods=2)
1451+
with pytest.raises(TypeError, match="unexpected keyword argument"):
1452+
sa.fit(data, "outcome", "unit", "time", "first_treat", min_post_periods=2)
1453+
14451454
def test_all_never_treated_inf_raises(self):
14461455
"""Test that all-never-treated data with np.inf encoding raises ValueError."""
14471456
data = generate_staggered_data(n_units=100, n_periods=10, n_cohorts=3, seed=42)

tests/test_two_stage.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1145,3 +1145,35 @@ def test_print_summary(self, capsys):
11451145
results.print_summary()
11461146
captured = capsys.readouterr()
11471147
assert "Two-Stage DiD" in captured.out
1148+
1149+
def test_sparse_fallback_path(self):
1150+
"""Size guard falls back to per-column path and produces same results."""
1151+
import diff_diff.two_stage as ts_mod
1152+
1153+
data = generate_test_data(n_units=50, n_periods=6, seed=42)
1154+
1155+
# Run with normal (high) threshold — uses dense path
1156+
result_dense = TwoStageDiD().fit(
1157+
data, outcome="outcome", unit="unit", time="time", first_treat="first_treat"
1158+
)
1159+
1160+
# Patch threshold to 1 to force per-column path on all data
1161+
orig = ts_mod._SPARSE_DENSE_THRESHOLD
1162+
try:
1163+
ts_mod._SPARSE_DENSE_THRESHOLD = 1
1164+
result_sparse = TwoStageDiD().fit(
1165+
data,
1166+
outcome="outcome",
1167+
unit="unit",
1168+
time="time",
1169+
first_treat="first_treat",
1170+
)
1171+
finally:
1172+
ts_mod._SPARSE_DENSE_THRESHOLD = orig
1173+
1174+
np.testing.assert_allclose(
1175+
result_dense.overall_att, result_sparse.overall_att, rtol=1e-10
1176+
)
1177+
np.testing.assert_allclose(
1178+
result_dense.overall_se, result_sparse.overall_se, rtol=1e-10
1179+
)

0 commit comments

Comments
 (0)