Skip to content

Commit c0c7e5a

Browse files
igerberclaude
andcommitted
Fix M2 gradient scaling: use np.sum instead of np.mean over control subsets
The M2 gradient terms in PS nuisance corrections used np.mean() over control subsets, introducing an extra 1/n_c divisor. R's DRDID computes M2 as colMeans() over the full n-sample (zeros for treated), then divides by mean(w.cont) — the n's cancel, giving sum(w*resid*X)/sum(w). With our Hajek-normalized weights (w_norm = w/sum(w)), np.sum(w_norm*resid*X) directly yields sum(w*resid*X)/sum(w), matching R after cancellation. The single /n on the correction line remains as the psi-to-phi conversion. Applied at all 5 PS correction sites (panel survey IPW/DR, panel non-survey DR, RCS IPW, RCS DR). Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 43c5547 commit c0c7e5a

1 file changed

Lines changed: 10 additions & 8 deletions

File tree

diff_diff/staggered.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2072,8 +2072,10 @@ def _ipw_estimation(
20722072
asy_lin_rep_psi = score_ps @ H_psi_inv
20732073

20742074
att_control_weighted = np.sum(weights_control_norm * control_change)
2075-
# R: M2 = colMeans(w.cont * (y - att) * X)
2076-
M2 = np.mean(
2075+
# R: M2 = colMeans(w.cont * (y - att) * X) / mean(w.cont)
2076+
# np.sum (not mean): subset sum with normalized weights matches
2077+
# R's full-sample colMeans/mean(w) after cancellation
2078+
M2 = np.sum(
20772079
(weights_control_norm * (control_change - att_control_weighted))[:, None]
20782080
* X_all_int[n_t:],
20792081
axis=0,
@@ -2331,7 +2333,7 @@ def _doubly_robust(
23312333
asy_lin_rep_psi = score_ps @ H_psi_inv
23322334

23332335
dr_resid_control = m_control - control_change
2334-
M2_dr = np.mean(
2336+
M2_dr = np.sum(
23352337
((weights_control / sw_t_sum) * dr_resid_control)[:, None]
23362338
* X_all_int[n_t:],
23372339
axis=0,
@@ -2394,7 +2396,7 @@ def _doubly_robust(
23942396
asy_lin_rep_psi = score_ps @ H_psi_inv
23952397

23962398
dr_resid_control = m_control - control_change
2397-
M2_dr = np.mean(
2399+
M2_dr = np.sum(
23982400
((weights_control / n_t) * dr_resid_control)[:, None] * X_all_int[n_t:],
23992401
axis=0,
24002402
)
@@ -3152,8 +3154,8 @@ def _ipw_estimation_rc(
31523154
cs_slice = slice(n_gt + n_gs + n_ct, None)
31533155

31543156
M2 = np.zeros(X_all_int.shape[1])
3155-
M2 += np.mean(ipw_resid_ct[:, None] * X_all_int[ct_slice], axis=0)
3156-
M2 -= np.mean(ipw_resid_cs[:, None] * X_all_int[cs_slice], axis=0)
3157+
M2 += np.sum(ipw_resid_ct[:, None] * X_all_int[ct_slice], axis=0)
3158+
M2 -= np.sum(ipw_resid_cs[:, None] * X_all_int[cs_slice], axis=0)
31573159

31583160
# psi-scale correction, convert to phi
31593161
inf_all = inf_all + (asy_lin_rep_psi @ M2) / n_all
@@ -3469,12 +3471,12 @@ def _doubly_robust_rc(
34693471

34703472
M2 = np.zeros(X_all_int.shape[1])
34713473
if sum_w_ipw_ct > 0:
3472-
M2 -= np.mean(
3474+
M2 -= np.sum(
34733475
((w_ipw_ct * dr_resid_ct / sum_w_ipw_ct)[:, None] * X_all_int[ct_slice]),
34743476
axis=0,
34753477
)
34763478
if sum_w_ipw_cs > 0:
3477-
M2 += np.mean(
3479+
M2 += np.sum(
34783480
((w_ipw_cs * dr_resid_cs / sum_w_ipw_cs)[:, None] * X_all_int[cs_slice]),
34793481
axis=0,
34803482
)

0 commit comments

Comments
 (0)