Skip to content

Commit cfb822e

Browse files
igerberclaude
andcommitted
Address PR #456 R6 review (1 P2 perf + 1 P3 docs)
P2 perf: remove duplicate cohort distance pass. Previously the event- study path computed cohort-by-unit distances twice on staggered panels: once in _compute_nearest_treated_distance_staggered for d_it (running min), then again in _compute_event_time_per_row to recover the per-row spillover-trigger onset. On large staggered panels this doubled the dominant spatial work. Fix: thread d_bar into _compute_nearest_treated_distance_staggered as an optional kwarg. When supplied, the cohort loop now ALSO computes trigger_onset_per_unit (the first cohort whose treated units fall within d_bar of unit i) and broadcasts it to rows. The helper's return is now a 4-tuple (d_it, row_unit, row_time, trigger_onset_or_None). _compute_event_time_per_row accepts an optional precomputed_trigger_onset_per_row that, when supplied (as fit() now does on the staggered event-study path), skips the redundant cohort loop. Falls back to inline computation for unit-test callers. Test callsites for _compute_nearest_treated_distance_staggered updated to handle the new 4-tuple via `d_it, row_unit, row_time, _trigger = ...`. P3 docs: llms-full.txt and api/spillover.rst now explicitly state that event_study=True requires horizon_max>=1 or None (horizon_max=0 is rejected, with redirect to event_study=False for the aggregate spec). The previous wording described horizon_max=0 as a meaningful collapsed design, which contradicted the new R5 rejection. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
1 parent 59bdc1b commit cfb822e

4 files changed

Lines changed: 88 additions & 17 deletions

File tree

diff_diff/guides/llms-full.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -478,7 +478,7 @@ SpilloverDiD(
478478
alpha: float = 0.05,
479479
anticipation: int = 0,
480480
event_study: bool = False, # Wave C: per-event-time × ring decomposition (Butts Table 2)
481-
horizon_max: int | None = None, # Bin event-times outside [-H,+H] into endpoint pools (event-study mode)
481+
horizon_max: int | None = None, # Bin event-times outside [-H,+H] into endpoint pools (event-study mode); H>=1 or None — H=0 rejected (use event_study=False for aggregate spec)
482482
rank_deficient_action: str = "warn",
483483
)
484484
```

diff_diff/spillover.py

Lines changed: 79 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -338,7 +338,8 @@ def _compute_nearest_treated_distance_staggered(
338338
coords: Tuple[str, str],
339339
metric: SpilloverMetric,
340340
first_treat_by_unit: Dict[Any, Any],
341-
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
341+
d_bar: Optional[float] = None,
342+
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, Optional[np.ndarray]]:
342343
"""Return per-row nearest-treated distance for the staggered case.
343344
344345
For each (unit, period) observation, find the minimum distance to any
@@ -361,6 +362,13 @@ def _compute_nearest_treated_distance_staggered(
361362
first_treat_by_unit : dict
362363
Mapping from unit identifier to onset time (or ``np.inf`` for
363364
never-treated). Generated by :func:`_extract_treatment_onsets`.
365+
d_bar : float, optional
366+
When supplied, the function additionally computes the per-row
367+
**spillover-trigger onset** (earliest cohort onset whose treated
368+
units fall within ``d_bar`` of unit ``i``) reusing the cohort
369+
loop. Used by :func:`_compute_event_time_per_row` to avoid a
370+
duplicate cohort pass on the event-study path
371+
(PR #456 R6 performance fix).
364372
365373
Notes
366374
-----
@@ -377,6 +385,11 @@ def _compute_nearest_treated_distance_staggered(
377385
Aligned unit identifier per row (for downstream broadcasting).
378386
row_time : ndarray of shape (n_rows,)
379387
Aligned time identifier per row.
388+
trigger_onset_per_row : ndarray of shape (n_rows,) or None
389+
``None`` when ``d_bar`` is None. Otherwise: per-row earliest
390+
cohort onset whose treated units fall within ``d_bar`` of the
391+
row's unit, broadcast from per-unit. NaN for rows whose unit is
392+
never within ``d_bar`` of any cohort.
380393
"""
381394
unit_coords_df = (
382395
data[[unit, coords[0], coords[1]]].drop_duplicates(subset=[unit]).set_index(unit)
@@ -389,13 +402,16 @@ def _compute_nearest_treated_distance_staggered(
389402
row_time = np.asarray(data[time].values)
390403
n_rows = len(row_unit)
391404
d_it = np.full(n_rows, np.inf, dtype=np.float64)
405+
trigger_onset_per_unit_pos: Optional[np.ndarray] = (
406+
np.full(len(unit_index), np.nan, dtype=np.float64) if d_bar is not None else None
407+
)
392408

393409
# Determine the cohort onset times that exist in the data (excluding never-treated).
394410
unique_onsets = sorted({ft for ft in first_treat_by_unit.values() if np.isfinite(ft)})
395411
if not unique_onsets:
396412
# Degenerate: no treated units. Caller should have rejected this
397413
# in `_validate_spillover_inputs`, but defensively return inf.
398-
return d_it, row_unit, row_time
414+
return d_it, row_unit, row_time, None
399415

400416
# Row's unit position. Invariant across cohort iterations — compute
401417
# once outside the loop.
@@ -426,7 +442,25 @@ def _compute_nearest_treated_distance_staggered(
426442
update_mask = affected_rows & (row_cohort_dist < d_it)
427443
d_it[update_mask] = row_cohort_dist[update_mask]
428444

429-
return d_it, row_unit, row_time
445+
# Reuse this same cohort distance computation for the per-unit
446+
# spillover-trigger onset when d_bar is supplied. The trigger is
447+
# the FIRST cohort whose treated units fall within d_bar of unit
448+
# i — once locked it persists for later cohort iterations. Using
449+
# cumulative-treated distances here is fine: if a unit is in
450+
# range of cohort c1, dists_to_cohort at onset=c1 already detects
451+
# it; later iterations with extra treated units only shrink the
452+
# distance, never grow it back above d_bar.
453+
if trigger_onset_per_unit_pos is not None:
454+
in_range_for_cohort = dists_to_cohort <= d_bar
455+
not_yet_triggered = np.isnan(trigger_onset_per_unit_pos)
456+
trigger_onset_per_unit_pos[in_range_for_cohort & not_yet_triggered] = onset
457+
458+
# Broadcast per-unit trigger to rows when computed.
459+
if trigger_onset_per_unit_pos is not None:
460+
trigger_onset_per_row = trigger_onset_per_unit_pos[row_pos]
461+
else:
462+
trigger_onset_per_row = None
463+
return d_it, row_unit, row_time, trigger_onset_per_row
430464

431465

432466
def _compute_event_time_per_row(
@@ -439,6 +473,7 @@ def _compute_event_time_per_row(
439473
coords: Tuple[str, str],
440474
metric: SpilloverMetric,
441475
d_bar: float,
476+
precomputed_trigger_onset_per_row: Optional[np.ndarray] = None,
442477
) -> Tuple[np.ndarray, np.ndarray]:
443478
"""Compute two event-time clocks per row for Wave C event-study mode.
444479
@@ -475,6 +510,16 @@ def _compute_event_time_per_row(
475510
-------
476511
K_direct : ndarray of shape (n_rows,), float64 with NaN where undefined.
477512
K_spill : ndarray of shape (n_rows,), float64 with NaN where undefined.
513+
514+
Notes
515+
-----
516+
PR #456 R6 performance fix: when ``precomputed_trigger_onset_per_row``
517+
is supplied (as :func:`_compute_nearest_treated_distance_staggered`
518+
now optionally returns when called with ``d_bar=...``), the cohort
519+
loop is skipped — K_spill is derived directly from the precomputed
520+
trigger. The fallback (compute trigger inline) is kept for unit-test
521+
callers and other code paths that don't have access to the staggered
522+
distance helper's output.
478523
"""
479524
n_rows = len(row_unit)
480525
row_time_arr = np.asarray(row_time, dtype=np.float64)
@@ -485,8 +530,19 @@ def _compute_event_time_per_row(
485530
direct_defined = np.isfinite(own_onsets)
486531
K_direct[direct_defined] = row_time_arr[direct_defined] - own_onsets[direct_defined]
487532

488-
# trigger_onset[i] = first effective_onset among cohorts whose treated
489-
# units have d(i, treated_in_cohort) <= d_bar.
533+
if precomputed_trigger_onset_per_row is not None:
534+
# Fast path: reuse trigger onsets already computed by the staggered
535+
# distance helper. Avoids a duplicate cohort loop.
536+
row_trigger = np.asarray(precomputed_trigger_onset_per_row, dtype=np.float64)
537+
K_spill = np.full(n_rows, np.nan, dtype=np.float64)
538+
triggered = np.isfinite(row_trigger)
539+
post_trigger = triggered & (row_time_arr >= row_trigger)
540+
K_spill[post_trigger] = row_time_arr[post_trigger] - row_trigger[post_trigger]
541+
return K_direct, K_spill
542+
543+
# Fallback path (test callers, etc.): compute trigger inline via own
544+
# cohort loop. trigger_onset[i] = first effective_onset among cohorts
545+
# whose treated units have d(i, treated_in_cohort) <= d_bar.
490546
unit_coords_df = (
491547
data[[unit, coords[0], coords[1]]].drop_duplicates(subset=[unit]).set_index(unit)
492548
)
@@ -2167,14 +2223,21 @@ def fit(
21672223
unit_coords_for_validation.shape[0],
21682224
)
21692225

2226+
# Capture the spillover-trigger onsets alongside d_it on the
2227+
# staggered path so the event-study branch below can reuse them
2228+
# without redoing the cohort distance loop (PR #456 R6 perf fix).
2229+
trigger_onset_per_row_cached: Optional[np.ndarray] = None
21702230
if is_staggered:
2171-
d_it_per_row, _, _ = _compute_nearest_treated_distance_staggered(
2172-
data,
2173-
unit=unit,
2174-
time=time,
2175-
coords=self.conley_coords,
2176-
metric=self.conley_metric,
2177-
first_treat_by_unit=effective_onsets,
2231+
d_it_per_row, _, _, trigger_onset_per_row_cached = (
2232+
_compute_nearest_treated_distance_staggered(
2233+
data,
2234+
unit=unit,
2235+
time=time,
2236+
coords=self.conley_coords,
2237+
metric=self.conley_metric,
2238+
first_treat_by_unit=effective_onsets,
2239+
d_bar=self._effective_d_bar if self.event_study else None,
2240+
)
21782241
)
21792242
else:
21802243
# Non-staggered: single common onset. Build d_i per unit once,
@@ -2398,6 +2461,10 @@ def fit(
23982461
),
23992462
metric=self.conley_metric,
24002463
d_bar=self._effective_d_bar,
2464+
# PR #456 R6 perf fix: on the staggered path, reuse the
2465+
# trigger onsets computed during the d_it cohort loop
2466+
# instead of redoing the dense pairwise pass.
2467+
precomputed_trigger_onset_per_row=trigger_onset_per_row_cached,
24012468
)
24022469
# event_study=True without conley_coords requires fallback coords for
24032470
# ring-trigger computation. The validator already requires either

docs/api/spillover.rst

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,11 @@ planned as follow-up enhancements:
195195
period ``-1 - anticipation`` (TwoStageDiD parity). ``horizon_max``
196196
bins event-times into endpoint pools (no row drop — divergence
197197
from TwoStageDiD's filtering semantic, intentional per
198-
``feedback_no_silent_failures``). Scalar ``att`` becomes a
198+
``feedback_no_silent_failures``). ``horizon_max`` must be ``>=1`` or
199+
``None`` under ``event_study=True``; ``horizon_max=0`` is rejected
200+
(the single bin ``k=0`` leaves no event-time pair to anchor the
201+
reference period — for a single aggregate effect, use
202+
``event_study=False`` instead). Scalar ``att`` becomes a
199203
sample-share-weighted average of post-treatment ``tau_k`` with SE
200204
from linear-combination inference on the post-treatment vcov block.
201205
Per-event-time SEs share the same Wave B Gardner-GMM caveat

tests/test_spillover.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -359,7 +359,7 @@ class TestComputeNearestTreatedDistanceStaggered:
359359

360360
def test_inf_pre_any_treatment(self, staggered_panel):
361361
df, ft = staggered_panel
362-
d_it, row_unit, row_time = _compute_nearest_treated_distance_staggered(
362+
d_it, row_unit, row_time, _trigger = _compute_nearest_treated_distance_staggered(
363363
df,
364364
unit="unit",
365365
time="time",
@@ -373,7 +373,7 @@ def test_inf_pre_any_treatment(self, staggered_panel):
373373

374374
def test_cohort_a_active_at_t1(self, staggered_panel):
375375
df, ft = staggered_panel
376-
d_it, row_unit, row_time = _compute_nearest_treated_distance_staggered(
376+
d_it, row_unit, row_time, _trigger = _compute_nearest_treated_distance_staggered(
377377
df,
378378
unit="unit",
379379
time="time",
@@ -394,7 +394,7 @@ def test_cohort_a_active_at_t1(self, staggered_panel):
394394

395395
def test_running_min_across_cohorts_at_t2(self, staggered_panel):
396396
df, ft = staggered_panel
397-
d_it, row_unit, row_time = _compute_nearest_treated_distance_staggered(
397+
d_it, row_unit, row_time, _trigger = _compute_nearest_treated_distance_staggered(
398398
df,
399399
unit="unit",
400400
time="time",

0 commit comments

Comments
 (0)