Skip to content

Commit 1166024

Browse files
igerberclaude
andcommitted
Fix weighted-mass IF linearization in ContinuousDiD and TripleDifference from PR #226 review (round 3)
- ContinuousDiD: store survey-weighted treated/control masses and weighted dpsi_bar in bootstrap_info; use weighted masses for p_1, p_0, n_total in IF construction so TSL linearizes the weighted estimator - TripleDifference: use survey-weighted subgroup mass (sum(w_sub)) instead of raw counts (n_sub) for pairwise IF combination weights w3, w2, w1 when survey design is active Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent bb3b64a commit 1166024

2 files changed

Lines changed: 34 additions & 8 deletions

File tree

diff_diff/continuous_did.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -608,6 +608,10 @@ def fit(
608608
control_idx = b_info["control_indices"]
609609
n_t = b_info["n_treated"]
610610
n_c = b_info["n_control"]
611+
# Use survey-weighted masses when available
612+
if "w_treated" in b_info:
613+
n_t = b_info["w_treated"]
614+
n_c = b_info["w_control"]
611615
n_total_gt = n_t + n_c
612616
p_1 = n_t / n_total_gt
613617
p_0 = n_c / n_total_gt
@@ -972,11 +976,18 @@ def _compute_dose_response_gt(
972976
treated_indices = np.where(treated_mask)[0]
973977
control_indices = np.where(control_mask)[0]
974978

979+
# dpsi_bar: mean derivative basis vector (weighted when survey)
980+
if w_treated is not None:
981+
dpsi_bar = np.average(dPsi_treated, axis=0, weights=w_treated)
982+
else:
983+
dpsi_bar = np.mean(dPsi_treated, axis=0)
984+
975985
bootstrap_info = {
976986
"bread": bread,
977987
"ee_treated": ee_treated,
978988
"ee_control": ee_control,
979989
"psi_bar": psi_bar,
990+
"dpsi_bar": dpsi_bar,
980991
"beta_hat": beta_hat,
981992
"beta_pred": beta_pred,
982993
"treated_indices": treated_indices,
@@ -993,6 +1004,11 @@ def _compute_dose_response_gt(
9931004
"acrt_glob": acrt_glob,
9941005
}
9951006

1007+
# Store survey-weighted masses for IF linearization
1008+
if w_treated is not None:
1009+
bootstrap_info["w_treated"] = float(np.sum(w_treated))
1010+
bootstrap_info["w_control"] = float(np.sum(w_control))
1011+
9961012
return {
9971013
"att_d": att_d,
9981014
"acrt_d": acrt_d,
@@ -1080,13 +1096,17 @@ def _compute_analytical_se(
10801096
control_idx = info["control_indices"]
10811097
n_t = info["n_treated"]
10821098
n_c = info["n_control"]
1099+
# Use survey-weighted masses when available
1100+
if "w_treated" in info:
1101+
n_t = info["w_treated"]
1102+
n_c = info["w_control"]
10831103
bread = info["bread"]
10841104
ee_treated = info["ee_treated"]
10851105
ee_control = info["ee_control"]
10861106
psi_bar = info["psi_bar"]
1107+
dpsi_bar = info["dpsi_bar"]
10871108
Psi_eval = info["Psi_eval"]
10881109
dPsi_eval = info["dPsi_eval"]
1089-
dPsi_treated = info["dPsi_treated"]
10901110
att_glob_gt = info["att_glob"]
10911111
mu_0 = info["mu_0"]
10921112
delta_y_treated = info["delta_y_treated"]
@@ -1119,7 +1139,6 @@ def _compute_analytical_se(
11191139
if_acrt_d[idx] += w * (dPsi_eval @ beta_pert)
11201140

11211141
# ACRT_glob IF: (1/n_t) sum_j dpsi(D_j)' @ beta_pert
1122-
dpsi_bar = np.mean(dPsi_treated, axis=0)
11231142
for k, idx in enumerate(treated_idx):
11241143
beta_pert = bread @ ee_treated[k] / n_t
11251144
if_acrt_glob[idx] += w * float(dpsi_bar @ beta_pert)

diff_diff/triple_diff.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1025,12 +1025,19 @@ def _estimate_ddd_decomposition(
10251025
)
10261026

10271027
# Influence function weights (matching R's att_dr_rc)
1028-
n3 = np.sum((subgroup == 3) | (subgroup == 4))
1029-
n2 = np.sum((subgroup == 2) | (subgroup == 4))
1030-
n1 = np.sum((subgroup == 1) | (subgroup == 4))
1031-
w3 = n / n3
1032-
w2 = n / n2
1033-
w1 = n / n1
1028+
if survey_weights is not None:
1029+
n3 = np.sum(survey_weights[(subgroup == 3) | (subgroup == 4)])
1030+
n2 = np.sum(survey_weights[(subgroup == 2) | (subgroup == 4)])
1031+
n1 = np.sum(survey_weights[(subgroup == 1) | (subgroup == 4)])
1032+
n_total = np.sum(survey_weights)
1033+
else:
1034+
n3 = np.sum((subgroup == 3) | (subgroup == 4))
1035+
n2 = np.sum((subgroup == 2) | (subgroup == 4))
1036+
n1 = np.sum((subgroup == 1) | (subgroup == 4))
1037+
n_total = n
1038+
w3 = n_total / n3
1039+
w2 = n_total / n2
1040+
w1 = n_total / n1
10341041

10351042
inf_func = (
10361043
w3 * did_results[3]["inf"] + w2 * did_results[2]["inf"] - w1 * did_results[1]["inf"]

0 commit comments

Comments
 (0)