@@ -379,22 +379,17 @@ def _run_multiplier_bootstrap(
379379 control_weights @ control_inf
380380 )
381381
382- perturbations = self ._check_and_fix_nonfinite (
383- perturbations , f"bootstrap perturbations for ATT(g,t) { gt_pairs [j ]} "
384- )
382+ # Let non-finite values propagate - they will be handled at statistics computation
385383 bootstrap_atts_gt [:, j ] = original_atts [j ] + perturbations
386384
387385 # Vectorized overall ATT: matrix-vector multiply
388386 # Shape: (n_bootstrap,)
389- # Suppress RuntimeWarnings for edge cases
387+ # Suppress RuntimeWarnings for edge cases - non-finite values handled at statistics computation
390388 with np .errstate (divide = 'ignore' , invalid = 'ignore' , over = 'ignore' ):
391389 bootstrap_overall = bootstrap_atts_gt @ overall_weights
392390
393- bootstrap_overall = self ._check_and_fix_nonfinite (
394- bootstrap_overall , "bootstrap overall ATT aggregation"
395- )
396-
397391 # Vectorized event study aggregation
392+ # Non-finite values handled at statistics computation stage
398393 rel_periods : List [int ] = []
399394 bootstrap_event_study : Optional [Dict [int , np .ndarray ]] = None
400395 if event_study_info is not None :
@@ -409,11 +404,8 @@ def _run_multiplier_bootstrap(
409404 with np .errstate (divide = 'ignore' , invalid = 'ignore' , over = 'ignore' ):
410405 bootstrap_event_study [e ] = bootstrap_atts_gt [:, gt_indices ] @ weights
411406
412- bootstrap_event_study [e ] = self ._check_and_fix_nonfinite (
413- bootstrap_event_study [e ], f"bootstrap event study aggregation (e={ e } )"
414- )
415-
416407 # Vectorized group aggregation
408+ # Non-finite values handled at statistics computation stage
417409 group_list : List [Any ] = []
418410 bootstrap_group : Optional [Dict [Any , np .ndarray ]] = None
419411 if group_agg_info is not None :
@@ -427,26 +419,24 @@ def _run_multiplier_bootstrap(
427419 with np .errstate (divide = 'ignore' , invalid = 'ignore' , over = 'ignore' ):
428420 bootstrap_group [g ] = bootstrap_atts_gt [:, gt_indices ] @ weights
429421
430- bootstrap_group [g ] = self ._check_and_fix_nonfinite (
431- bootstrap_group [g ], f"bootstrap group aggregation (g={ g } )"
432- )
433-
434422 # Compute bootstrap statistics for ATT(g,t)
435423 gt_ses = {}
436424 gt_cis = {}
437425 gt_p_values = {}
438426
439427 for j , gt in enumerate (gt_pairs ):
440428 se , ci , p_value = self ._compute_effect_bootstrap_stats (
441- original_atts [j ], bootstrap_atts_gt [:, j ]
429+ original_atts [j ], bootstrap_atts_gt [:, j ],
430+ context = f"ATT(g={ gt [0 ]} , t={ gt [1 ]} )"
442431 )
443432 gt_ses [gt ] = se
444433 gt_cis [gt ] = ci
445434 gt_p_values [gt ] = p_value
446435
447436 # Compute bootstrap statistics for overall ATT
448437 overall_se , overall_ci , overall_p_value = self ._compute_effect_bootstrap_stats (
449- original_overall , bootstrap_overall
438+ original_overall , bootstrap_overall ,
439+ context = "overall ATT"
450440 )
451441
452442 # Compute bootstrap statistics for event study effects
@@ -461,7 +451,8 @@ def _run_multiplier_bootstrap(
461451
462452 for e in rel_periods :
463453 se , ci , p_value = self ._compute_effect_bootstrap_stats (
464- event_study_info [e ]['effect' ], bootstrap_event_study [e ]
454+ event_study_info [e ]['effect' ], bootstrap_event_study [e ],
455+ context = f"event study (e={ e } )"
465456 )
466457 event_study_ses [e ] = se
467458 event_study_cis [e ] = ci
@@ -479,7 +470,8 @@ def _run_multiplier_bootstrap(
479470
480471 for g in group_list :
481472 se , ci , p_value = self ._compute_effect_bootstrap_stats (
482- group_agg_info [g ]['effect' ], bootstrap_group [g ]
473+ group_agg_info [g ]['effect' ], bootstrap_group [g ],
474+ context = f"group effect (g={ g } )"
483475 )
484476 group_effect_ses [g ] = se
485477 group_effect_cis [g ] = ci
@@ -640,16 +632,23 @@ def _compute_effect_bootstrap_stats(
640632 self ,
641633 original_effect : float ,
642634 boot_dist : np .ndarray ,
635+ context : str = "bootstrap distribution" ,
643636 ) -> Tuple [float , Tuple [float , float ], float ]:
644637 """
645638 Compute bootstrap statistics for a single effect.
646639
640+ Non-finite bootstrap samples are dropped and a warning is issued if any
641+ are present. If too few valid samples remain (<50%), returns NaN for all
642+ statistics to signal invalid inference.
643+
647644 Parameters
648645 ----------
649646 original_effect : float
650647 Original point estimate.
651648 boot_dist : np.ndarray
652649 Bootstrap distribution of the effect.
650+ context : str, optional
651+ Description for warning messages, by default "bootstrap distribution".
653652
654653 Returns
655654 -------
@@ -660,35 +659,65 @@ def _compute_effect_bootstrap_stats(
660659 p_value : float
661660 Bootstrap p-value.
662661 """
663- se = float (np .std (boot_dist , ddof = 1 ))
664- ci = self ._compute_percentile_ci (boot_dist , self .alpha )
665- p_value = self ._compute_bootstrap_pvalue (original_effect , boot_dist )
662+ # Filter out non-finite values
663+ finite_mask = np .isfinite (boot_dist )
664+ n_valid = np .sum (finite_mask )
665+ n_total = len (boot_dist )
666+
667+ if n_valid < n_total :
668+ import warnings
669+ n_nonfinite = n_total - n_valid
670+ warnings .warn (
671+ f"Dropping { n_nonfinite } /{ n_total } non-finite bootstrap samples in { context } . "
672+ "This may occur with very small samples or extreme weights. "
673+ "Bootstrap estimates based on remaining valid samples." ,
674+ RuntimeWarning ,
675+ stacklevel = 3
676+ )
677+
678+ # Check if we have enough valid samples
679+ if n_valid < n_total * 0.5 :
680+ import warnings
681+ warnings .warn (
682+ f"Too few valid bootstrap samples ({ n_valid } /{ n_total } ) in { context } . "
683+ "Returning NaN for SE/CI/p-value to signal invalid inference." ,
684+ RuntimeWarning ,
685+ stacklevel = 3
686+ )
687+ return np .nan , (np .nan , np .nan ), np .nan
688+
689+ # Use only valid samples
690+ valid_dist = boot_dist [finite_mask ]
691+
692+ se = float (np .std (valid_dist , ddof = 1 ))
693+ ci = self ._compute_percentile_ci (valid_dist , self .alpha )
694+ p_value = self ._compute_bootstrap_pvalue (original_effect , valid_dist )
666695 return se , ci , p_value
667696
668- def _check_and_fix_nonfinite (self , arr : np .ndarray , context : str ) -> np .ndarray :
669- """Check for non-finite values and warn if found .
697+ def _mask_nonfinite_samples (self , arr : np .ndarray , context : str ) -> np .ndarray :
698+ """Return boolean mask of finite samples, warning if any dropped .
670699
671700 Parameters
672701 ----------
673702 arr : np.ndarray
674- Array to check.
703+ Array to check (1D bootstrap distribution) .
675704 context : str
676705 Description of where this check is happening (for warning message).
677706
678707 Returns
679708 -------
680709 np.ndarray
681- Array with non-finite values replaced by 0.0 .
710+ Boolean mask where True indicates finite (valid) samples .
682711 """
683- if not np .all (np .isfinite (arr )):
712+ finite_mask = np .isfinite (arr )
713+ if not np .all (finite_mask ):
684714 import warnings
685- n_nonfinite = np .sum (~ np . isfinite ( arr ) )
715+ n_nonfinite = np .sum (~ finite_mask )
686716 warnings .warn (
687- f"Non-finite values ( { n_nonfinite } /{ arr .size } ) in { context } . "
717+ f"Dropping { n_nonfinite } /{ arr .size } non-finite bootstrap samples in { context } . "
688718 "This may occur with very small samples or extreme weights. "
689- "Bootstrap estimates may be unreliable ." ,
719+ "Bootstrap estimates based on remaining valid samples ." ,
690720 RuntimeWarning ,
691721 stacklevel = 3
692722 )
693- return np .where (np .isfinite (arr ), arr , 0.0 )
694- return arr
723+ return finite_mask
0 commit comments