Skip to content

Commit 903edc7

Browse files
igerberclaude
andcommitted
Address PR #95 round 2: refactor p-value, extend errstate, strengthen tests
- Refactor _compute_bootstrap_pvalue to accept n_valid parameter, eliminating duplicated p-value logic in _compute_effect_bootstrap_stats - Extend np.errstate coverage in staggered_aggregation.py to wrap all WIF division operations (not just matrix multiplication) - Add deviation note to Methodology Registry documenting defensive enhancement over R/Stata reference implementations - Strengthen test assertions: verify warnings are captured and NaN SE is accompanied by validity warnings - Add test_validity_threshold_nan_se for edge case coverage Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
1 parent ba0eee6 commit 903edc7

4 files changed

Lines changed: 97 additions & 18 deletions

File tree

diff_diff/staggered_aggregation.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -273,16 +273,17 @@ def _compute_aggregated_se_with_wif(
273273
indicator_sum = np.sum(indicator_matrix - pg_keepers, axis=1)
274274

275275
# Vectorized wif matrix computation
276-
# if1_matrix[i,k] = (indicator[i,k] - pg[k]) / sum_pg
277-
if1_matrix = (indicator_matrix - pg_keepers) / sum_pg_keepers
278-
# if2_matrix[i,k] = indicator_sum[i] * pg[k] / sum_pg^2
279-
if2_matrix = np.outer(indicator_sum, pg_keepers) / (sum_pg_keepers ** 2)
280-
wif_matrix = if1_matrix - if2_matrix
281-
282-
# Single matrix-vector multiply for all contributions
283-
# wif_contrib[i] = sum_k(wif[i,k] * att[k])
284276
# Suppress RuntimeWarnings for edge cases (small samples, extreme weights)
277+
# in division operations and matrix multiplication
285278
with np.errstate(divide='ignore', invalid='ignore', over='ignore'):
279+
# if1_matrix[i,k] = (indicator[i,k] - pg[k]) / sum_pg
280+
if1_matrix = (indicator_matrix - pg_keepers) / sum_pg_keepers
281+
# if2_matrix[i,k] = indicator_sum[i] * pg[k] / sum_pg^2
282+
if2_matrix = np.outer(indicator_sum, pg_keepers) / (sum_pg_keepers ** 2)
283+
wif_matrix = if1_matrix - if2_matrix
284+
285+
# Single matrix-vector multiply for all contributions
286+
# wif_contrib[i] = sum_k(wif[i,k] * att[k])
286287
wif_contrib = wif_matrix @ effects
287288

288289
# Check for non-finite values from edge cases

diff_diff/staggered_bootstrap.py

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -605,13 +605,30 @@ def _compute_bootstrap_pvalue(
605605
self,
606606
original_effect: float,
607607
boot_dist: np.ndarray,
608+
n_valid: Optional[int] = None,
608609
) -> float:
609610
"""
610611
Compute two-sided bootstrap p-value.
611612
612613
Uses the percentile method: p-value is the proportion of bootstrap
613614
estimates on the opposite side of zero from the original estimate,
614615
doubled for two-sided test.
616+
617+
Parameters
618+
----------
619+
original_effect : float
620+
Original point estimate.
621+
boot_dist : np.ndarray
622+
Bootstrap distribution of the effect.
623+
n_valid : int, optional
624+
Number of valid bootstrap samples. If None, uses self.n_bootstrap.
625+
Use this when boot_dist has already been filtered for non-finite values
626+
to ensure the p-value floor is based on the actual valid sample count.
627+
628+
Returns
629+
-------
630+
float
631+
Two-sided bootstrap p-value.
615632
"""
616633
if original_effect >= 0:
617634
# Proportion of bootstrap estimates <= 0
@@ -623,8 +640,9 @@ def _compute_bootstrap_pvalue(
623640
# Two-sided p-value
624641
p_value = min(2 * p_one_sided, 1.0)
625642

626-
# Ensure minimum p-value
627-
p_value = max(p_value, 1 / (self.n_bootstrap + 1))
643+
# Ensure minimum p-value using n_valid if provided, otherwise n_bootstrap
644+
n_for_floor = n_valid if n_valid is not None else self.n_bootstrap
645+
p_value = max(p_value, 1 / (n_for_floor + 1))
628646

629647
return float(p_value)
630648

@@ -693,12 +711,7 @@ def _compute_effect_bootstrap_stats(
693711
se = float(np.std(valid_dist, ddof=1))
694712
ci = self._compute_percentile_ci(valid_dist, self.alpha)
695713

696-
# Compute p-value inline with correct floor based on valid sample count
697-
if original_effect >= 0:
698-
p_one_sided = np.mean(valid_dist <= 0)
699-
else:
700-
p_one_sided = np.mean(valid_dist >= 0)
701-
p_value = min(2 * p_one_sided, 1.0)
702-
p_value = max(p_value, 1 / (n_valid_bootstrap + 1)) # Floor uses valid count
714+
# Compute p-value using shared method with correct floor based on valid sample count
715+
p_value = self._compute_bootstrap_pvalue(original_effect, valid_dist, n_valid=n_valid_bootstrap)
703716

704-
return se, ci, float(p_value)
717+
return se, ci, p_value

docs/methodology/REGISTRY.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,7 @@ Aggregations:
208208
- Analytic SE: Returns NaN to signal invalid inference (not biased via zeroing)
209209
- Bootstrap: Drops non-finite samples, warns, and adjusts p-value floor accordingly
210210
- Threshold: Returns NaN if <50% of bootstrap samples are valid
211+
- **Note**: This is a defensive enhancement over reference implementations (R's `did::att_gt`, Stata's `csdid`) which may error or produce unhandled inf/nan in edge cases without informative warnings
211212

212213
**Reference implementation(s):**
213214
- R: `did::att_gt()` (Callaway & Sant'Anna's official package)

tests/test_staggered.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -741,13 +741,77 @@ def test_extreme_weights_warning(self):
741741
first_treat='first_treat'
742742
)
743743

744+
# Collect warning messages for inspection
745+
warning_messages = [str(warning.message) for warning in w]
746+
744747
# ATT should be finite
745748
assert np.isfinite(boot_results.overall_att), "ATT should be finite"
749+
746750
# Bootstrap SE based on valid samples - may be finite or NaN
747751
assert boot_results.bootstrap_results is not None, "Bootstrap results should exist"
748752
assert np.isfinite(boot_results.overall_se) or np.isnan(boot_results.overall_se), \
749753
"Bootstrap SE should be finite or NaN (not inf)"
750754

755+
# If SE is NaN, verify it's due to validity threshold (should have warning)
756+
if np.isnan(boot_results.overall_se):
757+
assert any("valid" in msg.lower() or "nan" in msg.lower() for msg in warning_messages), \
758+
"NaN SE should be accompanied by warning about validity"
759+
760+
def test_validity_threshold_nan_se(self):
761+
"""Test that <50% valid bootstrap samples returns NaN SE with warning.
762+
763+
This tests the methodology-aligned behavior where invalid inference
764+
is signaled via NaN rather than biased estimates.
765+
"""
766+
import warnings
767+
np.random.seed(42)
768+
769+
# Create minimal dataset that might trigger edge cases
770+
n_units, n_periods = 10, 3
771+
units = np.repeat(np.arange(n_units), n_periods)
772+
times = np.tile(np.arange(n_periods), n_units)
773+
774+
# Only 1 treated unit - very extreme
775+
first_treat = np.zeros(n_units)
776+
first_treat[0] = 1
777+
first_treat_expanded = np.repeat(first_treat, n_periods)
778+
779+
post = (times >= first_treat_expanded) & (first_treat_expanded > 0)
780+
outcomes = 1.0 + 2.0 * post + np.random.randn(len(units)) * 0.5
781+
782+
data = pd.DataFrame({
783+
'unit': units,
784+
'time': times,
785+
'outcome': outcomes,
786+
'first_treat': first_treat_expanded.astype(int),
787+
})
788+
789+
# Use low n_bootstrap to trigger warning and potentially non-finite samples
790+
cs_boot = CallawaySantAnna(n_bootstrap=30, seed=42)
791+
792+
with warnings.catch_warnings(record=True) as w:
793+
warnings.simplefilter("always")
794+
boot_results = cs_boot.fit(
795+
data,
796+
outcome='outcome',
797+
unit='unit',
798+
time='time',
799+
first_treat='first_treat'
800+
)
801+
802+
warning_messages = [str(warning.message) for warning in w]
803+
804+
# Should get the low n_bootstrap warning
805+
assert any("n_bootstrap" in msg for msg in warning_messages), \
806+
"Should warn about low n_bootstrap"
807+
808+
# Bootstrap results should exist
809+
assert boot_results.bootstrap_results is not None, "Bootstrap results should exist"
810+
811+
# SE constraints: finite or NaN (never inf)
812+
assert np.isfinite(boot_results.overall_se) or np.isnan(boot_results.overall_se), \
813+
"Bootstrap SE should be finite or NaN (not inf)"
814+
751815
def test_near_collinear_covariates(self):
752816
"""Test that near-collinear covariates are handled gracefully."""
753817
data = generate_staggered_data_with_covariates(seed=42)

0 commit comments

Comments
 (0)