Skip to content

Commit ab65708

Browse files
igerberclaude
andcommitted
Fix P0/P1 findings from AI review: TripleDiff IPW/DR survey threading, CS SE formula
- P0: Thread survey_weights through TripleDifference IPW and DR call chains (_ipw_estimation, _doubly_robust, _compute_did_rc_ipw, _compute_did_rc_dr). Survey weights now enter Riesz representers for weighted Hajek averages. - P1: Fix CallawaySantAnna no-covariate survey SE to derive from sum(IF^2) instead of sum(w_norm * (y-mean)^2). All 4 locations now consistent with stored influence functions. - P1: Update REGISTRY.md TripleDifference entry to reflect full survey support (was still marked as "IPW/DR deferred"). - P2: Add behavioral tests for TripleDiff IPW/DR survey: non-uniform weights change ATT, uniform weights match unweighted. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent a8af057 commit ab65708

4 files changed

Lines changed: 181 additions & 20 deletions

File tree

diff_diff/staggered.py

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -715,13 +715,15 @@ def _compute_all_att_gt_vectorized(
715715
mu_c = float(np.sum(sw_c_norm * control_change))
716716
att = mu_t - mu_c
717717

718-
var_t = float(np.sum(sw_t_norm * (treated_change - mu_t) ** 2))
719-
var_c = float(np.sum(sw_c_norm * (control_change - mu_c) ** 2))
720-
se = float(np.sqrt(var_t + var_c)) if (n_t > 0 and n_c > 0) else 0.0
721-
722718
# Influence function (survey-weighted)
723719
inf_treated = sw_t_norm * (treated_change - mu_t)
724720
inf_control = -sw_c_norm * (control_change - mu_c)
721+
# SE derived from IF: sum(IF_i^2)
722+
se = (
723+
float(np.sqrt(np.sum(inf_treated**2) + np.sum(inf_control**2)))
724+
if (n_t > 0 and n_c > 0)
725+
else 0.0
726+
)
725727
sw_sum = float(np.sum(sw_t))
726728
else:
727729
att = float(np.mean(treated_change) - np.mean(control_change))
@@ -1624,9 +1626,11 @@ def _outcome_regression(
16241626
inf_func = np.concatenate([inf_treated, inf_control])
16251627

16261628
# SE from influence function variance
1627-
var_t = float(np.sum(sw_t_norm * (treated_change - mu_t) ** 2))
1628-
var_c = float(np.sum(sw_c_norm * (control_change - mu_c) ** 2))
1629-
se = float(np.sqrt(var_t + var_c)) if (n_t > 0 and n_c > 0) else 0.0
1629+
se = (
1630+
float(np.sqrt(np.sum(inf_treated**2) + np.sum(inf_control**2)))
1631+
if (n_t > 0 and n_c > 0)
1632+
else 0.0
1633+
)
16301634
else:
16311635
att = float(np.mean(treated_change) - np.mean(control_change))
16321636

@@ -1787,9 +1791,11 @@ def _ipw_estimation(
17871791
inf_control = -sw_c_norm * (control_change - mu_c)
17881792
inf_func = np.concatenate([inf_treated, inf_control])
17891793

1790-
var_t = float(np.sum(sw_t_norm * (treated_change - mu_t) ** 2))
1791-
var_c = float(np.sum(sw_c_norm * (control_change - mu_c) ** 2))
1792-
se = float(np.sqrt(var_t + var_c)) if (n_t > 0 and n_c > 0) else 0.0
1794+
se = (
1795+
float(np.sqrt(np.sum(inf_treated**2) + np.sum(inf_control**2)))
1796+
if (n_t > 0 and n_c > 0)
1797+
else 0.0
1798+
)
17931799
else:
17941800
p_treat = n_treated / n_total # unconditional propensity score
17951801

@@ -1998,9 +2004,11 @@ def _doubly_robust(
19982004
inf_control = -sw_c_norm * (control_change - mu_c)
19992005
inf_func = np.concatenate([inf_treated, inf_control])
20002006

2001-
var_t = float(np.sum(sw_t_norm * (treated_change - mu_t) ** 2))
2002-
var_c = float(np.sum(sw_c_norm * (control_change - mu_c) ** 2))
2003-
se = float(np.sqrt(var_t + var_c)) if (n_t > 0 and n_c > 0) else 0.0
2007+
se = (
2008+
float(np.sqrt(np.sum(inf_treated**2) + np.sum(inf_control**2)))
2009+
if (n_t > 0 and n_c > 0)
2010+
else 0.0
2011+
)
20042012
else:
20052013
att = float(np.mean(treated_change) - np.mean(control_change))
20062014

diff_diff/triple_diff.py

Lines changed: 71 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -565,9 +565,25 @@ def fit(
565565
resolved_survey=resolved_survey,
566566
)
567567
elif self.estimation_method == "ipw":
568-
att, se, r_squared, pscore_stats = self._ipw_estimation(y, G, P, T, X)
568+
att, se, r_squared, pscore_stats = self._ipw_estimation(
569+
y,
570+
G,
571+
P,
572+
T,
573+
X,
574+
survey_weights=survey_weights,
575+
resolved_survey=resolved_survey,
576+
)
569577
else: # doubly robust
570-
att, se, r_squared, pscore_stats = self._doubly_robust(y, G, P, T, X)
578+
att, se, r_squared, pscore_stats = self._doubly_robust(
579+
y,
580+
G,
581+
P,
582+
T,
583+
X,
584+
survey_weights=survey_weights,
585+
resolved_survey=resolved_survey,
586+
)
571587

572588
# Compute inference
573589
# When survey design is active, use survey df (n_PSU - n_strata)
@@ -758,6 +774,8 @@ def _ipw_estimation(
758774
P: np.ndarray,
759775
T: np.ndarray,
760776
X: Optional[np.ndarray],
777+
survey_weights: Optional[np.ndarray] = None,
778+
resolved_survey=None,
761779
) -> Tuple[float, float, Optional[float], Optional[Dict[str, float]]]:
762780
"""
763781
Estimate ATT using inverse probability weighting via three-DiD
@@ -767,7 +785,15 @@ def _ipw_estimation(
767785
subgroup membership P(subgroup=4|X) within {j, 4} subset.
768786
Matches R's triplediff::ddd() with est_method="ipw".
769787
"""
770-
return self._estimate_ddd_decomposition(y, G, P, T, X)
788+
return self._estimate_ddd_decomposition(
789+
y,
790+
G,
791+
P,
792+
T,
793+
X,
794+
survey_weights=survey_weights,
795+
resolved_survey=resolved_survey,
796+
)
771797

772798
def _doubly_robust(
773799
self,
@@ -776,6 +802,8 @@ def _doubly_robust(
776802
P: np.ndarray,
777803
T: np.ndarray,
778804
X: Optional[np.ndarray],
805+
survey_weights: Optional[np.ndarray] = None,
806+
resolved_survey=None,
779807
) -> Tuple[float, float, Optional[float], Optional[Dict[str, float]]]:
780808
"""
781809
Estimate ATT using doubly robust estimation via three-DiD
@@ -786,7 +814,15 @@ def _doubly_robust(
786814
correctly specified. Matches R's triplediff::ddd() with
787815
est_method="dr".
788816
"""
789-
return self._estimate_ddd_decomposition(y, G, P, T, X)
817+
return self._estimate_ddd_decomposition(
818+
y,
819+
G,
820+
P,
821+
T,
822+
X,
823+
survey_weights=survey_weights,
824+
resolved_survey=resolved_survey,
825+
)
790826

791827
def _estimate_ddd_decomposition(
792828
self,
@@ -1186,7 +1222,17 @@ def _compute_did_rc(
11861222
Matches R's triplediff::compute_did_rc().
11871223
"""
11881224
if est_method == "ipw":
1189-
return self._compute_did_rc_ipw(y, post, PA4, PAa, pscore, covX, hessian, n)
1225+
return self._compute_did_rc_ipw(
1226+
y,
1227+
post,
1228+
PA4,
1229+
PAa,
1230+
pscore,
1231+
covX,
1232+
hessian,
1233+
n,
1234+
weights=weights,
1235+
)
11901236
elif est_method == "reg":
11911237
return self._compute_did_rc_reg(
11921238
y,
@@ -1215,6 +1261,7 @@ def _compute_did_rc(
12151261
or_trt_post,
12161262
hessian,
12171263
n,
1264+
weights=weights,
12181265
)
12191266

12201267
def _compute_did_rc_ipw(
@@ -1227,6 +1274,7 @@ def _compute_did_rc_ipw(
12271274
covX: np.ndarray,
12281275
hessian: Optional[np.ndarray],
12291276
n: int,
1277+
weights: Optional[np.ndarray] = None,
12301278
) -> Tuple[float, np.ndarray]:
12311279
"""IPW DiD for a single pairwise comparison (RC)."""
12321280
# Riesz representers (IPW weights * indicators)
@@ -1235,6 +1283,13 @@ def _compute_did_rc_ipw(
12351283
riesz_control_pre = pscore * PAa * (1 - post) / (1 - pscore)
12361284
riesz_control_post = pscore * PAa * post / (1 - pscore)
12371285

1286+
# Incorporate survey weights into Riesz representers
1287+
if weights is not None:
1288+
riesz_treat_pre = riesz_treat_pre * weights
1289+
riesz_treat_post = riesz_treat_post * weights
1290+
riesz_control_pre = riesz_control_pre * weights
1291+
riesz_control_post = riesz_control_post * weights
1292+
12381293
# Hajek-normalized cell-time means
12391294
def _hajek(riesz, y_vals):
12401295
denom = np.mean(riesz)
@@ -1393,6 +1448,7 @@ def _compute_did_rc_dr(
13931448
or_trt_post: np.ndarray,
13941449
hessian: Optional[np.ndarray],
13951450
n: int,
1451+
weights: Optional[np.ndarray] = None,
13961452
) -> Tuple[float, np.ndarray]:
13971453
"""Doubly robust DiD for a single pairwise comparison (RC)."""
13981454
or_ctrl = post * or_ctrl_post + (1 - post) * or_ctrl_pre
@@ -1406,6 +1462,16 @@ def _compute_did_rc_dr(
14061462
riesz_dt1 = PA4 * post
14071463
riesz_dt0 = PA4 * (1 - post)
14081464

1465+
# Incorporate survey weights into Riesz representers
1466+
if weights is not None:
1467+
riesz_treat_pre = riesz_treat_pre * weights
1468+
riesz_treat_post = riesz_treat_post * weights
1469+
riesz_control_pre = riesz_control_pre * weights
1470+
riesz_control_post = riesz_control_post * weights
1471+
riesz_d = riesz_d * weights
1472+
riesz_dt1 = riesz_dt1 * weights
1473+
riesz_dt0 = riesz_dt0 * weights
1474+
14091475
# DR cell-time components
14101476
def _safe_ratio(num, denom):
14111477
return num / denom if denom > 0 else 0.0

docs/methodology/REGISTRY.md

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1245,8 +1245,7 @@ has no additional effect.
12451245
- [x] Influence function SE: std(w3·IF_3 + w2·IF_2 - w1·IF_1) / sqrt(n)
12461246
- [x] Cluster-robust SE via Liang-Zeger variance on influence function
12471247
- [x] ATT and SE match R within <0.001% for all methods and DGP types
1248-
- [x] Survey design support (Phase 3): regression method with weighted OLS + TSL on combined influence functions; IPW/DR deferred
1249-
- **Note:** TripleDifference IPW/DR with survey weights deferred until weighted solve_logit() (Phase 5)
1248+
- [x] Survey design support: all methods (reg, IPW, DR) with weighted OLS/logit + TSL on combined influence functions. Weighted solve_logit() for propensity scores in IPW/DR paths.
12501249

12511250
---
12521251

tests/test_survey_phase4.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -715,6 +715,94 @@ def test_ipw_survey_results_finite(self, ddd_survey_data):
715715
assert np.isfinite(result.se)
716716
assert result.survey_metadata is not None
717717

718+
def test_ipw_nonuniform_weights_change_att(self, ddd_survey_data):
719+
"""Non-uniform survey weights should change IPW ATT vs unweighted."""
720+
sd = SurveyDesign(weights="weight")
721+
r_no = TripleDifference(estimation_method="ipw").fit(
722+
ddd_survey_data,
723+
"outcome",
724+
"group",
725+
"partition",
726+
"time",
727+
)
728+
r_sv = TripleDifference(estimation_method="ipw").fit(
729+
ddd_survey_data,
730+
"outcome",
731+
"group",
732+
"partition",
733+
"time",
734+
survey_design=sd,
735+
)
736+
assert not np.isclose(
737+
r_no.att, r_sv.att, atol=1e-6
738+
), "Non-uniform survey weights should change IPW ATT"
739+
740+
def test_dr_nonuniform_weights_change_att(self, ddd_survey_data):
741+
"""Non-uniform survey weights should change DR ATT vs unweighted."""
742+
sd = SurveyDesign(weights="weight")
743+
r_no = TripleDifference(estimation_method="dr").fit(
744+
ddd_survey_data,
745+
"outcome",
746+
"group",
747+
"partition",
748+
"time",
749+
)
750+
r_sv = TripleDifference(estimation_method="dr").fit(
751+
ddd_survey_data,
752+
"outcome",
753+
"group",
754+
"partition",
755+
"time",
756+
survey_design=sd,
757+
)
758+
assert not np.isclose(
759+
r_no.att, r_sv.att, atol=1e-6
760+
), "Non-uniform survey weights should change DR ATT"
761+
762+
def test_ipw_uniform_weights_match_unweighted(self, ddd_survey_data):
763+
"""Uniform survey weights should match unweighted IPW result."""
764+
data = ddd_survey_data.copy()
765+
data["uw"] = 1.0
766+
sd = SurveyDesign(weights="uw")
767+
r_no = TripleDifference(estimation_method="ipw").fit(
768+
data,
769+
"outcome",
770+
"group",
771+
"partition",
772+
"time",
773+
)
774+
r_sv = TripleDifference(estimation_method="ipw").fit(
775+
data,
776+
"outcome",
777+
"group",
778+
"partition",
779+
"time",
780+
survey_design=sd,
781+
)
782+
assert np.isclose(r_no.att, r_sv.att, atol=1e-6)
783+
784+
def test_dr_uniform_weights_match_unweighted(self, ddd_survey_data):
785+
"""Uniform survey weights should match unweighted DR result."""
786+
data = ddd_survey_data.copy()
787+
data["uw"] = 1.0
788+
sd = SurveyDesign(weights="uw")
789+
r_no = TripleDifference(estimation_method="dr").fit(
790+
data,
791+
"outcome",
792+
"group",
793+
"partition",
794+
"time",
795+
)
796+
r_sv = TripleDifference(estimation_method="dr").fit(
797+
data,
798+
"outcome",
799+
"group",
800+
"partition",
801+
"time",
802+
survey_design=sd,
803+
)
804+
assert np.isclose(r_no.att, r_sv.att, atol=1e-6)
805+
718806

719807
# =============================================================================
720808
# TestCallawaySantAnnaSurveyInference

0 commit comments

Comments
 (0)