Skip to content

Commit 01d2306

Browse files
authored
Merge pull request #65 from igerber/optimize-callaway-wif-performance
Optimize CallawaySantAnna WIF computation with vectorized NumPy operations
2 parents 3849196 + 1e686e6 commit 01d2306

2 files changed

Lines changed: 129 additions & 106 deletions

File tree

diff_diff/staggered.py

Lines changed: 69 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
import warnings
99
from dataclasses import dataclass, field
10-
from typing import Any, Dict, List, Optional, Tuple
10+
from typing import Any, Dict, List, Optional, Set, Tuple
1111

1212
import numpy as np
1313
import pandas as pd
@@ -1117,7 +1117,7 @@ def fit(
11171117

11181118
# Compute overall ATT (simple aggregation)
11191119
overall_att, overall_se = self._aggregate_simple(
1120-
group_time_effects, influence_func_info, df, unit
1120+
group_time_effects, influence_func_info, df, unit, precomputed
11211121
)
11221122
overall_t = overall_att / overall_se if overall_se > 0 else 0.0
11231123
overall_p = compute_p_value(overall_t)
@@ -1471,6 +1471,7 @@ def _aggregate_simple(
14711471
influence_func_info: Dict,
14721472
df: pd.DataFrame,
14731473
unit: str,
1474+
precomputed: Optional[PrecomputedData] = None,
14741475
) -> Tuple[float, float]:
14751476
"""
14761477
Compute simple weighted average of ATT(g,t).
@@ -1508,7 +1509,7 @@ def _aggregate_simple(
15081509
# Compute SE using influence function aggregation with wif adjustment
15091510
overall_se = self._compute_aggregated_se_with_wif(
15101511
gt_pairs, weights_norm, effects, groups_for_gt,
1511-
influence_func_info, df, unit
1512+
influence_func_info, df, unit, precomputed
15121513
)
15131514

15141515
return overall_att, overall_se
@@ -1582,6 +1583,7 @@ def _compute_aggregated_se_with_wif(
15821583
influence_func_info: Dict,
15831584
df: pd.DataFrame,
15841585
unit: str,
1586+
precomputed: Optional[PrecomputedData] = None,
15851587
) -> float:
15861588
"""
15871589
Compute SE with weight influence function (wif) adjustment.
@@ -1605,22 +1607,23 @@ def _compute_aggregated_se_with_wif(
16051607
return 0.0
16061608

16071609
# Build unit index mapping
1608-
all_units = set()
1610+
all_units_set: Set[Any] = set()
16091611
for (g, t) in gt_pairs:
16101612
if (g, t) in influence_func_info:
16111613
info = influence_func_info[(g, t)]
1612-
all_units.update(info['treated_units'])
1613-
all_units.update(info['control_units'])
1614+
all_units_set.update(info['treated_units'])
1615+
all_units_set.update(info['control_units'])
16141616

1615-
if not all_units:
1617+
if not all_units_set:
16161618
return 0.0
16171619

1618-
all_units = sorted(all_units)
1620+
all_units = sorted(all_units_set)
16191621
n_units = len(all_units)
16201622
unit_to_idx = {u: i for i, u in enumerate(all_units)}
16211623

16221624
# Get unique groups and their information
16231625
unique_groups = sorted(set(groups_for_gt))
1626+
unique_groups_set = set(unique_groups)
16241627
group_to_idx = {g: i for i, g in enumerate(unique_groups)}
16251628

16261629
# Compute group-level probabilities matching R's formula:
@@ -1639,6 +1642,10 @@ def _compute_aggregated_se_with_wif(
16391642
pg_keepers = np.array([pg_by_group[group_to_idx[g]] for g in groups_for_gt])
16401643
sum_pg_keepers = np.sum(pg_keepers)
16411644

1645+
# Guard against zero weights (no keepers = no variance)
1646+
if sum_pg_keepers == 0:
1647+
return 0.0
1648+
16421649
# Standard aggregated influence (without wif)
16431650
psi_standard = np.zeros(n_units)
16441651

@@ -1649,62 +1656,66 @@ def _compute_aggregated_se_with_wif(
16491656
info = influence_func_info[(g, t)]
16501657
w = weights[j]
16511658

1652-
for i, uid in enumerate(info['treated_units']):
1653-
idx = unit_to_idx[uid]
1654-
psi_standard[idx] += w * info['treated_inf'][i]
1655-
1656-
for i, uid in enumerate(info['control_units']):
1657-
idx = unit_to_idx[uid]
1658-
psi_standard[idx] += w * info['control_inf'][i]
1659-
1660-
# Build unit-group membership indicator
1661-
unit_groups = {}
1662-
for uid in all_units:
1663-
unit_first_treat = df[df[unit] == uid]['first_treat'].iloc[0]
1664-
if unit_first_treat in unique_groups:
1665-
unit_groups[uid] = unit_first_treat
1666-
else:
1667-
unit_groups[uid] = None # Never-treated or other
1668-
1669-
# Compute wif using R's exact formula (iterate over keepers, not groups)
1670-
# R's wif function:
1659+
# Vectorized influence function aggregation for treated units
1660+
treated_indices = np.array([unit_to_idx[uid] for uid in info['treated_units']])
1661+
if len(treated_indices) > 0:
1662+
np.add.at(psi_standard, treated_indices, w * info['treated_inf'])
1663+
1664+
# Vectorized influence function aggregation for control units
1665+
control_indices = np.array([unit_to_idx[uid] for uid in info['control_units']])
1666+
if len(control_indices) > 0:
1667+
np.add.at(psi_standard, control_indices, w * info['control_inf'])
1668+
1669+
# Build unit-group array using precomputed data if available
1670+
# This is O(n_units) instead of O(n_units × n_obs) DataFrame lookups
1671+
if precomputed is not None:
1672+
# Use precomputed cohort mapping
1673+
precomputed_units = precomputed['all_units']
1674+
precomputed_cohorts = precomputed['unit_cohorts']
1675+
precomputed_unit_to_idx = precomputed['unit_to_idx']
1676+
1677+
# Build unit_groups_array for the units in this SE computation
1678+
# A value of -1 indicates never-treated or other (not in unique_groups)
1679+
unit_groups_array = np.full(n_units, -1, dtype=np.float64)
1680+
for i, uid in enumerate(all_units):
1681+
if uid in precomputed_unit_to_idx:
1682+
cohort = precomputed_cohorts[precomputed_unit_to_idx[uid]]
1683+
if cohort in unique_groups_set:
1684+
unit_groups_array[i] = cohort
1685+
else:
1686+
# Fallback: build from DataFrame (slow path for backward compatibility)
1687+
unit_groups_array = np.full(n_units, -1, dtype=np.float64)
1688+
for i, uid in enumerate(all_units):
1689+
unit_first_treat = df[df[unit] == uid]['first_treat'].iloc[0]
1690+
if unit_first_treat in unique_groups_set:
1691+
unit_groups_array[i] = unit_first_treat
1692+
1693+
# Vectorized WIF computation
1694+
# R's wif formula:
16711695
# if1[i,k] = (indicator(G_i == group_k) - pg[k]) / sum(pg[keepers])
16721696
# if2[i,k] = indicator_sum[i] * pg[k] / sum(pg[keepers])^2
16731697
# wif[i,k] = if1[i,k] - if2[i,k]
1674-
#
1675-
# Then: wif_contrib[i] = sum_k(wif[i,k] * att[k])
1698+
# wif_contrib[i] = sum_k(wif[i,k] * att[k])
16761699

1677-
n_keepers = len(gt_pairs)
1678-
wif_contrib = np.zeros(n_units)
1700+
# Build indicator matrix: (n_units, n_keepers)
1701+
# indicator_matrix[i, k] = 1.0 if unit i belongs to group for keeper k
1702+
groups_for_gt_array = np.array(groups_for_gt)
1703+
indicator_matrix = (unit_groups_array[:, np.newaxis] == groups_for_gt_array[np.newaxis, :]).astype(np.float64)
16791704

1680-
# Pre-compute indicator_sum for each unit
1705+
# Vectorized indicator_sum: sum over keepers
16811706
# indicator_sum[i] = sum_k(indicator(G_i == group_k) - pg[k])
1682-
indicator_sum = np.zeros(n_units)
1683-
for j, g in enumerate(groups_for_gt):
1684-
pg_k = pg_keepers[j]
1685-
for uid in all_units:
1686-
i = unit_to_idx[uid]
1687-
unit_g = unit_groups[uid]
1688-
indicator = 1.0 if unit_g == g else 0.0
1689-
indicator_sum[i] += (indicator - pg_k)
1690-
1691-
# Compute wif contribution for each keeper
1692-
for j, (g, t) in enumerate(gt_pairs):
1693-
pg_k = pg_keepers[j]
1694-
att_k = effects[j]
1695-
1696-
for uid in all_units:
1697-
i = unit_to_idx[uid]
1698-
unit_g = unit_groups[uid]
1699-
indicator = 1.0 if unit_g == g else 0.0
1700-
1701-
# R's formula for wif
1702-
if1_ik = (indicator - pg_k) / sum_pg_keepers
1703-
if2_ik = indicator_sum[i] * pg_k / (sum_pg_keepers ** 2)
1704-
wif_ik = if1_ik - if2_ik
1705-
1706-
# Add contribution: wif[i,k] * att[k]
1707-
wif_contrib[i] += wif_ik * att_k
1707+
indicator_sum = np.sum(indicator_matrix - pg_keepers, axis=1)
1708+
1709+
# Vectorized wif matrix computation
1710+
# if1_matrix[i,k] = (indicator[i,k] - pg[k]) / sum_pg
1711+
if1_matrix = (indicator_matrix - pg_keepers) / sum_pg_keepers
1712+
# if2_matrix[i,k] = indicator_sum[i] * pg[k] / sum_pg^2
1713+
if2_matrix = np.outer(indicator_sum, pg_keepers) / (sum_pg_keepers ** 2)
1714+
wif_matrix = if1_matrix - if2_matrix
1715+
1716+
# Single matrix-vector multiply for all contributions
1717+
# wif_contrib[i] = sum_k(wif[i,k] * att[k])
1718+
wif_contrib = wif_matrix @ effects
17081719

17091720
# Scale by 1/n_units to match R's getSE formula: sqrt(mean(IF^2)/n)
17101721
psi_wif = wif_contrib / n_units

docs/benchmarks.rst

Lines changed: 60 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ Summary Table
7676
- **PASS**
7777
* - CallawaySantAnna
7878
- < 1e-10
79-
- < 1%
79+
- 0.0%
8080
- Yes
8181
- **PASS**
8282
* - SyntheticDiD
@@ -171,27 +171,28 @@ Callaway-Sant'Anna Results
171171
- 2.519
172172
- < 1e-10
173173
* - SE
174-
- 0.062
175-
- 0.062
176174
- 0.063
177-
- 2.3%
175+
- 0.063
176+
- 0.063
177+
- 0.0%
178178
* - Time (s)
179-
- 0.005
180-
- 0.005
181-
- 0.071
182-
- **14x faster**
179+
- 0.007 ± 0.000
180+
- 0.007 ± 0.000
181+
- 0.070 ± 0.001
182+
- **10x faster**
183183

184-
**Validation**: PASS - Both point estimates and standard errors match R closely.
184+
**Validation**: PASS - Both point estimates and standard errors match R exactly.
185185

186186
**Key findings from investigation:**
187187

188188
1. **Individual ATT(g,t) effects match perfectly** (~1e-11 difference)
189189
2. **Never-treated coding**: R's ``did`` package requires ``first_treat=Inf``
190190
for never-treated units. diff-diff accepts ``first_treat=0``. The benchmark
191191
converts 0 to Inf for R compatibility.
192-
3. **Standard errors**: As of v1.5.0, analytical SEs use influence function
193-
aggregation (matching R's approach), resulting in < 3% SE difference across
194-
all scales. Both analytical and bootstrap inference now match R closely.
192+
3. **Standard errors**: As of v2.0.2, analytical SEs match R's ``did`` package
193+
exactly (0.0% difference). The weight influence function (wif) formula was
194+
corrected to match R's implementation, achieving numerical equivalence across
195+
all dataset scales.
195196

196197
Performance Comparison
197198
----------------------
@@ -270,37 +271,37 @@ Three-Way Performance Summary
270271
- R (s)
271272
- Python Pure (s)
272273
- Python Rust (s)
273-
- Rust/R
274+
- Pure/R
274275
- Rust/Pure
275276
* - small
276-
- 0.071
277-
- 0.005
278-
- 0.005
279-
- **14.1x**
277+
- 0.070
278+
- 0.007
279+
- 0.007
280+
- **10x**
280281
- 1.0x
281282
* - 1k
282283
- 0.114
283-
- 0.012
284-
- 0.012
285-
- **9.4x**
284+
- 0.013
285+
- 0.013
286+
- **9x**
286287
- 1.0x
287288
* - 5k
288-
- 0.341
289-
- 0.055
290-
- 0.056
291-
- **6.1x**
289+
- 0.345
290+
- 0.053
291+
- 0.051
292+
- **7x**
292293
- 1.0x
293294
* - 10k
294-
- 0.726
295-
- 0.156
296-
- 0.155
297-
- **4.7x**
295+
- 0.727
296+
- 0.134
297+
- 0.138
298+
- **5x**
298299
- 1.0x
299300
* - 20k
300-
- 1.464
301-
- 0.404
302-
- 0.411
303-
- **3.6x**
301+
- 1.490
302+
- 0.352
303+
- 0.358
304+
- **4x**
304305
- 1.0x
305306

306307
**SyntheticDiD Results:**
@@ -391,10 +392,10 @@ Dataset Sizes
391392
Key Observations
392393
~~~~~~~~~~~~~~~~
393394

394-
1. **diff-diff is dramatically faster than R**:
395+
1. **Performance varies by estimator and scale**:
395396

396-
- **BasicDiD/TWFE**: 2-18x faster than R
397-
- **CallawaySantAnna**: 4-14x faster than R
397+
- **BasicDiD/TWFE**: 2-18x faster than R at all scales
398+
- **CallawaySantAnna**: 4-10x faster than R at all scales (vectorized WIF computation)
398399
- **SyntheticDiD**: 565-2234x faster than R (R takes 24 minutes at 10k scale!)
399400

400401
2. **Rust backend benefit depends on the estimator**:
@@ -410,15 +411,20 @@ Key Observations
410411
- **Bootstrap inference**: May help with parallelized iterations
411412
- **BasicDiD/CallawaySantAnna**: Optional - pure Python is equally fast
412413

413-
4. **Scaling behavior**: Both Python implementations show excellent scaling.
414-
At 10K scale (500K observations for SyntheticDiD), Rust completes in
415-
~2.6 seconds vs ~20 seconds for pure Python vs ~24 minutes for R.
414+
4. **Scaling behavior**: Python implementations show excellent scaling behavior
415+
across all estimators. SyntheticDiD is 565x faster than R at 10k scale.
416+
CallawaySantAnna achieves **exact SE accuracy** (0.0% difference) while
417+
being 4-10x faster than R through vectorized NumPy operations.
416418

417419
5. **No Rust required for most use cases**: Users without Rust/maturin can
418420
install diff-diff and get full functionality with excellent performance.
419-
For BasicDiD and CallawaySantAnna, pure Python achieves the same speed as Rust.
420421
Only SyntheticDiD benefits significantly from the Rust backend.
421422

423+
6. **CallawaySantAnna accuracy and speed**: As of v2.0.3, CallawaySantAnna
424+
achieves both exact numerical accuracy (0.0% SE difference from R) AND
425+
superior performance (4-10x faster than R) through vectorized weight
426+
influence function (WIF) computation using NumPy matrix operations.
427+
422428
Performance Optimization Details
423429
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
424430

@@ -436,7 +442,12 @@ The performance improvements come from:
436442
4. **Vectorized bootstrap** (CallawaySantAnna): Matrix operations instead of
437443
nested loops, batch weight generation
438444

439-
5. **Optional Rust backend** (v2.0.0): PyO3-based Rust extension for compute-intensive
445+
5. **Vectorized WIF computation** (CallawaySantAnna, v2.0.3): Weight influence
446+
function computation uses NumPy matrix operations instead of O(n_units × n_keepers)
447+
nested loops. The indicator matrix, if1/if2 matrices, and wif contribution are
448+
computed using broadcasting and matrix multiplication: ``wif_contrib = wif_matrix @ effects``
449+
450+
6. **Optional Rust backend** (v2.0.0): PyO3-based Rust extension for compute-intensive
440451
operations (OLS, robust variance, bootstrap weights, simplex projection)
441452

442453
Why is diff-diff Fast?
@@ -496,12 +507,13 @@ Results Comparison
496507
1. **Point estimates match exactly**: The overall ATT of -0.039951 is identical
497508
between diff-diff and R's ``did`` package, validating the core estimation logic.
498509

499-
2. **Standard errors match**: As of v1.5.0, analytical SEs use influence function
500-
aggregation (matching R's approach), resulting in < 1% difference. Both point
501-
estimates and standard errors now match R's ``did`` package.
510+
2. **Standard errors match exactly**: As of v2.0.2, analytical SEs use the corrected
511+
weight influence function formula, achieving 0.0% difference from R's ``did``
512+
package. Both point estimates and standard errors are numerically equivalent.
502513

503-
3. **Performance**: diff-diff is ~14x faster than R on this real-world dataset,
504-
consistent with the synthetic data benchmarks at small scale.
514+
3. **Performance**: diff-diff is ~14x faster than R on this real-world dataset
515+
at small scale. Performance scales differently at larger sizes (see performance
516+
tables above).
505517

506518
This validation on real-world data with known published results confirms that
507519
diff-diff produces correct estimates that match the reference R implementation.
@@ -576,9 +588,9 @@ When to Trust Results
576588
match R closely. Use ``variance_method="placebo"`` (default) to match R's
577589
inference. Results are fully validated.
578590

579-
- **CallawaySantAnna**: Group-time effects (ATT(g,t)) are reliable. Overall
580-
ATT aggregation may differ from R due to weighting choices. When comparing
581-
to R ``did`` package, verify aggregation settings match.
591+
- **CallawaySantAnna**: Both group-time effects (ATT(g,t)) and overall ATT
592+
aggregation match R exactly. Standard errors are numerically equivalent
593+
(0.0% difference) as of v2.0.2.
582594

583595
Known Differences
584596
~~~~~~~~~~~~~~~~~

0 commit comments

Comments
 (0)