@@ -181,6 +181,13 @@ class CallawaySantAnna(
181181 Trimming bound for propensity scores. Scores are clipped to
182182 ``[pscore_trim, 1 - pscore_trim]`` before weight computation
183183 in IPW and DR estimation. Must be in ``(0, 0.5)``.
184+ panel : bool, default=True
185+ Whether the data is a balanced/unbalanced panel (units observed
186+ across multiple time periods). Set to ``False`` for repeated
187+ cross-sections where each observation has a unique unit ID and
188+ units do not repeat across periods. Uses cross-sectional DRDID
189+ (Sant'Anna & Zhao 2020, Section 4) with per-observation influence
190+ functions.
184191
185192 Attributes
186193 ----------
@@ -1783,6 +1790,7 @@ def fit(
17831790 pscore_trim = self .pscore_trim ,
17841791 survey_metadata = survey_metadata ,
17851792 event_study_vcov = event_study_vcov ,
1793+ panel = self .panel ,
17861794 )
17871795
17881796 self .is_fitted_ = True
@@ -2972,6 +2980,40 @@ def _ipw_estimation_rc(
29722980 inf_control = np .concatenate ([inf_ct , inf_cs ])
29732981 inf_all = np .concatenate ([inf_treated , inf_control ])
29742982
2983+ # PS IF correction for cross-sectional IPW
2984+ X_all_int = np .column_stack ([np .ones (len (D_all )), X_all ])
2985+ pscore_all = pscore # already computed and clipped
2986+
2987+ W_ps = pscore_all * (1 - pscore_all )
2988+ if sw_all is not None :
2989+ W_ps = W_ps * sw_all
2990+ H_ps = X_all_int .T @ (W_ps [:, None ] * X_all_int )
2991+ H_ps_inv = _safe_inv (H_ps )
2992+
2993+ score_ps = (D_all - pscore_all )[:, None ] * X_all_int
2994+ if sw_all is not None :
2995+ score_ps = score_ps * sw_all [:, None ]
2996+ asy_lin_rep_ps = score_ps @ H_ps_inv # (n_all, p+1)
2997+
2998+ # M2: gradient of IPW ATT w.r.t. PS parameters
2999+ # Control IPW residuals from both periods
3000+ ipw_resid_ct = w_ct_norm * (y_ct - mu_ct_ipw )
3001+ ipw_resid_cs = w_cs_norm * (y_cs - mu_cs_ipw )
3002+ # Zero for treated observations
3003+ M2_rc = np .zeros (X_all_int .shape [1 ])
3004+ # Control-t contribution
3005+ M2_rc += np .mean (
3006+ ipw_resid_ct [:, None ] * X_all_int [n_gt + n_gs : n_gt + n_gs + n_ct ],
3007+ axis = 0 ,
3008+ )
3009+ # Control-s contribution (opposite sign -- base period)
3010+ M2_rc -= np .mean (
3011+ ipw_resid_cs [:, None ] * X_all_int [n_gt + n_gs + n_ct :],
3012+ axis = 0 ,
3013+ )
3014+
3015+ inf_all = inf_all + asy_lin_rep_ps @ M2_rc
3016+
29753017 se = float (np .sqrt (np .sum (inf_all ** 2 )))
29763018
29773019 idx_all = None
@@ -3121,6 +3163,70 @@ def _doubly_robust_rc(
31213163 inf_control = np .concatenate ([inf_ct , inf_cs ])
31223164 inf_all = np .concatenate ([inf_treated , inf_control ])
31233165
3166+ # --- PS IF correction ---
3167+ X_all_int = np .column_stack ([np .ones (len (D_all )), X_all ])
3168+ pscore_all = pscore
3169+
3170+ W_ps = pscore_all * (1 - pscore_all )
3171+ if sw_all is not None :
3172+ W_ps = W_ps * sw_all
3173+ H_ps = X_all_int .T @ (W_ps [:, None ] * X_all_int )
3174+ H_ps_inv = _safe_inv (H_ps )
3175+
3176+ score_ps = (D_all - pscore_all )[:, None ] * X_all_int
3177+ if sw_all is not None :
3178+ score_ps = score_ps * sw_all [:, None ]
3179+ asy_lin_rep_ps = score_ps @ H_ps_inv
3180+
3181+ # M2_dr: uses DR residuals (m-y) instead of raw y
3182+ dr_resid_ct = m_ct - y_ct # control period-t DR residuals
3183+ dr_resid_cs = m_cs - y_cs # control period-s DR residuals
3184+ normalizer = np .sum (sw_gt ) if sw_gt is not None else n_gt
3185+ M2_dr = np .zeros (X_all_int .shape [1 ])
3186+ # Control-t: (w_ct/normalizer) * (m_ct - y_ct) * X
3187+ ct_slice = slice (n_gt + n_gs , n_gt + n_gs + n_ct )
3188+ M2_dr += np .mean (
3189+ ((w_ct / normalizer ) * dr_resid_ct )[:, None ] * X_all_int [ct_slice ],
3190+ axis = 0 ,
3191+ )
3192+ # Control-s: -(w_cs/normalizer) * (m_cs - y_cs) * X (opposite sign)
3193+ cs_slice = slice (n_gt + n_gs + n_ct , None )
3194+ M2_dr -= np .mean (
3195+ ((w_cs / normalizer ) * dr_resid_cs )[:, None ] * X_all_int [cs_slice ],
3196+ axis = 0 ,
3197+ )
3198+
3199+ inf_all = inf_all + asy_lin_rep_ps @ M2_dr
3200+
3201+ # --- OR IF correction -- period t model ---
3202+ W_t = sw_ct if sw_ct is not None else np .ones (n_ct )
3203+ bread_t = _safe_inv (X_ct_int .T @ (W_t [:, None ] * X_ct_int ))
3204+
3205+ # M1_t: dATT/dbeta_t (from treated-t prediction and control-t augmentation)
3206+ sw_gt_vals = sw_gt if sw_gt is not None else np .ones (n_gt )
3207+ M1_t = (
3208+ - np .sum (sw_gt_vals [:, None ] * X_gt_int , axis = 0 )
3209+ + np .sum (w_ct [:, None ] * X_ct_int , axis = 0 )
3210+ ) / normalizer
3211+
3212+ asy_lin_rep_or_t = (W_t * (y_ct - m_ct ))[:, None ] * X_ct_int @ bread_t
3213+ # Apply only to control-t portion of inf_all
3214+ inf_all [n_gt + n_gs : n_gt + n_gs + n_ct ] += asy_lin_rep_or_t @ M1_t
3215+
3216+ # --- OR IF correction -- period s model ---
3217+ W_s = sw_cs if sw_cs is not None else np .ones (n_cs )
3218+ bread_s = _safe_inv (X_cs_int .T @ (W_s [:, None ] * X_cs_int ))
3219+
3220+ sw_gs_vals = sw_gs if sw_gs is not None else np .ones (n_gs )
3221+ M1_s = (
3222+ np .sum (sw_gs_vals [:, None ] * X_gs_int , axis = 0 )
3223+ - np .sum (w_cs [:, None ] * X_cs_int , axis = 0 )
3224+ ) / normalizer
3225+
3226+ asy_lin_rep_or_s = (W_s * (y_cs - m_cs ))[:, None ] * X_cs_int @ bread_s
3227+ # Apply only to control-s portion of inf_all
3228+ inf_all [n_gt + n_gs + n_ct :] += asy_lin_rep_or_s @ M1_s
3229+
31243230 se = float (np .sqrt (np .sum (inf_all ** 2 )))
31253231
31263232 idx_all = None
0 commit comments