Skip to content

Commit c5eca78

Browse files
igerberclaude
andcommitted
Fix MultiPeriodDiD bootstrap+survey fallback and relabel n_treated/n_control as observations from PR #218 review (round 14)
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 7ba2124 commit c5eca78

3 files changed

Lines changed: 44 additions & 10 deletions

File tree

diff_diff/estimators.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -838,13 +838,16 @@ def fit( # type: ignore[override]
838838
ValueError
839839
If required parameters are missing or data validation fails.
840840
"""
841-
# Warn if wild bootstrap is requested but not supported
841+
# Fall back to analytical inference if wild bootstrap requested
842+
# (must happen before _resolve_survey_for_fit which rejects bootstrap+survey)
843+
effective_inference = self.inference
842844
if self.inference == "wild_bootstrap":
843845
warnings.warn(
844846
"Wild bootstrap inference is not yet supported for MultiPeriodDiD. "
845847
"Using analytical inference instead.",
846848
UserWarning,
847849
)
850+
effective_inference = "analytical"
848851

849852
# Validate basic inputs
850853
if outcome is None or treatment is None or time is None:
@@ -992,7 +995,7 @@ def fit( # type: ignore[override]
992995
from diff_diff.survey import _resolve_effective_cluster, _resolve_survey_for_fit
993996

994997
resolved_survey, survey_weights, survey_weight_type, survey_metadata = (
995-
_resolve_survey_for_fit(survey_design, data, self.inference)
998+
_resolve_survey_for_fit(survey_design, data, effective_inference)
996999
)
9971000

9981001
# Handle absorbed fixed effects (within-transformation)

diff_diff/results.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,9 @@ class DiDResults:
3434
n_obs : int
3535
Number of observations used in estimation.
3636
n_treated : int
37-
Number of treated units.
37+
Number of treated observations.
3838
n_control : int
39-
Number of control units.
39+
Number of control observations.
4040
"""
4141

4242
att: float
@@ -93,8 +93,8 @@ def summary(self, alpha: Optional[float] = None) -> str:
9393
"=" * 70,
9494
"",
9595
f"{'Observations:':<25} {self.n_obs:>10}",
96-
f"{'Treated units:':<25} {self.n_treated:>10}",
97-
f"{'Control units:':<25} {self.n_control:>10}",
96+
f"{'Treated obs:':<25} {self.n_treated:>10}",
97+
f"{'Control obs:':<25} {self.n_control:>10}",
9898
]
9999

100100
if self.r_squared is not None:
@@ -645,9 +645,9 @@ class SyntheticDiDResults:
645645
n_obs : int
646646
Number of observations used in estimation.
647647
n_treated : int
648-
Number of treated units.
648+
Number of treated observations.
649649
n_control : int
650-
Number of control units.
650+
Number of control observations.
651651
unit_weights : dict
652652
Dictionary mapping control unit IDs to their synthetic weights.
653653
time_weights : dict
@@ -714,8 +714,8 @@ def summary(self, alpha: Optional[float] = None) -> str:
714714
"=" * 75,
715715
"",
716716
f"{'Observations:':<25} {self.n_obs:>10}",
717-
f"{'Treated units:':<25} {self.n_treated:>10}",
718-
f"{'Control units:':<25} {self.n_control:>10}",
717+
f"{'Treated obs:':<25} {self.n_treated:>10}",
718+
f"{'Control obs:':<25} {self.n_control:>10}",
719719
f"{'Pre-treatment periods:':<25} {len(self.pre_periods):>10}",
720720
f"{'Post-treatment periods:':<25} {len(self.post_periods):>10}",
721721
]

tests/test_survey.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2793,3 +2793,34 @@ def test_zero_score_dispersion_stratified_psu(self):
27932793
vcov = compute_survey_vcov(X, residuals, resolved=resolved)
27942794
# Zero residuals → zero scores → zero V_h per stratum → zero vcov
27952795
np.testing.assert_array_equal(vcov, np.zeros((2, 2)))
2796+
2797+
2798+
class TestRound14Fixes:
2799+
"""Tests for PR #218 review round 14 fixes."""
2800+
2801+
def test_multiperiod_bootstrap_survey_fallback(self):
2802+
"""MultiPeriodDiD with wild_bootstrap + survey_design falls back gracefully."""
2803+
np.random.seed(42)
2804+
n = 40
2805+
df = pd.DataFrame(
2806+
{
2807+
"outcome": np.random.randn(n),
2808+
"treated": np.array([1] * 20 + [0] * 20),
2809+
"time": np.tile([0, 1, 2, 3], 10),
2810+
"w": np.ones(n),
2811+
}
2812+
)
2813+
sd = SurveyDesign(weights="w", weight_type="pweight")
2814+
mpd = MultiPeriodDiD(inference="wild_bootstrap")
2815+
# Should warn about fallback and produce valid analytical results
2816+
with warnings.catch_warnings():
2817+
warnings.simplefilter("ignore")
2818+
result = mpd.fit(
2819+
df,
2820+
outcome="outcome",
2821+
treatment="treated",
2822+
time="time",
2823+
post_periods=[2, 3],
2824+
survey_design=sd,
2825+
)
2826+
assert np.isfinite(result.avg_att)

0 commit comments

Comments
 (0)