Skip to content

Commit f3c9f05

Browse files
igerberclaude
andcommitted
Fix non-finite handling to align with Methodology Registry
Address PR #95 reviewer feedback: - Analytic SE: return NaN instead of zeroing to signal invalid inference - Bootstrap: drop invalid samples and warn, preserving valid distribution - Update test to verify methodology-aligned behavior (finite or NaN, not biased) Per docs/methodology/REGISTRY.md: "Missing group-time cells: ATT(g,t) set to NaN" Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
1 parent 854e2a8 commit f3c9f05

3 files changed

Lines changed: 93 additions & 50 deletions

File tree

diff_diff/staggered_aggregation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -292,11 +292,11 @@ def _compute_aggregated_se_with_wif(
292292
warnings.warn(
293293
f"Non-finite values ({n_nonfinite}/{len(wif_contrib)}) in weight influence "
294294
"function computation. This may occur with very small samples or extreme "
295-
"weights. SE estimates may be unreliable.",
295+
"weights. Returning NaN for SE to signal invalid inference.",
296296
RuntimeWarning,
297297
stacklevel=2
298298
)
299-
wif_contrib = np.where(np.isfinite(wif_contrib), wif_contrib, 0.0)
299+
return np.nan # Signal invalid inference instead of biased SE
300300

301301
# Scale by 1/n_units to match R's getSE formula: sqrt(mean(IF^2)/n)
302302
psi_wif = wif_contrib / n_units

diff_diff/staggered_bootstrap.py

Lines changed: 62 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -379,22 +379,17 @@ def _run_multiplier_bootstrap(
379379
control_weights @ control_inf
380380
)
381381

382-
perturbations = self._check_and_fix_nonfinite(
383-
perturbations, f"bootstrap perturbations for ATT(g,t) {gt_pairs[j]}"
384-
)
382+
# Let non-finite values propagate - they will be handled at statistics computation
385383
bootstrap_atts_gt[:, j] = original_atts[j] + perturbations
386384

387385
# Vectorized overall ATT: matrix-vector multiply
388386
# Shape: (n_bootstrap,)
389-
# Suppress RuntimeWarnings for edge cases
387+
# Suppress RuntimeWarnings for edge cases - non-finite values handled at statistics computation
390388
with np.errstate(divide='ignore', invalid='ignore', over='ignore'):
391389
bootstrap_overall = bootstrap_atts_gt @ overall_weights
392390

393-
bootstrap_overall = self._check_and_fix_nonfinite(
394-
bootstrap_overall, "bootstrap overall ATT aggregation"
395-
)
396-
397391
# Vectorized event study aggregation
392+
# Non-finite values handled at statistics computation stage
398393
rel_periods: List[int] = []
399394
bootstrap_event_study: Optional[Dict[int, np.ndarray]] = None
400395
if event_study_info is not None:
@@ -409,11 +404,8 @@ def _run_multiplier_bootstrap(
409404
with np.errstate(divide='ignore', invalid='ignore', over='ignore'):
410405
bootstrap_event_study[e] = bootstrap_atts_gt[:, gt_indices] @ weights
411406

412-
bootstrap_event_study[e] = self._check_and_fix_nonfinite(
413-
bootstrap_event_study[e], f"bootstrap event study aggregation (e={e})"
414-
)
415-
416407
# Vectorized group aggregation
408+
# Non-finite values handled at statistics computation stage
417409
group_list: List[Any] = []
418410
bootstrap_group: Optional[Dict[Any, np.ndarray]] = None
419411
if group_agg_info is not None:
@@ -427,26 +419,24 @@ def _run_multiplier_bootstrap(
427419
with np.errstate(divide='ignore', invalid='ignore', over='ignore'):
428420
bootstrap_group[g] = bootstrap_atts_gt[:, gt_indices] @ weights
429421

430-
bootstrap_group[g] = self._check_and_fix_nonfinite(
431-
bootstrap_group[g], f"bootstrap group aggregation (g={g})"
432-
)
433-
434422
# Compute bootstrap statistics for ATT(g,t)
435423
gt_ses = {}
436424
gt_cis = {}
437425
gt_p_values = {}
438426

439427
for j, gt in enumerate(gt_pairs):
440428
se, ci, p_value = self._compute_effect_bootstrap_stats(
441-
original_atts[j], bootstrap_atts_gt[:, j]
429+
original_atts[j], bootstrap_atts_gt[:, j],
430+
context=f"ATT(g={gt[0]}, t={gt[1]})"
442431
)
443432
gt_ses[gt] = se
444433
gt_cis[gt] = ci
445434
gt_p_values[gt] = p_value
446435

447436
# Compute bootstrap statistics for overall ATT
448437
overall_se, overall_ci, overall_p_value = self._compute_effect_bootstrap_stats(
449-
original_overall, bootstrap_overall
438+
original_overall, bootstrap_overall,
439+
context="overall ATT"
450440
)
451441

452442
# Compute bootstrap statistics for event study effects
@@ -461,7 +451,8 @@ def _run_multiplier_bootstrap(
461451

462452
for e in rel_periods:
463453
se, ci, p_value = self._compute_effect_bootstrap_stats(
464-
event_study_info[e]['effect'], bootstrap_event_study[e]
454+
event_study_info[e]['effect'], bootstrap_event_study[e],
455+
context=f"event study (e={e})"
465456
)
466457
event_study_ses[e] = se
467458
event_study_cis[e] = ci
@@ -479,7 +470,8 @@ def _run_multiplier_bootstrap(
479470

480471
for g in group_list:
481472
se, ci, p_value = self._compute_effect_bootstrap_stats(
482-
group_agg_info[g]['effect'], bootstrap_group[g]
473+
group_agg_info[g]['effect'], bootstrap_group[g],
474+
context=f"group effect (g={g})"
483475
)
484476
group_effect_ses[g] = se
485477
group_effect_cis[g] = ci
@@ -640,16 +632,23 @@ def _compute_effect_bootstrap_stats(
640632
self,
641633
original_effect: float,
642634
boot_dist: np.ndarray,
635+
context: str = "bootstrap distribution",
643636
) -> Tuple[float, Tuple[float, float], float]:
644637
"""
645638
Compute bootstrap statistics for a single effect.
646639
640+
Non-finite bootstrap samples are dropped and a warning is issued if any
641+
are present. If too few valid samples remain (<50%), returns NaN for all
642+
statistics to signal invalid inference.
643+
647644
Parameters
648645
----------
649646
original_effect : float
650647
Original point estimate.
651648
boot_dist : np.ndarray
652649
Bootstrap distribution of the effect.
650+
context : str, optional
651+
Description for warning messages, by default "bootstrap distribution".
653652
654653
Returns
655654
-------
@@ -660,35 +659,65 @@ def _compute_effect_bootstrap_stats(
660659
p_value : float
661660
Bootstrap p-value.
662661
"""
663-
se = float(np.std(boot_dist, ddof=1))
664-
ci = self._compute_percentile_ci(boot_dist, self.alpha)
665-
p_value = self._compute_bootstrap_pvalue(original_effect, boot_dist)
662+
# Filter out non-finite values
663+
finite_mask = np.isfinite(boot_dist)
664+
n_valid = np.sum(finite_mask)
665+
n_total = len(boot_dist)
666+
667+
if n_valid < n_total:
668+
import warnings
669+
n_nonfinite = n_total - n_valid
670+
warnings.warn(
671+
f"Dropping {n_nonfinite}/{n_total} non-finite bootstrap samples in {context}. "
672+
"This may occur with very small samples or extreme weights. "
673+
"Bootstrap estimates based on remaining valid samples.",
674+
RuntimeWarning,
675+
stacklevel=3
676+
)
677+
678+
# Check if we have enough valid samples
679+
if n_valid < n_total * 0.5:
680+
import warnings
681+
warnings.warn(
682+
f"Too few valid bootstrap samples ({n_valid}/{n_total}) in {context}. "
683+
"Returning NaN for SE/CI/p-value to signal invalid inference.",
684+
RuntimeWarning,
685+
stacklevel=3
686+
)
687+
return np.nan, (np.nan, np.nan), np.nan
688+
689+
# Use only valid samples
690+
valid_dist = boot_dist[finite_mask]
691+
692+
se = float(np.std(valid_dist, ddof=1))
693+
ci = self._compute_percentile_ci(valid_dist, self.alpha)
694+
p_value = self._compute_bootstrap_pvalue(original_effect, valid_dist)
666695
return se, ci, p_value
667696

668-
def _check_and_fix_nonfinite(self, arr: np.ndarray, context: str) -> np.ndarray:
669-
"""Check for non-finite values and warn if found.
697+
def _mask_nonfinite_samples(self, arr: np.ndarray, context: str) -> np.ndarray:
698+
"""Return boolean mask of finite samples, warning if any dropped.
670699
671700
Parameters
672701
----------
673702
arr : np.ndarray
674-
Array to check.
703+
Array to check (1D bootstrap distribution).
675704
context : str
676705
Description of where this check is happening (for warning message).
677706
678707
Returns
679708
-------
680709
np.ndarray
681-
Array with non-finite values replaced by 0.0.
710+
Boolean mask where True indicates finite (valid) samples.
682711
"""
683-
if not np.all(np.isfinite(arr)):
712+
finite_mask = np.isfinite(arr)
713+
if not np.all(finite_mask):
684714
import warnings
685-
n_nonfinite = np.sum(~np.isfinite(arr))
715+
n_nonfinite = np.sum(~finite_mask)
686716
warnings.warn(
687-
f"Non-finite values ({n_nonfinite}/{arr.size}) in {context}. "
717+
f"Dropping {n_nonfinite}/{arr.size} non-finite bootstrap samples in {context}. "
688718
"This may occur with very small samples or extreme weights. "
689-
"Bootstrap estimates may be unreliable.",
719+
"Bootstrap estimates based on remaining valid samples.",
690720
RuntimeWarning,
691721
stacklevel=3
692722
)
693-
return np.where(np.isfinite(arr), arr, 0.0)
694-
return arr
723+
return finite_mask

tests/test_staggered.py

Lines changed: 29 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -682,7 +682,14 @@ def test_extreme_propensity_scores(self):
682682
assert results.overall_se > 0, "SE should be positive"
683683

684684
def test_extreme_weights_warning(self):
685-
"""Test that extreme weights produce warnings, not silent failures."""
685+
"""Test that extreme weights produce warnings and methodology-aligned behavior.
686+
687+
Per the Methodology Registry (docs/methodology/REGISTRY.md):
688+
- Missing group-time cells: ATT(g,t) set to NaN
689+
- Analytic SE: returns NaN to signal invalid inference (not biased via zeroing)
690+
- Bootstrap: drops invalid samples and warns, preserving valid distribution
691+
"""
692+
import warnings
686693
np.random.seed(42)
687694

688695
# Minimal dataset: very small sample with unbalanced groups
@@ -705,7 +712,7 @@ def test_extreme_weights_warning(self):
705712
'first_treat': first_treat_expanded.astype(int),
706713
})
707714

708-
# Test without bootstrap first
715+
# Test without bootstrap - ATT should be finite, SE may be NaN for edge cases
709716
cs = CallawaySantAnna()
710717
results = cs.fit(
711718
data,
@@ -715,24 +722,31 @@ def test_extreme_weights_warning(self):
715722
first_treat='first_treat'
716723
)
717724

718-
# Results should be finite even in edge cases
725+
# ATT point estimate should be finite
719726
assert np.isfinite(results.overall_att), "ATT should be finite"
720-
assert np.isfinite(results.overall_se), "SE should be finite"
727+
# SE is either finite (valid) or NaN (signals invalid inference) - not biased
728+
assert np.isfinite(results.overall_se) or np.isnan(results.overall_se), \
729+
"SE should be finite or NaN (not inf)"
721730

722-
# Test with bootstrap enabled
723-
cs_boot = CallawaySantAnna(n_bootstrap=50, seed=42)
724-
boot_results = cs_boot.fit(
725-
data,
726-
outcome='outcome',
727-
unit='unit',
728-
time='time',
729-
first_treat='first_treat'
730-
)
731+
# Test with bootstrap - should drop invalid samples with warning
732+
cs_boot = CallawaySantAnna(n_bootstrap=100, seed=42)
733+
734+
with warnings.catch_warnings(record=True) as w:
735+
warnings.simplefilter("always")
736+
boot_results = cs_boot.fit(
737+
data,
738+
outcome='outcome',
739+
unit='unit',
740+
time='time',
741+
first_treat='first_treat'
742+
)
731743

732-
# Bootstrap should also produce finite results
744+
# ATT should be finite
733745
assert np.isfinite(boot_results.overall_att), "ATT should be finite"
746+
# Bootstrap SE based on valid samples - may be finite or NaN
734747
assert boot_results.bootstrap_results is not None, "Bootstrap results should exist"
735-
assert np.isfinite(boot_results.overall_se), "Bootstrap SE should be finite"
748+
assert np.isfinite(boot_results.overall_se) or np.isnan(boot_results.overall_se), \
749+
"Bootstrap SE should be finite or NaN (not inf)"
736750

737751
def test_near_collinear_covariates(self):
738752
"""Test that near-collinear covariates are handled gracefully."""

0 commit comments

Comments
 (0)