Skip to content

Commit b376043

Browse files
igerberclaude
andcommitted
Fix CI review R9: align controls aggregation to effective sample
R8's controls-block fix scoped NaN/Inf validation to the positive-weight subset (shorter data_controls) but then assigned those shorter arrays into an x_agg_input built from the full-length frame, causing a length-mismatch on any SurveyDesign.subpopulation() / zero-weight excluded row before covariate aggregation could run. Root-caused fix: derive both the validation frame AND the aggregation frame from the same positive-weight effective sample (data_eff, survey_weights_eff). Zero-weight rows are genuinely out-of-sample throughout the DID^X path now. Non-survey fits unchanged. Added TestZeroWeightSubpopulation.test_zero_weight_row_with_nan_control pinning the subpopulation contract for the DID^X path — injects a zero-weight row with NaN control value and asserts fit() succeeds. All 263 targeted tests pass. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent de8ff5e commit b376043

2 files changed

Lines changed: 37 additions & 12 deletions

File tree

diff_diff/chaisemartin_dhaultfoeuille.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -730,16 +730,18 @@ def fit(
730730
f"data. Available columns: {list(data.columns)}"
731731
)
732732
# SurveyDesign.subpopulation() contract: zero-weight rows are
733-
# out-of-sample. Scope NaN/Inf validation to positive-weight
734-
# rows so that excluded obs with missing covariates do not
735-
# abort the fit. The downstream weighted aggregation
736-
# (sum(w*x)/sum(w)) handles zero-weight rows correctly on
737-
# its own.
733+
# out-of-sample. Scope BOTH validation and aggregation to the
734+
# positive-weight subset so excluded rows with missing/invalid
735+
# covariates do not abort the fit and cell aggregation aligns
736+
# with the effective sample used by _validate_and_aggregate_to_cells.
738737
if survey_weights is not None:
739738
pos_mask_ctrl = np.asarray(survey_weights) > 0
740-
data_controls = data.loc[pos_mask_ctrl, controls].copy()
739+
data_eff = data.loc[pos_mask_ctrl]
740+
survey_weights_eff = np.asarray(survey_weights)[pos_mask_ctrl]
741741
else:
742-
data_controls = data[controls].copy()
742+
data_eff = data
743+
survey_weights_eff = None
744+
data_controls = data_eff[controls].copy()
743745
for c in controls:
744746
try:
745747
data_controls[c] = pd.to_numeric(data_controls[c])
@@ -760,14 +762,15 @@ def fit(
760762
"Remove or replace non-finite covariates before fitting."
761763
)
762764
# Aggregate covariates to cell means (same groupby as treatment/outcome).
763-
# Use the coerced copy joined with group/time from original data.
764-
x_agg_input = data[[group, time]].copy()
765+
# Build x_agg_input from the same effective-sample frame so rows
766+
# align with data_controls.
767+
x_agg_input = data_eff[[group, time]].copy()
765768
x_agg_input[controls] = data_controls[controls].values
766-
if survey_weights is not None:
769+
if survey_weights_eff is not None:
767770
# Survey-weighted covariate cell means: sum(w*x)/sum(w)
768-
x_agg_input["_w_"] = survey_weights
771+
x_agg_input["_w_"] = survey_weights_eff
769772
for c in controls:
770-
x_agg_input[f"_wx_{c}"] = survey_weights * x_agg_input[c].values
773+
x_agg_input[f"_wx_{c}"] = survey_weights_eff * x_agg_input[c].values
771774
wx_cols = [f"_wx_{c}" for c in controls]
772775
g_agg = x_agg_input.groupby([group, time], as_index=False).agg(
773776
{**{wc: "sum" for wc in wx_cols}, "_w_": "sum"}

tests/test_survey_dcdh.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -861,6 +861,28 @@ def test_zero_weight_row_with_nan_outcome(self, base_data):
861861
)
862862
assert np.isfinite(result.overall_att)
863863

864+
def test_zero_weight_row_with_nan_control(self, base_data):
865+
"""A zero-weight row with NaN in a control column must not abort
866+
the DID^X path, and the covariate cell aggregation must use only
867+
positive-weight rows (no length-mismatch error)."""
868+
rng = np.random.default_rng(13)
869+
df_ = base_data.copy()
870+
df_["pw"] = 1.0
871+
df_["x"] = rng.normal(0, 1, size=len(df_))
872+
# Inject a zero-weight row with NaN control value
873+
sample = df_.iloc[0].copy()
874+
sample["x"] = np.nan
875+
sample["pw"] = 0.0
876+
df_ = pd.concat([df_, pd.DataFrame([sample])], ignore_index=True)
877+
sd = SurveyDesign(weights="pw")
878+
result = ChaisemartinDHaultfoeuille(seed=1).fit(
879+
df_,
880+
outcome="outcome", group="group",
881+
time="period", treatment="treatment",
882+
L_max=1, controls=["x"], survey_design=sd,
883+
)
884+
assert np.isfinite(result.overall_att)
885+
864886
def test_zero_weight_row_with_nan_heterogeneity(self, base_data):
865887
"""A zero-weight row with NaN in the heterogeneity column must
866888
not trip the heterogeneity time-invariance validator."""

0 commit comments

Comments
 (0)