Skip to content

Commit 4bf566d

Browse files
igerberclaude
andcommitted
Address CI review: RCS IF corrections, aggregation weights, replicate VCV, panel on results
Fix 5 findings from PR #240 CI review: - Add cross-sectional nuisance IF corrections (PS + OR) to _ipw_estimation_rc and _doubly_robust_rc, matching panel path methodology - Use fixed full-sample cohort masses for unweighted RCS aggregation weights (consistency with WIF group-share denominator) - Guard replicate-weight designs from full event-study VCV (diagonal fallback) - Add panel field to CallawaySantAnnaResults, fix summary labels for RCS - Add panel to class docstring, replicate VCV test, RCS IF correction test Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 04bd263 commit 4bf566d

5 files changed

Lines changed: 245 additions & 9 deletions

File tree

diff_diff/staggered.py

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,13 @@ class CallawaySantAnna(
181181
Trimming bound for propensity scores. Scores are clipped to
182182
``[pscore_trim, 1 - pscore_trim]`` before weight computation
183183
in IPW and DR estimation. Must be in ``(0, 0.5)``.
184+
panel : bool, default=True
185+
Whether the data is a balanced/unbalanced panel (units observed
186+
across multiple time periods). Set to ``False`` for repeated
187+
cross-sections where each observation has a unique unit ID and
188+
units do not repeat across periods. Uses cross-sectional DRDID
189+
(Sant'Anna & Zhao 2020, Section 4) with per-observation influence
190+
functions.
184191
185192
Attributes
186193
----------
@@ -1783,6 +1790,7 @@ def fit(
17831790
pscore_trim=self.pscore_trim,
17841791
survey_metadata=survey_metadata,
17851792
event_study_vcov=event_study_vcov,
1793+
panel=self.panel,
17861794
)
17871795

17881796
self.is_fitted_ = True
@@ -2972,6 +2980,40 @@ def _ipw_estimation_rc(
29722980
inf_control = np.concatenate([inf_ct, inf_cs])
29732981
inf_all = np.concatenate([inf_treated, inf_control])
29742982

2983+
# PS IF correction for cross-sectional IPW
2984+
X_all_int = np.column_stack([np.ones(len(D_all)), X_all])
2985+
pscore_all = pscore # already computed and clipped
2986+
2987+
W_ps = pscore_all * (1 - pscore_all)
2988+
if sw_all is not None:
2989+
W_ps = W_ps * sw_all
2990+
H_ps = X_all_int.T @ (W_ps[:, None] * X_all_int)
2991+
H_ps_inv = _safe_inv(H_ps)
2992+
2993+
score_ps = (D_all - pscore_all)[:, None] * X_all_int
2994+
if sw_all is not None:
2995+
score_ps = score_ps * sw_all[:, None]
2996+
asy_lin_rep_ps = score_ps @ H_ps_inv # (n_all, p+1)
2997+
2998+
# M2: gradient of IPW ATT w.r.t. PS parameters
2999+
# Control IPW residuals from both periods
3000+
ipw_resid_ct = w_ct_norm * (y_ct - mu_ct_ipw)
3001+
ipw_resid_cs = w_cs_norm * (y_cs - mu_cs_ipw)
3002+
# Zero for treated observations
3003+
M2_rc = np.zeros(X_all_int.shape[1])
3004+
# Control-t contribution
3005+
M2_rc += np.mean(
3006+
ipw_resid_ct[:, None] * X_all_int[n_gt + n_gs : n_gt + n_gs + n_ct],
3007+
axis=0,
3008+
)
3009+
# Control-s contribution (opposite sign -- base period)
3010+
M2_rc -= np.mean(
3011+
ipw_resid_cs[:, None] * X_all_int[n_gt + n_gs + n_ct :],
3012+
axis=0,
3013+
)
3014+
3015+
inf_all = inf_all + asy_lin_rep_ps @ M2_rc
3016+
29753017
se = float(np.sqrt(np.sum(inf_all**2)))
29763018

29773019
idx_all = None
@@ -3121,6 +3163,70 @@ def _doubly_robust_rc(
31213163
inf_control = np.concatenate([inf_ct, inf_cs])
31223164
inf_all = np.concatenate([inf_treated, inf_control])
31233165

3166+
# --- PS IF correction ---
3167+
X_all_int = np.column_stack([np.ones(len(D_all)), X_all])
3168+
pscore_all = pscore
3169+
3170+
W_ps = pscore_all * (1 - pscore_all)
3171+
if sw_all is not None:
3172+
W_ps = W_ps * sw_all
3173+
H_ps = X_all_int.T @ (W_ps[:, None] * X_all_int)
3174+
H_ps_inv = _safe_inv(H_ps)
3175+
3176+
score_ps = (D_all - pscore_all)[:, None] * X_all_int
3177+
if sw_all is not None:
3178+
score_ps = score_ps * sw_all[:, None]
3179+
asy_lin_rep_ps = score_ps @ H_ps_inv
3180+
3181+
# M2_dr: uses DR residuals (m-y) instead of raw y
3182+
dr_resid_ct = m_ct - y_ct # control period-t DR residuals
3183+
dr_resid_cs = m_cs - y_cs # control period-s DR residuals
3184+
normalizer = np.sum(sw_gt) if sw_gt is not None else n_gt
3185+
M2_dr = np.zeros(X_all_int.shape[1])
3186+
# Control-t: (w_ct/normalizer) * (m_ct - y_ct) * X
3187+
ct_slice = slice(n_gt + n_gs, n_gt + n_gs + n_ct)
3188+
M2_dr += np.mean(
3189+
((w_ct / normalizer) * dr_resid_ct)[:, None] * X_all_int[ct_slice],
3190+
axis=0,
3191+
)
3192+
# Control-s: -(w_cs/normalizer) * (m_cs - y_cs) * X (opposite sign)
3193+
cs_slice = slice(n_gt + n_gs + n_ct, None)
3194+
M2_dr -= np.mean(
3195+
((w_cs / normalizer) * dr_resid_cs)[:, None] * X_all_int[cs_slice],
3196+
axis=0,
3197+
)
3198+
3199+
inf_all = inf_all + asy_lin_rep_ps @ M2_dr
3200+
3201+
# --- OR IF correction -- period t model ---
3202+
W_t = sw_ct if sw_ct is not None else np.ones(n_ct)
3203+
bread_t = _safe_inv(X_ct_int.T @ (W_t[:, None] * X_ct_int))
3204+
3205+
# M1_t: dATT/dbeta_t (from treated-t prediction and control-t augmentation)
3206+
sw_gt_vals = sw_gt if sw_gt is not None else np.ones(n_gt)
3207+
M1_t = (
3208+
-np.sum(sw_gt_vals[:, None] * X_gt_int, axis=0)
3209+
+ np.sum(w_ct[:, None] * X_ct_int, axis=0)
3210+
) / normalizer
3211+
3212+
asy_lin_rep_or_t = (W_t * (y_ct - m_ct))[:, None] * X_ct_int @ bread_t
3213+
# Apply only to control-t portion of inf_all
3214+
inf_all[n_gt + n_gs : n_gt + n_gs + n_ct] += asy_lin_rep_or_t @ M1_t
3215+
3216+
# --- OR IF correction -- period s model ---
3217+
W_s = sw_cs if sw_cs is not None else np.ones(n_cs)
3218+
bread_s = _safe_inv(X_cs_int.T @ (W_s[:, None] * X_cs_int))
3219+
3220+
sw_gs_vals = sw_gs if sw_gs is not None else np.ones(n_gs)
3221+
M1_s = (
3222+
np.sum(sw_gs_vals[:, None] * X_gs_int, axis=0)
3223+
- np.sum(w_cs[:, None] * X_cs_int, axis=0)
3224+
) / normalizer
3225+
3226+
asy_lin_rep_or_s = (W_s * (y_cs - m_cs))[:, None] * X_cs_int @ bread_s
3227+
# Apply only to control-s portion of inf_all
3228+
inf_all[n_gt + n_gs + n_ct :] += asy_lin_rep_or_s @ M1_s
3229+
31243230
se = float(np.sqrt(np.sum(inf_all**2)))
31253231

31263232
idx_all = None

diff_diff/staggered_aggregation.py

Lines changed: 45 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -73,15 +73,31 @@ def _aggregate_simple(
7373
if g > 0: # exclude never-treated (0)
7474
survey_cohort_weights[g] = float(np.sum(sw[unit_cohorts == g]))
7575

76+
# For unweighted RCS: use fixed full-sample cohort counts so that
77+
# aggregation weights match the WIF group-share denominator.
78+
rcs_cohort_counts = None
79+
if (
80+
precomputed is not None
81+
and not precomputed.get("is_panel", True)
82+
and survey_cohort_weights is None
83+
):
84+
unit_cohorts = precomputed["unit_cohorts"]
85+
rcs_cohort_counts = {}
86+
for g in np.unique(unit_cohorts):
87+
if g > 0:
88+
rcs_cohort_counts[g] = int(np.sum(unit_cohorts == g))
89+
7690
for (g, t), data in group_time_effects.items():
7791
# Only include post-treatment effects (t >= g - anticipation)
7892
# Pre-treatment effects are for parallel trends, not overall ATT
7993
if t < g - self.anticipation:
8094
continue
8195
effects.append(data["effect"])
82-
# Use fixed cohort-level survey weight sum for aggregation
96+
# Use fixed cohort-level weights for aggregation
8397
if survey_cohort_weights is not None and g in survey_cohort_weights:
8498
weights_list.append(survey_cohort_weights[g])
99+
elif rcs_cohort_counts is not None and g in rcs_cohort_counts:
100+
weights_list.append(rcs_cohort_counts[g])
85101
else:
86102
weights_list.append(data["n_treated"])
87103
gt_pairs.append((g, t))
@@ -571,15 +587,29 @@ def _aggregate_event_study(
571587
if g > 0:
572588
survey_cohort_weights[g] = float(np.sum(sw[unit_cohorts == g]))
573589

590+
# For unweighted RCS: fixed full-sample cohort counts (matching WIF)
591+
rcs_cohort_counts = None
592+
if (
593+
precomputed is not None
594+
and not precomputed.get("is_panel", True)
595+
and survey_cohort_weights is None
596+
):
597+
unit_cohorts_es = precomputed["unit_cohorts"]
598+
rcs_cohort_counts = {}
599+
for g in np.unique(unit_cohorts_es):
600+
if g > 0:
601+
rcs_cohort_counts[g] = int(np.sum(unit_cohorts_es == g))
602+
574603
for (g, t), data in group_time_effects.items():
575604
e = t - g # Relative time
576605
if e not in effects_by_e:
577606
effects_by_e[e] = []
578-
w = (
579-
survey_cohort_weights[g]
580-
if survey_cohort_weights is not None and g in survey_cohort_weights
581-
else data["n_treated"]
582-
)
607+
if survey_cohort_weights is not None and g in survey_cohort_weights:
608+
w = survey_cohort_weights[g]
609+
elif rcs_cohort_counts is not None and g in rcs_cohort_counts:
610+
w = rcs_cohort_counts[g]
611+
else:
612+
w = data["n_treated"]
583613
effects_by_e[e].append(
584614
(
585615
(g, t), # Keep track of the (g,t) pair
@@ -733,8 +763,16 @@ def _aggregate_event_study(
733763

734764
meat, _, _ = _compute_stratified_psu_meat(Psi, resolved_survey)
735765
event_study_vcov = meat
766+
elif (
767+
resolved_survey is not None
768+
and hasattr(resolved_survey, "uses_replicate_variance")
769+
and resolved_survey.uses_replicate_variance
770+
):
771+
# Replicate-weight: fall back to None (diagonal in HonestDiD)
772+
# until multivariate replicate VCV is implemented
773+
event_study_vcov = None
736774
else:
737-
# Simple sum-of-outer-products (no survey or replicate-only)
775+
# No survey: simple sum-of-outer-products
738776
event_study_vcov = Psi.T @ Psi
739777
except (ValueError, np.linalg.LinAlgError):
740778
pass # Fall back to diagonal (None)

diff_diff/staggered_results.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ class CallawaySantAnnaResults:
111111
alpha: float = 0.05
112112
control_group: str = "never_treated"
113113
base_period: str = "varying"
114+
panel: bool = True
114115
event_study_effects: Optional[Dict[int, Dict[str, Any]]] = field(default=None)
115116
group_effects: Optional[Dict[Any, Dict[str, Any]]] = field(default=None)
116117
influence_functions: Optional["np.ndarray"] = field(default=None, repr=False)
@@ -155,8 +156,8 @@ def summary(self, alpha: Optional[float] = None) -> str:
155156
"=" * 85,
156157
"",
157158
f"{'Total observations:':<30} {self.n_obs:>10}",
158-
f"{'Treated units:':<30} {self.n_treated_units:>10}",
159-
f"{'Never-treated units:':<30} {self.n_control_units:>10}",
159+
f"{'Treated ' + ('obs:' if not self.panel else 'units:'):<30} {self.n_treated_units:>10}",
160+
f"{'Control ' + ('obs:' if not self.panel else 'units:'):<30} {self.n_control_units:>10}",
160161
f"{'Treatment cohorts:':<30} {len(self.groups):>10}",
161162
f"{'Time periods:':<30} {len(self.time_periods):>10}",
162163
f"{'Control group:':<30} {self.control_group:>10}",

tests/test_honest_did.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1243,6 +1243,43 @@ def test_no_survey_gives_none_df(self):
12431243
assert h_result.df_survey is None
12441244
assert h_result.survey_metadata is None
12451245

1246+
def test_replicate_weight_uses_diagonal_fallback(self):
1247+
"""Replicate-weight designs should NOT produce full event_study_vcov."""
1248+
from diff_diff import CallawaySantAnna, SurveyDesign, generate_staggered_data
1249+
1250+
data = generate_staggered_data(n_units=100, n_periods=5, seed=42)
1251+
unit_ids = data["unit"].unique()
1252+
n_units = len(unit_ids)
1253+
unit_map = {uid: i for i, uid in enumerate(unit_ids)}
1254+
idx = data["unit"].map(unit_map).values
1255+
1256+
# Create replicate weights (4 replicates)
1257+
rng = np.random.default_rng(42)
1258+
data["weight"] = (1.0 + 0.3 * (np.arange(n_units) % 3))[idx]
1259+
for k in range(4):
1260+
data[f"repwt_{k}"] = data["weight"] * rng.uniform(0.8, 1.2, len(data))
1261+
# Make constant within unit
1262+
unit_rw = data.groupby("unit")[f"repwt_{k}"].first()
1263+
data[f"repwt_{k}"] = data["unit"].map(unit_rw)
1264+
1265+
sd = SurveyDesign(
1266+
weights="weight",
1267+
replicate_weights=[f"repwt_{k}" for k in range(4)],
1268+
replicate_method="JK1",
1269+
)
1270+
cs_result = CallawaySantAnna().fit(
1271+
data,
1272+
"outcome",
1273+
"unit",
1274+
"period",
1275+
"first_treat",
1276+
survey_design=sd,
1277+
aggregate="event_study",
1278+
)
1279+
1280+
# event_study_vcov should be None (diagonal fallback for replicate designs)
1281+
assert cs_result.event_study_vcov is None
1282+
12461283

12471284
# =============================================================================
12481285
# Tests for Visualization (without matplotlib)

tests/test_staggered_rc.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -349,3 +349,57 @@ def test_empty_cell_nan(self):
349349
v["effect"] for v in result.group_time_effects.values() if np.isfinite(v["effect"])
350350
]
351351
assert len(finite_effects) > 0
352+
353+
354+
# =============================================================================
355+
# Methodology: IF corrections change SE
356+
# =============================================================================
357+
358+
359+
class TestIFCorrections:
360+
"""Verify RCS DR/IPW IF corrections are non-trivial."""
361+
362+
def test_dr_se_differs_from_reg_rc(self, rc_data_with_covariates):
363+
"""DR and reg should give different SEs in RCS (DR has IF corrections)."""
364+
r_reg = CallawaySantAnna(estimation_method="reg", panel=False).fit(
365+
rc_data_with_covariates,
366+
"outcome",
367+
"unit",
368+
"period",
369+
"first_treat",
370+
covariates=["x1"],
371+
)
372+
r_dr = CallawaySantAnna(estimation_method="dr", panel=False).fit(
373+
rc_data_with_covariates,
374+
"outcome",
375+
"unit",
376+
"period",
377+
"first_treat",
378+
covariates=["x1"],
379+
)
380+
# SEs should differ (DR has nuisance IF corrections)
381+
assert r_reg.overall_se != r_dr.overall_se
382+
383+
def test_panel_field_on_results(self, rc_data):
384+
"""panel=False should be reflected on CallawaySantAnnaResults."""
385+
result = CallawaySantAnna(estimation_method="reg", panel=False).fit(
386+
rc_data,
387+
"outcome",
388+
"unit",
389+
"period",
390+
"first_treat",
391+
)
392+
assert result.panel is False
393+
394+
def test_summary_labels_rcs(self, rc_data):
395+
"""Summary should use 'obs' labels for RCS, not 'units'."""
396+
result = CallawaySantAnna(estimation_method="reg", panel=False).fit(
397+
rc_data,
398+
"outcome",
399+
"unit",
400+
"period",
401+
"first_treat",
402+
)
403+
summary = result.summary()
404+
assert "obs:" in summary
405+
assert "units:" not in summary.split("\n")[3] # Treated line

0 commit comments

Comments
 (0)