Skip to content

Commit e754b04

Browse files
igerberclaude
andcommitted
Fix round-3 review P1s: DDD double-weighting, CS WIF scaling, CS covariate nuisance IF
- TripleDifference: remove double-weighting in IPW/DR moment corrections — since Riesz representers already incorporate survey weights, moment means use np.mean() not np.average(weights=). Removed _wmean_ax0 helper. - CallawaySantAnna WIF: apply s_i symmetrically to both indicator and pg terms in the weighted share estimator IF. Normalize by total_weight (sum of survey weights) instead of n_units. - CallawaySantAnna outcome regression covariate IF: add weighted regression nuisance IF correction (asymptotic linear representation of beta from WLS, projected onto weighted treated covariate mean). IPW and DR IFs unchanged (IPW matches unweighted structure; DR is self-correcting per Theorem 3.1). Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent ae6b278 commit e754b04

3 files changed

Lines changed: 70 additions & 42 deletions

File tree

diff_diff/staggered.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1591,9 +1591,34 @@ def _outcome_regression(
15911591
sw_c_norm = sw_control / np.sum(sw_control)
15921592
att = float(np.sum(sw_t_norm * treated_residuals))
15931593

1594+
# --- Regression nuisance IF correction ---
1595+
# Account for uncertainty in beta estimation
1596+
X_c = np.column_stack([np.ones(n_c), X_control])
1597+
X_t = np.column_stack([np.ones(n_t), X_treated])
1598+
1599+
# Weighted bread: (X'WX)^{-1}
1600+
XWX = X_c.T @ (X_c * sw_control[:, None])
1601+
try:
1602+
XWX_inv = np.linalg.solve(XWX, np.eye(XWX.shape[0]))
1603+
except np.linalg.LinAlgError:
1604+
XWX_inv = np.linalg.lstsq(XWX, np.eye(XWX.shape[0]), rcond=None)[0]
1605+
1606+
# Per-control regression score: w_i * x_i * resid_i
1607+
resid_c = control_change - X_c @ beta
1608+
score_c = X_c * (sw_control * resid_c)[:, None]
1609+
asy_lin_rep_reg = score_c @ XWX_inv # shape (n_c, p)
1610+
1611+
# Weighted treated covariate mean
1612+
X_treated_mean_w = np.average(X_t, axis=0, weights=sw_treated)
1613+
1614+
# Regression IF correction for control observations
1615+
inf_control_reg_corr = asy_lin_rep_reg @ X_treated_mean_w
1616+
15941617
# Influence function (survey-weighted)
15951618
inf_treated = sw_t_norm * (treated_residuals - att)
1596-
inf_control = -sw_c_norm * (control_change - np.sum(sw_c_norm * control_change))
1619+
inf_control = (
1620+
-sw_c_norm * (control_change - np.dot(X_c, beta)) + inf_control_reg_corr / n_c
1621+
)
15971622
inf_func = np.concatenate([inf_treated, inf_control])
15981623

15991624
# SE from influence function variance

diff_diff/staggered_aggregation.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -341,7 +341,8 @@ def _compute_combined_influence_function(
341341
).astype(np.float64)
342342

343343
if survey_w is not None:
344-
# Survey-weighted WIF: indicator entries are sw_i / sum(sw_all)
344+
# Survey-weighted WIF for group-share estimator p_g = sum(s_i * 1{G_i=g}) / sum(s_j).
345+
# IF_i(p_g) = s_i * (1{G_i=g} - p_g) / sum(s_j)
345346
# Build per-unit weight vector aligned to our index space
346347
if global_unit_to_idx is not None and precomputed is not None:
347348
unit_sw = np.zeros(n_units)
@@ -353,12 +354,16 @@ def _compute_combined_influence_function(
353354
else:
354355
unit_sw = np.ones(n_units)
355356

356-
# Weighted indicator: sw_i * 1{G_i == g_k} / sum(sw_all)
357-
weighted_indicator = indicator_matrix * (unit_sw / total_weight)[:, np.newaxis]
358-
indicator_sum_w = np.sum(weighted_indicator - pg_keepers, axis=1)
357+
# s_i * 1{G_i == g_k}
358+
weighted_indicator = indicator_matrix * unit_sw[:, np.newaxis]
359+
# s_i * p_g_k (symmetric weight application)
360+
weighted_pg_term = pg_keepers[np.newaxis, :] * unit_sw[:, np.newaxis]
361+
# s_i * (1{G_i == g_k} - p_g_k) / sum(s_j)
362+
indicator_diff = (weighted_indicator - weighted_pg_term) / total_weight
363+
indicator_sum_w = np.sum(indicator_diff, axis=1)
359364

360365
with np.errstate(divide="ignore", invalid="ignore", over="ignore"):
361-
if1_matrix = (weighted_indicator - pg_keepers) / sum_pg_keepers
366+
if1_matrix = indicator_diff / sum_pg_keepers
362367
if2_matrix = np.outer(indicator_sum_w, pg_keepers) / (sum_pg_keepers**2)
363368
wif_matrix = if1_matrix - if2_matrix
364369
wif_contrib = wif_matrix @ effects
@@ -386,8 +391,9 @@ def _compute_combined_influence_function(
386391
nan_result = np.full(n_units, np.nan)
387392
return nan_result, all_units
388393

389-
# Scale by 1/n_units to match R's getSE formula
390-
psi_wif = wif_contrib / n_units
394+
# Scale by 1/total_weight to match R's getSE formula
395+
# (for non-survey, total_weight == n_units; for survey, total_weight == sum(sw))
396+
psi_wif = wif_contrib / total_weight
391397

392398
# Combine standard and wif terms
393399
psi_total = psi_standard + psi_wif

diff_diff/triple_diff.py

Lines changed: 31 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1329,26 +1329,16 @@ def _hajek(riesz, y_vals):
13291329
score_ps = score_ps * weights[:, None]
13301330
asy_lin_rep_ps = score_ps @ hessian
13311331

1332-
if weights is not None:
1333-
M2_pre = np.average(
1334-
(riesz_control_pre * (y - att_control_pre))[:, None] * covX,
1335-
axis=0,
1336-
weights=weights,
1337-
) / np.mean(riesz_control_pre)
1338-
M2_post = np.average(
1339-
(riesz_control_post * (y - att_control_post))[:, None] * covX,
1340-
axis=0,
1341-
weights=weights,
1342-
) / np.mean(riesz_control_post)
1343-
else:
1344-
M2_pre = np.mean(
1345-
(riesz_control_pre * (y - att_control_pre))[:, None] * covX,
1346-
axis=0,
1347-
) / np.mean(riesz_control_pre)
1348-
M2_post = np.mean(
1349-
(riesz_control_post * (y - att_control_post))[:, None] * covX,
1350-
axis=0,
1351-
) / np.mean(riesz_control_post)
1332+
# Riesz representers already incorporate survey weights,
1333+
# so use np.mean (not np.average with weights) to avoid double-weighting.
1334+
M2_pre = np.mean(
1335+
(riesz_control_pre * (y - att_control_pre))[:, None] * covX,
1336+
axis=0,
1337+
) / np.mean(riesz_control_pre)
1338+
M2_post = np.mean(
1339+
(riesz_control_post * (y - att_control_post))[:, None] * covX,
1340+
axis=0,
1341+
) / np.mean(riesz_control_post)
13521342
inf_control_ps = asy_lin_rep_ps @ (M2_post - M2_pre)
13531343
inf_control = inf_control + inf_control_ps
13541344

@@ -1616,19 +1606,15 @@ def _safe_ratio(num, denom):
16161606
)
16171607

16181608
# OR correction for treated
1619-
def _wmean_ax0(arr):
1620-
"""Weighted or unweighted column mean."""
1621-
if weights is not None:
1622-
return np.average(arr, axis=0, weights=weights)
1623-
return np.mean(arr, axis=0)
1624-
1609+
# Riesz representers already incorporate survey weights,
1610+
# so use np.mean (not weighted average) to avoid double-weighting.
16251611
M1_post = (
1626-
(-_wmean_ax0((riesz_treat_post * post)[:, None] * covX) / m_riesz_treat_post)
1612+
(-np.mean((riesz_treat_post * post)[:, None] * covX, axis=0) / m_riesz_treat_post)
16271613
if m_riesz_treat_post > 0
16281614
else np.zeros(covX.shape[1])
16291615
)
16301616
M1_pre = (
1631-
(-_wmean_ax0((riesz_treat_pre * (1 - post))[:, None] * covX) / m_riesz_treat_pre)
1617+
(-np.mean((riesz_treat_pre * (1 - post))[:, None] * covX, axis=0) / m_riesz_treat_pre)
16321618
if m_riesz_treat_pre > 0
16331619
else np.zeros(covX.shape[1])
16341620
)
@@ -1653,15 +1639,19 @@ def _wmean_ax0(arr):
16531639
# PS correction for control
16541640
M2_pre = (
16551641
(
1656-
_wmean_ax0((riesz_control_pre * (y - or_ctrl - att_control_pre))[:, None] * covX)
1642+
np.mean(
1643+
(riesz_control_pre * (y - or_ctrl - att_control_pre))[:, None] * covX, axis=0
1644+
)
16571645
/ m_riesz_control_pre
16581646
)
16591647
if m_riesz_control_pre > 0
16601648
else np.zeros(covX.shape[1])
16611649
)
16621650
M2_post = (
16631651
(
1664-
_wmean_ax0((riesz_control_post * (y - or_ctrl - att_control_post))[:, None] * covX)
1652+
np.mean(
1653+
(riesz_control_post * (y - or_ctrl - att_control_post))[:, None] * covX, axis=0
1654+
)
16651655
/ m_riesz_control_post
16661656
)
16671657
if m_riesz_control_post > 0
@@ -1671,12 +1661,15 @@ def _wmean_ax0(arr):
16711661

16721662
# OR correction for control
16731663
M3_post = (
1674-
(-_wmean_ax0((riesz_control_post * post)[:, None] * covX) / m_riesz_control_post)
1664+
(-np.mean((riesz_control_post * post)[:, None] * covX, axis=0) / m_riesz_control_post)
16751665
if m_riesz_control_post > 0
16761666
else np.zeros(covX.shape[1])
16771667
)
16781668
M3_pre = (
1679-
(-_wmean_ax0((riesz_control_pre * (1 - post))[:, None] * covX) / m_riesz_control_pre)
1669+
(
1670+
-np.mean((riesz_control_pre * (1 - post))[:, None] * covX, axis=0)
1671+
/ m_riesz_control_pre
1672+
)
16801673
if m_riesz_control_pre > 0
16811674
else np.zeros(covX.shape[1])
16821675
)
@@ -1704,12 +1697,16 @@ def _wmean_ax0(arr):
17041697

17051698
# OR combination
17061699
mom_post = (
1707-
_wmean_ax0((riesz_d[:, None] / m_riesz_d - riesz_dt1[:, None] / m_riesz_dt1) * covX)
1700+
np.mean(
1701+
(riesz_d[:, None] / m_riesz_d - riesz_dt1[:, None] / m_riesz_dt1) * covX, axis=0
1702+
)
17081703
if (m_riesz_d > 0 and m_riesz_dt1 > 0)
17091704
else np.zeros(covX.shape[1])
17101705
)
17111706
mom_pre = (
1712-
_wmean_ax0((riesz_d[:, None] / m_riesz_d - riesz_dt0[:, None] / m_riesz_dt0) * covX)
1707+
np.mean(
1708+
(riesz_d[:, None] / m_riesz_d - riesz_dt0[:, None] / m_riesz_dt0) * covX, axis=0
1709+
)
17131710
if (m_riesz_d > 0 and m_riesz_dt0 > 0)
17141711
else np.zeros(covX.shape[1])
17151712
)

0 commit comments

Comments
 (0)