@@ -1479,29 +1479,36 @@ def _aggregate_simple(
14791479
14801480 Standard errors are computed using influence function aggregation,
14811481 which properly accounts for covariances across (g,t) pairs due to
1482- shared control units. This matches R's `did` package approach.
1482+ shared control units. This includes the wif (weight influence function)
1483+ adjustment from R's `did` package that accounts for uncertainty in
1484+ estimating the group-size weights.
14831485 """
14841486 effects = []
14851487 weights_list = []
14861488 gt_pairs = []
1489+ groups_for_gt = []
14871490
14881491 for (g , t ), data in group_time_effects .items ():
14891492 effects .append (data ['effect' ])
14901493 weights_list .append (data ['n_treated' ])
14911494 gt_pairs .append ((g , t ))
1495+ groups_for_gt .append (g )
14921496
14931497 effects = np .array (effects )
14941498 weights = np .array (weights_list , dtype = float )
1499+ groups_for_gt = np .array (groups_for_gt )
14951500
14961501 # Normalize weights
1497- weights = weights / np .sum (weights )
1502+ total_weight = np .sum (weights )
1503+ weights_norm = weights / total_weight
14981504
14991505 # Weighted average
1500- overall_att = np .sum (weights * effects )
1506+ overall_att = np .sum (weights_norm * effects )
15011507
1502- # Compute SE using influence function aggregation
1503- overall_se = self ._compute_aggregated_se (
1504- gt_pairs , weights , influence_func_info
1508+ # Compute SE using influence function aggregation with wif adjustment
1509+ overall_se = self ._compute_aggregated_se_with_wif (
1510+ gt_pairs , weights_norm , effects , groups_for_gt ,
1511+ influence_func_info , df , unit
15051512 )
15061513
15071514 return overall_att , overall_se
@@ -1566,6 +1573,162 @@ def _compute_aggregated_se(
15661573 variance = np .sum (psi_overall ** 2 )
15671574 return np .sqrt (variance )
15681575
1576+ def _compute_aggregated_se_with_wif (
1577+ self ,
1578+ gt_pairs : List [Tuple [Any , Any ]],
1579+ weights : np .ndarray ,
1580+ effects : np .ndarray ,
1581+ groups_for_gt : np .ndarray ,
1582+ influence_func_info : Dict ,
1583+ df : pd .DataFrame ,
1584+ unit : str ,
1585+ ) -> float :
1586+ """
1587+ Compute SE with weight influence function (wif) adjustment.
1588+
1589+ This matches R's `did` package approach for "simple" aggregation,
1590+ which accounts for uncertainty in estimating group-size weights.
1591+
1592+ The wif adjustment adds variance due to the fact that aggregation
1593+ weights w_g = n_g / N depend on estimated group sizes.
1594+
1595+ Formula:
1596+ agg_inf_i = Σ_gt w_gt × inf_i_gt + wif_i × ATT_gt
1597+ se = sqrt(Σ_i (agg_inf_i)²)
1598+
1599+ where wif_i captures how unit i influences the weight estimation.
1600+ """
1601+ if not influence_func_info :
1602+ return 0.0
1603+
1604+ # Build unit index mapping
1605+ all_units = set ()
1606+ for (g , t ) in gt_pairs :
1607+ if (g , t ) in influence_func_info :
1608+ info = influence_func_info [(g , t )]
1609+ all_units .update (info ['treated_units' ])
1610+ all_units .update (info ['control_units' ])
1611+
1612+ if not all_units :
1613+ return 0.0
1614+
1615+ all_units = sorted (all_units )
1616+ n_units = len (all_units )
1617+ unit_to_idx = {u : i for i , u in enumerate (all_units )}
1618+
1619+ # Get unique groups and their information
1620+ unique_groups = sorted (set (groups_for_gt ))
1621+ n_groups = len (unique_groups )
1622+ group_to_idx = {g : i for i , g in enumerate (unique_groups )}
1623+
1624+ # Compute group-level probabilities (proportion of treated in each group)
1625+ # pg[g] = n_g / N where N = total treated across all groups
1626+ group_sizes = {}
1627+ for g in unique_groups :
1628+ # Count unique treated units in this group
1629+ treated_in_g = df [df ['first_treat' ] == g ][unit ].nunique ()
1630+ group_sizes [g ] = treated_in_g
1631+
1632+ total_treated = sum (group_sizes .values ())
1633+ pg = np .array ([group_sizes [g ] / total_treated for g in unique_groups ])
1634+
1635+ # Standard aggregated influence (without wif)
1636+ psi_standard = np .zeros (n_units )
1637+
1638+ for j , (g , t ) in enumerate (gt_pairs ):
1639+ if (g , t ) not in influence_func_info :
1640+ continue
1641+
1642+ info = influence_func_info [(g , t )]
1643+ w = weights [j ]
1644+
1645+ for i , uid in enumerate (info ['treated_units' ]):
1646+ idx = unit_to_idx [uid ]
1647+ psi_standard [idx ] += w * info ['treated_inf' ][i ]
1648+
1649+ for i , uid in enumerate (info ['control_units' ]):
1650+ idx = unit_to_idx [uid ]
1651+ psi_standard [idx ] += w * info ['control_inf' ][i ]
1652+
1653+ # Compute wif adjustment
1654+ # wif captures the influence of each unit on the weight estimation
1655+ # For simple aggregation with group-size weights:
1656+ # wif_i = [I(G_i = g) - pg[g]] / sum(pg) for numerator effect
1657+ # - adjustment for denominator effect
1658+ #
1659+ # R's formula (computed at GROUP level, not (g,t) level):
1660+ # if1 = (1*(G == g) - pg) / sum(pg)
1661+ # if2 = rowSums(1*(G == g) - pg) * (pg / sum(pg)^2)
1662+ # wif = if1 - if2
1663+ #
1664+ # The wif matrix is then multiplied by group-level aggregated ATT
1665+
1666+ # Build unit-group membership indicator
1667+ unit_groups = {}
1668+ for uid in all_units :
1669+ unit_first_treat = df [df [unit ] == uid ]['first_treat' ].iloc [0 ]
1670+ if unit_first_treat in unique_groups :
1671+ unit_groups [uid ] = unit_first_treat
1672+ else :
1673+ unit_groups [uid ] = None # Never-treated or other
1674+
1675+ # Compute group-level aggregated ATT (average ATT for each group)
1676+ # This matches R's approach where wif is multiplied by group-level ATT
1677+ group_att = {}
1678+ group_weight_sum = {}
1679+ for j , (g , t ) in enumerate (gt_pairs ):
1680+ if g not in group_att :
1681+ group_att [g ] = 0.0
1682+ group_weight_sum [g ] = 0.0
1683+ # Weight within group is equal (simple average over time periods)
1684+ group_att [g ] += effects [j ]
1685+ group_weight_sum [g ] += 1
1686+
1687+ for g in unique_groups :
1688+ if g in group_att and group_weight_sum [g ] > 0 :
1689+ group_att [g ] /= group_weight_sum [g ]
1690+ else :
1691+ group_att [g ] = 0.0
1692+
1693+ # Compute wif contribution for each unit (at GROUP level)
1694+ psi_wif = np .zeros (n_units )
1695+
1696+ # For each GROUP (not each (g,t) pair)
1697+ for g in unique_groups :
1698+ g_idx = group_to_idx [g ]
1699+ att_g = group_att [g ]
1700+
1701+ for uid in all_units :
1702+ i = unit_to_idx [uid ]
1703+ unit_g = unit_groups [uid ]
1704+
1705+ # Indicator: 1 if unit belongs to group g, 0 otherwise
1706+ indicator = 1.0 if unit_g == g else 0.0
1707+
1708+ # wif_i for this group
1709+ # Formula: (indicator - pg[g]) - Σ_g' (indicator_g' - pg_g') × pg_g
1710+ wif_i = (indicator - pg [g_idx ])
1711+
1712+ # Denominator adjustment
1713+ denom_adj = 0.0
1714+ for g_prime in unique_groups :
1715+ g_prime_idx = group_to_idx [g_prime ]
1716+ ind_prime = 1.0 if unit_g == g_prime else 0.0
1717+ denom_adj += (ind_prime - pg [g_prime_idx ]) * pg [g_idx ]
1718+
1719+ wif_i = wif_i - denom_adj
1720+
1721+ # Scale by group ATT and add to wif contribution
1722+ # Scale by 1/n_units to match R's getSE formula: sqrt(mean(IF^2)/n)
1723+ psi_wif [i ] += wif_i * att_g / n_units
1724+
1725+ # Combine standard and wif terms
1726+ psi_total = psi_standard + psi_wif
1727+
1728+ # Compute variance
1729+ variance = np .sum (psi_total ** 2 )
1730+ return np .sqrt (variance )
1731+
15691732 def _aggregate_event_study (
15701733 self ,
15711734 group_time_effects : Dict ,
0 commit comments