Skip to content

Commit eb5ee64

Browse files
igerberclaude
andcommitted
Fix 4 P1 issues from PR #226 review (round 2)
- ContinuousDiD: rescale IFs by n_units before compute_survey_vcov to avoid double-counting 1/n bread; use unit-level df_survey - EfficientDiD: align unit_first_panel_row to sorted all_units order; build unit-level ResolvedSurveyDesign once in fit(); use unit-level df - SunAbraham: thread survey weights into _compute_iw_effects and _compute_overall_att for survey-weighted cohort aggregation - StackedDiD: pass survey df to safe_inference for event-study and overall ATT p-values/CIs Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 29a787c commit eb5ee64

5 files changed

Lines changed: 130 additions & 79 deletions

File tree

diff_diff/continuous_did.py

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -509,8 +509,8 @@ def fit(
509509
overall_att_se = analytic["overall_att_se"]
510510
overall_acrt_se = analytic["overall_acrt_se"]
511511

512-
# Survey df for t-distribution inference
513-
_survey_df = survey_metadata.df_survey if survey_metadata is not None else None
512+
# Survey df for t-distribution inference (unit-level, not panel-level)
513+
_survey_df = analytic.get("df_survey")
514514

515515
overall_att_t, overall_att_p, overall_att_ci = safe_inference(
516516
overall_att, overall_att_se, self.alpha, df=_survey_df
@@ -626,7 +626,9 @@ def fit(
626626
# Compute SE: survey-aware TSL or standard sqrt(sum(IF^2))
627627
if unit_resolved_es is not None:
628628
X_ones_es = np.ones((n_units, 1))
629-
vcov_es = compute_survey_vcov(X_ones_es, if_es, unit_resolved_es)
629+
# Rescale IFs from 1/n convention to score scale for TSL
630+
if_es_tsl = if_es * n_units
631+
vcov_es = compute_survey_vcov(X_ones_es, if_es_tsl, unit_resolved_es)
630632
es_se = float(np.sqrt(np.abs(vcov_es[0, 0])))
631633
else:
632634
es_se = float(np.sqrt(np.sum(if_es**2)))
@@ -1162,22 +1164,33 @@ def _compute_analytical_se(
11621164

11631165
X_ones = np.ones((n_units, 1))
11641166

1167+
# Rescale IFs from 1/n convention to score scale for TSL sandwich.
1168+
# The per-unit IFs contain internal 1/n_t, 1/n_c scaling (for the
1169+
# unweighted SE = sqrt(sum(IF^2)) convention). compute_survey_vcov
1170+
# applies its own (X'WX)^{-1} ≈ 1/n bread, which would double-count.
1171+
# Multiplying by n_units undoes the internal scaling so TSL gives
1172+
# the correct variance.
1173+
if_att_glob_tsl = if_att_glob * n_units
1174+
if_acrt_glob_tsl = if_acrt_glob * n_units
1175+
if_att_d_tsl = if_att_d * n_units
1176+
if_acrt_d_tsl = if_acrt_d * n_units
1177+
11651178
# Overall ATT SE via compute_survey_vcov
1166-
vcov_att = compute_survey_vcov(X_ones, if_att_glob, unit_resolved)
1179+
vcov_att = compute_survey_vcov(X_ones, if_att_glob_tsl, unit_resolved)
11671180
overall_att_se = float(np.sqrt(np.abs(vcov_att[0, 0])))
11681181

11691182
# Overall ACRT SE via compute_survey_vcov
1170-
vcov_acrt = compute_survey_vcov(X_ones, if_acrt_glob, unit_resolved)
1183+
vcov_acrt = compute_survey_vcov(X_ones, if_acrt_glob_tsl, unit_resolved)
11711184
overall_acrt_se = float(np.sqrt(np.abs(vcov_acrt[0, 0])))
11721185

11731186
# Per-grid-point SEs for dose-response curves
11741187
att_d_se = np.zeros(n_grid)
11751188
acrt_d_se = np.zeros(n_grid)
11761189
for d_idx in range(n_grid):
1177-
vcov_d = compute_survey_vcov(X_ones, if_att_d[:, d_idx], unit_resolved)
1190+
vcov_d = compute_survey_vcov(X_ones, if_att_d_tsl[:, d_idx], unit_resolved)
11781191
att_d_se[d_idx] = float(np.sqrt(np.abs(vcov_d[0, 0])))
11791192

1180-
vcov_d = compute_survey_vcov(X_ones, if_acrt_d[:, d_idx], unit_resolved)
1193+
vcov_d = compute_survey_vcov(X_ones, if_acrt_d_tsl[:, d_idx], unit_resolved)
11811194
acrt_d_se[d_idx] = float(np.sqrt(np.abs(vcov_d[0, 0])))
11821195
else:
11831196
# SE = sqrt(sum(IF_i^2)), matching CallawaySantAnna's convention
@@ -1188,11 +1201,15 @@ def _compute_analytical_se(
11881201
att_d_se = np.sqrt(np.sum(if_att_d**2, axis=0))
11891202
acrt_d_se = np.sqrt(np.sum(if_acrt_d**2, axis=0))
11901203

1204+
# Return unit-level survey df when available (for t-distribution inference)
1205+
unit_df_survey = unit_resolved.df_survey if resolved_survey is not None else None
1206+
11911207
return {
11921208
"overall_att_se": overall_att_se,
11931209
"overall_acrt_se": overall_acrt_se,
11941210
"att_d_se": att_d_se,
11951211
"acrt_d_se": acrt_d_se,
1212+
"df_survey": unit_df_survey,
11961213
}
11971214

11981215
def _run_bootstrap(

diff_diff/efficient_did.py

Lines changed: 53 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ def __init__(
133133
self.kernel_bandwidth = kernel_bandwidth
134134
self.is_fitted_ = False
135135
self.results_: Optional[EfficientDiDResults] = None
136-
self._survey_se_ctx: Optional[tuple] = None
136+
self._unit_resolved_survey = None
137137
self._validate_params()
138138

139139
def _validate_params(self) -> None:
@@ -361,9 +361,45 @@ def fit(
361361
all_units = sorted(df[unit].unique())
362362
n_units = len(all_units)
363363

364-
# Build unit-to-first-panel-row index (for unit-level survey collapse)
365-
_first_rows = df.groupby(unit).cumcount() == 0
366-
self._unit_first_panel_row = np.where(_first_rows)[0]
364+
# Build unit-to-first-panel-row index aligned to all_units (sorted)
365+
# order. The previous approach (groupby cumcount == 0) yielded
366+
# first-appearance order which can differ from sorted order when the
367+
# input DataFrame is not pre-sorted by unit.
368+
first_pos: Dict[Any, int] = {}
369+
for i, u in enumerate(df[unit].values):
370+
if u not in first_pos:
371+
first_pos[u] = i
372+
self._unit_first_panel_row = np.array([first_pos[u] for u in all_units])
373+
374+
# Build unit-level ResolvedSurveyDesign once (avoids repeated
375+
# construction in _compute_survey_eif_se and ensures consistent
376+
# unit-level df for safe_inference t-distribution).
377+
if resolved_survey is not None:
378+
from diff_diff.survey import ResolvedSurveyDesign
379+
380+
row_idx = self._unit_first_panel_row
381+
unit_weights_s = resolved_survey.weights[row_idx]
382+
unit_strata = (
383+
resolved_survey.strata[row_idx] if resolved_survey.strata is not None else None
384+
)
385+
unit_psu = resolved_survey.psu[row_idx] if resolved_survey.psu is not None else None
386+
unit_fpc = resolved_survey.fpc[row_idx] if resolved_survey.fpc is not None else None
387+
n_strata_u = len(np.unique(unit_strata)) if unit_strata is not None else 0
388+
n_psu_u = len(np.unique(unit_psu)) if unit_psu is not None else 0
389+
self._unit_resolved_survey = ResolvedSurveyDesign(
390+
weights=unit_weights_s,
391+
weight_type=resolved_survey.weight_type,
392+
strata=unit_strata,
393+
psu=unit_psu,
394+
fpc=unit_fpc,
395+
n_strata=n_strata_u,
396+
n_psu=n_psu_u,
397+
lonely_psu=resolved_survey.lonely_psu,
398+
)
399+
# Use unit-level df (not panel-level) for t-distribution
400+
self._survey_df = self._unit_resolved_survey.df_survey
401+
else:
402+
self._unit_resolved_survey = None
367403

368404
period_to_col = {p: i for i, p in enumerate(time_periods)}
369405
period_1 = time_periods[0]
@@ -686,11 +722,8 @@ def fit(
686722

687723
# Analytical SE = sqrt(mean(EIF^2) / n) [paper p.21]
688724
# With survey: use TSL variance via compute_survey_vcov
689-
if resolved_survey is not None:
690-
se_gt = self._compute_survey_eif_se(
691-
eif_vals,
692-
resolved_survey,
693-
)
725+
if self._unit_resolved_survey is not None:
726+
se_gt = self._compute_survey_eif_se(eif_vals)
694727
else:
695728
se_gt = float(np.sqrt(np.mean(eif_vals**2) / n_units))
696729

@@ -714,12 +747,6 @@ def fit(
714747
"Check data has sufficient observations."
715748
)
716749

717-
# ----- Store survey context for aggregation SE helpers -----
718-
# Temporarily store survey context for use in aggregation helpers.
719-
# This avoids threading survey args through the deeply nested
720-
# aggregation methods that are also used by the bootstrap mixin.
721-
self._survey_se_ctx = resolved_survey if resolved_survey is not None else None
722-
723750
# ----- Aggregation -----
724751
overall_att, overall_se = self._aggregate_overall(
725752
group_time_effects, eif_by_gt, n_units, cohort_fractions, unit_cohorts
@@ -752,9 +779,6 @@ def fit(
752779
unit_cohorts=unit_cohorts,
753780
)
754781

755-
# Clean up temporary survey context
756-
self._survey_se_ctx = None
757-
758782
# ----- Bootstrap -----
759783
bootstrap_results = None
760784
if self.n_bootstrap > 0 and eif_by_gt:
@@ -855,63 +879,27 @@ def fit(
855879

856880
# -- Survey SE helpers ----------------------------------------------------
857881

858-
def _compute_survey_eif_se(
859-
self,
860-
eif_vals: np.ndarray,
861-
resolved_survey: Any,
862-
) -> float:
882+
def _compute_survey_eif_se(self, eif_vals: np.ndarray) -> float:
863883
"""Compute SE from EIF scores using Taylor Series Linearization.
864884
865-
The EIF is at unit level (shape n_units). We collapse the
866-
panel-level resolved survey to unit level using the first-panel-row
867-
index and pass unit-level arrays to ``compute_survey_vcov``.
868-
This avoids the previous bug where expanding EIF to panel rows
869-
created one implicit PSU per period-copy, deflating SEs for
870-
weights-only and stratified-no-PSU survey designs.
885+
Uses the pre-built unit-level ``_unit_resolved_survey`` constructed
886+
once in ``fit()``, ensuring consistent unit-level arrays and
887+
avoiding repeated subsetting of panel-level survey data.
871888
"""
872-
from diff_diff.survey import ResolvedSurveyDesign, compute_survey_vcov
889+
from diff_diff.survey import compute_survey_vcov
873890

874-
row_idx = self._unit_first_panel_row
875-
n_units = len(eif_vals)
876-
877-
# Subset survey arrays to unit level
878-
unit_weights = resolved_survey.weights[row_idx]
879-
unit_strata = (
880-
resolved_survey.strata[row_idx] if resolved_survey.strata is not None else None
881-
)
882-
unit_psu = resolved_survey.psu[row_idx] if resolved_survey.psu is not None else None
883-
unit_fpc = resolved_survey.fpc[row_idx] if resolved_survey.fpc is not None else None
884-
885-
# Count unique strata/PSU in the unit-level subset
886-
n_strata_unit = len(np.unique(unit_strata)) if unit_strata is not None else 0
887-
n_psu_unit = len(np.unique(unit_psu)) if unit_psu is not None else 0
888-
889-
unit_resolved = ResolvedSurveyDesign(
890-
weights=unit_weights,
891-
weight_type=resolved_survey.weight_type,
892-
strata=unit_strata,
893-
psu=unit_psu,
894-
fpc=unit_fpc,
895-
n_strata=n_strata_unit,
896-
n_psu=n_psu_unit,
897-
lonely_psu=resolved_survey.lonely_psu,
898-
)
899-
900-
X_ones = np.ones((n_units, 1))
901-
vcov = compute_survey_vcov(X_ones, eif_vals, unit_resolved)
891+
X_ones = np.ones((len(eif_vals), 1))
892+
vcov = compute_survey_vcov(X_ones, eif_vals, self._unit_resolved_survey)
902893
return float(np.sqrt(np.abs(vcov[0, 0])))
903894

904895
def _eif_se(self, eif_vals: np.ndarray, n_units: int) -> float:
905896
"""Compute SE from aggregated EIF scores.
906897
907-
Dispatches to survey TSL when ``_survey_se_ctx`` is set (during
908-
fit), otherwise uses the standard analytical formula.
898+
Dispatches to survey TSL when ``_unit_resolved_survey`` is set
899+
(during fit), otherwise uses the standard analytical formula.
909900
"""
910-
if self._survey_se_ctx is not None:
911-
return self._compute_survey_eif_se(
912-
eif_vals,
913-
self._survey_se_ctx,
914-
)
901+
if self._unit_resolved_survey is not None:
902+
return self._compute_survey_eif_se(eif_vals)
915903
return float(np.sqrt(np.mean(eif_vals**2) / n_units))
916904

917905
# -- Aggregation helpers --------------------------------------------------

diff_diff/stacked_did.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -459,7 +459,14 @@ def fit(
459459
idx = interaction_indices[h]
460460
effect = float(coef[idx])
461461
se = float(np.sqrt(max(vcov[idx, idx], 0.0)))
462-
t_stat, p_value, conf_int = safe_inference(effect, se, alpha=self.alpha)
462+
_survey_df = (
463+
max(survey_metadata.df_survey, 1)
464+
if survey_metadata is not None and survey_metadata.df_survey is not None
465+
else None
466+
)
467+
t_stat, p_value, conf_int = safe_inference(
468+
effect, se, alpha=self.alpha, df=_survey_df
469+
)
463470
n_obs_h = int(np.sum((et_vals == h) & (d_vals == 1)))
464471
event_study_effects[h] = {
465472
"effect": effect,
@@ -489,7 +496,14 @@ def fit(
489496
overall_att = np.nan
490497
overall_se = np.nan
491498

492-
overall_t, overall_p, overall_ci = safe_inference(overall_att, overall_se, alpha=self.alpha)
499+
_survey_df_overall = (
500+
max(survey_metadata.df_survey, 1)
501+
if survey_metadata is not None and survey_metadata.df_survey is not None
502+
else None
503+
)
504+
overall_t, overall_p, overall_ci = safe_inference(
505+
overall_att, overall_se, alpha=self.alpha, df=_survey_df_overall
506+
)
493507

494508
# ---- Construct results ----
495509
self.results_ = StackedDiDResults(

diff_diff/sun_abraham.py

Lines changed: 35 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -625,6 +625,15 @@ def fit(
625625
resolved_survey=resolved_survey,
626626
)
627627

628+
# Resolve survey weight column name for cohort aggregation
629+
survey_weight_col = (
630+
survey_design.weights
631+
if survey_design is not None
632+
and hasattr(survey_design, "weights")
633+
and survey_design.weights
634+
else None
635+
)
636+
628637
# Compute interaction-weighted event study effects
629638
event_study_effects, cohort_weights = self._compute_iw_effects(
630639
df,
@@ -636,6 +645,7 @@ def fit(
636645
cohort_ses,
637646
vcov_cohort,
638647
coef_index_map,
648+
survey_weight_col=survey_weight_col,
639649
)
640650

641651
# Compute overall ATT (average of post-treatment effects)
@@ -647,6 +657,7 @@ def fit(
647657
cohort_weights,
648658
vcov_cohort,
649659
coef_index_map,
660+
survey_weight_col=survey_weight_col,
650661
)
651662

652663
overall_t, overall_p, overall_ci = safe_inference(overall_att, overall_se, alpha=self.alpha)
@@ -869,6 +880,7 @@ def _compute_iw_effects(
869880
cohort_ses: Dict[Tuple[Any, int], float],
870881
vcov_cohort: np.ndarray,
871882
coef_index_map: Dict[Tuple[Any, int], int],
883+
survey_weight_col: Optional[str] = None,
872884
) -> Tuple[Dict[int, Dict[str, Any]], Dict[int, Dict[Any, float]]]:
873885
"""
874886
Compute interaction-weighted event study effects.
@@ -878,6 +890,10 @@ def _compute_iw_effects(
878890
where w_{g,e} = n_{g,e} / Σ_g n_{g,e} is the share of observations from cohort g
879891
at event-time e among all treated observations at that event-time.
880892
893+
When survey weights are provided, n_{g,e} is the survey-weighted mass
894+
(sum of weights) rather than raw observation counts, so the estimand
895+
reflects the survey-weighted cohort composition.
896+
881897
Returns
882898
-------
883899
event_study_effects : dict
@@ -888,8 +904,15 @@ def _compute_iw_effects(
888904
event_study_effects: Dict[int, Dict[str, Any]] = {}
889905
cohort_weights: Dict[int, Dict[Any, float]] = {}
890906

891-
# Pre-compute per-event-time observation counts: n_{g,e}
892-
event_time_counts = df[df[first_treat] > 0].groupby([first_treat, "_rel_time"]).size()
907+
# Pre-compute per-event-time observation mass: n_{g,e}
908+
# With survey weights, use weighted sum; otherwise raw counts.
909+
treated_mask = df[first_treat] > 0
910+
if survey_weight_col is not None and survey_weight_col in df.columns:
911+
event_time_counts = (
912+
df[treated_mask].groupby([first_treat, "_rel_time"])[survey_weight_col].sum()
913+
)
914+
else:
915+
event_time_counts = df[treated_mask].groupby([first_treat, "_rel_time"]).size()
893916

894917
for e in rel_periods:
895918
# Get cohorts that have observations at this relative time
@@ -951,23 +974,31 @@ def _compute_overall_att(
951974
cohort_weights: Dict[int, Dict[Any, float]],
952975
vcov_cohort: np.ndarray,
953976
coef_index_map: Dict[Tuple[Any, int], int],
977+
survey_weight_col: Optional[str] = None,
954978
) -> Tuple[float, float]:
955979
"""
956980
Compute overall ATT as weighted average of post-treatment effects.
957981
982+
When survey weights are provided, the per-period weights use
983+
survey-weighted mass rather than raw observation counts.
984+
958985
Returns (att, se) tuple.
959986
"""
960987
post_effects = [(e, eff) for e, eff in event_study_effects.items() if e >= 0]
961988

962989
if not post_effects:
963990
return np.nan, np.nan
964991

965-
# Weight by number of treated observations at each relative time
992+
# Weight by (survey-weighted) mass of treated observations at each relative time
966993
post_weights = []
967994
post_estimates = []
968995

969996
for e, eff in post_effects:
970-
n_at_e = len(df[(df["_rel_time"] == e) & (df[first_treat] > 0)])
997+
mask = (df["_rel_time"] == e) & (df[first_treat] > 0)
998+
if survey_weight_col is not None and survey_weight_col in df.columns:
999+
n_at_e = df.loc[mask, survey_weight_col].sum()
1000+
else:
1001+
n_at_e = len(df[mask])
9711002
post_weights.append(max(n_at_e, 1))
9721003
post_estimates.append(eff["effect"])
9731004

tests/test_sun_abraham.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1291,7 +1291,8 @@ def test_variance_fallback_warning(self):
12911291

12921292
def patched_compute_overall_att(df, first_treat, event_study_effects,
12931293
cohort_effects, cohort_weights,
1294-
vcov_cohort, coef_index_map):
1294+
vcov_cohort, coef_index_map,
1295+
survey_weight_col=None):
12951296
# Pass an empty coef_index_map to trigger the fallback
12961297
return original_method(
12971298
df, first_treat, event_study_effects,

0 commit comments

Comments
 (0)