From 2994d4961c9b0e19b756c30c7ce20c9c4cbd1906 Mon Sep 17 00:00:00 2001 From: igerber Date: Tue, 7 Apr 2026 09:18:20 -0400 Subject: [PATCH 01/16] Add aggregate_survey() for survey microdata-to-panel aggregation Adds a new prep utility that aggregates individual-level survey data (BRFSS, ACS, CPS) to geographic-period cells with design-based precision estimates. Returns a (DataFrame, SurveyDesign) tuple ready for second-stage DiD estimation with inverse-variance weights and geographic clustering. Computes Horvitz-Thompson influence functions per cell, routing to TSL or replicate variance via existing survey infrastructure. Falls back to SRS variance when within-cell design is too thin. Co-Authored-By: Claude Opus 4.6 (1M context) --- diff_diff/__init__.py | 2 + diff_diff/prep.py | 326 +++++++++++- docs/api/prep.rst | 42 ++ docs/methodology/REGISTRY.md | 26 + tests/test_prep.py | 944 ++++++++++++++++++++++++++--------- 5 files changed, 1088 insertions(+), 252 deletions(-) diff --git a/diff_diff/__init__.py b/diff_diff/__init__.py index f20be546..6aed7430 100644 --- a/diff_diff/__init__.py +++ b/diff_diff/__init__.py @@ -78,6 +78,7 @@ compute_pretrends_power, ) from diff_diff.prep import ( + aggregate_survey, aggregate_to_cohorts, balance_panel, create_event_time, @@ -328,6 +329,7 @@ "generate_survey_did_data", "generate_continuous_did_data", "create_event_time", + "aggregate_survey", "aggregate_to_cohorts", "rank_control_units", # Honest DiD sensitivity analysis diff --git a/diff_diff/prep.py b/diff_diff/prep.py index 97543576..a66fbd59 100644 --- a/diff_diff/prep.py +++ b/diff_diff/prep.py @@ -9,25 +9,30 @@ re-exported here for backward compatibility. """ -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Tuple, Union import numpy as np import pandas as pd -from diff_diff.utils import compute_synthetic_weights - # Re-export data generation functions from prep_dgp for backward compatibility -from diff_diff.prep_dgp import ( +from diff_diff.prep_dgp import ( # noqa: F401 generate_continuous_did_data, + generate_ddd_data, generate_did_data, - generate_staggered_data, + generate_event_study_data, generate_factor_data, - generate_ddd_data, generate_panel_data, - generate_event_study_data, + generate_staggered_data, generate_staggered_ddd_data, generate_survey_did_data, ) +from diff_diff.survey import ( + ResolvedSurveyDesign, + SurveyDesign, + compute_replicate_if_variance, + compute_survey_if_variance, +) +from diff_diff.utils import compute_synthetic_weights # Constants for rank_control_units _SIMILARITY_THRESHOLD_SD = 0.5 # Controls within this many SDs are "similar" @@ -1300,3 +1305,310 @@ def trim_weights( result[weight_col] = w return result + + +# --------------------------------------------------------------------------- +# Survey aggregation helpers +# --------------------------------------------------------------------------- + + +def _cell_mean_variance( + y: np.ndarray, + weights: np.ndarray, + cell_resolved: ResolvedSurveyDesign, +) -> Tuple[float, float, int, bool]: + """Compute design-based mean and variance of the weighted mean for one cell. + + Parameters + ---------- + y : np.ndarray + Outcome values for the cell (may contain NaN). + weights : np.ndarray + Resolved weights for the cell (already extracted from ResolvedSurveyDesign). + cell_resolved : ResolvedSurveyDesign + Resolved survey design subsetted to this cell. + + Returns + ------- + mean : float + Design-weighted cell mean. + variance : float + Design-based variance of the cell mean (>= 0). Uses SRS fallback + when the design-based estimate is unidentifiable. + n_valid : int + Number of non-missing observations. + used_srs_fallback : bool + True if SRS variance was used instead of design-based. + """ + valid = ~np.isnan(y) + n_valid = int(np.sum(valid)) + + if n_valid == 0: + return np.nan, np.nan, 0, False + + if n_valid == 1: + y_bar = float(y[valid][0]) + return y_bar, np.nan, 1, False + + # Zero out weights for NaN observations (subpopulation approach) + w = weights.copy() + y_clean = np.where(valid, y, 0.0) + w_valid = w * valid.astype(np.float64) + sum_w = np.sum(w_valid) + + if sum_w <= 0: + return np.nan, np.nan, n_valid, False + + # Design-weighted mean + y_bar = float(np.sum(w_valid * y_clean) / sum_w) + + # Influence function: psi_i = w_i * (y_i - y_bar) / sum(w) + psi = w_valid * (y_clean - y_bar) / sum_w + + # Route to TSL or replicate variance + used_srs = False + if cell_resolved.uses_replicate_variance: + variance, _ = compute_replicate_if_variance(psi, cell_resolved) + else: + variance = compute_survey_if_variance(psi, cell_resolved) + + # SRS fallback when design-based variance is unidentifiable + if np.isnan(variance): + resid_sq = w_valid * (y_clean - y_bar) ** 2 + variance = float(np.sum(resid_sq) / (sum_w**2) * n_valid / (n_valid - 1)) + used_srs = True + + return y_bar, max(float(variance), 0.0), n_valid, used_srs + + +def aggregate_survey( + data: pd.DataFrame, + by: Union[str, List[str]], + outcomes: Union[str, List[str]], + survey_design: SurveyDesign, + covariates: Optional[Union[str, List[str]]] = None, + min_n: int = 2, + lonely_psu: Optional[str] = None, +) -> Tuple[pd.DataFrame, SurveyDesign]: + """Aggregate survey microdata to geographic-period cells with design-based precision. + + Computes design-weighted cell means and their Taylor-linearized (or + replicate-based) standard errors for each cell defined by the ``by`` + columns. Returns a panel-ready DataFrame with precision weights and a + pre-configured :class:`SurveyDesign` for second-stage DiD estimation. + + This follows R's ``survey::svyby()`` pattern: the survey design is + subsetted to each cell and domain-level statistics are computed using + the within-cell strata/PSU structure. + + Parameters + ---------- + data : pd.DataFrame + Individual-level microdata. + by : str or list of str + Columns defining cells (e.g., ``["state", "year"]``). The first + element is used as the clustering variable in the returned + SurveyDesign (geographic unit for second-stage inference). + outcomes : str or list of str + Outcome variable(s) to aggregate with full precision tracking. + Each outcome produces ``{name}_mean``, ``{name}_se``, + ``{name}_n``, and ``{name}_precision`` columns. + survey_design : SurveyDesign + Survey design specification for the microdata. + covariates : str or list of str, optional + Additional variables to aggregate as design-weighted means only + (no SE/precision columns). + min_n : int, default 2 + Minimum respondents per cell. Cells below this threshold use + simple random sampling variance as a fallback. + lonely_psu : str, optional + Override the survey design's ``lonely_psu`` setting for within-cell + computation. One of ``"remove"``, ``"certainty"``, ``"adjust"``. + + Returns + ------- + panel_df : pd.DataFrame + Aggregated panel with columns: grouping variables, + ``{outcome}_mean``, ``{outcome}_se``, ``{outcome}_n``, + ``{outcome}_precision``, ``{covariate}_mean``, ``cell_n``, + ``cell_n_eff``, ``srs_fallback``. + second_stage_design : SurveyDesign + Pre-configured for second-stage estimation with + ``weight_type="aweight"``, precision weights from the first + outcome, and geographic clustering via ``psu``. + + Examples + -------- + >>> design = SurveyDesign(weights="finalwt", strata="strat", psu="psu") + >>> panel, stage2 = aggregate_survey( + ... microdata, by=["state", "year"], + ... outcomes="smoking_rate", survey_design=design, + ... ) + >>> result = DifferenceInDifferences().fit( + ... panel, outcome="smoking_rate_mean", + ... treatment="treated", time="post", survey_design=stage2, + ... ) + """ + import warnings + from dataclasses import replace + + # --- Normalize inputs --- + by_cols = [by] if isinstance(by, str) else list(by) + outcome_cols = [outcomes] if isinstance(outcomes, str) else list(outcomes) + cov_cols = ( + [covariates] if isinstance(covariates, str) else list(covariates) if covariates else [] + ) + + # --- Validate --- + all_cols = by_cols + outcome_cols + cov_cols + missing = [c for c in all_cols if c not in data.columns] + if missing: + raise ValueError(f"Columns not found in DataFrame: {missing}") + + overlap = set(by_cols) & (set(outcome_cols) | set(cov_cols)) + if overlap: + raise ValueError(f"Columns appear in both 'by' and outcomes/covariates: {overlap}") + + if not isinstance(survey_design, SurveyDesign): + raise TypeError( + f"survey_design must be a SurveyDesign instance, got {type(survey_design).__name__}" + ) + + if min_n < 1: + raise ValueError(f"min_n must be >= 1, got {min_n}") + + if lonely_psu is not None and lonely_psu not in ("remove", "certainty", "adjust"): + raise ValueError( + f"lonely_psu must be 'remove', 'certainty', or 'adjust', got '{lonely_psu}'" + ) + + # --- Resolve design once on full data --- + effective_design = ( + replace(survey_design, lonely_psu=lonely_psu) if lonely_psu else survey_design + ) + full_resolved = effective_design.resolve(data) + + # --- Per-cell computation --- + grouped = data.groupby(by_cols, sort=True) + rows: List[Dict[str, Any]] = [] + srs_cells: List[str] = [] + zero_var_cells: List[str] = [] + + for cell_key, cell_df in grouped: + cell_idx = np.array(cell_df.index) + # Convert to positional indices for array subsetting + pos_idx = data.index.get_indexer(cell_idx) + + cell_n = len(pos_idx) + cell_key_str = str(cell_key) + + # Subset arrays from full resolved design + cell_w = full_resolved.weights[pos_idx] + cell_strata = full_resolved.strata[pos_idx] if full_resolved.strata is not None else None + cell_psu = full_resolved.psu[pos_idx] if full_resolved.psu is not None else None + cell_fpc = full_resolved.fpc[pos_idx] if full_resolved.fpc is not None else None + + cell_n_strata = int(len(np.unique(cell_strata))) if cell_strata is not None else 0 + cell_n_psu = int(len(np.unique(cell_psu))) if cell_psu is not None else 0 + + cell_resolved = full_resolved.subset_to_units( + row_idx=pos_idx, + weights=cell_w, + strata=cell_strata, + psu=cell_psu, + fpc=cell_fpc, + n_strata=cell_n_strata, + n_psu=cell_n_psu, + ) + + # Cell-level statistics + sum_w = float(np.sum(cell_w)) + sum_w2 = float(np.sum(cell_w**2)) + cell_n_eff = (sum_w**2 / sum_w2) if sum_w2 > 0 else 0.0 + + # Build row dict with grouping columns + row: Dict[str, Any] = {} + if len(by_cols) == 1: + row[by_cols[0]] = cell_key + else: + for i, col in enumerate(by_cols): + row[col] = cell_key[i] + + row["cell_n"] = cell_n + row["cell_n_eff"] = cell_n_eff + + cell_srs_fallback = False + + # Outcomes: mean + SE + n + precision + for var in outcome_cols: + y = cell_df[var].values.astype(np.float64) + y_bar, variance, n_valid, used_srs = _cell_mean_variance(y, cell_w, cell_resolved) + se = float(np.sqrt(variance)) if not np.isnan(variance) else np.nan + + if used_srs: + cell_srs_fallback = True + + # Zero variance → precision NaN + if se == 0.0: + precision = np.nan + zero_var_cells.append(cell_key_str) + elif np.isnan(se): + precision = np.nan + else: + precision = 1.0 / variance + + row[f"{var}_mean"] = y_bar + row[f"{var}_se"] = se + row[f"{var}_n"] = n_valid + row[f"{var}_precision"] = precision + + # Covariates: mean only + for var in cov_cols: + y = cell_df[var].values.astype(np.float64) + valid = ~np.isnan(y) + w_valid = cell_w * valid.astype(np.float64) + sw = float(np.sum(w_valid)) + if sw > 0: + row[f"{var}_mean"] = float(np.sum(w_valid * np.where(valid, y, 0.0)) / sw) + else: + row[f"{var}_mean"] = np.nan + + row["srs_fallback"] = cell_srs_fallback + if cell_srs_fallback: + srs_cells.append(cell_key_str) + + rows.append(row) + + # --- Warnings --- + if srs_cells: + warnings.warn( + f"Design-based variance not estimable for {len(srs_cells)} cell(s); " + f"using SRS fallback: {srs_cells[:5]}" + + (f" ... and {len(srs_cells) - 5} more" if len(srs_cells) > 5 else ""), + UserWarning, + stacklevel=2, + ) + if zero_var_cells: + warnings.warn( + f"Zero variance in {len(zero_var_cells)} cell(s) (precision set to NaN): " + f"{zero_var_cells[:5]}" + + (f" ... and {len(zero_var_cells) - 5} more" if len(zero_var_cells) > 5 else ""), + UserWarning, + stacklevel=2, + ) + + # --- Assemble output --- + panel_df = pd.DataFrame(rows) + + # Sort by grouping columns + panel_df = panel_df.sort_values(by_cols).reset_index(drop=True) + + # --- Construct second-stage SurveyDesign --- + first_outcome = outcome_cols[0] + second_stage_design = SurveyDesign( + weights=f"{first_outcome}_precision", + weight_type="aweight", + psu=by_cols[0], + ) + + return panel_df, second_stage_design diff --git a/docs/api/prep.rst b/docs/api/prep.rst index 462e82d0..45406f7f 100644 --- a/docs/api/prep.rst +++ b/docs/api/prep.rst @@ -250,6 +250,48 @@ Example outcome='outcome' ) +Survey Aggregation +------------------ + +aggregate_survey +~~~~~~~~~~~~~~~~ + +Aggregate survey microdata to geographic-period cells with design-based precision. + +.. autofunction:: diff_diff.aggregate_survey + +Example +^^^^^^^ + +.. code-block:: python + + from diff_diff import aggregate_survey, SurveyDesign, DifferenceInDifferences + + # Define the survey design for the microdata + design = SurveyDesign(weights="finalwt", strata="strat", psu="psu") + + # Aggregate to state-year panel with design-based SEs + panel, stage2 = aggregate_survey( + microdata, + by=["state", "year"], + outcomes="smoking_rate", + covariates=["age", "income"], + survey_design=design, + ) + + # panel has: state, year, smoking_rate_mean, smoking_rate_se, + # smoking_rate_n, smoking_rate_precision, age_mean, income_mean, + # cell_n, cell_n_eff, srs_fallback + + # stage2 is pre-configured: aweights + state-level clustering + result = DifferenceInDifferences().fit( + panel, + outcome="smoking_rate_mean", + treatment="treated", + time="post", + survey_design=stage2, + ) + Data Validation --------------- diff --git a/docs/methodology/REGISTRY.md b/docs/methodology/REGISTRY.md index 42f0fa5f..edd49e08 100644 --- a/docs/methodology/REGISTRY.md +++ b/docs/methodology/REGISTRY.md @@ -2303,6 +2303,32 @@ unequal selection probabilities). Linearization variance estimation, matching the R `survey` package convention that clusters are the primary sampling units. +### Survey Aggregation (`aggregate_survey`) + +Aggregation of individual-level survey microdata to geographic-period cells with +design-based precision estimates, for use as a pre-processing step before panel +DiD estimation on repeated cross-section survey data. + +- **Reference**: Lumley (2004) "Analysis of Complex Survey Samples", Journal of + Statistical Software 9(8). R `survey::svyby()` implements similar per-group + survey estimation. +- **Cell mean**: Design-weighted mean `ȳ_g = Σ w_i y_i / Σ w_i` for each cell g + defined by grouping columns (e.g., state × year). +- **Cell variance**: Linearized influence function `ψ_i = w_i (y_i - ȳ_g) / Σ w_j`, + then design-based variance via `compute_survey_if_variance()` (TSL) or + `compute_replicate_if_variance()` (replicate designs). This is the standard + Horvitz-Thompson linearization for a ratio estimator. +- **Precision weight**: `1 / V(ȳ_g)` used as inverse-variance weight (aweight) + in second-stage DiD estimation. +- **Note:** SRS fallback when design-based variance is unidentifiable within a cell + (e.g., all strata have singleton PSUs after cell subsetting). Formula: + `V_SRS = Σ w_i(y_i - ȳ)² / (Σ w_j)² × n/(n-1)`. Cells using SRS fallback + are flagged via `srs_fallback` column. +- **Note:** FPC values are passed through unchanged from the full design to cell + subsets — they represent population N_h per stratum, not per cell. +- **Edge case**: Zero-variance cells (all observations identical) set precision to + NaN to avoid infinite weights in second-stage WLS. + ### Survey-Aware Bootstrap (Phase 6) Two strategies for bootstrap variance under complex survey designs: diff --git a/tests/test_prep.py b/tests/test_prep.py index 0674186a..1a5d8216 100644 --- a/tests/test_prep.py +++ b/tests/test_prep.py @@ -7,6 +7,7 @@ import pytest from diff_diff.prep import ( + aggregate_survey, aggregate_to_cohorts, balance_panel, create_event_time, @@ -17,6 +18,7 @@ validate_did_data, wide_to_long, ) +from diff_diff.survey import SurveyDesign class TestMakeTreatmentIndicator: @@ -102,10 +104,9 @@ def test_treatment_start(self): def test_datetime_column(self): """Test with datetime column.""" - df = pd.DataFrame({ - "date": pd.to_datetime(["2020-01-01", "2020-06-01", "2021-01-01"]), - "y": [1, 2, 3] - }) + df = pd.DataFrame( + {"date": pd.to_datetime(["2020-01-01", "2020-06-01", "2021-01-01"]), "y": [1, 2, 3]} + ) result = make_post_indicator(df, "date", treatment_start="2020-06-01") assert result["post"].tolist() == [0, 1, 1] @@ -133,50 +134,36 @@ class TestWideToLong: def test_basic_conversion(self): """Test basic wide to long conversion.""" - wide_df = pd.DataFrame({ - "firm_id": [1, 2], - "sales_2019": [100, 150], - "sales_2020": [110, 160], - "sales_2021": [120, 170] - }) + wide_df = pd.DataFrame( + { + "firm_id": [1, 2], + "sales_2019": [100, 150], + "sales_2020": [110, 160], + "sales_2021": [120, 170], + } + ) result = wide_to_long( wide_df, value_columns=["sales_2019", "sales_2020", "sales_2021"], id_column="firm_id", time_name="year", - value_name="sales" + value_name="sales", ) assert len(result) == 6 assert set(result.columns) == {"firm_id", "year", "sales"} def test_with_time_values(self): """Test with explicit time values.""" - wide_df = pd.DataFrame({ - "id": [1], - "t1": [10], - "t2": [20] - }) + wide_df = pd.DataFrame({"id": [1], "t1": [10], "t2": [20]}) result = wide_to_long( - wide_df, - value_columns=["t1", "t2"], - id_column="id", - time_values=[2020, 2021] + wide_df, value_columns=["t1", "t2"], id_column="id", time_values=[2020, 2021] ) assert result["period"].tolist() == [2020, 2021] def test_preserves_other_columns(self): """Test that other columns are preserved.""" - wide_df = pd.DataFrame({ - "id": [1, 2], - "group": ["A", "B"], - "t1": [10, 20], - "t2": [15, 25] - }) - result = wide_to_long( - wide_df, - value_columns=["t1", "t2"], - id_column="id" - ) + wide_df = pd.DataFrame({"id": [1, 2], "group": ["A", "B"], "t1": [10, 20], "t2": [15, 25]}) + result = wide_to_long(wide_df, value_columns=["t1", "t2"], id_column="id") assert "group" in result.columns assert result[result["id"] == 1]["group"].tolist() == ["A", "A"] @@ -198,32 +185,26 @@ class TestBalancePanel: def test_inner_balance(self): """Test inner balance (keep complete units only).""" - df = pd.DataFrame({ - "unit": [1, 1, 1, 2, 2, 3, 3, 3], - "period": [1, 2, 3, 1, 2, 1, 2, 3], - "y": [10, 11, 12, 20, 21, 30, 31, 32] - }) + df = pd.DataFrame( + { + "unit": [1, 1, 1, 2, 2, 3, 3, 3], + "period": [1, 2, 3, 1, 2, 1, 2, 3], + "y": [10, 11, 12, 20, 21, 30, 31, 32], + } + ) result = balance_panel(df, "unit", "period", method="inner") assert set(result["unit"].unique()) == {1, 3} assert len(result) == 6 def test_outer_balance(self): """Test outer balance (include all combinations).""" - df = pd.DataFrame({ - "unit": [1, 1, 2], - "period": [1, 2, 1], - "y": [10, 11, 20] - }) + df = pd.DataFrame({"unit": [1, 1, 2], "period": [1, 2, 1], "y": [10, 11, 20]}) result = balance_panel(df, "unit", "period", method="outer") assert len(result) == 4 # 2 units x 2 periods def test_fill_with_value(self): """Test fill method with specific value.""" - df = pd.DataFrame({ - "unit": [1, 1, 2], - "period": [1, 2, 1], - "y": [10.0, 11.0, 20.0] - }) + df = pd.DataFrame({"unit": [1, 1, 2], "period": [1, 2, 1], "y": [10.0, 11.0, 20.0]}) result = balance_panel(df, "unit", "period", method="fill", fill_value=0.0) assert len(result) == 4 missing_row = result[(result["unit"] == 2) & (result["period"] == 2)] @@ -231,11 +212,13 @@ def test_fill_with_value(self): def test_fill_forward_backward(self): """Test fill method with forward/backward fill.""" - df = pd.DataFrame({ - "unit": [1, 1, 1, 2, 2], - "period": [1, 2, 3, 1, 3], # Unit 2 missing period 2 - "y": [10.0, 11.0, 12.0, 20.0, 22.0] - }) + df = pd.DataFrame( + { + "unit": [1, 1, 1, 2, 2], + "period": [1, 2, 3, 1, 3], # Unit 2 missing period 2 + "y": [10.0, 11.0, 12.0, 20.0, 22.0], + } + ) result = balance_panel(df, "unit", "period", method="fill", fill_value=None) assert len(result) == 6 # Check that unit 2, period 2 was filled @@ -255,11 +238,9 @@ class TestValidateDidData: def test_valid_data(self): """Test validation of valid data.""" - df = pd.DataFrame({ - "y": [1.0, 2.0, 3.0, 4.0], - "treated": [0, 0, 1, 1], - "post": [0, 1, 0, 1] - }) + df = pd.DataFrame( + {"y": [1.0, 2.0, 3.0, 4.0], "treated": [0, 0, 1, 1], "post": [0, 1, 0, 1]} + ) result = validate_did_data(df, "y", "treated", "post", raise_on_error=False) assert result["valid"] is True assert len(result["errors"]) == 0 @@ -273,33 +254,25 @@ def test_missing_column(self): def test_non_numeric_outcome(self): """Test validation catches non-numeric outcome.""" - df = pd.DataFrame({ - "y": ["a", "b", "c", "d"], - "treated": [0, 0, 1, 1], - "post": [0, 1, 0, 1] - }) + df = pd.DataFrame( + {"y": ["a", "b", "c", "d"], "treated": [0, 0, 1, 1], "post": [0, 1, 0, 1]} + ) result = validate_did_data(df, "y", "treated", "post", raise_on_error=False) assert result["valid"] is False assert any("numeric" in e for e in result["errors"]) def test_non_binary_treatment(self): """Test validation catches non-binary treatment.""" - df = pd.DataFrame({ - "y": [1.0, 2.0, 3.0], - "treated": [0, 1, 2], - "post": [0, 1, 0] - }) + df = pd.DataFrame({"y": [1.0, 2.0, 3.0], "treated": [0, 1, 2], "post": [0, 1, 0]}) result = validate_did_data(df, "y", "treated", "post", raise_on_error=False) assert result["valid"] is False assert any("binary" in e for e in result["errors"]) def test_missing_values(self): """Test validation catches missing values.""" - df = pd.DataFrame({ - "y": [1.0, np.nan, 3.0, 4.0], - "treated": [0, 0, 1, 1], - "post": [0, 1, 0, 1] - }) + df = pd.DataFrame( + {"y": [1.0, np.nan, 3.0, 4.0], "treated": [0, 0, 1, 1], "post": [0, 1, 0, 1]} + ) result = validate_did_data(df, "y", "treated", "post", raise_on_error=False) assert result["valid"] is False assert any("missing" in e for e in result["errors"]) @@ -312,12 +285,14 @@ def test_raises_on_error(self): def test_panel_validation(self): """Test panel-specific validation.""" - df = pd.DataFrame({ - "y": [1.0, 2.0, 3.0, 4.0], - "treated": [0, 0, 1, 1], - "post": [0, 1, 0, 1], - "unit": [1, 1, 2, 2] - }) + df = pd.DataFrame( + { + "y": [1.0, 2.0, 3.0, 4.0], + "treated": [0, 0, 1, 1], + "post": [0, 1, 0, 1], + "unit": [1, 1, 2, 2], + } + ) result = validate_did_data(df, "y", "treated", "post", unit="unit", raise_on_error=False) assert result["valid"] is True assert result["summary"]["n_units"] == 2 @@ -328,21 +303,25 @@ class TestSummarizeDidData: def test_basic_summary(self): """Test basic summary statistics.""" - df = pd.DataFrame({ - "y": [10, 11, 12, 13, 20, 21, 22, 23], - "treated": [0, 0, 1, 1, 0, 0, 1, 1], - "post": [0, 1, 0, 1, 0, 1, 0, 1] - }) + df = pd.DataFrame( + { + "y": [10, 11, 12, 13, 20, 21, 22, 23], + "treated": [0, 0, 1, 1, 0, 0, 1, 1], + "post": [0, 1, 0, 1, 0, 1, 0, 1], + } + ) summary = summarize_did_data(df, "y", "treated", "post") assert len(summary) == 5 # 4 groups + DiD estimate def test_did_estimate_included(self): """Test that DiD estimate is calculated.""" - df = pd.DataFrame({ - "y": [10, 20, 15, 30], # Perfect DiD = 30-15 - (20-10) = 5 - "treated": [0, 0, 1, 1], - "post": [0, 1, 0, 1] - }) + df = pd.DataFrame( + { + "y": [10, 20, 15, 30], # Perfect DiD = 30-15 - (20-10) = 5 + "treated": [0, 0, 1, 1], + "post": [0, 1, 0, 1], + } + ) summary = summarize_did_data(df, "y", "treated", "post") assert "DiD Estimate" in summary.index assert summary.loc["DiD Estimate", "mean"] == 5.0 @@ -369,11 +348,7 @@ def test_treatment_effect_recovery(self): true_effect = 5.0 data = generate_did_data( - n_units=200, - n_periods=4, - treatment_effect=true_effect, - noise_sd=0.5, - seed=42 + n_units=200, n_periods=4, treatment_effect=true_effect, noise_sd=0.5, seed=42 ) did = DifferenceInDifferences() @@ -405,21 +380,25 @@ class TestCreateEventTime: def test_basic_event_time(self): """Test basic event time calculation.""" - df = pd.DataFrame({ - "unit": [1, 1, 1, 2, 2, 2], - "year": [2018, 2019, 2020, 2018, 2019, 2020], - "treatment_year": [2019, 2019, 2019, 2020, 2020, 2020] - }) + df = pd.DataFrame( + { + "unit": [1, 1, 1, 2, 2, 2], + "year": [2018, 2019, 2020, 2018, 2019, 2020], + "treatment_year": [2019, 2019, 2019, 2020, 2020, 2020], + } + ) result = create_event_time(df, "year", "treatment_year") assert result["event_time"].tolist() == [-1, 0, 1, -2, -1, 0] def test_never_treated(self): """Test handling of never-treated units.""" - df = pd.DataFrame({ - "unit": [1, 1, 2, 2], - "year": [2019, 2020, 2019, 2020], - "treatment_year": [2020, 2020, np.nan, np.nan] - }) + df = pd.DataFrame( + { + "unit": [1, 1, 2, 2], + "year": [2019, 2020, 2019, 2020], + "treatment_year": [2020, 2020, np.nan, np.nan], + } + ) result = create_event_time(df, "year", "treatment_year") assert result.loc[0, "event_time"] == -1 assert result.loc[1, "event_time"] == 0 @@ -428,10 +407,7 @@ def test_never_treated(self): def test_custom_column_name(self): """Test custom output column name.""" - df = pd.DataFrame({ - "year": [2019, 2020], - "treat_time": [2020, 2020] - }) + df = pd.DataFrame({"year": [2019, 2020], "treat_time": [2020, 2020]}) result = create_event_time(df, "year", "treat_time", new_column="rel_time") assert "rel_time" in result.columns @@ -441,12 +417,14 @@ class TestAggregateToCohorts: def test_basic_aggregation(self): """Test basic cohort aggregation.""" - df = pd.DataFrame({ - "unit": [1, 1, 2, 2, 3, 3, 4, 4], - "period": [0, 1, 0, 1, 0, 1, 0, 1], - "treated": [1, 1, 1, 1, 0, 0, 0, 0], - "y": [10, 15, 12, 17, 8, 10, 9, 11] - }) + df = pd.DataFrame( + { + "unit": [1, 1, 2, 2, 3, 3, 4, 4], + "period": [0, 1, 0, 1, 0, 1, 0, 1], + "treated": [1, 1, 1, 1, 0, 0, 0, 0], + "y": [10, 15, 12, 17, 8, 10, 9, 11], + } + ) result = aggregate_to_cohorts(df, "unit", "period", "treated", "y") assert len(result) == 4 # 2 treatment groups x 2 periods assert "mean_y" in result.columns @@ -454,13 +432,15 @@ def test_basic_aggregation(self): def test_with_covariates(self): """Test aggregation with covariates.""" - df = pd.DataFrame({ - "unit": [1, 1, 2, 2], - "period": [0, 1, 0, 1], - "treated": [1, 1, 0, 0], - "y": [10, 15, 8, 10], - "x": [1.0, 1.5, 0.5, 0.8] - }) + df = pd.DataFrame( + { + "unit": [1, 1, 2, 2], + "period": [0, 1, 0, 1], + "treated": [1, 1, 0, 0], + "y": [10, 15, 8, 10], + "x": [1.0, 1.5, 0.5, 0.8], + } + ) result = aggregate_to_cohorts(df, "unit", "period", "treated", "y", covariates=["x"]) assert "x" in result.columns @@ -478,7 +458,7 @@ def test_basic_ranking(self): unit_column="unit", time_column="period", outcome_column="outcome", - treatment_column="treated" + treatment_column="treated", ) assert "quality_score" in result.columns assert "outcome_trend_score" in result.columns @@ -502,7 +482,7 @@ def test_with_covariates(self): time_column="period", outcome_column="outcome", treatment_column="treated", - covariates=["x1"] + covariates=["x1"], ) assert not result["covariate_score"].isna().all() @@ -517,7 +497,7 @@ def test_explicit_treated_units(self): unit_column="unit", time_column="period", outcome_column="outcome", - treated_units=[0, 1, 2] + treated_units=[0, 1, 2], ) # Should not include treated units in ranking assert 0 not in result["unit"].values @@ -536,7 +516,7 @@ def test_exclude_units(self): time_column="period", outcome_column="outcome", treatment_column="treated", - exclude_units=[15, 16, 17] + exclude_units=[15, 16, 17], ) assert 15 not in result["unit"].values assert 16 not in result["unit"].values @@ -559,7 +539,7 @@ def test_require_units(self): outcome_column="outcome", treatment_column="treated", require_units=require, - n_top=5 + n_top=5, ) # Required units should be present for u in require: @@ -579,7 +559,7 @@ def test_n_top_limit(self): time_column="period", outcome_column="outcome", treatment_column="treated", - n_top=10 + n_top=10, ) assert len(result) == 10 @@ -597,7 +577,7 @@ def test_suggest_treatment_candidates(self): time_column="period", outcome_column="outcome", suggest_treatment_candidates=True, - n_treatment_candidates=5 + n_treatment_candidates=5, ) assert "treatment_candidate_score" in result.columns assert len(result) == 5 @@ -614,7 +594,7 @@ def test_original_unchanged(self): unit_column="unit", time_column="period", outcome_column="outcome", - treatment_column="treated" + treatment_column="treated", ) assert data.columns.tolist() == original_cols @@ -626,10 +606,7 @@ def test_error_missing_column(self): with pytest.raises(ValueError, match="not found"): rank_control_units( - data, - unit_column="missing_col", - time_column="period", - outcome_column="outcome" + data, unit_column="missing_col", time_column="period", outcome_column="outcome" ) def test_error_both_treatment_specs(self): @@ -645,7 +622,7 @@ def test_error_both_treatment_specs(self): time_column="period", outcome_column="outcome", treatment_column="treated", - treated_units=[0, 1] + treated_units=[0, 1], ) def test_error_require_and_exclude_same_unit(self): @@ -662,7 +639,7 @@ def test_error_require_and_exclude_same_unit(self): outcome_column="outcome", treatment_column="treated", require_units=[5], - exclude_units=[5] + exclude_units=[5], ) def test_synthetic_weight_sum(self): @@ -676,7 +653,7 @@ def test_synthetic_weight_sum(self): unit_column="unit", time_column="period", outcome_column="outcome", - treatment_column="treated" + treatment_column="treated", ) # Synthetic weights should sum to approximately 1 @@ -694,7 +671,7 @@ def test_pre_periods_explicit(self): time_column="period", outcome_column="outcome", treatment_column="treated", - pre_periods=[0, 1] # Only use first two periods + pre_periods=[0, 1], # Only use first two periods ) assert len(result) > 0 @@ -715,7 +692,7 @@ def test_weight_parameters(self): treatment_column="treated", covariates=["x1"], outcome_weight=1.0, - covariate_weight=0.0 + covariate_weight=0.0, ) # All weight on covariates @@ -727,7 +704,7 @@ def test_weight_parameters(self): treatment_column="treated", covariates=["x1"], outcome_weight=0.0, - covariate_weight=1.0 + covariate_weight=1.0, ) # Rankings should differ @@ -745,10 +722,7 @@ def test_unbalanced_panel(self): # Remove all pre-period data for one control unit control_units = data[data["treated"] == 0]["unit"].unique() unit_to_partially_remove = control_units[0] - mask = ~( - (data["unit"] == unit_to_partially_remove) & - (data["period"] < 3) - ) + mask = ~((data["unit"] == unit_to_partially_remove) & (data["period"] < 3)) unbalanced_data = data[mask].copy() result = rank_control_units( @@ -756,7 +730,7 @@ def test_unbalanced_panel(self): unit_column="unit", time_column="period", outcome_column="outcome", - treatment_column="treated" + treatment_column="treated", ) # Should still work and exclude the unit with no pre-treatment data @@ -776,8 +750,7 @@ def test_single_control_unit(self): single_control = control_units[0] filtered_data = data[ - (data["unit"].isin(treated_units)) | - (data["unit"] == single_control) + (data["unit"].isin(treated_units)) | (data["unit"] == single_control) ].copy() result = rank_control_units( @@ -785,7 +758,7 @@ def test_single_control_unit(self): unit_column="unit", time_column="period", outcome_column="outcome", - treatment_column="treated" + treatment_column="treated", ) assert len(result) == 1 @@ -804,7 +777,13 @@ def test_basic_generation(self): data = generate_staggered_data(n_units=50, n_periods=8, seed=42) assert len(data) == 400 # 50 units x 8 periods assert set(data.columns) == { - "unit", "period", "outcome", "first_treat", "treated", "treat", "true_effect" + "unit", + "period", + "outcome", + "first_treat", + "treated", + "treat", + "true_effect", } def test_never_treated_fraction(self): @@ -819,9 +798,7 @@ def test_cohort_periods(self): """Test custom cohort periods.""" from diff_diff.prep import generate_staggered_data - data = generate_staggered_data( - n_units=100, n_periods=10, cohort_periods=[4, 6], seed=42 - ) + data = generate_staggered_data(n_units=100, n_periods=10, cohort_periods=[4, 6], seed=42) cohorts = data.groupby("unit")["first_treat"].first().unique() assert set(cohorts) == {0, 4, 6} @@ -829,9 +806,7 @@ def test_treatment_effect_direction(self): """Test that treatment effect is positive.""" from diff_diff.prep import generate_staggered_data - data = generate_staggered_data( - n_units=100, treatment_effect=3.0, noise_sd=0.1, seed=42 - ) + data = generate_staggered_data(n_units=100, treatment_effect=3.0, noise_sd=0.1, seed=42) # Treated observations should have positive true_effect treated_effects = data[data["treated"] == 1]["true_effect"] assert (treated_effects > 0).all() @@ -841,8 +816,12 @@ def test_dynamic_effects(self): from diff_diff.prep import generate_staggered_data data = generate_staggered_data( - n_units=50, n_periods=10, treatment_effect=2.0, - dynamic_effects=True, effect_growth=0.1, seed=42 + n_units=50, + n_periods=10, + treatment_effect=2.0, + dynamic_effects=True, + effect_growth=0.1, + seed=42, ) # Effects should grow over time since treatment # Check a treated unit @@ -881,9 +860,7 @@ def test_basic_generation(self): data = generate_factor_data(n_units=30, n_pre=8, n_post=4, n_treated=5, seed=42) assert len(data) == 360 # 30 units x 12 periods - assert set(data.columns) == { - "unit", "period", "outcome", "treated", "treat", "true_effect" - } + assert set(data.columns) == {"unit", "period", "outcome", "treated", "treat", "true_effect"} def test_treated_units_count(self): """Test that n_treated is respected.""" @@ -908,9 +885,14 @@ def test_treatment_effect_recovery(self): true_effect = 3.0 data = generate_factor_data( - n_units=100, n_pre=10, n_post=5, n_treated=30, - treatment_effect=true_effect, noise_sd=0.1, factor_strength=0.1, - seed=42 + n_units=100, + n_pre=10, + n_post=5, + n_treated=30, + treatment_effect=true_effect, + noise_sd=0.1, + factor_strength=0.1, + seed=42, ) # Simple DiD on treated vs control, post vs pre treated_post = data[(data["treat"] == 1) & (data["period"] >= 10)]["outcome"].mean() @@ -994,17 +976,35 @@ def test_treatment_effect_recovery(self): from diff_diff.prep import generate_ddd_data true_effect = 3.0 - data = generate_ddd_data(n_per_cell=200, treatment_effect=true_effect, noise_sd=0.5, seed=42) + data = generate_ddd_data( + n_per_cell=200, treatment_effect=true_effect, noise_sd=0.5, seed=42 + ) # Manual DDD calculation - y_111 = data[(data["group"] == 1) & (data["partition"] == 1) & (data["time"] == 1)]["outcome"].mean() - y_110 = data[(data["group"] == 1) & (data["partition"] == 1) & (data["time"] == 0)]["outcome"].mean() - y_101 = data[(data["group"] == 1) & (data["partition"] == 0) & (data["time"] == 1)]["outcome"].mean() - y_100 = data[(data["group"] == 1) & (data["partition"] == 0) & (data["time"] == 0)]["outcome"].mean() - y_011 = data[(data["group"] == 0) & (data["partition"] == 1) & (data["time"] == 1)]["outcome"].mean() - y_010 = data[(data["group"] == 0) & (data["partition"] == 1) & (data["time"] == 0)]["outcome"].mean() - y_001 = data[(data["group"] == 0) & (data["partition"] == 0) & (data["time"] == 1)]["outcome"].mean() - y_000 = data[(data["group"] == 0) & (data["partition"] == 0) & (data["time"] == 0)]["outcome"].mean() + y_111 = data[(data["group"] == 1) & (data["partition"] == 1) & (data["time"] == 1)][ + "outcome" + ].mean() + y_110 = data[(data["group"] == 1) & (data["partition"] == 1) & (data["time"] == 0)][ + "outcome" + ].mean() + y_101 = data[(data["group"] == 1) & (data["partition"] == 0) & (data["time"] == 1)][ + "outcome" + ].mean() + y_100 = data[(data["group"] == 1) & (data["partition"] == 0) & (data["time"] == 0)][ + "outcome" + ].mean() + y_011 = data[(data["group"] == 0) & (data["partition"] == 1) & (data["time"] == 1)][ + "outcome" + ].mean() + y_010 = data[(data["group"] == 0) & (data["partition"] == 1) & (data["time"] == 0)][ + "outcome" + ].mean() + y_001 = data[(data["group"] == 0) & (data["partition"] == 0) & (data["time"] == 1)][ + "outcome" + ].mean() + y_000 = data[(data["group"] == 0) & (data["partition"] == 0) & (data["time"] == 0)][ + "outcome" + ].mean() manual_ddd = (y_111 - y_110) - (y_101 - y_100) - (y_011 - y_010) + (y_001 - y_000) assert abs(manual_ddd - true_effect) < 0.5 @@ -1027,9 +1027,7 @@ def test_basic_generation(self): data = generate_panel_data(n_units=50, n_periods=6, seed=42) assert len(data) == 300 # 50 units x 6 periods - assert set(data.columns) == { - "unit", "period", "treated", "post", "outcome", "true_effect" - } + assert set(data.columns) == {"unit", "period", "treated", "post", "outcome", "true_effect"} def test_treatment_fraction(self): """Test that treatment_fraction is respected.""" @@ -1072,8 +1070,12 @@ def test_non_parallel_trends(self): from diff_diff.prep import generate_panel_data data = generate_panel_data( - n_units=200, n_periods=8, parallel_trends=False, - trend_violation=1.0, noise_sd=0.1, seed=42 + n_units=200, + n_periods=8, + parallel_trends=False, + trend_violation=1.0, + noise_sd=0.1, + seed=42, ) # Calculate pre-treatment trends pre_data = data[data["post"] == 0] @@ -1116,7 +1118,13 @@ def test_basic_generation(self): data = generate_event_study_data(n_units=100, n_pre=5, n_post=5, seed=42) assert len(data) == 1000 # 100 units x 10 periods assert set(data.columns) == { - "unit", "period", "treated", "post", "outcome", "event_time", "true_effect" + "unit", + "period", + "treated", + "post", + "outcome", + "event_time", + "true_effect", } def test_event_time(self): @@ -1143,8 +1151,7 @@ def test_treatment_effect_recovery(self): true_effect = 4.0 data = generate_event_study_data( - n_units=500, n_pre=5, n_post=5, treatment_effect=true_effect, - noise_sd=0.5, seed=42 + n_units=500, n_pre=5, n_post=5, treatment_effect=true_effect, noise_sd=0.5, seed=42 ) # Simple DiD @@ -1182,8 +1189,18 @@ def test_basic_shape_and_columns(self): data = generate_survey_did_data(n_units=100, n_periods=4, cohort_periods=[2, 3], seed=42) assert len(data) == 400 # 100 units x 4 periods - expected = {"unit", "period", "outcome", "first_treat", "treated", - "true_effect", "stratum", "psu", "fpc", "weight"} + expected = { + "unit", + "period", + "outcome", + "first_treat", + "treated", + "true_effect", + "stratum", + "psu", + "fpc", + "weight", + } assert set(data.columns) == expected def test_survey_columns_valid(self): @@ -1237,8 +1254,10 @@ def test_replicate_weights(self): n_strata, psu_per = 3, 4 data = generate_survey_did_data( - n_strata=n_strata, psu_per_stratum=psu_per, - include_replicate_weights=True, seed=42, + n_strata=n_strata, + psu_per_stratum=psu_per, + include_replicate_weights=True, + seed=42, ) n_psu = n_strata * psu_per rep_cols = [c for c in data.columns if c.startswith("rep_")] @@ -1246,7 +1265,7 @@ def test_replicate_weights(self): # Each replicate should zero out one PSU for r in range(n_psu): - assert (data.loc[data[f"rep_{r}"] == 0, "psu"].nunique() == 1) + assert data.loc[data[f"rep_{r}"] == 0, "psu"].nunique() == 1 def test_covariates(self): """Test covariate columns are added when requested.""" @@ -1277,7 +1296,9 @@ def test_treatment_structure(self): from diff_diff.prep import generate_survey_did_data data = generate_survey_did_data( - cohort_periods=[3, 5], never_treated_frac=0.3, seed=42, + cohort_periods=[3, 5], + never_treated_frac=0.3, + seed=42, ) cohorts = set(data.groupby("unit")["first_treat"].first().unique()) assert 0 in cohorts # never-treated @@ -1308,8 +1329,10 @@ def test_jk1_minimum_psu_guard(self): # Configured count: 1 PSU total with pytest.raises(ValueError, match="at least 2 PSUs"): generate_survey_did_data( - n_strata=1, psu_per_stratum=1, - include_replicate_weights=True, seed=42, + n_strata=1, + psu_per_stratum=1, + include_replicate_weights=True, + seed=42, ) def test_jk1_one_populated_psu_guard(self): @@ -1320,9 +1343,13 @@ def test_jk1_one_populated_psu_guard(self): # 2 configured PSUs but only 1 unit -> only 1 populated PSU with pytest.raises(ValueError, match="at least 2 populated PSUs"): generate_survey_did_data( - n_units=1, n_strata=1, psu_per_stratum=2, - cohort_periods=[2], n_periods=4, - include_replicate_weights=True, seed=42, + n_units=1, + n_strata=1, + psu_per_stratum=2, + cohort_periods=[2], + n_periods=4, + include_replicate_weights=True, + seed=42, ) def test_repeated_cross_section(self): @@ -1330,7 +1357,11 @@ def test_repeated_cross_section(self): from diff_diff.prep import generate_survey_did_data data = generate_survey_did_data( - n_units=20, n_periods=4, cohort_periods=[2], panel=False, seed=42, + n_units=20, + n_periods=4, + cohort_periods=[2], + panel=False, + seed=42, ) assert len(data) == 80 assert data["unit"].nunique() == 80 # unique across all periods @@ -1457,11 +1488,18 @@ def test_psu_period_factor_deff_regression(self): warnings.filterwarnings("ignore") df = generate_survey_did_data( - n_units=200, n_periods=8, cohort_periods=[3, 5], - never_treated_frac=0.3, treatment_effect=2.0, - n_strata=5, psu_per_stratum=8, fpc_per_stratum=200.0, - weight_variation="moderate", psu_re_sd=2.0, - psu_period_factor=1.0, seed=42, + n_units=200, + n_periods=8, + cohort_periods=[3, 5], + never_treated_frac=0.3, + treatment_effect=2.0, + n_strata=5, + psu_per_stratum=8, + fpc_per_stratum=200.0, + weight_variation="moderate", + psu_re_sd=2.0, + psu_period_factor=1.0, + seed=42, ) sd = SurveyDesign(weights="weight", strata="stratum", psu="psu", fpc="fpc") @@ -1472,21 +1510,22 @@ def test_psu_period_factor_deff_regression(self): did = DifferenceInDifferences() r_naive = did.fit(c3, outcome="outcome", treatment="treat", time="post") r_survey = did.fit( - c3, outcome="outcome", treatment="treat", time="post", + c3, + outcome="outcome", + treatment="treat", + time="post", survey_design=sd, ) - assert r_survey.se > r_naive.se, ( - f"Survey SE ({r_survey.se:.4f}) should exceed naive SE ({r_naive.se:.4f})" - ) + assert ( + r_survey.se > r_naive.se + ), f"Survey SE ({r_survey.se:.4f}) should exceed naive SE ({r_naive.se:.4f})" # DEFF for treat_x_post must be > 1 c3["treat_x_post"] = c3["treat"] * c3["post"] resolved = sd.resolve(c3) reg = LinearRegression(include_intercept=True, survey_design=resolved) reg.fit(X=c3[["treat", "post", "treat_x_post"]].values, y=c3["outcome"].values) - deff = reg.compute_deff( - coefficient_names=["intercept", "treat", "post", "treat_x_post"] - ) + deff = reg.compute_deff(coefficient_names=["intercept", "treat", "post", "treat_x_post"]) txp_deff = deff.deff[3] # treat_x_post assert txp_deff > 1.0, f"DEFF for treat_x_post ({txp_deff:.2f}) should be > 1" @@ -1514,9 +1553,7 @@ def test_icc_parameter(self): from diff_diff.prep_dgp import generate_survey_did_data target_icc = 0.3 - df = generate_survey_did_data( - n_units=1000, icc=target_icc, seed=42 - ) + df = generate_survey_did_data(n_units=1000, icc=target_icc, seed=42) # ANOVA-based ICC on period 1 (pre-treatment, no TE contamination) p1 = df[df["period"] == 1] groups = p1.groupby("psu")["outcome"] @@ -1561,9 +1598,7 @@ def test_weight_cv_parameter(self): from diff_diff.prep_dgp import generate_survey_did_data target_cv = 0.5 - df = generate_survey_did_data( - n_units=1000, weight_cv=target_cv, seed=42 - ) + df = generate_survey_did_data(n_units=1000, weight_cv=target_cv, seed=42) weights = df.groupby("unit")["weight"].first().values realized_cv = weights.std() / weights.mean() assert abs(realized_cv - target_cv) < 0.15 @@ -1573,9 +1608,7 @@ def test_weight_cv_and_weight_variation_conflict(self): from diff_diff.prep_dgp import generate_survey_did_data with pytest.raises(ValueError, match="Cannot specify both weight_cv"): - generate_survey_did_data( - weight_cv=0.5, weight_variation="high", seed=42 - ) + generate_survey_did_data(weight_cv=0.5, weight_variation="high", seed=42) def test_weight_cv_nan_inf(self): """weight_cv must reject NaN and Inf.""" @@ -1622,8 +1655,7 @@ def test_informative_sampling_default_weights(self): expected_mean = 1.0 + 1.0 * (s / 4) stratum_weights = p1.loc[p1["stratum"] == s, "weight"] assert abs(stratum_weights.mean() - expected_mean) < 0.15, ( - f"Stratum {s}: expected mean ~{expected_mean}, " - f"got {stratum_weights.mean():.3f}" + f"Stratum {s}: expected mean ~{expected_mean}, " f"got {stratum_weights.mean():.3f}" ) # Within-stratum variation should exist (informative sampling) assert stratum_weights.std() > 0.01 @@ -1670,9 +1702,7 @@ def test_icc_with_covariates(self): from diff_diff.prep_dgp import generate_survey_did_data target_icc = 0.3 - df = generate_survey_did_data( - n_units=1000, icc=target_icc, add_covariates=True, seed=42 - ) + df = generate_survey_did_data(n_units=1000, icc=target_icc, add_covariates=True, seed=42) # ANOVA-based ICC on period 1 p1 = df[df["period"] == 1] groups = p1.groupby("psu")["outcome"] @@ -1840,9 +1870,7 @@ def test_strata_sizes(self): from diff_diff.prep_dgp import generate_survey_did_data sizes = [60, 50, 40, 30, 20] - df = generate_survey_did_data( - n_units=200, strata_sizes=sizes, seed=42 - ) + df = generate_survey_did_data(n_units=200, strata_sizes=sizes, seed=42) for s, expected in enumerate(sizes): actual = df[df["period"] == 1]["stratum"].value_counts().get(s, 0) assert actual == expected @@ -1852,9 +1880,7 @@ def test_strata_sizes_sum_mismatch(self): from diff_diff.prep_dgp import generate_survey_did_data with pytest.raises(ValueError, match="strata_sizes must sum"): - generate_survey_did_data( - n_units=200, strata_sizes=[50, 50, 50, 50, 49], seed=42 - ) + generate_survey_did_data(n_units=200, strata_sizes=[50, 50, 50, 50, 49], seed=42) def test_strata_sizes_float_rejected(self): """strata_sizes must contain integers, not floats.""" @@ -1877,12 +1903,12 @@ def test_covariate_effects_custom(self): """Custom covariate coefficients should change outcome variance.""" from diff_diff.prep_dgp import generate_survey_did_data - df_default = generate_survey_did_data( - n_units=500, add_covariates=True, seed=42 - ) + df_default = generate_survey_did_data(n_units=500, add_covariates=True, seed=42) df_large = generate_survey_did_data( - n_units=500, add_covariates=True, - covariate_effects=(2.0, 1.0), seed=42, + n_units=500, + add_covariates=True, + covariate_effects=(2.0, 1.0), + seed=42, ) # Larger coefficients → larger outcome variance assert df_large["outcome"].var() > df_default["outcome"].var() @@ -1891,12 +1917,12 @@ def test_covariate_effects_zero(self): """Zero covariate effects should produce same variance as no covariates.""" from diff_diff.prep_dgp import generate_survey_did_data - df_no_cov = generate_survey_did_data( - n_units=500, add_covariates=False, seed=42 - ) + df_no_cov = generate_survey_did_data(n_units=500, add_covariates=False, seed=42) df_zero = generate_survey_did_data( - n_units=500, add_covariates=True, - covariate_effects=(0.0, 0.0), seed=42, + n_units=500, + add_covariates=True, + covariate_effects=(0.0, 0.0), + seed=42, ) # Outcome variance should be similar (covariates contribute nothing) assert abs(df_zero["outcome"].var() - df_no_cov["outcome"].var()) < 0.5 @@ -1920,32 +1946,460 @@ def test_te_covariate_interaction_requires_covariates(self): from diff_diff.prep_dgp import generate_survey_did_data with pytest.raises(ValueError, match="te_covariate_interaction requires"): - generate_survey_did_data( - te_covariate_interaction=0.5, add_covariates=False, seed=42 - ) + generate_survey_did_data(te_covariate_interaction=0.5, add_covariates=False, seed=42) def test_covariate_effects_validation(self): """covariate_effects must be length 2 and finite.""" from diff_diff.prep_dgp import generate_survey_did_data with pytest.raises(ValueError, match="covariate_effects must have length 2"): - generate_survey_did_data( - add_covariates=True, covariate_effects=(1.0,), seed=42 - ) + generate_survey_did_data(add_covariates=True, covariate_effects=(1.0,), seed=42) with pytest.raises(ValueError, match="covariate_effects must be finite"): - generate_survey_did_data( - add_covariates=True, covariate_effects=(np.nan, 0.3), seed=42 - ) + generate_survey_did_data(add_covariates=True, covariate_effects=(np.nan, 0.3), seed=42) with pytest.raises(ValueError, match="covariate_effects must be finite"): - generate_survey_did_data( - add_covariates=True, covariate_effects=(0.5, np.inf), seed=42 - ) + generate_survey_did_data(add_covariates=True, covariate_effects=(0.5, np.inf), seed=42) def test_te_covariate_interaction_validation(self): """te_covariate_interaction must be finite.""" from diff_diff.prep_dgp import generate_survey_did_data with pytest.raises(ValueError, match="te_covariate_interaction must be finite"): - generate_survey_did_data( - add_covariates=True, te_covariate_interaction=np.nan, seed=42 + generate_survey_did_data(add_covariates=True, te_covariate_interaction=np.nan, seed=42) + + +class TestAggregateSurvey: + """Tests for aggregate_survey function.""" + + @pytest.fixture + def micro_data(self): + """Create simple microdata: 2 states, 2 years, with survey design.""" + rng = np.random.RandomState(42) + n = 400 + states = np.repeat(["CA", "TX"], n // 2) + years = np.tile(np.repeat([2019, 2020], n // 4), 2) + strata = np.repeat(np.arange(4), n // 4) + psu = np.arange(n) // 10 # 10 obs per PSU, 40 PSUs total + weights = rng.uniform(0.5, 3.0, n) + outcome = rng.normal(10, 2, n) + # Make CA slightly higher than TX to get different means + outcome[: n // 2] += 2.0 + covariate = rng.normal(50, 10, n) + + return pd.DataFrame( + { + "state": states, + "year": years, + "stratum": strata, + "cluster": psu, + "wt": weights, + "y": outcome, + "x": covariate, + } + ) + + @pytest.fixture + def design(self): + return SurveyDesign(weights="wt", strata="stratum", psu="cluster") + + def test_basic_aggregation(self, micro_data, design): + """Design-weighted means should differ from simple means.""" + panel, _ = aggregate_survey( + micro_data, + by=["state", "year"], + outcomes="y", + survey_design=design, + ) + assert len(panel) == 4 # 2 states x 2 years + + # Check design-weighted mean differs from simple mean + simple_mean = micro_data[micro_data["state"] == "CA"][micro_data["year"] == 2019][ + "y" + ].mean() + ca_2019 = panel[(panel["state"] == "CA") & (panel["year"] == 2019)] + weighted_mean = ca_2019["y_mean"].iloc[0] + # With non-uniform weights, these should differ + assert weighted_mean != pytest.approx(simple_mean, abs=0.01) + + def test_column_naming(self, micro_data, design): + """All expected columns should be present.""" + panel, _ = aggregate_survey( + micro_data, + by=["state", "year"], + outcomes="y", + covariates="x", + survey_design=design, + ) + expected = { + "state", + "year", + "y_mean", + "y_se", + "y_n", + "y_precision", + "x_mean", + "cell_n", + "cell_n_eff", + "srs_fallback", + } + assert expected.issubset(set(panel.columns)) + + def test_multiple_outcomes(self, micro_data, design): + """Each outcome gets own columns; SurveyDesign uses first.""" + micro_data = micro_data.copy() + micro_data["y2"] = micro_data["y"] * 2 + panel, stage2 = aggregate_survey( + micro_data, + by=["state", "year"], + outcomes=["y", "y2"], + survey_design=design, + ) + assert "y_mean" in panel.columns + assert "y2_mean" in panel.columns + assert "y_precision" in panel.columns + assert "y2_precision" in panel.columns + assert stage2.weights == "y_precision" + + def test_covariates_mean_only(self, micro_data, design): + """Covariates get mean column only, no SE/precision.""" + panel, _ = aggregate_survey( + micro_data, + by=["state", "year"], + outcomes="y", + covariates="x", + survey_design=design, + ) + assert "x_mean" in panel.columns + assert "x_se" not in panel.columns + assert "x_precision" not in panel.columns + + def test_returned_survey_design(self, micro_data, design): + """Returned SurveyDesign has correct aweight config and clustering.""" + _, stage2 = aggregate_survey( + micro_data, + by=["state", "year"], + outcomes="y", + survey_design=design, + ) + assert stage2.weight_type == "aweight" + assert stage2.weights == "y_precision" + assert stage2.psu == "state" + + def test_srs_fallback(self): + """Cells where design-based variance fails get SRS fallback.""" + # Create data where each cell has only 1 PSU per stratum + rng = np.random.RandomState(99) + n = 40 + data = pd.DataFrame( + { + "geo": np.repeat(["A", "B"], n // 2), + "time": np.tile(np.repeat([0, 1], n // 4), 2), + "stratum": np.arange(n), # every obs is its own stratum + "psu": np.arange(n), # every obs is its own PSU + "wt": np.ones(n), + "y": rng.normal(0, 1, n), + } + ) + design = SurveyDesign(weights="wt", strata="stratum", psu="psu", lonely_psu="remove") + with pytest.warns(UserWarning, match="SRS fallback"): + panel, _ = aggregate_survey( + data, + by=["geo", "time"], + outcomes="y", + survey_design=design, + ) + assert panel["srs_fallback"].all() + # SRS SE should be finite and positive + assert (panel["y_se"] > 0).all() + assert panel["y_se"].notna().all() + + # Verify SRS SE matches manual computation for one cell + cell = data[(data["geo"] == "A") & (data["time"] == 0)] + y_vals = cell["y"].values + n_cell = len(y_vals) + expected_var = np.var(y_vals, ddof=0) / n_cell * n_cell / (n_cell - 1) + expected_se = np.sqrt(expected_var) + actual_se = panel[(panel["geo"] == "A") & (panel["time"] == 0)]["y_se"].iloc[0] + assert actual_se == pytest.approx(expected_se, rel=1e-10) + + def test_missing_values(self, micro_data, design): + """Missing values reduce var_n but cell_n stays the same.""" + micro_data = micro_data.copy() + # Set some values to NaN in CA-2019 + mask = (micro_data["state"] == "CA") & (micro_data["year"] == 2019) + idx = micro_data[mask].index[:5] + micro_data.loc[idx, "y"] = np.nan + + panel, _ = aggregate_survey( + micro_data, + by=["state", "year"], + outcomes="y", + survey_design=design, + ) + ca_2019 = panel[(panel["state"] == "CA") & (panel["year"] == 2019)] + assert ca_2019["y_n"].iloc[0] == 100 - 5 # 5 NaN + assert ca_2019["cell_n"].iloc[0] == 100 # all respondents + + def test_zero_variance_cell(self): + """When all values are identical, precision is NaN.""" + data = pd.DataFrame( + { + "geo": np.repeat(["A", "B"], 20), + "time": np.tile(np.repeat([0, 1], 10), 2), + "wt": np.ones(40), + "y": np.concatenate( + [ + np.full(10, 5.0), # A-0: constant + np.random.RandomState(1).normal(5, 1, 10), # A-1 + np.random.RandomState(2).normal(5, 1, 10), # B-0 + np.random.RandomState(3).normal(5, 1, 10), # B-1 + ] + ), + } + ) + design = SurveyDesign(weights="wt") + with pytest.warns(UserWarning, match="Zero variance"): + panel, _ = aggregate_survey( + data, + by=["geo", "time"], + outcomes="y", + survey_design=design, + ) + a0 = panel[(panel["geo"] == "A") & (panel["time"] == 0)] + assert a0["y_mean"].iloc[0] == pytest.approx(5.0) + assert a0["y_se"].iloc[0] == pytest.approx(0.0) + assert np.isnan(a0["y_precision"].iloc[0]) + + def test_lonely_psu_override(self): + """lonely_psu parameter overrides survey_design setting.""" + rng = np.random.RandomState(77) + n = 40 + data = pd.DataFrame( + { + "geo": np.repeat(["A", "B"], n // 2), + "time": np.tile(np.repeat([0, 1], n // 4), 2), + "stratum": np.repeat(np.arange(4), n // 4), + "psu": np.arange(n) // 5, + "wt": np.ones(n), + "y": rng.normal(0, 1, n), + } + ) + design = SurveyDesign(weights="wt", strata="stratum", psu="psu", lonely_psu="remove") + # Override to "certainty" — different behavior for singletons + panel_cert, _ = aggregate_survey( + data, + by=["geo", "time"], + outcomes="y", + survey_design=design, + lonely_psu="certainty", + ) + panel_remove, _ = aggregate_survey( + data, + by=["geo", "time"], + outcomes="y", + survey_design=design, + ) + # Results should exist for both + assert len(panel_cert) == 4 + assert len(panel_remove) == 4 + + def test_single_by_column(self, micro_data, design): + """Single string for by works correctly.""" + panel, stage2 = aggregate_survey( + micro_data, + by="state", + outcomes="y", + survey_design=design, + ) + assert len(panel) == 2 # CA, TX + assert "state" in panel.columns + assert stage2.psu == "state" + + def test_srs_equivalence_weights_only(self): + """With no strata/PSU, SE matches weighted SRS formula.""" + rng = np.random.RandomState(123) + n = 100 + data = pd.DataFrame( + { + "geo": np.repeat(["A", "B"], n // 2), + "time": np.ones(n, dtype=int), + "wt": rng.uniform(0.5, 2.0, n), + "y": rng.normal(10, 2, n), + } + ) + design = SurveyDesign(weights="wt") + panel, _ = aggregate_survey( + data, + by=["geo", "time"], + outcomes="y", + survey_design=design, + ) + + # Manual SRS computation for cell A + cell = data[data["geo"] == "A"] + w = cell["wt"].values + y = cell["y"].values + # resolve() normalizes pweights to mean=1 + w_norm = w / w.mean() + sum_w = np.sum(w_norm) + y_bar = np.sum(w_norm * y) / sum_w + n_cell = len(y) + # SRS variance with weights: implicit per-obs PSU + # meat = (n/(n-1)) * sum((w*(y-ybar)/sum_w)^2) + psi = w_norm * (y - y_bar) / sum_w + variance = (n_cell / (n_cell - 1)) * np.sum(psi**2) + expected_se = np.sqrt(variance) + + actual_se = panel[panel["geo"] == "A"]["y_se"].iloc[0] + assert actual_se == pytest.approx(expected_se, rel=1e-10) + + def test_design_effect_increases_se(self): + """With PSU clustering, SE should be larger than without.""" + rng = np.random.RandomState(55) + n = 200 + psu_ids = np.arange(n) // 10 # 10 obs per PSU + # Add PSU-level random effects to create ICC + psu_effects = rng.normal(0, 3, 20) + y = rng.normal(0, 1, n) + psu_effects[psu_ids] + + data = pd.DataFrame( + { + "geo": ["A"] * n, + "time": np.ones(n, dtype=int), + "cluster": psu_ids, + "wt": np.ones(n), + "y": y, + } + ) + + design_no_psu = SurveyDesign(weights="wt") + design_psu = SurveyDesign(weights="wt", psu="cluster") + + panel_no_psu, _ = aggregate_survey( + data, + by=["geo", "time"], + outcomes="y", + survey_design=design_no_psu, + ) + panel_psu, _ = aggregate_survey( + data, + by=["geo", "time"], + outcomes="y", + survey_design=design_psu, + ) + + se_no_psu = panel_no_psu["y_se"].iloc[0] + se_psu = panel_psu["y_se"].iloc[0] + assert se_psu > se_no_psu # clustering increases SE + + def test_equal_weights_simple_mean(self): + """With equal weights, design-weighted mean equals arithmetic mean.""" + rng = np.random.RandomState(88) + n = 60 + data = pd.DataFrame( + { + "geo": np.repeat(["A", "B"], n // 2), + "time": np.tile(np.repeat([0, 1], n // 4), 2), + "wt": np.ones(n), + "y": rng.normal(10, 2, n), + } + ) + design = SurveyDesign(weights="wt") + panel, _ = aggregate_survey( + data, + by=["geo", "time"], + outcomes="y", + survey_design=design, + ) + # Check A-0 cell + cell = data[(data["geo"] == "A") & (data["time"] == 0)] + assert panel[(panel["geo"] == "A") & (panel["time"] == 0)]["y_mean"].iloc[0] == ( + pytest.approx(cell["y"].mean(), rel=1e-12) + ) + + def test_pipeline_with_did(self): + """Full pipeline: microdata → aggregate → DiD estimation.""" + from diff_diff import DifferenceInDifferences + + # Construct microdata simulating 4 states, 2 periods, ~50 obs/cell + rng = np.random.RandomState(42) + rows = [] + for state in range(4): + treated = 1 if state < 2 else 0 + for period in [0, 1]: + n_cell = rng.randint(40, 60) + # Treatment effect in post period for treated states + te = 3.0 if (treated and period == 1) else 0.0 + for _ in range(n_cell): + strat = rng.randint(0, 3) + psu_id = state * 100 + strat * 10 + rng.randint(0, 3) + rows.append( + { + "state": state, + "period": period, + "stratum": strat, + "psu": psu_id, + "wt": rng.uniform(0.5, 2.0), + "outcome": rng.normal(10 + te, 2), + "treated": treated, + } + ) + micro = pd.DataFrame(rows) + + design = SurveyDesign(weights="wt", strata="stratum", psu="psu") + panel, stage2 = aggregate_survey( + micro, + by=["state", "period"], + outcomes="outcome", + covariates="treated", + survey_design=design, + ) + + panel["treated_bin"] = (panel["treated_mean"] > 0.5).astype(int) + + did = DifferenceInDifferences() + result = did.fit( + panel, + outcome="outcome_mean", + treatment="treated_bin", + time="period", + survey_design=stage2, + ) + assert result.att is not None + assert np.isfinite(result.att) + assert np.isfinite(result.se) + assert result.se > 0 + # ATT should be near the true effect of 3.0 + assert abs(result.att - 3.0) < 2.0 + + # --- Error tests --- + + def test_error_missing_column(self, micro_data, design): + """Missing column raises ValueError.""" + with pytest.raises(ValueError, match="Columns not found"): + aggregate_survey( + micro_data, + by=["state", "year"], + outcomes="nonexistent", + survey_design=design, + ) + + def test_error_invalid_survey_design(self, micro_data): + """Non-SurveyDesign object raises TypeError.""" + with pytest.raises(TypeError, match="SurveyDesign instance"): + aggregate_survey( + micro_data, + by=["state", "year"], + outcomes="y", + survey_design="not_a_design", + ) + + def test_error_min_n_too_small(self, micro_data, design): + """min_n < 1 raises ValueError.""" + with pytest.raises(ValueError, match="min_n must be >= 1"): + aggregate_survey( + micro_data, + by=["state", "year"], + outcomes="y", + survey_design=design, + min_n=0, ) From d330d36187bbbca1dbc13834d4df3f6389f5b979 Mon Sep 17 00:00:00 2001 From: igerber Date: Tue, 7 Apr 2026 10:25:40 -0400 Subject: [PATCH 02/16] Fix P0 domain estimation, P1 min_n/empty guard, P3 docs/tests Rework aggregate_survey() to use full-design domain estimation: zero-pad influence functions outside each cell, preserving full strata/PSU structure for variance computation per Lumley (2004) Section 3.4 and the library's subpopulation() methodology. Also fix: min_n parameter now operative (forces SRS fallback), empty-input guard added, docstring examples corrected, REGISTRY.md rewritten with domain estimation language, 4 new tests added (domain regression, min_n behavior, empty input, replicate weights). Co-Authored-By: Claude Opus 4.6 (1M context) --- diff_diff/prep.py | 140 +++++++++++++++++------------- docs/api/prep.rst | 2 +- docs/methodology/REGISTRY.md | 23 +++-- tests/test_prep.py | 159 ++++++++++++++++++++++++++++++++--- 4 files changed, 238 insertions(+), 86 deletions(-) diff --git a/diff_diff/prep.py b/diff_diff/prep.py index a66fbd59..f2a1f95a 100644 --- a/diff_diff/prep.py +++ b/diff_diff/prep.py @@ -1313,20 +1313,29 @@ def trim_weights( def _cell_mean_variance( - y: np.ndarray, - weights: np.ndarray, - cell_resolved: ResolvedSurveyDesign, + y_full: np.ndarray, + full_resolved: ResolvedSurveyDesign, + cell_mask: np.ndarray, + min_n: int, ) -> Tuple[float, float, int, bool]: """Compute design-based mean and variance of the weighted mean for one cell. + Uses full-design domain estimation: the influence function is zero-padded + outside the cell, preserving the full strata/PSU structure for variance + estimation. This is the methodologically correct approach for domain + estimation under complex survey designs (Lumley 2004, Section 3.4). + Parameters ---------- - y : np.ndarray - Outcome values for the cell (may contain NaN). - weights : np.ndarray - Resolved weights for the cell (already extracted from ResolvedSurveyDesign). - cell_resolved : ResolvedSurveyDesign - Resolved survey design subsetted to this cell. + y_full : np.ndarray + Outcome values for the full dataset (may contain NaN). + full_resolved : ResolvedSurveyDesign + Full-sample resolved survey design. + cell_mask : np.ndarray + Boolean mask identifying cell members in the full dataset. + min_n : int + Minimum valid observations for design-based variance. Below this + threshold, SRS fallback is used. Returns ------- @@ -1334,43 +1343,55 @@ def _cell_mean_variance( Design-weighted cell mean. variance : float Design-based variance of the cell mean (>= 0). Uses SRS fallback - when the design-based estimate is unidentifiable. + when the design-based estimate is unidentifiable or n_valid < min_n. n_valid : int - Number of non-missing observations. + Number of non-missing observations in the cell. used_srs_fallback : bool True if SRS variance was used instead of design-based. """ - valid = ~np.isnan(y) + y_cell = y_full[cell_mask] + w_cell = full_resolved.weights[cell_mask] + valid = ~np.isnan(y_cell) n_valid = int(np.sum(valid)) if n_valid == 0: return np.nan, np.nan, 0, False - if n_valid == 1: - y_bar = float(y[valid][0]) + if n_valid < 2: + y_bar = float(y_cell[valid][0]) return y_bar, np.nan, 1, False - # Zero out weights for NaN observations (subpopulation approach) - w = weights.copy() - y_clean = np.where(valid, y, 0.0) - w_valid = w * valid.astype(np.float64) - sum_w = np.sum(w_valid) + # Weighted mean from cell members (NaN-safe) + w_valid = w_cell * valid.astype(np.float64) + y_clean = np.where(valid, y_cell, 0.0) + sum_w = float(np.sum(w_valid)) if sum_w <= 0: return np.nan, np.nan, n_valid, False - # Design-weighted mean y_bar = float(np.sum(w_valid * y_clean) / sum_w) - # Influence function: psi_i = w_i * (y_i - y_bar) / sum(w) - psi = w_valid * (y_clean - y_bar) / sum_w - - # Route to TSL or replicate variance + # SRS fallback if below min_n threshold used_srs = False - if cell_resolved.uses_replicate_variance: - variance, _ = compute_replicate_if_variance(psi, cell_resolved) + if n_valid < min_n: + resid_sq = w_valid * (y_clean - y_bar) ** 2 + variance = float(np.sum(resid_sq) / (sum_w**2) * n_valid / (n_valid - 1)) + return y_bar, max(variance, 0.0), n_valid, True + + # Full-design domain estimation: construct full-length psi with zeros + # outside the cell, preserving full strata/PSU structure for variance + n_total = len(y_full) + psi = np.zeros(n_total) + # Positions in full array where cell member has valid data + cell_indices = np.where(cell_mask)[0] + valid_positions = cell_indices[valid] + psi[valid_positions] = w_valid[valid] * (y_clean[valid] - y_bar) / sum_w + + # Route to TSL or replicate variance using the full design + if full_resolved.uses_replicate_variance: + variance, _ = compute_replicate_if_variance(psi, full_resolved) else: - variance = compute_survey_if_variance(psi, cell_resolved) + variance = compute_survey_if_variance(psi, full_resolved) # SRS fallback when design-based variance is unidentifiable if np.isnan(variance): @@ -1397,9 +1418,10 @@ def aggregate_survey( columns. Returns a panel-ready DataFrame with precision weights and a pre-configured :class:`SurveyDesign` for second-stage DiD estimation. - This follows R's ``survey::svyby()`` pattern: the survey design is - subsetted to each cell and domain-level statistics are computed using - the within-cell strata/PSU structure. + Each cell is treated as a subpopulation/domain of the full survey + design: influence function values are zero-padded outside the cell, + preserving full strata/PSU structure for variance estimation per + Lumley (2004) Section 3.4. Parameters ---------- @@ -1446,7 +1468,7 @@ def aggregate_survey( ... ) >>> result = DifferenceInDifferences().fit( ... panel, outcome="smoking_rate_mean", - ... treatment="treated", time="post", survey_design=stage2, + ... treatment="treated", time="year", survey_design=stage2, ... ) """ import warnings @@ -1482,12 +1504,21 @@ def aggregate_survey( f"lonely_psu must be 'remove', 'certainty', or 'adjust', got '{lonely_psu}'" ) + # --- Empty-input guard --- + if data.empty: + raise ValueError("data must be non-empty") + # --- Resolve design once on full data --- effective_design = ( replace(survey_design, lonely_psu=lonely_psu) if lonely_psu else survey_design ) full_resolved = effective_design.resolve(data) + # --- Precompute full-length outcome/covariate arrays --- + n_total = len(data) + all_vars = outcome_cols + cov_cols + y_arrays: Dict[str, np.ndarray] = {var: data[var].values.astype(np.float64) for var in all_vars} + # --- Per-cell computation --- grouped = data.groupby(by_cols, sort=True) rows: List[Dict[str, Any]] = [] @@ -1496,32 +1527,17 @@ def aggregate_survey( for cell_key, cell_df in grouped: cell_idx = np.array(cell_df.index) - # Convert to positional indices for array subsetting pos_idx = data.index.get_indexer(cell_idx) - cell_n = len(pos_idx) - cell_key_str = str(cell_key) + # Boolean mask for full-design domain estimation + cell_mask = np.zeros(n_total, dtype=bool) + cell_mask[pos_idx] = True - # Subset arrays from full resolved design - cell_w = full_resolved.weights[pos_idx] - cell_strata = full_resolved.strata[pos_idx] if full_resolved.strata is not None else None - cell_psu = full_resolved.psu[pos_idx] if full_resolved.psu is not None else None - cell_fpc = full_resolved.fpc[pos_idx] if full_resolved.fpc is not None else None - - cell_n_strata = int(len(np.unique(cell_strata))) if cell_strata is not None else 0 - cell_n_psu = int(len(np.unique(cell_psu))) if cell_psu is not None else 0 - - cell_resolved = full_resolved.subset_to_units( - row_idx=pos_idx, - weights=cell_w, - strata=cell_strata, - psu=cell_psu, - fpc=cell_fpc, - n_strata=cell_n_strata, - n_psu=cell_n_psu, - ) + cell_n = int(np.sum(cell_mask)) + cell_key_str = str(cell_key) - # Cell-level statistics + # Cell-level statistics (Kish ESS is a property of the cell) + cell_w = full_resolved.weights[cell_mask] sum_w = float(np.sum(cell_w)) sum_w2 = float(np.sum(cell_w**2)) cell_n_eff = (sum_w**2 / sum_w2) if sum_w2 > 0 else 0.0 @@ -1539,10 +1555,14 @@ def aggregate_survey( cell_srs_fallback = False - # Outcomes: mean + SE + n + precision + # Outcomes: mean + SE + n + precision (full-design domain estimation) for var in outcome_cols: - y = cell_df[var].values.astype(np.float64) - y_bar, variance, n_valid, used_srs = _cell_mean_variance(y, cell_w, cell_resolved) + y_bar, variance, n_valid, used_srs = _cell_mean_variance( + y_arrays[var], + full_resolved, + cell_mask, + min_n, + ) se = float(np.sqrt(variance)) if not np.isnan(variance) else np.nan if used_srs: @@ -1562,14 +1582,14 @@ def aggregate_survey( row[f"{var}_n"] = n_valid row[f"{var}_precision"] = precision - # Covariates: mean only + # Covariates: design-weighted mean only for var in cov_cols: - y = cell_df[var].values.astype(np.float64) - valid = ~np.isnan(y) + y_cell = y_arrays[var][cell_mask] + valid = ~np.isnan(y_cell) w_valid = cell_w * valid.astype(np.float64) sw = float(np.sum(w_valid)) if sw > 0: - row[f"{var}_mean"] = float(np.sum(w_valid * np.where(valid, y, 0.0)) / sw) + row[f"{var}_mean"] = float(np.sum(w_valid * np.where(valid, y_cell, 0.0)) / sw) else: row[f"{var}_mean"] = np.nan diff --git a/docs/api/prep.rst b/docs/api/prep.rst index 45406f7f..7ecee05b 100644 --- a/docs/api/prep.rst +++ b/docs/api/prep.rst @@ -288,7 +288,7 @@ Example panel, outcome="smoking_rate_mean", treatment="treated", - time="post", + time="year", survey_design=stage2, ) diff --git a/docs/methodology/REGISTRY.md b/docs/methodology/REGISTRY.md index edd49e08..576b5e3d 100644 --- a/docs/methodology/REGISTRY.md +++ b/docs/methodology/REGISTRY.md @@ -2310,22 +2310,21 @@ design-based precision estimates, for use as a pre-processing step before panel DiD estimation on repeated cross-section survey data. - **Reference**: Lumley (2004) "Analysis of Complex Survey Samples", Journal of - Statistical Software 9(8). R `survey::svyby()` implements similar per-group - survey estimation. + Statistical Software 9(8), Section 3.4 (domain estimation). - **Cell mean**: Design-weighted mean `ȳ_g = Σ w_i y_i / Σ w_i` for each cell g defined by grouping columns (e.g., state × year). -- **Cell variance**: Linearized influence function `ψ_i = w_i (y_i - ȳ_g) / Σ w_j`, - then design-based variance via `compute_survey_if_variance()` (TSL) or - `compute_replicate_if_variance()` (replicate designs). This is the standard - Horvitz-Thompson linearization for a ratio estimator. +- **Cell variance**: Each cell is treated as a subpopulation/domain of the full + survey design (consistent with `SurveyDesign.subpopulation()` and the + Subpopulation Analysis section below). The influence function + `ψ_i = w_i (y_i - ȳ_g) / Σ w_j` is zero-padded outside the cell, preserving + full strata/PSU structure for variance estimation via `compute_survey_if_variance()` + (TSL) or `compute_replicate_if_variance()` (replicate designs). - **Precision weight**: `1 / V(ȳ_g)` used as inverse-variance weight (aweight) in second-stage DiD estimation. -- **Note:** SRS fallback when design-based variance is unidentifiable within a cell - (e.g., all strata have singleton PSUs after cell subsetting). Formula: - `V_SRS = Σ w_i(y_i - ȳ)² / (Σ w_j)² × n/(n-1)`. Cells using SRS fallback - are flagged via `srs_fallback` column. -- **Note:** FPC values are passed through unchanged from the full design to cell - subsets — they represent population N_h per stratum, not per cell. +- **Note:** SRS fallback when design-based variance is unidentifiable (e.g., all + strata contribute zero variance) or when the cell has fewer than `min_n` valid + observations. Formula: `V_SRS = Σ w_i(y_i - ȳ)² / (Σ w_j)² × n/(n-1)`. + Cells using SRS fallback are flagged via `srs_fallback` column. - **Edge case**: Zero-variance cells (all observations identical) set precision to NaN to avoid infinite weights in second-stage WLS. diff --git a/tests/test_prep.py b/tests/test_prep.py index 1a5d8216..c637e9c9 100644 --- a/tests/test_prep.py +++ b/tests/test_prep.py @@ -2234,19 +2234,25 @@ def test_srs_equivalence_weights_only(self): survey_design=design, ) - # Manual SRS computation for cell A - cell = data[data["geo"] == "A"] - w = cell["wt"].values - y = cell["y"].values - # resolve() normalizes pweights to mean=1 - w_norm = w / w.mean() - sum_w = np.sum(w_norm) - y_bar = np.sum(w_norm * y) / sum_w - n_cell = len(y) - # SRS variance with weights: implicit per-obs PSU - # meat = (n/(n-1)) * sum((w*(y-ybar)/sum_w)^2) - psi = w_norm * (y - y_bar) / sum_w - variance = (n_cell / (n_cell - 1)) * np.sum(psi**2) + # Manual full-design domain estimation for cell A: + # psi is zero-padded to n_total; adjustment uses n_total/(n_total-1) + w_all = data["wt"].values + w_all_norm = w_all / w_all.mean() # resolve() normalizes to mean=1 + cell_mask = (data["geo"] == "A").values + y_all = data["y"].values + w_cell = w_all_norm[cell_mask] + y_cell = y_all[cell_mask] + sum_w = np.sum(w_cell) + y_bar = np.sum(w_cell * y_cell) / sum_w + + # Full-length psi with zeros outside cell + psi_full = np.zeros(n) + psi_full[cell_mask] = w_cell * (y_cell - y_bar) / sum_w + + # Implicit per-obs PSU with full-design adjustment + psi_mean = psi_full.mean() + centered = psi_full - psi_mean + variance = (n / (n - 1)) * np.sum(centered**2) expected_se = np.sqrt(variance) actual_se = panel[panel["geo"] == "A"]["y_se"].iloc[0] @@ -2403,3 +2409,130 @@ def test_error_min_n_too_small(self, micro_data, design): survey_design=design, min_n=0, ) + + def test_error_empty_data(self, design): + """Empty DataFrame raises ValueError.""" + empty = pd.DataFrame(columns=["state", "year", "y", "wt", "stratum", "cluster"]) + with pytest.raises(ValueError, match="data must be non-empty"): + aggregate_survey( + empty, + by=["state", "year"], + outcomes="y", + survey_design=design, + ) + + def test_domain_estimation_preserves_full_design(self): + """Full-design domain estimation accounts for PSUs outside the cell. + + Stratum 0 has PSUs {0, 1}. Cell A contains only PSU 0. + With physical subsetting, stratum 0 would be a singleton → skipped. + With full-design domain estimation, both PSUs participate → non-zero + stratum variance contribution and no SRS fallback. + """ + rng = np.random.RandomState(42) + # 2 strata, 2 PSUs each, 5 obs per PSU = 20 obs total + # Cell A = first 10 obs (stratum 0 PSU 0 + stratum 1 PSU 2) + # Cell B = last 10 obs (stratum 0 PSU 1 + stratum 1 PSU 3) + data = pd.DataFrame( + { + "geo": ["A"] * 10 + ["B"] * 10, + "time": np.ones(20, dtype=int), + "stratum": np.repeat([0, 0, 1, 1], 5), + "psu": np.repeat([0, 1, 2, 3], 5), + "wt": np.ones(20), + "y": rng.normal(10, 2, 20), + } + ) + # Reassign so cell A has only PSU 0 from stratum 0, cell B has only PSU 1 + data.loc[:4, "geo"] = "A" # stratum 0, PSU 0 + data.loc[5:9, "geo"] = "B" # stratum 0, PSU 1 + data.loc[10:14, "geo"] = "A" # stratum 1, PSU 2 + data.loc[15:19, "geo"] = "B" # stratum 1, PSU 3 + + design = SurveyDesign(weights="wt", strata="stratum", psu="psu", lonely_psu="remove") + panel, _ = aggregate_survey( + data, + by=["geo", "time"], + outcomes="y", + survey_design=design, + ) + + cell_a = panel[panel["geo"] == "A"] + # With full-design domain estimation: + # - Both PSUs in each stratum participate (one with zero psi) + # - No singleton PSU → no SRS fallback needed + assert not cell_a["srs_fallback"].iloc[0] + assert cell_a["y_se"].iloc[0] > 0 + assert np.isfinite(cell_a["y_se"].iloc[0]) + assert np.isfinite(cell_a["y_precision"].iloc[0]) + + def test_min_n_forces_srs_fallback(self): + """min_n parameter forces SRS fallback for small cells.""" + rng = np.random.RandomState(44) + n = 40 + data = pd.DataFrame( + { + "geo": np.repeat(["A", "B"], n // 2), + "time": np.ones(n, dtype=int), + "wt": np.ones(n), + "y": rng.normal(10, 2, n), + } + ) + design = SurveyDesign(weights="wt") + + # min_n=30 → cells with 20 obs each should use SRS fallback + with pytest.warns(UserWarning, match="SRS fallback"): + panel_high, _ = aggregate_survey( + data, + by=["geo", "time"], + outcomes="y", + survey_design=design, + min_n=30, + ) + assert panel_high["srs_fallback"].all() + + # min_n=1 → should use design-based variance (no fallback) + panel_low, _ = aggregate_survey( + data, + by=["geo", "time"], + outcomes="y", + survey_design=design, + min_n=1, + ) + assert not panel_low["srs_fallback"].any() + + # SEs should differ between the two + se_high = panel_high[panel_high["geo"] == "A"]["y_se"].iloc[0] + se_low = panel_low[panel_low["geo"] == "A"]["y_se"].iloc[0] + assert se_high != pytest.approx(se_low, rel=1e-6) + + def test_replicate_weight_aggregation(self): + """Aggregation with replicate weight designs produces valid SEs.""" + from diff_diff.prep_dgp import generate_survey_did_data + + micro = generate_survey_did_data( + n_units=200, + n_periods=4, + cohort_periods=[3], + n_strata=3, + psu_per_stratum=6, + include_replicate_weights=True, + panel=False, + seed=42, + ) + # Build replicate weight column list + rep_cols = [c for c in micro.columns if c.startswith("rep_")] + design = SurveyDesign( + weights="weight", + replicate_weights=rep_cols, + replicate_method="BRR", + ) + panel, _ = aggregate_survey( + micro, + by=["stratum", "period"], + outcomes="outcome", + survey_design=design, + ) + # All cells should have finite, positive SEs + assert panel["outcome_se"].notna().all() + assert (panel["outcome_se"] > 0).all() From fe77d84a1d54bda99fd9e817bd05ea74c68d4118 Mon Sep 17 00:00:00 2001 From: igerber Date: Tue, 7 Apr 2026 10:46:59 -0400 Subject: [PATCH 03/16] Validate grouping columns for NaN, fix examples, add missing-key tests Address remaining P1/P3 from AI review: - Reject NaN in by columns before groupby (P1) - Make docstring/RST examples illustrative pseudocode (P3) - Add tests for partial and all-NaN grouping keys (P3) Co-Authored-By: Claude Opus 4.6 (1M context) --- diff_diff/prep.py | 20 ++++++++++++++++---- docs/api/prep.rst | 14 +++++++------- tests/test_prep.py | 31 +++++++++++++++++++++++++++++++ 3 files changed, 54 insertions(+), 11 deletions(-) diff --git a/diff_diff/prep.py b/diff_diff/prep.py index f2a1f95a..f3dcb420 100644 --- a/diff_diff/prep.py +++ b/diff_diff/prep.py @@ -1466,10 +1466,13 @@ def aggregate_survey( ... microdata, by=["state", "year"], ... outcomes="smoking_rate", survey_design=design, ... ) - >>> result = DifferenceInDifferences().fit( - ... panel, outcome="smoking_rate_mean", - ... treatment="treated", time="year", survey_design=stage2, - ... ) + >>> # Add treatment/time indicators at the panel level, then fit: + >>> # panel["treated"] = ... # e.g., from policy adoption data + >>> # panel["post"] = (panel["year"] >= treatment_year).astype(int) + >>> # result = DifferenceInDifferences().fit( + >>> # panel, outcome="smoking_rate_mean", + >>> # treatment="treated", time="post", survey_design=stage2, + >>> # ) """ import warnings from dataclasses import replace @@ -1508,6 +1511,15 @@ def aggregate_survey( if data.empty: raise ValueError("data must be non-empty") + # --- Validate grouping columns have no missing values --- + by_missing = data[by_cols].isna().any() + cols_with_na = list(by_missing[by_missing].index) + if cols_with_na: + raise ValueError( + f"Missing values in grouping column(s): {cols_with_na}. " + f"Drop or fill NaN values before calling aggregate_survey()." + ) + # --- Resolve design once on full data --- effective_design = ( replace(survey_design, lonely_psu=lonely_psu) if lonely_psu else survey_design diff --git a/docs/api/prep.rst b/docs/api/prep.rst index 7ecee05b..f1fded43 100644 --- a/docs/api/prep.rst +++ b/docs/api/prep.rst @@ -284,13 +284,13 @@ Example # cell_n, cell_n_eff, srs_fallback # stage2 is pre-configured: aweights + state-level clustering - result = DifferenceInDifferences().fit( - panel, - outcome="smoking_rate_mean", - treatment="treated", - time="year", - survey_design=stage2, - ) + # Add treatment/time indicators at the panel level, then fit: + # panel["treated"] = ... # from policy adoption data + # panel["post"] = (panel["year"] >= treatment_year).astype(int) + # result = DifferenceInDifferences().fit( + # panel, outcome="smoking_rate_mean", + # treatment="treated", time="post", survey_design=stage2, + # ) Data Validation --------------- diff --git a/tests/test_prep.py b/tests/test_prep.py index c637e9c9..4913ee28 100644 --- a/tests/test_prep.py +++ b/tests/test_prep.py @@ -2421,6 +2421,37 @@ def test_error_empty_data(self, design): survey_design=design, ) + def test_error_missing_grouping_keys(self, micro_data, design): + """NaN in grouping columns raises ValueError.""" + data = micro_data.copy() + data.loc[0, "state"] = np.nan + with pytest.raises(ValueError, match="Missing values in grouping column"): + aggregate_survey( + data, + by=["state", "year"], + outcomes="y", + survey_design=design, + ) + + def test_error_all_missing_grouping_keys(self, design): + """All-NaN grouping column raises ValueError.""" + data = pd.DataFrame( + { + "state": [np.nan] * 10, + "year": np.ones(10, dtype=int), + "y": np.random.RandomState(1).normal(0, 1, 10), + "wt": np.ones(10), + } + ) + design_simple = SurveyDesign(weights="wt") + with pytest.raises(ValueError, match="Missing values in grouping column"): + aggregate_survey( + data, + by=["state", "year"], + outcomes="y", + survey_design=design_simple, + ) + def test_domain_estimation_preserves_full_design(self): """Full-design domain estimation accounts for PSUs outside the cell. From 1bbdfa882b74e05891df20a065eef582dbe6f20e Mon Sep 17 00:00:00 2001 From: igerber Date: Tue, 7 Apr 2026 11:01:02 -0400 Subject: [PATCH 04/16] Use positional indices for cell membership, add duplicate-index test Replace label-based index lookup with stable positional row tracking via _row_pos column, so duplicate DataFrame indices cannot break or mis-map cell aggregation. Add regression test verifying identical results with duplicated vs clean indices. Co-Authored-By: Claude Opus 4.6 (1M context) --- diff_diff/prep.py | 7 ++++--- tests/test_prep.py | 32 ++++++++++++++++++++++++++++++++ 2 files changed, 36 insertions(+), 3 deletions(-) diff --git a/diff_diff/prep.py b/diff_diff/prep.py index f3dcb420..a96a6e19 100644 --- a/diff_diff/prep.py +++ b/diff_diff/prep.py @@ -1532,14 +1532,15 @@ def aggregate_survey( y_arrays: Dict[str, np.ndarray] = {var: data[var].values.astype(np.float64) for var in all_vars} # --- Per-cell computation --- - grouped = data.groupby(by_cols, sort=True) + # Use stable positional indices (safe with duplicate DataFrame indices) + row_positions = np.arange(n_total) + grouped = data.assign(_row_pos=row_positions).groupby(by_cols, sort=True) rows: List[Dict[str, Any]] = [] srs_cells: List[str] = [] zero_var_cells: List[str] = [] for cell_key, cell_df in grouped: - cell_idx = np.array(cell_df.index) - pos_idx = data.index.get_indexer(cell_idx) + pos_idx = cell_df["_row_pos"].values # Boolean mask for full-design domain estimation cell_mask = np.zeros(n_total, dtype=bool) diff --git a/tests/test_prep.py b/tests/test_prep.py index 4913ee28..2c3e9041 100644 --- a/tests/test_prep.py +++ b/tests/test_prep.py @@ -2452,6 +2452,38 @@ def test_error_all_missing_grouping_keys(self, design): survey_design=design_simple, ) + def test_duplicate_index(self): + """Duplicate DataFrame indices do not break aggregation.""" + rng = np.random.RandomState(77) + n = 40 + data = pd.DataFrame( + { + "geo": np.repeat(["A", "B"], n // 2), + "time": np.tile(np.repeat([0, 1], n // 4), 2), + "wt": np.ones(n), + "y": rng.normal(10, 2, n), + } + ) + # Create duplicate indices (e.g., from concat without reset_index) + data.index = list(range(n // 2)) * 2 # 0..19, 0..19 + + design = SurveyDesign(weights="wt") + panel_dup, _ = aggregate_survey( + data, by=["geo", "time"], outcomes="y", survey_design=design + ) + + # Compare against clean-index version + data_clean = data.reset_index(drop=True) + panel_clean, _ = aggregate_survey( + data_clean, by=["geo", "time"], outcomes="y", survey_design=design + ) + + # Results should be identical + np.testing.assert_allclose( + panel_dup["y_mean"].values, panel_clean["y_mean"].values, rtol=1e-12 + ) + np.testing.assert_allclose(panel_dup["y_se"].values, panel_clean["y_se"].values, rtol=1e-12) + def test_domain_estimation_preserves_full_design(self): """Full-design domain estimation accounts for PSUs outside the cell. From 4788166cc653a4155aae2f9b5ac9ac3da841f3c3 Mon Sep 17 00:00:00 2001 From: igerber Date: Tue, 7 Apr 2026 11:08:54 -0400 Subject: [PATCH 05/16] Use groupby().indices instead of injected column for cell membership Replace _row_pos column injection with groupby().indices which returns positional indices directly. No temporary columns are added to the user's data, eliminating any column name collision risk. Co-Authored-By: Claude Opus 4.6 (1M context) --- diff_diff/prep.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/diff_diff/prep.py b/diff_diff/prep.py index a96a6e19..16163de6 100644 --- a/diff_diff/prep.py +++ b/diff_diff/prep.py @@ -1532,16 +1532,15 @@ def aggregate_survey( y_arrays: Dict[str, np.ndarray] = {var: data[var].values.astype(np.float64) for var in all_vars} # --- Per-cell computation --- - # Use stable positional indices (safe with duplicate DataFrame indices) - row_positions = np.arange(n_total) - grouped = data.assign(_row_pos=row_positions).groupby(by_cols, sort=True) + # Use groupby().indices for position-based cell membership (safe with + # duplicate DataFrame indices, no column injection into user data) + grouped = data.groupby(by_cols, sort=True) + cell_indices = grouped.indices # dict of cell_key → positional indices rows: List[Dict[str, Any]] = [] srs_cells: List[str] = [] zero_var_cells: List[str] = [] - for cell_key, cell_df in grouped: - pos_idx = cell_df["_row_pos"].values - + for cell_key, pos_idx in cell_indices.items(): # Boolean mask for full-design domain estimation cell_mask = np.zeros(n_total, dtype=bool) cell_mask[pos_idx] = True From 703a7fe74b5c528bdcb01607bf76e4bb774c6a00 Mon Sep 17 00:00:00 2001 From: igerber Date: Tue, 7 Apr 2026 11:36:12 -0400 Subject: [PATCH 06/16] Exclude zero-weight rows from valid observation count Define validity as non-NaN AND positive weight so zero-weight padding rows don't inflate {outcome}_n or bypass n_valid < 2 / min_n guards. Add regression test for cell with 1 real + 9 zero-weight observations. Co-Authored-By: Claude Opus 4.6 (1M context) --- diff_diff/prep.py | 3 ++- tests/test_prep.py | 30 ++++++++++++++++++++++++++++++ 2 files changed, 32 insertions(+), 1 deletion(-) diff --git a/diff_diff/prep.py b/diff_diff/prep.py index 16163de6..53509939 100644 --- a/diff_diff/prep.py +++ b/diff_diff/prep.py @@ -1351,7 +1351,8 @@ def _cell_mean_variance( """ y_cell = y_full[cell_mask] w_cell = full_resolved.weights[cell_mask] - valid = ~np.isnan(y_cell) + # Valid = non-missing AND positive weight (zero-weight rows are padding) + valid = ~np.isnan(y_cell) & (w_cell > 0) n_valid = int(np.sum(valid)) if n_valid == 0: diff --git a/tests/test_prep.py b/tests/test_prep.py index 2c3e9041..da4edd8f 100644 --- a/tests/test_prep.py +++ b/tests/test_prep.py @@ -2452,6 +2452,36 @@ def test_error_all_missing_grouping_keys(self, design): survey_design=design_simple, ) + def test_zero_weight_rows_excluded_from_n_valid(self): + """Zero-weight rows should not count as valid observations.""" + rng = np.random.RandomState(66) + # Cell A: 1 positive-weight obs + 9 zero-weight padding + # With only 1 effective observation, SE should be NaN + data = pd.DataFrame( + { + "geo": ["A"] * 10 + ["B"] * 10, + "time": np.ones(20, dtype=int), + "wt": np.concatenate( + [ + np.array([1.0] + [0.0] * 9), # A: 1 real, 9 padding + np.ones(10), # B: all real + ] + ), + "y": rng.normal(10, 2, 20), + } + ) + design = SurveyDesign(weights="wt") + panel, _ = aggregate_survey(data, by=["geo", "time"], outcomes="y", survey_design=design) + cell_a = panel[panel["geo"] == "A"] + # Only 1 positive-weight obs → n_valid=1, SE=NaN + assert cell_a["y_n"].iloc[0] == 1 + assert np.isnan(cell_a["y_se"].iloc[0]) + + cell_b = panel[panel["geo"] == "B"] + # 10 positive-weight obs → normal SE + assert cell_b["y_n"].iloc[0] == 10 + assert cell_b["y_se"].iloc[0] > 0 + def test_duplicate_index(self): """Duplicate DataFrame indices do not break aggregation.""" rng = np.random.RandomState(77) From f04fe0d18cf4fae93cc33bd3facc9a55605b80c8 Mon Sep 17 00:00:00 2001 From: igerber Date: Tue, 7 Apr 2026 11:48:56 -0400 Subject: [PATCH 07/16] Add fit-ready weight column mapping NaN precision to 0.0 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The returned SurveyDesign now points at a {outcome}_weight column where NaN/Inf precision values are mapped to 0.0, so downstream fit() never rejects missing weights. Diagnostic *_precision column is preserved as-is. Add stage2-handoff test with single-observation cell (NaN precision → zero weight → fit succeeds). Co-Authored-By: Claude Opus 4.6 (1M context) --- diff_diff/prep.py | 17 +++++++++++--- tests/test_prep.py | 57 ++++++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 69 insertions(+), 5 deletions(-) diff --git a/diff_diff/prep.py b/diff_diff/prep.py index 53509939..a962817c 100644 --- a/diff_diff/prep.py +++ b/diff_diff/prep.py @@ -1453,8 +1453,10 @@ def aggregate_survey( panel_df : pd.DataFrame Aggregated panel with columns: grouping variables, ``{outcome}_mean``, ``{outcome}_se``, ``{outcome}_n``, - ``{outcome}_precision``, ``{covariate}_mean``, ``cell_n``, - ``cell_n_eff``, ``srs_fallback``. + ``{outcome}_precision``, ``{outcome}_weight``, + ``{covariate}_mean``, ``cell_n``, ``cell_n_eff``, + ``srs_fallback``. The ``_weight`` column is a fit-ready + version of ``_precision`` with NaN/Inf mapped to 0.0. second_stage_design : SurveyDesign Pre-configured for second-stage estimation with ``weight_type="aweight"``, precision weights from the first @@ -1637,9 +1639,18 @@ def aggregate_survey( panel_df = panel_df.sort_values(by_cols).reset_index(drop=True) # --- Construct second-stage SurveyDesign --- + # Create a fit-ready weight column: NaN/Inf precision → 0.0 so downstream + # resolve() doesn't reject missing weights. Diagnostic *_precision is kept. first_outcome = outcome_cols[0] + weight_col = f"{first_outcome}_weight" + panel_df[weight_col] = np.where( + np.isfinite(panel_df[f"{first_outcome}_precision"]), + panel_df[f"{first_outcome}_precision"], + 0.0, + ) + second_stage_design = SurveyDesign( - weights=f"{first_outcome}_precision", + weights=weight_col, weight_type="aweight", psu=by_cols[0], ) diff --git a/tests/test_prep.py b/tests/test_prep.py index da4edd8f..d338a166 100644 --- a/tests/test_prep.py +++ b/tests/test_prep.py @@ -2057,7 +2057,7 @@ def test_multiple_outcomes(self, micro_data, design): assert "y2_mean" in panel.columns assert "y_precision" in panel.columns assert "y2_precision" in panel.columns - assert stage2.weights == "y_precision" + assert stage2.weights == "y_weight" def test_covariates_mean_only(self, micro_data, design): """Covariates get mean column only, no SE/precision.""" @@ -2081,7 +2081,7 @@ def test_returned_survey_design(self, micro_data, design): survey_design=design, ) assert stage2.weight_type == "aweight" - assert stage2.weights == "y_precision" + assert stage2.weights == "y_weight" assert stage2.psu == "state" def test_srs_fallback(self): @@ -2452,6 +2452,59 @@ def test_error_all_missing_grouping_keys(self, design): survey_design=design_simple, ) + def test_stage2_handoff_with_nonfinite_cells(self): + """stage2 SurveyDesign works even when some cells have NaN precision.""" + from diff_diff import DifferenceInDifferences + + rng = np.random.RandomState(99) + rows = [] + for state in range(4): + treated = 1 if state < 2 else 0 + for period in [0, 1]: + te = 3.0 if (treated and period == 1) else 0.0 + n_cell = 30 + for _ in range(n_cell): + rows.append( + { + "state": state, + "period": period, + "wt": rng.uniform(0.5, 2.0), + "outcome": rng.normal(10 + te, 2), + "treated": treated, + } + ) + micro = pd.DataFrame(rows) + # Make one cell have only 1 observation → NaN SE → NaN precision + mask = (micro["state"] == 0) & (micro["period"] == 0) + micro = micro.drop(micro[mask].index[1:]) # keep only 1 row + + design = SurveyDesign(weights="wt") + panel, stage2 = aggregate_survey( + micro, + by=["state", "period"], + outcomes="outcome", + covariates="treated", + survey_design=design, + ) + + # The zero-variance cell should have weight=0 (not NaN) + cell_00 = panel[(panel["state"] == 0) & (panel["period"] == 0)] + assert np.isnan(cell_00["outcome_precision"].iloc[0]) # diagnostic + assert cell_00["outcome_weight"].iloc[0] == 0.0 # fit-ready + + # stage2 should work with fit() despite NaN-precision cells + panel["treated_bin"] = (panel["treated_mean"] > 0.5).astype(int) + did = DifferenceInDifferences() + result = did.fit( + panel, + outcome="outcome_mean", + treatment="treated_bin", + time="period", + survey_design=stage2, + ) + assert np.isfinite(result.att) + assert np.isfinite(result.se) + def test_zero_weight_rows_excluded_from_n_valid(self): """Zero-weight rows should not count as valid observations.""" rng = np.random.RandomState(66) From 7f84169d87f06600c23b923b705d6ad15e16020f Mon Sep 17 00:00:00 2001 From: igerber Date: Tue, 7 Apr 2026 12:12:29 -0400 Subject: [PATCH 08/16] Drop non-estimable cells before returning panel Cells with non-finite outcome mean (n_valid==0, all-missing, all-zero- weight) are dropped from the panel with a warning before constructing the stage2 SurveyDesign. This ensures fit() never encounters NaN outcomes and eliminates all-zero-weight PSUs from second-stage variance/df calculations. Co-Authored-By: Claude Opus 4.6 (1M context) --- diff_diff/prep.py | 19 ++++++++++++++++++- tests/test_prep.py | 34 +++++++++++++++++++--------------- 2 files changed, 37 insertions(+), 16 deletions(-) diff --git a/diff_diff/prep.py b/diff_diff/prep.py index a962817c..fdd8a552 100644 --- a/diff_diff/prep.py +++ b/diff_diff/prep.py @@ -1638,10 +1638,27 @@ def aggregate_survey( # Sort by grouping columns panel_df = panel_df.sort_values(by_cols).reset_index(drop=True) + # --- Drop non-estimable cells --- + # Cells with non-finite mean (n_valid==0 or all-missing) cannot contribute + # to second-stage estimation and would cause fit() to reject NaN outcomes. + # Dropping them also removes all-zero-weight PSUs from the panel. + first_outcome = outcome_cols[0] + mean_col = f"{first_outcome}_mean" + nonestimable = ~np.isfinite(panel_df[mean_col].values) + if np.any(nonestimable): + n_dropped = int(np.sum(nonestimable)) + dropped_keys = panel_df.loc[nonestimable, by_cols].values.tolist() + warnings.warn( + f"Dropped {n_dropped} non-estimable cell(s) with no valid observations: " + f"{dropped_keys[:5]}" + (f" ... and {n_dropped - 5} more" if n_dropped > 5 else ""), + UserWarning, + stacklevel=2, + ) + panel_df = panel_df[~nonestimable].reset_index(drop=True) + # --- Construct second-stage SurveyDesign --- # Create a fit-ready weight column: NaN/Inf precision → 0.0 so downstream # resolve() doesn't reject missing weights. Diagnostic *_precision is kept. - first_outcome = outcome_cols[0] weight_col = f"{first_outcome}_weight" panel_df[weight_col] = np.where( np.isfinite(panel_df[f"{first_outcome}_precision"]), diff --git a/tests/test_prep.py b/tests/test_prep.py index d338a166..6d98675d 100644 --- a/tests/test_prep.py +++ b/tests/test_prep.py @@ -2453,7 +2453,7 @@ def test_error_all_missing_grouping_keys(self, design): ) def test_stage2_handoff_with_nonfinite_cells(self): - """stage2 SurveyDesign works even when some cells have NaN precision.""" + """Non-estimable cells are dropped; stage2 works with fit().""" from diff_diff import DifferenceInDifferences rng = np.random.RandomState(99) @@ -2474,25 +2474,29 @@ def test_stage2_handoff_with_nonfinite_cells(self): } ) micro = pd.DataFrame(rows) - # Make one cell have only 1 observation → NaN SE → NaN precision + # Make one cell all-NaN outcome → n_valid=0 → NaN mean → dropped mask = (micro["state"] == 0) & (micro["period"] == 0) - micro = micro.drop(micro[mask].index[1:]) # keep only 1 row + micro.loc[mask, "outcome"] = np.nan design = SurveyDesign(weights="wt") - panel, stage2 = aggregate_survey( - micro, - by=["state", "period"], - outcomes="outcome", - covariates="treated", - survey_design=design, - ) + with pytest.warns(UserWarning, match="non-estimable"): + panel, stage2 = aggregate_survey( + micro, + by=["state", "period"], + outcomes="outcome", + covariates="treated", + survey_design=design, + ) + + # Non-estimable cell should be dropped from panel + assert len(panel) == 7 # 8 cells - 1 dropped + assert not ((panel["state"] == 0) & (panel["period"] == 0)).any() - # The zero-variance cell should have weight=0 (not NaN) - cell_00 = panel[(panel["state"] == 0) & (panel["period"] == 0)] - assert np.isnan(cell_00["outcome_precision"].iloc[0]) # diagnostic - assert cell_00["outcome_weight"].iloc[0] == 0.0 # fit-ready + # No NaN in outcome mean or weight columns + assert panel["outcome_mean"].notna().all() + assert panel["outcome_weight"].notna().all() - # stage2 should work with fit() despite NaN-precision cells + # stage2 should work with fit() panel["treated_bin"] = (panel["treated_mean"] > 0.5).astype(int) did = DifferenceInDifferences() result = did.fit( From fa6c73874b9552404edbdba5db7a000786b07a62 Mon Sep 17 00:00:00 2001 From: igerber Date: Tue, 7 Apr 2026 12:28:28 -0400 Subject: [PATCH 09/16] Drop zero-weight PSUs and guard empty post-drop panel Geographic units where every cell has zero weight are pruned before constructing stage2, preventing inflated survey df/variance. If all cells are dropped, raise ValueError with clear message. Add regressions for zero-weight PSU pruning and all-cells-dropped cases. Co-Authored-By: Claude Opus 4.6 (1M context) --- diff_diff/prep.py | 27 ++++++++++++++++++- tests/test_prep.py | 65 +++++++++++++++++++++++++++++++++++++++++----- 2 files changed, 84 insertions(+), 8 deletions(-) diff --git a/diff_diff/prep.py b/diff_diff/prep.py index fdd8a552..e752b0c3 100644 --- a/diff_diff/prep.py +++ b/diff_diff/prep.py @@ -1666,10 +1666,35 @@ def aggregate_survey( 0.0, ) + # Drop geographic units (PSUs) with zero total weight — they would + # inflate survey df and distort second-stage variance estimation. + geo_col = by_cols[0] + geo_weight = panel_df.groupby(geo_col)[weight_col].sum() + zero_geos = geo_weight[geo_weight == 0].index + if len(zero_geos) > 0: + n_before = len(panel_df) + panel_df = panel_df[~panel_df[geo_col].isin(zero_geos)].reset_index(drop=True) + n_after = len(panel_df) + warnings.warn( + f"Dropped {n_before - n_after} cell(s) from {len(zero_geos)} " + f"geographic unit(s) with zero total weight: " + f"{list(zero_geos[:5])}" + + (f" ... and {len(zero_geos) - 5} more" if len(zero_geos) > 5 else ""), + UserWarning, + stacklevel=2, + ) + + # Guard: all cells dropped + if panel_df.empty: + raise ValueError( + "No estimable cells remain after aggregation. " + "All cells had missing outcomes or zero effective weight." + ) + second_stage_design = SurveyDesign( weights=weight_col, weight_type="aweight", - psu=by_cols[0], + psu=geo_col, ) return panel_df, second_stage_design diff --git a/tests/test_prep.py b/tests/test_prep.py index 6d98675d..7aa0b462 100644 --- a/tests/test_prep.py +++ b/tests/test_prep.py @@ -2509,18 +2509,70 @@ def test_stage2_handoff_with_nonfinite_cells(self): assert np.isfinite(result.att) assert np.isfinite(result.se) + def test_zero_weight_psu_dropped(self): + """Geographic units with zero total weight are dropped from panel.""" + rng = np.random.RandomState(88) + # State 0: all cells have only 1 valid obs → NaN precision → weight=0 + # State 1-3: normal cells + rows = [] + for state in range(4): + for period in [0, 1]: + if state == 0: + # 1 obs per cell → NaN SE → weight=0 + rows.append( + { + "state": state, + "period": period, + "wt": 1.0, + "y": rng.normal(10, 2), + } + ) + else: + for _ in range(20): + rows.append( + { + "state": state, + "period": period, + "wt": 1.0, + "y": rng.normal(10, 2), + } + ) + data = pd.DataFrame(rows) + design = SurveyDesign(weights="wt") + with pytest.warns(UserWarning, match="zero total weight"): + panel, _ = aggregate_survey( + data, by=["state", "period"], outcomes="y", survey_design=design + ) + # State 0 should be entirely gone + assert 0 not in panel["state"].values + assert len(panel) == 6 # 3 states × 2 periods + + def test_error_all_cells_dropped(self): + """All cells non-estimable raises ValueError.""" + data = pd.DataFrame( + { + "state": ["A"] * 5, + "period": np.ones(5, dtype=int), + "wt": np.ones(5), + "y": [np.nan] * 5, + } + ) + design = SurveyDesign(weights="wt") + with pytest.raises(ValueError, match="No estimable cells remain"): + aggregate_survey(data, by=["state", "period"], outcomes="y", survey_design=design) + def test_zero_weight_rows_excluded_from_n_valid(self): """Zero-weight rows should not count as valid observations.""" rng = np.random.RandomState(66) - # Cell A: 1 positive-weight obs + 9 zero-weight padding - # With only 1 effective observation, SE should be NaN + # Cell A: 3 positive-weight obs + 7 zero-weight padding + # n_valid should be 3, not 10 data = pd.DataFrame( { "geo": ["A"] * 10 + ["B"] * 10, "time": np.ones(20, dtype=int), "wt": np.concatenate( [ - np.array([1.0] + [0.0] * 9), # A: 1 real, 9 padding + np.array([1.0, 1.0, 1.0] + [0.0] * 7), # A: 3 real np.ones(10), # B: all real ] ), @@ -2530,12 +2582,11 @@ def test_zero_weight_rows_excluded_from_n_valid(self): design = SurveyDesign(weights="wt") panel, _ = aggregate_survey(data, by=["geo", "time"], outcomes="y", survey_design=design) cell_a = panel[panel["geo"] == "A"] - # Only 1 positive-weight obs → n_valid=1, SE=NaN - assert cell_a["y_n"].iloc[0] == 1 - assert np.isnan(cell_a["y_se"].iloc[0]) + # Only 3 positive-weight obs → n_valid=3 + assert cell_a["y_n"].iloc[0] == 3 cell_b = panel[panel["geo"] == "B"] - # 10 positive-weight obs → normal SE + # 10 positive-weight obs assert cell_b["y_n"].iloc[0] == 10 assert cell_b["y_se"].iloc[0] > 0 From d6abd0651a08d79e47839e6cadc5c8778b4390d0 Mon Sep 17 00:00:00 2001 From: igerber Date: Tue, 7 Apr 2026 12:42:08 -0400 Subject: [PATCH 10/16] Document multi-outcome filtering contract, warn on secondary data loss Multi-outcome filtering is based on the first outcome (consistent with the returned SurveyDesign using the first outcome's weights). Docstring now explicitly states this contract. Warning emitted when dropped cells had valid data for secondary outcomes. Test added for the contract. Co-Authored-By: Claude Opus 4.6 (1M context) --- diff_diff/prep.py | 27 +++++++++++++++++++++------ tests/test_prep.py | 26 ++++++++++++++++++++++++++ 2 files changed, 47 insertions(+), 6 deletions(-) diff --git a/diff_diff/prep.py b/diff_diff/prep.py index e752b0c3..ed222d02 100644 --- a/diff_diff/prep.py +++ b/diff_diff/prep.py @@ -1435,7 +1435,11 @@ def aggregate_survey( outcomes : str or list of str Outcome variable(s) to aggregate with full precision tracking. Each outcome produces ``{name}_mean``, ``{name}_se``, - ``{name}_n``, and ``{name}_precision`` columns. + ``{name}_n``, and ``{name}_precision`` columns. When multiple + outcomes are given, panel filtering (non-estimable cell + removal, zero-weight PSU pruning) is based on the **first** + outcome only, consistent with the returned SurveyDesign. For + independent per-outcome support, call once per outcome. survey_design : SurveyDesign Survey design specification for the microdata. covariates : str or list of str, optional @@ -1648,12 +1652,23 @@ def aggregate_survey( if np.any(nonestimable): n_dropped = int(np.sum(nonestimable)) dropped_keys = panel_df.loc[nonestimable, by_cols].values.tolist() - warnings.warn( - f"Dropped {n_dropped} non-estimable cell(s) with no valid observations: " - f"{dropped_keys[:5]}" + (f" ... and {n_dropped - 5} more" if n_dropped > 5 else ""), - UserWarning, - stacklevel=2, + # Warn about secondary outcomes losing valid data in dropped cells + secondary_loss = [] + for var in outcome_cols[1:]: + valid_secondary = np.isfinite(panel_df.loc[nonestimable, f"{var}_mean"].values) + if np.any(valid_secondary): + secondary_loss.append(var) + msg = ( + f"Dropped {n_dropped} non-estimable cell(s) (based on first outcome " + f"'{first_outcome}'): {dropped_keys[:5]}" + + (f" ... and {n_dropped - 5} more" if n_dropped > 5 else "") ) + if secondary_loss: + msg += ( + f". Note: {secondary_loss} had valid data in dropped cells. " + f"For independent per-outcome support, call once per outcome." + ) + warnings.warn(msg, UserWarning, stacklevel=2) panel_df = panel_df[~nonestimable].reset_index(drop=True) # --- Construct second-stage SurveyDesign --- diff --git a/tests/test_prep.py b/tests/test_prep.py index 7aa0b462..6f2a4c8c 100644 --- a/tests/test_prep.py +++ b/tests/test_prep.py @@ -2059,6 +2059,32 @@ def test_multiple_outcomes(self, micro_data, design): assert "y2_precision" in panel.columns assert stage2.weights == "y_weight" + def test_multi_outcome_filtering_contract(self): + """Multi-outcome filtering is based on first outcome; warns about secondary data loss.""" + rng = np.random.RandomState(33) + data = pd.DataFrame( + { + "geo": np.repeat(["A", "B"], 20), + "time": np.ones(40, dtype=int), + "wt": np.ones(40), + "y1": np.concatenate( + [[np.nan] * 20, rng.normal(10, 2, 20)] + ), # A: all-NaN, B: valid + "y2": rng.normal(5, 1, 40), # valid everywhere + } + ) + design = SurveyDesign(weights="wt") + with pytest.warns(UserWarning, match="y2.*valid data in dropped"): + panel, _ = aggregate_survey( + data, + by=["geo", "time"], + outcomes=["y1", "y2"], + survey_design=design, + ) + # Cell A dropped (y1 non-estimable), even though y2 was valid + assert len(panel) == 1 + assert panel["geo"].iloc[0] == "B" + def test_covariates_mean_only(self, micro_data, design): """Covariates get mean column only, no SE/precision.""" panel, _ = aggregate_survey( From 15b3d4a4b458967a16545fda2b3cad64bef2371f Mon Sep 17 00:00:00 2001 From: igerber Date: Tue, 7 Apr 2026 12:58:33 -0400 Subject: [PATCH 11/16] Address remaining P2/P3: empty-list guards, replicate method, docs - Validate empty by/outcomes lists with clear ValueError (P2) - Fix replicate test to use JK1 matching DGP output (P3) - Update RST example with *_weight column and filtering note (P3) Co-Authored-By: Claude Opus 4.6 (1M context) --- diff_diff/prep.py | 5 +++++ docs/api/prep.rst | 8 ++++++-- tests/test_prep.py | 17 ++++++++++++++++- 3 files changed, 27 insertions(+), 3 deletions(-) diff --git a/diff_diff/prep.py b/diff_diff/prep.py index ed222d02..9a4db4df 100644 --- a/diff_diff/prep.py +++ b/diff_diff/prep.py @@ -1492,6 +1492,11 @@ def aggregate_survey( ) # --- Validate --- + if not by_cols: + raise ValueError("'by' must specify at least one grouping column") + if not outcome_cols: + raise ValueError("'outcomes' must specify at least one outcome variable") + all_cols = by_cols + outcome_cols + cov_cols missing = [c for c in all_cols if c not in data.columns] if missing: diff --git a/docs/api/prep.rst b/docs/api/prep.rst index f1fded43..7491c085 100644 --- a/docs/api/prep.rst +++ b/docs/api/prep.rst @@ -280,8 +280,12 @@ Example ) # panel has: state, year, smoking_rate_mean, smoking_rate_se, - # smoking_rate_n, smoking_rate_precision, age_mean, income_mean, - # cell_n, cell_n_eff, srs_fallback + # smoking_rate_n, smoking_rate_precision, smoking_rate_weight, + # age_mean, income_mean, cell_n, cell_n_eff, srs_fallback + # + # *_weight is fit-ready (NaN precision -> 0.0) + # Non-estimable cells and zero-weight geos are dropped automatically. + # Multi-outcome filtering is keyed off the first outcome. # stage2 is pre-configured: aweights + state-level clustering # Add treatment/time indicators at the panel level, then fit: diff --git a/tests/test_prep.py b/tests/test_prep.py index 6f2a4c8c..50e89f25 100644 --- a/tests/test_prep.py +++ b/tests/test_prep.py @@ -2436,6 +2436,21 @@ def test_error_min_n_too_small(self, micro_data, design): min_n=0, ) + def test_error_empty_by(self, micro_data, design): + """Empty by list raises ValueError.""" + with pytest.raises(ValueError, match="at least one grouping column"): + aggregate_survey(micro_data, by=[], outcomes="y", survey_design=design) + + def test_error_empty_outcomes(self, micro_data, design): + """Empty outcomes list raises ValueError.""" + with pytest.raises(ValueError, match="at least one outcome"): + aggregate_survey( + micro_data, + by=["state", "year"], + outcomes=[], + survey_design=design, + ) + def test_error_empty_data(self, design): """Empty DataFrame raises ValueError.""" empty = pd.DataFrame(columns=["state", "year", "y", "wt", "stratum", "cluster"]) @@ -2752,7 +2767,7 @@ def test_replicate_weight_aggregation(self): design = SurveyDesign( weights="weight", replicate_weights=rep_cols, - replicate_method="BRR", + replicate_method="JK1", ) panel, _ = aggregate_survey( micro, From 8ac222222b7a00e1bd1a010b8dd8037b88a54c2b Mon Sep 17 00:00:00 2001 From: igerber Date: Tue, 7 Apr 2026 14:07:51 -0400 Subject: [PATCH 12/16] Validate numeric dtype for outcomes/covariates before aggregation Non-numeric columns now raise a clear ValueError instead of failing inside astype(np.float64). Add regression test for string column. Co-Authored-By: Claude Opus 4.6 (1M context) --- diff_diff/prep.py | 6 ++++++ tests/test_prep.py | 12 ++++++++++++ 2 files changed, 18 insertions(+) diff --git a/diff_diff/prep.py b/diff_diff/prep.py index 9a4db4df..be725945 100644 --- a/diff_diff/prep.py +++ b/diff_diff/prep.py @@ -1541,6 +1541,12 @@ def aggregate_survey( # --- Precompute full-length outcome/covariate arrays --- n_total = len(data) all_vars = outcome_cols + cov_cols + non_numeric = [v for v in all_vars if not np.issubdtype(data[v].dtype, np.number)] + if non_numeric: + raise ValueError( + f"Non-numeric column(s) in outcomes/covariates: {non_numeric}. " + f"All outcome and covariate columns must be numeric." + ) y_arrays: Dict[str, np.ndarray] = {var: data[var].values.astype(np.float64) for var in all_vars} # --- Per-cell computation --- diff --git a/tests/test_prep.py b/tests/test_prep.py index 50e89f25..00f1d28b 100644 --- a/tests/test_prep.py +++ b/tests/test_prep.py @@ -2451,6 +2451,18 @@ def test_error_empty_outcomes(self, micro_data, design): survey_design=design, ) + def test_error_non_numeric_outcome(self, micro_data, design): + """Non-numeric outcome column raises ValueError.""" + data = micro_data.copy() + data["label"] = "foo" + with pytest.raises(ValueError, match="Non-numeric column"): + aggregate_survey( + data, + by=["state", "year"], + outcomes="label", + survey_design=design, + ) + def test_error_empty_data(self, design): """Empty DataFrame raises ValueError.""" empty = pd.DataFrame(columns=["state", "year", "y", "wt", "stratum", "cluster"]) From 50ab3bdf372e8093e7c09fe97025d666788d80ea Mon Sep 17 00:00:00 2001 From: igerber Date: Tue, 7 Apr 2026 14:20:38 -0400 Subject: [PATCH 13/16] Use pd.api.types.is_numeric_dtype for nullable dtype support Replace np.issubdtype with pd.api.types.is_numeric_dtype so pandas nullable extension dtypes (Int64, Float64) are accepted as numeric. Add regression test with Float64 outcome column. Co-Authored-By: Claude Opus 4.6 (1M context) --- diff_diff/prep.py | 2 +- tests/test_prep.py | 15 +++++++++++++++ 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/diff_diff/prep.py b/diff_diff/prep.py index be725945..9b323a8f 100644 --- a/diff_diff/prep.py +++ b/diff_diff/prep.py @@ -1541,7 +1541,7 @@ def aggregate_survey( # --- Precompute full-length outcome/covariate arrays --- n_total = len(data) all_vars = outcome_cols + cov_cols - non_numeric = [v for v in all_vars if not np.issubdtype(data[v].dtype, np.number)] + non_numeric = [v for v in all_vars if not pd.api.types.is_numeric_dtype(data[v])] if non_numeric: raise ValueError( f"Non-numeric column(s) in outcomes/covariates: {non_numeric}. " diff --git a/tests/test_prep.py b/tests/test_prep.py index 00f1d28b..6d6c1850 100644 --- a/tests/test_prep.py +++ b/tests/test_prep.py @@ -2463,6 +2463,21 @@ def test_error_non_numeric_outcome(self, micro_data, design): survey_design=design, ) + def test_nullable_numeric_dtypes(self): + """Pandas nullable Int64/Float64 dtypes are accepted as numeric.""" + data = pd.DataFrame( + { + "geo": np.repeat(["A", "B"], 10), + "time": np.ones(20, dtype=int), + "wt": np.ones(20), + "y": pd.array(np.random.RandomState(1).normal(0, 1, 20), dtype="Float64"), + } + ) + design = SurveyDesign(weights="wt") + panel, _ = aggregate_survey(data, by=["geo", "time"], outcomes="y", survey_design=design) + assert len(panel) == 2 + assert panel["y_mean"].notna().all() + def test_error_empty_data(self, design): """Empty DataFrame raises ValueError.""" empty = pd.DataFrame(columns=["state", "year", "y", "wt", "stratum", "cluster"]) From bb5d6609ca2dc61f1cc6ed7582c919b7a985a502 Mon Sep 17 00:00:00 2001 From: igerber Date: Tue, 7 Apr 2026 14:32:39 -0400 Subject: [PATCH 14/16] Add replicate-weight min_n fallback regression test MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Exercises the SRS fallback path under a JK1 replicate-weight design, verifying that fallback SEs are finite/positive and differ from the replicate-based SEs. Covers the min_n × replicate interaction. Co-Authored-By: Claude Opus 4.6 (1M context) --- tests/test_prep.py | 50 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 50 insertions(+) diff --git a/tests/test_prep.py b/tests/test_prep.py index 6d6c1850..f6db1fc7 100644 --- a/tests/test_prep.py +++ b/tests/test_prep.py @@ -2805,3 +2805,53 @@ def test_replicate_weight_aggregation(self): # All cells should have finite, positive SEs assert panel["outcome_se"].notna().all() assert (panel["outcome_se"] > 0).all() + + def test_replicate_weight_min_n_fallback(self): + """SRS fallback works correctly under replicate-weight designs.""" + from diff_diff.prep_dgp import generate_survey_did_data + + micro = generate_survey_did_data( + n_units=200, + n_periods=4, + cohort_periods=[3], + n_strata=3, + psu_per_stratum=6, + include_replicate_weights=True, + panel=False, + seed=42, + ) + rep_cols = [c for c in micro.columns if c.startswith("rep_")] + design = SurveyDesign( + weights="weight", + replicate_weights=rep_cols, + replicate_method="JK1", + ) + + # min_n high enough to force SRS fallback on all cells + with pytest.warns(UserWarning, match="SRS fallback"): + panel_srs, _ = aggregate_survey( + micro, + by=["stratum", "period"], + outcomes="outcome", + survey_design=design, + min_n=9999, + ) + assert panel_srs["srs_fallback"].all() + assert panel_srs["outcome_se"].notna().all() + assert (panel_srs["outcome_se"] > 0).all() + + # Default min_n → replicate-based variance (no fallback) + panel_rep, _ = aggregate_survey( + micro, + by=["stratum", "period"], + outcomes="outcome", + survey_design=design, + ) + assert not panel_rep["srs_fallback"].any() + + # SEs should differ between SRS fallback and replicate-based + assert not np.allclose( + panel_srs["outcome_se"].values, + panel_rep["outcome_se"].values, + rtol=1e-6, + ) From e7eefa18f83a72ba783b27ef1a115e4ea32eb6bf Mon Sep 17 00:00:00 2001 From: igerber Date: Tue, 7 Apr 2026 14:48:34 -0400 Subject: [PATCH 15/16] Normalize weights in SRS fallback for scale invariance Both SRS fallback branches now normalize positive weights to mean=1 before computing variance, ensuring SEs are invariant to constant weight rescaling (important for replicate designs that preserve raw weight scale). Add scale-invariance regression test with 5x rescaling. Co-Authored-By: Claude Opus 4.6 (1M context) --- diff_diff/prep.py | 20 ++++++++++++++++---- tests/test_prep.py | 40 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 56 insertions(+), 4 deletions(-) diff --git a/diff_diff/prep.py b/diff_diff/prep.py index 9b323a8f..c5bdbdb2 100644 --- a/diff_diff/prep.py +++ b/diff_diff/prep.py @@ -1373,10 +1373,17 @@ def _cell_mean_variance( y_bar = float(np.sum(w_valid * y_clean) / sum_w) # SRS fallback if below min_n threshold + # Normalize positive weights to mean=1 so fallback is scale-invariant + # (replicate designs preserve raw weight scale per survey.py:L189-240) used_srs = False if n_valid < min_n: - resid_sq = w_valid * (y_clean - y_bar) ** 2 - variance = float(np.sum(resid_sq) / (sum_w**2) * n_valid / (n_valid - 1)) + w_norm = w_valid.copy() + w_pos = w_norm[w_norm > 0] + if len(w_pos) > 0: + w_norm[w_norm > 0] = w_pos / w_pos.mean() + sum_wn = float(np.sum(w_norm)) + resid_sq = w_norm * (y_clean - y_bar) ** 2 + variance = float(np.sum(resid_sq) / (sum_wn**2) * n_valid / (n_valid - 1)) return y_bar, max(variance, 0.0), n_valid, True # Full-design domain estimation: construct full-length psi with zeros @@ -1396,8 +1403,13 @@ def _cell_mean_variance( # SRS fallback when design-based variance is unidentifiable if np.isnan(variance): - resid_sq = w_valid * (y_clean - y_bar) ** 2 - variance = float(np.sum(resid_sq) / (sum_w**2) * n_valid / (n_valid - 1)) + w_norm = w_valid.copy() + w_pos = w_norm[w_norm > 0] + if len(w_pos) > 0: + w_norm[w_norm > 0] = w_pos / w_pos.mean() + sum_wn = float(np.sum(w_norm)) + resid_sq = w_norm * (y_clean - y_bar) ** 2 + variance = float(np.sum(resid_sq) / (sum_wn**2) * n_valid / (n_valid - 1)) used_srs = True return y_bar, max(float(variance), 0.0), n_valid, used_srs diff --git a/tests/test_prep.py b/tests/test_prep.py index f6db1fc7..e08c48e2 100644 --- a/tests/test_prep.py +++ b/tests/test_prep.py @@ -2855,3 +2855,43 @@ def test_replicate_weight_min_n_fallback(self): panel_rep["outcome_se"].values, rtol=1e-6, ) + + def test_srs_fallback_scale_invariant(self): + """SRS fallback SEs are invariant to constant weight rescaling.""" + rng = np.random.RandomState(55) + n = 60 + data = pd.DataFrame( + { + "geo": np.repeat(["A", "B", "C"], n // 3), + "time": np.ones(n, dtype=int), + "wt": rng.uniform(0.5, 2.0, n), + "y": rng.normal(10, 2, n), + } + ) + design1 = SurveyDesign(weights="wt") + + # Force SRS fallback with high min_n + with pytest.warns(UserWarning, match="SRS fallback"): + panel1, _ = aggregate_survey( + data, + by=["geo", "time"], + outcomes="y", + survey_design=design1, + min_n=9999, + ) + + # Rescale weights by 5x → should give identical SEs + data2 = data.copy() + data2["wt"] = data2["wt"] * 5.0 + design2 = SurveyDesign(weights="wt") + with pytest.warns(UserWarning, match="SRS fallback"): + panel2, _ = aggregate_survey( + data2, + by=["geo", "time"], + outcomes="y", + survey_design=design2, + min_n=9999, + ) + + np.testing.assert_allclose(panel1["y_se"].values, panel2["y_se"].values, rtol=1e-10) + np.testing.assert_allclose(panel1["y_mean"].values, panel2["y_mean"].values, rtol=1e-10) From a4371125e09bfaa0153a0c3d581a326cceda79f1 Mon Sep 17 00:00:00 2001 From: igerber Date: Tue, 7 Apr 2026 16:06:33 -0400 Subject: [PATCH 16/16] Skip aggregate_survey doc snippet in snippet tests The RST example references undefined 'microdata' variable, same as the existing wide_to_long skip pattern. Co-Authored-By: Claude Opus 4.6 (1M context) --- tests/test_doc_snippets.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_doc_snippets.py b/tests/test_doc_snippets.py index cb4296da..8b02b565 100644 --- a/tests/test_doc_snippets.py +++ b/tests/test_doc_snippets.py @@ -97,6 +97,7 @@ def _extract_snippets(rst_path: Path) -> List[Tuple[int, str]]: r"pip\s+install", r"wild_bootstrap_se\(X,", # low-level array API pseudo-code r"wide_to_long\(", # references undefined wide_data variable + r"aggregate_survey\(", # references undefined microdata variable ] # Third-party packages imported by comparison-page snippets that may not