@@ -565,9 +565,25 @@ def fit(
565565 resolved_survey = resolved_survey ,
566566 )
567567 elif self .estimation_method == "ipw" :
568- att , se , r_squared , pscore_stats = self ._ipw_estimation (y , G , P , T , X )
568+ att , se , r_squared , pscore_stats = self ._ipw_estimation (
569+ y ,
570+ G ,
571+ P ,
572+ T ,
573+ X ,
574+ survey_weights = survey_weights ,
575+ resolved_survey = resolved_survey ,
576+ )
569577 else : # doubly robust
570- att , se , r_squared , pscore_stats = self ._doubly_robust (y , G , P , T , X )
578+ att , se , r_squared , pscore_stats = self ._doubly_robust (
579+ y ,
580+ G ,
581+ P ,
582+ T ,
583+ X ,
584+ survey_weights = survey_weights ,
585+ resolved_survey = resolved_survey ,
586+ )
571587
572588 # Compute inference
573589 # When survey design is active, use survey df (n_PSU - n_strata)
@@ -758,6 +774,8 @@ def _ipw_estimation(
758774 P : np .ndarray ,
759775 T : np .ndarray ,
760776 X : Optional [np .ndarray ],
777+ survey_weights : Optional [np .ndarray ] = None ,
778+ resolved_survey = None ,
761779 ) -> Tuple [float , float , Optional [float ], Optional [Dict [str , float ]]]:
762780 """
763781 Estimate ATT using inverse probability weighting via three-DiD
@@ -767,7 +785,15 @@ def _ipw_estimation(
767785 subgroup membership P(subgroup=4|X) within {j, 4} subset.
768786 Matches R's triplediff::ddd() with est_method="ipw".
769787 """
770- return self ._estimate_ddd_decomposition (y , G , P , T , X )
788+ return self ._estimate_ddd_decomposition (
789+ y ,
790+ G ,
791+ P ,
792+ T ,
793+ X ,
794+ survey_weights = survey_weights ,
795+ resolved_survey = resolved_survey ,
796+ )
771797
772798 def _doubly_robust (
773799 self ,
@@ -776,6 +802,8 @@ def _doubly_robust(
776802 P : np .ndarray ,
777803 T : np .ndarray ,
778804 X : Optional [np .ndarray ],
805+ survey_weights : Optional [np .ndarray ] = None ,
806+ resolved_survey = None ,
779807 ) -> Tuple [float , float , Optional [float ], Optional [Dict [str , float ]]]:
780808 """
781809 Estimate ATT using doubly robust estimation via three-DiD
@@ -786,7 +814,15 @@ def _doubly_robust(
786814 correctly specified. Matches R's triplediff::ddd() with
787815 est_method="dr".
788816 """
789- return self ._estimate_ddd_decomposition (y , G , P , T , X )
817+ return self ._estimate_ddd_decomposition (
818+ y ,
819+ G ,
820+ P ,
821+ T ,
822+ X ,
823+ survey_weights = survey_weights ,
824+ resolved_survey = resolved_survey ,
825+ )
790826
791827 def _estimate_ddd_decomposition (
792828 self ,
@@ -1186,7 +1222,17 @@ def _compute_did_rc(
11861222 Matches R's triplediff::compute_did_rc().
11871223 """
11881224 if est_method == "ipw" :
1189- return self ._compute_did_rc_ipw (y , post , PA4 , PAa , pscore , covX , hessian , n )
1225+ return self ._compute_did_rc_ipw (
1226+ y ,
1227+ post ,
1228+ PA4 ,
1229+ PAa ,
1230+ pscore ,
1231+ covX ,
1232+ hessian ,
1233+ n ,
1234+ weights = weights ,
1235+ )
11901236 elif est_method == "reg" :
11911237 return self ._compute_did_rc_reg (
11921238 y ,
@@ -1215,6 +1261,7 @@ def _compute_did_rc(
12151261 or_trt_post ,
12161262 hessian ,
12171263 n ,
1264+ weights = weights ,
12181265 )
12191266
12201267 def _compute_did_rc_ipw (
@@ -1227,6 +1274,7 @@ def _compute_did_rc_ipw(
12271274 covX : np .ndarray ,
12281275 hessian : Optional [np .ndarray ],
12291276 n : int ,
1277+ weights : Optional [np .ndarray ] = None ,
12301278 ) -> Tuple [float , np .ndarray ]:
12311279 """IPW DiD for a single pairwise comparison (RC)."""
12321280 # Riesz representers (IPW weights * indicators)
@@ -1235,6 +1283,13 @@ def _compute_did_rc_ipw(
12351283 riesz_control_pre = pscore * PAa * (1 - post ) / (1 - pscore )
12361284 riesz_control_post = pscore * PAa * post / (1 - pscore )
12371285
1286+ # Incorporate survey weights into Riesz representers
1287+ if weights is not None :
1288+ riesz_treat_pre = riesz_treat_pre * weights
1289+ riesz_treat_post = riesz_treat_post * weights
1290+ riesz_control_pre = riesz_control_pre * weights
1291+ riesz_control_post = riesz_control_post * weights
1292+
12381293 # Hajek-normalized cell-time means
12391294 def _hajek (riesz , y_vals ):
12401295 denom = np .mean (riesz )
@@ -1393,6 +1448,7 @@ def _compute_did_rc_dr(
13931448 or_trt_post : np .ndarray ,
13941449 hessian : Optional [np .ndarray ],
13951450 n : int ,
1451+ weights : Optional [np .ndarray ] = None ,
13961452 ) -> Tuple [float , np .ndarray ]:
13971453 """Doubly robust DiD for a single pairwise comparison (RC)."""
13981454 or_ctrl = post * or_ctrl_post + (1 - post ) * or_ctrl_pre
@@ -1406,6 +1462,16 @@ def _compute_did_rc_dr(
14061462 riesz_dt1 = PA4 * post
14071463 riesz_dt0 = PA4 * (1 - post )
14081464
1465+ # Incorporate survey weights into Riesz representers
1466+ if weights is not None :
1467+ riesz_treat_pre = riesz_treat_pre * weights
1468+ riesz_treat_post = riesz_treat_post * weights
1469+ riesz_control_pre = riesz_control_pre * weights
1470+ riesz_control_post = riesz_control_post * weights
1471+ riesz_d = riesz_d * weights
1472+ riesz_dt1 = riesz_dt1 * weights
1473+ riesz_dt0 = riesz_dt0 * weights
1474+
14091475 # DR cell-time components
14101476 def _safe_ratio (num , denom ):
14111477 return num / denom if denom > 0 else 0.0
0 commit comments