Skip to content

Commit aeca1d8

Browse files
igerberclaude
andcommitted
Fix CI review R5: survey TWFE math consistency + zero-weight row filter
- P1 #1: _compute_twfe_diagnostic now uses cell_weight (w_gt when available, else n_gt) for FE regressions, the normalization denominator, contribution weights, and the Corollary 1 observation shares. On survey-backed inputs the outputs now match the observation-level pweighted TWFE estimand; non-survey path is byte-identical. - P1 #2: Zero-weight rows are dropped before the groupby in _validate_and_aggregate_to_cells when weights are provided, so that d_min/d_max/n_gt reflect the effective sample. Prevents zero-weight subpopulation rows from tripping the fuzzy-DiD guard or inflating downstream n_gt counts. - P2: 2 new regression tests in test_survey_dcdh.py — TestSurveyTWFEOracle.test_survey_twfe_matches_obs_level_pweighted_ols verifies beta_fe matches an observation-level pweighted OLS under survey (would fail if n_gt was still used), and TestZeroWeightSubpopulation.test_mixed_zero_weight_row_excluded_from_validation verifies an injected zero-weight row with opposite treatment value doesn't trip the within-cell constancy check. All 256 targeted tests pass. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 97789f5 commit aeca1d8

2 files changed

Lines changed: 131 additions & 12 deletions

File tree

diff_diff/chaisemartin_dhaultfoeuille.py

Lines changed: 37 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,18 @@ def _validate_and_aggregate_to_cells(
210210

211211
# 5. Cell aggregation (compute min/max for within-cell check)
212212
if weights is not None:
213+
# Zero-weight rows are out-of-sample (e.g., via
214+
# SurveyDesign.subpopulation()). Pre-filter them before the
215+
# groupby so that d_min/d_max/n_gt reflect the effective sample
216+
# and a zero-weight obs cannot trip the within-cell treatment-
217+
# constancy check or inflate downstream n_gt counts.
218+
weights_arr = np.asarray(weights, dtype=np.float64)
219+
pos_mask = weights_arr > 0
220+
if not pos_mask.all():
221+
df = df.loc[pos_mask].reset_index(drop=True)
222+
weights_arr = weights_arr[pos_mask]
223+
weights = weights_arr
224+
213225
# Survey-weighted cell aggregation.
214226
# y_gt = sum(w_i * y_i) / sum(w_i) within each (g, t) cell.
215227
# Treatment is constant within cells (checked below), so weighted
@@ -4828,7 +4840,16 @@ def _compute_twfe_diagnostic(
48284840
"""
48294841
X, _ = _build_group_time_design(cell, group_col, time_col)
48304842
d_arr = cell["d_gt"].to_numpy().astype(float)
4831-
n_arr = cell["n_gt"].to_numpy().astype(float)
4843+
# Cell weight for Theorem 1: under survey_design, survey-weighted
4844+
# cell totals (w_gt) replace raw cell counts (n_gt) so the FE
4845+
# regressions, normalization denominator, and Corollary 1 shares
4846+
# match the observation-level pweighted TWFE estimand. Without
4847+
# survey_design (w_gt column absent), fall back to n_gt — the
4848+
# non-survey path is unchanged.
4849+
if "w_gt" in cell.columns:
4850+
cell_weight = cell["w_gt"].to_numpy().astype(float)
4851+
else:
4852+
cell_weight = cell["n_gt"].to_numpy().astype(float)
48324853
y_arr = cell["y_gt"].to_numpy().astype(float)
48334854

48344855
# Step 1-2: regress d on FE
@@ -4837,13 +4858,13 @@ def _compute_twfe_diagnostic(
48374858
d_arr,
48384859
return_vcov=False,
48394860
rank_deficient_action=rank_deficient_action,
4840-
weights=n_arr,
4861+
weights=cell_weight,
48414862
)
48424863
eps = residuals_d
48434864

48444865
# Step 3: per-cell weights — normalize by sum over treated cells
48454866
treated_mask = d_arr == 1
4846-
denom = float((n_arr[treated_mask] * eps[treated_mask]).sum())
4867+
denom = float((cell_weight[treated_mask] * eps[treated_mask]).sum())
48474868
if denom == 0:
48484869
# Cannot normalize: the design has zero treated mass after FE absorption.
48494870
# Warn so the user knows the diagnostic returned NaN values rather than
@@ -4866,12 +4887,14 @@ def _compute_twfe_diagnostic(
48664887
sigma_fe=float("nan"),
48674888
beta_fe=float("nan"),
48684889
)
4869-
w_gt = (n_arr * eps) / denom
4890+
contribution_weights = (cell_weight * eps) / denom
48704891

48714892
weights_df = cell[[group_col, time_col]].copy()
4872-
weights_df["weight"] = w_gt
4893+
weights_df["weight"] = contribution_weights
48734894

4874-
fraction_negative = float((w_gt[treated_mask] < 0).sum() / treated_mask.sum())
4895+
fraction_negative = float(
4896+
(contribution_weights[treated_mask] < 0).sum() / treated_mask.sum()
4897+
)
48754898

48764899
# Step 5: plain TWFE regression of y on (FE + d_gt)
48774900
X_with_d = np.column_stack([X, d_arr.reshape(-1, 1)])
@@ -4880,7 +4903,7 @@ def _compute_twfe_diagnostic(
48804903
y_arr,
48814904
return_vcov=False,
48824905
rank_deficient_action=rank_deficient_action,
4883-
weights=n_arr,
4906+
weights=cell_weight,
48844907
)
48854908
beta_fe = float(coef_fe[-1])
48864909

@@ -4897,12 +4920,14 @@ def _compute_twfe_diagnostic(
48974920
# sigma(w) = sqrt(sum_treated(s * (w_paper - 1)^2))
48984921
# sigma_fe = |beta_fe| / sigma(w)
48994922
#
4900-
# where s_{g,t} = N_{g,t} / N_1 are observation shares.
4923+
# where s_{g,t} = N_{g,t} / N_1 are observation shares (under
4924+
# survey_design, cell_weight is w_gt so shares are effective-
4925+
# weight shares; non-survey path is byte-identical).
49014926
eps_treated = eps[treated_mask]
4902-
n_treated_arr = n_arr[treated_mask]
4903-
n1 = float(n_treated_arr.sum()) # total treated observations
4904-
if n1 > 0:
4905-
shares = n_treated_arr / n1 # s_{g,t} = N_{g,t} / N_1
4927+
w_treated_arr = cell_weight[treated_mask]
4928+
w1 = float(w_treated_arr.sum()) # total treated weight (N_1 or W_1)
4929+
if w1 > 0:
4930+
shares = w_treated_arr / w1 # s_{g,t} = w_{g,t} / w_1
49064931
denom_paper = float((shares * eps_treated).sum())
49074932
if abs(denom_paper) > 0:
49084933
w_paper = eps_treated / denom_paper # paper's w_{g,t}

tests/test_survey_dcdh.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -725,3 +725,97 @@ def test_twfe_helper_rejects_non_pweight(self, base_data):
725725
time="period", treatment="treatment",
726726
survey_design=sd,
727727
)
728+
729+
730+
# ── Test: TWFE diagnostic oracle under survey ───────────────────────
731+
732+
733+
class TestSurveyTWFEOracle:
734+
"""twfe_beta_fe under survey must match an observation-level pweighted
735+
TWFE regression on the same data (proving w_gt is used, not n_gt)."""
736+
737+
def test_survey_twfe_matches_obs_level_pweighted_ols(self, data_with_survey):
738+
from diff_diff.chaisemartin_dhaultfoeuille import twowayfeweights
739+
from diff_diff.linalg import solve_ols
740+
741+
sd = SurveyDesign(weights="pw")
742+
helper = twowayfeweights(
743+
data_with_survey,
744+
outcome="outcome", group="group",
745+
time="period", treatment="treatment",
746+
survey_design=sd,
747+
)
748+
assert np.isfinite(helper.beta_fe)
749+
750+
# Build observation-level TWFE design with group and period FE
751+
# (reference category dropped) and treatment indicator.
752+
df_ = data_with_survey.copy()
753+
groups_u = sorted(df_["group"].unique())
754+
periods_u = sorted(df_["period"].unique())
755+
g_map = {g: i for i, g in enumerate(groups_u)}
756+
t_map = {t: i for i, t in enumerate(periods_u)}
757+
g_idx = df_["group"].map(g_map).to_numpy()
758+
t_idx = df_["period"].map(t_map).to_numpy()
759+
n = len(df_)
760+
X_g = np.zeros((n, len(groups_u) - 1))
761+
X_t = np.zeros((n, len(periods_u) - 1))
762+
for i in range(n):
763+
if g_idx[i] > 0:
764+
X_g[i, g_idx[i] - 1] = 1.0
765+
if t_idx[i] > 0:
766+
X_t[i, t_idx[i] - 1] = 1.0
767+
intercept = np.ones((n, 1))
768+
treat = df_["treatment"].to_numpy().astype(float).reshape(-1, 1)
769+
X_obs = np.hstack([intercept, X_g, X_t, treat])
770+
y_obs = df_["outcome"].to_numpy().astype(float)
771+
w_obs = df_["pw"].to_numpy().astype(float)
772+
773+
coef, _, _ = solve_ols(
774+
X_obs, y_obs,
775+
weights=w_obs, weight_type="pweight",
776+
return_vcov=False,
777+
)
778+
beta_oracle = float(coef[-1])
779+
# Point-estimate match (one obs per cell in this fixture; so the
780+
# cell-level WLS with cell_weight == w_gt equals the obs-level
781+
# WLS with w_obs weights).
782+
assert helper.beta_fe == pytest.approx(beta_oracle, rel=1e-6), (
783+
f"helper.beta_fe={helper.beta_fe} oracle={beta_oracle} "
784+
f"— TWFE diagnostic must use w_gt under survey"
785+
)
786+
787+
788+
# ── Test: Zero-weight subpopulation exclusion ──────────────────────
789+
790+
791+
class TestZeroWeightSubpopulation:
792+
"""Zero-weight rows must not trip fuzzy-DiD guard or inflate counts."""
793+
794+
def test_mixed_zero_weight_row_excluded_from_validation(self, base_data):
795+
"""A cell with a positive-weight treated obs and a zero-weight
796+
obs with a different treatment value must fit cleanly — the
797+
zero-weight row is out-of-sample (SurveyDesign.subpopulation())."""
798+
df_ = base_data.copy()
799+
df_["pw"] = 1.0
800+
# Pick a treated (g, t) cell. Add a zero-weight row in the same
801+
# cell with the opposite treatment value. Unweighted d_min != d_max
802+
# would trip the fuzzy-DiD guard; pre-filtering zero-weight rows
803+
# must bypass it.
804+
treated_mask = df_["treatment"] == 1
805+
if not treated_mask.any():
806+
pytest.skip("no treated row in fixture")
807+
sample = df_[treated_mask].iloc[0].copy()
808+
# Flip treatment on the injected row, give it zero weight
809+
sample["treatment"] = 0
810+
sample["pw"] = 0.0
811+
df_ = pd.concat([df_, pd.DataFrame([sample])], ignore_index=True)
812+
sd = SurveyDesign(weights="pw")
813+
814+
# Must succeed (not raise fuzzy-DiD ValueError)
815+
result = ChaisemartinDHaultfoeuille(seed=1).fit(
816+
df_,
817+
outcome="outcome", group="group",
818+
time="period", treatment="treatment",
819+
survey_design=sd,
820+
)
821+
assert np.isfinite(result.overall_att)

0 commit comments

Comments
 (0)