Skip to content

Commit 3798bd4

Browse files
igerberclaude
andcommitted
Fix ContinuousDiD fweight TSL scaling and BaconDecomposition empty-cell guard from PR #226 review (round 6)
- ContinuousDiD: rescale IFs by unit-level total survey mass (unit_resolved.weights.sum()) instead of hard-coded n_units, so TSL SEs are correct for fweight designs where mass != n_units - BaconDecomposition: add empty-cell guard in _compute_treated_vs_never before np.average() to prevent crashes on unbalanced/filtered panels Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent d26d498 commit 3798bd4

2 files changed

Lines changed: 20 additions & 9 deletions

File tree

diff_diff/bacon.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -822,6 +822,15 @@ def _compute_treated_vs_never(
822822
never_pre_mask = never_mask & df[time].isin(pre_periods)
823823
never_post_mask = never_mask & df[time].isin(post_periods)
824824

825+
# Guard against empty cells (unbalanced/filtered panels)
826+
if not (
827+
np.any(treated_pre_mask)
828+
and np.any(treated_post_mask)
829+
and np.any(never_pre_mask)
830+
and np.any(never_post_mask)
831+
):
832+
return None
833+
825834
treated_pre = np.average(y[treated_pre_mask], weights=w[treated_pre_mask])
826835
treated_post = np.average(y[treated_post_mask], weights=w[treated_post_mask])
827836
never_pre = np.average(y[never_pre_mask], weights=w[never_pre_mask])

diff_diff/continuous_did.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -639,8 +639,9 @@ def fit(
639639
# Compute SE: survey-aware TSL or standard sqrt(sum(IF^2))
640640
if unit_resolved_es is not None:
641641
X_ones_es = np.ones((n_units, 1))
642-
# Rescale IFs from 1/n convention to score scale for TSL
643-
if_es_tsl = if_es * n_units
642+
# Rescale IFs by total survey mass (not n_units) for fweight support
643+
tsl_scale_es = float(unit_resolved_es.weights.sum())
644+
if_es_tsl = if_es * tsl_scale_es
644645
vcov_es = compute_survey_vcov(X_ones_es, if_es_tsl, unit_resolved_es)
645646
es_se = float(np.sqrt(np.abs(vcov_es[0, 0])))
646647
else:
@@ -1198,13 +1199,14 @@ def _compute_analytical_se(
11981199
# Rescale IFs from 1/n convention to score scale for TSL sandwich.
11991200
# The per-unit IFs contain internal 1/n_t, 1/n_c scaling (for the
12001201
# unweighted SE = sqrt(sum(IF^2)) convention). compute_survey_vcov
1201-
# applies its own (X'WX)^{-1} ≈ 1/n bread, which would double-count.
1202-
# Multiplying by n_units undoes the internal scaling so TSL gives
1203-
# the correct variance.
1204-
if_att_glob_tsl = if_att_glob * n_units
1205-
if_acrt_glob_tsl = if_acrt_glob * n_units
1206-
if_att_d_tsl = if_att_d * n_units
1207-
if_acrt_d_tsl = if_acrt_d * n_units
1202+
# applies its own (X'WX)^{-1} bread, which would double-count.
1203+
# Rescale by the unit-level total survey mass (= n_units for
1204+
# pweight/aweight, but can differ for fweight).
1205+
tsl_scale = float(unit_resolved.weights.sum())
1206+
if_att_glob_tsl = if_att_glob * tsl_scale
1207+
if_acrt_glob_tsl = if_acrt_glob * tsl_scale
1208+
if_att_d_tsl = if_att_d * tsl_scale
1209+
if_acrt_d_tsl = if_acrt_d * tsl_scale
12081210

12091211
# Overall ATT SE via compute_survey_vcov
12101212
vcov_att = compute_survey_vcov(X_ones, if_att_glob_tsl, unit_resolved)

0 commit comments

Comments
 (0)