@@ -338,7 +338,8 @@ def _compute_nearest_treated_distance_staggered(
338338 coords : Tuple [str , str ],
339339 metric : SpilloverMetric ,
340340 first_treat_by_unit : Dict [Any , Any ],
341- ) -> Tuple [np .ndarray , np .ndarray , np .ndarray ]:
341+ d_bar : Optional [float ] = None ,
342+ ) -> Tuple [np .ndarray , np .ndarray , np .ndarray , Optional [np .ndarray ]]:
342343 """Return per-row nearest-treated distance for the staggered case.
343344
344345 For each (unit, period) observation, find the minimum distance to any
@@ -361,6 +362,13 @@ def _compute_nearest_treated_distance_staggered(
361362 first_treat_by_unit : dict
362363 Mapping from unit identifier to onset time (or ``np.inf`` for
363364 never-treated). Generated by :func:`_extract_treatment_onsets`.
365+ d_bar : float, optional
366+ When supplied, the function additionally computes the per-row
367+ **spillover-trigger onset** (earliest cohort onset whose treated
368+ units fall within ``d_bar`` of unit ``i``) reusing the cohort
369+ loop. Used by :func:`_compute_event_time_per_row` to avoid a
370+ duplicate cohort pass on the event-study path
371+ (PR #456 R6 performance fix).
364372
365373 Notes
366374 -----
@@ -377,6 +385,11 @@ def _compute_nearest_treated_distance_staggered(
377385 Aligned unit identifier per row (for downstream broadcasting).
378386 row_time : ndarray of shape (n_rows,)
379387 Aligned time identifier per row.
388+ trigger_onset_per_row : ndarray of shape (n_rows,) or None
389+ ``None`` when ``d_bar`` is None. Otherwise: per-row earliest
390+ cohort onset whose treated units fall within ``d_bar`` of the
391+ row's unit, broadcast from per-unit. NaN for rows whose unit is
392+ never within ``d_bar`` of any cohort.
380393 """
381394 unit_coords_df = (
382395 data [[unit , coords [0 ], coords [1 ]]].drop_duplicates (subset = [unit ]).set_index (unit )
@@ -389,13 +402,16 @@ def _compute_nearest_treated_distance_staggered(
389402 row_time = np .asarray (data [time ].values )
390403 n_rows = len (row_unit )
391404 d_it = np .full (n_rows , np .inf , dtype = np .float64 )
405+ trigger_onset_per_unit_pos : Optional [np .ndarray ] = (
406+ np .full (len (unit_index ), np .nan , dtype = np .float64 ) if d_bar is not None else None
407+ )
392408
393409 # Determine the cohort onset times that exist in the data (excluding never-treated).
394410 unique_onsets = sorted ({ft for ft in first_treat_by_unit .values () if np .isfinite (ft )})
395411 if not unique_onsets :
396412 # Degenerate: no treated units. Caller should have rejected this
397413 # in `_validate_spillover_inputs`, but defensively return inf.
398- return d_it , row_unit , row_time
414+ return d_it , row_unit , row_time , None
399415
400416 # Row's unit position. Invariant across cohort iterations — compute
401417 # once outside the loop.
@@ -426,7 +442,25 @@ def _compute_nearest_treated_distance_staggered(
426442 update_mask = affected_rows & (row_cohort_dist < d_it )
427443 d_it [update_mask ] = row_cohort_dist [update_mask ]
428444
429- return d_it , row_unit , row_time
445+ # Reuse this same cohort distance computation for the per-unit
446+ # spillover-trigger onset when d_bar is supplied. The trigger is
447+ # the FIRST cohort whose treated units fall within d_bar of unit
448+ # i — once locked it persists for later cohort iterations. Using
449+ # cumulative-treated distances here is fine: if a unit is in
450+ # range of cohort c1, dists_to_cohort at onset=c1 already detects
451+ # it; later iterations with extra treated units only shrink the
452+ # distance, never grow it back above d_bar.
453+ if trigger_onset_per_unit_pos is not None :
454+ in_range_for_cohort = dists_to_cohort <= d_bar
455+ not_yet_triggered = np .isnan (trigger_onset_per_unit_pos )
456+ trigger_onset_per_unit_pos [in_range_for_cohort & not_yet_triggered ] = onset
457+
458+ # Broadcast per-unit trigger to rows when computed.
459+ if trigger_onset_per_unit_pos is not None :
460+ trigger_onset_per_row = trigger_onset_per_unit_pos [row_pos ]
461+ else :
462+ trigger_onset_per_row = None
463+ return d_it , row_unit , row_time , trigger_onset_per_row
430464
431465
432466def _compute_event_time_per_row (
@@ -439,6 +473,7 @@ def _compute_event_time_per_row(
439473 coords : Tuple [str , str ],
440474 metric : SpilloverMetric ,
441475 d_bar : float ,
476+ precomputed_trigger_onset_per_row : Optional [np .ndarray ] = None ,
442477) -> Tuple [np .ndarray , np .ndarray ]:
443478 """Compute two event-time clocks per row for Wave C event-study mode.
444479
@@ -475,6 +510,16 @@ def _compute_event_time_per_row(
475510 -------
476511 K_direct : ndarray of shape (n_rows,), float64 with NaN where undefined.
477512 K_spill : ndarray of shape (n_rows,), float64 with NaN where undefined.
513+
514+ Notes
515+ -----
516+ PR #456 R6 performance fix: when ``precomputed_trigger_onset_per_row``
517+ is supplied (as :func:`_compute_nearest_treated_distance_staggered`
518+ now optionally returns when called with ``d_bar=...``), the cohort
519+ loop is skipped — K_spill is derived directly from the precomputed
520+ trigger. The fallback (compute trigger inline) is kept for unit-test
521+ callers and other code paths that don't have access to the staggered
522+ distance helper's output.
478523 """
479524 n_rows = len (row_unit )
480525 row_time_arr = np .asarray (row_time , dtype = np .float64 )
@@ -485,8 +530,19 @@ def _compute_event_time_per_row(
485530 direct_defined = np .isfinite (own_onsets )
486531 K_direct [direct_defined ] = row_time_arr [direct_defined ] - own_onsets [direct_defined ]
487532
488- # trigger_onset[i] = first effective_onset among cohorts whose treated
489- # units have d(i, treated_in_cohort) <= d_bar.
533+ if precomputed_trigger_onset_per_row is not None :
534+ # Fast path: reuse trigger onsets already computed by the staggered
535+ # distance helper. Avoids a duplicate cohort loop.
536+ row_trigger = np .asarray (precomputed_trigger_onset_per_row , dtype = np .float64 )
537+ K_spill = np .full (n_rows , np .nan , dtype = np .float64 )
538+ triggered = np .isfinite (row_trigger )
539+ post_trigger = triggered & (row_time_arr >= row_trigger )
540+ K_spill [post_trigger ] = row_time_arr [post_trigger ] - row_trigger [post_trigger ]
541+ return K_direct , K_spill
542+
543+ # Fallback path (test callers, etc.): compute trigger inline via own
544+ # cohort loop. trigger_onset[i] = first effective_onset among cohorts
545+ # whose treated units have d(i, treated_in_cohort) <= d_bar.
490546 unit_coords_df = (
491547 data [[unit , coords [0 ], coords [1 ]]].drop_duplicates (subset = [unit ]).set_index (unit )
492548 )
@@ -2167,14 +2223,21 @@ def fit(
21672223 unit_coords_for_validation .shape [0 ],
21682224 )
21692225
2226+ # Capture the spillover-trigger onsets alongside d_it on the
2227+ # staggered path so the event-study branch below can reuse them
2228+ # without redoing the cohort distance loop (PR #456 R6 perf fix).
2229+ trigger_onset_per_row_cached : Optional [np .ndarray ] = None
21702230 if is_staggered :
2171- d_it_per_row , _ , _ = _compute_nearest_treated_distance_staggered (
2172- data ,
2173- unit = unit ,
2174- time = time ,
2175- coords = self .conley_coords ,
2176- metric = self .conley_metric ,
2177- first_treat_by_unit = effective_onsets ,
2231+ d_it_per_row , _ , _ , trigger_onset_per_row_cached = (
2232+ _compute_nearest_treated_distance_staggered (
2233+ data ,
2234+ unit = unit ,
2235+ time = time ,
2236+ coords = self .conley_coords ,
2237+ metric = self .conley_metric ,
2238+ first_treat_by_unit = effective_onsets ,
2239+ d_bar = self ._effective_d_bar if self .event_study else None ,
2240+ )
21782241 )
21792242 else :
21802243 # Non-staggered: single common onset. Build d_i per unit once,
@@ -2398,6 +2461,10 @@ def fit(
23982461 ),
23992462 metric = self .conley_metric ,
24002463 d_bar = self ._effective_d_bar ,
2464+ # PR #456 R6 perf fix: on the staggered path, reuse the
2465+ # trigger onsets computed during the d_it cohort loop
2466+ # instead of redoing the dense pairwise pass.
2467+ precomputed_trigger_onset_per_row = trigger_onset_per_row_cached ,
24012468 )
24022469 # event_study=True without conley_coords requires fallback coords for
24032470 # ring-trigger computation. The validator already requires either
0 commit comments