77
88import warnings
99from dataclasses import dataclass , field
10- from typing import Any , Dict , List , Optional , Tuple
10+ from typing import Any , Dict , List , Optional , Set , Tuple
1111
1212import numpy as np
1313import pandas as pd
@@ -1117,7 +1117,7 @@ def fit(
11171117
11181118 # Compute overall ATT (simple aggregation)
11191119 overall_att , overall_se = self ._aggregate_simple (
1120- group_time_effects , influence_func_info , df , unit
1120+ group_time_effects , influence_func_info , df , unit , precomputed
11211121 )
11221122 overall_t = overall_att / overall_se if overall_se > 0 else 0.0
11231123 overall_p = compute_p_value (overall_t )
@@ -1471,6 +1471,7 @@ def _aggregate_simple(
14711471 influence_func_info : Dict ,
14721472 df : pd .DataFrame ,
14731473 unit : str ,
1474+ precomputed : Optional [PrecomputedData ] = None ,
14741475 ) -> Tuple [float , float ]:
14751476 """
14761477 Compute simple weighted average of ATT(g,t).
@@ -1508,7 +1509,7 @@ def _aggregate_simple(
15081509 # Compute SE using influence function aggregation with wif adjustment
15091510 overall_se = self ._compute_aggregated_se_with_wif (
15101511 gt_pairs , weights_norm , effects , groups_for_gt ,
1511- influence_func_info , df , unit
1512+ influence_func_info , df , unit , precomputed
15121513 )
15131514
15141515 return overall_att , overall_se
@@ -1582,6 +1583,7 @@ def _compute_aggregated_se_with_wif(
15821583 influence_func_info : Dict ,
15831584 df : pd .DataFrame ,
15841585 unit : str ,
1586+ precomputed : Optional [PrecomputedData ] = None ,
15851587 ) -> float :
15861588 """
15871589 Compute SE with weight influence function (wif) adjustment.
@@ -1605,22 +1607,23 @@ def _compute_aggregated_se_with_wif(
16051607 return 0.0
16061608
16071609 # Build unit index mapping
1608- all_units = set ()
1610+ all_units_set : Set [ Any ] = set ()
16091611 for (g , t ) in gt_pairs :
16101612 if (g , t ) in influence_func_info :
16111613 info = influence_func_info [(g , t )]
1612- all_units .update (info ['treated_units' ])
1613- all_units .update (info ['control_units' ])
1614+ all_units_set .update (info ['treated_units' ])
1615+ all_units_set .update (info ['control_units' ])
16141616
1615- if not all_units :
1617+ if not all_units_set :
16161618 return 0.0
16171619
1618- all_units = sorted (all_units )
1620+ all_units = sorted (all_units_set )
16191621 n_units = len (all_units )
16201622 unit_to_idx = {u : i for i , u in enumerate (all_units )}
16211623
16221624 # Get unique groups and their information
16231625 unique_groups = sorted (set (groups_for_gt ))
1626+ unique_groups_set = set (unique_groups )
16241627 group_to_idx = {g : i for i , g in enumerate (unique_groups )}
16251628
16261629 # Compute group-level probabilities matching R's formula:
@@ -1639,6 +1642,10 @@ def _compute_aggregated_se_with_wif(
16391642 pg_keepers = np .array ([pg_by_group [group_to_idx [g ]] for g in groups_for_gt ])
16401643 sum_pg_keepers = np .sum (pg_keepers )
16411644
1645+ # Guard against zero weights (no keepers = no variance)
1646+ if sum_pg_keepers == 0 :
1647+ return 0.0
1648+
16421649 # Standard aggregated influence (without wif)
16431650 psi_standard = np .zeros (n_units )
16441651
@@ -1649,62 +1656,66 @@ def _compute_aggregated_se_with_wif(
16491656 info = influence_func_info [(g , t )]
16501657 w = weights [j ]
16511658
1652- for i , uid in enumerate (info ['treated_units' ]):
1653- idx = unit_to_idx [uid ]
1654- psi_standard [idx ] += w * info ['treated_inf' ][i ]
1655-
1656- for i , uid in enumerate (info ['control_units' ]):
1657- idx = unit_to_idx [uid ]
1658- psi_standard [idx ] += w * info ['control_inf' ][i ]
1659-
1660- # Build unit-group membership indicator
1661- unit_groups = {}
1662- for uid in all_units :
1663- unit_first_treat = df [df [unit ] == uid ]['first_treat' ].iloc [0 ]
1664- if unit_first_treat in unique_groups :
1665- unit_groups [uid ] = unit_first_treat
1666- else :
1667- unit_groups [uid ] = None # Never-treated or other
1668-
1669- # Compute wif using R's exact formula (iterate over keepers, not groups)
1670- # R's wif function:
1659+ # Vectorized influence function aggregation for treated units
1660+ treated_indices = np .array ([unit_to_idx [uid ] for uid in info ['treated_units' ]])
1661+ if len (treated_indices ) > 0 :
1662+ np .add .at (psi_standard , treated_indices , w * info ['treated_inf' ])
1663+
1664+ # Vectorized influence function aggregation for control units
1665+ control_indices = np .array ([unit_to_idx [uid ] for uid in info ['control_units' ]])
1666+ if len (control_indices ) > 0 :
1667+ np .add .at (psi_standard , control_indices , w * info ['control_inf' ])
1668+
1669+ # Build unit-group array using precomputed data if available
1670+ # This is O(n_units) instead of O(n_units × n_obs) DataFrame lookups
1671+ if precomputed is not None :
1672+ # Use precomputed cohort mapping
1673+ precomputed_units = precomputed ['all_units' ]
1674+ precomputed_cohorts = precomputed ['unit_cohorts' ]
1675+ precomputed_unit_to_idx = precomputed ['unit_to_idx' ]
1676+
1677+ # Build unit_groups_array for the units in this SE computation
1678+ # A value of -1 indicates never-treated or other (not in unique_groups)
1679+ unit_groups_array = np .full (n_units , - 1 , dtype = np .float64 )
1680+ for i , uid in enumerate (all_units ):
1681+ if uid in precomputed_unit_to_idx :
1682+ cohort = precomputed_cohorts [precomputed_unit_to_idx [uid ]]
1683+ if cohort in unique_groups_set :
1684+ unit_groups_array [i ] = cohort
1685+ else :
1686+ # Fallback: build from DataFrame (slow path for backward compatibility)
1687+ unit_groups_array = np .full (n_units , - 1 , dtype = np .float64 )
1688+ for i , uid in enumerate (all_units ):
1689+ unit_first_treat = df [df [unit ] == uid ]['first_treat' ].iloc [0 ]
1690+ if unit_first_treat in unique_groups_set :
1691+ unit_groups_array [i ] = unit_first_treat
1692+
1693+ # Vectorized WIF computation
1694+ # R's wif formula:
16711695 # if1[i,k] = (indicator(G_i == group_k) - pg[k]) / sum(pg[keepers])
16721696 # if2[i,k] = indicator_sum[i] * pg[k] / sum(pg[keepers])^2
16731697 # wif[i,k] = if1[i,k] - if2[i,k]
1674- #
1675- # Then: wif_contrib[i] = sum_k(wif[i,k] * att[k])
1698+ # wif_contrib[i] = sum_k(wif[i,k] * att[k])
16761699
1677- n_keepers = len (gt_pairs )
1678- wif_contrib = np .zeros (n_units )
1700+ # Build indicator matrix: (n_units, n_keepers)
1701+ # indicator_matrix[i, k] = 1.0 if unit i belongs to group for keeper k
1702+ groups_for_gt_array = np .array (groups_for_gt )
1703+ indicator_matrix = (unit_groups_array [:, np .newaxis ] == groups_for_gt_array [np .newaxis , :]).astype (np .float64 )
16791704
1680- # Pre-compute indicator_sum for each unit
1705+ # Vectorized indicator_sum: sum over keepers
16811706 # indicator_sum[i] = sum_k(indicator(G_i == group_k) - pg[k])
1682- indicator_sum = np .zeros (n_units )
1683- for j , g in enumerate (groups_for_gt ):
1684- pg_k = pg_keepers [j ]
1685- for uid in all_units :
1686- i = unit_to_idx [uid ]
1687- unit_g = unit_groups [uid ]
1688- indicator = 1.0 if unit_g == g else 0.0
1689- indicator_sum [i ] += (indicator - pg_k )
1690-
1691- # Compute wif contribution for each keeper
1692- for j , (g , t ) in enumerate (gt_pairs ):
1693- pg_k = pg_keepers [j ]
1694- att_k = effects [j ]
1695-
1696- for uid in all_units :
1697- i = unit_to_idx [uid ]
1698- unit_g = unit_groups [uid ]
1699- indicator = 1.0 if unit_g == g else 0.0
1700-
1701- # R's formula for wif
1702- if1_ik = (indicator - pg_k ) / sum_pg_keepers
1703- if2_ik = indicator_sum [i ] * pg_k / (sum_pg_keepers ** 2 )
1704- wif_ik = if1_ik - if2_ik
1705-
1706- # Add contribution: wif[i,k] * att[k]
1707- wif_contrib [i ] += wif_ik * att_k
1707+ indicator_sum = np .sum (indicator_matrix - pg_keepers , axis = 1 )
1708+
1709+ # Vectorized wif matrix computation
1710+ # if1_matrix[i,k] = (indicator[i,k] - pg[k]) / sum_pg
1711+ if1_matrix = (indicator_matrix - pg_keepers ) / sum_pg_keepers
1712+ # if2_matrix[i,k] = indicator_sum[i] * pg[k] / sum_pg^2
1713+ if2_matrix = np .outer (indicator_sum , pg_keepers ) / (sum_pg_keepers ** 2 )
1714+ wif_matrix = if1_matrix - if2_matrix
1715+
1716+ # Single matrix-vector multiply for all contributions
1717+ # wif_contrib[i] = sum_k(wif[i,k] * att[k])
1718+ wif_contrib = wif_matrix @ effects
17081719
17091720 # Scale by 1/n_units to match R's getSE formula: sqrt(mean(IF^2)/n)
17101721 psi_wif = wif_contrib / n_units
0 commit comments