Skip to content

Commit 6165bc4

Browse files
igerberclaude
andcommitted
Fix TripleDiff TSL double-weighting, rewrite CS reg covariate survey IF
- TripleDifference: divide out survey weights from IF before passing to compute_survey_vcov, since Riesz representers already incorporate weights and TSL would multiply by weights again - CallawaySantAnna _outcome_regression: rewrite survey covariate IF to follow DRDID panel OR structure — all terms consistently scaled by 1/sw_t_sum, nuisance correction divided by sw_t_sum for correct normalization Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent ec4dda3 commit 6165bc4

2 files changed

Lines changed: 34 additions & 22 deletions

File tree

diff_diff/staggered.py

Lines changed: 24 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1603,43 +1603,48 @@ def _outcome_regression(
16031603
treated_residuals = treated_change - predicted_control
16041604

16051605
if sw_treated is not None:
1606-
sw_t_norm = sw_treated / np.sum(sw_treated)
1607-
sw_c_norm = sw_control / np.sum(sw_control)
1606+
sw_t_sum = float(np.sum(sw_treated))
1607+
sw_t_norm = sw_treated / sw_t_sum
16081608
att = float(np.sum(sw_t_norm * treated_residuals))
16091609

1610-
# --- Regression nuisance IF correction ---
1611-
# Account for uncertainty in beta estimation
1610+
# --- DRDID panel OR influence function (survey-weighted) ---
1611+
# Following Sant'Anna & Zhao (2020) Theorem 3.1 for the OR estimator.
1612+
# All IF terms are scaled by 1/sw_t_sum so that sum(IF^2) gives V(ATT).
16121613
X_c = np.column_stack([np.ones(n_c), X_control])
16131614
X_t = np.column_stack([np.ones(n_t), X_treated])
16141615

1615-
# Weighted bread: (X'WX)^{-1}
1616+
# Treated component: w_i * (ΔY_i - m(X_i) - ATT) / sum(w_treated)
1617+
inf_treated = (sw_treated / sw_t_sum) * (treated_residuals - att)
1618+
1619+
# Control outcome-regression component
1620+
predicted_c = np.dot(X_c, beta)
1621+
inf_control_or = -(sw_control / sw_t_sum) * (control_change - predicted_c)
1622+
1623+
# Regression nuisance IF correction (accounts for beta estimation)
1624+
# Hessian of WLS: H = X_c' W_c X_c
16161625
XWX = X_c.T @ (X_c * sw_control[:, None])
16171626
try:
16181627
XWX_inv = np.linalg.solve(XWX, np.eye(XWX.shape[0]))
16191628
except np.linalg.LinAlgError:
16201629
XWX_inv = np.linalg.lstsq(XWX, np.eye(XWX.shape[0]), rcond=None)[0]
16211630

1622-
# Per-control regression score: w_i * x_i * resid_i
1623-
resid_c = control_change - X_c @ beta
1631+
# Per-control score: w_i * x_i * (y_i - x_i'beta)
1632+
resid_c = control_change - predicted_c
16241633
score_c = X_c * (sw_control * resid_c)[:, None]
1625-
asy_lin_rep_reg = score_c @ XWX_inv # shape (n_c, p)
1634+
asy_lin_rep_reg = score_c @ XWX_inv # (n_c, p)
16261635

1627-
# Weighted treated covariate mean
1628-
X_treated_mean_w = np.average(X_t, axis=0, weights=sw_treated)
1636+
# Projection direction: survey-weighted treated covariate mean
1637+
X_treated_mean_w = np.sum(X_t * sw_treated[:, None], axis=0) / sw_t_sum
16291638

1630-
# Regression IF correction for control observations
1631-
inf_control_reg_corr = asy_lin_rep_reg @ X_treated_mean_w
1639+
# Correction: how beta uncertainty affects ATT
1640+
inf_control_reg_corr = (asy_lin_rep_reg @ X_treated_mean_w) / sw_t_sum
16321641

1633-
# Influence function (survey-weighted)
1634-
inf_treated = sw_t_norm * (treated_residuals - att)
1635-
inf_control = (
1636-
-sw_c_norm * (control_change - np.dot(X_c, beta)) + inf_control_reg_corr
1637-
)
1642+
inf_control = inf_control_or + inf_control_reg_corr
16381643
inf_func = np.concatenate([inf_treated, inf_control])
16391644

16401645
# SE from influence function variance
1641-
var_psi = np.sum(inf_treated**2) + np.sum(inf_control**2)
1642-
se = float(np.sqrt(var_psi)) if var_psi > 0 else 0.0
1646+
se = float(np.sqrt(np.sum(inf_func**2)))
1647+
se = se if se > 0 else 0.0
16431648
else:
16441649
att = float(np.mean(treated_residuals))
16451650

diff_diff/triple_diff.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1077,11 +1077,18 @@ def _estimate_ddd_decomposition(
10771077

10781078
if resolved_survey is not None:
10791079
# Survey-weighted SE via TSL on the combined influence function.
1080-
# Treat the IF as a single-parameter score vector:
1081-
# compute_survey_vcov(ones, IF, resolved) gives V(ATT).
1080+
# The pairwise IFs already incorporate survey weights (via weighted
1081+
# Riesz representers), but compute_survey_vcov multiplies by weights
1082+
# again internally. Divide out the survey weights to get the
1083+
# unweighted IF that TSL will correctly re-weight.
10821084
from diff_diff.survey import compute_survey_vcov
10831085

1084-
vcov_survey = compute_survey_vcov(np.ones((n, 1)), inf_func, resolved_survey)
1086+
inf_for_tsl = inf_func.copy()
1087+
sw = survey_weights
1088+
if sw is not None:
1089+
nz = sw > 0
1090+
inf_for_tsl[nz] = inf_for_tsl[nz] / sw[nz]
1091+
vcov_survey = compute_survey_vcov(np.ones((n, 1)), inf_for_tsl, resolved_survey)
10851092
se = float(np.sqrt(vcov_survey[0, 0]))
10861093
elif self._cluster_ids is not None:
10871094
# Cluster-robust SE: sum IF within clusters, then Liang-Zeger variance

0 commit comments

Comments
 (0)