@@ -1592,11 +1592,14 @@ def _compute_aggregated_se_with_wif(
15921592 The wif adjustment adds variance due to the fact that aggregation
15931593 weights w_g = n_g / N depend on estimated group sizes.
15941594
1595- Formula:
1596- agg_inf_i = Σ_gt w_gt × inf_i_gt + wif_i × ATT_gt
1597- se = sqrt(Σ_i (agg_inf_i)²)
1598-
1599- where wif_i captures how unit i influences the weight estimation.
1595+ Formula (matching R's did::aggte):
1596+ agg_inf_i = Σ_k w_k × inf_i_k + wif_i × ATT_k
1597+ se = sqrt(mean(agg_inf^2) / n)
1598+
1599+ where:
1600+ - k indexes "keepers" (post-treatment (g,t) pairs)
1601+ - w_k = pg[k] / sum(pg[keepers]) where pg = n_g / n_all
1602+ - wif captures how unit i influences the weight estimation
16001603 """
16011604 if not influence_func_info :
16021605 return 0.0
@@ -1618,19 +1621,23 @@ def _compute_aggregated_se_with_wif(
16181621
16191622 # Get unique groups and their information
16201623 unique_groups = sorted (set (groups_for_gt ))
1621- n_groups = len (unique_groups )
16221624 group_to_idx = {g : i for i , g in enumerate (unique_groups )}
16231625
1624- # Compute group-level probabilities (proportion of treated in each group)
1625- # pg[g] = n_g / N where N = total treated across all groups
1626+ # Compute group-level probabilities matching R's formula:
1627+ # pg[g] = n_g / n_all (fraction of ALL units in group g)
1628+ # This differs from our old formula which used n_g / total_treated
16261629 group_sizes = {}
16271630 for g in unique_groups :
1628- # Count unique treated units in this group
16291631 treated_in_g = df [df ['first_treat' ] == g ][unit ].nunique ()
16301632 group_sizes [g ] = treated_in_g
16311633
1632- total_treated = sum (group_sizes .values ())
1633- pg = np .array ([group_sizes [g ] / total_treated for g in unique_groups ])
1634+ # pg indexed by group
1635+ pg_by_group = np .array ([group_sizes [g ] / n_units for g in unique_groups ])
1636+
1637+ # pg indexed by keeper (each (g,t) pair gets its group's pg)
1638+ # This matches R's: pg <- pgg[match(group, originalglist)]
1639+ pg_keepers = np .array ([pg_by_group [group_to_idx [g ]] for g in groups_for_gt ])
1640+ sum_pg_keepers = np .sum (pg_keepers )
16341641
16351642 # Standard aggregated influence (without wif)
16361643 psi_standard = np .zeros (n_units )
@@ -1650,19 +1657,6 @@ def _compute_aggregated_se_with_wif(
16501657 idx = unit_to_idx [uid ]
16511658 psi_standard [idx ] += w * info ['control_inf' ][i ]
16521659
1653- # Compute wif adjustment
1654- # wif captures the influence of each unit on the weight estimation
1655- # For simple aggregation with group-size weights:
1656- # wif_i = [I(G_i = g) - pg[g]] / sum(pg) for numerator effect
1657- # - adjustment for denominator effect
1658- #
1659- # R's formula (computed at GROUP level, not (g,t) level):
1660- # if1 = (1*(G == g) - pg) / sum(pg)
1661- # if2 = rowSums(1*(G == g) - pg) * (pg / sum(pg)^2)
1662- # wif = if1 - if2
1663- #
1664- # The wif matrix is then multiplied by group-level aggregated ATT
1665-
16661660 # Build unit-group membership indicator
16671661 unit_groups = {}
16681662 for uid in all_units :
@@ -1672,60 +1666,54 @@ def _compute_aggregated_se_with_wif(
16721666 else :
16731667 unit_groups [uid ] = None # Never-treated or other
16741668
1675- # Compute group-level aggregated ATT (average ATT for each group)
1676- # This matches R's approach where wif is multiplied by group-level ATT
1677- group_att = {}
1678- group_weight_sum = {}
1679- for j , (g , t ) in enumerate (gt_pairs ):
1680- if g not in group_att :
1681- group_att [g ] = 0.0
1682- group_weight_sum [g ] = 0.0
1683- # Weight within group is equal (simple average over time periods)
1684- group_att [g ] += effects [j ]
1685- group_weight_sum [g ] += 1
1669+ # Compute wif using R's exact formula (iterate over keepers, not groups)
1670+ # R's wif function:
1671+ # if1[i,k] = (indicator(G_i == group_k) - pg[k]) / sum(pg[keepers])
1672+ # if2[i,k] = indicator_sum[i] * pg[k] / sum(pg[keepers])^2
1673+ # wif[i,k] = if1[i,k] - if2[i,k]
1674+ #
1675+ # Then: wif_contrib[i] = sum_k(wif[i,k] * att[k])
16861676
1687- for g in unique_groups :
1688- if g in group_att and group_weight_sum [g ] > 0 :
1689- group_att [g ] /= group_weight_sum [g ]
1690- else :
1691- group_att [g ] = 0.0
1677+ n_keepers = len (gt_pairs )
1678+ wif_contrib = np .zeros (n_units )
16921679
1693- # Compute wif contribution for each unit (at GROUP level)
1694- psi_wif = np .zeros (n_units )
1680+ # Pre-compute indicator_sum for each unit
1681+ # indicator_sum[i] = sum_k(indicator(G_i == group_k) - pg[k])
1682+ indicator_sum = np .zeros (n_units )
1683+ for j , g in enumerate (groups_for_gt ):
1684+ pg_k = pg_keepers [j ]
1685+ for uid in all_units :
1686+ i = unit_to_idx [uid ]
1687+ unit_g = unit_groups [uid ]
1688+ indicator = 1.0 if unit_g == g else 0.0
1689+ indicator_sum [i ] += (indicator - pg_k )
16951690
1696- # For each GROUP (not each (g,t) pair)
1697- for g in unique_groups :
1698- g_idx = group_to_idx [ g ]
1699- att_g = group_att [ g ]
1691+ # Compute wif contribution for each keeper
1692+ for j , ( g , t ) in enumerate ( gt_pairs ) :
1693+ pg_k = pg_keepers [ j ]
1694+ att_k = effects [ j ]
17001695
17011696 for uid in all_units :
17021697 i = unit_to_idx [uid ]
17031698 unit_g = unit_groups [uid ]
1704-
1705- # Indicator: 1 if unit belongs to group g, 0 otherwise
17061699 indicator = 1.0 if unit_g == g else 0.0
17071700
1708- # wif_i for this group
1709- # Formula: (indicator - pg[g]) - Σ_g' (indicator_g' - pg_g') × pg_g
1710- wif_i = (indicator - pg [g_idx ])
1711-
1712- # Denominator adjustment
1713- denom_adj = 0.0
1714- for g_prime in unique_groups :
1715- g_prime_idx = group_to_idx [g_prime ]
1716- ind_prime = 1.0 if unit_g == g_prime else 0.0
1717- denom_adj += (ind_prime - pg [g_prime_idx ]) * pg [g_idx ]
1701+ # R's formula for wif
1702+ if1_ik = (indicator - pg_k ) / sum_pg_keepers
1703+ if2_ik = indicator_sum [i ] * pg_k / (sum_pg_keepers ** 2 )
1704+ wif_ik = if1_ik - if2_ik
17181705
1719- wif_i = wif_i - denom_adj
1706+ # Add contribution: wif[i,k] * att[k]
1707+ wif_contrib [i ] += wif_ik * att_k
17201708
1721- # Scale by group ATT and add to wif contribution
1722- # Scale by 1/n_units to match R's getSE formula: sqrt(mean(IF^2)/n)
1723- psi_wif [i ] += wif_i * att_g / n_units
1709+ # Scale by 1/n_units to match R's getSE formula: sqrt(mean(IF^2)/n)
1710+ psi_wif = wif_contrib / n_units
17241711
17251712 # Combine standard and wif terms
17261713 psi_total = psi_standard + psi_wif
17271714
1728- # Compute variance
1715+ # Compute variance and SE
1716+ # R's formula: sqrt(mean(IF^2) / n) = sqrt(sum(IF^2) / n^2)
17291717 variance = np .sum (psi_total ** 2 )
17301718 return np .sqrt (variance )
17311719
0 commit comments