Skip to content

Commit ada7494

Browse files
igerberclaude
andcommitted
Fix wif formula to exactly match R's did package SE computation
The weight influence function (wif) formula had two issues causing ~1.1% SE difference vs R: 1. pg computation: Used n_g / total_treated instead of n_g / n_all 2. wif iteration: Iterated over groups with averaged ATT instead of keepers (post-treatment pairs) with individual ATT(g,t) values Now implements R's exact formula: - if1[i,k] = (indicator(G_i == group_k) - pg[k]) / sum(pg[keepers]) - if2[i,k] = indicator_sum[i] * pg[k] / sum(pg[keepers])^2 - wif[i,k] = if1[i,k] - if2[i,k] - wif_contrib[i] = sum_k(wif[i,k] * att[k]) Result: SE now matches R within <0.01% (essentially exact match). Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
1 parent a3837d0 commit ada7494

2 files changed

Lines changed: 54 additions & 66 deletions

File tree

diff_diff/staggered.py

Lines changed: 51 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -1592,11 +1592,14 @@ def _compute_aggregated_se_with_wif(
15921592
The wif adjustment adds variance due to the fact that aggregation
15931593
weights w_g = n_g / N depend on estimated group sizes.
15941594
1595-
Formula:
1596-
agg_inf_i = Σ_gt w_gt × inf_i_gt + wif_i × ATT_gt
1597-
se = sqrt(Σ_i (agg_inf_i)²)
1598-
1599-
where wif_i captures how unit i influences the weight estimation.
1595+
Formula (matching R's did::aggte):
1596+
agg_inf_i = Σ_k w_k × inf_i_k + wif_i × ATT_k
1597+
se = sqrt(mean(agg_inf^2) / n)
1598+
1599+
where:
1600+
- k indexes "keepers" (post-treatment (g,t) pairs)
1601+
- w_k = pg[k] / sum(pg[keepers]) where pg = n_g / n_all
1602+
- wif captures how unit i influences the weight estimation
16001603
"""
16011604
if not influence_func_info:
16021605
return 0.0
@@ -1618,19 +1621,23 @@ def _compute_aggregated_se_with_wif(
16181621

16191622
# Get unique groups and their information
16201623
unique_groups = sorted(set(groups_for_gt))
1621-
n_groups = len(unique_groups)
16221624
group_to_idx = {g: i for i, g in enumerate(unique_groups)}
16231625

1624-
# Compute group-level probabilities (proportion of treated in each group)
1625-
# pg[g] = n_g / N where N = total treated across all groups
1626+
# Compute group-level probabilities matching R's formula:
1627+
# pg[g] = n_g / n_all (fraction of ALL units in group g)
1628+
# This differs from our old formula which used n_g / total_treated
16261629
group_sizes = {}
16271630
for g in unique_groups:
1628-
# Count unique treated units in this group
16291631
treated_in_g = df[df['first_treat'] == g][unit].nunique()
16301632
group_sizes[g] = treated_in_g
16311633

1632-
total_treated = sum(group_sizes.values())
1633-
pg = np.array([group_sizes[g] / total_treated for g in unique_groups])
1634+
# pg indexed by group
1635+
pg_by_group = np.array([group_sizes[g] / n_units for g in unique_groups])
1636+
1637+
# pg indexed by keeper (each (g,t) pair gets its group's pg)
1638+
# This matches R's: pg <- pgg[match(group, originalglist)]
1639+
pg_keepers = np.array([pg_by_group[group_to_idx[g]] for g in groups_for_gt])
1640+
sum_pg_keepers = np.sum(pg_keepers)
16341641

16351642
# Standard aggregated influence (without wif)
16361643
psi_standard = np.zeros(n_units)
@@ -1650,19 +1657,6 @@ def _compute_aggregated_se_with_wif(
16501657
idx = unit_to_idx[uid]
16511658
psi_standard[idx] += w * info['control_inf'][i]
16521659

1653-
# Compute wif adjustment
1654-
# wif captures the influence of each unit on the weight estimation
1655-
# For simple aggregation with group-size weights:
1656-
# wif_i = [I(G_i = g) - pg[g]] / sum(pg) for numerator effect
1657-
# - adjustment for denominator effect
1658-
#
1659-
# R's formula (computed at GROUP level, not (g,t) level):
1660-
# if1 = (1*(G == g) - pg) / sum(pg)
1661-
# if2 = rowSums(1*(G == g) - pg) * (pg / sum(pg)^2)
1662-
# wif = if1 - if2
1663-
#
1664-
# The wif matrix is then multiplied by group-level aggregated ATT
1665-
16661660
# Build unit-group membership indicator
16671661
unit_groups = {}
16681662
for uid in all_units:
@@ -1672,60 +1666,54 @@ def _compute_aggregated_se_with_wif(
16721666
else:
16731667
unit_groups[uid] = None # Never-treated or other
16741668

1675-
# Compute group-level aggregated ATT (average ATT for each group)
1676-
# This matches R's approach where wif is multiplied by group-level ATT
1677-
group_att = {}
1678-
group_weight_sum = {}
1679-
for j, (g, t) in enumerate(gt_pairs):
1680-
if g not in group_att:
1681-
group_att[g] = 0.0
1682-
group_weight_sum[g] = 0.0
1683-
# Weight within group is equal (simple average over time periods)
1684-
group_att[g] += effects[j]
1685-
group_weight_sum[g] += 1
1669+
# Compute wif using R's exact formula (iterate over keepers, not groups)
1670+
# R's wif function:
1671+
# if1[i,k] = (indicator(G_i == group_k) - pg[k]) / sum(pg[keepers])
1672+
# if2[i,k] = indicator_sum[i] * pg[k] / sum(pg[keepers])^2
1673+
# wif[i,k] = if1[i,k] - if2[i,k]
1674+
#
1675+
# Then: wif_contrib[i] = sum_k(wif[i,k] * att[k])
16861676

1687-
for g in unique_groups:
1688-
if g in group_att and group_weight_sum[g] > 0:
1689-
group_att[g] /= group_weight_sum[g]
1690-
else:
1691-
group_att[g] = 0.0
1677+
n_keepers = len(gt_pairs)
1678+
wif_contrib = np.zeros(n_units)
16921679

1693-
# Compute wif contribution for each unit (at GROUP level)
1694-
psi_wif = np.zeros(n_units)
1680+
# Pre-compute indicator_sum for each unit
1681+
# indicator_sum[i] = sum_k(indicator(G_i == group_k) - pg[k])
1682+
indicator_sum = np.zeros(n_units)
1683+
for j, g in enumerate(groups_for_gt):
1684+
pg_k = pg_keepers[j]
1685+
for uid in all_units:
1686+
i = unit_to_idx[uid]
1687+
unit_g = unit_groups[uid]
1688+
indicator = 1.0 if unit_g == g else 0.0
1689+
indicator_sum[i] += (indicator - pg_k)
16951690

1696-
# For each GROUP (not each (g,t) pair)
1697-
for g in unique_groups:
1698-
g_idx = group_to_idx[g]
1699-
att_g = group_att[g]
1691+
# Compute wif contribution for each keeper
1692+
for j, (g, t) in enumerate(gt_pairs):
1693+
pg_k = pg_keepers[j]
1694+
att_k = effects[j]
17001695

17011696
for uid in all_units:
17021697
i = unit_to_idx[uid]
17031698
unit_g = unit_groups[uid]
1704-
1705-
# Indicator: 1 if unit belongs to group g, 0 otherwise
17061699
indicator = 1.0 if unit_g == g else 0.0
17071700

1708-
# wif_i for this group
1709-
# Formula: (indicator - pg[g]) - Σ_g' (indicator_g' - pg_g') × pg_g
1710-
wif_i = (indicator - pg[g_idx])
1711-
1712-
# Denominator adjustment
1713-
denom_adj = 0.0
1714-
for g_prime in unique_groups:
1715-
g_prime_idx = group_to_idx[g_prime]
1716-
ind_prime = 1.0 if unit_g == g_prime else 0.0
1717-
denom_adj += (ind_prime - pg[g_prime_idx]) * pg[g_idx]
1701+
# R's formula for wif
1702+
if1_ik = (indicator - pg_k) / sum_pg_keepers
1703+
if2_ik = indicator_sum[i] * pg_k / (sum_pg_keepers ** 2)
1704+
wif_ik = if1_ik - if2_ik
17181705

1719-
wif_i = wif_i - denom_adj
1706+
# Add contribution: wif[i,k] * att[k]
1707+
wif_contrib[i] += wif_ik * att_k
17201708

1721-
# Scale by group ATT and add to wif contribution
1722-
# Scale by 1/n_units to match R's getSE formula: sqrt(mean(IF^2)/n)
1723-
psi_wif[i] += wif_i * att_g / n_units
1709+
# Scale by 1/n_units to match R's getSE formula: sqrt(mean(IF^2)/n)
1710+
psi_wif = wif_contrib / n_units
17241711

17251712
# Combine standard and wif terms
17261713
psi_total = psi_standard + psi_wif
17271714

1728-
# Compute variance
1715+
# Compute variance and SE
1716+
# R's formula: sqrt(mean(IF^2) / n) = sqrt(sum(IF^2) / n^2)
17291717
variance = np.sum(psi_total ** 2)
17301718
return np.sqrt(variance)
17311719

tests/test_se_accuracy.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -235,10 +235,10 @@ def test_se_vs_r_benchmark(self):
235235
assert att_diff < 1e-8, \
236236
f"ATT differs from R: {results.overall_att} vs {r_overall_att}"
237237

238-
# SE should be within 2% of R (target after wif adjustment)
238+
# SE should match R exactly (< 0.01% after wif fix)
239239
se_diff_pct = abs(results.overall_se - r_overall_se) / r_overall_se * 100
240-
assert se_diff_pct < 2.0, \
241-
f"SE differs from R by {se_diff_pct:.2f}%, expected <2%"
240+
assert se_diff_pct < 0.01, \
241+
f"SE differs from R by {se_diff_pct:.4f}%, expected <0.01%"
242242

243243
def test_timing_performance(self, cs_results):
244244
"""

0 commit comments

Comments
 (0)