Skip to content

Commit 98d80c6

Browse files
igerberclaude
andcommitted
Fix round-4 review P1s: CS WIF normalization, IPW nuisance IF, TwoStage n_psu
- CallawaySantAnna WIF: remove inner /total_weight from indicator_diff — the final psi_wif/total_weight handles normalization once, matching R's did::wif() - CallawaySantAnna IPW covariate: add propensity score nuisance IF correction (survey-weighted Hessian, score, M2 gradient) so per-cell and aggregated SEs account for PS estimation uncertainty - TwoStageDiD: recompute n_psu/n_strata after always-treated filtering via np.unique() on subsetted arrays, then recompute survey_metadata Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent e754b04 commit 98d80c6

3 files changed

Lines changed: 55 additions & 2 deletions

File tree

diff_diff/staggered.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1774,8 +1774,42 @@ def _ipw_estimation(
17741774
)
17751775
inf_func = np.concatenate([inf_treated, inf_control])
17761776

1777+
# Propensity score IF correction
1778+
# Accounts for estimation uncertainty in logistic regression coefficients
1779+
X_all_int = np.column_stack([np.ones(n_t + n_c), X_all])
1780+
pscore_all = np.concatenate([pscore_treated, pscore_control])
1781+
1782+
# Survey-weighted PS Hessian: sum(w_i * mu_i * (1-mu_i) * x_i * x_i')
1783+
W_ps = pscore_all * (1 - pscore_all)
1784+
if sw_all is not None:
1785+
W_ps = W_ps * sw_all
1786+
H = X_all_int.T @ (W_ps[:, None] * X_all_int)
1787+
try:
1788+
H_inv = np.linalg.solve(H, np.eye(H.shape[0]))
1789+
except np.linalg.LinAlgError:
1790+
H_inv = np.linalg.lstsq(H, np.eye(H.shape[0]), rcond=None)[0]
1791+
1792+
# PS score: w_i * (D_i - pi_i) * X_i
1793+
D_all = np.concatenate([np.ones(n_t), np.zeros(n_c)])
1794+
score_ps = (D_all - pscore_all)[:, None] * X_all_int
1795+
if sw_all is not None:
1796+
score_ps = score_ps * sw_all[:, None]
1797+
asy_lin_rep_ps = score_ps @ H_inv # shape (n_t + n_c, p)
1798+
1799+
# M2: gradient of ATT w.r.t. PS parameters
1800+
att_control_weighted = np.sum(weights_control_norm * control_change)
1801+
M2 = np.mean(
1802+
(weights_control_norm * (control_change - att_control_weighted))[:, None]
1803+
* X_all_int[n_t:],
1804+
axis=0,
1805+
)
1806+
1807+
# PS correction to influence function
1808+
inf_ps_correction = asy_lin_rep_ps @ M2
1809+
inf_func = inf_func + inf_ps_correction
1810+
17771811
# SE from influence function variance
1778-
var_psi = np.sum(inf_treated**2) + np.sum(inf_control**2)
1812+
var_psi = np.sum(inf_func**2)
17791813
se = float(np.sqrt(var_psi)) if var_psi > 0 else 0.0
17801814
else:
17811815
# IPW weights for control units: p(X) / (1 - p(X))

diff_diff/staggered_aggregation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -359,7 +359,7 @@ def _compute_combined_influence_function(
359359
# s_i * p_g_k (symmetric weight application)
360360
weighted_pg_term = pg_keepers[np.newaxis, :] * unit_sw[:, np.newaxis]
361361
# s_i * (1{G_i == g_k} - p_g_k) / sum(s_j)
362-
indicator_diff = (weighted_indicator - weighted_pg_term) / total_weight
362+
indicator_diff = weighted_indicator - weighted_pg_term
363363
indicator_sum_w = np.sum(indicator_diff, axis=1)
364364

365365
with np.errstate(divide="ignore", invalid="ignore", over="ignore"):

diff_diff/two_stage.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -312,6 +312,25 @@ def fit(
312312
else None
313313
),
314314
)
315+
# Recompute n_psu/n_strata after subsetting
316+
new_n_psu = (
317+
len(np.unique(resolved_survey.psu)) if resolved_survey.psu is not None else 0
318+
)
319+
new_n_strata = (
320+
len(np.unique(resolved_survey.strata))
321+
if resolved_survey.strata is not None
322+
else 0
323+
)
324+
resolved_survey = replace(resolved_survey, n_psu=new_n_psu, n_strata=new_n_strata)
325+
# Recompute survey_metadata since it depends on these counts
326+
from diff_diff.survey import compute_survey_metadata
327+
328+
raw_w = (
329+
df[survey_design.weights].values.astype(np.float64)
330+
if survey_design.weights
331+
else np.ones(len(df), dtype=np.float64)
332+
)
333+
survey_metadata = compute_survey_metadata(resolved_survey, raw_w)
315334

316335
# Treatment indicator with anticipation
317336
effective_treat = df[first_treat] - self.anticipation

0 commit comments

Comments
 (0)