Skip to content

Commit f04fe0d

Browse files
igerberclaude
andcommitted
Add fit-ready weight column mapping NaN precision to 0.0
The returned SurveyDesign now points at a {outcome}_weight column where NaN/Inf precision values are mapped to 0.0, so downstream fit() never rejects missing weights. Diagnostic *_precision column is preserved as-is. Add stage2-handoff test with single-observation cell (NaN precision → zero weight → fit succeeds). Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 703a7fe commit f04fe0d

2 files changed

Lines changed: 69 additions & 5 deletions

File tree

diff_diff/prep.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1453,8 +1453,10 @@ def aggregate_survey(
14531453
panel_df : pd.DataFrame
14541454
Aggregated panel with columns: grouping variables,
14551455
``{outcome}_mean``, ``{outcome}_se``, ``{outcome}_n``,
1456-
``{outcome}_precision``, ``{covariate}_mean``, ``cell_n``,
1457-
``cell_n_eff``, ``srs_fallback``.
1456+
``{outcome}_precision``, ``{outcome}_weight``,
1457+
``{covariate}_mean``, ``cell_n``, ``cell_n_eff``,
1458+
``srs_fallback``. The ``_weight`` column is a fit-ready
1459+
version of ``_precision`` with NaN/Inf mapped to 0.0.
14581460
second_stage_design : SurveyDesign
14591461
Pre-configured for second-stage estimation with
14601462
``weight_type="aweight"``, precision weights from the first
@@ -1637,9 +1639,18 @@ def aggregate_survey(
16371639
panel_df = panel_df.sort_values(by_cols).reset_index(drop=True)
16381640

16391641
# --- Construct second-stage SurveyDesign ---
1642+
# Create a fit-ready weight column: NaN/Inf precision → 0.0 so downstream
1643+
# resolve() doesn't reject missing weights. Diagnostic *_precision is kept.
16401644
first_outcome = outcome_cols[0]
1645+
weight_col = f"{first_outcome}_weight"
1646+
panel_df[weight_col] = np.where(
1647+
np.isfinite(panel_df[f"{first_outcome}_precision"]),
1648+
panel_df[f"{first_outcome}_precision"],
1649+
0.0,
1650+
)
1651+
16411652
second_stage_design = SurveyDesign(
1642-
weights=f"{first_outcome}_precision",
1653+
weights=weight_col,
16431654
weight_type="aweight",
16441655
psu=by_cols[0],
16451656
)

tests/test_prep.py

Lines changed: 55 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2057,7 +2057,7 @@ def test_multiple_outcomes(self, micro_data, design):
20572057
assert "y2_mean" in panel.columns
20582058
assert "y_precision" in panel.columns
20592059
assert "y2_precision" in panel.columns
2060-
assert stage2.weights == "y_precision"
2060+
assert stage2.weights == "y_weight"
20612061

20622062
def test_covariates_mean_only(self, micro_data, design):
20632063
"""Covariates get mean column only, no SE/precision."""
@@ -2081,7 +2081,7 @@ def test_returned_survey_design(self, micro_data, design):
20812081
survey_design=design,
20822082
)
20832083
assert stage2.weight_type == "aweight"
2084-
assert stage2.weights == "y_precision"
2084+
assert stage2.weights == "y_weight"
20852085
assert stage2.psu == "state"
20862086

20872087
def test_srs_fallback(self):
@@ -2452,6 +2452,59 @@ def test_error_all_missing_grouping_keys(self, design):
24522452
survey_design=design_simple,
24532453
)
24542454

2455+
def test_stage2_handoff_with_nonfinite_cells(self):
2456+
"""stage2 SurveyDesign works even when some cells have NaN precision."""
2457+
from diff_diff import DifferenceInDifferences
2458+
2459+
rng = np.random.RandomState(99)
2460+
rows = []
2461+
for state in range(4):
2462+
treated = 1 if state < 2 else 0
2463+
for period in [0, 1]:
2464+
te = 3.0 if (treated and period == 1) else 0.0
2465+
n_cell = 30
2466+
for _ in range(n_cell):
2467+
rows.append(
2468+
{
2469+
"state": state,
2470+
"period": period,
2471+
"wt": rng.uniform(0.5, 2.0),
2472+
"outcome": rng.normal(10 + te, 2),
2473+
"treated": treated,
2474+
}
2475+
)
2476+
micro = pd.DataFrame(rows)
2477+
# Make one cell have only 1 observation → NaN SE → NaN precision
2478+
mask = (micro["state"] == 0) & (micro["period"] == 0)
2479+
micro = micro.drop(micro[mask].index[1:]) # keep only 1 row
2480+
2481+
design = SurveyDesign(weights="wt")
2482+
panel, stage2 = aggregate_survey(
2483+
micro,
2484+
by=["state", "period"],
2485+
outcomes="outcome",
2486+
covariates="treated",
2487+
survey_design=design,
2488+
)
2489+
2490+
# The zero-variance cell should have weight=0 (not NaN)
2491+
cell_00 = panel[(panel["state"] == 0) & (panel["period"] == 0)]
2492+
assert np.isnan(cell_00["outcome_precision"].iloc[0]) # diagnostic
2493+
assert cell_00["outcome_weight"].iloc[0] == 0.0 # fit-ready
2494+
2495+
# stage2 should work with fit() despite NaN-precision cells
2496+
panel["treated_bin"] = (panel["treated_mean"] > 0.5).astype(int)
2497+
did = DifferenceInDifferences()
2498+
result = did.fit(
2499+
panel,
2500+
outcome="outcome_mean",
2501+
treatment="treated_bin",
2502+
time="period",
2503+
survey_design=stage2,
2504+
)
2505+
assert np.isfinite(result.att)
2506+
assert np.isfinite(result.se)
2507+
24552508
def test_zero_weight_rows_excluded_from_n_valid(self):
24562509
"""Zero-weight rows should not count as valid observations."""
24572510
rng = np.random.RandomState(66)

0 commit comments

Comments
 (0)